使用pytorch提取卷积神经网络的特征图可视化
作者:落樱弥城 发布时间:2023-02-01 20:32:30
前言
文章中的代码是参考基于Pytorch的特征图提取编写的代码本身很简单这里只做简单的描述。
1. 效果图
先看效果图(第一张是原图,后面的都是相应的特征图,这里使用的网络是resnet50,需要注意的是下面图片显示的特征图是经过放大后的图,原图是比较小的图,因为太小不利于我们观察):
2. 完整代码
import os
import torch
import torchvision as tv
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import argparse
import skimage.data
import skimage.io
import skimage.transform
import numpy as np
import matplotlib.pyplot as plt
import torchvision.models as models
from PIL import Image
import cv2
class FeatureExtractor(nn.Module):
def __init__(self, submodule, extracted_layers):
super(FeatureExtractor, self).__init__()
self.submodule = submodule
self.extracted_layers = extracted_layers
def forward(self, x):
outputs = {}
for name, module in self.submodule._modules.items():
if "fc" in name:
x = x.view(x.size(0), -1)
x = module(x)
print(name)
if self.extracted_layers is None or name in self.extracted_layers and 'fc' not in name:
outputs[name] = x
return outputs
def get_picture(pic_name, transform):
img = skimage.io.imread(pic_name)
img = skimage.transform.resize(img, (256, 256))
img = np.asarray(img, dtype=np.float32)
return transform(img)
def make_dirs(path):
if os.path.exists(path) is False:
os.makedirs(path)
def get_feature():
pic_dir = './images/2.jpg'
transform = transforms.ToTensor()
img = get_picture(pic_dir, transform)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 插入维度
img = img.unsqueeze(0)
img = img.to(device)
net = models.resnet101().to(device)
net.load_state_dict(torch.load('./model/resnet101-5d3b4d8f.pt'))
exact_list = None
dst = './feautures'
therd_size = 256
myexactor = FeatureExtractor(net, exact_list)
outs = myexactor(img)
for k, v in outs.items():
features = v[0]
iter_range = features.shape[0]
for i in range(iter_range):
#plt.imshow(x[0].data.numpy()[0,i,:,:],cmap='jet')
if 'fc' in k:
continue
feature = features.data.numpy()
feature_img = feature[i,:,:]
feature_img = np.asarray(feature_img * 255, dtype=np.uint8)
dst_path = os.path.join(dst, k)
make_dirs(dst_path)
feature_img = cv2.applyColorMap(feature_img, cv2.COLORMAP_JET)
if feature_img.shape[0] < therd_size:
tmp_file = os.path.join(dst_path, str(i) + '_' + str(therd_size) + '.png')
tmp_img = feature_img.copy()
tmp_img = cv2.resize(tmp_img, (therd_size,therd_size), interpolation = cv2.INTER_NEAREST)
cv2.imwrite(tmp_file, tmp_img)
dst_file = os.path.join(dst_path, str(i) + '.png')
cv2.imwrite(dst_file, feature_img)
if __name__ == '__main__':
get_feature()
3. 代码说明
下面的模块是根据所指定的模型筛选出指定层的特征图输出,如果未指定也就是extracted_layers是None则以字典的形式输出全部的特征图,另外因为全连接层本身是一维的没必要输出因此进行了过滤。
class FeatureExtractor(nn.Module):
def __init__(self, submodule, extracted_layers):
super(FeatureExtractor, self).__init__()
self.submodule = submodule
self.extracted_layers = extracted_layers
def forward(self, x):
outputs = {}
for name, module in self.submodule._modules.items():
if "fc" in name:
x = x.view(x.size(0), -1)
x = module(x)
print(name)
if self.extracted_layers is None or name in self.extracted_layers and 'fc' not in name:
outputs[name] = x
return outputs
这段主要是存储图片,为每个层创建一个文件夹将特征图以JET的colormap进行按顺序存储到该文件夹,并且如果特征图过小也会对特征图放大同时存储原始图和放大后的图。
for k, v in outs.items():
features = v[0]
iter_range = features.shape[0]
for i in range(iter_range):
#plt.imshow(x[0].data.numpy()[0,i,:,:],cmap='jet')
if 'fc' in k:
continue
feature = features.data.numpy()
feature_img = feature[i,:,:]
feature_img = np.asarray(feature_img * 255, dtype=np.uint8)
dst_path = os.path.join(dst, k)
make_dirs(dst_path)
feature_img = cv2.applyColorMap(feature_img, cv2.COLORMAP_JET)
if feature_img.shape[0] < therd_size:
tmp_file = os.path.join(dst_path, str(i) + '_' + str(therd_size) + '.png')
tmp_img = feature_img.copy()
tmp_img = cv2.resize(tmp_img, (therd_size,therd_size), interpolation = cv2.INTER_NEAREST)
cv2.imwrite(tmp_file, tmp_img)
dst_file = os.path.join(dst_path, str(i) + '.png')
cv2.imwrite(dst_file, feature_img)
这里主要是一些参数,比如要提取的网络,网络的权重,要提取的层,指定的图像放大的大小,存储路径等等。
net = models.resnet101().to(device)
net.load_state_dict(torch.load('./model/resnet101-5d3b4d8f.pt'))
exact_list = None#['conv1']
dst = './feautures'
therd_size = 256
4. 可视化梯度,feature
上面的办法只是简单的将经过网络计算的图片的输出的feature进行图片,github上有将CNN的梯度等全部进行可视化的代码:pytorch-cnn-visualizations,需要注意的是如果只是简单的替换成自己的网络可能无法运行,大概率会报model没有features或者classifier等错误,这两个是进行分类网络定义时的Sequential,其实就是索引网络的每一层,自己稍微修改用model.children()
等方法进行替换即可,我自己修改之后得到的代码grayondream-pytorch-visualization(本来想稍微封装一下成为一个更加通用的结构,暂时没时间以后再说吧!),下面是效果图:
来源:https://blog.csdn.net/GrayOnDream/article/details/99090247
猜你喜欢
- 一般开发,SQL Server的数据库所有者为dbo.但是为了安全,有时候可能把它换成其它的名称,所有者变换不是很方便.这里列出两种供参考
- python实现四舍五入""" 四舍五入 :param
- 说到Javascript的类继承,就必然离不开原型链,但只通过原型链实现的继承有着不少缺陷。无参数类继承的问题先看一段示例代码,实现B继承于
- 这个格式是我自创的,经常有人问我为什么,这里做个简单总结:1、分类,一个模块或者同类功能定义为一类定义,每类定义之间用段落隔开。2、分级,每
- 比如有下面一段代码: for i in range(10): print ("%s" % (f_list[i].name
- Python 中方法的缺省参数问题在Python中可以缺省给方法制定缺省值,但是这个缺省值在某些情况下确是和我们预期不太一致的&he
- # 基础版,不依赖环境import timeimport base64import hashlibclass Token_hander():
- 简介如果你经常网上冲浪,这样参差不齐的多栏布局,是不是很眼熟啊?类似的布局,似乎一夜之间出现在国内外大大小小的网站上,比如 Pinteres
- 一、MySQL修改密码方法总结首先要说明一点的是:一般情况下,修改MySQL密码是需要有mysql里的root权限的,这样一般用户是无法更改
- 为什么会讲 MRO?在讲多继承的时候,有讲到, 当继承的多个父类拥有同名属性、方法,子类对象调用该属性、方法时会调用哪个父类的属性、方法呢?
- 有时候在一个页面用到收放功能的时候时,总有一个虚线框在触发收放的功能按钮上,显得特别刺眼,那如何去除这个虚线框呢?在IE下,需要在标签 a
- 全局变量与局部变量# num1是全局变量num1 = 1# num2是局部变量def func():num2 = 2在函数外(且不在函数里)
- Douglas Crockford是JavaScript开发社区最知名的权威,是JSON、JSLint、JSMin和ADSafe之父,是《J
- Python 做为一个脚本语言,可以很方便地写各种工具。当你在服务端要运行一个工具或服务时,输入参数似乎是一种硬需(当然你也可以通过配置文件
- 图片人脸检测#coding=utf-8import cv2import dlibpath = "img/meinv.png&quo
- ConfigParser模块在Python3修改为configparser,这个模块定义了一个ConfigeParser类,该类的作用是让配
- MYSQL在一个字段值前面加字符窜,如下:member 表名card 字段名update member SET card = '00
- 为了提高Asp程序的性能,人们常常将经常使用的数据缓存在 Application,但是你修改了数据库后怎么让application更新呢,本
- 最近在学习python爬虫,使用requests的时候遇到了不少的问题,比如说在requests中如何使用cookies进行登录验证,这可以
- 目录前言前期准备数据的选择与获取分词筛选与可视化总结前言”数据可视化“这个话题,相信大家并不陌生,在一些平台,经常可以看到一些动态条形图的视