Pytorch DataLoader 变长数据处理方式
作者:HappyCtest 发布时间:2022-08-06 23:07:50
标签:Pytorch,DataLoader,变长数据
关于Pytorch中怎么自定义Dataset数据集类、怎样使用DataLoader迭代加载数据,这篇官方文档已经说得很清楚了,这里就不在赘述。
现在的问题:有的时候,特别对于NLP任务来说,输入的数据可能不是定长的,比如多个句子的长度一般不会一致,这时候使用DataLoader加载数据时,不定长的句子会被胡乱切分,这肯定是不行的。
解决方法是重写DataLoader的collate_fn,具体方法如下:
# 假如每一个样本为:
sample = {
# 一个句子中各个词的id
'token_list' : [5, 2, 4, 1, 9, 8],
# 结果y
'label' : 5,
}
# 重写collate_fn函数,其输入为一个batch的sample数据
def collate_fn(batch):
# 因为token_list是一个变长的数据,所以需要用一个list来装这个batch的token_list
token_lists = [item['token_list'] for item in batch]
# 每个label是一个int,我们把这个batch中的label也全取出来,重新组装
labels = [item['label'] for item in batch]
# 把labels转换成Tensor
labels = torch.Tensor(labels)
return {
'token_list': token_lists,
'label': labels,
}
# 在使用DataLoader加载数据时,注意collate_fn参数传入的是重写的函数
DataLoader(trainset, batch_size=4, shuffle=True, num_workers=4, collate_fn=collate_fn)
使用以上方法,可以保证DataLoader能Load出一个batch的数据,load出来的东西就是重写的collate_fn函数最后return出来的字典。
来源:https://blog.csdn.net/HappyCtest/article/details/88872651
0
投稿
猜你喜欢
- 方案一:利用selenium+phantomjs * 面浏览器的形式访问网站,再获取cookie值:from selenium import
- 一、前言最近本都是开开心心的打开电脑写一些祖传BUG但一个报错阻碍了我写BUG的进度!这年代还有能阻碍我写BUG的报错???二、解决过程一个
- 阅读上一篇:FrontPage2002简明教程七:HTML在FrontPage中的应用 FrontPage 2002比起以前版本的FronP
- 前言HI,好久不见,今天是关闭朋友圈的第60天,我是野蛮成长的AC-Asteroid。人生苦短,我用Python,通过短短两周时间自学,从基
- 译注:前两天看到一篇不错的英文文章,叫做 How browsers work,该文概要的介绍了浏览器从头到尾的工作机制,包括HTML等的解析
- 这篇文章主要介绍了基于Python批量生成指定尺寸缩略图代码实例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值
- 音频文件放入和.py文件同级的目录下import winsound # 导入此模块实现声音播放功能import time # 导入此模块,获
- 需求背景:进行分值计算。如下图,如果只是一两个还好说,写写判断,但是如果有几十个,几百个,会不会惨不忍睹。而且,下面的还是三种情况。例如:解
- 这些年来,我发现许多开发者对于何时使用数据操纵语言(DML)触发器与何时使用约束感到迷惑。许多时候,如果没有正确应用这两个对象,就会造成问题
- 兼容当前HTML/XHTML文档是否有DTD声明:以下为程序代码:var xtop = document.documentElement.s
- 安装刚接触Pillow的朋友先来看一下Pillow的安装方法,在这里我们以Mac OS环境为例:(1)、使用 pip 安装 Python 库
- 可实现功能:1.随机生成一个整数。2.随机生成任意范围内的一个整数。3.随机生成指定长度的整数组4.随机生成指定长度的任意范围的整数组5.随
- 1 前言前面的文章中我们已经获取到了基金的阶段变动信息和ETF信息的获取,那么在本章中,我们将继续前面的内容,获取基金的价格信息,并且把之前
- 目录最终版本过程借鉴代码思考urllib.request和requestsBeautifulSoup优化处理总结代码复制可直接使用,记得pi
- 本文实例讲述了django框架模型层功能、组成与用法。分享给大家供大家参考,具体如下:Django models是Django框架自定义的一
- 单表操作增加数据auther_obj = {"auther_name":"崔皓然","au
- 前言项目中大量用到图片加载,由于图片太大,加载速度很慢,因此需要对文件进行统一压缩一:导入包from PIL import Imageimp
- MySQL 一级防范检查列表以下是加固你的 Mysql 服务器安全所要做的工作的重要参考:Securing MySQL: step-by-s
- 1. 非 matlab v7.3 files 读写import scipy.io as sioimport numpy# matFile 读
- 这个代表显示宽度整数列的显示宽度与mysql需要用多少个字符来显示该列数值,与该整数需要的存储空间的大小都没有关系