├── .github └── workflows │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── README_ZH.md ├── adalomo ├── README.md ├── README_ZH.md ├── further-pretraining │ ├── evaluate.py │ └── train.py └── instruction-tuning │ └── train.py ├── assets ├── LOMO.png ├── adalomo_algorithm.png └── hook_func.png ├── lomo ├── README.md ├── README_ZH.md ├── config │ ├── args_lomo.yaml │ ├── args_lomo_lora.yaml │ ├── ds_config.json │ └── ds_config_lora.json ├── log │ ├── __init__.py │ ├── handler.py │ ├── highlighter.py │ ├── logger.py │ └── print.py ├── run.sh └── src │ ├── arguments.py │ ├── lomo.py │ ├── lomo_lora_trainer.py │ ├── lomo_trainer.py │ ├── merge_llama_with_lora.py │ ├── mydatasets.py │ ├── prompts.py │ ├── train_lomo.py │ ├── train_lomo_lora.py │ └── utils.py ├── lomo_optim ├── __init__.py ├── adalomo.py └── lomo.py └── pyproject.toml /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | pull_request: 15 | branches: [main] 16 | 17 | permissions: 18 | contents: read 19 | 20 | jobs: 21 | deploy: 22 | 23 | runs-on: ubuntu-latest 24 | 25 | steps: 26 | - uses: actions/checkout@v4 27 | - name: Set up Python 28 | uses: actions/setup-python@v5 29 | with: 30 | python-version: '3.x' 31 | 32 | - name: Install the dependencies 33 | run: pip install build twine 34 | - uses: gaurav-nelson/github-action-markdown-link-check@v1 35 | - name: Build and publish 36 | run: python -m build --wheel 37 | 38 | - name: Publish 39 | if: github.event_name == 'release' 40 | env: 41 | TWINE_USERNAME: __token__ 42 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} 43 | run: twine upload dist/* 44 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | lomo_optim.egg-info 3 | dist 4 | build 5 | __pycache__ 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 OpenLMLab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [**English**](./README.md) | [**中文**](./README_ZH.md) 2 | 3 | This is the implementation for [Full Parameter Fine-Tuning for Large Language Models with Limited Resources](https://arxiv.org/pdf/2306.09782.pdf) 4 | and [AdaLomo: Low-memory Optimization with Adaptive Learning Rate](https://arxiv.org/pdf/2310.10195.pdf). 5 | 6 | # News 7 | - LOMO and AdaLomo were integrated in [`transformers`](https://huggingface.co/docs/transformers/main/en/trainer#lomo-optimizer) and [`accelerate`](https://huggingface.co/docs/accelerate/main/en/package_reference/accelerator#accelerate.Accelerator.lomo_backward). 8 | - PyPI package `lomo-optim` was released. 9 | - LOMO and AdaLomo were integrated in [`CoLLiE`](https://github.com/OpenMOSS/collie) library, which supports Collaborative Training of Large Language Models in an Efficient Way. 10 | 11 | # Usage 12 | You can install `lomo-optim` from PyPI using pip. 13 | 14 | ```bash 15 | pip install lomo-optim 16 | ``` 17 | 18 | Then, import `Lomo` or `AdaLomo`. 19 | 20 | ```python 21 | from lomo_optim import Lomo 22 | from lomo_optim import AdaLomo 23 | ``` 24 | 25 | The usage of `Lomo` and `AdaLomo` is similar but not the same as PyTorch's optimizers 26 | ([example](https://github.com/OpenMOSS/CoLLiE/blob/726ec80d263c1e1c56344dfde5b3c24897daa94d/collie/controller/trainer.py#L469)). 27 | We recommend to use `AdaLomo` without `gradnorm` to get better performance and higher throughput. 28 | 29 | # LOMO: LOw-Memory Optimization 30 | 31 | In this work, we propose a new optimizer, **LO**w-Memory **O**ptimization (**LOMO**), which fuses the gradient computation and the parameter update in one step to reduce memory usage. 32 | Our approach enables the full parameter fine-tuning of a 7B model on a single RTX 3090, or 33 | a 65B model on a single machine with 8×RTX 3090, each with 24GB memory. 34 | 35 | ![LOMO](https://raw.githubusercontent.com/OpenLMLab/LOMO/main/assets/LOMO.png) 36 | 37 | ## Implementation 38 | ![Hook function](https://raw.githubusercontent.com/OpenLMLab/LOMO/main/assets/hook_func.png) 39 | Our implementation relies on injecting hook functions into PyTorch's backward pass. As depicted in the figure, we register a customized hook function for each parameter. When the gradient of a parameter is computed (prior to writing it to the .grad attribute), its corresponding hook function is invoked. For more information about hook functions and the backward pass of the autograd graph, please refer to [PyTorch's documentation](https://pytorch.org/docs/stable/notes/autograd.html#backward-hooks-execution). In summary, during the backward pass, we go through a tensor and its grad_fn, write the gradient into the .grad attribute, and then pass to the next tensor. 40 | 41 | Our customized hook function scans all the parameters, updating a parameter if its .grad attribute is not empty, and then clears and frees the .grad attribute. Since the hook function for a parameter is called before its .grad attribute is set, the .grad attribute of the last parameter in the autograd graph is not ready when the last hook function is invoked. Therefore, we perform an additional scan to update the last parameter. 42 | 43 | The code for LOMO is in [lomo](lomo) folder. 44 | 45 | # AdaLomo: Low-memory Optimization with Adaptive Learning Rate 46 | 47 | In this work, we examined the distinctions between the LOMO and Adam optimization techniques and introduce AdaLomo, which provides an adaptive learning rate for each parameter and utilizes grouped update normalization while maintaining memory efficiency. 48 | AdaLomo achieves results comparable to AdamW in both instruction-tuning and further pre-training with less memory footprint. 49 | 50 | ![AdaLomo](https://raw.githubusercontent.com/OpenLMLab/LOMO/main/assets/adalomo_algorithm.png) 51 | 52 | The code for AdaLomo is in [adalomo](adalomo) folder. 53 | 54 | ## Citation 55 | ```text 56 | @article{lv2023full, 57 | title={Full Parameter Fine-tuning for Large Language Models with Limited Resources}, 58 | author={Lv, Kai and Yang, Yuqing and Liu, Tengxiao and Gao, Qinghui and Guo, Qipeng and Qiu, Xipeng}, 59 | journal={arXiv preprint arXiv:2306.09782}, 60 | year={2023} 61 | } 62 | @article{lv2023adalomo, 63 | title={AdaLomo: Low-memory Optimization with Adaptive Learning Rate}, 64 | author={Lv, Kai and Yan, Hang and Guo, Qipeng and Lv, Haijun and Qiu, Xipeng}, 65 | journal={arXiv preprint arXiv:2310.10195}, 66 | year={2023} 67 | } 68 | ``` 69 | -------------------------------------------------------------------------------- /README_ZH.md: -------------------------------------------------------------------------------- 1 | [**English**](./README.md) | [**中文**](./README_ZH.md) 2 | 3 | 论文 [Full Parameter Fine-Tuning for Large Language Models with Limited Resources](https://arxiv.org/pdf/2306.09782.pdf) 和 [AdaLomo: Low-memory Optimization with Adaptive Learning Rate](https://arxiv.org/pdf/2310.10195.pdf) 的实现. 4 | 5 | # 新闻 6 | - LOMO 和 AdaLomo 集成到了 [`transformers`](https://huggingface.co/docs/transformers/main/en/trainer#lomo-optimizer) 和 [`accelerate`](https://huggingface.co/docs/accelerate/main/en/package_reference/accelerator#accelerate.Accelerator.lomo_backward) 中. 7 | - 发布了 PyPI 包 `lomo-optim`. 8 | - LOMO 和 AdaLomo 已经集成到了 [`CoLLiE`](https://github.com/OpenLMLab/collie) (Collaborative Training of Large Language Models in an Efficient Way) 中。 9 | 10 | # Usage 11 | 可以使用 pip 从 PyPI 安装 `lomo-optim` 包。 12 | 13 | ```bash 14 | pip install lomo-optim 15 | ``` 16 | 17 | 然后,从 `lomo_optim` 中导入 `Lomo` 或 `AdaLomo` 18 | 19 | ```python 20 | from lomo_optim import Lomo 21 | from lomo_optim import AdaLomo 22 | ``` 23 | `Lomo` 和 `AdaLomo` 的使用方法与 PyTorch 的优化器类似,但不完全相同([示例](https://github.com/OpenMOSS/CoLLiE/blob/726ec80d263c1e1c56344dfde5b3c24897daa94d/collie/controller/trainer.py#L469))。 24 | 推荐使用 `AdaLomo` 并且不加 `gradnorm` 来获得更好的性能同时维持更高的吞吐量。 25 | 26 | # LOMO: LOw-Memory Optimization 27 | 28 | 在这个工作中,我们提出了一个新的优化器,**LO**w-Memory **O**ptimization (**LOMO**),它将梯度计算和参数更新融合在一步中,以减少内存使用。 29 | 我们的方法使得在单张 RTX 3090 上可以进行 7B 模型的全参数微调,或者在单个 8×RTX 3090 的机器上可以进行 65B 模型的全参数微调(RTX 3090 的内存为 24GB)。 30 | 31 | ![LOMO](assets/LOMO.png) 32 | 33 | ## 实现 34 | ![Hook function](assets/hook_func.png) 35 | 我们通过在PyTorch的反向传播过程中注入钩子函数实现我们的方法。如图中所示,我们为模型的每一个参数都注册了自定义的钩子函数。当一个参数的梯度计算完毕之后(但还没有写入到.grad),它对应的钩子函数就被调用了。更多关于钩子函数和反向传播的介绍可以参考[PyTorch的官方文档](https://pytorch.org/docs/stable/notes/autograd.html#backward-hooks-execution)。简而言之,反向过程会从一个张量到它的梯度函数,然后把梯度写入.grad,再传递到下一个张量。 36 | 37 | 我们的自定义钩子函数会扫描所有的参数,如果发现有.grad不为空的参数就进行更新,然后清空并释放相应的.grad。因为一个参数的钩子函数会在它的.grad还未被赋值时调用,整个求导图的最后一个参数的钩子函数调用时,它的.grad还不可用。因此,我们额外进行一次扫描来更新最后一个参数。 38 | 39 | LOMO的代码在 [lomo](lomo) 文件夹中。 40 | 41 | # AdaLomo: Low-memory Optimization with Adaptive Learning Rate 42 | 43 | 在这个工作中,我们研究了LOMO和Adam优化技术之间的区别,并介绍了AdaLomo,它为每个参数提供自适应的学习率,并在保持内存效率的同时利用了分组更新归一化。 44 | AdaLomo在指令微调和继续预训练中实现了与AdamW相当的结果,但占用的显存更少。 45 | 46 | ![AdaLomo](assets/adalomo_algorithm.png) 47 | 48 | AdaLomo的代码在 [adalomo](adalomo) 文件夹中。 49 | 50 | ## 引用 51 | ```text 52 | @article{lv2023full, 53 | title={Full Parameter Fine-tuning for Large Language Models with Limited Resources}, 54 | author={Lv, Kai and Yang, Yuqing and Liu, Tengxiao and Gao, Qinghui and Guo, Qipeng and Qiu, Xipeng}, 55 | journal={arXiv preprint arXiv:2306.09782}, 56 | year={2023} 57 | } 58 | @article{lv2023adalomo, 59 | title={AdaLomo: Low-memory Optimization with Adaptive Learning Rate}, 60 | author={Lv, Kai and Yan, Hang and Guo, Qipeng and Lv, Haijun and Qiu, Xipeng}, 61 | journal={arXiv preprint arXiv:2310.10195}, 62 | year={2023} 63 | } 64 | ``` 65 | -------------------------------------------------------------------------------- /adalomo/README.md: -------------------------------------------------------------------------------- 1 | [**English**](./README.md) | [**中文**](./README_ZH.md) 2 | 3 | # AdaLomo: Low-memory Optimization with Adaptive Learning Rate 4 | 5 | This is the code for [AdaLomo: Low-memory Optimization with Adaptive Learning Rate](https://arxiv.org/pdf/2310.10195.pdf). 6 | 7 | In this work, we examined the distinctions between the LOMO and Adam optimization techniques and introduce AdaLomo, which provides an adaptive learning rate for each parameter and utilizes grouped update normalization while maintaining memory efficiency. 8 | AdaLomo achieves results comparable to AdamW in both instruction-tuning and further pre-training with less memory footprint. 9 | 10 | ![AdaLomo](../assets/adalomo_algorithm.png) 11 | 12 | ## Dependencies 13 | ```shell 14 | collie-lm 15 | ``` 16 | 17 | AdaLomo is implemented at [https://github.com/OpenLMLab/collie/blob/dev/collie/optim/adalomo.py](https://github.com/OpenLMLab/collie/blob/dev/collie/optim/adalomo.py). 18 | 19 | ## Instruction-tuning 20 | We use Alpaca-GPT4 as our training dataset, which is available at https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM/blob/main/data/alpaca_gpt4_data.json. 21 | 22 | ### Download the dataset 23 | ```shell 24 | cd instruction-tuning 25 | wget https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM/blob/main/data/alpaca_gpt4_data.json 26 | ``` 27 | 28 | ### Training 29 | ```shell 30 | torchrun --nproc_per_node=8 train.py --optim adalomo --model_size 7b 31 | ``` 32 | 33 | ### Evaluation 34 | The evaluation is based on opencompass. Below are the steps for quick installation. 35 | ```shell 36 | conda create --name opencompass python=3.10 pytorch torchvision pytorch-cuda -c nvidia -c pytorch -y 37 | conda activate opencompass 38 | git clone https://github.com/KaiLv69/opencompass opencompass 39 | cd opencompass 40 | pip install -e . 41 | ``` 42 | Below are the steps for evaluation. 43 | ```shell 44 | python run.py configs/eval_collie.py -r 45 | ``` 46 | `-r` is for resuming the previous evaluation process. 47 | 48 | You may refer to `opencompass/configs/eval_collie.py` for more details. 49 | 50 | ## Further pre-training 51 | 52 | ### Get dataset 53 | 54 | Download python subset of StarCoder and set the path in the `get_dataset()` in `further-pretraining/train.py`. 55 | 56 | ### Training 57 | ```shell 58 | torchrun --nproc_per_node=8 train.py --optim adalomo --model_size 7b 59 | ``` 60 | -------------------------------------------------------------------------------- /adalomo/README_ZH.md: -------------------------------------------------------------------------------- 1 | [**English**](./README.md) | [**中文**](./README_ZH.md) 2 | 3 | # AdaLomo: Low-memory Optimization with Adaptive Learning Rate 4 | 5 | 这是 [AdaLomo: Low-memory Optimization with Adaptive Learning Rate](https://arxiv.org/pdf/2310.10195.pdf) 的代码。 6 | 7 | 在这个工作中,我们研究了 LOMO 和 Adam 优化技术之间的区别,并提出了 AdaLomo,它为每个参数提供自适应学习率,并利用分组更新归一化来保持内存效率。 8 | AdaLomo 在指令微调和进一步预训练中的结果与 AdamW 相当,但内存占用更少。 9 | 10 | ![AdaLomo](../assets/adalomo_algorithm.png) 11 | 12 | ## 依赖 13 | ```shell 14 | collie-lm 15 | ``` 16 | 17 | AdaLomo 在 [https://github.com/OpenLMLab/collie/blob/dev/collie/optim/adalomo.py](https://github.com/OpenLMLab/collie/blob/dev/collie/optim/adalomo.py) 中实现。 18 | 19 | ## 指令微调 20 | 我们使用 Alpaca-GPT4 作为我们的训练数据集,该数据集可在 https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM/blob/main/data/alpaca_gpt4_data.json 获取. 21 | 22 | ### 下载数据集 23 | ```shell 24 | cd instruction-tuning 25 | wget https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM/blob/main/data/alpaca_gpt4_data.json 26 | ``` 27 | 28 | ### 训练 29 | ```shell 30 | torchrun --nproc_per_node=8 train.py --optim adalomo --model_size 7b 31 | ``` 32 | 33 | ### 评估 34 | 评估基于 opencompass。以下是快速安装步骤。 35 | ```shell 36 | conda create --name opencompass python=3.10 pytorch torchvision pytorch-cuda -c nvidia -c pytorch -y 37 | conda activate opencompass 38 | git clone https://github.com/KaiLv69/opencompass opencompass 39 | cd opencompass 40 | pip install -e . 41 | ``` 42 | 以下是评估步骤。 43 | ```shell 44 | python run.py configs/eval_collie.py -r 45 | ``` 46 | `-r` 用于恢复之前的评估过程。 47 | 48 | 您可以参考 `opencompass/configs/eval_collie.py` 了解更多细节。 49 | 50 | ## 继续预训练 51 | 52 | ### 获取数据集 53 | 54 | 下载 StarCoder 的 python 子集,并在 `further-pretraining/train.py` 的 `get_dataset()` 中设置路径。 55 | 56 | ### 训练 57 | ```shell 58 | torchrun --nproc_per_node=8 train.py --optim adalomo --model_size 7b 59 | ``` 60 | -------------------------------------------------------------------------------- /adalomo/further-pretraining/evaluate.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Optional, Tuple, Any, Dict 2 | 3 | import torch 4 | from peft import PeftModel 5 | from torch import nn 6 | import torch.distributed as dist 7 | from collie import ColliePadder, GPTLMLoss, auto_param_call, BaseMetric 8 | from collie.module import PipelineModel 9 | from collie.controller.evaluator import Evaluator 10 | 11 | 12 | class EvaluatorForPretraining(Evaluator): 13 | def __init__(self, 14 | loss_fn: Callable = GPTLMLoss(), 15 | collate_fn: Optional[Callable] = ColliePadder(), 16 | *args, 17 | **kwargs): 18 | self.loss_fn = loss_fn 19 | super().__init__(collate_fn=collate_fn, *args, **kwargs) 20 | 21 | @staticmethod 22 | @torch.no_grad() 23 | def eval_fn(evaluator, batch: Dict) -> Any: 24 | """一次验证的基本单元 25 | 26 | :param evaluator: 训练器 27 | :param batch: 一个 batch 的数据,类型为长度为 ``Dict``,格式为: 28 | 29 | .. code-block:: 30 | { 31 | "input_ids": torch.tensor([[1, 100, 100, 2]]), 32 | "labels": torch.tensor([[1, 100, 100, 2]]), 33 | } 34 | 35 | :return: 一次验证的结果,为 `Dict` 类型,该结果会被传入 `metric` 的 `update` 方法中 36 | """ 37 | # concat prompt labels for p-tuning 38 | if evaluator.config.peft_config and evaluator.config.peft_config.peft_type in ["PROMPT_TUNING", "P_TUNING"]: 39 | batch_size = batch["input_ids"].shape[0] 40 | if "labels" in batch.keys(): 41 | prefix_labels = torch.full((batch_size, evaluator.config.peft_config.num_virtual_tokens), -100).to( 42 | batch["labels"].device) 43 | batch["labels"] = torch.cat((prefix_labels, batch["labels"]), dim=1) 44 | if evaluator.config.pp_size > 1: 45 | if isinstance(evaluator.engine.module, PipelineModel): 46 | evaluator.engine.module.forward_type = "eval" 47 | if isinstance(evaluator.engine.module, PeftModel) and isinstance(evaluator.engine.module.get_base_model(), 48 | PipelineModel): 49 | evaluator.engine.module.get_base_model().forward_type = "eval" 50 | outputs = evaluator.engine.module(**batch) 51 | else: 52 | outputs = evaluator.engine(**batch) 53 | loss = auto_param_call(evaluator.loss_fn, {**batch, **outputs}, 54 | signature_fn=evaluator.loss_fn.forward if isinstance(evaluator.loss_fn, 55 | nn.Module) else evaluator.loss_fn) 56 | ppl = torch.exp(loss) 57 | 58 | # calculate acc 59 | pred = torch.argmax(outputs["logits"], dim=-1)[..., :-1].contiguous() # bs seq_len 60 | shifted_labels = batch['labels'][..., 1:].contiguous().to(pred.device) 61 | valid_mask = (shifted_labels != -100) 62 | correct = (pred == shifted_labels) & valid_mask 63 | correct_count = correct.float().sum().item() 64 | total = valid_mask.float().sum().item() 65 | return { 66 | "ppl": ppl.detach().clone().view(1, ).cuda(), 67 | "correct": correct_count, 68 | "total": total 69 | } 70 | 71 | 72 | class AccMetric(BaseMetric): 73 | def __init__(self, gather_result: bool = False) -> None: 74 | super().__init__(gather_result) 75 | self.total = 0 76 | self.correct = 0 77 | 78 | def reset(self): 79 | self.total = 0 80 | self.correct = 0 81 | 82 | def get_metric(self) -> Optional[Dict]: 83 | return {'acc': round(self.correct / (self.total + 1e-12), 6), "total": self.total, "correct": self.correct} 84 | 85 | def update(self, result: Dict): 86 | self.total += result['total'] 87 | self.correct += result["correct"] 88 | 89 | def gather(self, result): 90 | if self.trainer.config.dp_size > 1: 91 | group = self.trainer.engine.mpu.get_data_parallel_group() 92 | for key in result.keys(): 93 | if key in ["total", "correct"]: 94 | gather_list = [None for _ in range(self.trainer.config.dp_size)] 95 | dist.all_gather_object(gather_list, result[key], group=group) 96 | result[key] = sum(gather_list) 97 | return result 98 | -------------------------------------------------------------------------------- /adalomo/further-pretraining/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | import torch 5 | from transformers import get_cosine_schedule_with_warmup, LlamaTokenizer 6 | from collie import EvalMonitor, PPLMetric, AccuracyMetric, EvaluatorForPerplexity, Callback 7 | from collie.config import CollieConfig 8 | from collie.controller.trainer import Trainer 9 | from collie.module import GPTLMLoss 10 | from collie.log import logger 11 | from collie.optim import AdaLomo 12 | from collie.models.llama.model import LlamaForCausalLM 13 | from collie.utils.monitor import StepTimeMonitor, TGSMonitor, MemoryMonitor, LossMonitor, LRMonitor 14 | from collie.data import CollieDatasetForTraining 15 | 16 | from evaluate import EvaluatorForPretraining, AccMetric 17 | 18 | 19 | def get_model(): 20 | model_name = f'huggyllama/llama-{args.model_size}' 21 | 22 | config = CollieConfig.from_pretrained(model_name) 23 | 24 | config.tp_size = args.tp 25 | config.dp_size = args.dp 26 | config.pp_size = args.pp 27 | config.train_epochs = args.train_epochs 28 | config.train_micro_batch_size = args.micro_batch 29 | config.eval_batch_size = 2 * args.micro_batch 30 | config.eval_per_n_steps = 100 31 | 32 | config.ds_config = { 33 | "fp16": {"enabled": True}, 34 | "monitor_config": { 35 | "enabled": True, 36 | "tag": f"{args.optim}-lr-{args.lr}_epoch-{config.train_epochs}_tp{args.tp}_dp{args.dp}_bs{args.micro_batch}", 37 | "wandb": { 38 | "enabled": True, 39 | "team": "collie_exp", 40 | "project": "adalomo", 41 | "group": f"llama-{args.model_size}-{args.domain}", 42 | } 43 | }, 44 | "zero_optimization": {"stage": 0 if args.dp == 1 else 3}, 45 | "zero_allow_untested_optimizer": True, 46 | } 47 | 48 | model = LlamaForCausalLM.from_pretrained(model_name, config) 49 | model.set_cache(False) 50 | tokenizer = LlamaTokenizer.from_pretrained(model_name, padding_side="left") 51 | tokenizer.pad_token = tokenizer.bos_token 52 | return model, tokenizer, config 53 | 54 | 55 | def get_dataset(tokenizer): 56 | train_dataset = [] 57 | eval_dataset = [] 58 | eval_num = 2000 59 | if args.domain == "python": 60 | for i in range(10): 61 | logger.info(f"Loading train dataset {i}") 62 | data_path = f"/path_to_starcoder/train/code/python/train-000{i:02}-of-00059_train.jsonl" 63 | with open(data_path, "r") as f: 64 | for line in f: 65 | train_dataset.append({'text': json.loads(line)['content']}) 66 | data_path = f"/path_to_starcoder/train/code/python/train-00010-of-00059_train.jsonl" 67 | num = 0 68 | with open(data_path, "r") as f: 69 | for line in f: 70 | num += 1 71 | if num > eval_num: 72 | break 73 | eval_dataset.append({'text': json.loads(line)['content']}) 74 | elif args.domain == "cn": 75 | data_path = f"/path_to_baidu-baike/merged-1.jsonl" 76 | data_num = 2e6 77 | num = 0 78 | with open(data_path, "r") as f: 79 | for line in f: 80 | num += 1 81 | if num > data_num: 82 | if num < data_num + eval_num: 83 | eval_dataset.append({'text': json.loads(line)['content']}) 84 | continue 85 | else: 86 | break 87 | train_dataset.append({'text': json.loads(line)['content']}) 88 | else: 89 | raise ValueError(f"domain {args.domain} not supported") 90 | 91 | train_dataset = CollieDatasetForTraining(train_dataset, tokenizer=tokenizer, max_length=2048) 92 | eval_dataset = CollieDatasetForTraining(eval_dataset, tokenizer=tokenizer, max_length=2048) 93 | return train_dataset, eval_dataset 94 | 95 | 96 | def train(): 97 | model, tokenizer, config = get_model() 98 | train_dataset, eval_dataset = get_dataset(tokenizer) 99 | if args.optim == "adalomo": 100 | optimizer = AdaLomo( 101 | model, 102 | lr=args.lr, 103 | loss_scale=2 ** 10, 104 | ) 105 | elif args.optim == "adamw": 106 | optimizer = torch.optim.AdamW( 107 | model.parameters(), 108 | betas=(0.9, 0.95), 109 | lr=args.lr, 110 | ) 111 | else: 112 | raise ValueError(f"optim {args.optim} not support") 113 | 114 | total_step = (len(train_dataset) * config.train_epochs) // (args.micro_batch * args.dp) 115 | lr_scheduler = get_cosine_schedule_with_warmup( 116 | optimizer, 117 | num_warmup_steps=int(total_step * 0.03), 118 | num_training_steps=total_step 119 | ) 120 | 121 | evaluator = EvaluatorForPretraining( 122 | model=model, 123 | config=config, 124 | dataset=eval_dataset, 125 | monitors=[ 126 | EvalMonitor(config) 127 | ], 128 | metrics={ 129 | 'ppl': PPLMetric(gather_result=True), 130 | 'acc': AccMetric(gather_result=True), 131 | } 132 | ) 133 | 134 | monitors = [ 135 | StepTimeMonitor(config), 136 | TGSMonitor(config), 137 | MemoryMonitor(config), 138 | LossMonitor(config), 139 | LRMonitor(config) 140 | ] 141 | 142 | trainer = Trainer( 143 | model=model, 144 | config=config, 145 | loss_fn=GPTLMLoss(-100), 146 | optimizer=optimizer, 147 | train_dataset=train_dataset, 148 | monitors=monitors, 149 | lr_scheduler=lr_scheduler, 150 | evaluators=[evaluator], 151 | ) 152 | 153 | trainer.train() 154 | 155 | 156 | if __name__ == "__main__": 157 | parser = argparse.ArgumentParser() 158 | parser.add_argument("--pp", default=1, type=int) 159 | parser.add_argument("--tp", default=1, type=int) 160 | parser.add_argument("--dp", default=8, type=int) 161 | parser.add_argument("--model_size", default="7b", type=str) 162 | parser.add_argument("--micro_batch", default=16, type=int) 163 | parser.add_argument("--train_epochs", default=1, type=int) 164 | parser.add_argument("--lr", default=1e-3, type=float) 165 | 166 | parser.add_argument("--optim", default="adalomo", type=str, choices=["adalomo", "adamw"]) 167 | parser.add_argument("--domain", default="python", type=str) 168 | 169 | args = parser.parse_args() 170 | train() 171 | -------------------------------------------------------------------------------- /adalomo/instruction-tuning/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | import torch 5 | from peft import LoraConfig, TaskType 6 | from transformers import get_linear_schedule_with_warmup, LlamaTokenizer 7 | from collie.config import CollieConfig 8 | from collie.controller.trainer import Trainer 9 | from collie.module import GPTLMLoss 10 | from collie.log import logger 11 | from collie.optim import AdaLomo 12 | from collie.models.llama.model import LlamaForCausalLM 13 | from collie.utils.monitor import StepTimeMonitor, TGSMonitor, MemoryMonitor, LossMonitor, LRMonitor 14 | from collie.data import CollieDatasetForTraining 15 | 16 | 17 | def get_model(): 18 | model_name = f'huggyllama/llama-{args.model_size}' 19 | config = CollieConfig.from_pretrained(model_name) 20 | config.tp_size = args.tp 21 | config.dp_size = args.dp 22 | config.pp_size = 1 23 | config.train_epochs = args.train_epochs 24 | config.train_micro_batch_size = args.micro_batch 25 | config.gradient_accumulation_steps = 1 26 | config.ds_config = { 27 | "fp16": {"enabled": True}, 28 | "monitor_config": { 29 | "enabled": True, 30 | "tag": f"{args.optim}-lr-{args.lr}_epoch-{config.train_epochs}_tp{args.tp}_dp{args.dp}_bs{args.micro_batch}", 31 | "wandb": { 32 | "enabled": True, 33 | "team": "collie_exp", 34 | "project": "adalomo", 35 | "group": f"llama-{args.model_size}-alpaca", 36 | } 37 | }, 38 | "zero_optimization": {"stage": 0 if args.dp == 1 else 3}, 39 | } 40 | if args.optim == "lora": 41 | config.peft_config = LoraConfig( 42 | r=8, 43 | lora_alpha=16, 44 | lora_dropout=0.05, 45 | # target_modules=["q_proj", "v_proj"], 46 | target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "down_proj", "up_proj"], 47 | bias="none", 48 | task_type=TaskType.CAUSAL_LM 49 | ) 50 | 51 | model = LlamaForCausalLM.from_pretrained(model_name, config=config) 52 | tokenizer = LlamaTokenizer.from_pretrained(model_name, padding_side='left') 53 | tokenizer.pad_token = tokenizer.bos_token 54 | return model, tokenizer, config 55 | 56 | 57 | def get_dataset(tokenizer): 58 | template_input = ( 59 | "Below is an instruction that describes a task, paired with an input that provides further context. " 60 | "Write a response that appropriately completes the request.\n\n" 61 | "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:" 62 | ) 63 | template_no_input = ( 64 | "Below is an instruction that describes a task. " 65 | "Write a response that appropriately completes the request.\n\n" 66 | "### Instruction:\n{instruction}\n\n### Response:" 67 | ) 68 | 69 | with open("alpaca_gpt4_data.json", "r", encoding="utf8") as fp: 70 | json_dataset = json.load(fp) 71 | train_dataset = [ 72 | { 73 | "input": template_input.format_map(example) if example.get("input", "") != "" 74 | else template_no_input.format_map(example), 75 | "output": example['output'] + tokenizer.eos_token 76 | } for example in json_dataset 77 | ] 78 | train_dataset = CollieDatasetForTraining(train_dataset, tokenizer=tokenizer, max_length=2048) 79 | logger.info(f"Train dataset len: {len(train_dataset)}\nTrain dataset[0]: {train_dataset[0]}") 80 | return train_dataset 81 | 82 | 83 | def train(): 84 | model, tokenizer, config = get_model() 85 | train_dataset = get_dataset(tokenizer) 86 | if args.optim == "adalomo": 87 | optimizer = AdaLomo( 88 | model, 89 | lr=args.lr, 90 | loss_scale=2 ** 10, 91 | ) 92 | elif args.optim == "adamw": 93 | optimizer = torch.optim.AdamW( 94 | model.parameters(), 95 | betas=(0.9, 0.95), 96 | lr=args.lr, 97 | ) 98 | elif args.optim == "lora": 99 | optimizer = torch.optim.AdamW( 100 | filter(lambda p: p.requires_grad, model.parameters()), 101 | betas=(0.9, 0.95), 102 | lr=args.lr 103 | ) 104 | else: 105 | raise ValueError(f"optim {args.optim} not support") 106 | 107 | total_step = (len(train_dataset) * config.train_epochs) // (args.micro_batch * args.dp) 108 | lr_scheduler = get_linear_schedule_with_warmup( 109 | optimizer, 110 | num_warmup_steps=int(total_step * 0.03), 111 | num_training_steps=total_step 112 | ) 113 | 114 | monitors = [ 115 | StepTimeMonitor(config), 116 | TGSMonitor(config), 117 | MemoryMonitor(config), 118 | LossMonitor(config), 119 | LRMonitor(config) 120 | ] 121 | 122 | trainer = Trainer( 123 | model=model, 124 | config=config, 125 | loss_fn=GPTLMLoss(-100), 126 | optimizer=optimizer, 127 | train_dataset=train_dataset, 128 | monitors=monitors, 129 | lr_scheduler=lr_scheduler, 130 | ) 131 | trainer.train() 132 | logger.info("Save Model") 133 | save_path = f"./llama-{args.model_size}/{args.optim}_lr-{args.lr}_epoch-{args.train_epochs}_tp{args.tp}_dp{args.dp}_bs{args.micro_batch}" 134 | trainer.save_model(save_path) 135 | 136 | 137 | if __name__ == '__main__': 138 | parser = argparse.ArgumentParser() 139 | parser.add_argument("--pp", default=1, type=int) 140 | parser.add_argument("--tp", default=1, type=int) 141 | parser.add_argument("--dp", default=8, type=int) 142 | parser.add_argument("--model_size", default="7b", type=str) 143 | parser.add_argument("--micro_batch", default=16, type=int) 144 | parser.add_argument("--train_epochs", default=3, type=int) 145 | parser.add_argument("--lr", default=5e-4, type=float) 146 | parser.add_argument("--optim", default="adalomo", type=str, choices=["adalomo", "adamw", "lora"]) 147 | args = parser.parse_args() 148 | train() 149 | -------------------------------------------------------------------------------- /assets/LOMO.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenLMLab/LOMO/45d4bac16c642f101ed00db7244cacf28e5dde15/assets/LOMO.png -------------------------------------------------------------------------------- /assets/adalomo_algorithm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenLMLab/LOMO/45d4bac16c642f101ed00db7244cacf28e5dde15/assets/adalomo_algorithm.png -------------------------------------------------------------------------------- /assets/hook_func.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenLMLab/LOMO/45d4bac16c642f101ed00db7244cacf28e5dde15/assets/hook_func.png -------------------------------------------------------------------------------- /lomo/README.md: -------------------------------------------------------------------------------- 1 | [**English**](./README.md) | [**中文**](./README_ZH.md) 2 | 3 | # LOMO: LOw-Memory Optimization 4 | 5 | This is the implementation for [Full Parameter Fine-Tuning for Large Language Models with Limited Resources](https://arxiv.org/pdf/2306.09782.pdf). 6 | 7 | In this work, we propose a new optimizer, **LO**w-Memory **O**ptimization (**LOMO**), which fuses the gradient computation and the parameter update in one step to reduce memory usage. 8 | Our approach enables the full parameter fine-tuning of a 7B model on a single RTX 3090, or 9 | a 65B model on a single machine with 8×RTX 3090, each with 24GB memory. 10 | 11 | ![LOMO](../assets/LOMO.png) 12 | 13 | ## Dependencies 14 | ```shell 15 | torch 16 | deepspeed 17 | transformers 18 | peft 19 | wandb 20 | ``` 21 | The minimum dependency is PyTorch, and others are used to reproduce our paper results. 22 | 23 | ## Run the code 24 | 25 | We provide code for fine-tuning Large Language Models (LLMs) using three different approaches: **LOMO**, **LoRA**, and **LoRA + LOMO**. 26 | 27 | 1. For full parameter fine-tuning using LOMO, the implementation is in `src/lomo_trainer.py`, and you can run: 28 | ```shell 29 | deepspeed --master_port "$port" --include localhost:"$CUDA_VISIBLE_DEVICES" src/train_lomo.py config/args_lomo.yaml 30 | ``` 31 | 32 | 2. For LoRA and LoRA + LOMO, the implementation is in `src/lomo_lora_trainer.py`, and you can run: 33 | ```shell 34 | deepspeed --master_port "$port" --include localhost:"$CUDA_VISIBLE_DEVICES" src/train_lomo_lora.py config/args_lomo_lora.yaml 35 | ``` 36 | In the code, we have included the `lora_only` argument in `src/arguments.py`, which controls whether to use LoRA only or LoRA + LOMO. Please note that when `lora_only` is set to `True`, the arguments related to LOMO will not work. 37 | 38 | Besides, we provide a simple `run.sh` script for convenience. You can execute the code using the following command: 39 | ```shell 40 | bash run.sh 41 | ``` 42 | 43 | For data processing, we currently only provide the six datasets of SuperGLUE mentioned in the paper. If you wish to use new datasets, please modify the `Dataset` and `DataCollator` accordingly. 44 | 45 | For evaluation, we currently only provide the `eval_step` codes for [multiple-choice QA](https://github.com/OpenLMLab/LOMO/blob/91cc71387d0a576c000a7dc568543c4ef22401db/src/lomo_trainer.py#L259-L276) and [generation](https://github.com/OpenLMLab/LOMO/blob/91cc71387d0a576c000a7dc568543c4ef22401db/src/lomo_trainer.py#L278-L297) tasks. If you have other requirements, please modify the `eval_step` code in the `LOMOTrainer` or `LOMOLoRATrainer` accordingly and provide the necessary `compute_metrics` to the trainer. 46 | 47 | ## Reproduce our results 48 | We provide the sampled datasets used in our experiments [here](https://drive.google.com/drive/folders/1zV7sXvU7YHKWyS3fYV0yyi7FyTjIpEuO?usp=sharing). 49 | Due to the limited computational resources, we reported the highest results obtained from experiments conducted with the same random seed (`42`). 50 | We acknolwedge this limitation in our work and plan to conduct repeated experiments in the next version to address it. 51 | 52 | > Feel free to raise issues if you have any questions. 53 | 54 | ## Implementation 55 | ![Hook function](../assets/hook_func.png) 56 | Our implementation relies on injecting hook functions into PyTorch's backward pass. As depicted in the figure, we register a customized hook function for each parameter. When the gradient of a parameter is computed (prior to writing it to the .grad attribute), its corresponding hook function is invoked. For more information about hook functions and the backward pass of the autograd graph, please refer to [PyTorch's documentation](https://pytorch.org/docs/stable/notes/autograd.html#backward-hooks-execution). In summary, during the backward pass, we go through a tensor and its grad_fn, write the gradient into the .grad attribute, and then pass to the next tensor. 57 | 58 | Our customized hook function scans all the parameters, updating a parameter if its .grad attribute is not empty, and then clears and frees the .grad attribute. Since the hook function for a parameter is called before its .grad attribute is set, the .grad attribute of the last parameter in the autograd graph is not ready when the last hook function is invoked. Therefore, we perform an additional scan to update the last parameter. 59 | 60 | ## Citation 61 | ```text 62 | @article{lv2023full, 63 | title={Full Parameter Fine-tuning for Large Language Models with Limited Resources}, 64 | author={Lv, Kai and Yang, Yuqing and Liu, Tengxiao and Gao, Qinghui and Guo, Qipeng and Qiu, Xipeng}, 65 | journal={arXiv preprint arXiv:2306.09782}, 66 | year={2023} 67 | } 68 | ``` -------------------------------------------------------------------------------- /lomo/README_ZH.md: -------------------------------------------------------------------------------- 1 | [**English**](./README.md) | [**中文**](./README_ZH.md) 2 | 3 | # LOMO: LOw-Memory Optimization 4 | 5 | 论文 [Full Parameter Fine-Tuning for Large Language Models with Limited Resources](https://arxiv.org/pdf/2306.09782.pdf) 的实现. 6 | 7 | 在这个工作中,我们提出了一个新的优化器,**LO**w-Memory **O**ptimization (**LOMO**),它将梯度计算和参数更新融合在一步中,以减少内存使用。 8 | 我们的方法使得在单张 RTX 3090 上可以进行 7B 模型的全参数微调,或者在单个 8×RTX 3090 的机器上可以进行 65B 模型的全参数微调(RTX 3090 的内存为 24GB)。 9 | 10 | ![LOMO](../assets/LOMO.png) 11 | 12 | ## 依赖 13 | ```shell 14 | torch 15 | deepspeed 16 | transformers 17 | peft 18 | wandb 19 | ``` 20 | LOMO本身只依赖 PyTorch,其他依赖用于复现我们的论文结果。 21 | 22 | ## 执行代码 23 | 24 | 我们提供了三种不同方法来微调大型语言模型(LLM)的代码:**LOMO**,**LoRA**和**LoRA + LOMO**。 25 | 26 | 1. 对于使用LOMO进行完全参数微调的实现位于`src/lomo_trainer.py`,您可以运行以下命令: 27 | ```shell 28 | deepspeed --master_port "$port" --include localhost:"$CUDA_VISIBLE_DEVICES" src/train_lomo.py config/args_lomo.yaml 29 | ``` 30 | 31 | 2. 对于LoRA和LoRA + LOMO的实现位于`src/lomo_lora_trainer.py`,您可以运行以下命令: 32 | ```shell 33 | deepspeed --master_port "$port" --include localhost:"$CUDA_VISIBLE_DEVICES" src/train_lomo_lora.py config/args_lomo_lora.yaml 34 | ``` 35 | 在代码中,我们在`src/arguments.py`中包含了`lora_only`参数,用于控制是否仅使用LoRA或LoRA + LOMO。请注意,当将`lora_only`设置为`True`时,与LOMO相关的参数将不起作用。 36 | 37 | 此外,我们还提供了一个简单的`run.sh`脚本以方便使用。您可以使用以下命令执行代码: 38 | ```shell 39 | bash run.sh 40 | ``` 41 | 42 | 对于数据处理,我们目前只提供了论文中提到的SuperGLUE的六个数据集。如果您希望使用新的数据集,请相应地修改`Dataset`和`DataCollator`。 43 | 44 | 对于评估,我们目前仅为[多项选择问答(multiple-choice QA)](https://github.com/OpenLMLab/LOMO/blob/91cc71387d0a576c000a7dc568543c4ef22401db/src/lomo_trainer.py#L259-L276)和[生成(generation)](https://github.com/OpenLMLab/LOMO/blob/91cc71387d0a576c000a7dc568543c4ef22401db/src/lomo_trainer.py#L278-L297)任务提供了`eval_step`代码。如果您有其他需求,请相应地修改`LOMOTrainer`或`LOMOLoRATrainer`中的`eval_step`代码,并为训练器提供必要的`compute_metrics`函数。 45 | 46 | ## 复现我们的结果 47 | 我们在[这里](https://drive.google.com/drive/folders/1zV7sXvU7YHKWyS3fYV0yyi7FyTjIpEuO?usp=sharing)提供了我们实验中使用的采样数据集。 48 | 由于计算资源有限,我们报告了使用相同随机种子(`42`)进行的实验中获得的最高结果。 49 | 我们在我们的工作中承认了这个限制,并计划在下一个版本中进行重复实验来解决这个问题。 50 | 51 | > 如果您有任何问题,请随时提出问题。 52 | 53 | ## 实现 54 | ![Hook function](../assets/hook_func.png) 55 | 我们通过在PyTorch的反向传播过程中注入钩子函数实现我们的方法。如图中所示,我们为模型的每一个参数都注册了自定义的钩子函数。当一个参数的梯度计算完毕之后(但还没有写入到.grad),它对应的钩子函数就被调用了。更多关于钩子函数和反向传播的介绍可以参考[PyTorch的官方文档](https://pytorch.org/docs/stable/notes/autograd.html#backward-hooks-execution)。简而言之,反向过程会从一个张量到它的梯度函数,然后把梯度写入.grad,再传递到下一个张量。 56 | 57 | 我们的自定义钩子函数会扫描所有的参数,如果发现有.grad不为空的参数就进行更新,然后清空并释放相应的.grad。因为一个参数的钩子函数会在它的.grad还未被赋值时调用,整个求导图的最后一个参数的钩子函数调用时,它的.grad还不可用。因此,我们额外进行一次扫描来更新最后一个参数。 58 | 59 | ## 引用 60 | ```text 61 | @article{lv2023full, 62 | title={Full Parameter Fine-tuning for Large Language Models with Limited Resources}, 63 | author={Lv, Kai and Yang, Yuqing and Liu, Tengxiao and Gao, Qinghui and Guo, Qipeng and Qiu, Xipeng}, 64 | journal={arXiv preprint arXiv:2306.09782}, 65 | year={2023} 66 | } 67 | ``` -------------------------------------------------------------------------------- /lomo/config/args_lomo.yaml: -------------------------------------------------------------------------------- 1 | # model 2 | model_name_or_path: '/remote-home/share/llama_hf/7B' 3 | # data 4 | dataset_name: 'wic' 5 | refresh: false 6 | data_tag: 'base' 7 | train_on_inputs: false 8 | data_max_length: 1024 9 | # training 10 | # trainer 11 | tag: 'lomo' 12 | output_dir: 'outputs' 13 | overwrite_output_dir: true 14 | deepspeed: 'config/ds_config.json' 15 | do_train: true 16 | do_eval: true 17 | evaluation_strategy: 'epoch' 18 | per_device_train_batch_size: 16 19 | per_device_eval_batch_size: 16 20 | learning_rate: 0.03 21 | weight_decay: 0 22 | num_train_epochs: 10 23 | lr_scheduler_type: 'linear' 24 | warmup: 0.1 25 | clip_grad_norm: 1.0 26 | # please set `resume_from_checkpoint` to load checkpoints. 27 | #resume_from_checkpoint: 'outputs/wic_7B_lomo/output_adamw_hf_lr0.03_bs16_warmup0.1_clipnorm1.0/checkpoint-0' 28 | # please set `save_strategy` (`no`, `epoch`, `steps`) and `save_total_limit` (the max amount of checkpoints) to save checkpoints. 29 | save_strategy: 'no' 30 | save_total_limit: 0 31 | seed: 42 32 | #bf16: true 33 | remove_unused_columns: false 34 | load_best_model_at_end: false 35 | metric_for_best_model: 'acc' 36 | group_by_length: false 37 | #report_to: 'wandb' 38 | dataloader_pin_memory: false 39 | gradient_checkpointing: true 40 | predict_with_generate: false -------------------------------------------------------------------------------- /lomo/config/args_lomo_lora.yaml: -------------------------------------------------------------------------------- 1 | # model 2 | model_name_or_path: '/remote-home/share/llama_hf/7B' 3 | # data 4 | dataset_name: 'wic' 5 | refresh: false 6 | data_tag: 'base' 7 | train_on_inputs: false 8 | data_max_length: 1024 9 | # training 10 | # trainer 11 | peft_type: 'lora' 12 | lora_only: false 13 | hf_learning_rate: 0.0005 14 | hf_weight_decay: 0 15 | hf_lr_scheduler_type: 'linear' 16 | hf_warmup: 0.05 17 | tag: 'lora-qv-r2-lomo' 18 | output_dir: 'outputs' 19 | overwrite_output_dir: true 20 | deepspeed: 'config/ds_config_lora.json' 21 | do_train: true 22 | do_eval: true 23 | evaluation_strategy: 'epoch' 24 | per_device_train_batch_size: 16 25 | per_device_eval_batch_size: 16 26 | learning_rate: 0.005 27 | weight_decay: 0 28 | num_train_epochs: 10 29 | lr_scheduler_type: 'linear' 30 | warmup: 0.05 31 | clip_grad_norm: 1.0 32 | #clip_grad_value: 1.0 33 | #clip_loss_value: 5.0 34 | log_level: 'info' 35 | logging_steps: 1 36 | # please set `resume_from_checkpoint` to load checkpoints. check `merge_llama_with_lora.py` first. 37 | #resume_from_checkpoint: 'outputs/wic_7B_lora-qv-r2-lomo/output_lr0.005_bs16_warmup0.05_clipnorm1.0/checkpoint-0/merge_weights' 38 | # please set `save_strategy` (`no`, `epoch`, `steps`) and `save_total_limit` (the max amount of checkpoints) to save checkpoints. 39 | save_strategy: 'no' 40 | save_total_limit: 0 41 | seed: 42 42 | #bf16: true 43 | remove_unused_columns: false 44 | load_best_model_at_end: false 45 | metric_for_best_model: 'acc' 46 | optim: 'sgd' 47 | group_by_length: false 48 | #report_to: 'wandb' 49 | dataloader_pin_memory: false 50 | gradient_checkpointing: true 51 | predict_with_generate: false 52 | lora_r: 2 -------------------------------------------------------------------------------- /lomo/config/ds_config.json: -------------------------------------------------------------------------------- 1 | { 2 | 3 | "bf16": { 4 | "enabled": false 5 | }, 6 | "fp16": { 7 | "enabled": true 8 | }, 9 | "zero_allow_untested_optimizer": true, 10 | "zero_force_ds_cpu_optimizer": false, 11 | 12 | "zero_optimization": { 13 | "stage": 3, 14 | "overlap_comm": true, 15 | "contiguous_gradients": true, 16 | "sub_group_size": 1e8, 17 | "stage3_max_live_parameters": 1e8, 18 | "stage3_max_reuse_distance": 1e8, 19 | "stage3_gather_16bit_weights_on_model_save": true 20 | }, 21 | 22 | 23 | "gradient_accumulation_steps": 1, 24 | "steps_per_print": 2000, 25 | "train_micro_batch_size_per_gpu": 2, 26 | "wall_clock_breakdown": false 27 | } -------------------------------------------------------------------------------- /lomo/config/ds_config_lora.json: -------------------------------------------------------------------------------- 1 | { 2 | 3 | "bf16": { 4 | "enabled": true 5 | }, 6 | "fp16": { 7 | "enabled": false 8 | }, 9 | "zero_allow_untested_optimizer": true, 10 | "zero_force_ds_cpu_optimizer": false, 11 | 12 | "zero_optimization": { 13 | "stage": 3, 14 | "overlap_comm": true, 15 | "contiguous_gradients": true, 16 | "sub_group_size": 1e8, 17 | "stage3_max_live_parameters": 1e8, 18 | "stage3_max_reuse_distance": 1e8, 19 | "stage3_gather_16bit_weights_on_model_save": true 20 | }, 21 | 22 | 23 | "gradient_accumulation_steps": 1, 24 | "steps_per_print": 2000, 25 | "train_micro_batch_size_per_gpu": 2, 26 | "wall_clock_breakdown": false 27 | } -------------------------------------------------------------------------------- /lomo/log/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'logger', 3 | "print" 4 | ] 5 | 6 | from .logger import logger 7 | from .print import print 8 | 9 | -------------------------------------------------------------------------------- /lomo/log/handler.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | from logging import getLevelName 4 | 5 | try: 6 | from tqdm.auto import tqdm 7 | except ImportError: 8 | tqdm = None 9 | 10 | __all__ = [] 11 | 12 | if tqdm is not None: 13 | class TqdmLoggingHandler(logging.Handler): 14 | def __init__(self, level=logging.INFO): 15 | super().__init__(level) 16 | 17 | def emit(self, record): 18 | try: 19 | msg = self.format(record) 20 | tqdm.write(msg) 21 | self.flush() 22 | except (KeyboardInterrupt, SystemExit): 23 | raise 24 | except: 25 | self.handleError(record) 26 | else: 27 | class TqdmLoggingHandler(logging.StreamHandler): 28 | def __init__(self, level=logging.INFO): 29 | super().__init__(sys.stdout) 30 | self.setLevel(level) 31 | 32 | 33 | class StdoutStreamHandler(logging.StreamHandler): 34 | """ 35 | 重载 StreamHandler 使得替换 sys.stdout 的时候能够生效。 36 | 37 | """ 38 | def __init__(self): 39 | super(StdoutStreamHandler, self).__init__() 40 | 41 | def flush(self): 42 | """ 43 | Flushes the stream. 44 | """ 45 | self.acquire() 46 | try: 47 | sys.stdout.flush() 48 | finally: 49 | self.release() 50 | 51 | def emit(self, record): 52 | """ 53 | Emit a record. 54 | 55 | If a formatter is specified, it is used to format the record. 56 | The record is then written to the stream with a trailing newline. If 57 | exception information is present, it is formatted using 58 | traceback.print_exception and appended to the stream. If the stream 59 | has an 'encoding' attribute, it is used to determine how to do the 60 | output to the stream. 61 | """ 62 | try: 63 | msg = self.format(record) 64 | stream = sys.stdout 65 | # issue 35046: merged two stream.writes into one. 66 | stream.write(msg + self.terminator) 67 | self.flush() 68 | except RecursionError: # See issue 36272 69 | raise 70 | except Exception: 71 | self.handleError(record) 72 | 73 | def setStream(self, stream): 74 | """ 75 | Sets the StreamHandler's stream to the specified value, 76 | if it is different. 77 | 78 | Returns the old stream, if the stream was changed, or None 79 | if it wasn't. 80 | """ 81 | raise RuntimeError("Cannot set the stream of FStreamHandler.") 82 | 83 | def __repr__(self): 84 | level = getLevelName(self.level) 85 | name = getattr(sys.stdout, 'name', '') 86 | # bpo-36015: name can be an int 87 | name = str(name) 88 | if name: 89 | name += ' ' 90 | return '<%s %s(%s)>' % (self.__class__.__name__, name, level) 91 | -------------------------------------------------------------------------------- /lomo/log/highlighter.py: -------------------------------------------------------------------------------- 1 | from rich.highlighter import Highlighter 2 | 3 | __all__ = [] 4 | 5 | class ColorHighlighter(Highlighter): 6 | def __init__(self, color='black'): 7 | self.color = color 8 | 9 | def highlight(self, text): 10 | text.stylize(self.color) -------------------------------------------------------------------------------- /lomo/log/logger.py: -------------------------------------------------------------------------------- 1 | r""" 2 | :class:`Logger` 是记录日志的模块,**logger** 封装了 logging 模块的 Logger, 3 | 具体使用方式与直接使用 :class:`logging.Logger` 相同,同时也新增一些简单好用的API 4 | 5 | 使用方式:: 6 | 7 | # logger 可以和 logging.Logger 一样使用 8 | logger.info('your msg') 9 | logger.error('your msg') 10 | 11 | # logger 新增的API 12 | # 将日志输出到文件,以及输出的日志等级 13 | logger.add_file('/path/to/log', level='INFO') 14 | # 定义在命令行中的显示格式和日志等级 15 | logger.set_stdout('tqdm', level='WARN') 16 | # 仅警告一次 17 | logger.warning_once('your msg') 18 | # 分布式训练下,仅在 rank 0 输出警告 19 | logger.rank_zero_warning('your msg') 20 | 21 | """ 22 | 23 | 24 | import logging 25 | import logging.config 26 | from logging import DEBUG, ERROR, INFO, WARNING, CRITICAL, raiseExceptions 27 | import os 28 | import sys 29 | import warnings 30 | from pathlib import Path 31 | from typing import Optional, Union 32 | from rich.logging import RichHandler 33 | import datetime 34 | import torch 35 | 36 | __all__ = [ 37 | 'logger' 38 | ] 39 | 40 | from .handler import StdoutStreamHandler, TqdmLoggingHandler 41 | 42 | 43 | ROOT_NAME = 'LOMO' 44 | 45 | 46 | class LoggerSingleton(type): 47 | _instances = {} 48 | 49 | def __call__(cls, *args, **kwargs): 50 | if cls not in cls._instances: 51 | cls._instances[cls] = super(LoggerSingleton, cls).__call__(*args, **kwargs) 52 | return cls._instances[cls] 53 | 54 | 55 | class LOMOLogger(logging.Logger, metaclass=LoggerSingleton): 56 | def __init__(self, name): 57 | super().__init__(name) 58 | self._warning_msgs = set() 59 | 60 | def add_file(self, path: Optional[Union[str, Path]] = None, level='AUTO', remove_other_handlers: bool = False, 61 | mode: str = "w"): 62 | """ 63 | 将日志输出到 path 中。 64 | 65 | :param path: 若 path 为文件路径(通过 path 是否包含后缀判定 path 是否表示文件名,例如 output.log 会被认为是文件,而 66 | output 则认为是文件夹)则直接写入到给定文件中;如果判定为文件夹,则是在该文件夹下以 时间戳 创建一个日志文件。 67 | :param level: 可选 ['INFO', 'WARNING', 'DEBUG', 'ERROR', 'AUTO'], 其中AUTO表示根据环境变量"LOMO_LOG_LEVEL'进行 68 | 设置。 69 | :param remove_other_handlers: 是否移除其它 handler ,如果移除,则terminal中将不会有 log 输出。 70 | :param mode: 可选为['w', 'a'],如果传入的 path 是存在的文件,'w' 会覆盖原有内容 'a' 则会在文件结尾处继续添加。 71 | :return: 72 | """ 73 | r"""添加日志输出文件和输出级别""" 74 | if level == 'AUTO': 75 | level = parse_level() 76 | return _add_file_handler(self, path, level, remove_other_handlers, mode) 77 | 78 | def set_stdout(self, stdout: str = 'raw', level: str = 'AUTO'): 79 | """ 80 | 设置 log 的 terminal 输出形式。 81 | 82 | :param stdout: 可选['rich', 'naive', 'raw', 'none']。 83 | :param level: 可选 ['INFO', 'WARNING', 'DEBUG', 'ERROR', 'AUTO'], 其中AUTO表示根据环境变量"LOMO_LOG_LEVEL'进行 84 | 设置。 85 | :return: 86 | """ 87 | r"""设置标准输出格式和输出级别""" 88 | if level == 'AUTO': 89 | level = parse_level() 90 | return _set_stdout_handler(self, stdout, level) 91 | 92 | def debug(self, msg, *args, **kwargs): 93 | """ 94 | Delegate a debug call to the underlying log. 95 | """ 96 | if self.isEnabledFor(DEBUG): 97 | kwargs = self._add_rank_info(kwargs) 98 | self._log(DEBUG, msg, args, **kwargs) 99 | 100 | def info(self, msg, *args, **kwargs): 101 | """ 102 | Delegate an info call to the underlying log. 103 | """ 104 | if self.isEnabledFor(INFO): 105 | kwargs = self._add_rank_info(kwargs) 106 | self._log(INFO, msg, args, **kwargs) 107 | 108 | def warning(self, msg, *args, **kwargs): 109 | """ 110 | Delegate a warning call to the underlying log. 111 | """ 112 | if self.isEnabledFor(WARNING): 113 | kwargs = self._add_rank_info(kwargs) 114 | self._log(WARNING, msg, args, **kwargs) 115 | 116 | def warning_once(self, msg, *args, **kwargs): 117 | """ 118 | 相同的 warning 内容只会 warning 一次 119 | 120 | :param msg: 121 | :param args: 122 | :param kwargs: 123 | :return: 124 | """ 125 | if msg not in self._warning_msgs: 126 | if self.isEnabledFor(WARNING): 127 | kwargs = self._add_rank_info(kwargs) 128 | self._log(WARNING, msg, args, **kwargs) 129 | self._warning_msgs.add(msg) 130 | 131 | def rank_zero_warning(self, msg, *args, once=False, **kwargs): 132 | """ 133 | 只在 rank 0 上 warning 。 134 | 135 | :param msg: 136 | :param args: 137 | :param once: 是否只 warning 一次 138 | :param kwargs: 139 | :return: 140 | """ 141 | if os.environ.get('LOCAL_RANK', 0) == 0: 142 | if once: 143 | if msg in self._warning_msgs: 144 | return 145 | self._warning_msgs.add(msg) 146 | 147 | if self.isEnabledFor(WARNING): 148 | kwargs = self._add_rank_info(kwargs) 149 | self._log(WARNING, msg, args, **kwargs) 150 | 151 | def warn(self, msg, *args, **kwargs): 152 | if self.isEnabledFor(WARNING): 153 | kwargs = self._add_rank_info(kwargs) 154 | self._log(WARNING, msg, args, **kwargs) 155 | 156 | def error(self, msg, *args, **kwargs): 157 | """ 158 | Delegate an error call to the underlying log. 159 | """ 160 | if self.isEnabledFor(ERROR): 161 | kwargs = self._add_rank_info(kwargs) 162 | self._log(ERROR, msg, args, **kwargs) 163 | 164 | def exception(self, msg, *args, exc_info=True, **kwargs): 165 | """ 166 | Delegate an exception call to the underlying log. 167 | """ 168 | kwargs = self._add_rank_info(kwargs) 169 | self.error(msg, *args, exc_info=exc_info, **kwargs) 170 | 171 | def critical(self, msg, *args, **kwargs): 172 | """ 173 | Delegate a critical call to the underlying log. 174 | """ 175 | if self.isEnabledFor(CRITICAL): 176 | kwargs = self._add_rank_info(kwargs) 177 | self._log(CRITICAL, msg, args, **kwargs) 178 | 179 | def log(self, level, msg, *args, **kwargs): 180 | """ 181 | Delegate a log call to the underlying log, after adding 182 | contextual information from this adapter instance. 183 | """ 184 | if not isinstance(level, int): 185 | if raiseExceptions: 186 | raise TypeError("level must be an integer") 187 | else: 188 | return 189 | if self.isEnabledFor(level): 190 | kwargs = self._add_rank_info(kwargs) 191 | self._log(level, msg, args, **kwargs) 192 | 193 | def _add_rank_info(self, kwargs): 194 | if torch.distributed.is_initialized(): 195 | extra = kwargs.get('extra', {}) 196 | extra.update({"rank": int(os.environ.get('LOCAL_RANK', 0))}) 197 | kwargs["extra"] = extra 198 | return kwargs 199 | 200 | def setLevel(self, level) -> None: 201 | """ 202 | 设置当前 logger 以及其 handler 的 log 级别 203 | 204 | :param level: 205 | :return: 206 | """ 207 | if isinstance(level, str): 208 | level = level.upper() 209 | super().setLevel(level) 210 | for handler in self.handlers: 211 | handler.setLevel(level) 212 | 213 | def _set_distributed(self): 214 | """ 215 | 在 LOMO 拉起进程的时候,调用一下这个方法,使得能够输出 rank 信息 216 | 217 | :return: 218 | """ 219 | for handler in self.handlers: 220 | if isinstance(handler, logging.FileHandler): 221 | formatter = logging.Formatter(fmt='Rank: %(rank)s - %(asctime)s - %(module)s - [%(levelname)s] - %(message)s', 222 | datefmt='%Y/%m/%d %H:%M:%S') 223 | else: 224 | formatter = logging.Formatter('Rank: %(rank)s - %(message)s') 225 | handler.setFormatter(formatter) 226 | 227 | 228 | def _get_level(level): 229 | if not isinstance(level, int): 230 | level = level.lower() 231 | level = {'info': logging.INFO, 'debug': logging.DEBUG, 232 | 'warn': logging.WARN, 'warning': logging.WARNING, 233 | 'error': logging.ERROR}[level] 234 | return level 235 | 236 | 237 | def _add_file_handler(_logger: logging.Logger, path: Optional[Union[str, Path]] = None, level: str = 'INFO', 238 | remove_other_handlers: bool = False, mode: str = "w"): 239 | if path is None: 240 | path = Path.cwd() 241 | if isinstance(path, str): 242 | path = Path(path) 243 | if not isinstance(path, Path): 244 | raise TypeError("Parameter `path` can only be `str` or `pathlib.Path` type.") 245 | if not path.exists(): 246 | head, tail = os.path.splitext(path) 247 | if tail == '': # 说明没有后缀,理解为是一个folder 248 | path.mkdir(parents=True, exist_ok=True) 249 | else: 250 | # 主进程会帮助我们创建文件夹,但是由于主从进程几乎是同步的,因此到这里时子进程也会尝试创建文件夹,即使主进程会做这件事情; 251 | dirname = os.path.dirname(path) 252 | os.makedirs(dirname, exist_ok=True) 253 | if path.is_dir(): 254 | path = path.joinpath(os.environ.get('LOGGING_TIME', f"{datetime.datetime.now().strftime('%Y-%m-%d-%H_%M_%S_%f')}") + '.log') 255 | 256 | if not isinstance(remove_other_handlers, bool): 257 | raise TypeError("Parameter `remove_other_handlers` can only be `bool` type.") 258 | 259 | if not isinstance(mode, str): 260 | raise TypeError("Parameter 'evaluate_fn' can only be `str` type.") 261 | if mode not in {"w", "a"}: 262 | raise ValueError("Parameter `evaluate_fn` can only be one of these values: ('w', 'a').") 263 | 264 | for h in _logger.handlers: 265 | if isinstance(h, logging.FileHandler): 266 | if os.path.abspath(path) == h.baseFilename: 267 | # file path already added 268 | return 269 | 270 | # File Handler 271 | if int(os.environ.get('LOCAL_RANK', 0)) == 0: 272 | if os.path.exists(path): 273 | assert os.path.isfile(path) 274 | warnings.warn('log already exists in {}'.format(path)) 275 | 276 | dirname = os.path.abspath(os.path.dirname(path)) 277 | os.makedirs(dirname, exist_ok=True) 278 | 279 | # 这里只要检测到是分布式训练,我们就将 evaluate_fn 改为 "a";这样会导致的一个问题在于,如果第二次训练也是分布式训练,logger记录的log不会重新 280 | # 覆盖掉原文件,而是会接着上一次的 log 继续添加; 281 | # 这样做主要是为了解决这样的情形所导致的问题:在分布式训练中,进程 1 比 进程 0 先运行到这里,然后使得进程 0 将进程 1 的 log 覆盖掉; 282 | # if torch.distributed.is_initialized():# and int(os.environ.get(LOMO_GLOBAL_RANK, 0)) != 0: 283 | # mode = "a" 284 | 285 | file_handler = logging.FileHandler(path, mode=mode) 286 | logger.info(f"Writing log to file:{os.path.abspath(path)}") 287 | file_handler.setLevel(_get_level(level)) 288 | 289 | if torch.distributed.is_initialized(): 290 | file_formatter = logging.Formatter(fmt='Rank: %(rank)s - %(asctime)s - %(module)s - [%(levelname)s] - %(message)s', 291 | datefmt='%Y/%m/%d %H:%M:%S') 292 | else: 293 | file_formatter = logging.Formatter(fmt='%(asctime)s - %(module)s - [%(levelname)s] - %(message)s', 294 | datefmt='%Y/%m/%d %H:%M:%S') 295 | 296 | file_handler.setFormatter(file_formatter) 297 | _logger.addHandler(file_handler) 298 | 299 | if remove_other_handlers: 300 | _need_remove_handlers = [] 301 | for i, h in enumerate(_logger.handlers): 302 | if not isinstance(h, logging.FileHandler): 303 | _need_remove_handlers.append(h) 304 | for handler in _need_remove_handlers: 305 | _logger.removeHandler(handler) 306 | 307 | return file_handler 308 | 309 | 310 | def _set_stdout_handler(_logger, stdout='raw', level='INFO'): 311 | level = _get_level(level) 312 | supported_stdout = ['none', 'raw', 'tqdm', 'naive', 'rich'] 313 | if stdout not in supported_stdout: 314 | raise ValueError('stdout must in one of {}'.format(supported_stdout)) 315 | # make sure to initialize _logger only once 316 | stream_handler = None 317 | _handlers = (logging.StreamHandler, TqdmLoggingHandler, StdoutStreamHandler, RichHandler) 318 | for i, h in enumerate(_logger.handlers): 319 | if isinstance(h, _handlers): 320 | stream_handler = h 321 | break 322 | if stream_handler is not None: 323 | _logger.removeHandler(stream_handler) 324 | del stream_handler 325 | 326 | # Stream Handler 327 | if stdout == 'raw': 328 | stream_handler = StdoutStreamHandler() 329 | elif stdout == 'rich': 330 | stream_handler = RichHandler(level=level, log_time_format="[%X]") 331 | elif stdout == 'naive': 332 | stream_handler = logging.StreamHandler(sys.stdout) 333 | elif stdout == 'tqdm': 334 | stream_handler = TqdmLoggingHandler(level) 335 | else: 336 | stream_handler = None 337 | 338 | if stream_handler is not None: 339 | if torch.distributed.is_initialized(): 340 | stream_formatter = logging.Formatter('Rank: %(rank)s - %(message)s') 341 | else: 342 | stream_formatter = logging.Formatter('%(message)s') 343 | stream_handler.setLevel(level) 344 | stream_handler.setFormatter(stream_formatter) 345 | _logger.addHandler(stream_handler) 346 | 347 | return stream_handler 348 | 349 | 350 | def _init_logger(path=None, stdout='rich', level='INFO'): 351 | r"""initialize _logger""" 352 | level = _get_level(level) 353 | 354 | logger = LOMOLogger(ROOT_NAME) 355 | 356 | logger.propagate = False 357 | 358 | _set_stdout_handler(logger, stdout, level) 359 | 360 | # File Handler 361 | if path is not None: 362 | _add_file_handler(logger, path, level) 363 | 364 | logger.setLevel(level) 365 | 366 | return logger 367 | 368 | 369 | def parse_level(): 370 | level = 'WARNING' if int(os.environ.get('LOCAL_RANK', 0)) != 0 else "INFO" 371 | return level 372 | 373 | 374 | logger = _init_logger(path=None, stdout='rich', level=parse_level()) 375 | logger.debug("The environment variables are as following:") 376 | logger.debug(os.environ) 377 | -------------------------------------------------------------------------------- /lomo/log/print.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'print' 3 | ] 4 | from logging import INFO 5 | from .logger import logger 6 | 7 | 8 | def print(*args, sep=' ', end='\n', file=None, flush=False): 9 | """ 10 | 用来重定向 print 函数至 logger.info 的函数。 11 | 12 | :param args: 需要打印的内容 13 | :param sep: 存在多个输入时,使用的间隔。 14 | :param end: 该参数在当前设置无意义,因为结尾一定会被加入 ``'\\\\n'`` 。 15 | :param file: 该参数无意义。 16 | :param flush: 该参数无意义。 17 | :return: 18 | """ 19 | line = sep.join(map(str, args)) 20 | if logger.isEnabledFor(INFO): 21 | kwargs = logger._add_rank_info({}) 22 | logger._log(INFO, line, None, **kwargs) 23 | -------------------------------------------------------------------------------- /lomo/run.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | port=$(shuf -i25000-30000 -n1) 3 | 4 | # for full parameter fine-tuning using LOMO 5 | deepspeed --master_port "$port" --include localhost:0 src/train_lomo.py config/args_lomo.yaml 6 | 7 | # for LoRA + LOMO 8 | #deepspeed --master_port "$port" --include localhost:0 src/train_lomo_lora.py config/args_lomo_lora.yaml 9 | -------------------------------------------------------------------------------- /lomo/src/arguments.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Optional 3 | from transformers import Seq2SeqTrainingArguments 4 | 5 | 6 | @dataclass 7 | class ModelArguments: 8 | model_name_or_path: Optional[str] = field(default="llama-7B") 9 | cache_dir: Optional[str] = field(default='../llama/checkpoint') 10 | # llama_dir: Optional[str] = field(default='/remote-home/klv/exps/MossOn3090/llama') 11 | 12 | 13 | @dataclass 14 | class DataArguments: 15 | data_dir: str = field(default='data') 16 | dataset_name: str = field(default='openbookqa') 17 | refresh: bool = field(default=False, metadata={"help": "Whether to refresh the data."}) 18 | 19 | data_tag: str = field(default='src') 20 | prompt_type: str = field(default='natural', metadata={"help": "The type of prompt, including [natural, brown]."}) 21 | train_on_inputs: bool = field(default=False, metadata={"help": "Whether to train on input."}) 22 | data_max_length: int = field(default=1024) 23 | few_shot_size: int = field(default=-1) 24 | in_context_learning: bool = field(default=False, metadata={"help": "Whether to use in-context learning."}) 25 | 26 | 27 | @dataclass 28 | class MyTrainingArguments(Seq2SeqTrainingArguments): 29 | tag: str = field(default=None, metadata={"help": "Tag for the experiment."}) 30 | 31 | predict_with_generate: bool = field(default=False, metadata={"help": "Whether to use generate for prediction."}) 32 | 33 | clip_grad_norm: float = field(default=None, metadata={ 34 | "help": "Maximum gradient normalized value (for gradient clipping)."}) # recommend 1.0 35 | clip_grad_value: float = field(default=None, metadata={"help": "Maximum gradient value (for gradient clipping)."}) 36 | clip_loss_value: float = field(default=None, 37 | metadata={"help": "Maximum loss value (for token loss clipping)."}) # recommend 5.0 38 | warmup: float = field(default=0.0, 39 | metadata={"help": "The number of warmup steps (int) or the warmup ratio (float)."}) 40 | 41 | max_length: int = field(default=20, metadata={"help": "The maximum length of the sequence to be generated."}) 42 | max_new_tokens: int = field(default=None, metadata={ 43 | "help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."}) 44 | do_sample: bool = field(default=False, 45 | metadata={"help": "Whether or not to use sampling ; use greedy decoding otherwise."}) 46 | temperature: float = field(default=1.0, 47 | metadata={"help": "The value used to modulate the next token probabilities."}) 48 | top_k: int = field(default=50, metadata={ 49 | "help": "If set to int > 0, only the top k tokens with the highest probability will be considered for generation."}) 50 | top_p: float = field(default=1.0, metadata={ 51 | "help": "If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation."}) 52 | typical_p: float = field(default=1.0, metadata={ 53 | "help": "If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation."}) 54 | repetition_penalty: float = field(default=1.0, metadata={ 55 | "help": "The parameter for repetition penalty. 1.0 means no penalty. See this paper for more details: https://arxiv.org/pdf/1909.05858.pdf"}) 56 | 57 | length_normalization: bool = field(default=True, metadata={"help": "Whether to normalize the loss by the length of the input."}) 58 | unconditional_normalization: bool = field(default=False, metadata={"help": "Whether to normalize the loss by the length of the input."}) 59 | 60 | hf_learning_rate: float = field(default=5e-4, metadata={"help": "The learning rate for the HF optimizer."}) 61 | hf_weight_decay: float = field(default=0.0, metadata={"help": "The weight decay for the HF optimizer."}) 62 | hf_lr_scheduler_type: str = field(default='linear', metadata={"help": "The lr scheduler type for the HF optimizer."}) 63 | hf_warmup: int = field(default=0, metadata={"help": "The warmup steps for the HF optimizer."}) 64 | 65 | # lora hyperparams 66 | peft_type: str = field(default=None, metadata={ 67 | "help": "The type of PEFT, including [lora, prefix-tuning, prompt-tuning, p-tuning]."}) 68 | lora_r: int = field(default=8, metadata={"help": "Lora attention dimension."}) 69 | lora_alpha: int = field(default=16, metadata={"help": "The alpha parameter for Lora scaling."}) 70 | lora_dropout: float = field(default=0.05, metadata={"help": "The dropout probability for Lora layers."}) 71 | lora_only: bool = field(default=False, metadata={"help": "Whether to use LoRA without LOMO"}) 72 | -------------------------------------------------------------------------------- /lomo/src/lomo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.optim import Optimizer 4 | import torch.distributed as dist 5 | 6 | from src.utils import DynamicLossScaler 7 | 8 | 9 | class LOMO(Optimizer): 10 | """ 11 | 一个自定义的优化器类LOMO,用于在分布式训练中的梯度更新。 12 | 13 | 该类实现两个梯度更新函数 :meth:`fuse_update` 和 :meth:`fuse_update_zero3`,分别用于非ZeRO和ZeRO模式下的梯度更新。 14 | 15 | :param model: 待优化的模型 16 | :param lr: 学习率,默认值为1e-3 17 | :param clip_grad_norm: 梯度裁剪的范数阈值 18 | 19 | .. note:: 20 | 21 | clip_grad_norm须为正数 22 | 23 | :param clip_grad_value: 梯度裁剪的值域阈值 24 | """ 25 | 26 | def __init__(self, model, lr=1e-3, clip_grad_norm=None, clip_grad_value=None): 27 | self.model = model 28 | self.lr = lr 29 | self.local_rank = int(os.environ["LOCAL_RANK"]) 30 | self.world_size = dist.get_world_size() 31 | self.clip_grad_norm = clip_grad_norm 32 | self.clip_grad_value = clip_grad_value 33 | 34 | # for grad norm 35 | if self.clip_grad_norm is not None and self.clip_grad_norm <= 0: 36 | raise ValueError(f"clip_grad_norm should be positive, got {self.clip_grad_norm}.") 37 | self.gather_norm = False 38 | self.grad_norms = [] 39 | self.clip_coef = None 40 | 41 | # check if zero3 is enabled 42 | p0 = list(self.model.parameters())[0] 43 | if hasattr(p0, 'ds_tensor'): # zero3 is enabled 44 | self.grad_func = self.fuse_update_zero3() 45 | else: 46 | self.grad_func = self.fuse_update() 47 | # check if fp16 is enabled 48 | if p0.dtype == torch.float16: 49 | self.loss_scaler = DynamicLossScaler( 50 | init_scale=2 ** 16, 51 | ) # TODO: add args 52 | if self.clip_grad_norm is None: 53 | raise ValueError( 54 | "Loss scaling is recommended to be used with grad norm to get better performance." 55 | ) 56 | else: 57 | self.loss_scaler = None 58 | 59 | # register hook function, which will be called through the backward process 60 | for n, p in self.model.named_parameters(): 61 | if p.requires_grad: 62 | p.register_hook(self.grad_func) 63 | defaults = dict(lr=lr, clip_grad_norm=clip_grad_norm, clip_grad_value=clip_grad_value) 64 | super(LOMO, self).__init__(self.model.parameters(), defaults) 65 | 66 | def fuse_update(self): 67 | """ 68 | 在非ZeRO模式下更新模型参数的梯度。 69 | 70 | :return: func,一个闭包函数,用于更新模型参数的梯度 71 | """ 72 | 73 | def func(x): 74 | """ 75 | 闭包函数,用于更新模型参数的梯度。 76 | """ 77 | with torch.no_grad(): 78 | for n, p in self.model.named_parameters(): 79 | if p.requires_grad and p.grad is not None: 80 | if self.loss_scaler: 81 | if self.loss_scaler.has_overflow_serial or self.loss_scaler._has_inf_or_nan(p.grad): 82 | # if the overflow is detected, drop the gradient 83 | p.grad = None 84 | self.loss_scaler.has_overflow_serial = True 85 | break 86 | grad_fp32 = p.grad.to(torch.float32) 87 | p.grad = None 88 | if self.loss_scaler: 89 | grad_fp32.div_(self.loss_scaler.loss_scale) 90 | if self.gather_norm: 91 | # we adopt two backward pass for gradient norm compuation and parameter update, respectively. 92 | self.grad_norms.append(torch.norm(grad_fp32, 2.0)) 93 | else: 94 | if self.clip_grad_value is not None and self.clip_grad_value > 0: 95 | # Clipping gradients by their value 96 | grad_fp32.clamp_(min=-self.clip_grad_value, max=self.clip_grad_value) 97 | if self.clip_grad_norm is not None and self.clip_grad_norm > 0 and self.clip_coef is not None: 98 | # Normalize the gradient according to its norm (computed in another pass) 99 | grad_fp32.mul_(self.clip_coef) 100 | p_fp32 = p.data.to(torch.float32) 101 | p_fp32.add_(grad_fp32, alpha=-self.lr) 102 | p.data.copy_(p_fp32) 103 | 104 | return x 105 | 106 | return func 107 | 108 | def fuse_update_zero3(self): 109 | """ 110 | 在ZeRO模式下更新模型参数的梯度。 111 | 112 | :return: func,一个闭包函数,用于更新模型参数的梯度。 113 | """ 114 | def func(x): 115 | with torch.no_grad(): 116 | for n, p in self.model.named_parameters(): 117 | if p.grad is not None: 118 | torch.distributed.all_reduce(p.grad, op=torch.distributed.ReduceOp.AVG, async_op=False) 119 | if self.loss_scaler: 120 | if self.loss_scaler.has_overflow_serial or self.loss_scaler._has_inf_or_nan(p.grad): 121 | # if the overflow is detected, drop the gradient 122 | p.grad = None 123 | self.loss_scaler.has_overflow_serial = True 124 | break 125 | 126 | grad_fp32 = p.grad.to(torch.float32) 127 | p.grad = None 128 | param_fp32 = p.ds_tensor.to(torch.float32) 129 | if self.loss_scaler: 130 | grad_fp32.div_(self.loss_scaler.loss_scale) 131 | 132 | if self.gather_norm: 133 | # we adopt two backward pass for gradient norm compuation and parameter update, respectively. 134 | self.grad_norms.append(torch.norm(grad_fp32, 2.0)) 135 | else: # update param 136 | one_dim_grad_fp32 = grad_fp32.view(-1) 137 | partition_size = p.ds_tensor.numel() 138 | start = partition_size * self.local_rank 139 | end = min(start + partition_size, grad_fp32.numel()) 140 | partitioned_grad_fp32 = one_dim_grad_fp32.narrow(0, start, end - start) 141 | 142 | if self.clip_grad_value is not None: 143 | # Clipping gradients by their value 144 | partitioned_grad_fp32.clamp_(min=-self.clip_grad_value, max=self.clip_grad_value) 145 | if self.clip_grad_norm is not None and self.clip_grad_norm > 0 and self.clip_coef is not None: 146 | # Normalize the gradient according to its norm (computed in another pass) 147 | partitioned_grad_fp32.mul_(self.clip_coef) 148 | 149 | partitioned_p = param_fp32.narrow(0, 0, end - start) 150 | partitioned_p.add_(partitioned_grad_fp32, alpha=-self.lr) 151 | p.ds_tensor[ : end - start] = partitioned_p 152 | return x 153 | 154 | return func 155 | 156 | def fused_backward(self, loss, lr): 157 | """ 158 | 执行一步反向传播并更新模型的梯度。 159 | 160 | :param loss: 模型的loss值 161 | :param lr: 学习率 162 | """ 163 | self.lr = lr 164 | # Users need call grad_norm themselves and then call backward_step 165 | if self.clip_grad_norm is not None and self.clip_grad_norm > 0 and self.clip_coef is None: 166 | raise ValueError( 167 | "clip_grad_norm is not None, but clip_coef is None. " 168 | "Please call optimizer.grad_norm() before optimizer.fused_backward()." 169 | ) 170 | if self.loss_scaler: 171 | loss = loss * self.loss_scaler.loss_scale 172 | loss.backward() 173 | # update the last parameter since the last parameter in the computaiton graph is not ready when calling hook functions 174 | # the argument of grad_func is just a placeholder, and it can be anything. 175 | self.grad_func(0) 176 | 177 | def grad_norm(self, loss): 178 | """ 179 | 计算梯度的范数。 180 | 181 | :param loss: 模型的loss值 182 | """ 183 | self.gather_norm = True 184 | self.grad_norms = [] 185 | if self.loss_scaler: 186 | self.loss_scaler.has_overflow_serial = False 187 | loss = loss * self.loss_scaler.loss_scale 188 | loss.backward(retain_graph=True) 189 | # update the last parameter since the last parameter in the computaiton graph is not ready when calling hook functions 190 | # the argument of grad_func is just a placeholder, and it can be anything. 191 | self.grad_func(0) 192 | 193 | if self.loss_scaler and self.loss_scaler.has_overflow_serial: 194 | self.loss_scaler.update_scale(overflow=True) 195 | with torch.no_grad(): # clear gradients 196 | for n, p in self.model.named_parameters(): 197 | p.grad = None 198 | return 199 | 200 | 201 | with torch.no_grad(): 202 | # The norm is computed over all gradients together, as if they were 203 | # concatenated into a single vector. Gradients are modified in-place. 204 | self.grad_norms = torch.stack(self.grad_norms) 205 | 206 | total_norm = torch.norm(self.grad_norms, 2.0) 207 | self.clip_coef = float(self.clip_grad_norm) / (total_norm + 1e-6) 208 | self.clip_coef = torch.clamp(self.clip_coef, max=1.0) 209 | self.gather_norm = False 210 | -------------------------------------------------------------------------------- /lomo/src/lomo_lora_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import operator 4 | from collections import OrderedDict 5 | from itertools import chain 6 | from pathlib import Path 7 | import shutil 8 | 9 | import tqdm 10 | import torch 11 | from torch.nn import CrossEntropyLoss 12 | from torch.utils.data import DistributedSampler, DataLoader 13 | from transformers.trainer_pt_utils import DistributedLengthGroupedSampler, SequentialDistributedSampler, nested_numpify 14 | from transformers.trainer_utils import has_length, seed_worker 15 | from transformers import GenerationConfig 16 | from transformers.optimization import AdamW, get_scheduler 17 | 18 | try: 19 | import deepspeed 20 | from deepspeed import comm as dist 21 | from deepspeed.accelerator import get_accelerator 22 | except: 23 | pass 24 | 25 | from src.utils import LearningRateScheduler, WandbLogger, get_loss 26 | from log import print 27 | from peft import get_peft_model_state_dict 28 | 29 | 30 | class LOMOLoRATrainer: 31 | def __init__( 32 | self, 33 | model, 34 | training_args, 35 | data_collator, 36 | train_dataset, 37 | eval_dataset, 38 | tokenizer, 39 | compute_metrics, 40 | optimizers=None, 41 | ): 42 | self.training_args = training_args 43 | if self.training_args.world_size > 1: 44 | raise NotImplementedError("Distributed training for LOMO+LoRA is not supported yet.") 45 | self.train_dataset = train_dataset 46 | self.eval_dataset = eval_dataset 47 | self.tokenizer = tokenizer 48 | self.wandb = WandbLogger(training_args) 49 | self.allow_print = self.training_args.local_rank in [0, -1] 50 | if self.training_args.do_eval: 51 | self.metrics = {} 52 | self.compute_metrics = compute_metrics 53 | 54 | # get train_dataloader and eval_dataloader 55 | if isinstance(data_collator, dict): 56 | assert 'train' in data_collator and 'eval' in data_collator, "data_collator should be a dict with keys 'train' and 'eval'." 57 | self.train_data_collator = data_collator['train'] 58 | if self.training_args.do_eval: 59 | self.eval_data_collator = data_collator['eval'] 60 | else: 61 | self.train_data_collator = self.eval_data_collator = data_collator 62 | self.train_dataloader = self.get_train_dataloader() 63 | if self.training_args.do_eval: 64 | if isinstance(self.eval_dataset, dict): 65 | self.eval_dataloader = {} 66 | for prefix in self.eval_dataset.keys(): 67 | self.eval_dataloader[prefix] = self.get_eval_dataloader(self.eval_dataset[prefix]) 68 | else: 69 | self.eval_dataloader = self.get_eval_dataloader() 70 | 71 | # setup learning rate 72 | self.num_steps_per_epoch = len(self.train_dataloader) 73 | self.global_step = 1 74 | self.n_steps = self.num_steps_per_epoch * self.training_args.num_train_epochs 75 | self.lr_scheduler = LearningRateScheduler(learning_rate=self.training_args.learning_rate, 76 | warmup=self.training_args.warmup, 77 | schedule=self.training_args.lr_scheduler_type, 78 | n_steps=self.n_steps) 79 | self.lr = 0 80 | # for grad norm 81 | self.gather_norm = False 82 | self.grad_norms = [] 83 | self.clip_coef = None 84 | 85 | hf_optimizer = None 86 | hf_lr_scheduler = None 87 | if self.training_args.do_train: 88 | hf_optimizer = AdamW(optimizers['model_parameters'], lr=training_args.hf_learning_rate, 89 | weight_decay=training_args.hf_weight_decay) 90 | hf_lr_scheduler = get_scheduler(training_args.hf_lr_scheduler_type, 91 | optimizer=hf_optimizer, 92 | num_warmup_steps=training_args.hf_warmup * self.n_steps if training_args.hf_warmup < 1 else training_args.hf_warmup, 93 | num_training_steps=self.n_steps) 94 | 95 | if 'deepspeed' not in sys.modules: 96 | raise ModuleNotFoundError( 97 | "Detected DeepSpeed is not installed. See https://github.com/microsoft/DeepSpeed") 98 | 99 | # Initialize deepspeed engine 100 | self.model, self.peft_optimizer, _, self.peft_lr_scheduler = deepspeed.initialize( 101 | config=training_args.deepspeed, 102 | model=model, 103 | model_parameters=optimizers['model_parameters'] if self.training_args.do_train else None, 104 | optimizer=hf_optimizer, 105 | lr_scheduler=hf_lr_scheduler 106 | ) 107 | 108 | if not self.training_args.lora_only: 109 | # register inplace grad hook 110 | self.grad_func = self.inplace_grad() 111 | for n, p in model.named_parameters(): 112 | if "lora_" not in n and p.requires_grad: 113 | p.register_hook(self.grad_func) 114 | 115 | # self.dummy_optimizer = DeepSpeedZeRoOffload( 116 | # self.model.module, 117 | # timers=self.model.timers if self.model.wall_clock_breakdown() else None, 118 | # ds_config=self.model.config, 119 | # overlap_comm=self.model.zero_overlap_comm(), 120 | # prefetch_bucket_size=self.model.zero_prefetch_bucket_size(), 121 | # max_reuse_distance=self.model.zero_max_reuse_distance(), 122 | # max_live_parameters=self.model.zero_max_live_parameters(), 123 | # param_persistence_threshold=self.model.zero_param_persistence_threshold(), 124 | # model_persistence_threshold=self.model.zero_model_persistence_threshold(), 125 | # offload_param_config=self.model.zero_offload_param(), 126 | # mpu=self.model.mpu 127 | # ) 128 | 129 | get_accelerator().empty_cache() 130 | 131 | def inplace_grad(self): 132 | # An approximation of in-place grad update under zero3 of deepspeed 133 | def func(x): 134 | with torch.no_grad(): 135 | for n, p in self.model.named_parameters(): 136 | if "lora_" in n: 137 | continue 138 | 139 | if p.grad is not None: 140 | torch.distributed.all_reduce(p.grad, op=torch.distributed.ReduceOp.AVG, async_op=False) 141 | if self.gather_norm: 142 | grad_fp32 = p.grad.detach().clone().to(torch.float32) 143 | self.grad_norms.append(torch.norm(grad_fp32, 2.0)) 144 | p.grad = None 145 | else: 146 | one_dim_grad = p.grad.view(-1) 147 | partition_size = p.ds_tensor.numel() 148 | start = partition_size * self.training_args.local_rank 149 | end = start + partition_size 150 | 151 | if end > p.grad.numel(): 152 | partitioned_grad = one_dim_grad.narrow(0, start, p.grad.numel() - start) 153 | # partitioned_grad = torch.cat([partitioned_grad, torch.zeros(end - p.grad.numel()).cuda()]) 154 | partitioned_p = p.ds_tensor.narrow(0, 0, p.grad.numel() - start) 155 | partitioned_grad_fp32 = partitioned_grad.detach().clone().to(torch.float32) 156 | partitioned_p_fp32 = partitioned_p.detach().clone().to(torch.float32) 157 | if self.training_args.clip_grad_value is not None: 158 | # Gradients are modified in-place. 159 | partitioned_grad_fp32.clamp_(min=-self.training_args.clip_grad_value, 160 | max=self.training_args.clip_grad_value) 161 | if self.training_args.clip_grad_norm is not None and self.training_args.clip_grad_norm > 0 and self.clip_coef is not None: 162 | partitioned_grad_fp32.mul_(self.clip_coef) 163 | partitioned_p_fp32.add_(partitioned_grad_fp32, alpha=-self.lr) 164 | partitioned_p.copy_(partitioned_p_fp32) 165 | else: 166 | partitioned_grad = one_dim_grad.narrow(0, start, partition_size) 167 | partitioned_grad_fp32 = partitioned_grad.detach().clone().to(torch.float32) 168 | if self.training_args.clip_grad_value is not None: 169 | # Gradients are modified in-place. 170 | partitioned_grad_fp32.clamp_(min=-self.training_args.clip_grad_value, 171 | max=self.training_args.clip_grad_value) 172 | if self.training_args.clip_grad_norm is not None and self.training_args.clip_grad_norm > 0 and self.clip_coef is not None: 173 | partitioned_grad_fp32.mul_(self.clip_coef) 174 | ds_tensor_fp32 = p.ds_tensor.detach().clone().to(torch.float32) 175 | ds_tensor_fp32.add_(partitioned_grad_fp32, alpha=-self.lr) 176 | p.ds_tensor.copy_(ds_tensor_fp32) 177 | p.grad = None 178 | return x 179 | 180 | return func 181 | 182 | def train(self): 183 | for epoch in range(self.training_args.num_train_epochs): 184 | print(f"***** Running Training *****") 185 | print(f" Num examples: {len(self.train_dataset)}") 186 | print(f" Num Epochs: {self.training_args.num_train_epochs}") 187 | print(f" Current Epoch: {epoch}") 188 | print(f" Batch Size: {self.training_args.per_device_train_batch_size}") 189 | if self.allow_print: 190 | self.wandb.log({'train/epoch': epoch}, step=self.global_step) 191 | self.train_dataloader.sampler.set_epoch(epoch) 192 | 193 | with tqdm.tqdm(self.train_dataloader, disable=not self.allow_print) as tqb: 194 | for step, batch in enumerate(tqb, start=1): 195 | self.model.train() 196 | outs = self.model( 197 | input_ids=batch['input_ids'].cuda(), 198 | attention_mask=batch['attention_mask'].cuda(), 199 | ) 200 | loss = get_loss(outs.logits, batch['labels'], self.training_args.clip_loss_value) 201 | 202 | # update the learning rate 203 | if self.training_args.lora_only: 204 | self.global_step = self.num_steps_per_epoch * epoch + step 205 | 206 | loss = self.model.backward(loss) 207 | self.model.step() # Optimizer step for deepspeed must be called on every step regardless of the value of gradient_accumulation_steps 208 | else: 209 | self.global_step = self.num_steps_per_epoch * epoch + step 210 | self.lr = self.lr_scheduler.step(self.global_step) 211 | 212 | if self.training_args.clip_grad_norm is not None and self.training_args.clip_grad_norm > 0: 213 | self.gather_norm = True 214 | self.grad_norms = [] 215 | 216 | self.model.backward(loss) 217 | # update the last one since the hook function will not be called for the last parameter 218 | self.grad_func(0) 219 | # self.model.optimizer._get_param_coordinator(training=True).reset_step() 220 | # self.dummy_optimizer.get_param_coordinator(training=True).reset_step() 221 | 222 | with torch.no_grad(): 223 | # The norm is computed over all gradients together, as if they were 224 | # concatenated into a single vector. Gradients are modified in-place. 225 | self.grad_norms = torch.stack(self.grad_norms) 226 | device = torch.device(f"cuda:{self.training_args.local_rank}") 227 | all_grad_norms = torch.zeros(self.training_args.world_size * self.grad_norms.shape[0], 228 | dtype=self.grad_norms.dtype, device=device) 229 | torch.distributed.all_gather_into_tensor(all_grad_norms, self.grad_norms) 230 | 231 | total_norm = torch.norm(all_grad_norms, 2.0) 232 | self.clip_coef = float(self.training_args.clip_grad_norm) / (total_norm + 1e-6) 233 | self.clip_coef = torch.clamp(self.clip_coef, max=1.0) 234 | self.gather_norm = False 235 | 236 | # 第二次forward 237 | outs = self.model( 238 | input_ids=batch['input_ids'].cuda(), 239 | attention_mask=batch['attention_mask'].cuda(), 240 | ) 241 | loss = get_loss(outs.logits, batch['labels'], self.training_args.clip_loss_value) 242 | 243 | # update peft params 244 | loss = self.model.backward(loss) 245 | self.grad_func(0) 246 | self.model.step() # Optimizer step for deepspeed must be called on every step regardless of the value of gradient_accumulation_steps 247 | 248 | tqb.set_postfix({'loss': loss.item()}) 249 | if self.allow_print: 250 | self.wandb.log( 251 | { 252 | 'train/loss': loss.item(), 253 | 'train/learning_rate': self.lr, 254 | 'train/hf_learning_rate': self.model.get_lr()[0], 255 | 'train/global_step': self.global_step, 256 | }, 257 | step=self.global_step 258 | ) 259 | 260 | if self.training_args.save_strategy == 'steps' and self.global_step % self.training_args.save_steps == 0: 261 | self.save_model(self.global_step) 262 | 263 | if self.training_args.do_eval and self.training_args.evaluation_strategy == 'steps' and \ 264 | self.global_step % self.training_args.eval_steps == 0: 265 | if isinstance(self.eval_dataset, dict): 266 | for prefix in self.eval_dataset.keys(): 267 | assert prefix in self.eval_dataloader.keys(), "eval_dataset and eval_dataloader should have the same keys." 268 | self.eval(self.global_step, epoch, self.eval_dataset[prefix], 269 | self.eval_dataloader[prefix], prefix) 270 | else: 271 | self.eval(self.global_step, epoch, self.eval_dataset, self.eval_dataloader, 'eval') 272 | 273 | if self.training_args.save_strategy == 'epoch': 274 | self.save_model(epoch) 275 | 276 | if self.training_args.do_eval and self.training_args.evaluation_strategy == 'epoch': 277 | if isinstance(self.eval_dataset, dict): 278 | for prefix in self.eval_dataset.keys(): 279 | assert prefix in self.eval_dataloader.keys(), "eval_dataset and eval_dataloader should have the same keys." 280 | self.eval(self.global_step, epoch, self.eval_dataset[prefix], self.eval_dataloader[prefix], 281 | prefix) 282 | else: 283 | self.eval(self.global_step, epoch, self.eval_dataset, self.eval_dataloader, 'eval') 284 | 285 | def eval( 286 | self, 287 | step: int, 288 | epoch: int, 289 | dataset: torch.utils.data.Dataset, 290 | dataloader: DataLoader, 291 | eval_prefix: str 292 | ): 293 | r""" 294 | Shared by both eval(validation) and predict(test). 295 | This method will be called by the trainer to evaluate the model. 296 | """ 297 | print(f"***** Running {eval_prefix} *****") 298 | print(f" Num examples: {len(dataset)}") 299 | print(f" Current Epoch: {epoch}") 300 | print(f" Batch size: {self.training_args.per_device_eval_batch_size}") 301 | 302 | with tqdm.tqdm(dataloader, disable=not self.allow_print) as tqb: 303 | all_preds = None 304 | self.model.eval() 305 | for batch in tqb: 306 | with torch.no_grad(): 307 | if self.training_args.predict_with_generate: 308 | pred = self.generate_step(batch) 309 | else: 310 | pred = self.eval_step(batch) 311 | all_preds = pred if all_preds is None else all_preds + pred 312 | 313 | all_preds_gather = [None for _ in range(self.training_args.world_size)] 314 | torch.distributed.all_gather_object(all_preds_gather, all_preds) 315 | all_pred_merged = list(chain(*all_preds_gather)) 316 | 317 | result = self.compute_metrics(all_pred_merged, dataset, eval_prefix) 318 | result = {f"{eval_prefix}/{k}": v for k, v in result.items()} 319 | prefix_metric_for_best_model = f'{eval_prefix}/{self.training_args.metric_for_best_model}' 320 | result_value = result[prefix_metric_for_best_model] 321 | 322 | if self.allow_print: 323 | print(f'epoch: {epoch}, step: {step}, {self.training_args.metric_for_best_model}: {result_value}') 324 | self.wandb.log(result, step=step) 325 | 326 | if self.is_better(result, prefix_metric_for_best_model): 327 | self.wandb.set_summary(f'{eval_prefix}/best_{self.training_args.metric_for_best_model}', result_value) 328 | self.wandb.set_summary(f'{eval_prefix}/best_epoch', epoch) 329 | self.wandb.set_summary(f'{eval_prefix}/best_step', step) 330 | self.metrics[prefix_metric_for_best_model] = result_value 331 | 332 | def eval_step(self, batch): 333 | """ 334 | used for classification or multi-choice qa tasks in eval() 335 | """ 336 | outs = self.model(batch['input_ids'].cuda(), batch['attention_mask'].cuda()) 337 | # Shift so that tokens < n predict n 338 | shift_logits = outs.logits[..., :-1, :].contiguous() 339 | shift_labels = batch['labels'][..., 1:].cuda().contiguous() 340 | # Flatten the tokens 341 | loss_fct = CrossEntropyLoss(reduction='none') 342 | loss = loss_fct(shift_logits.view(shift_labels.shape[0] * shift_labels.shape[1], -1), 343 | shift_labels.view(-1)).view_as(shift_labels) 344 | loss = loss.mean(dim=1) 345 | group_loss = loss.split(batch['split_size']) 346 | preds = torch.stack([torch.argmin(l) for l in group_loss], dim=0) 347 | 348 | preds = nested_numpify(preds) 349 | return preds.tolist() 350 | 351 | def generate_step(self, batch): 352 | """ 353 | used for generation tasks in eval() 354 | """ 355 | self.model.eval() 356 | generation_config = GenerationConfig(max_length=self.training_args.max_length, 357 | max_new_tokens=self.training_args.max_new_tokens, 358 | do_sample=self.training_args.do_sample, 359 | temperature=self.training_args.temperature, 360 | top_k=self.training_args.top_k, 361 | top_p=self.training_args.top_p, 362 | typical_p=self.training_args.typical_p, 363 | repetition_penalty=self.training_args.repetition_penalty, ) 364 | logits = self.model.generate( 365 | inputs=batch['input_ids'].cuda(), 366 | generation_config=generation_config 367 | ) 368 | predictions = logits.detach().cpu().numpy() 369 | pred_texts = self.tokenizer.batch_decode(predictions, skip_special_tokens=True) 370 | return pred_texts 371 | 372 | def is_better(self, result_dict, key): 373 | """ 374 | 判断 ``result`` 是否更好。 375 | 376 | :param result: 377 | """ 378 | op = operator.gt if self.training_args.greater_is_better else operator.lt 379 | return ( 380 | key not in self.metrics or op(result_dict[key], self.metrics[key]) 381 | ) 382 | 383 | def get_train_sampler(self): 384 | if self.train_dataset is None or not has_length(self.train_dataset): 385 | return None 386 | 387 | # for backwards compatibility, we generate a seed here (which is sampled from a generator seeded with 388 | # `self.training_args.seed`) if data_seed isn't provided. 389 | # Further on in this method, we default to `self.training_args.seed` instead. 390 | seed = self.training_args.data_seed if self.training_args.data_seed is not None else self.training_args.seed 391 | 392 | if self.training_args.group_by_length: 393 | return DistributedLengthGroupedSampler( 394 | self.training_args.per_device_train_batch_size * self.training_args.gradient_accumulation_steps, 395 | dataset=self.train_dataset, 396 | num_replicas=self.training_args.world_size, 397 | rank=self.training_args.local_rank, 398 | lengths=None, 399 | model_input_name="input_ids", 400 | seed=seed, 401 | ) 402 | else: 403 | return DistributedSampler( 404 | self.train_dataset, 405 | num_replicas=self.training_args.world_size, 406 | rank=self.training_args.local_rank, 407 | seed=seed 408 | ) 409 | 410 | def get_train_dataloader(self): 411 | """ 412 | Returns the training [`~torch.utils.data.DataLoader`]. 413 | Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed 414 | training if necessary) otherwise. 415 | Subclass and override this method if you want to inject some custom behavior. 416 | """ 417 | if self.train_dataset is None: 418 | raise ValueError("Trainer: training requires a train_dataset.") 419 | 420 | data_collator = self.train_data_collator 421 | train_sampler = self.get_train_sampler() 422 | 423 | return DataLoader( 424 | self.train_dataset, 425 | batch_size=self.training_args.per_device_train_batch_size, 426 | sampler=train_sampler, 427 | collate_fn=data_collator, 428 | drop_last=self.training_args.dataloader_drop_last, 429 | num_workers=self.training_args.dataloader_num_workers, 430 | pin_memory=self.training_args.dataloader_pin_memory, 431 | worker_init_fn=seed_worker, 432 | ) 433 | 434 | def get_eval_sampler(self, eval_dataset): 435 | return SequentialDistributedSampler( 436 | eval_dataset, 437 | num_replicas=self.training_args.world_size, 438 | rank=self.training_args.local_rank, 439 | # batch_size=self.training_args.per_device_eval_batch_size 440 | ) 441 | 442 | def get_eval_dataloader(self, eval_dataset=None): 443 | """ 444 | Returns the evaluation [`~torch.utils.data.DataLoader`]. 445 | 446 | Subclass and override this method if you want to inject some custom behavior. 447 | 448 | Args: 449 | eval_dataset (`torch.utils.data.Dataset`, *optional*): 450 | If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted 451 | by the `model.forward()` method are automatically removed. It must implement `__len__`. 452 | """ 453 | if eval_dataset is None and self.eval_dataset is None: 454 | raise ValueError("Trainer: evaluation requires an eval_dataset.") 455 | eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset 456 | data_collator = self.eval_data_collator 457 | 458 | eval_sampler = self.get_eval_sampler(eval_dataset) 459 | 460 | return DataLoader( 461 | eval_dataset, 462 | sampler=eval_sampler, 463 | batch_size=self.training_args.per_device_eval_batch_size, 464 | collate_fn=data_collator, 465 | drop_last=self.training_args.dataloader_drop_last, 466 | num_workers=self.training_args.dataloader_num_workers, 467 | pin_memory=self.training_args.dataloader_pin_memory, 468 | ) 469 | 470 | def save_model(self, index): 471 | if self.training_args.local_rank in [-1, 0]: 472 | checkpoint_dir = sorted(Path(self.training_args.output_dir).glob("checkpoint-*")) 473 | if len(checkpoint_dir) >= self.training_args.save_total_limit: 474 | shutil.rmtree(checkpoint_dir[0], ignore_errors=True) 475 | torch.distributed.barrier() 476 | 477 | output_dir = os.path.join(self.training_args.output_dir, f"checkpoint-{index}") 478 | if not os.path.exists(output_dir): 479 | os.makedirs(output_dir, exist_ok=True) 480 | state_dict = OrderedDict() 481 | for n, p in self.model.module.named_parameters(): 482 | state_dict[n] = (p.ds_tensor.detach().cpu(), p.ds_numel, p.ds_shape) 483 | # save model shards 484 | if self.training_args.local_rank != 0: 485 | with open(os.path.join(output_dir, f'pytorch_model-{self.training_args.local_rank}.bin'), 'wb') as f: 486 | torch.save(state_dict, f) 487 | torch.distributed.barrier() 488 | # merge model shards 489 | if self.training_args.local_rank == 0: 490 | # save config 491 | self.model.module.config.save_pretrained(output_dir) 492 | for rank in range(1, self.training_args.world_size): 493 | with open(os.path.join(output_dir, f'pytorch_model-{rank}.bin'), 'rb') as f: 494 | state_dict_rank = torch.load(f) 495 | for n in state_dict_rank: 496 | state_dict[n] = ( 497 | torch.cat([state_dict[n][0], state_dict_rank[n][0]], dim=0), 498 | state_dict[n][1], 499 | state_dict[n][2] 500 | ) 501 | # remove shard files 502 | os.remove(os.path.join(output_dir, f'pytorch_model-{rank}.bin')) 503 | # reshape to original shape 504 | for n in state_dict: 505 | numel = state_dict[n][1] 506 | shape = state_dict[n][2] 507 | state_dict[n] = state_dict[n][0][:numel].view(shape) 508 | 509 | # save inv_freq for llama 510 | if self.model.module.config.model_type == "llama": 511 | num_layers = self.model.module.config.num_hidden_layers 512 | head_dim = self.model.module.config.hidden_size // self.model.module.config.num_attention_heads 513 | base = 10000.0 514 | inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim)) 515 | for layer in range(num_layers): 516 | state_dict[f'model.layers.{layer}.self_attn.rotary_emb.inv_freq'] = inv_freq 517 | 518 | with open(os.path.join(output_dir, f'pytorch_model.bin'), 'wb') as f: 519 | torch.save(state_dict, f) 520 | print(f"Save model to {output_dir}.") 521 | 522 | # save lora 523 | adapter_output_dir = os.path.join(output_dir, 'adapter_model') 524 | self.model.module.peft_config['default'].save_pretrained(adapter_output_dir) 525 | # if state dict is not what you expected, you can use the following code to get the state dict 526 | # engine_state_dict = self.model._zero3_consolidated_16bit_state_dict() 527 | lora_state_dict = get_peft_model_state_dict(self.model.module, state_dict) 528 | torch.save(lora_state_dict, os.path.join(adapter_output_dir, "adapter_model.bin")) 529 | print(f"Save adapter model at {adapter_output_dir}") 530 | torch.distributed.barrier() 531 | -------------------------------------------------------------------------------- /lomo/src/lomo_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import operator 4 | from collections import OrderedDict 5 | from itertools import chain 6 | from pathlib import Path 7 | import shutil 8 | 9 | import tqdm 10 | import torch 11 | from torch.nn import CrossEntropyLoss 12 | from torch.utils.data import DistributedSampler, DataLoader 13 | from transformers.trainer_pt_utils import DistributedLengthGroupedSampler, SequentialDistributedSampler, nested_numpify 14 | from transformers.trainer_utils import has_length, seed_worker 15 | from transformers import GenerationConfig 16 | 17 | try: 18 | import deepspeed 19 | from deepspeed import comm as dist 20 | from deepspeed.accelerator import get_accelerator 21 | except: 22 | pass 23 | 24 | from src.utils import LearningRateScheduler, WandbLogger, DynamicLossScaler, get_loss 25 | from src.lomo import LOMO 26 | from log import print 27 | 28 | 29 | class LOMOTrainer: 30 | def __init__( 31 | self, 32 | model, 33 | training_args, 34 | data_collator, 35 | train_dataset, 36 | eval_dataset, 37 | tokenizer, 38 | compute_metrics, 39 | ): 40 | self.training_args = training_args 41 | self.train_dataset = train_dataset 42 | self.eval_dataset = eval_dataset 43 | self.tokenizer = tokenizer 44 | self.wandb = WandbLogger(training_args) 45 | self.allow_print = self.training_args.local_rank in [-1, 0] 46 | if self.training_args.do_eval: 47 | self.metrics = {} 48 | self.compute_metrics = compute_metrics 49 | 50 | if 'deepspeed' not in sys.modules: 51 | raise ModuleNotFoundError( 52 | "Detected DeepSpeed is not installed. See https://github.com/microsoft/DeepSpeed") 53 | 54 | # Initialize deepspeed engine 55 | self.model, _, _, _ = deepspeed.initialize( 56 | config=training_args.deepspeed, 57 | model=model, 58 | ) 59 | 60 | # get train_dataloader and eval_dataloader 61 | if isinstance(data_collator, dict): 62 | assert 'train' in data_collator and 'eval' in data_collator, "data_collator should be a dict with keys 'train' and 'eval'." 63 | self.train_data_collator = data_collator['train'] 64 | if self.training_args.do_eval: 65 | self.eval_data_collator = data_collator['eval'] 66 | else: 67 | self.train_data_collator = self.eval_data_collator = data_collator 68 | self.train_dataloader = self.get_train_dataloader() 69 | if self.training_args.do_eval: 70 | if isinstance(self.eval_dataset, dict): 71 | self.eval_dataloader = {} 72 | for prefix in self.eval_dataset.keys(): 73 | self.eval_dataloader[prefix] = self.get_eval_dataloader(self.eval_dataset[prefix]) 74 | else: 75 | self.eval_dataloader = self.get_eval_dataloader() 76 | 77 | # setup learning rate 78 | self.num_steps_per_epoch = len(self.train_dataloader) 79 | self.global_step = 1 80 | self.n_steps = self.num_steps_per_epoch * self.training_args.num_train_epochs 81 | self.lr_scheduler = LearningRateScheduler(learning_rate=self.training_args.learning_rate, 82 | warmup=self.training_args.warmup, 83 | schedule=self.training_args.lr_scheduler_type, 84 | n_steps=self.n_steps) 85 | self.lr = 0 86 | 87 | self.optimizer = LOMO(model, self.lr, training_args.clip_grad_norm, training_args.clip_grad_value) 88 | 89 | get_accelerator().empty_cache() 90 | 91 | def train(self): 92 | for epoch in range(self.training_args.num_train_epochs): 93 | print(f"***** Running Training *****") 94 | print(f" Num examples: {len(self.train_dataset)}") 95 | print(f" Num Epochs: {self.training_args.num_train_epochs}") 96 | print(f" Current Epoch: {epoch}") 97 | print(f" Batch Size: {self.training_args.per_device_train_batch_size}") 98 | if self.allow_print: 99 | self.wandb.log({'train/epoch': epoch}, step=self.global_step) 100 | self.train_dataloader.sampler.set_epoch(epoch) 101 | 102 | with tqdm.tqdm(self.train_dataloader, disable=not self.allow_print) as tqb: 103 | for step, batch in enumerate(tqb, start=1): 104 | self.model.train() 105 | outs = self.model( 106 | input_ids=batch['input_ids'].cuda(), 107 | attention_mask=batch['attention_mask'].cuda(), 108 | ) 109 | loss = get_loss(outs.logits, batch['labels'], self.training_args.clip_loss_value) 110 | 111 | # update the learning rate 112 | self.global_step = self.num_steps_per_epoch * epoch + step 113 | self.lr = self.lr_scheduler.step(self.global_step) 114 | if self.training_args.clip_grad_norm is not None and self.training_args.clip_grad_norm > 0: 115 | self.optimizer.grad_norm(loss) 116 | # self.gather_norm = True 117 | # self.grad_norms = [] 118 | # self.loss_scaler.has_overflow_serial = False 119 | # scaled_loss = loss * self.loss_scaler.loss_scale 120 | # 121 | # scaled_loss.backward() 122 | # # update the last one since the hook function will not be called for the last parameter 123 | # self.grad_func(0) 124 | 125 | if self.optimizer.loss_scaler and self.optimizer.loss_scaler.has_overflow_serial: 126 | print(f"Gradient overflow, skipping step {self.global_step}") 127 | # self.loss_scaler.update_scale(overflow=True) 128 | # with torch.no_grad(): 129 | # for n, p in self.model.named_parameters(): 130 | # p.grad = None 131 | self.model.optimizer.get_param_coordinator(training=True).reset_step() 132 | tqb.set_postfix({'loss': loss.item()}) 133 | if self.allow_print: 134 | self.wandb.log( 135 | { 136 | 'train/loss': loss.item(), 137 | 'train/learning_rate': self.lr, 138 | 'train/global_step': self.global_step, 139 | }, 140 | step=self.global_step 141 | ) 142 | continue 143 | 144 | # with torch.no_grad(): 145 | # # The norm is computed over all gradients together, as if they were 146 | # # concatenated into a single vector. Gradients are modified in-place. 147 | # self.grad_norms = torch.stack(self.grad_norms) 148 | # # device = torch.device(f"cuda:{self.training_args.local_rank}") 149 | # # all_grad_norms = torch.zeros(self.training_args.world_size * self.grad_norms.shape[0], dtype=self.grad_norms.dtype, device=device) 150 | # # torch.distributed.all_gather_into_tensor(all_grad_norms, self.grad_norms) 151 | # 152 | # # total_norm = torch.norm(all_grad_norms, 2.0) / self.training_args.world_size 153 | # total_norm = torch.norm(self.grad_norms, 2.0) 154 | # self.clip_coef = float(self.training_args.clip_grad_norm) / (total_norm + 1e-6) 155 | # self.clip_coef = torch.clamp(self.clip_coef, max=1.0) 156 | # self.gather_norm = False 157 | else: 158 | self.model.optimizer.get_param_coordinator(training=True).reset_step() 159 | # 第二次forward 160 | outs = self.model( 161 | input_ids=batch['input_ids'].cuda(), 162 | attention_mask=batch['attention_mask'].cuda(), 163 | ) 164 | loss = get_loss(outs.logits, batch['labels'], self.training_args.clip_loss_value) 165 | 166 | # scaled_loss = loss * self.loss_scaler.loss_scale 167 | # 168 | # scaled_loss.backward() 169 | # # update the last one since the hook function will not be called for the last parameter 170 | # self.grad_func(0) 171 | # self.loss_scaler.update_scale(overflow=False) 172 | self.optimizer.fused_backward(loss, self.lr) 173 | self.model.optimizer.get_param_coordinator(training=True).reset_step() 174 | 175 | tqb.set_postfix({'loss': loss.item()}) 176 | if self.allow_print: 177 | self.wandb.log( 178 | { 179 | 'train/loss': loss.item(), 180 | 'train/learning_rate': self.lr, 181 | 'train/global_step': self.global_step, 182 | }, 183 | step=self.global_step 184 | ) 185 | 186 | if self.training_args.save_strategy == 'steps' and self.global_step % self.training_args.save_steps == 0: 187 | self.save_model(self.global_step) 188 | 189 | if self.training_args.do_eval and self.training_args.evaluation_strategy == 'steps' and \ 190 | self.global_step % self.training_args.eval_steps == 0: 191 | if isinstance(self.eval_dataset, dict): 192 | for prefix in self.eval_dataset.keys(): 193 | assert prefix in self.eval_dataloader.keys(), "eval_dataset and eval_dataloader should have the same keys." 194 | self.eval(self.global_step, epoch, self.eval_dataset[prefix], 195 | self.eval_dataloader[prefix], prefix) 196 | else: 197 | self.eval(self.global_step, epoch, self.eval_dataset, self.eval_dataloader, 'eval') 198 | 199 | if self.training_args.save_strategy == 'epoch': 200 | self.save_model(epoch) 201 | 202 | if self.training_args.do_eval and self.training_args.evaluation_strategy == 'epoch': 203 | if isinstance(self.eval_dataset, dict): 204 | for prefix in self.eval_dataset.keys(): 205 | assert prefix in self.eval_dataloader.keys(), "eval_dataset and eval_dataloader should have the same keys." 206 | self.eval(self.global_step, epoch, self.eval_dataset[prefix], self.eval_dataloader[prefix], 207 | prefix) 208 | else: 209 | self.eval(self.global_step, epoch, self.eval_dataset, self.eval_dataloader, 'eval') 210 | 211 | def eval( 212 | self, 213 | step: int, 214 | epoch: int, 215 | dataset: torch.utils.data.Dataset, 216 | dataloader: DataLoader, 217 | eval_prefix: str 218 | ): 219 | r""" 220 | Shared by both eval(validation) and predict(test). 221 | This method will be called by the trainer to evaluate the model. 222 | """ 223 | print(f"***** Running {eval_prefix} *****") 224 | print(f" Num examples: {len(dataset)}") 225 | print(f" Current Epoch: {epoch}") 226 | print(f" Batch size: {self.training_args.per_device_eval_batch_size}") 227 | 228 | with tqdm.tqdm(dataloader, disable=not self.allow_print) as tqb: 229 | all_preds = None 230 | self.model.eval() 231 | for batch in tqb: 232 | with torch.no_grad(): 233 | if self.training_args.predict_with_generate: 234 | pred = self.generate_step(batch) 235 | else: 236 | pred = self.eval_step(batch) 237 | all_preds = pred if all_preds is None else all_preds + pred 238 | 239 | all_preds_gather = [None for _ in range(self.training_args.world_size)] 240 | torch.distributed.all_gather_object(all_preds_gather, all_preds) 241 | all_pred_merged = list(chain(*all_preds_gather)) 242 | 243 | result = self.compute_metrics(all_pred_merged, dataset, eval_prefix) 244 | result = {f"{eval_prefix}/{k}": v for k, v in result.items()} 245 | prefix_metric_for_best_model = f'{eval_prefix}/{self.training_args.metric_for_best_model}' 246 | result_value = result[prefix_metric_for_best_model] 247 | 248 | if self.allow_print: 249 | print(f'epoch: {epoch}, step: {step}, {self.training_args.metric_for_best_model}: {result_value}') 250 | self.wandb.log(result, step=step) 251 | 252 | if self.is_better(result, prefix_metric_for_best_model): 253 | self.wandb.set_summary(f'{eval_prefix}/best_{self.training_args.metric_for_best_model}', result_value) 254 | self.wandb.set_summary(f'{eval_prefix}/best_epoch', epoch) 255 | self.wandb.set_summary(f'{eval_prefix}/best_step', step) 256 | self.metrics[prefix_metric_for_best_model] = result_value 257 | 258 | def eval_step(self, batch): 259 | """ 260 | used for classification or multi-choice qa tasks in eval() 261 | """ 262 | outs = self.model(batch['input_ids'].cuda(), batch['attention_mask'].cuda()) 263 | # Shift so that tokens < n predict n 264 | shift_logits = outs.logits[..., :-1, :].contiguous() 265 | shift_labels = batch['labels'][..., 1:].cuda().contiguous() 266 | # Flatten the tokens 267 | loss_fct = CrossEntropyLoss(reduction='none') 268 | loss = loss_fct(shift_logits.view(shift_labels.shape[0] * shift_labels.shape[1], -1), 269 | shift_labels.view(-1)).view_as(shift_labels) 270 | loss = loss.mean(dim=1) 271 | group_loss = loss.split(batch['split_size']) 272 | preds = torch.stack([torch.argmin(l) for l in group_loss], dim=0) 273 | 274 | preds = nested_numpify(preds) 275 | return preds.tolist() 276 | 277 | def generate_step(self, batch): 278 | """ 279 | used for generation tasks in eval() 280 | """ 281 | self.model.eval() 282 | generation_config = GenerationConfig(max_length=self.training_args.max_length, 283 | max_new_tokens=self.training_args.max_new_tokens, 284 | do_sample=self.training_args.do_sample, 285 | temperature=self.training_args.temperature, 286 | top_k=self.training_args.top_k, 287 | top_p=self.training_args.top_p, 288 | typical_p=self.training_args.typical_p, 289 | repetition_penalty=self.training_args.repetition_penalty, ) 290 | logits = self.model.generate( 291 | inputs=batch['input_ids'].cuda(), 292 | generation_config=generation_config 293 | ) 294 | predictions = logits.detach().cpu().numpy() 295 | pred_texts = self.tokenizer.batch_decode(predictions, skip_special_tokens=True) 296 | return pred_texts 297 | 298 | def is_better(self, result_dict, key): 299 | """ 300 | 判断 ``result`` 是否更好。 301 | 302 | :param result: 303 | """ 304 | op = operator.gt if self.training_args.greater_is_better else operator.lt 305 | return ( 306 | key not in self.metrics or op(result_dict[key], self.metrics[key]) 307 | ) 308 | 309 | def get_train_sampler(self): 310 | if self.train_dataset is None or not has_length(self.train_dataset): 311 | return None 312 | 313 | # for backwards compatibility, we generate a seed here (which is sampled from a generator seeded with 314 | # `self.training_args.seed`) if data_seed isn't provided. 315 | # Further on in this method, we default to `self.training_args.seed` instead. 316 | seed = self.training_args.data_seed if self.training_args.data_seed is not None else self.training_args.seed 317 | 318 | if self.training_args.group_by_length: 319 | return DistributedLengthGroupedSampler( 320 | self.training_args.per_device_train_batch_size * self.training_args.gradient_accumulation_steps, 321 | dataset=self.train_dataset, 322 | num_replicas=self.training_args.world_size, 323 | rank=self.training_args.local_rank, 324 | lengths=None, 325 | model_input_name="input_ids", 326 | seed=seed, 327 | ) 328 | else: 329 | return DistributedSampler( 330 | self.train_dataset, 331 | num_replicas=self.training_args.world_size, 332 | rank=self.training_args.local_rank, 333 | seed=seed 334 | ) 335 | 336 | def get_train_dataloader(self): 337 | """ 338 | Returns the training [`~torch.utils.data.DataLoader`]. 339 | Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed 340 | training if necessary) otherwise. 341 | Subclass and override this method if you want to inject some custom behavior. 342 | """ 343 | if self.train_dataset is None: 344 | raise ValueError("Trainer: training requires a train_dataset.") 345 | 346 | data_collator = self.train_data_collator 347 | train_sampler = self.get_train_sampler() 348 | 349 | return DataLoader( 350 | self.train_dataset, 351 | batch_size=self.training_args.per_device_train_batch_size, 352 | sampler=train_sampler, 353 | collate_fn=data_collator, 354 | drop_last=self.training_args.dataloader_drop_last, 355 | num_workers=self.training_args.dataloader_num_workers, 356 | pin_memory=self.training_args.dataloader_pin_memory, 357 | worker_init_fn=seed_worker, 358 | ) 359 | 360 | def get_eval_sampler(self, eval_dataset): 361 | return SequentialDistributedSampler( 362 | eval_dataset, 363 | num_replicas=self.training_args.world_size, 364 | rank=self.training_args.local_rank, 365 | # batch_size=self.training_args.per_device_eval_batch_size 366 | ) 367 | 368 | def get_eval_dataloader(self, eval_dataset=None): 369 | """ 370 | Returns the evaluation [`~torch.utils.data.DataLoader`]. 371 | 372 | Subclass and override this method if you want to inject some custom behavior. 373 | 374 | Args: 375 | eval_dataset (`torch.utils.data.Dataset`, *optional*): 376 | If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted 377 | by the `model.forward()` method are automatically removed. It must implement `__len__`. 378 | """ 379 | if eval_dataset is None and self.eval_dataset is None: 380 | raise ValueError("Trainer: evaluation requires an eval_dataset.") 381 | eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset 382 | data_collator = self.eval_data_collator 383 | 384 | eval_sampler = self.get_eval_sampler(eval_dataset) 385 | 386 | return DataLoader( 387 | eval_dataset, 388 | sampler=eval_sampler, 389 | batch_size=self.training_args.per_device_eval_batch_size, 390 | collate_fn=data_collator, 391 | drop_last=self.training_args.dataloader_drop_last, 392 | num_workers=self.training_args.dataloader_num_workers, 393 | pin_memory=self.training_args.dataloader_pin_memory, 394 | ) 395 | 396 | def save_model(self, index): 397 | if self.training_args.local_rank in [-1, 0]: 398 | checkpoint_dir = sorted(Path(self.training_args.output_dir).glob("checkpoint-*")) 399 | if len(checkpoint_dir) >= self.training_args.save_total_limit: 400 | shutil.rmtree(checkpoint_dir[0], ignore_errors=True) 401 | torch.distributed.barrier() 402 | 403 | output_dir = os.path.join(self.training_args.output_dir, f"checkpoint-{index}") 404 | if not os.path.exists(output_dir): 405 | os.makedirs(output_dir, exist_ok=True) 406 | state_dict = OrderedDict() if torch.distributed.get_rank() == 0 else None 407 | shared_params = {} 408 | 409 | # Prepare for checkpoint save by ensuring all parameters are partitioned 410 | self.model.optimizer.partition_all_parameters() 411 | 412 | for name, param in self.model.module.named_parameters(): 413 | with deepspeed.zero.GatheredParameters(param): 414 | if torch.distributed.get_rank() == 0: 415 | # can't rely on param.data_ptr() as it will be reused as weights gets 416 | # gathered and reduced, but param.ds_id is unique across all zero weights 417 | # (and shared params will have the same param.ds_id) 418 | if param.ds_id in shared_params: 419 | # shared weights 420 | state_dict[name] = state_dict[shared_params[param.ds_id]] 421 | else: 422 | state_dict[name] = param.detach().cpu() 423 | shared_params[param.ds_id] = name 424 | 425 | if len(self.model.optimizer.persistent_parameters) > 0: 426 | self.model.optimizer.persistent_parameters[0].all_gather(self.model.optimizer.persistent_parameters) 427 | 428 | if torch.distributed.get_rank() == 0: 429 | self.model.module.config.save_pretrained(output_dir) 430 | torch.save(state_dict, os.path.join(output_dir, 'pytorch_model.bin')) 431 | print(f"Save model to {output_dir}") 432 | 433 | torch.distributed.barrier() 434 | -------------------------------------------------------------------------------- /lomo/src/merge_llama_with_lora.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | from peft import PeftModel 6 | from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer 7 | 8 | ''' 9 | This is the code for merging the LoRA adapter with the base model. [ref] https://github.com/tloen/alpaca-lora/blob/main/export_hf_checkpoint.py 10 | 11 | To load `lora + lomo` checkpoint, please first run `python merge_llama_with_lora.py` to merge the weights. Then, set `resume_from_checkpoint` to the merged weights path. 12 | ''' 13 | 14 | 15 | def apply_lora(model_name_or_path, output_path, lora_path): 16 | print(f"Loading the base model from {model_name_or_path}") 17 | base = AutoModelForCausalLM.from_pretrained( 18 | model_name_or_path, low_cpu_mem_usage=True 19 | ) 20 | # base_tokenizer = LlamaTokenizer.from_pretrained(model_name_or_path) 21 | 22 | print(f"Loading the LoRA adapter from {lora_path}") 23 | 24 | lora_model = PeftModel.from_pretrained( 25 | base, 26 | lora_path, 27 | ) 28 | 29 | print("Applying the LoRA") 30 | model = lora_model.merge_and_unload() 31 | 32 | print(f"Saving the target model to {output_path}") 33 | model.save_pretrained(output_path) 34 | # base_tokenizer.save_pretrained(output_path) 35 | 36 | 37 | if __name__ == "__main__": 38 | parser = argparse.ArgumentParser() 39 | # parser.add_argument("--model_name_or_path", type=str, required=True) 40 | # parser.add_argument("--output_path", type=str, required=True) 41 | # parser.add_argument("--lora_path", type=str, required=True) 42 | parser.add_argument("--llama", action="store_true", required=True) 43 | 44 | # model_name_or_path = '/remote-home/share/llama_hf/7B' 45 | model_name_or_path = 'outputs/wic_7B_lora-qv-r2-lomo/output_lr0.005_bs16_warmup0.05_clipnorm1.0/checkpoint-0' 46 | # good path 47 | ckpt_path = 'outputs/wic_7B_lora-qv-r2-lomo/output_lr0.005_bs16_warmup0.05_clipnorm1.0/checkpoint-0' 48 | # ckpt_path = 'outputs/lora65b_checkpoint-2500' 49 | lora_path = os.path.join(ckpt_path, 'adapter_model') 50 | output_path = os.path.join(ckpt_path, 'merge_weights') 51 | 52 | # lora_path = 'outputs/belle_llama-7b_1w_zh_len1400_zero2_5e4/output_lora_adamw_hf_lr0.0005/checkpoint-1560/global_step1560/adapter_model' 53 | # output_path = 'outputs/belle_llama-7b_1w_zh_len1400_zero2_5e4/output_lora_adamw_hf_lr0.0005/checkpoint-1560/global_step1560/merge_weights' 54 | 55 | apply_lora(model_name_or_path, output_path, lora_path) 56 | -------------------------------------------------------------------------------- /lomo/src/mydatasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import random 4 | from tqdm import tqdm 5 | from typing import Callable, Any 6 | 7 | from datasets import load_dataset 8 | from dataclasses import dataclass 9 | import numpy as np 10 | import torch 11 | from torch.utils.data import Dataset 12 | 13 | from log import print 14 | from prompts import QuestionPart, Exemplar, idx_to_ltr 15 | 16 | IGNORE_INDEX = -100 17 | REPRODUCIBILITY_SEED = 0 18 | 19 | 20 | class MyDataset(Dataset): 21 | def __init__(self, data_args, tokenizer, dataset_info, split): 22 | super().__init__() 23 | self.data_args = data_args 24 | self.tokenizer = tokenizer 25 | self.split = split 26 | self.sample_size = dataset_info.sample_size 27 | self.prompt_type = dataset_info.prompt_type 28 | 29 | save_dir = os.path.join(data_args.data_dir, data_args.dataset_name, data_args.data_tag) 30 | if not os.path.exists(save_dir): 31 | os.makedirs(save_dir, exist_ok=True) 32 | 33 | save_file = os.path.join(save_dir, f'{split}.pt') 34 | if data_args.refresh or not os.path.exists(save_file): 35 | dataset = load_dataset(dataset_info.path, name=dataset_info.name, split=split) 36 | self.data = self.process(dataset_info.extractor, dataset, save_file) 37 | else: 38 | print('Loading data from', save_file) 39 | self.data = torch.load(save_file) 40 | print('Data size:', len(self.data)) 41 | print('Data format:', self.data[0]) 42 | print('Max length:', max([len(d['input_ids']) for d in self.data])) if self.split == 'train' else \ 43 | print('Max length:', max([max([len(d) for d in dd['input_ids']]) for dd in self.data])) 44 | 45 | def process(self, extractor, dataset, save_file): 46 | data = [] 47 | for instance in tqdm(dataset): 48 | exemplar = Exemplar(**extractor(instance)) 49 | if self.prompt_type == 'brown': 50 | prompt = exemplar.get_brown_prompt() 51 | else: 52 | prompt = exemplar.get_natural_prompt() 53 | source = prompt['source'] 54 | 55 | targets = [] 56 | 57 | def _tokenize_fn(source, target): 58 | targets.append(target) 59 | example = f"{source}{target}" 60 | example_tokenized = self.tokenizer.encode(example, truncation=True, max_length=self.data_args.data_max_length) 61 | example_tokenized = example_tokenized + [self.tokenizer.eos_token_id] 62 | source_tokenized = self.tokenizer.encode(source) 63 | 64 | input_ids = example_tokenized 65 | labels = copy.deepcopy(input_ids) 66 | if not self.data_args.train_on_inputs: 67 | labels = np.array(labels) 68 | labels[:len(source_tokenized) - 1] = IGNORE_INDEX 69 | return input_ids, labels 70 | 71 | if self.split == 'train': 72 | input_ids, labels = _tokenize_fn(source, prompt['target']) 73 | else: 74 | input_ids = [] 75 | labels = [] 76 | for choice in prompt['choices']: 77 | op_input_ids, op_labels = _tokenize_fn(source, choice) 78 | input_ids.append(op_input_ids) 79 | labels.append(op_labels) 80 | 81 | data.append({'input_ids': input_ids, 82 | 'labels': labels, 83 | 'source': source, 84 | 'target': targets, 85 | 'answer': exemplar.answer_idx}) 86 | 87 | if self.sample_size > 0 and len(data) > self.sample_size: 88 | random.seed(REPRODUCIBILITY_SEED) 89 | possible_idxs = list(range(len(data))) 90 | sampled_idxs = random.sample(possible_idxs, self.sample_size) 91 | data = [data[i] for i in sampled_idxs] 92 | print(f'Sampled {self.sample_size} examples from {len(possible_idxs)} examples.') 93 | 94 | torch.save(data, save_file) 95 | print('Saving data to', save_file) 96 | return data 97 | 98 | def concat_exemplars(self, exemplars): 99 | exemplar_prompts = [f"{e['source']}{e['target'][0]}" for e in exemplars] 100 | exemplars = "\n\n".join(exemplar_prompts) 101 | return exemplars 102 | 103 | def __len__(self): 104 | return len(self.data) 105 | 106 | def __getitem__(self, idx): 107 | return { 108 | 'input_ids': self.data[idx]['input_ids'], 109 | 'labels': self.data[idx]['labels'] 110 | } 111 | 112 | 113 | @dataclass 114 | class DatasetInfo: 115 | path: str = None 116 | exemplar_split: str = None 117 | eval_split: str = None 118 | test_split: str = None 119 | extractor: Callable = Any 120 | name: str = None 121 | data_dir: str = None 122 | sample_size: int = -1 123 | prompt_type: str = 'brown' 124 | 125 | 126 | def get_dataset_info(dataset_name): 127 | if dataset_name == 'boolq': 128 | return DatasetInfo( 129 | path="super_glue", 130 | name="boolq", 131 | exemplar_split="train", 132 | eval_split="validation", 133 | sample_size=1000, 134 | extractor=lambda row: { 135 | "parts": [ 136 | QuestionPart( 137 | f"{row['passage']} {row['question']}", 138 | ), 139 | ], 140 | "choices": [ 141 | 'No', 'Yes' 142 | ], 143 | "answer_idx": int(row["label"]) 144 | } 145 | ) 146 | # elif dataset_name == 'cb': 147 | # return DatasetInfo( 148 | # path="super_glue", 149 | # name="cb", 150 | # exemplar_split="train", 151 | # eval_split="validation", 152 | # sample_size=1000, 153 | # extractor=lambda row: { 154 | # "parts": [ 155 | # QuestionPart( 156 | # f"Suppose {row['premise']} Can we infer that \"{row['hypothesis']}\"? Yes, No, or Maybe?", 157 | # ), 158 | # ], 159 | # "choices": [ 160 | # 'Yes', 'No', 'Maybe' 161 | # ], 162 | # "answer_idx": int(row["label"]) 163 | # } 164 | # ) 165 | elif dataset_name == 'multirc': 166 | return DatasetInfo( 167 | path="super_glue", 168 | name="multirc", 169 | exemplar_split="train", 170 | eval_split="validation", 171 | sample_size=1000, 172 | extractor=lambda row: { 173 | "parts": [ 174 | QuestionPart( 175 | f"{row['paragraph']}", 176 | ), 177 | QuestionPart( 178 | f"{row['question']}", 179 | tag='Question' 180 | ), 181 | QuestionPart( 182 | f'I found this answer "{row["answer"]}". Is that correct? Yes or No?', 183 | ), 184 | ], 185 | "choices": [ 186 | 'No', 'Yes' 187 | ], 188 | "answer_idx": int(row["label"]) 189 | } 190 | ) 191 | elif dataset_name == 'rte': 192 | return DatasetInfo( 193 | path="super_glue", 194 | name="rte", 195 | exemplar_split="train", 196 | eval_split="validation", 197 | sample_size=1000, 198 | extractor=lambda row: { 199 | "parts": [ 200 | QuestionPart( 201 | f"{row['premise']}\nDoes this mean that \"{row['hypothesis']}\" is true? Yes or No?", 202 | ), 203 | ], 204 | "choices": [ 205 | 'Yes', 'No' 206 | ], 207 | "answer_idx": int(row["label"]) 208 | } 209 | ) 210 | elif dataset_name == 'wic': 211 | return DatasetInfo( 212 | path="super_glue", 213 | name="wic", 214 | exemplar_split="train", 215 | eval_split="validation", 216 | sample_size=1000, 217 | extractor=lambda row: { 218 | "parts": [ 219 | QuestionPart( 220 | f"Does the word \"{row['word']}\" have the same meaning in these two sentences? Yes, No?\n{row['sentence1']}\n{row['sentence2']}", 221 | ), 222 | ], 223 | "choices": [ 224 | 'No', 'Yes' 225 | ], 226 | "answer_idx": int(row["label"]) 227 | } 228 | ) 229 | elif dataset_name == 'wsc': 230 | return DatasetInfo( 231 | path="super_glue", 232 | name="wsc", 233 | exemplar_split="train", 234 | eval_split="validation", 235 | sample_size=1000, 236 | extractor=lambda row: { 237 | "parts": [ 238 | QuestionPart( 239 | f"{row['text']}\nIn the previous sentence, does the pronuon \"{row['span2_text']}\" refer to \"{row['span1_text']}\"? Yes or No?", 240 | ), 241 | ], 242 | "choices": [ 243 | 'No', 'Yes' 244 | ], 245 | "answer_idx": int(row["label"]) 246 | } 247 | ) 248 | elif dataset_name == 'copa': 249 | return DatasetInfo( 250 | path="super_glue", 251 | name="copa", 252 | exemplar_split="train", 253 | eval_split="validation", 254 | sample_size=1000, 255 | prompt_type='natural', 256 | extractor=lambda row: { 257 | "parts": [ 258 | QuestionPart( 259 | f"{row['premise']} so " if row['question'] == 'effect' else f"{row['premise']} because ", 260 | ), 261 | ], 262 | "choices": [ 263 | row['choice1'], row['choice2'] 264 | ], 265 | "answer_idx": int(row["label"]) 266 | } 267 | ) 268 | elif dataset_name == 'record': 269 | return DatasetInfo( 270 | path="super_glue", 271 | name="record", 272 | exemplar_split="train", 273 | eval_split="validation", 274 | sample_size=1000, 275 | extractor=process_record 276 | ) 277 | else: 278 | raise NotImplementedError 279 | 280 | 281 | def process_record(row): 282 | def record_clean_choices(row): 283 | if len(row['answers']) == 1: 284 | return row['entities'], row['entities'].index(row['answers'][0]) 285 | 286 | new_entities = [] 287 | for entity in row['entities']: 288 | if entity in row['answers'][1:]: 289 | continue 290 | new_entities.append(entity) 291 | return new_entities, new_entities.index(row['answers'][0]) 292 | 293 | choices, answer_idx = record_clean_choices(row) 294 | return { 295 | "parts": [ 296 | QuestionPart( 297 | "{}\n{}\nQuestion: What is the \"@placeholder\"?".format(row['passage'].replace('@highlight\n', '- '), row['query']), 298 | ), 299 | ], 300 | "choices": choices, 301 | "answer_idx": answer_idx 302 | } 303 | 304 | 305 | if __name__ == '__main__': 306 | from transformers import HfArgumentParser 307 | from arguments import ModelArguments, DataArguments 308 | from transformers import AutoTokenizer 309 | 310 | parser = HfArgumentParser((ModelArguments, DataArguments)) 311 | model_args, data_args = parser.parse_args_into_dataclasses() 312 | model_args.model_name_or_path = '/home/klv/llama_hf/7B' 313 | data_args.dataset_name = 'record' 314 | data_args.refresh = True 315 | data_args.data_tag = 'debug' 316 | train_on_inputs = False 317 | data_args.data_max_length = 512 318 | 319 | tokenizer = AutoTokenizer.from_pretrained( 320 | model_args.model_name_or_path, 321 | use_fast=False, 322 | padding_side='left' 323 | ) 324 | tokenizer.pad_token_id = 0 325 | 326 | dataset_info = get_dataset_info(data_args.dataset_name) 327 | train_dataset = MyDataset(data_args, tokenizer, dataset_info, split=dataset_info.exemplar_split) 328 | eval_dataset = MyDataset(data_args, tokenizer, dataset_info, split=dataset_info.eval_split) 329 | # test_dataset = MyDataset(data_args, tokenizer, dataset_info, split=dataset_info.test_split) 330 | 331 | 332 | 333 | -------------------------------------------------------------------------------- /lomo/src/prompts.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import random 4 | 5 | 6 | def idx_to_ltr(idx): 7 | return chr(idx + ord("A")) 8 | 9 | 10 | @dataclass 11 | class QuestionPart: 12 | text: str 13 | tag: str = None 14 | 15 | def __str__(self): 16 | if self.tag is not None: 17 | return f"{self.tag}: {self.text}" 18 | else: 19 | return self.text 20 | 21 | 22 | @dataclass 23 | class Question: 24 | parts: list 25 | choices: list 26 | answer_idx: int 27 | task: str = None 28 | 29 | def get_n_choices(self): 30 | return len(self.choices) 31 | 32 | def get_answer_str(self): 33 | return self.choices[self.answer_idx] 34 | 35 | def _get_prompt(self, include_choices): 36 | prompt = "" 37 | for part in self.parts: 38 | prompt += f"{str(part)}\n" 39 | if include_choices: 40 | for i, choice in enumerate(self.choices): 41 | prompt += f"{idx_to_ltr(i)}. {choice}\n" 42 | return prompt 43 | 44 | def get_natural_prompt(self): 45 | return self._get_prompt(include_choices=True) 46 | 47 | def get_brown_prompt(self): 48 | return self._get_prompt(include_choices=False) 49 | 50 | def strong_shuffle(self): 51 | # This method shuffles choices such that choosing 52 | # the answer at the originally correct 53 | # index will mean getting the question wrong 54 | 55 | # For degenerate questions where all choices are the same 56 | if len(set(self.choices)) == 1: 57 | return 58 | 59 | answer_idx = self.answer_idx 60 | answer_str = self.get_answer_str() 61 | while self.choices[answer_idx] == answer_str: 62 | random.shuffle(self.choices) 63 | self.answer_idx = self.choices.index(answer_str) 64 | 65 | def permute_choices(self, perm): 66 | self.choices = [self.choices[i] for i in perm] 67 | self.answer_idx = perm.index(self.answer_idx) 68 | 69 | 70 | class Exemplar(Question): 71 | 72 | def get_natural_prompt(self): 73 | prompt = super().get_brown_prompt().strip('\n') 74 | # return f"{prompt} {self.get_answer_str()}" 75 | return { 76 | 'source': f"{prompt}", 77 | 'target': f"{self.get_answer_str()}", 78 | 'choices': self.choices 79 | } 80 | 81 | def get_brown_prompt(self): 82 | prompt = super().get_brown_prompt() 83 | # return f"{prompt} {self.get_answer_str()}" 84 | return { 85 | 'source': f"{prompt}Answer: ", 86 | 'target': f"{self.get_answer_str()}", 87 | 'choices': self.choices 88 | } 89 | -------------------------------------------------------------------------------- /lomo/src/train_lomo.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | import sys 4 | 5 | import torch 6 | from transformers import HfArgumentParser 7 | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig 8 | from transformers import set_seed 9 | from dataclasses import asdict 10 | from transformers.deepspeed import HfDeepSpeedConfig 11 | import wandb 12 | # os.environ['WANDB_MODE'] = 'debug' 13 | 14 | python_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 15 | print("PYTHON_PATH", python_path) 16 | sys.path.append(python_path) 17 | from log import print 18 | from arguments import ModelArguments, DataArguments, MyTrainingArguments 19 | from mydatasets import MyDataset, get_dataset_info 20 | from lomo_trainer import LOMOTrainer 21 | from utils import DataCollatorForCauselLM, EvalDataCollatorForCauselLM 22 | 23 | 24 | def compute_metrics(all_pred, eval_dataset, eval_prefix=None): 25 | golds = [ins['answer'] for ins in eval_dataset.data] 26 | preds = all_pred[:len(golds)] 27 | 28 | acc = round(sum([int(pred == gold) for pred, gold in zip(preds, golds)]) / len(golds), 6) 29 | result = {'acc': acc} 30 | return result 31 | 32 | 33 | def train(): 34 | # ========== 1. logs and args ========== 35 | torch.set_default_dtype(torch.float16) 36 | parser = HfArgumentParser((ModelArguments, DataArguments, MyTrainingArguments)) 37 | if sys.argv[-1].endswith(".yaml"): 38 | model_args, data_args, training_args = parser.parse_yaml_file(yaml_file=os.path.abspath(sys.argv[-1])) 39 | else: 40 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 41 | set_seed(training_args.seed) 42 | 43 | model_name = model_args.model_name_or_path.split('/')[-1] 44 | tag_name = '_'.join([data_args.dataset_name, model_name, training_args.tag] if training_args.tag else [data_args.dataset_name, model_name]) 45 | hparam_name = 'output' 46 | if training_args.optim != 'sgd': 47 | hparam_name += '_' + training_args.optim 48 | if training_args.learning_rate != 5e-4: 49 | hparam_name += '_lr' + str(training_args.learning_rate) 50 | if training_args.per_device_train_batch_size != 8: 51 | hparam_name += '_bs' + str(training_args.per_device_train_batch_size) 52 | if training_args.lr_scheduler_type != 'linear': 53 | hparam_name += '_' + training_args.lr_scheduler_type 54 | if training_args.warmup != 0: 55 | hparam_name += '_warmup' + str(training_args.warmup) 56 | if training_args.clip_grad_norm and training_args.clip_grad_norm > 0: 57 | hparam_name += '_clipnorm' + str(training_args.clip_grad_norm) 58 | if training_args.clip_grad_value and training_args.clip_grad_value > 0: 59 | hparam_name += '_clipgrad' + str(training_args.clip_grad_value) 60 | if training_args.clip_loss_value and training_args.clip_loss_value > 0: 61 | hparam_name += '_cliploss' + str(training_args.clip_loss_value) 62 | # assert training_args.clip_grad_value is None or training_args.clip_loss_value is None 63 | training_args.output_dir = os.path.join('outputs', tag_name, hparam_name) 64 | 65 | if training_args.tag == 'debug': 66 | os.environ['WANDB_MODE'] = 'offline' 67 | if training_args.local_rank in [-1, 0]: 68 | wandb_config = copy.deepcopy(asdict(training_args)) 69 | wandb_config.update(asdict(model_args)) 70 | wandb_config.update(asdict(data_args)) 71 | wandb.init( 72 | project="collie", 73 | entity='collie_exp', 74 | name=tag_name if hparam_name == 'output' else '_'.join([tag_name, hparam_name.replace('output_', '')]), 75 | config=wandb_config 76 | ) 77 | 78 | # ========== 2. Load pretrained model and tokenizer. ========== 79 | ds_config = training_args.deepspeed 80 | dschf = HfDeepSpeedConfig(ds_config) 81 | config = AutoConfig.from_pretrained(model_args.model_name_or_path) 82 | config.gradient_checkpointing = training_args.gradient_checkpointing 83 | if training_args.resume_from_checkpoint is not None: 84 | print(f'Load checkpoint from {training_args.resume_from_checkpoint}.') 85 | model = AutoModelForCausalLM.from_pretrained( 86 | model_args.model_name_or_path if training_args.resume_from_checkpoint is None else training_args.resume_from_checkpoint, 87 | local_files_only=True, 88 | config=config, 89 | ) 90 | 91 | tokenizer = AutoTokenizer.from_pretrained( 92 | model_args.model_name_or_path, 93 | use_fast=False, 94 | padding_side='left' 95 | ) 96 | tokenizer.pad_token_id = 0 97 | 98 | # ========== 3. Preprocessing the datasets. ========== 99 | dataset_info = get_dataset_info(data_args.dataset_name) 100 | train_dataset = MyDataset(data_args, tokenizer, dataset_info, split=dataset_info.exemplar_split) 101 | eval_dataset = MyDataset(data_args, tokenizer, dataset_info, split=dataset_info.eval_split) 102 | 103 | # ========== 4. Initialize our Trainer. ========== 104 | trainer = LOMOTrainer( 105 | model=model, 106 | training_args=training_args, 107 | data_collator={'train': DataCollatorForCauselLM(tokenizer, max_length=data_args.data_max_length, padding_side='left'), 108 | 'eval': EvalDataCollatorForCauselLM(tokenizer, max_length=data_args.data_max_length, padding_side='left')}, 109 | train_dataset=train_dataset, 110 | eval_dataset=eval_dataset, 111 | tokenizer=tokenizer, 112 | compute_metrics=compute_metrics, 113 | ) 114 | if training_args.do_train: 115 | trainer.train() 116 | 117 | if training_args.do_eval: 118 | trainer.eval(trainer.global_step, 0, trainer.eval_dataset, trainer.eval_dataloader, 'test') 119 | 120 | 121 | if __name__ == "__main__": 122 | train() 123 | -------------------------------------------------------------------------------- /lomo/src/train_lomo_lora.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | import sys 4 | 5 | import torch 6 | from transformers import HfArgumentParser 7 | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig 8 | from transformers import set_seed 9 | from dataclasses import asdict 10 | from transformers.deepspeed import HfDeepSpeedConfig 11 | from peft import get_peft_model, TaskType, LoraConfig 12 | import wandb 13 | # os.environ['WANDB_MODE'] = 'debug' 14 | 15 | python_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 16 | print("PYTHON_PATH", python_path) 17 | sys.path.append(python_path) 18 | from log import print 19 | from arguments import ModelArguments, DataArguments, MyTrainingArguments 20 | from mydatasets import MyDataset, get_dataset_info 21 | from lomo_lora_trainer import LOMOLoRATrainer 22 | from utils import DataCollatorForCauselLM, EvalDataCollatorForCauselLM 23 | 24 | 25 | def compute_metrics(all_pred, eval_dataset, eval_prefix=None): 26 | golds = [ins['answer'] for ins in eval_dataset.data] 27 | preds = all_pred[:len(golds)] 28 | 29 | acc = round(sum([int(pred == gold) for pred, gold in zip(preds, golds)]) / len(golds), 6) 30 | result = {'acc': acc} 31 | return result 32 | 33 | 34 | def train(): 35 | # ========== 1. logs and args ========== 36 | torch.set_default_dtype(torch.bfloat16) 37 | parser = HfArgumentParser((ModelArguments, DataArguments, MyTrainingArguments)) 38 | if sys.argv[-1].endswith(".yaml"): 39 | model_args, data_args, training_args = parser.parse_yaml_file(yaml_file=os.path.abspath(sys.argv[-1])) 40 | else: 41 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 42 | set_seed(training_args.seed) 43 | 44 | model_name = model_args.model_name_or_path.split('/')[-1] 45 | tag_name = '_'.join([data_args.dataset_name, model_name, training_args.tag] if training_args.tag else [data_args.dataset_name, model_name]) 46 | hparam_name = 'output' 47 | if training_args.optim != 'sgd': 48 | hparam_name += '_' + training_args.optim 49 | if training_args.learning_rate != 5e-4: 50 | hparam_name += '_lr' + str(training_args.learning_rate) 51 | if training_args.per_device_train_batch_size != 8: 52 | hparam_name += '_bs' + str(training_args.per_device_train_batch_size) 53 | if training_args.lr_scheduler_type != 'linear': 54 | hparam_name += '_' + training_args.lr_scheduler_type 55 | if training_args.warmup != 0: 56 | hparam_name += '_warmup' + str(training_args.warmup) 57 | if training_args.clip_grad_norm and training_args.clip_grad_norm > 0: 58 | hparam_name += '_clipnorm' + str(training_args.clip_grad_norm) 59 | if training_args.clip_grad_value and training_args.clip_grad_value > 0: 60 | hparam_name += '_clipgrad' + str(training_args.clip_grad_value) 61 | if training_args.clip_loss_value and training_args.clip_loss_value > 0: 62 | hparam_name += '_cliploss' + str(training_args.clip_loss_value) 63 | # assert training_args.clip_grad_value is None or training_args.clip_loss_value is None 64 | training_args.output_dir = os.path.join('outputs', tag_name, hparam_name) 65 | 66 | if training_args.tag == 'debug': 67 | os.environ['WANDB_MODE'] = 'offline' 68 | if training_args.local_rank in [-1, 0]: 69 | wandb_config = copy.deepcopy(asdict(training_args)) 70 | wandb_config.update(asdict(model_args)) 71 | wandb_config.update(asdict(data_args)) 72 | wandb.init( 73 | project="collie", 74 | entity='collie_exp', 75 | name=tag_name if hparam_name == 'output' else '_'.join([tag_name, hparam_name.replace('output_', '')]), 76 | config=wandb_config 77 | ) 78 | 79 | # ========== 2. Load pretrained model and tokenizer. ========== 80 | ds_config = training_args.deepspeed 81 | dschf = HfDeepSpeedConfig(ds_config) 82 | config = AutoConfig.from_pretrained(model_args.model_name_or_path) 83 | config.gradient_checkpointing = training_args.gradient_checkpointing 84 | if training_args.resume_from_checkpoint is not None: 85 | print(f'Load checkpoint from {training_args.resume_from_checkpoint}.') 86 | assert not training_args.do_train, 'do not support resume training now.' 87 | model = AutoModelForCausalLM.from_pretrained( 88 | model_args.model_name_or_path if training_args.resume_from_checkpoint is None else training_args.resume_from_checkpoint, 89 | local_files_only=True, 90 | config=config, 91 | ) 92 | 93 | tokenizer = AutoTokenizer.from_pretrained( 94 | model_args.model_name_or_path, 95 | use_fast=False, 96 | padding_side='left' 97 | ) 98 | tokenizer.pad_token_id = 0 99 | 100 | peft_params = [] 101 | non_peft_names = [] 102 | non_peft_params = [] 103 | if training_args.resume_from_checkpoint is None: 104 | for name, param in model.named_parameters(): 105 | if param.requires_grad is False: 106 | continue 107 | non_peft_names.append(name) 108 | non_peft_params.append(param) 109 | 110 | # use peft 111 | if training_args.peft_type is not None: 112 | print(f'Using peft.{training_args.peft_type}') 113 | if training_args.peft_type == 'lora': 114 | peft_config = LoraConfig( 115 | r=training_args.lora_r, 116 | lora_alpha=training_args.lora_alpha, 117 | target_modules=["q_proj", "v_proj"], 118 | # target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "down_proj", "up_proj"], 119 | lora_dropout=training_args.lora_dropout, 120 | bias="none", 121 | task_type=TaskType.CAUSAL_LM 122 | ) 123 | model.enable_input_require_grads() 124 | else: 125 | raise ValueError(f"Unknown PEFT type: {training_args.peft_type}") 126 | model = get_peft_model(model, peft_config) 127 | model.print_trainable_parameters() 128 | 129 | # unfreeze base model 130 | # 包完peft之后的参数名字:base_model.model.model.layers.23.self_attn.v_proj.weight 131 | # 之前的参数的名字:model.layers.23.self_attn.v_proj.weight 132 | for name, param in model.named_parameters(): 133 | if name.split('base_model.model.')[1] in non_peft_names: 134 | if not training_args.lora_only: 135 | param.requires_grad = True 136 | if "lora_" in name: 137 | peft_params.append(param) 138 | 139 | torch.cuda.empty_cache() 140 | 141 | # ========== 3. Preprocessing the datasets. ========== 142 | dataset_info = get_dataset_info(data_args.dataset_name) 143 | train_dataset = MyDataset(data_args, tokenizer, dataset_info, split=dataset_info.exemplar_split) 144 | # if data_args.few_shot_size != -1: 145 | # # few_shot_indices = sample(range(len(train_dataset)), data_args.few_shot_size) 146 | # train_dataset = Subset(train_dataset, range(data_args.few_shot_size)) 147 | eval_dataset = MyDataset(data_args, tokenizer, dataset_info, split=dataset_info.eval_split) 148 | if dataset_info.test_split: 149 | test_dataset = MyDataset(data_args, tokenizer, dataset_info, split=dataset_info.test_split) 150 | eval_dataset = { 151 | # 'validation': eval_dataset, 152 | 'test': test_dataset 153 | } 154 | 155 | # ========== 4. Initialize our Trainer. ========== 156 | trainer = LOMOLoRATrainer( 157 | model=model, 158 | training_args=training_args, 159 | data_collator={'train': DataCollatorForCauselLM(tokenizer, max_length=data_args.data_max_length, padding_side='left'), 160 | 'eval': EvalDataCollatorForCauselLM(tokenizer, max_length=data_args.data_max_length, padding_side='left')}, 161 | train_dataset=train_dataset, 162 | eval_dataset=eval_dataset, 163 | tokenizer=tokenizer, 164 | compute_metrics=compute_metrics, 165 | optimizers={'model_parameters': peft_params}, 166 | ) 167 | if training_args.do_train: 168 | trainer.train() 169 | 170 | if training_args.do_eval: 171 | trainer.eval(trainer.global_step, 0, trainer.eval_dataset, trainer.eval_dataloader, 'test') 172 | 173 | 174 | if __name__ == "__main__": 175 | train() 176 | -------------------------------------------------------------------------------- /lomo/src/utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from dataclasses import dataclass 3 | 4 | import numpy as np 5 | from torch.nn import CrossEntropyLoss 6 | from transformers.utils import PaddingStrategy 7 | from transformers.trainer import * 8 | import wandb 9 | 10 | 11 | @dataclass 12 | class DataCollatorForCauselLM: 13 | """ 14 | Data collator that will dynamically pad the inputs received, as well as the labels. 15 | Args: 16 | tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]): 17 | The tokenizer used for encoding the data. 18 | model ([`PreTrainedModel`]): 19 | The model that is being trained. If set and has the *prepare_decoder_input_ids_from_labels*, use it to 20 | prepare the *decoder_input_ids* 21 | This is useful when using *label_smoothing* to avoid calculating loss twice. 22 | padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): 23 | Select a strategy to pad the returned sequences (according to the model's padding side and padding index) 24 | among: 25 | - `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single 26 | sequence is provided). 27 | - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum 28 | acceptable input length for the model if that argument is not provided. 29 | - `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths). 30 | max_length (`int`, *optional*): 31 | Maximum length of the returned list and optionally padding length (see above). 32 | pad_to_multiple_of (`int`, *optional*): 33 | If set will pad the sequence to a multiple of the provided value. 34 | This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 35 | 7.5 (Volta). 36 | label_pad_token_id (`int`, *optional*, defaults to -100): 37 | The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions). 38 | return_tensors (`str`): 39 | The type of Tensor to return. Allowable values are "np", "pt" and "tf". 40 | """ 41 | 42 | tokenizer: Any 43 | model: Optional[Any] = None 44 | padding: Union[bool, str, PaddingStrategy] = True 45 | max_length: Optional[int] = None 46 | pad_to_multiple_of: Optional[int] = None 47 | label_pad_token_id: int = -100 48 | return_tensors: str = "pt" 49 | padding_side: str = 'right' 50 | 51 | def __call__(self, features, return_tensors=None): 52 | padding_side = self.padding_side 53 | 54 | # if return_tensors is None: 55 | # return_tensors = self.return_tensors 56 | labels = [feature["labels"] for feature in features] if "labels" in features[0].keys() else None 57 | # We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the 58 | # same length to return tensors. 59 | if labels is not None: 60 | max_label_length = max(len(l) for l in labels) 61 | if self.pad_to_multiple_of is not None: 62 | max_label_length = ( 63 | (max_label_length + self.pad_to_multiple_of - 1) 64 | // self.pad_to_multiple_of 65 | * self.pad_to_multiple_of 66 | ) 67 | 68 | for feature in features: 69 | remainder = [self.label_pad_token_id] * (max_label_length - len(feature["labels"])) 70 | if isinstance(feature["labels"], list): 71 | feature["labels"] = ( 72 | feature["labels"] + remainder if padding_side == "right" else remainder + feature["labels"] 73 | ) 74 | elif padding_side == "right": 75 | feature["labels"] = np.concatenate([feature["labels"], remainder]).astype(np.int64) 76 | else: 77 | feature["labels"] = np.concatenate([remainder, feature["labels"]]).astype(np.int64) 78 | 79 | max_length = max(len(feature['input_ids']) for feature in features) 80 | if padding_side == 'right': 81 | input_ids = [feature['input_ids'] + [self.tokenizer.pad_token_id] * (max_length - len(feature['input_ids'])) 82 | for feature in features] 83 | attention_mask = [[1] * len(feature['input_ids']) + [0] * (max_length - len(feature['input_ids'])) for 84 | feature in features] 85 | elif padding_side == 'left': 86 | input_ids = [[self.tokenizer.pad_token_id] * (max_length - len(feature['input_ids'])) + feature['input_ids'] 87 | for feature in features] 88 | attention_mask = [[0] * (max_length - len(feature['input_ids'])) + [1] * len(feature['input_ids']) for 89 | feature in features] 90 | else: 91 | raise ValueError("Invalid padding strategy:" + str(padding_side)) 92 | 93 | features = { 94 | 'input_ids': torch.tensor(input_ids).long(), 95 | 'attention_mask': torch.tensor(attention_mask).long(), 96 | 'labels': torch.tensor(np.array([feature['labels'] for feature in features])).long() 97 | } 98 | return features 99 | 100 | 101 | @dataclass 102 | class EvalDataCollatorForCauselLM: 103 | """ 104 | Data collator that will dynamically pad the inputs received, as well as the labels. 105 | Args: 106 | tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]): 107 | The tokenizer used for encoding the data. 108 | model ([`PreTrainedModel`]): 109 | The model that is being trained. If set and has the *prepare_decoder_input_ids_from_labels*, use it to 110 | prepare the *decoder_input_ids* 111 | This is useful when using *label_smoothing* to avoid calculating loss twice. 112 | padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): 113 | Select a strategy to pad the returned sequences (according to the model's padding side and padding index) 114 | among: 115 | - `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single 116 | sequence is provided). 117 | - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum 118 | acceptable input length for the model if that argument is not provided. 119 | - `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths). 120 | max_length (`int`, *optional*): 121 | Maximum length of the returned list and optionally padding length (see above). 122 | pad_to_multiple_of (`int`, *optional*): 123 | If set will pad the sequence to a multiple of the provided value. 124 | This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 125 | 7.5 (Volta). 126 | label_pad_token_id (`int`, *optional*, defaults to -100): 127 | The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions). 128 | return_tensors (`str`): 129 | The type of Tensor to return. Allowable values are "np", "pt" and "tf". 130 | """ 131 | 132 | tokenizer: Any 133 | model: Optional[Any] = None 134 | padding: Union[bool, str, PaddingStrategy] = True 135 | max_length: Optional[int] = None 136 | pad_to_multiple_of: Optional[int] = None 137 | label_pad_token_id: int = -100 138 | return_tensors: str = "pt" 139 | padding_side: str = 'left' 140 | unconditional_normalization: bool = False 141 | 142 | def __call__(self, features, return_tensors=None): 143 | padding_side = self.padding_side 144 | 145 | split_size = [] 146 | new_features = [] 147 | assert "labels" in features[0].keys() 148 | for feature in features: 149 | split_size.append(len(feature["labels"])) 150 | for op_input_ids, op_labels in zip(feature["input_ids"], feature["labels"]): 151 | un_mask = np.zeros_like(op_labels) 152 | un_mask_index = np.where(op_labels == self.label_pad_token_id, 1, 0).sum() - 2 153 | un_mask[:un_mask_index] = 1 154 | new_features.append({"input_ids": op_input_ids, "labels": op_labels, "un_mask": un_mask}) 155 | 156 | labels = [feature["labels"] for feature in new_features] 157 | # We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the 158 | # same length to return tensors. 159 | if labels is not None: 160 | max_label_length = max(len(l) for l in labels) 161 | if self.pad_to_multiple_of is not None: 162 | max_label_length = ( 163 | (max_label_length + self.pad_to_multiple_of - 1) 164 | // self.pad_to_multiple_of 165 | * self.pad_to_multiple_of 166 | ) 167 | 168 | for feature in new_features: 169 | remainder = [self.label_pad_token_id] * (max_label_length - len(feature["labels"])) 170 | if isinstance(feature["labels"], list): 171 | feature["labels"] = ( 172 | feature["labels"] + remainder if padding_side == "right" else remainder + feature["labels"] 173 | ) 174 | elif padding_side == "right": 175 | feature["labels"] = np.concatenate([feature["labels"], remainder]).astype(np.int64) 176 | feature["un_mask"] = np.concatenate([feature["un_mask"], np.ones_like(remainder)]).astype(np.int64) 177 | else: 178 | feature["labels"] = np.concatenate([remainder, feature["labels"]]).astype(np.int64) 179 | feature["un_mask"] = np.concatenate([np.ones_like(remainder), feature["un_mask"]]).astype(np.int64) 180 | 181 | max_length = max(len(feature['input_ids']) for feature in new_features) 182 | if padding_side == 'right': 183 | input_ids = [feature['input_ids'] + [self.tokenizer.pad_token_id] * (max_length - len(feature['input_ids'])) 184 | for feature in new_features] 185 | attention_mask = [[1] * len(feature['input_ids']) + [0] * (max_length - len(feature['input_ids'])) for 186 | feature in new_features] 187 | elif padding_side == 'left': 188 | input_ids = [[self.tokenizer.pad_token_id] * (max_length - len(feature['input_ids'])) + feature['input_ids'] 189 | for feature in new_features] 190 | attention_mask = [[0] * (max_length - len(feature['input_ids'])) + [1] * len(feature['input_ids']) for 191 | feature in new_features] 192 | else: 193 | raise ValueError("Invalid padding strategy:" + str(padding_side)) 194 | 195 | batched_features = { 196 | 'input_ids': torch.tensor(input_ids).long(), 197 | 'attention_mask': torch.tensor(attention_mask).long(), 198 | 'labels': torch.tensor(np.array([feature['labels'] for feature in new_features])).long(), 199 | 'split_size': split_size 200 | } 201 | if self.unconditional_normalization: 202 | batched_features['un_mask'] = torch.tensor(np.array([feature['un_mask'] for feature in new_features])).bool() 203 | 204 | return batched_features 205 | 206 | 207 | class LearningRateScheduler: 208 | r""" 209 | Learning rate scheduler with warmup. 210 | 211 | :param warmup: if ``warmup`` is an integer, ``warmup`` stands for warmup steps, if ``warmup`` is a float, 212 | such as 0.1, then it stands for warmup_ratio. 213 | :param schedule: the learning rate will be adjusted according to ``schedule`` strategy, 214 | which can be: linear or constant. 215 | """ 216 | 217 | def __init__(self, 218 | warmup: float, 219 | schedule: str, 220 | learning_rate: float, 221 | n_steps: int = 0): 222 | 223 | self.warmup = max(warmup, 0.) 224 | self.schedule = schedule 225 | self.initial_lr = learning_rate 226 | 227 | if self.warmup > 1: 228 | self.warmup = self.warmup / n_steps 229 | self.t_steps = max(2, n_steps) 230 | 231 | if self.schedule == 'constant': 232 | self.get_lr = self._get_constant_lr 233 | elif self.schedule == 'linear': 234 | self.get_lr = self._get_linear_lr 235 | else: 236 | raise NotImplementedError("Only support 'linear', 'constant'.") 237 | 238 | def _get_constant_lr(self, progress): 239 | if progress < self.warmup: 240 | return progress / self.warmup 241 | return 1 242 | 243 | def _get_linear_lr(self, progress): 244 | if progress < self.warmup: 245 | return progress / self.warmup 246 | return max((progress - 1.) / (self.warmup - 1.), 0.) 247 | 248 | def step(self, global_step): 249 | progress = global_step / self.t_steps 250 | return self.initial_lr * self.get_lr(progress) 251 | 252 | 253 | class WandbLogger: 254 | """ 255 | 使用 wandb 记录信息的类。 256 | 257 | :param training_args: Trainer 的参数 258 | """ 259 | 260 | def __init__(self, training_args): 261 | self.training_args = training_args 262 | # report_to is a list 263 | self.able = "wandb" in getattr(training_args, "report_to", []) 264 | if self.able and 'wandb' not in sys.modules: 265 | raise ModuleNotFoundError( 266 | "Detected Wandb not installed while you have set " 267 | "`report_to=['wandb']` in your training config. Please " 268 | "either set `report_to` to another value or install wandb.") 269 | 270 | def init(self, *args, **kwargs): 271 | if self.able: 272 | wandb.init(*args, **kwargs) 273 | 274 | def log(self, *args, **kwargs): 275 | if self.able: 276 | wandb.log(*args, **kwargs) 277 | 278 | def set_summary(self, key, value): 279 | if self.able: 280 | wandb.run.summary[key] = value 281 | 282 | 283 | class DynamicLossScaler: 284 | def __init__(self, 285 | init_scale=2 ** 32, 286 | scale_factor=2., 287 | scale_window=1000, 288 | min_scale=1, 289 | delayed_shift=1, 290 | consecutive_hysteresis=False, 291 | raise_error_at_min_scale=True, 292 | dtype=torch.half): 293 | self.cur_scale = init_scale 294 | self.cur_iter = 0 295 | self.last_overflow_iter = -1 296 | self.scale_factor = scale_factor 297 | self.scale_window = scale_window 298 | self.min_scale = min_scale 299 | self.delayed_shift = delayed_shift 300 | self.cur_hysteresis = delayed_shift 301 | self.consecutive_hysteresis = consecutive_hysteresis 302 | self.raise_error_at_min_scale = raise_error_at_min_scale 303 | self.dtype = dtype 304 | self.has_overflow_serial = False 305 | 306 | @property 307 | def loss_scale(self): 308 | return self.cur_scale 309 | 310 | # `x` is a torch.Tensor 311 | def _has_inf_or_nan(self, x): 312 | try: 313 | # if x is half, the .float() incurs an additional deep copy, but it's necessary if 314 | # Pytorch's .sum() creates a one-element tensor of the same type as x 315 | # (which is true for some recent version of pytorch). 316 | cpu_sum = float(x.float().sum()) 317 | # More efficient version that can be used if .sum() returns a Python scalar 318 | # cpu_sum = float(x.sum()) 319 | except RuntimeError as instance: 320 | # We want to check if inst is actually an overflow exception. 321 | # RuntimeError could come from a different error. 322 | # If so, we still want the exception to propagate. 323 | if "value cannot be converted" not in instance.args[0]: 324 | raise 325 | return True 326 | else: 327 | if cpu_sum in [float('inf'), -float('inf')] or cpu_sum != cpu_sum: 328 | return True 329 | return False 330 | 331 | # `overflow` is boolean indicating whether the gradient overflowed 332 | def update_scale(self, overflow): 333 | if overflow: 334 | # self.cur_scale /= self.scale_factor 335 | if self.delayed_shift == 1 or self.cur_hysteresis == 1: 336 | if (self.cur_scale == self.min_scale) and self.raise_error_at_min_scale: 337 | raise Exception( 338 | "Current loss scale already at minimum - cannot decrease scale anymore. Exiting run.") 339 | else: 340 | next_scale = max(self.cur_scale / self.scale_factor, self.min_scale) 341 | if torch.distributed.get_rank() == 0: 342 | overflow_msg = f"[deepspeed] OVERFLOW! Rank {torch.distributed.get_rank()} Skipping step." 343 | if self.dtype == torch.half: 344 | overflow_msg += f" Attempted loss scale: {int(self.cur_scale)}, reducing to {int(next_scale)}" 345 | print(overflow_msg) 346 | self.cur_scale = next_scale 347 | else: 348 | if torch.distributed.get_rank() == 0: 349 | overflow_msg = f"[deepspeed] OVERFLOW! Rank {torch.distributed.get_rank()} Skipping step." 350 | if self.dtype == torch.half: 351 | overflow_msg += f" Attempted loss scale: {int(self.cur_scale)}, but hysteresis is {self.cur_hysteresis}. Reducing hysteresis to {self.cur_hysteresis - 1}" 352 | print(overflow_msg) 353 | self.cur_hysteresis -= 1 354 | self.last_overflow_iter = self.cur_iter 355 | else: 356 | if self.consecutive_hysteresis: 357 | if torch.distributed.get_rank() == 0: 358 | hysteresis_msg = f"Consecutive hysteresis is enabled. Restoring hysteresis to {self.delayed_shift}" 359 | print(hysteresis_msg) 360 | self.cur_hysteresis = self.delayed_shift 361 | if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0: 362 | if not self.consecutive_hysteresis: 363 | self.cur_hysteresis = self.delayed_shift 364 | self.cur_scale *= self.scale_factor 365 | self.cur_iter += 1 366 | 367 | 368 | def get_loss(logits, labels, clip_loss_value=None): 369 | # Shift so that tokens < n predict n 370 | shift_logits = logits[..., :-1, :].contiguous() 371 | shift_labels = labels[:, 1:].contiguous() 372 | # Flatten the tokens 373 | if clip_loss_value is not None: 374 | loss_fct = CrossEntropyLoss(reduction='none') 375 | loss = loss_fct(shift_logits.view(shift_labels.shape[0] * shift_labels.shape[1], -1), 376 | shift_labels.view(-1).cuda()) 377 | loss.data.clamp_(min=-clip_loss_value, max=clip_loss_value) 378 | loss = loss.mean() 379 | else: 380 | loss_fct = CrossEntropyLoss() 381 | loss = loss_fct(shift_logits.view(shift_labels.shape[0] * shift_labels.shape[1], -1), 382 | shift_labels.view(-1).cuda()) 383 | return loss -------------------------------------------------------------------------------- /lomo_optim/__init__.py: -------------------------------------------------------------------------------- 1 | from .adalomo import AdaLomo 2 | from .lomo import Lomo 3 | 4 | __version__ = "0.1.1" 5 | __all__ = ["Lomo", "AdaLomo"] 6 | -------------------------------------------------------------------------------- /lomo_optim/adalomo.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.distributed as dist 5 | from torch.optim import Optimizer 6 | 7 | try: 8 | from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled 9 | except ImportError: 10 | from transformers.deepspeed import is_deepspeed_zero3_enabled 11 | 12 | from transformers.utils import logging 13 | 14 | 15 | class AdaLomo(Optimizer): 16 | """ 17 | 一个自定义的优化器类AdaLomo,用于在分布式训练中的梯度更新。 18 | 19 | 该类实现两个梯度更新函数 :meth:`fuse_update` 和 :meth:`fuse_update_zero3`,分别用于非ZeRO和ZeRO模式下的梯度更新。 20 | 21 | :param model: 待优化的模型 22 | :param lr: 学习率,默认值为1e-3 23 | :param eps: 正则化系数。eps[0]防止梯度平方太小,eps[1]用于在根据参数的RMS放缩学习率时防止步长太大 24 | :param clip_threshold: 归一化update矩阵时的阈值 25 | :param decay_rate: 梯度平方移动平均的衰减率 26 | :param clip_grad_norm: 梯度裁剪的范数阈值 27 | 28 | .. note:: 29 | 30 | clip_grad_norm须为正数 31 | :param clip_grad_value: 梯度裁剪的值域阈值 32 | :param weight_decay: 权重衰减系数,默认值为0.0 33 | :param loss_scale: 损失缩放系数,可以用来提高训练精度,但是太大可能会导致nan 34 | """ 35 | 36 | def __init__( 37 | self, 38 | model, 39 | lr=1e-3, 40 | loss_scale=2**10, 41 | eps=(1e-30, 1e-3), 42 | clip_threshold=1.0, 43 | decay_rate=-0.8, 44 | clip_grad_norm=None, 45 | clip_grad_value=None, 46 | weight_decay=0.0, 47 | ): 48 | self.model = model 49 | self.lr = lr 50 | self.clip_grad_norm = clip_grad_norm 51 | self.clip_grad_value = clip_grad_value 52 | self.weight_decay = weight_decay 53 | self.loss_scale = loss_scale 54 | if self.weight_decay > 0.0: 55 | self.do_weight_decay = True 56 | else: 57 | self.do_weight_decay = False 58 | self.eps = eps 59 | self.step_num = 0 60 | self.decay_rate = decay_rate 61 | self.clip_threshold = clip_threshold 62 | 63 | # for grad norm 64 | if self.clip_grad_norm is not None and self.clip_grad_norm <= 0: 65 | raise ValueError( 66 | f"clip_grad_norm should be positive, got {self.clip_grad_norm}." 67 | ) 68 | self.gather_norm = False 69 | self.grad_norms = [] 70 | self.clip_coef = None 71 | 72 | # check if zero3 is enabled 73 | self.zero3_enabled = is_deepspeed_zero3_enabled() 74 | if self.zero3_enabled: # zero3 is enabled 75 | self.grad_func = self.fuse_update_zero3() 76 | else: 77 | self.grad_func = self.fuse_update() 78 | 79 | self.exp_avg_sq = {} 80 | self.exp_avg_sq_row = {} 81 | self.exp_avg_sq_col = {} 82 | 83 | # register hook function, which will be called through the backward process 84 | for n, p in self.model.named_parameters(): 85 | if self.zero3_enabled: 86 | if len(p.ds_shape) == 1: 87 | self.exp_avg_sq[n] = torch.zeros(p.ds_shape[0], dtype=torch.float32).cuda() 88 | else: 89 | self.exp_avg_sq_row[n] = torch.zeros(p.ds_shape[0], dtype=torch.float32).cuda() 90 | self.exp_avg_sq_col[n] = torch.zeros(p.ds_shape[1], dtype=torch.float32).cuda() 91 | else: 92 | if len(p.data.shape) == 1: 93 | self.exp_avg_sq[n] = torch.zeros(p.data.shape[0], dtype=torch.float32).cuda() 94 | else: 95 | self.exp_avg_sq_row[n] = torch.zeros(p.data.shape[0], dtype=torch.float32).cuda() 96 | self.exp_avg_sq_col[n] = torch.zeros(p.data.shape[1], dtype=torch.float32).cuda() 97 | 98 | if p.requires_grad: 99 | p.register_hook(self.grad_func) 100 | defaults = dict( 101 | lr=lr, 102 | eps=eps, 103 | weight_decay=weight_decay, 104 | clip_grad_norm=clip_grad_norm, 105 | clip_grad_value=clip_grad_value, 106 | ) 107 | super(AdaLomo, self).__init__(self.model.parameters(), defaults) 108 | 109 | @staticmethod 110 | def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col): 111 | # copy from fairseq's adafactor implementation: 112 | # https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505 113 | r_factor = ( 114 | (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)) 115 | .rsqrt_() 116 | .unsqueeze(-1) 117 | ) 118 | c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() 119 | return torch.mul(r_factor, c_factor) 120 | 121 | @staticmethod 122 | def _rms(tensor): 123 | return tensor.norm(2) / (tensor.numel() ** 0.5) 124 | 125 | def fuse_update(self): 126 | """ 127 | 在非ZeRO模式下更新模型参数的梯度。 128 | 129 | :return: func,一个闭包函数,用于更新模型参数的梯度 130 | """ 131 | 132 | def func(x): 133 | """ 134 | 闭包函数,用于更新模型参数的梯度。 135 | """ 136 | with torch.no_grad(): 137 | for n, p in self.model.named_parameters(): 138 | if p.requires_grad and p.grad is not None: 139 | grad_fp32 = p.grad.to(torch.float32) 140 | p.grad = None 141 | if self.loss_scale: 142 | grad_fp32.div_(self.loss_scale) 143 | if self.gather_norm: 144 | # we adopt two backward pass for gradient norm computation and parameter update, respectively. 145 | self.grad_norms.append(torch.norm(grad_fp32, 2.0)) 146 | else: 147 | # grad clip or norm 148 | if ( 149 | self.clip_grad_value is not None 150 | and self.clip_grad_value > 0 151 | ): 152 | # Clipping gradients by their value 153 | grad_fp32.clamp_( 154 | min=-self.clip_grad_value, max=self.clip_grad_value 155 | ) 156 | if ( 157 | self.clip_grad_norm is not None 158 | and self.clip_grad_norm > 0 159 | and self.clip_coef is not None 160 | ): 161 | # Normalize the gradient according to its norm (computed in another pass) 162 | grad_fp32.mul_(self.clip_coef) 163 | 164 | # To avoid math errors for edge cases 165 | if self.step_num == 0 and self.decay_rate < 0: 166 | decay_rate = - self.decay_rate 167 | else: 168 | decay_rate = self.decay_rate 169 | 170 | beta2t = 1.0 - math.pow(self.step_num, decay_rate) 171 | update = (grad_fp32**2) + self.eps[0] 172 | 173 | if len(p.data.shape) > 1: 174 | self.exp_avg_sq_row[n].mul_(beta2t).add_( 175 | update.mean(dim=-1), alpha=1.0 - beta2t 176 | ) 177 | self.exp_avg_sq_col[n].mul_(beta2t).add_( 178 | update.mean(dim=-2), alpha=1.0 - beta2t 179 | ) 180 | update = self._approx_sq_grad( 181 | self.exp_avg_sq_row[n], self.exp_avg_sq_col[n] 182 | ) 183 | update.mul_(grad_fp32) 184 | else: 185 | self.exp_avg_sq[n].mul_(beta2t).add_( 186 | update, alpha=1.0 - beta2t 187 | ) 188 | update = self.exp_avg_sq[n].rsqrt().mul_(grad_fp32) 189 | 190 | update.div_( 191 | (self._rms(update) / self.clip_threshold).clamp_( 192 | min=1.0 193 | ) 194 | ) 195 | 196 | p_fp32 = p.data.to(torch.float32) 197 | p_rms = torch.norm(p_fp32, 2.0) / math.sqrt(p.numel()) 198 | lr = self.lr 199 | param_scale = max(self.eps[1], p_rms) 200 | lr = lr * param_scale 201 | 202 | if self.do_weight_decay: 203 | p_fp32.mul_(1.0 - lr * self.weight_decay) 204 | p_fp32.add_(update, alpha=-lr) 205 | p.data.copy_(p_fp32) 206 | 207 | return x 208 | 209 | return func 210 | 211 | def fuse_update_zero3(self): 212 | """ 213 | 在ZeRO模式下更新模型参数的梯度。 214 | 215 | :return: func,一个闭包函数,用于更新模型参数的梯度。 216 | """ 217 | 218 | def func(x): 219 | with torch.no_grad(): 220 | for n, p in self.model.named_parameters(): 221 | if p.grad is not None: 222 | torch.distributed.all_reduce( 223 | p.grad, op=torch.distributed.ReduceOp.AVG, async_op=False 224 | ) 225 | 226 | grad_fp32 = p.grad.to(torch.float32) 227 | p.grad = None 228 | if self.loss_scale: 229 | grad_fp32.div_(self.loss_scale) 230 | 231 | if self.gather_norm: 232 | # we adopt two backward pass for gradient norm computation and parameter update, respectively. 233 | self.grad_norms.append(torch.norm(grad_fp32, 2.0)) 234 | else: # update param 235 | partition_size = p.ds_tensor.numel() 236 | start = partition_size * self.dp_rank 237 | end = min(start + partition_size, grad_fp32.numel()) 238 | 239 | if self.clip_grad_value is not None: 240 | # Clipping gradients by their value 241 | grad_fp32.clamp_( 242 | min=-self.clip_grad_value, max=self.clip_grad_value 243 | ) 244 | if ( 245 | self.clip_grad_norm is not None 246 | and self.clip_grad_norm > 0 247 | and self.clip_coef is not None 248 | ): 249 | # Normalize the gradient according to its norm (computed in another pass) 250 | grad_fp32.mul_(self.clip_coef) 251 | 252 | # To avoid math errors for edge cases 253 | if self.step_num == 0 and self.decay_rate < 0: 254 | decay_rate = - self.decay_rate 255 | else: 256 | decay_rate = self.decay_rate 257 | beta2t = 1.0 - math.pow(self.step_num, decay_rate) 258 | update = (grad_fp32**2) + self.eps[0] # 改成addcmul_ 259 | 260 | if len(p.ds_shape) > 1: 261 | self.exp_avg_sq_row[n].mul_(beta2t).add_( 262 | update.mean(dim=-1), alpha=1.0 - beta2t 263 | ) 264 | self.exp_avg_sq_col[n].mul_(beta2t).add_( 265 | update.mean(dim=-2), alpha=1.0 - beta2t 266 | ) 267 | update = self._approx_sq_grad( 268 | self.exp_avg_sq_row[n], self.exp_avg_sq_col[n] 269 | ) 270 | update.mul_(grad_fp32) 271 | else: 272 | self.exp_avg_sq[n].mul_(beta2t).add_( 273 | update, alpha=1.0 - beta2t 274 | ) 275 | update = self.exp_avg_sq[n].rsqrt().mul_(grad_fp32) 276 | 277 | update.div_( 278 | (self._rms(update) / self.clip_threshold).clamp_( 279 | min=1.0 280 | ) 281 | ) 282 | 283 | one_dim_update = update.view(-1) 284 | partitioned_update = one_dim_update.narrow( 285 | 0, start, end - start 286 | ) 287 | param_fp32 = p.ds_tensor.to(torch.float32) 288 | partitioned_p = param_fp32.narrow(0, 0, end - start) 289 | 290 | p_rms = torch.norm(partitioned_p, 2.0) ** 2 291 | dist.all_reduce(p_rms, op=torch.distributed.ReduceOp.SUM) 292 | p_rms = (p_rms / p.ds_numel).sqrt() 293 | 294 | lr = self.lr 295 | param_scale = max(self.eps[1], p_rms) 296 | lr = lr * param_scale 297 | 298 | if self.do_weight_decay: 299 | partitioned_p.mul_(1.0 - lr * self.weight_decay) 300 | partitioned_p.add_(partitioned_update, alpha=-lr) 301 | p.ds_tensor[: end - start] = partitioned_p 302 | 303 | return x 304 | 305 | return func 306 | 307 | def fused_backward(self, loss, lr): 308 | """ 309 | 执行一步反向传播并更新模型的梯度。 310 | 311 | :param loss: 模型的loss值 312 | :param lr: 学习率 313 | """ 314 | self.lr = lr 315 | if self.loss_scale: 316 | loss = loss * self.loss_scale 317 | self.step_num += 1 318 | loss.backward() 319 | # update the last parameter since the last parameter in the computaiton graph is not ready when calling hook functions 320 | # the argument of grad_func is just a placeholder, and it can be anything. 321 | self.grad_func(0) 322 | 323 | def grad_norm(self, loss): 324 | """ 325 | 计算梯度的范数。 326 | 327 | :param loss: 模型的loss值 328 | """ 329 | self.gather_norm = True 330 | self.grad_norms = [] 331 | if self.loss_scale: 332 | loss = loss * self.loss_scale 333 | loss.backward(retain_graph=True) 334 | # update the last parameter since the last parameter in the computaiton graph is not ready when calling hook functions 335 | # the argument of grad_func is just a placeholder, and it can be anything. 336 | self.grad_func(0) 337 | 338 | with torch.no_grad(): 339 | # The norm is computed over all gradients together, as if they were 340 | # concatenated into a single vector. Gradients are modified in-place. 341 | self.grad_norms = torch.stack(self.grad_norms) 342 | 343 | total_norm = torch.norm(self.grad_norms, 2.0) 344 | self.clip_coef = float(self.clip_grad_norm) / (total_norm + 1e-6) 345 | self.clip_coef = torch.clamp(self.clip_coef, max=1.0) 346 | self.gather_norm = False 347 | -------------------------------------------------------------------------------- /lomo_optim/lomo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from torch.optim import Optimizer 4 | 5 | try: 6 | from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled 7 | except ImportError: 8 | from transformers.deepspeed import is_deepspeed_zero3_enabled 9 | 10 | from transformers.utils import logging 11 | 12 | logger = logging.get_logger() 13 | 14 | 15 | class Lomo(Optimizer): 16 | """ 17 | 一个自定义的优化器类Lomo,用于在分布式训练中的梯度更新。 18 | 19 | 该类实现两个梯度更新函数 :meth:`fuse_update` 和 :meth:`fuse_update_zero3`,分别用于非ZeRO和ZeRO模式下的梯度更新。 20 | 21 | :param model: 待优化的模型 22 | :param lr: 学习率,默认值为1e-3 23 | :param clip_grad_norm: 梯度裁剪的范数阈值 24 | 25 | .. note:: 26 | 27 | clip_grad_norm须为正数 28 | :param zero3_enabled: 是否开启了 zero3 29 | :param clip_grad_value: 梯度裁剪的值域阈值 30 | :param loss_scale_args: 用于初始化 :class:`DynamicLossScaler` 的参数 31 | """ 32 | 33 | def __init__( 34 | self, 35 | model, 36 | lr=1e-3, 37 | clip_grad_norm=None, 38 | clip_grad_value=None, 39 | weight_decay=0.0, 40 | loss_scale_args={}, 41 | ): 42 | self.model = model 43 | self.lr = lr 44 | self.clip_grad_norm = clip_grad_norm 45 | self.clip_grad_value = clip_grad_value 46 | self.loss_scaler = None 47 | self.loss_scale_args = loss_scale_args 48 | self.weight_decay = weight_decay 49 | if self.weight_decay > 0.0: 50 | self.do_weight_decay = True 51 | else: 52 | self.do_weight_decay = False 53 | 54 | # for grad norm 55 | if self.clip_grad_norm is not None and self.clip_grad_norm <= 0: 56 | raise ValueError( 57 | f"clip_grad_norm should be positive, got {self.clip_grad_norm}." 58 | ) 59 | self.gather_norm = False 60 | self.grad_norms = [] 61 | self.clip_coef = None 62 | 63 | # check if zero3 is enabled 64 | self.zero3_enabled = is_deepspeed_zero3_enabled() 65 | if self.zero3_enabled: # zero3 is enabled 66 | self.grad_func = self.fuse_update_zero3() 67 | else: 68 | self.grad_func = self.fuse_update() 69 | self.first_backward = True # check bf16 or fp16 in the first backward 70 | 71 | # register hook function, which will be called through the backward process 72 | for n, p in self.model.named_parameters(): 73 | if p.requires_grad: 74 | p.register_hook(self.grad_func) 75 | defaults = dict( 76 | lr=lr, clip_grad_norm=clip_grad_norm, clip_grad_value=clip_grad_value 77 | ) 78 | super(Lomo, self).__init__(self.model.parameters(), defaults) 79 | 80 | def fuse_update(self): 81 | """ 82 | 在非ZeRO模式下更新模型参数的梯度。 83 | 84 | :return: func,一个闭包函数,用于更新模型参数的梯度 85 | """ 86 | 87 | def func(x): 88 | """ 89 | 闭包函数,用于更新模型参数的梯度。 90 | """ 91 | with torch.no_grad(): 92 | for n, p in self.model.named_parameters(): 93 | if p.requires_grad and p.grad is not None: 94 | if self.loss_scaler and ( 95 | self.loss_scaler.has_overflow_serial 96 | or self.loss_scaler._has_inf_or_nan(p.grad) 97 | ): 98 | # if the overflow is detected, drop the gradient 99 | p.grad = None 100 | self.loss_scaler.has_overflow_serial = True 101 | break 102 | grad_fp32 = p.grad.to(torch.float32) 103 | p.grad = None 104 | if self.loss_scaler: 105 | grad_fp32.div_(self.loss_scaler.loss_scale) 106 | if self.gather_norm: 107 | # we adopt two backward pass for gradient norm compuation and parameter update, respectively. 108 | self.grad_norms.append(torch.norm(grad_fp32, 2.0)) 109 | else: 110 | if ( 111 | self.clip_grad_value is not None 112 | and self.clip_grad_value > 0 113 | ): 114 | # Clipping gradients by their value 115 | grad_fp32.clamp_( 116 | min=-self.clip_grad_value, max=self.clip_grad_value 117 | ) 118 | if ( 119 | self.clip_grad_norm is not None 120 | and self.clip_grad_norm > 0 121 | and self.clip_coef is not None 122 | ): 123 | # Normalize the gradient according to its norm (computed in another pass) 124 | grad_fp32.mul_(self.clip_coef) 125 | p_fp32 = p.data.to(torch.float32) 126 | if self.do_weight_decay: 127 | p_fp32.mul_(1.0 - self.lr * self.weight_decay) 128 | p_fp32.add_(grad_fp32, alpha=-self.lr) 129 | p.data.copy_(p_fp32) 130 | 131 | return x 132 | 133 | return func 134 | 135 | def fuse_update_zero3(self): 136 | """ 137 | 在ZeRO模式下更新模型参数的梯度。 138 | 139 | :return: func,一个闭包函数,用于更新模型参数的梯度。 140 | """ 141 | 142 | def func(x): 143 | with torch.no_grad(): 144 | for n, p in self.model.named_parameters(): 145 | if p.grad is not None: 146 | torch.distributed.all_reduce( 147 | p.grad, op=torch.distributed.ReduceOp.AVG, async_op=False 148 | ) 149 | if self.loss_scaler and ( 150 | self.loss_scaler.has_overflow_serial 151 | or self.loss_scaler._has_inf_or_nan(p.grad) 152 | ): 153 | # if the overflow is detected, drop the gradient 154 | p.grad = None 155 | self.loss_scaler.has_overflow_serial = True 156 | break 157 | 158 | grad_fp32 = p.grad.to(torch.float32) 159 | p.grad = None 160 | param_fp32 = p.ds_tensor.to(torch.float32) 161 | if self.loss_scaler: 162 | grad_fp32.div_(self.loss_scaler.loss_scale) 163 | 164 | if self.gather_norm: 165 | # we adopt two backward pass for gradient norm compuation and parameter update, respectively. 166 | self.grad_norms.append(torch.norm(grad_fp32, 2.0)) 167 | else: # update param 168 | one_dim_grad_fp32 = grad_fp32.view(-1) 169 | partition_size = p.ds_tensor.numel() 170 | start = partition_size * dist.get_rank() 171 | end = min(start + partition_size, grad_fp32.numel()) 172 | partitioned_grad_fp32 = one_dim_grad_fp32.narrow( 173 | 0, start, end - start 174 | ) 175 | 176 | if self.clip_grad_value is not None: 177 | # Clipping gradients by their value 178 | partitioned_grad_fp32.clamp_( 179 | min=-self.clip_grad_value, max=self.clip_grad_value 180 | ) 181 | if ( 182 | self.clip_grad_norm is not None 183 | and self.clip_grad_norm > 0 184 | and self.clip_coef is not None 185 | ): 186 | # Normalize the gradient according to its norm (computed in another pass) 187 | partitioned_grad_fp32.mul_(self.clip_coef) 188 | 189 | partitioned_p = param_fp32.narrow(0, 0, end - start) 190 | if self.do_weight_decay: 191 | partitioned_p.mul_(1.0 - self.lr * self.weight_decay) 192 | partitioned_p.add_(partitioned_grad_fp32, alpha=-self.lr) 193 | p.ds_tensor[: end - start] = partitioned_p 194 | return x 195 | 196 | return func 197 | 198 | def fused_backward(self, loss, lr): 199 | """ 200 | 执行一步反向传播并更新模型的梯度。 201 | 202 | :param loss: 模型的loss值 203 | :param lr: 学习率 204 | """ 205 | if self.first_backward: 206 | self.first_backward = False 207 | if loss.dtype == torch.float16: 208 | self.loss_scaler = DynamicLossScaler(**self.loss_scale_args) 209 | if self.clip_grad_norm is None: 210 | self.clip_grad_norm = 1.0 211 | logger.warning( 212 | "Loss scale is recommended to be used with grad norm to get better performance. " 213 | "Set grad norm to 1.0." 214 | ) 215 | self.lr = lr 216 | # Users need call grad_norm themselves and then call backward_step 217 | if ( 218 | self.clip_grad_norm is not None 219 | and self.clip_grad_norm > 0 220 | and self.clip_coef is None 221 | ): 222 | raise ValueError( 223 | "clip_grad_norm is not None, but clip_coef is None. " 224 | "Please call optimizer.grad_norm() before optimizer.fused_backward()." 225 | ) 226 | if self.loss_scaler: 227 | loss = loss * self.loss_scaler.loss_scale 228 | loss.backward() 229 | # update the last parameter since the last parameter in the computaiton graph is not ready when calling hook functions 230 | # the argument of grad_func is just a placeholder, and it can be anything. 231 | self.grad_func(0) 232 | 233 | def grad_norm(self, loss): 234 | """ 235 | 计算梯度的范数。 236 | 237 | :param loss: 模型的loss值 238 | """ 239 | if self.first_backward: 240 | self.first_backward = False 241 | if loss.dtype == torch.float16: 242 | self.loss_scaler = DynamicLossScaler(**self.loss_scale_args) 243 | 244 | self.gather_norm = True 245 | self.grad_norms = [] 246 | if self.loss_scaler: 247 | self.loss_scaler.has_overflow_serial = False 248 | loss = loss * self.loss_scaler.loss_scale 249 | loss.backward(retain_graph=True) 250 | # update the last parameter since the last parameter in the computaiton graph is not ready when calling hook functions 251 | # the argument of grad_func is just a placeholder, and it can be anything. 252 | self.grad_func(0) 253 | 254 | if self.loss_scaler and self.loss_scaler.has_overflow_serial: 255 | self.loss_scaler.update_scale(overflow=True) 256 | with torch.no_grad(): # clear gradients 257 | for n, p in self.model.named_parameters(): 258 | p.grad = None 259 | return 260 | 261 | with torch.no_grad(): 262 | # The norm is computed over all gradients together, as if they were 263 | # concatenated into a single vector. Gradients are modified in-place. 264 | self.grad_norms = torch.stack(self.grad_norms) 265 | 266 | total_norm = torch.norm(self.grad_norms, 2.0) 267 | self.clip_coef = float(self.clip_grad_norm) / (total_norm + 1e-6) 268 | self.clip_coef = torch.clamp(self.clip_coef, max=1.0) 269 | self.gather_norm = False 270 | 271 | 272 | class DynamicLossScaler: 273 | """ 274 | 动态loss缩放器,用于在训练过程中动态调整loss的缩放比例。 275 | 276 | :param init_scale: 初始缩放比例 277 | :param scale_factor: 缩放因子 278 | :param scale_window: 279 | :param min_scale: 最小缩放比例,默认为1 280 | :param delayed_shift: 延迟移位,默认为1 281 | :param consecutive_hysteresis: 是否启用连续的滞后效应,默认为False。如果是True,在处理梯度溢出时会滞后 :attr:`delayed_shift` 个迭代周期。 282 | :param raise_error_at_min_scale: 最小缩放比例时是否抛出异常,默认为True 283 | :param dtype: 数据类型,默认为torch.half 284 | """ 285 | 286 | def __init__( 287 | self, 288 | init_scale=2**32, 289 | scale_factor=2.0, 290 | scale_window=1000, 291 | min_scale=1, 292 | delayed_shift=1, 293 | consecutive_hysteresis=False, 294 | raise_error_at_min_scale=True, 295 | dtype=torch.half, 296 | ): 297 | self.cur_scale = init_scale 298 | self.cur_iter = 0 299 | self.last_overflow_iter = -1 300 | self.scale_factor = scale_factor 301 | self.scale_window = scale_window 302 | self.min_scale = min_scale 303 | self.delayed_shift = delayed_shift 304 | self.cur_hysteresis = delayed_shift 305 | self.consecutive_hysteresis = consecutive_hysteresis 306 | self.raise_error_at_min_scale = raise_error_at_min_scale 307 | self.dtype = dtype 308 | self.has_overflow_serial = False 309 | 310 | @property 311 | def loss_scale(self): 312 | return self.cur_scale 313 | 314 | # `x` is a torch.Tensor 315 | def _has_inf_or_nan(self, x): 316 | try: 317 | # if x is half, the .float() incurs an additional deep copy, but it's necessary if 318 | # Pytorch's .sum() creates a one-element tensor of the same type as x 319 | # (which is true for some recent version of pytorch). 320 | cpu_sum = float(x.float().sum()) 321 | # More efficient version that can be used if .sum() returns a Python scalar 322 | # cpu_sum = float(x.sum()) 323 | except RuntimeError as instance: 324 | # We want to check if inst is actually an overflow exception. 325 | # RuntimeError could come from a different error. 326 | # If so, we still want the exception to propagate. 327 | if "value cannot be converted" not in instance.args[0]: 328 | raise 329 | return True 330 | else: 331 | if cpu_sum in [float("inf"), -float("inf")] or cpu_sum != cpu_sum: 332 | return True 333 | return False 334 | 335 | # `overflow` is boolean indicating whether the gradient overflowed 336 | def update_scale(self, overflow): 337 | if overflow: 338 | # self.cur_scale /= self.scale_factor 339 | if self.delayed_shift == 1 or self.cur_hysteresis == 1: 340 | if (self.cur_scale == self.min_scale) and self.raise_error_at_min_scale: 341 | raise Exception( 342 | "Current loss scale already at minimum - cannot decrease scale anymore. Exiting run." 343 | ) 344 | else: 345 | next_scale = max(self.cur_scale / self.scale_factor, self.min_scale) 346 | if torch.distributed.get_rank() == 0: 347 | overflow_msg = f"[LOMO] OVERFLOW! Rank {torch.distributed.get_rank()} Skipping step." 348 | if self.dtype == torch.half: 349 | overflow_msg += f" Attempted loss scale: {int(self.cur_scale)}, reducing to {int(next_scale)}" 350 | print(overflow_msg) 351 | self.cur_scale = next_scale 352 | else: 353 | if torch.distributed.get_rank() == 0: 354 | overflow_msg = f"[LOMO] OVERFLOW! Rank {torch.distributed.get_rank()} Skipping step." 355 | if self.dtype == torch.half: 356 | overflow_msg += f" Attempted loss scale: {int(self.cur_scale)}, but hysteresis is {self.cur_hysteresis}. Reducing hysteresis to {self.cur_hysteresis - 1}" 357 | print(overflow_msg) 358 | self.cur_hysteresis -= 1 359 | self.last_overflow_iter = self.cur_iter 360 | else: 361 | if self.consecutive_hysteresis: 362 | if torch.distributed.get_rank() == 0: 363 | hysteresis_msg = f"Consecutive hysteresis is enabled. Restoring hysteresis to {self.delayed_shift}" 364 | print(hysteresis_msg) 365 | self.cur_hysteresis = self.delayed_shift 366 | if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0: 367 | if not self.consecutive_hysteresis: 368 | self.cur_hysteresis = self.delayed_shift 369 | self.cur_scale *= self.scale_factor 370 | self.cur_iter += 1 371 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "setuptools-scm"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "lomo-optim" 7 | authors = [ 8 | {name = "Kai Lv", email = "klv21@m.fudan.edu.cn"}, 9 | ] 10 | description = "LOMO: LOw-Memory Optimization" 11 | readme = "README.md" 12 | requires-python = ">=3.8" 13 | license = {file = "LICENSE"} 14 | classifiers = [ 15 | "Topic :: Scientific/Engineering :: Artificial Intelligence" 16 | ] 17 | dependencies = [ 18 | "transformers", "torch" 19 | ] 20 | dynamic = ["version"] 21 | 22 | [project.urls] 23 | Homepage = "https://github.com/OpenLMLab/LOMO" 24 | Documentation = "https://openlmlab-collie.readthedocs.io/zh-cn/latest/api/generated/collie.optim.Lomo.html" 25 | Repository = "https://github.com/OpenLMLab/LOMO.git" 26 | 27 | [tool.setuptools] 28 | packages = ["lomo_optim"] 29 | 30 | [tool.setuptools.dynamic] 31 | version = {attr = "lomo_optim.__version__"} 32 | --------------------------------------------------------------------------------