TensorFlow实现Softmax回归模型
作者:marsjhao 发布时间:2023-07-14 19:07:51
一、概述及完整代码
对MNIST(MixedNational Institute of Standard and Technology database)这个非常简单的机器视觉数据集,Tensorflow为我们进行了方便的封装,可以直接加载MNIST数据成我们期望的格式.本程序使用Softmax Regression训练手写数字识别的分类模型.
先看完整代码:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data", one_hot=True)
print(mnist.train.images.shape, mnist.train.labels.shape)
print(mnist.test.images.shape, mnist.test.labels.shape)
print(mnist.validation.images.shape, mnist.validation.labels.shape)
#构建计算图
x = tf.placeholder(tf.float32, [None, 784])
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
y = tf.nn.softmax(tf.matmul(x, W) + b)
y_ = tf.placeholder(tf.float32, [None, 10])
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
#在会话sess中启动图
sess = tf.InteractiveSession() #创建InteractiveSession对象
tf.global_variables_initializer().run() #全局参数初始化器
for i in range(1000):
batch_xs, batch_ys = mnist.train.next_batch(100)
train_step.run({x: batch_xs, y_: batch_ys})
#测试验证阶段
#沿着第1条轴方向取y和y_的最大值的索引并判断是否相等
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
#转换bool型tensor为float32型tensor并求平均即得到正确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print(accuracy.eval({x: mnist.test.images, y_: mnist.test.labels}))
二、详细解读
首先看一下使用TensorFlow进行算法设计训练的核心步骤
1.定义算法公式,也就是神经网络forward时的计算;
2.定义loss,选定优化器,并制定优化器优化loss;
3.在训练集上迭代训练算法模型;
4.在测试集或验证集上对训练得到的模型进行准确率评测.
首先创建一个Placeholder,即输入张量数据的地方,第一个参数是数据类型dtype,第二个参数是tensor的形状shape.接下来创建SoftmaxRegression模型中的weights(W)和biases(b)的Variable对象,不同于存储数据的tensor一旦使用掉就会消失,Variable在模型训练迭代中是持久存在的,并且在每轮迭代中被更新Variable初始化可以是常量或随机值.接下来实现模型算法y = softmax(Wx + b),TensorFlow语言只需要一行代码,tf.nn包含了大量神经网络的组件,头tf.matmul是矩阵乘法函数.TensorFlow将模型中的forward和backward的内容都自动实现,只要定义好loss,训练的时候会自动求导并进行梯度下降,完成对模型参数的自动学习.定义损失函数lossfunction来描述分类精度,对于多分类问题通常使用cross-entropy交叉熵.先定义一个placeholder输入真实的label,tf.reduce_sum和tf.reduce_mean的功能分别是求和和求平均.构造完损失函数cross-entropy后,再定义一个优化算法即可开始训练.我们采用随机梯度下降SGD,定义好后TensorFlow会自动添加许多运算操作来实现反向传播和梯度下降,而给我们提供的是一个封装好的优化器,只需要每轮迭代时feed数据给它就好.设置好学习率.
构造阶段完成后, 才能启动图. 启动图的第一步是创建一个 Session 对象或InteractiveSession对象, 如果无任何创建参数, 会话构造器将启动默认图.创建InteractiveSession对象会这个Session注册为默认的Session,之后的运算也默认跑在这个Session里面,不同Session之间的数据和运算应该是相互独立的.下一步使用TensorFlow的全局参数初始化器tf.global_variables_initializer病直接执行它的run方法(这个全局参数初始化器应该是1.0.0版本中的新特性,在之前0.10.0版本测试不通过).
至此,以上定义的所有公式其实只是Computation Graph,代码执行到这时,计算还没有实际发生,只有等调用run方法并feed数据时计算才真正执行.
随后一步,就可以开始迭代地执行训练操作train_step.这里每次都从训练集中随机抽取100条样本构成一个mini-batch,并feed给placeholder.
完成迭代训练后,就可以对模型的准确率进行验证.比较y和y_在各个测试样本中最大值所在的索引,然后转换为float32型tensor后求平均即可得到正确率.多次测试后得到在测试集上的正确率为92%左右.还是比较理想的结果.
三、其他补充
1.Sesssion类和InteractiveSession类
对于product =tf.matmul(matrix1, matrix2),调用 sess 的 'run()' 方法来执行矩阵乘法 op, 传入 'product' 作为该方法的参数.上面提到, 'product' 代表了矩阵乘法 op 的输出, 传入它是向方法表明, 我们希望取回矩阵乘法 op 的输出.整个执行过程是自动化的, 会话负责传递op 所需的全部输入. op 通常是并发执行的.函数调用 'run(product)' 触发了图中三个 op (两个常量 op 和一个矩阵乘法 op)的执行.返回值 'result' 是一个 numpy的`ndarray`对象.
Session 对象在使用完后需要关闭以释放资源sess.close(). 除了显式调用 close 外, 也可以使用"with" 代码块 来自动完成关闭动作.
with tf.Session() as sess:
result = sess.run([product])
print result
为了便于使用诸如 IPython 之类的 Python 交互环境, 可以使用InteractiveSession代替 Session 类, 使用 Tensor.eval()和 Operation.run()方法代替 Session.run(). 这样可以避免使用一个变量来持有会话.
# 进入一个交互式 TensorFlow 会话.
import tensorflow as tf
sess = tf.InteractiveSession()
x = tf.Variable([1.0, 2.0])
a = tf.constant([3.0, 3.0])
# 使用初始化器 initializer op 的 run() 方法初始化 'x'
x.initializer.run()
# 增加一个减法 sub op, 从 'x' 减去 'a'. 运行减法 op, 输出结果
sub = tf.sub(x, a)
print sub.eval()
# ==> [-2. -1.]
2.tf.reduce_sum
首先,tf.reduce_X一系列运算操作(operation)是实现对一个tensor各种减少维度的数学计算.
tf.reduce_sum(input_tensor, reduction_indices=None,keep_dims=False, name=None)
运算功能:沿着给定维度reduction_indices的方向降低input_tensor的维度,除非keep_dims=True,tensor的秩在reduction_indices上减1,被降低的维度的长度为1.如果reduction_indices没有传入参数,所有维度都降低,返回只含有1个元素的tensor.运算最终返回降维后的tensor.
演示代码:
# 'x' is [[1, 1, 1]
# [1, 1, 1]]
tf.reduce_sum(x) ==> 6
tf.reduce_sum(x, 0) ==> [2, 2, 2]
tf.reduce_sum(x, 1) ==> [3, 3]
tf.reduce_sum(x, 1, keep_dims=True) ==> [[3], [3]]
tf.reduce_sum(x, [0, 1]) ==> 6
3.tf.reduce_mean
tf.reduce_mean(input_tensor, reduction_indices=None,keep_dims=False, name=None)
运算功能:将input_tensor沿着给定维度reduction_indices减少维度,除非keep_dims=True,tensor的秩在reduction_indices上减1,被降低的维度的长度为1.如果reduction_indices没有传入参数,所有维度都降低,返回只含有1个元素的tensor.运算最终返回降维后的tensor.
演示代码:
# 'x' is [[1., 1. ]
# [2., 2.]]
tf.reduce_mean(x) ==> 1.5
tf.reduce_mean(x, 0) ==> [1.5, 1.5]
tf.reduce_mean(x, 1) ==> [1., 2.]
4.tf.argmax
tf.argmax(input, dimension, name=None)
运算功能:返回input在指定维度下的最大值的索引.返回类型为int64.
来源:http://blog.csdn.net/marsjhao/article/details/63709327
猜你喜欢
- 首先恭喜月影,当然希望好书大卖!原文提供了样章下载1.1M,pdf格式的。如果大家想下载可以访问源地址:http://bbs.51js.co
- 今天网页调试的时候在线订单出现错误:Server 对象 错误 'ASP 0178
- 当我们在使用pycharm时,输入特殊的关键字会有提示,然后按enter就可以自动补全,如果我们经常需要输出重复的代码时,能否也利用这种方法
- 废话不多说,直接上代码/** * lhgcalendar时间插件限制只能选择三个月 * @d 获取到的开始时间 * @m 要限制的时间的长度
- 最近在做Python 的项目,特地整理了下 Python 序列的方法。序列se
- 简介表单的操作是Web程序开发中最核心的模块之一,绝大多数的动态交互功能都是通过表单的形式实现的。本文会教大家实现简单的表单操作。普通表单提
- 前言本文是美团一位大佬写的,还不错拿出来和大家分享下,代码中嵌套在html中sql语句是java框架的写法,理解其sql要执行的语句即可。背
- 一、聚合函数聚合函数:又叫组函数,用来对表中的数据进行统计和计算,结合group by分组使用,用于统计和计算分组数据常用聚合函数count
- <'% '************************************************
- 视频观看视频Pygame模块之pygame.draw本文将主要介绍Pygame的draw模块,主要内容翻译自pygame的官方文档pygam
- 前言在开始本文之前,首先要知道Python中对象包含的三个基本要素,分别是:id(身份标识)、python type()(数据类型)和val
- Access允许您在数据库表中包含附件。通过利用微软的对象链接和嵌入(OLE)技术,您可以将照片、图表、文档及其他文件存储在您的Access
- 最终效果如下图,右侧灰边看相对位置,版权所有谨防假冒:去年曾针对有时间先后的翻页记录了思考片段。之后没来得及调整一直是默认和插件并用,虽然难
- Mybatis插入mysql报主键重复的问题首先思路是这样的,先去数据表里面去找有没有这个主键的数据(如果有会有返回值,如果没有则返回nul
- 在查询语句中使用 NOLOCK 和 READPAST 处理一个数据库死锁的异常时候,其中一个建议就是使用 NOLOCK 或者 READPAS
- 示例一:直接编写AJAX 实现。 客户端: 代码如下:<!DOCTYPE html PUBLIC &qu
- Javascript刷新页面的几种方法: 1. history.go(0) 2. location.reload() 3. location
- 目录1. 配置Python环境变量2. 安装Python编辑器,并在其中配置Python3. 安装控制包uiautomator2,和其它辅助
- 1:事件机制共享队列:利用消息机制在两个队列中,通过传递消息,实现可以控制的生产者消费者问题要求:readthread读时,writethr
- 在使用Python库时,常常会用到matplotlib.pyplot绘图,本文介绍在PyCharm及Jupyter Notebook页面中控