pytorch 减小显存消耗,优化显存使用, 计算模型的中间变量

一、 模型如下:

i is 0, m is RFBNet(
  (base): ModuleList(
    (0): Conv2d(1, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace)
    (2): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(24, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace)
    (7): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(48, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace)
    (12): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace)
    (14): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(96, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU(inplace)
    (19): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace)
    (21): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace)
  )
  (Norm): BasicRFB_a(
    (branch0): Sequential(
      (0): BasicConv(
        (conv): Conv2d(192, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(48, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
        (relu): ReLU(inplace)
      )
      (1): BasicConv(
        (conv): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(48, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
      )
    )
    (branch1): Sequential(
      (0): BasicConv(
        (conv): Conv2d(192, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(48, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
        (relu): ReLU(inplace)
      )
      (1): BasicConv(
        (conv): Conv2d(48, 48, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0), bias=False)
        (bn): BatchNorm2d(48, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
        (relu): ReLU(inplace)
      )
      (2): BasicConv(
        (conv): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3), dilation=(3, 3), bias=False)
        (bn): BatchNorm2d(48, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
      )
    )
    (branch2): Sequential(
      (0): BasicConv(
        (conv): Conv2d(192, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(48, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
        (relu): ReLU(inplace)
      )
      (1): BasicConv(
        (conv): Conv2d(48, 48, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1), bias=False)
        (bn): BatchNorm2d(48, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
        (relu): ReLU(inplace)
      )
      (2): BasicConv(
        (conv): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3), dilation=(3, 3), bias=False)
        (bn): BatchNorm2d(48, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
      )
    )
    (branch3): Sequential(
      (0): BasicConv(
        (conv): Conv2d(192, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(24, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
        (relu): ReLU(inplace)
      )
      (1): BasicConv(
        (conv): Conv2d(24, 36, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1), bias=False)
        (bn): BatchNorm2d(36, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
        (relu): ReLU(inplace)
      )
      (2): BasicConv(
        (conv): Conv2d(36, 48, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0), bias=False)
        (bn): BatchNorm2d(48, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
        (relu): ReLU(inplace)
      )
      (3): BasicConv(
        (conv): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(5, 5), dilation=(5, 5), bias=False)
        (bn): BatchNorm2d(48, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
      )
    )
    (ConvLinear): BasicConv(
      (conv): Conv2d(192, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(192, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
    )
    (shortcut): BasicConv(
      (conv): Conv2d(192, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(192, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
    )
    (relu): ReLU()
  )
  (loc): ModuleList(
    (0): Conv2d(192, 20, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (conf): ModuleList(
    (0): Conv2d(192, 25, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (softmax): Softmax()
)


i is 1, m is ModuleList(
  (0): Conv2d(1, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace)
  (2): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU(inplace)
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(24, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): ReLU(inplace)
  (7): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): ReLU(inplace)
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Conv2d(48, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): ReLU(inplace)
  (12): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (13): ReLU(inplace)
  (14): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (15): ReLU(inplace)
  (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (17): Conv2d(96, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (18): ReLU(inplace)
  (19): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (20): ReLU(inplace)
  (21): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (22): ReLU(inplace)
)


i is 2, m is Conv2d(1, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
i is 3, m is ReLU(inplace)
i is 4, m is Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
i is 5, m is ReLU(inplace)
i is 6, m is MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
i is 7, m is Conv2d(24, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
i is 8, m is ReLU(inplace)
i is 9, m is Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
i is 10, m is ReLU(inplace)
i is 11, m is MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
i is 12, m is Conv2d(48, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
i is 13, m is ReLU(inplace)
i is 14, m is Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
i is 15, m is ReLU(inplace)
i is 16, m is Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
i is 17, m is ReLU(inplace)
i is 18, m is MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
i is 19, m is Conv2d(96, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
i is 20, m is ReLU(inplace)
i is 21, m is Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
i is 22, m is ReLU(inplace)
i is 23, m is Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
i is 24, m is ReLU(inplace)


i is 25, m is BasicRFB_a(
  (branch0): Sequential(
    (0): BasicConv(
      (conv): Conv2d(192, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(48, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
    )
    (1): BasicConv(
      (conv): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(48, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
    )
  )
  (branch1): Sequential(
    (0): BasicConv(
      (conv): Conv2d(192, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(48, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
    )
    (1): BasicConv(
      (conv): Conv2d(48, 48, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0), bias=False)
      (bn): BatchNorm2d(48, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
    )
    (2): BasicConv(
      (conv): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3), dilation=(3, 3), bias=False)
      (bn): BatchNorm2d(48, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
    )
  )
  (branch2): Sequential(
    (0): BasicConv(
      (conv): Conv2d(192, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(48, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
    )
    (1): BasicConv(
      (conv): Conv2d(48, 48, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1), bias=False)
      (bn): BatchNorm2d(48, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
    )
    (2): BasicConv(
      (conv): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3), dilation=(3, 3), bias=False)
      (bn): BatchNorm2d(48, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
    )
  )
  (branch3): Sequential(
    (0): BasicConv(
      (conv): Conv2d(192, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(24, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
    )
    (1): BasicConv(
      (conv): Conv2d(24, 36, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1), bias=False)
      (bn): BatchNorm2d(36, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
    )
    (2): BasicConv(
      (conv): Conv2d(36, 48, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0), bias=False)
      (bn): BatchNorm2d(48, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
    )
    (3): BasicConv(
      (conv): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(5, 5), dilation=(5, 5), bias=False)
      (bn): BatchNorm2d(48, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
    )
  )
  (ConvLinear): BasicConv(
    (conv): Conv2d(192, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(192, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
  )
  (shortcut): BasicConv(
    (conv): Conv2d(192, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(192, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
  )
  (relu): ReLU()
)


i is 26, m is Sequential(
  (0): BasicConv(
    (conv): Conv2d(192, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(48, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
    (relu): ReLU(inplace)
  )
  (1): BasicConv(
    (conv): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(48, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
  )
)
i is 27, m is BasicConv(
  (conv): Conv2d(192, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (bn): BatchNorm2d(48, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
  (relu): ReLU(inplace)
)
i is 28, m is Conv2d(192, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
i is 29, m is BatchNorm2d(48, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
i is 30, m is ReLU(inplace)
i is 31, m is BasicConv(
  (conv): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn): BatchNorm2d(48, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
)
i is 32, m is Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
i is 33, m is BatchNorm2d(48, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)

i is 34, m is Sequential(
  (0): BasicConv(
    (conv): Conv2d(192, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(48, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
    (relu): ReLU(inplace)
  )
  (1): BasicConv(
    (conv): Conv2d(48, 48, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0), bias=False)
    (bn): BatchNorm2d(48, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
    (relu): ReLU(inplace)
  )
  (2): BasicConv(
    (conv): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3), dilation=(3, 3), bias=False)
    (bn): BatchNorm2d(48, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
  )
)
i is 35, m is BasicConv(
  (conv): Conv2d(192, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (bn): BatchNorm2d(48, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
  (relu): ReLU(inplace)
)
i is 36, m is Conv2d(192, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
i is 37, m is BatchNorm2d(48, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
i is 38, m is ReLU(inplace)
i is 39, m is BasicConv(
  (conv): Conv2d(48, 48, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0), bias=False)
  (bn): BatchNorm2d(48, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
  (relu): ReLU(inplace)
)
i is 40, m is Conv2d(48, 48, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0), bias=False)
i is 41, m is BatchNorm2d(48, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
i is 42, m is ReLU(inplace)
i is 43, m is BasicConv(
  (conv): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3), dilation=(3, 3), bias=False)
  (bn): BatchNorm2d(48, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
)
i is 44, m is Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3), dilation=(3, 3), bias=False)
i is 45, m is BatchNorm2d(48, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)


i is 46, m is Sequential(
  (0): BasicConv(
    (conv): Conv2d(192, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(48, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
    (relu): ReLU(inplace)
  )
  (1): BasicConv(
    (conv): Conv2d(48, 48, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1), bias=False)
    (bn): BatchNorm2d(48, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
    (relu): ReLU(inplace)
  )
  (2): BasicConv(
    (conv): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3), dilation=(3, 3), bias=False)
    (bn): BatchNorm2d(48, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
  )
)
i is 47, m is BasicConv(
  (conv): Conv2d(192, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (bn): BatchNorm2d(48, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
  (relu): ReLU(inplace)
)
i is 48, m is Conv2d(192, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
i is 49, m is BatchNorm2d(48, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
i is 50, m is ReLU(inplace)
i is 51, m is BasicConv(
  (conv): Conv2d(48, 48, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1), bias=False)
  (bn): BatchNorm2d(48, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
  (relu): ReLU(inplace)
)
i is 52, m is Conv2d(48, 48, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1), bias=False)
i is 53, m is BatchNorm2d(48, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
i is 54, m is ReLU(inplace)
i is 55, m is BasicConv(
  (conv): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3), dilation=(3, 3), bias=False)
  (bn): BatchNorm2d(48, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
)
i is 56, m is Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3), dilation=(3, 3), bias=False)
i is 57, m is BatchNorm2d(48, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
i is 58, m is Sequential(
  (0): BasicConv(
    (conv): Conv2d(192, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(24, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
    (relu): ReLU(inplace)
  )
  (1): BasicConv(
    (conv): Conv2d(24, 36, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1), bias=False)
    (bn): BatchNorm2d(36, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
    (relu): ReLU(inplace)
  )
  (2): BasicConv(
    (conv): Conv2d(36, 48, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0), bias=False)
    (bn): BatchNorm2d(48, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
    (relu): ReLU(inplace)
  )
  (3): BasicConv(
    (conv): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(5, 5), dilation=(5, 5), bias=False)
    (bn): BatchNorm2d(48, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
  )
)
i is 59, m is BasicConv(
  (conv): Conv2d(192, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (bn): BatchNorm2d(24, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
  (relu): ReLU(inplace)
)
i is 60, m is Conv2d(192, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
i is 61, m is BatchNorm2d(24, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
i is 62, m is ReLU(inplace)
i is 63, m is BasicConv(
  (conv): Conv2d(24, 36, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1), bias=False)
  (bn): BatchNorm2d(36, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
  (relu): ReLU(inplace)
)
i is 64, m is Conv2d(24, 36, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1), bias=False)
i is 65, m is BatchNorm2d(36, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
i is 66, m is ReLU(inplace)
i is 67, m is BasicConv(
  (conv): Conv2d(36, 48, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0), bias=False)
  (bn): BatchNorm2d(48, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
  (relu): ReLU(inplace)
)
i is 68, m is Conv2d(36, 48, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0), bias=False)
i is 69, m is BatchNorm2d(48, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
i is 70, m is ReLU(inplace)
i is 71, m is BasicConv(
  (conv): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(5, 5), dilation=(5, 5), bias=False)
  (bn): BatchNorm2d(48, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
)
i is 72, m is Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(5, 5), dilation=(5, 5), bias=False)
i is 73, m is BatchNorm2d(48, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)

i is 74, m is BasicConv(
  (conv): Conv2d(192, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (bn): BatchNorm2d(192, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
)
i is 75, m is Conv2d(192, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
i is 76, m is BatchNorm2d(192, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
i is 77, m is BasicConv(
  (conv): Conv2d(192, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (bn): BatchNorm2d(192, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
)
i is 78, m is Conv2d(192, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
i is 79, m is BatchNorm2d(192, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)

i is 80, m is ReLU()
i is 81, m is ModuleList(
  (0): Conv2d(192, 20, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
i is 82, m is Conv2d(192, 20, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
i is 83, m is ModuleList(
  (0): Conv2d(192, 25, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
i is 84, m is Conv2d(192, 25, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
i is 85, m is Softmax()

二、 模型代码

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from cell_detection.layers import *
import torchvision.transforms as transforms
import torchvision.models as models
import torch.backends.cudnn as cudnn
import os
class BasicConv(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False):
        super(BasicConv, self).__init__()
        self.out_channels = out_planes
        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
        self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None
        self.relu = nn.ReLU(inplace=True) if relu else None
    def forward(self, x):
        x = self.conv(x)
        if self.bn is not None:
            x = self.bn(x)
        if self.relu is not None:
            x = self.relu(x)
        return x
class BasicRFB(nn.Module):
    def __init__(self, in_planes, out_planes, stride=1, scale = 0.1, visual = 1):
        super(BasicRFB, self).__init__()
        self.scale = scale
        self.out_channels = out_planes
        inter_planes = in_planes // 8
        self.branch0 = nn.Sequential(
                BasicConv(in_planes, inter_planes, kernel_size=1, stride=1),
                #BasicConv(inter_planes, (inter_planes//2)*3, kernel_size=(1,3), stride=1, padding=(0,1)),
                BasicConv(inter_planes, 2*inter_planes, kernel_size=(3,3), stride=stride, padding=(1,1)),
                BasicConv(2*inter_planes, 2*inter_planes, kernel_size=3, stride=1, padding=visual+1, dilation=visual+1, relu=False)
                )
        self.branch1 = nn.Sequential(
                BasicConv(in_planes, inter_planes, kernel_size=1, stride=1),
                BasicConv(inter_planes, (inter_planes//2)*3, kernel_size=3, stride=1, padding=1),
                BasicConv((inter_planes//2)*3, 2*inter_planes, kernel_size=3, stride=stride, padding=1),
                BasicConv(2*inter_planes, 2*inter_planes, kernel_size=3, stride=1, padding=2*visual+1, dilation=2*visual+1, relu=False)
                )
        self.ConvLinear = BasicConv(4*inter_planes, out_planes, kernel_size=1, stride=1, relu=False)
        self.shortcut = BasicConv(in_planes, out_planes, kernel_size=1, stride=stride, relu=False)
        self.relu = nn.ReLU(inplace=False)
    def forward(self,x):
        x0 = self.branch0(x)
        x1 = self.branch1(x)
        out = torch.cat((x0,x1),1)
        out = self.ConvLinear(out)
        short = self.shortcut(x)
        out = out*self.scale + short  # 分别经过out和short,再一个线性组合?
        out = self.relu(out)
        return out
class BasicRFB_a(nn.Module):
    def __init__(self, in_planes, out_planes, stride=1, scale = 0.1):
        super(BasicRFB_a, self).__init__()
        self.scale = scale
        self.out_channels = out_planes
        inter_planes = in_planes //4
        self.branch0 = nn.Sequential(
                BasicConv(in_planes, inter_planes, kernel_size=1, stride=1),
                BasicConv(inter_planes, inter_planes, kernel_size=3, stride=1, padding=1,relu=False)
                )
        self.branch1 = nn.Sequential(
                BasicConv(in_planes, inter_planes, kernel_size=1, stride=1),
                BasicConv(inter_planes, inter_planes, kernel_size=(3,1), stride=1, padding=(1,0)),
                BasicConv(inter_planes, inter_planes, kernel_size=3, stride=1, padding=3, dilation=3, relu=False)
                )
        self.branch2 = nn.Sequential(
                BasicConv(in_planes, inter_planes, kernel_size=1, stride=1),
                BasicConv(inter_planes, inter_planes, kernel_size=(1,3), stride=stride, padding=(0,1)),
                BasicConv(inter_planes, inter_planes, kernel_size=3, stride=1, padding=3, dilation=3, relu=False)
                )
        '''
        self.branch3 = nn.Sequential(
                BasicConv(in_planes, inter_planes, kernel_size=1, stride=1),
                BasicConv(inter_planes, inter_planes, kernel_size=3, stride=1, padding=1),
                BasicConv(inter_planes, inter_planes, kernel_size=3, stride=1, padding=3, dilation=3, relu=False)
                )
        '''
        self.branch3 = nn.Sequential(
                BasicConv(in_planes, inter_planes//2, kernel_size=1, stride=1),
                BasicConv(inter_planes//2, (inter_planes//4)*3, kernel_size=(1,3), stride=1, padding=(0,1)),
                BasicConv((inter_planes//4)*3, inter_planes, kernel_size=(3,1), stride=stride, padding=(1,0)),
                BasicConv(inter_planes, inter_planes, kernel_size=3, stride=1, padding=5, dilation=5, relu=False)
                )
        self.ConvLinear = BasicConv(4*inter_planes, out_planes, kernel_size=1, stride=1, relu=False)
        self.shortcut = BasicConv(in_planes, out_planes, kernel_size=1, stride=stride, relu=False)
        self.relu = nn.ReLU(inplace=False)
    def forward(self, x):
        x0 = self.branch0(x)
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        x3 = self.branch3(x)
        out = torch.cat((x0,x1,x2,x3),1)
        out = self.ConvLinear(out)
        short = self.shortcut(x)
        out = out*self.scale + short
        out = self.relu(out)
        return out
class RFBNet(nn.Module):
    """RFB Net for object detection
    The network is based on the SSD architecture.
    Each multibox layer branches into
        1) conv2d for class conf scores
        2) conv2d for localization predictions
        3) associated priorbox layer to produce default bounding
           boxes specific to the layer's feature map size.
    See: https://arxiv.org/pdf/1711.07767.pdf for more details on RFB Net.
    Args:
        phase: (string) Can be "test" or "train"
        base: VGG16 layers for input, size of either 300 or 512
        extras: extra layers that feed to multibox loc and conf layers
        head: "multibox head" consists of loc and conf conv layers
    """
    def __init__(self, phase, size, base, extras, head, num_classes):
        super(RFBNet, self).__init__()
        self.phase = phase
        self.num_classes = num_classes
        self.size = size
        if size == 300:
            self.indicator = 3
        elif size == 512:
            self.indicator = 5
        else:
            print("Error: Sorry only SSD300 and SSD512 are supported!")
            return
        # vgg network
        self.base = nn.ModuleList(base)  # base,extras, head对应了multibox
        # conv_4
        self.Norm = BasicRFB_a(192, 192, stride = 1, scale=1.0)
        # self.extras = nn.ModuleList(extras[0:2])  # 只用conv4_3这一层
        self.loc = nn.ModuleList(head[0][0:1])
        self.conf = nn.ModuleList(head[1][0:1])
        if self.phase == 'test':
            self.softmax = nn.Softmax()
    def forward(self, x):
        """Applies network layers and ops on input image(s) x.
        Args:
            x: input image or batch of images. Shape: [batch,3*batch,300,300].
        Return:
            Depending on phase:
            test:
                list of concat outputs from:
                    1: softmax layers, Shape: [batch*num_priors,num_classes]
                    2: localization layers, Shape: [batch,num_priors*4]
                    3: priorbox layers, Shape: [2,num_priors*4]
            train:
                list of concat outputs from:
                    1: confidence layers, Shape: [batch*num_priors,num_classes]
                    2: localization layers, Shape: [batch,num_priors*4]
                    3: priorbox layers, Shape: [2,num_priors*4]
        """
        sources = list()
        loc = list()
        conf = list()
        # apply vgg up to conv4_3 relu
        for k in range(23):
            x = self.base[k](x)
        s = self.Norm(x)  # BasicRFB_a这个分支
        sources.append(s)
        # apply multibox head to source layers
        for (x, l, c) in zip(sources, self.loc, self.conf):
            loc.append(l(x).permute(0, 2, 3, 1).contiguous())  # permute调换维度,channel放到最后
            conf.append(c(x).permute(0, 2, 3, 1).contiguous())
        #print([o.size() for o in loc])
        feature_map_size = [(o.size()[1], o.size()[2]) for o in conf]
        loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1)  # 将特征直接拉直
        conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1)
        if self.phase == "test":
            output = (
                loc.view(loc.size(0), -1, 4),                   # loc preds
                self.softmax(conf.view(-1, self.num_classes)),  # conf preds
            )
        else:
            output = (
                loc.view(loc.size(0), -1, 4),
                conf.view(conf.size(0), -1, self.num_classes),
            )
        return output, feature_map_size
    def load_weights(self, base_file):
        other, ext = os.path.splitext(base_file)
        if ext == '.pkl' or '.pth':
            print('Loading weights into state dict...')
            self.load_state_dict(torch.load(base_file))
            print('Finished!')
        else:
            print('Sorry only .pth and .pkl files supported.')
# This function is derived from torchvision VGG make_layers()
# https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py
def vgg(cfg, i, batch_norm=False):
    layers = []
    in_channels = i
    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        elif v == 'C':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]  #pooling的时候使用ceil还是floor
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v
    # 这里开始不同
    pool5 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
    conv6 = nn.Conv2d(192, 384, kernel_size=3, padding=6, dilation=6)  #dilation 6卷积
    conv7 = nn.Conv2d(384, 384, kernel_size=1)
    layers += [pool5, conv6,
               nn.ReLU(inplace=True), conv7, nn.ReLU(inplace=True)]
    return layers
base = {
    '512': [24, 24, 'M', 48, 48, 'M', 96, 96, 96, 'C', 192, 192, 192, 'M',
            192, 192, 192],
}
def add_extras(size, cfg, i, batch_norm=False):
    # Extra layers added to VGG for feature scaling
    layers = []
    in_channels = i
    flag = False
    for k, v in enumerate(cfg):
        if in_channels != 'S':
            if v == 'S':
                if in_channels == 96 and size == 512:
                    layers += [BasicRFB(in_channels, cfg[k+1], stride=2, scale = 1.0, visual=1)]
                else:
                    layers += [BasicRFB(in_channels, cfg[k+1], stride=2, scale = 1.0, visual=2)]
            else:
                layers += [BasicRFB(in_channels, v, scale = 1.0, visual=2)]
        in_channels = v
    if size == 512:
        layers += [BasicConv(96,48,kernel_size=1,stride=1)]
        layers += [BasicConv(48,96,kernel_size=4,stride=1,padding=1)]
    elif size ==300:
        layers += [BasicConv(256,128,kernel_size=1,stride=1)]
        layers += [BasicConv(128,256,kernel_size=3,stride=1)]
        layers += [BasicConv(256,128,kernel_size=1,stride=1)]
        layers += [BasicConv(128,256,kernel_size=3,stride=1)]
    else:
        print("Error: Sorry only RFBNet300 and RFBNet512 are supported!")
        return
    return layers
extras = {
    '512': [384, 'S', 192, 'S', 96, 'S', 96,'S',96],
}
def multibox(size, vgg, extra_layers, cfg, num_classes):
    """
    主要是用来生成loc_layers和conf_layers
    :param size: image size
    :param vgg: 传入的vgg网络,list类型,包含各个层
    :param extra_layers: list类型
    :param cfg: [6,6,6,6,4,4]list, mbox中定义的
    :param num_classes:
    :return:
    """
    loc_layers = []
    conf_layers = []
    vgg_source = [-2]
    for k, v in enumerate(vgg_source):
        if k == 0:
            loc_layers += [nn.Conv2d(192,
                                 cfg[k] * 4, kernel_size=3, padding=1)]
            conf_layers +=[nn.Conv2d(192,
                                 cfg[k] * num_classes, kernel_size=3, padding=1)]
        else:
            loc_layers += [nn.Conv2d(vgg[v].out_channels,
                                 cfg[k] * 4, kernel_size=3, padding=1)]
            conf_layers += [nn.Conv2d(vgg[v].out_channels,
                        cfg[k] * num_classes, kernel_size=3, padding=1)]
    i = 1
    indicator = 0
    if size == 300:
        indicator = 3
    elif size == 512:
        indicator = 5
    else:
        print("Error: Sorry only RFBNet300 and RFBNet512 are supported!")
        return
    for k, v in enumerate(extra_layers):
        if k < indicator or k%2== 0:
            loc_layers += [nn.Conv2d(v.out_channels, cfg[i]
                                 * 4, kernel_size=3, padding=1)]
            conf_layers += [nn.Conv2d(v.out_channels, cfg[i]
                                  * num_classes, kernel_size=3, padding=1)]
            i +=1
    return vgg, extra_layers, (loc_layers, conf_layers)
mbox = {
    '512': [5, 5, 5, 5, 5, 4, 4], # number of boxes per feature map location
}
def build_net(phase, size=300, num_classes=21):
    if phase != "test" and phase != "train":
        print("Error: Phase not recognized")
        return
    if size != 300 and size != 512:
        print("Error: Sorry only RFBNet300 and RFBNet512 are supported!")
        return
    return RFBNet(phase, size, *multibox(size, vgg(base[str(size)], 1),
                                add_extras(size, extras[str(size)], 384),
                                mbox[str(size)], num_classes), num_classes)

二、计算模型参数占用的显存以及中间变量占用的显存

def modelsize(model, input, type_size=4):
    L = list()
    L_tmp = list()
    for p  in model.parameters():
        #print("p is: ", p)
        L.append(p.size())
        L_tmp.append(np.prod(L))
        L = list()
    para = sum(L_tmp)
    print('model {} : params: {:4f}M'.format(model._get_name(), para * type_size / 1000 / 1000))
    input_ = input.clone()
    input_.requires_grad_(requires_grad=False)
    print("model.modules is: ", model.modules())
    print("model.modules length is: ", len(list(model.modules())))
    mods = list(model.modules())
    out_sizes = []
    print("***************************************")
    for i in range(0, len(mods)):
        m = mods[i]
        #print("i is {0}, m is {1}".format(i, m))
    print("***************************************")
    for i in range(2, 24):
        m = mods[i]
        #print("i is {0}, m is {1}".format(i, m))
        if isinstance(m, nn.ReLU):
            if m.inplace:
                continue
        out = m(input_)
        out_sizes.append(np.array(out.size()))
        input_ = out
    input_0 = input_.clone()
    for i in range(28, 34):
        if i == 31:
            continue
        m = mods[i]
        #print("i is {0}, m is {1}".format(i, m))
        #print("input_shape size is: ", input_0.shape)
        if isinstance(m, nn.ReLU):
            if m.inplace:
                continue
        out_0 = m(input_0)
        out_sizes.append(np.array(out_0.size()))
        input_0 = out_0
        
        
    input_1 = input_.clone()
    for i in range(36, 46):
        if i == 39 or i == 43 :
            continue
        m = mods[i]
        #print("i is {0}, m is {1}".format(i, m))
        if isinstance(m, nn.ReLU):
            if m.inplace:
                continue
        out_1 = m(input_1)
        out_sizes.append(np.array(out_1.size()))
        input_1 = out_1
    input_2 = input_.clone()
    for i in range(48, 58):
        if i == 51 or i == 55 :
            continue
        m = mods[i]
        #print("i is {0}, m is {1}".format(i, m))
        if isinstance(m, nn.ReLU):
            if m.inplace:
                continue
        out_2 = m(input_2)
        out_sizes.append(np.array(out_2.size()))
        input_2 = out_2
    input_3 = input_.clone()
    for i in range(60, 74):
        if i == 63 or i == 67 or i == 71:
            continue
        m = mods[i]
        print("i is {0}, m is {1}".format(i, m))
        if isinstance(m, nn.ReLU):
            if m.inplace:
                continue
        out_3 = m(input_3)
        out_sizes.append(np.array(out_3.size()))
        input_3 = out_3
    input_RFB = torch.cat((input_0, input_1, input_2, input_3),1)
    
    for i in range(75, 81):
        if i == 77:
            continue
        m = mods[i]
        print("i is {0}, m is {1}".format(i, m))
        if isinstance(m, nn.ReLU):
            if m.inplace:
                continue
        out_RFB = m(input_RFB)
        out_sizes.append(np.array(out_RFB.size()))
        input_RFB = out_RFB
    total_nums = 0
    for i in range(len(out_sizes)):
        s = out_sizes[i]
        nums = np.prod(np.array(s))
        total_nums += nums
    print('Model {} : intermedite variables: {:3f} M (without backward)'.format(model._get_name(), total_nums * type_size / 1000 / 1000))

参考博客:

        https://blog.csdn.net/jacke121/article/details/81329679

        https://github.com/Oldpan/Pytorch-Memory-Utils

猜你喜欢

转载自blog.csdn.net/jishuqianjin/article/details/86502073