用tensorflow搭建CNN的方法
作者:freedom098 发布时间:2021-05-29 08:51:17
CNN(Convolutional Neural Networks) 卷积神经网络简单讲就是把一个图片的数据传递给CNN,原涂层是由RGB组成,然后CNN把它的厚度加厚,长宽变小,每做一层都这样被拉长,最后形成一个分类器
在 CNN 中有几个重要的概念:
stride
padding
pooling
stride,就是每跨多少步抽取信息。每一块抽取一部分信息,长宽就缩减,但是厚度增加。抽取的各个小块儿,再把它们合并起来,就变成一个压缩后的立方体。
padding,抽取的方式有两种,一种是抽取后的长和宽缩减,另一种是抽取后的长和宽和原来的一样。
pooling,就是当跨步比较大的时候,它会漏掉一些重要的信息,为了解决这样的问题,就加上一层叫pooling,事先把这些必要的信息存储起来,然后再变成压缩后的层
利用tensorflow搭建CNN,也就是卷积神经网络是一件很简单的事情,笔者按照官方教程中使用MNIST手写数字识别为例展开代码,整个程序也基本与官方例程一致,不过在比较容易迷惑的地方加入了注释,有一定的机器学习或者卷积神经网络制式的人都应该可以迅速领会到代码的含义。
#encoding=utf-8
import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
def weight_variable(shape):
initial = tf.truncated_normal(shape,stddev=0.1) #截断正态分布,此函数原型为尺寸、均值、标准差
return tf.Variable(initial)
def bias_variable(shape):
initial = tf.constant(0.1,shape=shape)
return tf.Variable(initial)
def conv2d(x,W):
return tf.nn.conv2d(x,W,strides=[1,1,1,1],padding='SAME') # strides第0位和第3为一定为1,剩下的是卷积的横向和纵向步长
def max_pool_2x2(x):
return tf.nn.max_pool(x,ksize = [1,2,2,1],strides=[1,2,2,1],padding='SAME')# 参数同上,ksize是池化块的大小
x = tf.placeholder("float", shape=[None, 784])
y_ = tf.placeholder("float", shape=[None, 10])
# 图像转化为一个四维张量,第一个参数代表样本数量,-1表示不定,第二三参数代表图像尺寸,最后一个参数代表图像通道数
x_image = tf.reshape(x,[-1,28,28,1])
# 第一层卷积加池化
w_conv1 = weight_variable([5,5,1,32]) # 第一二参数值得卷积核尺寸大小,即patch,第三个参数是图像通道数,第四个参数是卷积核的数目,代表会出现多少个卷积特征
b_conv1 = bias_variable([32])
h_conv1 = tf.nn.relu(conv2d(x_image,w_conv1)+b_conv1)
h_pool1 = max_pool_2x2(h_conv1)
# 第二层卷积加池化
w_conv2 = weight_variable([5,5,32,64]) # 多通道卷积,卷积出64个特征
b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(conv2d(h_pool1,w_conv2)+b_conv2)
h_pool2 = max_pool_2x2(h_conv2)
# 原图像尺寸28*28,第一轮图像缩小为14*14,共有32张,第二轮后图像缩小为7*7,共有64张
w_fc1 = weight_variable([7*7*64,1024])
b_fc1 = bias_variable([1024])
h_pool2_flat = tf.reshape(h_pool2,[-1,7*7*64]) # 展开,第一个参数为样本数量,-1未知
f_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat,w_fc1)+b_fc1)
# dropout操作,减少过拟合
keep_prob = tf.placeholder(tf.float32)
h_fc1_drop = tf.nn.dropout(f_fc1,keep_prob)
w_fc2 = weight_variable([1024,10])
b_fc2 = bias_variable([10])
y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop,w_fc2)+b_fc2)
cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv)) # 定义交叉熵为loss函数
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) # 调用优化器优化
correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
sess = tf.InteractiveSession()
sess.run(tf.initialize_all_variables())
for i in range(2000):
batch = mnist.train.next_batch(50)
if i%100 == 0:
train_accuracy = accuracy.eval(feed_dict={x:batch[0], y_: batch[1], keep_prob: 1.0})
print "step %d, training accuracy %g"%(i, train_accuracy)
train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
print "test accuracy %g"%accuracy.eval(feed_dict={x: mnist.test.images[0:500], y_: mnist.test.labels[0:500], keep_prob: 1.0})
在程序中主要注意这么几点:
1、维度问题,由于我们tensorflow基于的是张量这样一个概念,张量其实就是维度扩展的矩阵,因此维度特别重要,而且维度也是很容易使人迷惑的地方。
2、卷积问题,卷积核不只是二维的,多通道卷积时卷积核就是三维的
3、最后进行检验的时候,如果一次性加载出所有的验证集,出现了内存爆掉的情况,由于是使用的是云端的服务器,可能内存小一些,如果内存够用可以直接全部加载上看结果
4、这个程序原始版本迭代次数设置了20000次,这个次数大约要训练数个小时(在不使用GPU的情况下),这个次数可以按照要求更改。
来源:http://blog.csdn.net/freedom098/article/details/51918825


