pytorch 实现cross entropy损失函数计算方式
作者:HawardScut 发布时间:2022-03-18 00:45:50
标签:pytorch,nn.MSELoss,损失函数
均方损失函数:
这里 loss, x, y 的维度是一样的,可以是向量或者矩阵,i 是下标。
很多的 loss 函数都有 size_average 和 reduce 两个布尔类型的参数。因为一般损失函数都是直接计算 batch 的数据,因此返回的 loss 结果都是维度为 (batch_size, ) 的向量。
(1)如果 reduce = False,那么 size_average 参数失效,直接返回向量形式的 loss
(2)如果 reduce = True,那么 loss 返回的是标量
a)如果 size_average = True,返回 loss.mean();
b)如果 size_average = False,返回 loss.sum();
注意:默认情况下, reduce = True,size_average = True
import torch
import numpy as np
1、返回向量
loss_fn = torch.nn.MSELoss(reduce=False, size_average=False)
a=np.array([[1,2],[3,4]])
b=np.array([[2,3],[4,5]])
input = torch.autograd.Variable(torch.from_numpy(a))
target = torch.autograd.Variable(torch.from_numpy(b))
这里将Variable类型统一为float()(tensor类型也是调用xxx.float())
loss = loss_fn(input.float(), target.float())
print(loss)
tensor([[ 1., 1.],
[ 1., 1.]])
2、返回平均值
a=np.array([[1,2],[3,4]])
b=np.array([[2,3],[4,4]])
loss_fn = torch.nn.MSELoss(reduce=True, size_average=True)
input = torch.autograd.Variable(torch.from_numpy(a))
target = torch.autograd.Variable(torch.from_numpy(b))
loss = loss_fn(input.float(), target.float())
print(loss)
tensor(0.7500)
来源:https://blog.csdn.net/hao5335156/article/details/81029791


猜你喜欢
- 1.输入命令 mysqld --skip-grant-tables (前提关闭mysql.exe的进程 net stop mys
- 我就废话不多说了,直接上代码吧!import pandas as pdimport numpy as npimport matplotlib
- 为了实现项目中的搜索功能,我们使用的是全文检索框架haystack+搜索引擎whoosh+中文分词包jieba安装和配置安装所需包pip i
- 问题描述使用 Navicat 导入之前转储好的 sql 文件,报错错误原因在信息日志当中往上翻,发现没有选择数据库,所以报错的原因就是没有提
- 本文实例为大家分享了tkinter+pygame+spider实现音乐播放器,供大家参考,具体内容如下1.确定页面SongSheet&nbs
- MaxPooling1D和GlobalMaxPooling1D区别import tensorflow as tffrom tensorflo
- 本文实例为大家分享了python rsync服务器之间文件夹同步的具体代码,供大家参考,具体内容如下About rsync配置两
- 爬虫要想爬的好,IP代理少不了。。现在网站基本都有些反爬措施,访问速度稍微快点,就会发现IP被封,不然就是提交验证。下面就两种常用的模块来讲
- 本文实例为大家分享了pygame实现贪吃蛇游戏的具体代码,供大家参考,具体内容如下为了简化起见,游戏素材暂定为两张简单的图片(文中用的是30
- 这将为我们的团队节省每天重复的数据处理时间......简介如果你目前在一个数据或商业智能团队工作,你的任务之一可能是制作一些每日、每周或每月
- 1、MySQL下载1.1下载MySQL8.0.26安装与卸载的完整步骤记录MySQL关是一种关系数据库管理系统,所使用的 SQL 语言是用于
- 以下为SQL SERVER7.0以上版本的字段类型说明。SQL SERVER6.5的字段类型说明请参考SQL SERVER提供的说明。bit
- Go的三种安装方式Go有多种安装方式,你可以选择自己喜欢的。这里我们介绍三种最常见的安装方式:1.Go源码安装:这是一种标准的软件安装方式。
- 新建项目如下图,比如sigma目录是我要上传的项目,在six-sigma目录下新建三个文件,分别是LICENSE也就是开源协议,README
- 效果图:css:<style type="text/css"> /* 带复选框的下拉框 */ ul li{
- 一、项目介绍爬取网址:CSDN首页的Python、Java、前端、架构以及数据库栏目。简单分析其各自的URL不难发现,都是https://w
- 晚上突然间看到大猫的头像在闪动,速度打开一看,发现他问,以前我写button标签的时候有没有写type属性,老实的我只有诚实地告诉他,我没写
- --1. 创建表,添加测试数据 CREATE TABLE tb(id int, [value] varchar(10)) INSERT tb
- 本文实例讲述了Python scipy的二维图像卷积运算与图像模糊处理操作。分享给大家供大家参考,具体如下:二维图像卷积运算一 代码impo
- 实例如下所示:#!/usr/bin/python# -*- coding: UTF-8 -*-import smtplibimport em