网络编程
位置:首页>> 网络编程>> Python编程>> PyTorch 解决Dataset和Dataloader遇到的问题

PyTorch 解决Dataset和Dataloader遇到的问题

作者:xgbm_k  发布时间:2023-10-14 04:37:50 

标签:PyTorch,Dataset,Dataloader

今天在使用PyTorch中Dataset遇到了一个问题。先看代码


class psDataset(Dataset):
 def __init__(self, x, y, transforms = None):
   super(Dataset, self).__init__()
   self.x = x
   self.y = y
   if transforms == None:
     self.transforms = Compose([Resize((224, 224)), ToTensor()])
   else:
     self.transforms = transforms

def __len__(self):
   return len(self.x)

def __getitem__(self, idx):
   img = Image.open(self.x[idx])
   img = self.transforms(img)    
   return img, torch.tensor([[self.y[idx]]])

结果运行时报错:RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 3 and 1 in dimension 1 at /opt/conda/conda-bld/pytorch_1522182087074/work/torch/lib/TH/generic/THTensorMath.c:2897

Google了一下发现是这样的:读入的图片有些是灰度图(1个通道),绝大多数是RGB图片(3通道),也有些是带透明度的(4通道)

。这导致在读入后最后一个维度(通道数)不一致(可能是1、3或者4)。

Dataloader在制作batch data时,tensor的shape必须一样,就报了这个错误。解决的方法是:img = img.convert(“RGB”)。完

整代码如下:


class psDataset(Dataset):
 def __init__(self, x, y, transforms = None):
   super(Dataset, self).__init__()
   self.x = x
   self.y = y
   if transforms == None:
     self.transforms = Compose([Resize((224, 224)), ToTensor()])
   else:
     self.transforms = transforms

def __len__(self):
   return len(self.x)

def __getitem__(self, idx):
   img = Image.open(self.x[idx])
   img = img.convert("RGB")
   img = self.transforms(img)    
   return img, torch.tensor([[self.y[idx]]])

来源:https://blog.csdn.net/xgbm_k/article/details/84067245

0
投稿

猜你喜欢

手机版 网络编程 asp之家 www.aspxhome.com