tensorflow实现softma识别MNIST
作者:freedom098 发布时间:2021-02-17 22:32:56
标签:tensorflow,softma,MNIST
识别MNIST已经成了深度学习的hello world,所以每次例程基本都会用到这个数据集,这个数据集在tensorflow内部用着很好的封装,因此可以方便地使用。
这次我们用tensorflow搭建一个softmax多分类器,和之前搭建线性回归差不多,第一步是通过确定变量建立图模型,然后确定误差函数,最后调用优化器优化。
误差函数与线性回归不同,这里因为是多分类问题,所以使用了交叉熵。
另外,有一点值得注意的是,这里构建模型时我试图想拆分多个函数,但是后来发现这样做难度很大,因为图是在规定变量就已经定义好的,不能随意拆分,也不能当做变量传来传去,因此需要将他们写在一起。
代码如下:
#encoding=utf-8
__author__ = 'freedom'
import tensorflow as tf
def loadMNIST():
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data',one_hot=True)
return mnist
def softmax(mnist,rate=0.01,batchSize=50,epoch=20):
n = 784 # 向量的维度数目
m = None # 样本数,这里可以获取,也可以不获取
c = 10 # 类别数目
x = tf.placeholder(tf.float32,[m,n])
y = tf.placeholder(tf.float32,[m,c])
w = tf.Variable(tf.zeros([n,c]))
b = tf.Variable(tf.zeros([c]))
pred= tf.nn.softmax(tf.matmul(x,w)+b)
loss = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),reduction_indices=1))
opt = tf.train.GradientDescentOptimizer(rate).minimize(loss)
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
for index in range(epoch):
avgLoss = 0
batchNum = int(mnist.train.num_examples/batchSize)
for batch in range(batchNum):
batch_x,batch_y = mnist.train.next_batch(batchSize)
_,Loss = sess.run([opt,loss],{x:batch_x,y:batch_y})
avgLoss += Loss
avgLoss /= batchNum
print 'every epoch average loss is ',avgLoss
right = tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(right,tf.float32))
print 'Accracy is ',sess.run(accuracy,({x:mnist.test.images,y:mnist.test.labels}))
if __name__ == "__main__":
mnist = loadMNIST()
softmax(mnist)
来源:http://blog.csdn.net/freedom098/article/details/52116813
0
投稿
猜你喜欢
- 从Request对象中获取数据我们在第三章讲述View的函数时已经介绍过HttpRequest对象了,但当时并没有讲太多。 让我们回忆下:每
- 前端的小伙伴们在babel等的加持下,已经可以愉快的使用es6来写代码了。然后对于服务端的nodejs就有点坑爹了,虽然原生支持了es6,但
- 自从HTML5能为我们的新网页带来更高效洁净的代码而得到更多的关注,然而唯一能让IE识别那些新元素(如<article>)的途径
- 在函数参数中乱用表达式作为默认值Python允许给一个函数的某个参数设置默认值以使该参数成为一个可选参数。尽管这是这门语言很棒的一个功能,但
- 在本节中,我们将详细介绍 Python 标准库中的 json 模块。JSON(JavaScript Objec
- 本文实例讲述了Go语言生成随机数的方法。分享给大家供大家参考。具体实现方法如下:golang生成随机数可以使用math/rand包packa
- 动态加载JavaScript文件和CSS资源为Web前端开发提供了巨大的灵活性,同时也实现了lazy load和按需加载,相比XMLHttp
- 熟悉SQL的人都知道,完成同一个任务,SQL可能有多种写法,但不同写法的查询性能可能会有天壤之别,本文列举出五个查询优化的方法,当然,优化的
- 前言在Python中定义函数,可以用必选参数、默认参数、可变参数和关键字参数,这4种参数都可以一起使用,或者只用其中某些,但是请注意,参数定
- 一、简介 XML(eXtensible Markup Languag
- 一、什么是 Postman(前世今生)Postman 诞生于 2013 年,一开始只是 Abhinav Asthana 着手于解决 API
- 本文实例讲述了PHP接口多继承及tarits实现多继承效果的方法。分享给大家供大家参考,具体如下:接口多继承在PHP的面向对象中,接口可以继
- 在我们开始之前,一定要注意这篇文章只针对Windows用户!对于那些使用Windows的人来说,这是一个有趣的想法。如果您想使用python
- 本文实例讲述了python根据路径导入模块的方法,分享给大家供大家参考。具体方法如下:常规做法如下:import sys sys.path.
- expect脚本expect是什么expect是一个免费的编程工具,用来实现自动的交互式任务,而无需人为干预。说白了,expect就是一套用
- 如何在页面中快捷地添加翻页按钮? 先编写一个nextprev.inc文件,再将代码<
- CUDA的线程与块GPU从计算逻辑来讲,可以认为是一个高并行度的计算阵列,我们可以想象成一个二维的像围棋棋盘一样的网格,每一个格子都可以执行
- 最近需要用Python写一个简易通讯录,但是对于数据存储很发愁。大家都知道,使用 Python 中的列表和字典进行存储数据是很不靠谱的,所以
- 一,前言今天做的东西,还算可以,修改了若干个bug,自己又写成功的写了几个bug。增加了一个功能——
- 其实这个东西没什么技术含量,就是给大家提供一个给表格加滚动条的思路。运行代码框<html><head><tit