关于pytorch中网络loss传播和参数更新的理解
作者:少年木 发布时间:2023-08-06 05:29:09
相比于2018年,在ICLR2019提交论文中,提及不同框架的论文数量发生了极大变化,网友发现,提及tensorflow的论文数量从2018年的228篇略微提升到了266篇,keras从42提升到56,但是pytorch的数量从87篇提升到了252篇。
TensorFlow: 228--->266
Keras: 42--->56
Pytorch: 87--->252
在使用pytorch中,自己有一些思考,如下:
1. loss计算和反向传播
import torch.nn as nn
criterion = nn.MSELoss().cuda()
output = model(input)
loss = criterion(output, target)
loss.backward()
通过定义损失函数:criterion,然后通过计算网络真实输出和真实标签之间的误差,得到网络的损失值:loss;
最后通过loss.backward()完成误差的反向传播,通过pytorch的内在机制完成自动求导得到每个参数的梯度。
需要注意,在机器学习或者深度学习中,我们需要通过修改参数使得损失函数最小化或最大化,一般是通过梯度进行网络模型的参数更新,通过loss的计算和误差反向传播,我们得到网络中,每个参数的梯度值,后面我们再通过优化算法进行网络参数优化更新。
2. 网络参数更新
在更新网络参数时,我们需要选择一种调整模型参数更新的策略,即优化算法。
优化算法中,简单的有一阶优化算法:
其中 就是通常说的学习率,
是函数的梯度;
自己的理解是,对于复杂的优化算法,基本原理也是这样的,不过计算更加复杂。
在pytorch中,torch.optim是一个实现各种优化算法的包,可以直接通过这个包进行调用。
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
注意:
1)在前面部分1中,已经通过loss的反向传播得到了每个参数的梯度,然后再本部分通过定义优化器(优化算法),确定了网络更新的方式,在上述代码中,我们将模型的需要更新的参数传入优化器。
2)注意优化器,即optimizer中,传入的模型更新的参数,对于网络中有多个模型的网络,我们可以选择需要更新的网络参数进行输入即可,上述代码,只会更新model中的模型参数。对于需要更新多个模型的参数的情况,可以参考以下代码:
optimizer = torch.optim.Adam([{'params': model.parameters()}, {'params': gru.parameters()}], lr=0.01) 3) 在优化前需要先将梯度归零,即optimizer.zeros()。
3. loss计算和参数更新
import torch.nn as nn
import torch
criterion = nn.MSELoss().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
output = model(input)
loss = criterion(output, target)
optimizer.zero_grad() # 将所有参数的梯度都置零
loss.backward() # 误差反向传播计算参数梯度
optimizer.step() # 通过梯度做一步参数更新
来源:https://blog.csdn.net/yangzhengzheng95/article/details/85268896


猜你喜欢
- 1. 前言文章主要围绕着以下三个问题:group by的作用where与having的区别表的连接分为哪些,分别是什么作用2. 表的设计在创
- 实际线上的场景比较复杂,当时涉及了truncate, delete 两个操作,经确认丢数据差不多7万多行,等停下来时,差不多又有共计1万多行
- 前言最近找几个老友准备聊天发现几个已经被删除好友名单,做为潜水党多年的我已经不知道成为多少人的黑名单,但是好友列表却依然有不是好友的名单,面
- 1、TransBigData简介TransBigData是一个为交通时空大数据处理、分析和可视化而开发的Python包。TransBigDa
- 整个重装步骤大致分四个步骤进行,第一步,备份原mysql中的所有数据库。第二步,完全卸载mysql第三步,下载安装新版mysql第四步,导入
- 1.GAN简述在GAN中,有两个模型,一个是生成模型,用于生成样本,一个是判别模型,用于判断样本是真还是假。但由于在GAN中,使用的JS散度
- 有的小伙伴对于枚举的理解很模糊,其实我们可以把它看成一个数量的大管家,对其中的每一个数进行检查,保证里面的数字都没有重复的,这就是枚举的用法
- 本文针对SQL 2016 正式版安装过程进行梳理总结,帮助大家顺利安装SQL 2016,具体内容如下1.点击全新安装2.接着就是下一步,下一
- 1、jieba库基本介绍(1)、jieba库概述jieba是优秀的中文分词第三方库- 中文文本需要通过分词获得单个的词语- jieba是优秀
- 考虑到数据安全问题,准备把服务器上的数据库迁移到刚刚挂载的云硬盘上,研究一下,这个方法是最靠谱的,分享之!首先建立数据库即将迁移到的目录mk
- 移动端适配满足多个查询时的优先级: 请注意,可以同时满足多个查询,并且它们都将由mergeOption合并,mergeOption稍后由me
- 服务器响应HTTP的类型ContentType大全,使用方法:<% Response.ContentType =&
- 本文记录了MySQL下载安装详细教程,供大家参考,具体内容如下1.下载MySQL数据库可以访问官方网站:2.点击DOWNLOADS模块下的C
- 本文实例讲述了Python实现压缩文件夹与解压缩zip文件的方法。分享给大家供大家参考,具体如下:直接上代码#coding=utf-8#甄码
- 写在之前首先是写在之前的一些建议:首先是关于这本书,我真的认为他是将神经网络里非常棒的一本书,但你也需要注意,如果你真的想自己动手去实现,那
- 部署apache服务的步骤:准备环境:关闭防火墙 :service iptables stop设置开机关闭防火墙:chkconfig ipt
- 本文实例讲述了php实现的日历程序。分享给大家供大家参考。具体如下:<?php /* * php 输出日历程序 */ header(&
- 前言我们在上一期学习了关于Python 迭代器Iterator详情相关的概念,满足迭代器需要符合两个条件实现__iter__()方
- 本文实例为大家分享了ajax实现无刷新上传文件功能的具体代码,供大家参考,具体内容如下详细代码如下<!DOCTYPE HTML>
- 第一章:日志管理 1.forcing log switchessql> alter system switch logfile;2.f