Pytorch实现将label变成one hot编码的两种方式
作者:咆哮的阿杰 发布时间:2021-08-26 08:45:00
由于Pytorch不像TensorFlow有谷歌巨头做维护,很多功能并没有很高级的封装,比如说没有tf.one_hot函数。
本篇介绍将一个mini batch的label向量变成形状为[batch size, class numbers]的one hot编码的两种方法,涉及到
tensor.scatter_
tensor.index_select
前言
本文将针对全连接网络和全卷积网络输出的形式不同,将one hot编码分两种情况。
第一种针对网络输出是二维,即全连接层的输出形式, [Batchsize, Num_class]
第二种针对输出是四维特征图,即分割网络的输出形式,[Batchsize, Num_class, H,W]
先将第一种情况
使用scatter_获得one hot 编码
我相信在CSDN上找这个函数用法的人都是看不懂官方介绍的,所以我不会像其他地方那样,搬官方教程,我也是琢磨了很久才看懂这个函数,但函数声明还是要看看的。
tensor.scatter_(dim, index, src)
dim
: 指定了覆盖数据是从哪个轴作为依据。后面再详细解释。值的范围是从0到 sum(tensor.shape)-1index
: 告诉函数要将src中对应的值放到tensor的哪个位置。index的shape要和src一致,或者src可以通过广播机制实现shape一致。src
: 保存了想用来覆盖tensor的值
我们先看一个例子,例子从别的博客copy过来,但我会做更加详细的介绍。觉得讲得好请留言作为鼓励。
>>> x = torch.rand(2, 5)
>>> x
0.4319 0.6500 0.4080 0.8760 0.2355
0.2609 0.4711 0.8486 0.8573 0.1029
[torch.FloatTensor of size 2x5]
>>> torch.zeros(3, 5).scatter_(0, torch.LongTensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
0.4319 0.4711 0.8486 0.8760 0.2355
0.0000 0.6500 0.0000 0.8573 0.0000
0.2609 0.0000 0.4080 0.0000 0.1029
[torch.FloatTensor of size 3x5]
注意到dim为0,代表以第一个维度作为依托。index是一个二维数组。
[0,1,2,0,0]
[2,0,0,1,2]
那么我们要覆盖tensor的位置有10个,分别为
[0,0];[1,1];[2,2];[0,3];[0,4]
[2,0];[0,1];[0,2];[1,3];[2,4]
dim指定了index我们要将index的值作为哪一个轴的值。其他轴就是按照0到max shape -1变化罢了。比如说dim为0,那么index的值都作为坐标的第一个位置的值,另一个位置从0到4变换。
你们可以验证下,是不是这10个位置被覆盖了。10个位置的第一个轴是index的数字,第二个数字是index中的列数,从0到4。
要覆盖的位置有了,那么用什么值覆盖呢?别忘了我们的index的维度和src是一样的。index中选择什么位置的坐标,就对应用src对应的位置的值代替。
比如说要代替tensor中[0,0]的值,index中[0,0]就是第0行第0列对应的位置,那我们用src第0行第0列的值代替tensor的值。大家可以去验证一下。
我们看看下面的的情况,如果dim为1呢。
>>> z = torch.zeros(2, 4).scatter_(1, torch.LongTensor([[2], [3]]), 1.23)
>>> z
先分析一下
dim为1,那么index的值都作为坐标的第2个位置的值,第一个位置的值应该从0到1变化。
所以要被代替的位置有
[0,2];[1,3]
而[0,2]的位置要填入的值为1.23,[1,3]要填入的值为1.23。(广播机制将1.23这个标量扩展到了shape为(2,1))
好的,函数用法知道了。我们现在看看如何用该函数将label编码为one hot编码。
首先设想一个batch size为8的label。有10类,所以label中的数字应该是从0到9的。
import torch as t
import numpy as np
batch_size = 8
class_num = 10
label = np.random.randint(0,class_num,size=(batch_size,1))
label = t.LongTensor(label)
我们就获得了一个label,shape是(8,1),必须是2维。如果是(8,)下面的内容会报错的。
y_one_hot = t.zeros(batch_size,class_num).scatter_(1,label,1)
print(y_one_hot)
'''
tensor([[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
[0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
[0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 1., 0.]])
'''
搞定。下面我们看下面一种方法。
使用tensor.index_select获得one hot编码
还是先看下index_select的用法。
tensor.index_select( dim, index, out=None)
dim
: 指定按什么维度取tensor中的向量index
: 是一个一维的张量。描述了按照dim维度取出tensor对应的index值的向量。
我们不看例子了,直接看方法,以此为例。
ones = torch.sparse.torch.eye(class_num)
return ones.index_select(0,label)
这里的label是一维的向量,不是二维的。因为index制定了必须是一维的
先生成一个单位矩阵,尺寸是[class_num, class_num]。
dim为0,以为这按照行来取tensor的向量。具体取哪一行呢,就是label中的值了。
这时我们应该也明白为啥这两行代码能实现one hot编码了吧。
如果label是[ 1,3,0],有四类。那我们得到就是
[0,1,0,0]
[0,0,0,1]
[1,0,0,0]
第二种针对分割网络的one_hot编码
对于分割类任务,网络的GT肯定是二维数组,而不是像分类任务那样的一维数组了。而对于分割任务,我们将其视作很多个像素值的分类任务,将ground truth 直接 reshape为向量形式,然后用上面的方法转为one hot编码,然后再reshape回来。核心是不变的。
下面举个例子。
import torch
import numpy as np
gt = np.random.randint(0,5, size=[15,15]) #先生成一个15*15的label,值在5以内,意思是5类分割任务
gt = torch.LongTensor(gt)
def get_one_hot(label, N):
size = list(label.size())
label = label.view(-1) # reshape 为向量
ones = torch.sparse.torch.eye(N)
ones = ones.index_select(0, label) # 用上面的办法转为换one hot
size.append(N) # 把类别输目添到size的尾后,准备reshape回原来的尺寸
return ones.view(*size)
gt_one_hot = get_one_hot(gt, 5)
print(gt_one_hot)
print(gt_one_hot.shape)
print(gt_one_hot.argmax(-1) == gt) # 判断one hot 转换方式是否正确,全是1就是正确的
另外注意,在Pytorch中,如果要和网络输出的特征图一起计算loss,还要把上面输出的one hot编码的最后一个维度使用permute转到通道维度上。
来源:https://blog.csdn.net/qq_34914551/article/details/88700334
猜你喜欢
- 一、环境Ubuntu 16.04tensorflow 1.4.0keras 2.1.3二、训练数据时报错:ValueError: Error
- 系统环境:Win10 64位MySQL版本:mysql-5.7.18-winX64部署的步骤就是按照网上说的:1:修改环境变量path,增加
- QTimer控件介绍如果在应用程序中周期性地进行某项操作,比如周期性的检测主机的cpu值,则需要用到QTimer定时器,QTimer类提供了
- 一. ADO.NET的定义ADO.NET来源于COM组件库ADO(即ActiveX Data Objects),是微软公司新一代.NET数据
- 一、项目介绍爬取网址:CSDN首页的Python、Java、前端、架构以及数据库栏目。简单分析其各自的URL不难发现,都是https://w
- 本文内容会引起杀毒软件的莫名兴奋,建议先安抚杀毒软件,让杀毒软件先休息一下再继续操作安装python3.6转exe会遇到很多问题,其中部分是
- SQL标准定义了4类隔离级别,包括了一些具体规则,用来限定事务内外的哪些改变是可见的,哪些是不可见的。低级别的隔离级一般支持更高
- startswith()方法Python startswith() 方法用于检查字符串是否是以指定子字符串开头如果是则返回 True,否则返
- 在javascript中原型(prototype)定义了特定类型的所有实例都可以访问的属性和方法,很多些情况下需要重新对原型中的属性赋值,如
- ASP从发布至今已经7年了,使用ASP技术已经相当成熟,自从微软推出了ASP.NET之后就逐渐停止了对ASP版本的更新。但是由于有很多人仍然
- 在照着Tensorflow官网的demo敲了一遍分类器项目的代码后,运行倒是成功了,结果也不错。但是最终还是要训练自己的数据,所以尝试准备加
- 首先,来说一下对话框: 对话框在Windows应用程序中使用非常普遍,许多应用程序的设定,与用户交互需要通过对话框来进行,因此对话框是Win
- SQL Server TEXT、NTEXT字段拆分的问题引用的内容:SET NOCOUNT ON CREATE 
- 有助于效率的类型选择1、使你的数据尽可能小最基本的优化之一是使你的数据(和索引)在磁盘上(并且在内存中)占据的空间尽可能小。这能给出巨大的改
- 本文实例讲述了python操作mongodb根据_id查询数据的实现方法。分享给大家供大家参考。具体分析如下:_id是mongodb自动生成
- Base64是网络上最常见的用于传输8Bit字节码的编码方式之一,是一种基于64个可打印字符来表示二进制数据的方法。通过http传输图片常常
- 如下所示:def ref_txt_demo(): f = open('1.txt', 'r') data =
- 前言cookie使用最多的地方想必是保存用户的账号与密码,可以避免用户每次登录时都要重新输入1.vue中cookie的安装在终端中输入命令n
- Go令牌Go程序包括各种令牌和令牌可以是一个关键字,一个标识符,常量,字符串文字或符号。例如,下面的Go语句由六个令牌:fmt.Printl
- 本文实例讲述了python3实现的zip格式压缩文件夹操作。分享给大家供大家参考,具体如下:思路:先把第一级目录中的文件进行遍历,如果是文件