Python 实现一个全连接的神经网络
作者:??zidea???? 发布时间:2021-01-20 05:46:42
前言
在这篇文章中,准备用 Python 从头开始实现一个全连接的神经网络。你可能会问,为什么需要自己实现,有很多库和框架可以为我们做这件事,比如 Tensorflow、Pytorch 等。这里只想说只有自己亲手实现了,才是自己的。
想到今天自己从接触到从事与神经网络相关工作已经多少 2、3 年了,其中也尝试用 tensorflow 或 pytorch 框架去实现一些经典网络。不过对于反向传播背后机制还是比较模糊。
梯度
梯度是函数上升最快方向,最快的方向也就是说这个方向函数形状很陡峭,那么也是函数下降最快的方向。
虽然关于一些理论、梯度消失和结点饱和可以输出一个 1、2、3 但是深究还是没有底气,毕竟没有自己动手去实现过一个反向传播和完整训练过程。所以感觉还是浮在表面,知其所以然而。
因为最近有一段空闲时间、所以利用这段休息时间将要把这部分知识整理一下、深入了解了解
类型 | 符号 | 说明 | 表达式 | 维度 |
---|---|---|---|---|
标量 | n^LnL | 表示第 L 层神经元的数量 | ||
向量 | B^LBL | 表示第 L 层偏置 | n^L \times 1nL×1 | |
矩阵 | W^LWL | 表示第 L 层的权重 | n^L \times n^LnL×nL | |
向量 | Z^LZL | 表示第 L 层输入到激活函数的值 | Z^L=W^LA^{(L-1)} + B^LZL=WLA(L−1)+BL | n^L \times 1nL×1 |
向量 | A^LAL | 表示第 L 层输出值 | A^L = \sigma(Z^L)AL=σ(ZL) | n^L \times 1nL×1 |
我们大家可能都了解训练神经网络的过程,就是更新网络参数,更新的方向是降低损失函数值。也就是将学习问题转换为了一个优化的问题。那么如何更新参数呢?我们需要计算参与训练参数相对于损失函数的导数,然后求解梯度,然后使用梯度下降法来更新参数,迭代这个过程,可以找到一个最佳的解决方案来最小化损失函数。
我们知道反向传播主要就是用来结算损失函数相对于权重和偏置的导数
可能已经听到或读到了,很多关于在网络通过反向传播来传递误差的信息。然后根据神经元的 w 和 b 对偏差贡献的大小。也就是将误差分配到每一个神经元上。 但这里的误差(error)是什么意思呢?这个误差的确切的定义又是什么?答案是这些误差是由每一层神经网络所贡献的,而且某一层的误差是后继层误差基础上分摊的,网络中第 层的误差用
来表示。
反向传播是基于 4 个基本方程的,通过这些方程来计算误差和损失函数,这里将这 4 个方程一一列出
关于如何解读这个 4 个方程,随后想用一期分享来说明。
class NeuralNetwork(object):
def __init__(self):
pass
def forward(self,x):
# 返回前向传播的 Z 也就是 w 和 b 线性组合,输入激活函数前的值
# 返回激活函数输出值 A
# z_s , a_s
pass
def backward(self,y,z_s,a_s):
#返回前向传播中学习参数的导数 dw db
pass
def train(self,x,y,batch_size=10,epochs=100,lr=0.001):
pass
我们都是神经网络学习过程,也就是训练过程。主要分为两个阶段前向传播和后向传播
在前向传播函数中,主要计算传播的 Z 和 A,关于 Z 和 A 具体是什么请参见前面表格
在反向传播中计算可学习变量 w 和 b 的导数
def __init__(self,layers = [2 , 10, 1], activations=['sigmoid', 'sigmoid']):
assert(len(layers) == len(activations)+1)
self.layers = layers
self.activations = activations
self.weights = []
self.biases = []
for i in range(len(layers)-1):
self.weights.append(np.random.randn(layers[i+1], layers[i]))
self.biases.append(np.random.randn(layers[i+1], 1))
layers 参数用于指定每一层神经元的个数
activations 为每一层指定激活函数,也就是
来简单读解一下代码 assert(len(layers) == len(activations)+1)
for i in range(len(layers)-1):
self.weights.append(np.random.randn(layers[i+1], layers[i]))
self.biases.append(np.random.randn(layers[i+1], 1))
因为权重连接每一个层神经元的 w 和 b ,也就两两层之间的方程,上面代码是对
前向传播
def feedforward(self, x):
# 返回前向传播的值
a = np.copy(x)
z_s = []
a_s = [a]
for i in range(len(self.weights)):
activation_function = self.getActivationFunction(self.activations[i])
z_s.append(self.weights[i].dot(a) + self.biases[i])
a = activation_function(z_s[-1])
a_s.append(a)
return (z_s, a_s)
这里激活函数,这个函数返回值是一个函数,在 python 用 lambda
来返回一个函数,这里简答留下一个伏笔,随后会对其进行修改。
@staticmethod
def getActivationFunction(name):
if(name == 'sigmoid'):
return lambda x : np.exp(x)/(1+np.exp(x))
elif(name == 'linear'):
return lambda x : x
elif(name == 'relu'):
def relu(x):
y = np.copy(x)
y[y<0] = 0
return y
return relu
else:
print('Unknown activation function. linear is used')
return lambda x: x
[@staticmethod]
def getDerivitiveActivationFunction(name):
if(name == 'sigmoid'):
sig = lambda x : np.exp(x)/(1+np.exp(x))
return lambda x :sig(x)*(1-sig(x))
elif(name == 'linear'):
return lambda x: 1
elif(name == 'relu'):
def relu_diff(x):
y = np.copy(x)
y[y>=0] = 1
y[y<0] = 0
return y
return relu_diff
else:
print('Unknown activation function. linear is used')
return lambda x: 1
反向传播
这是本次分享重点
def backpropagation(self,y, z_s, a_s):
dw = [] # dC/dW
db = [] # dC/dB
deltas = [None] * len(self.weights) # delta = dC/dZ 计算每一层的误差
# 最后一层误差
deltas[-1] = ((y-a_s[-1])*(self.getDerivitiveActivationFunction(self.activations[-1]))(z_s[-1]))
# 反向传播
for i in reversed(range(len(deltas)-1)):
deltas[i] = self.weights[i+1].T.dot(deltas[i+1])*(self.getDerivitiveActivationFunction(self.activations[i])(z_s[i]))
#a= [print(d.shape) for d in deltas]
batch_size = y.shape[1]
db = [d.dot(np.ones((batch_size,1)))/float(batch_size) for d in deltas]
dw = [d.dot(a_s[i].T)/float(batch_size) for i,d in enumerate(deltas)]
# 返回权重(weight)矩阵 and 偏置向量(biases)
return dw, db
首先计算最后一层误差根据 BP1 等式可以得到下面的式子
deltas[-1] = ((y-a_s[-1])*(self.getDerivitiveActivationFunction(self.activations[-1]))(z_s[-1]))
接下来基于上一层的 误差来计算当前层
batch_size = y.shape[1]
db = [d.dot(np.ones((batch_size,1)))/float(batch_size) for d in deltas]
dw = [d.dot(a_s[i].T)/float(batch_size) for i,d in enumerate(deltas)]
开始训练
def train(self, x, y, batch_size=10, epochs=100, lr = 0.01):
# update weights and biases based on the output
for e in range(epochs):
i=0
while(i<len(y)):
x_batch = x[i:i+batch_size]
y_batch = y[i:i+batch_size]
i = i+batch_size
z_s, a_s = self.feedforward(x_batch)
dw, db = self.backpropagation(y_batch, z_s, a_s)
self.weights = [w+lr*dweight for w,dweight in zip(self.weights, dw)]
self.biases = [w+lr*dbias for w,dbias in zip(self.biases, db)]
# print("loss = {}".format(np.linalg.norm(a_s[-1]-y_batch) ))
来源:https://juejin.cn/post/7113071079257538567


