Python利用scikit-learn实现近邻算法分类的示例详解
作者:吃肉的小馒头 发布时间:2021-01-09 18:43:44
标签:Python,scikit-learn,近邻算法
scikit-learn库
scikit-learn已经封装好很多数据挖掘的算法
现介绍数据挖掘框架的搭建方法
1.转换器(Transformer)用于数据预处理,数据转换
2.流水线(Pipeline)组合数据挖掘流程,方便再次使用(封装)
3.估计器(Estimator)用于分类,聚类,回归分析(各种算法对象)
所有的估计器都有下面2个函数
fit() 训练
用法:estimator.fit(X_train, y_train)
estimator = KNeighborsClassifier() 是scikit-learn算法对象
X_train = dataset.data 是numpy数组
y_train = dataset.target 是numpy数组
predict() 预测
用法:estimator.predict(X_test)
estimator = KNeighborsClassifier() 是scikit-learn算法对象
X_test = dataset.data 是numpy数组
示例
%matplotlib inline
# Ionosphere数据集
# https://archive.ics.uci.edu/ml/machine-learning-databases/ionosphere/
# 下载ionosphere.data和ionosphere.names文件,放在 ./data/Ionosphere/ 目录下
import os
home_folder = os.path.expanduser("~")
print(home_folder) # home目录
# Change this to the location of your dataset
home_folder = "." # 改为当前目录
data_folder = os.path.join(home_folder, "data")
print(data_folder)
data_filename = os.path.join(data_folder, "ionosphere.data")
print(data_filename)
import csv
import numpy as np
# Size taken from the dataset and is known已知数据集形状
X = np.zeros((351, 34), dtype='float')
y = np.zeros((351,), dtype='bool')
with open(data_filename, 'r') as input_file:
reader = csv.reader(input_file)
for i, row in enumerate(reader):
# Get the data, converting each item to a float
data = [float(datum) for datum in row[:-1]]
# Set the appropriate row in our dataset用真实数据覆盖掉初始化的0
X[i] = data
# 1 if the class is 'g', 0 otherwise
y[i] = row[-1] == 'g' # 相当于if row[-1]=='g': y[i]=1 else: y[i]=0
# 数据预处理
from sklearn.cross_validation import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=14)
print("训练集数据有 {} 条".format(X_train.shape[0]))
print("测试集数据有 {} 条".format(X_test.shape[0]))
print("每条数据有 {} 个features".format(X_train.shape[1]))
输出:
训练集数据有 263 条
测试集数据有 88 条
每条数据有 34 个features
# 实例化算法对象->训练->预测->评价
from sklearn.neighbors import KNeighborsClassifier
estimator = KNeighborsClassifier()
estimator.fit(X_train, y_train)
y_predicted = estimator.predict(X_test)
accuracy = np.mean(y_test == y_predicted) * 100
print("准确率 {0:.1f}%".format(accuracy))
# 其他评价方式
from sklearn.cross_validation import cross_val_score
scores = cross_val_score(estimator, X, y, scoring='accuracy')
average_accuracy = np.mean(scores) * 100
print("平均准确率 {0:.1f}%".format(average_accuracy))
avg_scores = []
all_scores = []
parameter_values = list(range(1, 21)) # Including 20
for n_neighbors in parameter_values:
estimator = KNeighborsClassifier(n_neighbors=n_neighbors)
scores = cross_val_score(estimator, X, y, scoring='accuracy')
avg_scores.append(np.mean(scores))
all_scores.append(scores)
输出:
准确率 86.4%
平均准确率 82.3%
from matplotlib import pyplot as plt
plt.figure(figsize=(32,20))
plt.plot(parameter_values, avg_scores, '-o', linewidth=5, markersize=24)
#plt.axis([0, max(parameter_values), 0, 1.0])
for parameter, scores in zip(parameter_values, all_scores):
n_scores = len(scores)
plt.plot([parameter] * n_scores, scores, '-o')
plt.plot(parameter_values, all_scores, 'bx')
from collections import defaultdict
all_scores = defaultdict(list)
parameter_values = list(range(1, 21)) # Including 20
for n_neighbors in parameter_values:
for i in range(100):
estimator = KNeighborsClassifier(n_neighbors=n_neighbors)
scores = cross_val_score(estimator, X, y, scoring='accuracy', cv=10)
all_scores[n_neighbors].append(scores)
for parameter in parameter_values:
scores = all_scores[parameter]
n_scores = len(scores)
plt.plot([parameter] * n_scores, scores, '-o')
plt.plot(parameter_values, avg_scores, '-o')
来源:https://blog.csdn.net/qq_42034590/article/details/129243282
0
投稿
猜你喜欢
- Firefox 3.5已经发布了几个月了,且已经历5次小幅更新。而基于Gecko 1.9.2的Firefox 3.6也已经开发数月,现在已经
- (1)最近真是郁闷,在Myeclipse中使用DB Browser但出现以下问题:(2)然后赶紧百度,求大神解决,主要的解决方法试一下几种:
- 本文实例讲述了Python实现的特征提取操作。分享给大家供大家参考,具体如下:# -*- coding: utf-8 -*-"&q
- 1.多选按钮的方法以下为常用的方法:方法描述deselect()清除多选按钮选中选项。flash()在激活状态颜色和正常颜色之间闪烁几次多选
- 一、计数排序计数排序(Counting sort)是一种稳定的排序算法算法的步骤如下:找出待排序的数组中最大和最小的元素统计数组中每个值为i
- 本文实例讲述了python使用socket向客户端发送数据的方法。分享给大家供大家参考。具体如下:import socket, syspor
- 1、HKEY_LOCAL_MACHINE\SYSTEM\ControlSet001\Services\Eventlog\Applicatio
- 前言Python是C语言实现的,因此Python对象在C语言层面应该是一个结构体 ,组织对象占用的内存。 不同类型的对象,数据及行为均可能不
- 1、块级作用域想想此时运行下面的程序会有输出吗?执行会成功吗?#块级作用域if 1 == 1: name = "lzl"
- 代码如下:<script type=text/javascript src=http://fw.qq.com/ipaddress>
- 如果你使用过大部分,那么你的ASP功力应该是非常高的了ADO对象(太常用了):ConnectionCommandRecordSetRecor
- 一、为什么使用TFRecord?正常情况下我们训练文件夹经常会生成 train, test 或者val文件夹,这些文件夹内部往往会存着成千上
- MySQL使用环境变量TMPDIR的值作为保存临时文件的目录的路径名。如果未设置TMPDIR,MySQL将使用系统的默认值,通常为/tmp、
- 数据库(DataBase,DB)是一个长期存储在计算机内的、有组织的、有共享的、统一管理的数据集合。通俗地说,数据库就是一个按照数据结构来组
- 以下的文章主要是介绍MySQL5创建存储过程的实例演示,MySQL5创建存储在实际操作中应用的频率还是很高的,以下就是MySQL5创建存储过
- 五、XML带来的好处 (1)更有意义的搜索 数据可被XML唯一的标识。没有XML,搜索软件必须了解每个数据库是如何构建的。这实际上是不可能的
- 本文实例讲述了php 多继承的几种常见实现方法。分享给大家供大家参考,具体如下:class Parent1 { function
- 从PJBlog 2.7开始,验证码的功能就很好很强大了,但是同时也给手工输入带来了不小的麻烦——经常输错。之前我写了一个《自己写的一个PJB
- 如下所示:# -*- coding: utf-8 -*-#简述:一个整数,它加上100和加上268后都是一个完全平方数#提问:请问该数是多少
- js对文字进行编码涉及3个函数:escape,encodeURI,encodeURIComponent,相应3个解码函数:unescape,