在pytorch中动态调整优化器的学习率方式
作者:FesianXu 发布时间:2022-08-14 00:30:57
标签:pytorch,优化器,学习率
在深度学习中,经常需要动态调整学习率,以达到更好地训练效果,本文纪录在pytorch中的实现方法,其优化器实例为SGD优化器,其他如Adam优化器同样适用。
一般来说,在以SGD优化器作为基本优化器,然后根据epoch实现学习率指数下降,代码如下:
step = [10,20,30,40]
base_lr = 1e-4
sgd_opt = torch.optim.SGD(model.parameters(), lr=base_lr, nesterov=True, momentum=0.9)
def adjust_lr(epoch):
lr = base_lr * (0.1 ** np.sum(epoch >= np.array(step)))
for params_group in sgd_opt.param_groups:
params_group['lr'] = lr
return lr
只需要在每个train的epoch之前使用这个函数即可。
for epoch in range(60):
model.train()
adjust_lr(epoch)
for ind, each in enumerate(train_loader):
mat, label = each
...
补充知识:Pytorch框架下应用Bi-LSTM实现汽车评论文本关键词抽取
需要调用的模块及整体Bi-lstm流程
import torch
import pandas as pd
import numpy as np
from tensorflow import keras
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
import gensim
from sklearn.model_selection import train_test_split
class word_extract(nn.Module):
def __init__(self,d_model,embedding_matrix):
super(word_extract, self).__init__()
self.d_model=d_model
self.embedding=nn.Embedding(num_embeddings=len(embedding_matrix),embedding_dim=200)
self.embedding.weight.data.copy_(embedding_matrix)
self.embedding.weight.requires_grad=False
self.lstm1=nn.LSTM(input_size=200,hidden_size=50,bidirectional=True)
self.lstm2=nn.LSTM(input_size=2*self.lstm1.hidden_size,hidden_size=50,bidirectional=True)
self.linear=nn.Linear(2*self.lstm2.hidden_size,4)
def forward(self,x):
w_x=self.embedding(x)
first_x,(first_h_x,first_c_x)=self.lstm1(w_x)
second_x,(second_h_x,second_c_x)=self.lstm2(first_x)
output_x=self.linear(second_x)
return output_x
将文本转换为数值形式
def trans_num(word2idx,text):
text_list=[]
for i in text:
s=i.rstrip().replace('\r','').replace('\n','').split(' ')
numtext=[word2idx[j] if j in word2idx.keys() else word2idx['_PAD'] for j in s ]
text_list.append(numtext)
return text_list
将Gensim里的词向量模型转为矩阵形式,后续导入到LSTM模型中
def establish_word2vec_matrix(model): #负责将数值索引转为要输入的数据
word2idx = {"_PAD": 0} # 初始化 `[word : token]` 字典,后期 tokenize 语料库就是用该词典。
num2idx = {0: "_PAD"}
vocab_list = [(k, model.wv[k]) for k, v in model.wv.vocab.items()]
# 存储所有 word2vec 中所有向量的数组,留意其中多一位,词向量全为 0, 用于 padding
embeddings_matrix = np.zeros((len(model.wv.vocab.items()) + 1, model.vector_size))
for i in range(len(vocab_list)):
word = vocab_list[i][0]
word2idx[word] = i + 1
num2idx[i + 1] = word
embeddings_matrix[i + 1] = vocab_list[i][1]
embeddings_matrix = torch.Tensor(embeddings_matrix)
return embeddings_matrix, word2idx, num2idx
训练过程
def train(model,epoch,learning_rate,batch_size,x, y, val_x, val_y):
optimizor = optim.Adam(model.parameters(), lr=learning_rate)
data = TensorDataset(x, y)
data = DataLoader(data, batch_size=batch_size)
for i in range(epoch):
for j, (per_x, per_y) in enumerate(data):
output_y = model(per_x)
loss = F.cross_entropy(output_y.view(-1,output_y.size(2)), per_y.view(-1))
optimizor.zero_grad()
loss.backward()
optimizor.step()
arg_y=output_y.argmax(dim=2)
fit_correct=(arg_y==per_y).sum()
fit_acc=fit_correct.item()/(per_y.size(0)*per_y.size(1))
print('##################################')
print('第{}次迭代第{}批次的训练误差为{}'.format(i + 1, j + 1, loss), end=' ')
print('第{}次迭代第{}批次的训练准确度为{}'.format(i + 1, j + 1, fit_acc))
val_output_y = model(val_x)
val_loss = F.cross_entropy(val_output_y.view(-1,val_output_y.size(2)), val_y.view(-1))
arg_val_y=val_output_y.argmax(dim=2)
val_correct=(arg_val_y==val_y).sum()
val_acc=val_correct.item()/(val_y.size(0)*val_y.size(1))
print('第{}次迭代第{}批次的预测误差为{}'.format(i + 1, j + 1, val_loss), end=' ')
print('第{}次迭代第{}批次的预测准确度为{}'.format(i + 1, j + 1, val_acc))
torch.save(model,'./extract_model.pkl')#保存模型
主函数部分
if __name__ =='__main__':
#生成词向量矩阵
word2vec = gensim.models.Word2Vec.load('./word2vec_model')
embedding_matrix,word2idx,num2idx=establish_word2vec_matrix(word2vec)#输入的是词向量模型
#
train_data=pd.read_csv('./数据.csv')
x=list(train_data['文本'])
# 将文本从文字转化为数值,这部分trans_num函数你需要自己改动去适应你自己的数据集
x=trans_num(word2idx,x)
#x需要先进行填充,也就是每个句子都是一样长度,不够长度的以0来填充,填充词单独分为一类
# #也就是说输入的x是固定长度的数值列表,例如[50,123,1850,21,199,0,0,...]
#输入的y是[2,0,1,0,0,1,3,3,3,3,3,.....]
#填充代码你自行编写,以下部分是针对我的数据集
x=keras.preprocessing.sequence.pad_sequences(
x,maxlen=60,value=0,padding='post',
)
y=list(train_data['BIO数值'])
y_text=[]
for i in y:
s=i.rstrip().split(' ')
numtext=[int(j) for j in s]
y_text.append(numtext)
y=y_text
y=keras.preprocessing.sequence.pad_sequences(
y,maxlen=60,value=3,padding='post',
)
# 将数据进行划分
fit_x,val_x,fit_y,val_y=train_test_split(x,y,train_size=0.8,test_size=0.2)
fit_x=torch.LongTensor(fit_x)
fit_y=torch.LongTensor(fit_y)
val_x=torch.LongTensor(val_x)
val_y=torch.LongTensor(val_y)
#开始应用
w_extract=word_extract(d_model=200,embedding_matrix=embedding_matrix)
train(model=w_extract,epoch=5,learning_rate=0.001,batch_size=50,
x=fit_x,y=fit_y,val_x=val_x,val_y=val_y)#可以自行改动参数,设置学习率,批次,和迭代次数
w_extract=torch.load('./extract_model.pkl')#加载保存好的模型
pred_val_y=w_extract(val_x).argmax(dim=2)
来源:https://blog.csdn.net/LoseInVain/article/details/87858408


