论文精讲 | 基于昇思MindSpore实现多域原型对比学习下的泛化联邦原型学习

作者:李锐锋

论文标题

Rethinking Federated Learning with Domain Shift: A Prototype View

论文来源

CVPR 2023

论文链接

https://openaccess.thecvf.com/content/CVPR2023/papers/Huang_Rethinking_Federated_Learning_With_Domain_Shift_A_Prototype_View_CVPR_2023_paper.pdf

代码链接

https://github.com/yuhangchen0/FPL_MS

昇思MindSpore作为一个开源的AI框架,为产学研和开发人员带来端边云全场景协同、极简开发、极致性能,超大规模AI预训练、极简开发、安全可信的体验,2020.3.28开源来已超过5百万的下载量,昇思MindSpore已支持数百+AI顶会论文,走入Top100+高校教学,通过HMS在5000+App上商用,拥有数量众多的开发者,在AI计算中心,金融、智能制造、金融、云、无线、数通、能源、消费者1+8+N、智能汽车等端边云车全场景逐步广泛应用,是Gitee指数最高的开源软件。欢迎大家参与开源贡献、套件、模型众智、行业创新与应用、算法创新、学术合作、AI书籍合作等,贡献您在云侧、端侧、边侧以及安全领域的应用案例。

在科技界、学术界和工业界对昇思MindSpore的广泛支持下,基于昇思MindSpore的AI论文2023年在所有AI框架中占比7%,连续两年进入全球第二,感谢CAAI和各位高校老师支持,我们一起继续努力做好AI科研创新。昇思MindSpore社区支持顶级会议论文研究,持续构建原创AI成果。我会不定期挑选一些优秀的论文来推送和解读,希望更多的产学研专家跟昇思MindSpore合作,一起推动原创AI研究,昇思MindSpore社区会持续支撑好AI创新和AI应用,本文是昇思MindSpore AI顶会论文系列第18篇,我选择了来自武汉大学计算机学院的叶茫老师团队的一篇论文解读,感谢各位专家教授同学的投稿。

昇思MindSpore旨在实现易开发、高效执行、全场景覆盖三大目标。通过使用体验,昇思MindSpore这一深度学习框架的发展速度飞快,它的各类API的设计都在朝着更合理、更完整、更强大的方向不断优化。此外,昇思不断涌现的各类开发工具也在辅助这一生态圈营造更加便捷强大的开发手段,例如MindSpore Insight,它可以将模型架构以图的形式呈现出来,也可以动态监控模型运行时各个指标和参数的变化,使开发过程更加方便。

01

研究背景

在数字化的世界中,数据隐私和安全性成为了日益关注的核心议题。正是在这样的背景下,联邦学习应运而生,作为一种保护数据隐私的分布式机器学习方法,其核心思想是让多个设备或服务器共同训练一个模型,而无需共享原始数据。这种方法可以应对多台移动设备上的机器学习任务,特别是在数据隐私和安全性需求高的情况下。

联邦学习中有个重要的待解决的问题:数据异构性。通常指的是参与学习的各节点(例如设备、服务器或组织)持有的数据可能存在巨大的差异。这些差异可能涉及数据的分布、质量、数量以及特征类型等多个方面。数据异构性问题在联邦学习中尤为重要,因为它可能直接影响模型的学习效果和泛化能力。

本文指出,针对数据异构,现有的解决方案主要关注来自同一领域的所有私有数据。当分布式数据来源于不同的领域时,私有模型在其他领域(存在领域偏移)上容易展现出退化的性能,而且全局信号无法捕获丰富公平的领域信息。因此,作者期望在联邦学习过程中通过优化的全局模型能够稳定地在多个领域上提供泛化性能。

在本文中,作者提出了一种针对领域偏移下的联邦学习的“联邦原型学习”(FPL)。核心思想是构建集群原型和无偏原型,提供丰富的领域知识和公平的收敛目标。一方面,将样本嵌入远离来自不同类别的群集原型,更接近于相同语义的群集原型。另一方面,引入了一致性正则化,以使本地实例与相应的无偏原型对齐。

论文基于昇思MindSpore进行框架开发和实验,Digits和Office Caltech任务等实验结果证明了所提出的解决方案的有效性和关键模块的高效性。

02

团队介绍

