├── .github └── workflows │ └── black.yml ├── .gitignore ├── LICENSE ├── README.md ├── experiment ├── config.py ├── ip.py └── launch.py ├── requirements.txt ├── scaelum ├── __init__.py ├── builder │ ├── __init__.py │ ├── builder.py │ ├── module_wrapper.py │ └── sequential_wrapper.py ├── config │ ├── __init__.py │ └── config.py ├── dataset │ ├── __init__.py │ ├── bert_dataset.py │ ├── data_generator.py │ ├── dataset.py │ └── glue │ │ ├── __init__.py │ │ ├── file_utils.py │ │ ├── processor.py │ │ └── tokenization.py ├── dynamics │ ├── __init__.py │ ├── allocator.py │ ├── benchmarker.py │ ├── estimator.py │ ├── parameter_server.py │ ├── worker.py │ └── worker_manager.py ├── logger │ ├── __init__.py │ └── logger.py ├── model │ ├── __init__.py │ ├── bert.py │ ├── bert_layers.py │ ├── layers.py │ ├── rpc_model.py │ └── rpc_module.py ├── registry │ ├── __init__.py │ └── registry.py ├── runner │ ├── __init__.py │ ├── hooks.py │ ├── hooks_collection │ │ ├── __init__.py │ │ ├── checkpoint_hook.py │ │ ├── distributed_timer_helper_hook.py │ │ └── stop_hook.py │ └── runner.py ├── stimulator │ ├── __init__.py │ └── stimulator.py ├── timer │ ├── __init__.py │ └── timer.py ├── utils.py └── version.py └── setup.py /.github/workflows/black.yml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | lint: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v2 10 | - uses: psf/black@stable 11 | with: 12 | options: "--check --verbose" 13 | src: "./scaelum" 14 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | local_settings.py 60 | db.sqlite3 61 | db.sqlite3-journal 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # IPython 80 | profile_default/ 81 | ipython_config.py 82 | 83 | # pyenv 84 | .python-version 85 | 86 | # pipenv 87 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 88 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 89 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 90 | # install all needed dependencies. 91 | #Pipfile.lock 92 | 93 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 94 | __pypackages__/ 95 | 96 | # Celery stuff 97 | celerybeat-schedule 98 | celerybeat.pid 99 | 100 | # SageMath parsed files 101 | *.sage.py 102 | 103 | # Environments 104 | .env 105 | .venv 106 | env/ 107 | venv/ 108 | ENV/ 109 | env.bak/ 110 | venv.bak/ 111 | 112 | # Spyder project settings 113 | .spyderproject 114 | .spyproject 115 | 116 | # Rope project settings 117 | .ropeproject 118 | 119 | # mkdocs documentation 120 | /site 121 | 122 | # mypy 123 | .mypy_cache/ 124 | .dmypy.json 125 | dmypy.json 126 | 127 | # Pyre type checker 128 | .pyre/ 129 | 130 | # torch data 131 | data/ 132 | .vscode/ 133 | dllb.log 134 | 135 | # macOS 136 | .DS_Store 137 | 138 | # JetBrains 139 | .idea/ 140 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2021- HPC-AI Technology Inc. 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Sky Computing 2 | 3 | ## Introduction 4 | 5 | Sky Computing is a load-balanced framework for federated learning model parallelism. It adaptively allocate model layers to devices based on the their hardware sepcification. Sky Computing outperforms the baseline method by 55% in training time when training 160-layer BERT in a 64-node cluster. Our paper can be found at https://arxiv.org/abs/2202.11836 6 | 7 | The concept *sky computing* was first introduced by Dr. Katarzyna Keahey et al. They used this word to describe a cross-cloud compute pattern. And later Prof. Stoica and Prof. Shenker generalized this word to geo-distributed computing. Our project is based on their definition. [\[1\]](https://ieeexplore.ieee.org/abstract/document/5226615) [\[2\]](https://dl.acm.org/doi/abs/10.1145/3458336.3465301) 8 | 9 | ## Installation 10 | 11 | ```shell 12 | git clone git@github.com:hpcaitech/SkyComputing.git 13 | python -m pip install -r requirements.txt 14 | cd ./scaelum 15 | python -m pip install -v -e . 16 | ``` 17 | 18 | ## Experiment (using BERT) 19 | 20 | To benchmark the Sky Computing, we prepared a single demo which you can run on your cluster to train BERT. 21 | 22 | ### Prepare BERT model 23 | 24 | Bidirectional Encoder Representations from Transformers (aka [BERT](https://aclanthology.org/N19-1423/)) is one of the state-of-the-art deep learning models for Natural Language Processing. In the experiment part, we use BERT to run a simple benchmark. 25 | 26 | ```shell 27 | cd $PROJECT 28 | mkdir -p BERT/model && cd BERT/model 29 | wget https://storage.googleapis.com/bert_models/2019_05_30/wwm_uncased_L-24_H-1024_A-16.zip 30 | unzip wwm_uncased_L-24_H-1024_A-16.zip 31 | ``` 32 | 33 | ### Prepare GLUE MNLI dataset 34 | 35 | The General Language Understanding Evaluation (aka [GLUE](https://gluebenchmark.com/)) benchmark is a collection of resources for training, evaluating, and analyzing natural language understanding systems. And the Multi-Genre Natural Language Inference (aka [MNLI](https://cims.nyu.edu/~sbowman/multinli/)) is one of the tasks in GLUE, it is a crowd-sourced collection of 433k sentence pairs annotated with textual entailment information. 36 | 37 | ```shell 38 | cd $PROJECT 39 | mkdir -p BERT/data && cd BERT/data 40 | wget https://gist.githubusercontent.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e/raw/1502038877f6a88c225a34450793fbc3ea87eaba/download_glue_data.py 41 | python download_glue_data.py --data_dir ./glue_data --tasks MNLI 42 | ``` 43 | 44 | ### Configuration 45 | 46 | To run dllb in your cluster, you need to write a config file which contains the necessary information about training, e.g. model layers, useful environment variables. We have provided a well-commentted [example](https://github.com/hpcaitech/SkyComputing/blob/main/experiment/config.py), and here are some most important option: 47 | 48 | ```python 49 | # your project path 50 | PROJECT = os.getenv("PROJECT") 51 | 52 | # allocation type, valid values are even, optimal and dynamic 53 | ALLOCATE_TYPE = "even" 54 | 55 | # num of node (including the central server) 56 | CORE_NUM = 4 57 | ``` 58 | 59 | ### Run scripts 60 | 61 | [Slurm](https://www.schedmd.com/) is an open source, fault-tolerant, and highly scalable cluster management and job scheduling system for large and small Linux clusters. We used slurm script to run our experiment. 62 | 63 | ```shell 64 | #!/bin/sh 65 | 66 | #SBATCH --job-name=gpu16 # Job name 67 | #SBATCH -o gpu16.o%j # Name of stdout output file 68 | #SBATCH -e gpu16.e%j # Name of stderr error file 69 | #SBATCH -N 16 # Node numbers 70 | #SBATCH -n 16 # GPU numbers 71 | #SBATCH --time=02:00:00 # Run time (hh:mm:ss) 72 | 73 | # run 74 | python ./ip_addr.py > "./HOST" 75 | srun python ./launch.py -c "./experiment/config.py" 76 | ``` 77 | 78 | ## Citation 79 | 80 | ```tex 81 | @misc{zhu2022sky, 82 | title={Sky Computing: Accelerating Geo-distributed Computing in Federated Learning}, 83 | author={Jie Zhu and Shenggui Li and Yang You}, 84 | year={2022}, 85 | eprint={2202.11836}, 86 | archivePrefix={arXiv}, 87 | primaryClass={cs.LG} 88 | } 89 | ``` 90 | 91 | ## Reference 92 | 93 | ```tex 94 | @article{keahey2009sky, 95 | title={Sky computing}, 96 | author={Keahey, Katarzyna and Tsugawa, Mauricio and Matsunaga, Andrea and Fortes, Jose}, 97 | journal={IEEE Internet Computing}, 98 | volume={13}, 99 | number={5}, 100 | pages={43--51}, 101 | year={2009}, 102 | publisher={IEEE} 103 | } 104 | @inproceedings{stoica2021cloud, 105 | title={From cloud computing to sky computing}, 106 | author={Stoica, Ion and Shenker, Scott}, 107 | booktitle={Proceedings of the Workshop on Hot Topics in Operating Systems}, 108 | pages={26--32}, 109 | year={2021} 110 | } 111 | ``` 112 | -------------------------------------------------------------------------------- /experiment/config.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os.path as osp 3 | import os 4 | from scaelum.model.bert import BertConfig 5 | import numpy as np 6 | 7 | # project path 8 | PROJECT = os.getenv("PROJECT") 9 | 10 | # allocation type, valid values are optimal, even and dynamic 11 | ALLOCATE_TYPE = "even" 12 | 13 | # num of training node (including central server) 14 | CORE_NUM = 4 15 | 16 | # num of hidden layers 17 | LAYER_NUM = 10 18 | 19 | # SOME PARAMS 20 | __data_root = f"{PROJECT}/BERT/data/glue_data" 21 | __task = "MNLI" 22 | __model_root = f"{PROJECT}/BERT/model/wwm_uncased_L-24_H-1024_A-16" 23 | __config_file = osp.join(__model_root, "bert_config.json") 24 | __config = BertConfig.from_json_file(__config_file) 25 | 26 | __ENCODER = [ 27 | dict(layer_type="BertLayer_Head", config=__config.__dict__), 28 | dict(layer_type="BertLayer_Body", config=__config.__dict__), 29 | dict(layer_type="BertLayer_Tail", config=__config.__dict__), 30 | ] * LAYER_NUM # __config.num_hidden_layers 31 | 32 | __BERT_LAYERS = ( 33 | [ 34 | dict( 35 | layer_type="BertEmbeddings", 36 | config=__config.__dict__, 37 | ) 38 | ] 39 | + __ENCODER 40 | + [ 41 | dict(layer_type="BertPooler", config=__config.__dict__), 42 | dict( 43 | layer_type="BertTailForClassification", 44 | hidden_dropout_prob=__config.hidden_dropout_prob, 45 | hidden_size=__config.hidden_size, 46 | num_classes=3, 47 | ), 48 | ] 49 | ) 50 | 51 | # config for rpc initialization 52 | # will be replaced in SLURM job script 53 | rpc_config = dict( 54 | MASTER_ADDR="localhost", MASTER_PORT="29500", GLOO_SOCKET_IFNAME="ipogif0" 55 | ) 56 | 57 | # for runner logger hook 58 | __LOG_ROOT = f"{PROJECT}/logs/{CORE_NUM}nodes_{LAYER_NUM}layers/{ALLOCATE_TYPE}" 59 | logging_config = dict( 60 | mode="a", filename=osp.join(__LOG_ROOT, "allocation.log") # do not change 61 | ) 62 | 63 | # worker config 64 | worker_config = [] 65 | 66 | 67 | def get_slowdown(val, num): 68 | # generate reproducible random slowdown 69 | rng = np.random.default_rng(seed=35) 70 | rints = rng.integers(low=1, high=7, size=num + 1) 71 | return rints[val] 72 | 73 | WORKER_NUM = CORE_NUM - 1 74 | 75 | for i in range(1, WORKER_NUM + 1): 76 | worker_config.append( 77 | dict( 78 | name=f"gpu-{i}", 79 | server_config=dict( 80 | host="localhost", 81 | port="8001", 82 | ), 83 | extra_config=dict( 84 | # slowdown=get_slowdown(i, WORKER_NUM), 85 | slowdown=1, # TODO remove this method 86 | logging_config=dict( 87 | mode="a", # do not change 88 | filename=osp.join(__LOG_ROOT, f"node-{i}-train.log"), 89 | ), 90 | mem_limit=-1, 91 | cuda_device=0, 92 | module_to_cuda=True, 93 | output_to_cpu=True, 94 | timer_config=dict( 95 | root=__LOG_ROOT, 96 | ), 97 | ), 98 | ), 99 | ) 100 | 101 | # model config 102 | model_config = __BERT_LAYERS 103 | 104 | # dataset config 105 | data_config = dict( 106 | dataset_cfg=dict( 107 | type="GlueDataset", 108 | data_dir=osp.join(__data_root, __task), 109 | bert_model="large-uncased", 110 | vocab_file=osp.join(__model_root, "vocab.txt"), 111 | max_seq_length=128, 112 | do_lower_case=False, 113 | processor="mnli", 114 | ), 115 | dataloader_cfg=dict( 116 | batch_size=32, 117 | shuffle=True, 118 | num_workers=0, 119 | ), 120 | ) 121 | 122 | # dynamic allocation config 123 | allocator_config = dict( 124 | type=ALLOCATE_TYPE, 125 | benchmark_config=dict( 126 | model=dict( 127 | device="cpu", 128 | param_scale=2, 129 | data_generator_cfg=dict( 130 | generator_type="DataloaderGenerator", generator_cfg=data_config 131 | ), 132 | ), 133 | device=dict( 134 | model_config=[ 135 | dict( 136 | layer_type="Conv2d", 137 | in_channels=256, 138 | out_channels=256, 139 | kernel_size=3, 140 | padding=1, 141 | ) 142 | ] 143 | * 10, 144 | iterations=30, 145 | data_generator_cfg=dict( 146 | generator_type="RandomTensorGenerator", 147 | generator_cfg=dict(size=(32, 256, 64, 64)), 148 | ), 149 | ), 150 | ), 151 | ) 152 | 153 | # training config 154 | train_config = dict( 155 | optim_cfg=dict(optim_type="SGD", lr=0.001), 156 | loss_cfg=dict(type="CrossEntropyLoss"), 157 | runner_cfg=dict( 158 | max_epochs=1, 159 | max_iters=30, 160 | ), 161 | # for runner hook 162 | hook_config=[ 163 | dict(type="StopHook", root=__LOG_ROOT), 164 | dict(type="DistributedTimerHelperHook"), 165 | ], 166 | timer_config=dict(root=__LOG_ROOT), 167 | ) 168 | -------------------------------------------------------------------------------- /experiment/ip.py: -------------------------------------------------------------------------------- 1 | import ifaddr 2 | import os 3 | 4 | if os.getenv("SLURM_PROCID") == 0: 5 | adapters = ifaddr.get_adapters() 6 | 7 | for adapter in adapters: 8 | if adapter.nice_name == "ipogif0": 9 | for addr in adapter.ips: 10 | if ':' not in addr.ip: 11 | print(f"{addr.ip}") -------------------------------------------------------------------------------- /experiment/launch.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import argparse 5 | import os 6 | import os.path as osp 7 | import time 8 | 9 | import scaelum.utils as dutils 10 | import torch.distributed.rpc as rpc 11 | from scaelum import build_hook, Runner, RpcModel 12 | from scaelum.builder import build_dataloader_from_cfg, build_data_generator 13 | from scaelum.config import load_config 14 | from scaelum.dynamics import Allocator, WorkerManager, ParameterServer, ModelBenchmarker, DeviceBenchmarker 15 | from scaelum.logger import Logger 16 | from torch import optim 17 | from torch.distributed.optim import DistributedOptimizer 18 | 19 | 20 | def run_process(rank: int, 21 | world_size: int, 22 | rpc_config: dict, 23 | model_config: list = None, 24 | data_config: dict = None, 25 | logging_config: dict = None, 26 | allocator_config: dict = None, 27 | train_config: dict = None, 28 | worker_config: list = None, 29 | ): 30 | # set env var for rpc init 31 | for k, v in rpc_config.items(): 32 | os.environ[k] = str(v) 33 | 34 | # init rpc 35 | print('starting to initialize rpc on rank: {}'.format(rank)) 36 | worker_name = dutils.generate_worker_name(rank) 37 | rpc.init_rpc( 38 | name=worker_name, 39 | rank=rank, 40 | world_size=world_size, 41 | backend=rpc.BackendType.TENSORPIPE, 42 | rpc_backend_options=rpc.TensorPipeRpcBackendOptions( 43 | num_worker_threads=8, 44 | rpc_timeout=1200 # 600 second timeout 45 | ) 46 | ) 47 | print('rpc initialized on rank: {}'.format(rank)) 48 | 49 | if rank == 0: 50 | # clean previous logs 51 | log_workspace = osp.dirname(logging_config['filename']) 52 | 53 | if osp.exists(log_workspace): 54 | for f in os.listdir(log_workspace): 55 | os.remove(osp.join(log_workspace, f)) 56 | else: 57 | os.makedirs(log_workspace) 58 | 59 | # init logging 60 | logger = Logger(**logging_config) 61 | logger.info('logger initialized') 62 | 63 | # init worker manager 64 | worker_manager = WorkerManager() 65 | worker_manager.load_worker_pool_from_config(worker_config) 66 | 67 | # create parameter server 68 | parameter_server = ParameterServer(model_config) 69 | 70 | # create dataloder 71 | dataloader_cfg = data_config['dataloader_cfg'].copy() 72 | dataset_cfg = data_config['dataset_cfg'].copy() 73 | data_loader = build_dataloader_from_cfg( 74 | dataloader_cfg=dataloader_cfg, 75 | dataset_cfg=dataset_cfg) 76 | logger.info('created data loader') 77 | 78 | # benchmarking 79 | benchmark_cfg = allocator_config['benchmark_config'].copy() 80 | 81 | # build model benchmarker 82 | data_cfg_for_model = benchmark_cfg['model'].pop('data_generator_cfg') 83 | generator_type = data_cfg_for_model.pop('generator_type') 84 | data_generator_for_model = build_data_generator( 85 | module_name=generator_type, 86 | **data_cfg_for_model 87 | ) 88 | 89 | model_benchmarker = ModelBenchmarker( 90 | model_config=model_config, 91 | data_generator=data_generator_for_model, 92 | **benchmark_cfg['model'] 93 | ) 94 | 95 | # build device benchmarker 96 | data_cfg_for_device = benchmark_cfg['device'].pop('data_generator_cfg') 97 | generator_type = data_cfg_for_device.pop('generator_type') 98 | data_generator_for_device = build_data_generator( 99 | module_name=generator_type, 100 | **data_cfg_for_device 101 | ) 102 | 103 | device_benchmarker = DeviceBenchmarker( 104 | worker_manager=worker_manager, 105 | data_generator=data_generator_for_device, 106 | **benchmark_cfg['device'] 107 | ) 108 | 109 | # build allocator 110 | allocator = Allocator( 111 | model_cfg=model_config, 112 | worker_manager=worker_manager, 113 | model_benchmarker=model_benchmarker, 114 | device_benchmarker=device_benchmarker, 115 | ) 116 | 117 | allocation_success = True 118 | 119 | if allocator_config['type'] == "dynamic": 120 | try: 121 | worker_manager = allocator.dynamic_allocate() 122 | logger.info('dynamically allocated model layers based on benchmarking') 123 | 124 | del allocator, device_benchmarker, model_benchmarker, data_generator_for_model, data_generator_for_device 125 | except Exception: 126 | allocation_success = False 127 | elif allocator_config['type'] == "optimal": 128 | print("using optimal allocate") 129 | try: 130 | worker_manager = allocator.optimal_allocate() 131 | logger.info("use optimal strategy to allocate model layers, may take long time") 132 | del allocator, device_benchmarker, model_benchmarker, data_generator_for_model, data_generator_for_device 133 | except Exception: 134 | allocation_success = False 135 | else: 136 | worker_manager = allocator.even_allocate() 137 | logger.info('Evenly allocated model layers') 138 | del allocator 139 | 140 | # log worker workload 141 | for worker in worker_manager.worker_pool: 142 | logger.info('rank: {}, number of layers: {}'.format( 143 | worker.rank, len(worker.model_config))) 144 | 145 | if allocation_success: 146 | # create model 147 | model = RpcModel(worker_manager=worker_manager) 148 | logger.info('created model') 149 | 150 | # create optimizer 151 | # Build DistributedOptimizer. 152 | optim_mod = getattr(optim, train_config['optim_cfg'].pop('optim_type')) 153 | dist_optim = DistributedOptimizer( 154 | optim_mod, model.parameter_rrefs(), 155 | **train_config['optim_cfg']) 156 | logger.info('created distrubted optimizer') 157 | 158 | # build runner 159 | loss_cfg = train_config['loss_cfg'].copy() 160 | timer_cfg = train_config['timer_config'].copy() 161 | logging_cfg = logging_config.copy() 162 | 163 | runner = Runner( 164 | model=model, 165 | parameter_server=parameter_server, 166 | worker_manager=worker_manager, 167 | optimizer=dist_optim, 168 | loss_cfg=loss_cfg, 169 | timer_cfg=timer_cfg, 170 | logging_cfg=logging_cfg, 171 | **train_config['runner_cfg'] 172 | ) 173 | logger.info('created runner') 174 | 175 | for cfg in train_config['hook_config']: 176 | cfg_copy = cfg.copy() 177 | hook_name = cfg_copy.pop('type') 178 | hook = build_hook(hook_name, **cfg_copy) 179 | runner.register_hook(hook) 180 | logger.info('register hooks') 181 | 182 | runner.train(data_loader) 183 | else: 184 | time.sleep(30) 185 | rpc.shutdown() 186 | print('finish') 187 | 188 | 189 | def parse_args(): 190 | parser = argparse.ArgumentParser() 191 | parser.add_argument('-c', '--config', type=str, help='path to config file') 192 | # parser.add_argument('-n', '--ntasks', type=int) 193 | # parser.add_argument('-t', '--type', type=str) 194 | parser.add_argument('-p', '--port', type=int, default=29500) 195 | return parser.parse_args() 196 | 197 | 198 | if __name__ == '__main__': 199 | args = parse_args() 200 | config = load_config(args.config) 201 | 202 | # get individual config 203 | model_config = config.pop('model_config') 204 | rpc_config = config.pop('rpc_config') 205 | data_config = config.pop('data_config') 206 | logging_config = config.pop('logging_config') 207 | worker_config = config.pop('worker_config') 208 | allocator_config = config.pop('allocator_config') 209 | train_config = config.pop('train_config') 210 | 211 | # replace the rpc init environment variable 212 | host_file = "./HOST" 213 | with open(host_file, 'r') as f: 214 | host = f.readline().strip() 215 | 216 | rpc_config['MASTER_ADDR'] = host 217 | rpc_config['MASTER_PORT'] = args.port 218 | 219 | # set Language to avoid error on Frontera 220 | os.environ['LC_ALL'] = 'C.UTF-8' 221 | 222 | # get ibrun rank and world size 223 | rank = int(os.environ['SLURM_PROCID']) 224 | world_size = int(os.environ['SLURM_NPROCS']) 225 | 226 | run_process( 227 | rank=rank, 228 | world_size=world_size, 229 | rpc_config=rpc_config, 230 | model_config=model_config, 231 | data_config=data_config, 232 | logging_config=logging_config, 233 | allocator_config=allocator_config, 234 | train_config=train_config, 235 | worker_config=worker_config, 236 | ) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib 2 | tqdm 3 | boto3 4 | scikit-image 5 | fastapi 6 | uvicorn 7 | pthflops==0.3.5 8 | torch==1.6.0 9 | torchvision==0.7.0 10 | psutil 11 | pyroute2 12 | requests 13 | loguru 14 | pulp 15 | mip 16 | ifaddr -------------------------------------------------------------------------------- /scaelum/__init__.py: -------------------------------------------------------------------------------- 1 | from .builder import * 2 | from .config import * 3 | 4 | from .dataset import * 5 | from .dynamics import * 6 | from .logger import * 7 | from .model import * 8 | from .registry import * 9 | from .runner import * 10 | from .timer import * 11 | from .version import __version__ 12 | -------------------------------------------------------------------------------- /scaelum/builder/__init__.py: -------------------------------------------------------------------------------- 1 | from .builder import ( 2 | build_dataloader_from_cfg, 3 | build_from_registry, 4 | build_hook, 5 | build_layer, 6 | build_module_from_cfg, 7 | build_data_generator, 8 | ) 9 | from .module_wrapper import ModuleWrapper 10 | from .sequential_wrapper import SequentialWrapper 11 | -------------------------------------------------------------------------------- /scaelum/builder/builder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | 5 | from scaelum.registry import DATA_GENERATOR, DATASET, HOOKS, LAYER, Registry 6 | from torch.utils.data import DataLoader 7 | 8 | from .module_wrapper import ModuleWrapper 9 | from .sequential_wrapper import SequentialWrapper 10 | 11 | 12 | def build_from_registry(module_name: str, registry: Registry, *args, **kwargs): 13 | mod = registry.get_module(module_name) 14 | return mod(*args, **kwargs) 15 | 16 | 17 | def build_layer(module_name: str, *args, **kwargs): 18 | return build_from_registry(module_name, LAYER, *args, **kwargs) 19 | 20 | 21 | def build_hook(module_name: str, *args, **kwargs): 22 | return build_from_registry(module_name, HOOKS, *args, **kwargs) 23 | 24 | 25 | def build_data_generator(module_name: str, *args, **kwargs): 26 | return build_from_registry(module_name, DATA_GENERATOR, *args, **kwargs) 27 | 28 | 29 | def build_module_from_cfg(rank, model_cfg: list, module_wrapper_cfg: dict): 30 | layers = [] 31 | 32 | for layer_cfg in model_cfg: 33 | layer_cfg_copy = layer_cfg.copy() 34 | layer_type = layer_cfg_copy.pop("layer_type") 35 | layer = build_layer(layer_type, **layer_cfg_copy) 36 | layers.append(layer) 37 | module = SequentialWrapper(*layers) 38 | module_wrapper_cfg["record_forward_time"] = True 39 | module = ModuleWrapper(rank=rank, module=module, **module_wrapper_cfg) 40 | 41 | return module 42 | 43 | 44 | def build_dataloader_from_cfg(dataset_cfg, dataloader_cfg): 45 | dataset_type = dataset_cfg.pop("type") 46 | dataset = build_from_registry(dataset_type, DATASET, **dataset_cfg) 47 | dataloader = DataLoader(dataset, **dataloader_cfg) 48 | 49 | return dataloader 50 | -------------------------------------------------------------------------------- /scaelum/builder/module_wrapper.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import subprocess 5 | import time 6 | 7 | import psutil 8 | import torch 9 | import torch.nn as nn 10 | from scaelum import utils as dutils 11 | from scaelum.logger import Logger 12 | from scaelum.timer import DistributedTimer 13 | 14 | from .sequential_wrapper import SequentialWrapper 15 | 16 | try: 17 | from torch.distributed.rpc import PyRRef as RpcRef 18 | except ImportError: 19 | from torch.distributed.rpc import RRef as RpcRef 20 | 21 | 22 | class ModuleWrapper(nn.Module): 23 | def __init__( 24 | self, 25 | rank: int, 26 | module: nn.Module, 27 | module_to_cuda: bool, 28 | output_to_cpu: bool, 29 | mem_limit: int, 30 | slowdown: int, 31 | timer_config: dict, 32 | logging_config: dict = None, 33 | cuda_device: int = -1, 34 | record_forward_time: bool = False, 35 | ): 36 | super(ModuleWrapper, self).__init__() 37 | # add basic config 38 | self._rank = rank 39 | self._module_to_cuda = module_to_cuda 40 | self._output_to_cpu = output_to_cpu 41 | self._slowdown = slowdown 42 | 43 | # add module 44 | assert isinstance( 45 | module, SequentialWrapper 46 | ), "The module is of type {}, but expected SequentialWrapper".format( 47 | type(module) 48 | ) 49 | self._module = module 50 | 51 | # add memory limit check 52 | assert ( 53 | mem_limit != 0 and mem_limit >= -1 54 | ), "mem_limit can only be set to -1 or positive number, if it is -1, it will automatically detect the GPU RAM based on cuda device." 55 | self._mem_limit = mem_limit 56 | 57 | # logger 58 | if logging_config: 59 | self._logger = Logger(**logging_config) 60 | else: 61 | self._logger = None 62 | 63 | # timer 64 | self._timer = DistributedTimer(**timer_config) 65 | 66 | # add slowdown to backward 67 | if logging_config: 68 | bwd_logger = Logger(**logging_config) 69 | else: 70 | bwd_logger = None 71 | 72 | self._slowdown_module = BackwardSlowdownModule( 73 | rank=rank, 74 | slowdown=slowdown, 75 | timer=self._timer, 76 | logger=bwd_logger, 77 | do_slowdown=True, 78 | ) 79 | 80 | self._output_slowdown_module = BackwardSlowdownModule( 81 | rank=rank, 82 | slowdown=slowdown, 83 | timer=self._timer, 84 | logger=bwd_logger, 85 | do_slowdown=False, 86 | ) 87 | 88 | # move from cpu to gpu if configured 89 | assert ( 90 | not module_to_cuda or cuda_device >= 0 91 | ), "GPU device index must be non-negative" 92 | if self._module_to_cuda: 93 | torch.cuda.set_device(cuda_device) 94 | self.cuda() 95 | self._gpu_index = cuda_device 96 | 97 | self._record_forward_time = record_forward_time 98 | if self._record_forward_time: 99 | self.forward_time = [] 100 | 101 | @property 102 | def rank(self): 103 | return self._rank 104 | 105 | @property 106 | def gpu_index(self): 107 | return self._gpu_index 108 | 109 | def _time_forward(self, *args): 110 | """ 111 | Time the forward pass and log into file if logger is present. Slow down 112 | by manually calling time.sleep(n) to simulate different computing power 113 | """ 114 | 115 | # preprocess 116 | args = self._process_data_before(args) 117 | 118 | # forward 119 | start = dutils.get_time() 120 | output = self._module(*args) 121 | end = dutils.get_time() 122 | 123 | # slowdown 124 | comp_time = end - start 125 | if self._slowdown > 0: 126 | time.sleep(comp_time * self._slowdown) 127 | 128 | # log 129 | real_end = dutils.get_time() 130 | if self._logger: 131 | self._logger.info( 132 | "forward time on rank {}: {}".format(self._rank, real_end - start) 133 | ) 134 | 135 | if self._record_forward_time: 136 | self.forward_time.append(real_end - start) 137 | 138 | # post processing 139 | output = self._process_data_after(output) 140 | return output 141 | 142 | def _convert_to_tuple(self, data): 143 | if isinstance(data, tuple) or isinstance(data, list): 144 | return data 145 | else: 146 | return (data,) 147 | 148 | def _process_data_before(self, args): 149 | # handle rpc ref 150 | if len(args) == 1 and isinstance(args[0], RpcRef): 151 | args = self._fetch_data_before(args[0]) 152 | args = [self._move_data_before(arg) for arg in args] 153 | args = [self._slowdown_backward(arg) for arg in args] 154 | return args 155 | 156 | def _process_data_after(self, output): 157 | output = self._convert_to_tuple(output) 158 | output = [self._slowdown_output_backward(arg) for arg in output] 159 | output = tuple(output) 160 | output = [self._move_data_after(data) for data in output] 161 | return output 162 | 163 | def _fetch_data_before(self, data): 164 | output = data.to_here() 165 | return output 166 | 167 | def _move_data_before(self, data): 168 | if self._module_to_cuda and isinstance(data, torch.Tensor): 169 | data = data.to("cuda:{}".format(self._gpu_index)) 170 | return data 171 | 172 | def _move_data_after(self, data): 173 | if self._output_to_cpu and isinstance(data, torch.Tensor): 174 | data = data.cpu() 175 | return data 176 | 177 | def _slowdown_backward(self, data): 178 | if isinstance(data, torch.Tensor): 179 | data = self._slowdown_module(data) 180 | return data 181 | 182 | def _slowdown_output_backward(self, data): 183 | if isinstance(data, torch.Tensor): 184 | data = self._output_slowdown_module(data) 185 | return data 186 | 187 | def detect_mem(self, destroy_module: bool): 188 | """ 189 | If mem_limit is -1, it means automatically detect the RAM. 190 | If it is a positive number, it means to use this set value as memory limit. 191 | 192 | The RAM should be in MB. 193 | """ 194 | if destroy_module: 195 | # delete module to get more accurate memory reading 196 | del self._module 197 | if self._module_to_cuda: 198 | torch.cuda.empty_cache() 199 | 200 | if self._mem_limit > 0: 201 | return self._mem_limit 202 | elif self._mem_limit == -1: 203 | if self._module_to_cuda: 204 | return self._detect_gpu_ram() 205 | else: 206 | return self._detect_cpu_ram() 207 | else: 208 | raise ValueError("Invalid value {} for mem_limit".format(self._mem_limit)) 209 | 210 | def _detect_gpu_ram(self): 211 | # get available mem 212 | _output_to_list = lambda x: x.decode("ascii").split("\n")[:-1] 213 | COMMAND = "nvidia-smi --query-gpu=memory.free --format=csv" 214 | memory_free_info = _output_to_list(subprocess.check_output(COMMAND.split()))[1:] 215 | memory_free_values = [int(x.split()[0]) for i, x in enumerate(memory_free_info)] 216 | 217 | # minus 800 to avoid OOM 218 | mem = memory_free_values[self._gpu_index] - 500 219 | return mem 220 | 221 | def _detect_cpu_ram(self): 222 | avai_mem = psutil.virtual_memory().available 223 | avai_mem = avai_mem / 1024 / 1024 224 | return avai_mem 225 | 226 | def forward(self, *args): 227 | """ 228 | Args: 229 | *args: a tuple containing all the inputs 230 | 231 | Returns: a tuple of tensors 232 | 233 | """ 234 | if isinstance(self, RpcRef): 235 | self = self.to_here() 236 | output = self._time_forward(*args) 237 | return output 238 | 239 | 240 | class BackwardSlowdownFunction(torch.autograd.Function): 241 | @staticmethod 242 | def forward(ctx, feat, rank, slowdown, timer, logger, do_slowdown): 243 | ctx.slowdown = slowdown 244 | ctx.timer = timer 245 | ctx.logger = logger 246 | ctx.rank = rank 247 | ctx.do_slowdown = do_slowdown 248 | 249 | ctx.save_for_backward(feat) 250 | output = feat.clone() 251 | return output 252 | 253 | @staticmethod 254 | def backward(ctx, grad_output): 255 | # slowdown 256 | timer = ctx.timer 257 | logger = ctx.logger 258 | 259 | if ctx.do_slowdown: 260 | dutils.synchronize() 261 | timer.add_timestamp() 262 | 263 | backward_time = timer.get_prev_interval() 264 | 265 | if ctx.slowdown > 0: 266 | time.sleep(max(0, backward_time * ctx.slowdown)) 267 | 268 | # log 269 | if logger: 270 | logger.info( 271 | "backward time on rank {}: {}".format( 272 | ctx.rank, backward_time * (ctx.slowdown + 1) 273 | ) 274 | ) 275 | 276 | grad_input = grad_output.clone() 277 | 278 | dutils.synchronize() 279 | timer.add_timestamp() 280 | 281 | # the number of output of backward should be the same 282 | # as that of input of forward 283 | return grad_input, None, None, None, None, None, None 284 | 285 | 286 | class BackwardSlowdownModule(nn.Module): 287 | def __init__(self, rank, slowdown, timer, logger, do_slowdown): 288 | super().__init__() 289 | self.rank = rank 290 | self.slowdown = slowdown 291 | self.timer = timer 292 | self.logger = logger 293 | self.do_slowdown = do_slowdown 294 | self.func = BackwardSlowdownFunction.apply 295 | 296 | def forward(self, data): 297 | return self.func( 298 | data, self.rank, self.slowdown, self.timer, self.logger, self.do_slowdown 299 | ) 300 | -------------------------------------------------------------------------------- /scaelum/builder/sequential_wrapper.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | 5 | import torch.nn as nn 6 | 7 | 8 | class SequentialWrapper(nn.Sequential): 9 | """ 10 | A wrapper class for nn.Sequential so that the nn.Sequential 11 | can handle multiple inputs 12 | """ 13 | 14 | def forward(self, *inputs): 15 | for module in self._modules.values(): 16 | if isinstance(inputs, tuple) or isinstance(inputs, list): 17 | inputs = module(*inputs) 18 | else: 19 | inputs = module(inputs) 20 | return inputs 21 | -------------------------------------------------------------------------------- /scaelum/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import Config, load_config 2 | -------------------------------------------------------------------------------- /scaelum/config/config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import inspect 5 | import os.path as osp 6 | import sys 7 | from importlib.machinery import SourceFileLoader 8 | 9 | 10 | class Config(dict): 11 | """ 12 | Wrap a dictionary object so that we can access 13 | values as attributes 14 | """ 15 | 16 | def __missing__(self, name): 17 | raise KeyError(name) 18 | 19 | def __getattr__(self, name): 20 | value = super(Config, self).__getitem__(name) 21 | return value 22 | 23 | def __setattr__(self, name, value): 24 | super(Config, self).__setitem__(name, value) 25 | 26 | def update(self, config): 27 | for k, v in config.items(): 28 | self.__setattr__(k, v) 29 | return self 30 | 31 | @staticmethod 32 | def from_dict(data: dict): 33 | config = Config() 34 | 35 | for k, v in data.items(): 36 | config.__setattr__(k, v) 37 | return config 38 | 39 | 40 | def _py2dict(py_path: str): 41 | # pylint: disable=no-value-for-parameter 42 | """ 43 | Read python file as python dictionary 44 | """ 45 | 46 | assert py_path.endswith(".py") 47 | 48 | py_path = osp.abspath(py_path) 49 | parent_dir = osp.dirname(py_path) 50 | if parent_dir not in sys.path: 51 | sys.path.insert(0, parent_dir) 52 | 53 | module_name = osp.splitext(osp.basename(py_path))[0] 54 | source_file = SourceFileLoader(fullname=module_name, path=py_path) 55 | module = source_file.load_module() 56 | sys.path.pop(0) 57 | doc = { 58 | k: v 59 | for k, v in module.__dict__.items() 60 | if not k.startswith("__") and not inspect.ismodule(v) and not inspect.isclass(v) 61 | } 62 | del sys.modules[module_name] 63 | return doc 64 | 65 | 66 | def load_config(file_path: str): 67 | config_dict = _py2dict(file_path) 68 | config = Config(config_dict) 69 | 70 | base = config.pop("base", None) 71 | 72 | if base: 73 | base_config_path = osp.join(osp.dirname(file_path), base) 74 | base_config_dict = _py2dict(base_config_path) 75 | base_config = Config(base_config_dict) 76 | config = base_config.update(config) 77 | 78 | return config 79 | -------------------------------------------------------------------------------- /scaelum/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .bert_dataset import GlueDataset 2 | from .data_generator import BaseGenerator, DataloaderGenerator, RandomTensorGenerator 3 | from .dataset import * 4 | -------------------------------------------------------------------------------- /scaelum/dataset/bert_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | 5 | import os.path as osp 6 | import pickle 7 | 8 | import torch 9 | from scaelum.registry import DATASET 10 | from torch.utils.data import Dataset, TensorDataset 11 | 12 | from .glue.processor import PROCESSORS, convert_examples_to_features 13 | from .glue.tokenization import BertTokenizer 14 | 15 | 16 | @DATASET.register_module 17 | class GlueDataset(Dataset): 18 | def __init__( 19 | self, data_dir, bert_model, vocab_file, max_seq_length, do_lower_case, processor 20 | ): 21 | super().__init__() 22 | self.data_dir = data_dir 23 | self.bert_model = bert_model 24 | self.max_seq_length = max_seq_length 25 | self.do_lower_case = do_lower_case 26 | self.tokenizer = BertTokenizer( 27 | vocab_file, 28 | do_lower_case=do_lower_case, 29 | max_len=512, 30 | ) 31 | self.processor = PROCESSORS[processor]() 32 | self.dataset = self._build_dataset() 33 | 34 | def __getitem__(self, idx): 35 | items = self.dataset.__getitem__(idx) 36 | 37 | return items[:3], items[-1] 38 | 39 | def __len__(self): 40 | return self.dataset.__len__() 41 | 42 | def _get_train_features(self): 43 | cached_train_features_file = osp.join( 44 | self.data_dir, 45 | "{0}_{1}_{2}".format( 46 | list(filter(None, self.bert_model.split("/"))).pop(), 47 | str(self.max_seq_length), 48 | str(self.do_lower_case), 49 | ), 50 | ) 51 | train_features = None 52 | try: 53 | with open(cached_train_features_file, "rb") as reader: 54 | train_features = pickle.load(reader) 55 | except: 56 | print("Converting examples to features") 57 | train_examples = self.processor.get_train_examples(data_dir=self.data_dir) 58 | train_features, _ = convert_examples_to_features( 59 | train_examples, 60 | self.processor.get_labels(), 61 | self.max_seq_length, 62 | self.tokenizer, 63 | ) 64 | with open(cached_train_features_file, "wb") as writer: 65 | pickle.dump(train_features, writer) 66 | return train_features 67 | 68 | def _gen_tensor_dataset(self, features): 69 | all_input_ids = torch.tensor( 70 | [f.input_ids for f in features], 71 | dtype=torch.long, 72 | ) 73 | all_input_mask = torch.tensor( 74 | [f.input_mask for f in features], 75 | dtype=torch.long, 76 | ) 77 | all_segment_ids = torch.tensor( 78 | [f.segment_ids for f in features], 79 | dtype=torch.long, 80 | ) 81 | all_label_ids = torch.tensor( 82 | [f.label_id for f in features], 83 | dtype=torch.long, 84 | ) 85 | return TensorDataset( 86 | all_input_ids, 87 | all_input_mask, 88 | all_segment_ids, 89 | all_label_ids, 90 | ) 91 | 92 | def _build_dataset(self): 93 | features = self._get_train_features() 94 | return self._gen_tensor_dataset(features) 95 | -------------------------------------------------------------------------------- /scaelum/dataset/data_generator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | import abc 4 | 5 | import torch 6 | from scaelum.builder import build_dataloader_from_cfg 7 | from scaelum.registry import DATA_GENERATOR 8 | 9 | 10 | class BaseGenerator(object): 11 | __metaclass__ = abc.ABCMeta 12 | 13 | @abc.abstractmethod 14 | def generate(self): 15 | pass 16 | 17 | 18 | @DATA_GENERATOR.register_module 19 | class RandomTensorGenerator(BaseGenerator): 20 | def __init__(self, generator_cfg): 21 | self.generator_cfg = generator_cfg 22 | 23 | def generate(self): 24 | return torch.rand(**self.generator_cfg) 25 | 26 | 27 | @DATA_GENERATOR.register_module 28 | class DataloaderGenerator(BaseGenerator): 29 | def __init__(self, generator_cfg): 30 | self.dataloader = build_dataloader_from_cfg(**generator_cfg) 31 | self.generator = iter(self.dataloader) 32 | 33 | def generate(self): 34 | return next(iter(self.dataloader))[0] 35 | -------------------------------------------------------------------------------- /scaelum/dataset/dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | 5 | import random 6 | 7 | import torch 8 | import torchvision 9 | import torchvision.transforms as transforms 10 | from scaelum.registry import DATASET 11 | from torch.utils.data import Dataset 12 | 13 | 14 | @DATASET.register_module 15 | class RandomMlpDataset(Dataset): 16 | def __init__(self, num=1000, dim=1024): 17 | self.dim = dim 18 | self.data = torch.rand(num, dim) # , 224, 224) 19 | 20 | def __len__(self): 21 | return self.data.size(0) 22 | 23 | def __getitem__(self, idx): 24 | return self.data[idx], random.randint(0, self.dim - 1) 25 | 26 | 27 | @DATASET.register_module 28 | class CIFAR10Dataset(Dataset): 29 | def __init__(self, mean, std, *args, **kwargs): 30 | transform_train = transforms.Compose( 31 | [ 32 | # transforms.ToPILImage(), 33 | transforms.RandomCrop(32, padding=4), 34 | transforms.RandomHorizontalFlip(), 35 | transforms.RandomRotation(15), 36 | transforms.ToTensor(), 37 | transforms.Normalize(mean, std), 38 | ] 39 | ) 40 | self.cifar10dataset = torchvision.datasets.CIFAR10( 41 | transform=transform_train, *args, **kwargs 42 | ) 43 | 44 | def __len__(self): 45 | return self.cifar10dataset.__len__() 46 | 47 | def __getitem__(self, idx): 48 | return self.cifar10dataset.__getitem__(idx) 49 | -------------------------------------------------------------------------------- /scaelum/dataset/glue/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpcaitech/SkyComputing/456c749d87f0fdda551635937ee083f41e6e340d/scaelum/dataset/glue/__init__.py -------------------------------------------------------------------------------- /scaelum/dataset/glue/file_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | 5 | from __future__ import absolute_import, division, print_function, unicode_literals 6 | 7 | import json 8 | import logging 9 | import os 10 | import shutil 11 | import sys 12 | import tempfile 13 | from functools import wraps 14 | from hashlib import sha256 15 | from io import open 16 | 17 | import boto3 18 | import requests 19 | from botocore.exceptions import ClientError 20 | from tqdm import tqdm 21 | 22 | try: 23 | from urllib.parse import urlparse 24 | except ImportError: 25 | from urlparse import urlparse 26 | 27 | try: 28 | from pathlib import Path 29 | 30 | PYTORCH_PRETRAINED_BERT_CACHE = Path( 31 | os.getenv( 32 | "PYTORCH_PRETRAINED_BERT_CACHE", Path.home() / ".pytorch_pretrained_bert" 33 | ) 34 | ) 35 | except AttributeError: 36 | PYTORCH_PRETRAINED_BERT_CACHE = os.getenv( 37 | "PYTORCH_PRETRAINED_BERT_CACHE", 38 | os.path.join(os.path.expanduser("~"), ".pytorch_pretrained_bert"), 39 | ) 40 | 41 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 42 | 43 | 44 | def url_to_filename(url, etag=None): 45 | """ 46 | Convert `url` into a hashed filename in a repeatable way. 47 | If `etag` is specified, append its hash to the url's, delimited 48 | by a period. 49 | """ 50 | url_bytes = url.encode("utf-8") 51 | url_hash = sha256(url_bytes) 52 | filename = url_hash.hexdigest() 53 | 54 | if etag: 55 | etag_bytes = etag.encode("utf-8") 56 | etag_hash = sha256(etag_bytes) 57 | filename += "." + etag_hash.hexdigest() 58 | 59 | return filename 60 | 61 | 62 | def filename_to_url(filename, cache_dir=None): 63 | """ 64 | Return the url and etag (which may be ``None``) stored for `filename`. 65 | Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. 66 | """ 67 | if cache_dir is None: 68 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 69 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 70 | cache_dir = str(cache_dir) 71 | 72 | cache_path = os.path.join(cache_dir, filename) 73 | if not os.path.exists(cache_path): 74 | raise EnvironmentError("file {} not found".format(cache_path)) 75 | 76 | meta_path = cache_path + ".json" 77 | if not os.path.exists(meta_path): 78 | raise EnvironmentError("file {} not found".format(meta_path)) 79 | 80 | with open(meta_path, encoding="utf-8") as meta_file: 81 | metadata = json.load(meta_file) 82 | url = metadata["url"] 83 | etag = metadata["etag"] 84 | 85 | return url, etag 86 | 87 | 88 | def cached_path(url_or_filename, cache_dir=None): 89 | """ 90 | Given something that might be a URL (or might be a local path), 91 | determine which. If it's a URL, download the file and cache it, and 92 | return the path to the cached file. If it's already a local path, 93 | make sure the file exists and then return the path. 94 | """ 95 | if cache_dir is None: 96 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 97 | if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): 98 | url_or_filename = str(url_or_filename) 99 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 100 | cache_dir = str(cache_dir) 101 | 102 | parsed = urlparse(url_or_filename) 103 | 104 | if parsed.scheme in ("http", "https", "s3"): 105 | # URL, so get it from the cache (downloading if necessary) 106 | return get_from_cache(url_or_filename, cache_dir) 107 | elif os.path.exists(url_or_filename): 108 | # File, and it exists. 109 | return url_or_filename 110 | elif parsed.scheme == "": 111 | # File, but it doesn't exist. 112 | raise EnvironmentError("file {} not found".format(url_or_filename)) 113 | else: 114 | # Something unknown 115 | raise ValueError( 116 | "unable to parse {} as a URL or as a local path".format(url_or_filename) 117 | ) 118 | 119 | 120 | def split_s3_path(url): 121 | """Split a full s3 path into the bucket name and path.""" 122 | parsed = urlparse(url) 123 | if not parsed.netloc or not parsed.path: 124 | raise ValueError("bad s3 path {}".format(url)) 125 | bucket_name = parsed.netloc 126 | s3_path = parsed.path 127 | # Remove '/' at beginning of path. 128 | if s3_path.startswith("/"): 129 | s3_path = s3_path[1:] 130 | return bucket_name, s3_path 131 | 132 | 133 | def s3_request(func): 134 | """ 135 | Wrapper function for s3 requests in order to create more helpful error 136 | messages. 137 | """ 138 | 139 | @wraps(func) 140 | def wrapper(url, *args, **kwargs): 141 | try: 142 | return func(url, *args, **kwargs) 143 | except ClientError as exc: 144 | if int(exc.response["Error"]["Code"]) == 404: 145 | raise EnvironmentError("file {} not found".format(url)) 146 | else: 147 | raise 148 | 149 | return wrapper 150 | 151 | 152 | @s3_request 153 | def s3_etag(url): 154 | """Check ETag on S3 object.""" 155 | s3_resource = boto3.resource("s3") 156 | bucket_name, s3_path = split_s3_path(url) 157 | s3_object = s3_resource.Object(bucket_name, s3_path) 158 | return s3_object.e_tag 159 | 160 | 161 | @s3_request 162 | def s3_get(url, temp_file): 163 | """Pull a file directly from S3.""" 164 | s3_resource = boto3.resource("s3") 165 | bucket_name, s3_path = split_s3_path(url) 166 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) 167 | 168 | 169 | def http_get(url, temp_file): 170 | req = requests.get(url, stream=True) 171 | content_length = req.headers.get("Content-Length") 172 | total = int(content_length) if content_length is not None else None 173 | progress = tqdm(unit="B", total=total) 174 | for chunk in req.iter_content(chunk_size=1024): 175 | if chunk: # filter out keep-alive new chunks 176 | progress.update(len(chunk)) 177 | temp_file.write(chunk) 178 | progress.close() 179 | 180 | 181 | def get_from_cache(url, cache_dir=None): 182 | """ 183 | Given a URL, look for the corresponding dataset in the local cache. 184 | If it's not there, download it. Then return the path to the cached file. 185 | """ 186 | if cache_dir is None: 187 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 188 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 189 | cache_dir = str(cache_dir) 190 | 191 | if not os.path.exists(cache_dir): 192 | os.makedirs(cache_dir) 193 | 194 | # Get eTag to add to filename, if it exists. 195 | if url.startswith("s3://"): 196 | etag = s3_etag(url) 197 | else: 198 | response = requests.head(url, allow_redirects=True) 199 | if response.status_code != 200: 200 | raise IOError( 201 | "HEAD request failed for url {} with status code {}".format( 202 | url, response.status_code 203 | ) 204 | ) 205 | etag = response.headers.get("ETag") 206 | 207 | filename = url_to_filename(url, etag) 208 | 209 | # get cache path to put the file 210 | cache_path = os.path.join(cache_dir, filename) 211 | 212 | if not os.path.exists(cache_path): 213 | # Download to temporary file, then copy to cache dir once finished. 214 | # Otherwise you get corrupt cache entries if the download gets interrupted. 215 | with tempfile.NamedTemporaryFile() as temp_file: 216 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name) 217 | 218 | # GET file object 219 | if url.startswith("s3://"): 220 | s3_get(url, temp_file) 221 | else: 222 | http_get(url, temp_file) 223 | 224 | # we are copying the file before closing it, so flush to avoid truncation 225 | temp_file.flush() 226 | # shutil.copyfileobj() starts at the current position, so go to the start 227 | temp_file.seek(0) 228 | 229 | logger.info("copying %s to cache at %s", temp_file.name, cache_path) 230 | with open(cache_path, "wb") as cache_file: 231 | shutil.copyfileobj(temp_file, cache_file) 232 | 233 | logger.info("creating metadata file for %s", cache_path) 234 | meta = {"url": url, "etag": etag} 235 | meta_path = cache_path + ".json" 236 | with open(meta_path, "w", encoding="utf-8") as meta_file: 237 | json.dump(meta, meta_file) 238 | 239 | logger.info("removing temp file %s", temp_file.name) 240 | 241 | return cache_path 242 | 243 | 244 | def read_set_from_file(filename): 245 | """ 246 | Extract a de-duped collection (set) of text from a file. 247 | Expected file format is one item per line. 248 | """ 249 | collection = set() 250 | with open(filename, "r", encoding="utf-8") as file_: 251 | for line in file_: 252 | collection.add(line.rstrip()) 253 | return collection 254 | 255 | 256 | def get_file_extension(path, dot=True, lower=True): 257 | ext = os.path.splitext(path)[1] 258 | ext = ext if dot else ext[1:] 259 | return ext.lower() if lower else ext 260 | -------------------------------------------------------------------------------- /scaelum/dataset/glue/processor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | 5 | import csv 6 | import os 7 | import sys 8 | 9 | 10 | class InputExample(object): 11 | """A single training/test example for simple sequence classification.""" 12 | 13 | def __init__(self, guid, text_a, text_b=None, label=None): 14 | """Constructs a InputExample. 15 | Args: 16 | guid: Unique id for the example. 17 | text_a: string. The untokenized text of the first sequence. For 18 | single sequence tasks, only this sequence must be specified. 19 | text_b: (Optional) string. The untokenized text of the second 20 | sequence. Only must be specified for sequence pair tasks. 21 | label: (Optional) string. The label of the example. This should be 22 | specified for train and dev examples, but not for test 23 | examples. 24 | """ 25 | self.guid = guid 26 | self.text_a = text_a 27 | self.text_b = text_b 28 | self.label = label 29 | 30 | 31 | class InputFeatures(object): 32 | """A single set of features of data.""" 33 | 34 | def __init__(self, input_ids, input_mask, segment_ids, label_id): 35 | self.input_ids = input_ids 36 | self.input_mask = input_mask 37 | self.segment_ids = segment_ids 38 | self.label_id = label_id 39 | 40 | 41 | class DataProcessor(object): 42 | """Base class for data converters for sequence classification data sets.""" 43 | 44 | def get_train_examples(self, data_dir): 45 | """Gets a collection of `InputExample`s for the train set.""" 46 | raise NotImplementedError() 47 | 48 | def get_dev_examples(self, data_dir): 49 | """Gets a collection of `InputExample`s for the dev set.""" 50 | raise NotImplementedError() 51 | 52 | def get_labels(self): 53 | """Gets the list of labels for this data set.""" 54 | raise NotImplementedError() 55 | 56 | @classmethod 57 | def _read_tsv(cls, input_file, quotechar=None): 58 | """Reads a tab separated value file.""" 59 | with open(input_file, "r", encoding="utf-8") as f: 60 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar) 61 | lines = [] 62 | for line in reader: 63 | if sys.version_info[0] == 2: 64 | line = list(unicode(cell, "utf-8") for cell in line) 65 | lines.append(line) 66 | return lines 67 | 68 | 69 | class MrpcProcessor(DataProcessor): 70 | """Processor for the MRPC data set (GLUE version).""" 71 | 72 | def get_train_examples(self, data_dir): 73 | """See base class.""" 74 | return self._create_examples( 75 | self._read_tsv(os.path.join(data_dir, "train.tsv")), 76 | "train", 77 | ) 78 | 79 | def get_dev_examples(self, data_dir): 80 | """See base class.""" 81 | return self._create_examples( 82 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), 83 | "dev", 84 | ) 85 | 86 | def get_labels(self): 87 | """See base class.""" 88 | return ["0", "1"] 89 | 90 | def _create_examples(self, lines, set_type): 91 | """Creates examples for the training and dev sets.""" 92 | examples = [] 93 | for (i, line) in enumerate(lines): 94 | if i == 0: 95 | continue 96 | guid = "%s-%s" % (set_type, i) 97 | text_a = line[3] 98 | text_b = line[4] 99 | label = line[0] 100 | examples.append( 101 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label) 102 | ) 103 | return examples 104 | 105 | 106 | class MnliProcessor(DataProcessor): 107 | """Processor for the MultiNLI data set (GLUE version).""" 108 | 109 | def get_train_examples(self, data_dir): 110 | """See base class.""" 111 | return self._create_examples( 112 | self._read_tsv(os.path.join(data_dir, "train.tsv")), 113 | "train", 114 | ) 115 | 116 | def get_dev_examples(self, data_dir): 117 | """See base class.""" 118 | return self._create_examples( 119 | self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), 120 | "dev_matched", 121 | ) 122 | 123 | def get_labels(self): 124 | """See base class.""" 125 | return ["contradiction", "entailment", "neutral"] 126 | 127 | def _create_examples(self, lines, set_type): 128 | """Creates examples for the training and dev sets.""" 129 | examples = [] 130 | for (i, line) in enumerate(lines): 131 | if i == 0: 132 | continue 133 | guid = "%s-%s" % (set_type, line[0]) 134 | text_a = line[8] 135 | text_b = line[9] 136 | label = line[-1] 137 | examples.append( 138 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label) 139 | ) 140 | return examples 141 | 142 | 143 | class ColaProcessor(DataProcessor): 144 | """Processor for the CoLA data set (GLUE version).""" 145 | 146 | def get_train_examples(self, data_dir): 147 | """See base class.""" 148 | return self._create_examples( 149 | self._read_tsv(os.path.join(data_dir, "train.tsv")), 150 | "train", 151 | ) 152 | 153 | def get_dev_examples(self, data_dir): 154 | """See base class.""" 155 | return self._create_examples( 156 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), 157 | "dev", 158 | ) 159 | 160 | def get_labels(self): 161 | """See base class.""" 162 | return ["0", "1"] 163 | 164 | def _create_examples(self, lines, set_type): 165 | """Creates examples for the training and dev sets.""" 166 | examples = [] 167 | for (i, line) in enumerate(lines): 168 | guid = "%s-%s" % (set_type, i) 169 | text_a = line[3] 170 | label = line[1] 171 | examples.append( 172 | InputExample(guid=guid, text_a=text_a, text_b=None, label=label) 173 | ) 174 | return examples 175 | 176 | 177 | class Sst2Processor(DataProcessor): 178 | """Processor for the CoLA data set (GLUE version).""" 179 | 180 | def get_train_examples(self, data_dir): 181 | """See base class.""" 182 | return self._create_examples( 183 | self._read_tsv(os.path.join(data_dir, "train.tsv")), 184 | "train", 185 | ) 186 | 187 | def get_dev_examples(self, data_dir): 188 | """See base class.""" 189 | return self._create_examples( 190 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), 191 | "dev", 192 | ) 193 | 194 | def get_labels(self): 195 | """See base class.""" 196 | return ["0", "1"] 197 | 198 | def _create_examples(self, lines, set_type): 199 | """Creates examples for the training and dev sets.""" 200 | examples = [] 201 | for (i, line) in enumerate(lines): 202 | if i == 0: 203 | continue 204 | guid = "%s-%s" % (set_type, i) 205 | text_a = line[0] 206 | label = line[1] 207 | examples.append( 208 | InputExample(guid=guid, text_a=text_a, text_b=None, label=label) 209 | ) 210 | return examples 211 | 212 | 213 | def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer): 214 | """Loads a data file into a list of `InputBatch`s.""" 215 | 216 | label_map = {label: i for i, label in enumerate(label_list)} 217 | 218 | features = [] 219 | for (ex_index, example) in enumerate(examples): 220 | tokens_a = tokenizer.tokenize(example.text_a) 221 | 222 | tokens_b = None 223 | if example.text_b: 224 | tokens_b = tokenizer.tokenize(example.text_b) 225 | # Modifies `tokens_a` and `tokens_b` in place so that the total 226 | # length is less than the specified length. 227 | # Account for [CLS], [SEP], [SEP] with "- 3" 228 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) 229 | else: 230 | # Account for [CLS] and [SEP] with "- 2" 231 | if len(tokens_a) > max_seq_length - 2: 232 | tokens_a = tokens_a[: (max_seq_length - 2)] 233 | 234 | # The convention in BERT is: 235 | # (a) For sequence pairs: 236 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 237 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 238 | # (b) For single sequences: 239 | # tokens: [CLS] the dog is hairy . [SEP] 240 | # type_ids: 0 0 0 0 0 0 0 241 | # 242 | # Where "type_ids" are used to indicate whether this is the first 243 | # sequence or the second sequence. The embedding vectors for `type=0` and 244 | # `type=1` were learned during pre-training and are added to the wordpiece 245 | # embedding vector (and position vector). This is not *strictly* necessary 246 | # since the [SEP] token unambigiously separates the sequences, but it makes 247 | # it easier for the model to learn the concept of sequences. 248 | # 249 | # For classification tasks, the first vector (corresponding to [CLS]) is 250 | # used as as the "sentence vector". Note that this only makes sense because 251 | # the entire model is fine-tuned. 252 | tokens = ["[CLS]"] + tokens_a + ["[SEP]"] 253 | segment_ids = [0] * len(tokens) 254 | 255 | if tokens_b: 256 | tokens += tokens_b + ["[SEP]"] 257 | segment_ids += [1] * (len(tokens_b) + 1) 258 | 259 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 260 | 261 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 262 | # tokens are attended to. 263 | input_mask = [1] * len(input_ids) 264 | 265 | # Zero-pad up to the sequence length. 266 | padding = [0] * (max_seq_length - len(input_ids)) 267 | input_ids += padding 268 | input_mask += padding 269 | segment_ids += padding 270 | 271 | assert len(input_ids) == max_seq_length 272 | assert len(input_mask) == max_seq_length 273 | assert len(segment_ids) == max_seq_length 274 | 275 | label_id = label_map[example.label] 276 | 277 | features.append( 278 | InputFeatures( 279 | input_ids=input_ids, 280 | input_mask=input_mask, 281 | segment_ids=segment_ids, 282 | label_id=label_id, 283 | ) 284 | ) 285 | return features, label_map 286 | 287 | 288 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 289 | """Truncates a sequence pair in place to the maximum length.""" 290 | 291 | # This is a simple heuristic which will always truncate the longer sequence 292 | # one token at a time. This makes more sense than truncating an equal percent 293 | # of tokens from each, since if one sequence is very short then each token 294 | # that's truncated likely contains more information than a longer sequence. 295 | while True: 296 | total_length = len(tokens_a) + len(tokens_b) 297 | if total_length <= max_length: 298 | break 299 | if len(tokens_a) > len(tokens_b): 300 | tokens_a.pop() 301 | else: 302 | tokens_b.pop() 303 | 304 | 305 | PROCESSORS = { 306 | "cola": ColaProcessor, 307 | "mnli": MnliProcessor, 308 | "mrpc": MrpcProcessor, 309 | "sst-2": Sst2Processor, 310 | } 311 | -------------------------------------------------------------------------------- /scaelum/dataset/glue/tokenization.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | 5 | from __future__ import absolute_import, division, print_function, unicode_literals 6 | 7 | import collections 8 | import logging 9 | import os 10 | import unicodedata 11 | from io import open 12 | 13 | import six 14 | 15 | from .file_utils import cached_path 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | PRETRAINED_VOCAB_ARCHIVE_MAP = { 20 | "bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", 21 | "bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", 22 | "bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt", 23 | "bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt", 24 | "bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", 25 | "bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", 26 | "bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", 27 | } 28 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { 29 | "bert-base-uncased": 512, 30 | "bert-large-uncased": 512, 31 | "bert-base-cased": 512, 32 | "bert-large-cased": 512, 33 | "bert-base-multilingual-uncased": 512, 34 | "bert-base-multilingual-cased": 512, 35 | "bert-base-chinese": 512, 36 | } 37 | VOCAB_NAME = "vocab.txt" 38 | 39 | 40 | def convert_to_unicode(text): 41 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 42 | if six.PY3: 43 | if isinstance(text, str): 44 | return text 45 | elif isinstance(text, bytes): 46 | return text.decode("utf-8", "ignore") 47 | else: 48 | raise ValueError("Unsupported string type: %s" % (type(text))) 49 | elif six.PY2: 50 | if isinstance(text, str): 51 | return text.decode("utf-8", "ignore") 52 | elif isinstance(text, unicode): 53 | return text 54 | else: 55 | raise ValueError("Unsupported string type: %s" % (type(text))) 56 | else: 57 | raise ValueError("Not running on Python2 or Python 3?") 58 | 59 | 60 | def load_vocab(vocab_file): 61 | """Loads a vocabulary file into a dictionary.""" 62 | vocab = collections.OrderedDict() 63 | index = 0 64 | with open(vocab_file, "r", encoding="utf-8") as reader: 65 | while True: 66 | token = reader.readline() 67 | if not token: 68 | break 69 | token = token.strip() 70 | vocab[token] = index 71 | index += 1 72 | return vocab 73 | 74 | 75 | def whitespace_tokenize(text): 76 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 77 | text = text.strip() 78 | if not text: 79 | return [] 80 | tokens = text.split() 81 | return tokens 82 | 83 | 84 | class BertTokenizer(object): 85 | """Runs end-to-end tokenization: punctuation splitting + wordpiece""" 86 | 87 | def __init__( 88 | self, 89 | vocab_file, 90 | do_lower_case=True, 91 | max_len=None, 92 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"), 93 | ): 94 | if not os.path.isfile(vocab_file): 95 | raise ValueError( 96 | "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " 97 | "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format( 98 | vocab_file 99 | ) 100 | ) 101 | self.vocab = load_vocab(vocab_file) 102 | self.ids_to_tokens = collections.OrderedDict( 103 | [(ids, tok) for tok, ids in self.vocab.items()] 104 | ) 105 | self.basic_tokenizer = BasicTokenizer( 106 | do_lower_case=do_lower_case, never_split=never_split 107 | ) 108 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 109 | self.max_len = max_len if max_len is not None else int(1e12) 110 | 111 | def tokenize(self, text): 112 | split_tokens = [] 113 | for token in self.basic_tokenizer.tokenize(text): 114 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 115 | split_tokens.append(sub_token) 116 | return split_tokens 117 | 118 | def convert_tokens_to_ids(self, tokens): 119 | """Converts a sequence of tokens into ids using the vocab.""" 120 | ids = [] 121 | for token in tokens: 122 | ids.append(self.vocab[token]) 123 | if len(ids) > self.max_len: 124 | raise ValueError( 125 | "Token indices sequence length is longer than the specified maximum " 126 | " sequence length for this BERT model ({} > {}). Running this" 127 | " sequence through BERT will result in indexing errors".format( 128 | len(ids), self.max_len 129 | ) 130 | ) 131 | return ids 132 | 133 | def convert_ids_to_tokens(self, ids): 134 | """Converts a sequence of ids in wordpiece tokens using the vocab.""" 135 | tokens = [] 136 | for i in ids: 137 | tokens.append(self.ids_to_tokens[i]) 138 | return tokens 139 | 140 | @classmethod 141 | def from_pretrained( 142 | cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs 143 | ): 144 | """ 145 | Instantiate a PreTrainedBertModel from a pre-trained model file. 146 | Download and cache the pre-trained model file if needed. 147 | """ 148 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: 149 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] 150 | else: 151 | vocab_file = pretrained_model_name_or_path 152 | if os.path.isdir(vocab_file): 153 | vocab_file = os.path.join(vocab_file, VOCAB_NAME) 154 | # redirect to the cache, if necessary 155 | try: 156 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) 157 | except EnvironmentError: 158 | logger.error( 159 | "Model name '{}' was not found in model name list ({}). " 160 | "We assumed '{}' was a path or url but couldn't find any file " 161 | "associated to this path or url.".format( 162 | pretrained_model_name_or_path, 163 | ", ".join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), 164 | vocab_file, 165 | ) 166 | ) 167 | return None 168 | if resolved_vocab_file == vocab_file: 169 | logger.info("loading vocabulary file {}".format(vocab_file)) 170 | else: 171 | logger.info( 172 | "loading vocabulary file {} from cache at {}".format( 173 | vocab_file, resolved_vocab_file 174 | ) 175 | ) 176 | if ( 177 | pretrained_model_name_or_path 178 | in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP 179 | ): 180 | # if we're using a pretrained model, ensure the tokenizer wont index sequences longer 181 | # than the number of positional embeddings 182 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[ 183 | pretrained_model_name_or_path 184 | ] 185 | kwargs["max_len"] = min(kwargs.get("max_len", int(1e12)), max_len) 186 | # Instantiate tokenizer. 187 | tokenizer = cls(resolved_vocab_file, *inputs, **kwargs) 188 | return tokenizer 189 | 190 | 191 | class BasicTokenizer(object): 192 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 193 | 194 | def __init__( 195 | self, 196 | do_lower_case=True, 197 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"), 198 | ): 199 | """Constructs a BasicTokenizer. 200 | Args: 201 | do_lower_case: Whether to lower case the input. 202 | """ 203 | self.do_lower_case = do_lower_case 204 | self.never_split = never_split 205 | 206 | def tokenize(self, text): 207 | """Tokenizes a piece of text.""" 208 | text = self._clean_text(text) 209 | # This was added on November 1st, 2018 for the multilingual and Chinese 210 | # models. This is also applied to the English models now, but it doesn't 211 | # matter since the English models were not trained on any Chinese data 212 | # and generally don't have any Chinese data in them (there are Chinese 213 | # characters in the vocabulary because Wikipedia does have some Chinese 214 | # words in the English Wikipedia.). 215 | text = self._tokenize_chinese_chars(text) 216 | orig_tokens = whitespace_tokenize(text) 217 | split_tokens = [] 218 | for token in orig_tokens: 219 | if self.do_lower_case and token not in self.never_split: 220 | token = token.lower() 221 | token = self._run_strip_accents(token) 222 | split_tokens.extend(self._run_split_on_punc(token)) 223 | 224 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 225 | return output_tokens 226 | 227 | def _run_strip_accents(self, text): 228 | """Strips accents from a piece of text.""" 229 | text = unicodedata.normalize("NFD", text) 230 | output = [] 231 | for char in text: 232 | cat = unicodedata.category(char) 233 | if cat == "Mn": 234 | continue 235 | output.append(char) 236 | return "".join(output) 237 | 238 | def _run_split_on_punc(self, text): 239 | """Splits punctuation on a piece of text.""" 240 | if text in self.never_split: 241 | return [text] 242 | chars = list(text) 243 | i = 0 244 | start_new_word = True 245 | output = [] 246 | while i < len(chars): 247 | char = chars[i] 248 | if _is_punctuation(char): 249 | output.append([char]) 250 | start_new_word = True 251 | else: 252 | if start_new_word: 253 | output.append([]) 254 | start_new_word = False 255 | output[-1].append(char) 256 | i += 1 257 | 258 | return ["".join(x) for x in output] 259 | 260 | def _tokenize_chinese_chars(self, text): 261 | """Adds whitespace around any CJK character.""" 262 | output = [] 263 | for char in text: 264 | cp = ord(char) 265 | if self._is_chinese_char(cp): 266 | output.append(" ") 267 | output.append(char) 268 | output.append(" ") 269 | else: 270 | output.append(char) 271 | return "".join(output) 272 | 273 | def _is_chinese_char(self, cp): 274 | """Checks whether CP is the codepoint of a CJK character.""" 275 | # This defines a "chinese character" as anything in the CJK Unicode block: 276 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 277 | # 278 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 279 | # despite its name. The modern Korean Hangul alphabet is a different block, 280 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 281 | # space-separated words, so they are not treated specially and handled 282 | # like the all of the other languages. 283 | if ( 284 | (cp >= 0x4E00 and cp <= 0x9FFF) 285 | or (cp >= 0x3400 and cp <= 0x4DBF) # 286 | or (cp >= 0x20000 and cp <= 0x2A6DF) # 287 | or (cp >= 0x2A700 and cp <= 0x2B73F) # 288 | or (cp >= 0x2B740 and cp <= 0x2B81F) # 289 | or (cp >= 0x2B820 and cp <= 0x2CEAF) # 290 | or (cp >= 0xF900 and cp <= 0xFAFF) 291 | or (cp >= 0x2F800 and cp <= 0x2FA1F) # 292 | ): # 293 | return True 294 | 295 | return False 296 | 297 | def _clean_text(self, text): 298 | """Performs invalid character removal and whitespace cleanup on text.""" 299 | output = [] 300 | for char in text: 301 | cp = ord(char) 302 | if cp == 0 or cp == 0xFFFD or _is_control(char): 303 | continue 304 | if _is_whitespace(char): 305 | output.append(" ") 306 | else: 307 | output.append(char) 308 | return "".join(output) 309 | 310 | 311 | class WordpieceTokenizer(object): 312 | """Runs WordPiece tokenization.""" 313 | 314 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): 315 | self.vocab = vocab 316 | self.unk_token = unk_token 317 | self.max_input_chars_per_word = max_input_chars_per_word 318 | 319 | def tokenize(self, text): 320 | """Tokenizes a piece of text into its word pieces. 321 | This uses a greedy longest-match-first algorithm to perform tokenization 322 | using the given vocabulary. 323 | For example: 324 | input = "unaffable" 325 | output = ["un", "##aff", "##able"] 326 | Args: 327 | text: A single token or whitespace separated tokens. This should have 328 | already been passed through `BasicTokenizer`. 329 | Returns: 330 | A list of wordpiece tokens. 331 | """ 332 | 333 | output_tokens = [] 334 | for token in whitespace_tokenize(text): 335 | chars = list(token) 336 | if len(chars) > self.max_input_chars_per_word: 337 | output_tokens.append(self.unk_token) 338 | continue 339 | 340 | is_bad = False 341 | start = 0 342 | sub_tokens = [] 343 | while start < len(chars): 344 | end = len(chars) 345 | cur_substr = None 346 | while start < end: 347 | substr = "".join(chars[start:end]) 348 | if start > 0: 349 | substr = "##" + substr 350 | if substr in self.vocab: 351 | cur_substr = substr 352 | break 353 | end -= 1 354 | if cur_substr is None: 355 | is_bad = True 356 | break 357 | sub_tokens.append(cur_substr) 358 | start = end 359 | 360 | if is_bad: 361 | output_tokens.append(self.unk_token) 362 | else: 363 | output_tokens.extend(sub_tokens) 364 | return output_tokens 365 | 366 | 367 | def _is_whitespace(char): 368 | """Checks whether `chars` is a whitespace character.""" 369 | # \t, \n, and \r are technically contorl characters but we treat them 370 | # as whitespace since they are generally considered as such. 371 | if char == " " or char == "\t" or char == "\n" or char == "\r": 372 | return True 373 | cat = unicodedata.category(char) 374 | if cat == "Zs": 375 | return True 376 | return False 377 | 378 | 379 | def _is_control(char): 380 | """Checks whether `chars` is a control character.""" 381 | # These are technically control characters but we count them as whitespace 382 | # characters. 383 | if char == "\t" or char == "\n" or char == "\r": 384 | return False 385 | cat = unicodedata.category(char) 386 | if cat.startswith("C"): 387 | return True 388 | return False 389 | 390 | 391 | def _is_punctuation(char): 392 | """Checks whether `chars` is a punctuation character.""" 393 | cp = ord(char) 394 | # We treat all non-letter/number ASCII as punctuation. 395 | # Characters such as "^", "$", and "`" are not in the Unicode 396 | # Punctuation class but we treat them as punctuation anyways, for 397 | # consistency. 398 | if ( 399 | (cp >= 33 and cp <= 47) 400 | or (cp >= 58 and cp <= 64) 401 | or (cp >= 91 and cp <= 96) 402 | or (cp >= 123 and cp <= 126) 403 | ): 404 | return True 405 | cat = unicodedata.category(char) 406 | if cat.startswith("P"): 407 | return True 408 | return False 409 | -------------------------------------------------------------------------------- /scaelum/dynamics/__init__.py: -------------------------------------------------------------------------------- 1 | from .allocator import Allocator 2 | from .benchmarker import ModelBenchmarker, DeviceBenchmarker 3 | from .estimator import Estimator 4 | from .parameter_server import ParameterServer 5 | from .worker import Worker 6 | from .worker_manager import WorkerManager 7 | -------------------------------------------------------------------------------- /scaelum/dynamics/allocator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import math 5 | from loguru import logger 6 | import pulp 7 | 8 | from .benchmarker import ModelBenchmarker, DeviceBenchmarker 9 | from .worker_manager import WorkerManager 10 | 11 | 12 | class Allocator(object): 13 | def __init__( 14 | self, 15 | model_cfg: dict, 16 | worker_manager: WorkerManager, 17 | model_benchmarker: ModelBenchmarker, 18 | device_benchmarker: DeviceBenchmarker, 19 | ): 20 | self._model_cfg = model_cfg 21 | self._worker_manager = worker_manager 22 | self._model_benchmarker = model_benchmarker 23 | self._device_benchmarker = device_benchmarker 24 | 25 | def optimal_allocate(self, max_time=300, threads=24): 26 | 27 | # benchmark 28 | worker_ranks, workers_performance = zip( 29 | *self._device_benchmarker.benchmark().items() 30 | ) 31 | lf, lm = self._model_benchmarker.benchmark() 32 | # logger.info(f"layers flops: {lf}") 33 | # logger.info(f"layers memories: {lm}") 34 | 35 | D = len(worker_ranks) 36 | L = len(lf) 37 | 38 | # parse the results 39 | worker_ranks = [int(item.lstrip("worker")) for item in worker_ranks] 40 | logger.info(f"worker ranks: {worker_ranks}") 41 | dt = [item["time"] for item in workers_performance] 42 | logger.info(f"worker time: {dt}") 43 | dm = [item["avai_mem"] for item in workers_performance] 44 | # logger.info(f"worker memory limit: {dm}") 45 | 46 | # solve problem 47 | model = pulp.LpProblem("optimal_allocate", pulp.LpMinimize) 48 | logger.info("set up MIP") 49 | 50 | # create variables 51 | x = pulp.LpVariable.matrix("x", (range(D), range(L)), cat=pulp.LpBinary) 52 | y = pulp.LpVariable.matrix("y", range(D), lowBound=0, upBound=L) 53 | z = pulp.LpVariable.matrix("z", range(D), lowBound=0, upBound=L) 54 | q = pulp.LpVariable("max_device_time") 55 | logger.info("added all variables") 56 | 57 | # add one feasible solution to pre-solve 58 | avg_num_layer = math.floor(L / D) 59 | num_remain_layer = L - avg_num_layer * D 60 | 61 | bias = 0 62 | for row, device in enumerate(x): 63 | start_idx = row * avg_num_layer + bias 64 | if num_remain_layer > 0: 65 | num_remain_layer -= 1 66 | bias += 1 67 | end_idx = row * avg_num_layer + avg_num_layer - 1 + bias 68 | z[row].setInitialValue(start_idx) 69 | y[row].setInitialValue(end_idx) 70 | for col, layer in enumerate(device): 71 | if start_idx <= col <= end_idx: 72 | layer.setInitialValue(1) 73 | else: 74 | layer.setInitialValue(0) 75 | logger.info("added one feasible solution") 76 | 77 | # objective function 78 | model.objective = q 79 | logger.info("add obj.") 80 | 81 | # add constraints 82 | 83 | # constraint 1 84 | for i in range(D): 85 | model += ( 86 | pulp.LpAffineExpression([(x[i][j], lm[j]) for j in range(L)]) <= dm[i] 87 | ) 88 | 89 | # constraint 2 and 3 90 | for i in range(D): 91 | for j in range(L): 92 | model += y[i] >= j * x[i][j] 93 | model += z[i] <= j * x[i][j] + (L + 1) * (1 - x[i][j]) 94 | 95 | # constraint 4 96 | for i in range(D): 97 | model += y[i] - z[i] <= pulp.lpSum(x[i][j] for j in range(L)) - 1 98 | 99 | # constraint 5 100 | for j in range(L): 101 | model += pulp.lpSum(x[i][j] for i in range(D)) == 1 102 | 103 | # constraint 6 104 | for i in range(D): 105 | model += q >= dt[i] * pulp.lpSum(x[i][j] * lf[j] for j in range(L)) 106 | 107 | logger.info("added all constraints") 108 | 109 | solver_list = pulp.listSolvers(onlyAvailable=True) 110 | 111 | if "GUROBI_CMD" in solver_list: 112 | logger.info("using gurobi as solver") 113 | model.solve( 114 | pulp.GUROBI_CMD( 115 | timeLimit=max_time, 116 | msg=True, 117 | gapRel=0.2, 118 | threads=threads, 119 | warmStart=True, 120 | ) 121 | ) 122 | else: 123 | logger.info("using CBC as solver") 124 | model.solve( 125 | pulp.PULP_CBC_CMD( 126 | timeLimit=max_time, 127 | msg=True, 128 | gapRel=0.2, 129 | threads=threads, 130 | warmStart=True, 131 | ) 132 | ) 133 | 134 | for i in z: 135 | print(i.value(), end=" ") 136 | print() 137 | for i in y: 138 | print(i.value(), end=" ") 139 | print() 140 | 141 | # allocate to 142 | partition = [] 143 | for i in range(D): 144 | info = { 145 | "rank": worker_ranks[i], 146 | "start": int(z[i].value()), 147 | "end": int(y[i].value()), 148 | } 149 | partition.append(info) 150 | # sort partition by idx 151 | partition.sort(key=lambda t: t["start"]) 152 | print(partition) 153 | 154 | for i, info in enumerate(partition): 155 | for worker in self._worker_manager.worker_pool: 156 | if info["rank"] == worker.rank: 157 | print(f"rank {worker.rank}", end=" ") 158 | layers = self._model_cfg[info["start"] : info["end"] + 1] 159 | print(f"has layer {info['start']} to layer {info['end']}", end=" ") 160 | worker.model_config = layers 161 | print("and set up new config") 162 | worker.order = i + 1 163 | print(f"rank {worker.rank}'s order: {worker.order}") 164 | 165 | # for i, rank in enumerate(worker_ranks): 166 | # for worker in self._worker_manager.worker_pool: 167 | # if worker.rank == rank: 168 | # layers = self._model_cfg[int(z[i].value()):int(y[i].value())] 169 | # print(f"rank {rank} has layer {int(z[i].value())} to {int(y[i].value())}") 170 | # worker.model_config = layers 171 | # worker.order = i + 1 172 | 173 | self._worker_manager.reset_rank_by_order() 174 | print("reset by order") 175 | 176 | for worker in self._worker_manager.worker_pool: 177 | print(worker.rank) 178 | 179 | return self._worker_manager 180 | 181 | def dynamic_allocate(self, break_iter=1000): 182 | """ 183 | Allocate the layers dynamically among the workers 184 | """ 185 | 186 | # get results 187 | worker_time_and_avai_mem = self._device_benchmarker.benchmark() 188 | layer_flops, layer_mem = self._model_benchmarker.benchmark() 189 | 190 | print("worker_time_and_avai_mem: {}".format(worker_time_and_avai_mem)) 191 | # print('layer_flops: {}'.format(layer_flops)) 192 | # print('layer_mem: {}'.format(layer_mem)) 193 | 194 | # parse the results 195 | worker_time_and_avai_mem = list(worker_time_and_avai_mem.items()) 196 | worker_ranks = [ 197 | int(item[0].lstrip("worker")) for item in worker_time_and_avai_mem 198 | ] 199 | worker_time = [item[1]["time"] for item in worker_time_and_avai_mem] 200 | worker_avai_mem = [item[1]["avai_mem"] for item in worker_time_and_avai_mem] 201 | 202 | # check if the smallest worker avai mem can hold the smallest layer 203 | assert min(worker_avai_mem) > min( 204 | layer_mem 205 | ), "The smallest worker has insufficient memory for smallest layer" 206 | 207 | # create partition index 208 | num_layer = len(layer_flops) 209 | num_worker = len(worker_ranks) 210 | avg_num_layers = math.floor(num_layer / num_worker) 211 | remainder = num_layer - avg_num_layers * num_worker 212 | num_layers_on_worker = [avg_num_layers] * num_worker 213 | 214 | for i in range(num_worker): 215 | if remainder > 0: 216 | num_layers_on_worker[i] += 1 217 | remainder -= 1 218 | else: 219 | break 220 | partition_idx = [0] + [ 221 | sum(num_layers_on_worker[:idx]) for idx in range(1, num_worker + 1) 222 | ] 223 | 224 | # partition based on benchmark results 225 | partition_idx = self._allocate_by_mem( 226 | worker_rank=worker_ranks, 227 | partition_idx=partition_idx, 228 | worker_avai_mem=worker_avai_mem, 229 | layer_mem=layer_mem, 230 | ) 231 | partition_idx = self._allocate_by_flops_time( 232 | worker_rank=worker_ranks, 233 | partition_idx=partition_idx, 234 | worker_time=worker_time, 235 | layer_flops=layer_flops, 236 | worker_avai_mem=worker_avai_mem, 237 | layer_mem=layer_mem, 238 | break_iter=break_iter, 239 | ) 240 | 241 | # allocate to configs 242 | for i, rank in enumerate(worker_ranks): 243 | for worker in self._worker_manager.worker_pool: 244 | if worker.rank == rank: 245 | print(f"rank {worker.rank}", end=" ") 246 | layers = self._model_cfg[partition_idx[i] : partition_idx[i + 1]] 247 | print( 248 | f"rank {rank} has layer {int(partition_idx[i])} to {partition_idx[i + 1]}" 249 | ) 250 | worker.model_config = layers 251 | worker.order = i + 1 252 | 253 | self._worker_manager.reset_rank_by_order() 254 | for worker in self._worker_manager.worker_pool: 255 | print(worker.rank, end=" ") 256 | 257 | return self._worker_manager 258 | 259 | def even_allocate(self): 260 | """ 261 | Allocate the layers equally among the workers based on the number of layers 262 | """ 263 | num_worker = len(self._worker_manager.worker_pool) 264 | num_layer = len(self._model_cfg) 265 | avg_num_layer = math.floor(num_layer / num_worker) 266 | num_remain_layer = num_layer - avg_num_layer * num_worker 267 | cur_layer_idx = 0 268 | 269 | for idx, worker in enumerate(self._worker_manager.worker_pool): 270 | if num_remain_layer > 0: 271 | num_remain_layer -= 1 272 | cur_num_layer = avg_num_layer + 1 273 | else: 274 | cur_num_layer = avg_num_layer 275 | 276 | layers = self._model_cfg[cur_layer_idx : cur_layer_idx + cur_num_layer] 277 | worker.model_config = layers 278 | cur_layer_idx = cur_layer_idx + cur_num_layer 279 | 280 | return self._worker_manager 281 | 282 | def _get_num_layers_on_worker(self, index, partition_idx): 283 | return partition_idx[index + 1] - partition_idx[index] 284 | 285 | def _is_last_worker(self, index, worker_rank): 286 | return index == len(worker_rank) - 1 287 | 288 | def _list_greater_than(self, l1, l2): 289 | for x, y in zip(l1, l2): 290 | if x < y: 291 | return False 292 | 293 | return True 294 | 295 | def _allocate_by_flops_time( 296 | self, 297 | worker_rank, 298 | partition_idx, 299 | worker_time, 300 | layer_flops, 301 | worker_avai_mem, 302 | layer_mem, 303 | break_iter, 304 | ): 305 | # normalize time results 306 | worker_time = [item / min(worker_time) for item in worker_time] 307 | 308 | # iteratively update partition index based on flops * time 309 | iter = 0 310 | while True: 311 | # calculate flops on each worker 312 | workers_flops_time_allocated = [ 313 | sum(layer_flops[partition_idx[j] : partition_idx[j + 1]]) 314 | * worker_time[j] 315 | for j in range(len(worker_rank)) 316 | ] 317 | 318 | # set the target flops * time on average 319 | target = sum(workers_flops_time_allocated) // len(worker_rank) 320 | 321 | old_partition_idx = partition_idx[:] 322 | 323 | for j in range(len(worker_rank) - 1): 324 | current_workload = ( 325 | sum(layer_flops[partition_idx[j] : partition_idx[j + 1]]) 326 | * worker_time[j] 327 | ) 328 | 329 | if ( 330 | current_workload < target 331 | and self._get_num_layers_on_worker(j + 1, partition_idx) > 1 332 | ): 333 | # add a layer if memory allows 334 | expected_ram_allocated = sum( 335 | layer_mem[partition_idx[j] : partition_idx[j + 1] + 1] 336 | ) 337 | if expected_ram_allocated < worker_avai_mem[j]: 338 | partition_idx[j + 1] += 1 339 | else: 340 | last_layer_workload_on_this_device = ( 341 | layer_flops[partition_idx[j + 1] - 1] * worker_time[j] 342 | ) 343 | workload_on_next_device = ( 344 | sum(layer_flops[partition_idx[j] : partition_idx[j + 1]]) 345 | * worker_time[j] 346 | ) 347 | 348 | if ( 349 | workload_on_next_device < target 350 | and current_workload 351 | > target + last_layer_workload_on_this_device 352 | and self._get_num_layers_on_worker(j, partition_idx) > 1 353 | ): 354 | next_worker_expected_ram_allocated = sum( 355 | layer_mem[partition_idx[j + 1] - 1 : partition_idx[j + 2]] 356 | ) 357 | if next_worker_expected_ram_allocated < worker_avai_mem[j + 1]: 358 | partition_idx[j + 1] -= 1 359 | 360 | if old_partition_idx == partition_idx: 361 | break 362 | 363 | iter += 1 364 | 365 | if iter == break_iter: 366 | break 367 | 368 | return partition_idx 369 | 370 | def _allocate_by_mem(self, worker_rank, partition_idx, worker_avai_mem, layer_mem): 371 | # flag for if allocation satisfy memory requirement 372 | mem_satisfy = False 373 | 374 | def _compute_mem_allocated(lm, pi, wr): 375 | return [sum(lm[pi[j] : pi[j + 1]]) for j in range(len(wr))] 376 | 377 | # iteratively update partition index based on mem_avai and mem_allocated 378 | while True: 379 | # calculate flops on each worker 380 | workers_mem_allocated = _compute_mem_allocated( 381 | layer_mem, partition_idx, worker_avai_mem 382 | ) 383 | 384 | # break the loop if mem allocated < avai mem on each worker 385 | if self._list_greater_than(worker_avai_mem, workers_mem_allocated): 386 | mem_satisfy = True 387 | break 388 | 389 | old_partition_idx = partition_idx[:] 390 | 391 | for j in range(len(worker_rank) - 1): 392 | while ( 393 | workers_mem_allocated[j] > worker_avai_mem[j] 394 | and partition_idx[j + 1] - partition_idx[j] > 1 395 | ): 396 | # remove a layer if memory is not enough 397 | partition_idx[j + 1] -= 1 398 | workers_mem_allocated = _compute_mem_allocated( 399 | layer_mem, partition_idx, worker_avai_mem 400 | ) 401 | 402 | if self._list_greater_than(worker_avai_mem, workers_mem_allocated): 403 | mem_satisfy = True 404 | break 405 | 406 | if mem_satisfy: 407 | break 408 | 409 | # add a layer if memory allows 410 | while ( 411 | workers_mem_allocated[j] < worker_avai_mem[j] 412 | and partition_idx[j + 2] - partition_idx[j + 1] > 1 413 | ): 414 | expected_ram_allocated = sum( 415 | layer_mem[partition_idx[j] : partition_idx[j + 1] + 1] 416 | ) 417 | if expected_ram_allocated < worker_avai_mem[j]: 418 | partition_idx[j + 1] += 1 419 | workers_mem_allocated = _compute_mem_allocated( 420 | layer_mem, partition_idx, worker_avai_mem 421 | ) 422 | else: 423 | break 424 | 425 | if self._list_greater_than(worker_avai_mem, workers_mem_allocated): 426 | mem_satisfy = True 427 | break 428 | 429 | if mem_satisfy: 430 | break 431 | 432 | if old_partition_idx == partition_idx: 433 | break 434 | 435 | if mem_satisfy: 436 | return partition_idx 437 | else: 438 | print(partition_idx) 439 | raise Exception("memory allocation failed") 440 | -------------------------------------------------------------------------------- /scaelum/dynamics/benchmarker.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import abc 5 | import os 6 | import scaelum.utils as dutils 7 | import torch 8 | import torch.distributed.rpc as rpc 9 | from scaelum.builder import build_module_from_cfg, build_layer 10 | from scaelum.dataset import BaseGenerator 11 | from .estimator import Estimator 12 | from .worker_manager import WorkerManager 13 | 14 | if os.getenv("STIMULATE") is not None: 15 | from dllb.stimulator import Stimulator 16 | 17 | 18 | class BaseBenchmarker(object): 19 | """ 20 | A base class for benchmarking objects 21 | """ 22 | 23 | __metaclass__ = abc.ABCMeta 24 | 25 | @abc.abstractmethod 26 | def benchmark(self): 27 | raise NotImplementedError("not implemented yet") 28 | 29 | 30 | class DeviceBenchmarker(BaseBenchmarker): 31 | def __init__( 32 | self, 33 | worker_manager: WorkerManager, 34 | data_generator: BaseGenerator, 35 | model_config: dict, 36 | iterations: int, 37 | dtype: str = None, 38 | ): 39 | super(DeviceBenchmarker, self).__init__() 40 | self._worker_manager = worker_manager 41 | self._model_config = model_config 42 | self._data_generator = data_generator 43 | self._iterations = iterations 44 | self._dtype = dtype 45 | # If stimulate on HPC, then use stimulator to slowdown 46 | if os.getenv("STIMULATE") is not None: # TODO: Is the usage correct? 47 | self._stimulator = Stimulator(self._worker_manager.size) 48 | 49 | @staticmethod 50 | def local_benchmark(rank, data, model_cfg, module_wrapper_cfg, iterations, dtype): 51 | """ 52 | A method to wrap the benchmarking code 53 | """ 54 | model = build_module_from_cfg( 55 | rank=rank, model_cfg=model_cfg, module_wrapper_cfg=module_wrapper_cfg 56 | ) 57 | 58 | device = next(model.parameters()).device 59 | # print( 60 | # 'rank : {}: device: {}, model gpu index: {}'.format(rank, device, model.gpu_index)) 61 | 62 | time = Estimator.benchmark_speed( 63 | model=model, data=data, device=device, iterations=iterations, dtype=dtype 64 | ) 65 | avai_mem = model.detect_mem(destroy_module=True) 66 | del model 67 | 68 | if torch.cuda.is_available(): 69 | torch.cuda.empty_cache() 70 | 71 | return time, avai_mem 72 | 73 | def benchmark(self): 74 | """ 75 | Run benchmarking to test the computational speed and available memory 76 | of the local or remote 77 | """ 78 | 79 | # init results 80 | results = dict() 81 | 82 | # get data 83 | data = self._data_generator.generate() 84 | 85 | # benchmarking on different devices 86 | result_queue = [] 87 | for worker in self._worker_manager.worker_pool: 88 | rank = worker.rank 89 | worker_name = dutils.generate_worker_name(rank) 90 | module_wrapper_cfg = worker.extra_config.copy() 91 | 92 | if rank == 0: 93 | # run locally 94 | time, avai_mem = self.local_benchmark( 95 | rank, 96 | data, 97 | self._model_config, 98 | module_wrapper_cfg, 99 | self._iterations, 100 | self._dtype, 101 | ) 102 | result_queue.append((worker_name, time, avai_mem)) 103 | else: 104 | res = rpc.rpc_async( 105 | to=worker_name, 106 | func=DeviceBenchmarker.local_benchmark, 107 | args=( 108 | rank, 109 | data, 110 | self._model_config, 111 | module_wrapper_cfg, 112 | self._iterations, 113 | self._dtype, 114 | ), 115 | ) 116 | result_queue.append((worker_name, res)) 117 | 118 | for res in result_queue: 119 | worker_name = res[0] 120 | 121 | if len(res) == 2: 122 | time, avai_mem = res[1].wait() 123 | else: 124 | time, avai_mem = res[1], res[2] 125 | 126 | if os.getenv("STIMULATE") is not None: 127 | # int(res[0].lstrip("worker")) get the original rank 128 | time *= self._stimulator.c_slowdown[int(res[0].lstrip("worker"))] 129 | avai_mem /= self._stimulator.m_slowdown[int(res[0].lstrip("worker"))] 130 | 131 | results[worker_name] = dict(time=time, avai_mem=avai_mem) 132 | 133 | return results 134 | 135 | 136 | class ModelBenchmarker(BaseBenchmarker): 137 | def __init__( 138 | self, 139 | model_config: dict, 140 | data_generator: BaseGenerator, 141 | device: str, 142 | dtype: str = None, 143 | param_scale: int = 2, 144 | ): 145 | super(ModelBenchmarker, self).__init__() 146 | self._model_config = model_config 147 | self._data_generator = data_generator 148 | self._device = device 149 | self._dtype = dtype 150 | self._param_scale = param_scale 151 | 152 | @property 153 | def model_config(self): 154 | return self._model_config 155 | 156 | def benchmark(self): 157 | flops_list = [] 158 | mem_list = [] 159 | 160 | # measure flops of each layer 161 | data = self._data_generator.generate() 162 | 163 | # NOTE: this only applies to BERT since the single machine cannot host such a large model and will cause OOM 164 | # TODO: Remove this is you wish to use this framework for other models 165 | num_encoder_layer = int((len(self._model_config) - 3) / 3) 166 | model_cfg = self._model_config[:4] + self._model_config[-2:] 167 | 168 | for idx, layer_cfg in enumerate(model_cfg): 169 | # build layer 170 | layer_cfg_copy = layer_cfg.copy() 171 | layer_type = layer_cfg_copy.pop("layer_type") 172 | layer = build_layer(layer_type, **layer_cfg_copy) 173 | 174 | # get flops and mem usage 175 | output, flops, mem_usage = Estimator.benchmark_model( 176 | model=layer, 177 | data=data, 178 | device=self._device, 179 | dtype=self._dtype, 180 | param_scale=self._param_scale, 181 | ) 182 | 183 | # remove layer to save RAM 184 | del layer 185 | if torch.cuda.is_available(): 186 | torch.cuda.empty_cache() 187 | 188 | # override input 189 | data = output 190 | 191 | # log 192 | flops_list.append(flops) 193 | mem_list.append(mem_usage) 194 | 195 | # NOTE: this only applies to BERT since the single machine cannot host such a large model and will cause OOM 196 | # TODO: Remove this is you wish to use this framework for other models 197 | flops_list = ( 198 | [flops_list[0]] + flops_list[1:4] * num_encoder_layer + flops_list[-2:] 199 | ) 200 | mem_list = [mem_list[0]] + mem_list[1:4] * num_encoder_layer + mem_list[-2:] 201 | return flops_list, mem_list 202 | -------------------------------------------------------------------------------- /scaelum/dynamics/estimator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | from collections import OrderedDict 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | from scaelum import utils 10 | from pthflops import count_ops 11 | 12 | 13 | class Estimator: 14 | @staticmethod 15 | def benchmark_speed(model, data, device, iterations, dtype=None): 16 | # convert to tuple 17 | data = Estimator._convert_to_tuple(data) 18 | 19 | # convert data type 20 | if dtype: 21 | data = Estimator._convert_dtype(data, dtype) 22 | 23 | # move to device 24 | data = Estimator._move_to_device(data, device) 25 | model = model.to(device) 26 | 27 | # measure forward time 28 | with torch.no_grad(): 29 | for i in range(iterations): 30 | model(*data) 31 | end = utils.get_time() 32 | total_time = sum(model.forward_time) 33 | # print('forward time on rank {}: {}'.format(model.rank, model.forward_time)) 34 | return total_time 35 | 36 | @staticmethod 37 | def _convert_dtype(data, dtype): 38 | assert isinstance(dtype, str) 39 | dtype = getattr(torch, dtype) 40 | data = [_data.to(dtype) for _data in data] 41 | return data 42 | 43 | @staticmethod 44 | def _move_to_device(data, device): 45 | data = [_data.to(device) for _data in data] 46 | return data 47 | 48 | @staticmethod 49 | def _convert_to_tuple(data): 50 | if isinstance(data, (list, tuple)): 51 | return data 52 | else: 53 | return (data,) 54 | 55 | @staticmethod 56 | def benchmark_model(model, data, device, dtype=None, param_scale=2): 57 | # convert to tuple 58 | data = Estimator._convert_to_tuple(data) 59 | 60 | # convert dtype 61 | if dtype: 62 | data = Estimator._convert_dtype(data, dtype) 63 | 64 | # move to device 65 | data = Estimator._move_to_device(data, device) 66 | model = model.to(device) 67 | 68 | # get output 69 | output = model(*data) 70 | 71 | flops = Estimator._calc_flops(model, data) 72 | mem_usage = Estimator._calc_memory_usage(model, data, param_scale) 73 | return output, flops, mem_usage 74 | 75 | @staticmethod 76 | def _calc_flops(model, data): 77 | # need to convert to tuple for multiple inputs for jit tracing 78 | if isinstance(data, list): 79 | data = tuple(data) 80 | assert isinstance(data, tuple) 81 | flops, _ = count_ops(model, data, print_readable=False) 82 | return flops 83 | 84 | @staticmethod 85 | def _calc_memory_usage(model, data, param_scale): 86 | assert isinstance(data, (list, tuple)) 87 | 88 | def register_hook(module): 89 | def hook(module, input, output): 90 | class_name = str(module.__class__).split(".")[-1].split("'")[0] 91 | module_idx = len(summary) 92 | 93 | m_key = "%s-%i" % (class_name, module_idx + 1) 94 | summary[m_key] = OrderedDict() 95 | if isinstance(output, (list, tuple)): 96 | summary[m_key]["output_shape"] = [list(o.size()) for o in output] 97 | else: 98 | summary[m_key]["output_shape"] = [list(output.size())] 99 | 100 | # TODO: add reserved memory for CUDA and use a.nelement() * a.element_size() to calculate size 101 | params = 0 102 | if hasattr(module, "weight") and hasattr(module.weight, "size"): 103 | params += torch.prod(torch.LongTensor(list(module.weight.size()))) 104 | if hasattr(module, "bias") and hasattr(module.bias, "size"): 105 | params += torch.prod(torch.LongTensor(list(module.bias.size()))) 106 | summary[m_key]["nb_params"] = params 107 | 108 | if not isinstance(module, nn.Sequential) and not isinstance( 109 | module, nn.ModuleList 110 | ): 111 | hooks.append(module.register_forward_hook(hook)) 112 | 113 | # create properties 114 | summary = OrderedDict() 115 | hooks = [] 116 | 117 | # register hook 118 | model.apply(register_hook) 119 | 120 | # make a forward pass 121 | model(*data) 122 | 123 | # remove these hooks 124 | for h in hooks: 125 | h.remove() 126 | 127 | total_params = 0 128 | total_output = 0 129 | for layer in summary: 130 | # input_shape, output_shape, trainable, nb_params 131 | total_params += summary[layer]["nb_params"] 132 | total_output += sum( 133 | [ 134 | np.prod(output_shape) 135 | for output_shape in summary[layer]["output_shape"] 136 | ] 137 | ) 138 | 139 | # assume 4 bytes/number (float on cuda). 140 | total_input_size = abs( 141 | sum([np.prod(_data.size()) * 4.0 / (1024**2.0) for _data in data]) 142 | ) 143 | total_output_size = abs( 144 | 2.0 * total_output * 4.0 / (1024**2.0) 145 | ) # x2 for gradients 146 | total_params_size = abs( 147 | param_scale * total_params * 4.0 / (1024**2.0) 148 | ) # x param_scale for backward 149 | total_size = total_params_size + total_output_size + total_input_size 150 | 151 | # return summary 152 | return total_size.item() 153 | -------------------------------------------------------------------------------- /scaelum/dynamics/parameter_server.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | 5 | from collections import OrderedDict 6 | from typing import Dict 7 | 8 | import torch 9 | import torch.nn as nn 10 | from scaelum.builder import build_layer 11 | from torch import Tensor 12 | 13 | 14 | class ParameterServer(nn.Module): 15 | def __init__(self, model_config: list) -> None: 16 | super(ParameterServer, self).__init__() 17 | self._model_config = model_config 18 | self.module_list = nn.ModuleList() 19 | 20 | self._build_model() 21 | 22 | def _build_model(self) -> None: 23 | for cfg in self._model_config: 24 | cfg_copy = cfg.copy() 25 | layer_type = cfg_copy.pop("layer_type") 26 | layer = build_layer(layer_type, **cfg_copy) 27 | self.module_list.append(layer) 28 | 29 | def load_weights_from_file(self, checkpoint: str) -> None: 30 | self.module_list.load_state_dict(torch.load(checkpoint)) 31 | 32 | def save_weights_to_file(self, checkpoint: str) -> None: 33 | torch.save(self.module_list.state_dict(), checkpoint) 34 | 35 | def update_weights(self, state_dict: OrderedDict, idx: int) -> None: 36 | self.module_list[idx].load_state_dict(state_dict) 37 | 38 | def get_state_dict(self, idx: int) -> Dict[str, Tensor]: 39 | return self.module_list[idx].state_dict() 40 | -------------------------------------------------------------------------------- /scaelum/dynamics/worker.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | 5 | import uuid 6 | 7 | 8 | class Worker(object): 9 | def __init__( 10 | self, 11 | rank: int, 12 | name: int, 13 | server_config: dict, 14 | worker_id: str = None, 15 | order: int = None, 16 | model_config: list = None, 17 | extra_config: dict = None, 18 | is_running: bool = False, 19 | ) -> None: 20 | 21 | self._rank = rank 22 | self._name = name 23 | self._is_running = is_running 24 | self._order = order 25 | if worker_id is None: 26 | self._worker_id = uuid.uuid4().__str__() 27 | else: 28 | self._worker_id = worker_id 29 | 30 | # configs 31 | self._server_config = server_config 32 | self._model_config = model_config 33 | self._extra_config = extra_config 34 | 35 | @property 36 | def rank(self) -> int: 37 | return self._rank 38 | 39 | @property 40 | def id(self) -> str: 41 | return self._worker_id 42 | 43 | @property 44 | def name(self) -> str: 45 | return self._name 46 | 47 | @property 48 | def model_config(self) -> list: 49 | return self._model_config 50 | 51 | @property 52 | def env_config(self) -> dict: 53 | return self._env_config 54 | 55 | @property 56 | def server_config(self) -> dict: 57 | return self._server_config 58 | 59 | @property 60 | def extra_config(self) -> dict: 61 | return self._extra_config 62 | 63 | @property 64 | def is_running(self) -> bool: 65 | return self._is_running 66 | 67 | @property 68 | def order(self) -> None: 69 | return self._order 70 | 71 | @order.setter 72 | def order(self, ord: int) -> None: 73 | self._order = ord 74 | 75 | @is_running.setter 76 | def is_running(self, status: bool) -> None: 77 | self._is_running = status 78 | 79 | @model_config.setter 80 | def model_config(self, config: list) -> None: 81 | self._model_config = config 82 | 83 | @rank.setter 84 | def rank(self, rank) -> None: 85 | self._rank = rank 86 | 87 | def serialize(self): 88 | return self.__dict__ 89 | 90 | @staticmethod 91 | def deserialize(data: dict): 92 | kwargs = dict() 93 | 94 | for k, v in data.items(): 95 | kwargs[k.lstrip("_")] = v 96 | 97 | return Worker(**kwargs) 98 | -------------------------------------------------------------------------------- /scaelum/dynamics/worker_manager.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | from .worker import Worker 5 | 6 | 7 | class WorkerManager(object): 8 | def __init__(self): 9 | self._worker_pool = [] 10 | 11 | @property 12 | def size(self): 13 | return len(self._worker_pool) 14 | 15 | @property 16 | def worker_pool(self): 17 | return self._worker_pool 18 | 19 | def get_by_id(self, id_str: str, allow_not_found: bool = False) -> Worker: 20 | for worker in self._worker_pool: 21 | if worker.id == id_str: 22 | return worker 23 | 24 | if not allow_not_found: 25 | raise LookupError( 26 | "Worker with id {} is not found in the worker pool".format(id_str) 27 | ) 28 | else: 29 | return None 30 | 31 | def load_worker_pool_from_config(self, config: dict) -> None: 32 | for i, worker_config in enumerate(config): 33 | worker = Worker(rank=i + 1, **worker_config) # rank 0 is reserved for host 34 | self._worker_pool.append(worker) 35 | 36 | def assign_model_to_worker(self, rank: int, model_config: dict) -> None: 37 | for worker in self._worker_pool: 38 | if worker.rank == rank: 39 | worker.model_config(model_config) 40 | return 41 | 42 | raise LookupError( 43 | "Worker with rank {} is not found in the worker pool".format(rank) 44 | ) 45 | 46 | def add_worker(self, worker_id: str, worker_config: dict) -> None: 47 | rank = len(self._worker_pool) + 1 48 | worker = Worker(rank=rank, worker_id=worker_id, **worker_config) 49 | self._worker_pool.append(worker) 50 | 51 | def _allocate_rank(self) -> None: 52 | for i, worker in enumerate(self._worker_pool): 53 | worker.rank = i + 1 # rank 0 is reserved for host 54 | 55 | def remove_worker_by_id(self, id_str: str) -> None: 56 | worker = self.get_by_id(id_str) 57 | assert not worker.is_running, "Worker {} is still running".format(id_str) 58 | 59 | self._worker_pool.remove(worker) 60 | self._allocate_rank() 61 | 62 | def reset_rank_by_order(self): 63 | self._worker_pool.sort(key=lambda x: x.order) 64 | self._allocate_rank() 65 | 66 | def serialize(self): 67 | res = [] 68 | for worker in self._worker_pool: 69 | res.append(worker.serialize()) 70 | return res 71 | 72 | @staticmethod 73 | def deserialize(data: list): 74 | worker_manager = WorkerManager() 75 | 76 | for worker_data in data: 77 | worker = Worker.deserialize(worker_data) 78 | worker_manager.worker_pool.append(worker) 79 | return worker_manager 80 | -------------------------------------------------------------------------------- /scaelum/logger/__init__.py: -------------------------------------------------------------------------------- 1 | from .logger import Logger 2 | -------------------------------------------------------------------------------- /scaelum/logger/logger.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | 4 | class Logger: 5 | def __init__(self, filename: str, mode: str = "a"): 6 | self.file = open(file=filename, mode=mode) 7 | 8 | def info(self, message: str) -> None: 9 | new_line = "INFO - {} - {}\n".format(datetime.now(), message) 10 | self._write(new_line) 11 | 12 | def _write(self, message: str) -> None: 13 | self.file.write(message) 14 | self.file.flush() 15 | -------------------------------------------------------------------------------- /scaelum/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .bert import * 2 | from .bert_layers import * 3 | from .layers import * 4 | from .rpc_model import RpcModel 5 | from .rpc_module import BaseModule, LocalModule, RemoteModule 6 | -------------------------------------------------------------------------------- /scaelum/model/bert.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | import sys 4 | 5 | 6 | class BertConfig(dict): 7 | """Configuration class to store the configuration of a `BertModel`.""" 8 | 9 | def __init__( 10 | self, 11 | vocab_size_or_config_json_file, 12 | hidden_size=768, 13 | num_hidden_layers=12, 14 | num_attention_heads=12, 15 | intermediate_size=3072, 16 | hidden_act="gelu", 17 | hidden_dropout_prob=0.1, 18 | attention_probs_dropout_prob=0.1, 19 | max_position_embeddings=512, 20 | type_vocab_size=2, 21 | initializer_range=0.02, 22 | output_all_encoded_layers=False, 23 | ): 24 | """Constructs BertConfig. 25 | Args: 26 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`. 27 | hidden_size: Size of the encoder layers and the pooler layer. 28 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 29 | num_attention_heads: Number of attention heads for each attention layer in 30 | the Transformer encoder. 31 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 32 | layer in the Transformer encoder. 33 | hidden_act: The non-linear activation function (function or string) in the 34 | encoder and pooler. If string, "gelu", "relu" and "swish" are supported. 35 | hidden_dropout_prob: The dropout probabilitiy for all fully connected 36 | layers in the embeddings, encoder, and pooler. 37 | attention_probs_dropout_prob: The dropout ratio for the attention 38 | probabilities. 39 | max_position_embeddings: The maximum sequence length that this model might 40 | ever be used with. Typically set this to something large just in case 41 | (e.g., 512 or 1024 or 2048). 42 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 43 | `BertModel`. 44 | initializer_range: The sttdev of the truncated_normal_initializer for 45 | initializing all weight matrices. 46 | """ 47 | if isinstance(vocab_size_or_config_json_file, str) or ( 48 | sys.version_info[0] == 2 49 | and isinstance(vocab_size_or_config_json_file, unicode) 50 | ): 51 | with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader: 52 | json_config = json.loads(reader.read()) 53 | for key, value in json_config.items(): 54 | self.__dict__[key] = value 55 | elif isinstance(vocab_size_or_config_json_file, int): 56 | self.vocab_size = vocab_size_or_config_json_file 57 | self.hidden_size = hidden_size 58 | self.num_hidden_layers = num_hidden_layers 59 | self.num_attention_heads = num_attention_heads 60 | self.hidden_act = hidden_act 61 | self.intermediate_size = intermediate_size 62 | self.hidden_dropout_prob = hidden_dropout_prob 63 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 64 | self.max_position_embeddings = max_position_embeddings 65 | self.type_vocab_size = type_vocab_size 66 | self.initializer_range = initializer_range 67 | self.output_all_encoded_layers = output_all_encoded_layers 68 | else: 69 | raise ValueError( 70 | "First argument must be either a vocabulary size (int)" 71 | "or the path to a pretrained model config file (str)" 72 | ) 73 | 74 | @classmethod 75 | def from_dict(cls, json_object): 76 | """Constructs a `BertConfig` from a Python dictionary of parameters.""" 77 | config = BertConfig(vocab_size_or_config_json_file=-1) 78 | for key, value in json_object.items(): 79 | config.__dict__[key] = value 80 | return config 81 | 82 | @classmethod 83 | def from_json_file(cls, json_file): 84 | """Constructs a `BertConfig` from a json file of parameters.""" 85 | with open(json_file, "r", encoding="utf-8") as reader: 86 | text = reader.read() 87 | return cls.from_dict(json.loads(text)) 88 | 89 | def __repr__(self): 90 | return str(self.to_json_string()) 91 | 92 | def to_dict(self): 93 | """Serializes this instance to a Python dictionary.""" 94 | output = copy.deepcopy(self.__dict__) 95 | return output 96 | 97 | def to_json_string(self): 98 | """Serializes this instance to a JSON string.""" 99 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 100 | 101 | 102 | # class BertPreTrainedModel(nn.Module): 103 | # """ An abstract class to handle weights initialization and 104 | # a simple interface for dowloading and loading pretrained models. 105 | # """ 106 | 107 | # def __init__(self, config, *inputs, **kwargs): 108 | # super(BertPreTrainedModel, self).__init__() 109 | # if not isinstance(config, BertConfig): 110 | # raise ValueError( 111 | # "Parameter config in `{}(config)` should be an instance of class `BertConfig`. " 112 | # "To create a model from a Google pretrained model use " 113 | # "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( 114 | # self.__class__.__name__, self.__class__.__name__ 115 | # )) 116 | # self.config = config 117 | 118 | # def init_bert_weights(self, module): 119 | # """ Initialize the weights. 120 | # """ 121 | # if isinstance(module, (nn.Linear, nn.Embedding)): 122 | # # Slightly different from the TF version which uses truncated_normal for initialization 123 | # # cf https://github.com/pytorch/pytorch/pull/5617 124 | # module.weight.data.normal_( 125 | # mean=0.0, std=self.config.initializer_range) 126 | # elif isinstance(module, BertLayerNorm): 127 | # module.bias.data.zero_() 128 | # module.weight.data.fill_(1.0) 129 | # if isinstance(module, nn.Linear) and module.bias is not None: 130 | # module.bias.data.zero_() 131 | 132 | # def checkpoint_activations(self, val): 133 | # def _apply_flag(module): 134 | # if hasattr(module, "_checkpoint_activations"): 135 | # module._checkpoint_activations = val 136 | # self.apply(_apply_flag) 137 | 138 | 139 | # class BertModel(BertPreTrainedModel): 140 | # """BERT model ("Bidirectional Embedding Representations from a Transformer"). 141 | # Params: 142 | # config: a BertConfig class instance with the configuration to build a new model 143 | # Inputs: 144 | # `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 145 | # with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 146 | # `extract_features.py`, `run_classifier.py` and `run_squad.py`) 147 | # `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 148 | # types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 149 | # a `sentence B` token (see BERT paper for more details). 150 | # `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 151 | # selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 152 | # input sequence length in the current batch. It's the mask that we typically use for attention when 153 | # a batch has varying length sentences. 154 | # Outputs: Tuple of (encoded_layers, pooled_output) 155 | # `encoded_layers`: controled by `output_all_encoded_layers` argument: 156 | # - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end 157 | # of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each 158 | # encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size], 159 | # - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding 160 | # to the last attention block of shape [batch_size, sequence_length, hidden_size], 161 | # `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a 162 | # classifier pretrained on top of the hidden state associated to the first character of the 163 | # input (`CLS`) to train on the Next-Sentence task (see BERT's paper). 164 | # Example usage: 165 | # ```python 166 | # # Already been converted into WordPiece token ids 167 | # input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 168 | # input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 169 | # token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 170 | # config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 171 | # num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 172 | # model = modeling.BertModel(config=config) 173 | # all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) 174 | # ``` 175 | # """ 176 | 177 | # def __init__(self, config): 178 | # super(BertModel, self).__init__(config) 179 | # self.embeddings = BertEmbeddings(config) 180 | # self.encoder = BertEncoder(config) 181 | # self.pooler = BertPooler(config) 182 | # self.apply(self.init_bert_weights) 183 | # self.output_all_encoded_layers = config.output_all_encoded_layers 184 | 185 | # def forward(self, input_ids, token_type_ids, attention_mask): 186 | # # We create a 3D attention mask from a 2D tensor mask. 187 | # # Sizes are [batch_size, 1, 1, to_seq_length] 188 | # # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] 189 | # # this attention mask is more simple than the triangular masking of causal attention 190 | # # used in OpenAI GPT, we just need to prepare the broadcast dimension here. 191 | # extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) 192 | 193 | # # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 194 | # # masked positions, this operation will create a tensor which is 0.0 for 195 | # # positions we want to attend and -10000.0 for masked positions. 196 | # # Since we are adding it to the raw scores before the softmax, this is 197 | # # effectively the same as removing these entirely. 198 | # extended_attention_mask = extended_attention_mask.to( 199 | # dtype=self.embeddings.word_embeddings.weight.dtype) # fp16 compatibility 200 | # extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 201 | 202 | # embedding_output = self.embeddings(input_ids, token_type_ids) 203 | # encoded_layers = self.encoder( 204 | # embedding_output, extended_attention_mask) 205 | # sequence_output = encoded_layers[-1] 206 | # pooled_output = self.pooler(sequence_output) 207 | # if not self.output_all_encoded_layers: 208 | # encoded_layers = encoded_layers[-1:] 209 | # return encoded_layers, pooled_output 210 | 211 | 212 | # class BertForSequenceClassification(BertPreTrainedModel): 213 | # """BERT model for classification. 214 | # This module is composed of the BERT model with a linear layer on top of 215 | # the pooled output. 216 | # Params: 217 | # `config`: a BertConfig class instance with the configuration to build a new model. 218 | # `num_labels`: the number of classes for the classifier. Default = 2. 219 | # Inputs: 220 | # `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 221 | # with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 222 | # `extract_features.py`, `run_classifier.py` and `run_squad.py`) 223 | # `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 224 | # types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 225 | # a `sentence B` token (see BERT paper for more details). 226 | # `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 227 | # selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 228 | # input sequence length in the current batch. It's the mask that we typically use for attention when 229 | # a batch has varying length sentences. 230 | # `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] 231 | # with indices selected in [0, ..., num_labels]. 232 | # Outputs: 233 | # if `labels` is not `None`: 234 | # Outputs the CrossEntropy classification loss of the output with the labels. 235 | # if `labels` is `None`: 236 | # Outputs the classification logits of shape [batch_size, num_labels]. 237 | # Example usage: 238 | # ```python 239 | # # Already been converted into WordPiece token ids 240 | # input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 241 | # input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 242 | # token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 243 | # config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 244 | # num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 245 | # num_labels = 2 246 | # model = BertForSequenceClassification(config, num_labels) 247 | # logits = model(input_ids, token_type_ids, input_mask) 248 | # ``` 249 | # """ 250 | 251 | # def __init__(self, config, num_labels): 252 | # super(BertForSequenceClassification, self).__init__(config) 253 | # self.num_labels = num_labels 254 | # self.bert = BertModel(config) 255 | # self.dropout = nn.Dropout(config.hidden_dropout_prob) 256 | # self.classifier = nn.Linear(config.hidden_size, num_labels) 257 | # self.apply(self.init_bert_weights) 258 | 259 | # def forward(self, input_ids, token_type_ids=None, attention_mask=None): 260 | # _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask) 261 | # pooled_output = self.dropout(pooled_output) 262 | # return self.classifier(pooled_output) 263 | -------------------------------------------------------------------------------- /scaelum/model/bert_layers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | """PyTorch BERT model.""" 4 | 5 | from __future__ import absolute_import, division, print_function, unicode_literals 6 | 7 | import math 8 | import sys 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | import torch.nn.init as init 13 | from scaelum.registry import LAYER 14 | from torch import nn 15 | from torch.nn import Module 16 | from torch.nn.parameter import Parameter 17 | 18 | from .bert import BertConfig 19 | 20 | 21 | def gelu(x): 22 | return x * 0.5 * (1.0 + torch.erf(x / 1.41421)) 23 | 24 | 25 | # used only for triton inference 26 | 27 | 28 | def bias_gelu(bias, y): 29 | x = bias + y 30 | return x * 0.5 * (1.0 + torch.erf(x / 1.41421)) 31 | 32 | 33 | # used specifically for training since torch.nn.functional.gelu breaks ONNX export 34 | 35 | 36 | def bias_gelu_training(bias, y): 37 | x = bias + y 38 | return torch.nn.functional.gelu(x) # Breaks ONNX export 39 | 40 | 41 | def bias_tanh(bias, y): 42 | x = bias + y 43 | return torch.tanh(x) 44 | 45 | 46 | def swish(x): 47 | return x * torch.sigmoid(x) 48 | 49 | 50 | # torch.nn.functional.gelu(x) # Breaks ONNX export 51 | ACT2FN = { 52 | "gelu": gelu, 53 | "bias_gelu": bias_gelu, 54 | "bias_tanh": bias_tanh, 55 | "relu": torch.nn.functional.relu, 56 | "swish": swish, 57 | } 58 | 59 | 60 | class LinearActivation(Module): 61 | r"""Fused Linear and activation Module.""" 62 | __constants__ = ["bias"] 63 | 64 | def __init__(self, in_features, out_features, act="gelu", bias=True): 65 | super(LinearActivation, self).__init__() 66 | self.in_features = in_features 67 | self.out_features = out_features 68 | self.act_fn = nn.Identity() # 69 | self.biased_act_fn = None # 70 | # 71 | self.bias = None 72 | # For TorchScript 73 | if isinstance(act, str) or ( 74 | sys.version_info[0] == 2 and isinstance(act, unicode) 75 | ): 76 | if bias and not "bias" in act: # compatibility 77 | act = "bias_" + act # 78 | # 79 | self.biased_act_fn = ACT2FN[act] 80 | 81 | else: 82 | self.act_fn = ACT2FN[act] 83 | else: 84 | self.act_fn = act 85 | self.weight = Parameter(torch.Tensor(out_features, in_features)) 86 | if bias: 87 | self.bias = Parameter(torch.Tensor(out_features)) 88 | else: 89 | self.register_parameter("bias", None) 90 | self.reset_parameters() 91 | 92 | def reset_parameters(self): 93 | init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 94 | if self.bias is not None: 95 | fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) 96 | bound = 1 / math.sqrt(fan_in) 97 | init.uniform_(self.bias, -bound, bound) 98 | 99 | def forward(self, input): 100 | if not self.bias is None: 101 | return self.biased_act_fn(self.bias, F.linear(input, self.weight, None)) 102 | else: 103 | return self.act_fn(F.linear(input, self.weight, self.bias)) 104 | 105 | def extra_repr(self): 106 | return "in_features={}, out_features={}, bias={}".format( 107 | self.in_features, self.out_features, self.bias is not None 108 | ) 109 | 110 | 111 | class BertNonFusedLayerNorm(nn.Module): 112 | def __init__(self, hidden_size, eps=1e-12): 113 | """Construct a layernorm module in the TF style (epsilon inside the square root).""" 114 | super(BertNonFusedLayerNorm, self).__init__() 115 | self.weight = nn.Parameter(torch.ones(hidden_size)) 116 | self.bias = nn.Parameter(torch.zeros(hidden_size)) 117 | self.variance_epsilon = eps 118 | 119 | def forward(self, x): 120 | u = x.mean(-1, keepdim=True) 121 | s = x - u 122 | s = s * s 123 | s = s.mean(-1, keepdim=True) 124 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 125 | return self.weight * x + self.bias 126 | 127 | 128 | try: 129 | import apex 130 | 131 | # apex.amp.register_half_function(apex.normalization.fused_layer_norm, 'FusedLayerNorm') 132 | import apex.normalization 133 | from apex.normalization.fused_layer_norm import FusedLayerNormAffineFunction 134 | 135 | # apex.amp.register_float_function(apex.normalization.FusedLayerNorm, 'forward') 136 | # BertLayerNorm = apex.normalization.FusedLayerNorm 137 | APEX_IS_AVAILABLE = True 138 | except ImportError: 139 | # BertLayerNorm = BertNonFusedLayerNorm 140 | APEX_IS_AVAILABLE = False 141 | 142 | 143 | class BertLayerNorm(Module): 144 | def __init__(self, hidden_size, eps=1e-12): 145 | super(BertLayerNorm, self).__init__() 146 | self.shape = torch.Size((hidden_size,)) 147 | self.eps = eps 148 | self.weight = nn.Parameter(torch.ones(hidden_size)) 149 | self.bias = nn.Parameter(torch.zeros(hidden_size)) 150 | self.apex_enabled = APEX_IS_AVAILABLE 151 | 152 | @torch.jit.unused 153 | def fused_layer_norm(self, x): 154 | return FusedLayerNormAffineFunction.apply( 155 | x, self.weight, self.bias, self.shape, self.eps 156 | ) 157 | 158 | def forward(self, x): 159 | if self.apex_enabled and not torch.jit.is_scripting(): 160 | x = self.fused_layer_norm(x) 161 | else: 162 | u = x.mean(-1, keepdim=True) 163 | s = x - u 164 | s = s * s 165 | s = s.mean(-1, keepdim=True) 166 | x = (x - u) / torch.sqrt(s + self.eps) 167 | x = self.weight * x + self.bias 168 | return x 169 | 170 | 171 | @LAYER.register_module 172 | class BertEmbeddings(nn.Module): 173 | """Construct the embeddings from word, position and token_type embeddings.""" 174 | 175 | def __init__(self, config): 176 | super().__init__() 177 | config = BertConfig.from_dict(config) 178 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) 179 | self.position_embeddings = nn.Embedding( 180 | config.max_position_embeddings, config.hidden_size 181 | ) 182 | self.token_type_embeddings = nn.Embedding( 183 | config.type_vocab_size, config.hidden_size 184 | ) 185 | 186 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load 187 | # any TensorFlow checkpoint file 188 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) 189 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 190 | 191 | def forward(self, input_ids, token_type_ids, attention_mask): 192 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) 193 | extended_attention_mask = extended_attention_mask.to( 194 | dtype=self.word_embeddings.weight.dtype 195 | ) 196 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 197 | 198 | seq_length = input_ids.size(1) 199 | position_ids = torch.arange( 200 | seq_length, dtype=torch.long, device=input_ids.device 201 | ) 202 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids) 203 | 204 | words_embeddings = self.word_embeddings(input_ids) 205 | position_embeddings = self.position_embeddings(position_ids) 206 | token_type_embeddings = self.token_type_embeddings(token_type_ids) 207 | 208 | embeddings = words_embeddings + position_embeddings + token_type_embeddings 209 | embeddings = self.LayerNorm(embeddings) 210 | embeddings = self.dropout(embeddings) 211 | 212 | return embeddings, extended_attention_mask 213 | 214 | 215 | class BertSelfAttention(nn.Module): 216 | def __init__(self, config): 217 | super(BertSelfAttention, self).__init__() 218 | if config.hidden_size % config.num_attention_heads != 0: 219 | raise ValueError( 220 | "The hidden size (%d) is not a multiple of the number of attention " 221 | "heads (%d)" % (config.hidden_size, config.num_attention_heads) 222 | ) 223 | self.num_attention_heads = config.num_attention_heads 224 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 225 | self.all_head_size = self.num_attention_heads * self.attention_head_size 226 | 227 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 228 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 229 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 230 | 231 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 232 | 233 | def transpose_for_scores(self, x): 234 | new_x_shape = x.size()[:-1] + ( 235 | self.num_attention_heads, 236 | self.attention_head_size, 237 | ) 238 | x = torch.reshape(x, new_x_shape) 239 | return x.permute(0, 2, 1, 3) 240 | 241 | def transpose_key_for_scores(self, x): 242 | new_x_shape = x.size()[:-1] + ( 243 | self.num_attention_heads, 244 | self.attention_head_size, 245 | ) 246 | x = torch.reshape(x, new_x_shape) 247 | return x.permute(0, 2, 3, 1) 248 | 249 | def forward(self, hidden_states, attention_mask): 250 | mixed_query_layer = self.query(hidden_states) 251 | mixed_key_layer = self.key(hidden_states) 252 | mixed_value_layer = self.value(hidden_states) 253 | 254 | query_layer = self.transpose_for_scores(mixed_query_layer) 255 | key_layer = self.transpose_key_for_scores(mixed_key_layer) 256 | value_layer = self.transpose_for_scores(mixed_value_layer) 257 | 258 | # Take the dot product between "query" and "key" to get the raw attention scores. 259 | attention_scores = torch.matmul(query_layer, key_layer) 260 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 261 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 262 | attention_scores = attention_scores + attention_mask 263 | 264 | # Normalize the attention scores to probabilities. 265 | attention_probs = F.softmax(attention_scores, dim=-1) 266 | 267 | # This is actually dropping out entire tokens to attend to, which might 268 | # seem a bit unusual, but is taken from the original Transformer paper. 269 | attention_probs = self.dropout(attention_probs) 270 | 271 | context_layer = torch.matmul(attention_probs, value_layer) 272 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 273 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 274 | context_layer = torch.reshape(context_layer, new_context_layer_shape) 275 | return context_layer 276 | 277 | 278 | class BertSelfOutput(nn.Module): 279 | def __init__(self, config): 280 | super(BertSelfOutput, self).__init__() 281 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 282 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) 283 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 284 | 285 | def forward(self, hidden_states, input_tensor): 286 | hidden_states = self.dense(hidden_states) 287 | hidden_states = self.dropout(hidden_states) 288 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 289 | return hidden_states 290 | 291 | 292 | class BertAttention(nn.Module): 293 | def __init__(self, config): 294 | super(BertAttention, self).__init__() 295 | self.self = BertSelfAttention(config) 296 | self.output = BertSelfOutput(config) 297 | 298 | def forward(self, input_tensor, attention_mask): 299 | self_output = self.self(input_tensor, attention_mask) 300 | attention_output = self.output(self_output, input_tensor) 301 | return attention_output 302 | 303 | 304 | class BertIntermediate(nn.Module): 305 | def __init__(self, config): 306 | super(BertIntermediate, self).__init__() 307 | self.dense_act = LinearActivation( 308 | config.hidden_size, config.intermediate_size, act=config.hidden_act 309 | ) 310 | 311 | def forward(self, hidden_states): 312 | hidden_states = self.dense_act(hidden_states) 313 | return hidden_states 314 | 315 | 316 | class BertOutput(nn.Module): 317 | def __init__(self, config): 318 | super(BertOutput, self).__init__() 319 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size) 320 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) 321 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 322 | 323 | def forward(self, hidden_states, input_tensor): 324 | hidden_states = self.dense(hidden_states) 325 | hidden_states = self.dropout(hidden_states) 326 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 327 | return hidden_states 328 | 329 | 330 | @LAYER.register_module 331 | class BertLayer_Head(nn.Module): 332 | def __init__(self, config): 333 | super().__init__() 334 | config = BertConfig.from_dict(config) 335 | self.attention = BertAttention(config) 336 | 337 | def forward(self, hidden_states, attention_mask): 338 | attention_output = self.attention(hidden_states, attention_mask) 339 | return attention_output, attention_mask 340 | 341 | 342 | @LAYER.register_module 343 | class BertLayer_Body(nn.Module): 344 | def __init__(self, config): 345 | super().__init__() 346 | config = BertConfig.from_dict(config) 347 | self.intermediate = BertIntermediate(config) 348 | 349 | def forward(self, attention_output, attention_mask): 350 | intermediate_output = self.intermediate(attention_output) 351 | return intermediate_output, attention_output, attention_mask 352 | 353 | 354 | @LAYER.register_module 355 | class BertLayer_Tail(nn.Module): 356 | def __init__(self, config): 357 | super().__init__() 358 | config = BertConfig.from_dict(config) 359 | self.output = BertOutput(config) 360 | 361 | def forward(self, intermediate_output, attention_output, attention_mask): 362 | layer_output = self.output(intermediate_output, attention_output) 363 | return layer_output, attention_mask 364 | 365 | 366 | @LAYER.register_module 367 | class BertTailForClassification(nn.Module): 368 | def __init__(self, hidden_dropout_prob, hidden_size, num_classes): 369 | super().__init__() 370 | self.num_classes = num_classes 371 | self.dropout = nn.Dropout(hidden_dropout_prob) 372 | self.classifier = nn.Linear(hidden_size, num_classes) 373 | 374 | def forward(self, logits): 375 | logits = self.dropout(logits) 376 | logits = self.classifier(logits) 377 | logits = logits.view(-1, self.num_classes) 378 | return logits 379 | 380 | 381 | @LAYER.register_module 382 | class BertPooler(nn.Module): 383 | def __init__(self, config): 384 | super().__init__() 385 | config = BertConfig.from_dict(config) 386 | self.dense_act = LinearActivation( 387 | config.hidden_size, config.hidden_size, act="tanh" 388 | ) 389 | 390 | def forward(self, hidden_states, attention_mask): 391 | # We "pool" the model by simply taking the hidden state corresponding 392 | # to the first token. 393 | first_token_tensor = hidden_states[:, 0] 394 | pooled_output = self.dense_act(first_token_tensor) 395 | return pooled_output 396 | 397 | 398 | # class BertPredictionHeadTransform(nn.Module): 399 | # def __init__(self, config): 400 | # super(BertPredictionHeadTransform, self).__init__() 401 | # self.dense_act = LinearActivation( 402 | # config.hidden_size, config.hidden_size, act=config.hidden_act) 403 | # self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) 404 | 405 | # def forward(self, hidden_states): 406 | # hidden_states = self.dense_act(hidden_states) 407 | # hidden_states = self.LayerNorm(hidden_states) 408 | # return hidden_states 409 | 410 | 411 | # class BertLMPredictionHead(nn.Module): 412 | # def __init__(self, config, bert_model_embedding_weights): 413 | # super(BertLMPredictionHead, self).__init__() 414 | # self.transform = BertPredictionHeadTransform(config) 415 | 416 | # # The output weights are the same as the input embeddings, but there is 417 | # # an output-only bias for each token. 418 | # self.decoder = nn.Linear(bert_model_embedding_weights.size(1), 419 | # bert_model_embedding_weights.size(0), 420 | # bias=False) 421 | # self.decoder.weight = bert_model_embedding_weights 422 | # self.bias = nn.Parameter(torch.zeros( 423 | # bert_model_embedding_weights.size(0))) 424 | 425 | # def forward(self, hidden_states): 426 | # hidden_states = self.transform(hidden_states) 427 | # hidden_states = self.decoder(hidden_states) + self.bias 428 | # return hidden_states 429 | 430 | 431 | # class BertOnlyMLMHead(nn.Module): 432 | # def __init__(self, config, bert_model_embedding_weights): 433 | # super(BertOnlyMLMHead, self).__init__() 434 | # self.predictions = BertLMPredictionHead( 435 | # config, bert_model_embedding_weights) 436 | 437 | # def forward(self, sequence_output): 438 | # prediction_scores = self.predictions(sequence_output) 439 | # return prediction_scores 440 | 441 | 442 | # class BertOnlyNSPHead(nn.Module): 443 | # def __init__(self, config): 444 | # super(BertOnlyNSPHead, self).__init__() 445 | # self.seq_relationship = nn.Linear(config.hidden_size, 2) 446 | 447 | # def forward(self, pooled_output): 448 | # seq_relationship_score = self.seq_relationship(pooled_output) 449 | # return seq_relationship_score 450 | 451 | 452 | # class BertPreTrainingHeads(nn.Module): 453 | # def __init__(self, config, bert_model_embedding_weights): 454 | # super(BertPreTrainingHeads, self).__init__() 455 | # self.predictions = BertLMPredictionHead( 456 | # config, bert_model_embedding_weights) 457 | # self.seq_relationship = nn.Linear(config.hidden_size, 2) 458 | 459 | # def forward(self, sequence_output, pooled_output): 460 | # prediction_scores = self.predictions(sequence_output) 461 | # seq_relationship_score = self.seq_relationship(pooled_output) 462 | # return prediction_scores, seq_relationship_score 463 | 464 | 465 | # class BertForPreTraining(BertPreTrainedModel): 466 | # """BERT model with pre-training heads. 467 | # This module comprises the BERT model followed by the two pre-training heads: 468 | # - the masked language modeling head, and 469 | # - the next sentence classification head. 470 | # Params: 471 | # config: a BertConfig class instance with the configuration to build a new model. 472 | # Inputs: 473 | # `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 474 | # with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 475 | # `extract_features.py`, `run_classifier.py` and `run_squad.py`) 476 | # `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 477 | # types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 478 | # a `sentence B` token (see BERT paper for more details). 479 | # `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 480 | # selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 481 | # input sequence length in the current batch. It's the mask that we typically use for attention when 482 | # a batch has varying length sentences. 483 | # `masked_lm_labels`: optional masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] 484 | # with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss 485 | # is only computed for the labels set in [0, ..., vocab_size] 486 | # `next_sentence_label`: optional next sentence classification loss: torch.LongTensor of shape [batch_size] 487 | # with indices selected in [0, 1]. 488 | # 0 => next sentence is the continuation, 1 => next sentence is a random sentence. 489 | # Outputs: 490 | # if `masked_lm_labels` and `next_sentence_label` are not `None`: 491 | # Outputs the total_loss which is the sum of the masked language modeling loss and the next 492 | # sentence classification loss. 493 | # if `masked_lm_labels` or `next_sentence_label` is `None`: 494 | # Outputs a tuple comprising 495 | # - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and 496 | # - the next sentence classification logits of shape [batch_size, 2]. 497 | # Example usage: 498 | # ```python 499 | # # Already been converted into WordPiece token ids 500 | # input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 501 | # input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 502 | # token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 503 | # config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 504 | # num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 505 | # model = BertForPreTraining(config) 506 | # masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask) 507 | # ``` 508 | # """ 509 | 510 | # def __init__(self, config): 511 | # super(BertForPreTraining, self).__init__(config) 512 | # self.bert = BertModel(config) 513 | # self.cls = BertPreTrainingHeads( 514 | # config, self.bert.embeddings.word_embeddings.weight) 515 | # self.apply(self.init_bert_weights) 516 | 517 | # def forward(self, input_ids, token_type_ids, attention_mask): 518 | # encoded_layers, pooled_output = self.bert( 519 | # input_ids, token_type_ids, attention_mask) 520 | # sequence_output = encoded_layers[-1] 521 | # prediction_scores, seq_relationship_score = self.cls( 522 | # sequence_output, pooled_output) 523 | 524 | # return prediction_scores, seq_relationship_score 525 | 526 | 527 | # class BertForMaskedLM(BertPreTrainedModel): 528 | # """BERT model with the masked language modeling head. 529 | # This module comprises the BERT model followed by the masked language modeling head. 530 | # Params: 531 | # config: a BertConfig class instance with the configuration to build a new model. 532 | # Inputs: 533 | # `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 534 | # with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 535 | # `extract_features.py`, `run_classifier.py` and `run_squad.py`) 536 | # `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 537 | # types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 538 | # a `sentence B` token (see BERT paper for more details). 539 | # `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 540 | # selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 541 | # input sequence length in the current batch. It's the mask that we typically use for attention when 542 | # a batch has varying length sentences. 543 | # `masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] 544 | # with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss 545 | # is only computed for the labels set in [0, ..., vocab_size] 546 | # Outputs: 547 | # if `masked_lm_labels` is not `None`: 548 | # Outputs the masked language modeling loss. 549 | # if `masked_lm_labels` is `None`: 550 | # Outputs the masked language modeling logits of shape [batch_size, sequence_length, vocab_size]. 551 | # Example usage: 552 | # ```python 553 | # # Already been converted into WordPiece token ids 554 | # input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 555 | # input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 556 | # token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 557 | # config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 558 | # num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 559 | # model = BertForMaskedLM(config) 560 | # masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask) 561 | # ``` 562 | # """ 563 | 564 | # def __init__(self, config): 565 | # super(BertForMaskedLM, self).__init__(config) 566 | # self.bert = BertModel(config) 567 | # self.cls = BertOnlyMLMHead( 568 | # config, self.bert.embeddings.word_embeddings.weight) 569 | # self.apply(self.init_bert_weights) 570 | 571 | # def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None): 572 | # encoded_layers, _ = self.bert( 573 | # input_ids, token_type_ids, attention_mask) 574 | # sequence_output = encoded_layers[-1] 575 | # prediction_scores = self.cls(sequence_output) 576 | 577 | # if masked_lm_labels is not None: 578 | # loss_fct = CrossEntropyLoss(ignore_index=-1) 579 | # masked_lm_loss = loss_fct( 580 | # prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) 581 | # return masked_lm_loss 582 | # else: 583 | # return prediction_scores 584 | 585 | 586 | # class BertForNextSentencePrediction(BertPreTrainedModel): 587 | # """BERT model with next sentence prediction head. 588 | # This module comprises the BERT model followed by the next sentence classification head. 589 | # Params: 590 | # config: a BertConfig class instance with the configuration to build a new model. 591 | # Inputs: 592 | # `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 593 | # with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 594 | # `extract_features.py`, `run_classifier.py` and `run_squad.py`) 595 | # `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 596 | # types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 597 | # a `sentence B` token (see BERT paper for more details). 598 | # `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 599 | # selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 600 | # input sequence length in the current batch. It's the mask that we typically use for attention when 601 | # a batch has varying length sentences. 602 | # `next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size] 603 | # with indices selected in [0, 1]. 604 | # 0 => next sentence is the continuation, 1 => next sentence is a random sentence. 605 | # Outputs: 606 | # if `next_sentence_label` is not `None`: 607 | # Outputs the total_loss which is the sum of the masked language modeling loss and the next 608 | # sentence classification loss. 609 | # if `next_sentence_label` is `None`: 610 | # Outputs the next sentence classification logits of shape [batch_size, 2]. 611 | # Example usage: 612 | # ```python 613 | # # Already been converted into WordPiece token ids 614 | # input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 615 | # input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 616 | # token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 617 | # config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 618 | # num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 619 | # model = BertForNextSentencePrediction(config) 620 | # seq_relationship_logits = model(input_ids, token_type_ids, input_mask) 621 | # ``` 622 | # """ 623 | 624 | # def __init__(self, config): 625 | # super(BertForNextSentencePrediction, self).__init__(config) 626 | # self.bert = BertModel(config) 627 | # self.cls = BertOnlyNSPHead(config) 628 | # self.apply(self.init_bert_weights) 629 | 630 | # def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None): 631 | # _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask) 632 | # seq_relationship_score = self.cls(pooled_output) 633 | 634 | # if next_sentence_label is not None: 635 | # loss_fct = CrossEntropyLoss(ignore_index=-1) 636 | # next_sentence_loss = loss_fct( 637 | # seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) 638 | # return next_sentence_loss 639 | # else: 640 | # return seq_relationship_score 641 | 642 | 643 | # class BertForSequenceClassification(BertPreTrainedModel): 644 | # """BERT model for classification. 645 | # This module is composed of the BERT model with a linear layer on top of 646 | # the pooled output. 647 | # Params: 648 | # `config`: a BertConfig class instance with the configuration to build a new model. 649 | # `num_labels`: the number of classes for the classifier. Default = 2. 650 | # Inputs: 651 | # `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 652 | # with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 653 | # `extract_features.py`, `run_classifier.py` and `run_squad.py`) 654 | # `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 655 | # types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 656 | # a `sentence B` token (see BERT paper for more details). 657 | # `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 658 | # selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 659 | # input sequence length in the current batch. It's the mask that we typically use for attention when 660 | # a batch has varying length sentences. 661 | # `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] 662 | # with indices selected in [0, ..., num_labels]. 663 | # Outputs: 664 | # if `labels` is not `None`: 665 | # Outputs the CrossEntropy classification loss of the output with the labels. 666 | # if `labels` is `None`: 667 | # Outputs the classification logits of shape [batch_size, num_labels]. 668 | # Example usage: 669 | # ```python 670 | # # Already been converted into WordPiece token ids 671 | # input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 672 | # input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 673 | # token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 674 | # config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 675 | # num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 676 | # num_labels = 2 677 | # model = BertForSequenceClassification(config, num_labels) 678 | # logits = model(input_ids, token_type_ids, input_mask) 679 | # ``` 680 | # """ 681 | 682 | # def __init__(self, config, num_labels): 683 | # super(BertForSequenceClassification, self).__init__(config) 684 | # self.num_labels = num_labels 685 | # self.bert = BertModel(config) 686 | # self.dropout = nn.Dropout(config.hidden_dropout_prob) 687 | # self.classifier = nn.Linear(config.hidden_size, num_labels) 688 | # self.apply(self.init_bert_weights) 689 | 690 | # def forward(self, input_ids, token_type_ids=None, attention_mask=None): 691 | # _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask) 692 | # pooled_output = self.dropout(pooled_output) 693 | # return self.classifier(pooled_output) 694 | 695 | 696 | # class BertForMultipleChoice(BertPreTrainedModel): 697 | # """BERT model for multiple choice tasks. 698 | # This module is composed of the BERT model with a linear layer on top of 699 | # the pooled output. 700 | # Params: 701 | # `config`: a BertConfig class instance with the configuration to build a new model. 702 | # `num_choices`: the number of classes for the classifier. Default = 2. 703 | # Inputs: 704 | # `input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length] 705 | # with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 706 | # `extract_features.py`, `run_classifier.py` and `run_squad.py`) 707 | # `token_type_ids`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] 708 | # with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` 709 | # and type 1 corresponds to a `sentence B` token (see BERT paper for more details). 710 | # `attention_mask`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] with indices 711 | # selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 712 | # input sequence length in the current batch. It's the mask that we typically use for attention when 713 | # a batch has varying length sentences. 714 | # `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] 715 | # with indices selected in [0, ..., num_choices]. 716 | # Outputs: 717 | # if `labels` is not `None`: 718 | # Outputs the CrossEntropy classification loss of the output with the labels. 719 | # if `labels` is `None`: 720 | # Outputs the classification logits of shape [batch_size, num_labels]. 721 | # Example usage: 722 | # ```python 723 | # # Already been converted into WordPiece token ids 724 | # input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]], [[12, 16, 42], [14, 28, 57]]]) 725 | # input_mask = torch.LongTensor([[[1, 1, 1], [1, 1, 0]],[[1,1,0], [1, 0, 0]]]) 726 | # token_type_ids = torch.LongTensor([[[0, 0, 1], [0, 1, 0]],[[0, 1, 1], [0, 0, 1]]]) 727 | # config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 728 | # num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 729 | # num_choices = 2 730 | # model = BertForMultipleChoice(config, num_choices) 731 | # logits = model(input_ids, token_type_ids, input_mask) 732 | # ``` 733 | # """ 734 | 735 | # def __init__(self, config, num_choices): 736 | # super(BertForMultipleChoice, self).__init__(config) 737 | # self.num_choices = num_choices 738 | # self.bert = BertModel(config) 739 | # self.dropout = nn.Dropout(config.hidden_dropout_prob) 740 | # self.classifier = nn.Linear(config.hidden_size, 1) 741 | # self.apply(self.init_bert_weights) 742 | 743 | # def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): 744 | # flat_input_ids = input_ids.view(-1, input_ids.size(-1)) 745 | # flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) 746 | # flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) 747 | # _, pooled_output = self.bert( 748 | # flat_input_ids, flat_token_type_ids, flat_attention_mask) 749 | # pooled_output = self.dropout(pooled_output) 750 | # logits = self.classifier(pooled_output) 751 | # reshaped_logits = logits.view(-1, self.num_choices) 752 | 753 | # if labels is not None: 754 | # loss_fct = CrossEntropyLoss() 755 | # loss = loss_fct(reshaped_logits, labels) 756 | # return loss 757 | # else: 758 | # return reshaped_logits 759 | 760 | 761 | # class BertForTokenClassification(BertPreTrainedModel): 762 | # """BERT model for token-level classification. 763 | # This module is composed of the BERT model with a linear layer on top of 764 | # the full hidden state of the last layer. 765 | # Params: 766 | # `config`: a BertConfig class instance with the configuration to build a new model. 767 | # `num_labels`: the number of classes for the classifier. Default = 2. 768 | # Inputs: 769 | # `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 770 | # with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 771 | # `extract_features.py`, `run_classifier.py` and `run_squad.py`) 772 | # `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 773 | # types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 774 | # a `sentence B` token (see BERT paper for more details). 775 | # `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 776 | # selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 777 | # input sequence length in the current batch. It's the mask that we typically use for attention when 778 | # a batch has varying length sentences. 779 | # `labels`: labels for the classification output: torch.LongTensor of shape [batch_size, sequence_length] 780 | # with indices selected in [0, ..., num_labels]. 781 | # Outputs: 782 | # if `labels` is not `None`: 783 | # Outputs the CrossEntropy classification loss of the output with the labels. 784 | # if `labels` is `None`: 785 | # Outputs the classification logits of shape [batch_size, sequence_length, num_labels]. 786 | # Example usage: 787 | # ```python 788 | # # Already been converted into WordPiece token ids 789 | # input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 790 | # input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 791 | # token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 792 | # config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 793 | # num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 794 | # num_labels = 2 795 | # model = BertForTokenClassification(config, num_labels) 796 | # logits = model(input_ids, token_type_ids, input_mask) 797 | # ``` 798 | # """ 799 | 800 | # def __init__(self, config, num_labels): 801 | # super(BertForTokenClassification, self).__init__(config) 802 | # self.num_labels = num_labels 803 | # self.bert = BertModel(config) 804 | # self.dropout = nn.Dropout(config.hidden_dropout_prob) 805 | # self.classifier = nn.Linear(config.hidden_size, num_labels) 806 | # self.apply(self.init_bert_weights) 807 | 808 | # def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): 809 | # encoded_layers, _ = self.bert( 810 | # input_ids, token_type_ids, attention_mask) 811 | # sequence_output = encoded_layers[-1] 812 | # sequence_output = self.dropout(sequence_output) 813 | # logits = self.classifier(sequence_output) 814 | 815 | # if labels is not None: 816 | # loss_fct = CrossEntropyLoss() 817 | # # Only keep active parts of the loss 818 | # if attention_mask is not None: 819 | # active_loss = attention_mask.view(-1) == 1 820 | # active_logits = logits.view(-1, self.num_labels)[active_loss] 821 | # active_labels = labels.view(-1)[active_loss] 822 | # loss = loss_fct(active_logits, active_labels) 823 | # else: 824 | # loss = loss_fct( 825 | # logits.view(-1, self.num_labels), labels.view(-1)) 826 | # return loss 827 | # else: 828 | # return logits 829 | 830 | 831 | # class BertForQuestionAnswering(BertPreTrainedModel): 832 | # """BERT model for Question Answering (span extraction). 833 | # This module is composed of the BERT model with a linear layer on top of 834 | # the sequence output that computes start_logits and end_logits 835 | # Params: 836 | # `config`: a BertConfig class instance with the configuration to build a new model. 837 | # Inputs: 838 | # `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 839 | # with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 840 | # `extract_features.py`, `run_classifier.py` and `run_squad.py`) 841 | # `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 842 | # types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 843 | # a `sentence B` token (see BERT paper for more details). 844 | # `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 845 | # selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 846 | # input sequence length in the current batch. It's the mask that we typically use for attention when 847 | # a batch has varying length sentences. 848 | # Outputs: 849 | # Outputs a tuple of start_logits, end_logits which are the logits respectively for the start and end 850 | # position tokens of shape [batch_size, sequence_length]. 851 | # Example usage: 852 | # ```python 853 | # # Already been converted into WordPiece token ids 854 | # input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 855 | # input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 856 | # token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 857 | # config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 858 | # num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 859 | # model = BertForQuestionAnswering(config) 860 | # start_logits, end_logits = model(input_ids, token_type_ids, input_mask) 861 | # ``` 862 | # """ 863 | 864 | # def __init__(self, config): 865 | # super(BertForQuestionAnswering, self).__init__(config) 866 | # self.bert = BertModel(config) 867 | # # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version 868 | # # self.dropout = nn.Dropout(config.hidden_dropout_prob) 869 | # self.qa_outputs = nn.Linear(config.hidden_size, 2) 870 | # self.apply(self.init_bert_weights) 871 | 872 | # def forward(self, input_ids, token_type_ids, attention_mask): 873 | # encoded_layers, _ = self.bert( 874 | # input_ids, token_type_ids, attention_mask) 875 | # sequence_output = encoded_layers[-1] 876 | # logits = self.qa_outputs(sequence_output) 877 | # start_logits, end_logits = logits.split(1, dim=-1) 878 | # start_logits = start_logits.squeeze(-1) 879 | # end_logits = end_logits.squeeze(-1) 880 | # return start_logits, end_logits 881 | -------------------------------------------------------------------------------- /scaelum/model/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from scaelum.registry import LAYER 4 | 5 | 6 | @LAYER.register_module 7 | class BasicBlock(nn.Module): 8 | """Basic Block for resnet 18 and resnet 34""" 9 | 10 | # BasicBlock and BottleNeck block 11 | # have different output size 12 | # we use class attribute expansion 13 | # to distinct 14 | expansion = 1 15 | 16 | def __init__(self, in_channels, out_channels, stride=1): 17 | super().__init__() 18 | 19 | # residual function 20 | self.residual_function = nn.Sequential( 21 | nn.Conv2d( 22 | in_channels, 23 | out_channels, 24 | kernel_size=3, 25 | stride=stride, 26 | padding=1, 27 | bias=False, 28 | ), 29 | nn.BatchNorm2d(out_channels), 30 | nn.ReLU(inplace=True), 31 | nn.Conv2d( 32 | out_channels, 33 | out_channels * BasicBlock.expansion, 34 | kernel_size=3, 35 | padding=1, 36 | bias=False, 37 | ), 38 | nn.BatchNorm2d(out_channels * BasicBlock.expansion), 39 | ) 40 | 41 | # shortcut 42 | self.shortcut = nn.Sequential() 43 | 44 | # the shortcut output dimension is not the same with residual function 45 | # use 1*1 convolution to match the dimension 46 | if stride != 1 or in_channels != BasicBlock.expansion * out_channels: 47 | self.shortcut = nn.Sequential( 48 | nn.Conv2d( 49 | in_channels, 50 | out_channels * BasicBlock.expansion, 51 | kernel_size=1, 52 | stride=stride, 53 | bias=False, 54 | ), 55 | nn.BatchNorm2d(out_channels * BasicBlock.expansion), 56 | ) 57 | 58 | def forward(self, x): 59 | return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x)) 60 | 61 | 62 | @LAYER.register_module 63 | class BottleNeck(nn.Module): 64 | """Residual block for resnet over 50 layers""" 65 | 66 | expansion = 4 67 | 68 | def __init__(self, in_channels, out_channels, stride=1): 69 | super().__init__() 70 | expansion = LAYER.get_module(self.__class__.__name__).expansion 71 | 72 | self.residual_function = nn.Sequential( 73 | nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), 74 | nn.BatchNorm2d(out_channels), 75 | nn.ReLU(inplace=True), 76 | nn.Conv2d( 77 | out_channels, 78 | out_channels, 79 | stride=stride, 80 | kernel_size=3, 81 | padding=1, 82 | bias=False, 83 | ), 84 | nn.BatchNorm2d(out_channels), 85 | nn.ReLU(inplace=True), 86 | nn.Conv2d( 87 | out_channels, out_channels * expansion, kernel_size=1, bias=False 88 | ), 89 | nn.BatchNorm2d(out_channels * expansion), 90 | ) 91 | 92 | self.shortcut = nn.Sequential() 93 | 94 | if stride != 1 or in_channels != out_channels * expansion: 95 | self.shortcut = nn.Sequential( 96 | nn.Conv2d( 97 | in_channels, 98 | out_channels * expansion, 99 | stride=stride, 100 | kernel_size=1, 101 | bias=False, 102 | ), 103 | nn.BatchNorm2d(out_channels * expansion), 104 | ) 105 | 106 | def forward(self, x): 107 | return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x)) 108 | 109 | 110 | @LAYER.register_module 111 | class ResLayer(nn.Module): 112 | def __init__(self, block_name, in_channels, out_channels, num_blocks, stride): 113 | super().__init__() 114 | self.block = LAYER.get_module(block_name) 115 | self.in_channels = in_channels 116 | self.out_channels = out_channels 117 | self.num_blocks = num_blocks 118 | self.stride = stride 119 | self.layers = self._make_layer() 120 | 121 | def _make_layer(self): 122 | """make resnet layers(by layer i didnt mean this 'layer' was the 123 | same as a neuron netowork layer, ex. conv layer), one layer may 124 | contain more than one residual block 125 | 126 | Args: 127 | block: block type, basic block or bottle neck block 128 | out_channels: output depth channel number of this layer 129 | num_blocks: how many blocks per layer 130 | stride: the stride of the first block of this layer 131 | 132 | Return: 133 | return a resnet layer 134 | """ 135 | 136 | # we have num_block blocks per layer, the first block 137 | # could be 1 or 2, other blocks would always be 1 138 | strides = [self.stride] + [1] * (self.num_blocks - 1) 139 | layers = [] 140 | for stride in strides: 141 | layers.append(self.block(self.in_channels, self.out_channels, stride)) 142 | self.in_channels = self.out_channels * self.block.expansion 143 | 144 | return nn.Sequential(*layers) 145 | 146 | def forward(self, *args, **kwargs): 147 | return self.layers(*args, **kwargs) 148 | 149 | 150 | @LAYER.register_module 151 | class ResTail(nn.Module): 152 | def __init__(self, in_channels, block_name, num_classes): 153 | super().__init__() 154 | block = LAYER.get_module(block_name) 155 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 156 | self.fc = nn.Linear(in_channels * block.expansion, num_classes) 157 | 158 | def forward(self, *args, **kwargs): 159 | output = self.avg_pool(*args, **kwargs) 160 | output = output.view(output.size(0), -1) 161 | output = self.fc(output) 162 | 163 | return output 164 | 165 | 166 | @LAYER.register_module 167 | class ResHead(nn.Module): 168 | def __init__(self, in_channels, out_channels): 169 | super().__init__() 170 | self.conv1 = nn.Sequential( 171 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), 172 | nn.BatchNorm2d(out_channels), 173 | nn.ReLU(inplace=True), 174 | ) 175 | 176 | def forward(self, *args, **kwargs): 177 | return self.conv1(*args, **kwargs) 178 | 179 | 180 | @LAYER.register_module 181 | class ResNet(nn.Module): 182 | def __init__(self, block, num_block, num_classes=100): 183 | super().__init__() 184 | 185 | self.in_channels = 64 186 | 187 | self.conv1 = nn.Sequential( 188 | nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False), 189 | nn.BatchNorm2d(64), 190 | nn.ReLU(inplace=True), 191 | ) 192 | # we use a different inputsize than the original paper 193 | # so conv2_x's stride is 1 194 | self.conv2_x = self._make_layer(block, 64, num_block[0], 1) 195 | self.conv3_x = self._make_layer(block, 128, num_block[1], 2) 196 | self.conv4_x = self._make_layer(block, 256, num_block[2], 2) 197 | self.conv5_x = self._make_layer(block, 512, num_block[3], 2) 198 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 199 | self.fc = nn.Linear(512 * block.expansion, num_classes) 200 | 201 | def _make_layer(self, block, out_channels, num_blocks, stride): 202 | """make resnet layers(by layer i didnt mean this 'layer' was the 203 | same as a neuron netowork layer, ex. conv layer), one layer may 204 | contain more than one residual block 205 | 206 | Args: 207 | block: block type, basic block or bottle neck block 208 | out_channels: output depth channel number of this layer 209 | num_blocks: how many blocks per layer 210 | stride: the stride of the first block of this layer 211 | 212 | Return: 213 | return a resnet layer 214 | """ 215 | 216 | # we have num_block blocks per layer, the first block 217 | # could be 1 or 2, other blocks would always be 1 218 | strides = [stride] + [1] * (num_blocks - 1) 219 | layers = [] 220 | for stride in strides: 221 | layers.append(block(self.in_channels, out_channels, stride)) 222 | self.in_channels = out_channels * block.expansion 223 | 224 | return nn.Sequential(*layers) 225 | 226 | def forward(self, x): 227 | output = self.conv1(x) 228 | output = self.conv2_x(output) 229 | output = self.conv3_x(output) 230 | output = self.conv4_x(output) 231 | output = self.conv5_x(output) 232 | output = self.avg_pool(output) 233 | output = output.view(output.size(0), -1) 234 | output = self.fc(output) 235 | 236 | return output 237 | 238 | 239 | def resnet18(): 240 | """return a ResNet 18 object""" 241 | return ResNet(BasicBlock, [2, 2, 2, 2]) 242 | 243 | 244 | def resnet34(): 245 | """return a ResNet 34 object""" 246 | return ResNet(BasicBlock, [3, 4, 6, 3]) 247 | 248 | 249 | def resnet50(): 250 | """return a ResNet 50 object""" 251 | return ResNet(BottleNeck, [3, 4, 6, 3]) 252 | 253 | 254 | def resnet101(): 255 | """return a ResNet 101 object""" 256 | return ResNet(BottleNeck, [3, 4, 23, 3]) 257 | 258 | 259 | def resnet152(): 260 | """return a ResNet 152 object""" 261 | return ResNet(BottleNeck, [3, 8, 36, 3]) 262 | -------------------------------------------------------------------------------- /scaelum/model/rpc_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | 5 | import torch.nn as nn 6 | from scaelum.dynamics import WorkerManager 7 | 8 | from .rpc_module import LocalModule, RemoteModule 9 | 10 | try: 11 | from torch.distributed.rpc import PyRRef as RpcRef 12 | except ImportError: 13 | from torch.distributed.rpc import RRef as RpcRef 14 | 15 | 16 | class RpcModel(nn.Module): 17 | def __init__(self, worker_manager: WorkerManager): 18 | super(RpcModel, self).__init__() 19 | self.worker_manager = worker_manager 20 | self.model = self._build_model() 21 | assert isinstance(self.model, nn.ModuleList), "model must be iterable" 22 | 23 | def _build_model(self): 24 | # init model 25 | model = nn.ModuleList() 26 | 27 | for worker in self.worker_manager.worker_pool: 28 | if worker.rank == 0: 29 | module = LocalModule( 30 | rank=worker.rank, 31 | model_cfg=worker.model_config, 32 | sequential_wrapper_cfg=worker.extra_config, 33 | ) 34 | else: 35 | module = RemoteModule( 36 | rank=worker.rank, 37 | model_cfg=worker.model_config, 38 | sequential_wrapper_cfg=worker.extra_config, 39 | ) 40 | model.append(module) 41 | 42 | return model 43 | 44 | def forward(self, *args): 45 | # Handle input in the case of List[List[Tensor]] 46 | if len(args) == 1 and isinstance(args[0], (list, tuple)): 47 | args = args[0] 48 | 49 | for idx, layer in enumerate(self.model): 50 | args = layer(*args) 51 | 52 | result = args[0] 53 | if isinstance(result, RpcRef): 54 | result = result.to_here()[0] 55 | return result 56 | 57 | def parameter_rrefs(self): 58 | remote_params = [] 59 | 60 | for layer in self.model: 61 | remote_params.extend(layer.parameter_rrefs()) 62 | 63 | return remote_params 64 | -------------------------------------------------------------------------------- /scaelum/model/rpc_module.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | from typing import List, Dict 5 | 6 | import torch.distributed.rpc as rpc 7 | import torch.nn as nn 8 | from scaelum import utils 9 | from scaelum.builder import ModuleWrapper, build_module_from_cfg 10 | 11 | 12 | class BaseModule(nn.Module): 13 | def __init__( 14 | self, 15 | rank, 16 | model_cfg, 17 | sequential_wrapper_cfg, 18 | ): 19 | super(BaseModule, self).__init__() 20 | self.rank = rank 21 | self.model_cfg = model_cfg 22 | self.sequential_wrapper_cfg = sequential_wrapper_cfg 23 | self.module = self._build_module() 24 | 25 | def _forward(self, *args): 26 | raise Exception("not implemented") 27 | 28 | def forward(self, *args, **kwargs): 29 | output = self._forward(*args, **kwargs) 30 | return output 31 | 32 | def _build_module(self): 33 | raise NotImplementedError("not implemented") 34 | 35 | def load_weights(self, state_dict: List[Dict]) -> None: 36 | raise NotImplementedError("not implemented") 37 | 38 | def get_state_dict(self) -> List[Dict]: 39 | raise NotImplementedError("not implemented") 40 | 41 | def parameter_rrefs(self) -> List: 42 | raise NotImplementedError("not implemented") 43 | 44 | 45 | class LocalModule(BaseModule): 46 | def _forward(self, *args): 47 | res = self.module(*args) 48 | 49 | if isinstance(res, tuple) or isinstance(res, list): 50 | return res 51 | else: 52 | return (res,) 53 | 54 | def _build_module(self): 55 | module = build_module_from_cfg( 56 | rank=self.rank, 57 | model_cfg=self.model_cfg, 58 | module_wrapper_cfg=self.sequential_wrapper_cfg, 59 | ) 60 | return module 61 | 62 | def load_weights(self, state_dict: List[Dict]) -> None: 63 | utils.load_weights(self.module, state_dict) 64 | self.module._move_module_to_cuda() 65 | 66 | def get_state_dict(self) -> List[Dict]: 67 | return utils.get_state_dict(self.module) 68 | 69 | def parameter_rrefs(self) -> List: 70 | return utils.parameter_rrefs(self.module) 71 | 72 | 73 | class RemoteModule(BaseModule): 74 | def _forward(self, *args): 75 | # must return as a tuple for consistency 76 | res = rpc.remote( 77 | "worker{}".format(self.rank), 78 | ModuleWrapper.forward, 79 | args=[self.module] + list(args), 80 | ) 81 | return (res,) 82 | 83 | def _build_module(self): 84 | module = rpc.remote( 85 | "worker{}".format(self.rank), 86 | build_module_from_cfg, 87 | args=(self.rank, self.model_cfg, self.sequential_wrapper_cfg), 88 | ) 89 | return module 90 | 91 | def load_weights(self, state_dict: List[Dict]) -> None: 92 | utils.remote_method(utils.load_weights, self.module, state_dict) 93 | utils.remote_method(ModuleWrapper._move_module_to_cuda, self.module) 94 | 95 | def get_state_dict(self) -> List[Dict]: 96 | return utils.remote_method(utils.get_state_dict, self.module) 97 | 98 | def parameter_rrefs(self) -> List: 99 | return utils.remote_method(utils.parameter_rrefs, self.module) 100 | -------------------------------------------------------------------------------- /scaelum/registry/__init__.py: -------------------------------------------------------------------------------- 1 | from .registry import DATA_GENERATOR, DATASET, HOOKS, LAYER, Registry 2 | -------------------------------------------------------------------------------- /scaelum/registry/registry.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | 5 | import torch.nn as nn 6 | 7 | 8 | class Registry(object): 9 | def __init__(self, name: str): 10 | self.name = name 11 | self._registry = dict() 12 | 13 | def register_module(self, module_class): 14 | module_name = module_class.__name__ 15 | assert module_name not in self._registry 16 | self._registry[module_name] = module_class 17 | 18 | def get_module(self, module_name: str, include_torch=True): 19 | if module_name in self._registry: 20 | return self._registry[module_name] 21 | elif include_torch and hasattr(nn, module_name): 22 | return getattr(nn, module_name) 23 | else: 24 | raise NameError("Module {} not found".format(module_name)) 25 | 26 | 27 | LAYER = Registry("layer") 28 | DATASET = Registry("dataset") 29 | HOOKS = Registry("hook") 30 | DATA_GENERATOR = Registry("data_generator") 31 | -------------------------------------------------------------------------------- /scaelum/runner/__init__.py: -------------------------------------------------------------------------------- 1 | from .hooks import Hook 2 | from .hooks_collection import * 3 | from .runner import Runner 4 | -------------------------------------------------------------------------------- /scaelum/runner/hooks.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | 5 | class Hook(object): 6 | def before_run(self, runner): 7 | pass 8 | 9 | def after_run(self, runner): 10 | pass 11 | 12 | def before_epoch(self, runner): 13 | pass 14 | 15 | def after_epoch(self, runner): 16 | pass 17 | 18 | def before_iter(self, runner): 19 | pass 20 | 21 | def after_iter(self, runner): 22 | pass 23 | 24 | def before_train_epoch(self, runner): 25 | self.before_epoch(runner) 26 | 27 | def before_val_epoch(self, runner): 28 | self.before_epoch(runner) 29 | 30 | def after_train_epoch(self, runner): 31 | self.after_epoch(runner) 32 | 33 | def after_val_epoch(self, runner): 34 | self.after_epoch(runner) 35 | 36 | def before_train_iter(self, runner): 37 | self.before_iter(runner) 38 | 39 | def before_val_iter(self, runner): 40 | self.before_iter(runner) 41 | 42 | def after_train_iter(self, runner): 43 | self.after_iter(runner) 44 | 45 | def after_val_iter(self, runner): 46 | self.after_iter(runner) 47 | 48 | def every_n_epochs(self, runner, n): 49 | return (runner.epoch + 1) % n == 0 if n > 0 else False 50 | 51 | def every_n_inner_iters(self, runner, n): 52 | return (runner.inner_iter + 1) % n == 0 if n > 0 else False 53 | 54 | def every_n_iters(self, runner, n): 55 | return (runner.iter + 1) % n == 0 if n > 0 else False 56 | 57 | def end_of_epoch(self, runner): 58 | return runner.inner_iter + 1 == len(runner.data_loader) 59 | -------------------------------------------------------------------------------- /scaelum/runner/hooks_collection/__init__.py: -------------------------------------------------------------------------------- 1 | from .checkpoint_hook import CheckpointHook 2 | from .distributed_timer_helper_hook import DistributedTimerHelperHook 3 | from .stop_hook import StopHook 4 | -------------------------------------------------------------------------------- /scaelum/runner/hooks_collection/checkpoint_hook.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | 5 | import os 6 | import os.path as osp 7 | 8 | from scaelum.registry import HOOKS 9 | 10 | from ..hooks import Hook 11 | 12 | 13 | @HOOKS.register_module 14 | class CheckpointHook(Hook): 15 | def __init__( 16 | self, 17 | load_checkpoint_from: str = None, 18 | save_path: str = None, 19 | save_interval: int = None, 20 | ): 21 | self._load_checkpoint_from = load_checkpoint_from 22 | self._save_interval = save_interval 23 | self._save_path = save_path 24 | 25 | def before_run(self, runner): 26 | if self._load_checkpoint_from: 27 | runner.parameter_server.load_weights_from_file(self._load_checkpoint_from) 28 | cur_layer = 0 29 | rpc_model = runner.model 30 | 31 | for idx, module in enumerate(rpc_model.model): 32 | # get the number of layers on this worker 33 | num_layers = len(runner.worker_manager.worker_pool[idx].model_config) 34 | state_dict_list = [] 35 | 36 | # get the state dict from parameter server 37 | for layer_idx in range(cur_layer, cur_layer + num_layers): 38 | state_dict = runner.parameter_server.get_state_dict(layer_idx) 39 | state_dict_list.append(state_dict) 40 | cur_layer += num_layers 41 | 42 | # load the weights onto to the module 43 | module.load_weights(state_dict_list) 44 | 45 | def after_epoch(self, runner): 46 | if not osp.exists(self._save_path): 47 | os.mkdir(self._save_path) 48 | 49 | if self.every_n_epochs(runner, self._save_interval): 50 | # gather weights from workers 51 | rpc_model = runner.model 52 | all_state_dict = [] 53 | 54 | for module in rpc_model.model: 55 | state_dict = module.get_state_dict() 56 | all_state_dict.extend(state_dict) 57 | 58 | # update the weights in parameter server 59 | for i in range(len(all_state_dict)): 60 | try: 61 | runner.parameter_server.update_weights(all_state_dict[i], i) 62 | except: 63 | raise Exception( 64 | "have {} state dicts, have {} layers, error occurs at {}".format( 65 | len(all_state_dict), 66 | len(runner.parameter_server.module_list), 67 | i, 68 | ) 69 | ) 70 | 71 | # save weights 72 | epoch = runner.epoch 73 | file_name = osp.join(self._save_path, "epoch_{}.pth".format(epoch)) 74 | runner.parameter_server.save_weights_to_file(file_name) 75 | -------------------------------------------------------------------------------- /scaelum/runner/hooks_collection/distributed_timer_helper_hook.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | 5 | from scaelum.registry import HOOKS 6 | 7 | from ..hooks import Hook 8 | 9 | 10 | @HOOKS.register_module 11 | class DistributedTimerHelperHook(Hook): 12 | def before_run(self, runner): 13 | runner._timer.clean_prev_file() 14 | 15 | def after_run(self, runner): 16 | runner._timer.clean_prev_file() 17 | -------------------------------------------------------------------------------- /scaelum/runner/hooks_collection/stop_hook.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import os 5 | import os.path as osp 6 | 7 | from scaelum.registry import HOOKS 8 | 9 | from ..hooks import Hook 10 | 11 | 12 | @HOOKS.register_module 13 | class StopHook(Hook): 14 | def __init__(self, root="/tmp"): 15 | super().__init__() 16 | self.file_path = osp.join(root, "stop_flag.txt") 17 | 18 | def after_iter(self, runner): 19 | with open(self.file_path, "r") as f: 20 | flag = f.readline().strip() 21 | 22 | if flag == "1": 23 | runner.iter = runner.max_iters + 1 24 | runner.epoch = runner.max_epochs + 1 25 | 26 | def before_run(self, runner): 27 | with open(self.file_path, "w") as f: 28 | f.write("0") 29 | 30 | def after_run(self, runner): 31 | if osp.exists(self.file_path): 32 | os.remove(self.file_path) 33 | 34 | @staticmethod 35 | def stop(root): 36 | file_path = osp.join(root, "stop_flag.txt") 37 | with open(file_path, "w") as f: 38 | f.write("1") 39 | -------------------------------------------------------------------------------- /scaelum/runner/runner.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | 5 | import torch.distributed.autograd as dist_autograd 6 | import torch.nn as nn 7 | from scaelum import utils as dutils, WorkerManager 8 | from scaelum.dynamics import ParameterServer 9 | from scaelum.logger import Logger 10 | from scaelum.runner import Hook 11 | from scaelum.timer import DistributedTimer 12 | from torch.distributed.optim import DistributedOptimizer 13 | 14 | 15 | class Runner: 16 | def __init__( 17 | self, 18 | model: nn.Module, 19 | parameter_server: ParameterServer, 20 | worker_manager: WorkerManager, 21 | optimizer: DistributedOptimizer, 22 | max_epochs: int, 23 | max_iters: int, 24 | loss_cfg: dict, 25 | timer_cfg: dict, 26 | logging_cfg: dict, 27 | ): 28 | # model and optimizer instance 29 | self.model = model 30 | self.worker_manager = worker_manager 31 | self.parameter_server = parameter_server 32 | self.optimizer = optimizer 33 | 34 | # param 35 | self._hooks = [] 36 | self._epoch = 0 37 | self._iter = 0 38 | self._inner_iter = 0 39 | self._max_epoch = max_epochs 40 | self._max_iter = max_iters 41 | 42 | # logger 43 | self._logging_config = logging_cfg 44 | self._logger = Logger(**logging_cfg) 45 | 46 | # timer 47 | self._timer_config = timer_cfg 48 | self._timer = DistributedTimer(**timer_cfg) 49 | 50 | # build loss 51 | loss_name = loss_cfg.pop("type") 52 | self.loss_function = getattr(nn, loss_name)(**loss_cfg) 53 | 54 | @property 55 | def hooks(self): 56 | """list[:obj:`Hook`]: A list of registered hooks.""" 57 | return self._hooks 58 | 59 | @property 60 | def epoch(self): 61 | """int: Current epoch.""" 62 | return self._epoch 63 | 64 | @epoch.setter 65 | def epoch(self, i): 66 | self._epoch = i 67 | 68 | @property 69 | def iter(self): 70 | """int: Current iteration.""" 71 | return self._iter 72 | 73 | @iter.setter 74 | def iter(self, i): 75 | self._iter = i 76 | 77 | @property 78 | def inner_iter(self): 79 | """int: Iteration in an epoch.""" 80 | return self._inner_iter 81 | 82 | @property 83 | def max_epochs(self): 84 | """int: Maximum training epochs.""" 85 | return self._max_epochs 86 | 87 | @property 88 | def max_iter(self): 89 | """int: Maximum training iterations.""" 90 | return self._max_iter 91 | 92 | def register_hook(self, hook): 93 | assert isinstance(hook, Hook) 94 | self._hooks.append(hook) 95 | 96 | def _call_hook(self, fn_name): 97 | """Call all hooks. 98 | Args: 99 | fn_name (str): The function name in each hook to be called, such as 100 | "before_train_epoch". 101 | """ 102 | for hook in self._hooks: 103 | getattr(hook, fn_name)(self) 104 | 105 | def train(self, data_loader): 106 | # set model to train 107 | self.model.train(True) 108 | self._call_hook("before_run") 109 | 110 | # train by epoch 111 | while self.epoch < self._max_epoch: 112 | # call hook func 113 | self._call_hook("before_train_epoch") 114 | 115 | # train by iter 116 | for batch_index, (data, labels) in enumerate(data_loader): 117 | 118 | # break if max iter is exceeded 119 | if self._iter > self._max_iter: 120 | break 121 | 122 | self._logger.info("epoch: {}, iter: {}".format(self.epoch, self.iter)) 123 | 124 | # call hook func 125 | self._call_hook("before_train_iter") 126 | 127 | with dist_autograd.context() as context_id: 128 | # forward 129 | fwd_start = dutils.get_time() 130 | outputs = self.model(data) 131 | loss = self.loss_function(outputs, labels) 132 | fwd_end = dutils.get_time() 133 | 134 | # Backward pass (run distributed autograd). 135 | bwd_start = dutils.get_time() 136 | self._timer.add_timestamp() 137 | dist_autograd.backward(context_id, [loss]) 138 | bwd_end = dutils.get_time() 139 | self.optimizer.step(context_id) 140 | step_end = dutils.get_time() 141 | 142 | # log time 143 | self._logger.info("forward time: {}".format(fwd_end - fwd_start)) 144 | self._logger.info("backward time: {}".format(bwd_end - bwd_start)) 145 | self._logger.info("step time: {}".format(step_end - bwd_end)) 146 | 147 | # update iter 148 | self._iter += 1 149 | self._call_hook("after_train_iter") 150 | 151 | # update epoch 152 | self._epoch += 1 153 | self._call_hook("after_train_epoch") 154 | 155 | # finish training 156 | self._call_hook("after_run") 157 | -------------------------------------------------------------------------------- /scaelum/stimulator/__init__.py: -------------------------------------------------------------------------------- 1 | from .stimulator import Stimulator 2 | -------------------------------------------------------------------------------- /scaelum/stimulator/stimulator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Stimulator(object): 5 | def __init__(self, worker_num: int): 6 | self.worker_num = worker_num 7 | # generate random slowdown [1, 3) for memory usage 8 | m_rng = np.random.default_rng(seed=22) 9 | self.m_slowdown = 2 * m_rng.random((worker_num + 1,)) + 1 10 | # generate random slowdown [1, 2) for network usage 11 | n_rng = np.random.default_rng(seed=32) 12 | self.n_slowdown = n_rng.random((worker_num + 1,)) + 1 13 | # generate random slowdown [1, 4) for computing power 14 | c_rng = np.random.default_rng(seed=32) 15 | self.c_slowdown = c_rng.random((worker_num + 1,)) + 1 16 | 17 | def memory_slowdown(self, worker_id: int) -> float: 18 | return self.m_slowdown[worker_id] 19 | 20 | def compute_slowdown(self, worker_id: int) -> float: 21 | return self.c_slowdown[worker_id] 22 | 23 | def network_stimulate(self, worker_id: int) -> float: 24 | return self.n_slowdown[worker_id] 25 | -------------------------------------------------------------------------------- /scaelum/timer/__init__.py: -------------------------------------------------------------------------------- 1 | from .timer import DistributedTimer 2 | -------------------------------------------------------------------------------- /scaelum/timer/timer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | 5 | import os 6 | import os.path as osp 7 | import time 8 | 9 | 10 | class DistributedTimer: 11 | def __init__(self, root="/tmp"): 12 | self.file_path = osp.join(root, "dist_timer.txt") 13 | 14 | def clean_prev_file(self): 15 | if osp.exists(self.file_path): 16 | os.remove(self.file_path) 17 | 18 | def add_timestamp(self): 19 | with open(self.file_path, "a") as f: 20 | new_line = "timestamp: {}\n".format(time.time()) 21 | f.write(new_line) 22 | 23 | def get_prev_interval(self): 24 | with open(self.file_path, "r") as f: 25 | all_lines = f.readlines() 26 | start_time = float(all_lines[-2].split(":")[-1]) 27 | end_time = float(all_lines[-1].split(":")[-1]) 28 | 29 | return end_time - start_time 30 | -------------------------------------------------------------------------------- /scaelum/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | 5 | import time 6 | from collections import OrderedDict 7 | from typing import List, Dict 8 | 9 | import torch 10 | import torch.distributed.rpc as rpc 11 | import torch.nn as nn 12 | 13 | # global params 14 | GPU = torch.cuda.is_available() 15 | 16 | 17 | def synchronize(): 18 | if GPU: 19 | torch.cuda.synchronize() 20 | 21 | 22 | def get_time(): 23 | synchronize() 24 | return time.time() 25 | 26 | 27 | def call_method(method, rref, *args, **kwargs): 28 | return method(rref.local_value(), *args, **kwargs) 29 | 30 | 31 | def remote_method(method, rref, *args, **kwargs): 32 | args = [method, rref] + list(args) 33 | return rpc.rpc_sync(rref.owner(), call_method, args=args, kwargs=kwargs) 34 | 35 | 36 | def load_weights(model: nn.Module, state_dict: List[Dict]): 37 | # model.modules() gives a generator 38 | # the first element in the list is the whole module 39 | # thus exclude it when loading weights 40 | modules = list(model.modules())[1] 41 | 42 | error_msg = "Weights do not match the model, model has {} modules while state dict has {}".format( 43 | len(modules), len(state_dict) 44 | ) 45 | assert len(modules) == len(state_dict), error_msg 46 | 47 | # load weights 48 | for idx, mod in enumerate(modules): 49 | mod.load_state_dict(state_dict[idx]) 50 | 51 | 52 | def get_state_dict(model: nn.Module) -> List[Dict]: 53 | modules = list(model.modules())[1] 54 | module_weights = [mod.state_dict() for mod in modules] 55 | module_weights = [weights_to_cpu(weights) for weights in module_weights] 56 | return module_weights 57 | 58 | 59 | def weights_to_cpu(state_dict): 60 | state_dict_cpu = OrderedDict() 61 | for key, val in state_dict.items(): 62 | state_dict_cpu[key] = val.cpu() 63 | return state_dict_cpu 64 | 65 | 66 | def parameter_rrefs(module): 67 | param_rrefs = [] 68 | for param in module.parameters(): 69 | param_rrefs.append(rpc.RRef(param)) 70 | return param_rrefs 71 | 72 | 73 | def count_params(model, to_console=False): 74 | num_params = sum(p.numel() for p in model.parameters()) / 1000000.0 75 | num_grad_params = ( 76 | sum(p.numel() for p in model.parameters() if p.requires_grad) / 1000000.0 77 | ) 78 | 79 | if to_console: 80 | print("Number of parameters: {:.5g} M".format(num_params)) 81 | print("Number of parameters requiring grad: {:.5g} M".format(num_grad_params)) 82 | 83 | return num_params, num_grad_params 84 | 85 | 86 | def generate_worker_name(rank): 87 | return "worker{}".format(rank) 88 | -------------------------------------------------------------------------------- /scaelum/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.1" 2 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | 5 | from setuptools import find_packages, setup 6 | 7 | 8 | def readme(): 9 | with open('README.md', encoding='utf-8') as f: 10 | content = f.read() 11 | return content 12 | 13 | 14 | def read_requirements(): 15 | with open('requirements.txt', 'r') as f: 16 | content = f.readlines() 17 | return content 18 | 19 | 20 | if __name__ == '__main__': 21 | setup(name='scaelum', 22 | version='0.0.1', 23 | description='Accelerating Geo-distributed Computing in Federated Learning', 24 | long_description=readme(), 25 | long_description_content_type="text/markdown", 26 | author='HPC-AI Technology Inc.', 27 | author_email='contact@hpcaitech.com', 28 | url='https://github.com/hpcaitech/SkyComputing', 29 | keywords='Python, scripts', 30 | packages=find_packages(), 31 | classifiers=[ 32 | 'Development Status :: 4 - Beta', 33 | 'License :: OSI Approved :: Apache Software License', 34 | 'Operating System :: OS Independent', 35 | 'Programming Language :: Python :: 3', 36 | 'Programming Language :: Python :: 3.5', 37 | 'Programming Language :: Python :: 3.6', 38 | 'Programming Language :: Python :: 3.7', 39 | ], 40 | license='Apache License 2.0', 41 | install_requires=read_requirements(), 42 | zip_safe=False, 43 | ) 44 | --------------------------------------------------------------------------------