LOST IN BLUE

2022/08/10

TPU Research Cloud体験記 with JAX/Flax

30 日間 TPU を使うことができるTPU Research Cloud (TRC)に参加していたので、TRC や使っていた JAX/Flax について書いておこうと思います。

TPU Research Cloud (TRC)

TRCは、もともとはTensorflow Research Cloudと呼ばれていたもので、Google の所有するクラウド TPU を学生や研究者などに一定期間使わせてくれるものです。

TRC への参加は、ホームページにある申請フォームから申請が可能です。
自分の場合は、フォームに「JAX でモデルの追実装をやります」のような感じで申請したところ、特に追加で作文などはなく 7~8 時間程度で承認メールが来ました。

その後、GCP で TRC 用のプロジェクトを作り、プロジェクト ID をメールで送信すれば TRC 開始となります。
貸し出される TPU は GCP で使う形なため、ある程度 GCP について知っているほうがスムーズに使えるかもしれません。しかし、公式ドキュメントがよくまとまっているので全く知らない場合でも大丈夫だと思います。

今回、TRC では TPU と JAX/Flax を使ってVITSJVSで学習して、日本語 TTS を学習させていました。
以下にサンプル音声を貼っておきます。

学習データ量がやはり少ないので、濁音や外来語などの品質が低いようです。しかし、十分な品質で再現できたと思います。

TPU

簡単に TPU について説明しておきます。TPU は Tensor Processing Unit の略で、行列演算に特化したアクセラレーターです。シストリックアレイによって効率的に行列計算を行えるため、行列計算が大部分を占める深層学習に向いています。GPU と比べると TPU のほうが演算自体は速いようですが、結局現在の計算処理においてはデータ IO が無視できないため、概ね TPU・GPU 間で大きく速度差は無いようです。
話は逸れますが、近年の新しい NVIDIA GPU において TensorCore や半精度での計算が猛プッシュされる背景には、計算速度が帯域幅をどれだけ稼げるかにかかっている面があるためだと言えます。TPU も計算は bfloat16 という半精度浮動小数点演算で行うようになっており、これからの深層学習においては混合精度学習がデフォルトとなる時代がそのうちくるのかも知れません。

TPU on GCP

Cloud TPU では、Cloud TPU VM と Cloud TPU Pod の 2 つの提供形態があります。
この 2 つの違いについては、公式ドキュメントに書いてあります。

ここで、TRC ではあくまで TPU のみについての貸し出しになるので、GCP 上で別に CPU やメモリ・ストレージなどを使った分についてはユーザー負担となります。TPU Pod を使う場合、このように1 つのユーザー VM から、gRPC で各 TPU ノードを制御する形になっているため、このユーザー VM の料金については負担する必要があります。一方、TPU VM は、TPU デバイスに物理的につながっている VM が使用可能であり、この VM に積まれている高性能な CPU・大容量のメモリ・ある程度のストレージは TPU の貸出に含まれるようなので、特に料金を負担せずに使うことができます。よって、複数の TPU ノードを使いたいというようなケース以外では TPU VM を用いるのが良いと思います。

自分は、学習ログや学習データについてはCloud Storageで保存していたので、Cloud Storage 代で約 400 円だけかかりました。
Cloud Storage を使うことで、前処理済みのデータを複数 TPU VM で共有できるため、かなり便利です。Cloud Storage バケットとの接続はこのドキュメントにかかれています。
バケットの操作はgsutilを使ってやることもできますが、Cloud Storage Fuseを使ってマウントしてしまったほうが楽です。

TPU を使う

以下のドキュメントやチュートリアルに従えば簡単に使えると思います。

注意点としては、割り当てられる TPU のゾーンは決まっているため、それ以外のゾーンの TPU を使ってしまうと料金がかかってしまいます。
以下のように、v3、v2 によってゾーンも違うため、gcloud compute tpus tpu-vm createする場合は--zone=zoneのゾーン指定に注意してください。

TPUの割り当て

