pytorch加载自己的图片数据集的2种方法详解
作者:_-周-_ 发布时间:2023-08-09 00:35:48
pytorch加载图片数据集有两种方法。
1.ImageFolder 适合于分类数据集,并且每一个类别的图片在同一个文件夹, ImageFolder加载的数据集, 训练数据为文件件下的图片, 训练标签是对应的文件夹, 每个文件夹为一个类别
导入ImageFolder()包
from torchvision.datasets import ImageFolder
在Flower_Orig_dataset文件夹下有flower_orig 和 sunflower这两个文件夹, 这两个文件夹下放着同一个类别的图片。 使用 ImageFolder 加载的图片, 就会返回图片信息和对应的label信息, 但是label信息是根据文件夹给出的, 如flower_orig就是标签0, sunflower就是标签1。
ImageFolder 加载数据集
1. 导入包和设置transform
import torch
from torchvision import transforms, datasets
import torch.nn as nn
from torch.utils.data import DataLoader
transforms = transforms.Compose([
transforms.Resize(256), # 将图片短边缩放至256,长宽比保持不变:
transforms.CenterCrop(224), #将图片从中心切剪成3*224*224大小的图片
transforms.ToTensor() #把图片进行归一化,并把数据转换成Tensor类型
])
2.加载数据集: 将分类图片的父目录作为路径传递给ImageFolder(), 并传入transform。这样就有了要加载的数据集, 之后就可以使用DataLoader加载数据, 并构建网络训练。
path = r'D:\数据集\Flower_Orig_dataset'
data_train = datasets.ImageFolder(path, transform=transforms)
data_loader = DataLoader(data_train, batch_size=64, shuffle=True)
for i, data in enumerate(data_loader):
images, labels = data
print(images.shape)
print(labels.shape)
break
使用pytorch提供的Dataset类创建自己的数据集。
具体步骤:
1. 首先要有一个txt文件, 这个文件格式是: 图片路径 标签. 这样的格式, 所以使用os库, 遍历自己的图片名, 并把标签和图片路径写入txt文件。
2. 有了这个txt文件, 我们就可以在类里面构造我们的数据集.
2.1 把图片路径和图片标签分割开, 有两个列表, 一个列表是图片路径名, 一个列表是标签号, 有一点就是第 i 个图片列表和 第 i 个标签是对应的
3. 重写__len__方法 和 __getitem__方法
3.1 getitem方法中, 获得对应的图片路径,并用PIL库读取文件把图片transfrom后, 在getitem函数中返回读取的图片和标签即可
4.就可以构建数据集实例和加载数据集.
定义一个用来生成[ 图片路径 标签] 这样的txt文件函数
def make_txt(root, file_name, label):
path = os.path.join(root, file_name)
data = os.listdir(path)
f = open(path+'\\'+'f.txt', 'w')
for line in data:
f.write(line+' '+str(label)+'\n')
f.close()
#调用函数生成两个文件夹下的txt文件
make_txt(path, file_name='flower_orig', label=0)
make_txt(path, file_name='sunflower', label=1)
将连个txt文件合并成一个txt文件,表示数据集所有的图片和标签
def link_txt(file1, file2):
txt_list = []
path = r'D:\数据集\Flower_Orig_dataset\data.txt'
f = open(path, 'a')
f1 = open(file1, 'r')
data1 = f1.readlines()
for line in data1:
txt_list.append(line)
f2 = open(file2, 'r')
data2 = f2.readlines()
for line in data2:
txt_list.append(line)
for line in txt_list:
f.write(line)
f.close()
f1.close()
f2.close()
#调用函数, 将两个文件夹下的txt文件合并
file1 = r'D:\数据集\Flower_Orig_dataset\flower_orig\f.txt'
file2 = r'D:\数据集\Flower_Orig_dataset\sunflower\f.txt'
link_txt(file1=file1, file2=file2)
现在我们已经有了我们制作数据集所需要的txt文件, 接下来要做的即使继承Dataset类, 来构建自己的数据集 , 别忘了前面说的 构建数据集步骤, 在__getitem__函数中, 需要拿到图片路径和标签, 并且用PIL库方法读取图片,对图片进行transform转换后,返回图片信息和标签信息
Dataset加载数据集
我们读取图片的根目录, 在根目录下有所有图片的txt文件, 拿到txt文件后, 先读取txt文件, 之后遍历txt文件中的每一行, 首先去除掉尾部的换行符, 在以空格切分,前半部分是图片名称, 后半部分是图片标签, 当图片名称和根目录结合,就得到了我们的图片路径
class MyDataset(Dataset):
def __init__(self, img_path, transform=None):
super(MyDataset, self).__init__()
self.root = img_path
self.txt_root = self.root + 'data.txt'
f = open(self.txt_root, 'r')
data = f.readlines()
imgs = []
labels = []
for line in data:
line = line.rstrip()
word = line.split()
imgs.append(os.path.join(self.root, word[1], word[0]))
labels.append(word[1])
self.img = imgs
self.label = labels
self.transform = transform
def __len__(self):
return len(self.label)
def __getitem__(self, item):
img = self.img[item]
label = self.label[item]
img = Image.open(img).convert('RGB')
#此时img是PIL.Image类型 label是str类型
if transforms is not None:
img = self.transform(img)
label = np.array(label).astype(np.int64)
label = torch.from_numpy(label)
return img, label
加载我们的数据集:
path = r'D:\数据集\Flower_Orig_dataset'
dataset = MyDataset(path, transform=transform)
data_loader = DataLoader(dataset=dataset, batch_size=64, shuffle=True)
接下来我们就可以构建我们的网络架构:
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3,16,3)
self.maxpool = nn.MaxPool2d(2,2)
self.conv2 = nn.Conv2d(16,5,3)
self.relu = nn.ReLU()
self.fc1 = nn.Linear(55*55*5, 1200)
self.fc2 = nn.Linear(1200,64)
self.fc3 = nn.Linear(64,2)
def forward(self,x):
x = self.maxpool(self.relu(self.conv1(x))) #113
x = self.maxpool(self.relu(self.conv2(x))) #55
x = x.view(-1, self.num_flat_features(x))
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
def num_flat_features(self, x):
size = x.size()[1:]
num_features = 1
for s in size:
num_features *= s
return num_features
训练我们的网络:
model = Net()
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
epochs = 10
for epoch in range(epochs):
running_loss = 0.0
for i, data in enumerate(data_loader):
images, label = data
out = model(images)
loss = criterion(out, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item()
if(i+1)%10 == 0:
print('[%d %5d] loss: %.3f'%(epoch+1, i+1, running_loss/100))
running_loss = 0.0
print('finished train')
保存网络模型(这里不止是保存参数,还保存了网络结构)
#保存模型
torch.save(net, 'model_name.pth') #保存的是模型, 不止是w和b权重值
# 读取模型
model = torch.load('model_name.pth')
来源:https://blog.csdn.net/qq_53345829/article/details/124308515


猜你喜欢
- 前言pytorch对一下常用的公开数据集有很方便的API接口,但是当我们需要使用自己的数据集训练神经网络时,就需要自定义数据集,在pytor
- 从微信小程序官方发布的公告中我们可获知:小程序体验版、开发版调用 wx.getUserInfo 接口,将无法弹出授权询问框,默认调用失败,需
- Wake-On-LAN简称WOL,是一种电源管理功能;如果存在网络活动,则允许设备将操作系统从待机或休眠模式中唤醒。许多主板厂商支持IBM提
- struts2.3.24 + spring4.1.6 + hibernate4.3.11+ mysql5.5.25开发环境搭建及相关说明。
- ES6添加了Promise对象,成功时在then中处理,失败则在catch中处理,但有时候,我们需要在无论成功或失败时都要做一些事,比如隐藏
- 本文实例讲述了Python Web编程之WSGI协议。分享给大家供大家参考,具体如下:WSGI简介Web框架和Wen服务器之间需要进行通信,
- 现在假如要写一个按照"标题",'内容','作者'等等进行针对性的选择,这时需要涉及到使用
- 前言本人做SSM项目的时候,在做删除功能时,发现找不到字段,在搜索了各种博客之后终于找到了解决办法一、报错Unknown column &a
- 面向对象编程时,都会遇到一个概念,类,python也有这个概念,下面我们通过代码来深入了解下。创建和使用类class Dog(): &nbs
- 我经常需要用Python与solr进行异步请求工作。这里有段代码阻塞在Solr http请求上, 直到第一个完成才会执行第二个请
- 上篇文章给大家介绍了MySQL 8.0.23 主要更新一览(新特征解读) ,感兴趣的朋友点击查看吧!最新版windows mysq
- 树莓派与arduino串口通信第一步:先设置硬件串口分配给GPIO串口输入sudo raspi-config命令进入树莓派系统配置界面,选择
- python可以简单优美,也很有趣,下面是收集的例子:1.一句话开始一个http的文件服务器:$ python -m SimpleHTTPS
- 1.limit函数的语法和用法(1)常用且简单的语法和用法①语法:limit n 即limit <参数>具体语法:select
- 实验1.1 列表a = [1, 2, 3, 4]for i in a: print(i)  
- 本文将详细介绍一下如何搭建深度学习所需要的实验环境.这个框架分为以下六个模块显卡简单理解这个就是我们常说的GPU,显卡的功能是一个专门做矩阵
- 以下为引用的内容: <html> <head> <title>不刷新页面查询的方法&
- 例子一:Python用WMI模块获取windowns系统的硬件信息:硬盘分区、使用情况,内存大小,CPU型号,当前运行的进程,自启动程序及位
- 绑定的值与规则指定的值一定要相同-------第一步:<el-form :model="ruleForm" :ru
- 1 与达尔文对话140年前,1858年7月1日,达尔文在英伦岛发表了自己有关自然选择的杰出论文。他提出,生物的发展规律是物竞天择。经过物竞,