

class SpatialTransformer(nn.Module):
    N-D Spatial Transformer

    def __init__(self, size, mode='bilinear'):    #size = [128,128]

        self.mode = mode

        # create sampling grid
        vectors = [torch.arange(0, s) for s in size]
        grids = torch.meshgrid(vectors)
        grid = torch.stack(grids)
        grid = torch.unsqueeze(grid, 0)
        grid = grid.type(torch.FloatTensor)
        self.register_buffer('grid', grid)

    def forward(self, src, flow):
        new_locs = self.grid + flow
        shape = flow.shape[2:]

        for i in range(len(shape)):
            new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5)

        # move channels dim to last position
        # also not sure why, but the channels need to be reversed
        if len(shape) == 2:
            new_locs = new_locs.permute(0, 2, 3, 1)
            new_locs = new_locs[..., [1, 0]]
        elif len(shape) == 3:
            new_locs = new_locs.permute(0, 2, 3, 4, 1)
            new_locs = new_locs[..., [2, 1, 0]]

        return nnf.grid_sample(src, new_locs, align_corners=True, mode=self.mode)



vectors = [torch.arange(0, s) for s in size]
#[tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
#         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
#         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
#         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
#         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
#         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
#         84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
#         98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
#        112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
#        126, 127]), tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
#         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
#         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
#         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
#         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
#         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
#         84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
#         98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
#        112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
#        126, 127])]



 grids = torch.meshgrid(vectors)
#(tensor([[  0,   0,   0,  ...,   0,   0,   0],
#        [  1,   1,   1,  ...,   1,   1,   1],
#        [  2,   2,   2,  ...,   2,   2,   2],
#        ...,
#        [125, 125, 125,  ..., 125, 125, 125],
#        [126, 126, 126,  ..., 126, 126, 126],
#        [127, 127, 127,  ..., 127, 127, 127]]), tensor([[  0,   1,   2,  ..., 125, 126, 127],
#        [  0,   1,   2,  ..., 125, 126, 127],
#        [  0,   1,   2,  ..., 125, 126, 127],
#        ...,
#        [  0,   1,   2,  ..., 125, 126, 127],
#        [  0,   1,   2,  ..., 125, 126, 127],
#        [  0,   1,   2,  ..., 125, 126, 127]]))



grid = torch.stack(grids)   #          torch.Size([2, 128, 128])
#tensor([[[  0,   0,   0,  ...,   0,   0,   0],
#         [  1,   1,   1,  ...,   1,   1,   1],
#         [  2,   2,   2,  ...,   2,   2,   2],
#         ...,
#         [125, 125, 125,  ..., 125, 125, 125],
#         [126, 126, 126,  ..., 126, 126, 126],
#         [127, 127, 127,  ..., 127, 127, 127]],
#        [[  0,   1,   2,  ..., 125, 126, 127],
#         [  0,   1,   2,  ..., 125, 126, 127],
#         [  0,   1,   2,  ..., 125, 126, 127],
#         ...,
#         [  0,   1,   2,  ..., 125, 126, 127],
#         [  0,   1,   2,  ..., 125, 126, 127],
#         [  0,   1,   2,  ..., 125, 126, 127]]])



grid = torch.unsqueeze(grid, 0)    #torch.Size([1, 2, 128, 128])   B,C,H,W
#tensor([[[[  0,   0,   0,  ...,   0,   0,   0],
#          [  1,   1,   1,  ...,   1,   1,   1],
#          [  2,   2,   2,  ...,   2,   2,   2],
#          ...,
#          [125, 125, 125,  ..., 125, 125, 125],
#          [126, 126, 126,  ..., 126, 126, 126],
#          [127, 127, 127,  ..., 127, 127, 127]],

#         [[  0,   1,   2,  ..., 125, 126, 127],
#          [  0,   1,   2,  ..., 125, 126, 127],
#          [  0,   1,   2,  ..., 125, 126, 127],
#          ...,
#          [  0,   1,   2,  ..., 125, 126, 127],
#          [  0,   1,   2,  ..., 125, 126, 127],
#          [  0,   1,   2,  ..., 125, 126, 127]]]])



grid = grid.type(torch.FloatTensor)



self.register_buffer('grid', grid)


前向传播 forward


new_locs = self.grid + flow



shape = flow.shape[2:]



for i in range(len(shape)):
            new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5)
#        for i in range(128):   #我处理的图像大小为(128,128),我把每一列像素都限定在[-1,1]之间
#            new_locs[..., i, ...] = 2 * (new_locs[..., i, ...] / 127 - 0.5)

本if语句的意义是,将网格值标准化为 [-1, 1] 以进行重采样。


 if len(shape) == 2:
            new_locs = new_locs.permute(0, 2, 3, 1)
            new_locs = new_locs[..., [1, 0]]


return nnf.grid_sample(src, new_locs, align_corners=True, mode=self.mode)




