解决Pytorch训练过程中loss不下降的问题
作者:猫猬兽 发布时间:2023-03-01 09:30:22
标签:Pytorch,loss,不下降
在使用Pytorch进行神经网络训练时,有时会遇到训练学习率不下降的问题。出现这种问题的可能原因有很多,包括学习率过小,数据没有进行Normalization等。不过除了这些常规的原因,还有一种难以发现的原因:在计算loss时数据维数不匹配。
下面是我的代码:
loss_function = torch.nn.MSE_loss()
optimizer.zero_grad()
output = model(x_train)
loss = loss_function(output, y_train)
loss.backward()
optimizer.step()
要特别注意计算loss时网络输出值output和真实值y_train的维数必须完全匹配,否则训练误差不下降,无法训练。这种错误在训练一维数据时很容易忽略,要十分注意。
来源:https://blog.csdn.net/yyb19951015/article/details/88779171


猜你喜欢
- 今天介绍Python当中十大可视化工具,每一个都独具特色,惊艳一方。MatplotlibMatplotlib 是 Python 的一个绘图库
- 安全公司 Imperva Cloud WAF 保护了全球超过10万个网站,并且每天观察到大约10亿次攻击。他们每天都会检测到成千上万种黑客工
- m3u8原理当我们在网页播放视频时,网页向服务器发起一个以.m3u8结尾的连接请求,服务器会将具体的.ts文件链接路径发送给网页,网页接收这
- DatePart 的语法是 DatePart(interval, date),用以取 date 的某部分。 interval yyyy:da
- 1、如何快速找到多个字典中的公共键(key)实际案例:西班牙足球甲级联赛,每轮球员进球统计:第一轮:{'苏亚雷斯': 1,
- 在查询凭证、审核凭证时出现“列前缀tempdb.无效: 未指定表名”的错误提示,怎么解决?原因:是因为SQL2000无法识别计算机名称中”-
- 具体方法:(推荐教程:mysql数据库学习教程)查看表被锁状态# 查询哪些表锁了show OPEN TABLE
- 用phpMyAdmin时在导入和导出MySQL5数据时,有一个SQL compatibility mode选项,其可选值为NONE、ANSI
- 前言最近在做文本统计,用 Python 实现,遇到了一个比较有意思的难题——如何保存统计结果。直接写入内存实在是放不下,十几个小时后内存耗尽
- 调用sklearn的model_selection时,发现sklearn中没有model_selection的模块。经过检查,发现anaco
- 本文实例讲述了django框架模板语言使用方法。分享给大家供大家参考,具体如下:模板功能作用:生成html界面内容,模版致力于界面如何显示,
- 一、实验环境1.Windows7x64_SP12.anaconda3.7 + python3.7(anaconda集成,不需单独安装)3.p
- pyqtgraph是Python平台上一种功能强大的2D/3D绘图库,相对于matplotlib库,由于其在内部实现方式上,使用了高速计算的
- Go+ 语言的安装和环境配置有些复杂,官方教程也没有写的很详细。通过控制台编写和运行 Go+ 程序很不方便。本文从零开始,详细介绍 Go+
- 首先声明,在这组里我是个绝对的菜鸟。再次声明,小爝这个菜鸟在“网页设计”这个圈里混了快1年了。 摘要:我知道我有多少底,所以我在总结我的成长
- 成天都要与样式打交道的朋友,相信对CSS选择符(CSS Selectors)都不会陌生。不过对于刚接触或者还不是很熟悉css的朋友来说,能够
- 作为让高中生心脏骤停的四个字,对于高考之后的人来说可谓刻骨铭心,所以定义不再赘述,直接撸图,其标准方程分别为在Python中,绘制动图需要用
- 前言在python中,print是重要的输出语句,让我们更方便的知道程序的运行状况,但是这样还不够,我们也可以用print来给周围的小伙伴秀
- Git修改已提交的commit注释两种情况:修改最后一次注释1、在命令行输入如下命令,然后回车:git commit --amend2、在命
- 原图代码 src = cv2.imread("28.png") gray_src = cv2.c