使用sklearn对多分类的每个类别进行指标评价操作
作者:山阴少年 发布时间:2022-04-20 17:19:39
标签:sklearn,多分类,类别,指标评价
今天晚上,笔者接到客户的一个需要,那就是:对多分类结果的每个类别进行指标评价,也就是需要输出每个类型的精确率(precision),召回率(recall)以及F1值(F1-score)。
对于这个需求,我们可以用sklearn来解决,方法并没有难,笔者在此仅做记录,供自己以后以及读者参考。
我们模拟的数据如下:
y_true = ['北京', '上海', '成都', '成都', '上海', '北京', '上海', '成都', '北京', '上海']
y_pred = ['北京', '上海', '成都', '上海', '成都', '成都', '上海', '成都', '北京', '上海']
其中y_true为真实数据,y_pred为多分类后的模拟数据。使用sklearn.metrics中的classification_report即可实现对多分类的每个类别进行指标评价。
示例的Python代码如下:
# -*- coding: utf-8 -*-
from sklearn.metrics import classification_report
y_true = ['北京', '上海', '成都', '成都', '上海', '北京', '上海', '成都', '北京', '上海']
y_pred = ['北京', '上海', '成都', '上海', '成都', '成都', '上海', '成都', '北京', '上海']
t = classification_report(y_true, y_pred, target_names=['北京', '上海', '成都'])
print(t)
输出结果如下:
precision recall f1-score support
北京 0.75 0.75 0.75 4
上海 1.00 0.67 0.80 3
成都 0.50 0.67 0.57 3
accuracy 0.70 10
macro avg 0.75 0.69 0.71 10
weighted avg 0.75 0.70 0.71 10
需要注意的是,输出的结果数据类型为str,如果需要使用该输出结果,则可将该方法中的output_dict参数设置为True,此时输出的结果如下:
{‘北京': {‘precision': 0.75, ‘recall': 0.75, ‘f1-score': 0.75, ‘support': 4},
‘上海': {‘precision': 1.0, ‘recall': 0.6666666666666666, ‘f1-score': 0.8, ‘support': 3},
‘成都': {‘precision': 0.5, ‘recall': 0.6666666666666666, ‘f1-score': 0.5714285714285715, ‘support': 3},
‘accuracy': 0.7,
‘macro avg': {‘precision': 0.75, ‘recall': 0.6944444444444443, ‘f1-score': 0.7071428571428572, ‘support': 10},
‘weighted avg': {‘precision': 0.75, ‘recall': 0.7, ‘f1-score': 0.7114285714285715, ‘support': 10}}
使用confusion_matrix方法可以输出该多分类问题的混淆矩阵,代码如下:
from sklearn.metrics import confusion_matrix
y_true = ['北京', '上海', '成都', '成都', '上海', '北京', '上海', '成都', '北京', '上海']
y_pred = ['北京', '上海', '成都', '上海', '成都', '成都', '上海', '成都', '北京', '上海']
print(confusion_matrix(y_true, y_pred, labels = ['北京', '上海', '成都']))
输出结果如下:
[[2 0 1]
[0 3 1]
[0 1 2]]
为了将该混淆矩阵绘制成图片,可使用如下的Python代码:
# -*- coding: utf-8 -*-
# author: Jclian91
# place: Daxing Beijing
# time: 2019-11-14 21:52
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import matplotlib as mpl
# 支持中文字体显示, 使用于Mac系统
zhfont=mpl.font_manager.FontProperties(fname="/Library/Fonts/Songti.ttc")
y_true = ['北京', '上海', '成都', '成都', '上海', '北京', '上海', '成都', '北京', '上海']
y_pred = ['北京', '上海', '成都', '上海', '成都', '成都', '上海', '成都', '北京', '上海']
classes = ['北京', '上海', '成都']
confusion = confusion_matrix(y_true, y_pred)
# 绘制热度图
plt.imshow(confusion, cmap=plt.cm.Greens)
indices = range(len(confusion))
plt.xticks(indices, classes, fontproperties=zhfont)
plt.yticks(indices, classes, fontproperties=zhfont)
plt.colorbar()
plt.xlabel('y_pred')
plt.ylabel('y_true')
# 显示数据
for first_index in range(len(confusion)):
for second_index in range(len(confusion[first_index])):
plt.text(first_index, second_index, confusion[first_index][second_index])
# 显示图片
plt.show()
生成的混淆矩阵图片如下:
补充知识:python Sklearn实现xgboost的二分类和多分类
二分类:
train2.txt的格式如下:
import numpy as np
import pandas as pd
import sklearn
from sklearn.cross_validation import train_test_split,cross_val_score
from xgboost.sklearn import XGBClassifier
from sklearn.metrics import precision_score,roc_auc_score
min_max_scaler = sklearn.preprocessing.MinMaxScaler(feature_range=(-1,1))
resultX = []
resultY = []
with open("./train_data/train2.txt",'r') as rf:
train_lines = rf.readlines()
for train_line in train_lines:
train_line_temp = train_line.split(",")
train_line_temp = map(float, train_line_temp)
line_x = train_line_temp[1:-1]
line_y = train_line_temp[-1]
resultX.append(line_x)
resultY.append(line_y)
X = np.array(resultX)
Y = np.array(resultY)
X = min_max_scaler.fit_transform(X)
X_train,X_test, Y_train, Y_test = train_test_split(X,Y,test_size=0.3)
xgbc = XGBClassifier()
xgbc.fit(X_train,Y_train)
pre_test = xgbc.predict(X_test)
auc_score = roc_auc_score(Y_test,pre_test)
pre_score = precision_score(Y_test,pre_test)
print("xgb_auc_score:",auc_score)
print("xgb_pre_score:",pre_score)
多分类:有19种分类其中正常0,异常1~18种。数据格式如下:
# -*- coding:utf-8 -*-
from sklearn import datasets
from sklearn.multiclass import OneVsRestClassifier
from sklearn.svm import LinearSVC
from sklearn.cross_validation import train_test_split,cross_val_score
from sklearn.svm import SVC
from sklearn.linear_model import LogisticRegression
from xgboost.sklearn import XGBClassifier
import sklearn
import numpy as np
from sklearn.preprocessing import OneHotEncoder
from sklearn.metrics import precision_score,roc_auc_score
min_max_scaler = sklearn.preprocessing.MinMaxScaler(feature_range=(-1,1))
resultX = []
resultY = []
with open("../train_data/train_multi_class.txt",'r') as rf:
train_lines = rf.readlines()
for train_line in train_lines:
train_line_temp = train_line.split(",")
train_line_temp = map(float, train_line_temp) # 转化为浮点数
line_x = train_line_temp[1:-1]
line_y = train_line_temp[-1]
resultX.append(line_x)
resultY.append(line_y)
X = np.array(resultX)
Y = np.array(resultY)
#fit_transform(partData)对部分数据先拟合fit,找到该part的整体指标,如均值、方差、最大值最小值等等(根据具体转换的目的),然后对该partData进行转换transform,从而实现数据的标准化、归一化等等。。
X = min_max_scaler.fit_transform(X)
#通过OneHotEncoder函数将Y值离散化成19维,例如3离散成000000···100
Y = OneHotEncoder(sparse = False).fit_transform(Y.reshape(-1,1))
X_train,X_test, Y_train, Y_test = train_test_split(X,Y,test_size=0.2)
model = OneVsRestClassifier(XGBClassifier(),n_jobs=2)
clf = model.fit(X_train, Y_train)
pre_Y = clf.predict(X_test)
test_auc2 = roc_auc_score(Y_test,pre_Y)#验证集上的auc值
print ("xgb_muliclass_auc:",test_auc2)
来源:https://blog.csdn.net/jclian91/article/details/103074506


猜你喜欢
- 1.背景最近使用Pytest中的fixture和conftest时,遇到需要在conftest中的setup和teardown方法里传递参数
- python读取npz/npy文件npz和npy文件都可以直接使用numpy读写。import numpy as npac = np.loa
- golang这个语言用起来和java、 c#之类语言差不多,和c/c++差别比较大,有自动管理内存机制,省心省力。然而,如果写golang真
- 最近要做一个图像生成的课题,在网上找了一个混合的数据集。这个数据集中一共有360个文件夹,然后文件夹中有6-9张不等的照片,我的目标就是编写
- FULLTEXT以前使用查找时都是以 %关键字% 进行模糊查询结果的,这种查询方式有一些缺点,比如不能查询多个列必须手动添加条件以实现,效率
- Python信息抽取之乱码解决办法就事论事,直说自己遇到的情况,和我不一样的路过吧,一样的就看看吧信息抓取,用python,beautifu
- 一、所用知识点:1. for循环与if判断的结合2. %s占位符的使用3. 辅助标志的使用(标志位)4. break的使用二、代码示例:
- Python 提供了多个图形开发界面的库。Tkinter就是其中之一。 Tkinter 模块(Tk 接口)是 Python 的标准 Tk G
- 1、凸包检测与凸缺陷定义凸包是将最外层的点连接起来构成的凸多边形,它能包含点击中所有的点。物体的凸包检测常应用在物体识别、手势识别及边界检测
- Python中str is not callable问题详解及解决办法问题提出: 在Python的代码,在运行过程中
- 使用BootstrapValidator进行注册校验和登录错误提示,具体内容如下1、介绍在AdminEAP框架中,使用了BootstrapV
- 只有mdf文件的数据库附加失败的修复 附加时报如下错误: 服务器: 消息 1813,级别 16,状态 2,行 1 未能打开新数据库 '
- 原理:print() 函数会把内容放到内存中, 内存中的内容并不一定能够及时刷新显示到屏幕中(应该是要满足某个条件,这个条件现在还不清楚)。
- MYSQL有不同类型的日志文件(各自存储了不同类型的日志),从它们当中可以查询到MYSQL里都做了些什么,对于MYSQL的管理工作,这些日志
- 在开始聊我在阿里四个月的网页推广设计之前,我想先来说说我对平面设计和网页设计的认识。它们之间的交集。1.它们都是集艺术创作、电脑技术和数字技
- <html> <head> <title>biyuan给大家拜年了!</title> <
- 一、实现划词功能说是划词翻译,实际上我们是通过获取用户的剪切板内容,通过一系列的操作得到的。首先呢,我们就先实现如何获取剪切板内容的程序首先
- 最近稍稍有点空闲时间,于是重新温习了一下之前学习过的python基础。废话不多说,记录一下自己的所得。首先,安装什么的不在本人的温习范围,另
- 科讯5.0 标签和之前版本变化不大,如果用老版本的科讯,可以参考这个标签使用。相关文章:新云4.0 模板通用标签说明 标签清单:======
- HTML文件其实就是由一组尖括号构成的标签组织起来的,每一对尖括号形式一个标签,标签之间存在上下关系,形成标签树;XPath 使用路径表达式