利用Pytorch实现获取特征图的方法详解
作者:拜阳 发布时间:2023-09-11 16:16:02
简单加载官方预训练模型
torchvision.models预定义了很多公开的模型结构
如果pretrained参数设置为False,那么仅仅设定模型结构;如果设置为True,那么会启动一个下载流程,下载预训练参数
如果只想调用模型,不想训练,那么设置model.eval()和model.requires_grad_(False)
想查看模型参数可以使用modules和named_modules,其中named_modules是一个长度为2的tuple,第一个变量是name,第二个变量是module本身。
# -*- coding: utf-8 -*-
from torch import nn
from torchvision import models
# load model. If pretrained is True, there will be a downloading process
model = models.vgg19(pretrained=True)
model.eval()
model.requires_grad_(False)
# get model component
features = model.features
modules = features.modules()
named_modules = features.named_modules()
# print modules
for module in modules:
if isinstance(module, nn.Conv2d):
weight = module.weight
bias = module.bias
print(module, weight.shape, bias.shape,
weight.requires_grad, bias.requires_grad)
elif isinstance(module, nn.ReLU):
print(module)
print()
for named_module in named_modules:
name = named_module[0]
module = named_module[1]
if isinstance(module, nn.Conv2d):
weight = module.weight
bias = module.bias
print(name, module, weight.shape, bias.shape,
weight.requires_grad, bias.requires_grad)
elif isinstance(module, nn.ReLU):
print(name, module)
图片预处理
使用opencv和pil读图都可以使用transforms.ToTensor()把原本[H, W, 3]的数据转成[3, H, W]的tensor。但opencv要注意把数据改成RGB顺序。
vgg系列模型需要做normalization,建议配合torchvision.transforms来实现。
mini-batches of 3-channel RGB images of shape (3 x H x W), where H and W are expected to be at least 224. The images have to be loaded in to a range of [0, 1] and then normalized using mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225].
参考:https://pytorch.org/hub/pytorch_vision_vgg/
# -*- coding: utf-8 -*-
from PIL import Image
import cv2
import torch
from torchvision import transforms
# transforms for preprocess
preprocess = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# load image using cv2
image_cv2 = cv2.imread('lena_std.bmp')
image_cv2 = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB)
image_cv2 = preprocess(image_cv2)
# load image using pil
image_pil = Image.open('lena_std.bmp')
image_pil = preprocess(image_pil)
# check whether image_cv2 and image_pil are same
print(torch.all(image_cv2 == image_pil))
print(image_cv2.shape, image_pil.shape)
提取单个特征图
如果只提取单层特征图,可以把模型截断,以节省算力和显存消耗。
下面索引之所以有+1是因为pytorch预训练模型里面第一个索引的module总是完整模块结构,第二个才开始子模块。
# -*- coding: utf-8 -*-
from PIL import Image
from torchvision import models
from torchvision import transforms
# load model. If pretrained is True, there will be a downloading process
model = models.vgg19(pretrained=True)
model = model.features[:16 + 1] # 16 = conv3_4
model.eval()
model.requires_grad_(False)
model.to('cuda')
print(model)
# load and preprocess image
preprocess = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
transforms.Resize(size=(224, 224))
])
image = Image.open('lena_std.bmp')
image = preprocess(image)
inputs = image.unsqueeze(0) # add batch dimension
inputs = inputs.cuda()
# forward
output = model(inputs)
print(output.shape)
提取多个特征图
第一种方式:逐层运行model,如果碰到了需要保存的feature map就存下来。
第二种方式:使用register_forward_hook,使用这种方式需要用一个类把feature map以成员变量的形式缓存下来。
两种方式的运行效率差不多
第一种方式简单直观,但是只能处理类似VGG这种没有跨层连接的网络;第二种方式更加通用。
# -*- coding: utf-8 -*-
from PIL import Image
import torch
from torchvision import models
from torchvision import transforms
# load model. If pretrained is True, there will be a downloading process
model = models.vgg19(pretrained=True)
model = model.features[:16 + 1] # 16 = conv3_4
model.eval()
model.requires_grad_(False)
model.to('cuda')
# check module name
for named_module in model.named_modules():
name = named_module[0]
module = named_module[1]
print('-------- %s --------' % name)
print(module)
print()
# load and preprocess image
preprocess = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
transforms.Resize(size=(224, 224))
])
image = Image.open('lena_std.bmp')
image = preprocess(image)
inputs = image.unsqueeze(0) # add batch dimension
inputs = inputs.cuda()
# forward - 1
layers = [2, 7, 8, 9, 16]
layers = sorted(set(layers))
feature_maps = {}
feature = inputs
for i in range(max(layers) + 1):
feature = model[i](feature)
if i in layers:
feature_maps[i] = feature
for key in feature_maps:
print(key, feature_maps.get(key).shape)
# forward - 2
class FeatureHook:
def __init__(self, module):
self.inputs = None
self.output = None
self.hook = module.register_forward_hook(self.get_features)
def get_features(self, module, inputs, output):
self.inputs = inputs
self.output = output
layer_names = ['2', '7', '8', '9', '16']
hook_modules = []
for named_module in model.named_modules():
name = named_module[0]
module = named_module[1]
if name in layer_names:
hook_modules.append(module)
hooks = [FeatureHook(module) for module in hook_modules]
output = model(inputs)
features = [hook.output for hook in hooks]
for feature in features:
print(feature.shape)
# check correctness
for i, layer in enumerate(layers):
feature1 = feature_maps.get(layer)
feature2 = features[i]
print(torch.all(feature1 == feature2))
使用第二种方式(register_forward_hook),resnet特征图也可以顺利拿到。
而由于resnet的model已经不可以用model[i]的形式索引,所以无法使用第一种方式。
# -*- coding: utf-8 -*-
from PIL import Image
from torchvision import models
from torchvision import transforms
# load model. If pretrained is True, there will be a downloading process
model = models.resnet18(pretrained=True)
model.eval()
model.requires_grad_(False)
model.to('cuda')
# check module name
for named_module in model.named_modules():
name = named_module[0]
module = named_module[1]
print('-------- %s --------' % name)
print(module)
print()
# load and preprocess image
preprocess = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
transforms.Resize(size=(224, 224))
])
image = Image.open('lena_std.bmp')
image = preprocess(image)
inputs = image.unsqueeze(0) # add batch dimension
inputs = inputs.cuda()
class FeatureHook:
def __init__(self, module):
self.inputs = None
self.output = None
self.hook = module.register_forward_hook(self.get_features)
def get_features(self, module, inputs, output):
self.inputs = inputs
self.output = output
layer_names = [
'conv1',
'layer1.0.relu',
'layer2.0.conv1'
]
hook_modules = []
for named_module in model.named_modules():
name = named_module[0]
module = named_module[1]
if name in layer_names:
hook_modules.append(module)
hooks = [FeatureHook(module) for module in hook_modules]
output = model(inputs)
features = [hook.output for hook in hooks]
for feature in features:
print(feature.shape)
问题来了,resnet这种类型的网络结构怎么截断?
使用如下命令就可以,print查看需要截断到哪里,然后用nn.Sequential重组即可。
需注意重组后网络的module_name会发生变化。
print(list(model.children())
model = torch.nn.Sequential(*list(model.children())[:6])
来源:https://blog.csdn.net/bby1987/article/details/126636160
猜你喜欢
- 片头Python看了差不多三四天吧,基本上给基础看差不多了。写个管理系统吧,后续不出意外SQL、文件存储版本都会更。学习Python感想:
- 一、标签语法由%}和 {% 来定义的,例如:{%tag%} {%endtag%},完整的标签有开始就有结束,如条件语句,有条件判断的开始,也
- 存在问题:jupyter代码无法在pycharm中运行原因:工作文件和安装文件不统一引起的解决方案:pycharm中新建工程项目时,要将图中
- From Python正则表达式re.match(pattern, string, flags=0)尝试从字符串起始位置匹配一个模式;如果不
- 一、整体合并团队协作中,开发人员A、B、C分别在dev上进行功能开发,并push代码到远端dev上。当测试人员需要对功能进行测试的时候,我们
- datetime64与unix时间戳互转在用pandas处理数据时,经常要处理一些时间类型数据,经常把pandas时间类型与datetime
- 一、概述:用来描述或者匹配一系列符合某个语句规则的字符串二、单个符号1、英文句点.符号:匹配单个任意字符。表达式t.o 可以匹配:tno,t
- <style type="text/css"> <!-- body,td,th {
- 本文实例讲述了PHP函数extension_loaded()用法。分享给大家供大家参考。具体分析如下:extension_loaded —
- 一.基本概念事务是指满足ACID特性的的一组操作,可以通过Commit提交事务,也可以也可以通过Rollback进行回滚。会存在中间态和一致
- 概括、从python1.6开始就可以处理unicode字符了。 一、几种常见的编码格式。 1.1、ascii,用1个字节表示。 1.2、UT
- 本文实例讲述了Python面向对象程序设计之私有变量,私有方法原理与用法。分享给大家供大家参考,具体如下:私有变量,私有方法:python的
- 一 MySQL的内部组件结构大体来说,MySQL 可以分为 Server 层和存储引擎层两部分。1.1 service层主要包括连接器、查询
- 在某些特殊情况下,我们的 Python 脚本需要调用父目录下的其他模块。例如:在编写 GNE 的测试用例时,有一个脚本 generate_n
- 本文实例讲述了Python实现批量读取图片并存入mongodb数据库的方法。分享给大家供大家参考,具体如下:我的图片放在E:\image\中
- 一、什么是jieba库jieba是优秀的中文分词第三方库,由于中文文本之间每个汉字都是连续书写的,我们需要通过特定的手段来获得其中的每个词组
- 一、正则表达式的特殊字符介绍正则表达式^ 匹配行首 &nb
- 本文实例为大家分享了JS作用域链的相关内容,供大家参考,具体内容如下1、所有全局变量和函数都是作为window对象的属性和方法创建的。2、在
- 一、 排序的基本使用在查询数据时,如果没有使用排序操作,默认情况下SQL会按元组添加的顺序来排列查询结果。在SQL中,使用关键字 ORDER
- 直接使用word文档已经难不倒大家了,有没有想过用python构建一个word文档写点文章呢?当然这个文章的框架需要我们用代码一点点的建立,