pytorch中的embedding词向量的使用方法
作者:乐且有仪 发布时间:2022-03-25 09:05:27
标签:pytorch,embedding,词向量
Embedding
词嵌入在 pytorch 中非常简单,只需要调用 torch.nn.Embedding(m, n) 就可以了,m 表示单词的总数目,n 表示词嵌入的维度,其实词嵌入就相当于是一个大矩阵,矩阵的每一行表示一个单词。
emdedding初始化
默认是随机初始化的
import torch
from torch import nn
from torch.autograd import Variable
# 定义词嵌入
embeds = nn.Embedding(2, 5) # 2 个单词,维度 5
# 得到词嵌入矩阵,开始是随机初始化的
torch.manual_seed(1)
embeds.weight
# 输出结果:
Parameter containing:
-0.8923 -0.0583 -0.1955 -0.9656 0.4224
0.2673 -0.4212 -0.5107 -1.5727 -0.1232
[torch.FloatTensor of size 2x5]
如果从使用已经训练好的词向量,则采用
pretrained_weight = np.array(args.pretrained_weight) # 已有词向量的numpy
self.embed.weight.data.copy_(torch.from_numpy(pretrained_weight))
embed的读取
读取一个向量。
注意参数只能是LongTensor型的
# 访问第 50 个词的词向量
embeds = nn.Embedding(100, 10)
embeds(Variable(torch.LongTensor([50])))
# 输出:
Variable containing:
0.6353 1.0526 1.2452 -1.8745 -0.1069 0.1979 0.4298 -0.3652 -0.7078 0.2642
[torch.FloatTensor of size 1x10]
读取多个向量。
输入为两个维度(batch的大小,每个batch的单词个数),输出则在两个维度上加上词向量的大小。
Input: LongTensor (N, W), N = mini-batch, W = number of indices to extract per mini-batch
Output: (N, W, embedding_dim)
见代码
# an Embedding module containing 10 tensors of size 3
embedding = nn.Embedding(10, 3)
# 每批取两组,每组四个单词
input = Variable(torch.LongTensor([[1,2,4,5],[4,3,2,9]]))
a = embedding(input) # 输出2*4*3
a[0],a[1]
输出为:
(Variable containing:
-1.2603 0.4337 0.4181
0.4458 -0.1987 0.4971
-0.5783 1.3640 0.7588
0.4956 -0.2379 -0.7678
[torch.FloatTensor of size 4x3], Variable containing:
-0.5783 1.3640 0.7588
-0.5313 -0.3886 -0.6110
0.4458 -0.1987 0.4971
-1.3768 1.7323 0.4816
[torch.FloatTensor of size 4x3])
来源:https://blog.csdn.net/david0611/article/details/81090371


猜你喜欢
- 前言这部分已经折腾我两天了,还是没有头绪,可能还会折腾更久,最后在第三天上午解决问题,在一个不起眼的地方被坑了,jQuery加载的问题。会者
- 要想在ASP.NET项目中使用SQLite数据库,先需下载一个ADO.NET 2.0 SQLite Data Provider,下载地址为:
- 原因:使用git clone项目后,项目根路径是小写英文名称,比如cmdbapi,但是项目里面的import导入自己的相关包时,红色报错解决
- 正则表达式,又称正规表示法、常规表示法(英语:Regular Expression,在代码中常简写为regex、regexp或RE),计算机
- WSGI协议首先弄清下面几个概念:WSGI:全称是Web Server Gateway Interface,WSGI不是服务器,python
- 本文实例讲述了Python将名称映射到序列元素中的方法。分享给大家供大家参考,具体如下:问题:希望通过名称来访问元素,减少结构中对位置的依赖
- 前言废话不多说,直接开造。这里的话我们有两个目标,第一个是如何把一个2维图片上的点映射到3维空间。第二就是如何生成3D点云。当然实际上这是一
- MySQL性能优化在互联网公司MySQL的使用非常广泛,大家经常会有MySQL性能优化方面的需求。整理了一些在MySQL优化方面的实用技巧。
- 一、实战场景Flask 框架实现用户的注册,登录和登出。二、主要知识点flask_login 插件使用SQLAlchemy 基础操作用户基础
- MySQL支持的两种主要表存储格式MyISAM,InnoDB,上个月做个项目时,先使用了InnoDB,结果速度特别慢,1秒钟只能插入10几条
- reshape函数:改变数组的维数(注意不是shape大小)>>> e= np.arange(10)>>>
- 由于javascript无法获取img文件头数据,必须等待其加载完毕后才能获取真实的大小,所以lightbox类效果为了让图片居中显示,导致
- 往期学习:python数据类型: python数据结构:数据类型.python的输入输出: python数据结构之输入输出及控制和异常.py
- 在使用数据库的时候,难免要在使用过程中进行删除的操作,如果是使用int类型的字段,令其自增长,这是个最简单的办法,但是后果会有些不是你想要的
- ASP具备动态输出任一Office应用程序文件格式的功能。在开始编写代码之前,我们首先需要做的就是设置正确的文件类型,因为浏览器需要知道如何
- 该章节我们来学习一下在 Python 中去创建并使用多进程的方法,通过学习该章节,我们将可以通过创建多个进程来帮助我们提高脚本执行的效率。可
- 目录1.jupyter2.jupyter基础操作2.1windows更新pip库2.2jupyter安装2.3初次启动jupyter2.4设
- 引言Python 是一个强大的语言,提供了许多内置函数以帮助开发者编写高效、简洁的代码。在这篇文章中,我们将深入探讨三个内置函数:map、f
- 前言JS为什么要用ajax来提交在使用from提交时,浏览器会向服务器发送选中的文件的内容而不仅仅是发送文件名。为安全起见,即file-up
- 表分区是最近才知道的哦 ,以前自己做都是分表来实现上亿级别的数据了,下面我来给大家介绍一下mysql表分区创建与使用吧,希望对各位同学会有所