猜你喜欢
- python 利用pywifi模块实现连接网络破解wifi密码实时监控网络,具体内容如下:import pywififrom pywifi
- 前言学完语法和正在学习语法的时候,我们可以在空闲的时候,写几个简单的小项目,今天我们就用最基础的语法看两个实战语法练习猜数字游戏项目游戏说明
- 功能:返回字符、二进制、文本或图像表达式的一部分语法:SUBSTRING ( expression, start, length ) 1、s
- 示例1:pycallclass.cpp:#include <iostream>using namespace std;typed
- 前言Python是C语言实现的,因此Python对象在C语言层面应该是一个结构体 ,组织对象占用的内存。 不同类型的对象,数据及行为均可能不
- Python是一种面向对象的语言,但它不像C++一样把标准类都封装到库中,而是进行了进一步的封装,语言本身就集成一些类和函数,比如print
- 前言大家做开发的应该都知道,在开发程序中很重要的一点是测试,我们如何保证代码的质量,如何保证每个函数是可运行,运行结果是正确的,又如何保证写
- 在ASP的实际操作中,总会发生这样的情况,如在银行,从我的帐户往费文华的帐户划款,我的帐户显示已经划出,但因银行的系统出现故障,导致费文华帐
- Python3实现旋转数组的3种算法下面是Python3实现的旋转数组的3种算法。一、题目给定一个数组,将数组中的元素向右移动 k 个位置,
- 前端JS中使用XMLHttpRequest 2上传图片到服务器,PC端和大部分手机上都正常,但在少部分安卓手机上上传失败,服务器上查看图片,
- Go(又称Golang)是Google开发的一种静态强类型、编译型、并发型,并具有垃圾回收功能的编程语言。下载Go语言开发包大家可以在Go语
- 想使用正则表达式来获取一段文本中的任意字符,写出如下匹配规则: (.*) 结果运行之后才发现,无法获得换行之后的文本。于是查了一下手册,才发
- 问题背景两张表一张是用户表a(主键是int类型),一张是用户具体信息表b(用户表id字段是varchar类型)。因为要显示用户及用户信息,所
- 一、备份数据库1、打开SQL企业管理器,在控制台根目录中依次点开Microsoft SQL Server2、SQL Server组-->
- 本文实例讲述了python基于pygame实现响应游戏中事件的方法。分享给大家供大家参考,具体如下:先看一下我做的demo效果:当玩家按下键
- Microsoft Visual C++ 14.0 is required. Get it with “Microsof
- 一个写给别人的小代码顺便也贴上来这是一个滑动展示用的小容器通过鼠标移动和离开触发滑动效果<!DOCTYPE html PUBLIC &
- 一、读取Excel中的数据安装xlrd 只能读取Excel内容pip install xlrd==1.2.0xlrd库的open_workb
- 这是我上一篇关于安全的文章的其中一节。这是一个众所周知的事实,对你运行中的网站的MySQL数据库备份是极为重要的只需按照下面3步做,一切都在
- 概述在python中,以单下划线开头的(_a)的代表不能直接访问的类属性,需通过类提供的接口进行访问,不能用“from