在pytorch中计算准确率,召回率和F1值的操作
作者:coding_zhang 发布时间:2022-02-13 18:06:40
标签:pytorch,准确率,召回率,F1
看代码吧~
predict = output.argmax(dim = 1)
confusion_matrix =torch.zeros(2,2)
for t, p in zip(predict.view(-1), target.view(-1)):
confusion_matrix[t.long(), p.long()] += 1
a_p =(confusion_matrix.diag() / confusion_matrix.sum(1))[0]
b_p = (confusion_matrix.diag() / confusion_matrix.sum(1))[1]
a_r =(confusion_matrix.diag() / confusion_matrix.sum(0))[0]
b_r = (confusion_matrix.diag() / confusion_matrix.sum(0))[1]
补充:pytorch 查全率 recall 查准率 precision F1调和平均 准确率 accuracy
看代码吧~
def eval():
net.eval()
test_loss = 0
correct = 0
total = 0
classnum = 9
target_num = torch.zeros((1,classnum))
predict_num = torch.zeros((1,classnum))
acc_num = torch.zeros((1,classnum))
for batch_idx, (inputs, targets) in enumerate(testloader):
if use_cuda:
inputs, targets = inputs.cuda(), targets.cuda()
inputs, targets = Variable(inputs, volatile=True), Variable(targets)
outputs = net(inputs)
loss = criterion(outputs, targets)
# loss is variable , if add it(+=loss) directly, there will be a bigger ang bigger graph.
test_loss += loss.data[0]
_, predicted = torch.max(outputs.data, 1)
total += targets.size(0)
correct += predicted.eq(targets.data).cpu().sum()
pre_mask = torch.zeros(outputs.size()).scatter_(1, predicted.cpu().view(-1, 1), 1.)
predict_num += pre_mask.sum(0)
tar_mask = torch.zeros(outputs.size()).scatter_(1, targets.data.cpu().view(-1, 1), 1.)
target_num += tar_mask.sum(0)
acc_mask = pre_mask*tar_mask
acc_num += acc_mask.sum(0)
recall = acc_num/target_num
precision = acc_num/predict_num
F1 = 2*recall*precision/(recall+precision)
accuracy = acc_num.sum(1)/target_num.sum(1)
#精度调整
recall = (recall.numpy()[0]*100).round(3)
precision = (precision.numpy()[0]*100).round(3)
F1 = (F1.numpy()[0]*100).round(3)
accuracy = (accuracy.numpy()[0]*100).round(3)
# 打印格式方便复制
print('recall'," ".join('%s' % id for id in recall))
print('precision'," ".join('%s' % id for id in precision))
print('F1'," ".join('%s' % id for id in F1))
print('accuracy',accuracy)
补充:Python scikit-learn,分类模型的评估,精确率和召回率,classification_report
分类模型的评估标准一般最常见使用的是准确率(estimator.score()),即预测结果正确的百分比。
混淆矩阵:
准确率是相对所有分类结果;精确率、召回率、F1-score是相对于某一个分类的预测评估标准。
精确率(Precision):预测结果为正例样本中真实为正例的比例(查的准)()
召回率(Recall):真实为正例的样本中预测结果为正例的比例(查的全)()
分类的其他评估标准:F1-score,反映了模型的稳健型
demo.py(分类评估,精确率、召回率、F1-score,classification_report):
from sklearn.datasets import fetch_20newsgroups
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.metrics import classification_report
# 加载数据集 从scikit-learn官网下载新闻数据集(共20个类别)
news = fetch_20newsgroups(subset='all') # all表示下载训练集和测试集
# 进行数据分割 (划分训练集和测试集)
x_train, x_test, y_train, y_test = train_test_split(news.data, news.target, test_size=0.25)
# 对数据集进行特征抽取 (进行特征提取,将新闻文档转化成特征词重要性的数字矩阵)
tf = TfidfVectorizer() # tf-idf表示特征词的重要性
# 以训练集数据统计特征词的重要性 (从训练集数据中提取特征词)
x_train = tf.fit_transform(x_train)
print(tf.get_feature_names()) # ["condensed", "condescend", ...]
x_test = tf.transform(x_test) # 不需要重新fit()数据,直接按照训练集提取的特征词进行重要性统计。
# 进行朴素贝叶斯算法的预测
mlt = MultinomialNB(alpha=1.0) # alpha表示拉普拉斯平滑系数,默认1
print(x_train.toarray()) # toarray() 将稀疏矩阵以稠密矩阵的形式显示。
'''
[[ 0. 0. 0. ..., 0.04234873 0. 0. ]
[ 0. 0. 0. ..., 0. 0. 0. ]
...,
[ 0. 0.03934786 0. ..., 0. 0. 0. ]
'''
mlt.fit(x_train, y_train) # 填充训练集数据
# 预测类别
y_predict = mlt.predict(x_test)
print("预测的文章类别为:", y_predict) # [4 18 8 ..., 15 15 4]
# 准确率
print("准确率为:", mlt.score(x_test, y_test)) # 0.853565365025
print("每个类别的精确率和召回率:", classification_report(y_test, y_predict, target_names=news.target_names))
'''
precision recall f1-score support
alt.atheism 0.86 0.66 0.75 207
comp.graphics 0.85 0.75 0.80 238
sport.baseball 0.96 0.94 0.95 253
...,
'''
召回率的意义(应用场景):产品的不合格率(不想漏掉任何一个不合格的产品,查全);癌症预测(不想漏掉任何一个癌症患者)
来源:https://blog.csdn.net/coding_zhang/article/details/89537810


