pytorch 获取层权重,对特定层注入hook, 提取中间层输出的方法
作者:青盏 发布时间:2021-05-20 07:01:01
标签:pytorch,层权重,hook,中间层
如下所示:
#获取模型权重
for k, v in model_2.state_dict().iteritems():
print("Layer {}".format(k))
print(v)
#获取模型权重
for layer in model_2.modules():
if isinstance(layer, nn.Linear):
print(layer.weight)
#将一个模型权重载入另一个模型
model = VGG(make_layers(cfg['E']), **kwargs)
if pretrained:
load = torch.load('/home/huangqk/.torch/models/vgg19-dcbb9e9d.pth')
load_state = {k: v for k, v in load.items() if k not in ['classifier.0.weight', 'classifier.0.bias', 'classifier.3.weight', 'classifier.3.bias', 'classifier.6.weight', 'classifier.6.bias']}
model_state = model.state_dict()
model_state.update(load_state)
model.load_state_dict(model_state)
return model
# 对特定层注入hook
def hook_layers(model):
def hook_function(module, inputs, outputs):
recreate_image(inputs[0])
print(model.features._modules)
first_layer = list(model.features._modules.items())[0][1]
first_layer.register_forward_hook(hook_function)
#获取层
x = someinput
for l in vgg.features.modules():
x = l(x)
modulelist = list(vgg.features.modules())
for l in modulelist[:5]:
x = l(x)
keep = x
for l in modulelist[5:]:
x = l(x)
# 提取vgg模型的中间层输出
# coding:utf8
import torch
import torch.nn as nn
from torchvision.models import vgg16
from collections import namedtuple
class Vgg16(torch.nn.Module):
def __init__(self):
super(Vgg16, self).__init__()
features = list(vgg16(pretrained=True).features)[:23]
# features的第3,8,15,22层分别是: relu1_2,relu2_2,relu3_3,relu4_3
self.features = nn.ModuleList(features).eval()
def forward(self, x):
results = []
for ii, model in enumerate(self.features):
x = model(x)
if ii in {3, 8, 15, 22}:
results.append(x)
vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3'])
return vgg_outputs(*results)
来源:https://blog.csdn.net/qq_16234613/article/details/80217851


猜你喜欢
- 一、事务:事务是逻辑上的一组操作,要么都成功,要么都失败!——————————————————————————————————1、SQL执行
- plt.imshow(image)无法显示图片的解决使用plt.imshow()发现不能显示图片,加了plt.show()也还是不能显示先引
- 最近在写博客,刚好写到用户注册注销模块,觉得这一方面还是挺有趣的。当尝试掀开 Django 的源代码时一切 API 就不会变得那么摸不着。顺
- 在python中,通过导入random库,就能使用randint 和 randrange这两个方法来产生随机整数。那这两个方法的区别在于什么
- Golang调度机制最近抽空研究、整理了一下Golang调度机制,学习了其他大牛的文章。把自己的理解写下来。如有错误,请指正!!!golan
- 用于匹配的正则表达式为 :([1-9]\d*\.?\d*)|(0\.\d*[1-9])([1-9] :匹配1~9的数字;\d :匹配数字,包
- 游戏规则用pygame动画实现神庙逃亡类似的小游戏,当玩家移动的时候躲避 * ,如果 * 命中玩家或者名字龙都会减速,玩家躲避 * 使更多的 * 打
- 代码如下:Function splitx(strs1 As String, strs2 A
- 什么是序列化与反序列化这里引入微软对序列化的解释:序列化是指将对象转换成字节流,从而存储对象或将对象传输到内存、数据库或文件的过程。 它的主
- 报错代码使用cmd查看电脑显卡的信息,调用nvidia-smi查看显卡使用情况报错如下:'nvidia-smi' 不是内部或
- 代码如下def PI(n): pi=0 for k in range(n): pi +=
- 实例如下所示:import pandas as pdimport reimport mathdframe1 = pd.read_excel(
- 简介:使用python的过程中肯定少不了读取文件的操作,传统的形式是使用 直接打开、然后在操作、然后再关闭,这样代码量稍微大些不说,一旦在操
- 在模板中往往要加载静态文件,如CSS, JavaScript,图片等。那么这些文件在django中如何才能正确加载呢?首先要在setting
- 在js中调用asp文件的方法很简单,我们可以用在静态页面的点击数统计,虽然静态页面不支持asp程序,但是我们可以使用js调用,来变相的达到这
- 引言python编程时,一部分人习惯将实现同一个功能的代码放在同一个文件;使用这些代码只需要import就可以了;下面看一个例子。testM
- python是免费的么?python是免费的,也就是开源的。编程软件的盈利方式就是你使用它, 用的人越多越值钱。注:Python 是一个高层
- python 绘制拟合曲线并加指定点标识import osimport numpy as npfrom scipy import logfr
- xml(可扩展标记语言)看起来可能像某种w3c标准——现在没有什么实际影响,即使以后能派上用场,也是很久以后的事。但实际上,它现在已经得到了
- 通过 register_shutdown_function 方法,可以让我们设置一个当执行关闭时可以被调用的另一个函数。也就是说,当我们的脚