不止于ZeRO:BMTrain技术原理浅析

作者:赵威霖、韩旭
2022-06-20 18:17:00

前期我们发起了 CPM-Live 开源大模型直播训练,与现有的大模型训练使用百余张显卡相比,我们实现了 8 张 A100 显卡 训练百亿大模型。这优异效果的背后基于的是 大模型高效训练工具 BMTrain 和 模型仓库 ModelCenter。与现有框架相比,BMTrain 能够实现大模型的低资源、高效训练,并且简单易用,便于开发者上手。

支撑起 BMTrain 优异性能表现的是其采用的多项分布式训练优化技术,它们共同解决了大模型训练过程中的 显存占用 问题。为了深刻理解这一关键问题,我们不妨分析一下模型训练过程中的显存占用情况。

模型训练中的显存占用主要包括:模型参数、模型梯度、优化器状态、运算中间变量。以下图为例,训练过程中的显存占用包括一份模型参数以及对应的一份梯度,比较常用的 Adam 会保留两倍参数量的优化器参数,除此之外还有一些运算的中间变量。

根据上述分析,对于一个百亿参数大模型,模型参数约 20G,训练过程中需要占用的显存就会超过 80G,在每一张显卡中都完整地维护这些内容,显存是远远不够的。这就需要我们采用相关分布式训练技术,进行模型训练的显存优化。

为解决这一关键问题,在 BMTrain 中,我们通过 数据并行 降低运算中间变量显存占比、增大吞吐量,通过ZeRO降低模型参数、模型梯度、优化器状态的显存占比,通过Optimizer Offload将优化器状态卸载到内存上,通过Checkpointing算子融合避免储存运算的中间变量,最后使用通信计算重叠进一步降低整套系统时间花费。

综合使用这些技术,BMTrain 可以实现 单张消费级显卡全参数微调 BERT-Large8 台 A100 小集群训练 GPT-3,在超大规模模型训练场景下与 DeepSpeed 等框架相比最多可节省90%的算力成本。想了解这些技术具体的细节吗?本文来带你一探究竟!

 

背景知识

分布式训练的核心精神是切割,将数据、参数等诸多要素切割到不同计算节点上进行运算。有切割就有合并,不同节点之间会频繁通信以同步及汇总计算结果。

这里我们简单介绍 5 个基本通信算子,这是分布式训练框架的重要基础(以四张显卡为例,由 rank0 到 rank3 表示):

01 Broadcast

张量位于某张显卡中,广播后,每张显卡都会获得一个同样的张量。

02 Reduce

每张显卡中存有一个张量,将这些张量进行如求和、取max等计算后,其结果被置于指定的某张显卡上。

03 All Reduce

每张显卡中存有一个张量,使用它们进行相关计算后的结果被置于所有的显卡上,各张显卡上得到的结果相同。

04 Reduce Scatter

每张显卡中存有一个大小为 4d 的张量,张量之间进行计算后的结果被平均切分为 4 份,每份的大小为 d,分别置于 4 张显卡上。

05 All Gather

每张显卡中存有一个大小为 d 的张量,收集后,张量拼接的结果 (大小为 4d) 被置于所有的显卡上,各张显卡上得到的结果相同。

分布式训练

一种典型的分布式训练方法是使用数据并行,然而对于大模型来说,仅通过数据并行进行显存优化是远远不够的,我们需要更进一步地进行切割。进一步优化的技术主要来自两大技术路线:在算子层面进行切割的 模型并行流水线并行技术 以及在显存上进行切割的 ZeRO技术。在BMTrain中,我们采用了 数据并行 ZeRO技术 来进行模型的分布式训练,并将陆续支持模型并行与流水线并行。

数据并行

数据并行通过减小每张显卡上需要处理的 batch 大小来减少模型的运行中间变量。具体来说,假设有n张显卡,那么每张显卡可以只去处理 batch_sizen{\text{batch_size}\over n} 的数据,最后将各张显卡计算得到的梯度进行求和 ( all-reduce ) 即可。在这种方式中,每张显卡都会获得完整的梯度信息,最后每一张显卡上分别执行优化器的 step。

- 采用数据并行策略,原模型训练需要的运算中间变量被划分到不同显卡中。图中以八卡并行为例,后面各图也采用相同的设定

