Python深度学习pyTorch权重衰减与L2范数正则化解析
作者:算法菜鸟飞高高 发布时间:2021-03-18 11:39:01
标签:深度学习,pyTorch,权重,范数正则化
下面进行一个高维线性实验
假设我们的真实方程是:
假设feature数200,训练样本和测试样本各20个
模拟数据集
num_train,num_test = 10,10
num_features = 200
true_w = torch.ones((num_features,1),dtype=torch.float32) * 0.01
true_b = torch.tensor(0.5)
samples = torch.normal(0,1,(num_train+num_test,num_features))
noise = torch.normal(0,0.01,(num_train+num_test,1))
labels = samples.matmul(true_w) + true_b + noise
train_samples, train_labels= samples[:num_train],labels[:num_train]
test_samples, test_labels = samples[num_train:],labels[num_train:]
定义带正则项的loss function
def loss_function(predict,label,w,lambd):
loss = (predict - label) ** 2
loss = loss.mean() + lambd * (w**2).mean()
return loss
画图的方法
def semilogy(x_val,y_val,x_label,y_label,x2_val,y2_val,legend):
plt.figure(figsize=(3,3))
plt.xlabel(x_label)
plt.ylabel(y_label)
plt.semilogy(x_val,y_val)
if x2_val and y2_val:
plt.semilogy(x2_val,y2_val)
plt.legend(legend)
plt.show()
拟合和画图
def fit_and_plot(train_samples,train_labels,test_samples,test_labels,num_epoch,lambd):
w = torch.normal(0,1,(train_samples.shape[-1],1),requires_grad=True)
b = torch.tensor(0.,requires_grad=True)
optimizer = torch.optim.Adam([w,b],lr=0.05)
train_loss = []
test_loss = []
for epoch in range(num_epoch):
predict = train_samples.matmul(w) + b
epoch_train_loss = loss_function(predict,train_labels,w,lambd)
optimizer.zero_grad()
epoch_train_loss.backward()
optimizer.step()
test_predict = test_sapmles.matmul(w) + b
epoch_test_loss = loss_function(test_predict,test_labels,w,lambd)
train_loss.append(epoch_train_loss.item())
test_loss.append(epoch_test_loss.item())
semilogy(range(1,num_epoch+1),train_loss,'epoch','loss',range(1,num_epoch+1),test_loss,['train','test'])
以上就是Python深度学习pyTorch权重衰减与L2范数正则化解析的详细内容,更多关于Python pyTorch权重与L2范数正则化的资料请关注脚本之家其它相关文章!
来源:https://blog.csdn.net/qq_43152622/article/details/116937183
0
投稿
猜你喜欢
- 在大学,有很多喜欢的课是需要抢的。但是,这个课的人数和座位都是有限的,今天这个教程教你如何抢到座位,有座位了还怕听不到课吗?赶紧学起来吧,真
- 今天有点囧a=['XXXX_game.sql', 'XXXX_game_sp.sql', 'XXXX
- <style> #L { position:absolute; color:
- 本文主要介绍了Python利用numpy实现三层神经网络的示例代码,分享给大家,具体如下:其实神经网络很好实现,稍微有点基础的基本都可以实现
- 一、功能实现对学生对个人信息的增删查改实现后台对所有学生信息的操作二、平台windows+pycharm(python开发工具)三、逻辑框图
- 本期给大家讲解的函数都不陌生,大家都遇到使用过,但是不要轻易觉得简单去学习,因为往往看似简单的东西,从一个方面深入下收都是一大堆的东西,千万
- 下面直接上代码留存,方便以后查阅复用。# -*- coding: utf-8 -*- #作者:LeniyTsan#时间:2014-07-17
- 如何让图片自动缩放以适合界面大小,拿出你的Editplus,打开c_function.asp文件,找到UBBCode函数,在第417行有如下
- lambda 语法lambda 函数的语法只包含一个语句,表现形式如下:lambda [arg1 [,arg2,.....argn]]:ex
- 以下的文章主要介绍的是MySQL 查询缓存的实际应用代码以及查看MySQL 查询缓存的大小 ,碎片整理,清除缓存以及监视MySQL 查询缓存
- PyQt5选项卡控件QTabWidget简介QTabWidget控件提供了一个选项卡和一个页面区域,默认显示第一个选项卡的页面,通过单击各选
- 这篇文章主要介绍了用python写测试数据文件过程解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋
- 直方图介绍直方图(Histogram),又称质量分布图,是一种统计报告图,由一系列高度不等的纵向条纹或线段表示数据分布的情况。 一般用横轴表
- 问题最近在研究图学习,在用networkx库绘图的时候发现问题。'''author:zhengtime:2020.1
- 前言我们的游戏资源处理工具是Python实现的,功能包括csv解析,UI材质处理,动画资源解析、批处理,Androd&iOS自动打包
- 前言日常使用python经常要对文本进行处理,无论是爬虫的数据解析,还是大数据的文本清洗,还是普通文件的处理,都是要用到字符串. Pytho
- <script>var d = '2013-07-21';var nd = d.replace(new RegE
- 1、yield,将函数变为 generator (生成器)例如:斐波那契数列def fib(num): a, b, c = 1,
- K-近邻算法概述简单地说, k-近邻算法采用测量不同特征值之间的距离方法进行分类。k-近邻算法优点:精度高、对异常值不敏感、无数据输入假定。
- 前言:以往看到我博客的小伙伴可能都知道,我的前言一般都是吐槽和讲废话环节,哈哈哈哈。今天难得休息,最近可真是太忙了,博主已经连续一年都在99