├── mpu ├── tests │ ├── __init__.py │ ├── commons.py │ ├── test_data.py │ ├── test_initialize.py │ ├── test_cross_entropy.py │ └── test_random.py ├── __init__.py ├── utils.py ├── grads.py ├── data.py ├── mappings.py ├── cross_entropy.py ├── initialize.py └── layers.py ├── example.txt ├── data ├── __init__.py ├── test │ ├── test_preprocess_data.sh │ └── test_indexed_dataset.py ├── Makefile ├── samplers.py ├── bert_dataset.py └── gpt2_dataset.py ├── bpe_3w_new └── chinese_vocab.model ├── requirements.txt ├── scripts ├── zero-shot-ocnli.sh ├── zero-shot-tnews.sh ├── zero-shot-iflytek.sh └── generate_text.sh ├── model ├── __init__.py ├── model.py ├── gpt2_modeling.py └── distributed.py ├── fp16 ├── __init__.py ├── fp16util.py └── loss_scaler.py ├── LICENSE ├── data_utils ├── corpora.py ├── tokenization_gpt2.py ├── __init__.py ├── tf_dl.py ├── samplers.py ├── lazy_loader.py └── file_utils.py ├── .gitignore ├── README.md ├── change_mp.py ├── configure_data.py └── utils.py /mpu/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /example.txt: -------------------------------------------------------------------------------- 1 | 中国的首都是北京 2 | 日本的首都是东京 3 | 美国的首都是 4 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from . import indexed_dataset 2 | -------------------------------------------------------------------------------- /bpe_3w_new/chinese_vocab.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TsinghuaAI/CPM-1-Generate/HEAD/bpe_3w_new/chinese_vocab.model -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.6.0 2 | nltk>=3.4 3 | numpy>=1.15.4 4 | pandas>=0.24.0 5 | boto3==1.11.11 6 | regex==2020.1.8 7 | sentencepiece 8 | jieba 9 | pybind11 10 | requests 11 | -------------------------------------------------------------------------------- /data/test/test_preprocess_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | IMPL=cached 4 | python ../preprocess_data.py \ 5 | --input test_samples.json \ 6 | --vocab vocab.txt \ 7 | --dataset-impl ${IMPL} \ 8 | --output-prefix test_samples_${IMPL} \ 9 | --workers 1 \ 10 | --log-interval 2 11 | -------------------------------------------------------------------------------- /data/Makefile: -------------------------------------------------------------------------------- 1 | CXXFLAGS += -O3 -Wall -shared -std=c++11 -fPIC -fdiagnostics-color 2 | CPPFLAGS += $(shell python3 -m pybind11 --includes) 3 | LIBNAME = helpers 4 | LIBEXT = $(shell python3-config --extension-suffix) 5 | 6 | default: $(LIBNAME)$(LIBEXT) 7 | 8 | %$(LIBEXT): %.cpp 9 | $(CXX) $(CXXFLAGS) $(CPPFLAGS) $< -o $@ 10 | -------------------------------------------------------------------------------- /scripts/zero-shot-ocnli.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CHECKPOINT_PATH=$1 4 | # Large 5 | # MPSIZE=2 6 | # NLAYERS=32 7 | # NHIDDEN=2560 8 | # NATT=32 9 | # MAXSEQLEN=1024 10 | 11 | # Small 12 | MPSIZE=2 13 | NLAYERS=12 14 | NHIDDEN=768 15 | NATT=12 16 | MAXSEQLEN=1024 17 | 18 | CMD="python -m torch.distributed.launch --nproc_per_node 2 zero-shot-cls.py \ 19 | --model-parallel-size $MPSIZE \ 20 | --num-layers $NLAYERS \ 21 | --hidden-size $NHIDDEN \ 22 | --load $CHECKPOINT_PATH \ 23 | --num-attention-heads $NATT \ 24 | --seq-length $MAXSEQLEN \ 25 | --max-position-embeddings 1024 \ 26 | --fp16 \ 27 | --cache-dir cache \ 28 | --eval-data-path $2 \ 29 | --tokenizer-path bpe_3w_new/ \ 30 | --vocab-size 30000 \ 31 | --task ocnli " 32 | 33 | $CMD 34 | -------------------------------------------------------------------------------- /scripts/zero-shot-tnews.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CHECKPOINT_PATH=$1 4 | # Large 5 | # MPSIZE=2 6 | # NLAYERS=32 7 | # NHIDDEN=2560 8 | # NATT=32 9 | # MAXSEQLEN=1024 10 | 11 | # Small 12 | MPSIZE=1 13 | NLAYERS=12 14 | NHIDDEN=768 15 | NATT=12 16 | MAXSEQLEN=1024 17 | 18 | CMD="python -m torch.distributed.launch --nproc_per_node 8 zero-shot-cls.py \ 19 | --model-parallel-size $MPSIZE \ 20 | --num-layers $NLAYERS \ 21 | --hidden-size $NHIDDEN \ 22 | --load $CHECKPOINT_PATH \ 23 | --num-attention-heads $NATT \ 24 | --seq-length $MAXSEQLEN \ 25 | --max-position-embeddings 1024 \ 26 | --fp16 \ 27 | --cache-dir cache \ 28 | --eval-data-path $2 \ 29 | --tokenizer-path bpe_3w_new/ \ 30 | --vocab-size 30000 \ 31 | --task tnews " 32 | 33 | $CMD 34 | -------------------------------------------------------------------------------- /scripts/zero-shot-iflytek.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CHECKPOINT_PATH=$1 4 | # Large 5 | # MPSIZE=2 6 | # NLAYERS=32 7 | # NHIDDEN=2560 8 | # NATT=32 9 | # MAXSEQLEN=1024 10 | 11 | # Small 12 | MPSIZE=1 13 | NLAYERS=12 14 | NHIDDEN=768 15 | NATT=12 16 | MAXSEQLEN=1024 17 | 18 | CMD="python -m torch.distributed.launch --nproc_per_node 8 zero-shot-cls.py \ 19 | --model-parallel-size $MPSIZE \ 20 | --num-layers $NLAYERS \ 21 | --hidden-size $NHIDDEN \ 22 | --load $CHECKPOINT_PATH \ 23 | --num-attention-heads $NATT \ 24 | --seq-length $MAXSEQLEN \ 25 | --max-position-embeddings 1024 \ 26 | --fp16 \ 27 | --cache-dir cache \ 28 | --eval-data-path $2 \ 29 | --tokenizer-path bpe_3w_new/ \ 30 | --vocab-size 30000 \ 31 | --task iflytek \ 32 | --seed 5" 33 | 34 | $CMD 35 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from .distributed import * 17 | from .gpt2_modeling import gpt2_get_params_for_weight_decay_optimization 18 | from .gpt2_modeling import GPT2Model 19 | from .model import BertModel 20 | from .model import get_params_for_weight_decay_optimization 21 | -------------------------------------------------------------------------------- /scripts/generate_text.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CHECKPOINT_PATH=$1 4 | MPSIZE=2 5 | NLAYERS=32 6 | NHIDDEN=2560 7 | NATT=32 8 | MAXSEQLEN=1024 9 | 10 | #SAMPLING ARGS 11 | TEMP=0.9 12 | #If TOPK/TOPP are 0 it defaults to greedy sampling, top-k will also override top-p 13 | TOPK=0 14 | TOPP=0 15 | 16 | CMD="python -m torch.distributed.launch --nproc_per_node $MPSIZE generate_samples.py \ 17 | --model-parallel-size $MPSIZE \ 18 | --num-layers $NLAYERS \ 19 | --hidden-size $NHIDDEN \ 20 | --load $CHECKPOINT_PATH \ 21 | --num-attention-heads $NATT \ 22 | --seq-length $MAXSEQLEN \ 23 | --max-position-embeddings 1024 \ 24 | --fp16 \ 25 | --cache-dir cache \ 26 | --out-seq-length 512 \ 27 | --temperature $TEMP \ 28 | --top_k $TOPK \ 29 | --top_p $TOPP \ 30 | --tokenizer-path bpe_3w_new/ \ 31 | --vocab-size 30000 " 32 | 33 | if [ ! -z $2 ]; then 34 | CMD+="--input-text $2" 35 | fi 36 | 37 | $CMD 38 | -------------------------------------------------------------------------------- /fp16/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from .fp16util import ( 16 | BN_convert_float, 17 | network_to_half, 18 | prep_param_lists, 19 | model_grads_to_master_grads, 20 | master_params_to_model_params, 21 | tofp16, 22 | to_python_float, 23 | clip_grad_norm, 24 | convert_module, 25 | convert_network, 26 | FP16Model, 27 | ) 28 | 29 | from .fp16 import * 30 | from .loss_scaler import * 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 THU-PLM 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /data_utils/corpora.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """several datasets with preset arguments""" 16 | from .datasets import json_dataset, csv_dataset 17 | import os 18 | 19 | class wikipedia(json_dataset): 20 | """ 21 | dataset for wikipedia with arguments configured for convenience 22 | 23 | command line usage: `--train-data wikipedia` 24 | """ 25 | PATH = 'data/wikipedia/wikidump_lines.json' 26 | assert_str = "make sure to set PATH for wikipedia data_utils/corpora.py" 27 | def __init__(self, **kwargs): 28 | assert os.path.exists(wikipedia.PATH), \ 29 | wikipedia.assert_str 30 | if not kwargs: 31 | kwargs = {} 32 | kwargs['text_key'] = 'text' 33 | kwargs['loose_json'] = True 34 | super(wikipedia, self).__init__(wikipedia.PATH, **kwargs) 35 | 36 | 37 | class webtext(json_dataset): 38 | """ 39 | dataset for webtext with arguments configured for convenience 40 | 41 | command line usage: `--train-data webtext` 42 | """ 43 | PATH = '/data/private/zhangzhengyan/corpus/merge_new.json' 44 | assert_str = "make sure to set PATH for webtext data_utils/corpora.py" 45 | def __init__(self, **kwargs): 46 | assert os.path.exists(webtext.PATH), \ 47 | webtext.assert_str 48 | if not kwargs: 49 | kwargs = {} 50 | kwargs['text_key'] = 'text' 51 | kwargs['loose_json'] = True 52 | super(webtext, self).__init__(webtext.PATH, **kwargs) 53 | 54 | 55 | NAMED_CORPORA = { 56 | 'wikipedia': wikipedia, 57 | 'webtext': webtext, 58 | } 59 | -------------------------------------------------------------------------------- /mpu/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Model parallel utility interface.""" 17 | 18 | from .cross_entropy import vocab_parallel_cross_entropy 19 | 20 | from .data import broadcast_data 21 | 22 | from .grads import clip_grad_norm 23 | 24 | from .initialize import destroy_model_parallel 25 | from .initialize import get_data_parallel_group 26 | from .initialize import get_data_parallel_rank 27 | from .initialize import get_data_parallel_world_size 28 | from .initialize import get_model_parallel_group 29 | from .initialize import get_model_parallel_rank 30 | from .initialize import get_model_parallel_src_rank 31 | from .initialize import get_model_parallel_world_size 32 | from .initialize import initialize_model_parallel 33 | from .initialize import model_parallel_is_initialized 34 | 35 | from .layers import ColumnParallelLinear 36 | from .layers import ParallelEmbedding 37 | from .layers import RowParallelLinear 38 | from .layers import VocabParallelEmbedding 39 | 40 | from .mappings import copy_to_model_parallel_region 41 | from .mappings import gather_from_model_parallel_region 42 | from .mappings import reduce_from_model_parallel_region 43 | from .mappings import scatter_to_model_parallel_region 44 | 45 | from .random import checkpoint 46 | from .random import partition_activations_in_checkpoint 47 | from .random import get_cuda_rng_tracker 48 | from .random import model_parallel_cuda_manual_seed 49 | 50 | from .transformer import BertParallelSelfAttention 51 | from .transformer import BertParallelTransformerLayer 52 | from .transformer import GPT2ParallelTransformer 53 | from .transformer import LayerNorm 54 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /data_utils/tokenization_gpt2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for OpenAI GPT.""" 16 | from __future__ import (absolute_import, division, print_function, 17 | unicode_literals) 18 | 19 | import sys 20 | import json 21 | import logging 22 | import os 23 | import regex as re 24 | from io import open 25 | import sentencepiece as spm 26 | import jieba 27 | 28 | try: 29 | from functools import lru_cache 30 | except ImportError: 31 | # Just a dummy decorator to get the checks to run on python2 32 | # because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now. 33 | def lru_cache(): 34 | return lambda func: func 35 | 36 | from .file_utils import cached_path 37 | 38 | class GPT2Tokenizer(object): 39 | 40 | def __init__(self, vocab_file, model_file, max_len=None): 41 | self.max_len = max_len if max_len is not None else int(1e12) 42 | self.encoder = json.load(open(vocab_file)) 43 | self.decoder = {v:k for k,v in self.encoder.items()} 44 | 45 | self.sp = spm.SentencePieceProcessor(model_file=model_file) 46 | self.translator = str.maketrans(" \n", "\u2582\u2583") 47 | 48 | self.eod_id = self.encoder[''] 49 | 50 | @property 51 | def vocab_size(self): 52 | return len(self.encoder) 53 | 54 | def __len__(self): 55 | return len(self.encoder) + len(self.special_tokens) 56 | 57 | @property 58 | def eod(self): 59 | return self.eod_id 60 | 61 | def tokenize(self, text): 62 | """ Tokenize a string. """ 63 | seg_list = [x.translate(self.translator) for x in jieba.cut(text, cut_all=False)] 64 | new_seg = " ".join(seg_list) 65 | return self.sp.encode(new_seg) 66 | 67 | def encode(self, text): 68 | res = self.tokenize(text) 69 | return res 70 | 71 | def decode(self, tokens): 72 | text = self.sp.decode(tokens) 73 | text = text.replace(' ', '').replace('\u2582', ' ').replace('\u2583', '\n') 74 | return text 75 | 76 | -------------------------------------------------------------------------------- /mpu/tests/commons.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import argparse 17 | import os 18 | import random 19 | import numpy 20 | import torch 21 | 22 | import mpu 23 | 24 | 25 | class IdentityLayer(torch.nn.Module): 26 | def __init__(self, size, scale=1.0): 27 | super(IdentityLayer, self).__init__() 28 | self.weight = torch.nn.Parameter(scale * torch.randn(size)) 29 | def forward(self): 30 | return self.weight 31 | 32 | 33 | def set_random_seed(seed): 34 | """Set random seed for reproducability.""" 35 | random.seed(seed) 36 | numpy.random.seed(seed) 37 | torch.manual_seed(seed) 38 | mpu.model_parallel_cuda_manual_seed(seed) 39 | 40 | 41 | def initialize_distributed(backend='nccl'): 42 | """Initialize torch.distributed.""" 43 | # Get local rank in case it is provided. 44 | parser = argparse.ArgumentParser() 45 | parser.add_argument('--local_rank', type=int, default=None, 46 | help='local rank passed from distributed launcher') 47 | args = parser.parse_args() 48 | local_rank = args.local_rank 49 | 50 | # Get rank and world size. 51 | rank = int(os.getenv('RANK', '0')) 52 | world_size = int(os.getenv("WORLD_SIZE", '1')) 53 | 54 | print('> initializing torch.distributed with local rank: {}, ' 55 | 'rank: {}, world size: {}'.format(local_rank, rank, world_size)) 56 | 57 | # Set the device id. 58 | device = rank % torch.cuda.device_count() 59 | if local_rank is not None: 60 | device = local_rank 61 | torch.cuda.set_device(device) 62 | 63 | # Call the init process. 64 | init_method = 'tcp://' 65 | master_ip = os.getenv('MASTER_ADDR', 'localhost') 66 | master_port = os.getenv('MASTER_PORT', '6000') 67 | init_method += master_ip + ':' + master_port 68 | torch.distributed.init_process_group( 69 | backend=backend, 70 | world_size=world_size, 71 | rank=rank, 72 | init_method=init_method) 73 | 74 | 75 | def print_separator(message): 76 | torch.distributed.barrier() 77 | filler_len = (78 - len(message)) // 2 78 | filler = '-' * filler_len 79 | string = '\n' + filler + ' {} '.format(message) + filler 80 | if torch.distributed.get_rank() == 0: 81 | print(string, flush=True) 82 | torch.distributed.barrier() 83 | -------------------------------------------------------------------------------- /mpu/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import torch 18 | 19 | 20 | def ensure_divisibility(numerator, denominator): 21 | """Ensure that numerator is divisible by the denominator.""" 22 | assert numerator % denominator == 0, '{} is not divisible by {}'.format( 23 | numerator, denominator) 24 | 25 | 26 | def divide(numerator, denominator): 27 | """Ensure that numerator is divisible by the denominator and return 28 | the division value.""" 29 | ensure_divisibility(numerator, denominator) 30 | return numerator // denominator 31 | 32 | 33 | def split_tensor_along_last_dim(tensor, num_partitions, 34 | contiguous_split_chunks=False): 35 | """Split a tensor along its last dimension. 36 | Arguments: 37 | tensor: input tensor. 38 | num_partitions: number of partitions to split the tensor 39 | contiguous_split_chunks: If True, make each chunk contiguous 40 | in memory. 41 | """ 42 | # Get the size and dimension. 43 | last_dim = tensor.dim() - 1 44 | last_dim_size = divide(tensor.size()[last_dim], num_partitions) 45 | # Split. 46 | tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) 47 | # Note: torch.split does not create contiguous tensors by default. 48 | if contiguous_split_chunks: 49 | return tuple(chunk.contiguous() for chunk in tensor_list) 50 | 51 | return tensor_list 52 | 53 | 54 | class VocabUtility: 55 | """Split the vocabulary into `world_size` chunks amd return the 56 | first and last index of the vocabulary belonging to the `rank` 57 | partition: Note that indecies in [fist, last)""" 58 | 59 | @staticmethod 60 | def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, 61 | rank, world_size): 62 | index_f = rank * per_partition_vocab_size 63 | index_l = index_f + per_partition_vocab_size 64 | return index_f, index_l 65 | 66 | @staticmethod 67 | def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size): 68 | per_partition_vocab_size = divide(global_vocab_size, world_size) 69 | return VocabUtility.vocab_range_from_per_partition_vocab_size( 70 | per_partition_vocab_size, rank, world_size) 71 | -------------------------------------------------------------------------------- /mpu/grads.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | # Parts of the code here are adapted from PyTorch 18 | # repo: https://github.com/pytorch/pytorch 19 | 20 | 21 | import torch 22 | from torch._six import inf 23 | 24 | from .initialize import get_model_parallel_group 25 | from .initialize import get_model_parallel_rank 26 | 27 | 28 | def clip_grad_norm(parameters, max_norm, norm_type=2): 29 | """Clips gradient norm of an iterable of parameters. 30 | 31 | This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and 32 | added functionality to handle model parallel parameters. Note that 33 | the gradients are modified in place. 34 | 35 | Arguments: 36 | parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a 37 | single Tensor that will have gradients normalized 38 | max_norm (float or int): max norm of the gradients 39 | norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for 40 | infinity norm. 41 | 42 | Returns: 43 | Total norm of the parameters (viewed as a single vector). 44 | """ 45 | if isinstance(parameters, torch.Tensor): 46 | parameters = [parameters] 47 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 48 | max_norm = float(max_norm) 49 | norm_type = float(norm_type) 50 | if norm_type == inf: 51 | total_norm = max(p.grad.data.abs().max() for p in parameters) 52 | total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) 53 | # Take max across all GPUs. 54 | torch.distributed.all_reduce(total_norm_cuda, 55 | op=torch.distributed.ReduceOp.MAX, 56 | group=get_model_parallel_group()) 57 | total_norm = total_norm_cuda[0].item() 58 | else: 59 | total_norm = 0 60 | for p in parameters: 61 | if p.model_parallel or (get_model_parallel_rank() == 0): 62 | param_norm = p.grad.data.norm(norm_type) 63 | total_norm += param_norm.item() ** norm_type 64 | # Sum across all model parallel GPUs. 65 | total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) 66 | torch.distributed.all_reduce(total_norm_cuda, 67 | op=torch.distributed.ReduceOp.SUM, 68 | group=get_model_parallel_group()) 69 | total_norm = total_norm_cuda[0].item() ** (1. / norm_type) 70 | clip_coef = max_norm / (total_norm + 1e-6) 71 | if clip_coef < 1: 72 | for p in parameters: 73 | p.grad.data.mul_(clip_coef) 74 | return total_norm 75 | -------------------------------------------------------------------------------- /mpu/tests/test_data.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import functools 17 | import operator 18 | import sys 19 | sys.path.append("../..") 20 | 21 | import torch 22 | import mpu 23 | from mpu import data as data_utils 24 | 25 | from commons import initialize_distributed 26 | from commons import print_separator 27 | 28 | 29 | def test_boradcast_data(model_parallel_size): 30 | 31 | if torch.distributed.get_rank() == 0: 32 | print('> testing boradcast_data with model parallel size {} ...'. 33 | format(model_parallel_size)) 34 | 35 | mpu.initialize_model_parallel(model_parallel_size) 36 | torch.manual_seed(1234 + mpu.get_data_parallel_rank()) 37 | model_parallel_size = mpu.get_model_parallel_world_size() 38 | 39 | key_size_t = {'key1': [7, 11], 40 | 'key2': [8, 2, 1], 41 | 'key3': [13], 42 | 'key4': [5, 1, 2], 43 | 'key5': [5, 12]} 44 | keys = list(key_size_t.keys()) 45 | 46 | data = {} 47 | data_t = {} 48 | for key in key_size_t: 49 | data[key] = torch.LongTensor(size=key_size_t[key]).random_(0, 1000) 50 | data_t[key] = data[key].clone() 51 | data['keyX'] = torch.FloatTensor(size=(5, )).random_(0, 1000) 52 | data_t['keyX'] = data['keyX'].clone() 53 | if mpu.get_model_parallel_rank() != 0: 54 | data = None 55 | 56 | data_utils._check_data_types(keys, data_t, torch.int64) 57 | key_size, key_numel, \ 58 | total_numel = data_utils._build_key_size_numel_dictionaries(keys, data) 59 | for key in keys: 60 | assert key_size[key] == key_size_t[key] 61 | total_numel_t = 0 62 | for key in keys: 63 | target_size = functools.reduce(operator.mul, key_size_t[key], 1) 64 | assert key_numel[key] == target_size 65 | total_numel_t += target_size 66 | assert total_numel == total_numel_t 67 | 68 | data_b = data_utils.broadcast_data(keys, data, torch.int64) 69 | for key in keys: 70 | tensor = data_t[key].cuda() 71 | assert data_b[key].sub(tensor).abs().max() == 0 72 | 73 | # Reset groups 74 | mpu.destroy_model_parallel() 75 | 76 | torch.distributed.barrier() 77 | if torch.distributed.get_rank() == 0: 78 | print('>> passed the test :-)') 79 | 80 | 81 | if __name__ == '__main__': 82 | 83 | initialize_distributed() 84 | world_size = torch.distributed.get_world_size() 85 | 86 | model_parallel_size = 1 87 | while model_parallel_size <= world_size: 88 | print_separator('test test boradcast data') 89 | test_boradcast_data(model_parallel_size) 90 | model_parallel_size *= 2 91 | 92 | 93 | -------------------------------------------------------------------------------- /mpu/tests/test_initialize.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import sys 17 | sys.path.append("../..") 18 | 19 | import torch 20 | import mpu 21 | 22 | from commons import initialize_distributed 23 | from commons import print_separator 24 | 25 | 26 | def test_initialize_model_parallel(model_parallel_size): 27 | 28 | if torch.distributed.get_rank() == 0: 29 | print('> testing initialize_model_parallel with size {} ...'.format( 30 | model_parallel_size)) 31 | model_parallel_size_ = min(model_parallel_size, 32 | torch.distributed.get_world_size()) 33 | assert not mpu.model_parallel_is_initialized() 34 | mpu.initialize_model_parallel(model_parallel_size_) 35 | assert mpu.model_parallel_is_initialized() 36 | 37 | # Checks. 38 | def check(group, world_size, rank): 39 | assert world_size == torch.distributed.get_world_size(group=group) 40 | assert rank == torch.distributed.get_rank(group=group) 41 | 42 | # Model parallel. 43 | world_size = model_parallel_size_ 44 | rank = torch.distributed.get_rank() % model_parallel_size_ 45 | assert world_size == mpu.get_model_parallel_world_size() 46 | assert rank == mpu.get_model_parallel_rank() 47 | check(mpu.get_model_parallel_group(), world_size, rank) 48 | 49 | 50 | # Data parallel. 51 | world_size = torch.distributed.get_world_size() // model_parallel_size_ 52 | rank = torch.distributed.get_rank() // model_parallel_size 53 | assert world_size == mpu.get_data_parallel_world_size() 54 | assert rank == mpu.get_data_parallel_rank() 55 | check(mpu.get_data_parallel_group(), world_size, rank) 56 | 57 | # Reset groups 58 | mpu.destroy_model_parallel() 59 | 60 | torch.distributed.barrier() 61 | if torch.distributed.get_rank() == 0: 62 | print('>> passed the test :-)') 63 | 64 | 65 | def test_get_model_parallel_src_rank(model_parallel_size_): 66 | 67 | if torch.distributed.get_rank() == 0: 68 | print('> testing get_model_parallel_src_rank with size {} ...'.format( 69 | model_parallel_size_)) 70 | model_parallel_size = min(model_parallel_size_, 71 | torch.distributed.get_world_size()) 72 | assert not mpu.model_parallel_is_initialized() 73 | mpu.initialize_model_parallel(model_parallel_size) 74 | assert mpu.model_parallel_is_initialized() 75 | 76 | # Checks 77 | src_rank = torch.distributed.get_rank() - mpu.get_model_parallel_rank() 78 | assert mpu.get_model_parallel_src_rank() == src_rank 79 | 80 | # Reset groups 81 | mpu.destroy_model_parallel() 82 | 83 | torch.distributed.barrier() 84 | if torch.distributed.get_rank() == 0: 85 | print('>> passed the test :-)') 86 | 87 | 88 | if __name__ == '__main__': 89 | 90 | initialize_distributed() 91 | world_size = torch.distributed.get_world_size() 92 | model_parallel_size = 1 93 | while model_parallel_size <= world_size: 94 | print_separator('test initialize model parallel') 95 | test_initialize_model_parallel(model_parallel_size) 96 | print_separator('test model parallel source rank') 97 | test_get_model_parallel_src_rank(model_parallel_size) 98 | model_parallel_size *= 2 99 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CPM-Generate 2 | 3 | 为了促进中文自然语言处理研究的发展,本项目提供了 **CPM-LM** (2.6B) 模型的文本生成代码,可用于文本生成的本地测试,并以此为基础进一步研究零次学习/少次学习等场景。[[模型下载](https://model.baai.ac.cn/model-detail/100017)] [[技术报告](https://www.sciencedirect.com/science/article/pii/S266665102100019X)] 4 | 5 | **若您想使用CPM-1进行推理,我们建议使用高效推理工具[BMInf](https://github.com/OpenBMB/BMInf),支持1060以上显卡单卡推理。** 6 | 7 | ## 安装 8 | 9 | 首先安装pytorch等基础依赖,再安装[APEX](https://github.com/NVIDIA/apex#quick-start)以支持fp16: 10 | ``` 11 | pip install -r requirements.txt 12 | git clone https://github.com/NVIDIA/apex 13 | cd apex 14 | pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ 15 | ``` 16 | 17 | 考虑apex的安装容易发生问题,我们构建了对应的Docker容器,可以进行快速环境搭建。安装方式如下: 18 | ``` 19 | docker pull dmye/cpm:v0 20 | ``` 21 | 参考运行指令如下: 22 | ``` 23 | sudo docker run --gpus '"device=0,1"' -it -v :/CPM --name=cpm cpm:v0 24 | ``` 25 | 其中``为代码所在目录,-v进行文件目录挂载 26 | 27 | 注:感谢qhduan同学提供了基于TensorFlow的[使用代码](https://github.com/qhduan/CPM-LM-TF2),用作Pytorch之外的备选。 28 | 29 | ## 模型 30 | 31 | 模型下载后文件夹的目录结构需设置如下: 32 | ``` 33 | . 34 | ├── 80000 35 | │   ├── mp_rank_00_model_states.pt 36 | │   └── mp_rank_01_model_states.pt 37 | └── latest_checkpointed_iteration.txt 38 | ``` 39 | 为保证下载文件的正确性,文件的checksum如下: 40 | ``` 41 | SHA1 42 | 71d6b6ad4f47b46724eb82c05da8fb9175e62a7d 80000/mp_rank_00_model_states.pt 43 | 42aa247a262e2011fa5e276f1a8389fad6d80edc 80000/mp_rank_01_model_states.pt 44 | MD5 45 | f3f6d2f7d84c6a45290a31dabf79ddac 80000/mp_rank_00_model_states.pt 46 | b0e960be4b5226e759ae6fc5246f9160 80000/mp_rank_01_model_states.pt 47 | ``` 48 | 49 | ## 使用 50 | 51 | 提供了命令行交互式生成: 52 | ``` 53 | bash scripts/generate_text.sh /path/to/CPM 54 | ``` 55 | 如不使用交互式输入,可增加第二个参数,告知输入文本的位置 56 | ``` 57 | bash scripts/generate_text.sh /path/to/CPM example.txt 58 | ``` 59 | 运行该脚本需要两块GPU,每张卡的GPU内存占用约为7GB。该项目主要基于 [Megatron-LM](https://github.com/NVIDIA/Megatron-LM) 进行修改。模型的主体架构与GPT-2一致。 60 | 61 | 默认的模型并行参数为2,如果需要修改,可以使用`change_mp.py`,并调整`generate_text.sh`中的`MPSIZE`。`change_mp.py`的使用示例如下: 62 | ``` 63 | python change_mp.py /path/to/CPM MPSIZE 64 | ``` 65 | 这里的`/path/to/CPM`为模型路径,`MPSIZE`为一个整数,可以为1或者2的倍数,结果会生成一个新的模型,存储路径为`/path/to/CPM_MPSIZE`。 66 | 67 | ## Tokenization 68 | 69 | Tokenization实现主要在`data_util/tokenization_gpt2.py`,先对于文本进行分词,再使用 SentencePiece 得到 BPE 的结果。由于 SentencePiece 不能有效编码空格和换行符,在 BPE 之前,我们将文本中的空格和换行符替换为`\u2582`和`\u2583`。生成文本的时候也会对应的把生成的`\u2582`和`\u2583`替换回空格和换行符。 70 | 71 | 对应[问题](https://kexue.fm/archives/7912)已解决。 72 | 73 | ## 分类任务零次学习(Zero-shot Learning) 74 | 75 | 提供了三个任务的零次学习任务脚本以供参考,包括OCNLI、TNEWS和IFLYTEK,[数据下载链接](https://github.com/CLUEbenchmark/CLUE)。脚本使用方法如下: 76 | ``` 77 | # OCNLI 78 | bash scripts/zero-shot-ocnli.sh /path/to/CPM /path/to/dataset 79 | # TNEWS 80 | bash scripts/zero-shot-tnews.sh /path/to/CPM /path/to/dataset 81 | # IFLYTEK 82 | bash scripts/zero-shot-iflytek.sh /path/to/CPM /path/to/dataset 83 | ``` 84 | 85 | 如果想要在完整标签数据上进程TNEWS和IFLYTEK评测,需要将加载数据函数(`load_iflytek_data`和`load_tnews_data`)中的`sampled_labels`设置为`True`。 86 | 87 | ## 小规模模型 88 | 89 | - [CPM-Distill](https://github.com/TsinghuaAI/CPM-1-Distill) 是 2.6B(26亿)参数 CPM-Large 模型蒸馏版本,参数量为 109M 90 | 91 | - [CPM-Generate-distill](https://huggingface.co/mymusise/CPM-Generate-distill) 是`CPM-Distill`的第三方实现,支持`Pytorch` 和`Tensorflow` 92 | 93 | ## TODO 94 | 95 | - ~~实验环境的docker镜像~~ 96 | - ~~提供各个任务具体的使用模板~~ 97 | - ~~公开技术报告~~ 98 | - ~~模型并行数可动态调整~~ 99 | - ~~Fine-tune代码~~ 100 | - ~~开源实验中使用的小规模模型参数~~ 101 | 102 | ## 引用 103 | 104 | ``` 105 | @article{cpm-v1, 106 | title={CPM: A Large-scale Generative Chinese Pre-trained Language Model}, 107 | author={Zhang, Zhengyan and Han, Xu, and Zhou, Hao, and Ke, Pei, and Gu, Yuxian and Ye, Deming and Qin, Yujia and Su, Yusheng and Ji, Haozhe and Guan, Jian and Qi, Fanchao and Wang, Xiaozhi and Zheng, Yanan and Zeng, Guoyang and Cao, Huanqi and Chen, Shengqi and Li, Daixuan and Sun, Zhenbo and Liu, Zhiyuan and Huang, Minlie and Han, Wentao and Tang, Jie and Li, Juanzi and Sun, Maosong}, 108 | year={2020} 109 | } 110 | ``` 111 | -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Utilities for wrapping BertModel.""" 17 | 18 | import torch 19 | 20 | from .modeling import BertConfig 21 | from .modeling import BertForPreTraining, BertForMaskedLM 22 | from .modeling import BertLayerNorm 23 | 24 | 25 | def get_params_for_weight_decay_optimization(module): 26 | 27 | weight_decay_params = {'params': []} 28 | no_weight_decay_params = {'params': [], 'weight_decay': 0.0} 29 | for module_ in module.modules(): 30 | if isinstance(module_, (BertLayerNorm, torch.nn.LayerNorm)): 31 | no_weight_decay_params['params'].extend( 32 | [p for p in list(module_._parameters.values()) 33 | if p is not None]) 34 | else: 35 | weight_decay_params['params'].extend( 36 | [p for n, p in list(module_._parameters.items()) 37 | if p is not None and n != 'bias']) 38 | no_weight_decay_params['params'].extend( 39 | [p for n, p in list(module_._parameters.items()) 40 | if p is not None and n == 'bias']) 41 | 42 | return weight_decay_params, no_weight_decay_params 43 | 44 | 45 | class BertModel(torch.nn.Module): 46 | 47 | def __init__(self, args): 48 | super(BertModel, self).__init__() 49 | if args.pretrained_bert: 50 | self.model = BertForPreTraining.from_pretrained( 51 | args.tokenizer_model_type, 52 | cache_dir=args.cache_dir, 53 | fp32_layernorm=args.fp32_layernorm, 54 | fp32_embedding=args.fp32_embedding, 55 | layernorm_epsilon=args.layernorm_epsilon) 56 | else: 57 | if args.intermediate_size is None: 58 | intermediate_size = 4 * args.hidden_size 59 | else: 60 | intermediate_size = args.intermediate_size 61 | self.config = BertConfig( 62 | args.tokenizer_num_tokens, 63 | hidden_size=args.hidden_size, 64 | num_hidden_layers=args.num_layers, 65 | num_attention_heads=args.num_attention_heads, 66 | intermediate_size=intermediate_size, 67 | hidden_dropout_prob=args.hidden_dropout, 68 | attention_probs_dropout_prob=args.attention_dropout, 69 | max_position_embeddings=args.max_position_embeddings, 70 | type_vocab_size=args.tokenizer_num_type_tokens, 71 | fp32_layernorm=args.fp32_layernorm, 72 | fp32_embedding=args.fp32_embedding, 73 | fp32_tokentypes=args.fp32_tokentypes, 74 | layernorm_epsilon=args.layernorm_epsilon, 75 | deep_init=args.deep_init) 76 | self.model = BertForPreTraining(self.config) 77 | 78 | def forward(self, input_tokens, token_type_ids=None, 79 | attention_mask=None, checkpoint_activations=False): 80 | return self.model( 81 | input_tokens, token_type_ids, attention_mask, 82 | checkpoint_activations=checkpoint_activations) 83 | 84 | def state_dict(self, destination=None, prefix='', keep_vars=False): 85 | return self.model.state_dict(destination=destination, prefix=prefix, 86 | keep_vars=keep_vars) 87 | 88 | def load_state_dict(self, state_dict, strict=True): 89 | return self.model.load_state_dict(state_dict, strict=strict) 90 | 91 | -------------------------------------------------------------------------------- /mpu/tests/test_cross_entropy.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import random 17 | import sys 18 | sys.path.append("../..") 19 | 20 | import torch 21 | import torch.nn.functional as F 22 | import mpu 23 | from mpu.cross_entropy import vocab_parallel_cross_entropy 24 | 25 | from commons import initialize_distributed 26 | from commons import print_separator 27 | from commons import IdentityLayer 28 | from commons import set_random_seed 29 | 30 | 31 | def torch_cross_entropy(batch_size, seq_length, vocab_size, 32 | logits_scale, seed): 33 | set_random_seed(seed) 34 | identity = IdentityLayer((batch_size, seq_length, vocab_size), 35 | scale=logits_scale).cuda() 36 | logits = identity() 37 | target = torch.cuda.LongTensor( 38 | size=(batch_size, seq_length)).random_(0, vocab_size) 39 | loss = F.cross_entropy(logits.view(-1, logits.size()[-1]), 40 | target.view(-1), 41 | reduction='none').view_as(target).mean() 42 | loss.backward() 43 | return loss, identity.weight.grad 44 | 45 | 46 | def mpu_cross_entropy(batch_size, seq_length, vocab_size, 47 | logits_scale, seed): 48 | set_random_seed(seed) 49 | identity = IdentityLayer((batch_size, seq_length, vocab_size), 50 | scale=logits_scale).cuda() 51 | logits = identity() 52 | logits_parallel = mpu.scatter_to_model_parallel_region(logits) 53 | target = torch.cuda.LongTensor( 54 | size=(batch_size, seq_length)).random_(0, vocab_size) 55 | loss = vocab_parallel_cross_entropy(logits_parallel, target).mean() 56 | loss.backward() 57 | return loss, identity.weight.grad 58 | 59 | 60 | def test_cross_entropy(model_parallel_size): 61 | 62 | if torch.distributed.get_rank() == 0: 63 | print('> testing cross entropy with model parallel size {} ...'. 64 | format(model_parallel_size)) 65 | 66 | mpu.initialize_model_parallel(model_parallel_size) 67 | model_parallel_size = mpu.get_model_parallel_world_size() 68 | 69 | batch_size = 13 70 | seq_length = 17 71 | vocab_size_per_partition = 11 72 | logits_scale = 1000.0 73 | vocab_size = vocab_size_per_partition * model_parallel_size 74 | seed = 1234 75 | 76 | loss_torch, grad_torch = torch_cross_entropy(batch_size, seq_length, 77 | vocab_size, logits_scale, 78 | seed) 79 | loss_mpu, grad_mpu = mpu_cross_entropy(batch_size, seq_length, 80 | vocab_size, logits_scale, 81 | seed) 82 | 83 | error = loss_torch.sub_(loss_mpu).abs().max() 84 | print(' max error in loss on global rank {}: {}'.format( 85 | torch.distributed.get_rank(), error)) 86 | assert error < 1.0e-6 87 | 88 | error = grad_torch.sub_(grad_mpu).abs().max() 89 | print(' max error in grad on global rank {}: {}'.format( 90 | torch.distributed.get_rank(), error)) 91 | assert error < 1.0e-6 92 | 93 | # Reset groups 94 | mpu.destroy_model_parallel() 95 | 96 | torch.distributed.barrier() 97 | if torch.distributed.get_rank() == 0: 98 | print('>> passed the test :-)') 99 | 100 | 101 | if __name__ == '__main__': 102 | 103 | initialize_distributed() 104 | world_size = torch.distributed.get_world_size() 105 | 106 | model_parallel_size = 1 107 | while model_parallel_size <= world_size: 108 | print_separator('test cross entropy') 109 | test_cross_entropy(model_parallel_size) 110 | model_parallel_size *= 2 111 | -------------------------------------------------------------------------------- /mpu/data.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | 18 | from .initialize import get_model_parallel_group 19 | from .initialize import get_model_parallel_rank 20 | from .initialize import get_model_parallel_src_rank 21 | 22 | 23 | _MAX_DATA_DIM = 4 24 | 25 | 26 | def _check_data_types(keys, data, target_dtype): 27 | """Check that all the keys have the same target data type.""" 28 | for key in keys: 29 | assert data[key].dtype == target_dtype, '{} has data type {} which '\ 30 | 'is different than {}'.format(key, data[key].dtype, target_dtype) 31 | 32 | 33 | def _build_key_size_numel_dictionaries(keys, data): 34 | """Build the size on rank 0 and broadcast.""" 35 | max_dim = _MAX_DATA_DIM 36 | sizes = [0 for _ in range(max_dim) for _ in keys] 37 | 38 | # Pack the sizes on rank zero. 39 | if get_model_parallel_rank() == 0: 40 | offset = 0 41 | for key in keys: 42 | assert data[key].dim() < max_dim, 'you should increase MAX_DATA_DIM' 43 | size = data[key].size() 44 | for i, s in enumerate(size): 45 | sizes[i + offset] = s 46 | offset += max_dim 47 | 48 | # Move to GPU and broadcast. 49 | sizes_cuda = torch.cuda.LongTensor(sizes) 50 | torch.distributed.broadcast(sizes_cuda, get_model_parallel_src_rank(), 51 | group=get_model_parallel_group()) 52 | 53 | # Move back to cpu and unpack. 54 | sizes_cpu = sizes_cuda.cpu() 55 | key_size = {} 56 | key_numel = {} 57 | total_numel = 0 58 | offset = 0 59 | for key in keys: 60 | i = 0 61 | size = [] 62 | numel = 1 63 | while sizes_cpu[offset + i] > 0: 64 | this_size = sizes_cpu[offset + i] 65 | size.append(this_size) 66 | numel *= this_size 67 | i += 1 68 | key_size[key] = size 69 | key_numel[key] = numel 70 | total_numel += numel 71 | offset += max_dim 72 | 73 | return key_size, key_numel, total_numel 74 | 75 | 76 | def broadcast_data(keys, data, datatype): 77 | """Broadcast data from rank zero of each model parallel group to the 78 | members of the same model parallel group. 79 | 80 | Arguments: 81 | keys: list of keys in the data disctionary to be broadcasted 82 | data: data dictionary of string keys and cpu tensor values. 83 | datatype: torch data type of all tensors in data associated 84 | with keys. 85 | """ 86 | # Build (key, size) and (key, number of elements) dictionaries along 87 | # with the total number of elements on all ranks. 88 | key_size, key_numel, total_numel = _build_key_size_numel_dictionaries(keys, 89 | data) 90 | 91 | # Pack on rank zero. 92 | if get_model_parallel_rank() == 0: 93 | # Check that all keys have the same data type. 94 | _check_data_types(keys, data, datatype) 95 | # Flatten the data associated with the keys 96 | flatten_data = torch.cat( 97 | [data[key].contiguous().view(-1) for key in keys], dim=0).cuda() 98 | else: 99 | flatten_data = torch.empty(total_numel, 100 | device=torch.cuda.current_device(), 101 | dtype=datatype) 102 | 103 | # Boradcast 104 | torch.distributed.broadcast(flatten_data, get_model_parallel_src_rank(), 105 | group=get_model_parallel_group()) 106 | 107 | # Unpack 108 | output = {} 109 | offset = 0 110 | for key in keys: 111 | size = key_size[key] 112 | numel = key_numel[key] 113 | output[key] = flatten_data.narrow(0, offset, numel).view(size) 114 | offset += numel 115 | 116 | return output 117 | -------------------------------------------------------------------------------- /mpu/mappings.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | 18 | from .initialize import get_model_parallel_group 19 | from .utils import split_tensor_along_last_dim 20 | 21 | 22 | def _reduce(input_): 23 | """All-reduce the the input tensor across model parallel group.""" 24 | group = get_model_parallel_group() 25 | 26 | # Bypass the function if we are using only 1 GPU. 27 | if torch.distributed.get_world_size(group=group) == 1: 28 | return input_ 29 | 30 | # All-reduce. 31 | torch.distributed.all_reduce(input_, group=group) 32 | 33 | return input_ 34 | 35 | 36 | def _split(input_): 37 | """Split the tensor along its last dimension and keep the 38 | corresponding slice.""" 39 | group = get_model_parallel_group() 40 | 41 | # Bypass the function if we are using only 1 GPU. 42 | if torch.distributed.get_world_size(group=group) == 1: 43 | return input_ 44 | 45 | # Split along last dimension. 46 | world_size = torch.distributed.get_world_size(group=group) 47 | input_list = split_tensor_along_last_dim(input_, world_size) 48 | 49 | # Note: torch.split does not create contiguous tensors by default. 50 | rank = torch.distributed.get_rank(group=group) 51 | output = input_list[rank].contiguous() 52 | 53 | return output 54 | 55 | 56 | def _gather(input_): 57 | """Gather tensors and concatinate along the last dimension.""" 58 | group = get_model_parallel_group() 59 | 60 | # Bypass the function if we are using only 1 GPU. 61 | if torch.distributed.get_world_size(group=group) == 1: 62 | return input_ 63 | 64 | # Size and dimension. 65 | last_dim = input_.dim() - 1 66 | rank = torch.distributed.get_rank(group=group) 67 | world_size = torch.distributed.get_world_size(group=group) 68 | 69 | tensor_list = [torch.empty_like(input_) for _ in range(world_size)] 70 | tensor_list[rank] = input_ 71 | torch.distributed.all_gather(tensor_list, input_, group=group) 72 | 73 | # Note: torch.cat already creates a contiguous tensor. 74 | output = torch.cat(tensor_list, dim=last_dim).contiguous() 75 | 76 | return output 77 | 78 | 79 | class _CopyToModelParallelRegion(torch.autograd.Function): 80 | """Pass the input to the model parallel region.""" 81 | 82 | @staticmethod 83 | def forward(ctx, input_): 84 | return input_ 85 | 86 | @staticmethod 87 | def backward(ctx, grad_output): 88 | return _reduce(grad_output) 89 | 90 | 91 | class _ReduceFromModelParallelRegion(torch.autograd.Function): 92 | """All-redcue the input from the model parallel region.""" 93 | 94 | @staticmethod 95 | def forward(ctx, input_): 96 | return _reduce(input_) 97 | 98 | @staticmethod 99 | def backward(ctx, grad_output): 100 | return grad_output 101 | 102 | 103 | class _ScatterToModelParallelRegion(torch.autograd.Function): 104 | """Split the input and keep only the corresponding chuck to the rank.""" 105 | 106 | @staticmethod 107 | def forward(ctx, input_): 108 | return _split(input_) 109 | 110 | @staticmethod 111 | def backward(ctx, grad_output): 112 | return _gather(grad_output) 113 | 114 | 115 | class _GatherFromModelParallelRegion(torch.autograd.Function): 116 | """Gather the input from model parallel region and concatinate.""" 117 | 118 | @staticmethod 119 | def forward(ctx, input_): 120 | return _gather(input_) 121 | 122 | @staticmethod 123 | def backward(ctx, grad_output): 124 | return _split(grad_output) 125 | 126 | 127 | # ----------------- 128 | # Helper functions. 129 | # ----------------- 130 | 131 | def copy_to_model_parallel_region(input_): 132 | return _CopyToModelParallelRegion.apply(input_) 133 | 134 | def reduce_from_model_parallel_region(input_): 135 | return _ReduceFromModelParallelRegion.apply(input_) 136 | 137 | def scatter_to_model_parallel_region(input_): 138 | return _ScatterToModelParallelRegion.apply(input_) 139 | 140 | def gather_from_model_parallel_region(input_): 141 | return _GatherFromModelParallelRegion.apply(input_) 142 | -------------------------------------------------------------------------------- /mpu/cross_entropy.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import torch 18 | 19 | from .initialize import get_model_parallel_group 20 | from .initialize import get_model_parallel_rank 21 | from .initialize import get_model_parallel_world_size 22 | from .utils import VocabUtility 23 | 24 | 25 | class _VocabParallelCrossEntropy(torch.autograd.Function): 26 | 27 | @staticmethod 28 | def forward(ctx, vocab_parallel_logits, target): 29 | 30 | # Copy so the input remains unchanged. 31 | logits = vocab_parallel_logits.clone() 32 | # Maximum value along vocab dimension across all GPUs. 33 | logits_max = torch.max(logits, dim=-1)[0] 34 | torch.distributed.all_reduce(logits_max, 35 | op=torch.distributed.ReduceOp.MAX, 36 | group=get_model_parallel_group()) 37 | # Subtract the maximum value. 38 | logits.sub_(logits_max.unsqueeze(dim=-1)) 39 | # Sum of exponential of logits along vocab dimension across all GPUs. 40 | exp_logits = logits.exp() 41 | sum_exp_logits = exp_logits.sum(dim=-1) 42 | torch.distributed.all_reduce(sum_exp_logits, 43 | op=torch.distributed.ReduceOp.SUM, 44 | group=get_model_parallel_group()) 45 | 46 | # Get the partition's vocab indecies 47 | get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size 48 | partition_vocab_size = vocab_parallel_logits.size()[-1] 49 | rank = get_model_parallel_rank() 50 | world_size = get_model_parallel_world_size() 51 | vocab_start_index, vocab_end_index = get_vocab_range( 52 | partition_vocab_size, rank, world_size) 53 | 54 | # Create a mask of valid vocab ids (1 means it needs to be masked). 55 | target_mask = (target < vocab_start_index) | (target >= vocab_end_index) 56 | masked_target = target.clone() - vocab_start_index 57 | masked_target[target_mask] = 0 58 | 59 | # Get predicted-logits = logits[target]. 60 | # For Simplicity, we convert logits to a 2-D tensor with size 61 | # [*, partition-vocab-size] and target to a 1-D tensor of size [*]. 62 | logits_2d = logits.view(-1, partition_vocab_size) 63 | masked_target_1d = masked_target.view(-1) 64 | arange_1d = torch.arange(start=0, end=logits_2d.size()[0], 65 | device=logits_2d.device) 66 | predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] 67 | predicted_logits = predicted_logits_1d.view_as(target) 68 | predicted_logits[target_mask] = 0.0 69 | # All reduce is needed to get the chunks from other GPUs. 70 | torch.distributed.all_reduce(predicted_logits, 71 | op=torch.distributed.ReduceOp.SUM, 72 | group=get_model_parallel_group()) 73 | 74 | # Loss = log(sum(exp(logits))) - predicted-logit. 75 | loss = torch.log(sum_exp_logits) - predicted_logits 76 | 77 | # Store softmax, target-mask and masked-target for backward pass. 78 | exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) 79 | ctx.save_for_backward(exp_logits, target_mask, masked_target_1d) 80 | 81 | return loss 82 | 83 | @staticmethod 84 | def backward(ctx, grad_output): 85 | 86 | # Retreive tensors from the forward path. 87 | softmax, target_mask, masked_target_1d = ctx.saved_tensors 88 | 89 | # All the inputs have softmax as thier gradient. 90 | grad_input = softmax 91 | # For simplicity, work with the 2D gradient. 92 | partition_vocab_size = softmax.size()[-1] 93 | grad_2d = grad_input.view(-1, partition_vocab_size) 94 | 95 | # Add the gradient from matching classes. 96 | arange_1d = torch.arange(start=0, end=grad_2d.size()[0], 97 | device=grad_2d.device) 98 | grad_2d[arange_1d, masked_target_1d] -= ( 99 | 1.0 - target_mask.view(-1).float()) 100 | 101 | # Finally elementwise multiplication with the output gradients. 102 | grad_input.mul_(grad_output.unsqueeze(dim=-1)) 103 | 104 | return grad_input, None 105 | 106 | 107 | def vocab_parallel_cross_entropy(vocab_parallel_logits, target): 108 | """Helper function for the cross entropy.""" 109 | return _VocabParallelCrossEntropy.apply(vocab_parallel_logits, target) 110 | -------------------------------------------------------------------------------- /data/test/test_indexed_dataset.py: -------------------------------------------------------------------------------- 1 | # This file isn't really a formal automated test, it's just a place to 2 | # put some code used during development and manual testing of 3 | # indexed_dataset. 4 | 5 | from megatron.data import indexed_dataset 6 | from megatron.tokenizer import build_tokenizer 7 | import argparse 8 | import os 9 | import sys 10 | 11 | import torch 12 | 13 | script_dir = os.path.dirname(os.path.realpath(__file__)) 14 | sys.path.append(os.path.join(script_dir, "../../../")) 15 | 16 | 17 | def test_indexed_dataset(args): 18 | ds = indexed_dataset.make_dataset(args.data, args.dataset_impl) 19 | tokenizer = build_tokenizer(args) 20 | print(len(ds.doc_idx)) 21 | print(len(ds)) 22 | print(ds.doc_idx[-1]) 23 | if ds.supports_prefetch: 24 | # just prefetch the whole thing in test (so assume it is small) 25 | ds.prefetch(range(len(ds))) 26 | if args.count > len(ds.doc_idx) - 1: 27 | args.count = len(ds.doc_idx) - 1 28 | 29 | for i in range(args.count): 30 | start = ds.doc_idx[i] 31 | end = ds.doc_idx[i + 1] 32 | ids = ds[start:end] 33 | print(f"Document {i}:") 34 | print("--------------") 35 | for s in ids: 36 | assert len(s) > 0 37 | l = s.data.tolist() 38 | text = tokenizer.detokenize(l) 39 | print(text) 40 | print("---") 41 | 42 | 43 | def test_indexed_dataset_get(args): 44 | ds = indexed_dataset.make_dataset(args.data, args.dataset_impl) 45 | tokenizer = build_tokenizer(args) 46 | size = ds.sizes[0] 47 | print(f"size: {size}") 48 | full = ds.get(0) 49 | print(full) 50 | # print(tokenizer.detokenize(full.data.tolist())) 51 | print("---") 52 | end = ds.get(0, offset=size - 10) 53 | print(end) 54 | # print(tokenizer.detokenize(end.data.tolist())) 55 | 56 | start = ds.get(0, length=10) 57 | print(start) 58 | # print(tokenizer.detokenize(start.data.tolist())) 59 | 60 | part = ds.get(0, offset=2, length=8) 61 | print(part) 62 | # print(tokenizer.detokenize(part.data.tolist())) 63 | 64 | # def test_albert_dataset(args): 65 | # # tokenizer = FullBertTokenizer(args.vocab, do_lower_case=True) 66 | # # idataset = indexed_dataset.make_dataset(args.data, args.dataset_impl) 67 | # # ds = AlbertDataset(idataset, tokenizer) 68 | # ds = AlbertDataset.from_paths(args.vocab, args.data, args.dataset_impl, 69 | # args.epochs, args.max_num_samples, 70 | # args.masked_lm_prob, args.seq_length, 71 | # args.short_seq_prob, args.seed) 72 | # truncated = 0 73 | # total = 0 74 | # for i, s in enumerate(ds): 75 | # ids = s['text'] 76 | # tokens = ds.tokenizer.convert_ids_to_tokens(ids) 77 | # print(tokens) 78 | # if i >= args.count-1: 79 | # exit() 80 | 81 | 82 | def main(): 83 | parser = argparse.ArgumentParser() 84 | parser.add_argument('--data', type=str, help='prefix to data files') 85 | parser.add_argument('--dataset-impl', type=str, default='infer', 86 | choices=['lazy', 'cached', 'mmap', 'infer']) 87 | parser.add_argument('--count', type=int, default=10, 88 | help='Number of samples/documents to print') 89 | 90 | group = parser.add_argument_group(title='tokenizer') 91 | group.add_argument('--tokenizer-type', type=str, required=True, 92 | choices=['BertWordPieceLowerCase', 93 | 'GPT2BPETokenizer'], 94 | help='What type of tokenizer to use.') 95 | group.add_argument('--vocab-file', type=str, default=None, 96 | help='Path to the vocab file') 97 | group.add_argument('--merge-file', type=str, default=None, 98 | help='Path to the BPE merge file (if necessary).') 99 | 100 | parser.add_argument('--epochs', type=int, default=5, 101 | help='Number of epochs to plan for') 102 | parser.add_argument('--max-num-samples', type=int, default=None, 103 | help='Maximum number of samples to plan for') 104 | parser.add_argument('--masked-lm-prob', type=float, default=0.15, 105 | help='probability of masking tokens') 106 | parser.add_argument('--seq-length', type=int, default=512, 107 | help='maximum sequence length') 108 | parser.add_argument('--short-seq-prob', type=float, default=0.1, 109 | help='probability of creating a short sequence') 110 | parser.add_argument('--seed', type=int, default=1234, 111 | help='random seed') 112 | args = parser.parse_args() 113 | args.rank = 0 114 | args.make_vocab_size_divisible_by = 128 115 | args.model_parallel_size = 1 116 | 117 | if args.dataset_impl == "infer": 118 | args.dataset_impl = indexed_dataset.infer_dataset_impl(args.data) 119 | 120 | # test_albert_dataset(args) 121 | test_indexed_dataset_get(args) 122 | 123 | 124 | if __name__ == "__main__": 125 | main() 126 | -------------------------------------------------------------------------------- /model/gpt2_modeling.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """GPT-2 model.""" 17 | 18 | import torch 19 | import torch.nn.functional as F 20 | 21 | import mpu 22 | 23 | 24 | def init_method_normal(std=0.02): 25 | """Init method based on normal distribution. 26 | 27 | This is only used for embeddings. The transformer has its 28 | own initializer. 29 | """ 30 | def init_(tensor): 31 | return torch.nn.init.normal_(tensor, mean=0.0, std=std) 32 | return init_ 33 | 34 | 35 | class GPT2Model(torch.nn.Module): 36 | """GPT-2 Language model. 37 | 38 | The output of the forward method are the logits (parallel or 39 | serial depending on the `parallel_output` flag. 40 | """ 41 | 42 | def __init__(self, 43 | num_layers, 44 | vocab_size, 45 | hidden_size, 46 | num_attention_heads, 47 | embedding_dropout_prob, 48 | attention_dropout_prob, 49 | output_dropout_prob, 50 | max_sequence_length, 51 | checkpoint_activations, 52 | checkpoint_num_layers=1, 53 | parallel_output=True): 54 | 55 | super(GPT2Model, self).__init__() 56 | 57 | self.parallel_output = parallel_output 58 | 59 | init_method = init_method_normal(std=0.02) 60 | 61 | # Word embeddings (parallel). 62 | self.word_embeddings = mpu.VocabParallelEmbedding( 63 | vocab_size, hidden_size, init_method=init_method) 64 | 65 | # Position embedding (serial). 66 | self.position_embeddings = torch.nn.Embedding(max_sequence_length, 67 | hidden_size) 68 | # Initialize the position embeddings. 69 | init_method(self.position_embeddings.weight) 70 | 71 | # Embeddings dropout 72 | self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob) 73 | 74 | # Transformer 75 | self.transformer = mpu.GPT2ParallelTransformer(num_layers, 76 | hidden_size, 77 | num_attention_heads, 78 | attention_dropout_prob, 79 | output_dropout_prob, 80 | checkpoint_activations, 81 | checkpoint_num_layers) 82 | 83 | def forward(self, input_ids, position_ids, attention_mask, past_key_values=None, use_cache=False): 84 | 85 | # Embeddings. 86 | words_embeddings = self.word_embeddings(input_ids) 87 | position_embeddings = self.position_embeddings(position_ids) 88 | embeddings = words_embeddings + position_embeddings 89 | 90 | # Dropout. 91 | embeddings = self.embedding_dropout(embeddings) 92 | 93 | # Transformer. 94 | transformer_output, presents = self.transformer(embeddings, attention_mask, past_key_values=past_key_values, use_cache=use_cache) 95 | 96 | # Parallel logits. 97 | transformer_output_parallel = mpu.copy_to_model_parallel_region( 98 | transformer_output) 99 | logits_parallel = F.linear(transformer_output_parallel, 100 | self.word_embeddings.weight) 101 | 102 | if self.parallel_output: 103 | return logits_parallel, presents 104 | 105 | return mpu.gather_from_model_parallel_region(logits_parallel), presents 106 | 107 | 108 | def gpt2_get_params_for_weight_decay_optimization(module): 109 | 110 | weight_decay_params = {'params': []} 111 | no_weight_decay_params = {'params': [], 'weight_decay': 0.0} 112 | for module_ in module.modules(): 113 | if isinstance(module_, (mpu.LayerNorm, torch.nn.LayerNorm)): 114 | no_weight_decay_params['params'].extend( 115 | [p for p in list(module_._parameters.values()) 116 | if p is not None]) 117 | else: 118 | weight_decay_params['params'].extend( 119 | [p for n, p in list(module_._parameters.items()) 120 | if p is not None and n != 'bias']) 121 | no_weight_decay_params['params'].extend( 122 | [p for n, p in list(module_._parameters.items()) 123 | if p is not None and n == 'bias']) 124 | 125 | return weight_decay_params, no_weight_decay_params 126 | -------------------------------------------------------------------------------- /model/distributed.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 18 | import torch.distributed as dist 19 | from torch.nn.modules import Module 20 | from torch.autograd import Variable 21 | 22 | import mpu 23 | 24 | class DistributedDataParallel(Module): 25 | 26 | def __init__(self, module): 27 | super(DistributedDataParallel, self).__init__() 28 | self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False 29 | 30 | self.module = module 31 | self.data_parallel_group = mpu.get_data_parallel_group() 32 | src_rank = mpu.get_model_parallel_rank() 33 | for p in self.module.parameters(): 34 | if torch.is_tensor(p): 35 | dist.broadcast(p, src_rank, group=self.data_parallel_group) 36 | 37 | def allreduce_params(reduce_after=True, no_scale=False, fp32_allreduce=False): 38 | if(self.needs_reduction): 39 | self.needs_reduction = False 40 | buckets = {} 41 | for name, param in self.module.named_parameters(): 42 | if param.requires_grad and param.grad is not None: 43 | tp = (param.data.type()) 44 | if tp not in buckets: 45 | buckets[tp] = [] 46 | buckets[tp].append(param) 47 | if self.warn_on_half: 48 | if torch.cuda.HalfTensor in buckets: 49 | print("WARNING: gloo dist backend for half parameters may be extremely slow." + 50 | " It is recommended to use the NCCL backend in this case.") 51 | self.warn_on_half = False 52 | for tp in buckets: 53 | bucket = buckets[tp] 54 | grads = [param.grad.data for param in bucket] 55 | coalesced = _flatten_dense_tensors(grads) 56 | if fp32_allreduce: 57 | coalesced = coalesced.float() 58 | if not no_scale and not reduce_after: 59 | coalesced /= dist.get_world_size(group=self.data_parallel_group) 60 | dist.all_reduce(coalesced, group=self.data_parallel_group) 61 | torch.cuda.synchronize() 62 | if not no_scale and reduce_after: 63 | coalesced /= dist.get_world_size(group=self.data_parallel_group) 64 | for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): 65 | buf.copy_(synced) 66 | self.hook_handles = [] 67 | self.hooks = [] 68 | for param in list(self.module.parameters()): 69 | def allreduce_hook(*unused): 70 | Variable._execution_engine.queue_callback(allreduce_params) 71 | # handle = param.register_hook(allreduce_hook) 72 | #self.hooks.append(allreduce_hook) 73 | #self.hook_handles.append(handle) 74 | self.allreduce_params = allreduce_params 75 | 76 | def forward(self, *inputs, **kwargs): 77 | self.needs_reduction = True 78 | return self.module(*inputs, **kwargs) 79 | 80 | def state_dict(self, destination=None, prefix='', keep_vars=False): 81 | #[h.remove() for h in self.hook_handles] 82 | sd = self.module.state_dict(destination, prefix, keep_vars) 83 | # for handle, hook in zip(self.hook_handles, self.hooks): 84 | # d = handle.hooks_dict_ref() 85 | # d[handle.id] = hook 86 | 87 | return sd 88 | 89 | def load_state_dict(self, state_dict, strict=True): 90 | self.module.load_state_dict(state_dict, strict=strict) 91 | 92 | ''' 93 | def _sync_buffers(self): 94 | buffers = list(self.module._all_buffers()) 95 | if len(buffers) > 0: 96 | # cross-node buffer sync 97 | flat_buffers = _flatten_dense_tensors(buffers) 98 | dist.broadcast(flat_buffers, 0) 99 | for buf, synced in zip(buffers, _unflatten_dense_tensors(flat_buffers, buffers)): 100 | buf.copy_(synced) 101 | def train(self, mode=True): 102 | # Clear NCCL communicator and CUDA event cache of the default group ID, 103 | # These cache will be recreated at the later call. This is currently a 104 | # work-around for a potential NCCL deadlock. 105 | if dist._backend == dist.dist_backend.NCCL: 106 | dist._clear_group_cache() 107 | super(DistributedDataParallel, self).train(mode) 108 | self.module.train(mode) 109 | ''' 110 | 111 | -------------------------------------------------------------------------------- /mpu/initialize.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | """Model and data parallel groups.""" 18 | 19 | import torch 20 | 21 | from .utils import ensure_divisibility 22 | 23 | 24 | # Model parallel group that the current rank belongs to. 25 | _MODEL_PARALLEL_GROUP = None 26 | # Data parallel group that the current rank belongs to. 27 | _DATA_PARALLEL_GROUP = None 28 | 29 | 30 | def initialize_model_parallel(model_parallel_size_): 31 | """ 32 | Initialize model data parallel groups. 33 | 34 | Arguments: 35 | model_parallel_size: number of GPUs used to parallelize model. 36 | 37 | Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we 38 | use 2 GPUs to parallelize the model. The present function will 39 | create 4 model parallel groups and 2 data parallel grous as: 40 | 4 model parallel groups: 41 | [g0, g1], [g2, g3], [g4, g5], [g6, g7] 42 | 2 data parallel groups: 43 | [g0, g2, g4, g6], [g1, g3, g5, g7] 44 | Note that for efficiency, the caller should make sure adjacent ranks 45 | are on the same DGX box. For example if we are using 2 DGX-1 boxes 46 | with a total of 16 GPUs, rank 0 to 7 belong to the first box and 47 | ranks 8 to 15 belong to the second box. 48 | """ 49 | if torch.distributed.get_rank() == 0: 50 | print('> initializing model parallel with size {}'.format( 51 | model_parallel_size_)) 52 | # Get world size and rank. Ensure some consistencies. 53 | assert torch.distributed.is_initialized() 54 | world_size = torch.distributed.get_world_size() 55 | model_parallel_size = min(model_parallel_size_, world_size) 56 | ensure_divisibility(world_size, model_parallel_size) 57 | rank = torch.distributed.get_rank() 58 | 59 | # Build the data parallel groups. 60 | global _DATA_PARALLEL_GROUP 61 | assert _DATA_PARALLEL_GROUP is None, \ 62 | 'data parallel group is already initialized' 63 | for i in range(model_parallel_size): 64 | ranks = range(i, world_size, model_parallel_size) 65 | group = torch.distributed.new_group(ranks) 66 | if i == (rank % model_parallel_size): 67 | _DATA_PARALLEL_GROUP = group 68 | 69 | # Build the model parallel groups. 70 | global _MODEL_PARALLEL_GROUP 71 | assert _MODEL_PARALLEL_GROUP is None, \ 72 | 'model parallel group is already initialized' 73 | for i in range(world_size // model_parallel_size): 74 | ranks = range(i * model_parallel_size, 75 | (i + 1) * model_parallel_size) 76 | group = torch.distributed.new_group(ranks) 77 | if i == (rank // model_parallel_size): 78 | _MODEL_PARALLEL_GROUP = group 79 | 80 | 81 | def model_parallel_is_initialized(): 82 | """Check if model and data parallel groups are initialized.""" 83 | if _MODEL_PARALLEL_GROUP is None or _DATA_PARALLEL_GROUP is None: 84 | return False 85 | return True 86 | 87 | 88 | def get_model_parallel_group(): 89 | """Get the model parallel group the caller rank belongs to.""" 90 | assert _MODEL_PARALLEL_GROUP is not None, \ 91 | 'model parallel group is not initialized' 92 | return _MODEL_PARALLEL_GROUP 93 | 94 | 95 | def get_data_parallel_group(): 96 | """Get the data parallel group the caller rank belongs to.""" 97 | assert _DATA_PARALLEL_GROUP is not None, \ 98 | 'data parallel group is not initialized' 99 | return _DATA_PARALLEL_GROUP 100 | 101 | 102 | def get_model_parallel_world_size(): 103 | """Return world size for the model parallel group.""" 104 | return torch.distributed.get_world_size(group=get_model_parallel_group()) 105 | 106 | 107 | def get_model_parallel_rank(): 108 | """Return my rank for the model parallel group.""" 109 | return torch.distributed.get_rank(group=get_model_parallel_group()) 110 | 111 | 112 | def get_model_parallel_src_rank(): 113 | """Calculate the global rank corresponding to a local rank zeor 114 | in the model parallel group.""" 115 | global_rank = torch.distributed.get_rank() 116 | local_world_size = get_model_parallel_world_size() 117 | return (global_rank // local_world_size) * local_world_size 118 | 119 | 120 | def get_data_parallel_world_size(): 121 | """Return world size for the data parallel group.""" 122 | return torch.distributed.get_world_size(group=get_data_parallel_group()) 123 | 124 | 125 | def get_data_parallel_rank(): 126 | """Return my rank for the data parallel group.""" 127 | return torch.distributed.get_rank(group=get_data_parallel_group()) 128 | 129 | 130 | def destroy_model_parallel(): 131 | """Set the groups to none.""" 132 | global _MODEL_PARALLEL_GROUP 133 | _MODEL_PARALLEL_GROUP = None 134 | global _DATA_PARALLEL_GROUP 135 | _DATA_PARALLEL_GROUP = None 136 | -------------------------------------------------------------------------------- /data_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """utils for creating datasets""" 16 | import os 17 | import math 18 | 19 | from .samplers import DistributedBatchSampler 20 | from .datasets import json_dataset, csv_dataset, split_ds, ConcatDataset, SplitDataset, bert_sentencepair_dataset, GPT2Dataset 21 | from .lazy_loader import exists_lazy, make_lazy, lazy_array_loader 22 | from .tokenization import Tokenization, CommandToken, Tokenizer, CharacterLevelTokenizer, BertWordPieceTokenizer, GPT2BPETokenizer, make_tokenizer 23 | from . import corpora 24 | 25 | TRAIN_DATA = 0 26 | VAL_DATA = 1 27 | TEST_DATA = 2 28 | 29 | def should_split(split): 30 | """ 31 | given split proportions checks if should split 32 | Examples: 33 | >>> should_split([10,0,0]) 34 | False 35 | >>> should_split([1,.1,.2]) 36 | True 37 | """ 38 | return max(split)/sum(split) != 1. 39 | 40 | def get_ext(path): 41 | """gets path extension""" 42 | return os.path.splitext(path)[1] 43 | 44 | def get_dataset(path, **kwargs): 45 | """gets dataset object based on keyword args and file at `path`""" 46 | if supported_corpus(path): 47 | return corpora.NAMED_CORPORA[path](**kwargs) 48 | ext = get_ext(path) 49 | if '.json' in ext: 50 | text = json_dataset(path, **kwargs) 51 | elif ext in ['.csv', '.tsv']: 52 | text = csv_dataset(path, **kwargs) 53 | else: 54 | raise NotImplementedError('data file type %s is not supported'%(ext)) 55 | return text 56 | 57 | def supported_corpus(corpus_name): 58 | """checks if corpus name is defined in `corpora.py`""" 59 | return corpus_name in corpora.NAMED_CORPORA 60 | 61 | def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=None, split=[1.], 62 | delim=',', loose=False, binarize_sent=False, drop_unlabeled=False, tokenizer=None, 63 | tokenizer_type='CharacterLevelTokenizer', tokenizer_model_path=None, vocab_size=None, 64 | model_type='bpe', pad_token=0, character_converage=1.0, non_binary_cols=None, **kwargs): 65 | """function to create datasets+tokenizers for common options""" 66 | if isinstance(process_fn, str): 67 | process_fn = eval(process_fn) 68 | if non_binary_cols is not None: 69 | # multilabel dataset support (only for csvs) 70 | label_key = non_binary_cols 71 | def get_dataset_from_path(path_): 72 | if lazy: 73 | # get lazily loaded dataset 74 | named_corpora = False 75 | if supported_corpus(path_): 76 | named_corpora = True 77 | name = path_ 78 | path_ = corpora.NAMED_CORPORA[path_].PATH 79 | if not exists_lazy(path_, data_type='data'): 80 | # create cached version of dataset for lazy loading if it doesn't exist 81 | text = get_dataset(name if named_corpora else path_, text_key=text_key, label_key=label_key, binarize_sent=binarize_sent, 82 | delim=delim, drop_unlabeled=drop_unlabeled, loose_json=loose) 83 | make_lazy(path_, text.X, data_type='data') 84 | text = lazy_array_loader(path_, data_type='data', map_fn=process_fn) 85 | else: 86 | # get dataset 87 | text = get_dataset(path_, text_key=text_key, label_key=label_key, binarize_sent=binarize_sent, 88 | delim=delim, drop_unlabeled=drop_unlabeled, loose_json=loose, preprocess_fn=process_fn) 89 | return text 90 | # get one or multiple datasets and concatenate 91 | if isinstance(path, str): 92 | path = [path] 93 | datasets = [get_dataset_from_path(p) for p in path] 94 | if len(datasets) == 1: 95 | ds = datasets[0] 96 | else: 97 | ds = ConcatDataset(datasets) 98 | # make tokenizer for dataset 99 | if tokenizer is None: 100 | tokenizer = make_tokenizer(tokenizer_type, ds, tokenizer_model_path, vocab_size, model_type, 101 | pad_token, character_converage, **kwargs) 102 | 103 | ds_type = '' 104 | if 'ds_type' in kwargs: 105 | ds_type = kwargs['ds_type'] 106 | ds.SetTokenizer(tokenizer) 107 | # Split dataset into train/val/test (and wrap bert dataset) 108 | if should_split(split): 109 | ds = split_ds(ds, split) 110 | if ds_type.lower() == 'bert': 111 | presplit_sentences = kwargs['presplit_sentences'] if 'presplit_sentences' in kwargs else False 112 | ds = [bert_sentencepair_dataset(d, max_seq_len=seq_length, presplit_sentences=presplit_sentences) if d is not None else None for d in ds] 113 | elif ds_type.lower() == 'gpt2': 114 | ds = [GPT2Dataset(d, max_seq_len=seq_length) if d is not None else None for d in ds] 115 | else: 116 | if ds_type.lower() == 'bert': 117 | presplit_sentences = kwargs['presplit_sentences'] if 'presplit_sentences' in kwargs else False 118 | ds = bert_sentencepair_dataset(ds, max_seq_len=seq_length, presplit_sentences=presplit_sentences) 119 | elif ds_type.lower() == 'gpt2': 120 | ds = GPT2Dataset(ds, max_seq_len=seq_length) 121 | return ds, tokenizer 122 | -------------------------------------------------------------------------------- /data_utils/tf_dl.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch DataLoader for TFRecords""" 16 | 17 | import queue 18 | import threading 19 | 20 | import tensorflow as tf 21 | tf.enable_eager_execution() 22 | import torch 23 | import numpy as np 24 | 25 | class TFRecordDataLoader(object): 26 | def __init__(self, records, batch_size, max_seq_len, max_preds_per_seq, train, num_workers=2, seed=1, threaded_dl=False): 27 | assert max_preds_per_seq is not None, "--max-preds-per-seq MUST BE SPECIFIED when using tfrecords" 28 | tf.set_random_seed(seed) 29 | if isinstance(records, str): 30 | records = [records] 31 | 32 | self.record_converter = Record2Example({"input_ids": tf.FixedLenFeature([max_seq_len], tf.int64), 33 | "input_mask": tf.FixedLenFeature([max_seq_len], tf.int64), 34 | "segment_ids": tf.FixedLenFeature([max_seq_len], tf.int64), 35 | "masked_lm_positions": tf.FixedLenFeature([max_preds_per_seq], tf.int64), 36 | "masked_lm_ids": tf.FixedLenFeature([max_preds_per_seq], tf.int64), 37 | "masked_lm_weights": tf.FixedLenFeature([max_preds_per_seq], tf.float32), 38 | "next_sentence_labels": tf.FixedLenFeature([1], tf.int64)}) 39 | 40 | #Instantiate dataset according to original BERT implementation 41 | if train: 42 | self.dataset = tf.data.Dataset.from_tensor_slices(tf.constant(records)) 43 | self.dataset = self.dataset.repeat() 44 | self.dataset = self.dataset.shuffle(buffer_size=len(records)) 45 | 46 | # use sloppy tfrecord dataset 47 | self.dataset = self.dataset.apply( 48 | tf.contrib.data.parallel_interleave( 49 | tf.data.TFRecordDataset, 50 | sloppy=train, 51 | cycle_length=min(num_workers, len(records)))) 52 | self.dataset = self.dataset.shuffle(buffer_size=100) 53 | else: 54 | self.dataset = tf.data.TFRecordDataset(records) 55 | self.dataset = self.dataset.repeat() 56 | 57 | # Instantiate dataloader (do not drop remainder for eval) 58 | loader_args = {'batch_size': batch_size, 59 | 'num_parallel_batches': num_workers, 60 | 'drop_remainder': train} 61 | self.dataloader = self.dataset.apply(tf.contrib.data.map_and_batch(self.record_converter, **loader_args)) 62 | self.threaded_dl = threaded_dl 63 | self.num_workers = num_workers 64 | 65 | def __iter__(self): 66 | if self.threaded_dl: 67 | data_iter = iter(MultiprocessLoader(self.dataloader, self.num_workers)) 68 | for item in data_iter: 69 | yield item 70 | else: 71 | data_iter = iter(self.dataloader) 72 | for item in data_iter: 73 | yield convert_tf_example_to_torch_tensors(item) 74 | 75 | class Record2Example(object): 76 | def __init__(self, feature_map): 77 | self.feature_map = feature_map 78 | 79 | def __call__(self, record): 80 | """Decodes a BERT TF record to a TF example.""" 81 | example = tf.parse_single_example(record, self.feature_map) 82 | for k, v in list(example.items()): 83 | if v.dtype == tf.int64: 84 | example[k] = tf.to_int32(v) 85 | return example 86 | 87 | def convert_tf_example_to_torch_tensors(example): 88 | item = {k: (v.numpy()) for k,v in example.items()} 89 | mask = np.zeros_like(item['input_ids']) 90 | mask_labels = np.ones_like(item['input_ids'])*-1 91 | for b, row in enumerate(item['masked_lm_positions'].astype(int)): 92 | for i, idx in enumerate(row): 93 | if item['masked_lm_weights'][b, i] != 0: 94 | mask[b, idx] = 1 95 | mask_labels[b, idx] = item['masked_lm_ids'][b, i] 96 | output = {'text': item['input_ids'], 'types': item['segment_ids'],'is_random': item['next_sentence_labels'], 97 | 'pad_mask': 1-item['input_mask'], 'mask': mask, 'mask_labels': mask_labels} 98 | return {k: torch.from_numpy(v) for k,v in output.items()} 99 | 100 | class MultiprocessLoader(object): 101 | def __init__(self, dataloader, num_workers=2): 102 | self.dl = dataloader 103 | self.queue_size = 2*num_workers 104 | 105 | def __iter__(self): 106 | output_queue = queue.Queue(self.queue_size) 107 | output_thread = threading.Thread(target=_multiproc_iter, 108 | args=(self.dl, output_queue)) 109 | output_thread.daemon = True 110 | output_thread.start() 111 | 112 | while output_thread.is_alive(): 113 | yield output_queue.get(block=True) 114 | else: 115 | print(RuntimeError('TF record data loader thread exited unexpectedly')) 116 | 117 | def _multiproc_iter(dl, output_queue): 118 | data_iter = iter(dl) 119 | for item in data_iter: 120 | tensors = convert_tf_example_to_torch_tensors(item) 121 | output_queue.put(tensors, block=True) -------------------------------------------------------------------------------- /change_mp.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import torch 4 | import copy 5 | 6 | checkpoint = sys.argv[1] 7 | target_mp = int(sys.argv[2]) 8 | 9 | assert os.path.isdir(checkpoint) 10 | with open(os.path.join(checkpoint, 'latest_checkpointed_iteration.txt')) as fin: 11 | iteration = int(fin.read().strip()) 12 | 13 | checkpoint = os.path.join(checkpoint, str(iteration)) 14 | 15 | filenames = os.listdir(checkpoint) 16 | filenames = sorted(filenames, 17 | key=lambda x: int(x.split('_')[2])) 18 | filenames = [os.path.join(checkpoint, x) for x in filenames] 19 | 20 | if target_mp == len(filenames): 21 | print("MP size keeps the same.") 22 | exit(0) 23 | 24 | if sys.argv[1][-1] == '/': 25 | new_checkpoint = sys.argv[1][:-1] + '_MP' + sys.argv[2] 26 | else: 27 | new_checkpoint = sys.argv[1] + '_MP' + sys.argv[2] 28 | if not os.path.exists(new_checkpoint): 29 | os.mkdir(new_checkpoint) 30 | with open(os.path.join(new_checkpoint, 'latest_checkpointed_iteration.txt'), 'w') as fout: 31 | fout.write("{}\n".format(iteration)) 32 | new_checkpoint = os.path.join(new_checkpoint, str(iteration)) 33 | if not os.path.exists(new_checkpoint): 34 | os.mkdir(new_checkpoint) 35 | 36 | preserve_keys = [ 37 | "lr_scheduler", 38 | "skipped_steps", 39 | "global_steps", 40 | "global_samples", 41 | "dp_world_size", 42 | "iteration", 43 | "np_rng_state", 44 | "random_rng_state", 45 | "torch_rng_state", 46 | "cuda_rng_state", 47 | "rng_tracker_states", 48 | 49 | ] 50 | 51 | if target_mp < len(filenames): 52 | print("Decrease MP size.") 53 | assert len(filenames) % target_mp == 0 54 | ratio = len(filenames) // target_mp 55 | for i in range(target_mp): 56 | start = ratio * i 57 | end = ratio * (i+1) 58 | d = torch.load(filenames[start], 59 | map_location='cpu') 60 | for k in d.keys(): 61 | if k !='module': 62 | if k in preserve_keys: 63 | pass 64 | elif k == "mp_world_size": 65 | d[k] = target_mp 66 | else: 67 | d[k] = None 68 | for j in range(start+1, end): 69 | d_new = torch.load(filenames[j], 70 | map_location='cpu') 71 | for k, v in d_new['module'].items(): 72 | assert len(v.shape) < 3 73 | if len(v.shape) == 2 and 'position' not in k: 74 | if 'query' in k: 75 | size_1 = d['module'][k].shape[0] // 3 76 | size_2 = v.shape[0] // 3 77 | target = d['module'][k] 78 | d['module'][k] = torch.cat([ 79 | target[:size_1, :], v[:size_2, :], 80 | target[size_1:size_1*2, :], v[size_2:size_2*2, :], 81 | target[size_1*2:, :], v[size_2*2:, :]], 0) 82 | elif 'word' in k or 'h_to_4h' in k: 83 | d['module'][k] = torch.cat([d['module'][k], v], 0) 84 | else: 85 | d['module'][k] = torch.cat([d['module'][k], v], 1) 86 | if len(v.shape) == 1 and 'query_key_value' in k: 87 | size_1 = d['module'][k].shape[0] // 3 88 | size_2 = v.shape[0] // 3 89 | target = d['module'][k] 90 | d['module'][k] = torch.cat([ 91 | target[:size_1], v[:size_2], 92 | target[size_1:size_1*2], v[size_2:size_2*2], 93 | target[size_1*2:], v[size_2*2:]], 0) 94 | 95 | if len(v.shape) == 1 and 'dense_h_to_4h' in k: 96 | d['module'][k] = torch.cat([d['module'][k], v], 0) 97 | filename = os.path.join(new_checkpoint, "mp_rank_{:02d}_model_states.pt".format(i)) 98 | torch.save(d, filename) 99 | 100 | if target_mp > len(filenames): 101 | print("Increase MP size.") 102 | assert target_mp % len(filenames) == 0 103 | ratio = target_mp // len(filenames) 104 | for i in range(len(filenames)): 105 | start = ratio * i 106 | end = ratio * (i+1) 107 | d = torch.load(filenames[i], 108 | map_location='cpu') 109 | for j in range(start, end): 110 | d_new = {} 111 | shift = j - start 112 | for k, v in d.items(): 113 | if k != 'module': 114 | if k in preserve_keys: 115 | d_new[k] = copy.deepcopy(d[k]) 116 | elif k == "mp_world_size": 117 | d_new[k] = target_mp 118 | else: 119 | d_new[k] = None 120 | d_new['module'] = {} 121 | for k, v in d['module'].items(): 122 | assert len(v.shape) < 3 123 | if len(v.shape) == 2 and 'position' not in k: 124 | if 'query' in k: 125 | part = v.shape[0] // ratio // 3 126 | d_new['module'][k] = torch.cat([v[shift*part:(shift+1)*part, :], v[(shift+ratio)*part:(shift+1+ratio)*part, :], v[(shift+2*ratio)*part:(shift+1+2*ratio)*part, :]], 0) 127 | elif 'word' in k or 'h_to_4h' in k: 128 | part = v.shape[0] // ratio 129 | d_new['module'][k] = v[shift*part:(shift+1)*part, :] 130 | else: 131 | part = v.shape[1] // ratio 132 | d_new['module'][k] = v[:, shift*part:(shift+1)*part] 133 | elif len(v.shape) == 1 and 'dense_h_to_4h' in k: 134 | part = v.shape[0] // ratio 135 | d_new['module'][k] = v[shift*part:(shift+1)*part] 136 | elif len(v.shape) == 1 and 'query_key_value' in k: 137 | part = v.shape[0] // ratio // 3 138 | d_new['module'][k] = torch.cat([v[shift*part:(shift+1)*part], v[(shift+ratio)*part:(shift+1+ratio)*part], v[(shift+2*ratio)*part:(shift+1+2*ratio)*part]], 0) 139 | else: 140 | d_new['module'][k] = v 141 | filename = os.path.join(new_checkpoint, "mp_rank_{:02d}_model_states.pt".format(j)) 142 | torch.save(d_new, filename) 143 | -------------------------------------------------------------------------------- /data_utils/samplers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """batch samplers that work with either random or sequential data samplers""" 16 | import math 17 | import os 18 | import sys 19 | 20 | import torch 21 | from torch.utils import data 22 | import numpy as np 23 | 24 | class RandomSampler(data.sampler.Sampler): 25 | r""" 26 | Based off of pytorch RandomSampler and DistributedSampler. Essentially a RandomSampler, 27 | but this class lets the user set an epoch like DistributedSampler 28 | Samples elements randomly. If without replacement, then sample from a shuffled dataset. 29 | If with replacement, then user can specify ``num_samples`` to draw. 30 | Arguments: 31 | data_source (Dataset): dataset to sample from 32 | num_samples (int): number of samples to draw, default=len(dataset) 33 | replacement (bool): samples are drawn with replacement if ``True``, default=False 34 | """ 35 | 36 | def __init__(self, data_source, replacement=False, num_samples=None): 37 | self.data_source = data_source 38 | self.replacement = replacement 39 | self._num_samples = num_samples 40 | self.epoch = -1 41 | 42 | if self._num_samples is not None and replacement is False: 43 | raise ValueError("With replacement=False, num_samples should not be specified, " 44 | "since a random permute will be performed.") 45 | 46 | if not isinstance(self.num_samples, int) or self.num_samples <= 0: 47 | raise ValueError("num_samples should be a positive integer " 48 | "value, but got num_samples={}".format(self.num_samples)) 49 | if not isinstance(self.replacement, bool): 50 | raise ValueError("replacement should be a boolean value, but got " 51 | "replacement={}".format(self.replacement)) 52 | 53 | @property 54 | def num_samples(self): 55 | # dataset size might change at runtime 56 | if self._num_samples is None: 57 | return len(self.data_source) 58 | return self._num_samples 59 | 60 | def __iter__(self): 61 | n = len(self.data_source) 62 | g = torch.Generator() 63 | if self.epoch >= 0: 64 | g.manual_seed(self.epoch) 65 | if self.replacement: 66 | return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64, generator=g).tolist()) 67 | return iter(torch.randperm(n, generator=g).tolist()) 68 | 69 | def __len__(self): 70 | return self.num_samples 71 | 72 | def set_epoch(self, epoch): 73 | self.epoch = epoch 74 | 75 | class DistributedBatchSampler(data.sampler.BatchSampler): 76 | """ 77 | similar to normal implementation of distributed sampler, except implementation is at the 78 | batch sampler level, instead of just the sampler level. This allows wrapping of arbitrary 79 | data samplers (sequential, random, WeightedRandomSampler, etc.) with this batch sampler. 80 | """ 81 | def __init__(self, sampler, batch_size, drop_last, rank=-1, world_size=2, wrap_last=False): 82 | super(DistributedBatchSampler, self).__init__(sampler, batch_size, drop_last) 83 | if rank == -1: 84 | assert False, 'should not be here' 85 | rank = torch.distributed.get_rank() 86 | self.rank = rank 87 | self.world_size = world_size 88 | self.sampler.wrap_around = 0 89 | self.wrap_around = 0 90 | self.wrap_last = wrap_last 91 | self.start_iter = 0 92 | 93 | def __iter__(self): 94 | batch = [] 95 | last_batch = None 96 | i = 0 97 | for idx in self.data_iterator(self.sampler, wrap_around=False): 98 | batch.append(idx) 99 | if len(batch) == self.batch_size: 100 | tbatch = self._batch(batch) 101 | if i >= self.start_iter: 102 | yield tbatch 103 | self.start_iter = 0 104 | i += 1 105 | last_batch = np.array(list(tbatch)) 106 | batch = [] 107 | batch_len = len(batch) 108 | if batch_len > 0 and not self.drop_last: 109 | if self.wrap_last: 110 | self.sampler.wrap_around -= (self.batch_size) 111 | self.wrap_around += (len(batch)) 112 | self.wrap_around %= self.batch_size 113 | if isinstance(self.sampler, TransposedSampler): 114 | for i, idx in enumerate(self.data_iterator(self.sampler, wrap_around=True)): 115 | if i == 0: 116 | continue 117 | batch.append(idx) 118 | new_batch_len = len(batch) 119 | if len(batch) == self.batch_size: 120 | break 121 | yield self._batch(batch) 122 | if self.wrap_last: 123 | self.sampler.wrap_around += self.batch_size 124 | 125 | def data_iterator(self, _iter, wrap_around=False): 126 | """iterates through data and handles wrap around""" 127 | for i, idx in enumerate(_iter): 128 | if i < self.wrap_around%self.batch_size: 129 | continue 130 | if wrap_around: 131 | self.wrap_around += 1 132 | self.wrap_around %= self.batch_size 133 | yield idx 134 | 135 | def _batch(self, batch): 136 | """extracts samples only pertaining to this worker's batch""" 137 | start = self.rank*self.batch_size//self.world_size 138 | end = (self.rank+1)*self.batch_size//self.world_size 139 | return batch[start:end] 140 | -------------------------------------------------------------------------------- /data/samplers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Batch samplers that work with either random or sequential data samplers.""" 17 | 18 | import torch 19 | from torch.utils import data 20 | 21 | 22 | class RandomSampler(data.sampler.Sampler): 23 | """Based off of pytorch RandomSampler and DistributedSampler. Essentially 24 | a RandomSampler, but this class lets the user set an epoch like 25 | DistributedSampler Samples elements randomly. If without replacement, then 26 | sample from a shuffled dataset. If with replacement, then user can 27 | specify ``num_samples`` to draw. 28 | Arguments: 29 | data_source (Dataset): dataset to sample from 30 | num_samples (int): number of samples to draw, default=len(dataset) 31 | replacement (bool): samples are drawn with replacement if ``True``, 32 | default=False 33 | """ 34 | 35 | def __init__(self, data_source, replacement=False, num_samples=None): 36 | self.data_source = data_source 37 | self.replacement = replacement 38 | self._num_samples = num_samples 39 | self.epoch = -1 40 | 41 | if self._num_samples is not None and replacement is False: 42 | raise ValueError("With replacement=False, num_samples should not " 43 | "be specified, since a random permute will be " 44 | "performed.") 45 | 46 | if not isinstance(self.num_samples, int) or self.num_samples <= 0: 47 | raise ValueError("num_samples should be a positive integer " 48 | "value, but got num_samples={}".format( 49 | self.num_samples)) 50 | if not isinstance(self.replacement, bool): 51 | raise ValueError("replacement should be a boolean value, but got " 52 | "replacement={}".format(self.replacement)) 53 | 54 | @property 55 | def num_samples(self): 56 | # dataset size might change at runtime 57 | if self._num_samples is None: 58 | return len(self.data_source) 59 | return self._num_samples 60 | 61 | def __iter__(self): 62 | n = len(self.data_source) 63 | g = torch.Generator() 64 | if self.epoch >= 0: 65 | g.manual_seed(self.epoch) 66 | if self.replacement: 67 | return iter(torch.randint(high=n, size=(self.num_samples,), 68 | dtype=torch.int64, generator=g).tolist()) 69 | return iter(torch.randperm(n, generator=g).tolist()) 70 | 71 | def __len__(self): 72 | return self.num_samples 73 | 74 | def set_epoch(self, epoch): 75 | self.epoch = epoch 76 | 77 | 78 | class DistributedBatchSampler(data.sampler.BatchSampler): 79 | """Similar to normal implementation of distributed sampler, except 80 | implementation is at the batch sampler level, instead of just the 81 | sampler level. This allows wrapping of arbitrary data samplers 82 | (sequential, random, WeightedRandomSampler, etc.) with this batch 83 | sampler. 84 | 85 | The `interleave` argument specifies how to distribute a batch. A value 86 | of True combined with the above random sampler is equivalent to pytorch's 87 | torch.utils.data.distributed.DistributedSampler. 88 | 89 | For the following batch [0,1,2,3,4,5,6,7] and data parallelism of 2 90 | specifying True will result in the following samples for each gpu: 91 | GPU0: [0,2,4,6] GPU1: [1,3,5,7] 92 | specifying False will result in the following samples: 93 | GPU0: [0,1,2,3] GPU1: [4,5,6,7]""" 94 | 95 | def __init__(self, sampler, batch_size, drop_last, rank=-1, 96 | world_size=2, wrap_last=False, interleave=False): 97 | super(DistributedBatchSampler, self).__init__(sampler, batch_size, 98 | drop_last) 99 | if rank == -1: 100 | assert False, 'should not be here' 101 | rank = torch.distributed.get_rank() 102 | self.rank = rank 103 | self.world_size = world_size 104 | self.sampler.wrap_around = 0 105 | self.wrap_around = 0 106 | self.wrap_last = wrap_last 107 | self.start_iter = 0 108 | self.interleave = interleave 109 | 110 | def __iter__(self): 111 | batch = [] 112 | i = 0 113 | for idx in self.data_iterator(self.sampler, wrap_around=False): 114 | batch.append(idx) 115 | if len(batch) == self.batch_size: 116 | tbatch = self._batch(batch) 117 | if i >= self.start_iter: 118 | yield tbatch 119 | self.start_iter = 0 120 | i += 1 121 | batch = [] 122 | batch_len = len(batch) 123 | if batch_len > 0 and not self.drop_last: 124 | if self.wrap_last: 125 | self.sampler.wrap_around -= (self.batch_size) 126 | self.wrap_around += (len(batch)) 127 | self.wrap_around %= self.batch_size 128 | yield self._batch(batch) 129 | if self.wrap_last: 130 | self.sampler.wrap_around += self.batch_size 131 | 132 | def data_iterator(self, _iter, wrap_around=False): 133 | """iterates through data and handles wrap around""" 134 | for i, idx in enumerate(_iter): 135 | if i < self.wrap_around % self.batch_size: 136 | continue 137 | if wrap_around: 138 | self.wrap_around += 1 139 | self.wrap_around %= self.batch_size 140 | yield idx 141 | 142 | def _batch(self, batch): 143 | """extracts samples only pertaining to this worker's batch""" 144 | if self.interleave: 145 | return batch[self.rank:self.batch_size:self.world_size] 146 | start = self.rank * self.batch_size // self.world_size 147 | end = (self.rank + 1) * self.batch_size // self.world_size 148 | return batch[start:end] 149 | -------------------------------------------------------------------------------- /mpu/tests/test_random.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import sys 17 | sys.path.append("../..") 18 | 19 | import torch 20 | import mpu 21 | 22 | from commons import initialize_distributed 23 | from commons import print_separator 24 | 25 | 26 | def test_set_cuda_rng_state(model_parallel_size): 27 | 28 | if torch.distributed.get_rank() == 0: 29 | print('> testing set_rng_state with size {} ...'. 30 | format(model_parallel_size)) 31 | 32 | mpu.initialize_model_parallel(model_parallel_size) 33 | model_parallel_size = mpu.get_model_parallel_world_size() 34 | 35 | size = 123 36 | seed = 1234 37 | torch.cuda.manual_seed(1234) 38 | tensor = torch.cuda.FloatTensor(size) 39 | 40 | # Get the state 41 | rng_state = torch.cuda.get_rng_state() 42 | rng_state_copy = rng_state.clone() 43 | 44 | # Do some stuff. 45 | for _ in range(5): 46 | torch.randn(size, out=tensor) 47 | result_1 = tensor.clone() 48 | 49 | assert rng_state.sub(rng_state_copy).max() == 0 50 | assert torch.cuda.get_rng_state().sub(rng_state_copy).max() > 0 51 | 52 | # State should be different. 53 | new_rng_state = torch.cuda.get_rng_state() 54 | max_diff = new_rng_state.sub(rng_state).max() 55 | print(' max diff in rng state (should be non-zero) on global rank {}: {}'. 56 | format(torch.distributed.get_rank(), max_diff)) 57 | assert max_diff > 0 58 | 59 | # Reset the rng state and do the same stuff. 60 | mpu.random._set_cuda_rng_state(rng_state) 61 | for _ in range(5): 62 | torch.randn(size, out=tensor) 63 | mpu.random._set_cuda_rng_state(rng_state) 64 | for _ in range(5): 65 | torch.randn(size, out=tensor) 66 | result_2 = tensor.clone() 67 | 68 | # Results should be the same 69 | error = result_2.sub(result_1).abs().max() 70 | print(' max error in generated tensors (should be zero) on ' 71 | 'global rank {}: {}'.format(torch.distributed.get_rank(), error)) 72 | assert error < 1.0e-6 73 | 74 | # Input state should have remained intact. 75 | error = rng_state.sub(rng_state_copy).max() 76 | print(' max error in rng state (should be zero) on global rank {}: {}'. 77 | format(torch.distributed.get_rank(), error)) 78 | assert error == 0 79 | 80 | # Reset groups 81 | mpu.destroy_model_parallel() 82 | 83 | torch.distributed.barrier() 84 | if torch.distributed.get_rank() == 0: 85 | print('>> passed the test :-)') 86 | 87 | 88 | def test_cuda_rng_tracker(model_parallel_size): 89 | 90 | if torch.distributed.get_rank() == 0: 91 | print('> testing cuda rng tracker with size {} ...'. 92 | format(model_parallel_size)) 93 | 94 | mpu.initialize_model_parallel(model_parallel_size) 95 | model_parallel_size = mpu.get_model_parallel_world_size() 96 | 97 | seed_1 = 1234 98 | seed_2 = 4321 99 | size = [12, 21] 100 | tensor = torch.cuda.FloatTensor(size) 101 | 102 | # Set to seed_1 and generate two tensors. 103 | torch.cuda.manual_seed(seed_1) 104 | torch.randn(size, out=tensor) 105 | target_11 = tensor.clone() 106 | torch.randn(size, out=tensor) 107 | target_12 = tensor.clone() 108 | 109 | # Set to seed_2 and generate two tensors. 110 | torch.cuda.manual_seed(seed_2) 111 | torch.randn(size, out=tensor) 112 | target_21 = tensor.clone() 113 | torch.randn(size, out=tensor) 114 | target_22 = tensor.clone() 115 | 116 | # Now if we interleave seed_1 and seed_2, 117 | # we should still get the same tensors 118 | torch.cuda.manual_seed(seed_1) 119 | mpu.get_cuda_rng_tracker().add('test', seed_2) 120 | 121 | torch.randn(size, out=tensor) 122 | result_11 = tensor.clone() 123 | 124 | with mpu.get_cuda_rng_tracker().fork('test'): 125 | torch.randn(size, out=tensor) 126 | result_21 = tensor.clone() 127 | 128 | torch.randn(size, out=tensor) 129 | result_12 = tensor.clone() 130 | 131 | with mpu.get_cuda_rng_tracker().fork('test'): 132 | torch.randn(size, out=tensor) 133 | result_22 = tensor.clone() 134 | 135 | diff = result_11.sub(result_21).abs().max() 136 | diff = min(diff, result_12.sub(result_22).abs().max()) 137 | print(' max diff in generated tensors (should be non-zero) on ' 138 | 'global rank {}: {}'.format(torch.distributed.get_rank(), diff)) 139 | assert diff > 1.0e-6 140 | error = max(result_11.sub(target_11).abs().max(), 141 | result_12.sub(target_12).abs().max()) 142 | error = max(error, result_21.sub(target_21).abs().max()) 143 | error = max(error, result_22.sub(target_22).abs().max()) 144 | print(' max error in generated tensors (should be zero) on ' 145 | 'global rank {}: {}'.format(torch.distributed.get_rank(), error)) 146 | assert error < 1.0e-6 147 | 148 | # Reset the tracker 149 | mpu.get_cuda_rng_tracker().reset() 150 | 151 | # Reset groups 152 | mpu.destroy_model_parallel() 153 | 154 | torch.distributed.barrier() 155 | if torch.distributed.get_rank() == 0: 156 | print('>> passed the test :-)') 157 | 158 | 159 | def test_model_parallel_cuda_manual_seed(model_parallel_size): 160 | 161 | if torch.distributed.get_rank() == 0: 162 | print('> testing model parallel cuda manual seed with size {} ...'. 163 | format(model_parallel_size)) 164 | 165 | mpu.initialize_model_parallel(model_parallel_size) 166 | model_parallel_size = mpu.get_model_parallel_world_size() 167 | 168 | mpu.model_parallel_cuda_manual_seed(12345) 169 | assert torch.cuda.initial_seed() == 12345 170 | with mpu.get_cuda_rng_tracker().fork(): 171 | assert torch.cuda.initial_seed() == (12345 + 2718 + 172 | mpu.get_model_parallel_rank()) 173 | 174 | # Reset the tracker 175 | mpu.get_cuda_rng_tracker().reset() 176 | 177 | # Reset groups 178 | mpu.destroy_model_parallel() 179 | 180 | torch.distributed.barrier() 181 | if torch.distributed.get_rank() == 0: 182 | print('>> passed the test :-)') 183 | 184 | 185 | if __name__ == '__main__': 186 | 187 | initialize_distributed() 188 | world_size = torch.distributed.get_world_size() 189 | 190 | model_parallel_size = 1 191 | while model_parallel_size <= world_size: 192 | print_separator('test set rng state') 193 | test_set_cuda_rng_state(model_parallel_size) 194 | model_parallel_size *= 2 195 | 196 | model_parallel_size = 1 197 | while model_parallel_size <= world_size: 198 | print_separator('test cuda rng tracker') 199 | test_cuda_rng_tracker(model_parallel_size) 200 | model_parallel_size *= 2 201 | 202 | model_parallel_size = 1 203 | while model_parallel_size <= world_size: 204 | print_separator('test model parallel cuda manual seed') 205 | test_model_parallel_cuda_manual_seed(model_parallel_size) 206 | model_parallel_size *= 2 207 | 208 | -------------------------------------------------------------------------------- /data_utils/lazy_loader.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """utils for loading text from disk""" 16 | import os 17 | import mmap 18 | import pickle as pkl 19 | import time 20 | from itertools import accumulate 21 | 22 | import torch 23 | from torch.multiprocessing import Lock 24 | 25 | def get_lazy_path(path): 26 | """ 27 | Gets directory path where lazy files are stored. 28 | """ 29 | return os.path.splitext(path)[0]+'.lazy' 30 | 31 | def exists_lazy(path, data_type='data'): 32 | """ 33 | Check if we've already made a lazy version of this file for the `data_type` field. 34 | """ 35 | if not os.path.exists(get_lazy_path(path)): 36 | return False 37 | contents = os.listdir(get_lazy_path(path)) 38 | if data_type not in contents: 39 | return False 40 | if data_type+'.len.pkl' not in contents: 41 | return False 42 | return True 43 | 44 | def make_lazy(path, strs, data_type='data'): 45 | """ 46 | Make lazy version of `data_type` field of the file. Byte offsets 47 | corresponding to data indices are stored in a `.len.pkl` data file. 48 | """ 49 | lazypath = get_lazy_path(path) 50 | if not os.path.exists(lazypath): 51 | os.makedirs(lazypath) 52 | datapath = os.path.join(lazypath, data_type) 53 | lenpath = os.path.join(lazypath, data_type+'.len.pkl') 54 | if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: 55 | with open(datapath, 'wb') as f: 56 | str_lens = [] 57 | str_cnt = 0 58 | for s in strs: 59 | if isinstance(s, dict): 60 | s = s['text'] 61 | encoded = s.encode('utf-8') 62 | f.write(encoded) 63 | str_cnt = len(encoded) 64 | str_lens.append(str_cnt) 65 | pkl.dump(str_lens, open(lenpath, 'wb')) 66 | else: 67 | while not os.path.exists(lenpath): 68 | time.sleep(1) 69 | 70 | def split_strings(strings, start, chr_lens): 71 | """ 72 | Split strings based on string lengths and given start. 73 | """ 74 | return [strings[i-start:j-start] for i, j in zip([start]+chr_lens[:-1], chr_lens)] 75 | 76 | class ProcessorTokenizer: 77 | """ 78 | callable class that runs a preprocessing, as well as tokenization step, 79 | on input text. 80 | """ 81 | def __init__(self, tokenizer, process_fn=None): 82 | self.tokenizer = tokenizer 83 | self.process_fn = process_fn 84 | 85 | def __call__(self, string): 86 | if self.tokenizer is not None: 87 | string = self.tokenizer(string, process_fn=self.process_fn) 88 | elif self.process_fn is not None: 89 | string = self.process_fn(string) 90 | return string 91 | 92 | class lazy_array_loader(object): 93 | """ 94 | Arguments: 95 | path: path to directory where array entries are concatenated into one big string file 96 | and the .len file are located 97 | data_type (str): Some datsets have multiple fields that are stored in different paths. 98 | `data_type` specifies which of these fields to load in this class 99 | mem_map (boolean): Specifies whether to memory map file `path` 100 | map_fn (callable): Fetched strings are passed through map_fn before being returned. 101 | 102 | Example of lazy loader directory structure: 103 | file.json 104 | file.lazy/ 105 | data_type1 106 | data_type1.len.pkl 107 | data_type2 108 | data_type2.len.pkl 109 | """ 110 | def __init__(self, path, data_type='data', mem_map=False, map_fn=None): 111 | lazypath = get_lazy_path(path) 112 | datapath = os.path.join(lazypath, data_type) 113 | #get file where array entries are concatenated into one big string 114 | self._file = open(datapath, 'rb') 115 | self.file = self._file 116 | #memory map file if necessary 117 | self.mem_map = mem_map 118 | if self.mem_map: 119 | self.file = mmap.mmap(self.file.fileno(), 0, prot=mmap.PROT_READ) 120 | lenpath = os.path.join(lazypath, data_type+'.len.pkl') 121 | self.lens = pkl.load(open(lenpath, 'rb')) 122 | self.ends = list(accumulate(self.lens)) 123 | self.dumb_ends = list(self.ends) 124 | self.read_lock = Lock() 125 | self.process_fn = map_fn 126 | self.map_fn = map_fn 127 | self._tokenizer = None 128 | 129 | def SetTokenizer(self, tokenizer): 130 | """ 131 | logic to set and remove (set to None) tokenizer. 132 | combines preprocessing/tokenization into one callable. 133 | """ 134 | if tokenizer is None: 135 | if not hasattr(self, '_tokenizer'): 136 | self._tokenizer = tokenizer 137 | else: 138 | self._tokenizer = tokenizer 139 | self.map_fn = ProcessorTokenizer(tokenizer, self.process_fn) 140 | 141 | def GetTokenizer(self): 142 | return self._tokenizer 143 | 144 | def __getitem__(self, index): 145 | """ 146 | read file and splice strings based on string ending array `self.ends` 147 | """ 148 | if not isinstance(index, slice): 149 | if index == 0: 150 | start = 0 151 | else: 152 | start = self.ends[index-1] 153 | end = self.ends[index] 154 | rtn = self.file_read(start, end) 155 | if self.map_fn is not None: 156 | return self.map_fn(rtn) 157 | else: 158 | # if slice, fetch strings with 1 diskread and then splice in memory 159 | chr_lens = self.ends[index] 160 | if index.start == 0 or index.start is None: 161 | start = 0 162 | else: 163 | start = self.ends[index.start-1] 164 | stop = chr_lens[-1] 165 | strings = self.file_read(start, stop) 166 | rtn = split_strings(strings, start, chr_lens) 167 | if self.map_fn is not None: 168 | return self.map_fn([s for s in rtn]) 169 | return rtn 170 | 171 | def __len__(self): 172 | return len(self.ends) 173 | 174 | def file_read(self, start=0, end=None): 175 | """read specified portion of file""" 176 | 177 | # atomic reads to avoid race conditions with multiprocess dataloader 178 | self.read_lock.acquire() 179 | # seek to start of file read 180 | self.file.seek(start) 181 | # read to end of file if no end point provided 182 | if end is None: 183 | rtn = self.file.read() 184 | #else read amount needed to reach end point 185 | else: 186 | rtn = self.file.read(end-start) 187 | self.read_lock.release() 188 | #TODO: @raulp figure out mem map byte string bug 189 | #if mem map'd need to decode byte string to string 190 | rtn = rtn.decode('utf-8', 'ignore') 191 | # rtn = str(rtn) 192 | if self.mem_map: 193 | rtn = rtn.decode('unicode_escape') 194 | return rtn 195 | 196 | -------------------------------------------------------------------------------- /fp16/fp16util.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | import torch.nn as nn 18 | from torch.autograd import Variable 19 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 20 | 21 | import mpu 22 | 23 | 24 | class tofp16(nn.Module): 25 | """ 26 | Utility module that implements:: 27 | 28 | def forward(self, input): 29 | return input.half() 30 | """ 31 | 32 | def __init__(self): 33 | super(tofp16, self).__init__() 34 | 35 | def forward(self, input): 36 | return input.half() 37 | 38 | 39 | def BN_convert_float(module): 40 | """ 41 | Utility function for network_to_half(). 42 | 43 | Retained for legacy purposes. 44 | """ 45 | if isinstance(module, torch.nn.modules.batchnorm._BatchNorm) and module.affine is True: 46 | module.float() 47 | for child in module.children(): 48 | BN_convert_float(child) 49 | return module 50 | 51 | 52 | def network_to_half(network): 53 | """ 54 | Convert model to half precision in a batchnorm-safe way. 55 | 56 | Retained for legacy purposes. It is recommended to use FP16Model. 57 | """ 58 | return nn.Sequential(tofp16(), BN_convert_float(network.half())) 59 | 60 | 61 | def convert_module(module, dtype): 62 | """ 63 | Converts a module's immediate parameters and buffers to dtype. 64 | """ 65 | for param in module.parameters(recurse=False): 66 | if param is not None: 67 | if param.data.dtype.is_floating_point: 68 | param.data = param.data.to(dtype=dtype) 69 | if param._grad is not None and param._grad.data.dtype.is_floating_point: 70 | param._grad.data = param._grad.data.to(dtype=dtype) 71 | 72 | for buf in module.buffers(recurse=False): 73 | if buf is not None and buf.data.dtype.is_floating_point: 74 | buf.data = buf.data.to(dtype=dtype) 75 | 76 | 77 | def convert_network(network, dtype): 78 | """ 79 | Converts a network's parameters and buffers to dtype. 80 | """ 81 | for module in network.modules(): 82 | if isinstance(module, torch.nn.modules.batchnorm._BatchNorm) and module.affine is True: 83 | continue 84 | convert_module(module, dtype) 85 | return network 86 | 87 | 88 | class FP16Model(nn.Module): 89 | """ 90 | Convert model to half precision in a batchnorm-safe way. 91 | """ 92 | 93 | def __init__(self, network): 94 | super(FP16Model, self).__init__() 95 | self.network = convert_network(network, dtype=torch.half) 96 | 97 | def forward(self, *inputs): 98 | inputs = tuple(t.half() for t in inputs) 99 | return self.network(*inputs) 100 | 101 | 102 | def backwards_debug_hook(grad): 103 | raise RuntimeError("master_params recieved a gradient in the backward pass!") 104 | 105 | def prep_param_lists(model, flat_master=False): 106 | """ 107 | Creates a list of FP32 master parameters for a given model, as in 108 | `Training Neural Networks with Mixed Precision: Real Examples`_. 109 | 110 | Args: 111 | model (torch.nn.Module): Existing Pytorch model 112 | flat_master (bool, optional, default=False): Flatten the master parameters into a single tensor, as a performance optimization. 113 | Returns: 114 | A tuple (``model_params``, ``master_params``). ``model_params`` is a list of the model's parameters for later use with :func:`model_grads_to_master_grads` and :func:`master_params_to_model_params`. ``master_params`` is a list of FP32 master gradients. If ``flat_master=True``, ``master_params`` will be a list with one element. 115 | 116 | Example:: 117 | 118 | model_params, master_params = prep_param_lists(model) 119 | 120 | .. warning:: 121 | Currently, if ``flat_master=True``, all the model's parameters must be the same type. If the model has parameters of different types, use ``flat_master=False``, or use :class:`FP16_Optimizer`. 122 | 123 | .. _`Training Neural Networks with Mixed Precision: Real Examples`: 124 | http://on-demand.gputechconf.com/gtc/2018/video/S81012/ 125 | """ 126 | model_params = [param for param in model.parameters() if param.requires_grad] 127 | 128 | if flat_master: 129 | # Give the user some more useful error messages 130 | try: 131 | # flatten_dense_tensors returns a contiguous flat array. 132 | # http://pytorch.org/docs/master/_modules/torch/_utils.html 133 | master_params = _flatten_dense_tensors([param.data for param in model_params]).float() 134 | except: 135 | print("Error in prep_param_lists: model may contain a mixture of parameters " 136 | "of different types. Use flat_master=False, or use F16_Optimizer.") 137 | raise 138 | master_params = torch.nn.Parameter(master_params) 139 | master_params.requires_grad = True 140 | # master_params.register_hook(backwards_debug_hook) 141 | if master_params.grad is None: 142 | master_params.grad = master_params.new(*master_params.size()) 143 | return model_params, [master_params] 144 | else: 145 | master_params = [param.clone().float().detach() for param in model_params] 146 | for param in master_params: 147 | param.requires_grad = True 148 | return model_params, master_params 149 | 150 | 151 | def model_grads_to_master_grads(model_params, master_params, flat_master=False): 152 | """ 153 | Copy model gradients to master gradients. 154 | 155 | Args: 156 | model_params: List of model parameters created by :func:`prep_param_lists`. 157 | master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`model_grads_to_master_grads`. 158 | """ 159 | if flat_master: 160 | # The flattening may incur one more deep copy than is necessary. 161 | master_params[0].grad.data.copy_( 162 | _flatten_dense_tensors([p.grad.data for p in model_params])) 163 | else: 164 | for model, master in zip(model_params, master_params): 165 | if model.grad is not None: 166 | if master.grad is None: 167 | master.grad = Variable(master.data.new(*master.data.size())) 168 | master.grad.data.copy_(model.grad.data) 169 | else: 170 | master.grad = None 171 | 172 | 173 | def master_params_to_model_params(model_params, master_params, flat_master=False): 174 | """ 175 | Copy master parameters to model parameters. 176 | 177 | Args: 178 | model_params: List of model parameters created by :func:`prep_param_lists`. 179 | master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`master_params_to_model_params`. 180 | """ 181 | if flat_master: 182 | for model, master in zip(model_params, 183 | _unflatten_dense_tensors(master_params[0].data, model_params)): 184 | model.data.copy_(master) 185 | else: 186 | for model, master in zip(model_params, master_params): 187 | model.data.copy_(master.data) 188 | 189 | # Backward compatibility fixes 190 | 191 | def to_python_float(t): 192 | if hasattr(t, 'item'): 193 | return t.item() 194 | else: 195 | return t[0] 196 | 197 | TORCH_MAJOR = int(torch.__version__.split('.')[0]) 198 | TORCH_MINOR = int(torch.__version__.split('.')[1]) 199 | 200 | clip_grad_norm = mpu.clip_grad_norm 201 | #elif TORCH_MAJOR == 0 and TORCH_MINOR <= 4: 202 | # clip_grad_norm = torch.nn.utils.clip_grad_norm 203 | #else: 204 | # clip_grad_norm = torch.nn.utils.clip_grad_norm_ 205 | -------------------------------------------------------------------------------- /data_utils/file_utils.py: -------------------------------------------------------------------------------- 1 | # This file is provided as is from: 2 | # https://github.com/huggingface/pytorch-pretrained-BERT 3 | # Please refer to their repository for copyright. 4 | 5 | """ 6 | Utilities for working with the local dataset cache. 7 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp 8 | Copyright by the AllenNLP authors. 9 | """ 10 | from __future__ import (absolute_import, division, print_function, unicode_literals) 11 | 12 | import json 13 | import logging 14 | import os 15 | import shutil 16 | import tempfile 17 | from functools import wraps 18 | from hashlib import sha256 19 | import sys 20 | from io import open 21 | 22 | import boto3 23 | import requests 24 | from botocore.exceptions import ClientError 25 | from tqdm import tqdm 26 | 27 | try: 28 | from urllib.parse import urlparse 29 | except ImportError: 30 | from urlparse import urlparse 31 | 32 | try: 33 | from pathlib import Path 34 | PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 35 | Path.home() / '.pytorch_pretrained_bert')) 36 | except (AttributeError, ImportError): 37 | PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 38 | os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert')) 39 | 40 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 41 | 42 | 43 | def url_to_filename(url, etag=None): 44 | """ 45 | Convert `url` into a hashed filename in a repeatable way. 46 | If `etag` is specified, append its hash to the url's, delimited 47 | by a period. 48 | """ 49 | url_bytes = url.encode('utf-8') 50 | url_hash = sha256(url_bytes) 51 | filename = url_hash.hexdigest() 52 | 53 | if etag: 54 | etag_bytes = etag.encode('utf-8') 55 | etag_hash = sha256(etag_bytes) 56 | filename += '.' + etag_hash.hexdigest() 57 | 58 | return filename 59 | 60 | 61 | def filename_to_url(filename, cache_dir=None): 62 | """ 63 | Return the url and etag (which may be ``None``) stored for `filename`. 64 | Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. 65 | """ 66 | if cache_dir is None: 67 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 68 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 69 | cache_dir = str(cache_dir) 70 | 71 | cache_path = os.path.join(cache_dir, filename) 72 | if not os.path.exists(cache_path): 73 | raise EnvironmentError("file {} not found".format(cache_path)) 74 | 75 | meta_path = cache_path + '.json' 76 | if not os.path.exists(meta_path): 77 | raise EnvironmentError("file {} not found".format(meta_path)) 78 | 79 | with open(meta_path, encoding="utf-8") as meta_file: 80 | metadata = json.load(meta_file) 81 | url = metadata['url'] 82 | etag = metadata['etag'] 83 | 84 | return url, etag 85 | 86 | 87 | def cached_path(url_or_filename, cache_dir=None): 88 | """ 89 | Given something that might be a URL (or might be a local path), 90 | determine which. If it's a URL, download the file and cache it, and 91 | return the path to the cached file. If it's already a local path, 92 | make sure the file exists and then return the path. 93 | """ 94 | if cache_dir is None: 95 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 96 | if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): 97 | url_or_filename = str(url_or_filename) 98 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 99 | cache_dir = str(cache_dir) 100 | 101 | parsed = urlparse(url_or_filename) 102 | 103 | if parsed.scheme in ('http', 'https', 's3'): 104 | # URL, so get it from the cache (downloading if necessary) 105 | return get_from_cache(url_or_filename, cache_dir) 106 | elif os.path.exists(url_or_filename): 107 | # File, and it exists. 108 | return url_or_filename 109 | elif parsed.scheme == '': 110 | # File, but it doesn't exist. 111 | raise EnvironmentError("file {} not found".format(url_or_filename)) 112 | else: 113 | # Something unknown 114 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) 115 | 116 | 117 | def split_s3_path(url): 118 | """Split a full s3 path into the bucket name and path.""" 119 | parsed = urlparse(url) 120 | if not parsed.netloc or not parsed.path: 121 | raise ValueError("bad s3 path {}".format(url)) 122 | bucket_name = parsed.netloc 123 | s3_path = parsed.path 124 | # Remove '/' at beginning of path. 125 | if s3_path.startswith("/"): 126 | s3_path = s3_path[1:] 127 | return bucket_name, s3_path 128 | 129 | 130 | def s3_request(func): 131 | """ 132 | Wrapper function for s3 requests in order to create more helpful error 133 | messages. 134 | """ 135 | 136 | @wraps(func) 137 | def wrapper(url, *args, **kwargs): 138 | try: 139 | return func(url, *args, **kwargs) 140 | except ClientError as exc: 141 | if int(exc.response["Error"]["Code"]) == 404: 142 | raise EnvironmentError("file {} not found".format(url)) 143 | else: 144 | raise 145 | 146 | return wrapper 147 | 148 | 149 | @s3_request 150 | def s3_etag(url): 151 | """Check ETag on S3 object.""" 152 | s3_resource = boto3.resource("s3") 153 | bucket_name, s3_path = split_s3_path(url) 154 | s3_object = s3_resource.Object(bucket_name, s3_path) 155 | return s3_object.e_tag 156 | 157 | 158 | @s3_request 159 | def s3_get(url, temp_file): 160 | """Pull a file directly from S3.""" 161 | s3_resource = boto3.resource("s3") 162 | bucket_name, s3_path = split_s3_path(url) 163 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) 164 | 165 | 166 | def http_get(url, temp_file): 167 | req = requests.get(url, stream=True) 168 | content_length = req.headers.get('Content-Length') 169 | total = int(content_length) if content_length is not None else None 170 | progress = tqdm(unit="B", total=total) 171 | for chunk in req.iter_content(chunk_size=1024): 172 | if chunk: # filter out keep-alive new chunks 173 | progress.update(len(chunk)) 174 | temp_file.write(chunk) 175 | progress.close() 176 | 177 | 178 | def get_from_cache(url, cache_dir=None): 179 | """ 180 | Given a URL, look for the corresponding dataset in the local cache. 181 | If it's not there, download it. Then return the path to the cached file. 182 | """ 183 | if cache_dir is None: 184 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 185 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 186 | cache_dir = str(cache_dir) 187 | 188 | if not os.path.exists(cache_dir): 189 | os.makedirs(cache_dir) 190 | 191 | # Get eTag to add to filename, if it exists. 192 | if url.startswith("s3://"): 193 | etag = s3_etag(url) 194 | else: 195 | response = requests.head(url, allow_redirects=True) 196 | if response.status_code != 200: 197 | raise IOError("HEAD request failed for url {} with status code {}" 198 | .format(url, response.status_code)) 199 | etag = response.headers.get("ETag") 200 | 201 | filename = url_to_filename(url, etag) 202 | 203 | # get cache path to put the file 204 | cache_path = os.path.join(cache_dir, filename) 205 | 206 | if not os.path.exists(cache_path): 207 | # Download to temporary file, then copy to cache dir once finished. 208 | # Otherwise you get corrupt cache entries if the download gets interrupted. 209 | with tempfile.NamedTemporaryFile() as temp_file: 210 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name) 211 | 212 | # GET file object 213 | if url.startswith("s3://"): 214 | s3_get(url, temp_file) 215 | else: 216 | http_get(url, temp_file) 217 | 218 | # we are copying the file before closing it, so flush to avoid truncation 219 | temp_file.flush() 220 | # shutil.copyfileobj() starts at the current position, so go to the start 221 | temp_file.seek(0) 222 | 223 | logger.info("copying %s to cache at %s", temp_file.name, cache_path) 224 | with open(cache_path, 'wb') as cache_file: 225 | shutil.copyfileobj(temp_file, cache_file) 226 | 227 | logger.info("creating metadata file for %s", cache_path) 228 | meta = {'url': url, 'etag': etag} 229 | meta_path = cache_path + '.json' 230 | with open(meta_path, 'w', encoding="utf-8") as meta_file: 231 | json.dump(meta, meta_file) 232 | 233 | logger.info("removing temp file %s", temp_file.name) 234 | 235 | return cache_path 236 | 237 | 238 | def read_set_from_file(filename): 239 | ''' 240 | Extract a de-duped collection (set) of text from a file. 241 | Expected file format is one item per line. 242 | ''' 243 | collection = set() 244 | with open(filename, 'r', encoding='utf-8') as file_: 245 | for line in file_: 246 | collection.add(line.rstrip()) 247 | return collection 248 | 249 | 250 | def get_file_extension(path, dot=True, lower=True): 251 | ext = os.path.splitext(path)[1] 252 | ext = ext if dot else ext[1:] 253 | return ext.lower() if lower else ext 254 | -------------------------------------------------------------------------------- /configure_data.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """parses arguments and preps data loader""" 17 | 18 | import copy 19 | import torch 20 | import data_utils 21 | 22 | import mpu 23 | 24 | class DataConfig: 25 | 26 | def __init__(self, defaults={}): 27 | super(DataConfig, self).__init__() 28 | self.defaults = defaults 29 | 30 | def apply(self, args): 31 | if torch.distributed.get_rank() == 0: 32 | print('configuring data') 33 | self.apply_defaults(args) 34 | return make_loaders(args) 35 | 36 | def set_defaults(self, **kwargs): 37 | for k, v in kwargs.items(): 38 | self.defaults[k] = v 39 | 40 | def apply_defaults(self, args): 41 | for k, v in self.defaults.items(): 42 | k = k.replace('-', '_') 43 | if not hasattr(args, k): 44 | setattr(args, k, v) 45 | 46 | 47 | def make_data_loader(dataset, batch_size, args): 48 | 49 | shuffle = args.shuffle 50 | if shuffle: 51 | sampler = data_utils.samplers.RandomSampler(dataset, replacement=True, num_samples=batch_size*args.train_iters) 52 | else: 53 | sampler = torch.utils.data.SequentialSampler(dataset) 54 | world_size = torch.distributed.get_world_size( 55 | group=mpu.get_data_parallel_group()) 56 | rank = torch.distributed.get_rank(group=mpu.get_data_parallel_group()) 57 | distributed = world_size > 1 58 | drop_last = distributed 59 | 60 | if distributed: 61 | batch_sampler = data_utils.samplers.DistributedBatchSampler(sampler, 62 | batch_size, 63 | drop_last, 64 | rank, 65 | world_size) 66 | else: 67 | batch_sampler = torch.utils.data.BatchSampler(sampler, 68 | batch_size, 69 | drop_last) 70 | 71 | data_loader = torch.utils.data.DataLoader(dataset, 72 | batch_sampler=batch_sampler, 73 | num_workers=args.num_workers, 74 | pin_memory=True) 75 | 76 | return data_loader 77 | 78 | 79 | def make_tfrecord_loaders(args): 80 | """Load train/val/test dataset from shuffled TFRecords""" 81 | 82 | import data_utils.tf_dl 83 | data_set_args = {'batch_size': args.batch_size, 84 | 'max_seq_len': args.seq_length, 85 | 'max_preds_per_seq': args.max_preds_per_seq, 86 | 'train': True, 87 | 'num_workers': max(args.num_workers, 1), 88 | 'seed': args.seed + args.rank + 1, 89 | 'threaded_dl': args.num_workers > 0 90 | } 91 | train = data_utils.tf_dl.TFRecordDataLoader(args.train_data, 92 | **data_set_args) 93 | data_set_args['train'] = False 94 | if args.eval_seq_length is not None: 95 | data_set_args['max_seq_len'] = args.eval_seq_length 96 | if args.eval_max_preds_per_seq is not None: 97 | data_set_args['max_preds_per_seq'] = args.eval_max_preds_per_seq 98 | valid = None 99 | if args.valid_data is not None: 100 | valid = data_utils.tf_dl.TFRecordDataLoader(args.valid_data, 101 | **data_set_args) 102 | test = None 103 | if args.test_data is not None: 104 | test = data_utils.tf_dl.TFRecordDataLoader(args.test_data, 105 | **data_set_args) 106 | tokenizer = data_utils.make_tokenizer(args.tokenizer_type, 107 | train, 108 | args.tokenizer_path, 109 | args.vocab_size, 110 | args.tokenizer_model_type, 111 | cache_dir=args.cache_dir) 112 | 113 | return (train, valid, test), tokenizer 114 | 115 | 116 | def make_loaders(args): 117 | """makes training/val/test""" 118 | 119 | if args.use_tfrecords: 120 | return make_tfrecord_loaders(args) 121 | world_size = torch.distributed.get_world_size( 122 | group=mpu.get_data_parallel_group()) 123 | batch_size = args.batch_size * world_size 124 | eval_batch_size = batch_size 125 | if args.eval_batch_size is not None: 126 | eval_batch_size = args.eval_batch_size * world_size 127 | seq_length = args.seq_length 128 | if seq_length < 0: 129 | seq_length = seq_length * world_size 130 | eval_seq_length = args.eval_seq_length 131 | if eval_seq_length is not None and eval_seq_length < 0: 132 | eval_seq_length = eval_seq_length * world_size 133 | split = get_split(args) 134 | data_set_args = { 135 | 'path': args.train_data, 136 | 'seq_length': seq_length, 137 | 'lazy': args.lazy_loader, 138 | 'delim': args.delim, 139 | 'text_key': args.text_key, 140 | 'label_key': 'label', 141 | 'non_binary_cols': None, 142 | 'ds_type': args.data_set_type, 143 | 'split': split, 144 | 'loose': args.loose_json, 145 | 'tokenizer_type': args.tokenizer_type, 146 | 'tokenizer_model_path': args.tokenizer_path, 147 | 'vocab_size': args.vocab_size, 148 | 'model_type': args.tokenizer_model_type, 149 | 'cache_dir': args.cache_dir, 150 | 'max_preds_per_seq': args.max_preds_per_seq, 151 | 'presplit_sentences': args.presplit_sentences} 152 | 153 | eval_set_args = copy.copy(data_set_args) 154 | eval_set_args['split'] = [1.] 155 | # if optional eval args were set then replace their 156 | # equivalent values in the arg dict 157 | if eval_seq_length: 158 | eval_set_args['seq_length'] = eval_seq_length 159 | if args.eval_max_preds_per_seq: 160 | eval_set_args['max_preds_per_seq'] = args.eval_max_preds_per_seq 161 | if args.eval_text_key is not None: 162 | eval_set_args['text_key'] = args.eval_text_key 163 | 164 | # make datasets splits and tokenizer 165 | train = None 166 | valid = None 167 | test = None 168 | 169 | if args.train_data is not None: 170 | train, tokenizer = data_utils.make_dataset(**data_set_args) 171 | if data_utils.should_split(split): 172 | train, valid, test = train 173 | eval_set_args['tokenizer'] = tokenizer 174 | 175 | # make training and val dataset if necessary 176 | if valid is None and args.valid_data is not None: 177 | eval_set_args['path'] = args.valid_data 178 | valid, tokenizer = data_utils.make_dataset(**eval_set_args) 179 | eval_set_args['tokenizer'] = tokenizer 180 | if test is None and args.test_data is not None: 181 | eval_set_args['path'] = args.test_data 182 | test, tokenizer = data_utils.make_dataset(**eval_set_args) 183 | 184 | # wrap datasets with data loader 185 | if train is not None and args.batch_size > 0: 186 | train = make_data_loader(train, batch_size, args) 187 | args.do_train = True 188 | else: 189 | args.do_train = False 190 | eval_batch_size = eval_batch_size if eval_batch_size != 0 else batch_size 191 | if valid is not None: 192 | valid = make_data_loader(valid, eval_batch_size, args) 193 | args.do_valid = True 194 | else: 195 | args.do_valid = False 196 | if test is not None: 197 | test = make_data_loader(test, eval_batch_size, args) 198 | args.do_test = True 199 | else: 200 | args.do_test = False 201 | 202 | return (train, valid, test), tokenizer 203 | 204 | def get_split(args): 205 | """ 206 | Get dataset splits from comma separated string list 207 | """ 208 | splits = [] 209 | if args.split.find(',') != -1: 210 | splits = [float(s) for s in args.split.split(',')] 211 | elif args.split.find('/') != -1: 212 | splits = [float(s) for s in args.split.split('/')] 213 | else: 214 | splits = [float(args.split)] 215 | split_total = sum(splits) 216 | if split_total < 1.: 217 | splits.append(1-split_total) 218 | while len(splits) < 3: 219 | splits.append(0.) 220 | splits = splits[:3] 221 | if args.valid_data is not None: 222 | splits[1] = 0. 223 | if args.test_data is not None: 224 | splits[2] = 0. 225 | final_sum = sum(splits) 226 | return [s/final_sum for s in splits] 227 | 228 | def configure_data(): 229 | 230 | """add cmdline flags for configuring datasets""" 231 | # These are options that are used by data_utils, but are either 232 | # deprecated or not meant to be exposed to the command line user. 233 | # These options are intneded to be set in code by specific scripts. 234 | defaults = { 235 | 'world_size': 1, 236 | 'rank': -1, 237 | 'persist_state': 0, 238 | 'lazy': False, 239 | 'transpose': False, 240 | 'data_set_type': 'supervised', 241 | 'seq_length': 256, 242 | 'eval_seq_length': 256, 243 | 'samples_per_shard': 100 244 | } 245 | 246 | return DataConfig(defaults=defaults) 247 | -------------------------------------------------------------------------------- /fp16/loss_scaler.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | import mpu 18 | 19 | # item() is a recent addition, so this helps with backward compatibility. 20 | def to_python_float(t): 21 | if hasattr(t, 'item'): 22 | return t.item() 23 | else: 24 | return t[0] 25 | 26 | class LossScaler: 27 | """ 28 | Class that manages a static loss scale. This class is intended to interact with 29 | :class:`FP16_Optimizer`, and should not be directly manipulated by the user. 30 | 31 | Use of :class:`LossScaler` is enabled via the ``static_loss_scale`` argument to 32 | :class:`FP16_Optimizer`'s constructor. 33 | 34 | Args: 35 | scale (float, optional, default=1.0): The loss scale. 36 | """ 37 | 38 | def __init__(self, scale=1): 39 | self.cur_scale = scale 40 | 41 | # `params` is a list / generator of torch.Variable 42 | def has_overflow(self, params): 43 | return False 44 | 45 | # `x` is a torch.Tensor 46 | def _has_inf_or_nan(x): 47 | return False 48 | 49 | def update_scale(self, overflow): 50 | pass 51 | 52 | @property 53 | def loss_scale(self): 54 | return self.cur_scale 55 | 56 | def scale_gradient(self, module, grad_in, grad_out): 57 | return tuple(self.loss_scale * g for g in grad_in) 58 | 59 | def backward(self, loss, retain_graph=False): 60 | scaled_loss = loss*self.loss_scale 61 | scaled_loss.backward(retain_graph=retain_graph) 62 | 63 | class DynamicLossScaler: 64 | """ 65 | Class that manages dynamic loss scaling. It is recommended to use :class:`DynamicLossScaler` 66 | indirectly, by supplying ``dynamic_loss_scale=True`` to the constructor of 67 | :class:`FP16_Optimizer`. However, it's important to understand how :class:`DynamicLossScaler` 68 | operates, because the default options can be changed using the 69 | the ``dynamic_loss_args`` argument to :class:`FP16_Optimizer`'s constructor. 70 | 71 | Loss scaling is designed to combat the problem of underflowing gradients encountered at long 72 | times when training fp16 networks. Dynamic loss scaling begins by attempting a very high loss 73 | scale. Ironically, this may result in OVERflowing gradients. If overflowing gradients are 74 | encountered, :class:`DynamicLossScaler` informs :class:`FP16_Optimizer` that an overflow has 75 | occurred. 76 | :class:`FP16_Optimizer` then skips the update step for this particular iteration/minibatch, 77 | and :class:`DynamicLossScaler` adjusts the loss scale to a lower value. 78 | If a certain number of iterations occur without overflowing gradients detected, 79 | :class:`DynamicLossScaler` increases the loss scale once more. 80 | In this way :class:`DynamicLossScaler` attempts to "ride the edge" of 81 | always using the highest loss scale possible without incurring overflow. 82 | 83 | Args: 84 | init_scale (float, optional, default=2**32): Initial loss scale attempted by :class:`DynamicLossScaler.` 85 | scale_factor (float, optional, default=2.0): Factor used when adjusting the loss scale. If an overflow is encountered, the loss scale is readjusted to loss scale/``scale_factor``. If ``scale_window`` consecutive iterations take place without an overflow, the loss scale is readjusted to loss_scale*``scale_factor``. 86 | scale_window (int, optional, default=1000): Number of consecutive iterations without an overflow to wait before increasing the loss scale. 87 | """ 88 | 89 | def __init__(self, 90 | init_scale=2**32, 91 | scale_factor=2., 92 | scale_window=1000, 93 | min_scale=1, 94 | delayed_shift=1, 95 | consecutive_hysteresis=False): 96 | self.cur_scale = init_scale 97 | self.cur_iter = 0 98 | self.last_overflow_iter = -1 99 | self.scale_factor = scale_factor 100 | self.scale_window = scale_window 101 | self.min_scale = min_scale 102 | self.delayed_shift = delayed_shift 103 | self.cur_hysteresis = delayed_shift 104 | self.consecutive_hysteresis = consecutive_hysteresis 105 | 106 | # `params` is a list / generator of torch.Variable 107 | def has_overflow_serial(self, params): 108 | for p in params: 109 | if p.grad is not None and DynamicLossScaler._has_inf_or_nan(p.grad.data): 110 | return True 111 | 112 | return False 113 | 114 | def has_overflow(self, params): 115 | overflow = self.has_overflow_serial(params) 116 | # Since each model parallel GPU carries only part of the model, 117 | # make sure overflow flag is synced across all the model parallel GPUs 118 | overflow_gpu = torch.cuda.ByteTensor([overflow]) 119 | torch.distributed.all_reduce(overflow_gpu, 120 | op=torch.distributed.ReduceOp.MAX, 121 | group=mpu.get_model_parallel_group()) 122 | overflow = overflow_gpu[0].item() 123 | return bool(overflow) 124 | 125 | 126 | # `x` is a torch.Tensor 127 | def _has_inf_or_nan(x): 128 | try: 129 | # if x is half, the .float() incurs an additional deep copy, but it's necessary if 130 | # Pytorch's .sum() creates a one-element tensor of the same type as x 131 | # (which is true for some recent version of pytorch). 132 | cpu_sum = float(x.float().sum()) 133 | # More efficient version that can be used if .sum() returns a Python scalar 134 | # cpu_sum = float(x.sum()) 135 | except RuntimeError as instance: 136 | # We want to check if inst is actually an overflow exception. 137 | # RuntimeError could come from a different error. 138 | # If so, we still want the exception to propagate. 139 | if "value cannot be converted" not in instance.args[0]: 140 | raise 141 | return True 142 | else: 143 | if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum: 144 | return True 145 | return False 146 | 147 | # `overflow` is boolean indicating whether the gradient overflowed 148 | def update_scale(self, overflow): 149 | 150 | if not hasattr(self, 'min_scale'): 151 | self.min_scale = 1 152 | if not hasattr(self, 'delayed_shift'): 153 | self.delayed_shift = 1 154 | if not hasattr(self, 'cur_hysteresis'): 155 | self.cur_hysteresis = 1 156 | if not hasattr(self, 'consecutive_hysteresis'): 157 | self.consecutive_hysteresis = True 158 | if overflow: 159 | # self.cur_scale /= self.scale_factor 160 | if self.delayed_shift == 1 or self.cur_hysteresis == 1: 161 | self.cur_scale = max(self.cur_scale/self.scale_factor, self.min_scale) 162 | else: 163 | self.cur_hysteresis -= 1 164 | self.last_overflow_iter = self.cur_iter 165 | else: 166 | if self.consecutive_hysteresis: 167 | self.cur_hysteresis = self.delayed_shift 168 | if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0: 169 | if not self.consecutive_hysteresis: 170 | self.cur_hysteresis = self.delayed_shift 171 | self.cur_scale *= self.scale_factor 172 | self.cur_iter += 1 173 | 174 | @property 175 | def loss_scale(self): 176 | return self.cur_scale 177 | 178 | def scale_gradient(self, module, grad_in, grad_out): 179 | return tuple(self.loss_scale * g for g in grad_in) 180 | 181 | def backward(self, loss, retain_graph=False): 182 | scaled_loss = loss*self.loss_scale 183 | scaled_loss.backward(retain_graph=retain_graph) 184 | 185 | ############################################################## 186 | # Example usage below here -- assuming it's in a separate file 187 | ############################################################## 188 | """ 189 | TO-DO separate out into an example. 190 | if __name__ == "__main__": 191 | import torch 192 | from torch.autograd import Variable 193 | from dynamic_loss_scaler import DynamicLossScaler 194 | 195 | # N is batch size; D_in is input dimension; 196 | # H is hidden dimension; D_out is output dimension. 197 | N, D_in, H, D_out = 64, 1000, 100, 10 198 | 199 | # Create random Tensors to hold inputs and outputs, and wrap them in Variables. 200 | x = Variable(torch.randn(N, D_in), requires_grad=False) 201 | y = Variable(torch.randn(N, D_out), requires_grad=False) 202 | 203 | w1 = Variable(torch.randn(D_in, H), requires_grad=True) 204 | w2 = Variable(torch.randn(H, D_out), requires_grad=True) 205 | parameters = [w1, w2] 206 | 207 | learning_rate = 1e-6 208 | optimizer = torch.optim.SGD(parameters, lr=learning_rate) 209 | loss_scaler = DynamicLossScaler() 210 | 211 | for t in range(500): 212 | y_pred = x.mm(w1).clamp(min=0).mm(w2) 213 | loss = (y_pred - y).pow(2).sum() * loss_scaler.loss_scale 214 | print('Iter {} loss scale: {}'.format(t, loss_scaler.loss_scale)) 215 | print('Iter {} scaled loss: {}'.format(t, loss.data[0])) 216 | print('Iter {} unscaled loss: {}'.format(t, loss.data[0] / loss_scaler.loss_scale)) 217 | 218 | # Run backprop 219 | optimizer.zero_grad() 220 | loss.backward() 221 | 222 | # Check for overflow 223 | has_overflow = DynamicLossScaler.has_overflow(parameters) 224 | 225 | # If no overflow, unscale grad and update as usual 226 | if not has_overflow: 227 | for param in parameters: 228 | param.grad.data.mul_(1. / loss_scaler.loss_scale) 229 | optimizer.step() 230 | # Otherwise, don't do anything -- ie, skip iteration 231 | else: 232 | print('OVERFLOW!') 233 | 234 | # Update loss scale for next iteration 235 | loss_scaler.update_scale(has_overflow) 236 | 237 | """ 238 | -------------------------------------------------------------------------------- /data/bert_dataset.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """BERT Style dataset.""" 17 | 18 | import os 19 | import time 20 | 21 | import numpy as np 22 | import torch 23 | from torch.utils.data import Dataset 24 | 25 | #from megatron import get_tokenizer 26 | import mpu 27 | from data.dataset_utils import build_training_sample 28 | from data.indexed_dataset import make_dataset as make_indexed_dataset 29 | from utils import print_rank_0 30 | 31 | 32 | def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, 33 | train_valid_test_num_samples, 34 | max_seq_length, masked_lm_prob, 35 | short_seq_prob, seed, skip_warmup): 36 | 37 | # Indexed dataset. 38 | indexed_dataset = get_indexed_dataset_(data_prefix, 39 | data_impl, 40 | skip_warmup) 41 | 42 | # Get start and end indices of train/valid/train into doc-idx 43 | # Note that doc-idx is desinged to be num-docs + 1 so we can 44 | # easily iterate over it. 45 | total_num_of_documents = indexed_dataset.doc_idx.shape[0] - 1 46 | splits = get_train_valid_test_split_(splits_string, total_num_of_documents) 47 | 48 | # Print stats about the splits. 49 | print_rank_0(' > dataset split:') 50 | 51 | def print_split_stats(name, index): 52 | print_rank_0(' {}:'.format(name)) 53 | print_rank_0(' document indices in [{}, {}) total of {} ' 54 | 'documents'.format(splits[index], splits[index + 1], 55 | splits[index + 1] - splits[index])) 56 | start_index = indexed_dataset.doc_idx[splits[index]] 57 | end_index = indexed_dataset.doc_idx[splits[index + 1]] 58 | print_rank_0(' sentence indices in [{}, {}) total of {} ' 59 | 'sentences'.format(start_index, end_index, 60 | end_index - start_index)) 61 | print_split_stats('train', 0) 62 | print_split_stats('validation', 1) 63 | print_split_stats('test', 2) 64 | 65 | def build_dataset(index, name): 66 | dataset = None 67 | if splits[index + 1] > splits[index]: 68 | # Get the pointer to the original doc-idx so we can set it later. 69 | doc_idx_ptr = indexed_dataset.get_doc_idx() 70 | # Slice the doc-idx 71 | start_index = splits[index] 72 | # Add +1 so we can index into the dataset to get the upper bound. 73 | end_index = splits[index + 1] + 1 74 | # New doc_idx view. 75 | indexed_dataset.set_doc_idx(doc_idx_ptr[start_index:end_index]) 76 | # Build the dataset accordingly. 77 | dataset = BertDataset( 78 | name=name, 79 | indexed_dataset=indexed_dataset, 80 | data_prefix=data_prefix, 81 | num_epochs=None, 82 | max_num_samples=train_valid_test_num_samples[index], 83 | masked_lm_prob=masked_lm_prob, 84 | max_seq_length=max_seq_length, 85 | short_seq_prob=short_seq_prob, 86 | seed=seed) 87 | # Set the original pointer so dataset remains the main dataset. 88 | indexed_dataset.set_doc_idx(doc_idx_ptr) 89 | # Checks. 90 | assert indexed_dataset.doc_idx[0] == 0 91 | assert indexed_dataset.doc_idx.shape[0] == \ 92 | (total_num_of_documents + 1) 93 | return dataset 94 | 95 | train_dataset = build_dataset(0, 'train') 96 | valid_dataset = build_dataset(1, 'valid') 97 | test_dataset = build_dataset(2, 'test') 98 | 99 | return (train_dataset, valid_dataset, test_dataset) 100 | 101 | 102 | class BertDataset(Dataset): 103 | 104 | def __init__(self, name, indexed_dataset, data_prefix, 105 | num_epochs, max_num_samples, masked_lm_prob, 106 | max_seq_length, short_seq_prob, seed): 107 | 108 | # Params to store. 109 | self.name = name 110 | self.seed = seed 111 | self.masked_lm_prob = masked_lm_prob 112 | self.max_seq_length = max_seq_length 113 | 114 | # Dataset. 115 | self.indexed_dataset = indexed_dataset 116 | 117 | # Build the samples mapping. 118 | self.samples_mapping = get_samples_mapping_(self.indexed_dataset, 119 | data_prefix, 120 | num_epochs, 121 | max_num_samples, 122 | self.max_seq_length, 123 | short_seq_prob, 124 | self.seed, 125 | self.name) 126 | 127 | # Vocab stuff. 128 | tokenizer = get_tokenizer() 129 | self.vocab_id_list = list(tokenizer.inv_vocab.keys()) 130 | self.vocab_id_to_token_dict = tokenizer.inv_vocab 131 | self.cls_id = tokenizer.cls 132 | self.sep_id = tokenizer.sep 133 | self.mask_id = tokenizer.mask 134 | self.pad_id = tokenizer.pad 135 | 136 | def __len__(self): 137 | return self.samples_mapping.shape[0] 138 | 139 | def __getitem__(self, idx): 140 | 141 | start_index, end_index, seq_length = self.samples_mapping[idx] 142 | sample = [] 143 | for index in range(start_index, end_index): 144 | sample.append(self.indexed_dataset[index]) 145 | # Note that this rng state should be numpy and not python since 146 | # python randint is inclusive whereas the numpy one is exclusive. 147 | np_rng = np.random.RandomState(seed=(self.seed + idx)) 148 | return build_training_sample(sample, seq_length, 149 | self.max_seq_length, # needed for padding 150 | self.vocab_id_list, 151 | self.vocab_id_to_token_dict, 152 | self.cls_id, self.sep_id, 153 | self.mask_id, self.pad_id, 154 | self.masked_lm_prob, np_rng) 155 | 156 | 157 | def get_indexed_dataset_(data_prefix, data_impl, skip_warmup): 158 | 159 | print_rank_0(' > building dataset index ...') 160 | 161 | start_time = time.time() 162 | indexed_dataset = make_indexed_dataset(data_prefix, 163 | data_impl, 164 | skip_warmup) 165 | assert indexed_dataset.sizes.shape[0] == indexed_dataset.doc_idx[-1] 166 | print_rank_0(' > finished creating indexed dataset in {:4f} ' 167 | 'seconds'.format(time.time() - start_time)) 168 | 169 | print_rank_0(' > indexed dataset stats:') 170 | print_rank_0(' number of documents: {}'.format( 171 | indexed_dataset.doc_idx.shape[0] - 1)) 172 | print_rank_0(' number of sentences: {}'.format( 173 | indexed_dataset.sizes.shape[0])) 174 | 175 | return indexed_dataset 176 | 177 | 178 | def get_train_valid_test_split_(splits_string, size): 179 | """ Get dataset splits from comma or '/' separated string list.""" 180 | 181 | splits = [] 182 | if splits_string.find(',') != -1: 183 | splits = [float(s) for s in splits_string.split(',')] 184 | elif splits_string.find('/') != -1: 185 | splits = [float(s) for s in splits_string.split('/')] 186 | else: 187 | splits = [float(splits_string)] 188 | while len(splits) < 3: 189 | splits.append(0.) 190 | splits = splits[:3] 191 | splits_sum = sum(splits) 192 | assert splits_sum > 0.0 193 | splits = [split / splits_sum for split in splits] 194 | splits_index = [0] 195 | for index, split in enumerate(splits): 196 | splits_index.append(splits_index[index] + 197 | int(round(split * float(size)))) 198 | diff = splits_index[-1] - size 199 | for index in range(1, len(splits_index)): 200 | splits_index[index] -= diff 201 | assert len(splits_index) == 4 202 | assert splits_index[-1] == size 203 | return splits_index 204 | 205 | 206 | def get_samples_mapping_(indexed_dataset, 207 | data_prefix, 208 | num_epochs, 209 | max_num_samples, 210 | max_seq_length, 211 | short_seq_prob, 212 | seed, 213 | name): 214 | if not num_epochs: 215 | if not max_num_samples: 216 | raise ValueError("Need to specify either max_num_samples " 217 | "or num_epochs") 218 | num_epochs = np.iinfo(np.int32).max - 1 219 | if not max_num_samples: 220 | max_num_samples = np.iinfo(np.int64).max - 1 221 | 222 | # Filename of the index mapping 223 | indexmap_filename = data_prefix 224 | indexmap_filename += '_{}_indexmap'.format(name) 225 | if num_epochs != (np.iinfo(np.int32).max - 1): 226 | indexmap_filename += '_{}ep'.format(num_epochs) 227 | if max_num_samples != (np.iinfo(np.int64).max - 1): 228 | indexmap_filename += '_{}mns'.format(max_num_samples) 229 | indexmap_filename += '_{}msl'.format(max_seq_length) 230 | indexmap_filename += '_{:0.2f}ssp'.format(short_seq_prob) 231 | indexmap_filename += '_{}s'.format(seed) 232 | indexmap_filename += '.npy' 233 | 234 | # Build the indexed mapping if not exist. 235 | if torch.distributed.get_rank() == 0 and \ 236 | not os.path.isfile(indexmap_filename): 237 | print(' > WARNING: could not find index map file {}, building ' 238 | 'the indices on rank 0 ...'.format(indexmap_filename)) 239 | 240 | # Make sure the types match the helpers input types. 241 | assert indexed_dataset.doc_idx.dtype == np.int64 242 | assert indexed_dataset.sizes.dtype == np.int32 243 | 244 | # Build samples mapping 245 | verbose = torch.distributed.get_rank() == 0 246 | start_time = time.time() 247 | print_rank_0(' > building sapmles index mapping for {} ...'.format( 248 | name)) 249 | # First compile and then import. 250 | from data.dataset_utils import compile_helper 251 | compile_helper() 252 | from data import helpers 253 | samples_mapping = helpers.build_mapping( 254 | indexed_dataset.doc_idx, 255 | indexed_dataset.sizes, 256 | num_epochs, 257 | max_num_samples, 258 | max_seq_length - 3, # account for added tokens 259 | short_seq_prob, 260 | seed, 261 | verbose) 262 | print_rank_0(' > done building sapmles index maping') 263 | np.save(indexmap_filename, samples_mapping, allow_pickle=True) 264 | print_rank_0(' > saved the index mapping in {}'.format( 265 | indexmap_filename)) 266 | # Make sure all the ranks have built the mapping 267 | print_rank_0(' > elasped time to build and save samples mapping ' 268 | '(seconds): {:4f}'.format( 269 | time.time() - start_time)) 270 | # This should be a barrier but nccl barrier assumes 271 | # device_index=rank which is not the case for model 272 | # parallel case 273 | counts = torch.cuda.LongTensor([1]) 274 | torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group()) 275 | assert counts[0].item() == torch.distributed.get_world_size( 276 | group=mpu.get_data_parallel_group()) 277 | 278 | # Load indexed dataset. 279 | print_rank_0(' > loading indexed mapping from {}'.format( 280 | indexmap_filename)) 281 | start_time = time.time() 282 | samples_mapping = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r') 283 | print_rank_0(' loaded indexed file in {:3.3f} seconds'.format( 284 | time.time() - start_time)) 285 | print_rank_0(' total number of samples: {}'.format( 286 | samples_mapping.shape[0])) 287 | 288 | return samples_mapping 289 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Utilities for logging and serialization""" 17 | 18 | import os 19 | import random 20 | import time 21 | import numpy as np 22 | import torch 23 | 24 | from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP 25 | from fp16 import FP16_Optimizer 26 | import mpu 27 | import model 28 | 29 | 30 | def print_rank_0(message): 31 | if torch.distributed.is_initialized(): 32 | if torch.distributed.get_rank() == 0: 33 | print(message, flush=True) 34 | else: 35 | print(message, flush=True) 36 | 37 | 38 | def print_args(args): 39 | """Print arguments.""" 40 | 41 | print('arguments:', flush=True) 42 | for arg in vars(args): 43 | dots = '.' * (29 - len(arg)) 44 | print(' {} {} {}'.format(arg, dots, getattr(args, arg)), flush=True) 45 | 46 | 47 | def print_params_min_max_norm(optimizer, iteration): 48 | """Print min, max, and norm of all parameters.""" 49 | index = 0 50 | rank = torch.distributed.get_rank() 51 | string = 'iteration, rank, index, model-parallel,min, max, norm\n' 52 | optimizer_ = optimizer 53 | if isinstance(optimizer, FP16_Optimizer): 54 | optimizer_ = optimizer.optimizer 55 | for param_group in optimizer_.param_groups: 56 | for param in param_group['params']: 57 | index += 1 58 | min_ = param.data.min() 59 | max_ = param.data.max() 60 | norm = param.data.norm() 61 | string += '{:7d}, {:4d}, {:4d}, {:2d}, '.format( 62 | iteration, rank, index, int(param.model_parallel)) 63 | string += '{:.6E}, {:.6E}, {:.6E}\n'.format(min_, max_, norm) 64 | print(string, flush=True) 65 | 66 | 67 | class Timers: 68 | """Group of timers.""" 69 | 70 | class Timer: 71 | """Timer.""" 72 | 73 | def __init__(self, name): 74 | self.name_ = name 75 | self.elapsed_ = 0.0 76 | self.started_ = False 77 | self.start_time = time.time() 78 | 79 | def start(self): 80 | """Start the timer.""" 81 | assert not self.started_, 'timer has already been started' 82 | torch.cuda.synchronize() 83 | self.start_time = time.time() 84 | self.started_ = True 85 | 86 | def stop(self): 87 | """Stop the timer.""" 88 | assert self.started_, 'timer is not started' 89 | torch.cuda.synchronize() 90 | self.elapsed_ += (time.time() - self.start_time) 91 | self.started_ = False 92 | 93 | def reset(self): 94 | """Reset timer.""" 95 | self.elapsed_ = 0.0 96 | self.started_ = False 97 | 98 | def elapsed(self, reset=True): 99 | """Calculate the elapsed time.""" 100 | started_ = self.started_ 101 | # If the timing in progress, end it first. 102 | if self.started_: 103 | self.stop() 104 | # Get the elapsed time. 105 | elapsed_ = self.elapsed_ 106 | # Reset the elapsed time 107 | if reset: 108 | self.reset() 109 | # If timing was in progress, set it back. 110 | if started_: 111 | self.start() 112 | return elapsed_ 113 | 114 | def __init__(self): 115 | self.timers = {} 116 | 117 | def __call__(self, name): 118 | if name not in self.timers: 119 | self.timers[name] = self.Timer(name) 120 | return self.timers[name] 121 | 122 | def log(self, names, normalizer=1.0, reset=True): 123 | """Log a group of timers.""" 124 | assert normalizer > 0.0 125 | string = 'time (ms)' 126 | for name in names: 127 | elapsed_time = self.timers[name].elapsed( 128 | reset=reset) * 1000.0/ normalizer 129 | string += ' | {}: {:.2f}'.format(name, elapsed_time) 130 | print_rank_0(string) 131 | 132 | 133 | def report_memory(name): 134 | """Simple GPU memory report.""" 135 | 136 | mega_bytes = 1024.0 * 1024.0 137 | string = name + ' memory (MB)' 138 | string += ' | allocated: {}'.format( 139 | torch.cuda.memory_allocated() / mega_bytes) 140 | string += ' | max allocated: {}'.format( 141 | torch.cuda.max_memory_allocated() / mega_bytes) 142 | string += ' | cached: {}'.format(torch.cuda.memory_cached() / mega_bytes) 143 | string += ' | max cached: {}'.format( 144 | torch.cuda.max_memory_cached()/ mega_bytes) 145 | print_rank_0(string) 146 | 147 | 148 | def get_checkpoint_name(checkpoints_path, iteration, release=False, zero=False): 149 | if release: 150 | d = 'release' 151 | else: 152 | d = '{:d}'.format(iteration) 153 | if zero: 154 | dp_rank = mpu.get_data_parallel_rank() 155 | d += '_zero_dp_rank_{}'.format(dp_rank) 156 | return os.path.join(checkpoints_path, d, 157 | 'mp_rank_{:02d}_model_states.pt'.format(mpu.get_model_parallel_rank())) 158 | 159 | 160 | def ensure_directory_exists(filename): 161 | dirname = os.path.dirname(filename) 162 | if not os.path.exists(dirname): 163 | os.makedirs(dirname) 164 | 165 | 166 | def get_checkpoint_tracker_filename(checkpoints_path): 167 | return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt') 168 | 169 | 170 | def save_zero_checkpoint(args, iteration, optimizer): 171 | zero_sd = {'iteration': iteration, 172 | 'optimizer_state_dict': optimizer.state_dict()} 173 | zero_checkpoint_name = get_checkpoint_name(args.save, iteration, zero=True) 174 | ensure_directory_exists(zero_checkpoint_name) 175 | torch.save(zero_sd, zero_checkpoint_name) 176 | print(' successfully saved {}'.format(zero_checkpoint_name)) 177 | 178 | def save_checkpoint(iteration, model, optimizer, 179 | lr_scheduler, args): 180 | """Save a model checkpoint.""" 181 | # Only rank zer0 of the data parallel writes to the disk. 182 | if isinstance(model, torchDDP): 183 | model = model.module 184 | 185 | if mpu.get_data_parallel_rank() == 0: 186 | checkpoint_name = get_checkpoint_name(args.save, iteration) 187 | print('global rank {} is saving checkpoint at iteration {:7d} to {}'. 188 | format(torch.distributed.get_rank(), iteration, checkpoint_name)) 189 | 190 | sd = {} 191 | sd['iteration'] = iteration 192 | sd['model'] = model.state_dict() 193 | 194 | # Optimizer stuff. 195 | if not args.no_save_optim: 196 | if optimizer is not None: 197 | sd['optimizer'] = optimizer.state_dict() 198 | if lr_scheduler is not None: 199 | sd['lr_scheduler'] = lr_scheduler.state_dict() 200 | 201 | # rng states. 202 | if not args.no_save_rng: 203 | sd['random_rng_state'] = random.getstate() 204 | sd['np_rng_state'] = np.random.get_state() 205 | sd['torch_rng_state'] = torch.get_rng_state() 206 | sd['cuda_rng_state'] = torch.cuda.get_rng_state() 207 | sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states() 208 | 209 | 210 | ensure_directory_exists(checkpoint_name) 211 | torch.save(sd, checkpoint_name) 212 | print(' successfully saved {}'.format(checkpoint_name)) 213 | 214 | # Wait so everyone is done (necessary) 215 | torch.distributed.barrier() 216 | # And update the latest iteration 217 | if torch.distributed.get_rank() == 0: 218 | tracker_filename = get_checkpoint_tracker_filename(args.save) 219 | with open(tracker_filename, 'w') as f: 220 | f.write(str(iteration)) 221 | # Wait so everyone is done (not necessary) 222 | torch.distributed.barrier() 223 | 224 | def save_ds_checkpoint(iteration, model, args): 225 | """Save a model checkpoint.""" 226 | 227 | sd = {} 228 | sd['iteration'] = iteration 229 | # rng states. 230 | if not args.no_save_rng: 231 | sd['random_rng_state'] = random.getstate() 232 | sd['np_rng_state'] = np.random.get_state() 233 | sd['torch_rng_state'] = torch.get_rng_state() 234 | sd['cuda_rng_state'] = torch.cuda.get_rng_state() 235 | sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states() 236 | 237 | model.save_checkpoint(args.save, iteration, client_state = sd) 238 | 239 | 240 | def get_checkpoint_iteration(args): 241 | # Read the tracker file and set the iteration. 242 | tracker_filename = get_checkpoint_tracker_filename(args.load) 243 | if not os.path.isfile(tracker_filename): 244 | print_rank_0('WARNING: could not find the metadata file {} '.format( 245 | tracker_filename)) 246 | print_rank_0(' will not load any checkpoints and will start from ' 247 | 'random') 248 | return 0, False, False 249 | iteration = 0 250 | release = False 251 | with open(tracker_filename, 'r') as f: 252 | metastring = f.read().strip() 253 | try: 254 | iteration = int(metastring) 255 | except ValueError: 256 | release = metastring == 'release' 257 | if not release: 258 | print_rank_0('ERROR: Invalid metadata file {}. Exiting'.format( 259 | tracker_filename)) 260 | exit() 261 | 262 | assert iteration > 0 or release, 'error parsing metadata file {}'.format( 263 | tracker_filename) 264 | 265 | return iteration, release, True 266 | 267 | def load_checkpoint_model(model, args): 268 | """Load a model checkpoint.""" 269 | 270 | iteration, release, success = get_checkpoint_iteration(args) 271 | 272 | if not success: 273 | return 0 274 | 275 | # Checkpoint. 276 | checkpoint_name = get_checkpoint_name(args.load, iteration, release) 277 | 278 | if mpu.get_data_parallel_rank() == 0: 279 | print('global rank {} is loading checkpoint {}'.format( 280 | torch.distributed.get_rank(), checkpoint_name)) 281 | 282 | # Load the checkpoint. 283 | sd = torch.load(checkpoint_name, map_location='cpu') 284 | 285 | if isinstance(model, torchDDP): 286 | model = model.module 287 | 288 | # Model. 289 | try: 290 | model.load_state_dict(sd['module']) 291 | except KeyError: 292 | print_rank_0('A metadata file exists but unable to load model ' 293 | 'from checkpoint {}, exiting'.format(checkpoint_name)) 294 | exit() 295 | 296 | torch.distributed.barrier() 297 | if mpu.get_data_parallel_rank() == 0: 298 | print(' successfully loaded {}'.format(checkpoint_name)) 299 | 300 | return iteration 301 | 302 | def load_weights(src, dst, dst2src=False): 303 | """ 304 | Loads weights from src to dst via in place copy. 305 | src is a huggingface gpt2model, while dst is one of our models. 306 | dst2src=True loads parameters from our models into huggingface's. 307 | ^dst2src is still untested 308 | """ 309 | conv_layer = 'Conv1D' in str(type(src)) 310 | for n, p in src.named_parameters(): 311 | if dst2src: 312 | data = dst._parameters[n].data 313 | load = p.data 314 | else: 315 | data = p.data 316 | load = dst._parameters[n].data 317 | if conv_layer and 'weight' in n: 318 | data = data.t().contiguous() 319 | load.copy_(data) 320 | # dst._parameters[n].data.copy_(data) 321 | 322 | def load_mlp(our, oai, dst2src=False): 323 | load_weights(oai.c_fc, our.dense_h_to_4h, dst2src) 324 | load_weights(oai.c_proj, our.dense_4h_to_h, dst2src) 325 | 326 | def load_attention(our, oai, dst2src=False): 327 | load_weights(oai.c_attn, our.query_key_value, dst2src) 328 | load_weights(oai.c_proj, our.dense, dst2src) 329 | 330 | def load_transformer_layer(our, oai, dst2src=False): 331 | load_weights(oai.ln_1, our.input_layernorm, dst2src) 332 | load_weights(oai.ln_2, our.post_attention_layernorm, dst2src) 333 | load_mlp(our.mlp, oai.mlp, dst2src) 334 | load_attention(our.attention, oai.attn, dst2src) 335 | 336 | def move_weights(our, oai, dst2src=False): 337 | """ 338 | Loads weights from `oai` to `our` via in place copy. 339 | `oai` is a huggingface gpt2model, while `our` is one of our models. 340 | dst2src=True loads parameters from our models into huggingface's. 341 | ^dst2src=True is still untested 342 | """ 343 | # while isinstance(our, (torchDDP, model.distributed.DistributedDataParallel, FP16_Module)): 344 | # our=our.module 345 | transformer_model = oai.transformer 346 | load_weights(transformer_model.ln_f, our.transformer.final_layernorm, dst2src) 347 | load_weights(transformer_model.wte, our.word_embeddings, dst2src) 348 | load_weights(transformer_model.wpe, our.position_embeddings, dst2src) 349 | 350 | for our_layer, oai_layer in zip(our.transformer.layers, oai.transformer.h): 351 | load_transformer_layer(our_layer, oai_layer, dst2src) 352 | -------------------------------------------------------------------------------- /data/gpt2_dataset.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """GPT2 style dataset.""" 17 | 18 | import os 19 | import time 20 | 21 | import numpy as np 22 | import torch 23 | 24 | from utils import print_rank_0 25 | import mpu 26 | from data.bert_dataset import get_train_valid_test_split_ 27 | from data.indexed_dataset import make_dataset as make_indexed_dataset 28 | 29 | 30 | def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, 31 | train_valid_test_num_samples, 32 | seq_length, seed, skip_warmup): 33 | """Build train, valid, and test datasets.""" 34 | 35 | # Indexed dataset. 36 | indexed_dataset = get_indexed_dataset_(data_prefix, 37 | data_impl, 38 | skip_warmup) 39 | 40 | total_num_of_documents = indexed_dataset.sizes.shape[0] 41 | splits = get_train_valid_test_split_(splits_string, total_num_of_documents) 42 | 43 | # Print stats about the splits. 44 | print_rank_0(' > dataset split:') 45 | 46 | def print_split_stats(name, index): 47 | print_rank_0(' {}:'.format(name)) 48 | print_rank_0(' document indices in [{}, {}) total of {} ' 49 | 'documents'.format(splits[index], splits[index + 1], 50 | splits[index + 1] - splits[index])) 51 | print_split_stats('train', 0) 52 | print_split_stats('validation', 1) 53 | print_split_stats('test', 2) 54 | 55 | def build_dataset(index, name): 56 | dataset = None 57 | if splits[index + 1] > splits[index]: 58 | documents = np.arange(start=splits[index], stop=splits[index + 1], 59 | step=1, dtype=np.int32) 60 | dataset = GPT2Dataset(name, data_prefix, 61 | documents, indexed_dataset, 62 | train_valid_test_num_samples[index], 63 | seq_length, seed) 64 | return dataset 65 | 66 | train_dataset = build_dataset(0, 'train') 67 | valid_dataset = build_dataset(1, 'valid') 68 | test_dataset = build_dataset(2, 'test') 69 | 70 | return (train_dataset, valid_dataset, test_dataset) 71 | 72 | 73 | def get_indexed_dataset_(data_prefix, data_impl, skip_warmup): 74 | """Build indexed dataset.""" 75 | print_rank_0(' > building dataset index ...') 76 | 77 | start_time = time.time() 78 | indexed_dataset = make_indexed_dataset(data_prefix, 79 | data_impl, 80 | skip_warmup) 81 | print_rank_0(' > finished creating indexed dataset in {:4f} ' 82 | 'seconds'.format(time.time() - start_time)) 83 | print_rank_0(' number of documents: {}'.format( 84 | indexed_dataset.sizes.shape[0])) 85 | 86 | return indexed_dataset 87 | 88 | 89 | class GPT2Dataset(torch.utils.data.Dataset): 90 | 91 | def __init__(self, name, data_prefix, documents, indexed_dataset, 92 | num_samples, seq_length, seed): 93 | 94 | self.name = name 95 | self.indexed_dataset = indexed_dataset 96 | 97 | # Checks 98 | assert np.min(documents) >= 0 99 | assert np.max(documents) < indexed_dataset.sizes.shape[0] 100 | 101 | # Build index mappings. 102 | self.doc_idx, self.sample_idx, self.shuffle_idx = _build_index_mappings( 103 | self.name, data_prefix, documents, self.indexed_dataset.sizes, 104 | num_samples, seq_length, seed) 105 | 106 | def __len__(self): 107 | # -1 is due to data structure used to retieve the index: 108 | # sample i --> [sample_idx[i], sample_idx[i+1]) 109 | return self.sample_idx.shape[0] - 1 110 | 111 | def __getitem__(self, idx): 112 | # Get the shuffled index. 113 | idx = self.shuffle_idx[idx] 114 | # Start and end documents and offsets. 115 | doc_index_f = self.sample_idx[idx][0] 116 | doc_index_l = self.sample_idx[idx + 1][0] 117 | offset_f = self.sample_idx[idx][1] 118 | offset_l = self.sample_idx[idx + 1][1] 119 | # If we are within the same document, just extract the chunk. 120 | if doc_index_f == doc_index_l: 121 | sample = self.indexed_dataset.get(self.doc_idx[doc_index_f], 122 | offset=offset_f, 123 | length=offset_l - offset_f + 1) 124 | else: 125 | # Otherwise, get the rest of the initial document. 126 | sample_list = [self.indexed_dataset.get(self.doc_idx[doc_index_f], 127 | offset=offset_f)] 128 | # Loop over all in between documents and add the entire document. 129 | for i in range(doc_index_f + 1, doc_index_l): 130 | sample_list.append(self.indexed_dataset.get(self.doc_idx[i])) 131 | # And finally add the relevant portion of last document. 132 | sample_list.append(self.indexed_dataset.get( 133 | self.doc_idx[doc_index_l], 134 | length=offset_l + 1)) 135 | sample = np.concatenate(sample_list) 136 | 137 | return {'text': np.array(sample, dtype=np.int64)} 138 | 139 | 140 | def _build_index_mappings(name, data_prefix, documents, sizes, 141 | num_samples, seq_length, seed): 142 | """Build doc-idx, sample-idx, and shuffle-idx. 143 | doc-idx: is an array (ordered) of documents to be used in training. 144 | sample-idx: is the start document index and document offset for each 145 | training sample. 146 | shuffle-idx: maps the sample index into a random index into sample-idx. 147 | """ 148 | # Number of tokens in each epoch and number of required epochs. 149 | tokens_per_epoch = _num_tokens(documents, sizes) 150 | num_epochs = _num_epochs(tokens_per_epoch, seq_length, num_samples) 151 | # rng state 152 | np_rng = np.random.RandomState(seed=seed) 153 | 154 | # Filename of the index mappings. 155 | _filename = data_prefix 156 | _filename += '_{}_indexmap'.format(name) 157 | _filename += '_{}ns'.format(num_samples) 158 | _filename += '_{}sl'.format(seq_length) 159 | _filename += '_{}s'.format(seed) 160 | doc_idx_filename = _filename + '_doc_idx.npy' 161 | sample_idx_filename = _filename + '_sample_idx.npy' 162 | shuffle_idx_filename = _filename + '_shuffle_idx.npy' 163 | 164 | # Build the indexed mapping if not exist. 165 | if torch.distributed.get_rank() == 0: 166 | if (not os.path.isfile(doc_idx_filename)) or \ 167 | (not os.path.isfile(sample_idx_filename)) or \ 168 | (not os.path.isfile(shuffle_idx_filename)): 169 | 170 | print_rank_0(' > WARNING: could not find index map files, building ' 171 | 'the indices on rank 0 ...') 172 | # doc-idx. 173 | start_time = time.time() 174 | doc_idx = _build_doc_idx(documents, num_epochs, np_rng) 175 | np.save(doc_idx_filename, doc_idx, allow_pickle=True) 176 | print_rank_0(' > elasped time to build and save doc-idx mapping ' 177 | '(seconds): {:4f}'.format(time.time() - start_time)) 178 | # sample-idx. 179 | start_time = time.time() 180 | # Use C++ implementation for speed. 181 | # First compile and then import. 182 | from data.dataset_utils import compile_helper 183 | compile_helper() 184 | from data import helpers 185 | assert doc_idx.dtype == np.int32 186 | assert sizes.dtype == np.int32 187 | sample_idx = helpers.build_sample_idx(sizes, doc_idx, seq_length, 188 | num_epochs, tokens_per_epoch) 189 | # sample_idx = _build_sample_idx(sizes, doc_idx, seq_length, 190 | # num_epochs, tokens_per_epoch) 191 | np.save(sample_idx_filename, sample_idx, allow_pickle=True) 192 | print_rank_0(' > elasped time to build and save sample-idx mapping ' 193 | '(seconds): {:4f}'.format(time.time() - start_time)) 194 | # shuffle-idx. 195 | start_time = time.time() 196 | # -1 is due to data structure used to retieve the index: 197 | # sample i --> [sample_idx[i], sample_idx[i+1]) 198 | shuffle_idx = _build_shuffle_idx(sample_idx.shape[0] - 1, np_rng) 199 | np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True) 200 | print_rank_0(' > elasped time to build and save shuffle-idx mapping' 201 | ' (seconds): {:4f}'.format(time.time() - start_time)) 202 | 203 | # This should be a barrier but nccl barrier assumes 204 | # device_index=rank which is not the case for model 205 | # parallel case 206 | counts = torch.cuda.LongTensor([1]) 207 | torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group()) 208 | assert counts[0].item() == torch.distributed.get_world_size( 209 | group=mpu.get_data_parallel_group()) 210 | 211 | # Load mappings. 212 | start_time = time.time() 213 | print_rank_0(' > loading doc-idx mapping from {}'.format( 214 | doc_idx_filename)) 215 | doc_idx = np.load(doc_idx_filename, allow_pickle=True, mmap_mode='r') 216 | print_rank_0(' > loading sample-idx mapping from {}'.format( 217 | sample_idx_filename)) 218 | sample_idx = np.load(sample_idx_filename, allow_pickle=True, mmap_mode='r') 219 | print_rank_0(' > loading shuffle-idx mapping from {}'.format( 220 | shuffle_idx_filename)) 221 | shuffle_idx = np.load(shuffle_idx_filename, allow_pickle=True, mmap_mode='r') 222 | print_rank_0(' loaded indexed file in {:3.3f} seconds'.format( 223 | time.time() - start_time)) 224 | print_rank_0(' total number of samples: {}'.format( 225 | sample_idx.shape[0])) 226 | print_rank_0(' total number of epochs: {}'.format(num_epochs)) 227 | 228 | return doc_idx, sample_idx, shuffle_idx 229 | 230 | 231 | def _num_tokens(documents, sizes): 232 | """Total number of tokens in the dataset.""" 233 | return np.sum(sizes[documents]) 234 | 235 | 236 | def _num_epochs(tokens_per_epoch, seq_length, num_samples): 237 | """Based on number of samples and sequence lenght, calculate how many 238 | epochs will be needed.""" 239 | num_epochs = 0 240 | total_tokens = 0 241 | while True: 242 | num_epochs += 1 243 | total_tokens += tokens_per_epoch 244 | # -1 is because we need to retrieve seq_length + 1 token each time 245 | # but the last token will overlap with the first token of the next 246 | # sample except for the last sample. 247 | if ((total_tokens - 1) // seq_length) >= num_samples: 248 | return num_epochs 249 | 250 | 251 | def _build_doc_idx(documents, num_epochs, np_rng): 252 | """Build an array with length = number-of-epochs * number-of-dcuments. 253 | Each index is mapped to a corresponding document.""" 254 | doc_idx = np.mgrid[0:num_epochs, 0:len(documents)][1] 255 | doc_idx[:] = documents 256 | doc_idx = doc_idx.reshape(-1) 257 | doc_idx = doc_idx.astype(np.int32) 258 | np_rng.shuffle(doc_idx) 259 | return doc_idx 260 | 261 | 262 | def _build_sample_idx(sizes, doc_idx, seq_length, 263 | num_epochs, tokens_per_epoch): 264 | """Sample index mapping is a 2D array with sizes 265 | [number-of-samples + 1, 2] where [..., 0] contains 266 | the index into `doc_idx` and [..., 1] is the 267 | starting offset in that document.""" 268 | 269 | # Total number of samples. For -1 see comments in `_num_epochs`. 270 | num_samples = (num_epochs * tokens_per_epoch - 1) // seq_length 271 | sample_idx = np.zeros([num_samples + 1, 2], dtype=np.int32) 272 | 273 | # Index into sample_idx. 274 | sample_index = 0 275 | # Index into doc_idx. 276 | doc_idx_index = 0 277 | # Begining offset for each document. 278 | doc_offset = 0 279 | # Start with first document and no offset. 280 | sample_idx[sample_index][0] = doc_idx_index 281 | sample_idx[sample_index][1] = doc_offset 282 | sample_index += 1 283 | while sample_index <= num_samples: 284 | # Start with a fresh sequence. 285 | remaining_seq_length = seq_length + 1 286 | while remaining_seq_length != 0: 287 | # Get the document length. 288 | doc_id = doc_idx[doc_idx_index] 289 | doc_length = sizes[doc_id] - doc_offset 290 | # And add it to the current sequence. 291 | remaining_seq_length -= doc_length 292 | # If we have more than a full sequence, adjust offset and set 293 | # remaining length to zero so we return from the while loop. 294 | # Note that -1 here is for the same reason we have -1 in 295 | # `_num_epochs` calculations. 296 | if remaining_seq_length <= 0: 297 | doc_offset += (remaining_seq_length + doc_length - 1) 298 | remaining_seq_length = 0 299 | else: 300 | # Otherwise, start from the begining of the next document. 301 | doc_idx_index += 1 302 | doc_offset = 0 303 | # Record the sequence. 304 | sample_idx[sample_index][0] = doc_idx_index 305 | sample_idx[sample_index][1] = doc_offset 306 | sample_index += 1 307 | 308 | return sample_idx 309 | 310 | 311 | def _build_shuffle_idx(size, np_rng): 312 | """Build the range [0, size) and shuffle.""" 313 | dtype_ = np.uint32 314 | if size >= (np.iinfo(np.uint32).max - 1): 315 | dtype_ = np.int64 316 | shuffle_idx = np.arange(start=0, stop=size, step=1, dtype=dtype_) 317 | np_rng.shuffle(shuffle_idx) 318 | return shuffle_idx 319 | -------------------------------------------------------------------------------- /mpu/layers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | # Parts of the code here are adapted from PyTorch 18 | # repo: https://github.com/pytorch/pytorch 19 | 20 | 21 | import math 22 | 23 | import torch 24 | import torch.nn.functional as F 25 | import torch.nn.init as init 26 | from torch.nn.parameter import Parameter 27 | 28 | from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm 29 | 30 | from .initialize import get_model_parallel_rank 31 | from .initialize import get_model_parallel_world_size 32 | from .mappings import copy_to_model_parallel_region 33 | from .mappings import gather_from_model_parallel_region 34 | from .mappings import reduce_from_model_parallel_region 35 | from .mappings import scatter_to_model_parallel_region 36 | from .random import get_cuda_rng_tracker 37 | from .utils import divide 38 | from .utils import split_tensor_along_last_dim 39 | from .utils import VocabUtility 40 | 41 | 42 | def _initialize_affine_weight(weight, output_size, input_size, 43 | per_partition_size, partition_dim, init_method, 44 | stride=1, return_master_weight=False): 45 | """Initialize affine weight for model parallel. 46 | 47 | Build the master weight on all processes and scatter 48 | the relevant chunk.""" 49 | # If we only use 1 process for model parallelism, bypass scatter. 50 | world_size = get_model_parallel_world_size() 51 | if world_size == 1: 52 | init_method(weight) 53 | if return_master_weight: 54 | return weight 55 | return None 56 | 57 | # Initialize master weight 58 | master_weight = torch.empty(output_size, input_size, 59 | dtype=weight.dtype, 60 | requires_grad=False) 61 | init_method(master_weight) 62 | 63 | # Split and copy 64 | per_partition_per_stride_size = divide(per_partition_size, stride) 65 | weight_list = torch.split(master_weight, per_partition_per_stride_size, 66 | dim=partition_dim) 67 | rank = get_model_parallel_rank() 68 | my_weight_list = weight_list[rank::world_size] 69 | 70 | with torch.no_grad(): 71 | torch.cat(my_weight_list, dim=partition_dim, out=weight) 72 | if return_master_weight: 73 | return master_weight 74 | return None 75 | 76 | 77 | class VocabParallelEmbedding(torch.nn.Module): 78 | """Embedding parallelized in the vocabulary dimension. 79 | 80 | This is mainly adapted from torch.nn.Embedding and all the default 81 | values are kept. 82 | Arguments: 83 | num_embeddings: vocabulary size. 84 | embedding_dim: size of hidden state. 85 | init_method: method to initialize weights. 86 | """ 87 | def __init__(self, num_embeddings, embedding_dim, 88 | init_method=init.xavier_normal_): 89 | super(VocabParallelEmbedding, self).__init__() 90 | # Keep the input dimensions. 91 | self.num_embeddings = num_embeddings 92 | self.embedding_dim = embedding_dim 93 | # Set the detauls for compatibility. 94 | self.padding_idx = None 95 | self.max_norm = None 96 | self.norm_type = 2. 97 | self.scale_grad_by_freq = False 98 | self.sparse = False 99 | self._weight = None 100 | # Divide the weight matrix along the vocaburaly dimension. 101 | self.vocab_start_index, self.vocab_end_index = \ 102 | VocabUtility.vocab_range_from_global_vocab_size( 103 | self.num_embeddings, get_model_parallel_rank(), 104 | get_model_parallel_world_size()) 105 | self.num_embeddings_per_partition = self.vocab_end_index - \ 106 | self.vocab_start_index 107 | 108 | # Allocate weights. 109 | self.weight = Parameter(torch.Tensor(self.num_embeddings_per_partition, 110 | self.embedding_dim)) 111 | self.weight.model_parallel = True 112 | # And initialize. 113 | _initialize_affine_weight( 114 | self.weight, self.num_embeddings, self.embedding_dim, 115 | self.num_embeddings_per_partition, 0, init_method) 116 | 117 | def forward(self, input_): 118 | # Build the mask. 119 | input_mask = (input_ < self.vocab_start_index) | \ 120 | (input_ >= self.vocab_end_index) 121 | # Mask the input. 122 | masked_input = input_.clone() - self.vocab_start_index 123 | masked_input[input_mask] = 0 124 | # Get the embeddings. 125 | output_parallel = F.embedding(masked_input, self.weight, 126 | self.padding_idx, self.max_norm, 127 | self.norm_type, self.scale_grad_by_freq, 128 | self.sparse) 129 | # Mask the output embedding. 130 | output_parallel[input_mask, :] = 0.0 131 | # Reduce across all the model parallel GPUs. 132 | output = reduce_from_model_parallel_region(output_parallel) 133 | return output 134 | 135 | 136 | class ParallelEmbedding(torch.nn.Module): 137 | """Embedding parallelized in the embedding dimension. 138 | 139 | This is mainly adapted from torch.nn.Embedding and all the default 140 | values are kept. 141 | Arguments: 142 | num_embeddings: vocabulary size. 143 | embedding_dim: size of hidden state. 144 | init_method: method to initialize weights. 145 | """ 146 | def __init__(self, num_embeddings, embedding_dim, 147 | init_method=init.xavier_normal_, 148 | keep_master_weight_for_test=False): 149 | super(ParallelEmbedding, self).__init__() 150 | # Keep the input dimensions. 151 | self.num_embeddings = num_embeddings 152 | self.embedding_dim = embedding_dim 153 | # Set some detauls for compatibility. 154 | self.padding_idx = None 155 | self.max_norm = None 156 | self.norm_type = 2. 157 | self.scale_grad_by_freq = False 158 | self.sparse = False 159 | self._weight = None 160 | # Divide the weight matrix along the embedding dimension. 161 | world_size = get_model_parallel_world_size() 162 | self.embedding_dim_per_partition = divide(self.embedding_dim, 163 | world_size) 164 | 165 | # Allocate weights. 166 | self.weight = Parameter(torch.Tensor(self.num_embeddings, 167 | self.embedding_dim_per_partition)) 168 | self.weight.model_parallel = True 169 | # And initialize. 170 | _initialize_affine_weight( 171 | self.weight, self.num_embeddings, self.embedding_dim, 172 | self.embedding_dim_per_partition, 1, init_method, 173 | stride=1, return_master_weight=False) 174 | 175 | def forward(self, input_): 176 | input_parallel = copy_to_model_parallel_region(input_) 177 | output_parallel = F.embedding(input_parallel, self.weight, 178 | self.padding_idx, self.max_norm, 179 | self.norm_type, self.scale_grad_by_freq, 180 | self.sparse) 181 | output = gather_from_model_parallel_region(output_parallel) 182 | return output 183 | 184 | 185 | class ColumnParallelLinear(torch.nn.Module): 186 | """Linear layer with column parallelism. 187 | 188 | The linear layer is defined as Y = XA + b. A is parallelized along 189 | its second dimension as A = [A_1, ..., A_p]. 190 | 191 | Arguments: 192 | input_size: first dimension of matrix A. 193 | output_size: second dimension of matrix A. 194 | bias: If true, add bias 195 | gather_output: If true, call all-gether on output and make Y avaiable 196 | to all GPUs, otherwise, every GPU will have its output 197 | which is Y_i = XA_i 198 | init_method: method to initialize weights. Note that bias is always set 199 | to zero. 200 | stride: For the strided linear layers. 201 | keep_master_weight_for_test: This was added for testing and should be 202 | set to False. It returns the master weights 203 | used for initialization. 204 | """ 205 | def __init__(self, input_size, output_size, bias=True, gather_output=True, 206 | init_method=init.xavier_normal_, stride=1, 207 | keep_master_weight_for_test=False): 208 | super(ColumnParallelLinear, self).__init__() 209 | 210 | # Keep input parameters 211 | self.input_size = input_size 212 | self.output_size = output_size 213 | self.gather_output = gather_output 214 | # Divide the weight matrix along the last dimension. 215 | world_size = get_model_parallel_world_size() 216 | self.output_size_per_partition = divide(output_size, world_size) 217 | 218 | # Parameters. 219 | # Note: torch.nn.functional.linear performs XA^T + b and as a result 220 | # we allocate the transpose. 221 | self.weight = Parameter(torch.Tensor(self.output_size_per_partition, 222 | self.input_size)) 223 | self.weight.model_parallel = True 224 | if bias: 225 | self.bias = Parameter(torch.Tensor(self.output_size_per_partition)) 226 | self.bias.model_parallel = True 227 | # Always initialize bias to zero. 228 | with torch.no_grad(): 229 | self.bias.zero_() 230 | else: 231 | self.register_parameter('bias', None) 232 | 233 | # Initialize weight. 234 | self.master_weight = _initialize_affine_weight( 235 | self.weight, self.output_size, self.input_size, 236 | self.output_size_per_partition, 0, init_method, 237 | stride=stride, return_master_weight=keep_master_weight_for_test) 238 | 239 | def forward(self, input_): 240 | # Set up backprop all-reduce. 241 | input_parallel = copy_to_model_parallel_region(input_) 242 | # Matrix multiply. 243 | output_parallel = F.linear(input_parallel, self.weight, self.bias) 244 | if self.gather_output: 245 | # All-gather across the partitions. 246 | output = gather_from_model_parallel_region(output_parallel) 247 | else: 248 | output = output_parallel 249 | return output 250 | 251 | 252 | class RowParallelLinear(torch.nn.Module): 253 | """Linear layer with row parallelism. 254 | 255 | The linear layer is defined as Y = XA + b. A is parallelized along 256 | its first dimension and X along its second dimension as: 257 | - - 258 | | A_1 | 259 | | . | 260 | A = | . | X = [X_1, ..., X_p] 261 | | . | 262 | | A_p | 263 | - - 264 | Arguments: 265 | input_size: first dimension of matrix A. 266 | output_size: second dimension of matrix A. 267 | bias: If true, add bias. Note that bias is not parallelized. 268 | input_is_parallel: If true, we assume that the input is already 269 | split across the GPUs and we do not split 270 | again. 271 | init_method: method to initialize weights. Note that bias is always set 272 | to zero. 273 | stride: For the strided linear layers. 274 | keep_master_weight_for_test: This was added for testing and should be 275 | set to False. It returns the master weights 276 | used for initialization. 277 | """ 278 | def __init__(self, input_size, output_size, bias=True, 279 | input_is_parallel=False, 280 | init_method=init.xavier_normal_, stride=1, 281 | keep_master_weight_for_test=False): 282 | super(RowParallelLinear, self).__init__() 283 | 284 | # Keep input parameters 285 | self.input_size = input_size 286 | self.output_size = output_size 287 | self.input_is_parallel = input_is_parallel 288 | # Divide the weight matrix along the last dimension. 289 | world_size = get_model_parallel_world_size() 290 | self.input_size_per_partition = divide(input_size, world_size) 291 | 292 | # Parameters. 293 | # Note: torch.nn.functional.linear performs XA^T + b and as a result 294 | # we allocate the transpose. 295 | self.weight = Parameter(torch.Tensor(self.output_size, 296 | self.input_size_per_partition)) 297 | self.weight.model_parallel = True 298 | if bias: 299 | self.bias = Parameter(torch.Tensor(self.output_size)) 300 | # Always initialize bias to zero. 301 | with torch.no_grad(): 302 | self.bias.zero_() 303 | else: 304 | self.register_parameter('bias', None) 305 | 306 | # Initialize weight. 307 | self.master_weight = _initialize_affine_weight( 308 | self.weight, self.output_size, self.input_size, 309 | self.input_size_per_partition, 1, init_method, 310 | stride=stride, return_master_weight=keep_master_weight_for_test) 311 | 312 | def forward(self, input_): 313 | # Set up backprop all-reduce. 314 | if self.input_is_parallel: 315 | input_parallel = input_ 316 | else: 317 | input_parallel = scatter_to_model_parallel_region(input_) 318 | # Matrix multiply. 319 | output_parallel = F.linear(input_parallel, self.weight) 320 | # All-reduce across all the partitions. 321 | output_ = reduce_from_model_parallel_region(output_parallel) 322 | if self.bias is not None: 323 | output = output_ + self.bias 324 | else: 325 | output = output_ 326 | return output 327 | 328 | --------------------------------------------------------------------------------