模型并行

模型并行技术尝试将模型计算进行切割。以全连接层为例,对于计算 yA=WA×BxB\mathbf{y}_A=W_{A\times B} \mathbf{x}_B ,通过将参数矩阵分解为n个小矩阵 WAn×B(i)W_{{A\over n}\times B}^{(i)} ,每张显卡上计算 yAn(i)=WAn×B(i)xB\mathbf{y}_{A\over n}^{(i)}=W_{{A\over n}\times B}^{(i)} \mathbf{x}_B , 然后通过 all-gather 通信即可获得完整的结果 yA\mathbf{y}_A 。在这种方法中,各张显卡均处理同一批次的数据,在计算时进行合作。

- 采用模型并行策略,模型参数被划分到不同的显卡中

与模型并行类似的一种解决思路是流水线并行,也是尝试对训练计算进行切分。相比于模型并行中对transformer 模型进行纵向的计算切分,流水线并行则将不同层的 transformer block 计算划分到不同的显卡上。

 

ZeRO

在实际训练中,优化器 ( 如 Adam ) 状态占用的显存要比参数和梯度二者加起来还要多,因此 ZeRO(Zero Redundancy Optimizer,零冗余优化器)技术首次提出对优化器状态进行切分,每张显卡上只负责优化器状态对应的部分参数的更新。训练策略上,ZeRO 基于数据并行,不同的数据被划分到不同的显卡上进行计算。根据对优化器状态、梯度、参数划分程度的不同,ZeRO 技术包含 ZeRO-1/2/3 三个层次。

ZeRO-1

因为 ZeRO 基于数据并行,首先需要通过 all-gather 操作获取完整的模型参数更新结果,随后每张显卡根据自己的数据和模型参数完成对应的前向传播和反向传播。在整个过程中,梯度和参数均完整地保留在每张卡上,随后对梯度进行 reduce-scatter,每张卡根据自己所划分的优化器状态和梯度来计算对应部分的模型参数。

- 基于 ZeRO-1 和数据并行,优化器状态和运算中间变量被划分到不同的显卡中

ZeRO-2

ZeRO-2 在 ZeRO-1 的基础上进一步对梯度进行划分。注意,由于在反传的过程中,不需要始终保留完整的梯度,在计算当前层梯度时,只需要后一层输入的梯度。因此在反传的过程中,对于不参与后续反传计算的梯度,可以立即 reduce-scatter 划分到多块卡上,这样在训练过程中,梯度在每块卡上的显存占用,就变为原先的 1/n1 / n 了。反传结束后,每块卡再根据部分的梯度和优化器状态,计算得到更新后的模型参数,最后再将更新后的参数使用 all-gather 同步到其他的显卡上。

- 基于 ZeRO-2 和数据并行,梯度、优化器状态和运算中间变量被划分到不同的显卡中

ZeRO-3

而 ZeRO-3 技术,则是更进一步将模型参数部分进行切分。由于每张显卡只有一部分的优化器状态,只更新一部分的参数,一个很直观的思路就是每张显卡上只维护优化器需要更新的那一部分参数。然而,在模型的计算过程中,还是需要完整的模型参数。因而在 ZeRO-3 中,模型中的每个模块在计算之前,都需要通过一次 all-gather 操作将参数恢复完整,并在前向计算结束后再将模型参数释放掉。进行反传时,再重新使用 all-gather 获取参数计算梯度并使用 reduce-scatter 划分梯度,如下图。

通过使用 ZeRO-3 优化,训练相关的所有信息均被切碎分散到不同的显卡上,让每张显卡上的显存占用都被降低到极致,使得每张显卡上可以容下更大的 batch_size,更充分地利用计算核心,带来更大的模型吞吐,同时将训练模型所需的显卡数量降至最低。

- 基于ZeRO-3和数据并行,参数、梯度、优化器状态和运算中间变量被划分到不同的显卡中

