pytorch对可变长度序列的处理方法详解
作者:深度学习1 发布时间:2022-11-11 23:19:39
主要是用函数torch.nn.utils.rnn.PackedSequence()和torch.nn.utils.rnn.pack_padded_sequence()以及torch.nn.utils.rnn.pad_packed_sequence()来进行的,分别来看看这三个函数的用法。
1、torch.nn.utils.rnn.PackedSequence()
NOTE: 这个类的实例不能手动创建。它们只能被 pack_padded_sequence() 实例化。
PackedSequence对象包括:
一个data对象:一个torch.Variable(令牌的总数,每个令牌的维度),在这个简单的例子中有五个令牌序列(用整数表示):(18,1)
一个batch_sizes对象:每个时间步长的令牌数列表,在这个例子中为:[6,5,2,4,1]
用pack_padded_sequence函数来构造这个对象非常的简单:
如何构造一个PackedSequence对象(batch_first = True)
PackedSequence对象有一个很不错的特性,就是我们无需对序列解包(这一步操作非常慢)即可直接在PackedSequence数据变量上执行许多操作。特别是我们可以对令牌执行任何操作(即对令牌的顺序/上下文不敏感)。当然,我们也可以使用接受PackedSequence作为输入的任何一个pyTorch模块(pyTorch 0.2)。
2、torch.nn.utils.rnn.pack_padded_sequence()
这里的pack,理解成压紧比较好。 将一个 填充过的变长序列 压紧。(填充时候,会有冗余,所以压紧一下)
输入的形状可以是(T×B×* )。T是最长序列长度,B是batch size,*代表任意维度(可以是0)。如果batch_first=True的话,那么相应的 input size 就是 (B×T×*)。
Variable中保存的序列,应该按序列长度的长短排序,长的在前,短的在后。即input[:,0]代表的是最长的序列,input[:, B-1]保存的是最短的序列。
NOTE: 只要是维度大于等于2的input都可以作为这个函数的参数。你可以用它来打包labels,然后用RNN的输出和打包后的labels来计算loss。通过PackedSequence对象的.data属性可以获取 Variable。
参数说明:
input (Variable) – 变长序列 被填充后的 batch
lengths (list[int]) – Variable 中 每个序列的长度。
batch_first (bool, optional) – 如果是True,input的形状应该是B*T*size。
返回值:
一个PackedSequence 对象。
3、torch.nn.utils.rnn.pad_packed_sequence()
填充packed_sequence。
上面提到的函数的功能是将一个填充后的变长序列压紧。 这个操作和pack_padded_sequence()是相反的。把压紧的序列再填充回来。
返回的Varaible的值的size是 T×B×*, T 是最长序列的长度,B 是 batch_size,如果 batch_first=True,那么返回值是B×T×*。
Batch中的元素将会以它们长度的逆序排列。
参数说明:
sequence (PackedSequence) – 将要被填充的 batch
batch_first (bool, optional) – 如果为True,返回的数据的格式为 B×T×*。
返回值: 一个tuple,包含被填充后的序列,和batch中序列的长度列表。
例子:
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn import utils as nn_utils
batch_size = 2
max_length = 3
hidden_size = 2
n_layers =1
tensor_in = torch.FloatTensor([[1, 2, 3], [1, 0, 0]]).resize_(2,3,1)
tensor_in = Variable( tensor_in ) #[batch, seq, feature], [2, 3, 1]
seq_lengths = [3,1] # list of integers holding information about the batch size at each sequence step
# pack it
pack = nn_utils.rnn.pack_padded_sequence(tensor_in, seq_lengths, batch_first=True)
# initialize
rnn = nn.RNN(1, hidden_size, n_layers, batch_first=True)
h0 = Variable(torch.randn(n_layers, batch_size, hidden_size))
#forward
out, _ = rnn(pack, h0)
# unpack
unpacked = nn_utils.rnn.pad_packed_sequence(out)
print('111',unpacked)
输出:
111 (Variable containing:
(0 ,.,.) =
0.5406 0.3584
-0.1403 0.0308
(1 ,.,.) =
-0.6855 -0.9307
0.0000 0.0000
[torch.FloatTensor of size 2x2x2]
, [2, 1])
来源:104.116.116.112.58.47.47.119.119.119.46.99.110.98.108.111.103.115.46.99.111.109.47.108.105.110.100.97.120.105.110.47.112.47.56.48.53.50.48.52.51.46.104.116.109.108.
猜你喜欢
- 背景形态学处理方法是基于对二进制图像进行处理的,卷积核决定图像处理后的效果;形态学的处理哦本质上相当于对图像做前处理,提取出有用的特征,以便
- 本文实例为大家分享了python实现简单的飞机大战的具体代码,供大家参考,具体内容如下制作初衷这几天闲来没事干,就想起来好长时间没做过游戏了
- 随着软件项目进入“维护模式”,对可读性和编码标准的要求很容易落空(甚至从一开始就没有建立过那些标准)。然而,在代码库中保持一致的代码风格和测
- Python爬虫包 BeautifulSoup 递归抓取实例详解概要:爬虫的主要目的就是为了沿着网络抓取需要的内容。它们的本质是
- 1. 从键盘输入一个整数,求 100 除以它的商,并显示输出。要求对从键盘输入的数值进行异常处理。try: n=i
- 普通MySQL运行,数据量和访问量不大的话,是足够快的,但是当数据量和访问量剧增的时候,那么就会明显发现MySQL很慢,甚至do
- 需求: 一台机器上有多个网卡, 如何访问指定的 URL 时使用指定的网卡发送数据呢?$ curl --interface eth0 www.
- 在SQL SERVER中,你可能需要获得当前日期和计算一些其他的日期,例如,你的程序可能需要判断一个月的第一天或者最后一天。你们大部分人大概
- 一、本文使用的第三方包和工具python 3.8 谷歌浏览器selenium(3.141.0)(pip install
- Python是一门简单易学的编程语言。阅读好的Python程序感觉就像阅读英语,尽管是非常严格的英语。Python的这种伪代码特性是其最大强
- 1、下载LineNumber.pyhttp://idlex.sourceforge.net/extensions.html2、配置方法(1)
- *args 和 **kwargs首先,要知道的是并不是必须写成*args和**kwargs。 只有变量前⾯的*才是必须的。所以,你也可以写成
- 我就废话不多说了,大家还是直接看代码吧~import torch.nn as nnimport torch.nn.functional as
- 关于JavaSctipt的兼容性,最懒的办法就是用jQuery的工具函数。尽量不要用那些什么ECMAScript之类的函数,因为很多浏览器都
- 对于英文不行我来说使用英文版PyCharm实在是太难受了,网上好多汉化补丁都是网友提供了,下面为大家介绍一种PyCharm官方中文语言包汉化
- 有如下的一堆mac地址,需要更改成一定格式,如mac='902B345FB021'改为mac='90-2B-34-5
- 本文以一个简单的实例讲述了python实现斐波那契数列数列递归函数的方法,代码精简易懂。分享给大家供大家参考之用。主要函数代码如下:def
- 今天我们用python和python的工具包pygame来编写一个贪吃蛇的小游戏贪吃蛇游戏功能介绍贪吃蛇的游戏规则如下:通过上下左右键或者W
- 今天发现一个使用python写的管理cisco设备的小框架tratto,可以用来批量执行命令。下载后主要有3个文件:Systems.py 定
- 如下所示:L = ['adam', 'Lisa', 'bart', 'Paul