联邦学习_王树森_视频整理


视频链接: 联邦学习:技术角度的讲解(中文)Introduction to Federated Learning_哔哩哔哩_bilibili

联邦学习是一种特殊的分布式机器学习,联邦学习有很实际的应用。

存在的应用问题

1.很多用户装了 Google 的APP,这些移动端的 APP 会收集和产生数据。Google 想要建立一个机器学习模型,使用移动端的数据来训练该模型。APP 把数据收集起来,发送到 Google 云端,然后 Google 在自己的集群上来训练模型。但是如果有这样一个限制条件,用户觉得自己照片的信息是隐私,不允许 Google 把数据传到云上,那么 Google 该怎么来训练这个机器学习模型呢?

2.每家医院都有自己的数据,可以训练自己的医疗诊断模型,但是每家的数据都不多,训练效果不好。最简单的办法就是这些医院把数据整合起来,放到一个服务器上,可以是其中一家医院的服务器,也可以是一个第三方公司的服务器。在服务器上就可以把模型训练出来了。但事实上,不管是医院、银行还是保险公司,都不能轻易地把用户数据交给别人。这可能违反公司的规定,也可能违反法律,总之不可能把每一家的数据都集中在一个服务器上。

联邦学习

什么是联邦学习?

我前面的课上讲了并行算法或者叫做分布式算法,其中有一个编程模型叫做 parameter server,系统里面有一个或几个节点作为server,其他的节点作为worker。server 和 worker 之间可以通信,通信的方式叫做 message passing, server 可以给 worker 发消息, worker 可以给 server 发消息。我们可以用这种系统来训练,最小二乘回归也可以训练神经网络。

训练的过程中,计算几乎都是 worker 做的, server 端存储模型参数,并且更新模型参数。训练最小二乘或者神经网络需要做梯度下降或随机梯度下降,想要让算法收敛,需要重复很多轮梯度下降,每一轮都要重复这样几步操作。

首先 worker 节点向 server 索要模型参数, server 会把最新的模型参数发给worker,这一步需要通信,通信复杂度就是模型参数的数量;
然后 worker 就用本地的数据,用最新的模型参数来算出本地的梯度或者随机梯度,这一步是不需要通信的, worker 只需要在本地做计算就好;
worker 节点算出梯度后,把梯度发给 server,这一步又需要通信,梯度的维度和参数的维度是一样的,所以这一步的通信复杂度还是等于模型参数的数量;
server 收到梯度后,用梯度来更新模型参数,比如做一次梯度下降,或者随机梯度下降、RMSprop 、Adam 等算法来更新参数,这样就完成了一次迭代。

总结一下,每次迭代需要两次通信, server 把模型参数发给worker, worker 把梯度发给 server 。
计算主要是由 worker 做的, worker 计算梯度,这个计算量比较大, server 端用梯度更新模型,这个计算量很小,只需要算两个向量或者矩阵的差就好了。

回顾一开始的两个应用,一个是 Google 希望从用户移动设备的数据中进行机器学习,另一个应用是几个医院希望用医疗数据来训练模型。困难都是数据和隐私的问题,数据没有办法被集中放在一起。但是你想用分布式机器学习,这个问题不就解决了吗?你把每个用户的手机当做一个 worker 节点,让这些节点做计算,算出梯度,然后把梯度发给server,这样 server 不就可以学出来一个模型了吗?而且没有违反隐私限制,数据根本没有离开 worker 节点, server 是看不到用户数据的。用分布式机器学习,我一开始讲的两个问题都解决了。

**Federated learning 联邦学习其实就是一种分布式机器学习。**虽然联邦学习被炒得很热,但是它跟分布式机器学习没有本质区别。联邦学习是 15 年提出来的,联邦学习的应用很新,但是方法上没有太多新意,一直到 17 年他们的文章才发表在 AISTATS 会议上。联邦学习带来很多有意思的问题,值得学术界来解决,但是其本质完全就是分布式机器学习。

联邦学习这个名字是怎么来的?联邦是由很多个邦或者州组成的,每个邦或者州都有很高的自治权组成的,组成的联邦就是个比较松散的政府。这很像我之前说的两个应用,手机用户或者医院,他们可以参与分布式学习,但是控制权还在用户自己手上,这就很类似于联邦,联邦学习就是这样一个比喻。

