Pytorch框架之one_hot编码函数解读
作者:NULL 发布时间:2023-02-16 11:34:05
Pytorch one_hot编码函数解读
one_hot编码定义
在一个给定的向量中,按照设定的最值–可以是向量中包含的最大值(作为最高分类数),有也可以是自定义的最大值,设计one_hot编码的长度:最大值+1【详见举的例子吧】。
然后按照最大值创建一个1*(最大值+1)的维度大小的全零零向量:[0, 0, 0, …] => 共最大值+1对应的个数
接着按照向量中的值,从第0位开始索引,将向量中值对应的位置设置为1,其他保持为0.
eg:
假设设定one_hot长度为4(最大值) –
且当前向量中值为1对应的one_hot编码:
[0, 1, 0, 0]
当前向量中值为2对应的one_hot编码:
[0, 0, 1, 0]
eg:
假设设定one_hot长度为6(等价最大值+1) –
且当前向量中值为4对应的one_hot编码:
[0, 0, 0, 0, 1, 0]
当前向量中值为2对应的one_hot编码:
[0, 0, 1, 0, 0, 0]
eg:
targets = [4, 1, 0, 3] => max_value=4=>one_hot的长度为(4+1)
假设设定one_hot长度为5(最大值) –
且当前向量中值为4对应的one_hot编码:
[0, 0, 0, 0, 1]
当前向量中值为1对应的one_hot编码:
[0, 1, 0, 0, 0]
Pytorch中one_hot转换
import torch
targets = torch.tensor([5, 3, 2, 1])
targets_to_one_hot = torch.nn.functional.one_hot(targets) # 默认按照targets其中的最大值+1作为one_hot编码的长度
# result:
# tensor(
# [0, 0, 0, 0, 0, 1],
# [0, 0, 0, 1, 0, 0],
# [0, 0, 1, 0, 0, 0],
# [0, 1, 0, 0, 0, 0]
#)
targets_to_one_hot = torch.nn.functional.one_hot(targets, num_classes=7) 3# 指定one_hot编码长度为7
# result:
# tensor(
# [0, 0, 0, 0, 0, 1, 0],
# [0, 0, 0, 1, 0, 0, 0],
# [0, 0, 1, 0, 0, 0, 0],
# [0, 1, 0, 0, 0, 0, 0]
#)
总结:one_hot编码主要用于分类时,作为一个类别的编码–方便判别与相关计算;
1. 如同类别数统计,只需要将one_hot编码相加得到一个一维向量就知道了一批数据中所有类别的预测或真实的分布情况;
2. 相比于预测出具体的类别数–43等,用向量可以使用向量相关的算法进行时间上的优化等等
Pytorch变量类型转换及one_hot编码表示
生成张量
y = torch.empty(3, dtype=torch.long).random_(5)
y = torch.Tensor(2,3).random_(10)
y = torch.randn(3,4).random_(10)
查看类型
y.type
y.dtype
类型转化
tensor.long()/int()/float()
long(),int(),float() 实现类型的转化
One_hot编码表示
def one_hot(y):
'''
y: (N)的一维tensor,值为每个样本的类别
out:
y_onehot: 转换为one_hot 编码格式
'''
y = y.view(-1, 1)
# y_onehot = torch.FloatTensor(3, 5)
# y_onehot.zero_()
y_onehot = torch.zeros(3,5) # 等价于上面
y_onehot.scatter_(1, y, 1)
return y_onehot
y = torch.empty(3, dtype=torch.long).random_(5) #标签
res = one_hot(y) # 转化为One_hot类型
# One_hot类型标签转化为整数型列表的两种方法
h = torch.argmax(res,dim=1)
_,h1 = res.max(dim=1)
expand()函数
这个函数的作用就是对指定的维度进行数值大小的改变。只能改变维大小为1的维,否则就会报错。不改变的维可以传入-1或者原来的数值。
a=torch.randn(1,1,3,768)
print(a.shape) #torch.Size([1, 1, 3, 768])
b=a.expand(2,-1,-1,-1)
print(b.shape) #torch.Size([2, 1, 3, 768])
c=a.expand(2,1,3,768)
print(c.shape) #torch.Size([2, 1, 3, 768])
repeat()函数
沿着指定的维度,对原来的tensor进行数据复制。这个函数和expand()还是有点区别的。expand()只能对维度为1的维进行扩大,而repeat()对所有的维度可以随意操作。
a=torch.randn(2,1,768)
print(a)
print(a.shape) #torch.Size([2, 1, 768])
b=a.repeat(1,2,1)
print(b)
print(b.shape) #torch.Size([2, 2, 768])
c=a.repeat(3,3,3)
print(c)
print(c.shape) #torch.Size([6, 3, 2304])
来源:https://blog.csdn.net/weixin_44604887/article/details/109523281


