pytorch制作自己的LMDB数据操作示例
作者:团长sama 发布时间:2023-05-24 11:51:27
标签:pytorch,LMDB,数据操作
本文实例讲述了pytorch制作自己的LMDB数据操作。分享给大家供大家参考,具体如下:
前言
记录下pytorch里如何使用lmdb的code,自用
制作部分的Code
code就是ASTER里数据制作部分的代码改了点,aster_train.txt里面就算图片的完整路径每行一个,图片同目录下有同名的txt,里面记着jpg的标签
import os
import lmdb # install lmdb by "pip install lmdb"
import cv2
import numpy as np
from tqdm import tqdm
import six
from PIL import Image
import scipy.io as sio
from tqdm import tqdm
import re
def checkImageIsValid(imageBin):
if imageBin is None:
return False
imageBuf = np.fromstring(imageBin, dtype=np.uint8)
img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE)
imgH, imgW = img.shape[0], img.shape[1]
if imgH * imgW == 0:
return False
return True
def writeCache(env, cache):
with env.begin(write=True) as txn:
for k, v in cache.items():
txn.put(k.encode(), v)
def _is_difficult(word):
assert isinstance(word, str)
return not re.match('^[\w]+$', word)
def createDataset(outputPath, imagePathList, labelList, lexiconList=None, checkValid=True):
"""
Create LMDB dataset for CRNN training.
ARGS:
outputPath : LMDB output path
imagePathList : list of image path
labelList : list of corresponding groundtruth texts
lexiconList : (optional) list of lexicon lists
checkValid : if true, check the validity of every image
"""
assert(len(imagePathList) == len(labelList))
nSamples = len(imagePathList)
env = lmdb.open(outputPath, map_size=1099511627776)#最大空间1048576GB
cache = {}
cnt = 1
for i in range(nSamples):
imagePath = imagePathList[i]
label = labelList[i]
if len(label) == 0:
continue
if not os.path.exists(imagePath):
print('%s does not exist' % imagePath)
continue
with open(imagePath, 'rb') as f:
imageBin = f.read()
if checkValid:
if not checkImageIsValid(imageBin):
print('%s is not a valid image' % imagePath)
continue
#数据库中都是二进制数据
imageKey = 'image-%09d' % cnt#9位数不足填零
labelKey = 'label-%09d' % cnt
cache[imageKey] = imageBin
cache[labelKey] = label.encode()
if lexiconList:
lexiconKey = 'lexicon-%09d' % cnt
cache[lexiconKey] = ' '.join(lexiconList[i])
if cnt % 1000 == 0:
writeCache(env, cache)
cache = {}
print('Written %d / %d' % (cnt, nSamples))
cnt += 1
nSamples = cnt-1
cache['num-samples'] = str(nSamples).encode()
writeCache(env, cache)
print('Created dataset with %d samples' % nSamples)
def get_sample_list(txt_path:str):
with open(txt_path,'r') as fr:
jpg_list=[x.strip() for x in fr.readlines() if os.path.exists(x.replace('.jpg','.txt').strip())]
txt_content_list=[]
for jpg in jpg_list:
label_path=jpg.replace('.jpg','.txt')
with open(label_path,'r') as fr:
try:
str_tmp=fr.readline()
except UnicodeDecodeError as e:
print(label_path)
raise(e)
txt_content_list.append(str_tmp.strip())
return jpg_list,txt_content_list
if __name__ == "__main__":
txt_path='/home/gpu-server/disk/disk1/NumberData/8NumberSample/aster_train.txt'
lmdb_output_path = '/home/gpu-server/project/aster/dataset/train'
imagePathList,labelList=get_sample_list(txt_path)
createDataset(lmdb_output_path, imagePathList, labelList)
读取部分
这里用的pytorch的dataloader,简单记录一下,人比较懒,代码就直接抄过来,不整理拆分了,重点看__getitem__
from __future__ import absolute_import
# import sys
# sys.path.append('./')
import os
# import moxing as mox
import pickle
from tqdm import tqdm
from PIL import Image, ImageFile
import numpy as np
import random
import cv2
import lmdb
import sys
import six
import torch
from torch.utils import data
from torch.utils.data import sampler
from torchvision import transforms
from lib.utils.labelmaps import get_vocabulary, labels2strs
from lib.utils import to_numpy
ImageFile.LOAD_TRUNCATED_IMAGES = True
from config import get_args
global_args = get_args(sys.argv[1:])
if global_args.run_on_remote:
import moxing as mox
#moxing是一个分布式的框架 跳过
class LmdbDataset(data.Dataset):
def __init__(self, root, voc_type, max_len, num_samples, transform=None):
super(LmdbDataset, self).__init__()
if global_args.run_on_remote:
dataset_name = os.path.basename(root)
data_cache_url = "/cache/%s" % dataset_name
if not os.path.exists(data_cache_url):
os.makedirs(data_cache_url)
if mox.file.exists(root):
mox.file.copy_parallel(root, data_cache_url)
else:
raise ValueError("%s not exists!" % root)
self.env = lmdb.open(data_cache_url, max_readers=32, readonly=True)
else:
self.env = lmdb.open(root, max_readers=32, readonly=True)
assert self.env is not None, "cannot create lmdb from %s" % root
self.txn = self.env.begin()
self.voc_type = voc_type
self.transform = transform
self.max_len = max_len
self.nSamples = int(self.txn.get(b"num-samples"))
self.nSamples = min(self.nSamples, num_samples)
assert voc_type in ['LOWERCASE', 'ALLCASES', 'ALLCASES_SYMBOLS','DIGITS']
self.EOS = 'EOS'
self.PADDING = 'PADDING'
self.UNKNOWN = 'UNKNOWN'
self.voc = get_vocabulary(voc_type, EOS=self.EOS, PADDING=self.PADDING, UNKNOWN=self.UNKNOWN)
self.char2id = dict(zip(self.voc, range(len(self.voc))))
self.id2char = dict(zip(range(len(self.voc)), self.voc))
self.rec_num_classes = len(self.voc)
self.lowercase = (voc_type == 'LOWERCASE')
def __len__(self):
return self.nSamples
def __getitem__(self, index):
assert index <= len(self), 'index range error'
index += 1
img_key = b'image-%09d' % index
imgbuf = self.txn.get(img_key)
#由于Image.open需要一个类文件对象 所以这里需要把二进制转为一个类文件对象
buf = six.BytesIO()
buf.write(imgbuf)
buf.seek(0)
try:
img = Image.open(buf).convert('RGB')
# img = Image.open(buf).convert('L')
# img = img.convert('RGB')
except IOError:
print('Corrupted image for %d' % index)
return self[index + 1]
# reconition labels
label_key = b'label-%09d' % index
word = self.txn.get(label_key).decode()
if self.lowercase:
word = word.lower()
## fill with the padding token
label = np.full((self.max_len,), self.char2id[self.PADDING], dtype=np.int)
label_list = []
for char in word:
if char in self.char2id:
label_list.append(self.char2id[char])
else:
## add the unknown token
print('{0} is out of vocabulary.'.format(char))
label_list.append(self.char2id[self.UNKNOWN])
## add a stop token
label_list = label_list + [self.char2id[self.EOS]]
assert len(label_list) <= self.max_len
label[:len(label_list)] = np.array(label_list)
if len(label) <= 0:
return self[index + 1]
# label length
label_len = len(label_list)
if self.transform is not None:
img = self.transform(img)
return img, label, label_len
希望本文所述对大家Python程序设计有所帮助。
来源:https://blog.csdn.net/sinat_24899403/article/details/102795355


