keras小技巧——获取某一个网络层的输出方式
作者:LoveMIss-Y 发布时间:2023-08-20 12:56:47
前言:
keras默认提供了如何获取某一个层的某一个节点的输出,但是没有提供如何获取某一个层的输出的接口,所以有时候我们需要获取某一个层的输出,则需要自己编写代码,但是鉴于keras高层封装的特性,编写起来实际上很简单,本文提供两种常见的方法来实现,基于上一篇文章的模型和代码: keras自定义回调函数查看训练的loss和accuracy
一、模型加载以及各个层的信息查看
从前面的定义可知,参见上一篇文章,一共定义了8个网络层,定义如下:
model.add(Convolution2D(filters=6, kernel_size=(5, 5), padding='valid', input_shape=(img_rows, img_cols, 1), activation='tanh'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Convolution2D(filters=16, kernel_size=(5, 5), padding='valid', activation='tanh'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(120, activation='tanh'))
model.add(Dense(84, activation='tanh'))
model.add(Dense(n_classes, activation='softmax'))
这里每一个层都没有起名字,实际上最好给每一个层取一个名字,所以这里就使用索引来访问层,如下:
for index in range(8):
layer=model.get_layer(index=index)
# layer=model.layers[index] # 这样获取每一个层也是一样的
print(model)
'''运行结果如下:
<keras.engine.sequential.Sequential object at 0x0000012A4F232E10>
<keras.engine.sequential.Sequential object at 0x0000012A4F232E10>
<keras.engine.sequential.Sequential object at 0x0000012A4F232E10>
<keras.engine.sequential.Sequential object at 0x0000012A4F232E10>
<keras.engine.sequential.Sequential object at 0x0000012A4F232E10>
<keras.engine.sequential.Sequential object at 0x0000012A4F232E10>
<keras.engine.sequential.Sequential object at 0x0000012A4F232E10>
<keras.engine.sequential.Sequential object at 0x0000012A4F232E10>
'''
当然由于 model.laters是一个列表,所以可以一次性打印出所有的层信息,即
print(model.layers) # 打印出所有的层
二、模型的加载
准备测试数据
# 训练参数
learning_rate = 0.001
epochs = 10
batch_size = 128
n_classes = 10
# 定义图像维度reshape
img_rows, img_cols = 28, 28
# 加载keras中的mnist数据集 分为60,000个训练集,10,000个测试集
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# 将图片转化为(samples,width,height,channels)的格式
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
# 将X_train, X_test的数据格式转为float32
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
# 将X_train, X_test归一化0-1
x_train /= 255
x_test /= 255
# 输出0-9转换为ont-hot形式
y_train = np_utils.to_categorical(y_train, n_classes)
y_test = np_utils.to_categorical(y_test, n_classes)
模型的加载
model=keras.models.load_model('./models/lenet5_weight.h5')
注意事项:
keras的每一个层有一个input和output属性,但是它是只针对单节点的层而言的哦,否则就不需要我们再自己编写输出函数了,
如果一个层具有单个节点 (i.e. 如果它不是共享层), 你可以得到它的输入张量、输出张量、输入尺寸和输出尺寸:
layer.input
layer.output
layer.input_shape
layer.output_shape
如果层有多个节点 (参见: 层节点和共享层的概念), 您可以使用以下函数:
layer.get_input_at(node_index)
layer.get_output_at(node_index)
layer.get_input_shape_at(node_index)
layer.get_output_shape_at(node_index)
三、获取某一个层的输出的方法定义
3.1 第一种实现方法
def get_output_function(model,output_layer_index):
'''
model: 要保存的模型
output_layer_index:要获取的那一个层的索引
'''
vector_funcrion=K.function([model.layers[0].input],[model.layers[output_layer_index].output])
def inner(input_data):
vector=vector_funcrion([input_data])[0]
return vector
return inner
# 现在仅仅测试一张图片
#选择一张图片,选择第一张
x= np.expand_dims(x_test[1],axis=0) #[1,28,28,1] 的形状
get_feature=get_output_function(model,6) # 该函数的返回值依然是一个函数哦,获取第6层输出
feature=get_feature(x) # 相当于调用 定义在里面的inner函数
print(feature)
'''运行结果为
[[-0.99986297 -0.9988328 -0.9273474 0.9101525 -0.9054705 -0.95798373
0.9911243 0.78576803 0.99676156 0.39356467 -0.9724135 -0.74534595
0.8527011 -0.9968267 -0.9420816 -0.32765102 -0.41667578 0.99942905
0.92333794 0.7565034 -0.38416263 -0.994241 0.3781617 0.9621943
0.9443946 0.9671554 -0.01000021 -0.9984282 -0.96650964 -0.9925837
-0.48193568 -0.9749565 -0.79769516 0.9651831 0.9678705 -0.9444472
0.9405674 0.97538495 -0.12366439 -0.9973782 0.05803521 0.9159217
-0.9627071 0.99898154 0.99429387 -0.985909 0.5787794 -0.9789403
-0.94316894 0.9999644 0.9156823 0.46314353 -0.01582102 0.98359734
0.5586145 -0.97360635 0.99058044 0.9995654 -0.9800733 0.99942625
0.8786553 -0.9992093 0.99916387 -0.5141877 0.99970615 0.28427476
0.86589384 0.7649907 -0.9986046 0.9999706 -0.9892468 0.99854743
-0.86872625 -0.9997323 0.98981035 -0.87805724 -0.9999373 -0.7842255
-0.97456616 -0.97237325 -0.729563 0.98718935 0.9992022 -0.5294769 ]]
'''
但是上面的实现方法似乎不是很简单,还有更加简单的方法,思想来源与keras中,可以将整个模型model也当成是层layer来处理,实现如下面。
3.2 第二种实现方法
import keras
import numpy as np
from keras.datasets import mnist
from keras.models import Model
model=keras.models.load_model('./models/lenet5_weight.h5')
#选择一张图片,选择第一张
x= np.expand_dims(x_test[1],axis=0) #[1,28,28,1] 的形状
# 将模型作为一个层,输出第7层的输出
layer_model = Model(inputs=model.input, outputs=model.layers[6].output)
feature=layer_model.predict(x)
print(feature)
'''运行结果为:
[[-0.99986297 -0.9988328 -0.9273474 0.9101525 -0.9054705 -0.95798373
0.9911243 0.78576803 0.99676156 0.39356467 -0.9724135 -0.74534595
0.8527011 -0.9968267 -0.9420816 -0.32765102 -0.41667578 0.99942905
0.92333794 0.7565034 -0.38416263 -0.994241 0.3781617 0.9621943
0.9443946 0.9671554 -0.01000021 -0.9984282 -0.96650964 -0.9925837
-0.48193568 -0.9749565 -0.79769516 0.9651831 0.9678705 -0.9444472
0.9405674 0.97538495 -0.12366439 -0.9973782 0.05803521 0.9159217
-0.9627071 0.99898154 0.99429387 -0.985909 0.5787794 -0.9789403
-0.94316894 0.9999644 0.9156823 0.46314353 -0.01582102 0.98359734
0.5586145 -0.97360635 0.99058044 0.9995654 -0.9800733 0.99942625
0.8786553 -0.9992093 0.99916387 -0.5141877 0.99970615 0.28427476
0.86589384 0.7649907 -0.9986046 0.9999706 -0.9892468 0.99854743
-0.86872625 -0.9997323 0.98981035 -0.87805724 -0.9999373 -0.7842255
-0.97456616 -0.97237325 -0.729563 0.98718935 0.9992022 -0.5294769 ]]
'''
可见和上面的结果是一样的,
总结:
由于keras的层与模型之间实际上的转化关系,所以提供了非常灵活的输出方法,推荐使用第二种方法获得某一个层的输出。总结为以下几个主要的步骤(四步走):
import keras
import numpy as np
from keras.datasets import mnist
from keras.models import Model
# 第一步:准备输入数据
x= np.expand_dims(x_test[1],axis=0) #[1,28,28,1] 的形状
# 第二步:加载已经训练的模型
model=keras.models.load_model('./models/lenet5_weight.h5')
# 第三步:将模型作为一个层,输出第7层的输出
layer_model = Model(inputs=model.input, outputs=model.layers[6].output)
# 第四步:调用新建的“曾模型”的predict方法,得到模型的输出
feature=layer_model.predict(x)
print(feature)
来源:https://blog.csdn.net/qq_27825451/article/details/93378950


猜你喜欢
- 网页编程中,在与数据库打交道的时候我们经常会碰到乱码的经常。本文就将介绍一种ASP读取MySQL数据库出现乱码的解决办法。情景再现:使用My
- 用pycharm开发时,在导入自己写的python文件时出现模块名爆红的情况,而且后面每次调用文件里的函数都没有没有提示,必须自己手动输入,
- 加上设置字符编码的方法:response.setHeader("charset","gb2312")
- 本文实例讲述了python实现根据主机名字获得所有ip地址的方法。分享给大家供大家参考。具体实现方法如下:# -*- coding: utf
- 1.如何构建应用框架一般来说构建应用框架包含3个部分:命令行参数解析配置文件解析应用的命令行框架:需要具备 Help 功能、需要能够解析命令
- 关于文件加载及处理1、检查python关于文件加载及处理方式文件路径是否存在,如果不存在就创建此路径。#如果不存在路径,就创建一个这样的路径
- 1 概述1.1 贪心算法贪心算法总是作出在当前看来最好的选择。也就是说贪心算法并不从整体最优考虑,它所作出的选择只是在某种意义上的局部最优选
- 把函数作为参数的用法比较直观:def func(a, b): return a + bdef test(f, a, b): print f
- 最近开发小程序,需要做一个导航,导航可以通过template写出来,但是这个项目需要在导航中处理一些逻辑,做成组件更方便些。首先新建head
- 跟着节奏继续来探索fixtures的灵活性。一、一个测试函数/fixture一次请求多个fixture在测试函数和fixture函数中,每一
- 本文实例讲述了python实现统计代码行数的方法。分享给大家供大家参考。具体实现方法如下:'''Author: li
- 本文实例讲述了Flask框架实现的前端RSA加密与后端Python解密功能。分享给大家供大家参考,具体如下:前言在使用 Flask 开发用户
- 不多说,我们直接上源码:# -*- coding:UTF-8 -*-'''实现文件打包、上传与校验Created o
- 第一次写ASP类,实现功能:分段统计程序执行时间,输出统计表等.程序代码:Class ccClsProcessTimeRecord
- 简介本文主要介绍如何通过pyplot来绘制函数图。主要绘制函数如下: - 一元一次函数 - 一元二次函数 - 指数函数 - 自然对数函数 -
- 下载MySQL-8.0.23点击下载:mysql-8.0.23-linux-glibc2.12-x86_64.tar.xz解压MySQL的安
- 阅读目录什么是前端代码异常 window.onerror写一个js报错的上报库注意点:缺点:在平时的工作,js报错是比较常见的一个
- 本文是通过深度学习框架keras来做SQL注入特征识别, 不过虽然用了keras,但是大部分还是普通的神经网络,只是外加了一些规则化、dro
- PyQt5切换按钮控件QPushButton简介QAbstractButton类为抽象类,不能实例化,必须由其他的按钮类继承QAbstrac
- 简洁的隐藏垂直菜单在hover时将内容展开。这样的效果在JS里有很多个版本,但这个可以说是绝无仅有的CSS版本。此菜单可以在IE5.5,IE