对pytorch中不定长序列补齐的操作
作者:XJTU-Qidong 发布时间:2022-03-24 17:33:04
第二种方法通常是在load一个batch数据时, 在collate_fn中进行补齐的.
以下给出两种思路:
第一种思路是比较容易想到的, 就是对一个batch的样本进行遍历, 然后使用np.pad对每一个样本进行补齐.
for unit in data:
mask = np.zeros(max_length)
s_len = len(unit[0]) # calculate the length of sequence in each unit
mask[: s_len] = 1
unit[0] = np.pad(unit[0], (0, max_length - s_len), 'constant', constant_values=(0, 0))
mask_batch.append(mask)
但是这种方法在batch size很大的情况下会很慢, 因为使用for循环进行了遍历. 我在实际用的时候, 当batch_size=128时, 一个batch的加载时间甚至是一个batch训练时间的几倍!
因此, 我想到如何并行地对序列进行补齐. 第二种方法的思路就是使用torch中自带的pad_sequence来并行补齐.
batch_sequence = list(map(lambda x: torch.tensor(x[findex]), x_data))
batch_data[feat] = torch.nn.utils.rnn.pad_sequence(batch_sequence).T
可以看到这里使用pad_sequence一次性对整个batch进行补齐. 下面对这个函数进行详细说明.
pad_sequence详解
from torch.utils.rnn import pad_sequence
a = torch.ones(10)
b = torch.ones(6)
c = torch.ones(20)
abc = pad_sequence([a,b,c]) # shape(20, 3)
注意这个函数接收的是一个元素为tensor的列表, 而不是tensor.
最终, 这个函数会将所有tensor转换为tensor矩阵#shape(max_length, batch_size). 因此, 在使用完后通常还需要转置一下.
补充:PyTorch中用于RNN变长序列填充函数的简单使用
1、PyTorch中RNN变长序列的问题
RNN在处理变长序列时有它的优势。在分批处理变长序列问题时,每个序列的长度往往不会完全相等,因此针对一个batch中序列长度不一的情况,需要对某些序列进行PAD(填充)操作,使得一个batch内的序列长度相等。
PyTorch中的pack_padded_sequence和pad_packed_sequence可处理上述问题,以下用一个示例演示这两个函数的简单使用方法。
2、填充函数简介
“压缩”函数:用于将填充后的序列tensor进行压缩,方便RNN处理
pack_padded_sequence(input, lengths, batch_first=False, enforce_sorted=True)
(1)input->被“压缩”的tensor,维度一般为[batch_size,_max_seq_len[,embedding_size]]或者[max_seq_len,batch_size[,embedding_size]]
若input维度为:[batch_size,_max_seq_len[,embedding_size]]
要将batch_first设置为True,这表示input的第一个维度为batch的数量
若input维度为:[max_seq_len,batch_size[,embedding_size]]
要将batch_first设置为False(默认值),这表示input的第一个维度不是batch的数量
(2)lengths->lengths参数表示一个batch中序列真实长度,类型为列表,在例子中详细说明
(3)batch_first->表示batch的数量是否在input的第一维度,默认值为False
(4)enforce_sorted->input中的会自动按照lengths的情况进行排序,默认值为
“解压”函数:该函数与"压缩函数"相对应,经“压缩函数”处理的输入经过RNN得到的最终结果可以利用该函数进行“解压”
pad_packed_sequence(sequence, batch_first=False, padding_value=0.0, total_length=None):
(1)sequence->压缩函数处理过的input经RNN后得到的结果
(2)batch_first->与“压缩”函数中的batch_first一致
(3)padding_value->序列进行填充时使用的索引,默认为0
(4)total_length->暂略
3、PyTorch代码示例
代码如下(示例):
# Create by leslie_miao on 2020/11/1
import torch
import torch.nn as nn
d_model = 10 # 词嵌入的维度
hidden_size = 20 # lstm隐藏层单元数量
layer_num = 1 # lstm层数
# 输入inputs,维度为[batch_size,max_seq_len]=[3,4],其中0代表填充
# 该input包含3个序列,每个序列的真实长度分别为: 4 3 2
inputs = torch.tensor([[1,2,3,4],[1,2,3,0],[1,2,0,0]])
embedding = nn.Embedding(5,d_model)
# 获取词嵌入后的inputs 当前inputs的维度为[batch_size,max_seq_len,d_model]=[3,4,10]
inputs = embedding(inputs)
# 查看inputs的维度
print(inputs.size())
# print: torch.Size([3, 4, 10])
# 利用“压缩”函数对inputs进行压缩处理,[4,3,2]分别为inputs中序列的真实长度,batch_first=True表示inputs的第一维是batch_size
inputs = nn.utils.rnn.pack_padded_sequence(inputs,lengths=[4,3,2],batch_first=True)
# 查看经“压缩”函数处理过的inputs的维度
print(inputs[0].size())
# print: torch.Size([9, 10])
# 定义RNN网络
network = nn.LSTM(input_size=d_model,hidden_size=hidden_size,batch_first=True,num_layers=layer_num)
# 初始化RNN相关门参数
c_0 = torch.zeros((layer_num,3,hidden_size))
h_0 = torch.zeros((layer_num,3,hidden_size)) # [rnn层数,batch_size,hidden_size]
# inputs经过RNN网络后得到的结果outputs
output,(h_n,c_n) = network(inputs,(h_0,c_0))
#查看未经“解压函数”处理的outputs维度
print(output[0].size())
# print: torch.Size([9, 20])
# 利用“解压函数”对outputs进行解压操作,其中batch_first设置与“压缩函数相同”,padding_value为0
output = nn.utils.rnn.pad_packed_sequence(output,batch_first=True,padding_value=0)
# 查看经“解压函数”处理的outputs维度
print(output[0].size())
# print:torch.Size([3, 4, 20])
来源:https://blog.csdn.net/dong_liuqi/article/details/114670932
猜你喜欢
- 场景游戏里有很多关卡(可能有几百个了),理论上每次发布到外网前都要遍历各关卡看看会不会有异常,上次就有玩家在打某个关卡时卡住不动了,如果每个
- 一、Pytest概念Pytest 是 Python 的一种单元测试框架,与 Python 自带的 unittest 测试框架类似,但是比 u
- <!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 4.01 Transitional//EN&
- 顺序执行顺序执行是我们比较熟悉的工作模式,类似俗称流水账编程。所有不含分支、循环和goto语言,并且每一递归调用的Go函数一般都是顺序执行的
- 我就废话不多说了,大家还是直接看代码吧!import matplotlib.pyplot as pltimport numpy as npf
- 一维线性拟合数据为y=4x+5加上噪音结果:import numpy as npfrom mpl_toolkits.mplot3d impo
- 一 Process对象的join方法在主进程运行过程中如果想并发地执行其他的任务,我们可以开启子进程,此时主进程的任务与子进程的任务分两种情
- 要实现标题的功能,总共分四步:1.创建html错误页2.配置settings3.编写视图4.配置url我的开发环境:django1.10.3
- 对一个有向无环图(Directed Acyclic Graph简称DAG)G进行拓扑排序,是将G中所有顶点排成一个线性序列,使得图中任意一对
- 创建表书籍模型: 书籍有书名和出版日期,一本书可能会有多个作者,一个作者也可以写多本书,所以作者和书籍的关系就是多对多的关联关系(many-
- 发送端可以不停的发送新文件,接收端可以不停的接收新文件。例如:发送端输入:e:\visio.rar,接收端会默认保存为 e:\new_vis
- Eric A. Meyer 对基于 Web 标准的 CSS 与 HTML 绝非一知半解,他是这个领域杰出的专家,曾写过不少 CSS 方面的书
- 接触Python时间不长,对有些知识点,掌握的不是很扎实,我个人比较崇尚不管学习什么东西,首先一定回去把基础打的非常扎实了,再往高处走。今天
- 本文实例讲述了Go语言中的range用法。分享给大家供大家参考。具体如下:for 循环的 range 格式可以对 slice 或者 map
- 处理办法,删除该文件,或清空该文件内容;我的处理是清空后,再设置该文件权限为Everyone拒绝访问。
- 每次在操作数据库的时候最烦的就是根据表单提交的内容写sql语句,特别是字段比较多的时候很麻烦,动不动就容易写错。所以我就写了下面的生成sql
- 一、正则表达式–元字符re 模块使 Python 语言拥有全部的正则表达式功能1. 数量词# 提取大小写字母混合的单词import rea
- 概要:本文主要描述XHTML中相对定位和绝对定位各自的本质、用法、区别和两者之间的关系。以及使用CSS的Left、Right、Top、Bot
- 假设我们有一个非常简单的Post模型,它将是一个图像及其描述,from django.db import modelsclass Post(
- /* 判断指定的内容是否为空,若为空则弹出 警告框 */ function isEmpty(theValue, strMsg){ if(th