pytorch分类模型绘制混淆矩阵以及可视化详解
作者:王延凯的博客 发布时间:2023-01-17 17:35:43
标签:pytorch,混淆矩阵,可视化
Step 1. 获取混淆矩阵
#首先定义一个 分类数*分类数 的空混淆矩阵
conf_matrix = torch.zeros(Emotion_kinds, Emotion_kinds)
# 使用torch.no_grad()可以显著降低测试用例的GPU占用
with torch.no_grad():
for step, (imgs, targets) in enumerate(test_loader):
# imgs: torch.Size([50, 3, 200, 200]) torch.FloatTensor
# targets: torch.Size([50, 1]), torch.LongTensor 多了一维,所以我们要把其去掉
targets = targets.squeeze() # [50,1] -----> [50]
# 将变量转为gpu
targets = targets.cuda()
imgs = imgs.cuda()
# print(step,imgs.shape,imgs.type(),targets.shape,targets.type())
out = model(imgs)
#记录混淆矩阵参数
conf_matrix = confusion_matrix(out, targets, conf_matrix)
conf_matrix=conf_matrix.cpu()
混淆矩阵的求取用到了confusion_matrix函数,其定义如下:
def confusion_matrix(preds, labels, conf_matrix):
preds = torch.argmax(preds, 1)
for p, t in zip(preds, labels):
conf_matrix[p, t] += 1
return conf_matrix
在当我们的程序执行结束 test_loader 后,我们可以得到本次数据的 混淆矩阵,接下来就要计算其 识别正确的个数以及混淆矩阵可视化:
conf_matrix=np.array(conf_matrix.cpu())# 将混淆矩阵从gpu转到cpu再转到np
corrects=conf_matrix.diagonal(offset=0)#抽取对角线的每种分类的识别正确个数
per_kinds=conf_matrix.sum(axis=1)#抽取每个分类数据总的测试条数
print("混淆矩阵总元素个数:{0},测试集总个数:{1}".format(int(np.sum(conf_matrix)),test_num))
print(conf_matrix)
# 获取每种Emotion的识别准确率
print("每种情感总个数:",per_kinds)
print("每种情感预测正确的个数:",corrects)
print("每种情感的识别准确率为:{0}".format([rate*100 for rate in corrects/per_kinds]))
执行此步的输出结果如下所示:
Step 2. 混淆矩阵可视化
对上边求得的混淆矩阵可视化
# 绘制混淆矩阵
Emotion=8#这个数值是具体的分类数,大家可以自行修改
labels = ['neutral', 'calm', 'happy', 'sad', 'angry', 'fearful', 'disgust', 'surprised']#每种类别的标签
# 显示数据
plt.imshow(conf_matrix, cmap=plt.cm.Blues)
# 在图中标注数量/概率信息
thresh = conf_matrix.max() / 2#数值颜色阈值,如果数值超过这个,就颜色加深。
for x in range(Emotion_kinds):
for y in range(Emotion_kinds):
# 注意这里的matrix[y, x]不是matrix[x, y]
info = int(conf_matrix[y, x])
plt.text(x, y, info,
verticalalignment='center',
horizontalalignment='center',
color="white" if info > thresh else "black")
plt.tight_layout()#保证图不重叠
plt.yticks(range(Emotion_kinds), labels)
plt.xticks(range(Emotion_kinds), labels,rotation=45)#X轴字体倾斜45°
plt.show()
plt.close()
来源:https://blog.csdn.net/weixin_38468077/article/details/121671139


猜你喜欢
- 有一道题: 比较两个列表范围,如果包含的话,返回TRUE,否则FALSE。 详细题目如下:Create a function, this f
- models.py:from django.db import models # 出版社class Publisher(models.Mod
- 前言嗨,彦祖们,不会过圣诞了还是一个人吧?今天我们来讲一下如何用python来画一个圣诞树,学会就快给那个她发过去吧,我的朋友圈已经让圣诞树
- 想要asp能连接mysql数据库需要安装MySQL ODBC 3.51 驱动 http://www.jb51.net/softs/19910
- 今天我们整理了ip地址和身份证的javascript验证方法。虽然ip地址和身份证的验证不是很经常会遇到,但是大家也可以研究一下js代码,里
- 消息/事件机制是几乎所有开发语言都有的机制,并不是deviceone的独创,在某些语言称之为消息(Event),有些地方称之为(Messag
- 换了N种字符串连接的方法,终于连接上去了。 共享下用的 Provider=SQLOLEDB.1; User ID=sa; Password=
- 在本教程中,我将指导您如何编写代码,以使用具有基于表单的身份验证的Spring安全API来保护Spring Boot应用程序中的网页。用户详
- 效果图:实现代码如下view<canvas id="radar-canvas" class="radar
- 1、开启Mysql慢查询1.1、查看慢查询相关配置show variables like 'slow_query_log%'
- 简单实现了一个在函数执行出现异常时自动重试的装饰器,支持控制最多重试次数,每次重试间隔,每次重试间隔时间递增。最新的代码可以访问从githu
- 一、SQL Server 和SSMS的安装1. SQL的安装下载地址:SQL Server。进入下载地址选择Developer或者Expre
- 对于题目中提出的问题,可以拆分来一步步解决。在 MySQL 中 KEY 和 INDEX 是同义。那这个问题就可以简化为 PRIMARY KE
- 如下所示:distances = np.sqrt(np.sum(np.asarray(airportPosition - x_vals)**
- 实例如下所示:# -*- coding:utf-8 -*- #os模块中包含很多操作文件和目录的函数 import os #获取目标文件夹的
- 哎,以前写博文的时候没注意,有些图片用QQ来截取,获得的图片文件名都是类似于QQ截图20120926174732-300×15.png的形式
- 在使用numpy数组的过程中时常会出现nan或者inf的元素,可能会造成数值计算时的一些错误。这里提供一个numpy库函数的用法,使nan和
- 最近在工作中遇到一个问题,就是有一个功能希望在各种服务器上实现,而服务器上的系统版本可能都不一样,有的是 CentOS 6.x,有的是 Ce
- 如何下载:我先去MySQL首页下载最新版本的MySQL-链接:https://www.mysql.com/downloads/进入此界面下载
- 需求:前端获取到摄像头信息,通过模型来进行判断人像是否在镜头中,镜头是否有被遮挡。实现步骤:1、通过video标签来展示摄像头中的内容2、通