猜你喜欢
- 很多朋友在论坛和留言区域问mysql在什么情况下才需要进行分库分表,以及采用何种设计方式才是最优的选择,根据这些问题,小编为大家整理了关于M
- 前言当需要将多张图像拼接成一张更大的图像时,通常会用到图片拼接技术。这种技术在许多领域中都有广泛的应用,例如计算机视觉、图像处理、卫星图像、
- 事先在网上搜索了一大圈,头都大了,看到那么多文章写道在python里安装psycopg2的各种坑和各种麻烦,各种不成功。搜索了一下午,索性外
- 一,啥是Block Formatting Context当涉及到可视化布局的时候,Block Formatting Context提供了一个
- virtualenv与virtualenvwrapper当涉及到python项目开发时为了不污染全局环境,通常都会使用环境隔离管理工具vir
- 网上的很多PHP微信支付接入教程都颇为复杂,且需要配置和引入较多的文件,本人通过整理后给出一个单文件版的,希望可以给各位想接入微信支付的带来
- 在python中,如下代码结果一定不会让你吃惊:Python 3.3.2 (v3.3.2:d047928ae3f6, May 16 2013
- 一,最常见MYSQL最基本的分页方式:select * from content order by id desc limit 0, 10在
- 在知乎上看到这样一个问题:MySQL 查询 select * from table where id in (几百或几千个 id) 如何提高
- kruskal算法基本思路:先对边按权重从小到大排序,先选取权重最小的一条边,如果该边的两个节点均为不同的分量,则加入到最小生成树,否则计算
- 安装python分三个步骤:*下载python*安装python*检查是否安装成功1、下载Python(1)python下载地址https:
- 最近写了一些python3程序,四处能看到bytes类型,而它并不存在于python2中,这也是python3和python2显著区别之一。
- 在上一期python numpy 模块中对概述介绍了numpy 模块安装、使用方法、特点等入门知识。numpy 模块是一个开源的第三方Pyt
- 目录1. 字符串的翻转2. 判断字符串是不是回文串3. 单词大小写4. 字符串的拆分5. 字符串的合并6. 将元素进行重复7. 列表的拓展8
- 跨域当我们遇到请求后台接口遇到 Access-Control-Allow-Origin 时,那说明跨域了。跨域是因为浏览器的同源策略所导致,
- 之前使用django+mysql建立的一个站点,发现向数据库中写入中文字符时总会报错,尝试了修改settings文件和更改数据表的字符集后仍
- django路由和视图要了解django是如何运行的,首先要了解路由和视图两个概念,然后我们在项目中添加一些简单的路由和视图路由和视图的概念
- 把 Oracle 数据库从 RAC 集群迁移到单机环境一、系统环境1、源数据库db_name:hisdb SID:hisdb1、
- mysql时间序列间隔查询在时间序列处理中,采集到的数据保存在数据表中,采集的频率可能是固定间隔(10秒,1小时或者1天),但往往是不固定的
- 这篇文章主要介绍了Python安装whl文件过程图解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友