class NetWrapper(nn.Module):
def __init__(self, net, projection_size, projection_hidden_size, layer=-2, use_simsiam_mlp=False):
super().__init__()
self.net = net
self.layer = layer
self.projector = None
self.projection_size = projection_size
self.projection_hidden_size = projection_hidden_size
self.use_simsiam_mlp = use_simsiam_mlp
self.hidden = {}
self.hook_registered = False
def _find_layer(self):
if type(self.layer) == str:
modules = dict([*self.net.named_modules()])
return modules.get(self.layer, None)
elif type(self.layer) == int:
children = [*self.net.children()]
return children[self.layer]
return None
def _hook(self, _, input, output):
device = input[0].device
self.hidden[device] = flatten(output)
def _register_hook(self):
layer = self._find_layer()
assert layer is not None, f'hidden layer ({self.layer}) not found'
handle = layer.register_forward_hook(self._hook)
self.hook_registered = True
@singleton('projector')
def _get_projector(self, hidden):
_, dim = hidden.shape
create_mlp_fn = MLP if not self.use_simsiam_mlp else SimSiamMLP
projector = create_mlp_fn(dim, self.projection_size, self.projection_hidden_size)
return projector.to(hidden)
def get_representation(self, x):
if self.layer == -1:
return self.net(x)
if not self.hook_registered:
self._register_hook()
self.hidden.clear()
_ = self.net(x)
hidden = self.hidden[x.device]
self.hidden.clear()
assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
return hidden
def forward(self, x, return_projection=True):
representation = self.get_representation(x)
if not return_projection:
return representation
projector = self._get_projector(representation)
projection = projector(representation)
print(projection, representation)
return projection, representation
【程序】def _hook(self, _, input, output), def _register_hook(self):通过隐式调用 将网络输出结果保存在字典中
猜你喜欢
转载自blog.csdn.net/nyist_yangguang/article/details/128314317
今日推荐
周排行