画pytorch模型图,以及参数计算的方法
作者:月落乌啼silence 发布时间:2023-09-25 09:12:58
标签:pytorch,模型图,参数
刚入pytorch的坑,代码还没看太懂。之前用keras用习惯了,第一次使用pytorch还有些不适应,希望广大老司机多多指教。
首先说说,我们如何可视化模型。在keras中就一句话,keras.summary(),或者plot_model(),就可以把模型展现的淋漓尽致。
但是pytorch中好像没有这样一个api让我们直观的看到模型的样子。但是有网友提供了一段代码,可以把模型画出来,对我来说简直就是如有神助啊。
话不多说,上代码吧。
import torch
from torch.autograd import Variable
import torch.nn as nn
from graphviz import Digraph
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
)
self.conv2 = nn.Sequential(
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
)
self.out = nn.Linear(32*7*7, 10)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size(0), -1) # (batch, 32*7*7)
out = self.out(x)
return out
def make_dot(var, params=None):
""" Produces Graphviz representation of PyTorch autograd graph
Blue nodes are the Variables that require grad, orange are Tensors
saved for backward in torch.autograd.Function
Args:
var: output Variable
params: dict of (name, Variable) to add names to node that
require grad (TODO: make optional)
"""
if params is not None:
assert isinstance(params.values()[0], Variable)
param_map = {id(v): k for k, v in params.items()}
node_attr = dict(style='filled',
shape='box',
align='left',
fontsize='12',
ranksep='0.1',
height='0.2')
dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
seen = set()
def size_to_str(size):
return '('+(', ').join(['%d' % v for v in size])+')'
def add_nodes(var):
if var not in seen:
if torch.is_tensor(var):
dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
elif hasattr(var, 'variable'):
u = var.variable
name = param_map[id(u)] if params is not None else ''
node_name = '%s\n %s' % (name, size_to_str(u.size()))
dot.node(str(id(var)), node_name, fillcolor='lightblue')
else:
dot.node(str(id(var)), str(type(var).__name__))
seen.add(var)
if hasattr(var, 'next_functions'):
for u in var.next_functions:
if u[0] is not None:
dot.edge(str(id(u[0])), str(id(var)))
add_nodes(u[0])
if hasattr(var, 'saved_tensors'):
for t in var.saved_tensors:
dot.edge(str(id(t)), str(id(var)))
add_nodes(t)
add_nodes(var.grad_fn)
return dot
if __name__ == '__main__':
net = CNN()
x = Variable(torch.randn(1, 1, 28, 28))
y = net(x)
g = make_dot(y)
g.view()
params = list(net.parameters())
k = 0
for i in params:
l = 1
print("该层的结构:" + str(list(i.size())))
for j in i.size():
l *= j
print("该层参数和:" + str(l))
k = k + l
print("总参数数量和:" + str(k))
模型很简单,代码也很简单。就是conv -> relu -> maxpool -> conv -> relu -> maxpool -> fc
大家在可视化的时候,直接复制make_dot那段代码即可,然后需要初始化一个net,以及这个网络需要的数据规模,此处就以 这段代码为例,初始化一个模型net,准备这个模型的输入数据x,shape为(batch,channels,height,width) 然后把数据传入模型得到输出结果y。传入make_dot即可得到下图。
net = CNN()
x = Variable(torch.randn(1, 1, 28, 28))
y = net(x)
g = make_dot(y)
g.view()
最后输出该网络的各种参数。
该层的结构:[16, 1, 5, 5]
该层参数和:400
该层的结构:[16]
该层参数和:16
该层的结构:[32, 16, 5, 5]
该层参数和:12800
该层的结构:[32]
该层参数和:32
该层的结构:[10, 1568]
该层参数和:15680
该层的结构:[10]
该层参数和:10
总参数数量和:28938
来源:https://blog.csdn.net/qq_18293213/article/details/79047742


猜你喜欢
- 目录前言前期准备数据的选择与获取分词筛选与可视化总结前言”数据可视化“这个话题,相信大家并不陌生,在一些平台,经常可以看到一些动态条形图的视
- $forceUpdate()的使用在Vue官方文档中指出,$forceUpdate具有强制刷新的作用。那在vue框架中,如果data中有一个
- 没有使用队列,也没有线程池还在学习只是多线程 #coding:utf8 import urllib2,sys,re import threa
- 这篇文章主要介绍了python已协程方式处理任务实现过程,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的
- 本文为大家分享了macOS下mysql 8.0.16 安装配置教程,供大家参考,具体内容如下1、安装程序下载打开MySQL 官网选择 DOW
- juypter notebook中直接使用log_device_placement=True打印不出来device信息# Creates a
- 利用layui制作与众不同的感谢表单,表格layui极大的提高了前端开发效率,它极具个性的样式等等都非常吸引人,接下来我将为大家展示如何利用
- 本文实例讲述了Vue常用传值方式、父传子、子传父及非父子。分享给大家供大家参考,具体如下:父组件向子组件传值是利用props子组件中的注意事
- 面对不断成长的用户,跟随用户的脚步齐步向前,做引起共鸣的改变,去除低龄化的设计,用成熟稳重的心态面对用户。QQBanner自2006 年推出
- Lambda函数,即Lambda 表达式(lambda expression),是一个匿名函数(不存在函数名的函数),Lambda表达式基于
- 最近 W3C 一口气推出 7 个 HTML 工作草案,涵盖了 HTML5,HTML RDF,HTML Microdata,HTM
- 介绍shutil 名字来源于 shell utilities,有学习或了解过Linux的人应该都对 shell 不陌生,可以借此来记忆模块的
- Fuko Masked 是 Kaloyan Tsvetkov 的一个小型PHP库,用于通过用编辑后的元素替换列入黑名单的元素来屏蔽敏感数据。
- 如下所示:<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional
- 正则表达式,又称正规表示法、常规表示法(英语:Regular Expression,在代码中常简写为regex、regexp或RE),计算机
- 前言最近在维护项目的python项目代码,项目使用了 python 的日志模块 logging, 设定了保存的日志数目, 不过没有生效,还要
- 本篇文章将在项目中引入 typescript,以及手动搭建一个用于测试组件库组件 Vue3 项目因为我们是使用 Vite+Ts 开发的是 V
- 1、fit和fit_generator的区别首先Keras中的fit()函数传入的x_train和y_train是被完整的加载进内存的,当然
- 本文实例为大家分享了Python/C++实现字符串逆序的具体代码,供大家参考,具体内容如下题目描述:将字符串逆序输出Python实现一:借助
- Python实现Mysql数据统计的实例代码如下所示:import pymysqlimport xlwtexcel=xlwt.Workboo