联邦学习跟传统分布式学习有什么区别?

1. Users have control over their device and data.

  • Mapreduce 等传统分布式系统中, worker 受 server 的控制,接受 server 的指令, server 甚至可以发送一个叫 shuffle 指令,让 worker 之间交换数据,把数据打散。传统分布式机器学习很像中国的政治体制,上面有对各省的绝对控制权。
  • 联邦学习跟传统分布式学习的第一个区别就是 server 对用户设备和数据有多大的控制权。**用户对自己的设备和数据有绝对的控制权。**用户可以随时让自己的设备停止参与计算和通信。这就像是联邦的组成部分,每一个邦都有很强的自治权。

2. Worker nodes are unstale.

  • 传统分布式机器学习的设备往往都是在机房里连着高速宽带, 24 小时开机,有专人维护,非常稳定。

  • 参与联邦学习的 worker 节点往往都是手机、iPad、智能家居这样不稳定的设备。

  • 传统分布式,计算的节点几乎都是相同型号的机器和处理器,计算性能几乎都是一样的。长话短说,联邦学习的 worker 节点不稳定,这对分布式计算造成了困难。

  • 联邦学习用户的设备也各不相同。这些设备的计算能力各不相同,有快有慢。要让设备同时开始做计算, iPhone 10 已经算了好几轮了, iPhone 5 一轮还没算完。不管是同步算法还是异步算法,结点有快有慢,都会造成严重的问题。

3. Communication cost is higher than computation cost.

  • 传统分布式系统中,节点直接拿网线连起来,或者接入了高速宽带, 24 小时开机,连接非常稳定。
  • 联邦学习里面 worker 节点往往都是手机、iPad,这些设备跟服务器的连接都是远程连接,甚至设备可能跟服务器不在一个国家,所以带宽很低,网络延迟很高,发送几千万个模型参数不可能几百毫秒就完成,能几秒完成就很不错了。因此,通信代价非常大,通信代价远大于计算的代价。

4. Data stored on worker nodes are not IID.

  • 在传统的分布式机器学习中,数据的划分通常是均匀的、随机打乱的。如果用每个节点上的数据去算一个统计量,比如说均值方差,会发现每个节点上的统计量都差不多。如果能把数据随机在节点之间打乱,shuffle一下,数据就能成为独立同分布,这样非常有利于设计高效的算法。
  • 联邦学习不能假设数据是独立同分布的。每个手机用户的数据统计性质肯定不一样,因为用户的习惯不一样。比如说我经常拿手机拍拍风景,有些妹子喜欢自拍,所以我们两个手机相册的图片的统计性质就完全不一样。由于数据不是独立同分布,很多已有的减少通讯次数的算法不再适用。

5. The amount of data is servely imbalanced.

  • 有的用户几天拍不了一张照片,有的用户每天拍几十张照片,这两个手机上照片的数量差了上百倍。worker节点上的数据集有的大,有的小,建模和计算都会出问题。
  • 建模:如果给每张图片相同的权重,那么学出来的模型几乎取决于重度用户,拍照少的用户就被忽略了。如果给每个用户相同的权重,这样学出来的模型对重度用户的效果不太好。建模很麻烦。
  • 计算:由于负载不平衡,计算的时候也是个问题。计算时间不一样,一个已经算了 100 个 epoch 了,另一个用户可能连一个 epoch 还没算完。传统分布式计算都是要做负载平衡的,但是联邦学习没有办法做负载均衡,不能把一个用户的数据转移到另一个用户的手机上去。

Research Direction

1. Communication-Efficiency

由于 2 和 3 这两个原因,设计联邦学习算法的时候,最重要的是减少通信次数。因为 worker 节点不稳定, server 发一个请求, worker 未必会立刻响应,万一手机关机了,或者没连Wifi,手机又不会响应。再者,传几千万个模型参数也有点慢。假如一个算法需要 1000 次通信才能收敛,另一个只要 100 次通信就能收敛。我们肯定想要用通信 100 次的算法。

因此,联邦学习最重要的研究方向就是如何降低通信次数,哪怕让计算量大很多,只要能减少通信次数就是值得的。已经有很多算法可以降低通信次数,这些算法的理念都是多做计算,少做通信

