python实现随机梯度下降(SGD)
作者:芳草碧连天lc 发布时间:2021-04-15 19:41:20
标签:python,梯度下降,SGD
使用神经网络进行样本训练,要实现随机梯度下降算法。这里我根据麦子学院彭亮老师的讲解,总结如下,(神经网络的结构在另一篇博客中已经定义):
def SGD(self, training_data, epochs, mini_batch_size, eta, test_data=None):
if test_data:
n_test = len(test_data)#有多少个测试集
n = len(training_data)
for j in xrange(epochs):
random.shuffle(training_data)
mini_batches = [
training_data[k:k+mini_batch_size]
for k in xrange(0,n,mini_batch_size)]
for mini_batch in mini_batches:
self.update_mini_batch(mini_batch, eta)
if test_data:
print "Epoch {0}: {1}/{2}".format(j, self.evaluate(test_data),n_test)
else:
print "Epoch {0} complete".format(j)
其中training_data是训练集,是由很多的tuples(元组)组成。每一个元组(x,y)代表一个实例,x是图像的向量表示,y是图像的类别。
epochs表示训练多少轮。
mini_batch_size表示每一次训练的实例个数。
eta表示学习率。
test_data表示测试集。
比较重要的函数是self.update_mini_batch,他是更新权重和偏置的关键函数,接下来就定义这个函数。
def update_mini_batch(self, mini_batch,eta):
nabla_b = [np.zeros(b.shape) for b in self.biases]
nabla_w = [np.zeros(w.shape) for w in self.weights]
for x,y in mini_batch:
delta_nabla_b, delta_nable_w = self.backprop(x,y)#目标函数对b和w的偏导数
nabla_b = [nb+dnb for nb,dnb in zip(nabla_b,delta_nabla_b)]
nabla_w = [nw+dnw for nw,dnw in zip(nabla_w,delta_nabla_w)]#累加b和w
#最终更新权重为
self.weights = [w-(eta/len(mini_batch))*nw for w, nw in zip(self.weights, nabla_w)]
self.baises = [b-(eta/len(mini_batch))*nb for b, nb in zip(self.baises, nabla_b)]
这个update_mini_batch函数根据你传入的一些数据进行更新神经网络的权重和偏置。
来源:http://blog.csdn.net/leichaoaizhaojie/article/details/56840328


猜你喜欢
- CreateOrUpdate 是业务开发中很常见的场景,我们支持用户对某个业务实体进行创建/配置。希望实现的 repository 接口要达
- transpose() 这个函数如果括号内不带参数,就相当于转置,和.T效果一样,而今天主要来讲解其带参数。我们看如下一个numpy的数组:
- 突然发现, pycharm 2020.2都出来了哈, 现在jetbrain团队对中文用户也比较友好, 比以前更加适合小白了再就是很多类似的教
- 1.首先安装 PyPDF2 库:pip install PyPDF22.然后保存下面文件(已带注释,具体实现请自己思考)import osi
- # -*- coding: UTF-8 -*-from __future__ import unicode_literalsimport I
- 题目:一个txt文件中已知数据格式为:C4DC4D/mayaC4DC4D/suC4D/max/AE统计每个字段出现的次数,比如C4D、may
- 问题产生:pycharm→settings→Project interpreter→下载matplotlib包运行代码,出现以下提示:找不到
- python绘制横向水平柱状条形图Bar,供大家参考,具体内容如下import matplotlibimport randomimport
- 这篇文章主要介绍了通过Kettle自定义jar包供javascript使用,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参
- 在 Go 语言中,map 是一种非常常见的数据类型,它可以用于快速地检索数据。Go 语言中的 map 与其他编程语言中的类似的数据类型相比,
- 私有变量表示方法在变量前加上两个下划线的是私有变量。class Teacher(): def __init__(self,nam
- 之前mysql用着好着,可是今天在启动mysql后输入密码出现了闪退,在任务管理器中发现mysql服务没有启动,当手动启动时提示拒绝访问。在
- 项目结构:运行效果:=========================================================代码部
- 1 Series线性的数据结构, series是一个一维数组Pandas 会默然用0到n-1来作为series的index, 但也可以自己指
- 中国,美国,英国3国时间js同步动态显示,对于做企业网站的朋友相信用的到,特别是做英文网站的朋友,加上这一段代码会给你的网站增色不少!本文j
- Python import .pyd文件时会搜索sys.path列表中的路径运行import xxx.pyd1. 'ImportEr
- 本文实例讲述了PHP中Static(静态)关键字功能与用法。分享给大家供大家参考,具体如下:1、什么是static?static 是C++中
- 整理文档,搜刮出一个使用Vue.Js结合Jquery Ajax加载数据的两种方式的代码,稍微整理精简一下做下分享。废话不多说,直接上代码ht
- 最近看ECShop到网上找资料,发现好多说明ECShop的文件结构不全面,于是想自己弄个出来。但这是个无聊耗时的工作,自己就写了个Pytho
- 在PHP中,我们不能用const直接定义数组常量,但是const可以定义字符串常量,结合eval()函数使字符串常量能执行。所以,我们可以用