Mxnet 实现自己的dataiter

版权声明: https://blog.csdn.net/qq_35606924/article/details/80033850

实现深度学习模型的时候,有时候dataiter会不能满足自己的需求,所以需要继承下来,自己写一下。


mxnet写customed dataiter

代码也比较简单,直接上代码了

import mxnet as mx

class custom_iter(mx.io.DataIter):
    def __init__(self, data_iter):
        super(custom_iter,self).__init__()
        self.data_iter = data_iter
        self.batch_size = self.data_iter.batch_size

    @property
    def provide_data(self):
        return self.data_iter.provide_data

    @property
    def provide_label(self):
        provide_label = self.data_iter.provide_label[0]

        #return [('softmax_label', provide_label[1]), \
                # ('other_loss_label', provide_label[1])]

        return [('softmax_label', provide_label[1])]

    def hard_reset(self):
        self.data_iter.hard_reset()

    def reset(self):
        self.data_iter.reset()

    def next(self):
        batch = self.data_iter.next()
        label = batch.label[0]

        return mx.io.DataBatch(data=batch.data, label=[label,label], \
                pad=batch.pad, index=batch.index)



import numpy as np
eigval = np.array([55.46, 4.794, 1.148])
eigvec = np.array([[-0.5675, 0.7192, 0.4009],
                [-0.5808, -0.0045, -0.8140],
                [-0.5836, -0.6948, 0.4203]])
shape_=112
shape=(3,shape_,shape_)


aug_list_test=[mx.image.ForceResizeAug(size=(shape_,shape_)),
                #mx.image.ResizeAug(size=shape_+32),
                mx.image.CenterCropAug((shape_,shape_)),
          ]
aug_list_train=[
                #mx.image.ResizeAug(size=shape_+32),
                mx.image.ForceResizeAug(size=(shape_,shape_)),
                mx.image.RandomCropAug((shape_,shape_)),
                mx.image.HorizontalFlipAug(0.5),
                mx.image.CastAug(),
                mx.image.ColorJitterAug(0.0, 0.1, 0.1),
                mx.image.HueJitterAug(0.5),
                mx.image.LightingAug(0.1, eigval, eigvec),
          ]

def get_iterator(batch_size):
    """return train and val iterators for training"""
    
    train_iter = mx.image.ImageIter(batch_size=batch_size,
                                    data_shape=shape,
                                    label_width=1,
                                    aug_list=aug_list_train,
                                    shuffle=True,
                                    path_root='',
                                    path_imglist='/you/path/train.lst'
                                    )
    val_iter = mx.image.ImageIter(batch_size=batch_size,
                                  data_shape=shape,
                                  label_width=1,
                                  shuffle=False,
                                  aug_list=aug_list_test,
                                  path_root='',
                                  path_imglist='/you/path/val.lst'
                                 )

    return (custom_iter(train_iter), custom_iter(val_iter))

猜你喜欢

转载自blog.csdn.net/qq_35606924/article/details/80033850