リソースが無い場合は割当に失敗してしまう場合が多々あります。特に TPU v3 は大人気のようで、There is no more capacity in the zone \"europe-west4-a\"をめちゃくちゃ見ることになると思います。祈りながらcreateを連打してください。

TRC・TPU を使う上で、調べても原因などが出ずに困惑したのは上のThere is no more capacity in the zone \"europe-west4-a\ぐらいで、それ以外は公式ドキュメントを読めば普通に使うことができました。

JAX/Flax

今回、TPU と併せて使っていたのが Google が開発している深層学習フレームワークのFlaxになります。Flax は、自動微分ライブラリのJAXに基づいています。JAX は、あくまで自動微分ライブラリなため、深層学習モデルを構築・学習するのに必要なことを Flax が担っている形になります。

JAX/Flax の特徴として、XLA というドメイン固有コンパイラを用いた jit によって学習が非常に速いことが挙げられます。JAX/Flax が Google によって作られたものであるため TPU とも相性がよく、PyTorch や Tensorflow では TPU 上で動かすにはいくつか変更が必要ですが、Flax はほぼ変更不要です。その他にも、関数型言語の影響を受けているため書き方も PyTorch や Tensorflow とは異なったものとなっていたり、vmapなど便利な機能があるなど、新進気鋭の深層学習フレームワークです。

詳細な紹介は、公式のJAX QuickstartFlax Getting Started、ブログ記事などに任せてここでは TRC で使った中での感想を書きたいと思います。

TPU・JAX/Flax の良いところ・厳しいところ

良いところ

やはり第一に JAX/Flax・TPU 共に速いことが挙げられます。学習を短くできるのはコスト削減にも繋がりますし、トライアンドエラーの回数も増やせるので非常にありがたいです。
TPU はバッチサイズをなるべく大きくする必要はありますが、TPU v3 で概ね混合精度の A100 と同等以上の速さで動いていました。比較対象の A100 が研究室のサーバーなので、DGX などであれば A100 は更にはやい可能性はあります。
上でも書きましたが、TPU は半精度浮動小数点で計算するようになっているため非常に速いです。現在の GPU では float16 を用いた混合精度が主流ですが、loss のスケーリングが必要であったり NaN が出てしまうなどがありました。実際、今回も GPU でのデバッグ時に混合精度だと NaN が出てしまっていました。TPU はダイナミックレンジが大きい bfloat16 で計算するため、特に混合精度が原因の NaN は出ること無く学習が可能でした。
JAX/Flax についても、1epoch 目は jit が走るため遅くなってしまいますが、それ以降は圧倒的な速さで動いてくれます。今回使ったコードはデバッグのため GPU 上でも動かしていましたが、GPU の使用率も非常に高いまま安定して動いていました。同じコードを動かしたことはありませんが、似たようなコードでは PyTorch では GPU 使用率 50~90%をよく行き来したりするのに対し、JAX/Flax では GPU 使用率はほぼ 99%に張り付いており結構感動しました。
ただし、CPU においてはその限りではありません。XLA は主に GPU・TPU への最適化に重きを置いているため、CPU においては必ず速くなるとは限りません。何なら jit コンパイルが終わらないこともあります。

これは JAX/Flax に対してですが、JAX/Flax はフレームワークとして後発ということもあり、主観的な設計や書き方などが洗練されてるように思います。numpy 互換の API がありながら、自動ベクトル化のjax.vmapやデバイス並列化のjax.pmapなどの優れた機能があり、パラメーターも別に変数として扱うことによって内部でのような処理をしているか悩むといったことがありません。特に特徴的なのが乱数の扱いで、JAX では明示的な乱数生成器を利用することで、再現性を保証しています。PyTorch でも JAX をインスパイアしたfunctorchが開発されており、JAX/Flax の設計はかなり良いものであると思います。

TPU については、上で書いた通りドキュメントを読めば一通り動かせるようになっています。特にドライバーや CUDA のようなライブラリのセットアップは無いため、GPU をクラウドで借りるときと同じような感覚で使えるようになっています。ライブラリによっては、TPU 用にコードを変更する必要があったりもしますが、JAX/Flax であればjax.pmapを加えるだけでほぼそのまま動きます。