猜你喜欢
- 本节内容深浅拷贝循环方式字典常用方法总结一、深浅拷贝列表、元组、字典(以及其他)对于列表、元组和字典而言,进行赋值(=)、浅拷贝(copy)
- 平面设计 常用尺寸 三折页广告 标准尺寸: (A4)210mm x 285mm普通宣传册 标准尺寸: (A4)210mm x 285mm文件
- 前言在搜集了很多文本语料之后,会开始漫长的数据清洗过程,通常要不断迭代。1. 问题描述有些文本数据中,会包含一些特殊符号。猜想可能是从某些富
- 最近有需求是,需要把对方提供的ftp地址上的图片获取到本地服务器,原先计划想着是用shell 操作,因为shell 本身也支持ftp的命令
- 一、格式化输入和输出1.从终端获取用户的输入fmt.Scanf 空格作为分隔符,占位符和格式化输出的一致fmt.Scan 从终端获取用户的输
- 本文实例为大家分享了微信小程序实现侧边导航栏的具体代码,供大家参考,具体内容如下效果图wxml<view class='pro
- 功能:实现网页内容的即时编辑,增加页面的可用性、交互性。方法1:直接通过textarea标签实现,请运行下边代码:<!DOCTYPE
- 配置要求:IIS(win2000 server 自带)、Java 2 SDK 1.4.2 (或更高版本)、Tomcat Web Server
- 不知道算不算DW4的大BUG. DW4实际的运行如下: 读注册表中HKEY_CURRENT_USER/
- 代码如下所示:import osimport requestsimport datetimefrom Crypto.Cipher impor
- 前言这里先说明一下,网上很多人说阿里规定500w数据就要分库分表。实际上,这个500w并不是定义死的,而是与MySQL的配置以及机器的硬件有
- 简介Elasticsearch 是一个分布式可扩展的实时搜索和分析引擎,一个建立在全文搜索引擎 Apache Lucene&trad
- EXCEL数据上传到SQL SERVER中的方法需要注意到三点!注意点一:要把EXCEL数据上传到SQL SERVER中必须提前把EXCEL
- 背景故事2022虎年将至,值此新春佳节之际,各大社区更是你争我赶纷纷发起春节征文活动正当我一筹莫展之际,几位粉丝朋友们的小请求点醒了我:对呀
- Git 恢复到之前版本1. 应用场景进行了错误提交,需要将代码回退至某个版本;或者需要检出某个版本的代码,再切换回最新版本。2. 解决方法2
- 记录一下:# Three loss functionscategory_predict1 = Dense(100, activation=&
- SQL Server具有强大的复制功能,除了将数据和数据库对象从一个数据库复制并准确分发的另一个数据库中,还要实行数据库之间的同步。SQL
- 一、Python 下载Python是运行的环境,必不可少,如果你是Linux系统的话,不用安装,自带了Python。首先我们打开浏览器搜索P
- 背景我们先来看看MySQL 8.0的事务提交的大致流程以上流程,是MySQL8.0对WAL原则的一种实现,这个流程意味着,任何一个事务的提交
- 前几天要写一个东西里面有用到读文件的。 可是我不想用FSO,我怕有的空间不支持。 &nbs