TensorFlow教程Softmax逻辑回归识别手写数字MNIST数据集
作者:零尾 发布时间:2021-05-24 18:25:35
标签:TensorFlow,Softmax,MNIST,逻辑回归
基于MNIST数据集的逻辑回归模型做十分类任务
没有隐含层的Softmax Regression只能直接从图像的像素点推断是哪个数字,而没有特征抽象的过程。多层神经网络依靠隐含层,则可以组合出高阶特征,比如横线、竖线、圆圈等,之后可以将这些高阶特征或者说组件再组合成数字,就能实现精准的匹配和分类。
import tensorflow as tf
import numpy as np
import input_data
print('Download and Extract MNIST dataset')
mnist = input_data.read_data_sets('data/', one_hot=True) # one_hot=True意思是编码格式为01编码
print("tpye of 'mnist' is %s" % (type(mnist)))
print("number of train data is %d" % (mnist.train.num_examples))
print("number of test data is %d" % (mnist.test.num_examples))
trainimg = mnist.train.images
trainlabel = mnist.train.labels
testimg = mnist.test.images
testlabel = mnist.test.labels
print("MNIST loaded")
"""
print("type of 'trainimg' is %s" % (type(trainimg)))
print("type of 'trainlabel' is %s" % (type(trainlabel)))
print("type of 'testimg' is %s" % (type(testimg)))
print("type of 'testlabel' is %s" % (type(testlabel)))
print("------------------------------------------------")
print("shape of 'trainimg' is %s" % (trainimg.shape,))
print("shape of 'trainlabel' is %s" % (trainlabel.shape,))
print("shape of 'testimg' is %s" % (testimg.shape,))
print("shape of 'testlabel' is %s" % (testlabel.shape,))
"""
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10]) # None is for infinite
w = tf.Variable(tf.zeros([784, 10])) # 为了方便直接用0初始化,可以高斯初始化
b = tf.Variable(tf.zeros([10])) # 10分类的任务,10种label,所以只需要初始化10个b
pred = tf.nn.softmax(tf.matmul(x, w) + b) # 前向传播的预测值
cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred), reduction_indices=[1])) # 交叉熵损失函数
optm = tf.train.GradientDescentOptimizer(0.01).minimize(cost)
corr = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1)) # tf.equal()对比预测值的索引和真实label的索引是否一样,一样返回True,不一样返回False
accr = tf.reduce_mean(tf.cast(corr, tf.float32))
init = tf.global_variables_initializer() # 全局参数初始化器
training_epochs = 100 # 所有样本迭代100次
batch_size = 100 # 每进行一次迭代选择100个样本
display_step = 5
# SESSION
sess = tf.Session() # 定义一个Session
sess.run(init) # 在sess里run一下初始化操作
# MINI-BATCH LEARNING
for epoch in range(training_epochs): # 每一个epoch进行循环
avg_cost = 0. # 刚开始损失值定义为0
num_batch = int(mnist.train.num_examples/batch_size)
for i in range(num_batch): # 每一个batch进行选择
batch_xs, batch_ys = mnist.train.next_batch(batch_size) # 通过next_batch()就可以一个一个batch的拿数据,
sess.run(optm, feed_dict={x: batch_xs, y: batch_ys}) # run一下用梯度下降进行求解,通过placeholder把x,y传进来
avg_cost += sess.run(cost, feed_dict={x: batch_xs, y:batch_ys})/num_batch
# DISPLAY
if epoch % display_step == 0: # display_step之前定义为5,这里每5个epoch打印一下
train_acc = sess.run(accr, feed_dict={x: batch_xs, y:batch_ys})
test_acc = sess.run(accr, feed_dict={x: mnist.test.images, y: mnist.test.labels})
print("Epoch: %03d/%03d cost: %.9f TRAIN ACCURACY: %.3f TEST ACCURACY: %.3f"
% (epoch, training_epochs, avg_cost, train_acc, test_acc))
print("DONE")
迭代100次跑一下模型,最终,在测试集上可以达到92.2%的准确率,虽然还不错,但是还达不到实用的程度。手写数字的识别的主要应用场景是识别银行支票,如果准确率不够高,可能会引起严重的后果。
Epoch: 095/100 loss: 0.283259882 train_acc: 0.940 test_acc: 0.922
插一些知识点,关于tensorflow中一些函数的用法
sess = tf.InteractiveSession()
arr = np.array([[31, 23, 4, 24, 27, 34],
[18, 3, 25, 0, 6, 35],
[28, 14, 33, 22, 30, 8],
[13, 30, 21, 19, 7, 9],
[16, 1, 26, 32, 2, 29],
[17, 12, 5, 11, 10, 15]])
在tensorflow中打印要用.eval()
tf.rank(arr).eval() # 打印矩阵arr的维度
tf.shape(arr).eval() # 打印矩阵arr的大小
tf.argmax(arr, 0).eval() # 打印最大值的索引,参数0为按列求索引,1为按行求索引
来源:https://blog.csdn.net/lwplwf/article/details/60603746
0
投稿
猜你喜欢
- 前言针对于一维数组的存储方式,即(n,)存储为列向量一、创建一个array使用np.arange()创建一个一维数组,或者np.array(
- 在这个擦亮自己的眼睛去看SQL Server的系列中的第二篇中提过要写历史渊源,这里的历史主要描述的是数据库本身的历史与SQL Server
- 我就废话不多说了,大家还是直接看代码吧try: s = socket.socket() s.bind(('127.0.0.1'
- 本文实例讲述了Python使用crontab模块设置和清除定时任务操作。分享给大家供大家参考,具体如下:centos7下安装Python的p
- 使用 sorted() 函数使用 sorted() 函数对字典进行排序,将其转换为元组列表,再按照指定的键或者值进行排序。按照键排序的示例代
- Firefox 3.5已经发布了几个月了,且已经历5次小幅更新。而基于Gecko 1.9.2的Firefox 3.6也已经开发数月,现在已经
- 踩了很多坑,记录一下这次试验,本次测试环境:Linux centos7 64位。pyenv是一个python版本管理工具,它能够进行全局的p
- 最近在做编程练习,发现有些结果的值与答案相差较大,通过分析比较得出结论,大概过程如下:定义了一个计算损失的函数:def error(yhat
- 我是在做行人检测中需要将一段视频变为图片数据集,然后想将视频每秒钟的图片提取出来。语言:python所需要的库:cv2,numpy (自行安
- 相信很多程序员在调试代码时,都用过 print。代码少还好说,如果是大型项目,面对众多 print 的输出结果,可能要头大了。今天推荐一个
- 本文实例为大家分享了Python使用Pillow添加水印的具体代码,供大家参考,具体内容如下python数据分析得到的图片,并对照片添加水印
- asp 在线备份 恢复 sql server 数据库,对于远程没有提供sql server远程连接或打包下载的朋友是个临时解决方法,对于大数
- 注:所有文字,除注明网站类型外,其他均针对企业站点.请随时注意留言,若修改则会在首页提示文字里标注.若牵扯到业务方面的问题,我可能不会做过多
- 前言以下是我对python中编写脚本最重要的库之一pyautogui的学习整理,分享给大家希望有所帮助提示:我在初步使用pyautogui的
- 给定list,如何以空格/逗号等符号以分隔符输出呢?一般的,简单的for循环可以打印出list的内容:l=[1,2,3,4]for i in
- 一、MySQL修改密码方法总结首先要说明一点的是:一般情况下,修改MySQL密码是需要有mysql里的root权限的,这样一般用户是无法更改
- XPath(XML Path language)是一种处理XML文档段的语言。XSLT(Extensible Stylesheet Lang
- 2022-09-29shell操作:我在使用中是pycharm与数据库建立连接的一个工具。使用的环境:在此处是用在了虚拟环境中。使用场景:一
- Python中的[1:]意思是去掉列表中第一个元素(下标为0),去后面的元素进行操作,以一个示例题为例,用在遍历中统计个数:题:读入N名学生
- anaconda指的是一个开源的Python