pytorch分类模型绘制混淆矩阵以及可视化详解
作者:王延凯的博客 发布时间:2023-01-17 17:35:43
标签:pytorch,混淆矩阵,可视化
Step 1. 获取混淆矩阵
#首先定义一个 分类数*分类数 的空混淆矩阵
conf_matrix = torch.zeros(Emotion_kinds, Emotion_kinds)
# 使用torch.no_grad()可以显著降低测试用例的GPU占用
with torch.no_grad():
for step, (imgs, targets) in enumerate(test_loader):
# imgs: torch.Size([50, 3, 200, 200]) torch.FloatTensor
# targets: torch.Size([50, 1]), torch.LongTensor 多了一维,所以我们要把其去掉
targets = targets.squeeze() # [50,1] -----> [50]
# 将变量转为gpu
targets = targets.cuda()
imgs = imgs.cuda()
# print(step,imgs.shape,imgs.type(),targets.shape,targets.type())
out = model(imgs)
#记录混淆矩阵参数
conf_matrix = confusion_matrix(out, targets, conf_matrix)
conf_matrix=conf_matrix.cpu()
混淆矩阵的求取用到了confusion_matrix函数,其定义如下:
def confusion_matrix(preds, labels, conf_matrix):
preds = torch.argmax(preds, 1)
for p, t in zip(preds, labels):
conf_matrix[p, t] += 1
return conf_matrix
在当我们的程序执行结束 test_loader 后,我们可以得到本次数据的 混淆矩阵,接下来就要计算其 识别正确的个数以及混淆矩阵可视化:
conf_matrix=np.array(conf_matrix.cpu())# 将混淆矩阵从gpu转到cpu再转到np
corrects=conf_matrix.diagonal(offset=0)#抽取对角线的每种分类的识别正确个数
per_kinds=conf_matrix.sum(axis=1)#抽取每个分类数据总的测试条数
print("混淆矩阵总元素个数:{0},测试集总个数:{1}".format(int(np.sum(conf_matrix)),test_num))
print(conf_matrix)
# 获取每种Emotion的识别准确率
print("每种情感总个数:",per_kinds)
print("每种情感预测正确的个数:",corrects)
print("每种情感的识别准确率为:{0}".format([rate*100 for rate in corrects/per_kinds]))
执行此步的输出结果如下所示:
Step 2. 混淆矩阵可视化
对上边求得的混淆矩阵可视化
# 绘制混淆矩阵
Emotion=8#这个数值是具体的分类数,大家可以自行修改
labels = ['neutral', 'calm', 'happy', 'sad', 'angry', 'fearful', 'disgust', 'surprised']#每种类别的标签
# 显示数据
plt.imshow(conf_matrix, cmap=plt.cm.Blues)
# 在图中标注数量/概率信息
thresh = conf_matrix.max() / 2#数值颜色阈值,如果数值超过这个,就颜色加深。
for x in range(Emotion_kinds):
for y in range(Emotion_kinds):
# 注意这里的matrix[y, x]不是matrix[x, y]
info = int(conf_matrix[y, x])
plt.text(x, y, info,
verticalalignment='center',
horizontalalignment='center',
color="white" if info > thresh else "black")
plt.tight_layout()#保证图不重叠
plt.yticks(range(Emotion_kinds), labels)
plt.xticks(range(Emotion_kinds), labels,rotation=45)#X轴字体倾斜45°
plt.show()
plt.close()
来源:https://blog.csdn.net/weixin_38468077/article/details/121671139
0
投稿
猜你喜欢
- 本文实例讲述了Python实现对一个函数应用多个装饰器的方法。分享给大家供大家参考,具体如下:下面的例子展示了对一个函数应用多个装饰器,可以
- PYTHON 操作 XML读取XML文件关于XML的介绍<data> 与 </data> 是一对标签的开始与结束&l
- Python-apply(lambda x: )使用def instant_order_deal(plat, special_product
- 为你的网站,博客等添加rss聚合功能,给出rss.asp和rss.xml两种的聚合代码看过的朋友可帮忙顶哦,这些代码都是第一次发的,外面很多
- 字典的本质就是 hash 表,hash 表就是通过 key 找到其 value ,平均情况下你只需要花费 O(1) 的时间复杂度即可以完成对
- 一、异常处理在程序开发中如果遇到一些 不可预知的错误 或 你懒得做一些判断 时,可以选择用异常处理来做。import requestswhi
- 零、SQLAlchemy是什么?SQLAlchemy的官网上写着它的介绍文字:SQLAlchemy is the Python SQL to
- 文章介绍OpenCV 库中包含很多运算函数,这里着重介绍按位运算的基本原理并举例说明。本篇文章中主要涉及到的函数有:按位与:bitwise_
- 简单低级的爬虫速度快,伪装度低,如果没有反爬机制,它们可以很快的抓取大量数据,甚至因为请求过多,造成服务器不能正常工作。而伪装度高的爬虫爬取
- 描述super() 函数用于调用下一个父类(超类)并返回该父类实例的方法。super 是用来解决多重继承问题的,直接用类名调用父类方法在使用
- 一、函数解释在torch/_C/_VariableFunctions.py的有该定义,意义就是实现一下公式:换句话说,就是需要传入5个参数,
- 题目文件scores.csv包含十位学生的成绩单,表头是"姓名 语文 数学 英语"。请编程完成下述功能。1)计算每位学生
- 视图(View)“视图”主要指我们送到Web浏览器的最终结果??比如我们的脚本生成的HTML。当说到视图时,很多人想到的是模版,但是把模板方
- 本文实例讲解了python实现两个程序之间通信的方法,具体方法如下:该实例采用socket实现,与socket网络编程不一样的是socket
- 在写论文时,如果是菜鸟级别,可能不会花太多时间去学latex,直接用word去写,但是这有一个问题,当我们用其他工具画完实验彩 * 时,放到w
- 简单邮件传输协议(SMTP)是一种协议,用于处理在电子邮件服务器之间发送电子邮件和路由电子邮件。Python提供了smtplib模块,该模块
- 我就废话不多说啦,还是直接看代码吧!list1 = [1,2,3,4]a,b,c,d = list1则a = 1b =2等这种方式只有当左边
- 公司客户在使用网站后台编辑添加修改内容时,经常是直接从word文档里复制内容到编辑器里后就提交。结果是在内容显示页面上是五花八门的样式,有时
- 数据驱动模式的测试好处相比普通模式的测试就显而易见了吧!使用数据驱动的模式,可以根据业务分解测试数据,只需定义变量,使用外部或者自定义的数据
- 起因在公司搭建了套webpack多页面应用脚手架,开始用着很爽,解决了既想使用Vue的模块化开发,又想做多页打包上线管理的初衷,但是随着业务