网络编程
位置:首页>> 网络编程>> Python编程>> pytorch分类模型绘制混淆矩阵以及可视化详解

pytorch分类模型绘制混淆矩阵以及可视化详解

作者:王延凯的博客  发布时间:2023-01-17 17:35:43 

标签:pytorch,混淆矩阵,可视化

Step 1. 获取混淆矩阵

#首先定义一个 分类数*分类数 的空混淆矩阵
conf_matrix = torch.zeros(Emotion_kinds, Emotion_kinds)
# 使用torch.no_grad()可以显著降低测试用例的GPU占用
   with torch.no_grad():
       for step, (imgs, targets) in enumerate(test_loader):
           # imgs:     torch.Size([50, 3, 200, 200])   torch.FloatTensor
           # targets:  torch.Size([50, 1]),     torch.LongTensor  多了一维,所以我们要把其去掉
           targets = targets.squeeze()  # [50,1] ----->  [50]

# 将变量转为gpu
           targets = targets.cuda()
           imgs = imgs.cuda()
           # print(step,imgs.shape,imgs.type(),targets.shape,targets.type())

out = model(imgs)
           #记录混淆矩阵参数
           conf_matrix = confusion_matrix(out, targets, conf_matrix)
           conf_matrix=conf_matrix.cpu()

混淆矩阵的求取用到了confusion_matrix函数,其定义如下:

def confusion_matrix(preds, labels, conf_matrix):
   preds = torch.argmax(preds, 1)
   for p, t in zip(preds, labels):
       conf_matrix[p, t] += 1
   return conf_matrix

在当我们的程序执行结束 test_loader 后,我们可以得到本次数据的 混淆矩阵,接下来就要计算其 识别正确的个数以及混淆矩阵可视化:

conf_matrix=np.array(conf_matrix.cpu())# 将混淆矩阵从gpu转到cpu再转到np
corrects=conf_matrix.diagonal(offset=0)#抽取对角线的每种分类的识别正确个数
per_kinds=conf_matrix.sum(axis=1)#抽取每个分类数据总的测试条数

print("混淆矩阵总元素个数:{0},测试集总个数:{1}".format(int(np.sum(conf_matrix)),test_num))
print(conf_matrix)

# 获取每种Emotion的识别准确率
print("每种情感总个数:",per_kinds)
print("每种情感预测正确的个数:",corrects)
print("每种情感的识别准确率为:{0}".format([rate*100 for rate in corrects/per_kinds]))

执行此步的输出结果如下所示:

pytorch分类模型绘制混淆矩阵以及可视化详解

Step 2. 混淆矩阵可视化

对上边求得的混淆矩阵可视化

# 绘制混淆矩阵
Emotion=8#这个数值是具体的分类数,大家可以自行修改
labels = ['neutral', 'calm', 'happy', 'sad', 'angry', 'fearful', 'disgust', 'surprised']#每种类别的标签

# 显示数据
plt.imshow(conf_matrix, cmap=plt.cm.Blues)

# 在图中标注数量/概率信息
thresh = conf_matrix.max() / 2#数值颜色阈值,如果数值超过这个,就颜色加深。
for x in range(Emotion_kinds):
   for y in range(Emotion_kinds):
       # 注意这里的matrix[y, x]不是matrix[x, y]
       info = int(conf_matrix[y, x])
       plt.text(x, y, info,
                verticalalignment='center',
                horizontalalignment='center',
                color="white" if info > thresh else "black")

plt.tight_layout()#保证图不重叠
plt.yticks(range(Emotion_kinds), labels)
plt.xticks(range(Emotion_kinds), labels,rotation=45)#X轴字体倾斜45°
plt.show()
plt.close()

来源:https://blog.csdn.net/weixin_38468077/article/details/121671139

0
投稿

猜你喜欢

手机版 网络编程 asp之家 www.aspxhome.com