pytorch dataset实战案例之读取数据集的代码
作者:半岛铁子_ 发布时间:2023-10-06 23:51:01
标签:pytorch,dataset,读取数据集
概述
最近在跑一篇图像修复论文的代码,配置好环境之后开始运行,发现数据一直加载不进去。
害,还是得看人家代码咋写的,一句一句看逻辑,准能找出问题。通读dataset后,发现了问题所在,终于成功加载了数据集。
项目结构与代码
项目结构
主要的目的就是从数据集中读取到彩色图像和掩码图像。
代码
代码中涉及到torch.transforms、合并路径等知识点,我在代码中都进行了详细的注释,路径要对照着项目结构,如果自己用的话要根据项目结构去将相对路径改过来。dataset.py
:当前的工作路径:…\OT-GAN-for-Inpainting-master\src\data
import os
import math
import numpy as np
from glob import glob
from random import shuffle
from PIL import Image, ImageFilter
import torch
import torchvision.transforms.functional as F
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
class InpaintingData(Dataset):
def __init__(self, args):
super(Dataset, self).__init__() # 继承Dataset的父类的初始化函数
self.w = self.h = args.image_size # 通过args传入新的属性---图像的w和h
self.mask_type = args.mask_type # 通过args传入新的属性---mask_type
# image and mask
self.image_path = [] #创建image_path的数组
for ext in ['*.jpg', '*.png']: # 获取每一个后缀为.jpg或者.png的图片,为ext
# 将dir_image、data_train和ext拼接作为图片的路径,并将其存入到数组image_path之中,glob()获取一个lsit集合
self.image_path.extend(glob(os.path.join(args.dir_image, args.data_train, ext)))
self.mask_path = glob(os.path.join(args.dir_mask, args.mask_type, '*.png')) #拼接dir_mask、mask_type和路径下所有的.png作为mask_path
# augmentation
self.img_trans = transforms.Compose([ #接收一个 transforms方法的list为参数,将这些操作组合到一起,返回一个新的tranforms
transforms.RandomResizedCrop(args.image_size), #随机随机长宽比裁剪,大小为image_size
transforms.RandomHorizontalFlip(), #随机水平翻转
transforms.ColorJitter(0.05, 0.05, 0.05, 0.05), #改变图像的亮度、对比度、饱和度和色调。
transforms.ToTensor()]) # 转为tensor,并归一化至[0-1]
self.mask_trans = transforms.Compose([
transforms.Resize(args.image_size, interpolation=transforms.InterpolationMode.NEAREST), #将输入图像调整为给定的大小,interpolation是插值方式,此处是默认值NEAREST
transforms.RandomHorizontalFlip(), #随机水平翻转
transforms.RandomRotation( #随机旋转
(0, 45), interpolation=transforms.InterpolationMode.NEAREST), #(0, 45)是角度
])
def __len__(self): # __len__和__getitem__DataSet类必须实现的静态方法
return len(self.image_path)
def __getitem__(self, index):
# load image
image = Image.open(self.image_path[index]).convert('RGB') #获取图像,并将其转化为RGB(3x8位像素)模式
filename = os.path.basename(self.image_path[index]) #获取图片的路径
if self.mask_type == 'pconv': #如果mask_type为pconv
index = np.random.randint(0, len(self.mask_path)) #随机从mask_path中获取一个下标
mask = Image.open(self.mask_path[index]) #根据下标获取mask图片
mask = mask.convert('L') #将mask图片转化为L(8位像素的黑白图片,0表示黑,255表示白)模式
else: # 构造mask,有mask数据集的话就运行不到这里
mask = np.zeros((self.h, self.w)).astype(np.uint8) #构造与h和w一样大的图片,都用0填充,并将其转换为uint8
mask[self.h // 4:self.h // 4 * 3, self.w // 4:self.w // 4 * 3] = 1
mask = Image.fromarray(m).convert('L')
# augment
image = self.img_trans(image) * 2. - 1. # 数据标准化,将输出限定在一定的范围
mask = F.to_tensor(self.mask_trans(mask)) # 将转化后的mask图像转化为tensor
return image, mask, filename #返回
if __name__ == '__main__':
from attrdict import AttrDict
args = {
'dir_image': '../../examples/logos',
'data_train': 'image',
'dir_mask': '../../examples/logos/mask',
'mask_type': 'pconv',
'image_size': 512
}
args = AttrDict(args) # 将上面定义的参数传入AttrDict()作为新参数
data = InpaintingData(args) #创建InpaintingData对象
print(len(data), len(data.mask_path)) #输出data的长度,mask的长度
img, mask, filename = data[0] # 获取第一张图片
print(img.size(), mask.size(), filename) #打印上述信息
输出:
再Debug一下看:
如下图所示,执行玩加载数据的代码之后,已经成功获取到数据
来源:https://blog.csdn.net/hshudoudou/article/details/127431107
0
投稿
猜你喜欢
- os 包 和 bufio 包Go 标准库的 os 包,为我们提供很多操作文件的函数,如 Open(name) 打开文件、Create(nam
- 目录Counter类创建计数值的访问与缺失的键计数器的更新键的删除elements()most_common([n])fromkeys浅拷贝
- 昨天在网上看到一个防采集软件,说采集只访问当前网页,不会访问网页的图片、JS等,今天突然想到,通过动态程序和Js访问分别记录访问者的IP,然
- python等待10秒执行下一命令的方法:首先导入时间(time)模块;然后在需要等待执行的命令前调用sleep()方法,并在方法的括号里将
- 本文实例讲述了ES6正则表达式和字符串正则方法。分享给大家供大家参考,具体如下:RegExp构造函数在ES5中,RegExp构造函数的参数有
- 下面通过一段代码给大家介绍php参数过滤class mysafe{ public $logname; public $isshwomsg;
- SMTP协议首先了解SMTP(简单邮件传输协议),邮件传送代理程序使用SMTP协议来发送电邮到接收者的邮件服务器。SMTP协议只能用来发送邮
- 概述在进行网站爬取数据的时候,会发现很多网站都进行了反爬虫的处理,如JS加密,Ajax加密,反Debug等方法,通过请求获取数据和页面展示的
- 体系结构 Microsoft按照客户/服务器体系结构的分布进行操作。这种方法产生不必要的代价和复杂性。在Internet中,Oracle已经
- vue-amap是对高德地图JS API进行封装的、适用于vue项目的地图组件库。在笔者开发的很多项目中都有用到,相比直接使用高德地图JS
- 逻辑判断与逻辑语句对于─件事情正确与否(真假的判断) √ X根据判断的结果做不同的事情,就是我们的逻辑业务对于条件满足的判断语句,就是条件语
- mysql的root账户,我在连接时通常用的是localhost或127.0.0.1,公司的测试服务器上的mysql也是localhost所
- 类的定义# class是定义类的关键字,ClassName为类的名称class ClassName:# 在这里写其他内容passclass
- 周期置换密码参考教材:《现代密码学教程》P47 3.1.2加密解密过程周期置换密码是将明文p串按固定长度m分组.然后对每组中的子串按1,2&
- 1 圆点选择选项设置效果展示代码参考#!/usr/bin/python# -*- coding:utf-8 -*-import sysfro
- windows系统下安装Pyinstallercmd下输入指令pip install PyInstallerPyinstaller的使用进入
- 1.什么是getters在介绍state中我们了解到,在Store仓库里,state就是用来存放数据,若是对数据进行处理输出,比如数据要过滤
- 实际中,很多数据都是存为txt文件、csv文件等,但是在程序中处理的时候numpy数组或列表是最方便的。本文简单介绍读入txt文件以及将之转
- 在使用element-ui的时候,有一个常用的组件,那就是el-popover,但是element-ui官方文档中样式跟用法都比较局限,在使
- 本文实例讲述了wxPython使用系统剪切板的方法。分享给大家供大家参考。具体如下:程序运行效果如下图所示:主要代码如下:import wx