python使用梯度下降和牛顿法寻找Rosenbrock函数最小值实例
作者:SpringHerald 发布时间:2022-09-10 20:01:20
标签:python,梯度下降,牛顿法,Rosenbrock
Rosenbrock函数的定义如下:
其函数图像如下:
我分别使用梯度下降法和牛顿法做了寻找Rosenbrock函数的实验。
梯度下降
梯度下降的更新公式:
图中蓝色的点为起点,橙色的曲线(实际上是折线)是寻找最小值点的轨迹,终点(最小值点)为 (1,1)(1,1)。
梯度下降用了约5000次才找到最小值点。
我选择的迭代步长 α=0.002α=0.002,αα 没有办法取的太大,当为0.003时就会发生振荡:
牛顿法
牛顿法的更新公式:
Hessian矩阵中的每一个二阶偏导我是用手算算出来的。
牛顿法只迭代了约5次就找到了函数的最小值点。
下面贴出两个实验的代码。
梯度下降:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import ticker
def f(x, y):
return (1 - x) ** 2 + 100 * (y - x * x) ** 2
def H(x, y):
return np.matrix([[1200 * x * x - 400 * y + 2, -400 * x],
[-400 * x, 200]])
def grad(x, y):
return np.matrix([[2 * x - 2 + 400 * x * (x * x - y)],
[200 * (y - x * x)]])
def delta_grad(x, y):
g = grad(x, y)
alpha = 0.002
delta = alpha * g
return delta
# ----- 绘制等高线 -----
# 数据数目
n = 256
# 定义x, y
x = np.linspace(-1, 1.1, n)
y = np.linspace(-0.1, 1.1, n)
# 生成网格数据
X, Y = np.meshgrid(x, y)
plt.figure()
# 填充等高线的颜色, 8是等高线分为几部分
plt.contourf(X, Y, f(X, Y), 5, alpha=0, cmap=plt.cm.hot)
# 绘制等高线
C = plt.contour(X, Y, f(X, Y), 8, locator=ticker.LogLocator(), colors='black', linewidth=0.01)
# 绘制等高线数据
plt.clabel(C, inline=True, fontsize=10)
# ---------------------
x = np.matrix([[-0.2],
[0.4]])
tol = 0.00001
xv = [x[0, 0]]
yv = [x[1, 0]]
plt.plot(x[0, 0], x[1, 0], marker='o')
for t in range(6000):
delta = delta_grad(x[0, 0], x[1, 0])
if abs(delta[0, 0]) < tol and abs(delta[1, 0]) < tol:
break
x = x - delta
xv.append(x[0, 0])
yv.append(x[1, 0])
plt.plot(xv, yv, label='track')
# plt.plot(xv, yv, label='track', marker='o')
plt.xlabel('x')
plt.ylabel('y')
plt.title('Gradient for Rosenbrock Function')
plt.legend()
plt.show()
牛顿法:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import ticker
def f(x, y):
return (1 - x) ** 2 + 100 * (y - x * x) ** 2
def H(x, y):
return np.matrix([[1200 * x * x - 400 * y + 2, -400 * x],
[-400 * x, 200]])
def grad(x, y):
return np.matrix([[2 * x - 2 + 400 * x * (x * x - y)],
[200 * (y - x * x)]])
def delta_newton(x, y):
alpha = 1.0
delta = alpha * H(x, y).I * grad(x, y)
return delta
# ----- 绘制等高线 -----
# 数据数目
n = 256
# 定义x, y
x = np.linspace(-1, 1.1, n)
y = np.linspace(-1, 1.1, n)
# 生成网格数据
X, Y = np.meshgrid(x, y)
plt.figure()
# 填充等高线的颜色, 8是等高线分为几部分
plt.contourf(X, Y, f(X, Y), 5, alpha=0, cmap=plt.cm.hot)
# 绘制等高线
C = plt.contour(X, Y, f(X, Y), 8, locator=ticker.LogLocator(), colors='black', linewidth=0.01)
# 绘制等高线数据
plt.clabel(C, inline=True, fontsize=10)
# ---------------------
x = np.matrix([[-0.3],
[0.4]])
tol = 0.00001
xv = [x[0, 0]]
yv = [x[1, 0]]
plt.plot(x[0, 0], x[1, 0], marker='o')
for t in range(100):
delta = delta_newton(x[0, 0], x[1, 0])
if abs(delta[0, 0]) < tol and abs(delta[1, 0]) < tol:
break
x = x - delta
xv.append(x[0, 0])
yv.append(x[1, 0])
plt.plot(xv, yv, label='track')
# plt.plot(xv, yv, label='track', marker='o')
plt.xlabel('x')
plt.ylabel('y')
plt.title('Newton\'s Method for Rosenbrock Function')
plt.legend()
plt.show()
来源:https://blog.csdn.net/bu2_int/article/details/81333737


