keras设置多个GPU一起训练网络

训练的时候

from keras.utils import multi_gpu_model
 
# Replicates `model` on 8 GPUs.
# This assumes that your machine has 8 available GPUs.
parallel_model = multi_gpu_model(model, gpus=8)
parallel_model.compile(loss='categorical_crossentropy',
                       optimizer='rmsprop')
 
# This `fit` call will be distributed on 8 GPUs.
# Since the batch size is 256, each GPU will process 32 samples.
parallel_model.fit(x, y, epochs=20, batch_size=256)

load_weights的时候也要按上述方法定义parallel_model,否则会报错

ValueError: You are trying to load a weight file containing 1 layers into a model with 6 layers.

发布了83 篇原创文章 · 获赞 4 · 访问量 5369

猜你喜欢

转载自blog.csdn.net/weixin_43486780/article/details/105257804