pytorch dataset实战案例之读取数据集的代码
作者:半岛铁子_ 发布时间:2023-10-06 23:51:01
标签:pytorch,dataset,读取数据集
概述
最近在跑一篇图像修复论文的代码,配置好环境之后开始运行,发现数据一直加载不进去。
害,还是得看人家代码咋写的,一句一句看逻辑,准能找出问题。通读dataset后,发现了问题所在,终于成功加载了数据集。
项目结构与代码
项目结构
主要的目的就是从数据集中读取到彩 * 像和掩码图像。
代码
代码中涉及到torch.transforms、合并路径等知识点,我在代码中都进行了详细的注释,路径要对照着项目结构,如果自己用的话要根据项目结构去将相对路径改过来。dataset.py
:当前的工作路径:…\OT-GAN-for-Inpainting-master\src\data
import os
import math
import numpy as np
from glob import glob
from random import shuffle
from PIL import Image, ImageFilter
import torch
import torchvision.transforms.functional as F
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
class InpaintingData(Dataset):
def __init__(self, args):
super(Dataset, self).__init__() # 继承Dataset的父类的初始化函数
self.w = self.h = args.image_size # 通过args传入新的属性---图像的w和h
self.mask_type = args.mask_type # 通过args传入新的属性---mask_type
# image and mask
self.image_path = [] #创建image_path的数组
for ext in ['*.jpg', '*.png']: # 获取每一个后缀为.jpg或者.png的图片,为ext
# 将dir_image、data_train和ext拼接作为图片的路径,并将其存入到数组image_path之中,glob()获取一个lsit集合
self.image_path.extend(glob(os.path.join(args.dir_image, args.data_train, ext)))
self.mask_path = glob(os.path.join(args.dir_mask, args.mask_type, '*.png')) #拼接dir_mask、mask_type和路径下所有的.png作为mask_path
# augmentation
self.img_trans = transforms.Compose([ #接收一个 transforms方法的list为参数,将这些操作组合到一起,返回一个新的tranforms
transforms.RandomResizedCrop(args.image_size), #随机随机长宽比裁剪,大小为image_size
transforms.RandomHorizontalFlip(), #随机水平翻转
transforms.ColorJitter(0.05, 0.05, 0.05, 0.05), #改变图像的亮度、对比度、饱和度和色调。
transforms.ToTensor()]) # 转为tensor,并归一化至[0-1]
self.mask_trans = transforms.Compose([
transforms.Resize(args.image_size, interpolation=transforms.InterpolationMode.NEAREST), #将输入图像调整为给定的大小,interpolation是插值方式,此处是默认值NEAREST
transforms.RandomHorizontalFlip(), #随机水平翻转
transforms.RandomRotation( #随机旋转
(0, 45), interpolation=transforms.InterpolationMode.NEAREST), #(0, 45)是角度
])
def __len__(self): # __len__和__getitem__DataSet类必须实现的静态方法
return len(self.image_path)
def __getitem__(self, index):
# load image
image = Image.open(self.image_path[index]).convert('RGB') #获取图像,并将其转化为RGB(3x8位像素)模式
filename = os.path.basename(self.image_path[index]) #获取图片的路径
if self.mask_type == 'pconv': #如果mask_type为pconv
index = np.random.randint(0, len(self.mask_path)) #随机从mask_path中获取一个下标
mask = Image.open(self.mask_path[index]) #根据下标获取mask图片
mask = mask.convert('L') #将mask图片转化为L(8位像素的黑白图片,0表示黑,255表示白)模式
else: # 构造mask,有mask数据集的话就运行不到这里
mask = np.zeros((self.h, self.w)).astype(np.uint8) #构造与h和w一样大的图片,都用0填充,并将其转换为uint8
mask[self.h // 4:self.h // 4 * 3, self.w // 4:self.w // 4 * 3] = 1
mask = Image.fromarray(m).convert('L')
# augment
image = self.img_trans(image) * 2. - 1. # 数据标准化,将输出限定在一定的范围
mask = F.to_tensor(self.mask_trans(mask)) # 将转化后的mask图像转化为tensor
return image, mask, filename #返回
if __name__ == '__main__':
from attrdict import AttrDict
args = {
'dir_image': '../../examples/logos',
'data_train': 'image',
'dir_mask': '../../examples/logos/mask',
'mask_type': 'pconv',
'image_size': 512
}
args = AttrDict(args) # 将上面定义的参数传入AttrDict()作为新参数
data = InpaintingData(args) #创建InpaintingData对象
print(len(data), len(data.mask_path)) #输出data的长度,mask的长度
img, mask, filename = data[0] # 获取第一张图片
print(img.size(), mask.size(), filename) #打印上述信息
输出:
再Debug一下看:
如下图所示,执行玩加载数据的代码之后,已经成功获取到数据
来源:https://blog.csdn.net/hshudoudou/article/details/127431107
0
投稿
猜你喜欢
- 官方文档介绍链接:append方法介绍DataFrame.append(other, ignore_index=False, verify_
- 网站开发时经常需要在某个页面需要实现对大量图片的浏览,如果考虑流量的话,大可以像pconline一样每个页面只显示一张图片,让用户每看一张图
- numpy的delete是可以删除数组的整行和整列的,下面简单介绍和举例说明delete函数用法:numpy.delete(arr, obj
- Perl的特殊符号@ 数组 &nb
- python class(object)的含义在python2中有区别,在Python3中已经没有区别:object为默认类,表示继承关系c
- 记得以前写过一篇文章 php有效的过滤html标签,js代码,css样式标签: <?php $str = preg_replace(
- 从PHP的5.4.0版本开始,PHP提供了一种全新的代码复用的概念,那就是Trait。Trait其字面意思是”特性”、”特点”,我们可以理解
- python是很容易上手的编程语言,但是有些时候使用python编写的程序并不能保证其运行速度(例如:while 和 for),这个时候我们
- 网页设计中,内容组织恐怕是最至关重要、最影响设计品质的方面了。如何将信息组织到好的布局中,是一个网站的基础,并且应该在考虑外观之前就决定好。
- # -*- coding:utf-8 -*-__author__ = 'walkskyer'import osimport
- 每次查询分析器寻找路径时,并不会每一次都去统计索引中包含的行数,值的范围等,而是根据一定条件创建和更新这些信息后保存到数据库中,这也就是所谓
- 你们要的3D太阳系图片上传之后不知为何帧率降低了许多。。。日地月三体所谓三体,就是三个物体在重力作用下的运动。由于三点共面,所以三个质点仅在
- 互联网是一个飞速发展的行业,任何的止步不前都会导致被淘汰,只是时间早晚的问题,所以一个公司的学习与创新能力是非常重要的,特别是对于一个年轻的
- 多值运动,也就是对于某个对象来说,不仅仅只是其中一个属性值在变化,而是好多个,比如宽,高,字体,透明度等等同时变化当然了,多值运动会产生一个
- <!doctype><html><head><title>新闻图片轮换类</title
- 闭包account=0def atm(num,flag): global account  
- 问题提出最近,使用tqdm模块,对于大文件的阅读进行进度监控。然而我发现有个问题,即在tqdm模块使用一定没错的情况下,进度条死活打印不出来
- 为了让鼠标移到小图上显示大图,我利用鼠标事件新建了一个层来显示大图.当然之前最好得到XY坐标取得当前鼠标的X,Y坐标:function&nb
- 最近的答题赢钱很火爆,我也参与了几次,有些题目确实很难答,但是10秒钟的时间根本不够百度的,所以写了个辅助挂,这样可以出现题目时自动百度,这
- 如何制作K线图?也不难,代码和说明见下:<%@ Language=VBScript %><%Respo