PyTorch如何创建自己的数据集
作者:ZQ_ZHU 发布时间:2022-10-17 05:22:17
PyTorch创建自己的数据集
图片文件在同一的文件夹下
思路是继承 torch.utils.data.Dataset,并重点重写其 __getitem__方法,示例代码如下:
class ImageFolder(Dataset):
def __init__(self, folder_path):
self.files = sorted(glob.glob('%s/*.*' % folder_path))
def __getitem__(self, index):
path = self.files[index % len(self.files)]
img = np.array(Image.open(path))
h, w, c = img.shape
pad = ((40, 40), (4, 4), (0, 0))
# img = np.pad(img, pad, 'constant', constant_values=0) / 255
img = np.pad(img, pad, mode='edge') / 255.0
img = torch.from_numpy(img).float()
patches = np.reshape(img, (3, 10, 128, 11, 128))
patches = np.transpose(patches, (0, 1, 3, 2, 4))
return img, patches, path
def __len__(self):
return len(self.files)
图片文件在不同的文件夹下
比如我们有数据如下:
─── data
├── train
│ ├── 0.jpg
│ └── 1.jpg
├── test
│ ├── 0.jpg
│ └── 1.jpg
└── val
├── 1.jpg
└── 2.jpg
此时我们只需要将以上代码稍作修改即可,修改的代码如下:
self.files = sorted(glob.glob('%s/**/*.*' % folder_path, recursive=True))
其他代码不变。
pytorch常用数据集的使用
对于pytorch数据集的使用,示例代码如下:
from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms import Compose
from torchvision import transforms
import torchvision
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
dataset_transform = Compose([transforms.ToTensor()])
# 关于官方数据集的使用还是关键要看pytorch的官方文档
train_set = torchvision.datasets.CIFAR10(root="./CIFAR10",train=True,transform=dataset_transform,download=True)
test_set = torchvision.datasets.CIFAR10(root="./CIFAR10",train=False,transform=dataset_transform,download=True)
# 查看测试数据集中的第一个数据
# print(test_set[0])
# 查看测试数据集中的分类情况
# print(test_set.classes)
#
# 取出第一个数据中的图片(img)和分类结果(target)
# img,target = test_set[0]
# 查看图片数据的类型
# print(img)
# print(target)
# 输出类别
# print(test_set.classes[target])
# 查看图片
# img.show()
# 使用tensorboard显示tensor数据类型的图片
writer = SummaryWriter("logs")
for i in range(10):
# 取出数据中的图片(img)和分类结果(target)
img,target = test_set[i]
writer.add_image("test_set",img,i)
writer.close()
上述代码运行结果在tensorboard可视化:
代码
train_set = torchvision.datasets.CIFAR10(root="./CIFAR10",train=True,transform=dataset_transform,download=True)
常用参数讲解
root
:根目录,存放数据集的位置train
:若为True,则划分为训练数据集,若为False,则划分为测试数据集transform
:指定输入数据集处理方式download
:若为True,则会将数据集下载到root指定的目录下,否则不会下载
官方文档对参数的解释:
root (string) – Root directory of dataset where directory cifar-10-batches-py exists or will be saved to if download is set to True.
train (bool, optional) – If True, creates dataset from training set, otherwise creates from test set.
transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
注意:
关于官方数据集的使用还是关键要看pytorch的官方文档
下载数据集的细节之处:知道下载链接(下载链接可以在源码中查看)之后可以不用使用代码下载了,使用迅雷来下载可能会更快。
要学会使用Pycharm中的ctrl+p和ctrl+alt这两个快捷键
pytorch官网
pytorch官方数据集(下载数据集方法)
来源:https://blog.csdn.net/zzq060143/article/details/88382919


猜你喜欢
- Python DataFrame 如何设置列表字段/元素类型?比如笔者想将列表的两个字段由float64设置为int64,那么就要用到Dat
- 1.理论只要两个表的公共字段有匹配值,就将这两个表中的记录组合起来。个人理解:以一个共同的字段求两个表中符合要求的交集,并将每个表符合要求的
- 拆包是指将一个结构中的数据拆分为多个单独变量中。以元组为例:>>> a = ('windows', 10,
- 以下插件是我在项目中经常使用的jQuery插件,不见得是最好的,但是我目前接触到的jQuery插件中最适合我的。01. jQuery.Fle
- 项目中用到了限流,受限于一些实现方式上的东西,手撕了一个简单的服务端限流器。服务端限流和客户端限流的区别,简单来说就是:1)服务端限流对接口
- 本文实例讲述了Python3实现的判断回文链表算法。分享给大家供大家参考,具体如下:问题:请判断一个链表是否为回文链表。方案一:指针法cla
- 1.什么是并发编程并发编程是实现多任务协同处理,改善系统性能的方式。Python中实现并发编程主要依靠进程(Process):进程是计算机中
- 前言如今的玩家们在无聊的时候会玩些什么游戏呢?王者还是吃鸡是最多的选择。但在80、90年代的时候多是一些很简单的游戏:《超级玛丽》、《蜘蛛纸
- 一、如何实现可迭代对象和迭代器对象?实际案例某软件要求从网络抓取各个城市气味信息,并其次显示:北京: 15 ~ 20 天津: 17 ~ 22
- 问题你需要将数字格式化后输出,并控制数字的位数、对齐、千位分隔符和其他的细节。解决方案格式化输出单个数字的时候,可以使用内置的 format
- 引言 咱们公司从事的是信息安全涉密应用的一些项目研发一共有分为三步,相比较于一般
- 之前写了一个matlab的,越用越觉得麻烦,如果不同数据集要改类别数目,而且运行速度慢。所以重新写了一个Python的,直接读取xml文件夹
- springboot配置文件抽离,便于服务器读取对应配置文件,避免项目频繁更改配置文件,影响项目的调试与发布1.创建统一配置中心项目coni
- 昨天微信小程序(应用号)内测的消息把整个技术社区炸开了锅,我也忍不住跟了几波,可惜没有内测资格,听闻破解版出来了,今天早上就着原来的项目资源
- 一、zmial发送邮件zmial是第三方库,需进行安装pip install zmail完成后,来给发一封邮件subject:标题conte
- 昨天晚上才发现已经出了jQuery的1.3版本,于是下载下来,把原来一个兄弟翻译的1.2.6的文档移植到了1.3中,点击这里可
- numpy.flip(m, axis=None)Reverse the order of elements in an array alon
- 官方教程http://www.kuitao8.com/demo/20140224/1/bootstrap-multiselect-maste
- 使用场景对手机号码进行地域分析,需要查询归属地;问题描述针对数据集比较大的情况,通过脚本来处理,使用多线程的方法来加快查询速度pool =
- Python 列表理解及使用方法列表是最常用的Python最常用的数据类型,它和其它序列一样,可以进行包括索引,切片,加,乘,检查成员的操作