Pytorch 如何实现常用正则化
作者:winycg 发布时间:2022-11-02 22:15:14
标签:Pytorch,正则化
Stochastic Depth
论文:Deep Networks with Stochastic Depth
本文的正则化针对于ResNet中的残差结构,类似于dropout的原理,训练时对模块进行随机的删除,从而提升模型的泛化能力。
对于上述的ResNet网络,模块越在后面被drop掉的概率越大。
作者直觉上认为前期提取的低阶特征会被用于后面的层。
第一个模块保留的概率为1,之后保留概率随着深度线性递减。
对一个模块的drop函数可以采用如下的方式实现:
def drop_connect(inputs, p, training):
""" Drop connect. """
if not training: return inputs # 测试阶段
batch_size = inputs.shape[0]
keep_prob = 1 - p
random_tensor = keep_prob
random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device)
# 以样本为单位生成模块是否被drop的01向量
binary_tensor = torch.floor(random_tensor)
# 因为越往后越容易被drop,所以没有被drop的值就要通过除keep_prob来放大
output = inputs / keep_prob * binary_tensor
return output
在Pytorch建立的Module类中,具有forward函数
可以在forward函数中进行drop:
def forward(self, x):
x=...
if stride == 1 and in_planes == out_planes:
if drop_connect_rate:
x = drop_connect(x, p=drop_connect_rate, training=self.training)
x = x + inputs # skip connection
return x
主函数:
for idx, block in enumerate(self._blocks):
drop_connect_rate = self._global_params.drop_connect_rate
if drop_connect_rate:
drop_connect_rate *= float(idx) / len(self._blocks)
x = block(x, drop_connect_rate=drop_connect_rate)
补充:pytorch中的L2正则化实现方法
搭建神经网络时需要使用L2正则化等操作来防止过拟合,而pytorch不像TensorFlow能在任意卷积函数中添加L2正则化的超参,那怎么在pytorch中实现L2正则化呢?
方法如下:超级简单!
optimizer = torch.optim.Adam(net.parameters(), lr=0.001, weight_decay=5.0)
torch.optim.Adam()参数中的 weight_decay=5.0 即为L2正则化(只是pytorch换了名字),其数值即为L2正则化的惩罚系数,一般设置为1、5、10(根据需要设置,默认为0,不使用L2正则化)。
注:
pytorch中的优化函数L2正则化默认对所有网络参数进行惩罚,且只能实现L2正则化,如需只惩罚指定网络层参数或采用L1正则化,只能自己定义。。。
来源:https://winycg.blog.csdn.net/article/details/96361576


猜你喜欢
- 在服务器部署时,往往都是在后台运行。当程序发生特定的错误时,我希望能够在日志中查询。因此这里熟悉以下 logging 模块的用法。loggi
- 在进行数据抓取时,经常会遇到IP被限制的情况,常见的解决方案是搭建 * 池,或购买IP代理的服务。除此之外,还有一个另外的方法就是使用家里
- 一、集中式vs分布式1.Subversion属于集中式的版本控制系统集中式的版本控制系统都有一个单一的集中管理的服务器,保存所有文件的修订版
- 一、前言经过前面的两篇文章,整体工作已经完成了2/3了,剩下的1/3,将会在本片文章提及前面两步文章链接python实战之德州扑克第一步-发
- 用Nginx做为代理服务器, 后端为 apache2. 设置允许上传最大为100M的文件. 1. Nginx配置: http { .....
- 如下所示:class Login(QMainWindow): """登录窗口""
- 问题出现原因python里numpy默认的是浅拷贝,即拷贝的是对象内存地址,导致两个数据结构共用一个内存地址。结果是修改拷贝的值的时候原对象
- 请问如何实现复合查询?我们用下面的代码来实现动态生成查询条件,动态显示结果的复合查询。set database to databasenam
- 请问鼠标移过去就出现二级菜单代码怎么写啊 <head><style type="tex
- 我们知道 Golang 切片(slice) 在容量不足的情况下会进行扩容,扩容的原理是怎样的呢?是不是每次扩一倍?下面我们结合源码来告诉你答
- Python跑循环时内存泄露今天在用Tensorflow跑回归做测试时,仅仅需要循环四千多次 (补充说一句,我在个人PC上跑的)。运行以后,
- 废话不多说了,直接给大家贴代码了。-- create functioncreate function [dbo].[fnXmlToJson]
- 1。总体概要kNN算法已经在上一篇博客中说明。对于要处理手写体数字,需要处理的点主要包括: (1)图片的预处理:将png,jpg等格式的图片
- Python打包分发工具setuptools:曾经 Python 的分发工具是 distutils,但它无法定义包之间的依赖关系。setup
- SQLSTATESQL SERVER 驱动程序错误描述 HY000所有绑定列都是只读的。必须是可升级的列,以使用 SQLSetPo
- 1. 代码讲解1.1 导库import os.pathfrom os import listdirimport numpy as npimp
- 对于一般的图像提取轮廓,介绍了一个很好的方法,但是对于有噪声的图像,并不能很好地捕获到目标物体。比如对于我的鼠标,提取的轮廓效果并不好,因为
- 摘要:本篇博客介绍了YOLOv5车牌识别的理论基础,包括目标检测的概念、YOLO系列的发展历程、YOLOv5的网络结构和损失函数等。通过深入
- 首先让我们来看看有关 Perl 面向对象编程的三个基本定义:1. 一个“对象”是指一个“有办法知道它是属于哪个类”的简单引用。(
- 情境还原: 公司一项目新上线,刚上线的第2天,在后台发现数据库服务器与IIS服务器的网络IO出现瓶颈,1GB的网络带宽,占用了70%-100