浅谈sklearn中predict与predict_proba区别
作者:GitzLiu 发布时间:2023-11-08 03:53:45
标签:sklearn,predict,proba
predict_proba 返回的是一个 n 行 k 列的数组,列是标签(有排序), 第 i 行 第 j 列上的数值是模型预测 第 i 个预测样本为某个标签的概率,并且每一行的概率和为1。
predict 直接返回的是预测 的标签。
具体见下面示例:
# conding :utf-8
from sklearn.linear_model import LogisticRegression
import numpy as np
x_train = np.array([[1,2,3],
[1,3,4],
[2,1,2],
[4,5,6],
[3,5,3],
[1,7,2]])
y_train = np.array([3, 3, 3, 2, 2, 2])
x_test = np.array([[2,2,2],
[3,2,6],
[1,7,4]])
clf = LogisticRegression()
clf.fit(x_train, y_train)
# 返回预测标签
print(clf.predict(x_test))
# 返回预测属于某标签的概率
print(clf.predict_proba(x_test))
# [2 3 2]
#
# [[0.56651809 0.43348191]
# [0.15598162 0.84401838]
# [0.86852502 0.13147498]]
# 分析结果:
# 标签是 2,3 共两个,所以predict_proba返回的为2列,且是排序的(第一列为标签2,第二列为标签3),
# 返回矩阵的行数是测试样本个数 因此为3行
# 预测[2,2,2]的标签是2的概率为0.56651809,3的概率为0.43348191
#
# 预测[3,2,6]的标签是2的概率为0.15598162,3的概率为0.84401838
#
# 预测[1,7,4]的标签是2的概率为0.86852502,3的概率为0.13147498
补充知识:sklearn中predict与predict_proba的识别结果不一致
今天训练了好久的决策树模型在测试的时候发现个bug,使用predict得到的结果居然不是predict_proba中最大数值的索引!因为脚本中需要模型的置信度,所以希望拿到predict_proba的类别概率。
经过胡乱分析发现predict_proba得到的维度比总类别数少了几个,经过测试发现就是这个造成的,即训练集中有部分类别样本数为0。这个问题比较隐蔽,记录一下方便天涯沦落人绕坑。
Tip:在sklearn的train_test_split中有一个参数可以强制测试集和训练集的数据分布一致,也就不会导致缺类别的问题。
来源:https://blog.csdn.net/GitzLiu/article/details/81952431


猜你喜欢
- 有2个不同的方法增加用户:通过使用GRANT语句或通过直接操作MySQL授权表。比较好的方法是使用GRANT语句,因为他们是更简明并且好像错
- 以前没怎么仔细的研究过ajax,只是用到了就直接拿过来用,发现了问题再找解决方法.以下是我在找解决问题的过程中的一点小小的总结. 一.谈Aj
- 本文实例为大家分享了python抖音表白神器,供大家参考,具体内容如下# -*- coding: utf-8 -*-import sysfr
- 由于本人在实际应用中遇到了有关 numpy.sum() 函数参数 axis 的问题,这里特来记录一下。也供大家参考。示例代码如下:impor
- 前言如果想分布式执行用例,用例设计必须遵循以下原则:1、用例之间都是独立的,2、用例a不要去依赖用例b3、用例执行没先后顺序,4、随机都能执
- 阅读本文大概需要3分钟关于函数和模块讲了这么久,我一直想用一个好玩有趣的小例子来总结一下,同时也作为实战练习一下。趣味编程其实是最好的学习途
- Go 语言 switch 语句switch 语句用于基于不同条件执行不同动作,每一个 case 分支都是唯一的,从上直下逐一测试,直到匹配为
- 练习一:假设你获取到了2017年内地电影票房前20的电影(列表a)和电影票房数据(列表b),那么如何更加直观的展示该数据?a = [&quo
- 前言大家好,我是辣条今天给大家带来几个实用的python脚本工具,原因不难猜这段时间我亲爱的女朋友呢给我整出点小花样,差点让我电脑GG了。我
- 前言由于学校科技立项的项目需要实现Android App端与PHP Web端的简单数据交互的实现,当前场景是Web端使用的是MySql数据库
- 分割单词将一个标识符分割成若干单词存进列表,便于后续命名法的转换先引入正则表达式包import re至于如何分割单词看个人喜好,如以常见分隔
- demo.py(装饰器,带参数的装饰器):def set_level(level_num): def set_func(func
- 在Pydev能正常执行的脚本,在导出后在命令行执行,通常会报自己写的包导入时找不到。一:报错原因在PyDev中,test.py 中导入Tes
- 本文主要介绍了pandas针对excel处理的实现,分享给大家,具体如下:读取文件import padasdf = pd.read_csv(
- 图像增强算子几何变换算子图像的几何变换又称为图像空间变换, 它将一幅图像中的坐标位置映射到另一幅图像中的新坐标位置。图像缩放缩放只是调整图像
- GreatSQL社区原创内容未经授权不得随意使用,转载请联系小编并注明来源。GreatSQL是MySQL的国产分支版本,使用上与MySQL一
- 一些小技巧1. 如何查出效率低的语句?在MySQL下,在启动参数中设置 --log-slow-queries=[文件名],就可以在指定的日志
- 合并与分割tf.concattf.concat可以帮助我们实现拼接操作.格式:tf.concat( values,
- gjsonGJSON 是一个Go包,它提供了一种从json文档中获取值的快速简单的方法。它具有单行检索、点符号路径、迭代和解析 json 行
- 本文实例讲述了Python注释、分支结构、循环结构、伪“选择结构”用法。分享给大家供大家参考,具体如下:注释:python使用#作为行注释符