PyTorch策略梯度算法详情
作者:??盼小辉丶??? 发布时间:2022-12-20 14:35:12
0. 前言
本节中,我们使用策略梯度算法解决 CartPole
问题。虽然在这个简单问题中,使用随机搜索策略和爬山算法就足够了。但是,我们可以使用这个简单问题来更专注的学习策略梯度算法,并在之后的学习中使用此算法解决更加复杂的问题。
1. 策略梯度算法
策略梯度算法通过记录回合中的所有时间步并基于回合结束时与这些时间步相关联的奖励来更新权重训练智能体。使智能体遍历整个回合然后基于获得的奖励更新策略的技术称为蒙特卡洛策略梯度。
在策略梯度算法中,模型权重在每个回合结束时沿梯度方向移动。关于梯度的计算,我们将在下一节中详细解释。此外,在每一时间步中,基于当前状态和权重计算的概率得到策略,并从中采样一个动作。与随机搜索和爬山算法(通过采取确定性动作以获得更高的得分)相反,它不再确定地采取动作。因此,策略从确定性转变为随机性。例如,如果向左的动作和向右的动作的概率为 [0.8,0.2]
,则表示有 80%
的概率选择向左的动作,但这并不意味着一定会选择向左的动作。
2. 使用策略梯度算法解决CartPole问题
在本节中,我们将学习使用 PyTorch
实现策略梯度算法了。 导入所需的库,创建 CartPole
环境实例,并计算状态空间和动作空间的尺寸:
import gym
import torch
import matplotlib.pyplot as plt
env = gym.make('CartPole-v0')
n_state = env.observation_space.shape[0]
print(n_state)
n_action = env.action_space.n
print(n_action)
定义 run_episode
函数,在此函数中,根据给定输入权重的情况下模拟一回合 CartPole
游戏,并返回奖励和计算出的梯度。在每个时间步中执行以下操作:
根据当前状态和输入权重计算两个动作的概率
probs
根据结果概率采样一个动作
action
以概率作为输入计算
softmax
函数的导数d_softmax
,由于只需要计算与选定动作相关的导数,因此:
\frac {\partial p_i} {\partial z_j} = p_i(1-p_j), i=j∂zj∂pi=pi(1−pj),i=j
将所得的导数
d_softmax
除以概率probs
,以得与策略相关的对数导数d_log
根据链式法则计算权重的梯度
grad
:
\frac {dy}{dx}=\frac{dy}{du}\cdot\frac{du}{dx}dxdy=dudy⋅dxdu
记录得到的梯度
grad
执行动作,累积奖励并更新状态
def run_episode(env, weight):
state = env.reset()
grads = []
total_reward = 0
is_done = False
while not is_done:
state = torch.from_numpy(state).float()
# 根据当前状态和输入权重计算两个动作的概率 probs
z = torch.matmul(state, weight)
probs = torch.nn.Softmax(dim=0)(z)
# 根据结果概率采样一个动作 action
action = int(torch.bernoulli(probs[1]).item())
# 以概率作为输入计算 softmax 函数的导数 d_softmax
d_softmax = torch.diag(probs) - probs.view(-1, 1) * probs
# 计算与策略相关的对数导数d_log
d_log = d_softmax[action] / probs[action]
# 计算权重的梯度grad
grad = state.view(-1, 1) * d_log
grads.append(grad)
state, reward, is_done, _ = env.step(action)
total_reward += reward
if is_done:
break
return total_reward, grads
回合完成后,返回在此回合中获得的总奖励以及在各个时间步中计算的梯度信息,用于之后更新权重。
接下来,定义要运行的回合数,在每个回合中调用 run_episode
函数,并初始化权重以及用于记录每个回合总奖励的变量:
n_episode = 1000
weight = torch.rand(n_state, n_action)
total_rewards = []
在每个回合结束后,使用计算出的梯度来更新权重。对于回合中的每个时间步,权重都根据学习率、计算出的梯度和智能体在剩余时间步中的获得的总奖励进行更新。
我们知道在回合终止之前,每一时间步的奖励都是 1
。因此,我们用于计算每个时间步策略梯度的未来奖励是剩余的时间步数。在每个回合之后,我们使用随机梯度上升方法将梯度乘以未来奖励来更新权重。这样,一个回合中经历的时间步越长,权重的更新幅度就越大,这将增加获得更大总奖励的机会。我们设定学习率为 0.001
:
learning_rate = 0.001
for e in range(n_episode):
total_reward, gradients = run_episode(env, weight)
print('Episode {}: {}'.format(e + 1, total_reward))
for i, gradient in enumerate(gradients):
weight += learning_rate * gradient * (total_reward - i)
total_rewards.append(total_reward)
然后,我们计算通过策略梯度算法获得的平均总奖励:
print('Average total reward over {} episode: {}'.format(n_episode, sum(total_rewards)/n_episode))
我们可以绘制每个回合的总奖励变化情况,如下所示:
plt.plot(total_rewards)
plt.xlabel('Episode')
plt.ylabel('Reward')
plt.show()
在上图中,我们可以看到奖励会随着训练回合的增加呈现出上升趋势,然后能够在最大值处稳定。我们还可以看到,即使在收敛之后,奖励也会振荡,这是由于策略梯度算法是一种随机策略算法。
最后,我们查看学习到策略在 1000
个新回合中的性能表现,并计算平均奖励:
n_episode_eval = 1000
total_rewards_eval = []
for e in range(n_episode_eval):
total_reward, _ = run_episode(env, weight)
print('Episode {}: {}'.format(e+1, total_reward))
total_rewards_eval.append(total_reward)
print('Average total reward over {} episode: {}'.format(n_episode_eval, sum(total_rewards_eval)/n_episode_eval))
# Average total reward over 1000 episode: 200
进行测试后,可以看到回合的平均奖励接近最大值 200
。可以多次测试训练后的模型,得到的平均奖励较为稳定。正如我们一开始所说的那样,对于诸如 CartPole
之类的简单环境,策略梯度算法可能大材小用,但它为我们解决更加复杂的问题奠定了基础。
来源:https://juejin.cn/post/7118954918198640654


