2022/08/10

TPU Research Cloud体験記 with JAX/Flax

30日間TPUを使うことができるTPU Research Cloud (TRC)に参加していたので、TRCや使っていたJAX/Flaxについて書いておこうと思います。

TPU Research Cloud (TRC)

TRCは、もともとはTensorflow Research Cloudと呼ばれていたもので、Googleの所有するクラウドTPUを学生や研究者などに一定期間使わせてくれるものです。

TRCへの参加は、ホームページにある申請フォームから申請が可能です。
自分の場合は、フォームに「JAXでモデルの追実装をやります」のような感じで申請したところ、特に追加で作文などはなく7~8時間程度で承認メールが来ました。

その後、GCPでTRC用のプロジェクトを作り、プロジェクトIDをメールで送信すればTRC開始となります。 貸し出されるTPUはGCPで使う形なため、ある程度GCPについて知っているほうがスムーズに使えるかもしれません。しかし、公式ドキュメントがよくまとまっているので全く知らない場合でも大丈夫だと思います。

今回、TRCではTPUとJAX/Flaxを使ってVITSJVSで学習して、日本語TTSを学習させていました。 以下にサンプル音声を貼っておきます。

学習データ量がやはり少ないので、濁音や外来語などの品質が低いようです。しかし、十分な品質で再現できたと思います。

TPU

簡単にTPUについて説明しておきます。TPUはTensor Processing Unitの略で、行列演算に特化したアクセラレーターです。シストリックアレイによって効率的に行列計算を行えるため、行列計算が大部分を占める深層学習に向いています。GPUと比べるとTPUのほうが演算自体は速いようですが、結局現在の計算処理においてはデータIOが無視できないため、概ねTPU・GPU間で大きく速度差は無いようです。
話は逸れますが、近年の新しいNVIDIA GPUにおいてTensorCoreや半精度での計算が猛プッシュされる背景には、計算速度が帯域幅をどれだけ稼げるかにかかっている面があるためだと言えます。TPUも計算はbfloat16という半精度浮動小数点演算で行うようになっており、これからの深層学習においては混合精度学習がデフォルトとなる時代がそのうちくるのかも知れません。

TPU on GCP

Cloud TPUでは、Cloud TPU VMとCloud TPU Podの2つの提供形態があります。 この2つの違いについては、公式ドキュメントに書いてあります。

ここで、TRCではあくまでTPUのみについての貸し出しになるので、GCP上で別にCPUやメモリ・ストレージなどを使った分についてはユーザー負担となります。TPU Podを使う場合、このように1つのユーザーVMから、gRPCで各TPUノードを制御する形になっているため、このユーザーVMの料金については負担する必要があります。一方、TPU VMは、TPUデバイスに物理的につながっているVMが使用可能であり、このVMに積まれている高性能なCPU・大容量のメモリ・ある程度のストレージはTPUの貸出に含まれるようなので、特に料金を負担せずに使うことができます。よって、複数のTPUノードを使いたいというようなケース以外ではTPU VMを用いるのが良いと思います。

自分は、学習ログや学習データについてはCloud Storageで保存していたので、Cloud Storage代で約400円だけかかりました。 Cloud Storageを使うことで、前処理済みのデータを複数TPU VMで共有できるため、かなり便利です。Cloud Storage バケットとの接続はこのドキュメントにかかれています。
バケットの操作はgsutilを使ってやることもできますが、Cloud Storage Fuseを使ってマウントしてしまったほうが楽です。

TPUを使う

以下のドキュメントやチュートリアルに従えば簡単に使えると思います。

注意点としては、割り当てられるTPUのゾーンは決まっているため、それ以外のゾーンのTPUを使ってしまうと料金がかかってしまいます。 以下のように、v3、v2によってゾーンも違うため、gcloud compute tpus tpu-vm createする場合は--zone=zoneのゾーン指定に注意してください。

TPUの割り当て

リソースが無い場合は割当に失敗してしまう場合が多々あります。特にTPU v3は大人気のようで、There is no more capacity in the zone \"europe-west4-a\"をめちゃくちゃ見ることになると思います。祈りながらcreateを連打してください。

