基于MSELoss()与CrossEntropyLoss()的区别详解
作者:Foneone 发布时间:2022-05-17 19:18:27
标签:MSELoss,CrossEntropyLoss
基于pytorch来讲
MSELoss()多用于回归问题,也可以用于one_hotted编码形式,
CrossEntropyLoss()名字为交叉熵损失函数,不用于one_hotted编码形式
MSELoss()要求batch_x与batch_y的tensor都是FloatTensor类型
CrossEntropyLoss()要求batch_x为Float,batch_y为LongTensor类型
(1)CrossEntropyLoss() 举例说明:
比如二分类问题,最后一层输出的为2个值,比如下面的代码:
class CNN (nn.Module ) :
def __init__ ( self , hidden_size1 , output_size , dropout_p) :
super ( CNN , self ).__init__ ( )
self.hidden_size1 = hidden_size1
self.output_size = output_size
self.dropout_p = dropout_p
self.conv1 = nn.Conv1d ( 1,8,3,padding =1)
self.fc1 = nn.Linear (8*500, self.hidden_size1 )
self.out = nn.Linear (self.hidden_size1,self.output_size )
def forward ( self , encoder_outputs ) :
cnn_out = F.max_pool1d ( F.relu (self.conv1(encoder_outputs)),2)
cnn_out = F.dropout ( cnn_out ,self.dropout_p) #加一个dropout
cnn_out = cnn_out.view (-1,8*500)
output_1 = torch.tanh ( self.fc1 ( cnn_out ) )
output = self.out ( ouput_1)
return output
最后的输出结果为:
上面一个tensor为output结果,下面为target,没有使用one_hotted编码。
训练过程如下:
cnn_optimizer = torch.optim.SGD(cnn.parameters(),learning_rate,momentum=0.9,\
weight_decay=1e-5)
criterion = nn.CrossEntropyLoss()
def train ( input_variable , target_variable , cnn , cnn_optimizer , criterion ) :
cnn_output = cnn( input_variable )
print(cnn_output)
print(target_variable)
loss = criterion ( cnn_output , target_variable)
cnn_optimizer.zero_grad ()
loss.backward( )
cnn_optimizer.step( )
#print('loss: ',loss.item())
return loss.item() #返回损失
说明CrossEntropyLoss()是output两位为one_hotted编码形式,但target不是one_hotted编码形式。
(2)MSELoss() 举例说明:
网络结构不变,但是标签是one_hotted编码形式。下面的图仅做说明,网络结构不太对,出来的预测也不太对。
如果target不是one_hotted编码形式会报错,报的错误如下。
目前自己理解的两者的区别,就是这样的,至于多分类问题是不是也是样的有待考察。
来源:https://blog.csdn.net/foneone/article/details/90127707
0
投稿
猜你喜欢
- 获取带有中文参数的url内容对于中文的参数如果不进行编码的话,python的urllib2直接处理会报错,我们可以先将中文转换成utf- 8
- 前言随着Python 3.8的发布,赋值表达式运算符(也称为海象运算符)也发布了。运算符使值的赋值可以传递到表达式中。这通常会使语句数减少一
- 本文介绍了prototype.js常用函数及其使用方法例子说明函数名
- 一、cv2.contourArea起初使用该函数的时候看不懂返回的面积,有0有负数的,于是研究了一下。opencv计算轮廓内面积函数使用的是
- 想要使用多个CPU核心来进行测试,可以使用 -n 参数( 或者 --numprocesses)(使用8个核心来跑测试用例)pytest -n
- 因为比较简单,我就不说什么了。一看就明白的!1.sql防注入函数Function ChkStr(InString) &
- 页面自动刷新代码大全,基本上所有要求自动刷新页面的代码都有,大家可以自由发挥做出完美的页面。 1)10表示间隔10秒刷 ...页面自动刷新代
- 大家可能都不大熟悉关于pdb这个模块,实际上就是python中的内置模块,主要作用于命令行调试代码,下面我们将通过是哪个小结给大家详细介绍下
- PHP mysqli_stmt_init() 函数初始化声明并返回 mysqli_stmt_prepare() 使用的对象:<?php
- 思路有些混乱,希望大家能理解我的意思。看了faster rcnn的tensorflow代码,关于fix_variables的作用我不是很明白
- 平时经常看php的错误日志,很少有机会去自己动手写日志,看了王健的《最佳日志实践》觉得写一个清晰明了,结构分明的日志还是非常有必要的。在写日
- 提高性能有如下方法1、Cython,用于合并python和c语言静态编译泛型2、IPython.parallel,用于在本地或者集群上并行执
- 今天主题是实现并发服务器,实现方法有多种版本,先从简单的单进程代码实现到多进程,多线程的实现,最终引入一些高级模块来实现并发TCP服务器。说
- 前言相信大家在工作无聊时,总想掏出手机,看看微博热搜在讨论什么有趣的话题,但又不方便直接打开微博浏览,今天就和大家分享一个有趣的小爬虫,定时
- 前言没有用过的东西,没有深刻理解的东西很难说自己会,而且被别人一问必然破绽百出。虽然之前有接触过python协程的概念,但是只是走马观花,这
- 导语:举例:Python做一个根据后缀名整理文件的工具,先来看看效果:自动整理前:自动整理后:这样看起来就好很多了。1.准备开始之前,你要确
- 何为标准化:在数据分析之前,我们通常需要先将数据标准化(normalization),利用标准化后的数据进行数据分析。数据标准化也就是统计数
- Q. How can I restrict access to my SQL Server so that it only allows c
- 方法1:import sysprint(sys.argv)得到文件当前绝对路径字符串的一个列表['D:/pycharm/Practi
- 使用OpenCV和Python查找图片差异flyfish方法1 均方误差的算法(Mean Squared Error , MSE)下面的一些