keras多显卡训练方式
作者:深夜虫鸣 发布时间:2022-05-01 02:50:21
使用keras进行训练,默认使用单显卡,即使设置了os.environ['CUDA_VISIBLE_DEVICES']为两张显卡,也只是占满了显存,再设置tf.GPUOptions(allow_growth=True)之后可以清楚看到,只占用了第一张显卡,第二张显卡完全没用。
要使用多张显卡,需要按如下步骤:
(1)import multi_gpu_model函数:from keras.utils import multi_gpu_model
(2)在定义好model之后,使用multi_gpu_model设置模型由几张显卡训练,如下:
model=Model(...) #定义模型结构
model_parallel=multi_gpu_model(model,gpu=n) #使用几张显卡n等于几
model_parallel.compile(...) #注意是model_parallel,不是model
通过以上代码,model将作为CPU上的原始模型,而model_parallel将作为拷贝模型被复制到各个GPU上进行梯度计算。如果batchsize为128,显卡n=2,则每张显卡单独计算128/2=64张图像,然后在CPU上将两张显卡计算得到的梯度进行融合更新,并对模型权重进行更新后再将新模型拷贝到GPU再次训练。
(3)从上面可以看出,进行训练时,仍然在model_parallel上进行:
model_parallel.fit(...) #注意是model_parallel
(4)保存模型时,model_parallel保存了训练时显卡数量的信息,所以如果直接保存model_parallel的话,只能将模型设置为相同数量的显卡调用,否则训练的模型将不能调用。因此,为了之后的调用方便,只保存CPU上的模型,即model:
model.save(...) #注意是model,不是model_parallel
如果用到了callback函数,则默认保存的也是model_parallel(因为训练函数是针对model_parallel的),所以要用回调函数保存model的话需要自己对回调函数进行定义:
class OwnCheckpoint(keras.callbacks.Callback):
def __init__(self,model):
self.model_to_save=model
def on_epoch_end(self,epoch,logs=None): #这里logs必须写
self.model_to_save.save('model_advanced/model_%d.h5' % epoch)
定以后具体使用如下:
checkpoint=OwnCheckpoint(model)
model_parallel.fit_generator(...,callbacks=[checkpoint])
这样就没问题了!
补充知识:keras.fit_generator及多卡训练记录
1.环境问题
使用keras,以tensorflow为背景,tensorflow1.14多卡训练会出错 python3.6
2.代码
2.1
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = '4,5'
2.2 自定义generator函数
def img_image_generator(path_img, path_lab, batch_size, data_list):
while True:
# 'train_list.csv'
file_list = pd.read_csv(data_list, sep=',',usecols=[1]).values.tolist()
file_list = [i[0] for i in file_list]
cnt = 0
X = []
Y1 = []
for file_i in file_list:
x = cv2.imread(path_img+'/'+file_i, cv2.IMREAD_GRAYSCALE)
x = x.astype('float32')
x /= 255.
y = cv2.imread(path_lab+'/'+file_i, cv2.IMREAD_GRAYSCALE)
y = y.astype('float32')
y /= 255.
X.append(x.reshape(256, 256, 1))
Y1.append(y.reshape(256, 256, 1))
cnt += 1
if cnt == batch_size:
cnt = 0
yield (np.array(X), [np.array(Y1), np.array(Y1)])
X = []
Y1 = []
2.3 函数调用及训练
generator_train = img_image_generator(path1, path2, 4, pathcsv_train)
generator_test= img_image_generator(path1, path2, 4, pathcsv_test)
model.fit_generator(generator_train, steps_per_epoch=237*2, epochs=50, callbacks=callbacks_list, validation_data=generator_test, validation_steps=60*2)
3. 多卡训练
3.1 复制model
model_parallel = multi_gpu_model(model, gpus=2)
3.2 checkpoint 定义
class ParallelModelCheckpoint(ModelCheckpoint):
def __init__(self, model, filepath, monitor='val_out_final_score', verbose=0,\
save_best_only=False, save_weights_only=False, mode='auto', period=1):
self.single_model = model
super(ParallelModelCheckpoint, self).__init__(filepath, monitor, verbose, save_best_only, save_weights_only, mode, period)
def set_model(self, model):
super(ParallelModelCheckpoint, self).set_model(self.single_model)
使用
model_checkpoint = ParallelModelCheckpoint(model=model, filepath=filepath, monitor='val_loss',verbose=1, save_best_only=True, mode='min')
3.3 注意的问题
保存模型是时候需要使用以原来的模型保存,不能使用model_parallel保存
来源:https://blog.csdn.net/u010122972/article/details/84784245


