聊一聊我们最近开源的 Ling 2.0 原生 FP8 混合精度训练

TL;DR

本次开源的 Ling 2.0 原生使用 FP8 精度进行训练,在不断追求 FP8 极致性价比的过程中,我们有了以下成果:

  1. 几乎无损的模型效果:更细的量化粒度能够避免 outlier 对整体量化精度的影响,从而保障了 loss 的收敛趋势与 bf16 训练一致。此外,细粒度量化也有利于 backward 使用更高精度的 FP8 E4M3 而非 E5M2,进一步降低了整体数值精度的损失。
  2. 更优的框架执行效率:FP8 带来的显存优势,能够为模型切分、重计算等手段带来更大的发挥空间,从而进一步提升整体的训练吞吐量。在 FP8 tile/block wise scaling 的基础上,我们进一步引入了 FP8 optimizer、FP8 on-demand transpose weight、FP8 padding routing map 等优化技术。最终,在 8/16/32 x 80G GPUs 的训练性能测试中,相比于 LLaMA 3.1 8B 与 Qwen3 8B,Ling-mini-2.0 在开启 MTP 时可获得 30~60% 吞吐提升,关闭 MTP 时则达到 90-120% 吞吐提升。

 



 

近年来,大模型技术发展迅速,模型参数规模与训练数据量呈指数级增长,对计算资源、内存带宽和能耗提出了空前挑战。降低计算与存储开销成为提升模型研发效率、降低推理部署成本的关键路径。低精度训练在降低显存占用的同时也能够提供更高的计算效率,但更少的 bit 位无法精确表示数值导致的 loss 收敛慢、榜单效果降低则是需要解决的新问题。

本次开源的 Ling 2.0 系列模型原生使用 FP8 混合精度训练,我们也同步开源了全套 FP8 训练方案。这是业界首个在 MoE 模型上支持 FP8 混合精度训练的完整开源方案,开箱即用。

 

如何缓解 FP8 精度劣势

精度担忧从何而来?

FP8,即使用 8 个 bit 位来表示一个浮点数。相比于 FP32/FP16/BF16,8 bit 意味着在浮点数的指数位(exponent)或尾数位(mantissa)需要做一些取舍,这直接影响了 FP8 能够表达的数值范围以及数值数量。

如下图,当前社区广泛使用 FP8 E4M3 或 FP8 E5M2 这两种 FP8 数值类型,其中 E4M3 相比于 E5M2 有着多一位尾数来提升精度、但少一位指数而导致更小的动态范围。

相比于 BF16(E8M7),FP8 更小的动态范围以及更低的数值精度会导致训练过程中出现 loss 收敛速度降低(算不准)、所产出模型评测效果变弱等问题。因此,想要表示模型训练过程中的weight/activations/grad 等 tensor 信息,通常需要 FP8 tensor 与 scale tensor 协作来完成。其中 scale tensor 是一个远小于 FP8 tensor 尺寸但具有更大动态范围的矩阵/实数。

我们将一个原始 BF16 tensor 转为 FP8 tensor + scale tensor 的过程称为“量化”,将其逆过程称为“反量化”。在这样的模式下,现有模型中的 BF16 Linear 层即可使用如下图所示的方案转换为 FP8 Linear。

 

哪里引入了误差?

在低精度训练的探索中,我们将低精度所导致的误差来源归因为以下几方面:

  1. FP8 量化误差(矩阵非零元素在量化后变为 0 的现象)
  2. FP8 量化失真(矩阵相近数值在量化后相等的现象)
  3. FP8 gemm 精度差异

其中,在我们看来,使用 fp32 精度维护 accumulate(累加器)的 gemm 算子在第 3 点中所导致的误差可忽略不计,真正影响训练收敛性的根因主要在于 1、2。因此,为了安全护航 Ling 2.0 系列模型的训练,我们开发了 FP8 量化失真/下溢以及高精度重算误差的监控,在出现精度预警时及时推送相关信息并介入处理。

 

Ling 2.0 训练方案

粗粒度量化(per-tensor/channel wise scaling)

