原
pytorch模型参数理解备忘
模型结构
def downsample_basic_block(x, planes, stride):
out = F.avg_pool3d(x, kernel_size=1, stride=stride)
zero_pads = torch.Tensor(
out.size(0), planes - out.size(1), out.size(2), out.size(3),
out.size(4)).zero_()
if isinstance(out.data, torch.cuda.FloatTensor):
zero_pads = zero_pads.cuda()
out = Variable(torch.cat([out.data, zero_pads], dim=1))
return out
class ResNeXtBottleneck(nn.Module):
expansion = 2
def __init__(self, inplanes, planes, cardinality, stride=1,
downsample=None,conv3d_bias=True):
super(ResNeXtBottleneck, self).__init__()
mid_planes = cardinality * int(planes / 32)
self.conv1 = nn.Conv3d(inplanes, mid_planes, kernel_size=1, bias=conv3d_bias)
self.bn1 = nn.BatchNorm3d(mid_planes)
self.conv2 = nn.Conv3d(
mid_planes,
mid_planes,
kernel_size=3,
stride=stride,
padding=1,
groups=cardinality,
bias=conv3d_bias)
self.bn2 = nn.BatchNorm3d(mid_planes)
self.conv3 = nn.Conv3d(
mid_planes, planes * self.expansion, kernel_size=1, bias=conv3d_bias)
self.bn3 = nn.BatchNorm3d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNeXt(nn.Module):
def __init__(self,
block,
layers,
sample_size,
sample_duration,
shortcut_type='B',
cardinality=32,
num_classes=400,
conv3d_bias=True):
self.conv3d_bias = conv3d_bias
self.inplanes = 64
super(ResNeXt, self).__init__()
self.conv1 = nn.Conv3d(
3,
64,
kernel_size=7,
stride=2,
padding=3,
bias=conv3d_bias)
self.bn1 = nn.BatchNorm3d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1)
self.layer1 = self._make_layer(block, 128, layers[0], shortcut_type,
cardinality)
self.layer2 = self._make_layer(
block, 256, layers[1], shortcut_type, cardinality, stride=2)
self.layer3 = self._make_layer(
block, 512, layers[2], shortcut_type, cardinality, stride=2)
self.layer4 = self._make_layer(
block, 1024, layers[3], shortcut_type, cardinality, stride=2)
last_duration = int(math.ceil(sample_duration / 16))
last_size = int(math.ceil(sample_size / 32))
self.avgpool = nn.AvgPool3d(
(last_duration, last_size, last_size), stride=1)
self.fc = nn.Linear(cardinality * 32 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv3d):
m.weight = nn.init.kaiming_normal_(m.weight, mode='fan_out')
elif isinstance(m, nn.BatchNorm3d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self,
block,
planes,
blocks,
shortcut_type,
cardinality,
stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
if shortcut_type == 'A':
downsample = partial(
downsample_basic_block,
planes=planes * block.expansion,
stride=stride)
else:
downsample = nn.Sequential(
nn.Conv3d(
self.inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias=self.conv3d_bias), nn.BatchNorm3d(planes * block.expansion))
layers = []
layers.append(
block(self.inplanes, planes, cardinality, stride, downsample, conv3d_bias=self.conv3d_bias))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, cardinality, conv3d_bias=self.conv3d_bias))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75
- 76
- 77
- 78
- 79
- 80
- 81
- 82
- 83
- 84
- 85
- 86
- 87
- 88
- 89
- 90
- 91
- 92
- 93
- 94
- 95
- 96
- 97
- 98
- 99
- 100
- 101
- 102
- 103
- 104
- 105
- 106
- 107
- 108
- 109
- 110
- 111
- 112
- 113
- 114
- 115
- 116
- 117
- 118
- 119
- 120
- 121
- 122
- 123
- 124
- 125
- 126
- 127
- 128
- 129
- 130
- 131
- 132
- 133
- 134
- 135
- 136
- 137
- 138
- 139
- 140
- 141
- 142
- 143
- 144
- 145
- 146
- 147
- 148
- 149
- 150
- 151
- 152
- 153
- 154
pretrain = torch.load(opt.pretrain_path)
assert opt.arch == pretrain['arch']
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in pretrain['state_dict'].items():
print(k)
name = k[7:]
print(name) # remove `module.`
new_state_dict[name] = v
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
程序结果
这里之所以需要pretrain['state_dict']
而不是直接使用model.load_state_dict(torch.load(opt.pretrain_path))
是因为保存模型的时候不但保存了参数,还有周期,结构等信息。
states = {
'epoch': epoch + 1,
'arch': opt.arch,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
}
torch.save(states, save_file_path)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
name = k[7:]
去掉每一个参数名的前七个字符,因为下载的预训练模型是在torch.nn.DataParallel
分布式下训练的,而我只有单卡,所以需要去掉参数名前面的module,再load。
optimizer.state_dict()
有state
和param_groups
两个key,其中param_groups
的value如下所示。
for k, v in model.named_parameters():
print(k)
- 1
- 2
for i in model.parameters():
print(i)
- 1
- 2
print(model.fc)
- 1
返回
Linear(in_features=2048, out_features=90, bias=True)
- 1
print(model.layer4)
- 1
总结:首先构建计算图,返回model。如果只想要一致的学习率,只需要再optimizer的第一个参数里写model.parameters()。model.parameters()应该是个有顺序的字典,此时len(optimizer.param_groups)等于1。如果想分别设定学习率等参数,可以如下设置,因为model.parameters()字典有顺序,所以这里列表添加的{‘params’: v, ‘lr’: 0.0}都没有参数名,剩下的没有定义的参数如momentum就按照optimizer参数设置的给每个都分配
for i in range(opt.ft_begin_index, 5):
ft_module_names.append('layer{}'.format(i))
ft_module_names.append('fc')
parameters = []
for k, v in model.parameters():
print(k)
for ft_module in ft_module_names:
if ft_module in k:
parameters.append({'params': v})
break
else:
parameters.append({'params': v, 'lr': 0.0})
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
<link href="https://csdnimg.cn/release/phoenix/mdeditor/markdown_views-258a4616f7.css" rel="stylesheet">
</div>
</article>
模型结构