对Pytorch神经网络初始化kaiming分布详解
作者:winycg 发布时间:2022-08-27 01:39:50
函数的增益值
torch.nn.init.calculate_gain(nonlinearity, param=None)
提供了对非线性函数增益值的计算。
增益值gain是一个比例值,来调控输入数量级和输出数量级之间的关系。
fan_in和fan_out
pytorch计算fan_in和fan_out的源码
def _calculate_fan_in_and_fan_out(tensor):
dimensions = tensor.ndimension()
if dimensions < 2:
raise ValueError("Fan in and fan out can not be computed
for tensor with fewer than 2 dimensions")
if dimensions == 2: # Linear
fan_in = tensor.size(1)
fan_out = tensor.size(0)
else:
num_input_fmaps = tensor.size(1)
num_output_fmaps = tensor.size(0)
receptive_field_size = 1
if tensor.dim() > 2:
receptive_field_size = tensor[0][0].numel()
fan_in = num_input_fmaps * receptive_field_size
fan_out = num_output_fmaps * receptive_field_size
return fan_in, fan_out
xavier分布
xavier分布解析:https://prateekvjoshi.com/2016/03/29/understanding-xavier-initialization-in-deep-neural-networks/
假设使用的是sigmoid函数。当权重值(值指的是绝对值)过小,输入值每经过网络层,方差都会减少,每一层的加权和很小,在sigmoid函数0附件的区域相当于线性函数,失去了DNN的非线性性。
当权重的值过大,输入值经过每一层后方差会迅速上升,每层的输出值将会很大,此时每层的梯度将会趋近于0.
xavier初始化可以使得输入值x x x<math><semantics><mrow><mi>x</mi></mrow><annotation encoding="application/x-tex">x</annotation></semantics></math>x方差经过网络层后的输出值y y y<math><semantics><mrow><mi>y</mi></mrow><annotation encoding="application/x-tex">y</annotation></semantics></math>y方差不变。
(1)xavier的均匀分布
torch.nn.init.xavier_uniform_(tensor, gain=1)
也称为Glorot initialization。
>>> w = torch.empty(3, 5)
>>> nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu'))
(2) xavier正态分布
torch.nn.init.xavier_normal_(tensor, gain=1)
也称为Glorot initialization。
kaiming分布
Xavier在tanh中表现的很好,但在Relu激活函数中表现的很差,所何凯明提出了针对于relu的初始化方法。pytorch默认使用kaiming正态分布初始化卷积层参数。
(1) kaiming均匀分布
torch.nn.init.kaiming_uniform_
(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')
也被称为 He initialization。
a – the negative slope of the rectifier used after this layer (0 for ReLU by default).激活函数的负斜率,
mode – either ‘fan_in' (default) or ‘fan_out'. Choosing fan_in preserves the magnitude of the variance of the weights in the forward pass. Choosing fan_out preserves the magnitudes in the backwards
pass.默认为fan_in模式,fan_in可以保持前向传播的权重方差的数量级,fan_out可以保持反向传播的权重方差的数量级。
>>> w = torch.empty(3, 5)
>>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')
(2) kaiming正态分布
torch.nn.init.kaiming_normal_
(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')
也被称为 He initialization。
>>> w = torch.empty(3, 5)
>>> nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu')
来源:https://blog.csdn.net/winycg/article/details/86649832
猜你喜欢
- open函数你必须先用Python内置的open()函数打开一个文件,创建一个file对象,相关的辅助方法才可以调用它进行读写。语法:fil
- 本月第一天日期SELECT FirstDayOfCurrentMonth = dateadd(mm,datediff(mm,0,getdat
- 前言本文主要给大家介绍了关于golang解析网页利器goquery使用的相关内容,分享出来供大家参考学习,下面话不多说了,来一起看看详细的介
- 一、需求描述web 自动化测试/python爬虫往往会遇到扫码登录的情况,不是所有的网站都支持用户密码登录,遇到这种扫码登录的情况会阻碍我们
- 如何用HtmlEncode来显示Unicode? 见下:<%@ Language=VBS
- 本文实例讲述了JavaScript使用正则表达式获取全部分组内容的方法。分享给大家供大家参考,具体如下:1. 需要使用正则表达式的exec2
- 学习目的:掌握下拉列表框的用法,并理解AutoPostBack属性; 理解IsPoskBack及用法; 初识DataTable的
- 本文介绍了python opencv之SIFT算法示例,分享给大家,具体如下:目标:学习SIFT算法的概念 学习在图像中查找SIFT关键的和
- asp之家注:有时候我们想让程序运行变慢下来,asp中该怎么做呢?原理很简单就是在运行程序前运行一段无关紧要的程序就可以了,要实现加长程序的
- 在跨业务、跨网站发送数据或者业务升级的时候,我们有的时候需要指定发送数据的编码方式,比如页面是utf-8编码的,而发送出去的数据却是GB23
- 代码如下:<% str = request("str") reg 
- 原因是:It looks like you need to flush stdout periodically (e.g. sys.stdo
- 这篇论坛文章(赛迪网技术社区)主要介绍了数据仓库基本报表制作过程中的SQL写法,详细内容请参考下文:在数据仓库的基本报表制作过程中,通常会使
- PHP hex2bin() 函数实例把十六进制值转换为 ASCII 字符:<?php echo hex2bin("48656
- 从毕业实习算起,从事可用性方面的工作到现在已经5年了。在此记录笔者的一些所见所想,和大家讨论分享一下。用户研究在“以用户为中心”的界面设计方
- PHP SESSION 的存储Session会话存储方式PHP将session以文件的形式存储服务器的文件中,session.save_pa
- 一、使用ddt和data装饰器的大致框架如下,每个test_开头的方法,代表一条测试用例from ddt import ddt,dataim
- 线程实现Python中线程有两种方式:函数或者用类来包装线程对象。threading模块中包含了丰富的多线程支持功能:threading.c
- (1) os.system仅仅在一个子终端运行系统命令,而不能获取命令执行后的返回信息system(command) -> exit_
- 有关JS中字符串的相关文章,现在网上大概不计其数了。这里我不想再就这个问题做过多的论述,只是对几种方式的实现在各种浏览器中的执行效率进行对比