├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── RELEASE.md
├── WORKSPACE
├── praxis
├── AUTHORS
├── BUILD
├── __init__.py
├── asserts.py
├── asserts_test.py
├── base_hyperparams.py
├── base_hyperparams_test.py
├── base_input.py
├── base_input_test.py
├── base_layer.py
├── base_layer_test.py
├── base_model.py
├── beam_search.py
├── beam_search_test.py
├── build-visibility.bzl
├── contrib
│ └── gpu
│ │ └── scripts_gpu
│ │ ├── lora_layers.py
│ │ └── models.py
├── decoder_hparams.py
├── decoder_utils.py
├── decoder_utils_test.py
├── fiddle_tags.py
├── flat_beam_search.py
├── flat_beam_search_test.py
├── flax_utils.py
├── gshard_utils.py
├── layers
│ ├── BUILD
│ ├── __init__.py
│ ├── activations.py
│ ├── activations_test.py
│ ├── adapters.py
│ ├── adapters_test.py
│ ├── attentions.py
│ ├── attentions_test.py
│ ├── augmentations.py
│ ├── augmentations_test.py
│ ├── base_ops.py
│ ├── bregman.py
│ ├── bregman_test.py
│ ├── chain
│ │ ├── BUILD
│ │ ├── __init__.py
│ │ ├── chain.py
│ │ ├── chain_extensions.py
│ │ └── chain_test.py
│ ├── checkpoint_policy.py
│ ├── chunk.py
│ ├── chunk_test.py
│ ├── conformers.py
│ ├── conformers_test.py
│ ├── convolutions.py
│ ├── convolutions_test.py
│ ├── ctc_objectives.py
│ ├── ctc_objectives_test.py
│ ├── einsum.py
│ ├── einsum_test.py
│ ├── embedding_softmax.py
│ ├── embedding_softmax_test.py
│ ├── flax_adapter.py
│ ├── flax_adapter_test.py
│ ├── frnn.py
│ ├── frnn_test.py
│ ├── glam.py
│ ├── gpu_fast_attention.py
│ ├── grok.py
│ ├── grouped_query_attention.py
│ ├── grouped_query_attention_test.py
│ ├── injection
│ │ ├── BUILD
│ │ ├── __init__.py
│ │ ├── fp8_nvidia_gpu.py
│ │ └── fp8_nvidia_gpu_test.py
│ ├── linears.py
│ ├── linears_test.py
│ ├── losses.py
│ ├── mobilenet.py
│ ├── models.py
│ ├── models_test.py
│ ├── multi_query_attention.py
│ ├── multi_query_attention_test.py
│ ├── ngrammer.py
│ ├── ngrammer_test.py
│ ├── normalizations.py
│ ├── normalizations_test.py
│ ├── pipeline.py
│ ├── pipeline_test.py
│ ├── poolings.py
│ ├── poolings_test.py
│ ├── quantization
│ │ ├── BUILD
│ │ ├── __init__.py
│ │ ├── attentions.py
│ │ ├── attentions_test.py
│ │ ├── automl_select.py
│ │ ├── automl_select_test.py
│ │ ├── conformers.py
│ │ ├── conformers_test.py
│ │ ├── convolutions.py
│ │ ├── convolutions_test.py
│ │ ├── einsum.py
│ │ ├── einsum_test.py
│ │ ├── embedding_softmax.py
│ │ ├── embedding_softmax_test.py
│ │ ├── linears.py
│ │ ├── linears_test.py
│ │ ├── multi_query_attention.py
│ │ ├── multi_query_attention_test.py
│ │ ├── ngrammer.py
│ │ ├── ngrammer_test.py
│ │ ├── operations.py
│ │ ├── operations_test.py
│ │ ├── optimization.py
│ │ ├── optimization_test.py
│ │ ├── overflow_check.py
│ │ ├── overflow_check_test.py
│ │ ├── quantization_hparams.py
│ │ ├── quantization_test.py
│ │ ├── quantize.py
│ │ ├── quantize_test.py
│ │ ├── quantizer.py
│ │ ├── quantizer_test.py
│ │ ├── searchable.py
│ │ ├── searchable_test.py
│ │ ├── sparsity
│ │ │ ├── BUILD
│ │ │ ├── __init__.py
│ │ │ ├── attentions_test.py
│ │ │ ├── linears_test.py
│ │ │ ├── sparsifier.py
│ │ │ ├── sparsifier_test.py
│ │ │ ├── sparsity.py
│ │ │ ├── sparsity_hparams.py
│ │ │ ├── sparsity_modes.py
│ │ │ └── sparsity_test.py
│ │ ├── tests
│ │ │ ├── attention_projection_aqt_test.py
│ │ │ ├── attention_projection_fq_test.py
│ │ │ ├── attention_projection_ptq_test.py
│ │ │ ├── combined_qkv_projection_aqt_test.py
│ │ │ ├── combined_qkv_projection_fq_test.py
│ │ │ ├── combined_qkv_projection_ptq_test.py
│ │ │ ├── dotproduct_attention_aqt_test.py
│ │ │ ├── linear_ptq_test.py
│ │ │ ├── one_headed_attention_projection_aqt_test.py
│ │ │ └── test_util.py
│ │ ├── utils.py
│ │ └── utils_test.py
│ ├── quantizer.py
│ ├── quantizer_objectives.py
│ ├── quantizer_objectives_test.py
│ ├── quantizer_test.py
│ ├── repeats.py
│ ├── repeats_test.py
│ ├── resnets.py
│ ├── rnn_cell.py
│ ├── rnn_cell_test.py
│ ├── searchable.py
│ ├── searchable_test.py
│ ├── sedd
│ │ └── sedd.ipynb
│ ├── sequential.py
│ ├── sequential_test.py
│ ├── sharding.py
│ ├── shared_layers_test.py
│ ├── spectrum_augmenter.py
│ ├── spectrum_augmenter_test.py
│ ├── ssm.py
│ ├── ssm_test.py
│ ├── ssm_transformers.py
│ ├── ssm_transformers_test.py
│ ├── stats.py
│ ├── stats_test.py
│ ├── stochastics.py
│ ├── stochastics_test.py
│ ├── test_layers.py
│ ├── transformer_models.py
│ ├── transformer_models_encoder_decoder_test.py
│ ├── transformer_models_test.py
│ ├── transformers.py
│ ├── transformers_test.py
│ ├── vanillanets.py
│ ├── vanillanets_test.py
│ ├── video
│ │ ├── BUILD
│ │ ├── __init__.py
│ │ ├── enc_dec_3dcnn.py
│ │ ├── enc_dec_3dcnn_test.py
│ │ ├── losses.py
│ │ ├── losses_test.py
│ │ ├── quantizer.py
│ │ ├── quantizer_test.py
│ │ ├── vqvae.py
│ │ └── vqvae_test.py
│ ├── vits.py
│ └── vits_test.py
├── lazy_loader.py
├── lingvo_lib.py
├── metric_utils.py
├── optimizer_prefix_vectorization.py
├── optimizer_prefix_vectorization_test.py
├── optimizers.py
├── optimizers_test.py
├── pax_fiddle.py
├── pax_fiddle_test.py
├── pip_package
│ ├── build_pip_pkg.sh
│ ├── cloudbuild-postsubmit.yaml
│ ├── cloudbuild-presubmit.yaml
│ ├── cloudbuild-release.yaml
│ ├── postsubmit.Dockerfile
│ ├── prepare_release.sh
│ ├── presubmit.Dockerfile
│ ├── release.Dockerfile
│ └── requirements.txt
├── praxis.bzl
├── py_utils.py
├── py_utils_test.py
├── pytypes.py
├── pytypes_test.py
├── sample_decode.py
├── sample_decode_test.py
├── schedules.py
├── schedules_test.py
├── test_utils.py
├── token_samplers.py
├── token_samplers_test.py
├── trees.py
└── trees_test.py
├── requirements.in
└── setup.py
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement. You (or your employer) retain the copyright to your contribution;
10 | this simply gives us permission to use and redistribute your contributions as
11 | part of the project. Head over to to see
12 | your current agreements on file or to sign a new one.
13 |
14 | You generally only need to submit a CLA once, so if you've already submitted one
15 | (even if it was for a different project), you probably don't need to do it
16 | again.
17 |
18 | ## Code reviews
19 |
20 | All submissions, including submissions by project members, require review. We
21 | use GitHub pull requests for this purpose. Consult
22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
23 | information on using pull requests.
24 |
25 | ## Community Guidelines
26 |
27 | This project follows
28 | [Google's Open Source Community Guidelines](https://opensource.google.com/conduct/).
29 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Praxis
2 |
3 | What is Praxis?
4 | Praxis is the layer library for [Pax](https://github.com/google/paxml/). While Praxis is optimized for ML at scale, Praxis has a goal to be usable by other JAX-based ML projects.
5 |
6 |
7 | Some examples of layers to be folded into Praxis are in the [praxis/layers/](https://github.com/google/praxis/tree/main/praxis/layers) directory.
8 |
9 |
10 | Copyright 2022 Google LLC
11 |
12 | Licensed under the Apache License, Version 2.0 (the "License");
13 | you may not use this file except in compliance with the License.
14 | You may obtain a copy of the License at
15 |
16 | https://www.apache.org/licenses/LICENSE-2.0
17 |
18 | Unless required by applicable law or agreed to in writing, software
19 | distributed under the License is distributed on an "AS IS" BASIS,
20 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21 | See the License for the specific language governing permissions and
22 | limitations under the License.
23 |
--------------------------------------------------------------------------------
/RELEASE.md:
--------------------------------------------------------------------------------
1 | # Version: 1.4.0
2 | ## Major Features and Improvements
3 | ## Breaking changes
4 | ## Deprecations
5 | ## Note
6 | * Version: 1.4.0
7 | * Build Date: 20240408
8 | * Praxis commit: 794623a65934cac98aaa962afcbf9d40a51c02f8
9 | # Version: 1.3.1
10 | ## Major Features and Improvements
11 | ## Breaking changes
12 | ## Deprecations
13 | ## Note
14 | * Version: 1.3.1
15 | * Build Date: 20240220
16 | * Praxis commit: a935384da88ec739d48696c2d0a55657a183de00
17 | # Version: 1.3.0
18 | ## Major Features and Improvements
19 | ## Breaking changes
20 | ## Deprecations
21 | ## Note
22 | * Version: 1.3.0
23 | * Build Date: 20240216
24 | * Praxis commit: ca8ad2ba8be5b77092a8e5f0a0774a39cc3ad766
25 | # Version: 1.2.0
26 | ## Major Features and Improvements
27 | ## Breaking changes
28 | ## Deprecations
29 | ## Note
30 | * Version: 1.2.0
31 | * Build Date: 20231016
32 | * Praxis commit: 7bd63412bf86a68e09fcd9455f76a4909d19377e
33 | # Version: 1.1.0
34 | ## Major Features and Improvements
35 | * Move to python 3.10 as the minimal python requirement (previously on python 3.8).
36 | * Add various quantization support.
37 | * Add support for fine-grained weight sparsity in praxis layers.
38 | * Add support for Direct Preference Optimization (DPO).
39 | ## Note
40 | * Version: 1.1.0
41 | * Build Date: 20230824
42 | * Praxis commit: 2a7d0407871502a1d79dcd0e01411e73f1d15d36
43 | # Version: 1.0.0
44 | ## Major Features and Improvements
45 | * **Fiddle** - Praxis layers and BaseParameterizable are now configured with [Fiddle](https://github.com/google/fiddle), a Python-first configuration library. Fiddle reduces boilerplate, and adds productivity features including history tracking, graphviz visualization, support for aliasing objects, and more.
46 | * **CLI Experiment and Data Injectability** - Enable Pax users to select which experiments to run without the need to recompile for each experiment. Using a CLI interface based on Fiddle, users can override subsets of the experiment’s canonical dataset.
47 | * **CLU Metrics** - Praxis has adopted CLU metrics as its standard metric interface. This allows other Jax/Flax codebases that have CLU metrics to use them in Praxis.
48 | * **Flax Interoperability** - Praxis now supports shape inference, __call__ for forward propagation, and has adopted Linen’s AxisMetadata for its mid-level sharding APIs. These changes improve interoperability with other Flax-based libraries such as T5X.
49 | ## Note
50 | * Version: 1.0.0
51 | * Build Date: 20230329
52 | * Praxis commit: 621c2ca7bfcd0e21ea118a3d8e40e29b48313c0c
53 | # Version: 0.4.0
54 | ## Note
55 | * Version: 0.4.0
56 | * Build Date: 20230329
57 | * Praxis commit: 621c2ca7bfcd0e21ea118a3d8e40e29b48313c0c
58 | # Version: 0.3.0
59 | ## Major Features and Improvements
60 | * Fiddle migration
61 | * Improve numerical stability when using bfloat16
62 | * Improve and add new functionalities to decoding algorithms
63 | * Improve quantization support and add quantization aware training
64 | * Improve streaming support
65 | * Move learners / sgf and train_states modules to paxml
66 | * Misc renaming / API updates for consistency
67 | ## Note
68 | * Version: 0.3.0
69 | * Build Date: 20230201
70 | * Praxis commit: 9e1d13d888ac18a567e249ddb41e6b1bd1fe505a
71 | # Version: 0.2.1
72 | ## Note
73 | * Version: 0.2.1
74 | * Build Date: 20221121
75 | * Praxis commit: f7e98026c1c5ecbc6e4aff175621d443fa37fcf2
76 | # Version: 0.2.0
77 | ## Major Features and Improvements
78 | * Preparatory work for Fiddle integration
79 | * Support for Flax shape inference
80 | * Support for Jax Array
81 | * Optimizer additions and improvements:
82 | - HeroLion
83 | - ShardedAdagrad
84 | - ShardedStaticAccumulator optimizer wrapper to do a fixed number of gradient
85 | accumulations
86 | - Shampoo improvements
87 | - Fix for multi-optimizer following the introduction of optax.MaskedNode
88 | - Improve sanitization of NaNs/Infs gradients during training
89 | * Decoding
90 | - Add support for ExtendNSteps
91 | - Add beam search support for sequence models
92 | - Set prefix_lengths by input_indicator for PrefixLM
93 | - Move decode post-processing tensors into host memory
94 | * Summaries
95 | - Add support for verbosity level
96 | - Add more knobs to the learner to control summary generation
97 | ## Deprecations
98 | * Disallow hparams override in setup()
99 | * Hparams and layer names must now be distinct
100 | ## Note
101 | * Version: 0.2.0
102 | * Build Date: 20221114
103 | * Praxis commit: 413da1ad8148f27faebca119f8c5deedca66228b
104 | # Version: 0.1.0
105 | ## Major Features and Improvements
106 | ## Breaking changes
107 | ## Deprecations
108 | ## Note
109 | * Version: 0.1.0
110 | * Build Date: 20220702
111 | * Commit:
112 |
--------------------------------------------------------------------------------
/WORKSPACE:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Workspace file for Praxis."""
17 |
18 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
19 |
20 | http_archive(
21 | name = "bazel_skylib",
22 | urls = [
23 | "https://mirror.bazel.build/github.com/bazelbuild/bazel-skylib/releases/download/1.2.1/bazel-skylib-1.2.1.tar.gz",
24 | "https://github.com/bazelbuild/bazel-skylib/releases/download/1.2.1/bazel-skylib-1.2.1.tar.gz",
25 | ],
26 | sha256 = "f7be3474d42aae265405a592bb7da8e171919d74c16f082a5457840f06054728",
27 | )
28 |
29 | load("@bazel_skylib//:workspace.bzl", "bazel_skylib_workspace") #buildifier: disable=load-on-top
30 |
31 | http_archive(
32 | name = "com_google_protobuf", # v3.19.4 // 2022-01-2
33 | sha256 = "3bd7828aa5af4b13b99c191e8b1e884ebfa9ad371b0ce264605d347f135d2568",
34 | strip_prefix = "protobuf-3.19.4",
35 | url = "https://github.com/protocolbuffers/protobuf/archive/v3.19.4.tar.gz",
36 | )
37 |
38 | bazel_skylib_workspace()
39 |
40 | http_archive(
41 | name = "rules_python",
42 | sha256 = "cdf6b84084aad8f10bf20b46b77cb48d83c319ebe6458a18e9d2cebf57807cdd",
43 | strip_prefix = "rules_python-0.8.1",
44 | url = "https://github.com/bazelbuild/rules_python/archive/refs/tags/0.8.1.tar.gz",
45 | )
46 |
47 | http_archive(
48 | name = "zlib",
49 | build_file = "@com_google_protobuf//:third_party/zlib.BUILD",
50 | sha256 = "c3e5e9fdd5004dcb542feda5ee4f0ff0744628baf8ed2dd5d66f8ca1197cb1a1",
51 | strip_prefix = "zlib-1.2.11",
52 | urls = [
53 | "https://mirror.bazel.build/zlib.net/zlib-1.2.11.tar.gz",
54 | "https://zlib.net/zlib-1.2.11.tar.gz",
55 | ],
56 | )
57 |
--------------------------------------------------------------------------------
/praxis/AUTHORS:
--------------------------------------------------------------------------------
1 | # This is the list of Pax's significant contributors.
2 | #
3 | # This does not necessarily list everyone who has contributed code,
4 | # especially since many employees of one corporation may be contributing.
5 | # To see the full list of contributors, see the revision history in
6 | # source control.
7 | Google LLC
8 | NVIDIA Corporation
--------------------------------------------------------------------------------
/praxis/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/praxis/build-visibility.bzl:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Open source BUILD visibility used for Praxis."""
17 |
18 | JAX_VISIBILITY = ["//visibility:public"]
19 |
--------------------------------------------------------------------------------
/praxis/contrib/gpu/scripts_gpu/lora_layers.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from jax import numpy as jnp
17 | from praxis import base_layer
18 | from praxis import pax_fiddle
19 | from praxis import pytypes
20 | from praxis.layers.attentions import AttentionProjection, CombinedQKVProjectionLayer
21 | from praxis.layers.linears import Linear
22 |
23 |
24 | WeightInit = base_layer.WeightInit
25 | LayerTpl = pax_fiddle.Config[base_layer.BaseLayer]
26 | template_field = base_layer.template_field
27 | WeightHParams = base_layer.WeightHParams
28 | JTensor = pytypes.JTensor
29 |
30 |
31 | class LoraTheta(base_layer.Theta):
32 |
33 | def __init__(self, module):
34 | self.module = module
35 |
36 | def _lora_initialized(self):
37 | if (
38 | self.module.has_variable("params", "lora_a")
39 | and self.module.has_variable("params", "lora_b")
40 | and "lora_a" in self.module._weight_hparams
41 | and "lora_b" in self.module._weight_hparams
42 | ):
43 | return True
44 | else:
45 | return False
46 |
47 | def _lorafy_var(self, var):
48 | lora_a = super().__getattr__("lora_a")
49 | lora_b = super().__getattr__("lora_b")
50 | new_var = self.module.einsum("...dr,...nr->...dn", lora_a, lora_b)
51 | new_var = jnp.reshape(new_var, var.shape)
52 | new_var += var
53 | return new_var
54 |
55 | def __getattr__(self, k):
56 | var = super().__getattr__(k)
57 | if not self._lora_initialized():
58 | return var
59 |
60 | if k == "w":
61 | return self._lorafy_var(var)
62 |
63 | return var
64 |
65 | def __getitem__(self, k):
66 | var = super().__getattr__(k)
67 | if not self._lora_initialized():
68 | return var
69 |
70 | if k == "w":
71 | return self._lorafy_var(var)
72 |
73 | return var
74 |
75 |
76 | class LoraThetaDescriptor:
77 | """Dot syntax accession descriptor."""
78 |
79 | def __get__(self, obj, objtype=None):
80 | return LoraTheta(obj)
81 |
82 |
83 | class LoraLinear(Linear):
84 | rank: int = 0
85 | lora_init: WeightInit | None = None
86 | theta = LoraThetaDescriptor()
87 |
88 | def setup(self) -> None:
89 | lora_init = self.lora_init if self.lora_init else self.weight_init
90 |
91 | super().setup()
92 | self.create_variable(
93 | "lora_a",
94 | WeightHParams(
95 | shape=[self.input_dims, self.rank],
96 | init=lora_init,
97 | mesh_shape=self.mesh_shape,
98 | tensor_split_dims_mapping=[None, None],
99 | ),
100 | )
101 | self.create_variable(
102 | "lora_b",
103 | WeightHParams(
104 | shape=[self.output_dims, self.rank],
105 | init=WeightInit.Constant(scale=0.0),
106 | mesh_shape=self.mesh_shape,
107 | tensor_split_dims_mapping=[None, None],
108 | ),
109 | )
110 |
111 |
112 | class LoraAttentionProjection(AttentionProjection):
113 | rank: int = 0
114 | lora_init: WeightInit | None = None
115 | theta = LoraThetaDescriptor()
116 |
117 | def setup(self) -> None:
118 | super().setup()
119 | w_weight_params = self._weight_hparams["w"]
120 | lora_init = self.lora_init if self.lora_init else w_weight_params.init
121 |
122 | self.create_variable(
123 | "lora_a",
124 | WeightHParams(
125 | shape=[self.input_dim, self.rank],
126 | init=lora_init,
127 | mesh_shape=self.mesh_shape,
128 | tensor_split_dims_mapping=[
129 | None,
130 | None,
131 | ],
132 | ),
133 | )
134 | self.create_variable(
135 | "lora_b",
136 | WeightHParams(
137 | shape=[self.dim_per_head * self.num_heads, self.rank],
138 | init=WeightInit.Constant(scale=0.0),
139 | mesh_shape=self.mesh_shape,
140 | tensor_split_dims_mapping=[
141 | None,
142 | None,
143 | ],
144 | ),
145 | )
146 |
147 |
148 | class LoraCombinedQKVProjection(CombinedQKVProjectionLayer):
149 | rank: int = 0
150 | lora_init: WeightInit | None = None
151 | theta = LoraThetaDescriptor()
152 |
153 | def setup(self) -> None:
154 | super().setup()
155 | w_weight_params = self._weight_hparams["w"]
156 | lora_init = self.lora_init if self.lora_init else w_weight_params.init
157 |
158 | self.create_variable(
159 | "lora_a",
160 | WeightHParams(
161 | shape=[3, self.input_dim, self.rank],
162 | init=lora_init,
163 | mesh_shape=self.mesh_shape,
164 | tensor_split_dims_mapping=[None, None, None],
165 | ),
166 | )
167 | self.create_variable(
168 | "lora_b",
169 | WeightHParams(
170 | shape=[3, self.dim_per_head * self.num_heads, self.rank],
171 | init=WeightInit.Constant(scale=0.0),
172 | mesh_shape=self.mesh_shape,
173 | tensor_split_dims_mapping=[None, None, None],
174 | ),
175 | )
176 |
--------------------------------------------------------------------------------
/praxis/fiddle_tags.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Declares Fiddle tags for Praxis.
17 |
18 | Fiddle tags are essentially named collections of configuration values. They can
19 | diverge, but can also all be set at once to the same value.
20 | """
21 |
22 | import fiddle as fdl
23 |
24 |
25 | class BaseDType(fdl.Tag):
26 | """Base JAX DType.
27 |
28 | Usually you want ParamsDType or ActivationDType, but we provide a base class
29 | here in case the user wants to quickly set all DTypes.
30 | """
31 |
32 |
33 | class ParamsDType(BaseDType):
34 | """DType for parameters."""
35 |
36 |
37 | class ActivationDType(BaseDType):
38 | """DType for parameters."""
39 |
40 |
41 | class WeightInit(fdl.Tag):
42 | """Weight initializer class.
43 |
44 | Tagged values should generally be base_layer.WeightInit.
45 | """
46 |
47 |
48 | class ParamsSplitDimsMapping(fdl.Tag):
49 | """SplitDimsMapping for parameters."""
50 |
51 |
52 | class ActivationSplitDimsMapping(fdl.Tag):
53 | """SplitDimsMapping for activations."""
54 |
55 |
56 | class DropoutRate(fdl.Tag):
57 | """Tag for dropout rates."""
58 |
--------------------------------------------------------------------------------
/praxis/flat_beam_search_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Unit tests for flat_beam_search."""
17 |
18 | from absl.testing import absltest
19 | from jax import numpy as jnp
20 | import numpy as np
21 | from praxis import flat_beam_search
22 | from praxis import test_utils
23 |
24 |
25 | class FlatBeamSearchHelperTest(test_utils.TestCase):
26 |
27 | def test_update_mask_without_time_step(self):
28 | beam_mask = jnp.array(
29 | [[[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]],
30 | dtype=jnp.float32)
31 | hyp_id = jnp.array([[0, 3, 0, 1]], jnp.float32)
32 | update_beam_mask = flat_beam_search.update_beam_mask(
33 | beam_mask, hyp_id, time_step=None)
34 | self.assertArraysEqual(
35 | update_beam_mask,
36 | np.array([[[1, 0, 0, 0], [0, 0, 0, 1], [1, 0, 0, 0], [0, 1, 0, 0]]],
37 | dtype=np.float32))
38 |
39 | def test_update_mask_without_step2(self):
40 | beam_mask = jnp.array(
41 | [[
42 | [1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
43 | [0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0],
44 | [1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
45 | [0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0],
46 | ]],
47 | dtype=jnp.float32,
48 | )
49 | hyp_id = jnp.array([[3, 3, 0, 1]], jnp.float32)
50 | update_beam_mask = flat_beam_search.update_beam_mask(
51 | beam_mask, hyp_id, time_step=None)
52 | self.assertArraysEqual(
53 | update_beam_mask,
54 | np.array([[[0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0],
55 | [0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0],
56 | [1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
57 | [0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0]]],
58 | dtype=np.float32))
59 |
60 | def test_update_mask_with_step(self):
61 | beam_mask = jnp.array(
62 | [[
63 | [1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
64 | [0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0],
65 | [1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
66 | [0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0],
67 | ]],
68 | dtype=jnp.float32,
69 | )
70 | hyp_id = jnp.array([[3, 3, 0, 1]], jnp.float32)
71 | update_beam_mask = flat_beam_search.update_beam_mask(
72 | beam_mask, hyp_id, time_step=2)
73 | self.assertArraysEqual(
74 | update_beam_mask,
75 | np.array([[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0],
76 | [0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0],
77 | [1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0],
78 | [0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1]]],
79 | dtype=np.float32))
80 |
81 | def test_get_final_output_ids(self):
82 | beam_mask = jnp.array(
83 | [[[0, 0, 1, 0, 0, 0, 0, 1], [0, 0, 1, 0, 0, 0, 1, 0],
84 | [1, 0, 0, 0, 1, 0, 0, 0], [0, 0, 1, 0, 0, 1, 0, 0]]],
85 | dtype=jnp.int32)
86 | output_ids = jnp.array([[[0, 5], [1, 6], [2, 7], [3, 8]]], dtype=np.int32)
87 | final_output_ids = flat_beam_search.get_final_output_ids(
88 | beam_mask, output_ids)
89 |
90 | self.assertArraysEqual(
91 | final_output_ids,
92 | np.array([[[2, 8], [2, 7], [0, 5], [2, 6]]], dtype=np.int32))
93 |
94 | def test_update_topk_scores_with_eos(self):
95 | end_mask = jnp.array([[[0, 1], [0, 1]]], dtype=np.float32)
96 | cur_mask = jnp.array([[[1, 0], [0, 1]]], dtype=np.float32)
97 |
98 | end_scores = jnp.array([[0, 1]], dtype=np.float32)
99 | cur_scores = jnp.array([[2, 3]], dtype=np.float32)
100 |
101 | end_scores_norm = jnp.array([[2, 0]], dtype=np.float32)
102 | cur_scores_norm = jnp.array([[3, 1]], dtype=np.float32)
103 |
104 | (output_mask, output_scores,
105 | output_scores_norm) = flat_beam_search.update_topk_scores_with_eos(
106 | (end_mask, end_scores, end_scores_norm),
107 | (cur_mask, cur_scores, cur_scores_norm))
108 |
109 | self.assertArraysEqual(output_mask,
110 | np.array([[[1, 0], [0, 1]]], dtype=np.float32))
111 | self.assertArraysEqual(output_scores, np.array([[2, 0]], dtype=np.float32))
112 | self.assertArraysEqual(output_scores_norm,
113 | np.array([[3, 2]], dtype=np.float32))
114 |
115 |
116 | if __name__ == '__main__':
117 | absltest.main()
118 |
--------------------------------------------------------------------------------
/praxis/layers/activations_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for activations."""
17 |
18 | from absl.testing import absltest
19 | import jax
20 | from jax import numpy as jnp
21 | from praxis import base_layer
22 | from praxis import pax_fiddle
23 | from praxis import test_utils
24 | from praxis.layers import activations
25 |
26 |
27 | class ActivationsTest(test_utils.TestCase):
28 |
29 | def _run(self, p, inputs, outputs):
30 | self.assertEqual(outputs, base_layer.instantiate(p.set(name='n'))(inputs))
31 |
32 | def test_Exp(self):
33 | inputs = jnp.array([1.0])
34 | self._run(pax_fiddle.Config(activations.Exp), inputs, jnp.exp(inputs))
35 |
36 | def test_Softplus(self):
37 | inputs = jnp.array([1.0])
38 | self._run(
39 | pax_fiddle.Config(activations.Softplus), inputs, jax.nn.softplus(inputs)
40 | )
41 |
42 | def test_get_subclass_by_name(self):
43 | layer_cls = activations.BaseActivation.get_subclass_by_name('swish')
44 | self.assertIs(layer_cls, activations.Swish)
45 |
46 | layer_cls = activations.BaseActivation.get_subclass_by_name('cubed_relu')
47 | self.assertIs(layer_cls, activations.CubedReLU)
48 |
49 | def test_get_subclass_by_name_ambiguous(self):
50 | class Exp(activations.BaseActivation): # pylint: disable=unused-variable
51 | pass
52 |
53 | with self.assertRaises(KeyError):
54 | activations.BaseActivation.get_subclass_by_name('exp')
55 |
56 |
57 | if __name__ == '__main__':
58 | absltest.main()
59 |
--------------------------------------------------------------------------------
/praxis/layers/base_ops.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Op wrappers that can be used instead of native JAX ops to allow injection of custom ops.
17 |
18 | This is useful for quantization, sparsity and possibly other techniques.
19 | """
20 |
21 | import jax.numpy as jnp
22 | from praxis import base_layer
23 | from praxis import pytypes
24 |
25 | JTensor = pytypes.JTensor
26 |
27 |
28 | # These wrappers allow (with a help of fiddle) to inject custom ops
29 | # implementations into various Pax layers.
30 | # Base ops injection are useful if one wants to inject quantized or sparse op.
31 | # The main benefit is a reduction of the need for layer forking.
32 | #
33 | # It also allows these custom ops to use state e.g.:
34 | # to use variables for calibration stats and
35 | # to use random numbers e.g.: for stochastic rounding.
36 |
37 |
38 | class EinsumOp(base_layer.BaseLayer):
39 | """Wrapper around jnp.einsum used in standard Pax layers."""
40 |
41 | def __call__(self, equation: str, *args: JTensor) -> JTensor:
42 | return jnp.einsum(equation, *args)
43 |
44 |
45 | class EinsumGatedOp(base_layer.BaseLayer):
46 | """Wrapper around two jnp.einsum for gated FFN."""
47 |
48 | def __call__(self, equation: str, *args: JTensor) -> tuple[JTensor, JTensor]:
49 | assert len(args) == 3
50 | x, k, k_gated = args
51 | y = jnp.einsum(equation, x, k)
52 | y_gated = jnp.einsum(equation, x, k_gated)
53 | return y, y_gated
54 |
55 |
56 | class ArrayLookup(base_layer.BaseLayer):
57 | """Wrapper around array indexing as used in embedding lookup."""
58 |
59 | def __call__(self, x: JTensor, idx) -> JTensor:
60 | return x[idx]
61 |
--------------------------------------------------------------------------------
/praxis/layers/bregman_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for Praxis Bregman PCA layer."""
17 |
18 | from absl.testing import absltest
19 | from absl.testing import parameterized
20 | import jax
21 | import jax.numpy as jnp
22 | import numpy as np
23 | from praxis import base_layer
24 | from praxis import pax_fiddle
25 | from praxis import py_utils
26 | from praxis import test_utils
27 | from praxis.layers import bregman
28 |
29 | instantiate = base_layer.instantiate
30 | to_np = test_utils.to_np
31 | NON_TRAINABLE = base_layer.NON_TRAINABLE
32 |
33 |
34 | class BregmanTest(test_utils.TestCase):
35 |
36 | def setUp(self):
37 | super().setUp()
38 | np.random.seed(123456)
39 |
40 | @parameterized.parameters(
41 | ('IDENTITY', 0.99, 0.0, 0.0, True),
42 | ('IDENTITY', 1.0, 0.01, 0.0, True),
43 | ('IDENTITY', 1.0, 0.01, 0.01, False),
44 | ('IDENTITY', 0.99, 0.01, 0.01, True),
45 | ('LEAKY_RELU', 0.99, 0.0, 0.0, False),
46 | ('LEAKY_RELU', 1.0, 0.01, 0.0, True),
47 | ('LEAKY_RELU', 1.0, 0.01, 0.01, False),
48 | ('LEAKY_RELU', 0.99, 0.01, 0.01, True),
49 | ('SOFTMAX', 0.99, 0.0, 0.0, True),
50 | ('SOFTMAX', 1.0, 0.01, 0.0, True),
51 | ('SOFTMAX', 1.0, 0.01, 0.01, False),
52 | ('SOFTMAX', 0.99, 0.01, 0.01, False),
53 | )
54 | def test_bregman_layer(
55 | self,
56 | activation,
57 | mean_beta,
58 | coefficients_lr,
59 | components_lr,
60 | constant_lr_schedule,
61 | ):
62 | """Tests layer construction and the expected outputs."""
63 | activation_type = getattr(bregman.ActivationType, activation)
64 | p = pax_fiddle.Config(
65 | bregman.BregmanPCA,
66 | name='bregman_pca',
67 | num_components=3,
68 | input_dims=[8, 10],
69 | activation_type=activation_type,
70 | negative_slope=0.1,
71 | mean_beta=mean_beta,
72 | coefficients_lr=coefficients_lr,
73 | coefficients_beta=0.9,
74 | coefficients_steps=20,
75 | components_lr=components_lr,
76 | components_beta=0.9,
77 | start_step=0,
78 | end_step=1,
79 | constant_lr_schedule=constant_lr_schedule,
80 | )
81 | layer = instantiate(p)
82 | if activation == 'SOFTMAX':
83 | npy_input = np.random.random([16] + p.input_dims).astype('float32')
84 | npy_input = npy_input / np.sum(npy_input, axis=-1, keepdims=True)
85 | else:
86 | npy_input = np.random.normal(1.0, 0.5, [16] + p.input_dims).astype(
87 | 'float32'
88 | )
89 | inputs = jnp.asarray(npy_input)
90 | with base_layer.JaxContext.new_context():
91 | prng_key = jax.random.PRNGKey(seed=123)
92 | initial_vars = layer.init(prng_key, inputs)
93 |
94 | @jax.jit
95 | def comp(theta, inputs):
96 | with base_layer.JaxContext.new_context():
97 | return layer.apply(theta, inputs, mutable=[NON_TRAINABLE])
98 |
99 | (outputs, coefficients), updated_vars = comp(initial_vars, inputs)
100 | self.assertAllClose(outputs, inputs, atol=1e-5)
101 |
102 | with base_layer.JaxContext.new_context():
103 | layer = layer.bind(initial_vars, mutable=[NON_TRAINABLE])
104 | initial_vars = py_utils.NestedMap.FromNestedDict(
105 | initial_vars['non_trainable']
106 | )
107 | init_components = initial_vars.components
108 | init_mean = initial_vars.mean
109 | mean = updated_vars[NON_TRAINABLE]['mean']
110 | components = updated_vars[NON_TRAINABLE]['components']
111 | init_loss = layer.bregman_loss_fn(
112 | jnp.zeros_like(coefficients), init_components, init_mean, inputs
113 | )
114 | final_loss = layer.bregman_loss_fn(coefficients, components, mean, inputs)
115 | self.assertLess(final_loss, init_loss)
116 |
117 | representations = layer.reconstruct(coefficients)
118 | self.assertEqual(representations.shape, inputs.shape)
119 |
120 | def test_pca_convergence(self):
121 | """Tests whether the gradients are zero at the solution."""
122 | p = pax_fiddle.Config(
123 | bregman.BregmanPCA,
124 | name='bregman_pca',
125 | num_components=3,
126 | input_dims=[3],
127 | activation_type=bregman.ActivationType.IDENTITY,
128 | start_step=0,
129 | end_step=1,
130 | )
131 | layer = instantiate(p)
132 | npy_input = np.random.normal(1.0, 0.5, [16] + p.input_dims).astype(
133 | 'float32'
134 | )
135 | inputs = jnp.asarray(npy_input)
136 |
137 | with base_layer.JaxContext.new_context():
138 | prng_key = jax.random.PRNGKey(seed=123)
139 | initial_vars = layer.init(prng_key, inputs)
140 | layer = layer.bind(initial_vars, mutable=[NON_TRAINABLE])
141 | mean = jnp.zeros((1, 3))
142 | components = jnp.eye(3)
143 | coefficients_grad = layer.coefficients_grad_fn(
144 | inputs, components, mean, inputs
145 | )
146 | components_grad = layer.components_grad_fn(inputs, components, mean, inputs)
147 | self.assertAllClose(
148 | coefficients_grad, jnp.zeros_like(coefficients_grad), atol=1e-5
149 | )
150 | self.assertAllClose(
151 | components_grad, jnp.zeros_like(components_grad), atol=1e-5
152 | )
153 |
154 |
155 | if __name__ == '__main__':
156 | absltest.main()
157 |
--------------------------------------------------------------------------------
/praxis/layers/chain/BUILD:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # Description: Chain and utilities.
17 | # The public API is defined in __init__.py.
18 |
19 | load("//praxis:build-visibility.bzl", "JAX_VISIBILITY")
20 | load("//praxis:praxis.bzl", "pytype_strict_library", "pytype_strict_test")
21 |
22 | package(default_visibility = JAX_VISIBILITY)
23 |
24 | pytype_strict_library(
25 | name = "chain",
26 | srcs = ["__init__.py"],
27 | deps = [
28 | ":chain_extensions",
29 | ":chain_lib",
30 | ],
31 | )
32 |
33 | pytype_strict_library(
34 | name = "chain_lib",
35 | srcs = ["chain.py"],
36 | deps = ["//praxis:base_layer"],
37 | )
38 |
39 | pytype_strict_library(
40 | name = "chain_extensions",
41 | srcs = ["chain_extensions.py"],
42 | deps = [
43 | ":chain_lib",
44 | # Implicit absl.logging dependency.
45 | # Implicit jax dependency.
46 | "//praxis:base_layer",
47 | "//praxis:pax_fiddle",
48 | "//praxis:py_utils",
49 | "//praxis/layers:activations",
50 | "//praxis/layers:linears",
51 | "//praxis/layers:repeats",
52 | ],
53 | )
54 |
55 | pytype_strict_test(
56 | name = "chain_test",
57 | srcs = ["chain_test.py"],
58 | deps = [
59 | ":chain",
60 | # Implicit absl.testing.absltest dependency.
61 | # Implicit absl.testing.parameterized dependency.
62 | # Implicit upb python proto dependency.
63 | # Implicit jax dependency.
64 | # Implicit numpy dependency.
65 | "//praxis:base_hyperparams",
66 | "//praxis:base_layer",
67 | "//praxis:pax_fiddle",
68 | "//praxis:py_utils",
69 | "//praxis:test_utils",
70 | "//praxis/layers:activations",
71 | ],
72 | )
73 |
--------------------------------------------------------------------------------
/praxis/layers/chain/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Expose public api of Chain and utilities."""
17 |
18 | from praxis.layers.chain.chain import call_chain
19 | from praxis.layers.chain.chain import Chain
20 | from praxis.layers.chain.chain_extensions import add_residual
21 | from praxis.layers.chain.chain_extensions import apply_padding
22 | from praxis.layers.chain.chain_extensions import chain
23 | from praxis.layers.chain.chain_extensions import copy_n_times
24 | from praxis.layers.chain.chain_extensions import dict_to_args
25 | from praxis.layers.chain.chain_extensions import feed_forward
26 | from praxis.layers.chain.chain_extensions import full_like
27 | from praxis.layers.chain.chain_extensions import kwargs_with_name
28 | from praxis.layers.chain.chain_extensions import log_args
29 | from praxis.layers.chain.chain_extensions import repeat
30 |
--------------------------------------------------------------------------------
/praxis/layers/chain/chain.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """`Chain` is a utility for chaining layers.
17 |
18 | In its most typical use `Chain` the output from the previous layer is given as
19 | input to the next. But `Chain` extends and generalizes combination of layers by
20 | treating `*args` as a stack; popping arguments and pushing outputs as it
21 | iterates through the layers.
22 | """
23 |
24 | import inspect
25 | from typing import Any, Callable, Sequence
26 |
27 | from praxis import base_layer
28 |
29 |
30 | class Chain(base_layer.BaseLayer):
31 | """A utility layer for chaining a list of layers.
32 |
33 | `Chain` is similar to `Sequential` but extends and generalizes combination of
34 | layers by treating `*args` as a stack; popping required arguments and pushing
35 | outputs as it iterates through the layers.
36 |
37 | Specifically when the layers' signatures are different the leading arguments
38 | are replaced with the outputs from the previous layer:
39 | outputs = layer(args[:num_args])
40 | args = outputs + args[len(outputs):] # replace leading arguments
41 |
42 | That is, `*args` is processed like a stack (leading arguments are poped and
43 | replaced). This stack logic is implemented in `call_chained_layer()`.
44 |
45 | `**kwargs` are passed straight-through the chain; unmodified. Further kwargs
46 | are only passed to layers whose `__call__()` specifies the existence of such
47 | an argument.
48 |
49 | Attributes:
50 | preserve_adopted_names: Flax module naming property (don't change).
51 | layers: The list of layers.
52 | """
53 |
54 | preserve_adopted_names: bool = True
55 | layers: Sequence[base_layer.BaseLayer] | None = None
56 |
57 | def __call__(self, *args: Any, **kwargs: Any) -> Any:
58 | """Calls layers one-by-one and returns the last chain output."""
59 | return self.call_chain(*args, **kwargs)[-1]
60 |
61 | def call_chain(self, *args: Any, **kwargs: Any) -> Sequence[Any]:
62 | """Calls layers one-by-one and returns a list with each layer's output."""
63 | return call_chain(self.layers, *args, **kwargs)
64 |
65 |
66 | def call_chain(
67 | layers: Sequence[Callable[..., Any]], *args: Any, **kwargs: Any
68 | ) -> Sequence[Any]:
69 | """Calls layers one-by-one and returns a list with each layer's output.
70 |
71 | Wraps `call_chained_layer()` which implements the core logic of chain
72 | argument passing.
73 |
74 | Args:
75 | layers: The sequence of `BaseLayer`s.
76 | *args: The argument stack.
77 | **kwargs: The optional kwargs.
78 |
79 | Returns:
80 | A list with the output from each layer.
81 | """
82 | outputs = []
83 | args_stack = args
84 | for i, l in enumerate([l for l in layers if l is not None]):
85 | try:
86 | layer_outs = call_chained_layer(l, *args_stack, **kwargs)
87 | except Exception as e:
88 | raise type(e)(
89 | str(e) + f' Layer index={i} name={_name_attr(l)} args: {args_stack}'
90 | ) from e
91 |
92 | outputs.append(layer_outs)
93 | args_stack = ensure_tuple(layer_outs)
94 |
95 | if not outputs:
96 | return [args if len(args) > 1 else args[0]]
97 |
98 | return outputs
99 |
100 |
101 | def call_chained_layer(
102 | layer: Callable[..., Any], *args: Any, **kwargs: Any
103 | ) -> Any:
104 | """Passes required arguments and matching kwargs to `layer`.
105 |
106 | This is the main function implementing the argument stack; it passes
107 | required arguments to the layer and all matching kwargs. Other arguments are
108 | withheld.
109 |
110 | Args:
111 | layer: The layer to call.
112 | *args: The arguments to pass to required parameters.
113 | **kwargs: The kwargs to pass to existing kwargs.
114 |
115 | Returns:
116 | The new argument stack with leading arguments replaced:
117 | outs = layer(args[:num_args])
118 | outputs = outs + args[len(outs):]
119 | """
120 | signature = inspect.signature(layer.__call__)
121 | num_args = len(args)
122 | count = 0
123 | is_variadic = False
124 | for name, p in signature.parameters.items():
125 | if name in kwargs.keys():
126 | break
127 | if p.default != inspect.Signature.empty:
128 | break
129 | # Pass everything if variadic.
130 | if p.kind in (p.VAR_POSITIONAL, p.VAR_KEYWORD):
131 | is_variadic = True
132 | break
133 | count += 1
134 | if count > num_args:
135 | raise ValueError(
136 | f'Layer name={_name_attr(layer)} has too many args {count} > {num_args}'
137 | f' signature={signature}'
138 | )
139 | matching_kwargs = {
140 | k: v for k, v in kwargs.items() if k in signature.parameters.keys()
141 | }
142 |
143 | if is_variadic:
144 | outs = layer(*args, **kwargs)
145 | else:
146 | outs = layer(*args[:count], **matching_kwargs)
147 |
148 | outs = list(ensure_tuple(outs)) if outs is not None else []
149 | outputs = outs + list(args[len(outs) :])
150 | return tuple(outputs) if len(outputs) > 1 else outputs[0]
151 |
152 |
153 | def ensure_tuple(x: Any) -> tuple[Any, ...]:
154 | """Ensures that `x` is a tuple."""
155 | return x if isinstance(x, tuple) else (x,)
156 |
157 |
158 | def _name_attr(layer: Callable[..., Any]) -> str:
159 | """Returns the `name` attribute if it exists."""
160 | return layer.name if hasattr(layer, 'name') else ''
161 |
--------------------------------------------------------------------------------
/praxis/layers/checkpoint_policy.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Gradient checkpoint policies that are supported by the `checkpoint` transform."""
17 |
18 | import enum
19 | import jax
20 |
21 |
22 | @enum.unique
23 | class AutodiffCheckpointType(str, enum.Enum):
24 | """jax.checkpoint policy types."""
25 |
26 | SAVE_NOTHING = 'save_nothing'
27 | SAVE_UNET_ALL_CONV = 'save_unet_all_conv'
28 | SAVE_UNET_CONV = 'save_unet_conv'
29 | SAVE_EVERYTHING = 'save_everything'
30 | SAVE_QKV_OUT_PROJ = 'save_qkv_out_proj'
31 | SAVE_OUT_PROJ = 'save_out_proj'
32 | SAVE_CONTEXT = 'save_context'
33 | SAVE_CONTEXT_AND_OUT_PROJ = 'save_encoded_and_out_proj'
34 | SAVE_DOT_ONLY = 'save_dot_only'
35 | SAVE_DOT_WITH_NO_BATCH_DIM = 'save_dot_with_no_batch_dims'
36 | SAVE_DOT_FOR_MLPERF_200B = 'save_dot_for_mlperf_200b'
37 | SAVE_ITERATION_INPUT = 'save_iteration_input'
38 | SAVE_TRANSFORMER_LAYER_OUTPUT = 'save_transformer_layer_output'
39 | SAVE_QUANTIZED = 'save_quantized'
40 | SAVE_QKV_OUT_PROJ_SEPARATE = 'save_qkv_out_proj_separate'
41 | SAVE_DOT_EXCEPT_LOGITS_FFN1 = 'save_dot_except_logits_ffn1'
42 | SAVE_DOT_EXCEPT_LOGITS = 'save_dot_except_logits'
43 |
44 |
45 | def custom_policy(checkpoint_policy: AutodiffCheckpointType):
46 | """Returns a JAX Autodiff checkpointing policy from the enum value."""
47 | # TODO(zhangqiaorjc): Configure custom checkpoint policy in expt config
48 | # without introducing enum.
49 | if checkpoint_policy == AutodiffCheckpointType.SAVE_EVERYTHING:
50 | return jax.checkpoint_policies.everything_saveable
51 | if checkpoint_policy == AutodiffCheckpointType.SAVE_DOT_ONLY:
52 | return jax.checkpoint_policies.checkpoint_dots
53 | if checkpoint_policy == AutodiffCheckpointType.SAVE_DOT_WITH_NO_BATCH_DIM:
54 | return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims
55 | if checkpoint_policy == AutodiffCheckpointType.SAVE_QKV_OUT_PROJ:
56 | return jax.checkpoint_policies.save_only_these_names(
57 | 'combined_qkv_proj', 'out_proj'
58 | )
59 | if checkpoint_policy == AutodiffCheckpointType.SAVE_QKV_OUT_PROJ_SEPARATE:
60 | return jax.checkpoint_policies.save_only_these_names(
61 | 'query_proj', 'value_proj', 'key_proj', 'out_proj'
62 | )
63 | if checkpoint_policy == AutodiffCheckpointType.SAVE_CONTEXT:
64 | return jax.checkpoint_policies.save_only_these_names('context')
65 | if checkpoint_policy == AutodiffCheckpointType.SAVE_OUT_PROJ:
66 | return jax.checkpoint_policies.save_only_these_names('out_proj')
67 | if checkpoint_policy == AutodiffCheckpointType.SAVE_CONTEXT_AND_OUT_PROJ:
68 | return jax.checkpoint_policies.save_only_these_names('context', 'out_proj')
69 | if checkpoint_policy == AutodiffCheckpointType.SAVE_DOT_FOR_MLPERF_200B:
70 | return jax.checkpoint_policies.save_only_these_names(
71 | 'combined_qkv_proj',
72 | 'query_proj',
73 | 'value_proj',
74 | 'key_proj',
75 | 'context',
76 | 'out_proj',
77 | )
78 | if checkpoint_policy == AutodiffCheckpointType.SAVE_QUANTIZED:
79 | return jax.checkpoint_policies.save_only_these_names(
80 | 'context',
81 | 'out_proj',
82 | 'combined_qkv_proj',
83 | 'qlhs',
84 | 'qrhs',
85 | 'lhs_scale',
86 | 'rhs_scale',
87 | )
88 | if checkpoint_policy == AutodiffCheckpointType.SAVE_ITERATION_INPUT:
89 | return jax.checkpoint_policies.save_only_these_names('iteration_input')
90 | if checkpoint_policy == AutodiffCheckpointType.SAVE_TRANSFORMER_LAYER_OUTPUT:
91 | return jax.checkpoint_policies.save_only_these_names(
92 | 'transformer_layer_out'
93 | )
94 | if checkpoint_policy == AutodiffCheckpointType.SAVE_DOT_EXCEPT_LOGITS_FFN1:
95 | return jax.checkpoint_policies.save_only_these_names(
96 | 'combined_qkv_proj',
97 | 'query_proj',
98 | 'value_proj',
99 | 'key_proj',
100 | 'context',
101 | 'out_proj',
102 | 'ffn2',
103 | )
104 | if checkpoint_policy == AutodiffCheckpointType.SAVE_DOT_EXCEPT_LOGITS:
105 | return jax.checkpoint_policies.save_only_these_names(
106 | 'combined_qkv_proj',
107 | 'query_proj',
108 | 'value_proj',
109 | 'key_proj',
110 | 'context',
111 | 'out_proj',
112 | 'ffn1',
113 | 'ffn2',
114 | )
115 | if checkpoint_policy == AutodiffCheckpointType.SAVE_UNET_CONV:
116 | return jax.checkpoint_policies.save_only_these_names('conv_out')
117 | if checkpoint_policy == AutodiffCheckpointType.SAVE_UNET_ALL_CONV:
118 | return jax.checkpoint_policies.save_only_these_names(
119 | 'conv_0', 'conv_1', 'conv_out'
120 | )
121 | assert checkpoint_policy == AutodiffCheckpointType.SAVE_NOTHING
122 | return jax.checkpoint_policies.nothing_saveable
123 |
--------------------------------------------------------------------------------
/praxis/layers/chunk.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Chunk functions."""
17 |
18 | import jax
19 | from jax import numpy as jnp
20 | from praxis import pytypes
21 |
22 | JTensor = pytypes.JTensor
23 |
24 |
25 | def chunk(x: JTensor, *, chunk_size: int, axis: int = 1) -> JTensor:
26 | """Splits the x tensor into chunks with `chunk_size`.
27 |
28 | Args:
29 | x: JTensor of shape [..., L, ...].
30 | chunk_size: int, the chunk size.
31 | axis: axis for chunking.
32 |
33 | Returns:
34 | JTensor of shape [L/C, ..., C, ...].
35 | """
36 | # Maybe pad paddings to multiple of chunk_size.
37 | extra_pad = -x.shape[axis] % chunk_size
38 | if extra_pad:
39 | # Use default pad value unless specified.
40 | x = jnp.pad(x, [[0, extra_pad if i == axis else 0] for i in range(x.ndim)])
41 |
42 | x_shape = jnp.shape(x)
43 | num_chunks = x_shape[axis] // chunk_size
44 | # Reshape to [..., L/C, C, ...].
45 | new_shape = x_shape[:axis] + (num_chunks, chunk_size) + x_shape[axis + 1 :]
46 | chunk_x = jnp.reshape(x, new_shape)
47 | # [L/C, ..., C, ...]
48 | return jnp.moveaxis(chunk_x, axis, 0)
49 |
50 |
51 | def unchunk(
52 | chunk_x: JTensor, axis: int = 1, seqlen: int | None = None
53 | ) -> JTensor:
54 | """Reshapes the x tensor from chunks to the original shape.
55 |
56 | Args:
57 | chunk_x: JTensor of shape [L/C, ..., C, ...].
58 | axis: axis for chunking.
59 | seqlen: int, the original sequence length.
60 |
61 | Returns:
62 | JTensor of shape [..., L, ...].
63 | """
64 | chunk_size = chunk_x.shape[axis + 1]
65 | # [..., L/C, C, ...]
66 | chunk_x = jnp.moveaxis(chunk_x, 0, axis)
67 | # [..., L, ...]
68 | new_shape = chunk_x.shape[:axis] + (-1,) + chunk_x.shape[axis + 2 :]
69 | x = jnp.reshape(chunk_x, new_shape)
70 | if seqlen is not None and seqlen % chunk_size:
71 | x = jax.lax.slice_in_dim(x, 0, seqlen, axis=axis)
72 | return x
73 |
--------------------------------------------------------------------------------
/praxis/layers/chunk_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for Praxis Chunk functions."""
17 |
18 | from absl.testing import absltest
19 | from absl.testing import parameterized
20 | import numpy as np
21 | from praxis import test_utils
22 | from praxis.layers import chunk
23 |
24 |
25 | class ChunkTest(test_utils.TestCase):
26 |
27 | def setUp(self):
28 | super().setUp()
29 | np.random.seed(123456)
30 |
31 | @parameterized.parameters(
32 | ([3, 4, 5], 0, 2, [2, 2, 4, 5]),
33 | ([3, 5], 1, 2, [3, 3, 2]),
34 | ([3, 4, 2, 8], 1, 3, [2, 3, 3, 2, 8]),
35 | ([3, 4, 2, 8], 3, 4, [2, 3, 4, 2, 4]),
36 | )
37 | def test_chunk(self, in_shape, axis, chunk_size, chunk_shape):
38 | x = np.random.normal(1.0, 0.5, in_shape)
39 | chunk_x = chunk.chunk(x, chunk_size=chunk_size, axis=axis)
40 | self.assertArraysEqual(chunk_x.shape, chunk_shape)
41 |
42 | out_x = chunk.unchunk(chunk_x, axis=axis, seqlen=x.shape[axis])
43 | self.assertAllClose(x, out_x)
44 |
45 |
46 | if __name__ == '__main__':
47 | absltest.main()
48 |
--------------------------------------------------------------------------------
/praxis/layers/einsum.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """A layer that computes an Einsum with a weight, and optionally adds a bias."""
17 |
18 | from typing import Sequence
19 |
20 | from praxis import base_layer
21 | from praxis import pax_fiddle
22 | from praxis import pytypes
23 | from praxis.layers import base_ops
24 |
25 | JTensor = pytypes.JTensor
26 | LayerTpl = pax_fiddle.Config[base_layer.BaseLayer]
27 | template_field = base_layer.template_field
28 |
29 |
30 | class Einsum(base_layer.BaseLayer):
31 | """Layer that computes an einsum and maybe a bias.
32 |
33 | The fan-in, fan-out and bias dimensions are inferred from the einsum equation.
34 | If bias is used, the fan-out dims must appear at the end of the output tensor.
35 |
36 | Attributes:
37 | eqn: Einsum equation. It should be in the format of input,w->output. E.g.,
38 | '...d,df->...f'.
39 | w_shape: Weight shape.
40 | use_bias: Whether to add a bias.
41 | einsum_op_tpl: The op definition that implements einsum. Enables injection.
42 | """
43 | eqn: str = ''
44 | w_shape: Sequence[int] = ()
45 | use_bias: bool = False
46 | einsum_op_tpl: LayerTpl = template_field(base_ops.EinsumOp)
47 |
48 | def setup(self) -> None:
49 | operands, out = self.eqn.split('->')
50 | x, w = operands.split(',')
51 | assert '.' not in w
52 | fan_in = sorted(w.index(d) for d in (set(x) - set(out)))
53 | fan_out = sorted(w.index(d) for d in (set(out) - set(x)))
54 | w_sharding = self.weight_split_dims_mapping.wt
55 | pc = base_layer.WeightHParams(
56 | shape=self.w_shape,
57 | fan_in_axes=fan_in,
58 | fan_out_axes=fan_out,
59 | mesh_shape=self.mesh_shape,
60 | tensor_split_dims_mapping=w_sharding,
61 | )
62 | self.create_variable('w', pc)
63 | if self.use_bias:
64 | out_bias_dims = sorted(out.index(d) for d in (set(out) - set(x)))
65 | # Fan-out dims must be at the end of `out`.
66 | assert all(d >= len(out) - len(out_bias_dims) for d in out_bias_dims)
67 | bias_shape = [self.w_shape[w.index(out[d])] for d in out_bias_dims]
68 | if w_sharding is not None:
69 | b_sharding = [w_sharding[w.index(out[d])] for d in out_bias_dims]
70 | else:
71 | b_sharding = None
72 | pc_bias = base_layer.WeightHParams(
73 | shape=bias_shape,
74 | init=base_layer.WeightInit.Constant(0.0),
75 | mesh_shape=self.mesh_shape,
76 | tensor_split_dims_mapping=b_sharding,
77 | )
78 | self.create_variable('b', pc_bias)
79 | self.create_child('einsum', self.einsum_op_tpl.clone())
80 |
81 | def __call__(self, inputs: JTensor) -> JTensor:
82 | """Computes the einsum and maybe bias.
83 |
84 | Args:
85 | inputs: A JTensor of shape as described in the equation.
86 |
87 | Returns:
88 | The result of the einsum with maybe a bias added.
89 | """
90 | ret = self.einsum(self.eqn, inputs, self.theta.w)
91 | if self.use_bias:
92 | ret += self.theta.b
93 | return base_layer.maybe_shard(
94 | ret, self.activation_split_dims_mapping.out, self.mesh_axis_names
95 | )
96 |
--------------------------------------------------------------------------------
/praxis/layers/einsum_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for Praxis Einsum layers."""
17 |
18 | from absl.testing import absltest
19 | from absl.testing import parameterized
20 | import jax
21 | import numpy as np
22 | from praxis import base_layer
23 | from praxis import pax_fiddle
24 | from praxis import test_utils
25 | from praxis.layers import einsum
26 |
27 | instantiate = base_layer.instantiate
28 |
29 |
30 | class EinsumTest(test_utils.TestCase):
31 |
32 | def setUp(self):
33 | super().setUp()
34 | np.random.seed(123456)
35 |
36 | @parameterized.parameters(
37 | ('...d,df->...f', [3, 4, 5], [5, 8], False),
38 | ('...d,dnh->...nh', [3, 5], [5, 2, 8], True),
39 | ('blnh,dnh->bld', [3, 4, 2, 8], [5, 2, 8], False),
40 | ('...nh,hdn->...d', [3, 4, 2, 8], [8, 5, 2], True),
41 | )
42 | def test_einsum(self, eqn, in_shape, w_shape, use_bias):
43 | p = pax_fiddle.Config(
44 | einsum.Einsum,
45 | name='einsum',
46 | eqn=eqn,
47 | w_shape=w_shape,
48 | use_bias=use_bias,
49 | )
50 | layer = instantiate(p)
51 | inputs = np.random.normal(1.0, 0.5, in_shape).astype(
52 | 'float32'
53 | )
54 | prng_key = jax.random.PRNGKey(seed=123)
55 | initial_vars = layer.init(prng_key, inputs)
56 | if use_bias:
57 | # Make sure the bias is non-zero.
58 | initial_vars['params']['b'] = np.random.normal(
59 | 1.0, 0.5, initial_vars['params']['b'].shape
60 | ).astype('float32')
61 | outputs = layer.apply(initial_vars, inputs)
62 | np_outputs = np.einsum(eqn, inputs, initial_vars['params']['w'])
63 | if use_bias:
64 | np_outputs += initial_vars['params']['b']
65 | self.assertAllClose(outputs, np_outputs, atol=1e-6)
66 |
67 |
68 | if __name__ == '__main__':
69 | absltest.main()
70 |
--------------------------------------------------------------------------------
/praxis/layers/injection/BUILD:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # Description:
17 | # Praxis layers. The public API is defined in __init__.py.
18 |
19 | load("//praxis:build-visibility.bzl", "JAX_VISIBILITY")
20 | load("//praxis:praxis.bzl", "pytype_strict_library", "pytype_strict_test")
21 |
22 | package(default_visibility = JAX_VISIBILITY)
23 |
24 | licenses(["notice"])
25 |
26 | pytype_strict_library(
27 | name = "fp8_nvidia_gpu",
28 | srcs = ["fp8_nvidia_gpu.py"],
29 | deps = [
30 | # Implicit flax.core dependency.
31 | # Implicit jax dependency.
32 | "//praxis:base_layer",
33 | "//praxis:pax_fiddle",
34 | "//praxis:pytypes",
35 | "//praxis/layers",
36 | ],
37 | )
38 |
39 | pytype_strict_test(
40 | name = "fp8_nvidia_gpu_test",
41 | srcs = ["fp8_nvidia_gpu_test.py"],
42 | deps = [
43 | ":fp8_nvidia_gpu",
44 | # Implicit absl.testing.absltest dependency.
45 | # Implicit absl.testing.parameterized dependency.
46 | # Implicit flax.core dependency.
47 | # Implicit upb python proto dependency.
48 | # Implicit jax dependency.
49 | "//praxis:base_layer",
50 | "//praxis:pax_fiddle",
51 | "//praxis:test_utils",
52 | "//praxis/layers:linears",
53 | "//praxis/layers:pipeline",
54 | ],
55 | )
56 |
--------------------------------------------------------------------------------
/praxis/layers/injection/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Exposes the public layer functionalities for injection."""
17 |
18 | from praxis.layers.injection import fp8_nvidia_gpu
19 |
--------------------------------------------------------------------------------
/praxis/layers/losses.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Loss functions."""
17 |
18 | from jax import numpy as jnp
19 | from praxis import base_layer
20 | from praxis import pytypes
21 |
22 | from jax_bitempered_loss import loss
23 |
24 | WeightHParams = base_layer.WeightHParams
25 | WeightInit = base_layer.WeightInit
26 | JTensor = pytypes.JTensor
27 |
28 |
29 | class BiTemperedLoss(base_layer.BaseLayer):
30 | """Bi-tempered logitstic loss.
31 |
32 | Bi-Tempered logistic loss is a generalized softmax cross-entropy loss function
33 | with a bounded loss value per sample and a heavy-tail softmax probability
34 | function. Temperature t1 < 1.0 controls the boundedness and t2 > 1.0 controls
35 | the tail heaviness of the softmax probabilities.
36 |
37 | Temperature pairs (1, 1) correspond to CE loss. The schedule sets the
38 | temperatures to (1, 1) initially and transitions to (t1, t2) linearly over
39 | the interval [start_step, end_step]. If end_step == 0, then the temperatures
40 | are set to (t1, t2) throughout.
41 |
42 | Source: https://bit.ly/3jSol8T
43 |
44 | Attributes:
45 | t1: Temperature 1 (log).
46 | t2: Temperature 2 (exp).
47 | label_smoothing: Label smoothing.
48 | start_step: Step number to start transitioning from CE loss.
49 | end_step: Step number to reach the final temperature pairs (t1, t2). When
50 | end_step == 0, the temperatures are set to (t1, t2) throughout.
51 | """
52 | t1: float = 1.0
53 | t2: float = 1.0
54 | label_smoothing: float = 0.0
55 | start_step: int = 0
56 | end_step: int = 0
57 |
58 | def setup(self) -> None:
59 | """Initialize the step variable."""
60 | assert self.end_step >= self.start_step
61 | count = WeightHParams(
62 | shape=[],
63 | init=WeightInit.Constant(0.0),
64 | dtype=jnp.float32,
65 | collections=[base_layer.WeightHParamsCollection.REQUIRES_MEAN_SYNC])
66 | self.create_variable('count', count, trainable=False)
67 |
68 | def temperature_schedule(self, count: JTensor) -> JTensor:
69 | """Temperature schedule.
70 |
71 | The temperatures will be set to the final values if end_step == 0.
72 |
73 | Args:
74 | count: Step number.
75 |
76 | Returns:
77 | Base schedule.
78 | """
79 | count = jnp.array(count).astype(jnp.float32)
80 | schedule = jnp.where(
81 | jnp.logical_and(self.end_step > 0, count < self.end_step), 1.0, 0.0
82 | )
83 | schedule = jnp.where(
84 | count >= self.start_step,
85 | jnp.maximum(
86 | 1.0
87 | - (count - self.start_step)
88 | / jnp.maximum(self.end_step - self.start_step, 1.0),
89 | 0.0,
90 | ),
91 | schedule,
92 | )
93 | return schedule
94 |
95 | def __call__(self, logits: JTensor, labels: JTensor) -> JTensor:
96 | """Applies bi-tempered loss.
97 |
98 | Args:
99 | logits: The logits JTensor. Shaped [..., num_classes].
100 | labels: The one-hot labels JTensor. Shaped [..., num_classes].
101 |
102 | Returns:
103 | Loss values. Shaped either [...] or same as logits/labels but without the
104 | last dimension of size `num_classes`.
105 | """
106 | base_schedule = 0.0
107 | if not self.do_eval:
108 | count = self.get_var('count')
109 | self.update_var('count', count + 1.0)
110 | base_schedule = self.temperature_schedule(count)
111 | t1 = 1.0 * base_schedule + self.t1 * (1.0 - base_schedule)
112 | t2 = 1.0 * base_schedule + self.t2 * (1.0 - base_schedule)
113 | loss_vals = loss.bi_tempered_logistic_loss(
114 | logits, labels, t1, t2, label_smoothing=self.label_smoothing
115 | )
116 | return loss_vals
117 |
--------------------------------------------------------------------------------
/praxis/layers/quantization/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Exposes the public layer functionalities."""
17 |
18 | from praxis.layers.quantization.attentions import AttentionProjection
19 | from praxis.layers.quantization.attentions import AttentionProjectionLoRA
20 | from praxis.layers.quantization.attentions import CombinedQKVProjectionLayer
21 | from praxis.layers.quantization.attentions import DotProductAttention
22 | from praxis.layers.quantization.conformers import DotProductAttentionWithContext
23 | from praxis.layers.quantization.convolutions import Conv2D
24 | from praxis.layers.quantization.einsum import Einsum
25 | from praxis.layers.quantization.embedding_softmax import Embedding
26 | from praxis.layers.quantization.embedding_softmax import NClassMajorSharedEmbeddingSoftmax
27 | from praxis.layers.quantization.embedding_softmax import SharedEmbeddingSoftmax
28 | from praxis.layers.quantization.linears import Linear
29 | from praxis.layers.quantization.linears import LinearLoRA
30 | from praxis.layers.quantization.multi_query_attention import OneHeadedAttentionProjection
31 | from praxis.layers.quantization.ngrammer import Ngrammer
32 | from praxis.layers.quantization.ngrammer import VQNgrammer
33 | from praxis.layers.quantization.operations import einsum
34 | from praxis.layers.quantization.overflow_check import AttentionProjectionOverflowCheck, CombinedQKVProjectionLayerOverflowCheck, FeedForwardOverflowCheck, OneHeadedAttentionProjectionOverflowCheck
35 | from praxis.layers.quantization.searchable import SearchableAttentionProjection
36 | from praxis.layers.quantization.searchable import SearchableCombinedQKVProjectionLayer
37 | from praxis.layers.quantization.searchable import SearchableLinear
38 |
--------------------------------------------------------------------------------
/praxis/layers/quantization/automl_select.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Praxis AutoMLSelect layer to switch branches according to AutoML decisions."""
17 |
18 | from typing import Sequence
19 |
20 | from flax import linen as nn
21 | import jax.numpy as jnp
22 | from praxis import base_layer
23 | from praxis import pax_fiddle
24 | from praxis import pytypes
25 |
26 | JTensor = pytypes.JTensor
27 | LayerTpl = pax_fiddle.Config[base_layer.BaseLayer]
28 | template_field = pax_fiddle.template_field
29 | WeightInit = base_layer.WeightInit
30 | WeightHParams = base_layer.WeightHParams
31 |
32 |
33 | # TODO(yunn): Move BaseAutoMLSelect to efficient NAS package for Praxis
34 | class BaseAutoMLSelect(base_layer.BaseLayer):
35 | """Praxis AutoMLSelect layer to switch branches according to AutoML decisions.
36 |
37 | Attributes:
38 | search_options_tpl: a sequence of layers, with each represents a branch.
39 | They layers must have the same shapes of input and output.
40 | """
41 |
42 | search_options_tpl: Sequence[LayerTpl] | None = template_field(None)
43 |
44 | def setup(self) -> None:
45 | if not self.search_options_tpl:
46 | raise AttributeError('Must set at least one search option.')
47 | decision = WeightHParams(
48 | shape=[],
49 | init=WeightInit.Constant(0),
50 | dtype=jnp.uint8,
51 | mesh_shape=self.mesh_shape
52 | )
53 | rl_variables = WeightHParams(
54 | shape=[len(self.search_options_tpl)],
55 | init=WeightInit.Constant(0),
56 | dtype=jnp.float32,
57 | mesh_shape=self.mesh_shape
58 | )
59 | self.create_children('search_options', self.search_options_tpl)
60 | self.create_variable('decision', decision, trainable=False)
61 | self.create_variable('rl_variables', rl_variables, trainable=False)
62 |
63 |
64 | class AutoMLSelect(BaseAutoMLSelect):
65 | """An AutoMLSelect that switches between different quantizer."""
66 |
67 | def __call__(
68 | self,
69 | x: JTensor,
70 | contract_dims: int | Sequence[int],
71 | squeeze_scale=True,
72 | quantized_dtype: jnp.dtype | None = None,
73 | ) -> tuple[JTensor, JTensor, JTensor | None]:
74 | def branch_fn(i):
75 | def quantize_fn(mdl, inputs):
76 | return mdl.search_options[i].quantize(
77 | inputs, contract_dims, squeeze_scale, quantized_dtype
78 | )
79 |
80 | return quantize_fn
81 |
82 | branches = [branch_fn(i) for i in range(len(self.search_options))]
83 | return nn.switch(self.get_var('decision'), branches, self, x)
84 |
85 | def quantize(
86 | self,
87 | x: JTensor,
88 | contract_dims: int | Sequence[int],
89 | squeeze_scale=True,
90 | quantized_dtype: jnp.dtype | None = None,
91 | ) -> tuple[JTensor, JTensor, JTensor | None]:
92 | return self.__call__(x, contract_dims, squeeze_scale, quantized_dtype)
93 |
--------------------------------------------------------------------------------
/praxis/layers/quantization/automl_select_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for praxis.layers.quantization.automl_select."""
17 |
18 | from absl.testing import absltest
19 |
20 | import jax
21 | import jax.numpy as jnp
22 |
23 | from praxis import base_hyperparams
24 | from praxis import pax_fiddle
25 | from praxis import test_utils
26 | from praxis.layers.quantization import automl_select
27 | from praxis.layers.quantization import quantization_hparams
28 | from praxis.layers.quantization import quantizer
29 |
30 |
31 | ActQuantizationParams = quantization_hparams.ActQuantizationParams
32 | instantiate = base_hyperparams.instantiate
33 |
34 |
35 | class AutomlSelectTest(test_utils.TestCase):
36 |
37 | def test_automl_select(self):
38 | p = pax_fiddle.Config(automl_select.AutoMLSelect)
39 | p.search_options_tpl = [
40 | quantizer.create_tensor_quantizer(
41 | 'int4', ActQuantizationParams(precision=4)
42 | ),
43 | quantizer.create_tensor_quantizer(
44 | 'int8', ActQuantizationParams(precision=8)
45 | ),
46 | ]
47 |
48 | m = instantiate(p)
49 | x = jnp.ones((1, 3))
50 | m_vars = m.init(jax.random.PRNGKey(0), x, 0)
51 | m_vars['non_trainable']['decision'] = 0
52 | q_x1, q_s1, _ = m.apply(m_vars, x, 0)
53 | m_vars['non_trainable']['decision'] = 1
54 | q_x2, q_s2, _ = m.apply(m_vars, x, 0)
55 | self.assertNotAllClose(q_x1, q_x2)
56 | self.assertNotAllClose(q_s1, q_s2)
57 | self.assertIn('rl_variables', m_vars['non_trainable'])
58 |
59 |
60 | if __name__ == '__main__':
61 | absltest.main()
62 |
--------------------------------------------------------------------------------
/praxis/layers/quantization/conformers.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Quantized Conformer Layers."""
17 |
18 | import jax.numpy as jnp
19 | from praxis import pytypes
20 | from praxis.layers import attentions
21 | from praxis.layers.quantization import attentions as qattentions
22 |
23 | JTensor = pytypes.JTensor
24 |
25 |
26 | class DotProductAttentionWithContext(qattentions.DotProductAttention):
27 | """Dot-product attention with given left and right context.
28 |
29 | It covers several use cases:
30 | 1 global self attention when left_context=right_context=None
31 | 2 local self attention when left_context!=None and right_context!=None
32 | 3 hybrid self attention when left_context or right_context is None
33 |
34 | For use cases (2,3) it will use emulated local self attention.
35 | For use case (2) it is more efficient to use LocalSelfAttention.
36 | """
37 |
38 | left_context: int | None = None
39 | right_context: int | None = None
40 |
41 | def _dot_atten(
42 | self,
43 | query: JTensor,
44 | key: JTensor,
45 | value: JTensor,
46 | atten_mask: JTensor,
47 | relative_bias: JTensor | None = None,
48 | ) -> tuple[JTensor, JTensor]:
49 | """Main attention function.
50 |
51 | Args:
52 | query: JTensor of shape [B, T, N, H].
53 | key: JTensor of shape [B, S, N, H].
54 | value: JTensor of shape [B, S, N, H].
55 | atten_mask: JTensor of shape [1|B, 1, 1|T, S] which is a mask that is
56 | applied to prevent attention between unwanted pairs. This has already
57 | been converted into large negative logits. Note that the first and third
58 | dimension allow size 1 if the mask is shared by every item in the batch
59 | or every token in the target sequence.
60 | relative_bias: Relative bias of shape [B, N, T, S].
61 |
62 | Returns:
63 | encoded: JTensor of shape [B, T, N, H].
64 | atten_probs: JTensor of shape [B, N, T, S].
65 | """
66 | time_size = query.shape[1]
67 |
68 | if self.left_context is not None or self.right_context is not None:
69 | input_atten_mask = atten_mask
70 | atten_mask = attentions.limited_context_mask(
71 | self.left_context, self.right_context, time_size
72 | )
73 | atten_mask = jnp.minimum(atten_mask, input_atten_mask)
74 | return super()._dot_atten(query, key, value, atten_mask, relative_bias)
75 |
--------------------------------------------------------------------------------
/praxis/layers/quantization/conformers_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for quantized attentions."""
17 |
18 | from absl.testing import absltest
19 | import jax
20 | import numpy as np
21 | from praxis import base_layer
22 | from praxis import pax_fiddle
23 | from praxis import test_utils
24 | from praxis.layers import attentions
25 | from praxis.layers import conformers
26 | from praxis.layers.quantization import conformers as qconformers
27 | from praxis.layers.quantization import quantization_hparams
28 |
29 | QuantizationParams = quantization_hparams.QuantizationParams
30 | QuantizationMode = quantization_hparams.QuantizationMode
31 | QuantizationType = quantization_hparams.QuantizationType
32 | instantiate = base_layer.instantiate
33 |
34 |
35 | class DotProductAttentionWithContexSyncTest(test_utils.TestCase):
36 |
37 | def setUp(self):
38 | super().setUp()
39 | np.random.seed(123456)
40 |
41 | def test_dot_product_attention_with_context(self):
42 | batch_size = 2
43 | source_max_length = 16
44 | input_dim = 32
45 |
46 | atten_f_p = pax_fiddle.Config(
47 | conformers.DotProductAttentionWithContext, name='atten_f'
48 | )
49 | atten_q_p = pax_fiddle.Config(
50 | qconformers.DotProductAttentionWithContext,
51 | name='atten_q',
52 | quantization=QuantizationParams(
53 | quantization_type=QuantizationType.AQT,
54 | mode=QuantizationMode.TRAINING,
55 | # Test using 23 bits to minimize the quantization error and test
56 | # for numerical correctness.
57 | act_params=quantization_hparams.ActQuantizationParams(precision=23),
58 | weight_params=quantization_hparams.WeightQuantizationParams(),
59 | ),
60 | )
61 | for p in [atten_f_p, atten_q_p]:
62 | p.input_dim = input_dim
63 | p.hidden_dim = 16
64 | p.left_context = 3
65 | p.right_context = 5
66 |
67 | atten_f = instantiate(atten_f_p)
68 | atten_q = instantiate(atten_q_p)
69 |
70 | query_vec = np.random.normal(
71 | size=[batch_size, source_max_length, input_dim]
72 | ).astype(np.float32)
73 | key_vec = query_vec
74 | value_vec = query_vec
75 | atten_mask = attentions.causal_mask(query_vec)
76 |
77 | with base_layer.JaxContext.new_context():
78 | initial_vars = atten_f.init(
79 | jax.random.PRNGKey(0),
80 | query_vec,
81 | key_vec,
82 | value_vec,
83 | atten_mask,
84 | )
85 | fprop_out_f, _ = atten_f.apply(
86 | initial_vars,
87 | query_vec,
88 | key_vec,
89 | value_vec,
90 | atten_mask,
91 | method=atten_f.__call__,
92 | )
93 | fprop_out_q, _ = atten_q.apply(
94 | initial_vars,
95 | query_vec,
96 | key_vec,
97 | value_vec,
98 | atten_mask,
99 | method=atten_q.__call__,
100 | )
101 | self.assertAllClose(fprop_out_f, fprop_out_q)
102 |
103 |
104 | if __name__ == '__main__':
105 | absltest.main()
106 |
--------------------------------------------------------------------------------
/praxis/layers/quantization/convolutions.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Convolutional layers."""
17 |
18 | from jax import numpy as jnp
19 | from praxis import base_layer
20 | from praxis import pytypes
21 | from praxis.layers import convolutions
22 | from praxis.layers import normalizations
23 | from praxis.layers.quantization import quantization_hparams
24 | from praxis.layers.quantization import quantizer
25 |
26 | WeightHParams = base_layer.WeightHParams
27 | QuantizationParams = quantization_hparams.QuantizationParams
28 | instance_field = base_layer.instance_field
29 | JTensor = pytypes.JTensor
30 |
31 |
32 | class Conv2D(convolutions.Conv2D, quantizer.QuantizationLayer): # pytype: disable=signature-mismatch
33 | """Conv2D with support of SAME/VALID paddings."""
34 |
35 | _PACK_4BIT_DIM = 0
36 |
37 | def setup(self) -> None:
38 | self.check_dimensions()
39 | wp = self.weight_split_dims_mapping
40 | pc = WeightHParams(
41 | shape=self.filter_shape,
42 | mesh_shape=self.mesh_shape,
43 | tensor_split_dims_mapping=wp.wt,
44 | dtype=self.dtype,
45 | init=self.kernel_init,
46 | )
47 | self.set_up_weights(
48 | weight_name='w',
49 | weight_params=pc,
50 | scale_shape=[self.filter_shape[-1]],
51 | )
52 |
53 | if self.bias:
54 | self.create_variable(
55 | 'b',
56 | WeightHParams(
57 | shape=[self.filter_shape[-1]],
58 | init=self.bias_init,
59 | dtype=self.dtype,
60 | ),
61 | )
62 |
63 | wn = self.weight_norm_tpl.clone().set(dim=self.filter_shape[-1])
64 | self.weight_norm: normalizations.BaseNormalization
65 | self.create_child('weight_norm', wn)
66 |
67 | def __call__(self, inputs: JTensor) -> JTensor:
68 | """FProp that supports strided, dilated convolution, depthwise convolution.
69 |
70 | Args:
71 | inputs: Input sequence of shape [B, H, W, D_in], also known more popularly
72 | as NHWC format.
73 |
74 | Returns:
75 | Output sequence after applying convolutions of shape [B, H', W', D_out].
76 | Note that if the padding is SAME and there is no dilation and striding,
77 | then H' = H and W' = W.
78 | """
79 | # Check if the feature_group_count is compatible with the inputs and filter
80 | # For more information see XLA docs on ConvWithGeneralPadding below
81 | # https://www.tensorflow.org/xla/operation_semantics#convwithgeneralpadding_convolution
82 | if inputs.shape[3] % self.filter_shape[2] != 0:
83 | raise ValueError(
84 | f'Input features {inputs.shape[3]} must be a'
85 | f'multiple of filter input dim {self.filter_shape[2]} '
86 | f'(Input shape: {inputs.shape}, '
87 | f'filter shape: {self.filter_shape}).'
88 | )
89 | # feature group count is D_in // filter input dim
90 | feature_group_count = inputs.shape[3] // self.filter_shape[2]
91 | if self.filter_shape[3] % feature_group_count != 0:
92 | raise ValueError(
93 | f'Filter output dim {self.filter_shape[3]} must be a '
94 | f'multiple of feature group count {feature_group_count} '
95 | f'(Input shape: {inputs.shape}, '
96 | f'filter shape: {self.filter_shape}).'
97 | )
98 | padding = self._compute_padding(inputs.shape)
99 | inputs = self._shard_bhwc(inputs.astype(self.fprop_dtype))
100 |
101 | # The `dimension_numbers=('NHWC', 'HWIO', 'NHWC')` is to be consistent
102 | # with tf.conv2d, see e.g., see
103 | # https://github.com/google/jax/blob/main/jax/_src/lax/lax.py#L622
104 | dimension_numbers = ('NHWC', 'HWIO', 'NHWC')
105 | outputs = quantizer.quantized_conv(
106 | self, inputs, padding, dimension_numbers, feature_group_count
107 | )
108 | outputs = self._shard_bhwc(outputs)
109 | if self.bias:
110 | outputs += jnp.reshape(self.theta.b, (1,) * (outputs.ndim - 1) + (-1,))
111 | return outputs
112 |
--------------------------------------------------------------------------------
/praxis/layers/quantization/convolutions_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for Quantized Praxis convolutional layers."""
17 |
18 | from absl.testing import absltest
19 | from absl.testing import parameterized
20 | import jax
21 | from jax import numpy as jnp
22 | import numpy as np
23 | from praxis import base_layer
24 | from praxis import pax_fiddle
25 | from praxis import test_utils
26 | from praxis.layers import convolutions
27 | from praxis.layers.quantization import convolutions as qconvolutions
28 | from praxis.layers.quantization import operations
29 | from praxis.layers.quantization import quantization_hparams
30 |
31 | instantiate = base_layer.instantiate
32 | QuantizationParams = quantization_hparams.QuantizationParams
33 | QuantizationMode = quantization_hparams.QuantizationMode
34 |
35 | PARAMS = base_layer.PARAMS
36 | NON_TRAINABLE = base_layer.NON_TRAINABLE
37 |
38 |
39 | class ConvolutionsTest(test_utils.TestCase):
40 |
41 | def setUp(self):
42 | super().setUp()
43 | np.random.seed(123456)
44 |
45 | @parameterized.parameters(
46 | ((5, 4, 24, 36), (1, 1), (1, 1), False, [2, 16, 36, 72]),
47 | ((2, 4, 16, 8), (2, 2), (1, 1), False, [2, 16, 32, 128]),
48 | ((4, 8, 16, 32), (1, 1), (1, 1), False, [2, 16, 32, 64]),
49 | ((2, 8, 16, 32), (1, 1), (2, 2), False, [2, 16, 32, 64]),
50 | ((2, 8, 16, 32), (2, 2), (2, 2), False, [2, 16, 32, 64]),
51 | ((2, 8, 16, 32), (1, 1), (2, 1), False, [2, 16, 32, 64]),
52 | ((2, 8, 16, 32), (2, 2), (2, 1), False, [2, 16, 32, 64]),
53 | ((2, 8, 16, 32), (1, 1), (2, 2), True, [2, 16, 32, 64]),
54 | ((2, 8, 16, 32), (2, 2), (2, 2), True, [2, 16, 32, 64]),
55 | )
56 | def test_conv2d_layer_same_padding(
57 | self,
58 | filter_shape,
59 | filter_stride,
60 | dilations,
61 | tf_equivalent_padding,
62 | input_shape,
63 | ):
64 | npy_inputs = np.random.normal(0.0, 0.5, input_shape).astype('float32')
65 | inputs = jnp.asarray(npy_inputs)
66 | prng_key = jax.random.PRNGKey(seed=123)
67 |
68 | # Float layer.
69 | p = pax_fiddle.Config(
70 | convolutions.Conv2D,
71 | name='jax_conv2d',
72 | filter_shape=filter_shape,
73 | filter_stride=filter_stride,
74 | dilations=dilations,
75 | tf_equivalent_padding=tf_equivalent_padding,
76 | padding='SAME',
77 | )
78 | conv_layer = instantiate(p)
79 | initial_vars = conv_layer.init(prng_key, inputs)
80 | outputs = conv_layer.apply(initial_vars, inputs)
81 |
82 | # Quantized layer.
83 | qp = pax_fiddle.Config(
84 | qconvolutions.Conv2D,
85 | name='jax_conv2_q',
86 | filter_shape=filter_shape,
87 | filter_stride=filter_stride,
88 | dilations=dilations,
89 | tf_equivalent_padding=tf_equivalent_padding,
90 | padding='SAME',
91 | quantization=QuantizationParams(mode=QuantizationMode.INFERENCE)
92 | )
93 | qconv_layer = instantiate(qp)
94 | qinitial_vars = qconv_layer.init(prng_key, inputs)
95 | # TODO(jianlijianli): Use quantize_weight() once it's implemented.
96 | qweight, qscale, _ = operations.reduce_precision(
97 | initial_vars['params']['w'], [0, 1, 2]
98 | )
99 | qinitial_vars['params']['w'] = qweight
100 | qinitial_vars['params']['w_quantized_scale'] = jnp.squeeze(qscale)
101 | qoutput = qconv_layer.apply(qinitial_vars, inputs)
102 |
103 | self.assertAllClose(qoutput, outputs, rtol=1e-02, atol=1e-02)
104 |
105 |
106 | if __name__ == '__main__':
107 | absltest.main()
108 |
--------------------------------------------------------------------------------
/praxis/layers/quantization/einsum.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """A layer that computes an Einsum with a weight, and optionally adds a bias."""
17 |
18 | from typing import Sequence
19 |
20 | from praxis import base_layer
21 | from praxis import pax_fiddle
22 | from praxis import pytypes
23 | from praxis.layers.quantization import quantizer
24 |
25 | JTensor = pytypes.JTensor
26 | LayerTpl = pax_fiddle.Config[base_layer.BaseLayer]
27 | template_field = base_layer.template_field
28 |
29 |
30 | class Einsum(quantizer.QuantizationLayer):
31 | """Layer that computes an einsum and maybe a bias.
32 |
33 | The fan-in, fan-out and bias dimensions are inferred from the einsum equation.
34 | If bias is used, the fan-out dims must appear at the end of the output tensor.
35 |
36 | Attributes:
37 | eqn: Einsum equation. It should be in the format of input,w->output. E.g.,
38 | '...d,df->...f'.
39 | w_shape: Weight shape.
40 | use_bias: Whether to add a bias.
41 | """
42 | eqn: str = ''
43 | w_shape: Sequence[int] = ()
44 | use_bias: bool = False
45 | _PACK_4BIT_DIM = 0
46 |
47 | def setup(self) -> None:
48 | operands, out = self.eqn.split('->')
49 | x, w = operands.split(',')
50 | assert '.' not in w
51 | fan_in = sorted(w.index(d) for d in (set(x) - set(out)))
52 | fan_out = sorted(w.index(d) for d in (set(out) - set(x)))
53 | w_sharding = self.weight_split_dims_mapping.wt
54 | pc = base_layer.WeightHParams(
55 | shape=self.w_shape,
56 | fan_in_axes=fan_in,
57 | fan_out_axes=fan_out,
58 | mesh_shape=self.mesh_shape,
59 | tensor_split_dims_mapping=w_sharding,
60 | )
61 | out_bias_dims = sorted(out.index(d) for d in (set(out) - set(x)))
62 | bias_shape = [self.w_shape[w.index(out[d])] for d in out_bias_dims]
63 | self.set_up_weights(
64 | weight_name='w',
65 | weight_params=pc,
66 | scale_shape=bias_shape,
67 | )
68 | if self.use_bias:
69 | # Fan-out dims must be at the end of `out`.
70 | assert all(d >= len(out) - len(out_bias_dims) for d in out_bias_dims)
71 | if w_sharding is not None:
72 | b_sharding = [w_sharding[w.index(out[d])] for d in out_bias_dims]
73 | else:
74 | b_sharding = None
75 | pc_bias = base_layer.WeightHParams(
76 | shape=bias_shape,
77 | init=base_layer.WeightInit.Constant(0.0),
78 | mesh_shape=self.mesh_shape,
79 | tensor_split_dims_mapping=b_sharding,
80 | )
81 | self.create_variable('b', pc_bias)
82 |
83 | def __call__(self, inputs: JTensor) -> JTensor:
84 | """Computes the einsum and maybe bias.
85 |
86 | Args:
87 | inputs: A JTensor of shape as described in the equation.
88 |
89 | Returns:
90 | The result of the einsum with maybe a bias added.
91 | """
92 | ret = self.quantized_einsum(
93 | eqn=self.eqn,
94 | x=inputs,
95 | w=self.theta.w,
96 | reshape=[],
97 | )
98 | if self.use_bias:
99 | ret += self.theta.b
100 | return base_layer.maybe_shard(
101 | ret, self.activation_split_dims_mapping.out, self.mesh_axis_names
102 | )
103 |
--------------------------------------------------------------------------------
/praxis/layers/quantization/einsum_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for Quantized Praxis Einsum layer."""
17 |
18 | from absl.testing import absltest
19 | from absl.testing import parameterized
20 | import jax
21 | import numpy as np
22 | from praxis import base_layer
23 | from praxis import pax_fiddle
24 | from praxis import test_utils
25 | from praxis.layers.quantization import einsum
26 |
27 | instantiate = base_layer.instantiate
28 |
29 |
30 | class EinsumTest(test_utils.TestCase):
31 |
32 | def setUp(self):
33 | super().setUp()
34 | np.random.seed(123456)
35 |
36 | @parameterized.parameters(
37 | ('...d,df->...f', [3, 4, 5], [5, 8], False),
38 | ('...d,dnh->...nh', [3, 5], [5, 2, 8], True),
39 | ('blnh,dnh->bld', [3, 4, 2, 8], [5, 2, 8], False),
40 | ('...nh,hdn->...d', [3, 4, 2, 8], [8, 5, 2], True),
41 | )
42 | def test_einsum(self, eqn, in_shape, w_shape, use_bias):
43 | p = pax_fiddle.Config(
44 | einsum.Einsum,
45 | name='einsum',
46 | eqn=eqn,
47 | w_shape=w_shape,
48 | use_bias=use_bias,
49 | )
50 | layer = instantiate(p)
51 | inputs = np.random.normal(1.0, 0.5, in_shape).astype('float32')
52 | prng_key = jax.random.PRNGKey(seed=123)
53 | initial_vars = layer.init(prng_key, inputs)
54 | qw = np.arange(np.prod(w_shape)).astype('int8').reshape(w_shape)
55 | initial_vars['params']['w'] = qw
56 | if use_bias:
57 | # Make sure the bias is non-zero.
58 | initial_vars['params']['b'] = np.random.normal(
59 | 1.0, 0.5, initial_vars['params']['b'].shape
60 | ).astype('float32')
61 | outputs = layer.apply(initial_vars, inputs)
62 | np_outputs = np.multiply(
63 | np.einsum(eqn, inputs, initial_vars['params']['w']),
64 | initial_vars['params']['w_quantized_scale'],
65 | )
66 | if use_bias:
67 | np_outputs += initial_vars['params']['b']
68 | self.assertAllClose(outputs, np_outputs, atol=1e-6)
69 |
70 |
71 | if __name__ == '__main__':
72 | absltest.main()
73 |
--------------------------------------------------------------------------------
/praxis/layers/quantization/multi_query_attention_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for Praxis quantized multi-query attention layers."""
17 |
18 | from absl import logging
19 | from absl.testing import absltest
20 | from absl.testing import parameterized
21 | import jax
22 | import numpy as np
23 | from praxis import base_layer
24 | from praxis import pax_fiddle
25 | from praxis import test_utils
26 | from praxis.layers import multi_query_attention as mqa_f
27 | from praxis.layers.quantization import multi_query_attention as mqa_q
28 | from praxis.layers.quantization import quantization_hparams
29 | from praxis.layers.quantization import utils
30 |
31 | instantiate = base_layer.instantiate
32 | QuantizationParams = quantization_hparams.QuantizationParams
33 | QuantizationMode = quantization_hparams.QuantizationMode
34 | QuantizationType = quantization_hparams.QuantizationType
35 | WeightQuantizationParams = quantization_hparams.WeightQuantizationParams
36 |
37 |
38 | # TODO(b/277357920): add mutiqueryattention quantization sync test.
39 |
40 | class MultiQueryAttentionTest(test_utils.TestCase):
41 |
42 | def setUp(self):
43 | super().setUp()
44 | np.random.seed(123456)
45 |
46 | @parameterized.named_parameters(
47 | dict(testcase_name='PTQ', quantization_type=QuantizationType.PTQ),
48 | dict(testcase_name='AQT', quantization_type=QuantizationType.AQT),
49 | )
50 | def test_one_headed_projection_shape(self, quantization_type):
51 | test_layer_p = pax_fiddle.Config(
52 | mqa_q.OneHeadedAttentionProjection,
53 | name='mh',
54 | input_dim=16,
55 | output_dim=5,
56 | quantization=QuantizationParams(
57 | quantization_type=quantization_type,
58 | mode=QuantizationMode.TRAINING,
59 | weight_params=quantization_hparams.WeightQuantizationParams(),
60 | ),
61 | )
62 | layer = instantiate(test_layer_p)
63 |
64 | inputs = np.random.normal(1.5, 2.0, [5, 16]).astype(np.float32)
65 |
66 | prng_key = jax.random.PRNGKey(seed=123)
67 | prng_key, init_key = jax.random.split(prng_key)
68 | initial_vars = layer.init(init_key, inputs)
69 | logging.info('initial_vars: %s', initial_vars)
70 |
71 | jax_out = layer.apply(initial_vars, inputs)
72 | self.assertSequenceEqual(jax_out.shape, [5, 5])
73 |
74 | @parameterized.product(
75 | quantization_type=[QuantizationType.PTQ, QuantizationType.AQT],
76 | use_symmetric=[True, False],
77 | precision=[8, 4],
78 | )
79 | def test_one_headed_projection_quantized(
80 | self, quantization_type, use_symmetric, precision
81 | ):
82 | input_dim = 16
83 | output_dim = 3
84 | batch = 5
85 | p_f = pax_fiddle.Config(
86 | mqa_f.OneHeadedAttentionProjection,
87 | name='ohp_f',
88 | input_dim=input_dim,
89 | output_dim=output_dim,
90 | )
91 | p_q = pax_fiddle.Config(
92 | mqa_q.OneHeadedAttentionProjection,
93 | name='ohp_q',
94 | input_dim=input_dim,
95 | output_dim=output_dim,
96 | quantization=QuantizationParams(
97 | quantization_type=QuantizationType.PTQ,
98 | mode=QuantizationMode.INFERENCE,
99 | weight_params=WeightQuantizationParams(
100 | use_symmetric=use_symmetric, precision=precision
101 | ),
102 | ),
103 | )
104 |
105 | inputs = np.random.normal(1.5, 2.0, [batch, input_dim]).astype(np.float32)
106 | quantized_weight_range = (-8, 7) if precision == 4 else (-128, 127)
107 | quantized_weight = np.random.randint(
108 | *quantized_weight_range, (input_dim, output_dim), dtype=np.int8
109 | )
110 | w_scale = np.array([0.5, 2.0, 3.3], dtype=np.float32)
111 | if use_symmetric:
112 | weight_rescaled = quantized_weight * w_scale
113 | else:
114 | w_zp = np.array([-10.0, 10.0, -2.5], dtype=np.float32)
115 | weight_rescaled = quantized_weight * w_scale - w_zp
116 |
117 | ohp_f = instantiate(p_f)
118 | ohp_q = instantiate(p_q)
119 |
120 | with base_layer.JaxContext.new_context():
121 | prng_key = jax.random.PRNGKey(seed=123)
122 | initial_vars_f = ohp_f.init(prng_key, inputs)
123 | initial_vars_q = ohp_q.init(prng_key, inputs)
124 | initial_vars_f['params']['w'] = weight_rescaled
125 | initial_vars_q['params']['w'] = (
126 | utils.pack_4bit(quantized_weight, pack_dim=0)
127 | if precision == 4
128 | else quantized_weight
129 | )
130 | initial_vars_q['params']['w_quantized_scale'] = w_scale
131 | if not use_symmetric:
132 | initial_vars_q['params']['w_quantized_zp'] = w_zp
133 | outputs_f = ohp_f.apply(initial_vars_f, inputs)
134 | outputs_q = ohp_q.apply(initial_vars_q, inputs)
135 | self.assertAllClose(outputs_f, outputs_q)
136 |
137 |
138 | if __name__ == '__main__':
139 | absltest.main()
140 |
--------------------------------------------------------------------------------
/praxis/layers/quantization/ngrammer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Quantized Ngrammer layers."""
17 |
18 | from jax import numpy as jnp
19 | from praxis import base_layer
20 | from praxis import pax_fiddle
21 | from praxis.layers import ngrammer
22 | from praxis.layers import normalizations
23 | from praxis.layers.quantization import embedding_softmax as quantized_embedding_softmax
24 | from praxis.layers.quantization import quantization_hparams
25 |
26 | QuantizationParams = quantization_hparams.QuantizationParams
27 | WeightHParams = base_layer.WeightHParams
28 | instance_field = base_layer.instance_field
29 |
30 |
31 | class Ngrammer(ngrammer.Ngrammer):
32 | """Quantized Ngrammer."""
33 |
34 | quantization: QuantizationParams = instance_field(QuantizationParams)
35 |
36 | def setup(self) -> None:
37 | """Constructs an instance which looks up ngrams."""
38 |
39 | if self.concat_ngrams:
40 | # The ngram_emb_dim must be smaller than dim_per_head.
41 | assert self.ngram_emb_dim <= self.dim_per_head
42 | else:
43 | # If not concatenating ngram embeddings, check the dims are compatible.
44 | assert self.ngram_emb_dim == self.dim_per_head
45 |
46 | # Create a separate layer norm per head for embedding normalization.
47 | # Create a separate layer norm per head for ngram embedding normalization.
48 | emb_layer_norm_p = []
49 | ngram_emb_layer_norm_p = []
50 | ngram_emb_table_p = []
51 | for i in range(self.num_heads):
52 | layer_norm_p = pax_fiddle.Config(normalizations.LayerNorm).clone()
53 | layer_norm_p.dim = self.dim_per_head
54 | layer_norm_p.name = f'layer_norm_{i}'
55 |
56 | emb_layer_norm_p.append(layer_norm_p)
57 | ngram_layer_norm_p = pax_fiddle.Config(normalizations.LayerNorm).clone()
58 | ngram_layer_norm_p.dim = self.ngram_emb_dim
59 | ngram_emb_layer_norm_p.append(ngram_layer_norm_p)
60 |
61 | # Create embedding table for ngram lookup.
62 | embedding_p = pax_fiddle.Config(
63 | quantized_embedding_softmax.Embedding,
64 | quantization=self.quantization,
65 | ).clone()
66 | embedding_p.name = f'embedding_{i}'
67 | embedding_p.num_classes = self.ngram_vocab_size
68 | embedding_p.input_dims = self.ngram_emb_dim
69 | embedding_p.params_init = self.params_init
70 | # Copy sharding annotations.
71 | embedding_p.weight_split_dims_mapping = self.weight_split_dims_mapping
72 | ngram_emb_table_p.append(embedding_p)
73 |
74 | self.create_children('emb_layer_norm', emb_layer_norm_p)
75 | self.create_children('ngram_layer_norm', ngram_emb_layer_norm_p)
76 | self.create_children('ngram_table', ngram_emb_table_p)
77 |
78 |
79 | class VQNgrammer(ngrammer.VQNgrammer):
80 | """Quantized VQNgrammer."""
81 |
82 | quantization: QuantizationParams = instance_field(QuantizationParams)
83 |
84 | def setup(self) -> None:
85 | """Constructs a VQ layer and an N-grammer layer."""
86 |
87 | if self.concat_ngrams:
88 | # The ngram_emb_dim must be smaller than dim_per_head.
89 | assert self.ngram_emb_dim <= self.dim_per_head
90 | else:
91 | # If not concatenating ngram embeddings, check the dims are compatible.
92 | assert self.ngram_emb_dim == self.dim_per_head
93 |
94 | # Create VQ layer.
95 | vq_layer_p = pax_fiddle.Config(
96 | ngrammer.VectorQuantization,
97 | num_clusters=self.num_clusters,
98 | num_heads=self.num_heads,
99 | dim_per_head=self.dim_per_head,
100 | decay=self.decay,
101 | epsilon=self.epsilon,
102 | params_init=self.params_init,
103 | )
104 | self.create_child('vq_layer', vq_layer_p)
105 |
106 | # Create the input id to cluster id cache.
107 | if self.unigram_vocab_size:
108 | input_id_to_cluster_id_cache = WeightHParams(
109 | shape=[self.unigram_vocab_size, self.num_heads],
110 | dtype=jnp.int32,
111 | init=base_layer.WeightInit.Constant(0),
112 | )
113 | self.create_variable(
114 | 'input_id_to_cluster_id_cache',
115 | input_id_to_cluster_id_cache,
116 | trainable=False,
117 | )
118 |
119 | # Create N-gram lookup layer.
120 | ngram_layer_p = pax_fiddle.Config(
121 | Ngrammer, # Quantized Ngrammer.
122 | quantization=self.quantization,
123 | ngram_vocab_size=self.ngram_vocab_size,
124 | unigram_vocab_size=self.num_clusters,
125 | ngram_emb_dim=self.ngram_emb_dim,
126 | concat_ngrams=self.concat_ngrams,
127 | num_heads=self.num_heads,
128 | dim_per_head=self.dim_per_head,
129 | params_init=self.params_init,
130 | weight_split_dims_mapping=self.weight_split_dims_mapping,
131 | )
132 | self.create_child('ngram_layer', ngram_layer_p)
133 |
--------------------------------------------------------------------------------
/praxis/layers/quantization/optimization_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for quantization optimizations."""
17 |
18 | from absl.testing import absltest
19 | from jax import numpy as jnp
20 | import numpy as np
21 | from praxis import test_utils
22 | from praxis.layers.quantization import optimization
23 |
24 |
25 | class OptimizationTest(test_utils.TestCase):
26 |
27 | def setUp(self):
28 | super().setUp()
29 | np.random.seed(123456)
30 |
31 | def test_optimization_int8(self):
32 | t = jnp.array([[1.0, 2.0, 3.0], [4.0, 1.0, 2.0]])
33 | bound = jnp.array([[3.0], [4.0]])
34 | ret = optimization.get_best_bound(t, bound, -128.0, 127.0)
35 | expected = jnp.array([[3.0], [4.0]])
36 | self.assertArraysEqual(ret, expected)
37 |
38 | def test_optimization_int4(self):
39 | t = jnp.array([[1.0, 2.0, 3.0], [4.0, 1.0, 2.0]])
40 | bound = jnp.array([[3.0], [4.0]])
41 | ret = optimization.get_best_bound(t, bound, -8.0, 7.0)
42 | expected = jnp.array([[3.0], [4.0]])
43 | self.assertArraysEqual(ret, expected)
44 |
45 | def test_optimization_int4_p2(self):
46 | t = jnp.array([[1.0, 2.0, 3.0], [4.0, 1.0, 2.0]])
47 | bound = jnp.array([[3.0], [4.0]])
48 | ret = optimization.get_best_bound(t, bound, -8.0, 7.0, 2.0)
49 | expected = jnp.array([[2.85], [3.8]])
50 | self.assertArraysEqual(ret, expected)
51 |
52 | def test_optimization_int4_per_channel(self):
53 | t = jnp.array([[1.0, 2.0, 3.0], [4.0, 1.0, 2.0]])
54 | bound = jnp.max(t, axis=0, keepdims=True)
55 | best_bound_per_channel = optimization.get_best_bound(
56 | t, bound, -8.0, 7.0, 2.0, per_channel=True
57 | )
58 | expected_best_bound_per_channel = jnp.array([[4.0, 1.9, 3.0]])
59 | self.assertArraysEqual(
60 | best_bound_per_channel, expected_best_bound_per_channel
61 | )
62 | best_bound_per_tensor = optimization.get_best_bound(
63 | t, bound, -8.0, 7.0, 2.0
64 | )
65 | expected_best_bound_per_tensor = jnp.array([[4.0, 2.0, 3.0]])
66 | self.assertArraysEqual(
67 | best_bound_per_tensor, expected_best_bound_per_tensor
68 | )
69 |
70 |
71 | if __name__ == '__main__':
72 | absltest.main()
73 |
--------------------------------------------------------------------------------
/praxis/layers/quantization/searchable.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Layers for Mixed-Precision Quantization Search."""
17 |
18 | import dataclasses
19 | from typing import Any, Sequence
20 |
21 | from praxis import base_layer
22 | from praxis import pax_fiddle
23 | from praxis.layers.quantization import attentions
24 | from praxis.layers.quantization import automl_select
25 | from praxis.layers.quantization import linears
26 | from praxis.layers.quantization import quantization_hparams
27 | from praxis.layers.quantization import quantizer
28 |
29 | QuantizationParams = quantization_hparams.QuantizationParams
30 | template_field = pax_fiddle.template_field
31 |
32 | _SUPPORTED_PRECISIONS = [2, 4, 8]
33 |
34 |
35 | def _build_quantization_templates(
36 | precisions: Sequence[int], base_qhparams: QuantizationParams
37 | ):
38 | """Build templates for precision search from the base quantization hparams."""
39 | act_qtzr_tmpls = []
40 | weight_qtzr_tmpls = []
41 |
42 | if not precisions:
43 | raise AttributeError('Must specify at least 1 precisions.')
44 |
45 | for precision in precisions: # pylint: disable=not-an-iterable
46 | if precision not in _SUPPORTED_PRECISIONS:
47 | raise AttributeError('Precision %d is not supported.' % precision)
48 | # if not quantizing activations, then don't set precision to indicate.
49 | if base_qhparams.act_params and base_qhparams.act_params.precision:
50 | act_qtzr_tmpl = quantizer.create_tensor_quantizer(
51 | f'act_quantizer_int{str(precision)}', base_qhparams.act_params
52 | )
53 | act_qtzr_tmpl.precision = precision
54 | act_qtzr_tmpls.append(act_qtzr_tmpl)
55 | weight_qtzr_tmpl = quantizer.create_tensor_quantizer(
56 | f'weight_quantizer_int{str(precision)}', base_qhparams.weight_params
57 | )
58 | weight_qtzr_tmpl.precision = precision
59 | weight_qtzr_tmpls.append(weight_qtzr_tmpl)
60 | return act_qtzr_tmpls, weight_qtzr_tmpls
61 |
62 |
63 | class SearchableQuantizedLayer(base_layer.BaseLayer):
64 | """A template to make the precision of quantizers searchable."""
65 | precisions: Sequence[int] = dataclasses.field(default_factory=list)
66 |
67 | def create_tensor_quantizers(self):
68 | act_qtzr_tmpls, weight_qtzr_tmpls = _build_quantization_templates(
69 | self.precisions, self.quantization
70 | )
71 | # Use AutoMLSelect as quantizer
72 | if act_qtzr_tmpls:
73 | self.create_child(
74 | 'act_quantizer',
75 | pax_fiddle.Config(
76 | automl_select.AutoMLSelect,
77 | search_options_tpl=act_qtzr_tmpls,
78 | ),
79 | )
80 | else:
81 | self.create_child(
82 | 'act_quantizer',
83 | quantizer.create_tensor_quantizer(
84 | 'aqt_quantizer', self.quantization.act_params
85 | ),
86 | )
87 | self.create_child(
88 | 'weight_quantizer',
89 | pax_fiddle.Config(
90 | automl_select.AutoMLSelect,
91 | search_options_tpl=weight_qtzr_tmpls,
92 | ),
93 | )
94 |
95 | # This is for fixing the return type mismatch during multiple inheritance.
96 | def replace(self, **kwargs: dict[Any, Any]):
97 | return self.replace(**kwargs)
98 |
99 |
100 | class SearchableLinear(SearchableQuantizedLayer, linears.Linear):
101 | """Quantized Linear layer with searchable precision."""
102 |
103 | pass
104 |
105 |
106 | class SearchableAttentionProjection(
107 | SearchableQuantizedLayer, attentions.AttentionProjection
108 | ):
109 | """Layer that computes multi-head projection with searchable precision."""
110 |
111 | pass
112 |
113 |
114 | class SearchableCombinedQKVProjectionLayer(
115 | SearchableQuantizedLayer, attentions.CombinedQKVProjectionLayer
116 | ):
117 | """Layer that computes QKV projection with searchable precision."""
118 | pass
119 |
--------------------------------------------------------------------------------
/praxis/layers/quantization/searchable_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for praxis.layers.quantization.searchable."""
17 |
18 | from absl.testing import absltest
19 |
20 | import jax
21 | import jax.numpy as jnp
22 |
23 | from praxis import base_hyperparams
24 | from praxis import base_layer
25 | from praxis import pax_fiddle
26 | from praxis import test_utils
27 | from praxis.layers.quantization import quantization_hparams
28 | from praxis.layers.quantization import searchable
29 |
30 | jax.config.update('jax_threefry_partitionable', False)
31 |
32 | instantiate = base_hyperparams.instantiate
33 | QuantizationType = quantization_hparams.QuantizationType
34 |
35 |
36 | def _run_option(model, model_vars, inputs, act_decision, weight_decision):
37 | if 'act_quantizer' in model_vars['non_trainable']:
38 | model_vars['non_trainable']['act_quantizer']['decision'] = act_decision
39 | model_vars['non_trainable']['weight_quantizer']['decision'] = weight_decision
40 | return model.apply(model_vars, inputs)
41 |
42 |
43 | class SearchableTest(test_utils.TestCase):
44 |
45 | def setUp(self):
46 | super().setUp()
47 | self.quantization_tpl = quantization_hparams.QuantizationParams(
48 | quantization_type=quantization_hparams.QuantizationType.AQT,
49 | mode=quantization_hparams.QuantizationMode.TRAINING,
50 | act_params=quantization_hparams.ActQuantizationParams(),
51 | weight_params=quantization_hparams.WeightQuantizationParams(),
52 | )
53 |
54 | def _test_common(self, p, x):
55 | m = instantiate(p)
56 |
57 | with base_layer.JaxContext.new_context():
58 | m_vars = m.init(jax.random.PRNGKey(0), x)
59 | a4w4 = _run_option(m, m_vars, x, 0, 0)
60 | a4w8 = _run_option(m, m_vars, x, 0, 1)
61 | a8w4 = _run_option(m, m_vars, x, 1, 0)
62 | a8w8 = _run_option(m, m_vars, x, 1, 1)
63 |
64 | self.assertAllClose(a4w4, a4w8, rtol=0.05, atol=0.05)
65 | self.assertAllClose(a4w4, a8w4, rtol=0.05, atol=0.05)
66 | self.assertAllClose(a4w8, a8w8, rtol=0.05, atol=0.05)
67 |
68 | def test_searchable_linear(self):
69 | p = pax_fiddle.Config(
70 | searchable.SearchableLinear,
71 | quantization=self.quantization_tpl,
72 | input_dims=3,
73 | output_dims=1,
74 | precisions=[4, 8],
75 | )
76 | self._test_common(p, jnp.ones((1, 3), dtype=p.dtype))
77 |
78 | def test_searchable_linear_without_act_quant(self):
79 | self.quantization_tpl.act_params.precision = None
80 | p = pax_fiddle.Config(
81 | searchable.SearchableLinear,
82 | quantization=self.quantization_tpl,
83 | input_dims=3,
84 | output_dims=1,
85 | precisions=[4, 8],
86 | )
87 | m = instantiate(p)
88 | m_vars = m.init(jax.random.PRNGKey(0), jnp.ones((1, 3), dtype=p.dtype))
89 | self.assertNotIn('act_quantizer', m_vars['non_trainable'])
90 |
91 | def test_attention_projections(self):
92 | p = pax_fiddle.Config(
93 | searchable.SearchableAttentionProjection,
94 | quantization=self.quantization_tpl,
95 | input_dim=8,
96 | num_heads=2,
97 | dim_per_head=3,
98 | is_output_projection=True,
99 | precisions=[4, 8],
100 | )
101 | self._test_common(p, jnp.ones((4, 2, 3), dtype=p.dtype))
102 |
103 | def test_combined_qkv_projections(self):
104 | p = pax_fiddle.Config(
105 | searchable.SearchableCombinedQKVProjectionLayer,
106 | quantization=self.quantization_tpl,
107 | input_dim=8,
108 | num_heads=3,
109 | dim_per_head=2,
110 | precisions=[4, 8],
111 | )
112 | self._test_common(p, jnp.ones((4, 8), dtype=p.dtype))
113 |
114 |
115 | if __name__ == '__main__':
116 | absltest.main()
117 |
--------------------------------------------------------------------------------
/praxis/layers/quantization/sparsity/BUILD:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # Description:
17 | # Sparsity related layers. The public API is defined in __init__.py.
18 |
19 | load("//praxis:build-visibility.bzl", "JAX_VISIBILITY")
20 | load("//praxis:praxis.bzl", "pytype_strict_library", "pytype_strict_test")
21 |
22 | package(default_visibility = JAX_VISIBILITY)
23 |
24 | pytype_strict_library(
25 | name = "layers",
26 | srcs = ["__init__.py"],
27 | )
28 |
29 | pytype_strict_library(
30 | name = "sparsity_hparams",
31 | srcs = ["sparsity_hparams.py"],
32 | deps = [":sparsity_modes"],
33 | )
34 |
35 | pytype_strict_library(
36 | name = "sparsity",
37 | srcs = ["sparsity.py"],
38 | deps = [
39 | ":sparsity_hparams",
40 | # Implicit jax dependency.
41 | ],
42 | )
43 |
44 | pytype_strict_test(
45 | name = "sparsity_test",
46 | srcs = ["sparsity_test.py"],
47 | deps = [
48 | ":sparsity",
49 | ":sparsity_hparams",
50 | # Implicit absl.testing.absltest dependency.
51 | # Implicit absl.testing.parameterized dependency.
52 | # Implicit upb python proto dependency.
53 | # Implicit jax dependency.
54 | # Implicit numpy dependency.
55 | ],
56 | )
57 |
58 | pytype_strict_library(
59 | name = "sparsity_modes",
60 | srcs = ["sparsity_modes.py"],
61 | deps = [
62 | # Implicit jax dependency.
63 | "//praxis:pytypes",
64 | ],
65 | )
66 |
67 | pytype_strict_test(
68 | name = "linears_test",
69 | srcs = ["linears_test.py"],
70 | deps = [
71 | ":sparsity_hparams",
72 | ":sparsity_modes",
73 | # Implicit absl.logging dependency.
74 | # Implicit absl.testing.absltest dependency.
75 | # Implicit absl.testing.parameterized dependency.
76 | # Implicit upb python proto dependency.
77 | # Implicit jax dependency.
78 | # Implicit numpy dependency.
79 | "//praxis:base_layer",
80 | "//praxis:pax_fiddle",
81 | "//praxis:test_utils",
82 | "//praxis/layers:linears",
83 | "//praxis/layers/quantization:linears",
84 | ],
85 | )
86 |
87 | pytype_strict_test(
88 | name = "attentions_test",
89 | srcs = ["attentions_test.py"],
90 | deps = [
91 | ":sparsity_hparams",
92 | ":sparsity_modes",
93 | # Implicit absl.testing.absltest dependency.
94 | # Implicit absl.testing.parameterized dependency.
95 | # Implicit upb python proto dependency.
96 | # Implicit jax dependency.
97 | # Implicit numpy dependency.
98 | "//praxis:base_layer",
99 | "//praxis:pax_fiddle",
100 | "//praxis:py_utils",
101 | "//praxis:test_utils",
102 | "//praxis/layers:attentions",
103 | "//praxis/layers/quantization:attentions",
104 | ],
105 | )
106 |
107 | pytype_strict_library(
108 | name = "sparsifier",
109 | srcs = ["sparsifier.py"],
110 | deps = [
111 | ":sparsity",
112 | ":sparsity_hparams",
113 | ":sparsity_modes",
114 | # Implicit jax dependency.
115 | "//praxis:base_layer",
116 | "//praxis:pytypes",
117 | ],
118 | )
119 |
120 | pytype_strict_test(
121 | name = "sparsifier_test",
122 | srcs = ["sparsifier_test.py"],
123 | deps = [
124 | ":sparsifier",
125 | ":sparsity_hparams",
126 | ":sparsity_modes",
127 | # Implicit absl.testing.absltest dependency.
128 | # Implicit absl.testing.parameterized dependency.
129 | # Implicit upb python proto dependency.
130 | # Implicit jax dependency.
131 | # Implicit numpy dependency.
132 | "//praxis:base_layer",
133 | "//praxis:pax_fiddle",
134 | "//praxis:test_utils",
135 | "//praxis/layers:linears",
136 | "//praxis/layers/quantization:layers",
137 | ],
138 | )
139 |
--------------------------------------------------------------------------------
/praxis/layers/quantization/sparsity/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Exposes the public layer functionalities."""
17 |
--------------------------------------------------------------------------------
/praxis/layers/quantizer_objectives.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """A set of objective functions for building quantizers (e.g VQ-VAE)."""
17 |
18 | import jax
19 | import jax.numpy as jnp
20 | from praxis import pytypes
21 |
22 | JTensor = pytypes.JTensor
23 |
24 |
25 | def scatter_nd(indices, updates, shape):
26 | zeros = jnp.zeros(shape, updates.dtype)
27 | key = tuple(jnp.moveaxis(indices, -1, 0))
28 | return zeros.at[key].add(updates)
29 |
30 |
31 | def batch_pplx_entropy_from_codes(
32 | codes: JTensor,
33 | num_classes: int,
34 | *,
35 | paddings: JTensor | None = None,
36 | data_parallel_axis: str | None = None
37 | ):
38 | """Calculates pplx and entropy from probs within batch.
39 |
40 | Args:
41 | codes: [..., num_groups] with values in [0, num_classes).
42 | num_classes: A Python int.
43 | paddings: [...], 0/1 value tensor.
44 | data_parallel_axis: If this is set we sum over replicas to get the combined
45 | statistics
46 |
47 | Returns:
48 | A tuple of 3 tensors:
49 | - pplx: scalar, avg_across_groups(avg(non-padded samples of a group))
50 | - entropy: scalar, avg_across_groups(avg(non-padded samples of a group))
51 | - histogram: [g, c], code word counts.
52 | """
53 | rank = len(codes.shape)
54 | assert rank is not None
55 | assert rank >= 2
56 |
57 | is_in_pmap = data_parallel_axis is not None
58 |
59 | codes = codes.astype(jnp.int32)
60 | if paddings is None:
61 | paddings = jnp.zeros_like(codes[..., 0], dtype=codes.dtype)
62 | else:
63 | paddings = paddings.astype(codes.dtype)
64 |
65 | num_groups = codes.shape[-1]
66 | # [?, g]
67 | codes = jnp.reshape(codes, [-1, num_groups])
68 | paddings = jnp.reshape(paddings, [-1, 1])
69 | paddings = jnp.broadcast_to(paddings, codes.shape)
70 |
71 | # [g]
72 | indices_offset = jnp.arange(
73 | start=0, stop=num_groups * num_classes, step=num_classes, dtype=jnp.int32)
74 | # [?, g]
75 | indices = codes + indices_offset
76 | # [? * g, 1]
77 | indices = jnp.reshape(indices, [-1])[:, jnp.newaxis]
78 |
79 | # [? * g]
80 | mask = (1.0 - paddings).astype(jnp.float32)
81 | values = jnp.reshape(mask, [-1])
82 |
83 | # [g * c]
84 | histogram = scatter_nd(indices, values, [num_groups * num_classes])
85 | normalizer = jnp.sum(values) / num_groups
86 |
87 | if is_in_pmap:
88 | histogram = jax.lax.psum(histogram, data_parallel_axis)
89 | normalizer = jax.lax.psum(normalizer, data_parallel_axis)
90 | # [g, c]
91 | histogram = jnp.reshape(histogram, [num_groups, num_classes])
92 |
93 | # [g, c]
94 | probs = histogram / jnp.maximum(normalizer, 1.0)
95 | log_probs = jnp.log(jnp.maximum(1.0e-30, probs))
96 | # [g]
97 | sum_plogp = jnp.sum(log_probs * probs, -1)
98 | pplx = jnp.mean(jnp.exp(-sum_plogp))
99 | entropy = jnp.log(pplx)
100 | return pplx, entropy, histogram
101 |
102 |
103 | def batch_codebook_coverage(
104 | codes: JTensor,
105 | num_classes: int,
106 | *,
107 | paddings: JTensor,
108 | data_parallel_axis: str | None = None
109 | ):
110 | """Computes codebook coverage within a batch.
111 |
112 | Args:
113 | codes: [..., num_groups], values are in [0, num_classes).
114 | num_classes: A Python int.
115 | paddings: [...], 0/1 value tensor.
116 | data_parallel_axis: If set will psum() over the axis
117 |
118 | Returns:
119 | A scalar JTensor, avg coverage across groups.
120 | """
121 | # [num_groups, num_classes]
122 | _, _, histogram = batch_pplx_entropy_from_codes(
123 | codes,
124 | num_classes,
125 | paddings=paddings,
126 | data_parallel_axis=data_parallel_axis)
127 | onehot = jnp.greater(histogram, 0).astype(jnp.float32)
128 | avg_num_covered_words = jnp.mean(jnp.sum(onehot, -1))
129 | return avg_num_covered_words / num_classes
130 |
--------------------------------------------------------------------------------
/praxis/layers/quantizer_objectives_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for quantizer_objectives."""
17 |
18 | from absl.testing import absltest
19 | import jax.numpy as jnp
20 | import numpy as np
21 | from praxis import test_utils
22 | from praxis.layers import quantizer_objectives
23 |
24 |
25 | class CodebookObjectivesTest(test_utils.TestCase):
26 | codes = np.array([
27 | [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [0, 0]],
28 | [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [0, 0]],
29 | ])
30 |
31 | paddings = np.array([[0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 1]])
32 | entropy = 1.609
33 | pplx = 5.000
34 | num_classes = 11
35 |
36 | def test_batch_pplx_entropy_from_codes(self):
37 | pplx, entropy, _ = quantizer_objectives.batch_pplx_entropy_from_codes(
38 | codes=jnp.array(self.codes),
39 | num_classes=self.num_classes,
40 | paddings=jnp.array(self.paddings),
41 | )
42 |
43 | self.assertAlmostEqual(
44 | np.array(pplx),
45 | np.array(self.pplx),
46 | delta=1e-3,
47 | msg='PPLX is not the same',
48 | )
49 | self.assertAlmostEqual(
50 | np.array(entropy),
51 | np.array(self.entropy),
52 | delta=1e-3,
53 | msg='Entropy is not the same',
54 | )
55 |
56 |
57 | if __name__ == '__main__':
58 | absltest.main()
59 |
--------------------------------------------------------------------------------
/praxis/layers/quantizer_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for quantizer."""
17 |
18 | from absl.testing import absltest
19 | from praxis import pax_fiddle
20 | from absl.testing import parameterized
21 |
22 | import jax
23 | import jax.numpy as jnp
24 | import numpy as np
25 | from praxis import base_layer
26 | from praxis import test_utils
27 | from praxis.layers import quantizer
28 |
29 | instantiate = base_layer.instantiate
30 |
31 |
32 | class VectorQuantizerTest(test_utils.TestCase):
33 |
34 | W = np.array([[0.116230249, 0.0104732513, -0.409445882, -0.153374314],
35 | [-0.0672334433, -0.430877686, -0.280010223, 0.394074917],
36 | [-0.360892653, -0.153173685, -0.45321393, -0.176380157],
37 | [0.406187773, 0.304340839, 0.439772606, 0.368542314]])
38 |
39 | def _GetParams(self, num_classes, latent_dim):
40 | return pax_fiddle.Config(
41 | quantizer.VectorQuantizer,
42 | name='vq',
43 | normalize_latent_vector=True,
44 | normalize_codebook=True,
45 | num_latent_classes=num_classes,
46 | latent_dim=latent_dim,
47 | beta=0.1,
48 | )
49 |
50 | def testBase(self):
51 | num_classes = 4
52 | latent_dim = 4
53 |
54 | b, t = 2, 4
55 | np.random.seed(2021)
56 | z = np.random.rand(b, t, latent_dim).astype(np.float32)
57 | paddings = np.zeros((b, t)).astype(np.float32)
58 |
59 | vq_p = self._GetParams(num_classes, latent_dim)
60 | vq = instantiate(vq_p)
61 | vq_theta = vq.init(jax.random.PRNGKey(1), z, paddings)
62 | vq_theta['params']['w'] = jnp.expand_dims(self.W, 1)
63 | out = vq.apply(vq_theta, z, paddings)
64 |
65 | with self.subTest('test_shape'):
66 | self.assertEqual((b, t, latent_dim), out.z_q.shape)
67 | self.assertEqual((b, t, 1), out.z_codes.shape)
68 | self.assertEqual((b, t, 1, num_classes), out.z_onehot.shape)
69 | with self.subTest('test_z_q'):
70 | self.assertAllClose(15.861525, np.sum(out.z_q))
71 | with self.subTest('test_z_codes'):
72 | self.assertEqual(24, np.sum(out.z_codes))
73 | with self.subTest('test_codebook_coverage'):
74 | self.assertEqual(0.25, np.sum(out.codebook_coverage))
75 | with self.subTest('test_pplx'):
76 | self.assertEqual(1.0, out.pplx)
77 | with self.subTest('test_entropy'):
78 | self.assertAllClose(0., out.entropy)
79 |
80 | def testNCodebooks(self):
81 | num_classes = 4
82 | latent_dim = 4
83 | num_groups = 2
84 |
85 | b, t = 2, 4
86 | np.random.seed(2021)
87 | z = np.random.rand(b, t, latent_dim).astype(np.float32)
88 | paddings = np.zeros((b, t)).astype(np.float32)
89 |
90 | vq_p = self._GetParams(num_classes, latent_dim)
91 | vq_p.num_groups = num_groups
92 | vq = instantiate(vq_p)
93 | vq_theta = vq.init(jax.random.PRNGKey(1), z, paddings)
94 | out = vq.apply(vq_theta, z, paddings)
95 |
96 | with self.subTest('test_shape'):
97 | self.assertEqual((b, t, latent_dim), out.z_q.shape)
98 | self.assertEqual((b, t, num_groups), out.z_codes.shape)
99 | self.assertEqual((b, t, num_groups, num_classes), out.z_onehot.shape)
100 |
101 |
102 | class RandomVectorQuantizerTest(test_utils.TestCase):
103 |
104 | @parameterized.parameters(
105 | (2, 4, 20, 16, 4, 1),
106 | (3, 4, 20, 32, 4, 2),
107 | (4, 7, 16, 256, 20, 8),
108 | )
109 | def testBase(self, b, t, latent_dim, projection_dim, num_classes, num_groups):
110 | np.random.seed(2022)
111 | z = np.random.rand(b, t, latent_dim).astype(np.float32)
112 | paddings = np.zeros((b, t)).astype(np.float32)
113 |
114 | rq = pax_fiddle.Config(
115 | quantizer.RandomVectorQuantizer,
116 | name='vq',
117 | num_latent_classes=num_classes,
118 | num_groups=num_groups,
119 | latent_dim=latent_dim,
120 | projection_dim=projection_dim,
121 | )
122 | rq = instantiate(rq)
123 | rq_theta = rq.init(jax.random.PRNGKey(1), z, paddings)
124 | out = rq.apply(rq_theta, z, paddings)
125 | self.assertEqual((b, t, projection_dim), out.z_q.shape)
126 | self.assertEqual((b, t, num_groups), out.z_codes.shape)
127 | self.assertEqual((b, t, num_groups, num_classes), out.z_onehot.shape)
128 |
129 | if __name__ == '__main__':
130 | absltest.main()
131 |
--------------------------------------------------------------------------------
/praxis/layers/searchable.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Layers for Platform-aware AutoML Search."""
17 |
18 | from typing import Any, Sequence
19 |
20 | from flax import linen as nn
21 | import jax.numpy as jnp
22 | from praxis import base_layer
23 | from praxis import pax_fiddle
24 | from praxis import pytypes
25 |
26 | JTensor = pytypes.JTensor
27 | LayerTpl = pax_fiddle.Config[base_layer.BaseLayer]
28 | NestedJTensor = pytypes.NestedJTensor
29 | template_field = pax_fiddle.template_field
30 | WeightInit = base_layer.WeightInit
31 | WeightHParams = base_layer.WeightHParams
32 |
33 | template_field = pax_fiddle.template_field
34 |
35 |
36 | class AutoMLSelect(base_layer.BaseLayer):
37 | """AutoMLSelect layer to switch branches according to AutoML decisions.
38 |
39 | Attributes:
40 | search_options_tpl: a sequence of layers, with each represents a branch.
41 | They layers must have the same shapes of input and output.
42 | """
43 |
44 | search_options_tpl: Sequence[LayerTpl] | None = template_field(None)
45 |
46 | def setup(self) -> None:
47 | if not self.search_options_tpl:
48 | raise AttributeError('Must set at least one search option.')
49 | decision = WeightHParams(
50 | shape=[],
51 | init=WeightInit.Constant(0),
52 | dtype=jnp.uint8,
53 | mesh_shape=self.mesh_shape
54 | )
55 | self.create_children('search_options', self.search_options_tpl)
56 | self.create_variable('decision', decision, trainable=False)
57 |
58 | def __call__(
59 | self,
60 | x: JTensor,
61 | *args: Any,
62 | **kwargs: Any,
63 | ) -> NestedJTensor:
64 | def branch_fn(i):
65 | return lambda mdl, x: mdl.search_options[i](x)
66 |
67 | branches = [branch_fn(i) for i in range(len(self.search_options))]
68 |
69 | if self.is_mutable_collection('params'):
70 | for branch in branches:
71 | _ = branch(self, x)
72 |
73 | return nn.switch(self.get_var('decision'), branches, self, x)
74 |
--------------------------------------------------------------------------------
/praxis/layers/searchable_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for praxis.layers.searchable."""
17 |
18 | from absl.testing import absltest
19 | import jax
20 | import jax.numpy as jnp
21 | from praxis import base_hyperparams
22 | from praxis import base_layer
23 | from praxis import pax_fiddle
24 | from praxis import test_utils
25 | from praxis.layers import linears
26 | from praxis.layers import searchable
27 |
28 | instantiate = base_hyperparams.instantiate
29 | weight_init = base_layer.WeightInit
30 |
31 |
32 | class AutomlSelectTest(test_utils.TestCase):
33 |
34 | def test_automl_select(self):
35 | p = pax_fiddle.Config(searchable.AutoMLSelect)
36 | p.search_options_tpl = [
37 | pax_fiddle.Config(
38 | linears.Linear,
39 | name='jax_ffn0',
40 | weight_init=weight_init.Constant(0.0),
41 | input_dims=1,
42 | output_dims=1),
43 | pax_fiddle.Config(
44 | linears.Linear,
45 | name='jax_ffn1',
46 | weight_init=weight_init.Constant(1.0),
47 | input_dims=1,
48 | output_dims=1),
49 | ]
50 |
51 | m = instantiate(p)
52 | x = jnp.ones(1, dtype=jnp.float32)
53 | m_vars = m.init(jax.random.PRNGKey(0), x)
54 | m_vars['non_trainable']['decision'] = 0
55 | x1 = m.apply(m_vars, x)
56 | m_vars['non_trainable']['decision'] = 1
57 | x2 = m.apply(m_vars, x)
58 | self.assertAllClose(x1, jnp.zeros(1, dtype=jnp.float32))
59 | self.assertAllClose(x2, jnp.ones(1, dtype=jnp.float32))
60 |
61 |
62 | if __name__ == '__main__':
63 | absltest.main()
64 |
--------------------------------------------------------------------------------
/praxis/layers/sequential.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """A sequential stack of layers for Pax."""
17 |
18 | from typing import Any, Callable, Sequence
19 |
20 | from praxis import base_layer
21 |
22 |
23 | class Sequential(base_layer.BaseLayer):
24 | """Applies a linear chain of Modules."""
25 | layers: Sequence[Callable[..., Any]] | None = None
26 |
27 | def __call__(self, *args, **kwargs):
28 | if not self.layers:
29 | raise ValueError(f'Empty Sequential module {self.name}.')
30 |
31 | outputs = self.layers[0](*args, **kwargs)
32 | for layer in self.layers[1:]:
33 | if isinstance(outputs, tuple):
34 | outputs = layer(*outputs)
35 | elif isinstance(outputs, dict):
36 | outputs = layer(**outputs)
37 | else:
38 | outputs = layer(outputs)
39 | return outputs
40 |
--------------------------------------------------------------------------------
/praxis/layers/sequential_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Unittest for Sequential."""
17 |
18 | from absl.testing import absltest
19 | import jax
20 | from jax import numpy as jnp
21 |
22 | from praxis import base_layer
23 | from praxis import pax_fiddle
24 | from praxis import py_utils
25 | from praxis import pytypes
26 | from praxis import test_utils
27 | from praxis.layers.sequential import Sequential
28 |
29 | JTensor = pytypes.JTensor
30 | NestedMap = py_utils.NestedMap
31 | instantiate = base_layer.instantiate
32 |
33 | PARAMS = base_layer.PARAMS
34 | RANDOM = base_layer.RANDOM
35 | AUX_LOSS = base_layer.AUX_LOSS
36 | SUMMARIES = base_layer.SUMMARIES
37 | NON_TRAINABLE = base_layer.NON_TRAINABLE
38 | NON_PAX_VAR_COLLECTION = base_layer.NON_PAX_VAR_COLLECTION
39 | DECODE_CACHE = base_layer.DECODE_CACHE
40 | DEFAULT_INIT_MUTABLE_LIST = base_layer.DEFAULT_INIT_MUTABLE_LIST
41 |
42 |
43 | class Feature(base_layer.BaseLayer):
44 | def __call__(self, features: JTensor, paddings: JTensor) -> NestedMap:
45 | return NestedMap(features=features, paddings=paddings)
46 |
47 |
48 | class SpecAugment(base_layer.BaseLayer):
49 | def __call__(self, features: JTensor, paddings: JTensor) -> NestedMap:
50 | return NestedMap(features=features, paddings=paddings)
51 |
52 |
53 | class Encoder(base_layer.BaseLayer):
54 | def __call__(self, features: JTensor, paddings: JTensor) -> NestedMap:
55 | return NestedMap(embeddings=features, paddings=paddings)
56 |
57 |
58 | class DecodingPrep(base_layer.BaseLayer):
59 | def __call__(self, embeddings: JTensor, paddings: JTensor) -> NestedMap:
60 | return NestedMap(embeddings=embeddings, paddings=paddings)
61 |
62 |
63 | class SequentialTest(test_utils.TestCase):
64 |
65 | def test_simple_sequence(self):
66 | specaug_p = pax_fiddle.Config(SpecAugment, name='specaugment')
67 | feature_p = pax_fiddle.Config(Feature, name='feature')
68 | encoder_p = pax_fiddle.Config(Encoder, name='encoder')
69 | decodingprep_p = pax_fiddle.Config(DecodingPrep, name='decodingprep')
70 |
71 | sequence_p = pax_fiddle.Config(
72 | Sequential, layers=[specaug_p, feature_p, encoder_p, decodingprep_p])
73 | sequence = instantiate(sequence_p)
74 |
75 | features = jax.random.normal(jax.random.PRNGKey(123), shape=[2, 1, 4])
76 | paddings = jnp.zeros(shape=[2, 1])
77 |
78 | k1 = jax.random.PRNGKey(123)
79 | k2 = jax.random.PRNGKey(456)
80 | with base_layer.JaxContext.new_context():
81 | initial_vars = sequence.init(
82 | rngs={RANDOM: k1, PARAMS: k2},
83 | mutable=DEFAULT_INIT_MUTABLE_LIST,
84 | features=features,
85 | paddings=paddings)
86 |
87 | outputs = sequence.apply(initial_vars, features, paddings)
88 |
89 | self.assertIn('embeddings', outputs)
90 | self.assertIn('paddings', outputs)
91 | self.assertAllClose(outputs.embeddings, features)
92 |
93 |
94 | if __name__ == '__main__':
95 | absltest.main()
96 |
97 |
--------------------------------------------------------------------------------
/praxis/layers/sharding.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Sharding utilities."""
17 |
18 | import math
19 | import re
20 | from typing import Sequence
21 |
22 | import jax
23 | from jax.interpreters import pxla
24 | from praxis import base_layer
25 | from praxis import py_utils
26 |
27 | DimSharding = str | Sequence[str] | None
28 | Sharding = Sequence[DimSharding] | None
29 |
30 |
31 | def derive(s: Sharding, eqn: str) -> Sharding:
32 | """Derives a sharding based on an equation `original->derived`.
33 |
34 | Each letter in original and derived represents a named dimension, and the
35 | derivation is done by matching dimension names. E.g., with s=('x', 'y') and
36 | eqn="ab->cbda", the result will be (None, 'y', None, 'x').
37 |
38 | Args:
39 | s: Source sharding.
40 | eqn: Derivation equation with named dimensions.
41 |
42 | Returns:
43 | The derived sharding.
44 | """
45 | if s is None:
46 | return None
47 | pieces = eqn.split('->')
48 | assert len(pieces) == 2, eqn
49 | original, derived = pieces
50 |
51 | return tuple(s[original.index(d)] if d in original else None for d in derived)
52 |
53 |
54 | def shard(x: jax.Array, s: Sharding, eqn: str | None = None) -> jax.Array:
55 | """Annotates x with a sharding based on an optional equation.
56 |
57 | If equation is not specified, apply just jax.lax.with_sharding_constraint.
58 |
59 | In equation `original->derived`, each letter in original and derived
60 | represents a named dimension, and the derivation is done by matching
61 | dimension names.
62 |
63 | Each dim in `derived` can also be ?, which means unconstrained.
64 |
65 | E.g., with s=('x', 'y') and eqn="ab->cb?a", the derived sharding will be
66 | (None, 'y', unconstrained, 'x').
67 |
68 | In `derived`, there can also be a group of consecutive dims marked optional,
69 | which are represented as dims inside `[]`. The tensor x can have either all of
70 | these dims, or none of them.
71 |
72 | E.g., with s=('x', 'y', 'z') and eqn="abc->[ab]c", the derived sharding will
73 | be ('x', 'y', 'z') for a 3D tensor x, and ('z',) of a 1D tensor x.
74 |
75 | Args:
76 | x: The tensor to annotate.
77 | s: Source sharding.
78 | eqn: Derivation equation with named dimensions.
79 |
80 | Returns:
81 | The derived sharding.
82 | """
83 | if s is None or not py_utils.global_mesh_defined():
84 | return x
85 |
86 | if eqn is not None:
87 | original, derived = eqn.split('->')
88 | if '[' in derived:
89 | l, optional, r = re.split(r'\[|\]', derived)
90 |
91 | if x.ndim == len(l) + len(r):
92 | derived = l + r
93 | elif x.ndim == len(l) + len(optional) + len(r):
94 | derived = l + optional + r
95 | else:
96 | raise ValueError(f'Given {derived=} is incompatible with {x=}')
97 |
98 | s = derive(s, f'{original}->{derived}')
99 | assert s is not None
100 | s = list(s)
101 | for i, p in enumerate(derived):
102 | if p == '?':
103 | s[i] = jax.sharding.PartitionSpec.UNCONSTRAINED
104 |
105 | partition_spec = jax.sharding.PartitionSpec(*s)
106 |
107 | # If mesh_axes_transpose exists in the current context, device axes will be
108 | # remapped according to the transpose rules.
109 | partition_spec = base_layer.maybe_transpose_mesh_axes(partition_spec)
110 |
111 | return jax.lax.with_sharding_constraint(x, partition_spec)
112 |
113 |
114 | def get_dim_sharding(s: Sharding, dim: int) -> DimSharding:
115 | """Returns the sharding on one dimension."""
116 | if s is None:
117 | return None
118 | return s[dim]
119 |
120 |
121 | def shard_one_dim(x: jax.Array, s: DimSharding, dim: int) -> jax.Array:
122 | """Annotates x on one dim while other dims are unconstrained."""
123 | perm = '?' * dim + 'd' + '?' * (x.ndim - dim - 1)
124 | return shard(x, (s,), 'd->' + perm)
125 |
126 |
127 | def num_shards_on_dim(dim_sharding: DimSharding) -> int:
128 | """Returns the number of shards on one dimension in a sharding."""
129 | mesh = pxla.thread_resources.env.physical_mesh
130 | axis_sizes = dict(zip(mesh.axis_names, mesh.devices.shape))
131 |
132 | mapping = None
133 | if base_layer.JaxContext.has_context():
134 | mapping = base_layer.cur_jax_context().hparams.mesh_axes_transpose
135 |
136 | match dim_sharding:
137 | case None:
138 | return 1
139 | case str():
140 | return axis_sizes.get(
141 | base_layer.transpose_one_axis(dim_sharding, mapping), 1
142 | )
143 | case _:
144 | assert isinstance(dim_sharding, Sequence)
145 | return math.prod(
146 | axis_sizes.get(base_layer.transpose_one_axis(axis, mapping), 1)
147 | for axis in dim_sharding
148 | )
149 |
--------------------------------------------------------------------------------
/praxis/layers/ssm_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for Praxis SSM layers."""
17 |
18 | from absl import logging
19 | from absl.testing import absltest
20 | from absl.testing import parameterized
21 | import jax
22 | from jax import numpy as jnp
23 | import numpy as np
24 | from praxis import base_layer
25 | from praxis import pax_fiddle
26 | from praxis import py_utils
27 | from praxis import test_utils
28 | from praxis.layers import ssm
29 | import tensorflow.compat.v2 as tf
30 |
31 | to_np = test_utils.to_np
32 | to_tf_nmap = test_utils.to_tf_nmap
33 | instantiate = base_layer.instantiate
34 |
35 |
36 | class SSMTest(test_utils.TestCase):
37 |
38 | def setUp(self):
39 | super().setUp()
40 | np.random.seed(123456)
41 | tf.random.set_seed(123)
42 |
43 | @parameterized.named_parameters(
44 | {
45 | 'testcase_name': 'ss4d-1d-legs',
46 | 'hippo_type': 'ss4d-1d-legs',
47 | },
48 | {
49 | 'testcase_name': 'ss4d-1d-lagt',
50 | 'hippo_type': 'ss4d-1d-lagt',
51 | },
52 | {
53 | 'testcase_name': 'ss4d-1d',
54 | 'hippo_type': 'ss4d-1d',
55 | },
56 | )
57 | def test_s4d_layer(self, hippo_type):
58 | p = pax_fiddle.Config(
59 | ssm.SSM,
60 | name='ssm',
61 | nheads=5,
62 | dim=1,
63 | l_max=2,
64 | decode_num_samples=1,
65 | step_size=1.,
66 | hippo_type=hippo_type
67 | )
68 | s4d = instantiate(p)
69 | npy_input = np.random.normal(1.0, 0.5,
70 | [2, p.l_max, p.dim]).astype('float32')
71 | inputs = jnp.asarray(npy_input)
72 | prng_key = jax.random.PRNGKey(seed=123)
73 | initial_vars = s4d.init(prng_key, inputs)
74 |
75 | # Test convolution/fft.
76 | outputs, ssm_state = s4d.apply(
77 | initial_vars, inputs, mutable=[base_layer.DECODE_CACHE])
78 | logging.info('outputs = %s', outputs)
79 |
80 | # Test extend_step.
81 | out_step = []
82 | updated_vars = py_utils.merge_dict(ssm_state, initial_vars)
83 | logging.info('init_vars w state = %s', updated_vars)
84 | for i in range(p.l_max):
85 | out, ssm_state = s4d.apply(
86 | updated_vars, inputs[:, i, :], method=s4d.extend_step,
87 | mutable=[base_layer.DECODE_CACHE])
88 | logging.info('outputs = %s', out)
89 | logging.info('ssm_states = %s', ssm_state)
90 | updated_vars['decoder_cache'] = ssm_state['decoder_cache']
91 | out_step.append(out)
92 |
93 | out_step = jnp.stack(out_step, axis=1)
94 |
95 | # Make sure the convolution/fft gets the same results as extend_step.
96 | self.assertAllClose(to_np(outputs), to_np(out_step), atol=1e-6)
97 |
98 |
99 | class S5Test(test_utils.TestCase):
100 |
101 | def setUp(self):
102 | super().setUp()
103 | np.random.seed(123456)
104 | tf.random.set_seed(123)
105 |
106 | def test_s5_layer(self):
107 | p = pax_fiddle.Config(
108 | ssm.S5,
109 | name='ssm',
110 | nheads=5,
111 | dim=1,
112 | l_max=2,
113 | decode_num_samples=1,
114 | step_size=1.0,
115 | )
116 | s5 = instantiate(p)
117 | npy_input = np.random.normal(1.0, 0.5, [2, p.l_max, p.dim]).astype(
118 | 'float32'
119 | )
120 | inputs = jnp.asarray(npy_input)
121 | prng_key = jax.random.PRNGKey(seed=123)
122 | initial_vars = s5.init(prng_key, inputs)
123 |
124 | # Test convolution/parallel_scan.
125 | outputs, ssm_state = s5.apply(
126 | initial_vars, inputs, mutable=[base_layer.DECODE_CACHE]
127 | )
128 | logging.info('outputs = %s', outputs)
129 |
130 | # Test extend_step.
131 | out_step = []
132 | updated_vars = py_utils.merge_dict(ssm_state, initial_vars)
133 | logging.info('init_vars w state = %s', updated_vars)
134 | for i in range(p.l_max):
135 | out, ssm_state = s5.apply(
136 | updated_vars,
137 | inputs[:, i, :],
138 | method=s5.extend_step,
139 | mutable=[base_layer.DECODE_CACHE],
140 | )
141 | logging.info('outputs = %s', out)
142 | logging.info('ssm_states = %s', ssm_state)
143 | updated_vars['decoder_cache'] = ssm_state['decoder_cache']
144 | out_step.append(out)
145 |
146 | out_step = jnp.stack(out_step, axis=1)
147 |
148 | # Make sure the convolution gets the same results as extend_step.
149 | self.assertAllClose(to_np(outputs), to_np(out_step), atol=1e-6)
150 |
151 |
152 | if __name__ == '__main__':
153 | absltest.main()
154 |
--------------------------------------------------------------------------------
/praxis/layers/stats.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Compute stats of tensors mostly for monitoring purposes."""
17 |
18 | from jax import numpy as jnp
19 | from praxis import py_utils
20 | from praxis import pytypes
21 |
22 | JTensor = pytypes.JTensor
23 | NestedMap = py_utils.NestedMap
24 |
25 |
26 | def compute_stats(inputs: JTensor, padding: JTensor | None = None) -> NestedMap:
27 | """Computes various stats over the valid data points in inputs."""
28 | # Let's compute stats in fp32
29 | inputs = inputs.astype(jnp.float32)
30 | if padding is None:
31 | padding = jnp.zeros_like(inputs)
32 | assert inputs.ndim == padding.ndim, f'{inputs.shape}, {padding.shape}'
33 | mask = 1.0 - padding
34 |
35 | sum_v = jnp.sum(inputs * mask)
36 | count_v = jnp.sum(jnp.ones_like(inputs) * mask)
37 | mean_v = sum_v / jnp.maximum(1.0, count_v)
38 | sum_v_squared = jnp.sum(jnp.square((inputs - mean_v) * mask))
39 | std_v = jnp.sqrt(sum_v_squared / jnp.maximum(1.0, count_v))
40 | max_v = jnp.max(jnp.abs(inputs * mask))
41 |
42 | return NestedMap(mean_v=mean_v, std_v=std_v, max_v=max_v)
43 |
--------------------------------------------------------------------------------
/praxis/layers/stats_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for Praxis attention layers."""
17 |
18 | from absl.testing import absltest
19 | from jax import numpy as jnp
20 | import numpy as np
21 | from praxis import test_utils
22 | from praxis.layers import stats
23 |
24 |
25 | class StatsTest(test_utils.TestCase):
26 |
27 | def setUp(self):
28 | super().setUp()
29 | np.random.seed(12345687)
30 |
31 | def test_compute_stats(self):
32 | inputs = jnp.array([[0., 1., 2.], [3., 4., 5.]])
33 | padding = jnp.array([[0., 0., 0.], [0., 0., 1.]])
34 | inputs_stats = stats.compute_stats(inputs, padding)
35 | self.assertAllClose(
36 | [2., 1.414214, 4.],
37 | [inputs_stats.mean_v, inputs_stats.std_v, inputs_stats.max_v])
38 |
39 |
40 | if __name__ == '__main__':
41 | absltest.main()
42 |
--------------------------------------------------------------------------------
/praxis/layers/stochastics.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Stochastic layers."""
17 |
18 | import numbers
19 | from typing import Sequence
20 |
21 | import jax
22 | from jax import numpy as jnp
23 | from praxis import base_layer
24 | from praxis import py_utils
25 | from praxis import pytypes
26 |
27 | NestedMap = py_utils.NestedMap
28 | JTensor = pytypes.JTensor
29 |
30 |
31 | class Dropout(base_layer.BaseLayer):
32 | """Apply dropout during training.
33 |
34 | Attributes:
35 | keep_prob: Keep probability.
36 | noise_shape: A 1-D list of type `int32`, representing the shape for randomly
37 | generated keep/drop flags. Note that this noise_shape is unknown, when
38 | building layer params.
39 | noise_shape_broadcast_dims: A list of dimension where the noise shape is
40 | broadcasted. For example, noise_shape = [n, h, w, 1] when
41 | noise_shape_broadcast_dims=[-1].
42 | dropout_at_eval: Whether or not to also perform dropout at eval time. We
43 | typically want to replace dropout by expectation during eval. However, in
44 | certain cases E(f(x)) != f(E(x)), and replacing dropout by its expectation
45 | during eval leads to worse quality.
46 | transpose_qk: Whether or not to transpose the sequence length dimensions
47 | correspond to query and key in the generated random numbers to avoid
48 | expensive copies.
49 | """
50 | keep_prob: float = 1.0
51 | noise_shape: Sequence[int] | None = None
52 | noise_shape_broadcast_dims: Sequence[int] | None = None
53 | dropout_at_eval: bool = False
54 | transpose_qk: bool | None = False
55 |
56 | def _dropout(self, inputs: JTensor, noise_shape: list[int]) -> JTensor:
57 | if noise_shape is None:
58 | noise_shape = inputs.shape
59 | prng_key = self.next_prng_key()
60 | keep_prob = self.keep_prob
61 | assert keep_prob > 0.0
62 | random_nums = keep_prob + jax.random.uniform(
63 | prng_key, noise_shape, inputs.dtype, minval=0.0, maxval=1.0)
64 | transpose = (
65 | len(noise_shape) == 4
66 | and noise_shape[3] == noise_shape[2]
67 | and self.transpose_qk
68 | )
69 | if transpose:
70 | random_nums = jnp.transpose(random_nums, (0, 1, 3, 2))
71 | binary_mask = jnp.floor(random_nums)
72 | return inputs * binary_mask / keep_prob
73 |
74 | def __call__(self, inputs: JTensor) -> JTensor:
75 | """Applies dropout to inputs.
76 |
77 | Args:
78 | inputs: The inputs JTensor.
79 |
80 | Returns:
81 | inputs with dropout applied at training time.
82 | """
83 | if isinstance(self.keep_prob, numbers.Real) and self.keep_prob == 1.0:
84 | return inputs
85 |
86 | if self.do_eval and not self.dropout_at_eval:
87 | return inputs
88 |
89 | if not self.noise_shape_broadcast_dims:
90 | noise_shape = self.noise_shape
91 | else:
92 | noise_shape = self.noise_shape or list(inputs.shape)
93 | if not isinstance(noise_shape, list):
94 | noise_shape = list(noise_shape)
95 | for dim in self.noise_shape_broadcast_dims:
96 | if dim >= len(noise_shape):
97 | raise ValueError('Invalid broadcasted dim {}'.format(dim))
98 | noise_shape[dim] = 1
99 |
100 | ret = self._dropout(inputs, noise_shape)
101 | return ret
102 |
103 |
104 | class StochasticResidual(base_layer.BaseLayer):
105 | """Stochastic residual layer that randomly drops the residual branch.
106 |
107 | Attributes:
108 | residual_weight: Residual weight with which to add the reisdual back to the
109 | input.
110 | survival_prob: Survival probability of the residual branch while dropping
111 | out.
112 | """
113 | residual_weight: float = 1.0
114 | survival_prob: float = 1.0
115 |
116 | def _drop_connect(self, inputs: JTensor) -> JTensor:
117 | """Drops the entire residual layer with given survival probability.
118 |
119 | Args:
120 | inputs: input `.JTensor` which is on the residual branch which is dropped.
121 |
122 | Returns:
123 | Dropped out inputs.
124 | """
125 | if self.do_eval:
126 | return inputs
127 |
128 | # Compute tensor.
129 | prng_key = self.next_prng_key()
130 | batch_size = inputs.shape[0]
131 | shape = [batch_size] + [1] * (len(inputs.shape) - 1)
132 | random_tensor = self.survival_prob + jax.random.uniform(
133 | prng_key, shape, dtype=inputs.dtype
134 | )
135 | binary_tensor = jnp.floor(random_tensor)
136 | # Unlike conventional way that multiply survival_prob at test time, here we
137 | # divide survival_prob at training time, such that no additional compute is
138 | # needed at test time.
139 | output = inputs / self.survival_prob * binary_tensor
140 | return output
141 |
142 | def __call__(self, inputs: JTensor, residual: JTensor) -> JTensor:
143 | """Returns inputs + residual with stochastic dropout.
144 |
145 | Args:
146 | inputs: input `.JTensor`.
147 | residual: residual `.JTensor` which is added to input with dropout.
148 |
149 | Returns:
150 | Output `.JTensor` which is residual added to inputs with dropout.
151 | """
152 | return inputs + self.residual_weight * self._drop_connect(residual)
153 |
--------------------------------------------------------------------------------
/praxis/layers/stochastics_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for Praxis stochastic layers."""
17 |
18 | from absl import logging
19 | from absl.testing import absltest
20 | import jax
21 | from jax import numpy as jnp
22 | from praxis import base_layer
23 | from praxis import pax_fiddle
24 | from praxis import test_utils
25 | from praxis.layers import stochastics
26 |
27 | jax.config.update('jax_threefry_partitionable', False)
28 |
29 | instantiate = base_layer.instantiate
30 |
31 |
32 | class StochasticsTest(test_utils.TestCase):
33 |
34 | def test_dropout_layer01(self):
35 | test_layer_p = pax_fiddle.Config(
36 | stochastics.Dropout, name='dropout', keep_prob=0.8
37 | )
38 | layer = instantiate(test_layer_p)
39 |
40 | inputs = jnp.ones([10, 1000], dtype=jnp.bfloat16)
41 |
42 | with base_layer.JaxContext.new_context():
43 | prng_key = jax.random.PRNGKey(seed=12346)
44 | prng_key, init_key = jax.random.split(prng_key)
45 | prng_key, dropout_k1, dropout_k2 = jax.random.split(prng_key, 3)
46 | initial_vars = layer.init({
47 | 'random': dropout_k1,
48 | 'params': init_key
49 | }, inputs)
50 | logging.info('initial_vars: %s', initial_vars)
51 | output1 = layer.apply(initial_vars, inputs, rngs={'random': dropout_k1})
52 | output2 = layer.apply(initial_vars, inputs, rngs={'random': dropout_k2})
53 |
54 | out1_sum = jnp.sum(output1)
55 | out2_sum = jnp.sum(output2)
56 | out1_nonzero = jnp.sum(output1 > 0.0)
57 | out2_nonzero = jnp.sum(output2 > 0.0)
58 |
59 | logging.info('out1_sum: %s', out1_sum)
60 | logging.info('out2_sum: %s', out2_sum)
61 | logging.info('out1_nonzero: %s', out1_nonzero)
62 | logging.info('out2_nonzero: %s', out2_nonzero)
63 |
64 | finfo = jnp.finfo(inputs.dtype)
65 | nmant = finfo.nmant
66 | self.assertEqual(9984.0, out1_sum)
67 | if nmant < 8:
68 | self.assertEqual(9984.0, out2_sum)
69 | self.assertEqual(7983.0, out1_nonzero)
70 | self.assertEqual(7964.0, out2_nonzero)
71 | else:
72 | self.assertEqual(9920.0, out2_sum)
73 | self.assertEqual(8000.0, out1_nonzero)
74 | self.assertEqual(7952.0, out2_nonzero)
75 |
76 | def test_dropout_layer_02(self):
77 | test_layer_p = pax_fiddle.Config(
78 | stochastics.Dropout,
79 | name='dropout',
80 | keep_prob=0.8,
81 | noise_shape=[10, 6, 8],
82 | noise_shape_broadcast_dims=[2],
83 | )
84 | layer = instantiate(test_layer_p)
85 |
86 | inputs = jnp.ones([2, 10, 6, 8], dtype=jnp.bfloat16)
87 |
88 | with base_layer.JaxContext.new_context():
89 | prng_key = jax.random.PRNGKey(seed=12346)
90 | prng_key, init_key = jax.random.split(prng_key)
91 | prng_key, compute_key = jax.random.split(prng_key)
92 | initial_vars = layer.init({
93 | 'random': compute_key,
94 | 'params': init_key
95 | }, inputs)
96 | logging.info('initial_vars: %s', initial_vars)
97 | output1 = layer.apply(initial_vars, inputs, rngs={'random': compute_key})
98 |
99 | out1_sum = jnp.sum(output1)
100 | out1_nonzero = jnp.sum(output1 > 0.0)
101 |
102 | logging.info('out1_sum: %s', out1_sum)
103 | logging.info('out1_nonzero: %s', out1_nonzero)
104 |
105 | self.assertEqual(980, out1_sum)
106 | self.assertEqual(784, out1_nonzero)
107 |
108 | def test_dropout_layer_03(self):
109 | test_layer_p = pax_fiddle.Config(
110 | stochastics.Dropout,
111 | name='dropout',
112 | keep_prob=0.8,
113 | noise_shape_broadcast_dims=[0, 3],
114 | )
115 | layer = instantiate(test_layer_p)
116 |
117 | inputs = jnp.ones([2, 10, 6, 8], dtype=jnp.bfloat16)
118 |
119 | with base_layer.JaxContext.new_context():
120 | prng_key = jax.random.PRNGKey(seed=12346)
121 | prng_key, init_key = jax.random.split(prng_key)
122 | prng_key, compute_key = jax.random.split(prng_key)
123 | initial_vars = layer.init({
124 | 'random': compute_key,
125 | 'params': init_key
126 | }, inputs)
127 | logging.info('initial_vars: %s', initial_vars)
128 |
129 | output1 = layer.apply(initial_vars, inputs, rngs={'random': compute_key})
130 |
131 | out1_sum = jnp.sum(output1)
132 | out1_nonzero = jnp.sum(output1 > 0.0)
133 |
134 | logging.info('out1_sum: %s', out1_sum)
135 | logging.info('out1_nonzero: %s', out1_nonzero)
136 |
137 | self.assertEqual(980, out1_sum)
138 | self.assertEqual(784, out1_nonzero)
139 |
140 |
141 | if __name__ == '__main__':
142 | absltest.main()
143 |
--------------------------------------------------------------------------------
/praxis/layers/vanillanets_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for vanillanets."""
17 |
18 | from absl.testing import absltest
19 | from praxis import pax_fiddle
20 | from absl.testing import parameterized
21 | import jax
22 | import jax.numpy as jnp
23 | import numpy as np
24 | from praxis import base_layer
25 | from praxis.layers import poolings
26 | from praxis.layers import vanillanets
27 |
28 | instantiate = base_layer.instantiate
29 |
30 |
31 | class VanillanetsTest(parameterized.TestCase):
32 |
33 | def setUp(self):
34 | super().setUp()
35 | np.random.seed(123456)
36 |
37 | @parameterized.parameters(
38 | (3, 1, 'RELU', [2, 16, 36, 72], 72),
39 | (4, 2, 'TANH', [4, 12, 12, 16], 32),
40 | (4, 2, 'RELU', [4, 12, 12, 16], 8),
41 | (5, 1, 'NONE', [4, 12, 12, 16], 8),
42 | )
43 | def test_vanilla_block(self, kernel_size, stride, activation, input_shape,
44 | output_dim):
45 | input_dim = input_shape[-1]
46 | p = pax_fiddle.Config(
47 | vanillanets.VanillaBlock,
48 | name='vanilla_block',
49 | input_dim=input_dim,
50 | output_dim=output_dim,
51 | kernel_size=kernel_size,
52 | stride=stride,
53 | )
54 | resnet_layer = instantiate(p)
55 | npy_inputs = np.random.normal(1.0, 0.5, input_shape).astype('float32')
56 | inputs = jnp.asarray(npy_inputs)
57 |
58 | with base_layer.JaxContext.new_context():
59 | prng_key = jax.random.PRNGKey(seed=123)
60 | initial_vars = resnet_layer.init(prng_key, inputs)
61 | output = resnet_layer.apply(initial_vars, inputs)
62 |
63 | @parameterized.parameters(
64 | ([1, 4, 4, 3], None),
65 | ([1, 4, 4, 3], [1, 2]),
66 | )
67 | def test_vanilla_net(self, input_shape, spatial_pooling_dims):
68 | p = vanillanets.VanillaNet.HParamsVanillaNet5().set(
69 | name='vanillanet', output_spatial_pooling_params=spatial_pooling_dims)
70 | if spatial_pooling_dims is not None:
71 | p.output_spatial_pooling_params = pax_fiddle.Config(
72 | poolings.GlobalPooling, pooling_dims=spatial_pooling_dims
73 | )
74 | vanillanet_layer = instantiate(p)
75 | npy_inputs = np.random.normal(1.0, 0.5, input_shape).astype('float32')
76 | inputs = jnp.asarray(npy_inputs)
77 |
78 | with base_layer.JaxContext.new_context():
79 | prng_key = jax.random.PRNGKey(seed=123)
80 | initial_vars = vanillanet_layer.init(prng_key, inputs)
81 | output = vanillanet_layer.apply(initial_vars, inputs)
82 |
83 |
84 | if __name__ == '__main__':
85 | absltest.main()
86 |
--------------------------------------------------------------------------------
/praxis/layers/video/BUILD:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | load("//praxis:build-visibility.bzl", "JAX_VISIBILITY")
17 |
18 | # Description:
19 | # video related layers. The public API is defined in __init__.py.
20 | load("//praxis:praxis.bzl", "pytype_strict_library", "pytype_strict_test")
21 |
22 | package(default_visibility = JAX_VISIBILITY)
23 |
24 | pytype_strict_library(
25 | name = "enc_dec_3dcnn",
26 | srcs = [
27 | "enc_dec_3dcnn.py",
28 | ],
29 | deps = [
30 | # Implicit jax dependency.
31 | # Implicit numpy dependency.
32 | "//praxis:base_layer",
33 | "//praxis:pax_fiddle",
34 | "//praxis:py_utils",
35 | "//praxis:pytypes",
36 | "//praxis/layers:activations",
37 | "//praxis/layers:convolutions",
38 | "//praxis/layers:linears",
39 | "//praxis/layers:normalizations",
40 | ],
41 | )
42 |
43 | pytype_strict_test(
44 | name = "enc_dec_3dcnn_test",
45 | srcs = [
46 | "enc_dec_3dcnn_test.py",
47 | ],
48 | deps = [
49 | ":enc_dec_3dcnn",
50 | # Implicit absl.logging dependency.
51 | # Implicit absl.testing.absltest dependency.
52 | # Implicit jax dependency.
53 | "//praxis:base_layer",
54 | "//praxis:pax_fiddle",
55 | "//praxis:test_utils",
56 | ],
57 | )
58 |
59 | pytype_strict_library(
60 | name = "losses",
61 | srcs = [
62 | "losses.py",
63 | ],
64 | deps = [
65 | # Implicit jax dependency.
66 | "//praxis:base_layer",
67 | "//praxis:base_model",
68 | "//praxis:py_utils",
69 | "//praxis:pytypes",
70 | ],
71 | )
72 |
73 | pytype_strict_test(
74 | name = "losses_test",
75 | srcs = [
76 | "losses_test.py",
77 | ],
78 | deps = [
79 | ":losses",
80 | ":vqvae",
81 | # Implicit absl.testing.absltest dependency.
82 | # Implicit absl.testing.parameterized dependency.
83 | # Implicit jax dependency.
84 | # Implicit numpy dependency.
85 | "//praxis:base_layer",
86 | "//praxis:pax_fiddle",
87 | "//praxis:py_utils",
88 | "//praxis:test_utils",
89 | ],
90 | )
91 |
92 | pytype_strict_library(
93 | name = "quantizer",
94 | srcs = [
95 | "quantizer.py",
96 | ],
97 | deps = [
98 | # Implicit jax dependency.
99 | # Implicit numpy dependency.
100 | "//praxis:base_layer",
101 | "//praxis:py_utils",
102 | "//praxis:pytypes",
103 | ],
104 | )
105 |
106 | pytype_strict_test(
107 | name = "quantizer_test",
108 | srcs = [
109 | "quantizer_test.py",
110 | ],
111 | deps = [
112 | ":quantizer",
113 | # Implicit absl.testing.absltest dependency.
114 | # Implicit absl.testing.parameterized dependency.
115 | # Implicit jax dependency.
116 | "//praxis:base_layer",
117 | "//praxis:pax_fiddle",
118 | "//praxis:py_utils",
119 | "//praxis:test_utils",
120 | ],
121 | )
122 |
123 | pytype_strict_library(
124 | name = "vqvae",
125 | srcs = [
126 | "vqvae.py",
127 | ],
128 | deps = [
129 | ":enc_dec_3dcnn",
130 | ":quantizer",
131 | # Implicit jax dependency.
132 | "//praxis:base_layer",
133 | "//praxis:base_model",
134 | "//praxis:pax_fiddle",
135 | "//praxis:py_utils",
136 | "//praxis:pytypes",
137 | "//praxis/layers:activations",
138 | "//praxis/layers:convolutions",
139 | "//praxis/layers:linears",
140 | ],
141 | )
142 |
143 | pytype_strict_test(
144 | name = "vqvae_test",
145 | srcs = [
146 | "vqvae_test.py",
147 | ],
148 | deps = [
149 | ":quantizer",
150 | ":vqvae",
151 | # Implicit absl.testing.absltest dependency.
152 | # Implicit absl.testing.parameterized dependency.
153 | # Implicit jax dependency.
154 | "//praxis:base_layer",
155 | "//praxis:pax_fiddle",
156 | "//praxis:py_utils",
157 | "//praxis:test_utils",
158 | ],
159 | )
160 |
--------------------------------------------------------------------------------
/praxis/layers/video/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Exposes the public layer functionalities for video."""
17 |
18 | from praxis.layers.video import enc_dec_3dcnn
19 | from praxis.layers.video import losses
20 | from praxis.layers.video import quantizer
21 | from praxis.layers.video import vqvae
22 |
--------------------------------------------------------------------------------
/praxis/layers/video/enc_dec_3dcnn_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from absl import logging
17 | from absl.testing import absltest
18 | import jax
19 | from jax import numpy as jnp
20 | from praxis import base_layer
21 | from praxis import pax_fiddle
22 | from praxis import test_utils
23 | from praxis.layers.video import enc_dec_3dcnn
24 |
25 |
26 | class EndcDec3dcnnTest(test_utils.TestCase):
27 |
28 | def test_depth_to_space(self):
29 | x = jnp.ones((2, 5, 3, 3, 32))
30 | y = enc_dec_3dcnn.depth_to_space(x, 2, 4)
31 | self.assertEqual(y.shape, (2, 10, 6, 6, 4))
32 |
33 | def test_gan_res_block(self):
34 | prng_key, _ = jax.random.split(jax.random.PRNGKey(1234))
35 | pax_x = jax.random.normal(prng_key, (1, 17, 256, 257, 32))
36 | res_block_p = pax_fiddle.Config(
37 | enc_dec_3dcnn.DiscriminatorResBlock,
38 | name='res_block',
39 | input_dim=32,
40 | output_dim=32,
41 | )
42 | pax_layer = base_layer.instantiate(res_block_p)
43 | init_vars = pax_layer.init(prng_key, pax_x)
44 | logging.info(
45 | 'init_vars: %s', jax.tree_util.tree_map(lambda x: x.shape, init_vars)
46 | )
47 | pax_y = pax_layer.apply(init_vars, pax_x)
48 | self.assertEqual(pax_y.shape, (1, 9, 128, 129, 32))
49 |
50 | def test_res_block(self):
51 | prng_key, _ = jax.random.split(jax.random.PRNGKey(1234))
52 | pax_x = jax.random.normal(prng_key, (1, 5, 3, 3, 32))
53 | res_block_p = pax_fiddle.Config(
54 | enc_dec_3dcnn.ResBlock,
55 | name='res_block',
56 | input_dim=32,
57 | output_dim=64,
58 | use_conv_shortcut=True,
59 | )
60 | pax_layer = base_layer.instantiate(res_block_p)
61 | init_vars = pax_layer.init(prng_key, pax_x)
62 | logging.info(
63 | 'init_vars: %s', jax.tree_util.tree_map(lambda x: x.shape, init_vars)
64 | )
65 | pax_y = pax_layer.apply(init_vars, pax_x)
66 | self.assertEqual(pax_y.shape, pax_x.shape[:-1] + (64,))
67 |
68 |
69 | if __name__ == '__main__':
70 | absltest.main()
71 |
--------------------------------------------------------------------------------
/praxis/layers/video/losses.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Loss functions for vqvae/vqgan models."""
17 |
18 | from collections.abc import Callable
19 | import jax
20 | import jax.numpy as jnp
21 | from praxis import base_layer
22 | from praxis import base_model
23 | from praxis import py_utils
24 | from praxis import pytypes
25 |
26 | JTensor = pytypes.JTensor
27 |
28 |
29 | def r1_gradient_penalty(
30 | inputs: JTensor,
31 | logits_fn: Callable[[JTensor], JTensor],
32 | grad_penalty_cost: float = 10.0,
33 | ) -> tuple[JTensor, JTensor]:
34 | """Calculates gradients penalty loss to regularize the discriminator.
35 |
36 | Args:
37 | inputs: A tensor of image inputs.
38 | logits_fn: A function that takes inputs and returns logits.
39 | grad_penalty_cost: scalar weight for the gradient penalty loss.
40 |
41 | Returns:
42 | A tuple of logits and the gradient penalty.
43 | """
44 | out, vjp_fn = jax.vjp(logits_fn, inputs, has_aux=False)
45 | # Check if jax.value_and_grad is more efficient than jax.vjp at scale.
46 | grad = vjp_fn(jnp.ones_like(out))[0]
47 | flattened_grad = jnp.asarray(grad.reshape((inputs.shape[0], -1)), jnp.float32)
48 | penalty = (
49 | jnp.mean(jnp.sum(jnp.square(flattened_grad), axis=-1)) * grad_penalty_cost
50 | )
51 | return out, penalty
52 |
53 |
54 | def _discriminator_loss(logits_real: JTensor, logits_fake: JTensor) -> JTensor:
55 | """Calculates non-saturating discriminator loss."""
56 | d_loss_real = jax.nn.softplus(-logits_real)
57 | d_loss_fake = jax.nn.softplus(logits_fake)
58 | return jnp.mean(d_loss_real) + jnp.mean(d_loss_fake)
59 |
60 |
61 | def _generator_loss(logits_fake):
62 | """Calculates non-saturating generator loss."""
63 | return jnp.mean(jax.nn.softplus(-logits_fake))
64 |
65 |
66 | class VQGANLoss(base_layer.BaseLayer):
67 | """Loss layer for VQGAN."""
68 |
69 | g_adversarial_loss_weight: float = 0.1
70 | reconstruction_loss_weight: float = 5.0
71 | polyak_decay: float = 0.999
72 | lecam_weight: float = 0.001
73 |
74 | def lecam_loss(self, real_pred: JTensor, fake_pred: JTensor) -> JTensor:
75 | """Calculates lecam loss.
76 |
77 | Described in https://arxiv.org/abs/2104.03310
78 |
79 | Args:
80 | real_pred: scalar, predictions for the real samples.
81 | fake_pred: scalar, prdictions for the reconstructed (fake) samples.
82 |
83 | Returns:
84 | Lecam regularization loss (scalar).
85 | """
86 | ema_fake_pred = self.get_var('ema_fake_pred')
87 | ema_real_pred = self.get_var('ema_real_pred')
88 | return jnp.mean(
89 | jnp.power(jax.nn.relu(real_pred - ema_fake_pred), 2)
90 | ) + jnp.mean(jnp.power(jax.nn.relu(ema_real_pred - fake_pred), 2))
91 |
92 | def setup(self):
93 | """Constructs this jax module and registers variables."""
94 | decay_factor_hparams = base_layer.WeightHParams(
95 | shape=[],
96 | init=base_layer.WeightInit.Constant(0.0),
97 | dtype=jnp.float32,
98 | collections=[base_layer.WeightHParamsCollection.REQUIRES_MEAN_SYNC],
99 | )
100 |
101 | self.create_variable('ema_real_pred', decay_factor_hparams, trainable=False)
102 | self.create_variable('ema_fake_pred', decay_factor_hparams, trainable=False)
103 |
104 | def __call__(
105 | self, predictions: base_model.Predictions, input_batch: py_utils.NestedMap
106 | ) -> py_utils.NestedMap:
107 | original_video = input_batch.video
108 | reconstructed = predictions['reconstructed']
109 | logits_real = predictions['logits_real']
110 | logits_fake = predictions['logits_fake']
111 | real_pred = jnp.mean(logits_real)
112 | fake_pred = jnp.mean(logits_fake)
113 |
114 | ema_fake_pred = self.get_var('ema_fake_pred')
115 | ema_real_pred = self.get_var('ema_real_pred')
116 | ema_fake_pred = (
117 | fake_pred * (1 - self.polyak_decay) + ema_fake_pred * self.polyak_decay
118 | )
119 | ema_real_pred = (
120 | real_pred * (1 - self.polyak_decay) + ema_real_pred * self.polyak_decay
121 | )
122 | self.update_var('ema_fake_pred', ema_fake_pred)
123 | self.update_var('ema_real_pred', ema_real_pred)
124 |
125 | losses = py_utils.NestedMap()
126 | losses.grad_penalty = predictions['r1_gradient_penalty']
127 | losses.lecam_loss = (
128 | self.lecam_loss(logits_real, logits_fake) * self.lecam_weight
129 | )
130 |
131 | losses.d_adversarial_loss = _discriminator_loss(logits_real, logits_fake)
132 | losses.g_adversarial_loss = (
133 | _generator_loss(logits_fake) * self.g_adversarial_loss_weight
134 | )
135 |
136 | diff = jnp.asarray(original_video - reconstructed, jnp.float32)
137 |
138 | losses.reconstruction_loss = (
139 | jnp.mean(jnp.square(diff)) * self.reconstruction_loss_weight
140 | )
141 | losses.perceptual_loss = jnp.array(0.0, dtype=jnp.float32)
142 | if self.do_eval:
143 | losses.quantizer_loss = jnp.zeros_like(losses.reconstruction_loss)
144 | else:
145 | losses.quantizer_loss = predictions['quantizer_loss']
146 | losses.d_loss = (
147 | losses.d_adversarial_loss + losses.grad_penalty + losses.lecam_loss
148 | )
149 | losses.g_loss = (
150 | losses.reconstruction_loss
151 | + losses.g_adversarial_loss
152 | + losses.perceptual_loss
153 | + losses.quantizer_loss
154 | )
155 | return losses
156 |
--------------------------------------------------------------------------------
/praxis/layers/video/losses_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import functools
17 |
18 | from absl.testing import absltest
19 | from absl.testing import parameterized
20 | import jax
21 | import numpy as np
22 | from praxis import base_layer
23 | from praxis import pax_fiddle
24 | from praxis import py_utils
25 | from praxis import test_utils
26 | from praxis.layers.video import losses
27 | from praxis.layers.video import vqvae
28 |
29 |
30 | class LossesTest(test_utils.TestCase):
31 |
32 | def test_r1_gradient_penalty(self):
33 | prng_key = jax.random.PRNGKey(seed=123)
34 | x = jax.random.normal(prng_key, (2, 5, 16, 16, 3))
35 | # Create a pax layer and get the output from the random input.
36 | p = pax_fiddle.Config(
37 | vqvae.Discriminator,
38 | name='magvit',
39 | num_frames=5,
40 | image_height=16,
41 | image_width=16,
42 | filters=32,
43 | channel_multipliers=(2, 4),
44 | )
45 | context_p = base_layer.JaxContext.HParams(do_eval=False)
46 | with base_layer.JaxContext.new_context(hparams=context_p):
47 | pax_layer = base_layer.instantiate(p)
48 | pax_vars = pax_layer.init(prng_key, x)
49 | logit_fn = functools.partial(pax_layer.apply, pax_vars)
50 | logits, penalty = losses.r1_gradient_penalty(x, logit_fn)
51 | self.assertEqual(logits.shape, (2, 1))
52 | self.assertEqual(penalty.shape, ())
53 |
54 | @parameterized.parameters(True, False)
55 | def test_vqgan_loss(self, do_eval):
56 | batch_size, num_frames, height, width, channels = 2, 5, 128, 128, 3
57 | video_shape = (batch_size, num_frames, height, width, channels)
58 | np.random.seed(12345)
59 | input_batch = py_utils.NestedMap(
60 | video=np.random.randint(0, 255, size=video_shape)
61 | )
62 | predictions = py_utils.NestedMap(
63 | reconstructed=np.random.normal(size=video_shape),
64 | logits_real=np.random.normal(size=(batch_size, 1)),
65 | logits_fake=np.random.normal(size=(batch_size, 1)),
66 | quantizer_loss=np.random.normal(size=[]),
67 | r1_gradient_penalty=np.random.normal(size=[]),
68 | )
69 |
70 | loss_p = pax_fiddle.Config(
71 | losses.VQGANLoss,
72 | name='loss',
73 | )
74 | loss_layer = loss_p.Instantiate()
75 | prng_key = jax.random.PRNGKey(seed=123)
76 | context_p = base_layer.JaxContext.HParams(do_eval=do_eval)
77 | with base_layer.JaxContext.new_context(hparams=context_p):
78 | init_vars = loss_layer.init(prng_key, predictions, input_batch)
79 | loss_dict, updated_vars = loss_layer.apply(
80 | init_vars, predictions, input_batch, mutable=base_layer.NON_TRAINABLE
81 | )
82 | for loss in loss_dict.values():
83 | self.assertEqual((), loss.shape)
84 | self.assertNotEqual(
85 | updated_vars[base_layer.NON_TRAINABLE]['ema_fake_pred'], 0.0
86 | )
87 | self.assertNotEqual(
88 | updated_vars[base_layer.NON_TRAINABLE]['ema_real_pred'], 0.0
89 | )
90 |
91 |
92 | if __name__ == '__main__':
93 | absltest.main()
94 |
--------------------------------------------------------------------------------
/praxis/layers/video/quantizer_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from absl.testing import absltest
17 | from absl.testing import parameterized
18 | import jax
19 | from praxis import base_layer
20 | from praxis import pax_fiddle
21 | from praxis import test_utils
22 | from praxis.layers.video import quantizer
23 |
24 | jax.config.update("jax_threefry_partitionable", False)
25 |
26 |
27 | class QuantizerTest(test_utils.TestCase):
28 |
29 | @parameterized.parameters(
30 | (8, True, 0, 0, 0),
31 | (8, False, -0.331989, 0.10211816, -0.43410715),
32 | (16, True, 0, 0, 0),
33 | (16, False, -0.3675179, 0.1005669, -0.46808478),
34 | )
35 | def test_encode_decode_id_and_loss(
36 | self, embedding_dim, do_eval, quantzer_loss, e_latent_loss, entropy_loss
37 | ):
38 | prng_key = jax.random.PRNGKey(seed=123)
39 | x = jax.random.normal(prng_key, (1, 3, 6, 6, embedding_dim))
40 | config = pax_fiddle.Config(
41 | quantizer.LookupFreeQuantizer,
42 | name="quantizer",
43 | embedding_dim=embedding_dim,
44 | )
45 | context_p = base_layer.JaxContext.HParams(do_eval=do_eval)
46 | with base_layer.JaxContext.new_context(hparams=context_p):
47 | lookup_free_quantizer = base_layer.instantiate(config)
48 | pax_vars = lookup_free_quantizer.init(prng_key, x)
49 | self.assertEmpty(pax_vars)
50 | value, result_dict = lookup_free_quantizer.apply(pax_vars, x)
51 | self.assertEqual(value.shape, x.shape)
52 | self.assertAllClose(x, result_dict.raw)
53 |
54 | if not do_eval:
55 | ids = result_dict.encoding_indices
56 | decoded = lookup_free_quantizer.decode_ids(ids)
57 | self.assertAllClose(decoded, result_dict.encodings)
58 | self.assertEqual(result_dict.quantizer_loss.shape, ())
59 | self.assertEqual(result_dict.e_latent_loss.shape, ())
60 | self.assertEqual(result_dict.entropy_loss.shape, ())
61 | self.assertAllClose(result_dict.quantizer_loss, quantzer_loss)
62 | self.assertAllClose(result_dict.e_latent_loss, e_latent_loss)
63 | self.assertAllClose(result_dict.entropy_loss, entropy_loss)
64 |
65 |
66 | if __name__ == "__main__":
67 | absltest.main()
68 |
--------------------------------------------------------------------------------
/praxis/layers/video/vqvae_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from absl.testing import absltest
17 | import jax
18 | from praxis import base_layer
19 | from praxis import pax_fiddle
20 | from praxis import test_utils
21 | from praxis.layers.video import quantizer
22 | from praxis.layers.video import vqvae
23 |
24 | jax.config.update("jax_threefry_partitionable", False)
25 |
26 |
27 | class VQVAETest(test_utils.TestCase):
28 |
29 | def test_auto_encoder(self):
30 | prng_key = jax.random.PRNGKey(seed=123)
31 | x = jax.random.normal(prng_key, (1, 5, 16, 16, 3))
32 | config = pax_fiddle.Config(
33 | vqvae.VQVaeModel,
34 | name="vqvae",
35 | encoder_tpl=pax_fiddle.Config(
36 | vqvae.Encoder,
37 | name="encoder",
38 | filters=128,
39 | input_dim=3,
40 | embedding_dim=8,
41 | num_res_blocks=4,
42 | temporal_downsample=(False, True, True),
43 | channel_multipliers=(1, 2, 2, 4),
44 | ),
45 | decoder_tpl=pax_fiddle.Config(
46 | vqvae.Decoder,
47 | name="decoder",
48 | filters=128,
49 | embedding_dim=8,
50 | output_dim=3,
51 | num_res_blocks=4,
52 | temporal_downsample=(False, True, True),
53 | channel_multipliers=(1, 2, 2, 4),
54 | ),
55 | quantizer_tpl=pax_fiddle.Config(
56 | quantizer.LookupFreeQuantizer,
57 | name="quantizer",
58 | embedding_dim=8,
59 | ),
60 | )
61 | vqvae_model = base_layer.instantiate(config)
62 | pax_vars = vqvae_model.init(prng_key, x)
63 | value = vqvae_model.apply(pax_vars, x)
64 | self.assertIsInstance(value, tuple)
65 | self.assertLen(value, 2)
66 | self.assertEqual(value[0].shape, x.shape)
67 | self.assertAllClose(value[1].quantizer_loss, -0.09399976)
68 | self.assertAllClose(value[1].e_latent_loss, 0.07907465)
69 | self.assertAllClose(value[1].entropy_loss, -0.17307441)
70 |
71 | def test_discriminator(self):
72 | prng_key = jax.random.PRNGKey(seed=123)
73 | x = jax.random.normal(prng_key, (1, 5, 16, 16, 3))
74 | config = pax_fiddle.Config(
75 | vqvae.Discriminator,
76 | name="discriminator",
77 | filters=128,
78 | num_frames=5,
79 | image_height=16,
80 | image_width=16,
81 | input_dim=3,
82 | blur_filter_size=3,
83 | channel_multipliers=(2, 4, 4, 4),
84 | )
85 | discriminator = base_layer.instantiate(config)
86 | pax_vars = discriminator.init(prng_key, x)
87 | value = discriminator.apply(pax_vars, x)
88 | self.assertAllClose(value, -4.26183e-05)
89 |
90 |
91 | if __name__ == "__main__":
92 | absltest.main()
93 |
--------------------------------------------------------------------------------
/praxis/lazy_loader.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """A LazyLoader class."""
17 |
18 | import importlib
19 | import types
20 | from absl import logging
21 |
22 |
23 | class LazyLoader(types.ModuleType):
24 | """Lazily import a module, mainly to avoid pulling in large dependencies.
25 |
26 | `contrib`, and `ffmpeg` are examples of modules that are large and not always
27 | needed, and this allows them to only be loaded when they are used.
28 | """
29 |
30 | # The lint error here is incorrect.
31 | def __init__(self, local_name, parent_module_globals, name, warning=None):
32 | self._local_name = local_name
33 | self._parent_module_globals = parent_module_globals
34 | self._warning = warning
35 |
36 | # These members allows doctest correctly process this module member without
37 | # triggering self._load(). self._load() mutates parant_module_globals and
38 | # triggers a dict mutated during iteration error from doctest.py.
39 | # - for from_module()
40 | self.__module__ = name.rsplit(".", 1)[0]
41 | # - for is_routine()
42 | self.__wrapped__ = None
43 |
44 | super(LazyLoader, self).__init__(name)
45 |
46 | def _load(self):
47 | """Load the module and insert it into the parent's globals."""
48 | # Import the target module and insert it into the parent's namespace
49 | module = importlib.import_module(self.__name__)
50 | self._parent_module_globals[self._local_name] = module
51 |
52 | # Emit a warning if one was specified
53 | if self._warning:
54 | logging.warning(self._warning)
55 | # Make sure to only warn once.
56 | self._warning = None
57 |
58 | # Update this object's dict so that if someone keeps a reference to the
59 | # LazyLoader, lookups are efficient (__getattr__ is only called on lookups
60 | # that fail).
61 | self.__dict__.update(module.__dict__)
62 |
63 | return module
64 |
65 | def __getattr__(self, item):
66 | module = self._load()
67 | return getattr(module, item)
68 |
69 | def __repr__(self):
70 | # Carefully to not trigger _load, since repr may be called in very
71 | # sensitive places.
72 | return f""
73 |
74 | def __dir__(self):
75 | module = self._load()
76 | return dir(module)
77 |
--------------------------------------------------------------------------------
/praxis/lingvo_lib.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Exposes public Lingvo symbols used by Praxis.
17 |
18 | Lingvo symbols must *not* be imported elsewhere in Praxis library code (except
19 | in test-related modules).
20 | """
21 |
22 | from lingvo.core import cluster
23 | from lingvo.core import cluster_factory
24 | from lingvo.core import hyperparams
25 | from lingvo.core import nested_map
26 | from praxis import lazy_loader
27 |
28 | # datasource is slow to import (because it imports TF), so we do it lazily.
29 | datasource = lazy_loader.LazyLoader(
30 | 'datasource', globals(), 'lingvo.core.datasource'
31 | )
32 |
33 | # Note: These are only used by LingvoInputAdaptor, and may possibly be moved
34 | # outside of Praxis in the future.
35 | current_cluster = cluster_factory.Current
36 | infeed_context_scope = cluster.InfeedContextScope
37 |
38 | # Core data-structure for aggregating tensors.
39 | NestedMap = nested_map.NestedMap
40 |
41 | # Note: HParams-related classes. This may possibly be removed post-Fiddle
42 | # migration.
43 | InstantiableParams = hyperparams.InstantiableParams
44 | HParams = hyperparams.Params
45 |
--------------------------------------------------------------------------------
/praxis/metric_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Utility functions for computing metrics."""
17 |
18 | import jax
19 | from jax import numpy as jnp
20 |
21 |
22 | def top_k_accuracy(
23 | top_k: int,
24 | logits: jax.Array,
25 | label_ids: jax.Array | None = None,
26 | label_probs: jax.Array | None = None,
27 | weights: jax.Array | None = None,
28 | ) -> jax.Array:
29 | """Computes the top-k accuracy given the logits and labels.
30 |
31 | Args:
32 | top_k: An int scalar, specifying the value of top-k.
33 | logits: A [..., C] float tensor corresponding to the logits.
34 | label_ids: A [...] int vector corresponding to the class labels. One of
35 | label_ids and label_probs should be presented.
36 | label_probs: A [..., C] float vector corresponding to the class
37 | probabilites. Must be presented if label_ids is None.
38 | weights: A [...] float vector corresponding to the weight to assign to each
39 | example.
40 |
41 | Returns:
42 | The top-k accuracy represented as a `JTensor`.
43 |
44 | Raises:
45 | ValueError if neither `label_ids` nor `label_probs` are provided.
46 | """
47 | if label_ids is None and label_probs is None:
48 | raise ValueError("One of label_ids and label_probs should be given.")
49 | if label_ids is None:
50 | label_ids = jnp.argmax(label_probs, axis=-1)
51 | if weights is None:
52 | weights = jnp.ones(logits.shape[:-1])
53 |
54 | values, _ = jax.lax.top_k(logits, k=top_k)
55 | threshold = jnp.min(values, axis=-1)
56 |
57 | # Reshape logits to [-1, C].
58 | logits_reshaped = jnp.reshape(logits, [-1, logits.shape[-1]])
59 |
60 | # Reshape label_ids to [-1, 1].
61 | label_ids_reshaped = jnp.reshape(label_ids, [-1, 1])
62 | logits_slice = jnp.take_along_axis(
63 | logits_reshaped, label_ids_reshaped, axis=-1
64 | )[..., 0]
65 |
66 | # Reshape logits_slice back to original shape to be compatible with weights.
67 | logits_slice = jnp.reshape(logits_slice, label_ids.shape)
68 | correct = jnp.greater_equal(logits_slice, threshold)
69 | correct_sum = jnp.sum(correct * weights)
70 | all_sum = jnp.maximum(1.0, jnp.sum(weights))
71 | return correct_sum / all_sum
72 |
--------------------------------------------------------------------------------
/praxis/optimizer_prefix_vectorization_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Unit tests for prefix vectorization in optimizers."""
17 |
18 | import typing
19 |
20 | from absl import logging
21 | from absl.testing import absltest
22 | import jax
23 | from jax import numpy as jnp
24 | import optax
25 | from praxis import base_layer
26 | from praxis import optimizer_prefix_vectorization as opt_vec
27 | from praxis import test_utils
28 |
29 |
30 | class OptimizerPrefixVectorizationTest(test_utils.TestCase):
31 |
32 | def test_vectorized_prefix_with_tree_map_params(self):
33 | def _opt_init(params):
34 | # Reduction over each variable. Behavior will depend on vectorization.
35 | logging.info(f'Init called with params {params}')
36 | return jax.tree.map(jnp.sum, params)
37 |
38 | def _opt_update(updates, state, params):
39 | del params
40 | return jax.tree.map(lambda u, s: u + s, updates, state), state
41 |
42 | grad_tx = optax.GradientTransformationExtraArgs(
43 | init=_opt_init, update=_opt_update
44 | )
45 |
46 | grads = base_layer.NestedMap(
47 | a=jnp.array([1, 2], dtype=jnp.float32),
48 | b=jnp.array([1, 2], dtype=jnp.float32),
49 | c=jnp.array([[1, 2], [3, 4]], dtype=jnp.float32),
50 | )
51 | variables = grads.copy()
52 | a_var_param = base_layer.WeightHParams(())
53 | a_var_param.repeat_prefix = [2]
54 | a_var_param.repeat_prefix_split_dims_mapping = [-1]
55 | b_var_param = base_layer.WeightHParams((2,))
56 | c_var_param = base_layer.WeightHParams(())
57 | c_var_param.repeat_prefix = [2, 2]
58 | c_var_param.repeat_prefix_split_dims_mapping = [('data', 'mdl'), None]
59 | var_hparams = base_layer.NestedMap(
60 | a=a_var_param, b=b_var_param, c=c_var_param
61 | )
62 |
63 | grad_tx = opt_vec.get_transformations_with_vectorized_repeat_prefix(
64 | grad_tx, var_hparams
65 | )
66 |
67 | state = grad_tx.init(variables)
68 | logging.info(state)
69 | opt_states_pspec = opt_vec.partition_params(grad_tx, var_hparams, state)
70 | logging.info('opt_states_pspec=%s', opt_states_pspec)
71 | # Computed update is 0 + state, and state is sum of each variable.
72 | update, _ = grad_tx.update(
73 | jax.tree.map(jnp.zeros_like, variables), state, variables
74 | )
75 | # Variables a and c are scalars excluding the prefix, so the update must be
76 | # equal to the initial variable values.
77 | update = typing.cast(base_layer.NestedMap, update)
78 | variables = typing.cast(base_layer.NestedMap, variables)
79 | self.assertAllClose(update.a, variables.a)
80 | self.assertAllClose(update.c, variables.c)
81 | # b is not vectorized, so the update equals the sum reduction of the initial
82 | # variable value.
83 | self.assertAllClose(
84 | update.b, jnp.zeros_like(variables.b) + jnp.sum(variables.b)
85 | )
86 |
87 |
88 | if __name__ == '__main__':
89 | absltest.main()
90 |
--------------------------------------------------------------------------------
/praxis/pip_package/build_pip_pkg.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | set -e
4 |
5 | PLATFORM="$(uname -s | tr 'A-Z' 'a-z')"
6 |
7 | export PYTHON_VERSION="${PYTHON_VERSION:-3}"
8 | export PYTHON_MINOR_VERSION="${PYTHON_MINOR_VERSION}"
9 | export DEST="${WHEEL_FOLDER:-/tmp/wheels}"
10 |
11 | if [[ -z "${PYTHON_MINOR_VERSION}" ]]; then
12 | PYTHON="python${PYTHON_VERSION}"
13 | else
14 | PYTHON="python${PYTHON_VERSION}.${PYTHON_MINOR_VERSION}"
15 | fi
16 |
17 | function main() {
18 | ${PYTHON} setup.py bdist_wheel
19 |
20 | if [ ! -d "${DEST}" ]; then
21 | mkdir -p "${DEST}"
22 | fi
23 |
24 | cp dist/*.whl "${DEST}"
25 | echo $(date) : "=== Output wheel file is in: ${DEST}"
26 | }
27 |
28 | main "$@"
29 |
--------------------------------------------------------------------------------
/praxis/pip_package/cloudbuild-postsubmit.yaml:
--------------------------------------------------------------------------------
1 | steps:
2 | - name: 'gcr.io/cloud-builders/docker'
3 | args: [
4 | 'build',
5 | '--build-arg', 'image_name=${_IMAGE_NAME}',
6 | '-f', 'praxis/pip_package/postsubmit.Dockerfile', '.'
7 | ]
8 |
9 | substitutions:
10 | _PYTHON_VERSION: '3.10'
11 | _RELEASE_VERSION: 'nightly' # or rX.Y
12 | _IMAGE_NAME: 'paxml_${_RELEASE_VERSION}_${_PYTHON_VERSION}'
13 | options:
14 | dynamic_substitutions: true
15 | substitution_option: 'ALLOW_LOOSE'
16 | machineType: E2_HIGHCPU_32
17 | timeout: 1800s
18 |
--------------------------------------------------------------------------------
/praxis/pip_package/cloudbuild-presubmit.yaml:
--------------------------------------------------------------------------------
1 | steps:
2 | - name: 'gcr.io/cloud-builders/docker'
3 | args: [
4 | 'build',
5 | '--build-arg', 'image_name=${_IMAGE_NAME}',
6 | '-f', 'praxis/pip_package/presubmit.Dockerfile', '.'
7 | ]
8 |
9 | substitutions:
10 | _PYTHON_VERSION: '3.10'
11 | _RELEASE_VERSION: 'nightly' # or rX.Y
12 | _IMAGE_NAME: 'paxml_${_RELEASE_VERSION}_${_PYTHON_VERSION}'
13 | options:
14 | dynamic_substitutions: true
15 | substitution_option: 'ALLOW_LOOSE'
16 | machineType: E2_HIGHCPU_32
17 | timeout: 1800s
18 |
--------------------------------------------------------------------------------
/praxis/pip_package/cloudbuild-release.yaml:
--------------------------------------------------------------------------------
1 | steps:
2 | - name: 'gcr.io/cloud-builders/docker'
3 | args: [
4 | 'build',
5 | '-t', 'gcr.io/${PROJECT_ID}/${_IMAGE_NAME}',
6 | '-f', 'praxis/pip_package/release.Dockerfile', '.',
7 | '--build-arg', 'wheel_folder=${_WHEEL_FOLDER}',
8 | ]
9 | timeout: 3600s
10 | - name: 'gcr.io/cloud-builders/docker'
11 | args: ['push', '--all-tags', 'gcr.io/${PROJECT_ID}/${_IMAGE_NAME}']
12 | timeout: 1800s
13 | - name: 'gcr.io/${PROJECT_ID}/${_IMAGE_NAME}'
14 | entrypoint: 'bash'
15 | args: ['-c', 'mv ${_WHEEL_FOLDER}/*.whl .']
16 |
17 | substitutions:
18 | _PYTHON_VERSION: '3.10'
19 | _RELEASE_VERSION: '1.4.0' # or rX.Y
20 | _IMAGE_NAME: 'praxis_${_RELEASE_VERSION}_${_PYTHON_VERSION}'
21 | _WHEEL_FOLDER: '/tmp/wheels'
22 | options:
23 | dynamic_substitutions: true
24 | substitution_option: 'ALLOW_LOOSE'
25 | machineType: E2_HIGHCPU_8
26 | timeout: 5400s
27 | artifacts:
28 | objects:
29 | location: 'gs://pax-on-cloud-tpu-project/wheels/$(date -u +%Y%m%d)-praxis-${_RELEASE_VERSION}'
30 | paths: ['/**/*.whl']
31 |
--------------------------------------------------------------------------------
/praxis/pip_package/postsubmit.Dockerfile:
--------------------------------------------------------------------------------
1 | ARG image_name
2 | ARG base_image="gcr.io/pax-on-cloud-project/${image_name}:latest"
3 | FROM $base_image
4 |
5 | RUN rm -rf /praxis
6 | COPY . /praxis
7 | RUN pip3 uninstall -y fiddle
8 | RUN pip3 uninstall -y flax
9 | RUN pip3 uninstall -y jax
10 | RUN pip3 install --no-deps -r /praxis/praxis/pip_package/requirements.txt
11 |
12 | RUN cd /praxis && bazel build ...
13 |
14 | RUN cd /praxis && \
15 | bazel test \
16 | --test_output=all \
17 | --test_verbose_timeout_warnings \
18 | -- \
19 | praxis/... \
20 | -praxis/layers:attentions_test \
21 | -praxis/layers:convolutions_test \
22 | -praxis/layers:ctc_objectives_test \
23 | -praxis/layers:embedding_softmax_test \
24 | -praxis/layers:models_test \
25 | -praxis/layers:ngrammer_test \
26 | -praxis/layers:normalizations_test \
27 | -praxis/layers:rnn_cell_test \
28 | -praxis/layers:transformer_models_encoder_decoder_test \
29 | -praxis/layers:transformer_models_test \
30 | -praxis/layers:transformers_test
31 |
32 | WORKDIR /
33 |
34 | CMD ["/bin/bash"]
35 |
--------------------------------------------------------------------------------
/praxis/pip_package/prepare_release.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # This script prepare a new release by:
4 | # 1) update version number in setup.py and cloudbuild-release.yaml
5 | # 2) add a new section in RELEASE.md with version and corresponding commit
6 |
7 | set -e -x
8 |
9 | function print_help_and_exit {
10 | echo "Usage: prepare_release.sh -x -d "
11 | echo "exp: bash prepare_release.sh -x 0.2.0 -d 20221114"
12 | exit 0
13 | }
14 |
15 | while getopts "hd:x:" opt; do
16 | case $opt in
17 | x)
18 | PRAXIS_VERSION=${OPTARG}
19 | ;;
20 | d)
21 | BUILD_DATE=${OPTARG}
22 | ;;
23 | *)
24 | print_help_and_exit
25 | ;;
26 | esac
27 | done
28 |
29 | RELEASE_NOTE="../RELEASE.md"
30 | RELEASE_NOTE_NEW="release_new.md"
31 |
32 | if [[ -z "$BUILD_DATE" ]]; then
33 | echo "Build date is required!"
34 | exit 1
35 | fi
36 |
37 | if [[ -z "$PRAXIS_VERSION" ]]; then
38 | echo "praxis version is required!"
39 | exit 1
40 | fi
41 |
42 | echo "Build date: "$BUILD_DATE
43 | echo "PRAXIS version: "$PRAXIS_VERSION
44 |
45 | sed -i "s/version='[0-9.]*'/version='$PRAXIS_VERSION'/" setup.py
46 | sed -i "s/_RELEASE_VERSION: '[0-9.]*'/_RELEASE_VERSION: '$PRAXIS_VERSION'/" cloudbuild-release.yaml
47 | gsutil cp gs://pax-on-cloud-tpu-project/wheels/"$BUILD_DATE"/praxis_commit.txt ./
48 | PRAXIS_COMMIT=$(> $RELEASE_NOTE_NEW
53 | echo "## Major Features and Improvements" >> $RELEASE_NOTE_NEW
54 | echo "## Breaking changes" >> $RELEASE_NOTE_NEW
55 | echo "## Deprecations" >> $RELEASE_NOTE_NEW
56 | echo "## Note" >> $RELEASE_NOTE_NEW
57 | echo "* Version: $PRAXIS_VERSION" >> $RELEASE_NOTE_NEW
58 | echo "* Build Date: $BUILD_DATE" >> $RELEASE_NOTE_NEW
59 | echo "* Praxis commit: $PRAXIS_COMMIT" >> $RELEASE_NOTE_NEW
60 | RELEASE_NOTE_TMP="RELEASE.tmp.md"
61 | cat $RELEASE_NOTE_NEW $RELEASE_NOTE >> $RELEASE_NOTE_TMP
62 | rm $RELEASE_NOTE_NEW
63 | rm $RELEASE_NOTE
64 | mv $RELEASE_NOTE_TMP $RELEASE_NOTE
65 |
--------------------------------------------------------------------------------
/praxis/pip_package/presubmit.Dockerfile:
--------------------------------------------------------------------------------
1 | ARG image_name
2 | ARG base_image="gcr.io/pax-on-cloud-project/${image_name}:latest"
3 | FROM $base_image
4 |
5 | RUN rm -rf /praxis
6 | COPY . /praxis
7 | RUN pip3 uninstall -y fiddle
8 | RUN pip3 uninstall -y flax
9 | RUN pip3 uninstall -y jax
10 | RUN pip3 install --no-deps -r /praxis/praxis/pip_package/requirements.txt
11 | RUN cd /praxis && bazel build ...
12 | #TODO:enable -praxis/layers:normalizations_test once the new Lingvo pip package is released
13 | # RUN cd /praxis && bazel test --test_output=all --test_verbose_timeout_warnings -- praxis/... -praxis/layers:transformer_models_test -praxis/layers:ngrammer_test -praxis/layers:attentions_test -praxis/layers:transformers_test -praxis/layers:models_test -praxis/layers:convolutions_test -praxis/layers:embedding_softmax_test -praxis/layers:ctc_objectives_test -praxis/layers:normalizations_test
14 | # RUN cd /praxis && bazel test --test_output=all --test_verbose_timeout_warnings -- praxis:asserts_test praxis:base_hyperparams_test
15 | RUN cd /praxis && \
16 | bazel test \
17 | --test_output=all \
18 | --test_verbose_timeout_warnings \
19 | -- \
20 | praxis/... \
21 | -praxis/layers:attentions_test \
22 | -praxis/layers:convolutions_test \
23 | -praxis/layers:ctc_objectives_test \
24 | -praxis/layers:embedding_softmax_test \
25 | -praxis/layers:models_test \
26 | -praxis/layers:ngrammer_test \
27 | -praxis/layers:normalizations_test \
28 | -praxis/layers:rnn_cell_test \
29 | -praxis/layers:transformer_models_encoder_decoder_test \
30 | -praxis/layers:transformer_models_test \
31 | -praxis/layers:transformers_test
32 |
33 |
34 | WORKDIR /
35 |
36 | CMD ["/bin/bash"]
37 |
--------------------------------------------------------------------------------
/praxis/pip_package/release.Dockerfile:
--------------------------------------------------------------------------------
1 | ARG cpu_base_image="ubuntu:22.04"
2 | ARG base_image=$cpu_base_image
3 | FROM $base_image
4 |
5 | LABEL maintainer="Pax team "
6 |
7 | # Re-declare args because the args declared before FROM can't be used in any
8 | # instruction after a FROM.
9 | ARG cpu_base_image="ubuntu:22.04"
10 | ARG base_image=$cpu_base_image
11 | ARG wheel_folder
12 | ENV WHEEL_FOLDER $wheel_folder
13 | ENV PYTHON_VERSION="3"
14 | ENV PYTHON_MINOR_VERSION="10"
15 |
16 | # Pick up some TF dependencies
17 | RUN apt update && DEBIAN_FRONTEND=noninteractive apt install -y --no-install-recommends software-properties-common
18 | RUN apt update && DEBIAN_FRONTEND=noninteractive apt install -y --no-install-recommends \
19 | build-essential \
20 | curl \
21 | git \
22 | pkg-config \
23 | rename \
24 | rsync \
25 | unzip \
26 | vim \
27 | && \
28 | apt-get clean && \
29 | rm -rf /var/lib/apt/lists/*
30 |
31 | # Install python 3.10
32 | RUN apt-get update && apt-get install -y \
33 | python3 python3-dev python3-pip python3-venv && \
34 | rm -rf /var/lib/apt/lists/* && \
35 | python3.10 -m pip install pip --upgrade && \
36 | update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 0
37 |
38 | # Make python3.10 the default python version
39 | RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.10 0
40 |
41 | ARG bazel_version=5.1.1
42 | # This is to install bazel, for development purposes.
43 | ENV BAZEL_VERSION ${bazel_version}
44 | RUN mkdir /bazel && \
45 | cd /bazel && \
46 | curl -H "User-Agent: Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/57.0.2987.133 Safari/537.36" -fSsL -O https://github.com/bazelbuild/bazel/releases/download/$BAZEL_VERSION/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \
47 | curl -H "User-Agent: Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/57.0.2987.133 Safari/537.36" -fSsL -o /bazel/LICENSE.txt https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE && \
48 | chmod +x bazel-*.sh && \
49 | ./bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \
50 | cd / && \
51 | rm -f /bazel/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh
52 |
53 | COPY . /praxis
54 | RUN mkdir $WHEEL_FOLDER
55 | RUN sed -i 's/ @ git.*//g' /praxis/requirements.in
56 | RUN pip3 install -r /praxis/requirements.in
57 |
58 | RUN cd /praxis && bazel build ...
59 |
60 | RUN cd praxis && \
61 | bazel test \
62 | --test_output=all \
63 | --test_verbose_timeout_warnings \
64 | -- \
65 | praxis/... \
66 | -praxis/layers:attentions_test \
67 | -praxis/layers:convolutions_test \
68 | -praxis/layers:ctc_objectives_test \
69 | -praxis/layers:embedding_softmax_test \
70 | -praxis/layers:models_test \
71 | -praxis/layers:ngrammer_test \
72 | -praxis/layers:normalizations_test \
73 | -praxis/layers:transformer_models_test \
74 | -praxis/layers:transformers_test
75 |
76 | RUN cd praxis && bash praxis/pip_package/build_pip_pkg.sh
77 |
78 | WORKDIR /
79 |
80 | CMD ["/bin/bash"]
81 |
--------------------------------------------------------------------------------
/praxis/praxis.bzl:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Implements custom rules for Praxis."""
17 |
18 | # Placeholder to use until bazel supports pytype_library.
19 | def pytype_strict_library(name, **kwargs):
20 | native.py_library(name = name, **kwargs)
21 |
22 | # Placeholder to use until bazel supports pytype_strict_contrib_test.
23 | def pytype_strict_test(name, **kwargs):
24 | native.py_test(name = name, **kwargs)
25 |
--------------------------------------------------------------------------------
/praxis/pytypes_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Praxis pytype tests."""
17 |
18 | from absl.testing import absltest
19 | from jax import tree_util as jtu
20 | from praxis import pytypes
21 | from praxis import test_utils
22 |
23 |
24 | class TreesTest(test_utils.TestCase):
25 |
26 | def test_nestedmap_paths(self):
27 | """Ensure NestedMap is registered as a pytree_node correctly."""
28 | tree = pytypes.NestedMap(
29 | a=pytypes.NestedMap(x=0, y=1, zz=2),
30 | b=pytypes.NestedMap(z=1),
31 | )
32 | dict_tree = {'a': {'x': 0, 'y': 1, 'zz': 2}, 'b': {'z': 1}}
33 | self.assertSequenceEqual(
34 | jtu.tree_leaves_with_path(tree),
35 | [
36 | ((jtu.DictKey(key='a'), jtu.DictKey(key='x')), 0),
37 | ((jtu.DictKey(key='a'), jtu.DictKey(key='y')), 1),
38 | ((jtu.DictKey(key='a'), jtu.DictKey(key='zz')), 2),
39 | ((jtu.DictKey(key='b'), jtu.DictKey(key='z')), 1),
40 | ],
41 | )
42 | self.assertSequenceEqual(
43 | jtu.tree_leaves_with_path(dict_tree), jtu.tree_leaves_with_path(tree)
44 | )
45 |
46 |
47 | if __name__ == '__main__':
48 | absltest.main()
49 |
--------------------------------------------------------------------------------
/requirements.in:
--------------------------------------------------------------------------------
1 | # To update requirements.txt for praxis and paxml, run:
2 | # cd ../../paxml/pip_package && bash ./compile_requirements.sh
3 |
4 | absl-py
5 | # enforce newer chex version to avoid deprecated jax syntax
6 | chex>=0.1.85
7 | clu @ git+https://github.com/google/CommonLoopUtils
8 | einops
9 | etils
10 | fiddle @ git+https://github.com/google/fiddle
11 | flax @ git+https://github.com/google/flax
12 | jax @ git+https://github.com/google/jax
13 | jax-bitempered-loss
14 | jaxtyping
15 | lingvo
16 | numpy<2
17 | optax
18 | optax-shampoo
19 | opt-einsum
20 | # ensure sentencepiece is compataible with protobuf 3.19.6
21 | sentencepiece==0.1.99
22 | tensorflow-datasets==4.8.3
23 | tensorflow-metadata==1.12.0
24 | tensorflow-text~=2.9.0
25 | tensorflow~=2.9.2
26 | tfds-nightly==4.8.3.dev202303280045
27 | typeguard
28 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Pax Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Setup.py file for praxis."""
17 |
18 | import os
19 | from setuptools import find_packages
20 | from setuptools import setup
21 |
22 | # Set this envvar to avoid installing packages from head that can overwrite
23 | # existing installs of those packages, e.g., jax
24 | SKIP_HEAD_INSTALLS = os.environ.get('SKIP_HEAD_INSTALLS', '')
25 |
26 | def _get_requirements():
27 | """Parses requirements.txt file."""
28 | install_requires_tmp = []
29 | with open(
30 | os.path.join(os.path.dirname(__file__), './requirements.in'), 'r'
31 | ) as f:
32 | for line in f:
33 | package_name = line.strip()
34 | # Skip empty line or comments starting with "#".
35 | if (
36 | not package_name
37 | or package_name[0] == '#'
38 | or (' @ ' in package_name and SKIP_HEAD_INSTALLS)
39 | ):
40 | continue
41 | else:
42 | install_requires_tmp.append(package_name)
43 | return install_requires_tmp
44 |
45 |
46 | install_requires = _get_requirements()
47 |
48 | setup(
49 | name='praxis',
50 | version='1.4.0',
51 | description=(
52 | 'Functionalities such as a layers for building neural networks in Jax.'
53 | ),
54 | author='PAX team',
55 | author_email='pax-dev@google.com',
56 | packages=find_packages(),
57 | python_requires='>=3.10',
58 | install_requires=install_requires,
59 | url='https://github.com/google/praxis',
60 | license='Apache-2.0',
61 | classifiers=[
62 | 'Programming Language :: Python :: 3.10',
63 | 'Programming Language :: Python :: 3.11',
64 | ],
65 | zip_safe=False,
66 | )
67 |
--------------------------------------------------------------------------------