pytorch 如何把图像数据集进行划分成train,test和val
作者:l8947943 发布时间:2023-12-26 15:28:10
标签:pytorch,划分,数据集,train,test,val
1、手上目前拥有数据集是一大坨,没有train,test,val的划分
如图所示
2、目录结构:
|---data
|---dslr
|---images
|---back_pack
|---a.jpg
|---b.jpg
...
3、转换后的格式如图
目录结构为:
|---datanews
|---dslr
|---images
|---test
|---train
|---valid
|---back_pack
|---a.jpg
|---b.jpg
...
4、代码如下:
4.1 先创建同样结构的层级结构
4.2 然后讲原始数据按照比例划分
4.3 移入到对应的文件目录里面
import os, random, shutil
def make_dir(source, target):
'''
创建和源文件相似的文件路径函数
:param source: 源文件位置
:param target: 目标文件位置
'''
dir_names = os.listdir(source)
for names in dir_names:
for i in ['train', 'valid', 'test']:
path = target + '/' + i + '/' + names
if not os.path.exists(path):
os.makedirs(path)
def divideTrainValiTest(source, target):
'''
创建和源文件相似的文件路径
:param source: 源文件位置
:param target: 目标文件位置
'''
# 得到源文件下的种类
pic_name = os.listdir(source)
# 对于每一类里的数据进行操作
for classes in pic_name:
# 得到这一种类的图片的名字
pic_classes_name = os.listdir(os.path.join(source, classes))
random.shuffle(pic_classes_name)
# 按照8:1:1比例划分
train_list = pic_classes_name[0:int(0.8 * len(pic_classes_name))]
valid_list = pic_classes_name[int(0.8 * len(pic_classes_name)):int(0.9 * len(pic_classes_name))]
test_list = pic_classes_name[int(0.9 * len(pic_classes_name)):]
# 对于每个图片,移入到对应的文件夹里面
for train_pic in train_list:
shutil.copyfile(source + '/' + classes + '/' + train_pic, target + '/train/' + classes + '/' + train_pic)
for validation_pic in valid_list:
shutil.copyfile(source + '/' + classes + '/' + validation_pic,
target + '/valid/' + classes + '/' + validation_pic)
for test_pic in test_list:
shutil.copyfile(source + '/' + classes + '/' + test_pic, target + '/test/' + classes + '/' + test_pic)
if __name__ == '__main__':
filepath = r'../data/dslr/images'
dist = r'../datanews/dslr/images'
make_dir(filepath, dist)
divideTrainValiTest(filepath, dist)
补充:pytorch中数据集的划分方法及eError: take(): argument 'index' (position 1) must be Tensor, not numpy.ndarray错误原因
在使用pytorch框架时,难免需要对数据集进行训练集和验证集的划分,一般使用sklearn.model_selection中的train_test_split方法
该方法使用如下:
from sklearn.model_selection import train_test_split
import numpy as np
import torch
import torch.autograd import Variable
from torch.utils.data import DataLoader
traindata = np.load(train_path) # image_num * W * H
trainlabel = np.load(train_label_path)
train_data = traindata[:, np.newaxis, ...]
train_label_data = trainlabel[:, np.newaxis, ...]
x_tra, x_val, y_tra, y_val = train_test_split(train_data, train_label_data, test_size=0.1, random_state=0) # 训练集和验证集使用9:1
x_tra = Variable(torch.from_numpy(x_tra))
x_tra = x_tra.float()
y_tra = Variable(torch.from_numpy(y_tra))
y_tra = y_tra.float()
x_val = Variable(torch.from_numpy(x_val))
x_val = x_val.float()
y_val = Variable(torch.from_numpy(y_val))
y_val = y_val.float()
# 训练集的DataLoader
traindataset = torch.utils.data.TensorDataset(x_tra, y_tra)
trainloader = DataLoader(dataset=traindataset, num_workers=opt.threads, batch_size=8, shuffle=True)
# 验证集的DataLoader
validataset = torch.utils.data.TensorDataset(x_val, y_val)
valiloader = DataLoader(dataset=validataset, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True)
注意:如果按照如下方式使用,就会报eError: take(): argument 'index' (position 1) must be Tensor, not numpy.ndarray错误
from sklearn.model_selection import train_test_split
import numpy as np
import torch
import torch.autograd import Variable
from torch.utils.data import DataLoader
traindata = np.load(train_path) # image_num * W * H
trainlabel = np.load(train_label_path)
train_data = traindata[:, np.newaxis, ...]
train_label_data = trainlabel[:, np.newaxis, ...]
x_train = Variable(torch.from_numpy(train_data))
x_train = x_train.float()
y_train = Variable(torch.from_numpy(train_label_data))
y_train = y_train.float()
# 将原始的训练数据集分为训练集和验证集,后面就可以使用早停机制
x_tra, x_val, y_tra, y_val = train_test_split(x_train, y_train, test_size=0.1) # 训练集和验证集使用9:1
报错原因:
train_test_split方法接受的x_train,y_train格式应该为numpy.ndarray 而不应该是Tensor,这点需要注意。
来源:https://blog.csdn.net/l8947943/article/details/105696192


