最近传言Chat GPT也是MoE,想起之前做profile的经历,记录一下。

MoE稀疏地被激活特定的线性层,并且需要分别对不同的worker scatter,计算完成后再gather,需要两次同步。虽然FC只是非常小的计算量,在MoE的体量下可以忽略负载不均衡的问题,但是两次同步之间的计算量较大,当整个通信组的变大的时候,不同rank之间的速度极差也有概率变大,在64/128卡的时候几乎可以闲置接近一半。

如果想让这个overhead减小,我能想到的方案有这几个:

  1. 将整个训练的group切分成多个,每个subgroup分别进行route,之后定期同步不同的subgroup。(有点像联邦学习)问题是每个subgroup的batch size都被限制住了,这样会不会造成不稳定的情况?从之前简单摸FL的经历来看,检测、分类模型都会有退化现象。
  2. 能不能通信的时候算一些别的有意义的事情?比如将experts设置成一个分支,通信的时候把另外的不需要通信的分支算完,差不多之后通信也都完成了。
  3. 把现在torch的if/else代码改到cuda上

后两者都只能减少最少的overhead,而第一个又会带来精度问题,好奇OpenAI是怎么scale到数千个worker的。