PyTorch加载自己的数据集实例详解
作者:飞谷云人工智能 发布时间:2022-07-29 14:10:36
数据预处理在解决深度学习问题的过程中,往往需要花费大量的时间和精力。 数据处理的质量对训练神经网络来说十分重要,良好的数据处理不仅会加速模型训练, 更会提高模型性能。为解决这一问题,PyTorch提供了几个高效便捷的工具, 以便使用者进行数据处理或增强等操作,同时可通过并行化加速数据加载。
数据集存放大致有以下两种方式:
(1)所有数据集放在一个目录下,文件名上附有标签名,数据集存放格式如下: root/cat_dog/cat.01.jpg
root/cat_dog/cat.02.jpg
........................
root/cat_dog/dog.01.jpg
root/cat_dog/dog.02.jpg
......................
(2)不同类别的数据集放在不同目录下,目录名就是标签,数据集存放格式如下:
root/ants/xxx.png
root/ants/xxy.jpeg
root/ants/xxz.png
................
root/bees/123.jpg
root/bees/nsdf3.png
root/bees/asd932_.png
..................
1.1 对第1种数据集的处理步骤
(1)生成包含各文件名的列表(List)
(2)定义Dataset的一个子类,该子类需要继承Dataset类,查看Dataset类的源码
(3)重写父类Dataset中的两个魔法方法: 一个是: __lent__(self),其功能是len(Dataset),返回Dataset的样本数。 另一个是__getitem__(self,index),其功能假设索引为i,使Dataset[i]返回第i个样本。
(4)使用torch.utils.data.DataLoader加载数据集Dataset.
1.2 实例详解
以下以cat-dog数据集为例,说明如何实现自定义数据集的加载。
1.2.1 数据集结构
所有数据集在cat-dog目录下:
.\cat_dog\cat.01.jpg
.\cat_dog\cat.02.jpg
.\cat_dog\cat.03.jpg
....................
.\cat_dog\dog.01.jpg
.\cat_dog\dog.02.jpg
....................
1.2.2 导入需要用到的模块
from torch.utils.data import DataLoader,Dataset
from skimage import io,transform
import matplotlib.pyplot as plt
import oimport torch
from torchvision import transforms, utils
from PIL import Image
import pandas as pd
import numpy as np
#过滤警告信息
import warnings
warnings.filterwarnings("ignore")
1.2.3定义加载自定义数据的类
class MyDataset(Dataset): #继承Dataset
def __init__(self, path_dir, transform=None): #初始化一些属性
self.path_dir = path_dir #文件路径,如'.\data\cat-dog'
self.transform = transform #对图形进行处理,如标准化、截取、转换等
self.images = os.listdir(self.path_dir)#把路径下的所有文件放在一个列表中
def __len__(self):#返回整个数据集的大小
return len(self.images)
def __getitem__(self,index):#根据索引index返回图像及标签
image_index = self.images[index]#根据索引获取图像文件名称
img_path = os.path.join(self.path_dir, image_index)#获取图像的路径或目录
img = Image.open(img_path).convert('RGB')# 读取图像
# 根据目录名称获取图像标签(cat或dog)
label = img_path.split('\\')[-1].split('.')[0]
#把字符转换为数字cat-0,dog-1
label = 1 if 'dog' in label else 0
if self.transform is not None:
img = self.transform(img)
return img,label
1.2.4 实例化类
dataset = MyDataset('.\data\cat-dog',transform=None)
img, label = dataset[0] #将启动魔法方法__getitem__(0)
print(type(img))
<class 'PIL.Image.Image'>
1.2.5 查看图像形状
i=1
for img, label in dataset:
if i
img的形状(500, 374),label的值0
img的形状(300, 280),label的值0
img的形状(489, 499),label的值0
img的形状(431, 410),label的值0
img的形状(300, 224),label的值0
从上面返回样本的形状来看:
(1)每张图片的大小不一样,如果需要取batch训练的神经网络来说很不友好。
(2)返回样本的数值较大,未归一化至[-1, 1]
为此需要对img进行转换,如何转换?只要使用torchvision中的transforms即可
1.2.6 对图像数据进行处理
这里使用torchvision中的transforms模块
from torchvision import transforms as T
transform = T.Compose([
T.Resize(224), # 缩放图片(Image),保持长宽比不变,最短边为224像素
T.CenterCrop(224), # 从图片中间切出224*224的图片
T.ToTensor(), # 将图片(Image)转成Tensor,归一化至[0, 1]
T.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5]) # 标准化至[-1, 1],规定均值和标准差
])
1.2.7查看处理后的数据
dataset = MyDataset('.\data\cat-dog',transform=transform)
for img, label in dataset:
print("图像img的形状{},标签label的值{}".format(img.shape, label))
print("图像数据预处理后:\n",img)
break
图像img的形状torch.Size([3, 224, 224]),标签label的值0
图像数据预处理后:
tensor([[[ 0.9059, 0.9137, 0.9137, ..., 0.9451, 0.9451, 0.9451],
[ 0.9059, 0.9137, 0.9137, ..., 0.9451, 0.9451, 0.9451],
[ 0.9059, 0.9137, 0.9137, ..., 0.9529, 0.9529, 0.9529],
...,
[-0.4824, -0.5294, -0.5373, ..., -0.9216, -0.9294, -0.9451],
[-0.4980, -0.5529, -0.5608, ..., -0.9294, -0.9373, -0.9529],
[-0.4980, -0.5529, -0.5686, ..., -0.9529, -0.9608, -0.9608]],
[[ 0.5686, 0.5765, 0.5765, ..., 0.7961, 0.7882, 0.7882],
[ 0.5686, 0.5765, 0.5765, ..., 0.7961, 0.7882, 0.7882],
[ 0.5686, 0.5765, 0.5765, ..., 0.8039, 0.7961, 0.7961],
...,
[-0.6078, -0.6471, -0.6549, ..., -0.9137, -0.9216, -0.9373],
[-0.6157, -0.6706, -0.6784, ..., -0.9216, -0.9294, -0.9451],
[-0.6157, -0.6706, -0.6863, ..., -0.9451, -0.9529, -0.9529]],
[[-0.0510, -0.0431, -0.0431, ..., 0.2078, 0.2157, 0.2157],
[-0.0510, -0.0431, -0.0431, ..., 0.2078, 0.2157, 0.2157],
[-0.0510, -0.0431, -0.0431, ..., 0.2157, 0.2235, 0.2235],
...,
[-0.9529, -0.9843, -0.9922, ..., -0.9529, -0.9608, -0.9765],
[-0.9686, -0.9922, -1.0000, ..., -0.9608, -0.9686, -0.9843],
[-0.9686, -0.9922, -1.0000, ..., -0.9843, -0.9922, -0.9922]]])
由此可知,数据已标准化、规范化。
1.2.8对数据集进行批量加载
使用DataLoader模块,对数据集dataset进行批量加载
#使用DataLoader加载数据
dataloader = DataLoader(dataset,batch_size=4,shuffle=True)
for batch_datas, batch_labels in dataloader:
print(batch_datas.size(),batch_labels.size())
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([2, 3, 224, 224]) torch.Size([2])
1.2.9随机查看一个批次的图像
import torchvision
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline
# 显示图像
def imshow(img):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
# 随机获取部分训练数据
dataiter = iter(dataloader)
images, labels = dataiter.next()
# 显示图像
imshow(torchvision.utils.make_grid(images))
# 打印标签
print(' '.join('%s' % ["小狗" if labels[j].item()==1 else "小猫" for j in range(4)]))
2 对第2种数据集的处理
处理这种情况比较简单,可分为2步:
(1)使用datasets.ImageFolder读取、处理图像。
(2)使用.data.DataLoader批量加载数据集,示例如下:
import torch
from torchvision import transforms, datasets
data_transform = transforms.Compose([
transforms.RandomSizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
hymenoptera_dataset = datasets.ImageFolder(root='.\catdog\train',
transform=data_transform)
dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,
来源:http://www.feiguyunai.com/index.php/2020/03/17/python-dl-pytorch-custom_data/


猜你喜欢
- 今天我给大家分享的是在不刷新页面的前提下,使用PHP+jQuery+Ajax实现多图片上传的效果。用户只需要点击选择要上传的图片,然后图片自
- 使用select @@identity 得到刚插入数据的ID1.适用于所有 ADO 版本<%Dim loConn, 
- 需求:将查询的两列数据导出到excel中 1.选择数据库,右键任务→导出数据,打开导入导出向导,单击下一步2.在打开的SQ
- 在自己的网站上更新文章时一个比较常见的问题是:文章插图太宽,使整个网页都变形了。如果对每个插图都先进行缩放再插入的话,太麻烦了。我前段时间写
- 前言大家都知道,英文的分词由于单词间是以空格进行分隔的,所以分词要相对的容易些,而中文就不同了,中文中一个句子的分隔就是以字为单位的了,而所
- 一、什么是Golang?Golang(又称Go)是一种由谷歌公司开发的编程语言。它是一种静态类型、编译型、并发型语言,被设计用于构建高效、可
- 一、PK(主键约束)1、什么是主键?在了解主键之前,先了解一下什么是关键字关键字:在表中具有唯一性的字段,比如一个人的身份证号,学号。一个表
- “Be conservative in what you send; be liberal in what you accept. &nbs
- CSSer与其他IT职位一样,在找工作的时候,都会面临着面试官提出的问题,或者给出的试卷。一、超链接点击过后hover样式就不出现的问题?被
- 对于跟我一样,自学javascript且没有其他语言学习经验的人来说,一开始的时候,javascript的调试也是一个比较大的难点,很多基础
- 引言反射是通过实体对象获取反射对象(Value、Type),然后可以操作相应的方法。在某些情况下,我们可能并不知道变量的具体类型,这时候就可
- 在pytorch的CNN代码中经常会看到x.view(x.size(0), -1)首先,在pytorch中的view()函数就是用来改变te
- 目录1、字典的定义字典和列表的区别:字典的基本使用2、循环遍历3、字符串的定义4、字符串的常用操作字符串 查找和替换字符串 文本对齐演练去除
- 只有pd模型文件, 打印所有节点from tensorflow.python.framework import tensor_utilfro
- 不知道有没有人碰到过这样恶心的问题:两张表连接查询并limit,SQL效率很高,但是加上order by以后,语句的执行时间变的巨长,效率巨
- 在使用ORACLE的过程过,我们会经常遇到一些ORACLE产生的错误,对于初学者而言,这些错误可能有点模糊,而且可能一时不知怎么去处理产生的
- 有些小伙伴跟小编讨论了python中使用多线程原理的问题,就聊到了关于python多线程的弊端问题,这点可能在使用的过程中大家会能感觉到。而
- js实现点击掉落特效 先看看效果图 话不多说代码<!DOCTYPE HTML><html><head
- 1、查看数据库的字符集数据库的字符集必须和Linux下设置的环境变量一致,不然会有乱码。以下两个sql语句都可以查到:select * fr
- 今天重新研究了下VB里面的ScriptControl组件,发现asp里面也能调用。研究了下方法,后来和lcx讨论了下。得到了如下代码,在此感