tensorflow实现softma识别MNIST
作者:freedom098 发布时间:2021-02-17 22:32:56
标签:tensorflow,softma,MNIST
识别MNIST已经成了深度学习的hello world,所以每次例程基本都会用到这个数据集,这个数据集在tensorflow内部用着很好的封装,因此可以方便地使用。
这次我们用tensorflow搭建一个softmax多分类器,和之前搭建线性回归差不多,第一步是通过确定变量建立图模型,然后确定误差函数,最后调用优化器优化。
误差函数与线性回归不同,这里因为是多分类问题,所以使用了交叉熵。
另外,有一点值得注意的是,这里构建模型时我试图想拆分多个函数,但是后来发现这样做难度很大,因为图是在规定变量就已经定义好的,不能随意拆分,也不能当做变量传来传去,因此需要将他们写在一起。
代码如下:
#encoding=utf-8
__author__ = 'freedom'
import tensorflow as tf
def loadMNIST():
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data',one_hot=True)
return mnist
def softmax(mnist,rate=0.01,batchSize=50,epoch=20):
n = 784 # 向量的维度数目
m = None # 样本数,这里可以获取,也可以不获取
c = 10 # 类别数目
x = tf.placeholder(tf.float32,[m,n])
y = tf.placeholder(tf.float32,[m,c])
w = tf.Variable(tf.zeros([n,c]))
b = tf.Variable(tf.zeros([c]))
pred= tf.nn.softmax(tf.matmul(x,w)+b)
loss = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),reduction_indices=1))
opt = tf.train.GradientDescentOptimizer(rate).minimize(loss)
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
for index in range(epoch):
avgLoss = 0
batchNum = int(mnist.train.num_examples/batchSize)
for batch in range(batchNum):
batch_x,batch_y = mnist.train.next_batch(batchSize)
_,Loss = sess.run([opt,loss],{x:batch_x,y:batch_y})
avgLoss += Loss
avgLoss /= batchNum
print 'every epoch average loss is ',avgLoss
right = tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(right,tf.float32))
print 'Accracy is ',sess.run(accuracy,({x:mnist.test.images,y:mnist.test.labels}))
if __name__ == "__main__":
mnist = loadMNIST()
softmax(mnist)
来源:http://blog.csdn.net/freedom098/article/details/52116813


猜你喜欢
- 在学习linear regression时经常处理的数据一般多是矩阵或者n维向量的数据形式,所以必须对矩阵有一定的认识基础。numpy中创建
- 1.背景最近使用Pytest中的fixture和conftest时,遇到需要在conftest中的setup和teardown方法里传递参数
- php 生成短网址 原理: 1.将原网址做crc32校验,得到校验码。 2.使用sprintf('%u') 将校验码转为无符
- BULK INSERT以用户指定的格式复制一个数据文件至数据库表或视图中。 语法:BULK INSERT [ [ 'database
- 现在,我们已经把一个Web App的框架完全搭建好了,从后端的API到前端的MVVM,流程已经跑通了。在继续工作前,注意到每次修改Pytho
- PyQt5工具栏控件QToolBar介绍QToolBar控件是由文本按钮,图标或其他小控件按钮组成的可移动面板,通常位于菜单栏下方QTool
- 当然我们可以在后台中获取参数的值,然后在前台js代码中获取变量的值,具体做法请参考我的这篇文章:JavaScript获取后台C#变量以及调用
- 本文介绍如何利用带进度条的ASP无组件实现断点续传下载大文件。<%@LANGUAGE="VBSCRIPT"&nbs
- 本文实例为大家分享了python手写均值滤波的具体代码,供大家参考,具体内容如下原理与卷积类似,设置一个n*n的滤波模板,滤波模板内的值累加
- 在开始后面的内容之前,先来解释一下urllib2中的两个个方法:info / geturl urlopen返回的应答对象respo
- 熵值法也称熵权法,是学术研究,及实际应用中的一种常用且有效的编制指标的方法。1.简单理解 信息熵机器学习中的决策树算法是对信息熵的一种典型的
- 使用python画图,发现生成的图片在console里。不仅感觉很别扭,很多功能也没法实现(比如希望在一幅图里画两条曲线)。想像matlab
- 一、数据降维机器学习中的维度就是特征的数量,降维即减少特征数量。降维方式有:特征选择、主成分分析。1.特征选择当出现以下情况时,可选择该方式
- 很多介绍 根据日志等级打印不同颜色 的文章都是介绍的Ideolog , 但是我个人还是倾向于 Grep Console , 你可以在配置界面
- 在已知DICOM和三维模型对应掩膜的情况下,计算三维模型的体积。思路:1、计算每个体素的体积。每个体素为长方体,x,y为PixelSpaci
- 如果您刚刚开始接触网页设计,是不是经常发生这样的问题呢?做好的网页在自己机器上可以正常浏览,而把页面传到服务器上就总是出现看不到图片,css
- 要想在不宽裕的页面展现丰富的内容,现在通用的做法使用tab,在一块区域通过tab切换来更换该区域的内容。这篇文章分析了tab设计很在理,今天
- 在thoughtbot,我们用Ruby和Rails工作,但通常我们总是尝试使用最合适的语言或者框架来解决问题。我最近一直在探索机器学习技术,
- 1,下载Yii,站点:http://www.yiiframework.com/download/注意版本,这里是根据Yii1来的,如果是Yi
- 当用到socket来进行网络程序开发时,大多数情况下会遇到中文字符的发送与接收,这时若对发送的字符串用默认的方式进行处理,则一般会得到一堆乱