├── .github └── workflows │ ├── lint.yml │ ├── nightly.yml │ ├── release.yml │ ├── scorecard.yml │ ├── static.yml │ └── unit-tests.yml ├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── dev-requirements.txt ├── docs ├── docs │ ├── assets │ │ ├── image-trojan.png │ │ ├── jax_profiler.png │ │ ├── jax_profiler2.png │ │ └── torchax.png │ ├── index.md │ ├── javascripts │ │ └── mathjax.js │ ├── tutorials │ │ ├── distributed_array.py │ │ └── trainingyt.py │ └── user_guide │ │ ├── get-started.md │ │ └── how-it-works.md └── mkdocs.yml ├── examples ├── README.md ├── __init__.py ├── _diffusion.py ├── _grad_of_attention.py ├── basic_training.py ├── basic_training_jax.py ├── eager_mode.py ├── lightning_training.py ├── requirements.txt ├── train_gpt │ └── requirements.txt └── train_llama_torchtitan │ ├── Dockerfile │ ├── README.md │ ├── __init__.py │ ├── helper.py │ ├── splash_attn.py │ └── train_llama.py ├── format.sh ├── pyproject.toml ├── scripts └── update_to_nightly_version.py ├── test-requirements.txt ├── test ├── __init__.py ├── base_test_util.py ├── gemma │ ├── __init__.py │ ├── config.py │ ├── model.py │ ├── test_gemma.py │ └── tokenizer.py ├── test_amp.py ├── test_checkpoint.py ├── test_context.py ├── test_conv.py ├── test_core_aten_ops.py ├── test_embedding.py ├── test_exports.py ├── test_flax.py ├── test_functions.py ├── test_image.py ├── test_interop.py ├── test_jittable_module.py ├── test_libraries.py ├── test_misc.py ├── test_mutations.py ├── test_ops.py ├── test_symbolic_shapes.py ├── test_threading.py ├── test_train.py ├── test_unbounded_dynamism.py ├── test_util.py └── test_view.py ├── test_dist ├── test_mesh_util.py └── test_to_device.py └── torchax ├── CONTRIBUTING.md ├── __init__.py ├── amp.py ├── checkpoint.py ├── config.py ├── decompositions.py ├── device_module.py ├── export.py ├── flax.py ├── interop.py ├── mesh_util.py ├── ops ├── __init__.py ├── jaten.py ├── jax_reimplement.py ├── jc10d.py ├── jimage.py ├── jlibrary.py ├── jtorch.py ├── jtorchvision_nms.py ├── mappings.py ├── op_base.py └── ops_registry.py ├── tensor.py ├── train.py ├── types.py ├── util.py └── view.py /.github/workflows/lint.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/.github/workflows/lint.yml -------------------------------------------------------------------------------- /.github/workflows/nightly.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/.github/workflows/nightly.yml -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/.github/workflows/release.yml -------------------------------------------------------------------------------- /.github/workflows/scorecard.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/.github/workflows/scorecard.yml -------------------------------------------------------------------------------- /.github/workflows/static.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/.github/workflows/static.yml -------------------------------------------------------------------------------- /.github/workflows/unit-tests.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/.github/workflows/unit-tests.yml -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/.gitignore -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/CONTRIBUTING.md -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/LICENSE -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/README.md -------------------------------------------------------------------------------- /dev-requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/dev-requirements.txt -------------------------------------------------------------------------------- /docs/docs/assets/image-trojan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/docs/docs/assets/image-trojan.png -------------------------------------------------------------------------------- /docs/docs/assets/jax_profiler.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/docs/docs/assets/jax_profiler.png -------------------------------------------------------------------------------- /docs/docs/assets/jax_profiler2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/docs/docs/assets/jax_profiler2.png -------------------------------------------------------------------------------- /docs/docs/assets/torchax.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/docs/docs/assets/torchax.png -------------------------------------------------------------------------------- /docs/docs/index.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/docs/docs/index.md -------------------------------------------------------------------------------- /docs/docs/javascripts/mathjax.js: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/docs/docs/javascripts/mathjax.js -------------------------------------------------------------------------------- /docs/docs/tutorials/distributed_array.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/docs/docs/tutorials/distributed_array.py -------------------------------------------------------------------------------- /docs/docs/tutorials/trainingyt.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/docs/docs/tutorials/trainingyt.py -------------------------------------------------------------------------------- /docs/docs/user_guide/get-started.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/docs/docs/user_guide/get-started.md -------------------------------------------------------------------------------- /docs/docs/user_guide/how-it-works.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/docs/docs/user_guide/how-it-works.md -------------------------------------------------------------------------------- /docs/mkdocs.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/docs/mkdocs.yml -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/examples/README.md -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/examples/__init__.py -------------------------------------------------------------------------------- /examples/_diffusion.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/examples/_diffusion.py -------------------------------------------------------------------------------- /examples/_grad_of_attention.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/examples/_grad_of_attention.py -------------------------------------------------------------------------------- /examples/basic_training.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/examples/basic_training.py -------------------------------------------------------------------------------- /examples/basic_training_jax.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/examples/basic_training_jax.py -------------------------------------------------------------------------------- /examples/eager_mode.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/examples/eager_mode.py -------------------------------------------------------------------------------- /examples/lightning_training.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/examples/lightning_training.py -------------------------------------------------------------------------------- /examples/requirements.txt: -------------------------------------------------------------------------------- 1 | torchvision 2 | matplotlib 3 | optax -------------------------------------------------------------------------------- /examples/train_gpt/requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/examples/train_gpt/requirements.txt -------------------------------------------------------------------------------- /examples/train_llama_torchtitan/Dockerfile: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/examples/train_llama_torchtitan/Dockerfile -------------------------------------------------------------------------------- /examples/train_llama_torchtitan/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/examples/train_llama_torchtitan/README.md -------------------------------------------------------------------------------- /examples/train_llama_torchtitan/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/examples/train_llama_torchtitan/__init__.py -------------------------------------------------------------------------------- /examples/train_llama_torchtitan/helper.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/examples/train_llama_torchtitan/helper.py -------------------------------------------------------------------------------- /examples/train_llama_torchtitan/splash_attn.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/examples/train_llama_torchtitan/splash_attn.py -------------------------------------------------------------------------------- /examples/train_llama_torchtitan/train_llama.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/examples/train_llama_torchtitan/train_llama.py -------------------------------------------------------------------------------- /format.sh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/format.sh -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/pyproject.toml -------------------------------------------------------------------------------- /scripts/update_to_nightly_version.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/scripts/update_to_nightly_version.py -------------------------------------------------------------------------------- /test-requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/test-requirements.txt -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/test/__init__.py -------------------------------------------------------------------------------- /test/base_test_util.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/test/base_test_util.py -------------------------------------------------------------------------------- /test/gemma/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/test/gemma/__init__.py -------------------------------------------------------------------------------- /test/gemma/config.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/test/gemma/config.py -------------------------------------------------------------------------------- /test/gemma/model.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/test/gemma/model.py -------------------------------------------------------------------------------- /test/gemma/test_gemma.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/test/gemma/test_gemma.py -------------------------------------------------------------------------------- /test/gemma/tokenizer.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/test/gemma/tokenizer.py -------------------------------------------------------------------------------- /test/test_amp.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/test/test_amp.py -------------------------------------------------------------------------------- /test/test_checkpoint.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/test/test_checkpoint.py -------------------------------------------------------------------------------- /test/test_context.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/test/test_context.py -------------------------------------------------------------------------------- /test/test_conv.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/test/test_conv.py -------------------------------------------------------------------------------- /test/test_core_aten_ops.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/test/test_core_aten_ops.py -------------------------------------------------------------------------------- /test/test_embedding.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/test/test_embedding.py -------------------------------------------------------------------------------- /test/test_exports.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/test/test_exports.py -------------------------------------------------------------------------------- /test/test_flax.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/test/test_flax.py -------------------------------------------------------------------------------- /test/test_functions.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/test/test_functions.py -------------------------------------------------------------------------------- /test/test_image.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/test/test_image.py -------------------------------------------------------------------------------- /test/test_interop.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/test/test_interop.py -------------------------------------------------------------------------------- /test/test_jittable_module.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/test/test_jittable_module.py -------------------------------------------------------------------------------- /test/test_libraries.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/test/test_libraries.py -------------------------------------------------------------------------------- /test/test_misc.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/test/test_misc.py -------------------------------------------------------------------------------- /test/test_mutations.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/test/test_mutations.py -------------------------------------------------------------------------------- /test/test_ops.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/test/test_ops.py -------------------------------------------------------------------------------- /test/test_symbolic_shapes.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/test/test_symbolic_shapes.py -------------------------------------------------------------------------------- /test/test_threading.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/test/test_threading.py -------------------------------------------------------------------------------- /test/test_train.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/test/test_train.py -------------------------------------------------------------------------------- /test/test_unbounded_dynamism.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/test/test_unbounded_dynamism.py -------------------------------------------------------------------------------- /test/test_util.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/test/test_util.py -------------------------------------------------------------------------------- /test/test_view.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/test/test_view.py -------------------------------------------------------------------------------- /test_dist/test_mesh_util.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/test_dist/test_mesh_util.py -------------------------------------------------------------------------------- /test_dist/test_to_device.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/test_dist/test_to_device.py -------------------------------------------------------------------------------- /torchax/CONTRIBUTING.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/torchax/CONTRIBUTING.md -------------------------------------------------------------------------------- /torchax/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/torchax/__init__.py -------------------------------------------------------------------------------- /torchax/amp.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/torchax/amp.py -------------------------------------------------------------------------------- /torchax/checkpoint.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/torchax/checkpoint.py -------------------------------------------------------------------------------- /torchax/config.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/torchax/config.py -------------------------------------------------------------------------------- /torchax/decompositions.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/torchax/decompositions.py -------------------------------------------------------------------------------- /torchax/device_module.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/torchax/device_module.py -------------------------------------------------------------------------------- /torchax/export.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/torchax/export.py -------------------------------------------------------------------------------- /torchax/flax.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/torchax/flax.py -------------------------------------------------------------------------------- /torchax/interop.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/torchax/interop.py -------------------------------------------------------------------------------- /torchax/mesh_util.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/torchax/mesh_util.py -------------------------------------------------------------------------------- /torchax/ops/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/torchax/ops/__init__.py -------------------------------------------------------------------------------- /torchax/ops/jaten.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/torchax/ops/jaten.py -------------------------------------------------------------------------------- /torchax/ops/jax_reimplement.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/torchax/ops/jax_reimplement.py -------------------------------------------------------------------------------- /torchax/ops/jc10d.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/torchax/ops/jc10d.py -------------------------------------------------------------------------------- /torchax/ops/jimage.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/torchax/ops/jimage.py -------------------------------------------------------------------------------- /torchax/ops/jlibrary.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/torchax/ops/jlibrary.py -------------------------------------------------------------------------------- /torchax/ops/jtorch.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/torchax/ops/jtorch.py -------------------------------------------------------------------------------- /torchax/ops/jtorchvision_nms.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/torchax/ops/jtorchvision_nms.py -------------------------------------------------------------------------------- /torchax/ops/mappings.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/torchax/ops/mappings.py -------------------------------------------------------------------------------- /torchax/ops/op_base.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/torchax/ops/op_base.py -------------------------------------------------------------------------------- /torchax/ops/ops_registry.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/torchax/ops/ops_registry.py -------------------------------------------------------------------------------- /torchax/tensor.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/torchax/tensor.py -------------------------------------------------------------------------------- /torchax/train.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/torchax/train.py -------------------------------------------------------------------------------- /torchax/types.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/torchax/types.py -------------------------------------------------------------------------------- /torchax/util.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/torchax/util.py -------------------------------------------------------------------------------- /torchax/view.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/torchax/HEAD/torchax/view.py --------------------------------------------------------------------------------