设计算法的基本想法:worker 节点拿到模型参数之后,在本地做很多计算,这样就可以得到比梯度更好的下降方向。然后 worker 把这个更好的下降方向传给 server。 server 用这个下降方向来更新模型参数。由于这个下降方向比梯度更好,所以可以让收敛更快。比方说原本用梯度下降需要 1000 次迭代才能收敛。由于新的算法算出的下降方向更好,所以现在做 100 次迭代算法就能收敛了。迭代次数少了 10 倍,所以通信次数也就少了 10 倍。

在联邦学习的例子中,可以让用户设备在空闲而且充电的时候做计算,即使手机本地的计算量增加了很多,也不会影响用户体验。

并行梯度下降

并行梯度下降,每一个 worker 节点上都有一部分数据,这叫做数据并发(data parallelism)。每一轮迭代开始的时候, server 把最新的模型参数发送给 worker 节点,我们来看一下这个 worker 节点每一轮都做什么样的操作:

OzbXNi.png

假设这个设备是第 i i i 个 worker 节点,它做这样的通信和计算:

  • 从服务器接收到最新的模型参数 w \mathbf{w} w
  • w \mathbf{w} w和worker节点本地的数据算出一个梯度 g i g_i gi
  • 最后就把梯度 g i g_i gi 发送给 server。
OzbuLX.png
  • server 接收到所有worker发来的梯度之后,
  • 把这些梯度 g 1 g_1 g1 一直到 g m g_m gm 全都加起来,得到了梯度 g g g
  • 然后 server 做一次梯度下降,更新参数 w \mathbf{w} w。这里的 α \alpha α 是步长 step size 或者叫 learning rate 学习率。

然后系统就可以进行下一轮迭代。系统把最新的参数 w 发给worker, worker 算梯度,把梯度传回server,然后 server 再来更新一次参数,不断重复。这个过程很多很多次,最后算法会收敛。

Federated Averaging Algorithm

这个算法跟并行梯度下降不太一样。Federated averaging 是一种 communication efficient algorithm。FedAvg用更少的通信次数就能达到收敛。

每一轮迭代的第一步跟前面一样,还是 server 把参数发给 worker 节点,但是每个 worker 节点所做的就跟之前不太一样了, worker 节点接收到模型参数 w \mathbf{w} w 之后,重复a、b操作。

OzbwPj.png
  • 首先用 w \mathbf{w} w 和本地数据去算一个梯度 g g g,然后在本地做梯度下降,本地更新参数 w \mathbf{w} w,这里的 α \alpha α 叫做步长或学习率。把a、b这两个步骤重复几个epoch。这里 epoch 的意思是说把本地数据全都扫一遍叫做一个epoch,扫 5 遍就是 5 个epoch。重复 AB 这个步骤 1- 5 个 epoch 就好了,重复太多次也不好。
  • 重复a、b这两步很多次之后。假设这个设备是第 i i i 个节点,把最终本地得到的参数 w \mathbf{w} w 记作是 w ~ i \tilde {\mathbf{w}}_i w~i,然后把 w ~ i \tilde {\mathbf{w}}_i w~i发送给server,这样这个节点就完成了计算。

为什么要这样在本地做很多次计算,把收到的模型参数 w \mathbf{w} w改进很多次之后再把它发回server?

OzkovY.png

这样做在两次通信之间可以把参数做很大的改进,而不仅仅是一次梯度下降。

  • worker 节点在本地做了很多次梯度下降,把模型参数从 w \mathbf{w} w 变成了本地的 w ~ i \tilde {\mathbf{w}}_i w~i ,worker 就把这些算出来的 w ~ i \tilde {\mathbf{w}}_i w~i 发送给server。

  • server 接收到全部的 w ~ i \tilde {\mathbf{w}}_i w~i 之后就对他们做一个平均或者是加权平均。把这 m 个 worker 节点的输出 w ~ 1 \tilde {\mathbf{w}}_1 w~1,一直到 w ~ m \tilde {\mathbf{w}}_m w~m 加起来除以m,这样得到的平均就作为新的模型参数 w \mathbf{w} w。下一轮迭代的时候再把这个新的 w \mathbf{w} w 发送给所有的 worker 节点。

实验对比 FedAvg 和 Grad Descent