猜你喜欢
- 本文介绍了目前6种比较常用的进度条,让大家都能直观地看到脚本运行最新的进展情况1.普通进度条在代码迭代运行中可以自己进行统计计算,并使用格式
- 如下所示:import urllib,json,requestsurl = 'http://127.0.0.1:8000/accou
- 本文实例讲述了JS实现获取数组中最大值或最小值功能。分享给大家供大家参考,具体如下:方法一://最小值Array.prototype.min
- 前言前面几个章节我们学习了对于普通文件的操作,比如说文件的创建、复制粘贴、裁剪粘贴、文件名的重命名、删除等等。另外还学习了一些基本练习,如何
- 目录Ⅰ. 简介Ⅱ. 注意事项Ⅲ. 使用方法Ⅳ. 教程首先spring自带了mongodb的orm,spring data mongodb,但
- 前言相关性分析算是很多算法以及建模的基础知识之一了,十分经典。关于许多特征关联关系以及相关趋势都可以利用相关性分析计算表达。其中常见的相关性
- 常用的 random 模块方法import random# random.random()用于生成一个 0 到 1 的随机浮点数: 0 &l
- Python Json使用本篇主要介绍一下 python 中 json的使用 如何把 dict转成json 、object 转成json 、
- Sql server中常用的几个数据类型: binary 固定长度的二进制数据,其最大长度为 8,000 个字节。 varbinary 可变
- JavaScript 中 typeof 和 instanceof 常用来判断一个变量是否为空,或者是什么类型的。但它们之间还是有区别的:ty
- 每个进行过较大型的ASP-Web应用程序设计的开发人员大概都有如下的经历:ASP代码与页面HTML混淆难分,业务逻辑与显示方式绞合,使得代码
- MySQL数据库中如何修改root用户的密码呢?下面总结了修改root用户密码的一些方法1: 使用set password语句修改mysql
- 前言mysql模块(项目地址为https://github.com/mysqljs/mysql)是一个开源的、JavaScript编写的My
- series: [{ &nbs
- 概述递归函数即直接或间接调用自身的函数,且递归过程中必须有一个明确的递归结束条件,称为递归出口。递归极其强大一点就是能够遍历任意的,不可预知
- 使用go mod之后,想要在goland中有代码提示,有两种方式,一种是使用gopath下的goimport工具,另一种是使用gomod自身
- 使用游标实现declare @id1 int,@oldid int,@e_REcordid int ,@Olde_REcordid intD
- 本文实例讲述了JS获取日期的方法。分享给大家供大家参考,具体如下:原理很简单,一天的时间的毫秒数是1000*60*60*24,前n天的日期就
- 首先让我们来看看有关 Perl 面向对象编程的三个基本定义:1. 一个“对象”是指一个“有办法知道它是属于哪个类”的简单引用。(
- python安装教程和Pycharm安装详细教程,分享给大家。首先我们来安装python1、首先进入网站下载:点击打开链接(或自己输入网址h