本文共 2383 字,大约阅读时间需要 7 分钟。
import tensorflow as tfimport numpy as npimport matplotlib.pyplot as pltfrom tensorflow.examples.tutorials.mnist import input_datatrainimg = mnist.train.imagestrainlabel = mnist.train.labelstestimg = mnist.test.imagestestlabel = mnist.test.labelsprint(trainlabel[0]) #输出第一行label值
运行结果:
判别样本是否为对应类别,是则为1,否则为0,完成10分类任务,故上述样本的类别为9。#初始化变量x = tf.placeholder("float",[None,784]) #placeholder(先占位,不复制),[None,784]样本的个数(无限大),每个样本的特征(784个像素点)y = tf.placeholder("float",[None,10])#样本的类别(10个)W = tf.Variable(tf.zeros([784,10]))#每个特征(784个像素点)对应输出10个分类值b = tf.Variable(tf.zeros([10]))#逻辑回归模型(softmax完成多分类任务)actv = tf.nn.softmax(tf.matmul(x,W)+b)#计算属于正确类别的概率值#计算损失值(预测值与真实值间的均方差)cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(actv),reduction_indices=1))#采用梯度下降优化参数(W,b),最小化损失值learning_rate=0.01 optm = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) #学习率为0.01#optimizer = tf.train.GradientDescentOptimizer(0.01)#学习率为0.01#最小化损失值#optm = optimizer.minimize(cost)#预测值,equal返回的值是布尔类型#argmax返回矩阵中最大元素的索引,0,代表列方向;1代表行方向 pred = tf.equal(tf.argmax(actv,1),tf.argmax(y,1)) #准确率accr = tf.reduce_mean(tf.cast(pred,"float")) #cast进行类型转化 (true为1,false为0) init_op = tf.global_variables_initializer() #定义全局变量training_epochs = 50 #将所有样本迭代50次batch_size = 100 #每次迭代选择样本的个数 display_step =5 #每进行5个epoch进行一次展示 with tf.Session() as sess: sess.run(init_op) for epoch in range(training_epochs): avg_cost =0.0 # 初始化损失值 num_batch = int(mnist.train.num_examples/batch_size) for i in range(num_batch): batch_xs, batch_ys = mnist.train.next_batch(batch_size) #以batch为单位逐次进行 sess.run(optm,feed_dict={x: batch_xs,y: batch_ys}) #给x,y赋值 feeds={x: batch_xs,y: batch_ys} avg_cost +=sess.run(cost,feed_dict= feeds)/num_batch #显示 if epoch % display_step == 0: feeds_train = {x: batch_xs,y: batch_ys} feeds_test = {x:mnist.test.images,y: mnist.test.labels} train_acc = sess.run(accr,feed_dict= feeds_train) test_acc = sess.run(accr,feed_dict= feeds_test) print("Epoch: %03d/%03d cost:%.9f trian_acc: %.3f test_acc: %.3f" % (epoch,training_epochs,avg_cost,train_acc,test_acc)) print("Done")
运行结果:
由上图可知训练集的准确率为91%,测试集的准确率为91.8%。转载地址:http://dohwi.baihongyu.com/