PyTorch数据读取的实现示例
作者:YXHPY 发布时间:2022-01-31 04:15:48
前言
PyTorch
作为一款深度学习框架,已经帮助我们实现了很多很多的功能了,包括数据的读取和转换了,那么这一章节就介绍一下PyTorch
内置的数据读取模块吧
模块介绍
pandas 用于方便操作含有字符串的表文件,如csv
zipfile python内置的文件解压包
cv2 用于图片处理的模块,读入的图片模块为BGR,N H W C
torchvision.transforms 用于图片的操作库,比如随机裁剪、缩放、模糊等等,可用于数据的增广,但也不仅限于内置的图片操作,也可以自行进行图片数据的操作,这章也会讲解
torch.utils.data.Dataset torch内置的对象类型
torch.utils.data.DataLoader 和Dataset配合使用可以实现数据的加速读取和随机读取等等功能
import zipfile # 解压
import pandas as pd # 操作数据
import os # 操作文件或文件夹
import cv2 # 图像操作库
import matplotlib.pyplot as plt # 图像展示库
from torch.utils.data import Dataset # PyTorch内置对象
from torchvision import transforms # 图像增广转换库 PyTorch内置
import torch
初步读取数据
数据下载到此处
我们先初步编写一个脚本来实现图片的展示
# 解压文件到指定目录
def unzip_file(root_path, filename):
full_path = os.path.join(root_path, filename)
file = zipfile.ZipFile(full_path)
file.extractall(root_path)
unzip_file(root_path, zip_filename)
# 读入csv文件
face_landmarks = pd.read_csv(os.path.join(extract_path, csv_filename))
# pandas读出的数据如想要操作索引 使用iloc
image_name = face_landmarks.iloc[:,0]
landmarks = face_landmarks.iloc[:,1:]
# 展示
def show_face(extract_path, image_file, face_landmark):
plt.imshow(plt.imread(os.path.join(extract_path, image_file)), cmap='gray')
point_x = face_landmark.to_numpy()[0::2]
point_y = face_landmark.to_numpy()[1::2]
plt.scatter(point_x, point_y, c='r', s=6)
show_face(extract_path, image_name.iloc[1], landmarks.iloc[1])
使用内置库来实现
实现MyDataset
使用内置库是我们的代码更加的规范,并且可读性也大大增加
继承Dataset,需要我们实现的有两个地方:
实现
__len__
返回数据的长度,实例化调用len()
时返回__getitem__
给定数据的索引返回对应索引的数据如:a[0]transform
数据的额外操作时调用
class FaceDataset(Dataset):
def __init__(self, extract_path, csv_filename, transform=None):
super(FaceDataset, self).__init__()
self.extract_path = extract_path
self.csv_filename = csv_filename
self.transform = transform
self.face_landmarks = pd.read_csv(os.path.join(extract_path, csv_filename))
def __len__(self):
return len(self.face_landmarks)
def __getitem__(self, idx):
image_name = self.face_landmarks.iloc[idx,0]
landmarks = self.face_landmarks.iloc[idx,1:].astype('float32')
point_x = landmarks.to_numpy()[0::2]
point_y = landmarks.to_numpy()[1::2]
image = plt.imread(os.path.join(self.extract_path, image_name))
sample = {'image':image, 'point_x':point_x, 'point_y':point_y}
if self.transform is not None:
sample = self.transform(sample)
return sample
测试功能是否正常
face_dataset = FaceDataset(extract_path, csv_filename)
sample = face_dataset[0]
plt.imshow(sample['image'], cmap='gray')
plt.scatter(sample['point_x'], sample['point_y'], c='r', s=2)
plt.title('face')
实现自己的数据处理模块
内置的在torchvision.transforms
模块下,由于我们的数据结构不能满足内置模块的要求,我们就必须自己实现
图片的缩放,由于缩放后人脸的标注位置也应该发生对应的变化,所以要自己实现对应的变化
class Rescale(object):
def __init__(self, out_size):
assert isinstance(out_size,tuple) or isinstance(out_size,int), 'out size isinstance int or tuple'
self.out_size = out_size
def __call__(self, sample):
image, point_x, point_y = sample['image'], sample['point_x'], sample['point_y']
new_h, new_w = self.out_size if isinstance(self.out_size,tuple) else (self.out_size, self.out_size)
new_image = cv2.resize(image,(new_w, new_h))
h, w = image.shape[0:2]
new_y = new_h / h * point_y
new_x = new_w / w * point_x
return {'image':new_image, 'point_x':new_x, 'point_y':new_y}
将数据转换为torch
认识的数据格式因此,就必须转换为tensor
注意
: cv2
和matplotlib
读出的图片默认的shape为N H W C
,而torch
默认接受的是N C H W
因此使用tanspose
转换维度,torch
转换多维度使用permute
class ToTensor(object):
def __call__(self, sample):
image, point_x, point_y = sample['image'], sample['point_x'], sample['point_y']
new_image = image.transpose((2,0,1))
return {'image':torch.from_numpy(new_image), 'point_x':torch.from_numpy(point_x), 'point_y':torch.from_numpy(point_y)}
测试
transform = transforms.Compose([Rescale((1024, 512)), ToTensor()])
face_dataset = FaceDataset(extract_path, csv_filename, transform=transform)
sample = face_dataset[0]
plt.imshow(sample['image'].permute((1,2,0)), cmap='gray')
plt.scatter(sample['point_x'], sample['point_y'], c='r', s=2)
plt.title('face')
使用Torch内置的loader加速读取数据
data_loader = DataLoader(face_dataset, batch_size=4, shuffle=True, num_workers=0)
for i in data_loader:
print(i['image'].shape)
break
torch.Size([4, 3, 1024, 512])
注意
: windows
环境尽量不使用num_workers
会发生报错
来源:https://blog.csdn.net/weixin_42263486/article/details/108295120


