tensorflow学习笔记之简单的神经网络训练和测试
作者:denny402 发布时间:2021-02-18 20:28:14
标签:tensorflow,神经网络
本文实例为大家分享了用简单的神经网络来训练和测试的具体代码,供大家参考,具体内容如下
刚开始学习tf时,我们从简单的地方开始。卷积神经网络(CNN)是由简单的神经网络(NN)发展而来的,因此,我们的第一个例子,就从神经网络开始。
神经网络没有卷积功能,只有简单的三层:输入层,隐藏层和输出层。
数据从输入层输入,在隐藏层进行加权变换,最后在输出层进行输出。输出的时候,我们可以使用softmax回归,输出属于每个类别的概率值。借用极客学院的图表示如下:
其中,x1,x2,x3为输入数据,经过运算后,得到三个数据属于某个类别的概率值y1,y2,y3. 用简单的公式表示如下:
在训练过程中,我们将真实的结果和预测的结果相比(交叉熵比较法),会得到一个残差。公式如下:
y是我们预测的概率值,y'是实际的值。这个残差越小越好,我们可以使用梯度下降法,不停地改变W和b的值,使得残差逐渐变小,最后收敛到最小值。这样训练就完成了,我们就得到了一个模型(W和b的最优化值)。
完整代码如下:
import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
x = tf.placeholder(tf.float32, [None, 784])
y_actual = tf.placeholder(tf.float32, shape=[None, 10])
W = tf.Variable(tf.zeros([784,10])) #初始化权值W
b = tf.Variable(tf.zeros([10])) #初始化偏置项b
y_predict = tf.nn.softmax(tf.matmul(x,W) + b) #加权变换并进行softmax回归,得到预测概率
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_actual*tf.log(y_predict),reduction_indies=1)) #求交叉熵
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy) #用梯度下降法使得残差最小
correct_prediction = tf.equal(tf.argmax(y_predict,1), tf.argmax(y_actual,1)) #在测试阶段,测试准确度计算
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) #多个批次的准确度均值
init = tf.initialize_all_variables()
with tf.Session() as sess:
sess.run(init)
for i in range(1000): #训练阶段,迭代1000次
batch_xs, batch_ys = mnist.train.next_batch(100) #按批次训练,每批100行数据
sess.run(train_step, feed_dict={x: batch_xs, y_actual: batch_ys}) #执行训练
if(i%100==0): #每训练100次,测试一次
print "accuracy:",sess.run(accuracy, feed_dict={x: mnist.test.images, y_actual: mnist.test.labels})
每训练100次,测试一次,随着训练次数的增加,测试精度也在增加。训练结束后,1W行数据测试的平均精度为91%左右,不是太高,肯定没有CNN高。
来源:http://www.cnblogs.com/denny402/p/5852983.html
0
投稿
猜你喜欢
- 一、前言何谓动态导入模块,就是说模块的导入可以根据我们的需求动态的去导入,不是像一般的在代码文件开头固定的导入所需的模块。何谓反射机制,利用
- 导言GridView是由一组字段(Field)组成的,它们都指定的了来自DataSource中的什么属性需要用到自己的输出呈现中。最简单的字
- 前言本文主要给大家介绍了Go语言中函数new与make的使用和区别,关于Go语言中new和make是内建的两个函数,主要用来创建分配类型内存
- 1、HKEY_LOCAL_MACHINE\SYSTEM\ControlSet001\Services\Eventlog\Applicatio
- Python的线程操作在旧版本中使用的是thread模块,在Python27和Python3中引入了threading模块,同时thread
- PHP fprintf() 函数实例把一些文本写入到名为 "test.txt" 的文本文件:<?php $numb
- my.ini文件[mysqld]max_allowed_packet = 10M
- 本文实例为大家分享了Golang实现文件传输的具体代码,供大家参考,具体内容如下借助TCP完成文件的传输,基本思路如下:1、发送方(客户端)
- 介绍今天有个不正经的需求,就是要快速做一个restful api的性能测试,要求测试在海量作业数据的情况下客户端分页获取所有作业的性能。因为
- 如提取第1行,第2列的值:df.iloc[[0],[1]]则会返回一个df,即有字段名和行号。如果用values属性取值:df.iloc[[
- 做一个将本地图片上传到mysql数据库的小实例,顺便也下载下来到桌面检测是否上传成功。在写代码之前得先在数据库中建立image表,用来存储图
- SQL中的单记录函数 1.ASCII 返回与指定的字符对应的十进制数; SQL> select ascii('A')
- 安装tf2onnx以及onnxruntimepip install onnxruntimepip install tf2onnxtf 转为o
- 今天有一位同学给了我一个excel文件,要求读取某些行,某些列,然后我试着做了一个demo,这里分享出来,希望能帮到大家:首先安装xlrd:
- 使用re, urllib, threading多线程抓取天涯帖子内容,设置url为需抓取的天涯帖子的第一页,设置file_name为下载后的
- python2.7安装opencv-python很慢且总是失败当直接使用pip安装opencv-python时,且总是报错,找了好久,发现是
- 为什么要写这篇文章其实是因为最近学到了python的property装饰器的相关知识,刚开始学得云里雾里,于是乎,看了许多相关博客,不巧,大
- 美餐每天发一个用Excel汇总的就餐数据,我们把它导入到数据库后,行政办公服务用它和公司内的就餐数据进行比对查重。初始实现是单线程,和imp
- IE 的弹窗常用的有两种,不外乎是 window.open 与 window.showModalDialog,前者兼容性好,后者
- 需求:1.大量csv文件,以数字命名,如1.csv、2.cvs等;2.逐个打开,对csv文件中的某一列进行格式修改;3.将更改后的内容写入新