基于Tensorflow的MNIST手写数字识别分类
作者:qq_40579095 发布时间:2023-12-01 11:35:18
标签:Tensorflow,MNIST,数字识别
本文实例为大家分享了基于Tensorflow的MNIST手写数字识别分类的具体实现代码,供大家参考,具体内容如下
代码如下:
import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.contrib.tensorboard.plugins import projector
import time
IMAGE_PIXELS = 28
hidden_unit = 100
output_nums = 10
learning_rate = 0.001
train_steps = 50000
batch_size = 500
test_data_size = 10000
#日志目录(这里根据自己的目录修改)
logdir = 'D:/Develop_Software/Anaconda3/WorkDirectory/summary/mnist'
#导入mnist数据
mnist = input_data.read_data_sets('MNIST_data', one_hot = True)
#全局训练步数
global_step = tf.Variable(0, name = 'global_step', trainable = False)
with tf.name_scope('input'):
#输入数据
with tf.name_scope('x'):
x = tf.placeholder(
dtype = tf.float32, shape = (None, IMAGE_PIXELS * IMAGE_PIXELS))
#收集x图像的会总数据
with tf.name_scope('x_summary'):
shaped_image_batch = tf.reshape(
tensor = x,
shape = (-1, IMAGE_PIXELS, IMAGE_PIXELS, 1),
name = 'shaped_image_batch')
tf.summary.image(name = 'image_summary',
tensor = shaped_image_batch,
max_outputs = 10)
with tf.name_scope('y_'):
y_ = tf.placeholder(dtype = tf.float32, shape = (None, 10))
with tf.name_scope('hidden_layer'):
with tf.name_scope('hidden_arg'):
#隐层模型参数
with tf.name_scope('hid_w'):
hid_w = tf.Variable(
tf.truncated_normal(shape = (IMAGE_PIXELS * IMAGE_PIXELS, hidden_unit)),
name = 'hidden_w')
#添加获取隐层权重统计值汇总数据的汇总操作
tf.summary.histogram(name = 'weights', values = hid_w)
with tf.name_scope('hid_b'):
hid_b = tf.Variable(tf.zeros(shape = (1, hidden_unit), dtype = tf.float32),
name = 'hidden_b')
#隐层输出
with tf.name_scope('relu'):
hid_out = tf.nn.relu(tf.matmul(x, hid_w) + hid_b)
with tf.name_scope('softmax_layer'):
with tf.name_scope('softmax_arg'):
#softmax层参数
with tf.name_scope('sm_w'):
sm_w = tf.Variable(
tf.truncated_normal(shape = (hidden_unit, output_nums)),
name = 'softmax_w')
#添加获取softmax层权重统计值汇总数据的汇总操作
tf.summary.histogram(name = 'weights', values = sm_w)
with tf.name_scope('sm_b'):
sm_b = tf.Variable(tf.zeros(shape = (1, output_nums), dtype = tf.float32),
name = 'softmax_b')
#softmax层的输出
with tf.name_scope('softmax'):
y = tf.nn.softmax(tf.matmul(hid_out, sm_w) + sm_b)
#梯度裁剪,因为概率取值为[0, 1]为避免出现无意义的log(0),故将y值裁剪到[1e-10, 1]
y_clip = tf.clip_by_value(y, 1.0e-10, 1 - 1.0e-5)
with tf.name_scope('cross_entropy'):
#使用交叉熵代价函数
cross_entropy = -tf.reduce_sum(y_ * tf.log(y_clip) + (1 - y_) * tf.log(1 - y_clip))
#添加获取交叉熵的汇总操作
tf.summary.scalar(name = 'cross_entropy', tensor = cross_entropy)
with tf.name_scope('train'):
#若不使用同步训练机制,使用Adam优化器
optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate)
#单步训练操作,
train_op = optimizer.minimize(cross_entropy, global_step = global_step)
#加载测试数据
test_image = mnist.test.images
test_label = mnist.test.labels
test_feed = {x:test_image, y_:test_label}
with tf.name_scope('accuracy'):
prediction = tf.equal(tf.argmax(input = y, axis = 1),
tf.argmax(input = y_, axis = 1))
accuracy = tf.reduce_mean(
input_tensor = tf.cast(x = prediction, dtype = tf.float32))
#创建嵌入变量
embedding_var = tf.Variable(test_image, trainable = False, name = 'embedding')
saver = tf.train.Saver({'embedding':embedding_var})
#创建元数据文件,将MNIST图像测试集对应的标签写入文件
def CreateMedaDataFile():
with open(logdir + '/metadata.tsv', 'w') as f:
label = np.nonzero(test_label)[1]
for i in range(test_data_size):
f.write('%d\n' % label[i])
#创建投影配置参数
def CreateProjectorConfig():
config = projector.ProjectorConfig()
embeddings = config.embeddings.add()
embeddings.tensor_name = 'embedding:0'
embeddings.metadata_path = logdir + '/metadata.tsv'
projector.visualize_embeddings(writer, config)
#聚集汇总操作
merged = tf.summary.merge_all()
#创建会话的配置参数
sess_config = tf.ConfigProto(
allow_soft_placement = True,
log_device_placement = False)
#创建会话
with tf.Session(config = sess_config) as sess:
#创建FileWriter实例
writer = tf.summary.FileWriter(logdir = logdir, graph = sess.graph)
#初始化全局变量
sess.run(tf.global_variables_initializer())
time_begin = time.time()
print('Training begin time: %f' % time_begin)
while True:
#加载训练批数据
batch_x, batch_y = mnist.train.next_batch(batch_size)
train_feed = {x:batch_x, y_:batch_y}
loss, _, summary= sess.run([cross_entropy, train_op, merged], feed_dict = train_feed)
step = global_step.eval()
#如果step为100的整数倍
if step % 100 == 0:
now = time.time()
print('%f: global_step = %d, loss = %f' % (
now, step, loss))
#向事件文件中添加汇总数据
writer.add_summary(summary = summary, global_step = step)
#若大于等于训练总步数,退出训练
if step >= train_steps:
break
time_end = time.time()
print('Training end time: %f' % time_end)
print('Training time: %f' % (time_end - time_begin))
#测试模型精度
test_accuracy = sess.run(accuracy, feed_dict = test_feed)
print('accuracy: %f' % test_accuracy)
saver.save(sess = sess, save_path = logdir + '/embedding_var.ckpt')
CreateMedaDataFile()
CreateProjectorConfig()
#关闭FileWriter
writer.close()
来源:https://blog.csdn.net/qq_40579095/article/details/88804019


