├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── RELEASE.md ├── WORKSPACE ├── praxis ├── AUTHORS ├── BUILD ├── __init__.py ├── asserts.py ├── asserts_test.py ├── base_hyperparams.py ├── base_hyperparams_test.py ├── base_input.py ├── base_input_test.py ├── base_layer.py ├── base_layer_test.py ├── base_model.py ├── beam_search.py ├── beam_search_test.py ├── build-visibility.bzl ├── contrib │ └── gpu │ │ └── scripts_gpu │ │ ├── lora_layers.py │ │ └── models.py ├── decoder_hparams.py ├── decoder_utils.py ├── decoder_utils_test.py ├── fiddle_tags.py ├── flat_beam_search.py ├── flat_beam_search_test.py ├── flax_utils.py ├── gshard_utils.py ├── layers │ ├── BUILD │ ├── __init__.py │ ├── activations.py │ ├── activations_test.py │ ├── adapters.py │ ├── adapters_test.py │ ├── attentions.py │ ├── attentions_test.py │ ├── augmentations.py │ ├── augmentations_test.py │ ├── base_ops.py │ ├── bregman.py │ ├── bregman_test.py │ ├── chain │ │ ├── BUILD │ │ ├── __init__.py │ │ ├── chain.py │ │ ├── chain_extensions.py │ │ └── chain_test.py │ ├── checkpoint_policy.py │ ├── chunk.py │ ├── chunk_test.py │ ├── conformers.py │ ├── conformers_test.py │ ├── convolutions.py │ ├── convolutions_test.py │ ├── ctc_objectives.py │ ├── ctc_objectives_test.py │ ├── einsum.py │ ├── einsum_test.py │ ├── embedding_softmax.py │ ├── embedding_softmax_test.py │ ├── flax_adapter.py │ ├── flax_adapter_test.py │ ├── frnn.py │ ├── frnn_test.py │ ├── glam.py │ ├── gpu_fast_attention.py │ ├── grok.py │ ├── grouped_query_attention.py │ ├── grouped_query_attention_test.py │ ├── injection │ │ ├── BUILD │ │ ├── __init__.py │ │ ├── fp8_nvidia_gpu.py │ │ └── fp8_nvidia_gpu_test.py │ ├── linears.py │ ├── linears_test.py │ ├── losses.py │ ├── mobilenet.py │ ├── models.py │ ├── models_test.py │ ├── multi_query_attention.py │ ├── multi_query_attention_test.py │ ├── ngrammer.py │ ├── ngrammer_test.py │ ├── normalizations.py │ ├── normalizations_test.py │ ├── pipeline.py │ ├── pipeline_test.py │ ├── poolings.py │ ├── poolings_test.py │ ├── quantization │ │ ├── BUILD │ │ ├── __init__.py │ │ ├── attentions.py │ │ ├── attentions_test.py │ │ ├── automl_select.py │ │ ├── automl_select_test.py │ │ ├── conformers.py │ │ ├── conformers_test.py │ │ ├── convolutions.py │ │ ├── convolutions_test.py │ │ ├── einsum.py │ │ ├── einsum_test.py │ │ ├── embedding_softmax.py │ │ ├── embedding_softmax_test.py │ │ ├── linears.py │ │ ├── linears_test.py │ │ ├── multi_query_attention.py │ │ ├── multi_query_attention_test.py │ │ ├── ngrammer.py │ │ ├── ngrammer_test.py │ │ ├── operations.py │ │ ├── operations_test.py │ │ ├── optimization.py │ │ ├── optimization_test.py │ │ ├── overflow_check.py │ │ ├── overflow_check_test.py │ │ ├── quantization_hparams.py │ │ ├── quantization_test.py │ │ ├── quantize.py │ │ ├── quantize_test.py │ │ ├── quantizer.py │ │ ├── quantizer_test.py │ │ ├── searchable.py │ │ ├── searchable_test.py │ │ ├── sparsity │ │ │ ├── BUILD │ │ │ ├── __init__.py │ │ │ ├── attentions_test.py │ │ │ ├── linears_test.py │ │ │ ├── sparsifier.py │ │ │ ├── sparsifier_test.py │ │ │ ├── sparsity.py │ │ │ ├── sparsity_hparams.py │ │ │ ├── sparsity_modes.py │ │ │ └── sparsity_test.py │ │ ├── tests │ │ │ ├── attention_projection_aqt_test.py │ │ │ ├── attention_projection_fq_test.py │ │ │ ├── attention_projection_ptq_test.py │ │ │ ├── combined_qkv_projection_aqt_test.py │ │ │ ├── combined_qkv_projection_fq_test.py │ │ │ ├── combined_qkv_projection_ptq_test.py │ │ │ ├── dotproduct_attention_aqt_test.py │ │ │ ├── linear_ptq_test.py │ │ │ ├── one_headed_attention_projection_aqt_test.py │ │ │ └── test_util.py │ │ ├── utils.py │ │ └── utils_test.py │ ├── quantizer.py │ ├── quantizer_objectives.py │ ├── quantizer_objectives_test.py │ ├── quantizer_test.py │ ├── repeats.py │ ├── repeats_test.py │ ├── resnets.py │ ├── rnn_cell.py │ ├── rnn_cell_test.py │ ├── searchable.py │ ├── searchable_test.py │ ├── sedd │ │ └── sedd.ipynb │ ├── sequential.py │ ├── sequential_test.py │ ├── sharding.py │ ├── shared_layers_test.py │ ├── spectrum_augmenter.py │ ├── spectrum_augmenter_test.py │ ├── ssm.py │ ├── ssm_test.py │ ├── ssm_transformers.py │ ├── ssm_transformers_test.py │ ├── stats.py │ ├── stats_test.py │ ├── stochastics.py │ ├── stochastics_test.py │ ├── test_layers.py │ ├── transformer_models.py │ ├── transformer_models_encoder_decoder_test.py │ ├── transformer_models_test.py │ ├── transformers.py │ ├── transformers_test.py │ ├── vanillanets.py │ ├── vanillanets_test.py │ ├── video │ │ ├── BUILD │ │ ├── __init__.py │ │ ├── enc_dec_3dcnn.py │ │ ├── enc_dec_3dcnn_test.py │ │ ├── losses.py │ │ ├── losses_test.py │ │ ├── quantizer.py │ │ ├── quantizer_test.py │ │ ├── vqvae.py │ │ └── vqvae_test.py │ ├── vits.py │ └── vits_test.py ├── lazy_loader.py ├── lingvo_lib.py ├── metric_utils.py ├── optimizer_prefix_vectorization.py ├── optimizer_prefix_vectorization_test.py ├── optimizers.py ├── optimizers_test.py ├── pax_fiddle.py ├── pax_fiddle_test.py ├── pip_package │ ├── build_pip_pkg.sh │ ├── cloudbuild-postsubmit.yaml │ ├── cloudbuild-presubmit.yaml │ ├── cloudbuild-release.yaml │ ├── postsubmit.Dockerfile │ ├── prepare_release.sh │ ├── presubmit.Dockerfile │ ├── release.Dockerfile │ └── requirements.txt ├── praxis.bzl ├── py_utils.py ├── py_utils_test.py ├── pytypes.py ├── pytypes_test.py ├── sample_decode.py ├── sample_decode_test.py ├── schedules.py ├── schedules_test.py ├── test_utils.py ├── token_samplers.py ├── token_samplers_test.py ├── trees.py └── trees_test.py ├── requirements.in └── setup.py /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution; 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | 25 | ## Community Guidelines 26 | 27 | This project follows 28 | [Google's Open Source Community Guidelines](https://opensource.google.com/conduct/). 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Praxis 2 | 3 | What is Praxis? 4 | Praxis is the layer library for [Pax](https://github.com/google/paxml/). While Praxis is optimized for ML at scale, Praxis has a goal to be usable by other JAX-based ML projects. 5 | 6 | 7 | Some examples of layers to be folded into Praxis are in the [praxis/layers/](https://github.com/google/praxis/tree/main/praxis/layers) directory. 8 | 9 | 10 | Copyright 2022 Google LLC 11 | 12 | Licensed under the Apache License, Version 2.0 (the "License"); 13 | you may not use this file except in compliance with the License. 14 | You may obtain a copy of the License at 15 | 16 | https://www.apache.org/licenses/LICENSE-2.0 17 | 18 | Unless required by applicable law or agreed to in writing, software 19 | distributed under the License is distributed on an "AS IS" BASIS, 20 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 21 | See the License for the specific language governing permissions and 22 | limitations under the License. 23 | -------------------------------------------------------------------------------- /RELEASE.md: -------------------------------------------------------------------------------- 1 | # Version: 1.4.0 2 | ## Major Features and Improvements 3 | ## Breaking changes 4 | ## Deprecations 5 | ## Note 6 | * Version: 1.4.0 7 | * Build Date: 20240408 8 | * Praxis commit: 794623a65934cac98aaa962afcbf9d40a51c02f8 9 | # Version: 1.3.1 10 | ## Major Features and Improvements 11 | ## Breaking changes 12 | ## Deprecations 13 | ## Note 14 | * Version: 1.3.1 15 | * Build Date: 20240220 16 | * Praxis commit: a935384da88ec739d48696c2d0a55657a183de00 17 | # Version: 1.3.0 18 | ## Major Features and Improvements 19 | ## Breaking changes 20 | ## Deprecations 21 | ## Note 22 | * Version: 1.3.0 23 | * Build Date: 20240216 24 | * Praxis commit: ca8ad2ba8be5b77092a8e5f0a0774a39cc3ad766 25 | # Version: 1.2.0 26 | ## Major Features and Improvements 27 | ## Breaking changes 28 | ## Deprecations 29 | ## Note 30 | * Version: 1.2.0 31 | * Build Date: 20231016 32 | * Praxis commit: 7bd63412bf86a68e09fcd9455f76a4909d19377e 33 | # Version: 1.1.0 34 | ## Major Features and Improvements 35 | * Move to python 3.10 as the minimal python requirement (previously on python 3.8). 36 | * Add various quantization support. 37 | * Add support for fine-grained weight sparsity in praxis layers. 38 | * Add support for Direct Preference Optimization (DPO). 39 | ## Note 40 | * Version: 1.1.0 41 | * Build Date: 20230824 42 | * Praxis commit: 2a7d0407871502a1d79dcd0e01411e73f1d15d36 43 | # Version: 1.0.0 44 | ## Major Features and Improvements 45 | * **Fiddle** - Praxis layers and BaseParameterizable are now configured with [Fiddle](https://github.com/google/fiddle), a Python-first configuration library. Fiddle reduces boilerplate, and adds productivity features including history tracking, graphviz visualization, support for aliasing objects, and more. 46 | * **CLI Experiment and Data Injectability** - Enable Pax users to select which experiments to run without the need to recompile for each experiment. Using a CLI interface based on Fiddle, users can override subsets of the experiment’s canonical dataset. 47 | * **CLU Metrics** - Praxis has adopted CLU metrics as its standard metric interface. This allows other Jax/Flax codebases that have CLU metrics to use them in Praxis. 48 | * **Flax Interoperability** - Praxis now supports shape inference, __call__ for forward propagation, and has adopted Linen’s AxisMetadata for its mid-level sharding APIs. These changes improve interoperability with other Flax-based libraries such as T5X. 49 | ## Note 50 | * Version: 1.0.0 51 | * Build Date: 20230329 52 | * Praxis commit: 621c2ca7bfcd0e21ea118a3d8e40e29b48313c0c 53 | # Version: 0.4.0 54 | ## Note 55 | * Version: 0.4.0 56 | * Build Date: 20230329 57 | * Praxis commit: 621c2ca7bfcd0e21ea118a3d8e40e29b48313c0c 58 | # Version: 0.3.0 59 | ## Major Features and Improvements 60 | * Fiddle migration 61 | * Improve numerical stability when using bfloat16 62 | * Improve and add new functionalities to decoding algorithms 63 | * Improve quantization support and add quantization aware training 64 | * Improve streaming support 65 | * Move learners / sgf and train_states modules to paxml 66 | * Misc renaming / API updates for consistency 67 | ## Note 68 | * Version: 0.3.0 69 | * Build Date: 20230201 70 | * Praxis commit: 9e1d13d888ac18a567e249ddb41e6b1bd1fe505a 71 | # Version: 0.2.1 72 | ## Note 73 | * Version: 0.2.1 74 | * Build Date: 20221121 75 | * Praxis commit: f7e98026c1c5ecbc6e4aff175621d443fa37fcf2 76 | # Version: 0.2.0 77 | ## Major Features and Improvements 78 | * Preparatory work for Fiddle integration 79 | * Support for Flax shape inference 80 | * Support for Jax Array 81 | * Optimizer additions and improvements: 82 | - HeroLion 83 | - ShardedAdagrad 84 | - ShardedStaticAccumulator optimizer wrapper to do a fixed number of gradient 85 | accumulations 86 | - Shampoo improvements 87 | - Fix for multi-optimizer following the introduction of optax.MaskedNode 88 | - Improve sanitization of NaNs/Infs gradients during training 89 | * Decoding 90 | - Add support for ExtendNSteps 91 | - Add beam search support for sequence models 92 | - Set prefix_lengths by input_indicator for PrefixLM 93 | - Move decode post-processing tensors into host memory 94 | * Summaries 95 | - Add support for verbosity level 96 | - Add more knobs to the learner to control summary generation 97 | ## Deprecations 98 | * Disallow hparams override in setup() 99 | * Hparams and layer names must now be distinct 100 | ## Note 101 | * Version: 0.2.0 102 | * Build Date: 20221114 103 | * Praxis commit: 413da1ad8148f27faebca119f8c5deedca66228b 104 | # Version: 0.1.0 105 | ## Major Features and Improvements 106 | ## Breaking changes 107 | ## Deprecations 108 | ## Note 109 | * Version: 0.1.0 110 | * Build Date: 20220702 111 | * Commit: 112 | -------------------------------------------------------------------------------- /WORKSPACE: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Workspace file for Praxis.""" 17 | 18 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") 19 | 20 | http_archive( 21 | name = "bazel_skylib", 22 | urls = [ 23 | "https://mirror.bazel.build/github.com/bazelbuild/bazel-skylib/releases/download/1.2.1/bazel-skylib-1.2.1.tar.gz", 24 | "https://github.com/bazelbuild/bazel-skylib/releases/download/1.2.1/bazel-skylib-1.2.1.tar.gz", 25 | ], 26 | sha256 = "f7be3474d42aae265405a592bb7da8e171919d74c16f082a5457840f06054728", 27 | ) 28 | 29 | load("@bazel_skylib//:workspace.bzl", "bazel_skylib_workspace") #buildifier: disable=load-on-top 30 | 31 | http_archive( 32 | name = "com_google_protobuf", # v3.19.4 // 2022-01-2 33 | sha256 = "3bd7828aa5af4b13b99c191e8b1e884ebfa9ad371b0ce264605d347f135d2568", 34 | strip_prefix = "protobuf-3.19.4", 35 | url = "https://github.com/protocolbuffers/protobuf/archive/v3.19.4.tar.gz", 36 | ) 37 | 38 | bazel_skylib_workspace() 39 | 40 | http_archive( 41 | name = "rules_python", 42 | sha256 = "cdf6b84084aad8f10bf20b46b77cb48d83c319ebe6458a18e9d2cebf57807cdd", 43 | strip_prefix = "rules_python-0.8.1", 44 | url = "https://github.com/bazelbuild/rules_python/archive/refs/tags/0.8.1.tar.gz", 45 | ) 46 | 47 | http_archive( 48 | name = "zlib", 49 | build_file = "@com_google_protobuf//:third_party/zlib.BUILD", 50 | sha256 = "c3e5e9fdd5004dcb542feda5ee4f0ff0744628baf8ed2dd5d66f8ca1197cb1a1", 51 | strip_prefix = "zlib-1.2.11", 52 | urls = [ 53 | "https://mirror.bazel.build/zlib.net/zlib-1.2.11.tar.gz", 54 | "https://zlib.net/zlib-1.2.11.tar.gz", 55 | ], 56 | ) 57 | -------------------------------------------------------------------------------- /praxis/AUTHORS: -------------------------------------------------------------------------------- 1 | # This is the list of Pax's significant contributors. 2 | # 3 | # This does not necessarily list everyone who has contributed code, 4 | # especially since many employees of one corporation may be contributing. 5 | # To see the full list of contributors, see the revision history in 6 | # source control. 7 | Google LLC 8 | NVIDIA Corporation -------------------------------------------------------------------------------- /praxis/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /praxis/build-visibility.bzl: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Open source BUILD visibility used for Praxis.""" 17 | 18 | JAX_VISIBILITY = ["//visibility:public"] 19 | -------------------------------------------------------------------------------- /praxis/contrib/gpu/scripts_gpu/lora_layers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from jax import numpy as jnp 17 | from praxis import base_layer 18 | from praxis import pax_fiddle 19 | from praxis import pytypes 20 | from praxis.layers.attentions import AttentionProjection, CombinedQKVProjectionLayer 21 | from praxis.layers.linears import Linear 22 | 23 | 24 | WeightInit = base_layer.WeightInit 25 | LayerTpl = pax_fiddle.Config[base_layer.BaseLayer] 26 | template_field = base_layer.template_field 27 | WeightHParams = base_layer.WeightHParams 28 | JTensor = pytypes.JTensor 29 | 30 | 31 | class LoraTheta(base_layer.Theta): 32 | 33 | def __init__(self, module): 34 | self.module = module 35 | 36 | def _lora_initialized(self): 37 | if ( 38 | self.module.has_variable("params", "lora_a") 39 | and self.module.has_variable("params", "lora_b") 40 | and "lora_a" in self.module._weight_hparams 41 | and "lora_b" in self.module._weight_hparams 42 | ): 43 | return True 44 | else: 45 | return False 46 | 47 | def _lorafy_var(self, var): 48 | lora_a = super().__getattr__("lora_a") 49 | lora_b = super().__getattr__("lora_b") 50 | new_var = self.module.einsum("...dr,...nr->...dn", lora_a, lora_b) 51 | new_var = jnp.reshape(new_var, var.shape) 52 | new_var += var 53 | return new_var 54 | 55 | def __getattr__(self, k): 56 | var = super().__getattr__(k) 57 | if not self._lora_initialized(): 58 | return var 59 | 60 | if k == "w": 61 | return self._lorafy_var(var) 62 | 63 | return var 64 | 65 | def __getitem__(self, k): 66 | var = super().__getattr__(k) 67 | if not self._lora_initialized(): 68 | return var 69 | 70 | if k == "w": 71 | return self._lorafy_var(var) 72 | 73 | return var 74 | 75 | 76 | class LoraThetaDescriptor: 77 | """Dot syntax accession descriptor.""" 78 | 79 | def __get__(self, obj, objtype=None): 80 | return LoraTheta(obj) 81 | 82 | 83 | class LoraLinear(Linear): 84 | rank: int = 0 85 | lora_init: WeightInit | None = None 86 | theta = LoraThetaDescriptor() 87 | 88 | def setup(self) -> None: 89 | lora_init = self.lora_init if self.lora_init else self.weight_init 90 | 91 | super().setup() 92 | self.create_variable( 93 | "lora_a", 94 | WeightHParams( 95 | shape=[self.input_dims, self.rank], 96 | init=lora_init, 97 | mesh_shape=self.mesh_shape, 98 | tensor_split_dims_mapping=[None, None], 99 | ), 100 | ) 101 | self.create_variable( 102 | "lora_b", 103 | WeightHParams( 104 | shape=[self.output_dims, self.rank], 105 | init=WeightInit.Constant(scale=0.0), 106 | mesh_shape=self.mesh_shape, 107 | tensor_split_dims_mapping=[None, None], 108 | ), 109 | ) 110 | 111 | 112 | class LoraAttentionProjection(AttentionProjection): 113 | rank: int = 0 114 | lora_init: WeightInit | None = None 115 | theta = LoraThetaDescriptor() 116 | 117 | def setup(self) -> None: 118 | super().setup() 119 | w_weight_params = self._weight_hparams["w"] 120 | lora_init = self.lora_init if self.lora_init else w_weight_params.init 121 | 122 | self.create_variable( 123 | "lora_a", 124 | WeightHParams( 125 | shape=[self.input_dim, self.rank], 126 | init=lora_init, 127 | mesh_shape=self.mesh_shape, 128 | tensor_split_dims_mapping=[ 129 | None, 130 | None, 131 | ], 132 | ), 133 | ) 134 | self.create_variable( 135 | "lora_b", 136 | WeightHParams( 137 | shape=[self.dim_per_head * self.num_heads, self.rank], 138 | init=WeightInit.Constant(scale=0.0), 139 | mesh_shape=self.mesh_shape, 140 | tensor_split_dims_mapping=[ 141 | None, 142 | None, 143 | ], 144 | ), 145 | ) 146 | 147 | 148 | class LoraCombinedQKVProjection(CombinedQKVProjectionLayer): 149 | rank: int = 0 150 | lora_init: WeightInit | None = None 151 | theta = LoraThetaDescriptor() 152 | 153 | def setup(self) -> None: 154 | super().setup() 155 | w_weight_params = self._weight_hparams["w"] 156 | lora_init = self.lora_init if self.lora_init else w_weight_params.init 157 | 158 | self.create_variable( 159 | "lora_a", 160 | WeightHParams( 161 | shape=[3, self.input_dim, self.rank], 162 | init=lora_init, 163 | mesh_shape=self.mesh_shape, 164 | tensor_split_dims_mapping=[None, None, None], 165 | ), 166 | ) 167 | self.create_variable( 168 | "lora_b", 169 | WeightHParams( 170 | shape=[3, self.dim_per_head * self.num_heads, self.rank], 171 | init=WeightInit.Constant(scale=0.0), 172 | mesh_shape=self.mesh_shape, 173 | tensor_split_dims_mapping=[None, None, None], 174 | ), 175 | ) 176 | -------------------------------------------------------------------------------- /praxis/fiddle_tags.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Declares Fiddle tags for Praxis. 17 | 18 | Fiddle tags are essentially named collections of configuration values. They can 19 | diverge, but can also all be set at once to the same value. 20 | """ 21 | 22 | import fiddle as fdl 23 | 24 | 25 | class BaseDType(fdl.Tag): 26 | """Base JAX DType. 27 | 28 | Usually you want ParamsDType or ActivationDType, but we provide a base class 29 | here in case the user wants to quickly set all DTypes. 30 | """ 31 | 32 | 33 | class ParamsDType(BaseDType): 34 | """DType for parameters.""" 35 | 36 | 37 | class ActivationDType(BaseDType): 38 | """DType for parameters.""" 39 | 40 | 41 | class WeightInit(fdl.Tag): 42 | """Weight initializer class. 43 | 44 | Tagged values should generally be base_layer.WeightInit. 45 | """ 46 | 47 | 48 | class ParamsSplitDimsMapping(fdl.Tag): 49 | """SplitDimsMapping for parameters.""" 50 | 51 | 52 | class ActivationSplitDimsMapping(fdl.Tag): 53 | """SplitDimsMapping for activations.""" 54 | 55 | 56 | class DropoutRate(fdl.Tag): 57 | """Tag for dropout rates.""" 58 | -------------------------------------------------------------------------------- /praxis/flat_beam_search_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Unit tests for flat_beam_search.""" 17 | 18 | from absl.testing import absltest 19 | from jax import numpy as jnp 20 | import numpy as np 21 | from praxis import flat_beam_search 22 | from praxis import test_utils 23 | 24 | 25 | class FlatBeamSearchHelperTest(test_utils.TestCase): 26 | 27 | def test_update_mask_without_time_step(self): 28 | beam_mask = jnp.array( 29 | [[[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]], 30 | dtype=jnp.float32) 31 | hyp_id = jnp.array([[0, 3, 0, 1]], jnp.float32) 32 | update_beam_mask = flat_beam_search.update_beam_mask( 33 | beam_mask, hyp_id, time_step=None) 34 | self.assertArraysEqual( 35 | update_beam_mask, 36 | np.array([[[1, 0, 0, 0], [0, 0, 0, 1], [1, 0, 0, 0], [0, 1, 0, 0]]], 37 | dtype=np.float32)) 38 | 39 | def test_update_mask_without_step2(self): 40 | beam_mask = jnp.array( 41 | [[ 42 | [1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], 43 | [0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0], 44 | [1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], 45 | [0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0], 46 | ]], 47 | dtype=jnp.float32, 48 | ) 49 | hyp_id = jnp.array([[3, 3, 0, 1]], jnp.float32) 50 | update_beam_mask = flat_beam_search.update_beam_mask( 51 | beam_mask, hyp_id, time_step=None) 52 | self.assertArraysEqual( 53 | update_beam_mask, 54 | np.array([[[0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0], 55 | [0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0], 56 | [1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], 57 | [0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0]]], 58 | dtype=np.float32)) 59 | 60 | def test_update_mask_with_step(self): 61 | beam_mask = jnp.array( 62 | [[ 63 | [1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], 64 | [0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0], 65 | [1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], 66 | [0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0], 67 | ]], 68 | dtype=jnp.float32, 69 | ) 70 | hyp_id = jnp.array([[3, 3, 0, 1]], jnp.float32) 71 | update_beam_mask = flat_beam_search.update_beam_mask( 72 | beam_mask, hyp_id, time_step=2) 73 | self.assertArraysEqual( 74 | update_beam_mask, 75 | np.array([[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0], 76 | [0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0], 77 | [1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0], 78 | [0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1]]], 79 | dtype=np.float32)) 80 | 81 | def test_get_final_output_ids(self): 82 | beam_mask = jnp.array( 83 | [[[0, 0, 1, 0, 0, 0, 0, 1], [0, 0, 1, 0, 0, 0, 1, 0], 84 | [1, 0, 0, 0, 1, 0, 0, 0], [0, 0, 1, 0, 0, 1, 0, 0]]], 85 | dtype=jnp.int32) 86 | output_ids = jnp.array([[[0, 5], [1, 6], [2, 7], [3, 8]]], dtype=np.int32) 87 | final_output_ids = flat_beam_search.get_final_output_ids( 88 | beam_mask, output_ids) 89 | 90 | self.assertArraysEqual( 91 | final_output_ids, 92 | np.array([[[2, 8], [2, 7], [0, 5], [2, 6]]], dtype=np.int32)) 93 | 94 | def test_update_topk_scores_with_eos(self): 95 | end_mask = jnp.array([[[0, 1], [0, 1]]], dtype=np.float32) 96 | cur_mask = jnp.array([[[1, 0], [0, 1]]], dtype=np.float32) 97 | 98 | end_scores = jnp.array([[0, 1]], dtype=np.float32) 99 | cur_scores = jnp.array([[2, 3]], dtype=np.float32) 100 | 101 | end_scores_norm = jnp.array([[2, 0]], dtype=np.float32) 102 | cur_scores_norm = jnp.array([[3, 1]], dtype=np.float32) 103 | 104 | (output_mask, output_scores, 105 | output_scores_norm) = flat_beam_search.update_topk_scores_with_eos( 106 | (end_mask, end_scores, end_scores_norm), 107 | (cur_mask, cur_scores, cur_scores_norm)) 108 | 109 | self.assertArraysEqual(output_mask, 110 | np.array([[[1, 0], [0, 1]]], dtype=np.float32)) 111 | self.assertArraysEqual(output_scores, np.array([[2, 0]], dtype=np.float32)) 112 | self.assertArraysEqual(output_scores_norm, 113 | np.array([[3, 2]], dtype=np.float32)) 114 | 115 | 116 | if __name__ == '__main__': 117 | absltest.main() 118 | -------------------------------------------------------------------------------- /praxis/layers/activations_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for activations.""" 17 | 18 | from absl.testing import absltest 19 | import jax 20 | from jax import numpy as jnp 21 | from praxis import base_layer 22 | from praxis import pax_fiddle 23 | from praxis import test_utils 24 | from praxis.layers import activations 25 | 26 | 27 | class ActivationsTest(test_utils.TestCase): 28 | 29 | def _run(self, p, inputs, outputs): 30 | self.assertEqual(outputs, base_layer.instantiate(p.set(name='n'))(inputs)) 31 | 32 | def test_Exp(self): 33 | inputs = jnp.array([1.0]) 34 | self._run(pax_fiddle.Config(activations.Exp), inputs, jnp.exp(inputs)) 35 | 36 | def test_Softplus(self): 37 | inputs = jnp.array([1.0]) 38 | self._run( 39 | pax_fiddle.Config(activations.Softplus), inputs, jax.nn.softplus(inputs) 40 | ) 41 | 42 | def test_get_subclass_by_name(self): 43 | layer_cls = activations.BaseActivation.get_subclass_by_name('swish') 44 | self.assertIs(layer_cls, activations.Swish) 45 | 46 | layer_cls = activations.BaseActivation.get_subclass_by_name('cubed_relu') 47 | self.assertIs(layer_cls, activations.CubedReLU) 48 | 49 | def test_get_subclass_by_name_ambiguous(self): 50 | class Exp(activations.BaseActivation): # pylint: disable=unused-variable 51 | pass 52 | 53 | with self.assertRaises(KeyError): 54 | activations.BaseActivation.get_subclass_by_name('exp') 55 | 56 | 57 | if __name__ == '__main__': 58 | absltest.main() 59 | -------------------------------------------------------------------------------- /praxis/layers/base_ops.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Op wrappers that can be used instead of native JAX ops to allow injection of custom ops. 17 | 18 | This is useful for quantization, sparsity and possibly other techniques. 19 | """ 20 | 21 | import jax.numpy as jnp 22 | from praxis import base_layer 23 | from praxis import pytypes 24 | 25 | JTensor = pytypes.JTensor 26 | 27 | 28 | # These wrappers allow (with a help of fiddle) to inject custom ops 29 | # implementations into various Pax layers. 30 | # Base ops injection are useful if one wants to inject quantized or sparse op. 31 | # The main benefit is a reduction of the need for layer forking. 32 | # 33 | # It also allows these custom ops to use state e.g.: 34 | # to use variables for calibration stats and 35 | # to use random numbers e.g.: for stochastic rounding. 36 | 37 | 38 | class EinsumOp(base_layer.BaseLayer): 39 | """Wrapper around jnp.einsum used in standard Pax layers.""" 40 | 41 | def __call__(self, equation: str, *args: JTensor) -> JTensor: 42 | return jnp.einsum(equation, *args) 43 | 44 | 45 | class EinsumGatedOp(base_layer.BaseLayer): 46 | """Wrapper around two jnp.einsum for gated FFN.""" 47 | 48 | def __call__(self, equation: str, *args: JTensor) -> tuple[JTensor, JTensor]: 49 | assert len(args) == 3 50 | x, k, k_gated = args 51 | y = jnp.einsum(equation, x, k) 52 | y_gated = jnp.einsum(equation, x, k_gated) 53 | return y, y_gated 54 | 55 | 56 | class ArrayLookup(base_layer.BaseLayer): 57 | """Wrapper around array indexing as used in embedding lookup.""" 58 | 59 | def __call__(self, x: JTensor, idx) -> JTensor: 60 | return x[idx] 61 | -------------------------------------------------------------------------------- /praxis/layers/bregman_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for Praxis Bregman PCA layer.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | import jax 21 | import jax.numpy as jnp 22 | import numpy as np 23 | from praxis import base_layer 24 | from praxis import pax_fiddle 25 | from praxis import py_utils 26 | from praxis import test_utils 27 | from praxis.layers import bregman 28 | 29 | instantiate = base_layer.instantiate 30 | to_np = test_utils.to_np 31 | NON_TRAINABLE = base_layer.NON_TRAINABLE 32 | 33 | 34 | class BregmanTest(test_utils.TestCase): 35 | 36 | def setUp(self): 37 | super().setUp() 38 | np.random.seed(123456) 39 | 40 | @parameterized.parameters( 41 | ('IDENTITY', 0.99, 0.0, 0.0, True), 42 | ('IDENTITY', 1.0, 0.01, 0.0, True), 43 | ('IDENTITY', 1.0, 0.01, 0.01, False), 44 | ('IDENTITY', 0.99, 0.01, 0.01, True), 45 | ('LEAKY_RELU', 0.99, 0.0, 0.0, False), 46 | ('LEAKY_RELU', 1.0, 0.01, 0.0, True), 47 | ('LEAKY_RELU', 1.0, 0.01, 0.01, False), 48 | ('LEAKY_RELU', 0.99, 0.01, 0.01, True), 49 | ('SOFTMAX', 0.99, 0.0, 0.0, True), 50 | ('SOFTMAX', 1.0, 0.01, 0.0, True), 51 | ('SOFTMAX', 1.0, 0.01, 0.01, False), 52 | ('SOFTMAX', 0.99, 0.01, 0.01, False), 53 | ) 54 | def test_bregman_layer( 55 | self, 56 | activation, 57 | mean_beta, 58 | coefficients_lr, 59 | components_lr, 60 | constant_lr_schedule, 61 | ): 62 | """Tests layer construction and the expected outputs.""" 63 | activation_type = getattr(bregman.ActivationType, activation) 64 | p = pax_fiddle.Config( 65 | bregman.BregmanPCA, 66 | name='bregman_pca', 67 | num_components=3, 68 | input_dims=[8, 10], 69 | activation_type=activation_type, 70 | negative_slope=0.1, 71 | mean_beta=mean_beta, 72 | coefficients_lr=coefficients_lr, 73 | coefficients_beta=0.9, 74 | coefficients_steps=20, 75 | components_lr=components_lr, 76 | components_beta=0.9, 77 | start_step=0, 78 | end_step=1, 79 | constant_lr_schedule=constant_lr_schedule, 80 | ) 81 | layer = instantiate(p) 82 | if activation == 'SOFTMAX': 83 | npy_input = np.random.random([16] + p.input_dims).astype('float32') 84 | npy_input = npy_input / np.sum(npy_input, axis=-1, keepdims=True) 85 | else: 86 | npy_input = np.random.normal(1.0, 0.5, [16] + p.input_dims).astype( 87 | 'float32' 88 | ) 89 | inputs = jnp.asarray(npy_input) 90 | with base_layer.JaxContext.new_context(): 91 | prng_key = jax.random.PRNGKey(seed=123) 92 | initial_vars = layer.init(prng_key, inputs) 93 | 94 | @jax.jit 95 | def comp(theta, inputs): 96 | with base_layer.JaxContext.new_context(): 97 | return layer.apply(theta, inputs, mutable=[NON_TRAINABLE]) 98 | 99 | (outputs, coefficients), updated_vars = comp(initial_vars, inputs) 100 | self.assertAllClose(outputs, inputs, atol=1e-5) 101 | 102 | with base_layer.JaxContext.new_context(): 103 | layer = layer.bind(initial_vars, mutable=[NON_TRAINABLE]) 104 | initial_vars = py_utils.NestedMap.FromNestedDict( 105 | initial_vars['non_trainable'] 106 | ) 107 | init_components = initial_vars.components 108 | init_mean = initial_vars.mean 109 | mean = updated_vars[NON_TRAINABLE]['mean'] 110 | components = updated_vars[NON_TRAINABLE]['components'] 111 | init_loss = layer.bregman_loss_fn( 112 | jnp.zeros_like(coefficients), init_components, init_mean, inputs 113 | ) 114 | final_loss = layer.bregman_loss_fn(coefficients, components, mean, inputs) 115 | self.assertLess(final_loss, init_loss) 116 | 117 | representations = layer.reconstruct(coefficients) 118 | self.assertEqual(representations.shape, inputs.shape) 119 | 120 | def test_pca_convergence(self): 121 | """Tests whether the gradients are zero at the solution.""" 122 | p = pax_fiddle.Config( 123 | bregman.BregmanPCA, 124 | name='bregman_pca', 125 | num_components=3, 126 | input_dims=[3], 127 | activation_type=bregman.ActivationType.IDENTITY, 128 | start_step=0, 129 | end_step=1, 130 | ) 131 | layer = instantiate(p) 132 | npy_input = np.random.normal(1.0, 0.5, [16] + p.input_dims).astype( 133 | 'float32' 134 | ) 135 | inputs = jnp.asarray(npy_input) 136 | 137 | with base_layer.JaxContext.new_context(): 138 | prng_key = jax.random.PRNGKey(seed=123) 139 | initial_vars = layer.init(prng_key, inputs) 140 | layer = layer.bind(initial_vars, mutable=[NON_TRAINABLE]) 141 | mean = jnp.zeros((1, 3)) 142 | components = jnp.eye(3) 143 | coefficients_grad = layer.coefficients_grad_fn( 144 | inputs, components, mean, inputs 145 | ) 146 | components_grad = layer.components_grad_fn(inputs, components, mean, inputs) 147 | self.assertAllClose( 148 | coefficients_grad, jnp.zeros_like(coefficients_grad), atol=1e-5 149 | ) 150 | self.assertAllClose( 151 | components_grad, jnp.zeros_like(components_grad), atol=1e-5 152 | ) 153 | 154 | 155 | if __name__ == '__main__': 156 | absltest.main() 157 | -------------------------------------------------------------------------------- /praxis/layers/chain/BUILD: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Description: Chain and utilities. 17 | # The public API is defined in __init__.py. 18 | 19 | load("//praxis:build-visibility.bzl", "JAX_VISIBILITY") 20 | load("//praxis:praxis.bzl", "pytype_strict_library", "pytype_strict_test") 21 | 22 | package(default_visibility = JAX_VISIBILITY) 23 | 24 | pytype_strict_library( 25 | name = "chain", 26 | srcs = ["__init__.py"], 27 | deps = [ 28 | ":chain_extensions", 29 | ":chain_lib", 30 | ], 31 | ) 32 | 33 | pytype_strict_library( 34 | name = "chain_lib", 35 | srcs = ["chain.py"], 36 | deps = ["//praxis:base_layer"], 37 | ) 38 | 39 | pytype_strict_library( 40 | name = "chain_extensions", 41 | srcs = ["chain_extensions.py"], 42 | deps = [ 43 | ":chain_lib", 44 | # Implicit absl.logging dependency. 45 | # Implicit jax dependency. 46 | "//praxis:base_layer", 47 | "//praxis:pax_fiddle", 48 | "//praxis:py_utils", 49 | "//praxis/layers:activations", 50 | "//praxis/layers:linears", 51 | "//praxis/layers:repeats", 52 | ], 53 | ) 54 | 55 | pytype_strict_test( 56 | name = "chain_test", 57 | srcs = ["chain_test.py"], 58 | deps = [ 59 | ":chain", 60 | # Implicit absl.testing.absltest dependency. 61 | # Implicit absl.testing.parameterized dependency. 62 | # Implicit upb python proto dependency. 63 | # Implicit jax dependency. 64 | # Implicit numpy dependency. 65 | "//praxis:base_hyperparams", 66 | "//praxis:base_layer", 67 | "//praxis:pax_fiddle", 68 | "//praxis:py_utils", 69 | "//praxis:test_utils", 70 | "//praxis/layers:activations", 71 | ], 72 | ) 73 | -------------------------------------------------------------------------------- /praxis/layers/chain/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Expose public api of Chain and utilities.""" 17 | 18 | from praxis.layers.chain.chain import call_chain 19 | from praxis.layers.chain.chain import Chain 20 | from praxis.layers.chain.chain_extensions import add_residual 21 | from praxis.layers.chain.chain_extensions import apply_padding 22 | from praxis.layers.chain.chain_extensions import chain 23 | from praxis.layers.chain.chain_extensions import copy_n_times 24 | from praxis.layers.chain.chain_extensions import dict_to_args 25 | from praxis.layers.chain.chain_extensions import feed_forward 26 | from praxis.layers.chain.chain_extensions import full_like 27 | from praxis.layers.chain.chain_extensions import kwargs_with_name 28 | from praxis.layers.chain.chain_extensions import log_args 29 | from praxis.layers.chain.chain_extensions import repeat 30 | -------------------------------------------------------------------------------- /praxis/layers/chain/chain.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """`Chain` is a utility for chaining layers. 17 | 18 | In its most typical use `Chain` the output from the previous layer is given as 19 | input to the next. But `Chain` extends and generalizes combination of layers by 20 | treating `*args` as a stack; popping arguments and pushing outputs as it 21 | iterates through the layers. 22 | """ 23 | 24 | import inspect 25 | from typing import Any, Callable, Sequence 26 | 27 | from praxis import base_layer 28 | 29 | 30 | class Chain(base_layer.BaseLayer): 31 | """A utility layer for chaining a list of layers. 32 | 33 | `Chain` is similar to `Sequential` but extends and generalizes combination of 34 | layers by treating `*args` as a stack; popping required arguments and pushing 35 | outputs as it iterates through the layers. 36 | 37 | Specifically when the layers' signatures are different the leading arguments 38 | are replaced with the outputs from the previous layer: 39 | outputs = layer(args[:num_args]) 40 | args = outputs + args[len(outputs):] # replace leading arguments 41 | 42 | That is, `*args` is processed like a stack (leading arguments are poped and 43 | replaced). This stack logic is implemented in `call_chained_layer()`. 44 | 45 | `**kwargs` are passed straight-through the chain; unmodified. Further kwargs 46 | are only passed to layers whose `__call__()` specifies the existence of such 47 | an argument. 48 | 49 | Attributes: 50 | preserve_adopted_names: Flax module naming property (don't change). 51 | layers: The list of layers. 52 | """ 53 | 54 | preserve_adopted_names: bool = True 55 | layers: Sequence[base_layer.BaseLayer] | None = None 56 | 57 | def __call__(self, *args: Any, **kwargs: Any) -> Any: 58 | """Calls layers one-by-one and returns the last chain output.""" 59 | return self.call_chain(*args, **kwargs)[-1] 60 | 61 | def call_chain(self, *args: Any, **kwargs: Any) -> Sequence[Any]: 62 | """Calls layers one-by-one and returns a list with each layer's output.""" 63 | return call_chain(self.layers, *args, **kwargs) 64 | 65 | 66 | def call_chain( 67 | layers: Sequence[Callable[..., Any]], *args: Any, **kwargs: Any 68 | ) -> Sequence[Any]: 69 | """Calls layers one-by-one and returns a list with each layer's output. 70 | 71 | Wraps `call_chained_layer()` which implements the core logic of chain 72 | argument passing. 73 | 74 | Args: 75 | layers: The sequence of `BaseLayer`s. 76 | *args: The argument stack. 77 | **kwargs: The optional kwargs. 78 | 79 | Returns: 80 | A list with the output from each layer. 81 | """ 82 | outputs = [] 83 | args_stack = args 84 | for i, l in enumerate([l for l in layers if l is not None]): 85 | try: 86 | layer_outs = call_chained_layer(l, *args_stack, **kwargs) 87 | except Exception as e: 88 | raise type(e)( 89 | str(e) + f' Layer index={i} name={_name_attr(l)} args: {args_stack}' 90 | ) from e 91 | 92 | outputs.append(layer_outs) 93 | args_stack = ensure_tuple(layer_outs) 94 | 95 | if not outputs: 96 | return [args if len(args) > 1 else args[0]] 97 | 98 | return outputs 99 | 100 | 101 | def call_chained_layer( 102 | layer: Callable[..., Any], *args: Any, **kwargs: Any 103 | ) -> Any: 104 | """Passes required arguments and matching kwargs to `layer`. 105 | 106 | This is the main function implementing the argument stack; it passes 107 | required arguments to the layer and all matching kwargs. Other arguments are 108 | withheld. 109 | 110 | Args: 111 | layer: The layer to call. 112 | *args: The arguments to pass to required parameters. 113 | **kwargs: The kwargs to pass to existing kwargs. 114 | 115 | Returns: 116 | The new argument stack with leading arguments replaced: 117 | outs = layer(args[:num_args]) 118 | outputs = outs + args[len(outs):] 119 | """ 120 | signature = inspect.signature(layer.__call__) 121 | num_args = len(args) 122 | count = 0 123 | is_variadic = False 124 | for name, p in signature.parameters.items(): 125 | if name in kwargs.keys(): 126 | break 127 | if p.default != inspect.Signature.empty: 128 | break 129 | # Pass everything if variadic. 130 | if p.kind in (p.VAR_POSITIONAL, p.VAR_KEYWORD): 131 | is_variadic = True 132 | break 133 | count += 1 134 | if count > num_args: 135 | raise ValueError( 136 | f'Layer name={_name_attr(layer)} has too many args {count} > {num_args}' 137 | f' signature={signature}' 138 | ) 139 | matching_kwargs = { 140 | k: v for k, v in kwargs.items() if k in signature.parameters.keys() 141 | } 142 | 143 | if is_variadic: 144 | outs = layer(*args, **kwargs) 145 | else: 146 | outs = layer(*args[:count], **matching_kwargs) 147 | 148 | outs = list(ensure_tuple(outs)) if outs is not None else [] 149 | outputs = outs + list(args[len(outs) :]) 150 | return tuple(outputs) if len(outputs) > 1 else outputs[0] 151 | 152 | 153 | def ensure_tuple(x: Any) -> tuple[Any, ...]: 154 | """Ensures that `x` is a tuple.""" 155 | return x if isinstance(x, tuple) else (x,) 156 | 157 | 158 | def _name_attr(layer: Callable[..., Any]) -> str: 159 | """Returns the `name` attribute if it exists.""" 160 | return layer.name if hasattr(layer, 'name') else '' 161 | -------------------------------------------------------------------------------- /praxis/layers/checkpoint_policy.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Gradient checkpoint policies that are supported by the `checkpoint` transform.""" 17 | 18 | import enum 19 | import jax 20 | 21 | 22 | @enum.unique 23 | class AutodiffCheckpointType(str, enum.Enum): 24 | """jax.checkpoint policy types.""" 25 | 26 | SAVE_NOTHING = 'save_nothing' 27 | SAVE_UNET_ALL_CONV = 'save_unet_all_conv' 28 | SAVE_UNET_CONV = 'save_unet_conv' 29 | SAVE_EVERYTHING = 'save_everything' 30 | SAVE_QKV_OUT_PROJ = 'save_qkv_out_proj' 31 | SAVE_OUT_PROJ = 'save_out_proj' 32 | SAVE_CONTEXT = 'save_context' 33 | SAVE_CONTEXT_AND_OUT_PROJ = 'save_encoded_and_out_proj' 34 | SAVE_DOT_ONLY = 'save_dot_only' 35 | SAVE_DOT_WITH_NO_BATCH_DIM = 'save_dot_with_no_batch_dims' 36 | SAVE_DOT_FOR_MLPERF_200B = 'save_dot_for_mlperf_200b' 37 | SAVE_ITERATION_INPUT = 'save_iteration_input' 38 | SAVE_TRANSFORMER_LAYER_OUTPUT = 'save_transformer_layer_output' 39 | SAVE_QUANTIZED = 'save_quantized' 40 | SAVE_QKV_OUT_PROJ_SEPARATE = 'save_qkv_out_proj_separate' 41 | SAVE_DOT_EXCEPT_LOGITS_FFN1 = 'save_dot_except_logits_ffn1' 42 | SAVE_DOT_EXCEPT_LOGITS = 'save_dot_except_logits' 43 | 44 | 45 | def custom_policy(checkpoint_policy: AutodiffCheckpointType): 46 | """Returns a JAX Autodiff checkpointing policy from the enum value.""" 47 | # TODO(zhangqiaorjc): Configure custom checkpoint policy in expt config 48 | # without introducing enum. 49 | if checkpoint_policy == AutodiffCheckpointType.SAVE_EVERYTHING: 50 | return jax.checkpoint_policies.everything_saveable 51 | if checkpoint_policy == AutodiffCheckpointType.SAVE_DOT_ONLY: 52 | return jax.checkpoint_policies.checkpoint_dots 53 | if checkpoint_policy == AutodiffCheckpointType.SAVE_DOT_WITH_NO_BATCH_DIM: 54 | return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims 55 | if checkpoint_policy == AutodiffCheckpointType.SAVE_QKV_OUT_PROJ: 56 | return jax.checkpoint_policies.save_only_these_names( 57 | 'combined_qkv_proj', 'out_proj' 58 | ) 59 | if checkpoint_policy == AutodiffCheckpointType.SAVE_QKV_OUT_PROJ_SEPARATE: 60 | return jax.checkpoint_policies.save_only_these_names( 61 | 'query_proj', 'value_proj', 'key_proj', 'out_proj' 62 | ) 63 | if checkpoint_policy == AutodiffCheckpointType.SAVE_CONTEXT: 64 | return jax.checkpoint_policies.save_only_these_names('context') 65 | if checkpoint_policy == AutodiffCheckpointType.SAVE_OUT_PROJ: 66 | return jax.checkpoint_policies.save_only_these_names('out_proj') 67 | if checkpoint_policy == AutodiffCheckpointType.SAVE_CONTEXT_AND_OUT_PROJ: 68 | return jax.checkpoint_policies.save_only_these_names('context', 'out_proj') 69 | if checkpoint_policy == AutodiffCheckpointType.SAVE_DOT_FOR_MLPERF_200B: 70 | return jax.checkpoint_policies.save_only_these_names( 71 | 'combined_qkv_proj', 72 | 'query_proj', 73 | 'value_proj', 74 | 'key_proj', 75 | 'context', 76 | 'out_proj', 77 | ) 78 | if checkpoint_policy == AutodiffCheckpointType.SAVE_QUANTIZED: 79 | return jax.checkpoint_policies.save_only_these_names( 80 | 'context', 81 | 'out_proj', 82 | 'combined_qkv_proj', 83 | 'qlhs', 84 | 'qrhs', 85 | 'lhs_scale', 86 | 'rhs_scale', 87 | ) 88 | if checkpoint_policy == AutodiffCheckpointType.SAVE_ITERATION_INPUT: 89 | return jax.checkpoint_policies.save_only_these_names('iteration_input') 90 | if checkpoint_policy == AutodiffCheckpointType.SAVE_TRANSFORMER_LAYER_OUTPUT: 91 | return jax.checkpoint_policies.save_only_these_names( 92 | 'transformer_layer_out' 93 | ) 94 | if checkpoint_policy == AutodiffCheckpointType.SAVE_DOT_EXCEPT_LOGITS_FFN1: 95 | return jax.checkpoint_policies.save_only_these_names( 96 | 'combined_qkv_proj', 97 | 'query_proj', 98 | 'value_proj', 99 | 'key_proj', 100 | 'context', 101 | 'out_proj', 102 | 'ffn2', 103 | ) 104 | if checkpoint_policy == AutodiffCheckpointType.SAVE_DOT_EXCEPT_LOGITS: 105 | return jax.checkpoint_policies.save_only_these_names( 106 | 'combined_qkv_proj', 107 | 'query_proj', 108 | 'value_proj', 109 | 'key_proj', 110 | 'context', 111 | 'out_proj', 112 | 'ffn1', 113 | 'ffn2', 114 | ) 115 | if checkpoint_policy == AutodiffCheckpointType.SAVE_UNET_CONV: 116 | return jax.checkpoint_policies.save_only_these_names('conv_out') 117 | if checkpoint_policy == AutodiffCheckpointType.SAVE_UNET_ALL_CONV: 118 | return jax.checkpoint_policies.save_only_these_names( 119 | 'conv_0', 'conv_1', 'conv_out' 120 | ) 121 | assert checkpoint_policy == AutodiffCheckpointType.SAVE_NOTHING 122 | return jax.checkpoint_policies.nothing_saveable 123 | -------------------------------------------------------------------------------- /praxis/layers/chunk.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Chunk functions.""" 17 | 18 | import jax 19 | from jax import numpy as jnp 20 | from praxis import pytypes 21 | 22 | JTensor = pytypes.JTensor 23 | 24 | 25 | def chunk(x: JTensor, *, chunk_size: int, axis: int = 1) -> JTensor: 26 | """Splits the x tensor into chunks with `chunk_size`. 27 | 28 | Args: 29 | x: JTensor of shape [..., L, ...]. 30 | chunk_size: int, the chunk size. 31 | axis: axis for chunking. 32 | 33 | Returns: 34 | JTensor of shape [L/C, ..., C, ...]. 35 | """ 36 | # Maybe pad paddings to multiple of chunk_size. 37 | extra_pad = -x.shape[axis] % chunk_size 38 | if extra_pad: 39 | # Use default pad value unless specified. 40 | x = jnp.pad(x, [[0, extra_pad if i == axis else 0] for i in range(x.ndim)]) 41 | 42 | x_shape = jnp.shape(x) 43 | num_chunks = x_shape[axis] // chunk_size 44 | # Reshape to [..., L/C, C, ...]. 45 | new_shape = x_shape[:axis] + (num_chunks, chunk_size) + x_shape[axis + 1 :] 46 | chunk_x = jnp.reshape(x, new_shape) 47 | # [L/C, ..., C, ...] 48 | return jnp.moveaxis(chunk_x, axis, 0) 49 | 50 | 51 | def unchunk( 52 | chunk_x: JTensor, axis: int = 1, seqlen: int | None = None 53 | ) -> JTensor: 54 | """Reshapes the x tensor from chunks to the original shape. 55 | 56 | Args: 57 | chunk_x: JTensor of shape [L/C, ..., C, ...]. 58 | axis: axis for chunking. 59 | seqlen: int, the original sequence length. 60 | 61 | Returns: 62 | JTensor of shape [..., L, ...]. 63 | """ 64 | chunk_size = chunk_x.shape[axis + 1] 65 | # [..., L/C, C, ...] 66 | chunk_x = jnp.moveaxis(chunk_x, 0, axis) 67 | # [..., L, ...] 68 | new_shape = chunk_x.shape[:axis] + (-1,) + chunk_x.shape[axis + 2 :] 69 | x = jnp.reshape(chunk_x, new_shape) 70 | if seqlen is not None and seqlen % chunk_size: 71 | x = jax.lax.slice_in_dim(x, 0, seqlen, axis=axis) 72 | return x 73 | -------------------------------------------------------------------------------- /praxis/layers/chunk_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for Praxis Chunk functions.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | import numpy as np 21 | from praxis import test_utils 22 | from praxis.layers import chunk 23 | 24 | 25 | class ChunkTest(test_utils.TestCase): 26 | 27 | def setUp(self): 28 | super().setUp() 29 | np.random.seed(123456) 30 | 31 | @parameterized.parameters( 32 | ([3, 4, 5], 0, 2, [2, 2, 4, 5]), 33 | ([3, 5], 1, 2, [3, 3, 2]), 34 | ([3, 4, 2, 8], 1, 3, [2, 3, 3, 2, 8]), 35 | ([3, 4, 2, 8], 3, 4, [2, 3, 4, 2, 4]), 36 | ) 37 | def test_chunk(self, in_shape, axis, chunk_size, chunk_shape): 38 | x = np.random.normal(1.0, 0.5, in_shape) 39 | chunk_x = chunk.chunk(x, chunk_size=chunk_size, axis=axis) 40 | self.assertArraysEqual(chunk_x.shape, chunk_shape) 41 | 42 | out_x = chunk.unchunk(chunk_x, axis=axis, seqlen=x.shape[axis]) 43 | self.assertAllClose(x, out_x) 44 | 45 | 46 | if __name__ == '__main__': 47 | absltest.main() 48 | -------------------------------------------------------------------------------- /praxis/layers/einsum.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """A layer that computes an Einsum with a weight, and optionally adds a bias.""" 17 | 18 | from typing import Sequence 19 | 20 | from praxis import base_layer 21 | from praxis import pax_fiddle 22 | from praxis import pytypes 23 | from praxis.layers import base_ops 24 | 25 | JTensor = pytypes.JTensor 26 | LayerTpl = pax_fiddle.Config[base_layer.BaseLayer] 27 | template_field = base_layer.template_field 28 | 29 | 30 | class Einsum(base_layer.BaseLayer): 31 | """Layer that computes an einsum and maybe a bias. 32 | 33 | The fan-in, fan-out and bias dimensions are inferred from the einsum equation. 34 | If bias is used, the fan-out dims must appear at the end of the output tensor. 35 | 36 | Attributes: 37 | eqn: Einsum equation. It should be in the format of input,w->output. E.g., 38 | '...d,df->...f'. 39 | w_shape: Weight shape. 40 | use_bias: Whether to add a bias. 41 | einsum_op_tpl: The op definition that implements einsum. Enables injection. 42 | """ 43 | eqn: str = '' 44 | w_shape: Sequence[int] = () 45 | use_bias: bool = False 46 | einsum_op_tpl: LayerTpl = template_field(base_ops.EinsumOp) 47 | 48 | def setup(self) -> None: 49 | operands, out = self.eqn.split('->') 50 | x, w = operands.split(',') 51 | assert '.' not in w 52 | fan_in = sorted(w.index(d) for d in (set(x) - set(out))) 53 | fan_out = sorted(w.index(d) for d in (set(out) - set(x))) 54 | w_sharding = self.weight_split_dims_mapping.wt 55 | pc = base_layer.WeightHParams( 56 | shape=self.w_shape, 57 | fan_in_axes=fan_in, 58 | fan_out_axes=fan_out, 59 | mesh_shape=self.mesh_shape, 60 | tensor_split_dims_mapping=w_sharding, 61 | ) 62 | self.create_variable('w', pc) 63 | if self.use_bias: 64 | out_bias_dims = sorted(out.index(d) for d in (set(out) - set(x))) 65 | # Fan-out dims must be at the end of `out`. 66 | assert all(d >= len(out) - len(out_bias_dims) for d in out_bias_dims) 67 | bias_shape = [self.w_shape[w.index(out[d])] for d in out_bias_dims] 68 | if w_sharding is not None: 69 | b_sharding = [w_sharding[w.index(out[d])] for d in out_bias_dims] 70 | else: 71 | b_sharding = None 72 | pc_bias = base_layer.WeightHParams( 73 | shape=bias_shape, 74 | init=base_layer.WeightInit.Constant(0.0), 75 | mesh_shape=self.mesh_shape, 76 | tensor_split_dims_mapping=b_sharding, 77 | ) 78 | self.create_variable('b', pc_bias) 79 | self.create_child('einsum', self.einsum_op_tpl.clone()) 80 | 81 | def __call__(self, inputs: JTensor) -> JTensor: 82 | """Computes the einsum and maybe bias. 83 | 84 | Args: 85 | inputs: A JTensor of shape as described in the equation. 86 | 87 | Returns: 88 | The result of the einsum with maybe a bias added. 89 | """ 90 | ret = self.einsum(self.eqn, inputs, self.theta.w) 91 | if self.use_bias: 92 | ret += self.theta.b 93 | return base_layer.maybe_shard( 94 | ret, self.activation_split_dims_mapping.out, self.mesh_axis_names 95 | ) 96 | -------------------------------------------------------------------------------- /praxis/layers/einsum_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for Praxis Einsum layers.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | import jax 21 | import numpy as np 22 | from praxis import base_layer 23 | from praxis import pax_fiddle 24 | from praxis import test_utils 25 | from praxis.layers import einsum 26 | 27 | instantiate = base_layer.instantiate 28 | 29 | 30 | class EinsumTest(test_utils.TestCase): 31 | 32 | def setUp(self): 33 | super().setUp() 34 | np.random.seed(123456) 35 | 36 | @parameterized.parameters( 37 | ('...d,df->...f', [3, 4, 5], [5, 8], False), 38 | ('...d,dnh->...nh', [3, 5], [5, 2, 8], True), 39 | ('blnh,dnh->bld', [3, 4, 2, 8], [5, 2, 8], False), 40 | ('...nh,hdn->...d', [3, 4, 2, 8], [8, 5, 2], True), 41 | ) 42 | def test_einsum(self, eqn, in_shape, w_shape, use_bias): 43 | p = pax_fiddle.Config( 44 | einsum.Einsum, 45 | name='einsum', 46 | eqn=eqn, 47 | w_shape=w_shape, 48 | use_bias=use_bias, 49 | ) 50 | layer = instantiate(p) 51 | inputs = np.random.normal(1.0, 0.5, in_shape).astype( 52 | 'float32' 53 | ) 54 | prng_key = jax.random.PRNGKey(seed=123) 55 | initial_vars = layer.init(prng_key, inputs) 56 | if use_bias: 57 | # Make sure the bias is non-zero. 58 | initial_vars['params']['b'] = np.random.normal( 59 | 1.0, 0.5, initial_vars['params']['b'].shape 60 | ).astype('float32') 61 | outputs = layer.apply(initial_vars, inputs) 62 | np_outputs = np.einsum(eqn, inputs, initial_vars['params']['w']) 63 | if use_bias: 64 | np_outputs += initial_vars['params']['b'] 65 | self.assertAllClose(outputs, np_outputs, atol=1e-6) 66 | 67 | 68 | if __name__ == '__main__': 69 | absltest.main() 70 | -------------------------------------------------------------------------------- /praxis/layers/injection/BUILD: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Description: 17 | # Praxis layers. The public API is defined in __init__.py. 18 | 19 | load("//praxis:build-visibility.bzl", "JAX_VISIBILITY") 20 | load("//praxis:praxis.bzl", "pytype_strict_library", "pytype_strict_test") 21 | 22 | package(default_visibility = JAX_VISIBILITY) 23 | 24 | licenses(["notice"]) 25 | 26 | pytype_strict_library( 27 | name = "fp8_nvidia_gpu", 28 | srcs = ["fp8_nvidia_gpu.py"], 29 | deps = [ 30 | # Implicit flax.core dependency. 31 | # Implicit jax dependency. 32 | "//praxis:base_layer", 33 | "//praxis:pax_fiddle", 34 | "//praxis:pytypes", 35 | "//praxis/layers", 36 | ], 37 | ) 38 | 39 | pytype_strict_test( 40 | name = "fp8_nvidia_gpu_test", 41 | srcs = ["fp8_nvidia_gpu_test.py"], 42 | deps = [ 43 | ":fp8_nvidia_gpu", 44 | # Implicit absl.testing.absltest dependency. 45 | # Implicit absl.testing.parameterized dependency. 46 | # Implicit flax.core dependency. 47 | # Implicit upb python proto dependency. 48 | # Implicit jax dependency. 49 | "//praxis:base_layer", 50 | "//praxis:pax_fiddle", 51 | "//praxis:test_utils", 52 | "//praxis/layers:linears", 53 | "//praxis/layers:pipeline", 54 | ], 55 | ) 56 | -------------------------------------------------------------------------------- /praxis/layers/injection/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Exposes the public layer functionalities for injection.""" 17 | 18 | from praxis.layers.injection import fp8_nvidia_gpu 19 | -------------------------------------------------------------------------------- /praxis/layers/losses.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Loss functions.""" 17 | 18 | from jax import numpy as jnp 19 | from praxis import base_layer 20 | from praxis import pytypes 21 | 22 | from jax_bitempered_loss import loss 23 | 24 | WeightHParams = base_layer.WeightHParams 25 | WeightInit = base_layer.WeightInit 26 | JTensor = pytypes.JTensor 27 | 28 | 29 | class BiTemperedLoss(base_layer.BaseLayer): 30 | """Bi-tempered logitstic loss. 31 | 32 | Bi-Tempered logistic loss is a generalized softmax cross-entropy loss function 33 | with a bounded loss value per sample and a heavy-tail softmax probability 34 | function. Temperature t1 < 1.0 controls the boundedness and t2 > 1.0 controls 35 | the tail heaviness of the softmax probabilities. 36 | 37 | Temperature pairs (1, 1) correspond to CE loss. The schedule sets the 38 | temperatures to (1, 1) initially and transitions to (t1, t2) linearly over 39 | the interval [start_step, end_step]. If end_step == 0, then the temperatures 40 | are set to (t1, t2) throughout. 41 | 42 | Source: https://bit.ly/3jSol8T 43 | 44 | Attributes: 45 | t1: Temperature 1 (log). 46 | t2: Temperature 2 (exp). 47 | label_smoothing: Label smoothing. 48 | start_step: Step number to start transitioning from CE loss. 49 | end_step: Step number to reach the final temperature pairs (t1, t2). When 50 | end_step == 0, the temperatures are set to (t1, t2) throughout. 51 | """ 52 | t1: float = 1.0 53 | t2: float = 1.0 54 | label_smoothing: float = 0.0 55 | start_step: int = 0 56 | end_step: int = 0 57 | 58 | def setup(self) -> None: 59 | """Initialize the step variable.""" 60 | assert self.end_step >= self.start_step 61 | count = WeightHParams( 62 | shape=[], 63 | init=WeightInit.Constant(0.0), 64 | dtype=jnp.float32, 65 | collections=[base_layer.WeightHParamsCollection.REQUIRES_MEAN_SYNC]) 66 | self.create_variable('count', count, trainable=False) 67 | 68 | def temperature_schedule(self, count: JTensor) -> JTensor: 69 | """Temperature schedule. 70 | 71 | The temperatures will be set to the final values if end_step == 0. 72 | 73 | Args: 74 | count: Step number. 75 | 76 | Returns: 77 | Base schedule. 78 | """ 79 | count = jnp.array(count).astype(jnp.float32) 80 | schedule = jnp.where( 81 | jnp.logical_and(self.end_step > 0, count < self.end_step), 1.0, 0.0 82 | ) 83 | schedule = jnp.where( 84 | count >= self.start_step, 85 | jnp.maximum( 86 | 1.0 87 | - (count - self.start_step) 88 | / jnp.maximum(self.end_step - self.start_step, 1.0), 89 | 0.0, 90 | ), 91 | schedule, 92 | ) 93 | return schedule 94 | 95 | def __call__(self, logits: JTensor, labels: JTensor) -> JTensor: 96 | """Applies bi-tempered loss. 97 | 98 | Args: 99 | logits: The logits JTensor. Shaped [..., num_classes]. 100 | labels: The one-hot labels JTensor. Shaped [..., num_classes]. 101 | 102 | Returns: 103 | Loss values. Shaped either [...] or same as logits/labels but without the 104 | last dimension of size `num_classes`. 105 | """ 106 | base_schedule = 0.0 107 | if not self.do_eval: 108 | count = self.get_var('count') 109 | self.update_var('count', count + 1.0) 110 | base_schedule = self.temperature_schedule(count) 111 | t1 = 1.0 * base_schedule + self.t1 * (1.0 - base_schedule) 112 | t2 = 1.0 * base_schedule + self.t2 * (1.0 - base_schedule) 113 | loss_vals = loss.bi_tempered_logistic_loss( 114 | logits, labels, t1, t2, label_smoothing=self.label_smoothing 115 | ) 116 | return loss_vals 117 | -------------------------------------------------------------------------------- /praxis/layers/quantization/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Exposes the public layer functionalities.""" 17 | 18 | from praxis.layers.quantization.attentions import AttentionProjection 19 | from praxis.layers.quantization.attentions import AttentionProjectionLoRA 20 | from praxis.layers.quantization.attentions import CombinedQKVProjectionLayer 21 | from praxis.layers.quantization.attentions import DotProductAttention 22 | from praxis.layers.quantization.conformers import DotProductAttentionWithContext 23 | from praxis.layers.quantization.convolutions import Conv2D 24 | from praxis.layers.quantization.einsum import Einsum 25 | from praxis.layers.quantization.embedding_softmax import Embedding 26 | from praxis.layers.quantization.embedding_softmax import NClassMajorSharedEmbeddingSoftmax 27 | from praxis.layers.quantization.embedding_softmax import SharedEmbeddingSoftmax 28 | from praxis.layers.quantization.linears import Linear 29 | from praxis.layers.quantization.linears import LinearLoRA 30 | from praxis.layers.quantization.multi_query_attention import OneHeadedAttentionProjection 31 | from praxis.layers.quantization.ngrammer import Ngrammer 32 | from praxis.layers.quantization.ngrammer import VQNgrammer 33 | from praxis.layers.quantization.operations import einsum 34 | from praxis.layers.quantization.overflow_check import AttentionProjectionOverflowCheck, CombinedQKVProjectionLayerOverflowCheck, FeedForwardOverflowCheck, OneHeadedAttentionProjectionOverflowCheck 35 | from praxis.layers.quantization.searchable import SearchableAttentionProjection 36 | from praxis.layers.quantization.searchable import SearchableCombinedQKVProjectionLayer 37 | from praxis.layers.quantization.searchable import SearchableLinear 38 | -------------------------------------------------------------------------------- /praxis/layers/quantization/automl_select.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Praxis AutoMLSelect layer to switch branches according to AutoML decisions.""" 17 | 18 | from typing import Sequence 19 | 20 | from flax import linen as nn 21 | import jax.numpy as jnp 22 | from praxis import base_layer 23 | from praxis import pax_fiddle 24 | from praxis import pytypes 25 | 26 | JTensor = pytypes.JTensor 27 | LayerTpl = pax_fiddle.Config[base_layer.BaseLayer] 28 | template_field = pax_fiddle.template_field 29 | WeightInit = base_layer.WeightInit 30 | WeightHParams = base_layer.WeightHParams 31 | 32 | 33 | # TODO(yunn): Move BaseAutoMLSelect to efficient NAS package for Praxis 34 | class BaseAutoMLSelect(base_layer.BaseLayer): 35 | """Praxis AutoMLSelect layer to switch branches according to AutoML decisions. 36 | 37 | Attributes: 38 | search_options_tpl: a sequence of layers, with each represents a branch. 39 | They layers must have the same shapes of input and output. 40 | """ 41 | 42 | search_options_tpl: Sequence[LayerTpl] | None = template_field(None) 43 | 44 | def setup(self) -> None: 45 | if not self.search_options_tpl: 46 | raise AttributeError('Must set at least one search option.') 47 | decision = WeightHParams( 48 | shape=[], 49 | init=WeightInit.Constant(0), 50 | dtype=jnp.uint8, 51 | mesh_shape=self.mesh_shape 52 | ) 53 | rl_variables = WeightHParams( 54 | shape=[len(self.search_options_tpl)], 55 | init=WeightInit.Constant(0), 56 | dtype=jnp.float32, 57 | mesh_shape=self.mesh_shape 58 | ) 59 | self.create_children('search_options', self.search_options_tpl) 60 | self.create_variable('decision', decision, trainable=False) 61 | self.create_variable('rl_variables', rl_variables, trainable=False) 62 | 63 | 64 | class AutoMLSelect(BaseAutoMLSelect): 65 | """An AutoMLSelect that switches between different quantizer.""" 66 | 67 | def __call__( 68 | self, 69 | x: JTensor, 70 | contract_dims: int | Sequence[int], 71 | squeeze_scale=True, 72 | quantized_dtype: jnp.dtype | None = None, 73 | ) -> tuple[JTensor, JTensor, JTensor | None]: 74 | def branch_fn(i): 75 | def quantize_fn(mdl, inputs): 76 | return mdl.search_options[i].quantize( 77 | inputs, contract_dims, squeeze_scale, quantized_dtype 78 | ) 79 | 80 | return quantize_fn 81 | 82 | branches = [branch_fn(i) for i in range(len(self.search_options))] 83 | return nn.switch(self.get_var('decision'), branches, self, x) 84 | 85 | def quantize( 86 | self, 87 | x: JTensor, 88 | contract_dims: int | Sequence[int], 89 | squeeze_scale=True, 90 | quantized_dtype: jnp.dtype | None = None, 91 | ) -> tuple[JTensor, JTensor, JTensor | None]: 92 | return self.__call__(x, contract_dims, squeeze_scale, quantized_dtype) 93 | -------------------------------------------------------------------------------- /praxis/layers/quantization/automl_select_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for praxis.layers.quantization.automl_select.""" 17 | 18 | from absl.testing import absltest 19 | 20 | import jax 21 | import jax.numpy as jnp 22 | 23 | from praxis import base_hyperparams 24 | from praxis import pax_fiddle 25 | from praxis import test_utils 26 | from praxis.layers.quantization import automl_select 27 | from praxis.layers.quantization import quantization_hparams 28 | from praxis.layers.quantization import quantizer 29 | 30 | 31 | ActQuantizationParams = quantization_hparams.ActQuantizationParams 32 | instantiate = base_hyperparams.instantiate 33 | 34 | 35 | class AutomlSelectTest(test_utils.TestCase): 36 | 37 | def test_automl_select(self): 38 | p = pax_fiddle.Config(automl_select.AutoMLSelect) 39 | p.search_options_tpl = [ 40 | quantizer.create_tensor_quantizer( 41 | 'int4', ActQuantizationParams(precision=4) 42 | ), 43 | quantizer.create_tensor_quantizer( 44 | 'int8', ActQuantizationParams(precision=8) 45 | ), 46 | ] 47 | 48 | m = instantiate(p) 49 | x = jnp.ones((1, 3)) 50 | m_vars = m.init(jax.random.PRNGKey(0), x, 0) 51 | m_vars['non_trainable']['decision'] = 0 52 | q_x1, q_s1, _ = m.apply(m_vars, x, 0) 53 | m_vars['non_trainable']['decision'] = 1 54 | q_x2, q_s2, _ = m.apply(m_vars, x, 0) 55 | self.assertNotAllClose(q_x1, q_x2) 56 | self.assertNotAllClose(q_s1, q_s2) 57 | self.assertIn('rl_variables', m_vars['non_trainable']) 58 | 59 | 60 | if __name__ == '__main__': 61 | absltest.main() 62 | -------------------------------------------------------------------------------- /praxis/layers/quantization/conformers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Quantized Conformer Layers.""" 17 | 18 | import jax.numpy as jnp 19 | from praxis import pytypes 20 | from praxis.layers import attentions 21 | from praxis.layers.quantization import attentions as qattentions 22 | 23 | JTensor = pytypes.JTensor 24 | 25 | 26 | class DotProductAttentionWithContext(qattentions.DotProductAttention): 27 | """Dot-product attention with given left and right context. 28 | 29 | It covers several use cases: 30 | 1 global self attention when left_context=right_context=None 31 | 2 local self attention when left_context!=None and right_context!=None 32 | 3 hybrid self attention when left_context or right_context is None 33 | 34 | For use cases (2,3) it will use emulated local self attention. 35 | For use case (2) it is more efficient to use LocalSelfAttention. 36 | """ 37 | 38 | left_context: int | None = None 39 | right_context: int | None = None 40 | 41 | def _dot_atten( 42 | self, 43 | query: JTensor, 44 | key: JTensor, 45 | value: JTensor, 46 | atten_mask: JTensor, 47 | relative_bias: JTensor | None = None, 48 | ) -> tuple[JTensor, JTensor]: 49 | """Main attention function. 50 | 51 | Args: 52 | query: JTensor of shape [B, T, N, H]. 53 | key: JTensor of shape [B, S, N, H]. 54 | value: JTensor of shape [B, S, N, H]. 55 | atten_mask: JTensor of shape [1|B, 1, 1|T, S] which is a mask that is 56 | applied to prevent attention between unwanted pairs. This has already 57 | been converted into large negative logits. Note that the first and third 58 | dimension allow size 1 if the mask is shared by every item in the batch 59 | or every token in the target sequence. 60 | relative_bias: Relative bias of shape [B, N, T, S]. 61 | 62 | Returns: 63 | encoded: JTensor of shape [B, T, N, H]. 64 | atten_probs: JTensor of shape [B, N, T, S]. 65 | """ 66 | time_size = query.shape[1] 67 | 68 | if self.left_context is not None or self.right_context is not None: 69 | input_atten_mask = atten_mask 70 | atten_mask = attentions.limited_context_mask( 71 | self.left_context, self.right_context, time_size 72 | ) 73 | atten_mask = jnp.minimum(atten_mask, input_atten_mask) 74 | return super()._dot_atten(query, key, value, atten_mask, relative_bias) 75 | -------------------------------------------------------------------------------- /praxis/layers/quantization/conformers_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for quantized attentions.""" 17 | 18 | from absl.testing import absltest 19 | import jax 20 | import numpy as np 21 | from praxis import base_layer 22 | from praxis import pax_fiddle 23 | from praxis import test_utils 24 | from praxis.layers import attentions 25 | from praxis.layers import conformers 26 | from praxis.layers.quantization import conformers as qconformers 27 | from praxis.layers.quantization import quantization_hparams 28 | 29 | QuantizationParams = quantization_hparams.QuantizationParams 30 | QuantizationMode = quantization_hparams.QuantizationMode 31 | QuantizationType = quantization_hparams.QuantizationType 32 | instantiate = base_layer.instantiate 33 | 34 | 35 | class DotProductAttentionWithContexSyncTest(test_utils.TestCase): 36 | 37 | def setUp(self): 38 | super().setUp() 39 | np.random.seed(123456) 40 | 41 | def test_dot_product_attention_with_context(self): 42 | batch_size = 2 43 | source_max_length = 16 44 | input_dim = 32 45 | 46 | atten_f_p = pax_fiddle.Config( 47 | conformers.DotProductAttentionWithContext, name='atten_f' 48 | ) 49 | atten_q_p = pax_fiddle.Config( 50 | qconformers.DotProductAttentionWithContext, 51 | name='atten_q', 52 | quantization=QuantizationParams( 53 | quantization_type=QuantizationType.AQT, 54 | mode=QuantizationMode.TRAINING, 55 | # Test using 23 bits to minimize the quantization error and test 56 | # for numerical correctness. 57 | act_params=quantization_hparams.ActQuantizationParams(precision=23), 58 | weight_params=quantization_hparams.WeightQuantizationParams(), 59 | ), 60 | ) 61 | for p in [atten_f_p, atten_q_p]: 62 | p.input_dim = input_dim 63 | p.hidden_dim = 16 64 | p.left_context = 3 65 | p.right_context = 5 66 | 67 | atten_f = instantiate(atten_f_p) 68 | atten_q = instantiate(atten_q_p) 69 | 70 | query_vec = np.random.normal( 71 | size=[batch_size, source_max_length, input_dim] 72 | ).astype(np.float32) 73 | key_vec = query_vec 74 | value_vec = query_vec 75 | atten_mask = attentions.causal_mask(query_vec) 76 | 77 | with base_layer.JaxContext.new_context(): 78 | initial_vars = atten_f.init( 79 | jax.random.PRNGKey(0), 80 | query_vec, 81 | key_vec, 82 | value_vec, 83 | atten_mask, 84 | ) 85 | fprop_out_f, _ = atten_f.apply( 86 | initial_vars, 87 | query_vec, 88 | key_vec, 89 | value_vec, 90 | atten_mask, 91 | method=atten_f.__call__, 92 | ) 93 | fprop_out_q, _ = atten_q.apply( 94 | initial_vars, 95 | query_vec, 96 | key_vec, 97 | value_vec, 98 | atten_mask, 99 | method=atten_q.__call__, 100 | ) 101 | self.assertAllClose(fprop_out_f, fprop_out_q) 102 | 103 | 104 | if __name__ == '__main__': 105 | absltest.main() 106 | -------------------------------------------------------------------------------- /praxis/layers/quantization/convolutions.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Convolutional layers.""" 17 | 18 | from jax import numpy as jnp 19 | from praxis import base_layer 20 | from praxis import pytypes 21 | from praxis.layers import convolutions 22 | from praxis.layers import normalizations 23 | from praxis.layers.quantization import quantization_hparams 24 | from praxis.layers.quantization import quantizer 25 | 26 | WeightHParams = base_layer.WeightHParams 27 | QuantizationParams = quantization_hparams.QuantizationParams 28 | instance_field = base_layer.instance_field 29 | JTensor = pytypes.JTensor 30 | 31 | 32 | class Conv2D(convolutions.Conv2D, quantizer.QuantizationLayer): # pytype: disable=signature-mismatch 33 | """Conv2D with support of SAME/VALID paddings.""" 34 | 35 | _PACK_4BIT_DIM = 0 36 | 37 | def setup(self) -> None: 38 | self.check_dimensions() 39 | wp = self.weight_split_dims_mapping 40 | pc = WeightHParams( 41 | shape=self.filter_shape, 42 | mesh_shape=self.mesh_shape, 43 | tensor_split_dims_mapping=wp.wt, 44 | dtype=self.dtype, 45 | init=self.kernel_init, 46 | ) 47 | self.set_up_weights( 48 | weight_name='w', 49 | weight_params=pc, 50 | scale_shape=[self.filter_shape[-1]], 51 | ) 52 | 53 | if self.bias: 54 | self.create_variable( 55 | 'b', 56 | WeightHParams( 57 | shape=[self.filter_shape[-1]], 58 | init=self.bias_init, 59 | dtype=self.dtype, 60 | ), 61 | ) 62 | 63 | wn = self.weight_norm_tpl.clone().set(dim=self.filter_shape[-1]) 64 | self.weight_norm: normalizations.BaseNormalization 65 | self.create_child('weight_norm', wn) 66 | 67 | def __call__(self, inputs: JTensor) -> JTensor: 68 | """FProp that supports strided, dilated convolution, depthwise convolution. 69 | 70 | Args: 71 | inputs: Input sequence of shape [B, H, W, D_in], also known more popularly 72 | as NHWC format. 73 | 74 | Returns: 75 | Output sequence after applying convolutions of shape [B, H', W', D_out]. 76 | Note that if the padding is SAME and there is no dilation and striding, 77 | then H' = H and W' = W. 78 | """ 79 | # Check if the feature_group_count is compatible with the inputs and filter 80 | # For more information see XLA docs on ConvWithGeneralPadding below 81 | # https://www.tensorflow.org/xla/operation_semantics#convwithgeneralpadding_convolution 82 | if inputs.shape[3] % self.filter_shape[2] != 0: 83 | raise ValueError( 84 | f'Input features {inputs.shape[3]} must be a' 85 | f'multiple of filter input dim {self.filter_shape[2]} ' 86 | f'(Input shape: {inputs.shape}, ' 87 | f'filter shape: {self.filter_shape}).' 88 | ) 89 | # feature group count is D_in // filter input dim 90 | feature_group_count = inputs.shape[3] // self.filter_shape[2] 91 | if self.filter_shape[3] % feature_group_count != 0: 92 | raise ValueError( 93 | f'Filter output dim {self.filter_shape[3]} must be a ' 94 | f'multiple of feature group count {feature_group_count} ' 95 | f'(Input shape: {inputs.shape}, ' 96 | f'filter shape: {self.filter_shape}).' 97 | ) 98 | padding = self._compute_padding(inputs.shape) 99 | inputs = self._shard_bhwc(inputs.astype(self.fprop_dtype)) 100 | 101 | # The `dimension_numbers=('NHWC', 'HWIO', 'NHWC')` is to be consistent 102 | # with tf.conv2d, see e.g., see 103 | # https://github.com/google/jax/blob/main/jax/_src/lax/lax.py#L622 104 | dimension_numbers = ('NHWC', 'HWIO', 'NHWC') 105 | outputs = quantizer.quantized_conv( 106 | self, inputs, padding, dimension_numbers, feature_group_count 107 | ) 108 | outputs = self._shard_bhwc(outputs) 109 | if self.bias: 110 | outputs += jnp.reshape(self.theta.b, (1,) * (outputs.ndim - 1) + (-1,)) 111 | return outputs 112 | -------------------------------------------------------------------------------- /praxis/layers/quantization/convolutions_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for Quantized Praxis convolutional layers.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | import jax 21 | from jax import numpy as jnp 22 | import numpy as np 23 | from praxis import base_layer 24 | from praxis import pax_fiddle 25 | from praxis import test_utils 26 | from praxis.layers import convolutions 27 | from praxis.layers.quantization import convolutions as qconvolutions 28 | from praxis.layers.quantization import operations 29 | from praxis.layers.quantization import quantization_hparams 30 | 31 | instantiate = base_layer.instantiate 32 | QuantizationParams = quantization_hparams.QuantizationParams 33 | QuantizationMode = quantization_hparams.QuantizationMode 34 | 35 | PARAMS = base_layer.PARAMS 36 | NON_TRAINABLE = base_layer.NON_TRAINABLE 37 | 38 | 39 | class ConvolutionsTest(test_utils.TestCase): 40 | 41 | def setUp(self): 42 | super().setUp() 43 | np.random.seed(123456) 44 | 45 | @parameterized.parameters( 46 | ((5, 4, 24, 36), (1, 1), (1, 1), False, [2, 16, 36, 72]), 47 | ((2, 4, 16, 8), (2, 2), (1, 1), False, [2, 16, 32, 128]), 48 | ((4, 8, 16, 32), (1, 1), (1, 1), False, [2, 16, 32, 64]), 49 | ((2, 8, 16, 32), (1, 1), (2, 2), False, [2, 16, 32, 64]), 50 | ((2, 8, 16, 32), (2, 2), (2, 2), False, [2, 16, 32, 64]), 51 | ((2, 8, 16, 32), (1, 1), (2, 1), False, [2, 16, 32, 64]), 52 | ((2, 8, 16, 32), (2, 2), (2, 1), False, [2, 16, 32, 64]), 53 | ((2, 8, 16, 32), (1, 1), (2, 2), True, [2, 16, 32, 64]), 54 | ((2, 8, 16, 32), (2, 2), (2, 2), True, [2, 16, 32, 64]), 55 | ) 56 | def test_conv2d_layer_same_padding( 57 | self, 58 | filter_shape, 59 | filter_stride, 60 | dilations, 61 | tf_equivalent_padding, 62 | input_shape, 63 | ): 64 | npy_inputs = np.random.normal(0.0, 0.5, input_shape).astype('float32') 65 | inputs = jnp.asarray(npy_inputs) 66 | prng_key = jax.random.PRNGKey(seed=123) 67 | 68 | # Float layer. 69 | p = pax_fiddle.Config( 70 | convolutions.Conv2D, 71 | name='jax_conv2d', 72 | filter_shape=filter_shape, 73 | filter_stride=filter_stride, 74 | dilations=dilations, 75 | tf_equivalent_padding=tf_equivalent_padding, 76 | padding='SAME', 77 | ) 78 | conv_layer = instantiate(p) 79 | initial_vars = conv_layer.init(prng_key, inputs) 80 | outputs = conv_layer.apply(initial_vars, inputs) 81 | 82 | # Quantized layer. 83 | qp = pax_fiddle.Config( 84 | qconvolutions.Conv2D, 85 | name='jax_conv2_q', 86 | filter_shape=filter_shape, 87 | filter_stride=filter_stride, 88 | dilations=dilations, 89 | tf_equivalent_padding=tf_equivalent_padding, 90 | padding='SAME', 91 | quantization=QuantizationParams(mode=QuantizationMode.INFERENCE) 92 | ) 93 | qconv_layer = instantiate(qp) 94 | qinitial_vars = qconv_layer.init(prng_key, inputs) 95 | # TODO(jianlijianli): Use quantize_weight() once it's implemented. 96 | qweight, qscale, _ = operations.reduce_precision( 97 | initial_vars['params']['w'], [0, 1, 2] 98 | ) 99 | qinitial_vars['params']['w'] = qweight 100 | qinitial_vars['params']['w_quantized_scale'] = jnp.squeeze(qscale) 101 | qoutput = qconv_layer.apply(qinitial_vars, inputs) 102 | 103 | self.assertAllClose(qoutput, outputs, rtol=1e-02, atol=1e-02) 104 | 105 | 106 | if __name__ == '__main__': 107 | absltest.main() 108 | -------------------------------------------------------------------------------- /praxis/layers/quantization/einsum.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """A layer that computes an Einsum with a weight, and optionally adds a bias.""" 17 | 18 | from typing import Sequence 19 | 20 | from praxis import base_layer 21 | from praxis import pax_fiddle 22 | from praxis import pytypes 23 | from praxis.layers.quantization import quantizer 24 | 25 | JTensor = pytypes.JTensor 26 | LayerTpl = pax_fiddle.Config[base_layer.BaseLayer] 27 | template_field = base_layer.template_field 28 | 29 | 30 | class Einsum(quantizer.QuantizationLayer): 31 | """Layer that computes an einsum and maybe a bias. 32 | 33 | The fan-in, fan-out and bias dimensions are inferred from the einsum equation. 34 | If bias is used, the fan-out dims must appear at the end of the output tensor. 35 | 36 | Attributes: 37 | eqn: Einsum equation. It should be in the format of input,w->output. E.g., 38 | '...d,df->...f'. 39 | w_shape: Weight shape. 40 | use_bias: Whether to add a bias. 41 | """ 42 | eqn: str = '' 43 | w_shape: Sequence[int] = () 44 | use_bias: bool = False 45 | _PACK_4BIT_DIM = 0 46 | 47 | def setup(self) -> None: 48 | operands, out = self.eqn.split('->') 49 | x, w = operands.split(',') 50 | assert '.' not in w 51 | fan_in = sorted(w.index(d) for d in (set(x) - set(out))) 52 | fan_out = sorted(w.index(d) for d in (set(out) - set(x))) 53 | w_sharding = self.weight_split_dims_mapping.wt 54 | pc = base_layer.WeightHParams( 55 | shape=self.w_shape, 56 | fan_in_axes=fan_in, 57 | fan_out_axes=fan_out, 58 | mesh_shape=self.mesh_shape, 59 | tensor_split_dims_mapping=w_sharding, 60 | ) 61 | out_bias_dims = sorted(out.index(d) for d in (set(out) - set(x))) 62 | bias_shape = [self.w_shape[w.index(out[d])] for d in out_bias_dims] 63 | self.set_up_weights( 64 | weight_name='w', 65 | weight_params=pc, 66 | scale_shape=bias_shape, 67 | ) 68 | if self.use_bias: 69 | # Fan-out dims must be at the end of `out`. 70 | assert all(d >= len(out) - len(out_bias_dims) for d in out_bias_dims) 71 | if w_sharding is not None: 72 | b_sharding = [w_sharding[w.index(out[d])] for d in out_bias_dims] 73 | else: 74 | b_sharding = None 75 | pc_bias = base_layer.WeightHParams( 76 | shape=bias_shape, 77 | init=base_layer.WeightInit.Constant(0.0), 78 | mesh_shape=self.mesh_shape, 79 | tensor_split_dims_mapping=b_sharding, 80 | ) 81 | self.create_variable('b', pc_bias) 82 | 83 | def __call__(self, inputs: JTensor) -> JTensor: 84 | """Computes the einsum and maybe bias. 85 | 86 | Args: 87 | inputs: A JTensor of shape as described in the equation. 88 | 89 | Returns: 90 | The result of the einsum with maybe a bias added. 91 | """ 92 | ret = self.quantized_einsum( 93 | eqn=self.eqn, 94 | x=inputs, 95 | w=self.theta.w, 96 | reshape=[], 97 | ) 98 | if self.use_bias: 99 | ret += self.theta.b 100 | return base_layer.maybe_shard( 101 | ret, self.activation_split_dims_mapping.out, self.mesh_axis_names 102 | ) 103 | -------------------------------------------------------------------------------- /praxis/layers/quantization/einsum_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for Quantized Praxis Einsum layer.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | import jax 21 | import numpy as np 22 | from praxis import base_layer 23 | from praxis import pax_fiddle 24 | from praxis import test_utils 25 | from praxis.layers.quantization import einsum 26 | 27 | instantiate = base_layer.instantiate 28 | 29 | 30 | class EinsumTest(test_utils.TestCase): 31 | 32 | def setUp(self): 33 | super().setUp() 34 | np.random.seed(123456) 35 | 36 | @parameterized.parameters( 37 | ('...d,df->...f', [3, 4, 5], [5, 8], False), 38 | ('...d,dnh->...nh', [3, 5], [5, 2, 8], True), 39 | ('blnh,dnh->bld', [3, 4, 2, 8], [5, 2, 8], False), 40 | ('...nh,hdn->...d', [3, 4, 2, 8], [8, 5, 2], True), 41 | ) 42 | def test_einsum(self, eqn, in_shape, w_shape, use_bias): 43 | p = pax_fiddle.Config( 44 | einsum.Einsum, 45 | name='einsum', 46 | eqn=eqn, 47 | w_shape=w_shape, 48 | use_bias=use_bias, 49 | ) 50 | layer = instantiate(p) 51 | inputs = np.random.normal(1.0, 0.5, in_shape).astype('float32') 52 | prng_key = jax.random.PRNGKey(seed=123) 53 | initial_vars = layer.init(prng_key, inputs) 54 | qw = np.arange(np.prod(w_shape)).astype('int8').reshape(w_shape) 55 | initial_vars['params']['w'] = qw 56 | if use_bias: 57 | # Make sure the bias is non-zero. 58 | initial_vars['params']['b'] = np.random.normal( 59 | 1.0, 0.5, initial_vars['params']['b'].shape 60 | ).astype('float32') 61 | outputs = layer.apply(initial_vars, inputs) 62 | np_outputs = np.multiply( 63 | np.einsum(eqn, inputs, initial_vars['params']['w']), 64 | initial_vars['params']['w_quantized_scale'], 65 | ) 66 | if use_bias: 67 | np_outputs += initial_vars['params']['b'] 68 | self.assertAllClose(outputs, np_outputs, atol=1e-6) 69 | 70 | 71 | if __name__ == '__main__': 72 | absltest.main() 73 | -------------------------------------------------------------------------------- /praxis/layers/quantization/multi_query_attention_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for Praxis quantized multi-query attention layers.""" 17 | 18 | from absl import logging 19 | from absl.testing import absltest 20 | from absl.testing import parameterized 21 | import jax 22 | import numpy as np 23 | from praxis import base_layer 24 | from praxis import pax_fiddle 25 | from praxis import test_utils 26 | from praxis.layers import multi_query_attention as mqa_f 27 | from praxis.layers.quantization import multi_query_attention as mqa_q 28 | from praxis.layers.quantization import quantization_hparams 29 | from praxis.layers.quantization import utils 30 | 31 | instantiate = base_layer.instantiate 32 | QuantizationParams = quantization_hparams.QuantizationParams 33 | QuantizationMode = quantization_hparams.QuantizationMode 34 | QuantizationType = quantization_hparams.QuantizationType 35 | WeightQuantizationParams = quantization_hparams.WeightQuantizationParams 36 | 37 | 38 | # TODO(b/277357920): add mutiqueryattention quantization sync test. 39 | 40 | class MultiQueryAttentionTest(test_utils.TestCase): 41 | 42 | def setUp(self): 43 | super().setUp() 44 | np.random.seed(123456) 45 | 46 | @parameterized.named_parameters( 47 | dict(testcase_name='PTQ', quantization_type=QuantizationType.PTQ), 48 | dict(testcase_name='AQT', quantization_type=QuantizationType.AQT), 49 | ) 50 | def test_one_headed_projection_shape(self, quantization_type): 51 | test_layer_p = pax_fiddle.Config( 52 | mqa_q.OneHeadedAttentionProjection, 53 | name='mh', 54 | input_dim=16, 55 | output_dim=5, 56 | quantization=QuantizationParams( 57 | quantization_type=quantization_type, 58 | mode=QuantizationMode.TRAINING, 59 | weight_params=quantization_hparams.WeightQuantizationParams(), 60 | ), 61 | ) 62 | layer = instantiate(test_layer_p) 63 | 64 | inputs = np.random.normal(1.5, 2.0, [5, 16]).astype(np.float32) 65 | 66 | prng_key = jax.random.PRNGKey(seed=123) 67 | prng_key, init_key = jax.random.split(prng_key) 68 | initial_vars = layer.init(init_key, inputs) 69 | logging.info('initial_vars: %s', initial_vars) 70 | 71 | jax_out = layer.apply(initial_vars, inputs) 72 | self.assertSequenceEqual(jax_out.shape, [5, 5]) 73 | 74 | @parameterized.product( 75 | quantization_type=[QuantizationType.PTQ, QuantizationType.AQT], 76 | use_symmetric=[True, False], 77 | precision=[8, 4], 78 | ) 79 | def test_one_headed_projection_quantized( 80 | self, quantization_type, use_symmetric, precision 81 | ): 82 | input_dim = 16 83 | output_dim = 3 84 | batch = 5 85 | p_f = pax_fiddle.Config( 86 | mqa_f.OneHeadedAttentionProjection, 87 | name='ohp_f', 88 | input_dim=input_dim, 89 | output_dim=output_dim, 90 | ) 91 | p_q = pax_fiddle.Config( 92 | mqa_q.OneHeadedAttentionProjection, 93 | name='ohp_q', 94 | input_dim=input_dim, 95 | output_dim=output_dim, 96 | quantization=QuantizationParams( 97 | quantization_type=QuantizationType.PTQ, 98 | mode=QuantizationMode.INFERENCE, 99 | weight_params=WeightQuantizationParams( 100 | use_symmetric=use_symmetric, precision=precision 101 | ), 102 | ), 103 | ) 104 | 105 | inputs = np.random.normal(1.5, 2.0, [batch, input_dim]).astype(np.float32) 106 | quantized_weight_range = (-8, 7) if precision == 4 else (-128, 127) 107 | quantized_weight = np.random.randint( 108 | *quantized_weight_range, (input_dim, output_dim), dtype=np.int8 109 | ) 110 | w_scale = np.array([0.5, 2.0, 3.3], dtype=np.float32) 111 | if use_symmetric: 112 | weight_rescaled = quantized_weight * w_scale 113 | else: 114 | w_zp = np.array([-10.0, 10.0, -2.5], dtype=np.float32) 115 | weight_rescaled = quantized_weight * w_scale - w_zp 116 | 117 | ohp_f = instantiate(p_f) 118 | ohp_q = instantiate(p_q) 119 | 120 | with base_layer.JaxContext.new_context(): 121 | prng_key = jax.random.PRNGKey(seed=123) 122 | initial_vars_f = ohp_f.init(prng_key, inputs) 123 | initial_vars_q = ohp_q.init(prng_key, inputs) 124 | initial_vars_f['params']['w'] = weight_rescaled 125 | initial_vars_q['params']['w'] = ( 126 | utils.pack_4bit(quantized_weight, pack_dim=0) 127 | if precision == 4 128 | else quantized_weight 129 | ) 130 | initial_vars_q['params']['w_quantized_scale'] = w_scale 131 | if not use_symmetric: 132 | initial_vars_q['params']['w_quantized_zp'] = w_zp 133 | outputs_f = ohp_f.apply(initial_vars_f, inputs) 134 | outputs_q = ohp_q.apply(initial_vars_q, inputs) 135 | self.assertAllClose(outputs_f, outputs_q) 136 | 137 | 138 | if __name__ == '__main__': 139 | absltest.main() 140 | -------------------------------------------------------------------------------- /praxis/layers/quantization/ngrammer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Quantized Ngrammer layers.""" 17 | 18 | from jax import numpy as jnp 19 | from praxis import base_layer 20 | from praxis import pax_fiddle 21 | from praxis.layers import ngrammer 22 | from praxis.layers import normalizations 23 | from praxis.layers.quantization import embedding_softmax as quantized_embedding_softmax 24 | from praxis.layers.quantization import quantization_hparams 25 | 26 | QuantizationParams = quantization_hparams.QuantizationParams 27 | WeightHParams = base_layer.WeightHParams 28 | instance_field = base_layer.instance_field 29 | 30 | 31 | class Ngrammer(ngrammer.Ngrammer): 32 | """Quantized Ngrammer.""" 33 | 34 | quantization: QuantizationParams = instance_field(QuantizationParams) 35 | 36 | def setup(self) -> None: 37 | """Constructs an instance which looks up ngrams.""" 38 | 39 | if self.concat_ngrams: 40 | # The ngram_emb_dim must be smaller than dim_per_head. 41 | assert self.ngram_emb_dim <= self.dim_per_head 42 | else: 43 | # If not concatenating ngram embeddings, check the dims are compatible. 44 | assert self.ngram_emb_dim == self.dim_per_head 45 | 46 | # Create a separate layer norm per head for embedding normalization. 47 | # Create a separate layer norm per head for ngram embedding normalization. 48 | emb_layer_norm_p = [] 49 | ngram_emb_layer_norm_p = [] 50 | ngram_emb_table_p = [] 51 | for i in range(self.num_heads): 52 | layer_norm_p = pax_fiddle.Config(normalizations.LayerNorm).clone() 53 | layer_norm_p.dim = self.dim_per_head 54 | layer_norm_p.name = f'layer_norm_{i}' 55 | 56 | emb_layer_norm_p.append(layer_norm_p) 57 | ngram_layer_norm_p = pax_fiddle.Config(normalizations.LayerNorm).clone() 58 | ngram_layer_norm_p.dim = self.ngram_emb_dim 59 | ngram_emb_layer_norm_p.append(ngram_layer_norm_p) 60 | 61 | # Create embedding table for ngram lookup. 62 | embedding_p = pax_fiddle.Config( 63 | quantized_embedding_softmax.Embedding, 64 | quantization=self.quantization, 65 | ).clone() 66 | embedding_p.name = f'embedding_{i}' 67 | embedding_p.num_classes = self.ngram_vocab_size 68 | embedding_p.input_dims = self.ngram_emb_dim 69 | embedding_p.params_init = self.params_init 70 | # Copy sharding annotations. 71 | embedding_p.weight_split_dims_mapping = self.weight_split_dims_mapping 72 | ngram_emb_table_p.append(embedding_p) 73 | 74 | self.create_children('emb_layer_norm', emb_layer_norm_p) 75 | self.create_children('ngram_layer_norm', ngram_emb_layer_norm_p) 76 | self.create_children('ngram_table', ngram_emb_table_p) 77 | 78 | 79 | class VQNgrammer(ngrammer.VQNgrammer): 80 | """Quantized VQNgrammer.""" 81 | 82 | quantization: QuantizationParams = instance_field(QuantizationParams) 83 | 84 | def setup(self) -> None: 85 | """Constructs a VQ layer and an N-grammer layer.""" 86 | 87 | if self.concat_ngrams: 88 | # The ngram_emb_dim must be smaller than dim_per_head. 89 | assert self.ngram_emb_dim <= self.dim_per_head 90 | else: 91 | # If not concatenating ngram embeddings, check the dims are compatible. 92 | assert self.ngram_emb_dim == self.dim_per_head 93 | 94 | # Create VQ layer. 95 | vq_layer_p = pax_fiddle.Config( 96 | ngrammer.VectorQuantization, 97 | num_clusters=self.num_clusters, 98 | num_heads=self.num_heads, 99 | dim_per_head=self.dim_per_head, 100 | decay=self.decay, 101 | epsilon=self.epsilon, 102 | params_init=self.params_init, 103 | ) 104 | self.create_child('vq_layer', vq_layer_p) 105 | 106 | # Create the input id to cluster id cache. 107 | if self.unigram_vocab_size: 108 | input_id_to_cluster_id_cache = WeightHParams( 109 | shape=[self.unigram_vocab_size, self.num_heads], 110 | dtype=jnp.int32, 111 | init=base_layer.WeightInit.Constant(0), 112 | ) 113 | self.create_variable( 114 | 'input_id_to_cluster_id_cache', 115 | input_id_to_cluster_id_cache, 116 | trainable=False, 117 | ) 118 | 119 | # Create N-gram lookup layer. 120 | ngram_layer_p = pax_fiddle.Config( 121 | Ngrammer, # Quantized Ngrammer. 122 | quantization=self.quantization, 123 | ngram_vocab_size=self.ngram_vocab_size, 124 | unigram_vocab_size=self.num_clusters, 125 | ngram_emb_dim=self.ngram_emb_dim, 126 | concat_ngrams=self.concat_ngrams, 127 | num_heads=self.num_heads, 128 | dim_per_head=self.dim_per_head, 129 | params_init=self.params_init, 130 | weight_split_dims_mapping=self.weight_split_dims_mapping, 131 | ) 132 | self.create_child('ngram_layer', ngram_layer_p) 133 | -------------------------------------------------------------------------------- /praxis/layers/quantization/optimization_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for quantization optimizations.""" 17 | 18 | from absl.testing import absltest 19 | from jax import numpy as jnp 20 | import numpy as np 21 | from praxis import test_utils 22 | from praxis.layers.quantization import optimization 23 | 24 | 25 | class OptimizationTest(test_utils.TestCase): 26 | 27 | def setUp(self): 28 | super().setUp() 29 | np.random.seed(123456) 30 | 31 | def test_optimization_int8(self): 32 | t = jnp.array([[1.0, 2.0, 3.0], [4.0, 1.0, 2.0]]) 33 | bound = jnp.array([[3.0], [4.0]]) 34 | ret = optimization.get_best_bound(t, bound, -128.0, 127.0) 35 | expected = jnp.array([[3.0], [4.0]]) 36 | self.assertArraysEqual(ret, expected) 37 | 38 | def test_optimization_int4(self): 39 | t = jnp.array([[1.0, 2.0, 3.0], [4.0, 1.0, 2.0]]) 40 | bound = jnp.array([[3.0], [4.0]]) 41 | ret = optimization.get_best_bound(t, bound, -8.0, 7.0) 42 | expected = jnp.array([[3.0], [4.0]]) 43 | self.assertArraysEqual(ret, expected) 44 | 45 | def test_optimization_int4_p2(self): 46 | t = jnp.array([[1.0, 2.0, 3.0], [4.0, 1.0, 2.0]]) 47 | bound = jnp.array([[3.0], [4.0]]) 48 | ret = optimization.get_best_bound(t, bound, -8.0, 7.0, 2.0) 49 | expected = jnp.array([[2.85], [3.8]]) 50 | self.assertArraysEqual(ret, expected) 51 | 52 | def test_optimization_int4_per_channel(self): 53 | t = jnp.array([[1.0, 2.0, 3.0], [4.0, 1.0, 2.0]]) 54 | bound = jnp.max(t, axis=0, keepdims=True) 55 | best_bound_per_channel = optimization.get_best_bound( 56 | t, bound, -8.0, 7.0, 2.0, per_channel=True 57 | ) 58 | expected_best_bound_per_channel = jnp.array([[4.0, 1.9, 3.0]]) 59 | self.assertArraysEqual( 60 | best_bound_per_channel, expected_best_bound_per_channel 61 | ) 62 | best_bound_per_tensor = optimization.get_best_bound( 63 | t, bound, -8.0, 7.0, 2.0 64 | ) 65 | expected_best_bound_per_tensor = jnp.array([[4.0, 2.0, 3.0]]) 66 | self.assertArraysEqual( 67 | best_bound_per_tensor, expected_best_bound_per_tensor 68 | ) 69 | 70 | 71 | if __name__ == '__main__': 72 | absltest.main() 73 | -------------------------------------------------------------------------------- /praxis/layers/quantization/searchable.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Layers for Mixed-Precision Quantization Search.""" 17 | 18 | import dataclasses 19 | from typing import Any, Sequence 20 | 21 | from praxis import base_layer 22 | from praxis import pax_fiddle 23 | from praxis.layers.quantization import attentions 24 | from praxis.layers.quantization import automl_select 25 | from praxis.layers.quantization import linears 26 | from praxis.layers.quantization import quantization_hparams 27 | from praxis.layers.quantization import quantizer 28 | 29 | QuantizationParams = quantization_hparams.QuantizationParams 30 | template_field = pax_fiddle.template_field 31 | 32 | _SUPPORTED_PRECISIONS = [2, 4, 8] 33 | 34 | 35 | def _build_quantization_templates( 36 | precisions: Sequence[int], base_qhparams: QuantizationParams 37 | ): 38 | """Build templates for precision search from the base quantization hparams.""" 39 | act_qtzr_tmpls = [] 40 | weight_qtzr_tmpls = [] 41 | 42 | if not precisions: 43 | raise AttributeError('Must specify at least 1 precisions.') 44 | 45 | for precision in precisions: # pylint: disable=not-an-iterable 46 | if precision not in _SUPPORTED_PRECISIONS: 47 | raise AttributeError('Precision %d is not supported.' % precision) 48 | # if not quantizing activations, then don't set precision to indicate. 49 | if base_qhparams.act_params and base_qhparams.act_params.precision: 50 | act_qtzr_tmpl = quantizer.create_tensor_quantizer( 51 | f'act_quantizer_int{str(precision)}', base_qhparams.act_params 52 | ) 53 | act_qtzr_tmpl.precision = precision 54 | act_qtzr_tmpls.append(act_qtzr_tmpl) 55 | weight_qtzr_tmpl = quantizer.create_tensor_quantizer( 56 | f'weight_quantizer_int{str(precision)}', base_qhparams.weight_params 57 | ) 58 | weight_qtzr_tmpl.precision = precision 59 | weight_qtzr_tmpls.append(weight_qtzr_tmpl) 60 | return act_qtzr_tmpls, weight_qtzr_tmpls 61 | 62 | 63 | class SearchableQuantizedLayer(base_layer.BaseLayer): 64 | """A template to make the precision of quantizers searchable.""" 65 | precisions: Sequence[int] = dataclasses.field(default_factory=list) 66 | 67 | def create_tensor_quantizers(self): 68 | act_qtzr_tmpls, weight_qtzr_tmpls = _build_quantization_templates( 69 | self.precisions, self.quantization 70 | ) 71 | # Use AutoMLSelect as quantizer 72 | if act_qtzr_tmpls: 73 | self.create_child( 74 | 'act_quantizer', 75 | pax_fiddle.Config( 76 | automl_select.AutoMLSelect, 77 | search_options_tpl=act_qtzr_tmpls, 78 | ), 79 | ) 80 | else: 81 | self.create_child( 82 | 'act_quantizer', 83 | quantizer.create_tensor_quantizer( 84 | 'aqt_quantizer', self.quantization.act_params 85 | ), 86 | ) 87 | self.create_child( 88 | 'weight_quantizer', 89 | pax_fiddle.Config( 90 | automl_select.AutoMLSelect, 91 | search_options_tpl=weight_qtzr_tmpls, 92 | ), 93 | ) 94 | 95 | # This is for fixing the return type mismatch during multiple inheritance. 96 | def replace(self, **kwargs: dict[Any, Any]): 97 | return self.replace(**kwargs) 98 | 99 | 100 | class SearchableLinear(SearchableQuantizedLayer, linears.Linear): 101 | """Quantized Linear layer with searchable precision.""" 102 | 103 | pass 104 | 105 | 106 | class SearchableAttentionProjection( 107 | SearchableQuantizedLayer, attentions.AttentionProjection 108 | ): 109 | """Layer that computes multi-head projection with searchable precision.""" 110 | 111 | pass 112 | 113 | 114 | class SearchableCombinedQKVProjectionLayer( 115 | SearchableQuantizedLayer, attentions.CombinedQKVProjectionLayer 116 | ): 117 | """Layer that computes QKV projection with searchable precision.""" 118 | pass 119 | -------------------------------------------------------------------------------- /praxis/layers/quantization/searchable_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for praxis.layers.quantization.searchable.""" 17 | 18 | from absl.testing import absltest 19 | 20 | import jax 21 | import jax.numpy as jnp 22 | 23 | from praxis import base_hyperparams 24 | from praxis import base_layer 25 | from praxis import pax_fiddle 26 | from praxis import test_utils 27 | from praxis.layers.quantization import quantization_hparams 28 | from praxis.layers.quantization import searchable 29 | 30 | jax.config.update('jax_threefry_partitionable', False) 31 | 32 | instantiate = base_hyperparams.instantiate 33 | QuantizationType = quantization_hparams.QuantizationType 34 | 35 | 36 | def _run_option(model, model_vars, inputs, act_decision, weight_decision): 37 | if 'act_quantizer' in model_vars['non_trainable']: 38 | model_vars['non_trainable']['act_quantizer']['decision'] = act_decision 39 | model_vars['non_trainable']['weight_quantizer']['decision'] = weight_decision 40 | return model.apply(model_vars, inputs) 41 | 42 | 43 | class SearchableTest(test_utils.TestCase): 44 | 45 | def setUp(self): 46 | super().setUp() 47 | self.quantization_tpl = quantization_hparams.QuantizationParams( 48 | quantization_type=quantization_hparams.QuantizationType.AQT, 49 | mode=quantization_hparams.QuantizationMode.TRAINING, 50 | act_params=quantization_hparams.ActQuantizationParams(), 51 | weight_params=quantization_hparams.WeightQuantizationParams(), 52 | ) 53 | 54 | def _test_common(self, p, x): 55 | m = instantiate(p) 56 | 57 | with base_layer.JaxContext.new_context(): 58 | m_vars = m.init(jax.random.PRNGKey(0), x) 59 | a4w4 = _run_option(m, m_vars, x, 0, 0) 60 | a4w8 = _run_option(m, m_vars, x, 0, 1) 61 | a8w4 = _run_option(m, m_vars, x, 1, 0) 62 | a8w8 = _run_option(m, m_vars, x, 1, 1) 63 | 64 | self.assertAllClose(a4w4, a4w8, rtol=0.05, atol=0.05) 65 | self.assertAllClose(a4w4, a8w4, rtol=0.05, atol=0.05) 66 | self.assertAllClose(a4w8, a8w8, rtol=0.05, atol=0.05) 67 | 68 | def test_searchable_linear(self): 69 | p = pax_fiddle.Config( 70 | searchable.SearchableLinear, 71 | quantization=self.quantization_tpl, 72 | input_dims=3, 73 | output_dims=1, 74 | precisions=[4, 8], 75 | ) 76 | self._test_common(p, jnp.ones((1, 3), dtype=p.dtype)) 77 | 78 | def test_searchable_linear_without_act_quant(self): 79 | self.quantization_tpl.act_params.precision = None 80 | p = pax_fiddle.Config( 81 | searchable.SearchableLinear, 82 | quantization=self.quantization_tpl, 83 | input_dims=3, 84 | output_dims=1, 85 | precisions=[4, 8], 86 | ) 87 | m = instantiate(p) 88 | m_vars = m.init(jax.random.PRNGKey(0), jnp.ones((1, 3), dtype=p.dtype)) 89 | self.assertNotIn('act_quantizer', m_vars['non_trainable']) 90 | 91 | def test_attention_projections(self): 92 | p = pax_fiddle.Config( 93 | searchable.SearchableAttentionProjection, 94 | quantization=self.quantization_tpl, 95 | input_dim=8, 96 | num_heads=2, 97 | dim_per_head=3, 98 | is_output_projection=True, 99 | precisions=[4, 8], 100 | ) 101 | self._test_common(p, jnp.ones((4, 2, 3), dtype=p.dtype)) 102 | 103 | def test_combined_qkv_projections(self): 104 | p = pax_fiddle.Config( 105 | searchable.SearchableCombinedQKVProjectionLayer, 106 | quantization=self.quantization_tpl, 107 | input_dim=8, 108 | num_heads=3, 109 | dim_per_head=2, 110 | precisions=[4, 8], 111 | ) 112 | self._test_common(p, jnp.ones((4, 8), dtype=p.dtype)) 113 | 114 | 115 | if __name__ == '__main__': 116 | absltest.main() 117 | -------------------------------------------------------------------------------- /praxis/layers/quantization/sparsity/BUILD: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Description: 17 | # Sparsity related layers. The public API is defined in __init__.py. 18 | 19 | load("//praxis:build-visibility.bzl", "JAX_VISIBILITY") 20 | load("//praxis:praxis.bzl", "pytype_strict_library", "pytype_strict_test") 21 | 22 | package(default_visibility = JAX_VISIBILITY) 23 | 24 | pytype_strict_library( 25 | name = "layers", 26 | srcs = ["__init__.py"], 27 | ) 28 | 29 | pytype_strict_library( 30 | name = "sparsity_hparams", 31 | srcs = ["sparsity_hparams.py"], 32 | deps = [":sparsity_modes"], 33 | ) 34 | 35 | pytype_strict_library( 36 | name = "sparsity", 37 | srcs = ["sparsity.py"], 38 | deps = [ 39 | ":sparsity_hparams", 40 | # Implicit jax dependency. 41 | ], 42 | ) 43 | 44 | pytype_strict_test( 45 | name = "sparsity_test", 46 | srcs = ["sparsity_test.py"], 47 | deps = [ 48 | ":sparsity", 49 | ":sparsity_hparams", 50 | # Implicit absl.testing.absltest dependency. 51 | # Implicit absl.testing.parameterized dependency. 52 | # Implicit upb python proto dependency. 53 | # Implicit jax dependency. 54 | # Implicit numpy dependency. 55 | ], 56 | ) 57 | 58 | pytype_strict_library( 59 | name = "sparsity_modes", 60 | srcs = ["sparsity_modes.py"], 61 | deps = [ 62 | # Implicit jax dependency. 63 | "//praxis:pytypes", 64 | ], 65 | ) 66 | 67 | pytype_strict_test( 68 | name = "linears_test", 69 | srcs = ["linears_test.py"], 70 | deps = [ 71 | ":sparsity_hparams", 72 | ":sparsity_modes", 73 | # Implicit absl.logging dependency. 74 | # Implicit absl.testing.absltest dependency. 75 | # Implicit absl.testing.parameterized dependency. 76 | # Implicit upb python proto dependency. 77 | # Implicit jax dependency. 78 | # Implicit numpy dependency. 79 | "//praxis:base_layer", 80 | "//praxis:pax_fiddle", 81 | "//praxis:test_utils", 82 | "//praxis/layers:linears", 83 | "//praxis/layers/quantization:linears", 84 | ], 85 | ) 86 | 87 | pytype_strict_test( 88 | name = "attentions_test", 89 | srcs = ["attentions_test.py"], 90 | deps = [ 91 | ":sparsity_hparams", 92 | ":sparsity_modes", 93 | # Implicit absl.testing.absltest dependency. 94 | # Implicit absl.testing.parameterized dependency. 95 | # Implicit upb python proto dependency. 96 | # Implicit jax dependency. 97 | # Implicit numpy dependency. 98 | "//praxis:base_layer", 99 | "//praxis:pax_fiddle", 100 | "//praxis:py_utils", 101 | "//praxis:test_utils", 102 | "//praxis/layers:attentions", 103 | "//praxis/layers/quantization:attentions", 104 | ], 105 | ) 106 | 107 | pytype_strict_library( 108 | name = "sparsifier", 109 | srcs = ["sparsifier.py"], 110 | deps = [ 111 | ":sparsity", 112 | ":sparsity_hparams", 113 | ":sparsity_modes", 114 | # Implicit jax dependency. 115 | "//praxis:base_layer", 116 | "//praxis:pytypes", 117 | ], 118 | ) 119 | 120 | pytype_strict_test( 121 | name = "sparsifier_test", 122 | srcs = ["sparsifier_test.py"], 123 | deps = [ 124 | ":sparsifier", 125 | ":sparsity_hparams", 126 | ":sparsity_modes", 127 | # Implicit absl.testing.absltest dependency. 128 | # Implicit absl.testing.parameterized dependency. 129 | # Implicit upb python proto dependency. 130 | # Implicit jax dependency. 131 | # Implicit numpy dependency. 132 | "//praxis:base_layer", 133 | "//praxis:pax_fiddle", 134 | "//praxis:test_utils", 135 | "//praxis/layers:linears", 136 | "//praxis/layers/quantization:layers", 137 | ], 138 | ) 139 | -------------------------------------------------------------------------------- /praxis/layers/quantization/sparsity/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Exposes the public layer functionalities.""" 17 | -------------------------------------------------------------------------------- /praxis/layers/quantizer_objectives.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """A set of objective functions for building quantizers (e.g VQ-VAE).""" 17 | 18 | import jax 19 | import jax.numpy as jnp 20 | from praxis import pytypes 21 | 22 | JTensor = pytypes.JTensor 23 | 24 | 25 | def scatter_nd(indices, updates, shape): 26 | zeros = jnp.zeros(shape, updates.dtype) 27 | key = tuple(jnp.moveaxis(indices, -1, 0)) 28 | return zeros.at[key].add(updates) 29 | 30 | 31 | def batch_pplx_entropy_from_codes( 32 | codes: JTensor, 33 | num_classes: int, 34 | *, 35 | paddings: JTensor | None = None, 36 | data_parallel_axis: str | None = None 37 | ): 38 | """Calculates pplx and entropy from probs within batch. 39 | 40 | Args: 41 | codes: [..., num_groups] with values in [0, num_classes). 42 | num_classes: A Python int. 43 | paddings: [...], 0/1 value tensor. 44 | data_parallel_axis: If this is set we sum over replicas to get the combined 45 | statistics 46 | 47 | Returns: 48 | A tuple of 3 tensors: 49 | - pplx: scalar, avg_across_groups(avg(non-padded samples of a group)) 50 | - entropy: scalar, avg_across_groups(avg(non-padded samples of a group)) 51 | - histogram: [g, c], code word counts. 52 | """ 53 | rank = len(codes.shape) 54 | assert rank is not None 55 | assert rank >= 2 56 | 57 | is_in_pmap = data_parallel_axis is not None 58 | 59 | codes = codes.astype(jnp.int32) 60 | if paddings is None: 61 | paddings = jnp.zeros_like(codes[..., 0], dtype=codes.dtype) 62 | else: 63 | paddings = paddings.astype(codes.dtype) 64 | 65 | num_groups = codes.shape[-1] 66 | # [?, g] 67 | codes = jnp.reshape(codes, [-1, num_groups]) 68 | paddings = jnp.reshape(paddings, [-1, 1]) 69 | paddings = jnp.broadcast_to(paddings, codes.shape) 70 | 71 | # [g] 72 | indices_offset = jnp.arange( 73 | start=0, stop=num_groups * num_classes, step=num_classes, dtype=jnp.int32) 74 | # [?, g] 75 | indices = codes + indices_offset 76 | # [? * g, 1] 77 | indices = jnp.reshape(indices, [-1])[:, jnp.newaxis] 78 | 79 | # [? * g] 80 | mask = (1.0 - paddings).astype(jnp.float32) 81 | values = jnp.reshape(mask, [-1]) 82 | 83 | # [g * c] 84 | histogram = scatter_nd(indices, values, [num_groups * num_classes]) 85 | normalizer = jnp.sum(values) / num_groups 86 | 87 | if is_in_pmap: 88 | histogram = jax.lax.psum(histogram, data_parallel_axis) 89 | normalizer = jax.lax.psum(normalizer, data_parallel_axis) 90 | # [g, c] 91 | histogram = jnp.reshape(histogram, [num_groups, num_classes]) 92 | 93 | # [g, c] 94 | probs = histogram / jnp.maximum(normalizer, 1.0) 95 | log_probs = jnp.log(jnp.maximum(1.0e-30, probs)) 96 | # [g] 97 | sum_plogp = jnp.sum(log_probs * probs, -1) 98 | pplx = jnp.mean(jnp.exp(-sum_plogp)) 99 | entropy = jnp.log(pplx) 100 | return pplx, entropy, histogram 101 | 102 | 103 | def batch_codebook_coverage( 104 | codes: JTensor, 105 | num_classes: int, 106 | *, 107 | paddings: JTensor, 108 | data_parallel_axis: str | None = None 109 | ): 110 | """Computes codebook coverage within a batch. 111 | 112 | Args: 113 | codes: [..., num_groups], values are in [0, num_classes). 114 | num_classes: A Python int. 115 | paddings: [...], 0/1 value tensor. 116 | data_parallel_axis: If set will psum() over the axis 117 | 118 | Returns: 119 | A scalar JTensor, avg coverage across groups. 120 | """ 121 | # [num_groups, num_classes] 122 | _, _, histogram = batch_pplx_entropy_from_codes( 123 | codes, 124 | num_classes, 125 | paddings=paddings, 126 | data_parallel_axis=data_parallel_axis) 127 | onehot = jnp.greater(histogram, 0).astype(jnp.float32) 128 | avg_num_covered_words = jnp.mean(jnp.sum(onehot, -1)) 129 | return avg_num_covered_words / num_classes 130 | -------------------------------------------------------------------------------- /praxis/layers/quantizer_objectives_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for quantizer_objectives.""" 17 | 18 | from absl.testing import absltest 19 | import jax.numpy as jnp 20 | import numpy as np 21 | from praxis import test_utils 22 | from praxis.layers import quantizer_objectives 23 | 24 | 25 | class CodebookObjectivesTest(test_utils.TestCase): 26 | codes = np.array([ 27 | [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [0, 0]], 28 | [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [0, 0]], 29 | ]) 30 | 31 | paddings = np.array([[0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 1]]) 32 | entropy = 1.609 33 | pplx = 5.000 34 | num_classes = 11 35 | 36 | def test_batch_pplx_entropy_from_codes(self): 37 | pplx, entropy, _ = quantizer_objectives.batch_pplx_entropy_from_codes( 38 | codes=jnp.array(self.codes), 39 | num_classes=self.num_classes, 40 | paddings=jnp.array(self.paddings), 41 | ) 42 | 43 | self.assertAlmostEqual( 44 | np.array(pplx), 45 | np.array(self.pplx), 46 | delta=1e-3, 47 | msg='PPLX is not the same', 48 | ) 49 | self.assertAlmostEqual( 50 | np.array(entropy), 51 | np.array(self.entropy), 52 | delta=1e-3, 53 | msg='Entropy is not the same', 54 | ) 55 | 56 | 57 | if __name__ == '__main__': 58 | absltest.main() 59 | -------------------------------------------------------------------------------- /praxis/layers/quantizer_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for quantizer.""" 17 | 18 | from absl.testing import absltest 19 | from praxis import pax_fiddle 20 | from absl.testing import parameterized 21 | 22 | import jax 23 | import jax.numpy as jnp 24 | import numpy as np 25 | from praxis import base_layer 26 | from praxis import test_utils 27 | from praxis.layers import quantizer 28 | 29 | instantiate = base_layer.instantiate 30 | 31 | 32 | class VectorQuantizerTest(test_utils.TestCase): 33 | 34 | W = np.array([[0.116230249, 0.0104732513, -0.409445882, -0.153374314], 35 | [-0.0672334433, -0.430877686, -0.280010223, 0.394074917], 36 | [-0.360892653, -0.153173685, -0.45321393, -0.176380157], 37 | [0.406187773, 0.304340839, 0.439772606, 0.368542314]]) 38 | 39 | def _GetParams(self, num_classes, latent_dim): 40 | return pax_fiddle.Config( 41 | quantizer.VectorQuantizer, 42 | name='vq', 43 | normalize_latent_vector=True, 44 | normalize_codebook=True, 45 | num_latent_classes=num_classes, 46 | latent_dim=latent_dim, 47 | beta=0.1, 48 | ) 49 | 50 | def testBase(self): 51 | num_classes = 4 52 | latent_dim = 4 53 | 54 | b, t = 2, 4 55 | np.random.seed(2021) 56 | z = np.random.rand(b, t, latent_dim).astype(np.float32) 57 | paddings = np.zeros((b, t)).astype(np.float32) 58 | 59 | vq_p = self._GetParams(num_classes, latent_dim) 60 | vq = instantiate(vq_p) 61 | vq_theta = vq.init(jax.random.PRNGKey(1), z, paddings) 62 | vq_theta['params']['w'] = jnp.expand_dims(self.W, 1) 63 | out = vq.apply(vq_theta, z, paddings) 64 | 65 | with self.subTest('test_shape'): 66 | self.assertEqual((b, t, latent_dim), out.z_q.shape) 67 | self.assertEqual((b, t, 1), out.z_codes.shape) 68 | self.assertEqual((b, t, 1, num_classes), out.z_onehot.shape) 69 | with self.subTest('test_z_q'): 70 | self.assertAllClose(15.861525, np.sum(out.z_q)) 71 | with self.subTest('test_z_codes'): 72 | self.assertEqual(24, np.sum(out.z_codes)) 73 | with self.subTest('test_codebook_coverage'): 74 | self.assertEqual(0.25, np.sum(out.codebook_coverage)) 75 | with self.subTest('test_pplx'): 76 | self.assertEqual(1.0, out.pplx) 77 | with self.subTest('test_entropy'): 78 | self.assertAllClose(0., out.entropy) 79 | 80 | def testNCodebooks(self): 81 | num_classes = 4 82 | latent_dim = 4 83 | num_groups = 2 84 | 85 | b, t = 2, 4 86 | np.random.seed(2021) 87 | z = np.random.rand(b, t, latent_dim).astype(np.float32) 88 | paddings = np.zeros((b, t)).astype(np.float32) 89 | 90 | vq_p = self._GetParams(num_classes, latent_dim) 91 | vq_p.num_groups = num_groups 92 | vq = instantiate(vq_p) 93 | vq_theta = vq.init(jax.random.PRNGKey(1), z, paddings) 94 | out = vq.apply(vq_theta, z, paddings) 95 | 96 | with self.subTest('test_shape'): 97 | self.assertEqual((b, t, latent_dim), out.z_q.shape) 98 | self.assertEqual((b, t, num_groups), out.z_codes.shape) 99 | self.assertEqual((b, t, num_groups, num_classes), out.z_onehot.shape) 100 | 101 | 102 | class RandomVectorQuantizerTest(test_utils.TestCase): 103 | 104 | @parameterized.parameters( 105 | (2, 4, 20, 16, 4, 1), 106 | (3, 4, 20, 32, 4, 2), 107 | (4, 7, 16, 256, 20, 8), 108 | ) 109 | def testBase(self, b, t, latent_dim, projection_dim, num_classes, num_groups): 110 | np.random.seed(2022) 111 | z = np.random.rand(b, t, latent_dim).astype(np.float32) 112 | paddings = np.zeros((b, t)).astype(np.float32) 113 | 114 | rq = pax_fiddle.Config( 115 | quantizer.RandomVectorQuantizer, 116 | name='vq', 117 | num_latent_classes=num_classes, 118 | num_groups=num_groups, 119 | latent_dim=latent_dim, 120 | projection_dim=projection_dim, 121 | ) 122 | rq = instantiate(rq) 123 | rq_theta = rq.init(jax.random.PRNGKey(1), z, paddings) 124 | out = rq.apply(rq_theta, z, paddings) 125 | self.assertEqual((b, t, projection_dim), out.z_q.shape) 126 | self.assertEqual((b, t, num_groups), out.z_codes.shape) 127 | self.assertEqual((b, t, num_groups, num_classes), out.z_onehot.shape) 128 | 129 | if __name__ == '__main__': 130 | absltest.main() 131 | -------------------------------------------------------------------------------- /praxis/layers/searchable.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Layers for Platform-aware AutoML Search.""" 17 | 18 | from typing import Any, Sequence 19 | 20 | from flax import linen as nn 21 | import jax.numpy as jnp 22 | from praxis import base_layer 23 | from praxis import pax_fiddle 24 | from praxis import pytypes 25 | 26 | JTensor = pytypes.JTensor 27 | LayerTpl = pax_fiddle.Config[base_layer.BaseLayer] 28 | NestedJTensor = pytypes.NestedJTensor 29 | template_field = pax_fiddle.template_field 30 | WeightInit = base_layer.WeightInit 31 | WeightHParams = base_layer.WeightHParams 32 | 33 | template_field = pax_fiddle.template_field 34 | 35 | 36 | class AutoMLSelect(base_layer.BaseLayer): 37 | """AutoMLSelect layer to switch branches according to AutoML decisions. 38 | 39 | Attributes: 40 | search_options_tpl: a sequence of layers, with each represents a branch. 41 | They layers must have the same shapes of input and output. 42 | """ 43 | 44 | search_options_tpl: Sequence[LayerTpl] | None = template_field(None) 45 | 46 | def setup(self) -> None: 47 | if not self.search_options_tpl: 48 | raise AttributeError('Must set at least one search option.') 49 | decision = WeightHParams( 50 | shape=[], 51 | init=WeightInit.Constant(0), 52 | dtype=jnp.uint8, 53 | mesh_shape=self.mesh_shape 54 | ) 55 | self.create_children('search_options', self.search_options_tpl) 56 | self.create_variable('decision', decision, trainable=False) 57 | 58 | def __call__( 59 | self, 60 | x: JTensor, 61 | *args: Any, 62 | **kwargs: Any, 63 | ) -> NestedJTensor: 64 | def branch_fn(i): 65 | return lambda mdl, x: mdl.search_options[i](x) 66 | 67 | branches = [branch_fn(i) for i in range(len(self.search_options))] 68 | 69 | if self.is_mutable_collection('params'): 70 | for branch in branches: 71 | _ = branch(self, x) 72 | 73 | return nn.switch(self.get_var('decision'), branches, self, x) 74 | -------------------------------------------------------------------------------- /praxis/layers/searchable_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for praxis.layers.searchable.""" 17 | 18 | from absl.testing import absltest 19 | import jax 20 | import jax.numpy as jnp 21 | from praxis import base_hyperparams 22 | from praxis import base_layer 23 | from praxis import pax_fiddle 24 | from praxis import test_utils 25 | from praxis.layers import linears 26 | from praxis.layers import searchable 27 | 28 | instantiate = base_hyperparams.instantiate 29 | weight_init = base_layer.WeightInit 30 | 31 | 32 | class AutomlSelectTest(test_utils.TestCase): 33 | 34 | def test_automl_select(self): 35 | p = pax_fiddle.Config(searchable.AutoMLSelect) 36 | p.search_options_tpl = [ 37 | pax_fiddle.Config( 38 | linears.Linear, 39 | name='jax_ffn0', 40 | weight_init=weight_init.Constant(0.0), 41 | input_dims=1, 42 | output_dims=1), 43 | pax_fiddle.Config( 44 | linears.Linear, 45 | name='jax_ffn1', 46 | weight_init=weight_init.Constant(1.0), 47 | input_dims=1, 48 | output_dims=1), 49 | ] 50 | 51 | m = instantiate(p) 52 | x = jnp.ones(1, dtype=jnp.float32) 53 | m_vars = m.init(jax.random.PRNGKey(0), x) 54 | m_vars['non_trainable']['decision'] = 0 55 | x1 = m.apply(m_vars, x) 56 | m_vars['non_trainable']['decision'] = 1 57 | x2 = m.apply(m_vars, x) 58 | self.assertAllClose(x1, jnp.zeros(1, dtype=jnp.float32)) 59 | self.assertAllClose(x2, jnp.ones(1, dtype=jnp.float32)) 60 | 61 | 62 | if __name__ == '__main__': 63 | absltest.main() 64 | -------------------------------------------------------------------------------- /praxis/layers/sequential.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """A sequential stack of layers for Pax.""" 17 | 18 | from typing import Any, Callable, Sequence 19 | 20 | from praxis import base_layer 21 | 22 | 23 | class Sequential(base_layer.BaseLayer): 24 | """Applies a linear chain of Modules.""" 25 | layers: Sequence[Callable[..., Any]] | None = None 26 | 27 | def __call__(self, *args, **kwargs): 28 | if not self.layers: 29 | raise ValueError(f'Empty Sequential module {self.name}.') 30 | 31 | outputs = self.layers[0](*args, **kwargs) 32 | for layer in self.layers[1:]: 33 | if isinstance(outputs, tuple): 34 | outputs = layer(*outputs) 35 | elif isinstance(outputs, dict): 36 | outputs = layer(**outputs) 37 | else: 38 | outputs = layer(outputs) 39 | return outputs 40 | -------------------------------------------------------------------------------- /praxis/layers/sequential_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Unittest for Sequential.""" 17 | 18 | from absl.testing import absltest 19 | import jax 20 | from jax import numpy as jnp 21 | 22 | from praxis import base_layer 23 | from praxis import pax_fiddle 24 | from praxis import py_utils 25 | from praxis import pytypes 26 | from praxis import test_utils 27 | from praxis.layers.sequential import Sequential 28 | 29 | JTensor = pytypes.JTensor 30 | NestedMap = py_utils.NestedMap 31 | instantiate = base_layer.instantiate 32 | 33 | PARAMS = base_layer.PARAMS 34 | RANDOM = base_layer.RANDOM 35 | AUX_LOSS = base_layer.AUX_LOSS 36 | SUMMARIES = base_layer.SUMMARIES 37 | NON_TRAINABLE = base_layer.NON_TRAINABLE 38 | NON_PAX_VAR_COLLECTION = base_layer.NON_PAX_VAR_COLLECTION 39 | DECODE_CACHE = base_layer.DECODE_CACHE 40 | DEFAULT_INIT_MUTABLE_LIST = base_layer.DEFAULT_INIT_MUTABLE_LIST 41 | 42 | 43 | class Feature(base_layer.BaseLayer): 44 | def __call__(self, features: JTensor, paddings: JTensor) -> NestedMap: 45 | return NestedMap(features=features, paddings=paddings) 46 | 47 | 48 | class SpecAugment(base_layer.BaseLayer): 49 | def __call__(self, features: JTensor, paddings: JTensor) -> NestedMap: 50 | return NestedMap(features=features, paddings=paddings) 51 | 52 | 53 | class Encoder(base_layer.BaseLayer): 54 | def __call__(self, features: JTensor, paddings: JTensor) -> NestedMap: 55 | return NestedMap(embeddings=features, paddings=paddings) 56 | 57 | 58 | class DecodingPrep(base_layer.BaseLayer): 59 | def __call__(self, embeddings: JTensor, paddings: JTensor) -> NestedMap: 60 | return NestedMap(embeddings=embeddings, paddings=paddings) 61 | 62 | 63 | class SequentialTest(test_utils.TestCase): 64 | 65 | def test_simple_sequence(self): 66 | specaug_p = pax_fiddle.Config(SpecAugment, name='specaugment') 67 | feature_p = pax_fiddle.Config(Feature, name='feature') 68 | encoder_p = pax_fiddle.Config(Encoder, name='encoder') 69 | decodingprep_p = pax_fiddle.Config(DecodingPrep, name='decodingprep') 70 | 71 | sequence_p = pax_fiddle.Config( 72 | Sequential, layers=[specaug_p, feature_p, encoder_p, decodingprep_p]) 73 | sequence = instantiate(sequence_p) 74 | 75 | features = jax.random.normal(jax.random.PRNGKey(123), shape=[2, 1, 4]) 76 | paddings = jnp.zeros(shape=[2, 1]) 77 | 78 | k1 = jax.random.PRNGKey(123) 79 | k2 = jax.random.PRNGKey(456) 80 | with base_layer.JaxContext.new_context(): 81 | initial_vars = sequence.init( 82 | rngs={RANDOM: k1, PARAMS: k2}, 83 | mutable=DEFAULT_INIT_MUTABLE_LIST, 84 | features=features, 85 | paddings=paddings) 86 | 87 | outputs = sequence.apply(initial_vars, features, paddings) 88 | 89 | self.assertIn('embeddings', outputs) 90 | self.assertIn('paddings', outputs) 91 | self.assertAllClose(outputs.embeddings, features) 92 | 93 | 94 | if __name__ == '__main__': 95 | absltest.main() 96 | 97 | -------------------------------------------------------------------------------- /praxis/layers/sharding.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Sharding utilities.""" 17 | 18 | import math 19 | import re 20 | from typing import Sequence 21 | 22 | import jax 23 | from jax.interpreters import pxla 24 | from praxis import base_layer 25 | from praxis import py_utils 26 | 27 | DimSharding = str | Sequence[str] | None 28 | Sharding = Sequence[DimSharding] | None 29 | 30 | 31 | def derive(s: Sharding, eqn: str) -> Sharding: 32 | """Derives a sharding based on an equation `original->derived`. 33 | 34 | Each letter in original and derived represents a named dimension, and the 35 | derivation is done by matching dimension names. E.g., with s=('x', 'y') and 36 | eqn="ab->cbda", the result will be (None, 'y', None, 'x'). 37 | 38 | Args: 39 | s: Source sharding. 40 | eqn: Derivation equation with named dimensions. 41 | 42 | Returns: 43 | The derived sharding. 44 | """ 45 | if s is None: 46 | return None 47 | pieces = eqn.split('->') 48 | assert len(pieces) == 2, eqn 49 | original, derived = pieces 50 | 51 | return tuple(s[original.index(d)] if d in original else None for d in derived) 52 | 53 | 54 | def shard(x: jax.Array, s: Sharding, eqn: str | None = None) -> jax.Array: 55 | """Annotates x with a sharding based on an optional equation. 56 | 57 | If equation is not specified, apply just jax.lax.with_sharding_constraint. 58 | 59 | In equation `original->derived`, each letter in original and derived 60 | represents a named dimension, and the derivation is done by matching 61 | dimension names. 62 | 63 | Each dim in `derived` can also be ?, which means unconstrained. 64 | 65 | E.g., with s=('x', 'y') and eqn="ab->cb?a", the derived sharding will be 66 | (None, 'y', unconstrained, 'x'). 67 | 68 | In `derived`, there can also be a group of consecutive dims marked optional, 69 | which are represented as dims inside `[]`. The tensor x can have either all of 70 | these dims, or none of them. 71 | 72 | E.g., with s=('x', 'y', 'z') and eqn="abc->[ab]c", the derived sharding will 73 | be ('x', 'y', 'z') for a 3D tensor x, and ('z',) of a 1D tensor x. 74 | 75 | Args: 76 | x: The tensor to annotate. 77 | s: Source sharding. 78 | eqn: Derivation equation with named dimensions. 79 | 80 | Returns: 81 | The derived sharding. 82 | """ 83 | if s is None or not py_utils.global_mesh_defined(): 84 | return x 85 | 86 | if eqn is not None: 87 | original, derived = eqn.split('->') 88 | if '[' in derived: 89 | l, optional, r = re.split(r'\[|\]', derived) 90 | 91 | if x.ndim == len(l) + len(r): 92 | derived = l + r 93 | elif x.ndim == len(l) + len(optional) + len(r): 94 | derived = l + optional + r 95 | else: 96 | raise ValueError(f'Given {derived=} is incompatible with {x=}') 97 | 98 | s = derive(s, f'{original}->{derived}') 99 | assert s is not None 100 | s = list(s) 101 | for i, p in enumerate(derived): 102 | if p == '?': 103 | s[i] = jax.sharding.PartitionSpec.UNCONSTRAINED 104 | 105 | partition_spec = jax.sharding.PartitionSpec(*s) 106 | 107 | # If mesh_axes_transpose exists in the current context, device axes will be 108 | # remapped according to the transpose rules. 109 | partition_spec = base_layer.maybe_transpose_mesh_axes(partition_spec) 110 | 111 | return jax.lax.with_sharding_constraint(x, partition_spec) 112 | 113 | 114 | def get_dim_sharding(s: Sharding, dim: int) -> DimSharding: 115 | """Returns the sharding on one dimension.""" 116 | if s is None: 117 | return None 118 | return s[dim] 119 | 120 | 121 | def shard_one_dim(x: jax.Array, s: DimSharding, dim: int) -> jax.Array: 122 | """Annotates x on one dim while other dims are unconstrained.""" 123 | perm = '?' * dim + 'd' + '?' * (x.ndim - dim - 1) 124 | return shard(x, (s,), 'd->' + perm) 125 | 126 | 127 | def num_shards_on_dim(dim_sharding: DimSharding) -> int: 128 | """Returns the number of shards on one dimension in a sharding.""" 129 | mesh = pxla.thread_resources.env.physical_mesh 130 | axis_sizes = dict(zip(mesh.axis_names, mesh.devices.shape)) 131 | 132 | mapping = None 133 | if base_layer.JaxContext.has_context(): 134 | mapping = base_layer.cur_jax_context().hparams.mesh_axes_transpose 135 | 136 | match dim_sharding: 137 | case None: 138 | return 1 139 | case str(): 140 | return axis_sizes.get( 141 | base_layer.transpose_one_axis(dim_sharding, mapping), 1 142 | ) 143 | case _: 144 | assert isinstance(dim_sharding, Sequence) 145 | return math.prod( 146 | axis_sizes.get(base_layer.transpose_one_axis(axis, mapping), 1) 147 | for axis in dim_sharding 148 | ) 149 | -------------------------------------------------------------------------------- /praxis/layers/ssm_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for Praxis SSM layers.""" 17 | 18 | from absl import logging 19 | from absl.testing import absltest 20 | from absl.testing import parameterized 21 | import jax 22 | from jax import numpy as jnp 23 | import numpy as np 24 | from praxis import base_layer 25 | from praxis import pax_fiddle 26 | from praxis import py_utils 27 | from praxis import test_utils 28 | from praxis.layers import ssm 29 | import tensorflow.compat.v2 as tf 30 | 31 | to_np = test_utils.to_np 32 | to_tf_nmap = test_utils.to_tf_nmap 33 | instantiate = base_layer.instantiate 34 | 35 | 36 | class SSMTest(test_utils.TestCase): 37 | 38 | def setUp(self): 39 | super().setUp() 40 | np.random.seed(123456) 41 | tf.random.set_seed(123) 42 | 43 | @parameterized.named_parameters( 44 | { 45 | 'testcase_name': 'ss4d-1d-legs', 46 | 'hippo_type': 'ss4d-1d-legs', 47 | }, 48 | { 49 | 'testcase_name': 'ss4d-1d-lagt', 50 | 'hippo_type': 'ss4d-1d-lagt', 51 | }, 52 | { 53 | 'testcase_name': 'ss4d-1d', 54 | 'hippo_type': 'ss4d-1d', 55 | }, 56 | ) 57 | def test_s4d_layer(self, hippo_type): 58 | p = pax_fiddle.Config( 59 | ssm.SSM, 60 | name='ssm', 61 | nheads=5, 62 | dim=1, 63 | l_max=2, 64 | decode_num_samples=1, 65 | step_size=1., 66 | hippo_type=hippo_type 67 | ) 68 | s4d = instantiate(p) 69 | npy_input = np.random.normal(1.0, 0.5, 70 | [2, p.l_max, p.dim]).astype('float32') 71 | inputs = jnp.asarray(npy_input) 72 | prng_key = jax.random.PRNGKey(seed=123) 73 | initial_vars = s4d.init(prng_key, inputs) 74 | 75 | # Test convolution/fft. 76 | outputs, ssm_state = s4d.apply( 77 | initial_vars, inputs, mutable=[base_layer.DECODE_CACHE]) 78 | logging.info('outputs = %s', outputs) 79 | 80 | # Test extend_step. 81 | out_step = [] 82 | updated_vars = py_utils.merge_dict(ssm_state, initial_vars) 83 | logging.info('init_vars w state = %s', updated_vars) 84 | for i in range(p.l_max): 85 | out, ssm_state = s4d.apply( 86 | updated_vars, inputs[:, i, :], method=s4d.extend_step, 87 | mutable=[base_layer.DECODE_CACHE]) 88 | logging.info('outputs = %s', out) 89 | logging.info('ssm_states = %s', ssm_state) 90 | updated_vars['decoder_cache'] = ssm_state['decoder_cache'] 91 | out_step.append(out) 92 | 93 | out_step = jnp.stack(out_step, axis=1) 94 | 95 | # Make sure the convolution/fft gets the same results as extend_step. 96 | self.assertAllClose(to_np(outputs), to_np(out_step), atol=1e-6) 97 | 98 | 99 | class S5Test(test_utils.TestCase): 100 | 101 | def setUp(self): 102 | super().setUp() 103 | np.random.seed(123456) 104 | tf.random.set_seed(123) 105 | 106 | def test_s5_layer(self): 107 | p = pax_fiddle.Config( 108 | ssm.S5, 109 | name='ssm', 110 | nheads=5, 111 | dim=1, 112 | l_max=2, 113 | decode_num_samples=1, 114 | step_size=1.0, 115 | ) 116 | s5 = instantiate(p) 117 | npy_input = np.random.normal(1.0, 0.5, [2, p.l_max, p.dim]).astype( 118 | 'float32' 119 | ) 120 | inputs = jnp.asarray(npy_input) 121 | prng_key = jax.random.PRNGKey(seed=123) 122 | initial_vars = s5.init(prng_key, inputs) 123 | 124 | # Test convolution/parallel_scan. 125 | outputs, ssm_state = s5.apply( 126 | initial_vars, inputs, mutable=[base_layer.DECODE_CACHE] 127 | ) 128 | logging.info('outputs = %s', outputs) 129 | 130 | # Test extend_step. 131 | out_step = [] 132 | updated_vars = py_utils.merge_dict(ssm_state, initial_vars) 133 | logging.info('init_vars w state = %s', updated_vars) 134 | for i in range(p.l_max): 135 | out, ssm_state = s5.apply( 136 | updated_vars, 137 | inputs[:, i, :], 138 | method=s5.extend_step, 139 | mutable=[base_layer.DECODE_CACHE], 140 | ) 141 | logging.info('outputs = %s', out) 142 | logging.info('ssm_states = %s', ssm_state) 143 | updated_vars['decoder_cache'] = ssm_state['decoder_cache'] 144 | out_step.append(out) 145 | 146 | out_step = jnp.stack(out_step, axis=1) 147 | 148 | # Make sure the convolution gets the same results as extend_step. 149 | self.assertAllClose(to_np(outputs), to_np(out_step), atol=1e-6) 150 | 151 | 152 | if __name__ == '__main__': 153 | absltest.main() 154 | -------------------------------------------------------------------------------- /praxis/layers/stats.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Compute stats of tensors mostly for monitoring purposes.""" 17 | 18 | from jax import numpy as jnp 19 | from praxis import py_utils 20 | from praxis import pytypes 21 | 22 | JTensor = pytypes.JTensor 23 | NestedMap = py_utils.NestedMap 24 | 25 | 26 | def compute_stats(inputs: JTensor, padding: JTensor | None = None) -> NestedMap: 27 | """Computes various stats over the valid data points in inputs.""" 28 | # Let's compute stats in fp32 29 | inputs = inputs.astype(jnp.float32) 30 | if padding is None: 31 | padding = jnp.zeros_like(inputs) 32 | assert inputs.ndim == padding.ndim, f'{inputs.shape}, {padding.shape}' 33 | mask = 1.0 - padding 34 | 35 | sum_v = jnp.sum(inputs * mask) 36 | count_v = jnp.sum(jnp.ones_like(inputs) * mask) 37 | mean_v = sum_v / jnp.maximum(1.0, count_v) 38 | sum_v_squared = jnp.sum(jnp.square((inputs - mean_v) * mask)) 39 | std_v = jnp.sqrt(sum_v_squared / jnp.maximum(1.0, count_v)) 40 | max_v = jnp.max(jnp.abs(inputs * mask)) 41 | 42 | return NestedMap(mean_v=mean_v, std_v=std_v, max_v=max_v) 43 | -------------------------------------------------------------------------------- /praxis/layers/stats_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for Praxis attention layers.""" 17 | 18 | from absl.testing import absltest 19 | from jax import numpy as jnp 20 | import numpy as np 21 | from praxis import test_utils 22 | from praxis.layers import stats 23 | 24 | 25 | class StatsTest(test_utils.TestCase): 26 | 27 | def setUp(self): 28 | super().setUp() 29 | np.random.seed(12345687) 30 | 31 | def test_compute_stats(self): 32 | inputs = jnp.array([[0., 1., 2.], [3., 4., 5.]]) 33 | padding = jnp.array([[0., 0., 0.], [0., 0., 1.]]) 34 | inputs_stats = stats.compute_stats(inputs, padding) 35 | self.assertAllClose( 36 | [2., 1.414214, 4.], 37 | [inputs_stats.mean_v, inputs_stats.std_v, inputs_stats.max_v]) 38 | 39 | 40 | if __name__ == '__main__': 41 | absltest.main() 42 | -------------------------------------------------------------------------------- /praxis/layers/stochastics.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Stochastic layers.""" 17 | 18 | import numbers 19 | from typing import Sequence 20 | 21 | import jax 22 | from jax import numpy as jnp 23 | from praxis import base_layer 24 | from praxis import py_utils 25 | from praxis import pytypes 26 | 27 | NestedMap = py_utils.NestedMap 28 | JTensor = pytypes.JTensor 29 | 30 | 31 | class Dropout(base_layer.BaseLayer): 32 | """Apply dropout during training. 33 | 34 | Attributes: 35 | keep_prob: Keep probability. 36 | noise_shape: A 1-D list of type `int32`, representing the shape for randomly 37 | generated keep/drop flags. Note that this noise_shape is unknown, when 38 | building layer params. 39 | noise_shape_broadcast_dims: A list of dimension where the noise shape is 40 | broadcasted. For example, noise_shape = [n, h, w, 1] when 41 | noise_shape_broadcast_dims=[-1]. 42 | dropout_at_eval: Whether or not to also perform dropout at eval time. We 43 | typically want to replace dropout by expectation during eval. However, in 44 | certain cases E(f(x)) != f(E(x)), and replacing dropout by its expectation 45 | during eval leads to worse quality. 46 | transpose_qk: Whether or not to transpose the sequence length dimensions 47 | correspond to query and key in the generated random numbers to avoid 48 | expensive copies. 49 | """ 50 | keep_prob: float = 1.0 51 | noise_shape: Sequence[int] | None = None 52 | noise_shape_broadcast_dims: Sequence[int] | None = None 53 | dropout_at_eval: bool = False 54 | transpose_qk: bool | None = False 55 | 56 | def _dropout(self, inputs: JTensor, noise_shape: list[int]) -> JTensor: 57 | if noise_shape is None: 58 | noise_shape = inputs.shape 59 | prng_key = self.next_prng_key() 60 | keep_prob = self.keep_prob 61 | assert keep_prob > 0.0 62 | random_nums = keep_prob + jax.random.uniform( 63 | prng_key, noise_shape, inputs.dtype, minval=0.0, maxval=1.0) 64 | transpose = ( 65 | len(noise_shape) == 4 66 | and noise_shape[3] == noise_shape[2] 67 | and self.transpose_qk 68 | ) 69 | if transpose: 70 | random_nums = jnp.transpose(random_nums, (0, 1, 3, 2)) 71 | binary_mask = jnp.floor(random_nums) 72 | return inputs * binary_mask / keep_prob 73 | 74 | def __call__(self, inputs: JTensor) -> JTensor: 75 | """Applies dropout to inputs. 76 | 77 | Args: 78 | inputs: The inputs JTensor. 79 | 80 | Returns: 81 | inputs with dropout applied at training time. 82 | """ 83 | if isinstance(self.keep_prob, numbers.Real) and self.keep_prob == 1.0: 84 | return inputs 85 | 86 | if self.do_eval and not self.dropout_at_eval: 87 | return inputs 88 | 89 | if not self.noise_shape_broadcast_dims: 90 | noise_shape = self.noise_shape 91 | else: 92 | noise_shape = self.noise_shape or list(inputs.shape) 93 | if not isinstance(noise_shape, list): 94 | noise_shape = list(noise_shape) 95 | for dim in self.noise_shape_broadcast_dims: 96 | if dim >= len(noise_shape): 97 | raise ValueError('Invalid broadcasted dim {}'.format(dim)) 98 | noise_shape[dim] = 1 99 | 100 | ret = self._dropout(inputs, noise_shape) 101 | return ret 102 | 103 | 104 | class StochasticResidual(base_layer.BaseLayer): 105 | """Stochastic residual layer that randomly drops the residual branch. 106 | 107 | Attributes: 108 | residual_weight: Residual weight with which to add the reisdual back to the 109 | input. 110 | survival_prob: Survival probability of the residual branch while dropping 111 | out. 112 | """ 113 | residual_weight: float = 1.0 114 | survival_prob: float = 1.0 115 | 116 | def _drop_connect(self, inputs: JTensor) -> JTensor: 117 | """Drops the entire residual layer with given survival probability. 118 | 119 | Args: 120 | inputs: input `.JTensor` which is on the residual branch which is dropped. 121 | 122 | Returns: 123 | Dropped out inputs. 124 | """ 125 | if self.do_eval: 126 | return inputs 127 | 128 | # Compute tensor. 129 | prng_key = self.next_prng_key() 130 | batch_size = inputs.shape[0] 131 | shape = [batch_size] + [1] * (len(inputs.shape) - 1) 132 | random_tensor = self.survival_prob + jax.random.uniform( 133 | prng_key, shape, dtype=inputs.dtype 134 | ) 135 | binary_tensor = jnp.floor(random_tensor) 136 | # Unlike conventional way that multiply survival_prob at test time, here we 137 | # divide survival_prob at training time, such that no additional compute is 138 | # needed at test time. 139 | output = inputs / self.survival_prob * binary_tensor 140 | return output 141 | 142 | def __call__(self, inputs: JTensor, residual: JTensor) -> JTensor: 143 | """Returns inputs + residual with stochastic dropout. 144 | 145 | Args: 146 | inputs: input `.JTensor`. 147 | residual: residual `.JTensor` which is added to input with dropout. 148 | 149 | Returns: 150 | Output `.JTensor` which is residual added to inputs with dropout. 151 | """ 152 | return inputs + self.residual_weight * self._drop_connect(residual) 153 | -------------------------------------------------------------------------------- /praxis/layers/stochastics_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for Praxis stochastic layers.""" 17 | 18 | from absl import logging 19 | from absl.testing import absltest 20 | import jax 21 | from jax import numpy as jnp 22 | from praxis import base_layer 23 | from praxis import pax_fiddle 24 | from praxis import test_utils 25 | from praxis.layers import stochastics 26 | 27 | jax.config.update('jax_threefry_partitionable', False) 28 | 29 | instantiate = base_layer.instantiate 30 | 31 | 32 | class StochasticsTest(test_utils.TestCase): 33 | 34 | def test_dropout_layer01(self): 35 | test_layer_p = pax_fiddle.Config( 36 | stochastics.Dropout, name='dropout', keep_prob=0.8 37 | ) 38 | layer = instantiate(test_layer_p) 39 | 40 | inputs = jnp.ones([10, 1000], dtype=jnp.bfloat16) 41 | 42 | with base_layer.JaxContext.new_context(): 43 | prng_key = jax.random.PRNGKey(seed=12346) 44 | prng_key, init_key = jax.random.split(prng_key) 45 | prng_key, dropout_k1, dropout_k2 = jax.random.split(prng_key, 3) 46 | initial_vars = layer.init({ 47 | 'random': dropout_k1, 48 | 'params': init_key 49 | }, inputs) 50 | logging.info('initial_vars: %s', initial_vars) 51 | output1 = layer.apply(initial_vars, inputs, rngs={'random': dropout_k1}) 52 | output2 = layer.apply(initial_vars, inputs, rngs={'random': dropout_k2}) 53 | 54 | out1_sum = jnp.sum(output1) 55 | out2_sum = jnp.sum(output2) 56 | out1_nonzero = jnp.sum(output1 > 0.0) 57 | out2_nonzero = jnp.sum(output2 > 0.0) 58 | 59 | logging.info('out1_sum: %s', out1_sum) 60 | logging.info('out2_sum: %s', out2_sum) 61 | logging.info('out1_nonzero: %s', out1_nonzero) 62 | logging.info('out2_nonzero: %s', out2_nonzero) 63 | 64 | finfo = jnp.finfo(inputs.dtype) 65 | nmant = finfo.nmant 66 | self.assertEqual(9984.0, out1_sum) 67 | if nmant < 8: 68 | self.assertEqual(9984.0, out2_sum) 69 | self.assertEqual(7983.0, out1_nonzero) 70 | self.assertEqual(7964.0, out2_nonzero) 71 | else: 72 | self.assertEqual(9920.0, out2_sum) 73 | self.assertEqual(8000.0, out1_nonzero) 74 | self.assertEqual(7952.0, out2_nonzero) 75 | 76 | def test_dropout_layer_02(self): 77 | test_layer_p = pax_fiddle.Config( 78 | stochastics.Dropout, 79 | name='dropout', 80 | keep_prob=0.8, 81 | noise_shape=[10, 6, 8], 82 | noise_shape_broadcast_dims=[2], 83 | ) 84 | layer = instantiate(test_layer_p) 85 | 86 | inputs = jnp.ones([2, 10, 6, 8], dtype=jnp.bfloat16) 87 | 88 | with base_layer.JaxContext.new_context(): 89 | prng_key = jax.random.PRNGKey(seed=12346) 90 | prng_key, init_key = jax.random.split(prng_key) 91 | prng_key, compute_key = jax.random.split(prng_key) 92 | initial_vars = layer.init({ 93 | 'random': compute_key, 94 | 'params': init_key 95 | }, inputs) 96 | logging.info('initial_vars: %s', initial_vars) 97 | output1 = layer.apply(initial_vars, inputs, rngs={'random': compute_key}) 98 | 99 | out1_sum = jnp.sum(output1) 100 | out1_nonzero = jnp.sum(output1 > 0.0) 101 | 102 | logging.info('out1_sum: %s', out1_sum) 103 | logging.info('out1_nonzero: %s', out1_nonzero) 104 | 105 | self.assertEqual(980, out1_sum) 106 | self.assertEqual(784, out1_nonzero) 107 | 108 | def test_dropout_layer_03(self): 109 | test_layer_p = pax_fiddle.Config( 110 | stochastics.Dropout, 111 | name='dropout', 112 | keep_prob=0.8, 113 | noise_shape_broadcast_dims=[0, 3], 114 | ) 115 | layer = instantiate(test_layer_p) 116 | 117 | inputs = jnp.ones([2, 10, 6, 8], dtype=jnp.bfloat16) 118 | 119 | with base_layer.JaxContext.new_context(): 120 | prng_key = jax.random.PRNGKey(seed=12346) 121 | prng_key, init_key = jax.random.split(prng_key) 122 | prng_key, compute_key = jax.random.split(prng_key) 123 | initial_vars = layer.init({ 124 | 'random': compute_key, 125 | 'params': init_key 126 | }, inputs) 127 | logging.info('initial_vars: %s', initial_vars) 128 | 129 | output1 = layer.apply(initial_vars, inputs, rngs={'random': compute_key}) 130 | 131 | out1_sum = jnp.sum(output1) 132 | out1_nonzero = jnp.sum(output1 > 0.0) 133 | 134 | logging.info('out1_sum: %s', out1_sum) 135 | logging.info('out1_nonzero: %s', out1_nonzero) 136 | 137 | self.assertEqual(980, out1_sum) 138 | self.assertEqual(784, out1_nonzero) 139 | 140 | 141 | if __name__ == '__main__': 142 | absltest.main() 143 | -------------------------------------------------------------------------------- /praxis/layers/vanillanets_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for vanillanets.""" 17 | 18 | from absl.testing import absltest 19 | from praxis import pax_fiddle 20 | from absl.testing import parameterized 21 | import jax 22 | import jax.numpy as jnp 23 | import numpy as np 24 | from praxis import base_layer 25 | from praxis.layers import poolings 26 | from praxis.layers import vanillanets 27 | 28 | instantiate = base_layer.instantiate 29 | 30 | 31 | class VanillanetsTest(parameterized.TestCase): 32 | 33 | def setUp(self): 34 | super().setUp() 35 | np.random.seed(123456) 36 | 37 | @parameterized.parameters( 38 | (3, 1, 'RELU', [2, 16, 36, 72], 72), 39 | (4, 2, 'TANH', [4, 12, 12, 16], 32), 40 | (4, 2, 'RELU', [4, 12, 12, 16], 8), 41 | (5, 1, 'NONE', [4, 12, 12, 16], 8), 42 | ) 43 | def test_vanilla_block(self, kernel_size, stride, activation, input_shape, 44 | output_dim): 45 | input_dim = input_shape[-1] 46 | p = pax_fiddle.Config( 47 | vanillanets.VanillaBlock, 48 | name='vanilla_block', 49 | input_dim=input_dim, 50 | output_dim=output_dim, 51 | kernel_size=kernel_size, 52 | stride=stride, 53 | ) 54 | resnet_layer = instantiate(p) 55 | npy_inputs = np.random.normal(1.0, 0.5, input_shape).astype('float32') 56 | inputs = jnp.asarray(npy_inputs) 57 | 58 | with base_layer.JaxContext.new_context(): 59 | prng_key = jax.random.PRNGKey(seed=123) 60 | initial_vars = resnet_layer.init(prng_key, inputs) 61 | output = resnet_layer.apply(initial_vars, inputs) 62 | 63 | @parameterized.parameters( 64 | ([1, 4, 4, 3], None), 65 | ([1, 4, 4, 3], [1, 2]), 66 | ) 67 | def test_vanilla_net(self, input_shape, spatial_pooling_dims): 68 | p = vanillanets.VanillaNet.HParamsVanillaNet5().set( 69 | name='vanillanet', output_spatial_pooling_params=spatial_pooling_dims) 70 | if spatial_pooling_dims is not None: 71 | p.output_spatial_pooling_params = pax_fiddle.Config( 72 | poolings.GlobalPooling, pooling_dims=spatial_pooling_dims 73 | ) 74 | vanillanet_layer = instantiate(p) 75 | npy_inputs = np.random.normal(1.0, 0.5, input_shape).astype('float32') 76 | inputs = jnp.asarray(npy_inputs) 77 | 78 | with base_layer.JaxContext.new_context(): 79 | prng_key = jax.random.PRNGKey(seed=123) 80 | initial_vars = vanillanet_layer.init(prng_key, inputs) 81 | output = vanillanet_layer.apply(initial_vars, inputs) 82 | 83 | 84 | if __name__ == '__main__': 85 | absltest.main() 86 | -------------------------------------------------------------------------------- /praxis/layers/video/BUILD: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | load("//praxis:build-visibility.bzl", "JAX_VISIBILITY") 17 | 18 | # Description: 19 | # video related layers. The public API is defined in __init__.py. 20 | load("//praxis:praxis.bzl", "pytype_strict_library", "pytype_strict_test") 21 | 22 | package(default_visibility = JAX_VISIBILITY) 23 | 24 | pytype_strict_library( 25 | name = "enc_dec_3dcnn", 26 | srcs = [ 27 | "enc_dec_3dcnn.py", 28 | ], 29 | deps = [ 30 | # Implicit jax dependency. 31 | # Implicit numpy dependency. 32 | "//praxis:base_layer", 33 | "//praxis:pax_fiddle", 34 | "//praxis:py_utils", 35 | "//praxis:pytypes", 36 | "//praxis/layers:activations", 37 | "//praxis/layers:convolutions", 38 | "//praxis/layers:linears", 39 | "//praxis/layers:normalizations", 40 | ], 41 | ) 42 | 43 | pytype_strict_test( 44 | name = "enc_dec_3dcnn_test", 45 | srcs = [ 46 | "enc_dec_3dcnn_test.py", 47 | ], 48 | deps = [ 49 | ":enc_dec_3dcnn", 50 | # Implicit absl.logging dependency. 51 | # Implicit absl.testing.absltest dependency. 52 | # Implicit jax dependency. 53 | "//praxis:base_layer", 54 | "//praxis:pax_fiddle", 55 | "//praxis:test_utils", 56 | ], 57 | ) 58 | 59 | pytype_strict_library( 60 | name = "losses", 61 | srcs = [ 62 | "losses.py", 63 | ], 64 | deps = [ 65 | # Implicit jax dependency. 66 | "//praxis:base_layer", 67 | "//praxis:base_model", 68 | "//praxis:py_utils", 69 | "//praxis:pytypes", 70 | ], 71 | ) 72 | 73 | pytype_strict_test( 74 | name = "losses_test", 75 | srcs = [ 76 | "losses_test.py", 77 | ], 78 | deps = [ 79 | ":losses", 80 | ":vqvae", 81 | # Implicit absl.testing.absltest dependency. 82 | # Implicit absl.testing.parameterized dependency. 83 | # Implicit jax dependency. 84 | # Implicit numpy dependency. 85 | "//praxis:base_layer", 86 | "//praxis:pax_fiddle", 87 | "//praxis:py_utils", 88 | "//praxis:test_utils", 89 | ], 90 | ) 91 | 92 | pytype_strict_library( 93 | name = "quantizer", 94 | srcs = [ 95 | "quantizer.py", 96 | ], 97 | deps = [ 98 | # Implicit jax dependency. 99 | # Implicit numpy dependency. 100 | "//praxis:base_layer", 101 | "//praxis:py_utils", 102 | "//praxis:pytypes", 103 | ], 104 | ) 105 | 106 | pytype_strict_test( 107 | name = "quantizer_test", 108 | srcs = [ 109 | "quantizer_test.py", 110 | ], 111 | deps = [ 112 | ":quantizer", 113 | # Implicit absl.testing.absltest dependency. 114 | # Implicit absl.testing.parameterized dependency. 115 | # Implicit jax dependency. 116 | "//praxis:base_layer", 117 | "//praxis:pax_fiddle", 118 | "//praxis:py_utils", 119 | "//praxis:test_utils", 120 | ], 121 | ) 122 | 123 | pytype_strict_library( 124 | name = "vqvae", 125 | srcs = [ 126 | "vqvae.py", 127 | ], 128 | deps = [ 129 | ":enc_dec_3dcnn", 130 | ":quantizer", 131 | # Implicit jax dependency. 132 | "//praxis:base_layer", 133 | "//praxis:base_model", 134 | "//praxis:pax_fiddle", 135 | "//praxis:py_utils", 136 | "//praxis:pytypes", 137 | "//praxis/layers:activations", 138 | "//praxis/layers:convolutions", 139 | "//praxis/layers:linears", 140 | ], 141 | ) 142 | 143 | pytype_strict_test( 144 | name = "vqvae_test", 145 | srcs = [ 146 | "vqvae_test.py", 147 | ], 148 | deps = [ 149 | ":quantizer", 150 | ":vqvae", 151 | # Implicit absl.testing.absltest dependency. 152 | # Implicit absl.testing.parameterized dependency. 153 | # Implicit jax dependency. 154 | "//praxis:base_layer", 155 | "//praxis:pax_fiddle", 156 | "//praxis:py_utils", 157 | "//praxis:test_utils", 158 | ], 159 | ) 160 | -------------------------------------------------------------------------------- /praxis/layers/video/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Exposes the public layer functionalities for video.""" 17 | 18 | from praxis.layers.video import enc_dec_3dcnn 19 | from praxis.layers.video import losses 20 | from praxis.layers.video import quantizer 21 | from praxis.layers.video import vqvae 22 | -------------------------------------------------------------------------------- /praxis/layers/video/enc_dec_3dcnn_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from absl import logging 17 | from absl.testing import absltest 18 | import jax 19 | from jax import numpy as jnp 20 | from praxis import base_layer 21 | from praxis import pax_fiddle 22 | from praxis import test_utils 23 | from praxis.layers.video import enc_dec_3dcnn 24 | 25 | 26 | class EndcDec3dcnnTest(test_utils.TestCase): 27 | 28 | def test_depth_to_space(self): 29 | x = jnp.ones((2, 5, 3, 3, 32)) 30 | y = enc_dec_3dcnn.depth_to_space(x, 2, 4) 31 | self.assertEqual(y.shape, (2, 10, 6, 6, 4)) 32 | 33 | def test_gan_res_block(self): 34 | prng_key, _ = jax.random.split(jax.random.PRNGKey(1234)) 35 | pax_x = jax.random.normal(prng_key, (1, 17, 256, 257, 32)) 36 | res_block_p = pax_fiddle.Config( 37 | enc_dec_3dcnn.DiscriminatorResBlock, 38 | name='res_block', 39 | input_dim=32, 40 | output_dim=32, 41 | ) 42 | pax_layer = base_layer.instantiate(res_block_p) 43 | init_vars = pax_layer.init(prng_key, pax_x) 44 | logging.info( 45 | 'init_vars: %s', jax.tree_util.tree_map(lambda x: x.shape, init_vars) 46 | ) 47 | pax_y = pax_layer.apply(init_vars, pax_x) 48 | self.assertEqual(pax_y.shape, (1, 9, 128, 129, 32)) 49 | 50 | def test_res_block(self): 51 | prng_key, _ = jax.random.split(jax.random.PRNGKey(1234)) 52 | pax_x = jax.random.normal(prng_key, (1, 5, 3, 3, 32)) 53 | res_block_p = pax_fiddle.Config( 54 | enc_dec_3dcnn.ResBlock, 55 | name='res_block', 56 | input_dim=32, 57 | output_dim=64, 58 | use_conv_shortcut=True, 59 | ) 60 | pax_layer = base_layer.instantiate(res_block_p) 61 | init_vars = pax_layer.init(prng_key, pax_x) 62 | logging.info( 63 | 'init_vars: %s', jax.tree_util.tree_map(lambda x: x.shape, init_vars) 64 | ) 65 | pax_y = pax_layer.apply(init_vars, pax_x) 66 | self.assertEqual(pax_y.shape, pax_x.shape[:-1] + (64,)) 67 | 68 | 69 | if __name__ == '__main__': 70 | absltest.main() 71 | -------------------------------------------------------------------------------- /praxis/layers/video/losses.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Loss functions for vqvae/vqgan models.""" 17 | 18 | from collections.abc import Callable 19 | import jax 20 | import jax.numpy as jnp 21 | from praxis import base_layer 22 | from praxis import base_model 23 | from praxis import py_utils 24 | from praxis import pytypes 25 | 26 | JTensor = pytypes.JTensor 27 | 28 | 29 | def r1_gradient_penalty( 30 | inputs: JTensor, 31 | logits_fn: Callable[[JTensor], JTensor], 32 | grad_penalty_cost: float = 10.0, 33 | ) -> tuple[JTensor, JTensor]: 34 | """Calculates gradients penalty loss to regularize the discriminator. 35 | 36 | Args: 37 | inputs: A tensor of image inputs. 38 | logits_fn: A function that takes inputs and returns logits. 39 | grad_penalty_cost: scalar weight for the gradient penalty loss. 40 | 41 | Returns: 42 | A tuple of logits and the gradient penalty. 43 | """ 44 | out, vjp_fn = jax.vjp(logits_fn, inputs, has_aux=False) 45 | # Check if jax.value_and_grad is more efficient than jax.vjp at scale. 46 | grad = vjp_fn(jnp.ones_like(out))[0] 47 | flattened_grad = jnp.asarray(grad.reshape((inputs.shape[0], -1)), jnp.float32) 48 | penalty = ( 49 | jnp.mean(jnp.sum(jnp.square(flattened_grad), axis=-1)) * grad_penalty_cost 50 | ) 51 | return out, penalty 52 | 53 | 54 | def _discriminator_loss(logits_real: JTensor, logits_fake: JTensor) -> JTensor: 55 | """Calculates non-saturating discriminator loss.""" 56 | d_loss_real = jax.nn.softplus(-logits_real) 57 | d_loss_fake = jax.nn.softplus(logits_fake) 58 | return jnp.mean(d_loss_real) + jnp.mean(d_loss_fake) 59 | 60 | 61 | def _generator_loss(logits_fake): 62 | """Calculates non-saturating generator loss.""" 63 | return jnp.mean(jax.nn.softplus(-logits_fake)) 64 | 65 | 66 | class VQGANLoss(base_layer.BaseLayer): 67 | """Loss layer for VQGAN.""" 68 | 69 | g_adversarial_loss_weight: float = 0.1 70 | reconstruction_loss_weight: float = 5.0 71 | polyak_decay: float = 0.999 72 | lecam_weight: float = 0.001 73 | 74 | def lecam_loss(self, real_pred: JTensor, fake_pred: JTensor) -> JTensor: 75 | """Calculates lecam loss. 76 | 77 | Described in https://arxiv.org/abs/2104.03310 78 | 79 | Args: 80 | real_pred: scalar, predictions for the real samples. 81 | fake_pred: scalar, prdictions for the reconstructed (fake) samples. 82 | 83 | Returns: 84 | Lecam regularization loss (scalar). 85 | """ 86 | ema_fake_pred = self.get_var('ema_fake_pred') 87 | ema_real_pred = self.get_var('ema_real_pred') 88 | return jnp.mean( 89 | jnp.power(jax.nn.relu(real_pred - ema_fake_pred), 2) 90 | ) + jnp.mean(jnp.power(jax.nn.relu(ema_real_pred - fake_pred), 2)) 91 | 92 | def setup(self): 93 | """Constructs this jax module and registers variables.""" 94 | decay_factor_hparams = base_layer.WeightHParams( 95 | shape=[], 96 | init=base_layer.WeightInit.Constant(0.0), 97 | dtype=jnp.float32, 98 | collections=[base_layer.WeightHParamsCollection.REQUIRES_MEAN_SYNC], 99 | ) 100 | 101 | self.create_variable('ema_real_pred', decay_factor_hparams, trainable=False) 102 | self.create_variable('ema_fake_pred', decay_factor_hparams, trainable=False) 103 | 104 | def __call__( 105 | self, predictions: base_model.Predictions, input_batch: py_utils.NestedMap 106 | ) -> py_utils.NestedMap: 107 | original_video = input_batch.video 108 | reconstructed = predictions['reconstructed'] 109 | logits_real = predictions['logits_real'] 110 | logits_fake = predictions['logits_fake'] 111 | real_pred = jnp.mean(logits_real) 112 | fake_pred = jnp.mean(logits_fake) 113 | 114 | ema_fake_pred = self.get_var('ema_fake_pred') 115 | ema_real_pred = self.get_var('ema_real_pred') 116 | ema_fake_pred = ( 117 | fake_pred * (1 - self.polyak_decay) + ema_fake_pred * self.polyak_decay 118 | ) 119 | ema_real_pred = ( 120 | real_pred * (1 - self.polyak_decay) + ema_real_pred * self.polyak_decay 121 | ) 122 | self.update_var('ema_fake_pred', ema_fake_pred) 123 | self.update_var('ema_real_pred', ema_real_pred) 124 | 125 | losses = py_utils.NestedMap() 126 | losses.grad_penalty = predictions['r1_gradient_penalty'] 127 | losses.lecam_loss = ( 128 | self.lecam_loss(logits_real, logits_fake) * self.lecam_weight 129 | ) 130 | 131 | losses.d_adversarial_loss = _discriminator_loss(logits_real, logits_fake) 132 | losses.g_adversarial_loss = ( 133 | _generator_loss(logits_fake) * self.g_adversarial_loss_weight 134 | ) 135 | 136 | diff = jnp.asarray(original_video - reconstructed, jnp.float32) 137 | 138 | losses.reconstruction_loss = ( 139 | jnp.mean(jnp.square(diff)) * self.reconstruction_loss_weight 140 | ) 141 | losses.perceptual_loss = jnp.array(0.0, dtype=jnp.float32) 142 | if self.do_eval: 143 | losses.quantizer_loss = jnp.zeros_like(losses.reconstruction_loss) 144 | else: 145 | losses.quantizer_loss = predictions['quantizer_loss'] 146 | losses.d_loss = ( 147 | losses.d_adversarial_loss + losses.grad_penalty + losses.lecam_loss 148 | ) 149 | losses.g_loss = ( 150 | losses.reconstruction_loss 151 | + losses.g_adversarial_loss 152 | + losses.perceptual_loss 153 | + losses.quantizer_loss 154 | ) 155 | return losses 156 | -------------------------------------------------------------------------------- /praxis/layers/video/losses_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import functools 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | import jax 21 | import numpy as np 22 | from praxis import base_layer 23 | from praxis import pax_fiddle 24 | from praxis import py_utils 25 | from praxis import test_utils 26 | from praxis.layers.video import losses 27 | from praxis.layers.video import vqvae 28 | 29 | 30 | class LossesTest(test_utils.TestCase): 31 | 32 | def test_r1_gradient_penalty(self): 33 | prng_key = jax.random.PRNGKey(seed=123) 34 | x = jax.random.normal(prng_key, (2, 5, 16, 16, 3)) 35 | # Create a pax layer and get the output from the random input. 36 | p = pax_fiddle.Config( 37 | vqvae.Discriminator, 38 | name='magvit', 39 | num_frames=5, 40 | image_height=16, 41 | image_width=16, 42 | filters=32, 43 | channel_multipliers=(2, 4), 44 | ) 45 | context_p = base_layer.JaxContext.HParams(do_eval=False) 46 | with base_layer.JaxContext.new_context(hparams=context_p): 47 | pax_layer = base_layer.instantiate(p) 48 | pax_vars = pax_layer.init(prng_key, x) 49 | logit_fn = functools.partial(pax_layer.apply, pax_vars) 50 | logits, penalty = losses.r1_gradient_penalty(x, logit_fn) 51 | self.assertEqual(logits.shape, (2, 1)) 52 | self.assertEqual(penalty.shape, ()) 53 | 54 | @parameterized.parameters(True, False) 55 | def test_vqgan_loss(self, do_eval): 56 | batch_size, num_frames, height, width, channels = 2, 5, 128, 128, 3 57 | video_shape = (batch_size, num_frames, height, width, channels) 58 | np.random.seed(12345) 59 | input_batch = py_utils.NestedMap( 60 | video=np.random.randint(0, 255, size=video_shape) 61 | ) 62 | predictions = py_utils.NestedMap( 63 | reconstructed=np.random.normal(size=video_shape), 64 | logits_real=np.random.normal(size=(batch_size, 1)), 65 | logits_fake=np.random.normal(size=(batch_size, 1)), 66 | quantizer_loss=np.random.normal(size=[]), 67 | r1_gradient_penalty=np.random.normal(size=[]), 68 | ) 69 | 70 | loss_p = pax_fiddle.Config( 71 | losses.VQGANLoss, 72 | name='loss', 73 | ) 74 | loss_layer = loss_p.Instantiate() 75 | prng_key = jax.random.PRNGKey(seed=123) 76 | context_p = base_layer.JaxContext.HParams(do_eval=do_eval) 77 | with base_layer.JaxContext.new_context(hparams=context_p): 78 | init_vars = loss_layer.init(prng_key, predictions, input_batch) 79 | loss_dict, updated_vars = loss_layer.apply( 80 | init_vars, predictions, input_batch, mutable=base_layer.NON_TRAINABLE 81 | ) 82 | for loss in loss_dict.values(): 83 | self.assertEqual((), loss.shape) 84 | self.assertNotEqual( 85 | updated_vars[base_layer.NON_TRAINABLE]['ema_fake_pred'], 0.0 86 | ) 87 | self.assertNotEqual( 88 | updated_vars[base_layer.NON_TRAINABLE]['ema_real_pred'], 0.0 89 | ) 90 | 91 | 92 | if __name__ == '__main__': 93 | absltest.main() 94 | -------------------------------------------------------------------------------- /praxis/layers/video/quantizer_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from absl.testing import absltest 17 | from absl.testing import parameterized 18 | import jax 19 | from praxis import base_layer 20 | from praxis import pax_fiddle 21 | from praxis import test_utils 22 | from praxis.layers.video import quantizer 23 | 24 | jax.config.update("jax_threefry_partitionable", False) 25 | 26 | 27 | class QuantizerTest(test_utils.TestCase): 28 | 29 | @parameterized.parameters( 30 | (8, True, 0, 0, 0), 31 | (8, False, -0.331989, 0.10211816, -0.43410715), 32 | (16, True, 0, 0, 0), 33 | (16, False, -0.3675179, 0.1005669, -0.46808478), 34 | ) 35 | def test_encode_decode_id_and_loss( 36 | self, embedding_dim, do_eval, quantzer_loss, e_latent_loss, entropy_loss 37 | ): 38 | prng_key = jax.random.PRNGKey(seed=123) 39 | x = jax.random.normal(prng_key, (1, 3, 6, 6, embedding_dim)) 40 | config = pax_fiddle.Config( 41 | quantizer.LookupFreeQuantizer, 42 | name="quantizer", 43 | embedding_dim=embedding_dim, 44 | ) 45 | context_p = base_layer.JaxContext.HParams(do_eval=do_eval) 46 | with base_layer.JaxContext.new_context(hparams=context_p): 47 | lookup_free_quantizer = base_layer.instantiate(config) 48 | pax_vars = lookup_free_quantizer.init(prng_key, x) 49 | self.assertEmpty(pax_vars) 50 | value, result_dict = lookup_free_quantizer.apply(pax_vars, x) 51 | self.assertEqual(value.shape, x.shape) 52 | self.assertAllClose(x, result_dict.raw) 53 | 54 | if not do_eval: 55 | ids = result_dict.encoding_indices 56 | decoded = lookup_free_quantizer.decode_ids(ids) 57 | self.assertAllClose(decoded, result_dict.encodings) 58 | self.assertEqual(result_dict.quantizer_loss.shape, ()) 59 | self.assertEqual(result_dict.e_latent_loss.shape, ()) 60 | self.assertEqual(result_dict.entropy_loss.shape, ()) 61 | self.assertAllClose(result_dict.quantizer_loss, quantzer_loss) 62 | self.assertAllClose(result_dict.e_latent_loss, e_latent_loss) 63 | self.assertAllClose(result_dict.entropy_loss, entropy_loss) 64 | 65 | 66 | if __name__ == "__main__": 67 | absltest.main() 68 | -------------------------------------------------------------------------------- /praxis/layers/video/vqvae_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from absl.testing import absltest 17 | import jax 18 | from praxis import base_layer 19 | from praxis import pax_fiddle 20 | from praxis import test_utils 21 | from praxis.layers.video import quantizer 22 | from praxis.layers.video import vqvae 23 | 24 | jax.config.update("jax_threefry_partitionable", False) 25 | 26 | 27 | class VQVAETest(test_utils.TestCase): 28 | 29 | def test_auto_encoder(self): 30 | prng_key = jax.random.PRNGKey(seed=123) 31 | x = jax.random.normal(prng_key, (1, 5, 16, 16, 3)) 32 | config = pax_fiddle.Config( 33 | vqvae.VQVaeModel, 34 | name="vqvae", 35 | encoder_tpl=pax_fiddle.Config( 36 | vqvae.Encoder, 37 | name="encoder", 38 | filters=128, 39 | input_dim=3, 40 | embedding_dim=8, 41 | num_res_blocks=4, 42 | temporal_downsample=(False, True, True), 43 | channel_multipliers=(1, 2, 2, 4), 44 | ), 45 | decoder_tpl=pax_fiddle.Config( 46 | vqvae.Decoder, 47 | name="decoder", 48 | filters=128, 49 | embedding_dim=8, 50 | output_dim=3, 51 | num_res_blocks=4, 52 | temporal_downsample=(False, True, True), 53 | channel_multipliers=(1, 2, 2, 4), 54 | ), 55 | quantizer_tpl=pax_fiddle.Config( 56 | quantizer.LookupFreeQuantizer, 57 | name="quantizer", 58 | embedding_dim=8, 59 | ), 60 | ) 61 | vqvae_model = base_layer.instantiate(config) 62 | pax_vars = vqvae_model.init(prng_key, x) 63 | value = vqvae_model.apply(pax_vars, x) 64 | self.assertIsInstance(value, tuple) 65 | self.assertLen(value, 2) 66 | self.assertEqual(value[0].shape, x.shape) 67 | self.assertAllClose(value[1].quantizer_loss, -0.09399976) 68 | self.assertAllClose(value[1].e_latent_loss, 0.07907465) 69 | self.assertAllClose(value[1].entropy_loss, -0.17307441) 70 | 71 | def test_discriminator(self): 72 | prng_key = jax.random.PRNGKey(seed=123) 73 | x = jax.random.normal(prng_key, (1, 5, 16, 16, 3)) 74 | config = pax_fiddle.Config( 75 | vqvae.Discriminator, 76 | name="discriminator", 77 | filters=128, 78 | num_frames=5, 79 | image_height=16, 80 | image_width=16, 81 | input_dim=3, 82 | blur_filter_size=3, 83 | channel_multipliers=(2, 4, 4, 4), 84 | ) 85 | discriminator = base_layer.instantiate(config) 86 | pax_vars = discriminator.init(prng_key, x) 87 | value = discriminator.apply(pax_vars, x) 88 | self.assertAllClose(value, -4.26183e-05) 89 | 90 | 91 | if __name__ == "__main__": 92 | absltest.main() 93 | -------------------------------------------------------------------------------- /praxis/lazy_loader.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """A LazyLoader class.""" 17 | 18 | import importlib 19 | import types 20 | from absl import logging 21 | 22 | 23 | class LazyLoader(types.ModuleType): 24 | """Lazily import a module, mainly to avoid pulling in large dependencies. 25 | 26 | `contrib`, and `ffmpeg` are examples of modules that are large and not always 27 | needed, and this allows them to only be loaded when they are used. 28 | """ 29 | 30 | # The lint error here is incorrect. 31 | def __init__(self, local_name, parent_module_globals, name, warning=None): 32 | self._local_name = local_name 33 | self._parent_module_globals = parent_module_globals 34 | self._warning = warning 35 | 36 | # These members allows doctest correctly process this module member without 37 | # triggering self._load(). self._load() mutates parant_module_globals and 38 | # triggers a dict mutated during iteration error from doctest.py. 39 | # - for from_module() 40 | self.__module__ = name.rsplit(".", 1)[0] 41 | # - for is_routine() 42 | self.__wrapped__ = None 43 | 44 | super(LazyLoader, self).__init__(name) 45 | 46 | def _load(self): 47 | """Load the module and insert it into the parent's globals.""" 48 | # Import the target module and insert it into the parent's namespace 49 | module = importlib.import_module(self.__name__) 50 | self._parent_module_globals[self._local_name] = module 51 | 52 | # Emit a warning if one was specified 53 | if self._warning: 54 | logging.warning(self._warning) 55 | # Make sure to only warn once. 56 | self._warning = None 57 | 58 | # Update this object's dict so that if someone keeps a reference to the 59 | # LazyLoader, lookups are efficient (__getattr__ is only called on lookups 60 | # that fail). 61 | self.__dict__.update(module.__dict__) 62 | 63 | return module 64 | 65 | def __getattr__(self, item): 66 | module = self._load() 67 | return getattr(module, item) 68 | 69 | def __repr__(self): 70 | # Carefully to not trigger _load, since repr may be called in very 71 | # sensitive places. 72 | return f"" 73 | 74 | def __dir__(self): 75 | module = self._load() 76 | return dir(module) 77 | -------------------------------------------------------------------------------- /praxis/lingvo_lib.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Exposes public Lingvo symbols used by Praxis. 17 | 18 | Lingvo symbols must *not* be imported elsewhere in Praxis library code (except 19 | in test-related modules). 20 | """ 21 | 22 | from lingvo.core import cluster 23 | from lingvo.core import cluster_factory 24 | from lingvo.core import hyperparams 25 | from lingvo.core import nested_map 26 | from praxis import lazy_loader 27 | 28 | # datasource is slow to import (because it imports TF), so we do it lazily. 29 | datasource = lazy_loader.LazyLoader( 30 | 'datasource', globals(), 'lingvo.core.datasource' 31 | ) 32 | 33 | # Note: These are only used by LingvoInputAdaptor, and may possibly be moved 34 | # outside of Praxis in the future. 35 | current_cluster = cluster_factory.Current 36 | infeed_context_scope = cluster.InfeedContextScope 37 | 38 | # Core data-structure for aggregating tensors. 39 | NestedMap = nested_map.NestedMap 40 | 41 | # Note: HParams-related classes. This may possibly be removed post-Fiddle 42 | # migration. 43 | InstantiableParams = hyperparams.InstantiableParams 44 | HParams = hyperparams.Params 45 | -------------------------------------------------------------------------------- /praxis/metric_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Utility functions for computing metrics.""" 17 | 18 | import jax 19 | from jax import numpy as jnp 20 | 21 | 22 | def top_k_accuracy( 23 | top_k: int, 24 | logits: jax.Array, 25 | label_ids: jax.Array | None = None, 26 | label_probs: jax.Array | None = None, 27 | weights: jax.Array | None = None, 28 | ) -> jax.Array: 29 | """Computes the top-k accuracy given the logits and labels. 30 | 31 | Args: 32 | top_k: An int scalar, specifying the value of top-k. 33 | logits: A [..., C] float tensor corresponding to the logits. 34 | label_ids: A [...] int vector corresponding to the class labels. One of 35 | label_ids and label_probs should be presented. 36 | label_probs: A [..., C] float vector corresponding to the class 37 | probabilites. Must be presented if label_ids is None. 38 | weights: A [...] float vector corresponding to the weight to assign to each 39 | example. 40 | 41 | Returns: 42 | The top-k accuracy represented as a `JTensor`. 43 | 44 | Raises: 45 | ValueError if neither `label_ids` nor `label_probs` are provided. 46 | """ 47 | if label_ids is None and label_probs is None: 48 | raise ValueError("One of label_ids and label_probs should be given.") 49 | if label_ids is None: 50 | label_ids = jnp.argmax(label_probs, axis=-1) 51 | if weights is None: 52 | weights = jnp.ones(logits.shape[:-1]) 53 | 54 | values, _ = jax.lax.top_k(logits, k=top_k) 55 | threshold = jnp.min(values, axis=-1) 56 | 57 | # Reshape logits to [-1, C]. 58 | logits_reshaped = jnp.reshape(logits, [-1, logits.shape[-1]]) 59 | 60 | # Reshape label_ids to [-1, 1]. 61 | label_ids_reshaped = jnp.reshape(label_ids, [-1, 1]) 62 | logits_slice = jnp.take_along_axis( 63 | logits_reshaped, label_ids_reshaped, axis=-1 64 | )[..., 0] 65 | 66 | # Reshape logits_slice back to original shape to be compatible with weights. 67 | logits_slice = jnp.reshape(logits_slice, label_ids.shape) 68 | correct = jnp.greater_equal(logits_slice, threshold) 69 | correct_sum = jnp.sum(correct * weights) 70 | all_sum = jnp.maximum(1.0, jnp.sum(weights)) 71 | return correct_sum / all_sum 72 | -------------------------------------------------------------------------------- /praxis/optimizer_prefix_vectorization_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Unit tests for prefix vectorization in optimizers.""" 17 | 18 | import typing 19 | 20 | from absl import logging 21 | from absl.testing import absltest 22 | import jax 23 | from jax import numpy as jnp 24 | import optax 25 | from praxis import base_layer 26 | from praxis import optimizer_prefix_vectorization as opt_vec 27 | from praxis import test_utils 28 | 29 | 30 | class OptimizerPrefixVectorizationTest(test_utils.TestCase): 31 | 32 | def test_vectorized_prefix_with_tree_map_params(self): 33 | def _opt_init(params): 34 | # Reduction over each variable. Behavior will depend on vectorization. 35 | logging.info(f'Init called with params {params}') 36 | return jax.tree.map(jnp.sum, params) 37 | 38 | def _opt_update(updates, state, params): 39 | del params 40 | return jax.tree.map(lambda u, s: u + s, updates, state), state 41 | 42 | grad_tx = optax.GradientTransformationExtraArgs( 43 | init=_opt_init, update=_opt_update 44 | ) 45 | 46 | grads = base_layer.NestedMap( 47 | a=jnp.array([1, 2], dtype=jnp.float32), 48 | b=jnp.array([1, 2], dtype=jnp.float32), 49 | c=jnp.array([[1, 2], [3, 4]], dtype=jnp.float32), 50 | ) 51 | variables = grads.copy() 52 | a_var_param = base_layer.WeightHParams(()) 53 | a_var_param.repeat_prefix = [2] 54 | a_var_param.repeat_prefix_split_dims_mapping = [-1] 55 | b_var_param = base_layer.WeightHParams((2,)) 56 | c_var_param = base_layer.WeightHParams(()) 57 | c_var_param.repeat_prefix = [2, 2] 58 | c_var_param.repeat_prefix_split_dims_mapping = [('data', 'mdl'), None] 59 | var_hparams = base_layer.NestedMap( 60 | a=a_var_param, b=b_var_param, c=c_var_param 61 | ) 62 | 63 | grad_tx = opt_vec.get_transformations_with_vectorized_repeat_prefix( 64 | grad_tx, var_hparams 65 | ) 66 | 67 | state = grad_tx.init(variables) 68 | logging.info(state) 69 | opt_states_pspec = opt_vec.partition_params(grad_tx, var_hparams, state) 70 | logging.info('opt_states_pspec=%s', opt_states_pspec) 71 | # Computed update is 0 + state, and state is sum of each variable. 72 | update, _ = grad_tx.update( 73 | jax.tree.map(jnp.zeros_like, variables), state, variables 74 | ) 75 | # Variables a and c are scalars excluding the prefix, so the update must be 76 | # equal to the initial variable values. 77 | update = typing.cast(base_layer.NestedMap, update) 78 | variables = typing.cast(base_layer.NestedMap, variables) 79 | self.assertAllClose(update.a, variables.a) 80 | self.assertAllClose(update.c, variables.c) 81 | # b is not vectorized, so the update equals the sum reduction of the initial 82 | # variable value. 83 | self.assertAllClose( 84 | update.b, jnp.zeros_like(variables.b) + jnp.sum(variables.b) 85 | ) 86 | 87 | 88 | if __name__ == '__main__': 89 | absltest.main() 90 | -------------------------------------------------------------------------------- /praxis/pip_package/build_pip_pkg.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | PLATFORM="$(uname -s | tr 'A-Z' 'a-z')" 6 | 7 | export PYTHON_VERSION="${PYTHON_VERSION:-3}" 8 | export PYTHON_MINOR_VERSION="${PYTHON_MINOR_VERSION}" 9 | export DEST="${WHEEL_FOLDER:-/tmp/wheels}" 10 | 11 | if [[ -z "${PYTHON_MINOR_VERSION}" ]]; then 12 | PYTHON="python${PYTHON_VERSION}" 13 | else 14 | PYTHON="python${PYTHON_VERSION}.${PYTHON_MINOR_VERSION}" 15 | fi 16 | 17 | function main() { 18 | ${PYTHON} setup.py bdist_wheel 19 | 20 | if [ ! -d "${DEST}" ]; then 21 | mkdir -p "${DEST}" 22 | fi 23 | 24 | cp dist/*.whl "${DEST}" 25 | echo $(date) : "=== Output wheel file is in: ${DEST}" 26 | } 27 | 28 | main "$@" 29 | -------------------------------------------------------------------------------- /praxis/pip_package/cloudbuild-postsubmit.yaml: -------------------------------------------------------------------------------- 1 | steps: 2 | - name: 'gcr.io/cloud-builders/docker' 3 | args: [ 4 | 'build', 5 | '--build-arg', 'image_name=${_IMAGE_NAME}', 6 | '-f', 'praxis/pip_package/postsubmit.Dockerfile', '.' 7 | ] 8 | 9 | substitutions: 10 | _PYTHON_VERSION: '3.10' 11 | _RELEASE_VERSION: 'nightly' # or rX.Y 12 | _IMAGE_NAME: 'paxml_${_RELEASE_VERSION}_${_PYTHON_VERSION}' 13 | options: 14 | dynamic_substitutions: true 15 | substitution_option: 'ALLOW_LOOSE' 16 | machineType: E2_HIGHCPU_32 17 | timeout: 1800s 18 | -------------------------------------------------------------------------------- /praxis/pip_package/cloudbuild-presubmit.yaml: -------------------------------------------------------------------------------- 1 | steps: 2 | - name: 'gcr.io/cloud-builders/docker' 3 | args: [ 4 | 'build', 5 | '--build-arg', 'image_name=${_IMAGE_NAME}', 6 | '-f', 'praxis/pip_package/presubmit.Dockerfile', '.' 7 | ] 8 | 9 | substitutions: 10 | _PYTHON_VERSION: '3.10' 11 | _RELEASE_VERSION: 'nightly' # or rX.Y 12 | _IMAGE_NAME: 'paxml_${_RELEASE_VERSION}_${_PYTHON_VERSION}' 13 | options: 14 | dynamic_substitutions: true 15 | substitution_option: 'ALLOW_LOOSE' 16 | machineType: E2_HIGHCPU_32 17 | timeout: 1800s 18 | -------------------------------------------------------------------------------- /praxis/pip_package/cloudbuild-release.yaml: -------------------------------------------------------------------------------- 1 | steps: 2 | - name: 'gcr.io/cloud-builders/docker' 3 | args: [ 4 | 'build', 5 | '-t', 'gcr.io/${PROJECT_ID}/${_IMAGE_NAME}', 6 | '-f', 'praxis/pip_package/release.Dockerfile', '.', 7 | '--build-arg', 'wheel_folder=${_WHEEL_FOLDER}', 8 | ] 9 | timeout: 3600s 10 | - name: 'gcr.io/cloud-builders/docker' 11 | args: ['push', '--all-tags', 'gcr.io/${PROJECT_ID}/${_IMAGE_NAME}'] 12 | timeout: 1800s 13 | - name: 'gcr.io/${PROJECT_ID}/${_IMAGE_NAME}' 14 | entrypoint: 'bash' 15 | args: ['-c', 'mv ${_WHEEL_FOLDER}/*.whl .'] 16 | 17 | substitutions: 18 | _PYTHON_VERSION: '3.10' 19 | _RELEASE_VERSION: '1.4.0' # or rX.Y 20 | _IMAGE_NAME: 'praxis_${_RELEASE_VERSION}_${_PYTHON_VERSION}' 21 | _WHEEL_FOLDER: '/tmp/wheels' 22 | options: 23 | dynamic_substitutions: true 24 | substitution_option: 'ALLOW_LOOSE' 25 | machineType: E2_HIGHCPU_8 26 | timeout: 5400s 27 | artifacts: 28 | objects: 29 | location: 'gs://pax-on-cloud-tpu-project/wheels/$(date -u +%Y%m%d)-praxis-${_RELEASE_VERSION}' 30 | paths: ['/**/*.whl'] 31 | -------------------------------------------------------------------------------- /praxis/pip_package/postsubmit.Dockerfile: -------------------------------------------------------------------------------- 1 | ARG image_name 2 | ARG base_image="gcr.io/pax-on-cloud-project/${image_name}:latest" 3 | FROM $base_image 4 | 5 | RUN rm -rf /praxis 6 | COPY . /praxis 7 | RUN pip3 uninstall -y fiddle 8 | RUN pip3 uninstall -y flax 9 | RUN pip3 uninstall -y jax 10 | RUN pip3 install --no-deps -r /praxis/praxis/pip_package/requirements.txt 11 | 12 | RUN cd /praxis && bazel build ... 13 | 14 | RUN cd /praxis && \ 15 | bazel test \ 16 | --test_output=all \ 17 | --test_verbose_timeout_warnings \ 18 | -- \ 19 | praxis/... \ 20 | -praxis/layers:attentions_test \ 21 | -praxis/layers:convolutions_test \ 22 | -praxis/layers:ctc_objectives_test \ 23 | -praxis/layers:embedding_softmax_test \ 24 | -praxis/layers:models_test \ 25 | -praxis/layers:ngrammer_test \ 26 | -praxis/layers:normalizations_test \ 27 | -praxis/layers:rnn_cell_test \ 28 | -praxis/layers:transformer_models_encoder_decoder_test \ 29 | -praxis/layers:transformer_models_test \ 30 | -praxis/layers:transformers_test 31 | 32 | WORKDIR / 33 | 34 | CMD ["/bin/bash"] 35 | -------------------------------------------------------------------------------- /praxis/pip_package/prepare_release.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # This script prepare a new release by: 4 | # 1) update version number in setup.py and cloudbuild-release.yaml 5 | # 2) add a new section in RELEASE.md with version and corresponding commit 6 | 7 | set -e -x 8 | 9 | function print_help_and_exit { 10 | echo "Usage: prepare_release.sh -x -d " 11 | echo "exp: bash prepare_release.sh -x 0.2.0 -d 20221114" 12 | exit 0 13 | } 14 | 15 | while getopts "hd:x:" opt; do 16 | case $opt in 17 | x) 18 | PRAXIS_VERSION=${OPTARG} 19 | ;; 20 | d) 21 | BUILD_DATE=${OPTARG} 22 | ;; 23 | *) 24 | print_help_and_exit 25 | ;; 26 | esac 27 | done 28 | 29 | RELEASE_NOTE="../RELEASE.md" 30 | RELEASE_NOTE_NEW="release_new.md" 31 | 32 | if [[ -z "$BUILD_DATE" ]]; then 33 | echo "Build date is required!" 34 | exit 1 35 | fi 36 | 37 | if [[ -z "$PRAXIS_VERSION" ]]; then 38 | echo "praxis version is required!" 39 | exit 1 40 | fi 41 | 42 | echo "Build date: "$BUILD_DATE 43 | echo "PRAXIS version: "$PRAXIS_VERSION 44 | 45 | sed -i "s/version='[0-9.]*'/version='$PRAXIS_VERSION'/" setup.py 46 | sed -i "s/_RELEASE_VERSION: '[0-9.]*'/_RELEASE_VERSION: '$PRAXIS_VERSION'/" cloudbuild-release.yaml 47 | gsutil cp gs://pax-on-cloud-tpu-project/wheels/"$BUILD_DATE"/praxis_commit.txt ./ 48 | PRAXIS_COMMIT=$(> $RELEASE_NOTE_NEW 53 | echo "## Major Features and Improvements" >> $RELEASE_NOTE_NEW 54 | echo "## Breaking changes" >> $RELEASE_NOTE_NEW 55 | echo "## Deprecations" >> $RELEASE_NOTE_NEW 56 | echo "## Note" >> $RELEASE_NOTE_NEW 57 | echo "* Version: $PRAXIS_VERSION" >> $RELEASE_NOTE_NEW 58 | echo "* Build Date: $BUILD_DATE" >> $RELEASE_NOTE_NEW 59 | echo "* Praxis commit: $PRAXIS_COMMIT" >> $RELEASE_NOTE_NEW 60 | RELEASE_NOTE_TMP="RELEASE.tmp.md" 61 | cat $RELEASE_NOTE_NEW $RELEASE_NOTE >> $RELEASE_NOTE_TMP 62 | rm $RELEASE_NOTE_NEW 63 | rm $RELEASE_NOTE 64 | mv $RELEASE_NOTE_TMP $RELEASE_NOTE 65 | -------------------------------------------------------------------------------- /praxis/pip_package/presubmit.Dockerfile: -------------------------------------------------------------------------------- 1 | ARG image_name 2 | ARG base_image="gcr.io/pax-on-cloud-project/${image_name}:latest" 3 | FROM $base_image 4 | 5 | RUN rm -rf /praxis 6 | COPY . /praxis 7 | RUN pip3 uninstall -y fiddle 8 | RUN pip3 uninstall -y flax 9 | RUN pip3 uninstall -y jax 10 | RUN pip3 install --no-deps -r /praxis/praxis/pip_package/requirements.txt 11 | RUN cd /praxis && bazel build ... 12 | #TODO:enable -praxis/layers:normalizations_test once the new Lingvo pip package is released 13 | # RUN cd /praxis && bazel test --test_output=all --test_verbose_timeout_warnings -- praxis/... -praxis/layers:transformer_models_test -praxis/layers:ngrammer_test -praxis/layers:attentions_test -praxis/layers:transformers_test -praxis/layers:models_test -praxis/layers:convolutions_test -praxis/layers:embedding_softmax_test -praxis/layers:ctc_objectives_test -praxis/layers:normalizations_test 14 | # RUN cd /praxis && bazel test --test_output=all --test_verbose_timeout_warnings -- praxis:asserts_test praxis:base_hyperparams_test 15 | RUN cd /praxis && \ 16 | bazel test \ 17 | --test_output=all \ 18 | --test_verbose_timeout_warnings \ 19 | -- \ 20 | praxis/... \ 21 | -praxis/layers:attentions_test \ 22 | -praxis/layers:convolutions_test \ 23 | -praxis/layers:ctc_objectives_test \ 24 | -praxis/layers:embedding_softmax_test \ 25 | -praxis/layers:models_test \ 26 | -praxis/layers:ngrammer_test \ 27 | -praxis/layers:normalizations_test \ 28 | -praxis/layers:rnn_cell_test \ 29 | -praxis/layers:transformer_models_encoder_decoder_test \ 30 | -praxis/layers:transformer_models_test \ 31 | -praxis/layers:transformers_test 32 | 33 | 34 | WORKDIR / 35 | 36 | CMD ["/bin/bash"] 37 | -------------------------------------------------------------------------------- /praxis/pip_package/release.Dockerfile: -------------------------------------------------------------------------------- 1 | ARG cpu_base_image="ubuntu:22.04" 2 | ARG base_image=$cpu_base_image 3 | FROM $base_image 4 | 5 | LABEL maintainer="Pax team " 6 | 7 | # Re-declare args because the args declared before FROM can't be used in any 8 | # instruction after a FROM. 9 | ARG cpu_base_image="ubuntu:22.04" 10 | ARG base_image=$cpu_base_image 11 | ARG wheel_folder 12 | ENV WHEEL_FOLDER $wheel_folder 13 | ENV PYTHON_VERSION="3" 14 | ENV PYTHON_MINOR_VERSION="10" 15 | 16 | # Pick up some TF dependencies 17 | RUN apt update && DEBIAN_FRONTEND=noninteractive apt install -y --no-install-recommends software-properties-common 18 | RUN apt update && DEBIAN_FRONTEND=noninteractive apt install -y --no-install-recommends \ 19 | build-essential \ 20 | curl \ 21 | git \ 22 | pkg-config \ 23 | rename \ 24 | rsync \ 25 | unzip \ 26 | vim \ 27 | && \ 28 | apt-get clean && \ 29 | rm -rf /var/lib/apt/lists/* 30 | 31 | # Install python 3.10 32 | RUN apt-get update && apt-get install -y \ 33 | python3 python3-dev python3-pip python3-venv && \ 34 | rm -rf /var/lib/apt/lists/* && \ 35 | python3.10 -m pip install pip --upgrade && \ 36 | update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 0 37 | 38 | # Make python3.10 the default python version 39 | RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.10 0 40 | 41 | ARG bazel_version=5.1.1 42 | # This is to install bazel, for development purposes. 43 | ENV BAZEL_VERSION ${bazel_version} 44 | RUN mkdir /bazel && \ 45 | cd /bazel && \ 46 | curl -H "User-Agent: Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/57.0.2987.133 Safari/537.36" -fSsL -O https://github.com/bazelbuild/bazel/releases/download/$BAZEL_VERSION/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \ 47 | curl -H "User-Agent: Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/57.0.2987.133 Safari/537.36" -fSsL -o /bazel/LICENSE.txt https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE && \ 48 | chmod +x bazel-*.sh && \ 49 | ./bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \ 50 | cd / && \ 51 | rm -f /bazel/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh 52 | 53 | COPY . /praxis 54 | RUN mkdir $WHEEL_FOLDER 55 | RUN sed -i 's/ @ git.*//g' /praxis/requirements.in 56 | RUN pip3 install -r /praxis/requirements.in 57 | 58 | RUN cd /praxis && bazel build ... 59 | 60 | RUN cd praxis && \ 61 | bazel test \ 62 | --test_output=all \ 63 | --test_verbose_timeout_warnings \ 64 | -- \ 65 | praxis/... \ 66 | -praxis/layers:attentions_test \ 67 | -praxis/layers:convolutions_test \ 68 | -praxis/layers:ctc_objectives_test \ 69 | -praxis/layers:embedding_softmax_test \ 70 | -praxis/layers:models_test \ 71 | -praxis/layers:ngrammer_test \ 72 | -praxis/layers:normalizations_test \ 73 | -praxis/layers:transformer_models_test \ 74 | -praxis/layers:transformers_test 75 | 76 | RUN cd praxis && bash praxis/pip_package/build_pip_pkg.sh 77 | 78 | WORKDIR / 79 | 80 | CMD ["/bin/bash"] 81 | -------------------------------------------------------------------------------- /praxis/praxis.bzl: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Implements custom rules for Praxis.""" 17 | 18 | # Placeholder to use until bazel supports pytype_library. 19 | def pytype_strict_library(name, **kwargs): 20 | native.py_library(name = name, **kwargs) 21 | 22 | # Placeholder to use until bazel supports pytype_strict_contrib_test. 23 | def pytype_strict_test(name, **kwargs): 24 | native.py_test(name = name, **kwargs) 25 | -------------------------------------------------------------------------------- /praxis/pytypes_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Praxis pytype tests.""" 17 | 18 | from absl.testing import absltest 19 | from jax import tree_util as jtu 20 | from praxis import pytypes 21 | from praxis import test_utils 22 | 23 | 24 | class TreesTest(test_utils.TestCase): 25 | 26 | def test_nestedmap_paths(self): 27 | """Ensure NestedMap is registered as a pytree_node correctly.""" 28 | tree = pytypes.NestedMap( 29 | a=pytypes.NestedMap(x=0, y=1, zz=2), 30 | b=pytypes.NestedMap(z=1), 31 | ) 32 | dict_tree = {'a': {'x': 0, 'y': 1, 'zz': 2}, 'b': {'z': 1}} 33 | self.assertSequenceEqual( 34 | jtu.tree_leaves_with_path(tree), 35 | [ 36 | ((jtu.DictKey(key='a'), jtu.DictKey(key='x')), 0), 37 | ((jtu.DictKey(key='a'), jtu.DictKey(key='y')), 1), 38 | ((jtu.DictKey(key='a'), jtu.DictKey(key='zz')), 2), 39 | ((jtu.DictKey(key='b'), jtu.DictKey(key='z')), 1), 40 | ], 41 | ) 42 | self.assertSequenceEqual( 43 | jtu.tree_leaves_with_path(dict_tree), jtu.tree_leaves_with_path(tree) 44 | ) 45 | 46 | 47 | if __name__ == '__main__': 48 | absltest.main() 49 | -------------------------------------------------------------------------------- /requirements.in: -------------------------------------------------------------------------------- 1 | # To update requirements.txt for praxis and paxml, run: 2 | # cd ../../paxml/pip_package && bash ./compile_requirements.sh 3 | 4 | absl-py 5 | # enforce newer chex version to avoid deprecated jax syntax 6 | chex>=0.1.85 7 | clu @ git+https://github.com/google/CommonLoopUtils 8 | einops 9 | etils 10 | fiddle @ git+https://github.com/google/fiddle 11 | flax @ git+https://github.com/google/flax 12 | jax @ git+https://github.com/google/jax 13 | jax-bitempered-loss 14 | jaxtyping 15 | lingvo 16 | numpy<2 17 | optax 18 | optax-shampoo 19 | opt-einsum 20 | # ensure sentencepiece is compataible with protobuf 3.19.6 21 | sentencepiece==0.1.99 22 | tensorflow-datasets==4.8.3 23 | tensorflow-metadata==1.12.0 24 | tensorflow-text~=2.9.0 25 | tensorflow~=2.9.2 26 | tfds-nightly==4.8.3.dev202303280045 27 | typeguard 28 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Pax Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Setup.py file for praxis.""" 17 | 18 | import os 19 | from setuptools import find_packages 20 | from setuptools import setup 21 | 22 | # Set this envvar to avoid installing packages from head that can overwrite 23 | # existing installs of those packages, e.g., jax 24 | SKIP_HEAD_INSTALLS = os.environ.get('SKIP_HEAD_INSTALLS', '') 25 | 26 | def _get_requirements(): 27 | """Parses requirements.txt file.""" 28 | install_requires_tmp = [] 29 | with open( 30 | os.path.join(os.path.dirname(__file__), './requirements.in'), 'r' 31 | ) as f: 32 | for line in f: 33 | package_name = line.strip() 34 | # Skip empty line or comments starting with "#". 35 | if ( 36 | not package_name 37 | or package_name[0] == '#' 38 | or (' @ ' in package_name and SKIP_HEAD_INSTALLS) 39 | ): 40 | continue 41 | else: 42 | install_requires_tmp.append(package_name) 43 | return install_requires_tmp 44 | 45 | 46 | install_requires = _get_requirements() 47 | 48 | setup( 49 | name='praxis', 50 | version='1.4.0', 51 | description=( 52 | 'Functionalities such as a layers for building neural networks in Jax.' 53 | ), 54 | author='PAX team', 55 | author_email='pax-dev@google.com', 56 | packages=find_packages(), 57 | python_requires='>=3.10', 58 | install_requires=install_requires, 59 | url='https://github.com/google/praxis', 60 | license='Apache-2.0', 61 | classifiers=[ 62 | 'Programming Language :: Python :: 3.10', 63 | 'Programming Language :: Python :: 3.11', 64 | ], 65 | zip_safe=False, 66 | ) 67 | --------------------------------------------------------------------------------