pytorch 可视化feature map的示例代码
作者:牛丸4 发布时间:2021-10-21 13:35:49
标签:pytorch,可视化,feature,map
之前做的一些项目中涉及到feature map 可视化的问题,一个层中feature map的数量往往就是当前层out_channels的值,我们可以通过以下代码可视化自己网络中某层的feature map,个人感觉可视化feature map对调参还是很有用的。
不多说了,直接看代码:
import torch
from torch.autograd import Variable
import torch.nn as nn
import pickle
from sys import path
path.append('/residual model path')
import residual_model
from residual_model import Residual_Model
model = Residual_Model()
model.load_state_dict(torch.load('./model.pkl'))
class myNet(nn.Module):
def __init__(self,pretrained_model,layers):
super(myNet,self).__init__()
self.net1 = nn.Sequential(*list(pretrained_model.children())[:layers[0]])
self.net2 = nn.Sequential(*list(pretrained_model.children())[:layers[1]])
self.net3 = nn.Sequential(*list(pretrained_model.children())[:layers[2]])
def forward(self,x):
out1 = self.net1(x)
out2 = self.net(out1)
out3 = self.net(out2)
return out1,out2,out3
def get_features(pretrained_model, x, layers = [3, 4, 9]): ## get_features 其实很简单
'''
1.首先import model
2.将weights load 进model
3.熟悉model的每一层的位置,提前知道要输出feature map的网络层是处于网络的那一层
4.直接将test_x输入网络,*list(model.chidren())是用来提取网络的每一层的结构的。net1 = nn.Sequential(*list(pretrained_model.children())[:layers[0]]) ,就是第三层前的所有层。
'''
net1 = nn.Sequential(*list(pretrained_model.children())[:layers[0]])
# print net1
out1 = net1(x)
net2 = nn.Sequential(*list(pretrained_model.children())[layers[0]:layers[1]])
# print net2
out2 = net2(out1)
#net3 = nn.Sequential(*list(pretrained_model.children())[layers[1]:layers[2]])
#out3 = net3(out2)
return out1, out2
with open('test.pickle','rb') as f:
data = pickle.load(f)
x = data['test_mains'][0]
x = Variable(torch.from_numpy(x)).view(1,1,128,1) ## test_x必须为Varibable
#x = Variable(torch.randn(1,1,128,1))
if torch.cuda.is_available():
x = x.cuda() # 如果模型的训练是用cuda加速的话,输入的变量也必须是cuda加速的,两个必须是对应的,网络的参数weight都是用cuda加速的,不然会报错
model = model.cuda()
output1,output2 = get_features(model,x)## model是训练好的model,前面已经import 进来了Residual model
print('output1.shape:',output1.shape)
print('output2.shape:',output2.shape)
#print('output3.shape:',output3.shape)
output_1 = torch.squeeze(output2,dim = 0)
output_1_arr = output_1.data.cpu().numpy() # 得到的cuda加速的输出不能直接转变成numpy格式的,当时根据报错的信息首先将变量转换为cpu的,然后转换为numpy的格式
output_1_arr = output_1_arr.reshape([output_1_arr.shape[0],output_1_arr.shape[1]])
来源:https://blog.csdn.net/baidu_36161077/article/details/81388221


猜你喜欢
- 1.腾讯企业邮箱SMTP服务器地址:smtp.exmail.qq.com,ssl端口为:4652.确保腾讯企业邮箱中开启了SMTP服务:3.
- 初步认识k-means翻译过来就是K均值聚类算法,其目的是将样本分割为k个簇,而这个k则是KMeans中最重要的参数:n_clusters,
- 我自己测试一下,很多字符变成了 ‘?'。数据库连接已经是使用了 utf8 字符集:define("MYSQL_ENCODE
- 此前带领小组成员主导过一个百万行代码上位机项目的重构工作,分析项目中存在的问题做了些针对性的优化,整个重构工作持续了一年半之久。主要针对以下
- 基本功能:能够实现学生成绩相关信息的输入、输出、查找、删除、修改等功能;(使用数据库对数据进行存取)输入并存储学生的信息:通过输入学生的学号
- 一、获取安装包最近的版本为0.4.12,下载地址:http://sourceforge.net/projects/sysbench/二、编译
- 1.创建mysql存储过程,这是个复杂查询加上了判断,比较复杂CREATE PROCEDURE searchAllList (IN trad
- 本文实例讲述了Python实现栈和队列的简单操作方法。分享给大家供大家参考,具体如下:先简单的了解一下数据结构里面的栈和堆:栈和队列是两种基
- 1、创建数组 var array = new Array(); var array = new Array(size);//指定数组的长度
- 例1:#!/usr/bin/perluse strict; use warnings;my $test = "asdf"
- 多表连接的基本语法多表连接,就是将几张表拼接为一张表,然后进行查询select 字段1, 字段2, ...from 表1 {inner|li
- 现在网上出现了很多在线换底色的网页版工具是这么做的呢?其实用Python就可以实现。环境要求Python3 numpy函数库 opencv库
- facebook的信息架构设计,是目前为止互联网上我见过的最合理的信息架构。每次培训,我基本都需要拿20分钟左右的时间来解析它,包括老的、新
- 先看看单条 SQL 语句的分页 SQL 吧。 方法1: 适用于 SQL Server 2000/2005 代码如下:SELECT TOP 页
- 本文实例讲述了JS实现运动缓冲效果的封装函数。分享给大家供大家参考,具体如下:之前经常写运动函数,要写好多好多,后来想办法封装起来。(运动缓
- 本文介绍基于Python语言中gdal模块,对遥感影像数据进行栅格读取与计算,同时基于QA波段对像元加以筛选、掩膜的操作。本文所要实现的需求
- 简单计数器代码如下所示:<% Set fs = CreateObject("Scri
- 本文使用的是最新的FCKeditor 2.3.1版本 官方网站下载: http://ckeditor.com/download[建议直接在官
- 是的,这仅仅是一个PPT文档,由Anna Debenham上传至slideshare。幻灯片的标题叫做《CSS nuggets》,嗯,很好的
- 什么是Flask?Flask是一个用Python编写的Web应用程序框架,Flask是python的web框架,最大的特征是轻便,让开发者自