基于pytorch 预训练的词向量用法详解
作者:kejizuiqianfang 发布时间:2021-04-04 21:35:37
如何在pytorch中使用word2vec训练好的词向量
torch.nn.Embedding()
这个方法是在pytorch中将词向量和词对应起来的一个方法. 一般情况下,如果我们直接使用下面的这种:
self.embedding = torch.nn.Embedding(num_embeddings=vocab_size, embedding_dim=embeding_dim)
num_embeddings=vocab_size 表示词汇量的大小
embedding_dim=embeding_dim 表示词向量的维度
这种情况下, 因为没有指定训练好的词向量, 所以embedding会帮咱们生成一个随机的词向量(但是在我刚刚测试的一个情感二分类问题中, 我发现好像用不用预训练的词向量, 结果差不多, 不过不排除是因为当时使用的模型比较简单, 导致一些特征根本就没提取出来).
如果我想使用word2vec预训练好的词向量该怎么做呢?
其实很简单,pytorch已经给我们提供好了接口
self.embedding.weight.data.copy_(torch.from_numpy(embeding_vector))
self.embedding.weight.requires_grad = False
上面两句代码的意思, 第一句就是导入词向量, 第二句表示的是在反向传播的时候, 不要对这些词向量进行求导更新. 我还看到有人会在优化器那里使用这样的代码:
# emotion_net是我定义的模型
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, emotion_net.parameters()), lr=1e-3, betas=(0.9, 0.99))
大概意思也是为了保证词向量不会被反向传播而更新, 具体有没有用我就不清楚了.
其实我感觉大家比较在意的其实应该是embeding_vector的形式, 下面我就介绍一下embeding_vector的形式
为了讲述方便, 这里定义出下面几个矩阵
embeding_vector:表示词向量,每行是一个词的词向量,有多少行就说明有多少单词
word_list:表示单词列表,里面就是单词
word_to_index:这个矩阵将word_list中的单词和embeding_vector中的位置对应起来
其实embeding_vector是一个numpy矩阵, 当然你看到了, 实际输入到pytorch的时候, 是需要转换成tensor类型的. 这个矩阵是什么样子的呢? 其中这个矩阵是 [vocab_size×embeding_dim] [vocab\_size \times embeding\_dim][vocab_size×embeding_dim] 的形式. 其中一共包含vocab_size vocab\_sizevocab_size 个单词, 每个单词的维度是 embed_dim embed\_dimembed_dim, 我们把这样一个矩阵输入就行了.
之后, 我们要做的其实就是将 word_to_index word\_to\_indexword_to_index 这个矩阵搞出来, 这里的单词转下标的矩阵, 就是联系 embeding_vector embeding\_vectorembeding_vector 和 word_list word\_listword_list 这两个矩阵的中间者. 我们在输入到torch.nn.Embedding中之前, 需要先通过 word_to_index word\_to\_indexword_to_index 将单词转换成 embeding_vector embeding\_vectorembeding_vector 的下标就可以了.
来源:https://blog.csdn.net/kejizuiqianfang/article/details/100825156


猜你喜欢
- 通过优化CSS代码,减小对系统资源的占用。自己整理出几个能减少系统资源占用的CSS写法,要优化网站的页面加载速度,这些注意点不能忽视!一、尽
- 简介日常开发中, 测试是不能缺少的.Go 标准库中有一个叫做 testing 的测试框架, 可以用于单元测试和性能测试.它是和命令 go t
- 插入记录时,影响插入速度的主要是索引、唯一性校验、一次插入记录条数等。根据这些情况,可以分别进行优化,本节将介绍优化插入记录速度的几种方法。
- 问:怎样解决MySQL 5.0.16的乱码问题?答:MySQL 5.0.16的乱码问题可以用下面的方法解决:1.设置phpMyAdminLa
- vendorvendor概念最早是由Keith提出,用来存放依赖包。在版本1.5出现。例如gb项目提供了一个名为gsftp的示例项目,它有一
- 目前还没有更好的方法来追写Excel,lorinnn在网上搜索到以及之后用到的方法就是使用第三方库xlutils来实现了这个功能,主体思想就
- 附加数据库出错:无法打开物理文件 "XXXXXXXXXXXXX"。操作系统错误 5:"5(拒绝访问。)&quo
- 工作中每天需要收集部门内的FR文件,发送给外部部门的同事帮忙上传,这么发了有大半年,昨天亮光一闪,为什么不做成自动化呢,于是用python实
- 1.前言有时候,我们需要把A库A1表某一部分或全部数据导出到B库B1表中,如果系统运维工程师没打通两个库链接,我们执行T-SQL是处理数据导
- 无限循环如果条件判断语句永远为 true,循环将会无限的执行下去,如下实例:#!/usr/bin/python# -*- coding: U
- OpenCV提供了很多简单的语句,实现复杂的功能,根据颜色和鼠标交互的基础语句,我们可以建立一个简单的画板。尽管它简单,但是制作的框架步骤不
- 双休日常常意味着很多休息时间。与其懒洋洋地坐在那里玩游戏,为何不学点新知识武装自己?本文中不会特定推荐哪种编程语言,但是会提供基于GitHu
- 前言在之前写过一篇博客"关系数据库如何快速查询表的记录数",里面介绍了使用sp_spaceused查看表的记录数是否正确
- 这篇文章主要介绍了python解析命令行参数的三种方法详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要
- bootstrap中有javascript插件modal也就是对话框,加入拖拽功能,具体内容如下在使用modal时首选需要引用js<l
- 在js中调用asp文件的方法很简单,我们可以用在静态页面的点击数统计,虽然静态页面不支持asp程序,但是我们可以使用js调用,来变相的达到这
- 参考网址 https://www.jb51.net/article/29551.htmSELECT [StartDate] FROM [db
- MapPathMapPath 方法将指定的相对或虚拟路径映射到服务器上相应的物理目录上。语法Server.MapPath( Path ) 参
- 这个问题好像在各种数据库中都存在,该如何处理呢?一、SQL中:sql="CREATE TABLE phone&
- 先给大家介绍下Python 字符串前面加u,r,b,f的含义(字符串前缀)1、字符串前加 u例:u"我是含有中文字符组成的字符串。