pytorch collate_fn的基础与应用教程
作者:音程 发布时间:2021-06-03 02:55:57
作用
collate_fn:即用于collate的function,用于整理数据的函数。
说到整理数据,你当然要会用数据,即会用数据制作工具torch.utils.data.Dataset
,虽然我们今天谈的是torch.utils.data.DataLoader
,但是,其实:
这两个你如何定义;
从装载器dataloader中取数据后做什么处理;
模型的forward()中如何处理。
这三部分都是有机统一而不可分割的,一个地方改了,其他地方就要改。
emmm…,小小总结,collate_fn笼统的说就是用于整理数据,通常我们不需要使用,其应用的情形是:各个数据长度不一样的情况,比如第一张图片大小是28*28,第二张是50*50,这样的话就如果不自己写collate_fn,而使用默认的,就会报错。
原则
其实说起来,我们也没有什么原则,但是如今大多数做深度学习都是使用GPU,所以这个时候我们需要记住一个总则:只有tensor数据类型才能运行在GPU上,list和numpy都不可以。
从而,我们什么时候将我们的数据转化为tensor是一个问题,我的答案是前一节中的三个部分都可以来转化,只是我们大多数的人都习惯在部分一转化。
基础
dataset
我们必须先看看torch.utils.data.Dataset如何使用,以一个例子为例:
import torch.utils.data as Data
class mydataset(Data.Dataset):
def __init__(self,train_inputs,train_targets):#必须有
super(mydataset,self).__init__()
self.inputs=train_inputs
self.targets=train_targets
def __getitem__(self, index):#必须重写
return self.inputs[index],self.targets[index]
def __len__(self):#必须重写
return len(self.targets)
#构造训练数据
datax=torch.randn(4,3)#构造4个输入
datay=torch.empty(4).random_(2)#构造4个标签
#制作dataset
dataset=mydataset(datax,datay)
下面,可以对dataset进行一系列操作,这些操作返回的结果和你之前那个class的三个函数定义都息息相关。我想说,那三个函数非常自由,你想怎么定义就怎么定义,上述只是一种常见的而已,你可以定制一个特色的。
len(dataset)#调用了你上面定义的def __len__()那个函数
#4
dataset[0]#调用了你上面定义的def __getitem__()那个函数
#(tensor([-1.1426, -1.3239, 1.8372]), tensor(0.))
所以我再三强调的是上面的输出结果和你的定义有关,比如你完全可以把def __getitem__()改成:
def __getitem__(self, index):
return self.inputs[index]#不输出标签
那么,
dataset[0]#此时当然变化。
#tensor([-1.1426, -1.3239, 1.8372])
可以看到,是非常随便的,你随便定制就好。
dataloader
torch.utils.data.DataLoader
dataloader=Data.DataLoader(dataset,batch_size=2)
4个数据,batch_size=2,所以一共有2个batch。
collate_fn如果你不指定,会调用pytorch内部的,也就是说这个函数是一定会调用的,而且调用这个函数时pytorch会往这个函数里面传入一个参数batch。
def my_collate(batch):
return xxx
这个batch是什么?这个东西和你定义的dataset, batch_size息息相关。batch是一个列表[x,...,xx],长度就是batch_size,里面每一个元素是dataset的某一个元素,即dataset[i](我在上一节展示过dataset[0])。
在我们的例子中,由于我们没有对dataloader设置需要打乱数据,即shuffle=True,那么第1个batch就是前两个数据,如下:
print(datax)
print(datay)
batch=[dataset[0],dataset[1]]#所以才说和你dataset中get_item的定义有关。
print(batch)
对,你没有看错,上述代码展示的batch就会传入到pytorch默认的collate_fn中,然后经过默认的处理,输出如下:
it=iter(dataloader)
nex=next(it)#我们展示第一个batch经过collate_fn之后的输出结果
print(nex)
其实,上面就是我们常用的,经典的输出结果,即输入和标签是分开的,第一项是输入tensor,第二项是标签tensor,输入的维度变成了(batch_size,input_size)。
但是我们乍一看,将第一个batch变成上述输出结果很容易呀,我们也会!我们下面就来自己写一个collate_fn实现这个功能。
# a simple custom collate function, just to show the idea
# `batch` is a list of tuple where first element is input tensor and the second element is corresponding label
def my_collate(batch):
inputs=[data[0].tolist() for data in batch]
target = torch.tensor([data[1] for data in batch])
return [data, target]
dataloader=Data.DataLoader(dataset,batch_size=2,collate_fn=my_collate)
print(datax)
print(datay)
it=iter(dataloader)
nex=next(it)
print(nex)
这不就和默认的collate_fn的输出结果一样了嘛!无非就是默认的还把输入变成了tensor,标签变成了tensor,我上面是列表,我改就是了嘛!如下:
def my_collate(batch):
inputs=[data[0].tolist() for data in batch]
inputs=torch.tensor(inputs)
target =[data[1].tolist() for data in batch]
target=torch.tensor(target)
return [inputs, target]
dataloader=Data.DataLoader(dataset,batch_size=2,collate_fn=my_collate)
it=iter(dataloader)
nex=next(it)
print(nex)
这下好了吧!
对了,作为彩蛋,告诉大家一个秘密:默认的collate_fn函数中有一些语句是转tensor以及tensor合并的操作,所以你的dataset如果没有设计成经典模式的话,使用默认的就容易报错,而我们自己会写collate_fn,当然就不存在这个问题啦。同时,给大家的一个经验就是,一般dataset是不会报错的,而是根据dataset制作dataloader的时候容易报错,因为默认collate_fn把dataset的类型限制得比较死。
应用情形
假设我们还是4个输入,但是维度不固定的。
a=[[1,2],[3,4,5],[1],[3,4,9]]
b=[1,0,0,1]
dataset=mydataset(a,b)
dataloader=Data.DataLoader(dataset,batch_size=2)
it=iter(dataloader)
nex=next(it)
nex
使用默认的collate_fn,直接报错,要求相同维度。
这个时候,我们可以使用自己的collate_fn,避免报错。
不过话说回来,我个人感受是:
在这里避免报错好像也没有什么用,因为大多数的神经网络都是定长输入的,而且很多的操作也要求相同维度才能相加或相乘,所以:这里不报错,后面还是报错。如果后面解决这个问题的方法是:在不足维度上进行补0操作,那么我们为什么不在建立dataset之前先补好呢?所以,collate_fn这个东西的应用场景还是有限的。
来源:https://blog.csdn.net/qq_43391414/article/details/120462055


