Pytorch模型定义与深度学习自查手册
作者:冬于 发布时间:2023-02-11 18:30:27
定义神经网络
继承nn.Module类;
初始化函数__init__:网络层设计;
forward函数:模型运行逻辑。
class NeuralNetwork(nn.Module): |
权重初始化
pytorch中的权值初始化
方法1:net.apply(weights_init)
def weights_init(m): |
方法2:在网络初始化的时候进行参数初始化
使用net.modules()遍历模型中的网络层的类型;
对其中的m层的weigth.data(tensor)部分进行初始化操作。
class Model(nn.Module): |
常用的操作
利用nn.Parameter()设计新的层
import torch |
nn.Flatten
展平输入的张量: 28x28 -> 784
input = torch.randn(32, 1, 5, 5) |
nn.Sequential
一个有序的容器,神经网络模块将按照在传入构造器的顺序依次被添加到计算图中执行,同时以神经网络模块为元素的有序字典也可以作为传入参数。
net = nn.Sequential( |
常用的层
全连接层nn.Linear()
torch.nn.Linear(in_features, out_features, bias=True,device=None, dtype=None)
in_features: 输入维度 |
m = nn.Linear(20, 30) |
torch.nn.Dropout
''' |
卷积torch.nn.ConvNd()
class torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)
in_channels(int) – 输入信号的通道 |
input: (N,C_in,H_in,W_in) N为批次,C_in即为in_channels,即一批内输入二维数据个数,H_in是二维数据行数,W_in是二维数据的列数
output: (N,C_out,H_out,W_out) N为批次,C_out即为out_channels,即一批内输出二维数据个数,H_out是二维数据行数,W_out是二维数据的列数
conv2 = nn.Conv2d( |
池化
最大池化torch.nn.MaxPoolNd()
torch.nn.MaxPool2d(kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False)
kernel_size- 窗口大小 |
max2=torch.nn.MaxPool2d(3,1,0,1) |
均值池化torch.nn.AvgPoolNd()
kernel_size - 池化窗口大小 |
torch.nn.AvgPool2d(kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True) |
反池化
是池化的一个“逆”过程,但“逆”只是通过上采样恢复到原来的尺寸,像素值是不能恢复成原来一模一样,因为像最大池化是不可逆的,除最大值之外的像素都已经丢弃了。
最大值反池化nn.MaxUnpool2d()
功能:对二维图像进行最大值池化上采样
参数:
kernel_size- 窗口大小 |
torch.nn.MaxUnpool2d(kernel_size, stride=None, padding=0) |
img_tensor=torch.Tensor(16,5,32,32) |
组合池化
组合池化同时利用最大值池化与均值池化两种的优势而引申的一种池化策略。常见组合策略有两种:Cat与Add。其代码描述如下:
def add_avgmax_pool2d(x, output_size=1): |
正则化层
Transformer相关Normalization方式
Normalization Layers
BatchNorm
torch.nn.BatchNorm1d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) |
参数:
num_features: 来自期望输入的特征数,该期望输入的大小为batch_size × num_features [× width],和之前输入卷积层的channel位的维度数目相同 |
# With Learnable Parameters |
LayerNorm
torch.nn.LayerNorm(normalized_shape, eps=1e-05, elementwise_affine=True) |
参数:
normalized_shape: 输入尺寸 |
LayerNorm
就是对(2, 2,4), 后面这一部分进行整个的标准化。可以理解为对整个图像进行标准化。
x_test = np.array([[[1,2,-1,1],[3,4,-2,2]], |
InstanceNorm
torch.nn.InstanceNorm1d(num_features, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False) |
参数:
num_features: 来自期望输入的特征数,该期望输入的大小为batch_size x num_features [x width] |
InstanceNorm
就是对(2, 2, 4)最后这一部分进行Norm。
x_test = np.array([[[1,2,-1,1],[3,4,-2,2]], |
GroupNorm
torch.nn.GroupNorm(num_groups, num_channels, eps=1e-05, affine=True) |
参数:
num_groups:需要划分为的groups |
当GroupNorm中group的数量是1的时候, 是与上面的LayerNorm是等价的。
x_test = np.array([[[1,2,-1,1],[3,4,-2,2]], |
当GroupNorm中num_groups的数量等于num_channel的数量,与InstanceNorm等价。
# Separate 2 channels into 2 groups (equivalent with InstanceNorm) |
激活函数
参考资料:GELU 激活函数
Pytorch激活函数及优缺点比较
torch.nn.GELU
bert源码给出的GELU代码pytorch版本表示如下:
def gelu(input_tensor): |
torch.nn.ELU(alpha=1.0,inplace=False)
def elu(x,alpha=1.0,inplace=False): |
α是超参数,默认为1.0
torch.nn.LeakyReLU(negative_slope=0.01,inplace=False)
def LeakyReLU(x,negative_slope=0.01,inplace=False):
return max(0,x)+negative_slope∗min(0,x)
其中 negative_slope是超参数,控制x为负数时斜率的角度,默认为1e-2
torch.nn.PReLU(num_parameters=1,init=0.25)
def PReLU(x,num_parameters=1,init=0.25):
return max(0,x)+init∗min(0,x)
其中a 是一个可学习的参数,当不带参数调用时,即nn.PReLU(),在所有的输入通道上使用同一个a,当带参数调用时,即nn.PReLU(nChannels),在每一个通道上学习一个单独的a。
注意:当为了获得好的performance学习一个a时,不要使用weight decay。
num_parameters:要学习的a的个数,默认1
init:a的初始值,默认0.25
torch.nn.ReLU(inplace=False)
CNN中最常用ReLu。
def ReLU(x,inplace=False):
return max(0,x)
torch.nn.ReLU6(inplace=False)
def ReLU6(x,inplace=False):
return min(max(0,x),6)
torch.nn.SELU(inplace=False)
def SELU(x,inplace=False):
alpha=1.6732632423543772848170429916717
scale=1.0507009873554804934193349852946
return scale∗(max(0,*x*)+min(0,alpha∗(exp(x)−1)))
torch.nn.CELU(alpha=1.0,inplace=False)
def CELU(x,alpha=1.0,inplace=False):
return max(0,x)+min(0,alpha∗(exp(x/alpha)−1))
其中α 默认为1.0
torch.nn.Sigmoid
def Sigmoid(x):
return 1/(np.exp(-x)+1)
torch.nn.LogSigmoid
def LogSigmoid(x):
return np.log(1/(np.exp(-x)+1))
torch.nn.Tanh
def Tanh(x):
return (np.exp(x)-np.exp(-x))/(np.exp(x)+np.exp(-x))
torch.nn.Tanhshrink
def Tanhshrink(x):
return x-(np.exp(x)-np.exp(-x))/(np.exp(x)+np.exp(-x))
torch.nn.Softplus(beta=1,threshold=20)
该函数可以看作是ReLu的平滑近似。
def Softplus(x,beta=1,threshold=20):
return np.log(1+np.exp(beta*x))/beta
torch.nn.Softshrink(lambd=0.5)
λ的值默认设置为0.5
def Softshrink(x,lambd=0.5):
if x>lambd:return x-lambd
elif x<-lambd:return x+lambd
else:return 0
nn.Softmax
m = nn.Softmax(dim=1)
input = torch.randn(2, 3)
output = m(input)
参考资料
Pytorch激活函数及优缺点比较
PyTorch快速入门教程二(线性回归以及logistic回归)
Pytorch全连接网络
pytorch系列之nn.Sequential讲解
来源:https://blog.csdn.net/weixin_43243315/article/details/121657881


猜你喜欢
- 我希望大家看到该标题就能让想象到它的功能: 1、WITH TEMPL
- 查看数据库状态:service mysqld status 启动数据库:service mysqld start&
- 本文实例为大家分享了python实现抠图的具体代码,供大家参考,具体内容如下其中使用了opencv中的grabcut方法直接上代码# enc
- 写这篇文章的缘由是我使用 reqeusts 库请求接口的时候, 直接使用请求参数里的 json 字段发送数据, 但是服务器无法识别我发送的数
- 相信每一个 javascript 学习者,都会去了解 JS 的各种基本数据类型,数组就是数据的组合,这是一个很基本也十分简单的概念,他的内容
- 字符编码我们已经讲过了,字符串也是一种数据类型,但是,字符串比较特殊的是还有一个编码问题。因为计算机只能处理数字,如果要处理文本,就必须先把
- 关于Java和Mysql 8.0.18版本的连接方式,供大家参考,具体内容如下1.官网下载mysql-server.(Connector/J
- 环境准备数据库版本:MySQL 5.7.20-log建表 SQLDROP TABLE IF EXISTS `t_ware_sale_stat
- 如下所示:from win32com.client import constantsimport osimport win32com.cli
- 一、前言 需求是获取某个时间范围内每小时数据和上小时数据的差值以及比率
- kaggle是一个为开发商和数据科学家提供举办机器学习竞赛、托管数据库、编写和分享代码的平台,在这上面有非常多的好项目、好资源可供机器学习、
- 数据表DROP TABLE IF EXISTS tb_score;CREATE TABLE tb_score( i
- 数组是一种有序的集合,可随时添加、删除其中的元素book = ['xiao zhu pei qi','xiao ji
- 本文详细解说了MySQL Order By Rand()效率优化的方案,并给出了优化的思路过程,是篇不可多得的MySQL Order By
- 这两天看了下某位大神的github,知道他对算法比较感兴趣,看了其中的一个计算数字的步数算法,感觉这个有点意思,所以就自己实现了一个。算法描
- Ctrl+N 按文件名搜索py文件ctrl+n可以搜索py文件勾选上面这个框可以搜索工程以外的文件Ctrl+shift+N 按文件名搜索所有
- 目录outputoutput.pathoutput.publicPathwebpack-dev-server中的publicPathhtml
- assert(断言)用于判断一个表达式,在表达式条件为 false 的时候触发异常。断言可以在条件不满足程序运行的情况下直接返回错误,而不必
- 增加字段alter table docdsp add dspcode char(200)删除字段ALTER TABLE tabl
- 前言目前在做vue的项目,用到了子组件依赖其父组件的数据,进行子组件的相关请求和页面数据展示,父组件渲染需要子组件通知更新父组件的state