介绍
想知道mxnet在训练过程或者验证过程中,如何通过iterator提供数据
几个问题:
- 如何构造iterator
- 从iterator中获取数据时,data_batch = next(iterator),输入输出是什么
分析mxnet自带的mx.io.NDArrayIter,看如何把一个NDArray转化为一个可以用于module.fit() 的 iterator
用于测试的代码,使用一个MLP学习mnist
'''
Loading Data
'''
import mxnet as mx
from collections import OrderedDict
from mxnet.ndarray import array
mnist = mx.test_utils.get_mnist()# dict
#'train_data' ndarray ,shape<class 'tuple'> (60000,1,28,28)
#'train_label' ndarray ,shape<class 'tuple'> (60000,)
#'test_data' ndarray ,shape<class 'tuple'> (10000,1,28,28)
#'test_label' ndarray ,shape<class 'tuple'> (10000,)
# Fix the seed
mx.random.seed(42)
# Set the compute context, GPU is available otherwise CPU
ctx = mx.gpu() if mx.test_utils.list_gpus() else mx.cpu()
batch_size = 100
train_iter = mx.io.NDArrayIter(mnist['train_data'], mnist['train_label'], batch_size, shuffle=True)
val_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size)
'''
Training
'''
'''
这里的名字'data'不能改,对应于mx.io.NDArrayIter的defaltname参数就是'data',往后看就明白了
也可以改着看看bug信息
'''
data = mx.sym.var('data')
# Flatten the data from 4-D shape into 2-D (batch_size, num_channel*width*height)
data = mx.sym.flatten(data=data)
# The first fully-connected layer and the corresponding activation function
fc1 = mx.sym.FullyConnected(data=data, num_hidden=128)
act1 = mx.sym.Activation(data=fc1, act_type="relu")
# The second fully-connected layer and the corresponding activation function
fc2 = mx.sym.FullyConnected(data=act1, num_hidden = 64)
act2 = mx.sym.Activation(data=fc2, act_type="relu")
# MNIST has 10 classes
fc3 = mx.sym.FullyConnected(data=act2, num_hidden=10)
# Softmax with cross entropy loss
mlp = mx.sym.SoftmaxOutput(data=fc3, name='softmax')
import logging
logging.getLogger().setLevel(logging.DEBUG) # logging to stdout
# create a trainable module on compute context
mlp_model = mx.mod.Module(symbol=mlp, context=ctx)
mlp_model.fit(train_iter, # train data
eval_data=val_iter, # validation data
optimizer='sgd', # use SGD to train
optimizer_params={'learning_rate':0.1}, # use fixed learning rate
eval_metric='acc', # report accuracy during training
batch_end_callback = mx.callback.Speedometer(batch_size, 100), # output progress for each 100 data batches
num_epoch=10) # train for at most 10 dataset passes
看 mx.io.NDArrayIter.__init__()
def __init__(self, data, label=None, batch_size=1, shuffle=False,
last_batch_handle='pad', data_name='data',
label_name='softmax_label'):
super(NDArrayIter, self).__init__(batch_size)
'''统一输入的格式为list(tuple(key,val),tuple(key,val)……)'''
'''划重点!!!这个key和executor里的symbol对应的'''
self.data = _init_data(data, allow_empty=False, default_name=data_name)
self.label = _init_data(label, allow_empty=True, default_name=label_name)
if ((_has_instance(self.data, CSRNDArray) or _has_instance(self.label, CSRNDArray)) and
(last_batch_handle != 'discard')):
raise NotImplementedError("`NDArrayIter` only supports ``CSRNDArray``" \
" with `last_batch_handle` set to `discard`.")
'''
self.idx 是一个arrange生成的list,list大小为self.data[0][1].shape[0]
从shape[0]可以看出输入的ndarray格式必须是num_data*data_instance
shuffle data 打乱数据
'''
if shuffle:
tmp_idx = arange(self.data[0][1].shape[0], dtype=np.int32)
self.idx = random_shuffle(tmp_idx, out=tmp_idx).asnumpy()
self.data = _shuffle(self.data, self.idx)
self.label = _shuffle(self.label, self.idx)
else:
self.idx = np.arange(self.data[0][1].shape[0])
'''如果选择'discard',则把输入的数据裁剪为batch_size的整数倍'''
# batching
if last_batch_handle == 'discard':
new_n = self.data[0][1].shape[0] - self.data[0][1].shape[0] % batch_size
self.idx = self.idx[:new_n]
'''把data和label关联成一个list=[data_0_ndarray,data_1_ndarray,……,label_0_ndarray,]'''
self.data_list = [x[1] for x in self.data] + [x[1] for x in self.label]
'''输入输出一共有多少个ndarray mnist-2 '''
self.num_source = len(self.data_list)
'''数据量mnist-60000 for train'''
self.num_data = self.idx.shape[0]
'''batch_size 不能大于总数据量'''
assert self.num_data >= batch_size, \
"batch_size needs to be smaller than data size."
'''定义一个光标'''
self.cursor = -batch_size
self.batch_size = batch_size
'''最后不够一个batch_size时的处理方法'pad' or 'discard' '''
self.last_batch_handle = last_batch_handle
_init_data 的作用是把输入的data统一格式,因为这个初始化data输入支持多种类型numpy.ndarray/mxnet.ndarray/h5py.Dataset输入,可以是单个的这些类型数据,也可能是他们的list输入
- 比如输入一个mxnet.ndarray
- 输出的格式为list[tuple(str{'_0_data'},mxnet.ndarray)]
- 如果输入一个list:[mxnet.ndarray,mxnet.ndarray]
- 输出格式为list[tuple(str{'_0_data'},mxnet.ndarray),tuple(str{'_1_data'},mxnet.ndarray)]
def _init_data(data, allow_empty, default_name):
"""Convert data into canonical form."""
assert (data is not None) or allow_empty
if data is None:
data = []
'''如果输入不是list,则把data转化为list,list中只有一个元素'''
if isinstance(data, (np.ndarray, NDArray, h5py.Dataset)
if h5py else (np.ndarray, NDArray)):
data = [data]
#type(data) = list
'''接着把list转化为OrderedDict'''
if isinstance(data, list):
if not allow_empty:
assert(len(data) > 0)
if len(data) == 1:
'''如果list中只有一个,即输入只有一个ndarray,
Dict 只有一个元素, key 命名和参数 default_name一致,val 为该输入data
data = OrderedDict([(default_name, data[0])]) # pylint: disable=redefined-variable-type'''
else:
'''输入多个ndarray,则Dict中有多个元素,key命名格式为('_%d_%s' % (i, default_name)
如:{('_0_data',ndarray),('_1_data',ndarray)……}'''
data = OrderedDict( # pylint: disable=redefined-variable-type
[('_%d_%s' % (i, default_name), d) for i, d in enumerate(data)])
if not isinstance(data, dict):
raise TypeError("Input must be NDArray, numpy.ndarray, h5py.Dataset " + \
"a list of them or dict with them as values")
'''这里把非mxnet.ndarray输入,转换成mxnet.ndarray数据类型'''
for k, v in data.items():
if not isinstance(v, (NDArray, h5py.Dataset) if h5py else NDArray):
try:
data[k] = array(v)
except:
raise TypeError(("Invalid type '%s' for %s, " % (type(v), k)) + \
"should be NDArray, numpy.ndarray or h5py.Dataset")
'''把Dict转成list,dict中的(key,val)变成tuple(key,val)'''
return list(sorted(data.items()))
初始化完之后再mod.fit()当中使用到该iter的几个成员API:
- mx.io.NDArrayIter.provide_data() # 用于module 或者 executor初始化
- mx.io.NDArrayIter.provide_label() # 用于module 或者 executor初始化和上面同步
- iter(mx.io.NDArrayIter) #得到一个迭代器,用于每次训练获取数据batch
- mx.io.NDArrayIter.reset() #训练时每个epoch 结束时 reset一次
@property
def provide_data(self):
"""The name and shape of data provided by this iterator."""
return [
"""
DataDesc 是一个namedtuple,不知道啥的百度去……
return DataDesc对象,初始化该对象使用了两个信息,首先看self.data结构
self.data - list[tuple1(name1,val1),tuple2(name2,val2)……]
其中name是symbol中的输入参数的name,上述初始化iterator时指定的
val 是该参数的数据矩阵n*v
DataDesc对象初始化时使用到 str(name) 和 tuple(batch_size,v)"""
DataDesc(k, tuple([self.batch_size] + list(v.shape[1:])), v.dtype)
for k, v in self.data
]
Module / Executor 初始化时用到mx.io.NDArrayIter
mxnet.Module.fit().Module.bind()使用到了 mx.io.NDArrayIter.provide_data() 和mx.io.NDArrayIter.provide_label()
self.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label,
for_training=True, force_rebind=force_rebind)
'''解析数据描述子'''
self._data_shapes, self._label_shapes = _parse_data_desc(
self.data_names=['data'],
self.label_names=['Softmaxlabel'],
data_shapes=mx.io.NDArrayIter.provide_data(),
label_shapes=mx.io.NDArrayIter.provide_label())
'''这个数据解析器干两件事:
把data attributes 转成DataDesc 格式
检查输入的'数据属性表中的名字'是否和'网络输入symbol的名字'匹配
'''
def _parse_data_desc(data_names, label_names, data_shapes, label_shapes):
"""parse data_attrs into DataDesc format and check that names match"""
data_shapes = [x if isinstance(x, DataDesc) else DataDesc(*x) for x in data_shapes]
_check_names_match(data_names, data_shapes, 'data', True)
if label_shapes is not None:
label_shapes = [x if isinstance(x, DataDesc) else DataDesc(*x) for x in label_shapes]
_check_names_match(label_names, label_shapes, 'label', False)
else:
_check_names_match(label_names, [], 'label', False)
return data_shapes, label_shapes
'''然后得到的数据描述子用于Executor group类的初始化'''
self._exec_group = DataParallelExecutorGroup(self._symbol, self._context,
self._work_load_list, self._data_shapes,
self._label_shapes, self._param_names,
for_training, inputs_need_grad,
shared_group, logger=self.logger,
fixed_param_names=self._fixed_param_names,
grad_req=grad_req, group2ctxs=self._group2ctxs,
state_names=self._state_names)
'''初始化里在最后一行代码用到这些数据描述子'''#shared_group =None
self.bind_exec(data_shapes, label_shapes, shared_group)
'''继续搜,上面的处理是为了多GPU并行处理而设定的,在这里把每一个GPU负责的batch,分给每一个Executor,并把这些Executor收集起来'''
self.execs.append(self._bind_ith_exec(i, data_shapes_i, label_shapes_i,
shared_group))
'''继续 这里是Module.bind()的终点,通过simple_bind得到一个Executor'''
def _bind_ith_exec(self, i, data_shapes, label_shapes, shared_group):
"""Internal utility function to bind the i-th executor.
This function utilizes simple_bind python interface.
"""
shared_exec = None if shared_group is None else shared_group.execs[i]
context = self.contexts[i]
shared_data_arrays = self.shared_data_arrays[i]
input_shapes = dict(data_shapes)
if label_shapes is not None:
input_shapes.update(dict(label_shapes))
'''这里通过输入的data descriptor 得到一个字典 用于初始化 executor'''
input_types = {x.name: x.dtype for x in data_shapes}
if label_shapes is not None:
input_types.update({x.name: x.dtype for x in label_shapes})
group2ctx = self.group2ctxs[i]
'''simple_bind
后面开另外的blog仔细研究
目前推测
这里按照输入数据的描述子计算网络的静态图
并在对应的context上分配对应的空间'''
executor = self.symbol.simple_bind(ctx=context, grad_req=self.grad_req,
type_dict=input_types, shared_arg_names=self.param_names,
shared_exec=shared_exec, group2ctx=group2ctx,
shared_buffer=shared_data_arrays, **input_shapes)
self._total_exec_bytes += int(executor.debug_str().split('\n')[-3].split()[1])
return executor
Module / Executor 训练时用到mx.io.NDArrayIter
这里重点关注每个epoch中,fit()如何调用mx.io.NDArrayIter提供数据用于训练的
for epoch in range(begin_epoch, num_epoch):
tic = time.time()
eval_metric.reset()
nbatch = 0
'''iter()是用于把一个可迭代的非iter对象变成迭代器,如list
本blog中对train_data没有任何改变
因为传入的train_data本身就是一个迭代器
'''
data_iter = iter(train_data)
'''
print(type(data_iter ),type(train_iter))
<class 'mxnet.io.NDArrayIter'> <class 'mxnet.io.NDArrayIter'>
'''
'''初始化一个标识,用于检测iter是否到尾了'''
end_of_batch = False
'''获得一个batch大小的data,<class 'mxnet.io.DataBatch'>
DataBatch 下有这两个重要的成员变量
data <class 'list<class mx.NDArray>'>
label <class 'list<class mx.NDArray>'>
'''
next_data_batch = next(data_iter)
'''
def next(self):
#调用iter_next()判断迭代器是否到尾了
if self.iter_next():
''''''
return DataBatch(data=self.getdata(), label=self.getlabel(), \
pad=self.getpad(), index=None)
else:
raise StopIteration
#self.cursor 初始化为-self.batch_size
#这是为了在取数据时用到的cursor指向需要取的数据,如第一次next().getdata()时cursor = 0
#每次调用next()则自加self.batch_size
#如果记录值小于data的长度,则返回真,否则返回假
def iter_next(self):
self.cursor += self.batch_size
return self.cursor < self.num_data
'''
while not end_of_batch:
data_batch = next_data_batch
if monitor is not None:
monitor.tic()
self.forward_backward(data_batch)
self.update()
'''处理完参数更新后,获取新的batch'''
try:
# pre fetch next batch
next_data_batch = next(data_iter)
self.prepare(next_data_batch, sparse_row_id_fn=sparse_row_id_fn)
'''except 和 raise对应'''
except StopIteration:
end_of_batch = True
重点看看iter.next().getdata()
def _getdata(self, data_source):
"""Load data from underlying arrays, internal use only."""
assert(self.cursor < self.num_data), "DataIter needs reset."
'''判断iter中剩下的数据是否够一个batch'''
if self.cursor + self.batch_size <= self.num_data:
return [
# np.ndarray or NDArray case
'''data_source = self.data <class 'list<tuple<str_name,ndarray_data>>'>
取self.data list中所有成员中对应[cursor:cursor+data_batch]区间的数据
'''
x[1][self.cursor:self.cursor + self.batch_size]
if isinstance(x[1], (np.ndarray, NDArray)) else
# h5py (only supports indices in increasing order)
array(x[1][sorted(self.idx[
self.cursor:self.cursor + self.batch_size])][[
list(self.idx[self.cursor:
self.cursor + self.batch_size]).index(i)
for i in sorted(self.idx[
self.cursor:self.cursor + self.batch_size])
]]) for x in data_source
]
else:
pad = self.batch_size - self.num_data + self.cursor
return [
# np.ndarray or NDArray case
concatenate([x[1][self.cursor:], x[1][:pad]])
if isinstance(x[1], (np.ndarray, NDArray)) else
# h5py (only supports indices in increasing order)
concatenate([
array(x[1][sorted(self.idx[self.cursor:])][[
list(self.idx[self.cursor:]).index(i)
for i in sorted(self.idx[self.cursor:])
]]),
array(x[1][sorted(self.idx[:pad])][[
list(self.idx[:pad]).index(i)
for i in sorted(self.idx[:pad])
]])
]) for x in data_source
]
总结
- DataIter初始化时给定数据和数据名称要和网络输入的symbol名称对应
- DataIter用于网络初始化和网络训练测试