早期 FP8 训练方案通常采用 per-tensor wise scaling 或 per-channel wise scaling 的量化方式,其量化 scale 易受到矩阵离群点的影响而导致量化误差加剧,进而影响训练的 loss 收敛。在大规模数据训练后期,往往会出现 loss 收敛速度变慢甚至 loss 上涨的问题,且该现象难以通过提升精度的方式快速进行矫正,这使得此类 FP8 方案难以保障超大规模 LLM 模型的训练。

 

细粒度量化(per-tile/block wise scaling)

参考 Deepseek v3,Ling-mini-2.0 同样采用更细粒度的 tile/block wise scaling 方案,该方案通过对矩阵进行分块(tile/block)并为每块分别维护各自的 scale,很大程度避免了离群点对于全局量化误差的影响,这使得在 FP8 精度下训练超大规模 LLM 模型成为可能。

 

Ling-mini-2.0上的对比实验

BF16 v.s. FP8 loss 差异

基于 Ling 2.0 的 FP8 混合精度训练方案,下图展示了我们在 Ling-mini-2.0 架构下使用 BF16 与 FP8 精度训练的 loss 差异。FP8 相比于 BF16 在 loss 稳定收敛后全程保持在 0.001 左右的相对误差,且 loss 差异无上涨趋势。

 

FP8 下溢/失真 监控指标

从 FP8 下溢/失真监控来看,activations 与 grad 在训练全程均保持着很低的下溢比例,这保障了正向传播(forward)与反向传播(backward,dx)FP8 计算结果的可靠性。

而对于反向传播过程中计算 dw 所依赖的 grad (transpose),监控显示出尾层有较高的量化误差。基于该现象,我们从许多角度完成了分析与验证,判断该误差不会对模型训练造成明显的影响,且 dw 本身处于反向传播过程中的叶子结点,量化误差不会随反向传播而层层累积。

PS:

➡ FP8 作为 Ling Team 首次在语言模型下的低精度训练尝试,每个人心中都有类似这样的疑虑:“FP8 是否真的不影响模型效果、下溢比例上涨是否意味着模型要训坏了、FP8 训练太久了要不要用 BF16 回测一段时间看是否真的没问题”。

⬅ 在我们的工作中,除了确保 loss diff 在“历史实验”下是安全的,也需要确保“当下实验”在无基线的训练过程中依然也是安全的。因此,我们对 loss、榜单等指标与 FP8 下溢/失真指标 进行相关性分析,以及对于所提取的 下溢/失真 严重的 tensor 进行高精度 dw 重算与比对,以上分析得出的“健康训练”逐渐平复了大家心中的疑虑。

量化下溢(越小越好) 量化失真(越大越好)
activations
grad
grad (transpose)

 

如何用好 FP8 性能优势

Motivation

在有了精度可靠的训练方案后,我们开始着眼于 FP8 所能带来的性能优势。考虑到大部分个人用户的计算资源都不会很充裕,我们希望通过利用 FP8 更低显存占用 的优势,结合 CPU overhead 优化 两方面来提供一种有效的混合精度训练技术。进而让用户在有限计算资源(8~32卡 80G 显存 GPU)下对 Ling-mini-2.0 进行加训时,可获得比 10B 以下的 dense 模型吞吐更大的训练性能:

  • 更低的显存占用:相比于 BF16,FP8 所节省一半的显存可以为 micro batch size / TP、PP 切分/重计算 等策略带来更大的调整空间。而更细粒度的 FP8 量化方法能够保障 LLM 安全平稳的训练,也意味着现有网络的 weight、activations、grad、optimizer 等数值均有可能在 FP8 精度下做低损的“压缩”与“解压缩”,从而可以提供一种时间换空间的手段。
  • CPU overhead 优化:现有 Megatron + Transformer Engine 训练框架下使用 FP8 recipe 进行训练会带来许多额外的操作,例如 FP8 量化/反量化、FP8 padding/unpadding,更多的 assert 校验等,这些都是潜在影响整体训练效率的因素。

 

更优的训练效率

FP8 Optimizer

