Tensorflow实现在训练好的模型上进行测试
作者:非典型废言 发布时间:2022-10-04 07:17:00
标签:Tensorflow,训练,模型,测试
Tensorflow可以使用训练好的模型对新的数据进行测试,有两种方法:第一种方法是调用模型和训练在同一个py文件中,中情况比较简单;第二种是训练过程和调用模型过程分别在两个py文件中。本文将讲解第二种方法。
模型的保存
tensorflow提供可保存训练模型的接口,使用起来也不是很难,直接上代码讲解:
#网络结构
w1 = tf.Variable(tf.truncated_normal([in_units, h1_units], stddev=0.1))
b1 = tf.Variable(tf.zeros([h1_units]))
y = tf.nn.softmax(tf.matmul(w1, x) + b1)
tf.add_to_collection('network-output', y)
x = tf.placeholder(tf.float32, [None, in_units], name='x')
y_ = tf.placeholder(tf.float32, [None, 10], name='y_')
#损失函数与优化函数
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
train_step = tf.train.AdamOptimizer(rate).minimize(cross_entropy)
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init)
saver.save(sess,"save/model.ckpt")
train_step.run({x: train_x, y_: train_y})
以上代码就完成了模型的保存,值得注意的是下面这行代码
tf.add_to_collection('network-output', y)
这行代码保存了神经网络的输出,这个在后面使用导入模型过程中起到关键作用。
模型的导入
模型训练并保存后就可以导入来评估模型在测试集上的表现,网上很多文章只用简单的四则运算来做例子,让人看的头大。还是先上代码:
with tf.Session() as sess:
saver = tf.train.import_meta_graph('./model.ckpt.meta')
saver.restore(sess, './model.ckpt')# .data文件
pred = tf.get_collection('network-output')[0]
graph = tf.get_default_graph()
x = graph.get_operation_by_name('x').outputs[0]
y_ = graph.get_operation_by_name('y_').outputs[0]
y = sess.run(pred, feed_dict={x: test_x, y_: test_y})
讲解一下关键的代码,首先是pred = tf.get_collection('pred_network')[0],这行代码获得训练过程中网络输出的“接口”,简单理解就是,通过tf.get_collection() 这个方法获取了整个网络结构。获得网络结构后我们就需要喂它对应的数据y = sess.run(pred, feed_dict={x: test_x, y_: test_y}) 在训练过程中我们的输入是
x = tf.placeholder(tf.float32, [None, in_units], name='x')
y_ = tf.placeholder(tf.float32, [None, 10], name='y_')
因此导入模型后所需的输入也要与之对应可使用以下代码获得:
x = graph.get_operation_by_name('x').outputs[0]
y_ = graph.get_operation_by_name('y_').outputs[0]
使用模型的最后一步就是输入测试集,然后按照训练好的网络进行评估
sess.run(pred, feed_dict={x: test_x, y_: test_y})
理解下这行代码,sess.run() 的函数原型为
run(fetches, feed_dict=None, options=None, run_metadata=None)
Tensorflow对 feed_dict 执行fetches操作,因此在导入模型后的运算就是,按照训练的网络计算测试输入的数据。
来源:https://blog.csdn.net/sinat_35821976/article/details/80765145


猜你喜欢
- 本文实例讲述了Python比较文件夹比另一同名文件夹多出的文件并复制出来的方法。分享给大家供大家参考。具体如下:这个东东本来是做来给公司数据
- computedcomputed只接收一个getter函数1、getter必须有返回值2、computed返回一个只读响应式ref对象 (只
- 创建Deque序列:from collections import dequed = deque()Deque提供了类似list的操作方法:
- 举例如下: 数据表为DemoTable,字段有id, condition1,condition2,condition3,condition4
- 【OpenCV】⚠️高手勿入! 半小时学会基本操作⚠️边界填充概述OpenCV 是一个跨平台的计算机视觉库, 支持多语言, 功能强大. 今天
- 这篇文章主要介绍了Python JSON编解码方式原理详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要
- 1、纯粹的截取字符串function cutstr(thestr1,strlen) dim l,t,c&nbs
- 将cdb_pms表subject字段中的Welcom to替换成 欢迎光临 UPDATE `cdb_pms` SET `subject` =
- 自定义函数参数传递为 字符串格式 ,传递方式1:用this传递 2:引号缺省 3:转义字符(html中 " 代表"
- websocketWebsocket只是一个网络通信协议就像 http、ftp等都是网络通信的协议;不要多想;相对于HTTP这种非持久的协议
- SQL UNIQUE 约束 UNIQUE 约束唯一标识数据库表中的每条记录。 UNIQUE 和 PRIMARY KEY 约束均为列或列集合提
- 便携文档格式 (PDF) 是由 Adobe 开发的格式,主要用于呈现可打印的文档,其中包含有 pixel-perfect 格式,嵌入字体以及
- @Test public void test33() {
- 之前写过一篇 MySQL通过自定义函数的方式,递归查询树结构,从MySQL 8.0 开始终于支持了递归查询的语法CTE首先了解一下什么是 C
- 写了一段时间java切回写python偶尔会出现一些小麻烦,比如:在java中自定义对象变成json串很简单,调用一个方法就行,但同样的转换
- 网页过渡是指当浏览者进入或离开网页时,页面呈现的不同的刷新效果,比如卷动、百叶窗等。这样你的网页看起来
- 作为一个.net后台开发的程序猿,博客里既然大多都是前端相关的博文。是不是该考虑换方向了,转前端开发得了 ...小小吐槽一下,近期受该不该跳
- 今天在工作中遇到了一个问题,需要按时间查询,可是查询出来的结果显示的不正确。举个例子来说,要查找出2007-10-12至2007-10-31
- 与其他的大型数据库例如 Oracle、DB2、SQL Server等相比,MySQL 自有它的不足之处,但是这丝毫也没有减少它受欢迎的程度。
- 虽然我们一直使用书籍搜索的示例表单,并将起改进的很完美,但是这还是相当的简陋: 只包含一个字段,q。这简单的例子,我们不需要使用Django