详解PyTorch预定义数据集类datasets.ImageFolder使用方法
作者:实力 发布时间:2022-01-30 20:15:07
datasets.ImageFolder是PyTorch提供的一个预定义数据集类,用于处理图像数据。它可以方便地将一组图像加载到内存中,并为每个图像分配标签。
数据集准备和目录结构
要使用datasets.ImageFolder,我们需要准备好一个包含图像数据的目录,并按照以下方式进行组织:
root/
class1/
img1.jpg
img2.jpg
...
class2/
img1.jpg
img2.jpg
...
...
其中,root代表数据集根目录,class1、class2等代表不同的分类标签,img1、img2等代表图像文件名。每个类别(也称为标签)应该有一个单独的子目录,子目录中包含这个类别的所有图像文件。同时,每个图像文件在对应的子目录下,以其文件名作为其类别标签。这种目录组织方式可以让我们轻松获取图像和对应的标签信息。
加载数据集
完成数据集准备之后,我们就可以使用datasets.ImageFolder来加载它了。下面是一个示例代码:
import torchvision.datasets as datasets
import torchvision.transforms as transforms
data_dir = "/path/to/data"
transforms = transforms.Compose([
transforms.Resize(size=(224, 224)),
transforms.ToTensor(),
])
dataset = datasets.ImageFolder(root=data_dir, transform=transforms)
在这个例子中,我们首先导入datasets和transforms模块,然后指定数据集的根目录data_dir。接下来,我们定义一个 transforms 对象,它将图像转换为PyTorch张量,并调整大小为(224, 224)。
最后,我们使用datasets.ImageFolder来加载图像数据集。ImageFolder类需要两个参数:root 和 transform。root是数据集根目录;transform指定对每个图像应该执行的预处理操作,例如调整大小、裁剪、翻转等。
数据集划分
对于机器学习任务,我们通常需要将数据集划分成训练集、验证集和测试集。在PyTorch中,我们可以使用torch.utils.data.random_split函数来完成数据集的划分。下面是一个示例代码:
from torch.utils.data import DataLoader, random_split
# Split the dataset into train and test sets
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
# Split train dataset into train and validation sets
val_size = int(0.2 * len(train_dataset))
train_size = len(train_dataset) - val_size
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])
在这个例子中,我们先使用random_split函数将原始数据集划分为训练集和测试集,在这里80%的数据用于训练,20%的数据用于测试。然后,我们再次使用random_split函数将训练集划分为训练集和验证集,其中80%的数据用于训练,20%的数据用于验证。
数据加载器
最后,我们可以使用数据加载器(DataLoader)来加载数据集。数据加载器负责将图像数据和标签封装成批量,并提供多线程方式加载数据以加速训练过程。下面是一个示例代码:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
在这里,我们创建了三个数据加载器train_loader、val_loader 和 test_loader,它们分别对应训练集、验证集和测试集。batch_size参数指定了每个批次的大小,shuffle参数表示是否随机化输入数据(在训练集中设置为True,在验证集和测试集中设置为False)。
来源:https://juejin.cn/post/7223988948069302329


猜你喜欢
- 在pandas中,经常对数据进行处理 而导致数据索引顺序混乱,从而影响数据读取、插入等。小笔总结了以下几种重置索引的方法:import pa
- 以前的服务器,由于内存的价格过高,一般配置的内存不是很多,超过4GB的当然就不多了.现在的服务器,配置超过4GB就很多,在配作SQL 数据库
- login <?php require "../include/DBClass.php"; $usern
- 什么是deferdefer用来声明一个延迟函数,把这个函数放入到一个栈上, 当外部的包含方法return之前,返回参数到调用方法之前调用,也
- sys模块sys模块是与python解释器交互的一个接口sys.argv 命令行参数List,第一个元素是程序本身路径sys.
- 总结常用基本点如下: 1、触发器有两种类型:数据定义语言触发器(DDL触发器)和数据操纵语言触发器(DML触发器)。 DDL触发器:在用户对
- 本文主要内容:聚类算法的特点聚类算法样本间的属性(包括,有序属性、无序属性)度量标准聚类的常见算法,原型聚类(主要论述K均值聚类),层次聚类
- 一、简单介绍正则表达式是一种小型的、高度专业化的编程语言,并不是python * 有的,是许多编程语言中基础而又重要的一部分。在python中
- 目录一、两个模块二、SMTP端口三、四大步骤1、构造邮件内容2、连接邮件服务器3、登陆邮件服务器4、发送邮件四、常用场景1、纯文本邮件2、发
- 先来看个例子:#-*- coding:utf8 -*-s = u'中文截取's.decode('utf8')
- 本文实例讲述了python实现下载指定网址所有图片的方法。分享给大家供大家参考。具体实现方法如下:#coding=utf-8#downloa
- 现在,我们已经把一个Web App的框架完全搭建好了,从后端的API到前端的MVVM,流程已经跑通了。在继续工作前,注意到每次修改Pytho
- 在 Python 中,一般情况下我们可能直接用自带的 logging 模块来记录日志,包括我之前的时候也是一样。在使用时我们需要配置一些 H
- 本文为大家分享了pygame游戏之旅的第7篇,供大家参考,具体内容如下对car和障碍的宽高进行比较然后打印即可:if y < thin
- css加载器在webpack中,所有的资源(js文件、css文件、模板文件,图片文件等等)都被看成是一个模块,因此多有的资源都是可以被加载的
- 1.where语法和用法(1)语法:where <criteria> 即where <查询条件>具体查询语句:sel
- 文件上传是所有UI自动化测试都要面对的一个头疼问题,今天博主在这里给大家分享下自己处理文件上传的经验,希望能够帮助到广大被文件上传坑住的se
- Python文件遍历os.walk()与os.listdir()在图片处理过程中,样本数据的组织是个常见的问题,样本组织好了,后面数据转换、
- 在ASP中,如何创建DSN? 见下:<HTML><HEAD><META&n
- 调试程序的过程中,发现通过os.path.join拼接的路径出现了反斜杠directory1='/opt/apps/upgradeP