├── .gitattributes ├── .github └── workflows │ ├── publish.yaml │ └── tests.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── docs ├── JAX FP8 matmul tutorial.ipynb ├── PyTorch FP8 matmul tutorial.ipynb ├── img │ └── fp-formats.webp └── operators.md ├── examples ├── cifar10 │ ├── cifar10_training.py │ ├── cifar10_training_with_optax.py │ └── dataset_cifar10.py ├── mnist │ ├── datasets.py │ ├── flax │ │ ├── README.md │ │ ├── configs │ │ │ ├── __init__.py │ │ │ └── default.py │ │ ├── main.py │ │ ├── requirements.txt │ │ └── train.py │ ├── mnist_classifier_from_scratch.py │ ├── mnist_classifier_from_scratch_fp8.py │ └── mnist_classifier_mlp_flax.py └── scalify-quickstart.ipynb ├── jax_scalify ├── __init__.py ├── core │ ├── __init__.py │ ├── datatype.py │ ├── debug.py │ ├── interpreters.py │ ├── pow2.py │ ├── typing.py │ └── utils.py ├── lax │ ├── __init__.py │ ├── base_scaling_primitives.py │ ├── scaled_ops_common.py │ └── scaled_ops_l2.py ├── ops │ ├── __init__.py │ ├── cast.py │ ├── debug.py │ ├── rescaling.py │ └── utils.py ├── quantization │ ├── __init__.py │ └── scale.py ├── tree │ ├── __init__.py │ └── tree_util.py └── utils │ ├── __init__.py │ └── hlo.py ├── pyproject.toml ├── setup.cfg ├── test-requirements.txt └── tests ├── core ├── test_datatype.py ├── test_interpreter.py ├── test_pow2.py └── test_utils.py ├── lax ├── test_base_scaling_primitives.py ├── test_numpy_integration.py ├── test_scaled_ops_common.py ├── test_scaled_ops_l2.py └── test_scipy_integration.py ├── ops ├── test_cast.py ├── test_debug.py └── test_rescaling.py ├── quantization └── test_scale.py ├── tree └── test_tree_util.py └── utils └── test_hlo.py /.gitattributes: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.github/workflows/publish.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/.github/workflows/publish.yaml -------------------------------------------------------------------------------- /.github/workflows/tests.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/.github/workflows/tests.yaml -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/.gitignore -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/.pre-commit-config.yaml -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/LICENSE -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/README.md -------------------------------------------------------------------------------- /docs/JAX FP8 matmul tutorial.ipynb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/docs/JAX FP8 matmul tutorial.ipynb -------------------------------------------------------------------------------- /docs/PyTorch FP8 matmul tutorial.ipynb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/docs/PyTorch FP8 matmul tutorial.ipynb -------------------------------------------------------------------------------- /docs/img/fp-formats.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/docs/img/fp-formats.webp -------------------------------------------------------------------------------- /docs/operators.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/docs/operators.md -------------------------------------------------------------------------------- /examples/cifar10/cifar10_training.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/examples/cifar10/cifar10_training.py -------------------------------------------------------------------------------- /examples/cifar10/cifar10_training_with_optax.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/examples/cifar10/cifar10_training_with_optax.py -------------------------------------------------------------------------------- /examples/cifar10/dataset_cifar10.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/examples/cifar10/dataset_cifar10.py -------------------------------------------------------------------------------- /examples/mnist/datasets.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/examples/mnist/datasets.py -------------------------------------------------------------------------------- /examples/mnist/flax/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/examples/mnist/flax/README.md -------------------------------------------------------------------------------- /examples/mnist/flax/configs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/mnist/flax/configs/default.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/examples/mnist/flax/configs/default.py -------------------------------------------------------------------------------- /examples/mnist/flax/main.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/examples/mnist/flax/main.py -------------------------------------------------------------------------------- /examples/mnist/flax/requirements.txt: -------------------------------------------------------------------------------- 1 | clu 2 | flax 3 | ml-collections 4 | optax 5 | tensorflow-datasets 6 | -------------------------------------------------------------------------------- /examples/mnist/flax/train.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/examples/mnist/flax/train.py -------------------------------------------------------------------------------- /examples/mnist/mnist_classifier_from_scratch.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/examples/mnist/mnist_classifier_from_scratch.py -------------------------------------------------------------------------------- /examples/mnist/mnist_classifier_from_scratch_fp8.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/examples/mnist/mnist_classifier_from_scratch_fp8.py -------------------------------------------------------------------------------- /examples/mnist/mnist_classifier_mlp_flax.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/examples/mnist/mnist_classifier_mlp_flax.py -------------------------------------------------------------------------------- /examples/scalify-quickstart.ipynb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/examples/scalify-quickstart.ipynb -------------------------------------------------------------------------------- /jax_scalify/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/jax_scalify/__init__.py -------------------------------------------------------------------------------- /jax_scalify/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/jax_scalify/core/__init__.py -------------------------------------------------------------------------------- /jax_scalify/core/datatype.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/jax_scalify/core/datatype.py -------------------------------------------------------------------------------- /jax_scalify/core/debug.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/jax_scalify/core/debug.py -------------------------------------------------------------------------------- /jax_scalify/core/interpreters.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/jax_scalify/core/interpreters.py -------------------------------------------------------------------------------- /jax_scalify/core/pow2.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/jax_scalify/core/pow2.py -------------------------------------------------------------------------------- /jax_scalify/core/typing.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/jax_scalify/core/typing.py -------------------------------------------------------------------------------- /jax_scalify/core/utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/jax_scalify/core/utils.py -------------------------------------------------------------------------------- /jax_scalify/lax/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/jax_scalify/lax/__init__.py -------------------------------------------------------------------------------- /jax_scalify/lax/base_scaling_primitives.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/jax_scalify/lax/base_scaling_primitives.py -------------------------------------------------------------------------------- /jax_scalify/lax/scaled_ops_common.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/jax_scalify/lax/scaled_ops_common.py -------------------------------------------------------------------------------- /jax_scalify/lax/scaled_ops_l2.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/jax_scalify/lax/scaled_ops_l2.py -------------------------------------------------------------------------------- /jax_scalify/ops/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/jax_scalify/ops/__init__.py -------------------------------------------------------------------------------- /jax_scalify/ops/cast.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/jax_scalify/ops/cast.py -------------------------------------------------------------------------------- /jax_scalify/ops/debug.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/jax_scalify/ops/debug.py -------------------------------------------------------------------------------- /jax_scalify/ops/rescaling.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/jax_scalify/ops/rescaling.py -------------------------------------------------------------------------------- /jax_scalify/ops/utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/jax_scalify/ops/utils.py -------------------------------------------------------------------------------- /jax_scalify/quantization/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/jax_scalify/quantization/__init__.py -------------------------------------------------------------------------------- /jax_scalify/quantization/scale.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/jax_scalify/quantization/scale.py -------------------------------------------------------------------------------- /jax_scalify/tree/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/jax_scalify/tree/__init__.py -------------------------------------------------------------------------------- /jax_scalify/tree/tree_util.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/jax_scalify/tree/tree_util.py -------------------------------------------------------------------------------- /jax_scalify/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/jax_scalify/utils/__init__.py -------------------------------------------------------------------------------- /jax_scalify/utils/hlo.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/jax_scalify/utils/hlo.py -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/pyproject.toml -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/setup.cfg -------------------------------------------------------------------------------- /test-requirements.txt: -------------------------------------------------------------------------------- 1 | pytest 2 | -------------------------------------------------------------------------------- /tests/core/test_datatype.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/tests/core/test_datatype.py -------------------------------------------------------------------------------- /tests/core/test_interpreter.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/tests/core/test_interpreter.py -------------------------------------------------------------------------------- /tests/core/test_pow2.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/tests/core/test_pow2.py -------------------------------------------------------------------------------- /tests/core/test_utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/tests/core/test_utils.py -------------------------------------------------------------------------------- /tests/lax/test_base_scaling_primitives.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/tests/lax/test_base_scaling_primitives.py -------------------------------------------------------------------------------- /tests/lax/test_numpy_integration.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/tests/lax/test_numpy_integration.py -------------------------------------------------------------------------------- /tests/lax/test_scaled_ops_common.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/tests/lax/test_scaled_ops_common.py -------------------------------------------------------------------------------- /tests/lax/test_scaled_ops_l2.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/tests/lax/test_scaled_ops_l2.py -------------------------------------------------------------------------------- /tests/lax/test_scipy_integration.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/tests/lax/test_scipy_integration.py -------------------------------------------------------------------------------- /tests/ops/test_cast.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/tests/ops/test_cast.py -------------------------------------------------------------------------------- /tests/ops/test_debug.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/tests/ops/test_debug.py -------------------------------------------------------------------------------- /tests/ops/test_rescaling.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/tests/ops/test_rescaling.py -------------------------------------------------------------------------------- /tests/quantization/test_scale.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/tests/quantization/test_scale.py -------------------------------------------------------------------------------- /tests/tree/test_tree_util.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/tests/tree/test_tree_util.py -------------------------------------------------------------------------------- /tests/utils/test_hlo.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/HEAD/tests/utils/test_hlo.py --------------------------------------------------------------------------------