keras K.function获取某层的输出操作
作者:脱贫&&脱单&&不脱发 发布时间:2023-03-11 15:10:21
如下所示:
from keras import backend as K
from keras.models import load_model
models = load_model('models.hdf5')
image=r'image.png'
images=cv2.imread(r'image.png')
image_arr = process_image(image, (224, 224, 3))
image_arr = np.expand_dims(image_arr, axis=0)
layer_1 = K.function([base_model.get_input_at(0)], [base_model.get_layer('layer_name').output])
f1 = layer_1([image_arr])[0]
加载训练好并保存的网络模型
加载数据(图像),并将数据处理成array形式
指定输出层
将处理后的数据输入,然后获取输出
其中,K.function有两种不同的写法:
1. 获取名为layer_name的层的输出
layer_1 = K.function([base_model.get_input_at(0)], [base_model.get_layer('layer_name').output])#指定输出层的名称
2. 获取第n层的输出
layer_1 = K.function([model.get_input_at(0)], [model.layers[5].output])#指定输出层的序号(层号从0开始)
另外,需要注意的是,书写不规范会导致报错:
报错:
TypeError: inputs to a TensorFlow backend function should be a list or tuple
将该句:
f1 = layer_1(image_arr)[0]
修改为:
f1 = layer_1([image_arr])[0]
补充知识:keras.backend.function()
如下所示:
def function(inputs, outputs, updates=None, **kwargs):
"""Instantiates a Keras function.
Arguments:
inputs: List of placeholder tensors.
outputs: List of output tensors.
updates: List of update ops.
**kwargs: Passed to `tf.Session.run`.
Returns:
Output values as Numpy arrays.
Raises:
ValueError: if invalid kwargs are passed in.
"""
if kwargs:
for key in kwargs:
if (key not in tf_inspect.getargspec(session_module.Session.run)[0] and
key not in tf_inspect.getargspec(Function.__init__)[0]):
msg = ('Invalid argument "%s" passed to K.function with Tensorflow '
'backend') % key
raise ValueError(msg)
return Function(inputs, outputs, updates=updates, **kwargs)
这是keras.backend.function()的源码。其中函数定义开头的注释就是官方文档对该函数的解释。
我们可以发现function()函数返回的是一个Function对象。下面是Function类的定义。
class Function(object):
"""Runs a computation graph.
Arguments:
inputs: Feed placeholders to the computation graph.
outputs: Output tensors to fetch.
updates: Additional update ops to be run at function call.
name: a name to help users identify what this function does.
"""
def __init__(self, inputs, outputs, updates=None, name=None,
**session_kwargs):
updates = updates or []
if not isinstance(inputs, (list, tuple)):
raise TypeError('`inputs` to a TensorFlow backend function '
'should be a list or tuple.')
if not isinstance(outputs, (list, tuple)):
raise TypeError('`outputs` of a TensorFlow backend function '
'should be a list or tuple.')
if not isinstance(updates, (list, tuple)):
raise TypeError('`updates` in a TensorFlow backend function '
'should be a list or tuple.')
self.inputs = list(inputs)
self.outputs = list(outputs)
with ops.control_dependencies(self.outputs):
updates_ops = []
for update in updates:
if isinstance(update, tuple):
p, new_p = update
updates_ops.append(state_ops.assign(p, new_p))
else:
# assumed already an op
updates_ops.append(update)
self.updates_op = control_flow_ops.group(*updates_ops)
self.name = name
self.session_kwargs = session_kwargs
def __call__(self, inputs):
if not isinstance(inputs, (list, tuple)):
raise TypeError('`inputs` should be a list or tuple.')
feed_dict = {}
for tensor, value in zip(self.inputs, inputs):
if is_sparse(tensor):
sparse_coo = value.tocoo()
indices = np.concatenate((np.expand_dims(sparse_coo.row, 1),
np.expand_dims(sparse_coo.col, 1)), 1)
value = (indices, sparse_coo.data, sparse_coo.shape)
feed_dict[tensor] = value
session = get_session()
updated = session.run(
self.outputs + [self.updates_op],
feed_dict=feed_dict,
**self.session_kwargs)
return updated[:len(self.outputs)]
所以,function函数利用我们之前已经创建好的comuptation graph。遵循计算图,从输入到定义的输出。这也是为什么该函数经常用于提取中间层结果。
来源:https://blog.csdn.net/qq_37974048/article/details/102727653


猜你喜欢
- 我就废话不多说了,大家还是直接看代码吧!# 在setting设置外键'OPTIONS': { "in
- 今天要说的是一个高速视频流的采集和传输的问题,我不是研究这一块的,没有使用什么算法,仅仅是兴趣导致我很想搞懂这个问题.  
- 首先定义好样式,利用v-for中的index值,然后绑定样式来实现隔行变色效果。以下为完整代码,很简单,但也是个技巧。<!DOCTYP
- 什么是fixture在一个测试过程中,fixture主要提供以下功能:为测试提供上下文,比如环境变量,数据集(dataset),提供数据,数
- df.groupby() 之后按照特定顺序输出,方便后续作图,或者跟其他df对比作图。## 构造 pd.DataFramepatient_i
- 基础介绍今天我跟大家把我理解的这一块全面的介绍下,配有sql语句送给大家。首先来给大家做个这一块的介绍:1,自连接说到底就是多张表都是同一张
- 我们已经知道,null 没有任何的属性值,并且无法获取其实体(existence)值。所以 null.property 返回的是错误(err
- 在这个abc.php文件中写入如下代码。<?php phpinfo(); ?>你将会看到一个网页,网页内容通常,如下图所示:用中
- 目录一、Python字典1.什么是字典2.字典的创建方式2.1 通过其他字典创建2.2 通过关键字参数创建2.3 通过键值对的序列创建2.4
- <% pagenum=55'指定打印行数 %> <HTML> <HEAD> <
- 我们知道,正则表达式是一个处理字符串中很实用的技巧。然而,即便是Javascript写的很厉害的程序猿,有时也会忘掉正则表达式的语法,从而使
- 问题一: 在anconda里面如何创建新的python环境(也就是更换新的python版本)1.先打开anconda软件,创建需要的环境2.
- 昨天我突发奇想,想用display:inline来实现三列的布局可是搞了半天就是不行。但是理论上是可以的呀(后来才发现是不理解的不深刻,我的
- 我就废话不多说了,大家还是直接看代码吧try: s = socket.socket() s.bind(('127.0.0.1'
- 本文实例讲述了Python数据预处理之数据规范化。分享给大家供大家参考,具体如下:数据规范化为了消除指标之间的量纲和取值范围差异的影响,需要
- 通常说到外键,只会提到“外键的目的是确定资料的参考完整性(referential integrity)。”,但是外键具体包含哪些动作和含义呢
- 这两天在整理一些文章,但是文件夹中每个文章没有序号会看起来很乱,所以想着能不能用Python写一个小脚本。于是乎,参考了多方资料,简单写了下
- asp之家注:本文介绍了使用asp来获取access数据库中的一条随机记录的方法,简单实用,相信对初学者有所帮助,根据这个方法其实我们可以实
- 变量(variable)是Python语言中一个非常重要的概念。变量的主要作用就是为Python程序中的某个值起一个名字。类似于“张三”、“
- 数据备份与还原第二篇,具体如下基础概念:备份,将当前已有的数据或记录另存一份;还原,将数据恢复到备份时的状态。为什么要进行数据的备份与还原?