基于python的BP神经网络及异或实现过程解析
作者:沙克的世界 发布时间:2021-10-29 00:02:01
标签:python,bp,神经网络
BP神经网络是最简单的神经网络模型了,三层能够模拟非线性函数效果。
难点:
如何确定初始化参数?
如何确定隐含层节点数量?
迭代多少次?如何更快收敛?
如何获得全局最优解?
'''
neural networks
created on 2019.9.24
author: vince
'''
import math
import logging
import numpy
import random
import matplotlib.pyplot as plt
'''
neural network
'''
class NeuralNetwork:
def __init__(self, layer_nums, iter_num = 10000, batch_size = 1):
self.__ILI = 0;
self.__HLI = 1;
self.__OLI = 2;
self.__TLN = 3;
if len(layer_nums) != self.__TLN:
raise Exception("layer_nums length must be 3");
self.__layer_nums = layer_nums; #array [layer0_num, layer1_num ...layerN_num]
self.__iter_num = iter_num;
self.__batch_size = batch_size;
def train(self, X, Y):
X = numpy.array(X);
Y = numpy.array(Y);
self.L = [];
#initialize parameters
self.__weight = [];
self.__bias = [];
self.__step_len = [];
for layer_index in range(1, self.__TLN):
self.__weight.append(numpy.random.rand(self.__layer_nums[layer_index - 1], self.__layer_nums[layer_index]) * 2 - 1.0);
self.__bias.append(numpy.random.rand(self.__layer_nums[layer_index]) * 2 - 1.0);
self.__step_len.append(0.3);
logging.info("bias:%s" % (self.__bias));
logging.info("weight:%s" % (self.__weight));
for iter_index in range(self.__iter_num):
sample_index = random.randint(0, len(X) - 1);
logging.debug("-----round:%s, select sample %s-----" % (iter_index, sample_index));
output = self.forward_pass(X[sample_index]);
g = (-output[2] + Y[sample_index]) * self.activation_drive(output[2]);
logging.debug("g:%s" % (g));
for j in range(len(output[1])):
self.__weight[1][j] += self.__step_len[1] * g * output[1][j];
self.__bias[1] -= self.__step_len[1] * g;
e = [];
for i in range(self.__layer_nums[self.__HLI]):
e.append(numpy.dot(g, self.__weight[1][i]) * self.activation_drive(output[1][i]));
e = numpy.array(e);
logging.debug("e:%s" % (e));
for j in range(len(output[0])):
self.__weight[0][j] += self.__step_len[0] * e * output[0][j];
self.__bias[0] -= self.__step_len[0] * e;
l = 0;
for i in range(len(X)):
predictions = self.forward_pass(X[i])[2];
l += 0.5 * numpy.sum((predictions - Y[i]) ** 2);
l /= len(X);
self.L.append(l);
logging.debug("bias:%s" % (self.__bias));
logging.debug("weight:%s" % (self.__weight));
logging.debug("loss:%s" % (l));
logging.info("bias:%s" % (self.__bias));
logging.info("weight:%s" % (self.__weight));
logging.info("L:%s" % (self.L));
def activation(self, z):
return (1.0 / (1.0 + numpy.exp(-z)));
def activation_drive(self, y):
return y * (1.0 - y);
def forward_pass(self, x):
data = numpy.copy(x);
result = [];
result.append(data);
for layer_index in range(self.__TLN - 1):
data = self.activation(numpy.dot(data, self.__weight[layer_index]) - self.__bias[layer_index]);
result.append(data);
return numpy.array(result);
def predict(self, x):
return self.forward_pass(x)[self.__OLI];
def main():
logging.basicConfig(level = logging.INFO,
format = '%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s',
datefmt = '%a, %d %b %Y %H:%M:%S');
logging.info("trainning begin.");
nn = NeuralNetwork([2, 2, 1]);
X = numpy.array([[0, 0], [1, 0], [1, 1], [0, 1]]);
Y = numpy.array([0, 1, 0, 1]);
nn.train(X, Y);
logging.info("trainning end. predict begin.");
for x in X:
print(x, nn.predict(x));
plt.plot(nn.L)
plt.show();
if __name__ == "__main__":
main();
具体收敛效果
来源:https://www.cnblogs.com/thsss/p/11588433.html


猜你喜欢
- 今天来分享python学习的一个小例子,使用python暴力破解mysql数据库,实现方式是通过UI类库tkinter实现可视化面板效果,在
- 在Oracle数据库中,如何查找,定位一张表最后一次的DML操作的时间呢? 方式有三种,不过都有一些局限性,下面简单的解析、总结一下。1:使
- 很喜欢Python这门语言。在看过语法后学习了Django 这个 Web 开发框架。算是对 Python 有些熟悉了。不过对里面很多东西还是
- 1 简介二进制日志,记录对数据发生或潜在发生更改的SQL语句,并以二进制形式保存在磁盘。2 Binlog 的作用主要作用:复制、恢复和审计。
- This is a {t}. {name}是一个很强大的字符串模板解析方法。它接受三个参数,分别是{args.text},{args.obj
- SQL注入语句有时候会使用替换查询技术,就是让原有的查询语句查不到结果出错,而让自己构造的查询语句执行,并把执行结果代替原有查询语句查询结果
- 1、切片使用切片来实现列表的倒序排序,mylist[start:end:step],不改变原列表。#!/usr/bin/env python
- 现在大多数Centos6.x版本的系统python都是2.x,现因开发需求需要安装前端代码的构建工具glue,故必须要做python版本的升
- 程序运行效率程序的运行效率分为两种:第一种是时间效率,第二种是空间效率。时间效率被称为时间复杂度,而空间效率被称作空间复杂度。时间复杂度主要
- 最近在碰到有同学问我,vue父组件怎么使用外部对象,具体例子如下:有组件a:<div @click="onClick&quo
- 本文实例讲述了Python实现字符串与数组相互转换功能。分享给大家供大家参考,具体如下:字符串转数组str = '1,2,3'
- 零、本讲学习目标了解面向对象编程思想掌握类和对象的定义和使用了解Python中的对象一、面向对象(一)程序员“面向对象”在现实世界中存在各种
- 反射指的是运行时动态的获取变量的相关信息1. reflect 包类型是变量,类别是常量reflect.TypeOf,获取变量的类型,返回re
- 实例如下所示:#!/usr/bin/python# -*- coding: UTF-8 -*-import reimport urllib,
- 问题描述当前环境win10,python_3.6.1,64位。在windows下,在dos中运行pip install Scrapy报错:b
- 创建Dataframe主要是使用pandas中的DataFrame函数,其核心就是第一个参数:data,传入原始数据,因此我们可以据此给出六
- mysql安装好经常发现无法正常启动碰到最多的是error 2003的错误,以下为解决方法: mysqld -nt -remove mysq
- 在生活和工作中,我们每个人每天都在和时间打交道:早上什么时候起床?地铁几分钟来一趟?中午什么时候开始午休?明天是星期几?距离上次买衣服已经2
- Python 是一种美丽的语言,它简单易用却非常强大。但你真的会用 Python 的所有功能吗?任何编程语言的高级特征通常都是通过大量的使用
- 在最新版的pandas中(不知道之前的版本有没有这个问题),当我们对具有多层次索引的对象做切片或者通过df[bool_list]的方式索引的