使用pytorch提取卷积神经网络的特征图可视化
作者:落樱弥城 发布时间:2023-02-01 20:32:30
前言
文章中的代码是参考基于Pytorch的特征图提取编写的代码本身很简单这里只做简单的描述。
1. 效果图
先看效果图(第一张是原图,后面的都是相应的特征图,这里使用的网络是resnet50,需要注意的是下面图片显示的特征图是经过放大后的图,原图是比较小的图,因为太小不利于我们观察):
2. 完整代码
import os
import torch
import torchvision as tv
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import argparse
import skimage.data
import skimage.io
import skimage.transform
import numpy as np
import matplotlib.pyplot as plt
import torchvision.models as models
from PIL import Image
import cv2
class FeatureExtractor(nn.Module):
def __init__(self, submodule, extracted_layers):
super(FeatureExtractor, self).__init__()
self.submodule = submodule
self.extracted_layers = extracted_layers
def forward(self, x):
outputs = {}
for name, module in self.submodule._modules.items():
if "fc" in name:
x = x.view(x.size(0), -1)
x = module(x)
print(name)
if self.extracted_layers is None or name in self.extracted_layers and 'fc' not in name:
outputs[name] = x
return outputs
def get_picture(pic_name, transform):
img = skimage.io.imread(pic_name)
img = skimage.transform.resize(img, (256, 256))
img = np.asarray(img, dtype=np.float32)
return transform(img)
def make_dirs(path):
if os.path.exists(path) is False:
os.makedirs(path)
def get_feature():
pic_dir = './images/2.jpg'
transform = transforms.ToTensor()
img = get_picture(pic_dir, transform)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 插入维度
img = img.unsqueeze(0)
img = img.to(device)
net = models.resnet101().to(device)
net.load_state_dict(torch.load('./model/resnet101-5d3b4d8f.pt'))
exact_list = None
dst = './feautures'
therd_size = 256
myexactor = FeatureExtractor(net, exact_list)
outs = myexactor(img)
for k, v in outs.items():
features = v[0]
iter_range = features.shape[0]
for i in range(iter_range):
#plt.imshow(x[0].data.numpy()[0,i,:,:],cmap='jet')
if 'fc' in k:
continue
feature = features.data.numpy()
feature_img = feature[i,:,:]
feature_img = np.asarray(feature_img * 255, dtype=np.uint8)
dst_path = os.path.join(dst, k)
make_dirs(dst_path)
feature_img = cv2.applyColorMap(feature_img, cv2.COLORMAP_JET)
if feature_img.shape[0] < therd_size:
tmp_file = os.path.join(dst_path, str(i) + '_' + str(therd_size) + '.png')
tmp_img = feature_img.copy()
tmp_img = cv2.resize(tmp_img, (therd_size,therd_size), interpolation = cv2.INTER_NEAREST)
cv2.imwrite(tmp_file, tmp_img)
dst_file = os.path.join(dst_path, str(i) + '.png')
cv2.imwrite(dst_file, feature_img)
if __name__ == '__main__':
get_feature()
3. 代码说明
下面的模块是根据所指定的模型筛选出指定层的特征图输出,如果未指定也就是extracted_layers是None则以字典的形式输出全部的特征图,另外因为全连接层本身是一维的没必要输出因此进行了过滤。
class FeatureExtractor(nn.Module):
def __init__(self, submodule, extracted_layers):
super(FeatureExtractor, self).__init__()
self.submodule = submodule
self.extracted_layers = extracted_layers
def forward(self, x):
outputs = {}
for name, module in self.submodule._modules.items():
if "fc" in name:
x = x.view(x.size(0), -1)
x = module(x)
print(name)
if self.extracted_layers is None or name in self.extracted_layers and 'fc' not in name:
outputs[name] = x
return outputs
这段主要是存储图片,为每个层创建一个文件夹将特征图以JET的colormap进行按顺序存储到该文件夹,并且如果特征图过小也会对特征图放大同时存储原始图和放大后的图。
for k, v in outs.items():
features = v[0]
iter_range = features.shape[0]
for i in range(iter_range):
#plt.imshow(x[0].data.numpy()[0,i,:,:],cmap='jet')
if 'fc' in k:
continue
feature = features.data.numpy()
feature_img = feature[i,:,:]
feature_img = np.asarray(feature_img * 255, dtype=np.uint8)
dst_path = os.path.join(dst, k)
make_dirs(dst_path)
feature_img = cv2.applyColorMap(feature_img, cv2.COLORMAP_JET)
if feature_img.shape[0] < therd_size:
tmp_file = os.path.join(dst_path, str(i) + '_' + str(therd_size) + '.png')
tmp_img = feature_img.copy()
tmp_img = cv2.resize(tmp_img, (therd_size,therd_size), interpolation = cv2.INTER_NEAREST)
cv2.imwrite(tmp_file, tmp_img)
dst_file = os.path.join(dst_path, str(i) + '.png')
cv2.imwrite(dst_file, feature_img)
这里主要是一些参数,比如要提取的网络,网络的权重,要提取的层,指定的图像放大的大小,存储路径等等。
net = models.resnet101().to(device)
net.load_state_dict(torch.load('./model/resnet101-5d3b4d8f.pt'))
exact_list = None#['conv1']
dst = './feautures'
therd_size = 256
4. 可视化梯度,feature
上面的办法只是简单的将经过网络计算的图片的输出的feature进行图片,github上有将CNN的梯度等全部进行可视化的代码:pytorch-cnn-visualizations,需要注意的是如果只是简单的替换成自己的网络可能无法运行,大概率会报model没有features或者classifier等错误,这两个是进行分类网络定义时的Sequential,其实就是索引网络的每一层,自己稍微修改用model.children()
等方法进行替换即可,我自己修改之后得到的代码grayondream-pytorch-visualization(本来想稍微封装一下成为一个更加通用的结构,暂时没时间以后再说吧!),下面是效果图:
来源:https://blog.csdn.net/GrayOnDream/article/details/99090247


