├── .vscode ├── extensions.json └── settings.json ├── .gitignore ├── precondition ├── datamix_gemma │ ├── README.md │ ├── evals │ │ ├── __init__.py │ │ ├── eval.py │ │ └── crop.py │ ├── tokenizers │ │ ├── __init__.py │ │ └── gemma_tokenizer.py │ ├── dataset_builders │ │ ├── __init__.py │ │ ├── dataset_builder.py │ │ ├── preprocessed_dolly_dataset_builder.py │ │ ├── preprocessed_metamath_dataset_builder.py │ │ ├── preprocessed_open_orca_dataset_builder.py │ │ ├── preprocessed_codealpaca_dataset_builder.py │ │ ├── preprocessed_orca_math_dataset_builder.py │ │ ├── preprocessed_gsm8k_dataset_builder.py │ │ ├── preprocessed_wikipedia_dataset_builder.py │ │ ├── preprocessed_sciq_dataset_builder.py │ │ ├── mbpp_dataset_builder.py │ │ ├── mtnt_dataset_builder.py │ │ ├── orca_math_dataset_builder.py │ │ ├── open_orca_dataset_builder.py │ │ └── gsm8k_dataset_builder.py │ ├── training_batch_generators │ │ ├── __init__.py │ │ ├── training_batch_generator.py │ │ ├── vanilla_training_batch_generator.py │ │ ├── importance_weighting_training_batch_generator.py │ │ ├── fixed_dataset_importance_weighting_training_batch_generator.py │ │ ├── dartboard_importance_weighting_training_batch_generator.py │ │ └── dartboard_deterministic_training_batch_generator.py │ ├── cross_compile.py │ ├── finetune.py │ ├── random_baseline.py │ ├── finetuning_experiment.py │ ├── confusion_matrix_calc.py │ ├── finetune_eval_measurement.py │ ├── snr_calculation.py │ ├── Wikipedia_processing.ipynb │ └── deterministic_strategy_bandit_loop.py ├── tearfree │ ├── reallocation_test_data │ │ └── gnn_realloc.json │ ├── reallocation_test.py │ ├── praxis_shim.py │ ├── second_order.py │ ├── optimizer.py │ ├── reshaper_test.py │ ├── reshaper.py │ ├── momentum.py │ ├── optimizer_smoke_test.py │ ├── momentum_test.py │ └── optimizer_test.py ├── __init__.py ├── sm3_test.py ├── oco │ ├── datasets.py │ └── train.py ├── quantization_utils.py └── sm3.py ├── CHANGELOG.md ├── CONTRIBUTING.md ├── .github └── workflows │ └── pytest_and_autopublish.yml ├── pyproject.toml └── README.md /.vscode/extensions.json: -------------------------------------------------------------------------------- 1 | { 2 | "recommendations": [ 3 | "ms-python.black-formatter" 4 | ] 5 | } 6 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled python modules. 2 | *.pyc 3 | 4 | # Byte-compiled 5 | _pycache__/ 6 | .cache/ 7 | 8 | # Poetry, setuptools, PyPI distribution artifacts. 9 | /*.egg-info 10 | .eggs/ 11 | build/ 12 | dist/ 13 | poetry.lock 14 | 15 | # Tests 16 | .pytest_cache/ 17 | 18 | # Type checking 19 | .pytype/ 20 | 21 | # Other 22 | *.DS_Store 23 | 24 | # PyCharm 25 | .idea 26 | -------------------------------------------------------------------------------- /precondition/datamix_gemma/README.md: -------------------------------------------------------------------------------- 1 | # `Datamix`: Data mixture selection for LLM fine-tuning 2 | 3 | Authors: Shivam Gupta, Vlad Feinberg, Xinyi Chen, Elad Hazan, Peter Bartlett 4 | 5 | This project studies data mixture selection strategies for LLM fine-tuning. The goal is to learn a distribution over k datasets, such that SFT using this distribution maximizes performance on a downstream evaluation. -------------------------------------------------------------------------------- /precondition/tearfree/reallocation_test_data/gnn_realloc.json: -------------------------------------------------------------------------------- 1 | {"Dense_15": {"kernel": [278, 256]}, "Dense_17": {"kernel": [256, 128]}, "Dense_12": {"kernel": [290, 256]}, "Dense_3": {"kernel": [377, 256]}, "Dense_14": {"kernel": [285, 256]}, "Dense_2": {"kernel": [135, 256]}, "Dense_13": {"kernel": [439, 256]}, "Dense_9": {"kernel": [262, 256]}, "Dense_11": {"kernel": [267, 256]}, "Dense_7": {"kernel": [233, 256]}, "Dense_10": {"kernel": [298, 256]}, "Dense_0": {"kernel": [256, 0]}, "Dense_4": {"kernel": [256, 256]}, "Dense_6": {"kernel": [231, 256]}, "Dense_16": {"kernel": [248, 256]}, "Dense_5": {"kernel": [207, 256]}, "Dense_1": {"kernel": [62, 0]}, "Dense_8": {"kernel": [228, 256]}} 2 | -------------------------------------------------------------------------------- /precondition/datamix_gemma/evals/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /precondition/datamix_gemma/tokenizers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /precondition/datamix_gemma/dataset_builders/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /precondition/datamix_gemma/training_batch_generators/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /precondition/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """precondition API.""" 16 | 17 | # A new PyPI release will be pushed everytime `__version__` is increased 18 | # When changing this, also update the CHANGELOG.md 19 | __version__ = '0.3.0' 20 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "files.insertFinalNewline": true, 3 | "files.trimFinalNewlines": true, 4 | "files.trimTrailingWhitespace": true, 5 | "files.associations": { 6 | ".pylintrc": "ini" 7 | }, 8 | "python.testing.unittestEnabled": false, 9 | "python.testing.nosetestsEnabled": false, 10 | "python.testing.pytestEnabled": true, 11 | "python.linting.pylintUseMinimalCheckers": false, 12 | "[python]": { 13 | "editor.rulers": [80], 14 | "editor.tabSize": 2, 15 | "editor.defaultFormatter": "ms-python.black-formatter", 16 | "editor.formatOnSave": true, 17 | "editor.detectIndentation": false 18 | }, 19 | "python.formatting.provider": "none", 20 | "black-formatter.path": ["pyink"], 21 | "files.watcherExclude": { 22 | "**/.git/**": true 23 | }, 24 | "files.exclude": { 25 | "**/__pycache__": true, 26 | "**/.pytest_cache": true, 27 | "**/*.egg-info": true 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /precondition/datamix_gemma/evals/eval.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Base class for evaluation.""" 16 | 17 | class Eval: 18 | def __init__(self, model, tokenizer, vocab, eval_batch_size): 19 | self.model = model 20 | self.tokenizer = tokenizer 21 | self.vocab = vocab 22 | self.eval_batch_size = eval_batch_size 23 | 24 | def evaluate(self, params): 25 | raise NotImplementedError() 26 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | 23 | 24 | ## [Unreleased] 25 | 26 | ## [0.1.0] - 2022-01-01 27 | 28 | * Initial release 29 | 30 | [Unreleased]: https://github.com/google-research/precondition/compare/v0.1.0...HEAD 31 | [0.1.0]: https://github.com/google-research/precondition/releases/tag/v0.1.0 32 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement (CLA). You (or your employer) retain the copyright to your 10 | contribution; this simply gives us permission to use and redistribute your 11 | contributions as part of the project. Head over to 12 | to see your current agreements on file or 13 | to sign a new one. 14 | 15 | You generally only need to submit a CLA once, so if you've already submitted one 16 | (even if it was for a different project), you probably don't need to do it 17 | again. 18 | 19 | ## Code Reviews 20 | 21 | All submissions, including submissions by project members, require review. We 22 | use GitHub pull requests for this purpose. Consult 23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 24 | information on using pull requests. 25 | 26 | ## Community Guidelines 27 | 28 | This project follows 29 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/). 30 | -------------------------------------------------------------------------------- /precondition/datamix_gemma/cross_compile.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Cross compile the finetune code.""" 16 | 17 | from collections.abc import Sequence 18 | 19 | from absl import app 20 | from absl import flags 21 | from jax.mock_backend import mock_backend 22 | from precondition.datamix_gemma import confusion_matrix_calc 23 | 24 | 25 | CROSS_COMPILE = flags.DEFINE_boolean('debug', True, 'Run in debug mode') 26 | 27 | 28 | def cross_compile(): 29 | mock_backend.use_mock_backend( 30 | topology='8x8', 31 | chip_config='default', 32 | ) 33 | 34 | 35 | def main(_: Sequence[str]) -> None: 36 | cross_compile() 37 | confusion_matrix_calc.confusion_matrix_calc() 38 | #finetune_utils.finetune() 39 | 40 | 41 | if __name__ == '__main__': 42 | app.run(main) 43 | if __name__ == '__main__': 44 | app.run(main) 45 | -------------------------------------------------------------------------------- /precondition/datamix_gemma/training_batch_generators/training_batch_generator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """TrainingBatchGenerator.""" 16 | 17 | class TrainingBatchGenerator: 18 | """TrainingBatchGenerator.""" 19 | 20 | def __init__(self, train_ds_builders, batch_size, num_weights=2, num_iterations=100): 21 | self.train_ds_builders = train_ds_builders 22 | self.batch_size = batch_size 23 | self.num_weights = num_weights 24 | self.num_iterations = num_iterations 25 | 26 | #def prepare_for_training(self, weights_1, weights_2): 27 | # """Prepare for training.""" 28 | # raise NotImplementedError() 29 | def prepare_for_training(self, weights_list, new_unnormalized_weights): 30 | """Prepare for training.""" 31 | raise NotImplementedError() 32 | 33 | def get_next_batch(self, index): 34 | raise NotImplementedError() 35 | -------------------------------------------------------------------------------- /precondition/datamix_gemma/finetune.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | r"""Finetuning script for Gemma. 16 | 17 | POOL=gdm 18 | ALLOC=brain-pton 19 | 20 | To run: 21 | xmanager launch experimental/brain_pton/datamix_gemma/finetuning_experiment.py 22 | -- --xm_default_build_flags= --xm_enable_build_isolation=false 23 | --xm_resource_alloc=${POOL}/${ALLOC} 24 | 25 | To cross compile: 26 | """ 27 | 28 | from collections.abc import Sequence 29 | 30 | from absl import flags 31 | from precondition.datamix_gemma import finetune_utils 32 | from python import app 33 | 34 | 35 | # os.environ['KERAS_BACKEND'] = 'jax' 36 | 37 | CROSS_COMPILE = flags.DEFINE_boolean('debug', False, 'Run in debug mode') 38 | 39 | 40 | def main(_: Sequence[str]) -> None: 41 | #confusion_matrix_calc.confusion_matrix_calc() 42 | #snr_calculation.snr_calculation() 43 | finetune_utils.finetune() 44 | #finetune_eval_measurement.finetune_eval_measurement() 45 | 46 | 47 | if __name__ == '__main__': 48 | app.run(main) 49 | -------------------------------------------------------------------------------- /.github/workflows/pytest_and_autopublish.yml: -------------------------------------------------------------------------------- 1 | name: Unittests & Auto-publish 2 | 3 | # Allow to trigger the workflow manually (e.g. when deps changes) 4 | on: [push, workflow_dispatch] 5 | 6 | jobs: 7 | pytest-job: 8 | runs-on: ubuntu-latest 9 | timeout-minutes: 30 10 | 11 | concurrency: 12 | group: ${{ github.workflow }}-${{ github.ref }} 13 | cancel-in-progress: true 14 | 15 | steps: 16 | - uses: actions/checkout@v3 17 | 18 | # Install deps 19 | - uses: actions/setup-python@v4 20 | with: 21 | python-version: 3.9 22 | cache: pip 23 | cache-dependency-path: '**/pyproject.toml' 24 | 25 | - run: pip --version 26 | - run: pip install -e .[dev] 27 | - run: pip freeze 28 | 29 | # Run tests (in parallel) 30 | - name: Run core tests 31 | run: pytest -vv -n auto 32 | 33 | # Auto-publish when version is increased 34 | publish-job: 35 | # Only try to publish if: 36 | # * Repo is self (prevents running from forks) 37 | # * Branch is `main` 38 | if: | 39 | github.repository == 'google-research/precondition' 40 | && github.ref == 'refs/heads/main' 41 | needs: pytest-job # Only publish after tests are successful 42 | runs-on: ubuntu-latest 43 | permissions: 44 | contents: write 45 | timeout-minutes: 30 46 | environment: pypi publish 47 | steps: 48 | # Publish the package (if local `__version__` > pip version) 49 | - uses: etils-actions/pypi-auto-publish@v1 50 | with: 51 | pypi-token: ${{ secrets.PYPI_API_TOKEN }} 52 | gh-token: ${{ secrets.GITHUB_TOKEN }} 53 | parse-changelog: true 54 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "precondition-opt" 3 | description = "Preconditioning optimizers." 4 | readme = "README.md" 5 | requires-python = ">=3.8" 6 | license = {file = "LICENSE"} 7 | authors = [{name = "precondition authors", email="precondition-optimizers@google.com"}] 8 | classifiers = [ 9 | "Programming Language :: Python :: 3", 10 | "Programming Language :: Python :: 3 :: Only", 11 | "License :: OSI Approved :: Apache Software License", 12 | "Intended Audience :: Science/Research", 13 | ] 14 | keywords = [] 15 | 16 | # pip dependencies of the project 17 | dependencies = [ 18 | "chex", 19 | "flax", 20 | "optax", 21 | "jax", 22 | "numpy", 23 | "joblib", 24 | "absl-py", 25 | "scipy", 26 | "scikit-learn", 27 | "pandas", 28 | ] 29 | 30 | # This is set automatically by flit using `precondition.__version__` 31 | dynamic = ["version"] 32 | 33 | [project.urls] 34 | homepage = "https://github.com/google-research/precondition" 35 | repository = "https://github.com/google-research/precondition" 36 | # Other: `documentation`, `changelog` 37 | 38 | [project.optional-dependencies] 39 | # Development deps (unittest, linting, formating,...) 40 | # Installed through `pip install .[dev]` 41 | dev = [ 42 | "absl-py>=0.8.1", 43 | "pytest", 44 | "pytest-xdist", 45 | "pylint>=2.6.0", 46 | "pyink", 47 | ] 48 | 49 | [tool.pyink] 50 | # Formatting configuration to follow Google style-guide 51 | line-length = 80 52 | preview = true 53 | pyink-indentation = 2 54 | pyink-use-majority-quotes = true 55 | 56 | [build-system] 57 | requires = ["flit_core >=3.5,<4"] 58 | build-backend = "flit_core.buildapi" 59 | 60 | [tool.flit.module] 61 | name = "precondition" 62 | -------------------------------------------------------------------------------- /precondition/sm3_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for distributed_shampoo.""" 16 | 17 | from absl.testing import absltest 18 | import chex 19 | import jax 20 | import jax.numpy as jnp 21 | 22 | from precondition import sm3 23 | 24 | 25 | class SM3Test(chex.TestCase): 26 | 27 | def setUp(self): 28 | super().setUp() 29 | self.init_params = ( 30 | jnp.array([[0.5, 0.5], [0.5, 0.5]])) 31 | self.per_step_updates = (jnp.array([[0.1, -0.1], [0.01, 0.01]])) 32 | 33 | @chex.all_variants(with_pmap=False) 34 | def test_sm3_basic(self): 35 | params = self.init_params 36 | 37 | optim = sm3.sm3(0.1, 0.9, 0.999) 38 | init_fn = self.variant(optim.init) 39 | transform_fn = self.variant(optim.update) 40 | 41 | def _update(unused_batch): 42 | return transform_fn(self.per_step_updates, state, params) 43 | state = init_fn(params) 44 | chex.assert_tree_all_finite(state) 45 | pmap_fn = jax.pmap(_update, axis_name='batch') 46 | 47 | updates, state = pmap_fn(jnp.array([1.0])) 48 | chex.assert_tree_all_finite((params, updates, state)) 49 | 50 | 51 | if __name__ == '__main__': 52 | absltest.main() 53 | -------------------------------------------------------------------------------- /precondition/datamix_gemma/random_baseline.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Random baseline.""" 16 | 17 | import copy 18 | 19 | from absl import logging 20 | import jax 21 | import numpy as np 22 | from precondition.datamix_gemma import training_loop 23 | from precondition.datamix_gemma.evals import eval as eval_lib 24 | from precondition.datamix_gemma.training_batch_generators import training_batch_generator 25 | 26 | 27 | def random_simplex(n): 28 | """Return uniformly random vector in the n-simplex.""" 29 | k = np.random.exponential(scale=1.0, size=n) 30 | return k / np.sum(k) 31 | 32 | 33 | def random_baseline( 34 | eval_obj: eval_lib.Eval, 35 | train_obj: training_loop.TrainingLoop, 36 | training_batch_generator_obj: training_batch_generator.TrainingBatchGenerator, 37 | init_params, 38 | num_iterations=100, 39 | ): 40 | """Random baseline.""" 41 | for _ in range(num_iterations): 42 | random_weights = random_simplex(len(training_batch_generator_obj.train_ds_builders)) 43 | cur_params = copy.deepcopy(init_params) 44 | cur_params = jax.device_get(cur_params) 45 | trained_params = train_obj.train_loop( 46 | params={'params': cur_params}, get_next_batch_fn=training_batch_generator_obj.get_next_batch 47 | ) 48 | score = eval_obj.evaluate(trained_params['params']) 49 | logging.info(f'score: {score}') 50 | for i in range(len(random_weights)): 51 | logging.info(f'weights_{str(i)}: {random_weights[i]}') 52 | -------------------------------------------------------------------------------- /precondition/datamix_gemma/finetuning_experiment.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Experiment to finetune GEMMA.""" 16 | 17 | from typing import Sequence 18 | 19 | from absl import flags 20 | from absl import logging 21 | from python import app 22 | import tensorflow as tf 23 | from xmanager import xm 24 | from xmanager import xm_abc 25 | from xmanager.contrib.internal import xm_jax 26 | 27 | 28 | CROSS_COMPILE = flags.DEFINE_boolean('debug', False, 'Run in debug mode') 29 | 30 | 31 | def main(_: Sequence[str]) -> None: 32 | """Launches finetuning experiment.""" 33 | title = 'finetune_gemma' 34 | logging.set_verbosity(logging.INFO) 35 | tf.config.set_soft_device_placement(True) 36 | 37 | executor = xm_abc.Borg( 38 | requirements=xm.JobRequirements( 39 | ), 40 | ) 41 | with xm_abc.create_experiment(experiment_title=title) as experiment: 42 | [executable] = experiment.package([ 43 | xm.bazel_binary( 44 | label='//experimental/brain_pton/datamix_gemma:finetune', 45 | executor_spec=xm_abc.Borg.Spec(), 46 | args=xm_jax.JaxFlags().flags(), 47 | bazel_args=xm_abc.bazel_args.tpu(), 48 | ), 49 | ]) 50 | experiment.add( 51 | xm.Job( 52 | executable, 53 | #args=exe_args, 54 | env_vars={'HF_DATASETS_CACHE': '/tmp/hf_datasets_cache', 55 | 'KERAS_BACKEND': 'jax'}, 56 | executor=executor, 57 | ) 58 | ) 59 | 60 | 61 | if __name__ == '__main__': 62 | app.run(main) 63 | -------------------------------------------------------------------------------- /precondition/datamix_gemma/dataset_builders/dataset_builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Base class for dataset builders.""" 16 | 17 | import chex 18 | import jax 19 | from precondition.datamix_gemma.tokenizers import gemma_tokenizer 20 | import tensorflow as tf 21 | 22 | 23 | @chex.dataclass(frozen=True) 24 | class TrainingInput: 25 | # Input tokens provided to model 26 | input_tokens: jax.Array 27 | 28 | # A mask that determines which tokens contribute to the target loss 29 | # calculation 30 | target_mask: jax.Array 31 | 32 | 33 | class DatasetBuilder: 34 | """Base class for dataset builders. 35 | 36 | This class provides the interface for dataset builders. 37 | """ 38 | 39 | def __init__(self, tokenizer: gemma_tokenizer.GemmaTokenizer, 40 | max_seq_len: int): 41 | """Constructor. 42 | 43 | Args: 44 | tokenizer: Gemma tokenizer to use. 45 | max_seq_len: size of each sequence in a given batch. 46 | """ 47 | self._tokenizer = tokenizer 48 | self._max_seq_len = max_seq_len 49 | 50 | def _pad_up_to_max_len( 51 | self, input_tensor: tf.Tensor, pad_value: int | bool 52 | ) -> tf.Tensor: 53 | """Pads the given tensor up to max_seq_len.""" 54 | seq_len = tf.shape(input_tensor)[0] 55 | to_pad = tf.maximum(0, self._max_seq_len - seq_len) 56 | return tf.pad( 57 | input_tensor, 58 | [[0, to_pad]], 59 | mode='CONSTANT', 60 | constant_values=pad_value 61 | ) 62 | 63 | def get_train_dataset(self, batch_size: int, num_epochs: int): 64 | raise NotImplementedError() 65 | 66 | def get_validation_dataset(self, batch_size: int): 67 | raise NotImplementedError() 68 | -------------------------------------------------------------------------------- /precondition/tearfree/reallocation_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Simple test case for memory reallocation function.""" 16 | 17 | import json 18 | import os 19 | 20 | from absl.testing import absltest 21 | from absl.testing import parameterized 22 | from jax import numpy as jnp 23 | from precondition.tearfree import reallocation 24 | 25 | 26 | def dict_almost_equal(dict1, dict2, delta=1): 27 | """Helper function.""" 28 | for key, value in dict1.items(): 29 | assert key in dict2, key 30 | if isinstance(value, dict): 31 | dict_almost_equal(value, dict2[key], delta) 32 | else: 33 | for i in range(len(value)): 34 | assert jnp.abs(value[i] - dict2[key][i]) <= delta 35 | 36 | 37 | class ReallocationTest(parameterized.TestCase): 38 | 39 | def test_create_redist_dict(self): 40 | chpt_path = '' 41 | data_dir = os.path.join( 42 | os.path.dirname(__file__), 'reallocation_test_data' 43 | ) 44 | realloc_path = os.path.join(data_dir, 'gnn_realloc.json') 45 | states_path = os.path.join(data_dir, 'states.json') 46 | with open(states_path, 'r') as f: 47 | states = tuple(json.load(f)) 48 | sketches = states[-1]['inner_state']['0']['direction']['1']['sketches'] 49 | for layer in sketches: 50 | tmp = sketches[layer]['kernel']['axes'] 51 | for axes in tmp: 52 | tmp[axes]['eigvals'] = jnp.array( 53 | tmp[axes]['eigvals'], dtype=jnp.float32 54 | ) 55 | states[-1]['inner_state']['0']['direction']['1']['sketches'][layer][ 56 | 'kernel' 57 | ]['axes'] = tmp 58 | realloc_result = reallocation.create_redist_dict( 59 | chpt_path, [-1], 'sketch_trace', False, 256, states 60 | ) 61 | with open(realloc_path, 'r') as f: 62 | realloc_dict = json.load(f) 63 | 64 | dict_almost_equal(realloc_result, realloc_dict, delta=1) 65 | 66 | if __name__ == '__main__': 67 | absltest.main() 68 | 69 | -------------------------------------------------------------------------------- /precondition/datamix_gemma/training_batch_generators/vanilla_training_batch_generator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """VanillaTrainingBatchGenerator.""" 16 | 17 | from absl import logging 18 | import numpy as np 19 | from precondition.datamix_gemma.training_batch_generators import training_batch_generator 20 | import tensorflow_datasets as tfds 21 | 22 | 23 | class VanillaTrainingBatchGenerator( 24 | training_batch_generator.TrainingBatchGenerator 25 | ): 26 | """VanillaTrainingBatchGenerator.""" 27 | 28 | def __init__(self, train_ds_builders, batch_size, num_weights=2, num_iterations=100): 29 | super().__init__(train_ds_builders, batch_size, num_weights, num_iterations) 30 | self.training_iters = [] 31 | for dataset_builder_obj in self.train_ds_builders: 32 | self.training_iters.append( 33 | iter( 34 | tfds.as_numpy( 35 | dataset_builder_obj.get_train_dataset( 36 | batch_size=batch_size, num_epochs=1 37 | ) 38 | ) 39 | ) 40 | ) 41 | self.weights_list = [] 42 | 43 | def prepare_for_training(self, weights_list, new_unnormalized_weights): 44 | """Prepare for training.""" 45 | self.weights_list = weights_list 46 | #gradient discount factor 47 | return 1 48 | 49 | def get_next_batch(self, index): 50 | weights = self.weights_list[index] 51 | input_tokens_batch = [] 52 | input_mask_batch = [] 53 | factors = [] 54 | for _ in range(self.batch_size): 55 | cur_ind = np.random.choice(len(self.training_iters), p=weights) 56 | logging.info(f'cur_ind: {cur_ind}') 57 | try: 58 | cur_example = next(self.training_iters[cur_ind]) 59 | except StopIteration: 60 | self.training_iters[cur_ind] = iter( 61 | tfds.as_numpy( 62 | self.train_ds_builders[cur_ind].get_train_dataset( 63 | batch_size=self.batch_size, num_epochs=1 64 | ) 65 | ) 66 | ) 67 | cur_example = next(self.training_iters[cur_ind]) 68 | input_tokens_batch.append(np.asarray([cur_example.input_tokens])) 69 | input_mask_batch.append(np.asarray([cur_example.target_mask])) 70 | factors.append(1) 71 | return factors, input_tokens_batch, input_mask_batch 72 | -------------------------------------------------------------------------------- /precondition/datamix_gemma/tokenizers/gemma_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # We import JAX and some related packages. 16 | # Finally, we import Gemma. 17 | import jax 18 | import jax.numpy as jnp 19 | from sentencepiece import sentencepiece_processor as spm 20 | # We will use tensorflow to handle the dataset 21 | import tensorflow as tf 22 | 23 | 24 | """Custom tokenizer for Gemma.""" 25 | class GemmaTokenizer: 26 | """Custom wrapper around a SentencePieceProcessor for tensorflow.""" 27 | 28 | def __init__(self, 29 | spm_processor: spm.SentencePieceProcessor): 30 | self._spm_processor = spm_processor 31 | 32 | @property 33 | def pad_id(self) -> int: 34 | """Fast access to the pad id.""" 35 | return self._spm_processor.pad_id() 36 | 37 | def tokenize(self, 38 | example: str | bytes, 39 | prefix: str = '', 40 | suffix: str = '', 41 | add_eos: bool = True) -> jax.Array: 42 | """Tokenization function. 43 | 44 | Args: 45 | example: input string to tokenize. 46 | prefix: prefix to add to the input string. 47 | suffix: suffix to add to the input string. 48 | add_eos: if True, add an end of sentence token at the end of the output 49 | sequence. 50 | Returns: 51 | Tokens corresponding to the input string. 52 | """ 53 | int_list = [self._spm_processor.bos_id()] 54 | int_list.extend(self._spm_processor.EncodeAsIds(prefix)) 55 | int_list.extend(self._spm_processor.EncodeAsIds(example)) 56 | int_list.extend(self._spm_processor.EncodeAsIds(suffix)) 57 | if add_eos: 58 | int_list.append(self._spm_processor.eos_id()) 59 | 60 | #return tf.convert_to_tensor(int_list, dtype=tf.int32) 61 | return jnp.array(int_list, dtype=jnp.int32) 62 | 63 | def tokenize_tf_op(self, 64 | str_tensor: tf.Tensor, 65 | prefix: str = '', 66 | suffix: str = '', 67 | add_eos: bool = True) -> tf.Tensor: 68 | """Tensforflow operator for the tokenize function.""" 69 | encoded = tf.numpy_function( 70 | self.tokenize, 71 | [str_tensor, prefix, suffix, add_eos], 72 | tf.int32) 73 | encoded.set_shape([None]) 74 | return encoded 75 | 76 | # def to_string(self, tokens: jax.Array) -> str: 77 | # """Convert an array of tokens to a string.""" 78 | # return self._spm_processor.EncodeIds(tokens.tolist()) 79 | -------------------------------------------------------------------------------- /precondition/oco/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Module with sparse datasets on CNS.""" 16 | 17 | import dataclasses 18 | import os 19 | from typing import Callable 20 | 21 | from absl import flags 22 | import jax 23 | import jax.numpy as jnp 24 | import joblib 25 | import numpy as np 26 | import sklearn.datasets 27 | 28 | _DATA_DIR = flags.DEFINE_string( 29 | 'data_dir', 30 | None, 31 | 'load data: your directory needs to contain the benchmark datasets' 32 | ' (a9a, a9a.t, cifar10, cifar10.t, gisette_scale, gisette_scale.t,' 33 | ' where .t stands for the testing set)' 34 | ' in libsvm format, available at' 35 | ' https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/', 36 | ) 37 | 38 | SUPPORTED_DATASETS: list[str] = [ 39 | 'a9a', 40 | 'a9a.t', 41 | 'cifar10', 42 | 'cifar10.t', 43 | 'gisette_scale', 44 | 'gisette_scale.t', 45 | ] 46 | 47 | 48 | def _logistic_loss(w: jax.Array, x: jax.Array, y: jax.Array) -> jax.Array: 49 | """Compute logistic loss.""" 50 | # assumes y is binary 51 | pred = jnp.dot(w, x, precision=jax.lax.Precision.HIGHEST) 52 | lse = lambda x: jax.nn.logsumexp(jnp.array(x)) 53 | return y * lse([0, -pred]) + (1 - y) * lse([0, pred]) 54 | 55 | 56 | def incorrect(w: jax.Array, x: jax.Array, y: jax.Array) -> jax.Array: 57 | """Compute binary 0-1 loss.""" 58 | pred = jnp.dot(w, x, precision=jax.lax.Precision.HIGHEST) 59 | return (pred > 0) != (y > 0) 60 | 61 | 62 | Loss = Callable[[jax.Array, jax.Array, jax.Array], jax.Array] 63 | 64 | 65 | @dataclasses.dataclass 66 | class SimpleDataset: 67 | """Simple dense supervised learning dataset for linear learners.""" 68 | 69 | x: np.ndarray 70 | y: np.ndarray 71 | loss: Loss 72 | w_shape: tuple[int, ...] 73 | 74 | 75 | def _load_dataset_uncached(name: str) -> SimpleDataset: 76 | """Generate a dataset with an intercept added.""" 77 | assert name in SUPPORTED_DATASETS, name 78 | if not _DATA_DIR.value: 79 | raise ValueError('must specify directory where datasets are stored') 80 | filename = os.path.join(_DATA_DIR.value, name) 81 | with open(filename, 'rb') as f: 82 | x, y = sklearn.datasets.load_svmlight_file(f) 83 | 84 | x = x.todense() 85 | x = np.concatenate([x, np.ones((len(x), 1))], axis=1) 86 | y = y > 0 87 | return SimpleDataset(x, y, _logistic_loss, (x.shape[1],)) 88 | 89 | 90 | def load_dataset(name: str, cache: str = '/tmp/cache') -> SimpleDataset: 91 | memory = joblib.Memory(cache, verbose=0) 92 | cached_fn = memory.cache(_load_dataset_uncached) 93 | return cached_fn(name) 94 | -------------------------------------------------------------------------------- /precondition/tearfree/praxis_shim.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Shim interfaces for praxis, to avoid circular dependencies.""" 16 | 17 | import dataclasses 18 | from typing import Any, NamedTuple, Union 19 | 20 | import jax 21 | from jax import numpy as jnp 22 | import optax 23 | 24 | 25 | @dataclasses.dataclass(frozen=True) 26 | class ShardedGradientTransformation: 27 | """GradientTransformation that supports spmd.""" 28 | 29 | init: optax.TransformInitFn 30 | update: optax.TransformUpdateFn 31 | init_partition_spec: Any 32 | 33 | 34 | NestedHParams = Any 35 | 36 | 37 | class WeightHParams(NamedTuple): 38 | shape: list[int] 39 | init: Any 40 | dtype: jnp.dtype 41 | collections: Any 42 | tensor_split_dims_mapping: list[int] 43 | 44 | 45 | def sharded_chain( 46 | *args: Union[optax.GradientTransformation, ShardedGradientTransformation], 47 | ) -> ShardedGradientTransformation: 48 | """Chain as in praxis.optimizers.sharded_chain.""" 49 | 50 | def init_fn(params): 51 | return tuple(fn.init(params) for fn in args) 52 | 53 | def update_fn(updates, state, params=None): 54 | if len(args) != len(state): 55 | raise ValueError( 56 | 'The number of updates and states has to be the same in ' 57 | f'sharded chain. got {len(args)=}, {len(state)=}' 58 | ) 59 | 60 | new_state = [] 61 | for s, fn in zip(state, args): 62 | updates, new_s = fn.update(updates, s, params) 63 | # Some of the new states may have None instead of optax.MaskedNode. 64 | new_s = jax.tree.map( 65 | lambda x: optax.MaskedNode() if x is None else x, 66 | new_s, 67 | is_leaf=lambda x: x is None, 68 | ) 69 | new_state.append(new_s) 70 | return updates, tuple(new_state) 71 | 72 | def init_partition_spec_fn(mdl_vars): 73 | partition_specs = [] 74 | for fn in args: 75 | init_partition_spec = getattr(fn, 'init_partition_spec', None) 76 | if callable(init_partition_spec): 77 | nmap = init_partition_spec(mdl_vars) 78 | partition_specs.append(nmap) 79 | else: 80 | # Raise ValueError as we are attempting to sharded_chain an optimizer 81 | # that does not have an `init_partition_spec` method defined. 82 | raise ValueError( 83 | 'Attempting to use an optimizer in sharded_chain that ' 84 | 'does not have an init_partition_spec.' 85 | ) 86 | return optax.MaskedState(inner_state=tuple(partition_specs)) 87 | 88 | return ShardedGradientTransformation( 89 | init=init_fn, update=update_fn, init_partition_spec=init_partition_spec_fn 90 | ) 91 | -------------------------------------------------------------------------------- /precondition/datamix_gemma/dataset_builders/preprocessed_dolly_dataset_builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Dataset builder for the Dolly dataset.""" 16 | 17 | import enum as Enum 18 | 19 | from precondition.datamix_gemma.dataset_builders import dataset_builder 20 | from precondition.datamix_gemma.tokenizers import gemma_tokenizer 21 | import tensorflow as tf 22 | 23 | 24 | dolly_path = '/home/xinyic/dolly/dolly_data.tfrecord' 25 | 26 | 27 | class DatasetSplit(Enum.Enum): 28 | TRAIN = 'train' 29 | 30 | 31 | class PreprocessedDollyDatasetBuilder(dataset_builder.DatasetBuilder): 32 | """Dataset builder for the Dolly dataset.""" 33 | 34 | def __init__( 35 | self, tokenizer: gemma_tokenizer.GemmaTokenizer, max_seq_len: int 36 | ): 37 | """Constructor. 38 | 39 | Args: 40 | tokenizer: Gemma tokenizer to use. 41 | max_seq_len: size of each sequence in a given batch. 42 | """ 43 | self._tokenizer = tokenizer 44 | self._base_data = tf.data.TFRecordDataset( 45 | [dolly_path], num_parallel_reads=tf.data.AUTOTUNE 46 | ) 47 | self._max_seq_len = max_seq_len 48 | 49 | def _to_training_input( 50 | self, 51 | input_tokens, 52 | target_mask, 53 | ): 54 | return dataset_builder.TrainingInput( # type: ignore 55 | input_tokens=input_tokens, # type:ignore 56 | target_mask=target_mask, # type:ignore 57 | ) # type: ignore 58 | 59 | 60 | def _decode_fn(self, record_bytes): 61 | parsed_features = tf.io.parse_example( 62 | record_bytes, 63 | { 64 | 'input_tokens': tf.io.FixedLenFeature((), tf.string), 65 | 'target_mask': tf.io.FixedLenFeature((), tf.string), 66 | }, 67 | ) 68 | decoded = { 69 | 'input_tokens': tf.io.decode_raw( 70 | parsed_features['input_tokens'], out_type=tf.int32 71 | ), 72 | 'target_mask': tf.io.decode_raw( 73 | parsed_features['target_mask'], out_type=tf.bool 74 | ), 75 | } 76 | return { 77 | 'input_tokens': self._pad_up_to_max_len( 78 | decoded['input_tokens'], self._tokenizer.pad_id 79 | ), 80 | 'target_mask': self._pad_up_to_max_len( 81 | decoded['target_mask'], False 82 | ), 83 | } 84 | 85 | def get_train_dataset(self, batch_size: int, num_epochs: int): 86 | """Build the training dataset.""" 87 | ds = self._base_data.map( 88 | self._decode_fn, num_parallel_calls=tf.data.AUTOTUNE 89 | ) 90 | ds = ds.map( 91 | lambda x: self._to_training_input(x['input_tokens'], x['target_mask']), 92 | num_parallel_calls=tf.data.AUTOTUNE, 93 | ) 94 | return ds 95 | 96 | def get_validation_dataset(self, batch_size: int): 97 | ds = self._base_data.map( 98 | self._decode_fn, num_parallel_calls=tf.data.AUTOTUNE 99 | ) 100 | ds = ds.map( 101 | lambda x: self._to_training_input(x['input_tokens'], x['target_mask']), 102 | num_parallel_calls=tf.data.AUTOTUNE, 103 | ) 104 | return ds 105 | -------------------------------------------------------------------------------- /precondition/datamix_gemma/dataset_builders/preprocessed_metamath_dataset_builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Dataset builder for the MetaMath dataset.""" 16 | 17 | import enum as Enum 18 | 19 | from precondition.datamix_gemma.dataset_builders import dataset_builder 20 | from precondition.datamix_gemma.tokenizers import gemma_tokenizer 21 | import tensorflow as tf 22 | 23 | 24 | metamath_path = '/home/xinyic/metamath/metamath_data.tfrecord' 25 | 26 | 27 | class DatasetSplit(Enum.Enum): 28 | TRAIN = 'train' 29 | 30 | 31 | class PreprocessedMetaMathDatasetBuilder(dataset_builder.DatasetBuilder): 32 | """Dataset builder for the MetaMath dataset.""" 33 | 34 | def __init__( 35 | self, tokenizer: gemma_tokenizer.GemmaTokenizer, max_seq_len: int 36 | ): 37 | """Constructor. 38 | 39 | Args: 40 | tokenizer: Gemma tokenizer to use. 41 | max_seq_len: size of each sequence in a given batch. 42 | """ 43 | self._tokenizer = tokenizer 44 | self._base_data = tf.data.TFRecordDataset( 45 | [metamath_path], num_parallel_reads=tf.data.AUTOTUNE 46 | ) 47 | self._max_seq_len = max_seq_len 48 | 49 | def _to_training_input( 50 | self, 51 | input_tokens, 52 | target_mask, 53 | ): 54 | return dataset_builder.TrainingInput( # type: ignore 55 | input_tokens=input_tokens, # type:ignore 56 | target_mask=target_mask, # type:ignore 57 | ) # type: ignore 58 | 59 | def _decode_fn(self, record_bytes): 60 | parsed_features = tf.io.parse_example( 61 | record_bytes, 62 | { 63 | 'input_tokens': tf.io.FixedLenFeature((), tf.string), 64 | 'target_mask': tf.io.FixedLenFeature((), tf.string), 65 | }, 66 | ) 67 | decoded ={ 68 | 'input_tokens': tf.io.decode_raw( 69 | parsed_features['input_tokens'], out_type=tf.int32 70 | ), 71 | 'target_mask': tf.io.decode_raw( 72 | parsed_features['target_mask'], out_type=tf.bool 73 | ), 74 | } 75 | return { 76 | 'input_tokens': self._pad_up_to_max_len( 77 | decoded['input_tokens'], self._tokenizer.pad_id 78 | ), 79 | 'target_mask': self._pad_up_to_max_len( 80 | decoded['target_mask'], False 81 | ), 82 | } 83 | 84 | def get_train_dataset(self, batch_size: int, num_epochs: int): 85 | """Build the training dataset.""" 86 | ds = self._base_data.map( 87 | self._decode_fn, num_parallel_calls=tf.data.AUTOTUNE 88 | ) 89 | ds = ds.map( 90 | lambda x: self._to_training_input(x['input_tokens'], x['target_mask']), 91 | num_parallel_calls=tf.data.AUTOTUNE, 92 | ) 93 | return ds 94 | 95 | def get_validation_dataset(self, batch_size: int): 96 | ds = self._base_data.map( 97 | self._decode_fn, num_parallel_calls=tf.data.AUTOTUNE 98 | ) 99 | ds = ds.map( 100 | lambda x: self._to_training_input(x['input_tokens'], x['target_mask']), 101 | num_parallel_calls=tf.data.AUTOTUNE, 102 | ) 103 | return ds 104 | -------------------------------------------------------------------------------- /precondition/datamix_gemma/dataset_builders/preprocessed_open_orca_dataset_builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Dataset builder for the Open Orca dataset.""" 16 | 17 | import enum as Enum 18 | 19 | from precondition.datamix_gemma.dataset_builders import dataset_builder 20 | from precondition.datamix_gemma.tokenizers import gemma_tokenizer 21 | import tensorflow as tf 22 | 23 | 24 | open_orca_path = '/home/shivguptashi/open_orca/open_orca_data.tfrecord' 25 | 26 | class DatasetSplit(Enum.Enum): 27 | TRAIN = 'train' 28 | 29 | 30 | class PreprocessedOpenOrcaDatasetBuilder(dataset_builder.DatasetBuilder): 31 | """Dataset builder for the Open Orca dataset.""" 32 | 33 | def __init__( 34 | self, tokenizer: gemma_tokenizer.GemmaTokenizer, max_seq_len: int 35 | ): 36 | """Constructor. 37 | 38 | Args: 39 | tokenizer: Gemma tokenizer to use. 40 | max_seq_len: size of each sequence in a given batch. 41 | """ 42 | self._tokenizer = tokenizer 43 | self._base_data = tf.data.TFRecordDataset( 44 | [open_orca_path], num_parallel_reads=tf.data.AUTOTUNE 45 | ) 46 | self._max_seq_len = max_seq_len 47 | 48 | def _to_training_input( 49 | self, 50 | input_tokens, 51 | target_mask, 52 | ): 53 | return dataset_builder.TrainingInput( # type: ignore 54 | input_tokens=input_tokens, # type:ignore 55 | target_mask=target_mask, # type:ignore 56 | ) # type: ignore 57 | 58 | def _decode_fn(self, record_bytes): 59 | parsed_features = tf.io.parse_example( 60 | record_bytes, 61 | { 62 | 'input_tokens': tf.io.FixedLenFeature((), tf.string), 63 | 'target_mask': tf.io.FixedLenFeature((), tf.string), 64 | }, 65 | ) 66 | decoded = { 67 | 'input_tokens': tf.io.decode_raw( 68 | parsed_features['input_tokens'], out_type=tf.int32 69 | ), 70 | 'target_mask': tf.io.decode_raw( 71 | parsed_features['target_mask'], out_type=tf.bool 72 | ), 73 | } 74 | return { 75 | 'input_tokens': self._pad_up_to_max_len( 76 | decoded['input_tokens'], self._tokenizer.pad_id 77 | ), 78 | 'target_mask': self._pad_up_to_max_len( 79 | decoded['target_mask'], False 80 | ), 81 | } 82 | 83 | def get_train_dataset(self, batch_size: int, num_epochs: int): 84 | """Build the training dataset.""" 85 | ds = self._base_data.map( 86 | self._decode_fn, num_parallel_calls=tf.data.AUTOTUNE 87 | ) 88 | ds = ds.map( 89 | lambda x: self._to_training_input(x['input_tokens'], x['target_mask']), 90 | num_parallel_calls=tf.data.AUTOTUNE, 91 | ) 92 | return ds 93 | 94 | def get_validation_dataset(self, batch_size: int): 95 | ds = self._base_data.map( 96 | self._decode_fn, num_parallel_calls=tf.data.AUTOTUNE 97 | ) 98 | ds = ds.map( 99 | lambda x: self._to_training_input(x['input_tokens'], x['target_mask']), 100 | num_parallel_calls=tf.data.AUTOTUNE, 101 | ) 102 | return ds 103 | -------------------------------------------------------------------------------- /precondition/datamix_gemma/dataset_builders/preprocessed_codealpaca_dataset_builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Dataset builder for the CodeAlpaca dataset.""" 16 | 17 | import enum as Enum 18 | 19 | from precondition.datamix_gemma.dataset_builders import dataset_builder 20 | from precondition.datamix_gemma.tokenizers import gemma_tokenizer 21 | import tensorflow as tf 22 | 23 | 24 | codealpaca_path = '/home/xinyic/codealpaca/codealpaca_data.tfrecord' 25 | 26 | 27 | class DatasetSplit(Enum.Enum): 28 | TRAIN = 'train' 29 | 30 | 31 | class PreprocessedCodeAlpacaDatasetBuilder(dataset_builder.DatasetBuilder): 32 | """Dataset builder for the CodeAlpaca dataset.""" 33 | 34 | def __init__( 35 | self, tokenizer: gemma_tokenizer.GemmaTokenizer, max_seq_len: int 36 | ): 37 | """Constructor. 38 | 39 | Args: 40 | tokenizer: Gemma tokenizer to use. 41 | max_seq_len: size of each sequence in a given batch. 42 | """ 43 | self._tokenizer = tokenizer 44 | self._base_data = tf.data.TFRecordDataset( 45 | [codealpaca_path], num_parallel_reads=tf.data.AUTOTUNE 46 | ) 47 | self._max_seq_len = max_seq_len 48 | 49 | def _to_training_input( 50 | self, 51 | input_tokens, 52 | target_mask, 53 | ): 54 | return dataset_builder.TrainingInput( # type: ignore 55 | input_tokens=input_tokens, # type:ignore 56 | target_mask=target_mask, # type:ignore 57 | ) # type: ignore 58 | 59 | def _decode_fn(self, record_bytes): 60 | parsed_features = tf.io.parse_example( 61 | record_bytes, 62 | { 63 | 'input_tokens': tf.io.FixedLenFeature((), tf.string), 64 | 'target_mask': tf.io.FixedLenFeature((), tf.string), 65 | }, 66 | ) 67 | decoded = { 68 | 'input_tokens': tf.io.decode_raw( 69 | parsed_features['input_tokens'], out_type=tf.int32 70 | ), 71 | 'target_mask': tf.io.decode_raw( 72 | parsed_features['target_mask'], out_type=tf.bool 73 | ), 74 | } 75 | return { 76 | 'input_tokens': self._pad_up_to_max_len( 77 | decoded['input_tokens'], self._tokenizer.pad_id 78 | ), 79 | 'target_mask': self._pad_up_to_max_len( 80 | decoded['target_mask'], False 81 | ), 82 | } 83 | 84 | def get_train_dataset(self, batch_size: int, num_epochs: int): 85 | """Build the training dataset.""" 86 | ds = self._base_data.map( 87 | self._decode_fn, num_parallel_calls=tf.data.AUTOTUNE 88 | ) 89 | ds = ds.map( 90 | lambda x: self._to_training_input(x['input_tokens'], x['target_mask']), 91 | num_parallel_calls=tf.data.AUTOTUNE, 92 | ) 93 | return ds 94 | 95 | def get_validation_dataset(self, batch_size: int): 96 | ds = self._base_data.map( 97 | self._decode_fn, num_parallel_calls=tf.data.AUTOTUNE 98 | ) 99 | ds = ds.map( 100 | lambda x: self._to_training_input(x['input_tokens'], x['target_mask']), 101 | num_parallel_calls=tf.data.AUTOTUNE, 102 | ) 103 | return ds 104 | -------------------------------------------------------------------------------- /precondition/datamix_gemma/dataset_builders/preprocessed_orca_math_dataset_builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Dataset builder for the Preprocessed Orca Math dataset.""" 16 | 17 | import enum as Enum 18 | 19 | from precondition.datamix_gemma.dataset_builders import dataset_builder 20 | from precondition.datamix_gemma.tokenizers import gemma_tokenizer 21 | import tensorflow as tf 22 | 23 | 24 | orca_math_path = '/home/shivguptashi/orca_math/orca_math_data.tfrecord' 25 | 26 | 27 | class DatasetSplit(Enum.Enum): 28 | TRAIN = 'train' 29 | 30 | 31 | class PreprocessedOrcaMathDatasetBuilder(dataset_builder.DatasetBuilder): 32 | """Dataset builder for the Open Orca dataset.""" 33 | 34 | def __init__( 35 | self, tokenizer: gemma_tokenizer.GemmaTokenizer, max_seq_len: int 36 | ): 37 | """Constructor. 38 | 39 | Args: 40 | tokenizer: Gemma tokenizer to use. 41 | max_seq_len: size of each sequence in a given batch. 42 | """ 43 | self._tokenizer = tokenizer 44 | self._base_data = tf.data.TFRecordDataset( 45 | [orca_math_path], num_parallel_reads=tf.data.AUTOTUNE 46 | ) 47 | self._max_seq_len = max_seq_len 48 | 49 | def _to_training_input( 50 | self, 51 | input_tokens, 52 | target_mask, 53 | ): 54 | return dataset_builder.TrainingInput( # type: ignore 55 | input_tokens=input_tokens, # type:ignore 56 | target_mask=target_mask, # type:ignore 57 | ) # type: ignore 58 | 59 | def get_train_dataset(self, batch_size: int, num_epochs: int): 60 | """Build the training dataset.""" 61 | ds = self._base_data.map( 62 | self._decode_fn, num_parallel_calls=tf.data.AUTOTUNE 63 | ) 64 | ds = ds.map( 65 | lambda x: self._to_training_input(x['input_tokens'], x['target_mask']), 66 | num_parallel_calls=tf.data.AUTOTUNE, 67 | ) 68 | return ds 69 | 70 | def _decode_fn(self, record_bytes): 71 | parsed_features = tf.io.parse_example( 72 | record_bytes, 73 | { 74 | 'input_tokens': tf.io.FixedLenFeature((), tf.string), 75 | 'target_mask': tf.io.FixedLenFeature((), tf.string), 76 | }, 77 | ) 78 | decoded = { 79 | 'input_tokens': tf.io.decode_raw( 80 | parsed_features['input_tokens'], out_type=tf.int32 81 | ), 82 | 'target_mask': tf.io.decode_raw( 83 | parsed_features['target_mask'], out_type=tf.bool 84 | ), 85 | } 86 | return { 87 | 'input_tokens': self._pad_up_to_max_len( 88 | decoded['input_tokens'], self._tokenizer.pad_id 89 | ), 90 | 'target_mask': self._pad_up_to_max_len( 91 | decoded['target_mask'], False 92 | ), 93 | } 94 | 95 | def get_validation_dataset(self, batch_size: int): 96 | ds = self._base_data.map( 97 | self._decode_fn, num_parallel_calls=tf.data.AUTOTUNE 98 | ) 99 | ds = ds.map( 100 | lambda x: self._to_training_input(x['input_tokens'], x['target_mask']), 101 | num_parallel_calls=tf.data.AUTOTUNE, 102 | ) 103 | return ds 104 | -------------------------------------------------------------------------------- /precondition/datamix_gemma/dataset_builders/preprocessed_gsm8k_dataset_builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Dataset builder for the Preprocessed GSM8K dataset.""" 16 | 17 | import enum as Enum 18 | 19 | from precondition.datamix_gemma.dataset_builders import dataset_builder 20 | from precondition.datamix_gemma.tokenizers import gemma_tokenizer 21 | import tensorflow as tf 22 | 23 | 24 | gsm8k_preprocessed_path = '/home/shivguptashi/gsm8k_train/gsm8k_train.tfrecord' 25 | 26 | 27 | class DatasetSplit(Enum.Enum): 28 | TRAIN = 'train' 29 | 30 | 31 | class PreprocessedGSM8KDatasetBuilder(dataset_builder.DatasetBuilder): 32 | """Dataset builder for the Preprocessed GSM8k dataset.""" 33 | 34 | def __init__( 35 | self, tokenizer: gemma_tokenizer.GemmaTokenizer, max_seq_len: int 36 | ): 37 | """Constructor. 38 | 39 | Args: 40 | tokenizer: Gemma tokenizer to use. 41 | max_seq_len: size of each sequence in a given batch. 42 | """ 43 | self._tokenizer = tokenizer 44 | self._base_data = tf.data.TFRecordDataset( 45 | [gsm8k_preprocessed_path], num_parallel_reads=tf.data.AUTOTUNE 46 | ) 47 | self._max_seq_len = max_seq_len 48 | 49 | def _to_training_input( 50 | self, 51 | input_tokens, 52 | target_mask, 53 | ): 54 | return dataset_builder.TrainingInput( # type: ignore 55 | input_tokens=input_tokens, # type:ignore 56 | target_mask=target_mask, # type:ignore 57 | ) # type: ignore 58 | 59 | 60 | def _decode_fn(self, record_bytes): 61 | parsed_features = tf.io.parse_example( 62 | record_bytes, 63 | { 64 | 'input_tokens': tf.io.FixedLenFeature((), tf.string), 65 | 'target_mask': tf.io.FixedLenFeature((), tf.string), 66 | }, 67 | ) 68 | decoded = { 69 | 'input_tokens': tf.io.decode_raw( 70 | parsed_features['input_tokens'], out_type=tf.int32 71 | ), 72 | 'target_mask': tf.io.decode_raw( 73 | parsed_features['target_mask'], out_type=tf.bool 74 | ), 75 | } 76 | return { 77 | 'input_tokens': self._pad_up_to_max_len( 78 | decoded['input_tokens'], self._tokenizer.pad_id 79 | ), 80 | 'target_mask': self._pad_up_to_max_len( 81 | decoded['target_mask'], False 82 | ), 83 | } 84 | 85 | def get_train_dataset(self, batch_size: int, num_epochs: int): 86 | """Build the training dataset.""" 87 | ds = self._base_data.map( 88 | self._decode_fn, num_parallel_calls=tf.data.AUTOTUNE 89 | ) 90 | ds = ds.map( 91 | lambda x: self._to_training_input(x['input_tokens'], x['target_mask']), 92 | num_parallel_calls=tf.data.AUTOTUNE, 93 | ) 94 | return ds 95 | 96 | def get_validation_dataset(self, batch_size: int): 97 | ds = self._base_data.map( 98 | self._decode_fn, num_parallel_calls=tf.data.AUTOTUNE 99 | ) 100 | ds = ds.map( 101 | lambda x: self._to_training_input(x['input_tokens'], x['target_mask']), 102 | num_parallel_calls=tf.data.AUTOTUNE, 103 | ) 104 | return ds 105 | -------------------------------------------------------------------------------- /precondition/datamix_gemma/training_batch_generators/importance_weighting_training_batch_generator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ImportanceWeightingTrainingBatchGenerator.""" 16 | 17 | import itertools 18 | 19 | from absl import logging 20 | import numpy as np 21 | from precondition.datamix_gemma.training_batch_generators import training_batch_generator 22 | import tensorflow_datasets as tfds 23 | 24 | 25 | class ImportanceWeightingTrainingBatchGenerator( 26 | training_batch_generator.TrainingBatchGenerator 27 | ): 28 | """ImportanceWeightingTrainingBatchGenerator.""" 29 | 30 | def __init__( 31 | self, train_ds_builders, batch_size, num_weights=2, num_iterations=100 32 | ): 33 | super().__init__(train_ds_builders, batch_size, num_weights, num_iterations) 34 | self.training_iters_lists = [] 35 | for _ in range(self.num_weights): 36 | self.training_iters_lists.append([]) 37 | 38 | for dataset_builder_obj in self.train_ds_builders: 39 | cur_iter = iter( 40 | tfds.as_numpy( 41 | dataset_builder_obj.get_train_dataset( 42 | batch_size=batch_size, num_epochs=1 43 | ) 44 | ) 45 | ) 46 | iter_list = itertools.tee(cur_iter, self.num_weights) 47 | for i in range(self.num_weights): 48 | self.training_iters_lists[i].append(iter_list[i]) 49 | #self.avg_weights = np.zeros(len(self.weights_list[0])) 50 | self.avg_weights = [] 51 | self.weights_list = [] 52 | self.sample_choices = [] 53 | 54 | def prepare_for_training(self, weights_list, new_unnormalized_weights): 55 | """Prepare for training.""" 56 | self.weights_list = weights_list 57 | self.avg_weights = np.zeros(len(self.weights_list[0])) 58 | for i in range(len(self.weights_list)): 59 | self.avg_weights += self.weights_list[i] 60 | self.avg_weights /= len(self.weights_list) 61 | 62 | logging.info(f'Avg weights: {self.avg_weights}') 63 | self.sample_choices = np.random.choice( 64 | len(self.avg_weights), 65 | size=self.batch_size, 66 | p=self.avg_weights, 67 | ) 68 | return 1 69 | 70 | def get_next_batch(self, index): 71 | logging.info('Getting next batch') 72 | training_iters = self.training_iters_lists[index] 73 | input_tokens_batch = [] 74 | input_mask_batch = [] 75 | factors = np.zeros(self.batch_size) 76 | for i in range(self.batch_size): 77 | sample_choice = self.sample_choices[i] 78 | try: 79 | cur_example = next(training_iters[sample_choice]) 80 | except StopIteration: 81 | training_iters[sample_choice] = iter( 82 | tfds.as_numpy( 83 | self.train_ds_builders[sample_choice].get_train_dataset( 84 | batch_size=self.batch_size, num_epochs=1 85 | ) 86 | ) 87 | ) 88 | cur_example = next(training_iters[sample_choice]) 89 | factors[i] = self.weights_list[index][sample_choice] / self.avg_weights[sample_choice] #pytype: disable=attribute-error 90 | input_tokens_batch.append(np.asarray([cur_example.input_tokens])) 91 | input_mask_batch.append(np.asarray([cur_example.target_mask])) 92 | factors *= len(factors)/np.sum(factors) 93 | 94 | return factors, input_tokens_batch, input_mask_batch 95 | -------------------------------------------------------------------------------- /precondition/datamix_gemma/dataset_builders/preprocessed_wikipedia_dataset_builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Dataset builder for the Wikipedia dataset.""" 16 | 17 | import enum as Enum 18 | 19 | from precondition.datamix_gemma.dataset_builders import dataset_builder 20 | from precondition.datamix_gemma.tokenizers import gemma_tokenizer 21 | import tensorflow as tf 22 | 23 | 24 | wiki_path_prefix = '/home/shivguptashi/open_orca/wiki_tokenized' 25 | 26 | 27 | class DatasetSplit(Enum.Enum): 28 | TRAIN = 'train' 29 | 30 | 31 | class PreprocessedWikipediaDatasetBuilder(dataset_builder.DatasetBuilder): 32 | """Dataset builder for the Open Orca dataset.""" 33 | 34 | def __init__( 35 | self, 36 | tokenizer: gemma_tokenizer.GemmaTokenizer, 37 | max_seq_len: int, 38 | topic: int, 39 | ): 40 | """Constructor. 41 | 42 | Args: 43 | tokenizer: Gemma tokenizer to use. 44 | max_seq_len: size of each sequence in a given batch. 45 | topic: topic to use for training. 46 | """ 47 | self._tokenizer = tokenizer 48 | cur_wiki_path = wiki_path_prefix + '_topic_' + str(topic) + '.tfrecord' 49 | self._base_data = tf.data.TFRecordDataset( 50 | [cur_wiki_path], num_parallel_reads=tf.data.AUTOTUNE 51 | ) 52 | self._max_seq_len = max_seq_len 53 | 54 | def _to_training_input( 55 | self, 56 | input_tokens, 57 | target_mask, 58 | ): 59 | return dataset_builder.TrainingInput( # type: ignore 60 | input_tokens=input_tokens, # type:ignore 61 | target_mask=target_mask, # type:ignore 62 | ) # type: ignore 63 | 64 | def _decode_fn(self, record_bytes): 65 | parsed_features = tf.io.parse_example( 66 | record_bytes, 67 | { 68 | 'input_tokens': tf.io.FixedLenFeature((), tf.string), 69 | 'target_mask': tf.io.FixedLenFeature((), tf.string), 70 | }, 71 | ) 72 | decoded = { 73 | 'input_tokens': tf.io.decode_raw( 74 | parsed_features['input_tokens'], out_type=tf.int32 75 | ), 76 | 'target_mask': tf.io.decode_raw( 77 | parsed_features['target_mask'], out_type=tf.bool 78 | ), 79 | } 80 | 81 | return { 82 | 'input_tokens': self._pad_up_to_max_len( 83 | decoded['input_tokens'], self._tokenizer.pad_id 84 | ), 85 | 'target_mask': self._pad_up_to_max_len( 86 | decoded['target_mask'], False 87 | ), 88 | } 89 | 90 | def get_train_dataset(self, batch_size: int, num_epochs: int): 91 | """Build the training dataset.""" 92 | ds = self._base_data.map( 93 | self._decode_fn, num_parallel_calls=tf.data.AUTOTUNE 94 | ) 95 | ds = ds.map( 96 | lambda x: self._to_training_input(x['input_tokens'], x['target_mask']), 97 | num_parallel_calls=tf.data.AUTOTUNE, 98 | ) 99 | return ds 100 | 101 | def get_validation_dataset(self, batch_size: int): 102 | ds = self._base_data.map( 103 | self._decode_fn, num_parallel_calls=tf.data.AUTOTUNE 104 | ) 105 | ds = ds.map( 106 | lambda x: self._to_training_input(x['input_tokens'], x['target_mask']), 107 | num_parallel_calls=tf.data.AUTOTUNE, 108 | ) 109 | return ds 110 | -------------------------------------------------------------------------------- /precondition/tearfree/second_order.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Various strategies for tracking second order statistics.""" 16 | 17 | import dataclasses 18 | import enum 19 | from typing import Optional 20 | 21 | import optax 22 | from precondition.tearfree import praxis_shim 23 | from precondition.tearfree import reshaper 24 | from precondition.tearfree import shampoo 25 | from precondition.tearfree import sketchy 26 | 27 | 28 | @enum.unique 29 | class SecondOrderType(enum.Enum): 30 | """Different second order covariance tracking methods.""" 31 | 32 | SHAMPOO = 'shampoo' 33 | SKETCHY = 'sketchy' 34 | 35 | 36 | @dataclasses.dataclass 37 | class Options: 38 | """Toggle which second order statistics to track. 39 | 40 | Attributes: 41 | merge_dims: Merges small dimensions, see `reshaper.Options.merge_dims`. 42 | second_order_type: Which optimizer to use for grafting updates. 43 | shampoo_options: Options for blocked shampoo. 44 | sketchy_options: Options for Sketchy. 45 | """ 46 | 47 | merge_dims: int = 1024 48 | second_order_type: SecondOrderType = SecondOrderType.SHAMPOO 49 | shampoo_options: Optional[shampoo.Options] = dataclasses.field( 50 | default_factory=shampoo.Options 51 | ) 52 | sketchy_options: Optional[sketchy.Options] = None 53 | 54 | 55 | def apply(options: Options) -> praxis_shim.ShardedGradientTransformation: 56 | """Generate the second order update from options.""" 57 | reshaper_options = _reshaper_options(options) 58 | merge_tx = reshaper.merge(reshaper_options) 59 | precond_tx = _update_stats_and_precondition(options) 60 | 61 | def wrap_init(params: optax.Params): 62 | reshaped_params, _ = merge_tx.update(params, merge_tx.init(params), params) 63 | return precond_tx.init(reshaped_params) 64 | 65 | # TODO(vladf): later, we'll need to wrap pspec as well. 66 | wrapped_precond_tx = praxis_shim.ShardedGradientTransformation( 67 | wrap_init, precond_tx.update, precond_tx.init_partition_spec 68 | ) 69 | 70 | return praxis_shim.sharded_chain( 71 | merge_tx, 72 | wrapped_precond_tx, 73 | reshaper.unmerge(reshaper_options), 74 | ) 75 | 76 | 77 | def _reshaper_options(options: Options) -> reshaper.Options: 78 | if options.second_order_type == SecondOrderType.SHAMPOO: 79 | assert options.shampoo_options 80 | block_size = options.shampoo_options.block_size 81 | return reshaper.Options(options.merge_dims, block_size) 82 | if options.second_order_type == SecondOrderType.SKETCHY: 83 | return reshaper.Options(options.merge_dims, 0) 84 | else: 85 | raise ValueError( 86 | 'unknown second order type {}'.format(options.second_order_type) 87 | ) 88 | 89 | 90 | def _update_stats_and_precondition( 91 | options: Options, 92 | ) -> praxis_shim.ShardedGradientTransformation: 93 | if options.second_order_type == SecondOrderType.SHAMPOO: 94 | assert options.shampoo_options 95 | return shampoo.apply(options.shampoo_options) 96 | if options.second_order_type == SecondOrderType.SKETCHY: 97 | assert options.sketchy_options 98 | return sketchy.apply(options.sketchy_options) 99 | else: 100 | raise ValueError( 101 | 'unknown second order type {}'.format(options.second_order_type) 102 | ) 103 | -------------------------------------------------------------------------------- /precondition/tearfree/optimizer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tearfree optimizer implementation. 16 | 17 | OOM making your eyes water? Try the Tearfree Shampoo optimizer. 18 | 19 | This module handles logic for 20 | 21 | 1. Statistics/preconditioner update frequency 22 | 2. Applying momentum 23 | 3. Combining grafting and preconditioning updates, applying grafting 24 | 4. Typical update procedures, like learning rate, momentum, etc. 25 | """ 26 | 27 | import dataclasses 28 | from typing import Union 29 | 30 | import chex 31 | import optax 32 | from precondition.tearfree import grafting 33 | from precondition.tearfree import momentum 34 | from precondition.tearfree import praxis_shim 35 | from precondition.tearfree import second_order 36 | 37 | 38 | @dataclasses.dataclass 39 | class TearfreeOptions: 40 | """Configuration dataclass for tearfree optimizer. 41 | 42 | Attributes: 43 | grafting_options: Grafting options to modify update norm (see 44 | `grafting.Options`). 45 | second_order_options: Second-order statistics tracking options (see 46 | `second_order.Options`). 47 | momentum_options: Momentum options (see `momentum.Options`). 48 | """ 49 | 50 | grafting_options: grafting.Options = dataclasses.field( 51 | default_factory=grafting.Options 52 | ) 53 | second_order_options: second_order.Options = dataclasses.field( 54 | default_factory=second_order.Options 55 | ) 56 | momentum_options: momentum.Options = dataclasses.field( 57 | default_factory=momentum.Options 58 | ) 59 | 60 | 61 | def tearfree( 62 | learning_rate: Union[chex.Numeric, optax.Schedule], 63 | options: TearfreeOptions, 64 | ) -> praxis_shim.ShardedGradientTransformation: 65 | """Tearfree optimizer, supports pjit and jit. 66 | 67 | Preconditioned, grafted updates with momentum. 68 | 69 | One key difference in the logic is to only use a single momentum between 70 | the graft and preconditioned update. `distributed_shampoo` keeps a separate 71 | `diagonal_momentum` buffer, but never uses it after preconditioning is 72 | active (it is not used to adjust the grafting norm). This implies (1) 73 | we save memory (only one momentum buffer), (2) we are identical to 74 | `distributed_shampoo` if there is no warmup or no preconditioning 75 | (`options.start_preconditioning_step` is inf or 0). 76 | 77 | Args: 78 | learning_rate: The learning rate value or schedule. Learning rate is 79 | "decoupled", i.e., we always apply it last to the update (after weight 80 | decay, after momentum, etc.). 81 | options: Tearfree optimizer options. 82 | 83 | Returns: 84 | The sharded gradient transformation corresponding to an updated, 85 | preconditioned gradient, times the negative learning rate. 86 | """ 87 | 88 | second_order_tx = second_order.apply(options.second_order_options) 89 | graft_tx = grafting.graft(options.grafting_options, second_order_tx) 90 | momentum_tx = momentum.apply(options.momentum_options) 91 | if callable(learning_rate): 92 | lr_tx = optax.scale_by_schedule(lambda x: -1.0 * learning_rate(x)) 93 | else: 94 | lr_tx = optax.scale(-1.0 * learning_rate) 95 | return praxis_shim.sharded_chain( 96 | graft_tx, 97 | momentum_tx, 98 | lr_tx, 99 | ) 100 | -------------------------------------------------------------------------------- /precondition/datamix_gemma/dataset_builders/preprocessed_sciq_dataset_builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Dataset builder for the Preprocessed SciQ dataset.""" 16 | 17 | import enum as Enum 18 | 19 | from precondition.datamix_gemma.dataset_builders import dataset_builder 20 | from precondition.datamix_gemma.tokenizers import gemma_tokenizer 21 | import tensorflow as tf 22 | 23 | 24 | sciq_train_path = '/home/shivguptashi/sciq/sciq_train.tfrecord' 25 | sciq_validation_path = '/home/shivguptashi/sciq/sciq_validation.tfrecord' 26 | 27 | class DatasetSplit(Enum.Enum): 28 | TRAIN = 'train' 29 | VALIDATION = 'validation' 30 | 31 | 32 | class PreprocessedSciQDatasetBuilder(dataset_builder.DatasetBuilder): 33 | """Dataset builder for the SciQ dataset.""" 34 | 35 | def __init__( 36 | self, tokenizer: gemma_tokenizer.GemmaTokenizer, max_seq_len: int 37 | ): 38 | """Constructor. 39 | 40 | Args: 41 | tokenizer: Gemma tokenizer to use. 42 | max_seq_len: size of each sequence in a given batch. 43 | """ 44 | self._tokenizer = tokenizer 45 | self._train_data = tf.data.TFRecordDataset( 46 | [sciq_train_path], num_parallel_reads=tf.data.AUTOTUNE 47 | ) 48 | self._validation_data = tf.data.TFRecordDataset( 49 | [sciq_validation_path], num_parallel_reads=tf.data.AUTOTUNE 50 | ) 51 | self._max_seq_len = max_seq_len 52 | 53 | def _to_training_input( 54 | self, 55 | input_tokens, 56 | target_mask, 57 | ): 58 | return dataset_builder.TrainingInput( # type: ignore 59 | input_tokens=input_tokens, # type:ignore 60 | target_mask=target_mask, # type:ignore 61 | ) # type: ignore 62 | 63 | def _decode_fn(self, record_bytes): 64 | parsed_features = tf.io.parse_example( 65 | record_bytes, 66 | { 67 | 'input_tokens': tf.io.FixedLenFeature((), tf.string), 68 | 'target_mask': tf.io.FixedLenFeature((), tf.string), 69 | }, 70 | ) 71 | decoded = { 72 | 'input_tokens': tf.io.decode_raw( 73 | parsed_features['input_tokens'], out_type=tf.int32 74 | ), 75 | 'target_mask': tf.io.decode_raw( 76 | parsed_features['target_mask'], out_type=tf.bool 77 | ), 78 | } 79 | return { 80 | 'input_tokens': self._pad_up_to_max_len( 81 | decoded['input_tokens'], self._tokenizer.pad_id 82 | ), 83 | 'target_mask': self._pad_up_to_max_len( 84 | decoded['target_mask'], False 85 | ), 86 | } 87 | 88 | def get_train_dataset(self, batch_size: int, num_epochs: int): 89 | """Build the training dataset.""" 90 | ds = self._train_data.map( 91 | self._decode_fn, num_parallel_calls=tf.data.AUTOTUNE 92 | ) 93 | ds = ds.map( 94 | lambda x: self._to_training_input(x['input_tokens'], x['target_mask']), 95 | num_parallel_calls=tf.data.AUTOTUNE, 96 | ) 97 | return ds 98 | 99 | def get_validation_dataset(self, batch_size: int): 100 | ds = self._validation_data.map( 101 | self._decode_fn, num_parallel_calls=tf.data.AUTOTUNE 102 | ) 103 | ds = ds.map( 104 | lambda x: self._to_training_input(x['input_tokens'], x['target_mask']), 105 | num_parallel_calls=tf.data.AUTOTUNE, 106 | ) 107 | return ds 108 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # `precondition`: Preconditioning Optimizers 2 | 3 | [![Unittests](https://github.com/google-research/precondition/actions/workflows/pytest_and_autopublish.yml/badge.svg)](https://github.com/google-research/precondition/actions/workflows/pytest_and_autopublish.yml) 4 | [![PyPI version](https://badge.fury.io/py/precondition-opt.svg)](https://badge.fury.io/py/precondition-opt) 5 | 6 | Installation (note package name is `precondition` but pypi distribution name is `precondition-opt`): 7 | ``` 8 | pip3 install -U precondition-opt 9 | ``` 10 | 11 | Currently, this contains several preconditioning optimizer implementations. Please refer to the citations below. 12 | 13 | Shampoo (`distributed_shampoo.py`) 14 | 15 | ``` 16 | @article{anil2020scalable, 17 | title={Scalable second order optimization for deep learning}, 18 | author={Anil, Rohan and Gupta, Vineet and Koren, Tomer and Regan, Kevin and Singer, Yoram}, 19 | journal={arXiv preprint arXiv:2002.09018}, 20 | year={2020} 21 | } 22 | ``` 23 | 24 | Sketchy (`distributed_shampoo.py`), logical reference implementation as a branch in Shampoo. 25 | ``` 26 | @article{feinberg2023sketchy, 27 | title={Sketchy: Memory-efficient Adaptive Regularization with Frequent Directions}, 28 | author={Feinberg, Vladimir and Chen, Xinyi and Sun, Y Jennifer and Anil, Rohan and Hazan, Elad}, 29 | journal={arXiv preprint arXiv:2302.03764}, 30 | year={2023} 31 | } 32 | ``` 33 | In Appendix A of the aforementioned paper, S-Adagrad is tested along with other optimization algorithms (including RFD-SON, Adagrad, OGD, Ada-FD, FD-SON) on three benchmark datasets (`a9a`, `cifar10`, `gisette_scale`). To recreate the result in Appendix A, first download the benchmark datasets (available at https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/) to your a local folder (e.g. `~/data`). For each `DATASET = {'a9a', 'cifar10', 'gisette'}`, run the following two commands. The first command will run a sweep over all the hyperparameter sets, and the second command will plot a graph of the best set of hyperparameters from the previous run. Consistent with the fair allocation of hyperparameter training budget, we tune delta and the learning rate, each over 7 points uniformly spaced from 1e-6 to 1 in logarithmic scale, for FD-SON and Ada-FD. For the rest of the optimization algorithms, where delta is taken to be 0, we tune the learning rate over 49 points uniformly spaced from 1e-6 to 1 in logarithmic scale. The commands are provided as following: 34 | 35 | For running sweep over different sets of hyperparameters, run: (... hides different values of the hyperparameters to sweep over) 36 | ``` 37 | python3 oco/sweep.py --data_dir='~/data' --dataset=DATASET --save_dir='/tmp/results' \ 38 | --lr=1 ... -lr=1e-6 --alg=ADA, --alg=OGD, --alg=S_ADA, --alg=RFD_SON 39 | ``` 40 | and 41 | ``` 42 | python3 oco/sweep.py --data_dir='~/data' --dataset=DATASET --save_dir='/tmp/results' --lr=1 ... --lr=1e-6 \ 43 | --delta=1 --delta=1e-6 --alg=ADA_FD, --alg=FD_SON 44 | ``` 45 | After running the above commands, the results will be saved in the folders named with the time stamps at execution (e.g. '/tmp/results/YYYY-MM-DD' and '/tmp/results/yyyy-mm-dd'). 46 | 47 | To plot the best set of hyperparameters from the previous run, run: 48 | ``` 49 | python3 oco/sweep.py --data_dir='~/data' --dataset=DATASET --save_dir='/tmp/results' \ 50 | --use_best_from='/tmp/results/YYYY-MM-DD' --use_best_from='/tmp/results/yyyy-mm-dd' 51 | ``` 52 | For detailed documentations on the supported flags, run: 53 | ``` 54 | python3 oco/sweep.py --help 55 | ``` 56 | 57 | SM3 (`sm3.py`). 58 | ``` 59 | @article{anil2020scalable, 60 | title={Scalable second order optimization for deep learning}, 61 | author={Anil, Rohan and Gupta, Vineet and Koren, Tomer and Regan, Kevin and Singer, Yoram}, 62 | journal={arXiv preprint arXiv:2002.09018}, 63 | year={2020} 64 | } 65 | ``` 66 | 67 | This external repository was seeded from existing open-source work available at [this google-research repository](https://github.com/google-research/google-research/tree/master/scalable_shampoo). 68 | 69 | 70 | *This is not an officially supported Google product.* 71 | -------------------------------------------------------------------------------- /precondition/datamix_gemma/training_batch_generators/fixed_dataset_importance_weighting_training_batch_generator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """FixedDatasetImportanceWeightingTrainingBatchGenerator.""" 16 | 17 | from absl import logging 18 | import numpy as np 19 | from precondition.datamix_gemma.training_batch_generators import training_batch_generator 20 | import tensorflow_datasets as tfds 21 | 22 | 23 | class FixedDatasetImportanceWeightingTrainingBatchGenerator( 24 | training_batch_generator.TrainingBatchGenerator 25 | ): 26 | """FixedDatasetImportanceWeightingTrainingBatchGenerator.""" 27 | 28 | def __init__(self, train_ds_builders, batch_size, num_weights=2, num_iterations=100): 29 | super().__init__(train_ds_builders, batch_size, num_weights, num_iterations) 30 | self.training_iters = [] 31 | for dataset_builder_obj in self.train_ds_builders: 32 | cur_iter = iter( 33 | tfds.as_numpy( 34 | dataset_builder_obj.get_train_dataset( 35 | batch_size=batch_size, num_epochs=1 36 | ) 37 | ) 38 | ) 39 | self.training_iters.append(cur_iter) 40 | self.examples = [] 41 | num_datasets = len(self.training_iters) 42 | self.sample_choices = np.random.choice( 43 | num_datasets, 44 | size=self.batch_size * self.num_iterations, 45 | p=np.ones(num_datasets)/num_datasets, 46 | ) 47 | for i in range(self.batch_size * num_iterations): 48 | try: 49 | self.examples.append(next(self.training_iters[self.sample_choices[i]])) 50 | except StopIteration: 51 | self.training_iters[self.sample_choices[i]] = iter( 52 | tfds.as_numpy( 53 | self.train_ds_builders[self.sample_choices[i]].get_train_dataset( 54 | batch_size=self.batch_size, num_epochs=1 55 | ) 56 | ) 57 | ) 58 | self.examples.append(next(self.training_iters[self.sample_choices[i]])) 59 | 60 | #self.input_tokens_batch = np.asarray([[example.input_tokens] for example in self.examples]) 61 | #self.input_mask_batch = np.asarray([[example.target_mask] for example in self.examples]) 62 | self.weights_list = [] 63 | self.indices = [] 64 | self.factors = [] 65 | #self.avg_weights = np.zeros(len(self.weights_list[0])) 66 | 67 | def prepare_for_training(self, weights_list, new_unnormalized_weights): 68 | """Prepare for training.""" 69 | self.weights_list = weights_list 70 | self.indices = [0 for _ in range(self.num_weights)] 71 | self.factors = [np.zeros(self.batch_size * self.num_iterations) for _ in range(self.num_weights)] 72 | for i in range(self.num_weights): 73 | for j in range(self.batch_size * self.num_iterations): 74 | self.factors[i][j] = self.weights_list[i][self.sample_choices[j]] 75 | self.factors = (self.factors[i] / np.sum(self.factors[i])) * len(self.factors[i]) 76 | return 1 77 | 78 | def get_next_batch(self, index): 79 | logging.info(f'Getting next batch, batch_size={self.batch_size}, weights_list={self.weights_list}') 80 | logging.info(f'sample choices len: {len(self.sample_choices)}, examples len: {len(self.examples)}') 81 | cur_factors = self.factors[index][self.indices[index]:(self.indices[index]+self.batch_size)] 82 | cur_examples = self.examples[self.indices[index]:(self.indices[index]+self.batch_size)] 83 | cur_input_tokens_batch = np.asarray([[example.input_tokens] for example in cur_examples]) 84 | cur_input_mask_batch = np.asarray([[example.target_mask] for example in cur_examples]) 85 | self.indices[index] += self.batch_size 86 | return cur_factors, cur_input_tokens_batch, cur_input_mask_batch 87 | -------------------------------------------------------------------------------- /precondition/oco/train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utilities for running an OCO algorithm on a dataset.""" 16 | 17 | import functools 18 | from typing import Callable, Optional 19 | 20 | import jax 21 | import jax.numpy as jnp 22 | import numpy as np 23 | from precondition.oco import algorithms 24 | from precondition.oco import datasets 25 | 26 | LossAndGrad = Callable[ 27 | [jax.Array, jax.Array, jax.Array], tuple[jax.Array, jax.Array] 28 | ] 29 | 30 | 31 | @functools.partial( 32 | jax.jit, 33 | static_argnames=[ 34 | 'loss_and_grad', 35 | 'update_fn', 36 | # 'algorithm', 37 | # 'sketch_size', 38 | 'extra_loss', 39 | ], 40 | ) 41 | def _compiled_run_dataset( 42 | x: jax.Array, 43 | y: jax.Array, 44 | state: algorithms.State, 45 | obs_ixs: jax.Array, 46 | # delta: algorithms.RuntimeScalar, 47 | # lr: algorithms.RuntimeScalar, 48 | loss_and_grad: LossAndGrad, 49 | update_fn: algorithms.UpdateFn, 50 | # algorithm: algorithms.Algorithm, 51 | # sketch_size: int, 52 | extra_loss: Optional[datasets.Loss], 53 | ) -> algorithms.State: 54 | """Run an OCO algorithm, saving history at obs_ixs.""" 55 | 56 | # hparams = algorithms.HParams( 57 | # delta, 58 | # lr, 59 | # sketch_size, 60 | # algorithm, 61 | # ) 62 | 63 | # assume obs_ixs starts at 0 and ends at nrows-1 64 | # assume state has various keys pre-initialized, see below. 65 | 66 | def process_row(idx, state): # fori_loop index, so pylint: disable=unused-argument 67 | ix = state['n'] 68 | r = x[ix] 69 | f, g = loss_and_grad(state['w'], r, y[ix]) 70 | if extra_loss is not None: 71 | state['extra_loss'] += extra_loss(state['w'], r, y[ix]) 72 | state = update_fn(state, f, g) 73 | state['loss'] += f 74 | state['n'] += 1 75 | 76 | return state 77 | 78 | chunks = jnp.diff(obs_ixs, prepend=0) 79 | 80 | def scan_reduce(state, chunksize): 81 | state = jax.lax.fori_loop(0, chunksize, process_row, state) 82 | return state, state 83 | 84 | _, history = jax.lax.scan(scan_reduce, state, chunks) 85 | 86 | return history 87 | 88 | 89 | def run_dataset( 90 | dataset_name: str, 91 | num_obs: int, 92 | hparams: algorithms.HParams, 93 | extra_loss: Optional[datasets.Loss] = None, 94 | dataset_cache: str = '/tmp/cache', 95 | ) -> algorithms.State: 96 | """Run an OCO algorithm on a dataset, saving history at obs_ixs.""" 97 | assert num_obs >= 2 98 | 99 | dataset = datasets.load_dataset(dataset_name, dataset_cache) 100 | init_fn, update_fn = algorithms.generate_init_update(dataset.w_shape, hparams) 101 | 102 | obs_ixs = np.round( 103 | np.linspace(0, dataset.x.shape[0], num=num_obs, endpoint=True) 104 | ).astype(int) 105 | 106 | initial_state = init_fn() 107 | loss_and_grad = jax.value_and_grad(dataset.loss) 108 | 109 | assert 'loss' not in initial_state, list(initial_state) 110 | assert 'w' in initial_state, list(initial_state) 111 | assert 'n' not in initial_state, list(initial_state) 112 | initial_state['loss'] = jnp.array(0.0, dtype=jnp.float64) 113 | initial_state['n'] = 0 114 | if extra_loss is not None: 115 | initial_state['extra_loss'] = jnp.array(0.0, dtype=jnp.float64) 116 | 117 | # Unpack the compilable and static HParams separately. 118 | return _compiled_run_dataset( 119 | dataset.x, 120 | dataset.y, 121 | initial_state, 122 | obs_ixs, 123 | # hparams.delta, 124 | # hparams.lr, 125 | loss_and_grad, 126 | update_fn, 127 | # hparams.algorithm, 128 | # hparams.sketch_size, 129 | extra_loss, 130 | ) 131 | -------------------------------------------------------------------------------- /precondition/datamix_gemma/confusion_matrix_calc.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import functools 16 | 17 | from absl import logging 18 | # Finally, we import Gemma. 19 | import jax 20 | import numpy as np 21 | from precondition.datamix_gemma import finetune_utils 22 | from precondition.datamix_gemma.evals.gsm8k_eval import GSM8KEval 23 | from precondition.datamix_gemma.training_batch_generators import vanilla_training_batch_generator 24 | from precondition.datamix_gemma.training_loop import TrainingConfig 25 | from precondition.datamix_gemma.training_loop import TrainingLoop 26 | 27 | 28 | SEQ_SIZE = 1000 29 | VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"} 30 | CKPT_PATH = '/tfhub/prod/g_mini/2b_it_v1p1_orbax/1' 31 | VOCAB_PATH = '/home/mriviere/g_mini/tokenizer/gemini_bpe_256k_v5_no_tags_cleared_v1.model' 32 | 33 | def compute_eval_score(params, train_obj, training_batch_generator_obj, eval_obj): 34 | training_batch_generator_obj.prepare_for_training([1.,], [1.,]) 35 | trained_params = train_obj.train_loop( 36 | params={'params': params}, 37 | get_next_batch_fn=functools.partial( 38 | training_batch_generator_obj.get_next_batch, index=0 39 | ), 40 | ) 41 | eval_score = eval_obj.evaluate(trained_params['params']) 42 | return eval_score 43 | 44 | def confusion_matrix_calc(): 45 | model_2b, tokenizer, vocab, params = finetune_utils.setup_model() 46 | 47 | params = jax.tree_util.tree_map( 48 | lambda arr: jax.device_put( 49 | arr, jax.local_devices(backend='cpu')[0] 50 | ), 51 | params, 52 | ) 53 | train_batch_size = jax.local_device_count() 54 | logging.info('train_batch_size: %s', train_batch_size) 55 | mmlu_eval_batch_size = 8192 56 | mbpp_eval_batch_size = 1024 57 | 58 | 59 | training_cfg = TrainingConfig( 60 | learning_rate=1e-4, 61 | batch_size=train_batch_size, 62 | ) 63 | 64 | train_loop_obj = TrainingLoop( 65 | model=model_2b, 66 | pad_id=tokenizer.pad_id, 67 | training_cfg=training_cfg, 68 | num_training_steps=100, 69 | #optimization_alg='adam', 70 | ) 71 | 72 | all_dataset_builders = finetune_utils.get_dataset_builders(tokenizer) 73 | vanilla_training_batch_generator_obj = vanilla_training_batch_generator.VanillaTrainingBatchGenerator( 74 | train_ds_builders=[all_dataset_builders[0],], 75 | batch_size=jax.device_count(), 76 | num_weights=1, 77 | num_iterations=100, 78 | ) 79 | vanilla_training_batch_generator_obj.prepare_for_training([[1.,]], [[1.,]]) 80 | trained_params = train_loop_obj.train_loop( 81 | params={'params': params}, 82 | get_next_batch_fn=functools.partial( 83 | vanilla_training_batch_generator_obj.get_next_batch, index=0 84 | ), 85 | ) 86 | 87 | exit() 88 | 89 | gsm8k_eval_obj = GSM8KEval( 90 | model=model_2b, 91 | tokenizer=tokenizer, 92 | vocab=vocab, 93 | eval_batch_size=mmlu_eval_batch_size 94 | ) 95 | 96 | rng = np.random.default_rng(seed=0) 97 | num_iterations = 100 98 | scores = [] 99 | for i in range(len(all_dataset_builders)): 100 | vanilla_training_batch_generator_obj = vanilla_training_batch_generator.VanillaTrainingBatchGenerator( 101 | train_ds_builders=[all_dataset_builders[i],], 102 | batch_size=jax.device_count(), 103 | ) 104 | score = compute_eval_score( 105 | params, 106 | train_obj=train_loop_obj, 107 | training_batch_generator_obj=vanilla_training_batch_generator_obj, 108 | eval_obj=gsm8k_eval_obj, 109 | ) 110 | logging.info(f'Score at index {i}: {score}') 111 | scores.append(score) 112 | writer.close() 113 | -------------------------------------------------------------------------------- /precondition/datamix_gemma/finetune_eval_measurement.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import functools 16 | 17 | from absl import logging 18 | # Finally, we import Gemma. 19 | import jax 20 | import numpy as np 21 | from precondition.datamix_gemma import finetune_utils 22 | from precondition.datamix_gemma.evals.gsm8k_eval import GSM8KEval 23 | from precondition.datamix_gemma.training_batch_generators import vanilla_training_batch_generator 24 | from precondition.datamix_gemma.training_loop import TrainingConfig 25 | from precondition.datamix_gemma.training_loop import TrainingLoop 26 | 27 | 28 | SEQ_SIZE = 1000 29 | VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"} 30 | CKPT_PATH = '/tfhub/prod/g_mini/2b_it_v1p1_orbax/1' 31 | VOCAB_PATH = '/home/mriviere/g_mini/tokenizer/gemini_bpe_256k_v5_no_tags_cleared_v1.model' 32 | 33 | def compute_eval_score(params, train_obj, training_batch_generator_obj, eval_obj): 34 | training_batch_generator_obj.prepare_for_training([1.,], [1.,]) 35 | trained_params = train_obj.train_loop( 36 | params={'params': params}, 37 | get_next_batch_fn=functools.partial( 38 | training_batch_generator_obj.get_next_batch, index=0 39 | ), 40 | ) 41 | eval_score = eval_obj.evaluate(trained_params['params']) 42 | return eval_score 43 | 44 | def finetune_eval_measurement(): 45 | model_2b, tokenizer, vocab, params = finetune_utils.setup_model() 46 | 47 | params = jax.tree_util.tree_map( 48 | lambda arr: jax.device_put( 49 | arr, jax.local_devices(backend='cpu')[0] 50 | ), 51 | params, 52 | ) 53 | train_batch_size = jax.local_device_count() 54 | logging.info('train_batch_size: %s', train_batch_size) 55 | mmlu_eval_batch_size = 2048 56 | 57 | 58 | training_cfg = TrainingConfig( 59 | learning_rate=1e-4, 60 | batch_size=train_batch_size, 61 | ) 62 | 63 | train_loop_obj = TrainingLoop( 64 | model=model_2b, 65 | pad_id=tokenizer.pad_id, 66 | training_cfg=training_cfg, 67 | num_training_steps=1, 68 | #optimization_alg='adam', 69 | ) 70 | 71 | all_dataset_builders = finetune_utils.get_dataset_builders(tokenizer) 72 | 73 | vanilla_training_batch_generator_obj = vanilla_training_batch_generator.VanillaTrainingBatchGenerator( 74 | train_ds_builders=[all_dataset_builders[0],], 75 | batch_size=jax.device_count(), 76 | num_weights=1, 77 | num_iterations=100, 78 | ) 79 | gsm8k_eval_obj = GSM8KEval( 80 | model=model_2b, 81 | tokenizer=tokenizer, 82 | vocab=vocab, 83 | eval_batch_size=mmlu_eval_batch_size 84 | ) 85 | vanilla_training_batch_generator_obj.prepare_for_training([[1.,]], [[1.,]]) 86 | trained_params = {'params': params} 87 | for i in range(100): 88 | trained_params = train_loop_obj.train_loop( 89 | params=trained_params, 90 | get_next_batch_fn=functools.partial( 91 | vanilla_training_batch_generator_obj.get_next_batch, index=0 92 | ), 93 | ) 94 | score = gsm8k_eval_obj.evaluate(trained_params['params']) 95 | logging.info(f'Score at index {i}: {score}') 96 | 97 | exit() 98 | 99 | gsm8k_eval_obj = GSM8KEval( 100 | model=model_2b, 101 | tokenizer=tokenizer, 102 | vocab=vocab, 103 | eval_batch_size=mmlu_eval_batch_size 104 | ) 105 | 106 | rng = np.random.default_rng(seed=0) 107 | num_iterations = 100 108 | scores = [] 109 | for i in range(len(all_dataset_builders)): 110 | vanilla_training_batch_generator_obj = vanilla_training_batch_generator.VanillaTrainingBatchGenerator( 111 | train_ds_builders=[all_dataset_builders[i],], 112 | batch_size=jax.device_count(), 113 | ) 114 | score = compute_eval_score( 115 | params, 116 | train_obj=train_loop_obj, 117 | training_batch_generator_obj=vanilla_training_batch_generator_obj, 118 | eval_obj=gsm8k_eval_obj, 119 | ) 120 | logging.info(f'Score at index {i}: {score}') 121 | scores.append(score) 122 | -------------------------------------------------------------------------------- /precondition/datamix_gemma/dataset_builders/mbpp_dataset_builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """MBPP dataset builder.""" 16 | 17 | import enum as Enum 18 | 19 | from precondition.datamix_gemma.dataset_builders import dataset_builder 20 | from precondition.datamix_gemma.tokenizers import gemma_tokenizer 21 | import tensorflow as tf 22 | import tensorflow_datasets as tfds 23 | 24 | 25 | class DatasetSplit(Enum.Enum): 26 | TRAIN = 'train' 27 | TEST = 'test' 28 | PROMPT = 'prompt' 29 | 30 | 31 | class MBPPDatasetBuilder(dataset_builder.DatasetBuilder): 32 | """Dataset builder for the MBPP dataset.""" 33 | 34 | #BUFFER_SIZE_SHUFFLE = 10_000 35 | BUFFER_SIZE_SHUFFLE = 100 36 | 37 | def __init__( 38 | self, tokenizer: gemma_tokenizer.GemmaTokenizer, max_seq_len: int 39 | ): 40 | """Constructor. 41 | 42 | Args: 43 | tokenizer: Gemma tokenizer to use. 44 | max_seq_len: size of each sequence in a given batch. 45 | """ 46 | self._tokenizer = tokenizer 47 | self._base_data = { 48 | DatasetSplit.TRAIN: tfds.load( 49 | 'huggingface:mbpp/full', split='train' 50 | ), 51 | DatasetSplit.TEST: tfds.load( 52 | 'huggingface:mbpp/full', split='test' 53 | ), 54 | DatasetSplit.PROMPT: tfds.load( 55 | 'huggingface:mbpp/full', split='prompt' 56 | ), 57 | } 58 | self._max_seq_len = max_seq_len 59 | #train_ds = self._base_data[DatasetSplit.TEST] 60 | prompt_ds = self._base_data[DatasetSplit.PROMPT] 61 | #train_ds = train_ds.filter( 62 | # lambda x: 2 <= tf.cast(x['task_id'], tf.int32) <= 4 63 | #) 64 | prompt_ds = prompt_ds.map( 65 | lambda x: (x['text'], 66 | self._generate_tests_string(x['test_list']), x['code'])) 67 | prompt_ds = prompt_ds.map( 68 | lambda x, y, z: tf.py_function( 69 | self._generate_training_prompt, [x, y, z], [tf.string])) 70 | individual_prompts = [] 71 | for prompt in prompt_ds: 72 | individual_prompts.append(prompt[0].numpy().decode('utf-8')) 73 | self._train_prompt = '\n'.join(individual_prompts) 74 | 75 | def _generate_eval_prompt(self, prompt, tests_str): 76 | full_prompt = f'You are an expert Python programmer, and here is your task: {prompt.numpy().decode("utf-8")} Your code should pass these tests:\n\n{tests_str.numpy().decode("utf-8")}\n' # pylint: disable=line-too-long 77 | return full_prompt, tests_str.numpy().decode('utf-8') 78 | 79 | def _generate_training_prompt(self, prompt, tests_str, code): 80 | return f'You are an expert Python programmer, and here is your task: {prompt.numpy().decode("utf-8")} Your code should pass these tests:\n\n{tests_str.numpy().decode("utf-8")}\n[BEGIN]\n{code.numpy().decode("utf-8")}\n[DONE]' # pylint: disable=line-too-long 81 | 82 | def _generate_tests_string(self, tests_list): 83 | return tf.strings.reduce_join(tests_list, separator='\n') 84 | 85 | def _generate_full_eval_prompt(self, eval_prompt, tests_str): 86 | full_prompt = '\n'.join( 87 | [self._train_prompt, eval_prompt.numpy().decode('utf-8')] 88 | ) 89 | return full_prompt, tests_str.numpy().decode('utf-8') 90 | 91 | def get_test_dataset(self): 92 | ds = self._base_data[DatasetSplit.TEST].filter( 93 | lambda x: 11 <= x['task_id'] <= 510 94 | ) 95 | ds = ds.map( 96 | lambda x: (x['text'], self._generate_tests_string(x['test_list'])) 97 | ) 98 | ds = ds.map( 99 | lambda x, y: tf.py_function( 100 | self._generate_eval_prompt, [x, y], [tf.string, tf.string] 101 | ) 102 | ) 103 | ds = ds.map( 104 | lambda x, y: tf.py_function( 105 | self._generate_full_eval_prompt, [x, y], [tf.string, tf.string] 106 | ) 107 | ) 108 | ds = ds.map( 109 | lambda x, y: tf.py_function( 110 | self._generate_eval_prompt, [x, y], [tf.string, tf.string] 111 | ) 112 | ) 113 | ds = ds.map( 114 | lambda x, y: tf.py_function( 115 | self._generate_full_eval_prompt, [x, y], [tf.string, tf.string] 116 | ) 117 | ) 118 | return ds 119 | -------------------------------------------------------------------------------- /precondition/quantization_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Helper routines for quantization.""" 16 | 17 | from typing import Any 18 | 19 | import chex 20 | from flax import struct 21 | import jax.numpy as jnp 22 | 23 | 24 | # pylint:disable=no-value-for-parameter 25 | @struct.dataclass 26 | class QuantizedValue: 27 | """State associated with quantized value.""" 28 | quantized: chex.Array 29 | diagonal: chex.Array # Diagonal (if extract_diagonal is set) 30 | bucket_size: chex.Array 31 | quantized_dtype: jnp.dtype = struct.field( 32 | pytree_node=False) # Dtype for the quantized value. 33 | extract_diagonal: bool = struct.field( 34 | pytree_node=False) # In case its centered. 35 | shape: Any = struct.field(pytree_node=False) # Shape of the tensor. 36 | 37 | @classmethod 38 | def from_float_value(cls, fvalue, quantized_dtype, extract_diagonal=False): 39 | if isinstance(fvalue, list) and not fvalue: 40 | return QuantizedValue([], [], [], quantized_dtype, extract_diagonal, []) # pytype: disable=wrong-arg-types # numpy-scalars 41 | quantized, diagonal_fvalue, bucket_size = QuantizedValue.quantize( 42 | fvalue, quantized_dtype, extract_diagonal) 43 | return QuantizedValue(quantized, diagonal_fvalue, bucket_size, 44 | quantized_dtype, extract_diagonal, 45 | list(quantized.shape)) 46 | 47 | # Quantization is from Lingvo JAX optimizers. 48 | # We extend it for int16 quantization of PSD matrices. 49 | @classmethod 50 | def quantize(cls, fvalue, quantized_dtype, extract_diagonal=False): 51 | """Returns quantized value and the bucket.""" 52 | if quantized_dtype == jnp.float32: 53 | return fvalue, [], [] 54 | elif quantized_dtype == jnp.bfloat16: 55 | return fvalue.astype(jnp.bfloat16), [], [] 56 | 57 | float_dtype = fvalue.dtype 58 | if quantized_dtype == jnp.int8: 59 | # value -128 is not used. 60 | num_buckets = jnp.array(127.0, dtype=float_dtype) 61 | elif quantized_dtype == jnp.int16: 62 | # value -32768 is not used. 63 | num_buckets = jnp.array(32767.0, dtype=float_dtype) 64 | else: 65 | raise ValueError(f'Quantized dtype {quantized_dtype} not supported.') 66 | # max value is mapped to num_buckets 67 | 68 | if extract_diagonal and fvalue.ndim != 2: 69 | raise ValueError( 70 | f'Input array {fvalue} must be 2D to work with extract_diagonal.') 71 | 72 | diagonal_fvalue = [] 73 | if extract_diagonal: 74 | diagonal_fvalue = jnp.diag(fvalue) 75 | # Remove the diagonal entries. 76 | fvalue = fvalue - jnp.diag(diagonal_fvalue) 77 | 78 | # TODO(rohananil): Extend this by making use of information about the blocks 79 | # SM3 style which will be useful for diagonal statistics 80 | # We first decide the scale. 81 | if fvalue.ndim < 1: 82 | raise ValueError( 83 | f'Input array {fvalue} must have a strictly positive number of ' 84 | 'dimensions.') 85 | 86 | max_abs = jnp.max(jnp.abs(fvalue), axis=0) 87 | bucket_size = max_abs / num_buckets 88 | bs_expanded = bucket_size[jnp.newaxis, ...] 89 | # To avoid divide by 0.0 90 | bs_nonzero = jnp.where(bs_expanded > 0.0, bs_expanded, 91 | jnp.ones_like(bs_expanded)) 92 | ratio = fvalue / bs_nonzero 93 | # We use rounding to remove bias. 94 | quantized = jnp.round(ratio) 95 | return quantized.astype(quantized_dtype), diagonal_fvalue, bucket_size 96 | 97 | def to_float(self): 98 | """Returns the float value.""" 99 | if isinstance(self.quantized, list) and not self.quantized: 100 | return self.quantized 101 | 102 | if self.quantized_dtype == jnp.float32: 103 | return self.quantized 104 | 105 | if self.quantized_dtype == jnp.bfloat16: 106 | return self.quantized.astype(jnp.float32) 107 | 108 | float_dtype = self.bucket_size.dtype 109 | bucket_size = self.bucket_size[jnp.newaxis, ...] 110 | val = self.quantized.astype(float_dtype) * bucket_size 111 | if self.extract_diagonal: 112 | val += jnp.diag(self.diagonal) 113 | return val 114 | 115 | -------------------------------------------------------------------------------- /precondition/tearfree/reshaper_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for momentum implementation.""" 16 | 17 | from typing import Sequence 18 | 19 | from absl.testing import absltest 20 | from absl.testing import parameterized 21 | import jax 22 | from jax import numpy as jnp 23 | import numpy as np 24 | from precondition.tearfree import reshaper 25 | 26 | 27 | def _make_invalid_cases() -> Sequence[dict[str, ...]]: 28 | """Generate invalid cases which should throw.""" 29 | return [ 30 | { 31 | 'testcase_name': 'smallblock', 32 | 'invalid_options': reshaper.Options( 33 | block_size=1, 34 | ), 35 | }, 36 | { 37 | 'testcase_name': 'smallmerge', 38 | 'invalid_options': reshaper.Options( 39 | merge_dims=0, 40 | ), 41 | }, 42 | ] 43 | 44 | 45 | def _make_expected_shape_cases() -> Sequence[dict[str, ...]]: 46 | cases = [ 47 | {'in_shape': [4], 'merge': 2, 'block': 3, 'out_shape': [6]}, 48 | {'in_shape': [3], 'merge': 2, 'block': 3, 'out_shape': [3]}, 49 | {'in_shape': [1, 3, 1], 'merge': 2, 'block': 3, 'out_shape': [3]}, 50 | {'in_shape': [1, 3, 1], 'merge': 3, 'block': 3, 'out_shape': [3]}, 51 | {'in_shape': [1, 3, 1], 'merge': 3, 'block': 4, 'out_shape': [3]}, 52 | {'in_shape': [1, 3, 1, 2], 'merge': 2, 'block': 3, 'out_shape': [3, 2]}, 53 | {'in_shape': [4, 1, 5], 'merge': 2, 'block': 3, 'out_shape': [6, 6]}, 54 | {'in_shape': [1], 'merge': 2, 'block': 2, 'out_shape': []}, 55 | {'in_shape': [1, 1, 1], 'merge': 2, 'block': 2, 'out_shape': []}, 56 | {'in_shape': [1, 1, 1], 'merge': 2, 'block': 2, 'out_shape': []}, 57 | { 58 | 'in_shape': [3, 1, 5, 2, 2], 59 | 'merge': 4, 60 | 'block': 10, 61 | 'out_shape': [3, 5, 4], 62 | }, 63 | {'in_shape': [2, 3, 2], 'merge': 6, 'block': 10, 'out_shape': [6, 2]}, 64 | ] 65 | for case in cases[:]: 66 | if all(i <= case['block'] for i in case['in_shape']): 67 | block0 = case.copy() 68 | block0['block'] = 0 69 | cases.append(block0) 70 | return cases 71 | 72 | 73 | class ReshaperTest(parameterized.TestCase): 74 | """Basic test for shampoo implementation.""" 75 | 76 | @parameterized.named_parameters(_make_invalid_cases()) 77 | def test_invalid(self, invalid_options): 78 | with self.assertRaises(ValueError): 79 | reshaper.merge(invalid_options) 80 | 81 | @parameterized.parameters(_make_expected_shape_cases()) 82 | def test_expected_shape(self, in_shape, merge, block, out_shape): 83 | options = reshaper.Options(merge_dims=merge, block_size=block) 84 | init_fn, update_fn = reshaper.merge(options) 85 | init = jnp.zeros(in_shape) 86 | out, _ = update_fn(init, init_fn(None), init) 87 | self.assertSequenceEqual(out.shape, out_shape) 88 | 89 | @parameterized.parameters(_make_expected_shape_cases()) 90 | def test_inversion(self, in_shape, merge, block, out_shape): 91 | del out_shape 92 | options = reshaper.Options(merge_dims=merge, block_size=block) 93 | init_fn, update_fn = reshaper.merge(options) 94 | init = jax.random.normal(jax.random.PRNGKey(0), in_shape) 95 | out, _ = update_fn(init, init_fn(None), init) 96 | init_fn, update_fn = reshaper.unmerge(options) 97 | recover, _ = update_fn(out, init_fn(None), init) 98 | np.testing.assert_array_equal(init, recover) 99 | 100 | def test_tree(self): 101 | shapes = { 102 | 'w': [[{'b': (3, 2)}]], 103 | 'z': ( 104 | 1, 105 | 2, 106 | 1, 107 | ), 108 | } 109 | init = jax.tree.map( 110 | jnp.zeros, shapes, is_leaf=lambda x: isinstance(x, tuple) 111 | ) 112 | options = reshaper.Options(merge_dims=2, block_size=2) 113 | init_fn, update_fn = reshaper.merge(options) 114 | out, _ = update_fn(init, init_fn(None), init) 115 | out_shapes = jax.tree.map(lambda x: tuple(x.shape), out) 116 | expected_shapes = {'w': [[{'b': (4, 2)}]], 'z': (2,)} 117 | 118 | self.assertEqual(out_shapes, expected_shapes) 119 | 120 | 121 | if __name__ == '__main__': 122 | absltest.main() 123 | -------------------------------------------------------------------------------- /precondition/tearfree/reshaper.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Parameter reshaping module.""" 16 | 17 | import dataclasses 18 | import functools 19 | 20 | import jax 21 | from jax import numpy as jnp 22 | import optax 23 | from precondition import distributed_shampoo 24 | 25 | 26 | @dataclasses.dataclass 27 | class Options: 28 | """Parameter reshaping options. 29 | 30 | Attributes: 31 | merge_dims: Collapse dimensions smaller than this number left-to-right, 32 | e.g., [3, 1, 5, 2, 2] becomes [3, 5, 4] with `merge_dims = 4`. Notice 33 | ordering, [2, 3, 2] becomes [6, 2] with `merge_dims = 6`, not its reverse. 34 | block_size: If nonzero, pads all dimensions larger than the block size to a 35 | multiple of the block size. 36 | """ 37 | 38 | merge_dims: int = 1024 39 | block_size: int = 1024 40 | 41 | 42 | @dataclasses.dataclass 43 | class _Shapes: 44 | """Shape container.""" 45 | 46 | original_shape: list[int] 47 | merged_shape: list[int] 48 | padded_shape: list[int] 49 | 50 | 51 | def _derive_shapes(options: Options, param: jax.Array) -> _Shapes: 52 | """Derive desired shapes from options.""" 53 | merged = distributed_shampoo.merge_small_dims(param.shape, options.merge_dims) 54 | if merged == [1]: 55 | return _Shapes( 56 | original_shape=list(param.shape), 57 | merged_shape=[], 58 | padded_shape=[], 59 | ) 60 | if options.block_size == 0: 61 | padded = merged 62 | else: 63 | padded = [] 64 | for s in merged: 65 | if s >= options.block_size: 66 | s = (s + options.block_size - 1) // options.block_size 67 | s *= options.block_size 68 | padded.append(s) 69 | return _Shapes( 70 | original_shape=list(param.shape), 71 | merged_shape=merged, 72 | padded_shape=padded, 73 | ) 74 | 75 | 76 | def merge(options: Options) -> optax.GradientTransformation: 77 | """Merge and maybe pad gradients, leaving params alone.""" 78 | 79 | if options.merge_dims < 2: 80 | raise ValueError( 81 | 'merge_dims ({}) must be at least 2'.format(options.merge_dims) 82 | ) 83 | 84 | if options.block_size < 2 and options.block_size != 0: 85 | raise ValueError( 86 | 'block_size ({}) must be at least 2 (or 0 to disable)'.format( 87 | options.block_size 88 | ) 89 | ) 90 | 91 | def _merge(update: jax.Array, shapes: _Shapes) -> jax.Array: 92 | assert list(update.shape) == shapes.original_shape, (update.shape, shapes) 93 | merged = update.reshape(shapes.merged_shape) 94 | padding = [ 95 | (0, p - m) for p, m in zip(shapes.padded_shape, shapes.merged_shape) 96 | ] 97 | if padding and options.block_size > 0: 98 | return jnp.pad(merged, padding) 99 | return merged 100 | 101 | def update( 102 | updates: optax.Updates, 103 | state: optax.MaskedNode, 104 | params: optax.Params, 105 | ) -> tuple[optax.Updates, optax.MaskedNode]: 106 | shapes = jax.tree.map(functools.partial(_derive_shapes, options), params) 107 | new_updates = jax.tree.map(_merge, updates, shapes) 108 | return new_updates, state 109 | 110 | return optax.GradientTransformation(lambda _: optax.MaskedNode(), update) 111 | 112 | 113 | def unmerge(options: Options) -> optax.GradientTransformation: 114 | """Unmerge and unpad gradients, leaving params alone.""" 115 | 116 | def _unmerge(update: jax.Array, shapes: _Shapes) -> jax.Array: 117 | assert list(update.shape) == shapes.padded_shape, (update.shape, shapes) 118 | if options.block_size == 0: 119 | merged = update 120 | else: 121 | merged = update[tuple(slice(0, m) for m in shapes.merged_shape)] 122 | return merged.reshape(shapes.original_shape) 123 | 124 | def update( 125 | updates: optax.Updates, 126 | state: optax.MaskedNode, 127 | params: optax.Params, 128 | ) -> tuple[optax.Updates, optax.MaskedNode]: 129 | shapes = jax.tree.map(functools.partial(_derive_shapes, options), params) 130 | new_updates = jax.tree.map(_unmerge, updates, shapes) 131 | return new_updates, state 132 | 133 | return optax.GradientTransformation(lambda _: optax.MaskedNode(), update) 134 | -------------------------------------------------------------------------------- /precondition/datamix_gemma/dataset_builders/mtnt_dataset_builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Dataset builder for the MTNT dataset.""" 16 | 17 | import enum as Enum 18 | 19 | import jax.dlpack 20 | from precondition.datamix_gemma.dataset_builders import dataset_builder 21 | from precondition.datamix_gemma.tokenizers import gemma_tokenizer 22 | import tensorflow as tf 23 | import tensorflow_datasets as tfds 24 | 25 | 26 | class DatasetSplit(Enum.Enum): 27 | 28 | TRAIN = 'train' 29 | VALIDATION = 'validation' 30 | 31 | 32 | class MTNTDatasetBuilder(dataset_builder.DatasetBuilder): 33 | """Dataset builder for the MTNT dataset.""" 34 | 35 | N_ITEMS = {DatasetSplit.TRAIN: 35_692, DatasetSplit.VALIDATION: 811} 36 | 37 | BUFFER_SIZE_SHUFFLE = 10_000 38 | TRANSLATION_PREFIX = 'Translate this into French:\n' 39 | TRANSLATION_SUFFIX = '\n' 40 | 41 | def __init__( 42 | self, tokenizer: gemma_tokenizer.GemmaTokenizer, max_seq_len: int 43 | ): 44 | """Constructor. 45 | 46 | Args: 47 | tokenizer: Gemma tokenizer to use. 48 | max_seq_len: size of each sequence in a given batch. 49 | """ 50 | self._tokenizer = tokenizer 51 | self._base_data = { 52 | DatasetSplit.TRAIN: tfds.load('mtnt/en-fr', split='train'), 53 | DatasetSplit.VALIDATION: tfds.load('mtnt/en-fr', split='valid'), 54 | } 55 | self._max_seq_len = max_seq_len 56 | 57 | def _tokenize_source(self, example: tf.Tensor) -> tf.Tensor: 58 | """Tokenization function for the source.""" 59 | res = self._tokenizer.tokenize_tf_op( 60 | example, 61 | prefix=self.TRANSLATION_PREFIX, 62 | suffix=self.TRANSLATION_SUFFIX, 63 | add_eos=False, 64 | ) 65 | return res 66 | 67 | def _tokenize_destination(self, example: tf.Tensor): 68 | """Tokenization function for the French translation.""" 69 | return self._tokenizer.tokenize_tf_op(example, add_eos=True) 70 | 71 | def _to_training_input( 72 | self, 73 | src_tokens: jax.Array, 74 | dst_tokens: jax.Array, 75 | ): 76 | """Build a training input from a tuple of source and destination tokens.""" 77 | 78 | # The input sequence fed to the model is simply the concatenation of the 79 | # source and the destination. 80 | tokens = tf.concat([src_tokens, dst_tokens], axis=0) 81 | 82 | # To prevent the model from updating based on the source (input) 83 | # tokens, add a target mask to each input. 84 | q_mask = tf.zeros_like(src_tokens, dtype=tf.bool) 85 | a_mask = tf.ones_like(dst_tokens, dtype=tf.bool) 86 | mask = tf.concat([q_mask, a_mask], axis=0) 87 | 88 | # If the output tokens sequence is smaller than the target sequence size, 89 | # then pad it with pad tokens. 90 | tokens = self._pad_up_to_max_len(tokens, self._tokenizer.pad_id) 91 | 92 | # Don't want to perform the backward pass on the pad tokens. 93 | mask = self._pad_up_to_max_len(mask, False) 94 | return dataset_builder.TrainingInput( #type: ignore 95 | input_tokens=tokens, #type:ignore 96 | target_mask=mask, #type:ignore 97 | )# type: ignore 98 | 99 | def get_train_dataset(self, batch_size: int, num_epochs: int): 100 | """Build the training dataset.""" 101 | 102 | ds = self._base_data[DatasetSplit.TRAIN].map( 103 | lambda x: ( 104 | self._tokenize_source(x['src']), 105 | self._tokenize_destination(x['dst']), 106 | ) 107 | ) 108 | ds = ds.map(self._to_training_input) 109 | ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len) 110 | ds = ds.shuffle(buffer_size=self.BUFFER_SIZE_SHUFFLE) 111 | ds = ds.repeat(num_epochs) 112 | ds = ds.batch(batch_size, drop_remainder=True) 113 | return ds 114 | 115 | def get_validation_dataset(self, batch_size: int): 116 | """Build the validation dataset.""" 117 | 118 | ds = self._base_data[DatasetSplit.VALIDATION].map( 119 | lambda x: ( 120 | self._tokenize_source(x['src']), 121 | self._tokenize_destination(x['dst']), 122 | ) 123 | ) 124 | ds = ds.map(self._to_training_input) 125 | ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len) 126 | ds = ds.batch(batch_size, drop_remainder=True) 127 | return ds 128 | -------------------------------------------------------------------------------- /precondition/datamix_gemma/dataset_builders/orca_math_dataset_builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Dataset builder for the Orca Math dataset.""" 16 | 17 | import enum as Enum 18 | 19 | from absl import logging 20 | import jax.dlpack 21 | from precondition.datamix_gemma.dataset_builders import dataset_builder 22 | from precondition.datamix_gemma.tokenizers import gemma_tokenizer 23 | import tensorflow as tf 24 | import tensorflow_datasets as tfds 25 | 26 | 27 | class DatasetSplit(Enum.Enum): 28 | TRAIN = 'train' 29 | 30 | 31 | class OrcaMathDatasetBuilder(dataset_builder.DatasetBuilder): 32 | """Dataset builder for the Orca Math dataset.""" 33 | 34 | N_ITEMS = {DatasetSplit.TRAIN: 200035} 35 | 36 | #BUFFER_SIZE_SHUFFLE = 10_000 37 | BUFFER_SIZE_SHUFFLE = 100 38 | QUESTION_PREFIX = 'Question: \n' 39 | QUESTION_SUFFIX = '\n' 40 | #TRANSLATION_PREFIX = 'Translate this into French:\n' 41 | #TRANSLATION_SUFFIX = '\n' 42 | 43 | def __init__( 44 | self, tokenizer: gemma_tokenizer.GemmaTokenizer, max_seq_len: int 45 | ): 46 | """Constructor. 47 | 48 | Args: 49 | tokenizer: Gemma tokenizer to use. 50 | max_seq_len: size of each sequence in a given batch. 51 | """ 52 | self._tokenizer = tokenizer 53 | self._base_data = { 54 | DatasetSplit.TRAIN: tfds.load( 55 | 'huggingface:microsoft__orca_math_word_problems_200k', split='train' 56 | ), 57 | } 58 | logging.info( 59 | 'orca math size: %s', 60 | self._base_data[DatasetSplit.TRAIN].cardinality().numpy(), 61 | ) 62 | self._max_seq_len = max_seq_len 63 | 64 | def _tokenize_question(self, example: tf.Tensor): 65 | """Tokenization function for the Question.""" 66 | return self._tokenizer.tokenize_tf_op( 67 | example, 68 | prefix=self.QUESTION_PREFIX, 69 | suffix=self.QUESTION_SUFFIX, 70 | add_eos=False, 71 | ) 72 | 73 | def _tokenize_response(self, example: tf.Tensor): 74 | """Tokenization function for the Response.""" 75 | return self._tokenizer.tokenize_tf_op( 76 | example, 77 | add_eos=True, 78 | ) 79 | 80 | def _to_training_input( 81 | self, 82 | question_tokens: jax.Array, 83 | answer_tokens: jax.Array, 84 | ): 85 | """Build a training input from a tuple of source and destination tokens.""" 86 | 87 | # The input sequence fed to the model is simply the concatenation of the 88 | # source and the destination. 89 | tokens = tf.concat( 90 | [question_tokens, answer_tokens], axis=0 91 | ) 92 | 93 | # To prevent the model from updating based on the source (input) 94 | # tokens, add a target mask to each input. 95 | question_mask = tf.zeros_like(question_tokens, dtype=tf.bool) 96 | answer_mask = tf.ones_like(answer_tokens, dtype=tf.bool) 97 | mask = tf.concat([question_mask, answer_mask], axis=0) 98 | 99 | # If the output tokens sequence is smaller than the target sequence size, 100 | # then pad it with pad tokens. 101 | tokens = self._pad_up_to_max_len(tokens, self._tokenizer.pad_id) 102 | 103 | # Don't want to perform the backward pass on the pad tokens. 104 | mask = self._pad_up_to_max_len(mask, False) 105 | return dataset_builder.TrainingInput( #type: ignore 106 | input_tokens=tokens, #type:ignore 107 | target_mask=mask, #type:ignore 108 | )# type: ignore 109 | 110 | def get_train_dataset(self, batch_size: int, num_epochs: int): 111 | """Build the training dataset.""" 112 | 113 | ds = self._base_data[DatasetSplit.TRAIN].map( 114 | lambda x: ( 115 | self._tokenize_question(x['question']), 116 | self._tokenize_response(x['answer']) 117 | ) 118 | ) 119 | ds = ds.map(self._to_training_input) 120 | ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len) 121 | ds = ds.shuffle(buffer_size=self.BUFFER_SIZE_SHUFFLE) 122 | #ds = ds.repeat(num_epochs) 123 | #ds = ds.batch(batch_size, drop_remainder=True) 124 | return ds 125 | 126 | def get_validation_dataset(self, batch_size: int): 127 | """Build the validation dataset.""" 128 | 129 | ds = self._base_data[DatasetSplit.TRAIN].map( 130 | lambda x: ( 131 | self._tokenize_question(x['question']), 132 | self._tokenize_response(x['answer']) 133 | ) 134 | ) 135 | ds = ds.map(self._to_training_input) 136 | ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len) 137 | return ds 138 | -------------------------------------------------------------------------------- /precondition/datamix_gemma/snr_calculation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | 17 | SEQ_SIZE = 1000 18 | VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"} 19 | CKPT_PATH = '/tfhub/prod/g_mini/2b_it_v1p1_orbax/1' 20 | VOCAB_PATH = '/home/mriviere/g_mini/tokenizer/gemini_bpe_256k_v5_no_tags_cleared_v1.model' 21 | 22 | #def estimate_gradient(weights, delta, rng, params, train_obj, training_batch_generator_obj, eval_obj, writer): 23 | # cands = bandit_loop._generate_gaussian_candidates(weights, rng, delta=delta) 24 | # gradient_discount_factor = training_batch_generator_obj.prepare_for_training( 25 | # cands[0], cands[1] 26 | # ) 27 | # training_operations = [] 28 | # for cand_it in range(len(cands)): 29 | # cur_params = copy.deepcopy(params) 30 | # trained_params = train_obj.train_loop( 31 | # params={'params': cur_params}, 32 | # get_next_batch_fn=functools.partial( 33 | # training_batch_generator_obj.get_next_batch, index=cand_it 34 | # ), 35 | # ) 36 | # trained_params = jax.tree_util.tree_map( 37 | # lambda arr: jax.device_put( 38 | # arr, jax.local_devices(backend='cpu')[0] 39 | # ), 40 | # trained_params, 41 | # ) 42 | # training_operations.append(trained_params) 43 | # logging.info('Done training!') 44 | # scores = [] 45 | # for trained_params in training_operations: 46 | # trained_params = jax.device_get(trained_params) 47 | # scores.append( 48 | # eval_obj.evaluate(trained_params['params']) 49 | # ) 50 | # logging.info('[SCORES]: %s', scores) 51 | # for i in range(weights.shape[0]): 52 | # writer.write({'weights_' + str(i): weights[i]}) 53 | # writer.write({'average_score': (scores[0] + scores[1]) / 2.0}) 54 | # writer.write({'score_1': scores[0]}) 55 | # writer.write({'score_2': scores[1]}) 56 | # #grad = bandit_loop._compute_gradient_random_sign(*zip(cands, scores)) * gradient_discount_factor 57 | # grad = bandit_loop._compute_gradient(cands, delta, scores) 58 | # logging.info('[GRAD]: %s', grad) 59 | # return grad 60 | # 61 | #def snr_calculation(): 62 | # model_2b, tokenizer, vocab, params = finetune_utils.setup_model() 63 | # 64 | # params = jax.tree_util.tree_map( 65 | # lambda arr: jax.device_put( 66 | # arr, jax.local_devices(backend='cpu')[0] 67 | # ), 68 | # params, 69 | # ) 70 | # 71 | # train_batch_size = jax.local_device_count() 72 | # logging.info('train_batch_size: %s', train_batch_size) 73 | # mmlu_eval_batch_size = 8192 74 | # mbpp_eval_batch_size = 1024 75 | # 76 | # 77 | # training_cfg = TrainingConfig( 78 | # learning_rate=1e-4, 79 | # batch_size=train_batch_size, 80 | # ) 81 | # 82 | # train_loop_obj = TrainingLoop( 83 | # model=model_2b, 84 | # pad_id=tokenizer.pad_id, 85 | # training_cfg=training_cfg, 86 | # num_training_steps=100, 87 | # ) 88 | # 89 | # all_dataset_builders = finetune_utils.get_dataset_builders(tokenizer, [0, 2]) 90 | # importance_weighting_training_batch_generator_obj = importance_weighting_training_batch_generator.ImportanceWeightingTrainingBatchGenerator( 91 | # train_ds_builders=all_dataset_builders, 92 | # batch_size=jax.device_count(), 93 | # ) 94 | # 95 | # gsm8k_eval_obj = GSM8KEval( 96 | # model=model_2b, 97 | # tokenizer=tokenizer, 98 | # vocab=vocab, 99 | # eval_batch_size=mmlu_eval_batch_size 100 | # ) 101 | # 102 | # rng = np.random.default_rng(seed=0) 103 | # data_id = xdata.get_auto_data_id() 104 | # writer = xdata.bt.writer(data_id, 'scores') 105 | # num_iterations = 100 106 | # running_outer_sum = np.zeros((2, 2)) 107 | # state.running_sum = np.zeros(2) 108 | # state.running_outer_sum = np.zeros((2,2)) 109 | # for i in range(num_iterations): 110 | # ckpt.restore_or_save() 111 | # grad_estimate = estimate_gradient( 112 | # weights=np.array([0.5, 0.5]), 113 | # delta=0.0000001, 114 | # rng=rng, 115 | # params=params, 116 | # train_obj=train_loop_obj, 117 | # training_batch_generator_obj=importance_weighting_training_batch_generator_obj, 118 | # eval_obj=gsm8k_eval_obj, 119 | # writer=writer, 120 | # ) 121 | # grad_estimate = np.array(grad_estimate) 122 | # state.running_sum += grad_estimate 123 | # running_avg = state.running_sum / (i + 1) 124 | # state.running_outer_sum += np.outer(grad_estimate, grad_estimate) 125 | # running_cov_avg = (state.running_outer_sum/(i+1) - np.outer(running_avg, running_avg)) 126 | # #writer.write({'running_avg': running_avg, 'running_cov_avg': running_cov_avg, 'trace of running_cov_avg': np.trace(running_cov_avg)}) 127 | # if i > 10: 128 | # writer.write({'running_avg_0': running_avg[0]}) 129 | # writer.write({'running_avg_1': running_avg[1]}) 130 | # writer.write({'running_cov_avg_00': running_cov_avg[0, 0]}) 131 | # writer.write({'running_cov_avg_11': running_cov_avg[1, 1]}) 132 | # writer.write({'running grads SNR': (np.linalg.norm(running_avg) ** 2) / np.trace(running_cov_avg)}) 133 | # ckpt.save() 134 | # 135 | # writer.close() 136 | # -------------------------------------------------------------------------------- /precondition/datamix_gemma/training_batch_generators/dartboard_importance_weighting_training_batch_generator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """DartboardImportanceWeightingTrainingBatchGenerator.""" 16 | import copy 17 | import itertools 18 | from absl import logging 19 | 20 | import tensorflow_datasets as tfds 21 | import numpy as np 22 | from precondition.datamix_gemma.training_batch_generators import training_batch_generator 23 | 24 | 25 | class DartboardImportanceWeightingTrainingBatchGenerator( 26 | training_batch_generator.TrainingBatchGenerator 27 | ): 28 | def __init__(self, train_ds_builders, batch_size, num_weights=2, num_iterations=100): 29 | super().__init__(train_ds_builders, batch_size, num_weights, num_iterations) 30 | self.training_iters = [] 31 | for dataset_builder_obj in self.train_ds_builders: 32 | cur_iter = iter( 33 | tfds.as_numpy( 34 | dataset_builder_obj.get_train_dataset( 35 | batch_size=batch_size, num_epochs=1 36 | ) 37 | ) 38 | ) 39 | self.training_iters.append(cur_iter) 40 | self.num_datasets = len(self.train_ds_builders) 41 | self.sample_choices = np.random.choice( 42 | self.num_datasets, 43 | size=self.batch_size * num_iterations, 44 | p=np.ones(self.num_datasets)/self.num_datasets, 45 | ) 46 | self.examples = [] 47 | for i in range(self.batch_size * num_iterations): 48 | try: 49 | self.examples.append(next(self.training_iters[self.sample_choices[i]])) 50 | except StopIteration: 51 | self.training_iters[self.sample_choices[i]] = iter( 52 | tfds.as_numpy( 53 | self.train_ds_builders[self.sample_choices[i]].get_train_dataset( 54 | batch_size=self.batch_size, num_epochs=1 55 | ) 56 | ) 57 | ) 58 | self.examples.append(next(self.training_iters[self.sample_choices[i]])) 59 | self.unnormalized_weights = np.ones(self.num_datasets)/self.num_datasets 60 | self.avg_weights = np.ones(self.num_datasets)/self.num_datasets 61 | self.weights_list = [np.ones(self.num_datasets)/self.num_datasets for _ in range(num_weights)] 62 | self.indices = [] 63 | self.factors = [] 64 | 65 | def prepare_for_training(self, weights_list, new_unnormalized_weights): 66 | """Prepare for training.""" 67 | self.indices = [0 for _ in range(self.num_weights)] 68 | logging.info(f'new_unnormalized_weights: {new_unnormalized_weights}') 69 | logging.info(f'avg_weights: {self.avg_weights}') 70 | nochange_prob = new_unnormalized_weights / self.avg_weights 71 | nochange_prob = np.minimum(nochange_prob, 1) 72 | logging.info(f'nochange_prob: {nochange_prob}') 73 | self.unnormalized_weights = new_unnormalized_weights 74 | self.weights_list = weights_list 75 | self.avg_weights = np.zeros(len(self.weights_list[0])) 76 | for i in range(len(self.weights_list)): 77 | self.avg_weights += self.weights_list[i] 78 | self.avg_weights /= len(self.weights_list) 79 | for i in range(self.batch_size * self.num_iterations): 80 | change = np.random.choice(2, p=[nochange_prob[self.sample_choices[i]], 1-nochange_prob[self.sample_choices[i]]]) 81 | if change: 82 | self.sample_choices[i] = np.random.choice( 83 | len(self.avg_weights), 84 | p=self.avg_weights, 85 | ) 86 | try: 87 | self.examples[i] = next(self.training_iters[self.sample_choices[i]]) 88 | except StopIteration: 89 | self.training_iters[self.sample_choices[i]] = iter( 90 | tfds.as_numpy( 91 | self.train_ds_builders[self.sample_choices[i]].get_train_dataset( 92 | batch_size=self.batch_size, num_epochs=1 93 | ) 94 | ) 95 | ) 96 | self.examples[i] = next(self.training_iters[self.sample_choices[i]]) 97 | self.factors = [np.zeros(self.batch_size * self.num_iterations) for _ in range(self.num_weights)] 98 | for i in range(self.num_weights): 99 | for j in range(self.batch_size * self.num_iterations): 100 | self.factors[i][j] = self.weights_list[i][self.sample_choices[j]] / self.avg_weights[self.sample_choices[j]] 101 | self.factors[i] = (self.factors[i] / np.sum(self.factors[i])) * len(self.factors[i]) 102 | return 1 103 | 104 | def get_next_batch(self, index): 105 | logging.info(f'Getting next batch, batch_size={self.batch_size}, weights_list={self.weights_list}') 106 | cur_factors = self.factors[index][self.indices[index]:(self.indices[index]+self.batch_size)] 107 | cur_examples = self.examples[self.indices[index]:(self.indices[index]+self.batch_size)] 108 | cur_input_tokens_batch = np.asarray([[example.input_tokens] for example in cur_examples]) 109 | cur_input_mask_batch = np.asarray([[example.target_mask] for example in cur_examples]) 110 | self.indices[index] += self.batch_size 111 | logging.info(f'cur_factors: {cur_factors}') 112 | return cur_factors, cur_input_tokens_batch, cur_input_mask_batch 113 | -------------------------------------------------------------------------------- /precondition/datamix_gemma/training_batch_generators/dartboard_deterministic_training_batch_generator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import copy 16 | import itertools 17 | from typing import cast 18 | from absl import logging 19 | 20 | import tensorflow_datasets as tfds 21 | import numpy as np 22 | 23 | class DartboardDeterministicTrainingBatchGenerator: 24 | def __init__(self, train_ds_builders, batch_size, num_weights=2, num_iterations=100): 25 | self.train_ds_builders = train_ds_builders 26 | self.batch_size = batch_size 27 | self.num_weights = num_weights 28 | self.num_iterations = num_iterations 29 | self.training_iters = [] 30 | for dataset_builder_obj in self.train_ds_builders: 31 | cur_iter = iter( 32 | tfds.as_numpy( 33 | dataset_builder_obj.get_train_dataset( 34 | batch_size=batch_size, num_epochs=1 35 | ) 36 | ) 37 | ) 38 | self.training_iters.append(cur_iter) 39 | self.num_datasets = len(self.train_ds_builders) 40 | self.sample_choices = np.random.choice( 41 | self.num_datasets, 42 | size=self.batch_size * num_iterations, 43 | p=np.ones(self.num_datasets)/self.num_datasets, 44 | ) 45 | self.examples = [] 46 | for i in range(self.batch_size * num_iterations): 47 | try: 48 | self.examples.append(next(self.training_iters[self.sample_choices[i]])) 49 | except StopIteration: 50 | self.training_iters[self.sample_choices[i]] = iter( 51 | tfds.as_numpy( 52 | self.train_ds_builders[self.sample_choices[i]].get_train_dataset( 53 | batch_size=self.batch_size, num_epochs=1 54 | ) 55 | ) 56 | ) 57 | self.examples.append(next(self.training_iters[self.sample_choices[i]])) 58 | self.unnormalized_weights = np.ones(self.num_datasets)/self.num_datasets 59 | self.avg_weights = np.ones(self.num_datasets)/self.num_datasets 60 | self.weights_list = [np.ones(self.num_datasets)/self.num_datasets for _ in range(num_weights)] 61 | self.avg_index = 0 62 | self.indices = [] 63 | 64 | def prepare_for_training(self, avg_weights, new_unnormalized_weights): 65 | """Prepare for training.""" 66 | self.indices = [0 for _ in range(self.num_weights)] 67 | logging.info(f'new_unnormalized_weights: {new_unnormalized_weights}') 68 | logging.info(f'avg_weights: {self.avg_weights}') 69 | nochange_prob = new_unnormalized_weights / self.avg_weights 70 | nochange_prob = np.minimum(nochange_prob, 1) 71 | logging.info(f'nochange_prob: {nochange_prob}') 72 | self.unnormalized_weights = new_unnormalized_weights 73 | self.avg_weights = avg_weights 74 | for i in range(self.batch_size * self.num_iterations): 75 | change = np.random.choice(2, p=[nochange_prob[self.sample_choices[i]], 1-nochange_prob[self.sample_choices[i]]]) 76 | if change: 77 | self.sample_choices[i] = np.random.choice( 78 | len(self.avg_weights), 79 | p=self.avg_weights, 80 | ) 81 | try: 82 | self.examples[i] = next(self.training_iters[self.sample_choices[i]]) 83 | except StopIteration: 84 | self.training_iters[self.sample_choices[i]] = iter( 85 | tfds.as_numpy( 86 | self.train_ds_builders[self.sample_choices[i]].get_train_dataset( 87 | batch_size=self.batch_size, num_epochs=1 88 | ) 89 | ) 90 | ) 91 | self.examples[i] = next(self.training_iters[self.sample_choices[i]]) 92 | return 1 93 | 94 | def get_next_batch(self): 95 | logging.info(f'Getting next batch, batch_size={self.batch_size}, weights_list={self.weights_list}') 96 | factors = np.ones(self.batch_size) 97 | cur_examples = self.examples[self.avg_index:(self.avg_index+self.batch_size)] 98 | cur_input_tokens_batch = np.asarray([[example.input_tokens] for example in cur_examples]) 99 | cur_input_mask_batch = np.asarray([[example.target_mask] for example in cur_examples]) 100 | self.avg_index += self.batch_size 101 | return factors, cur_input_tokens_batch, cur_input_mask_batch 102 | 103 | def get_next_batch_special(self, index, delta): 104 | logging.info(f'Getting next batch, batch_size={self.batch_size}, weights_list={self.weights_list}') 105 | cur_batch = [] 106 | while len(cur_batch) < self.batch_size and self.indices[index] < len(self.examples): 107 | if self.sample_choices[self.indices[index]] == index: 108 | cur_batch.append(self.examples[self.indices[index]]) 109 | self.indices[index] += 1 110 | if len(cur_batch) == 0: 111 | return False 112 | cur_ind = 0 113 | while(len(cur_batch) < self.batch_size): 114 | cur_batch.append(cur_batch[cur_ind]) 115 | cur_ind += 1 116 | cur_input_tokens_batch = np.asarray([[example.input_tokens] for example in cur_batch]) 117 | cur_input_mask_batch = np.asarray([[example.target_mask] for example in cur_batch]) 118 | 119 | factors = np.ones(self.batch_size) * delta/np.sqrt(self.batch_size) 120 | return factors, cur_input_tokens_batch, cur_input_mask_batch 121 | -------------------------------------------------------------------------------- /precondition/tearfree/momentum.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Momentum configuration and transform.""" 16 | 17 | import copy 18 | import dataclasses 19 | from typing import Union 20 | 21 | import jax 22 | import optax 23 | from precondition.tearfree import praxis_shim 24 | 25 | 26 | @dataclasses.dataclass 27 | class Options: 28 | """Configuration dataclass for momentum. 29 | 30 | Notably, this class contains weight decay parameters. Why? 31 | 32 | In classical convex literature, Nesterov acceleration applied to gradient 33 | descent can be viewed as "revising" the last iterate's momentum based on 34 | the gradient we observe immediately after taking a momentum "gamble" 35 | (see viz, https://stats.stackexchange.com/a/191727). 36 | 37 | To maintain this interpretation exactly, we would need to go against 38 | the grain on how weight decay is implemented. Momentum must be the last* 39 | gradient transformation applied to the iterate, which would require the 40 | weight decay to be applied to the update before it's used to change 41 | the velocity (momentum's state, the first moment). 42 | 43 | In particular, AdamW and Adafactor suggest direct weight downscaling, 44 | excluding weight decay from the velocity accumulation. 45 | 46 | As a result, the true meaning of Nesterov acceleration here is better 47 | understood literally, described in its parameter doc. 48 | 49 | *Technically, some optimizers include the learning rate in the update used to 50 | update the velocity (e.g., Adafactor), but others apply the learning rate 51 | scaling last, after momentum (e.g., Adam). We can recover the former from the 52 | latter by dividing the decay by the root of the learning rate, so this 53 | particular "gradient transformation" shouldn't be viewed as affecting 54 | the Nesterov interpretation, up to tuning constants. 55 | 56 | Attributs: 57 | ema: If true, momentum is computed as an exponential moving 58 | average: `velocity(t+1) = decay * velocity(t) + (1 - decay) * update(t)` 59 | If false, then uses "trace" accumulation for momentum: 60 | `velocity(t+1) = decay * velocity(t) + update(t)`. Note that if the 61 | updates were the same (they aren't) then these would be the same up to a 62 | factor of `(1 - decay)`. This corresponds to distributed_shampoo argument 63 | `moving_average_for_momentum`. 64 | nesterov: Toggle for Nesterov acceleration. If false, then the new 65 | update `update'(t+1)` simply equals `velocity(t+1)`. If true, then 66 | `update'(t+1) = maybe_decay * update(t) + decay * velocity(t+1)`, where 67 | `maybe_decay` is `(1 - decay)` if `ema` and 1 otherwise. 68 | momentum_decay: The decay referred to in `ema` and `nesterov` formulas. 69 | weight_decay: Add `weight_decay * x(t)` to the `update(t)` value, where 70 | `x(t)` is the value of the current parameters. 71 | weight_decay_after_momentum: Whether weight decay addition is performed 72 | after the momentum transformation. 73 | """ 74 | 75 | ema: bool = False 76 | nesterov: bool = True 77 | momentum_decay: float = 0.9 78 | weight_decay: float = 0.0 79 | weight_decay_after_momentum: bool = True 80 | 81 | 82 | State = Union[optax.MaskedNode, optax.TraceState] 83 | 84 | 85 | def apply(options: Options) -> praxis_shim.ShardedGradientTransformation: 86 | """Generate the momentum update from options.""" 87 | _validate(options) 88 | 89 | momentum_transforms = [] 90 | if options.momentum_decay: 91 | if options.ema: 92 | momentum_transforms.append(optax.scale(1 - options.momentum_decay)) 93 | momentum_transforms.append( 94 | _sharded_trace(options.momentum_decay, options.nesterov) 95 | ) 96 | 97 | wd_transforms = [optax.add_decayed_weights(options.weight_decay)] * ( 98 | options.weight_decay > 0.0 99 | ) 100 | 101 | if options.weight_decay_after_momentum: 102 | transforms = momentum_transforms + wd_transforms 103 | else: 104 | transforms = wd_transforms + momentum_transforms 105 | 106 | return praxis_shim.sharded_chain(*transforms) 107 | 108 | 109 | def _validate(options: Options): 110 | """Raise ValueError if options are invalid.""" 111 | if not (0 <= options.momentum_decay <= 1): 112 | raise ValueError( 113 | 'momentum_decay ({}) must be in [0, 1]'.format(options.momentum_decay) 114 | ) 115 | 116 | if not (options.weight_decay >= 0): 117 | raise ValueError( 118 | 'weight_decay ({}) must be >= 0'.format(options.weight_decay) 119 | ) 120 | 121 | 122 | def _sharded_trace( 123 | momentum: float, nesterov: bool 124 | ) -> praxis_shim.ShardedGradientTransformation: 125 | """Extend optax's trace to allow sharding.""" 126 | trace = optax.trace(momentum, nesterov) 127 | 128 | def init_pspec_fn(mdl_params): 129 | def _opt_state_sharding_spec(var_hparams): 130 | s_var_hparams = copy.deepcopy(var_hparams) 131 | s_var_hparams.init = None 132 | return s_var_hparams 133 | 134 | mdl_sharding = jax.tree.map(_opt_state_sharding_spec, mdl_params) 135 | return optax.TraceState(trace=mdl_sharding) 136 | 137 | return praxis_shim.ShardedGradientTransformation( 138 | trace.init, trace.update, init_pspec_fn 139 | ) 140 | -------------------------------------------------------------------------------- /precondition/tearfree/optimizer_smoke_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Smoke tests for tearfree. 16 | 17 | The smoke test uses CPU-based sharding to verify that, under a variety of 18 | settings, (1) the optimizer results in finite, not-nan gradients and (2) 19 | distributed computation options don't change the math. 20 | """ 21 | 22 | import copy 23 | from typing import Sequence, Union 24 | 25 | from absl.testing import absltest 26 | from absl.testing import parameterized 27 | import chex 28 | import jax 29 | from jax import numpy as jnp 30 | import numpy as np 31 | import optax 32 | from precondition.tearfree import grafting 33 | from precondition.tearfree import momentum 34 | from precondition.tearfree import optimizer 35 | from precondition.tearfree import second_order 36 | from precondition.tearfree import shampoo 37 | from precondition.tearfree import sketchy 38 | 39 | 40 | def _make_distributed_equality_cases() -> list[dict[str, ...]]: 41 | """Make test cases of options for optimizer checks.""" 42 | cases = [] 43 | 44 | # Basic options exercise all of shampoo, grafting after the first step. 45 | basic_options = optimizer.TearfreeOptions( 46 | grafting_options=grafting.Options( 47 | grafting_type=grafting.GraftingType.RMSPROP, 48 | second_moment_decay=0.9, 49 | epsilon=1e-5, 50 | start_preconditioning_step=1, 51 | skip_preconditioning_any_dim_gt=4096, 52 | skip_preconditioning_rank1=False, 53 | ), 54 | second_order_options=second_order.Options( 55 | second_order_type=second_order.SecondOrderType.SHAMPOO, 56 | shampoo_options=shampoo.Options( 57 | block_size=1024, 58 | update_preconditioners_freq=1, 59 | update_statistics_freq=1, 60 | second_moment_decay=0.9, 61 | ), 62 | merge_dims=4096, 63 | ), 64 | momentum_options=momentum.Options( 65 | ema=True, 66 | nesterov=True, 67 | momentum_decay=0.5, 68 | weight_decay=0.0, 69 | weight_decay_after_momentum=True, 70 | ), 71 | ) 72 | 73 | basic_case = { 74 | 'testcase_name': 'basic', 75 | 'nsteps': 3, 76 | 'options': basic_options, 77 | 'lr': 0.1, 78 | 'shape': (4,), 79 | } 80 | cases.append(basic_case) 81 | 82 | case = copy.deepcopy(basic_case) 83 | case['lr'] = lambda x: 0.1 / (x + 1) 84 | case['testcase_name'] = 'schedule' 85 | cases.append(case) 86 | 87 | case = copy.deepcopy(basic_case) 88 | second_order_options = case['options'].second_order_options 89 | second_order_options.second_order_type = second_order.SecondOrderType.SKETCHY 90 | second_order_options.shampoo_options = None 91 | second_order_options.sketchy_options = sketchy.Options() 92 | case['testcase_name'] = 'sketchy' 93 | cases.append(case) 94 | 95 | case = copy.deepcopy(case) 96 | case['testcase_name'] += '_notrunc_lowrank' 97 | sketchy_options = case['options'].second_order_options.sketchy_options 98 | sketchy_options.truncate_numerical_noise = False 99 | sketchy_options.rank = 2 100 | cases.append(case) 101 | 102 | case = copy.deepcopy(basic_case) 103 | case['options'].grafting_options.grafting_type = ( 104 | grafting.GraftingType.ADAFACTOR 105 | ) 106 | case['testcase_name'] = 'adafactor' 107 | cases.append(case) 108 | 109 | # Need to test we at least parallelize the identical-to-tensor shapes 110 | # without any blocks. 111 | # Additional variants: 112 | # wd 113 | # wd with decay before momentum 114 | # grid of nesterov/ema 115 | # exercise merge dims 2d doing a merge 116 | # exercise merge dims 3d with only one thing merged 117 | # skip preconditioning any dim gt activating 118 | # skip preconditioning any dim gt rank1 activating 119 | # update stats/precond every 2 (6 steps) 120 | # update stats/precond every 2/4 (6 steps) 121 | 122 | # Test block-wise parallelism for Shampoo 123 | 124 | return cases 125 | 126 | 127 | class OptimizerSmokeTest(parameterized.TestCase): 128 | """Basic test for optimizer configurations.""" 129 | 130 | def _unroll(self, options, shape, transform=None, lr=0.1, n=4): 131 | """Generate states and grad updates n times.""" 132 | rng = jax.random.PRNGKey(0) 133 | params = jnp.zeros(shape) 134 | grads = jax.random.normal(rng, (n, *shape)) 135 | 136 | if transform is not None: 137 | params = transform(params) 138 | grads = jnp.stack([transform(g) for g in grads]) 139 | 140 | tx = optimizer.tearfree(lr, options) 141 | 142 | init = tx.init(params) 143 | 144 | def reduce(state, grad): 145 | new_grad, new_state = tx.update(grad, state, params) 146 | return new_state, new_grad 147 | 148 | _, out_grads = jax.lax.scan(reduce, init, grads) 149 | return out_grads 150 | 151 | @parameterized.named_parameters(_make_distributed_equality_cases()) 152 | def test_distributed_equality( 153 | self, 154 | options: optimizer.TearfreeOptions, 155 | shape: Sequence[int], 156 | lr: Union[float, optax.Schedule], 157 | nsteps: int, 158 | ) -> None: 159 | single_core = self._unroll(options, shape, lr=lr, n=nsteps) 160 | multi_core = self._unroll(options, shape, lr=lr, n=nsteps) 161 | 162 | chex.assert_tree_all_finite(single_core) 163 | np.testing.assert_allclose(single_core, multi_core) 164 | 165 | 166 | if __name__ == '__main__': 167 | absltest.main() 168 | -------------------------------------------------------------------------------- /precondition/tearfree/momentum_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for momentum implementation.""" 16 | 17 | import itertools 18 | from typing import Sequence 19 | 20 | from absl.testing import absltest 21 | from absl.testing import parameterized 22 | import jax 23 | from jax import numpy as jnp 24 | import numpy as np 25 | import optax 26 | from precondition.tearfree import momentum 27 | 28 | jax.config.update('jax_threefry_partitionable', False) 29 | 30 | 31 | def _make_no_state_cases() -> Sequence[dict[str, ...]]: 32 | bools = [False, True] 33 | cases = [] 34 | for ema, nesterov, wd, wd_after in itertools.product( 35 | bools, bools, [0.0, 0.9], bools 36 | ): 37 | momentum_decay = 0.0 38 | options = momentum.Options( 39 | ema, 40 | nesterov, 41 | momentum_decay, 42 | wd, 43 | wd_after, 44 | ) 45 | cases.append({'options': options}) 46 | return cases 47 | 48 | 49 | def _make_invalid_cases() -> Sequence[dict[str, ...]]: 50 | """Generate invalid cases which should throw.""" 51 | return [ 52 | { 53 | 'testcase_name': 'momentum_neg', 54 | 'invalid_options': momentum.Options( 55 | momentum_decay=-1.0, 56 | ), 57 | }, 58 | { 59 | 'testcase_name': 'wd_neg', 60 | 'invalid_options': momentum.Options( 61 | weight_decay=-0.1, 62 | ), 63 | }, 64 | { 65 | 'testcase_name': 'momentum_large', 66 | 'invalid_options': momentum.Options( 67 | momentum_decay=1.1, 68 | ), 69 | }, 70 | ] 71 | 72 | 73 | class MomentumTest(parameterized.TestCase): 74 | """Basic test for momentum implementation.""" 75 | 76 | def _unroll(self, tx, n, extract=False, wd=0): 77 | """Generate states and grad updates n times.""" 78 | rng = jax.random.PRNGKey(0) 79 | params = jnp.ones((3,)) 80 | grads = jax.random.normal(rng, (n, 3)) + wd * params 81 | init = tx.init(params) 82 | 83 | def scan(state, grad): 84 | new_grad, new_state = tx.update(grad, state, params) 85 | return new_state, (new_state, new_grad) 86 | 87 | _, (states, out_grad) = jax.lax.scan(scan, init, grads) 88 | if not extract: 89 | return out_grad 90 | return self._extract_velocity(states), out_grad, grads 91 | 92 | def _check_equal(self, expected_tx, actual_tx, nsteps): 93 | expected_grads = self._unroll(expected_tx, nsteps) 94 | actual_grads = self._unroll(actual_tx, nsteps) 95 | np.testing.assert_allclose(expected_grads, actual_grads) 96 | 97 | @parameterized.parameters(0.1, 0.9, 0.99) 98 | def test_ema(self, decay): 99 | """Check that we simulate ema decay.""" 100 | options = momentum.Options(ema=True, nesterov=False, momentum_decay=decay) 101 | nsteps = 4 102 | actual = momentum.apply(options) 103 | expected = optax.ema(decay, debias=False) 104 | self._check_equal(expected, actual, nsteps) 105 | 106 | def _extract_velocity(self, state): 107 | """Asserts only velocity state exists, extracts it.""" 108 | flat = jax.tree_util.tree_flatten(state)[0] 109 | self.assertLen(flat, 1) 110 | return flat[0] 111 | 112 | @parameterized.parameters(itertools.product([False, True], repeat=2)) 113 | def test_wd_before_momentum(self, ema, nesterov): 114 | options = momentum.Options( 115 | ema=ema, 116 | nesterov=nesterov, 117 | momentum_decay=0.9, 118 | weight_decay=0.0, 119 | ) 120 | nsteps = 4 121 | tx = momentum.apply(options) 122 | expected_grads = self._unroll(tx, nsteps, wd=0.1) 123 | options = momentum.Options( 124 | ema=ema, 125 | nesterov=nesterov, 126 | momentum_decay=0.9, 127 | weight_decay=0.1, 128 | weight_decay_after_momentum=False, 129 | ) 130 | tx = momentum.apply(options) 131 | actual_grads = self._unroll(tx, nsteps) 132 | np.testing.assert_allclose(expected_grads, actual_grads) 133 | 134 | @parameterized.parameters(itertools.product([False, True], repeat=2)) 135 | def test_basic(self, ema, decay_after): 136 | wd = 0.1 if decay_after else 0.0 137 | if decay_after: 138 | return 139 | decay = 0.9 140 | options = momentum.Options( 141 | ema=ema, 142 | nesterov=True, 143 | momentum_decay=decay, 144 | weight_decay=wd, 145 | weight_decay_after_momentum=True, 146 | ) 147 | tx = momentum.apply(options) 148 | v, g, ig = self._unroll(tx, 2, extract=True) 149 | 150 | ev = jnp.zeros((3,)) 151 | factor = (1 - decay) if ema else 1.0 152 | ev += factor * ig[0] 153 | self.assertSequenceAlmostEqual(v[0], ev, msg=v) 154 | expected_grad = decay * ev + factor * ig[0] 155 | expected_grad += jnp.ones((3,)) * wd 156 | self.assertSequenceAlmostEqual(g[0], expected_grad) 157 | 158 | ev = ev * decay + factor * ig[1] 159 | self.assertSequenceAlmostEqual(v[1], ev, delta=1e-6) 160 | expected_grad = decay * ev + factor * ig[1] 161 | expected_grad += jnp.ones((3,)) * wd 162 | self.assertSequenceAlmostEqual(g[1], expected_grad, delta=1e-6) 163 | 164 | @parameterized.parameters(_make_no_state_cases()) 165 | def test_no_state(self, options): 166 | """Ensure no state is created when decay is 0.0.""" 167 | assert options.momentum_decay == 0.0 168 | tx = momentum.apply(options) 169 | state = tx.init(jnp.zeros((3,))) 170 | flat = jax.tree_util.tree_flatten(state)[0] 171 | self.assertEmpty(flat) 172 | 173 | @parameterized.named_parameters(_make_invalid_cases()) 174 | def test_invalid(self, invalid_options): 175 | with self.assertRaises(ValueError): 176 | momentum.apply(invalid_options) 177 | 178 | 179 | if __name__ == '__main__': 180 | absltest.main() 181 | -------------------------------------------------------------------------------- /precondition/datamix_gemma/dataset_builders/open_orca_dataset_builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Dataset builder for the Open Orca dataset.""" 16 | 17 | import enum as Enum 18 | 19 | from absl import logging 20 | import jax.dlpack 21 | from precondition.datamix_gemma.dataset_builders import dataset_builder 22 | from precondition.datamix_gemma.tokenizers import gemma_tokenizer 23 | import tensorflow as tf 24 | import tensorflow_datasets as tfds 25 | 26 | 27 | class DatasetSplit(Enum.Enum): 28 | TRAIN = 'train' 29 | 30 | 31 | class OpenOrcaDatasetBuilder(dataset_builder.DatasetBuilder): 32 | """Dataset builder for the Open Orca dataset.""" 33 | 34 | N_ITEMS = {DatasetSplit.TRAIN: 2914896} 35 | 36 | #BUFFER_SIZE_SHUFFLE = 10_000 37 | BUFFER_SIZE_SHUFFLE = 100 38 | SYSTEM_PREFIX = 'System: \n' 39 | SYSTEM_SUFFIX = '\n' 40 | QUESTION_PREFIX = 'Question: \n' 41 | QUESTION_SUFFIX = '\n' 42 | #TRANSLATION_PREFIX = 'Translate this into French:\n' 43 | #TRANSLATION_SUFFIX = '\n' 44 | 45 | def __init__( 46 | self, tokenizer: gemma_tokenizer.GemmaTokenizer, max_seq_len: int 47 | ): 48 | """Constructor. 49 | 50 | Args: 51 | tokenizer: Gemma tokenizer to use. 52 | max_seq_len: size of each sequence in a given batch. 53 | """ 54 | self._tokenizer = tokenizer 55 | self._base_data = { 56 | DatasetSplit.TRAIN: tfds.load( 57 | 'huggingface:open_orca__openorca', split='train' 58 | ), 59 | } 60 | logging.info( 61 | 'open orca size: %s', 62 | self._base_data[DatasetSplit.TRAIN].cardinality().numpy(), 63 | ) 64 | self._max_seq_len = max_seq_len 65 | 66 | def _tokenize_system(self, example: tf.Tensor) -> tf.Tensor: 67 | """Tokenization function for the system prompt.""" 68 | res = self._tokenizer.tokenize_tf_op( 69 | example, 70 | prefix=self.SYSTEM_PREFIX, 71 | suffix=self.SYSTEM_SUFFIX, 72 | add_eos=False, 73 | ) 74 | return res 75 | 76 | def _tokenize_question(self, example: tf.Tensor): 77 | """Tokenization function for the Question.""" 78 | return self._tokenizer.tokenize_tf_op( 79 | example, 80 | prefix=self.QUESTION_PREFIX, 81 | suffix=self.QUESTION_SUFFIX, 82 | add_eos=False, 83 | ) 84 | 85 | def _tokenize_response(self, example: tf.Tensor): 86 | """Tokenization function for the Response.""" 87 | return self._tokenizer.tokenize_tf_op( 88 | example, 89 | add_eos=True, 90 | ) 91 | 92 | def _to_training_input( 93 | self, 94 | system_tokens: jax.Array, 95 | question_tokens: jax.Array, 96 | response_tokens: jax.Array, 97 | ): 98 | """Build a training input from a tuple of source and destination tokens.""" 99 | 100 | # The input sequence fed to the model is simply the concatenation of the 101 | # source and the destination. 102 | tokens = tf.concat( 103 | [system_tokens, question_tokens, response_tokens], axis=0 104 | ) 105 | 106 | # To prevent the model from updating based on the source (input) 107 | # tokens, add a target mask to each input. 108 | system_mask = tf.zeros_like(system_tokens, dtype=tf.bool) 109 | question_mask = tf.zeros_like(question_tokens, dtype=tf.bool) 110 | response_mask = tf.ones_like(response_tokens, dtype=tf.bool) 111 | mask = tf.concat([system_mask, question_mask, response_mask], axis=0) 112 | 113 | # If the output tokens sequence is smaller than the target sequence size, 114 | # then pad it with pad tokens. 115 | tokens = self._pad_up_to_max_len(tokens, self._tokenizer.pad_id) 116 | 117 | # Don't want to perform the backward pass on the pad tokens. 118 | mask = self._pad_up_to_max_len(mask, False) 119 | return dataset_builder.TrainingInput( #type: ignore 120 | input_tokens=tokens, #type:ignore 121 | target_mask=mask, #type:ignore 122 | )# type: ignore 123 | 124 | def get_train_dataset(self, batch_size: int, num_epochs: int): 125 | """Build the training dataset.""" 126 | 127 | ds = self._base_data[DatasetSplit.TRAIN].map( 128 | lambda x: ( 129 | self._tokenize_system(x['system_prompt']), 130 | self._tokenize_question(x['question']), 131 | self._tokenize_response(x['response']) 132 | ), 133 | num_parallel_calls=tf.data.AUTOTUNE, 134 | ) 135 | ds = ds.map(self._to_training_input, 136 | num_parallel_calls=tf.data.AUTOTUNE) 137 | ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len) 138 | ds = ds.shuffle(buffer_size=self.BUFFER_SIZE_SHUFFLE) 139 | #ds = ds.repeat(num_epochs) 140 | #ds = ds.batch(batch_size, drop_remainder=True) 141 | return ds 142 | 143 | def get_validation_dataset(self, batch_size: int): 144 | """Build the validation dataset.""" 145 | 146 | # Same steps as in `get_train_dataset`, but without shuffling and 147 | # repetition. 148 | # ds = self._base_data[DatasetSplit.VALIDATION].map( 149 | # lambda x: (self._tokenize_source(x['src']), 150 | # self._tokenize_destination(x['dst']))) 151 | ds = self._base_data[DatasetSplit.TRAIN].map( 152 | lambda x: ( 153 | self._tokenize_system(x['system_prompt']), 154 | self._tokenize_question(x['question']), 155 | self._tokenize_response(x['response']), 156 | ), 157 | num_parallel_calls=tf.data.AUTOTUNE, 158 | ) 159 | ds = ds.map(self._to_training_input, num_parallel_calls=tf.data.AUTOTUNE) 160 | ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len) 161 | # ds = ds.batch(batch_size, drop_remainder=True) 162 | return ds 163 | # ds = [self._to_training_input(x, y) for x, y in ds] 164 | # print('here3:', ds) 165 | # ds = [x for x in ds if tf.shape(x.input_tokens)[0] <= self._max_seq_len] 166 | # ds = [ds[i : i + batch_size] for i in range(0, len(ds), batch_size)] 167 | -------------------------------------------------------------------------------- /precondition/sm3.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """SM3 Implementation.""" 16 | 17 | import functools 18 | from typing import Any, NamedTuple 19 | 20 | import chex 21 | import jax 22 | import jax.numpy as jnp 23 | import optax 24 | 25 | from precondition.quantization_utils import QuantizedValue 26 | 27 | 28 | class SM3State(NamedTuple): 29 | count: chex.Array 30 | stats: Any 31 | 32 | 33 | # Per parameter optimizer state used in data-parallel training. 34 | class ParameterStats(NamedTuple): 35 | """State associated to each parameter of the model being trained.""" 36 | diagonal_statistics: chex.Array # Accumulator for diagonal preconditioner 37 | diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner 38 | 39 | 40 | def sm3( 41 | learning_rate, 42 | beta1=0.9, 43 | beta2=0.999, 44 | diagonal_epsilon=1e-10, 45 | weight_decay=0.0, 46 | normalize_grads=False): 47 | """SM3 optimizer. 48 | 49 | Memory-Efficient Adaptive Optimization, Rohan Anil, Vineet Gupta, Tomer Koren, 50 | Yoram Singer 51 | 52 | https://arxiv.org/abs/1901.11150 53 | 54 | Args: 55 | learning_rate: the step size used to update the parameters. 56 | beta1: momentum parameter. 57 | beta2: second moment averaging parameter. 58 | diagonal_epsilon: epsilon for sm3 59 | weight_decay: the amount of weight decay regularization to apply. defaults 60 | to 0.0. 61 | normalize_grads: Whether to normalize grads. Author finds it useful when 62 | grads are high variance. 63 | 64 | Returns: 65 | a GradientTransformation. 66 | """ 67 | 68 | def _quantize_momentum(momentum_statistics): 69 | return QuantizedValue.from_float_value(momentum_statistics, jnp.int8) 70 | 71 | def init_fn(params): 72 | """Initialise the optimiser's state.""" 73 | 74 | def _init(param): 75 | accumulators = [jnp.zeros([s]) for s in param.shape] 76 | momentum = _quantize_momentum(jnp.zeros_like(param)) 77 | return ParameterStats(accumulators, momentum) # pytype: disable=wrong-arg-types # numpy-scalars 78 | 79 | return SM3State( 80 | count=jnp.zeros([], jnp.int32), stats=jax.tree.map(_init, params)) 81 | 82 | def _get_expanded_shape(shape, i): 83 | rank = len(shape) 84 | # Replaces a `shape` of [M, N, K] with 1 in all dimensions except for i. 85 | # For eg: i = 1 returns [1, N, 1]. 86 | return [1] * i + [shape[i]] + [1] * (rank - i - 1) 87 | 88 | def _moving_averages(grad, accumulators): 89 | w = (1.0 - beta2) if beta2 != 1.0 else 1.0 90 | if grad.ndim < 2: 91 | return beta2 * accumulators[0] + w * grad**2 92 | else: 93 | min_accumulator = functools.reduce(jnp.minimum, accumulators) 94 | return beta2 * min_accumulator + w * grad**2 95 | 96 | def _moving_averages_momentum(grad, momentum): 97 | w = (1.0 - beta1) if beta1 != 1.0 else 1.0 98 | return beta1 * momentum.to_float() + w * grad 99 | 100 | def _sketch_diagonal_statistics(grad, updated_diagonal_statistics): 101 | all_diagonal_statistics = [] 102 | for i in range(grad.ndim): 103 | axes = list(range(i)) + list(range(i + 1, grad.ndim)) 104 | dim_diagonal_statistics = jnp.max(updated_diagonal_statistics, axis=axes) 105 | all_diagonal_statistics.append(dim_diagonal_statistics) 106 | if grad.ndim == 1: 107 | all_diagonal_statistics[0] = updated_diagonal_statistics 108 | return all_diagonal_statistics 109 | 110 | def update_fn(updates, state, params): 111 | stats = state.stats 112 | if normalize_grads: 113 | updates = jax.tree.map( 114 | lambda g: g / (jnp.linalg.norm(g) + 1e-16), updates) 115 | # Reshape all vectors into N-d tensors to compute min over them. 116 | # [n], [m] -> [n, 1], [1, m] 117 | expanded_diagonal_statistics = jax.tree.map( 118 | lambda grad, state: # pylint:disable=g-long-lambda 119 | [ 120 | jnp.reshape(state.diagonal_statistics[i], 121 | _get_expanded_shape(grad.shape, i)) 122 | for i in range(grad.ndim) 123 | ], 124 | updates, 125 | stats) 126 | 127 | # Compute new diagonal statistics 128 | new_diagonal_statistics = jax.tree.map(_moving_averages, updates, 129 | expanded_diagonal_statistics) 130 | 131 | # Compute preconditioners (1/sqrt(s)) where s is the statistics. 132 | new_preconditioners = jax.tree.map( 133 | lambda t: 1.0 / jnp.sqrt(t + diagonal_epsilon), new_diagonal_statistics) 134 | preconditioned_grads = jax.tree.map(lambda g, p: g * p, updates, 135 | new_preconditioners) 136 | 137 | # Compute updated momentum (also handle quantization) 138 | updated_momentum = jax.tree.map( 139 | lambda preconditioned_grad, state: # pylint:disable=g-long-lambda 140 | _moving_averages_momentum(preconditioned_grad, state.diagonal_momentum), 141 | preconditioned_grads, 142 | stats) 143 | 144 | # Update diagonal statistics. 145 | updated_diagonal_statistics = jax.tree.map(_sketch_diagonal_statistics, 146 | updates, new_diagonal_statistics) 147 | 148 | # Update momentum. 149 | new_sm3_stats = jax.tree.map( 150 | lambda momentum, diagonal_stats: # pylint:disable=g-long-lambda 151 | ParameterStats(diagonal_stats, _quantize_momentum(momentum)), 152 | updated_momentum, 153 | updated_diagonal_statistics) 154 | 155 | # Apply weight decay 156 | updated_momentum_with_wd = updated_momentum 157 | if weight_decay > 0.0: 158 | updated_momentum_with_wd = jax.tree.map(lambda g, p: g + weight_decay * p, 159 | updated_momentum, params) 160 | 161 | lr = learning_rate 162 | if callable(learning_rate): 163 | lr = learning_rate(state.count) 164 | 165 | new_updates = jax.tree.map(lambda pg: -lr * pg, updated_momentum_with_wd) 166 | return new_updates, SM3State(count=state.count+1, stats=new_sm3_stats) 167 | 168 | return optax.GradientTransformation(init_fn, update_fn) 169 | -------------------------------------------------------------------------------- /precondition/datamix_gemma/evals/crop.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Byte pair encoding utilities (Adapted from the official GPT-2 GitHub repository).""" 16 | 17 | import functools 18 | import json 19 | import os 20 | 21 | import regex as re 22 | import requests 23 | import tqdm 24 | 25 | 26 | data_dir = '/home/shivguptashi/data' 27 | 28 | 29 | def _get_encoder(subdir): 30 | """Downloads the encoder and vocab to the subdir.""" 31 | print('Downloading encoder and vocab to ', subdir) 32 | for filename in ['encoder.json', 'vocab.bpe']: 33 | r = requests.get( 34 | 'https://openaipublic.blob.core.windows.net/gpt-2/' 35 | + subdir 36 | + '/' 37 | + filename, 38 | stream=True, 39 | ) 40 | with open(os.path.join(subdir, filename), 'wb') as f: 41 | file_size = int(r.headers['content-length']) 42 | chunk_size = 1000 43 | with tqdm.tqdm( 44 | ncols=100, 45 | desc='Fetching ' + filename, 46 | total=file_size, 47 | unit_scale=True, 48 | ) as pbar: 49 | # 1k for chunk_size, since Ethernet packet size is around 1500 50 | # bytes 51 | for chunk in r.iter_content(chunk_size=chunk_size): 52 | f.write(chunk) 53 | pbar.update(chunk_size) 54 | 55 | 56 | @functools.lru_cache() 57 | def bytes_to_unicode(): 58 | """Returns list of utf-8 byte and a corresponding list of unicode strings. 59 | 60 | The reversible bpe codes work on unicode strings. 61 | This means you need a large # of unicode characters in your vocab if you 62 | want to avoid UNKs. When you're at something like a 10B token dataset you 63 | end up needing around 5K for decent coverage. This is a signficant 64 | percentage of your normal, say, 32K bpe vocab. To avoid that, we want 65 | lookup tables between utf-8 bytes and unicode strings. And avoids mapping 66 | to whitespace/control characters the bpe code barfs on. 67 | """ 68 | bs = ( 69 | list(range(ord('!'), ord('~') + 1)) 70 | + list(range(ord('¡'), ord('¬') + 1)) 71 | + list(range(ord('®'), ord('ÿ') + 1)) 72 | ) 73 | cs = bs[:] 74 | n = 0 75 | for b in range(2**8): 76 | if b not in bs: 77 | bs.append(b) 78 | cs.append(2**8 + n) 79 | n += 1 80 | cs = [chr(n) for n in cs] 81 | return dict(zip(bs, cs)) 82 | 83 | 84 | def get_pairs(word): 85 | """Return set of symbol pairs in a word. 86 | 87 | Word is represented as tuple of symbols (symbols being variable-length 88 | strings). 89 | 90 | Args: 91 | word: A string. 92 | 93 | Returns: 94 | A set of symbol pairs. 95 | """ 96 | pairs = set() 97 | prev_char = word[0] 98 | for char in word[1:]: 99 | pairs.add((prev_char, char)) 100 | prev_char = char 101 | return pairs 102 | 103 | 104 | class Encoder: 105 | """Encoder for byte pair encoding.""" 106 | 107 | def __init__(self, encoder, bpe_merges, errors='replace'): 108 | self.encoder = encoder 109 | self.decoder = {v: k for k, v in self.encoder.items()} 110 | self.errors = errors # how to handle errors in decoding 111 | self.byte_encoder = bytes_to_unicode() 112 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 113 | self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) 114 | self.cache = {} 115 | 116 | # Should haved added re.IGNORECASE so BPE merges can happen for capitalized 117 | # versions of contractions 118 | self.pat = re.compile( 119 | r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" 120 | ) 121 | 122 | def bpe(self, token): 123 | """Performs BPE on the given token.""" 124 | if token in self.cache: 125 | return self.cache[token] 126 | word = tuple(token) 127 | pairs = get_pairs(word) 128 | 129 | if not pairs: 130 | return token 131 | 132 | while True: 133 | bigram = min( 134 | pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')) 135 | ) 136 | if bigram not in self.bpe_ranks: 137 | break 138 | first, second = bigram 139 | new_word = [] 140 | i = 0 141 | while i < len(word): 142 | try: 143 | j = word.index(first, i) 144 | new_word.extend(word[i:j]) 145 | i = j 146 | except: #pylint: disable=bare-except 147 | new_word.extend(word[i:]) 148 | break 149 | 150 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second: 151 | new_word.append(first + second) 152 | i += 2 153 | else: 154 | new_word.append(word[i]) 155 | i += 1 156 | new_word = tuple(new_word) 157 | word = new_word 158 | if len(word) == 1: 159 | break 160 | else: 161 | pairs = get_pairs(word) 162 | word = ' '.join(word) 163 | self.cache[token] = word 164 | return word 165 | 166 | def encode(self, text): 167 | bpe_tokens = [] 168 | for token in re.findall(self.pat, text): 169 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 170 | bpe_tokens.extend( 171 | self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ') 172 | ) 173 | return bpe_tokens 174 | 175 | def decode(self, tokens): 176 | text = ''.join([self.decoder[token] for token in tokens]) 177 | text = bytearray([self.byte_decoder[c] for c in text]).decode( 178 | 'utf-8', errors=self.errors 179 | ) 180 | return text 181 | 182 | 183 | def get_encoder(model_name): 184 | """Returns the encoder for the given model.""" 185 | subdir = os.path.join('models', model_name) 186 | if not os.path.exists(subdir): 187 | os.makedirs(subdir) 188 | if not os.path.exists(os.path.join(subdir, 'encoder.json')): 189 | _get_encoder(subdir) 190 | 191 | subdir = subdir.replace('\\', '/') # needed for Windows 192 | 193 | with open(os.path.join(subdir, 'encoder.json'), 'r') as f: 194 | encoder = json.load(f) 195 | with open(os.path.join(subdir, 'vocab.bpe'), 'r', encoding='utf-8') as f: 196 | bpe_data = f.read() 197 | bpe_merges = [ 198 | tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1] 199 | ] 200 | return Encoder( 201 | encoder=encoder, 202 | bpe_merges=bpe_merges, 203 | ) 204 | 205 | 206 | enc = get_encoder('124M') 207 | 208 | 209 | def crop_prompt(prompt: str): 210 | """Crops the prompt to 2048 tokens.""" 211 | global enc # pylint: disable=global-variable-not-assigned 212 | 213 | cropped_prompt = enc.decode(enc.encode(prompt)[:2048]) 214 | return cropped_prompt 215 | 216 | 217 | def crop(s): 218 | """Crops the prompt to 2048 tokens.""" 219 | prompt = crop_prompt(s) 220 | return prompt 221 | -------------------------------------------------------------------------------- /precondition/tearfree/optimizer_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for tearfree optimizer.""" 16 | 17 | import dataclasses 18 | 19 | from absl.testing import absltest 20 | from absl.testing import parameterized 21 | import jax 22 | from jax import numpy as jnp 23 | import numpy as np 24 | import optax 25 | from precondition.tearfree import grafting 26 | from precondition.tearfree import momentum 27 | from precondition.tearfree import optimizer 28 | from precondition.tearfree import praxis_shim 29 | from precondition.tearfree import second_order 30 | from precondition.tearfree import shampoo 31 | 32 | 33 | class OptimizerTest(parameterized.TestCase): 34 | """Basic test for optimizer configurations.""" 35 | 36 | def setUp(self): 37 | super().setUp() 38 | jax.config.update('jax_debug_nans', True) 39 | 40 | def _unroll(self, options, shape, transform=None, lr=0.1, n=4): 41 | """Generate states and grad updates n times.""" 42 | rng = jax.random.PRNGKey(0) 43 | params = jnp.zeros(shape) 44 | grads = jax.random.normal(rng, (n, *shape)) 45 | 46 | if transform is not None: 47 | params = transform(params) 48 | grads = jnp.stack([transform(g) for g in grads]) 49 | 50 | if isinstance(options, optimizer.TearfreeOptions): 51 | tx = optimizer.tearfree(lr, options) 52 | else: 53 | tx = options 54 | init = tx.init(params) 55 | 56 | def reduce(state, grad): 57 | new_grad, new_state = tx.update(grad, state, params) 58 | return new_state, new_grad 59 | 60 | _, out_grads = jax.lax.scan(reduce, init, grads) 61 | return out_grads 62 | 63 | def _no_graft_no_momentum(self): 64 | return optimizer.TearfreeOptions( 65 | grafting_options=grafting.Options( 66 | grafting_type=grafting.GraftingType.NONE, 67 | second_moment_decay=0.0, 68 | skip_preconditioning_rank1=False, 69 | ), 70 | momentum_options=momentum.Options(momentum_decay=0.0), 71 | ) 72 | 73 | def test_merge_dims(self): 74 | shape = (2, 2) 75 | options = dataclasses.replace( 76 | self._no_graft_no_momentum(), 77 | second_order_options=second_order.Options(merge_dims=4), 78 | ) 79 | transform = lambda x: x.reshape(4) 80 | actual = self._unroll(options, shape) 81 | expected = self._unroll(options, shape, transform) 82 | np.testing.assert_allclose(actual.reshape(-1, 4), expected) 83 | 84 | def test_block_size(self): 85 | shape = (4,) 86 | options = dataclasses.replace( 87 | self._no_graft_no_momentum(), 88 | second_order_options=second_order.Options( 89 | shampoo_options=shampoo.Options(block_size=3) 90 | ), 91 | ) 92 | actual = self._unroll(options, shape) 93 | expected = self._unroll(options, shape) 94 | np.testing.assert_allclose(actual, expected) 95 | 96 | @parameterized.parameters( 97 | momentum.Options(), # Default is 0.9, active momentum. 98 | momentum.Options(momentum_decay=0.0), 99 | momentum.Options(weight_decay=0.01), 100 | momentum.Options(weight_decay=0.01, weight_decay_after_momentum=False), 101 | momentum.Options(nesterov=False), 102 | momentum.Options(ema=True), 103 | momentum.Options(ema=True, nesterov=True), 104 | ) 105 | def test_momentum_no_graft(self, momentum_options): 106 | shape = (4,) 107 | options = self._no_graft_no_momentum() 108 | options.momentum_options = momentum_options 109 | tx = praxis_shim.sharded_chain( 110 | second_order.apply(options.second_order_options), 111 | momentum.apply(momentum_options), 112 | optax.scale(-0.1), 113 | ) 114 | actual = self._unroll(options, shape) 115 | expected = self._unroll(tx, shape) 116 | np.testing.assert_allclose(actual, expected) 117 | 118 | def _grafting_tx( 119 | self, grafting_options 120 | ) -> praxis_shim.ShardedGradientTransformation: 121 | id_tx = optax.identity() 122 | id_tx_shard = praxis_shim.ShardedGradientTransformation( 123 | id_tx.init, 124 | id_tx.update, 125 | lambda _: optax.EmptyState(), 126 | ) 127 | return grafting.graft(grafting_options, id_tx_shard) 128 | 129 | def _grafting_tx_with_momentum( 130 | self, grafting_options, momentum_options, lr=0.1 131 | ): 132 | return praxis_shim.sharded_chain( 133 | self._grafting_tx(grafting_options), 134 | momentum.apply(momentum_options), 135 | optax.scale(-lr), 136 | ) 137 | 138 | @parameterized.parameters( 139 | grafting.Options(), 140 | grafting.Options( 141 | grafting_type=grafting.GraftingType.SGD, second_moment_decay=0.0 142 | ), 143 | grafting.Options(second_moment_decay=1.0), 144 | ) 145 | def test_momentum_yes_graft(self, grafting_options): 146 | shape = (4,) 147 | nsteps = 4 148 | options = self._no_graft_no_momentum() 149 | options.momentum_options.momentum_decay = 0.9 150 | options.grafting_options = grafting_options 151 | grafting_options.start_preconditioning_step = nsteps + 1 152 | grafting_options.skip_preconditioning_rank1 = False 153 | tx = self._grafting_tx_with_momentum( 154 | grafting_options, options.momentum_options 155 | ) 156 | expected = self._unroll(tx, shape, n=nsteps) 157 | actual = self._unroll(options, shape, n=nsteps) 158 | np.testing.assert_allclose(actual, expected) 159 | 160 | def _precondition_at(self, i): 161 | """Return optimizer with momentum, grafting, and start precon at step i.""" 162 | return optimizer.TearfreeOptions( 163 | grafting_options=grafting.Options( 164 | start_preconditioning_step=i, skip_preconditioning_rank1=False 165 | ) 166 | ) 167 | 168 | @parameterized.parameters( 169 | dict(shape=(1, 1, 1)), 170 | dict(shape=(1,)), 171 | dict(shape=tuple()), 172 | ) 173 | def test_scalar_is_grafting(self, shape): 174 | nsteps = 4 175 | options = self._precondition_at(2) 176 | tx = self._grafting_tx_with_momentum( 177 | options.grafting_options, options.momentum_options 178 | ) 179 | expected = self._unroll(tx, shape, n=nsteps) 180 | actual = self._unroll(options, shape, n=nsteps) 181 | np.testing.assert_allclose(actual, expected) 182 | 183 | def test_lr(self): 184 | shape = (3,) 185 | options = self._precondition_at(2) 186 | nsteps = 4 187 | 188 | def schedule(count): 189 | return (count + 1) * 0.1 190 | 191 | actual = self._unroll(options, shape, lr=schedule, n=nsteps) 192 | expected = self._unroll(options, shape, lr=0.1, n=nsteps) 193 | expected *= (jnp.arange(nsteps) + 1).reshape(-1, 1) 194 | np.testing.assert_allclose(actual, expected) 195 | 196 | 197 | if __name__ == '__main__': 198 | absltest.main() 199 | -------------------------------------------------------------------------------- /precondition/datamix_gemma/Wikipedia_processing.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [] 7 | }, 8 | "kernelspec": { 9 | "name": "python3", 10 | "display_name": "Python 3" 11 | }, 12 | "language_info": { 13 | "name": "python" 14 | } 15 | }, 16 | "cells": [ 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": { 21 | "id": "PE4ZRoyy4qGp" 22 | }, 23 | "outputs": [], 24 | "source": [ 25 | "import tensorflow_datasets as tfds\n", 26 | "import csv\n", 27 | "import sys\n", 28 | "import tensorflow as tf\n", 29 | "\n", 30 | "wiki_tfds = tfds.load('wikipedia/20190301.en', split='train')" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "source": [ 36 | "sample = wiki_tfds.take(2)\n", 37 | "for x in sample:\n", 38 | " print(x)" 39 | ], 40 | "metadata": { 41 | "id": "UXNdsgd55QZs" 42 | }, 43 | "execution_count": null, 44 | "outputs": [] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "source": [ 49 | "csv.field_size_limit(sys.maxsize)\n", 50 | "\n", 51 | "joined_csv_file_path = 'home/shivguptashi/wikidata/joined_table.csv'\n", 52 | "id_to_info = {}\n", 53 | "title_to_info = {}\n", 54 | "with gfile.Open(joined_csv_file_path, 'r') as f:\n", 55 | " csvreader = csv.reader(f, delimiter=',')\n", 56 | " it = 0\n", 57 | " for row in csvreader:\n", 58 | " title = row[-7].strip('|')\n", 59 | " #print(title)\n", 60 | " title_to_info[title] = row\n", 61 | " it += 1\n", 62 | " if it % 1000 == 0:\n", 63 | " print(it)\n" 64 | ], 65 | "metadata": { 66 | "id": "J4WTWdSk6Wg3" 67 | }, 68 | "execution_count": null, 69 | "outputs": [] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "source": [ 74 | "num_rows_in_info = len(title_to_info['Tatrapan'])\n", 75 | "print(num_rows_in_info)\n", 76 | "print(num_rows_in_info - 7 - 4)" 77 | ], 78 | "metadata": { 79 | "id": "ebUptuXE83TD" 80 | }, 81 | "execution_count": null, 82 | "outputs": [] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "source": [ 87 | "from operator import itemgetter\n", 88 | "\n", 89 | "def map_to_features(title, text):\n", 90 | " str_title = title.numpy().decode('utf-8')\n", 91 | " #try:\n", 92 | " if str_title not in title_to_info:\n", 93 | " return (title, text, -1)\n", 94 | " all_entries = [float(y) for y in title_to_info[str_title][4:num_rows_in_info-7]]\n", 95 | " index, _ = max(enumerate(all_entries), key=itemgetter(1))\n", 96 | " return (title, text, index)\n", 97 | " #except:\n", 98 | " # return (title, text, -1)\n", 99 | "\n", 100 | "wiki_tfds = wiki_tfds.map(lambda x: (x['title'], x['text']), num_parallel_calls=16)\n", 101 | "wiki_tfds = wiki_tfds.map(lambda x, y: tf.py_function(map_to_features, [x, y], [tf.string, tf.string, tf.int32]), num_parallel_calls=16)\n" 102 | ], 103 | "metadata": { 104 | "id": "Rc8U4zgE6iub" 105 | }, 106 | "execution_count": null, 107 | "outputs": [] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "source": [ 112 | "save_path = 'home/shivguptashi/wikidata/wiki_tfds_with_topic'\n", 113 | "tf.data.Dataset.save(wiki_tfds, save_path)\n", 114 | "#from functools import partial\n", 115 | "#def filter_f(x, i):\n", 116 | "# try:\n", 117 | "# str_title = x.numpy().decode('utf-8')\n", 118 | "# all_entries = [int(y) for y in title_to_info[str_title][4:num_rows_in_info-7]]\n", 119 | "# print(all_entries)\n", 120 | "# if int(title_to_info[str_title][i]) >= max(all_entries):\n", 121 | "# return True\n", 122 | "# return False\n", 123 | "# except:\n", 124 | "# return False\n", 125 | "#\n", 126 | "#categories_tfds = []\n", 127 | "#for i in range(4, num_rows_in_info - 7):\n", 128 | " #categories_tfds.append(wiki_tfds.filter(lambda x, y: tf.py_function(partial(filter_f, i=i), [x,], [tf.bool,])[0]))" 129 | ], 130 | "metadata": { 131 | "id": "KJI0Cq499oYv" 132 | }, 133 | "execution_count": null, 134 | "outputs": [] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "source": [ 139 | "sample_tfds = wiki_tfds.take(2)\n", 140 | "for x in sample_tfds:\n", 141 | " print(x[0])" 142 | ], 143 | "metadata": { 144 | "id": "TDlr690KwFLk" 145 | }, 146 | "execution_count": null, 147 | "outputs": [] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "source": [ 152 | "sample_category_tfds = categories_tfds[0].take(2)\n", 153 | "for x in sample_category_tfds:\n", 154 | " print(x)" 155 | ], 156 | "metadata": { 157 | "id": "5ARFZVx00J4U" 158 | }, 159 | "execution_count": null, 160 | "outputs": [] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "source": [ 165 | "\n", 166 | "wiki_tfds_with_topic = tf.data.Dataset.load(save_path)\n", 167 | "sample_tfds_with_topic = wiki_tfds_with_topic.take(10)\n", 168 | "sample_wiki_tfds_with_topic = sample_tfds_with_topic.filter(lambda x, y, z: tf.py_function(filter_f, [z], [tf.bool])[0])\n", 169 | "for x in sample_wiki_tfds_with_topic:\n", 170 | " print(x)" 171 | ], 172 | "metadata": { 173 | "id": "iim5buMZIfiU" 174 | }, 175 | "execution_count": null, 176 | "outputs": [] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "source": [ 181 | "from functools import partial\n", 182 | "def filter_f(topic_ind, i):\n", 183 | " if topic_ind == i:\n", 184 | " return True\n", 185 | " return False\n", 186 | "\n", 187 | "topic_wise_save_path = 'home/shivguptashi/wikidata/topic_wise_tfds'\n", 188 | "for i in range(58, num_rows_in_info - 7):\n", 189 | " ind = i - 4\n", 190 | " cur_topic_save_path = topic_wise_save_path + '_topic_' + str(ind)\n", 191 | " filtered_wiki_tfds = wiki_tfds_with_topic.filter(lambda x, y, z: tf.py_function(partial(filter_f, i=ind), [z], [tf.bool])[0])\n", 192 | " print('index:', ind)\n", 193 | " tf.data.Dataset.save(filtered_wiki_tfds, cur_topic_save_path)" 194 | ], 195 | "metadata": { 196 | "id": "6248Q9YpH_Em" 197 | }, 198 | "execution_count": null, 199 | "outputs": [] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "source": [ 204 | "print(num_rows_in_info)" 205 | ], 206 | "metadata": { 207 | "id": "1PItow2MGpIo" 208 | }, 209 | "execution_count": null, 210 | "outputs": [] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "source": [], 215 | "metadata": { 216 | "id": "SEyLBVkDo-FM" 217 | }, 218 | "execution_count": null, 219 | "outputs": [] 220 | } 221 | ] 222 | } -------------------------------------------------------------------------------- /precondition/datamix_gemma/dataset_builders/gsm8k_dataset_builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """GSM8k dataset builder.""" 16 | 17 | import enum as Enum 18 | 19 | import jax.dlpack 20 | from precondition.datamix_gemma.dataset_builders import dataset_builder 21 | from precondition.datamix_gemma.tokenizers import gemma_tokenizer 22 | import tensorflow as tf 23 | import tensorflow_datasets as tfds 24 | 25 | 26 | PREAMBLE = """As an expert problem solver solve step by step the following mathematical questions.""" 27 | 28 | # The default gsm8k prompt from the CoT paper 29 | # https://arxiv.org/pdf/2201.11903.pdf page 35. 30 | 31 | PROMPT = """Q: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today? 32 | A: We start with 15 trees. Later we have 21 trees. The difference must be the number of trees they planted. So, they must have planted 21 - 15 = 6 trees. The answer is 6. 33 | 34 | Q: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot? 35 | A: There are 3 cars in the parking lot already. 2 more arrive. Now there are 3 + 2 = 5 cars. The answer is 5. 36 | 37 | Q: Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total? 38 | A: Leah had 32 chocolates and Leah's sister had 42. That means there were originally 32 + 42 = 74 chocolates. 35 have been eaten. So in total they still have 74 - 35 = 39 chocolates. The answer is 39. 39 | 40 | Q: Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny? 41 | A: Jason had 20 lollipops. Since he only has 12 now, he must have given the rest to Denny. The number of lollipops he has given to Denny must have been 20 - 12 = 8 lollipops. The answer is 8. 42 | 43 | Q: Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now? 44 | A: He has 5 toys. He got 2 from mom, so after that he has 5 + 2 = 7 toys. Then he got 2 more from dad, so in total he has 7 + 2 = 9 toys. The answer is 9. 45 | 46 | Q: There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room? 47 | A: There are 4 days from monday to thursday. 5 computers were added each day. That means in total 4 * 5 = 20 computers were added. There were 9 computers in the beginning, so now there are 9 + 20 = 29 computers. The answer is 29. 48 | 49 | Q: Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday? 50 | A: Michael initially had 58 balls. He lost 23 on Tuesday, so after that he has 58 - 23 = 35 balls. On Wednesday he lost 2 more so now he has 35 - 2 = 33 balls. The answer is 33. 51 | 52 | Q: Olivia has $23. She bought five bagels for $3 each. How much money does she have left? 53 | A: She bought 5 bagels for $3 each. This means she spent 5 * $3 = $15 on the bagels. She had $23 in beginning, so now she has $23 - $15 = $8. The answer is 8.""" 54 | 55 | 56 | class DatasetSplit(Enum.Enum): 57 | TRAIN = 'train' 58 | TEST = 'test' 59 | 60 | 61 | class GSM8KDatasetBuilder(dataset_builder.DatasetBuilder): 62 | """Dataset builder for the GSM8k dataset.""" 63 | 64 | N_ITEMS = {DatasetSplit.TRAIN: 7473} 65 | 66 | #BUFFER_SIZE_SHUFFLE = 10_000 67 | BUFFER_SIZE_SHUFFLE = 100 68 | ANSWER_PREFIX = 'A: ' 69 | ANSWER_SUFFIX = '\n' 70 | QUESTION_PREFIX = 'Q: ' 71 | QUESTION_SUFFIX = '\n' 72 | #TRANSLATION_PREFIX = 'Translate this into French:\n' 73 | #TRANSLATION_SUFFIX = '\n' 74 | 75 | def __init__( 76 | self, tokenizer: gemma_tokenizer.GemmaTokenizer, max_seq_len: int 77 | ): 78 | """Constructor. 79 | 80 | Args: 81 | tokenizer: Gemma tokenizer to use. 82 | max_seq_len: size of each sequence in a given batch. 83 | """ 84 | self._tokenizer = tokenizer 85 | self._base_data = { 86 | DatasetSplit.TRAIN: tfds.load( 87 | 'huggingface:gsm8k/main', split='train' 88 | ), 89 | DatasetSplit.TEST: tfds.load( 90 | 'huggingface:gsm8k/main', split='test' 91 | ), 92 | } 93 | self._max_seq_len = max_seq_len 94 | 95 | def _tokenize_question(self, example: tf.Tensor): 96 | """Tokenization function for the Question.""" 97 | return self._tokenizer.tokenize_tf_op( 98 | example, 99 | prefix=self.QUESTION_PREFIX, 100 | suffix=self.QUESTION_SUFFIX, 101 | add_eos=False, 102 | ) 103 | 104 | def _tokenize_answer(self, example: tf.Tensor): 105 | """Tokenization function for the Response.""" 106 | return self._tokenizer.tokenize_tf_op( 107 | example, 108 | add_eos=True, 109 | ) 110 | 111 | def _to_training_input( 112 | self, 113 | question_tokens: jax.Array, 114 | answer_tokens: jax.Array, 115 | ): 116 | """Build a training input from a tuple of source and destination tokens.""" 117 | 118 | # The input sequence fed to the model is simply the concatenation of the 119 | # source and the destination. 120 | tokens = tf.concat( 121 | [question_tokens, answer_tokens], axis=0 122 | ) 123 | 124 | # To prevent the model from updating based on the source (input) 125 | # tokens, add a target mask to each input. 126 | question_mask = tf.zeros_like(question_tokens, dtype=tf.bool) 127 | answer_mask = tf.ones_like(answer_tokens, dtype=tf.bool) 128 | mask = tf.concat([question_mask, answer_mask], axis=0) 129 | 130 | # If the output tokens sequence is smaller than the target sequence size, 131 | # then pad it with pad tokens. 132 | tokens = self._pad_up_to_max_len(tokens, self._tokenizer.pad_id) 133 | 134 | # Don't want to perform the backward pass on the pad tokens. 135 | mask = self._pad_up_to_max_len(mask, False) 136 | return dataset_builder.TrainingInput( #type: ignore 137 | input_tokens=tokens, #type:ignore 138 | target_mask=mask, #type:ignore 139 | )# type: ignore 140 | 141 | def get_train_dataset(self, batch_size: int, num_epochs: int): 142 | """Build the training dataset.""" 143 | 144 | ds = self._base_data[DatasetSplit.TRAIN].map( 145 | lambda x: ( 146 | self._tokenize_question(x['question']), 147 | self._tokenize_answer(x['answer']), 148 | ), 149 | num_parallel_calls=tf.data.AUTOTUNE, 150 | ) 151 | ds = ds.map(self._to_training_input, 152 | num_parallel_calls=tf.data.AUTOTUNE) 153 | ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len) 154 | ds = ds.shuffle(buffer_size=self.BUFFER_SIZE_SHUFFLE) 155 | #ds = ds.repeat(num_epochs) 156 | #ds = ds.batch(batch_size, drop_remainder=True) 157 | return ds 158 | 159 | def get_validation_dataset(self, batch_size: int): 160 | """Build the validation dataset.""" 161 | 162 | # Same steps as in `get_train_dataset`, but without shuffling and 163 | # repetition. 164 | # ds = self._base_data[DatasetSplit.VALIDATION].map( 165 | # lambda x: (self._tokenize_source(x['src']), 166 | # self._tokenize_destination(x['dst']))) 167 | ds = self._base_data[DatasetSplit.TEST].map( 168 | lambda x: ( 169 | self._tokenize_question(x['question']), 170 | self._tokenize_answer(x['answer']), 171 | ), 172 | num_parallel_calls=tf.data.AUTOTUNE, 173 | ) 174 | ds = ds.map( 175 | self._to_training_input, 176 | num_parallel_calls=tf.data.AUTOTUNE, 177 | ) 178 | ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len) 179 | # ds = ds.batch(batch_size, drop_remainder=True) 180 | return ds 181 | # ds = [self._to_training_input(x, y) for x, y in ds] 182 | # print('here3:', ds) 183 | # ds = [x for x in ds if tf.shape(x.input_tokens)[0] <= self._max_seq_len] 184 | # ds = [ds[i : i + batch_size] for i in range(0, len(ds), batch_size)] 185 | 186 | def get_question_answer_dataset(self): 187 | #ds = self._base_data[DatasetSplit.TEST] 188 | return self._base_data[DatasetSplit.TEST] 189 | -------------------------------------------------------------------------------- /precondition/datamix_gemma/deterministic_strategy_bandit_loop.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The precondition Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Deterministic strategy bandit loop.""" 16 | 17 | import copy 18 | import functools 19 | 20 | from absl import logging 21 | import jax 22 | import numpy as np 23 | from precondition.datamix_gemma import bandit_loop 24 | from precondition.datamix_gemma import training_loop 25 | from precondition.datamix_gemma.evals import eval as eval_lib 26 | from precondition.datamix_gemma.training_batch_generators import dartboard_deterministic_training_batch_generator 27 | 28 | 29 | _TEST_FN_FLAG = False 30 | _STEP_SIZE = 0.1 31 | 32 | def run_deterministic_strategy_bandit_loop( 33 | eval_obj: eval_lib.Eval, 34 | train_obj: training_loop.TrainingLoop, 35 | training_batch_generator_obj: dartboard_deterministic_training_batch_generator.DartboardDeterministicTrainingBatchGenerator, 36 | init_weights=None, 37 | num_iterations=1000, 38 | step_size=0.001, 39 | delta=0.001, 40 | warm_start=False, 41 | init_params=None, 42 | static_weights=False, 43 | step_size_decay=False, 44 | step_size_decay_rate=0.95, 45 | momentum=False, 46 | momentum_beta=0.1, 47 | use_adagrad=False, 48 | use_adagrad_avg=False, 49 | use_adam=False, 50 | adam_beta1=0.9, 51 | adam_beta2=0.99, 52 | gradient_clipping=False, 53 | gradient_clipping_norm=30000, 54 | ): 55 | """Run the bandit loop. 56 | 57 | Args: 58 | eval_obj: the evaluation object. 59 | train_obj: the training object. 60 | init_weights: the initial weights. 61 | num_iterations: the number of iterations to run. 62 | step_size: the step size for the gradient update. 63 | delta: the magnitude of the perturbation for the gradient estimate. 64 | warm_start: whether to warm start the training. 65 | init_params: the initial parameters. 66 | static_weights: whether to use static weights. 67 | step_size_decay: whether to decay the step size. 68 | step_size_decay_rate: the rate of decay. 69 | momentum: whether to use momentum. 70 | momentum_beta: the beta for momentum. 71 | 72 | Returns: 73 | The final weights. 74 | """ 75 | adagrad_matrix= None 76 | adam_matrix = None 77 | adam_first_moment = None 78 | assert not (use_adagrad and use_adam) 79 | if use_adam: 80 | adam_matrix = np.ones(len(training_batch_generator_obj.train_ds_builders)) * 1 81 | adam_first_moment = np.zeros(len(training_batch_generator_obj.train_ds_builders)) 82 | elif use_adagrad: 83 | adagrad_matrix = np.ones(len(training_batch_generator_obj.train_ds_builders)) * 1e6 84 | init_weights = init_weights 85 | if init_weights is None: 86 | init_weights = np.ones( 87 | len(training_batch_generator_obj.train_ds_builders) 88 | ) / len(training_batch_generator_obj.train_ds_builders) 89 | momentum_vec = np.zeros(len(training_batch_generator_obj.train_ds_builders)) 90 | 91 | next_params = init_params 92 | #print(f'init eval score: {eval_obj.evaluate(params=next_params)}') 93 | #logging.info('Done running init_eval') 94 | 95 | weights = init_weights 96 | rng = np.random.default_rng(seed=0) 97 | 98 | unnormalized_weights = copy.deepcopy(weights) 99 | for it in range(num_iterations): 100 | if static_weights: 101 | weights = init_weights 102 | logging.info('[WEIGHTS]: %s', weights) 103 | #next_cands = _generate_candidates_random_sign(weights, rng, delta=delta) 104 | logging.info('Going to train!') 105 | #Prepare for training. 106 | gradient_discount_factor = training_batch_generator_obj.prepare_for_training( 107 | weights, unnormalized_weights 108 | ) 109 | 110 | if not warm_start: 111 | cur_params = copy.deepcopy(init_params) 112 | else: 113 | cur_params = copy.deepcopy(next_params) 114 | training_operations = [] 115 | init_trained_params = train_obj.train_loop( 116 | params={'params': cur_params}, 117 | get_next_batch_fn=functools.partial( 118 | training_batch_generator_obj.get_next_batch 119 | ), 120 | ) 121 | init_trained_params = jax.tree_util.tree_map( 122 | lambda arr: jax.device_put( 123 | arr, jax.local_devices(backend='cpu')[0] 124 | ), 125 | init_trained_params, 126 | ) 127 | for i in range(len(training_batch_generator_obj.train_ds_builders)): 128 | trained_params = copy.deepcopy(init_trained_params) 129 | trained_params = train_obj.train_loop( 130 | params=trained_params, 131 | get_next_batch_fn=functools.partial( 132 | training_batch_generator_obj.get_next_batch_special, index=i, delta=delta 133 | ), 134 | ) 135 | trained_params = jax.tree_util.tree_map( 136 | lambda arr: jax.device_put( 137 | arr, jax.local_devices(backend='cpu')[0] 138 | ), 139 | trained_params, 140 | ) 141 | training_operations.append(trained_params) 142 | logging.info('Done training!') 143 | init_score = eval_obj.evaluate(init_trained_params['params']) 144 | scores = [] 145 | for trained_params in training_operations: 146 | trained_params = jax.device_get(trained_params) 147 | scores.append( 148 | eval_obj.evaluate(trained_params['params']) 149 | ) 150 | if warm_start: 151 | next_params = training_operations[0]['params'] 152 | logging.info('iteration: %d', it) 153 | logging.info('[SCORES]: %s', scores) 154 | for i in range(weights.shape[0]): 155 | logging.inf(f'weights_{str(i)}: {weights[i]}') 156 | logging.info(f'average_score: {(scores[0] + scores[1]) / 2.0}') 157 | logging.info(f'score_1: {scores[0]}') 158 | logging.info(f'score_2: {scores[1]}') 159 | grad = np.zeros(len(weights)) 160 | for i in range(len(weights)): 161 | grad[i] = (scores[i] - init_score)/delta 162 | logging.info('[GRAD]: %s', grad) 163 | if momentum: 164 | momentum_vec = momentum_beta * momentum_vec + grad 165 | unnormalized_weights = bandit_loop._exponentiated_gradient(weights, momentum_vec, step_size) 166 | weights = unnormalized_weights/np.linalg.norm(unnormalized_weights, ord=1) 167 | elif use_adagrad: 168 | adagrad_matrix += grad * grad 169 | truncated_adagrad_matrix = np.maximum(adagrad_matrix, 1e-8) 170 | unnormalized_weights= bandit_loop._exponentiated_gradient(weights, grad / np.sqrt(truncated_adagrad_matrix), step_size) 171 | weights = unnormalized_weights/np.linalg.norm(unnormalized_weights, ord=1) 172 | for i in range(weights.shape[0]): 173 | logging.info(f'adagrad_matrix_{str(i)}: {adagrad_matrix[i]}') 174 | elif use_adam: 175 | adam_first_moment = adam_beta1 * adam_first_moment + (1 - adam_beta1) * grad 176 | bias_corrected_first_moment = adam_first_moment / (1 - adam_beta1 ** (it + 1)) 177 | logging.info(f'bias_corrected_first_moment: {bias_corrected_first_moment}') 178 | adam_matrix = (1 - adam_beta2) * grad * grad + adam_beta2 * adam_matrix 179 | logging.info(f'adam_matrix: {adam_matrix}') 180 | bias_corrected_adam_matrix = adam_matrix / (1 - adam_beta2 ** (it + 1)) 181 | truncated_bias_corrected_adam_matrix = np.maximum(bias_corrected_adam_matrix, 1e-8) 182 | unnormalized_weights = bandit_loop._exponentiated_gradient(weights, bias_corrected_first_moment/ np.sqrt(truncated_bias_corrected_adam_matrix), step_size) 183 | weights = unnormalized_weights/np.linalg.norm(unnormalized_weights, ord=1) 184 | for i in range(weights.shape[0]): 185 | logging.info(f'adam_matrix_{str(i)}: {adam_matrix[i]}') 186 | logging.info(f'adam_first_moment_{str(i)}: adam_first_moment[i]') 187 | elif use_adagrad_avg: 188 | adagrad_matrix += np.square(grad) 189 | unnormalized_weights = bandit_loop._exponentiated_gradient(weights, grad / np.mean(np.sqrt(adagrad_matrix + 1e-8)), step_size) 190 | weights = unnormalized_weights/np.linalg.norm(unnormalized_weights, ord=1) 191 | for i in range(weights.shape[0]): 192 | logging.info(f'adagrad_matrix_{str(i)}: {adagrad_matrix[i]}') 193 | else: 194 | unnormalized_weights = bandit_loop._exponentiated_gradient(weights, grad, step_size) 195 | weights = unnormalized_weights/np.linalg.norm(unnormalized_weights, ord=1) 196 | if step_size_decay: 197 | step_size *= step_size_decay_rate 198 | 199 | 200 | return weights 201 | --------------------------------------------------------------------------------