pytorch中的优化器optimizer.param_groups用法
作者:我是天才很好 发布时间:2023-02-12 08:10:08
标签:pytorch,优化器
optimizer.param_groups
: 是长度为2的list,其中的元素是2个字典;
optimizer.param_groups[0]
: 长度为6的字典,包括[‘amsgrad', ‘params', ‘lr', ‘betas', ‘weight_decay', ‘eps']这6个参数;
optimizer.param_groups[1]
: 好像是表示优化器的状态的一个字典;
import torch
import torch.optim as optimh2
w1 = torch.randn(3, 3)
w1.requires_grad = True
w2 = torch.randn(3, 3)
w2.requires_grad = True
o = optim.Adam([w1])
print(o.param_groups)
[{'amsgrad': False,
'betas': (0.9, 0.999),
'eps': 1e-08,
'lr': 0.001,
'params': [tensor([[ 2.9064, -0.2141, -0.4037],
[-0.5718, 1.0375, -0.6862],
[-0.8372, 0.4380, -0.1572]])],
'weight_decay': 0}]
Per the docs, the add_param_group method accepts a param_group parameter that is a dict. Example of use:h2import torch
import torch.optim as optimh2
w1 = torch.randn(3, 3)
w1.requires_grad = True
w2 = torch.randn(3, 3)
w2.requires_grad = True
o = optim.Adam([w1])
print(o.param_groups)
givesh2[{'amsgrad': False,
'betas': (0.9, 0.999),
'eps': 1e-08,
'lr': 0.001,
'params': [tensor([[ 2.9064, -0.2141, -0.4037],
[-0.5718, 1.0375, -0.6862],
[-0.8372, 0.4380, -0.1572]])],
'weight_decay': 0}]
nowh2o.add_param_group({'params': w2})
print(o.param_groups)
[{'amsgrad': False,
'betas': (0.9, 0.999),
'eps': 1e-08,
'lr': 0.001,
'params': [tensor([[ 2.9064, -0.2141, -0.4037],
[-0.5718, 1.0375, -0.6862],
[-0.8372, 0.4380, -0.1572]])],
'weight_decay': 0},
{'amsgrad': False,
'betas': (0.9, 0.999),
'eps': 1e-08,
'lr': 0.001,
'params': [tensor([[-0.0560, 0.4585, -0.7589],
[-0.1994, 0.4557, 0.5648],
[-0.1280, -0.0333, -1.1886]])],
'weight_decay': 0}]
# 动态修改学习率
for param_group in optimizer.param_groups:
param_group["lr"] = lr
# 得到学习率optimizer.param_groups[0]["lr"] h2# print('查看optimizer.param_groups结构:')
# i_list=[i for i in optimizer.param_groups[0].keys()]
# print(i_list)
['amsgrad', 'params', 'lr', 'betas', 'weight_decay', 'eps']
补充:pytorch中的优化器总结
以SGD优化器为例:
# -*- coding: utf-8 -*-
#@Time :2019/7/3 22:31
#@Author :XiaoMa
from torch import nn as nn
import torch as t
from torch.autograd import Variable as V
#定义一个LeNet网络
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
self.features=nn.Sequential(
nn.Conv2d(3,6,5),
nn.ReLU(),
nn.MaxPool2d(2,2),
nn.Conv2d(6,16,5),
nn.ReLU(),
nn.MaxPool2d(2,3)
)
self.classifier=nn.Sequential(\
nn.Linear(16*5*5,120),
nn.ReLU(),
nn.Linear(120,84),
nn.ReLU(),
nn.Linear(84,10)
)
def forward(self, x):
x=self.features(x)
x=x.view(-1,16*5*5)
x=self.classifier(x)
return x
net=Net()
from torch import optim #优化器
optimizer=optim.SGD(params=net.parameters(),lr=1)
optimizer.zero_grad() #梯度清零,相当于net.zero_grad()
input=V(t.randn(1,3,32,32))
output=net(input)
output.backward(output) #fake backward
optimizer.step() #执行优化
#为不同子网络设置不同的学习率,在finetune中经常用到
#如果对某个参数不指定学习率,就使用默认学习率
optimizer=optim.SGD(
[{'param':net.features.parameters()}, #学习率为1e-5
{'param':net.classifier.parameters(),'lr':1e-2}],lr=1e-5
)
#只为两个全连接层设置较大的学习率,其余层的学习率较小
special_layers=nn.ModuleList([net.classifier[0],net.classifier[3]])
special_layers_params=list(map(id,special_layers.parameters()))
base_params=filter(lambda p:id(p) not in special_layers_params,net.parameters())
optimizer=t.optim.SGD([
{'param':base_params},
{'param':special_layers.parameters(),'lr':0.01}
],lr=0.001)
调整学习率主要有两种做法。
一种是修改optimizer.param_groups中对应的学习率,另一种是新建优化器(更简单也是更推荐的做法),由于optimizer十分轻量级,构建开销很小,故可以构建新的optimizer。
但是新建优化器会重新初始化动量等状态信息,这对使用动量的优化器来说(如自带的momentum的sgd),可能会造成损失函数在收敛过程中出现震荡。
如:
#调整学习率,新建一个optimizer
old_lr=0.1
optimizer=optim.SGD([
{'param':net.features.parameters()},
{'param':net.classifiers.parameters(),'lr':old_lr*0.5}],lr=1e-5)
来源:https://wstchhwp.blog.csdn.net/article/details/108490956
0
投稿
猜你喜欢
- Oracle是世界上用得最多的数据库之一,活动服务器网页(ASP)是一种被广泛用于创建 * 页的功能强大的服务器端脚本语言。许多ASP开发人
- 这个问题困扰了我很长很长的时间,在跨域获取数据的时候就要用到服务器端的对象,以前一直用的是Msxml.XMLHTTP。但是问题太多了,特别严
- 首先数据库里需要有一个自动编号字段(ID)。然后第一次访问的时候,取出所有记录,定制好每页的记录数PageSize,计算出页数,然后根据页数
- 编号标准宗地编码(landCode)所在区段编码(sectCode)1131001BG001G0012131001BG002G0013131
- 很多文章都有提到关于使用phpExcel实现Excel数据的导入导出,大部分文章都差不多,或者就是转载的,都会出现一些问题,下面是本人研究p
- 本文实例为大家分享了React实现表格选取的具体代码,供大家参考,具体内容如下在工作中,遇到一个需求,在表格中实现类似于Excel选中一片区
- 数据库复制:简单来说,数据库复制就是由两台服务器,主服务器和备份服务器,主服务器修改后,备份服务器自动修改。复制的模式有两种:推送模式和请求
- 1.弹启一个全屏窗口 <html> <body onload="win
- 在python中,它也有这个含义,不过有点区别的是,“当...时候”这个条件成立在一段范围或者时间间隔内,从而在这段时间间隔内让python
- 对象Javascript 根本上是和对象相关的。数组是对象。函数是对象。对象是对象。那什么是对象呢?对象是名-值对的集合。名是字符串,值可以
- 使用ewebeditor作为后台编辑器时,尤其是一个页面中使用多次该编辑器时,在提交数据时,可能会遇到数据被重复提交的情况。搜索找来一些解决
- 用下列代码判断表单提交到服务器的数据是否有谈话内容,如果没有的话就不作处理了:if len(usersays)<>0&
- asp如何获知页面上的图象的实际尺寸大小?见下面的两个asp文件:<!--#include virtual="/i
- 导语大家好,我是栗子同学!今天给大家分享一个好玩的东西让时光倒流——当当当,其实就是让视频倒放而已正
- <P><HTML><HEAD><TITLE>javascriptboy</TITLE&
- 在用 Javascript 验证表单(form)中的单选框(radio)是否选中时,很多新手都会遇到问题,原因是 radio 和普通的文本框
- 清除浮动一个凡是做页面的人都会遇到的一个东西,但是是否大家都能够清楚的知道,全方位的了解呢?于是一闲下来了马上写了这样的一篇文章,不能讲面面
- 有个文本文件,需要替换里面的一个词,用python来完成,我是这样写的:def modify_text(): with open('
- 如下所示:# coding:utf-8import shapefilew = shapefile.Writer()w.autoBalance
- 1、引言小丝:鱼哥, 请教你个问题。小鱼:你觉得你得问题,是正儿八经的吗?小丝:那必须的, 人都正经,何况问题呢?小鱼:那可不敢说, 你得问