Pytorch中accuracy和loss的计算知识点总结
作者:嶙羽 发布时间:2023-06-25 10:57:32
标签:Pytorch,accuracy,loss
这几天关于accuracy和loss的计算有一些疑惑,原来是自己还没有弄清楚。
给出实例
def train(train_loader, model, criteon, optimizer, epoch):
train_loss = 0
train_acc = 0
num_correct= 0
for step, (x,y) in enumerate(train_loader):
# x: [b, 3, 224, 224], y: [b]
x, y = x.to(device), y.to(device)
model.train()
logits = model(x)
loss = criteon(logits, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += float(loss.item())
train_losses.append(train_loss)
pred = logits.argmax(dim=1)
num_correct += torch.eq(pred, y).sum().float().item()
logger.info("Train Epoch: {}\t Loss: {:.6f}\t Acc: {:.6f}".format(epoch,train_loss/len(train_loader),num_correct/len(train_loader.dataset)))
return num_correct/len(train_loader.dataset), train_loss/len(train_loader)
首先这样一次训练称为一个epoch,样本总数/batchsize是走完一个epoch所需的“步数”,相对应的,len(train_loader.dataset)也就是样本总数,len(train_loader)就是这个步数。
那么,accuracy的计算也就是在整个train_loader的for循环中(步数),把每个mini_batch中判断正确的个数累加起来,然后除以样本总数就行了;
而loss的计算有讲究了,首先在这里我们是计算交叉熵,关于交叉熵,也就是涉及到两个值,一个是模型给出的logits,也就是10个类,每个类的概率分布,另一个是样本自身的
label,在Pytorch中,只要把这两个值输进去就能计算交叉熵,用的方法是nn.CrossEntropyLoss,这个方法其实是计算了一个minibatch的均值了,因此累加以后需要除以的步数,也就是
minibatch的个数,而不是像accuracy那样是样本个数,这一点非常重要。
来源:https://www.cnblogs.com/yqpy/p/11497199.html
0
投稿
猜你喜欢
- 邹建 2004.4 代码如下:/*--调用示例 exec p_lockinfo1 --*/ alter proc p_lockinfo1
- 这个使用起来很简单,以前需要的时候在网上找的,用了感觉还不错,具体的看演示就明白了。,这个可以保留你文章中的html标记,需要你修改的就是下
- ORA-00600:internal error code,arguments:[num],[?],[?],[?],[?] 产生原因:这种错
- 1、首先停止正在运行的MySQL进程 Linux下,运行 killall -TERM mysqld Windows下,如果写成服务的 可以运
- 这是写给web设计者和前端开发者的教程,我们将演示如何使用Photoshop创建按钮的sprite图,然后是如何使用jQurey打造动态渐变
- 下面演示了,当asp程序发生错误时,屏蔽系统默认的错误显示,而显示自定义的错误信息。<%@ LANGUAGE="V
- 前言相信大家都玩过斗地主,规则就不再介绍了。直接上一张朋友圈看到的残局图:这道题我刚看到时,曾尝试用手工来破解,每次都以为找到了农民的必胜策
- 前言在讲解如何解决migrate报错原因前,我们先要了解migrate做了什么事情,migrate:将新生成的迁移脚本。映射到数据库中。创建
- 闭包内容:匿名函数:能够完成简单的功能,传递这个函数的引用,只有功能普通函数:能够完成复杂的功能,传递这个函数的引用,只有功能闭包:能够完成
- 在CSS样式中,dl部分只是简单的把内外边距设置为0,dd部分有一个clear属性需要特别注意。当某个元素的属性设置float浮动时,它所在
- 前言最近因为工作的需要,在写一些python脚本,总是使用print来打印信息感觉很low,所以抽空研究了一下python的logging库
- 1、代码from aip import AipFaceimport cv2import timeimport base64from PIL
- 这是借鉴了一位兄弟的代码,然后进行修改的,原来代码存在问题,用了2小时,自己修改,终于画出了滑稽脸,也算是对于今天学的turtle绘画库的一
- 本文实例讲述了asp.net实现图片以二进制流输出的两种方法。分享给大家供大家参考,具体如下:方法一:System.IO.MemoryStr
- 最近在老家找工作,无奈老家工作真心太少,也没什么面试机会,不过之前面试一家公司,提了一个有意思的需求,检测河面没有有什么船只之类的物体,我当
- 上一篇:微软建议的ASP性能优化28条守则(8)技巧 28:阅读资源链接下面是一些与性能有关的出色的资源链接。如果您想了解有关信息,请阅读
- 文本的排版依据语言的不同会有一些格式上的要求,比如简体中文中类似逗号、分号等标点符号不会出现在一行的开头,对于英文来讲就是一个完整单词不会在
- 序列化是将对象状态转换为可保持或传输的格式的过程。与序列化相对的是反序列化,它将流转换为对象。这两个过程结合起来,可以轻松地存储和传输数据方
- 最近在写的一个django小项目需要实现用户上传图片的功能,使用到了七牛云存储,特此记录下来。这里我使用的七牛python SDK 版本是7
- Oracle的逻辑运算符也是用在SQL语句中必不可少的因素,一共有三个逻辑运算符意义and双值运算符,如果左右两个条件都为真,则得到的值就为