pytorch sampler对数据进行采样的实现
作者:蓝鲸123 发布时间:2023-02-09 20:05:40
标签:pytorch,sampler,数据,采样
PyTorch中还单独提供了一个sampler模块,用来对数据进行采样。常用的有随机采样器:RandomSampler,当dataloader的shuffle参数为True时,系统会自动调用这个采样器,实现打乱数据。默认的是采用SequentialSampler,它会按顺序一个一个进行采样。这里介绍另外一个很有用的采样方法: WeightedRandomSampler,它会根据每个样本的权重选取数据,在样本比例不均衡的问题中,可用它来进行重采样。
构建WeightedRandomSampler时需提供两个参数:每个样本的权重weights、共选取的样本总数num_samples,以及一个可选参数replacement。权重越大的样本被选中的概率越大,待选取的样本数目一般小于全部的样本数目。replacement用于指定是否可以重复选取某一个样本,默认为True,即允许在一个epoch中重复采样某一个数据。如果设为False,则当某一类的样本被全部选取完,但其样本数目仍未达到num_samples时,sampler将不会再从该类中选择数据,此时可能导致weights参数失效。
下面举例说明。
from dataSet import *
dataset = DogCat('data/dogcat/', transform=transform)
from torch.utils.data import DataLoader
# 狗的图片被取出的概率是猫的概率的两倍
# 两类图片被取出的概率与weights的绝对大小无关,只和比值有关
weights = [2 if label == 1 else 1 for data, label in dataset]
print(weights)
from torch.utils.data.sampler import WeightedRandomSampler
sampler = WeightedRandomSampler(weights,\
num_samples=9,\
replacement=True)
dataloader = DataLoader(dataset,
batch_size=3,
sampler=sampler)
for datas, labels in dataloader:
print(labels.tolist())
输出:
[2, 2, 1, 1, 2, 1, 1, 2]
[1, 1, 0]
[1, 0, 0]
[0, 0, 1]
github 地址:
https://github.com/WebLearning17/CommonTool
来源:https://blog.csdn.net/TH_NUM/article/details/80877772


猜你喜欢
- #!/usr/bin/env python#coding=utf-8import osfrom pyinotify import Watch
- 在上一篇文章中,我们介绍了如何使用源码对TensorBoard进行编译教程,没有定制需求的可以直接使用pip进行安装。TensorBoard
- 简介:上文中已经介绍如何安装Pycharm已经环境变量的配置。现在软件已经安装成功,现在就开始动手做第一个Python项目。第一个“Hell
- 常用目标检测模型基本都是读取的PASCAL VOC格式的标签,下面代码用于生成VOC格式的代码,根据需要修改即可:from lxml imp
- Mysql 删除重复数据保留一条有效数据一、Mysql 删除重复数据,保留一条有效数据DELETE FROM SZ_Building WHE
- 你已经在上面取出w打头记录的例子中看到了LIKE的用法。LIKE判定词是一个非常有用的符号。不过,在很多情况下用了它可能会带给你太多的数据,
- 前言:枚举(enumeration)在许多编程语言中常被表示为一种基础的数据结构使用,枚举帮助组织一系列密切相关的成员到同一个群组机制下,一
- SQL Server中事务日志的作用:持续记录数据库所有的事务和这些事务对数据库所做的修改;一旦数据库出现灾难事件,就需要事务日志来进行近期
- 插入一条记录后,如何得到最新的自动增加ID?我们要用到SQL Server的@@IDENTITY。它能够记录下系统最近使用的一个IDENTI
- 概述Object.freeze(obj)可以冻结一个对象。一个被冻结的对象再也不能被修改;冻结了一个对象则不能向这个对象添加新的属性,不能删
- 载入库绘制表格我们需要用到python库中的matplotlib库import matplotlib.pyplot as plt一、折线图#
- python提取照片坐标信息的代码如下所示:from PIL import Imagefrom PIL.ExifTags import TA
- 一 背景 有赞的每个OLTP数据库实例上会设置一个sql-killer进程用于kill
- XML是一个精简的SGML,它将SGML的丰富功能与HTML的易用性结合到Web的应用中。XML保留了SGML的可扩展功能,这使XML从根本
- DbUtils是Javar的一个为简化JDBC操作类库commons-dbutils是Apache组织提供的一个开源JDBC工具类库,它是对
- 一、参数和共享引用:In [56]: def changer(a,b): ....: a=2 ....
- 请问如何用OleDbDataAdapter来对数据库进行删除、修改和添加?OleDbDataAdapter是DataSet和数据源之间建立联
- 怎么增大MySQL数据库连接数,MYSQL数据库安装完成后,默认连接数是100,流量稍微大一点的论坛或网站这个连接数是不够哟用
- 需要selenium控制的chrome向下滑动,自动加载一些内容,核心代码是:browser.execute_script("wi
- datetime日期时间类,主要熟悉API,时区的概念与语言无关。from datetime import datetime as dtdt