Pytorch 如何实现常用正则化
作者:winycg 发布时间:2022-11-02 22:15:14
标签:Pytorch,正则化
Stochastic Depth
论文:Deep Networks with Stochastic Depth
本文的正则化针对于ResNet中的残差结构,类似于dropout的原理,训练时对模块进行随机的删除,从而提升模型的泛化能力。
对于上述的ResNet网络,模块越在后面被drop掉的概率越大。
作者直觉上认为前期提取的低阶特征会被用于后面的层。
第一个模块保留的概率为1,之后保留概率随着深度线性递减。
对一个模块的drop函数可以采用如下的方式实现:
def drop_connect(inputs, p, training):
""" Drop connect. """
if not training: return inputs # 测试阶段
batch_size = inputs.shape[0]
keep_prob = 1 - p
random_tensor = keep_prob
random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device)
# 以样本为单位生成模块是否被drop的01向量
binary_tensor = torch.floor(random_tensor)
# 因为越往后越容易被drop,所以没有被drop的值就要通过除keep_prob来放大
output = inputs / keep_prob * binary_tensor
return output
在Pytorch建立的Module类中,具有forward函数
可以在forward函数中进行drop:
def forward(self, x):
x=...
if stride == 1 and in_planes == out_planes:
if drop_connect_rate:
x = drop_connect(x, p=drop_connect_rate, training=self.training)
x = x + inputs # skip connection
return x
主函数:
for idx, block in enumerate(self._blocks):
drop_connect_rate = self._global_params.drop_connect_rate
if drop_connect_rate:
drop_connect_rate *= float(idx) / len(self._blocks)
x = block(x, drop_connect_rate=drop_connect_rate)
补充:pytorch中的L2正则化实现方法
搭建神经网络时需要使用L2正则化等操作来防止过拟合,而pytorch不像TensorFlow能在任意卷积函数中添加L2正则化的超参,那怎么在pytorch中实现L2正则化呢?
方法如下:超级简单!
optimizer = torch.optim.Adam(net.parameters(), lr=0.001, weight_decay=5.0)
torch.optim.Adam()参数中的 weight_decay=5.0 即为L2正则化(只是pytorch换了名字),其数值即为L2正则化的惩罚系数,一般设置为1、5、10(根据需要设置,默认为0,不使用L2正则化)。
注:
pytorch中的优化函数L2正则化默认对所有网络参数进行惩罚,且只能实现L2正则化,如需只惩罚指定网络层参数或采用L1正则化,只能自己定义。。。
来源:https://winycg.blog.csdn.net/article/details/96361576
0
投稿
猜你喜欢
- language.xml 代码如下:<?xml version="1.0" encoding=
- 前言:在转换操作中,我们执行各种操作,例如更改系列的数据类型,将系列更改为列表等。为了执行转换操作,我们有各种有助于转换的功能,例如.ast
- 假如你正在运行使用MySQL的Web应用程序,那么你把密码或者其他敏感信息保存在应用程序里的机会就很大。保护这些数据免受黑客或者窥探者的获取
- 前言相信大家都应该有所体会,在平时经常会遇到处理 Excel 表格数据的情况,人工处理起来实在是太麻烦了,我们可以使用 Python 来解决
- 1. 选用适合的ORACLE优化器 ORACLE的优化器共有3种: a. RULE (基于规则) b. COST (基于成本) c. CHO
- 一、备份数据库1、打开SQL企业管理器,在控制台根目录中依次点开Microsoft SQL Server2、SQL Server组-->
- 一、介绍Django特点:具有完整的封装,开发者可以高效率的开发项目,Django将大部分的功能进行了封装,开发者只需要调用即可,如此,大大
- python中使用requests模块http请求时,发现中文参数不会自动的URL编码,并且没有找到类似urllib (python3)模块
- 如下所示:import numpy as np a=np.random.randint(0,10,size=[3,3,3])print(a)
- CSS网页布局开发中,会有很多小技巧,这里再扩展一下您所想要得到的知识,相信您会有很多收获!一、ul标签在Mozilla中默认是有paddi
- 小鸟(image)游戏展示代码展示import pygame,syspygame.init()#初始化操作#保存窗口大小width,heig
- 外网python2.7 虚拟环境中安装了 flask 模块,期望在内网使用,如何迁移外网的虚拟环境到内网呢?1 进入外网python虚拟环境
- 本文介绍在Anaconda环境下,安装Python中栅格、矢量等地理数据处理库GDAL的方法。  需要注
- 本文实例讲述了python获取本机mac地址和ip地址的方法。分享给大家供大家参考。具体如下:import sys, socketdef g
- 1.画最简单的直线图代码如下:import numpy as np import matplotlib.pyplot as plt x=[0
- 一、前言Python提供两种方法进行字符串格式化1、利用百分号来格式化字符串,现在Python已停止更新这种方法2、字符串的format方法
- 一、项目背景:为了回顾关于django的文件上传和分页功能,打算写一个微型的小说网站练练手。花了一个下午的时间,写了个小项目,发现其中其实遇
- python通过安装使用paramiko模块,将本地文件上传到服务器上import paramikoimport datetimeimpor
- python操作mongodb数据库# !/usr/bin/env python# -*- coding:utf-8 -*-"&q
- 本文实例为大家分享了python多线程下信号处理程序示例的具体代码,供大家参考,具体内容如下下面是一个网上转载的实现思路,经过验证,发现是可