pytorch打印网络结构的实例
作者:每天都要深度学习 发布时间:2023-11-04 15:15:51
标签:pytorch,打印,网络结构
最简单的方法当然可以直接print(net),但是这样网络比较复杂的时候效果不太好,看着比较乱;以前使用caffe的时候有一个网站可以在线生成网络框图,tensorflow可以用tensor board,keras中可以用model.summary()、或者plot_model()。pytorch没有这样的API,但是可以用代码来完成。
(1)安装环境:graphviz
conda install -n pytorch python-graphviz
或:
sudo apt-get install graphviz
或者从官网下载,按此教程。
(2)生成网络结构的代码:
def make_dot(var, params=None):
""" Produces Graphviz representation of PyTorch autograd graph
Blue nodes are the Variables that require grad, orange are Tensors
saved for backward in torch.autograd.Function
Args:
var: output Variable
params: dict of (name, Variable) to add names to node that
require grad (TODO: make optional)
"""
if params is not None:
assert isinstance(params.values()[0], Variable)
param_map = {id(v): k for k, v in params.items()}
node_attr = dict(style='filled',
shape='box',
align='left',
fontsize='12',
ranksep='0.1',
height='0.2')
dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
seen = set()
def size_to_str(size):
return '('+(', ').join(['%d' % v for v in size])+')'
def add_nodes(var):
if var not in seen:
if torch.is_tensor(var):
dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
elif hasattr(var, 'variable'):
u = var.variable
name = param_map[id(u)] if params is not None else ''
node_name = '%s\n %s' % (name, size_to_str(u.size()))
dot.node(str(id(var)), node_name, fillcolor='lightblue')
else:
dot.node(str(id(var)), str(type(var).__name__))
seen.add(var)
if hasattr(var, 'next_functions'):
for u in var.next_functions:
if u[0] is not None:
dot.edge(str(id(u[0])), str(id(var)))
add_nodes(u[0])
if hasattr(var, 'saved_tensors'):
for t in var.saved_tensors:
dot.edge(str(id(t)), str(id(var)))
add_nodes(t)
add_nodes(var.grad_fn)
return dot
(3)打印网络结构:
import torch
from torch.autograd import Variable
import torch.nn as nn
from graphviz import Digraph
class CNN(nn.module):
def __init__(self):
******
def forward(self,x):
******
return out
*****************************
def make_dot(): #复制上面的代码
*****************************
if __name__ == '__main__':
net = CNN()
x = Variable(torch.randn(1, 1, 1024,1024))
y = net(x)
g = make_dot(y)
g.view()
params = list(net.parameters())
k = 0
for i in params:
l = 1
print("该层的结构:" + str(list(i.size())))
for j in i.size():
l *= j
print("该层参数和:" + str(l))
k = k + l
print("总参数数量和:" + str(k))
(4)结果展示(例如这是一个resnet block类型的网络):
来源:https://blog.csdn.net/Lucifer_zzq/article/details/80657513
0
投稿
猜你喜欢
- 目前,很多互联网应用程序都提供了全文搜索功能,用户可以使用一个词或者词语片断作为查询项目来定位匹配的记录。在后台,这些程序使用在一个SELE
- 本文也是开发项目中的一个小经验Tip,虽然很简单,但对很多朋友也有小帮助。我们实际工程中,可能遇到开发环境、预上线环境、线上环境等环境场景,
- 在网页中经常见到两类不同的按钮。一类表示当前所示的状态,一类表示将要进行的动作。(如下图) 那么,同样是icon类的按钮,为什么有
- 本文实例为大家分享了Django文件上传与下载的具体代码,供大家参考,具体内容如下文件上传1.新建django项目,创建应用stu: pyt
- 参数strSQL 要导出的SQL查询语句strFields 字段名称列表,如果为空字符,则使用SQL语句中的字段名用法示例:1:export
- 前言观前提醒:因为是代码控制统计,所以操作每一个步骤都很重要,否则就会报错。操作步骤1.将在线编辑文档导入本地。为了方便代码处理,将导出的e
- IDE(集成开发环境)或换句话说PHP编辑器是开发人员在构建移动或Web应用必不可少的工具。在这篇文章中,我们将讨论有关PHP编辑器并分享5
- 源码下载地址网络构建一、什么是SRGANSRGAN出自论文Photo-Realistic Single Image Super-Resolu
- 目录简单的验证码简单的登录页面我们经常在登录一个网站,或者注册的时候需要输入一个验证码,有时候觉得很烦,因为有些验证码不仅复杂还看不清,许多
- 推荐的国内镜像站[ 个人推荐清华大学pypi镜像站(https://mirrors.tuna.tsinghua.edu.cn/help/py
- XML的未来 现在你已经知道XML。确实,结构有点复杂,而且DTD有各种可以定义文档可以包含的内容的选项。但还不只这些。考虑一个数据交换对其
- 使用Python 2.7 + pywin32 + wxpython开发每隔一段时间检测一下服务是否停止,如果停止尝试启动服务。进行服务停止日
- 前言为了满足用户渠道推广分析和用户账号绑定等场景的需要,公众平台提供了生成带参数二维码的接口。使用该接口可以获得多个带不同场景值的二维码,用
- 耦合两个或以上的体系或两种运动形式间相互作用而彼此影响以至于联合起来的现象。在软件工程中,对象之间的耦合度就是对象之间的依赖性,对象之间的耦
- 如下所示:def list_dict_duplicate_removal(): data_list = [{"a&qu
- 背景在注册或者登陆场景下,经常会遇到需要输入图片验证码的情况,最经典的就是12306买火车票。图片验证码的破解还是有一定难度的,而且如果配合
- 1.关系模型序列化1.1 什么是序列化?什么是反序列化?序列化的意思是把字典的形式转化成Json格式。当我们展示数据的时候需要使用。反序列化
- function MakeUrl($arr){  
- 对于任何一个开发项目来说最大的错误可能就是没有计划。最近,有些人认为开始前无需计划,一个优秀的开发者需要的是随机应变。我敢肯定这样的做法最后
- 实现效果:实现代码import numpy as npfrom skimage import img_as_floatimport matp