pytorch中dataloader 的sampler 参数详解
作者:mingqian_chu 发布时间:2023-09-16 21:00:13
1. dataloader() 初始化函数
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, multiprocessing_context=None):
其中几个常用的参数:
dataset 数据集,map-style and iterable-style 可以用index取值的对象、
batch_size 大小
shuffle 取batch是否随机取, 默认为False
sampler 定义取batch的方法,是一个迭代器, 每次生成一个key 用于读取dataset中的值
batch_sampler 也是一个迭代器, 每次生次一个batch_size的key
num_workers 参与工作的线程数collate_fn 对取出的batch进行处理
drop_last 对最后不足batchsize的数据的处理方法
下面看两段取自DataLoader中的__init__代码, 帮助我们理解几个常用参数之间的关系
2. shuffle 与sample之间的关系
当我们sampler有输入时,shuffle的值就没有意义,
if sampler is None: # give default samplers
if self._dataset_kind == _DatasetKind.Iterable:
# See NOTE [ Custom Samplers and IterableDataset ]
sampler = _InfiniteConstantSampler()
else: # map-style
if shuffle:
sampler = RandomSampler(dataset)
else:
sampler = SequentialSampler(dataset)
当dataset类型是map style时, shuffle其实就是改变sampler的取值
shuffle为默认值 False时,sampler是SequentialSampler,就是按顺序取样,
shuffle为True时,sampler是RandomSampler, 就是按随机取样
3. sample 的定义方法
3.1 sampler 参数的使用
sampler 是用来定义取batch方法的一个函数或者类,返回的是一个迭代器。
我们可以看下自带的RandomSampler类中最重要的iter函数
def __iter__(self):
n = len(self.data_source)
# dataset的长度, 按顺序索引
if self.replacement:# 对应的replace参数
return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())
return iter(torch.randperm(n).tolist())
可以看出,其实就是生成索引,然后随机的取值, 然后再迭代。
其实还有一些细节需要注意理解:
比如__len__函数,包括DataLoader的len和sample的len, 两者区别, 这部分代码比较简单,可以自行阅读,其实参考着RandomSampler写也不会出现问题。
比如,迭代器和生成器的使用, 以及区别
if batch_size is not None and batch_sampler is None:
# auto_collation without custom batch_sampler
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
self.sampler = sampler
self.batch_sampler = batch_sampler
BatchSampler的生成过程:
# 略去类的初始化
def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch
就是按batch_size从sampler中读取索引, 并形成生成器返回。
以上可以看出, batch_sampler和sampler, batch_size, drop_last之间的关系
如果batch_sampler没有定义的话且batch_size有定义, 会根据sampler, batch_size, drop_last生成一个batch_sampler
自带的注释中对batch_sampler有一句话: Mutually exclusive with :attr:batch_size :attr:shuffle, :attr:sampler, and :attr:drop_last.
意思就是b
atch_sampler 与这些参数冲突 ,即 如果你定义了batch_sampler, 其他参数都不需要有
4. batch 生成过程
每个batch都是由迭代器产生的:
# DataLoader中iter的部分
def __iter__(self):
if self.num_workers == 0:
return _SingleProcessDataLoaderIter(self)
else:
return _MultiProcessingDataLoaderIter(self)
# 再看调用的另一个类
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader):
super(_SingleProcessDataLoaderIter, self).__init__(loader)
assert self._timeout == 0
assert self._num_workers == 0
self._dataset_fetcher = _DatasetKind.create_fetcher(
self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)
def __next__(self):
index = self._next_index()
data = self._dataset_fetcher.fetch(index)
if self._pin_memory:
data = _utils.pin_memory.pin_memory(data)
return data
来源:https://blog.csdn.net/chumingqian/article/details/126625724


猜你喜欢
- 老生常谈的问题,大部分人也不一定可以系统的理解。Javascript语言对继承实现的并不好,需要工程师自己去实现一套完整的继承机制。下面我们
- 本文实例讲述了Python 操作mysql数据库查询之fetchone(), fetchmany(), fetchall()用法。分享给大家
- 最近,接手的项目里,提供的数据文件格式简直让人看不下去,使用pandas打不开
- 用于序列化的两个模块json:用于字符串和Python数据类型间进行转换pickle: 用于python特有的类型和python的数据类型间
- 问题描述在spring-boot启动时,希望能执行相应的sql文件来初始化数据库。使用配置文件初始化数据库可以在spring-boot的配置
- 1 监听启动activity 信息命令adb shell logcat | grep START 可以查看apk包名和Activity名字=
- 现在只要是有关头像的框基本都是圆形的了,C#提供的PictureBox控键默认情况下是方形的非常大的影响美观PictureBox默认情况下比
- 在项目中安装mockjs在项目目录下执行以下安装命令npm install mockjs --save在Vue项目中使用mockjs的基本流
- 一、事件(EVENT)是干什么的 自MySQL5.1.6起,增加了一个非常有特色的功能 - 事件调度器(Event Scheduler),
- <table> <tr> &nb
- html的标签的属性,比如id、class、href需要动态传递参数,拼接字符串,查了一些资料,并没有找到合适的解决方法,琢磨了一上午,终于
- 一般语言都提供了按字典排序的API,比如跟微信公众平台对接时就需要用到字典排序。按字典排序有很多种算法,最容易想到的就是字符串搜索的方式,但
- 在Pydev能正常执行的脚本,在导出后在命令行执行,通常会报自己写的包导入时找不到。一:报错原因在PyDev中,test.py 中导入Tes
- 如果你有过Web编程的经验,那么或多或少都听说过或者使用过模板。简而言之,模板是可用于创建动态内容的文本文件。例如,你有一个网站导航栏的模板
- 分别针对ie和火狐分别作了对xml文档和xml字符串的解析,所有代码都注释掉了,想看哪部分功能,去掉注释就可以了。至于在ajax环境下解析x
- 背景项目中经常使用别人维护的模块,在git中使用子模块的功能能够大大提高开发效率。使用子模块后,不必负责子模块的维护,只需要在必要的时候同步
- 前言 FTP(File Transfer Protocol)是文件传输协议的简称。用于Internet上的控制文件的双向传输。同时,它也是一
- 1.实现的思路(1)首先使用一个处理画框的程序,将图片中的有车和无车的停车位给画出来,并且保存坐标(如果画错了,将鼠标移至要删除的框中,右击
- 灰度直方图概括了图像的灰度级信息,简单的来说就是每个灰度级图像中的像素个数以及占有率,创建直方图无外乎两个步骤,统计直方图数据,再用绘图库绘
- 如何做一个全面的探测器? 我们也可以做一个功能类似的探测器,见下:<Script lan