Python深度学习pyTorch权重衰减与L2范数正则化解析
作者:算法菜鸟飞高高 发布时间:2021-03-18 11:39:01
标签:深度学习,pyTorch,权重,范数正则化
下面进行一个高维线性实验
假设我们的真实方程是:
假设feature数200,训练样本和测试样本各20个
模拟数据集
num_train,num_test = 10,10
num_features = 200
true_w = torch.ones((num_features,1),dtype=torch.float32) * 0.01
true_b = torch.tensor(0.5)
samples = torch.normal(0,1,(num_train+num_test,num_features))
noise = torch.normal(0,0.01,(num_train+num_test,1))
labels = samples.matmul(true_w) + true_b + noise
train_samples, train_labels= samples[:num_train],labels[:num_train]
test_samples, test_labels = samples[num_train:],labels[num_train:]
定义带正则项的loss function
def loss_function(predict,label,w,lambd):
loss = (predict - label) ** 2
loss = loss.mean() + lambd * (w**2).mean()
return loss
画图的方法
def semilogy(x_val,y_val,x_label,y_label,x2_val,y2_val,legend):
plt.figure(figsize=(3,3))
plt.xlabel(x_label)
plt.ylabel(y_label)
plt.semilogy(x_val,y_val)
if x2_val and y2_val:
plt.semilogy(x2_val,y2_val)
plt.legend(legend)
plt.show()
拟合和画图
def fit_and_plot(train_samples,train_labels,test_samples,test_labels,num_epoch,lambd):
w = torch.normal(0,1,(train_samples.shape[-1],1),requires_grad=True)
b = torch.tensor(0.,requires_grad=True)
optimizer = torch.optim.Adam([w,b],lr=0.05)
train_loss = []
test_loss = []
for epoch in range(num_epoch):
predict = train_samples.matmul(w) + b
epoch_train_loss = loss_function(predict,train_labels,w,lambd)
optimizer.zero_grad()
epoch_train_loss.backward()
optimizer.step()
test_predict = test_sapmles.matmul(w) + b
epoch_test_loss = loss_function(test_predict,test_labels,w,lambd)
train_loss.append(epoch_train_loss.item())
test_loss.append(epoch_test_loss.item())
semilogy(range(1,num_epoch+1),train_loss,'epoch','loss',range(1,num_epoch+1),test_loss,['train','test'])
以上就是Python深度学习pyTorch权重衰减与L2范数正则化解析的详细内容,更多关于Python pyTorch权重与L2范数正则化的资料请关注脚本之家其它相关文章!
来源:https://blog.csdn.net/qq_43152622/article/details/116937183


猜你喜欢
- 需求背景一个统计接口,前端需要返回两个数组,一个是0-23的小时计数,一个是各小时对应的统计数。思路 直接使用group by查询要统计的表
- 核心代码:#!/usr/bin/python#-*- coding:gbk -*-#设置源文件输出格式import sysimport ge
- 需求分析项目上需要用到手机号前7位,判断号码是否合法,还有归属地查询。旧的数据是几年前了太久了,打算用python爬虫重新爬一份单线程版本#
- 处理下拉列表需要使用selenium中的工具类Select,常用方法如下:示例网站:http://sahitest.com/demo示例场景
- 早上看了一个贴子,是一个哥们推广自己一个智能的数据库备份系统,他总结了数据库备份过程中所有可能出错的情况,可以借鉴。如果你做DBA时间不长,
- 本文实例讲述了vue动态组件和v-once指令。分享给大家供大家参考,具体如下:点击按钮时,自动切换两个组件<component :i
- 本文主要研究的是Python对内存的使用(深浅拷贝)的相关问题,具体介绍如下。浅拷贝就是对引用的拷贝(只拷贝父对象) 深拷贝就是对对象的资源
- pyqtgraph是Python平台上一种功能强大的2D/3D绘图库,相对于matplotlib库,由于其在内部实现方式上,使用了高速计算的
- 本文实例为大家分享了python将两张图片生成全景图片的具体代码,供大家参考,具体内容如下1、全景图片的介绍全景图通过广角的表现手段以及绘画
- 痛点json 是当前最常用的数据传输格式之一,纯文本,容易使用,方便阅读,在通信过程中大量被使用。 你是否遇到过json中某个字段
- table单元格新增行并编辑,具体内容如下需要bootstrap.min.css —— [ Bootstrap ]jquery-1.8.2.
- <?php /******************************************** *&nb
- 前言Python中的 True和 False总是让人困惑,一不小心就会用错,本文总结了三个易错点,分别是逻辑取反、if条件式和pandas.
- 生成全局ID的方法很多, 这里记录下一种简单的方案: 利用mysql的自增id生成全局唯一ID.1. 创建一张只需要两个字段的表:CREAT
- Python函数的设计规范1、Python函数设计时具备耦合性和聚合性1)、耦合性:(1).尽可能通过参数接受输入,以及通过return产生
- 一、前言别问我为啥题目是英文,因为…高大上(bushi。刷视频的时候偶然刷到了一个关于生日悖论的,当场就觉得不可思议,
- 要找到最早的活动事务,可以使用DBCC OPENTRAN命令。详细用法见MSDN:http://msdn.microsoft.com/zh-
- 一、多层前向神经网络多层前向神经网络由三部分组成:输出层、隐藏层、输出层,每层由单元组成;输入层由训练集的实例特征向量传入,经过连接结点的权
- ddt 是第三方模块,需安装, pip install ddtDDT包含类的装饰器ddt和两个方法装饰器data(直接输入测试数据)通常情况
- 1 random.choicepython random模块的choice方法随机选择某个元素foo = ['a',