猜你喜欢
- python中通过引用计数来回收垃圾对象,在某些环形数据结构(树,图……),存在对象间的循环引用,比如树的父节点引用子节点,子节点同时引用父
- 背景我们先来看看MySQL 8.0的事务提交的大致流程以上流程,是MySQL8.0对WAL原则的一种实现,这个流程意味着,任何一个事务的提交
- 千图成像也就是用N张图片组成一张图片的效果。制作方法有很多的,最常见的如用ps、懒人图云、foto-mosaik-edda这些制作。千图成像
- 关于target="_blank"去留的问题在网上已经被反复争议很多次了。有的说要留,有的说要去掉。主张留的一方主要是考
- 在 JavaScritp 中使用计时事件是很容易的,两个关键方法是:setTimeout()未来的某时执行代码clearTimeout()取
- 一、概述本文将介绍如何使用python3给企业微信发送消息。我的环境是linux + python3.6.10。二、python脚本#!/u
- 例如:我们在百度中搜索 词典网,则网址后面的参数就是http://www.baidu.com/s?cl=3&wd=%B4%CA%B5
- 上次遇到一个需要打包下载批量图片的问题,找了一下发现这个好方法,记录一下。首先新建一个zipfile打包类:<?phpclass zi
- 本文实例讲述了django框架中ajax的使用及避开CSRF 验证的方式。分享给大家供大家参考,具体如下:ajax(Asynchronous
- 使用TensorFlow模块时,弹出错误Your CPU supports instructions that this TensorFlo
- python的scipy.stats模块是连续型随机变量的公共方法,可以产生随机数,通常是以正态分布作为scipy.stats的基本使用方法
- 简介集合对象 set 是由具有唯一性的可哈希对象组成的无序多项集,如 list 不能哈希因此,不能作为 set 的一项。set 的常见用途包
- 一、题目要求用原生Python实现knn分类算法。二、题目分析数据来源:鸢尾花数据集(见附录Iris.txt)数据集包含150个数据集,分为
- 从本文开始,本系列将介绍python简单案例并进行代码展示,本文的案例是利用pandas库实现读取csv文件并按照列的从小到大进行排序。前言
- cron是什么cron的意思就是:计划任务,说白了就是定时任务。我和系统约个时间,你在几点几分几秒或者每隔几分钟跑一个任务(job),就那么
- 很多网站需要将好的会员号留着,或用于日后的盈利。实现方法不是本文讨论范围,本文仅列出用于检测靓号类型的一些正则。靓号检测:主要可以检测连号(
- 与其它大多数语言一样,Python 也拥有 for 循环。你到现在还未曾看到它们的唯一原因就是,Python 在其它太多的方面表现出色,通常
- JDBC连接MySQL数据库关键的四个步骤1、查找驱动程序MySQL目前提供的Java驱动程序为Connection/J,可以从MySQL官
- 在正式的生产环境中,我们常常会需要监控服务器的状态,以保证公司整个业务的正常运转,常常我们会用到像nagios、zabbix这类工具进行实时
- 问题描述:在使用Vue框架开发时,在函数中改变了页面中的某个值,在函数中查看是修改成功了,但在页面中没有及时刷新改变后的值;解决:运用 th