__getattribute__的妙用,实现对象实例的“继承“与方法重载

在面向对象编程中,继承当然是子类继承父类。在这里博主基于python内置的魔法函数,突发奇想实现了子类继承一个已经被实例化的对象。其核心原理就是__getattribute__的妙用,__getattribute__可以拦截对象属性访问,通过__getattribute__对访问的属性进行拦截,当访问到的是重载后的方法时,返回子类的方法,当访问到的是未知的方法时返回被实例化对象的方法。

1、具体实现

这里以继承pytorch的model实例为实现示例:

from torch import nn
from torchvision import models
import numpy as np
import torch
class MyModel(object):
    def __init__(self,model,train_type,eval_type):
        super(MyModel, self).__init__()
        self.model = model
        self.train_type = train_type
        self.eval_type = eval_type
    #在这里实现forward 用法:output=MyModel(input_data)
    def __call__(self,x):
        print("MyModel.forward()")
        return self.model.forward(x)
        
    #在按照自己的需要,对train进行修改。比如,按指数下降的方式调整模型中dropout的drop_rate
    def train(self):
        self.model.train()
        print("MyModel.train()")
            
    #在按照自己的需要,对eval进行修改
    def eval(self):
        self.model.eval()
        print("MyModel.eval()")
            
    #使用属性拦截器,无法拦截object内置的其他魔法函数属性, 但是重载后的魔法函数属性是可以被拦截的      
    def __getattribute__(self,*args,**kwargs):
        if args[0] in ["train","eval","model","train_type","eval_type"]:
            #执行类自身的方法
            return object.__getattribute__(self,*args,**kwargs)
        else:
            #执行model的方法
            return self.model.__getattribute__(*args,**kwargs)
        
    def __str__(self):
        return str(self.model)
    
model = models.resnet18(pretrained=True)
my_model=MyModel(model,train_type="train_type",eval_type="train_type")

2、与“父类”的属性对比

可以看出,这样子实现的子类是可以继承父类的绝大部分属性的,只是针对于模型的一些特殊属性(内部的layer结构)无法被有效继承。但是,这样子并不影响模型的训练与使用

attr1=dir(model)
attr2=dir(my_model)
print("model attr nums:",len(attr1))
print("mymodel attr nums:",len(attr2))
sub=set(attr1)-set(attr2)
print("mymodel missing attr:",sub)
print("my_model attr:",attr2)
代码执行结果如下所示
model attr nums: 111
mymodel attr nums: 101
mymodel missing attr: {'bn1', 'layer2', 'layer1', 'avgpool', 'layer3', 'relu', 'fc', 'maxpool', 'layer4', 'conv1'}
my_model attr: ['T_destination', '__annotations__', '__call__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattr__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__setstate__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_apply', '_backward_hooks', '_buffers', '_call_impl', '_forward_hooks', '_forward_impl', '_forward_pre_hooks', '_get_backward_hooks', '_get_name', '_is_full_backward_hook', '_load_from_state_dict', '_load_state_dict_pre_hooks', '_make_layer', '_maybe_warn_non_full_backward_hook', '_modules', '_named_members', '_non_persistent_buffers_set', '_norm_layer', '_parameters', '_register_load_state_dict_pre_hook', '_register_state_dict_hook', '_replicate_for_data_parallel', '_save_to_state_dict', '_slow_forward', '_state_dict_hooks', '_version', 'add_module', 'apply', 'base_width', 'bfloat16', 'buffers', 'children', 'cpu', 'cuda', 'dilation', 'double', 'dump_patches', 'eval', 'extra_repr', 'float', 'forward', 'get_buffer', 'get_parameter', 'get_submodule', 'groups', 'half', 'inplanes', 'load_state_dict', 'modules', 'named_buffers', 'named_children', 'named_modules', 'named_parameters', 'parameters', 'register_backward_hook', 'register_buffer', 'register_forward_hook', 'register_forward_pre_hook', 'register_full_backward_hook', 'register_parameter', 'requires_grad_', 'share_memory', 'state_dict', 'to', 'to_empty', 'train', 'training', 'type', 'xpu', 'zero_grad']

3、使用对比

这样子继承得来的子类在使用上与父类并没有任何区别

x=torch.ones((1,3,224,224))

my_model.eval()
y1=my_model(x)
y1=y1.detach().numpy()
print("mymodel:",y1.mean(),y1.max(),y1.min())

model.eval()
y2=model(x)
y2=y2.detach().numpy()
print("model:",y2.mean(),y2.max(),y2.min())
代码执行结果如下所示
MyModel.eval()
MyModel.forward()
mymodel: 6.187439e-06 4.4229083 -2.8364685
model: 6.187439e-06 4.4229083 -2.8364685

猜你喜欢

转载自blog.csdn.net/a486259/article/details/124010097