tensorflow使用CNN分析mnist手写体数字数据集
作者:Dillon2015 发布时间:2021-07-20 20:29:35
标签:tensorflow,mnist,数据集
本文实例为大家分享了tensorflow使用CNN分析mnist手写体数字数据集,供大家参考,具体内容如下
import tensorflow as tf
import numpy as np
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
trX, trY, teX, teY = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels
#把上述trX和teX的形状变为[-1,28,28,1],-1表示不考虑输入图片的数量,28×28是图片的长和宽的像素数,
# 1是通道(channel)数量,因为MNIST的图片是黑白的,所以通道是1,如果是RGB彩色图像,通道是3。
trX = trX.reshape(-1, 28, 28, 1) # 28x28x1 input img
teX = teX.reshape(-1, 28, 28, 1) # 28x28x1 input img
X = tf.placeholder("float", [None, 28, 28, 1])
Y = tf.placeholder("float", [None, 10])
#初始化权重与定义网络结构。
# 这里,我们将要构建一个拥有3个卷积层和3个池化层,随后接1个全连接层和1个输出层的卷积神经网络
def init_weights(shape):
return tf.Variable(tf.random_normal(shape, stddev=0.01))
w = init_weights([3, 3, 1, 32]) # patch大小为3×3,输入维度为1,输出维度为32
w2 = init_weights([3, 3, 32, 64]) # patch大小为3×3,输入维度为32,输出维度为64
w3 = init_weights([3, 3, 64, 128]) # patch大小为3×3,输入维度为64,输出维度为128
w4 = init_weights([128 * 4 * 4, 625]) # 全连接层,输入维度为 128 × 4 × 4,是上一层的输出数据又三维的转变成一维, 输出维度为625
w_o = init_weights([625, 10]) # 输出层,输入维度为 625, 输出维度为10,代表10类(labels)
# 神经网络模型的构建函数,传入以下参数
# X:输入数据
# w:每一层的权重
# p_keep_conv,p_keep_hidden:dropout要保留的神经元比例
def model(X, w, w2, w3, w4, w_o, p_keep_conv, p_keep_hidden):
# 第一组卷积层及池化层,最后dropout一些神经元
l1a = tf.nn.relu(tf.nn.conv2d(X, w, strides=[1, 1, 1, 1], padding='SAME'))
# l1a shape=(?, 28, 28, 32)
l1 = tf.nn.max_pool(l1a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
# l1 shape=(?, 14, 14, 32)
l1 = tf.nn.dropout(l1, p_keep_conv)
# 第二组卷积层及池化层,最后dropout一些神经元
l2a = tf.nn.relu(tf.nn.conv2d(l1, w2, strides=[1, 1, 1, 1], padding='SAME'))
# l2a shape=(?, 14, 14, 64)
l2 = tf.nn.max_pool(l2a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
# l2 shape=(?, 7, 7, 64)
l2 = tf.nn.dropout(l2, p_keep_conv)
# 第三组卷积层及池化层,最后dropout一些神经元
l3a = tf.nn.relu(tf.nn.conv2d(l2, w3, strides=[1, 1, 1, 1], padding='SAME'))
# l3a shape=(?, 7, 7, 128)
l3 = tf.nn.max_pool(l3a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
# l3 shape=(?, 4, 4, 128)
l3 = tf.reshape(l3, [-1, w4.get_shape().as_list()[0]]) # reshape to (?, 2048)
l3 = tf.nn.dropout(l3, p_keep_conv)
# 全连接层,最后dropout一些神经元
l4 = tf.nn.relu(tf.matmul(l3, w4))
l4 = tf.nn.dropout(l4, p_keep_hidden)
# 输出层
pyx = tf.matmul(l4, w_o)
return pyx #返回预测值
#我们定义dropout的占位符——keep_conv,它表示在一层中有多少比例的神经元被保留下来。生成网络模型,得到预测值
p_keep_conv = tf.placeholder("float")
p_keep_hidden = tf.placeholder("float")
py_x = model(X, w, w2, w3, w4, w_o, p_keep_conv, p_keep_hidden) #得到预测值
#定义损失函数,这里我们仍然采用tf.nn.softmax_cross_entropy_with_logits来比较预测值和真实值的差异,并做均值处理;
# 定义训练的操作(train_op),采用实现RMSProp算法的优化器tf.train.RMSPropOptimizer,学习率为0.001,衰减值为0.9,使损失最小;
# 定义预测的操作(predict_op)
cost = tf.reduce_mean(tf.nn. softmax_cross_entropy_with_logits(logits=py_x, labels=Y))
train_op = tf.train.RMSPropOptimizer(0.001, 0.9).minimize(cost)
predict_op = tf.argmax(py_x, 1)
#定义训练时的批次大小和评估时的批次大小
batch_size = 128
test_size = 256
#在一个会话中启动图,开始训练和评估
# Launch the graph in a session
with tf.Session() as sess:
# you need to initialize all variables
tf. global_variables_initializer().run()
for i in range(100):
training_batch = zip(range(0, len(trX), batch_size),
range(batch_size, len(trX)+1, batch_size))
for start, end in training_batch:
sess.run(train_op, feed_dict={X: trX[start:end], Y: trY[start:end],
p_keep_conv: 0.8, p_keep_hidden: 0.5})
test_indices = np.arange(len(teX)) # Get A Test Batch
np.random.shuffle(test_indices)
test_indices = test_indices[0:test_size]
print(i, np.mean(np.argmax(teY[test_indices], axis=1) ==
sess.run(predict_op, feed_dict={X: teX[test_indices],
p_keep_conv: 1.0,
p_keep_hidden: 1.0})))
来源:https://blog.csdn.net/Dillon2015/article/details/79068653


猜你喜欢
- 在vue-cli3中,公共文件夹由static变成了public先把要访问的json放到public文件夹下使用axios的get方法获取,
- 本文实例讲述了php与javascript正则匹配中文的方法。分享给大家供大家参考,具体如下:php中正则匹配utf-8中文: (重点是:[
- GetRef 函数 返回一个指向一过程的引用,此过程可绑定某事件。 Set object.eventname = GetRef(procna
- 本文实例讲述了JS实现屏蔽网页右键复制及ctrl+c复制的方法。分享给大家供大家参考,具体如下:老是有些网站会屏蔽你的鼠标右键或者用快捷键复
- frm文件和ibd文件简介 在MySQL中,如果我们使用了默认的存储引擎innodb创建一张表,那么在文件夹下面就会
- 通过python与ffmpeg结合使用,可生成进行视频点播、直播的压力测试脚本。可支持不同类型的视频流,比如rtmp或者hls形式。 通过如
- 今天我们将介绍处理大量数据时非常方便的工具。我不会只告诉您可能在手册中找到的一般信息,而是分享一些我发现的小技巧,例如tqdm与 multi
- swiper是我之前做前端页面会用到的一个插件,我自己认为是非常好用的。swiper提供了形式多种多样、适应各个终端的轮播图效果。本文是小编
- 本文实例讲述了python实现的用于搜索文件并进行内容替换的类。分享给大家供大家参考。具体实现方法如下:#!/usr/bin/python
- 总结了一下自己工作中使用到的注释书写规范,没有什么技术含量,只是用于统一制作方式,方便维护。包含了“区域注释”、“单行注释”、“注释层级”和
- Python 环境下文件的读取问题,请参见拙文 Python基础之文件读取的讲解这是一道著名的 Python 面试题,考察的问题是,Pyth
- 为 Web页指定 DOCTYPE 会影响浏览器呈现页的方式。Internet Explorer、Mozilla Firefox 和 Oper
- fixtures调用其他fixtures及fixture复用性 pytest最大的优点之一就是它非常灵活。它可以将复杂的测试需求简
- 一、Pytorch修改预训练模型时遇到key不匹配最近想着修改网络的预训练模型vgg.pth,但是发现当我加载预训练模型权重到新建的模型并保
- 本文实例演示了Python生成pdf文件的方法,是比较实用的功能,主要包含2个文件。具体实现方法如下:pdf.py文件如下:#!/usr/b
- 之前我们在入门jdbc的时候,常用这种方法连接数据库:package util;import java.sql.Connection;imp
- 背景小程序在网络层提供的API是能够完成一个程序与服务端交互的完整链路,但需要大量的定制化代码,才能实现请求拦截和响应拦截,不太符合大多数开
- 变量什么是变量?变量是在程序运行时,能存储计算结果或能表示值得抽象概念。简单地说,变量就是在程序运行时,记录数据用的变量定义格式:变量名称=
- 前奏为了能操作数据库, 首先我们要有一个数据库, 所以要首先安装Mysql, 然后创建一个测试数据库python_test用以后面的测试使用
- 最近在看python的算法书,之前在年前买的书,一直在工作间隙的时候,学习充电,终于看到这本书,但是确实又有点难,感觉作者写的代码太炫技 了