├── .github ├── dependabot.yml └── workflows │ ├── lint.yaml │ ├── nightly.yaml │ ├── release.yml │ └── test.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yaml ├── CHANGELOG.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── RELEASING.md ├── docs ├── requirements.txt └── source │ ├── JAX_Vision_transformer.ipynb │ ├── JAX_Vision_transformer.md │ ├── JAX_basic_text_classification.ipynb │ ├── JAX_basic_text_classification.md │ ├── JAX_examples_image_segmentation.ipynb │ ├── JAX_examples_image_segmentation.md │ ├── JAX_for_LLM_pretraining.ipynb │ ├── JAX_for_LLM_pretraining.md │ ├── JAX_for_PyTorch_users.ipynb │ ├── JAX_for_PyTorch_users.md │ ├── JAX_image_captioning.ipynb │ ├── JAX_image_captioning.md │ ├── JAX_machine_translation.ipynb │ ├── JAX_machine_translation.md │ ├── JAX_porting_PyTorch_model.ipynb │ ├── JAX_porting_PyTorch_model.md │ ├── JAX_time_series_classification.ipynb │ ├── JAX_time_series_classification.md │ ├── JAX_transformer_text_classification.ipynb │ ├── JAX_transformer_text_classification.md │ ├── JAX_visualizing_models_metrics.ipynb │ ├── JAX_visualizing_models_metrics.md │ ├── _static │ ├── ai-stack-logo.svg │ ├── css │ │ └── custom.css │ ├── favicon.png │ └── images │ │ ├── compiler.svg │ │ ├── hardware.svg │ │ ├── hero-radial.svg │ │ ├── hero.svg │ │ ├── loss_acc_example.png │ │ ├── model_display_example.png │ │ ├── nnx_display_example.png │ │ ├── testsheet_start_example.png │ │ ├── testsheets_500_3000.png │ │ ├── training_data_example.png │ │ ├── unetr_architecture.png │ │ └── what-is-the-jax-ai-stack.svg │ ├── _templates │ └── navbar-top.html │ ├── blog.md │ ├── conf.py │ ├── contributing.md │ ├── data_loaders.md │ ├── data_loaders_on_cpu_with_jax.ipynb │ ├── data_loaders_on_cpu_with_jax.md │ ├── data_loaders_on_gpu_with_jax.ipynb │ ├── data_loaders_on_gpu_with_jax.md │ ├── digits_diffusion_model.ipynb │ ├── digits_diffusion_model.md │ ├── digits_vae.ipynb │ ├── digits_vae.md │ ├── events.md │ ├── getting_started.md │ ├── index.html │ ├── index.rst │ ├── install.md │ ├── more_learning.md │ ├── neural_net_basics.ipynb │ ├── neural_net_basics.md │ ├── news.md │ ├── pytorch_users.md │ ├── tutorials.md │ └── videos.md ├── jax_ai_stack ├── __init__.py └── tests │ ├── __init__.py │ ├── test_chex.py │ ├── test_nnx_with_grain.py │ ├── test_nnx_with_optax.py │ ├── test_nnx_with_orbax.py │ ├── test_nnx_with_tfds.py │ └── test_tf_with_orbax_export.py └── pyproject.toml /.github/dependabot.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/.github/dependabot.yml -------------------------------------------------------------------------------- /.github/workflows/lint.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/.github/workflows/lint.yaml -------------------------------------------------------------------------------- /.github/workflows/nightly.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/.github/workflows/nightly.yaml -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/.github/workflows/release.yml -------------------------------------------------------------------------------- /.github/workflows/test.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/.github/workflows/test.yaml -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/.gitignore -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/.pre-commit-config.yaml -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/.readthedocs.yaml -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/CHANGELOG.md -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/CONTRIBUTING.md -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/LICENSE -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/README.md -------------------------------------------------------------------------------- /RELEASING.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/RELEASING.md -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/requirements.txt -------------------------------------------------------------------------------- /docs/source/JAX_Vision_transformer.ipynb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/JAX_Vision_transformer.ipynb -------------------------------------------------------------------------------- /docs/source/JAX_Vision_transformer.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/JAX_Vision_transformer.md -------------------------------------------------------------------------------- /docs/source/JAX_basic_text_classification.ipynb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/JAX_basic_text_classification.ipynb -------------------------------------------------------------------------------- /docs/source/JAX_basic_text_classification.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/JAX_basic_text_classification.md -------------------------------------------------------------------------------- /docs/source/JAX_examples_image_segmentation.ipynb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/JAX_examples_image_segmentation.ipynb -------------------------------------------------------------------------------- /docs/source/JAX_examples_image_segmentation.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/JAX_examples_image_segmentation.md -------------------------------------------------------------------------------- /docs/source/JAX_for_LLM_pretraining.ipynb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/JAX_for_LLM_pretraining.ipynb -------------------------------------------------------------------------------- /docs/source/JAX_for_LLM_pretraining.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/JAX_for_LLM_pretraining.md -------------------------------------------------------------------------------- /docs/source/JAX_for_PyTorch_users.ipynb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/JAX_for_PyTorch_users.ipynb -------------------------------------------------------------------------------- /docs/source/JAX_for_PyTorch_users.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/JAX_for_PyTorch_users.md -------------------------------------------------------------------------------- /docs/source/JAX_image_captioning.ipynb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/JAX_image_captioning.ipynb -------------------------------------------------------------------------------- /docs/source/JAX_image_captioning.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/JAX_image_captioning.md -------------------------------------------------------------------------------- /docs/source/JAX_machine_translation.ipynb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/JAX_machine_translation.ipynb -------------------------------------------------------------------------------- /docs/source/JAX_machine_translation.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/JAX_machine_translation.md -------------------------------------------------------------------------------- /docs/source/JAX_porting_PyTorch_model.ipynb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/JAX_porting_PyTorch_model.ipynb -------------------------------------------------------------------------------- /docs/source/JAX_porting_PyTorch_model.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/JAX_porting_PyTorch_model.md -------------------------------------------------------------------------------- /docs/source/JAX_time_series_classification.ipynb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/JAX_time_series_classification.ipynb -------------------------------------------------------------------------------- /docs/source/JAX_time_series_classification.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/JAX_time_series_classification.md -------------------------------------------------------------------------------- /docs/source/JAX_transformer_text_classification.ipynb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/JAX_transformer_text_classification.ipynb -------------------------------------------------------------------------------- /docs/source/JAX_transformer_text_classification.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/JAX_transformer_text_classification.md -------------------------------------------------------------------------------- /docs/source/JAX_visualizing_models_metrics.ipynb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/JAX_visualizing_models_metrics.ipynb -------------------------------------------------------------------------------- /docs/source/JAX_visualizing_models_metrics.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/JAX_visualizing_models_metrics.md -------------------------------------------------------------------------------- /docs/source/_static/ai-stack-logo.svg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/_static/ai-stack-logo.svg -------------------------------------------------------------------------------- /docs/source/_static/css/custom.css: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/_static/css/custom.css -------------------------------------------------------------------------------- /docs/source/_static/favicon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/_static/favicon.png -------------------------------------------------------------------------------- /docs/source/_static/images/compiler.svg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/_static/images/compiler.svg -------------------------------------------------------------------------------- /docs/source/_static/images/hardware.svg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/_static/images/hardware.svg -------------------------------------------------------------------------------- /docs/source/_static/images/hero-radial.svg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/_static/images/hero-radial.svg -------------------------------------------------------------------------------- /docs/source/_static/images/hero.svg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/_static/images/hero.svg -------------------------------------------------------------------------------- /docs/source/_static/images/loss_acc_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/_static/images/loss_acc_example.png -------------------------------------------------------------------------------- /docs/source/_static/images/model_display_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/_static/images/model_display_example.png -------------------------------------------------------------------------------- /docs/source/_static/images/nnx_display_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/_static/images/nnx_display_example.png -------------------------------------------------------------------------------- /docs/source/_static/images/testsheet_start_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/_static/images/testsheet_start_example.png -------------------------------------------------------------------------------- /docs/source/_static/images/testsheets_500_3000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/_static/images/testsheets_500_3000.png -------------------------------------------------------------------------------- /docs/source/_static/images/training_data_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/_static/images/training_data_example.png -------------------------------------------------------------------------------- /docs/source/_static/images/unetr_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/_static/images/unetr_architecture.png -------------------------------------------------------------------------------- /docs/source/_static/images/what-is-the-jax-ai-stack.svg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/_static/images/what-is-the-jax-ai-stack.svg -------------------------------------------------------------------------------- /docs/source/_templates/navbar-top.html: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/_templates/navbar-top.html -------------------------------------------------------------------------------- /docs/source/blog.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/blog.md -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/conf.py -------------------------------------------------------------------------------- /docs/source/contributing.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/contributing.md -------------------------------------------------------------------------------- /docs/source/data_loaders.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/data_loaders.md -------------------------------------------------------------------------------- /docs/source/data_loaders_on_cpu_with_jax.ipynb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/data_loaders_on_cpu_with_jax.ipynb -------------------------------------------------------------------------------- /docs/source/data_loaders_on_cpu_with_jax.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/data_loaders_on_cpu_with_jax.md -------------------------------------------------------------------------------- /docs/source/data_loaders_on_gpu_with_jax.ipynb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/data_loaders_on_gpu_with_jax.ipynb -------------------------------------------------------------------------------- /docs/source/data_loaders_on_gpu_with_jax.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/data_loaders_on_gpu_with_jax.md -------------------------------------------------------------------------------- /docs/source/digits_diffusion_model.ipynb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/digits_diffusion_model.ipynb -------------------------------------------------------------------------------- /docs/source/digits_diffusion_model.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/digits_diffusion_model.md -------------------------------------------------------------------------------- /docs/source/digits_vae.ipynb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/digits_vae.ipynb -------------------------------------------------------------------------------- /docs/source/digits_vae.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/digits_vae.md -------------------------------------------------------------------------------- /docs/source/events.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/events.md -------------------------------------------------------------------------------- /docs/source/getting_started.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/getting_started.md -------------------------------------------------------------------------------- /docs/source/index.html: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/index.html -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/index.rst -------------------------------------------------------------------------------- /docs/source/install.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/install.md -------------------------------------------------------------------------------- /docs/source/more_learning.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/more_learning.md -------------------------------------------------------------------------------- /docs/source/neural_net_basics.ipynb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/neural_net_basics.ipynb -------------------------------------------------------------------------------- /docs/source/neural_net_basics.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/neural_net_basics.md -------------------------------------------------------------------------------- /docs/source/news.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/news.md -------------------------------------------------------------------------------- /docs/source/pytorch_users.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/pytorch_users.md -------------------------------------------------------------------------------- /docs/source/tutorials.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/tutorials.md -------------------------------------------------------------------------------- /docs/source/videos.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/docs/source/videos.md -------------------------------------------------------------------------------- /jax_ai_stack/__init__.py: -------------------------------------------------------------------------------- 1 | """jax-ai-stack metapackage.""" 2 | 3 | __version__ = "2025.10.28" 4 | -------------------------------------------------------------------------------- /jax_ai_stack/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /jax_ai_stack/tests/test_chex.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/jax_ai_stack/tests/test_chex.py -------------------------------------------------------------------------------- /jax_ai_stack/tests/test_nnx_with_grain.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/jax_ai_stack/tests/test_nnx_with_grain.py -------------------------------------------------------------------------------- /jax_ai_stack/tests/test_nnx_with_optax.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/jax_ai_stack/tests/test_nnx_with_optax.py -------------------------------------------------------------------------------- /jax_ai_stack/tests/test_nnx_with_orbax.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/jax_ai_stack/tests/test_nnx_with_orbax.py -------------------------------------------------------------------------------- /jax_ai_stack/tests/test_nnx_with_tfds.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/jax_ai_stack/tests/test_nnx_with_tfds.py -------------------------------------------------------------------------------- /jax_ai_stack/tests/test_tf_with_orbax_export.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/jax_ai_stack/tests/test_tf_with_orbax_export.py -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-ai-stack/HEAD/pyproject.toml --------------------------------------------------------------------------------