浅谈pytorch grad_fn以及权重梯度不更新的问题
作者:端木亽 发布时间:2022-10-30 00:00:18
标签:pytorch,grad,fn,权重,梯度
前提:我训练的是二分类网络,使用语言为pytorch
Varibale包含三个属性:
data:存储了Tensor,是本体的数据
grad:保存了data的梯度,本事是个Variable而非Tensor,与data形状一致
grad_fn:指向Function对象,用于反向传播的梯度计算之用
在构建网络时,刚开始的错误为:没有可以grad_fn属性的变量。
百度后得知要对需要进行迭代更新的变量设置requires_grad=True ,操作如下:
train_pred = Variable(train_pred.float(), requires_grad=True)`
这样设置之后网络是跑起来了,但是准确率一直没有提升,很明显可以看出网络什么都没学到。
我输出 model.parameters() (网络内部的权重和偏置)查看,发现它的权重并没有更新,一直是同一个值,至此可以肯定网络什么都没学到,还是迭代那里出了问题。
询问同门后发现问题不在这里。
计算loss时,target与train_pred的size不匹配,我以以下操作修改了train_pred,使两者尺寸一致,才导致了上述问题。
train_pred = model(data)
train_pred = torch.max(train_pred, 1)[1].data.squeeze()
train_pred = Variable(train_pred.float(), requires_grad=False)
train_loss = F.binary_cross_entropy(validation_pred.float(), target)
train_loss.backward()
对train_pred多次处理后,它已无法正确地反向传播,实际上应该更改target,使其与train_pred size一致。
重点!!!要想loss正确反向传播,应直接将model(data)传入loss函数。
最终修改代码如下:
for batch_idx, (data, target) in enumerate(train_loader):
# Get Samples
label = target.view(target.size(0), 1).long()
target_onehot = torch.zeros(data.shape[0], args.num_classes).scatter_(1, label, 1)
data, target_onehot = Variable(data.cuda()), Variable(target_onehot.cuda().float())
model.zero_grad()
# Predict
train_pred = model(data)
train_loss = F.binary_cross_entropy(train_pred, target_onehot)
train_loss.backward()
optimizer.step()
来源:https://blog.csdn.net/duanmuji/article/details/85217338


猜你喜欢
- tkinter库Canvas操作三个实例实例一:涂鸦import tkinter as tkimport pyautogui as agfr
- 在修改后的 《闲谈 Web 图片服务器》 一文中也提及了"IE 浏览器的连接数问题",这也是个有趣的话题。值得补充记录一
- myPhoneBook2.py#!/usr/bin/python# -*- coding: utf-8 -*-import reclass
- 最近在部署Azure虚拟机的时候,一直访问不了网络数据库,一搜资料才知道,Azure默认是不打开入网规则的,需要手动设置。在 Windows
- 在python项目中,我们经常会用到lambda,那么lambda是什么呢,有什么作用,下面我们开始介绍1、可以使用lambda关键字创建匿
- 1、关于参数的区别实例方法:定义实例方法是最少有一个形参 ---> 实例对象,通常用 self类方法:定义类方法的时候最少有一个形参
- 牛顿法求平方根原理计算机常用循环来计算F的平方根.从某个猜测的x值开始,根据x^2与F的近似度来调整x,产生一个更好的猜测:x -= (x
- 开发环境:Pycharm 2018.3 + Anaconda3(5.3.0) + Python 3.7.1 + Numpy 1.15.4在此
- 1、设置web.config文件。以下为引用的内容:<system.web> ...... <globalization
- 1.os 库基本介绍os库提供通用的、基本的操作系统交互功能。三大操作系统:windowsMac OSLinuxos 库是python标准库
- 迄今为止,导出/导入工具集仍是跨多个平台转移数据所需劳动强度最小的首选实用工具,尽管人们常常抱怨它速度太慢。导入只是将每条记录从导出转储文件
- 1、选取最适用的字段属性 MySQL可以很好的支持大数据量的存取,但是一般说来,数据库中的表越小,在它上面执行的查询也就会越快。因此,在创建
- 客户端HTTP请求URL只是标识资源的位置,而HTTP是用来提交和获取资源。客户端发送一个HTTP请求到服务器的请求消息,包括以下格式:请求
- 1、简单应用代码如下:#!/usr/bin/env python# -*- coding: utf-8 -*-# @File : jieba
- 环境系统:win10cpu:i7-6700HQgpu:gtx965mpython : 3.6pytorch :0.3数据下载来源自Sasan
- show内容展示尝试用微信小程序的template组件实现。同时,尝试页面间转跳时传参,在目标页面引入模板文件实现 写的更少,做的更多 篇幅
- 现在市场上的OA基本上可归结为两大阵营,即php阵营和java阵营。但对接触Oa不久的用户来说,看到的往往只是它们的表相,只是明显的价格差异
- vue 中的 $slot以前一直不知到这个东西,后来发现 vue api 中 藏着很多的 很神奇的 api,比如这个具名插槽很好理解,但是那
- 目录前言一、首先二、接下来1.对照人脸获取2. 通过算法建立对照模型3.识别前言今天,我们用Python实现简单的人脸识别技术!Python
- 1、判断多个条件的语句,if为真则执行if后面的语句。2、如果elif是真的,则执行elif,后面的代码块不执行。3、如果if和elif不满