转载自keras中文文档
- Docs »
- keras新手指南 »
- 常见问题与解答
如何在多张GPU卡上使用Keras?
我们建议有多张GPU卡可用时,使用TnesorFlow后端。
有两种方法可以在多张GPU上运行一个模型:数据并行/设备并行
大多数情况下,你需要的很可能是“数据并行”
数据并行
数据并行将目标模型在多个设备上各复制一份,并使用每个设备上的复制品处理整个数据集的不同部分数据。Keras在
keras.utils.multi_gpu_model
中提供有内置函数,该函数可以产生任意模型的数据并行版本,最高支持在8片GPU上并行。 请参考utils中的multi_gpu_model文档。 下面是一个例子:(仅适用keras2.0+)from keras.utils import multi_gpu_model # Replicates `model` on 8 GPUs. # This assumes that your machine has 8 available GPUs. parallel_model = multi_gpu_model(model, gpus=8) parallel_model.compile(loss='categorical_crossentropy', optimizer='rmsprop') # This `fit` call will be distributed on 8 GPUs. # Since the batch size is 256, each GPU will process 32 samples. parallel_model.fit(x, y, epochs=20, batch_size=256)