Pytorch框架之one_hot编码函数解读
作者:NULL 发布时间:2023-02-16 11:34:05
Pytorch one_hot编码函数解读
one_hot编码定义
在一个给定的向量中,按照设定的最值–可以是向量中包含的最大值(作为最高分类数),有也可以是自定义的最大值,设计one_hot编码的长度:最大值+1【详见举的例子吧】。
然后按照最大值创建一个1*(最大值+1)的维度大小的全零零向量:[0, 0, 0, …] => 共最大值+1对应的个数
接着按照向量中的值,从第0位开始索引,将向量中值对应的位置设置为1,其他保持为0.
eg:
假设设定one_hot长度为4(最大值) –
且当前向量中值为1对应的one_hot编码:
[0, 1, 0, 0]
当前向量中值为2对应的one_hot编码:
[0, 0, 1, 0]
eg:
假设设定one_hot长度为6(等价最大值+1) –
且当前向量中值为4对应的one_hot编码:
[0, 0, 0, 0, 1, 0]
当前向量中值为2对应的one_hot编码:
[0, 0, 1, 0, 0, 0]
eg:
targets = [4, 1, 0, 3] => max_value=4=>one_hot的长度为(4+1)
假设设定one_hot长度为5(最大值) –
且当前向量中值为4对应的one_hot编码:
[0, 0, 0, 0, 1]
当前向量中值为1对应的one_hot编码:
[0, 1, 0, 0, 0]
Pytorch中one_hot转换
import torch
targets = torch.tensor([5, 3, 2, 1])
targets_to_one_hot = torch.nn.functional.one_hot(targets) # 默认按照targets其中的最大值+1作为one_hot编码的长度
# result:
# tensor(
# [0, 0, 0, 0, 0, 1],
# [0, 0, 0, 1, 0, 0],
# [0, 0, 1, 0, 0, 0],
# [0, 1, 0, 0, 0, 0]
#)
targets_to_one_hot = torch.nn.functional.one_hot(targets, num_classes=7) 3# 指定one_hot编码长度为7
# result:
# tensor(
# [0, 0, 0, 0, 0, 1, 0],
# [0, 0, 0, 1, 0, 0, 0],
# [0, 0, 1, 0, 0, 0, 0],
# [0, 1, 0, 0, 0, 0, 0]
#)
总结:one_hot编码主要用于分类时,作为一个类别的编码–方便判别与相关计算;
1. 如同类别数统计,只需要将one_hot编码相加得到一个一维向量就知道了一批数据中所有类别的预测或真实的分布情况;
2. 相比于预测出具体的类别数–43等,用向量可以使用向量相关的算法进行时间上的优化等等
Pytorch变量类型转换及one_hot编码表示
生成张量
y = torch.empty(3, dtype=torch.long).random_(5)
y = torch.Tensor(2,3).random_(10)
y = torch.randn(3,4).random_(10)
查看类型
y.type
y.dtype
类型转化
tensor.long()/int()/float()
long(),int(),float() 实现类型的转化
One_hot编码表示
def one_hot(y):
'''
y: (N)的一维tensor,值为每个样本的类别
out:
y_onehot: 转换为one_hot 编码格式
'''
y = y.view(-1, 1)
# y_onehot = torch.FloatTensor(3, 5)
# y_onehot.zero_()
y_onehot = torch.zeros(3,5) # 等价于上面
y_onehot.scatter_(1, y, 1)
return y_onehot
y = torch.empty(3, dtype=torch.long).random_(5) #标签
res = one_hot(y) # 转化为One_hot类型
# One_hot类型标签转化为整数型列表的两种方法
h = torch.argmax(res,dim=1)
_,h1 = res.max(dim=1)
expand()函数
这个函数的作用就是对指定的维度进行数值大小的改变。只能改变维大小为1的维,否则就会报错。不改变的维可以传入-1或者原来的数值。
a=torch.randn(1,1,3,768)
print(a.shape) #torch.Size([1, 1, 3, 768])
b=a.expand(2,-1,-1,-1)
print(b.shape) #torch.Size([2, 1, 3, 768])
c=a.expand(2,1,3,768)
print(c.shape) #torch.Size([2, 1, 3, 768])
repeat()函数
沿着指定的维度,对原来的tensor进行数据复制。这个函数和expand()还是有点区别的。expand()只能对维度为1的维进行扩大,而repeat()对所有的维度可以随意操作。
a=torch.randn(2,1,768)
print(a)
print(a.shape) #torch.Size([2, 1, 768])
b=a.repeat(1,2,1)
print(b)
print(b.shape) #torch.Size([2, 2, 768])
c=a.repeat(3,3,3)
print(c)
print(c.shape) #torch.Size([6, 3, 2304])
来源:https://blog.csdn.net/weixin_44604887/article/details/109523281
猜你喜欢
- 1. 需要的库, redispip install redis2. 连接Redisimport redisclass RedisCtrl(o
- 一、什么是字典树在自然语言处理中,字符串集合常用字典树存储,这是一种字符串上的树形数据结构。字典树中每条边都对应一个字,从根节点往下的路径构
- 【导读】亚马逊的 Alexa 的巨大成功已经证明:在不远的将来,实现一定程度上的语音支持将成为日常科技的基本要求。整合了语音识别的 Pyth
- 本文将介绍使用Dreamweaver来制作滑动菜单的方法,言归正传,废话少说。准备工作如下: 1. 在dw中新建一个空白文档(或者打开你要添
- 序言哈喽兄弟们,今天来实现一个Python采集视频、弹幕、评论与一体的小软件。平常咱们都是直接代码运行,不过今天我们做成软件,这样的话,咱们
- Python处理json字符串中的非法双引号工作中数据清洗时遇到以下情况:a = '{"地区": "湖
- 叨逼叨首先,介绍一下 pdb 调试,pdb 是 python 的一个内置模块,用于命令行来调试 Python 代码。或许你会说,现在用 Py
- 1.安装vscode和python3.7(安装路径在:E:\Python\Python37);2.打开vscode,在左下角点击设置图标选择
- 请问如何从ASP连接到Oracle Server?可用下面的代码进行连接: <%@ Lan
- 外部数据导入导入excel文件pandas导入excel用read_excel()方法:import pandas as pdexcel_f
- ASP通过XMLDom在服务器端操作XML文件的主要方法和实现对于小数据量,xml文件在检索更新上于ACCESS有很多优势。我曾经测试过不用
- 使用Python可视化Pygal包来生成可缩放的矢量图形文件!对于在尺寸不同的屏幕上显示图标,它们将自动缩放以适合观看者的屏幕,如果以在线的
- 内容摘要: Request和Response这两个对象是ASP所提供的内置对象中最常用的两个。在浏览器(或其他用户代理)和Web服
- 本文实例讲述了go语言睡眠排序算法。分享给大家供大家参考。具体分析如下:睡眠排序算法是一个天才程序员发明的,想法很简单,就是针对数组里的不同
- 格式化字符串漏洞覆盖大数字时,如果选择一次性输出大数字个字节来进行覆盖,会很久很久,或者直接报错中断,所以来搞个攻防世界高手区的题目来总结一
- 以下针对Ubuntu系统,Windows系统没有测试过。Ubuntu中默认就安装有Python 2.x和Python 3.x,默认情况下py
- 一、用HTTP头信息 也就是用PHP的HEADER函数。PHP里的HEADER函数的作用就是向浏览器发出由HTTP协议规定的本来应该通过WE
- 1,为了获取视频,你应该创建一个 VideoCapture 对象。他的参数可以是设备的索引号,或者是一个视频文件。设备索引号就是在指定要使用
- 本文实例为大家分享了python+pygame实现坦克大战的具体代码,供大家参考,具体内容如下一、首先导入pygame库二、源码分享#cod
- 一、数据类型1.数据类型的判断Number => int float complex bool容器 => str list tu