JAX/Flax についても、公式ドキュメントが非常によく整備されています。

分からないことが出てきたときは、公式ドキュメントを探せば大概見つかります。JAX - The Sharp Bitsは PyTorch などとは異なる JAX 特有の特徴や、NaN のデバッグ方法などが紹介されており、導入のための良いガイドとなっています。

厳いところ

TPU を TRC によって今回は無料で使うことができましたが、TPU v3-8 であれば 1 時間 8 ドル、プリエンプティブで 2.4 ドルが必要になります。プリエンプティブで 5 日間使えたとしても、約 300 ドルとなるため、デバッグや試行錯誤に使うのはなかなか厳しいと感じました。これでも V100 や A100 よりは安いようですが、検証段階では Tesla T4 といった安めの GPU インスタンスやオンプレマシンを使うのが良いと思いました。TPU は、ファインチューニングといった繰り返し何回も学習するような場合や、うまく行った小さいモデルをスケールアップして巨大なモデルで長時間学習するような場合に向いているでしょう。

公式ドキュメントはよく整備されているものの、PyTorch with GPU に比べると圧倒的に公式以外の情報は少ないです。そのため、公式ドキュメントでは解決できなかった問題は自力でどうにかすることがかなり求められました。例えば今回、モデルの途中の matmul の演算で NaN が出る現象について、NaN デバッグしても解決せず、いろいろ調べても情報は見つかりませんでした。この原因としては学習の初期では計算が不安定だったようで、はじめの数 epoch を matmul で行っていた部分を等価な処理に置き換えることで NaN は出なくなりました。同じコードを GPU で計算した場合は特にこの現象は見られなかったため、ライブラリやデバイスがまだ発展途上の部分もあると感じました。しかし、公式の例で様々なモデルが追実装されてうまく動いているため、大体のケースではこのような問題には出くわさないと思います。
また、新参のライブラリなため論文の著者実装としては PyTorch などに比べるとやはり用いられていません。そのため、新しく出た論文のアイデアなどを自力で JAX/Flax に移植することが求められる場面は多いと思います。

Flax は関数型ベースなため、例えばパラメーターについてもインスタンス変数の形では保持しません。何をやっているのかが明確な一方、裏を返せばライブラリ側でいい感じにやってくれないということでもあります。例えば、Batch Normalization であれば PyTorch で意識するのはせいぜい.eval().train()程度ですが、Flax ではバッチ統計量も変数として自分で管理し、学習で得られたバッチ統計量を使用するのか、そのバッチで求めたものを使うのかといった引数なども自分で用意する必要があります。何をやっているのかが分かりやすいといったメリットはありますが、ボイラープレートを省略できないのは面倒かもしれません。

JAX/Flax は jit があるので速いのですが、当然コンパイルできるようなモデルを書いてあげる必要があります。例えば、以下のようなコードはエラーとなります。

@jit
def f(x):
  if x < 3:
    return 3. * x ** 2
  else:
    return -4 * x

また、入力の shape が変わったりするとその度に jit が走ってしまうため、NLP ではバッチの中の最大長ではなく、データセットの中の最大長で padding してあげるようにするなど、jit の特性を理解したコードを書くことが求められます。Tensorflow v1 を書いたことがある人にとっては、懐かしい感覚を覚えるかもしれません。
これらの挙動についても以下の公式ドキュメントで詳しく書いてあるので、参考にすると良いと思います。

また、余程変わった処理を書かない限りはなかなか起きませんが、jit のコンパイルに非常に時間がかかるということも起きます。例えば、JAX において for は単純に展開されてしまうため、なるべくベクトルやテンソル単位での処理に置き換えることが望ましいです。
例えば、以下の 2 つの関数は同じ計算結果になりますが、JAX の中間表現は x と y が長さ 128 の場合、A が 5122 行に対して B が 1 行になります。

