自适应线性神经网络Adaline的python实现详解
作者:沙克的世界 发布时间:2023-11-03 03:57:40
标签:自适应,线性,神经网络,adaline,python
自适应线性神经网络Adaptive linear network, 是神经网络的入门级别网络。
相对于感知器,采用了f(z)=z的激活函数,属于连续函数。
代价函数为LMS函数,最小均方算法,Least mean square。
实现上,采用随机梯度下降,由于更新的随机性,运行多次结果是不同的。
'''
Adaline classifier
created on 2019.9.14
author: vince
'''
import pandas
import math
import numpy
import logging
import random
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
'''
Adaline classifier
Attributes
w: ld-array = weights after training
l: list = number of misclassification during each iteration
'''
class Adaline:
def __init__(self, eta = 0.001, iter_num = 500, batch_size = 1):
'''
eta: float = learning rate (between 0.0 and 1.0).
iter_num: int = iteration over the training dataset.
batch_size: int = gradient descent batch number,
if batch_size == 1, used SGD;
if batch_size == 0, use BGD;
else MBGD;
'''
self.eta = eta;
self.iter_num = iter_num;
self.batch_size = batch_size;
def train(self, X, Y):
'''
train training data.
X:{array-like}, shape=[n_samples, n_features] = Training vectors,
where n_samples is the number of training samples and
n_features is the number of features.
Y:{array-like}, share=[n_samples] = traget values.
'''
self.w = numpy.zeros(1 + X.shape[1]);
self.l = numpy.zeros(self.iter_num);
for iter_index in range(self.iter_num):
for rand_time in range(X.shape[0]):
sample_index = random.randint(0, X.shape[0] - 1);
if (self.activation(X[sample_index]) == Y[sample_index]):
continue;
output = self.net_input(X[sample_index]);
errors = Y[sample_index] - output;
self.w[0] += self.eta * errors;
self.w[1:] += self.eta * numpy.dot(errors, X[sample_index]);
break;
for sample_index in range(X.shape[0]):
self.l[iter_index] += (Y[sample_index] - self.net_input(X[sample_index])) ** 2 * 0.5;
logging.info("iter %s: w0(%s), w1(%s), w2(%s), l(%s)" %
(iter_index, self.w[0], self.w[1], self.w[2], self.l[iter_index]));
if iter_index > 1 and math.fabs(self.l[iter_index - 1] - self.l[iter_index]) < 0.0001:
break;
def activation(self, x):
return numpy.where(self.net_input(x) >= 0.0 , 1 , -1);
def net_input(self, x):
return numpy.dot(x, self.w[1:]) + self.w[0];
def predict(self, x):
return self.activation(x);
def main():
logging.basicConfig(level = logging.INFO,
format = '%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s',
datefmt = '%a, %d %b %Y %H:%M:%S');
iris = load_iris();
features = iris.data[:99, [0, 2]];
# normalization
features_std = numpy.copy(features);
for i in range(features.shape[1]):
features_std[:, i] = (features_std[:, i] - features[:, i].mean()) / features[:, i].std();
labels = numpy.where(iris.target[:99] == 0, -1, 1);
# 2/3 data from training, 1/3 data for testing
train_features, test_features, train_labels, test_labels = train_test_split(
features_std, labels, test_size = 0.33, random_state = 23323);
logging.info("train set shape:%s" % (str(train_features.shape)));
classifier = Adaline();
classifier.train(train_features, train_labels);
test_predict = numpy.array([]);
for feature in test_features:
predict_label = classifier.predict(feature);
test_predict = numpy.append(test_predict, predict_label);
score = accuracy_score(test_labels, test_predict);
logging.info("The accruacy score is: %s "% (str(score)));
#plot
x_min, x_max = train_features[:, 0].min() - 1, train_features[:, 0].max() + 1;
y_min, y_max = train_features[:, 1].min() - 1, train_features[:, 1].max() + 1;
plt.xlim(x_min, x_max);
plt.ylim(y_min, y_max);
plt.xlabel("width");
plt.ylabel("heigt");
plt.scatter(train_features[:, 0], train_features[:, 1], c = train_labels, marker = 'o', s = 10);
k = - classifier.w[1] / classifier.w[2];
d = - classifier.w[0] / classifier.w[2];
plt.plot([x_min, x_max], [k * x_min + d, k * x_max + d], "go-");
plt.show();
if __name__ == "__main__":
main();
来源:https://www.cnblogs.com/thsss/p/11520673.html
0
投稿
猜你喜欢
- 数据库系统是管理信息系统的核心,基于数据库的联机事务处理(OLTP)以及联机分析处理(OLAP)是银行、企业、政府等部门最为重要的计算机应用
- 常有人因为页面的面积问题,想在一个窄小的地方,显示一条条的信息,顺序往上滚动,在经典的BBS里,有一个随机上滚动的JS,好些人用不了,现在蛋
- 本文实例讲述了Python实现操纵控制windows注册表的方法。分享给大家供大家参考,具体如下:使用_winreg模块的话基本概念:KEY
- 一、数字类型所谓的“数字类”,就是指 DECIMAL 和 NUMERIC,它们是同一种类型。它严格的
- 在网络上看到的数字人整合动网论坛的方法都非常不全,站长们都是抄人家的,也不说明可不可用,提供下载的文件也不能下载.现在我提供一些信息。一、整
- 人生苦短,快学Python!上一周发了一篇文章《Python Tkinter图形工具使用方法及实例解析》,很多小伙伴都希望能多出点教程,今天
- 具体用法如下: 代码如下:-- ============================================= -- Autho
- 如何做一个计数器并让人家申请使用? 第一步:创建一个计数器(最简单的数字计数器,不是图片式的):&nbs
- 代码如下:---找出促销活动中销售额最高的职员 ---你刚在一家服装销售公司中找到了一份工作,此时经理要求你根据数据库中的两张表
- 安装官网下载http://ffmpeg.org/选择需要的版本在这个网址下载ffmpeg,https://github.com/BtbN/F
- 安装的依赖包flaskpymysqlflask_scriptflask_migrateflask_sqlalchemy创建Flask项目(项
- 今天逛论坛时看到有朋友问,是否有专门教Javascript的学校,这里想想把自己的一点建议和自己3年来的前端Javascript开发的经验跟
- 目录Python3 面向对象一丶面向对象技术简介对象可以包含任意数量和类型的数据。2.Python面向对象的三大特性一、继承 二、
- 问题你想在一个消息传输层如 sockets 、multiprocessing connections 或 ZeroMQ 的基础之上实现一个简
- 目录Python里的dict和set的效率有多高?字典中的散列表1.散列值和相等性散列表算法dict的实现及其导致的结果1.键必须死可散列的
- js对文字进行编码涉及3个函数:escape,encodeURI,encodeURIComponent,相应3个解码函数:unescape,
- 前言晚上回家闲来无事,想打开某直播平台,看看小姐姐直播。看着一个个多才多艺的小姐姐,眼花缭乱,好难抉择。究竟看哪个小姐姐直播好呢?今天我们就
- 信息架构的组件可以拆分成四类组织系统 如何组织信息,例如,依据主题或年代顺序。标签系统 如何表示信息,例如,科学术语(“Acer”)或通俗术
- jxdawei的blog:http://www.iwcn.net/本文讨论的是在web标准普及的形势下,网站程序员的定位以及如何与设计师配合
- 前言Logistic回归涉及到高等数学,线性代数,概率论,优化问题。本文尽量以最简单易懂的叙述方式,以少讲公式原理,多讲形象化案例为原则,给