pytorch_pretrained_bert如何将tensorflow模型转化为pytorch模型
作者:乐清sss 发布时间:2022-04-18 18:07:58
pytorch_pretrained_bert将tensorflow模型转化为pytorch模型
BERT仓库里的模型是TensorFlow版本的,需要进行相应的转换才能在pytorch中使用
在Google BERT仓库里下载需要的模型,这里使用的是中文预训练模型(chinese_L-12_H-768_A_12)
下载chinese_L-12_H-768_A-12.zip后解压,里面有5个文件
chinese_L-12_H-768_A-12.zip后解压,里面有5个文件
bert_config.json
bert_model.ckpt.data-00000-of-00001
bert_model.ckpt.index
bert_model.ckpt.meta
vocab.txt
使用bert仓库里的convert_bert_original_tf_checkpoint_to_pytorch.py将此模型转化为pytorch版本的,这里我的文件夹位置为:D:\Work\BISHE\BERT-Dureader\data\chinese_L-12_H-768_A-12,替换为自己的即可
python convert_tf_checkpoint_to_pytorch.py --tf_checkpoint_path D:\Work\BISHE\BERT-Dureader\data\chinese_L-12_H-768_A-12\bert_model.ckpt --bert_config_file D:\Work\BISHE\BERT-Dureader\data\chinese_L-12_H-768_A-12\bert_config.json --pytorch_dump_path D:\Work\BISHE\BERT-Dureader\data\chinese_L-12_H-768_A-12\pytorch_model.bin
注:这里让我疑惑的是模型有5个文件,为什么转化的时候使用的是bert_model.ckpt,而且这个文件也不存在呀,是我对TensorFlow的模型不太熟悉,查阅资料之后将5个文件的作用说明如下:
$ tree chinese_L-12_H-768_A-12/
chinese_L-12_H-768_A-12/
├── bert_config.json <- 模型配置文件
├── bert_model.ckpt.data-00000-of-00001 <- 保存断点文件列表,可以用来迅速查找最近一次的断点文件
├── bert_model.ckpt.index <- 为数据文件提供索引,存储的核心内容是以tensor name为键以BundleEntry为值的表格entries,BundleEntry主要内容是权值的类型、形状、偏移、校验和等信息。
├── bert_model.ckpt.meta <- 是MetaGraphDef序列化的二进制文件,保存了网络结构相关的数据,包括graph_def和saver_def等
└── vocab.txt <- 模型词汇表文件
0 directories, 5 files
在调用模型时使用chinese_L-12_H-768_A-12\bert_model.ckpt即可。
TensorFlow 读取ckpt文件中的tensor,将ckpt模型转为pytorch模型
想用MobileNet V1训练自己的数据,发现pytorch没有MobileNet V1的预训练权重,只好先下载TensorFlow的预训练权重,再转成pytorch模型。
读取ckpt中的Tensor名称以及Tensor值
TensorFlow的MobileNet V1预训练权重文件如下:
解压完文件后,发现没有.ckpt文件,文件名只需'./my_model/mobilenet_v1_1.0_224/mobilenet_v1_1.0_224.ckpt'这样写就行。
写一半发现Tensor名称好难对应起来。希望能给大家一个参考,也希望大家多多支持脚本之家
来源:https://blog.csdn.net/sunyueqinghit/article/details/103458365


猜你喜欢
- strIn 为 输入的Email地址字符串变量 返回为true或falsereturn Regex.IsMatch(strIn, @&quo
- 环境:Anaconda自带的编译器——Spyder最近才开使用conda,发现conda 就是 yyds,爱啦~一、Tensor(张量)im
- 1.彻底弄懂CSS盒子模式一(DIV布局快速入门) 2.彻底弄懂CSS盒子模式二(导航栏实例) 4.彻底弄懂CSS盒子模式四(绝对定位和相对
- #!/usr/bin/python#-*-coding:utf-8-*-# JCrawler# Author: Jam <810441
- 一、创建一个项目如果这是你第一次使用Django,那么你必须进行一些初始设置。也就是通过自动生成代码来建立一个Django项目--一个Dja
- PHP addslashes() 函数实例在每个双引号(")前添加反斜杠:<?php $str = addslashes(&
- (以下内容部分内容参考了http://adomas.org/javascript-mouse-wheel/ )之前js 仿Photoshop
- 简单使用csv.DictReader()方法示例代码1:import csvf = open('sample','r
- 本文实例讲述了Python实现基于C/S架构的聊天室功能。分享给大家供大家参考,具体如下:一、课程介绍1.简介本次项目课是实现简单聊天室程序
- 定义列表和其他类型的列表稍有不同,它由两部分组成:名称和定义。DT 指定名称,为内联元素。DD 指定定义,为块级元素。标准属性id, cla
- Go语言拼接URL路径有多种方法建议用ResolveReference。JoinPathJoinPath会把多个多个路径合并成一个路径,并且
- 通用用法但上图的字段名,类型需要根据不同接口填写,如某服务接口:因而对应的上传代码如下:# 输出参数:请求响应报文import reques
- 前言上篇文章,讲了经典卷积神经网络-resnet,这篇文章通过resnet网络,做一些具体的事情。一、技术介绍总的来说,第一步首先要加载数据
- 前几天在“CSS那些事儿”的群中,一位读者朋友(小土豆)问我书中提到首字下沉的时候为什么要增加一个清除浮动。当时我自己一时迷惑了,为什么呢,
- This application failed to start because it could not find or load the
- MySQL Community Server 8.0.29安装教程,供大家参考,具体内容如下一、简要说明仅安装MySQL服务器步骤二、前期准
- 前段时间用C语言做了个字符版的推箱子,着实是比较简陋。正好最近用到了Python,然后想着用Python做一个图形界面的推箱子。这回可没有C
- 使用场景当项目越来越庞大之后,不可避免的要拆分成多个子模块,我们希望各个子模块有独立的版本管理,并且由专门的人去维护,这时候我们就要用到gi
- Python PyTorch深度学习框架PyTorch是一个基于Python的深度学习框架,它支持使用CPU和GPU进行高效的神经网络训练。
- sys.dm_io_pending_io_requests可以返回当前IO Pending的状态,对于SQL Server 中每个挂起的I/