不过在 ZeRO 的原论文中指出, ZeRO-3 增加了额外的一次参数通信时间(即反向传播时的 all-gather ),因此会引入额外的通信开销,在部分场景下性能不及 ZeRO-2 和模型并行。为了减少额外通信量带来的效率损失,我们还额外引入了通信计算重叠的策略,这将在后面被介绍到。根据我们的实现,实验结果表明 ZeRO-3 在 NVLink+IB 的环境下训练超大规模模型较联合使用 ZeRO-2 和模型并行的方案会带来更大的计算吞吐量提升。

显存优化

除了上述分布式训练方法外,BMTrain还通过 Optimizer Offload 和 Checkpointing 技术进一步减少冗余的显存占用,并以牺牲最少的通信代价为前提,做到了在极致显存优化下仍然能高效率地训练。

Optimizer Offload

Optimizer Offload 是指将优化器状态从 GPU 卸载到 CPU 上,从而进一步节省显存。我们以 Adam 优化器为例介绍为什么需要将优化器的参数卸载。

在 Adam 中,优化器需要维护梯度的移动平均以及梯度平方的移动平均:

mt=β1mt−1+(1−β1)gtvt=β2vt−1+(1−β2)gt2\begin{aligned} m_t &= \beta_1 m_{t-1} + (1 - \beta_1) g_t\\ v_t &= \beta_2 v_{t-1} + (1-\beta_2) g_t^2\\ \end{aligned}

正如前文所示,与模型参数相比, Adam 优化器需要至少两份的显存占用量,这在混合精度训练中是一笔非常大的开销。通过使用 ZeRO-3 的梯度切分,每张计算卡上的需要处理的梯度信息大幅减少,将这一部分 GPU 计算卸载至 CPU 上产生的通信需求较小,同时 CPU 处理这样切分后的梯度也不会特别吃力。据此,我们付出了极小量的额外开销就将显存开销降低至原本的一半左右。

- 图 Optimizer Offload 技术

Checkpointing

Checkpointing 技术是一项很早就被提出,用于优化神经网络模型训练时计算图开销的方法。这种方法在 Transformers 等结构的模型训练中,能够起到非常明显的作用。目前主流的 Transformers 模型由大量的全连接层组成,我们以全连接层为例进行计算图的显存分析。

Forward: y=WxBackward: ∇x=WT∇y∇W=(∇y)xT\begin{aligned} \text{Forward: } & \mathbf{y} = W \mathbf{x}\\ \text{Backward: } & \nabla \mathbf{x} = W^T\nabla y \\ & \nabla W = (\nabla \mathbf{y}) \mathbf{x}^T \end{aligned}

为了能够在反向传播中计算梯度,需要在正向传播时记录下参数矩阵 WW 与输入 xx ,这两部分参数随着正向传播逐层累积,消耗了非常多的显存。

因此,我们使用 Checkpointing 技术(也称为亚线性内存优化),其核心方式是通过时间换空间,我们在模型各层之间设置检查点,只记录每一层模型的输入向量。在反向传播时,根据最近的 checkpoint 重新计算该层的局部计算图。

框架实现的优化

除了上述显存优化技术外,BMTrain 还在具体实现上进行优化,以期得到更好的加速效果。

混合精度

传统模型使用单精度参数进行训练,在大模型训练中,我们可以通过使用半精度参数来降低参数量并节省运算时间。具体实现上,BMTrain 在正向传播和反向传播的过程中均使用半精度进行计算,并在优化器中维护单精度的模型参数和优化器参数。

使用混合精度的另一个好处在于能够更好地利用显卡中的 tensor core。较新的显卡在 CUDA core 之外,还设置了专门用于张量运算的核心 tensor core,利用 tensor core 将为程序带来进一步的性能提升。使用混合精度训练能够更好地利用 tensor core 特性,从而为训练过程进一步加速。

算子融合

为了进一步提升性能,我们在 CPU 和 GPU 层面均进行了算子层面的实现优化。在 CPU 上,我们使用多线程 + SIMD(单指令流多数据流) 的 CPU 编程方式,对 Offload 至 CPU 计算的 Adam 优化器进行 CPU 上的计算加速,使其不会成为系统的性能瓶颈。在 GPU 上,我们使用算子融合的方式,将 Softmax 与 NLLLoss 算子合二为一,减小了中间结果的显存占用。

通信计算重叠

