このブログは、株式会社フィックスターズのエンジニアが、あらゆるテーマについて自由に書いているブログです。
皆さんこんにちは!CS-2上にアニーリングを実装する記事の第4回です!
本記事では、Cerebras Wafer Scale Engine (WSE)にSimulated Annealing(SA)をパラレルテンパリングで実装する方法を紹介します。
本記事の内容はこれらの記事の続きとなります。
今回は前回の記事で行ったウエハースケール上での並列化をさらに発展させ、パラレルテンパリングというアルゴリズムを適用した実装を作成しました。
まだ前回の記事を読んでいない方は、そちらから読むことをおすすめします。 もし前回の記事をお読みになった方も、再度読んでみるとよりこの記事を理解しやすくなると思います。
前回の記事では複数のPEを連携させたブロックを複数個配置して、PEに収まらない問題を並列で解けるようにしました。
今回は、『パラレルテンパリング』というアルゴリズムを追加します。 『パラレルテンパリング』とはアニーリングを並列で行うときにというアルゴリズムで、並列で実行している各実行の条件の調整を行うことで独立して試行を行うより収束が速くなるアルゴリズムです。
まず、Simulated annealingの復習をしましょう。Simulated annealingで解く問題はQUBOモデルの最小化問題で、以下の式で\(s\)を変えつつ最小の\(f(s)\)を求める問題と言えます。
\[f(s)=-\sum_{i\le j}Q_{ij}s_{i}s_{j}\]
プログラムの挙動としては\(s\)をランダムに選び、\(s\)を反転させてみて、\(f(s)\)が小さくなるなら採用。\(f(s)\)が大きくなるなら確率的に採用します。
確率的に採用するときの確率\(E\)は『温度』と呼ばれるパラメータを使いつつ概ね以下の数式で決定します。なお、この方針はメトロポリス法と呼ばれます。
\[E =\exp(- f(s) / t)\]
温度\(t\)が高ければ高いほど\(f(s)\)が大きくなる\(s\)の反転も採用されます。
ここで温度の決め方ですが、初期値から徐々に小さくするという方針で実装しています。
パラレルテンパリングはメトロポリス法を並列で行うときに使えるアルゴリズムです。
Simulated annealingを並列で複数個同時に実行する状況を考えます。
並列に計算する実行ごとに初期温度を変えておき、一定の条件で温度を交換するアルゴリズムがパラレルテンパリングです。うまく温度を交換することで適切な温度で各実行を行うことができるようになります。
温度を交換するかどうかの判定は以下の確率\(p\)で行います。
\[p = \exp\left(\frac{f_1(x) – f_2(x)}{1 / t_1 – 1 / t_2}\right)\]
ざっくり説明すると、『\(f(x)\)が大きい側が大きな温度になる』といった感じです。
処理の流れは以下になります
全ての処理をWSE上で実行することも可能ですが、今回の実装では扱いやすさと『色』の制約のためにホスト側でパラレルテンパリングの処理を実装しました。
色に関しては最初の記事 を参照してください。実は、使える色の数に強い制限があり、実装上の工夫が必要でした。 タイル内部の通信と全体の通信をうまく制御したり、性能測定やデバッグための情報を収集したりするための通信時に色をうまくやりくりする必要がありました。
今回のコードもこれまでと同様にGitHubにて公開するので、ご覧ください。尚、最初のInitial commitから順番にこれまでの各記事に対応しています。
実際にパラレルテンパリングにより温度交換が実施されているかシミュレータを使って確認してみます。 実験条件は以下のようになっています。
\(Q\)のサイズ | タイル構成 | 1タイル内のPE構成 | イテレーション数 | テンパリング周期 |
---|---|---|---|---|
\(50\times50\) | \(4\times4\) | \(2\times2\) | 512 | 32 |
この条件の元でホスト側のpythonコードに情報出力用のコードを入れて、実際にシミュレータを走らせてみます。
※出力に処理時間の数値もありますが、シミュレータでの処理時間である点にご注意下さい。
$ ./commands.sh -c config/parallel_tempering.toml
['sdk_debug_shell', 'compile', 'src/layout.csl', '--fabric-dims=15,14', '--fabric-offsets=4,1', '--params=Num:50', '--params=block_height:2', '--params=block_width:2', '--params=grid_height:4', '--params=grid_width:4', '--params=trace_buffer_size:0', '--params=collector_buffer_size:128', '--params=enable_simprint:1', '--params=MEMCPYH2D_DATA_1_ID:0', '--params=MEMCPYD2H_DATA_1_ID:1', '-o=out', '--memcpy', '--channels=1', '--max-parallelism=8']
merged_params={'Num': 50, 'block_height': 2, 'block_width': 2, 'grid_height': 4, 'grid_width': 4, 'trace_buffer_size': 0, 'collector_buffer_size': 128, 'enable_simprint': '1', 'MEMCPYH2D_DATA_1_ID': '0', 'MEMCPYD2H_DATA_1_ID': '1', 'max_iters': 512, 'log2_swap_interval': 5, 'time_constant': 1000, 'log_init_temperature': 32768, 'iterations_per_collect': 8, 'suppress_simfab_trace': True}
Q_triu=array([-0.6297765 , -0.0222319 , -0.54266036, ..., -0.7425307 ,
0.26843616, 0.78073215], dtype=float32)
started
runner.load 0.078647059s
runner.run 0.037856007s
init 0.000213681s
Send runtime parameter : max_iters=512 (I)
Send runtime parameter : log2_swap_interval=5 (I)
Send runtime parameter : time_constant=1000 (I)
Send runtime parameter : log_init_temperature=32768 (I)
Send runtime parameter : iterations_per_collect=8 (I)
memcpy_h2d 0.000330926s
processing: 6%|█ | 31/512 [00:33<08:47, 1.10s/it]
[i=0/15] swapped 6 pairs
processing: 12%|██ | 63/512 [00:47<05:39, 1.32it/s]
[i=1/15] swapped 2 pairs
processing: 19%|███ | 95/512 [01:01<04:27, 1.56it/s]
[i=2/15] swapped 3 pairs
processing: 25%|████ | 127/512 [01:14<03:46, 1.70it/s]
[i=3/15] swapped 3 pairs
processing: 31%|█████ | 159/512 [01:28<03:15, 1.81it/s]
[i=4/15] swapped 2 pairs
processing: 37%|██████ | 191/512 [01:41<02:50, 1.88it/s]
[i=5/15] swapped 1 pairs
processing: 44%|███████ | 223/512 [01:54<02:29, 1.94it/s]
[i=6/15] swapped 0 pairs
processing: 50%|████████ | 255/512 [02:08<02:09, 1.99it/s]
[i=7/15] swapped 1 pairs
processing: 56%|█████████ | 287/512 [02:22<01:51, 2.02it/s]
[i=8/15] swapped 3 pairs
processing: 62%|██████████ | 319/512 [02:35<01:34, 2.05it/s]
[i=9/15] swapped 1 pairs
processing: 69%|███████████ | 351/512 [02:49<01:17, 2.07it/s]
[i=10/15] swapped 2 pairs
processing: 75%|████████████ | 383/512 [03:03<01:01, 2.09it/s]
[i=11/15] swapped 1 pairs
processing: 81%|█████████████ | 415/512 [03:16<00:45, 2.11it/s]
[i=12/15] swapped 0 pairs
processing: 87%|██████████████ | 447/512 [03:30<00:30, 2.12it/s]
[i=13/15] swapped 0 pairs
processing: 94%|███████████████ | 479/512 [03:44<00:15, 2.14it/s]
[i=14/15] swapped 3 pairs
processing: 100%|████████████████| 511/512 [03:57<00:00, 2.15it/s]
[i=15/15] swapped 1 pairs
processing: 100%|████████████████| 511/512 [03:57<00:00, 2.15it/s]
swap_temperature 237.985667534s
memcpy_d2h 0.50405135s
memcpy_d2h 0.037933778s
Loading traces... Please wait a few minutes.
[COLLECTOR]Detect collector enabled. Loading collector info...
[COLLECTOR]Loading info from the collection row (1/grid_height=4)...
[COLLECTOR]Loading info from the collection row (2/grid_height=4)...
[COLLECTOR]Loading info from the collection row (3/grid_height=4)...
[COLLECTOR]Loading info from the collection row (4/grid_height=4)...
[COLLECTOR]Loading completed.
[COLLECTOR]STATISTICS FILE : /home/ubuntu/cerebras_ws/cerebras_sa/log/yyyymmdd-xxxxxx/statistics.json
Complete load trace (D2H).
SIM_STATS : /home/ubuntu/cerebras_ws/cerebras_sa/log/yyyymmdd-xxxxxx/sim_stats.json
load_trace_and_stop 0.037933778s
total 243.714701108s
best_s=array([1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0,
1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1,
1, 1, 1, 0, 1, 1], dtype=int32)
min_energy_wse=-66.581566 (in WSE)
min_energy=-66.58155961334705 (in Python)
opt_s = -66.582, [1 1 1 0 1 0 1 1 1 0 1 1 0 0 0 1 0 0 1 1 1 0 1 1 1 1 1 0 1 1 0 0 1 1 0 1 0
0 1 0 1 0 1 1 1 1 1 0 1 1]
best_s = -66.582, [1 1 1 0 1 0 1 1 1 0 1 1 0 0 0 1 0 0 1 1 1 0 1 1 1 1 1 0 1 1 0 0 1 1 0 1 0
0 1 0 1 0 1 1 1 1 1 0 1 1]
OK
テンパリング周期である32イテレーションのタイミングでホスト側でパラレルテンパリングを行っている様子がログから見て取れますね!
この時の、温度・エネルギーの状態についても実際に見てみましょう。比較としてパラレルテンパリングを無効化した際の結果も同時にお見せします。
実行中の各タイルにおける温度
・エネルギー
・SAのフリップ確率(8周期毎)
をプロットしたものが次のようになります。
パラレルテンパリング無 | パラレルテンパリング有 |
---|---|
大きい画像なので、詳細を確認したい場合は右クリックして別タブで開いて拡大するのをおすすめします。
パラレルテンパリング無の左図では、各タイル上段の温度が単調減少しているのに対して、パラレルテンパリング有の右図では不規則な温度変化となっているのがわかります。
あとは、実機を使って、このパラレルテンパリングを実際に動作させてみたいと思います。 実機上での結果については、また別のブログで紹介しますのでお楽しみ下さい!
今回の記事では、パラレルテンパリングを導入してSAを解く実装について述べました。
今後は、このパラレルテンパリングを導入したSAを使って、東京エレクトロンデバイス様ご協力の元で実際に実機上での評価を行っていきます。
最大カット問題のベンチマークとして有名なGsetを利用する予定です。またブログで説明していく予定なので、ぜひ楽しみにしていてください!
このプロジェクトは計算機好きな有志が自己研鑽の一部として、業務時間外(たまに時間内)で好き勝手に開発しています。
また、株式会社フィックスターズでは一緒に働く仲間を募集しています。 WSEのみならず様々なプロセッサでのプログラミング・高速化に興味がある方は、ぜひ採用ページよりご応募ください。
keisuke.kimura in Livox Mid-360をROS1/ROS2で動かしてみた
Sorry for the delay in replying. I have done SLAM (FAST_LIO) with Livox MID360, but for various reasons I have not be...
Miya in ウエハースケールエンジン向けSimulated Annealingを複数タイルによる並列化で実装しました
作成されたプロファイラがとても良さそうです :) ぜひ詳細を書いていただきたいです!...
Deivaprakash in Livox Mid-360をROS1/ROS2で動かしてみた
Hey guys myself deiva from India currently i am working in this Livox MID360 and eager to knwo whether you have done the...
岩崎システム設計 岩崎 満 in Alveo U50で10G Ethernetを試してみる
仕事の都合で、検索を行い、御社サイトにたどりつきました。 内容は大変参考になりま...
Prabuddhi Wariyapperuma in Livox Mid-360をROS1/ROS2で動かしてみた
This issue was sorted....