猜你喜欢
- 使用环境在cmd模式下输入 mysql --version (查看mysql安装的版本).完整的命令可以通过mysql --help来获取.
- 前言说起这个事情吧也相对来说比较尴尬,对于一个技术来说忘记密码然后找回密码都是相当简单的一个事情,但是在生产环境中没有保存记录只能是自己的失
- 如何通过PHP实现Des加密算法代码实例注:php7以上不支持了,因为php7去掉了某些函数, 另外变量的{}要改为[]<?phpcl
- prompt命令可以在mysql提示符中显示当前用户、数据库、时间等信息mysql -uroot -p --prompt="\\u
- 本文实例讲述了python config文件的读写操作。分享给大家供大家参考,具体如下:1、设置配置文件[mysql]host = 1234
- 问题:不同版本提交的城市文件夹数量固定,怎样确定本版本成果中缺少了哪些城市?背景:已有参照文件作为标准,利用取差集的方法#-*- codin
- 描述: 日志按日期、大小回滚代码:# -*- coding: utf-8 -*-import osimport logging.handle
- 文件名字处理文件名字得看业务要求。不需要保留原始名字,则随机生成名字,拼接上白名单校验过的后缀即可。反之要谨慎处理://允许上传的后缀白名单
- 在matplotlib官网看到了第三方库numpngw的简介,利用该库作为插件可以辅助matplotlib生成png动画。numpngw概述
- 简介网上流传的部分可以百度关键词“Python”和“word”后查看文章学习,以下内容为个人实践,修正了不能运行出错的情况。代码示例impo
- 约定:import pandas as pdimport numpy as npReIndex重新索引reindex()是pandas对象的
- 引言“深入认识Python内建类型”这部分的内容会从源码角度为大家介绍Python中各种常用的内建类
- 本文实例分析了Python中的异常处理try/except/finally/raise用法。分享给大家供大家参考,具体如下:异常发生在程序执
- isset(PHP 3, PHP 4, PHP 5 )isset -- 检测变量是否设置描述bool isset ( mixed var [
- 1.常用数据结构之列表我们先给大家一个编程任务,将一颗色子掷6000次,统计每个点数出现的次数。这个任务对大家来说应该是非常简单的,我们可以
- 假设有一名为"addnewuser"的存储过程,其内容如下:Create PROCEDURE dbo
- 一、什么是RequestsRequests 是Python语编写,基于urllib,采Apache2 Licensed开源协议的 HTTP
- property 和 attribute非常容易混淆,两个单词的中文翻译也都非常相近(property:属性,attribute:特性),但
- jqGrid是一个优秀的基于jQuery的DataGrid框架,想必大伙儿也不陌生,网上基于ASP的资料很少,我提供一个,数据格式是json
- 本文实例讲述了python打开url并按指定块读取网页内容的方法。分享给大家供大家参考。具体实现方法如下:import urllibpage