Adam 优化器为每个参数维护两个额外的状态变量(一二阶动量),其显存占用是参数的 2 倍,当模型增大且卡数有限时会成为显存瓶颈。使用 FP8 表示一二阶动量,可以节省 75% 的优化器显存占用,大幅度的降低显存压力,使得在有限的硬件显存限制下,支持更大尺寸模型的训练或提升 micro batch size 来加速训练。

经验证,在细粒度“压缩”与“解压缩”的方式下,使用 FP8 精度维护一二阶动量对模型训练的收敛性和榜单效果几乎没有影响。这可能归因于一二阶动量对数值精度不敏感,量化误差可以被动量的平滑特性所缓解。

方案思路可参考:8-bit Optimizers via Block-wise Quantization

 

FP8 on-demand transpose weight

受限于 FP8 Tensor 较低的转置效率,原有 transformer engine 实现中为了提高效率,会额外缓存了一份 weight.T 用于 backward 的计算,导致整体并未节省 weight 所占显存。在该背景下,我们通过开发更快的转置算子,实现了按需转置,去掉了这部分额外的转置权重,从而节省了一半的 weight 显存占用。

该项优化已提交至社区:blockwise fp8 weight memory optimization: on-demand columnwise fp8 weight creation

 

FP8 padding routing map

FP8 gemm kernel 要求运算矩阵 shape 是 16 的倍数,而在 MoE 模型中很难保证分发到每个 expert 的 token 数量均符合要求。原有 Megatron 框架通过额外的 padding/unpadding layer 在解决该问题的同时引入了额外的 CPU overhead。为了减少不必要的操作耗时,我们通过编辑 routing map 的方式在路由分发前便保证了预期矩阵 shape 与 FP8 gemm kernel 的兼容性,而相应被编辑区域由于 router probs 为 0 同时也保证了整体计算结果的数学等价性,从而进一步提升了模型的训练效率。

该优化已合入 Megatron:FP8 padding optimization of MoE models by padding the routing map.

 

Ling-mini-2.0 实测性能

基于上述提到的 FP8 方案,在单机 8 卡场景下,相比于 BF16 基线可节省单卡显存占用约 14-16GB。利用空余的显存可以调大 micro batch size,从而提升整体训练的吞吐。在 8/16/32 * 80G GPUs 的性能测试中,相比于 LLaMA 3.1 8B 与 Qwen3 8B,Ling-mini-2.0 在开启 MTP 时可获得 30-60% 吞吐提升,关闭 MTP 时则达到 90-120% 吞吐提升。

Model 8 x 80G GPUs (GBS=128) 16 x 80G GPUs (GBS=256) 32 x 80G GPUs (GBS=512)
LLaMA 3.1 8B (baseline) 81222 161319 321403
Qwen3 8B 55775 (-31.33%) 109799 (-31.94%) 219943 (-31.57%)
Ling-mini-2.0 109532 (+34.86%) 221585 (+37.36%) 448726 (+39.61%)
Ling-mini-2.0 w/o MTP 128298 (+57.96%) 307264 (+90.47%) 611466 (+90.25%)

 

一些碎碎念

在低精度训练探索的过程中,我们需要有一种更具“性价比”的链路来替代 BF16。这里的性价比简单来讲,即:同等训练耗时下取得更优的 loss 或榜单。

虽然看起来在不同角度扣一扣总能有一些性能提升,但实际做下来,我们发现对性能的优化是一个长期且漫长的过程,若是精度有可观测的损失,那便需要更大且更长周期的优化投入才能抹平这些损失。因此,对于 Ling 系列模型我们的要求非常严格,即 “精度无损”,在无损的精度下再持续地优化性能。

在当下,得益于 3D 并行策略等技术的发展,任意尺寸基于 transformer 架构的语言模型,均可以在 TP/PP/CP 等并行策略下被装载进多张 GPU 中。而低精度又额外提供一种用时间换空间的手段,因此,在有限的资源限制下,如何权衡计算效率、通信效率、重计算等因素来最大程度地提升的训练吞吐,是我们需要持续探索的。


我想对千千说~