python tensorflow基于cnn实现手写数字识别
作者:LN_IOS 发布时间:2023-05-09 06:22:06
标签:python,tensorflow,数字识别
一份基于cnn的手写数字自识别的代码,供大家参考,具体内容如下
# -*- coding: utf-8 -*-
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
# 加载数据集
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
# 以交互式方式启动session
# 如果不使用交互式session,则在启动session前必须
# 构建整个计算图,才能启动该计算图
sess = tf.InteractiveSession()
"""构建计算图"""
# 通过占位符来为输入图像和目标输出类别创建节点
# shape参数是可选的,有了它tensorflow可以自动捕获维度不一致导致的错误
x = tf.placeholder("float", shape=[None, 784]) # 原始输入
y_ = tf.placeholder("float", shape=[None, 10]) # 目标值
# 为了不在建立模型的时候反复做初始化操作,
# 我们定义两个函数用于初始化
def weight_variable(shape):
# 截尾正态分布,stddev是正态分布的标准偏差
initial = tf.truncated_normal(shape=shape, stddev=0.1)
return tf.Variable(initial)
def bias_variable(shape):
initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial)
# 卷积核池化,步长为1,0边距
def conv2d(x, W):
return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
def max_pool_2x2(x):
return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],
strides=[1, 2, 2, 1], padding='SAME')
"""第一层卷积"""
# 由一个卷积和一个最大池化组成。滤波器5x5中算出32个特征,是因为使用32个滤波器进行卷积
# 卷积的权重张量形状是[5, 5, 1, 32],1是输入通道的个数,32是输出通道个数
W_conv1 = weight_variable([5, 5, 1, 32])
# 每一个输出通道都有一个偏置量
b_conv1 = bias_variable([32])
# 位了使用卷积,必须将输入转换成4维向量,2、3维表示图片的宽、高
# 最后一维表示图片的颜色通道(因为是灰度图像所以通道数维1,RGB图像通道数为3)
x_image = tf.reshape(x, [-1, 28, 28, 1])
# 第一层的卷积结果,使用Relu作为激活函数
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1))
# 第一层卷积后的池化结果
h_pool1 = max_pool_2x2(h_conv1)
"""第二层卷积"""
W_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2)
"""全连接层"""
# 图片尺寸减小到7*7,加入一个有1024个神经元的全连接层
W_fc1 = weight_variable([7*7*64, 1024])
b_fc1 = bias_variable([1024])
# 将最后的池化层输出张量reshape成一维向量
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
# 全连接层的输出
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
"""使用Dropout减少过拟合"""
# 使用placeholder占位符来表示神经元的输出在dropout中保持不变的概率
# 在训练的过程中启用dropout,在测试过程中关闭dropout
keep_prob = tf.placeholder("float")
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)
"""输出层"""
W_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])
# 模型预测输出
y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)
# 交叉熵损失
cross_entropy = -tf.reduce_sum(y_ * tf.log(y_conv))
# 模型训练,使用AdamOptimizer来做梯度最速下降
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
# 正确预测,得到True或False的List
correct_prediction = tf.equal(tf.argmax(y_, 1), tf.argmax(y_conv, 1))
# 将布尔值转化成浮点数,取平均值作为精确度
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
# 在session中先初始化变量才能在session中调用
sess.run(tf.global_variables_initializer())
# 迭代优化模型
for i in range(2000):
# 每次取50个样本进行训练
batch = mnist.train.next_batch(50)
if i%100 == 0:
train_accuracy = accuracy.eval(feed_dict={
x: batch[0], y_: batch[1], keep_prob: 1.0}) # 模型中间不使用dropout
print("step %d, training accuracy %g" % (i, train_accuracy))
train_step.run(feed_dict={x:batch[0], y_:batch[1], keep_prob: 0.5})
print("test accuracy %g" % accuracy.eval(feed_dict={
x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))
做了2000次迭代,在测试集上的识别精度能够到0.9772……
来源:http://blog.csdn.net/LN_IOS/article/details/77966232


猜你喜欢
- 快速回顾一下RabbitMQ服务器的安装:sudo apt-get install rabbitmq-serverPython使用Rabbi
- 双向数据绑定指的是当对象的属性发生变化时能够同时改变对应的UI,反之亦然。换句话说,如果我们有一个user对象,这个对象有一个name属性,
- 如下所示:def trans_data_to_pair(self,data,index): contents=[
- 1.Apache2.2\conf\httpd.conf中释放: Include conf/extra/httpd-vhosts.conf(去
- PL/SQL是由Oracle公司对标准SQL进行扩展,专用于Oracle数据库中程序设计的专用语言,属第三代过程式程序设计语言。从Oracl
- 本文实例讲述了C#使用Socket快速判断数据库连接是否正常的方法。分享给大家供大家参考。具体分析如下:大家在做项目的时候,一般都是和数据库
- 一段时间以来,发现有很多人XHTML都不会用,不光是普通的初学者,有的程序员都不是很清楚该怎么写这个XHTML,我这里呢算是把一些常见的应用
- 原则, 以datetime为中心, 起点或中转, 转化为目标对象, 涵盖了大多数业务场景中需要的日期转换处理步骤:1. 掌握几种对象及其关系
- 在python中json分别由列表和字典组成,本文主要介绍python中字典与json相互转换的方法。使用json.dumps可以把字典转成
- 一、Python解释器 安装Windows平台下载地址 https://www.python.org/ftp/python/3.9.5/py
- 对大家推荐很好使用的MySql节点系统,像让大家对MySql节点系统有所了解,然后对MySql节点系统全面讲解介绍,希望对大家有用在向大家详
- 写在前面我的 CUDA 版本是什么? 这个问题本身就是有问题的,因为没有搞清楚cuda的分类这里的 CUDA 说的是 Driver CUDA
- os.system()在shell中执行一条命令。函数原型如下:它是最简单的调用系统应用的方式,下面是一个例子:import osimpor
- 引言反射是通过实体对象获取反射对象(Value、Type),然后可以操作相应的方法。在某些情况下,我们可能并不知道变量的具体类型,这时候就可
- scikit-learn 是基于 Python 语言的机器学习工具简单高效的数据挖掘和数据分析工具可供大家在各种环境中重复使用建立在 Num
- php魔术方法在php类保留方法中以 “__”两个下划线开头的函数称为魔术方法,我的理解为php类设
- 句柄(handle)是C++程序设计中经常提及的一个术语。它并不是一种具体的、固定不变的数据类型或实体,而是代表了程序设计中的一个广义的概念
- 因为 GAE 在国内访问不便,所以平时有一些小应用,我都会放在 SAE 上面, 虽然 SAE 还有很多缺陷,但算是上手比较容易的一个了,最起
- 在PHP中,我们不能用const直接定义数组常量,但是const可以定义字符串常量,结合eval()函数使字符串常量能执行。所以,我们可以用
- 应用一:有时候我们想把一个 list 或者 dict 传递给 javascript,处理后显示到网页上,比如要用 js 进行可视化的数据。请