更快的扩散采样:数值方法与基于分数的模型的碰撞
扩散模型能生成令人惊叹的图像,但速度很慢。生成一张图像需要几十次串行的神经网络前向传播——每次都是完整的 U-Net 前向传播。DPM-Solver++ 把这个数字降到了 10-20 步并保持合理的质量,是目前的最先进方法。但如果我们能借用科学计算社区几十年来使用的技术,是否能做得更好?
我一直在构建一个框架来测试这个问题。实验还没有运行(在等 GPU 集群时间),但代码已经写好,基线已经设置好,我想讲讲我在尝试什么,以及为什么我认为其中一些想法确实有胜算。
每个扩散模型内部隐藏的 ODE
从数学上讲,从扩散模型中采样就是求解一个常微分方程 (ODE)。Song 等人(2021)的概率流 ODE (probability flow ODE) 的形式如下:
其中 是预测噪声的神经网络, 是噪声调度, 是时间 时的噪声水平。你从 时的纯高斯噪声开始,积分到 ,得到一张干净的图像。
标准方法——DDIM、Euler、Heun——把这当作一个通用 ODE,用通用积分器求解。这可行,但忽略了这个特定方程结构中的关键信息。
线性-非线性分解
再看一眼这个 ODE。它可以干净地分解为两部分:
线性部分 ,其中 是一个标量系数乘以状态量。它有精确的解析解——不需要任何近似。非线性部分 是神经网络所在的地方,也是计算代价昂贵的部分。
DPM-Solver++ 已经在一定程度上利用了这一点。在对数信噪比 (log-SNR) 坐标 下,从时刻 到 的精确解为:
第一项精确求解了线性部分。所有数值误差都来自于对非线性部分那个积分的近似。问题是:我们能近似得更好吗?
指数积分器 (Exponential Integrators):让数学来干活
指数时间差分 (ETD) 方法来自流体动力学和刚性 ODE 文献(Cox & Matthews 2002,Hochbruck & Ostermann 2010)。核心思想是:如果你的 ODE 形如 ,就不要去近似 ——精确求解它,只近似 。
ETD1:一阶指数积分器
最简单的版本每步求一次网络,并使用指数积分器理论中的 -函数。在 log-SNR 坐标下,更新公式为:
其中 是 log-SNR 空间中的步长, 是步骤 时网络的数据预测。指数衰减部分解析处理线性部分。只有非线性部分的积分被近似,而且即便是这个近似也利用了指数结构——它不是简单的矩形法则。
每步一次网络求值。与 DDIM 成本相同。但误差完全集中在非线性近似上,不会分散到线性和非线性两部分。
ETD2:用于更高精度的预测-校正器
ETD2 使用当前和之前的数据预测,为积分项构建更好的数值积分:
这是一个加权组合——对非线性部分做类梯形积分,而线性部分保持精确。局部截断误差从 降到 。对于步长 较大的 10 步采样,这个差异很重要:在 时, 对比 是十倍的差距。
成本仍然是每步一次网络求值(它重用了之前的预测),所以 ETD2 相比 ETD1 几乎免费地获得了更高阶的精度。第一步回退到 ETD1,因为没有之前的预测可以重用。
Chebyshev 时间调度:步在哪里踩,和怎么踩一样重要
有一件事让我在深入研究时感到惊讶。时间步的选择——你在哪些 值上对网络求值——可能和积分方法本身一样重要。
扩散 ODE 在不同位置难度不同。在 (纯噪声)附近,分数函数光滑,几乎是高斯的。大步长是安全的。在 (干净数据)附近,分数编码了精细的图像细节,变化迅速。ODE 变得刚性 (stiff),需要小步长。大部分误差预算都花在了靠近 的最后一段。
Karras/EDM 调度用幂律间距()来处理这个问题,在干净数据端附近聚集步长。DPM-Solver++ 在 log-SNR 上使用均匀间距。两者都是启发式的。
Chebyshev 节点是另一回事。在逼近论中,它们是区间上可证明的最优插值点——它们最小化 Lebesgue 常数,避免 Runge 现象(均匀间距多项式插值时出现的灾难性振荡)。节点为:
它们在区间两端聚集,这与扩散 ODE 的结构非常契合:在非常靠近开始(离开高斯先验)和非常靠近结束(分辨精细细节)的地方,积分都是棘手的。
不需要额外的神经网络求值,不需要额外的计算。只是把你本来就要踩的步放在更聪明的位置。Karras 调度是靠直觉设计、靠实验验证的。Chebyshev 节点来自定理。我很好奇实践中哪个更好——我的猜测是 Chebyshev 在非常低的 NFE(5-7 步)时表现更好,因为那时步位置对误差的影响最为主导。
Richardson 外推 (Richardson Extrapolation):跑两次,消除误差
Richardson 外推是数值分析中最古老的技巧之一,我有点惊讶它还没有被应用到扩散采样中。这个想法简单得令人缴械。
假设你用一个 阶方法跑 步,得到结果 。误差大约是 ,其中 是未知常数。现在用相同的方法跑 步(步长 )。误差大约是 。你有两个方程两个未知数( 和 ),所以可以解出精确答案:
对于 DDIM():。对于 ETD2():。外推结果的误差是 ——你提升了一阶精度。
代价是总共 次网络求值(粗粒度跑 次,细粒度跑 次)。3 倍代价换来一阶精度的提升。不总是值得——但有个额外好处:如果粗粒度和细粒度结果已经吻合(相对差异低于某个容差),你就知道粗粒度结果已经收敛。可以跳过外推,省下那 次求值。这给你一个有上界代价的自适应质量控制。
在实现中,我加了一个提前退出检查:如果 ,立即返回 。在高 NFE 时这会可靠地触发,Richardson 采样器就变成了一个收敛检测器而不是外推器。
工程:Triton 核和混合精度
算法改进是一个维度。纯工程是另一个维度。我想分别测量这两者——把它们混在一起是那种声称"快了 2 倍"实际上是"算法好了 1.1 倍 + CUDA 技巧贡献了 1.8 倍"的论文的常见做法。
融合 Triton 核
ETD 更新步骤涉及几个逐元素操作:
out = decay * x + coeff * x0 # ETD1
out = decay * x + c0 * x0 + c1 * x0p # ETD2
out = (scale * fine - coarse) / (scale - 1) # Richardson 合并在原始 PyTorch 中,每个操作都会创建一个中间张量,从 GPU 全局显存中读写。对于 32×32 的 CIFAR 图像这可以忽略不计,但对于 512×512 或 1024×1024 的潜扩散模型,显存带宽的开销就积累起来了。
把这些融合成单个 Triton 核,可以消除中间内存分配。一个核读取所有输入,计算结果,写入一次。ETD1 核从"2 次读取 + 1 个中间结果 + 1 次写入"变成"2 次读取 + 1 次写入"。对于 ETD2,"3 次读取 + 2 个中间结果 + 1 次写入"变成"3 次读取 + 1 次写入"。
这些核在 Triton 不可用时会优雅地降级到 PyTorch——公共 API 返回 None,调用方使用非融合路径。
混合精度
扩散 ODE 在不同时刻有不同的精度需求。在 附近,信噪比低,分数函数光滑——FP16 足够。在 附近,你在分辨精细细节,ODE 是刚性的——FP32 很重要。一个简单的阈值( 用 FP16,否则用 FP32)应该能以最小的质量损失捕获大部分加速。
torch.compile
编译后的 DDIM 采样器用 torch.compile 包装模型调用,用于核融合和图优化。这与算法选择正交,应该能与任何采样器叠加。
公平对比:双表策略
这是我最在乎搞对的部分。机器学习论文经常在单个基准测试里混合算法改进和工程改进,让人根本无法判断到底是什么在起作用。
我使用两个独立的评估表:
表 1:算法公平性(FID vs NFE)。 所有方法使用相同的 PyTorch 代码路径,没有 Triton,没有 compile,没有混合精度。唯一的变量是采样算法和调度。这回答了:"在固定 次神经网络求值的预算下,哪种方法生成的图像最好?"
表 2:工程性能(FID vs 墙钟时间)。 每种方法都得到所有 CUDA 技巧——Triton 核、torch.compile、混合精度。这回答了:"在实际硬件上,实践中哪个最快?"
两个表的基线都是通过 diffusers 库的 DPM-Solver++(2M),这是使用最广泛的最先进采样器。CIFAR-10 32×32 上已发表的 FID:5 NFE 约 5.0,10 NFE 约 3.5,20 NFE 约 3.0。
我测试的每种方法都在相同的 NFE 值(5、10、20)下与这个基线对比。不挑对自己方法有利的 NFE 值。
我预期会发现什么
我坦诚地说出我的预测。
算法上最佳的赌注:Chebyshev 调度 + ETD2。 ETD2 几乎免费地获得更高阶精度(重用之前的预测),Chebyshev 节点应该给出接近最优的步位置。在 NFE=5 时,我猜这个组合能比 DPM-Solver++ 好 0.5-1.0 个 FID 点。理论支撑是清晰的:精确线性积分(ETD)+ 最优插值节点(Chebyshev)+ 高阶数值积分(ETD2 的梯形法则)。每个部分攻击不同的误差来源。
Richardson 外推:有用但昂贵。 在低 NFE(5-7)时,3 倍的代价让它不实际。在 NFE=15-20 时,收敛检测可能让它提前退出,有效验证 10 步已经足够了。我更多地把它看作质量保证工具,而不是速度工具。
Karras 调度会很难打败。 调度是在完全相同类型的模型上通过实验调出来的。Chebyshev 有理论支撑,但不是为这个特定问题设计的。我的诚实预期是 Chebyshev 在非常低的 NFE(5 步)时胜出,因为那里插值节点的理论最优性影响最大;Karras 在中等 NFE(10-20)时胜出,因为经验调参在那里发挥作用。
Triton 核:在 CIFAR 上边际收益,在更大图像上有意义。 对于 32×32 图像,网络前向传播是如此主导,以至于融合更新步骤几乎察觉不到。但这些核是为未来在 Stable Diffusion 64×64 潜变量上测试时准备的。
组合会比任何单项技术更重要。 最好的结果可能是类似:Chebyshev 调度 + ETD2 + 在 提前停止 + torch.compile。每个部分贡献适度的改进,但它们正交可叠加。
我可能错在哪里: DPM-Solver++ 已经在 log-SNR 空间做精确线性积分(他们的命题 4.1)。它与 ETD 方法之间的差距可能比我想的要小。DPM-Solver++ 中多步重用也比 ETD2 简单的两点积分更复杂。如果 DPM-Solver++ 基线在所有 NFE 值上都和 ETD2+Chebyshev 相差不超过 0.2 个 FID,故事就会变成"DPM-Solver++ 已经榨干了大部分汁水",而不是"数值方法打败了机器学习方法"。这也是一个有效的发现。
这真正是关于什么的
更深层的问题不是 ETD2 是否能比 DPM-Solver++ 好 0.3 个 FID 点。而是 ODE 求解器社区和扩散模型社区是否因为彼此交流不够而在性能上留下了桌上的钱。
DPM-Solver++ 从扩散模型框架的第一原理推导出来。指数积分器从刚性 ODE 框架的第一原理推导出来。它们得出了惊人相似的更新规则——都精确求解线性部分,都近似非线性积分。但它们来自不同的学术传统,在步位置、误差估计和阶数选择上做出了不同的选择。
如果实验表明精心选择的调度 + 指数积分器在低 NFE 时能媲美或超过 DPM-Solver++,这就验证了跨领域交叉的价值。如果做不到,这告诉我们 DPM-Solver++ 的模型专用设计选择(数据预测参数化、动态阈值、多步重用)承担了比通用数值框架更多的重量。
无论哪种结果,我们都能学到东西。GPU 集群队列是我和答案之间唯一的障碍。
