ウエハースケールエンジン向けにSimulated Annealingをパラレルテンパリングで実装しました

2024年10月30日

皆さんこんにちは!CS-2上にアニーリングを実装する記事の第4回です!

本記事では、Cerebras Wafer Scale Engine (WSE)にSimulated Annealing(SA)をパラレルテンパリングで実装する方法を紹介します。

本記事の内容はこれらの記事の続きとなります。

  1. ウエハースケール計算エンジンWSE-2においてSimulated-Annealingを実装しました
  2. ウエハースケールエンジンにSimulated Annealingを分散並列実装しCS-2実機で動作確認しました
  3. ウエハースケールエンジン向けSimulated Annealingを複数タイルによる並列化で実装しました

今回は前回の記事で行ったウエハースケール上での並列化をさらに発展させ、パラレルテンパリングというアルゴリズムを適用した実装を作成しました。

まだ前回の記事を読んでいない方は、そちらから読むことをおすすめします。 もし前回の記事をお読みになった方も、再度読んでみるとよりこの記事を理解しやすくなると思います。

前回までのおさらいと今回の内容

前回の記事では複数のPEを連携させたブロックを複数個配置して、PEに収まらない問題を並列で解けるようにしました。

今回は、『パラレルテンパリング』というアルゴリズムを追加します。 『パラレルテンパリング』とはアニーリングを並列で行うときにというアルゴリズムで、並列で実行している各実行の条件の調整を行うことで独立して試行を行うより収束が速くなるアルゴリズムです。

パラレルテンパリングとは

前提

まず、Simulated annealingの復習をしましょう。Simulated annealingで解く問題はQUBOモデルの最小化問題で、以下の式で\(s\)を変えつつ最小の\(f(s)\)を求める問題と言えます。

\[f(s)=-\sum_{i\le j}Q_{ij}s_{i}s_{j}\]

プログラムの挙動としては\(s\)をランダムに選び、\(s\)を反転させてみて、\(f(s)\)が小さくなるなら採用。\(f(s)\)が大きくなるなら確率的に採用します。

確率的に採用するときの確率\(E\)は『温度』と呼ばれるパラメータを使いつつ概ね以下の数式で決定します。なお、この方針はメトロポリス法と呼ばれます。

\[E =\exp(- f(s) / t)\]

温度\(t\)が高ければ高いほど\(f(s)\)が大きくなる\(s\)の反転も採用されます。

ここで温度の決め方ですが、初期値から徐々に小さくするという方針で実装しています。

パラレルテンパリング

パラレルテンパリングはメトロポリス法を並列で行うときに使えるアルゴリズムです。

Simulated annealingを並列で複数個同時に実行する状況を考えます。

並列に計算する実行ごとに初期温度を変えておき、一定の条件で温度を交換するアルゴリズムがパラレルテンパリングです。うまく温度を交換することで適切な温度で各実行を行うことができるようになります。

温度を交換するかどうかの判定は以下の確率\(p\)で行います。

\[p = \exp\left(\frac{f_1(x) – f_2(x)}{1 / t_1 – 1 / t_2}\right)\]

ざっくり説明すると、『\(f(x)\)が大きい側が大きな温度になる』といった感じです。

今回の実装内容