论文第一作者黄文柯目前硕博连读于武汉大学(2021-至今),导师为杜博教授和叶茫教授。本科毕业于武汉大学,主要研究方向包括联邦学习,图学习,金融科技等,目前在CVPR、IJCAI、ACM MM 国际顶级会议上以第一作者发表论文4篇。研究生期间获得国泰君安奖学金、优秀研究生等称号。曾在阿里巴巴集团、微软亚洲研究院等担任研究实习生。

论文通讯作者叶茫是武汉大学计算机学院教授、博士生导师,国家级高层次青年人才,中国科协青年托举人才。曾任阿联酋起源人工智能研究院研究科学家和美国哥伦比亚大学访问学者。主要研究方向计算机视觉、多媒体检索、联邦学习等,发表国际期刊会议论文 80 余篇,ESI 高被引论文 10 篇,谷歌学术引用 5600 余次。担任CVPR24、ACM MM23等学术会议领域主席。主持湖北省重点研发计划、国家自然科学基金面上项目等科研项目。获谷歌优秀奖学金、国际计算机视觉顶会 ICCV2021无人机目标重识别赛道冠军、2021-2022年斯坦福排行榜 “全球前2%顶尖科学家”、2022年百度AI华人青年学者等荣誉。。

研究团队MARS是由叶茫教授指导的专注研究监控视频行人/行为分析、无监督/半监督学习、跨模态理解与推理、联邦学习。

03

论文简介

3.1 介绍

基于前述的研究背景,本文提出联邦原型学习(Federated Prototype Learning),用于解决联邦多域泛化问题:私有数据来源于不同领域,不同的客户端存在差异较大的特征分布,由于本地模型会过度拟合本地分布,私有模型无法在其他领域表现好。比如说,一个在灰度图像MNIST上训练的本地模型A,在被服务器端聚合后,不能在另一个比如彩色图像SVHN数据集的客户端中表现正常,因为这个本地模型A无法学习到SVHN的领域信息,导致性能退化。

由于全局信号无法表征多个领域的知识信息,并且可能偏向主导领域的信息,导致泛化能力下降。为了让模型学习到丰富的多领域知识,使用共有信号提供多个领域信息提升泛化能力,本文提出利用集群原型表征不同领域信息,利用对比学习,增强不同领域相同类别共性,增强不同类别差异性,称为集群原型对比学习(Cluster Prototypes Contrastive Learning);为了避免朝潜在主导域优化,提升在少数域上的能力,本文利无偏原型提供公平稳定的信息,称为无偏原型一致性正则化(Unbiased Prototypes Consistent Regularization)。

3.2 方法

3.2.1 准备

联邦学习

在典型的联邦学习设置中,存在图片个参与者与其对应的私有数据,表示为:

图片

其中,图片表示本地数据规模。在异构联邦学习环境下,条件特征分布图片在各个参与者之间会变化,即便图片是一致的,这导致了领域偏移。定义领域偏移为:

图片

这意味着,在私有数据中存在领域偏移。具体来说,对于同一标签空间,不同参与者间存在独特的特征分布。

图片图1 本地客户端数据来源域不同,差异较大

此外,所有参与者达成共识,共享一个具有相同架构的模型。这个模型可以视为两个主要部分:特征提取器和分类器。特征提取器,记为图片,将样本x编码为特征空间图片中的一个维特征向量,表示为:

图片

分类器将特征映射为logits输出图片 ,在后续公式中,图片代表分类的类别。优化目标是,通过联邦学习过程学习一个在多个领域中都有良好性能的可泛化全局模型。

特征原型

为了实现后续的原型相关方法,本文首先构建了原型的定义:

图片

其中图片表示第图片个client的标签为图片的原型,通过计算第图片个client的标签为图片的所有样本的特征向量的平均值得到,直观表示这个client的标签图片所表现的领域信息。

如果先不考虑本文方法,最一般的方法就是直接平均所有client的标签的领域信息图片,让所有客户端学习这个信息,约束本地客户端更新:

图片

其中图片就表示整个联邦系统中标签为图片的所有来源于不同领域的样本的平均领域信息。然而,这样的全局存在偏见,他既无法正确的表示各个领域信息,也可能偏向主导领域,忽略少数域,如图2a。

图片图2 不同种类原型表示

3.2.2 集群原型对比学习

为了解决全局原型存在的问题,本文首先使用FINCH方法进行无监督聚类,将广泛的领域知识(每个样本的特征向量)无监督地进行分离,如此一来,来自于不同域的样本由于各自特征向量存在差异,不同的领域将被聚类为不同的集群,然后对同一集群内计算这个集群的原型,如图2b,防止多域之间相互平均后远离所有有用的领域知识。

