关于pytorch处理类别不平衡的问题
作者:NAAE 发布时间:2023-04-08 19:11:56
标签:pytorch,类别,不平衡
当训练样本不均匀时,我们可以采用过采样、欠采样、数据增强等手段来避免过拟合。今天遇到一个3d点云数据集合,样本分布极不均匀,正例与负例相差4-5个数量级。数据增强效果就不会太好了,另外过采样也不太合适,因为是空间数据,新增的点有可能会对真实分布产生未知影响。所以采用欠采样来缓解类别不平衡的问题。
下面的代码展示了如何使用WeightedRandomSampler来完成抽样。
numDataPoints = 1000
data_dim = 5
bs = 100
# Create dummy data with class imbalance 9 to 1
data = torch.FloatTensor(numDataPoints, data_dim)
target = np.hstack((np.zeros(int(numDataPoints * 0.9), dtype=np.int32),
np.ones(int(numDataPoints * 0.1), dtype=np.int32)))
print 'target train 0/1: {}/{}'.format(
len(np.where(target == 0)[0]), len(np.where(target == 1)[0]))
class_sample_count = np.array(
[len(np.where(target == t)[0]) for t in np.unique(target)])
weight = 1. / class_sample_count
samples_weight = np.array([weight[t] for t in target])
samples_weight = torch.from_numpy(samples_weight)
samples_weight = samples_weight.double()
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
target = torch.from_numpy(target).long()
train_dataset = torch.utils.data.TensorDataset(data, target)
train_loader = DataLoader(
train_dataset, batch_size=bs, num_workers=1, sampler=sampler)
for i, (data, target) in enumerate(train_loader):
print "batch index {}, 0/1: {}/{}".format(
i,
len(np.where(target.numpy() == 0)[0]),
len(np.where(target.numpy() == 1)[0]))
核心部分为实际使用时替换下变量把sampler传递给DataLoader即可,注意使用了sampler就不能使用shuffle,另外需要指定采样点个数:
class_sample_count = np.array(
[len(np.where(target == t)[0]) for t in np.unique(target)])
weight = 1. / class_sample_count
samples_weight = np.array([weight[t] for t in target])
samples_weight = torch.from_numpy(samples_weight)
samples_weight = samples_weight.double()
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
参考:https://discuss.pytorch.org/t/how-to-handle-imbalanced-classes/11264/2
来源:https://blog.csdn.net/zcgyq/article/details/83087997


猜你喜欢
- 从而使得有些字符(尤其是宽字符)无法正确地显示,即不再是utf-8格式了。解决办法:打开输出文件时即指定编码格式,就不会出现输出文件打开以后
- 如果你有一堆 PPT 要做,他们的格式是一样的,只是填充的内容不一样,那你就可以使用 Python 来减轻你的负担。PPT 分为内容和格式,
- 此货很干,跟上脚步!!!Cookiecookie是什么东西?小饼干?能吃吗?简单来说就是你第一次用账号密码访问服务器服务器在你本机硬盘上设置
- 目录一.简介二.特色三.flask规模化四. flask Blueprint总结一.简介Flask是一个使用Python编写的轻量级Web应
- 前言针对一些特殊的需求,在项目里,需要将响应式数据变为普通原始类型数据,这种情况是有的在Vue里,能够将普通数据类型的数据变为响应式数据,同
- 有时候我们想要的数据合并结果是数据的轴向连接,在pandas中这可以通过concat来实现。操作的对象通常是Series。Ipython中的
- 简介:本文介绍了图像检索的三种实现方式,均用python完成,其中前两种基于直方图比较,哈希法基于像素分布。 检索方式是:提前导入图片库作为
- SQLyog是一款MySQL可视化工具,他可以将部分SQL操作通过图形化界面操作来完成,方便开发者更好的进行开发及数据库设计。在安装SQLy
- 1. 什么是XSLT 大家可能听说过XSL(eXtensible Stylesheet Language),XSL和我们这里说的XSLT从狭
- 本文实例为大家分享了OpenCV基于ORB算法实现角点检测的具体代码,供大家参考,具体内容如下ORB算法是FAST算法和BRIEF算法的结合
- #!/usr/bin/python #-*- encoding: utf-8 -*- import types class NotInteg
- 一扯上文化二字,总觉虚无缥缈、漫无边际,或者老气横秋,如何有趣地利用中华文化的思想和符号,结合现代的元素,使其成为有意思的传播手法,这个问题
- 前言Beautiful Soup是python的一个HTML或XML的解析库,我们可以用它来方便的从网页中提取数据,它拥有强大的API和多样
- 前言虽然标题是全站,但目前只做了等级 top 100 直播间的全天弹幕收集。弹幕收集系统基于之前的B 站直播弹幕姬 Python 版修改而来
- 1.首先检查自己的环境变量是否配置正确点击setting 点击 Python Interpreter点击Add Interpret
- 本文实例讲述了js简单实现Select互换数据的方法。分享给大家供大家参考。具体如下:这里基于javascript实现两个Select互换数
- Django视图函数执行,不在主线程中,直接loop = asyncio.new_event_loop() # 不能loop = async
- 在我们使用log模块输出日志时,经常会遇到log输出重复的问题,如下:先来看这个文件log.py的代码:代码示例:''
- python2.7,下面是跑在window上的,稍作修改就可以跑在linux上。实测win7和raspbian均可,且raspbian可以直
- 刚好前些天有人提到eval()与exec()这两个函数,所以就翻了下Python的文档。这里就来简单说一下这两个函数以及与它们相关的几个函数