TensorFlow学习笔记12:CIFAR-10数据集图片分类
CIFAR-10数据集简介
CIFAR-10是一个更接近普适物体的彩色图像数据集。CIFAR-10 是由Hinton 的学生Alex Krizhevsky 和Ilya Sutskever 整理的一个用于识别普适物体的小型数据集。一共包含10个类别的RGB彩色图片:飞机( airplane )、汽车( automobile )、鸟类( bird )、猫( cat )、鹿( deer )、狗( dog )、蛙类( frog )、马( horse )、船( ship )和卡车( truck )。每个图片的尺寸为32 × 32 ,每个类别有6000个图像,数据集中一共有50000 张训练图片和10000 张测试图片。
TensenFlow实现
数据准备
首先去官方库下载cifar10.py
以及cifar10_input.py
文件来下载CIFAR-10数据集二进制文件,以及读取文件内容。
GitHub地址:https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10
运行下载数据集函数1
2from tensorflow.models.tutorials.image.cifar10 import cifar10
cifar10.maybe_download_and_extract()
数据集文件默认下载在./tmp/cifar10_data
文件下,可以将其移动到自己的工程文件夹下
定义函数
首先定义一个权重初始化函数1
2
3
4
5
6def variable_with_weight_loss(shape,std,w1):
var = tf.Variable(tf.truncated_normal(shape,stddev=std),dtype=tf.float32)
if w1 is not None:
weight_loss = tf.multiply(tf.nn.l2_loss(var),w1,name="weight_loss")
tf.add_to_collection("losses",weight_loss)
return var
然后定义一个损失函数1
2
3
4
5
6
7def loss_func(logits,labels):
labels = tf.cast(labels,tf.int32)
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
labels=labels,name="cross_entropy_per_example")
cross_entropy_mean = tf.reduce_mean(tf.reduce_sum(cross_entropy))
tf.add_to_collection("losses",cross_entropy_mean)
return tf.add_n(tf.get_collection("losses"),name="total_loss")
读取数据
1 | #设置每次训练的数据大小 |
网络结构
1 | #定义模型的输入和输出数据 |
训练和优化
1 | #设置最大迭代次数 |
训练过程
1 | #开始训练 |
测试过程
1 | #计算测试集上的准确率 |