猜你喜欢
- 起步在django框架中,用的是 pytz 库处理时区问题,所以我也尝试用这个库来处理。但发现了一个奇怪的问题:import datetim
- 解决办法: 1.新建一个同名的数据库(数据文件与原来的要一致) 2.再停掉sql server(注意不要分离数据库) 3.用原数据库的数据文
- 提示: 利用单表简单查询和多表高级查询技能,并且根据查询要求灵活使用内连接查询、外连接查询或子查询等。同时还利用内连接查询的两种格式、三种外
- 引言目前Python2和Python3存在版本上的不兼容性,这里将列举dict中的问题之一。下面话不多说,来看看详细的介绍:1. Pytho
- 今天学习CI框架过程中遇到个问题: A PHP Error was encountered Severity: Notice Message
- 1. TokuFT file system space is really low and access is restricte
- 难道真的要我破解一个么?算了,正好试试我的Python水平。 python版 #coding: gbk import httplib, ur
- 可扩展标记语言 (XML) 是用于描述数据集内容以及应如何将数据输出到设备上或如何在 Web 页上显示数据的语言。标记语言的创建来源于出版商
- 尽量避免使用DOM。当需要反复使用DOM时,先把对DOM的引用存到JavaScript本地变量里再使用。使用设置innerHTML的方法来替
- 敲了一个错误的mysql命令, 想取消怎么办? 如果用ctrl + c, 就直接退出了。怎么办呢?来看看:mysql> show ta
- 本文实例为大家分享了Vue+Websocket简单实现聊天功能的具体代码,供大家参考,具体内容如下效果图:聊天室此篇文章是针对Websock
- JavaScript中的标识符的命名有以下规则:由字母、数字、$、_组成以字母、$、_开头不可以使用保留字!!!要有意义!!!!!!!标识符
- PHP asXML()函数实例格式化 XML(版本 1.0)中的 SimpleXML 对象的数据:<?php $note=<&l
- MySQL Version确认(版本确认)的几个方法1.SHOW VARIABLES LIKE 'VERSION';mysq
- 一旦获得MySQL服务器的连接,需要选择一个特定的数据库工作。这是因为MySQL服务器可能有一个以上的数据库。从命令提示符,选择MySQL数
- 问题你想创建一个内嵌变量的字符串,变量被它的值所表示的字符串替换掉。解决方案Python并没有对在字符串中简单替换变量值提供直接的支持。 但
- upload.htm <html><head><title>网站维护 -
- 很多人会把Primary Key和聚集索引搞混起来,或者认为这是同一个东西。这个概念是非常错误的。 主键是一个约束(constraint),
- # -*- coding: utf-8 -*-import sysimport MySQLdbreload(sys)sys.setdefau
- 在数据传递时,需要先编解码;常用的方式是JSON编解码(参见《golang之JSON处理》)。但有时却需要读取部分字段后,才能知道具体类型,