pytorch MSELoss计算平均的实现方法
作者:sunrise_ccx 发布时间:2021-07-31 18:44:15
给定损失函数的输入y,pred,shape均为bxc。
若设定loss_fn = torch.nn.MSELoss(reduction='mean'),最终的输出值其实是(y - pred)每个元素数字的平方之和除以(bxc),也就是在batch和特征维度上都取了平均。
如果只想在batch上做平均,可以这样写:
loss_fn = torch.nn.MSELoss(reduction='sum')
loss = loss_fn(pred, y) / pred.size(0)
补充:PyTorch中MSELoss的使用
参数
torch.nn.MSELoss(size_average=None, reduce=None, reduction: str = 'mean')
size_average和reduce在当前版本的pytorch已经不建议使用了,只设置reduction就行了。
reduction的可选参数有:'none' 、'mean' 、'sum'
reduction='none'
:求所有对应位置的差的平方,返回的仍然是一个和原来形状一样的矩阵。
reduction='mean'
:求所有对应位置差的平方的均值,返回的是一个标量。
reduction='sum'
:求所有对应位置差的平方的和,返回的是一个标量。
更多可查看官方文档
举例
首先假设有三个数据样本分别经过神经网络运算,得到三个输出与其标签分别是:
y_pre = torch.Tensor([[1, 2, 3],
[2, 1, 3],
[3, 1, 2]])
y_label = torch.Tensor([[1, 0, 0],
[0, 1, 0],
[0, 0, 1]])
如果reduction='none':
criterion1 = nn.MSELoss(reduction='none')
loss1 = criterion1(x, y)
print(loss1)
则输出:
tensor([[0., 4., 9.],
[4., 0., 9.],
[9., 1., 1.]])
如果reduction='mean':
criterion2 = nn.MSELoss(reduction='mean')
loss2 = criterion2(x, y)
print(loss2)
则输出:
tensor(4.1111)
如果reduction='sum':
criterion3 = nn.MSELoss(reduction='sum')
loss3 = criterion3(x, y)
print(loss3)
则输出:
tensor(37.)
在反向传播时的使用
一般在反向传播时,都是先求loss,再使用loss.backward()求loss对每个参数 w_ij和b的偏导数(也可以理解为梯度)。
这里要注意的是,只有标量才能执行backward()函数,因此在反向传播中reduction不能设为'none'。
但具体设置为'sum'还是'mean'都是可以的。
若设置为'sum',则有Loss=loss_1+loss_2+loss_3,表示总的Loss由每个实例的loss_i构成,在通过Loss求梯度时,将每个loss_i的梯度也都考虑进去了。
若设置为'mean',则相比'sum'相当于Loss变成了Loss*(1/i),这在参数更新时影响不大,因为有学习率a的存在。
以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。如有错误或未考虑完全的地方,望不吝赐教。
来源:https://blog.csdn.net/qq_27061325/article/details/96130824
猜你喜欢
- 0. dockerfile命令FROM # 基础镜像,一切从这里开始构建MAINTAINER # 镜像是谁写的,姓名+邮箱RUN# 镜像构建
- 一、问题的提出随着互连网的发展,网站的数量以惊人的数字增加。网站的作用除了给广大网友们提供信息资讯服务外,还应该成为网友们上传与下载文件的场
- 目标打包Python selenium 自动化脚本(如下run.py文件)为exe执行文件,使之可以直接在未安装python环境的windo
- PIL(Python Image Library)是python的第三方图像处理库,但是由于其强大的功能与众多的使用人数,几乎已
- jqGrid是一个优秀的基于jQuery的DataGrid框架,想必大伙儿也不陌生,网上基于ASP的资料很少,我提供一个,数据格式是json
- TKinter库,Python 的 GUI 库非常多,之所以选择 Tkinter,一是最为简单,二是自带库,不需下载安装,随时使用,跨平台兼
- Function Comma(str)If Not(IsNumeric(str)) Or 
- 使用windows API使用PIL中的ImageGrab模块下面对两者的特点和用法进行详细解释。一、Python调用windows API
- 这里直接给出第一个版本的直接实现:import osimport numpy as npfrom sklearn.cluster impor
- 1.闭包的定义和使用当返回的内部函数使用了外部函数的变量就形成了闭包闭包可以对外部函数的变量进行保存,还可以提高代码的可重用性实现闭包的标准
- 一、为何使用Tkinter而非PyQt众所周知,在Python中创建图形界面程序有很多种的选择,其中PyQt和wxPython都是很热门的模
- binascii模块用法binascii模块用于在二进制和ASCII之间转换>> import binascii# 将binar
- 前言前几话主要讲解关于使用golang进行单元测试,在单元测试的上一层就是接口测试,本节主要讲使用golang进行接口测试,其中主要以htt
- 流程,通俗来讲,就是许多人,在做一系列的事情时,怎样相互协调,安排好这一系列事情的先后顺序,有什么事先的约定,需要达到怎样的预期目标。在UE
- 一、通信方式进程彼此之间互相隔离,要实现进程间通信(IPC),multiprocessing模块主要通过队列方式队列:队列类似于一条管道,元
- 说到JavaScript中声明变量的几种方法也就是var、let、const了,let和const是es6中新增的命令。那么它们之间有什么区
- 前言我们可以给视图函数加装饰器来判断是用户是否登录,把没有登录的用户请求跳转到登录页面等等。我们通过给几个特定视图函数加装饰器实现了这个需求
- 良好的编程习惯是每个程序员都应该具备的工作素质,在我的软件生涯中屡屡发现一些程序员的身上总有这样或者那样的坏毛病。这些毛病在一些从业时间不是
- 在http规则中用404来表示某个页面不能访问,一般来说,网站的404错误页面都是IIS或APACHE默认的页面,千篇一律,非常单调。由于可
- 十个免费的web前端开发工具网络技术发展迅速,部分技术难以保持每年都有新的工具出现,这同时也意味着许多旧的工具倒在了新技术的发展之路上。前端