Ozkqdv.png
  • 横轴:通信次数,纵轴:损失函数,值越小越好。

  • Fedavg 算法的收敛曲线在下面。这说明 FedAvg 让 loss 下降得更快,用相同次数的通信, FedAvg 收敛要快一些,这正是 FedAvg 有用的原因。FedAvg 被设计出来就是为了使用更少次数的通信达到收敛。两次通信之间, FedAvg 让每个 worker 节点做大量本地计算,以牺牲计算量为代价换取更少的通信次数

OzkTNc.png
  • 横轴:number of epochs,所有的 worker 都把自己本地数据扫一遍为一个epoch,因此 epoch 可以用来衡量计算量的多少。
  • FedAvg 的收敛曲线在上面。相同epochs下,FedAvg 的损失函数大于梯度下降的损失函数,这意味着让worker节点做相同的计算量,那么 FedAvg 的收敛比梯度下降要慢。
  • FedAvg 减少了通信量,但是增加了 worker 节点的计算量,这就是以牺牲计算量为代价换取减少通信量。但是,联邦学习中的计算代价小,通信代价大,所以 FedAvg 这种算法还是很有用的。

我们证明了 FedAvg 收敛而且不需要独立同分布的假设,理论上保证了可以使用 FedAvg 做联邦学习。

2. Privacy 隐私保护

回顾一下分布式机器学习或者是联邦学习的架构和算法:

  • 梯度下降算法中,server 把模型参数发给worker,worker算梯度,然后把梯度传回server。
  • FedAvg 不是把梯度传回server,而是把 worker 节点本地算出的参数模型传给server。

用梯度下降或者 FedAvg 这样的方式学习,被传来传去的只有模型参数和梯度,用户的数据并没有离开用户的设备。这样看来分布式学习和联邦学习都是安全的,用户的数据隐私被保护了

实际上,随机梯度是用户本地一个 batch 的数据算出来的,算梯度的时候其实就是用一个函数把用户的数据做了个函数变换,把数据映射到了梯度。虽然数据没有离开用户设备,但是梯度被传出去了。**梯度几乎携带了数据所有的信息,所以使用梯度是可以反推出来数据的。**用户的数据被间接泄露出去。

说到隐私保护,所有人都能想到用 differential privacy 来保护隐私。 differential privacy 其实就是加噪声,通常是往梯度里加噪声,往模型参数里加噪声也可以。可惜实验证明加噪声是不行的,要是噪声不够强,还是能反推出来用户的数据。要是噪声太强了,学习的过程就继续不下去了,损失函数不再往下降。

往梯度或者模型里加噪声,收敛的速度和最后测试的准确度都会变差。加到噪声越多,机器学习的效果越差。联邦学习的隐私泄露很容易,但是想保护隐私就比较困难。

3. Adversarial Robustness

第三个研究方向就是联邦学习的鲁棒性,让联邦学习可以抵御拜占庭错误和恶意的攻击。

什么是拜占庭错误? 针对分布式系统中出现异常节点的情况。拜占庭将军问题是个比喻,简单的说就是我们中出了个叛徒,我们一起商量工程,但是有个叛徒在忽悠我们,让我们去执行错误的战术,导致我们最后被团灭。

这是个传统的分布式系统的问题,要是有个节点故障了,但是没有挂掉这个故障节点就会给其他节点发送错误的信息,有可能把整个都给带到沟里去。联邦学习也存在拜占庭将军问题,因为真的会出叛徒,要是有个节点故意使坏,把自己的数据和标签做修改,那么传给 server 的梯度就是有害的,可能会让学到的模型犯错误,参考文献一提出了 data poisoning attack。文章发表在 NIPS 2018 上,意思就是把一部分作为训练数据的图片做一些小幅的修改,但这些修改是精心设计出来的扰动,加了扰动之后,这个图片样本就变成了毒药,可以用来对模型下毒。如果训练模型的时候用到这些毒药,模型就会犯一种很特别的错误。

Ozkz9I.png
  • 参考文献[1]是针对普遍的深度学习的,这种 data poisoning attack 用到联邦学习上,这是很容易做到的。worker节点只要能看到模型参数,就能把自己本地的图片变成毒药。worker节点用毒药计算出来梯度,然后给传回去,会让 server 上的模型犯某种特定的错误,或者留一个后门。

  • 参考文献[2]设计一种 model poisoning attack,专门针对分布式机器学习。就是把本地的数据标签换成错的,用正确的图片和错误的标签来计算梯度方向,再把梯度方向发给一个server。把标签换成错的,肯定会干扰分布式学习。

