├── LICENSE ├── README.md ├── code-exercises ├── 01 - JAX AI Stack.ipynb ├── 02 - NumPy and JAX NumPy.ipynb ├── 03 - Intro to NNX for PyTorch Users.ipynb ├── 04 - MNIST example.ipynb ├── 05 - Chex_ JAX & Flax NNX Reliability.ipynb ├── 06 - Debugging JAX and Flax NNX.ipynb ├── 07 - Grain for data loading.ipynb ├── 08 - Orbax for checkpointing.ipynb ├── 09 - Serving JAX with vLLM_SGLang.ipynb ├── 10 - Sharding & Parallelism.ipynb └── 11 - Optax optimizers.ipynb ├── quick-references ├── Chex Quick Reference.pdf ├── Flax NNX & JAX Quick Reference.pdf ├── Grain Quick Reference for JAX_Flax NNX.pdf ├── JAX & Flax NNX Debugging Quick Reference.pdf ├── JAX AI Stack Quick Reference.pdf ├── JAX NumPy Quick Reference.pdf ├── JAX Serving with vLLM & SGLang.pdf ├── JAX Sharding & Parallelism with Flax NNX.pdf ├── Optax & Flax NNX Quick Reference.pdf └── Orbax & Flax NNX Checkpointing_ Quick Reference.pdf └── slide-decks ├── 0 - Why JAX and Flax NNX.pdf ├── 0 - Why JAX and Flax NNX.pptx ├── 1 - JAX AI Stack.pdf ├── 1 - JAX AI Stack.pptx ├── 10 - Sharding & Parallelism.pdf ├── 10 - Sharding & Parallelism.pptx ├── 11 - Optax optimizers.pdf ├── 11 - Optax optimizers.pptx ├── 12 - Conclusion.pdf ├── 12 - Conclusion.pptx ├── 2 - NumPy and JAX NumPy.pdf ├── 2 - NumPy and JAX NumPy.pptx ├── 3 - Intro to NNX for PyTorch users.pdf ├── 3 - Intro to NNX for PyTorch users.pptx ├── 4 - MNIST example.pdf ├── 4 - MNIST example.pptx ├── 5 - Chex_ JAX & Flax NNX Reliability.pdf ├── 5 - Chex_ JAX & Flax NNX Reliability.pptx ├── 6 - Debugging JAX and Flax NNX.pdf ├── 6 - Debugging JAX and Flax NNX.pptx ├── 7 - Grain for data loading.pdf ├── 7 - Grain for data loading.pptx ├── 8 - Orbax for checkpointing.pdf ├── 8 - Orbax for checkpointing.pptx ├── 9 - Serving JAX with vLLM_SGLang.pdf └── 9 - Serving JAX with vLLM_SGLang.pptx /LICENSE: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rcrowe-google/Learning-JAX/HEAD/LICENSE -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rcrowe-google/Learning-JAX/HEAD/README.md -------------------------------------------------------------------------------- /code-exercises/01 - JAX AI Stack.ipynb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rcrowe-google/Learning-JAX/HEAD/code-exercises/01 - JAX AI Stack.ipynb -------------------------------------------------------------------------------- /code-exercises/02 - NumPy and JAX NumPy.ipynb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rcrowe-google/Learning-JAX/HEAD/code-exercises/02 - NumPy and JAX NumPy.ipynb -------------------------------------------------------------------------------- /code-exercises/03 - Intro to NNX for PyTorch Users.ipynb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rcrowe-google/Learning-JAX/HEAD/code-exercises/03 - Intro to NNX for PyTorch Users.ipynb -------------------------------------------------------------------------------- /code-exercises/04 - MNIST example.ipynb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rcrowe-google/Learning-JAX/HEAD/code-exercises/04 - MNIST example.ipynb -------------------------------------------------------------------------------- /code-exercises/05 - Chex_ JAX & Flax NNX Reliability.ipynb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rcrowe-google/Learning-JAX/HEAD/code-exercises/05 - Chex_ JAX & Flax NNX Reliability.ipynb -------------------------------------------------------------------------------- /code-exercises/06 - Debugging JAX and Flax NNX.ipynb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rcrowe-google/Learning-JAX/HEAD/code-exercises/06 - Debugging JAX and Flax NNX.ipynb -------------------------------------------------------------------------------- /code-exercises/07 - Grain for data loading.ipynb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rcrowe-google/Learning-JAX/HEAD/code-exercises/07 - Grain for data loading.ipynb -------------------------------------------------------------------------------- /code-exercises/08 - Orbax for checkpointing.ipynb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rcrowe-google/Learning-JAX/HEAD/code-exercises/08 - Orbax for checkpointing.ipynb -------------------------------------------------------------------------------- /code-exercises/09 - Serving JAX with vLLM_SGLang.ipynb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rcrowe-google/Learning-JAX/HEAD/code-exercises/09 - Serving JAX with vLLM_SGLang.ipynb -------------------------------------------------------------------------------- /code-exercises/10 - Sharding & Parallelism.ipynb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rcrowe-google/Learning-JAX/HEAD/code-exercises/10 - Sharding & Parallelism.ipynb -------------------------------------------------------------------------------- /code-exercises/11 - Optax optimizers.ipynb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rcrowe-google/Learning-JAX/HEAD/code-exercises/11 - Optax optimizers.ipynb -------------------------------------------------------------------------------- /quick-references/Chex Quick Reference.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rcrowe-google/Learning-JAX/HEAD/quick-references/Chex Quick Reference.pdf -------------------------------------------------------------------------------- /quick-references/Flax NNX & JAX Quick Reference.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rcrowe-google/Learning-JAX/HEAD/quick-references/Flax NNX & JAX Quick Reference.pdf -------------------------------------------------------------------------------- /quick-references/Grain Quick Reference for JAX_Flax NNX.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rcrowe-google/Learning-JAX/HEAD/quick-references/Grain Quick Reference for JAX_Flax NNX.pdf -------------------------------------------------------------------------------- /quick-references/JAX & Flax NNX Debugging Quick Reference.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rcrowe-google/Learning-JAX/HEAD/quick-references/JAX & Flax NNX Debugging Quick Reference.pdf -------------------------------------------------------------------------------- /quick-references/JAX AI Stack Quick Reference.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rcrowe-google/Learning-JAX/HEAD/quick-references/JAX AI Stack Quick Reference.pdf -------------------------------------------------------------------------------- /quick-references/JAX NumPy Quick Reference.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rcrowe-google/Learning-JAX/HEAD/quick-references/JAX NumPy Quick Reference.pdf -------------------------------------------------------------------------------- /quick-references/JAX Serving with vLLM & SGLang.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rcrowe-google/Learning-JAX/HEAD/quick-references/JAX Serving with vLLM & SGLang.pdf -------------------------------------------------------------------------------- /quick-references/JAX Sharding & Parallelism with Flax NNX.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rcrowe-google/Learning-JAX/HEAD/quick-references/JAX Sharding & Parallelism with Flax NNX.pdf -------------------------------------------------------------------------------- /quick-references/Optax & Flax NNX Quick Reference.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rcrowe-google/Learning-JAX/HEAD/quick-references/Optax & Flax NNX Quick Reference.pdf -------------------------------------------------------------------------------- /quick-references/Orbax & Flax NNX Checkpointing_ Quick Reference.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rcrowe-google/Learning-JAX/HEAD/quick-references/Orbax & Flax NNX Checkpointing_ Quick Reference.pdf -------------------------------------------------------------------------------- /slide-decks/0 - Why JAX and Flax NNX.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rcrowe-google/Learning-JAX/HEAD/slide-decks/0 - Why JAX and Flax NNX.pdf -------------------------------------------------------------------------------- /slide-decks/0 - Why JAX and Flax NNX.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rcrowe-google/Learning-JAX/HEAD/slide-decks/0 - Why JAX and Flax NNX.pptx -------------------------------------------------------------------------------- /slide-decks/1 - JAX AI Stack.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rcrowe-google/Learning-JAX/HEAD/slide-decks/1 - JAX AI Stack.pdf -------------------------------------------------------------------------------- /slide-decks/1 - JAX AI Stack.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rcrowe-google/Learning-JAX/HEAD/slide-decks/1 - JAX AI Stack.pptx -------------------------------------------------------------------------------- /slide-decks/10 - Sharding & Parallelism.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rcrowe-google/Learning-JAX/HEAD/slide-decks/10 - Sharding & Parallelism.pdf -------------------------------------------------------------------------------- /slide-decks/10 - Sharding & Parallelism.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rcrowe-google/Learning-JAX/HEAD/slide-decks/10 - Sharding & Parallelism.pptx -------------------------------------------------------------------------------- /slide-decks/11 - Optax optimizers.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rcrowe-google/Learning-JAX/HEAD/slide-decks/11 - Optax optimizers.pdf -------------------------------------------------------------------------------- /slide-decks/11 - Optax optimizers.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rcrowe-google/Learning-JAX/HEAD/slide-decks/11 - Optax optimizers.pptx -------------------------------------------------------------------------------- /slide-decks/12 - Conclusion.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rcrowe-google/Learning-JAX/HEAD/slide-decks/12 - Conclusion.pdf -------------------------------------------------------------------------------- /slide-decks/12 - Conclusion.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rcrowe-google/Learning-JAX/HEAD/slide-decks/12 - Conclusion.pptx -------------------------------------------------------------------------------- /slide-decks/2 - NumPy and JAX NumPy.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rcrowe-google/Learning-JAX/HEAD/slide-decks/2 - NumPy and JAX NumPy.pdf -------------------------------------------------------------------------------- /slide-decks/2 - NumPy and JAX NumPy.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rcrowe-google/Learning-JAX/HEAD/slide-decks/2 - NumPy and JAX NumPy.pptx -------------------------------------------------------------------------------- /slide-decks/3 - Intro to NNX for PyTorch users.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rcrowe-google/Learning-JAX/HEAD/slide-decks/3 - Intro to NNX for PyTorch users.pdf -------------------------------------------------------------------------------- /slide-decks/3 - Intro to NNX for PyTorch users.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rcrowe-google/Learning-JAX/HEAD/slide-decks/3 - Intro to NNX for PyTorch users.pptx -------------------------------------------------------------------------------- /slide-decks/4 - MNIST example.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rcrowe-google/Learning-JAX/HEAD/slide-decks/4 - MNIST example.pdf -------------------------------------------------------------------------------- /slide-decks/4 - MNIST example.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rcrowe-google/Learning-JAX/HEAD/slide-decks/4 - MNIST example.pptx -------------------------------------------------------------------------------- /slide-decks/5 - Chex_ JAX & Flax NNX Reliability.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rcrowe-google/Learning-JAX/HEAD/slide-decks/5 - Chex_ JAX & Flax NNX Reliability.pdf -------------------------------------------------------------------------------- /slide-decks/5 - Chex_ JAX & Flax NNX Reliability.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rcrowe-google/Learning-JAX/HEAD/slide-decks/5 - Chex_ JAX & Flax NNX Reliability.pptx -------------------------------------------------------------------------------- /slide-decks/6 - Debugging JAX and Flax NNX.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rcrowe-google/Learning-JAX/HEAD/slide-decks/6 - Debugging JAX and Flax NNX.pdf -------------------------------------------------------------------------------- /slide-decks/6 - Debugging JAX and Flax NNX.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rcrowe-google/Learning-JAX/HEAD/slide-decks/6 - Debugging JAX and Flax NNX.pptx -------------------------------------------------------------------------------- /slide-decks/7 - Grain for data loading.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rcrowe-google/Learning-JAX/HEAD/slide-decks/7 - Grain for data loading.pdf -------------------------------------------------------------------------------- /slide-decks/7 - Grain for data loading.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rcrowe-google/Learning-JAX/HEAD/slide-decks/7 - Grain for data loading.pptx -------------------------------------------------------------------------------- /slide-decks/8 - Orbax for checkpointing.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rcrowe-google/Learning-JAX/HEAD/slide-decks/8 - Orbax for checkpointing.pdf -------------------------------------------------------------------------------- /slide-decks/8 - Orbax for checkpointing.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rcrowe-google/Learning-JAX/HEAD/slide-decks/8 - Orbax for checkpointing.pptx -------------------------------------------------------------------------------- /slide-decks/9 - Serving JAX with vLLM_SGLang.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rcrowe-google/Learning-JAX/HEAD/slide-decks/9 - Serving JAX with vLLM_SGLang.pdf -------------------------------------------------------------------------------- /slide-decks/9 - Serving JAX with vLLM_SGLang.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rcrowe-google/Learning-JAX/HEAD/slide-decks/9 - Serving JAX with vLLM_SGLang.pptx --------------------------------------------------------------------------------