pytorch sampler对数据进行采样的实现
作者:蓝鲸123 发布时间:2023-02-09 20:05:40
标签:pytorch,sampler,数据,采样
PyTorch中还单独提供了一个sampler模块,用来对数据进行采样。常用的有随机采样器:RandomSampler,当dataloader的shuffle参数为True时,系统会自动调用这个采样器,实现打乱数据。默认的是采用SequentialSampler,它会按顺序一个一个进行采样。这里介绍另外一个很有用的采样方法: WeightedRandomSampler,它会根据每个样本的权重选取数据,在样本比例不均衡的问题中,可用它来进行重采样。
构建WeightedRandomSampler时需提供两个参数:每个样本的权重weights、共选取的样本总数num_samples,以及一个可选参数replacement。权重越大的样本被选中的概率越大,待选取的样本数目一般小于全部的样本数目。replacement用于指定是否可以重复选取某一个样本,默认为True,即允许在一个epoch中重复采样某一个数据。如果设为False,则当某一类的样本被全部选取完,但其样本数目仍未达到num_samples时,sampler将不会再从该类中选择数据,此时可能导致weights参数失效。
下面举例说明。
from dataSet import *
dataset = DogCat('data/dogcat/', transform=transform)
from torch.utils.data import DataLoader
# 狗的图片被取出的概率是猫的概率的两倍
# 两类图片被取出的概率与weights的绝对大小无关,只和比值有关
weights = [2 if label == 1 else 1 for data, label in dataset]
print(weights)
from torch.utils.data.sampler import WeightedRandomSampler
sampler = WeightedRandomSampler(weights,\
num_samples=9,\
replacement=True)
dataloader = DataLoader(dataset,
batch_size=3,
sampler=sampler)
for datas, labels in dataloader:
print(labels.tolist())
输出:
[2, 2, 1, 1, 2, 1, 1, 2]
[1, 1, 0]
[1, 0, 0]
[0, 0, 1]
github 地址:
https://github.com/WebLearning17/CommonTool
来源:https://blog.csdn.net/TH_NUM/article/details/80877772
0
投稿
猜你喜欢
- 本文实例讲述了python实现的批量分析xml标签中各个类别个数功能。分享给大家供大家参考,具体如下:文章目录需要个脚本分析下各个目标的数目
- 目录一、线程基础以及守护进程二、线程锁(互斥锁)三、线程锁(递归锁)四、死锁五、队列六、相关面试题七、判断数据是否安全八、进程池 &
- 一、整数python2中整形可以分为一般整形和长整形,但是在python3中,两者以及合二为一了,只有整形。python中的整形是具有无限精
- 在进行ASP网站开发时,有时需在客户端调用MSSQL数据库的数据进行打印,若调用数据量小,可以通过在客户端运用FileSystemObjec
- 环境:python2.7+django1.91、先下载django-sutipip install django-suit2、配置项目打开s
- 1. 正则表达式的应用在给用户发送消息时通常情况会有相同的消息模板,但其中部分信息跟用户相关,因此需要对消息模板中的变量部分进行替换。而对于
- 即使在urlencode之前str.decode(“cp936″).encode(“utf-8″)做了编码转换也是没用的。后来查询手册查到一
- 这里给大家分享一段使用PHP Socket 编程模拟Http post和get请求的代码,非常的实用,结尾部分我们再讨论下php模拟http
- 对象Python 中,一切皆对象。每个对象由:标识(identity)、类型(type)、value(值)组成。1. 标识用于唯一标识对象,
- 字符串字符串常用操作拼接字符串拼接字符串需要使用‘+’运算符可完成对多个字符串的拼接。如str =
- 如下所示:#coding utf-8a=0.001 #定义收敛步长xd=1 #定义寻找步
- 环境: 开发的IDE:JBuilderX 使用的数据库:MS Sql Server 2000 使用的数据库驱动:JSQL Driver(JD
- 一:取字符串中第几个字符print "Hello"[0] 表示输出字符串中第一个字符print "Hello&
- 1. 简述我们在用scrapy爬取数据时,首先就要明确我们要爬取什么数据。scrapy提供了Item对象这种简单的容器,我们可以通过Item
- 为了更直观的了解prometheus如何工作,本文使用prometheus的python库来做一些相应的测试。python库的github地
- Ruby 是一门通用的语言,不仅仅是一门应用于WEB开发的语言,但 Ruby 在WEB应用及WEB工具中的开发是最常见的。使用Ruby您不仅
- 前言Matplotlib的可以把很多张图画到一个显示界面,在作对比分析的时候非常有用。对应的有plt的subplot和figure的add_
- 一、利用Google API生成二维码Google提供了较为完善的二维码生成接口,调用API接口很简单,以下是调用代码:$urlToEnco
- 如下所示:try: f =open("D:/1.txt",'r') f.clos
- 一些简单的代码简化下面是一个简单示例,它说明了 jQuery 对代码的影响。要执行一些真正简单和常见的任务,比方说为页面的某一区域中的每个链