pytorch制作自己的LMDB数据操作示例
作者:团长sama 发布时间:2023-05-24 11:51:27
标签:pytorch,LMDB,数据操作
本文实例讲述了pytorch制作自己的LMDB数据操作。分享给大家供大家参考,具体如下:
前言
记录下pytorch里如何使用lmdb的code,自用
制作部分的Code
code就是ASTER里数据制作部分的代码改了点,aster_train.txt里面就算图片的完整路径每行一个,图片同目录下有同名的txt,里面记着jpg的标签
import os
import lmdb # install lmdb by "pip install lmdb"
import cv2
import numpy as np
from tqdm import tqdm
import six
from PIL import Image
import scipy.io as sio
from tqdm import tqdm
import re
def checkImageIsValid(imageBin):
if imageBin is None:
return False
imageBuf = np.fromstring(imageBin, dtype=np.uint8)
img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE)
imgH, imgW = img.shape[0], img.shape[1]
if imgH * imgW == 0:
return False
return True
def writeCache(env, cache):
with env.begin(write=True) as txn:
for k, v in cache.items():
txn.put(k.encode(), v)
def _is_difficult(word):
assert isinstance(word, str)
return not re.match('^[\w]+$', word)
def createDataset(outputPath, imagePathList, labelList, lexiconList=None, checkValid=True):
"""
Create LMDB dataset for CRNN training.
ARGS:
outputPath : LMDB output path
imagePathList : list of image path
labelList : list of corresponding groundtruth texts
lexiconList : (optional) list of lexicon lists
checkValid : if true, check the validity of every image
"""
assert(len(imagePathList) == len(labelList))
nSamples = len(imagePathList)
env = lmdb.open(outputPath, map_size=1099511627776)#最大空间1048576GB
cache = {}
cnt = 1
for i in range(nSamples):
imagePath = imagePathList[i]
label = labelList[i]
if len(label) == 0:
continue
if not os.path.exists(imagePath):
print('%s does not exist' % imagePath)
continue
with open(imagePath, 'rb') as f:
imageBin = f.read()
if checkValid:
if not checkImageIsValid(imageBin):
print('%s is not a valid image' % imagePath)
continue
#数据库中都是二进制数据
imageKey = 'image-%09d' % cnt#9位数不足填零
labelKey = 'label-%09d' % cnt
cache[imageKey] = imageBin
cache[labelKey] = label.encode()
if lexiconList:
lexiconKey = 'lexicon-%09d' % cnt
cache[lexiconKey] = ' '.join(lexiconList[i])
if cnt % 1000 == 0:
writeCache(env, cache)
cache = {}
print('Written %d / %d' % (cnt, nSamples))
cnt += 1
nSamples = cnt-1
cache['num-samples'] = str(nSamples).encode()
writeCache(env, cache)
print('Created dataset with %d samples' % nSamples)
def get_sample_list(txt_path:str):
with open(txt_path,'r') as fr:
jpg_list=[x.strip() for x in fr.readlines() if os.path.exists(x.replace('.jpg','.txt').strip())]
txt_content_list=[]
for jpg in jpg_list:
label_path=jpg.replace('.jpg','.txt')
with open(label_path,'r') as fr:
try:
str_tmp=fr.readline()
except UnicodeDecodeError as e:
print(label_path)
raise(e)
txt_content_list.append(str_tmp.strip())
return jpg_list,txt_content_list
if __name__ == "__main__":
txt_path='/home/gpu-server/disk/disk1/NumberData/8NumberSample/aster_train.txt'
lmdb_output_path = '/home/gpu-server/project/aster/dataset/train'
imagePathList,labelList=get_sample_list(txt_path)
createDataset(lmdb_output_path, imagePathList, labelList)
读取部分
这里用的pytorch的dataloader,简单记录一下,人比较懒,代码就直接抄过来,不整理拆分了,重点看__getitem__
from __future__ import absolute_import
# import sys
# sys.path.append('./')
import os
# import moxing as mox
import pickle
from tqdm import tqdm
from PIL import Image, ImageFile
import numpy as np
import random
import cv2
import lmdb
import sys
import six
import torch
from torch.utils import data
from torch.utils.data import sampler
from torchvision import transforms
from lib.utils.labelmaps import get_vocabulary, labels2strs
from lib.utils import to_numpy
ImageFile.LOAD_TRUNCATED_IMAGES = True
from config import get_args
global_args = get_args(sys.argv[1:])
if global_args.run_on_remote:
import moxing as mox
#moxing是一个分布式的框架 跳过
class LmdbDataset(data.Dataset):
def __init__(self, root, voc_type, max_len, num_samples, transform=None):
super(LmdbDataset, self).__init__()
if global_args.run_on_remote:
dataset_name = os.path.basename(root)
data_cache_url = "/cache/%s" % dataset_name
if not os.path.exists(data_cache_url):
os.makedirs(data_cache_url)
if mox.file.exists(root):
mox.file.copy_parallel(root, data_cache_url)
else:
raise ValueError("%s not exists!" % root)
self.env = lmdb.open(data_cache_url, max_readers=32, readonly=True)
else:
self.env = lmdb.open(root, max_readers=32, readonly=True)
assert self.env is not None, "cannot create lmdb from %s" % root
self.txn = self.env.begin()
self.voc_type = voc_type
self.transform = transform
self.max_len = max_len
self.nSamples = int(self.txn.get(b"num-samples"))
self.nSamples = min(self.nSamples, num_samples)
assert voc_type in ['LOWERCASE', 'ALLCASES', 'ALLCASES_SYMBOLS','DIGITS']
self.EOS = 'EOS'
self.PADDING = 'PADDING'
self.UNKNOWN = 'UNKNOWN'
self.voc = get_vocabulary(voc_type, EOS=self.EOS, PADDING=self.PADDING, UNKNOWN=self.UNKNOWN)
self.char2id = dict(zip(self.voc, range(len(self.voc))))
self.id2char = dict(zip(range(len(self.voc)), self.voc))
self.rec_num_classes = len(self.voc)
self.lowercase = (voc_type == 'LOWERCASE')
def __len__(self):
return self.nSamples
def __getitem__(self, index):
assert index <= len(self), 'index range error'
index += 1
img_key = b'image-%09d' % index
imgbuf = self.txn.get(img_key)
#由于Image.open需要一个类文件对象 所以这里需要把二进制转为一个类文件对象
buf = six.BytesIO()
buf.write(imgbuf)
buf.seek(0)
try:
img = Image.open(buf).convert('RGB')
# img = Image.open(buf).convert('L')
# img = img.convert('RGB')
except IOError:
print('Corrupted image for %d' % index)
return self[index + 1]
# reconition labels
label_key = b'label-%09d' % index
word = self.txn.get(label_key).decode()
if self.lowercase:
word = word.lower()
## fill with the padding token
label = np.full((self.max_len,), self.char2id[self.PADDING], dtype=np.int)
label_list = []
for char in word:
if char in self.char2id:
label_list.append(self.char2id[char])
else:
## add the unknown token
print('{0} is out of vocabulary.'.format(char))
label_list.append(self.char2id[self.UNKNOWN])
## add a stop token
label_list = label_list + [self.char2id[self.EOS]]
assert len(label_list) <= self.max_len
label[:len(label_list)] = np.array(label_list)
if len(label) <= 0:
return self[index + 1]
# label length
label_len = len(label_list)
if self.transform is not None:
img = self.transform(img)
return img, label, label_len
希望本文所述对大家Python程序设计有所帮助。
来源:https://blog.csdn.net/sinat_24899403/article/details/102795355
0
投稿
猜你喜欢
- //********************************************************************
- 将汉字转为拼音,可以用于批量汉字注音、文字排序、拼音检索文字等常见场景。现在互联网上有许多拼音转换工具,基于Python的开源模块也不少,今
- 效果图如下:图1(头像图片剪成圆形的,其他为透明)图2(给图片的4个角加椭圆)以前没处理过,处理起来真是有点费力呀。用到的模块:import
- innewDropList = [9,10,11,12,22,50,51,60,61]newDB = newDB[newDB['gr
- 今天来给大家推荐一个Python当中超级好用的内置函数,那便是lambda方法,本篇教程大致和大家分享什么是lambda函数lambda函数
- 近期Github开源了一款基于Python开发、名为Textshot的截图工具,刚开源不到半个月已经500+Star。这两天抽空看了一下Te
- 序 号前 缀使用的变量/范围或数据类型1a or arrArray2b or blnBoolean3bytByte4
- golang 字符串 int uint int64 uint64 互转字符串 转 intintNum, _ = strconv.Atoi(i
- 这可是个综合性的问题,看看下面对文件操作的集大成代码:<% 'Set file i/
- 这篇博客将介绍如何通过OpenCV中图像修复的技术——cv2.inpaint() 去除旧照片中的小噪音、笔划等。并提供一个可交互式的程序,利
- 本文为大家分享了python实现学生管理系统的具体代码,供大家参考,具体内容如下1.0版本学生管理系统''' 1.添
- 代码class Shuxing(): def __init__(self, size = 10): s
- 前言 1. 概述共享坐标轴就是几幅子图之间共享x轴或y轴,这一部分主要了解如何在利用matplotlib制图时共享坐标轴。pyplot.s
- HTTP请求方法GET:请求指定的页面信息,并返回实体主体。HEAD:类似于get请求,只不过返回的响应中没有具体的内容,用于获取报头POS
- 1. 用SimpleITK读取dicom序列:import SimpleITK as sitkimport numpy as npimg_p
- 对于web开来说,用户登陆、注册、文件上传等是最基础的功能,针对不同的web框架,相关的文章非常多,但搜索之后发现大多都不具有完整性,对于想
- 本文实例为大家分享了python OpenCV来表示USB摄像头画面的具体代码,供大家参考,具体内容如下确认Python版本$ python
- 导言:到目前为止,我们的教程围绕的是text数据。然而,很多应用程序既需要处理text数据,也需要处理二进制数据。比如招聘网站可能需要用户上
- 当成功安装了PHP,MYSQL后,我们一般要安装phpMyAdmin来管理你的mysql。本文介绍了phpMyAdmin 2.10.2的配置
- 简介观察者模式是行为型模式的一种,定义了对象间一对多的关系。当对象的状态发生变化时候,依赖于它的对象会得到通知。适用场景类似触发钩子事件,可