网络编程
位置:首页>> 网络编程>> Python编程>> Python sklearn预测评估指标混淆矩阵计算示例详解

Python sklearn预测评估指标混淆矩阵计算示例详解

作者:fanstuck  发布时间:2023-12-19 23:39:21 

标签:Python,sklearn,预测评估,混淆矩阵,计算

前言

很多时候需要对自己模型进行性能评估,对于一些理论上面的知识我想基本不用说明太多,关于校验模型准确度的指标主要有混淆矩阵、准确率、精确率、召回率、F1 score。另外还有P-R曲线以及AUC/ROC,这些我都有写过相应的理论和具体理论过程:

机器学习:性能度量篇-Python利用鸢尾花数据绘制ROC和AUC曲线

机器学习:性能度量篇-Python利用鸢尾花数据绘制P-R曲线

 这里我们主要进行实践利用sklearn快速实现模型数据校验,完成基础指标计算。

混淆矩阵

查准率(precision)与查全率(recall)是对于需求在信息检索、Web搜索等应用评估性能度量适应度高的检测数值。对于二分类问题,可将真实类别与算法预测类别的组合划分为真正例(ture positive)、 * 例(false positive)、真反例(true negative)、假反例(false negative)四种情形。显然TP+FP+TN+FN=样例总数。分类结果为混淆矩阵:

Python sklearn预测评估指标混淆矩阵计算示例详解

以分类模型中最简单的二分类为例,对于这种问题,我们的模型最终需要判断样本的结果是0还是1,或者说是positive还是negative。 因此,我们就能得到这样四个基础指标,我称他们是一级指标(最底层的):

  • 真实值是positive,模型认为是positive的数量(True Positive=TP)

  • 真实值是positive,模型认为是negative的数量(False Negative=FN):这就是统计学上的第二类错误(Type II Error)

  • 真实值是negative,模型认为是positive的数量(False Positive=FP):这就是统计学上的第一类错误(Type I Error)

  • 真实值是negative,模型认为是negative的数量(True Negative=TN)

预测性分类模型,肯定是希望越准越好。那么,对应到混淆矩阵中,那肯定是希望TP与TN的数量大,而FP与FN的数量小。所以当我们得到了模型的混淆矩阵后,就需要去看有多少观测值在第二、四象限对应的位置,这里的数值越多越好;反之,在第一、三象限对应位置出现的观测值肯定是越少越好。

python代码

混淆矩阵一般来说可以有三种实现展示方法,需要前置计算出混淆矩阵数据,这一点使用sklearn就可以实现:

from sklearn.metrics import confusion_matrix
y_true =df_evaluation.state_y
y_pred =df_evaluation.state_x
cm= confusion_matrix(y_true, y_pred,labels=[2,3,4,5])

其中cm就是计算出来的混淆矩阵:

Python sklearn预测评估指标混淆矩阵计算示例详解

利用sklearn的confusion_matrix函数就可以实现,这里将该函数的参数铺开一下:

sklearn.metrics.confusion_matrix(y_true,
                                y_pred,
                                *,
                                labels=None,
                                sample_weight=None,
                                normalize=None)

参数说明:

  • y_true:对比真值

  • y_pred: 预测值

  • labels:索引矩阵的标签列表。这可用于重新排序或选择标签的子集。如果给定“无”,则按排序顺序使用在y_true或y_pred中至少出现一次的值。

  • sample_weight:样本权重

  • normalize:在真(行)、预测(列)条件或所有总体上规范化混淆矩阵。如果“无”,则混淆矩阵将不会被归一化。

得到了混淆矩阵接下来进行数据可视化就好了,这里有三种实现形式,其中matplotlib和seaborn实现方法是一样的,都是热力图实现,另外sklearn自带一个ConfusionMatrixDisplay也可以直接实现热力。 第一种matplotlib/seaborn:

import seaborn as sns
import matplotlib.pyplot as plt
labels=[2,3,4,5]
sns.heatmap(cm,annot=True ,fmt="d",xticklabels=labels,yticklabels=labels)
plt.title('confusion matrix')  # 标题
plt.xlabel('Predict lable')  # x轴
plt.ylabel('True lable')  # y轴
plt.show()

Python sklearn预测评估指标混淆矩阵计算示例详解

第二种ConfusionMatrixDisplay:

disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)
disp.plot(
   include_values=True,            
   cmap="viridis",                
   ax=None,                        
   xticks_rotation="horizontal",  
   values_format="d"              
)
plt.show()

