PyTorch实现重写/改写Dataset并载入Dataloader
作者:全员鳄鱼 发布时间:2023-10-31 17:19:35
标签:PyTorch,重写,改写,Dataset
前言
众所周知,Dataset和Dataloder是pytorch中进行数据载入的部件。必须将数据载入后,再进行深度学习模型的训练。在pytorch的一些案例教学中,常使用torchvision.datasets
自带的MNIST、CIFAR-10数据集,一般流程为:
# 下载并存放数据集
train_dataset = torchvision.datasets.CIFAR10(root="数据集存放位置",download=True)
# load数据
train_loader = torch.utils.data.DataLoader(dataset=train_dataset)
但是,在我们自己的模型训练中,需要使用非官方自制的数据集。这时应该怎么办呢?
我们可以通过改写torch.utils.data.Dataset
中的__getitem__
和__len__
来载入我们自己的数据集。__getitem__
获取数据集中的数据,__len__
获取整个数据集的长度(即个数)。
改写
采用pytorch官网案例中提供的一个脸部landmark数据集。数据集中含有存放landmark的csv文件,但是我们在这篇文章中不使用(其实也可以随便下载一些图片作数据集来实验)。
import os
import torch
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
plt.ion() # interactive mode
torch.utils.data.Dataset
是一个抽象类,我们自己的数据集需要继承Dataset
,然后改写上述两个函数:
class ImageLoader(Dataset):
def __init__(self, file_path, transform=None):
super(ImageLoader,self).__init__()
self.file_path = file_path
self.transform = transform # 对输入图像进行预处理,这里并没有做,预设为None
self.image_names = os.listdir(self.file_path) # 文件名的列表
def __getitem__(self,idx):
image = self.image_names[idx]
image = io.imread(os.path.join(self.file_path,image))
# if self.transform:
# image= self.transform(image)
return image
def __len__(self):
return len(self.image_names)
# 设置自己存放的数据集位置,并plot展示
imageloader = ImageLoader(file_path="D:\\Projects\\datasets\\faces\\")
# imageloader.__len__() # 输出数据集长度(个数),应为71
# print(imageloader.__getitem__(0)) # 以数据形式展示
plt.imshow(imageloader.__getitem__(0)) # 以图像形式展示
plt.show()
得到的图片输出:
得到的数据输出,:
array([[[ 66, 59, 53],
[ 66, 59, 53],
[ 66, 59, 53],
...,
[ 59, 54, 48],
[ 59, 54, 48],
[ 59, 54, 48]],
...,
[153, 141, 129],
[158, 146, 134],
[158, 146, 134]]], dtype=uint8)
上面看到dytpe=uint8
,实际进行训练的时候,常常需要更改成float
的数据类型。可以使用:
# 直接改成pytorch中的tensor下的float格式
# 也可以用numpy的改成普通的float格式
to_float= torch.from_numpy(imageloader.__getitem__(0)).float()
改写完成后,直接使用train_loader =torch.utils.data.DataLoader(dataset=imageloader)
载入到Dataloader
中,就可以使用了。
下面的代码可以试着运行一下,产生的是一模一样的图片结果。
train_loader = torch.utils.data.DataLoader(dataset=imageloader)
train_loader.dataset[0]
plt.imshow(train_loader.dataset[0])
plt.show()
来源:https://blog.csdn.net/qq_38372240/article/details/107322677


猜你喜欢
- 前言:前面我们提到了Python数据类型中的内置数值类型与字符串类型。今天学习一下Python的序列数据类型,要知道的是在Python中没有
- 本文实例讲述了JavaScript常用数学函数用法。分享给大家供大家参考,具体如下:一、代码<script language=&quo
- 前言在MySQL中,并不是你建立了索引,并且你在SQL中使用到了该列,MySQL就肯定会使用到那些索引的,有一些情况很可能在你不知不觉中,你
- JavaScript 中的并没有提供像 VBScript 里的 DateAdd 方法用于日
- 做了一个Python的小项目。利用了一点python的可视化技巧,做出烟花绽放的效果,文章的灵感来自网络上一位大神。一.编译环境Pychar
- 实现效果图如下:当我点击 + 按钮时,会添加一行输入框组;当点击 - 按钮时,会删除这一行输入框组html代码如下:<div clas
- 如何通过PHP实现Des加密算法代码实例注:php7以上不支持了,因为php7去掉了某些函数, 另外变量的{}要改为[]<?phpcl
- 前言不知道什么是版本库的,扇自己两个大嘴巴;知道但不用的,扇自己四个大嘴巴。快扇去。你真扇了?那你是个大傻瓜。扇什么扇,有扇自己的功夫,还不
- 最近接触到微服务框架go-zero,翻看了整个框架代码,发现结构清晰、代码简洁,所以决定阅读源码学习下,本次阅读的源码位于core/sync
- JavaScript/Dom中有很多很零碎的东西,让人总是感觉理解的有些“朦胧”,因此,有时候还是应该总结一下,对于Event对象,前两天看
- 在进行python数据分析的时候,首先要进行数据预处理。有时候不得不处理一些非数值类别的数据,嗯, 今天要说的就是面对这些数据该如何处理。目
- 目录一、简单文本类型数据二、复杂型表格提取三、图片型表格提取大家好,从PDF中提取信息是办公场景中经常需要用到的操作,也是经常又读者在后台问
- 在进行接口自动化测试时,有好多接口都基于登陆接口的响应值来关联进行操作的,在次之前试了很多方法,都没有成功,其实很简单用session来做。
- 本文实例讲述了php下pdo的mysql事务处理用法。分享给大家供大家参考。具体分析如下:php+mysql事务处理的几个步骤:1.关闭自动
- 说到客户端数据存储,可能第一时间想到的是cookies,这是一种网站常见的存储数据的方法。它的最大优点是兼容性好,几乎所有浏览器都具有这个功
- 本文实例讲述了Go语言模拟while语句实现无限循环的方法。分享给大家供大家参考。具体实现方法如下:这段代码把for语句当成C语言里的whi
- 虚拟环境管理创建虚拟环境#默认路径下创建虚拟环境conda create -n pythonVirtual python=x.x # -n:
- 本文实例讲述了Python实现的多叉树寻找最短路径算法。分享给大家供大家参考,具体如下:多叉树的最短路径:思想: &n
- 学习python爬虫时遇到了一个问题,书上有示例如下:import reline='Cats are smarter than do
- Python 通过pip安装Django详细介绍经过前面的 Python 包管理工具的学习,接下来我们就要基于前面的知识,来配置 Djang