有了攻击自然会有人研究防御。

  • Defense1是 server 检验 worker 传回来梯度到底好不好。 server 会拿某个 worker 传回来的梯度来更新模型参数,然后算一下用新的参数在测试集上的准确率。如果某个 worker 恶意传回错误的梯度,肯定会造成测试准确率的下降。但是 worker 上的数据跟 server 上存的数据统计分布不一样,联邦学习中的 server 是不能看到用户数据的。即使 worker 是诚实的,单个 worker 算出来的梯度也可能会让 validation accuracy 变差。只有把所有的 worker 的梯度平均起来,更新模型参数, validation accuracy 才能变得更好。我觉得这种防御不太适用于联邦学习。

  • Defense2 是 server 检验 worker 传回的梯度。假设数据是独立同分布的,那么所有 worker 算出来的梯度都不会差太远。比较每个 worker 传回的梯度,要是有个别梯度跟其他梯度差距很大,就可以认为相应的 worker 是异常的。因为联邦学习中的数据都不是独立同分布的,每个 worker 上的数据本身性质就很不一样,所以用不同数据算出来的梯度自然也很不一样,所以不太有用。

  • Defense3 是 Byzantine-tolerant aggregation,server 不对算出来的梯度做加权平均,而是用 median 等更稳定的方法来整合梯度,比如说把平均 mean 换成中位数median,但实际上也可能不work。这些文章都在假设数据是独立同分布的。

总而言之,攻击比较容易,但是防御很困难。

总结

  • 联邦学习的目标是让多个用户一起协同训练出一个模型,但是这些用户不共享数据,用户的数据不能离开用户本地。如果用户的数据被交给第三方,然后做一个并行的训练,这个不叫联邦学习。联邦学习很重要的一点就是要保护用户的数据隐私。
  • 联邦学习是一种分布式学习,但是联邦学习有一些特点,导致联邦学习比普通的分布式学习要困难。重点讨论了两个难点,一个是数据不是独立同分布,另一个是通信很慢。

联邦学习有哪些值得研究的问题?

OzzxBD.png
  1. 首先就是要解决算法效率的问题。由于用户移动设备的响应慢,通信速度慢,每一次都要花很长时间才能等到大部分用户把结果发回来。假设一个小时能等到一半用户响应,我们才能做一次迭代更新,那么一个星期就能迭代 100 来次。如果一个算法要迭代几千几万次才能收敛,这个算法就没有办法被用在联邦学习上,所以要设计 communication-efficient algorithms,争取几十次通信就能训练出模型。我重点讲了 federated averaging。这个算法是以牺牲更多计算为代价,以换取通信次数的减少。其实 communication-efficient algorithms 也不是很新的想法,至少七八年前就人在研究怎么样减少通信次数了。但是绝大多数的算法只能用在独立同分布的数据上,不能用在联邦学习的数据上。

  2. 第二个研究方向是保护隐私,阻止别人通过梯度或者模型参数来反向推断出用户数据。联邦学习的算法号称保护隐私,因为数据没有离开用户设备,server看不到用户端的数据,但遗憾的是已有的算法并不能保护隐私。用传出去的梯度或者模型参数可以反过来推断用户数据的性质。已经有几篇文章验证这种推断是可行的,而且很容易做到。攻击很容易,但是防御非常困难。

  3. 第三个研究方向就是算法鲁棒性,有多个用户参与学习,每个用户都不受控制,这就有可能出现所谓的拜占庭将军问题,就是我们中出了个叛徒,叛徒可以发些乱七八糟的错误信息给server,也可以精心设计出有毒的样本或者有毒的梯度来毒害模型。如果数据是独立同分布的,就有办法检测出异常的 worker 节点,可以发现叛徒也有办法抵御攻击。但遗憾的是,联邦学习的特点就是数据不是独立同分布,所以这些攻击很难被发现,也不容易防御。提高联邦学习的鲁棒性是个值得研究的方法,但并不容易做到。

猜你喜欢

转载自blog.csdn.net/qq_45670134/article/details/131673951