[開発者向け]モデルから開発ログまで全て公開!スタンフォード「Marin」プロジェクトとJAXが拓く「オープン開発」の時代

目次

はじめに

 本稿では、Google Developers Blogに掲載された記事「Stanford’s Marin foundation model: The first fully open model developed using JAX」を主な情報源として、AI開発の透明性と再現性に新たな基準を打ち立てる、スタンフォード大学の画期的なプロジェクト「Marin」について詳しく解説します。

 このプロジェクトは、単にAIモデルを公開するだけでなく、その開発プロセス全体をオープンにする「オープン開発」という新しいアプローチを提唱しています。この野心的な試みを技術的に実現可能にしたのが、Googleが開発している機械学習フレームワーク「JAX」です。

参考記事

公式

要点

  • スタンフォード大学のMarinプロジェクトは、AIモデル、コード、データ、実験ログなど、開発の全プロセスを公開する「オープン開発」を提唱するものである。
  • このアプローチの目的は、AI研究における完全な再現性と科学的な透明性を確保し、研究者が互いの成果を検証し、その上で新たな研究を構築できる環境を育むことである。
  • プロジェクトの成功の鍵は、Googleの機械学習フレームワークJAXである。JAXが持つ「高速性」「大規模な並列処理能力」「ビット単位での再現性保証」といった特性が、この困難な挑戦を技術的に支えている。
  • このプロジェクトから生まれた最初の成果が「Marin-8B」モデルであり、開発に関わる全てが、商用利用も可能なApache 2.0ライセンスの下で公開されている。

詳細解説

「オープン」の新しい地平線:オープン開発とは?

 これまでAIの世界では、「オープンソース」や「オープンウェイト」といった言葉が使われてきました。オープンソースはソースコードが公開され、オープンウェイトは学習済みのモデル(重み)が公開されます。しかし、これらのモデルが「どのようにして作られたか」という訓練データや詳細な開発プロセスは、多くの場合ブラックボックスのままでした。

 Marinプロジェクトが提唱する「オープン開発」は、このレベルをさらに一歩進めるものです。以下の表のように、モデルの重みやソースコードだけでなく、どのようなデータで、どのような試行錯誤を経てそのモデルが完成したのか、その全行程をリアルタイムに近い形で公開します。これにより、第三者がその結果を完全に再現し、科学的な検証を行うことが可能になります。これは、AI開発における透明性と信頼性を飛躍的に高める試みです。

公開レベルソースコードモデルの重み訓練データ開発プロセスライセンス
オープン開発利用可能利用可能利用・文書化議論、成功、失敗がリアルタイムで公開寛容(Apache 2.0)
オープンソース利用可能利用可能利用・文書化レポートやリリースが時折公開寛容(Apache 2.0)
オープンウェイト限定的/不可利用可能利用不可様々(制限ありも)
クローズドソース利用不可利用不可利用不可プロプライエタリ

なぜJAXが選ばれたのか? Marinプロジェクトの技術的挑戦

 開発プロセスの完全な再現性を保証しながら、最先端の基盤モデルを効率的に開発するには、いくつかの重大な技術的課題を乗り越える必要がありました。Marinチームは、これらの課題を解決する上でJAXが最適なツールであると判断し、JAXを基盤とした新フレームワーク「Levanter」を構築しました。

1. 処理速度の最大化と効率化

  • 課題: AIの訓練では、膨大な計算を何度も繰り返します。Pythonのようなインタプリタ言語のままでは処理のオーバーヘッドが大きく、性能のボトルネックになります。
  • JAXによる解決策: JAXは@jax.jitという機能(デコレータ)を提供します。これにより、一連の計算処理(順伝播、損失計算、逆伝播、パラメータ更新)をまとめてコンパイルし、最適化された単一の機械語カーネルに変換します。これにより、Pythonのオーバーヘッドが解消され、ハードウェア(特にTPU)の性能を限界まで引き出すことができます。

