在面向对象编程中,继承当然是子类继承父类。在这里博主基于python内置的魔法函数,突发奇想实现了子类继承一个已经被实例化的对象。其核心原理就是__getattribute__的妙用,__getattribute__可以拦截对象属性访问,通过__getattribute__对访问的属性进行拦截,当访问到的是重载后的方法时,返回子类的方法,当访问到的是未知的方法时返回被实例化对象的方法。
1、具体实现
这里以继承pytorch的model实例为实现示例:
from torch import nn
from torchvision import models
import numpy as np
import torch
class MyModel(object):
def __init__(self,model,train_type,eval_type):
super(MyModel, self).__init__()
self.model = model
self.train_type = train_type
self.eval_type = eval_type
#在这里实现forward 用法:output=MyModel(input_data)
def __call__(self,x):
print("MyModel.forward()")
return self.model.forward(x)
#在按照自己的需要,对train进行修改。比如,按指数下降的方式调整模型中dropout的drop_rate
def train(self):
self.model.train()
print("MyModel.train()")
#在按照自己的需要,对eval进行修改
def eval(self):
self.model.eval()
print("MyModel.eval()")
#使用属性拦截器,无法拦截object内置的其他魔法函数属性, 但是重载后的魔法函数属性是可以被拦截的
def __getattribute__(self,*args,**kwargs):
if args[0] in ["train","eval","model","train_type","eval_type"]:
#执行类自身的方法
return object.__getattribute__(self,*args,**kwargs)
else:
#执行model的方法
return self.model.__getattribute__(*args,**kwargs)
def __str__(self):
return str(self.model)
model = models.resnet18(pretrained=True)
my_model=MyModel(model,train_type="train_type",eval_type="train_type")
2、与“父类”的属性对比
可以看出,这样子实现的子类是可以继承父类的绝大部分属性的,只是针对于模型的一些特殊属性(内部的layer结构)无法被有效继承。但是,这样子并不影响模型的训练与使用
attr1=dir(model)
attr2=dir(my_model)
print("model attr nums:",len(attr1))
print("mymodel attr nums:",len(attr2))
sub=set(attr1)-set(attr2)
print("mymodel missing attr:",sub)
print("my_model attr:",attr2)
代码执行结果如下所示 model attr nums: 111 mymodel attr nums: 101 mymodel missing attr: {'bn1', 'layer2', 'layer1', 'avgpool', 'layer3', 'relu', 'fc', 'maxpool', 'layer4', 'conv1'} my_model attr: ['T_destination', '__annotations__', '__call__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattr__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__setstate__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_apply', '_backward_hooks', '_buffers', '_call_impl', '_forward_hooks', '_forward_impl', '_forward_pre_hooks', '_get_backward_hooks', '_get_name', '_is_full_backward_hook', '_load_from_state_dict', '_load_state_dict_pre_hooks', '_make_layer', '_maybe_warn_non_full_backward_hook', '_modules', '_named_members', '_non_persistent_buffers_set', '_norm_layer', '_parameters', '_register_load_state_dict_pre_hook', '_register_state_dict_hook', '_replicate_for_data_parallel', '_save_to_state_dict', '_slow_forward', '_state_dict_hooks', '_version', 'add_module', 'apply', 'base_width', 'bfloat16', 'buffers', 'children', 'cpu', 'cuda', 'dilation', 'double', 'dump_patches', 'eval', 'extra_repr', 'float', 'forward', 'get_buffer', 'get_parameter', 'get_submodule', 'groups', 'half', 'inplanes', 'load_state_dict', 'modules', 'named_buffers', 'named_children', 'named_modules', 'named_parameters', 'parameters', 'register_backward_hook', 'register_buffer', 'register_forward_hook', 'register_forward_pre_hook', 'register_full_backward_hook', 'register_parameter', 'requires_grad_', 'share_memory', 'state_dict', 'to', 'to_empty', 'train', 'training', 'type', 'xpu', 'zero_grad']
3、使用对比
这样子继承得来的子类在使用上与父类并没有任何区别
x=torch.ones((1,3,224,224))
my_model.eval()
y1=my_model(x)
y1=y1.detach().numpy()
print("mymodel:",y1.mean(),y1.max(),y1.min())
model.eval()
y2=model(x)
y2=y2.detach().numpy()
print("model:",y2.mean(),y2.max(),y2.min())
代码执行结果如下所示 MyModel.eval() MyModel.forward() mymodel: 6.187439e-06 4.4229083 -2.8364685 model: 6.187439e-06 4.4229083 -2.8364685