TRC・TPUを使う上で、調べても原因などが出ずに困惑したのは上のThere is no more capacity in the zone \"europe-west4-a\ぐらいで、それ以外は公式ドキュメントを読めば普通に使うことができました。

JAX/Flax

今回、TPUと併せて使っていたのがGoogleが開発している深層学習フレームワークのFlaxになります。Flaxは、自動微分ライブラリのJAXに基づいています。JAXは、あくまで自動微分ライブラリなため、深層学習モデルを構築・学習するのに必要なことをFlaxが担っている形になります。

JAX/Flaxの特徴として、XLAというドメイン固有コンパイラを用いたjitによって学習が非常に速いことが挙げられます。JAX/FlaxがGoogleによって作られたものであるためTPUとも相性がよく、PyTorchやTensorflowではTPU上で動かすにはいくつか変更が必要ですが、Flaxはほぼ変更不要です。その他にも、関数型言語の影響を受けているため書き方もPyTorchやTensorflowとは異なったものとなっていたり、vmapなど便利な機能があるなど、新進気鋭の深層学習フレームワークです。

詳細な紹介は、公式のJAX QuickstartFlax Getting Started、ブログ記事などに任せてここではTRCで使った中での感想を書きたいと思います。

TPU・JAX/Flaxの良いところ・厳しいところ

良いところ

  • 速度

やはり第一にJAX/Flax・TPU共に速いことが挙げられます。学習を短くできるのはコスト削減にも繋がりますし、トライアンドエラーの回数も増やせるので非常にありがたいです。
TPUはバッチサイズをなるべく大きくする必要はありますが、TPU v3で概ね混合精度のA100と同等以上の速さで動いていました。比較対象のA100が研究室のサーバーなので、DGXなどであればA100は更にはやい可能性はあります。
上でも書きましたが、TPUは半精度浮動小数点で計算するようになっているため非常に速いです。現在のGPUではfloat16を用いた混合精度が主流ですが、lossのスケーリングが必要であったりNaNが出てしまうなどがありました。実際、今回もGPUでのデバッグ時に混合精度だとNaNが出てしまっていました。TPUはダイナミックレンジが大きいbfloat16で計算するため、特に混合精度が原因のNaNは出ること無く学習が可能でした。
JAX/Flaxについても、1epoch目はjitが走るため遅くなってしまいますが、それ以降は圧倒的な速さで動いてくれます。今回使ったコードはデバッグのためGPU上でも動かしていましたが、GPUの使用率も非常に高いまま安定して動いていました。同じコードを動かしたことはありませんが、似たようなコードではPyTorchではGPU使用率50~90%をよく行き来したりするのに対し、JAX/FlaxではGPU使用率はほぼ99%に張り付いており結構感動しました。
ただし、CPUにおいてはその限りではありません。XLAは主にGPU・TPUへの最適化に重きを置いているため、CPUにおいては必ず速くなるとは限りません。何ならjitコンパイルが終わらないこともあります。

  • 設計

これはJAX/Flaxに対してですが、JAX/Flaxはフレームワークとして後発ということもあり、主観的な設計や書き方などが洗練されてるように思います。numpy互換のAPIがありながら、自動ベクトル化のjax.vmapやデバイス並列化のjax.pmapなどの優れた機能があり、パラメーターも別に変数として扱うことによって内部でのような処理をしているか悩むといったことがありません。特に特徴的なのが乱数の扱いで、JAXでは明示的な乱数生成器を利用することで、再現性を保証しています。PyTorchでもJAXをインスパイアしたfunctorchが開発されており、JAX/Flaxの設計はかなり良いものであると思います。

  • 公式ドキュメント

TPUについては、上で書いた通りドキュメントを読めば一通り動かせるようになっています。特にドライバーやCUDAのようなライブラリのセットアップは無いため、GPUをクラウドで借りるときと同じような感覚で使えるようになっています。ライブラリによっては、TPU用にコードを変更する必要があったりもしますが、JAX/Flaxであればjax.pmapを加えるだけでほぼそのまま動きます。

JAX/Flaxについても、公式ドキュメントが非常によく整備されています。