图片

如上公式中,图片表示标签图片的被聚类后的图片个集群原型的集合。

基于此,本文通过新增损失项来实现集群原型对比学习。对于属于图片内某个样本图片,其特征向量为图片 ,本文通过对比学习,尽可能拉进该样本与具有相同语义的属于图片的所有原型的距离,或者说提高相同标签下来源于不同域的相似度,与此同时尽可能降低与所有不属于图片(记为图片)的原型相似度,通过此方法,在本地更新时学习不同领域的丰富知识,提高泛化能力。作者定义样本特征向量与原型的相似度为:

图片

然后构建实现集群原型对比学习的损失项:

图片

为什么这种方法有效呢?作者给出下面的分析:

图片

最小化这个损失函数相当于将样本特征向量紧密地拉近到其分配的正集群原型图片,并将特征向量远离其他的负原型图片。这不仅对多种领域失真保持不变性,还增强了语义的扩散性质,保证特征空间既具有泛化性又具有区分性,从而在联邦学习中获得满意的泛化性能。

3.2.3无偏原型一致性正则化

由于集群原型为领域转移下的可塑性带来了多样的领域知识,但由于非监督聚类方法,每次通信时都会动态生成集群原型,且其规模都在变化。因此,集群原型在不同的通信时代都不能提供稳定的收敛方向。本文提出第二个方法,通过构建公平稳定的无偏原型,约束多个集群原型与无偏原型的距离,保证持续的多域公平。

具体来说,将已经被聚类后的相同标签下的多个集群原型进行平均,表示该标签下的无偏收敛目标图片,如图2c。

图片

本文引入第二个损失项,利用一致性正则化项将样本的特征向量拉近到相应的无偏原型图片,提供一个相对公平且稳定的优化点,从而解决收敛不稳定性的问题:

图片

3.2.4 整体算法

除了上述两种损失想,再加上常规的模型训练使用的交叉熵损失函数,作为本文提出的联邦原型学习的损失函数:

图片

学习过程:

图片

本文算法

图片

04

实验结果

4.1 与State-of-the-art的实验结果对比

本文在Digits和Office Caltech数据集下进行测试,前者是数字的4种相同标签不同数据来源的数据集,后者是真实世界的4种相同标签不同数据来源的数据集。实验表明所提出的FPL不论是在单个领域上的性能还是多个领域上的平均性能,都优于当前SOTA。

图片

 

4.2 消融实验

图片

可以看出大部分情况下CPCL和UPCR共同作用能产生更好的性能。

图片

比较两种方法采用普通的全局原型和所提出的原型所展示出的实验效果,表明集群原型和无偏原型的有效性。

4.3 昇思MindSpore代码展示

本框架基于昇思MindSpore进行开发。

4.3.1 昇思MindSpore实现集群原型对比学习

def calculate_infonce(self, f_now, label, all_f, all_global_protos_keys):
        pos_indices = 0
        neg_indices = []
        for i, k in enumerate(all_global_protos_keys):
            if k == label.item():
                pos_indices = i
            else:
                neg_indices.append(i)

        f_pos = Tensor(all_f[pos_indices][0]).reshape(1,512)
        f_neg = ops.cat([Tensor(all_f[i]).reshape(-1, 512) for i in neg_indices], axis=0)
        #aaa
        f_proto = ops.cat((f_pos, f_neg), axis=0)
        f_now = f_now.reshape(1,512)

        f_now_np = f_now.asnumpy()
        f_proto_np = f_proto.asnumpy()
        def cosine_similarity_numpy(vec_a, vec_b):
            dot_product = np.dot(vec_a, vec_b.T)
            norm_a = np.linalg.norm(vec_a, axis=1, keepdims=True)
            norm_b = np.linalg.norm(vec_b, axis=1)
            return dot_product / (norm_a * norm_b)
        l_np = cosine_similarity_numpy(f_now_np, f_proto_np)
        l = Tensor(l_np)

        #l = ops.cosine_similarity(f_now, f_proto, dim=1)
        l = ops.div(l, self.infoNCET)

        exp_l = ops.exp(l).reshape(1, -1)

        pos_num = f_pos.shape[0]
        neg_num = f_neg.shape[0]
        pos_mask = Tensor([1] * pos_num + [0] * neg_num).reshape(1, -1)

        pos_l = exp_l * pos_mask
        sum_pos_l = ops.sum(pos_l, dim=1)
        sum_exp_l = ops.sum(exp_l, dim=1)
        infonce_loss = -ops.log(sum_pos_l / sum_exp_l)
        return Tensor(infonce_loss)

