├── .dockerignore ├── .gitignore ├── .style.yapf ├── LICENSE ├── README.md ├── benchmarks ├── make_training_seqlen_plots.py ├── paddle_train.py └── torch_train.py ├── docker ├── build.sh ├── interactive.sh ├── ngc_paddle.Dockerfile └── ngc_pyt.Dockerfile ├── docs └── images │ ├── binning.gif │ ├── binning_perf.gif │ ├── preprocess_perf.gif │ └── summary.gif ├── examples ├── local_example.sh └── slurm_example.sub ├── lddl ├── __init__.py ├── dask │ ├── __init__.py │ ├── bart │ │ ├── __init__.py │ │ └── pretrain.py │ ├── bert │ │ ├── DASK_LICENSE.txt │ │ ├── __init__.py │ │ ├── binning.py │ │ └── pretrain.py │ ├── load_balance.py │ └── readers.py ├── download │ ├── __init__.py │ ├── books.py │ ├── common_crawl.py │ ├── openwebtext.py │ ├── utils.py │ └── wikipedia.py ├── paddle │ ├── __init__.py │ ├── bert.py │ ├── dataloader.py │ ├── datasets.py │ ├── log.py │ └── utils.py ├── random.py ├── torch │ ├── __init__.py │ ├── bert.py │ ├── dataloader.py │ ├── datasets.py │ ├── log.py │ └── utils.py ├── torch_mp │ ├── __init__.py │ ├── bert.py │ ├── dataloader.py │ ├── datasets.py │ ├── log.py │ └── utils.py ├── types.py └── utils.py └── setup.py /.dockerignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | 140 | # Dask cache 141 | dask-worker-space/ 142 | 143 | # Data downloaded and generated when running the examples. 144 | data/ 145 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | 140 | # Dask cache 141 | dask-worker-space/ 142 | 143 | # Data downloaded and generated when running the examples. 144 | data/ 145 | -------------------------------------------------------------------------------- /.style.yapf: -------------------------------------------------------------------------------- 1 | [style] 2 | based_on_style = google 3 | indent_width = 2 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | SPDX-License-Identifier: MIT 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining a 5 | copy of this software and associated documentation files (the "Software"), 6 | to deal in the Software without restriction, including without limitation 7 | the rights to use, copy, modify, merge, publish, distribute, sublicense, 8 | and/or sell copies of the Software, and to permit persons to whom the 9 | Software is furnished to do so, subject to the following conditions: 10 | 11 | The above copyright notice and this permission notice shall be included in 12 | all copies or substantial portions of the Software. 13 | 14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 17 | THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 19 | FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 20 | DEALINGS IN THE SOFTWARE. 21 | -------------------------------------------------------------------------------- /benchmarks/make_training_seqlen_plots.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES 3 | # Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | # SPDX-License-Identifier: MIT 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a 7 | # copy of this software and associated documentation files (the "Software"), 8 | # to deal in the Software without restriction, including without limitation 9 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 10 | # and/or sell copies of the Software, and to permit persons to whom the 11 | # Software is furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 19 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 21 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 22 | # DEALINGS IN THE SOFTWARE. 23 | # 24 | 25 | import argparse 26 | import os 27 | import random 28 | import numpy as np 29 | import matplotlib.pyplot as plt 30 | 31 | from lddl.utils import expand_outdir_and_mkdir, get_all_files_paths_under 32 | 33 | 34 | def collect_data(args): 35 | npzs = [ 36 | fp for fp in get_all_files_paths_under(args.in_dir) 37 | if 'lens_' in fp and os.path.splitext(fp)[1] == '.npz' 38 | ] 39 | min_lens, max_lens = {}, {} 40 | seq_len_hist, padded_zero_hist = None, None 41 | for npz in npzs: 42 | rank = int(os.path.splitext(os.path.basename(npz))[0].split('_')[1]) 43 | with np.load(npz) as data: 44 | min_lens[rank] = data['min_lens'] 45 | max_lens[rank] = data['max_lens'] 46 | if seq_len_hist is None: 47 | seq_len_hist = data['seq_len_hist'] 48 | else: 49 | seq_len_hist += data['seq_len_hist'] 50 | if padded_zero_hist is None: 51 | padded_zero_hist = data['padded_zero_hist'] 52 | else: 53 | padded_zero_hist += data['padded_zero_hist'] 54 | assert max_lens[rank].shape == min_lens[rank].shape 55 | 56 | return min_lens, max_lens, seq_len_hist, padded_zero_hist 57 | 58 | 59 | def plot_rank_diff(args, min_lens, max_lens): 60 | """ Make sure the diff between min seq lens and max seq lens is smaller than 61 | the bin size. 62 | 63 | min_lens and max_lens and dict[int] -> np.array that map rank number to 64 | the list of max and min seq lens of all training iterations. 65 | """ 66 | rank_arrays = [] 67 | diffs = [] 68 | ranks = list(sorted(min_lens.keys())) 69 | for rank in ranks: 70 | diffs.append(max_lens[rank] - min_lens[rank]) 71 | rank_arrays.append(np.full(min_lens[rank].shape, rank, dtype=np.uint16)) 72 | rank_arrays = np.concatenate(rank_arrays) 73 | diffs = np.concatenate(diffs) 74 | plt.scatter(rank_arrays, diffs, s=0.1) 75 | plt.xlabel('rank') 76 | plt.xticks(ranks) 77 | plt.ylabel('diff') 78 | plt.yticks(np.arange(0, diffs.max() + 1, 1)) 79 | plt.title('rank vs. diff') 80 | plt.grid() 81 | plt.savefig(os.path.join(args.out_dir, 'rank_dist.png')) 82 | plt.close() 83 | 84 | 85 | def plot_min_max_lens(args, min_lens, max_lens): 86 | """ Make sure the min and max seq lens are limited by the bin size. 87 | """ 88 | ranks = list(sorted(min_lens.keys())) 89 | for rank in ranks: 90 | plt.scatter(min_lens[rank], max_lens[rank], s=0.1) 91 | plt.xlabel('min_lens') 92 | plt.xticks(np.arange(0, min_lens[rank].max() + args.bin_size, 93 | args.bin_size)) 94 | plt.ylabel('max_lens') 95 | plt.yticks(np.arange(0, max_lens[rank].max() + args.bin_size, 96 | args.bin_size)) 97 | plt.title('min_lens vs. max_lens') 98 | plt.grid() 99 | plt.savefig(os.path.join(args.out_dir, 'min_max_lens_{}.png'.format(rank))) 100 | plt.close() 101 | 102 | 103 | def plot_global_diff(args, min_lens, max_lens): 104 | """ Make sure that each rank chooses the same bin in each iteration. 105 | """ 106 | ranks = list(sorted(min_lens.keys())) 107 | global_min_lens = np.stack([min_lens[rank] for rank in ranks], axis=-1) 108 | global_max_lens = np.stack([max_lens[rank] for rank in ranks], axis=-1) 109 | diffs = global_max_lens.max(axis=-1) - global_min_lens.min(axis=-1) 110 | plt.scatter(np.full(diffs.shape, 0, dtype=np.uint8), diffs, s=0.1) 111 | plt.xticks([0]) 112 | plt.ylabel('diff') 113 | plt.yticks(np.arange(0, diffs.max() + 1, 1)) 114 | plt.title('global diff') 115 | plt.grid() 116 | plt.savefig(os.path.join(args.out_dir, 'global_diff.png')) 117 | plt.close() 118 | 119 | 120 | def plot_seq_len_hist(args, seq_len_hist): 121 | hist = [] 122 | xticks = [] 123 | for start in range(1, seq_len_hist.shape[0], args.seq_len_hist_bin): 124 | n = 0 125 | for seq_len in range(start, start + args.seq_len_hist_bin): 126 | n += seq_len_hist[seq_len] 127 | hist.append(n) 128 | xticks.append('{}-{}'.format(start, start + args.seq_len_hist_bin - 1)) 129 | plt.figure(figsize=(20, 5)) 130 | plt.bar(xticks, hist) 131 | plt.xlabel('seq_lens') 132 | plt.ylabel('# Samples') 133 | plt.title('Sequence Length Histogram') 134 | plt.grid() 135 | plt.savefig(os.path.join(args.out_dir, 'seq_len_hist.png')) 136 | plt.close() 137 | 138 | 139 | def plot_padded_zero_hist(args, padded_zero_hist): 140 | plt.bar(np.arange(0, len(padded_zero_hist)), padded_zero_hist) 141 | plt.xlabel('# zeros in a sequence') 142 | plt.ylabel('# samples') 143 | plt.title('# zeros in a sequence vs. # samples') 144 | plt.grid() 145 | plt.savefig(os.path.join(args.out_dir, 'padded_zero_hist.png')) 146 | plt.close() 147 | 148 | 149 | def hist_sum(hist): 150 | s = 0 151 | for v in range(hist.shape[0]): 152 | s += v * hist[v] 153 | return s 154 | 155 | 156 | def calculate_padded_zero_ratio(padded_zero_hist, seq_len_hist): 157 | num_zeros = hist_sum(padded_zero_hist) 158 | num_tokens = hist_sum(seq_len_hist) 159 | print('padded_zeros : tokens = {} : {} = {} : 1'.format( 160 | num_zeros, num_tokens, num_zeros / num_tokens)) 161 | 162 | 163 | def main(args): 164 | args.out_dir = expand_outdir_and_mkdir(args.out_dir) 165 | min_lens, max_lens, seq_len_hist, padded_zero_hist = collect_data(args) 166 | plot_rank_diff(args, min_lens, max_lens) 167 | plot_min_max_lens(args, min_lens, max_lens) 168 | plot_global_diff(args, min_lens, max_lens) 169 | plot_seq_len_hist(args, seq_len_hist) 170 | plot_padded_zero_hist(args, padded_zero_hist) 171 | calculate_padded_zero_ratio(padded_zero_hist, seq_len_hist) 172 | 173 | 174 | def attach_args(parser=argparse.ArgumentParser()): 175 | parser.add_argument('--in-dir', type=str, required=True) 176 | parser.add_argument('--out-dir', type=str, default="./fig") 177 | parser.add_argument('--bin-size', type=int, default=32) 178 | parser.add_argument('--seq-len-hist-bin', type=int, default=32) 179 | 180 | return parser 181 | 182 | 183 | if __name__ == "__main__": 184 | main(attach_args().parse_args()) 185 | -------------------------------------------------------------------------------- /benchmarks/paddle_train.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES 3 | # Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | # SPDX-License-Identifier: MIT 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a 7 | # copy of this software and associated documentation files (the "Software"), 8 | # to deal in the Software without restriction, including without limitation 9 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 10 | # and/or sell copies of the Software, and to permit persons to whom the 11 | # Software is furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 19 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 21 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 22 | # DEALINGS IN THE SOFTWARE. 23 | # 24 | 25 | import argparse 26 | import logging 27 | import numpy as np 28 | import os 29 | import time 30 | from transformers import BertTokenizerFast 31 | import paddle.distributed as dist 32 | 33 | from lddl.paddle import get_bert_pretrain_data_loader 34 | from lddl.paddle.utils import barrier, get_rank, get_world_size 35 | from lddl.utils import mkdir 36 | 37 | 38 | def get_batch_seq_lens(attention_mask): 39 | return attention_mask.sum(axis=1) 40 | 41 | 42 | class AverageMeter: 43 | """ 44 | Computes and stores the average and current value 45 | """ 46 | 47 | def __init__(self, warmup=0, keep=False): 48 | self.reset() 49 | self.warmup = warmup 50 | self.keep = keep 51 | 52 | def reset(self): 53 | self.val = 0 54 | self.avg = 0 55 | self.max = float('-inf') 56 | self.min = float('inf') 57 | self.sum = 0 58 | self.count = 0 59 | self.iters = 0 60 | self.vals = [] 61 | 62 | def update(self, val, n=1): 63 | self.iters += 1 64 | self.val = val 65 | 66 | if self.iters > self.warmup: 67 | self.sum += val * n 68 | self.max = max(val, self.max) 69 | self.min = min(val, self.min) 70 | self.count += n 71 | self.avg = self.sum / self.count 72 | if self.keep: 73 | self.vals.append(val) 74 | 75 | 76 | class Histogram: 77 | """ 78 | Computes and stores the histogram of values. 79 | """ 80 | 81 | def __init__(self): 82 | self.hist = np.zeros((1,), dtype=np.uint64) 83 | 84 | def update(self, val, n=1): 85 | if val >= self.hist.shape[0]: 86 | new_hist = np.zeros((val + 1,), dtype=np.uint64) 87 | new_hist[:self.hist.shape[0]] = self.hist[:] 88 | self.hist = new_hist 89 | self.hist[val] += n 90 | 91 | def update_with_tensor(self, t): 92 | for v in t.flatten().tolist(): 93 | self.update(v) 94 | 95 | 96 | def main(args): 97 | 98 | dist.init_parallel_env() 99 | 100 | world_size = get_world_size() 101 | if get_rank() == 0 and args.seq_len_dir is not None: 102 | mkdir(args.seq_len_dir) 103 | 104 | loader = get_bert_pretrain_data_loader( 105 | args.path, 106 | shuffle_buffer_size=args.shuffle_buffer_size, 107 | shuffle_buffer_warmup_factor=args.shuffle_buffer_warmup_factor, 108 | vocab_file=args.vocab_file, 109 | data_loader_kwargs={ 110 | 'batch_size': args.batch_size, 111 | 'num_workers': args.workers, 112 | 'prefetch_factor': args.prefetch 113 | }, 114 | mlm_probability=args.mlm_probability, 115 | base_seed=args.seed, 116 | log_dir=args.log_dir, 117 | log_level=getattr(logging, args.log_level), 118 | return_raw_samples=args.debug, 119 | start_epoch=args.start_epoch, 120 | sequence_length_alignment=args.sequence_length_alignment, 121 | ignore_index=args.ignore_index, 122 | ) 123 | if os.path.isfile(args.vocab_file): 124 | test_tokenizer = BertTokenizerFast(args.vocab_file) 125 | else: 126 | test_tokenizer = BertTokenizerFast.from_pretrained(args.vocab_file) 127 | 128 | meter = AverageMeter(warmup=args.warmup) 129 | 130 | lens_shape = (args.epochs, min(len(loader), args.iters_per_epoch)) 131 | min_lens, max_lens, batch_sizes, padded_lens = ( 132 | np.zeros(lens_shape, dtype=np.uint16), 133 | np.zeros(lens_shape, dtype=np.uint16), 134 | np.zeros(lens_shape, dtype=np.uint16), 135 | np.zeros(lens_shape, dtype=np.uint16), 136 | ) 137 | seq_len_hist = Histogram() 138 | padded_zero_hist = Histogram() 139 | 140 | step = 0 141 | for epoch in range(args.start_epoch, args.start_epoch + args.epochs): 142 | barrier() 143 | epoch_timer_start = time.time() 144 | batch_timer_start = time.time() 145 | total_samples = 0 146 | for i, data in enumerate(loader): 147 | step += 1 148 | if not args.debug: 149 | (input_ids, token_type_ids, attention_mask, masked_lm_labels, 150 | next_sentence_labels) = ( 151 | data['input_ids'], 152 | data['token_type_ids'], 153 | data['attention_mask'], 154 | data['masked_lm_labels'], 155 | data['next_sentence_labels'], 156 | ) 157 | 158 | batch_timer_stop = time.time() 159 | elapsed = batch_timer_stop - batch_timer_start 160 | meter.update(elapsed) 161 | 162 | if args.debug: 163 | current_samples = len(data[0]) * world_size 164 | else: 165 | current_samples = input_ids.shape[0] * world_size 166 | # mask shape: [batch, 1, 1, seq_len] -> [batch, seq_len] 167 | assert attention_mask.dim() == 4 168 | attention_mask = attention_mask.squeeze(axis=[1, 2]) 169 | assert input_ids.shape == token_type_ids.shape 170 | assert input_ids.shape == attention_mask.shape 171 | assert input_ids.shape == masked_lm_labels.shape 172 | # next_sentence_laels shape: [batch, 1] 173 | assert next_sentence_labels.dim() == 2 174 | assert next_sentence_labels.shape[1] == 1 175 | assert input_ids.shape[0] == next_sentence_labels.shape[0] 176 | seq_lens = get_batch_seq_lens(attention_mask) 177 | seq_len_hist.update_with_tensor(seq_lens) 178 | ( 179 | min_lens[epoch - args.start_epoch, i], 180 | max_lens[epoch - args.start_epoch, i], 181 | ) = seq_lens.min(), seq_lens.max() 182 | batch_sizes[epoch - args.start_epoch, i] = input_ids.shape[0] 183 | padded_lens[epoch - args.start_epoch, i] = input_ids.shape[1] 184 | padded_zero_hist.update_with_tensor(input_ids.shape[1] - seq_lens) 185 | 186 | total_samples += current_samples 187 | current_throughput = current_samples / elapsed 188 | if (i + 1) % args.log_freq == 0 and get_rank() == 0: 189 | avg_throughput = total_samples / meter.sum 190 | print('avg_throughput={}, avg_latency={} ms, ' 191 | 'min_latency={} ms, max_latency={} ms, ' 192 | 'current_throughput={}, current_latency={} ms'.format( 193 | avg_throughput, 194 | meter.avg * 1000, 195 | meter.min * 1000, 196 | meter.max * 1000, 197 | current_throughput, 198 | elapsed * 1000, 199 | )) 200 | if args.debug: 201 | print('len(data[0])={}'.format(len(data[0]))) 202 | print('sample=({} {} - {})'.format( 203 | data[0][0], 204 | data[1][0], 205 | data[2][0], 206 | )) 207 | else: 208 | print("Min length={} Max length={} Diff={}".format( 209 | min_lens[epoch - args.start_epoch, i], 210 | max_lens[epoch - args.start_epoch, i], 211 | max_lens[epoch - args.start_epoch, i] - 212 | min_lens[epoch - args.start_epoch, i], 213 | )) 214 | print('input_ids.shape={}'.format(input_ids.shape)) 215 | print('input_ids[0]={}'.format(input_ids[0])) 216 | print('convert_ids_to_tokens(input_ids[0])={}'.format( 217 | test_tokenizer.convert_ids_to_tokens(input_ids[0].tolist()))) 218 | print('token_type_ids[0]={}'.format(token_type_ids[0])) 219 | print('attention_mask[0]={}'.format(attention_mask[0])) 220 | print('masked_lm_labels[0]={}'.format(masked_lm_labels[0])) 221 | print('next_sentence_labels[0]={}'.format(next_sentence_labels[0])) 222 | mask = masked_lm_labels[0] != args.ignore_index 223 | print(f"mask: {mask}") 224 | for i in range(0, mask.shape[0]): 225 | if mask[i]: 226 | input_ids[0, i] = masked_lm_labels[0, i] 227 | print('original sequence={}'.format( 228 | test_tokenizer.convert_ids_to_tokens(input_ids[0].tolist()))) 229 | barrier() 230 | batch_timer_start = time.time() 231 | if step >= args.iters_per_epoch: 232 | break 233 | epoch_timer_stop = time.time() 234 | epoch_elapsed = epoch_timer_stop - epoch_timer_start 235 | if get_rank() == 0: 236 | avg_throughput = total_samples / meter.sum 237 | print('epoch={}, epoch_elapsed={}, avg_throughput={}, ' 238 | 'total_samples={}'.format( 239 | epoch, 240 | epoch_elapsed, 241 | avg_throughput, 242 | total_samples, 243 | )) 244 | assert meter.iters == min(len(loader), args.iters_per_epoch) 245 | meter.reset() 246 | 247 | if args.seq_len_dir is not None: 248 | # Save the sequence lengths to file 249 | np.savez_compressed( 250 | os.path.join(args.seq_len_dir, 'lens_{}.npz'.format(get_rank())), 251 | min_lens=min_lens, 252 | max_lens=max_lens, 253 | batch_sizes=batch_sizes, 254 | padded_lens=padded_lens, 255 | seq_len_hist=seq_len_hist.hist, 256 | padded_zero_hist=padded_zero_hist.hist, 257 | ) 258 | 259 | 260 | def attach_args(parser=argparse.ArgumentParser()): 261 | parser.add_argument('--path', type=str, required=True) 262 | parser.add_argument('--batch-size', type=int, default=64) 263 | parser.add_argument('--workers', type=int, default=4) 264 | parser.add_argument('--warmup', type=int, default=0) 265 | parser.add_argument('--epochs', type=int, default=2) 266 | parser.add_argument('--iters-per-epoch', type=int, default=float('inf')) 267 | parser.add_argument('--prefetch', type=int, default=2) 268 | parser.add_argument('--mlm-probability', type=float, default=0.15) 269 | parser.add_argument('--shuffle-buffer-size', type=int, default=16384) 270 | parser.add_argument('--shuffle-buffer-warmup-factor', type=int, default=16) 271 | parser.add_argument('--vocab-file', type=str, required=True) 272 | parser.add_argument('--seed', type=int, default=127) 273 | parser.add_argument('--start-epoch', type=int, default=0) 274 | parser.add_argument('--debug', action='store_true', default=False) 275 | parser.add_argument('--log-freq', type=int, default=1000) 276 | parser.add_argument('--log-dir', type=str, default=None) 277 | parser.add_argument( 278 | '--log-level', 279 | type=str, 280 | choices=['CRITICAL', 'ERROR', 'WARNING', 'INFO', 'DEBUG'], 281 | default='WARNING', 282 | ) 283 | parser.add_argument('--seq-len-dir', type=str, default=None) 284 | parser.add_argument('--sequence-length-alignment', type=int, default=8) 285 | parser.add_argument('--ignore-index', type=int, default=-1) 286 | return parser 287 | 288 | 289 | if __name__ == '__main__': 290 | main(attach_args().parse_args()) 291 | -------------------------------------------------------------------------------- /benchmarks/torch_train.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES 3 | # Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | # SPDX-License-Identifier: MIT 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a 7 | # copy of this software and associated documentation files (the "Software"), 8 | # to deal in the Software without restriction, including without limitation 9 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 10 | # and/or sell copies of the Software, and to permit persons to whom the 11 | # Software is furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 19 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 21 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 22 | # DEALINGS IN THE SOFTWARE. 23 | # 24 | 25 | import argparse 26 | import logging 27 | import numpy as np 28 | import os 29 | import random 30 | import time 31 | import torch 32 | from transformers import BertTokenizerFast 33 | 34 | from lddl.torch import get_bert_pretrain_data_loader 35 | from lddl.torch.utils import barrier, get_rank 36 | from lddl.utils import mkdir 37 | 38 | 39 | def get_batch_seq_lens(attention_mask): 40 | return attention_mask.sum(dim=1).int() 41 | 42 | 43 | class AverageMeter: 44 | """ 45 | Computes and stores the average and current value 46 | """ 47 | 48 | def __init__(self, warmup=0, keep=False): 49 | self.reset() 50 | self.warmup = warmup 51 | self.keep = keep 52 | 53 | def reset(self): 54 | self.val = 0 55 | self.avg = 0 56 | self.max = float('-inf') 57 | self.min = float('inf') 58 | self.sum = 0 59 | self.count = 0 60 | self.iters = 0 61 | self.vals = [] 62 | 63 | def update(self, val, n=1): 64 | self.iters += 1 65 | self.val = val 66 | 67 | if self.iters > self.warmup: 68 | self.sum += val * n 69 | self.max = max(val, self.max) 70 | self.min = min(val, self.min) 71 | self.count += n 72 | self.avg = self.sum / self.count 73 | if self.keep: 74 | self.vals.append(val) 75 | 76 | 77 | class Histogram: 78 | """ 79 | Computes and stores the histogram of values. 80 | """ 81 | 82 | def __init__(self): 83 | self.hist = np.zeros((1,), dtype=np.uint64) 84 | 85 | def update(self, val, n=1): 86 | if val >= self.hist.shape[0]: 87 | new_hist = np.zeros((val + 1,), dtype=np.uint64) 88 | new_hist[:self.hist.shape[0]] = self.hist[:] 89 | self.hist = new_hist 90 | self.hist[val] += n 91 | 92 | def update_with_tensor(self, t): 93 | for v in t.flatten().tolist(): 94 | self.update(v) 95 | 96 | 97 | def main(args): 98 | torch.cuda.set_device(args.local_rank) 99 | world_size = int(os.getenv('WORLD_SIZE', 1)) 100 | if world_size > 1: 101 | torch.distributed.init_process_group( 102 | backend='nccl', 103 | init_method='env://', 104 | ) 105 | 106 | if get_rank() == 0 and args.seq_len_dir is not None: 107 | mkdir(args.seq_len_dir) 108 | 109 | loader = get_bert_pretrain_data_loader( 110 | args.path, 111 | local_rank=args.local_rank, 112 | shuffle_buffer_size=args.shuffle_buffer_size, 113 | shuffle_buffer_warmup_factor=args.shuffle_buffer_warmup_factor, 114 | vocab_file=args.vocab_file, 115 | data_loader_kwargs={ 116 | 'batch_size': args.batch_size, 117 | 'num_workers': args.workers, 118 | 'prefetch_factor': args.prefetch 119 | }, 120 | mlm_probability=args.mlm_probability, 121 | base_seed=args.seed, 122 | log_dir=args.log_dir, 123 | log_level=getattr(logging, args.log_level), 124 | return_raw_samples=args.debug, 125 | start_epoch=args.start_epoch, 126 | sequence_length_alignment=args.sequence_length_alignment, 127 | ignore_index=args.ignore_index, 128 | ) 129 | if os.path.isfile(args.vocab_file): 130 | test_tokenizer = BertTokenizerFast(args.vocab_file) 131 | else: 132 | test_tokenizer = BertTokenizerFast.from_pretrained(args.vocab_file) 133 | 134 | meter = AverageMeter(warmup=args.warmup) 135 | 136 | lens_shape = (args.epochs, min(len(loader), args.iters_per_epoch)) 137 | min_lens, max_lens, batch_sizes, padded_lens = ( 138 | np.zeros(lens_shape, dtype=np.uint16), 139 | np.zeros(lens_shape, dtype=np.uint16), 140 | np.zeros(lens_shape, dtype=np.uint16), 141 | np.zeros(lens_shape, dtype=np.uint16), 142 | ) 143 | seq_len_hist = Histogram() 144 | padded_zero_hist = Histogram() 145 | 146 | for epoch in range(args.start_epoch, args.start_epoch + args.epochs): 147 | barrier() 148 | epoch_timer_start = time.time() 149 | batch_timer_start = time.time() 150 | total_samples = 0 151 | for i, data in enumerate(loader): 152 | if i >= args.iters_per_epoch: 153 | break 154 | if not args.debug: 155 | (input_ids, token_type_ids, attention_mask, labels, 156 | next_sentence_labels) = ( 157 | data['input_ids'], 158 | data['token_type_ids'], 159 | data['attention_mask'], 160 | data['labels'], 161 | data['next_sentence_labels'], 162 | ) 163 | batch_timer_stop = time.time() 164 | elapsed = batch_timer_stop - batch_timer_start 165 | meter.update(elapsed) 166 | 167 | if args.debug: 168 | current_samples = len(data[0]) * world_size 169 | else: 170 | current_samples = input_ids.size(0) * world_size 171 | assert input_ids.size() == token_type_ids.size() 172 | assert input_ids.size() == attention_mask.size() 173 | assert input_ids.size() == labels.size() 174 | assert next_sentence_labels.dim() == 1 175 | assert input_ids.size(0) == next_sentence_labels.size(0) 176 | seq_lens = get_batch_seq_lens(attention_mask) 177 | seq_len_hist.update_with_tensor(seq_lens) 178 | ( 179 | min_lens[epoch - args.start_epoch, i], 180 | max_lens[epoch - args.start_epoch, i], 181 | ) = seq_lens.min(), seq_lens.max() 182 | batch_sizes[epoch - args.start_epoch, i] = input_ids.size(0) 183 | padded_lens[epoch - args.start_epoch, i] = input_ids.size(1) 184 | padded_zero_hist.update_with_tensor(input_ids.size(1) - seq_lens) 185 | 186 | total_samples += current_samples 187 | current_throughput = current_samples / elapsed 188 | if (i + 1) % args.log_freq == 0 and get_rank() == 0: 189 | avg_throughput = total_samples / meter.sum 190 | print('avg_throughput={}, avg_latency={} ms, ' 191 | 'min_latency={} ms, max_latency={} ms, ' 192 | 'current_throughput={}, current_latency={} ms'.format( 193 | avg_throughput, 194 | meter.avg * 1000, 195 | meter.min * 1000, 196 | meter.max * 1000, 197 | current_throughput, 198 | elapsed * 1000, 199 | )) 200 | if args.debug: 201 | print('len(data[0])={}'.format(len(data[0]))) 202 | print('sample=({} {} - {})'.format( 203 | data[0][0], 204 | data[1][0], 205 | data[2][0], 206 | )) 207 | else: 208 | print("Min length={} Max length={} Diff={}".format( 209 | min_lens[epoch - args.start_epoch, i], 210 | max_lens[epoch - args.start_epoch, i], 211 | max_lens[epoch - args.start_epoch, i] - 212 | min_lens[epoch - args.start_epoch, i], 213 | )) 214 | print('input_ids.size()={}'.format(input_ids.size())) 215 | print('input_ids[0]={}'.format(input_ids[0])) 216 | print('convert_ids_to_tokens(input_ids[0])={}'.format( 217 | test_tokenizer.convert_ids_to_tokens(input_ids[0].tolist()))) 218 | print('token_type_ids[0]={}'.format(token_type_ids[0])) 219 | print('attention_mask[0]={}'.format(attention_mask[0])) 220 | print('labels[0]={}'.format(labels[0])) 221 | print('next_sentence_labels[0]={}'.format(next_sentence_labels[0])) 222 | mask = labels[0] != args.ignore_index 223 | input_ids[0, mask] = labels[0, mask] 224 | print('original sequence={}'.format( 225 | test_tokenizer.convert_ids_to_tokens(input_ids[0].tolist()))) 226 | barrier() 227 | batch_timer_start = time.time() 228 | epoch_timer_stop = time.time() 229 | epoch_elapsed = epoch_timer_stop - epoch_timer_start 230 | if args.local_rank == 0: 231 | avg_throughput = total_samples / meter.sum 232 | print('epoch={}, epoch_elapsed={}, avg_throughput={}, ' 233 | 'total_samples={}'.format( 234 | epoch, 235 | epoch_elapsed, 236 | avg_throughput, 237 | total_samples, 238 | )) 239 | assert meter.iters == min(len(loader), args.iters_per_epoch) 240 | meter.reset() 241 | 242 | if args.seq_len_dir is not None: 243 | # Save the sequence lengths to file 244 | np.savez_compressed( 245 | os.path.join(args.seq_len_dir, 'lens_{}.npz'.format(get_rank())), 246 | min_lens=min_lens, 247 | max_lens=max_lens, 248 | batch_sizes=batch_sizes, 249 | padded_lens=padded_lens, 250 | seq_len_hist=seq_len_hist.hist, 251 | padded_zero_hist=padded_zero_hist.hist, 252 | ) 253 | 254 | 255 | def attach_args(parser=argparse.ArgumentParser()): 256 | parser.add_argument('--path', type=str, required=True) 257 | parser.add_argument('--batch-size', type=int, default=64) 258 | parser.add_argument('--workers', type=int, default=4) 259 | parser.add_argument('--warmup', type=int, default=0) 260 | parser.add_argument('--epochs', type=int, default=2) 261 | parser.add_argument('--iters-per-epoch', type=int, default=float('inf')) 262 | parser.add_argument('--prefetch', type=int, default=2) 263 | parser.add_argument( 264 | '--local_rank', 265 | type=int, 266 | default=os.getenv('LOCAL_RANK', 0), 267 | ) 268 | parser.add_argument('--mlm-probability', type=float, default=0.15) 269 | parser.add_argument('--shuffle-buffer-size', type=int, default=16384) 270 | parser.add_argument('--shuffle-buffer-warmup-factor', type=int, default=16) 271 | parser.add_argument('--vocab-file', type=str, required=True) 272 | parser.add_argument('--seed', type=int, default=127) 273 | parser.add_argument('--start-epoch', type=int, default=0) 274 | parser.add_argument('--debug', action='store_true', default=False) 275 | parser.add_argument('--log-freq', type=int, default=1000) 276 | parser.add_argument('--log-dir', type=str, default=None) 277 | parser.add_argument( 278 | '--log-level', 279 | type=str, 280 | choices=['CRITICAL', 'ERROR', 'WARNING', 'INFO', 'DEBUG'], 281 | default='WARNING', 282 | ) 283 | parser.add_argument('--seq-len-dir', type=str, default=None) 284 | parser.add_argument('--sequence-length-alignment', type=int, default=8) 285 | parser.add_argument('--ignore-index', type=int, default=-1) 286 | return parser 287 | 288 | 289 | if __name__ == '__main__': 290 | main(attach_args().parse_args()) 291 | -------------------------------------------------------------------------------- /docker/build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | BASE_IMAGE=${1:-"ngc_pyt"} 4 | TAG=${2:-"21.11-py3"} 5 | URL=${3:-"lddl:latest"} 6 | PUSH=${4:-"none"} # 'push' or 'none' 7 | 8 | set -e 9 | 10 | docker build \ 11 | -f docker/${BASE_IMAGE}.Dockerfile \ 12 | --network=host \ 13 | --rm \ 14 | -t ${URL} \ 15 | --build-arg TAG=${TAG} \ 16 | . 17 | 18 | if [ "${PUSH}" == "push" ]; then 19 | docker push ${URL} 20 | elif [ "${PUSH}" == "none" ]; then 21 | echo "Keep the built image locally." 22 | else 23 | echo "Invalid \${PUSH} option: ${PUSH} !" 24 | exit 1 25 | fi 26 | -------------------------------------------------------------------------------- /docker/interactive.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | MOUNTS=$1 4 | CMD=${2:-"bash"} 5 | IMAGE=${3:-"lddl"} 6 | GPUS=${4:-"all"} 7 | 8 | docker run \ 9 | --gpus \"device=${GPUS}\" \ 10 | --init \ 11 | -it \ 12 | --rm \ 13 | --network=host \ 14 | --ipc=host \ 15 | -v $PWD:/workspace/lddl \ 16 | ${MOUNTS} \ 17 | ${IMAGE} \ 18 | ${CMD} 19 | -------------------------------------------------------------------------------- /docker/ngc_paddle.Dockerfile: -------------------------------------------------------------------------------- 1 | ARG TAG 2 | # Import a NGC PaddlePaddle container as the base image. 3 | # For more information on NGC PaddlePaddle containers, please visit: 4 | # https://catalog.ngc.nvidia.com/orgs/nvidia/containers/paddlepaddle 5 | FROM nvcr.io/nvidia/paddlepaddle:${TAG} 6 | 7 | ENV LANG C.UTF-8 8 | ENV LC_ALL C.UTF-8 9 | 10 | RUN apt-get update -qq && \ 11 | apt-get install -y git vim tmux && \ 12 | rm -rf /var/cache/apk/* 13 | 14 | RUN apt-get install -y libjemalloc-dev 15 | 16 | # Copy the lddl source code to /workspace/lddl in the image, then install. 17 | WORKDIR /workspace/lddl 18 | ADD . . 19 | RUN pip install ./ 20 | RUN pip install h5py pandas==1.5.2 21 | RUN pip install git+https://github.com/NVIDIA/dllogger#egg=dllogger 22 | 23 | # Download the NLTK model data. 24 | RUN python -m nltk.downloader punkt 25 | -------------------------------------------------------------------------------- /docker/ngc_pyt.Dockerfile: -------------------------------------------------------------------------------- 1 | ARG TAG 2 | # Import a NGC PyTorch container as the base image. 3 | # For more information on NGC PyTorch containers, please visit: 4 | # https://ngc.nvidia.com/catalog/containers/nvidia:pytorch 5 | FROM nvcr.io/nvidia/pytorch:${TAG} 6 | 7 | ENV LANG C.UTF-8 8 | ENV LC_ALL C.UTF-8 9 | 10 | RUN apt-get update -qq && \ 11 | apt-get install -y git vim tmux && \ 12 | rm -rf /var/cache/apk/* 13 | 14 | RUN apt-get install -y libjemalloc-dev 15 | 16 | # Copy the lddl source code to /workspace/lddl in the image, then install. 17 | WORKDIR /workspace/lddl 18 | ADD . . 19 | RUN pip install ./ 20 | 21 | # Download the NLTK model data. 22 | RUN python -m nltk.downloader punkt 23 | -------------------------------------------------------------------------------- /docs/images/binning.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/LDDL/e4006d2bad94113a39ba534d4e8af81db3c40642/docs/images/binning.gif -------------------------------------------------------------------------------- /docs/images/binning_perf.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/LDDL/e4006d2bad94113a39ba534d4e8af81db3c40642/docs/images/binning_perf.gif -------------------------------------------------------------------------------- /docs/images/preprocess_perf.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/LDDL/e4006d2bad94113a39ba534d4e8af81db3c40642/docs/images/preprocess_perf.gif -------------------------------------------------------------------------------- /docs/images/summary.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/LDDL/e4006d2bad94113a39ba534d4e8af81db3c40642/docs/images/summary.gif -------------------------------------------------------------------------------- /examples/local_example.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES 4 | # Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 5 | # SPDX-License-Identifier: MIT 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a 8 | # copy of this software and associated documentation files (the "Software"), 9 | # to deal in the Software without restriction, including without limitation 10 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 11 | # and/or sell copies of the Software, and to permit persons to whom the 12 | # Software is furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in 15 | # all copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 20 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 22 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 23 | # DEALINGS IN THE SOFTWARE. 24 | 25 | # 26 | # This bash script demonstrates how to use LDDL end-to-end (i.e., from 27 | # downloading the raw dataset to loading the input batches during training) on 28 | # a local machine for (mock) BERT Phase 2 pretraining with static masking and 29 | # sequence binning enabled. 30 | 31 | set -eux 32 | 33 | # Build a NGC PyTorch container image that has lddl installed. 34 | bash docker/build.sh 35 | 36 | # Create a directory to store data. 37 | mkdir -p data/ 38 | 39 | # Download the Wikipedia dump. 40 | readonly wikipedia_path=data/wikipedia 41 | bash docker/interactive.sh "" "download_wikipedia --outdir ${wikipedia_path}" 42 | 43 | # Download the vocab file from NVIDIA Deep Learning Examples (but you can 44 | # certainly get it from other sources as well). 45 | readonly vocab_source_url=https://raw.githubusercontent.com/NVIDIA/DeepLearningExamples/master/PyTorch/LanguageModeling/BERT/vocab/vocab 46 | mkdir -p data/vocab/ 47 | readonly vocab_path=data/vocab/bert-en-uncased.txt 48 | wget ${vocab_source_url} -O ${vocab_path} 49 | 50 | # Run the LDDL preprocessor for BERT Phase 2 pretraining with static masking and 51 | # sequence binning enabled (where the bin size is 64). 52 | readonly num_shards=4096 53 | readonly bin_size=64 54 | readonly jemalloc_path=/usr/lib/x86_64-linux-gnu/libjemalloc.so 55 | readonly pretrain_input_path=data/bert/pretrain/phase2/bin_size_${bin_size}/ 56 | bash docker/interactive.sh "" " \ 57 | mpirun \ 58 | --oversubscribe \ 59 | --allow-run-as-root \ 60 | -np $(nproc) \ 61 | -x LD_PRELOAD=${jemalloc_path} \ 62 | preprocess_bert_pretrain \ 63 | --schedule mpi \ 64 | --vocab-file ${vocab_path} \ 65 | --wikipedia ${wikipedia_path}/source/ \ 66 | --sink ${pretrain_input_path} \ 67 | --target-seq-length 512 \ 68 | --num-blocks ${num_shards} \ 69 | --bin-size ${bin_size} \ 70 | --masking " 71 | 72 | # Run the LDDL load balancer to balance the parquet shards generated by the LDDL 73 | # preprocessor. 74 | bash docker/interactive.sh "" " \ 75 | mpirun \ 76 | --oversubscribe \ 77 | --allow-run-as-root \ 78 | -np $(nproc) \ 79 | balance_dask_output \ 80 | --indir ${pretrain_input_path} \ 81 | --num-shards ${num_shards} " 82 | 83 | # Run a mock PyTorch training script that loads the input from the balanced 84 | # parquet shards using the LDDL data loader. 85 | # Once these training processes is up and running (as you can see from the 86 | # stdout printing), it simply emulates training and you can kill it at any time. 87 | readonly sequence_length_distribution_path=data/experiments/phase2/bin_size_${bin_size}/ 88 | bash docker/interactive.sh "" " \ 89 | python -m torch.distributed.launch --nproc_per_node=2 \ 90 | benchmarks/torch_train.py \ 91 | --path ${pretrain_input_path} \ 92 | --vocab-file ${vocab_path} " 93 | -------------------------------------------------------------------------------- /examples/slurm_example.sub: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --exclusive 3 | #SBATCH --mem=0 4 | #SBATCH --overcommit 5 | #SBATCH --parsable 6 | 7 | # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES 8 | # Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 9 | # SPDX-License-Identifier: MIT 10 | # 11 | # Permission is hereby granted, free of charge, to any person obtaining a 12 | # copy of this software and associated documentation files (the "Software"), 13 | # to deal in the Software without restriction, including without limitation 14 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 15 | # and/or sell copies of the Software, and to permit persons to whom the 16 | # Software is furnished to do so, subject to the following conditions: 17 | # 18 | # The above copyright notice and this permission notice shall be included in 19 | # all copies or substantial portions of the Software. 20 | # 21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 24 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 26 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 27 | # DEALINGS IN THE SOFTWARE. 28 | 29 | 30 | # 31 | # This Slurm script demonstrates how to use the LDDL preprocessor, load balancer 32 | # and data loader and scale them to multi-nodes on Slurm clusters for (mock) 33 | # BERT Phase 2 pretraining with static masking and sequence binning enabled. 34 | # 35 | 36 | set -eux 37 | 38 | # 39 | # The following configurations might need to be customized based on the setup 40 | # of the Slurm cluster you are using. 41 | # 42 | 43 | # The URL of the container image built via `bash docker/build.sh`. 44 | # For example, if you build the container image by 45 | # `bash docker/build.sh ngc_pyt 21.11-py3 lddl:latest push`, 46 | # then the URL would be "lddl:latest": 47 | readonly docker_image=${DOCKER_IMAGE:-"lddl:latest"} 48 | 49 | # Create a directory to store data. 50 | mkdir -p data/ 51 | 52 | # Assume the Wikipedia dump is already downloaded and moved to the following 53 | # location in the NFS of your Slurm cluster. 54 | # 55 | # Please refer to examples/local_example.sh on how to use the LDDL downloader 56 | # to download the Wikipedia dump. 57 | readonly wikipedia_path=data/wikipedia 58 | 59 | # Download the vocab file from NVIDIA Deep Learning Examples (but you can 60 | # certainly get it from other sources as well). 61 | readonly vocab_source_url=https://raw.githubusercontent.com/NVIDIA/DeepLearningExamples/master/PyTorch/LanguageModeling/BERT/vocab/vocab 62 | mkdir -p data/vocab/ 63 | readonly vocab_path=data/vocab/bert-en-uncased.txt 64 | wget ${vocab_source_url} -O ${vocab_path} 65 | 66 | # Run the LDDL preprocessor for BERT Phase 2 pretraining with static masking and 67 | # sequence binning enabled (where the bin size is 64). 68 | readonly mounts=$(realpath data/):/workspace/lddl/data 69 | readonly workdir=/workspace/lddl 70 | readonly num_shards=4096 71 | readonly bin_size=64 72 | readonly tasks_per_node=128 73 | readonly pretrain_input_path=data/bert/pretrain/phase2/bin_size_${bin_size}/ 74 | srun \ 75 | -l \ 76 | --mpi=pmix \ 77 | --container-image="${docker_image}" \ 78 | --container-mounts="${mounts}" \ 79 | --container-workdir=${workdir} \ 80 | --ntasks-per-node=${tasks_per_node} \ 81 | --export=ALL,LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libjemalloc.so \ 82 | preprocess_bert_pretrain \ 83 | --schedule mpi \ 84 | --vocab-file ${vocab_path} \ 85 | --wikipedia ${wikipedia_path}/source/ \ 86 | --sink ${pretrain_input_path} \ 87 | --target-seq-length 512 \ 88 | --num-blocks ${num_shards} \ 89 | --bin-size ${bin_size} \ 90 | --masking 91 | 92 | # Run the LDDL load balancer to balance the parquet shards generated by the LDDL 93 | # preprocessor. 94 | srun \ 95 | -l \ 96 | --mpi=pmix \ 97 | --container-image="${docker_image}" \ 98 | --container-mounts="${mounts}" \ 99 | --container-workdir=${workdir} \ 100 | --ntasks-per-node=${tasks_per_node} \ 101 | balance_dask_output \ 102 | --indir ${pretrain_input_path} \ 103 | --num-shards ${num_shards} 104 | 105 | # Run a mock PyTorch training script that loads the input from the balanced 106 | # parquet shards using the LDDL data loader. 107 | # Once these training processes is up and running (as you can see from the 108 | # stdout printing), it simply emulates training and you can kill it at any time. 109 | readonly gpus_per_node=8 110 | srun \ 111 | -l \ 112 | --container-image="${docker_image}" \ 113 | --container-mounts="${mounts}" \ 114 | --container-workdir=${workdir} \ 115 | --ntasks-per-node=${gpus_per_node} \ 116 | python benchmarks/torch_train.py \ 117 | --path ${pretrain_input_path} \ 118 | --vocab-file ${vocab_path} 119 | -------------------------------------------------------------------------------- /lddl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/LDDL/e4006d2bad94113a39ba534d4e8af81db3c40642/lddl/__init__.py -------------------------------------------------------------------------------- /lddl/dask/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/LDDL/e4006d2bad94113a39ba534d4e8af81db3c40642/lddl/dask/__init__.py -------------------------------------------------------------------------------- /lddl/dask/bart/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/LDDL/e4006d2bad94113a39ba534d4e8af81db3c40642/lddl/dask/bart/__init__.py -------------------------------------------------------------------------------- /lddl/dask/bart/pretrain.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES 3 | # Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | # SPDX-License-Identifier: MIT 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a 7 | # copy of this software and associated documentation files (the "Software"), 8 | # to deal in the Software without restriction, including without limitation 9 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 10 | # and/or sell copies of the Software, and to permit persons to whom the 11 | # Software is furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 19 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 21 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 22 | # DEALINGS IN THE SOFTWARE. 23 | # 24 | 25 | import argparse 26 | import dask 27 | import dask.bag as db 28 | import dask.distributed 29 | import functools 30 | import nltk 31 | import os 32 | import pyarrow as pa 33 | import time 34 | 35 | from lddl.dask.readers import (read_open_webtext, read_wikipedia, read_books, 36 | read_common_crawl, estimate_block_size) 37 | from lddl.utils import expand_outdir_and_mkdir 38 | from lddl.download.utils import parse_str_of_num_bytes 39 | 40 | 41 | def _get_sequences(wikipedia_path=None, 42 | books_path=None, 43 | common_crawl_path=None, 44 | open_webtext_path=None, 45 | wikipedia_lang='en', 46 | target_seq_length=128, 47 | short_seq_prob=0.1, 48 | blocksize=None, 49 | num_blocks=None): 50 | if num_blocks is not None: 51 | if blocksize is not None: 52 | raise ValueError('Only one of num_blocks or blocksize needs to be set!') 53 | blocksize = estimate_block_size( 54 | (wikipedia_path, books_path, common_crawl_path, open_webtext_path), 55 | num_blocks, 56 | ) 57 | bags = [] 58 | if wikipedia_path is not None: 59 | bags.append( 60 | read_wikipedia( 61 | wikipedia_path, 62 | lang=wikipedia_lang, 63 | blocksize=blocksize, 64 | )) 65 | if books_path is not None: 66 | bags.append(read_books( 67 | books_path, 68 | blocksize=blocksize, 69 | )) 70 | if common_crawl_path is not None: 71 | bags.append(read_common_crawl( 72 | common_crawl_path, 73 | blocksize=blocksize, 74 | )) 75 | 76 | if open_webtext_path is not None: 77 | bags.append(read_open_webtext( 78 | open_webtext_path, 79 | blocksize=blocksize, 80 | )) 81 | 82 | def _segment(article): 83 | return filter( 84 | None, 85 | map(lambda s: s.strip(), nltk.tokenize.sent_tokenize(article)), 86 | ) 87 | 88 | def _aggregate_sentences(sentences): 89 | # Cutting sentences into chunks that are close to target_seq_length 90 | # results is in the format of 91 | # [ 92 | # { 93 | # 'sentences': [sent1, sent2], 94 | # 'num_tokens': [num_tokens1, num_tokens2], 95 | # }, 96 | # { 97 | # 'sentences': [sent1, sent2, sent3], 98 | # 'num_tokens': [num_tokens1, num_tokens2, num_tokens3], 99 | # }, 100 | # { 101 | # 'sentences': [sent1], 102 | # 'num_tokens': [num_tokens1], 103 | # }, 104 | # ... 105 | # ] 106 | results = [] 107 | # Excluding [CLS], [SEP], [SEP] 108 | target_length = target_seq_length - 3 109 | chunk = "" 110 | num_tokens = 0 111 | for sentence in sentences: 112 | chunk += " " + sentence 113 | num_tokens += len(list(sentence.split())) 114 | if num_tokens >= target_length: 115 | results.append({ 116 | 'sentences': chunk, 117 | 'num_tokens': num_tokens, 118 | 'target_length': target_length, 119 | }) 120 | chunk = "" 121 | num_tokens = 0 122 | if num_tokens > 0: 123 | results.append({ 124 | 'sentences': chunk, 125 | 'num_tokens': num_tokens, 126 | 'target_length': target_length, 127 | }) 128 | return results 129 | 130 | def _generate_sequences(article): 131 | return _aggregate_sentences(_segment(article)) 132 | 133 | return db.concat(bags).map(_generate_sequences).flatten() 134 | 135 | 136 | def save(pairs, path, output_format='parquet'): 137 | if output_format == 'parquet': 138 | pairs.to_dataframe(meta={ 139 | 'sentences': str, 140 | }).to_parquet( 141 | path, 142 | engine='pyarrow', 143 | write_index=False, 144 | schema={ 145 | 'sentences': pa.string(), 146 | }, 147 | ) 148 | elif output_format == 'txt': 149 | pairs = pairs.map(lambda p: '{}'.format(p['sentences'],)).to_textfiles( 150 | os.path.join(path, '*.txt')) 151 | else: 152 | raise ValueError('Format {} not supported!'.format(output_format)) 153 | 154 | 155 | def main(args): 156 | 157 | if args.schedule == 'mpi': 158 | from dask_mpi import initialize 159 | initialize() 160 | client = dask.distributed.Client() 161 | else: 162 | client = dask.distributed.Client( 163 | n_workers=args.local_n_workers, 164 | threads_per_worker=args.local_threads_per_worker, 165 | ) 166 | 167 | nltk.download('punkt') 168 | 169 | tic = time.perf_counter() 170 | sequences = _get_sequences( 171 | wikipedia_path=args.wikipedia, 172 | books_path=args.books, 173 | common_crawl_path=args.common_crawl, 174 | open_webtext_path=args.open_webtext, 175 | wikipedia_lang=args.wikipedia_lang, 176 | target_seq_length=args.target_seq_length, 177 | short_seq_prob=args.short_seq_prob, 178 | blocksize=args.block_size, 179 | num_blocks=args.num_blocks, 180 | ) 181 | 182 | args.sink = expand_outdir_and_mkdir(args.sink) 183 | save(sequences, args.sink, output_format=args.output_format) 184 | print('Running the dask pipeline took {} s'.format(time.perf_counter() - tic)) 185 | 186 | 187 | def attach_args( 188 | parser=argparse.ArgumentParser('BART pretrain dataset dask pipeline')): 189 | parser.add_argument( 190 | '--schedule', 191 | type=str, 192 | default='mpi', 193 | choices=['mpi', 'local'], 194 | help='how the dask pipeline is scheduled', 195 | ) 196 | parser.add_argument( 197 | '--local-n-workers', 198 | type=int, 199 | default=os.cpu_count(), 200 | help='number of worker processes for the local cluster; ' 201 | 'only used when --schedule=local', 202 | ) 203 | parser.add_argument( 204 | '--local-threads-per-worker', 205 | type=int, 206 | default=1, 207 | help='number of Python user-level threads per worker process for the ' 208 | 'local cluster; only used when --schedule=local', 209 | ) 210 | parser.add_argument( 211 | '--wikipedia', 212 | type=str, 213 | default=None, 214 | help='path to the Wikipedia corpus', 215 | ) 216 | parser.add_argument( 217 | '--books', 218 | type=str, 219 | default=None, 220 | help='path to the Toronto books corpus', 221 | ) 222 | parser.add_argument( 223 | '--common-crawl', 224 | type=str, 225 | default=None, 226 | help='path to the Common Crawl news corpus', 227 | ) 228 | parser.add_argument( 229 | '--open-webtext', 230 | type=str, 231 | default=None, 232 | help='path to the Open WebText Corpus', 233 | ) 234 | parser.add_argument( 235 | '--sink', 236 | type=str, 237 | default=None, 238 | required=True, 239 | help='path to the dir to store output files', 240 | ) 241 | parser.add_argument( 242 | '--output-format', 243 | type=str, 244 | default='parquet', 245 | choices=['parquet', 'txt'], 246 | help='output file format', 247 | ) 248 | parser.add_argument( 249 | '--wikipedia-lang', 250 | type=str, 251 | default='en', 252 | choices=['en', 'zh'], 253 | help='wikipedia language type', 254 | ) 255 | parser.add_argument( 256 | '--target-seq-length', 257 | type=int, 258 | default=128, 259 | help='target sequence length', 260 | ) 261 | parser.add_argument( 262 | '--short-seq-prob', 263 | type=float, 264 | default=0.1, 265 | help='probability to use sequences shorter than --target-seq-length', 266 | ) 267 | parser.add_argument( 268 | '--block-size', 269 | type=functools.partial(parse_str_of_num_bytes, return_str=False), 270 | default=None, 271 | metavar='n[KMG]', 272 | help='The size of each output parquet/txt shard. Since Dask cannot ' 273 | 'guarantee perfect load balance, this value is only used as an estimate. ' 274 | 'Only one of --block-size and --num-blocks needs to be set, since one ' 275 | 'value can be derived from the other. Default: {}'.format(None), 276 | ) 277 | parser.add_argument( 278 | '--num-blocks', 279 | type=int, 280 | default=None, 281 | help='The total number of the output parquet/txt shards. Since Dask ' 282 | 'cannot guarantee perfect load balance, this value is only used as an ' 283 | 'estimate. Only one of --block-size or --num-blocks needs to be set, ' 284 | 'since one value can be derived from the other. Default: {}'.format(None), 285 | ) 286 | return parser 287 | 288 | 289 | def console_script(): 290 | main(attach_args().parse_args()) 291 | -------------------------------------------------------------------------------- /lddl/dask/bert/DASK_LICENSE.txt: -------------------------------------------------------------------------------- 1 | This library contains modified code from the Dask library 2 | (https://github.com/dask/dask). The Dask license is below. 3 | 4 | BSD 3-Clause License 5 | 6 | Copyright (c) 2014, Anaconda, Inc. and contributors 7 | All rights reserved. 8 | 9 | Redistribution and use in source and binary forms, with or without 10 | modification, are permitted provided that the following conditions are met: 11 | 12 | * Redistributions of source code must retain the above copyright notice, this 13 | list of conditions and the following disclaimer. 14 | 15 | * Redistributions in binary form must reproduce the above copyright notice, 16 | this list of conditions and the following disclaimer in the documentation 17 | and/or other materials provided with the distribution. 18 | 19 | * Neither the name of the copyright holder nor the names of its 20 | contributors may be used to endorse or promote products derived from 21 | this software without specific prior written permission. 22 | 23 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 24 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 25 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 26 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 27 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 28 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 29 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 30 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 31 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 32 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 33 | -------------------------------------------------------------------------------- /lddl/dask/bert/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/LDDL/e4006d2bad94113a39ba534d4e8af81db3c40642/lddl/dask/bert/__init__.py -------------------------------------------------------------------------------- /lddl/dask/bert/binning.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES 3 | # Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | # SPDX-License-Identifier: MIT 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a 7 | # copy of this software and associated documentation files (the "Software"), 8 | # to deal in the Software without restriction, including without limitation 9 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 10 | # and/or sell copies of the Software, and to permit persons to whom the 11 | # Software is furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 19 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 21 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 22 | # DEALINGS IN THE SOFTWARE. 23 | # 24 | 25 | from dask.highlevelgraph import HighLevelGraph 26 | 27 | # to_dataframe 28 | from dask.base import tokenize 29 | from dask.bag.core import reify 30 | import itertools 31 | 32 | # to_parquet 33 | import tlz as toolz 34 | from fsspec.core import get_fs_token_paths 35 | from dask.base import compute_as_if_collection 36 | from dask.delayed import Delayed 37 | from dask.utils import apply 38 | from dask.dataframe.core import Scalar 39 | from dask.dataframe.io.parquet.core import get_engine 40 | from dask.dataframe.io.parquet.arrow import _index_in_schema 41 | import pyarrow.parquet as pq 42 | try: 43 | import snappy 44 | 45 | snappy.compress 46 | except (ImportError, AttributeError): 47 | snappy = None 48 | 49 | NONE_LABEL = "__null_dask_index__" 50 | 51 | # to_textfiles 52 | import io 53 | import uuid 54 | from dask.bytes import open_files 55 | from dask.utils import ensure_unicode, ensure_bytes, system_encoding 56 | from contextlib import ExitStack 57 | 58 | # 59 | # dataframes 60 | # 61 | 62 | 63 | def _to_dataframe_binned(seq, columns, dtypes, bin_size, nbins): 64 | import pandas as pd 65 | 66 | seq = reify(seq) 67 | if not isinstance(seq, list): 68 | seq = list(seq) 69 | 70 | seqs = [[] for _ in range(nbins)] 71 | for i, iseq in enumerate(seq): 72 | seq_len = iseq['num_tokens'] 73 | bin_id = (seq_len - 1) // bin_size 74 | bin_id = nbins - 1 if bin_id > nbins - 1 else bin_id 75 | seqs[bin_id].append(iseq) 76 | 77 | dfl = list( 78 | map( 79 | lambda l: pd.DataFrame( 80 | l, 81 | columns=list(columns), 82 | ).astype(dtypes, copy=False), 83 | seqs, 84 | )) 85 | 86 | dfs = pd.concat(dfl, keys=list(map(str, list(range(nbins))))) 87 | 88 | # Add a bin_id column 89 | dfs['bin_id'] = list( 90 | itertools.chain.from_iterable( 91 | [[i] * len(bingrp) for i, bingrp in enumerate(seqs)])) 92 | 93 | return dfs 94 | 95 | 96 | def to_dataframe_binned(self, bin_size, nbins, meta=None, columns=None): 97 | import pandas as pd 98 | import dask.dataframe as dd 99 | 100 | if meta is None: 101 | head = self.take(1, warn=False) 102 | if len(head) == 0: 103 | raise ValueError("`dask.bag.Bag.to_dataframe` failed to " 104 | "properly infer metadata, please pass in " 105 | "metadata via the `meta` keyword") 106 | meta_nobin = pd.DataFrame(list(head), columns=columns) 107 | elif columns is not None: 108 | raise ValueError("Can't specify both `meta` and `columns`") 109 | else: 110 | meta_nobin = dd.utils.make_meta(meta, parent_meta=pd.DataFrame()) 111 | # Serializing the columns and dtypes is much smaller than serializing 112 | # the empty frame 113 | cols = list(meta_nobin.columns) 114 | dtypes = meta_nobin.dtypes.to_dict() 115 | name = "to_dataframe-binned-" + tokenize(self, cols, dtypes) 116 | dsk = self.__dask_optimize__(self.dask, self.__dask_keys__()) 117 | 118 | for i in range(self.npartitions): 119 | dsk[(name, i)] = (_to_dataframe_binned, (self.name, i), cols, dtypes, 120 | bin_size, nbins) 121 | 122 | # Update the meta 123 | meta['bin_id'] = int 124 | meta = dd.utils.make_meta(meta, parent_meta=pd.DataFrame()) 125 | 126 | divisions = [None] * (self.npartitions + 1) 127 | return dd.DataFrame(dsk, name, meta, divisions) 128 | 129 | 130 | # 131 | # parquet files 132 | # 133 | 134 | 135 | def to_parquet_binned( 136 | df, 137 | path, 138 | nbins, 139 | engine="auto", 140 | compression="default", 141 | write_index=True, 142 | append=False, 143 | overwrite=False, 144 | ignore_divisions=False, 145 | partition_on=None, 146 | storage_options=None, 147 | custom_metadata=None, 148 | write_metadata_file=True, 149 | compute=True, 150 | compute_kwargs=None, 151 | schema=None, 152 | **kwargs, 153 | ): 154 | compute_kwargs = compute_kwargs or {} 155 | 156 | if compression == "default": 157 | if snappy is not None: 158 | compression = "snappy" 159 | else: 160 | compression = None 161 | 162 | partition_on = partition_on or [] 163 | if isinstance(partition_on, str): 164 | partition_on = [partition_on] 165 | 166 | if set(partition_on) - set(df.columns): 167 | raise ValueError("Partitioning on non-existent column. " 168 | "partition_on=%s ." 169 | "columns=%s" % (str(partition_on), str(list(df.columns)))) 170 | 171 | if isinstance(engine, str): 172 | engine = get_engine(engine) 173 | 174 | if hasattr(path, "name"): 175 | path = stringify_path(path) 176 | fs, _, _ = get_fs_token_paths(path, 177 | mode="wb", 178 | storage_options=storage_options) 179 | # Trim any protocol information from the path before forwarding 180 | path = fs._strip_protocol(path) 181 | 182 | if overwrite: 183 | if isinstance(fs, LocalFileSystem): 184 | working_dir = fs.expand_path(".")[0] 185 | if path.rstrip("/") == working_dir.rstrip("/"): 186 | raise ValueError( 187 | "Cannot clear the contents of the current working directory!") 188 | if append: 189 | raise ValueError("Cannot use both `overwrite=True` and `append=True`!") 190 | if fs.exists(path) and fs.isdir(path): 191 | # Only remove path contents if 192 | # (1) The path exists 193 | # (2) The path is a directory 194 | # (3) The path is not the current working directory 195 | fs.rm(path, recursive=True) 196 | 197 | # Save divisions and corresponding index name. This is necessary, 198 | # because we may be resetting the index to write the file 199 | division_info = {"divisions": df.divisions, "name": df.index.name} 200 | if division_info["name"] is None: 201 | # As of 0.24.2, pandas will rename an index with name=None 202 | # when df.reset_index() is called. The default name is "index", 203 | # but dask will always change the name to the NONE_LABEL constant 204 | if NONE_LABEL not in df.columns: 205 | division_info["name"] = NONE_LABEL 206 | elif write_index: 207 | raise ValueError( 208 | "Index must have a name if __null_dask_index__ is a column.") 209 | else: 210 | warnings.warn("If read back by Dask, column named __null_dask_index__ " 211 | "will be set to the index (and renamed to None).") 212 | 213 | # There are some "resrved" names that may be used as the default column 214 | # name after resetting the index. However, we don't want to treat it as 215 | # a "special" name if the string is already used as a "real" column name. 216 | reserved_names = [] 217 | for name in ["index", "level_0"]: 218 | if name not in df.columns: 219 | reserved_names.append(name) 220 | 221 | # If write_index==True (default), reset the index and record the 222 | # name of the original index in `index_cols` (we will set the name 223 | # to the NONE_LABEL constant if it is originally `None`). 224 | # `fastparquet` will use `index_cols` to specify the index column(s) 225 | # in the metadata. `pyarrow` will revert the `reset_index` call 226 | # below if `index_cols` is populated (because pyarrow will want to handle 227 | # index preservation itself). For both engines, the column index 228 | # will be written to "pandas metadata" if write_index=True 229 | index_cols = [] 230 | if write_index: 231 | real_cols = set(df.columns) 232 | none_index = list(df._meta.index.names) == [None] 233 | df = df.reset_index() 234 | if none_index: 235 | df.columns = [ 236 | c if c not in reserved_names else NONE_LABEL for c in df.columns 237 | ] 238 | index_cols = [c for c in set(df.columns) - real_cols] 239 | else: 240 | # Not writing index - might as well drop it 241 | df = df.reset_index(drop=True) 242 | 243 | _to_parquet_kwargs = { 244 | "engine", 245 | "compression", 246 | "write_index", 247 | "append", 248 | "ignore_divisions", 249 | "partition_on", 250 | "storage_options", 251 | "write_metadata_file", 252 | "compute", 253 | } 254 | kwargs_pass = {k: v for k, v in kwargs.items() if k not in _to_parquet_kwargs} 255 | 256 | # Engine-specific initialization steps to write the dataset. 257 | # Possibly create parquet metadata, and load existing stuff if appending 258 | meta, schema, i_offset = engine.initialize_write( 259 | df, 260 | fs, 261 | path, 262 | append=append, 263 | ignore_divisions=ignore_divisions, 264 | partition_on=partition_on, 265 | division_info=division_info, 266 | index_cols=index_cols, 267 | schema=schema, 268 | **kwargs_pass, 269 | ) 270 | 271 | # Use i_offset and df.npartitions to define file-name list 272 | filenames = [ 273 | "part.%i.parquet" % (i + i_offset) for i in range(df.npartitions) 274 | ] 275 | 276 | # Construct IO graph 277 | dsk = {} 278 | name = "to-parquet-binned" + tokenize( 279 | df, 280 | fs, 281 | path, 282 | append, 283 | ignore_divisions, 284 | partition_on, 285 | division_info, 286 | index_cols, 287 | schema, 288 | ) 289 | part_tasks = [] 290 | kwargs_pass["fmd"] = meta 291 | kwargs_pass["compression"] = compression 292 | kwargs_pass["index_cols"] = index_cols 293 | kwargs_pass["schema"] = schema 294 | if custom_metadata: 295 | if b"pandas" in custom_metadata.keys(): 296 | raise ValueError( 297 | "User-defined key/value metadata (custom_metadata) can not " 298 | "contain a b'pandas' key. This key is reserved by Pandas, " 299 | "and overwriting the corresponding value can render the " 300 | "entire dataset unreadable.") 301 | kwargs_pass["custom_metadata"] = custom_metadata 302 | # Override write_partition to write binned parquet files 303 | engine.write_partition = write_partition_binned 304 | for d, filename in enumerate(filenames): 305 | dsk[(name, d)] = ( 306 | apply, 307 | engine.write_partition, 308 | [ 309 | engine, 310 | (df._name, d), 311 | path, 312 | fs, 313 | filename, 314 | partition_on, 315 | write_metadata_file, 316 | nbins, 317 | ], 318 | toolz.merge(kwargs_pass, {"head": True}) if d == 0 else kwargs_pass, 319 | ) 320 | part_tasks.append((name, d)) 321 | 322 | final_name = "metadata-" + name 323 | # Collect metadata and write _metadata 324 | 325 | if write_metadata_file: 326 | dsk[(final_name, 0)] = ( 327 | apply, 328 | engine.write_metadata, 329 | [ 330 | part_tasks, 331 | meta, 332 | fs, 333 | path, 334 | ], 335 | { 336 | "append": append, 337 | "compression": compression 338 | }, 339 | ) 340 | else: 341 | dsk[(final_name, 0)] = (lambda x: None, part_tasks) 342 | 343 | graph = HighLevelGraph.from_collections(final_name, dsk, dependencies=[df]) 344 | out = Delayed(name, graph) 345 | 346 | if compute: 347 | return compute_as_if_collection(Scalar, graph, [(final_name, 0)], 348 | **compute_kwargs) 349 | else: 350 | return Scalar(graph, final_name, "") 351 | 352 | 353 | def write_partition_binned( 354 | cls, 355 | df, 356 | path, 357 | fs, 358 | filename, 359 | partition_on, 360 | return_metadata, 361 | nbins, 362 | fmd=None, 363 | compression=None, 364 | index_cols=None, 365 | schema=None, 366 | head=False, 367 | custom_metadata=None, 368 | **kwargs, 369 | ): 370 | _meta = None 371 | preserve_index = False 372 | if _index_in_schema(index_cols, schema): 373 | df.set_index(index_cols, inplace=True) 374 | preserve_index = True 375 | else: 376 | index_cols = [] 377 | 378 | for ibin in range(nbins): 379 | 380 | dff = df[df.bin_id == ibin] 381 | 382 | filename_b = "%s_%d" % (filename, ibin) 383 | 384 | t = cls._pandas_to_arrow_table( 385 | dff, 386 | preserve_index=preserve_index, 387 | schema=schema, 388 | ) 389 | if custom_metadata: 390 | _md = t.schema.metadata 391 | _md.update(custom_metadata) 392 | t = t.replace_schema_metadata(metadata=_md) 393 | 394 | if partition_on: 395 | md_list = _write_partitioned( 396 | t, 397 | path, 398 | filename_b, 399 | partition_on, 400 | fs, 401 | index_cols=index_cols, 402 | compression=compression, 403 | **kwargs, 404 | ) 405 | if md_list: 406 | _meta = md_list[0] 407 | for i in range(1, len(md_list)): 408 | _append_row_groups(_meta, md_list[i]) 409 | else: 410 | md_list = [] 411 | with fs.open(fs.sep.join([path, filename_b]), "wb") as fil: 412 | pq.write_table( 413 | t, 414 | fil, 415 | compression=compression, 416 | metadata_collector=md_list, 417 | **kwargs, 418 | ) 419 | if md_list: 420 | _meta = md_list[0] 421 | _meta.set_file_path(filename) 422 | 423 | # Return the schema needed to write the metadata 424 | if return_metadata: 425 | d = {"meta": _meta} 426 | if head: 427 | # Only return schema if this is the "head" partition 428 | d["schema"] = t.schema 429 | return [d] 430 | else: 431 | return [] 432 | 433 | 434 | # 435 | # text files 436 | # 437 | 438 | 439 | class file_namer(object): 440 | 441 | def __init__(self, bin_size, nbins, prefix=""): 442 | self.__bin_size = bin_size 443 | self.__nbins = nbins 444 | self.__prefix = prefix 445 | 446 | def name_function(self, i): 447 | num = i // self.__nbins 448 | bin_val = i % self.__nbins 449 | return '%s%d_%d' % (self.__prefix, num, bin_val) 450 | 451 | 452 | def _to_textfiles_chunk_binned(data, lazy_files, last_endline, bin_size): 453 | nbins = len(lazy_files) 454 | with ExitStack() as stack: 455 | fs = [stack.enter_context(lazy_file) for lazy_file in lazy_files] 456 | if isinstance(fs[0], io.TextIOWrapper): 457 | endline = "\n" 458 | ensure = ensure_unicode 459 | else: 460 | endline = b"\n" 461 | ensure = ensure_bytes 462 | starteds = [False] * nbins 463 | for d in data: 464 | # Assuming the last character containes the number of tokens. 465 | seq_len = int(d.split()[-1]) 466 | bin_id = (seq_len - 1) // bin_size 467 | bin_id = nbins - 1 if bin_id > nbins - 1 else bin_id 468 | if starteds[bin_id]: 469 | fs[bin_id].write(endline) 470 | else: 471 | starteds[bin_id] = True 472 | fs[bin_id].write(ensure(d)) 473 | if last_endline: 474 | for f in fs: 475 | f.write(endline) 476 | 477 | 478 | def to_textfiles_binned(b, 479 | path, 480 | bin_size=64, 481 | nbins=8, 482 | compression="infer", 483 | encoding=system_encoding, 484 | compute=True, 485 | storage_options=None, 486 | last_endline=False, 487 | **kwargs): 488 | 489 | mode = "wb" if encoding is None else "wt" 490 | files = open_files(path, 491 | compression=compression, 492 | mode=mode, 493 | encoding=encoding, 494 | name_function=file_namer(bin_size, nbins).name_function, 495 | num=b.npartitions * nbins, 496 | **(storage_options or {})) 497 | 498 | name = "to-textfiles-binned-" + uuid.uuid4().hex 499 | dsk = {(name, i): (_to_textfiles_chunk_binned, (b.name, i), 500 | files[k:k + nbins], last_endline, bin_size) 501 | for i, k in enumerate(range(0, len(files), nbins))} 502 | graph = HighLevelGraph.from_collections(name, dsk, dependencies=[b]) 503 | out = type(b)(graph, name, b.npartitions) 504 | 505 | if compute: 506 | out.compute(**kwargs) 507 | return [f.path for f in files] 508 | else: 509 | return out.to_delayed() 510 | -------------------------------------------------------------------------------- /lddl/dask/load_balance.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES 3 | # Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | # SPDX-License-Identifier: MIT 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a 7 | # copy of this software and associated documentation files (the "Software"), 8 | # to deal in the Software without restriction, including without limitation 9 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 10 | # and/or sell copies of the Software, and to permit persons to whom the 11 | # Software is furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 19 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 21 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 22 | # DEALINGS IN THE SOFTWARE. 23 | # 24 | 25 | import argparse 26 | import json 27 | import numpy as np 28 | import os 29 | import pyarrow as pa 30 | import pyarrow.parquet as pq 31 | import time 32 | from mpi4py import MPI 33 | 34 | from lddl.types import File 35 | from lddl.utils import (get_all_files_paths_under, expand_outdir_and_mkdir, 36 | get_all_parquets_under, get_all_bin_ids, 37 | get_file_paths_for_bin_id, get_num_samples_of_parquet, 38 | attach_bool_arg) 39 | 40 | 41 | class Shard: 42 | 43 | def __init__(self, idx, input_files, outdir, keep_orig=True, postfix=''): 44 | self.idx = idx 45 | self._input_files = input_files 46 | self._outdir = outdir 47 | self._keep_orig = keep_orig 48 | self._postfix = postfix 49 | 50 | self._output_file = None 51 | 52 | @property 53 | def num_samples(self): 54 | n = 0 55 | if self._input_files is not None: 56 | for input_file in self._input_files: 57 | n += input_file.num_samples 58 | if self._output_file is not None: 59 | n += self._output_file.num_samples 60 | return n 61 | 62 | def __repr__(self): 63 | return ('Shard(idx={}, input_files={}, outdir={}, keep_orig={}, ' 64 | 'postfix={}, output_file={})'.format( 65 | self.idx, 66 | self._input_files, 67 | self._outdir, 68 | self._keep_orig, 69 | self._postfix, 70 | self._output_file, 71 | )) 72 | 73 | def _read_table(self, path): 74 | table = pq.read_table(path) 75 | if not self._keep_orig: # Only keep the read table in memory. 76 | os.remove(path) 77 | return table 78 | 79 | def _read_table_from_file(self, f): 80 | table = self._read_table(f.path) 81 | assert f.num_samples == len(table) 82 | return table 83 | 84 | def _store(self, num_samples, table=None): 85 | if table is not None: 86 | assert num_samples == len(table) 87 | if self._output_file is None: 88 | self._output_file = File( 89 | os.path.join( 90 | self._outdir, 91 | 'shard-{}.parquet{}'.format(self.idx, self._postfix), 92 | ), 93 | 0, 94 | ) 95 | else: 96 | if table is not None: 97 | table = pa.concat_tables([ 98 | self._read_table_from_file(self._output_file), 99 | table, 100 | ]) 101 | self._output_file.num_samples += num_samples 102 | if table is not None: 103 | assert self._output_file.num_samples == len(table) 104 | pq.write_table(table, self._output_file.path) 105 | 106 | def _load(self, num_samples, return_table=False): 107 | if return_table: 108 | tables = [] 109 | while num_samples > 0: 110 | if len(self._input_files) > 0: 111 | load_file = self._input_files.pop() 112 | else: 113 | load_file = self._output_file 114 | self._output_file = None 115 | load_num_samples = min(load_file.num_samples, num_samples) 116 | if return_table: 117 | load_table = self._read_table_from_file(load_file) 118 | tables.append(load_table.slice(length=load_num_samples)) 119 | if load_num_samples < load_file.num_samples: 120 | self._store( 121 | load_file.num_samples - load_num_samples, 122 | table=load_table.slice( 123 | offset=load_num_samples) if return_table else None, 124 | ) 125 | num_samples -= load_num_samples 126 | if return_table: 127 | return pa.concat_tables(tables) 128 | 129 | def balance(larger_shard, smaller_shard, idx): 130 | assert larger_shard.num_samples > smaller_shard.num_samples 131 | num_samples_to_transfer = ( 132 | larger_shard.num_samples - 133 | (larger_shard.num_samples + smaller_shard.num_samples) // 2) 134 | smaller_shard._store( 135 | num_samples_to_transfer, 136 | table=larger_shard._load( 137 | num_samples_to_transfer, 138 | return_table=(idx % get_world_size() == get_rank()), 139 | ), 140 | ) 141 | 142 | def flush(self, idx): 143 | if idx % get_world_size() == get_rank(): 144 | input_tables = [] 145 | num_samples_to_flush = 0 146 | while len(self._input_files) > 0: 147 | input_file = self._input_files.pop() 148 | num_samples_to_flush += input_file.num_samples 149 | if idx % get_world_size() == get_rank(): 150 | input_tables.append(self._read_table_from_file(input_file)) 151 | if num_samples_to_flush > 0: 152 | self._store( 153 | num_samples_to_flush, 154 | table=(pa.concat_tables(input_tables) if 155 | (idx % get_world_size() == get_rank()) else None), 156 | ) 157 | 158 | 159 | class Progress: 160 | 161 | def __init__(self, shards): 162 | num_shards = len(shards) 163 | total_num_samples = sum((s.num_samples for s in shards)) 164 | base_num_samples_per_shard = total_num_samples // num_shards 165 | self._targets = { 166 | base_num_samples_per_shard: num_shards - total_num_samples % num_shards, 167 | base_num_samples_per_shard + 1: total_num_samples % num_shards, 168 | } 169 | self._ready_shards = [] 170 | 171 | def __repr__(self): 172 | s = [ 173 | 'Progress(', 174 | ' Remaining:', 175 | ] 176 | s += [ 177 | ' {} shards with {} samples per shard'.format(v, k) 178 | for k, v in self._targets.items() 179 | ] 180 | s += [ 181 | ' Ready:', 182 | ' {} shards'.format(len(self._ready_shards)), 183 | ')', 184 | ] 185 | return '\n'.join(s) 186 | 187 | def completed(self): 188 | return sum(self._targets.values()) == 0 189 | 190 | def report(self, shards): 191 | smaller_shards, larger_shards = [], [] 192 | for shard in shards: 193 | if shard.num_samples in self._targets: 194 | self._targets[shard.num_samples] -= 1 195 | self._ready_shards.append(shard) 196 | if self._targets[shard.num_samples] == 0: 197 | del self._targets[shard.num_samples] 198 | else: 199 | if shard.num_samples < min(self._targets.keys()): 200 | smaller_shards.append(shard) 201 | else: 202 | larger_shards.append(shard) 203 | return smaller_shards, larger_shards 204 | 205 | @property 206 | def ready_shards(self): 207 | return self._ready_shards 208 | 209 | 210 | def get_world_size(): 211 | return MPI.COMM_WORLD.Get_size() 212 | 213 | 214 | def get_rank(): 215 | return MPI.COMM_WORLD.Get_rank() 216 | 217 | 218 | def barrier(): 219 | return MPI.COMM_WORLD.barrier() 220 | 221 | 222 | def allreduce(array, op=MPI.SUM): 223 | MPI.COMM_WORLD.Allreduce(MPI.IN_PLACE, array, op=op) 224 | 225 | 226 | def _build_files(file_paths): 227 | # Get the number of samples for each file in a collectively distributed 228 | # approach. 229 | all_files_num_samples = np.zeros((len(file_paths),), dtype=np.uint64) 230 | for file_idx in range(get_rank(), len(file_paths), get_world_size()): 231 | all_files_num_samples[file_idx] = get_num_samples_of_parquet( 232 | file_paths[file_idx]) 233 | allreduce(all_files_num_samples) 234 | return sorted( 235 | [ 236 | File(path, num_samples) for (path, num_samples) in zip( 237 | file_paths, 238 | all_files_num_samples.tolist(), 239 | ) 240 | ], 241 | key=lambda f: f.num_samples, 242 | ) 243 | 244 | 245 | def _build_shards(files, num_shards, outdir, keep_orig=True, postfix=''): 246 | return [ 247 | Shard( 248 | idx, 249 | files[idx::num_shards] if idx < len(files) else None, 250 | outdir, 251 | keep_orig=keep_orig, 252 | postfix=postfix, 253 | ) for idx in range(num_shards) 254 | ] 255 | 256 | 257 | def _calculate_mean_std_num_samples(shards): 258 | num_samples = [shard.num_samples for shard in shards] 259 | if len(num_samples) > 0: 260 | return np.mean(num_samples), np.std(num_samples) 261 | else: 262 | return np.NAN, np.NAN 263 | 264 | 265 | def attach_args(parser=argparse.ArgumentParser(""" 266 | LDDL Load Balancer for the parquet shards generated by the LDDL Preprocessor 267 | 268 | Assume the set of parquet shards generated by the LDDL Preprocessor is P, for 269 | any two parquet shards a and b in P, the LDDL load balancer makes sure that the 270 | numbers of samples in a and b differ *at most* by 1. In other words, the LDDL 271 | load balancer "balances" the number of samples among the parquet shards. 272 | 273 | MPI is used to scale the LDDL load balancer to multi-processes and multi-nodes. 274 | MPI can be accessed in various ways. For example, we can access MPI via mpirun: 275 | $ mpirun -c --oversubscribe --allow-run-as-root \\ 276 | balance_dask_output ... 277 | We can also access MPI via SLURM in a HPC cluster: 278 | $ srun -l --mpi=pmix --ntasks-per-node= \\ 279 | balance_dask_output ... 280 | """)): 281 | parser.add_argument( 282 | '--indir', 283 | type=str, 284 | required=True, 285 | help='The path to the directory that contains the parquet shards ' 286 | 'generated by the LDDL Preprocessor.', 287 | ) 288 | parser.add_argument( 289 | '--outdir', 290 | type=str, 291 | default=None, 292 | help="The path where the balanced parquet shards will be stored. If " 293 | "unspecified, the balanced parquet shards will be stored in the " 294 | "directory of '--indir'.", 295 | ) 296 | parser.add_argument( 297 | '--num-shards', 298 | type=int, 299 | required=True, 300 | help='The total number of shards that should be balanced into.', 301 | ) 302 | parser.add_argument( 303 | '--bin-ids', 304 | type=int, 305 | nargs='*', 306 | default=None, 307 | help='The bin IDs to perform load balance on (if binning is enabled). If ' 308 | 'unspecified, load balance will be performed on all bins.', 309 | ) 310 | attach_bool_arg( 311 | parser, 312 | 'keep-orig', 313 | default=False, 314 | help_str="If '--keep-orig' is specified, the original unbalanced parquet " 315 | "shards are kept. By default, those original unbalanced parquet shards " 316 | "are deleted after the balanced shards are generated.", 317 | ) 318 | return parser 319 | 320 | 321 | def _balance(file_paths, num_shards, outdir, keep_orig=True, postfix=''): 322 | files = _build_files(file_paths) 323 | shards = _build_shards( 324 | files, 325 | num_shards, 326 | outdir, 327 | keep_orig=keep_orig, 328 | postfix=postfix, 329 | ) 330 | if get_rank() == 0: 331 | print('Balancing the following {} files into {} shards:'.format( 332 | len(files), num_shards)) 333 | print('SUM(files.num_samples) = {}, SUM(shards.num_samples) = {}'.format( 334 | sum((f.num_samples for f in files)), 335 | sum((s.num_samples for s in shards)), 336 | )) 337 | progress = Progress(shards) 338 | if get_rank() == 0: 339 | print('Begin with {}'.format(progress)) 340 | iteration = 0 341 | while not progress.completed(): 342 | smaller_shards, larger_shards = progress.report(shards) 343 | if get_rank() == 0: 344 | print('iteration {}, {}, left {}, right {}'.format( 345 | iteration, 346 | progress, 347 | _calculate_mean_std_num_samples(smaller_shards), 348 | _calculate_mean_std_num_samples(larger_shards), 349 | )) 350 | smaller_shards = list( 351 | sorted(smaller_shards, key=lambda shard: shard.num_samples)) 352 | larger_shards = list( 353 | sorted( 354 | larger_shards, 355 | key=lambda shard: shard.num_samples, 356 | reverse=True, 357 | )) 358 | num_pairs = min(len(smaller_shards), len(larger_shards)) 359 | for i, (smaller_shard, larger_shard) in enumerate( 360 | zip(smaller_shards[:num_pairs], larger_shards[:num_pairs])): 361 | larger_shard.balance(smaller_shard, i) 362 | barrier() 363 | shards = smaller_shards + larger_shards 364 | iteration += 1 365 | 366 | [shard.flush(i) for i, shard in enumerate(progress.ready_shards)] 367 | if get_rank() == 0: 368 | print('Done!') 369 | return progress.ready_shards 370 | 371 | 372 | def _store_num_samples_per_shard(shards, outdir): 373 | num_samples_per_shard = { 374 | os.path.basename(shard._output_file.path): shard._output_file.num_samples 375 | for shard in shards 376 | } 377 | with open(os.path.join(outdir, '.num_samples.json'), 'w') as f: 378 | json.dump(num_samples_per_shard, f) 379 | 380 | 381 | def main(args): 382 | 383 | if args.outdir is None: 384 | args.outdir = args.indir 385 | else: 386 | args.outdir = expand_outdir_and_mkdir(args.outdir) 387 | 388 | file_paths = get_all_parquets_under(args.indir) 389 | if args.bin_ids is None: 390 | bin_ids = get_all_bin_ids(file_paths) 391 | if len(bin_ids) > 0: 392 | args.bin_ids = bin_ids 393 | ready_shards = [] 394 | if args.bin_ids is None: 395 | if get_rank() == 0: 396 | print('Load balancing for unbinned files ...') 397 | ready_shards.extend( 398 | _balance(file_paths, 399 | args.num_shards, 400 | args.outdir, 401 | keep_orig=args.keep_orig)) 402 | else: 403 | if get_rank() == 0: 404 | print('Load balancing for bin_ids = {} ...'.format(args.bin_ids)) 405 | for bin_id in args.bin_ids: 406 | if get_rank() == 0: 407 | print('Balancing bin_id = {} ...'.format(bin_id)) 408 | file_paths_current_bin = get_file_paths_for_bin_id(file_paths, bin_id) 409 | ready_shards.extend( 410 | _balance( 411 | file_paths_current_bin, 412 | args.num_shards, 413 | args.outdir, 414 | keep_orig=args.keep_orig, 415 | postfix='_{}'.format(bin_id), 416 | )) 417 | if get_rank() == 0: 418 | _store_num_samples_per_shard(ready_shards, args.outdir) 419 | 420 | 421 | def console_script(): 422 | tic = time.perf_counter() 423 | main(attach_args().parse_args()) 424 | if get_rank() == 0: 425 | print('Load balancing took {} s!'.format(time.perf_counter() - tic)) 426 | 427 | 428 | def generate_num_samples_cache(): 429 | parser = argparse.ArgumentParser( 430 | 'Generate .num_samples.json for the balanced parquets.') 431 | parser.add_argument( 432 | '--indir', 433 | type=str, 434 | default=None, 435 | help='path to the dir that contains the balanced shards', 436 | ) 437 | args = parser.parse_args() 438 | file_paths = get_all_parquets_under(args.indir) 439 | # Get the number of samples for each file in a collectively distributed 440 | # approach. 441 | all_files_num_samples = np.zeros((len(file_paths),), dtype=np.uint64) 442 | for file_idx in range(get_rank(), len(file_paths), get_world_size()): 443 | all_files_num_samples[file_idx] = get_num_samples_of_parquet( 444 | file_paths[file_idx]) 445 | allreduce(all_files_num_samples) 446 | all_files_num_samples = all_files_num_samples.tolist() 447 | with open(os.path.join(args.indir, '.num_samples.json'), 'w') as nsf: 448 | json.dump( 449 | { 450 | os.path.basename(file_paths[file_idx]): 451 | all_files_num_samples[file_idx] 452 | for file_idx in range(len(file_paths)) 453 | }, 454 | nsf, 455 | ) 456 | -------------------------------------------------------------------------------- /lddl/dask/readers.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES 3 | # Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | # SPDX-License-Identifier: MIT 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a 7 | # copy of this software and associated documentation files (the "Software"), 8 | # to deal in the Software without restriction, including without limitation 9 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 10 | # and/or sell copies of the Software, and to permit persons to whom the 11 | # Software is furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 19 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 21 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 22 | # DEALINGS IN THE SOFTWARE. 23 | # 24 | 25 | import dask.bag as db 26 | import nltk 27 | import os 28 | import random 29 | 30 | 31 | def _filter_empty_strs(bag_strs): 32 | return bag_strs.map(lambda s: s.strip()).filter(lambda s: len(s) > 0) 33 | 34 | 35 | def _find_files_under(path, extensions={'.txt'}): 36 | all_files = [] 37 | for current_dir, sub_dirs, file_names in os.walk(path): 38 | for file_name in file_names: 39 | if os.path.splitext(file_name)[1] in extensions: 40 | all_files.append(os.path.join(current_dir, file_name)) 41 | return list(sorted(all_files)) 42 | 43 | 44 | def _total_bytes_of(files): 45 | return sum(map(os.path.getsize, files)) 46 | 47 | 48 | def estimate_block_size(paths, num_blocks): 49 | total_bytes = 0 50 | for p in paths: 51 | if p is None: 52 | continue 53 | total_bytes += _total_bytes_of(_find_files_under(p)) 54 | print('total_bytes = {}, num_blocks = {}'.format(total_bytes, num_blocks)) 55 | block_size = round(total_bytes / num_blocks) 56 | print('block_size = {} bytes'.format(block_size)) 57 | return block_size 58 | 59 | 60 | def _read_bag_of_text( 61 | path, 62 | blocksize=None, 63 | sample_ratio=1.0, 64 | sample_seed=12345, 65 | ): 66 | input_files = _find_files_under(path) 67 | bag_strs = db.read_text(input_files, blocksize=blocksize) 68 | bag_strs = _filter_empty_strs(bag_strs) 69 | if sample_ratio < 1.0: 70 | bag_strs = bag_strs.random_sample(sample_ratio, random_state=sample_seed) 71 | return bag_strs 72 | 73 | 74 | def read_wikipedia( 75 | path, 76 | lang='en', 77 | blocksize=None, 78 | sample_ratio=1.0, 79 | sample_seed=12345, 80 | ): 81 | return _read_bag_of_text( 82 | os.path.join(path, lang), 83 | blocksize=blocksize, 84 | sample_ratio=sample_ratio, 85 | sample_seed=sample_seed, 86 | ) 87 | 88 | 89 | def read_books( 90 | path, 91 | blocksize=None, 92 | sample_ratio=1.0, 93 | sample_seed=12345, 94 | ): 95 | return _read_bag_of_text( 96 | path, 97 | blocksize=blocksize, 98 | sample_ratio=sample_ratio, 99 | sample_seed=sample_seed, 100 | ) 101 | 102 | 103 | def read_common_crawl( 104 | path, 105 | blocksize=None, 106 | sample_ratio=1.0, 107 | sample_seed=12345, 108 | ): 109 | return _read_bag_of_text( 110 | path, 111 | blocksize=blocksize, 112 | sample_ratio=sample_ratio, 113 | sample_seed=sample_seed, 114 | ) 115 | 116 | 117 | def read_open_webtext( 118 | path, 119 | blocksize=None, 120 | sample_ratio=1.0, 121 | sample_seed=12345, 122 | ): 123 | return _read_bag_of_text( 124 | path, 125 | blocksize=blocksize, 126 | sample_ratio=sample_ratio, 127 | sample_seed=sample_seed, 128 | ) 129 | 130 | 131 | def split_id_text(raw_text): 132 | # The first token is the document id. 133 | i = 0 134 | while i < len(raw_text) and not raw_text[i].isspace(): 135 | i += 1 136 | return raw_text[:i], raw_text[i + 1:] 137 | -------------------------------------------------------------------------------- /lddl/download/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/LDDL/e4006d2bad94113a39ba534d4e8af81db3c40642/lddl/download/__init__.py -------------------------------------------------------------------------------- /lddl/download/books.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES 3 | # Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | # SPDX-License-Identifier: MIT 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a 7 | # copy of this software and associated documentation files (the "Software"), 8 | # to deal in the Software without restriction, including without limitation 9 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 10 | # and/or sell copies of the Software, and to permit persons to whom the 11 | # Software is furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 19 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 21 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 22 | # DEALINGS IN THE SOFTWARE. 23 | # 24 | 25 | import argparse 26 | import functools 27 | import multiprocessing 28 | import os 29 | import subprocess 30 | import tqdm 31 | 32 | from .utils import download, parse_str_of_num_bytes 33 | from lddl.utils import (expand_outdir_and_mkdir, mkdir, 34 | get_all_files_paths_under, attach_bool_arg) 35 | 36 | 37 | def _get_url(): 38 | return 'https://the-eye.eu/public/AI/pile_preliminary_components/books1.tar.gz' 39 | 40 | 41 | def attach_args(parser=argparse.ArgumentParser(""" 42 | Books Downloader performs the following steps: 43 | - Step 1: Download the compressed bookscorpus from {} into the directory 44 | specified by the --outdir flag. 45 | - Step 2: Unzip the compressed bookscorpus into raw text files of individual 46 | books. 47 | - Step 3: Shard the books into text shards in the 'source' subdirectory under 48 | the directory specified by the --outdir flag. The text shards under the 49 | `source` subdirectory can then be used as the input to the LDDL preprocessor. 50 | All steps are executed by default. Each step, before it starts, expects the 51 | previous steps already finish. You can turn Step 1 off by --no-download, turn 52 | Step 2 off by --no-unzip, and turn Step 3 off by --no-shard. 53 | 54 | Examples: 55 | 56 | # Download the compressed bookscorpus into books/books1.tar.gz : 57 | $ download_books --no-unzip --no-shard 58 | $ tree books/ # tree can be installed via `sudo apt install tree`. 59 | books/ 60 | └── books1.tar.gz 61 | 62 | # Unzip books/books1.tar.gz into individual books: 63 | $ download_books --no-download --no-shard 64 | $ tree books/ 65 | books/ 66 | ├── books1 67 | │   ├── 2020-08-27-epub_urls.txt 68 | │   └── epubtxt 69 | │   ├── 1000-lines-magic-sequence.epub.txt 70 | │   ├── 1000-yards-john-milton-1.epub.txt 71 | │   ... 72 | │   └── zorana-confessions-of-a-small-town-super-villain.epub.txt 73 | ├── books1.tar.gz 74 | ├── tar.err 75 | └── tar.out 76 | 77 | # Shard the books into text shards under books/source which can be read by 78 | # the LDDL preprocessor as input. 79 | $ download_books --no-download --no-unzip 80 | $ tree books/ 81 | books/ 82 | ├── books1 83 | │   ├── 2020-08-27-epub_urls.txt 84 | │   └── epubtxt 85 | │   ├── 1000-lines-magic-sequence.epub.txt 86 | │   ├── 1000-yards-john-milton-1.epub.txt 87 | │   ... 88 | │   └── zorana-confessions-of-a-small-town-super-villain.epub.txt 89 | ├── books1.tar.gz 90 | ├── source 91 | │   ├── 0.txt 92 | │   ... 93 | │   └── 9.txt 94 | ├── tar.err 95 | └── tar.out 96 | # books/source is the input to the LDDL preprocessor. 97 | 98 | # Or, we could do all 3 steps together: 99 | $ download_books --outdir books/ 100 | """.format(_get_url()))): 101 | parser.add_argument( 102 | '--outdir', 103 | type=str, 104 | default=None, 105 | required=True, 106 | help='Path to the output directory. This directory will be created if not' 107 | ' already existed.', 108 | ) 109 | defaults = { 110 | '--download-chunk-size': 16 * 1024 * 1024, 111 | '--num-shards': 10, 112 | '--shard-num-processes': os.cpu_count(), 113 | } 114 | attach_bool_arg( 115 | parser, 116 | 'download', 117 | default=True, 118 | help_str='--download is set by default. To skip Step 1, explicitly set ' 119 | '--no-download.', 120 | ) 121 | attach_bool_arg( 122 | parser, 123 | 'unzip', 124 | default=True, 125 | help_str='--unzip is set by default. To skip Step 2, explicitly set ' 126 | '--no-unzip.', 127 | ) 128 | attach_bool_arg( 129 | parser, 130 | 'shard', 131 | default=True, 132 | help_str='--shard is set by default. To skip Step 3, explicitly set ' 133 | '--no-shard.', 134 | ) 135 | parser.add_argument( 136 | '--download-chunk-size', 137 | type=functools.partial(parse_str_of_num_bytes, return_str=False), 138 | default=defaults['--download-chunk-size'], 139 | metavar="n[KMG]", 140 | help='The downloading will be performed in a streaming way by looping ' 141 | 'over the following steps: (i) transfer a small chunk of data over the ' 142 | 'network into the host memory, (ii) write this chunk onto disk. This flag' 143 | ' indicates the chunk size. Default: {}'.format( 144 | defaults['--download-chunk-size']), 145 | ) 146 | parser.add_argument( 147 | '--num-shards', 148 | type=int, 149 | default=defaults['--num-shards'], 150 | help='The number of text shards into which the books are aggregated. ' 151 | 'Default: {}'.format(defaults['--num-shards']), 152 | ) 153 | parser.add_argument( 154 | '--shard-num-processes', 155 | type=int, 156 | default=defaults['--shard-num-processes'], 157 | help='The number of processes used to shard all books. ' 158 | 'Default: {}'.format(defaults['--shard-num-processes']), 159 | ) 160 | return parser 161 | 162 | 163 | def _shard_book(shard): 164 | shard_path, books = shard 165 | with open(shard_path, 'w', newline='\n') as shard_file: 166 | one_line_books = [] 167 | for book in books: 168 | with open(book, 'r', encoding='utf-8-sig', newline='\n') as book_file: 169 | book_lines = (bl.strip() for bl in book_file) 170 | book_lines = [bl for bl in book_lines if len(bl) > 0] 171 | # The first token is the name of the book. 172 | book_name = os.path.splitext(os.path.basename(book))[0] 173 | one_line_books.append(' '.join([book_name] + book_lines)) 174 | shard_file.write('\n'.join(one_line_books)) 175 | 176 | 177 | def _shard_books(books_dir, shards_dir, num_shards, num_processes): 178 | book_paths = [ 179 | f for f in get_all_files_paths_under(books_dir) 180 | if os.path.splitext(f)[1] == '.txt' 181 | ] 182 | shards = [( 183 | os.path.join(shards_dir, '{}.txt'.format(shard_idx)), 184 | book_paths[shard_idx::num_shards], 185 | ) for shard_idx in range(num_shards)] 186 | with multiprocessing.Pool(num_processes) as p: 187 | list(tqdm.tqdm(p.imap(_shard_book, shards), total=len(shards))) 188 | 189 | 190 | def main(args): 191 | args.outdir = expand_outdir_and_mkdir(args.outdir) 192 | target_path = os.path.join(args.outdir, 'books1.tar.gz') 193 | if args.download: 194 | download( 195 | _get_url(), 196 | target_path, 197 | chunk_size=args.download_chunk_size, 198 | ) 199 | if args.unzip: 200 | print('Unzipping {} ...'.format(target_path)) 201 | out_path = os.path.join(args.outdir, 'tar.out') 202 | err_path = os.path.join(args.outdir, 'tar.err') 203 | try: 204 | subprocess.run( 205 | ['tar', '-xvzf', target_path, '-C', args.outdir], 206 | check=True, 207 | stdout=open(out_path, 'w'), 208 | stderr=open(err_path, 'w'), 209 | ) 210 | except subprocess.CalledProcessError as e: 211 | print(e, 'Please check {} and {}'.format(out_path, err_path)) 212 | raise 213 | if args.shard: 214 | books_dir = os.path.join(args.outdir, 'books1', 'epubtxt') 215 | print('Sharding {} ...'.format(books_dir)) 216 | dask_source_path = os.path.join(args.outdir, 'source') 217 | mkdir(dask_source_path) 218 | _shard_books( 219 | books_dir, 220 | dask_source_path, 221 | args.num_shards, 222 | args.shard_num_processes, 223 | ) 224 | print('Dask source prepared at {} !'.format(dask_source_path)) 225 | 226 | 227 | def console_script(): 228 | main(attach_args().parse_args()) 229 | -------------------------------------------------------------------------------- /lddl/download/openwebtext.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES 3 | # Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | # SPDX-License-Identifier: MIT 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a 7 | # copy of this software and associated documentation files (the "Software"), 8 | # to deal in the Software without restriction, including without limitation 9 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 10 | # and/or sell copies of the Software, and to permit persons to whom the 11 | # Software is furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 19 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 21 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 22 | # DEALINGS IN THE SOFTWARE. 23 | # 24 | 25 | import argparse 26 | import functools 27 | import multiprocessing 28 | import os 29 | import shutil 30 | from glob import glob 31 | import subprocess 32 | import tqdm 33 | import gdown 34 | 35 | from lddl.utils import ( 36 | expand_outdir_and_mkdir, 37 | mkdir, 38 | get_all_files_paths_under, 39 | attach_bool_arg, 40 | ) 41 | 42 | 43 | def attach_args(parser=argparse.ArgumentParser(""" 44 | OpenWebTextCorpus Downloader performs the following steps: 45 | - Step 1: Download OpenWebTextCorpus 46 | (https://skylion007.github.io/OpenWebTextCorpus/) 47 | from provided google drive url and extract the raw text of the articles to 48 | the directory specified by the --outdir flag. 49 | - Step 2: Prepare and aggregate the raw text into text shards in the 'source' 50 | subdirectory under the directory specified by the --outdir flag. The text 51 | shards under the 'source' subdirectory can then be used as the input to the 52 | LDDL preprocessor. 53 | All steps are executed by default. Each step, before it starts, expects the 54 | previous steps already finish. You can turn Step 1 off by --no-download, and 55 | turn Step 2 off by --no-unzip and --no-shard. 56 | """)): 57 | parser.add_argument( 58 | '--outdir', 59 | type=str, 60 | default=None, 61 | required=True, 62 | help='path to the output dir', 63 | ) 64 | attach_bool_arg( 65 | parser, 66 | 'download', 67 | default=True, 68 | help_str='--download is set by default. To skip download, explicitly set ' 69 | '--no-download.', 70 | ) 71 | attach_bool_arg( 72 | parser, 73 | 'unzip', 74 | default=True, 75 | help_str='--unzip is set by default. To skip unzip, explicitly set ' 76 | '--no-unzip.', 77 | ) 78 | attach_bool_arg( 79 | parser, 80 | 'shard', 81 | default=True, 82 | help_str='--shard is set by default. To skip shard, explicitly set ' 83 | '--no-shard.', 84 | ) 85 | parser.add_argument( 86 | '--num-shards', 87 | type=int, 88 | default=32, 89 | help='number of shards', 90 | ) 91 | parser.add_argument( 92 | '--shard-num-processes', 93 | type=int, 94 | default=os.cpu_count(), 95 | help='num of processes used to shard OpenWebTextCorpus', 96 | ) 97 | parser.add_argument( 98 | '--url', 99 | type=str, 100 | default='https://drive.google.com/uc?id=1EA5V0oetDCOke7afsktL_JDQ-ETtNOvx', 101 | help='the google drive url of OpenWebTextCorpus', 102 | ) 103 | return parser 104 | 105 | 106 | def _shard_pages(shard): 107 | shard_path, pages = shard 108 | with open(shard_path, 'w', newline='\n') as shard_file: 109 | one_line_pages = [] 110 | for page in pages: 111 | text_paths = [ 112 | f for f in get_all_files_paths_under(page) 113 | if os.path.splitext(f)[1] == '.txt' 114 | ] 115 | page_lines = [] 116 | for text in text_paths: 117 | with open(text, 'r', encoding='utf-8-sig', newline='\n') as page_file: 118 | sub_page_lines = (pg.strip() for pg in page_file) 119 | sub_page_lines = [pg for pg in sub_page_lines if len(pg) > 0] 120 | page_lines.extend(sub_page_lines) 121 | # The first token is the name of the page. 122 | page_name = os.path.splitext(os.path.basename(page))[0] 123 | one_line_pages.append(' '.join([page_name] + page_lines)) 124 | shard_file.write('\n'.join(one_line_pages)) 125 | 126 | 127 | def unzip_subset(subset, text_dir): 128 | try: 129 | subdir_name = subset.split('.xz')[0].split('/')[-1] 130 | tmpdir_name = os.path.join('/tmp', subdir_name) 131 | subdir_name = os.path.join(text_dir, subdir_name) 132 | mkdir(subdir_name) 133 | mkdir(tmpdir_name) 134 | out_path = os.path.join(tmpdir_name, 'tar.out') 135 | err_path = os.path.join(tmpdir_name, 'tar.err') 136 | subprocess.run( 137 | ['tar', '-xvf', subset, '-C', subdir_name], 138 | check=True, 139 | stdout=open(out_path, 'w'), 140 | stderr=open(err_path, 'w'), 141 | ) 142 | shutil.rmtree(tmpdir_name) 143 | except subprocess.CalledProcessError as e: 144 | print(e, 'Please check {} and {}'.format(out_path, err_path)) 145 | raise 146 | 147 | 148 | def unzip_merge_txt(openweb_dir, text_dir, num_processes): 149 | subset_paths = [ 150 | f for f in get_all_files_paths_under(openweb_dir) 151 | if os.path.splitext(f)[1] == '.xz' 152 | ] 153 | with multiprocessing.Pool(num_processes) as p: 154 | list( 155 | tqdm.tqdm(p.map(functools.partial(unzip_subset, text_dir=text_dir), 156 | subset_paths), 157 | total=len(subset_paths))) 158 | 159 | 160 | def _shard_openwebs(text_dir, shards_dir, num_shards, num_processes): 161 | dir_paths = [d for d in glob(text_dir + '/*')] 162 | shards = [( 163 | os.path.join(shards_dir, '{}.txt'.format(shard_idx)), 164 | dir_paths[shard_idx::num_shards], 165 | ) for shard_idx in range(num_shards)] 166 | with multiprocessing.Pool(num_processes) as p: 167 | list(tqdm.tqdm(p.imap(_shard_pages, shards), total=len(shards))) 168 | 169 | 170 | def main(args): 171 | args.outdir = expand_outdir_and_mkdir(args.outdir) 172 | target_path = os.path.join(args.outdir, 'openwebtext.tar.xz') 173 | if args.download: 174 | gdown.download(args.url, target_path, quiet=False) 175 | if args.unzip: 176 | print('Unzipping {} ...'.format(target_path)) 177 | out_path = os.path.join(args.outdir, 'tar.out') 178 | err_path = os.path.join(args.outdir, 'tar.err') 179 | try: 180 | subprocess.run( 181 | ['tar', '-xvf', target_path, '-C', args.outdir], 182 | check=True, 183 | stdout=open(out_path, 'w'), 184 | stderr=open(err_path, 'w'), 185 | ) 186 | except subprocess.CalledProcessError as e: 187 | print(e, 'Please check {} and {}'.format(out_path, err_path)) 188 | raise 189 | openweb_dir = os.path.join(args.outdir, 'openwebtext') 190 | text_dir = os.path.join(args.outdir, 'txt') 191 | mkdir(text_dir) 192 | unzip_merge_txt(openweb_dir, text_dir, args.shard_num_processes) 193 | 194 | if args.shard: 195 | text_dir = os.path.join(args.outdir, 'txt') 196 | print('Sharding {} ...'.format(text_dir)) 197 | dask_source_path = os.path.join(args.outdir, 'source') 198 | mkdir(dask_source_path) 199 | _shard_openwebs( 200 | text_dir, 201 | dask_source_path, 202 | args.num_shards, 203 | args.shard_num_processes, 204 | ) 205 | print('Dask source prepared at {} !'.format(dask_source_path)) 206 | 207 | 208 | def console_script(): 209 | main(attach_args().parse_args()) 210 | -------------------------------------------------------------------------------- /lddl/download/utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES 3 | # Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | # SPDX-License-Identifier: MIT 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a 7 | # copy of this software and associated documentation files (the "Software"), 8 | # to deal in the Software without restriction, including without limitation 9 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 10 | # and/or sell copies of the Software, and to permit persons to whom the 11 | # Software is furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 19 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 21 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 22 | # DEALINGS IN THE SOFTWARE. 23 | # 24 | 25 | import os 26 | import requests 27 | import tqdm 28 | 29 | 30 | def download(url, path, chunk_size=16 * 1024 * 1024): 31 | with requests.get(url, stream=True) as r: 32 | r.raise_for_status() 33 | total_size = int(r.headers.get('content-length', 0)) 34 | progress_bar = tqdm.tqdm(total=total_size, unit='Bytes', unit_scale=True) 35 | with open(path, 'wb') as f: 36 | for chunk in r.iter_content(chunk_size=chunk_size): 37 | progress_bar.update(len(chunk)) 38 | f.write(chunk) 39 | progress_bar.close() 40 | 41 | 42 | def parse_str_of_num_bytes(s, return_str=False): 43 | try: 44 | power = 'kmg'.find(s[-1].lower()) + 1 45 | size = float(s[:-1]) * 1024**power 46 | except ValueError: 47 | raise ValueError('Invalid size: {}'.format(s)) 48 | if return_str: 49 | return s 50 | else: 51 | return int(size) 52 | -------------------------------------------------------------------------------- /lddl/download/wikipedia.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES 3 | # Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | # SPDX-License-Identifier: MIT 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a 7 | # copy of this software and associated documentation files (the "Software"), 8 | # to deal in the Software without restriction, including without limitation 9 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 10 | # and/or sell copies of the Software, and to permit persons to whom the 11 | # Software is furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 19 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 21 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 22 | # DEALINGS IN THE SOFTWARE. 23 | # 24 | 25 | import argparse 26 | import functools 27 | import multiprocessing 28 | import os 29 | import subprocess 30 | import tqdm 31 | import xml.etree.ElementTree as ET 32 | 33 | from .utils import download, parse_str_of_num_bytes 34 | from lddl.utils import (expand_outdir_and_mkdir, mkdir, 35 | get_all_files_paths_under, attach_bool_arg) 36 | 37 | 38 | def _get_url(lang): 39 | assert lang in {'en', 'zh'} 40 | return ('https://dumps.wikimedia.org/{lang}wiki/latest' 41 | '/{lang}wiki-latest-pages-articles.xml.bz2'.format(lang=lang)) 42 | 43 | 44 | def _get_download_target_filename(lang): 45 | return 'wikicorpus-{}.xml.bz2'.format(lang) 46 | 47 | 48 | def _prepare_one_shard(shard): 49 | source_shard_path, extract_shard_path = shard 50 | articles = [] 51 | with open(extract_shard_path, 'r', newline='\n') as extract_shard_file: 52 | article_open = None 53 | article_lines = [] 54 | for line in extract_shard_file: 55 | if '' in line: 58 | article_id = 'wiki-' + ET.fromstring(article_open + line).attrib['id'] 59 | article_open = None 60 | # article_lines[0] is the title 61 | if len(article_lines) > 1: 62 | # The first token is the article id. 63 | articles.append(' '.join([article_id] + article_lines[1:])) 64 | article_lines = [] 65 | else: 66 | if article_open: 67 | line = line.strip() 68 | if len(line) > 0: 69 | article_lines.append(line.strip()) 70 | 71 | if len(articles) > 0: 72 | print('{} -> {}'.format(extract_shard_path, source_shard_path)) 73 | with open(source_shard_path, 'w', newline='\n') as source_shard_file: 74 | source_shard_file.write('\n'.join(articles)) 75 | 76 | 77 | def _prepare_dask_source(extract_path, dask_source_path, num_processes): 78 | extracted_shards_paths = [ 79 | p for p in get_all_files_paths_under(extract_path) if 'wiki_' in p 80 | ] 81 | shards = [(os.path.join(dask_source_path, '{}.txt'.format(i)), esp) 82 | for i, esp in enumerate(extracted_shards_paths)] 83 | 84 | with multiprocessing.Pool(num_processes) as p: 85 | list(tqdm.tqdm(p.imap(_prepare_one_shard, shards), total=len(shards))) 86 | 87 | 88 | def _download_and_extract( 89 | lang='en', 90 | to_download=True, 91 | to_extract=True, 92 | to_prepare_source=True, 93 | download_chunk_size=16 * 1024 * 1024, 94 | extract_shard_size='128M', 95 | outdir=None, 96 | num_processes=os.cpu_count(), 97 | ): 98 | if lang not in {'en', 'zh'}: 99 | raise ValueError('Language {} not supported!'.format(lang)) 100 | 101 | url = _get_url(lang) 102 | target_filename = _get_download_target_filename(lang) 103 | target_path = os.path.join(outdir, target_filename) 104 | 105 | if to_download: 106 | download(url, target_path, chunk_size=download_chunk_size) 107 | 108 | extract_path = os.path.join(outdir, 'extracted', lang) 109 | if to_extract: 110 | mkdir(extract_path) 111 | print('Extracting {} ...'.format(target_path)) 112 | subprocess.run( 113 | [ 114 | 'python', 115 | '-m', 116 | 'wikiextractor.WikiExtractor', 117 | target_path, 118 | '--output', 119 | extract_path, 120 | '--bytes', 121 | extract_shard_size, 122 | '--processes', 123 | str(num_processes), 124 | ], 125 | check=True, 126 | stdout=open(os.path.join(extract_path, 'WikiExtractor.out'), 'w'), 127 | stderr=open(os.path.join(extract_path, 'WikiExtractor.err'), 'w'), 128 | ) 129 | 130 | if to_prepare_source: 131 | print('Preparing dask source from {} ...'.format(extract_path)) 132 | dask_source_path = os.path.join(outdir, 'source', lang) 133 | mkdir(dask_source_path) 134 | _prepare_dask_source(extract_path, dask_source_path, num_processes) 135 | print('Dask source prepared at {} !'.format(dask_source_path)) 136 | 137 | 138 | def attach_args(parser=argparse.ArgumentParser(""" 139 | Wikipedia Downloader performs the following steps: 140 | - Step 1: Download the Wikipedia dumps from {} into the directory specified by 141 | the --outdir flag. 142 | - Step 2: Extract the raw text from the Wikipedia dumps which are originally in 143 | the XML format. 144 | - Step 3: Prepare and aggregate the raw text into text shards in the 'source' 145 | subdirectory under the directory specified by the --outdir flag. The text 146 | shards under the 'source' subdirectory can then be used as the input to the 147 | LDDL preprocessor. 148 | All steps are executed by default. Each step, before it starts, expects the 149 | previous steps already finish. You can turn Step 1 off by --no-download, turn 150 | Step 2 off by --no-extract, and turn Step 3 off by --no-prepare-source. 151 | 152 | Examples: 153 | 154 | # Download the English Wikipedia dumps into wikipedia/wikicorpus-en.xml.bz2 : 155 | $ download_wikipedia --outdir wikipedia/ --no-extract --no-prepare-source 156 | $ tree wikipedia/ # tree can be installed via `sudo apt install tree`. 157 | wikipedia/ 158 | └── wikicorpus-en.xml.bz2 159 | 160 | # Extract the raw text from the English Wikipedia dumps: 161 | $ download_wikipedia --outdir wikipedia/ --no-download --no-prepare-source 162 | $ tree wikipedia/ 163 | wikipedia/ 164 | ├── extracted 165 | │   └── en 166 | │   ├── AA 167 | │   │   ├── wiki_00 168 | │   │   ├── wiki_01 169 | │   │   ... 170 | │   │   └── wiki_30 171 | │   ├── WikiExtractor.err 172 | │   └── WikiExtractor.out 173 | └── wikicorpus-en.xml.bz2 174 | 175 | # Prepare and aggregate the raw text into text shards under wikipedia/source 176 | # which can be read by the LDDL preprocessor as input: 177 | $ download_wikipedia --outdir wikipedia/ --no-download --no-extract 178 | $ tree wikipedia/ 179 | wikipedia/ 180 | ├── extracted 181 | │   └── en 182 | │   ├── AA 183 | │   │   ├── wiki_00 184 | │   │   ├── wiki_01 185 | │   │   ... 186 | │   │   └── wiki_30 187 | │   ├── WikiExtractor.err 188 | │   └── WikiExtractor.out 189 | ├── source 190 | │   └── en 191 | │   ├── 0.txt 192 | │   ├── 1.txt 193 | │   ... 194 | │   └── 30.txt 195 | └── wikicorpus-en.xml.bz2 196 | # wikipedia/source/ is the input to the LDDL preprocessor. 197 | 198 | # Or, we could do all 3 steps together: 199 | $ download_wikipedia --outdir wikipedia/ 200 | """.format(_get_url('en')))): 201 | parser.add_argument( 202 | '--outdir', 203 | type=str, 204 | default=None, 205 | required=True, 206 | help='Path to the output directory. This directory will be created if not' 207 | ' already existed.', 208 | ) 209 | defaults = { 210 | '--langs': ['en'], 211 | '--download-chunk-size': 16 * 1024 * 1024, 212 | '--extract-shard-size': '512M', 213 | '--num-processes': os.cpu_count(), 214 | } 215 | parser.add_argument( 216 | '--langs', 217 | default=defaults['--langs'], 218 | nargs='+', 219 | choices=['en', 'zh'], 220 | help='Language of the wikipedia dumps to download. Default: {}'.format( 221 | defaults['--langs']), 222 | ) 223 | attach_bool_arg( 224 | parser, 225 | 'download', 226 | default=True, 227 | help_str='--download is set by default. To skip Step 1, explicitly set ' 228 | '--no-download.', 229 | ) 230 | attach_bool_arg( 231 | parser, 232 | 'extract', 233 | default=True, 234 | help_str='--extract is set by default. To skip Step 2, explicitly set ' 235 | '--no-extract.') 236 | attach_bool_arg( 237 | parser, 238 | 'prepare-source', 239 | default=True, 240 | help_str='--prepare-source is set by default. To skip Step 3, explicitly ' 241 | 'set --no-prepare-source.', 242 | ) 243 | parser.add_argument( 244 | '--download-chunk-size', 245 | type=functools.partial(parse_str_of_num_bytes, return_str=False), 246 | default=defaults['--download-chunk-size'], 247 | metavar="n[KMG]", 248 | help='The downloading will be performed in a streaming way by looping ' 249 | 'over the following steps: (i) transfer a small chunk of data over the ' 250 | 'network into the host memory, (ii) write this chunk onto disk. This flag' 251 | ' indicates the chunk size. Default: {}'.format( 252 | defaults['--download-chunk-size']), 253 | ) 254 | parser.add_argument( 255 | '--extract-shard-size', 256 | type=functools.partial(parse_str_of_num_bytes, return_str=True), 257 | default=defaults['--extract-shard-size'], 258 | metavar="n[KMG]", 259 | help='The size of each text shard. Default: {}'.format( 260 | defaults['--extract-shard-size']), 261 | ) 262 | parser.add_argument( 263 | '--num-processes', 264 | type=int, 265 | default=os.cpu_count(), 266 | help='Num of processes to use. Default: {}'.format( 267 | defaults['--num-processes']), 268 | ) 269 | return parser 270 | 271 | 272 | def main(args): 273 | args.outdir = expand_outdir_and_mkdir(args.outdir) 274 | for lang in args.langs: 275 | _download_and_extract( 276 | lang=lang, 277 | to_download=args.download, 278 | to_extract=args.extract, 279 | to_prepare_source=args.prepare_source, 280 | download_chunk_size=args.download_chunk_size, 281 | extract_shard_size=args.extract_shard_size, 282 | outdir=args.outdir, 283 | num_processes=args.num_processes, 284 | ) 285 | 286 | 287 | def console_script(): 288 | main(attach_args().parse_args()) 289 | -------------------------------------------------------------------------------- /lddl/paddle/__init__.py: -------------------------------------------------------------------------------- 1 | from .bert import get_bert_pretrain_data_loader 2 | -------------------------------------------------------------------------------- /lddl/paddle/dataloader.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES 3 | # Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | # SPDX-License-Identifier: MIT 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a 7 | # copy of this software and associated documentation files (the "Software"), 8 | # to deal in the Software without restriction, including without limitation 9 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 10 | # and/or sell copies of the Software, and to permit persons to whom the 11 | # Software is furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 19 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 21 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 22 | # DEALINGS IN THE SOFTWARE. 23 | # 24 | 25 | import random 26 | import paddle 27 | 28 | from lddl.random import choices 29 | from .datasets import ParquetDataset 30 | 31 | 32 | class Binned: 33 | 34 | def __init__(self, dataloaders, base_seed=12345, start_epoch=0, logger=None): 35 | self._dataloaders = dataloaders 36 | 37 | self._base_seed = base_seed 38 | self._epoch = start_epoch - 1 39 | 40 | self._logger = logger 41 | 42 | self._world_rng_state = None 43 | 44 | def _init_rng_states(self): 45 | orig_rng_state = random.getstate() 46 | 47 | random.seed(self._base_seed + self._epoch) 48 | self._world_rng_state = random.getstate() 49 | 50 | random.setstate(orig_rng_state) 51 | 52 | def _init_iter(self): 53 | self._init_rng_states() 54 | num_samples_remaining = [len(dl.dataset) for dl in self._dataloaders] 55 | dataiters = [iter(dl) for dl in self._dataloaders] 56 | return num_samples_remaining, dataiters 57 | 58 | def __len__(self): 59 | return sum((len(dl) for dl in self._dataloaders)) 60 | 61 | def _get_batch_size(self, batch): 62 | raise NotImplementedError('Binned is an abstract class!') 63 | 64 | def _choices(self, population, weights=None, cum_weights=None, k=1): 65 | c, self._world_rng_state = choices( 66 | population, 67 | weights=weights, 68 | cum_weights=cum_weights, 69 | k=k, 70 | rng_state=self._world_rng_state, 71 | ) 72 | return c 73 | 74 | def __iter__(self): 75 | self._epoch += 1 76 | num_samples_remaining, dataiters = self._init_iter() 77 | 78 | for i in range(len(self)): 79 | bin_id = self._choices( 80 | list(range(len(dataiters))), 81 | weights=num_samples_remaining, 82 | k=1, 83 | )[0] 84 | self._logger.to('rank').info('{}-th iteration selects bin_id = {}'.format( 85 | i, bin_id)) 86 | assert num_samples_remaining[bin_id] > 0 87 | batch = next(dataiters[bin_id]) 88 | num_samples_remaining[bin_id] -= self._get_batch_size(batch) 89 | yield batch 90 | 91 | assert sum((nsr for nsr in num_samples_remaining)) == 0 92 | 93 | 94 | class DataLoader(paddle.io.DataLoader): 95 | 96 | def __len__(self): 97 | if isinstance(self.dataset, ParquetDataset): 98 | num_workers_per_rank = max(self.num_workers, 1) 99 | num_files_per_worker = self.dataset.num_files_per_rank // num_workers_per_rank 100 | num_samples_per_worker = self.dataset.num_samples_per_file * num_files_per_worker 101 | num_batches_per_worker = ( 102 | (num_samples_per_worker - 1) // self.batch_size + 1) 103 | return num_batches_per_worker * num_workers_per_rank 104 | else: 105 | super().__len__() 106 | -------------------------------------------------------------------------------- /lddl/paddle/datasets.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES 3 | # Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | # SPDX-License-Identifier: MIT 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a 7 | # copy of this software and associated documentation files (the "Software"), 8 | # to deal in the Software without restriction, including without limitation 9 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 10 | # and/or sell copies of the Software, and to permit persons to whom the 11 | # Software is furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 19 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 21 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 22 | # DEALINGS IN THE SOFTWARE. 23 | # 24 | 25 | import json 26 | import numpy as np 27 | import os 28 | import pyarrow.parquet as pq 29 | import random 30 | import paddle 31 | 32 | from paddle.io import IterableDataset, get_worker_info 33 | try: 34 | from paddle.base.framework import in_dygraph_mode 35 | except ImportError: 36 | from paddle.fluid.framework import in_dygraph_mode 37 | 38 | from lddl.types import File 39 | from lddl.utils import get_num_samples_of_parquet 40 | from lddl.random import randrange, shuffle, sample 41 | from .utils import (get_rank, get_local_rank, get_world_size, 42 | get_nproc_per_node, get_num_nodes, get_node_rank, 43 | all_reduce_in_static_mode) 44 | 45 | 46 | class ShuffleBuffer: 47 | 48 | def __init__( 49 | self, 50 | files, 51 | max_num_samples_to_yield, 52 | decode_record_batch, 53 | size, 54 | warmup_factor, 55 | logger, 56 | rng_state, 57 | ): 58 | num_samples_wasted = (sum( 59 | (f.num_samples for f in files)) - max_num_samples_to_yield) 60 | assert 0 <= num_samples_wasted <= len(files) 61 | 62 | self._files = files 63 | self._max_num_samples_to_yield = max_num_samples_to_yield 64 | self._decode_record_batch = decode_record_batch 65 | self._size = size 66 | self._warmup_factor = warmup_factor 67 | self._logger = logger 68 | self._rng_state = rng_state 69 | 70 | @property 71 | def num_samples(self): 72 | return sum((f.num_samples for f in self._files)) 73 | 74 | def _randrange(self, stop): 75 | n, self._rng_state = randrange(stop, rng_state=self._rng_state) 76 | return n 77 | 78 | def _shuffle(self, x): 79 | self._rng_state = shuffle(x, rng_state=self._rng_state) 80 | 81 | def __iter__(self): 82 | buffer = [] 83 | num_samples_to_yield = min( 84 | self._max_num_samples_to_yield, 85 | sum((f.num_samples for f in self._files)), 86 | ) 87 | remaining_num_samples = num_samples_to_yield 88 | 89 | for f in self._files: 90 | self._logger.to('worker').info('Reading {}'.format(f.path)) 91 | for b in pq.read_table(f.path).to_batches(): 92 | for isample in self._decode_record_batch(b): 93 | if remaining_num_samples <= 0: 94 | return 95 | if (len(buffer) 96 | >= min(self._size, 97 | (num_samples_to_yield - remaining_num_samples + 1) * 98 | self._warmup_factor)): 99 | replace_idx = self._randrange(len(buffer)) 100 | yield buffer[replace_idx] 101 | buffer[replace_idx] = isample 102 | remaining_num_samples -= 1 103 | else: 104 | buffer.append(isample) 105 | self._shuffle(buffer) 106 | for isample in buffer: 107 | if remaining_num_samples <= 0: 108 | return 109 | yield isample 110 | remaining_num_samples -= 1 111 | 112 | 113 | class ParquetDataset(IterableDataset): 114 | 115 | def __init__( 116 | self, 117 | file_paths, 118 | transform=lambda x: x, 119 | shuffle_buffer_size=16384, 120 | shuffle_buffer_warmup_factor=16, 121 | base_seed=12345, 122 | logger=None, 123 | start_epoch=0, 124 | ): 125 | super().__init__() 126 | self._transform = transform 127 | self._local_rank = get_local_rank 128 | self._shuffle_buffer_size = shuffle_buffer_size 129 | self._shuffle_buffer_warmup_factor = shuffle_buffer_warmup_factor 130 | self._base_seed = base_seed 131 | 132 | self._rank = get_rank() 133 | self._world_size = get_world_size() 134 | self._nproc_per_node = get_nproc_per_node() 135 | self._num_nodes = get_num_nodes() 136 | self._node_rank = get_node_rank() 137 | 138 | self._epoch = start_epoch - 1 139 | 140 | self._logger = logger 141 | 142 | assert len(file_paths) % self._num_nodes == 0 143 | assert len(file_paths) % self._world_size == 0 144 | self._files = self._get_files(file_paths) 145 | max_num_samples_per_file = max((f.num_samples for f in self._files)) 146 | min_num_samples_per_file = min((f.num_samples for f in self._files)) 147 | assert min_num_samples_per_file in { 148 | max_num_samples_per_file - 1, 149 | max_num_samples_per_file, 150 | } 151 | self._num_samples_per_file = min_num_samples_per_file 152 | total_num_samples = sum((f.num_samples for f in self._files)) 153 | num_samples_lost = (total_num_samples - 154 | self._num_samples_per_file * len(self._files)) 155 | self._logger.to('node').warning('lost {}/{}={}% samples in total'.format( 156 | num_samples_lost, 157 | total_num_samples, 158 | num_samples_lost / total_num_samples * 100, 159 | )) 160 | 161 | self._world_rng_state = None 162 | self._worker_rng_state = None 163 | 164 | def _get_files(self, file_paths): 165 | if in_dygraph_mode(): 166 | all_files_num_samples = paddle.zeros((len(file_paths),), dtype='int64') 167 | else: 168 | all_files_num_samples = np.zeros((len(file_paths),), dtype=np.int64) 169 | # Figure out how many samples in each file. 170 | num_samples_cache = {} # Map dirname to the dict of {basename: num_samples} 171 | for idx in range(self._rank, len(file_paths), self._world_size): 172 | fp = file_paths[idx] 173 | dn = os.path.dirname(fp) 174 | bn = os.path.basename(fp) 175 | # Load the num_samples cache file if it exists. 176 | if dn not in num_samples_cache: 177 | nsfp = os.path.join(dn, '.num_samples.json') 178 | try: 179 | with open(nsfp, 'r') as nsf: 180 | num_samples_cache[dn] = json.load(nsf) 181 | except Exception as e: 182 | self._logger.to('rank').warning('failed to load {}: {}'.format( 183 | nsfp, e)) 184 | # Mark that the num_samples cache file doesn't exist for this 185 | # directory. 186 | num_samples_cache[dn] = None 187 | if num_samples_cache[dn] is not None and bn in num_samples_cache[dn]: 188 | all_files_num_samples[idx] = num_samples_cache[dn][bn] 189 | else: 190 | # Find out num_samples by loading the parquet table. 191 | all_files_num_samples[idx] = get_num_samples_of_parquet(fp) 192 | if self._world_size > 1: 193 | # Sync. accross all ranks. 194 | if in_dygraph_mode(): 195 | paddle.distributed.all_reduce( 196 | all_files_num_samples, 197 | op=paddle.distributed.ReduceOp.SUM, 198 | ) 199 | else: 200 | all_files_num_samples = all_reduce_in_static_mode( 201 | all_files_num_samples, paddle.distributed.ReduceOp.SUM) 202 | all_files_num_samples = all_files_num_samples.tolist() 203 | return [File(fp, ns) for fp, ns in zip(file_paths, all_files_num_samples)] 204 | 205 | def __len__(self): 206 | """ This function only returns how many samples per rank will be yielded 207 | by this dataset. 208 | 209 | Note that, len(dataloader), where dataloader is a PaddlePaddle DataLoader 210 | wrapping this dataset, does NOT return the accurate number of batches. This 211 | is because, when (num_samples_per_file * num_files_per_worker) is not 212 | divisible by batch_size, each worker is going to generate a partial batch 213 | at the very end. 214 | 215 | However, PaddlePaddle DataLoader's __len__ only divide the number returned from 216 | this function by batch_size, which would be smaller than the actual number 217 | of batches by at most (num_workers - 1). 218 | 219 | We need to patch PaddlePaddle DataLoader function for this function to behave 220 | correctly. 221 | """ 222 | return self._num_samples_per_file * len(self._files) // self._world_size 223 | 224 | @property 225 | def num_samples_per_file(self): 226 | return self._num_samples_per_file 227 | 228 | @property 229 | def num_files_per_rank(self): 230 | return len(self._files) // self._world_size 231 | 232 | def _decode_record_batch(self, b): 233 | raise NotImplementedError('ParquetDataset is an abstract/interface class!') 234 | 235 | def _world_identical_sample(self, population, k, counts=None): 236 | s, self._world_rng_state = sample( 237 | population, 238 | k, 239 | rng_state=self._world_rng_state, 240 | ) 241 | return s 242 | 243 | def _init_worker(self): 244 | worker_info = get_worker_info() 245 | if worker_info is None: 246 | num_workers_per_rank = 1 247 | worker_rank = 0 248 | else: 249 | num_workers_per_rank = worker_info.num_workers 250 | worker_rank = worker_info.id 251 | assert (len(self._files) % (self._world_size * num_workers_per_rank) == 0) 252 | self._logger.init_for_worker(worker_rank) 253 | return worker_rank, num_workers_per_rank 254 | 255 | def _init_rng_states(self, worker_rank, num_workers_per_rank): 256 | orig_rng_state = random.getstate() 257 | 258 | random.seed(self._base_seed + self._epoch) 259 | self._world_rng_state = random.getstate() 260 | 261 | random.seed(self._base_seed + 262 | (self._epoch * self._world_size + self._rank) * 263 | num_workers_per_rank + worker_rank) 264 | self._worker_rng_state = random.getstate() 265 | 266 | random.setstate(orig_rng_state) 267 | 268 | def __iter__(self): 269 | self._epoch += 1 270 | 271 | worker_rank, num_workers_per_rank = self._init_worker() 272 | self._init_rng_states(worker_rank, num_workers_per_rank) 273 | 274 | files = self._world_identical_sample(self._files, k=len(self._files)) 275 | self._logger.to('node').warning('epoch = {}'.format(self._epoch)) 276 | self._logger.to('worker').info( 277 | '\n'.join(['files('] + [' {}'.format(f) for f in files] + [')'])) 278 | 279 | rank_files = files[self._rank::self._world_size] 280 | worker_files = rank_files[worker_rank::num_workers_per_rank] 281 | self._logger.to('worker').info( 282 | '\n'.join(['worker_files('] + [' {}'.format(f) for f in worker_files] + 283 | [')'])) 284 | sb = ShuffleBuffer( 285 | worker_files, 286 | self._num_samples_per_file * len(worker_files), 287 | lambda b: self._decode_record_batch(b), 288 | self._shuffle_buffer_size, 289 | self._shuffle_buffer_warmup_factor, 290 | self._logger, 291 | self._worker_rng_state, 292 | ) 293 | for isample in iter(sb): 294 | yield self._transform(isample) 295 | -------------------------------------------------------------------------------- /lddl/paddle/log.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES 3 | # Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | # SPDX-License-Identifier: MIT 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a 7 | # copy of this software and associated documentation files (the "Software"), 8 | # to deal in the Software without restriction, including without limitation 9 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 10 | # and/or sell copies of the Software, and to permit persons to whom the 11 | # Software is furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 19 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 21 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 22 | # DEALINGS IN THE SOFTWARE. 23 | # 24 | 25 | import logging 26 | import os 27 | import pathlib 28 | from .utils import (get_local_rank, get_node_rank) 29 | 30 | 31 | def _get_logger_name(node_rank, local_rank=None, worker_rank=None): 32 | if local_rank is None and worker_rank is None: 33 | return 'node-{}'.format(node_rank) 34 | elif worker_rank is None: 35 | return 'node-{}_local-{}'.format(node_rank, local_rank) 36 | else: 37 | return 'node-{}_local-{}_worker-{}'.format(node_rank, local_rank, 38 | worker_rank) 39 | 40 | 41 | class DummyLogger: 42 | 43 | def debug(self, msg, *args, **kwargs): 44 | pass 45 | 46 | def info(self, msg, *args, **kwargs): 47 | pass 48 | 49 | def warning(self, msg, *args, **kwargs): 50 | pass 51 | 52 | def error(self, msg, *args, **kwargs): 53 | pass 54 | 55 | def critical(self, msg, *args, **kwargs): 56 | pass 57 | 58 | def log(self, msg, *args, **kwargs): 59 | pass 60 | 61 | def exception(self, msg, *args, **kwargs): 62 | pass 63 | 64 | 65 | class DatasetLogger: 66 | 67 | def __init__( 68 | self, 69 | log_dir=None, 70 | log_level=logging.INFO, 71 | ): 72 | self._log_dir = log_dir 73 | self._node_rank = get_node_rank() 74 | self._local_rank = get_local_rank() 75 | self._worker_rank = None 76 | self._log_level = log_level 77 | 78 | if log_dir is not None: 79 | pathlib.Path(log_dir).mkdir(parents=True, exist_ok=True) 80 | # Create node level logger. 81 | if self._local_rank == 0: 82 | self._create_logger(_get_logger_name(self._node_rank)) 83 | # Create local_rank level logger. 84 | self._create_logger( 85 | _get_logger_name(self._node_rank, local_rank=self._local_rank)) 86 | 87 | def _create_logger(self, name): 88 | logger = logging.getLogger(name) 89 | fmt = logging.Formatter( 90 | 'LDDL - %(asctime)s - %(filename)s:%(lineno)d:%(funcName)s - %(name)s ' 91 | '- %(levelname)s : %(message)s') 92 | stream_handler = logging.StreamHandler() 93 | stream_handler.setFormatter(fmt) 94 | logger.addHandler(stream_handler) 95 | if self._log_dir is not None: 96 | path = os.path.join(self._log_dir, '{}.txt'.format(name)) 97 | file_handler = logging.FileHandler(path) 98 | file_handler.setFormatter(fmt) 99 | logger.addHandler(file_handler) 100 | logger.setLevel(self._log_level) 101 | return logger 102 | 103 | def init_for_worker(self, worker_rank): 104 | if self._worker_rank is None: 105 | self._worker_rank = worker_rank 106 | self._create_logger( 107 | _get_logger_name( 108 | self._node_rank, 109 | local_rank=self._local_rank, 110 | worker_rank=worker_rank, 111 | )) 112 | 113 | def to(self, which): 114 | assert which in {'node', 'rank', 'worker'} 115 | if which == 'node': 116 | if (self._local_rank == 0 and 117 | (self._worker_rank is None or self._worker_rank == 0)): 118 | return logging.getLogger(_get_logger_name(self._node_rank)) 119 | else: 120 | return DummyLogger() 121 | elif which == 'rank': 122 | if self._worker_rank is None or self._worker_rank == 0: 123 | return logging.getLogger( 124 | _get_logger_name(self._node_rank, local_rank=self._local_rank)) 125 | else: 126 | return DummyLogger() 127 | else: # which == 'worker' 128 | return logging.getLogger( 129 | _get_logger_name( 130 | self._node_rank, 131 | local_rank=self._local_rank, 132 | worker_rank=self._worker_rank, 133 | )) 134 | -------------------------------------------------------------------------------- /lddl/paddle/utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES 3 | # Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | # SPDX-License-Identifier: MIT 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a 7 | # copy of this software and associated documentation files (the "Software"), 8 | # to deal in the Software without restriction, including without limitation 9 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 10 | # and/or sell copies of the Software, and to permit persons to whom the 11 | # Software is furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 19 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 21 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 22 | # DEALINGS IN THE SOFTWARE. 23 | # 24 | 25 | import os 26 | import paddle 27 | try: 28 | from paddle.base.framework import in_dygraph_mode 29 | from paddle.base import unique_name, core 30 | except ImportError: 31 | from paddle.fluid.framework import in_dygraph_mode 32 | from paddle.fluid import unique_name, core 33 | from paddle.distributed.fleet.base.private_helper_function import wait_server_ready 34 | 35 | 36 | def get_rank(): 37 | return int(os.getenv("PADDLE_TRAINER_ID", "0")) 38 | 39 | 40 | def get_local_rank(): 41 | return int(os.getenv('PADDLE_RANK_IN_NODE', '0')) 42 | 43 | 44 | def get_world_size(): 45 | return int(os.getenv('PADDLE_TRAINERS_NUM', '1')) 46 | 47 | 48 | def barrier(): 49 | if get_world_size() > 1: 50 | paddle.distributed.barrier() 51 | 52 | 53 | def get_endpoints(): 54 | endpoints = os.getenv('PADDLE_TRAINER_ENDPOINTS') 55 | return endpoints.split(",") 56 | 57 | 58 | def get_current_endpoint(): 59 | return os.getenv("PADDLE_CURRENT_ENDPOINT") 60 | 61 | 62 | def get_other_endpoints(): 63 | other_endpoints = get_endpoints()[:] 64 | current_endpoint = get_current_endpoint() 65 | other_endpoints.remove(current_endpoint) 66 | return other_endpoints 67 | 68 | 69 | def get_num_nodes(): 70 | # paddle_local_size = int(os.getenv('PADDLE_LOCAL_SIZE', '-1')) 71 | endpoints = get_endpoints()[:] 72 | ips = set() 73 | for endpoint in endpoints: 74 | ip = endpoint.split(":")[0] 75 | ips.add(ip) 76 | return len(ips) 77 | 78 | 79 | def get_nproc_per_node(): 80 | return get_world_size() // get_num_nodes() 81 | 82 | 83 | def get_node_rank(): 84 | """ This assume the training processes are launched via 85 | paddle.distributed.launch.py. Therefore, the ordering scheme of 86 | rank -> (node_rank, local_rank) mapping is: 87 | 0 -> (0, 0) 88 | 1 -> (0, 1) 89 | ... 90 | nproc_per_node -> (1, 0) 91 | nproc_per_node+1 -> (1, 1) 92 | ... 93 | """ 94 | nproc_per_node = get_nproc_per_node() 95 | node_rank = get_rank() // nproc_per_node 96 | return node_rank 97 | 98 | 99 | def all_reduce_in_static_mode(local_tensor, reduce_op): 100 | assert not in_dygraph_mode(), "this function can only be used in static mode" 101 | rank = get_rank() 102 | local_rank = get_local_rank() 103 | nranks = get_world_size() 104 | current_endpoint = get_current_endpoint() 105 | other_endpoints = get_other_endpoints() 106 | device = paddle.set_device("gpu") 107 | if rank == 0: 108 | wait_server_ready(other_endpoints) 109 | 110 | startup_program = paddle.static.Program() 111 | main_program = paddle.static.Program() 112 | exe = paddle.static.Executor(device) 113 | 114 | block = startup_program.global_block() 115 | nccl_id_var = block.create_var( 116 | name=unique_name.generate('nccl_id'), 117 | persistable=True, 118 | type=core.VarDesc.VarType.RAW, 119 | ) 120 | 121 | block.append_op( 122 | type='c_gen_nccl_id', 123 | inputs={}, 124 | outputs={'Out': nccl_id_var}, 125 | attrs={ 126 | 'rank': rank, 127 | 'endpoint': current_endpoint, 128 | 'other_endpoints': other_endpoints, 129 | }, 130 | ) 131 | 132 | block.append_op( 133 | type='c_comm_init', 134 | inputs={'X': nccl_id_var}, 135 | outputs={}, 136 | attrs={ 137 | 'nranks': nranks, 138 | 'rank': rank, 139 | 'ring_id': 0 140 | }, 141 | ) 142 | 143 | with paddle.static.program_guard(main_program, startup_program): 144 | data = paddle.static.data(name='local_value', shape=[-1], dtype='int64') 145 | paddle.distributed.all_reduce(data, op=reduce_op) 146 | 147 | exe.run(startup_program) 148 | results = exe.run(main_program, 149 | feed={'local_value': local_tensor}, 150 | fetch_list=[data.name]) 151 | return results[0] 152 | -------------------------------------------------------------------------------- /lddl/random.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES 3 | # Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | # SPDX-License-Identifier: MIT 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a 7 | # copy of this software and associated documentation files (the "Software"), 8 | # to deal in the Software without restriction, including without limitation 9 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 10 | # and/or sell copies of the Software, and to permit persons to whom the 11 | # Software is furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 19 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 21 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 22 | # DEALINGS IN THE SOFTWARE. 23 | # 24 | 25 | import random 26 | 27 | 28 | def _swap_rng_state(new_state): 29 | old_state = random.getstate() 30 | random.setstate(new_state) 31 | return old_state 32 | 33 | 34 | def randrange(stop, rng_state=None): 35 | orig_rng_state = _swap_rng_state(rng_state) 36 | n = random.randrange(stop) 37 | return n, _swap_rng_state(orig_rng_state) 38 | 39 | 40 | def shuffle(x, rng_state=None): 41 | orig_rng_state = _swap_rng_state(rng_state) 42 | random.shuffle(x) 43 | return _swap_rng_state(orig_rng_state) 44 | 45 | 46 | def sample(population, k, rng_state=None): 47 | orig_rng_state = _swap_rng_state(rng_state) 48 | s = random.sample(population, k) 49 | return s, _swap_rng_state(orig_rng_state) 50 | 51 | 52 | def choices(population, weights=None, cum_weights=None, k=1, rng_state=None): 53 | orig_rng_state = _swap_rng_state(rng_state) 54 | c = random.choices(population, weights=weights, cum_weights=cum_weights, k=k) 55 | return c, _swap_rng_state(orig_rng_state) 56 | -------------------------------------------------------------------------------- /lddl/torch/__init__.py: -------------------------------------------------------------------------------- 1 | from .bert import get_bert_pretrain_data_loader 2 | -------------------------------------------------------------------------------- /lddl/torch/dataloader.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES 3 | # Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | # SPDX-License-Identifier: MIT 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a 7 | # copy of this software and associated documentation files (the "Software"), 8 | # to deal in the Software without restriction, including without limitation 9 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 10 | # and/or sell copies of the Software, and to permit persons to whom the 11 | # Software is furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 19 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 21 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 22 | # DEALINGS IN THE SOFTWARE. 23 | # 24 | 25 | import random 26 | import torch 27 | 28 | from lddl.random import choices 29 | from .datasets import ParquetDataset 30 | 31 | 32 | class Binned: 33 | 34 | def __init__(self, dataloaders, base_seed=12345, start_epoch=0, logger=None): 35 | self._dataloaders = dataloaders 36 | 37 | self._base_seed = base_seed 38 | self._epoch = start_epoch - 1 39 | 40 | self._logger = logger 41 | 42 | self._world_rng_state = None 43 | 44 | def _init_rng_states(self): 45 | orig_rng_state = random.getstate() 46 | 47 | random.seed(self._base_seed + self._epoch) 48 | self._world_rng_state = random.getstate() 49 | 50 | random.setstate(orig_rng_state) 51 | 52 | def _init_iter(self): 53 | self._init_rng_states() 54 | num_samples_remaining = [len(dl.dataset) for dl in self._dataloaders] 55 | dataiters = [iter(dl) for dl in self._dataloaders] 56 | return num_samples_remaining, dataiters 57 | 58 | def __len__(self): 59 | return sum((len(dl) for dl in self._dataloaders)) 60 | 61 | def _get_batch_size(self, batch): 62 | raise NotImplementedError('Binned is an abstract class!') 63 | 64 | def _choices(self, population, weights=None, cum_weights=None, k=1): 65 | c, self._world_rng_state = choices( 66 | population, 67 | weights=weights, 68 | cum_weights=cum_weights, 69 | k=k, 70 | rng_state=self._world_rng_state, 71 | ) 72 | return c 73 | 74 | def __iter__(self): 75 | self._epoch += 1 76 | num_samples_remaining, dataiters = self._init_iter() 77 | 78 | for i in range(len(self)): 79 | bin_id = self._choices( 80 | list(range(len(dataiters))), 81 | weights=num_samples_remaining, 82 | k=1, 83 | )[0] 84 | self._logger.to('rank').info('{}-th iteration selects bin_id = {}'.format( 85 | i, bin_id)) 86 | assert num_samples_remaining[bin_id] > 0 87 | batch = next(dataiters[bin_id]) 88 | num_samples_remaining[bin_id] -= self._get_batch_size(batch) 89 | yield batch 90 | 91 | assert sum((nsr for nsr in num_samples_remaining)) == 0 92 | 93 | 94 | class DataLoader(torch.utils.data.DataLoader): 95 | 96 | def __len__(self): 97 | if isinstance(self.dataset, ParquetDataset): 98 | num_workers_per_rank = max(self.num_workers, 1) 99 | num_files_per_worker = self.dataset.num_files_per_rank // num_workers_per_rank 100 | num_samples_per_worker = self.dataset.num_samples_per_file * num_files_per_worker 101 | num_batches_per_worker = ( 102 | (num_samples_per_worker - 1) // self.batch_size + 1) 103 | return num_batches_per_worker * num_workers_per_rank 104 | else: 105 | super().__len__() 106 | -------------------------------------------------------------------------------- /lddl/torch/datasets.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES 3 | # Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | # SPDX-License-Identifier: MIT 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a 7 | # copy of this software and associated documentation files (the "Software"), 8 | # to deal in the Software without restriction, including without limitation 9 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 10 | # and/or sell copies of the Software, and to permit persons to whom the 11 | # Software is furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 19 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 21 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 22 | # DEALINGS IN THE SOFTWARE. 23 | # 24 | 25 | import itertools 26 | import json 27 | import logging 28 | import numpy as np 29 | import os 30 | import pathlib 31 | import pyarrow.parquet as pq 32 | import random 33 | import torch 34 | import warnings 35 | 36 | from torch.utils.data import IterableDataset 37 | from torch.utils.data import get_worker_info 38 | 39 | from lddl.types import File 40 | from lddl.utils import get_num_samples_of_parquet 41 | from lddl.random import randrange, shuffle, sample 42 | from .utils import (get_rank, get_world_size, get_nproc_per_node, get_num_nodes, 43 | get_node_rank) 44 | 45 | 46 | class ShuffleBuffer: 47 | 48 | def __init__( 49 | self, 50 | files, 51 | max_num_samples_to_yield, 52 | decode_record_batch, 53 | size, 54 | warmup_factor, 55 | logger, 56 | rng_state, 57 | ): 58 | num_samples_wasted = (sum( 59 | (f.num_samples for f in files)) - max_num_samples_to_yield) 60 | assert 0 <= num_samples_wasted <= len(files) 61 | 62 | self._files = files 63 | self._max_num_samples_to_yield = max_num_samples_to_yield 64 | self._decode_record_batch = decode_record_batch 65 | self._size = size 66 | self._warmup_factor = warmup_factor 67 | self._logger = logger 68 | self._rng_state = rng_state 69 | 70 | @property 71 | def num_samples(self): 72 | return sum((f.num_samples for f in self._files)) 73 | 74 | def _randrange(self, stop): 75 | n, self._rng_state = randrange(stop, rng_state=self._rng_state) 76 | return n 77 | 78 | def _shuffle(self, x): 79 | self._rng_state = shuffle(x, rng_state=self._rng_state) 80 | 81 | def __iter__(self): 82 | buffer = [] 83 | num_samples_to_yield = min( 84 | self._max_num_samples_to_yield, 85 | sum((f.num_samples for f in self._files)), 86 | ) 87 | remaining_num_samples = num_samples_to_yield 88 | 89 | for f in self._files: 90 | self._logger.to('worker').info('Reading {}'.format(f.path)) 91 | for b in pq.read_table(f.path).to_batches(): 92 | for sample in self._decode_record_batch(b): 93 | if remaining_num_samples <= 0: 94 | return 95 | if (len(buffer) >= min( 96 | self._size, (num_samples_to_yield - remaining_num_samples + 1) * 97 | self._warmup_factor)): 98 | replace_idx = self._randrange(len(buffer)) 99 | yield buffer[replace_idx] 100 | buffer[replace_idx] = sample 101 | remaining_num_samples -= 1 102 | else: 103 | buffer.append(sample) 104 | self._shuffle(buffer) 105 | for sample in buffer: 106 | if remaining_num_samples <= 0: 107 | return 108 | yield sample 109 | remaining_num_samples -= 1 110 | 111 | 112 | class ParquetDataset(IterableDataset): 113 | 114 | def __init__( 115 | self, 116 | file_paths, 117 | transform=lambda x: x, 118 | local_rank=0, 119 | shuffle_buffer_size=16384, 120 | shuffle_buffer_warmup_factor=16, 121 | base_seed=12345, 122 | logger=None, 123 | start_epoch=0, 124 | ): 125 | super().__init__() 126 | self._transform = transform 127 | self._local_rank = local_rank 128 | self._shuffle_buffer_size = shuffle_buffer_size 129 | self._shuffle_buffer_warmup_factor = shuffle_buffer_warmup_factor 130 | self._base_seed = base_seed 131 | 132 | self._rank = get_rank() 133 | self._world_size = get_world_size() 134 | self._nproc_per_node = get_nproc_per_node(local_rank) 135 | self._num_nodes = get_num_nodes(nproc_per_node=self._nproc_per_node) 136 | self._node_rank = get_node_rank(nproc_per_node=self._nproc_per_node) 137 | 138 | self._epoch = start_epoch - 1 139 | 140 | self._logger = logger 141 | 142 | assert len(file_paths) % self._num_nodes == 0 143 | assert len(file_paths) % self._world_size == 0 144 | self._files = self._get_files(file_paths) 145 | max_num_samples_per_file = max((f.num_samples for f in self._files)) 146 | min_num_samples_per_file = min((f.num_samples for f in self._files)) 147 | assert min_num_samples_per_file + 1 == max_num_samples_per_file 148 | self._num_samples_per_file = min_num_samples_per_file 149 | total_num_samples = sum((f.num_samples for f in self._files)) 150 | num_samples_lost = (total_num_samples - 151 | self._num_samples_per_file * len(self._files)) 152 | self._logger.to('node').warning('lost {}/{}={}% samples in total'.format( 153 | num_samples_lost, 154 | total_num_samples, 155 | num_samples_lost / total_num_samples * 100, 156 | )) 157 | 158 | self._world_rng_state = None 159 | self._worker_rng_state = None 160 | 161 | def _get_files(self, file_paths): 162 | all_files_num_samples = torch.zeros((len(file_paths),), dtype=torch.long) 163 | if self._world_size > 1 and torch.distributed.get_backend() == 'nccl': 164 | all_files_num_samples = all_files_num_samples.to('cuda') 165 | # Figure out how many samples in each file. 166 | num_samples_cache = {} # Map dirname to the dict of {basename: num_samples} 167 | for idx in range(self._rank, len(file_paths), self._world_size): 168 | fp = file_paths[idx] 169 | dn = os.path.dirname(fp) 170 | bn = os.path.basename(fp) 171 | # Load the num_samples cache file if it exists. 172 | if dn not in num_samples_cache: 173 | nsfp = os.path.join(dn, '.num_samples.json') 174 | try: 175 | with open(nsfp, 'r') as nsf: 176 | num_samples_cache[dn] = json.load(nsf) 177 | except Exception as e: 178 | self._logger.to('rank').warning('failed to load {}: {}'.format( 179 | nsfp, e)) 180 | # Mark that the num_samples cache file doesn't exist for this 181 | # directory. 182 | num_samples_cache[dn] = None 183 | if num_samples_cache[dn] is not None and bn in num_samples_cache[dn]: 184 | all_files_num_samples[idx] = num_samples_cache[dn][bn] 185 | else: 186 | # Find out num_samples by loading the parquet table. 187 | all_files_num_samples[idx] = get_num_samples_of_parquet(fp) 188 | if self._world_size > 1: 189 | # Sync. accross all ranks. 190 | torch.distributed.all_reduce( 191 | all_files_num_samples, 192 | op=torch.distributed.ReduceOp.SUM, 193 | ) 194 | all_files_num_samples = all_files_num_samples.tolist() 195 | return [File(fp, ns) for fp, ns in zip(file_paths, all_files_num_samples)] 196 | 197 | def __len__(self): 198 | """ This function only returns how many samples per rank will be yielded 199 | by this dataset. 200 | 201 | Note that, len(dataloader), where dataloader is a PyTorch DataLoader 202 | wrapping this dataset, does NOT return the accurate number of batches. This 203 | is because, when (num_samples_per_file * num_files_per_worker) is not 204 | divisible by batch_size, each worker is going to generate a partial batch 205 | at the very end. 206 | 207 | However, PyTorch DataLoader's __len__ only divide the number returned from 208 | this function by batch_size, which would be smaller than the actual number 209 | of batches by at most (num_workers - 1). 210 | 211 | We need to patch PyTorch DataLoader function for this function to behave 212 | correctly. 213 | """ 214 | return self._num_samples_per_file * len(self._files) // self._world_size 215 | 216 | @property 217 | def num_samples_per_file(self): 218 | return self._num_samples_per_file 219 | 220 | @property 221 | def num_files_per_rank(self): 222 | return len(self._files) // self._world_size 223 | 224 | def _decode_record_batch(self, b): 225 | raise NotImplementedError('ParquetDataset is an abstract/interface class!') 226 | 227 | def _world_identical_sample(self, population, k, counts=None): 228 | s, self._world_rng_state = sample( 229 | population, 230 | k, 231 | rng_state=self._world_rng_state, 232 | ) 233 | return s 234 | 235 | def _init_worker(self): 236 | worker_info = get_worker_info() 237 | if worker_info is None: 238 | num_workers_per_rank = 1 239 | worker_rank = 0 240 | else: 241 | num_workers_per_rank = worker_info.num_workers 242 | worker_rank = worker_info.id 243 | assert (len(self._files) % (self._world_size * num_workers_per_rank) == 0) 244 | self._logger.init_for_worker(worker_rank) 245 | return worker_rank, num_workers_per_rank 246 | 247 | def _init_rng_states(self, worker_rank, num_workers_per_rank): 248 | orig_rng_state = random.getstate() 249 | 250 | random.seed(self._base_seed + self._epoch) 251 | self._world_rng_state = random.getstate() 252 | 253 | random.seed(self._base_seed + 254 | (self._epoch * self._world_size + self._rank) * 255 | num_workers_per_rank + worker_rank) 256 | self._worker_rng_state = random.getstate() 257 | 258 | random.setstate(orig_rng_state) 259 | 260 | def __iter__(self): 261 | self._epoch += 1 262 | 263 | worker_rank, num_workers_per_rank = self._init_worker() 264 | self._init_rng_states(worker_rank, num_workers_per_rank) 265 | 266 | files = self._world_identical_sample(self._files, k=len(self._files)) 267 | self._logger.to('node').warning('epoch = {}'.format(self._epoch)) 268 | self._logger.to('worker').info( 269 | '\n'.join(['files('] + [' {}'.format(f) for f in files] + [')'])) 270 | 271 | rank_files = files[self._rank::self._world_size] 272 | worker_files = rank_files[worker_rank::num_workers_per_rank] 273 | self._logger.to('worker').info( 274 | '\n'.join(['worker_files('] + [' {}'.format(f) for f in worker_files] + 275 | [')'])) 276 | sb = ShuffleBuffer( 277 | worker_files, 278 | self._num_samples_per_file * len(worker_files), 279 | lambda b: self._decode_record_batch(b), 280 | self._shuffle_buffer_size, 281 | self._shuffle_buffer_warmup_factor, 282 | self._logger, 283 | self._worker_rng_state, 284 | ) 285 | for sample in iter(sb): 286 | yield self._transform(sample) 287 | -------------------------------------------------------------------------------- /lddl/torch/log.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES 3 | # Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | # SPDX-License-Identifier: MIT 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a 7 | # copy of this software and associated documentation files (the "Software"), 8 | # to deal in the Software without restriction, including without limitation 9 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 10 | # and/or sell copies of the Software, and to permit persons to whom the 11 | # Software is furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 19 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 21 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 22 | # DEALINGS IN THE SOFTWARE. 23 | # 24 | 25 | import logging 26 | import os 27 | import pathlib 28 | 29 | 30 | def _get_logger_name(node_rank, local_rank=None, worker_rank=None): 31 | if local_rank is None and worker_rank is None: 32 | return 'node-{}'.format(node_rank) 33 | elif worker_rank is None: 34 | return 'node-{}_local-{}'.format(node_rank, local_rank) 35 | else: 36 | return 'node-{}_local-{}_worker-{}'.format(node_rank, local_rank, 37 | worker_rank) 38 | 39 | 40 | class DummyLogger: 41 | 42 | def debug(self, msg, *args, **kwargs): 43 | pass 44 | 45 | def info(self, msg, *args, **kwargs): 46 | pass 47 | 48 | def warning(self, msg, *args, **kwargs): 49 | pass 50 | 51 | def error(self, msg, *args, **kwargs): 52 | pass 53 | 54 | def critical(self, msg, *args, **kwargs): 55 | pass 56 | 57 | def log(self, msg, *args, **kwargs): 58 | pass 59 | 60 | def exception(self, msg, *args, **kwargs): 61 | pass 62 | 63 | 64 | class DatasetLogger: 65 | 66 | def __init__( 67 | self, 68 | log_dir=None, 69 | node_rank=0, 70 | local_rank=0, 71 | log_level=logging.INFO, 72 | ): 73 | self._log_dir = log_dir 74 | self._node_rank = node_rank 75 | self._local_rank = local_rank 76 | self._worker_rank = None 77 | self._log_level = log_level 78 | 79 | if log_dir is not None: 80 | pathlib.Path(log_dir).mkdir(parents=True, exist_ok=True) 81 | # Create node level logger. 82 | if local_rank == 0: 83 | self._create_logger(_get_logger_name(node_rank)) 84 | # Create local_rank level logger. 85 | self._create_logger(_get_logger_name(node_rank, local_rank=local_rank)) 86 | 87 | def _create_logger(self, name): 88 | logger = logging.getLogger(name) 89 | fmt = logging.Formatter( 90 | 'LDDL - %(asctime)s - %(filename)s:%(lineno)d:%(funcName)s - %(name)s ' 91 | '- %(levelname)s : %(message)s') 92 | stream_handler = logging.StreamHandler() 93 | stream_handler.setFormatter(fmt) 94 | logger.addHandler(stream_handler) 95 | if self._log_dir is not None: 96 | path = os.path.join(self._log_dir, '{}.txt'.format(name)) 97 | file_handler = logging.FileHandler(path) 98 | file_handler.setFormatter(fmt) 99 | logger.addHandler(file_handler) 100 | logger.setLevel(self._log_level) 101 | return logger 102 | 103 | def init_for_worker(self, worker_rank): 104 | if self._worker_rank is None: 105 | self._worker_rank = worker_rank 106 | self._create_logger( 107 | _get_logger_name( 108 | self._node_rank, 109 | local_rank=self._local_rank, 110 | worker_rank=worker_rank, 111 | )) 112 | 113 | def to(self, which): 114 | assert which in {'node', 'rank', 'worker'} 115 | if which == 'node': 116 | if (self._local_rank == 0 and 117 | (self._worker_rank is None or self._worker_rank == 0)): 118 | return logging.getLogger(_get_logger_name(self._node_rank)) 119 | else: 120 | return DummyLogger() 121 | elif which == 'rank': 122 | if self._worker_rank is None or self._worker_rank == 0: 123 | return logging.getLogger( 124 | _get_logger_name(self._node_rank, local_rank=self._local_rank)) 125 | else: 126 | return DummyLogger() 127 | else: # which == 'worker' 128 | return logging.getLogger( 129 | _get_logger_name( 130 | self._node_rank, 131 | local_rank=self._local_rank, 132 | worker_rank=self._worker_rank, 133 | )) 134 | -------------------------------------------------------------------------------- /lddl/torch/utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES 3 | # Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | # SPDX-License-Identifier: MIT 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a 7 | # copy of this software and associated documentation files (the "Software"), 8 | # to deal in the Software without restriction, including without limitation 9 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 10 | # and/or sell copies of the Software, and to permit persons to whom the 11 | # Software is furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 19 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 21 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 22 | # DEALINGS IN THE SOFTWARE. 23 | # 24 | 25 | import torch 26 | 27 | 28 | def barrier(): 29 | if torch.distributed.is_available() and torch.distributed.is_initialized(): 30 | torch.distributed.barrier() 31 | 32 | 33 | def get_rank(): 34 | if torch.distributed.is_available() and torch.distributed.is_initialized(): 35 | rank = torch.distributed.get_rank() 36 | else: 37 | rank = 0 38 | return rank 39 | 40 | 41 | def get_world_size(): 42 | if torch.distributed.is_available() and torch.distributed.is_initialized(): 43 | world_size = torch.distributed.get_world_size() 44 | else: 45 | world_size = 1 46 | return world_size 47 | 48 | 49 | def get_nproc_per_node(local_rank): 50 | if torch.distributed.is_available() and torch.distributed.is_initialized(): 51 | max_local_rank = torch.tensor( 52 | local_rank, 53 | device='cuda' if torch.distributed.get_backend() == 'nccl' else 'cpu', 54 | ) 55 | torch.distributed.all_reduce( 56 | max_local_rank, 57 | op=torch.distributed.ReduceOp.MAX, 58 | ) 59 | nproc_per_node = max_local_rank.item() + 1 60 | else: 61 | nproc_per_node = 1 62 | return nproc_per_node 63 | 64 | 65 | def get_num_nodes(local_rank=None, nproc_per_node=None): 66 | if torch.distributed.is_available() and torch.distributed.is_initialized(): 67 | if nproc_per_node is None: 68 | assert local_rank is not None 69 | nproc_per_node = get_nproc_per_node(local_rank) 70 | num_nodes = get_world_size() // nproc_per_node 71 | else: 72 | num_nodes = 1 73 | return num_nodes 74 | 75 | 76 | def get_node_rank(local_rank=None, nproc_per_node=None): 77 | """ This assume the training processes are launched via 78 | torch.distributed.launch.py. Therefore, the ordering scheme of 79 | rank -> (node_rank, local_rank) mapping is: 80 | 0 -> (0, 0) 81 | 1 -> (0, 1) 82 | ... 83 | nproc_per_node -> (1, 0) 84 | nproc_per_node+1 -> (1, 1) 85 | ... 86 | """ 87 | if torch.distributed.is_available() and torch.distributed.is_initialized(): 88 | if nproc_per_node is None: 89 | assert local_rank is not None 90 | nproc_per_node = get_nproc_per_node(local_rank) 91 | node_rank = get_rank() // nproc_per_node 92 | else: 93 | node_rank = 0 94 | return node_rank 95 | -------------------------------------------------------------------------------- /lddl/torch_mp/__init__.py: -------------------------------------------------------------------------------- 1 | from .bert import get_bert_pretrain_data_loader 2 | -------------------------------------------------------------------------------- /lddl/torch_mp/dataloader.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES 3 | # Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | # SPDX-License-Identifier: MIT 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a 7 | # copy of this software and associated documentation files (the "Software"), 8 | # to deal in the Software without restriction, including without limitation 9 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 10 | # and/or sell copies of the Software, and to permit persons to whom the 11 | # Software is furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 19 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 21 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 22 | # DEALINGS IN THE SOFTWARE. 23 | # 24 | 25 | import random 26 | import torch 27 | 28 | from lddl.random import choices 29 | from .datasets import ParquetDataset 30 | from .utils import get_rank 31 | 32 | 33 | class Binned: 34 | 35 | def __init__(self, 36 | dataloaders, 37 | base_seed=12345, 38 | start_epoch=0, 39 | global_batch_size=64, 40 | logger=None): 41 | self._dataloaders = dataloaders 42 | 43 | self._base_seed = base_seed 44 | self._epoch = start_epoch - 1 45 | 46 | self._logger = logger 47 | 48 | self._world_rng_state = None 49 | self.current_iteration = 0 50 | self.global_batch_size = global_batch_size 51 | self.bin_id = None 52 | self.global_batch = [] 53 | 54 | def _init_rng_states(self): 55 | orig_rng_state = random.getstate() 56 | 57 | random.seed(self._base_seed + self._epoch) 58 | self._world_rng_state = random.getstate() 59 | 60 | random.setstate(orig_rng_state) 61 | 62 | def _init_iter(self): 63 | self._init_rng_states() 64 | num_samples_remaining = [len(dl.dataset) for dl in self._dataloaders] 65 | dataiters = [iter(dl) for dl in self._dataloaders] 66 | return num_samples_remaining, dataiters 67 | 68 | def __len__(self): 69 | return sum((len(dl) for dl in self._dataloaders)) 70 | 71 | def _get_batch_size(self, batch): 72 | raise NotImplementedError('Binned is an abstract class!') 73 | 74 | def _choices(self, population, weights=None, cum_weights=None, k=1): 75 | c, self._world_rng_state = choices( 76 | population, 77 | weights=weights, 78 | cum_weights=cum_weights, 79 | k=k, 80 | rng_state=self._world_rng_state, 81 | ) 82 | return c 83 | 84 | def get_samples_seen_datasets(self, samples_seen, batch_size): 85 | num_samples_remaining, dataiters = self._init_iter() 86 | # Skip epochs that have already been seen 87 | self._epoch = samples_seen // sum(num_samples_remaining) 88 | samples_seen = samples_seen % sum(num_samples_remaining) 89 | self._init_rng_states() 90 | if samples_seen > 0: 91 | bins_samples_seen = [0] * len(self._dataloaders) 92 | while samples_seen > 0: 93 | bin_id = self._choices( 94 | list(range(len(self._dataloaders))), 95 | weights=num_samples_remaining, 96 | k=1, 97 | )[0] 98 | num_samples_remaining[bin_id] -= self.global_batch_size 99 | bins_samples_seen[bin_id] += self.global_batch_size 100 | samples_seen -= self.global_batch_size 101 | return bins_samples_seen, self._epoch 102 | 103 | def set_next(self): 104 | # At the end of the epoch setting Global_batch to None to let iterator know we are done 105 | if max(self.num_samples_remaining) <= self.global_batch_size: 106 | self.global_batch = None 107 | else: 108 | if self.global_batch == []: 109 | self.bin_id = self._choices( 110 | list(range(len(self.dataiters))), 111 | weights=self.num_samples_remaining, 112 | k=1, 113 | )[0] 114 | self.global_batch = next(self.dataiters[self.bin_id]) 115 | self.num_samples_remaining[self.bin_id] -= self.global_batch_size 116 | self.current_iteration += 1 117 | 118 | def get_seqlen(self): 119 | return self.global_batch[0]['text'].shape[1] 120 | 121 | def __next__(self): 122 | if self.global_batch is None: 123 | return StopIteration 124 | else: 125 | sample = self.global_batch.pop() 126 | self.set_next() 127 | return sample 128 | 129 | def __iter__(self): 130 | self._epoch += 1 131 | self.num_samples_remaining, self.dataiters = self._init_iter() 132 | self.set_next() 133 | return self 134 | 135 | 136 | class DataLoader(torch.utils.data.DataLoader): 137 | 138 | def __len__(self): 139 | if isinstance(self.dataset, ParquetDataset): 140 | num_workers_per_rank = max(self.num_workers, 1) 141 | num_files_per_worker = self.dataset.num_files_per_rank // num_workers_per_rank 142 | num_samples_per_worker = self.dataset.num_samples_per_file * num_files_per_worker 143 | num_batches_per_worker = ( 144 | (num_samples_per_worker - 1) // self.batch_size + 1) 145 | return num_batches_per_worker * num_workers_per_rank 146 | else: 147 | super().__len__() -------------------------------------------------------------------------------- /lddl/torch_mp/datasets.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES 3 | # Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | # SPDX-License-Identifier: MIT 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a 7 | # copy of this software and associated documentation files (the "Software"), 8 | # to deal in the Software without restriction, including without limitation 9 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 10 | # and/or sell copies of the Software, and to permit persons to whom the 11 | # Software is furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 19 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 21 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 22 | # DEALINGS IN THE SOFTWARE. 23 | # 24 | 25 | import itertools 26 | import json 27 | import logging 28 | import numpy as np 29 | import os 30 | import pathlib 31 | import pyarrow.parquet as pq 32 | import random 33 | import torch 34 | import warnings 35 | 36 | from torch.utils.data import IterableDataset 37 | from torch.utils.data import get_worker_info 38 | 39 | from lddl.types import File 40 | from lddl.utils import get_num_samples_of_parquet 41 | from lddl.random import randrange, shuffle, sample 42 | from .utils import ( 43 | get_rank, 44 | get_world_size, 45 | get_nproc_per_node, 46 | get_num_nodes, 47 | get_node_rank, 48 | get_dp_size, 49 | ) 50 | 51 | 52 | class ShuffleBuffer: 53 | 54 | def __init__(self, files, max_num_samples_to_yield, decode_record_batch, size, 55 | warmup_factor, logger, rng_state, samples_seen): 56 | num_samples_wasted = (sum( 57 | (f.num_samples for f in files)) - max_num_samples_to_yield) 58 | assert 0 <= num_samples_wasted <= len(files) 59 | 60 | self._files = files 61 | self._max_num_samples_to_yield = max_num_samples_to_yield 62 | self._decode_record_batch = decode_record_batch 63 | self._size = size 64 | self._warmup_factor = warmup_factor 65 | self._logger = logger 66 | self._rng_state = rng_state 67 | self.samples_seen = samples_seen 68 | 69 | @property 70 | def num_samples(self): 71 | return sum((f.num_samples for f in self._files)) 72 | 73 | def _randrange(self, stop): 74 | n, self._rng_state = randrange(stop, rng_state=self._rng_state) 75 | return n 76 | 77 | def _shuffle(self, x): 78 | self._rng_state = shuffle(x, rng_state=self._rng_state) 79 | 80 | def __iter__(self): 81 | buffer = [] 82 | num_samples_to_yield = min( 83 | self._max_num_samples_to_yield, 84 | sum((f.num_samples for f in self._files)) - self.samples_seen, 85 | ) 86 | remaining_num_samples = num_samples_to_yield 87 | for f in self._files: 88 | self._logger.to('worker').info('Reading {}'.format(f.path)) 89 | if self.samples_seen > 0: 90 | len_par = f.num_samples 91 | # Skip entire parquet if possible 92 | if len_par < self.samples_seen: 93 | self.samples_seen -= len_par 94 | continue 95 | pq_table = pq.read_table(f.path) 96 | if self.samples_seen > 0: 97 | pq_table = pq_table.slice(self.samples_seen) 98 | self.samples_seen = 0 99 | 100 | for b in pq_table.to_batches(): 101 | for sample in self._decode_record_batch(b): 102 | if remaining_num_samples <= 0: 103 | return 104 | if (len(buffer) 105 | >= min(self._size, 106 | (num_samples_to_yield - remaining_num_samples + 1) * 107 | self._warmup_factor)): 108 | replace_idx = self._randrange(len(buffer)) 109 | yield buffer[replace_idx] 110 | buffer[replace_idx] = sample 111 | remaining_num_samples -= 1 112 | else: 113 | buffer.append(sample) 114 | self._shuffle(buffer) 115 | for sample in buffer: 116 | if remaining_num_samples <= 0: 117 | return 118 | yield sample 119 | remaining_num_samples -= 1 120 | 121 | 122 | class ParquetDataset(IterableDataset): 123 | 124 | def __init__( 125 | self, 126 | file_paths, 127 | samples_seen=0, 128 | transform=lambda x: x, 129 | local_rank=0, 130 | dp_rank=0, 131 | shuffle_buffer_size=16384, 132 | shuffle_buffer_warmup_factor=16, 133 | base_seed=12345, 134 | logger=None, 135 | start_epoch=0, 136 | ): 137 | super().__init__() 138 | self._transform = transform 139 | self._local_rank = local_rank 140 | self.dp_rank = dp_rank 141 | self._shuffle_buffer_size = shuffle_buffer_size 142 | self._shuffle_buffer_warmup_factor = shuffle_buffer_warmup_factor 143 | self._base_seed = base_seed 144 | 145 | self._rank = get_rank() 146 | self._world_size = get_world_size() 147 | self._nproc_per_node = get_nproc_per_node(local_rank) 148 | self._num_dp_groups = get_dp_size(dp_rank) 149 | self._num_nodes = get_num_nodes(nproc_per_node=self._nproc_per_node) 150 | self._node_rank = get_node_rank(nproc_per_node=self._nproc_per_node) 151 | 152 | self._epoch = start_epoch - 1 153 | 154 | self._logger = logger 155 | self.samples_seen = samples_seen 156 | 157 | assert len(file_paths) % self._num_nodes == 0 158 | assert len(file_paths) % self._world_size == 0 159 | self._files = self._get_files(file_paths) 160 | max_num_samples_per_file = max((f.num_samples for f in self._files)) 161 | min_num_samples_per_file = min((f.num_samples for f in self._files)) 162 | assert min_num_samples_per_file + 1 == max_num_samples_per_file 163 | self._num_samples_per_file = min_num_samples_per_file 164 | total_num_samples = sum((f.num_samples for f in self._files)) 165 | num_samples_lost = (total_num_samples - 166 | self._num_samples_per_file * len(self._files)) 167 | self._logger.to('node').warning('lost {}/{}={}% samples in total'.format( 168 | num_samples_lost, 169 | total_num_samples, 170 | num_samples_lost / total_num_samples * 100, 171 | )) 172 | 173 | self._world_rng_state = None 174 | self._worker_rng_state = None 175 | 176 | def _get_files(self, file_paths): 177 | all_files_num_samples = torch.zeros((len(file_paths),), dtype=torch.long) 178 | if self._world_size > 1 and torch.distributed.get_backend() == 'nccl': 179 | all_files_num_samples = all_files_num_samples.to('cuda') 180 | # Figure out how many samples in each file. 181 | num_samples_cache = {} # Map dirname to the dict of {basename: num_samples} 182 | 183 | for idx in range(self._rank, len(file_paths), self._world_size): 184 | fp = file_paths[idx] 185 | dn = os.path.dirname(fp) 186 | bn = os.path.basename(fp) 187 | # Load the num_samples cache file if it exists. 188 | if dn not in num_samples_cache: 189 | nsfp = os.path.join(dn, '.num_samples.json') 190 | try: 191 | with open(nsfp, 'r') as nsf: 192 | num_samples_cache[dn] = json.load(nsf) 193 | except Exception as e: 194 | self._logger.to('rank').warning('failed to load {}: {}'.format( 195 | nsfp, e)) 196 | # Mark that the num_samples cache file doesn't exist for this 197 | # directory. 198 | num_samples_cache[dn] = None 199 | if num_samples_cache[dn] is not None and bn in num_samples_cache[dn]: 200 | all_files_num_samples[idx] = num_samples_cache[dn][bn] 201 | else: 202 | # Find out num_samples by loading the parquet table. 203 | all_files_num_samples[idx] = get_num_samples_of_parquet(fp) 204 | if self._world_size > 1: 205 | # Sync. accross all ranks. 206 | torch.distributed.all_reduce( 207 | all_files_num_samples, 208 | op=torch.distributed.ReduceOp.SUM, 209 | ) 210 | all_files_num_samples = all_files_num_samples.tolist() 211 | return [File(fp, ns) for fp, ns in zip(file_paths, all_files_num_samples)] 212 | 213 | def __len__(self): 214 | """ This function only returns how many samples per rank will be yielded 215 | by this dataset. 216 | 217 | Note that, len(dataloader), where dataloader is a PyTorch DataLoader 218 | wrapping this dataset, does NOT return the accurate number of batches. This 219 | is because, when (num_samples_per_file * num_files_per_worker) is not 220 | divisible by batch_size, each worker is going to generate a partial batch 221 | at the very end. 222 | 223 | However, PyTorch DataLoader's __len__ only divide the number returned from 224 | this function by batch_size, which would be smaller than the actual number 225 | of batches by at most (num_workers - 1). 226 | 227 | We need to patch PyTorch DataLoader function for this function to behave 228 | correctly. 229 | """ 230 | return (self._num_samples_per_file * len(self._files) // 231 | self._num_dp_groups) - self.samples_seen 232 | 233 | @property 234 | def num_samples_per_file(self): 235 | return self._num_samples_per_file 236 | 237 | @property 238 | def num_files_per_rank(self): 239 | return len(self._files) // self._num_dp_groups 240 | 241 | def _decode_record_batch(self, b): 242 | raise NotImplementedError('ParquetDataset is an abstract/interface class!') 243 | 244 | def _world_identical_sample(self, population, k, counts=None): 245 | s, self._world_rng_state = sample( 246 | population, 247 | k, 248 | rng_state=self._world_rng_state, 249 | ) 250 | return s 251 | 252 | def _init_worker(self): 253 | worker_info = get_worker_info() 254 | if worker_info is None: 255 | num_workers_per_rank = 1 256 | worker_rank = 0 257 | else: 258 | num_workers_per_rank = worker_info.num_workers 259 | worker_rank = worker_info.id 260 | assert (len(self._files) % (self._world_size * num_workers_per_rank) == 0) 261 | self._logger.init_for_worker(worker_rank) 262 | return worker_rank, num_workers_per_rank 263 | 264 | def _init_rng_states(self, worker_rank, num_workers_per_rank): 265 | orig_rng_state = random.getstate() 266 | 267 | random.seed(self._base_seed + self._epoch) 268 | self._world_rng_state = random.getstate() 269 | 270 | worker_seed_num = self._base_seed + ( 271 | self._epoch * self._world_size + 272 | self.dp_rank) * num_workers_per_rank + worker_rank 273 | random.seed(worker_seed_num) 274 | self._worker_rng_state = random.getstate() 275 | 276 | random.setstate(orig_rng_state) 277 | 278 | def __iter__(self): 279 | self._epoch += 1 280 | 281 | worker_rank, num_workers_per_rank = self._init_worker() 282 | self._init_rng_states(worker_rank, num_workers_per_rank) 283 | 284 | files = self._world_identical_sample(self._files, k=len(self._files)) 285 | self._logger.to('node').warning('epoch = {}'.format(self._epoch)) 286 | 287 | rank_files = files[self.dp_rank::self._num_dp_groups] 288 | worker_files = rank_files[worker_rank::num_workers_per_rank] 289 | 290 | self.sb = ShuffleBuffer(worker_files, 291 | self._num_samples_per_file * len(worker_files), 292 | lambda b: self._decode_record_batch(b), 293 | self._shuffle_buffer_size, 294 | self._shuffle_buffer_warmup_factor, self._logger, 295 | self._worker_rng_state, self.samples_seen) 296 | for sample in iter(self.sb): 297 | sample = self._transform(sample) 298 | yield sample 299 | self.samples_seen = 0 300 | -------------------------------------------------------------------------------- /lddl/torch_mp/log.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES 3 | # Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | # SPDX-License-Identifier: MIT 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a 7 | # copy of this software and associated documentation files (the "Software"), 8 | # to deal in the Software without restriction, including without limitation 9 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 10 | # and/or sell copies of the Software, and to permit persons to whom the 11 | # Software is furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 19 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 21 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 22 | # DEALINGS IN THE SOFTWARE. 23 | # 24 | 25 | import logging 26 | import os 27 | import pathlib 28 | 29 | 30 | def _get_logger_name(node_rank, local_rank=None, worker_rank=None): 31 | if local_rank is None and worker_rank is None: 32 | return 'node-{}'.format(node_rank) 33 | elif worker_rank is None: 34 | return 'node-{}_local-{}'.format(node_rank, local_rank) 35 | else: 36 | return 'node-{}_local-{}_worker-{}'.format(node_rank, local_rank, 37 | worker_rank) 38 | 39 | 40 | class DummyLogger: 41 | 42 | def debug(self, msg, *args, **kwargs): 43 | pass 44 | 45 | def info(self, msg, *args, **kwargs): 46 | pass 47 | 48 | def warning(self, msg, *args, **kwargs): 49 | pass 50 | 51 | def error(self, msg, *args, **kwargs): 52 | pass 53 | 54 | def critical(self, msg, *args, **kwargs): 55 | pass 56 | 57 | def log(self, msg, *args, **kwargs): 58 | pass 59 | 60 | def exception(self, msg, *args, **kwargs): 61 | pass 62 | 63 | 64 | class DatasetLogger: 65 | 66 | def __init__( 67 | self, 68 | log_dir=None, 69 | node_rank=0, 70 | local_rank=0, 71 | log_level=logging.INFO, 72 | ): 73 | self._log_dir = log_dir 74 | self._node_rank = node_rank 75 | self._local_rank = local_rank 76 | self._worker_rank = None 77 | self._log_level = log_level 78 | 79 | if log_dir is not None: 80 | pathlib.Path(log_dir).mkdir(parents=True, exist_ok=True) 81 | # Create node level logger. 82 | if local_rank == 0: 83 | self._create_logger(_get_logger_name(node_rank)) 84 | # Create local_rank level logger. 85 | self._create_logger(_get_logger_name(node_rank, local_rank=local_rank)) 86 | 87 | def _create_logger(self, name): 88 | logger = logging.getLogger(name) 89 | fmt = logging.Formatter( 90 | 'LDDL - %(asctime)s - %(filename)s:%(lineno)d:%(funcName)s - %(name)s ' 91 | '- %(levelname)s : %(message)s') 92 | stream_handler = logging.StreamHandler() 93 | stream_handler.setFormatter(fmt) 94 | logger.addHandler(stream_handler) 95 | if self._log_dir is not None: 96 | path = os.path.join(self._log_dir, '{}.txt'.format(name)) 97 | file_handler = logging.FileHandler(path) 98 | file_handler.setFormatter(fmt) 99 | logger.addHandler(file_handler) 100 | logger.setLevel(self._log_level) 101 | return logger 102 | 103 | def init_for_worker(self, worker_rank): 104 | if self._worker_rank is None: 105 | self._worker_rank = worker_rank 106 | self._create_logger( 107 | _get_logger_name( 108 | self._node_rank, 109 | local_rank=self._local_rank, 110 | worker_rank=worker_rank, 111 | )) 112 | 113 | def to(self, which): 114 | assert which in {'node', 'rank', 'worker'} 115 | if which == 'node': 116 | if (self._local_rank == 0 and 117 | (self._worker_rank is None or self._worker_rank == 0)): 118 | return logging.getLogger(_get_logger_name(self._node_rank)) 119 | else: 120 | return DummyLogger() 121 | elif which == 'rank': 122 | if self._worker_rank is None or self._worker_rank == 0: 123 | return logging.getLogger( 124 | _get_logger_name(self._node_rank, local_rank=self._local_rank)) 125 | else: 126 | return DummyLogger() 127 | else: # which == 'worker' 128 | return logging.getLogger( 129 | _get_logger_name( 130 | self._node_rank, 131 | local_rank=self._local_rank, 132 | worker_rank=self._worker_rank, 133 | )) 134 | -------------------------------------------------------------------------------- /lddl/torch_mp/utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES 3 | # Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | # SPDX-License-Identifier: MIT 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a 7 | # copy of this software and associated documentation files (the "Software"), 8 | # to deal in the Software without restriction, including without limitation 9 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 10 | # and/or sell copies of the Software, and to permit persons to whom the 11 | # Software is furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 19 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 21 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 22 | # DEALINGS IN THE SOFTWARE. 23 | # 24 | 25 | import torch 26 | 27 | 28 | def barrier(): 29 | if torch.distributed.is_available() and torch.distributed.is_initialized(): 30 | torch.distributed.barrier() 31 | 32 | 33 | def get_dp_size(dp_rank): 34 | """ 35 | This helper function will return how many data parallel groups we have in our 36 | system. 37 | """ 38 | if torch.distributed.is_available() and torch.distributed.is_initialized(): 39 | max_dp_rank = torch.tensor( 40 | dp_rank, 41 | device='cuda' if torch.distributed.get_backend() == 'nccl' else 'cpu', 42 | ) 43 | torch.distributed.all_reduce( 44 | max_dp_rank, 45 | op=torch.distributed.ReduceOp.MAX, 46 | ) 47 | dp_size = max_dp_rank.item() + 1 48 | else: 49 | dp_size = 1 50 | return dp_size 51 | 52 | 53 | def get_rank(): 54 | if torch.distributed.is_available() and torch.distributed.is_initialized(): 55 | rank = torch.distributed.get_rank() 56 | else: 57 | rank = 0 58 | return rank 59 | 60 | 61 | def get_world_size(): 62 | if torch.distributed.is_available() and torch.distributed.is_initialized(): 63 | world_size = torch.distributed.get_world_size() 64 | else: 65 | world_size = 1 66 | return world_size 67 | 68 | 69 | def get_nproc_per_node(local_rank): 70 | if torch.distributed.is_available() and torch.distributed.is_initialized(): 71 | max_local_rank = torch.tensor( 72 | local_rank, 73 | device='cuda' if torch.distributed.get_backend() == 'nccl' else 'cpu', 74 | ) 75 | torch.distributed.all_reduce( 76 | max_local_rank, 77 | op=torch.distributed.ReduceOp.MAX, 78 | ) 79 | nproc_per_node = max_local_rank.item() + 1 80 | else: 81 | nproc_per_node = 1 82 | return nproc_per_node 83 | 84 | 85 | def get_num_nodes(local_rank=None, nproc_per_node=None): 86 | if torch.distributed.is_available() and torch.distributed.is_initialized(): 87 | if nproc_per_node is None: 88 | assert local_rank is not None 89 | nproc_per_node = get_nproc_per_node(local_rank) 90 | num_nodes = get_world_size() // nproc_per_node 91 | else: 92 | num_nodes = 1 93 | return num_nodes 94 | 95 | 96 | def get_node_rank(local_rank=None, nproc_per_node=None): 97 | """ This assume the training processes are launched via 98 | torch.distributed.launch.py. Therefore, the ordering scheme of 99 | rank -> (node_rank, local_rank) mapping is: 100 | 0 -> (0, 0) 101 | 1 -> (0, 1) 102 | ... 103 | nproc_per_node -> (1, 0) 104 | nproc_per_node+1 -> (1, 1) 105 | ... 106 | """ 107 | if torch.distributed.is_available() and torch.distributed.is_initialized(): 108 | if nproc_per_node is None: 109 | assert local_rank is not None 110 | nproc_per_node = get_nproc_per_node(local_rank) 111 | node_rank = get_rank() // nproc_per_node 112 | else: 113 | node_rank = 0 114 | return node_rank 115 | -------------------------------------------------------------------------------- /lddl/types.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES 3 | # Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | # SPDX-License-Identifier: MIT 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a 7 | # copy of this software and associated documentation files (the "Software"), 8 | # to deal in the Software without restriction, including without limitation 9 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 10 | # and/or sell copies of the Software, and to permit persons to whom the 11 | # Software is furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 19 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 21 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 22 | # DEALINGS IN THE SOFTWARE. 23 | # 24 | 25 | 26 | class File: 27 | 28 | def __init__(self, path, num_samples): 29 | self.path = path 30 | self.num_samples = num_samples 31 | 32 | def __repr__(self): 33 | return 'File(path={}, num_samples={})'.format(self.path, self.num_samples) 34 | -------------------------------------------------------------------------------- /lddl/utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES 3 | # Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | # SPDX-License-Identifier: MIT 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a 7 | # copy of this software and associated documentation files (the "Software"), 8 | # to deal in the Software without restriction, including without limitation 9 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 10 | # and/or sell copies of the Software, and to permit persons to whom the 11 | # Software is furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 19 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 21 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 22 | # DEALINGS IN THE SOFTWARE. 23 | # 24 | 25 | import os 26 | import io 27 | import numpy as np 28 | import pathlib 29 | import pyarrow.parquet as pq 30 | 31 | 32 | def mkdir(d): 33 | pathlib.Path(d).mkdir(parents=True, exist_ok=True) 34 | 35 | 36 | def expand_outdir_and_mkdir(outdir): 37 | outdir = os.path.abspath(os.path.expanduser(outdir)) 38 | mkdir(outdir) 39 | return outdir 40 | 41 | 42 | def get_all_files_paths_under(root): 43 | return ( 44 | os.path.join(r, f) for r, subdirs, files in os.walk(root) for f in files) 45 | 46 | 47 | def get_all_parquets_under(path): 48 | return sorted([ 49 | p for p in get_all_files_paths_under(path) 50 | if '.parquet' in os.path.splitext(p)[1] 51 | ]) 52 | 53 | 54 | def get_all_bin_ids(file_paths): 55 | 56 | def is_binned_parquet(p): 57 | return '_' in os.path.splitext(p)[1] 58 | 59 | def get_bin_id(p): 60 | return int(os.path.splitext(p)[1].split('_')[-1]) 61 | 62 | bin_ids = list( 63 | sorted(set((get_bin_id(p) for p in file_paths if is_binned_parquet(p))))) 64 | for a, e in zip(bin_ids, range(len(bin_ids))): 65 | if a != e: 66 | raise ValueError('bin id must be contiguous integers starting from 0!') 67 | return bin_ids 68 | 69 | 70 | def get_file_paths_for_bin_id(file_paths, bin_id): 71 | return [ 72 | p for p in file_paths 73 | if '.parquet_{}'.format(bin_id) == os.path.splitext(p)[1] 74 | ] 75 | 76 | 77 | def get_num_samples_of_parquet(path): 78 | return len(pq.read_table(path)) 79 | 80 | 81 | def attach_bool_arg(parser, flag_name, default=False, help_str=None): 82 | attr_name = flag_name.replace('-', '_') 83 | parser.add_argument( 84 | '--{}'.format(flag_name), 85 | dest=attr_name, 86 | action='store_true', 87 | help=flag_name.replace('-', ' ') if help_str is None else help_str, 88 | ) 89 | parser.add_argument( 90 | '--no-{}'.format(flag_name), 91 | dest=attr_name, 92 | action='store_false', 93 | help=flag_name.replace('-', ' ') if help_str is None else help_str, 94 | ) 95 | parser.set_defaults(**{attr_name: default}) 96 | 97 | 98 | def serialize_np_array(a): 99 | memfile = io.BytesIO() 100 | np.save(memfile, a) 101 | memfile.seek(0) 102 | return memfile.read() 103 | 104 | 105 | def deserialize_np_array(b): 106 | memfile = io.BytesIO() 107 | memfile.write(b) 108 | memfile.seek(0) 109 | return np.load(memfile) 110 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES 3 | # Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | # SPDX-License-Identifier: MIT 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a 7 | # copy of this software and associated documentation files (the "Software"), 8 | # to deal in the Software without restriction, including without limitation 9 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 10 | # and/or sell copies of the Software, and to permit persons to whom the 11 | # Software is furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in 14 | # all copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 19 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 21 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 22 | # DEALINGS IN THE SOFTWARE. 23 | # 24 | 25 | from setuptools import setup, find_packages 26 | import pathlib 27 | 28 | here = pathlib.Path(__file__).parent.resolve() 29 | 30 | long_description = (here / 'README.md').read_text(encoding='utf-8') 31 | 32 | setup( 33 | name='lddl', 34 | version='0.1.0', 35 | description= 36 | 'Language Datasets and Data Loaders for NVIDIA Deep Learning Examples', 37 | long_description=long_description, 38 | long_description_content_type='text/markdown', 39 | url='github.com/NVIDIA/DeepLearningExamples/tools/lddl', 40 | author='Shang Wang', 41 | author_email='shangw@nvidia.com', 42 | classifiers=[ 43 | 'Development Status :: 3 - Alpha', 44 | 'Programming Language :: Python :: 3 :: Only', 45 | ], 46 | packages=find_packages(), 47 | python_requires='>=3.6', 48 | install_requires=[ 49 | 'dask[complete]==2021.7.1', 50 | 'distributed==2021.7.1', 51 | 'dask-mpi==2021.11.0', 52 | 'bokeh==2.4.3', 53 | 'pyarrow==14.0.1', 54 | 'mpi4py==3.1.3', 55 | 'transformers==4.16.2', 56 | 'wikiextractor==3.0.6', 57 | 'news-please @ git+https://github.com/fhamborg/news-please.git@3b7d9fdfeb148ef73f393bb2f2557e6bd878a09f', 58 | 'cchardet==2.1.7', 59 | 'awscli>=1.22.55', 60 | 'wikiextractor @ git+https://github.com/attardi/wikiextractor.git@v3.0.6', 61 | 'gdown==4.5.3', 62 | ], 63 | entry_points={ 64 | 'console_scripts': [ 65 | 'download_wikipedia=lddl.download.wikipedia:console_script', 66 | 'download_books=lddl.download.books:console_script', 67 | 'download_common_crawl=lddl.download.common_crawl:console_script', 68 | 'download_open_webtext=lddl.download.openwebtext:console_script', 69 | 'preprocess_bert_pretrain=lddl.dask.bert.pretrain:console_script', 70 | 'preprocess_bart_pretrain=lddl.dask.bart.pretrain:console_script', 71 | 'balance_dask_output=lddl.dask.load_balance:console_script', 72 | 'generate_num_samples_cache=lddl.dask.load_balance:generate_num_samples_cache', 73 | ], 74 | }, 75 | ) 76 | --------------------------------------------------------------------------------