Python利用Pytorch实现绘制ROC与PR曲线图
作者:Vertira 发布时间:2022-09-20 03:24:20
标签:Python,Pytorch,ROC,PR,曲线图
Pytorch 多分类模型绘制 ROC, PR 曲线(代码 亲测 可用)
ROC曲线
示例代码
import torch
import torch.nn as nn
import os
import numpy as np
from torchvision.datasets import ImageFolder
from utils.transform import get_transform_for_test
from senet.se_resnet import FineTuneSEResnet50
from scipy import interp
import matplotlib.pyplot as plt
from itertools import cycle
from sklearn.metrics import roc_curve, auc, f1_score, precision_recall_curve, average_precision_score
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
data_root = r'D:\TJU\GBDB\set113\set113_images\test1' # 测试集路径
test_weights_path = r"C:\Users\admin\Desktop\fsdownload\epoch_0278_top1_70.565_'checkpoint.pth.tar'" # 预训练模型参数
num_class = 113 # 类别数量
gpu = "cuda:0"
# mean=[0.948078, 0.93855226, 0.9332005], var=[0.14589554, 0.17054074, 0.18254866]
def test(model, test_path):
# 加载测试集和预训练模型参数
test_dir = os.path.join(data_root, 'test_images')
class_list = list(os.listdir(test_dir))
class_list.sort()
transform_test = get_transform_for_test(mean=[0.948078, 0.93855226, 0.9332005],
var=[0.14589554, 0.17054074, 0.18254866])
test_dataset = ImageFolder(test_dir, transform=transform_test)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=1, shuffle=False, drop_last=False, pin_memory=True, num_workers=1)
checkpoint = torch.load(test_path)
model.load_state_dict(checkpoint['state_dict'])
model.eval()
score_list = [] # 存储预测得分
label_list = [] # 存储真实标签
for i, (inputs, labels) in enumerate(test_loader):
inputs = inputs.cuda()
labels = labels.cuda()
outputs = model(inputs)
# prob_tmp = torch.nn.Softmax(dim=1)(outputs) # (batchsize, nclass)
score_tmp = outputs # (batchsize, nclass)
score_list.extend(score_tmp.detach().cpu().numpy())
label_list.extend(labels.cpu().numpy())
score_array = np.array(score_list)
# 将label转换成onehot形式
label_tensor = torch.tensor(label_list)
label_tensor = label_tensor.reshape((label_tensor.shape[0], 1))
label_onehot = torch.zeros(label_tensor.shape[0], num_class)
label_onehot.scatter_(dim=1, index=label_tensor, value=1)
label_onehot = np.array(label_onehot)
print("score_array:", score_array.shape) # (batchsize, classnum)
print("label_onehot:", label_onehot.shape) # torch.Size([batchsize, classnum])
# 调用sklearn库,计算每个类别对应的fpr和tpr
fpr_dict = dict()
tpr_dict = dict()
roc_auc_dict = dict()
for i in range(num_class):
fpr_dict[i], tpr_dict[i], _ = roc_curve(label_onehot[:, i], score_array[:, i])
roc_auc_dict[i] = auc(fpr_dict[i], tpr_dict[i])
# micro
fpr_dict["micro"], tpr_dict["micro"], _ = roc_curve(label_onehot.ravel(), score_array.ravel())
roc_auc_dict["micro"] = auc(fpr_dict["micro"], tpr_dict["micro"])
# macro
# First aggregate all false positive rates
all_fpr = np.unique(np.concatenate([fpr_dict[i] for i in range(num_class)]))
# Then interpolate all ROC curves at this points
mean_tpr = np.zeros_like(all_fpr)
for i in range(num_class):
mean_tpr += interp(all_fpr, fpr_dict[i], tpr_dict[i])
# Finally average it and compute AUC
mean_tpr /= num_class
fpr_dict["macro"] = all_fpr
tpr_dict["macro"] = mean_tpr
roc_auc_dict["macro"] = auc(fpr_dict["macro"], tpr_dict["macro"])
# 绘制所有类别平均的roc曲线
plt.figure()
lw = 2
plt.plot(fpr_dict["micro"], tpr_dict["micro"],
label='micro-average ROC curve (area = {0:0.2f})'
''.format(roc_auc_dict["micro"]),
color='deeppink', linestyle=':', linewidth=4)
plt.plot(fpr_dict["macro"], tpr_dict["macro"],
label='macro-average ROC curve (area = {0:0.2f})'
''.format(roc_auc_dict["macro"]),
color='navy', linestyle=':', linewidth=4)
colors = cycle(['aqua', 'darkorange', 'cornflowerblue'])
for i, color in zip(range(num_class), colors):
plt.plot(fpr_dict[i], tpr_dict[i], color=color, lw=lw,
label='ROC curve of class {0} (area = {1:0.2f})'
''.format(i, roc_auc_dict[i]))
plt.plot([0, 1], [0, 1], 'k--', lw=lw)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Some extension of Receiver operating characteristic to multi-class')
plt.legend(loc="lower right")
plt.savefig('set113_roc.jpg')
plt.show()
if __name__ == '__main__':
# 加载模型
seresnet = FineTuneSEResnet50(num_class=num_class)
device = torch.device(gpu)
seresnet = seresnet.to(device)
test(seresnet, test_weights_path)
运行结果:
PR曲线
示例代码
import torch
import torch.nn as nn
import os
import numpy as np
from torchvision.datasets import ImageFolder
from utils.transform import get_transform_for_test
from senet.se_resnet import FineTuneSEResnet50
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc, f1_score, precision_recall_curve, average_precision_score
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
data_root = r'D:\TJU\GBDB\set113\set113_images\test1' # 测试集路径
test_weights_path = r"C:\Users\admin\Desktop\fsdownload\epoch_0278_top1_70.565_'checkpoint.pth.tar'" # 预训练模型参数
num_class = 113 # 类别数量
gpu = "cuda:0"
# mean=[0.948078, 0.93855226, 0.9332005], var=[0.14589554, 0.17054074, 0.18254866]
def test(model, test_path):
# 加载测试集和预训练模型参数
test_dir = os.path.join(data_root, 'test_images')
class_list = list(os.listdir(test_dir))
class_list.sort()
transform_test = get_transform_for_test(mean=[0.948078, 0.93855226, 0.9332005],
var=[0.14589554, 0.17054074, 0.18254866])
test_dataset = ImageFolder(test_dir, transform=transform_test)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=1, shuffle=False, drop_last=False, pin_memory=True, num_workers=1)
checkpoint = torch.load(test_path)
model.load_state_dict(checkpoint['state_dict'])
model.eval()
score_list = [] # 存储预测得分
label_list = [] # 存储真实标签
for i, (inputs, labels) in enumerate(test_loader):
inputs = inputs.cuda()
labels = labels.cuda()
outputs = model(inputs)
# prob_tmp = torch.nn.Softmax(dim=1)(outputs) # (batchsize, nclass)
score_tmp = outputs # (batchsize, nclass)
score_list.extend(score_tmp.detach().cpu().numpy())
label_list.extend(labels.cpu().numpy())
score_array = np.array(score_list)
# 将label转换成onehot形式
label_tensor = torch.tensor(label_list)
label_tensor = label_tensor.reshape((label_tensor.shape[0], 1))
label_onehot = torch.zeros(label_tensor.shape[0], num_class)
label_onehot.scatter_(dim=1, index=label_tensor, value=1)
label_onehot = np.array(label_onehot)
print("score_array:", score_array.shape) # (batchsize, classnum) softmax
print("label_onehot:", label_onehot.shape) # torch.Size([batchsize, classnum]) onehot
# 调用sklearn库,计算每个类别对应的precision和recall
precision_dict = dict()
recall_dict = dict()
average_precision_dict = dict()
for i in range(num_class):
precision_dict[i], recall_dict[i], _ = precision_recall_curve(label_onehot[:, i], score_array[:, i])
average_precision_dict[i] = average_precision_score(label_onehot[:, i], score_array[:, i])
print(precision_dict[i].shape, recall_dict[i].shape, average_precision_dict[i])
# micro
precision_dict["micro"], recall_dict["micro"], _ = precision_recall_curve(label_onehot.ravel(),
score_array.ravel())
average_precision_dict["micro"] = average_precision_score(label_onehot, score_array, average="micro")
print('Average precision score, micro-averaged over all classes: {0:0.2f}'.format(average_precision_dict["micro"]))
# 绘制所有类别平均的pr曲线
plt.figure()
plt.step(recall_dict['micro'], precision_dict['micro'], where='post')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.ylim([0.0, 1.05])
plt.xlim([0.0, 1.0])
plt.title(
'Average precision score, micro-averaged over all classes: AP={0:0.2f}'
.format(average_precision_dict["micro"]))
plt.savefig("set113_pr_curve.jpg")
# plt.show()
if __name__ == '__main__':
# 加载模型
seresnet = FineTuneSEResnet50(num_class=num_class)
device = torch.device(gpu)
seresnet = seresnet.to(device)
test(seresnet, test_weights_path)
运行结果:
来源:https://blog.csdn.net/Vertira/article/details/128482515


