关于pytorch中网络loss传播和参数更新的理解
作者:少年木 发布时间:2023-08-06 05:29:09
相比于2018年,在ICLR2019提交论文中,提及不同框架的论文数量发生了极大变化,网友发现,提及tensorflow的论文数量从2018年的228篇略微提升到了266篇,keras从42提升到56,但是pytorch的数量从87篇提升到了252篇。
TensorFlow: 228--->266
Keras: 42--->56
Pytorch: 87--->252
在使用pytorch中,自己有一些思考,如下:
1. loss计算和反向传播
import torch.nn as nn
criterion = nn.MSELoss().cuda()
output = model(input)
loss = criterion(output, target)
loss.backward()
通过定义损失函数:criterion,然后通过计算网络真实输出和真实标签之间的误差,得到网络的损失值:loss;
最后通过loss.backward()完成误差的反向传播,通过pytorch的内在机制完成自动求导得到每个参数的梯度。
需要注意,在机器学习或者深度学习中,我们需要通过修改参数使得损失函数最小化或最大化,一般是通过梯度进行网络模型的参数更新,通过loss的计算和误差反向传播,我们得到网络中,每个参数的梯度值,后面我们再通过优化算法进行网络参数优化更新。
2. 网络参数更新
在更新网络参数时,我们需要选择一种调整模型参数更新的策略,即优化算法。
优化算法中,简单的有一阶优化算法:
其中 就是通常说的学习率, 是函数的梯度;
自己的理解是,对于复杂的优化算法,基本原理也是这样的,不过计算更加复杂。
在pytorch中,torch.optim是一个实现各种优化算法的包,可以直接通过这个包进行调用。
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
注意:
1)在前面部分1中,已经通过loss的反向传播得到了每个参数的梯度,然后再本部分通过定义优化器(优化算法),确定了网络更新的方式,在上述代码中,我们将模型的需要更新的参数传入优化器。
2)注意优化器,即optimizer中,传入的模型更新的参数,对于网络中有多个模型的网络,我们可以选择需要更新的网络参数进行输入即可,上述代码,只会更新model中的模型参数。对于需要更新多个模型的参数的情况,可以参考以下代码:
optimizer = torch.optim.Adam([{'params': model.parameters()}, {'params': gru.parameters()}], lr=0.01) 3) 在优化前需要先将梯度归零,即optimizer.zeros()。
3. loss计算和参数更新
import torch.nn as nn
import torch
criterion = nn.MSELoss().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
output = model(input)
loss = criterion(output, target)
optimizer.zero_grad() # 将所有参数的梯度都置零
loss.backward() # 误差反向传播计算参数梯度
optimizer.step() # 通过梯度做一步参数更新
来源:https://blog.csdn.net/yangzhengzheng95/article/details/85268896
猜你喜欢
- 这个翻滚代码没有使用什么marquee或者其它位移方法,而是每隔一秒把列表最顶端的那个li删掉,把这个li里面的内容插入到最底端新生成的li
- RFC文档有很多,有时候在没有联网的情况下也想翻阅,只能下载一份留存本地了。看了看地址列表,大概是这个范围:http://www.netwo
- 首先说明下范围用Javascript来开发WEB页面的动画效果该思路同时考虑页面效率、SEO,如果数据大,也可以缓解后端压力。这个是程序设计
- 摘要:本文介绍了tensorflow的常用函数。1、tensorflow常用函数TensorFlow 将图形定义转换成分布式执行的操作, 以
- 阅读上一篇:什么是名字空间<meta http-equiv="Content-Type" co
- 对于xml2ddl项目,Freshmeat.org提供了一整套基于GNU或者GPL通用公共许可证下的Python程序。在一个运行的Pytho
- python help使用C:\Users\wusong>pythonPython 3.8.2rc1 (tags/v3.8.2rc1:
- 内容摘要:FCKeditor至今已经到了2.3.1版本了,对于国内的WEB开发者来说,也基本上都已经“闻风知多少”了,很多人将其融放到自己的
- 什么是 AOPAOP,就是面向切面编程,简单的说,就是动态地将代码切入到类的指定方法、指定位置上的编程思想就是面向切面的编程。我们管切入到指
- 本文实例为大家分享了python实现转盘效果的具体代码,供大家参考,具体内容如下#抽奖 面向对象版本import tkinterimport
- YUI3.2.0 的 transition 模块,通过使用 transition:end 事件实现在 transition 完成后执行其他操
- 很多人都将<数据库设计范式>作为数据库表结构设计“圣经”,认为只要按照这个范式需求设计,就能让设计出来的表结构足够优化,既能保证
- 本XML系列教程将分三部分发布,到最后一期我们将拥有一个功能全面,更加友好的XML菜单。本教程这个第一期涉及到了一些XML的基础知识。大家都
- 前言在浏览博客时,偶然看到了用python将汉字转为拼音的第三方包,但是在实现的过程中发现一些参数已经更新,现在将两种方法记录一下。xpin
- 这篇文章主要介绍了Python hashlib常见摘要算法详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,
- 自己前端开发中常用到的一些技巧及问题解决方法,会常更新,希望对前端路上的朋友有帮助。1、文章标题列表中日期居右显示的方法(提供了两种方法,使
- 假如一个页面中的文本采用的都是同样的字体、同样的字号、同样的颜色,做为读者的你能轻易的区分出哪里是标题,哪里是正文内容吗?所以通常情况下,设
- 记得上次电梯按钮讨论中有朋友提到日本的无序电梯,我没有太明白意思。除了各位大师提出的无厘头方案,也有不少超前的创意,好多都值得继续思考和探索
- 这篇文章主要介绍了python框架django项目部署相关知识详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价
- 本文实例讲述了Python实现读取txt文件中的数据并绘制出图形操作。分享给大家供大家参考,具体如下:下面的是某一文本文件中的数据。6.11