def A(x, y):
    for i in range(x.shape[0]):
        x = x.at[i].set(x[i] + y[i])

    return x

def B(x, y):
    return x + y

変わった処理をしない限りそこまで問題になることは無いとは思いますが、頭の片隅で意識はしておいたほうが無難です。
今回自分がやったことの中では変わった処理が必要だったので、できるだけ最適化しても、CPU においては jit コンパイルが終わらず、GPU や TPU では 10~15 分程度 jit のコンパイル時間がかかりました。

一応、あまりいないとは思いますが jit が終わらなくなった場合の調査方法について書いておきます。
以下のtrain_step関数の jit に非常に時間がかかるとします。

@jax.jit
def train_step(batch):
    y = model_function(batch)
    loss = loss_function(y)

    return loss

batch = hoge
train_step(batch) # <= jitが走る

jax.make_jaxprは、JAX の中間表現を出力してくれるので、その内容をダンプします。

@jax.jit
def train_step(batch):
    y = model_function(batch)
    loss = loss_function(y)

    return loss

batch = hoge
print(jax.make_jaxpr(train_step)(batch))
# train_step(batch) # <= jitが走る

その内容を vscode などで開き、以下のようにアウトラインで見ると一定のパターンが繰り返されているような箇所があります。

jaxの中間表現のアウトライン

その部分のxla_callなど目立つシンボルを見ると、以下のように同じような処理が繰り返されています。

    detk:f32[4,1210] detl:f32[] detm:f32[] detn:f32[] deto:f32[] = xla_call[
      ** 中略 **
        in (deus, 0.0, 0.0, 0.0, 0.0) }
      name=jvp(vmap(jvp(_pad)))
    ] detj inf
      ** 中略 **
    deuz:f32[4,1210] deva:f32[] devb:f32[] devc:f32[] devd:f32[] = xla_call[
      ** 中略 **
        in (dewh, 0.0, 0.0, 0.0, 0.0) }
      name=jvp(vmap(jvp(_pad)))
    ] deuy inf

この場合は、xla_callで繰り返しjvp(vmap(jvp(_pad)))が呼び出されているので、この処理はどこかの for ループの中のpadの処理によって、不必要に長い中間表現が生成されてしまっていると見当をつけることができます。

最初に自分が単純に書いたコードだと、中間表現は約 40 万行ほどになっていました。それを最適化して約 9 万行まで減ったので、jit が遅すぎたり終わらなくなった場合には一度 JAX が吐いている中間表現を見てあげると良いと思います。

JAX/Flax には、そもそもサポートしない機能やまだ実装されていない機能などがあります。
例えば、データ入出力について公式ドキュメントで「The world doesn’t need yet another data loading library」といっており、PyTorch の DataLoader や Tensorflow の tf.data などを用います。また、エッジ方面も主に Tensorflow の形式にモデルを変換することで行うなど、あくまで JAX は自動微分ライブラリ、Flax は深層学習モデルを JAX で学習させるためのライブラリという割り切りがあるように思います。
一つのライブラリで全てを終わらせたかったり、他の余計なライブラリを入れたくないという場合には向いていないかもしれません。

まとめ

ここまで、TPU・JAX/Flax を使った中で感じた良いところや厳しいところを挙げてきましたが、最後にまとめとしてどういう人に TPU・JAX/Flax が向いているかを書いておこうと思います。

TPU に向いている人

JAX/Flax に向いている人

これらの向いている人の条件をバッチリ Google の社員は満たしていそうであり、設計思想や向き・不向きなどがものすごく腑に落ちました。

感想

今回、TRC で TPU・JAX/Flax について一通り触れてみましたが、とても良かったです。TPU は予算的に厳しいですが、JAX/Flax はこれからも使っていこうと思います。
また、TRC はあまり知名度が無いのか、情報やブログなども少なかったです。個人ではなかなか用意するのが難しい豊富な計算資源を使って大きめなモデルも学習できる機会としてもっと広まると良いと思いました。
最後になりますが、今回の貴重な機会を頂けたお礼を TRC team へ申し上げます。