猜你喜欢
- 目标在本章中,将了解:如何生成OpenCV-Python bindings如何将新的OpenCV模块扩展到PythonOpenCV-Pyth
- 传统方式需要10s,dat方式需要0.6simport osimport timeimport torchimport randomfrom
- python一直被病垢运行速度太慢,但是实际上python的执行效率并不慢,慢的是python用的解释器Cpython运行效率太差。“一行代
- 新建一个Spring Initializr项目2.把pom.xml文件中的oracle依赖换成自己的oracle版本依赖:原来的:现在的:
- 一 MySQL WorkbenchMySQL Workbench提供DBAs和developers一个集成工具环境:1)数据库设计和建模2)
- 目录项目地址:1) 启动方法2) web查看方法3) 功能说明:4) 展示:代码项目地址:https://github.com/guodon
- 1. document.form.item 问题 (1)现有问题:现有代码中存在许多 document.formName.item(&quo
- 1 类继承Python 是面向对象的编程语言,因此支持面向对象的三大特性之一:继承。继承是代码重用的一种途径,Python 中的继承就像现实
- hanxiaolian 为了躲避 lake2 ASP站长管理助手而写.. 一.绕过lake2 Asp木马扫描的小马 代码如下:<%&n
- 一、开发接口的作用1、mock接口:模拟一些接口。有一些有关联的接口,在别的接口没有开发好的时候,需要用这个接口,就可以写一个假接口,返回想
- Fucklt.py 使用了最先进的技术能够使你的代码不管里面有什么样的错误,你只管 FuckIt,程序就能"正常"执行,
- 将.ppm格式的图片转换成.jpg格式的图像,除了通过软件转换,还可以使用python脚本直接转换,so easy!!!from PIL i
- 所需库的安装很多人问Pytorch要怎么可视化,于是决定搞一篇。tensorboardX==2.0tensorflow==1.13.2由于t
- 一 在写之前 最好指定python的路径:#!/usr/bin/pythonpython 在linux中需要添加编码方式:以免出现中文乱码#
- 1、PandasPython Data Analysis Library 或 pandas 是基于NumPy 的一种工具,相当于这是Pyth
- 俺比较笨,对太专业的书一直不感冒,看了就想睡觉。最近李明同学传了本“大话设计模式”电子版。偶然翻了翻,感觉还满通俗的,正适合我这样的懒人学习
- 说起模板引擎,很多人会认为这是后台的东西(如PHP的Smarty、Java的Velocity),跟前端没有关系。然而,随着前端的逻辑变得越来
- 上一次很多朋友写文字屏蔽说到要用正则表达,其实不是我不想用(我正则用得不是很多,看过我之前爬虫的都知道,我直接用BeautifulSoup的
- <html><head><meta http-equiv="Content-T
- ETL的考虑 做 数据仓库系统,ETL是关键的一环。说大了,ETL是数据整合解决