拡散サンプリングの高速化:数値計算手法とスコアベースモデルの出会い
拡散モデルは素晴らしい画像を生成するが、遅い。1枚の画像に数十回の逐次ニューラルネットワーク評価が必要だ——各評価はU-Netを通じた完全なフォワードパスだ。DPM-Solver++はそれを合理的な品質で10〜20ステップまで削減し、現在のState of the Artだ。でも、科学計算コミュニティが数十年使ってきた技術を借りることで、さらに改善できるとしたら?
まさにそれをテストするフレームワークを構築している。実験はまだ実行していない(GPUクラスタの順番待ち中)が、コードは書けていて、ベースラインもセットアップされている。何を試みているか、なぜそれに勝ち目があると思うかを説明したい。
すべての拡散モデルの中に隠れているODE
拡散モデルからのサンプリングは、数学的に言えば常微分方程式(ODE)を解くことだ。Song et al. (2021) の確率フローODEはこの形を取る:
ここで はノイズを予測するニューラルネットワーク、 はノイズスケジュール、 は時刻 におけるノイズレベルだ。 での純粋なガウスノイズから始めて まで積分すると、クリーンな画像が得られる。
標準的なアプローチ——DDIM、Euler、Heun——はこれを汎用ODEとして扱い、汎用積分器を投げる。それは機能するが、この特定の方程式の構造について重要なことを無視している。
線形・非線形の分離
ODEをもう一度見てほしい。きれいに2つの部分に分解できる:
線形部分 (ここで は状態のスカラー係数)は厳密な解析解を持つ——近似不要。非線形部分 はニューラルネットワークが住む場所で、コストのかかる部分だ。
DPM-Solver++はすでにある程度これを利用している。対数SNR座標 では、時刻 から への厳密解は:
最初の項は線形部分を正確に解く。すべての数値誤差は非線形部分のその積分を近似することから来る。問題は:もっとうまく近似できないか?
指数積分器:数学に仕事をさせる
指数時間差分(ETD)法は流体力学と硬いODE文献から来ている(Cox & Matthews 2002、Hochbruck & Ostermann 2010)。核心的なアイデア:ODEが の形を持つなら、 を近似せず——正確に解いて、 だけを近似する。
ETD1:一次指数積分器
最も単純なバージョンはステップあたりネットワーク評価を1回行い、指数積分器理論から 関数を使う。対数SNR座標での更新は:
ここで は対数SNR空間でのステップ、 はステップ でのネットワークによるデータ予測だ。指数減衰が線形部分を解析的に処理する。非線形部分の積分だけが近似され、しかもそれさえも指数の構造を使っている——単純な矩形則ではない。
ステップあたりネットワーク評価1回。DDIMと同じコスト。でも誤差は純粋に非線形の近似に集中していて、線形と非線形の両方の項に広がっていない。
ETD2:より高精度のプレディクタ・コレクタ
ETD2は現在と前のデータ予測を使って積分項のより良い求積を構築する:
これは重み付きの組み合わせだ——非線形部分に対する台形則的な求積で、線形部分は正確なまま。局所打ち切り誤差は から に落ちる。 が大きい10ステップサンプリングでは、その差が重要だ: で 対 は10倍の差がある。
コストは依然としてステップあたりネットワーク評価1回だ(前の予測を再利用する)。ETD2はETD1と比べて本質的に無料で高次の精度を得る。最初のステップは再利用できる前の予測がないためETD1にフォールバックする。
チェビシェフ時間スケジュール:どこでステップを踏むかが、どう踏むかと同じくらい重要
これは掘り下げ始めたときに驚いたことだ。タイムステップの配置の選択——ネットワークをどの 値で評価するか——は積分手法自体と同じくらい重要になりうる。
拡散ODEはどこでも同じように難しいわけではない。(純粋なノイズ)付近では、スコア関数は滑らかでほぼガウス分布だ。大きなステップは安全。(クリーンなデータ)付近では、スコアは細かい画像の詳細をエンコードし、急速に変化する。ODEは硬くなり、小さなステップが必要になる。誤差バジェットのほとんどが 付近の最後の区間で費やされる。
Karras/EDMスケジュールはクリーンデータ端付近にステップをクラスタリングするべき乗則スペーシング()でこれに対処する。DPM-Solver++は対数SNRでの等間隔を使う。両方ともヒューリスティックだ。
チェビシェフノードは別物だ。近似理論において、これらは区間上で証明上最適な補間点だ——Lebesgue定数を最小化し、Rungeの現象(等間隔多項式補間で起こる壊滅的な振動)を回避する。ノードは:
区間の両端にクラスタリングされ、拡散ODEの構造とよく合致する:ガウス事前分布を離れる最初のところと細かい詳細を解決する最後のところの両方で積分が難しい。
追加のニューラルネットワーク評価なし。追加の計算なし。ただ、もともと踏もうとしていたステップのスマートな配置だ。Karrasスケジュールは直感で設計され経験的に検証された。チェビシェフノードは定理から来る。実際にどちらが勝つか興味がある——ステップ配置が誤差を支配する非常に低いNFE(5〜7ステップ)でチェビシェフが良い結果を出すと予想している。
Richardson外挿:2回実行して誤差をキャンセル
Richardson外挿は数値解析で最も古いトリックの一つで、拡散サンプリングに適用されていないことが少し意外だった。アイデアは拍子抜けするほど単純だ。
次の手法を ステップで実行して結果 を得たとする。誤差はおよそ (未知の定数 で)。今、同じ手法を ステップで実行する(ステップサイズ )。誤差はおよそ 。2つの方程式と2つの未知数( と )があるので、正確な答えを解ける:
DDIM()の場合:。ETD2()の場合:。外挿された結果の誤差は だ——1次の精度が得られた。
コストは合計 回のネットワーク評価(粗い実行で 回、細かい実行で 回)。1次余分のためのコストが3倍。常に価値があるわけではない——でもボーナスがある:粗い結果と細かい結果がすでに一致している場合(相対差が許容値以下)、粗い結果が収束していることが分かる。外挿をスキップして 回の評価を節約できる。これにより、コスト上限付きの適応的な品質管理ができる。
実装では早期終了チェックを追加した: なら、すぐに を返す。高NFEでこれが確実に発火し、Richardsonサンプラーは外挿器ではなく収束検出器になる。
エンジニアリング:Tritonカーネルと混合精度
アルゴリズムの改善は一つの軸だ。生のエンジニアリングは別の軸だ。両方を測定したいが、別々に——それらを混在させると、実際には「アルゴリズムで1.1倍 + CUDAトリックで1.8倍」なのに「2倍速い」と主張する論文になってしまう。
融合Tritonカーネル
ETD更新ステップはいくつかのelement-wise演算を含む:
out = decay * x + coeff * x0 # ETD1
out = decay * x + c0 * x0 + c1 * x0p # ETD2
out = (scale * fine - coarse) / (scale - 1) # Richardson結合バニラPyTorchでは、各演算が中間テンソルを作り、グローバルGPUメモリから読み書きする。32x32のCIFAR画像では無視できるが、512x512や1024x1024の潜在拡散では、メモリトラフィックが蓄積する。
これらを単一のTritonカーネルに融合することで中間アロケーションがなくなる。1つのカーネルが全入力を読み、結果を計算し、1回書き込む。ETD1カーネルは「2回ロード + 1つの中間 + 1回ストア」から「2回ロード + 1回ストア」になる。ETD2は「3回ロード + 2つの中間 + 1回ストア」から「3回ロード + 1回ストア」になる。
これらのカーネルはTritonが使えない場合にPyTorchに優雅にフォールバックする——パブリックAPIは None を返し、呼び出し元は非融合パスを使う。
混合精度
拡散ODEは時刻によって異なる精度要件を持つ。 付近では信号ノイズ比が低くスコアは滑らか——FP16で十分。 付近では細かい詳細を解決していてODEが硬い——FP32が重要。単純な閾値( はFP16、それ以外はFP32)が品質損失を最小限に抑えながら、速度向上の大部分をキャプチャするはずだ。
torch.compile
コンパイルされたDDIMサンプラーはモデル呼び出しを torch.compile でラップしてカーネル融合とグラフ最適化を行う。これはアルゴリズムの選択と直交していて、どのサンプラーとも組み合わせられる。
公平な比較:2テーブル戦略
これが最も正しくやりたい部分だ。ML論文はアルゴリズムとエンジニアリングの改善を単一のベンチマークに混在させがちで、実際に何が助けになっているか分からなくなる。
2つの別々の評価テーブルを使う:
テーブル1:アルゴリズムの公平性(FID対NFE)。 全手法が同じPyTorchコードパス、Tritonなし、compileなし、混合精度なしを使う。唯一の変数はサンプリングアルゴリズムとスケジュール。これが答える問い:「 回のニューラルネットワーク評価の固定バジェットで、どの手法が最も良い画像を生成するか?」
テーブル2:エンジニアリングパフォーマンス(FID対wall-clock時間)。 すべての手法がCUDAのトリックを全部使う——Tritonカーネル、torch.compile、混合精度。これが答える問い:「実際のハードウェアで、実際には何が最速か?」
両テーブルのベースラインは diffusers ライブラリのDPM-Solver++ (2M)で、最も広く使われているSOTAサンプラーだ。CIFAR-10 32x32での公表FID:5 NFEで〜5.0、10 NFEで〜3.5、20 NFEで〜3.0。
テストする全手法を同じNFE値(5、10、20)でこのベースラインに対して評価する。自分の手法が良く見えるNFEを選ぶようなことはしない。
何を見つけると思うか
予測について正直に書く。
アルゴリズムの最良の賭け:チェビシェフスケジュール + ETD2。 ETD2は無料で高次の精度を得る(前の予測を再利用)し、チェビシェフノードはほぼ最適なステップ配置を提供するはずだ。NFE=5で、この組み合わせがDPM-Solver++を0.5〜1.0 FIDポイント上回ると推測する。理論的な正当化はクリーンだ:正確な線形積分(ETD)+最適な補間ノード(チェビシェフ)+高次の求積(ETD2の台形則)。各ピースが異なる誤差の源を攻撃する。
Richardson外挿:有用だが高コスト。 低NFE(5〜7)では3倍のコストが実用的ではない。NFE=15〜20では、収束検出が早期終了できて、10ステップで十分だったことを効果的に検証できるかもしれない。速度ツールというより品質保証ツールと見ている。
Karrasスケジュールは打ち負かすのが難しい。 スケジュールはまさにこの種のモデルで経験的に調整されたものだ。チェビシェフには理論的な裏付けがあるが、この特定の問題のために設計されたわけではない。正直なところ、チェビシェフが補間ノードの理論的最適性が最も重要な非常に低いNFE(5ステップ)で勝ち、Karrasが経験的なチューニングが実を結ぶ中程度のNFE(10〜20)で勝つと予想する。
Tritonカーネル:CIFARでは限界的、大きな画像では意味がある。 32x32画像では、ネットワークのフォワードパスがあまりにも支配的で、更新ステップの融合はほぼ誤差範囲内だ。でもカーネルは64x64の潜在空間でStable Diffusionをテストする日のために書かれた。
組み合わせが個々の技術よりも重要になる。 最良の結果はおそらく:チェビシェフスケジュール + ETD2 + での早期停止 + torch.compile のような組み合わせだろう。各ピースが控えめな改善をもたらすが、直交していて積み重なる。
私が間違えるかもしれない点: DPM-Solver++はすでに対数SNR空間で正確な線形積分を行っている(命題4.1)。それがやっていることとETD手法がやっていることの差は私が思うより小さいかもしれない。DPM-Solver++のマルチステップ再利用もETD2の単純な2点求積よりも洗練されている。DPM-Solver++のベースラインが全NFE値でETD2+チェビシェフと0.2 FID以内に収まるなら、ストーリーは「数値手法がML手法を打ち負かした」ではなく「DPM-Solver++がすでにほとんどの果実を見つけていた」になる。それもそれで有効な発見だ。
これが本当は何についてなのか
より深い問いは、ETD2がDPM-Solver++を0.3 FIDポイント上回るかどうかではない。ODEソルバーコミュニティと拡散モデルコミュニティが、もっと互いに話し合わないことでパフォーマンスを無駄にしているかどうかだ。
DPM-Solver++は拡散モデルフレームワークの第一原理から導出された。指数積分器は硬いODEフレームワークの第一原理から導出された。両者は驚くほど似た更新則に至っている——どちらも線形部分を正確に解き、どちらも非線形積分を近似する。でも異なる知的伝統から来て、ステップ配置、誤差推定、次数選択について異なる選択をする。
実験が示すとなら、うまく選ばれたスケジュール+指数積分器が低NFEでDPM-Solver++に匹敵または勝る、それはクロスポリネーションを検証する。そうでなければ、DPM-Solver++のモデル固有の設計選択(データ予測パラメータ化、ダイナミックスレッショルディング、マルチステップ再利用)が汎用数値フレームワークよりずっと多くの仕事をしていることを教えてくれる。
どちらにしても、何かを学ぶ。GPUクラスタの順番待ちが私と答えの間にある唯一のものだ。