猜你喜欢
- 在缺失值填补上如果用前后的均值填补中间的均值,比如,0,空,1,我们希望中间填充0.5;或者0,空,空,1,我们希望中间填充0.33,0.6
- 前言我们在django-rest-framework 自定义swagger 文章中编写了接口, 调通了接口文档. 接口文档可以直接填写参数进
- 举例: 如:在字段名处输入:username,password,email,telphone 注意:不同的字段名用英文逗号隔开,且不支持星号
- 写在前面的话关于《交互设计实用指南》,我们最近收到很多朋友的反馈,有支持的也有批评的,在此一并感谢了,有你们的关注,我们才能走得更远。《交互
- 一、环境准备1.CentOS配置最好是用新克隆的虚拟机 ,虚拟机内存设置大一点(我设置的4G),配置网络,主机名,关闭防火墙,关闭selin
- var sss=(String.fromCharCode(127)); var xmlhttp =
- 模式特点:给定一个语言,定义它的文法的一种表示,并定义一个解释器,这个解释器使用该表示来解释语言中的句子。我们来看一下下面这样的程序结构:c
- Python编程时,经常需要跳过第一行读取文件内容。简单的做法是为每行设置一个line_num,然后判断line_num是否为1,如果不等于
- 最近在工作中涉及到判断服务器所在ip反馈程序使用情况的程序主要要求就是,本机或局域网调试程序时,不反馈其域名(localhost)或ip站长
- Pytorch调用forward()函数Module类是nn模块里提供的一个模型构造类,是所有神经网络模块的基类,我们可以继承它来定义我们想
- Java timezone设置和mybatis连接数据库时区设置JVM时区设置springboot工程运行时,需要指定时区,这样获取的时间才
- 本文实例为大家分享了vue简单的图书管理具体代码,供大家参考,具体内容如下<table class="table table
- 工作过程中遇到一个Js从Cookies里面取值的需求,Js貌似没有现成的方法可以指定Key值获取Cookie里面对应的值,参阅网上的代码,简
- 最近学习了SSD算法,了解了其基本的实现思路,并通过SSD模型训练自己的模型。基本环境torch1.2.0Pillow8.2.0torchv
- 前言本节我们来讲讲并发中最常见的情况存在即更新,在并发中若未存在行记录则插入,此时未处理好极容易出现插入重复键情况,本文我们来介绍对并发中存
- 1、同级目录下调用若在程序 testone.py 中导入模块 testtwo.py , 则直接使用【import testtwo 或 fro
- 无论使用int还是varchar,对于Status的多选查询都是不易应对的。举例,常规思维下对CustomerStatus的Enum设置如下
- 一、home页使用frametemplate/home.html<!DOCTYPE html><html lang=&qu
- 爬取 * 及测试是否可用很多人在爬虫时为了防止被封IP,所以就会去各大网站上查找免费的 * ,由于不是每个IP地址都是有效的,如果要进去
- 一般的网站会有很多页面,面包屑导航可以大大改善用户寻找他们的路径的方法。就可用性而言,面包屑可以减少一个网站的用户返回上一级页面的操作次数,