Pytorch技巧:DataLoader的collate_fn参数使用详解
作者:jmjackyrj 发布时间:2023-12-11 00:20:48
标签:Pytorch,DataLoader,collate,fn
DataLoader完整的参数表如下:
class torch.utils.data.DataLoader(
dataset,
batch_size=1,
shuffle=False,
sampler=None,
batch_sampler=None,
num_workers=0,
collate_fn=<function default_collate>,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None)
DataLoader在数据集上提供单进程或多进程的迭代器
几个关键的参数意思:
- shuffle:设置为True的时候,每个世代都会打乱数据集
- collate_fn:如何取样本的,我们可以定义自己的函数来准确地实现想要的功能
- drop_last:告诉如何处理数据集长度除于batch_size余下的数据。True就抛弃,否则保留
一个测试的例子
import torch
import torch.utils.data as Data
import numpy as np
test = np.array([0,1,2,3,4,5,6,7,8,9,10,11])
inputing = torch.tensor(np.array([test[i:i + 3] for i in range(10)]))
target = torch.tensor(np.array([test[i:i + 1] for i in range(10)]))
torch_dataset = Data.TensorDataset(inputing,target)
batch = 3
loader = Data.DataLoader(
dataset=torch_dataset,
batch_size=batch, # 批大小
# 若dataset中的样本数不能被batch_size整除的话,最后剩余多少就使用多少
collate_fn=lambda x:(
torch.cat(
[x[i][j].unsqueeze(0) for i in range(len(x))], 0
).unsqueeze(0) for j in range(len(x[0]))
)
)
for (i,j) in loader:
print(i)
print(j)
输出结果:
tensor([[[ 0, 1, 2],
[ 1, 2, 3],
[ 2, 3, 4]]], dtype=torch.int32)
tensor([[[ 0],
[ 1],
[ 2]]], dtype=torch.int32)
tensor([[[ 3, 4, 5],
[ 4, 5, 6],
[ 5, 6, 7]]], dtype=torch.int32)
tensor([[[ 3],
[ 4],
[ 5]]], dtype=torch.int32)
tensor([[[ 6, 7, 8],
[ 7, 8, 9],
[ 8, 9, 10]]], dtype=torch.int32)
tensor([[[ 6],
[ 7],
[ 8]]], dtype=torch.int32)
tensor([[[ 9, 10, 11]]], dtype=torch.int32)
tensor([[[ 9]]], dtype=torch.int32)
如果不要collate_fn的值,输出变成
tensor([[ 0, 1, 2],
[ 1, 2, 3],
[ 2, 3, 4]], dtype=torch.int32)
tensor([[ 0],
[ 1],
[ 2]], dtype=torch.int32)
tensor([[ 3, 4, 5],
[ 4, 5, 6],
[ 5, 6, 7]], dtype=torch.int32)
tensor([[ 3],
[ 4],
[ 5]], dtype=torch.int32)
tensor([[ 6, 7, 8],
[ 7, 8, 9],
[ 8, 9, 10]], dtype=torch.int32)
tensor([[ 6],
[ 7],
[ 8]], dtype=torch.int32)
tensor([[ 9, 10, 11]], dtype=torch.int32)
tensor([[ 9]], dtype=torch.int32)
所以collate_fn就是使结果多一维。
看看collate_fn的值是什么意思。我们把它改为如下
collate_fn=lambda x:x
并输出
for i in loader:
print(i)
得到结果
[(tensor([ 0, 1, 2], dtype=torch.int32), tensor([ 0], dtype=torch.int32)), (tensor([ 1, 2, 3], dtype=torch.int32), tensor([ 1], dtype=torch.int32)), (tensor([ 2, 3, 4], dtype=torch.int32), tensor([ 2], dtype=torch.int32))]
[(tensor([ 3, 4, 5], dtype=torch.int32), tensor([ 3], dtype=torch.int32)), (tensor([ 4, 5, 6], dtype=torch.int32), tensor([ 4], dtype=torch.int32)), (tensor([ 5, 6, 7], dtype=torch.int32), tensor([ 5], dtype=torch.int32))]
[(tensor([ 6, 7, 8], dtype=torch.int32), tensor([ 6], dtype=torch.int32)), (tensor([ 7, 8, 9], dtype=torch.int32), tensor([ 7], dtype=torch.int32)), (tensor([ 8, 9, 10], dtype=torch.int32), tensor([ 8], dtype=torch.int32))]
[(tensor([ 9, 10, 11], dtype=torch.int32), tensor([ 9], dtype=torch.int32))]
每个i都是一个列表,每个列表包含batch_size个元组,每个元组包含TensorDataset的单独数据。所以要将重新组合成每个batch包含1*3*3的input和1*3*1的target,就要重新解包并打包。 看看我们的collate_fn:
collate_fn=lambda x:(
torch.cat(
[x[i][j].unsqueeze(0) for i in range(len(x))], 0
).unsqueeze(0) for j in range(len(x[0]))
)
j取的是两个变量:input和target。i取的是batch_size。然后通过unsqueeze(0)方法在前面加一维。torch.cat(,0)将其打包起来。然后再通过unsqueeze(0)方法在前面加一维。 完成。
来源:https://blog.csdn.net/weixin_42028364/article/details/81675021


