tensorflow使用神经网络实现mnist分类
作者:Missayaa 发布时间:2023-07-05 10:19:13
标签:tensorflow,神经网络,mnist分类
本文实例为大家分享了tensorflow神经网络实现mnist分类的具体代码,供大家参考,具体内容如下
只有两层的神经网络,直接上代码
#引入包
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
#引入input_data文件
from tensorflow.examples.tutorials.mnist import input_data
#读取文件
mnist = input_data.read_data_sets('F:/mnist/data/',one_hot=True)
#定义第一个隐藏层和第二个隐藏层,输入层输出层
n_hidden_1 = 256
n_hidden_2 = 128
n_input = 784
n_classes = 10
#由于不知道输入图片个数,所以用placeholder
x = tf.placeholder("float",[None,n_input])
y = tf.placeholder("float",[None,n_classes])
stddev = 0.1
#定义权重
weights = {
'w1':tf.Variable(tf.random_normal([n_input,n_hidden_1],stddev = stddev)),
'w2':tf.Variable(tf.random_normal([n_hidden_1,n_hidden_2],stddev=stddev)),
'out':tf.Variable(tf.random_normal([n_hidden_2,n_classes],stddev=stddev))
}
#定义偏置
biases = {
'b1':tf.Variable(tf.random_normal([n_hidden_1])),
'b2':tf.Variable(tf.random_normal([n_hidden_2])),
'out':tf.Variable(tf.random_normal([n_classes])),
}
print("Network is Ready")
#前向传播
def multilayer_perceptrin(_X,_weights,_biases):
layer1 = tf.nn.sigmoid(tf.add(tf.matmul(_X,_weights['w1']),_biases['b1']))
layer2 = tf.nn.sigmoid(tf.add(tf.matmul(layer1,_weights['w2']),_biases['b2']))
return (tf.matmul(layer2,_weights['out'])+_biases['out'])
#定义优化函数,精准度等
pred = multilayer_perceptrin(x,weights,biases)
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = pred,labels=y))
optm = tf.train.GradientDescentOptimizer(learning_rate = 0.001).minimize(cost)
corr = tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
accr = tf.reduce_mean(tf.cast(corr,"float"))
print("Functions is ready")
#定义超参数
training_epochs = 80
batch_size = 200
display_step = 4
#会话开始
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
#优化
for epoch in range(training_epochs):
avg_cost=0.
total_batch = int(mnist.train.num_examples/batch_size)
for i in range(total_batch):
batch_xs,batch_ys = mnist.train.next_batch(batch_size)
feeds = {x:batch_xs,y:batch_ys}
sess.run(optm,feed_dict = feeds)
avg_cost += sess.run(cost,feed_dict=feeds)
avg_cost = avg_cost/total_batch
if (epoch+1) % display_step ==0:
print("Epoch:%03d/%03d cost:%.9f"%(epoch,training_epochs,avg_cost))
feeds = {x:batch_xs,y:batch_ys}
train_acc = sess.run(accr,feed_dict = feeds)
print("Train accuracy:%.3f"%(train_acc))
feeds = {x:mnist.test.images,y:mnist.test.labels}
test_acc = sess.run(accr,feed_dict = feeds)
print("Test accuracy:%.3f"%(test_acc))
print("Optimization Finished")
程序部分运行结果如下:
Train accuracy:0.605
Test accuracy:0.633
Epoch:071/080 cost:1.810029302
Train accuracy:0.600
Test accuracy:0.645
Epoch:075/080 cost:1.761531130
Train accuracy:0.690
Test accuracy:0.649
Epoch:079/080 cost:1.711757494
Train accuracy:0.640
Test accuracy:0.660
Optimization Finished
来源:https://blog.csdn.net/Missayaaa/article/details/80065319


猜你喜欢
- 一、方框滤波 方框滤波是均值滤波的一种形式。在均值滤波中,滤波结果的像素值是任意一个点的邻域平均值,等于各邻域像素值之和的均值,而在方框
- 一、前言数据库的数据量达到一定程度之后,为避免带来系统性能上的瓶颈。需要进行数据的处理,采用的手段是分区、分片、分库、分表。二、分片(类似分
- 集中式工作流(不常用)集中式工作流像SVN一样,以中央仓库作为项目所有修改的单点实体。所有修改都提交到 Master分支上。这种方式与 SV
- 前言如果你从事大数据工作,用Python的Pandas库时会发现很多惊喜。Pandas在数据科学和分析领域扮演越来越重要的角色,尤其是对于从
- 做手机整机测试的,肯定有开关机的需求,关机,几分钟后再开机(一直循环操作测试,就是不能重启);这个需求在关机后就没有办法开机了,任何脚本命令
- 本文借用HTML的css语法,将样式表应用到窗口部件。这里只是个简单的例子,实际上样式表的语法很丰富。以下类似于css: StyleShee
- 如何用ASP发送带附件的邮件?请问如何用CDONTS组件发送带附件的邮件? 见下列代码:<%&nb
- 一. 什么是模块(module)?在实际应用中,有时程序所要实现功能比较复杂,代码量也很大。若把所有的代码都存储在一个文件中,则不利于代码的
- 手动安装Anaconda环境变量安装 Anaconda后,在命令行执行python命令或conda命令会报错无法找到此时就需要我们手动添加环
- 首先介绍一下什么是桑葚图?桑基图(Sankey diagram),即桑基能量分流图,也叫桑基能量平衡图。它是一种特定类型的流程图,图中延伸的
- 当用户的页面需要动态加载iframe 时, 如果iframe的src中包传中文参数会出现编码错误;必须加编码,然后再解码。 编码:encod
- 摘要在机器视觉中,对于图像的处理有时候因为放置的原因导致ROI区域倾斜,这个时候我们会想办法把它纠正为正确的角度视角来,方便下一步的布局分析
- 我是以Python开门的,我还是觉得Python也可以进行地形三维可视化,当然这里需要借助第三方库,so,我就来介绍:Python一个很重要
- 为什么要 mock?后台接口还没完成,但前端要用到接口我想篡改后台接口的结果,测试一些分支逻辑起步本篇文章会使用到 swr、axios、vi
- 文章摘要本文简单说明了FLV文件的格式,以此为出发点,使用 Python 实现FLV视频的拼接。一.FLV文件格式关于FLV文件格式的解析网
- GoroutineGoroutine 是 Golang 提供的一种轻量级线程,我们通常称之为「协程」,相比较线程,创建一个协程的成本是很低的
- 楔子上一篇文章我们探讨了 GIL 的原理,以及如何释放 GIL 实现并行,做法是将函数声明为 nogil,然后使用 with nogil 上
- 现在实际的情况是这样的:客户有一台打卡机,员工打卡的信息全部储存在打卡机的Access数据库里面,现在客户引入了一种新的管理系统,需要将Ac
- 箱线图箱线图一般用来展现数据的分布,如上下四分位值、中位数等,也可以直观地展示异常点。Matplotlib提供了boxplot()函数绘制箱
- math模块# 数学相关模块import mathr = math.floor(3.2) # 向下取整print(r)r = math.ce