4.3.2 昇思**MindSpore实现无偏原型一致性正则化**

def hierarchical_info_loss(self, f_now, label, mean_f, all_global_protos_keys):


        pos_indices = 0
        for i, k in enumerate(all_global_protos_keys):
            if k == label.item():
                pos_indices = i



        mean_f_pos = Tensor(mean_f[pos_indices])
        f_now = Tensor(f_now)

        cu_info_loss = self.loss_mse(f_now, mean_f_pos)

        return cu_info_loss

4.3.3 客户端本地模型训练

 def _train_net(self, index, net, train_loader):

        if len(self.global_protos) != 0:
            all_global_protos_keys = np.array(list(self.global_protos.keys()))
            all_f = []
            mean_f = []
            for protos_key in all_global_protos_keys:
                temp_f = self.global_protos[protos_key]
                all_f.append(copy.deepcopy(temp_f))
                mean_f.append(copy.deepcopy(np.mean(temp_f, axis=0)))
            all_f = [item.copy() for item in all_f]
            mean_f = [item.copy() for item in mean_f]
        else:
            all_f = []
            mean_f = []
            all_global_protos_keys = []        

        optimizer = nn.SGD(net.trainable_params(), learning_rate=self.local_lr, momentum=0.9, weight_decay=1e-5)
        criterion1 = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
        criterion = CustomLoss(criterion1, self.loss2)
        self.loss_mse = mindspore.nn.MSELoss()
        train_net= nn.TrainOneStepCell(nn.WithLossCell(net,criterion), optimizer=optimizer)
        train_net.set_train(True)

        iterator = tqdm(range(self.local_epoch))
        for iter in iterator:

            agg_protos_label = {}
            for di in train_loader.create_dict_iterator():
                images = di["image"]
                labels = di["label"]

                #   train_net.set_train(False)
                f = net.features(images)
                #train_net.set_train(True)

                if len(self.global_protos) == 0:
                    loss_InfoNCE = 0 
                else:
                    i = 0
                    loss_InfoNCE = None

                    for label in labels:
                        if label in all_global_protos_keys:

                            f_now = f[i]
                            cu_info_loss = self.hierarchical_info_loss(f_now, label, mean_f, all_global_protos_keys)
                            xi_info_loss = self.calculate_infonce(f

05

总结与展望

在本文中,我们探讨了在异构联邦学习中领域转移下的泛化性和稳定性问题。我们的研究引入了一个简单而有效的联邦学习算法,即联邦原型学习(FPL)。我们利用原型(类的典型表示)来解决这两个问题,享受集群原型和无偏原型的互补优势:多样的领域知识和稳定的收敛信号。我们使用昇思MindSpore架构实现了FPL框架并展现其在效率和准确性上的优势。

在使用昇思MindSpore进行FPL框架开发中,我们注意到昇思MindSpore社区非常活跃,有许多华为开发者和使用者针对我们框架搭建中遇到的困难提供巨大帮助。不仅如此,借助昇思MindSpore提供的丰富的文档和教程以及社区中的实际案例和最佳实践,我们避免了许多潜在的陷阱,更快地达到了我们的研究目标。

90后程序员开发视频搬运软件、不到一年获利超 700 万,结局很刑! 谷歌证实裁员,涉及 Flutter、Dart 和 Python 团队 中国码农的“35岁魔咒” Xshell 8 开启 Beta 公测:支持 RDP 协议、可远程连接 Windows 10/11 ​MySQL 的第一个长期支持版 8.4 GA 开源日报 | 微软挤兑Chrome;阳痿中年的福报玩具;神秘AI能力太强被疑GPT-4.5;通义千问3个月开源8模型 Arc Browser for Windows 1.0 正式 GA Windows 10 市场份额达 70%,Windows 11 持续下滑 GitHub 发布 AI 原生开发工具 GitHub Copilot Workspace JAVA 下唯一一款搞定 OLTP+OLAP 的强类型查询这就是最好用的 ORM 相见恨晚
{{o.name}}
{{m.name}}

猜你喜欢

转载自my.oschina.net/u/4736317/blog/11072527