猜你喜欢
- 在处理数据和进行机器学习的时候,遇到了大量需要处理的时间序列。比如说:数据库读取的str和time的转化,还有time的差值计算。总结一下p
- 背景开工前我就觉得有什么不太对劲,感觉要背锅。这可不,上班第三天就捅锅了。我们有个了不起的后台程序,可以动态加载模块,并以线程方式运行,通过
- 注释:在大多数的情况下,修改MySQL是需要有mysql里的root权限的,所以一般用户无法更改密码,除非请求管理员。方法1使用phpmya
- 内容摘要:通常的,ASP中表单提交的数据一般被写入数据库。然而,如果你想让发送数据更为简便易行,那么,可以将它书写为XML文件格式。这种方式
- progress库安装和介绍1.安装progress库progress是Python第三方库,在终端执行 pip 命令安装。pip inst
- 最近因为数学建模3天速成Python,然后做了一道网络的题,要画网络图。在网上找了一些,发现都是一些很基础的丑陋红点图,并且关于网络的一些算
- 平面设计 常用尺寸 三折页广告 标准尺寸: (A4)210mm x 285mm普通宣传册 标准尺寸: (A4)210mm x 285mm文件
- 我们可以用DataFrame的apply函数实现对多列,多行的操作。需要记住的是,参数axis设为1是对列进行操作,参数axis设为0是对行
- JavaScript定义函数的三种实现方法【1】正常方法function print(msg){ document.write(
- 一直很想就搜索结果页写一些心得文章出来,甚至连目录都整理好了可是就是一直没有动手。因为总是觉得还差很多东西需要研究需要分析需要验证。最近也组
- 最近正好在寻求一种Python的数据库ORM (Object Relational Mapper),SQLAlchemy (项目主页)这个开
- 什么是科赫曲线科赫曲线是de Rham曲线的特例。给定线段AB,科赫曲线可以由以下步骤生成: 将线段分成三等份(AC,CD,DB) 以CD为
- 1.列表(本部分内容出入官方文档)对于这个功能,微信小程序中并没有提供类似于Android中listview性质的控件,所以我们需要使用 w
- 本文实例讲述了Python3中的真除和Floor除法用法。分享给大家供大家参考,具体如下:在Python3中,除法运算有两种,一种是真除,一
- python常见的错误有1.NameError变量名错误2.IndentationError代码缩进错误3.AttributeError对象
- 上一篇介绍了 HTML5 中 Canvas 的路径,这篇将要介绍一下 Canvas&nbs
- 前言随着我们不断地在一个文件中添加新的功能, 就会使得文件变得很长。 即便使用了继承,也抑制不住类的成长。为了解决这一问题,我们可以将类存储
- 循环可以用来重复执行某条语句,直到某个条件得到满足或遍历所有元素。1 for循环是for循环,可以把集合数据类型list、tuple、dic
- 在将string类型的数据类型转换为spark rdd时,一直报这个错,StructType can not accept object %
- 本文实例讲述了Python爬虫实现爬取百度百科词条功能。分享给大家供大家参考,具体如下:爬虫是一个自动提取网页的程序,它为搜索引擎从万维网上