├── .github └── ISSUE_TEMPLATE │ ├── bug_report.yml │ └── config.yml ├── .gitignore ├── LICENSE ├── README.md ├── example └── 7B.yaml ├── finetune ├── __init__.py ├── args.py ├── checkpointing.py ├── data │ ├── __init__.py │ ├── args.py │ ├── data_loader.py │ ├── dataset.py │ ├── exceptions.py │ └── tokenize.py ├── distributed.py ├── eval.py ├── loss.py ├── mixed_precision.py ├── monitoring │ ├── __init__.py │ ├── metrics_logger.py │ └── utils.py ├── utils.py └── wrapped_model.py ├── model ├── __init__.py ├── args.py ├── lora.py ├── moe.py ├── rope.py └── transformer.py ├── pyproject.toml ├── requirements.dev.txt ├── requirements.txt ├── tests ├── __init__.py ├── test_checkpointer.py ├── test_data.py ├── test_mixed_precision.py ├── test_model.py ├── test_train_loop.py └── test_utils.py ├── train.py ├── tutorials └── mistral_finetune_7b.ipynb └── utils ├── __init__.py ├── extend_model_vocab.py ├── merge_lora.py ├── reformat_data.py ├── reformat_data_glaive.py └── validate_data.py /.github/ISSUE_TEMPLATE/bug_report.yml: -------------------------------------------------------------------------------- 1 | name: Bug report related to mistral-finetune 2 | description: Submit a bug report that's related to mistral-finetune 3 | title: '[BUG: ' 4 | labels: ['bug', 'triage'] 5 | body: 6 | - type: markdown 7 | attributes: 8 | value: | 9 | Thanks for taking the time to fill out this bug report! 10 | - type: textarea 11 | id: python-vv 12 | attributes: 13 | label: Python Version 14 | description: Run `python -VV` from your virtual environment 15 | placeholder: Copy-paste the output (no need for backticks, will be formatted into code automatically) 16 | render: shell 17 | validations: 18 | required: true 19 | - type: textarea 20 | id: pip-freeze 21 | attributes: 22 | label: Pip Freeze 23 | description: Run `pip freeze` from your virtual environment 24 | placeholder: Copy-paste the output (no need for backticks, will be formatted into code automatically) 25 | render: shell 26 | validations: 27 | required: true 28 | - type: textarea 29 | id: reproduction-steps 30 | attributes: 31 | label: Reproduction Steps 32 | description: Provide a clear and concise description of the steps that lead to your issue. Make sure that the steps contain reproducible code snippet. If we cannot reproduce the problem we cannot act upon the issue. 33 | placeholder: | 34 | 1. First step... 35 | 2. Step 2... 36 | ... 37 | validations: 38 | required: true 39 | - type: textarea 40 | id: expected-behavior 41 | attributes: 42 | label: Expected Behavior 43 | description: Explain briefly what you expected to happen. 44 | validations: 45 | required: true 46 | - type: textarea 47 | id: additional-context 48 | attributes: 49 | label: Additional Context 50 | description: Add any context about your problem that you deem relevant. 51 | - type: textarea 52 | id: suggested-solutions 53 | attributes: 54 | label: Suggested Solutions 55 | description: Please list any solutions you recommend we consider. 56 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | contact_links: 3 | - name: Documentation 4 | url: https://docs.mistral.ai 5 | about: Developer documentation for the Mistral AI platform 6 | - name: Discord 7 | url: https://discord.com/invite/mistralai 8 | about: Chat with the Mistral community 9 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so -- we like our plugin sos 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | pre-commit-cache/ 19 | reference_copy 20 | reference 21 | .git_last_commit 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | .pip-cache/ 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | cover/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | .pybuilder/ 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | # For a library or package, you might want to ignore these files since the code is 92 | # intended to run in multiple environments; otherwise, check them in: 93 | # .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # poetry 103 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 104 | # This is especially recommended for binary packages to ensure reproducibility, and is more 105 | # commonly ignored for libraries. 106 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 107 | #poetry.lock 108 | 109 | # pdm 110 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 111 | #pdm.lock 112 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 113 | # in version control. 114 | # https://pdm.fming.dev/#use-with-ide 115 | .pdm.toml 116 | 117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 118 | __pypackages__/ 119 | 120 | # Celery stuff 121 | celerybeat-schedule 122 | celerybeat.pid 123 | 124 | # SageMath parsed files 125 | *.sage.py 126 | 127 | # Environments 128 | .env 129 | .venv 130 | env/ 131 | venv/ 132 | ENV/ 133 | env.bak/ 134 | venv.bak/ 135 | 136 | # Spyder project settings 137 | .spyderproject 138 | .spyproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Ruff 149 | .ruff_cache 150 | 151 | # Pyre type checker 152 | .pyre/ 153 | 154 | # pytype static type analyzer 155 | .pytype/ 156 | 157 | # Cython debug symbols 158 | cython_debug/ 159 | 160 | # PyCharm 161 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 162 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 163 | # and can be added to the global gitignore or merged into this file. For a more nuclear 164 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 165 | .idea/ 166 | 167 | # VSCode 168 | .vscode/ 169 | 170 | # Eval logs 171 | eval/eval-* 172 | 173 | # Symlinks 174 | tests/fixtures 175 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /example/7B.yaml: -------------------------------------------------------------------------------- 1 | # data 2 | data: 3 | instruct_data: "" # Fill 4 | data: "" # Optionally fill with pretraining data 5 | eval_instruct_data: "" # Optionally fill 6 | 7 | # model 8 | model_id_or_path: "" # Change to downloaded path 9 | lora: 10 | rank: 64 11 | 12 | # optim 13 | seq_len: 32768 14 | batch_size: 1 15 | max_steps: 300 16 | optim: 17 | lr: 6.e-5 18 | weight_decay: 0.1 19 | pct_start: 0.05 20 | 21 | # other 22 | seed: 0 23 | log_freq: 1 24 | eval_freq: 100 25 | no_eval: False 26 | ckpt_freq: 100 27 | 28 | save_adapters: True # save only trained LoRA adapters. Set to `False` to merge LoRA adapter into the base model and save full fine-tuned model 29 | 30 | run_dir: "" # Fill 31 | 32 | wandb: 33 | project: "" # your wandb project name 34 | run_name: "" # your wandb run name 35 | key: "" # your wandb api key 36 | offline: False 37 | -------------------------------------------------------------------------------- /finetune/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mistralai/mistral-finetune/656df1c94c80ca9703ebc471c9f106c9b7a0bfa7/finetune/__init__.py -------------------------------------------------------------------------------- /finetune/args.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from dataclasses import dataclass, field 4 | from pathlib import Path 5 | from typing import Optional 6 | 7 | from simple_parsing.helpers import Serializable 8 | 9 | from model.args import LoraArgs 10 | 11 | from .data.args import DataArgs 12 | 13 | 14 | @dataclass 15 | class OptimArgs(Serializable): 16 | lr: float = 1e-4 17 | weight_decay: float = 0.1 18 | pct_start: float = 0.05 19 | 20 | 21 | @dataclass 22 | class WandbArgs(Serializable): 23 | project: Optional[str] = None # Fill this argument to use wandb. 24 | offline: bool = False 25 | key: Optional[str] = None 26 | run_name: Optional[str] = None 27 | 28 | def __post_init__(self) -> None: 29 | if self.project is not None: 30 | try: 31 | import wandb # noqa: F401 32 | except ImportError: 33 | raise ImportError("`wandb` not installed. Either make sure `wandb` is installed or set `wandb:project` to None.") 34 | 35 | if len(self.project) == 0: 36 | raise ValueError("`wandb.project` must not be an empty string.") 37 | 38 | @dataclass 39 | class MLFlowArgs(Serializable): 40 | tracking_uri: Optional[str] = None 41 | experiment_name: Optional[str] = None 42 | 43 | def __post_init__(self) -> None: 44 | if self.tracking_uri is not None: 45 | try: 46 | import mlflow # noqa: F401 47 | except ImportError: 48 | raise ImportError("`mlflow` not installed. Either make sure `mlflow` is installed or set `mlflow.tracking_uri` to None.") 49 | 50 | if self.experiment_name is None: 51 | raise ValueError("If `mlflow.tracking_uri` is set, `mlflow.experiment_name` must be set as well.") 52 | 53 | 54 | 55 | @dataclass 56 | class TrainArgs(Serializable): 57 | data: DataArgs 58 | 59 | # if specified, instruct_tokenizer and model will be loaded 60 | model_id_or_path: str # Path to the directory containing the initial model or model id: "mistral-small" 61 | 62 | run_dir: str # Path to the directory where everything will be saved. It needs to be empty. 63 | # Name of the wandb run, if None it will be set to the name of the run_dir. 64 | 65 | optim: OptimArgs = field(default_factory=OptimArgs) 66 | seed: int = 0 67 | # Number of steps to accumulate gradients before doing an optimizer step. 68 | num_microbatches: int = 1 69 | 70 | seq_len: int = 2048 # Number of tokens per batch per device. 71 | batch_size: int = 1 72 | max_norm: float = 1.0 # Gradient clipping. 73 | max_steps: int = 100 # Number of training steps. 74 | log_freq: int = 1 # Number of steps between each logging. 75 | 76 | # Number of steps between each checkpoint saving. If inferior to 1, only the last checkpoint will be saved. 77 | ckpt_freq: int = 0 78 | save_adapters: bool = True 79 | # If True, no checkpoint will be saved. This is useful for development. 80 | no_ckpt: bool = False 81 | num_ckpt_keep: Optional[int] = 3 82 | eval_freq: int = 0 83 | no_eval: bool = True 84 | 85 | # Efficiency 86 | # Determines whether gradient checkpointing should be utilized or not during the training process. Gradient checkpointing can be beneficial in reducing memory usage at the cost of slightly longer training times. 87 | checkpoint: bool = True 88 | 89 | world_size: Optional[int] = field(init=False, default=None) 90 | 91 | # logging 92 | wandb: WandbArgs = field(default_factory=WandbArgs) 93 | mlflow: MLFlowArgs = field(default_factory=MLFlowArgs) 94 | 95 | # LoRA 96 | lora: Optional[LoraArgs] = field(default_factory=LoraArgs) 97 | 98 | def __post_init__(self) -> None: 99 | assert getattr(self, "world_size", None) is None 100 | self.world_size = int(os.environ.get("WORLD_SIZE", -1)) 101 | 102 | if self.wandb.offline: 103 | command = f"cd {self.run_dir}; wandb sync --sync-all" 104 | logging.info(f"to sync wandb offline, run: {command}") 105 | 106 | assert self.num_microbatches >= 1 107 | 108 | assert self.num_ckpt_keep is None or self.num_ckpt_keep >= 1 109 | 110 | if self.model_id_or_path is not None: 111 | Path(self.model_id_or_path).exists() 112 | 113 | if not self.save_adapters: 114 | logging.warning( 115 | "You have disabled `save_adapters` and are thus merging the trained LoRA checkpoint into the base model upon checkpointing. This might lead to OOM errors - make sure you have enough CPU and GPU memory." 116 | ) 117 | -------------------------------------------------------------------------------- /finetune/checkpointing.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import shutil 4 | from pathlib import Path 5 | from typing import Dict, List, Optional, Union 6 | 7 | import safetensors.torch 8 | import torch 9 | from mistral_common.tokens.tokenizers.sentencepiece import ( 10 | InstructTokenizerBase, 11 | SentencePieceTokenizer, 12 | ) 13 | from torch.distributed import barrier 14 | from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel 15 | 16 | from model.transformer import LoRALinear 17 | 18 | from .distributed import get_rank, get_world_size 19 | from .utils import TrainState 20 | 21 | logger = logging.getLogger("checkpointing") 22 | 23 | 24 | def main_logger_info(message: str) -> None: 25 | if get_rank() == 0: 26 | logger.info(message) 27 | 28 | 29 | class Checkpointer: 30 | """A class to save PyTorch model and optimizer states""" 31 | 32 | def __init__( 33 | self, 34 | model: FullyShardedDataParallel, 35 | state: TrainState, 36 | run_dir: Union[Path, str], 37 | optimizer: Optional[torch.optim.Optimizer] = None, 38 | num_ckpt_keep: Optional[int] = None, 39 | ): 40 | self.model = model 41 | self.optimizer = optimizer 42 | self.state = state 43 | self.run_dir = Path(run_dir) 44 | self.rank = get_rank() 45 | self.num_ckpt_keep = num_ckpt_keep 46 | 47 | @property 48 | def ckpt_dir(self) -> Path: 49 | return self.run_dir / "checkpoints" 50 | 51 | @property 52 | def dst_dir(self) -> Path: 53 | return self.ckpt_dir / f"checkpoint_{self.state.step:06d}" / "consolidated" 54 | 55 | @staticmethod 56 | def consolidated_path( 57 | ckpt_dir: Path, use_safetensors: bool, save_only_lora: Optional[bool] = False 58 | ) -> Path: 59 | suffix = "safetensors" if use_safetensors else "00.pth" 60 | prefix = "lora" if save_only_lora else "consolidated" 61 | 62 | return ckpt_dir / f"{prefix}.{suffix}" 63 | 64 | @staticmethod 65 | def _tmp(ckpt_dir: Path) -> Path: 66 | return ckpt_dir.with_name(f"tmp.{ckpt_dir.name}") 67 | 68 | def write_params_info(self, tmp_dst: Path): 69 | params_path = tmp_dst / "params.json" 70 | with open(params_path, "w") as f: 71 | model_args = self.model.args.to_dict() 72 | 73 | f.write(json.dumps(model_args, indent=4)) 74 | 75 | def delete_old_ckpts(self) -> List[Path]: 76 | all_saved_ckpts = [d for d in self.ckpt_dir.iterdir() if d.is_dir()] 77 | 78 | # Sort directories by creation time (oldest to newest) 79 | all_saved_ckpts.sort(key=lambda x: x.stat().st_ctime, reverse=True) 80 | 81 | ckpts_to_delete = all_saved_ckpts[self.num_ckpt_keep :] 82 | 83 | for ckpt_to_delete in ckpts_to_delete: 84 | try: 85 | shutil.rmtree(ckpt_to_delete) 86 | main_logger_info(f"Deleted ckpt: {ckpt_to_delete}") 87 | except OSError as e: 88 | main_logger_info(f"Error deleting directory {ckpt_to_delete}: {e}") 89 | 90 | return ckpts_to_delete 91 | 92 | @staticmethod 93 | def get_lora_states(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 94 | return {k: v for k, v in state_dict.items() if "lora" in k} 95 | 96 | @staticmethod 97 | def get_non_lora_states( 98 | state_dict: Dict[str, torch.Tensor], 99 | ) -> Dict[str, torch.Tensor]: 100 | return { 101 | k: v 102 | for k, v in state_dict.items() 103 | if not any(l_key in k for l_key in ["lora", "frozen"]) 104 | } 105 | 106 | @torch.no_grad() 107 | def retrieve_save_states( 108 | self, save_only_lora: bool, save_dtype: torch.dtype 109 | ) -> Dict[str, torch.Tensor]: 110 | if save_only_lora: 111 | assert ( 112 | self.model.args.lora.enable 113 | ), "Cannot save LoRA checkpoint as LoRA training is not enabled." 114 | 115 | # remove all potential hooks 116 | for module in self.model.modules(): 117 | if isinstance(module, LoRALinear) and hasattr(module, "_merge_lora_handle"): 118 | module._merge_lora_handle.remove() # type: ignore 119 | 120 | # merge weights if we don't just save LoRA 121 | if not save_only_lora: 122 | 123 | def merge_lora( 124 | m: torch.nn.Module, 125 | destination: Dict[str, torch.Tensor], 126 | prefix: str, 127 | *args, 128 | ): 129 | weight = m.merge_weight() # type: ignore 130 | destination[prefix + "weight"] = weight 131 | 132 | for module in self.model.modules(): 133 | if isinstance(module, LoRALinear): 134 | module._merge_lora_handle = module._register_state_dict_hook( 135 | merge_lora 136 | ) 137 | 138 | offload_to_cpu = get_world_size() > 1 139 | if save_only_lora: 140 | 141 | def is_trainable_fsdp( 142 | module: Union[torch.nn.Module, FullyShardedDataParallel], 143 | ): 144 | is_fsdp = isinstance(module, FullyShardedDataParallel) 145 | all_params_have_grads = is_fsdp and all( 146 | p.requires_grad is True for p in module.parameters() 147 | ) 148 | 149 | # need to make sure only lowest fsdp wrap is used 150 | is_leaf_node = is_fsdp and len(list(module.module.children())) == 0 # type: ignore 151 | 152 | return is_fsdp and all_params_have_grads and is_leaf_node 153 | 154 | # extract all modules with only trainable weights 155 | modules = { 156 | k: m for k, m in self.model.named_modules() if is_trainable_fsdp(m) 157 | } 158 | 159 | states = {} 160 | for key, module in modules.items(): 161 | assert isinstance( 162 | module, FullyShardedDataParallel 163 | ), "`module` should be an instance of `FullyShardedDataParallel`" 164 | parent_prefix = key.replace("_fsdp_wrapped_module.", "").replace( 165 | "_checkpoint_wrapped_module.", "" 166 | ) 167 | with module.summon_full_params( 168 | module, writeback=True, offload_to_cpu=offload_to_cpu 169 | ): 170 | states.update( 171 | { 172 | f"{parent_prefix}.{k}": v.to(dtype=save_dtype) 173 | for k, v in module.state_dict().items() 174 | } 175 | ) 176 | else: 177 | # make sure you have enough CPU RAM available to save the full model 178 | assert isinstance( 179 | self.model, FullyShardedDataParallel 180 | ), "`self.model` should be an instance of `FullyShardedDataParallel`" 181 | with self.model.summon_full_params( 182 | self.model, writeback=True, offload_to_cpu=offload_to_cpu 183 | ): 184 | states = self.get_non_lora_states(self.model.state_dict()) 185 | states = {k: v.to(dtype=save_dtype) for k, v in states.items()} 186 | 187 | states = dict(sorted(states.items())) 188 | return states 189 | 190 | @staticmethod 191 | def save_tokenizer(instruct_tokenizer: InstructTokenizerBase, tmp_dst: Path): 192 | if isinstance(instruct_tokenizer.tokenizer, SentencePieceTokenizer): 193 | serialized_spm = ( 194 | instruct_tokenizer.tokenizer._model.serialized_model_proto() 195 | ) # type: ignore 196 | 197 | tokenizer_path = tmp_dst / "tokenizer.model.v3" 198 | 199 | with open(tokenizer_path, "wb") as f: 200 | f.write(serialized_spm) 201 | else: 202 | path = instruct_tokenizer.tokenizer._path 203 | assert path is not None 204 | shutil.copy(path, tmp_dst / "tekken.json") 205 | 206 | @torch.no_grad() 207 | def save_checkpoint( 208 | self, 209 | save_only_lora: bool, 210 | dtype: torch.dtype = torch.float16, 211 | instruct_tokenizer: Optional[InstructTokenizerBase] = None, 212 | ): 213 | tmp_dst = self._tmp(self.dst_dir) 214 | main_logger_info( 215 | f"Dumping checkpoint in {self.dst_dir} using tmp name: {tmp_dst.name}" 216 | ) 217 | 218 | assert not self.dst_dir.exists(), f"dst exists {self.dst_dir}" 219 | tmp_dst.mkdir(parents=True, exist_ok=True) 220 | 221 | states: Dict[str, torch.Tensor] = self.retrieve_save_states( 222 | save_only_lora, dtype 223 | ) 224 | 225 | barrier() 226 | 227 | if self.rank == 0: 228 | # save checkpoint in tmp path 229 | safetensors.torch.save_file( 230 | states, 231 | self.consolidated_path( 232 | tmp_dst, use_safetensors=True, save_only_lora=save_only_lora 233 | ), # always use safetensors for checkpointing 234 | ) 235 | 236 | self.write_params_info(tmp_dst) 237 | 238 | # save tokenizer 239 | if instruct_tokenizer is not None: 240 | self.save_tokenizer(instruct_tokenizer, tmp_dst) 241 | 242 | assert not self.dst_dir.exists(), f"should not happen! {self.dst_dir}" 243 | tmp_dst.rename(self.dst_dir) 244 | 245 | logger.info( 246 | f"Done dumping checkpoint in {self.dst_dir} for step: {self.state.step}" 247 | ) 248 | 249 | # delete last n checkpoints 250 | if self.num_ckpt_keep is not None: 251 | ckpts_to_delete = self.delete_old_ckpts() 252 | logger.info( 253 | f"Done deleting checkpoints {', '.join([str(c) for c in ckpts_to_delete])}" 254 | ) 255 | 256 | main_logger_info("Done!") 257 | -------------------------------------------------------------------------------- /finetune/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mistralai/mistral-finetune/656df1c94c80ca9703ebc471c9f106c9b7a0bfa7/finetune/data/__init__.py -------------------------------------------------------------------------------- /finetune/data/args.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from dataclasses import dataclass, field 3 | 4 | from simple_parsing.helpers import Serializable 5 | 6 | logger = logging.getLogger("data") 7 | 8 | 9 | @dataclass() 10 | class InstructArgs(Serializable): 11 | shuffle: bool = True 12 | 13 | # For function calling training examples only the last tool call 14 | # of the assistant message can be used for training. Therefore, 15 | # we chunk longer function calling conversations into multiple 16 | # training samples to not lose any data. E.g.: 17 | # [[ 18 | # UserMessage_1, AssistantToolCallMessage_1, ToolMessage_1, AssistantMessage_1 19 | # UserMessage_2, AssistantToolCallMessage_2, ToolMessage_2, AssistantMessage_2 20 | # ]] 21 | # => is chunked into two training samples: 22 | # [[ 23 | # UserMessage_1, AssistantToolCallMessage_1, ToolMessage_1, AssistantMessage_1 24 | # ], 25 | # [ 26 | # UserMessage_1, AssistantToolCallMessage_1, ToolMessage_1, AssistantMessage_1 27 | # UserMessage_2, AssistantToolCallMessage_2, ToolMessage_2, AssistantMessage_2 28 | # ]] 29 | # NOTE: Only if your data is already pre-chunked should this argument be set to False 30 | dynamic_chunk_fn_call: bool = True 31 | 32 | 33 | @dataclass() 34 | class DataArgs(Serializable): 35 | # The data arguments `data` and `instruct_data` are a string in the format 36 | # "data_source_dir_1:weight_1,data_source_dir_2:weight_2,...". The weight 37 | # will be used to sample the data sources. If the sum of the weights is 38 | # not 1 when concatenating the two arguments `data` and `instruct_data`, 39 | # it will be normalized. The data sources folders must contain jsonl files. 40 | # If the value is an empty string, no data will be used for the corresponding 41 | # data type. 42 | data: str = ( 43 | "" # Each line in the jsonl files inside the data source directories must be a dictionary with a "text" key. See Readme for more details. Can be left empty. 44 | ) 45 | shuffle: bool = False 46 | instruct_data: str = ( 47 | "" # Each line in the jsonl files inside the data source directories must be a dictionary with a "interactions" key. See Readme for more details. Can be left empty. 48 | ) 49 | eval_instruct_data: str = ( 50 | "" # Each line in the jsonl files inside the data source directories must be a dictionary with a "interactions" key. See Readme for more details. Can be left empty. 51 | ) 52 | instruct: InstructArgs = field(default_factory=InstructArgs) 53 | 54 | def __post_init__(self) -> None: 55 | if ( 56 | self.instruct.shuffle is False 57 | and self.instruct.dynamic_chunk_fn_call is True 58 | ): 59 | raise ValueError( 60 | "Make sure to either enable `data.instruct.shuffle=True` or `data.instruct.dynamic_chunk_fn_call=False`. Dynamic chunking is only possible if data is loaded and shuffled before training." 61 | ) 62 | -------------------------------------------------------------------------------- /finetune/data/data_loader.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from typing import Any, Iterator, List, Optional 3 | 4 | import numpy as np 5 | from mistral_common.tokens.tokenizers.sentencepiece import InstructTokenizerBase 6 | 7 | from .args import DataArgs 8 | from .dataset import build_dataset 9 | 10 | 11 | @dataclasses.dataclass 12 | class Batch: 13 | x: np.ndarray 14 | y: np.ndarray 15 | sizes: List[int] 16 | y_mask: Optional[np.ndarray] = None 17 | is_pad_only: bool = False 18 | 19 | def __post_init__(self): 20 | assert self.x.ndim == 1 21 | assert self.x.shape == self.y.shape 22 | assert self.x.dtype == np.int64 23 | assert self.y.dtype == np.int64 24 | assert isinstance(self.sizes, list) 25 | assert sum(self.sizes) == self.x.size == self.y.size 26 | 27 | if self.y_mask is not None: 28 | assert self.y_mask.size == self.y.size, (self.y_mask.shape, self.y.shape) 29 | assert self.y_mask.dtype == bool 30 | assert sum(self.sizes) == self.y_mask.size 31 | assert not self.y_mask.all() 32 | assert self.y_mask.any() 33 | 34 | if self.is_pad_only: 35 | assert np.sum(np.abs(self.y)) == 0 36 | assert np.sum(np.abs(self.x)) == 0 37 | assert self.y_mask is None 38 | # create all 0's mask for pad samples 39 | self.y_mask = np.zeros_like(self.x) 40 | 41 | 42 | 43 | 44 | @dataclasses.dataclass 45 | class BatchList: 46 | x: List[List[int]] = dataclasses.field(default_factory=list) 47 | y: List[List[int]] = dataclasses.field(default_factory=list) 48 | sizes: List[List[int]] = dataclasses.field(default_factory=list) 49 | y_mask: List[List[bool]] = dataclasses.field(default_factory=list) 50 | 51 | def __post_init__(self): 52 | assert self.x == [], "`BatchList` has to be empty at init." 53 | assert self.y == [], "`BatchList` has to be empty at init." 54 | assert self.sizes == [], "`BatchList` has to be empty at init." 55 | assert self.y_mask == [], "`BatchList` has to be empty at init." 56 | 57 | def __len__(self) -> int: 58 | return len(self.x) 59 | 60 | def add(self, x: List[int], y: List[int], sizes: List[int], y_mask: List[bool]): 61 | self.x.append(x) 62 | self.y.append(y) 63 | self.sizes.append(sizes) 64 | self.y_mask.append(y_mask) 65 | 66 | def empty(self): 67 | self.x = [] 68 | self.y = [] 69 | self.sizes = [] 70 | self.y_mask = [] 71 | 72 | @staticmethod 73 | def flatten_to_numpy(list_of_lists: List[List[Any]], dtype: type) -> np.ndarray: 74 | return np.array([el for sublist in list_of_lists for el in sublist], dtype=dtype) 75 | 76 | def create_batch(self) -> Batch: 77 | x_np: np.ndarray = self.flatten_to_numpy(self.x, dtype=np.int64) 78 | y_np: np.ndarray = self.flatten_to_numpy(self.y, dtype=np.int64) 79 | sizes = sum(self.sizes, []) # noqa 80 | 81 | y_mask_flatten = self.flatten_to_numpy(self.y_mask, dtype=bool) 82 | y_mask_np: Optional[np.ndarray] = None if y_mask_flatten.all() else y_mask_flatten 83 | 84 | return Batch(x_np, y_np, sizes, y_mask_np) 85 | 86 | 87 | 88 | 89 | def build_data_loader( 90 | instruct_tokenizer: InstructTokenizerBase, 91 | args: DataArgs, 92 | batch_size: int, 93 | seq_len: int, 94 | seed: Optional[int], 95 | rank: int, 96 | world_size: int, 97 | is_eval: bool, 98 | ) -> Iterator[Batch]: 99 | pretrain_data = args.data if not is_eval else "" 100 | instruct_data = args.instruct_data if not is_eval else args.eval_instruct_data 101 | 102 | dataset = build_dataset( 103 | pretrain_data=pretrain_data, 104 | instruct_data=instruct_data, 105 | instruct_args=args.instruct, 106 | instruct_tokenizer=instruct_tokenizer, 107 | seq_len=seq_len, 108 | seed=seed, 109 | rank=rank, 110 | world_size=world_size, 111 | is_eval=is_eval, 112 | shuffle_pretrain=args.shuffle, 113 | ) 114 | 115 | batch_list = BatchList() 116 | for sample in dataset: 117 | assert all(s >= 0 for s in sample.sizes) 118 | 119 | batch_list.add(sample.x, sample.y, sample.sizes, sample.mask) 120 | 121 | if len(batch_list) == batch_size: 122 | batch: Batch = batch_list.create_batch() 123 | yield batch 124 | 125 | batch_list.empty() 126 | 127 | -------------------------------------------------------------------------------- /finetune/data/dataset.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import itertools 3 | import json 4 | import logging 5 | from dataclasses import dataclass 6 | from pathlib import Path 7 | from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Union 8 | 9 | import numpy as np 10 | import torch.distributed as dist 11 | from mistral_common.protocol.instruct.messages import ( 12 | FinetuningAssistantMessage, 13 | SystemMessage, 14 | ) 15 | from mistral_common.tokens.tokenizers.sentencepiece import InstructTokenizerBase 16 | 17 | from finetune.distributed import get_rank 18 | 19 | from .args import InstructArgs 20 | from .tokenize import ( 21 | Mask, 22 | SampleType, 23 | TokenSample, 24 | TrainingInstructSample, 25 | build_instruct_sample, 26 | encode, 27 | ) 28 | 29 | logger = logging.getLogger("dataset") 30 | 31 | 32 | _LOADED_DATASETS: Dict[Path, List[TokenSample]] = {} 33 | 34 | 35 | def main_logger_info(message: str) -> None: 36 | if dist.is_initialized() and get_rank() == 0: 37 | logger.info(message) 38 | 39 | 40 | def load_file(path: Path, world_size: int, rank: int) -> List[str]: 41 | lines = [] 42 | with path.open() as f: 43 | for idx, line in enumerate(f): 44 | if not idx % world_size == rank: 45 | continue 46 | lines.append(line) 47 | return lines 48 | 49 | 50 | def maybe_load_local_dataset( 51 | path: Path, chunk: bool, rank: int, world_size: int, instruct_tokenizer: InstructTokenizerBase, sample_type: SampleType 52 | ) -> List[TokenSample]: 53 | global _LOADED_DATASETS 54 | 55 | if path in _LOADED_DATASETS: 56 | return _LOADED_DATASETS[path] 57 | 58 | main_logger_info(f"Loading {path} ...") 59 | lines: List[str] = load_file(path, rank=rank, world_size=world_size) 60 | 61 | if chunk: 62 | lines += maybe_chunk_lines(lines) 63 | 64 | tokens_list: List[TokenSample] = [] 65 | for line in lines: 66 | data = json.loads(line) 67 | 68 | token_sample: TokenSample = encode( 69 | data, 70 | instruct_tokenizer=instruct_tokenizer, 71 | as_type=sample_type, 72 | ) 73 | tokens_list.append(token_sample) 74 | 75 | main_logger_info(f"{path} loaded and tokenized.") 76 | _LOADED_DATASETS[path] = tokens_list 77 | 78 | return _LOADED_DATASETS[path] 79 | 80 | 81 | @dataclass 82 | class DataDir: 83 | path: Path 84 | sample_type: SampleType 85 | 86 | @property 87 | def jsonl_files(self): 88 | assert self.path.exists(), f"Make sure that {self.path} exists" 89 | jsonl_files = list(self.path.rglob("*jsonl")) 90 | assert ( 91 | len(jsonl_files) > 0 92 | ), f"{self.path} does not seem to have any files ending with '.jsonl'" 93 | return jsonl_files 94 | 95 | 96 | @dataclass 97 | class DataFile: 98 | path: Path 99 | sample_type: SampleType 100 | 101 | @property 102 | def jsonl_files(self): 103 | assert self.path.exists(), f"Make sure that {self.path} exists" 104 | return [self.path] 105 | 106 | 107 | def parse_data_sources( 108 | pretrain_data: str, 109 | instruct_data: str, 110 | ) -> Tuple[List[Union[DataDir, DataFile]], List[float]]: 111 | seen: Set[str] = set() 112 | sources: List[Union[DataDir, DataFile]] = [] 113 | weights: List[float] = [] 114 | for sample_sources, sample_type in [ 115 | (pretrain_data, SampleType.PRETRAIN), 116 | (instruct_data, SampleType.INSTRUCT), 117 | ]: 118 | for source in sample_sources.strip().split(","): 119 | if not source: 120 | continue 121 | 122 | source_items = source.strip().split(":") 123 | if len(source_items) == 1: 124 | path_ = source_items[0] 125 | weight = 1.0 126 | elif len(source_items) == 2: 127 | path_, weight_ = source_items 128 | weight = float(weight_) 129 | else: 130 | raise ValueError( 131 | f"{source} is not correctly formatted. Make sure to format each data source as : or just " 132 | ) 133 | 134 | assert ( 135 | path_ not in seen 136 | ), f"{path_} seems to be duplicated. Make sure to only add it once." 137 | assert ( 138 | weight > 0 139 | ), f"Make sure to define strictly positive data sampling weights, not {weight}" 140 | 141 | data: Union[DataDir, DataFile] 142 | if Path(path_).is_dir(): 143 | data = DataDir(path=Path(path_), sample_type=sample_type) 144 | elif Path(path_).is_file(): 145 | data = DataFile(path=Path(path_), sample_type=sample_type) 146 | else: 147 | raise FileNotFoundError( 148 | f"The path {path_} does not exist. Make sure {path_} is either a file or directory that contains training data." 149 | ) 150 | 151 | sources.append(data) 152 | weights.append(weight) 153 | 154 | seen.add(path_) 155 | 156 | sum_weights = sum(weights) 157 | n_weights = [weight / sum_weights for weight in weights] 158 | 159 | assert min(n_weights) > 0 160 | assert ( 161 | abs(1 - sum(n_weights)) < 1e-8 162 | ), f"Defined data sampling weights {weights} must sum to 1." 163 | return sources, n_weights 164 | 165 | 166 | @dataclasses.dataclass() 167 | class SequenceMaskAndSizes: 168 | """ 169 | Concatenation of samples to reach a given size 170 | """ 171 | 172 | x: List[int] 173 | y: List[int] 174 | mask: Mask 175 | sizes: List[int] 176 | 177 | def __post_init__(self): 178 | assert sum(self.sizes) == len(self.x) == len(self.y) == len(self.mask) 179 | 180 | 181 | def sequence_iterator( 182 | ds_it: Iterator[TokenSample], 183 | seq_len: int, 184 | is_finite: bool, 185 | ) -> Iterator[SequenceMaskAndSizes]: 186 | """ 187 | Creates sequences of length `seq_len` from the dataset iterator by concatenating samples. 188 | """ 189 | x_buffer: List[int] = [] 190 | y_buffer: List[int] = [] 191 | mask_buffer: Mask = [] 192 | 193 | sizes: List[int] = [] 194 | n_missing = seq_len 195 | for sample in ds_it: 196 | assert 0 <= len(x_buffer) < seq_len, len(x_buffer) 197 | assert n_missing == seq_len - len( 198 | x_buffer 199 | ), f"n_missing: {n_missing} | seq_len - len(x_buffer) {seq_len - len(x_buffer)}" 200 | 201 | tokens, mask = sample.tokens, sample.masks[1:] 202 | x, y = tokens[:-1], tokens[1:] 203 | cur_pos = 0 204 | 205 | while cur_pos < len(x): 206 | size = len(x[cur_pos : cur_pos + n_missing]) 207 | 208 | curr_mask = mask[cur_pos : cur_pos + n_missing] 209 | if not any(curr_mask): 210 | cur_pos += size 211 | # we have a sequence with a mask filled with False 212 | continue 213 | 214 | x_buffer.extend(x[cur_pos : cur_pos + n_missing]) 215 | y_buffer.extend(y[cur_pos : cur_pos + n_missing]) 216 | mask_buffer.extend(curr_mask) 217 | n_missing -= size 218 | sizes.append(size) 219 | 220 | cur_pos += size 221 | 222 | if n_missing == 0: 223 | assert len(mask_buffer) == len(x_buffer) == seq_len == len(y_buffer) 224 | assert sum(sizes) == seq_len 225 | # we don't want to yield sequences with a mask filled with False 226 | if any(mask_buffer): 227 | yield SequenceMaskAndSizes( 228 | x=x_buffer, 229 | y=y_buffer, 230 | mask=mask_buffer, 231 | sizes=sizes, 232 | ) 233 | x_buffer, y_buffer = [], [] 234 | mask_buffer = [] 235 | sizes = [] 236 | n_missing = seq_len 237 | 238 | if is_finite: 239 | # if dataloader is in eval, pad to seq length 240 | if any(mask_buffer): 241 | mask_buffer.extend(n_missing * [False]) 242 | x_buffer.extend(n_missing * [0]) 243 | y_buffer.extend(n_missing * [0]) 244 | sizes.append(n_missing) 245 | 246 | yield SequenceMaskAndSizes( 247 | x=x_buffer, 248 | y=y_buffer, 249 | mask=mask_buffer, 250 | sizes=sizes, 251 | ) 252 | 253 | 254 | def build_dataset( 255 | pretrain_data: str, 256 | instruct_data: str, 257 | instruct_args: InstructArgs, 258 | instruct_tokenizer: InstructTokenizerBase, 259 | seq_len: int, 260 | seed: Optional[int], 261 | rank: int, 262 | world_size: int, 263 | is_eval: bool, 264 | shuffle_pretrain: bool = False, 265 | ) -> Iterator[SequenceMaskAndSizes]: 266 | sources, probabilities = parse_data_sources( 267 | pretrain_data=pretrain_data, instruct_data=instruct_data 268 | ) 269 | 270 | def do_shuffle(source: Union[DataDir, DataFile]) -> bool: 271 | shuffle = { 272 | SampleType.PRETRAIN: shuffle_pretrain, 273 | SampleType.INSTRUCT: instruct_args.shuffle, 274 | }[source.sample_type] 275 | 276 | return not is_eval and shuffle 277 | 278 | dataset_iterators = [ 279 | get_dataset_iterator( 280 | source, 281 | instruct_args=instruct_args, 282 | instruct_tokenizer=instruct_tokenizer, 283 | rank=rank, 284 | world_size=world_size, 285 | is_finite=is_eval, 286 | seed=seed, 287 | shuffle_at_epoch=do_shuffle(source), 288 | ) 289 | for source in sources 290 | ] 291 | 292 | sequence_iterators = [ 293 | sequence_iterator( 294 | ds_it=it, 295 | seq_len=seq_len, 296 | is_finite=is_eval, 297 | ) 298 | for it in dataset_iterators 299 | ] 300 | 301 | if is_eval: 302 | combined_iterator = itertools.chain.from_iterable(sequence_iterators) 303 | else: 304 | # make sure random_seed is different per rank and original seed 305 | random_seed = np.array((seed, rank)) 306 | rng = np.random.RandomState(seed=random_seed) 307 | combined_iterator = interleave_iterators( 308 | sequence_iterators, probabilities=probabilities, rng=rng 309 | ) 310 | 311 | return combined_iterator 312 | 313 | 314 | def get_rng(seed: int, rank: int) -> np.random.RandomState: 315 | random_seed = np.array((seed, rank)) 316 | rng = np.random.RandomState(seed=random_seed) 317 | return rng 318 | 319 | 320 | def get_dataset_iterator( 321 | source: Union[DataDir, DataFile], 322 | instruct_args: InstructArgs, 323 | instruct_tokenizer: InstructTokenizerBase, 324 | rank: int, 325 | world_size: int, 326 | is_finite: bool, 327 | seed: Optional[int], 328 | shuffle_at_epoch: bool, 329 | ) -> Iterator[TokenSample]: 330 | jsonl_files = source.jsonl_files 331 | rng: Optional[np.random.RandomState] = ( 332 | get_rng(seed, rank) if seed is not None else None 333 | ) 334 | 335 | chunk_dataset = ( 336 | instruct_args.dynamic_chunk_fn_call 337 | and source.sample_type == SampleType.INSTRUCT 338 | ) 339 | 340 | if not is_finite: 341 | # train mode 342 | while True: 343 | for jsonl_file in jsonl_files: 344 | if shuffle_at_epoch: 345 | assert rng is not None, "`seed` has to be passed when shuffling" 346 | # will preload all data into RAM, shuffle and yield 347 | yield from preload_and_yield( 348 | jsonl_file, 349 | chunk_dataset=chunk_dataset, 350 | rank=rank, 351 | world_size=world_size, 352 | rng=rng, 353 | instruct_tokenizer=instruct_tokenizer, 354 | sample_type=source.sample_type, 355 | ) 356 | else: 357 | # will read data on-the-fly and yield 358 | main_logger_info(f"Lazily loading {jsonl_file} ...") 359 | yield from lazy_load_and_yield( 360 | jsonl_file, 361 | rank=rank, 362 | world_size=world_size, 363 | instruct_tokenizer=instruct_tokenizer, 364 | sample_type=source.sample_type, 365 | ) 366 | else: 367 | # eval mode 368 | for jsonl_file in jsonl_files: 369 | # No need to shuffle for eval 370 | yield from lazy_load_and_yield( 371 | jsonl_file, 372 | rank=rank, 373 | world_size=world_size, 374 | instruct_tokenizer=instruct_tokenizer, 375 | sample_type=source.sample_type, 376 | ) 377 | 378 | 379 | def preload_and_yield( 380 | jsonl_file: Path, 381 | chunk_dataset: bool, 382 | rank: int, 383 | world_size: int, 384 | rng: np.random.RandomState, 385 | instruct_tokenizer: InstructTokenizerBase, 386 | sample_type: SampleType, 387 | ) -> Iterator[TokenSample]: 388 | # only instruct data has to be chunked 389 | # load dataset if not already loaded. Make sure to only load 1/world_size dataset 390 | tokens_list = maybe_load_local_dataset( 391 | jsonl_file, chunk=chunk_dataset, rank=rank, world_size=world_size, instruct_tokenizer=instruct_tokenizer, sample_type=sample_type 392 | ) 393 | 394 | if sample_type == SampleType.PRETRAIN: 395 | assert chunk_dataset is False, "Pretrain data should not have chunking enabled." 396 | 397 | main_logger_info(f"Shuffling {jsonl_file} ...") 398 | rng.shuffle(tokens_list) # type: ignore 399 | 400 | for token_sample in tokens_list: 401 | yield token_sample 402 | 403 | def lazy_load_and_yield( 404 | jsonl_file: Path, 405 | rank: int, 406 | world_size: int, 407 | instruct_tokenizer: InstructTokenizerBase, 408 | sample_type: SampleType, 409 | ): 410 | with jsonl_file.open() as file_handle: 411 | for idx, line in enumerate(file_handle): 412 | if not idx % world_size == rank: 413 | continue 414 | 415 | data = json.loads(line) 416 | yield encode( 417 | data, 418 | instruct_tokenizer=instruct_tokenizer, 419 | as_type=sample_type, 420 | ) 421 | 422 | 423 | def maybe_chunk_lines(lines: List[str]) -> List[str]: 424 | extra_lines: List[str] = [] 425 | for line in lines: 426 | data = json.loads(line) 427 | # multi-turn fn call data will be chunked and shorter conversations are added additionally 428 | maybe_chunked_lines = maybe_chunk_data(data) 429 | extra_lines.extend([json.dumps(line) for line in maybe_chunked_lines]) 430 | 431 | return extra_lines 432 | 433 | 434 | def maybe_chunk_data(data: Dict[str, Any]) -> List[Dict[str, Any]]: 435 | # think about always allowing both open-ai and non-open-ai data 436 | sample = build_instruct_sample(data) 437 | 438 | def num_assistant_messages(sample: TrainingInstructSample) -> int: 439 | return len( 440 | [m for m in sample.messages if isinstance(m, FinetuningAssistantMessage)] 441 | ) 442 | 443 | chunk_data = [] 444 | while sample.only_last is True and num_assistant_messages(sample) > 1: 445 | assert sample == build_instruct_sample(sample.dict()) 446 | last_message = sample.messages.pop() 447 | 448 | # 1. First pop until and including last assistant message 449 | system_message = None 450 | while not isinstance(last_message, FinetuningAssistantMessage): 451 | last_message = sample.messages.pop() 452 | if isinstance(last_message, SystemMessage): 453 | system_message = last_message 454 | 455 | # 2. Second pop until and excluding last assistant message 456 | prev_last_message = sample.messages[-1] 457 | while not isinstance(prev_last_message, FinetuningAssistantMessage): 458 | last_message = sample.messages.pop() 459 | if isinstance(last_message, SystemMessage): 460 | system_message = last_message 461 | 462 | prev_last_message = sample.messages[-1] 463 | 464 | # if system_message is not None, append again 465 | if system_message is not None: 466 | sample.messages.append(system_message) 467 | chunk_data.append(sample.dict()) 468 | 469 | return chunk_data 470 | 471 | 472 | def interleave_iterators(iterators: List[Iterator], probabilities, rng): 473 | while True: 474 | it_id = rng.choice(range(len(iterators)), p=probabilities) 475 | yield next(iterators[it_id]) 476 | -------------------------------------------------------------------------------- /finetune/data/exceptions.py: -------------------------------------------------------------------------------- 1 | class MessageFormatError(Exception): 2 | def __init__(self, message, data): 3 | self._message = message 4 | self._begin_data = data[:20] 5 | super().__init__() 6 | 7 | def __str__(self): 8 | return f"A message starting with {self._begin_data} is incorrectly formatted." + self._message 9 | 10 | 11 | class ToolCallFormatError(Exception): 12 | def __init__(self, message, data): 13 | self._message = message 14 | self._begin_data = data[:20] 15 | super().__init__() 16 | 17 | def __str__(self): 18 | return f"A tool call assistant message starting with {self._begin_data} of the conversation is incorrectly formatted. " + self._message 19 | 20 | 21 | class FunctionFormatError(Exception): 22 | def __init__(self, message, data): 23 | self._message = message 24 | self._begin_data = data[:20] 25 | super().__init__() 26 | 27 | def __str__(self): 28 | return ( 29 | f"A function of the conversation starting with {self._begin_data} is incorrectly formatted. " 30 | + self._message 31 | ) 32 | 33 | 34 | class ConversationFormatError(Exception): 35 | def __init__(self, message, data): 36 | self._message = message 37 | self._begin_data = data[:20] 38 | super().__init__() 39 | 40 | def __str__(self): 41 | return ( 42 | f"A conversation starting with {self._begin_data} is incorrectly formatted. " + self._message 43 | ) 44 | 45 | 46 | class UnrecognizedRoleError(Exception): 47 | def __init__(self, role, allowed_roles): 48 | self._role = role 49 | self._allowed_roles = allowed_roles 50 | super().__init__() 51 | 52 | def __str__(self): 53 | return ( 54 | f"The following role: {self._role} is not recognized in line: {self.line} of the dataset {self.dataset}. Make sure that each role is one of {self._allowed_roles}" 55 | + self._message 56 | ) 57 | -------------------------------------------------------------------------------- /finetune/data/tokenize.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from dataclasses import dataclass 3 | from enum import Enum 4 | from typing import Any, Dict, List, Optional, Union 5 | 6 | from mistral_common.protocol.instruct.messages import ( 7 | FinetuningAssistantMessage, 8 | Roles, 9 | SystemMessage, 10 | ToolMessage, 11 | UserMessage, 12 | ) 13 | from mistral_common.protocol.instruct.tool_calls import ( 14 | Function, 15 | FunctionCall, 16 | Tool, 17 | ToolCall, 18 | ) 19 | from mistral_common.protocol.instruct.validator import ( 20 | MistralRequestValidatorV3, 21 | ValidationMode, 22 | ) 23 | from mistral_common.tokens.instruct.request import InstructRequest 24 | from mistral_common.tokens.tokenizers.base import Tokenizer 25 | from mistral_common.tokens.tokenizers.sentencepiece import InstructTokenizerBase 26 | 27 | from .exceptions import ( 28 | ConversationFormatError, 29 | FunctionFormatError, 30 | MessageFormatError, 31 | ToolCallFormatError, 32 | UnrecognizedRoleError, 33 | ) 34 | 35 | logger = logging.getLogger("tokenize") 36 | 37 | Sequence = List[int] 38 | Mask = List[bool] 39 | 40 | 41 | class TrainingInstructSample(InstructRequest): 42 | available_tools: Optional[List[Tool]] = None 43 | only_last: bool = False 44 | 45 | 46 | @dataclass() 47 | class TokenSample: 48 | tokens: Sequence 49 | masks: Mask 50 | 51 | 52 | class SampleType(str, Enum): 53 | PRETRAIN = "pretrain" 54 | INSTRUCT = "instruct" 55 | 56 | 57 | def encode( 58 | data: Dict[str, Any], 59 | instruct_tokenizer: InstructTokenizerBase, 60 | as_type: SampleType, 61 | ) -> TokenSample: 62 | sample: Union[str, TrainingInstructSample] 63 | if as_type == SampleType.PRETRAIN: 64 | sample = get_pretrain_sample(data) 65 | elif as_type == SampleType.INSTRUCT: 66 | sample = build_instruct_sample(data) 67 | 68 | return tokenize(sample=sample, instruct_tokenizer=instruct_tokenizer) 69 | 70 | 71 | def get_pretrain_sample(data: Dict[str, Any]) -> str: 72 | content_keys = ["text", "content"] 73 | assert not all( 74 | k in data for k in content_keys 75 | ), "Make sure to have either 'text' or 'content' in your data. Not both." 76 | assert any( 77 | data.get(k) is not None for k in content_keys 78 | ), f"Must have one of 'text' or 'content' in your data. Only have {data.keys()}" 79 | 80 | # get first non-None value 81 | sample = None 82 | for key in content_keys: 83 | sample = data[key] if key in data else sample 84 | 85 | assert isinstance(sample, str), sample 86 | 87 | return sample 88 | 89 | 90 | def build_instruct_sample(data: Dict[str, Any]) -> TrainingInstructSample: 91 | messages: List[Union[SystemMessage, UserMessage, FinetuningAssistantMessage, ToolMessage]] = [] 92 | # optional data fields that might be set 93 | available_tools: Optional[List[Tool]] = data.get("available_tools") 94 | system_prompt = data.get("system_prompt") 95 | 96 | messages_keys = ["messages", "interactions"] 97 | content_keys = ["content", "text"] # both are accepted 98 | allowed_roles = [role.value for role in Roles] 99 | 100 | if not any(messages_key in data for messages_key in messages_keys): 101 | err = f"The conversation does not contain one of '{', '.join(messages_keys)}' key, but only {', '.join(data.keys())}. Make sure that the conversation includes one of '{', '.join(messages_keys)}'." 102 | raise ConversationFormatError(err, str(data)) 103 | 104 | if all(messages_key in data for messages_key in messages_keys): 105 | err = f"The conversation cannot contain both of '{', '.join(messages_keys)}' key, but only one of the two." 106 | raise ConversationFormatError(err, str(data)) 107 | 108 | # get first non-None value 109 | data_messages: Optional[List[Dict[str, Any]]] = None 110 | for key in messages_keys: 111 | data_messages = data[key] if key in data else data_messages 112 | 113 | assert data_messages is not None, "data_messages can't be None" 114 | 115 | if "available_tools" in data and "tools" in data: 116 | err = "The conversation contains both an `available_tools` and `tools` key. You can only have one." 117 | raise ConversationFormatError(err, str(data)) 118 | 119 | if data.get("tools", None) is not None and len(data["tools"]) > 0: 120 | available_tools = _parse_available_tools(data["tools"]) 121 | elif ( 122 | data.get("available_tools", None) is not None 123 | and len(data["available_tools"]) > 0 124 | ): 125 | available_tools = _parse_available_tools(data["available_tools"]) 126 | 127 | for data_message in data_messages: 128 | is_tool_call = data_message.get("tool_calls") is not None 129 | 130 | if "role" not in data_message: 131 | err = f"A message does not contain a 'role' key, but only {', '.join(data_message.keys())}. Make sure that the message includes the key 'role'." 132 | raise MessageFormatError(err, str(data)) 133 | 134 | role = data_message["role"] 135 | 136 | if all(key in data_message for key in content_keys): 137 | err = f"A {role} message contains both a 'text' and 'content' key. Make sure that there is only one of the two." 138 | raise MessageFormatError(err, str(data)) 139 | 140 | content: Optional[str] = None 141 | for key in content_keys: 142 | content = content if content is not None else data_message.get(key) 143 | 144 | # non-function call message must have content 145 | if not is_tool_call and content is None: 146 | err = f"A {role} message does not contain one of '{content_keys}' key, but only {', '.join(data_message.keys())}. Make sure that the message includes one of '{content_keys}' keys." 147 | raise MessageFormatError(err, str(data)) 148 | 149 | if role not in allowed_roles: 150 | raise UnrecognizedRoleError(role, allowed_roles) 151 | 152 | if data_message["role"] == "user": 153 | assert content is not None 154 | messages.append(UserMessage(content=content)) 155 | elif data_message["role"] == "assistant": 156 | tool_calls: Optional[List[ToolCall]] = None 157 | 158 | if is_tool_call: 159 | tool_calls = _parse_tool_calls(data_message["tool_calls"]) 160 | 161 | weight = data_message.get("weight") 162 | messages.append( 163 | FinetuningAssistantMessage( 164 | content=content, tool_calls=tool_calls, weight=weight 165 | ) 166 | ) 167 | elif data_message["role"] == "system": 168 | if system_prompt is not None: 169 | err = "Multiple messages with role 'system' encountered. Only one is allowed." 170 | raise MessageFormatError(err, str(data)) 171 | 172 | system_prompt = content 173 | elif data_message["role"] == "tool": 174 | assert content is not None 175 | tool_message = _parse_tool_message(content, data_message) 176 | messages.append(tool_message) 177 | 178 | # validate created messages 179 | validator = MistralRequestValidatorV3(ValidationMode.finetuning) 180 | validator.validate_messages(messages) 181 | validator._validate_tools(available_tools or []) 182 | 183 | # whether to train only on last assistant message 184 | only_last = data.get("only_last", False) or available_tools is not None 185 | 186 | return TrainingInstructSample( 187 | messages=messages, 188 | system_prompt=system_prompt, 189 | available_tools=available_tools, 190 | only_last=only_last, 191 | ) 192 | 193 | 194 | def _parse_available_tools(tools: List[Dict[str, Any]]) -> List[Tool]: 195 | available_tools = [] 196 | for tool in tools: 197 | if "function" not in tool: 198 | raise FunctionFormatError( 199 | "A tool dict does not have a 'function' key.", str(tool) 200 | ) 201 | 202 | func_data = tool["function"] 203 | 204 | for key in ["name", "description", "parameters"]: 205 | if key not in func_data: 206 | raise FunctionFormatError( 207 | f"A function dict does not have a {key} key.", str(func_data) 208 | ) 209 | 210 | if not isinstance(func_data["parameters"], dict): 211 | raise FunctionFormatError( 212 | f"A function 'parameters' key has to be of type dict, but is {type(func_data['parameters'])}. If the function has no parameters pass an empty dict ", str(func_data) 213 | ) 214 | 215 | description = func_data["description"] 216 | function = Function( 217 | name=func_data["name"], 218 | description=description, 219 | parameters=func_data["parameters"], 220 | ) 221 | 222 | available_tools.append(Tool(function=function)) 223 | return available_tools 224 | 225 | 226 | def _parse_tool_calls(calls: List[Dict[str, Any]]) -> List[ToolCall]: 227 | for key in ["id", "function"]: 228 | if not all(key in call for call in calls): 229 | err = f"A tool call of an assistant message does not have a {key} key" 230 | raise ToolCallFormatError(err, str(calls)) 231 | 232 | for key in ["name", "arguments"]: 233 | if not all(key in call["function"] for call in calls): 234 | err = ( 235 | f"A tool call function of an assistant message does not have a {key} key" 236 | ) 237 | raise ToolCallFormatError(err, str(calls)) 238 | 239 | if not all(isinstance(call["function"]["arguments"], str) for call in calls): 240 | err = "A tool call function of an assistant message does not have a 'arguments' key of type str" 241 | raise ToolCallFormatError(err, str(calls)) 242 | 243 | tool_calls = [ 244 | ToolCall( 245 | id=call["id"], 246 | function=FunctionCall( 247 | name=call["function"]["name"], 248 | arguments=call["function"]["arguments"], 249 | ), 250 | ) 251 | for call in calls 252 | ] 253 | return tool_calls 254 | 255 | 256 | def _parse_tool_message(content: str, data_message: Dict[str, Any]) -> ToolMessage: 257 | if "tool_call_id" not in data_message: 258 | err = f"A tool message does not contain a 'tool_call_id' key, but only {', '.join(data_message.keys())}. Make sure that the message includes the key 'tool_call_id'." 259 | raise MessageFormatError(err, str(data_message)) 260 | 261 | call_id = data_message["tool_call_id"] 262 | # name is deprecated in v3, but we'll add it nevertheless for now 263 | name = data_message.get("name") 264 | 265 | return ToolMessage(content=content, tool_call_id=call_id, name=name) 266 | 267 | 268 | def tokenize( 269 | sample: Union[str, TrainingInstructSample], 270 | instruct_tokenizer: InstructTokenizerBase, 271 | ) -> TokenSample: 272 | if isinstance(sample, str): 273 | tokenizer: Tokenizer = instruct_tokenizer.tokenizer 274 | return tokenize_pretrain(sample, tokenizer) 275 | elif isinstance(sample, TrainingInstructSample): 276 | return tokenize_instruct(sample, instruct_tokenizer) 277 | 278 | raise ValueError( 279 | f"`sample` has to be either of type `str` or `TrainingInstructSample`, not {type(sample)}." 280 | ) 281 | 282 | 283 | def tokenize_pretrain(sample: str, tokenizer: Tokenizer) -> TokenSample: 284 | tokens = tokenizer.encode(sample, bos=True, eos=True) 285 | masks = [True] * len(tokens) 286 | return TokenSample(tokens, masks) 287 | 288 | 289 | def tokenize_instruct( 290 | sample: TrainingInstructSample, 291 | instruct_tokenizer: InstructTokenizerBase, 292 | ) -> TokenSample: 293 | tokens: List[int] = instruct_tokenizer.start() 294 | masks: List[bool] = [False] 295 | 296 | mask_all_but_last = sample.only_last 297 | 298 | # find first and last user message 299 | user_messages = [ 300 | i for i, msg in enumerate(sample.messages) if isinstance(msg, UserMessage) 301 | ] 302 | first_user_idx = user_messages[0] if user_messages else -1 303 | last_user_idx = user_messages[-1] if user_messages else -1 304 | 305 | for msg_idx, message in enumerate(sample.messages): 306 | if isinstance(message, UserMessage): 307 | curr_tokens = instruct_tokenizer.encode_user_message( 308 | message, 309 | available_tools=sample.available_tools, 310 | is_last=msg_idx == last_user_idx, 311 | is_first=msg_idx == first_user_idx, 312 | system_prompt=sample.system_prompt, 313 | ) 314 | if isinstance(curr_tokens, tuple): 315 | # Versions of mistral_common>1.3.4 return a tuple of tokens (text), tokens (image), spans (image) 316 | curr_tokens = curr_tokens[0] 317 | 318 | curr_masks = [False] * len(curr_tokens) # only predict bot answers 319 | elif isinstance(message, ToolMessage): 320 | curr_tokens = instruct_tokenizer.encode_tool_message( 321 | message, is_before_last_user_message=msg_idx < last_user_idx 322 | ) 323 | curr_masks = [False] * len(curr_tokens) # only predict bot answers 324 | elif isinstance(message, FinetuningAssistantMessage): 325 | is_last_message = msg_idx == (len(sample.messages) - 1) 326 | 327 | # we don't want to predict a random call id 328 | message = maybe_remove_call_id(message, is_last_message=is_last_message) 329 | 330 | curr_tokens = instruct_tokenizer.encode_assistant_message( 331 | message, is_before_last_user_message=False 332 | ) 333 | 334 | is_weighted = message.weight is None or message.weight == 1 335 | is_relevant = (not mask_all_but_last) or is_last_message 336 | if is_weighted and is_relevant: 337 | curr_masks = [True] * len(curr_tokens) # only predict bot answers 338 | else: 339 | # in function calling we only backprop through last message 340 | curr_masks = [False] * len(curr_tokens) 341 | 342 | tokens.extend(curr_tokens) 343 | masks.extend(curr_masks) 344 | 345 | return TokenSample(tokens, masks) 346 | 347 | 348 | def maybe_remove_call_id(message: FinetuningAssistantMessage, is_last_message: bool): 349 | if message.tool_calls is None or not is_last_message: 350 | return message 351 | 352 | # remove call id 353 | message.tool_calls = [ 354 | ToolCall(function=call.function) for call in message.tool_calls 355 | ] 356 | 357 | return message 358 | -------------------------------------------------------------------------------- /finetune/distributed.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from functools import lru_cache 4 | from typing import List, Union 5 | 6 | import torch 7 | import torch.distributed as dist 8 | 9 | logger = logging.getLogger("distributed") 10 | 11 | BACKEND = "nccl" 12 | 13 | 14 | @lru_cache() 15 | def get_rank() -> int: 16 | return dist.get_rank() 17 | 18 | 19 | @lru_cache() 20 | def get_world_size() -> int: 21 | return dist.get_world_size() 22 | 23 | 24 | def visible_devices() -> List[int]: 25 | return [int(d) for d in os.environ["CUDA_VISIBLE_DEVICES"].split(",")] 26 | 27 | 28 | def set_device(): 29 | logger.info(f"torch.cuda.device_count: {torch.cuda.device_count()}") 30 | logger.info(f"CUDA_VISIBLE_DEVICES: {os.environ['CUDA_VISIBLE_DEVICES']}") 31 | logger.info(f"local rank: {int(os.environ['LOCAL_RANK'])}") 32 | 33 | assert torch.cuda.is_available() 34 | 35 | assert len(visible_devices()) == torch.cuda.device_count() 36 | 37 | if torch.cuda.device_count() == 1: 38 | # gpus-per-task set to 1 39 | torch.cuda.set_device(0) 40 | return 41 | 42 | local_rank = int(os.environ["LOCAL_RANK"]) 43 | logger.info(f"Set cuda device to {local_rank}") 44 | 45 | assert 0 <= local_rank < torch.cuda.device_count(), ( 46 | local_rank, 47 | torch.cuda.device_count(), 48 | ) 49 | torch.cuda.set_device(local_rank) 50 | 51 | 52 | def avg_aggregate(metric: Union[float, int]) -> Union[float, int]: 53 | buffer = torch.tensor([metric], dtype=torch.float32, device="cuda") 54 | dist.all_reduce(buffer, op=dist.ReduceOp.SUM) 55 | return buffer[0].item() / get_world_size() 56 | 57 | 58 | def is_torchrun() -> bool: 59 | return "TORCHELASTIC_RESTART_COUNT" in os.environ 60 | -------------------------------------------------------------------------------- /finetune/eval.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import List 3 | 4 | import numpy as np 5 | import torch.cuda 6 | import torch.distributed as dist 7 | from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel 8 | 9 | from .data.data_loader import Batch 10 | from .distributed import get_rank, get_world_size 11 | from .loss import compute_loss_with_mask 12 | from .utils import TrainState 13 | 14 | logger = logging.getLogger("eval") 15 | 16 | 17 | def main_logger_info(message: str) -> None: 18 | if get_rank() == 0: 19 | logger.info(message) 20 | 21 | 22 | def evaluate( 23 | model: FullyShardedDataParallel, 24 | batches: List[Batch], 25 | state: TrainState, 26 | ): 27 | # Create fake samples to make FSDP happy for unbalanced data 28 | num_samples = torch.tensor([len(batches)], device="cuda", dtype=torch.long) 29 | all_num_samples = [torch.zeros_like(num_samples) for _ in range(get_world_size())] 30 | 31 | torch.distributed.all_gather(all_num_samples, num_samples) 32 | 33 | total_num_samples = int(torch.tensor(all_num_samples).sum().item()) 34 | max_num_samples = int(torch.tensor(all_num_samples).max().item()) 35 | 36 | for _ in range(max_num_samples - int(num_samples.item())): 37 | pad_x = np.zeros_like(batches[-1].x) 38 | pad_y = np.zeros_like(batches[-1].y) 39 | pad_sizes = batches[-1].sizes.copy() 40 | 41 | pad_batch = Batch(pad_x, pad_y, pad_sizes, is_pad_only=True) 42 | batches.append(pad_batch) 43 | 44 | # eval mode! 45 | model.eval() 46 | 47 | eval_loss = torch.tensor(0.0).cuda() 48 | main_logger_info("Start eval...") 49 | for batch in batches: 50 | x = torch.from_numpy(batch.x).cuda() 51 | y = torch.from_numpy(batch.y).cuda() 52 | y_mask = ( 53 | torch.from_numpy(batch.y_mask).cuda() if batch.y_mask is not None else None 54 | ) 55 | 56 | with torch.no_grad(): 57 | output = model( 58 | input_ids=x, 59 | seqlens=batch.sizes, 60 | ) 61 | 62 | if not batch.is_pad_only: 63 | eval_loss += compute_loss_with_mask(output, y, y_mask) 64 | 65 | assert batch.is_pad_only or y.abs().sum() != 0, "Pad sample is used to compute loss." 66 | 67 | # sum loss 68 | main_logger_info("Eval finished!") 69 | 70 | dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM) 71 | eval_loss /= total_num_samples 72 | 73 | state.this_eval_loss = eval_loss.item() 74 | state.this_eval_perplexity = (2**eval_loss).item() 75 | 76 | # train mode! 77 | model.train() 78 | -------------------------------------------------------------------------------- /finetune/loss.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | 6 | 7 | def compute_loss_with_mask( 8 | logits: torch.Tensor, target: torch.Tensor, target_mask: Optional[torch.Tensor] 9 | ): 10 | if target_mask is None: 11 | return F.cross_entropy(logits, target, reduction="mean") 12 | 13 | mb_loss = F.cross_entropy(logits, target, reduction="none") 14 | mb_loss = torch.sum(mb_loss * target_mask) / torch.sum(target_mask) 15 | 16 | return mb_loss 17 | -------------------------------------------------------------------------------- /finetune/mixed_precision.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable 2 | 3 | import torch 4 | 5 | 6 | def prepare_mixed_precision( 7 | params: Iterable[torch.nn.Parameter], 8 | param_dtype: torch.dtype, 9 | optim_dtype: torch.dtype, 10 | ): 11 | """Appends a freshly allocated fp32 tensor copy of all params to parameters that can be updated.""" 12 | with torch.no_grad(): 13 | for p in params: 14 | if p.requires_grad: 15 | # Mixed precision: let's save a fp32 param tensor to each params that require a grad 16 | p._mp_param = torch.empty_like(p, dtype=optim_dtype) # type: ignore 17 | p._mp_param.copy_(p.to(optim_dtype)) # type: ignore 18 | 19 | p.data = p.data.to(param_dtype) 20 | 21 | 22 | def upcast_mixed_precision( 23 | params: Iterable[torch.nn.Parameter], optim_dtype: torch.dtype 24 | ): 25 | """Make sure to run this function BEFORE optimizer.step() so that all weights and optimizer states are updated in fp32 in .step()""" 26 | with torch.no_grad(): 27 | for p in params: 28 | if p.requires_grad and p.grad is not None: 29 | # store original tensor in p._temp 30 | p._temp = p.data # type: ignore 31 | # upcast data for the optimizer step 32 | p.data = p._mp_param # type: ignore 33 | p.grad = p.grad.to(optim_dtype) 34 | 35 | 36 | def downcast_mixed_precision( 37 | params: Iterable[torch.nn.Parameter], param_dtype: torch.dtype 38 | ): 39 | """Make sure to run this function AFTER optimizer.step() as optimizer.step() will update data underlying p.data and p._mp_param pointers""" 40 | with torch.no_grad(): 41 | for p in params: 42 | if p.requires_grad and p.grad is not None: 43 | # copy fp32 weights into bfloat16 tensor 44 | p._temp.copy_(p.data) # type: ignore 45 | # set _temp again to the data tensor 46 | p.data = p._temp # type: ignore 47 | p.grad = p.grad.to(param_dtype) 48 | -------------------------------------------------------------------------------- /finetune/monitoring/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mistralai/mistral-finetune/656df1c94c80ca9703ebc471c9f106c9b7a0bfa7/finetune/monitoring/__init__.py -------------------------------------------------------------------------------- /finetune/monitoring/metrics_logger.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | from datetime import datetime, timedelta 5 | from pathlib import Path 6 | from typing import Any, Dict, Optional, Union 7 | 8 | from torch.utils.tensorboard import SummaryWriter 9 | 10 | from finetune.args import MLFlowArgs, TrainArgs, WandbArgs 11 | from finetune.utils import TrainState 12 | 13 | logger = logging.getLogger("metrics_logger") 14 | 15 | GB = 1024**3 16 | 17 | 18 | def get_train_logs( 19 | state: TrainState, 20 | loss: float, 21 | lr: float, 22 | peak_allocated_mem: float, 23 | allocated_mem: float, 24 | train_args: TrainArgs, 25 | ) -> Dict[str, Union[float, int]]: 26 | metrics = { 27 | "lr": lr, 28 | "step": state.step, 29 | "loss": loss, 30 | "percent_done": 100 * state.step / train_args.max_steps, 31 | "peak_allocated_mem": peak_allocated_mem / GB, 32 | "allocated_mem": allocated_mem / GB, 33 | "wps": state.wps, 34 | "avg_wps": state.avg_wps, 35 | "eta_in_seconds": state.eta, 36 | } 37 | 38 | return metrics 39 | 40 | 41 | def get_eval_logs( 42 | step: int, 43 | train_loss: float, 44 | perplexity: Optional[float], 45 | eval_loss: Optional[float], 46 | ) -> Dict[str, Union[float, int]]: 47 | eval_dict = {"step": step, "train_loss": train_loss} 48 | 49 | if perplexity is not None: 50 | eval_dict["perplexity"] = perplexity 51 | 52 | if eval_loss is not None: 53 | eval_dict["eval_loss"] = eval_loss 54 | return eval_dict 55 | 56 | 57 | def train_log_msg( 58 | state: TrainState, logs: Dict[str, Union[float, int]], loss: float 59 | ) -> str: 60 | metrics: Dict[str, Union[float, int, datetime]] = dict(logs) # shallow copy 61 | metrics.pop("eta_in_seconds") 62 | 63 | metrics["eta"] = datetime.now() + timedelta(seconds=state.eta) 64 | metrics["step"] = state.step 65 | metrics["loss"] = loss 66 | 67 | parts = [] 68 | for key, fmt, new_name in [ 69 | ("step", "06", None), 70 | ("percent_done", "03.1f", "done (%)"), 71 | ("loss", ".3f", None), 72 | ("lr", ".1e", None), 73 | ("peak_allocated_mem", ".1f", "peak_alloc_mem (GB)"), 74 | ("allocated_mem", ".1f", "alloc_mem (GB)"), 75 | ("wps", ".1f", "words_per_second"), 76 | ("avg_wps", ".1f", "avg_words_per_second"), 77 | ("eta", "%Y-%m-%d %H:%M:%S", "ETA"), 78 | ]: 79 | name = key if new_name is None else new_name 80 | try: 81 | parts.append(f"{name}: {metrics[key]:>{fmt}}") 82 | except KeyError: 83 | logger.error(f"{key} not found in {sorted(metrics.keys())}") 84 | raise 85 | 86 | return " - ".join(parts) 87 | 88 | 89 | def eval_log_msg(logs: Dict[str, Union[float, int]]) -> str: 90 | parts = [] 91 | for key, fmt, new_name in [ 92 | ("step", "06", None), 93 | ("perplexity", ".3f", "eval_perplexity"), 94 | ("eval_loss", ".3f", None), 95 | ("train_loss", ".3f", None), 96 | ]: 97 | name = key if new_name is None else new_name 98 | if key in logs: 99 | parts.append(f"{name}: {logs[key]:>{fmt}}") 100 | 101 | return " - ".join(parts) 102 | 103 | 104 | class MetricsLogger: 105 | def __init__( 106 | self, 107 | dst_dir: Path, 108 | tag: str, 109 | is_master: bool, 110 | wandb_args: WandbArgs, 111 | mlflow_args: MLFlowArgs, 112 | config: Optional[Dict[str, Any]] = None, 113 | ): 114 | self.dst_dir = dst_dir 115 | self.tag = tag 116 | self.is_master = is_master 117 | self.jsonl_path = dst_dir / f"metrics.{tag}.jsonl" 118 | self.tb_dir = dst_dir / "tb" 119 | self.summary_writer: Optional[SummaryWriter] = None 120 | 121 | if not self.is_master: 122 | return 123 | 124 | filename_suffix = f".{tag}" 125 | self.tb_dir.mkdir(exist_ok=True) 126 | self.summary_writer = SummaryWriter( 127 | log_dir=str(self.tb_dir), 128 | max_queue=1000, 129 | filename_suffix=filename_suffix, 130 | ) 131 | self.is_wandb = wandb_args.project is not None 132 | self.is_mlflow = mlflow_args.tracking_uri is not None 133 | 134 | if self.is_wandb: 135 | import wandb 136 | 137 | if wandb_args.key is not None: 138 | wandb.login(key=wandb_args.key) 139 | if wandb_args.offline: 140 | os.environ["WANDB_MODE"] = "offline" 141 | if wandb.run is None: 142 | logger.info("initializing wandb") 143 | wandb.init( 144 | config=config, 145 | dir=dst_dir, 146 | project=wandb_args.project, 147 | job_type="training", 148 | name=wandb_args.run_name or dst_dir.name, 149 | resume=False, 150 | ) 151 | 152 | self.wandb_log = wandb.log 153 | 154 | if self.is_mlflow: 155 | import mlflow 156 | 157 | mlflow.set_tracking_uri(mlflow_args.tracking_uri) 158 | mlflow.set_experiment(mlflow_args.experiment_name or dst_dir.name) 159 | 160 | if tag == "train": 161 | mlflow.start_run() 162 | 163 | self.mlflow_log = mlflow.log_metric 164 | 165 | def log(self, metrics: Dict[str, Union[float, int]], step: int): 166 | if not self.is_master: 167 | return 168 | 169 | metrics_to_ignore = {"step"} 170 | assert self.summary_writer is not None 171 | for key, value in metrics.items(): 172 | if key in metrics_to_ignore: 173 | continue 174 | assert isinstance(value, (int, float)), (key, value) 175 | self.summary_writer.add_scalar( 176 | tag=f"{self.tag}.{key}", scalar_value=value, global_step=step 177 | ) 178 | 179 | if self.is_mlflow: 180 | self.mlflow_log(f"{self.tag}.{key}", value, step=step) 181 | 182 | if self.is_wandb: 183 | # grouping in wandb is done with / 184 | self.wandb_log( 185 | { 186 | f"{self.tag}/{key}": value 187 | for key, value in metrics.items() 188 | if key not in metrics_to_ignore 189 | }, 190 | step=step, 191 | ) 192 | 193 | metrics_: Dict[str, Any] = dict(metrics) # shallow copy 194 | if "step" in metrics_: 195 | assert step == metrics_["step"] 196 | else: 197 | metrics_["step"] = step 198 | metrics_["at"] = datetime.utcnow().isoformat() 199 | with self.jsonl_path.open("a") as fp: 200 | fp.write(f"{json.dumps(metrics_)}\n") 201 | 202 | def close(self): 203 | if not self.is_master: 204 | return 205 | 206 | if self.summary_writer is not None: 207 | self.summary_writer.close() 208 | self.summary_writer = None 209 | 210 | if self.is_wandb: 211 | import wandb 212 | 213 | # to be sure we are not hanging while finishing 214 | wandb.finish() 215 | 216 | if self.is_mlflow: 217 | import mlflow 218 | 219 | mlflow.end_run() 220 | 221 | def __del__(self): 222 | if self.summary_writer is not None: 223 | raise RuntimeError( 224 | "MetricsLogger not closed properly! You should " 225 | "make sure the close() method is called!" 226 | ) -------------------------------------------------------------------------------- /finetune/monitoring/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import sys 4 | import time 5 | 6 | 7 | class DeltaTimeFormatter(logging.Formatter): 8 | def format(self, record): 9 | delta = datetime.timedelta( 10 | seconds=int(record.relativeCreated / 1000) 11 | ) # no milliseconds 12 | record.delta = delta 13 | return super().format(record) 14 | 15 | 16 | def set_logger(level: int = logging.INFO): 17 | root = logging.getLogger() 18 | root.handlers.clear() 19 | root.setLevel(level) 20 | tz, *_ = time.tzname 21 | 22 | LOGFORMAT = "%(asctime)s - %(delta)s - %(name)s - %(levelname)s - %(message)s" 23 | TIMEFORMAT = f"%Y-%m-%d %H:%M:%S ({tz})" 24 | formatter = DeltaTimeFormatter(LOGFORMAT, TIMEFORMAT) 25 | 26 | handler = logging.StreamHandler(sys.stdout) 27 | handler.setLevel(level) 28 | handler.setFormatter(formatter) 29 | root.addHandler(handler) 30 | 31 | handler = logging.StreamHandler(sys.stderr) 32 | handler.setLevel(logging.WARNING) 33 | handler.setFormatter(formatter) 34 | root.addHandler(handler) 35 | -------------------------------------------------------------------------------- /finetune/utils.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import dataclasses 3 | import datetime 4 | import logging 5 | import time 6 | from typing import Optional, Protocol 7 | 8 | import torch 9 | 10 | logger = logging.getLogger("utils") 11 | 12 | 13 | @dataclasses.dataclass 14 | class TrainState: 15 | max_steps: int 16 | step: int = 0 17 | elapsed_time: float = 0.0 18 | n_seen_tokens: int = 0 19 | this_step_time: float = 0.0 20 | begin_step_time: float = 0.0 21 | this_eval_perplexity: Optional[float] = None 22 | this_eval_loss: Optional[float] = None 23 | 24 | def start_step(self): 25 | self.step += 1 26 | self.begin_step_time = time.time() 27 | 28 | def end_step(self, n_batch_tokens: int): 29 | self.this_step_time = time.time() - self.begin_step_time 30 | self.this_step_tokens = n_batch_tokens 31 | 32 | self.elapsed_time += self.this_step_time 33 | self.n_seen_tokens += self.this_step_tokens 34 | 35 | self.begin_step_time = time.time() 36 | 37 | @property 38 | def wps(self): 39 | return self.this_step_tokens / self.this_step_time 40 | 41 | @property 42 | def avg_wps(self): 43 | return self.n_seen_tokens / self.elapsed_time 44 | 45 | @property 46 | def eta(self): 47 | steps_left = self.max_steps - self.step 48 | avg_time_per_step = self.elapsed_time / self.step 49 | 50 | return steps_left * avg_time_per_step 51 | 52 | 53 | def set_random_seed(seed: int) -> None: 54 | """Set random seed for reproducibility.""" 55 | torch.manual_seed(seed) 56 | torch.cuda.manual_seed(seed) 57 | 58 | 59 | class Closable(Protocol): 60 | def close(self): 61 | pass 62 | 63 | 64 | @contextlib.contextmanager 65 | def logged_closing(thing: Closable, name: str): 66 | """ 67 | Logging the closing to be sure something is not hanging at exit time 68 | """ 69 | try: 70 | setattr(thing, "wrapped_by_closing", True) 71 | yield 72 | finally: 73 | logger.info(f"Closing: {name}") 74 | try: 75 | thing.close() 76 | except Exception: 77 | logger.error(f"Error while closing {name}!") 78 | raise 79 | logger.info(f"Closed: {name}") 80 | 81 | 82 | def now_as_str() -> str: 83 | return datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 84 | -------------------------------------------------------------------------------- /finetune/wrapped_model.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import json 3 | import logging 4 | import math 5 | from pathlib import Path 6 | from typing import Callable, Union 7 | 8 | import safetensors 9 | import torch 10 | import torch.distributed.fsdp.wrap as torch_wrap 11 | from torch.distributed.fsdp import BackwardPrefetch 12 | from torch.distributed.fsdp.api import ShardingStrategy 13 | from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel 14 | 15 | from model.args import ModelArgs, MoeArgs 16 | from model.transformer import Transformer, TransformerBlock 17 | 18 | from .args import LoraArgs 19 | from .checkpointing import Checkpointer 20 | from .distributed import ( 21 | get_rank, 22 | get_world_size, 23 | ) 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | 28 | def main_logger_info(message: str) -> None: 29 | if get_rank() == 0: 30 | logger.info(message) 31 | 32 | 33 | def get_fsdp_policy(is_lora: bool) -> Callable[[torch.nn.Module], bool]: 34 | """ 35 | This function instantiates the FSDP wrap policy. 36 | - Each Transformers block becomes its own FSDP group so that only a single Transformer block is sharded at a time 37 | - If LoRA is enabled, we additionally create separate FSDP sub-groups for every trainable and non-trainable parameter group 38 | since this is a requirement for mixed requires_grad=True/False training. See: https://pytorch.org/docs/stable/fsdp.html 39 | """ 40 | 41 | # Each transformer block becomes a FSDP group, each being sharded separately 42 | transformer_block_wrap_policy = functools.partial( 43 | torch_wrap.transformer_auto_wrap_policy, 44 | transformer_layer_cls=(TransformerBlock,), 45 | ) 46 | 47 | if not is_lora: 48 | return transformer_block_wrap_policy 49 | 50 | def fsdp_lora_policy_fn(module): 51 | return all(p.requires_grad for p in module.parameters()) 52 | 53 | # For LoRA training, trainable and non-trainable parameters need to be put into 54 | # different FSDP groups 55 | fsdp_lora_policy = functools.partial( 56 | torch_wrap.lambda_auto_wrap_policy, lambda_fn=fsdp_lora_policy_fn 57 | ) 58 | 59 | policies = [fsdp_lora_policy, transformer_block_wrap_policy] 60 | 61 | return functools.partial(torch_wrap._or_policy, policies=policies) 62 | 63 | 64 | def log_train_params(model: Union[torch.nn.Module, FullyShardedDataParallel]): 65 | world_size = get_world_size() 66 | 67 | num_params = world_size * sum(p.numel() for p in model.parameters()) 68 | num_train_params = world_size * sum( 69 | p.numel() for p in model.parameters() if p.requires_grad 70 | ) 71 | 72 | main_logger_info( 73 | f"{num_train_params:,.0f} out of {num_params:,.0f} parameters are finetuned ({num_train_params / num_params * 100:.2f}%)." 74 | ) 75 | 76 | 77 | def initialize_lora_parameters(model: torch.nn.Module, param_dtype: torch.dtype): 78 | """ 79 | Initialize LoRA layers with Kaiming uniform and zeros. 80 | See original paper for more info: https://arxiv.org/abs/2106.09685 and 81 | original github repo: https://github.com/microsoft/LoRA/blob/a0a92e0f26c067cf94747bdbf1ce73793fa44d19/loralib/layers.py#L122 82 | """ 83 | for m_name, module in model.named_modules(): 84 | if all(p.is_meta for p in module.parameters()): 85 | for p_name, param in module.named_parameters(): 86 | module._parameters[p_name] = torch.nn.Parameter( 87 | torch.empty_like(param, device="cpu", dtype=param_dtype) 88 | ) 89 | param = module._parameters[p_name] 90 | 91 | if m_name.split(".")[-1] == "lora_A": 92 | torch.nn.init.kaiming_uniform_(param, a=math.sqrt(5)) 93 | elif m_name.split(".")[-1] == "lora_B": 94 | torch.nn.init.zeros_(param) 95 | else: 96 | raise ValueError("Only Lora layers should be randomly initialized.") 97 | 98 | 99 | def load_args(folder: Path, lora: LoraArgs) -> ModelArgs: 100 | with open(folder / "params.json", "r") as f: 101 | args = json.loads(f.read()) 102 | 103 | model_args = ModelArgs( 104 | lora=lora, 105 | dim=args["dim"], 106 | n_layers=args["n_layers"], 107 | head_dim=args["head_dim"], 108 | hidden_dim=args["hidden_dim"], 109 | n_heads=args["n_heads"], 110 | n_kv_heads=args["n_kv_heads"], 111 | norm_eps=args["norm_eps"], 112 | vocab_size=args["vocab_size"], 113 | ) 114 | 115 | if args.get("rope_theta") is not None: 116 | model_args.rope_theta = args["rope_theta"] 117 | 118 | if args.get("moe") is not None: 119 | model_args.moe = MoeArgs(**args["moe"]) 120 | 121 | return model_args 122 | 123 | 124 | def load_model( 125 | folder: Path, 126 | lora: LoraArgs, 127 | checkpoint: bool, 128 | param_dtype: torch.dtype, 129 | ) -> FullyShardedDataParallel: 130 | model_args = load_args(folder, lora) 131 | 132 | if model_args.vocab_size == 32000: 133 | raise ValueError( 134 | f"Fine-tuning is not supported for older model versions with vocab_size 32000. Make sure to extend your model to vocab_size=32768 using `python -m utils.extend_model_vocab --original_model_ckpt {folder} --extended_model_ckpt {folder}_extended`." 135 | ) 136 | 137 | assert ( 138 | model_args.vocab_size >= 32768 139 | ), "Make sure to use a model with a vocab size of at least 32768" 140 | 141 | with torch.device("meta"): 142 | model = Transformer(args=model_args, checkpoint=checkpoint) 143 | 144 | if get_rank() == 0: 145 | state_dict = load_state_dict(folder, dtype=param_dtype) 146 | 147 | model.load_state_dict(state_dict, assign=True) # type: ignore 148 | logger.info("Loaded model on cpu!") 149 | 150 | if lora.enable: 151 | logger.info("Initializing lora layers ...") 152 | # initialize LoRA layers 153 | initialize_lora_parameters(model, param_dtype) 154 | 155 | assert not any( 156 | p.is_meta for p in model.parameters() 157 | ), "All parameters should be initialized by now" 158 | assert all( 159 | p.dtype == param_dtype for p in model.parameters() 160 | ), f"All parameters should be on {param_dtype}" 161 | 162 | logger.info("Finished initialization!") 163 | param_init_fn = None 164 | else: 165 | 166 | def param_init_fn(m): 167 | m.to_empty(device=torch.cuda.current_device(), recurse=False) 168 | m.to(param_dtype) 169 | 170 | assert all( 171 | p.is_meta for p in model.parameters() 172 | ), "All parameters should be on meta" 173 | 174 | torch.distributed.barrier() 175 | 176 | # only finetune LoRA parameters and freeze before wrapping 177 | if lora.enable: 178 | for name, param in model.named_parameters(): 179 | if "lora" in name: 180 | param.requires_grad = True 181 | else: 182 | param.requires_grad = False 183 | 184 | auto_wrap_policy = get_fsdp_policy(model_args.lora.enable) 185 | 186 | main_logger_info(f"Sharding model over {get_world_size()} GPUs ...") 187 | 188 | wrapped_model = FullyShardedDataParallel( 189 | model, 190 | sharding_strategy=ShardingStrategy.FULL_SHARD, 191 | auto_wrap_policy=auto_wrap_policy, 192 | backward_prefetch=BackwardPrefetch.BACKWARD_PRE, 193 | limit_all_gathers=True, 194 | device_id=torch.cuda.current_device(), 195 | sync_module_states=True, 196 | param_init_fn=param_init_fn, 197 | ) 198 | main_logger_info("Model sharded!") 199 | 200 | log_train_params(wrapped_model) 201 | 202 | return wrapped_model 203 | 204 | 205 | @torch.no_grad() 206 | def load_state_dict(path: Path, dtype: torch.dtype): 207 | assert path.is_dir(), path 208 | 209 | this_safetensors_path = Checkpointer.consolidated_path(path, use_safetensors=True) 210 | this_torch_path = Checkpointer.consolidated_path(path, use_safetensors=False) 211 | 212 | assert ( 213 | this_safetensors_path.exists() or this_torch_path.exists() 214 | ), f"Either {this_safetensors_path} or {this_torch_path} must exist." 215 | assert not ( 216 | this_safetensors_path.exists() and this_torch_path.exists() 217 | ), f"Only one of {this_safetensors_path} or {this_torch_path} should exist." 218 | 219 | if this_safetensors_path.exists(): 220 | logger.info(f"Reloading model from {this_safetensors_path} ...") 221 | model_state_dict = safetensors.torch.load_file(this_safetensors_path) 222 | else: 223 | logger.info(f"Reloading model from {this_torch_path} ...") 224 | model_state_dict = torch.load(this_torch_path) 225 | 226 | logger.info(f"Converting model to dtype {dtype} ...") 227 | 228 | for k, v in model_state_dict.items(): 229 | model_state_dict[k] = v.to(dtype) 230 | 231 | return model_state_dict 232 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mistralai/mistral-finetune/656df1c94c80ca9703ebc471c9f106c9b7a0bfa7/model/__init__.py -------------------------------------------------------------------------------- /model/args.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Optional 3 | 4 | from simple_parsing.helpers import Serializable 5 | 6 | 7 | @dataclass 8 | class LoraArgs(Serializable): 9 | enable: bool = True 10 | rank: int = 16 11 | dropout: float = 0.0 12 | scaling: float = 2.0 13 | 14 | def __post_init__(self): 15 | if self.enable: 16 | assert self.rank > 0 17 | assert self.scaling > 0.0 18 | 19 | 20 | @dataclass 21 | class MoeArgs(Serializable): 22 | num_experts: int = 8 23 | num_experts_per_tok: int = 2 24 | 25 | 26 | @dataclass 27 | class ModelArgs(Serializable): 28 | dim: int 29 | n_layers: int 30 | head_dim: int 31 | hidden_dim: int 32 | n_heads: int 33 | n_kv_heads: int 34 | norm_eps: float 35 | vocab_size: int 36 | rope_theta: float = 10000.0 37 | 38 | lora: LoraArgs = field(default_factory=LoraArgs) 39 | moe: Optional[MoeArgs] = None 40 | -------------------------------------------------------------------------------- /model/lora.py: -------------------------------------------------------------------------------- 1 | from typing import NamedTuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class LoRALinear(nn.Module): 8 | """ 9 | Implementation of: 10 | - LoRA: https://arxiv.org/abs/2106.09685 11 | 12 | Notes: 13 | - Freezing is handled at the network level, not the layer level. 14 | - Scaling factor controls relative importance of LoRA skip 15 | connection versus original frozen weight. General guidance is 16 | to keep it to 2.0 and sweep over learning rate when changing 17 | the rank. 18 | """ 19 | 20 | def __init__( 21 | self, 22 | in_features: int, 23 | out_features: int, 24 | rank: int, 25 | scaling: float, 26 | dropout: float, 27 | bias: bool = False, 28 | ): 29 | super().__init__() 30 | 31 | self.in_features = in_features 32 | self.out_features = out_features 33 | assert not bias 34 | self.bias = bias 35 | self.rank = rank 36 | self.scaling = scaling 37 | 38 | self.dropout = nn.Dropout(p=dropout) 39 | 40 | self.lora_A = nn.Linear( 41 | self.in_features, 42 | self.rank, 43 | bias=self.bias, 44 | ) 45 | self.lora_B = nn.Linear( 46 | self.rank, 47 | self.out_features, 48 | bias=self.bias, 49 | ) 50 | 51 | self.frozen_W = nn.Linear(self.in_features, self.out_features, bias=self.bias) 52 | 53 | # make sure no LoRA weights are marked as "missing" in load_state_dict 54 | def ignore_missing_keys(m: nn.Module, incompatible_keys: NamedTuple): 55 | # empty missing keys in place 56 | incompatible_keys.missing_keys[:] = [] # type: ignore 57 | 58 | self.register_load_state_dict_post_hook(ignore_missing_keys) 59 | 60 | def merge_weight(self): 61 | with torch.no_grad(): 62 | down_weight = self.lora_A.weight 63 | up_weight = self.lora_B.weight 64 | 65 | weight = up_weight.mm(down_weight) * self.scaling 66 | 67 | weight += self.frozen_W.weight 68 | return weight 69 | 70 | def _load_from_state_dict( 71 | self, 72 | state_dict, 73 | prefix, 74 | local_metadata, 75 | strict, 76 | missing_keys, 77 | unexpected_keys, 78 | error_msgs, 79 | ): 80 | key_name = prefix + "weight" 81 | 82 | # full checkpoint 83 | if key_name in state_dict: 84 | w_ref = state_dict[key_name] 85 | 86 | # load frozen weights 87 | self.frozen_W.load_state_dict({"weight": w_ref}, assign=True) 88 | 89 | def forward(self, x: torch.Tensor): 90 | lora = self.lora_B(self.lora_A(self.dropout(x))) 91 | return self.frozen_W(x) + lora * self.scaling 92 | 93 | def __repr__(self) -> str: 94 | return "{}Linear(in_features={}, out_features={}, r={}, dropout={})".format( 95 | "LoRA", self.in_features, self.out_features, self.rank, self.dropout.p 96 | ) 97 | -------------------------------------------------------------------------------- /model/moe.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from .args import MoeArgs 7 | 8 | 9 | class MoeLayer(nn.Module): 10 | def __init__(self, experts: List[nn.Module], gate: nn.Module, moe_args: MoeArgs): 11 | super().__init__() 12 | assert len(experts) > 0 13 | self.experts = nn.ModuleList(experts) 14 | self.gate = gate 15 | self.args = moe_args 16 | 17 | def forward(self, inputs: torch.Tensor): 18 | gate_logits = self.gate(inputs) 19 | weights, selected_experts = torch.topk( 20 | gate_logits, self.args.num_experts_per_tok 21 | ) 22 | weights = torch.nn.functional.softmax(weights, dim=1, dtype=torch.float).to( 23 | inputs.dtype 24 | ) 25 | results = torch.zeros_like(inputs) 26 | for i, expert in enumerate(self.experts): 27 | batch_idx, nth_expert = torch.where(selected_experts == i) 28 | results[batch_idx] += weights[batch_idx, nth_expert, None] * expert( 29 | inputs[batch_idx] 30 | ) 31 | return results 32 | -------------------------------------------------------------------------------- /model/rope.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import torch 4 | 5 | 6 | def precompute_freqs_cis( 7 | dim: int, end: int, theta: float, device: Optional[torch.device] = None 8 | ) -> torch.Tensor: 9 | freqs = 1.0 / ( 10 | theta ** (torch.arange(0, dim, 2, device=device)[: (dim // 2)].float() / dim) 11 | ) 12 | t = torch.arange(end, device=freqs.device) # type: ignore 13 | freqs = torch.outer(t, freqs).float() # type: ignore 14 | return torch.polar(torch.ones_like(freqs), freqs) # complex64 15 | 16 | 17 | def apply_rotary_emb( 18 | xq: torch.Tensor, 19 | xk: torch.Tensor, 20 | freqs_cis: torch.Tensor, 21 | ) -> Tuple[torch.Tensor, torch.Tensor]: 22 | xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) 23 | xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) 24 | freqs_cis = freqs_cis[:, None, :] 25 | xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(2) 26 | xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(2) 27 | return xq_out.type_as(xq), xk_out.type_as(xk) 28 | -------------------------------------------------------------------------------- /model/transformer.py: -------------------------------------------------------------------------------- 1 | import operator 2 | from functools import partial, reduce 3 | from typing import Iterable, List, Optional, Union 4 | 5 | import torch 6 | import torch.distributed.algorithms._checkpoint.checkpoint_wrapper as torch_ckpt 7 | import torch.nn as nn 8 | from xformers.ops.fmha import memory_efficient_attention 9 | from xformers.ops.fmha.attn_bias import AttentionBias, BlockDiagonalCausalMask 10 | 11 | from .args import ModelArgs 12 | from .lora import LoRALinear 13 | from .moe import MoeLayer 14 | from .rope import apply_rotary_emb, precompute_freqs_cis 15 | 16 | 17 | def repeat_kv(keys: torch.Tensor, values: torch.Tensor, repeats: int, dim: int): 18 | keys = torch.repeat_interleave(keys, repeats=repeats, dim=dim) 19 | values = torch.repeat_interleave(values, repeats=repeats, dim=dim) 20 | return keys, values 21 | 22 | 23 | def maybe_lora_layer( 24 | args: ModelArgs, rank: Optional[int] = None 25 | ) -> Union[partial[LoRALinear], type[nn.Linear]]: 26 | MaybeLora: Union[partial[LoRALinear], type[nn.Linear]] 27 | if not args.lora.enable: 28 | return nn.Linear 29 | 30 | rank = rank or args.lora.rank 31 | scaling = args.lora.scaling 32 | dropout = args.lora.dropout 33 | 34 | MaybeLora = partial( 35 | LoRALinear, 36 | rank=rank, 37 | scaling=scaling, 38 | dropout=dropout, 39 | ) 40 | 41 | return MaybeLora 42 | 43 | 44 | class Attention(nn.Module): 45 | def __init__(self, args: ModelArgs): 46 | super().__init__() 47 | self.args = args 48 | 49 | self.n_heads: int = args.n_heads 50 | self.n_kv_heads: int = args.n_kv_heads 51 | self.head_dim: int = args.head_dim 52 | 53 | self.repeats = self.n_heads // self.n_kv_heads 54 | 55 | self.scale = self.args.head_dim**-0.5 56 | 57 | MaybeLora = maybe_lora_layer(args) 58 | 59 | self.wq = MaybeLora(args.dim, args.n_heads * args.head_dim, bias=False) 60 | self.wk = MaybeLora(args.dim, args.n_kv_heads * args.head_dim, bias=False) 61 | self.wv = MaybeLora(args.dim, args.n_kv_heads * args.head_dim, bias=False) 62 | 63 | self.wo = MaybeLora(args.n_heads * args.head_dim, args.dim, bias=False) 64 | 65 | def forward( 66 | self, 67 | x: torch.Tensor, 68 | freqs_cis: torch.Tensor, 69 | mask: AttentionBias, 70 | ) -> torch.Tensor: 71 | seqlen_sum, _ = x.shape 72 | 73 | xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) 74 | 75 | xq = xq.view(seqlen_sum, self.n_heads, self.args.head_dim) 76 | xk = xk.view(seqlen_sum, self.n_kv_heads, self.args.head_dim) 77 | xv = xv.view(seqlen_sum, self.n_kv_heads, self.args.head_dim) 78 | xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) 79 | 80 | key, val = xk, xv 81 | 82 | # Repeat keys and values to match number of query heads 83 | key, val = repeat_kv(key, val, self.repeats, dim=1) 84 | 85 | # xformers requires (B=1, S, H, D) 86 | xq, key, val = xq[None, ...], key[None, ...], val[None, ...] 87 | output = memory_efficient_attention(xq, key, val, mask) 88 | 89 | return self.wo(output.view(seqlen_sum, -1)) 90 | 91 | 92 | class FeedForward(nn.Module): 93 | def __init__(self, args: ModelArgs): 94 | super().__init__() 95 | 96 | MaybeLora = maybe_lora_layer(args) 97 | self.w1 = MaybeLora(args.dim, args.hidden_dim, bias=False) 98 | self.w2 = MaybeLora(args.hidden_dim, args.dim, bias=False) 99 | self.w3 = MaybeLora(args.dim, args.hidden_dim, bias=False) 100 | 101 | def forward(self, x) -> torch.Tensor: 102 | return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) 103 | 104 | 105 | class RMSNorm(torch.nn.Module): 106 | def __init__(self, dim: int, eps: float = 1e-6): 107 | super().__init__() 108 | self.eps = eps 109 | self.weight = nn.Parameter(torch.ones(dim)) 110 | 111 | def _norm(self, x): 112 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 113 | 114 | def forward(self, x): 115 | output = self._norm(x.float()).type_as(x) 116 | return output * self.weight 117 | 118 | 119 | class TransformerBlock(nn.Module): 120 | def __init__(self, args: ModelArgs): 121 | super().__init__() 122 | self.n_heads = args.n_heads 123 | self.dim = args.dim 124 | self.attention = Attention(args) 125 | 126 | self.feed_forward: Union[MoeLayer, FeedForward] 127 | if args.moe is not None: 128 | self.feed_forward = MoeLayer( 129 | experts=[FeedForward(args=args) for _ in range(args.moe.num_experts)], 130 | gate=nn.Linear(args.dim, args.moe.num_experts, bias=False), 131 | moe_args=args.moe, 132 | ) 133 | else: 134 | self.feed_forward = FeedForward(args=args) 135 | 136 | self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) 137 | self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) 138 | self.args = args 139 | 140 | def forward( 141 | self, 142 | x: torch.Tensor, 143 | freqs_cis: torch.Tensor, 144 | att_mask: AttentionBias, 145 | ) -> torch.Tensor: 146 | r = self.attention(self.attention_norm(x), freqs_cis, att_mask) 147 | h = x + r 148 | 149 | r = self.feed_forward(self.ffn_norm(h)) 150 | out = h + r 151 | 152 | return out 153 | 154 | 155 | class Transformer(nn.Module): 156 | def __init__(self, args: ModelArgs, checkpoint: bool = False): 157 | super().__init__() 158 | self.args = args 159 | self.vocab_size = args.vocab_size 160 | self.n_layers = args.n_layers 161 | assert self.vocab_size > 0 162 | self.tok_embeddings = torch.nn.Embedding(args.vocab_size, args.dim) 163 | self.layers = torch.nn.ModuleList() 164 | for _ in range(args.n_layers): 165 | block: torch.nn.Module = TransformerBlock(args=args) 166 | if checkpoint: 167 | # activate gradient checkpointing as, see: https://pytorch.org/docs/stable/checkpoint.html 168 | non_reentrant_wrapper = partial( 169 | torch_ckpt.checkpoint_wrapper, 170 | checkpoint_impl=torch_ckpt.CheckpointImpl.NO_REENTRANT, 171 | ) 172 | block = non_reentrant_wrapper(block) 173 | 174 | self.layers.append(block) 175 | 176 | self.norm = RMSNorm(args.dim, eps=args.norm_eps) 177 | 178 | self.output = torch.nn.Linear( 179 | args.dim, 180 | args.vocab_size, 181 | bias=False, 182 | ) 183 | 184 | # set lazily 185 | self._freqs_cis = None 186 | 187 | @property 188 | def dtype(self) -> torch.dtype: 189 | return self.tok_embeddings.weight.dtype 190 | 191 | @property 192 | def device(self) -> torch.device: 193 | return self.tok_embeddings.weight.device 194 | 195 | @property 196 | def freqs_cis(self): 197 | # lazy init 198 | device = next(iter(self.parameters())).device 199 | if self._freqs_cis is None: 200 | self._freqs_cis = precompute_freqs_cis( 201 | self.args.head_dim, 128_000, theta=self.args.rope_theta, device=device 202 | ) 203 | 204 | return self._freqs_cis 205 | 206 | def forward( 207 | self, 208 | input_ids: torch.Tensor, 209 | seqlens: List[int], 210 | ) -> torch.Tensor: 211 | assert sum(seqlens) == input_ids.shape[0], (sum(seqlens), input_ids.shape[0]) 212 | 213 | h = self.tok_embeddings(input_ids) 214 | positions = positions_from_sizes(seqlens, self.freqs_cis.device) 215 | att_mask = BlockDiagonalCausalMask.from_seqlens(seqlens) 216 | 217 | freqs_cis = self.freqs_cis[positions].to(device=h.device) 218 | 219 | for layer in self.layers: 220 | h = layer(h, freqs_cis, att_mask) 221 | 222 | return self.output(self.norm(h)).float() 223 | 224 | 225 | def positions_from_sizes(sizes: Iterable[int], device): 226 | return torch.tensor( 227 | reduce(operator.iadd, [list(range(s)) for s in sizes], []), 228 | dtype=torch.long, 229 | device=device, 230 | ) 231 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | include = '\.pyi?$' 3 | exclude = ''' 4 | /( 5 | \.git 6 | | \.hg 7 | | \.mypy_cache 8 | | \.tox 9 | | \.venv 10 | | _build 11 | | buck-out 12 | | build 13 | | dist 14 | )/ 15 | ''' 16 | line-length = 88 17 | 18 | [tool.isort] 19 | known_third_party = [] 20 | line_length = 88 21 | 22 | [tool.mypy] 23 | python_version = "3.9" 24 | ignore_missing_imports = true 25 | follow_imports = "normal" 26 | 27 | [tool.ruff] 28 | lint.select = ["E", "F", "W", "C90", "I", "C4", "PT", "RSE", "TID", "TCH", "FLY", "NPY", "RUF", "T10"] 29 | line-length = 88 30 | lint.ignore = ["E501"] 31 | target-version = "py38" 32 | exclude = [".git","__pycache__",".vscode",".mypy_cache"] 33 | 34 | [tool.ruff.lint.mccabe] 35 | max-complexity = 20 36 | 37 | [tool.pytest.ini_options] 38 | python_classes = [ 39 | "!TestConfig", 40 | ] 41 | filterwarnings = [ 42 | "error", 43 | "ignore::DeprecationWarning", 44 | "ignore::ImportWarning", 45 | "ignore::DeprecationWarning:pkg_resources", 46 | "ignore::DeprecationWarning:wandb", 47 | "ignore::DeprecationWarning:torch.utils.tensorboard", 48 | "ignore::UserWarning", 49 | ] 50 | testpaths = [] 51 | addopts = "--ignore=third_party --ignore=tokens" 52 | -------------------------------------------------------------------------------- /requirements.dev.txt: -------------------------------------------------------------------------------- 1 | mypy==1.4.0 2 | mypy-extensions==1.0.0 3 | pre-commit 4 | ruff==0.0.282 5 | pytest 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | fire 2 | simple-parsing 3 | pyyaml 4 | mistral-common>=1.3.1 5 | safetensors 6 | tensorboard 7 | tqdm 8 | 9 | torch==2.2 10 | triton==2.2 11 | xformers==0.0.24 12 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mistralai/mistral-finetune/656df1c94c80ca9703ebc471c9f106c9b7a0bfa7/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_checkpointer.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pytest 4 | import torch 5 | 6 | from finetune.args import LoraArgs 7 | from finetune.checkpointing import Checkpointer 8 | from finetune.utils import TrainState 9 | from finetune.wrapped_model import load_model 10 | from tests.test_utils import MODEL_PATH, is_float_equal, setup_mp_test_dist 11 | from utils.merge_lora import merge_checkpoints 12 | 13 | from .test_utils import spawn_for_all_world_sizes 14 | 15 | # fmt: off 16 | EXPECTED_NON_LORA_KEYS = sorted(['layers.0.attention.wk.weight', 'layers.0.attention.wo.weight', 'layers.0.attention.wq.weight', 'layers.0.attention.wv.weight', 'layers.0.attention_norm.weight', 'layers.0.feed_forward.w1.weight', 'layers.0.feed_forward.w2.weight', 'layers.0.feed_forward.w3.weight', 'layers.0.ffn_norm.weight', 'layers.1.attention.wk.weight', 'layers.1.attention.wo.weight', 'layers.1.attention.wq.weight', 'layers.1.attention.wv.weight', 'layers.1.attention_norm.weight', 'layers.1.feed_forward.w1.weight', 'layers.1.feed_forward.w2.weight', 'layers.1.feed_forward.w3.weight', 'layers.1.ffn_norm.weight', 'norm.weight', 'output.weight', 'tok_embeddings.weight']) 17 | EXPECTED_LORA_KEYS = sorted(['layers.0.attention.wq.lora_A.weight', 'layers.0.attention.wq.lora_B.weight', 'layers.0.attention.wk.lora_A.weight', 'layers.0.attention.wk.lora_B.weight', 'layers.0.attention.wv.lora_A.weight', 'layers.0.attention.wv.lora_B.weight', 'layers.0.attention.wo.lora_A.weight', 'layers.0.attention.wo.lora_B.weight', 'layers.0.feed_forward.w1.lora_A.weight', 'layers.0.feed_forward.w1.lora_B.weight', 'layers.0.feed_forward.w2.lora_A.weight', 'layers.0.feed_forward.w2.lora_B.weight', 'layers.0.feed_forward.w3.lora_A.weight', 'layers.0.feed_forward.w3.lora_B.weight', 'layers.1.attention.wq.lora_A.weight', 'layers.1.attention.wq.lora_B.weight', 'layers.1.attention.wk.lora_A.weight', 'layers.1.attention.wk.lora_B.weight', 'layers.1.attention.wv.lora_A.weight', 'layers.1.attention.wv.lora_B.weight', 'layers.1.attention.wo.lora_A.weight', 'layers.1.attention.wo.lora_B.weight', 'layers.1.feed_forward.w1.lora_A.weight', 'layers.1.feed_forward.w1.lora_B.weight', 'layers.1.feed_forward.w2.lora_A.weight', 'layers.1.feed_forward.w2.lora_B.weight', 'layers.1.feed_forward.w3.lora_A.weight', 'layers.1.feed_forward.w3.lora_B.weight']) 18 | # fmt: on 19 | 20 | 21 | @pytest.mark.parametrize( 22 | ("world_size", "save_only_lora", "enable_lora"), 23 | [ 24 | (1, False, False), 25 | (2, False, False), 26 | (1, False, True), 27 | (2, False, True), 28 | (1, True, True), 29 | (2, True, True), # this is the most important test! - FSDP only LORA 30 | ], 31 | ) 32 | def test_states_retrieval(world_size, enable_lora, save_only_lora): 33 | spawn_for_all_world_sizes( 34 | _check_states_retrieval, 35 | world_sizes=[world_size], 36 | args=[enable_lora, save_only_lora], 37 | deterministic=True, 38 | ) 39 | 40 | 41 | def _check_states_retrieval( 42 | rank: int, 43 | world_size: int, 44 | filename: str, 45 | filename_rpc: str, 46 | enable_lora: bool, 47 | save_only_lora: bool, 48 | ): 49 | model_parallel = 1 50 | setup_mp_test_dist(rank, world_size, filename, model_parallel, seed=0) 51 | 52 | folder = Path(MODEL_PATH) 53 | model = load_model( 54 | folder=folder, 55 | lora=LoraArgs(enable=enable_lora), 56 | checkpoint=True, 57 | param_dtype=torch.bfloat16, 58 | ) 59 | 60 | # mock a train state that has done three steps 61 | step = 3 62 | state = TrainState(max_steps=10, step=step) # 10 is just a dummy value here 63 | 64 | # mock run_dir as we won't save anything in this test 65 | run_dir = Path("dir") 66 | use_sf = True 67 | 68 | checkpointer = Checkpointer(model, state, run_dir=run_dir, num_ckpt_keep=None) 69 | prefix = "lora" if enable_lora else "consolidated" 70 | 71 | assert checkpointer.dst_dir == Path( 72 | f"dir/checkpoints/checkpoint_00000{step}/consolidated" 73 | ), checkpointer.dst_dir 74 | assert checkpointer.consolidated_path( 75 | checkpointer.dst_dir, use_sf, save_only_lora=enable_lora 76 | ) == Path( 77 | f"dir/checkpoints/checkpoint_00000{step}/consolidated/{prefix}.safetensors" 78 | ), checkpointer.consolidated_path( 79 | checkpointer.dst_dir, use_sf, save_only_lora=enable_lora 80 | ) 81 | 82 | # increase step by one 83 | state.start_step() 84 | 85 | assert checkpointer.dst_dir == Path( 86 | f"dir/checkpoints/checkpoint_00000{step + 1}/consolidated" 87 | ), checkpointer.dst_dir 88 | assert checkpointer.consolidated_path( 89 | checkpointer.dst_dir, use_sf, save_only_lora=enable_lora 90 | ) == Path( 91 | f"dir/checkpoints/checkpoint_00000{step + 1}/consolidated/{prefix}.safetensors" 92 | ), checkpointer.consolidated_path( 93 | checkpointer.dst_dir, use_sf, save_only_lora=enable_lora 94 | ) 95 | 96 | assert all("lora" in k for k in EXPECTED_LORA_KEYS), EXPECTED_LORA_KEYS 97 | 98 | for save_dtype in [torch.float16, torch.bfloat16, torch.float32]: 99 | 100 | save_dict = checkpointer.retrieve_save_states( 101 | save_only_lora=save_only_lora, save_dtype=save_dtype 102 | ) 103 | 104 | for k, v in save_dict.items(): 105 | assert v.dtype == save_dtype, f"{k}: v.dtype" 106 | 107 | if save_only_lora: 108 | assert sorted(save_dict.keys()) == EXPECTED_LORA_KEYS, save_dict.keys() 109 | else: 110 | assert sorted(save_dict.keys()) == EXPECTED_NON_LORA_KEYS, save_dict.keys() 111 | 112 | EXPECTED_NON_LORA_VALUES = 34909.7500 113 | 114 | EXPECTED_LORA_VALUES = 984.4179840087891 115 | 116 | values_sum = sum(v.abs().float().sum().item() for v in save_dict.values()) 117 | 118 | if save_only_lora: 119 | assert is_float_equal( 120 | values_sum, EXPECTED_LORA_VALUES, 5e-1 121 | ), f"{values_sum} for {save_dtype}" 122 | else: 123 | assert is_float_equal( 124 | values_sum, EXPECTED_NON_LORA_VALUES, 1e-1 125 | ), f"{values_sum} for {save_dtype}" 126 | 127 | 128 | @pytest.mark.parametrize("world_size", [1, 2]) 129 | def test_lora_merge_equal(world_size): 130 | spawn_for_all_world_sizes( 131 | _check_lora_merge_equal, 132 | world_sizes=[world_size], 133 | deterministic=True, 134 | ) 135 | 136 | 137 | def _check_lora_merge_equal( 138 | rank: int, world_size: int, filename: str, filename_rpc: str 139 | ): 140 | model_parallel = 1 141 | enable_lora = True 142 | setup_mp_test_dist(rank, world_size, filename, model_parallel, seed=0) 143 | 144 | world_size // model_parallel 145 | 146 | folder = Path(MODEL_PATH) 147 | 148 | step = 3 149 | state = TrainState(max_steps=10, step=step) # 10 is just a dummy value here 150 | run_dir = Path("dir") 151 | 152 | non_lora_model = load_model( 153 | folder=folder, 154 | lora=LoraArgs(enable=False), 155 | checkpoint=True, 156 | param_dtype=torch.bfloat16, 157 | ) 158 | 159 | non_lora_checkpointer = Checkpointer( 160 | non_lora_model, state, run_dir=run_dir, num_ckpt_keep=None 161 | ) 162 | orig_model = non_lora_checkpointer.retrieve_save_states( 163 | save_only_lora=False, save_dtype=torch.float32 164 | ) 165 | 166 | scaling = 2.0 167 | 168 | model = load_model( 169 | folder=folder, 170 | lora=LoraArgs(enable=enable_lora, scaling=scaling), 171 | checkpoint=True, 172 | param_dtype=torch.bfloat16, 173 | ) 174 | 175 | state_dict = model.state_dict() 176 | state_dict = {k: v + 0.01 if "lora" in k else v for k, v in state_dict.items()} 177 | model.load_state_dict(state_dict) 178 | 179 | # mock a train state that has done three steps 180 | checkpointer = Checkpointer(model, state, run_dir=run_dir, num_ckpt_keep=None) 181 | 182 | for save_dtype in [torch.float16, torch.bfloat16, torch.float32]: 183 | model_dict = { 184 | k: torch.empty_like(v).copy_(v).to(save_dtype) 185 | for k, v in orig_model.items() 186 | } 187 | merged_save_dict = checkpointer.retrieve_save_states( 188 | save_only_lora=False, save_dtype=save_dtype 189 | ) 190 | 191 | lora_save_dict = checkpointer.retrieve_save_states( 192 | save_only_lora=True, save_dtype=save_dtype 193 | ) 194 | 195 | merge_checkpoints( 196 | model_dict, lora_save_dict, scaling=scaling, save_dtype=save_dtype 197 | ) 198 | 199 | for k in model_dict.keys(): 200 | torch.allclose( 201 | model_dict[k].cpu(), merged_save_dict[k].cpu(), atol=1e-3, rtol=1e-3 202 | ) 203 | 204 | for k in model_dict.keys(): 205 | # make sure that merged model differs from orig model 206 | if "attention" in k or "feed_forward" in k: 207 | not torch.allclose( 208 | orig_model[k].to(save_dtype).cpu(), 209 | merged_save_dict[k].cpu(), 210 | atol=1e-3, 211 | rtol=1e-3, 212 | ) 213 | -------------------------------------------------------------------------------- /tests/test_mixed_precision.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pytest 4 | import torch 5 | 6 | from finetune.args import LoraArgs 7 | from finetune.loss import compute_loss_with_mask 8 | from finetune.mixed_precision import ( 9 | downcast_mixed_precision, 10 | prepare_mixed_precision, 11 | upcast_mixed_precision, 12 | ) 13 | from finetune.wrapped_model import load_model 14 | from tests.test_utils import MODEL_PATH, get_dataloader, setup_mp_test_dist 15 | 16 | from .test_utils import spawn_for_all_world_sizes 17 | 18 | 19 | @pytest.mark.parametrize( 20 | ("world_size", "enable_lora"), [(1, False), (1, True), (2, False), (2, True)] 21 | ) 22 | def test_mixed_precision(world_size, enable_lora): 23 | spawn_for_all_world_sizes( 24 | _check_mixed_precision, 25 | world_sizes=[world_size], 26 | args=[enable_lora], 27 | deterministic=True, 28 | ) 29 | 30 | 31 | def _check_mixed_precision( 32 | rank: int, world_size: int, filename: str, filename_rpc: str, enable_lora: bool 33 | ): 34 | model_parallel = 1 35 | setup_mp_test_dist(rank, world_size, filename, model_parallel, seed=0) 36 | seq_len = 100 37 | 38 | folder = Path(MODEL_PATH) 39 | # mixed precision 40 | param_dtype = torch.bfloat16 41 | optim_dtype = torch.float32 42 | 43 | model = load_model( 44 | folder=folder, 45 | lora=LoraArgs(enable=enable_lora), 46 | checkpoint=True, 47 | param_dtype=param_dtype, 48 | ) 49 | 50 | optimizer = torch.optim.AdamW(model.parameters()) 51 | 52 | # initialize mixed precision training for TP 53 | prepare_mixed_precision( 54 | model.parameters(), param_dtype=param_dtype, optim_dtype=optim_dtype 55 | ) 56 | 57 | data_loader = get_dataloader(seq_len=seq_len) 58 | 59 | # ensure every parameter that requires a grad has a _mp_param of optim_dtype precision 60 | for param in model.parameters(): 61 | assert param.dtype == param_dtype 62 | if param.requires_grad: 63 | assert param._mp_param.dtype == optim_dtype 64 | assert ( 65 | param._mp_param.tolist() == param.data.to(optim_dtype).tolist() 66 | ), "mp param has to match param in optim dtype precision" 67 | else: 68 | assert not hasattr(param, "_mp_param") 69 | 70 | # test three train steps 71 | for _ in range(3): 72 | 73 | optimizer.zero_grad() 74 | 75 | # micro-batching 76 | for _ in range(2): 77 | batch = next(data_loader) 78 | 79 | x = torch.from_numpy(batch.x).cuda(non_blocking=True) 80 | y = torch.from_numpy(batch.y).cuda(non_blocking=True) 81 | y_mask = ( 82 | torch.from_numpy(batch.y_mask).cuda(non_blocking=True) 83 | if batch.y_mask is not None 84 | else None 85 | ) 86 | 87 | output = model( 88 | input_ids=x, 89 | seqlens=batch.sizes, 90 | ) 91 | 92 | mb_loss = compute_loss_with_mask(output, y, y_mask) 93 | mb_loss.backward() 94 | 95 | upcast_mixed_precision(model.parameters(), optim_dtype=optim_dtype) 96 | 97 | # ensure all params are upcasted correctly and mp param equals param 98 | param_sum = 0 99 | for param in model.parameters(): 100 | if param.requires_grad: 101 | assert param.dtype == optim_dtype, param.dtype 102 | assert ( 103 | param._mp_param.tolist() == param.data.tolist() 104 | ), "mp param and param should point to the same data" 105 | assert param.grad.dtype == optim_dtype 106 | assert param._temp.dtype == param_dtype 107 | param_sum += param.data.float().abs().sum() 108 | else: 109 | assert param.dtype == param_dtype 110 | 111 | optimizer.step() 112 | 113 | # ensure that after optimizer step params are still in optim dtype precision 114 | new_param_sum = 0 115 | for param in model.parameters(): 116 | if param.requires_grad: 117 | assert param.dtype == optim_dtype 118 | assert param._mp_param.dtype == optim_dtype 119 | assert param.grad.dtype == optim_dtype 120 | new_param_sum += param.data.float().abs().sum() 121 | else: 122 | assert param.dtype == param_dtype 123 | 124 | assert new_param_sum != param_sum, "Make sure parameters are updated" 125 | 126 | downcast_mixed_precision(model.parameters(), param_dtype=param_dtype) 127 | 128 | # ensure that before new forward pass params are downcast to param dtype 129 | for param in model.parameters(): 130 | assert param.dtype == param_dtype 131 | if param.requires_grad: 132 | assert param._mp_param.dtype == optim_dtype 133 | assert param.grad.dtype == param_dtype 134 | assert ( 135 | param._mp_param.to(param_dtype).tolist() == param.data.tolist() 136 | ), "mp param has to match param in optim dtype precision" 137 | -------------------------------------------------------------------------------- /tests/test_model.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | from pathlib import Path 3 | from typing import Dict 4 | 5 | import pytest 6 | import torch 7 | 8 | from finetune.args import LoraArgs 9 | from finetune.checkpointing import Checkpointer 10 | from finetune.loss import compute_loss_with_mask 11 | from finetune.mixed_precision import ( 12 | downcast_mixed_precision, 13 | prepare_mixed_precision, 14 | upcast_mixed_precision, 15 | ) 16 | from finetune.utils import TrainState 17 | from finetune.wrapped_model import load_model 18 | from model.transformer import ( 19 | LoRALinear, 20 | ) 21 | from tests.test_utils import ( 22 | MODEL_PATH, 23 | get_dataloader, 24 | is_float_equal, 25 | setup_mp_test_dist, 26 | ) 27 | 28 | from .test_utils import spawn_for_all_world_sizes 29 | 30 | torch.backends.cudnn.deterministic = True # use deterministic algorithms 31 | torch.backends.cudnn.benchmark = False # disable cuDNN benchmark 32 | 33 | 34 | @pytest.mark.parametrize( 35 | ("world_size", "enable_lora", "dtype"), 36 | [ 37 | (1, False, torch.float32), 38 | (1, True, torch.float32), 39 | (2, False, torch.float32), 40 | (2, True, torch.float32), 41 | (1, False, torch.bfloat16), 42 | (1, True, torch.bfloat16), 43 | (2, False, torch.bfloat16), 44 | (2, True, torch.bfloat16), 45 | ], 46 | ) 47 | def test_weights_loading(world_size, enable_lora, dtype): 48 | spawn_for_all_world_sizes( 49 | _check_weights_loading, 50 | world_sizes=[world_size], 51 | args=[enable_lora, dtype], 52 | deterministic=True, 53 | ) 54 | 55 | 56 | def _check_weights_loading( 57 | rank: int, 58 | world_size: int, 59 | filename: str, 60 | filename_rpc: str, 61 | enable_lora: bool, 62 | dtype: torch.dtype, 63 | ): 64 | model_parallel = 1 65 | setup_mp_test_dist(rank, world_size, filename, model_parallel, seed=0) 66 | 67 | folder = Path(MODEL_PATH) 68 | model = load_model( 69 | folder=folder, 70 | lora=LoraArgs(enable=enable_lora), 71 | checkpoint=True, 72 | param_dtype=dtype, 73 | ) 74 | 75 | # add hook so that LoRA weights are automatically merged: 76 | def register_merge_lora_hook(m: torch.nn.Module): 77 | def merge_lora( 78 | m: torch.nn.Module, destination: Dict[str, torch.Tensor], prefix: str, *args 79 | ): 80 | weight = m.merge_weight() 81 | destination[prefix + "weight"] = weight 82 | 83 | if isinstance(m, LoRALinear): 84 | m._merge_lora_handle = m._register_state_dict_hook(merge_lora) 85 | 86 | model.apply(register_merge_lora_hook) 87 | 88 | if world_size > 1: 89 | with model.summon_full_params(model, writeback=True): 90 | states = { 91 | k: v 92 | for k, v in model.state_dict().items() 93 | if "lora" not in k and "frozen" not in k 94 | } 95 | else: 96 | states = { 97 | k: v 98 | for k, v in model.state_dict().items() 99 | if "lora" not in k and "frozen" not in k 100 | } 101 | 102 | EXP_PARAM_SUM = 308.9932 if dtype == torch.float32 else 308.0 103 | params = sum([v.sum() for v in states.values()]).item() 104 | 105 | # LoRA is equal to no LoRA as LoRA weights should be init to 0 106 | assert is_float_equal(params, EXP_PARAM_SUM), params 107 | 108 | if enable_lora: 109 | lora_B_params = [ 110 | v.float().abs().sum() for k, v in model.named_parameters() if "lora_B" in k 111 | ] 112 | 113 | assert len(lora_B_params) > 0 114 | assert sum(lora_B_params) == 0, "Lora_B should always be zero init" 115 | 116 | lora_A_params = [ 117 | v.float().abs().sum() for k, v in model.named_parameters() if "lora_A" in k 118 | ] 119 | 120 | assert len(lora_A_params) > 0 121 | assert sum(lora_A_params) > 0, "Lora_A should init to non-zero values" 122 | 123 | 124 | @pytest.mark.parametrize( 125 | ("world_size", "enable_lora"), [(1, False), (1, True), (2, False), (2, True)] 126 | ) 127 | def test_fsdp_logits_and_loss(world_size, enable_lora): 128 | spawn_for_all_world_sizes( 129 | _check_fsdp_logits_and_loss, 130 | world_sizes=[world_size], 131 | args=[enable_lora], 132 | deterministic=True, 133 | ) 134 | 135 | 136 | def _check_fsdp_logits_and_loss( 137 | rank: int, world_size: int, filename: str, filename_rpc: str, enable_lora: bool 138 | ): 139 | model_parallel = 1 140 | setup_mp_test_dist(rank, world_size, filename, model_parallel, seed=0) 141 | seq_len = 100 142 | 143 | folder = Path(MODEL_PATH) 144 | model = load_model( 145 | folder=folder, 146 | lora=LoraArgs(enable=enable_lora), 147 | checkpoint=True, 148 | param_dtype=torch.bfloat16, 149 | ) 150 | # By setting equal rank and world_size we can assure that both ranks see the same data and hence the average 151 | data_loader = get_dataloader(seq_len=seq_len, rank=0, world_size=2) 152 | 153 | batch = next(data_loader) 154 | 155 | x = torch.from_numpy(batch.x).cuda(non_blocking=True) 156 | y = torch.from_numpy(batch.y).cuda(non_blocking=True) 157 | y_mask = torch.from_numpy(batch.y_mask).cuda(non_blocking=True) 158 | 159 | # forward / backward 160 | output = model( 161 | input_ids=x, 162 | seqlens=batch.sizes, 163 | ) 164 | 165 | # check logits 166 | # logits should be the same for LoRA and non-LoRA 167 | assert output.shape == (seq_len, model.args.vocab_size) 168 | output_sum = output.abs().float().sum().item() 169 | 170 | EXP_OUTPUT_WORLD_1 = 162617.625 171 | 172 | assert is_float_equal(output_sum, EXP_OUTPUT_WORLD_1, precision=1e1), output_sum 173 | 174 | # check loss is the same for all 175 | # loss should be the same for LoRA and non-LoRA 176 | mb_loss = compute_loss_with_mask(output, y, y_mask) 177 | 178 | EXPECTED_LOSS = 10.408413887023926 179 | 180 | assert is_float_equal(mb_loss.item(), EXPECTED_LOSS), mb_loss.item() 181 | 182 | 183 | @pytest.mark.parametrize( 184 | ("world_size", "dtype"), 185 | [(1, torch.bfloat16), (2, torch.bfloat16), (1, torch.float32), (2, torch.float32)], 186 | ) 187 | def test_fsdp_grads_non_lora(world_size, dtype): 188 | spawn_for_all_world_sizes( 189 | _check_fsdp_grads_non_lora, 190 | world_sizes=[world_size], 191 | deterministic=True, 192 | args=[dtype], 193 | ) 194 | 195 | 196 | def _check_fsdp_grads_non_lora( 197 | rank: int, world_size: int, filename: str, filename_rpc: str, dtype: torch.dtype 198 | ): 199 | model_parallel = 1 200 | setup_mp_test_dist(rank, world_size, filename, model_parallel, seed=0) 201 | seq_len = 2048 202 | 203 | folder = Path(MODEL_PATH) 204 | model = load_model( 205 | folder=folder, 206 | lora=LoraArgs(enable=False), 207 | checkpoint=True, 208 | param_dtype=dtype, 209 | ) 210 | # same world_size to check for equality 211 | data_loader = get_dataloader(seq_len=seq_len, rank=0, world_size=2) 212 | 213 | batch = next(data_loader) 214 | 215 | x = torch.from_numpy(batch.x).cuda(non_blocking=True) 216 | y = torch.from_numpy(batch.y).cuda(non_blocking=True) 217 | y_mask = torch.from_numpy(batch.y_mask).cuda(non_blocking=True) 218 | 219 | # forward / backward 220 | output = model( 221 | input_ids=x, 222 | seqlens=batch.sizes, 223 | ) 224 | 225 | mb_loss = compute_loss_with_mask(output, y, y_mask) 226 | mb_loss.backward() 227 | 228 | num_grad_params = sum([p.grad.numel() for p in model.parameters()]) 229 | 230 | assert (4301120 // world_size) == num_grad_params, num_grad_params 231 | 232 | torch.distributed.barrier() 233 | 234 | sharded_flat_grads = sum( 235 | [p.grad.float().abs().sum().item() for p in model.parameters()] 236 | ) 237 | 238 | print(f"{rank}: {world_size}: {dtype} = {sharded_flat_grads}") 239 | 240 | EXP_GRAD_WORLD_2_RANK_0 = 95.45827150344849 241 | EXP_GRAD_WORLD_2_RANK_1 = 86.09188461303711 242 | EXP_GRAD_WORLD_1 = EXP_GRAD_WORLD_2_RANK_0 + EXP_GRAD_WORLD_2_RANK_1 243 | 244 | if world_size == 1: 245 | assert is_float_equal( 246 | sharded_flat_grads, EXP_GRAD_WORLD_1, 2.0e-1 247 | ), sharded_flat_grads 248 | elif world_size == 2 and rank == 0: 249 | assert is_float_equal( 250 | sharded_flat_grads, EXP_GRAD_WORLD_2_RANK_0, 2.0e-1 251 | ), sharded_flat_grads 252 | elif world_size == 2 and rank == 1: 253 | assert is_float_equal( 254 | sharded_flat_grads, EXP_GRAD_WORLD_2_RANK_1, 2.0e-1 255 | ), sharded_flat_grads 256 | 257 | 258 | @pytest.mark.parametrize( 259 | ("world_size", "dtype"), 260 | [(1, torch.bfloat16), (2, torch.bfloat16), (1, torch.float32), (2, torch.float32)], 261 | ) 262 | def test_fsdp_grads_lora(world_size, dtype): 263 | spawn_for_all_world_sizes( 264 | _check_fsdp_grads_lora, 265 | world_sizes=[world_size], 266 | deterministic=True, 267 | args=[dtype], 268 | ) 269 | 270 | 271 | def _check_fsdp_grads_lora( 272 | rank: int, world_size: int, filename: str, filename_rpc: str, dtype: torch.dtype 273 | ): 274 | model_parallel = 1 275 | setup_mp_test_dist(rank, world_size, filename, model_parallel, seed=0) 276 | seq_len = 2048 277 | 278 | folder = Path(MODEL_PATH) 279 | model = load_model( 280 | folder=folder, 281 | lora=LoraArgs(enable=True), 282 | checkpoint=True, 283 | param_dtype=dtype, 284 | ) 285 | # same world_size to check for equality 286 | data_loader = get_dataloader(seq_len=seq_len, rank=0, world_size=2) 287 | 288 | batch = next(data_loader) 289 | 290 | x = torch.from_numpy(batch.x).cuda(non_blocking=True) 291 | y = torch.from_numpy(batch.y).cuda(non_blocking=True) 292 | y_mask = torch.from_numpy(batch.y_mask).cuda(non_blocking=True) 293 | 294 | # forward / backward 295 | output = model( 296 | input_ids=x, 297 | seqlens=batch.sizes, 298 | ) 299 | 300 | mb_loss = compute_loss_with_mask(output, y, y_mask) 301 | mb_loss.backward() 302 | 303 | num_grad_params = sum( 304 | [p.grad.numel() for p in model.parameters() if p.grad is not None] 305 | ) 306 | 307 | assert (40960 // world_size) == num_grad_params, num_grad_params 308 | 309 | torch.distributed.barrier() 310 | 311 | sharded_flat_grads = sum( 312 | [ 313 | p.grad.float().abs().sum().item() 314 | for p in model.parameters() 315 | if p.grad is not None 316 | ] 317 | ) 318 | 319 | print(f"{rank}: {world_size}: {dtype} = {sharded_flat_grads}") 320 | 321 | EXP_GRAD_WORLD_2_RANK_0 = 3.0742580661177635 322 | EXP_GRAD_WORLD_2_RANK_1 = 3.074301045779139 323 | EXP_GRAD_WORLD_1 = EXP_GRAD_WORLD_2_RANK_0 + EXP_GRAD_WORLD_2_RANK_1 324 | 325 | if world_size == 1: 326 | assert is_float_equal( 327 | sharded_flat_grads, EXP_GRAD_WORLD_1, 2.0e-1 328 | ), sharded_flat_grads 329 | elif world_size == 2 and rank == 0: 330 | assert is_float_equal( 331 | sharded_flat_grads, EXP_GRAD_WORLD_2_RANK_0, 2.0e-1 332 | ), sharded_flat_grads 333 | elif world_size == 2 and rank == 1: 334 | assert is_float_equal( 335 | sharded_flat_grads, EXP_GRAD_WORLD_2_RANK_1, 2.0e-1 336 | ), sharded_flat_grads 337 | 338 | 339 | @pytest.mark.parametrize( 340 | ("world_size", "dtype"), 341 | [(1, torch.bfloat16), (2, torch.bfloat16), (1, torch.float32), (2, torch.float32)], 342 | ) 343 | def test_grad_update_lora(world_size, dtype): 344 | spawn_for_all_world_sizes( 345 | _check_grad_update_lora, 346 | world_sizes=[world_size], 347 | args=[dtype], 348 | deterministic=True, 349 | ) 350 | 351 | 352 | def _check_grad_update_lora( 353 | rank: int, world_size: int, filename: str, filename_rpc: str, dtype: torch.dtype 354 | ): 355 | model_parallel = 1 356 | setup_mp_test_dist(rank, world_size, filename, model_parallel, seed=0) 357 | seq_len = 1000 358 | 359 | folder = Path(MODEL_PATH) 360 | model = load_model( 361 | folder=folder, 362 | lora=LoraArgs(enable=True), 363 | checkpoint=True, 364 | param_dtype=dtype, 365 | ) 366 | optimizer = torch.optim.AdamW(model.parameters()) 367 | 368 | data_loader = get_dataloader(seq_len=seq_len) 369 | 370 | batch = next(data_loader) 371 | 372 | x = torch.from_numpy(batch.x).cuda(non_blocking=True) 373 | y = torch.from_numpy(batch.y).cuda(non_blocking=True) 374 | y_mask = ( 375 | torch.from_numpy(batch.y_mask).cuda(non_blocking=True) 376 | if batch.y_mask is not None 377 | else None 378 | ) 379 | 380 | # forward / backward 381 | output = model( 382 | input_ids=x, 383 | seqlens=batch.sizes, 384 | ) 385 | 386 | mb_loss = compute_loss_with_mask(output, y, y_mask) 387 | mb_loss.backward() 388 | 389 | lora_weight_sum = 0 390 | non_lora_weight_sum = 0 391 | for name, param in model.named_parameters(): 392 | if "lora" in name or "norm" in name: 393 | assert param.grad is not None, name 394 | lora_weight_sum += param.data.float().abs().sum() 395 | else: 396 | assert param.grad is None, name 397 | non_lora_weight_sum += param.data.float().abs().sum() 398 | 399 | # update weights 400 | optimizer.step() 401 | 402 | new_lora_weight_sum = 0 403 | new_non_lora_weight_sum = 0 404 | for name, param in model.named_parameters(): 405 | if "lora" in name or "norm" in name: 406 | assert param.grad is not None, name 407 | new_lora_weight_sum += param.data.float().abs().sum() 408 | else: 409 | assert param.grad is None, name 410 | new_non_lora_weight_sum += param.data.float().abs().sum() 411 | 412 | # make sure that LoRA weights changed, but non-LoRA weights stayed the same 413 | assert not is_float_equal( 414 | new_lora_weight_sum, lora_weight_sum, 1e-4 415 | ), f"New: {new_lora_weight_sum}, Old: {lora_weight_sum}" 416 | assert is_float_equal( 417 | new_non_lora_weight_sum, non_lora_weight_sum, 1e-4 418 | ), f"New: {new_non_lora_weight_sum}, Old: {non_lora_weight_sum}" 419 | 420 | 421 | @pytest.mark.parametrize( 422 | ("enable_lora", "param_dtype"), 423 | [ 424 | (False, torch.float32), 425 | (True, torch.float32), 426 | (False, torch.bfloat16), 427 | (True, torch.bfloat16), 428 | ], 429 | ) 430 | def test_grads_fsdp_mp(enable_lora, param_dtype): 431 | with tempfile.TemporaryDirectory() as tmpdirname: 432 | for world_size in [1, 2]: 433 | spawn_for_all_world_sizes( 434 | _check_grads_fsdp_mp, 435 | world_sizes=[world_size], 436 | deterministic=True, 437 | args=[tmpdirname, enable_lora, param_dtype], 438 | ) 439 | 440 | w1_sd = torch.load(Path(tmpdirname) / Path("params_w1.pt"), map_location="cpu") 441 | w2_sd = torch.load(Path(tmpdirname) / Path("params_w2.pt"), map_location="cpu") 442 | 443 | for k in w1_sd.keys(): 444 | assert w1_sd[k].shape == w2_sd[k].shape, k 445 | atol = 10 if param_dtype == torch.float32 else 100 446 | assert (w1_sd[k] - w2_sd[k]).sum().abs().item() < atol 447 | 448 | 449 | def _check_grads_fsdp_mp( 450 | rank: int, 451 | world_size: int, 452 | filename: str, 453 | filename_rpc: str, 454 | tmpdirname: str, 455 | enable_lora: bool, 456 | param_dtype: torch.dtype, 457 | ): 458 | model_parallel = 1 459 | setup_mp_test_dist(rank, world_size, filename, model_parallel, seed=0) 460 | seq_len = 4096 461 | 462 | optim_dtype = torch.float32 463 | 464 | folder = Path(MODEL_PATH) 465 | model = load_model( 466 | folder=folder, 467 | lora=LoraArgs(enable=enable_lora), 468 | checkpoint=True, 469 | param_dtype=param_dtype, 470 | ) 471 | 472 | # high learning rate to show differences 473 | optimizer = torch.optim.AdamW(model.parameters(), lr=0.1) 474 | 475 | # mock a train state that has done three steps 476 | steps = 4 477 | state = TrainState(max_steps=steps) 478 | 479 | # mock run_dir as we won't save anything in this test 480 | run_dir = Path(tmpdirname) 481 | 482 | checkpointer = Checkpointer(model, state, run_dir=run_dir, num_ckpt_keep=None) 483 | 484 | # make sure the same data is seen 485 | dataloaders = [ 486 | get_dataloader(seq_len=seq_len, rank=rank + i, world_size=2) 487 | for i in range(2 - world_size + 1) 488 | ] 489 | 490 | prepare_mixed_precision( 491 | model.parameters(), param_dtype=param_dtype, optim_dtype=optim_dtype 492 | ) 493 | 494 | for _ in range(steps): 495 | state.start_step() 496 | optimizer.zero_grad() 497 | 498 | for data_loader in dataloaders: 499 | torch.manual_seed(0) 500 | batch = next(data_loader) 501 | 502 | x = torch.from_numpy(batch.x).cuda() 503 | y = torch.from_numpy(batch.y).cuda() 504 | y_mask = ( 505 | torch.from_numpy(batch.y_mask).cuda(non_blocking=True) 506 | if batch.y_mask is not None 507 | else None 508 | ) 509 | 510 | # forward / backward 511 | output = model( 512 | input_ids=x, 513 | seqlens=batch.sizes, 514 | ) 515 | 516 | mb_loss = compute_loss_with_mask(output, y, y_mask) 517 | mb_loss.backward() 518 | 519 | assert model.params[0].dtype == param_dtype 520 | 521 | print(f"rank: {rank}, world_size: {world_size}, x: {x.abs().sum()}") 522 | print(f"rank: {rank}, world_size: {world_size}, y: {y.abs().sum()}") 523 | print(f"rank: {rank}, world_size: {world_size}, x shape: {x.shape}") 524 | 525 | if y_mask is not None: 526 | print( 527 | f"rank: {rank}, world_size: {world_size}, y_mask: {y_mask.abs().sum()}" 528 | ) 529 | print(f"rank: {rank}, world_size: {world_size}, loss: {mb_loss}") 530 | 531 | for p in model.parameters(): 532 | if p.requires_grad: 533 | assert p.grad is not None 534 | p.grad.div_(len(dataloaders)) 535 | 536 | max_norm = 1.0 537 | model.clip_grad_norm_(max_norm=max_norm) 538 | 539 | upcast_mixed_precision(model.parameters(), optim_dtype=optim_dtype) 540 | 541 | optimizer.step() 542 | 543 | downcast_mixed_precision(model.parameters(), param_dtype=param_dtype) 544 | 545 | save_dict = checkpointer.retrieve_save_states( 546 | save_only_lora=enable_lora, save_dtype=torch.float32 547 | ) 548 | 549 | path = "params_w1.pt" if world_size == 1 else "params_w2.pt" 550 | torch.save(save_dict, Path(tmpdirname) / Path(path)) 551 | -------------------------------------------------------------------------------- /tests/test_train_loop.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | from contextlib import ExitStack 4 | from pathlib import Path 5 | 6 | import pytest 7 | import safetensors 8 | import torch 9 | 10 | from finetune.args import LoraArgs, OptimArgs, TrainArgs 11 | from finetune.data.args import DataArgs, InstructArgs 12 | from tests.test_utils import DATA_PATH, EVAL_DATA_PATH, MODEL_PATH, setup_mp_test_dist 13 | from train import _train 14 | 15 | from .test_utils import spawn_for_all_world_sizes 16 | 17 | 18 | def file_size_and_md5(file_path): 19 | # Check if the file exists 20 | if not os.path.isfile(file_path): 21 | return "Error: File not found" 22 | 23 | # Get the size of the file 24 | file_size = os.path.getsize(file_path) 25 | 26 | # Open the file in binary mode 27 | state_dict = safetensors.torch.load_file(file_path) 28 | md5_sum = sum(v.abs().sum().item() for v in state_dict.values()) 29 | 30 | return file_size, md5_sum 31 | 32 | 33 | @pytest.mark.parametrize("enable_lora", [False, True]) 34 | def test_integration(enable_lora): 35 | torch.backends.cudnn.deterministic = True # use deterministic algorithms 36 | torch.backends.cudnn.benchmark = False # disable cuDNN benchmark 37 | 38 | instruct = InstructArgs(shuffle=False, dynamic_chunk_fn_call=False) 39 | 40 | data_args = DataArgs( 41 | data="", 42 | instruct_data=DATA_PATH, 43 | eval_instruct_data=EVAL_DATA_PATH, 44 | instruct=instruct, 45 | ) 46 | 47 | model_path = MODEL_PATH 48 | optim_args = OptimArgs(lr=0.01, weight_decay=0.1, pct_start=0.0) 49 | 50 | with tempfile.TemporaryDirectory() as tmpdirname: 51 | args = TrainArgs( 52 | data=data_args, 53 | model_id_or_path=model_path, 54 | run_dir=tmpdirname, 55 | seed=0, 56 | optim=optim_args, 57 | max_steps=4, 58 | num_microbatches=1, 59 | lora=LoraArgs(enable=enable_lora), 60 | save_adapters=enable_lora, 61 | checkpoint=True, 62 | no_eval=False, 63 | ) 64 | spawn_for_all_world_sizes( 65 | _run_dummy_train, 66 | world_sizes=[2], 67 | deterministic=True, 68 | args=[args], 69 | ) 70 | 71 | prefix = "lora" if enable_lora else "consolidated" 72 | ckpt_path = Path(tmpdirname) / Path( 73 | f"checkpoints/checkpoint_00000{args.max_steps}/consolidated/{prefix}.safetensors" 74 | ) 75 | assert ckpt_path.exists() 76 | 77 | file_size, hash = file_size_and_md5(ckpt_path) 78 | 79 | EXPECTED_FILE_SIZE = [8604200, 84760] 80 | 81 | EXPECTED_HASH = [50515.5, 1296.875] 82 | 83 | assert file_size == EXPECTED_FILE_SIZE[enable_lora], file_size 84 | assert abs(hash - EXPECTED_HASH[enable_lora]) < 1e-2, hash 85 | 86 | 87 | def _run_dummy_train( 88 | rank: int, world_size: int, filename: str, filename_rpc: str, args: TrainArgs 89 | ): 90 | setup_mp_test_dist(rank, world_size, filename, 1, seed=0) 91 | 92 | with ExitStack() as exit_stack: 93 | _train(args, exit_stack) 94 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | from datetime import timedelta 4 | from typing import Any, Callable, Dict, List, Optional 5 | 6 | import torch 7 | import torch.distributed as dist 8 | from mistral_common.tokens.tokenizers.mistral import MistralTokenizer 9 | 10 | from finetune.data.args import DataArgs, InstructArgs 11 | from finetune.data.data_loader import build_data_loader 12 | from finetune.distributed import get_rank, get_world_size 13 | from finetune.utils import set_random_seed 14 | 15 | 16 | def is_float_equal(a, b, precision=5e-3): 17 | return abs(a - b) < precision 18 | 19 | 20 | MODEL_PATH = os.getenv("DUMMY_MODEL") 21 | assert MODEL_PATH != "", "Provide a path to a dummy model" 22 | DATA_PATH = "tests/fixtures/sample_instruct.jsonl:.1,tests/fixtures/sample_instruct_2.jsonl:.1,tests/fixtures/sample_instruct_3.jsonl:.1" 23 | EVAL_DATA_PATH = "tests/fixtures/sample_instruct.jsonl,tests/fixtures/sample_instruct_2.jsonl,tests/fixtures/sample_instruct_3.jsonl" 24 | 25 | 26 | # Model parallel group that the current rank belongs to. 27 | _MODEL_PARALLEL_GROUP = None 28 | # Data parallel group that the current rank belongs to. 29 | _DATA_PARALLEL_GROUP = None 30 | # Pipeline parallel group that the current rank belongs to. 31 | _PIPELINE_PARALLEL_GROUP = None 32 | 33 | _PIPELINE_PARALLEL_RANKS = None 34 | 35 | 36 | def rmf(filename: str) -> None: 37 | """Remove a file like rm -f.""" 38 | try: 39 | os.remove(filename) 40 | except FileNotFoundError: 41 | pass 42 | 43 | 44 | def runner( 45 | rank: int, test_func: Callable, deterministic: bool = False, *args: List[Any], **kwargs: Dict[str, Any] 46 | ) -> None: 47 | # At this point we're in a new process, torch options need to be set again 48 | if deterministic: 49 | torch.backends.cudnn.deterministic = True 50 | torch.backends.cudnn.benchmark = False 51 | torch.manual_seed(1357) 52 | 53 | test_func(rank, *args, **kwargs) 54 | 55 | 56 | def spawn_for_all_world_sizes( 57 | test_func: Callable, world_sizes: List[int] = [], args: Any = [], deterministic: bool = False 58 | ) -> None: 59 | for world_size in world_sizes: 60 | _, filename = tempfile.mkstemp() 61 | _, filename_rpc = tempfile.mkstemp() 62 | 63 | try: 64 | torch.multiprocessing.spawn( 65 | runner, 66 | args=(test_func, deterministic, world_size, filename, filename_rpc, *args), 67 | nprocs=world_size, 68 | join=True, 69 | ) 70 | finally: 71 | rmf(filename) 72 | rmf(filename_rpc) 73 | 74 | def initialize_model_parallel( 75 | model_parallel_size_: int, 76 | pipeline_length: int = 1, 77 | *, 78 | model_parallel_backend: Optional[str] = None, 79 | pipeline_backend: Optional[str] = None, 80 | ddp_backend: Optional[str] = None 81 | ) -> None: 82 | """ 83 | Initialize model data parallel groups. 84 | 85 | Arguments: 86 | model_parallel_size: number of GPUs used to parallelize model. 87 | 88 | Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we 89 | use 2 GPUs to parallelize the model. The present function will 90 | create 4 model parallel groups and 2 data parallel groups as: 91 | 4 model parallel groups: 92 | [g0, g1], [g2, g3], [g4, g5], [g6, g7] 93 | 2 data parallel groups: 94 | [g0, g2, g4, g6], [g1, g3, g5, g7] 95 | Note that for efficiency, the caller should make sure adjacent ranks 96 | are on the same DGX box. For example if we are using 2 DGX-1 boxes 97 | with a total of 16 GPUs, rank 0 to 7 belong to the first box and 98 | ranks 8 to 15 belong to the second box. 99 | """ 100 | # Get world size and rank. Ensure some consistencies. 101 | assert torch.distributed.is_initialized() 102 | world_size = torch.distributed.get_world_size() 103 | model_parallel_size = int(min(model_parallel_size_, world_size)) 104 | rank = torch.distributed.get_rank() 105 | 106 | data_parallel_size = int(world_size / (model_parallel_size * pipeline_length)) 107 | 108 | if torch.distributed.get_rank() == 0: 109 | print("> initializing model parallel with size {}".format(model_parallel_size_)) 110 | print("> initializing ddp with size {}".format(data_parallel_size)) 111 | print("> initializing pipeline with size {}".format(pipeline_length)) 112 | 113 | groups = torch.LongTensor(range(world_size)).reshape(data_parallel_size, pipeline_length, model_parallel_size) 114 | 115 | found = torch.where(groups == rank) 116 | assert all(len(x) == 1 for x in found) 117 | found = [x[0] for x in found] 118 | 119 | # Build the data parallel groups. 120 | global _DATA_PARALLEL_GROUP 121 | assert _DATA_PARALLEL_GROUP is None, "data parallel group is already initialized" 122 | for j in range(pipeline_length): 123 | for k in range(model_parallel_size): 124 | group = torch.distributed.new_group(groups[:, j, k].tolist(), backend=ddp_backend) 125 | if j == found[1] and k == found[2]: 126 | _DATA_PARALLEL_GROUP = group 127 | 128 | # Build the model parallel groups. 129 | global _MODEL_PARALLEL_GROUP 130 | assert _MODEL_PARALLEL_GROUP is None, "model parallel group is already initialized" 131 | for i in range(data_parallel_size): 132 | for j in range(pipeline_length): 133 | group = torch.distributed.new_group(groups[i, j, :].tolist(), backend=model_parallel_backend) 134 | if i == found[0] and j == found[1]: 135 | _MODEL_PARALLEL_GROUP = group 136 | 137 | global _PIPELINE_PARALLEL_GROUP 138 | assert _PIPELINE_PARALLEL_GROUP is None, "model parallel group is already initialized" 139 | global _PIPELINE_PARALLEL_RANKS 140 | assert _PIPELINE_PARALLEL_RANKS is None, "model parallel group is already initialized" 141 | for i in range(data_parallel_size): 142 | for k in range(model_parallel_size): 143 | ranks = groups[i, :, k].tolist() 144 | group = torch.distributed.new_group(ranks, backend=pipeline_backend) 145 | if i == found[0] and k == found[2]: 146 | _PIPELINE_PARALLEL_GROUP = group 147 | _PIPELINE_PARALLEL_RANKS = ranks 148 | 149 | 150 | def setup_mp_test_dist(rank, world_size, filename, model_parallel, seed=0): 151 | dist_init_for_testing(rank, world_size, filename) 152 | torch.cuda.set_device(rank) 153 | 154 | # Init NCCL 155 | backend = "nccl" 156 | initialize_model_parallel( 157 | model_parallel, 158 | model_parallel_backend=backend, 159 | pipeline_backend=backend, 160 | ddp_backend=backend, 161 | ) 162 | 163 | set_random_seed(seed) 164 | 165 | if torch.cuda.is_available(): 166 | torch.set_default_tensor_type(torch.cuda.FloatTensor) # type: ignore 167 | 168 | 169 | def dist_init_for_testing( 170 | rank: int, world_size: int, filename: str, filename_rpc: str = "", timeout: int = 30 171 | ): 172 | """ 173 | Same than fairscale testing.dist_init but without rpc 174 | 175 | filename_rpc is here to keep same signature than fairscale init 176 | """ 177 | 178 | os.environ["WORLD_SIZE"] = str(world_size) 179 | os.environ["RANK"] = str(rank) 180 | url = "file://" + filename 181 | 182 | backend = "nccl" if torch.cuda.is_available() else "gloo" 183 | 184 | if backend == "nccl" and torch.cuda.device_count() < world_size: 185 | raise RuntimeError( 186 | f"Requested world size {world_size} cannot be reached on this machine, not enough GPUs {torch.cuda.device_count()}" 187 | ) 188 | 189 | dist.init_process_group( 190 | backend=backend, 191 | rank=rank, 192 | world_size=world_size, 193 | init_method=url, 194 | timeout=timedelta(seconds=timeout), 195 | ) 196 | 197 | 198 | def get_dataloader( 199 | seed: int = 0, 200 | seq_len: int = 10000, 201 | rank: Optional[int] = None, 202 | world_size: Optional[int] = None, 203 | ): 204 | batch_size = 1 205 | rank = rank if rank is not None else get_rank() 206 | world_size = world_size if world_size is not None else get_world_size() 207 | 208 | instruct_tokenizer = MistralTokenizer.v3().instruct_tokenizer 209 | 210 | instruct = InstructArgs(shuffle=False, dynamic_chunk_fn_call=False) 211 | 212 | data_args = DataArgs( 213 | data="", 214 | instruct_data="tests/fixtures/sample_instruct.jsonl:.1,tests/fixtures/sample_instruct_2.jsonl:.1,tests/fixtures/sample_instruct_3.jsonl:.1", 215 | instruct=instruct, 216 | ) 217 | data_loader = build_data_loader( 218 | instruct_tokenizer, 219 | data_args, 220 | batch_size, 221 | seq_len, 222 | seed=seed, 223 | rank=rank, 224 | world_size=world_size, 225 | is_eval=False, 226 | ) 227 | return data_loader 228 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import logging 3 | import os 4 | import pprint 5 | from contextlib import ExitStack 6 | from pathlib import Path 7 | from typing import TYPE_CHECKING 8 | 9 | import fire 10 | import torch.cuda 11 | import torch.distributed as dist 12 | from mistral_common.tokens.tokenizers.mistral import MistralTokenizer 13 | from torch.optim import AdamW, lr_scheduler 14 | 15 | from finetune.args import TrainArgs 16 | from finetune.checkpointing import Checkpointer 17 | from finetune.data.data_loader import build_data_loader 18 | from finetune.distributed import ( 19 | BACKEND, 20 | avg_aggregate, 21 | get_rank, 22 | get_world_size, 23 | is_torchrun, 24 | set_device, 25 | ) 26 | from finetune.eval import evaluate 27 | from finetune.loss import compute_loss_with_mask 28 | from finetune.mixed_precision import ( 29 | downcast_mixed_precision, 30 | prepare_mixed_precision, 31 | upcast_mixed_precision, 32 | ) 33 | from finetune.monitoring.metrics_logger import ( 34 | MetricsLogger, 35 | eval_log_msg, 36 | get_eval_logs, 37 | get_train_logs, 38 | train_log_msg, 39 | ) 40 | from finetune.monitoring.utils import set_logger 41 | from finetune.utils import ( 42 | TrainState, 43 | logged_closing, 44 | set_random_seed, 45 | ) 46 | from finetune.wrapped_model import load_model, load_args 47 | 48 | if TYPE_CHECKING: 49 | from mistral_common.tokens.tokenizers.sentencepiece import InstructTokenizerBase 50 | 51 | logger = logging.getLogger("train") 52 | 53 | 54 | def main_logger_info(message: str) -> None: 55 | if get_rank() == 0: 56 | logger.info(message) 57 | 58 | 59 | def train(config: str): 60 | args: TrainArgs = TrainArgs.load(config, drop_extra_fields=False) 61 | print(f"args: {args}") 62 | set_logger(logging.INFO) 63 | 64 | with ExitStack() as exit_stack: 65 | _train(args, exit_stack) 66 | logger.info("Closed everything!") 67 | 68 | 69 | def _train( 70 | args: TrainArgs, 71 | exit_stack: ExitStack, 72 | ): 73 | # 1. Initial setup and checks 74 | set_random_seed(args.seed) 75 | 76 | # Init NCCL 77 | if "LOCAL_RANK" in os.environ: 78 | set_device() 79 | logger.info("Going to init comms...") 80 | 81 | dist.init_process_group(backend=BACKEND) 82 | else: 83 | logger.error( 84 | "PyTorch environment is not correctly initialized. This message should only be displayed when testing." 85 | ) 86 | 87 | # 2. Init run dir 88 | main_logger_info(f"Run dir: {args.run_dir}") 89 | run_dir = Path(args.run_dir) 90 | 91 | if is_torchrun(): 92 | if run_dir.exists(): 93 | raise RuntimeError( 94 | f"Run dir {run_dir} already exists. Make sure to either rename `run_dir` or remove {run_dir}." 95 | ) 96 | 97 | dist.barrier() 98 | run_dir.mkdir(exist_ok=True, parents=True) 99 | 100 | args_path = run_dir / "args.yaml" 101 | if not args_path.exists(): 102 | args.save(args_path) 103 | 104 | main_logger_info(f"TrainArgs: {pprint.pformat(dataclasses.asdict(args))}") 105 | 106 | # 3. Get loggers 107 | metrics_logger: MetricsLogger = MetricsLogger( 108 | run_dir, 109 | tag="train", 110 | is_master=get_rank() == 0, 111 | wandb_args=args.wandb, 112 | mlflow_args=args.mlflow, 113 | config=dataclasses.asdict(args), 114 | ) 115 | exit_stack.enter_context(logged_closing(metrics_logger, "metrics_logger")) 116 | 117 | eval_logger: MetricsLogger = MetricsLogger( 118 | run_dir, 119 | tag="eval", 120 | is_master=get_rank() == 0, 121 | wandb_args=args.wandb, 122 | mlflow_args=args.mlflow, 123 | config=dataclasses.asdict(args), 124 | ) 125 | exit_stack.enter_context(logged_closing(eval_logger, "eval_logger")) 126 | 127 | # 5. Potentially download model 128 | if Path(args.model_id_or_path).is_dir(): 129 | model_folder = Path(args.model_id_or_path) 130 | else: 131 | raise ValueError( 132 | "Invalid folder path. Please set `args.initial_model` to a valid folder path." 133 | ) 134 | 135 | # 6. Load function calling instruct tokenizer 136 | vocab_size = load_args(model_folder, args.lora).vocab_size 137 | is_tekken = vocab_size > 32768 138 | 139 | instruct_tokenizer: InstructTokenizerBase = MistralTokenizer.v3( 140 | is_tekken=is_tekken 141 | ).instruct_tokenizer # type: ignore 142 | 143 | # 7. Load data loaders 144 | data_loader = build_data_loader( 145 | instruct_tokenizer=instruct_tokenizer, 146 | args=args.data, 147 | seq_len=args.seq_len, 148 | batch_size=args.batch_size, 149 | seed=args.seed, 150 | rank=get_rank(), # DDP rank 151 | world_size=get_world_size(), # DDP world_size 152 | is_eval=False, 153 | ) 154 | 155 | if not args.no_eval: 156 | assert ( 157 | args.data.eval_instruct_data != "" 158 | ), "Either set `no_eval` to True or provide evaluation samples under `data.eval_instruct_data`" 159 | 160 | eval_data_loader = build_data_loader( 161 | instruct_tokenizer=instruct_tokenizer, 162 | args=args.data, 163 | seq_len=args.seq_len, 164 | batch_size=args.batch_size, 165 | seed=None, 166 | rank=get_rank(), # DDP rank 167 | world_size=get_world_size(), # DDP world_size 168 | is_eval=True, 169 | ) 170 | # pre-load all eval tokens 171 | eval_batches = list(eval_data_loader) 172 | 173 | # 8. Load model 174 | # Define mixed precision 175 | param_dtype = torch.bfloat16 176 | optim_dtype = torch.float32 177 | 178 | assert args.lora is not None, "`args.lora` should be set to a valid value." 179 | 180 | model = load_model( 181 | folder=model_folder, 182 | lora=args.lora, 183 | checkpoint=args.checkpoint, 184 | param_dtype=param_dtype, 185 | ) 186 | 187 | # 9. Load optimizer 188 | optimizer = AdamW( 189 | model.parameters(), 190 | lr=args.optim.lr, 191 | betas=(0.9, 0.95), 192 | eps=1e-08, 193 | weight_decay=args.optim.weight_decay, 194 | ) 195 | 196 | scheduler = lr_scheduler.OneCycleLR( 197 | optimizer, 198 | max_lr=args.optim.lr, 199 | total_steps=args.max_steps, 200 | pct_start=args.optim.pct_start, 201 | ) 202 | 203 | state = TrainState(args.max_steps) 204 | 205 | # 10. Initialize checkpointer 206 | checkpointer = Checkpointer( 207 | model=model, 208 | state=state, 209 | run_dir=run_dir, 210 | optimizer=optimizer, 211 | num_ckpt_keep=args.num_ckpt_keep, 212 | ) 213 | # 11. Prepare mixed precision 214 | prepare_mixed_precision( 215 | model.parameters(), param_dtype=param_dtype, optim_dtype=optim_dtype 216 | ) 217 | 218 | # 12. train! 219 | model.train() 220 | torch.cuda.empty_cache() 221 | 222 | while state.step < args.max_steps: 223 | state.start_step() 224 | is_last_step = state.step == args.max_steps 225 | 226 | optimizer.zero_grad() 227 | 228 | loss = torch.tensor([0.0], device="cuda") 229 | n_batch_tokens: int = 0 230 | 231 | for i in range(args.num_microbatches): 232 | # batch 233 | batch = next(data_loader) 234 | 235 | x = torch.from_numpy(batch.x).cuda(non_blocking=True) 236 | y = torch.from_numpy(batch.y).cuda(non_blocking=True) 237 | y_mask = ( 238 | torch.from_numpy(batch.y_mask).cuda(non_blocking=True) 239 | if batch.y_mask is not None 240 | else None 241 | ) 242 | 243 | # forward / backward 244 | output = model( 245 | input_ids=x, 246 | seqlens=batch.sizes, 247 | ) 248 | mb_loss = compute_loss_with_mask(output, y, y_mask) 249 | 250 | mb_loss.backward() 251 | 252 | loss += mb_loss.detach() 253 | n_batch_tokens += x.numel() 254 | 255 | if i < args.num_microbatches - 1: 256 | # synchronize CUDA to re-run backward 257 | assert args.num_microbatches > 1 # should not happen 258 | torch.cuda.synchronize() 259 | 260 | if args.num_microbatches > 1: 261 | loss /= args.num_microbatches 262 | for p in model.parameters(): 263 | if p.requires_grad: 264 | assert p.grad is not None 265 | p.grad.div_(args.num_microbatches) 266 | 267 | # upcast params for optimizer update 268 | upcast_mixed_precision(model.parameters(), optim_dtype=optim_dtype) 269 | 270 | # clip grad norm 271 | model.clip_grad_norm_(max_norm=args.max_norm) 272 | 273 | # optimizer step 274 | optimizer.step() 275 | 276 | # downcast params for forward & backward 277 | downcast_mixed_precision(model.parameters(), param_dtype=param_dtype) 278 | 279 | last_lr = scheduler.get_last_lr()[0] 280 | scheduler.step() 281 | 282 | # Host sync 283 | loss_item = loss.item() 284 | avg_loss = avg_aggregate(loss_item) 285 | 286 | if not args.no_eval and ( 287 | (args.eval_freq > 0 and state.step % args.eval_freq == 0) or is_last_step 288 | ): 289 | # write perplexity to state 290 | evaluate(model, eval_batches, state) 291 | 292 | eval_logs = get_eval_logs( 293 | state.step, avg_loss, state.this_eval_perplexity, state.this_eval_loss 294 | ) 295 | 296 | main_logger_info(eval_log_msg(eval_logs)) 297 | eval_logger.log(eval_logs, step=state.step) 298 | 299 | # Timing 300 | state.end_step(n_batch_tokens) 301 | 302 | if state.step % args.log_freq == 0: 303 | train_logs = get_train_logs( 304 | state, 305 | avg_loss, 306 | last_lr, 307 | torch.cuda.max_memory_allocated(), 308 | torch.cuda.memory_allocated(), 309 | args, 310 | ) 311 | main_logger_info(train_log_msg(state, logs=train_logs, loss=avg_loss)) 312 | metrics_logger.log(train_logs, step=state.step) 313 | 314 | if not args.no_ckpt and ( 315 | (args.ckpt_freq > 0 and state.step % args.ckpt_freq == 0) or is_last_step 316 | ): 317 | checkpointer.save_checkpoint( 318 | save_only_lora=args.save_adapters, 319 | dtype=param_dtype, 320 | instruct_tokenizer=instruct_tokenizer, 321 | ) 322 | 323 | main_logger_info("done!") 324 | 325 | 326 | if __name__ == "__main__": 327 | """See README.md for usage.""" 328 | fire.Fire(train) 329 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mistralai/mistral-finetune/656df1c94c80ca9703ebc471c9f106c9b7a0bfa7/utils/__init__.py -------------------------------------------------------------------------------- /utils/extend_model_vocab.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import math 4 | import os 5 | from pathlib import Path 6 | 7 | import torch 8 | from mistral_common.tokens.tokenizers.mistral import MistralTokenizer 9 | from mistral_common.tokens.tokenizers.sentencepiece import SentencePieceTokenizer 10 | 11 | from model.args import ModelArgs 12 | 13 | FIRST_PIECE_ID = 3 14 | OLD_VOCAB_SIZE = 32000 15 | NEW_VOCAB_SIZE = 32768 16 | 17 | 18 | def extend_model(original_model: Path, extended_model: Path): 19 | original_ckpt = torch.load(str(original_model / "consolidated.00.pth"), mmap=True) 20 | model_args = ModelArgs.load(str(original_model / "params.json")) 21 | 22 | original_vocab_size = model_args.vocab_size 23 | assert ( 24 | original_vocab_size == OLD_VOCAB_SIZE 25 | ), f"Original vocab size {original_vocab_size} is not equal to 32000. Can only extend models with vocab_size of 32000" 26 | 27 | if not extended_model.exists(): 28 | os.makedirs(extended_model, exist_ok=True) 29 | print(f"Created empty directory {extended_model}.") 30 | 31 | assert not list( 32 | extended_model.iterdir() 33 | ), f"Make sure {extended_model} is empty" 34 | 35 | # Load and check tokenizers 36 | mistral_tokenizer = MistralTokenizer.v3() 37 | tokenizer: SentencePieceTokenizer = mistral_tokenizer.instruct_tokenizer.tokenizer 38 | 39 | new_vocab_size = tokenizer.n_words 40 | assert ( 41 | new_vocab_size == 32768 42 | ), f"New Tokenizer has vocab_size: {new_vocab_size} but has to be equal to 32768. Make sure to pass a v2 or v3 tokenizer file" 43 | 44 | vocabulary_delta = new_vocab_size - original_vocab_size 45 | 46 | # Check that 0...FIRST_PIECE_ID-1 are UNK + control characters and FIRST_PIECE_ID is the first piece 47 | assert tokenizer._model.id_to_piece(vocabulary_delta + FIRST_PIECE_ID) == "<0x00>" 48 | assert tokenizer._model.id_to_piece(FIRST_PIECE_ID - 1) == "" 49 | 50 | assert isinstance(tokenizer, SentencePieceTokenizer) 51 | 52 | original_embeddings = original_ckpt["tok_embeddings.weight"] 53 | 54 | assert ( 55 | original_vocab_size == original_embeddings.shape[0] 56 | ), f"Original vocab size {original_vocab_size} is not equal to original embeddings shape {original_embeddings.shape[0]}." 57 | 58 | dim = original_embeddings.shape[1] 59 | 60 | # Extend embeddings 61 | extended_embeddings = torch.zeros( 62 | tokenizer.n_words, dim, dtype=original_embeddings.dtype 63 | ) 64 | extended_embeddings[:original_vocab_size] = original_embeddings 65 | extended_embeddings[:FIRST_PIECE_ID] = original_embeddings[:FIRST_PIECE_ID] 66 | extended_embeddings[FIRST_PIECE_ID + vocabulary_delta :] = original_embeddings[ 67 | FIRST_PIECE_ID: 68 | ] 69 | 70 | # randomly initialize new tokens 71 | extended_tokens = torch.empty( 72 | vocabulary_delta, dim, dtype=original_embeddings.dtype 73 | ) 74 | torch.nn.init.normal_(extended_tokens, std=1 / math.sqrt(dim)) 75 | 76 | extended_embeddings[FIRST_PIECE_ID : FIRST_PIECE_ID + vocabulary_delta] = ( 77 | extended_tokens 78 | ) 79 | 80 | # Extend output 81 | original_output = original_ckpt["output.weight"] 82 | assert ( 83 | original_output.shape[0] == original_vocab_size 84 | ), f"Original output shape {original_output.shape[0]} is not equal to {original_vocab_size}." 85 | assert ( 86 | original_output.shape[1] == dim 87 | ), f"Original output dim {original_output.shape[1]} is not equal to embedding dim {dim}." 88 | 89 | assert ( 90 | original_output.dtype == original_embeddings.dtype 91 | ), f"Original output and embeddings have different dtypes: {original_output.dtype} vs {original_embeddings.dtype}." 92 | 93 | extended_output = torch.zeros(tokenizer.n_words, dim, dtype=original_output.dtype) 94 | extended_output[:FIRST_PIECE_ID] = original_output[:FIRST_PIECE_ID] 95 | extended_output[FIRST_PIECE_ID + vocabulary_delta :] = original_output[ 96 | FIRST_PIECE_ID: 97 | ] 98 | 99 | # randomly initialize new tokens 100 | extended_tokens = torch.empty(vocabulary_delta, dim, dtype=original_output.dtype) 101 | torch.nn.init.normal_(extended_tokens, std=1 / math.sqrt(dim)) 102 | 103 | extended_output[FIRST_PIECE_ID : FIRST_PIECE_ID + vocabulary_delta] = ( 104 | extended_tokens 105 | ) 106 | 107 | original_ckpt["tok_embeddings.weight"] = extended_embeddings 108 | original_ckpt["output.weight"] = extended_output 109 | 110 | new_ckpt_path = extended_model / "consolidated.00.pth" 111 | print(f"Exporting extended model to {extended_model} ...") 112 | torch.save(original_ckpt, new_ckpt_path) 113 | 114 | params_path = extended_model / "params.json" 115 | with open(params_path, "w") as f: 116 | model_dict = model_args.to_dict() 117 | del model_dict["lora"] 118 | if model_dict["moe"] is None: 119 | del model_dict["moe"] 120 | model_dict["vocab_size"] = new_vocab_size 121 | 122 | f.write(json.dumps(model_dict, indent=4)) 123 | 124 | 125 | def main(): 126 | parser = argparse.ArgumentParser( 127 | description="Extend a model using the specified original model, extended model, and tokenizer paths." 128 | ) 129 | parser.add_argument( 130 | "--original_model_ckpt", type=Path, help="Path to the original model folder." 131 | ) 132 | parser.add_argument( 133 | "--extended_model_ckpt", type=Path, help="Path to the extended model file." 134 | ) 135 | args = parser.parse_args() 136 | 137 | extend_model( 138 | original_model=args.original_model_ckpt, 139 | extended_model=args.extended_model_ckpt, 140 | ) 141 | 142 | 143 | if __name__ == "__main__": 144 | main() 145 | -------------------------------------------------------------------------------- /utils/merge_lora.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Dict, Optional 3 | 4 | import safetensors.torch 5 | import torch 6 | import tqdm 7 | 8 | 9 | def merge_checkpoints( 10 | model_checkpoint: Dict[str, torch.Tensor], 11 | lora_checkpoint: Dict[str, torch.Tensor], 12 | scaling: float, 13 | save_dtype: Optional[torch.dtype] = None, 14 | ): 15 | save_dtype = save_dtype or next(iter(lora_checkpoint.values())).dtype 16 | print(f"Merging to {save_dtype} precision...") 17 | 18 | keys_to_update = [ 19 | key for key in lora_checkpoint.keys() if "norm" in key or "lora_A" in key 20 | ] 21 | assert any( 22 | "lora_A" in k or "lora_B" in k for k in keys_to_update 23 | ), "No `lora` keys found in your checkpoint. Check that `lora_ckpt` is correct." 24 | 25 | for key in tqdm.tqdm(keys_to_update): 26 | if "norm" in key: 27 | model_checkpoint[key] = lora_checkpoint[key].to("cpu") 28 | else: 29 | weight_name = key.replace("lora_A.weight", "weight") 30 | 31 | lora_A_weight = lora_checkpoint[key].to("cuda") 32 | lora_B_weight = lora_checkpoint[key.replace("lora_A", "lora_B")].to("cuda") 33 | 34 | weight = lora_B_weight.mm(lora_A_weight) * scaling 35 | weight += model_checkpoint[weight_name].to("cuda") 36 | weight = weight.to(save_dtype) 37 | 38 | model_checkpoint[weight_name] = weight.to("cpu") 39 | 40 | # cast all tensors to save dtype 41 | for key in tqdm.tqdm(model_checkpoint.keys()): 42 | model_checkpoint[key] = model_checkpoint[key].to(save_dtype) 43 | 44 | 45 | def load(filename: str): 46 | if filename.endswith(".safetensors"): 47 | return safetensors.torch.load_file(filename) 48 | else: 49 | return torch.load(filename) 50 | 51 | 52 | def main(args): 53 | model_checkpoint = load(args.initial_model_ckpt) 54 | lora_checkpoint = load(args.lora_ckpt) 55 | 56 | merge_checkpoints(model_checkpoint, lora_checkpoint, args.scaling) 57 | 58 | safetensors.torch.save_file(model_checkpoint, args.dump_ckpt) 59 | 60 | print(f"Merged checkpoint saved to {args.dump_ckpt}") 61 | 62 | 63 | if __name__ == "__main__": 64 | parser = argparse.ArgumentParser( 65 | description="Merge a LoRA checkpoint into a model checkpoint." 66 | ) 67 | parser.add_argument( 68 | "--initial_model_ckpt", 69 | type=str, 70 | required=True, 71 | help="Path to the model checkpoint.", 72 | ) 73 | parser.add_argument( 74 | "--lora_ckpt", type=str, required=True, help="Path to the LoRA checkpoint." 75 | ) 76 | parser.add_argument( 77 | "--dump_ckpt", 78 | type=str, 79 | required=True, 80 | help="Path to save the merged checkpoint.", 81 | ) 82 | parser.add_argument( 83 | "--scaling", 84 | type=float, 85 | default=2.0, 86 | help="Scaling factor for the LoRA checkpoint. Default is 2.0.", 87 | ) 88 | 89 | args = parser.parse_args() 90 | main(args) 91 | -------------------------------------------------------------------------------- /utils/reformat_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import json 4 | import os 5 | import random 6 | import string 7 | 8 | 9 | def reformat_jsonl(input_file): 10 | output_file = input_file + ".tmp" 11 | 12 | content_keys = ["content", "text"] 13 | with open(input_file, "r") as infile, open(output_file, "w") as outfile: 14 | for idx, line in enumerate(infile): 15 | data = json.loads(line) 16 | if "functions" in data: 17 | data["tools"] = [{"function": func} for func in data["functions"]] 18 | del data["functions"] 19 | 20 | skip_sample = False 21 | if "messages" in data: 22 | for i, msg in enumerate(data["messages"]): 23 | if "function_call" in msg: 24 | if "content" in msg: 25 | assert msg["content"] == "" 26 | del msg["content"] 27 | 28 | arguments = json.loads(msg["function_call"]["arguments"]) 29 | msg["function_call"]["arguments"] = json.dumps(arguments) 30 | 31 | msg["tool_calls"] = [{"function": msg.pop("function_call")}] 32 | 33 | for key in content_keys: 34 | if key in msg and msg[key] == "": 35 | if "tool_calls" in msg: 36 | del msg[key] 37 | print( 38 | f"Delete empty '{key}' field in tool call message in line {idx}" 39 | ) 40 | 41 | if all(msg.get(key) in ["", None] for key in content_keys): 42 | # conversation is incorrect 43 | skip_sample = True 44 | 45 | if msg["role"] in ["function", "tool"]: 46 | msg["role"] = "tool" 47 | if "tool_call_id" not in msg: 48 | msg["tool_call_id"] = "".join( 49 | random.choices( 50 | string.ascii_letters + string.digits, k=9 51 | ) 52 | ) 53 | 54 | # make sure prev 55 | if data["messages"][i - 1]["role"] == "assistant": 56 | prev_msg = data["messages"][i - 1] 57 | if "tool_calls" in prev_msg: 58 | tool_name = prev_msg["tool_calls"][0]["function"][ 59 | "name" 60 | ] 61 | 62 | assert tool_name == msg["name"] 63 | prev_msg["tool_calls"][0]["id"] = msg["tool_call_id"] 64 | 65 | # make sure last message is an assistant message 66 | while ( 67 | len(data["messages"]) > 0 68 | and data["messages"][-1]["role"] != "assistant" 69 | ): 70 | data["messages"].pop() 71 | 72 | if len(data["messages"]) == 0: 73 | skip_sample = True 74 | 75 | if not skip_sample: 76 | outfile.write(json.dumps(data) + "\n") 77 | else: 78 | print(f"Skipped {idx}th sample") 79 | 80 | os.rename(output_file, input_file) 81 | 82 | 83 | if __name__ == "__main__": 84 | parser = argparse.ArgumentParser(description="Reformat a JSONL file.") 85 | parser.add_argument("file", type=str, help="The input JSONL file") 86 | 87 | args = parser.parse_args() 88 | reformat_jsonl(args.file) 89 | -------------------------------------------------------------------------------- /utils/reformat_data_glaive.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import json 4 | import os 5 | import random 6 | import string 7 | 8 | 9 | def reformat_jsonl(input_file): # noqa: C901 10 | output_file = os.path.splitext(input_file)[0] + "_reformatted.jsonl" 11 | skipped_samples = [] 12 | 13 | with open(input_file, "r") as infile, open(output_file, "w") as outfile: 14 | for i, line in enumerate(infile): 15 | reformat_data = True 16 | data = json.loads(line) 17 | 18 | # Extract function description 19 | try: 20 | function_desc = json.loads(data["function_description"]) 21 | except json.decoder.JSONDecodeError: 22 | function_desc = ( 23 | data["function_description"].replace("\n", "").replace("}{", "},{").replace("\\t", "") 24 | ) 25 | function_desc = "[{" + function_desc[1:-1] + "}]" 26 | function_desc = json.loads(function_desc) 27 | 28 | function_desc = function_desc if isinstance(function_desc, list) else [function_desc] 29 | 30 | # Reformat tools section 31 | if len(function_desc) == 1 and function_desc[0] == {}: 32 | tools = None 33 | else: 34 | tools = [] 35 | for f in function_desc: 36 | if f["parameters"] is None: 37 | f["parameters"] = {} 38 | tools.append({"type": "function", "function": f}) 39 | 40 | messages = [] 41 | 42 | # Process conversations 43 | for idx, msg in enumerate(data["conversations"]): 44 | role = msg["from"] 45 | content = msg["value"] 46 | 47 | if role == "system": 48 | messages.append( 49 | {"role": "system", "content": content.split(" -")[0]} 50 | ) 51 | elif role == "human": 52 | messages.append({"role": "user", "content": content}) 53 | elif role == "function-call": 54 | try: 55 | function_call = json.loads(content) 56 | except json.decoder.JSONDecodeError: 57 | content = content.replace("'", "").replace("\\", "'") 58 | try: 59 | function_call = json.loads(content) 60 | except: # noqa: E722 61 | skipped_samples.append(str(i)) 62 | reformat_data = False 63 | break 64 | 65 | if not isinstance(function_call, list): 66 | function_calls = [function_call] 67 | else: 68 | function_calls = function_call 69 | 70 | tool_calls = [] 71 | for function_call in function_calls: 72 | assert not isinstance(function_call, list) 73 | tool_call_id = "".join( 74 | random.choices(string.ascii_letters + string.digits, k=9) 75 | ) 76 | 77 | if "arguments" in function_call and not isinstance(function_call["arguments"], str): 78 | function_call["arguments"] = str(function_call["arguments"]) 79 | elif "arguments" not in function_call: 80 | function_call["arguments"] = "" 81 | 82 | tool_calls.append({"id": tool_call_id, "type": "function", "function": function_call}) 83 | 84 | messages.append( 85 | { 86 | "role": "assistant", 87 | "tool_calls": tool_calls 88 | } 89 | ) 90 | elif role == "function-response": 91 | if "tool_calls" not in messages[-1]: 92 | skipped_samples.append(str(i)) 93 | reformat_data = False 94 | break 95 | 96 | assert len(messages[-1]["tool_calls"]) == 1 97 | tool_call_id = messages[-1]["tool_calls"][0]["id"] 98 | messages.append( 99 | { 100 | "role": "tool", 101 | "content": content, 102 | "tool_call_id": tool_call_id, 103 | } 104 | ) 105 | elif role == "gpt": 106 | messages.append({"role": "assistant", "content": content}) 107 | 108 | output_data = {"messages": messages} 109 | 110 | if tools is not None: 111 | output_data["tools"] = tools 112 | 113 | if reformat_data: 114 | outfile.write(json.dumps(output_data) + "\n") 115 | 116 | os.rename(output_file, input_file) 117 | print( 118 | f"Skipped {len(skipped_samples)} samples ({len(skipped_samples) / i:.2%}). The following samples are incorrectly formatted: \n\n {', '.join(skipped_samples)}" 119 | ) 120 | 121 | 122 | if __name__ == "__main__": 123 | parser = argparse.ArgumentParser(description="Reformat a JSONL file.") 124 | parser.add_argument("file", type=str, help="The input JSONL file") 125 | 126 | args = parser.parse_args() 127 | reformat_jsonl(args.file) 128 | -------------------------------------------------------------------------------- /utils/validate_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from pathlib import Path 4 | from typing import Dict 5 | 6 | from mistral_common.exceptions import ( 7 | InvalidAssistantMessageException, 8 | InvalidFunctionCallException, 9 | InvalidMessageStructureException, 10 | InvalidToolSchemaException, 11 | TokenizerException, 12 | ) 13 | from mistral_common.tokens.tokenizers.mistral import MistralTokenizer 14 | from tqdm import tqdm 15 | 16 | from finetune.args import TrainArgs 17 | from finetune.data.dataset import parse_data_sources 18 | from finetune.data.tokenize import ( 19 | ConversationFormatError, 20 | FunctionFormatError, 21 | MessageFormatError, 22 | SampleType, 23 | ToolCallFormatError, 24 | UnrecognizedRoleError, 25 | build_instruct_sample, 26 | get_pretrain_sample, 27 | tokenize, 28 | ) 29 | 30 | NUM_GPUS = 8 31 | 32 | # EXPECTED WPS for batch_size = 32768 per GPU on H100 33 | EXPECTED_WPS = { 34 | "open-mistral-7b": 5720, 35 | "open-mixtral-8x7b": 2966, 36 | "open-mixtral-8x22b": 1007, 37 | "mistral-large-latest": 567, 38 | 'open-mistral-nemo': 3337, 39 | } 40 | 41 | MIN_NUM_JSONL_LINES = 10 42 | MAX_NUM_JSONL_LINES = 10_000_000 43 | 44 | MIN_BYTES = 1_000 45 | MAX_BYTES = 10_000_000_000 # roughly 10 GB 46 | 47 | 48 | def convert_seconds_to_hms(seconds: float) -> str: 49 | hours = int(seconds // 3600) 50 | seconds %= 3600 51 | minutes = int(seconds // 60) 52 | seconds %= 60 53 | 54 | seconds = int(seconds) 55 | return f"{hours:02d}:{minutes:02d}:{seconds:02d}" 56 | 57 | 58 | def verify_size(jsonl_file: Path): 59 | print(f"Verifying {jsonl_file} ...") 60 | with jsonl_file.open() as f: 61 | num_lines = 0 62 | num_bytes = 0 63 | for line in f: 64 | num_lines += 1 65 | num_bytes += len(line) 66 | 67 | if num_lines < MIN_NUM_JSONL_LINES: 68 | raise ValueError( 69 | f"{jsonl_file} has only {num_lines} conversation which is less than the minimum amount of conversations required per dataset file: {MIN_NUM_JSONL_LINES}. Please make sure that each dataset has at least {MIN_NUM_JSONL_LINES} conversations." 70 | ) 71 | elif num_bytes < MIN_BYTES: 72 | raise ValueError( 73 | f"{jsonl_file} has only {num_bytes} bytes which is less than the minimum amount of bytes required per dataset file: {MIN_BYTES}. Please make sure that each dataset has at least {MIN_BYTES} bytes." 74 | ) 75 | elif num_lines > MAX_NUM_JSONL_LINES: 76 | raise ValueError( 77 | f"{jsonl_file} has {num_lines} conversations which is more than the maximum amount of allowed per dataset file: {MAX_NUM_JSONL_LINES}. Please make sure that each dataset has no more than {MAX_NUM_JSONL_LINES} conversations." 78 | ) 79 | elif num_bytes > MAX_BYTES: 80 | raise ValueError( 81 | f"{jsonl_file} has {num_bytes} bytes which is more than the maximum amount of bytes allowed per dataset file: {MAX_BYTES}. Please make sure that each dataset has no more than {MAX_BYTES} bytes." 82 | ) 83 | 84 | print( 85 | f"Dataset {jsonl_file} is valid. Dataset has {num_lines} conversations amounting to a total of {num_bytes} bytes." 86 | ) 87 | 88 | 89 | def get_train_stats( 90 | num_tokens: Dict[str, int], 91 | datasets_proportion: Dict[str, float], 92 | train_args: TrainArgs, 93 | return_type: str, 94 | ): 95 | dataset_tokens = sum(num_tokens.values()) 96 | batch_size = train_args.batch_size * train_args.seq_len * NUM_GPUS 97 | 98 | if Path(train_args.model_id_or_path).is_dir(): 99 | params_config = json.load( 100 | (Path(train_args.model_id_or_path) / "params.json").open() 101 | ) 102 | 103 | if params_config["dim"] == 4096 and params_config.get("moe") is None: 104 | model_id = "open-mistral-7b" 105 | elif params_config["dim"] == 4096 and params_config.get("moe") is not None: 106 | model_id = "open-mixtral-8x7b" 107 | elif params_config["dim"] == 6144: 108 | model_id = "open-mixtral-8x22b" 109 | elif params_config["dim"] == 12288: 110 | model_id = "mistral-large-latest" 111 | elif params_config["dim"] == 5120: 112 | model_id = "open-mistral-nemo" 113 | else: 114 | raise ValueError("Provided model folder seems incorrect.") 115 | else: 116 | model_id = train_args.model_id_or_path 117 | 118 | wps = EXPECTED_WPS[model_id] 119 | 120 | if return_type == "expected": 121 | train_tokens = train_args.max_steps * batch_size 122 | max_steps = train_args.max_steps 123 | num_epochs = train_tokens / dataset_tokens 124 | elif return_type == "recommended": 125 | num_epochs = 3 126 | max_steps = int(sum(num_tokens.values()) / batch_size * num_epochs) 127 | train_tokens = max_steps * batch_size 128 | else: 129 | raise ValueError( 130 | f"`return_type` is {return_type}, but has to be one of ['expected', 'recommended']" 131 | ) 132 | 133 | expected_time_in_sec = train_tokens / NUM_GPUS / wps 134 | 135 | # Add 5min buffer for loading/init/ckpt/eval 136 | expected_time_in_sec += 300 137 | 138 | train_tokens_per_dataset = { 139 | k: (train_tokens * v) for k, v in datasets_proportion.items() 140 | } 141 | 142 | return { 143 | "eta": convert_seconds_to_hms(expected_time_in_sec), 144 | "data_tokens": dataset_tokens, 145 | "train_tokens": train_tokens, 146 | "epochs": f"{num_epochs:.2f}", 147 | "max_steps": max_steps, 148 | "data_tokens_per_dataset": {k: f"{v:.1f}" for k, v in num_tokens.items()}, 149 | "train_tokens_per_dataset": { 150 | k: f"{v:.1f}" for k, v in train_tokens_per_dataset.items() 151 | }, 152 | "epochs_per_dataset": { 153 | k: f"{(train_tokens_per_dataset[k] / num_tokens[k]):.1f}" 154 | for k in num_tokens.keys() 155 | }, 156 | } 157 | 158 | 159 | def main(args): 160 | train_args = TrainArgs.load(args.train_yaml) 161 | 162 | yaml_data_errors = [] 163 | conversation_format_errors = [] 164 | message_format_errors = [] 165 | tokenization_errors = [] 166 | 167 | # Check if pretrain can be loaded 168 | # train_pretrain_data = train_args.data.data 169 | data = [("train", train_args.data.data, train_args.data.instruct_data)] 170 | 171 | if train_args.data.eval_instruct_data != "": 172 | data.append(("eval", "", train_args.data.eval_instruct_data)) 173 | 174 | EXPECTED_WPS.keys() 175 | 176 | instruct_tokenizer = MistralTokenizer.v3().instruct_tokenizer 177 | 178 | for name, pretrain_file, instruct_file in data: 179 | datasets, weights = parse_data_sources(pretrain_file, instruct_file) 180 | data_types = [d.sample_type for d in datasets] 181 | datasets = [str(d.path) for d in datasets] 182 | 183 | datasets_proportion = dict(zip(datasets, weights)) 184 | num_tokens = {k: 0 for k in datasets_proportion} 185 | 186 | for data_type, dataset in tqdm(zip(data_types, datasets)): 187 | # verify_size(Path(dataset)) 188 | print(f"Validating {dataset} ...") 189 | 190 | corrected_dataset = dataset + ".corrected" 191 | correct_lines = [] 192 | 193 | sub_yaml_data_errors = [] 194 | sub_conversation_format_errors = [] 195 | sub_message_format_errors = [] 196 | sub_tokenization_errors = [] 197 | 198 | # Load the dataset 199 | with open(dataset, "r", encoding="utf-8") as f: 200 | lines = f.readlines() 201 | for idx, line in tqdm(enumerate(lines), total=len(lines)): 202 | try: 203 | data = json.loads(line) 204 | except ValueError as e: 205 | prefix = f"The data in line {idx + 1} of dataset {dataset} is incorrectly formatted." 206 | sub_yaml_data_errors.append(prefix + str(e)) 207 | continue 208 | 209 | if data_type == SampleType.PRETRAIN: 210 | # TODO(Patrick) - Get good error message 211 | sample = get_pretrain_sample(data) 212 | else: 213 | try: 214 | sample = build_instruct_sample(data) 215 | 216 | except ( 217 | ConversationFormatError, 218 | UnrecognizedRoleError, 219 | MessageFormatError, 220 | ToolCallFormatError, 221 | FunctionFormatError, 222 | InvalidAssistantMessageException, 223 | InvalidFunctionCallException, 224 | InvalidMessageStructureException, 225 | InvalidToolSchemaException, 226 | ) as e: 227 | prefix = f"The data in line {idx + 1} of dataset {dataset} is incorrectly formatted." 228 | if isinstance( 229 | e, (ConversationFormatError, FunctionFormatError) 230 | ): 231 | sub_conversation_format_errors.append(prefix + str(e)) 232 | elif isinstance( 233 | e, 234 | ( 235 | MessageFormatError, 236 | UnrecognizedRoleError, 237 | ToolCallFormatError, 238 | ), 239 | ): 240 | sub_message_format_errors.append(prefix + str(e)) 241 | if isinstance( 242 | e, 243 | ( 244 | InvalidFunctionCallException, 245 | InvalidMessageStructureException, 246 | InvalidAssistantMessageException, 247 | InvalidToolSchemaException, 248 | ), 249 | ): 250 | sub_conversation_format_errors.append(prefix + str(e)) 251 | 252 | continue 253 | try: 254 | tokens = tokenize(sample, instruct_tokenizer).tokens 255 | except TokenizerException as e: 256 | error_message = ( 257 | f"The data in line {idx + 1} of dataset {dataset} could not be tokenized. " 258 | + str(e) 259 | ) 260 | sub_tokenization_errors.append(error_message) 261 | 262 | correct_lines.append(line) 263 | num_tokens[dataset] += len(tokens) 264 | 265 | is_sub_error = ( 266 | len( 267 | sub_yaml_data_errors 268 | + sub_conversation_format_errors 269 | + sub_message_format_errors 270 | + sub_tokenization_errors 271 | ) 272 | > 0 273 | ) 274 | if is_sub_error and args.create_corrected: 275 | with open(corrected_dataset, "w", encoding="utf-8") as f: 276 | for line in correct_lines: 277 | f.write(line) 278 | 279 | print(f"Saved {corrected_dataset}.") 280 | elif args.create_corrected: 281 | print(f"No error in {dataset} - no need to create a corrected version.") 282 | 283 | yaml_data_errors.extend(sub_yaml_data_errors) 284 | conversation_format_errors.extend(sub_conversation_format_errors) 285 | message_format_errors.extend(sub_message_format_errors) 286 | tokenization_errors.extend(sub_tokenization_errors) 287 | 288 | is_error = ( 289 | len( 290 | yaml_data_errors 291 | + conversation_format_errors 292 | + message_format_errors 293 | + tokenization_errors 294 | ) 295 | > 0 296 | ) 297 | if is_error: 298 | all_yaml_data_errors = "\n".join(yaml_data_errors) 299 | all_conversation_format_errors = "\n".join(conversation_format_errors) 300 | all_message_format_errors = "\n".join(message_format_errors) 301 | all_tokenization_errors = "\n".join(tokenization_errors) 302 | error_report = f""" 303 | Data error report 304 | ----------------------- \n 305 | The passed datasets contains some errors as listed below. Please make sure to fix these errors in order to start training. 306 | 307 | YAML data load errors: \n\n {all_yaml_data_errors} \n\n 308 | Conversation format errors: \n\n {all_conversation_format_errors} \n\n 309 | Message format errors: \n\n {all_message_format_errors} \n\n 310 | Tokenization errors: \n\n {all_tokenization_errors} \n\n 311 | """ 312 | if args.save_reports: 313 | with open(args.error_report_txt, "w") as f: 314 | f.write(error_report) 315 | 316 | print(error_report) 317 | else: 318 | print("No errors! Data is correctly formatted!") 319 | 320 | if name == "train" and not is_error: 321 | expected_stats = get_train_stats( 322 | num_tokens, datasets_proportion, train_args, return_type="expected" 323 | ) 324 | stats = { 325 | "expected": expected_stats, 326 | } 327 | 328 | filenames = ( 329 | f"{instruct_file}" 330 | if pretrain_file == "" 331 | else f"{instruct_file} and {pretrain_file}" 332 | ) 333 | 334 | print( 335 | f"Stats for {filenames} \n {20 * '-'} \n {json.dumps(stats, indent=4)}" 336 | ) 337 | 338 | if args.save_reports: 339 | if name == "train": 340 | with open(args.train_stats_json, "w") as file: 341 | json.dump(stats, file, indent=4) 342 | 343 | 344 | if __name__ == "__main__": 345 | parser = argparse.ArgumentParser(description="Validate jsonl.") 346 | parser.add_argument( 347 | "--train_yaml", 348 | type=str, 349 | help="Path to the data file", 350 | ) 351 | parser.add_argument( 352 | "--error_report_txt", 353 | type=str, 354 | default="data_errors.txt", 355 | help="Path to the error report.", 356 | ) 357 | parser.add_argument( 358 | "--train_stats_json", 359 | type=str, 360 | default="train_stats.json", 361 | help="Path to training statistics json file.", 362 | ) 363 | parser.add_argument( 364 | "--save_reports", action="store_true", help="Save reports to disk" 365 | ) 366 | parser.add_argument( 367 | "--create_corrected", 368 | action="store_true", 369 | help="Skip faulty lines and append all correct lines to `.corrected` datasets.", 370 | ) 371 | args = parser.parse_args() 372 | main(args) 373 | --------------------------------------------------------------------------------