分からないことが出てきたときは、公式ドキュメントを探せば大概見つかります。JAX - The Sharp BitsはPyTorchなどとは異なるJAX特有の特徴や、NaNのデバッグ方法などが紹介されており、導入のための良いガイドとなっています。

厳いところ

  • 高い

TPUをTRCによって今回は無料で使うことができましたが、TPU v3-8であれば1時間8ドル、プリエンプティブで2.4ドルが必要になります。プリエンプティブで5日間使えたとしても、約300ドルとなるため、デバッグや試行錯誤に使うのはなかなか厳しいと感じました。これでもV100やA100よりは安いようですが、検証段階ではTesla T4といった安めのGPUインスタンスやオンプレマシンを使うのが良いと思いました。TPUは、ファインチューニングといった繰り返し何回も学習するような場合や、うまく行った小さいモデルをスケールアップして巨大なモデルで長時間学習するような場合に向いているでしょう。

  • 公式以外のドキュメントや実装が少ない

公式ドキュメントはよく整備されているものの、PyTorch with GPUに比べると圧倒的に公式以外の情報は少ないです。そのため、公式ドキュメントでは解決できなかった問題は自力でどうにかすることがかなり求められました。例えば今回、モデルの途中のmatmulの演算でNaNが出る現象について、NaNデバッグしても解決せず、いろいろ調べても情報は見つかりませんでした。この原因としては学習の初期では計算が不安定だったようで、はじめの数epochをmatmulで行っていた部分を等価な処理に置き換えることでNaNは出なくなりました。同じコードをGPUで計算した場合は特にこの現象は見られなかったため、ライブラリやデバイスがまだ発展途上の部分もあると感じました。しかし、公式の例で様々なモデルが追実装されてうまく動いているため、大体のケースではこのような問題には出くわさないと思います。
また、新参のライブラリなため論文の著者実装としてはPyTorchなどに比べるとやはり用いられていません。そのため、新しく出た論文のアイデアなどを自力でJAX/Flaxに移植することが求められる場面は多いと思います。

  • ライブラリ側で勝手にいい感じにはやってくれない

Flaxは関数型ベースなため、例えばパラメーターについてもインスタンス変数の形では保持しません。何をやっているのかが明確な一方、裏を返せばライブラリ側でいい感じにやってくれないということでもあります。例えば、Batch NormalizationであればPyTorchで意識するのはせいぜい.eval().train()程度ですが、Flaxではバッチ統計量も変数として自分で管理し、学習で得られたバッチ統計量を使用するのか、そのバッチで求めたものを使うのかといった引数なども自分で用意する必要があります。何をやっているのかが分かりやすいといったメリットはありますが、ボイラープレートを省略できないのは面倒かもしれません。

  • コンパイラの気持ちになる必要がある

JAX/Flaxはjitがあるので速いのですが、当然コンパイルできるようなモデルを書いてあげる必要があります。例えば、以下のようなコードはエラーとなります。

@jit
def f(x):
  if x < 3:
    return 3. * x ** 2
  else:
    return -4 * x

また、入力のshapeが変わったりするとその度にjitが走ってしまうため、NLPではバッチの中の最大長ではなく、データセットの中の最大長でpaddingしてあげるようにするなど、jitの特性を理解したコードを書くことが求められます。Tensorflow v1を書いたことがある人にとっては、懐かしい感覚を覚えるかもしれません。
これらの挙動についても以下の公式ドキュメントで詳しく書いてあるので、参考にすると良いと思います。

また、余程変わった処理を書かない限りはなかなか起きませんが、jitのコンパイルに非常に時間がかかるということも起きます。例えば、JAXにおいてforは単純に展開されてしまうため、なるべくベクトルやテンソル単位での処理に置き換えることが望ましいです。
例えば、以下の2つの関数は同じ計算結果になりますが、JAXの中間表現はxとyが長さ128の場合、Aが5122行に対してBが1行になります。

def A(x, y):
    for i in range(x.shape[0]):
        x = x.at[i].set(x[i] + y[i])

    return x

def B(x, y):
    return x + y

変わった処理をしない限りそこまで問題になることは無いとは思いますが、頭の片隅で意識はしておいたほうが無難です。
今回自分がやったことの中では変わった処理が必要だったので、できるだけ最適化しても、CPUにおいてはjitコンパイルが終わらず、GPUやTPUでは10~15分程度jitのコンパイル時間がかかりました。

