pytorch中dataloader 的sampler 参数详解
作者:mingqian_chu 发布时间:2023-09-16 21:00:13
1. dataloader() 初始化函数
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, multiprocessing_context=None):
其中几个常用的参数:
dataset 数据集,map-style and iterable-style 可以用index取值的对象、
batch_size 大小
shuffle 取batch是否随机取, 默认为False
sampler 定义取batch的方法,是一个迭代器, 每次生成一个key 用于读取dataset中的值
batch_sampler 也是一个迭代器, 每次生次一个batch_size的key
num_workers 参与工作的线程数collate_fn 对取出的batch进行处理
drop_last 对最后不足batchsize的数据的处理方法
下面看两段取自DataLoader中的__init__代码, 帮助我们理解几个常用参数之间的关系
2. shuffle 与sample之间的关系
当我们sampler有输入时,shuffle的值就没有意义,
if sampler is None: # give default samplers
if self._dataset_kind == _DatasetKind.Iterable:
# See NOTE [ Custom Samplers and IterableDataset ]
sampler = _InfiniteConstantSampler()
else: # map-style
if shuffle:
sampler = RandomSampler(dataset)
else:
sampler = SequentialSampler(dataset)
当dataset类型是map style时, shuffle其实就是改变sampler的取值
shuffle为默认值 False时,sampler是SequentialSampler,就是按顺序取样,
shuffle为True时,sampler是RandomSampler, 就是按随机取样
3. sample 的定义方法
3.1 sampler 参数的使用
sampler 是用来定义取batch方法的一个函数或者类,返回的是一个迭代器。
我们可以看下自带的RandomSampler类中最重要的iter函数
def __iter__(self):
n = len(self.data_source)
# dataset的长度, 按顺序索引
if self.replacement:# 对应的replace参数
return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())
return iter(torch.randperm(n).tolist())
可以看出,其实就是生成索引,然后随机的取值, 然后再迭代。
其实还有一些细节需要注意理解:
比如__len__函数,包括DataLoader的len和sample的len, 两者区别, 这部分代码比较简单,可以自行阅读,其实参考着RandomSampler写也不会出现问题。
比如,迭代器和生成器的使用, 以及区别
if batch_size is not None and batch_sampler is None:
# auto_collation without custom batch_sampler
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
self.sampler = sampler
self.batch_sampler = batch_sampler
BatchSampler的生成过程:
# 略去类的初始化
def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch
就是按batch_size从sampler中读取索引, 并形成生成器返回。
以上可以看出, batch_sampler和sampler, batch_size, drop_last之间的关系
如果batch_sampler没有定义的话且batch_size有定义, 会根据sampler, batch_size, drop_last生成一个batch_sampler
自带的注释中对batch_sampler有一句话: Mutually exclusive with :attr:batch_size :attr:shuffle, :attr:sampler, and :attr:drop_last.
意思就是b
atch_sampler 与这些参数冲突 ,即 如果你定义了batch_sampler, 其他参数都不需要有
4. batch 生成过程
每个batch都是由迭代器产生的:
# DataLoader中iter的部分
def __iter__(self):
if self.num_workers == 0:
return _SingleProcessDataLoaderIter(self)
else:
return _MultiProcessingDataLoaderIter(self)
# 再看调用的另一个类
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader):
super(_SingleProcessDataLoaderIter, self).__init__(loader)
assert self._timeout == 0
assert self._num_workers == 0
self._dataset_fetcher = _DatasetKind.create_fetcher(
self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)
def __next__(self):
index = self._next_index()
data = self._dataset_fetcher.fetch(index)
if self._pin_memory:
data = _utils.pin_memory.pin_memory(data)
return data
来源:https://blog.csdn.net/chumingqian/article/details/126625724
猜你喜欢
- 首先这是VGG的结构图,VGG11则是红色框里的结构,共分五个block,如红框中的VGG11第一个block就是一个conv3-64卷积层
- 本文主要是用PyTorch来实现一个简单的回归任务。 编辑器:spyder1.引入相应的包及生成伪数据import torchimport
- 我们都知道 vue 中可以使用 modal 来实现 input 内容数据的双向绑定。小程序好像没有提供相应的方法支持,就需要我们自己写了。原
- python时间处理月份加减第三方模块 :python-dateutil安装方式:pip install python-dateutil实例
- 前言:HTML5和CSS3的时代到来了,新版2011版淘宝网首页已全部使用HTML5,拥抱变化才是王道。为之漫笔翻译的很好,看了一遍后,感觉
- 以下针对Ubuntu系统,Windows系统没有测试过。Ubuntu中默认就安装有Python 2.x和Python 3.x,默认情况下py
- 下面是我写的NumericStepper:谢谢 果果 和 Rimifon , 我对代码进行了完善, 支持自适应小数位数:
- 前几天安装Python的时候没有装上pip工具,所以只能现在手动安装了。首先,访问https://bootstrap.pypa.io/get
- 利用python的sftp实现文件上传,可以是文件,也可以是文件夹。版本Python2.7.13 应该不用pip安装更多的插件,都是自带的不
- 介绍在本文中,你将学习如何使用 Python 构建人脸识别系统。人脸识别比人脸检测更进一步。在人脸检测中,我们只检测人脸在图像中的位置,但在
- 这篇文章主要介绍了python基于event实现线程间通信控制,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,
- 目录1.按照一列数值进行排序1.1按照五缺失值的一列进行排序1.1.1升序排列1.1.2 降序排列1.2按照有缺失值的一列进行排序1.2.1
- 大家在学习python中,经常会使用到K-Means和图片压缩的,我们在此给大家分享一下K-Means和图片压缩的方法和原理,喜欢的朋友收藏
- 很多年以前,面对上古时代遗留的 HTML 发出的腐臭,我捂住鼻子唉声叹气。刚练熟 web 标准的我,恨不得寝其尸食其肉,把一切推翻重来。但经
- Div的浮动+循环(描述的不清楚,请看图)在设计和布局的时候,碰到图片循环问题,碰到间距问题,怎么样让循环的图片每行的起始点跟上边的titl
- 前言大家好,我是小张~记得小时候,家里只有一个钟表用来看时间(含有时针、分针、秒针的那种),挂在墙上哒哒哒响个不停,现在生活条件好了、基本人
- 英文版见:http://dflying.dflying.net/.../98_web_standard_and_aspnet__part1_
- 我认为在ASP中最好的办法是用编程实现定时刷新Cache,也就是说给Application中储存的设一个过期时间。当然,在ASP中Appli
- 下面就来说说解决方案吧~import osimport syscurPath = os.path.abspath(os.path.dirna
- 对于一个Dict:test_dict = {1:5, 2:4, 3:3, 4:2, 5:1}想要求key值大于等于3的所有项:print({