猜你喜欢
- 取行和列的几种常用方式:data[ 列名 ]: 取单列或多列,不能用连续方式取,也不能用于取行。data.列名: 只用于取单列,不能用于行。
- 之前想爬取一些淘宝的数据,后来发现需要登录,找了很多的资料,有个使用request的sessions加上cookie来登录的,cookie的
- 本文实例分析了python字符串连接方法。分享给大家供大家参考,具体如下:python字符串连接有几种方法,把大家可能用到的列出来,第一个方
- python中进行图表绘制的库主要有两个:matplotlib 和 pyecharts, 相比较而言:matplotlib中提供了BaseM
- 一、文件的操作流程第一,建立文件对象。第二,调用文件方法进行操作。第三,关闭文件。1、打开文件用python内置的open()函数打开一个文
- 本文实例讲述了Python3将jpg转为pdf文件的方法。分享给大家供大家参考,具体如下:#coding=utf-8#!/usr/bin/e
- 在你自己安装了一个新的MySQL服务器后,你需要为MySQL的root用户指定一个目录(缺省无口令),否则如果你忘记这点,你将你的MySQL
- 前提官网上提供了 Mac 和 Windows 上的安装包和 Linux 上安装需要的源码。下载地址如下:https://www.python
- 近期安装了python后,发现使用pycharm工具打开代码后发现代码下边会有波浪线的显示;但是该代码语句确实没有错误,通过查询发现了两种方
- 静态文件配置概述:静态文件交由Web服务器处理,Django本身不处理静态文件。简单的处理逻辑如下(以nginx为例):URI请求 --&g
- python3中的字符串是一种常见的数据类型。字符串有多种表现形式:单引号、双引号和三引号,且这些字符串的表现形式(单、双、三)都必须是成对
- 是因工作需要做的一个批量修改代码的小东西,拿出来与大家分享。 目前可以处理的文件类型:.asp .inc .htm .html
- MySQL字符串的拼接、截取、替换、查找位置。常用的字符串函数:函数说明CONCAT(s1,s2,...)返回连接参数产生的字符串,一个或多
- 作为酷爱编程的老程序员,实在按耐不下这个冲动,Python真的是太火了,不断撩拨我的心。我是对Python存有戒备之心的,想当年我基于Dru
- 决策树原理:从数据集中找出决定性的特征对数据集进行迭代划分,直到某个分支下的数据都属于同一类型,或者已经遍历了所有划分数据集的特征,停止决策
- SQL Server 2008我们也能从中体验到很多新的特性,但是对于SQL Server 2008安装,还是用图来说话比较好。本文将从SQ
- Vue.js是一个构建数据驱动的web界面的库。重点集中在MVVM模式的ViewModel层,因此非常容易与其它库或已有项目整合Vue.js
- Mootools 1.2手风琴(Accordion)教程原文地址:30 Days of Mootools 1.2 Tutoria
- 本文实例总结了python调用函数、类和文件操作。分享给大家供大家参考,具体如下:调用函数有三种方式一,导入整个模块(所有函数)导入 imp
- 格式化输出:format()format():把传统的%替换为{}来实现格式化输出1.使用位置参数:就是在字符串中把需要输出的变量值用{}来