Pytorch模型定义与深度学习自查手册
作者:冬于 发布时间:2023-02-11 18:30:27
定义神经网络
继承nn.Module类;
初始化函数__init__:网络层设计;
forward函数:模型运行逻辑。
class NeuralNetwork(nn.Module): |
权重初始化
pytorch中的权值初始化
方法1:net.apply(weights_init)
def weights_init(m): |
方法2:在网络初始化的时候进行参数初始化
使用net.modules()遍历模型中的网络层的类型;
对其中的m层的weigth.data(tensor)部分进行初始化操作。
class Model(nn.Module): |
常用的操作
利用nn.Parameter()设计新的层
import torch |
nn.Flatten
展平输入的张量: 28x28 -> 784
input = torch.randn(32, 1, 5, 5) |
nn.Sequential
一个有序的容器,神经网络模块将按照在传入构造器的顺序依次被添加到计算图中执行,同时以神经网络模块为元素的有序字典也可以作为传入参数。
net = nn.Sequential( |
常用的层
全连接层nn.Linear()
torch.nn.Linear(in_features, out_features, bias=True,device=None, dtype=None)
in_features: 输入维度 |
m = nn.Linear(20, 30) |
torch.nn.Dropout
''' |
卷积torch.nn.ConvNd()
class torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)
in_channels(int) – 输入信号的通道 |
input: (N,C_in,H_in,W_in) N为批次,C_in即为in_channels,即一批内输入二维数据个数,H_in是二维数据行数,W_in是二维数据的列数
output: (N,C_out,H_out,W_out) N为批次,C_out即为out_channels,即一批内输出二维数据个数,H_out是二维数据行数,W_out是二维数据的列数
conv2 = nn.Conv2d( |
池化
最大池化torch.nn.MaxPoolNd()
torch.nn.MaxPool2d(kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False)
kernel_size- 窗口大小 |
max2=torch.nn.MaxPool2d(3,1,0,1) |
均值池化torch.nn.AvgPoolNd()
kernel_size - 池化窗口大小 |
torch.nn.AvgPool2d(kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True) |
反池化
是池化的一个“逆”过程,但“逆”只是通过上采样恢复到原来的尺寸,像素值是不能恢复成原来一模一样,因为像最大池化是不可逆的,除最大值之外的像素都已经丢弃了。
最大值反池化nn.MaxUnpool2d()
功能:对二维图像进行最大值池化上采样
参数:
kernel_size- 窗口大小 |
torch.nn.MaxUnpool2d(kernel_size, stride=None, padding=0) |
img_tensor=torch.Tensor(16,5,32,32) |
组合池化
组合池化同时利用最大值池化与均值池化两种的优势而引申的一种池化策略。常见组合策略有两种:Cat与Add。其代码描述如下:
def add_avgmax_pool2d(x, output_size=1): |
正则化层
Transformer相关Normalization方式
Normalization Layers
BatchNorm
torch.nn.BatchNorm1d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) |
参数:
num_features: 来自期望输入的特征数,该期望输入的大小为batch_size × num_features [× width],和之前输入卷积层的channel位的维度数目相同 |
# With Learnable Parameters |
LayerNorm
torch.nn.LayerNorm(normalized_shape, eps=1e-05, elementwise_affine=True) |
参数:
normalized_shape: 输入尺寸 |
LayerNorm
就是对(2, 2,4), 后面这一部分进行整个的标准化。可以理解为对整个图像进行标准化。
x_test = np.array([[[1,2,-1,1],[3,4,-2,2]], |
InstanceNorm
torch.nn.InstanceNorm1d(num_features, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False) |
参数:
num_features: 来自期望输入的特征数,该期望输入的大小为batch_size x num_features [x width] |
InstanceNorm
就是对(2, 2, 4)最后这一部分进行Norm。
x_test = np.array([[[1,2,-1,1],[3,4,-2,2]], |
GroupNorm
torch.nn.GroupNorm(num_groups, num_channels, eps=1e-05, affine=True) |
参数:
num_groups:需要划分为的groups |
当GroupNorm中group的数量是1的时候, 是与上面的LayerNorm是等价的。
x_test = np.array([[[1,2,-1,1],[3,4,-2,2]], |
当GroupNorm中num_groups的数量等于num_channel的数量,与InstanceNorm等价。
# Separate 2 channels into 2 groups (equivalent with InstanceNorm) |
激活函数
参考资料:GELU 激活函数
Pytorch激活函数及优缺点比较
torch.nn.GELU
bert源码给出的GELU代码pytorch版本表示如下:
def gelu(input_tensor): |
torch.nn.ELU(alpha=1.0,inplace=False)
def elu(x,alpha=1.0,inplace=False): |
α是超参数,默认为1.0
torch.nn.LeakyReLU(negative_slope=0.01,inplace=False)
def LeakyReLU(x,negative_slope=0.01,inplace=False):
return max(0,x)+negative_slope∗min(0,x)
其中 negative_slope是超参数,控制x为负数时斜率的角度,默认为1e-2
torch.nn.PReLU(num_parameters=1,init=0.25)
def PReLU(x,num_parameters=1,init=0.25):
return max(0,x)+init∗min(0,x)
其中a 是一个可学习的参数,当不带参数调用时,即nn.PReLU(),在所有的输入通道上使用同一个a,当带参数调用时,即nn.PReLU(nChannels),在每一个通道上学习一个单独的a。
注意:当为了获得好的performance学习一个a时,不要使用weight decay。
num_parameters:要学习的a的个数,默认1
init:a的初始值,默认0.25
torch.nn.ReLU(inplace=False)
CNN中最常用ReLu。
def ReLU(x,inplace=False):
return max(0,x)
torch.nn.ReLU6(inplace=False)
def ReLU6(x,inplace=False):
return min(max(0,x),6)
torch.nn.SELU(inplace=False)
def SELU(x,inplace=False):
alpha=1.6732632423543772848170429916717
scale=1.0507009873554804934193349852946
return scale∗(max(0,*x*)+min(0,alpha∗(exp(x)−1)))
torch.nn.CELU(alpha=1.0,inplace=False)
def CELU(x,alpha=1.0,inplace=False):
return max(0,x)+min(0,alpha∗(exp(x/alpha)−1))
其中α 默认为1.0
torch.nn.Sigmoid
def Sigmoid(x):
return 1/(np.exp(-x)+1)
torch.nn.LogSigmoid
def LogSigmoid(x):
return np.log(1/(np.exp(-x)+1))
torch.nn.Tanh
def Tanh(x):
return (np.exp(x)-np.exp(-x))/(np.exp(x)+np.exp(-x))
torch.nn.Tanhshrink
def Tanhshrink(x):
return x-(np.exp(x)-np.exp(-x))/(np.exp(x)+np.exp(-x))
torch.nn.Softplus(beta=1,threshold=20)
该函数可以看作是ReLu的平滑近似。
def Softplus(x,beta=1,threshold=20):
return np.log(1+np.exp(beta*x))/beta
torch.nn.Softshrink(lambd=0.5)
λ的值默认设置为0.5
def Softshrink(x,lambd=0.5):
if x>lambd:return x-lambd
elif x<-lambd:return x+lambd
else:return 0
nn.Softmax
m = nn.Softmax(dim=1)
input = torch.randn(2, 3)
output = m(input)
参考资料
Pytorch激活函数及优缺点比较
PyTorch快速入门教程二(线性回归以及logistic回归)
Pytorch全连接网络
pytorch系列之nn.Sequential讲解
来源:https://blog.csdn.net/weixin_43243315/article/details/121657881
猜你喜欢
- 在用HTML(HyperText Markup Language,超文本链接标示语言)语言编写Web页面时,由于所用的Web浏览器对HTML
- 一、CAN报文简介CAN是控制器局域网络(Controller Area Network, CAN)的简称,是由以研发和生产汽车电子产品著称
- 本文实现利用python的socketserver这个强大的模块实现套接字的并发,具体内容如下目录结构如下:测试文件请放在server_fi
- 前言最近用Django写项目的时候用到了数据的传递,一窍不通,查了点资料。记录一下。水平不高,瓜不保熟。 从两方面来说:从后端传递
- 这篇文章主要介绍了用python写测试数据文件过程解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋
- PyCharm 具备一般 IDE 的功能,比如,调试、语法高亮、项目管理、代码跳转、智能提示、自动完成、单元测试、版本控制…另外,PyCha
- 一,利用键盘响应,在不刷新本页面的情况下验证表单输入是否合法用户通过onkeydown和onkeyup事件来触发响应事件。使用方法和oncl
- # -*- coding: utf8 -*- #! python print(repr("测试报警,xxxx是大猪头".
- 看看上一篇《javascript设计模式交流(一)Singleton Pattern》本文将讨论Prototype Pattern的js实现
- 淡入淡出图片轮换轮播效果,可以做新闻图片推荐需要的拿去用,效果预览请点击运行代码相关效果推荐:迅雷首页新闻图片轮播效果js源码 <!D
- 1.MS SCRIPT ENCODE基本上没什么用了,一段JS就可以破解2.封装成DLL比较可行的方法,有通过VB封装成DLL的例子,而且无
- 密码算法程序设计实践选的SHA-1。在写的过程中遇到一丢丢关于python移位的问题,记录一下。SHA-1其中第一步需要填充消息。简单阐述一
- 搭建lnmp完lnmp环境后,测试时出现502报错,看到这个问题,我立刻想到是php-fpm没有起来,但是我用 ps -ef | grep
- 很简单,只需建立一个worksheet和Excel相关的信息就可以了具体代码见下:<%set xlApp =&nb
- 今日大致浏览了一下《High Performance Web Sites》。本书的中文版是《高性能网站建设指南》。本书另有对其中个别问题深入
- 在讲样式表开发管理之前,我想插播一个小知识。前几天看web标准设计组里,看到龍佑康同学问到关于 block 和 inline 的区别。记得以
- 对比起Cookie,Session 是存储在服务器端的会话,相对安全,并且不像 Cookie 那样有存储长度限制。由于 Session 是以
- 一、使用ddt和data装饰器的大致框架如下,每个test_开头的方法,代表一条测试用例from ddt import ddt,dataim
- <?php /********************************************** *&n
- 如何正确理解和使用Command、Connection和 Recordset三个对象?我知道它们都是连接数据库的“好手”,但在编程的具体应用