猜你喜欢
- 题目:轮盘分为三部分: 一等奖, 二等奖和三等奖;轮盘转的时候是随机的,如果范围在[0,0.08)之间,代表一等奖,如果范围在[0.08,0
- 问题描述:想要去掉图像背景,只保留中心部分目标:1.利用ITK-SNAP制作二值化标签(即mask)2.利用软件ITK-SNAP把一幅图像中
- 一、代码注释介绍注释就是对代码的解释和说明,其目的是让人们能够更加轻松地了解代码。注释是编写程序时,写程序的人给一个语句、程序段、函数等的解
- 先给大家展示下效果图:向下滑动网页的时候能够自动加载图片并显示。盛放图片的盒子模型如下:<div class="box&qu
- MySQL Version确认(版本确认)的几个方法1.SHOW VARIABLES LIKE 'VERSION';mysq
- 在深度学习中,模型的输入size通常是正方形尺寸的,比如300 x 300这样.直接resize的话,会把图像拉的变形.通常我们希望resi
- 一、运算符算术运算符:+ - * / 可以在select 语句中使用连接运算符:|| select deptno|| dname from
- 可以写一个函数: 主要是使用正则来判断。另外输入字符是空的话,使用"-"来替换。CREATE FUNCTION [dbo
- 1、远程登录到linux上,使用到的模块paramiko#远程登陆操作系统def ssh(sys_ip,username,password,
- 缩进Python最具特色的是用缩进来标明成块的代码。我下面以if选择结构来举例。if后面跟随条件,如果条件成立,则执行归属于if的一个代码块
- 本文实例讲述了Python实现将数据框数据写入mongodb及mysql数据库的方法。分享给大家供大家参考,具体如下:主要内容:1、数据框数
- 导言Python官方文档对于内置函数的介绍较为简略,但这些内置函数在日常工作中却扮演着不可或缺的角色。为了更加便捷地使用和查阅这些函数,笔者
- 首先忠心感谢凌宇5942给我的帮助!在他的启迪下我发现了另一种实现flash透明背景的办法,愿与大家共同探讨:凌宇5942告知的解决办法:在
- 出差做PPT,要放一些图片上去,原图太大必须resize,十几张图片懒得一一处理了,最近正好在学python,最好的学习方式就是使用,于是写
- 安装去http://www.mysql.com/downloads/, 选择最下方的MySQL Community Edition,点击My
- 一、手指触屏,利用touchstart和touchend计算前后滑动距离,判断是上拉还是下滑。二、js中距离:pageY、clientY、o
- 1 前言上篇文章Python爬虫获取基金列表我们已经讲述了如何从基金网站上获取基金的列表信息。这一骗我们延续上一篇,继续分享如何抓取基金的基
- 目录一、前言二、Json.loads与eval 性能对比1. eval2. json.loads一、前言最近发现一些小伙伴使用eval来处理
- 在 Python 中一切都是对象。如果要在 Python 中表示一个对象,除了定义 class 外还有哪些方式
- api文档 https://sms-activate.org/cn/api2要使用SMSActivateAPI库从sms-activate.