pytorch 模型可视化的例子
作者:我~ 发布时间:2023-06-13 08:24:34
标签:pytorch,模型,可视化
如下所示:
一. visualize.py
from graphviz import Digraph
import torch
from torch.autograd import Variable
def make_dot(var, params=None):
""" Produces Graphviz representation of PyTorch autograd graph
Blue nodes are the Variables that require grad, orange are Tensors
saved for backward in torch.autograd.Function
Args:
var: output Variable
params: dict of (name, Variable) to add names to node that
require grad (TODO: make optional)
"""
if params is not None:
assert isinstance(params.values()[0], Variable)
param_map = {id(v): k for k, v in params.items()}
node_attr = dict(style='filled',
shape='box',
align='left',
fontsize='12',
ranksep='0.1',
height='0.2')
dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
seen = set()
def size_to_str(size):
return '('+(', ').join(['%d' % v for v in size])+')'
def add_nodes(var):
if var not in seen:
if torch.is_tensor(var):
dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
elif hasattr(var, 'variable'):
u = var.variable
name = param_map[id(u)] if params is not None else ''
node_name = '%s\n %s' % (name, size_to_str(u.size()))
dot.node(str(id(var)), node_name, fillcolor='lightblue')
else:
dot.node(str(id(var)), str(type(var).__name__))
seen.add(var)
if hasattr(var, 'next_functions'):
for u in var.next_functions:
if u[0] is not None:
dot.edge(str(id(u[0])), str(id(var)))
add_nodes(u[0])
if hasattr(var, 'saved_tensors'):
for t in var.saved_tensors:
dot.edge(str(id(t)), str(id(var)))
add_nodes(t)
add_nodes(var.grad_fn)
return dot
二. 使用步骤
import torch
from torch.autograd import Variable
from models import *
from visualize import make_dot
x = Variable(torch.rand(1, 3, 256, 256))
model = GeneratorUNet()
y = model(x)
g = make_dot(y)
g.view()
三. 效果展示
来源:https://blog.csdn.net/weixin_42445501/article/details/81221362
0
投稿
猜你喜欢
- 以图像处理见长的微软Live实验室,最近发布了一款新作:Pivot。装完启动后的第一印象就是一款浏览器,和IE、FF、Chrome又不太一样
- socket文件:当用Unix域套接字方式进行连接时需要的文件。pid文件:MySQL实例的进程ID文件。1.pid-file介绍MySQL
- OUTPUT是SQL SERVER2005的新特性,可以从数据修改语句中返回输出,可以看作是"返回结果的DML"。INS
- 一、需求 + 最终实现注:只是前端实现1. 需求需求来源是因为有一个做嵌入式 C/C++的 * 做了一个远程计算器。 需求是要求支持输入一个四
- Django 教程Python下有许多款不同的 Web 框架。Django是重量级选手中最有代表性的一位。许多成功的网站和APP都基于Dja
- 插件下载:blueideasearch.xpi首先第一步 说一下怎么样查看firefox插件的源码, 就我上边写的那个东西,把它下载下来.将
- 前言老早就看到新闻员工通过人脸识别监控老板来摸鱼。有时候摸鱼太入迷了,经常在上班时间玩其他的东西被老板看到。自从在咸鱼上淘了一个树莓派3b,
- plotly 的 Python 软件包是一个开源的代码库,它基于 plot.js,而后者基于 d3.js。我们实际使用的则是一个对 plot
- 构造查询条件worm是一款方便易用的Go语言ORM库。worm支Model方式(持结构体字段映射)、原生SQL以及SQLBuilder三种模
- 本文实例讲述了JS小游戏的仙剑翻牌源码,是一款非常优秀的游戏源码。分享给大家供大家参考。具体如下:一、游戏介绍:这是一个翻牌配对游戏,共十关
- Pytorch:Conv2d卷积前后尺寸Conv2d参数尺寸变化卷积前的尺寸为(N,C,W,H) ,卷积后尺寸为(N,F,W_n,H_n)W
- 这个问题我在给新云CMS升级时遇到了,按照升级步骤做完,后台登录时,出现“HTTP 错误 500.100 - 内部服务器错误 - ASP 错
- 在网站建设中,分类算法的应用非常的普遍。在设计一个电子商店时,要涉及到商品分类;在设计发布系统时,要涉及到栏目或者频道分类;在设计软件下载这
- 在python3.5下安装好matplotlib后,准备显示一张图片测试一下,但是控制台报错说需要安装python3-tk,我天
- 本文实例讲述了Python3实现并发检验代理池地址的方法。分享给大家供大家参考,具体如下:#encoding=utf-8#author: w
- 程序运行,产生如下结果,然后进程终止,导致这一结果的原因很有可能是内存 * 。当两个较大的 (e.g., 10000*10000 维)ndar
- 读取问题如下所示,我们在文本中写了一个问题,然后将其读取出来。“黄河远上白云间,一片孤城万仞山。”的作者是谁?王之涣李白白居易杜甫file
- 准备软件:1. J2SDK(1.5.0): jdk-1_5_0-linux-i586-rpm.bin2. Apache(2.0.53): h
- 1、前言 MySQL 是完全网络化的跨平台关系型数据库系统,同时是具有客户机/服务器体系结构的分布式数据库管理系统。它具有功能强、使用简便、
- 在python3中使用dict.keys()返回的不在是list类型了,也不支持索引,我们可以看一下下面这张图片那么我们应该怎么办呢,其实解