python神经网络tensorflow利用训练好的模型进行预测
作者:Bubbliiiing 发布时间:2022-09-27 17:33:17
标签:python,神经网络,tensorflow,模型预测,训练好的模型
学习前言
在神经网络学习中slim常用函数与如何训练、保存模型文章里已经讲述了如何使用slim训练出来一个模型,这篇文章将会讲述如何预测。
载入模型思路
载入模型的过程主要分为以下四步:
1、建立会话Session;
2、将img_input的placeholder传入网络,建立网络结构;
3、初始化所有变量;
4、利用saver对象restore载入所有参数。
这里要注意的重点是,在利用saver对象restore载入所有参数之前,必须要建立网络结构,因为网络结构对应着cpkt文件中的参数。
(网络层具有对应的名称scope。)
实现代码
在运行实验代码前,可以直接下载代码,因为存在许多依赖的文件
import tensorflow as tf
import numpy as np
from nets import Net
from tensorflow.examples.tutorials.mnist import input_data
def compute_accuracy(x_data,y_data):
global prediction
y_pre = sess.run(prediction,feed_dict={img_input:x_data})
correct_prediction = tf.equal(tf.arg_max(y_data,1),tf.arg_max(y_pre,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
result = sess.run(accuracy,feed_dict = {img_input:x_data})
return result
mnist = input_data.read_data_sets("MNIST_data",one_hot = "true")
slim = tf.contrib.slim
# img_input的placeholder
img_input = tf.placeholder(tf.float32, shape = (None, 784))
img_reshape = tf.reshape(img_input,shape = (-1,28,28,1))
# 载入模型
sess = tf.Session()
Conv_Net = Net.Conv_Net()
# 将img_input的placeholder传入网络
prediction = Conv_Net.net(img_reshape)
# 载入模型
ckpt_filename = './logs/model.ckpt-20000'
# 初始化所有变量
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
# 恢复
saver.restore(sess, ckpt_filename)
print(compute_accuracy(mnist.test.images,mnist.test.labels))
运行结果为:
0.9921
来源:https://blog.csdn.net/weixin_44791964/article/details/102584474


猜你喜欢
- 阅读上一篇:FrontPage XP设计教程4——Css样式表的应用表单在网站的制作过程中是比较常见的,举个简单的例子,我们在申请免费电子信
- 解决问题: 不使用for计算两组、多个矩形两两间的iou使用numpy广播的方法,在python程序中并不建议使用for语句,python中
- 最小生成树的Prim算法也是贪心算法的一大经典应用。Prim算法的特点是时刻维护一棵树,算法不断加边,加的过程始终是一棵树。Prim算法过程
- FSO中除了可以对驱动器、文件夹的操作以外,功能最强大的就是对文件的操作了。它可以用来记数、内容管理、搜索还可生成动态HTML页面等等。一、
- 今天来给大家推荐一个Python当中超级好用的内置函数,那便是lambda方法,本篇教程大致和大家分享什么是lambda函数lambda函数
- open(filename,mode,buffer) 其中第一个参数是要打开的文件的文件名,必选;第二个是打开方式,可选;第三个为缓冲区,可
- 1. ASCII 返回与指定的字符对应的十进制数; SQL> select ascii(A) A,ascii(a) a,as
- 1.sonarqube是一款代码分析的工具,通过soanrScanner扫描后的数据传递给sonarqube进行分析2.sonarqube社
- 1,创建测试表CREATE TABLE `testsign` ( `userid` int(5) DEFAULT NULL, `user
- Scrapy批量运行爬虫文件的两种方法:1、使用CrawProcess实现https://doc.scrapy.org/en/latest/
- 之前写了个python脚本用selenium+phantomjs爬新帖子,在循环拉取页面的过程中,phantomjs总是block住,使用W
- 在Python中,有许多用于发送HTTP请求的库,其中最受欢迎的是requests、aiohttp和httpx。这三个库的性能和功能各不相同
- 问题分析在关闭数据库的命令发现mysql关不了,提示Warning: World-writable config file '/et
- 研究(2)中讨论了栅格系统的基础知识。这一篇将集中探讨栅格系统的粒度问题。(注:如非特别指明,栅格系统均指24列960栅格系统)淘宝的首页(
- 求f(x) = sin(x)/x 的不定积分和负无穷到正无穷的定积分sin(x)/x 的不定积分是信号函数sig ,负无穷到正无穷的定积分为
- 可能是因为编译太简单了,golang 并没有一个官方的构建工具(类似于 java 的 maven 和 gradle之类的),但是除了编译,我
- 在网站建设中,分类算法的应用非常的普遍。在设计一个电子商店时,要涉及到商品分类;在设计发布系统时,要涉及到栏目或者频道分类;在设计软件下载这
- 1.DQL类型的SQL语句基本概述DQL类型的SQL语言全称为Data Query Language,中文名称为数据查询语言,主要是用来查询
- python中index()、find()方法,具体内容如下:index() 方法检测字符串中是否包含子字符串 str ,如果指定 beg(
- 虽然并非你编写的每个 Python 程序都要求一个严格的性能分析,但是让人放心的是,当问题发生的时候,Python 生态圈有各种各样的工具可