pytorch __init__、forward与__call__的用法小结
作者:时光碎了天 发布时间:2023-09-04 13:20:47
1.介绍
当我们使用pytorch来构建网络框架的时候,也会遇到和tensorflow(tensorflow __init__、build 和call小结)类似的情况,即经常会遇到__init__、forward和call这三个互相搭配着使用,那么它们的主要区别又在哪里呢?
1)__init__主要用来做参数初始化用,比如我们要初始化卷积的一些参数,就可以放到这里面,这点和tf里面的用法是一样的
2)forward是表示一个前向传播,构建网络层的先后运算步骤
3)__call__的功能其实和forward类似,所以很多时候,我们构建网络的时候,可以用__call__替代forward函数,但它们两个的区别又在哪里呢?
当网络构建完之后,调__call__的时候,会去先调forward,即__call__其实是包了一层forward,所以会导致两者的功能类似。
在pytorch在nn.Module中,实现了__call__方法,而在__call__方法中调用了forward函数:
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py
2.代码
import torch
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self, in_channels, mid_channels, out_channels):
super(Net, self).__init__()
self.conv0 = torch.nn.Sequential(
torch.nn.Conv2d(in_channels, mid_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
torch.nn.LeakyReLU())
self.conv1 = torch.nn.Sequential(
torch.nn.Conv2d(mid_channels, out_channels * 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))
def forward(self, x):
x = self.conv0(x)
x = self.conv1(x)
return x
class Net(nn.Module):
def __init__(self, in_channels, mid_channels, out_channels):
super(Net, self).__init__()
self.conv0 = torch.nn.Sequential(
torch.nn.Conv2d(in_channels, mid_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
torch.nn.LeakyReLU())
self.conv1 = torch.nn.Sequential(
torch.nn.Conv2d(mid_channels, out_channels * 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))
def __call__(self, x):
x = self.conv0(x)
x = self.conv1(x)
return x
补充:torch/nn目录结构以及__init__.py
torch/nn目录结构以及init.py
torch/nn目录结构
__init__.py:
from .modules import *
#nn.modules 导入modules目录下内容 定义容器modules
from .parameter import Parameter
#nn.Parameter 导入parameter.py 定义parameter
from .parallel import DataParallel
#导入parallel目录下data_parallel.py中的DataParallel类
from . import init
#nn.init 导入init.py 参数初始化
from . import utils
#nn.utils 导入utils目录下内容 官网api下nn.utils下api
对于backends, functional.py, _functions 需要在代码前重新Import
例如我们常用的
import torch.nn.functional as F 就是导入了functional.py
backends和_functions是functional.py实现各种函数时所用到的。
以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。如有错误或未考虑完全的地方,望不吝赐教。
来源:https://blog.csdn.net/u013289254/article/details/103826591
猜你喜欢
- 如下所示:#-*- coding: utf-8 -*-#code:myhaspl@qq.com#12-1.pyimport sysreloa
- 使用Keras如果要使用大规模数据集对网络进行训练,就没办法先加载进内存再从内存直接传到显存了,除了使用Sequence类以外,还可以使用迭
- 初入计算机视觉遇到的一些坑1.pytorch中转tensorx=np.random.randint(10,100,(10,10,10))x=
- Python下一切皆对象,每个对象都有多个属性(attribute),Python对属性有一套统一的管理方案。__dict__与dir()的
- 最近在研究tensorflow自带的例程speech_command,顺便学习tensorflow的一些基本用法。其中tensorboard
- 一、简介Python是一门功能强大的高级脚本语言,它的强大不仅表现在其自身的功能上,而且还表现在其良好的可扩展性上,正因如此,Python已
- 前阵子刚完成一个B/S架构的学校办公系统,体会就是表太多,文件太多,而每个文件中类似的操作(代码)也太多了,例如学生信息和教师信息操作,st
- 本文实例讲述了Python操作串口的方法。分享给大家供大家参考。具体如下:首先需确保安装了serial模块,如果没安装的话就安装一下pyth
- PHP implode() 函数实例把数组元素组合为一个字符串:<?php $arr = array('Hello',
- 描述的意思是HTML为中心的前端开发也差不多是web标准的意思。1.HTML是基础2.CSS依靠选择符提供视觉;3.Javascript 依
- 实例一:题目:有四个数字:1、2、3、4,能组成多少个互不相同且无重复数字的三位数?各是多少?程序分析:可填在百位、十位、个位的数字都是1、
- 1.使用Docker安装Elasticsearch及其扩展获取镜像,可以通过网络pullsudo docker image pull del
- 【原文地址】 Fixes for Common VS 2008 and .NET 3.5 Beta2 Issu
- Celery (芹菜)是基于Python开发的分布式任务队列。它支持使用任务队列的方式在分布的机器/进程/线程上执行任务调度。架
- 前言其实Python 的列表(list)内部实现是一个数组,也就是一个线性表。在列表中查找元素可以使用 list.index() 方法,其时
- 目录Mock概念Mock类简单的例子体验下 Mock 的功能特点一个相对正式的 Mock 例子一个完整的测试例子断言方法Mock概念mock
- 使用torchvision来进行图片的数据增广数据增强就是增强一个已有数据集,使得有更多的多样性。对于图片数据来说,就是改变图片的颜色和形状
- 通过优化CSS代码,减小对系统资源的占用。自己整理出几个能减少系统资源占用的CSS写法,要优化网站的页面加载速度,这些注意点不能忽视!一、尽
- 如下所示:# -*-coding:utf-8-*-from pandas import DataFrameimport pandas as
- 一个出错的例子#coding:utf-8s = u'中文'f = open("test.txt",&qu