Python sklearn预测评估指标混淆矩阵计算示例详解

这里我主要将一下ConfusionMatrixDisplay.plot()的可选参数:


plot(*,
    include_values=True,
    cmap='viridis',
    xticks_rotation='horizontal',
    values_format=None,
    ax=None,
    colorbar=True,
    im_kw=None,
    text_kw=None)

参数说明:

  • include_values:bool,default=True。包括混淆矩阵中的值。

  • cmap:str or matplotlib Colormap, default=’viridis’。matplotlib识别的颜色映射。

  • xticks_rotation: {‘vertical’, ‘horizontal’} or float, default=’horizontal’。旋转xtick标签。

  • values_format:str, default=None。混淆矩阵中值的格式规范。如果无,则格式规范为“d”或“.2g”,以较短者为准。

  • ax: matplotlib axes, default=None。要绘制的轴对象。如果为“无”,则创建新的图形和轴。

  • colorbar:bool, default=True。是否向绘图添加色条。

  • im_kw:dict, default=None。使用传递给matplotlib.pyplot.imshow调用的关键字进行读写。

  • text_kw:dict, default=None。使用传递给matplotlib.pyplot.text调用的关键字进行读写。

来源:https://juejin.cn/post/7184608711489355835

0
投稿

猜你喜欢

  • 利用python pyheatmap包绘制热力图,供大家参考,具体内容如下import matplotlib.pyplot as pltfr
  • 我们可能会出现这种情况,某个表原来设计不周全,导致表里面的数据数据重复,那么,如何对重复的数据进行删除呢?重复的数据可能有这样两种情况,第一
  • 代码如下:CREATE TABLE [dbo].[TbGuidTable]( [TableName] [varchar](50) NOT N
  • 1、我的源码在 /home/topsec/Documents/php-7.0.11 ,安装位置在 /usr/local/php7, php.
  • 小编今天写下关于后台管理员权限的分配自己的思路想法<?php /**reader * 小编的思想比较简单实现的功能
  • 仿windows选项卡或叫做tabpan以及tabpage,现在还有最新的进展譬如仿淘宝网导航菜单效果皆属于此类:运行代码框<scri
  • ISNULL     使用指定的替换值替换   NULL。   &nb
  •   “用户体验”作为舶来品在国内风靡已经有几个年头了,而且从目前情况来看仍旧会继续风靡一段时间。当某产品发布会上,发言人张口闭口就
  • 首先,我要在这里写上一些很官方的概念,意在说明面向对象是很具体化的,很实体的模式,不能让有些人看见&ldquo;对象&rdq
  • 链接的 target 属性怎么用 JS 来控制? 在HTML 4.0 Strict和XHTML 1.0 STRICT里不允许在<a&g
  • 因AJAX接受数据时服务器默认是采用UTF-8的编码形式进行传送,所以在很多GB2312中文网页中应用AJAX回传数据经常会发生中文乱码。解
  • 本文以修改用户名密码单元为案例,编写测试脚本。完成修改用户名密码模块单元测试。(ps.这个demo中登陆密码为“admin”)1. 打开浏览
  • 一图胜“十”言:SQL Server 数据库总结 一个大概的总结 经过一段时间的学习,也对数据库有了一些认识。 数据库基本是由表,关系,操作
  • PDOStatement::errorInfoPDOStatement::errorInfo — 获取跟上一次语句句柄操作相关的扩展错误信息
  • 为什么很多站长开始做英文网站,我想主要是原因是良好的互联网环境让大家更容易赚到钱,中小站长做英文网站大致为两类,一是电子商务的外贸网站,二是
  • SQL中的单记录函数1.ASCII返回与指定的字符对应的十进制数;SQL> select ascii('A') A,a
  • GetRepeatTimes(TheChar,TheString) 得到一个字符串在另一个字符串当中出现几次的函数(新)如:response
  • 在大家的日常python程序的编写过程中,都会有自己解决某个问题的解决办法,或者是在程序的调试过程中,用来帮助调试的程序公式。小编通过上万行
  • 这是一段点击复制的代码,现在我的页面里不仅有1个链接需要用到这段代码。请哪位好心人指教一下应该怎么用ID对应的方式来改写这段js,使它实现一
  • 目录1.技术背景2.问题复现3.解决思路4.总结概要1.技术背景笔者在执行一个Jax的任务中,又发现了一个奇怪的问题,就是明明只分配了很小的
手机版 网络编程 asp之家 www.aspxhome.com