# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import mxnet as mx
import logging
import os
def fit(args, network, data_loader, eval_metrics=None, batch_end_callback=None):
# kvstore
kv = mx.kvstore.create(args.kv_store)
# logging
head = '%(asctime)-15s Node[' + str(kv.rank) + '] %(message)s'
if 'log_file' in args and args.log_file is not None:
log_file = args.log_file
log_dir = args.log_dir
log_file_full_name = os.path.join(log_dir, log_file)
if not os.path.exists(log_dir):
os.mkdir(log_dir)
logger = logging.getLogger()
handler = logging.FileHandler(log_file_full_name)
formatter = logging.Formatter(head)
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(logging.DEBUG)
logger.info('start with arguments %s', args)
else:
logging.basicConfig(level=logging.DEBUG, format=head)
logging.info('start with arguments %s', args)
# load model
model_prefix = args.model_prefix
if model_prefix is not None:
model_prefix += "-%d" % (kv.rank)
model_args = {}
if args.load_epoch is not None:
assert model_prefix is not None
tmp = mx.model.FeedForward.load(model_prefix, args.load_epoch)
model_args = {'arg_params' : tmp.arg_params,
'aux_params' : tmp.aux_params,
'begin_epoch' : args.load_epoch}
# save model
save_model_prefix = args.save_model_prefix
if save_model_prefix is None:
save_model_prefix = model_prefix
checkpoint = None if save_model_prefix is None else mx.callback.do_checkpoint(save_model_prefix)
# data
(train, val) = data_loader(args, kv)
# train
devs = mx.cpu() if args.gpus is None else [
mx.gpu(int(i)) for i in args.gpus.split(',')]
epoch_size = args.num_examples / args.batch_size
if args.kv_store == 'dist_sync':
epoch_size /= kv.num_workers
model_args['epoch_size'] = epoch_size
if 'lr_factor' in args and args.lr_factor < 1:
model_args['lr_scheduler'] = mx.lr_scheduler.FactorScheduler(
step = max(int(epoch_size * args.lr_factor_epoch), 1),
factor = args.lr_factor)
if 'clip_gradient' in args and args.clip_gradient is not None:
model_args['clip_gradient'] = args.clip_gradient
# disable kvstore for single device
if 'local' in kv.type and (
args.gpus is None or len(args.gpus.split(',')) is 1):
kv = None
mod = mx.mod.Module(network, context=devs)
if eval_metrics is None:
eval_metrics = ['accuracy']
## TopKAccuracy only allows top_k > 1
for top_k in [5, 10, 20]:
eval_metrics.append(mx.metric.create('top_k_accuracy', top_k=top_k))
if batch_end_callback is not None:
if not isinstance(batch_end_callback, list):
batch_end_callback = [batch_end_callback]
else:
batch_end_callback = []
batch_end_callback.append(mx.callback.Speedometer(args.batch_size, 50))
mod.fit(train_data=train, eval_metric=eval_metrics, eval_data=val, optimizer='sgd',
optimizer_params={'learning_rate':args.lr, 'momentum': 0.9, 'wd': 0.00001},
num_epoch=args.num_epochs, batch_end_callback=batch_end_callback,
initializer=mx.init.Xavier(factor_type="in", magnitude=2.34),
kvstore=kv, epoch_end_callback=checkpoint, **model_args)
train_model.py
猜你喜欢
转载自blog.csdn.net/tony2278/article/details/87866166
今日推荐
周排行