keras多显卡训练方式
作者:深夜虫鸣 发布时间:2022-05-01 02:50:21
使用keras进行训练,默认使用单显卡,即使设置了os.environ['CUDA_VISIBLE_DEVICES']为两张显卡,也只是占满了显存,再设置tf.GPUOptions(allow_growth=True)之后可以清楚看到,只占用了第一张显卡,第二张显卡完全没用。
要使用多张显卡,需要按如下步骤:
(1)import multi_gpu_model函数:from keras.utils import multi_gpu_model
(2)在定义好model之后,使用multi_gpu_model设置模型由几张显卡训练,如下:
model=Model(...) #定义模型结构
model_parallel=multi_gpu_model(model,gpu=n) #使用几张显卡n等于几
model_parallel.compile(...) #注意是model_parallel,不是model
通过以上代码,model将作为CPU上的原始模型,而model_parallel将作为拷贝模型被复制到各个GPU上进行梯度计算。如果batchsize为128,显卡n=2,则每张显卡单独计算128/2=64张图像,然后在CPU上将两张显卡计算得到的梯度进行融合更新,并对模型权重进行更新后再将新模型拷贝到GPU再次训练。
(3)从上面可以看出,进行训练时,仍然在model_parallel上进行:
model_parallel.fit(...) #注意是model_parallel
(4)保存模型时,model_parallel保存了训练时显卡数量的信息,所以如果直接保存model_parallel的话,只能将模型设置为相同数量的显卡调用,否则训练的模型将不能调用。因此,为了之后的调用方便,只保存CPU上的模型,即model:
model.save(...) #注意是model,不是model_parallel
如果用到了callback函数,则默认保存的也是model_parallel(因为训练函数是针对model_parallel的),所以要用回调函数保存model的话需要自己对回调函数进行定义:
class OwnCheckpoint(keras.callbacks.Callback):
def __init__(self,model):
self.model_to_save=model
def on_epoch_end(self,epoch,logs=None): #这里logs必须写
self.model_to_save.save('model_advanced/model_%d.h5' % epoch)
定以后具体使用如下:
checkpoint=OwnCheckpoint(model)
model_parallel.fit_generator(...,callbacks=[checkpoint])
这样就没问题了!
补充知识:keras.fit_generator及多卡训练记录
1.环境问题
使用keras,以tensorflow为背景,tensorflow1.14多卡训练会出错 python3.6
2.代码
2.1
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = '4,5'
2.2 自定义generator函数
def img_image_generator(path_img, path_lab, batch_size, data_list):
while True:
# 'train_list.csv'
file_list = pd.read_csv(data_list, sep=',',usecols=[1]).values.tolist()
file_list = [i[0] for i in file_list]
cnt = 0
X = []
Y1 = []
for file_i in file_list:
x = cv2.imread(path_img+'/'+file_i, cv2.IMREAD_GRAYSCALE)
x = x.astype('float32')
x /= 255.
y = cv2.imread(path_lab+'/'+file_i, cv2.IMREAD_GRAYSCALE)
y = y.astype('float32')
y /= 255.
X.append(x.reshape(256, 256, 1))
Y1.append(y.reshape(256, 256, 1))
cnt += 1
if cnt == batch_size:
cnt = 0
yield (np.array(X), [np.array(Y1), np.array(Y1)])
X = []
Y1 = []
2.3 函数调用及训练
generator_train = img_image_generator(path1, path2, 4, pathcsv_train)
generator_test= img_image_generator(path1, path2, 4, pathcsv_test)
model.fit_generator(generator_train, steps_per_epoch=237*2, epochs=50, callbacks=callbacks_list, validation_data=generator_test, validation_steps=60*2)
3. 多卡训练
3.1 复制model
model_parallel = multi_gpu_model(model, gpus=2)
3.2 checkpoint 定义
class ParallelModelCheckpoint(ModelCheckpoint):
def __init__(self, model, filepath, monitor='val_out_final_score', verbose=0,\
save_best_only=False, save_weights_only=False, mode='auto', period=1):
self.single_model = model
super(ParallelModelCheckpoint, self).__init__(filepath, monitor, verbose, save_best_only, save_weights_only, mode, period)
def set_model(self, model):
super(ParallelModelCheckpoint, self).set_model(self.single_model)
使用
model_checkpoint = ParallelModelCheckpoint(model=model, filepath=filepath, monitor='val_loss',verbose=1, save_best_only=True, mode='min')
3.3 注意的问题
保存模型是时候需要使用以原来的模型保存,不能使用model_parallel保存
来源:https://blog.csdn.net/u010122972/article/details/84784245
猜你喜欢
- 如果你正在负责一个基于SQL Server的项目,或者你刚刚接触SQL Server,你都有可能要面临一些数据库性能的问题,这篇文章会为你提
- 0.引言利用python开发,借助Dlib库捕获摄像头中的人脸,提取人脸特征,通过计算欧氏距离来和预存的人脸特征进行对比,达到人脸识别的目的
- 为了顺利的开发一个多语言的国际化J2EE程序,需要修改数据库字符集,我的做法如下:安装 MySq时选择字符集为UTF-8修改MySql安装目
- 这篇文章主要介绍了如何使用Python多线程测试并发漏洞,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的
- python画分布图代码示例:# encoding=utf-8import matplotlib.pyplot as pltfrom pyl
- XHTML规范中有一条标准就是“每个XHTML标签都有一个结束标记”。那么对于HTML中原来不带结束标记的元素,则在该结束前加上“/”来关闭
- 无限循环如果条件判断语句永远为 true,循环将会无限的执行下去。如下实例#!/usr/bin/python# -*- coding: UT
- 这篇文章主要介绍了Python远程开发环境部署与调试过程图解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需
- 本文列出了初学网页编程中常用到的一些代码和一些技巧,简单实用,您一定用得到。1、oncontextmenu="window.eve
- 一、将对象转为json字符串json.dumps:将 Python 对象编码成 JSON 字符串json.loads:将已编码的 JSON
- golang 原生 http 库已经可以很方便地实现一个 http server 了,但对于复杂的 web 服务来说,路由解析,请求参数解析
- 今天来说一下,有些刚刚接触python的朋友,在使用pip install安装python 第三方库的过程中会出现网速很慢,或者是安装下载到
- 本文实例讲述了python复制文件的方法。分享给大家供大家参考。具体分析如下:这里涉及Python复制文件在实际操作方案中的实际应用以及Py
- 代理服务是一种复杂的技术,具有很多可配置的移动组件。详细信息如下:信息信息是指在服务代理应用程序中调用的基本信息单元。对于服务代理来说,信息
- 问题:SQL Server 2005中如何利用xml拆分字符串序列?解答:下文中介绍的方法比替换为select union all方法更为见
- 上次亚马逊的商品信息都获取到了,自然要看一下评论的部分。用户的评论能直观的反映当前商品值不值得购买,亚马逊的评分信息也能获取到做一个评分的权
- 前言图片的本质就是大量像素在二维平面上的组合,每个像素点用数字化方式记录颜色。可以直观的想象,一张图片就是一个巨大的电子栅格,每个格子内有一
- Go Gin 实现文件的上传下载流读取文件上传routerrouter.POST("/resources/common/uploa
- 正题: 1.1 javascript的灵活性 面向对象对象的Javascript编程模式:1、可以保存状态 2、具有对象内部才能调用的方法
- 什么是Firebug从事了数年的Web开发工作,越来越觉得现在对WEB开发有了更高的要求。要写出漂亮的HTML代码;要编写精致的CSS样式表