关于pytorch处理类别不平衡的问题
作者:NAAE 发布时间:2023-04-08 19:11:56
标签:pytorch,类别,不平衡
当训练样本不均匀时,我们可以采用过采样、欠采样、数据增强等手段来避免过拟合。今天遇到一个3d点云数据集合,样本分布极不均匀,正例与负例相差4-5个数量级。数据增强效果就不会太好了,另外过采样也不太合适,因为是空间数据,新增的点有可能会对真实分布产生未知影响。所以采用欠采样来缓解类别不平衡的问题。
下面的代码展示了如何使用WeightedRandomSampler来完成抽样。
numDataPoints = 1000
data_dim = 5
bs = 100
# Create dummy data with class imbalance 9 to 1
data = torch.FloatTensor(numDataPoints, data_dim)
target = np.hstack((np.zeros(int(numDataPoints * 0.9), dtype=np.int32),
np.ones(int(numDataPoints * 0.1), dtype=np.int32)))
print 'target train 0/1: {}/{}'.format(
len(np.where(target == 0)[0]), len(np.where(target == 1)[0]))
class_sample_count = np.array(
[len(np.where(target == t)[0]) for t in np.unique(target)])
weight = 1. / class_sample_count
samples_weight = np.array([weight[t] for t in target])
samples_weight = torch.from_numpy(samples_weight)
samples_weight = samples_weight.double()
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
target = torch.from_numpy(target).long()
train_dataset = torch.utils.data.TensorDataset(data, target)
train_loader = DataLoader(
train_dataset, batch_size=bs, num_workers=1, sampler=sampler)
for i, (data, target) in enumerate(train_loader):
print "batch index {}, 0/1: {}/{}".format(
i,
len(np.where(target.numpy() == 0)[0]),
len(np.where(target.numpy() == 1)[0]))
核心部分为实际使用时替换下变量把sampler传递给DataLoader即可,注意使用了sampler就不能使用shuffle,另外需要指定采样点个数:
class_sample_count = np.array(
[len(np.where(target == t)[0]) for t in np.unique(target)])
weight = 1. / class_sample_count
samples_weight = np.array([weight[t] for t in target])
samples_weight = torch.from_numpy(samples_weight)
samples_weight = samples_weight.double()
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
参考:https://discuss.pytorch.org/t/how-to-handle-imbalanced-classes/11264/2
来源:https://blog.csdn.net/zcgyq/article/details/83087997
0
投稿
猜你喜欢
- PDOStatement::fetchColumnPDOStatement::fetchColumn — 从结果集中的下一行返回单独的一列。
- 前言很多人会使用postman工具,或者熟悉python,但不一定会使用python来编写测试用例脚本,postman里面可以完整的将pyt
- 对于每个类型拥有的值范围以及并且指定日期何时间值的有效格式的描述见7.3.6 日期和时间类型。 1、这里是一个使用日期函数的例子。
- 前两天有一位网友问我一个关于Javascript中++操作符的问题,他的代码大致是这样的ADS.addEvent(window,'c
- 需求:主线程开启了多个线程去干活,每个线程需要完成的时间不同,但是在干完活以后都要通知给主线程下面上代码:#!/usr/bin/python
- 一个简单的实现class NaiveFilter():'''Filter Messages from keyword
- 代码如下:DECLARE @c INT DECLARE @c2 INT SELECT @c = COUNT(1) FROM dbo.Spli
- 1.Airbus Ship Detection Challengeurl: https://www.kaggle.com/comp
- 数据驱动模式的测试好处相比普通模式的测试就显而易见了吧!使用数据驱动的模式,可以根据业务分解测试数据,只需定义变量,使用外部或者自定义的数据
- 为了自定义一个模板标签,你需要告诉Django当遇到你的标签时怎样进行这个过程。当Django编译一个模板时,它将原始模板分成一个个 节点
- 数据模型==对象模型Python官方文档说法是“Python数据模型”,大多数Python书籍作者说法是“Python对象模型”,它们是一个
- 1.方法方法描述bbox(item, column=None)返回指定item的框选范围,或者单元格的框选范围column( cid, op
- 在python中进行两个整数相除的时候,在默认情况下都是只能够得到整数的值,而在需要进行对除所得的结果进行精确地求值时,想在运算后即得到浮点
- “没 Javascript 就会死”的页面通常都会加入 noscript 标签用于提示用户开启脚本支持。 然而在 IE8 下,如果在 nos
- 你在使用pandas处理DataFrame中是否遇到过如下这类问题?我们需要删除某一列所有元素中含有固定字符元素所在的行,比如下面的例子:&
- 从今天开始起,基督山将和大家一起进入ASP.net 诸多程序的学习中,老实说,.net到底是法宝还是垃圾,我们拭目以待。有任何问题,联络基督
- 前言:多态的实现必须满足两个前提条件1.继承:多态一定是发生在子类和父类之间2.重写:多态子类重写了父类的方法记住这两点再结合代码示例有助于
- 本文实例为大家分享了Python实现五子棋游戏的具体代码,供大家参考,具体内容如下class CheckerBoard(): &
- 如果是报名培训班的话,学习的速度可能会更快一些,毕竟是自己花钱了。自学python爬虫方法:首先要掌握一些有关爬虫的基础知识,基本的要知道什
- 凯撒密码的原理:计算并输出偏移量为3的凯撒密码的结果注意:密文是大写字母,在变换加密之前把明文字母都替换为大写字母def casar(mes