猜你喜欢
- 总的来说,提高应用程序性能的最好的方法是发现应用的瓶径之所在,和数据库进行交互的性能无疑是决定应用程序性能的重要环节之一。因为ADO是当前最
- 本文实例为大家分享了python3实现点餐系统的具体代码,供大家参考,具体内容如下题目: 某餐厅外卖每天更新菜品,
- 阅读上一篇:打造设计你自己的字体 Ⅱ永远都在寻觅字体设计的灵感。夏天过后,我买了一套便宜的书法钢笔,说服自己,它会让我的鸡爬字产生脱胎换骨的
- tensorflow下设置使用某一块GPU(从0开始编号):import osos.environ["CUDA_DEVICE_OR
- lighttpd (http://www.djangoproject.com/r/lighttpd/) 是一个轻量级的Web服务器,通常被用
- 本文实例讲述了JS小游戏的仙剑翻牌源码,是一款非常优秀的游戏源码。分享给大家供大家参考。具体如下:一、游戏介绍:这是一个翻牌配对游戏,共十关
- 然而这里不打算对某种存储引擎的实现细节进行描述,也不打算介绍各种存储引擎的优缺点,只是描述一下mysql如何处理binlog,并澄清几个容易
- 用python另一个抢票神器,你get到了吗?2017年时间飞逝,转眼间距离2018年春节还有不到1个月的时间,还在为抢不到火车票发愁吗?作
- 目录分页后添加删除功能实现模态框编辑内容完整代码笔记利用layui框架实现分页:layui实现完整表格分页:自己实现分页:# name: m
- 这篇文章主要介绍了python Opencv计算图像相似度过程解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价
- 下面是一份在 HTML 4 Strict 和 XHTML 1.0 Strict 下必须遵守的标签嵌套规则,比如你不能在 <a>
- 1. 作用将类方法转换为类属性,可以用 . 直接获取属性值或者对属性进行赋值2.实现方式使用property类来实现,也可以使用proper
- 游戏截图动态演示源码分享state/tool.pyimport osimport jsonfrom abc import abstractm
- 我们将要来学习python的重要概念迭代和迭代器,通过简单实用的例子如列表迭代器和xrange。可迭代一个对象,物理或者虚拟存储的序列。li
- 从这里开始我的博客,后台数据库是什么?没错,就是MySQL,服务器端使用的脚本就是PHP,整个框架使用的是WordPress。PHP和MyS
- 多表查询使用单个select 语句从多个表格中取出相关的查询结果,多表连接通常是建立在有相互关系的父子表上;1交叉连接第一个表格的所有行 乘
- 一、io包中接口的好处和优势1.1拷贝数据的函数io.Copy(dst Writer, src Reader)io.CopyBuffer(d
- 废话不多说了直接给大家贴代码了,具体代码如下所示:$('#myModal').on('shown', fun
- 使用cpu和gpu的区别在Tensorflow中使用gpu和cpu是有很大的差别的。在小数据集的情况下,cpu和gpu的性能差别不大。不过在
- 前言最近在重构一个复选框组件,原型是select2这个jQuery插件, 有兴趣的可以去搜下,封装的很好,但是并不能满足业务所有需求,最要命