联邦学习: 统一数据协作和隐私保护的技术解决之道

将门创投
关注

上图展示了分布式机器学习的基本框架:数据和模型的切分,单机间的计算协同,单机计算结果的合并;数据和模型的切分:如何处理的大数据或大模型进行切分,在多个机器上做并行训练。单机间的计算协同:在把这些数据和模型放到多个计算节点之后就,怎样实现不同机器之间的通信和同步,使得它们可以协作把机器学习模型训练好。单机计算结果的合并:当每个计算节点都能够训练出一个局部模型之后,怎样把这些局部模型做聚合,最终形成一个统一的机器学习模型。

关于数据切分,训练中所涉及的数据和模型规模巨大,需要基于分布式的机器学习平台,部署了数十个乃至数百个并行运行的计算节点对模型做训练。

第一种模式是数据并行化训练,不同的机器有同一个模型的多个副本,每个机器分配到不同的数据,然后将所有机器的计算结果按照某种方式合并。常见的方法基于“随机采样”,随机采样可以尽量保证每台机器上的局部训练数据与原始数据是独立同分布的。但是也有明显缺点,例如计算复杂度比较高,训练样本未被选中,导致训练样本浪费。

还有一种方法是全局置乱切分的方法,该方法将训练数据进行随机置乱,然后将打乱后的数据顺序划分为相应的小份,随后将这些小份数据分配到各个工作节点上。相比于随机采样方法,其计算复杂度比全局随机采样要小很多,而且置乱切分能保留每一个样本,直观上对样本利用更充分,同时和有放回的随机采样在收敛率上是基本一致的。

数据并行化的目标是将数据集均等地分配到系统的各个节点 (node),其中每个节点都有模型的一个副本。每个节点都会处理该数据集的一个不同子集并更新其本地权重集。这些本地权重会在整个集群中共享,从而通过一个累积算法计算出一个新的全局权重集。这些全局权重又会被分配至所有节点,然后节点会在此基础上处理下一批数据。数据并行化是应用最为广泛的并行策略,但随着数据并行训练设备数量的增加,设备之间的通信开销也在增长。  另一种模式是模型并行: 系统中的不同机器(GPU/CPU等)负责模型的不同部分。常见的场景例如,神经网络模型的规模比较大,无法存储于本地内存,则需要对模型进行划分,不同网络层被分配到不同的机器,或者同一层内部的不同参数被分配到不同机器。对于神经网络这种高度非线性的结构,各个工作节点不能相对独立地完成对自己负责的参数训练和更新,必须依赖与其他工作节点的协作。常用的模型并行方法有,横向按层划分、纵向跨层划分和模型随机划分。
除去上面两类并行方式,还有一类是混合并行 (Hybrid parallelism),在一个集群中,既有模型并行,又有数据并行。例如,在最近的Optimizing Multi-GPU Parallelization Strategies for Deep Learning Training,就提到如何如何使用混合并行的方法(在每个数据并行的基础上,加入多个设备,进行模型并行)从而实现更好的加速。

分布式机器学习系统的另一个挑战是单机间的计算协同。

同步通信 vs 异步通信

通信的步调是需要考虑的问题之一,在机器学习过程中,不同机器的数据大小,机器效能,训练速度会有差异,有的机器训练速度比较快,有的机器训练速度比较慢。如果采用同步通信的方式,其他机器需要等最慢的机器完成计算,才能往前继续训练,导致整个进程受集群里最慢的机器的严重影响。

为了实现高效的分布式机器学习的效率,异步通信被广泛关注和使用。在异步通信过程中,每台机器完成本地训练之后就把局部模型、局部梯度或模型更新推送到全局模型上去,并继续本地的训练过程,而不去等待其他的机器。还有通信的拓扑结构:常见的有,基于参数服务器的通信拓扑结构和基于流程图的通信拓扑结构;

另外一点是通信的频率:通信越频繁,通信的代价就会越高,可能会降低训练的速度。常见优化包括模型压缩,模型量化,随机丢弃等等。

除了通信以外,单机的计算结果合并,并聚合成整体模型,也是一个非常具有挑战的一个问题。在业界里常用的方式有参数平均,集成模型等。例如参数平均就是把各个不同的局部模型做简单的参数平均。参数平均是最简单的一种数据并行化。若采用参数平均法,训练的过程如下所示:

参数平均比较简单,但是不适用于非凸问题时,例如深度学习。可以考虑使用模型集成。

通过上面的介绍,我们对分布式机器学习技术有了一定的认识。实际上,联邦学习本质上也是一种分布式机器学习技术/框架。或者说,是某种数据并行化训练。在了解分布式机器学习后,联邦学习将不再神秘。

各个参与方就如同不同的worker,每方保留自己的底层数据,并且无需分享这些数据给其他worker。多个worker携带自己的数据,在加密形态的前提下共建模型,提升AI模型的效果。

最早在 2016 年由谷歌和爱丁堡大学学者提出,原本用于解决安卓手机终端用户在本地更新模型的问题,Google案例中的联邦学习过程:1. 设备端下载当前模型;2. 基于本地数据来更新模型;3. 云端整合多方更新,迭代模型。

从上图我们可以看到,客户端(移动设备)负责模型更新和上载,需要具备一定的计算资源。同时,这里还有很多优化点,包括训练时间,训练频率等。

联邦学习和传统分布式系统的差异

正如我们前面所说,联邦学习本质上也是一种分布式机器学习技术/框架,但是他和和传统分布式系统还是存在明显差异:Server与Worker的关系:分布式系统中Server对Worker是完全所有权/控制权,但是在联邦学习中是是相对松散的,并没有所有权/控制权。这也导致了额外的挑战。不同设备数据的差异性:分布式机器学习的每个节点基本上是均衡的,而在联邦学习的架构里,Worker节点与节点之间的数据差距是无法保证的。在联邦学习架构中,不同worker,还有server不是在局域网内部,需要消耗大量的网络,并需要考虑网络的稳定性。

声明: 本文由入驻OFweek维科号的作者撰写,观点仅代表作者本人,不代表OFweek立场。如有侵权或其他问题,请联系举报。
侵权投诉

下载OFweek,一手掌握高科技全行业资讯

还不是OFweek会员,马上注册
打开app,查看更多精彩资讯 >
  • 长按识别二维码
  • 进入OFweek阅读全文
长按图片进行保存