Pytorch反向求导更新网络参数的方法
作者:tsq292978891 发布时间:2021-02-07 11:48:52
标签:Pytorch,求导,参数
方法一:手动计算变量的梯度,然后更新梯度
import torch
from torch.autograd import Variable
# 定义参数
w1 = Variable(torch.FloatTensor([1,2,3]),requires_grad = True)
# 定义输出
d = torch.mean(w1)
# 反向求导
d.backward()
# 定义学习率等参数
lr = 0.001
# 手动更新参数
w1.data.zero_() # BP求导更新参数之前,需先对导数置0
w1.data.sub_(lr*w1.grad.data)
一个网络中通常有很多变量,如果按照上述的方法手动求导,然后更新参数,是很麻烦的,这个时候可以调用torch.optim
方法二:使用torch.optim
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.optim as optim
# 这里假设我们定义了一个网络,为net
steps = 10000
# 定义一个optim对象
optimizer = optim.SGD(net.parameters(), lr = 0.01)
# 在for循环中更新参数
for i in range(steps):
optimizer.zero_grad() # 对网络中参数当前的导数置0
output = net(input) # 网络前向计算
loss = criterion(output, target) # 计算损失
loss.backward() #得到模型中参数对当前输入的梯度
optimizer.step() # 更新参数
注意:torch.optim只用于参数更新和对参数的梯度置0,不能计算参数的梯度,在使用torch.optim进行参数更新之前,需要写前向与反向传播求导的代码
来源:https://blog.csdn.net/tsq292978891/article/details/79333707


猜你喜欢
- 在python中,有很多用于生成基于JS的百度开源的数据可视化图表 Echarts 的类库。设置的图样都非常漂亮,小编之前研究过很多图示,用
- Django2.0中编写models类下的ForeignKeybook = models.ForeignKey('BookInfo&
- 贪吃蛇游戏是经典手机游戏,既简单又耐玩。通过控制蛇头方向吃蛋,使得蛇变长,从而获得积分。在诺基亚时代,风靡整个手机界,今天我们来看看另类的,
- 通过status命令,查看Slow queries这一项,如果值长时间>0,说明有查询执行时间过长以下为引用的内容:mysql>
- 平时我们在使用MySQL数据库的时候经常会因为操作失误造成数据丢失,MySQL数据库备份可以帮助我们避免由于各种原因造成的数据丢失或着数据库
- 第一步、下载压缩包下载社区版的 MySQL,根据需求下载对应版本,其中有最小安装版本。具体各个版本的区别,可以上网查询,链接MySQL ::
- 上一篇博客主要介绍了决策树的原理,这篇主要介绍他的实现,代码环境python 3.4,实现的是ID3算法,首先为了后面matplotlib的
- 现在我主要教大家如何去实战,做一个简易的知乎日报API 首先你要熟悉django的基本用法,会写模型,会写视图函数,会配置url。1.配置字
- Goland 项目创建goland2020.3 及以上 IDE,默认创建的 go 项目 就是使用 gomod 管理!goland2020.3
- onchange在用于文本框输入框时,有一个明显的不足. 事件不会随着文字的输入而触发,而是等到文本框失去焦点(onblur)时才会触发.
- Python版本 实现了比之前的xxftp更多更完善的功能 1、继续支持多用户 2、继续支持虚拟目录 3、增加支持用户根目录以及映射虚拟目录
- 加了三个验证漏洞以及四个getshell方法# /usr/bin/env python3# -*- coding: utf-8 -*-# @
- Python assert 语句,又称断言语句,可以看做是功能缩小版的 if 语句,它用于判断某个表达式的值,如果值为真,则程序可以继续往下
- 刚才帮一位朋友做跳转的时候做的,为了获取完整的url地址,还是花了那么点时间不过现在看来,原来是那么简单,没有网上那么多复杂的东东,相信一定
- 什么是粘包问题最近在使用Golang编写Socket层,发现有时候接收端会一次读到多个数据包的问题。于是通过查阅资料,发现这个就是传说中的T
- 前言我们在开发应用是经常会需要用到一些数据的存储,存储的方式有多种,使用数据库是一种比较受大家欢迎的方式。但是对于一些小型的应用,如一些移动
- 此篇文章整理新手编写代码常见的一些错误,有些错误是粗心的错误,但对于新手而已,会折腾很长时间才搞定,所以在此总结下我遇到的一些问题。希望帮助
- 最近在接触一个Django项目,使用的是fbv( function-base views )模式,看起来特别不舒服,项目中有一个模型类117
- 有时你需临时搭建一个简单的 Web Server,但你又不想去安装 Apache、Nginx 等这类功能较复杂的 HTTP 服务程序时。这时
- 一、concurrent模块的介绍concurrent.futures模块提供了高度封装的异步调用接口ThreadPoolExecutor: