PyTorch预训练Bert模型的示例
作者:BLACK 发布时间:2021-11-12 14:31:39
本文介绍以下内容:
1. 使用transformers框架做预训练的bert-base模型;
2. 开发平台使用Google的Colab平台,白嫖GPU加速;
3. 使用datasets模块下载IMDB影评数据作为训练数据。
transformers模块简介
transformers框架为Huggingface开源的深度学习框架,支持几乎所有的Transformer架构的预训练模型。使用非常的方便,本文基于此框架,尝试一下预训练模型的使用,简单易用。
本来打算预训练bert-large模型,发现colab上GPU显存不够用,只能使用base版本了。打开colab,并且设置好GPU加速,接下来开始介绍代码。
代码实现
首先安装数据下载模块和transformers包。
pip install datasets
pip install transformers
使用datasets下载IMDB数据,返回DatasetDict类型的数据.返回的数据是文本类型,需要进行编码。下面会使用tokenizer进行编码。
from datasets import load_dataset
imdb = load_dataset('imdb')
print(imdb['train'][:3]) # 打印前3条训练数据
接下来加载tokenizer和模型.从transformers导入AutoModelForSequenceClassification, AutoTokenizer,创建模型和tokenizer。
from transformers import AutoModelForSequenceClassification, AutoTokenizer
model_checkpoint = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=2)
对原始数据进行编码,并且分批次(batch)
def preprocessing_func(examples):
return tokenizer(examples['text'],
padding=True,
truncation=True, max_length=300)
batch_size = 16
encoded_data = imdb.map(preprocessing_func, batched=True, batch_size=batch_size)
上面得到编码数据,每个批次设置为16.接下来需要指定训练的参数,训练参数的指定使用transformers给出的接口类TrainingArguments,模型的训练可以使用Trainer。
from transformers import Trainer, TrainingArguments
args = TrainingArguments(
'out',
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
learning_rate=5e-5,
evaluation_strategy='epoch',
num_train_epochs=10,
load_best_model_at_end=True,
)
trainer = Trainer(
model,
args=args,
train_dataset=encoded_data['train'],
eval_dataset=encoded_data['test'],
tokenizer=tokenizer
)
训练模型使用trainer对象的train方法
trainer.train()
评估模型使用trainer对象的evaluate方法
trainer.evaluate()
总结
本文介绍了基于transformers框架实现的bert预训练模型,此框架提供了非常友好的接口,可以方便读者尝试各种预训练模型。同时datasets也提供了很多数据集,便于学习NLP的各种问题。加上Google提供的colab环境,数据下载和预训练模型下载都非常快,建议读者自行去炼丹。本文完整的案例下载
来源:http://www.blackedu.vip/729/pytorch-yu-xun-lianbert-mo-xing/?utm_source=tuicool&utm_medium=referral


猜你喜欢
- 1、看机器配置,指三大件:cpu、内存、硬盘2、看mysql配置参数3、查系mysql行状态,可以用mysqlreport工具来查看4、查看
- 一、mock.js的使用mock.js的使用步骤① 下载依赖 npm install mock -d(开发环境使用)② 引入到main.js
- 今天我们整理了ip地址和身份证的javascript验证方法。虽然ip地址和身份证的验证不是很经常会遇到,但是大家也可以研究一下js代码,里
- 废话不多说,直接开干!抖音字符视频在今年火过一段时间。反正我是始终忘不了那段刘耕宏老师本草纲目的音乐…这一次自己也来实
- //清空form选择 function clearForm(id){ var formObj = document.getElementBy
- 对于三目运算符(ternary operator),python可以用conditional expressions来替代如对于x<5
- 本文实例为大家分享了Vue+Flask实现图片传输功能的具体代码,供大家参考,具体内容如下完整流程:1.图片转为formdata 传输到后端
- CrawlSpider作用:用于进行全站数据爬取CrawlSpider就是Spider的一个子类如何新建一个基于CrawlSpider的爬虫
- 目录什么是 JSON在哪里使用JSON基本的 JSON 语法如何在 Python 中处理 JSON 数据包含 JSON 模块使用 json.
- 从Python3.2引入的concurrent.futures模块,Python2.5以上需要在pypi中安装futures包。future
- 本文实例讲述了基于进程内通讯的python聊天室实现方法。分享给大家供大家参考。具体如下:#!/usr/bin/env python# Ad
- MongoDB是一个文档型数据库,是NOSQL家族中最重要的成员之一,以下代码封装了MongoDB的基本操作。MongoDBConfig.j
- python / 和 % 和 //(地板除)用于对数据进行除法运算。python中 // 和 / 和 %简介python中与除法相关的三个运
- 代码共享url: http://code.google.com/p/region-select-js/ 数据已经更新到中国统计局网站中的20
- Python偏函数Python偏函数和我们之前所学习的函数传参中的缺省参数有些类似,但是在实际应用中还是有所区别的,下面通过模拟一个场景一步
- 本文实例讲述了Python决策树和随机森林算法。分享给大家供大家参考,具体如下:决策树和随机森林都是常用的分类算法,它们的判断逻辑和人的思维
- 相关知识点:#key-value#字典是无序的,因为他没有下标,通过key找info={ 'stu01':"liu
- asp替换函数如下:Function ReplaceNoIgnoreCase(str,replStr) &n
- 一、CAN报文简介CAN是控制器局域网络(Controller Area Network, CAN)的简称,是由以研发和生产汽车电子产品著称
- 数据处理在现代企业运营中变得越来越重要,越来越关键,甚至会成为企业发展的一项瓶颈. 数据保护的重要性也不言而喻. 如果一个企业没有很好的数据