PyTorch如何创建自己的数据集
作者:ZQ_ZHU 发布时间:2022-10-17 05:22:17
PyTorch创建自己的数据集
图片文件在同一的文件夹下
思路是继承 torch.utils.data.Dataset,并重点重写其 __getitem__方法,示例代码如下:
class ImageFolder(Dataset):
def __init__(self, folder_path):
self.files = sorted(glob.glob('%s/*.*' % folder_path))
def __getitem__(self, index):
path = self.files[index % len(self.files)]
img = np.array(Image.open(path))
h, w, c = img.shape
pad = ((40, 40), (4, 4), (0, 0))
# img = np.pad(img, pad, 'constant', constant_values=0) / 255
img = np.pad(img, pad, mode='edge') / 255.0
img = torch.from_numpy(img).float()
patches = np.reshape(img, (3, 10, 128, 11, 128))
patches = np.transpose(patches, (0, 1, 3, 2, 4))
return img, patches, path
def __len__(self):
return len(self.files)
图片文件在不同的文件夹下
比如我们有数据如下:
─── data
├── train
│ ├── 0.jpg
│ └── 1.jpg
├── test
│ ├── 0.jpg
│ └── 1.jpg
└── val
├── 1.jpg
└── 2.jpg
此时我们只需要将以上代码稍作修改即可,修改的代码如下:
self.files = sorted(glob.glob('%s/**/*.*' % folder_path, recursive=True))
其他代码不变。
pytorch常用数据集的使用
对于pytorch数据集的使用,示例代码如下:
from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms import Compose
from torchvision import transforms
import torchvision
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
dataset_transform = Compose([transforms.ToTensor()])
# 关于官方数据集的使用还是关键要看pytorch的官方文档
train_set = torchvision.datasets.CIFAR10(root="./CIFAR10",train=True,transform=dataset_transform,download=True)
test_set = torchvision.datasets.CIFAR10(root="./CIFAR10",train=False,transform=dataset_transform,download=True)
# 查看测试数据集中的第一个数据
# print(test_set[0])
# 查看测试数据集中的分类情况
# print(test_set.classes)
#
# 取出第一个数据中的图片(img)和分类结果(target)
# img,target = test_set[0]
# 查看图片数据的类型
# print(img)
# print(target)
# 输出类别
# print(test_set.classes[target])
# 查看图片
# img.show()
# 使用tensorboard显示tensor数据类型的图片
writer = SummaryWriter("logs")
for i in range(10):
# 取出数据中的图片(img)和分类结果(target)
img,target = test_set[i]
writer.add_image("test_set",img,i)
writer.close()
上述代码运行结果在tensorboard可视化:
代码
train_set = torchvision.datasets.CIFAR10(root="./CIFAR10",train=True,transform=dataset_transform,download=True)
常用参数讲解
root
:根目录,存放数据集的位置train
:若为True,则划分为训练数据集,若为False,则划分为测试数据集transform
:指定输入数据集处理方式download
:若为True,则会将数据集下载到root指定的目录下,否则不会下载
官方文档对参数的解释:
root (string) – Root directory of dataset where directory cifar-10-batches-py exists or will be saved to if download is set to True.
train (bool, optional) – If True, creates dataset from training set, otherwise creates from test set.
transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
注意:
关于官方数据集的使用还是关键要看pytorch的官方文档
下载数据集的细节之处:知道下载链接(下载链接可以在源码中查看)之后可以不用使用代码下载了,使用迅雷来下载可能会更快。
要学会使用Pycharm中的ctrl+p和ctrl+alt这两个快捷键
pytorch官网
pytorch官方数据集(下载数据集方法)
来源:https://blog.csdn.net/zzq060143/article/details/88382919
猜你喜欢
- 首先,运行 Python 解释器,导入 re 模块并编译一个 RE:#!python Python 2.2.2 (#1, Feb 10 20
- 作用域:顾名思义,作用的范围。如果你是自学者,而且已经进军到函数这一部分了,那么就应当了解下Python的作用域。否则你可能会像我一样,总是
- 前言Go的错误处理这块是日常被大家吐槽较多的地方,我在工作中也观察到一些现象,比较严重的是在各层级的逻辑代码中对错误的处理有些重复。比如,有
- <script language="vbscript" runat="s
- 讲起学生成绩管理系统,从大一C语言的课程设计开始,到大二的C++课程设计都是这个题,最近在学树莓派,好像树莓派常用Python编程,于是学了
- 前言保留小数位是我们经常会碰到的问题,尤其是刷题过程中。那么在python中保留小数位的方法也非常多,但是笔者的原则就是什么简单用什么,因此
- 每周的《午间欢乐购》和《周末疯狂购》,已经成为视觉组的固定需求。从开始接触到现在5个月的时间里,思维也和这些小小banner逐渐碰撞出火花。
- 方法1: 单文件模块直接把文件拷贝到 $python_dir/Lib方法2: 多文件模块,文件内有setup.py文件在官网或者GitHub
- Oracle按不同时间分组统计的sql如下表table1: 日期(exportDate) &nbs
- 1.查看mysql上都有哪些库mysql> show databases \G***************************
- 由于asp中是使用双引号作为字符串的开始和结束标志的,单一个字符串中的双引号出现次数大于两个时,程序就有可能运行错误。asp中是怎么输出引号
- 今天分享 3 个 Python 编程小技巧,来看看你是否用过?1、如何按照字典的值的大小进行排序我们知道,字典的本质是哈希表,本身是无法排序
- 是时候了—— 在大部分情况下当用户输入密码时把它们用清晰的文字显示出来。一直以来,提供反馈、把系统状态形象化是最基本的可用性原则,当用户输入
- 不通过数据源名DSN也能访问Access数据库吗?代码如下:<% dim conn &nbs
- 从PJBlog 2.7开始,验证码的功能就很好很强大了,但是同时也给手工输入带来了不小的麻烦——经常输错。之前我写了一个《自己写的一个PJB
- 本文实例讲述了Python爬虫实现“盗取”微信好友信息的方法。分享给大家供大家参考,具体如下:刚起床,闲来无聊,找点事做,看了朋友圈一篇爬取
- 介绍pandas数据聚合和重组的相关知识,仅供参考。1GroupBy技术1.1简介简介:根据一个或多个键进行分组,每一组应用函数,再进行合并
- 问题:在使用mask_rcnn预测自己的数据集时,会出现下面错误:ResourceExhaustedError: OOM when allo
- asp按关键字查询XML的问题 '-------------------------------------------------
- mysql数据库没有增量备份的机制,当数据量太大的时候备份是一个很大的问题。还好mysql数据库提供了一种主从备份的机制,其实就是把主数据库