感知器基础原理及python实现过程详解
作者:沙克的世界 发布时间:2023-11-07 16:24:35
标签:感知器,原理,python,实现
简单版本,按照李航的《统计学习方法》的思路编写
数据采用了著名的sklearn自带的iries数据,最优化求解采用了SGD算法。
预处理增加了标准化操作。
'''
perceptron classifier
created on 2019.9.14
author: vince
'''
import pandas
import numpy
import logging
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
'''
perceptron classifier
Attributes
w: ld-array = weights after training
l: list = number of misclassification during each iteration
'''
class Perceptron:
def __init__(self, eta = 0.01, iter_num = 50, batch_size = 1):
'''
eta: float = learning rate (between 0.0 and 1.0).
iter_num: int = iteration over the training dataset.
batch_size: int = gradient descent batch number,
if batch_size == 1, used SGD;
if batch_size == 0, use BGD;
else MBGD;
'''
self.eta = eta;
self.iter_num = iter_num;
self.batch_size = batch_size;
def train(self, X, Y):
'''
train training data.
X:{array-like}, shape=[n_samples, n_features] = Training vectors,
where n_samples is the number of training samples and
n_features is the number of features.
Y:{array-like}, share=[n_samples] = traget values.
'''
self.w = numpy.zeros(1 + X.shape[1]);
self.l = numpy.zeros(self.iter_num);
for iter_index in range(self.iter_num):
for sample_index in range(X.shape[0]):
if (self.activation(X[sample_index]) != Y[sample_index]):
logging.debug("%s: pred(%s), label(%s), %s, %s" % (sample_index,
self.net_input(X[sample_index]) , Y[sample_index],
X[sample_index, 0], X[sample_index, 1]));
self.l[iter_index] += 1;
for sample_index in range(X.shape[0]):
if (self.activation(X[sample_index]) != Y[sample_index]):
self.w[0] += self.eta * Y[sample_index];
self.w[1:] += self.eta * numpy.dot(X[sample_index], Y[sample_index]);
break;
logging.info("iter %s: %s, %s, %s, %s" %
(iter_index, self.w[0], self.w[1], self.w[2], self.l[iter_index]));
def activation(self, x):
return numpy.where(self.net_input(x) >= 0.0 , 1 , -1);
def net_input(self, x):
return numpy.dot(x, self.w[1:]) + self.w[0];
def predict(self, x):
return self.activation(x);
def main():
logging.basicConfig(level = logging.INFO,
format = '%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s',
datefmt = '%a, %d %b %Y %H:%M:%S');
iris = load_iris();
features = iris.data[:99, [0, 2]];
# normalization
features_std = numpy.copy(features);
for i in range(features.shape[1]):
features_std[:, i] = (features_std[:, i] - features[:, i].mean()) / features[:, i].std();
labels = numpy.where(iris.target[:99] == 0, -1, 1);
# 2/3 data from training, 1/3 data for testing
train_features, test_features, train_labels, test_labels = train_test_split(
features_std, labels, test_size = 0.33, random_state = 23323);
logging.info("train set shape:%s" % (str(train_features.shape)));
p = Perceptron();
p.train(train_features, train_labels);
test_predict = numpy.array([]);
for feature in test_features:
predict_label = p.predict(feature);
test_predict = numpy.append(test_predict, predict_label);
score = accuracy_score(test_labels, test_predict);
logging.info("The accruacy score is: %s "% (str(score)));
#plot
x_min, x_max = train_features[:, 0].min() - 1, train_features[:, 0].max() + 1;
y_min, y_max = train_features[:, 1].min() - 1, train_features[:, 1].max() + 1;
plt.xlim(x_min, x_max);
plt.ylim(y_min, y_max);
plt.xlabel("width");
plt.ylabel("heigt");
plt.scatter(train_features[:, 0], train_features[:, 1], c = train_labels, marker = 'o', s = 10);
k = - p.w[1] / p.w[2];
d = - p.w[0] / p.w[2];
plt.plot([x_min, x_max], [k * x_min + d, k * x_max + d], "go-");
plt.show();
if __name__ == "__main__":
main();
来源:https://www.cnblogs.com/thsss/p/11519846.html


猜你喜欢
- 除了常用的csv文件和excel文件之外,我们还可以通过PY把数据保存文npy文件格式和mat文件格式。1. npy文件npy即numpy对
- 本文介绍了Python WEB应用部署的实现方法,分享给大家,具体如下: 使用Apache模块mod_wsgi运行Python WSGI应用
- php高并发之opcache今天工作的时候接触到客户的一台服务器,业务逻辑比较简单 。估算pv在120w左右吧,用的是阿里云2c4g的服务器
- 本文实例为大家分享了python利用opencv实现颜色检测的具体代码,供大家参考,具体内容如下需要实现倒车辅助标记检测的功能,倒车辅助标记
- 本文实例总结了常用SQL语句优化技巧。分享给大家供大家参考,具体如下:除了建立索引之外,保持良好的SQL语句编写习惯将会降低SQL性能问题发
- 1.安 * azel,从github上下载linux版的.sh文件,然后安装2.从GitHub上下载最新的TensorFlow源码3.进入Te
- 1. composer 安装 PDF组件composer require setasign/fpdicomposer require set
- vue动态添加store,路由和国际化vue动态添加store想写组件库?用这个吧 …// store module标
- 以前看到 andy的关于“Quiet Structure”觉的很不错,于是今天到她的个人站点上逛逛,发现不少好的文章,今天介绍的是
- 一、问题描述通过调用MyQR模块来实现生成个人所需二维码。安装:pip install myqr二、代码实现1.普通二维码from MyQR
- 问题描述我在flask程序中,启动了另一个python程序-test.py:os.system('nohup python /opt
- 关于杨辉三角是什么东西,右转 * :杨辉三角稍微看一下直观一点的图:11112113311464115101051161520156117
- slice 可以用来获取数组片段,它返回新数组,不会修改原数组。除了正常用法,slice 经常用来将 array-like 对象转换为 tr
- 一、函数概述简单来说 函数 就是自己定义的一段 小程序 方便自己调取使用def 用来定义函数的关键字 也就是这个函数的名字函数运行到retu
- 目录环境介绍原理介绍环境介绍Python 3.6 + OpenCV 3.4.1.15原理介绍首先,提取出模板中每一个数字的轮廓,再对信用卡图
- 1.使用测量工具,量化性能才能改进性能,常用的timeit和memory_profiler,此外还有profile、cProfile、hot
- 生生把写过的java版改成javascript版,第一次写,很不专业,见谅。唉,我是有多闲。var Sudoku = { &nbs
- 前言一年一度的虐狗节终于过去了,朋友圈各种晒,晒自拍,晒娃,晒美食,秀恩爱的。程序员在晒什么,程序员在加班。但是礼物还是少不了的,送什么好?
- 代码很简单,功能也很简单 =w=webpage2pdf#!/usr/bin/env python3import systry: from P
- 链表由一系列不必在内存中相连的结构构成,这些对象按线性顺序排序。每个结构含有表元素和指向后继元素的指针。最后一个单元的指针指向NULL。为了