Pytorch BertModel的使用说明
作者:无聊的人生事无聊 发布时间:2023-06-28 09:30:37
基本介绍
环境: Python 3.5+, Pytorch 0.4.1/1.0.0
安装:
pip install pytorch-pretrained-bert
必需参数:
--data_dir: "str": 数据根目录.目录下放着,train.xxx/dev.xxx/test.xxx三个数据文件.
--vocab_dir: "str": 词库文件地址.
--bert_model: "str": 存放着bert预训练好的模型. 需要是一个gz文件, 如"..x/xx/bert-base-chinese.tar.gz ", 里面包含一个bert_config.json和pytorch_model.bin文件.
--task_name: "str": 用来选择对应数据集的参数,如"cola",对应着数据集.
--output_dir: "str": 模型预测结果和模型参数存储目录.
简单例子:
导入所需包
import torch
from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM
创建分词器
tokenizer = BertTokenizer.from_pretrained(--vocab_dir)
需要参数: --vocab_dir, 数据样式见此
拥有函数:
tokenize: 输入句子,根据--vocab_dir和贪心原则切词. 返回单词列表
convert_token_to_ids: 将切词后的列表转换为词库对应id列表.
convert_ids_to_tokens: 将id列表转换为单词列表.
text = '[CLS] 武松打老虎 [SEP] 你在哪 [SEP]'
tokenized_text = tokenizer.tokenize(text)
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
segments_ids = [0, 0, 0, 0, 0, 0, 0,0,0,0, 1,1, 1, 1, 1, 1, 1, 1]
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])
这里对标记符号的切词似乎有问题([cls]/[sep]), 而且中文bert是基于字级别编码的,因此切出来的都是一个一个汉字:
['[', 'cl', '##s', ']', '武', '松', '打', '老', '虎', '[', 'sep', ']', '你', '在', '哪', '[', 'sep', ']']
创建bert模型并加载预训练模型:
model = BertModel.from_pretrained(--bert_model)
放入GPU:
tokens_tensor = tokens_tensor.cuda()
segments_tensors = segments_tensors.cuda()
model.cuda()
前向传播:
encoded_layers, pooled_output= model(tokens_tensor, segments_tensors)
参数:
input_ids: (batch_size, sqe_len)代表输入实例的Tensor
token_type_ids=None: (batch_size, sqe_len)一个实例可以含有两个句子,这个相当于句子标记.
attention_mask=None: (batch_size*): 传入每个实例的长度,用于attention的mask.
output_all_encoded_layers=True: 控制是否输出所有encoder层的结果.
返回值:
encoded_layer:长度为num_hidden_layers的(batch_size, sequence_length,hidden_size)的Tensor.列表
pooled_output: (batch_size, hidden_size), 最后一层encoder的第一个词[CLS]经过Linear层和激活函数Tanh()后的Tensor. 其代表了句子信息
补充:pytorch使用Bert
主要分为以下几个步骤:
下载模型放到目录中
使用transformers中的BertModel,BertTokenizer来加载模型与分词器
使用tokenizer的encode和decode 函数分别编码与解码,注意参数add_special_tokens和skip_special_tokens
forward的输入是一个[batch_size, seq_length]的tensor,再需要注意的是attention_mask参数。
输出是一个tuple,tuple的第一个值是bert的最后一个transformer层的hidden_state,size是[batch_size, seq_length, hidden_size],也就是bert最后的输出,再用于下游的任务。
# -*- encoding: utf-8 -*-
import warnings
warnings.filterwarnings('ignore')
from transformers import BertModel, BertTokenizer, BertConfig
import os
from os.path import dirname, abspath
root_dir = dirname(dirname(dirname(abspath(__file__))))
import torch
# 把预训练的模型从官网下载下来放到目录中
pretrained_path = os.path.join(root_dir, 'pretrained/bert_zh')
# 从文件中加载bert模型
model = BertModel.from_pretrained(pretrained_path)
# 从bert目录中加载词典
tokenizer = BertTokenizer.from_pretrained(pretrained_path)
print(f'vocab size :{tokenizer.vocab_size}')
# 把'[PAD]'编码
print(tokenizer.encode('[PAD]'))
print(tokenizer.encode('[SEP]'))
# 把中文句子编码,默认加入了special tokens了,也就是句子开头加入了[CLS] 句子结尾加入了[SEP]
ids = tokenizer.encode("我是中国人", add_special_tokens=True)
# 从结果中看,101是[CLS]的id,而2769是"我"的id
# [101, 2769, 3221, 704, 1744, 782, 102]
print(ids)
# 把ids解码为中文,默认是没有跳过特殊字符的
print(tokenizer.decode([101, 2769, 3221, 704, 1744, 782, 102], skip_special_tokens=False))
# print(model)
inputs = torch.tensor(ids).unsqueeze(0)
# forward,result是一个tuple,第一个tensor是最后的hidden-state
result = model(torch.tensor(inputs))
# [1, 5, 768]
print(result[0].size())
# [1, 768]
print(result[1].size())
for name, parameter in model.named_parameters():
# 打印每一层,及每一层的参数
print(name)
# 每一层的参数默认都requires_grad=True的,参数是可以学习的
print(parameter.requires_grad)
# 如果只想训练第11层transformer的参数的话:
if '11' in name:
parameter.requires_grad = True
else:
parameter.requires_grad = False
print([p.requires_grad for name, p in model.named_parameters()])
添加atten_mask的方法:
其中101是[CLS],102是[SEP],0是[PAD]
>>> a
tensor([[101, 3, 4, 23, 11, 1, 102, 0, 0, 0]])
>>> notpad = a!=0
>>> notpad
tensor([[ True, True, True, True, True, True, True, False, False, False]])
>>> notcls = a!=101
>>> notcls
tensor([[False, True, True, True, True, True, True, True, True, True]])
>>> notsep = a!=102
>>> notsep
tensor([[ True, True, True, True, True, True, False, True, True, True]])
>>> mask = notpad & notcls & notsep
>>> mask
tensor([[False, True, True, True, True, True, False, False, False, False]])
>>>
以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。如有错误或未考虑完全的地方,望不吝赐教。
来源:https://blog.csdn.net/Wangpeiyi9979/article/details/89191709
猜你喜欢
- 这里介绍了5中python获取window桌面路径的方法,获取这个路径有什么用呢?一般是将程序生成的文档输出到桌面便于查看编辑。前两个方法是
- language.xml 代码如下:<?xml version="1.0" encoding=
- 本文实例讲述了PHP 对象继承原理与简单用法。分享给大家供大家参考,具体如下:对象继承继承已为大家所熟知的一个程序设计特性,PHP 的对象模
- 如何自动更新导航栏?下面看看如何具体使用Content Linking组件: <&nbs
- 本文实例讲述了Python实现获取汉字偏旁部首的方法。分享给大家供大家参考,具体如下:功能介绍传入一个汉字,返回其偏旁部首字典分为本地字典与
- 本文实例讲述了Python图形绘制操作之正弦曲线实现方法。分享给大家供大家参考,具体如下:要画正弦曲线先设定一下x的取值范围,从0到2π。要
- 在网上查阅资料,发现很少用Python进行高斯函数的三维显示绘图的,原因可能是其图形显示太过怪异,没有MATLAB精细和直观。回顾一下二维高
- 一、简介基础知识:需要一定的html和css的语法知识基本概念:PHP(超文本预处理器)是一种通用开源脚本语言,在服务器上执行。PHP文件:
- 如下所示:# -*- coding:utf-8 -*-import sysreload(sys)sys.setdefaultencoding
- 一、“无”的哲学佛家讲究“因果报应”,有果必有应。此段看似与主题没有血缘关系,实际讲的是“因”。我个人比较喜欢老子的道家思想,并喜欢以其思想
- 我们到目前为止所谈到的SQL语句相对较为简单,如果再能通过标准的recordset循环查询,那么这些语句也能满足一些更复杂的要求。不过,何必
- 1. 二维(多维)数组降为一维数组方法1: reshape()+concatenate 函数,这个方法是间接法,利用 reshape() 函
- 前言:record类型,这是一种新引用类型,而不是类或结构。record与类不同,区别在于record类型使用基于值的相等性。例如:publ
- 以查询前20到30条为例,主键名为id 方法一: 先正查,再反查 select top 10 * from (select top 30 *
- 1)添加下面一句话到模型中for p in self.parameters(): p.requires_grad = False比如加载了r
- 如何在线查询本地机的文件?看看下面的例子,默认子目录与子虚拟目录为同一级别且名称一致,另我们使用了"http://intels.n
- 在JavaScript中,我们应该尽可能的用局部变量来代替全局变量,这句话所有人都知道,可是这句话是谁先说的?为什么要这么做?有什么根据么?
- 逆向最大匹配方法有正即有负,正向最大匹配算法大家可以参阅https://www.jb51.net/article/127404.htm逆向最
- 幸运草又名四叶草,一般指四叶的苜蓿、或车轴草。在十万株苜蓿草中,你可能只会发现一株是四叶草,机会率大约是十万分之一。因此四叶草是国际公认的幸
- fab命令好似结合我们编写的fabfile.py(其它文件名必须添加-f filename应用)来搭配使用的,部分命令行参数可以通过相应的方