猜你喜欢
- 译者 | 豌豆花下猫声明 :本文获得原作者授权翻译,转载请保留原文出处,请勿用于商业或非法用途。有许许多多文章写了 Python 中的许多很
- 使用SQL SERVER的[导入]功能,便可将access数据转换,但要注意原来的'自增字段'需要修改,将相应字段标识修改为
- overflow:hidden 用在div上时很好用,但直接用在td上,好像没有任何效果。td中的文本过长时依然自动换了一行像下面这要设定一
- 前言版本:windows 10.0python 3.8多重继承在Python数字比较与类结构中有简略提到类,且在Python中类的mro与继
- 第一题:ASP中,VBScript的唯一的数据类型是什么?第二题:在ASP中,VBScript有多种控制程序流程语句,如If…Then, S
- 简介在SQL SERVER中,数据库在硬盘上的存储方式和普通文件在Windows中的存储方式没有什么不同,仅仅是几个文件而已.SQL SER
- 本文实例讲述了微信小程序列表渲染功能之列表下拉刷新及上拉加载的实现方法。分享给大家供大家参考,具体如下:微信小程序为2017年1月9日打下了
- 前言Vuex 和 Pinia 是用于管理 Vue.js 应用程序状态的标准 Vue.js 库。让我们比较一下他们的代码、语言、功能和社区支持
- default-character-set=gbk #或gb2312,big5,utf8 然后重新启动mysql 运行->servic
- 场景:有一个多层嵌套的列表如:[[23],[3,3],[22,22],1,123,[[123,a],2]] 拆分成:def splitlis
- 前言今天就来理一理session、cookie、token这三者之间的关系!1.为什么会有它们?我们都知道 HTTP 协议是无状态的,所谓的
- ADO也提供更有效率方法来取得数据。GetRows 方法传回一个二维的数组变量,每一行对应Recordset中的一笔记录,且每
- 利用SocketServer模块来实现网络客户端与服务器并发连接非阻塞通信。首先,先了解下SocketServer模块中可供使用的类:Bas
- 1,reload 方法该方法强迫浏览器刷新当前页面。语法:location.reload([bForceGet])参数: bForceGet
- 安装cuda 我强调下 这个需要注意版本问题的注意 (个人的想法,安装思路,仅供参考)pytorch 需要注意这个现在支持的版本.根据这个支
- 描述filter() 函数用于过滤序列,过滤掉不符合条件的元素,返回一个迭代器对象,如果要转换为列表,可以使用 list() 来转换。该接收
- 前言orztop是一款实时show full processlist的工具,我们可以实时看到数据库有哪些线程,执行哪些语句等。工具使用方便简
- 前言这几年对运维人员来说最大的变化可能就是公有云的出现了,我相信可能很多小伙伴公司业务就跑在公有云上, 因为公司业务关系,我个人
- 先要明白Fscanf的工作原理Fscanf在遇到\n才结束遇到\r时就会把\r替换成0这就有个问题,要注意自己的文本换行符是什么,在Wind
- IE8主页http://www.microsoft.com/windows/products/winfamily/ie/ie8/defaul