上文中提到,ZeRO3 技术将引入额外的通信时间,我们采用通信计算策略来进行通信时间的优化。以反向传播为例,由于使用了 ZeRO-3 技术,需要将切碎至各个计算卡上的模型进行临时的重组装(对应图中的 Gather );而在反向传播 ( 对应图中的 Calculate ) 之后,我们还需要将得到的局部梯度重新切碎至不同的计算卡上(对应图中的 Scatter )。我们通过不同的 CUDA stream 区分不同的操作,让运算和通信得以同时运行,通过大量的计算时间隐藏通信的时间开销。

 

- 图 通信计算重叠

性能展示

综合使用上述技术,BMTrain 在大模型训练上效果出色,在不同规模的算力条件下均有较好的性能表现。

在单卡 2080Ti 上,BMTrain 可以实现 transformers 库无法实现的3 亿参数 BERT-Large微调。

在单卡 V100 上,BMTrain 训练 3 亿参数 BERT-Large 较 transformers 实现能够提高约 20 倍 batch size2.5 倍吞吐量

在单机 8 卡 A100 环境下,BMTrain 训练 130 亿参数的 GPT 较 Deepspeed / veGiantModel 实现能够提高约 4 倍 batch size1.6 倍吞吐量

在多机 8 卡 A100 环境下,BMTrain 可以使用较少 GPU 训练 1750 亿参数的 GPT-3,性能详见下表:

使用 BMTrain,64 张 A100 跑完 GPT-3 的 300B token 大概需要 2 年,服务器与显卡租金大约 900 万人民币左右。根据我们的实验估算,使用 128 张 A100 时,单卡吞吐量可以提升 2.5 倍以上,6 个月可以跑完 GPT-3,服务器租金大约 500 万人民币左右。虽然训练出 GPT-3 的成本依然高昂,但与 GPT-3 的 1200 万美元相比,成本仍然 节约了 90%以上

未来展望

本章我们主要介绍 BMTrain 中的基础加速算法,未来 BMTrain 将持续关注大模型的高效训练和性能优化,并开展工具包的进一步优化与升级,相关技术报告也会持续公开发布。同时,我们诚挚欢迎感兴趣的研究人员与开发者加入我们的开源社区,参与相关的研究交流、技术研讨与工具开发,共同为大模型的落地与应用添砖加瓦!

关注我们

▶ 官方网站:openbmb.org

▶ 交流QQ群:735930538

▶ 启智社区:git.openi.org.cn/OpenBM

▶ GitHub:github.com/OpenBMB

▶ 微博:weibo.cn/OpenBMB

▶ Twitter:twitter.com/OpenBMB

 

附录 参考文献

1. Samyam Rajbhandari, Jeff Rasley, Olatunji Ruwase, Yuxiong He.ZeRO: Memory Optimizations Toward Training Trillion Parameter Models.

2. Zhengda Bian, Hongxin Liu, Boxiang Wang, et al.Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training.

3. Adam Paszke, Sam Gross, Francisco Massa, et al.PyTorch: An Imperative Style, High-Performance Deep Learning Library.

4. Zhengyan Zhang, Xu Han, Hao Zhou, et al.CPM: A Large-scale Generative Chinese Pre-trained Language Model.

5. Zhengyan Zhang, Yuxian Gu, Xu Han, et al.CPM-2: Large-scale Cost-efficient Pre-trained Language Models.

6. Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova.BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding.

7. Colin Raffel, Noam Shazeer, Adam Roberts, et al.T5: Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer.

8. Alec Radford, Jeffrey Wu, Rewon Child, et al.GPT2: Language Models are Unsupervised Multitask Learners.

9. Ben Wang and Aran Komatsuzaki, et al.GPT-J from EleutherAI released in the repo mesh-transformer-jax.

10. Diederik P. Kingma, Jimmy Ba.Adam: A Method for Stochastic Optimization.

11. Yang You, Jing Li, Sashank Reddi, et al.Large Batch Optimization for Deep Learning: Training BERT in 76 minutes.

12. Hanlin Tang, Shaoduo Gan, Ammar Ahmad Awan, et al.1-bit Adam: Communication Efficient Large-Scale Training with Adam's Convergence Speed.

13. NCCL:docs.nvidia.com/deeplea