pytorch 数据加载性能对比分析
作者:ShellCollector 发布时间:2022-04-17 04:22:22
传统方式需要10s,dat方式需要0.6s
import os
import time
import torch
import random
from common.coco_dataset import COCODataset
def gen_data(batch_size,data_path,target_path):
os.makedirs(target_path,exist_ok=True)
dataloader = torch.utils.data.DataLoader(COCODataset(data_path,
(352, 352),
is_training=False, is_scene=True),
batch_size=batch_size,
shuffle=False, num_workers=0, pin_memory=False,
drop_last=True) # DataLoader
start = time.time()
for step, samples in enumerate(dataloader):
images, labels, image_paths = samples["image"], samples["label"], samples["img_path"]
print("time", images.size(0), time.time() - start)
start = time.time()
# torch.save(samples,target_path+ '/' + str(step) + '.dat')
print(step)
def cat_100(target_path,batch_size=100):
paths = os.listdir(target_path)
li = [i for i in range(len(paths))]
random.shuffle(li)
images = []
labels = []
image_paths = []
start = time.time()
for i in range(len(paths)):
samples = torch.load(target_path + str(li[i]) + ".dat")
image, label, image_path = samples["image"], samples["label"], samples["img_path"]
images.append(image.cuda())
labels.append(label.cuda())
image_paths.append(image_path)
if i % batch_size == batch_size - 1:
images = torch.cat((images), 0)
print("time", images.size(0), time.time() - start)
images = []
labels = []
image_paths = []
start = time.time()
i += 1
if __name__ == '__main__':
os.environ["CUDA_VISIBLE_DEVICES"] = '3'
batch_size=320
# target_path='d:/test_1000/'
target_path='d:\img_2/'
data_path = r'D:\dataset\origin_all_datas\_2train'
gen_data(batch_size,data_path,target_path)
# get_data(target_path,batch_size)
# cat_100(target_path,batch_size)
这个读取数据也比较快:320 batch_size 450ms
def cat_100(target_path,batch_size=100):
paths = os.listdir(target_path)
li = [i for i in range(len(paths))]
random.shuffle(li)
images = []
labels = []
image_paths = []
start = time.time()
for i in range(len(paths)):
samples = torch.load(target_path + str(li[i]) + ".dat")
image, label, image_path = samples["image"], samples["label"], samples["img_path"]
images.append(image)#.cuda())
labels.append(label)#.cuda())
image_paths.append(image_path)
if i % batch_size < batch_size - 1:
i += 1
continue
i += 1
images = torch.cat(([image.cuda() for image in images]), 0)
print("time", images.size(0), time.time() - start)
images = []
labels = []
image_paths = []
start = time.time()
补充:pytorch数据加载和处理问题解决方案
最近跟着pytorch中文文档学习遇到一些小问题,已经解决,在此对这些错误进行记录:
在读取数据集时报错:
AttributeError: 'Series' object has no attribute 'as_matrix'
在显示图片是时报错:
ValueError: Masked arrays must be 1-D
显示单张图片时figure一闪而过
在显示多张散点图的时候报错:
TypeError: show_landmarks() got an unexpected keyword argument 'image'
解决方案
主要问题在这一行: 最终目的是将Series转为Matrix,即调用np.mat即可完成。
修改前
landmarks =landmarks_frame.iloc[n, 1:].as_matrix()
修改后
landmarks =np.mat(landmarks_frame.iloc[n, 1:])
打散点的x和y坐标应该均为向量或列表,故将landmarks后使用tolist()方法即可
修改前
plt.scatter(landmarks[:,0],landmarks[:,1],s=10,marker='.',c='r')
修改后
plt.scatter(landmarks[:,0].tolist(),landmarks[:,1].tolist(),s=10,marker='.',c='r')
前面使用plt.ion()打开交互模式,则后面在plt.show()之前一定要加上plt.ioff()。这里直接加到函数里面,避免每次plt.show()之前都用plt.ioff()
修改前
def show_landmarks(imgs,landmarks):
'''显示带有地标的图片'''
plt.imshow(imgs)
plt.scatter(landmarks[:,0].tolist(),landmarks[:,1].tolist(),s=10,marker='.',c='r')#打上红色散点
plt.pause(1)#绘图窗口延时
修改后
def show_landmarks(imgs,landmarks):
'''显示带有地标的图片'''
plt.imshow(imgs)
plt.scatter(landmarks[:,0].tolist(),landmarks[:,1].tolist(),s=10,marker='.',c='r')#打上红色散点
plt.pause(1)#绘图窗口延时
plt.ioff()
网上说对于字典类型的sample可通过 **sample的方式获取每个键下的值,但是会报错,于是把输入写的详细一点,就成功了。
修改前
show_landmarks(**sample)
修改后
show_landmarks(sample['image'],sample['landmarks'])
以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。如有错误或未考虑完全的地方,望不吝赐教。
来源:https://blog.csdn.net/jacke121/article/details/85236561


