浅析PyTorch中nn.Module的使用
作者:Steven·简谈 发布时间:2021-10-29 14:04:53
标签:PyTorch,nn.Module
torch.nn.Modules 相当于是对网络某种层的封装,包括网络结构以及网络参数和一些操作
torch.nn.Module 是所有神经网络单元的基类
查看源码
初始化部分:
def __init__(self):
self._backend = thnn_backend
self._parameters = OrderedDict()
self._buffers = OrderedDict()
self._backward_hooks = OrderedDict()
self._forward_hooks = OrderedDict()
self._forward_pre_hooks = OrderedDict()
self._state_dict_hooks = OrderedDict()
self._load_state_dict_pre_hooks = OrderedDict()
self._modules = OrderedDict()
self.training = True
属性解释:
_parameters:字典,保存用户直接设置的 Parameter
_modules:子 module,即子类构造函数中的内容
_buffers:缓存
_backward_hooks与_forward_hooks:钩子技术,用来提取中间变量
training:判断值来决定前向传播策略
方法定义:
def forward(self, *input):
raise NotImplementedError
没有实际内容,用于被子类的 forward() 方法覆盖
且 forward 方法在 __call__ 方法中被调用:
def __call__(self, *input, **kwargs):
for hook in self._forward_pre_hooks.values():
hook(self, input)
if torch._C._get_tracing_state():
result = self._slow_forward(*input, **kwargs)
else:
result = self.forward(*input, **kwargs)
...
...
实例展示
简单搭建:
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self, n_feature, n_hidden, n_output):
super(Net, self).__init__()
self.hidden = nn.Linear(n_feature, n_hidden)
self.out = nn.Linear(n_hidden, n_output)
def forward(self, x):
x = F.relu(self.hidden(x))
x = self.out(x)
return x
Net 类继承了 torch 的 Module 和 __init__ 功能
hidden 是隐藏层线性输出
out 是输出层线性输出
打印出网络的结构:
>>> net = Net(n_feature=10, n_hidden=30, n_output=15)
>>> print(net)
Net(
(hidden): Linear(in_features=10, out_features=30, bias=True)
(out): Linear(in_features=30, out_features=15, bias=True)
)
来源:https://blog.csdn.net/weixin_44613063/article/details/90297299


猜你喜欢
- Gtid + Mha +Binlog server配置:1:测试环境OS:CentOS 6.5Mysql:5.6.28Mha:0.56192
- 代码如下import matplotlib.pyplot as pltimport numpy as npdef test4(): &nbs
- 准备工作创建表use [test1]gocreate table [dbo].[student]( [id] [int] ide
- 前言和Word、Excel承载数据的能力相比,PPT的应用重点在于表演。比如一场发布会、一场演说、一次产品展示、一次客户沟通&hel
- 本文实例分析了JS重载实现方法。分享给大家供大家参考,具体如下:重载是面向对象语言里很重要的一个特性,JS中没有真正的重载,是模拟出来的(因
- 回文利用python 自带的翻转 函数 reversed()def is_plalindrome(string): return
- CSS Sprites技术不新鲜,早在2005年 CSS Zengarden 的园主 Dave Shea 就在 ALA
- 环境:【wind2003[open Tftp server] + virtualbox:ubuntn10 server】tftp
- JS实现轮播图实现结果图:需求:1 根据图片动态添加小圆点 2 目标移动到小圆点轮播图片 3 鼠标离开图片,定时轮播图片;鼠标在图片上时暂停
- 引言:2020年12月20python宣布适配苹果m1芯片,这意味着python3.9.0可以不经过rosetta转化,以原生的方式运行在最
- 简介testing是 Go 语言标准库自带的测试库。在 Go 语言中编写测试很简单,只需要遵循 Go 测试的几个约定,与编写正常的 Go 代
- 我的环境,Windows10,Python3.6.3查询了很多有关资料,发现都是Python2版本操作Word文件的,所以就写了这篇短小的文
- 本文实例讲述了Python3.5运算符操作。分享给大家供大家参考,具体如下:1、运算符的分类2、算术运算符示例代码:#!/usr/bin/e
- 概述在使用keras中的keras.backend.batch_dot和tf.matmul实现功能其实是一样的智能矩阵乘法,比如A,B,C,
- CONVERT函数用于将值转换为指定的数据类型或字符集1.转换指定字符集CONVERT函数用于将字符串expr的字符集变成transcodi
- 1,使用mysqldump时报错(1064),这个是因为mysqldump版本太低与当前数据库版本不一致导致的。mysqldump: Cou
- 作者:敖士伟 Email:ikmb@163.com 转载注明作者 说明: 1、js根据表单元素class属性,把表单元素的name和valu
- 1 集合集合可以使用大括号({})或者set()函数进行创建,但是创建一个空集合必须使用set()函数,而不能用{},大括号是用来创建一个空
- 好不容易有个周末,不能闲着,趁着这个时间安装sql server2016正式版,下载那个安装包都用了一个星期安装包可以从这里下载:http:
- 具体代码如下所示:import smtplib, email, os, timefrom email.mime.multipart impo