Pytorch中accuracy和loss的计算知识点总结
作者:嶙羽 发布时间:2023-06-25 10:57:32
标签:Pytorch,accuracy,loss
这几天关于accuracy和loss的计算有一些疑惑,原来是自己还没有弄清楚。
给出实例
def train(train_loader, model, criteon, optimizer, epoch):
train_loss = 0
train_acc = 0
num_correct= 0
for step, (x,y) in enumerate(train_loader):
# x: [b, 3, 224, 224], y: [b]
x, y = x.to(device), y.to(device)
model.train()
logits = model(x)
loss = criteon(logits, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += float(loss.item())
train_losses.append(train_loss)
pred = logits.argmax(dim=1)
num_correct += torch.eq(pred, y).sum().float().item()
logger.info("Train Epoch: {}\t Loss: {:.6f}\t Acc: {:.6f}".format(epoch,train_loss/len(train_loader),num_correct/len(train_loader.dataset)))
return num_correct/len(train_loader.dataset), train_loss/len(train_loader)
首先这样一次训练称为一个epoch,样本总数/batchsize是走完一个epoch所需的“步数”,相对应的,len(train_loader.dataset)也就是样本总数,len(train_loader)就是这个步数。
那么,accuracy的计算也就是在整个train_loader的for循环中(步数),把每个mini_batch中判断正确的个数累加起来,然后除以样本总数就行了;
而loss的计算有讲究了,首先在这里我们是计算交叉熵,关于交叉熵,也就是涉及到两个值,一个是模型给出的logits,也就是10个类,每个类的概率分布,另一个是样本自身的
label,在Pytorch中,只要把这两个值输进去就能计算交叉熵,用的方法是nn.CrossEntropyLoss,这个方法其实是计算了一个minibatch的均值了,因此累加以后需要除以的步数,也就是
minibatch的个数,而不是像accuracy那样是样本个数,这一点非常重要。
来源:https://www.cnblogs.com/yqpy/p/11497199.html


猜你喜欢
- USE Demo GO /* 将表Code的列String中的值提取放到Record表中 String 中字符类型为 dsddddd,222
- mysqlslap常用参数说明–auto-generate-sql 由系统自动生成sql脚本进行测试–auto-generate-sql-a
- 获取Tensor的维数>>> import tensorflow as tf>>> tf.__versi
- DataFrame对象的创建,修改,合并import pandas as pdimport numpy as np创建DataFrame对象
- 一、存储引擎上节我们最后说到,SQL 的执行计划是执行器组件调用存储引擎的接口来完成的。那我们可以理解为:MySQL 这个数据库管理系统是依
- 在之前写过一篇使用python爬虫爬取电影天堂资源的文章,重点是如何解析页面和提高爬虫的效率。由于电影天堂上的资源获取权限是所有人都一样的,
- 1.元组的概念Python中的元组和列表很相似,元组也是Python语言提供的内置数据结构之一,可以在代码中直接使用。元组和列表就像是一个孪
- 首先要了解为什么用连接池,连接池能为你解决什么问题连接池主要的作用:1、减少与数据服务器建立TCP连接三次握手及连接关闭四次挥手的开销,从而
- 1、查询时间区间日期列表,不会由于数据表数据影响select a.date from ( select curda
- 考虑到女友的安全问题,就做了一个app实现定位和服务器实现转发的东西。刚学python,竟没想到用对象编程会更加方便,全程过程式开发,代码有
- 一、用HTTP头信息 也就是用PHP的HEADER函数。PHP里的HEADER函数的作用就是向浏览器发出由HTTP协议规定的本来应该通过WE
- 2021年7月1日,官方正式发布了1.0Datatable版本。1.0版本支持windows和linux,以及Macos。 具体文档可以见:
- 1、直接贴代码#!C:/Python27#coding=utf-8from selenium import webdriverfrom se
- 方法一:单表导入(1)打开"SQL Server 外围应用配置器"-->"功能的外围应用配置器"
- 说明1、如果数据集是高维度的,选择谱聚类是子空间的一种。2、如果数据量是中小型的,比如在100W条以内,K均值会是更好的选择;如果数据量超过
- 本文实例为大家分享了python爬取淘宝商品的具体代码,供大家参考,具体内容如下import requests as req import
- 1、选取最适用的字段属性MySQL 可以很好的支持大数据量的存取,但是一般说来,数据库中的表越小,在它上面执行的查询也就会越快。因此,在创建
- 需求是需要用python往 SqlServer中的image类型字段中插入二进制图片核心代码,研究好几个小时的代码:安装pywin32,ad
- 目录1. 什么是竞态2. 如何消除竞态3. Go 提供的并发工具3.1 互斥锁3.2 读写互斥锁3.3 Once3.4 竞态检测器4. 小结
- 本文实例讲述了bootstrap-table后端分页功能。分享给大家供大家参考,具体如下:使用bootstrap-table实现后台分页插件