├── .gitattributes ├── .gitignore ├── README.md ├── configs └── gpt_2_L_6 │ └── config.json ├── deepspeed_configs └── zero_stage1_config.json ├── fit_scale_loss_prediction.py ├── metric ├── __init__.py ├── accuracy │ └── accuracy.py ├── glue │ └── glue.py └── perplexity │ └── perplexity.py ├── modeling ├── __init__.py ├── deepspeed_mup.py ├── initialize_with_mup.py ├── lm_mup.py ├── modeling_gpt2_mup.py ├── mup_utils.py └── utils.py ├── mup_trainer.py ├── requirements.txt ├── res └── final_data │ └── test │ ├── current_data_args.json │ ├── lm │ ├── dataset.arrow │ ├── dataset_info.json │ └── state.json │ ├── mt │ ├── dataset.arrow │ ├── dataset_info.json │ └── state.json │ └── tokenizer │ ├── added_tokens.json │ ├── merges.txt │ ├── special_tokens_map.json │ ├── tokenizer_config.json │ └── vocab.json ├── run_eval_ppl_loss_pred.sh ├── run_eval_ppl_mup.py ├── run_grid_search_pair_wise_mup.sh ├── run_train_gpt_mup_from_scratch.py ├── utils ├── __init__.py ├── files.py ├── stat.py └── torchs.py └── visualize_lr_landscape.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm 2 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 3 | 4 | # User-specific stuff 5 | .idea/ 6 | .idea/**/workspace.xml 7 | .idea/**/tasks.xml 8 | .idea/**/usage.statistics.xml 9 | .idea/**/dictionaries 10 | .idea/**/shelf 11 | 12 | # Generated files 13 | .idea/**/contentModel.xml 14 | 15 | # Sensitive or high-churn files 16 | .idea/**/dataSources/ 17 | .idea/**/dataSources.ids 18 | .idea/**/dataSources.local.xml 19 | .idea/**/sqlDataSources.xml 20 | .idea/**/dynamic.xml 21 | .idea/**/uiDesigner.xml 22 | .idea/**/dbnavigator.xml 23 | 24 | # Gradle 25 | .idea/**/gradle.xml 26 | .idea/**/libraries 27 | 28 | # Gradle and Maven with auto-import 29 | # When using Gradle or Maven with auto-import, you should exclude module files, 30 | # since they will be recreated, and may cause churn. Uncomment if using 31 | # auto-import. 32 | # .idea/modules.xml 33 | # .idea/*.iml 34 | # .idea/modules 35 | # *.iml 36 | # *.ipr 37 | 38 | # CMake 39 | cmake-build-*/ 40 | 41 | # Mongo Explorer plugin 42 | .idea/**/mongoSettings.xml 43 | 44 | # File-based project format 45 | *.iws 46 | 47 | # IntelliJ 48 | out/ 49 | 50 | # mpeltonen/sbt-idea plugin 51 | .idea_modules/ 52 | 53 | # JIRA plugin 54 | atlassian-ide-plugin.xml 55 | 56 | # Cursive Clojure plugin 57 | .idea/replstate.xml 58 | 59 | # Crashlytics plugin (for Android Studio and IntelliJ) 60 | com_crashlytics_export_strings.xml 61 | crashlytics.properties 62 | crashlytics-build.properties 63 | fabric.properties 64 | 65 | # Editor-based Rest Client 66 | .idea/httpRequests 67 | 68 | # Android studio 3.1+ serialized cache file 69 | .idea/caches/build_file_checksums.ser 70 | 71 | ################################################################## 72 | # for mac os 73 | # General 74 | .DS_Store 75 | .AppleDouble 76 | .LSOverride 77 | 78 | # Icon must end with two \r 79 | Icon 80 | 81 | # Thumbnails 82 | ._* 83 | 84 | # Files that might appear in the root of a volume 85 | .DocumentRevisions-V100 86 | .fseventsd 87 | .Spotlight-V100 88 | .TemporaryItems 89 | .Trashes 90 | .VolumeIcon.icns 91 | .com.apple.timemachine.donotpresent 92 | 93 | # Directories potentially created on remote AFP share 94 | .AppleDB 95 | .AppleDesktop 96 | Network Trash Folder 97 | Temporary Items 98 | .apdisk 99 | 100 | ################################################################## 101 | # for python 102 | # Byte-compiled / optimized / DLL files 103 | __pycache__/ 104 | *.py[cod] 105 | *$py.class 106 | 107 | # C extensions 108 | *.so 109 | 110 | # Distribution / packaging 111 | .Python 112 | build/ 113 | develop-eggs/ 114 | dist/ 115 | downloads/ 116 | eggs/ 117 | .eggs/ 118 | <<<<<<< HEAD 119 | ======= 120 | lib/ 121 | >>>>>>> 2308330857fcff3d2d4ec2f78eb4d9278036c0f5 122 | lib64/ 123 | parts/ 124 | sdist/ 125 | var/ 126 | wheels/ 127 | pip-wheel-metadata/ 128 | share/python-wheels/ 129 | *.egg-info/ 130 | .installed.cfg 131 | *.egg 132 | MANIFEST 133 | 134 | # PyInstaller 135 | # Usually these files are written by a python script from a template 136 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 137 | *.manifest 138 | *.spec 139 | 140 | # Installer logs 141 | pip-log.txt 142 | pip-delete-this-directory.txt 143 | 144 | # Unit test / coverage reports 145 | htmlcov/ 146 | .tox/ 147 | .nox/ 148 | .coverage 149 | .coverage.* 150 | .cache 151 | nosetests.xml 152 | coverage.xml 153 | *.cover 154 | .hypothesis/ 155 | .pytest_cache/ 156 | 157 | # Translations 158 | *.mo 159 | *.pot 160 | 161 | # Django stuff: 162 | *.log 163 | local_settings.py 164 | db.sqlite3 165 | db.sqlite3-journal 166 | 167 | # Flask stuff: 168 | instance/ 169 | .webassets-cache 170 | 171 | # Scrapy stuff: 172 | .scrapy 173 | 174 | # Sphinx documentation 175 | docs/_build/ 176 | 177 | # PyBuilder 178 | target/ 179 | 180 | # Jupyter Notebook 181 | .ipynb_checkpoints 182 | 183 | # IPython 184 | profile_default/ 185 | ipython_config.py 186 | 187 | # pyenv 188 | .python-version 189 | 190 | # pipenv 191 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 192 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 193 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 194 | # install all needed dependencies. 195 | #Pipfile.lock 196 | 197 | # celery beat schedule file 198 | celerybeat-schedule 199 | 200 | # SageMath parsed files 201 | *.sage.py 202 | 203 | # Environments 204 | .env 205 | .venv 206 | env/ 207 | venv/ 208 | ENV/ 209 | env.bak/ 210 | venv.bak/ 211 | 212 | # Spyder project settings 213 | .spyderproject 214 | .spyproject 215 | 216 | # Rope project settings 217 | .ropeproject 218 | 219 | # mkdocs documentation 220 | /site 221 | 222 | # Pyre type checker 223 | .pyre/ 224 | 225 | # Project files 226 | test_*.py 227 | /logs/* 228 | /coord_check_mup/* -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Mu-scaling: Loss Prediction via Maximal Update Parametrization 2 | 3 | We show that Maximal Update Parametrization (Mup) itself provides a model sequence that fits a modified scaling law and enables accurate loss prediction. 4 | 5 | Mu-scaling paper: https://arxiv.org/abs/2304.06875 6 | 7 | This implementation is based on [Huggingface](https://github.com/huggingface/transformers) and [MuTransformers](https://github.com/microsoft/mutransformers), with modifications to improve stability and support Deepspeed. 8 | 9 | 10 | 11 | ## Quick Start 12 | 13 | ### 1. Environment Setting 14 | 15 | You can use conda or other tools to manage your python environment. To make things easy, we recommend conda. 16 | 17 | ``` 18 | conda create -n mu_scaling python=3.8 19 | conda activate mu_scaling 20 | pip install -r requirements.txt 21 | ``` 22 | 23 | If you are in China, you can use `pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple` instead of `pip install -r requirements.txt` to improve installation speed. 24 | 25 | ### 2. Data Preparation 26 | 27 | Preprocess datasets for causal language model following Huggingface [instructions](https://github.com/huggingface/transformers/tree/main/examples/pytorch/language-modeling). We also provide an example of processed data in res/final_data/test. 28 | 29 | ### 3. Train GPT-2 with Mup 30 | 31 | ```bash 32 | sh run_grid_search_pair_wise_mup.sh 33 | ``` 34 | 35 | ### 4. Plot Loss Landscape 36 | 37 | If Mup works correctly, loss basins for different widths should be aligned. 38 | 39 | ```python 40 | python visualize_lr_landscape.py 41 | ``` 42 | 43 | ### 5. Fit Scaling Laws 44 | 45 | Record the training loss with the same data on the same step, then run 46 | 47 | ```python 48 | python fit_scale_loss_prediction.py 49 | ``` 50 | 51 | ### 6. Evaluation 52 | 53 | If you would like to run on evaluation data, we suggest training all the models for more steps, and then 54 | 55 | ```bash 56 | sh run_eval_ppl_loss_pred.sh 57 | ``` 58 | 59 | ## References 60 | 61 | If this project helps you, please star and cite us, thanks! 62 | ``` 63 | @article{DBLP:journals/corr/abs-2304-06875, 64 | author = {Yiqun Yao and Yequan Wang}, 65 | title = {Research without Re-search: Maximal Update Parametrization Yields Accurate Loss Prediction across Scales}, 66 | journal = {CoRR}, 67 | volume = {abs/2304.06875}, 68 | year = {2023} 69 | } 70 | ``` 71 | -------------------------------------------------------------------------------- /configs/gpt_2_L_6/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "activation_function": "gelu_new", 3 | "architectures": [ 4 | "GPT2LMHeadModel" 5 | ], 6 | "attn_pdrop": 0.1, 7 | "bos_token_id": 50256, 8 | "cls_token_id": 50259, 9 | "embd_pdrop": 0.1, 10 | "eos_token_id": 50256, 11 | "initializer_range": 0.02, 12 | "layer_norm_epsilon": 1e-05, 13 | "model_type": "gpt2", 14 | "n_ctx": 1024, 15 | "n_embd": 256, 16 | "n_head": 4, 17 | "n_inner": null, 18 | "n_layer": 6, 19 | "n_positions": 1024, 20 | "pad_token_id": 50257, 21 | "reorder_and_upcast_attn": false, 22 | "resid_pdrop": 0.1, 23 | "scale_attn_by_inverse_layer_idx": false, 24 | "scale_attn_weights": true, 25 | "sep_token_id": 50258, 26 | "summary_activation": null, 27 | "summary_first_dropout": 0.1, 28 | "summary_proj_to_labels": true, 29 | "summary_type": "cls_index", 30 | "summary_use_proj": true, 31 | "task_specific_params": { 32 | "text-generation": { 33 | "do_sample": true, 34 | "max_length": 50 35 | } 36 | }, 37 | "transformers_version": "4.25.1", 38 | "use_cache": true, 39 | "vocab_size": 50261 40 | } 41 | -------------------------------------------------------------------------------- /deepspeed_configs/zero_stage1_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "optimizer": { 11 | "type": "AdamW", 12 | "params": { 13 | "lr": "auto", 14 | "weight_decay": "auto", 15 | "torch_adam": true, 16 | "adam_w_mode": true 17 | } 18 | }, 19 | "scheduler": { 20 | "type": "WarmupDecayLR", 21 | "params": { 22 | "warmup_min_lr": "auto", 23 | "warmup_max_lr": "auto", 24 | "warmup_num_steps": "auto", 25 | "total_num_steps": "auto" 26 | } 27 | }, 28 | "zero_optimization": { 29 | "stage": 1, 30 | "allgather_partitions": true, 31 | "allgather_bucket_size": 2e8, 32 | "overlap_comm": true, 33 | "reduce_scatter": true, 34 | "reduce_bucket_size": "auto", 35 | "contiguous_gradients": true 36 | }, 37 | "gradient_accumulation_steps": "auto", 38 | "gradient_clipping": "auto", 39 | "steps_per_print": 2000, 40 | "train_batch_size": "auto", 41 | "train_micro_batch_size_per_gpu": "auto", 42 | "wall_clock_breakdown": false 43 | } -------------------------------------------------------------------------------- /fit_scale_loss_prediction.py: -------------------------------------------------------------------------------- 1 | from scipy.optimize import curve_fit 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | 5 | # model widths 6 | shapes = [128, 256, 384, 512, 640, 768, 896, 1024, 2048, 3072] 7 | # number of parameters in millions 8 | x = [8.53, 21.56, 39.09, 61.12, 87.65, 118.68, 154.21, 194.24, 676.48, 1446.72] 9 | # training loss 10 | train_loss = [3.9181,3.6054,3.4431,3.3542,3.2902,3.2478,3.224,3.1797,3.0854, 3.0425] 11 | 12 | y = np.array(train_loss) 13 | x_sample = range(int(min(x)) - 1,int(max(x)) + 5,1) 14 | 15 | # power law for fitting 16 | def func(x, a, b,c): 17 | return a * np.power(x, b) + c 18 | 19 | def curve_fit_one_line(x, y, num_pred, x_sample): 20 | popt, pcov = curve_fit(func, x[:10-num_pred], y[:10-num_pred], p0=[1, -1, 3], maxfev=5000) 21 | a = popt[0] 22 | b = popt[1] 23 | c = popt[2] 24 | print(a, b, c) 25 | print(np.sqrt(np.diag(pcov))) 26 | 27 | yvals = func(x_sample, a, b, c) 28 | return popt, np.sqrt(np.diag(pcov)), yvals 29 | 30 | # Number of models used for prediction. Results for prediction are not used in fitting curves 31 | num_pred=3 32 | popt, perr, yvals = curve_fit_one_line(x, y, num_pred, x_sample) 33 | 34 | plot1 = plt.scatter(x[:10-num_pred], y[:10-num_pred], s=15, c='c') 35 | plot1 = plt.scatter(x[10-num_pred:], y[10-num_pred:], s=30, c='c', marker="*") 36 | plot2 = plt.plot(x_sample, yvals, 'c', ls="--", label='(7.5e-4, 0.04, 6.0)') 37 | 38 | for _x, _y, _s in zip(x, y, shapes): 39 | plt.text(_x * 0.99, _y * 0.99, f"{_s}", horizontalalignment='right', verticalalignment='top') 40 | plt.legend() 41 | #plt.xscale("log") 42 | plt.xlabel("model size / M") 43 | plt.ylabel("train_loss @ 20k") 44 | plt.savefig(f"train_loss_prediction_20k_check.png") 45 | 46 | 47 | -------------------------------------------------------------------------------- /metric/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import evaluate 3 | 4 | cur_path = os.path.dirname(__file__) 5 | 6 | 7 | def my_evaluate_load(path, **kwargs): 8 | # 首先尝试从本地加载 9 | if os.path.isdir(path) or os.path.isfile(path): 10 | fun_eval = evaluate.load(path, **kwargs) 11 | else: 12 | try: 13 | local_path = os.path.abspath(os.path.join(cur_path, path)) 14 | print(f'Load `{path}` From `{local_path}`') 15 | fun_eval = evaluate.load(local_path, **kwargs) 16 | except: 17 | print(f'Load `{path}` From `hub`') 18 | fun_eval = evaluate.load(path, **kwargs) 19 | 20 | return fun_eval 21 | -------------------------------------------------------------------------------- /metric/accuracy/accuracy.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Accuracy metric.""" 15 | 16 | import datasets 17 | from sklearn.metrics import accuracy_score 18 | 19 | import evaluate 20 | 21 | 22 | _DESCRIPTION = """ 23 | Accuracy is the proportion of correct predictions among the total number of cases processed. It can be computed with: 24 | Accuracy = (TP + TN) / (TP + TN + FP + FN) 25 | Where: 26 | TP: True positive 27 | TN: True negative 28 | FP: False positive 29 | FN: False negative 30 | """ 31 | 32 | 33 | _KWARGS_DESCRIPTION = """ 34 | Args: 35 | predictions (`list` of `int`): Predicted labels. 36 | references (`list` of `int`): Ground truth labels. 37 | normalize (`boolean`): If set to False, returns the number of correctly classified samples. Otherwise, returns the fraction of correctly classified samples. Defaults to True. 38 | sample_weight (`list` of `float`): Sample weights Defaults to None. 39 | 40 | Returns: 41 | accuracy (`float` or `int`): Accuracy score. Minimum possible value is 0. Maximum possible value is 1.0, or the number of examples input, if `normalize` is set to `True`.. A higher score means higher accuracy. 42 | 43 | Examples: 44 | 45 | Example 1-A simple example 46 | >>> accuracy_metric = evaluate.load("accuracy") 47 | >>> results = accuracy_metric.compute(references=[0, 1, 2, 0, 1, 2], predictions=[0, 1, 1, 2, 1, 0]) 48 | >>> print(results) 49 | {'accuracy': 0.5} 50 | 51 | Example 2-The same as Example 1, except with `normalize` set to `False`. 52 | >>> accuracy_metric = evaluate.load("accuracy") 53 | >>> results = accuracy_metric.compute(references=[0, 1, 2, 0, 1, 2], predictions=[0, 1, 1, 2, 1, 0], normalize=False) 54 | >>> print(results) 55 | {'accuracy': 3.0} 56 | 57 | Example 3-The same as Example 1, except with `sample_weight` set. 58 | >>> accuracy_metric = evaluate.load("accuracy") 59 | >>> results = accuracy_metric.compute(references=[0, 1, 2, 0, 1, 2], predictions=[0, 1, 1, 2, 1, 0], sample_weight=[0.5, 2, 0.7, 0.5, 9, 0.4]) 60 | >>> print(results) 61 | {'accuracy': 0.8778625954198473} 62 | """ 63 | 64 | 65 | _CITATION = """ 66 | @article{scikit-learn, 67 | title={Scikit-learn: Machine Learning in {P}ython}, 68 | author={Pedregosa, F. and Varoquaux, G. and Gramfort, A. and Michel, V. 69 | and Thirion, B. and Grisel, O. and Blondel, M. and Prettenhofer, P. 70 | and Weiss, R. and Dubourg, V. and Vanderplas, J. and Passos, A. and 71 | Cournapeau, D. and Brucher, M. and Perrot, M. and Duchesnay, E.}, 72 | journal={Journal of Machine Learning Research}, 73 | volume={12}, 74 | pages={2825--2830}, 75 | year={2011} 76 | } 77 | """ 78 | 79 | 80 | @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) 81 | class Accuracy(evaluate.Metric): 82 | def _info(self): 83 | return evaluate.MetricInfo( 84 | description=_DESCRIPTION, 85 | citation=_CITATION, 86 | inputs_description=_KWARGS_DESCRIPTION, 87 | features=datasets.Features( 88 | { 89 | "predictions": datasets.Sequence(datasets.Value("int32")), 90 | "references": datasets.Sequence(datasets.Value("int32")), 91 | } 92 | if self.config_name == "multilabel" 93 | else { 94 | "predictions": datasets.Value("int32"), 95 | "references": datasets.Value("int32"), 96 | } 97 | ), 98 | reference_urls=["https://scikit-learn.org/stable/modules/generated/sklearn.metrics.accuracy_score.html"], 99 | ) 100 | 101 | def _compute(self, predictions, references, normalize=True, sample_weight=None): 102 | return { 103 | "accuracy": float( 104 | accuracy_score(references, predictions, normalize=normalize, sample_weight=sample_weight) 105 | ) 106 | } 107 | -------------------------------------------------------------------------------- /metric/glue/glue.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Evaluate Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ GLUE benchmark metric. """ 15 | 16 | import datasets 17 | from scipy.stats import pearsonr, spearmanr 18 | from sklearn.metrics import f1_score, matthews_corrcoef 19 | 20 | import evaluate 21 | 22 | 23 | _CITATION = """\ 24 | @inproceedings{wang2019glue, 25 | title={{GLUE}: A Multi-Task Benchmark and Analysis Platform for Natural Language Understanding}, 26 | author={Wang, Alex and Singh, Amanpreet and Michael, Julian and Hill, Felix and Levy, Omer and Bowman, Samuel R.}, 27 | note={In the Proceedings of ICLR.}, 28 | year={2019} 29 | } 30 | """ 31 | 32 | _DESCRIPTION = """\ 33 | GLUE, the General Language Understanding Evaluation benchmark 34 | (https://gluebenchmark.com/) is a collection of resources for training, 35 | evaluating, and analyzing natural language understanding systems. 36 | """ 37 | 38 | _KWARGS_DESCRIPTION = """ 39 | Compute GLUE evaluation metric associated to each GLUE dataset. 40 | Args: 41 | predictions: list of predictions to score. 42 | Each translation should be tokenized into a list of tokens. 43 | references: list of lists of references for each translation. 44 | Each reference should be tokenized into a list of tokens. 45 | Returns: depending on the GLUE subset, one or several of: 46 | "accuracy": Accuracy 47 | "f1": F1 score 48 | "pearson": Pearson Correlation 49 | "spearmanr": Spearman Correlation 50 | "matthews_correlation": Matthew Correlation 51 | Examples: 52 | 53 | >>> glue_metric = evaluate.load('glue', 'sst2') # 'sst2' or any of ["mnli", "mnli_mismatched", "mnli_matched", "qnli", "rte", "wnli", "hans"] 54 | >>> references = [0, 1] 55 | >>> predictions = [0, 1] 56 | >>> results = glue_metric.compute(predictions=predictions, references=references) 57 | >>> print(results) 58 | {'accuracy': 1.0} 59 | 60 | >>> glue_metric = evaluate.load('glue', 'mrpc') # 'mrpc' or 'qqp' 61 | >>> references = [0, 1] 62 | >>> predictions = [0, 1] 63 | >>> results = glue_metric.compute(predictions=predictions, references=references) 64 | >>> print(results) 65 | {'accuracy': 1.0, 'f1': 1.0} 66 | 67 | >>> glue_metric = evaluate.load('glue', 'stsb') 68 | >>> references = [0., 1., 2., 3., 4., 5.] 69 | >>> predictions = [0., 1., 2., 3., 4., 5.] 70 | >>> results = glue_metric.compute(predictions=predictions, references=references) 71 | >>> print({"pearson": round(results["pearson"], 2), "spearmanr": round(results["spearmanr"], 2)}) 72 | {'pearson': 1.0, 'spearmanr': 1.0} 73 | 74 | >>> glue_metric = evaluate.load('glue', 'cola') 75 | >>> references = [0, 1] 76 | >>> predictions = [0, 1] 77 | >>> results = glue_metric.compute(predictions=predictions, references=references) 78 | >>> print(results) 79 | {'matthews_correlation': 1.0} 80 | """ 81 | 82 | 83 | def simple_accuracy(preds, labels): 84 | return float((preds == labels).mean()) 85 | 86 | 87 | def acc_and_f1(preds, labels): 88 | acc = simple_accuracy(preds, labels) 89 | f1 = float(f1_score(y_true=labels, y_pred=preds)) 90 | return { 91 | "accuracy": acc, 92 | "f1": f1, 93 | } 94 | 95 | 96 | def pearson_and_spearman(preds, labels): 97 | pearson_corr = float(pearsonr(preds, labels)[0]) 98 | spearman_corr = float(spearmanr(preds, labels)[0]) 99 | return { 100 | "pearson": pearson_corr, 101 | "spearmanr": spearman_corr, 102 | } 103 | 104 | 105 | @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) 106 | class Glue(evaluate.Metric): 107 | def _info(self): 108 | if self.config_name not in [ 109 | "sst2", 110 | "mnli", 111 | "mnli_mismatched", 112 | "mnli_matched", 113 | "cola", 114 | "stsb", 115 | "mrpc", 116 | "qqp", 117 | "qnli", 118 | "rte", 119 | "wnli", 120 | "hans", 121 | ]: 122 | raise KeyError( 123 | "You should supply a configuration name selected in " 124 | '["sst2", "mnli", "mnli_mismatched", "mnli_matched", ' 125 | '"cola", "stsb", "mrpc", "qqp", "qnli", "rte", "wnli", "hans"]' 126 | ) 127 | return evaluate.MetricInfo( 128 | description=_DESCRIPTION, 129 | citation=_CITATION, 130 | inputs_description=_KWARGS_DESCRIPTION, 131 | features=datasets.Features( 132 | { 133 | "predictions": datasets.Value("int64" if self.config_name != "stsb" else "float32"), 134 | "references": datasets.Value("int64" if self.config_name != "stsb" else "float32"), 135 | } 136 | ), 137 | codebase_urls=[], 138 | reference_urls=[], 139 | format="numpy", 140 | ) 141 | 142 | def _compute(self, predictions, references): 143 | if self.config_name == "cola": 144 | return {"matthews_correlation": matthews_corrcoef(references, predictions)} 145 | elif self.config_name == "stsb": 146 | return pearson_and_spearman(predictions, references) 147 | elif self.config_name in ["mrpc", "qqp"]: 148 | return acc_and_f1(predictions, references) 149 | elif self.config_name in ["sst2", "mnli", "mnli_mismatched", "mnli_matched", "qnli", "rte", "wnli", "hans"]: 150 | return {"accuracy": simple_accuracy(predictions, references)} 151 | else: 152 | raise KeyError( 153 | "You should supply a configuration name selected in " 154 | '["sst2", "mnli", "mnli_mismatched", "mnli_matched", ' 155 | '"cola", "stsb", "mrpc", "qqp", "qnli", "rte", "wnli", "hans"]' 156 | ) 157 | -------------------------------------------------------------------------------- /metric/perplexity/perplexity.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The HuggingFace Datasets Authors and the current dataset script contributor. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Perplexity Metric.""" 15 | 16 | import datasets 17 | import numpy as np 18 | import torch 19 | from torch.nn import CrossEntropyLoss 20 | from transformers import AutoModelForCausalLM, AutoTokenizer 21 | 22 | import evaluate 23 | from evaluate import logging 24 | 25 | 26 | _CITATION = """\ 27 | 28 | """ 29 | 30 | _DESCRIPTION = """ 31 | Perplexity (PPL) is one of the most common metrics for evaluating language models. 32 | It is defined as the exponentiated average negative log-likelihood of a sequence, calculated with exponent base `e`. 33 | 34 | For more information, see https://huggingface.co/docs/transformers/perplexity 35 | """ 36 | 37 | _KWARGS_DESCRIPTION = """ 38 | Args: 39 | model_id (str): model used for calculating Perplexity 40 | NOTE: Perplexity can only be calculated for causal language models. 41 | This includes models such as gpt2, causal variations of bert, 42 | causal versions of t5, and more (the full list can be found 43 | in the AutoModelForCausalLM documentation here: 44 | https://huggingface.co/docs/transformers/master/en/model_doc/auto#transformers.AutoModelForCausalLM ) 45 | 46 | predictions (list of str): input text, each separate text snippet 47 | is one list entry. 48 | batch_size (int): the batch size to run texts through the model. Defaults to 16. 49 | add_start_token (bool): whether to add the start token to the texts, 50 | so the perplexity can include the probability of the first word. Defaults to True. 51 | device (str): device to run on, defaults to 'cuda' when available 52 | Returns: 53 | perplexity: dictionary containing the perplexity scores for the texts 54 | in the input list, as well as the mean perplexity. If one of the input texts is 55 | longer than the max input length of the model, then it is truncated to the 56 | max length for the perplexity computation. 57 | Examples: 58 | Example 1: 59 | >>> perplexity = evaluate.load("perplexity", module_type="metric") 60 | >>> input_texts = ["lorem ipsum", "Happy Birthday!", "Bienvenue"] 61 | >>> results = perplexity.compute(model_id='gpt2', 62 | ... add_start_token=False, 63 | ... predictions=input_texts) # doctest:+ELLIPSIS 64 | >>> print(list(results.keys())) 65 | ['perplexities', 'mean_perplexity'] 66 | >>> print(round(results["mean_perplexity"], 0)) 67 | 647.0 68 | >>> print(round(results["perplexities"][0], 0)) 69 | 32.0 70 | 71 | Example 2: 72 | >>> from datasets import load_dataset 73 | >>> perplexity = evaluate.load("perplexity", module_type="metric") 74 | >>> input_texts = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")["text"][:10] # doctest: +SKIP 75 | >>> input_texts = [s for s in input_texts if s!=''] 76 | >>> results = perplexity.compute(model_id='gpt2', 77 | ... predictions=input_texts) 78 | >>> print(list(results.keys())) 79 | ['perplexities', 'mean_perplexity'] 80 | >>> print(round(results["mean_perplexity"], 2)) # doctest: +SKIP 81 | 576.76 82 | >>> print(round(results["perplexities"][0], 2)) # doctest: +SKIP 83 | 889.28 84 | """ 85 | 86 | 87 | @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) 88 | class Perplexity(evaluate.Metric): 89 | def _info(self): 90 | return evaluate.MetricInfo( 91 | module_type="metric", 92 | description=_DESCRIPTION, 93 | citation=_CITATION, 94 | inputs_description=_KWARGS_DESCRIPTION, 95 | features=datasets.Features( 96 | { 97 | "predictions": datasets.Value("string"), 98 | } 99 | ), 100 | reference_urls=["https://huggingface.co/docs/transformers/perplexity"], 101 | ) 102 | 103 | def _compute(self, predictions, model_id, batch_size: int = 16, add_start_token: bool = True, device=None): 104 | 105 | if device is not None: 106 | assert device in ["gpu", "cpu", "cuda"], "device should be either gpu or cpu." 107 | if device == "gpu": 108 | device = "cuda" 109 | else: 110 | device = "cuda" if torch.cuda.is_available() else "cpu" 111 | 112 | model = AutoModelForCausalLM.from_pretrained(model_id) 113 | model = model.to(device) 114 | 115 | tokenizer = AutoTokenizer.from_pretrained(model_id) 116 | 117 | # if batch_size > 1 (which generally leads to padding being required), and 118 | # if there is not an already assigned pad_token, assign an existing 119 | # special token to also be the padding token 120 | if tokenizer.pad_token is None and batch_size > 1: 121 | existing_special_tokens = list(tokenizer.special_tokens_map_extended.values()) 122 | # check that the model already has at least one special token defined 123 | assert ( 124 | len(existing_special_tokens) > 0 125 | ), "If batch_size > 1, model must have at least one special token to use for padding. Please use a different model or set batch_size=1." 126 | # assign one of the special tokens to also be the pad token 127 | tokenizer.add_special_tokens({"pad_token": existing_special_tokens[0]}) 128 | 129 | if add_start_token: 130 | # leave room for token to be added: 131 | assert ( 132 | tokenizer.bos_token is not None 133 | ), "Input model must already have a BOS token if using add_start_token=True. Please use a different model, or set add_start_token=False" 134 | max_tokenized_len = model.config.max_length - 1 135 | else: 136 | max_tokenized_len = model.config.max_length 137 | 138 | encodings = tokenizer( 139 | predictions, 140 | add_special_tokens=False, 141 | padding=True, 142 | truncation=True, 143 | max_length=max_tokenized_len, 144 | return_tensors="pt", 145 | return_attention_mask=True, 146 | ).to(device) 147 | 148 | encoded_texts = encodings["input_ids"] 149 | attn_masks = encodings["attention_mask"] 150 | 151 | # check that each input is long enough: 152 | if add_start_token: 153 | assert torch.all(torch.ge(attn_masks.sum(1), 1)), "Each input text must be at least one token long." 154 | else: 155 | assert torch.all( 156 | torch.ge(attn_masks.sum(1), 2) 157 | ), "When add_start_token=False, each input text must be at least two tokens long. Run with add_start_token=True if inputting strings of only one token, and remove all empty input strings." 158 | 159 | ppls = [] 160 | loss_fct = CrossEntropyLoss(reduction="none") 161 | 162 | for start_index in logging.tqdm(range(0, len(encoded_texts), batch_size)): 163 | end_index = min(start_index + batch_size, len(encoded_texts)) 164 | encoded_batch = encoded_texts[start_index:end_index] 165 | attn_mask = attn_masks[start_index:end_index] 166 | 167 | if add_start_token: 168 | bos_tokens_tensor = torch.tensor([[tokenizer.bos_token_id]] * encoded_batch.size(dim=0)).to(device) 169 | encoded_batch = torch.cat([bos_tokens_tensor, encoded_batch], dim=1) 170 | attn_mask = torch.cat( 171 | [torch.ones(bos_tokens_tensor.size(), dtype=torch.int64).to(device), attn_mask], dim=1 172 | ) 173 | 174 | labels = encoded_batch 175 | 176 | with torch.no_grad(): 177 | out_logits = model(encoded_batch, attention_mask=attn_mask).logits 178 | 179 | shift_logits = out_logits[..., :-1, :].contiguous() 180 | shift_labels = labels[..., 1:].contiguous() 181 | shift_attention_mask_batch = attn_mask[..., 1:].contiguous() 182 | 183 | perplexity_batch = torch.exp( 184 | (loss_fct(shift_logits.transpose(1, 2), shift_labels) * shift_attention_mask_batch).sum(1) 185 | / shift_attention_mask_batch.sum(1) 186 | ) 187 | 188 | ppls += perplexity_batch.tolist() 189 | 190 | return {"perplexities": ppls, "mean_perplexity": np.mean(ppls)} 191 | -------------------------------------------------------------------------------- /modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # from .nlm import NLMGPT2Model, NLMForSequenceClassification 2 | from .lm_mup import MupGPT2Model 3 | -------------------------------------------------------------------------------- /modeling/deepspeed_mup.py: -------------------------------------------------------------------------------- 1 | # This modifies transformers/deepspeed.py,Re-group parameters according to mup AdamW rules. 2 | """ 3 | Integration with Deepspeed 4 | """ 5 | 6 | import importlib.util 7 | import weakref 8 | from copy import deepcopy 9 | from functools import partialmethod 10 | from collections import defaultdict 11 | 12 | from transformers.dependency_versions_check import dep_version_check 13 | from transformers.utils import is_accelerate_available, is_torch_available, logging 14 | 15 | 16 | 17 | if is_torch_available(): 18 | import torch 19 | 20 | logger = logging.get_logger(__name__) 21 | 22 | 23 | def is_deepspeed_available(): 24 | return importlib.util.find_spec("deepspeed") is not None 25 | 26 | 27 | if is_accelerate_available() and is_deepspeed_available(): 28 | from accelerate.utils.deepspeed import HfDeepSpeedConfig as DeepSpeedConfig 29 | else: 30 | # Inherits from a dummy `object` if accelerate is not available, so that python succeeds to import this file. 31 | # Deepspeed glue code will never inherit this dummy object as it checks if accelerate is available. 32 | from builtins import object as DeepSpeedConfig 33 | 34 | 35 | class HfDeepSpeedConfig(DeepSpeedConfig): 36 | """ 37 | This object contains a DeepSpeed configuration dictionary and can be quickly queried for things like zero stage. 38 | A `weakref` of this object is stored in the module's globals to be able to access the config from areas where 39 | things like the Trainer object is not available (e.g. `from_pretrained` and `_get_resized_embeddings`). Therefore 40 | it's important that this object remains alive while the program is still running. 41 | [`Trainer`] uses the `HfTrainerDeepSpeedConfig` subclass instead. That subclass has logic to sync the configuration 42 | with values of [`TrainingArguments`] by replacing special placeholder values: `"auto"`. Without this special logic 43 | the DeepSpeed configuration is not modified in any way. 44 | Args: 45 | config_file_or_dict (`Union[str, Dict]`): path to DeepSpeed config file or dict. 46 | """ 47 | 48 | def __init__(self, config_file_or_dict): 49 | # set global weakref object 50 | set_hf_deepspeed_config(self) 51 | dep_version_check("accelerate") 52 | dep_version_check("deepspeed") 53 | super().__init__(config_file_or_dict) 54 | 55 | 56 | class HfTrainerDeepSpeedConfig(HfDeepSpeedConfig): 57 | """ 58 | The `HfTrainerDeepSpeedConfig` object is meant to be created during `TrainingArguments` object creation and has the 59 | same lifespan as the latter. 60 | """ 61 | 62 | def __init__(self, config_file_or_dict): 63 | super().__init__(config_file_or_dict) 64 | self._dtype = None 65 | self.mismatches = [] 66 | 67 | def dtype(self): 68 | if self._dtype is None: 69 | raise ValueError("trainer_config_process() wasn't called yet to tell dtype") 70 | return self._dtype 71 | 72 | def fill_match(self, ds_key_long, hf_val, hf_key=None, must_match=True): 73 | """ 74 | A utility method that massages the config file and can optionally verify that the values match. 75 | 1. Replace "auto" values with `TrainingArguments` value. 76 | 2. If it wasn't "auto" and `must_match` is true, then check that DS config matches Trainer 77 | config values and if mismatched add the entry to `self.mismatched` - will assert during 78 | `trainer_config_finalize` for one or more mismatches. 79 | """ 80 | config, ds_key = self.find_config_node(ds_key_long) 81 | if config is None: 82 | return 83 | 84 | if config.get(ds_key) == "auto": 85 | config[ds_key] = hf_val 86 | return 87 | 88 | if not must_match: 89 | return 90 | 91 | ds_val = config.get(ds_key) 92 | if ds_val is not None and ds_val != hf_val: 93 | self.mismatches.append(f"- ds {ds_key_long}={ds_val} vs hf {hf_key}={hf_val}") 94 | 95 | fill_only = partialmethod(fill_match, must_match=False) 96 | 97 | def trainer_config_process(self, args): 98 | """ 99 | Adjust the config with `TrainingArguments` values. This stage is run during `TrainingArguments` object 100 | creation. 101 | """ 102 | # DeepSpeed does: 103 | # train_batch_size = world_size * train_micro_batch_size_per_gpu * gradient_accumulation_steps 104 | train_batch_size = args.world_size * args.per_device_train_batch_size * args.gradient_accumulation_steps 105 | self.fill_match( 106 | "train_micro_batch_size_per_gpu", args.per_device_train_batch_size, "per_device_train_batch_size" 107 | ) 108 | self.fill_match("gradient_accumulation_steps", args.gradient_accumulation_steps, "gradient_accumulation_steps") 109 | self.fill_match("train_batch_size", train_batch_size, "train_batch_size (calculated)") 110 | self.fill_match("gradient_clipping", args.max_grad_norm, "max_grad_norm") 111 | 112 | self.fill_match("optimizer.params.lr", args.learning_rate, "learning_rate") 113 | self.fill_match("optimizer.params.betas", [args.adam_beta1, args.adam_beta2], "adam_beta1+adam_beta2") 114 | self.fill_match("optimizer.params.eps", args.adam_epsilon, "adam_epsilon") 115 | self.fill_match("optimizer.params.weight_decay", args.weight_decay, "weight_decay") 116 | 117 | self.fill_only("scheduler.params.warmup_min_lr", 0) # not a trainer arg 118 | self.fill_match("scheduler.params.warmup_max_lr", args.learning_rate, "learning_rate") 119 | # total_num_steps - will get set in trainer_config_finalize 120 | 121 | # fp16 122 | if args.fp16 or args.fp16_full_eval: 123 | fp16_backend = "apex" if args.fp16_backend == "apex" else "amp" 124 | else: 125 | fp16_backend = None 126 | 127 | # amp: similar to the pytorch native amp - it has a bunch of optional params but we won't set 128 | # any here unless the user did the work 129 | self.fill_match( 130 | "fp16.enabled", 131 | ((args.fp16 or args.fp16_full_eval) and fp16_backend == "amp"), 132 | "fp16|fp16_full_eval+fp16_backend(amp)", 133 | ) 134 | 135 | # apex: delegates amp work to apex (which needs to be available), but it cannot be used with any 136 | # ZeRO features 137 | self.fill_match("amp.enabled", fp16_backend == "apex", "fp16+fp16_backend(apex)") 138 | self.fill_match("amp.opt_level", args.fp16_opt_level, "fp16_opt_level") 139 | 140 | self.fill_match("bf16.enabled", (args.bf16 or args.bf16_full_eval), "bf16|bf16_full_eval") 141 | 142 | # deepspeed's default mode is fp16 unless there is a config that says differently 143 | if self.is_true("bf16.enabled"): 144 | self._dtype = torch.bfloat16 145 | elif self.is_false("fp16.enabled"): 146 | self._dtype = torch.float32 147 | else: 148 | self._dtype = torch.float16 149 | 150 | def trainer_config_finalize(self, args, model, num_training_steps): 151 | """ 152 | This stage is run after we have the model and know num_training_steps. 153 | Now we can complete the configuration process. 154 | """ 155 | # zero 156 | hidden_size = model.config.hidden_size 157 | self.fill_only("zero_optimization.reduce_bucket_size", hidden_size * hidden_size) 158 | if self.is_zero3(): 159 | # automatically assign the optimal config values based on model config 160 | self.fill_only("zero_optimization.stage3_prefetch_bucket_size", 0.9 * hidden_size * hidden_size) 161 | self.fill_only("zero_optimization.stage3_param_persistence_threshold", 10 * hidden_size) 162 | 163 | # scheduler 164 | self.fill_match("scheduler.params.total_num_steps", num_training_steps, "num_training_steps (calculated)") 165 | self.fill_match("scheduler.params.warmup_num_steps", args.get_warmup_steps(num_training_steps), "warmup_steps") 166 | 167 | if len(self.mismatches) > 0: 168 | mismatches = "\n".join(self.mismatches) 169 | raise ValueError( 170 | "Please correct the following DeepSpeed config values that mismatch TrainingArguments" 171 | f" values:\n{mismatches}\nThe easiest method is to set these DeepSpeed config values to 'auto'." 172 | ) 173 | 174 | 175 | # keep the config object global to be able to access it anywhere during TrainingArguments life-cycle 176 | _hf_deepspeed_config_weak_ref = None 177 | 178 | 179 | def set_hf_deepspeed_config(hf_deepspeed_config_obj): 180 | # this is a special weakref global object to allow us to get to Deepspeed config from APIs 181 | # that don't have an easy way to get to the Deepspeed config outside of the Trainer domain. 182 | global _hf_deepspeed_config_weak_ref 183 | # will go away automatically when HfDeepSpeedConfig is destroyed (when TrainingArguments is destroyed) 184 | _hf_deepspeed_config_weak_ref = weakref.ref(hf_deepspeed_config_obj) 185 | 186 | 187 | def unset_hf_deepspeed_config(): 188 | # useful for unit tests to ensure the global state doesn't leak - call from `tearDown` method 189 | global _hf_deepspeed_config_weak_ref 190 | _hf_deepspeed_config_weak_ref = None 191 | 192 | 193 | def is_deepspeed_zero3_enabled(): 194 | if _hf_deepspeed_config_weak_ref is not None and _hf_deepspeed_config_weak_ref() is not None: 195 | return _hf_deepspeed_config_weak_ref().is_zero3() 196 | else: 197 | return False 198 | 199 | 200 | def deepspeed_config(): 201 | if _hf_deepspeed_config_weak_ref is not None and _hf_deepspeed_config_weak_ref() is not None: 202 | return _hf_deepspeed_config_weak_ref().config 203 | else: 204 | return None 205 | 206 | 207 | def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps): 208 | """ 209 | A convenience wrapper that deals with optimizer and lr scheduler configuration. 210 | """ 211 | config = hf_deepspeed_config.config 212 | 213 | # Optimizer + Scheduler 214 | # Currently supported combos: 215 | # 1. DS scheduler + DS optimizer: Yes 216 | # 2. HF scheduler + HF optimizer: Yes 217 | # 3. DS scheduler + HF optimizer: Yes 218 | # 4. HF scheduler + DS optimizer: Yes 219 | # 220 | # Unless Offload is enabled in which case it's: 221 | # 1. DS scheduler + DS optimizer: Yes 222 | # 2. HF scheduler + HF optimizer: Mostly* 223 | # 3. DS scheduler + HF optimizer: Mostly* 224 | # 4. HF scheduler + DS optimizer: Yes 225 | # 226 | # Mostly*: All non-native DeepSpeed optimizers that have both CPU and GPU implementation should work (except LAMB) 227 | 228 | optimizer = None 229 | if "optimizer" in config: 230 | if args.adafactor: 231 | raise ValueError( 232 | "--adafactor was passed, but also found `optimizer` configured in the DeepSpeed config. " 233 | "Only one optimizer can be configured." 234 | ) 235 | else: 236 | if hf_deepspeed_config.is_offload(): 237 | logger.info( 238 | "Detected ZeRO Offload and non-DeepSpeed optimizers: This combination should work as long as the" 239 | " custom optimizer has both CPU and GPU implementation (except LAMB)" 240 | ) 241 | 242 | # ds supports Adam, OneBitAdam, and Lamb optimizers and can import other optimizers from torch. 243 | # But trainer uses AdamW by default. 244 | optimizer = trainer.create_optimizer() 245 | # To use other optimizers requires voiding warranty with: `zero_allow_untested_optimizer` 246 | config["zero_allow_untested_optimizer"] = True 247 | 248 | def _lr_scheduler_callable(optimizer): 249 | return trainer.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer) 250 | 251 | lr_scheduler = None 252 | if "scheduler" not in config: 253 | if optimizer is None: 254 | # Optimizer is not available, so use callable to defer lr_scheduler creation to DS init 255 | lr_scheduler = _lr_scheduler_callable 256 | else: 257 | lr_scheduler = trainer.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer) 258 | 259 | return optimizer, lr_scheduler 260 | 261 | 262 | def process_param_groups(params, **kwargs): 263 | # params: dict {name: params} 264 | param_groups = [{'params': params}] 265 | for param_group in param_groups: 266 | if 'lr' not in param_group: 267 | param_group['lr'] = kwargs['lr'] 268 | if 'weight_decay' not in param_group: 269 | param_group['weight_decay'] = kwargs.get('weight_decay', 0.) 270 | return param_groups 271 | 272 | 273 | def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inference=False): 274 | """ 275 | Init DeepSpeed, after updating the DeepSpeed configuration with any relevant Trainer's args. 276 | If `resume_from_checkpoint` was passed then an attempt to resume from a previously saved checkpoint will be made. 277 | Args: 278 | trainer: Trainer object 279 | num_training_steps: per single gpu 280 | resume_from_checkpoint: path to a checkpoint if to resume from after normal DeepSpeedEngine load 281 | inference: launch in inference mode (no optimizer and no lr scheduler) 282 | Returns: model, optimizer, lr_scheduler 283 | We may use `deepspeed_init` more than once during the life of Trainer, when we do - it's a temp hack based on: 284 | https://github.com/microsoft/DeepSpeed/issues/1394#issuecomment-937405374 until Deepspeed fixes a bug where it 285 | can't resume from a checkpoint after it did some stepping https://github.com/microsoft/DeepSpeed/issues/1612 286 | """ 287 | import deepspeed 288 | from deepspeed.utils import logger as ds_logger 289 | 290 | model = trainer.model 291 | args = trainer.args 292 | 293 | if hasattr(trainer, "hf_deepspeed_config_orig"): 294 | hf_deepspeed_config = deepcopy(trainer.hf_deepspeed_config_orig) 295 | else: 296 | hf_deepspeed_config = args.hf_deepspeed_config 297 | trainer.hf_deepspeed_config_orig = deepcopy(args.hf_deepspeed_config) 298 | 299 | # resume config update - some bits like `model` and `num_training_steps` only become available during train 300 | hf_deepspeed_config.trainer_config_finalize(args, model, num_training_steps) 301 | config = hf_deepspeed_config.config 302 | 303 | # set the Deepspeed log level consistent with the Trainer 304 | ds_logger.setLevel(args.get_process_log_level()) 305 | 306 | if inference: 307 | # only Z3 makes sense for the inference 308 | if not hf_deepspeed_config.is_zero3(): 309 | raise ValueError("ZeRO inference only makes sense with ZeRO Stage 3 - please adjust your config") 310 | 311 | # in case the training config is re-used for inference 312 | hf_deepspeed_config.del_config_sub_tree("optimizer") 313 | hf_deepspeed_config.del_config_sub_tree("lr_scheduler") 314 | optimizer, lr_scheduler = None, None 315 | model_parameters = None 316 | else: 317 | trainer.optimizer = None # important for when deepspeed_init is used as re-init 318 | optimizer, lr_scheduler = deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps) 319 | model_parameters = list(filter(lambda p: p.requires_grad, model.parameters())) 320 | 321 | # keep for quick debug: 322 | # from pprint import pprint; pprint(config) 323 | 324 | if args.use_mup: 325 | ############# begin mup re-grouping ############# 326 | model_parameters = dict(model.named_parameters()) 327 | new_param_groups = [] 328 | width_mult_group = [] 329 | decoupled_wd = False 330 | matrix_like_key_words = ["c_attn.weight", "c_proj.weight", "c_fc.weight", "c_proj.weight"] 331 | for param_group in process_param_groups(model_parameters, lr=trainer.args.learning_rate): 332 | # param_group: {"params": dict<{name:params}>, "lr": float, "weight_decay": float} 333 | # For every existing param group, we split into several new groups 334 | def new_group(): 335 | new_g = {k:v for k, v in param_group.items() if k != 'params'} 336 | new_g['params'] = [] 337 | return new_g 338 | # The matrix-like weights might need multiple groups since weights 339 | # might have different width multipliers 340 | matrix_like_p = defaultdict(new_group) # key is width_mult 341 | vector_like_p = new_group() 342 | for n, p in param_group['params'].items(): 343 | if not p.requires_grad: 344 | continue 345 | if any([key in n for key in matrix_like_key_words]): 346 | matrix_like_p[args.width_mult_for_weights]['params'].append(p) 347 | else: 348 | vector_like_p['params'].append(p) 349 | 350 | for width_mult, group in matrix_like_p.items(): 351 | # Scale learning rate and weight decay accordingly 352 | print(width_mult) 353 | group['lr'] /= width_mult 354 | if not decoupled_wd: 355 | group['weight_decay'] *= width_mult 356 | new_param_groups.extend(list(matrix_like_p.values()) + [vector_like_p]) 357 | width_mult_group.extend(list(matrix_like_p.keys()) + [1.0]) 358 | 359 | model_parameters = new_param_groups 360 | 361 | ######## end MuP ########## 362 | 363 | kwargs = dict( 364 | model=model, 365 | model_parameters=model_parameters, 366 | config_params=config, 367 | optimizer=optimizer, 368 | lr_scheduler=lr_scheduler, 369 | ) 370 | 371 | deepspeed_engine, optimizer, _, lr_scheduler = deepspeed.initialize(**kwargs) 372 | 373 | # Must reset all min_lrs, max_lrs, delta_lrs due to Deepspeed Implementation of Schedulers. 374 | # Temporarily, solve this with post-processing 375 | if args.use_mup: 376 | assert len(width_mult_group) == len(lr_scheduler.delta_lrs) 377 | for i in range(len(width_mult_group)): 378 | lr_scheduler.min_lrs[i] /= width_mult_group[i] 379 | lr_scheduler.max_lrs[i] /= width_mult_group[i] 380 | lr_scheduler.delta_lrs[i] /= width_mult_group[i] 381 | 382 | print(f"modified min_lrs:{lr_scheduler.min_lrs}, max_lrs:{lr_scheduler.max_lrs},\ 383 | delta_lrs:{lr_scheduler.delta_lrs}") 384 | 385 | if resume_from_checkpoint is not None: 386 | 387 | # it's possible that the user is trying to resume from model_path, which doesn't necessarily 388 | # contain a deepspeed checkpoint. e.g. examples just check if the dir exists and assume it's 389 | # a resume from a checkpoint and not just a local pretrained weight. So we check here if the 390 | # path contains what looks like a deepspeed checkpoint 391 | import glob 392 | 393 | deepspeed_checkpoint_dirs = sorted(glob.glob(f"{resume_from_checkpoint}/global_step*")) 394 | 395 | if len(deepspeed_checkpoint_dirs) > 0: 396 | logger.info(f"Attempting to resume from {resume_from_checkpoint}") 397 | # this magically updates self.optimizer and self.lr_scheduler 398 | load_path, _ = deepspeed_engine.load_checkpoint( 399 | resume_from_checkpoint, load_optimizer_states=True, load_lr_scheduler_states=True 400 | ) 401 | if load_path is None: 402 | raise ValueError(f"[deepspeed] failed to resume from checkpoint {resume_from_checkpoint}") 403 | else: 404 | logger.info(f"{resume_from_checkpoint} doesn't have deepspeed checkpoints, doing nothing") 405 | 406 | return deepspeed_engine, optimizer, lr_scheduler -------------------------------------------------------------------------------- /modeling/initialize_with_mup.py: -------------------------------------------------------------------------------- 1 | from modeling.mup_utils import get_lazy_model_from_scratch 2 | import math 3 | import copy 4 | 5 | def mup_init_from_scratch(config, training_args, model_args, logger): 6 | logger.info(f'Loading Mup model from scratch') 7 | mup = training_args.use_mup 8 | size_per_head = training_args.size_per_head 9 | 10 | config.output_mult = training_args.output_mult 11 | config.initializer_range = training_args.initializer_range 12 | 13 | # maybe reset dropout to zero 14 | if training_args.unified_dropout is not None: 15 | print(f"resetting dropout={training_args.unified_dropout}") 16 | config.attn_pdrop = training_args.unified_dropout 17 | config.embd_pdrop = training_args.unified_dropout 18 | config.resid_pdrop = training_args.unified_dropout 19 | config.summary_first_dropout = training_args.unified_dropout 20 | 21 | config_base = copy.deepcopy(config) 22 | 23 | logger.info(f"Generating proxy model for HP tuning") 24 | config_hp_search = copy.deepcopy(config_base) 25 | config_hp_search.attn_mult = float(math.sqrt(size_per_head)) if mup else None 26 | config_hp_search.n_embd = training_args.hp_tune_actual_width 27 | config_hp_search.n_head = int(config_hp_search.n_embd / training_args.size_per_head) 28 | 29 | model_f = get_lazy_model_from_scratch(config=config_hp_search, 30 | mup=mup, 31 | readout_zero_init=mup and training_args.readout_zero_init, 32 | query_zero_init=mup and training_args.query_zero_init, 33 | input_mult=training_args.output_mult, 34 | width_mult_for_weights=training_args.width_mult_for_weights) 35 | model = model_f() 36 | model.transformer.input_mult = training_args.output_mult 37 | assert model.transformer.input_mult is not None 38 | 39 | return model -------------------------------------------------------------------------------- /modeling/lm_mup.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import functional as F 3 | from torch.nn import CrossEntropyLoss 4 | from typing import Optional, Tuple, Union 5 | from transformers import logging 6 | from .modeling_gpt2_mup import GPT2PreTrainedModel, GPT2Model 7 | from .utils import MuReadout 8 | from transformers.utils.model_parallel_utils import assert_device_map, get_device_map 9 | 10 | from transformers.modeling_outputs import ( 11 | CausalLMOutputWithCrossAttentions, 12 | SequenceClassifierOutputWithPast, 13 | ) 14 | 15 | logger = logging.get_logger(__name__) 16 | 17 | class MupGPT2Model(GPT2PreTrainedModel): 18 | _keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"] 19 | 20 | def __init__(self, config): 21 | super().__init__(config) 22 | self.transformer = GPT2Model(config) 23 | # self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 24 | ### muP: swap nn.Linear with MuReadout 25 | self.lm_head = MuReadout(config.n_embd, config.vocab_size, bias=False, 26 | output_mult=config.output_mult, 27 | width_mult=config.width_mult_for_weights) 28 | 29 | # Model parallel 30 | self.model_parallel = False 31 | self.device_map = None 32 | 33 | # Initialize weights and apply final processing 34 | self.post_init() 35 | 36 | def parallelize(self, device_map=None): 37 | self.device_map = ( 38 | get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) 39 | if device_map is None 40 | else device_map 41 | ) 42 | assert_device_map(self.device_map, len(self.transformer.h)) 43 | self.transformer.parallelize(self.device_map) 44 | self.lm_head = self.lm_head.to(self.transformer.first_device) 45 | self.model_parallel = True 46 | 47 | def deparallelize(self): 48 | self.transformer.deparallelize() 49 | self.transformer = self.transformer.to("cpu") 50 | self.lm_head = self.lm_head.to("cpu") 51 | self.model_parallel = False 52 | torch.cuda.empty_cache() 53 | 54 | def get_output_embeddings(self): 55 | return self.lm_head 56 | 57 | def set_output_embeddings(self, new_embeddings): 58 | self.lm_head = new_embeddings 59 | 60 | def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): 61 | token_type_ids = kwargs.get("token_type_ids", None) 62 | # only last token for inputs_ids if past is defined in kwargs 63 | if past: 64 | input_ids = input_ids[:, -1].unsqueeze(-1) 65 | if token_type_ids is not None: 66 | token_type_ids = token_type_ids[:, -1].unsqueeze(-1) 67 | 68 | attention_mask = kwargs.get("attention_mask", None) 69 | position_ids = kwargs.get("position_ids", None) 70 | 71 | if attention_mask is not None and position_ids is None: 72 | # create position_ids on the fly for batch generation 73 | position_ids = attention_mask.long().cumsum(-1) - 1 74 | position_ids.masked_fill_(attention_mask == 0, 1) 75 | if past: 76 | position_ids = position_ids[:, -1].unsqueeze(-1) 77 | else: 78 | position_ids = None 79 | return { 80 | "input_ids": input_ids, 81 | "past_key_values": past, 82 | "use_cache": kwargs.get("use_cache"), 83 | "position_ids": position_ids, 84 | "attention_mask": attention_mask, 85 | "token_type_ids": token_type_ids, 86 | } 87 | 88 | def forward_for_lm( 89 | self, 90 | input_ids: Optional[torch.LongTensor] = None, 91 | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, 92 | attention_mask: Optional[torch.FloatTensor] = None, 93 | token_type_ids: Optional[torch.LongTensor] = None, 94 | position_ids: Optional[torch.LongTensor] = None, 95 | head_mask: Optional[torch.FloatTensor] = None, 96 | inputs_embeds: Optional[torch.FloatTensor] = None, 97 | encoder_hidden_states: Optional[torch.Tensor] = None, 98 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 99 | labels: Optional[torch.LongTensor] = None, 100 | use_cache: Optional[bool] = None, 101 | output_attentions: Optional[bool] = None, 102 | output_hidden_states: Optional[bool] = None, 103 | return_dict: Optional[bool] = None, 104 | ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: 105 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 106 | 107 | transformer_outputs = self.transformer( 108 | input_ids, 109 | past_key_values=past_key_values, 110 | attention_mask=attention_mask, 111 | token_type_ids=token_type_ids, 112 | position_ids=position_ids, 113 | head_mask=head_mask, 114 | inputs_embeds=inputs_embeds, 115 | encoder_hidden_states=encoder_hidden_states, 116 | encoder_attention_mask=encoder_attention_mask, 117 | use_cache=use_cache, 118 | output_attentions=output_attentions, 119 | output_hidden_states=output_hidden_states, 120 | return_dict=return_dict, 121 | ) 122 | hidden_states = transformer_outputs[0] 123 | 124 | # Set device for model parallelism 125 | if self.model_parallel: 126 | torch.cuda.set_device(self.transformer.first_device) 127 | hidden_states = hidden_states.to(self.lm_head.weight.device) 128 | 129 | lm_logits = self.lm_head(hidden_states) 130 | 131 | loss = None 132 | if labels is not None: 133 | # Shift so that tokens < n predict n 134 | shift_logits = lm_logits[..., :-1, :].contiguous() 135 | shift_labels = labels[..., 1:].contiguous() 136 | # Flatten the tokens 137 | loss_fct = CrossEntropyLoss() 138 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 139 | 140 | 141 | if not return_dict: 142 | output = (lm_logits,) + transformer_outputs[1:] 143 | return ((loss,) + output) if loss is not None else output 144 | 145 | return CausalLMOutputWithCrossAttentions( 146 | loss=loss, 147 | logits=lm_logits, 148 | past_key_values=transformer_outputs.past_key_values, 149 | hidden_states=transformer_outputs.hidden_states, 150 | attentions=transformer_outputs.attentions, 151 | cross_attentions=transformer_outputs.cross_attentions, 152 | ) 153 | 154 | def forward( 155 | self, 156 | input_ids: Optional[torch.LongTensor] = None, 157 | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, 158 | attention_mask: Optional[torch.FloatTensor] = None, 159 | token_type_ids: Optional[torch.LongTensor] = None, 160 | position_ids: Optional[torch.LongTensor] = None, 161 | head_mask: Optional[torch.FloatTensor] = None, 162 | inputs_embeds: Optional[torch.FloatTensor] = None, 163 | encoder_hidden_states: Optional[torch.Tensor] = None, 164 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 165 | tid: Optional[torch.LongTensor] = None, 166 | length: Optional[torch.LongTensor] = None, 167 | labels: Optional[torch.LongTensor] = None, 168 | use_cache: Optional[bool] = None, 169 | output_attentions: Optional[bool] = None, 170 | output_hidden_states: Optional[bool] = None, 171 | return_dict: Optional[bool] = None, 172 | train_type: Optional[str] = 'lm', 173 | ) -> Union[Tuple, CausalLMOutputWithCrossAttentions, SequenceClassifierOutputWithPast]: 174 | r""" 175 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 176 | Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set 177 | `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` 178 | are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` 179 | """ 180 | if train_type == 'lm': 181 | return self.forward_for_lm(input_ids, 182 | past_key_values, 183 | attention_mask, 184 | token_type_ids, 185 | position_ids, 186 | head_mask, 187 | inputs_embeds, 188 | encoder_hidden_states, 189 | encoder_attention_mask, 190 | labels, 191 | use_cache, 192 | output_attentions, 193 | output_hidden_states, 194 | return_dict) 195 | else: 196 | raise NotImplementedError( 197 | 'Unknow `train_type` = \'%s\'' % train_type 198 | ) 199 | 200 | @staticmethod 201 | def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]: 202 | """ 203 | This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or 204 | [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct 205 | beam_idx at every generation step. 206 | """ 207 | return tuple( 208 | tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) 209 | for layer_past in past 210 | ) 211 | 212 | 213 | if __name__ == '__main__': 214 | pass 215 | -------------------------------------------------------------------------------- /modeling/modeling_gpt2_mup.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import math 4 | import os 5 | from typing import Optional, Tuple, Union 6 | 7 | import torch 8 | import torch.utils.checkpoint 9 | from torch import nn 10 | from torch.cuda.amp import autocast 11 | 12 | from transformers.activations import ACT2FN 13 | from transformers.modeling_outputs import ( 14 | BaseModelOutputWithPastAndCrossAttentions, 15 | ) 16 | from transformers.modeling_utils import PreTrainedModel 17 | from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer 18 | from transformers.utils import ( 19 | logging, 20 | ) 21 | from transformers.utils.model_parallel_utils import assert_device_map, get_device_map 22 | from transformers import GPT2Config 23 | 24 | logger = logging.get_logger(__name__) 25 | 26 | GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [ 27 | "gpt2", 28 | "gpt2-medium", 29 | "gpt2-large", 30 | "gpt2-xl", 31 | "distilgpt2", 32 | # See all GPT-2 models at https://huggingface.co/models?filter=gpt2 33 | ] 34 | 35 | 36 | def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): 37 | """Load tf checkpoints in a pytorch model""" 38 | try: 39 | import re 40 | 41 | import tensorflow as tf 42 | except ImportError: 43 | logger.error( 44 | "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " 45 | "https://www.tensorflow.org/install/ for installation instructions." 46 | ) 47 | raise 48 | tf_path = os.path.abspath(gpt2_checkpoint_path) 49 | logger.info(f"Converting TensorFlow checkpoint from {tf_path}") 50 | # Load weights from TF model 51 | init_vars = tf.train.list_variables(tf_path) 52 | names = [] 53 | arrays = [] 54 | for name, shape in init_vars: 55 | logger.info(f"Loading TF weight {name} with shape {shape}") 56 | array = tf.train.load_variable(tf_path, name) 57 | names.append(name) 58 | arrays.append(array.squeeze()) 59 | 60 | for name, array in zip(names, arrays): 61 | name = name[6:] # skip "model/" 62 | name = name.split("/") 63 | pointer = model 64 | for m_name in name: 65 | if re.fullmatch(r"[A-Za-z]+\d+", m_name): 66 | scope_names = re.split(r"(\d+)", m_name) 67 | else: 68 | scope_names = [m_name] 69 | if scope_names[0] == "w" or scope_names[0] == "g": 70 | pointer = getattr(pointer, "weight") 71 | elif scope_names[0] == "b": 72 | pointer = getattr(pointer, "bias") 73 | elif scope_names[0] == "wpe" or scope_names[0] == "wte": 74 | pointer = getattr(pointer, scope_names[0]) 75 | pointer = getattr(pointer, "weight") 76 | else: 77 | pointer = getattr(pointer, scope_names[0]) 78 | if len(scope_names) >= 2: 79 | num = int(scope_names[1]) 80 | pointer = pointer[num] 81 | try: 82 | assert ( 83 | pointer.shape == array.shape 84 | ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" 85 | except AssertionError as e: 86 | e.args += (pointer.shape, array.shape) 87 | raise 88 | logger.info(f"Initialize PyTorch weight {name}") 89 | pointer.data = torch.from_numpy(array) 90 | return model 91 | 92 | 93 | class GPT2Attention(nn.Module): 94 | def __init__(self, config, is_cross_attention=False, layer_idx=None): 95 | super().__init__() 96 | 97 | max_positions = config.max_position_embeddings 98 | self.register_buffer( 99 | "bias", 100 | torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view( 101 | 1, 1, max_positions, max_positions 102 | ), 103 | ) 104 | self.register_buffer("masked_bias", torch.tensor(-1e4)) 105 | 106 | self.embed_dim = config.hidden_size 107 | self.num_heads = config.num_attention_heads 108 | self.head_dim = self.embed_dim // self.num_heads 109 | self.split_size = self.embed_dim 110 | if self.head_dim * self.num_heads != self.embed_dim: 111 | raise ValueError( 112 | f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" 113 | f" {self.num_heads})." 114 | ) 115 | 116 | self.scale_attn_weights = config.scale_attn_weights 117 | self.is_cross_attention = is_cross_attention 118 | 119 | ############## muP 120 | self.attn_mult = config.attn_mult 121 | ############## end muP 122 | 123 | # Layer-wise attention scaling, reordering, and upcasting 124 | self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx 125 | self.layer_idx = layer_idx 126 | self.reorder_and_upcast_attn = config.reorder_and_upcast_attn 127 | 128 | if self.is_cross_attention: 129 | self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim) 130 | self.q_attn = Conv1D(self.embed_dim, self.embed_dim) 131 | else: 132 | self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim) 133 | self.c_proj = Conv1D(self.embed_dim, self.embed_dim) 134 | 135 | self.attn_dropout = nn.Dropout(config.attn_pdrop) 136 | self.resid_dropout = nn.Dropout(config.resid_pdrop) 137 | 138 | self.pruned_heads = set() 139 | 140 | def prune_heads(self, heads): 141 | if len(heads) == 0: 142 | return 143 | heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads) 144 | index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)]) 145 | 146 | # Prune conv1d layers 147 | self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1) 148 | self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0) 149 | 150 | # Update hyper params 151 | self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads)) 152 | self.num_heads = self.num_heads - len(heads) 153 | self.pruned_heads = self.pruned_heads.union(heads) 154 | 155 | def _attn(self, query, key, value, attention_mask=None, head_mask=None): 156 | attn_weights = torch.matmul(query, key.transpose(-1, -2)) 157 | 158 | ### muP 159 | if self.scale_attn_weights: 160 | ### muP: attn scaling 161 | if self.attn_mult is not None: 162 | attn_weights = attn_weights * self.attn_mult / float(value.size(-1)) 163 | else: 164 | attn_weights = attn_weights / (float(value.size(-1)) ** 0.5) 165 | ### end muP 166 | 167 | # Layer-wise attention scaling 168 | if self.scale_attn_by_inverse_layer_idx: 169 | attn_weights = attn_weights / float(self.layer_idx + 1) 170 | 171 | if not self.is_cross_attention: 172 | # if only "normal" attention layer implements causal mask 173 | query_length, key_length = query.size(-2), key.size(-2) 174 | causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool) 175 | mask_value = torch.finfo(attn_weights.dtype).min 176 | # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. 177 | # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` 178 | mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device) 179 | attn_weights = torch.where(causal_mask, attn_weights, mask_value) 180 | 181 | if attention_mask is not None: 182 | # Apply the attention mask 183 | attn_weights = attn_weights + attention_mask 184 | 185 | attn_weights = nn.functional.softmax(attn_weights, dim=-1) 186 | 187 | # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise 188 | attn_weights = attn_weights.type(value.dtype) 189 | attn_weights = self.attn_dropout(attn_weights) 190 | 191 | # Mask heads if we want to 192 | if head_mask is not None: 193 | attn_weights = attn_weights * head_mask 194 | 195 | attn_output = torch.matmul(attn_weights, value) 196 | 197 | return attn_output, attn_weights 198 | 199 | def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None): 200 | # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM) 201 | bsz, num_heads, q_seq_len, dk = query.size() 202 | _, _, k_seq_len, _ = key.size() 203 | 204 | # Preallocate attn_weights for `baddbmm` 205 | attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device) 206 | 207 | # Compute Scale Factor 208 | scale_factor = 1.0 209 | if self.scale_attn_weights: 210 | scale_factor /= float(value.size(-1)) ** 0.5 211 | 212 | if self.scale_attn_by_inverse_layer_idx: 213 | scale_factor /= float(self.layer_idx + 1) 214 | 215 | # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk)) 216 | with autocast(enabled=False): 217 | q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) 218 | attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) 219 | attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) 220 | 221 | if not self.is_cross_attention: 222 | # if only "normal" attention layer implements causal mask 223 | query_length, key_length = query.size(-2), key.size(-2) 224 | causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() 225 | mask_value = torch.finfo(attn_weights.dtype).min 226 | # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. 227 | # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` 228 | mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device) 229 | attn_weights = torch.where(causal_mask, attn_weights, mask_value) 230 | 231 | if attention_mask is not None: 232 | # Apply the attention mask 233 | attn_weights = attn_weights + attention_mask 234 | 235 | attn_weights = nn.functional.softmax(attn_weights, dim=-1) 236 | 237 | # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise 238 | if attn_weights.dtype != torch.float32: 239 | raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32") 240 | attn_weights = attn_weights.type(value.dtype) 241 | attn_weights = self.attn_dropout(attn_weights) 242 | 243 | # Mask heads if we want to 244 | if head_mask is not None: 245 | attn_weights = attn_weights * head_mask 246 | 247 | attn_output = torch.matmul(attn_weights, value) 248 | 249 | return attn_output, attn_weights 250 | 251 | def _split_heads(self, tensor, num_heads, attn_head_size): 252 | """ 253 | Splits hidden_size dim into attn_head_size and num_heads 254 | """ 255 | new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) 256 | tensor = tensor.view(new_shape) 257 | return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) 258 | 259 | def _merge_heads(self, tensor, num_heads, attn_head_size): 260 | """ 261 | Merges attn_head_size dim and num_attn_heads dim into hidden_size 262 | """ 263 | tensor = tensor.permute(0, 2, 1, 3).contiguous() 264 | new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) 265 | return tensor.view(new_shape) 266 | 267 | def forward( 268 | self, 269 | hidden_states: Optional[Tuple[torch.FloatTensor]], 270 | layer_past: Optional[Tuple[torch.Tensor]] = None, 271 | attention_mask: Optional[torch.FloatTensor] = None, 272 | head_mask: Optional[torch.FloatTensor] = None, 273 | encoder_hidden_states: Optional[torch.Tensor] = None, 274 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 275 | use_cache: Optional[bool] = False, 276 | output_attentions: Optional[bool] = False, 277 | ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: 278 | if encoder_hidden_states is not None: 279 | if not hasattr(self, "q_attn"): 280 | raise ValueError( 281 | "If class is used as cross attention, the weights `q_attn` have to be defined. " 282 | "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." 283 | ) 284 | 285 | query = self.q_attn(hidden_states) 286 | key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) 287 | attention_mask = encoder_attention_mask 288 | else: 289 | query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) 290 | 291 | query = self._split_heads(query, self.num_heads, self.head_dim) 292 | key = self._split_heads(key, self.num_heads, self.head_dim) 293 | value = self._split_heads(value, self.num_heads, self.head_dim) 294 | 295 | if layer_past is not None: 296 | past_key, past_value = layer_past 297 | key = torch.cat((past_key, key), dim=-2) 298 | value = torch.cat((past_value, value), dim=-2) 299 | 300 | if use_cache is True: 301 | present = (key, value) 302 | else: 303 | present = None 304 | 305 | if self.reorder_and_upcast_attn: 306 | attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask) 307 | else: 308 | attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) 309 | 310 | attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) 311 | attn_output = self.c_proj(attn_output) 312 | attn_output = self.resid_dropout(attn_output) 313 | 314 | outputs = (attn_output, present) 315 | if output_attentions: 316 | outputs += (attn_weights,) 317 | 318 | return outputs # a, present, (attentions) 319 | 320 | 321 | class GPT2MLP(nn.Module): 322 | def __init__(self, intermediate_size, config): 323 | super().__init__() 324 | embed_dim = config.hidden_size 325 | self.c_fc = Conv1D(intermediate_size, embed_dim) 326 | self.c_proj = Conv1D(embed_dim, intermediate_size) 327 | self.act = ACT2FN[config.activation_function] 328 | self.dropout = nn.Dropout(config.resid_pdrop) 329 | 330 | def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: 331 | hidden_states = self.c_fc(hidden_states) 332 | hidden_states = self.act(hidden_states) 333 | hidden_states = self.c_proj(hidden_states) 334 | hidden_states = self.dropout(hidden_states) 335 | return hidden_states 336 | 337 | 338 | class GPT2Block(nn.Module): 339 | def __init__(self, config, layer_idx=None): 340 | super().__init__() 341 | hidden_size = config.hidden_size 342 | inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size 343 | 344 | self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) 345 | self.attn = GPT2Attention(config, layer_idx=layer_idx) 346 | self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) 347 | 348 | if config.add_cross_attention: 349 | self.crossattention = GPT2Attention(config, is_cross_attention=True, layer_idx=layer_idx) 350 | self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) 351 | 352 | self.mlp = GPT2MLP(inner_dim, config) 353 | 354 | def forward( 355 | self, 356 | hidden_states: Optional[Tuple[torch.FloatTensor]], 357 | layer_past: Optional[Tuple[torch.Tensor]] = None, 358 | attention_mask: Optional[torch.FloatTensor] = None, 359 | head_mask: Optional[torch.FloatTensor] = None, 360 | encoder_hidden_states: Optional[torch.Tensor] = None, 361 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 362 | use_cache: Optional[bool] = False, 363 | output_attentions: Optional[bool] = False, 364 | ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: 365 | residual = hidden_states 366 | hidden_states = self.ln_1(hidden_states) 367 | attn_outputs = self.attn( 368 | hidden_states, 369 | layer_past=layer_past, 370 | attention_mask=attention_mask, 371 | head_mask=head_mask, 372 | use_cache=use_cache, 373 | output_attentions=output_attentions, 374 | ) 375 | attn_output = attn_outputs[0] # output_attn: a, present, (attentions) 376 | outputs = attn_outputs[1:] 377 | # residual connection 378 | hidden_states = attn_output + residual 379 | 380 | if encoder_hidden_states is not None: 381 | # add one self-attention block for cross-attention 382 | if not hasattr(self, "crossattention"): 383 | raise ValueError( 384 | f"If `encoder_hidden_states` are passed, {self} has to be instantiated with " 385 | "cross-attention layers by setting `config.add_cross_attention=True`" 386 | ) 387 | residual = hidden_states 388 | hidden_states = self.ln_cross_attn(hidden_states) 389 | cross_attn_outputs = self.crossattention( 390 | hidden_states, 391 | attention_mask=attention_mask, 392 | head_mask=head_mask, 393 | encoder_hidden_states=encoder_hidden_states, 394 | encoder_attention_mask=encoder_attention_mask, 395 | output_attentions=output_attentions, 396 | ) 397 | attn_output = cross_attn_outputs[0] 398 | # residual connection 399 | hidden_states = residual + attn_output 400 | outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights 401 | 402 | residual = hidden_states 403 | hidden_states = self.ln_2(hidden_states) 404 | feed_forward_hidden_states = self.mlp(hidden_states) 405 | # residual connection 406 | hidden_states = residual + feed_forward_hidden_states 407 | 408 | if use_cache: 409 | outputs = (hidden_states,) + outputs 410 | else: 411 | outputs = (hidden_states,) + outputs[1:] 412 | 413 | return outputs # hidden_states, present, (attentions, cross_attentions) 414 | 415 | 416 | class GPT2PreTrainedModel(PreTrainedModel): 417 | """ 418 | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained 419 | models. 420 | """ 421 | 422 | config_class = GPT2Config 423 | load_tf_weights = load_tf_weights_in_gpt2 424 | base_model_prefix = "transformer" 425 | is_parallelizable = True 426 | supports_gradient_checkpointing = True 427 | _no_split_modules = ["GPT2Block"] 428 | 429 | def __init__(self, *inputs, **kwargs): 430 | super().__init__(*inputs, **kwargs) 431 | 432 | def _init_weights(self, module): 433 | return 434 | 435 | def _init_all_weights_for_mup(self, readout_zero_init=False, 436 | query_zero_init=False, input_mult=1.0, width_mult_for_weights=1.0): 437 | """ Initialize all weights according to Mup rules. 438 | Should be called after instantiation of the model.""" 439 | 440 | def _exact(p, mean, var): 441 | # Experimental improvements: make final mean/variance exactly equal to the target. 442 | # Sometimes large variance in normal_() causes the actual mean to deviate from expected. 443 | p.data = (p.data - torch.mean(p, dtype=torch.float32).item()) / math.sqrt(torch.var(p).item()) * math.sqrt(var) + mean 444 | 445 | ref_mean_var_dict = {} 446 | for k in dict(self.named_parameters()).keys(): 447 | if "ln" in k or "bias" in k: 448 | ref_mean_var_dict[k] = (1.0, 0.0) if "weight" in k else (0.0, 0.0) 449 | else: 450 | ref_mean_var_dict[k] = (0.0, self.config.initializer_range ** 2) # embedding and matrix-like params 451 | 452 | ############################################ 453 | 454 | for n, p in self.named_parameters(): 455 | if n == "transformer.wte.weight": 456 | if readout_zero_init: 457 | p.data.zero_() 458 | else: 459 | # vector-like,keep the same in all widths。 460 | p.data.normal_(mean=ref_mean_var_dict[n][0], 461 | std=math.sqrt(ref_mean_var_dict[n][1])) 462 | _exact(p, ref_mean_var_dict[n][0], ref_mean_var_dict[n][1]) 463 | if self.transformer.wte.padding_idx is not None: 464 | p.data[self.transformer.wte.padding_idx].zero_() 465 | elif n == "transformer.wpe.weight": 466 | # position embedding,vec-like, keep constant in all widths 467 | p.data.normal_(mean=ref_mean_var_dict[n][0], std=math.sqrt(ref_mean_var_dict[n][1])) 468 | _exact(p, ref_mean_var_dict[n][0], ref_mean_var_dict[n][1]) 469 | if self.transformer.wpe.padding_idx is not None: 470 | p.data[self.transformer.wpe.padding_idx].zero_() 471 | elif "lm_head" in n: 472 | # do nothing. lm_head.weight == transformer.wte.weight 473 | pass 474 | elif "ln" in n or "bias" in n: 475 | # layernorm, vec-like 476 | if "weight" in n: 477 | p.data.fill_(1.0) 478 | else: 479 | p.data.zero_() 480 | else: 481 | assert "weight" in n 482 | # all truely matrix-like parameters 483 | p.data.normal_(mean=0.0, std=math.sqrt(ref_mean_var_dict[n][1] / width_mult_for_weights)) 484 | scaled_var = ref_mean_var_dict[n][1] / width_mult_for_weights 485 | _exact(p, 0.0, scaled_var) 486 | 487 | if ("c_attn.weight" in n) and query_zero_init: 488 | # this makes Mup alignment more accurate 489 | _, fanout = p.shape 490 | assert fanout % 3 == 0 491 | p.data[:, :fanout//3] = 0 492 | 493 | depth_std = self.config.initializer_range / math.sqrt(2 * self.config.n_layer) 494 | if "c_proj" in n and "weight" in n: 495 | # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block 496 | p.data.normal_(mean=0.0, std=math.sqrt(depth_std ** 2 / width_mult_for_weights)) 497 | scaled_var = depth_std ** 2 / width_mult_for_weights 498 | _exact(p, 0.0, scaled_var) 499 | 500 | def _set_gradient_checkpointing(self, module, value=False): 501 | if isinstance(module, GPT2Model): 502 | module.gradient_checkpointing = value 503 | 504 | 505 | class GPT2Model(GPT2PreTrainedModel): 506 | _keys_to_ignore_on_load_missing = ["attn.masked_bias"] 507 | 508 | def __init__(self, config): 509 | super().__init__(config) 510 | 511 | self.embed_dim = config.hidden_size 512 | 513 | self.wte = nn.Embedding(config.vocab_size, self.embed_dim) 514 | # for mup #### 515 | self.input_mult = None # assigned outside 516 | ############## 517 | 518 | self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) 519 | 520 | self.drop = nn.Dropout(config.embd_pdrop) 521 | self.h = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)]) 522 | self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) 523 | 524 | # Model parallel 525 | self.model_parallel = False 526 | self.device_map = None 527 | self.gradient_checkpointing = False 528 | 529 | # Initialize weights and apply final processing 530 | self.post_init() 531 | 532 | def parallelize(self, device_map=None): 533 | # Check validity of device_map 534 | self.device_map = ( 535 | get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map 536 | ) 537 | assert_device_map(self.device_map, len(self.h)) 538 | self.model_parallel = True 539 | self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) 540 | self.last_device = "cuda:" + str(max(self.device_map.keys())) 541 | self.wte = self.wte.to(self.first_device) 542 | self.wpe = self.wpe.to(self.first_device) 543 | # Load onto devices 544 | for k, v in self.device_map.items(): 545 | for block in v: 546 | cuda_device = "cuda:" + str(k) 547 | self.h[block] = self.h[block].to(cuda_device) 548 | # ln_f to last 549 | self.ln_f = self.ln_f.to(self.last_device) 550 | 551 | 552 | def deparallelize(self): 553 | self.model_parallel = False 554 | self.device_map = None 555 | self.first_device = "cpu" 556 | self.last_device = "cpu" 557 | self.wte = self.wte.to("cpu") 558 | self.wpe = self.wpe.to("cpu") 559 | for index in range(len(self.h)): 560 | self.h[index] = self.h[index].to("cpu") 561 | self.ln_f = self.ln_f.to("cpu") 562 | torch.cuda.empty_cache() 563 | 564 | def get_input_embeddings(self): 565 | return self.wte 566 | 567 | def set_input_embeddings(self, new_embeddings): 568 | self.wte = new_embeddings 569 | 570 | def _prune_heads(self, heads_to_prune): 571 | """ 572 | Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} 573 | """ 574 | for layer, heads in heads_to_prune.items(): 575 | self.h[layer].attn.prune_heads(heads) 576 | 577 | def forward( 578 | self, 579 | input_ids: Optional[torch.LongTensor] = None, 580 | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, 581 | attention_mask: Optional[torch.FloatTensor] = None, 582 | token_type_ids: Optional[torch.LongTensor] = None, 583 | position_ids: Optional[torch.LongTensor] = None, 584 | head_mask: Optional[torch.FloatTensor] = None, 585 | inputs_embeds: Optional[torch.FloatTensor] = None, 586 | encoder_hidden_states: Optional[torch.Tensor] = None, 587 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 588 | use_cache: Optional[bool] = None, 589 | output_attentions: Optional[bool] = None, 590 | output_hidden_states: Optional[bool] = None, 591 | return_dict: Optional[bool] = None, 592 | ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: 593 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 594 | output_hidden_states = ( 595 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 596 | ) 597 | use_cache = use_cache if use_cache is not None else self.config.use_cache 598 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 599 | 600 | if input_ids is not None and inputs_embeds is not None: 601 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 602 | elif input_ids is not None: 603 | input_shape = input_ids.size() 604 | input_ids = input_ids.view(-1, input_shape[-1]) 605 | batch_size = input_ids.shape[0] 606 | elif inputs_embeds is not None: 607 | input_shape = inputs_embeds.size()[:-1] 608 | batch_size = inputs_embeds.shape[0] 609 | else: 610 | raise ValueError("You have to specify either input_ids or inputs_embeds") 611 | 612 | device = input_ids.device if input_ids is not None else inputs_embeds.device 613 | 614 | if token_type_ids is not None: 615 | token_type_ids = token_type_ids.view(-1, input_shape[-1]) 616 | if position_ids is not None: 617 | position_ids = position_ids.view(-1, input_shape[-1]) 618 | 619 | if past_key_values is None: 620 | past_length = 0 621 | past_key_values = tuple([None] * len(self.h)) 622 | else: 623 | past_length = past_key_values[0][0].size(-2) 624 | if position_ids is None: 625 | position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) 626 | position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) 627 | 628 | # GPT2Attention mask. 629 | if attention_mask is not None: 630 | if batch_size <= 0: 631 | raise ValueError("batch_size has to be defined and > 0") 632 | attention_mask = attention_mask.view(batch_size, -1) 633 | # We create a 3D attention mask from a 2D tensor mask. 634 | # Sizes are [batch_size, 1, 1, to_seq_length] 635 | # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] 636 | # this attention mask is more simple than the triangular masking of causal attention 637 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here. 638 | attention_mask = attention_mask[:, None, None, :] 639 | 640 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 641 | # masked positions, this operation will create a tensor which is 0.0 for 642 | # positions we want to attend and the dtype's smallest value for masked positions. 643 | # Since we are adding it to the raw scores before the softmax, this is 644 | # effectively the same as removing these entirely. 645 | attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility 646 | attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min 647 | 648 | # If a 2D or 3D attention mask is provided for the cross-attention 649 | # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] 650 | if self.config.add_cross_attention and encoder_hidden_states is not None: 651 | encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() 652 | encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) 653 | if encoder_attention_mask is None: 654 | encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) 655 | encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) 656 | else: 657 | encoder_attention_mask = None 658 | 659 | # Prepare head mask if needed 660 | # 1.0 in head_mask indicate we keep the head 661 | # attention_probs has shape bsz x n_heads x N x N 662 | # head_mask has shape n_layer x batch x n_heads x N x N 663 | head_mask = self.get_head_mask(head_mask, self.config.n_layer) 664 | 665 | if inputs_embeds is None: 666 | inputs_embeds = self.wte(input_ids) 667 | if self.input_mult is not None: 668 | inputs_embeds = inputs_embeds * self.input_mult 669 | 670 | 671 | position_embeds = self.wpe(position_ids) 672 | hidden_states = inputs_embeds + position_embeds 673 | 674 | if token_type_ids is not None: 675 | token_type_embeds = self.wte(token_type_ids) 676 | hidden_states = hidden_states + token_type_embeds 677 | 678 | hidden_states = self.drop(hidden_states) 679 | 680 | output_shape = input_shape + (hidden_states.size(-1),) 681 | 682 | presents = () if use_cache else None 683 | all_self_attentions = () if output_attentions else None 684 | all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None 685 | all_hidden_states = () if output_hidden_states else None 686 | for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): 687 | 688 | # Model parallel 689 | if self.model_parallel: 690 | torch.cuda.set_device(hidden_states.device) 691 | # Ensure layer_past is on same device as hidden_states (might not be correct) 692 | if layer_past is not None: 693 | layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) 694 | # Ensure that attention_mask is always on the same device as hidden_states 695 | if attention_mask is not None: 696 | attention_mask = attention_mask.to(hidden_states.device) 697 | if isinstance(head_mask, torch.Tensor): 698 | head_mask = head_mask.to(hidden_states.device) 699 | if output_hidden_states: 700 | all_hidden_states = all_hidden_states + (hidden_states,) 701 | 702 | if self.gradient_checkpointing and self.training: 703 | 704 | if use_cache: 705 | logger.warning( 706 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 707 | ) 708 | use_cache = False 709 | 710 | def create_custom_forward(module): 711 | def custom_forward(*inputs): 712 | # None for past_key_value 713 | return module(*inputs, use_cache, output_attentions) 714 | 715 | return custom_forward 716 | 717 | outputs = torch.utils.checkpoint.checkpoint( 718 | create_custom_forward(block), 719 | hidden_states, 720 | None, 721 | attention_mask, 722 | head_mask[i], 723 | encoder_hidden_states, 724 | encoder_attention_mask, 725 | ) 726 | else: 727 | outputs = block( 728 | hidden_states, 729 | layer_past=layer_past, 730 | attention_mask=attention_mask, 731 | head_mask=head_mask[i], 732 | encoder_hidden_states=encoder_hidden_states, 733 | encoder_attention_mask=encoder_attention_mask, 734 | use_cache=use_cache, 735 | output_attentions=output_attentions, 736 | ) 737 | 738 | hidden_states = outputs[0] 739 | if use_cache is True: 740 | presents = presents + (outputs[1],) 741 | 742 | if output_attentions: 743 | all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) 744 | if self.config.add_cross_attention: 745 | all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) 746 | 747 | # Model Parallel: If it's the last layer for that device, put things on the next device 748 | if self.model_parallel: 749 | for k, v in self.device_map.items(): 750 | if i == v[-1] and "cuda:" + str(k) != self.last_device: 751 | hidden_states = hidden_states.to("cuda:" + str(k + 1)) 752 | 753 | hidden_states = self.ln_f(hidden_states) 754 | 755 | hidden_states = hidden_states.view(output_shape) 756 | # Add last hidden state 757 | if output_hidden_states: 758 | all_hidden_states = all_hidden_states + (hidden_states,) 759 | 760 | if not return_dict: 761 | return tuple( 762 | v 763 | for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] 764 | if v is not None 765 | ) 766 | 767 | return BaseModelOutputWithPastAndCrossAttentions( 768 | last_hidden_state=hidden_states, 769 | past_key_values=presents, 770 | hidden_states=all_hidden_states, 771 | attentions=all_self_attentions, 772 | cross_attentions=all_cross_attentions, 773 | ) 774 | -------------------------------------------------------------------------------- /modeling/mup_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Microsoft Corporation. 2 | 3 | from functools import partial 4 | import math 5 | 6 | from torch.utils.data import DataLoader 7 | 8 | import seaborn as sns 9 | import datasets 10 | from transformers import default_data_collator 11 | 12 | from modeling.lm_mup import MupGPT2Model 13 | 14 | from torch.optim.lr_scheduler import LambdaLR 15 | 16 | sns.set() 17 | 18 | def get_dataloader(arch, signature_columns): 19 | # **************************************************************************************************** 20 | # Load data 21 | # **************************************************************************************************** 22 | train_dataset = None 23 | final_lm_dir = "add-your-own-path" 24 | max_lm_train_samples = None 25 | if final_lm_dir is not None: 26 | print(f'From {final_lm_dir} / {max_lm_train_samples}') 27 | train_dataset = datasets.load_from_disk(final_lm_dir) 28 | if max_lm_train_samples is not None: 29 | train_dataset = train_dataset.select(range(max_lm_train_samples)) 30 | print(f'{train_dataset}') 31 | 32 | assert train_dataset is not None 33 | 34 | lm_dataloader = DataLoader( 35 | train_dataset, 36 | collate_fn=default_data_collator, 37 | batch_size=8, 38 | num_workers=10, 39 | pin_memory=True, 40 | ) 41 | 42 | return lm_dataloader 43 | 44 | 45 | def get_lazy_model_from_scratch(config, mup=True, 46 | readout_zero_init=True, query_zero_init=True, input_mult=1.0, width_mult_for_weights=1.0): 47 | def f(): 48 | config_in = config 49 | model = MupGPT2Model._from_config(config_in) 50 | if mup: 51 | model._init_all_weights_for_mup(readout_zero_init, query_zero_init, input_mult, width_mult_for_weights) 52 | return model 53 | 54 | return f 55 | 56 | 57 | def _get_linear_schedule_with_inverse_log_warmup_lr_lambda(current_step: int, 58 | *, num_warmup_steps: int, num_training_steps: int): 59 | if current_step < num_warmup_steps: 60 | return math.log(current_step + 1) * 1.0 / math.log(float(max(1, num_warmup_steps)) + 1e-7) 61 | return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))) 62 | 63 | def get_linear_schedule_with_inverse_log_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1): 64 | lr_lambda = partial( 65 | _get_linear_schedule_with_inverse_log_warmup_lr_lambda, 66 | num_warmup_steps=num_warmup_steps, 67 | num_training_steps=num_training_steps, 68 | ) 69 | return LambdaLR(optimizer, lr_lambda, last_epoch) 70 | -------------------------------------------------------------------------------- /modeling/utils.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Linear 2 | 3 | class MuReadout(Linear): 4 | '''Drop-in replacement for all output linear layers. 5 | 6 | An "output" linear layer is one that maps from a width dimension (e.g., 7 | `d_model` in a Transformer) to a non-width dimension (e.g., vocab size). 8 | 9 | This layer implements the version of μP with a 1/width multiplier and a 10 | constant variance initialization for both weights and biases. 11 | ''' 12 | def __init__(self, *args, readout_zero_init=False, output_mult=1.0, width_mult=1.0, **kwargs): 13 | self.output_mult = output_mult 14 | self.readout_zero_init = readout_zero_init 15 | self.width_mult_val = width_mult 16 | super().__init__(*args, **kwargs) 17 | 18 | def width_mult(self): 19 | return self.width_mult_val 20 | 21 | def reset_parameters(self) -> None: 22 | if self.readout_zero_init: 23 | self.weight.data[:] = 0 24 | if self.bias is not None: 25 | self.bias.data[:] = 0 26 | else: 27 | super().reset_parameters() 28 | 29 | def forward(self, x): 30 | return super().forward( 31 | self.output_mult * x / self.width_mult()) 32 | 33 | 34 | if __name__ == '__main__': 35 | pass 36 | -------------------------------------------------------------------------------- /mup_trainer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import shutil 4 | import sys 5 | import time 6 | import torch.distributed as dist 7 | from typing import Callable, Dict, List, Optional, Tuple, Union 8 | from tqdm.auto import tqdm 9 | from collections import defaultdict 10 | 11 | import torch 12 | from torch import nn 13 | from torch.utils.data import DataLoader, Dataset, RandomSampler 14 | from torch.utils.data.distributed import DistributedSampler 15 | 16 | from transformers import Trainer 17 | from transformers.trainer_callback import TrainerCallback 18 | from transformers.modeling_utils import PreTrainedModel 19 | from transformers.training_args import TrainingArguments 20 | from transformers.data.data_collator import DataCollator 21 | from transformers.tokenization_utils_base import PreTrainedTokenizerBase 22 | from transformers.training_args import TrainingArguments 23 | from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_less_than_1_11 24 | from transformers.utils import ( 25 | is_apex_available, 26 | is_datasets_available, 27 | is_sagemaker_mp_enabled, 28 | is_torch_tpu_available, 29 | logging, 30 | ) 31 | from transformers.trainer_pt_utils import ( 32 | IterableDatasetShard, 33 | get_parameter_names, 34 | ) 35 | from transformers.trainer_utils import ( 36 | EvalPrediction, 37 | HPSearchBackend, 38 | ShardedDDPOption, 39 | TrainOutput, 40 | has_length, 41 | speed_metrics, 42 | ) 43 | from transformers.trainer_callback import ( 44 | TrainerCallback, 45 | TrainerState, 46 | ) 47 | from transformers.integrations import ( 48 | hp_params, 49 | ) 50 | from transformers.data.data_collator import DataCollator 51 | from transformers.debug_utils import DebugOption, DebugUnderflowOverflow 52 | # for mup 53 | from modeling.deepspeed_mup import deepspeed_init, process_param_groups 54 | from transformers.optimization import get_scheduler 55 | from modeling.mup_utils import get_linear_schedule_with_inverse_log_warmup 56 | 57 | 58 | if is_datasets_available(): 59 | import datasets 60 | 61 | if is_torch_tpu_available(check_device=False): 62 | import torch_xla.core.xla_model as xm 63 | import torch_xla.debug.metrics as met 64 | import torch_xla.distributed.parallel_loader as pl 65 | 66 | if is_apex_available(): 67 | from apex import amp 68 | 69 | logger = logging.get_logger(__name__) 70 | 71 | # Name of the files used for checkpointing 72 | TRAINING_ARGS_NAME = "training_args.bin" 73 | TRAINER_STATE_NAME = "trainer_state.json" 74 | OPTIMIZER_NAME = "optimizer.pt" 75 | SCHEDULER_NAME = "scheduler.pt" 76 | SCALER_NAME = "scaler.pt" 77 | 78 | 79 | class MupTrainer(Trainer): 80 | def __init__( 81 | self, 82 | model: Union[PreTrainedModel, nn.Module] = None, 83 | args: TrainingArguments = None, 84 | data_collator: Optional[DataCollator] = None, 85 | train_dataset: Optional[Dataset] = None, 86 | eval_dataset: Optional[Dataset] = None, 87 | tokenizer: Optional[PreTrainedTokenizerBase] = None, 88 | model_init: Callable[[], PreTrainedModel] = None, 89 | compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, 90 | callbacks: Optional[List[TrainerCallback]] = None, 91 | optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), 92 | preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None, 93 | ): 94 | super().__init__(model, args, data_collator, train_dataset, eval_dataset, tokenizer, model_init, compute_metrics, callbacks, 95 | optimizers, preprocess_logits_for_metrics) 96 | 97 | def _inner_training_loop( 98 | self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None 99 | ): 100 | self._train_batch_size = batch_size 101 | train_dataloader = self.get_train_dataloader() 102 | 103 | total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.world_size 104 | 105 | len_dataloader = None 106 | if has_length(train_dataloader): 107 | len_dataloader = len(train_dataloader) 108 | num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps 109 | num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) 110 | num_examples = self.num_examples(train_dataloader) 111 | if args.max_steps > 0: 112 | max_steps = args.max_steps 113 | num_train_epochs = args.max_steps // num_update_steps_per_epoch + int( 114 | args.max_steps % num_update_steps_per_epoch > 0 115 | ) 116 | # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's 117 | # the best we can do. 118 | num_train_samples = args.max_steps * total_train_batch_size 119 | else: 120 | max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch) 121 | num_train_epochs = math.ceil(args.num_train_epochs) 122 | num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs 123 | elif args.max_steps > 0: # Rely on max_steps when dataloader does not have a working size 124 | max_steps = args.max_steps 125 | # Setting a very large number of epochs so we go as many times as necessary over the iterator. 126 | num_train_epochs = sys.maxsize 127 | num_update_steps_per_epoch = max_steps 128 | num_examples = total_train_batch_size * args.max_steps 129 | num_train_samples = args.max_steps * total_train_batch_size 130 | else: 131 | raise ValueError( 132 | "args.max_steps must be set to a positive value if dataloader does not have a length, was" 133 | f" {args.max_steps}" 134 | ) 135 | 136 | if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug: 137 | if self.args.n_gpu > 1: 138 | # nn.DataParallel(model) replicates the model, creating new variables and module 139 | # references registered here no longer work on other gpus, breaking the module 140 | raise ValueError( 141 | "Currently --debug underflow_overflow is not supported under DP. Please use DDP" 142 | " (torch.distributed.launch)." 143 | ) 144 | else: 145 | debug_overflow = DebugUnderflowOverflow(self.model) # noqa 146 | 147 | delay_optimizer_creation = ( 148 | self.sharded_ddp is not None 149 | and self.sharded_ddp != ShardedDDPOption.SIMPLE 150 | or is_sagemaker_mp_enabled() 151 | or self.fsdp is not None 152 | ) 153 | if args.deepspeed: 154 | deepspeed_engine, optimizer, lr_scheduler = deepspeed_init( 155 | self, num_training_steps=max_steps, resume_from_checkpoint=resume_from_checkpoint 156 | ) 157 | self.model = deepspeed_engine.module 158 | self.model_wrapped = deepspeed_engine 159 | self.deepspeed = deepspeed_engine 160 | self.optimizer = optimizer 161 | self.lr_scheduler = lr_scheduler 162 | elif not delay_optimizer_creation: 163 | self.create_optimizer_and_scheduler(num_training_steps=max_steps) 164 | 165 | self.state = TrainerState() 166 | self.state.is_hyper_param_search = trial is not None 167 | 168 | # Activate gradient checkpointing if needed 169 | if args.gradient_checkpointing: 170 | self.model.gradient_checkpointing_enable() 171 | 172 | model = self._wrap_model(self.model_wrapped) 173 | 174 | if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None: 175 | self._load_from_checkpoint(resume_from_checkpoint, model) 176 | 177 | # for the rest of this function `model` is the outside model, whether it was wrapped or not 178 | if model is not self.model: 179 | self.model_wrapped = model 180 | 181 | if delay_optimizer_creation: 182 | self.create_optimizer_and_scheduler(num_training_steps=max_steps) 183 | 184 | # Check if saved optimizer or scheduler states exist 185 | self._load_optimizer_and_scheduler(resume_from_checkpoint) 186 | 187 | # important: at this point: 188 | # self.model is the Transformers Model 189 | # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc. 190 | 191 | # Train! 192 | logger.info("***** Running training *****") 193 | logger.info(f" Num examples = {num_examples}") 194 | logger.info(f" Num Epochs = {num_train_epochs}") 195 | logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") 196 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}") 197 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 198 | logger.info(f" Total optimization steps = {max_steps}") 199 | logger.info( 200 | f" Number of trainable parameters = {sum(p.numel() for p in model.parameters() if p.requires_grad)}" 201 | ) 202 | 203 | self.state.epoch = 0 204 | start_time = time.time() 205 | epochs_trained = 0 206 | steps_trained_in_current_epoch = 0 207 | steps_trained_progress_bar = None 208 | 209 | # Check if continuing training from a checkpoint 210 | if resume_from_checkpoint is not None and os.path.isfile( 211 | os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME) 212 | ): 213 | self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) 214 | epochs_trained = self.state.global_step // num_update_steps_per_epoch 215 | if not args.ignore_data_skip: 216 | steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) 217 | steps_trained_in_current_epoch *= args.gradient_accumulation_steps 218 | else: 219 | steps_trained_in_current_epoch = 0 220 | 221 | logger.info(" Continuing training from checkpoint, will skip to saved global_step") 222 | logger.info(f" Continuing training from epoch {epochs_trained}") 223 | logger.info(f" Continuing training from global step {self.state.global_step}") 224 | if not args.ignore_data_skip: 225 | logger.info( 226 | f" Will skip the first {epochs_trained} epochs then the first {steps_trained_in_current_epoch} " 227 | "batches in the first epoch. If this takes a lot of time, you can add the `--ignore_data_skip` " 228 | "flag to your launch command, but you will resume the training on data already seen by your model." 229 | ) 230 | if self.is_local_process_zero() and not args.disable_tqdm: 231 | steps_trained_progress_bar = tqdm(total=steps_trained_in_current_epoch) 232 | steps_trained_progress_bar.set_description("Skipping the first batches") 233 | 234 | # Update the references 235 | self.callback_handler.model = self.model 236 | self.callback_handler.optimizer = self.optimizer 237 | self.callback_handler.lr_scheduler = self.lr_scheduler 238 | self.callback_handler.train_dataloader = train_dataloader 239 | if self.hp_name is not None and self._trial is not None: 240 | # use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial 241 | # parameter to Train when using DDP. 242 | self.state.trial_name = self.hp_name(self._trial) 243 | if trial is not None: 244 | assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial 245 | self.state.trial_params = hp_params(assignments) 246 | else: 247 | self.state.trial_params = None 248 | # This should be the same if the state has been saved but in case the training arguments changed, it's safer 249 | # to set this after the load. 250 | self.state.max_steps = max_steps 251 | self.state.num_train_epochs = num_train_epochs 252 | self.state.is_local_process_zero = self.is_local_process_zero() 253 | self.state.is_world_process_zero = self.is_world_process_zero() 254 | 255 | # tr_loss is a tensor to avoid synchronization of TPUs through .item() 256 | tr_loss = torch.tensor(0.0).to(args.device) 257 | # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses 258 | self._total_loss_scalar = 0.0 259 | self._globalstep_last_logged = self.state.global_step 260 | model.zero_grad() 261 | 262 | self.control = self.callback_handler.on_train_begin(args, self.state, self.control) 263 | 264 | self.state.lm_batch_count = 0 265 | # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point. 266 | if not args.ignore_data_skip: 267 | for epoch in range(epochs_trained): 268 | is_random_sampler = hasattr(train_dataloader, "sampler") and isinstance( 269 | train_dataloader.sampler, RandomSampler 270 | ) 271 | if is_torch_less_than_1_11 or not is_random_sampler: 272 | # We just need to begin an iteration to create the randomization of the sampler. 273 | # That was before PyTorch 1.11 however... 274 | for _ in train_dataloader: 275 | break 276 | else: 277 | # Otherwise we need to call the whooooole sampler cause there is some random operation added 278 | # AT THE VERY END! 279 | _ = list(train_dataloader.sampler) 280 | 281 | for epoch in range(epochs_trained, num_train_epochs): 282 | if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler): 283 | train_dataloader.sampler.set_epoch(epoch) 284 | elif hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDatasetShard): 285 | train_dataloader.dataset.set_epoch(epoch) 286 | 287 | if is_torch_tpu_available(): 288 | parallel_loader = pl.ParallelLoader(train_dataloader, [args.device]).per_device_loader(args.device) 289 | epoch_iterator = parallel_loader 290 | else: 291 | epoch_iterator = train_dataloader 292 | 293 | # Reset the past mems state at the beginning of each epoch if necessary. 294 | if args.past_index >= 0: 295 | self._past = None 296 | 297 | steps_in_epoch = ( 298 | len(epoch_iterator) 299 | if len_dataloader is not None 300 | else args.max_steps * args.gradient_accumulation_steps 301 | ) 302 | self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control) 303 | 304 | if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0: 305 | self._load_rng_state(resume_from_checkpoint) 306 | 307 | step = -1 308 | for step, inputs in enumerate(epoch_iterator): 309 | if self.args.exit_steps is not None and step > self.args.exit_steps: 310 | exit() 311 | # Skip past any already trained steps if resuming training 312 | if steps_trained_in_current_epoch > 0: 313 | steps_trained_in_current_epoch -= 1 314 | if steps_trained_progress_bar is not None: 315 | steps_trained_progress_bar.update(1) 316 | if steps_trained_in_current_epoch == 0: 317 | self._load_rng_state(resume_from_checkpoint) 318 | continue 319 | elif steps_trained_progress_bar is not None: 320 | steps_trained_progress_bar.close() 321 | steps_trained_progress_bar = None 322 | 323 | if step % args.gradient_accumulation_steps == 0: 324 | self.control = self.callback_handler.on_step_begin(args, self.state, self.control) 325 | 326 | # LM 327 | if ( 328 | ((step + 1) % args.gradient_accumulation_steps != 0) 329 | and args.local_rank != -1 330 | and args._no_sync_in_gradient_accumulation 331 | ): 332 | # Avoid unnecessary DDP synchronization since there will be no backward pass on this example. 333 | with model.no_sync(): 334 | tr_loss_step = self.training_step(model, inputs) 335 | else: 336 | tr_loss_step = self.training_step(model, inputs) 337 | 338 | 339 | # works only when deepspeed is off 340 | if (self.state.global_step + 1) % self.args.logging_steps == 0: 341 | for k, v in model.module.named_parameters(): 342 | # torch.abs(logits).mean(dtype=torch.float32).item() 343 | if v.grad is not None: 344 | self.grad_mean_scalar_dict[k] = torch.abs(v.grad).mean(dtype=torch.float32).item() 345 | 346 | if ( 347 | args.logging_nan_inf_filter 348 | and not is_torch_tpu_available() 349 | and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)) 350 | ): 351 | # if loss is nan or inf simply add the average of previous logged losses 352 | tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged) 353 | else: 354 | tr_loss += tr_loss_step 355 | 356 | self.current_flos += float(self.floating_point_ops(inputs)) 357 | 358 | 359 | # Optimizer step for deepspeed must be called on every step regardless of the value of gradient_accumulation_steps 360 | if self.deepspeed: 361 | self.deepspeed.step() 362 | 363 | if (step + 1) % args.gradient_accumulation_steps == 0 or ( 364 | # last step in epoch but step is always smaller than gradient_accumulation_steps 365 | steps_in_epoch <= args.gradient_accumulation_steps 366 | and (step + 1) == steps_in_epoch 367 | ): 368 | # Gradient clipping 369 | if args.max_grad_norm is not None and args.max_grad_norm > 0 and not self.deepspeed: 370 | # deepspeed does its own clipping 371 | 372 | if self.do_grad_scaling: 373 | # Reduce gradients first for XLA 374 | if is_torch_tpu_available(): 375 | gradients = xm._fetch_gradients(self.optimizer) 376 | xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size()) 377 | # AMP: gradients need unscaling 378 | self.scaler.unscale_(self.optimizer) 379 | 380 | if is_sagemaker_mp_enabled() and args.fp16: 381 | self.optimizer.clip_master_grads(args.max_grad_norm) 382 | elif hasattr(self.optimizer, "clip_grad_norm"): 383 | # Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping 384 | self.optimizer.clip_grad_norm(args.max_grad_norm) 385 | elif hasattr(model, "clip_grad_norm_"): 386 | # Some models (like FullyShardedDDP) have a specific way to do gradient clipping 387 | model.clip_grad_norm_(args.max_grad_norm) 388 | else: 389 | # Revert to normal clipping otherwise, handling Apex or full precision 390 | nn.utils.clip_grad_norm_( 391 | amp.master_params(self.optimizer) if self.use_apex else model.parameters(), 392 | args.max_grad_norm, 393 | ) 394 | 395 | # Optimizer step 396 | optimizer_was_run = True 397 | if self.deepspeed: 398 | pass # called outside the loop 399 | elif is_torch_tpu_available(): 400 | if self.do_grad_scaling: 401 | self.scaler.step(self.optimizer) 402 | self.scaler.update() 403 | else: 404 | xm.optimizer_step(self.optimizer) 405 | elif self.do_grad_scaling: 406 | scale_before = self.scaler.get_scale() 407 | self.scaler.step(self.optimizer) 408 | self.scaler.update() 409 | scale_after = self.scaler.get_scale() 410 | optimizer_was_run = scale_before <= scale_after 411 | else: 412 | self.optimizer.step() 413 | 414 | if optimizer_was_run and not self.deepspeed: 415 | self.lr_scheduler.step() 416 | 417 | model.zero_grad() 418 | self.state.global_step += 1 419 | self.state.epoch = epoch + (step + 1) / steps_in_epoch 420 | self.control = self.callback_handler.on_step_end(args, self.state, self.control) 421 | 422 | self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval) 423 | else: 424 | self.control = self.callback_handler.on_substep_end(args, self.state, self.control) 425 | 426 | if self.control.should_epoch_stop or self.control.should_training_stop: 427 | break 428 | if step < 0: 429 | logger.warning( 430 | "There seems to be not a single sample in your epoch_iterator, stopping training at step" 431 | f" {self.state.global_step}! This is expected if you're using an IterableDataset and set" 432 | f" num_steps ({max_steps}) higher than the number of available samples." 433 | ) 434 | self.control.should_training_stop = True 435 | 436 | self.control = self.callback_handler.on_epoch_end(args, self.state, self.control) 437 | self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval) 438 | 439 | if DebugOption.TPU_METRICS_DEBUG in self.args.debug: 440 | if is_torch_tpu_available(): 441 | # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) 442 | xm.master_print(met.metrics_report()) 443 | else: 444 | logger.warning( 445 | "You enabled PyTorch/XLA debug metrics but you don't have a TPU " 446 | "configured. Check your training configuration if this is unexpected." 447 | ) 448 | if self.control.should_training_stop: 449 | break 450 | 451 | if args.past_index and hasattr(self, "_past"): 452 | # Clean the state at the end of training 453 | delattr(self, "_past") 454 | 455 | logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") 456 | if args.load_best_model_at_end and self.state.best_model_checkpoint is not None: 457 | # Wait for everyone to get here so we are sur the model has been saved by process 0. 458 | if is_torch_tpu_available(): 459 | xm.rendezvous("load_best_model_at_end") 460 | elif args.local_rank != -1: 461 | dist.barrier() 462 | elif is_sagemaker_mp_enabled(): 463 | smp.barrier() 464 | 465 | self._load_best_model() 466 | 467 | # add remaining tr_loss 468 | self._total_loss_scalar += tr_loss.item() 469 | train_loss = self._total_loss_scalar / self.state.global_step 470 | 471 | metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps) 472 | self.store_flos() 473 | metrics["total_flos"] = self.state.total_flos 474 | metrics["train_loss"] = train_loss 475 | 476 | self.is_in_train = False 477 | 478 | self._memory_tracker.stop_and_update_metrics(metrics) 479 | 480 | self.log(metrics) 481 | 482 | run_dir = self._get_output_dir(trial) 483 | checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir) 484 | 485 | # Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint. 486 | if self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1: 487 | for checkpoint in checkpoints_sorted: 488 | if checkpoint != self.state.best_model_checkpoint: 489 | logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") 490 | shutil.rmtree(checkpoint) 491 | 492 | self.control = self.callback_handler.on_train_end(args, self.state, self.control) 493 | 494 | return TrainOutput(self.state.global_step, train_loss, metrics) 495 | 496 | 497 | ############ added for Mup grad check without deepspeed. ############### 498 | def create_optimizer(self): 499 | """ 500 | This is triggered if deepspeed is off. 501 | Mup re-grouping of optimizer parameters. 502 | """ 503 | opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model 504 | 505 | if self.optimizer is None: 506 | decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS) 507 | decay_parameters = [name for name in decay_parameters if "bias" not in name] 508 | optimizer_grouped_parameters = [ 509 | { 510 | "params": [ 511 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad) 512 | ], 513 | "weight_decay": self.args.weight_decay, 514 | }, 515 | { 516 | "params": [ 517 | p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad) 518 | ], 519 | "weight_decay": 0.0, 520 | }, 521 | ] 522 | 523 | print("use_mup:") 524 | print(self.args.use_mup) 525 | 526 | if self.args.use_mup: 527 | new_param_groups = [] 528 | decoupled_wd = False 529 | for param_group in process_param_groups(optimizer_grouped_parameters, lr=self.args.learning_rate): 530 | # For every existing param group, we split into several new groups 531 | def new_group(): 532 | new_g = {k:v for k, v in param_group.items() if k != 'params'} 533 | new_g['params'] = [] 534 | return new_g 535 | # The matrix-like weights might need multiple groups since weights 536 | # might have different width multipliers 537 | matrix_like_p = defaultdict(new_group) # key is width_mult 538 | vector_like_p = new_group() 539 | for p in param_group['params']: 540 | assert hasattr(p, 'infshape'), ( 541 | f'A parameter with shape {p.shape} does not have `infshape` attribute. ' 542 | 'Did you forget to call `mup.set_base_shapes` on the model?') 543 | if p.infshape.ninf() == 2: 544 | matrix_like_p[p.infshape.width_mult()]['params'].append(p) 545 | elif p.infshape.ninf() > 2: 546 | raise NotImplementedError('more than 2 inf dimensions') 547 | else: 548 | vector_like_p['params'].append(p) 549 | 550 | for width_mult, group in matrix_like_p.items(): 551 | # Scale learning rate and weight decay accordingly 552 | print(width_mult) 553 | group['lr'] /= width_mult 554 | if not decoupled_wd: 555 | group['weight_decay'] *= width_mult 556 | 557 | new_param_groups.extend(list(matrix_like_p.values()) + [vector_like_p]) 558 | 559 | optimizer_grouped_parameters = new_param_groups 560 | 561 | optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) 562 | 563 | if self.sharded_ddp == ShardedDDPOption.SIMPLE: 564 | self.optimizer = OSS( 565 | params=optimizer_grouped_parameters, 566 | optim=optimizer_cls, 567 | **optimizer_kwargs, 568 | ) 569 | else: 570 | self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) 571 | if optimizer_cls.__name__ == "Adam8bit": 572 | import bitsandbytes 573 | 574 | manager = bitsandbytes.optim.GlobalOptimManager.get_instance() 575 | 576 | skipped = 0 577 | for module in opt_model.modules(): 578 | if isinstance(module, nn.Embedding): 579 | skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) 580 | print(f"skipped {module}: {skipped/2**20}M params") 581 | manager.register_module_override(module, "weight", {"optim_bits": 32}) 582 | logger.debug(f"bitsandbytes: will optimize {module} in fp32") 583 | print(f"skipped: {skipped/2**20}M params") 584 | 585 | if is_sagemaker_mp_enabled(): 586 | self.optimizer = smp.DistributedOptimizer(self.optimizer) 587 | 588 | return self.optimizer 589 | 590 | 591 | def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None): 592 | """ 593 | Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or 594 | passed as an argument. 595 | Args: 596 | num_training_steps (int): The number of training steps to do. 597 | """ 598 | if self.lr_scheduler is None: 599 | if self.args.log_warmup: 600 | self.lr_scheduler = get_linear_schedule_with_inverse_log_warmup( 601 | optimizer=self.optimizer if optimizer is None else optimizer, 602 | num_warmup_steps=self.args.get_warmup_steps(num_training_steps), 603 | num_training_steps=num_training_steps, 604 | ) 605 | else: 606 | self.lr_scheduler = get_scheduler( 607 | self.args.lr_scheduler_type, 608 | optimizer=self.optimizer if optimizer is None else optimizer, 609 | num_warmup_steps=self.args.get_warmup_steps(num_training_steps), 610 | num_training_steps=num_training_steps, 611 | ) 612 | return self.lr_scheduler 613 | 614 | 615 | def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval): 616 | if self.control.should_log: 617 | if is_torch_tpu_available(): 618 | xm.mark_step() 619 | 620 | logs: Dict[str, float] = {} 621 | 622 | # all_gather + mean() to get average loss over all processes 623 | tr_loss_scalar = self._nested_gather(tr_loss).mean().item() 624 | 625 | # reset tr_loss to zero 626 | tr_loss -= tr_loss 627 | 628 | logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) 629 | logs["learning_rate"] = self._get_learning_rate() 630 | 631 | self._total_loss_scalar += tr_loss_scalar 632 | self._globalstep_last_logged = self.state.global_step 633 | self.store_flos() 634 | 635 | self.log(logs) 636 | 637 | metrics = None 638 | if self.control.should_evaluate: 639 | if isinstance(self.eval_dataset, dict): 640 | for eval_dataset_name, eval_dataset in self.eval_dataset.items(): 641 | metrics = self.evaluate( 642 | eval_dataset=eval_dataset, 643 | ignore_keys=ignore_keys_for_eval, 644 | metric_key_prefix=f"eval_{eval_dataset_name}", 645 | ) 646 | else: 647 | metrics = self.evaluate(ignore_keys=ignore_keys_for_eval) 648 | self._report_to_hp_search(trial, self.state.global_step, metrics) 649 | 650 | if self.control.should_save: 651 | self._save_checkpoint(model, trial, metrics=metrics) 652 | self.control = self.callback_handler.on_save(self.args, self.state, self.control) 653 | 654 | def log(self, logs: Dict[str, float]) -> None: 655 | """ 656 | Log `logs` on the various objects watching training. 657 | 658 | Subclass and override this method to inject custom behavior. 659 | 660 | Args: 661 | logs (`Dict[str, float]`): 662 | The values to log. 663 | """ 664 | if self.state.epoch is not None: 665 | logs["epoch"] = round(self.state.epoch, 2) 666 | 667 | output = {**logs, **{"step": self.state.global_step}} 668 | self.state.log_history.append(output) 669 | self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aiohttp==3.8.3 2 | aiosignal==1.3.1 3 | anykeystore==0.2 4 | async-timeout==4.0.2 5 | attrs==22.1.0 6 | charset-normalizer==2.1.1 7 | cryptacular==1.6.2 8 | datasets==2.7.1 9 | defusedxml==0.7.1 10 | dill==0.3.6 11 | evaluate==0.4.0 12 | filelock==3.8.2 13 | frozenlist==1.3.3 14 | fsspec==2022.11.0 15 | greenlet==2.0.1 16 | huggingface-hub==0.11.1 17 | hupper==1.10.3 18 | idna==3.4 19 | install==1.3.5 20 | joblib==1.2.0 21 | MarkupSafe==2.1.1 22 | multidict==6.0.3 23 | multiprocess==0.70.14 24 | numpy==1.23.5 25 | nvidia-cublas-cu11==11.10.3.66 26 | nvidia-cuda-nvrtc-cu11==11.7.99 27 | nvidia-cuda-runtime-cu11==11.7.99 28 | nvidia-cudnn-cu11==8.5.0.96 29 | oauthlib==3.2.2 30 | packaging==22.0 31 | pandas==1.5.2 32 | PasteDeploy==3.0.1 33 | pbkdf2==1.3 34 | plaster==1.1.2 35 | plaster-pastedeploy==1.0.1 36 | protobuf==4.21.11 37 | psutil==5.9.4 38 | pyarrow==10.0.1 39 | pynvml==11.4.1 40 | pyramid==2.0 41 | python-dateutil==2.8.2 42 | python3-openid==3.2.0 43 | pytz==2022.6 44 | PyYAML==6.0 45 | regex==2022.10.31 46 | repoze.sendmail==4.4.1 47 | requests==2.28.1 48 | requests-oauthlib==1.3.1 49 | responses==0.18.0 50 | scikit-learn==1.2.0 51 | scipy==1.9.3 52 | sentencepiece==0.1.97 53 | six==1.16.0 54 | SQLAlchemy==1.4.45 55 | threadpoolctl==3.1.0 56 | tokenizers==0.13.2 57 | torch==1.13.0 58 | tqdm==4.64.1 59 | transaction==3.0.1 60 | transformers==4.25.1 61 | translationstring==1.4 62 | typing_extensions==4.4.0 63 | urllib3==1.26.13 64 | venusian==3.0.0 65 | WebOb==1.8.7 66 | WTForms==3.0.1 67 | wtforms-recaptcha==0.3.2 68 | xxhash==3.1.0 69 | yarl==1.8.2 70 | zope.deprecation==4.4.0 71 | zope.interface==5.5.2 72 | zope.sqlalchemy==1.6 73 | -------------------------------------------------------------------------------- /res/final_data/test/current_data_args.json: -------------------------------------------------------------------------------- 1 | { 2 | "lm_dataset_name": "wikitext|wikitext-2-raw-v1;glue|mrpc|sentence1", 3 | "lm_disk_dataset_dir": null, 4 | "lm_train_file": null, 5 | "lm_train_dir": "./test_data/owt2|json|text", 6 | "mt_train_file": "./test_data/mt/mt_test1.txt;./test_data/mt/mt_test2.txt", 7 | "mt_train_dir": null, 8 | "model_name_or_path": "gpt2", 9 | "final_train_dir": "./res/final_data/test", 10 | "final_lm_dir": "/home/yaoyiqun/llm_mup/nlm_dev/res/final_data/test/lm", 11 | "final_mt_dir": "/home/yaoyiqun/llm_mup/nlm_dev/res/final_data/test/mt", 12 | "final_tokenize_dir": "/home/yaoyiqun/llm_mup/nlm_dev/res/final_data/test/tokenizer", 13 | "block_size": 1024, 14 | "preprocessing_num_workers": 10, 15 | "cache_dir": null, 16 | "min_text_length": 5, 17 | "do_lm_process": true, 18 | "do_mt_process": true 19 | } -------------------------------------------------------------------------------- /res/final_data/test/lm/dataset.arrow: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cofe-ai/Mu-scaling/a3d13c3eaa02e28cc6ef95116358b128424f9fe4/res/final_data/test/lm/dataset.arrow -------------------------------------------------------------------------------- /res/final_data/test/lm/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "citation": "@misc{merity2016pointer,\n title={Pointer Sentinel Mixture Models},\n author={Stephen Merity and Caiming Xiong and James Bradbury and Richard Socher},\n year={2016},\n eprint={1609.07843},\n archivePrefix={arXiv},\n primaryClass={cs.CL}\n}\n\n\n@inproceedings{dolan2005automatically,\n title={Automatically constructing a corpus of sentential paraphrases},\n author={Dolan, William B and Brockett, Chris},\n booktitle={Proceedings of the Third International Workshop on Paraphrasing (IWP2005)},\n year={2005}\n}\n@inproceedings{wang2019glue,\n title={{GLUE}: A Multi-Task Benchmark and Analysis Platform for Natural Language Understanding},\n author={Wang, Alex and Singh, Amanpreet and Michael, Julian and Hill, Felix and Levy, Omer and Bowman, Samuel R.},\n note={In the Proceedings of ICLR.},\n year={2019}\n}", 3 | "description": "The WikiText language modeling dataset is a collection of over 100 million tokens extracted from the set of verified\n Good and Featured articles on Wikipedia. The dataset is available under the Creative Commons Attribution-ShareAlike\n License.\n\n\nGLUE, the General Language Understanding Evaluation benchmark\n(https://gluebenchmark.com/) is a collection of resources for training,\nevaluating, and analyzing natural language understanding systems.", 4 | "features": { 5 | "input_ids": { 6 | "feature": { 7 | "dtype": "int32", 8 | "_type": "Value" 9 | }, 10 | "_type": "Sequence" 11 | }, 12 | "attention_mask": { 13 | "feature": { 14 | "dtype": "int8", 15 | "_type": "Value" 16 | }, 17 | "_type": "Sequence" 18 | }, 19 | "labels": { 20 | "feature": { 21 | "dtype": "int64", 22 | "_type": "Value" 23 | }, 24 | "_type": "Sequence" 25 | } 26 | }, 27 | "homepage": "https://blog.einstein.ai/the-wikitext-long-term-dependency-language-modeling-dataset/\n\nhttps://www.microsoft.com/en-us/download/details.aspx?id=52398", 28 | "license": "Creative Commons Attribution-ShareAlike 4.0 International (CC BY-SA 4.0)" 29 | } -------------------------------------------------------------------------------- /res/final_data/test/lm/state.json: -------------------------------------------------------------------------------- 1 | { 2 | "_data_files": [ 3 | { 4 | "filename": "dataset.arrow" 5 | } 6 | ], 7 | "_fingerprint": "ceb1db52bc48c783", 8 | "_format_columns": [ 9 | "attention_mask", 10 | "input_ids", 11 | "labels" 12 | ], 13 | "_format_kwargs": {}, 14 | "_format_type": null, 15 | "_indexes": {}, 16 | "_output_all_columns": false, 17 | "_split": null 18 | } -------------------------------------------------------------------------------- /res/final_data/test/mt/dataset.arrow: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cofe-ai/Mu-scaling/a3d13c3eaa02e28cc6ef95116358b128424f9fe4/res/final_data/test/mt/dataset.arrow -------------------------------------------------------------------------------- /res/final_data/test/mt/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "citation": "", 3 | "description": "", 4 | "features": { 5 | "tid": { 6 | "dtype": "int64", 7 | "_type": "Value" 8 | }, 9 | "label": { 10 | "dtype": "int64", 11 | "_type": "Value" 12 | }, 13 | "tmplt_id": { 14 | "dtype": "int64", 15 | "_type": "Value" 16 | }, 17 | "input_ids": { 18 | "feature": { 19 | "dtype": "int32", 20 | "_type": "Value" 21 | }, 22 | "_type": "Sequence" 23 | }, 24 | "attention_mask": { 25 | "feature": { 26 | "dtype": "int8", 27 | "_type": "Value" 28 | }, 29 | "_type": "Sequence" 30 | }, 31 | "length": { 32 | "dtype": "int64", 33 | "_type": "Value" 34 | }, 35 | "split": { 36 | "dtype": "string", 37 | "_type": "Value" 38 | }, 39 | "opt_count": { 40 | "dtype": "int64", 41 | "_type": "Value" 42 | } 43 | }, 44 | "homepage": "", 45 | "license": "" 46 | } -------------------------------------------------------------------------------- /res/final_data/test/mt/state.json: -------------------------------------------------------------------------------- 1 | { 2 | "_data_files": [ 3 | { 4 | "filename": "dataset.arrow" 5 | } 6 | ], 7 | "_fingerprint": "a2286959a5a699c3", 8 | "_format_columns": null, 9 | "_format_kwargs": {}, 10 | "_format_type": null, 11 | "_indexes": {}, 12 | "_output_all_columns": false, 13 | "_split": null 14 | } -------------------------------------------------------------------------------- /res/final_data/test/tokenizer/added_tokens.json: -------------------------------------------------------------------------------- 1 | { 2 | "[cls]": 50259, 3 | "[pad]": 50257, 4 | "[sep]": 50258, 5 | "[tsk]": 50260 6 | } 7 | -------------------------------------------------------------------------------- /res/final_data/test/tokenizer/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "additional_special_tokens": [ 3 | "[tsk]" 4 | ], 5 | "bos_token": { 6 | "content": "<|endoftext|>", 7 | "lstrip": false, 8 | "normalized": true, 9 | "rstrip": false, 10 | "single_word": false 11 | }, 12 | "cls_token": "[cls]", 13 | "eos_token": { 14 | "content": "<|endoftext|>", 15 | "lstrip": false, 16 | "normalized": true, 17 | "rstrip": false, 18 | "single_word": false 19 | }, 20 | "pad_token": "[pad]", 21 | "sep_token": "[sep]", 22 | "unk_token": { 23 | "content": "<|endoftext|>", 24 | "lstrip": false, 25 | "normalized": true, 26 | "rstrip": false, 27 | "single_word": false 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /res/final_data/test/tokenizer/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "add_bos_token": false, 3 | "add_prefix_space": false, 4 | "bos_token": { 5 | "__type": "AddedToken", 6 | "content": "<|endoftext|>", 7 | "lstrip": false, 8 | "normalized": true, 9 | "rstrip": false, 10 | "single_word": false 11 | }, 12 | "eos_token": { 13 | "__type": "AddedToken", 14 | "content": "<|endoftext|>", 15 | "lstrip": false, 16 | "normalized": true, 17 | "rstrip": false, 18 | "single_word": false 19 | }, 20 | "errors": "replace", 21 | "model_max_length": 1024, 22 | "name_or_path": "gpt2", 23 | "pad_token": null, 24 | "special_tokens_map_file": null, 25 | "tokenizer_class": "GPT2Tokenizer", 26 | "unk_token": { 27 | "__type": "AddedToken", 28 | "content": "<|endoftext|>", 29 | "lstrip": false, 30 | "normalized": true, 31 | "rstrip": false, 32 | "single_word": false 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /run_eval_ppl_loss_pred.sh: -------------------------------------------------------------------------------- 1 | # WikiText2 : wikitext & wikitext-2-raw-v1 2 | CUDA_VISIBLE_DEVICES=1 3 | #export HF_DATASETS_OFFLINE=1 4 | params="0.06_2e-3_4" 5 | for width in 128 256 384 1024 6 | do 7 | python run_eval_ppl_mup.py \ 8 | --cache_dir /your/huggingface/cache/dir \ 9 | --is_ours \ 10 | --dataset_path wikitext \ 11 | --dataset_name wikitext-2-raw-v1 \ 12 | --model_name_or_path res/output/test_standard_mup_loss_pred_20k/${params}/width_${width}/checkpoint-20000 \ 13 | > logs/eval/20k_${params}_${width}.txt 2>&1 14 | done -------------------------------------------------------------------------------- /run_eval_ppl_mup.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from torch.nn import CrossEntropyLoss 4 | from datasets import load_dataset, DatasetDict 5 | from tqdm import tqdm 6 | from transformers import (AutoModelForCausalLM, AutoTokenizer, GPT2Tokenizer) 7 | from modeling.lm_mup import MupGPT2Model 8 | from mup import set_base_shapes 9 | 10 | parser = argparse.ArgumentParser() 11 | # Load Model 12 | parser.add_argument("--model_name_or_path", type=str, default='gpt2') 13 | parser.add_argument("--is_ours", default=False, action='store_true') 14 | 15 | # Dataset 16 | parser.add_argument("--dataset_path", type=str, default=None) 17 | parser.add_argument("--dataset_name", type=str, default=None) 18 | parser.add_argument("--dataset_split_name", type=str, default=None) 19 | parser.add_argument("--dataset_feature_name", type=str, default='text') 20 | parser.add_argument("--is_disk_data", default=False, action='store_true') 21 | 22 | # Others 23 | parser.add_argument("--stride", type=int, default=None) 24 | parser.add_argument("--cache_dir", type=str, default=None) 25 | parser.add_argument("--data_dir", type=str, default=None) 26 | params = parser.parse_args() 27 | 28 | device = "cuda" if torch.cuda.is_available() else "cpu" 29 | 30 | class cheat_infshape(object): 31 | def __init__(self, width_mult): 32 | self.width_mult_ = width_mult 33 | def width_mult(self): 34 | return self.width_mult_ 35 | 36 | 37 | # Load Model 38 | model = None 39 | tokenizer = None 40 | if params.is_ours: 41 | model = MupGPT2Model.from_pretrained(params.model_name_or_path).to(device) 42 | model.transformer.input_mult = model.config.output_mult 43 | model.lm_head.weight.infshape = cheat_infshape(model.config.n_embd / 256) 44 | print(f"self.lm_head.output_mult:{model.lm_head.output_mult}") 45 | print(f"self.transformer.input_mult:{model.transformer.input_mult}") 46 | print(f"self.lm.width_mult:{model.lm_head.width_mult()}") 47 | tokenizer = GPT2Tokenizer.from_pretrained(params.model_name_or_path) 48 | 49 | else: 50 | model = AutoModelForCausalLM.from_pretrained(params.model_name_or_path, cache_dir="/share/project/lixiang/cache").to(device) 51 | tokenizer = AutoTokenizer.from_pretrained(params.model_name_or_path, cache_dir="/share/project/lixiang/cache") 52 | assert model is not None and tokenizer is not None 53 | 54 | # Load Data 55 | raw_datasets = None 56 | if params.is_disk_data: 57 | raw_datasets = DatasetDict.load_from_disk(params.dataset_path) 58 | else: 59 | raw_datasets = load_dataset(params.dataset_path, params.dataset_name, 60 | cache_dir=params.cache_dir, 61 | data_dir=params.cache_dir 62 | ) 63 | assert raw_datasets is not None 64 | 65 | # Preprocessing 66 | target_dataset = None 67 | if params.dataset_split_name is not None: 68 | target_dataset = params.dataset_split_name 69 | elif "test" in raw_datasets.keys(): 70 | target_dataset = "test" 71 | elif "validation" in raw_datasets.keys(): 72 | target_dataset = "validation" 73 | elif "train" in raw_datasets.keys(): 74 | target_dataset = "train" 75 | assert target_dataset is not None 76 | target_dataset = "train" 77 | print(f'dataset: {params.dataset_path} - {params.dataset_name} - {target_dataset}') 78 | dataset_test = raw_datasets[target_dataset] 79 | encodings = tokenizer("\n\n".join(dataset_test[params.dataset_feature_name]), return_tensors="pt") 80 | 81 | # Eval 82 | ignore_index = CrossEntropyLoss().ignore_index 83 | if params.model_name_or_path.startswith('xlnet'): 84 | max_length = 1024 85 | else: 86 | max_length = model.config.n_positions 87 | # stride = max_length // 2 88 | stride = params.stride if params.stride is not None else max_length 89 | seq_len = encodings.input_ids.size(1) 90 | 91 | nlls = [] 92 | prev_end_loc = 0 93 | for begin_loc in tqdm(range(0, seq_len, stride)): 94 | end_loc = min(begin_loc + max_length, seq_len) 95 | trg_len = end_loc - prev_end_loc # may be different from stride on last loop 96 | input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device) 97 | attention_mask = encodings.attention_mask[:, begin_loc:end_loc].to(device) 98 | target_ids = input_ids.clone() 99 | target_ids[:, :-trg_len] = ignore_index 100 | 101 | with torch.no_grad(): 102 | outputs = model(input_ids, attention_mask=attention_mask, labels=target_ids) 103 | neg_log_likelihood = outputs.loss * trg_len 104 | 105 | nlls.append(neg_log_likelihood) 106 | 107 | prev_end_loc = end_loc 108 | if end_loc == seq_len: 109 | break 110 | 111 | # ppl = torch.exp(torch.stack(nlls).sum() / seq_len) 112 | ppl = torch.stack(nlls).sum() / seq_len 113 | n_params = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values()) 114 | print(f"model = {n_params / 2 ** 20:.2f}M params; ppl = {ppl}") 115 | -------------------------------------------------------------------------------- /run_grid_search_pair_wise_mup.sh: -------------------------------------------------------------------------------- 1 | warmup_ratio=0.01 2 | exit_steps=20000 3 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7, 4 | for output_mult in 5 5 | do 6 | for lr in 5e-5 1e-4 3e-4 1e-3 3e-3 7 | do 8 | for initializer_range in 0.05 9 | do 10 | for hp_tune_actual_width in 256 512 11 | do 12 | python -m torch.distributed.launch --nproc_per_node=8 \ 13 | --nnodes=1 \ 14 | run_train_gpt_mup_from_scratch.py \ 15 | --model_name_or_path gpt2 \ 16 | --model_load_pretrained False \ 17 | --config_name ./configs/gpt_2_L_6 \ 18 | --output_dir ./res/output/pair_wise_L6/${lr}_${output_mult}_${initializer_range}/width_${hp_tune_actual_width} \ 19 | --final_train_dir "/path/to/your/data" \ 20 | --overwrite_output_dir \ 21 | --num_train_epochs 0.1 \ 22 | --per_device_train_batch_size 6 \ 23 | --warmup_ratio ${warmup_ratio} \ 24 | --ddp_timeout 1000000 \ 25 | --logging_steps 100 \ 26 | --save_steps 5000 \ 27 | --save_total_limit 20 \ 28 | --learning_rate ${lr} \ 29 | --hp_tune_base_width 256 \ 30 | --size_per_head 64 \ 31 | --hp_tune_actual_width ${hp_tune_actual_width} \ 32 | --output_mult ${output_mult} \ 33 | --initializer_range ${initializer_range} \ 34 | --log_warmup True \ 35 | --unified_dropout 0.0 \ 36 | --exit_steps ${exit_steps} \ 37 | --deepspeed ./deepspeed_configs/zero_stage1_config.json \ 38 | > logs/pair_wise_L6_test_release_${lr}_${output_mult}_${initializer_range}_width_${hp_tune_actual_width}.txt 2>&1 39 | 40 | sleep 60 41 | done 42 | done 43 | done 44 | done -------------------------------------------------------------------------------- /run_train_gpt_mup_from_scratch.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | from dataclasses import dataclass, field 5 | from typing import Optional 6 | 7 | import datasets 8 | from transformers.deepspeed import is_deepspeed_zero3_enabled 9 | 10 | import transformers 11 | from transformers import ( 12 | HfArgumentParser, 13 | TrainingArguments, 14 | default_data_collator, 15 | set_seed, 16 | GPT2Tokenizer 17 | ) 18 | from transformers.trainer_utils import get_last_checkpoint 19 | 20 | from transformers import GPT2Config 21 | from modeling.initialize_with_mup import mup_init_from_scratch 22 | from mup_trainer import MupTrainer 23 | from utils import concat_path 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | 28 | @dataclass 29 | class ModelArguments: 30 | """ 31 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. 32 | """ 33 | 34 | model_name_or_path: Optional[str] = field( 35 | default=None, 36 | metadata={ 37 | "help": ( 38 | "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch." 39 | ) 40 | }, 41 | ) 42 | config_name: Optional[str] = field( 43 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 44 | ) 45 | cache_dir: Optional[str] = field( 46 | default=None, 47 | metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, 48 | ) 49 | use_fast_tokenizer: bool = field( 50 | default=True, 51 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 52 | ) 53 | model_revision: str = field( 54 | default="main", 55 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 56 | ) 57 | model_load_pretrained: bool = field( 58 | default=False, 59 | metadata={"help": "Whether to load checkpoints even if `model_name_or_path` is provided"}, 60 | ) 61 | 62 | def __post_init__(self): 63 | if self.config_name is None and self.model_name_or_path is None: 64 | raise ValueError( 65 | "--config_name or --model_name_or_path must be set" 66 | ) 67 | 68 | 69 | @dataclass 70 | class DataTrainingArguments: 71 | final_train_dir: Optional[str] = field(default=None, metadata={ 72 | "help": "training data path" 73 | }) 74 | final_lm_dir: Optional[str] = field(default=None, metadata={ 75 | "help": "path to language model data" 76 | }) 77 | final_tokenize_dir: Optional[str] = field(default=None, metadata={ 78 | "help": "tokenizer path" 79 | }) 80 | max_lm_train_samples: Optional[int] = field(default=None, metadata={ 81 | "help": "maximum samples" 82 | }) 83 | 84 | def __post_init__(self): 85 | if self.final_train_dir is not None: 86 | if self.final_lm_dir is None: 87 | self.final_lm_dir = concat_path(self.final_train_dir, 'lm') 88 | 89 | if self.final_tokenize_dir is None: 90 | self.final_tokenize_dir = concat_path(self.final_train_dir, 'tokenizer') 91 | 92 | 93 | @dataclass 94 | class MyTrainingArguments(TrainingArguments): 95 | hp_tune_base_width: Optional[int] = field(default=256, metadata={"help": "mup基础宽度 参数化时按照此宽度放缩"}) 96 | size_per_head: Optional[int] = field(default=128, metadata={"help": "每个头的宽度 默认在参数化放缩中不变"}) 97 | hp_tune_actual_width: Optional[int] = field(default=768, metadata={"help": "mup实际所调参的模型宽度"}) 98 | output_mult: Optional[float] = field(default=1.0, metadata={"help": "输出层乘子,可微调超参数,当前方案中表示对ckpt的vocab除以该数值。"}) 99 | initializer_range: Optional[float] = field(default=0.02, metadata={"help": "初始化标准差,覆盖config"}) 100 | log_warmup: Optional[bool] = field(default=False, metadata={"help": "无deepspeed时是否使用log warmup"}) 101 | unified_dropout: Optional[float] = field(default=None, metadata={"help": "若非none,将所有dropout层设为此值,主要用于零化dropout"}) 102 | use_mup: Optional[bool] = field(default=True, metadata={"help": "mup开关默认打开,手动关闭用来跑对照实验"}) 103 | exit_steps: Optional[int] = field(default=None, metadata={"help": "手动设定退出的step数"}) 104 | readout_zero_init: Optional[bool] = field(default=True, metadata={"help": "vocab是否全零化"}) 105 | query_zero_init: Optional[bool] = field(default=True, metadata={"help": "Q阵是否全零化"}) 106 | is_training_ckpt_self: Optional[bool] = field(default=False, metadata={"help": "是否是加载checkpoint同大小正式训练"}) 107 | 108 | 109 | def main(): 110 | 111 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, MyTrainingArguments)) 112 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 113 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 114 | else: 115 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 116 | 117 | # Setup logging 118 | logging.basicConfig( 119 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 120 | datefmt="%m/%d/%Y %H:%M:%S", 121 | handlers=[logging.StreamHandler(sys.stdout)], 122 | ) 123 | 124 | log_level = training_args.get_process_log_level() 125 | logger.setLevel(log_level) 126 | datasets.utils.logging.set_verbosity(log_level) 127 | transformers.utils.logging.set_verbosity(log_level) 128 | transformers.utils.logging.enable_default_handler() 129 | transformers.utils.logging.enable_explicit_format() 130 | 131 | # Log on each process the small summary: 132 | logger.warning( 133 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 134 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 135 | ) 136 | logger.info(f"Training/evaluation parameters {training_args}") 137 | 138 | # Detecting last checkpoint. 139 | last_checkpoint = None 140 | if os.path.isdir(training_args.output_dir) and not training_args.overwrite_output_dir: 141 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 142 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 143 | raise ValueError( 144 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 145 | "Use --overwrite_output_dir to overcome." 146 | ) 147 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 148 | logger.info( 149 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 150 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 151 | ) 152 | 153 | # Set seed before initializing model. 154 | set_seed(training_args.seed) 155 | 156 | # **************************************************************************************************** 157 | # tokenizer 158 | # **************************************************************************************************** 159 | tokenizer = None 160 | # look for data directory first, then configuration 161 | if data_args.final_tokenize_dir is not None: 162 | logger.info(f'loading tokenizer from PREPROCESSED DATA: {data_args.final_tokenize_dir}') 163 | tokenizer = GPT2Tokenizer.from_pretrained(data_args.final_tokenize_dir, cache_dir=model_args.cache_dir) 164 | elif model_args.config_name is not None: 165 | logger.info(f'loading tokenizer from CONFIG: {model_args.config_name}') 166 | tokenizer = GPT2Tokenizer.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir) 167 | else: 168 | raise ValueError( 169 | f"FAILED to load tokenizer, please provide one of --final_tokenize_dir, --config_name" 170 | ) 171 | 172 | # **************************************************************************************************** 173 | # Load Config 174 | # **************************************************************************************************** 175 | config_kwargs = { 176 | "cache_dir": model_args.cache_dir, 177 | "revision": model_args.model_revision 178 | } 179 | if model_args.config_name: 180 | try: 181 | config = GPT2Config.from_pretrained(model_args.config_name, **config_kwargs) 182 | except OSError as e: 183 | if model_args.model_name_or_path: 184 | logger.warning( 185 | f"failed to load from {model_args.config_name}. " 186 | f"constructing with: {model_args.model_name_or_path} " 187 | ) 188 | config_kwargs['vocab_size'] = len(tokenizer) 189 | config_kwargs['pad_token_id'] = tokenizer.pad_token_id 190 | config_kwargs['cls_token_id'] = tokenizer.cls_token_id 191 | config_kwargs['sep_token_id'] = tokenizer.sep_token_id 192 | config = GPT2Config.from_pretrained(model_args.model_name_or_path, **config_kwargs) 193 | 194 | config.save_pretrained(model_args.config_name) 195 | else: 196 | raise ValueError( 197 | f"Must provide one of: --model_name_or_path, --num_multi_task_labels" 198 | ) 199 | 200 | elif model_args.model_name_or_path : 201 | config_kwargs['vocab_size'] = len(tokenizer) 202 | config_kwargs['pad_token_id'] = tokenizer.pad_token_id 203 | config_kwargs['cls_token_id'] = tokenizer.cls_token_id 204 | config_kwargs['sep_token_id'] = tokenizer.sep_token_id 205 | config = GPT2Config.from_pretrained(model_args.model_name_or_path, **config_kwargs) 206 | else: 207 | raise ValueError( 208 | f"--config_name, --model_name_or_path, --num_multi_task_labels" 209 | ) 210 | 211 | # **************************************************************************************************** 212 | # Model Init 213 | # **************************************************************************************************** 214 | 215 | ### Initialize Model with Mup ### 216 | training_args.width_mult_for_weights = (float(training_args.hp_tune_actual_width) / training_args.hp_tune_base_width) if training_args.use_mup else 1.0 217 | 218 | logger.info(f"width_mult_for_weights: {training_args.width_mult_for_weights}") 219 | config.width_mult_for_weights = training_args.width_mult_for_weights 220 | 221 | # Mup only supports training from scratch 222 | assert model_args.model_load_pretrained == False 223 | model = mup_init_from_scratch(config=config, training_args=training_args, 224 | model_args=model_args, logger=logger) 225 | 226 | ################################# 227 | 228 | assert len(tokenizer) == model.config.vocab_size 229 | 230 | if is_deepspeed_zero3_enabled(): 231 | n_params = 0 232 | n_partitioned_params = 0 233 | for p in model.parameters(): 234 | if p.ds_tensor is not None: 235 | n_params += p.ds_numel 236 | n_partitioned_params += p.ds_tensor.numel() 237 | logger.info( 238 | f"My MSG: Training new model - Total size={n_params / 2 ** 20:.2f}M params. Total partitioned size={n_partitioned_params / 2 ** 20:.2f}M params ") 239 | else: 240 | n_params = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values()) 241 | logger.info(f"My MSG: Training new model - Total size={n_params / 2 ** 20:.2f}M params") 242 | 243 | # **************************************************************************************************** 244 | # Load Data 245 | # **************************************************************************************************** 246 | train_dataset = None 247 | if data_args.final_lm_dir is not None: 248 | logger.info(f'Loading LM data from: {data_args.final_lm_dir} / {data_args.max_lm_train_samples}') 249 | train_dataset = datasets.load_from_disk(data_args.final_lm_dir) 250 | if data_args.max_lm_train_samples is not None: 251 | train_dataset = train_dataset.select(range(data_args.max_lm_train_samples)) 252 | logger.info(f'{train_dataset}') 253 | 254 | # **************************************************************************************************** 255 | # Training Process 256 | # **************************************************************************************************** 257 | trainer = MupTrainer( 258 | model=model, 259 | args=training_args, 260 | train_dataset=train_dataset, 261 | eval_dataset=None, 262 | tokenizer=tokenizer, 263 | data_collator=default_data_collator, 264 | compute_metrics=None, 265 | preprocess_logits_for_metrics=None, 266 | ) 267 | 268 | # Training 269 | checkpoint = None 270 | if training_args.resume_from_checkpoint is not None: 271 | checkpoint = training_args.resume_from_checkpoint 272 | elif last_checkpoint is not None: 273 | checkpoint = last_checkpoint 274 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 275 | 276 | trainer.save_model() 277 | 278 | metrics = train_result.metrics 279 | max_lm_train_samples = ( 280 | data_args.max_lm_train_samples if data_args.max_lm_train_samples is not None else len(train_dataset) 281 | ) 282 | metrics["lm_train_samples"] = min(max_lm_train_samples, len(train_dataset)) 283 | trainer.log_metrics("train", metrics) 284 | trainer.save_metrics("train", metrics) 285 | trainer.save_state() 286 | 287 | kwargs = { 288 | "dataset": ', '.join([str(data_args.final_lm_dir), str(data_args.final_lm_dir)]) 289 | } 290 | trainer.create_model_card(**kwargs) 291 | 292 | 293 | if __name__ == "__main__": 294 | main() 295 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .files import * 2 | from .torchs import * 3 | from .stat import * 4 | -------------------------------------------------------------------------------- /utils/files.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pickle 4 | 5 | PARDIR = os.path.pardir 6 | 7 | 8 | def define_dir(*args, make_sure_exists=True): 9 | folder = os.path.abspath(os.path.join(*args)) 10 | if not os.path.exists(folder) and make_sure_exists: 11 | os.makedirs(folder) 12 | # assert os.path.isdir(folder), '[ERROR] \'%s\' is not a folder path!' % folder 13 | return folder 14 | 15 | 16 | def concat_path(*args): 17 | return os.path.abspath(os.path.join(*args)) 18 | 19 | 20 | def cur_dir_abspath(cur_file): 21 | return os.path.abspath(os.path.dirname(cur_file)) 22 | 23 | 24 | def del_file(file_path): 25 | if os.path.isfile(file_path): 26 | os.remove(file_path) 27 | 28 | 29 | def find_all_file_paths(src_dir): 30 | if not os.path.isdir(src_dir): 31 | raise FileNotFoundError(src_dir) 32 | for dir_cur, _, file_names in os.walk(src_dir, topdown=False): 33 | for file_name in file_names: 34 | if file_name.startswith('.'): 35 | continue 36 | yield os.path.join(dir_cur, file_name), file_name 37 | 38 | 39 | def find_cur_file_paths(src_dir): 40 | if not os.path.isdir(src_dir): 41 | raise FileNotFoundError(src_dir) 42 | for file_name in os.listdir(src_dir): 43 | file_path = os.path.join(src_dir, file_name) 44 | if os.path.isfile(file_path): 45 | yield file_path, file_name 46 | 47 | 48 | def get_file_paths(src_dir, sub_dir=True): 49 | iter_fun = find_all_file_paths if sub_dir else find_cur_file_paths 50 | file_infos = [x for x in iter_fun(src_dir)] 51 | if file_infos: 52 | file_paths, file_names = zip(*[x for x in iter_fun(src_dir)]) 53 | return list(file_paths), list(file_names) 54 | else: 55 | return [], [] 56 | 57 | 58 | def load_text_file(file_path) -> str: 59 | with open(file_path, 'r', encoding='UTF-8') as f: 60 | data = f.read() 61 | return data 62 | 63 | 64 | def save_text_file(file_path, data: str): 65 | with open(file_path, 'w', encoding='UTF-8') as f: 66 | f.write(data) 67 | 68 | 69 | def load_text_file_by_line(file_path): 70 | with open(file_path, 'r', encoding='UTF-8') as f: 71 | data = [token.replace('\n', '').replace('\r', '') for token in f.readlines()] 72 | return [x for x in data if x] 73 | 74 | 75 | def save_text_file_by_line(file_path, data: list): 76 | with open(file_path, 'w', encoding='UTF-8') as f: 77 | f.write('\n'.join(data)) 78 | 79 | 80 | def load_json_file(file_path): 81 | with open(file_path, 'r', encoding='UTF-8') as f: 82 | org_data = json.load(f) 83 | return org_data 84 | 85 | 86 | def save_json_file(file_path, data, indent=None): 87 | with open(file_path, 'w', encoding='UTF-8') as f: 88 | json.dump(data, f, ensure_ascii=False, indent=indent) 89 | 90 | 91 | def load_json_file_by_line(file_path): 92 | return [json.loads(line) for line in load_text_file_by_line(file_path)] 93 | 94 | 95 | def save_json_file_by_line(file_path, data: list): 96 | save_text_file_by_line(file_path, [json.dumps(x, ensure_ascii=False) for x in data]) 97 | 98 | 99 | def load_data_file(file_path: str): 100 | usrs, prds, labels, docs = [], [], [], [] 101 | for line in load_text_file_by_line(file_path): 102 | items = line.split('\t\t') 103 | usrs.append(items[0]) 104 | prds.append(items[1]) 105 | labels.append(int(items[2]) - 1) 106 | docs.append([sent.strip().split(' ') for sent in items[3][0:-1].split('')]) 107 | return usrs, prds, labels, docs 108 | 109 | 110 | def save_by_pickle(file_path, data): 111 | with open(file_path, 'wb') as f: 112 | pickle.dump(data, f) 113 | 114 | 115 | def load_by_pickle(file_path): 116 | with open(file_path, 'rb') as f: 117 | data = pickle.load(f) 118 | return data 119 | 120 | 121 | def get_file_paths_from_path(file_path, find_sub_dir=False): 122 | file_paths = [] 123 | if isinstance(file_path, list): 124 | pass 125 | 126 | elif isinstance(file_path, str): 127 | if os.path.isfile(file_path): 128 | file_paths.append(file_path) 129 | elif os.path.isdir(file_path): 130 | for fp, _ in (find_all_file_paths if find_sub_dir else find_cur_file_paths)(file_path): 131 | file_paths.append(fp) 132 | else: 133 | raise FileExistsError('Can not find path: \'%s\'' % file_path) 134 | 135 | else: 136 | raise TypeError('Unkown type `file_path` -> %s' % str(type(file_path))) 137 | 138 | file_paths = sorted(file_paths, reverse=False) 139 | return file_paths 140 | 141 | 142 | if __name__ == '__main__': 143 | pass 144 | -------------------------------------------------------------------------------- /utils/stat.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .files import load_text_file_by_line 3 | 4 | 5 | def stat_params(model, show_detail=False): 6 | total = [] 7 | for name, param in model.named_parameters(): 8 | p_size = param.nelement() 9 | total.append(p_size) 10 | if show_detail: 11 | print('%-50s %-30s %d' % (name, str(tuple(param.size())), p_size)) 12 | # print("Number of parameter: %.2fM" % (sum(total) / 1e6)) 13 | print("Number of parameter: %.2fM" % (sum(total) / (2 ** 20))) 14 | 15 | 16 | def stat_dataset_zh(file_path, sentence_sep='[SOS]'): 17 | raw_texts = load_text_file_by_line(file_path) 18 | dataset_size = len(raw_texts) 19 | print('All Data: %d' % dataset_size) 20 | 21 | doc_lengths, sent_lengths, num_sent_in_docs = [], [], [] 22 | for doc in raw_texts: 23 | sents = doc.split(sentence_sep) 24 | doc_length = 0 25 | for sent in sents: 26 | doc_length += len(sent) 27 | sent_lengths.append(len(sent)) 28 | doc_lengths.append(doc_length) 29 | num_sent_in_docs.append(len(sents)) 30 | 31 | print('Doc AVG Length: %d' % np.average(doc_lengths)) 32 | print('Sent AVG Length: %d' % np.average(sent_lengths)) 33 | print('Sent AVG Count : %d' % np.average(num_sent_in_docs)) 34 | 35 | 36 | def stat_dataset_by_tokenized(texts: list): 37 | num_texts = len(texts) 38 | print('Docs Num: %d' % num_texts) 39 | text_lengths = [len(tokens) for tokens in texts] 40 | print('Length: %d' % np.average(text_lengths)) 41 | 42 | 43 | if __name__ == '__main__': 44 | pass 45 | -------------------------------------------------------------------------------- /utils/torchs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import pynvml 5 | import random 6 | 7 | 8 | def get_gpus_meminfo(): 9 | try: 10 | pynvml.nvmlInit() 11 | handles = [pynvml.nvmlDeviceGetHandleByIndex(idx) for idx in range(pynvml.nvmlDeviceGetCount())] 12 | gpus_free = [pynvml.nvmlDeviceGetMemoryInfo(handle).free for handle in handles] 13 | gpus_idx = np.argsort(gpus_free)[::-1].tolist() 14 | gpus_free = [gpus_free[idx] for idx in gpus_idx] 15 | except Exception: 16 | gpus_free, gpus_idx = [], [] 17 | return gpus_idx, gpus_free 18 | 19 | 20 | def get_best_device(): 21 | device_idx = None 22 | if torch.cuda.is_available(): 23 | gpus, _ = get_gpus_meminfo() 24 | if gpus: 25 | device_idx = gpus[0] 26 | return device_idx 27 | 28 | 29 | def cuda_is_available(): 30 | return torch.cuda.is_available() 31 | 32 | 33 | def set_global_rand_seed(seed): 34 | random.seed(seed) 35 | np.random.seed(seed) 36 | os.environ['PYTHONHASHSEED'] = str(seed) 37 | 38 | torch.backends.cudnn.benchmark = False 39 | torch.backends.cudnn.deterministic = True 40 | 41 | torch.manual_seed(seed) 42 | torch.cuda.manual_seed(seed) 43 | torch.cuda.manual_seed_all(seed) 44 | -------------------------------------------------------------------------------- /visualize_lr_landscape.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import os 4 | import json 5 | 6 | res_dir = "res/output/" 7 | run_name = "your-own-run-name" # defined by --output_dir in your scripts 8 | checkpoint_num = 20000 9 | all_results = {} 10 | all_runs_list = [r for r in os.listdir(res_dir + run_name) if "misc" not in r] 11 | print(all_runs_list) 12 | for single_run in all_runs_list: 13 | for w in os.listdir(res_dir + run_name + "/" + single_run): 14 | if w not in all_results.keys(): 15 | all_results[w] = [] 16 | full_name = single_run 17 | target_file = os.path.join(res_dir, run_name, single_run, w, 18 | f"checkpoint-{checkpoint_num}", "trainer_state.json") 19 | if os.path.isfile(target_file): 20 | with open(target_file, "r") as f_d: 21 | dict = json.load(f_d) 22 | lm_loss = dict["log_history"][-1]["loss"] 23 | mt_loss = dict["log_history"][-1]["mt_loss"] if "mt_loss" in dict["log_history"][-1] else -1 24 | all_results[w].append((full_name, lm_loss, mt_loss)) 25 | 26 | for w, l in all_results.items(): 27 | print(w) 28 | print("sort by lm_loss:") 29 | for name, lm_loss, mt_loss in sorted(l, key=lambda x: x[1], reverse=False): 30 | print(f"{name}\t{lm_loss}\t{mt_loss}") 31 | 32 | # 一维依存关系曲线图 33 | loss_type = ["LM"] 34 | for i in range(1): 35 | for w, l in all_results.items(): 36 | l_1 = sorted(l, key=lambda x: float(x[0].split("_")[0]), reverse=False) 37 | x = [float(data[0].split("_")[0]) for data in l_1] 38 | y = [data[i+1] for data in l_1] 39 | plt.plot(x, y, label=w) 40 | # for _x, _y in zip(x, y): 41 | # plt.text(_x, _y, (_x, _y)) 42 | plt.xlabel("lr") 43 | plt.ylabel("loss") 44 | plt.xscale("log") 45 | plt.yscale("log") 46 | plt.legend() 47 | plt.savefig(f"logs/pics/{run_name}_{loss_type[i]}_step_{checkpoint_num}.png") 48 | plt.clf() --------------------------------------------------------------------------------