猜你喜欢
- 本文实例讲述了C# Ado.net读取SQLServer数据库存储过程列表及参数信息的方法。分享给大家供大家参考,具体如下:得到数据库存储过
- 一、安装go get github.com/sirupsen/logrus二、使用1、当做标准库使用logrus实现了标准库log的方法,可
- 问题描述最近产品提出一个需求,说是做表格呈现统计数据,不过数据源是来自两个地方的,所以需要做两个表格去呈现数据,同时在表格最后统计数据。效果
- 在MySQL中,A LEFT JOIN B join_condition执行过程如下:· 根据表A和A依赖的所有表设置表B。· 根据LEFT
- 最近的一些疫情信息很让人揪心,为了方便大家掌握疫情信息,在空闲之余做了一个关于 nCoV 的疫情监控小助手。主要的功能是通过企业微信的 We
- 今天我们来探索python中大部分的异常报错首先异常是什么,异常白话解释就是不正常,程序里面一般是指程序员输入的格式不规范,或者需求的参数类
- 一、介绍对于Visual Studio Code开发工具,有一款优秀的GoLang插件,它的主页为:https://github.com/m
- <html xmlns="http://www.w3.org/1999/xhtml"><head>
- 给密码加密是什么:用户注册的密码一般网站管理人员会利用md5方法加密,这种加密方法的好处是它是单向加密的,也就是说,你只有在提前知道某一串密
- 查看python搜索包的路径的实现方法:python搜索包的路径存储在sys.path下查看方法:import syssys.path临时添
- Liwu_Items表,CreateTime列建立聚集索引 第一种,sqlserver2005特有的分页语法 代码如下:declare @p
- python中字典和列表的使用,在数据处理中应该是最常用的,这两个熟练后基本可以应付大部分场景了。不过网上的基础教程只告诉你列表、字典是什么
- 1.定义在某些情况下,一个类的对象是有限且固定的,比如季节类,它只有 4 个对象;再比如行星类,目前只有 8 个对象。这种实例有限且固定的类
- 1、卓越亚马逊的首页轮换图片,每刷新一次,都是随机不同的顺序显示,这样的设计解决了对于较多图片轮换而靠后的图片信息很少被看到的问题,这点对于
- 前言让我的电脑认识我,我的电脑只有认识我,才配称之为我的电脑!今天,我们用Python实现简单的人脸识别技术!Python里,简单的人脸识别
- 简介虽然使用Explain不能够马上调优我们的SQL,它也不能给予我们一些调整建议,但是它能够让我们了解MySQL 优化器是如何执行SQL
- 我们都知道如何上传单个文件,但如果有大量文件或大量数据,这就扎心了,可能会变得单调。因此目前想到一种办法,将文件压缩成zip包,然后再解压到
- 目录1.随机取小数:2.整数的随机选取:3.随机列表取数,元素打乱:总结1.随机取小数:import randomprint(random.
- 无论你在linux上娱乐还是工作,这对你而言都是一个使用python来编程的很好的机会。回到大学我希望他们教我的是Python而不是Java
- 1.os.system函数wget 是一个下载软件的程序,如果已经下载好该软件,可以用py调用该软件。假如该软件目录在d:\tools\wg