猜你喜欢
- 前言在Django中有大量的通用类视图,例如ListView,DetailView,CreateView,UpdateView等等,将所有重
- 前言最近在看element-ui的源码,发现了一个这样的属性:inject.遂查看官网provider/injectprovider/inj
- 发现这个也是偶然,在测试的时候发现的,因此问题还发现一个bug。蛮有意思~ 假如输入http://www.aspxhome.com的话,在
- Python装饰器用法Python的装饰器是个好东西,它能干很多事情。但对于新手,它看起来似乎没那么简单。但事实上,装饰器本身也只是个函数。
- Python input 等待键盘输入,超时选择默认值,释放input,之后重新进入等待键盘输入状态,直到用户输入可用数据。一、调用 fun
- 相对于 Ajax,服务端 XMLHTTP 就是在服务端使用 XMLHttpRequest 对象了。虽然说,在服务端使用异步请求是比较不方便的
- 开放平台的API接口调用需要限制其频率,以节约服务器资源和避免恶意的频繁调用使用自定义频率限制组件:utils/thottle.pyclas
- 前言最近在B站上看到一个漂亮的仙女姐姐跳舞视频,循环看了亿遍又亿遍,久久不能离开!看着仙紫小姐姐的蹦迪视频,除了一键三连还能做什么?突发奇想
- 看到这个先思考,自己怎么输出他?为什么它有颜色?特殊符号去哪找?特殊符号在符号大全找 符号大全http://www.fhdq.net/任务1
- 目录一、_用于临时变量1.1 REPL1.2 for循环中的_1.3 元组拆包中的_1.4 国际化函数1.5 大数字表示形式二、var_用于
- Python获取当前时间_获取格式化时间:Python获取当前时间:使用 time.time( ) 获取到距离1970年1月1日的秒数(浮点
- 代码如下:---这是一个人事系统中的示例,要求记录一下员工的缺勤情况 ---1.要在表中记录一下缺勤计分,是对经常缺勤者的一种处
- 由于新云CMS系统,网站底部“版权信息”字段在数据库中是“文本”类型,有250个字符的限制。想在这里给加网站统计代码,因为字数限制的原因,就
- 结合网上的资料,自己亲自的去安装了一次MySQL,安装版本是win7x64 5.7.16。在安装过程中出现并解决了如下问题:1.“MySQL
- 前言我们在使用vue-cli启动项目的时候npm run dev便可以启动我们的项目了,通常我们的请求地址是以localhost:8080来
- 登录、注销和登录限制:登录在使用authenticate进行验证后,如果验证通过了。那么会返回一个user对象,拿到user对象后,可以使用
- 方案一func md5V(str string) string { h := md5.New() &n
- 一个SELECT查询中的LIKE语句来执行这种查询,尽管这种方法可行,但对于全文查找而言,这是一种效率极端低下的方法,尤其在处理大量数据的时
- 目录1. 反向引用_命名分组2. 正则函数小提示:总结1. 反向引用_命名分组# ### 反向引用import restrvar = &qu
- mysql取json字符串字段下的某个键的值要求:mysql版本5.7及以上SELECT JSON_EXTRACT('{"