from collections import OrderedDict
import torch
from models.faceland_d import FaceLanndInference_d
if __name__ == '__main__':
model = FaceLanndInference_d()
model_paths = ["./weights_d/0.0680_slim128_epoch_52.pth",
"./weights_d/0.0680_slim128_epoch_52.pth"]
if model_paths:
bone_dict = model.state_dict()
new_state_dict = OrderedDict()
data_len=len(model_paths)
for model_path in model_paths:
state_dict = torch.load(model_path)
for k, v in state_dict.items():
head = k[:7]
if head == 'module.':
tmp_name = k[7:] # remove `module.`
else:
tmp_name = k # continue
need_v = bone_dict[tmp_name]
if tmp_name in new_state_dict:
new_state_dict[tmp_name] += v/data_len
else:
new_state_dict[tmp_name] = v/data_len
model.load_state_dict(new_state_dict, strict=False)
torch.save(model.state_dict(), "new_weight.pth")
pytorch 多个模型 求平均
猜你喜欢
转载自blog.csdn.net/jacke121/article/details/131365177
今日推荐
周排行