2. 大規模並列処理の複雑さを克服

  • 課題: 基盤モデルの訓練には、数千個ものアクセラレータ(TPUやGPU)を同時に使う必要があります。どのデータをどのチップに割り当て、どのように通信させるかを手動で管理するのは非常に複雑で、コードが難解になる原因でした。
  • JAXによる解決策: JAXは、複数のデバイスで同じプログラムを並列実行するSPMD(Single-Program, Multiple-Data)という手法を自然にサポートします。開発者は複雑な通信処理を意識することなく、ロジックの記述に集中できます。さらに、Marinチームが開発したHaliaxというライブラリを使うことで、テンソルの次元に「バッチ」や「埋め込み」といった具体的な名前を付けられるようになり、コードの可読性と安全性が劇的に向上しました。

3. コスト効率と耐障害性の高いクラスタ管理

  • 課題: 大規模な計算リソースを常に確保するのは高コストです。そのため、より安価な「プリエンプティブル(中断される可能性のある)インスタンス」を有効活用する必要がありましたが、いつ中断されるか分からない環境で安定して訓練を続けるのは困難でした。
  • JAXとクラウド技術による解決策: Google CloudのTPU Multislice技術を活用し、物理的に離れた複数のTPUを論理的に一つの巨大なクラスタとして扱えるようにしました。さらに、オーケストレーションツールRayを組み合わせることで、一部のTPUが中断されても訓練ジョブ全体が停止することなく、自動的に復旧し、継続できる堅牢なシステムを構築しました。

4. 科学的信頼性の根幹:完全な再現性の実現

  • 課題: Marinプロジェクトの核となる目標は、誰でも結果を検証できる科学を実現することです。そのためには、訓練を中断・再開したり、異なるマシン構成に移行したりしても、計算結果がビット単位で完全に一致する必要がありました。
  • JAXによる解決策: JAXは、その設計思想の段階から決定論的な動作を重視しています。例えば、乱数生成器もデフォルトで再現可能なものが使われます。この強力な再現性保証のおかげで、Marin-8Bの訓練中、異なる種類のTPUハードウェア間を移行しながらも、ビット単位での完全な再現性を維持することに成功しました。

Marin-8Bモデルの航海

 これらの技術基盤の上に構築されたのが、80.3億パラメータを持つLlamaスタイルのトランスフォーマーモデル「Marin-8B」です。その訓練プロセスは「Tootsie」と名付けられ、一直線に進んだわけではありません。12兆トークンを超える学習の過程で、より質の高いデータを組み込んだり、学習率などのハイパーパラメータを調整したりと、試行錯誤を繰り返しながら適応的に進められました

 この「 messy(ごちゃごちゃした) 」とも言える現実のプロセスをありのままに公開すること自体が、後続の研究者にとって非常に価値のある教材となります。そして、このような複雑な変更を経てもなお、JAXとLevanterのスタックが完全な再現性を保ったという事実は、その技術の堅牢性を何よりも雄弁に物語っています。

まとめ

 本稿では、スタンフォード大学のMarinプロジェクトと、それを支えるGoogleのJAXについて解説しました。

 Marinプロジェクトは、AIモデルの開発プロセス全体を公開する「オープン開発」という野心的なビジョンを掲げ、AI研究における透明性と再現性の新たなスタンダードを提示しています。この試みは、誰もが結果を検証し、信頼し、その上に新たな知を積み重ねていける、より健全なエコシステムの構築に向けた重要な一歩です。

 そして、そのビジョンを実現可能にしたのが、JAXの卓越した性能、スケーラビリティ、そして再現性への強いこだわりでした。この「開かれた研究室」への参加は、誰にでも推奨されています。Marinのモデルを利用すること、研究に貢献すること、そしてJAXエコシステム全体を盛り上げていくことで、より透明で信頼性の高いAIの未来を共に築いていくことができるといえるでしょう。

この記事が気に入ったら
フォローしてね!

  • URLをコピーしました!
  • URLをコピーしました!
目次