一応、あまりいないとは思いますがjitが終わらなくなった場合の調査方法について書いておきます。
以下のtrain_step関数のjitに非常に時間がかかるとします。

@jax.jit
def train_step(batch):
    y = model_function(batch)
    loss = loss_function(y)

    return loss

batch = hoge
train_step(batch) # <= jitが走る

jax.make_jaxprは、JAXの中間表現を出力してくれるので、その内容をダンプします。

@jax.jit
def train_step(batch):
    y = model_function(batch)
    loss = loss_function(y)

    return loss

batch = hoge
print(jax.make_jaxpr(train_step)(batch))
# train_step(batch) # <= jitが走る

その内容をvscodeなどで開き、以下のようにアウトラインで見ると一定のパターンが繰り返されているような箇所があります。

jaxの中間表現のアウトライン

その部分のxla_callなど目立つシンボルを見ると、以下のように同じような処理が繰り返されています。

    detk:f32[4,1210] detl:f32[] detm:f32[] detn:f32[] deto:f32[] = xla_call[
      ** 中略 **
        in (deus, 0.0, 0.0, 0.0, 0.0) }
      name=jvp(vmap(jvp(_pad)))
    ] detj inf
      ** 中略 **
    deuz:f32[4,1210] deva:f32[] devb:f32[] devc:f32[] devd:f32[] = xla_call[
      ** 中略 **
        in (dewh, 0.0, 0.0, 0.0, 0.0) }
      name=jvp(vmap(jvp(_pad)))
    ] deuy inf

この場合は、xla_callで繰り返しjvp(vmap(jvp(_pad)))が呼び出されているので、この処理はどこかのforループの中のpadの処理によって、不必要に長い中間表現が生成されてしまっていると見当をつけることができます。

最初に自分が単純に書いたコードだと、中間表現は約40万行ほどになっていました。それを最適化して約9万行まで減ったので、jitが遅すぎたり終わらなくなった場合には一度JAXが吐いている中間表現を見てあげると良いと思います。

  • 他のライブラリに依存する面がある

JAX/Flaxには、そもそもサポートしない機能やまだ実装されていない機能などがあります。 例えば、データ入出力について公式ドキュメントで「The world doesn’t need yet another data loading library」といっており、PyTorchのDataLoaderやTensorflowのtf.dataなどを用います。また、エッジ方面も主にTensorflowの形式にモデルを変換することで行うなど、あくまでJAXは自動微分ライブラリ、Flaxは深層学習モデルをJAXで学習させるためのライブラリという割り切りがあるように思います。
一つのライブラリで全てを終わらせたかったり、他の余計なライブラリを入れたくないという場合には向いていないかもしれません。

まとめ

ここまで、TPU・JAX/Flaxを使った中で感じた良いところや厳しいところを挙げてきましたが、最後にまとめとしてどういう人にTPU・JAX/Flaxが向いているかを書いておこうと思います。

TPUに向いている人

  • ファインチューニングなどの同じ学習を繰り返し行う必要のある人
  • 大きめのモデルを効率よく学習させたい人

JAX/Flaxに向いている人

  • なるべく速くモデルを学習させたい人
  • TPUを使いたい人
  • モダンな書き方で書きたい人
  • 関数型プログラミングに抵抗がない人
  • 実装力のある人
  • エッジというよりは手元やサーバーで推論させる人
  • 内部実装に詳しい人に手軽に質問できる人

これらの向いている人の条件をバッチリGoogleの社員は満たしていそうであり、設計思想や向き・不向きなどがものすごく腑に落ちました。

感想

今回、TRCでTPU・JAX/Flaxについて一通り触れてみましたが、とても良かったです。TPUは予算的に厳しいですが、JAX/Flaxはこれからも使っていこうと思います。
また、TRCはあまり知名度が無いのか、情報やブログなども少なかったです。個人ではなかなか用意するのが難しい豊富な計算資源を使って大きめなモデルも学習できる機会としてもっと広まると良いと思いました。
最後になりますが、今回の貴重な機会を頂けたお礼をTRC teamへ申し上げます。