TensorFlow MNIST手写数据集的实现方法
作者:Baby-Lily 发布时间:2022-12-19 19:45:02
标签:TensorFlow,MNIST,数据集
MNIST数据集介绍
MNIST数据集中包含了各种各样的手写数字图片,数据集的官网是:http://yann.lecun.com/exdb/mnist/index.html,我们可以从这里下载数据集。使用如下的代码对数据集进行加载:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
运行上述代码会自动下载数据集并将文件解压在MNIST_data文件夹下面。代码中的one_hot=True,表示将样本的标签转化为one_hot编码。
MNIST数据集中的图片是28*28的,每张图被转化为一个行向量,长度是28*28=784,每一个值代表一个像素点。数据集中共有60000张手写数据图片,其中55000张训练数据,5000张测试数据。
在MNIST中,mnist.train.images是一个形状为[55000, 784]的张量,其中的第一个维度是用来索引图片,第二个维度图片中的像素。MNIST数据集包含有三部分,训练数据集,验证数据集,测试数据集(mnist.validation)。
标签是介于0-9之间的数字,用于描述图片中的数字,转化为one-hot向量即表示的数字对应的下标为1,其余的值为0。标签的训练数据是[55000,10]的数字矩阵。
下面定义了一个简单的网络对数据集进行训练,代码如下:
import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
tf.reset_default_graph()
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])
w = tf.Variable(tf.random_normal([784, 10]))
b = tf.Variable(tf.zeros([10]))
pred = tf.matmul(x, w) + b
pred = tf.nn.softmax(pred)
cost = tf.reduce_mean(-tf.reduce_sum(y * tf.log(pred), reduction_indices=1))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
training_epochs = 25
batch_size = 100
display_step = 1
save_path = 'model/'
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for epoch in range(training_epochs):
avg_cost = 0
total_batch = int(mnist.train.num_examples/batch_size)
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
_, c = sess.run([optimizer, cost], feed_dict={x:batch_xs, y:batch_ys})
avg_cost += c / total_batch
if (epoch + 1) % display_step == 0:
print('epoch= ', epoch+1, ' cost= ', avg_cost)
print('finished')
correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print('accuracy: ', accuracy.eval({x:mnist.test.images, y:mnist.test.labels}))
save = saver.save(sess, save_path=save_path+'mnist.cpkt')
print(" starting 2nd session ...... ")
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.restore(sess, save_path=save_path+'mnist.cpkt')
correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print('accuracy: ', accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))
output = tf.argmax(pred, 1)
batch_xs, batch_ys = mnist.test.next_batch(2)
outputval= sess.run([output], feed_dict={x:batch_xs, y:batch_ys})
print(outputval)
im = batch_xs[0]
im = im.reshape(-1, 28)
plt.imshow(im, cmap='gray')
plt.show()
im = batch_xs[1]
im = im.reshape(-1, 28)
plt.imshow(im, cmap='gray')
plt.show()
总结
以上所述是小编给大家介绍的TensorFlow MNIST手写数据集的实现方法,希望对大家有所帮助!
来源:https://www.cnblogs.com/baby-lily/p/10961482.html


猜你喜欢
- 打桩测试当我们在编写单元测试的时候,有时我们非常想 mock 掉其中一个方法,但是这个方法又没有接口去定义和实现(无法用 gith
- 通常程序会被编写为一个顺序执行并完成一个独立任务的代码。如果没有特别的需求,最好总是这样写代码,因为这种类型的程序通常很容易写,也很容易维护
- 说明1、Matplotlib函数可以绘制图形,使用plot函数绘制曲线。2、需要将200个点的x坐标和Y坐标分别以序列的形式输入plot函数
- 前面学习过Meanshift算法,在观察这个结果标记时,会发现有这样一个问题,如下图:汽车比较远时,用一个很小的窗口就可以把它框住,这是符合
- 一、Python urllib 模块是什么urllib 模块是 Python 标准库,其价值在于抓取网络上的 URL 资源,入门爬
- 使用PyQt5开发图形界面,里面使用日期框,这里把这个QDateEdit组件命名为:beginDatefrom PyQt5.QtCore i
- 一、ADO.Net数据库连接字符串1、OdbcConnection(System.Data.Odbc)(1)SQL Sever标准安全:&q
- django实现多种支付方式'''#思路我们希望,通过插拔的方式来实现多方式登录,比如新增一种支付方式,那么只要在项
- 1、update delete insert 这种语句都需要commit或者直接在连接数据库的时候加上autocommit=Trueimpo
- 目录1安装loguru|2loguru简单使用|3loguru保留日志文件|4loguru字符串输出|5loguru封装类,可以直接拿去用!
- <!doctype><html><head><title>新闻图片轮换类</title
- 目录:分析和设计组件编码实现和算法用 Ant 构建组件测试 JavaScript 组件我们走到哪儿了?前两期思考了太多东西,你是否已有倦意?
- 前言前面已经讲述了如何获取股票的k线数据,今天我们来分析一下股票的资金流入情况,股票的上涨和下跌都是由资金推动的,这其中的北上资金就是一个风
- 最近在做移动端项目时,需要实现一个列表页面的每一项item向左滑动时出现相应的删除按钮,本来想着直接使用zepto的touch.js插件,因
- python有很多有趣的库,其中wxpy是连接微信的接口,具体可以查看官方文档。可以实现自动操作,wxpy 支持 Python 3.4-3.
- 一、图像噪声图像噪声是图像在获取或者传输过程中受到随机信号干扰,妨碍人们对图像理解及分析处理的信号。很多时候将图像看作随机过程,因而描述噪声
- 关于端口复用一个套接字不能同时绑定多个端口,如果客户端想绑定端口号,一定要调用发送信息函数之前绑定( bind )端口,因为在发送信息函数(
- 摘 要:本文讨论了Visual Basic应用程序访问SQL Server数据库的几种常用的方法,分别说明了每种方法的内部机理并给出了每种方
- 本文实例为大家分享了python实现转盘效果的具体代码,供大家参考,具体内容如下#抽奖 面向对象版本import tkinterimport
- Servermanager启动连接数据库错误运行mgrstart.bat报错如下解决办法:修改C:\Siemens\Teamcenter12