処理の流れは以下になります

  1. 複数のタイルでSimulated annealingを一定回数実行
  2. 各タイルのエネルギー\(E\)と温度\(T\)をホストに転送 全タイルからEとTを回収
  3. 温度を交換するか判定し、新温度\(T'\)を返却(パラレルテンパリング) 温度交換の判定
    新温度T'の返却
  4. ある程度の回数実行したら5.へ、そうでなければ1.に戻る
  5. 結果をホストに転送し、ホスト側でタイル間のエネルギーを確認し、最終的な結果を確定

全ての処理をWSE上で実行することも可能ですが、今回の実装では扱いやすさと『色』の制約のためにホスト側でパラレルテンパリングの処理を実装しました。

色に関しては最初の記事 を参照してください。実は、使える色の数に強い制限があり、実装上の工夫が必要でした。 タイル内部の通信と全体の通信をうまく制御したり、性能測定やデバッグための情報を収集したりするための通信時に色をうまくやりくりする必要がありました。

今回のコードもこれまでと同様にGitHubにて公開するので、ご覧ください。尚、最初のInitial commitから順番にこれまでの各記事に対応しています。

挙動の確認

実際にパラレルテンパリングにより温度交換が実施されているかシミュレータを使って確認してみます。 実験条件は以下のようになっています。

\(Q\)のサイズ タイル構成 1タイル内のPE構成 イテレーション数 テンパリング周期
\(50\times50\) \(4\times4\) \(2\times2\) 512 32

この条件の元でホスト側のpythonコードに情報出力用のコードを入れて、実際にシミュレータを走らせてみます。

※出力に処理時間の数値もありますが、シミュレータでの処理時間である点にご注意下さい。

$ ./commands.sh -c config/parallel_tempering.toml
['sdk_debug_shell', 'compile', 'src/layout.csl', '--fabric-dims=15,14', '--fabric-offsets=4,1', '--params=Num:50', '--params=block_height:2', '--params=block_width:2', '--params=grid_height:4', '--params=grid_width:4', '--params=trace_buffer_size:0', '--params=collector_buffer_size:128', '--params=enable_simprint:1', '--params=MEMCPYH2D_DATA_1_ID:0', '--params=MEMCPYD2H_DATA_1_ID:1', '-o=out', '--memcpy', '--channels=1', '--max-parallelism=8']
merged_params={'Num': 50, 'block_height': 2, 'block_width': 2, 'grid_height': 4, 'grid_width': 4, 'trace_buffer_size': 0, 'collector_buffer_size': 128, 'enable_simprint': '1', 'MEMCPYH2D_DATA_1_ID': '0', 'MEMCPYD2H_DATA_1_ID': '1', 'max_iters': 512, 'log2_swap_interval': 5, 'time_constant': 1000, 'log_init_temperature': 32768, 'iterations_per_collect': 8, 'suppress_simfab_trace': True}
Q_triu=array([-0.6297765 , -0.0222319 , -0.54266036, ..., -0.7425307 ,
        0.26843616,  0.78073215], dtype=float32)
started
runner.load 0.078647059s
runner.run 0.037856007s
init 0.000213681s
Send runtime parameter : max_iters=512 (I)
Send runtime parameter : log2_swap_interval=5 (I)
Send runtime parameter : time_constant=1000 (I)
Send runtime parameter : log_init_temperature=32768 (I)
Send runtime parameter : iterations_per_collect=8 (I)
memcpy_h2d 0.000330926s
processing:   6%|█               | 31/512 [00:33<08:47,  1.10s/it]
[i=0/15] swapped 6 pairs
processing:  12%|██              | 63/512 [00:47<05:39,  1.32it/s]
[i=1/15] swapped 2 pairs
processing:  19%|███             | 95/512 [01:01<04:27,  1.56it/s]
[i=2/15] swapped 3 pairs
processing:  25%|████            | 127/512 [01:14<03:46,  1.70it/s]
[i=3/15] swapped 3 pairs
processing:  31%|█████           | 159/512 [01:28<03:15,  1.81it/s]
[i=4/15] swapped 2 pairs
processing:  37%|██████          | 191/512 [01:41<02:50,  1.88it/s]
[i=5/15] swapped 1 pairs
processing:  44%|███████         | 223/512 [01:54<02:29,  1.94it/s]
[i=6/15] swapped 0 pairs
processing:  50%|████████        | 255/512 [02:08<02:09,  1.99it/s]
[i=7/15] swapped 1 pairs
processing:  56%|█████████       | 287/512 [02:22<01:51,  2.02it/s]
[i=8/15] swapped 3 pairs
processing:  62%|██████████      | 319/512 [02:35<01:34,  2.05it/s]
[i=9/15] swapped 1 pairs
processing:  69%|███████████     | 351/512 [02:49<01:17,  2.07it/s]
[i=10/15] swapped 2 pairs
processing:  75%|████████████    | 383/512 [03:03<01:01,  2.09it/s]
[i=11/15] swapped 1 pairs
processing:  81%|█████████████   | 415/512 [03:16<00:45,  2.11it/s]
[i=12/15] swapped 0 pairs
processing:  87%|██████████████  | 447/512 [03:30<00:30,  2.12it/s]
[i=13/15] swapped 0 pairs
processing:  94%|███████████████ | 479/512 [03:44<00:15,  2.14it/s]
[i=14/15] swapped 3 pairs
processing: 100%|████████████████| 511/512 [03:57<00:00,  2.15it/s]
[i=15/15] swapped 1 pairs
processing: 100%|████████████████| 511/512 [03:57<00:00,  2.15it/s]
swap_temperature 237.985667534s
memcpy_d2h 0.50405135s
memcpy_d2h 0.037933778s
Loading traces... Please wait a few minutes.
[COLLECTOR]Detect collector enabled. Loading collector info...
[COLLECTOR]Loading info from the collection row (1/grid_height=4)...
[COLLECTOR]Loading info from the collection row (2/grid_height=4)...
[COLLECTOR]Loading info from the collection row (3/grid_height=4)...
[COLLECTOR]Loading info from the collection row (4/grid_height=4)...
[COLLECTOR]Loading completed.
[COLLECTOR]STATISTICS FILE : /home/ubuntu/cerebras_ws/cerebras_sa/log/yyyymmdd-xxxxxx/statistics.json
Complete load trace (D2H).
SIM_STATS : /home/ubuntu/cerebras_ws/cerebras_sa/log/yyyymmdd-xxxxxx/sim_stats.json
load_trace_and_stop 0.037933778s
total 243.714701108s
best_s=array([1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0,
       1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1,
       1, 1, 1, 0, 1, 1], dtype=int32)
min_energy_wse=-66.581566 (in WSE)
min_energy=-66.58155961334705 (in Python)
opt_s   = -66.582, [1 1 1 0 1 0 1 1 1 0 1 1 0 0 0 1 0 0 1 1 1 0 1 1 1 1 1 0 1 1 0 0 1 1 0 1 0
 0 1 0 1 0 1 1 1 1 1 0 1 1]
best_s  = -66.582, [1 1 1 0 1 0 1 1 1 0 1 1 0 0 0 1 0 0 1 1 1 0 1 1 1 1 1 0 1 1 0 0 1 1 0 1 0
 0 1 0 1 0 1 1 1 1 1 0 1 1]
OK

テンパリング周期である32イテレーションのタイミングでホスト側でパラレルテンパリングを行っている様子がログから見て取れますね!

この時の、温度・エネルギーの状態についても実際に見てみましょう。比較としてパラレルテンパリングを無効化した際の結果も同時にお見せします。

実行中の各タイルにおける温度エネルギーSAのフリップ確率(8周期毎)をプロットしたものが次のようになります。

パラレルテンパリング無 パラレルテンパリング有
パラレルテンパリング無 パラレルテンパリング有

大きい画像なので、詳細を確認したい場合は右クリックして別タブで開いて拡大するのをおすすめします。

パラレルテンパリング無の左図では、各タイル上段の温度が単調減少しているのに対して、パラレルテンパリング有の右図では不規則な温度変化となっているのがわかります。

左上タイルにおけるテンパリング前後の比較



あとは、実機を使って、このパラレルテンパリングを実際に動作させてみたいと思います。 実機上での結果については、また別のブログで紹介しますのでお楽しみ下さい!

終わりに

今回の記事では、パラレルテンパリングを導入してSAを解く実装について述べました。

今後は、このパラレルテンパリングを導入したSAを使って、東京エレクトロンデバイス様ご協力の元で実際に実機上での評価を行っていきます。

最大カット問題のベンチマークとして有名なGsetを利用する予定です。またブログで説明していく予定なので、ぜひ楽しみにしていてください!

開発に関わったメンバー (アルファベット順)

このプロジェクトは計算機好きな有志が自己研鑽の一部として、業務時間外(たまに時間内)で好き勝手に開発しています。

  • Hikaru Takayashiki
  • Hiroki Nishimoto
  • Keisuke Kimura
  • Naoki Yoshifuji
  • Toru Fukaya
  • Yoshiki Imaizumi
  • Yuki Ito

また、株式会社フィックスターズでは一緒に働く仲間を募集しています。 WSEのみならず様々なプロセッサでのプログラミング・高速化に興味がある方は、ぜひ採用ページよりご応募ください。

About Author

keisuke.kimura

Leave a Comment

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

このサイトはスパムを低減するために Akismet を使っています。コメントデータの処理方法の詳細はこちらをご覧ください

Recent Comments

Social Media