Pytorch 如何加速Dataloader提升数据读取速度
作者:MKFMIKU 发布时间:2023-11-12 00:00:32
标签:Pytorch,Dataloader,数据,读取速度
在利用DL解决图像问题时,影响训练效率最大的有时候是GPU,有时候也可能是CPU和你的磁盘。
很多设计不当的任务,在训练神经网络的时候,大部分时间都是在从磁盘中读取数据,而不是做 Backpropagation 。
这种症状的体现是使用 Nividia-smi 查看 GPU 使用率时,Memory-Usage 占用率很高,但是 GPU-Util 时常为 0% ,如下图所示:
如何解决这种问题呢?
在 Nvidia 提出的分布式框架 Apex 里面,我们在源码里面找到了一个简单的解决方案:
https://github.com/NVIDIA/apex/blob/f5cd5ae937f168c763985f627bbf850648ea5f3f/examples/imagenet/main_amp.py#L256
class data_prefetcher():
def __init__(self, loader):
self.loader = iter(loader)
self.stream = torch.cuda.Stream()
self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1)
self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1)
# With Amp, it isn't necessary to manually convert data to half.
# if args.fp16:
# self.mean = self.mean.half()
# self.std = self.std.half()
self.preload()
def preload(self):
try:
self.next_input, self.next_target = next(self.loader)
except StopIteration:
self.next_input = None
self.next_target = None
return
with torch.cuda.stream(self.stream):
self.next_input = self.next_input.cuda(non_blocking=True)
self.next_target = self.next_target.cuda(non_blocking=True)
# With Amp, it isn't necessary to manually convert data to half.
# if args.fp16:
# self.next_input = self.next_input.half()
# else:
self.next_input = self.next_input.float()
self.next_input = self.next_input.sub_(self.mean).div_(self.std)
我们能看到 Nvidia 是在读取每次数据返回给网络的时候,预读取下一次迭代需要的数据,
那么对我们自己的训练代码只需要做下面的改造:
training_data_loader = DataLoader(
dataset=train_dataset,
num_workers=opts.threads,
batch_size=opts.batchSize,
pin_memory=True,
shuffle=True,
)
for iteration, batch in enumerate(training_data_loader, 1):
# 训练代码
#-------------升级后---------
data, label = prefetcher.next()
iteration = 0
while data is not None:
iteration += 1
# 训练代码
data, label = prefetcher.next()
这样子我们的 Dataloader 就像打了鸡血一样提高了效率很多,如下图:
当然,最好的解决方案还是从硬件上,把读取速度慢的机械硬盘换成 NVME 固态吧~
补充:Pytorch设置多线程进行dataloader时影响GPU运行
使用PyTorch设置多线程(threads)进行数据读取时,其实是假的多线程,他是开了N个子进程(PID是连续的)进行模拟多线程工作。
以载入cocodataset为例
DataLoader
dataloader = torch.utils.data.DataLoader(COCODataset(config["train_path"],
(config["img_w"], config["img_h"]),
is_training=True),
batch_size=config["batch_size"],
shuffle=True, num_workers=32, pin_memory=True)
numworkers就是指定多少线程的参数,原为32。
检查GPU是否运行该程序
查看运行在gpu上的所有程序:
fuser -v /dev/nvidia*
如果没有返回,则该程序并没有在GPU上运行
指定GPU运行
将num_workers改成0即可
来源:https://zhuanlan.zhihu.com/p/66145913


猜你喜欢
- 在pycharm中我们有时需要切换python的版本,这里需要注意的是我们是在PyCharm中的Preferences中切换的,在File的
- 引言将对象的状态信息转换为可以存储或传输的形式的过程叫作序列化类似地从序列化后的数据转换成相对应的对象叫作 反序列化本文介绍 Python
- 本文只有代码,介绍了有关GUI界面的学生信息管理系统的实现。已经过调试没有很大问题。如有错误,还请批评指正。1.导入tkinter模块imp
- 1.pickle 写: 以写方式打开一个文件描述符,调用pickle.dump把对象写进去 &
- Python 中有 while 和 for 两种循环机制,其中 while 循环
- Todo清单需要实现的功能有添加任务、删除任务、编辑任务,操作要关联数据库。任务需要绑定用户,部门。用户需要绑定部门。{#自己编写一个基类模
- <img :onerror="errpic" class="customerHead" :sr
- 如何提高SQL Server数据库的性能,该从哪里入手呢?笔者认为,该遵循从外到内的顺序,来改善数据库的运行性能。如下图: 第一层
- 严正声明:本文仅限于技术讨论,严禁用于其他用途。基础知识socket通信模块:针对TCP/IP协议簇进行的程序封装,在Windows/Lin
- 一、并行复制的背景首先,为什么会有并行复制这个概念呢?1. DBA都应该知道,MySQL的复制是基于binlog的。 2. My
- pymysql的executemany使用在使用pymysql的executemany方法时,需要注意的几个问题1、在写sql语句时,不管字
- 本文内容会引起杀毒软件的莫名兴奋,建议先安抚杀毒软件,让杀毒软件先休息一下再继续操作安装python3.6转exe会遇到很多问题,其中部分是
- 每次安装总是有些不同,这次用这种方式尝试一下,也记录一下。1、首先需要去下载rpm包:镜像地址:http://mysql.mirrors.p
- 第一、几种常用方法读取TXT文档:urlopen()读取PDF文档:pdfminer3k第二、乱码问题(1)、from urllib.req
- rss.asp格式的 下面代码保存为rss.asp 代码如下:<!--#include file="conn.as
- python是一款简单易用的编程语言,特别是其第三方库,能够方便我们快速进入工作,但其第三方库的安装困扰很多人.现在安装python时,已经
- 前言报错如下:Could not open JDBC Connection for transaction; nested exceptio
- 引言欢迎来到我们的系列博客《Python全景系列》!在这个系列中,我们将带领你从Python的基础知识开始,一步步深入到高级话题,帮助你掌握
- 简介本文分享的实例代码主要通过python语言实现批量替换页眉页脚的操作功能,具体如下。代码#!/usr/bin/env python# -
- 爬取一些网站下指定的内容,一般来说可以用xpath来直接从网页上来获取,但是当我们获取的内容不唯一的时候我们无法选择,我们所需要的、所指定的