猜你喜欢
- 一、导入所需的库import randomimport cv2from matplotlib import pyplot as pltimp
- 废话不多说 上语句:查询锁表语句:select object_name,machine,s.sid,s.serial#from v$lock
- 回滚段管理一直是ORACLE数据库管理的一个难题,本文通过实例介绍ORACLE回滚段的概念,用法和规划及问题的解决。 回滚段概述 回滚段用于
- 举例为大家介绍如何运用命令行实现MySQL导出导入数据库一、命令行导出数据库1.进入MySQL目录下的bin文件夹:cd MySQL中到bi
- 一、PK(主键约束)1、什么是主键?在了解主键之前,先了解一下什么是关键字关键字:在表中具有唯一性的字段,比如一个人的身份证号,学号。一个表
- 最近各地中小学都在开展线上教学,有些不自觉的小朋友们用电脑在线学习的时候会趁家长不在的时候偷偷玩游戏、看漫画。本程序screenshot.p
- 对于小型站点,使用七牛云存储的免费配额已足够为站点提供稳定、快速的存储服务七牛云存储已有Python SDK,对它进行简单封装后,就可以直接
- 学完了Python脚本接口自动化之后,一直没有对该框架做总结,今天终于试着来做一份总结了。框架结构如下图:来说一下每个目录的作用:Confi
- 本文实例讲述了Django框架登录加上验证码校验实现验证功能。分享给大家供大家参考,具体如下:验证码生成函数pip install Pill
- 今天来说一下如何判断字典中是否存在某个key,一般有两种通用做法,下面为大家来分别讲解一下:第一种方法:使用自带函数实现。在python的字
- 本文实例讲述了python中readline判断文件读取结束的方法。分享给大家供大家参考。具体分析如下:大家知道,python中按行读取文件
- 对于英文不行我来说使用英文版PyCharm实在是太难受了,网上好多汉化补丁都是网友提供了,下面为大家介绍一种PyCharm官方中文语言包汉化
- 最近用到了mysql5.7的json字段的检索查询,发现挺好用的,记录一下笔记我们有一个日志表,里面的data字段是保存不同对象的json数
- 提起数据库,第一个想到的公司,一般都会是Oracle。该公司成立于1977年,最初是一家专门开发数据库的公司。Oracle在数据库领域一直处
- 一、前言介绍xlrd:可以对xlsx、xls、xlsm文件进行读操作且效率高。xlwt:主要对xls文件进行写操作且效率高,但是不能执行xl
- 本文实例为大家分享了JSP使用commons-fileupload实现文件上传代码,供大家参考,具体内容如下1、准备:将commons-fi
- 已经有很多年不使用SQLServer了,毕竟商业版本是个收费的,安装也不容易。最近因为想带领学生学习做个练习性的项目,参考了.net下的pe
- 测试需求 为了更好的测试你的ASP程序,你首先需要决定你的程序将来需要面对多大的压力。简单的说,压力或负载可以分解成以下数字:· 最低用户数
- UDP 套接字是可以使用 connect 系统调用连接到指定的地址的。从此以后,这个套接字只会接收来自这个地址的数据,而且可以使用 send
- 最近有用到对存储过程(procedure)重命名的功能,在网上找了一下资料都没有讲到在mysql中是如何实现的,当然可以删掉再重建,但是应该