使用K.function()调试keras操作
作者:庞加莱 发布时间:2022-08-03 05:07:31
Keras的底层库使用Theano或TensorFlow,这两个库也称为Keras的后端。无论是Theano还是TensorFlow,都需要提前定义好网络的结构,也就是常说的“计算图”。
在运行前需要对计算图编译,然后才能输出结果。那这里面主要有两个问题,第一是这个图结构在运行中不能任意更改,比如说计算图中有一个隐含层,神经元的数量是100,你想动态的修改这个隐含层神经元的数量那是不可以的;第二是调试困难,keras没有内置的调试工具,所以计算图的中间结果是很难看到的,一旦最终输出跟预想不一致,很难找到问题所在。
这里谈一谈本人调试keras的一些经验:
分阶段构建你的神经网络
不要一口气把整个网络全部写完,这样很难保证中间结果的正确性。加如一个CNN文本分类模型是这样的(如下代码),应该在加了Embedding层后,停止,打印一下中间结果,看看跟embedding向量能不能对上,输出的shape对不对。对上了再进行下一步操作。
有的人觉得这样很浪费时间,但是除非你能一遍写对,否则你将花上5倍的时间发现错误。
# model parameters:
embedding_dims = 50
cnn_filters = 100
cnn_kernel_size = 5
dense_hidden_dims = 200
model = Sequential()
model.add(Embedding(nb_words,embedding_dims,input_length=maxlen))
model.add(Dropout(0.5))
model.add(Conv1D(cnn_filters, cnn_kernel_size,padding='valid', activation='relu'))
model.add(GlobalMaxPooling1D())
model.add(Dense(dense_hidden_dims))
model.add(Dropout(0.5))
model.add(Activation('relu'))
model.add(Dense(1))
model.add(Activation('sigmoid'))
return model
使用K.function()函数打印中间结果
function函数可以接收传入数据,并返回一个numpy数组。使用这个函数我们可以方便地看到中间结果,尤其对于变长输入的Input。
下面是官方关于function的文档。
function
keras.backend.function(inputs, outputs, updates=None)
实例化 Keras 函数。
参数
inputs: 占位符张量列表。
outputs: 输出张量列表。
updates: 更新操作列表。
**kwargs: 需要传递给 tf.Session.run 的参数。
返回
输出值为 Numpy 数组。
异常
ValueError: 如果无效的 kwargs 被传入。
example
下面这个例子是打印一个LSTM层的中间结果,值得注意的是这个LSTM的sequence是变长的,可以看到输出的结果sequence长度分别是64和128
import keras.backend as K
from keras.layers import LSTM, Input
import numpy as np
I = Input(shape=(None, 200))
lstm = LSTM(20, return_sequences=True)
f = K.function(inputs=[I], outputs=[lstm(I)])
data1 = np.random.random(size=(2, 64, 200))
print(f([data1])[0].shape)
data2 = np.random.random(size=(2, 128, 200))
print(f([data2])[0].shape)
K.clear_session()
# (2, 64, 20)
# (2, 128, 20)
其他的调试技巧
有频繁张量变换操作的,如dot, mat, reshape等等,记得加一行形状变化的注释,如(100, 128)--> (100, 64)
可以使用tensorboard查看网络的参数情况
确保你的数据没有问题,很多时候输出不对不是神经网络有问题,而是数据有问题
来源:https://blog.csdn.net/u010960155/article/details/92176492


猜你喜欢
- 秉承MVC架构的思想,CI中的所有控制器都需要经过单点入口文件index.php(默认)来加载调用。也就是说,在默认情况下,所有CI开发项目
- 开发环境说明:Python 35Pytorch 0.2CPU/GPU均可1、LSTM简介人类在进行学习时,往往不总是零开始,学习物理你会有数
- 元组是不可变的Python对象序列。元组的序列就像列表。唯一的区别是,元组不能被改变,即元组是不可被修改。元组使用小括号,而列表
- 错误日志安装时出现了如下错误:SQL Server 2005 安装错误码29503。产品: Microsoft SQL Server 200
- .游标方式 1 DECLARE @Data NVARCHAR(max) SET @Data='1,tanw;2,keen
- 只有pd模型文件, 打印所有节点from tensorflow.python.framework import tensor_utilfro
- 故障状况:php网站连接mysql失败,但在命令行下通过mysql命令可登录并正常操作。解决方案:1、命令行下登录mysql,执行以下命令:
- 找到自己的mysql数据库的安装位置,如下 C:\Program Files\MySQL\MySQL Server 5.1,在它里面有个的m
- 进程什么是进程进程指的是一个程序的运行过程,或者说一个正在执行的程序所以说进程一种虚拟的概念,该虚拟概念起源操作系统一个CPU 同一时刻只能
- 静态方法:将下面的代码复制到<body>~</body>内 程序代码 <table cellpadd
- 关于mysql效率优化一般通过以下两种方式定位执行效率较低的sql语句。通过慢查询日志定位那些执行效率较低的 SQL 语句,用 --log-
- 如下所示:# 导入模块import win32guiwin = win32gui.FindWindow(None, u'张三'
- 需求在4*4的图片中,比较外围黑色像素点和内圈黑色像素点个数的大小将图片分类如上图图片外围黑色像素点5个大于内圈黑色像素点1个分为0类反之1
- 本文转自公众号:"算法与编程之美"1、问题描述Python中数据类型有列表,元组,字典,队列,栈,树等等。像列表,元组这
- Update 语句Update 语句用于修改表中的数据。语法:UPDATE 表名称 SET 列名称 = 新值 WHERE 列名称 = 某值P
- sorted函数sorted(iterable,key,reverse)iterable 待排序的可迭代对象key 对应的是个函数, 该函数
- 通过Python操作注册表有两种方式,第一种是通过Python的内置模块 _winreg;另一种方式就是Win32 Extension Fo
- 网页设计中的脏、乱、差,是我们在设计过程中常会遇到的问题。通常"脏"是由对色彩使用不当所产生的,而色彩使用不当产生的不好
- jQuery的选择器可谓异常强大,没有什么DOM里的任何数据能逃出它的掌心,这点是我非常喜欢的,以前获取NODE要用getElementBy
- 求N的阶乘本题要求编写程序,计算N的阶乘。输入格式:输入在一行中给出一个正整数 N。输出格式:在一行中按照“produc