├── distill_bloom ├── dataset │ ├── megatron │ │ ├── mpu │ │ │ ├── tests │ │ │ │ ├── __init__.py │ │ │ │ ├── commons.py │ │ │ │ ├── test_data.py │ │ │ │ ├── test_initialize.py │ │ │ │ ├── test_cross_entropy.py │ │ │ │ ├── test_random.py │ │ │ │ └── test_layers.py │ │ │ ├── utils.py │ │ │ ├── __init__.py │ │ │ ├── data.py │ │ │ ├── cross_entropy.py │ │ │ ├── mappings.py │ │ │ ├── random.py │ │ │ ├── initialize.py │ │ │ └── layers.py │ │ ├── Makefile │ │ └── helpers.cpp │ ├── dataloader.py │ ├── get_dataset.py │ ├── utils.py │ └── gpt_dataset.py ├── __init__.py ├── init_wrapper.py ├── arguments │ └── logging.py └── teacher-inference-script.py ├── Makefile ├── test_dataset.py ├── README.md └── teacher-inference-script.py /distill_bloom/dataset/megatron/mpu/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | check_dirs := distill_bloom 2 | 3 | style: 4 | black --preview $(check_dirs) 5 | isort $(check_dirs) -------------------------------------------------------------------------------- /distill_bloom/__init__.py: -------------------------------------------------------------------------------- 1 | # Dataset imports 2 | from .arguments.arguments import parse_args 3 | from .dataset.get_dataset import build_train_val_test_dataset 4 | from .dataset.dataloader import DistributedDataset, DistributedDataLoader 5 | 6 | # Arguments import 7 | from .init_wrapper import DeepSpeedInitWrapper, print_rank0 -------------------------------------------------------------------------------- /distill_bloom/dataset/megatron/Makefile: -------------------------------------------------------------------------------- 1 | 2 | PYTHON3CONFIG := $(shell command -v python3-config 2> /dev/null) 3 | 4 | ifndef PYTHON3CONFIG 5 | $(error "python3-config is not available. Please install it. It may be in a python-dev or another package") 6 | endif 7 | 8 | CXXFLAGS += -O3 -Wall -shared -std=c++11 -fPIC -fdiagnostics-color 9 | CPPFLAGS += $(shell python3 -m pybind11 --includes) 10 | LIBNAME = helpers 11 | LIBEXT = $(shell python3-config --extension-suffix) 12 | 13 | default: $(LIBNAME)$(LIBEXT) 14 | 15 | %$(LIBEXT): %.cpp 16 | $(CXX) $(CXXFLAGS) $(CPPFLAGS) $< -o $@ -------------------------------------------------------------------------------- /test_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import deepspeed 3 | 4 | import torch.distributed as dist 5 | 6 | from distill_bloom import build_train_val_test_dataset 7 | from distill_bloom import parse_args 8 | 9 | 10 | args = parse_args() 11 | 12 | local_rank = int(os.getenv("LOCAL_RANK", "0")) 13 | world_size = int(os.getenv("WORLD_SIZE", "1")) 14 | 15 | deepspeed.init_distributed("nccl") 16 | 17 | rank = dist.get_rank() 18 | 19 | if rank == 0: 20 | train_ds, val, test = build_train_val_test_dataset(args) 21 | print(f"The total dataset includes: {len(train_ds)} subsets") 22 | for i, train_data in enumerate(train_ds): 23 | print(f"Train dataset: {i} has {len(train_data)} samples") 24 | for data in train_data: 25 | print("Text: ", data['text']) 26 | break -------------------------------------------------------------------------------- /distill_bloom/dataset/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class DistributedDataset(torch.utils.data.Dataset): 4 | r""" 5 | Wrapper for torch.utils.data.Dataset to make it distributed. 6 | 7 | Args: 8 | dataset (torch.utils.data.Dataset): Dataset to be distributed. 9 | rank (int): Rank of the current process. 10 | world_size (int): Number of processes in the distributed group. 11 | """ 12 | def __init__(self, dataset, rank, world_size): 13 | self.dataset = dataset 14 | 15 | self.current_dataset_index = 0 16 | self.current_dataset = dataset[self.current_dataset_index] 17 | 18 | self.rank = rank 19 | self.world_size = world_size 20 | 21 | def _update_dataset(self): 22 | self.current_dataset_index += 1 23 | if self.current_dataset_index >= len(self.dataset): 24 | self.current_dataset_index = 0 25 | self.current_dataset = self.dataset[self.current_dataset_index] 26 | 27 | def __getitem__(self, index): 28 | r""" 29 | Loads a unique sample from the dataset. 30 | First tries to load the sample from the current dataset. 31 | If the current dataset is exhausted, it moves to the next dataset. 32 | """ 33 | try: 34 | item = self.current_dataset[(index*self.world_size) + self.rank] 35 | except IndexError: 36 | self._update_dataset() 37 | item = self.current_dataset[(index*self.world_size) + self.rank] 38 | return item 39 | 40 | def __len__(self): 41 | r""" 42 | Returns the length of the dataset. It corresponds to the total 43 | lenght of all the datasets in the dataset list. 44 | """ 45 | total_length = list(map(lambda x: len(x), self.dataset)) 46 | return sum(total_length) 47 | 48 | class DistributedDataLoader(torch.utils.data.DataLoader): 49 | r""" 50 | Wrapper around torch.utils.data.DataLoader to support distributed training. 51 | """ 52 | def __init__(self, dataset, rank, world_size, **kwargs): 53 | self.dataset = DistributedDataset(dataset, rank, world_size) 54 | super().__init__(self.dataset, **kwargs) 55 | -------------------------------------------------------------------------------- /distill_bloom/dataset/megatron/mpu/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import torch 18 | 19 | 20 | def ensure_divisibility(numerator, denominator): 21 | """Ensure that numerator is divisible by the denominator.""" 22 | assert numerator % denominator == 0, "{} is not divisible by {}".format( 23 | numerator, denominator 24 | ) 25 | 26 | 27 | def divide(numerator, denominator): 28 | """Ensure that numerator is divisible by the denominator and return 29 | the division value.""" 30 | ensure_divisibility(numerator, denominator) 31 | return numerator // denominator 32 | 33 | 34 | def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False): 35 | """Split a tensor along its last dimension. 36 | Arguments: 37 | tensor: input tensor. 38 | num_partitions: number of partitions to split the tensor 39 | contiguous_split_chunks: If True, make each chunk contiguous 40 | in memory. 41 | """ 42 | # Get the size and dimension. 43 | last_dim = tensor.dim() - 1 44 | last_dim_size = divide(tensor.size()[last_dim], num_partitions) 45 | # Split. 46 | tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) 47 | # Note: torch.split does not create contiguous tensors by default. 48 | if contiguous_split_chunks: 49 | return tuple(chunk.contiguous() for chunk in tensor_list) 50 | 51 | return tensor_list 52 | 53 | 54 | class VocabUtility: 55 | """Split the vocabulary into `world_size` chunks amd return the 56 | first and last index of the vocabulary belonging to the `rank` 57 | partition: Note that indecies in [fist, last)""" 58 | 59 | @staticmethod 60 | def vocab_range_from_per_partition_vocab_size( 61 | per_partition_vocab_size, rank, world_size 62 | ): 63 | index_f = rank * per_partition_vocab_size 64 | index_l = index_f + per_partition_vocab_size 65 | return index_f, index_l 66 | 67 | @staticmethod 68 | def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size): 69 | per_partition_vocab_size = divide(global_vocab_size, world_size) 70 | return VocabUtility.vocab_range_from_per_partition_vocab_size( 71 | per_partition_vocab_size, rank, world_size 72 | ) 73 | -------------------------------------------------------------------------------- /distill_bloom/dataset/megatron/mpu/tests/commons.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import argparse 17 | import os 18 | import random 19 | 20 | import mpu 21 | import numpy 22 | import torch 23 | 24 | 25 | class IdentityLayer(torch.nn.Module): 26 | def __init__(self, size, scale=1.0): 27 | super(IdentityLayer, self).__init__() 28 | self.weight = torch.nn.Parameter(scale * torch.randn(size)) 29 | 30 | def forward(self): 31 | return self.weight 32 | 33 | 34 | def set_random_seed(seed): 35 | """Set random seed for reproducability.""" 36 | random.seed(seed) 37 | numpy.random.seed(seed) 38 | torch.manual_seed(seed) 39 | mpu.model_parallel_cuda_manual_seed(seed) 40 | 41 | 42 | def initialize_distributed(backend="nccl"): 43 | """Initialize torch.distributed.""" 44 | # Get local rank in case it is provided. 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument( 47 | "--local_rank", 48 | type=int, 49 | default=None, 50 | help="local rank passed from distributed launcher", 51 | ) 52 | args = parser.parse_args() 53 | local_rank = args.local_rank 54 | 55 | # Get rank and world size. 56 | rank = int(os.getenv("RANK", "0")) 57 | world_size = int(os.getenv("WORLD_SIZE", "1")) 58 | 59 | print( 60 | "> initializing torch.distributed with local rank: {}, " 61 | "rank: {}, world size: {}".format(local_rank, rank, world_size) 62 | ) 63 | 64 | # Set the device id. 65 | device = rank % torch.cuda.device_count() 66 | if local_rank is not None: 67 | device = local_rank 68 | torch.cuda.set_device(device) 69 | 70 | # Call the init process. 71 | init_method = "tcp://" 72 | master_ip = os.getenv("MASTER_ADDR", "localhost") 73 | master_port = os.getenv("MASTER_PORT", "6000") 74 | init_method += master_ip + ":" + master_port 75 | torch.distributed.init_process_group( 76 | backend=backend, world_size=world_size, rank=rank, init_method=init_method 77 | ) 78 | 79 | 80 | def print_separator(message): 81 | torch.distributed.barrier() 82 | filler_len = (78 - len(message)) // 2 83 | filler = "-" * filler_len 84 | string = "\n" + filler + " {} ".format(message) + filler 85 | if torch.distributed.get_rank() == 0: 86 | print(string, flush=True) 87 | torch.distributed.barrier() 88 | -------------------------------------------------------------------------------- /distill_bloom/dataset/get_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.distributed as dist 2 | 3 | from .utils import build_dataset_group 4 | 5 | 6 | def build_train_val_test_dataset(args): 7 | r""" 8 | This function wraps all the dataset building functions from megatron. 9 | 10 | """ 11 | if args.train_samples: 12 | train_samples = args.train_samples 13 | else: 14 | train_samples = args.train_iters * args.global_batch_size 15 | eval_iters = (args.train_iters // args.eval_interval + 1) * args.eval_iters 16 | test_iters = args.eval_iters 17 | train_val_test_num_samples = [ 18 | train_samples, 19 | eval_iters * args.global_batch_size, 20 | test_iters * args.global_batch_size, 21 | ] 22 | 23 | train_ds, valid_ds, test_ds = None, None, None 24 | 25 | print("> building train, validation, and test datasets for GPT ...") 26 | # Option 1 of data loading using --data-path 27 | 28 | if args.data_path: 29 | train_ds, valid_ds, test_ds = build_train_valid_test_datasets( 30 | data_prefix=args.data_path, 31 | data_impl=args.data_impl, 32 | splits_string=args.split, 33 | train_valid_test_num_samples=train_val_test_num_samples, 34 | seq_length=args.seq_length, 35 | seed=args.seed, 36 | skip_warmup=(not args.mmap_warmup), 37 | ) 38 | # Option 2 of data loading using --(train|valid|test)-weighted-split-paths 39 | elif args.train_weighted_split_paths: 40 | assigned_train_valid_test = [] 41 | if args.train_weighted_split_paths is not None: 42 | train_ds = [] 43 | assigned_train_valid_test.append("train") 44 | if args.valid_weighted_split_paths is not None: 45 | valid_ds = [] 46 | assigned_train_valid_test.append("valid") 47 | if args.test_weighted_split_paths is not None: 48 | test_ds = [] 49 | assigned_train_valid_test.append("test") 50 | 51 | for s in assigned_train_valid_test: 52 | data_groups = zip( 53 | eval(f"args.{s}_weighted_split_paths"), 54 | eval(f"args.{s}_weighted_split_weights"), 55 | eval(f"args.{s}_weighted_split_splits"), 56 | eval(f"args.{s}_weighted_split_names"), 57 | ) 58 | for paths, weights, splits, name in data_groups: 59 | d = build_dataset_group( 60 | name, 61 | paths, 62 | weights, 63 | splits, 64 | args.data_impl, 65 | train_val_test_num_samples, 66 | args.seq_length, 67 | args.seed, 68 | (not args.mmap_warmup), 69 | train_valid_test=s, 70 | ) 71 | eval(f"{s}_ds").append(d) 72 | else: 73 | raise NotImplementedError("No dataloading argument passed") 74 | 75 | print("> finished creating GPT datasets ...") 76 | return train_ds, valid_ds, test_ds 77 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # distill-bloom-deepspeed 2 | 3 | Teacher - student distillation using DeepSpeed. 4 | This repository is partially based from [BLOOM DeepSpeed repository](https://github.com/huggingface/transformers-bloom-inference/tree/main/bloom-inference-scripts). We follow the same setup as the repository above 5 | 6 | ## Setup 7 | 8 | ```pip install transformers huggingface_hub==0.9.0``` 9 | ```pip install deepspeed>=0.7.3``` 10 | 11 | ## Install teacher checkpoints 12 | 13 | Install the DeepSpeed teacher checkpoints from [here]() to perform fast loading as described [here](https://github.com/huggingface/transformers-bloom-inference/tree/main/bloom-inference-scripts#run). Download them locally and follow the instructions below to run the training. 14 | 15 | ### Teacher inference 16 | 17 | We highly recommend to install the teacher and student weights locally, therefore to not have to re-install the weights again. 18 | After installing the teacher weights, run the following command to perform inference on the teacher model. 19 | 20 | ``` 21 | deepspeed --num_gpus NUM_GPUS teacher-inference-script.py --teacher-model-path[PATH_TO_BLOOM] --train-weighted-split-paths-path [PATH_TO_DATA] --train-iters [TRAIN_ITERS] --global-batch-size [GLOBAL_BATCH_SIZE] --eval-iters [EVAL_ITERS] --seq-length [SEQ_LEN] 22 | ``` 23 | 24 | #### Processing the dataset 25 | 26 | ##### Download the dataset 27 | 28 | Here we use the dataset used to train the BLOOM model, that is available on Jean Zay. First, download the dataset that is available on a S3 bucket. The raw dataset consist of 1.6TB of numpy arrays. If you want to train our your custom dataset, please build your own dataloader structure. 29 | 30 | ##### Get the splits 31 | 32 | For now we recommend to get the splits by running the following command. 33 | 34 | ``` 35 | export DATAPATH=[PATH_TO_DATASET] 36 | git clone https://github.com/bigscience-workshop/bigscience.git 37 | cd bigscience/ 38 | python data/catalogue/load_ratios_meg_ds_format.py --dataset-ratios-path ./data/catalogue/training_dataset_ratios_merged_nigercongo_v3.json --split train --output-meg-ds-ratio-file $DATAPATH/train.txt 39 | python data/catalogue/load_ratios_meg_ds_format.py --dataset-ratios-path ./data/catalogue/training_dataset_ratios_merged_nigercongo_v3.json --split val --output-meg-ds-ratio-file $DATAPATH/val.txt 40 | ``` 41 | 42 | ##### Test the data loading script 43 | 44 | ``` 45 | deepspeed --num_gpus 8 test.py --train-weighted-split-paths-path $DATAPATH/train.txt --train-iters 200 --global-batch-size 64 --eval-iters 20 --seq-length 2048 46 | ``` 47 | 48 | This test should output the lenght of the combined dataset as well as the total number of epochs. 49 | 50 | #### Training 51 | 52 | One the dataset is ready, we can start training the student model. 53 | 54 | 55 | ## Roadmap 56 | 57 | - [ ] Add support for teacher inference 58 | - [ ] Add support for student inference 59 | - [ ] Add support for communicating teacher logits to student node 60 | - [ ] Add support for student training (Ds-Zero) 61 | - [ ] Add support for distributed training (`hostfile`) 62 | - [x] Add support for loading Jean-Zay dataset 63 | - [ ] Add support for loading custom dataset 64 | 65 | -------------------------------------------------------------------------------- /distill_bloom/dataset/megatron/mpu/tests/test_data.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import functools 17 | import operator 18 | import sys 19 | 20 | import mpu 21 | import torch 22 | from commons import initialize_distributed, print_separator 23 | from mpu import data as data_utils 24 | 25 | sys.path.append("../..") 26 | 27 | 28 | def test_broadcast_data(tensor_model_parallel_size): 29 | if torch.distributed.get_rank() == 0: 30 | print( 31 | "> testing broadcast_data with model parallel size {} ...".format( 32 | tensor_model_parallel_size 33 | ) 34 | ) 35 | 36 | mpu.initialize_model_parallel(tensor_model_parallel_size) 37 | torch.manual_seed(1234 + mpu.get_data_parallel_rank()) 38 | tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() 39 | 40 | key_size_t = { 41 | "key1": [7, 11], 42 | "key2": [8, 2, 1], 43 | "key3": [13], 44 | "key4": [5, 1, 2], 45 | "key5": [5, 12], 46 | } 47 | keys = list(key_size_t.keys()) 48 | 49 | data = {} 50 | data_t = {} 51 | for key in key_size_t: 52 | data[key] = torch.LongTensor(size=key_size_t[key]).random_(0, 1000) 53 | data_t[key] = data[key].clone() 54 | data["keyX"] = torch.FloatTensor(size=(5,)).random_(0, 1000) 55 | data_t["keyX"] = data["keyX"].clone() 56 | if mpu.get_tensor_model_parallel_rank() != 0: 57 | data = None 58 | 59 | data_utils._check_data_types(keys, data_t, torch.int64) 60 | key_size, key_numel, total_numel = data_utils._build_key_size_numel_dictionaries( 61 | keys, data 62 | ) 63 | for key in keys: 64 | assert key_size[key] == key_size_t[key] 65 | total_numel_t = 0 66 | for key in keys: 67 | target_size = functools.reduce(operator.mul, key_size_t[key], 1) 68 | assert key_numel[key] == target_size 69 | total_numel_t += target_size 70 | assert total_numel == total_numel_t 71 | 72 | data_b = data_utils.broadcast_data(keys, data, torch.int64) 73 | for key in keys: 74 | tensor = data_t[key].cuda() 75 | assert data_b[key].sub(tensor).abs().max() == 0 76 | 77 | # Reset groups 78 | mpu.destroy_tensor_model_parallel() 79 | 80 | torch.distributed.barrier() 81 | if torch.distributed.get_rank() == 0: 82 | print(">> passed the test :-)") 83 | 84 | 85 | if __name__ == "__main__": 86 | initialize_distributed() 87 | world_size = torch.distributed.get_world_size() 88 | 89 | tensor_model_parallel_size = 1 90 | while tensor_model_parallel_size <= world_size: 91 | print_separator("test test broadcast data") 92 | test_broadcast_data(tensor_model_parallel_size) 93 | tensor_model_parallel_size *= 2 94 | -------------------------------------------------------------------------------- /distill_bloom/dataset/megatron/mpu/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Model parallel utility interface.""" 17 | 18 | from .cross_entropy import vocab_parallel_cross_entropy 19 | from .data import broadcast_data 20 | from .initialize import (destroy_model_parallel, get_data_parallel_group, 21 | get_data_parallel_rank, get_data_parallel_world_size, 22 | get_embedding_group, get_model_parallel_group, 23 | get_model_parallel_rank, 24 | get_model_parallel_world_size, 25 | get_pipeline_model_parallel_first_rank, 26 | get_pipeline_model_parallel_group, 27 | get_pipeline_model_parallel_last_rank, 28 | get_pipeline_model_parallel_next_rank, 29 | get_pipeline_model_parallel_prev_rank, 30 | get_pipeline_model_parallel_rank, 31 | get_pipeline_model_parallel_world_size, 32 | get_tensor_model_parallel_group, 33 | get_tensor_model_parallel_rank, 34 | get_tensor_model_parallel_src_rank, 35 | get_tensor_model_parallel_world_size, 36 | get_virtual_pipeline_model_parallel_rank, 37 | initialize_model_parallel, is_pipeline_first_stage, 38 | is_pipeline_last_stage, is_unitialized, 39 | model_parallel_is_initialized, 40 | set_pipeline_model_parallel_rank, 41 | set_pipeline_model_parallel_world_size, 42 | set_tensor_model_parallel_rank, 43 | set_tensor_model_parallel_world_size, 44 | set_virtual_pipeline_model_parallel_rank) 45 | from .mappings import (copy_to_tensor_model_parallel_region, 46 | gather_from_tensor_model_parallel_region, 47 | reduce_from_tensor_model_parallel_region, 48 | scatter_to_tensor_model_parallel_region) 49 | from .utils import divide, split_tensor_along_last_dim 50 | 51 | # from .layers import ColumnParallelLinear 52 | # from .layers import RowParallelLinear 53 | # from .layers import VocabParallelEmbedding 54 | # from .layers import (set_tensor_model_parallel_attributes, 55 | # set_defaults_if_not_set_tensor_model_parallel_attributes, 56 | # copy_tensor_model_parallel_attributes) 57 | 58 | 59 | # from .random import checkpoint 60 | # from .random import get_cuda_rng_tracker 61 | # from .random import init_checkpointed_activations_memory_buffer 62 | # from .random import model_parallel_cuda_manual_seed 63 | # from .random import reset_checkpointed_activations_memory_buffer 64 | # from .random import gather_split_1d_tensor 65 | # from .random import split_tensor_into_1d_equal_chunks 66 | -------------------------------------------------------------------------------- /distill_bloom/dataset/megatron/mpu/tests/test_initialize.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import sys 17 | 18 | import mpu 19 | import torch 20 | from commons import initialize_distributed, print_separator 21 | 22 | sys.path.append("../..") 23 | 24 | 25 | def test_initialize_model_parallel(tensor_model_parallel_size): 26 | if torch.distributed.get_rank() == 0: 27 | print( 28 | "> testing initialize_model_parallel with size {} ...".format( 29 | tensor_model_parallel_size 30 | ) 31 | ) 32 | tensor_model_parallel_size_ = min( 33 | tensor_model_parallel_size, torch.distributed.get_world_size() 34 | ) 35 | assert not mpu.model_parallel_is_initialized() 36 | mpu.initialize_model_parallel(tensor_model_parallel_size_) 37 | assert mpu.model_parallel_is_initialized() 38 | 39 | # Checks. 40 | def check(group, world_size, rank): 41 | assert world_size == torch.distributed.get_world_size(group=group) 42 | assert rank == torch.distributed.get_rank(group=group) 43 | 44 | # Model parallel. 45 | world_size = tensor_model_parallel_size_ 46 | rank = torch.distributed.get_rank() % tensor_model_parallel_size_ 47 | assert world_size == mpu.get_tensor_model_parallel_world_size() 48 | assert rank == mpu.get_tensor_model_parallel_rank() 49 | check(mpu.get_tensor_model_parallel_group(), world_size, rank) 50 | 51 | # Data parallel. 52 | world_size = torch.distributed.get_world_size() // tensor_model_parallel_size_ 53 | rank = torch.distributed.get_rank() // tensor_model_parallel_size 54 | assert world_size == mpu.get_data_parallel_world_size() 55 | assert rank == mpu.get_data_parallel_rank() 56 | check(mpu.get_data_parallel_group(), world_size, rank) 57 | 58 | # Reset groups 59 | mpu.destroy_model_parallel() 60 | 61 | torch.distributed.barrier() 62 | if torch.distributed.get_rank() == 0: 63 | print(">> passed the test :-)") 64 | 65 | 66 | def test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size_): 67 | if torch.distributed.get_rank() == 0: 68 | print( 69 | "> testing get_tensor_model_parallel_src_rank with size {} ...".format( 70 | tensor_model_parallel_size_ 71 | ) 72 | ) 73 | tensor_model_parallel_size = min( 74 | tensor_model_parallel_size_, torch.distributed.get_world_size() 75 | ) 76 | assert not mpu.model_parallel_is_initialized() 77 | mpu.initialize_model_parallel(tensor_model_parallel_size) 78 | assert mpu.model_parallel_is_initialized() 79 | 80 | # Checks 81 | src_rank = torch.distributed.get_rank() - mpu.get_tensor_model_parallel_rank() 82 | assert mpu.get_tensor_model_parallel_src_rank() == src_rank 83 | 84 | # Reset groups 85 | mpu.destroy_model_parallel() 86 | 87 | torch.distributed.barrier() 88 | if torch.distributed.get_rank() == 0: 89 | print(">> passed the test :-)") 90 | 91 | 92 | if __name__ == "__main__": 93 | initialize_distributed() 94 | world_size = torch.distributed.get_world_size() 95 | tensor_model_parallel_size = 1 96 | while tensor_model_parallel_size <= world_size: 97 | print_separator("test initialize model parallel") 98 | test_initialize_model_parallel(tensor_model_parallel_size) 99 | print_separator("test model parallel source rank") 100 | test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size) 101 | tensor_model_parallel_size *= 2 102 | -------------------------------------------------------------------------------- /teacher-inference-script.py: -------------------------------------------------------------------------------- 1 | # usage: 2 | # deepspeed --num_gpus 8 teacher-inference-script.py --name bigscience/bloom 3 | # 4 | # to run benchmarks: 5 | # deepspeed --num_gpus 8 teacher-inference-script.py --name bigscience/bloom --benchmark 6 | # 7 | 8 | 9 | # This is going to improve, but at the moment, the process is a bit cumbersome - we first use 10 | # 1. use Deepspeed-ZeRO to instantiate the model on GPUs, w/o loading the checkpoints, 11 | # 2. free the allocated storage 12 | # 3. start Deepspeed-Inference and only now load the checkpoint 13 | # 4. run generate 14 | # Done. 15 | # 16 | import gc 17 | import glob 18 | import io 19 | import json 20 | import math 21 | import os 22 | import time 23 | from pathlib import Path 24 | 25 | import deepspeed 26 | import torch 27 | import torch.distributed as dist 28 | from huggingface_hub import snapshot_download 29 | from transformers import AutoConfig, AutoModelForCausalLM 30 | from transformers.models.bloom.modeling_bloom import BloomBlock as BloomBlock 31 | from transformers.utils import is_offline_mode 32 | 33 | from distill_bloom import build_train_val_test_dataset, DistributedDataset, DistributedDataLoader 34 | from distill_bloom import parse_args, DeepSpeedInitWrapper, print_rank0 35 | 36 | # Arguments 37 | 38 | args = parse_args() 39 | 40 | local_rank = int(os.getenv("LOCAL_RANK", "0")) 41 | world_size = int(os.getenv("WORLD_SIZE", "1")) # World size is the number of GPUs 42 | 43 | deepspeed.init_distributed("nccl") 44 | 45 | rank = dist.get_rank() 46 | 47 | ## Check the args 48 | 49 | assert (world_size % args.global_batch_size) == 0, "micro_batch_size must be divisible by num_gpus" 50 | 51 | ds_init = DeepSpeedInitWrapper(args) 52 | ds_init.init_deepspeed_inference() 53 | model_name = ds_init.repo_root 54 | 55 | # Wait that all processes have correctly initiliazed DeepSpeed 56 | dist.barrier() 57 | 58 | 59 | print_rank0(f"*** Loading the model {model_name}") 60 | config = AutoConfig.from_pretrained(model_name) 61 | 62 | # Construct model with fake meta tensors, later will be replaced during ds-inference ckpt load 63 | with deepspeed.OnDevice(dtype=ds_init.dtype, device="meta"): 64 | model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.bfloat16) 65 | model = model.eval() 66 | 67 | # checkpoints_json=None 68 | model = deepspeed.init_inference( 69 | model, 70 | mp_size=world_size, 71 | base_dir=ds_init.repo_root, 72 | dtype=getattr(torch, ds_init.infer_dtype), 73 | checkpoint=ds_init.checkpoints_json, 74 | **ds_init.kwargs, 75 | ) 76 | model = model.module 77 | 78 | # Dataset building - each rank will have a different shard of the dataset 79 | train_ds, _, _ = build_train_val_test_dataset(args) 80 | data_loader = DistributedDataLoader( 81 | train_ds, 82 | rank=rank, 83 | world_size=world_size, 84 | batch_size=1 85 | ) 86 | dist.barrier() 87 | 88 | def generate_logits(inputs): 89 | """returns a list of zipped inputs, outputs and number of new tokens""" 90 | inputs = inputs.to(torch.cuda.current_device()) 91 | outputs = model(inputs).logits 92 | 93 | return outputs 94 | 95 | def generate_logits_batch(data_loader): 96 | for batch in data_loader: 97 | inputs = batch['text'] 98 | # as a sanity check, I used to check that inputs are different for each rank 99 | inputs = inputs.to(torch.cuda.current_device()) 100 | outputs = model(inputs).logits 101 | 102 | # Here we leave the return statement for debugging purposes 103 | # But in practice at this point we would probably call 104 | # dist.barrier() and send the logits together with the input 105 | # to the student model 106 | return outputs 107 | 108 | 109 | # warmup is a must if measuring speed as it's when all the optimizations are performed 110 | # e.g. on 8x80 a100 the first pass of 100 tokens takes 23sec, and the next one is 4secs 111 | print_rank0(f"*** Running generate warmup") 112 | # _ = generate_logits(inputs) 113 | _ = generate_logits_batch(data_loader) 114 | 115 | print_rank0(f"*** Running generate") 116 | t_generate_start = time.time() 117 | # generated = generate_logits(inputs) 118 | generated = generate_logits_batch(data_loader) 119 | print(rank, generated.shape) 120 | t_generate_span = time.time() - t_generate_start -------------------------------------------------------------------------------- /distill_bloom/dataset/megatron/mpu/tests/test_cross_entropy.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import random 17 | import sys 18 | 19 | import mpu 20 | import torch 21 | import torch.nn.functional as F 22 | from commons import (IdentityLayer, initialize_distributed, print_separator, 23 | set_random_seed) 24 | from mpu.cross_entropy import vocab_parallel_cross_entropy 25 | 26 | sys.path.append("../..") 27 | 28 | 29 | def torch_cross_entropy(batch_size, seq_length, vocab_size, logits_scale, seed): 30 | set_random_seed(seed) 31 | identity = IdentityLayer( 32 | (batch_size, seq_length, vocab_size), scale=logits_scale 33 | ).cuda() 34 | logits = identity() 35 | target = torch.cuda.LongTensor(size=(batch_size, seq_length)).random_(0, vocab_size) 36 | loss = ( 37 | F.cross_entropy( 38 | logits.view(-1, logits.size()[-1]), target.view(-1), reduction="none" 39 | ) 40 | .view_as(target) 41 | .mean() 42 | ) 43 | loss.backward() 44 | return loss, identity.weight.grad 45 | 46 | 47 | def mpu_cross_entropy(batch_size, seq_length, vocab_size, logits_scale, seed): 48 | set_random_seed(seed) 49 | identity = IdentityLayer( 50 | (batch_size, seq_length, vocab_size), scale=logits_scale 51 | ).cuda() 52 | logits = identity() 53 | logits_parallel = mpu.scatter_to_tensor_model_parallel_region(logits) 54 | target = torch.cuda.LongTensor(size=(batch_size, seq_length)).random_(0, vocab_size) 55 | loss = vocab_parallel_cross_entropy(logits_parallel, target).mean() 56 | loss.backward() 57 | return loss, identity.weight.grad 58 | 59 | 60 | def test_cross_entropy(tensor_model_parallel_size): 61 | if torch.distributed.get_rank() == 0: 62 | print( 63 | "> testing cross entropy with model parallel size {} ...".format( 64 | tensor_model_parallel_size 65 | ) 66 | ) 67 | 68 | mpu.initialize_model_parallel(tensor_model_parallel_size) 69 | tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() 70 | 71 | batch_size = 13 72 | seq_length = 17 73 | vocab_size_per_partition = 11 74 | logits_scale = 1000.0 75 | vocab_size = vocab_size_per_partition * tensor_model_parallel_size 76 | seed = 1234 77 | 78 | loss_torch, grad_torch = torch_cross_entropy( 79 | batch_size, seq_length, vocab_size, logits_scale, seed 80 | ) 81 | loss_mpu, grad_mpu = mpu_cross_entropy( 82 | batch_size, seq_length, vocab_size, logits_scale, seed 83 | ) 84 | 85 | error = loss_torch.sub_(loss_mpu).abs().max() 86 | print( 87 | " max error in loss on global rank {}: {}".format( 88 | torch.distributed.get_rank(), error 89 | ) 90 | ) 91 | assert error < 1.0e-6 92 | 93 | error = grad_torch.sub_(grad_mpu).abs().max() 94 | print( 95 | " max error in grad on global rank {}: {}".format( 96 | torch.distributed.get_rank(), error 97 | ) 98 | ) 99 | assert error < 1.0e-6 100 | 101 | # Reset groups 102 | mpu.destroy_tensor_model_parallel() 103 | 104 | torch.distributed.barrier() 105 | if torch.distributed.get_rank() == 0: 106 | print(">> passed the test :-)") 107 | 108 | 109 | if __name__ == "__main__": 110 | initialize_distributed() 111 | world_size = torch.distributed.get_world_size() 112 | 113 | tensor_model_parallel_size = 1 114 | while tensor_model_parallel_size <= world_size: 115 | print_separator("test cross entropy") 116 | test_cross_entropy(tensor_model_parallel_size) 117 | tensor_model_parallel_size *= 2 118 | -------------------------------------------------------------------------------- /distill_bloom/dataset/megatron/mpu/data.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | 18 | from .initialize import (get_tensor_model_parallel_group, 19 | get_tensor_model_parallel_rank, 20 | get_tensor_model_parallel_src_rank) 21 | 22 | _MAX_DATA_DIM = 5 23 | 24 | 25 | def _check_data_types(keys, data, target_dtype): 26 | """Check that all the keys have the same target data type.""" 27 | for key in keys: 28 | assert ( 29 | data[key].dtype == target_dtype 30 | ), "{} has data type {} which is different than {}".format( 31 | key, data[key].dtype, target_dtype 32 | ) 33 | 34 | 35 | def _build_key_size_numel_dictionaries(keys, data): 36 | """Build the size on rank 0 and broadcast.""" 37 | max_dim = _MAX_DATA_DIM 38 | sizes = [0 for _ in range(max_dim) for _ in keys] 39 | 40 | # Pack the sizes on rank zero. 41 | if get_tensor_model_parallel_rank() == 0: 42 | offset = 0 43 | for key in keys: 44 | assert data[key].dim() < max_dim, "you should increase MAX_DATA_DIM" 45 | size = data[key].size() 46 | for i, s in enumerate(size): 47 | sizes[i + offset] = s 48 | offset += max_dim 49 | 50 | # Move to GPU and broadcast. 51 | sizes_cuda = torch.cuda.LongTensor(sizes) 52 | torch.distributed.broadcast( 53 | sizes_cuda, 54 | get_tensor_model_parallel_src_rank(), 55 | group=get_tensor_model_parallel_group(), 56 | ) 57 | 58 | # Move back to cpu and unpack. 59 | sizes_cpu = sizes_cuda.cpu() 60 | key_size = {} 61 | key_numel = {} 62 | total_numel = 0 63 | offset = 0 64 | for key in keys: 65 | i = 0 66 | size = [] 67 | numel = 1 68 | while sizes_cpu[offset + i] > 0: 69 | this_size = sizes_cpu[offset + i] 70 | size.append(this_size) 71 | numel *= this_size 72 | i += 1 73 | key_size[key] = size 74 | key_numel[key] = numel 75 | total_numel += numel 76 | offset += max_dim 77 | 78 | return key_size, key_numel, total_numel 79 | 80 | 81 | def broadcast_data(keys, data, datatype): 82 | """Broadcast data from rank zero of each model parallel group to the 83 | members of the same model parallel group. 84 | 85 | Arguments: 86 | keys: list of keys in the data disctionary to be broadcasted 87 | data: data dictionary of string keys and cpu tensor values. 88 | datatype: torch data type of all tensors in data associated 89 | with keys. 90 | """ 91 | # Build (key, size) and (key, number of elements) dictionaries along 92 | # with the total number of elements on all ranks. 93 | key_size, key_numel, total_numel = _build_key_size_numel_dictionaries(keys, data) 94 | 95 | # Pack on rank zero. 96 | if get_tensor_model_parallel_rank() == 0: 97 | # Check that all keys have the same data type. 98 | _check_data_types(keys, data, datatype) 99 | # Flatten the data associated with the keys 100 | flatten_data = torch.cat( 101 | [data[key].contiguous().view(-1) for key in keys], dim=0 102 | ).cuda() 103 | else: 104 | flatten_data = torch.empty( 105 | total_numel, device=torch.cuda.current_device(), dtype=datatype 106 | ) 107 | 108 | # Broadcast 109 | torch.distributed.broadcast( 110 | flatten_data, 111 | get_tensor_model_parallel_src_rank(), 112 | group=get_tensor_model_parallel_group(), 113 | ) 114 | 115 | # Unpack 116 | output = {} 117 | offset = 0 118 | for key in keys: 119 | size = key_size[key] 120 | numel = key_numel[key] 121 | output[key] = flatten_data.narrow(0, offset, numel).view(size) 122 | offset += numel 123 | 124 | return output 125 | -------------------------------------------------------------------------------- /distill_bloom/init_wrapper.py: -------------------------------------------------------------------------------- 1 | import io, json 2 | from pathlib import Path 3 | 4 | import torch 5 | import torch.distributed as dist 6 | 7 | from transformers.models.bloom.modeling_bloom import BloomBlock as BloomBlock 8 | 9 | class DeepSpeedInitWrapper(object): 10 | r""" 11 | This is a wrapper around DeepSpeed inference / training script initialisation. 12 | It is used to initialise the DeepSpeed engine and load the necessary variables 13 | to correctly load the model and run inference. 14 | 15 | Args: 16 | args (:obj:`argparse.Namespace`): 17 | The parsed arguments from the command line. This contains all the arguments for 18 | training and inference. The `model_path` argument is used to load the model from 19 | the specified path. 20 | """ 21 | def __init__(self, args): 22 | r""" 23 | We need to store the rank of the current process since `write_checkpoints` is 24 | called only on rank 0. 25 | """ 26 | self.rank = dist.get_rank() 27 | self.checkpoints_json = "checkpoints.json" 28 | self.repo_root = args.teacher_model_path 29 | self.infer_dtype = "float16" 30 | 31 | def init_deepspeed_inference(self): 32 | r""" 33 | This function is a wrapper around the first lines that are called inside 34 | https://github.com/huggingface/transformers-bloom-inference/blob/main/bloom-inference-scripts/bloom-ds-inference.py 35 | """ 36 | tp_presharded_models = [ 37 | "microsoft/bloom-deepspeed-inference-int8", 38 | "microsoft/bloom-deepspeed-inference-fp16", 39 | ] 40 | tp_presharded_mode = True if self.repo_root in tp_presharded_models else False 41 | 42 | 43 | # use one of these args to `init_inference` 44 | # 1. injection_policy is the slower version, but it's plain pytorch so it'll always work 45 | # 2. replace_with_kernel_inject is the faster one (fast fused kernels) 46 | kernel_inject = True 47 | # kernel_inject = False 48 | 49 | if kernel_inject: 50 | # XXX: for now ds-inference only works with fp16 51 | self.dtype = torch.float16 52 | else: 53 | self.dtype = torch.bfloat16 54 | 55 | if kernel_inject: 56 | self.kwargs = dict(replace_with_kernel_inject=True) 57 | else: 58 | self.kwargs = dict( 59 | injection_policy={BloomBlock: ("self_attention.dense", "mlp.dense_4h_to_h")} 60 | ) 61 | 62 | if tp_presharded_mode: 63 | # tp presharded repos come with their own checkpoints config file 64 | checkpoints_json = os.path.join(self.repo_root, "ds_inference_config.json") 65 | else: 66 | # for normal bloom repo we need to write the checkpoints config file 67 | if self.rank == 0: 68 | write_checkponts_json(self.repo_root , self.rank, self.checkpoints_json) 69 | # dist.barrier() 70 | 71 | def print_rank0(*msg, rank=0): 72 | if rank != 0: 73 | return 74 | print(*msg) 75 | 76 | 77 | def get_checkpoint_files(model_name_or_path, rank=0,revision=None, force_offline=True): 78 | if not force_offline: 79 | # checks if online or not 80 | if is_offline_mode(): 81 | print_rank0("Offline mode: forcing local_files_only=True", rank) 82 | local_files_only = True 83 | else: 84 | local_files_only = False 85 | 86 | # loads files from hub 87 | cached_repo_dir = snapshot_download( 88 | model_name_or_path, 89 | allow_patterns=["*"], 90 | local_files_only=True, 91 | revision=revision, 92 | ) 93 | else: 94 | cached_repo_dir = model_name_or_path 95 | 96 | # extensions: .bin | .pt 97 | # creates a list of paths from all downloaded files in cache dir 98 | file_list = [ 99 | str(entry) 100 | for entry in Path(cached_repo_dir).rglob("*.[bp][it][n]") 101 | if entry.is_file() 102 | ] 103 | return file_list 104 | 105 | 106 | def write_checkponts_json(model_name, rank=0, checkpoints_json="checkpoints.json"): 107 | with io.open(checkpoints_json, "w", encoding="utf-8") as f: 108 | # checkpoint_files = glob.glob(f"{checkpoint_dir}/*bin") 109 | checkpoint_files = get_checkpoint_files(model_name, rank) 110 | 111 | # print("Checkpoint files:", checkpoint_files) 112 | 113 | data = {"type": "BLOOM", "checkpoints": checkpoint_files, "version": 1.0} 114 | 115 | json.dump(data, f) 116 | 117 | -------------------------------------------------------------------------------- /distill_bloom/dataset/megatron/mpu/cross_entropy.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import torch 18 | 19 | from .initialize import (get_tensor_model_parallel_group, 20 | get_tensor_model_parallel_rank, 21 | get_tensor_model_parallel_world_size) 22 | from .utils import VocabUtility 23 | 24 | 25 | class _VocabParallelCrossEntropy(torch.autograd.Function): 26 | @staticmethod 27 | def forward(ctx, vocab_parallel_logits, target): 28 | # Maximum value along vocab dimension across all GPUs. 29 | logits_max = torch.max(vocab_parallel_logits, dim=-1)[0] 30 | torch.distributed.all_reduce( 31 | logits_max, 32 | op=torch.distributed.ReduceOp.MAX, 33 | group=get_tensor_model_parallel_group(), 34 | ) 35 | # Subtract the maximum value. 36 | vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1)) 37 | 38 | # Get the partition's vocab indecies 39 | get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size 40 | partition_vocab_size = vocab_parallel_logits.size()[-1] 41 | rank = get_tensor_model_parallel_rank() 42 | world_size = get_tensor_model_parallel_world_size() 43 | vocab_start_index, vocab_end_index = get_vocab_range( 44 | partition_vocab_size, rank, world_size 45 | ) 46 | 47 | # Create a mask of valid vocab ids (1 means it needs to be masked). 48 | target_mask = (target < vocab_start_index) | (target >= vocab_end_index) 49 | masked_target = target.clone() - vocab_start_index 50 | masked_target[target_mask] = 0 51 | 52 | # Get predicted-logits = logits[target]. 53 | # For Simplicity, we convert logits to a 2-D tensor with size 54 | # [*, partition-vocab-size] and target to a 1-D tensor of size [*]. 55 | logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size) 56 | masked_target_1d = masked_target.view(-1) 57 | arange_1d = torch.arange( 58 | start=0, end=logits_2d.size()[0], device=logits_2d.device 59 | ) 60 | predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] 61 | predicted_logits_1d = predicted_logits_1d.clone().contiguous() 62 | predicted_logits = predicted_logits_1d.view_as(target) 63 | predicted_logits[target_mask] = 0.0 64 | # All reduce is needed to get the chunks from other GPUs. 65 | torch.distributed.all_reduce( 66 | predicted_logits, 67 | op=torch.distributed.ReduceOp.SUM, 68 | group=get_tensor_model_parallel_group(), 69 | ) 70 | 71 | # Sum of exponential of logits along vocab dimension across all GPUs. 72 | exp_logits = vocab_parallel_logits 73 | torch.exp(vocab_parallel_logits, out=exp_logits) 74 | sum_exp_logits = exp_logits.sum(dim=-1) 75 | torch.distributed.all_reduce( 76 | sum_exp_logits, 77 | op=torch.distributed.ReduceOp.SUM, 78 | group=get_tensor_model_parallel_group(), 79 | ) 80 | 81 | # Loss = log(sum(exp(logits))) - predicted-logit. 82 | loss = torch.log(sum_exp_logits) - predicted_logits 83 | 84 | # Store softmax, target-mask and masked-target for backward pass. 85 | exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) 86 | ctx.save_for_backward(exp_logits, target_mask, masked_target_1d) 87 | 88 | return loss 89 | 90 | @staticmethod 91 | def backward(ctx, grad_output): 92 | # Retreive tensors from the forward path. 93 | softmax, target_mask, masked_target_1d = ctx.saved_tensors 94 | 95 | # All the inputs have softmax as thier gradient. 96 | grad_input = softmax 97 | # For simplicity, work with the 2D gradient. 98 | partition_vocab_size = softmax.size()[-1] 99 | grad_2d = grad_input.view(-1, partition_vocab_size) 100 | 101 | # Add the gradient from matching classes. 102 | arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device) 103 | grad_2d[arange_1d, masked_target_1d] -= 1.0 - target_mask.view(-1).float() 104 | 105 | # Finally elementwise multiplication with the output gradients. 106 | grad_input.mul_(grad_output.unsqueeze(dim=-1)) 107 | 108 | return grad_input, None 109 | 110 | 111 | def vocab_parallel_cross_entropy(vocab_parallel_logits, target): 112 | """Helper function for the cross entropy.""" 113 | return _VocabParallelCrossEntropy.apply(vocab_parallel_logits, target) 114 | -------------------------------------------------------------------------------- /distill_bloom/dataset/megatron/mpu/mappings.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | 18 | from .initialize import (get_tensor_model_parallel_group, 19 | get_tensor_model_parallel_rank, 20 | get_tensor_model_parallel_world_size) 21 | from .utils import split_tensor_along_last_dim 22 | 23 | 24 | def _reduce(input_): 25 | """All-reduce the the input tensor across model parallel group.""" 26 | 27 | # Bypass the function if we are using only 1 GPU. 28 | if get_tensor_model_parallel_world_size() == 1: 29 | return input_ 30 | 31 | # All-reduce. 32 | torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group()) 33 | 34 | return input_ 35 | 36 | 37 | def _split(input_): 38 | """Split the tensor along its last dimension and keep the 39 | corresponding slice.""" 40 | 41 | world_size = get_tensor_model_parallel_world_size() 42 | # Bypass the function if we are using only 1 GPU. 43 | if world_size == 1: 44 | return input_ 45 | 46 | # Split along last dimension. 47 | input_list = split_tensor_along_last_dim(input_, world_size) 48 | 49 | # Note: torch.split does not create contiguous tensors by default. 50 | rank = get_tensor_model_parallel_rank() 51 | output = input_list[rank].contiguous() 52 | 53 | return output 54 | 55 | 56 | def _gather(input_): 57 | """Gather tensors and concatinate along the last dimension.""" 58 | 59 | world_size = get_tensor_model_parallel_world_size() 60 | # Bypass the function if we are using only 1 GPU. 61 | if world_size == 1: 62 | return input_ 63 | 64 | # Size and dimension. 65 | last_dim = input_.dim() - 1 66 | rank = get_tensor_model_parallel_rank() 67 | 68 | tensor_list = [torch.empty_like(input_) for _ in range(world_size)] 69 | tensor_list[rank] = input_ 70 | torch.distributed.all_gather( 71 | tensor_list, input_, group=get_tensor_model_parallel_group() 72 | ) 73 | 74 | # Note: torch.cat already creates a contiguous tensor. 75 | output = torch.cat(tensor_list, dim=last_dim).contiguous() 76 | 77 | return output 78 | 79 | 80 | class _CopyToModelParallelRegion(torch.autograd.Function): 81 | """Pass the input to the model parallel region.""" 82 | 83 | @staticmethod 84 | def symbolic(graph, input_): 85 | return input_ 86 | 87 | @staticmethod 88 | def forward(ctx, input_): 89 | return input_ 90 | 91 | @staticmethod 92 | def backward(ctx, grad_output): 93 | return _reduce(grad_output) 94 | 95 | 96 | class _ReduceFromModelParallelRegion(torch.autograd.Function): 97 | """All-reduce the input from the model parallel region.""" 98 | 99 | @staticmethod 100 | def symbolic(graph, input_): 101 | return _reduce(input_) 102 | 103 | @staticmethod 104 | def forward(ctx, input_): 105 | return _reduce(input_) 106 | 107 | @staticmethod 108 | def backward(ctx, grad_output): 109 | return grad_output 110 | 111 | 112 | class _ScatterToModelParallelRegion(torch.autograd.Function): 113 | """Split the input and keep only the corresponding chuck to the rank.""" 114 | 115 | @staticmethod 116 | def symbolic(graph, input_): 117 | return _split(input_) 118 | 119 | @staticmethod 120 | def forward(ctx, input_): 121 | return _split(input_) 122 | 123 | @staticmethod 124 | def backward(ctx, grad_output): 125 | return _gather(grad_output) 126 | 127 | 128 | class _GatherFromModelParallelRegion(torch.autograd.Function): 129 | """Gather the input from model parallel region and concatinate.""" 130 | 131 | @staticmethod 132 | def symbolic(graph, input_): 133 | return _gather(input_) 134 | 135 | @staticmethod 136 | def forward(ctx, input_): 137 | return _gather(input_) 138 | 139 | @staticmethod 140 | def backward(ctx, grad_output): 141 | return _split(grad_output) 142 | 143 | 144 | # ----------------- 145 | # Helper functions. 146 | # ----------------- 147 | 148 | 149 | def copy_to_tensor_model_parallel_region(input_): 150 | return _CopyToModelParallelRegion.apply(input_) 151 | 152 | 153 | def reduce_from_tensor_model_parallel_region(input_): 154 | return _ReduceFromModelParallelRegion.apply(input_) 155 | 156 | 157 | def scatter_to_tensor_model_parallel_region(input_): 158 | return _ScatterToModelParallelRegion.apply(input_) 159 | 160 | 161 | def gather_from_tensor_model_parallel_region(input_): 162 | return _GatherFromModelParallelRegion.apply(input_) 163 | -------------------------------------------------------------------------------- /distill_bloom/dataset/megatron/mpu/tests/test_random.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import sys 17 | 18 | import mpu 19 | import torch 20 | from commons import initialize_distributed, print_separator 21 | 22 | sys.path.append("../..") 23 | 24 | 25 | def test_set_cuda_rng_state(tensor_model_parallel_size): 26 | if torch.distributed.get_rank() == 0: 27 | print( 28 | "> testing set_rng_state with size {} ...".format( 29 | tensor_model_parallel_size 30 | ) 31 | ) 32 | 33 | mpu.initialize_model_parallel(tensor_model_parallel_size) 34 | tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() 35 | 36 | size = 123 37 | seed = 1234 38 | torch.cuda.manual_seed(1234) 39 | tensor = torch.cuda.FloatTensor(size) 40 | 41 | # Get the state 42 | rng_state = torch.cuda.get_rng_state() 43 | rng_state_copy = rng_state.clone() 44 | 45 | # Do some stuff. 46 | for _ in range(5): 47 | torch.randn(size, out=tensor) 48 | result_1 = tensor.clone() 49 | 50 | assert rng_state.sub(rng_state_copy).max() == 0 51 | assert torch.cuda.get_rng_state().sub(rng_state_copy).max() > 0 52 | 53 | # State should be different. 54 | new_rng_state = torch.cuda.get_rng_state() 55 | max_diff = new_rng_state.sub(rng_state).max() 56 | print( 57 | " max diff in rng state (should be non-zero) on global rank {}: {}".format( 58 | torch.distributed.get_rank(), max_diff 59 | ) 60 | ) 61 | assert max_diff > 0 62 | 63 | # Reset the rng state and do the same stuff. 64 | mpu.random._set_cuda_rng_state(rng_state) 65 | for _ in range(5): 66 | torch.randn(size, out=tensor) 67 | mpu.random._set_cuda_rng_state(rng_state) 68 | for _ in range(5): 69 | torch.randn(size, out=tensor) 70 | result_2 = tensor.clone() 71 | 72 | # Results should be the same 73 | error = result_2.sub(result_1).abs().max() 74 | print( 75 | " max error in generated tensors (should be zero) on " 76 | "global rank {}: {}".format(torch.distributed.get_rank(), error) 77 | ) 78 | assert error < 1.0e-6 79 | 80 | # Input state should have remained intact. 81 | error = rng_state.sub(rng_state_copy).max() 82 | print( 83 | " max error in rng state (should be zero) on global rank {}: {}".format( 84 | torch.distributed.get_rank(), error 85 | ) 86 | ) 87 | assert error == 0 88 | 89 | # Reset groups 90 | mpu.destroy_model_parallel() 91 | 92 | torch.distributed.barrier() 93 | if torch.distributed.get_rank() == 0: 94 | print(">> passed the test :-)") 95 | 96 | 97 | def test_cuda_rng_tracker(tensor_model_parallel_size): 98 | if torch.distributed.get_rank() == 0: 99 | print( 100 | "> testing cuda rng tracker with size {} ...".format( 101 | tensor_model_parallel_size 102 | ) 103 | ) 104 | 105 | mpu.initialize_model_parallel(tensor_model_parallel_size) 106 | tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() 107 | 108 | seed_1 = 1234 109 | seed_2 = 4321 110 | size = [12, 21] 111 | tensor = torch.cuda.FloatTensor(size) 112 | 113 | # Set to seed_1 and generate two tensors. 114 | torch.cuda.manual_seed(seed_1) 115 | torch.randn(size, out=tensor) 116 | target_11 = tensor.clone() 117 | torch.randn(size, out=tensor) 118 | target_12 = tensor.clone() 119 | 120 | # Set to seed_2 and generate two tensors. 121 | torch.cuda.manual_seed(seed_2) 122 | torch.randn(size, out=tensor) 123 | target_21 = tensor.clone() 124 | torch.randn(size, out=tensor) 125 | target_22 = tensor.clone() 126 | 127 | # Now if we interleave seed_1 and seed_2, 128 | # we should still get the same tensors 129 | torch.cuda.manual_seed(seed_1) 130 | mpu.get_cuda_rng_tracker().add("test", seed_2) 131 | 132 | torch.randn(size, out=tensor) 133 | result_11 = tensor.clone() 134 | 135 | with mpu.get_cuda_rng_tracker().fork("test"): 136 | torch.randn(size, out=tensor) 137 | result_21 = tensor.clone() 138 | 139 | torch.randn(size, out=tensor) 140 | result_12 = tensor.clone() 141 | 142 | with mpu.get_cuda_rng_tracker().fork("test"): 143 | torch.randn(size, out=tensor) 144 | result_22 = tensor.clone() 145 | 146 | diff = result_11.sub(result_21).abs().max() 147 | diff = min(diff, result_12.sub(result_22).abs().max()) 148 | print( 149 | " max diff in generated tensors (should be non-zero) on " 150 | "global rank {}: {}".format(torch.distributed.get_rank(), diff) 151 | ) 152 | assert diff > 1.0e-6 153 | error = max( 154 | result_11.sub(target_11).abs().max(), result_12.sub(target_12).abs().max() 155 | ) 156 | error = max(error, result_21.sub(target_21).abs().max()) 157 | error = max(error, result_22.sub(target_22).abs().max()) 158 | print( 159 | " max error in generated tensors (should be zero) on " 160 | "global rank {}: {}".format(torch.distributed.get_rank(), error) 161 | ) 162 | assert error < 1.0e-6 163 | 164 | # Reset the tracker 165 | mpu.get_cuda_rng_tracker().reset() 166 | 167 | # Reset groups 168 | mpu.destroy_model_parallel() 169 | 170 | torch.distributed.barrier() 171 | if torch.distributed.get_rank() == 0: 172 | print(">> passed the test :-)") 173 | 174 | 175 | def test_model_parallel_cuda_manual_seed(tensor_model_parallel_size): 176 | if torch.distributed.get_rank() == 0: 177 | print( 178 | "> testing model parallel cuda manual seed with size {} ...".format( 179 | tensor_model_parallel_size 180 | ) 181 | ) 182 | 183 | mpu.initialize_model_parallel(tensor_model_parallel_size) 184 | tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() 185 | 186 | mpu.model_parallel_cuda_manual_seed(12345) 187 | assert torch.cuda.initial_seed() == 12345 188 | with mpu.get_cuda_rng_tracker().fork(): 189 | assert torch.cuda.initial_seed() == ( 190 | 12345 + 2718 + mpu.get_tensor_model_parallel_rank() 191 | ) 192 | 193 | # Reset the tracker 194 | mpu.get_cuda_rng_tracker().reset() 195 | 196 | # Reset groups 197 | mpu.destroy_model_parallel() 198 | 199 | torch.distributed.barrier() 200 | if torch.distributed.get_rank() == 0: 201 | print(">> passed the test :-)") 202 | 203 | 204 | if __name__ == "__main__": 205 | initialize_distributed() 206 | world_size = torch.distributed.get_world_size() 207 | 208 | tensor_model_parallel_size = 1 209 | while tensor_model_parallel_size <= world_size: 210 | print_separator("test set rng state") 211 | test_set_cuda_rng_state(tensor_model_parallel_size) 212 | tensor_model_parallel_size *= 2 213 | 214 | tensor_model_parallel_size = 1 215 | while tensor_model_parallel_size <= world_size: 216 | print_separator("test cuda rng tracker") 217 | test_cuda_rng_tracker(tensor_model_parallel_size) 218 | tensor_model_parallel_size *= 2 219 | 220 | tensor_model_parallel_size = 1 221 | while tensor_model_parallel_size <= world_size: 222 | print_separator("test model parallel cuda manual seed") 223 | test_model_parallel_cuda_manual_seed(tensor_model_parallel_size) 224 | tensor_model_parallel_size *= 2 225 | -------------------------------------------------------------------------------- /distill_bloom/arguments/logging.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 Optuna, Hugging Face 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Logging utilities. """ 16 | 17 | import logging 18 | import os 19 | import sys 20 | import threading 21 | from functools import wraps 22 | from logging import CRITICAL # NOQA 23 | from logging import DEBUG # NOQA 24 | from logging import ERROR # NOQA 25 | from logging import FATAL # NOQA 26 | from logging import INFO # NOQA 27 | from logging import NOTSET # NOQA 28 | from logging import WARN # NOQA 29 | from logging import WARNING # NOQA 30 | from typing import Optional 31 | 32 | _lock = threading.Lock() 33 | _default_handler: Optional[logging.Handler] = None 34 | 35 | log_levels = { 36 | "debug": logging.DEBUG, 37 | "info": logging.INFO, 38 | "warning": logging.WARNING, 39 | "error": logging.ERROR, 40 | "critical": logging.CRITICAL, 41 | } 42 | 43 | _default_log_level = logging.WARNING 44 | 45 | 46 | def _get_default_logging_level(): 47 | """ 48 | If MEGATRON_DEEPSPEED_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is 49 | not - fall back to ``_default_log_level`` 50 | """ 51 | env_level_str = os.getenv("MEGATRON_DEEPSPEED_VERBOSITY", None) 52 | if env_level_str: 53 | if env_level_str in log_levels: 54 | return log_levels[env_level_str] 55 | else: 56 | logging.getLogger().warning( 57 | f"Unknown option MEGATRON_DEEPSPEED_VERBOSITY={env_level_str}, " 58 | f"has to be one of: { ', '.join(log_levels.keys()) }" 59 | ) 60 | return _default_log_level 61 | 62 | 63 | def _get_library_name() -> str: 64 | return __name__.split(".")[0] 65 | 66 | 67 | def _get_library_root_logger() -> logging.Logger: 68 | return logging.getLogger(_get_library_name()) 69 | 70 | 71 | def _configure_library_root_logger() -> None: 72 | global _default_handler 73 | 74 | with _lock: 75 | if _default_handler: 76 | # This library has already configured the library root logger. 77 | return 78 | _default_handler = logging.StreamHandler() # Set sys.stderr as stream. 79 | _default_handler.flush = sys.stderr.flush 80 | 81 | # Apply our default configuration to the library root logger. 82 | library_root_logger = _get_library_root_logger() 83 | library_root_logger.addHandler(_default_handler) 84 | library_root_logger.setLevel(_get_default_logging_level()) 85 | library_root_logger.propagate = False 86 | 87 | 88 | def _reset_library_root_logger() -> None: 89 | global _default_handler 90 | 91 | with _lock: 92 | if not _default_handler: 93 | return 94 | 95 | library_root_logger = _get_library_root_logger() 96 | library_root_logger.removeHandler(_default_handler) 97 | library_root_logger.setLevel(logging.NOTSET) 98 | _default_handler = None 99 | 100 | 101 | def get_log_levels_dict(): 102 | return log_levels 103 | 104 | 105 | def get_logger(name: Optional[str] = None) -> logging.Logger: 106 | """ 107 | Return a logger with the specified name. 108 | This function is not supposed to be directly accessed unless you are writing a custom transformers module. 109 | """ 110 | 111 | if name is None: 112 | name = _get_library_name() 113 | 114 | _configure_library_root_logger() 115 | return logging.getLogger(name) 116 | 117 | 118 | def get_verbosity() -> int: 119 | """ 120 | Return the current level for the 🤗 Transformers's root logger as an int. 121 | Returns: 122 | :obj:`int`: The logging level. 123 | .. note:: 124 | 🤗 Transformers has following logging levels: 125 | - 50: ``transformers.logging.CRITICAL`` or ``transformers.logging.FATAL`` 126 | - 40: ``transformers.logging.ERROR`` 127 | - 30: ``transformers.logging.WARNING`` or ``transformers.logging.WARN`` 128 | - 20: ``transformers.logging.INFO`` 129 | - 10: ``transformers.logging.DEBUG`` 130 | """ 131 | 132 | _configure_library_root_logger() 133 | return _get_library_root_logger().getEffectiveLevel() 134 | 135 | 136 | def set_verbosity(verbosity: int) -> None: 137 | """ 138 | Set the verbosity level for the 🤗 Transformers's root logger. 139 | Args: 140 | verbosity (:obj:`int`): 141 | Logging level, e.g., one of: 142 | - ``transformers.logging.CRITICAL`` or ``transformers.logging.FATAL`` 143 | - ``transformers.logging.ERROR`` 144 | - ``transformers.logging.WARNING`` or ``transformers.logging.WARN`` 145 | - ``transformers.logging.INFO`` 146 | - ``transformers.logging.DEBUG`` 147 | """ 148 | 149 | _configure_library_root_logger() 150 | _get_library_root_logger().setLevel(verbosity) 151 | 152 | 153 | def set_verbosity_info(): 154 | """Set the verbosity to the :obj:`INFO` level.""" 155 | return set_verbosity(INFO) 156 | 157 | 158 | def set_verbosity_warning(): 159 | """Set the verbosity to the :obj:`WARNING` level.""" 160 | return set_verbosity(WARNING) 161 | 162 | 163 | def set_verbosity_debug(): 164 | """Set the verbosity to the :obj:`DEBUG` level.""" 165 | return set_verbosity(DEBUG) 166 | 167 | 168 | def set_verbosity_error(): 169 | """Set the verbosity to the :obj:`ERROR` level.""" 170 | return set_verbosity(ERROR) 171 | 172 | 173 | def disable_default_handler() -> None: 174 | """Disable the default handler of the HuggingFace Transformers's root logger.""" 175 | 176 | _configure_library_root_logger() 177 | 178 | assert _default_handler is not None 179 | _get_library_root_logger().removeHandler(_default_handler) 180 | 181 | 182 | def enable_default_handler() -> None: 183 | """Enable the default handler of the HuggingFace Transformers's root logger.""" 184 | 185 | _configure_library_root_logger() 186 | 187 | assert _default_handler is not None 188 | _get_library_root_logger().addHandler(_default_handler) 189 | 190 | 191 | def add_handler(handler: logging.Handler) -> None: 192 | """adds a handler to the HuggingFace Transformers's root logger.""" 193 | 194 | _configure_library_root_logger() 195 | 196 | assert handler is not None 197 | _get_library_root_logger().addHandler(handler) 198 | 199 | 200 | def remove_handler(handler: logging.Handler) -> None: 201 | """removes given handler from the HuggingFace Transformers's root logger.""" 202 | 203 | _configure_library_root_logger() 204 | 205 | assert handler is not None and handler not in _get_library_root_logger().handlers 206 | _get_library_root_logger().removeHandler(handler) 207 | 208 | 209 | def disable_propagation() -> None: 210 | """ 211 | Disable propagation of the library log outputs. Note that log propagation is disabled by default. 212 | """ 213 | 214 | _configure_library_root_logger() 215 | _get_library_root_logger().propagate = False 216 | 217 | 218 | def enable_propagation() -> None: 219 | """ 220 | Enable propagation of the library log outputs. Please disable the HuggingFace Transformers's default handler to 221 | prevent double logging if the root logger has been configured. 222 | """ 223 | 224 | _configure_library_root_logger() 225 | _get_library_root_logger().propagate = True 226 | 227 | 228 | def enable_explicit_format() -> None: 229 | """ 230 | Enable explicit formatting for every HuggingFace Transformers's logger. The explicit formatter is as follows: 231 | :: 232 | [LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE 233 | All handlers currently bound to the root logger are affected by this method. 234 | """ 235 | handlers = _get_library_root_logger().handlers 236 | 237 | for handler in handlers: 238 | formatter = logging.Formatter( 239 | "[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s" 240 | ) 241 | handler.setFormatter(formatter) 242 | 243 | 244 | def reset_format() -> None: 245 | """ 246 | Resets the formatting for HuggingFace Transformers's loggers. 247 | All handlers currently bound to the root logger are affected by this method. 248 | """ 249 | handlers = _get_library_root_logger().handlers 250 | 251 | for handler in handlers: 252 | handler.setFormatter(None) 253 | -------------------------------------------------------------------------------- /distill_bloom/teacher-inference-script.py: -------------------------------------------------------------------------------- 1 | # usage: 2 | # deepspeed --num_gpus 8 teacher-inference-script.py --name bigscience/bloom 3 | # 4 | # to run benchmarks: 5 | # deepspeed --num_gpus 8 teacher-inference-script.py --name bigscience/bloom --benchmark 6 | # 7 | 8 | 9 | # This is going to improve, but at the moment, the process is a bit cumbersome - we first use 10 | # 1. use Deepspeed-ZeRO to instantiate the model on GPUs, w/o loading the checkpoints, 11 | # 2. free the allocated storage 12 | # 3. start Deepspeed-Inference and only now load the checkpoint 13 | # 4. run generate 14 | # Done. 15 | # 16 | 17 | 18 | import gc 19 | import glob 20 | import io 21 | import json 22 | import math 23 | import os 24 | import time 25 | from argparse import ArgumentParser 26 | from pathlib import Path 27 | 28 | import deepspeed 29 | import torch 30 | import torch.distributed as dist 31 | from huggingface_hub import snapshot_download 32 | from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer 33 | from transformers.models.bloom.modeling_bloom import BloomBlock as BloomBlock 34 | from transformers.utils import is_offline_mode 35 | 36 | # the Deepspeed team made these so it's super fast to load (~1 minute), rather than wait 10-20min loading time. 37 | tp_presharded_models = [ 38 | "microsoft/bloom-deepspeed-inference-int8", 39 | "microsoft/bloom-deepspeed-inference-fp16", 40 | ] 41 | 42 | t_start = time.time() 43 | 44 | num_tokens = 100 45 | 46 | parser = ArgumentParser() 47 | 48 | parser.add_argument("--name", required=True, type=str, help="model_name") 49 | parser.add_argument( 50 | "--dtype", 51 | type=str, 52 | help="float16 or int8", 53 | choices=["int8", "float16"], 54 | default="float16", 55 | ) 56 | parser.add_argument( 57 | "--local_rank", required=False, type=int, help="used by dist launchers" 58 | ) 59 | parser.add_argument("--batch_size", default=1, type=int, help="batch size") 60 | parser.add_argument( 61 | "--benchmark", action="store_true", help="additionally run benchmark" 62 | ) 63 | args = parser.parse_args() 64 | 65 | local_rank = int(os.getenv("LOCAL_RANK", "0")) 66 | world_size = int(os.getenv("WORLD_SIZE", "1")) 67 | 68 | deepspeed.init_distributed("nccl") 69 | rank = dist.get_rank() 70 | 71 | 72 | def print_rank0(*msg): 73 | if rank != 0: 74 | return 75 | print(*msg) 76 | 77 | 78 | ### Model loading and instantiating on GPUs 79 | 80 | 81 | def get_repo_root(model_name_or_path, revision=None): 82 | # checks if online or not 83 | if is_offline_mode(): 84 | print_rank0("Offline mode: forcing local_files_only=True") 85 | local_files_only = True 86 | else: 87 | local_files_only = False 88 | 89 | # loads files from hub 90 | cached_repo_dir = snapshot_download( 91 | model_name_or_path, 92 | allow_patterns=["*"], 93 | local_files_only=local_files_only, 94 | revision=revision, 95 | ) 96 | 97 | return cached_repo_dir 98 | 99 | 100 | def get_checkpoint_files(model_name_or_path, revision=None, force_offline=True): 101 | if not force_offline: 102 | # checks if online or not 103 | if is_offline_mode(): 104 | print_rank0("Offline mode: forcing local_files_only=True") 105 | local_files_only = True 106 | else: 107 | local_files_only = False 108 | 109 | # loads files from hub 110 | cached_repo_dir = snapshot_download( 111 | model_name_or_path, 112 | allow_patterns=["*"], 113 | local_files_only=True, 114 | revision=revision, 115 | ) 116 | else: 117 | cached_repo_dir = model_name_or_path 118 | 119 | # extensions: .bin | .pt 120 | # creates a list of paths from all downloaded files in cache dir 121 | file_list = [ 122 | str(entry) 123 | for entry in Path(cached_repo_dir).rglob("*.[bp][it][n]") 124 | if entry.is_file() 125 | ] 126 | return file_list 127 | 128 | 129 | model_name = args.name 130 | infer_dtype = args.dtype 131 | 132 | tp_presharded_mode = True if model_name in tp_presharded_models else False 133 | 134 | # print(get_checkpoint_files(model_name)) 135 | 136 | print_rank0(f"*** Loading the model {model_name}") 137 | 138 | tokenizer = AutoTokenizer.from_pretrained(model_name) 139 | config = AutoConfig.from_pretrained(model_name) 140 | 141 | # XXX: can't automatically derive dtype via config's `from_pretrained` 142 | # dtype = torch.bfloat16 if model_name in ["bigscience/bloom", "bigscience/bigscience-small-testing"] else torch.float16 143 | 144 | 145 | # use one of these args to `init_inference` 146 | # 1. injection_policy is the slower version, but it's plain pytorch so it'll always work 147 | # 2. replace_with_kernel_inject is the faster one (fast fused kernels) 148 | kernel_inject = True 149 | # kernel_inject = False 150 | 151 | if kernel_inject: 152 | # XXX: for now ds-inference only works with fp16 153 | dtype = torch.float16 154 | else: 155 | dtype = torch.bfloat16 156 | 157 | if args.benchmark: 158 | torch.cuda.empty_cache() 159 | gc.collect() 160 | deepspeed.runtime.utils.see_memory_usage("pre-from-pretrained", force=True) 161 | 162 | # Construct model with fake meta tensors, later will be replaced during ds-inference ckpt load 163 | with deepspeed.OnDevice(dtype=dtype, device="meta"): 164 | model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.bfloat16) 165 | 166 | if args.benchmark: 167 | deepspeed.runtime.utils.see_memory_usage("post-from-pretrained", force=True) 168 | 169 | model = model.eval() 170 | 171 | if args.benchmark: 172 | torch.cuda.empty_cache() 173 | gc.collect() 174 | deepspeed.runtime.utils.see_memory_usage("post-init-ds-zero-init", force=True) 175 | 176 | ### Deepspeed-Inference Loading 177 | 178 | checkpoints_json = "checkpoints.json" 179 | 180 | 181 | def write_checkponts_json(): 182 | with io.open(checkpoints_json, "w", encoding="utf-8") as f: 183 | # checkpoint_files = glob.glob(f"{checkpoint_dir}/*bin") 184 | checkpoint_files = get_checkpoint_files(model_name) 185 | 186 | # print("Checkpoint files:", checkpoint_files) 187 | 188 | data = {"type": "BLOOM", "checkpoints": checkpoint_files, "version": 1.0} 189 | 190 | json.dump(data, f) 191 | 192 | 193 | if args.benchmark: 194 | torch.cuda.empty_cache() 195 | gc.collect() 196 | deepspeed.runtime.utils.see_memory_usage("pre-ds-inference-init", force=True) 197 | 198 | if kernel_inject: 199 | kwargs = dict(replace_with_kernel_inject=True) 200 | else: 201 | kwargs = dict( 202 | injection_policy={BloomBlock: ("self_attention.dense", "mlp.dense_4h_to_h")} 203 | ) 204 | 205 | # TODO: this fails even if the model is present locally 206 | # repo_root = get_repo_root(model_name) 207 | repo_root = model_name 208 | 209 | if tp_presharded_mode: 210 | # tp presharded repos come with their own checkpoints config file 211 | checkpoints_json = os.path.join(repo_root, "ds_inference_config.json") 212 | else: 213 | # for normal bloom repo we need to write the checkpoints config file 214 | if rank == 0: 215 | write_checkponts_json() 216 | dist.barrier() 217 | 218 | # checkpoints_json=None 219 | model = deepspeed.init_inference( 220 | model, 221 | mp_size=world_size, 222 | base_dir=repo_root, 223 | dtype=getattr(torch, infer_dtype), 224 | checkpoint=checkpoints_json, 225 | **kwargs, 226 | ) 227 | 228 | if args.benchmark: 229 | torch.cuda.empty_cache() 230 | gc.collect() 231 | deepspeed.runtime.utils.see_memory_usage("post-ds-inference-init", force=True) 232 | 233 | 234 | model = model.module 235 | 236 | if args.benchmark: 237 | t_ready = time.time() 238 | 239 | 240 | ### Generate 241 | 242 | 243 | print_rank0(f"*** Starting to generate {num_tokens} tokens with bs={args.batch_size}") 244 | 245 | input_sentences = [ 246 | "DeepSpeed is a machine learning framework", 247 | "He is working on", 248 | "He has a", 249 | "He got all", 250 | "Everyone is happy and I can", 251 | "The new movie that got Oscar this year", 252 | "In the far far distance from our galaxy,", 253 | "Peace is the only way", 254 | ] 255 | 256 | if args.batch_size > len(input_sentences): 257 | # dynamically extend to support larger bs by repetition 258 | input_sentences *= math.ceil(args.batch_size / len(input_sentences)) 259 | 260 | generate_kwargs = dict(max_new_tokens=num_tokens, do_sample=False) 261 | 262 | 263 | print_rank0(f"Generate args {generate_kwargs}") 264 | 265 | inputs = input_sentences[: args.batch_size] 266 | 267 | 268 | def generate(): 269 | """returns a list of zipped inputs, outputs and number of new tokens""" 270 | 271 | input_tokens = tokenizer.batch_encode_plus( 272 | inputs, return_tensors="pt", padding=True 273 | ) 274 | for t in input_tokens: 275 | if torch.is_tensor(input_tokens[t]): 276 | input_tokens[t] = input_tokens[t].to(torch.cuda.current_device()) 277 | 278 | outputs = model.generate(**input_tokens, **generate_kwargs) 279 | 280 | input_tokens_lengths = [x.shape[0] for x in input_tokens.input_ids] 281 | output_tokens_lengths = [x.shape[0] for x in outputs] 282 | 283 | total_new_tokens = [ 284 | o - i for i, o in zip(input_tokens_lengths, output_tokens_lengths) 285 | ] 286 | outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) 287 | 288 | return zip(inputs, outputs, total_new_tokens) 289 | 290 | 291 | # warmup is a must if measuring speed as it's when all the optimizations are performed 292 | # e.g. on 8x80 a100 the first pass of 100 tokens takes 23sec, and the next one is 4secs 293 | print_rank0(f"*** Running generate warmup") 294 | _ = generate() 295 | 296 | print_rank0(f"*** Running generate") 297 | t_generate_start = time.time() 298 | generated = generate() 299 | t_generate_span = time.time() - t_generate_start 300 | for i, o, _ in generated: 301 | print_rank0(f"{'-'*60}\nin={i}\nout={o}\n") 302 | 303 | if args.benchmark: 304 | torch.cuda.empty_cache() 305 | gc.collect() 306 | deepspeed.runtime.utils.see_memory_usage("end-of-run", force=True) 307 | 308 | ### Benchmark 309 | 310 | # benchmark it! 311 | if args.benchmark: 312 | print_rank0(f"*** Running benchmark") 313 | 314 | # warm up 315 | for i in range(1): 316 | _ = generate() 317 | torch.cuda.synchronize() 318 | 319 | # benchmark 320 | t0 = time.time() 321 | cycles = 5 322 | total_new_tokens_generated = 0 323 | for i in range(cycles): 324 | generated = generate() 325 | total_new_tokens_generated += sum(new_tokens for _, _, new_tokens in generated) 326 | torch.cuda.synchronize() 327 | througput = (time.time() - t0) / (total_new_tokens_generated) 328 | print_rank0( 329 | f""" 330 | *** Performance stats: 331 | Throughput per token including tokenize: {througput*1000:.2f} msecs 332 | Start to ready to generate: {t_ready - t_start:.3f} secs 333 | Tokenize and generate {total_new_tokens_generated} (bs={args.batch_size}) tokens: {t_generate_span:.3f} secs 334 | Start to finish: {t_ready - t_start + t_generate_span:.3f} secs 335 | """ 336 | ) 337 | -------------------------------------------------------------------------------- /distill_bloom/dataset/utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from .gpt_dataset import GPTDataset 7 | from .indexed_dataset import (IndexedDataset, MMapIndexedDataset, 8 | create_doc_idx, data_file_path, index_file_path) 9 | 10 | 11 | def print_rank_0(message): 12 | """If distributed is initialized, print only on rank 0.""" 13 | if torch.distributed.is_initialized(): 14 | if torch.distributed.get_rank() == 0: 15 | print(message, flush=True) 16 | else: 17 | print(message, flush=True) 18 | 19 | 20 | def infer_dataset_impl(path): 21 | if IndexedDataset.exists(path): 22 | with open(index_file_path(path), "rb") as f: 23 | magic = f.read(8) 24 | if magic == IndexedDataset._HDR_MAGIC: 25 | return "cached" 26 | elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]: 27 | return "mmap" 28 | else: 29 | return None 30 | else: 31 | print(f"Dataset does not exist: {path}") 32 | print( 33 | "Path should be a basename that both .idx and .bin can be appended to get" 34 | " full filenames." 35 | ) 36 | return None 37 | 38 | 39 | def get_train_valid_test_split_(splits_string, size): 40 | r""" 41 | Get dataset splits from comma or '/' separated string list. 42 | `splits_string` expects an string of 3 sets of integers, summing up to `1000`. 43 | 44 | Returns: 45 | The proportion of the dataset to be used for training, validation and testing. 46 | """ 47 | splits = [] 48 | if splits_string.find(",") != -1: 49 | splits = [float(s) for s in splits_string.split(",")] 50 | elif splits_string.find("/") != -1: 51 | splits = [float(s) for s in splits_string.split("/")] 52 | else: 53 | splits = [float(splits_string)] 54 | while len(splits) < 3: 55 | splits.append(0.0) 56 | splits = splits[:3] 57 | splits_sum = sum(splits) 58 | assert splits_sum > 0.0 59 | splits = [split / splits_sum for split in splits] 60 | splits_index = [0] 61 | for index, split in enumerate(splits): 62 | splits_index.append(splits_index[index] + int(round(split * float(size)))) 63 | diff = splits_index[-1] - size 64 | for index in range(1, len(splits_index)): 65 | splits_index[index] -= diff 66 | assert len(splits_index) == 4 67 | assert splits_index[-1] == size 68 | return splits_index 69 | 70 | 71 | def analyze_data_prefix(data_prefix): 72 | # The data prefix should be in the format of: 73 | # weight-1, data-prefix-1, weight-2, data-prefix-2, .. 74 | assert len(data_prefix) % 2 == 0 75 | num_datasets = len(data_prefix) // 2 76 | weights = [0] * num_datasets 77 | prefixes = [0] * num_datasets 78 | for i in range(num_datasets): 79 | weights[i] = float(data_prefix[2 * i]) 80 | prefixes[i] = (data_prefix[2 * i + 1]).strip() 81 | # Normalize weights 82 | weight_sum = 0.0 83 | for weight in weights: 84 | weight_sum += weight 85 | assert weight_sum > 0.0 86 | weights = [weight / weight_sum for weight in weights] 87 | return prefixes, weights 88 | 89 | 90 | def get_split_by_range_(range_string, size): 91 | """Get dataset splits based on a range: 92 | range_string is in the form START%:END% for e.g. 0.2:0.8 93 | outputs an array of two values [start_index, end_index] 94 | """ 95 | # some checks that range is given in the correct form 96 | splits = [float(i) for i in range_string.split(":")] 97 | assert len(splits) == 2, "splits should be passed as start:end" 98 | assert splits[0] <= 1 and splits[1] <= 1 99 | splits_sum = sum(splits) 100 | assert splits_sum > 0.0 101 | splits_index = [round(s * float(size)) for s in splits] 102 | assert len(splits_index) == 2 103 | return splits_index 104 | 105 | 106 | def get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples): 107 | # Add 0.5% (the 1.005 factor) so in case the blending dataset does 108 | # not uniformly distribute the number of samples, we still have 109 | # samples left to feed to the network. 110 | prefixes, weights = analyze_data_prefix(data_prefix) 111 | datasets_train_valid_test_num_samples = [] 112 | for weight in weights: 113 | datasets_train_valid_test_num_samples.append( 114 | [ 115 | int(math.ceil(val * weight * 1.005)) 116 | for val in train_valid_test_num_samples 117 | ] 118 | ) 119 | 120 | return prefixes, weights, datasets_train_valid_test_num_samples 121 | 122 | 123 | def build_dataset_group( 124 | dataset_group_name, 125 | paths, 126 | weights, 127 | splits, 128 | data_impl, 129 | train_valid_test_num_samples, 130 | seq_length, 131 | seed, 132 | skip_warmup, 133 | train_valid_test, 134 | ): 135 | """ 136 | Build a single dataset group corresponding to Option 2 of data loading see arguments.py 137 | a dataset group is passed on the following form 138 | GIVEN_NAME WEIGHT1 START:END PATH1, WEIGHT2 START:END PATH2, WEIGHT2 START:END PATH2 139 | or alternatively 140 | GIVEN_NAME PATH1 # for a single dataset to be used fully 141 | """ 142 | 143 | assert train_valid_test in ["train", "valid", "test"] 144 | 145 | # Single dataset. 146 | if len(paths) == 1: 147 | dataset = _build_single_datasets( 148 | paths[0], 149 | splits[0], 150 | data_impl, 151 | train_valid_test_num_samples, 152 | seq_length, 153 | seed, 154 | skip_warmup, 155 | dataset_group_name, 156 | train_valid_test, 157 | ) 158 | return dataset 159 | # Blending dataset. 160 | else: 161 | data_prefix = [] 162 | # data_prefix is on the shape: 163 | # ["WEIGHT1", "PATH1", "WEIGHT2", "PATH2", "WEIGHT3", "PATH3"] 164 | for w, p in zip(weights, paths): 165 | data_prefix += [w, p] 166 | 167 | output = get_datasets_weights_and_num_samples( 168 | data_prefix, train_valid_test_num_samples 169 | ) 170 | prefixes, weights, datasets_train_valid_test_num_samples = output 171 | 172 | # Build individual datasets. 173 | datasets = [] 174 | for i in range(len(prefixes)): 175 | ds = _build_single_datasets( 176 | prefixes[i], 177 | splits[i], 178 | data_impl, 179 | datasets_train_valid_test_num_samples[i], 180 | seq_length, 181 | seed, 182 | skip_warmup, 183 | dataset_group_name, 184 | train_valid_test, 185 | ) 186 | 187 | datasets.append(ds) 188 | all_datasets = BlendableDataset(datasets, weights) 189 | 190 | return all_datasets 191 | 192 | 193 | def make_dataset(path, impl, skip_warmup=False): 194 | if not IndexedDataset.exists(path): 195 | print(f"Dataset does not exist: {path}") 196 | print( 197 | "Path should be a basename that both .idx and .bin can be appended to get" 198 | " full filenames." 199 | ) 200 | return None 201 | if impl == "infer": 202 | impl = infer_dataset_impl(path) 203 | if impl == "lazy" and IndexedDataset.exists(path): 204 | return IndexedDataset(path) 205 | elif impl == "cached" and IndexedDataset.exists(path): 206 | return IndexedCachedDataset(path) 207 | elif impl == "mmap" and MMapIndexedDataset.exists(path): 208 | return MMapIndexedDataset(path, skip_warmup) 209 | print(f"Unknown dataset implementation: {impl}") 210 | return None 211 | 212 | 213 | def get_indexed_dataset_(path, data_impl, skip_warmup): 214 | """Build indexed dataset.""" 215 | print_rank_0(" > building dataset index ...") 216 | start_time = time.time() 217 | indexed_dataset = make_dataset(path, data_impl, skip_warmup) 218 | print_rank_0( 219 | " > finished creating indexed dataset in {:4f} seconds".format( 220 | time.time() - start_time 221 | ) 222 | ) 223 | print_rank_0(" number of documents: {}".format(indexed_dataset.sizes.shape[0])) 224 | 225 | return indexed_dataset 226 | 227 | 228 | def build_dataset_group( 229 | dataset_group_name, 230 | paths, 231 | weights, 232 | splits, 233 | data_impl, 234 | train_valid_test_num_samples, 235 | seq_length, 236 | seed, 237 | skip_warmup, 238 | train_valid_test, 239 | ): 240 | """ 241 | Build a single dataset group corresponding to Option 2 of data loading see arguments.py 242 | a dataset group is passed on the following form 243 | GIVEN_NAME WEIGHT1 START:END PATH1, WEIGHT2 START:END PATH2, WEIGHT2 START:END PATH2 244 | or alternatively 245 | GIVEN_NAME PATH1 # for a single dataset to be used fully 246 | """ 247 | 248 | assert train_valid_test in ["train", "valid", "test"] 249 | 250 | # Single dataset. 251 | if len(paths) == 1: 252 | dataset = _build_single_datasets( 253 | paths[0], 254 | splits[0], 255 | data_impl, 256 | train_valid_test_num_samples, 257 | seq_length, 258 | seed, 259 | skip_warmup, 260 | dataset_group_name, 261 | train_valid_test, 262 | ) 263 | return dataset 264 | # Blending dataset. 265 | else: 266 | data_prefix = [] 267 | # data_prefix is on the shape: 268 | # ["WEIGHT1", "PATH1", "WEIGHT2", "PATH2", "WEIGHT3", "PATH3"] 269 | for w, p in zip(weights, paths): 270 | data_prefix += [w, p] 271 | 272 | output = get_datasets_weights_and_num_samples( 273 | data_prefix, train_valid_test_num_samples 274 | ) 275 | prefixes, weights, datasets_train_valid_test_num_samples = output 276 | 277 | # Build individual datasets. 278 | datasets = [] 279 | for i in range(len(prefixes)): 280 | ds = _build_single_datasets( 281 | prefixes[i], 282 | splits[i], 283 | data_impl, 284 | datasets_train_valid_test_num_samples[i], 285 | seq_length, 286 | seed, 287 | skip_warmup, 288 | dataset_group_name, 289 | train_valid_test, 290 | ) 291 | 292 | datasets.append(ds) 293 | all_datasets = BlendableDataset(datasets, weights) 294 | 295 | return all_datasets 296 | 297 | 298 | def _build_single_datasets( 299 | data_prefix, 300 | range_string, 301 | data_impl, 302 | train_valid_test_num_samples, 303 | seq_length, 304 | seed, 305 | skip_warmup, 306 | dataset_group_name, 307 | train_valid_test, 308 | ): 309 | """Build a single dataset""" 310 | 311 | assert train_valid_test in ["train", "valid", "test"] 312 | index = ["train", "valid", "test"].index(train_valid_test) 313 | 314 | # Indexed dataset. 315 | indexed_dataset = get_indexed_dataset_(data_prefix, data_impl, skip_warmup) 316 | 317 | total_num_of_documents = indexed_dataset.sizes.shape[0] 318 | # this corresponds to option2 for data loading on the form 319 | # WEIGHT1 START:END PATH1, WEIGHT2 START:END PATH2, WEIGHT3 START:END PATH3 320 | # splits here is an array of size 2 [start_index, end_index] 321 | splits = get_split_by_range_(range_string=range_string, size=total_num_of_documents) 322 | 323 | # Print stats about the splits. 324 | print_rank_0(" > dataset split:") 325 | 326 | print_rank_0(" {}:".format(dataset_group_name)) 327 | print_rank_0( 328 | " document indices in [{}, {}) total of {} documents".format( 329 | splits[0], splits[1], splits[1] - splits[0] 330 | ) 331 | ) 332 | 333 | def build_dataset(name): 334 | dataset = None 335 | if splits[1] > splits[0]: 336 | documents = np.arange( 337 | start=splits[0], stop=splits[1], step=1, dtype=np.int32 338 | ) 339 | dataset = GPTDataset( 340 | name, 341 | data_prefix, 342 | documents, 343 | indexed_dataset, 344 | train_valid_test_num_samples[index], 345 | seq_length, 346 | seed, 347 | ) 348 | return dataset 349 | 350 | dataset = build_dataset(dataset_group_name) 351 | 352 | return dataset 353 | -------------------------------------------------------------------------------- /distill_bloom/dataset/megatron/mpu/random.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | # Parts of the code here are adapted from PyTorch 18 | # repo: https://github.com/pytorch/pytorch 19 | 20 | import contextlib 21 | 22 | import torch 23 | from megatron import get_args 24 | from megatron.memory import allocate_mem_buff 25 | from torch import _C 26 | from torch.cuda import _lazy_call 27 | from torch.cuda import device as device_ctx_manager 28 | from torch.utils.checkpoint import detach_variable 29 | 30 | from .initialize import (get_data_parallel_rank, 31 | get_tensor_model_parallel_group, 32 | get_tensor_model_parallel_rank, 33 | get_tensor_model_parallel_world_size) 34 | 35 | # Default name for the model parallel rng tracker. 36 | _MODEL_PARALLEL_RNG_TRACKER_NAME = "model-parallel-rng" 37 | 38 | 39 | # Whether apply model parallelsim to checkpointed hidden states. 40 | _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = None 41 | 42 | 43 | def init_checkpointed_activations_memory_buffer(): 44 | """Initialize the memory buffer for the checkpointed activations.""" 45 | args = get_args() 46 | 47 | upper_bound_sequence_length = max( 48 | args.seq_length if args.seq_length is not None else 0, 49 | args.decoder_seq_length if args.decoder_seq_length is not None else 0, 50 | ) 51 | per_layer = ( 52 | args.micro_batch_size 53 | * upper_bound_sequence_length 54 | * args.hidden_size 55 | // args.tensor_model_parallel_size 56 | ) 57 | assert ( 58 | args.num_layers % args.checkpoint_num_layers == 0 59 | ), "number of layers is not divisible by checkpoint-num-layers" 60 | num_checkpointer_layers = args.num_layers // args.checkpoint_num_layers 61 | numel = per_layer * num_checkpointer_layers 62 | dtype = torch.half 63 | if not args.fp16: 64 | dtype = torch.float 65 | 66 | global _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER 67 | assert ( 68 | _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is None 69 | ), "checkpointed activations memory buffer is already allocated." 70 | _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = allocate_mem_buff( 71 | "checkpointed activations", numel, dtype, track_usage=False 72 | ) 73 | 74 | 75 | def reset_checkpointed_activations_memory_buffer(): 76 | """Reset the memory used for checkpointing.""" 77 | if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None: 78 | _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER.reset() 79 | 80 | 81 | def _set_cuda_rng_state(new_state, device=-1): 82 | """Sets the random number generator state of the current GPU. 83 | 84 | Argumentss: 85 | new_state (torch.ByteTensor): The desired state 86 | This function is adapted from PyTorch repo (torch.cuda.set_rng_state) 87 | with a single change: the input state is not cloned. Cloning caused 88 | major performance issues for +4 GPU cases. 89 | """ 90 | if hasattr(_C, "_cuda_setRNGState") and callable(_C._cuda_setRNGState): 91 | # older PyTorch 92 | def cb(): 93 | with device_ctx_manager(device): 94 | _C._cuda_setRNGState(new_state) 95 | 96 | else: 97 | # newer PyTorch 98 | if device == -1: 99 | device = torch.device("cuda") 100 | elif isinstance(device, str): 101 | device = torch.device(device) 102 | elif isinstance(device, int): 103 | device = torch.device("cuda", device) 104 | 105 | def cb(): 106 | idx = device.index 107 | if idx is None: 108 | idx = torch.cuda.current_device() 109 | default_generator = torch.cuda.default_generators[idx] 110 | default_generator.set_state(new_state) 111 | 112 | _lazy_call(cb) 113 | 114 | 115 | def split_tensor_into_1d_equal_chunks(tensor): 116 | """Break a tensor into equal 1D chunks.""" 117 | data = tensor.view(-1) 118 | partition_size = torch.numel(data) // get_tensor_model_parallel_world_size() 119 | start_index = partition_size * get_tensor_model_parallel_rank() 120 | end_index = start_index + partition_size 121 | return data[start_index:end_index] 122 | 123 | 124 | def gather_split_1d_tensor(tensor): 125 | """Opposite of above function, gather values from model parallel ranks.""" 126 | world_size = get_tensor_model_parallel_world_size() 127 | numel = torch.numel(tensor) 128 | numel_gathered = world_size * numel 129 | gathered = torch.empty( 130 | numel_gathered, 131 | dtype=tensor.dtype, 132 | device=torch.cuda.current_device(), 133 | requires_grad=False, 134 | ) 135 | chunks = [gathered[i * numel : (i + 1) * numel] for i in range(world_size)] 136 | torch.distributed.all_gather( 137 | chunks, tensor, group=get_tensor_model_parallel_group() 138 | ) 139 | return gathered 140 | 141 | 142 | class CudaRNGStatesTracker: 143 | """Tracker for the cuda RNG states. 144 | 145 | Using the `add` method, a cuda rng state is initialized based on 146 | the input `seed` and is assigned to `name`. Later, by forking the 147 | rng state, we can perform operations and return to our starting 148 | cuda state. 149 | """ 150 | 151 | def __init__(self): 152 | # Map from a string name to the cuda rng state. 153 | self.states_ = {} 154 | # Seeds are just for book keeping and ensure no seed is set twice. 155 | self.seeds_ = set() 156 | 157 | def reset(self): 158 | """Set to the initial state (no tracker).""" 159 | self.states_ = {} 160 | self.seeds_ = set() 161 | 162 | def get_states(self): 163 | """Get rng states. Copy the dictionary so we have direct 164 | pointers to the states, not just a pointer to the dictionary.""" 165 | states = {} 166 | for name in self.states_: 167 | states[name] = self.states_[name] 168 | return states 169 | 170 | def set_states(self, states): 171 | """Set the rng states. For efficiency purposes, we do not check 172 | the size of seed for compatibility.""" 173 | self.states_ = states 174 | 175 | def add(self, name, seed): 176 | """Track the rng state.""" 177 | # Check seed is not already used. 178 | if seed in self.seeds_: 179 | raise Exception("seed {} already exists".format(seed)) 180 | self.seeds_.add(seed) 181 | # Check that state is not already defined. 182 | if name in self.states_: 183 | raise Exception("cuda rng state {} already exists".format(name)) 184 | # Get the current rng state. 185 | orig_rng_state = torch.cuda.get_rng_state() 186 | # Set the new state and store it. 187 | torch.cuda.manual_seed(seed) 188 | self.states_[name] = torch.cuda.get_rng_state() 189 | # Reset rng state to what it was. 190 | _set_cuda_rng_state(orig_rng_state) 191 | 192 | @contextlib.contextmanager 193 | def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME): 194 | """Fork the cuda rng state, perform operations, and exit with 195 | the original state.""" 196 | # Check if we have added the state 197 | if name not in self.states_: 198 | print(name, self.states_) 199 | raise Exception("cuda rng state {} is not added".format(name)) 200 | # Store current rng state. 201 | orig_cuda_rng_state = torch.cuda.get_rng_state() 202 | # Set rng state to the desired one 203 | _set_cuda_rng_state(self.states_[name]) 204 | # Do the stuff we wanted to do. 205 | try: 206 | yield 207 | finally: 208 | # Update the current rng state for later use. 209 | self.states_[name] = torch.cuda.get_rng_state() 210 | # And set the state to the original state we started with. 211 | _set_cuda_rng_state(orig_cuda_rng_state) 212 | 213 | 214 | # RNG tracker object. 215 | _CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() 216 | 217 | 218 | def get_cuda_rng_tracker(): 219 | """Get cuda rng tracker.""" 220 | return _CUDA_RNG_STATE_TRACKER 221 | 222 | 223 | def model_parallel_cuda_manual_seed(seed): 224 | """Initialize model parallel cuda seed. 225 | 226 | This function should be called after the model parallel is 227 | initialized. Also, no torch.cuda.manual_seed should be called 228 | after this function. Basically, this is replacement for that 229 | function. 230 | Two set of RNG states are tracked: 231 | default state: This is for data parallelism and is the same among a 232 | set of model parallel GPUs but different across 233 | different model paralle groups. This is used for 234 | example for dropout in the non-tensor-model-parallel regions. 235 | tensor-model-parallel state: This state is different among a set of model 236 | parallel GPUs, but the same across data parallel 237 | groups. This is used for example for dropout in 238 | model parallel regions. 239 | """ 240 | # 2718 is just for fun and any POSITIVE value will work. 241 | offset = seed + 2718 242 | tensor_model_parallel_seed = offset + get_tensor_model_parallel_rank() 243 | # Data parallel gets the original seed. 244 | data_parallel_seed = seed 245 | 246 | if torch.distributed.get_rank() == 0: 247 | print( 248 | "> initializing model parallel cuda seeds on global rank {}, " 249 | "model parallel rank {}, and data parallel rank {} with " 250 | "model parallel seed: {} and data parallel seed: {}".format( 251 | torch.distributed.get_rank(), 252 | get_tensor_model_parallel_rank(), 253 | get_data_parallel_rank(), 254 | tensor_model_parallel_seed, 255 | data_parallel_seed, 256 | ), 257 | flush=True, 258 | ) 259 | _CUDA_RNG_STATE_TRACKER.reset() 260 | # Set the default state. 261 | torch.cuda.manual_seed(data_parallel_seed) 262 | # and model parallel state. 263 | _CUDA_RNG_STATE_TRACKER.add( 264 | _MODEL_PARALLEL_RNG_TRACKER_NAME, tensor_model_parallel_seed 265 | ) 266 | 267 | 268 | class CheckpointFunction(torch.autograd.Function): 269 | """This function is adapted from torch.utils.checkpoint with 270 | two main changes: 271 | 1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state` 272 | 2) the states in the model parallel tracker are also properly 273 | tracked/set/reset. 274 | """ 275 | 276 | @staticmethod 277 | def forward(ctx, run_function, *args): 278 | ctx.run_function = run_function 279 | 280 | # Copy the rng states. 281 | ctx.fwd_cpu_rng_state = torch.get_rng_state() 282 | ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state() 283 | ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() 284 | 285 | with torch.no_grad(): 286 | outputs = run_function(*args) 287 | 288 | # Divide hidden states across model parallel group and only keep 289 | # the chunk corresponding to the current rank. 290 | if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None: 291 | ctx.input_0_shape = args[0].data.shape 292 | args[0].data = split_tensor_into_1d_equal_chunks(args[0].data) 293 | args[0].data = _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER.add(args[0].data) 294 | 295 | # Store everything. 296 | ctx.save_for_backward(*args) 297 | 298 | return outputs 299 | 300 | @staticmethod 301 | def backward(ctx, *args): 302 | if not torch.autograd._is_checkpoint_valid(): 303 | raise RuntimeError( 304 | "Checkpointing is not compatible with .grad(), " 305 | "please use .backward() if possible" 306 | ) 307 | inputs = ctx.saved_tensors 308 | if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None: 309 | inputs[0].data = gather_split_1d_tensor(inputs[0].data) 310 | inputs[0].data = inputs[0].data.view(ctx.input_0_shape) 311 | 312 | # Store the current states. 313 | bwd_cpu_rng_state = torch.get_rng_state() 314 | bwd_cuda_rng_state = torch.cuda.get_rng_state() 315 | bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() 316 | 317 | # Set the states to what it used to be before the forward pass. 318 | torch.set_rng_state(ctx.fwd_cpu_rng_state) 319 | _set_cuda_rng_state(ctx.fwd_cuda_rng_state) 320 | get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker) 321 | 322 | # Compute the forward pass. 323 | detached_inputs = detach_variable(inputs) 324 | with torch.enable_grad(): 325 | outputs = ctx.run_function(*detached_inputs) 326 | 327 | # Set the states back to what it was at the start of this function. 328 | torch.set_rng_state(bwd_cpu_rng_state) 329 | _set_cuda_rng_state(bwd_cuda_rng_state) 330 | get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker) 331 | 332 | if isinstance(outputs, torch.Tensor): 333 | outputs = (outputs,) 334 | torch.autograd.backward(outputs, args) 335 | grads = tuple( 336 | inp.grad if isinstance(inp, torch.Tensor) else inp 337 | for inp in detached_inputs 338 | ) 339 | return (None,) + grads 340 | 341 | 342 | def checkpoint(function, *args): 343 | """Checkpoint a model or part of the model. 344 | This has been directly copied from torch.utils.checkpoint.""" 345 | return CheckpointFunction.apply(function, *args) 346 | -------------------------------------------------------------------------------- /distill_bloom/dataset/gpt_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from .megatron import mpu 8 | 9 | 10 | def print_rank_0(message): 11 | """If distributed is initialized, print only on rank 0.""" 12 | if torch.distributed.is_initialized(): 13 | if torch.distributed.get_rank() == 0: 14 | print(message, flush=True) 15 | else: 16 | print(message, flush=True) 17 | 18 | 19 | class GPTDataset(torch.utils.data.Dataset): 20 | def __init__( 21 | self, 22 | name, 23 | data_prefix, 24 | documents, 25 | indexed_dataset, 26 | num_samples, 27 | seq_length, 28 | seed, 29 | ): 30 | self.name = name 31 | self.indexed_dataset = indexed_dataset 32 | 33 | # Checks 34 | assert np.min(documents) >= 0 35 | assert np.max(documents) < indexed_dataset.sizes.shape[0] 36 | 37 | # Build index mappings. 38 | self.doc_idx, self.sample_idx, self.shuffle_idx = _build_index_mappings( 39 | self.name, 40 | data_prefix, 41 | documents, 42 | self.indexed_dataset.sizes, 43 | num_samples, 44 | seq_length, 45 | seed, 46 | ) 47 | 48 | def __len__(self): 49 | # -1 is due to data structure used to retieve the index: 50 | # sample i --> [sample_idx[i], sample_idx[i+1]) 51 | return self.sample_idx.shape[0] - 1 52 | 53 | def __getitem__(self, idx): 54 | # Get the shuffled index. 55 | idx = self.shuffle_idx[idx] 56 | # Start and end documents and offsets. 57 | doc_index_f = self.sample_idx[idx][0] 58 | doc_index_l = self.sample_idx[idx + 1][0] 59 | offset_f = self.sample_idx[idx][1] 60 | offset_l = self.sample_idx[idx + 1][1] 61 | # If we are within the same document, just extract the chunk. 62 | if doc_index_f == doc_index_l: 63 | sample = self.indexed_dataset.get( 64 | self.doc_idx[doc_index_f], 65 | offset=offset_f, 66 | length=offset_l - offset_f + 1, 67 | ) 68 | else: 69 | # Otherwise, get the rest of the initial document. 70 | sample_list = [ 71 | self.indexed_dataset.get(self.doc_idx[doc_index_f], offset=offset_f) 72 | ] 73 | # Loop over all in between documents and add the entire document. 74 | for i in range(doc_index_f + 1, doc_index_l): 75 | sample_list.append(self.indexed_dataset.get(self.doc_idx[i])) 76 | # And finally add the relevant portion of last document. 77 | sample_list.append( 78 | self.indexed_dataset.get(self.doc_idx[doc_index_l], length=offset_l + 1) 79 | ) 80 | sample = np.concatenate(sample_list) 81 | 82 | return {"text": np.array(sample, dtype=np.int64)} 83 | 84 | 85 | def _build_index_mappings( 86 | name, 87 | data_prefix, 88 | documents, 89 | sizes, 90 | num_samples, 91 | seq_length, 92 | seed, 93 | cutoff_last_epoch=0.95, 94 | ): 95 | """Build doc-idx, sample-idx, and shuffle-idx. 96 | doc-idx: is an array (ordered) of documents to be used in training. 97 | sample-idx: is the start document index and document offset for each 98 | training sample. 99 | shuffle-idx: maps the sample index into a random index into sample-idx. 100 | """ 101 | # Number of tokens in each epoch and number of required epochs. 102 | tokens_per_epoch = _num_tokens(documents, sizes) 103 | num_epochs = _num_epochs(tokens_per_epoch, seq_length, num_samples) 104 | # rng state 105 | np_rng = np.random.RandomState(seed=seed) 106 | 107 | # Filename of the index mappings. 108 | _filename = data_prefix 109 | _filename += "_{}_indexmap".format(name) 110 | _filename += "_{}ns".format(num_samples) 111 | _filename += "_{}sl".format(seq_length) 112 | _filename += "_{}s".format(seed) 113 | doc_idx_filename = _filename + "_doc_idx.npy" 114 | sample_idx_filename = _filename + "_sample_idx.npy" 115 | shuffle_idx_filename = _filename + "_shuffle_idx.npy" 116 | 117 | # Build the indexed mapping if not exist. 118 | if torch.distributed.get_rank() == 0: 119 | if ( 120 | (not os.path.isfile(doc_idx_filename)) 121 | or (not os.path.isfile(sample_idx_filename)) 122 | or (not os.path.isfile(shuffle_idx_filename)) 123 | ): 124 | print_rank_0( 125 | " > WARNING: could not find index map files, building " 126 | "the indices on rank 0 ..." 127 | ) 128 | 129 | # For the last epoch, decide whether include the entire epoch 130 | # in the global shuffle or not. 131 | 132 | # If we need only one epoch, then separating last epoch does 133 | # not mean anything. 134 | if num_epochs == 1: 135 | separate_last_epoch = False 136 | print( 137 | " > only one epoch required, setting separate_last_epoch to False", 138 | flush=True, 139 | ) 140 | 141 | else: 142 | # Get the number of samples for the last epoch 143 | num_samples_from_epochs_minus_one = ( 144 | (num_epochs - 1) * tokens_per_epoch - 1 145 | ) // seq_length 146 | last_epoch_num_samples = num_samples - num_samples_from_epochs_minus_one 147 | assert last_epoch_num_samples >= 0, ( 148 | f"last epoch number of samples {last_epoch_num_samples} should be" 149 | " non-negative." 150 | ) 151 | num_samples_per_epoch = (tokens_per_epoch - 1) // seq_length 152 | assert last_epoch_num_samples <= num_samples_per_epoch, ( 153 | f"last epoch number of samples {last_epoch_num_samples} exceeded" 154 | f" max value {num_samples_per_epoch}." 155 | ) 156 | # If we have less than cutoff_last_epoch * samples_per_epoch of the samples for the last epoch, 157 | # seperate out the epoch and treat it differently. 158 | separate_last_epoch = last_epoch_num_samples < int( 159 | cutoff_last_epoch * num_samples_per_epoch 160 | ) 161 | if separate_last_epoch: 162 | string = ( 163 | " > last epoch number of samples ({}) is smaller " 164 | "than {}% of number of samples per epoch ({}), " 165 | "setting separate_last_epoch to True" 166 | ) 167 | else: 168 | string = ( 169 | " > last epoch number of samples ({}) is larger " 170 | "than {}% of number of samples per epoch ({}), " 171 | "setting separate_last_epoch to False" 172 | ) 173 | print( 174 | string.format( 175 | last_epoch_num_samples, 176 | cutoff_last_epoch * 100, 177 | num_samples_per_epoch, 178 | ), 179 | flush=True, 180 | ) 181 | 182 | # doc-idx. 183 | start_time = time.time() 184 | doc_idx = _build_doc_idx(documents, num_epochs, np_rng, separate_last_epoch) 185 | np.save(doc_idx_filename, doc_idx, allow_pickle=True) 186 | print_rank_0( 187 | " > elasped time to build and save doc-idx mapping " 188 | "(seconds): {:4f}".format(time.time() - start_time) 189 | ) 190 | # sample-idx. 191 | start_time = time.time() 192 | # Use C++ implementation for speed. 193 | # First compile and then import. 194 | from .megatron import helpers 195 | 196 | assert doc_idx.dtype == np.int32 197 | assert sizes.dtype == np.int32 198 | sample_idx = helpers.build_sample_idx( 199 | sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch 200 | ) 201 | # sample_idx = _build_sample_idx(sizes, doc_idx, seq_length, 202 | # num_epochs, tokens_per_epoch) 203 | np.save(sample_idx_filename, sample_idx, allow_pickle=True) 204 | print_rank_0( 205 | " > elasped time to build and save sample-idx mapping " 206 | "(seconds): {:4f}".format(time.time() - start_time) 207 | ) 208 | # shuffle-idx. 209 | start_time = time.time() 210 | # -1 is due to data structure used to retieve the index: 211 | # sample i --> [sample_idx[i], sample_idx[i+1]) 212 | if separate_last_epoch: 213 | num_samples_ = num_samples_from_epochs_minus_one 214 | else: 215 | num_samples_ = sample_idx.shape[0] - 1 216 | shuffle_idx = _build_shuffle_idx( 217 | num_samples_, sample_idx.shape[0] - 1, np_rng 218 | ) 219 | np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True) 220 | print_rank_0( 221 | " > elasped time to build and save shuffle-idx mapping" 222 | " (seconds): {:4f}".format(time.time() - start_time) 223 | ) 224 | 225 | # This should be a barrier but nccl barrier assumes 226 | # device_index=rank which is not the case for model 227 | # parallel case 228 | # counts = torch.cuda.LongTensor([1]) 229 | # torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group()) 230 | # torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group()) 231 | # assert counts[0].item() == ( 232 | # torch.distributed.get_world_size() // 233 | # torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group())) 234 | 235 | # Load mappings. 236 | start_time = time.time() 237 | print_rank_0(" > loading doc-idx mapping from {}".format(doc_idx_filename)) 238 | doc_idx = np.load(doc_idx_filename, allow_pickle=True, mmap_mode="r") 239 | print_rank_0(" > loading sample-idx mapping from {}".format(sample_idx_filename)) 240 | sample_idx = np.load(sample_idx_filename, allow_pickle=True, mmap_mode="r") 241 | print_rank_0(" > loading shuffle-idx mapping from {}".format(shuffle_idx_filename)) 242 | shuffle_idx = np.load(shuffle_idx_filename, allow_pickle=True, mmap_mode="r") 243 | print_rank_0( 244 | " loaded indexed file in {:3.3f} seconds".format(time.time() - start_time) 245 | ) 246 | print_rank_0(" total number of samples: {}".format(sample_idx.shape[0])) 247 | print_rank_0(" total number of epochs: {}".format(num_epochs)) 248 | 249 | return doc_idx, sample_idx, shuffle_idx 250 | 251 | 252 | def _num_tokens(documents, sizes): 253 | """Total number of tokens in the dataset.""" 254 | return np.sum(sizes[documents]) 255 | 256 | 257 | def _num_epochs(tokens_per_epoch, seq_length, num_samples): 258 | """Based on number of samples and sequence lenght, calculate how many 259 | epochs will be needed.""" 260 | num_epochs = 0 261 | total_tokens = 0 262 | while True: 263 | num_epochs += 1 264 | total_tokens += tokens_per_epoch 265 | # -1 is because we need to retrieve seq_length + 1 token each time 266 | # but the last token will overlap with the first token of the next 267 | # sample except for the last sample. 268 | if ((total_tokens - 1) // seq_length) >= num_samples: 269 | return num_epochs 270 | 271 | 272 | def _build_doc_idx(documents, num_epochs, np_rng, separate_last_epoch): 273 | """Build an array with length = number-of-epochs * number-of-dcuments. 274 | Each index is mapped to a corresponding document.""" 275 | if not separate_last_epoch or num_epochs == 1: 276 | doc_idx = np.mgrid[0:num_epochs, 0 : len(documents)][1] 277 | doc_idx[:] = documents 278 | doc_idx = doc_idx.reshape(-1) 279 | doc_idx = doc_idx.astype(np.int32) 280 | np_rng.shuffle(doc_idx) 281 | return doc_idx 282 | 283 | doc_idx_first = _build_doc_idx(documents, num_epochs - 1, np_rng, False) 284 | doc_idx_last = _build_doc_idx(documents, 1, np_rng, False) 285 | return np.concatenate((doc_idx_first, doc_idx_last)) 286 | 287 | 288 | def _build_sample_idx(sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch): 289 | """Sample index mapping is a 2D array with sizes 290 | [number-of-samples + 1, 2] where [..., 0] contains 291 | the index into `doc_idx` and [..., 1] is the 292 | starting offset in that document.""" 293 | 294 | # Total number of samples. For -1 see comments in `_num_epochs`. 295 | num_samples = (num_epochs * tokens_per_epoch - 1) // seq_length 296 | sample_idx = np.zeros([num_samples + 1, 2], dtype=np.int32) 297 | 298 | # Index into sample_idx. 299 | sample_index = 0 300 | # Index into doc_idx. 301 | doc_idx_index = 0 302 | # Begining offset for each document. 303 | doc_offset = 0 304 | # Start with first document and no offset. 305 | sample_idx[sample_index][0] = doc_idx_index 306 | sample_idx[sample_index][1] = doc_offset 307 | sample_index += 1 308 | while sample_index <= num_samples: 309 | # Start with a fresh sequence. 310 | remaining_seq_length = seq_length + 1 311 | while remaining_seq_length != 0: 312 | # Get the document length. 313 | doc_id = doc_idx[doc_idx_index] 314 | doc_length = sizes[doc_id] - doc_offset 315 | # And add it to the current sequence. 316 | remaining_seq_length -= doc_length 317 | # If we have more than a full sequence, adjust offset and set 318 | # remaining length to zero so we return from the while loop. 319 | # Note that -1 here is for the same reason we have -1 in 320 | # `_num_epochs` calculations. 321 | if remaining_seq_length <= 0: 322 | doc_offset += remaining_seq_length + doc_length - 1 323 | remaining_seq_length = 0 324 | else: 325 | # Otherwise, start from the begining of the next document. 326 | doc_idx_index += 1 327 | doc_offset = 0 328 | # Record the sequence. 329 | sample_idx[sample_index][0] = doc_idx_index 330 | sample_idx[sample_index][1] = doc_offset 331 | sample_index += 1 332 | 333 | return sample_idx 334 | 335 | 336 | def _build_shuffle_idx(num_samples, total_size, np_rng): 337 | """Build the range [0, size) and shuffle.""" 338 | print( 339 | " > building shuffle index with split [0, {}) and [{}, {}) ...".format( 340 | num_samples, num_samples, total_size 341 | ), 342 | flush=True, 343 | ) 344 | 345 | dtype_ = np.uint32 346 | if total_size >= (np.iinfo(np.uint32).max - 1): 347 | dtype_ = np.int64 348 | 349 | shuffle_idx_first = np.arange(start=0, stop=num_samples, step=1, dtype=dtype_) 350 | np_rng.shuffle(shuffle_idx_first) 351 | if num_samples == total_size: 352 | return shuffle_idx_first 353 | 354 | shuffle_idx_last = np.arange( 355 | start=num_samples, stop=total_size, step=1, dtype=dtype_ 356 | ) 357 | np_rng.shuffle(shuffle_idx_last) 358 | 359 | return np.concatenate((shuffle_idx_first, shuffle_idx_last)) 360 | -------------------------------------------------------------------------------- /distill_bloom/dataset/megatron/mpu/initialize.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | """Model and data parallel groups.""" 18 | 19 | import torch 20 | 21 | from .utils import ensure_divisibility 22 | 23 | # Intra-layer model parallel group that the current rank belongs to. 24 | _TENSOR_MODEL_PARALLEL_GROUP = None 25 | # Inter-layer model parallel group that the current rank belongs to. 26 | _PIPELINE_MODEL_PARALLEL_GROUP = None 27 | # Model parallel group (both intra- and pipeline) that the current rank belongs to. 28 | _MODEL_PARALLEL_GROUP = None 29 | # Embedding group. 30 | _EMBEDDING_GROUP = None 31 | # Data parallel group that the current rank belongs to. 32 | _DATA_PARALLEL_GROUP = None 33 | 34 | _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None 35 | _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None 36 | 37 | # These values enable us to change the mpu sizes on the fly. 38 | _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None 39 | _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None 40 | _MPU_TENSOR_MODEL_PARALLEL_RANK = None 41 | _MPU_PIPELINE_MODEL_PARALLEL_RANK = None 42 | 43 | # A list of global ranks for each pipeline group to ease calculation of the source 44 | # rank when broadcasting from the first or last pipeline stage 45 | _PIPELINE_GLOBAL_RANKS = None 46 | 47 | 48 | def is_unitialized(): 49 | """Useful for code segments that may be accessed with or without mpu initialization 50 | """ 51 | return _DATA_PARALLEL_GROUP is None 52 | 53 | 54 | def initialize_model_parallel( 55 | tensor_model_parallel_size_=1, 56 | pipeline_model_parallel_size_=1, 57 | virtual_pipeline_model_parallel_size_=None, 58 | ): 59 | """ 60 | Initialize model data parallel groups. 61 | 62 | Arguments: 63 | tensor_model_parallel_size: number of GPUs used to parallelize model tensor. 64 | pipeline_model_parallel_size: number of GPUs used to parallelize model pipeline. 65 | 66 | Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we 67 | use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize 68 | the model pipeline. The present function will 69 | create 8 tensor model-parallel groups, 4 pipeline model-parallel groups 70 | and 8 data-parallel groups as: 71 | 8 data_parallel groups: 72 | [g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15] 73 | 8 tensor model-parallel groups: 74 | [g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15] 75 | 4 pipeline model-parallel groups: 76 | [g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15] 77 | Note that for efficiency, the caller should make sure adjacent ranks 78 | are on the same DGX box. For example if we are using 2 DGX-1 boxes 79 | with a total of 16 GPUs, rank 0 to 7 belong to the first box and 80 | ranks 8 to 15 belong to the second box. 81 | """ 82 | if torch.distributed.get_rank() == 0: 83 | print( 84 | "> initializing tensor model parallel with size {}".format( 85 | tensor_model_parallel_size_ 86 | ) 87 | ) 88 | print( 89 | "> initializing pipeline model parallel with size {}".format( 90 | pipeline_model_parallel_size_ 91 | ) 92 | ) 93 | # Get world size and rank. Ensure some consistencies. 94 | assert torch.distributed.is_initialized() 95 | world_size = torch.distributed.get_world_size() 96 | tensor_model_parallel_size = min(tensor_model_parallel_size_, world_size) 97 | pipeline_model_parallel_size = min(pipeline_model_parallel_size_, world_size) 98 | ensure_divisibility( 99 | world_size, tensor_model_parallel_size * pipeline_model_parallel_size 100 | ) 101 | data_parallel_size = world_size // ( 102 | tensor_model_parallel_size * pipeline_model_parallel_size 103 | ) 104 | 105 | num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size 106 | num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size 107 | num_data_parallel_groups = world_size // data_parallel_size 108 | 109 | if virtual_pipeline_model_parallel_size_ is not None: 110 | global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK 111 | global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE 112 | _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0 113 | _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = ( 114 | virtual_pipeline_model_parallel_size_ 115 | ) 116 | 117 | rank = torch.distributed.get_rank() 118 | 119 | # Build the data-parallel groups. 120 | global _DATA_PARALLEL_GROUP 121 | assert _DATA_PARALLEL_GROUP is None, "data parallel group is already initialized" 122 | all_data_parallel_group_ranks = [] 123 | for i in range(pipeline_model_parallel_size): 124 | start_rank = i * num_pipeline_model_parallel_groups 125 | end_rank = (i + 1) * num_pipeline_model_parallel_groups 126 | for j in range(tensor_model_parallel_size): 127 | ranks = range(start_rank + j, end_rank, tensor_model_parallel_size) 128 | all_data_parallel_group_ranks.append(list(ranks)) 129 | group = torch.distributed.new_group(ranks) 130 | if rank in ranks: 131 | _DATA_PARALLEL_GROUP = group 132 | 133 | # Build the model-parallel groups. 134 | global _MODEL_PARALLEL_GROUP 135 | assert _MODEL_PARALLEL_GROUP is None, "model parallel group is already initialized" 136 | for i in range(data_parallel_size): 137 | ranks = [ 138 | data_parallel_group_ranks[i] 139 | for data_parallel_group_ranks in all_data_parallel_group_ranks 140 | ] 141 | group = torch.distributed.new_group(ranks) 142 | if rank in ranks: 143 | _MODEL_PARALLEL_GROUP = group 144 | 145 | # Build the tensor model-parallel groups. 146 | global _TENSOR_MODEL_PARALLEL_GROUP 147 | assert ( 148 | _TENSOR_MODEL_PARALLEL_GROUP is None 149 | ), "tensor model parallel group is already initialized" 150 | for i in range(num_tensor_model_parallel_groups): 151 | ranks = range( 152 | i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size 153 | ) 154 | group = torch.distributed.new_group(ranks) 155 | if rank in ranks: 156 | _TENSOR_MODEL_PARALLEL_GROUP = group 157 | 158 | # Build the pipeline model-parallel groups and embedding groups 159 | # (first and last rank in each pipeline model-parallel group). 160 | global _PIPELINE_MODEL_PARALLEL_GROUP 161 | global _PIPELINE_GLOBAL_RANKS 162 | assert ( 163 | _PIPELINE_MODEL_PARALLEL_GROUP is None 164 | ), "pipeline model parallel group is already initialized" 165 | global _EMBEDDING_GROUP 166 | assert _EMBEDDING_GROUP is None, "embedding group is already initialized" 167 | for i in range(num_pipeline_model_parallel_groups): 168 | ranks = range(i, world_size, num_pipeline_model_parallel_groups) 169 | group = torch.distributed.new_group(ranks) 170 | if rank in ranks: 171 | _PIPELINE_MODEL_PARALLEL_GROUP = group 172 | _PIPELINE_GLOBAL_RANKS = ranks 173 | # Setup embedding group (to exchange gradients between 174 | # first and last stages). 175 | if len(ranks) > 1: 176 | embedding_ranks = [ranks[0], ranks[-1]] 177 | else: 178 | embedding_ranks = ranks 179 | group = torch.distributed.new_group(embedding_ranks) 180 | if rank in embedding_ranks: 181 | _EMBEDDING_GROUP = group 182 | 183 | 184 | def model_parallel_is_initialized(): 185 | """Check if model and data parallel groups are initialized.""" 186 | if ( 187 | _TENSOR_MODEL_PARALLEL_GROUP is None 188 | or _PIPELINE_MODEL_PARALLEL_GROUP is None 189 | or _DATA_PARALLEL_GROUP is None 190 | ): 191 | return False 192 | return True 193 | 194 | 195 | def get_model_parallel_group(): 196 | """Get the model parallel group the caller rank belongs to.""" 197 | assert _MODEL_PARALLEL_GROUP is not None, "model parallel group is not initialized" 198 | return _MODEL_PARALLEL_GROUP 199 | 200 | 201 | def get_tensor_model_parallel_group(): 202 | """Get the tensor model parallel group the caller rank belongs to.""" 203 | assert ( 204 | _TENSOR_MODEL_PARALLEL_GROUP is not None 205 | ), "intra_layer_model parallel group is not initialized" 206 | return _TENSOR_MODEL_PARALLEL_GROUP 207 | 208 | 209 | def get_pipeline_model_parallel_group(): 210 | """Get the pipeline model parallel group the caller rank belongs to.""" 211 | assert ( 212 | _PIPELINE_MODEL_PARALLEL_GROUP is not None 213 | ), "pipeline_model parallel group is not initialized" 214 | return _PIPELINE_MODEL_PARALLEL_GROUP 215 | 216 | 217 | def get_data_parallel_group(): 218 | """Get the data parallel group the caller rank belongs to.""" 219 | assert _DATA_PARALLEL_GROUP is not None, "data parallel group is not initialized" 220 | return _DATA_PARALLEL_GROUP 221 | 222 | 223 | def get_embedding_group(): 224 | """Get the embedding group the caller rank belongs to.""" 225 | assert _EMBEDDING_GROUP is not None, "embedding group is not initialized" 226 | return _EMBEDDING_GROUP 227 | 228 | 229 | def set_tensor_model_parallel_world_size(world_size): 230 | """Set the tensor model parallel size""" 231 | global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE 232 | _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = world_size 233 | 234 | 235 | def set_pipeline_model_parallel_world_size(world_size): 236 | """Set the pipeline model parallel size""" 237 | global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE 238 | _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size 239 | 240 | 241 | def get_tensor_model_parallel_world_size(): 242 | """Return world size for the tensor model parallel group.""" 243 | global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE 244 | if _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE is not None: 245 | return _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE 246 | return torch.distributed.get_world_size(group=get_tensor_model_parallel_group()) 247 | 248 | 249 | def get_model_parallel_world_size(): 250 | assert ( 251 | get_pipeline_model_parallel_world_size() == 1 252 | ), "legacy get_model_parallel_world_size is only supported if PP is disabled" 253 | return get_tensor_model_parallel_world_size() 254 | 255 | 256 | def get_pipeline_model_parallel_world_size(): 257 | """Return world size for the pipeline model parallel group.""" 258 | global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE 259 | if _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None: 260 | return _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE 261 | return torch.distributed.get_world_size(group=get_pipeline_model_parallel_group()) 262 | 263 | 264 | def set_tensor_model_parallel_rank(rank): 265 | """Set tensor model parallel rank.""" 266 | global _MPU_TENSOR_MODEL_PARALLEL_RANK 267 | _MPU_TENSOR_MODEL_PARALLEL_RANK = rank 268 | 269 | 270 | def set_pipeline_model_parallel_rank(rank): 271 | """Set pipeline model parallel rank.""" 272 | global _MPU_PIPELINE_MODEL_PARALLEL_RANK 273 | _MPU_PIPELINE_MODEL_PARALLEL_RANK = rank 274 | 275 | 276 | def get_tensor_model_parallel_rank(): 277 | """Return my rank for the tensor model parallel group.""" 278 | global _MPU_TENSOR_MODEL_PARALLEL_RANK 279 | if _MPU_TENSOR_MODEL_PARALLEL_RANK is not None: 280 | return _MPU_TENSOR_MODEL_PARALLEL_RANK 281 | return torch.distributed.get_rank(group=get_tensor_model_parallel_group()) 282 | 283 | 284 | def get_model_parallel_rank(): 285 | assert ( 286 | get_pipeline_model_parallel_world_size() == 1 287 | ), "legacy get_model_parallel_rank is only supported if PP is disabled" 288 | return get_tensor_model_parallel_rank() 289 | 290 | 291 | def get_pipeline_model_parallel_rank(): 292 | """Return my rank for the pipeline model parallel group.""" 293 | global _MPU_PIPELINE_MODEL_PARALLEL_RANK 294 | if _MPU_PIPELINE_MODEL_PARALLEL_RANK is not None: 295 | return _MPU_PIPELINE_MODEL_PARALLEL_RANK 296 | return torch.distributed.get_rank(group=get_pipeline_model_parallel_group()) 297 | 298 | 299 | def is_pipeline_first_stage(ignore_virtual=False): 300 | """Return True if in the first pipeline model-parallel stage, False otherwise.""" 301 | if not ignore_virtual: 302 | if ( 303 | get_virtual_pipeline_model_parallel_world_size() is not None 304 | and get_virtual_pipeline_model_parallel_rank() != 0 305 | ): 306 | return False 307 | return get_pipeline_model_parallel_rank() == 0 308 | 309 | 310 | def is_pipeline_last_stage(ignore_virtual=False): 311 | """Return True if in the last pipeline model-parallel stage, False otherwise.""" 312 | if not ignore_virtual: 313 | virtual_pipeline_model_parallel_world_size = ( 314 | get_virtual_pipeline_model_parallel_world_size() 315 | ) 316 | if ( 317 | virtual_pipeline_model_parallel_world_size is not None 318 | and get_virtual_pipeline_model_parallel_rank() 319 | != (virtual_pipeline_model_parallel_world_size - 1) 320 | ): 321 | return False 322 | return get_pipeline_model_parallel_rank() == ( 323 | get_pipeline_model_parallel_world_size() - 1 324 | ) 325 | 326 | 327 | def get_virtual_pipeline_model_parallel_rank(): 328 | """Return the virtual pipeline-parallel rank.""" 329 | global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK 330 | return _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK 331 | 332 | 333 | def set_virtual_pipeline_model_parallel_rank(rank): 334 | """Set the virtual pipeline-parallel rank.""" 335 | global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK 336 | _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = rank 337 | 338 | 339 | def get_virtual_pipeline_model_parallel_world_size(): 340 | """Return the virtual pipeline-parallel world size.""" 341 | global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE 342 | return _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE 343 | 344 | 345 | def get_tensor_model_parallel_src_rank(): 346 | """Calculate the global rank corresponding to the first local rank 347 | in the tensor model parallel group.""" 348 | global_rank = torch.distributed.get_rank() 349 | local_world_size = get_tensor_model_parallel_world_size() 350 | return (global_rank // local_world_size) * local_world_size 351 | 352 | 353 | def get_pipeline_model_parallel_first_rank(): 354 | assert ( 355 | _PIPELINE_GLOBAL_RANKS is not None 356 | ), "Pipeline parallel group is not initialized" 357 | return _PIPELINE_GLOBAL_RANKS[0] 358 | 359 | 360 | def get_pipeline_model_parallel_last_rank(): 361 | assert ( 362 | _PIPELINE_GLOBAL_RANKS is not None 363 | ), "Pipeline parallel group is not initialized" 364 | last_rank_local = get_pipeline_model_parallel_world_size() - 1 365 | return _PIPELINE_GLOBAL_RANKS[last_rank_local] 366 | 367 | 368 | def get_pipeline_model_parallel_next_rank(): 369 | assert ( 370 | _PIPELINE_GLOBAL_RANKS is not None 371 | ), "Pipeline parallel group is not initialized" 372 | rank_in_pipeline = get_pipeline_model_parallel_rank() 373 | world_size = get_pipeline_model_parallel_world_size() 374 | return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size] 375 | 376 | 377 | def get_pipeline_model_parallel_prev_rank(): 378 | assert ( 379 | _PIPELINE_GLOBAL_RANKS is not None 380 | ), "Pipeline parallel group is not initialized" 381 | rank_in_pipeline = get_pipeline_model_parallel_rank() 382 | world_size = get_pipeline_model_parallel_world_size() 383 | return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size] 384 | 385 | 386 | def get_data_parallel_world_size(): 387 | """Return world size for the data parallel group.""" 388 | return torch.distributed.get_world_size(group=get_data_parallel_group()) 389 | 390 | 391 | def get_data_parallel_rank(): 392 | """Return my rank for the data parallel group.""" 393 | return torch.distributed.get_rank(group=get_data_parallel_group()) 394 | 395 | 396 | def destroy_model_parallel(): 397 | """Set the groups to none.""" 398 | global _TENSOR_MODEL_PARALLEL_GROUP 399 | _TENSOR_MODEL_PARALLEL_GROUP = None 400 | global _PIPELINE_MODEL_PARALLEL_GROUP 401 | _PIPELINE_MODEL_PARALLEL_GROUP = None 402 | global _DATA_PARALLEL_GROUP 403 | _DATA_PARALLEL_GROUP = None 404 | -------------------------------------------------------------------------------- /distill_bloom/dataset/megatron/mpu/layers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | # Parts of the code here are adapted from PyTorch 18 | # repo: https://github.com/pytorch/pytorch 19 | 20 | 21 | import math 22 | from functools import partial 23 | 24 | import deepspeed.runtime.activation_checkpointing.checkpointing as ds_checkpointing 25 | import torch 26 | import torch.nn.functional as F 27 | import torch.nn.init as init 28 | from megatron import get_args, mpu 29 | from torch.nn.parameter import Parameter 30 | 31 | from ..model.fused_layer_norm import MixedFusedLayerNorm as LayerNorm 32 | from .initialize import (get_tensor_model_parallel_rank, 33 | get_tensor_model_parallel_world_size) 34 | from .mappings import (copy_to_tensor_model_parallel_region, 35 | gather_from_tensor_model_parallel_region, 36 | reduce_from_tensor_model_parallel_region, 37 | scatter_to_tensor_model_parallel_region) 38 | from .random import get_cuda_rng_tracker 39 | from .utils import VocabUtility, divide, split_tensor_along_last_dim 40 | 41 | _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = { 42 | "tensor_model_parallel": False, 43 | "partition_dim": -1, 44 | "partition_stride": 1, 45 | } 46 | 47 | 48 | def param_is_not_tensor_parallel_duplicate(param): 49 | return ( 50 | hasattr(param, "tensor_model_parallel") and param.tensor_model_parallel 51 | ) or (get_tensor_model_parallel_rank() == 0) 52 | 53 | 54 | def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride): 55 | # Make sure the attributes are not set. 56 | for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS: 57 | assert not hasattr(tensor, attribute) 58 | # Set the attributes. 59 | setattr(tensor, "tensor_model_parallel", is_parallel) 60 | setattr(tensor, "partition_dim", dim) 61 | setattr(tensor, "partition_stride", stride) 62 | 63 | 64 | def set_defaults_if_not_set_tensor_model_parallel_attributes(tensor): 65 | def maybe_set(attribute, value): 66 | if not hasattr(tensor, attribute): 67 | setattr(tensor, attribute, value) 68 | 69 | for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS: 70 | maybe_set(attribute, _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS[attribute]) 71 | 72 | 73 | def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor): 74 | def maybe_copy(attribute): 75 | if hasattr(source_tensor, attribute): 76 | setattr(destination_tensor, attribute, getattr(source_tensor, attribute)) 77 | 78 | for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS: 79 | maybe_copy(attribute) 80 | 81 | 82 | def _initialize_affine_weight_gpu(weight, init_method, partition_dim, stride=1): 83 | """Initialize affine weight for model parallel on GPU.""" 84 | 85 | set_tensor_model_parallel_attributes( 86 | tensor=weight, is_parallel=True, dim=partition_dim, stride=stride 87 | ) 88 | 89 | if ds_checkpointing.is_configured(): 90 | global get_cuda_rng_tracker 91 | get_cuda_rng_tracker = ds_checkpointing.get_cuda_rng_tracker 92 | 93 | with get_cuda_rng_tracker().fork(): 94 | init_method(weight) 95 | 96 | 97 | def _initialize_affine_weight_cpu( 98 | weight, 99 | output_size, 100 | input_size, 101 | per_partition_size, 102 | partition_dim, 103 | init_method, 104 | stride=1, 105 | return_master_weight=False, 106 | ): 107 | """Initialize affine weight for model parallel. 108 | 109 | Build the master weight on all processes and scatter 110 | the relevant chunk.""" 111 | 112 | set_tensor_model_parallel_attributes( 113 | tensor=weight, is_parallel=True, dim=partition_dim, stride=stride 114 | ) 115 | 116 | # Initialize master weight 117 | master_weight = torch.empty( 118 | output_size, input_size, dtype=torch.float, requires_grad=False 119 | ) 120 | init_method(master_weight) 121 | args = get_args() 122 | master_weight = master_weight.to(dtype=args.params_dtype) 123 | 124 | # Split and copy 125 | per_partition_per_stride_size = divide(per_partition_size, stride) 126 | weight_list = torch.split( 127 | master_weight, per_partition_per_stride_size, dim=partition_dim 128 | ) 129 | rank = get_tensor_model_parallel_rank() 130 | world_size = get_tensor_model_parallel_world_size() 131 | my_weight_list = weight_list[rank::world_size] 132 | 133 | with torch.no_grad(): 134 | torch.cat(my_weight_list, dim=partition_dim, out=weight) 135 | if return_master_weight: 136 | return master_weight 137 | return None 138 | 139 | 140 | def xavier_uniform_tensor_parallel_(tensor, gain=1.0, tp_degree=1): 141 | r""" 142 | This is a modified torch.nn.init.xavier_uniform_ with changes to support 143 | partitioned on the vocab size dim embedding with tensor parallel. 144 | 145 | Additional args: 146 | - tp_degree: degree of tensor parallel 147 | 148 | Note: the code assumes all partitions are equal in size 149 | """ 150 | # receptive_field_size=1 as dim==2, so we don't need init._calculate_fan_in_and_fan_out 151 | fan_out, fan_in = tensor.shape 152 | fan_out *= tp_degree # tp splits on num_embeddings dim 153 | 154 | std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) 155 | a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation 156 | 157 | return torch.nn.init._no_grad_uniform_(tensor, -a, a) 158 | 159 | 160 | class VocabParallelEmbedding(torch.nn.Module): 161 | """Embedding parallelized in the vocabulary dimension. 162 | 163 | This is mainly adapted from torch.nn.Embedding and all the default 164 | values are kept. 165 | Arguments: 166 | num_embeddings: vocabulary size. 167 | embedding_dim: size of hidden state. 168 | init_method: method to initialize weights. 169 | """ 170 | 171 | def __init__(self, num_embeddings, embedding_dim, init_method=init.xavier_normal_): 172 | super(VocabParallelEmbedding, self).__init__() 173 | # Keep the input dimensions. 174 | self.num_embeddings = num_embeddings 175 | self.embedding_dim = embedding_dim 176 | # Set the defaults for compatibility. 177 | self.padding_idx = None 178 | self.max_norm = None 179 | self.norm_type = 2.0 180 | self.scale_grad_by_freq = False 181 | self.sparse = False 182 | self._weight = None 183 | self.tensor_model_parallel_size = get_tensor_model_parallel_world_size() 184 | # Divide the weight matrix along the vocabulary dimension. 185 | ( 186 | self.vocab_start_index, 187 | self.vocab_end_index, 188 | ) = VocabUtility.vocab_range_from_global_vocab_size( 189 | self.num_embeddings, 190 | get_tensor_model_parallel_rank(), 191 | self.tensor_model_parallel_size, 192 | ) 193 | self.num_embeddings_per_partition = ( 194 | self.vocab_end_index - self.vocab_start_index 195 | ) 196 | 197 | # Allocate weights and initialize. 198 | args = get_args() 199 | 200 | # only the first stage embedding runs this class' forward. The head's embedding does its own 201 | # thing, so don't waste memory allocating LN weights. 202 | if mpu.is_pipeline_first_stage() and ( 203 | args.use_bnb_optimizer or args.embed_layernorm 204 | ): 205 | self.norm = LayerNorm(embedding_dim) 206 | 207 | if args.use_bnb_optimizer: 208 | # for BNB we ignore the passed init_method and use torch.nn.init.xavier_uniform_ 209 | # modified to calculate std on the unpartitioned embedding 210 | init_method = partial( 211 | xavier_uniform_tensor_parallel_, 212 | tp_degree=self.tensor_model_parallel_size, 213 | ) 214 | 215 | if args.use_cpu_initialization: 216 | self.weight = Parameter( 217 | torch.empty( 218 | self.num_embeddings_per_partition, 219 | self.embedding_dim, 220 | dtype=args.params_dtype, 221 | ) 222 | ) 223 | _initialize_affine_weight_cpu( 224 | self.weight, 225 | self.num_embeddings, 226 | self.embedding_dim, 227 | self.num_embeddings_per_partition, 228 | 0, 229 | init_method, 230 | ) 231 | else: 232 | self.weight = Parameter( 233 | torch.empty( 234 | self.num_embeddings_per_partition, 235 | self.embedding_dim, 236 | device=torch.cuda.current_device(), 237 | dtype=args.params_dtype, 238 | ) 239 | ) 240 | _initialize_affine_weight_gpu( 241 | self.weight, init_method, partition_dim=0, stride=1 242 | ) 243 | 244 | if args.use_bnb_optimizer: 245 | from bitsandbytes.optim import GlobalOptimManager 246 | 247 | GlobalOptimManager.get_instance().override_config( 248 | self.weight, "optim_bits", 32 249 | ) 250 | GlobalOptimManager.get_instance().register_parameters(self.weight) 251 | 252 | def forward(self, input_): 253 | if torch.any(input_ >= self.num_embeddings): 254 | raise ValueError( 255 | "There is an input id in the input that is greater than the highest" 256 | f" possible input id.\nInput: {input_}\nnum_embeddings:" 257 | f" {self.num_embeddings}" 258 | ) 259 | 260 | if self.tensor_model_parallel_size > 1: 261 | # Build the mask. 262 | input_mask = (input_ < self.vocab_start_index) | ( 263 | input_ >= self.vocab_end_index 264 | ) 265 | # Mask the input. 266 | masked_input = input_.clone() - self.vocab_start_index 267 | masked_input[input_mask] = 0 268 | else: 269 | # input_ is garanted to be in the range [0:self.vocab_end_index - self.vocab_start_index] thanks to the first check 270 | masked_input = input_ 271 | 272 | # Get the embeddings. 273 | output_parallel = F.embedding( 274 | masked_input, 275 | self.weight, 276 | self.padding_idx, 277 | self.max_norm, 278 | self.norm_type, 279 | self.scale_grad_by_freq, 280 | self.sparse, 281 | ) 282 | # Mask the output embedding. 283 | if self.tensor_model_parallel_size > 1: 284 | output_parallel[input_mask, :] = 0.0 285 | # Reduce across all the model parallel GPUs. 286 | output = reduce_from_tensor_model_parallel_region(output_parallel) 287 | 288 | if hasattr(self, "norm"): 289 | output = self.norm(output) 290 | 291 | return output 292 | 293 | 294 | class ColumnParallelLinear(torch.nn.Module): 295 | """Linear layer with column parallelism. 296 | 297 | The linear layer is defined as Y = XA + b. A is parallelized along 298 | its second dimension as A = [A_1, ..., A_p]. 299 | 300 | Arguments: 301 | input_size: first dimension of matrix A. 302 | output_size: second dimension of matrix A. 303 | bias: If true, add bias 304 | gather_output: If true, call all-gether on output and make Y avaiable 305 | to all GPUs, otherwise, every GPU will have its output 306 | which is Y_i = XA_i 307 | init_method: method to initialize weights. Note that bias is always set 308 | to zero. 309 | stride: For the strided linear layers. 310 | keep_master_weight_for_test: This was added for testing and should be 311 | set to False. It returns the master weights 312 | used for initialization. 313 | skip_bias_add: This was added to enable performance optimations where bias 314 | can be fused with other elementwise operations. we skip 315 | adding bias but instead return it. 316 | """ 317 | 318 | def __init__( 319 | self, 320 | input_size, 321 | output_size, 322 | bias=True, 323 | gather_output=True, 324 | init_method=init.xavier_normal_, 325 | stride=1, 326 | keep_master_weight_for_test=False, 327 | skip_bias_add=False, 328 | ): 329 | super(ColumnParallelLinear, self).__init__() 330 | 331 | # Keep input parameters 332 | self.input_size = input_size 333 | self.output_size = output_size 334 | self.gather_output = gather_output 335 | # Divide the weight matrix along the last dimension. 336 | world_size = get_tensor_model_parallel_world_size() 337 | self.output_size_per_partition = divide(output_size, world_size) 338 | self.skip_bias_add = skip_bias_add 339 | 340 | # Parameters. 341 | # Note: torch.nn.functional.linear performs XA^T + b and as a result 342 | # we allocate the transpose. 343 | # Initialize weight. 344 | args = get_args() 345 | if args.use_cpu_initialization: 346 | self.weight = Parameter( 347 | torch.empty( 348 | self.output_size_per_partition, 349 | self.input_size, 350 | dtype=args.params_dtype, 351 | ) 352 | ) 353 | self.master_weight = _initialize_affine_weight_cpu( 354 | self.weight, 355 | self.output_size, 356 | self.input_size, 357 | self.output_size_per_partition, 358 | 0, 359 | init_method, 360 | stride=stride, 361 | return_master_weight=keep_master_weight_for_test, 362 | ) 363 | else: 364 | self.weight = Parameter( 365 | torch.empty( 366 | self.output_size_per_partition, 367 | self.input_size, 368 | device=torch.cuda.current_device(), 369 | dtype=args.params_dtype, 370 | ) 371 | ) 372 | _initialize_affine_weight_gpu( 373 | self.weight, init_method, partition_dim=0, stride=stride 374 | ) 375 | 376 | if bias: 377 | if args.use_cpu_initialization: 378 | self.bias = Parameter( 379 | torch.empty(self.output_size_per_partition, dtype=args.params_dtype) 380 | ) 381 | else: 382 | self.bias = Parameter( 383 | torch.empty( 384 | self.output_size_per_partition, 385 | device=torch.cuda.current_device(), 386 | dtype=args.params_dtype, 387 | ) 388 | ) 389 | set_tensor_model_parallel_attributes(self.bias, True, 0, stride) 390 | # Always initialize bias to zero. 391 | with torch.no_grad(): 392 | self.bias.zero_() 393 | else: 394 | self.register_parameter("bias", None) 395 | 396 | def forward(self, input_): 397 | # Set up backprop all-reduce. 398 | input_parallel = copy_to_tensor_model_parallel_region(input_) 399 | # Matrix multiply. 400 | 401 | bias = self.bias if not self.skip_bias_add else None 402 | output_parallel = F.linear(input_parallel, self.weight, bias) 403 | if self.gather_output: 404 | # All-gather across the partitions. 405 | output = gather_from_tensor_model_parallel_region(output_parallel) 406 | else: 407 | output = output_parallel 408 | output_bias = self.bias if self.skip_bias_add else None 409 | return output, output_bias 410 | 411 | 412 | class RowParallelLinear(torch.nn.Module): 413 | """Linear layer with row parallelism. 414 | 415 | The linear layer is defined as Y = XA + b. A is parallelized along 416 | its first dimension and X along its second dimension as: 417 | - - 418 | | A_1 | 419 | | . | 420 | A = | . | X = [X_1, ..., X_p] 421 | | . | 422 | | A_p | 423 | - - 424 | Arguments: 425 | input_size: first dimension of matrix A. 426 | output_size: second dimension of matrix A. 427 | bias: If true, add bias. Note that bias is not parallelized. 428 | input_is_parallel: If true, we assume that the input is already 429 | split across the GPUs and we do not split 430 | again. 431 | init_method: method to initialize weights. Note that bias is always set 432 | to zero. 433 | stride: For the strided linear layers. 434 | keep_master_weight_for_test: This was added for testing and should be 435 | set to False. It returns the master weights 436 | used for initialization. 437 | skip_bias_add: This was added to enable performance optimations where bias 438 | can be fused with other elementwise operations. we skip 439 | adding bias but instead return it. 440 | """ 441 | 442 | def __init__( 443 | self, 444 | input_size, 445 | output_size, 446 | bias=True, 447 | input_is_parallel=False, 448 | init_method=init.xavier_normal_, 449 | stride=1, 450 | keep_master_weight_for_test=False, 451 | skip_bias_add=False, 452 | ): 453 | super(RowParallelLinear, self).__init__() 454 | 455 | # Keep input parameters 456 | self.input_size = input_size 457 | self.output_size = output_size 458 | self.input_is_parallel = input_is_parallel 459 | # Divide the weight matrix along the last dimension. 460 | world_size = get_tensor_model_parallel_world_size() 461 | self.input_size_per_partition = divide(input_size, world_size) 462 | self.skip_bias_add = skip_bias_add 463 | 464 | # Parameters. 465 | # Note: torch.nn.functional.linear performs XA^T + b and as a result 466 | # we allocate the transpose. 467 | # Initialize weight. 468 | args = get_args() 469 | if args.use_cpu_initialization: 470 | self.weight = Parameter( 471 | torch.empty( 472 | self.output_size, 473 | self.input_size_per_partition, 474 | dtype=args.params_dtype, 475 | ) 476 | ) 477 | self.master_weight = _initialize_affine_weight_cpu( 478 | self.weight, 479 | self.output_size, 480 | self.input_size, 481 | self.input_size_per_partition, 482 | 1, 483 | init_method, 484 | stride=stride, 485 | return_master_weight=keep_master_weight_for_test, 486 | ) 487 | else: 488 | self.weight = Parameter( 489 | torch.empty( 490 | self.output_size, 491 | self.input_size_per_partition, 492 | device=torch.cuda.current_device(), 493 | dtype=args.params_dtype, 494 | ) 495 | ) 496 | _initialize_affine_weight_gpu( 497 | self.weight, init_method, partition_dim=1, stride=stride 498 | ) 499 | if bias: 500 | if args.use_cpu_initialization: 501 | self.bias = Parameter( 502 | torch.empty(self.output_size, dtype=args.params_dtype) 503 | ) 504 | else: 505 | self.bias = Parameter( 506 | torch.empty( 507 | self.output_size, 508 | device=torch.cuda.current_device(), 509 | dtype=args.params_dtype, 510 | ) 511 | ) 512 | # Always initialize bias to zero. 513 | with torch.no_grad(): 514 | self.bias.zero_() 515 | else: 516 | self.register_parameter("bias", None) 517 | 518 | self.bias_tp_auto_sync = args.sync_tp_duplicated_parameters 519 | 520 | def forward(self, input_): 521 | # Set up backprop all-reduce. 522 | if self.input_is_parallel: 523 | input_parallel = input_ 524 | else: 525 | input_parallel = scatter_to_tensor_model_parallel_region(input_) 526 | # Matrix multiply. 527 | output_parallel = F.linear(input_parallel, self.weight) 528 | # All-reduce across all the partitions. 529 | output_ = reduce_from_tensor_model_parallel_region(output_parallel) 530 | 531 | if self.bias_tp_auto_sync: 532 | torch.distributed.all_reduce( 533 | self.bias, 534 | op=torch.distributed.ReduceOp.AVG, 535 | group=mpu.get_tensor_model_parallel_group(), 536 | ) 537 | 538 | if not self.skip_bias_add: 539 | output = output_ + self.bias if self.bias is not None else output_ 540 | output_bias = None 541 | else: 542 | output = output_ 543 | output_bias = self.bias 544 | return output, output_bias 545 | -------------------------------------------------------------------------------- /distill_bloom/dataset/megatron/mpu/tests/test_layers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import random 17 | import sys 18 | 19 | import mpu 20 | import torch 21 | import torch.nn.init as init 22 | from commons import initialize_distributed, print_separator, set_random_seed 23 | from mpu import layers 24 | from torch.nn.parameter import Parameter 25 | 26 | sys.path.append("../..") 27 | 28 | 29 | def test_parallel_embedding(tensor_model_parallel_size): 30 | if torch.distributed.get_rank() == 0: 31 | print( 32 | "> testing parallel embedding with model parallel size {} ...".format( 33 | tensor_model_parallel_size 34 | ) 35 | ) 36 | 37 | mpu.initialize_model_parallel(tensor_model_parallel_size) 38 | tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() 39 | 40 | batch_size = 17 41 | seq_length = 23 42 | vocab_size = 48 43 | hidden_size = 16 44 | seed = 1236 45 | 46 | set_random_seed(123) 47 | input_data = ( 48 | torch.LongTensor(size=(batch_size, seq_length)).random_(0, vocab_size).cuda() 49 | ) 50 | loss_weight = torch.randn([batch_size, seq_length, hidden_size]).cuda() 51 | 52 | set_random_seed(seed) 53 | embedding_original = torch.nn.Embedding(vocab_size, hidden_size).cuda() 54 | 55 | output = embedding_original(input_data) 56 | loss_original = torch.mul(output, loss_weight).sum() 57 | loss_original.backward() 58 | 59 | set_random_seed(seed) 60 | embedding_parallel = layers.ParallelEmbedding( 61 | vocab_size, hidden_size, init_method=init.normal_ 62 | ).cuda() 63 | output = embedding_parallel(input_data) 64 | loss_parallel = torch.mul(output, loss_weight).sum() 65 | loss_parallel.backward() 66 | 67 | set_random_seed(seed) 68 | embedding_vocab_parallel = layers.VocabParallelEmbedding( 69 | vocab_size, hidden_size, init_method=init.normal_ 70 | ).cuda() 71 | output = embedding_vocab_parallel(input_data) 72 | loss_vocab_parallel = torch.mul(output, loss_weight).sum() 73 | loss_vocab_parallel.backward() 74 | 75 | torch.distributed.barrier() 76 | error = loss_parallel.sub(loss_original).abs() 77 | print( 78 | " error in loss (parallel) on global rank {}: {}".format( 79 | torch.distributed.get_rank(), error 80 | ) 81 | ) 82 | assert error < 1.0e-12, "error: {}".format(error) 83 | 84 | torch.distributed.barrier() 85 | error = loss_vocab_parallel.sub(loss_original).abs() 86 | print( 87 | " error in loss (vocab parallel) on global rank {}: {}".format( 88 | torch.distributed.get_rank(), error 89 | ) 90 | ) 91 | assert error < 1.0e-12, "error: {}".format(error) 92 | 93 | weight_grad_orig = torch.split( 94 | embedding_original.weight.grad, hidden_size // tensor_model_parallel_size, 1 95 | )[mpu.get_tensor_model_parallel_rank()] 96 | error = embedding_parallel.weight.grad.sub(weight_grad_orig).abs().max() 97 | print( 98 | " error in grad (parallel) on global rank {}: {}".format( 99 | torch.distributed.get_rank(), error 100 | ) 101 | ) 102 | assert error < 1.0e-12, "error: {}".format(error) 103 | 104 | weight_grad_orig = torch.split( 105 | embedding_original.weight.grad, vocab_size // tensor_model_parallel_size, 0 106 | )[mpu.get_tensor_model_parallel_rank()] 107 | error = embedding_vocab_parallel.weight.grad.sub(weight_grad_orig).abs().max() 108 | print( 109 | " error in grad (vocab parallel) on global rank {}: {}".format( 110 | torch.distributed.get_rank(), error 111 | ) 112 | ) 113 | assert error < 1.0e-12, "error: {}".format(error) 114 | 115 | # Reset groups 116 | mpu.destroy_model_parallel() 117 | 118 | torch.distributed.barrier() 119 | if torch.distributed.get_rank() == 0: 120 | print(">> passed the test :-)") 121 | 122 | 123 | def test_initialize_affine_weight(tensor_model_parallel_size): 124 | mpu.initialize_model_parallel(tensor_model_parallel_size) 125 | if torch.distributed.get_rank() == 0: 126 | print( 127 | "> testing initialize_affine_weight with model parallel size: {}".format( 128 | tensor_model_parallel_size 129 | ) 130 | ) 131 | tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() 132 | 133 | seed = 12345 134 | input_size_coeff = 13 135 | input_size = input_size_coeff * tensor_model_parallel_size 136 | output_size_coeff = 17 137 | output_size = output_size_coeff * tensor_model_parallel_size 138 | 139 | # --------------- 140 | # Column parallel 141 | # --------------- 142 | weight = torch.empty(output_size_coeff, input_size) 143 | set_random_seed(seed) 144 | layers._initialize_affine_weight( 145 | weight, output_size, input_size, output_size_coeff, 0, torch.nn.init.normal_ 146 | ) 147 | # Target. 148 | set_random_seed(seed) 149 | master_weight = torch.empty(output_size, input_size) 150 | torch.nn.init.normal_(master_weight) 151 | rank = mpu.get_tensor_model_parallel_rank() 152 | my_weight = ( 153 | torch.split(master_weight, output_size_coeff, dim=0)[rank].contiguous().clone() 154 | ) 155 | 156 | # Compare. 157 | error = weight.sub(my_weight).abs().max() 158 | torch.distributed.barrier() 159 | print( 160 | " column parallel max error (should be zero) on global rank {}: {}".format( 161 | torch.distributed.get_rank(), error 162 | ) 163 | ) 164 | assert error < 1.0e-6 165 | 166 | # ------------ 167 | # Row parallel 168 | # ------------ 169 | weight = torch.empty(output_size, input_size_coeff) 170 | set_random_seed(seed) 171 | mpu.layers._initialize_affine_weight( 172 | weight, output_size, input_size, input_size_coeff, 1, torch.nn.init.normal_ 173 | ) 174 | # Target. 175 | set_random_seed(seed) 176 | master_weight = torch.empty(output_size, input_size) 177 | torch.nn.init.normal_(master_weight) 178 | rank = mpu.get_tensor_model_parallel_rank() 179 | my_weight = ( 180 | torch.split(master_weight, input_size_coeff, dim=1)[rank].contiguous().clone() 181 | ) 182 | 183 | # Compare. 184 | error = weight.sub(my_weight).abs().max() 185 | torch.distributed.barrier() 186 | print( 187 | " row parallel max error (should be zero) on global rank {}: {}".format( 188 | torch.distributed.get_rank(), error 189 | ) 190 | ) 191 | assert error < 1.0e-6 192 | 193 | # Reset groups 194 | mpu.destroy_model_parallel() 195 | 196 | torch.distributed.barrier() 197 | if torch.distributed.get_rank() == 0: 198 | print(" >> passed the test :-)") 199 | 200 | 201 | class IdentityLayer2D(torch.nn.Module): 202 | def __init__(self, m, n): 203 | super(IdentityLayer2D, self).__init__() 204 | self.weight = Parameter(torch.Tensor(m, n)) 205 | torch.nn.init.xavier_normal_(self.weight) 206 | 207 | def forward(self): 208 | return self.weight 209 | 210 | 211 | def test_column_parallel_linear(tensor_model_parallel_size): 212 | mpu.initialize_model_parallel(tensor_model_parallel_size) 213 | if torch.distributed.get_rank() == 0: 214 | print( 215 | "> testing ColumnParallelLinear with model parallel size: {}".format( 216 | tensor_model_parallel_size 217 | ) 218 | ) 219 | tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() 220 | 221 | seed = 12345 222 | set_random_seed(seed) 223 | input_size_coeff = 13 224 | input_size = input_size_coeff * tensor_model_parallel_size 225 | output_size_coeff = 17 226 | output_size = output_size_coeff * tensor_model_parallel_size 227 | batch_size = 7 228 | 229 | # Network 230 | identity_layer = IdentityLayer2D(batch_size, input_size).cuda() 231 | linear_layer = mpu.ColumnParallelLinear( 232 | input_size, output_size, keep_master_weight_for_test=True 233 | ).cuda() 234 | loss_weight = torch.randn([batch_size, output_size]).cuda() 235 | # Forward 236 | input_ = identity_layer() 237 | output = linear_layer(input_) 238 | loss = torch.mul(output, loss_weight).sum() 239 | # Backward 240 | loss.backward() 241 | 242 | # Values. 243 | dLdY = loss_weight 244 | X = identity_layer.weight 245 | A = linear_layer.master_weight.cuda() 246 | dLdA = torch.matmul(dLdY.t(), X) 247 | dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1) 248 | dLdX = torch.matmul(dLdY, A) 249 | 250 | rank = mpu.get_tensor_model_parallel_rank() 251 | my_dLdA = torch.split(dLdA, output_size_coeff, dim=0)[rank].contiguous().clone() 252 | error = my_dLdA.sub(linear_layer.weight.grad).abs().max() 253 | torch.distributed.barrier() 254 | print( 255 | " error in dLdA on global rank {}: {}".format( 256 | torch.distributed.get_rank(), error 257 | ) 258 | ) 259 | assert error < 1.0e-6 260 | 261 | my_dLdb = torch.split(dLdb, output_size_coeff, dim=0)[rank].contiguous().clone() 262 | error = my_dLdb.sub(linear_layer.bias.grad).abs().max() 263 | torch.distributed.barrier() 264 | print( 265 | " error in dLdb on global rank {}: {}".format( 266 | torch.distributed.get_rank(), error 267 | ) 268 | ) 269 | assert error < 1.0e-6 270 | 271 | error = dLdX.sub(identity_layer.weight.grad).abs().max() 272 | torch.distributed.barrier() 273 | print( 274 | " error in dLdX on global rank {}: {}".format( 275 | torch.distributed.get_rank(), error 276 | ) 277 | ) 278 | assert error < 1.0e-6 279 | 280 | # Reset groups 281 | mpu.destroy_model_parallel() 282 | 283 | torch.distributed.barrier() 284 | if torch.distributed.get_rank() == 0: 285 | print(" >> passed the test :-)") 286 | 287 | 288 | def test_row_parallel_linear(tensor_model_parallel_size): 289 | mpu.initialize_model_parallel(tensor_model_parallel_size) 290 | if torch.distributed.get_rank() == 0: 291 | print( 292 | "> testing RowParallelLinear with model parallel size: {}".format( 293 | tensor_model_parallel_size 294 | ) 295 | ) 296 | tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() 297 | 298 | seed = 12345 299 | set_random_seed(seed) 300 | input_size_coeff = 13 301 | input_size = input_size_coeff * tensor_model_parallel_size 302 | output_size_coeff = 17 303 | output_size = output_size_coeff * tensor_model_parallel_size 304 | batch_size = 7 305 | 306 | # Network 307 | identity_layer = IdentityLayer2D(batch_size, input_size).cuda() 308 | linear_layer = mpu.RowParallelLinear( 309 | input_size, output_size, keep_master_weight_for_test=True 310 | ).cuda() 311 | loss_weight = torch.randn([batch_size, output_size]).cuda() 312 | # Forward 313 | input_ = identity_layer() 314 | output = linear_layer(input_) 315 | loss = torch.mul(output, loss_weight).sum() 316 | # Backward 317 | loss.backward() 318 | 319 | # Values. 320 | dLdY = loss_weight 321 | X = identity_layer.weight 322 | A = linear_layer.master_weight.cuda() 323 | dLdA = torch.matmul(dLdY.t(), X) 324 | dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1) 325 | dLdX = torch.matmul(dLdY, A) 326 | 327 | rank = mpu.get_tensor_model_parallel_rank() 328 | my_dLdA = torch.split(dLdA, input_size_coeff, dim=1)[rank].contiguous().clone() 329 | error = my_dLdA.sub(linear_layer.weight.grad).abs().max() 330 | torch.distributed.barrier() 331 | print( 332 | " error in dLdA on global rank {}: {}".format( 333 | torch.distributed.get_rank(), error 334 | ) 335 | ) 336 | assert error < 1.0e-6 337 | 338 | error = dLdb.sub(linear_layer.bias.grad).abs().max() 339 | torch.distributed.barrier() 340 | print( 341 | " error in dLdb on global rank {}: {}".format( 342 | torch.distributed.get_rank(), error 343 | ) 344 | ) 345 | assert error < 1.0e-6 346 | 347 | error = dLdX.sub(identity_layer.weight.grad).abs().max() 348 | torch.distributed.barrier() 349 | print( 350 | " error in dLdX on global rank {}: {}".format( 351 | torch.distributed.get_rank(), error 352 | ) 353 | ) 354 | assert error < 1.0e-6 355 | 356 | # Reset groups 357 | mpu.destroy_model_parallel() 358 | 359 | torch.distributed.barrier() 360 | if torch.distributed.get_rank() == 0: 361 | print(" >> passed the test :-)") 362 | 363 | 364 | class IdentityLayer3D(torch.nn.Module): 365 | def __init__(self, m, n, k): 366 | super(IdentityLayer3D, self).__init__() 367 | self.weight = Parameter(torch.Tensor(m, n, k)) 368 | torch.nn.init.xavier_normal_(self.weight) 369 | 370 | def forward(self): 371 | return self.weight 372 | 373 | 374 | def parallel_self_attention( 375 | tensor_model_parallel_size, 376 | num_att_heads_per_partition, 377 | hidden_size_per_att_head, 378 | dropout_prob, 379 | batch_size, 380 | sequence_length, 381 | ): 382 | mpu.initialize_model_parallel(tensor_model_parallel_size) 383 | tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() 384 | 385 | seed = 12345 386 | set_random_seed(seed) 387 | 388 | num_att_heads = num_att_heads_per_partition * torch.distributed.get_world_size() 389 | hidden_size = hidden_size_per_att_head * num_att_heads 390 | 391 | # Network 392 | identity_layer = IdentityLayer3D(batch_size, sequence_length, hidden_size).cuda() 393 | attention_layer = mpu.BertParallelSelfAttention( 394 | hidden_size, num_att_heads, dropout_prob 395 | ).cuda() 396 | loss_weight = torch.randn([batch_size, sequence_length, hidden_size]).cuda() 397 | attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda() 398 | # Forward 399 | input_ = identity_layer() 400 | output = attention_layer(input_, attention_mask) 401 | loss = torch.mul(output, loss_weight).sum() 402 | # Backward 403 | loss.backward() 404 | 405 | rank = mpu.get_tensor_model_parallel_rank() 406 | mpu.destroy_model_parallel() 407 | return ( 408 | rank, 409 | hidden_size, 410 | tensor_model_parallel_size, 411 | loss, 412 | attention_layer, 413 | identity_layer, 414 | ) 415 | 416 | 417 | def test_parallel_self_attention(tensor_model_parallel_size): 418 | if torch.distributed.get_rank() == 0: 419 | print( 420 | "> testing ParallelSelfAttention with model parallel size: {}".format( 421 | tensor_model_parallel_size 422 | ) 423 | ) 424 | 425 | num_att_heads_per_partition = 3 426 | hidden_size_per_att_head = 7 427 | dropout_prob = 0.0 # has to be zero 428 | batch_size = 5 429 | sequence_length = 13 430 | 431 | ( 432 | rank_1, 433 | hideen_size_1, 434 | tensor_model_parallel_size_1, 435 | loss_1, 436 | attention_layer_1, 437 | identity_layer_1, 438 | ) = parallel_self_attention( 439 | 1, 440 | num_att_heads_per_partition, 441 | hidden_size_per_att_head, 442 | dropout_prob, 443 | batch_size, 444 | sequence_length, 445 | ) 446 | 447 | ( 448 | rank, 449 | hidden_size, 450 | tensor_model_parallel_size, 451 | loss, 452 | attention_layer, 453 | identity_layer, 454 | ) = parallel_self_attention( 455 | tensor_model_parallel_size, 456 | num_att_heads_per_partition, 457 | hidden_size_per_att_head, 458 | dropout_prob, 459 | batch_size, 460 | sequence_length, 461 | ) 462 | assert hideen_size_1 == hidden_size 463 | 464 | error = loss_1.sub(loss).abs().max() 465 | torch.distributed.barrier() 466 | print( 467 | " loss error on global rank {}: {}".format( 468 | torch.distributed.get_rank(), error 469 | ) 470 | ) 471 | assert error < 5.0e-6 472 | 473 | my_lin_grad_list = torch.split( 474 | attention_layer_1.query_key_value.weight.grad, 475 | hidden_size // tensor_model_parallel_size, 476 | 0, 477 | )[rank::tensor_model_parallel_size] 478 | my_lin_grad = torch.cat(my_lin_grad_list, dim=0) 479 | error = my_lin_grad.sub(attention_layer.query_key_value.weight.grad).abs().max() 480 | torch.distributed.barrier() 481 | print( 482 | " weight gradient error on global rank {}: {}".format( 483 | torch.distributed.get_rank(), error 484 | ) 485 | ) 486 | assert error < 5.0e-6 487 | 488 | error = identity_layer_1.weight.grad.sub(identity_layer.weight.grad).abs().max() 489 | torch.distributed.barrier() 490 | print( 491 | " input gradient error on global rank {}: {}".format( 492 | torch.distributed.get_rank(), error 493 | ) 494 | ) 495 | assert error < 5.0e-6 496 | 497 | torch.distributed.barrier() 498 | if torch.distributed.get_rank() == 0: 499 | print(" >> passed the test :-)") 500 | 501 | 502 | def parallel_transformer( 503 | tensor_model_parallel_size, 504 | num_att_heads_per_partition, 505 | hidden_size_per_att_head, 506 | batch_size, 507 | sequence_length, 508 | ): 509 | mpu.initialize_model_parallel(tensor_model_parallel_size) 510 | tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size() 511 | 512 | seed = 12345 513 | set_random_seed(seed) 514 | 515 | num_att_heads = num_att_heads_per_partition * torch.distributed.get_world_size() 516 | hidden_size = hidden_size_per_att_head * num_att_heads 517 | intermediate_size = 4 * hidden_size 518 | 519 | # Network 520 | identity_layer = IdentityLayer3D(batch_size, sequence_length, hidden_size).cuda() 521 | transformer_layer = mpu.BertParallelTransformerLayer( 522 | hidden_size, 523 | intermediate_size, 524 | num_att_heads, 525 | 0.0, 526 | 0.0, 527 | torch.nn.functional.relu, 528 | 1.0e-5, 529 | ).cuda() 530 | 531 | loss_weight = torch.randn([batch_size, sequence_length, hidden_size]).cuda() 532 | attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda() 533 | # Forward 534 | input_ = identity_layer() 535 | output = transformer_layer(input_, attention_mask) 536 | loss = torch.mul(output, loss_weight).sum() 537 | # Backward 538 | loss.backward() 539 | 540 | rank = mpu.get_tensor_model_parallel_rank() 541 | mpu.destroy_model_parallel() 542 | return ( 543 | rank, 544 | hidden_size, 545 | tensor_model_parallel_size, 546 | loss, 547 | transformer_layer, 548 | identity_layer, 549 | ) 550 | 551 | 552 | def test_parallel_transformer_layer(tensor_model_parallel_size): 553 | if torch.distributed.get_rank() == 0: 554 | print( 555 | "> testing ParallelTransformerLayer with model parallel size: {}".format( 556 | tensor_model_parallel_size 557 | ) 558 | ) 559 | 560 | num_att_heads_per_partition = 3 561 | hidden_size_per_att_head = 7 562 | batch_size = 5 563 | sequence_length = 13 564 | 565 | ( 566 | rank_1, 567 | hidden_size_1, 568 | tensor_model_parallel_size_1, 569 | loss_1, 570 | transformer_layer_1, 571 | identity_layer_1, 572 | ) = parallel_transformer( 573 | 1, 574 | num_att_heads_per_partition, 575 | hidden_size_per_att_head, 576 | batch_size, 577 | sequence_length, 578 | ) 579 | 580 | ( 581 | rank, 582 | hidden_size, 583 | tensor_model_parallel_size, 584 | loss, 585 | transformer_layer, 586 | identity_layer, 587 | ) = parallel_transformer( 588 | tensor_model_parallel_size, 589 | num_att_heads_per_partition, 590 | hidden_size_per_att_head, 591 | batch_size, 592 | sequence_length, 593 | ) 594 | 595 | error = loss_1.sub(loss).abs().max() 596 | torch.distributed.barrier() 597 | print( 598 | " loss error on global rank {}: {}".format( 599 | torch.distributed.get_rank(), error 600 | ) 601 | ) 602 | assert error < 5.0e-5, "error: {}".format(error) 603 | 604 | error = identity_layer_1.weight.grad.sub(identity_layer.weight.grad).abs().max() 605 | torch.distributed.barrier() 606 | print( 607 | " input gradient error on global rank {}: {}".format( 608 | torch.distributed.get_rank(), error 609 | ) 610 | ) 611 | assert error < 5.0e-5, "error: {}".format(error) 612 | 613 | torch.distributed.barrier() 614 | if torch.distributed.get_rank() == 0: 615 | print(" >> passed the test :-)") 616 | 617 | 618 | if __name__ == "__main__": 619 | torch.backends.cudnn.deterministic = True 620 | torch.backends.cudnn.benchmark = False 621 | 622 | initialize_distributed() 623 | world_size = torch.distributed.get_world_size() 624 | 625 | print_separator("test initialize affine weight") 626 | tensor_model_parallel_size = 1 627 | while tensor_model_parallel_size <= world_size: 628 | test_initialize_affine_weight(tensor_model_parallel_size) 629 | tensor_model_parallel_size *= 2 630 | 631 | tensor_model_parallel_size = 1 632 | while tensor_model_parallel_size <= world_size: 633 | print_separator("test parallel embedding") 634 | test_parallel_embedding(tensor_model_parallel_size) 635 | tensor_model_parallel_size *= 2 636 | 637 | print_separator("test column-parallel linear") 638 | tensor_model_parallel_size = 1 639 | while tensor_model_parallel_size <= world_size: 640 | test_column_parallel_linear(tensor_model_parallel_size) 641 | tensor_model_parallel_size *= 2 642 | 643 | print_separator("test row-parallel linear") 644 | tensor_model_parallel_size = 1 645 | while tensor_model_parallel_size <= world_size: 646 | test_row_parallel_linear(tensor_model_parallel_size) 647 | tensor_model_parallel_size *= 2 648 | 649 | print_separator("test parallel self-attention") 650 | tensor_model_parallel_size = 1 651 | while tensor_model_parallel_size <= world_size: 652 | test_parallel_self_attention(tensor_model_parallel_size) 653 | tensor_model_parallel_size *= 2 654 | 655 | print_separator("test parallel transformer") 656 | tensor_model_parallel_size = 1 657 | while tensor_model_parallel_size <= world_size: 658 | test_parallel_transformer_layer(tensor_model_parallel_size) 659 | tensor_model_parallel_size *= 2 660 | -------------------------------------------------------------------------------- /distill_bloom/dataset/megatron/helpers.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | coding=utf-8 3 | Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | */ 17 | 18 | 19 | /* Helper methods for fast index mapping builds */ 20 | 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include 28 | #include 29 | 30 | namespace py = pybind11; 31 | using namespace std; 32 | 33 | const int32_t LONG_SENTENCE_LEN = 512; 34 | 35 | 36 | void build_blending_indices(py::array_t& dataset_index, 37 | py::array_t& dataset_sample_index, 38 | const py::array_t& weights, 39 | const int32_t num_datasets, 40 | const int64_t size, const bool verbose) { 41 | /* Given multiple datasets and a weighting array, build samples 42 | such that it follows those wieghts.*/ 43 | 44 | if (verbose) { 45 | std::cout << "> building indices for blendable datasets ..." << std::endl; 46 | } 47 | 48 | // Get the pointer access without the checks. 49 | auto dataset_index_ptr = dataset_index.mutable_unchecked<1>(); 50 | auto dataset_sample_index_ptr = dataset_sample_index.mutable_unchecked<1>(); 51 | auto weights_ptr = weights.unchecked<1>(); 52 | 53 | // Initialize buffer for number of samples used for each dataset. 54 | int64_t current_samples[num_datasets]; 55 | for(int64_t i = 0; i < num_datasets; ++i) { 56 | current_samples[i] = 0; 57 | } 58 | 59 | // For each sample: 60 | for(int64_t sample_idx = 0; sample_idx < size; ++sample_idx) { 61 | 62 | // Determine where the max error in sampling is happening. 63 | auto sample_idx_double = std::max(static_cast(sample_idx), 1.0); 64 | int64_t max_error_index = 0; 65 | double max_error = weights_ptr[0] * sample_idx_double - 66 | static_cast(current_samples[0]); 67 | for (int64_t dataset_idx = 1; dataset_idx < num_datasets; ++dataset_idx) { 68 | double error = weights_ptr[dataset_idx] * sample_idx_double - 69 | static_cast(current_samples[dataset_idx]); 70 | if (error > max_error) { 71 | max_error = error; 72 | max_error_index = dataset_idx; 73 | } 74 | } 75 | 76 | // Populate the indices. 77 | dataset_index_ptr[sample_idx] = static_cast(max_error_index); 78 | dataset_sample_index_ptr[sample_idx] = current_samples[max_error_index]; 79 | 80 | // Update the total samples. 81 | current_samples[max_error_index] += 1; 82 | 83 | } 84 | 85 | // print info 86 | if (verbose) { 87 | std::cout << " > sample ratios:" << std::endl; 88 | for (int64_t dataset_idx = 0; dataset_idx < num_datasets; ++dataset_idx) { 89 | auto ratio = static_cast(current_samples[dataset_idx]) / 90 | static_cast(size); 91 | std::cout << " dataset " << dataset_idx << ", input: " << 92 | weights_ptr[dataset_idx] << ", achieved: " << ratio << std::endl; 93 | } 94 | } 95 | 96 | } 97 | 98 | 99 | py::array build_sample_idx(const py::array_t& sizes_, 100 | const py::array_t& doc_idx_, 101 | const int32_t seq_length, 102 | const int32_t num_epochs, 103 | const int64_t tokens_per_epoch) { 104 | /* Sample index (sample_idx) is used for gpt2 like dataset for which 105 | the documents are flattened and the samples are built based on this 106 | 1-D flatten array. It is a 2D array with sizes [number-of-samples + 1, 2] 107 | where [..., 0] contains the index into `doc_idx` and [..., 1] is the 108 | starting offset in that document.*/ 109 | 110 | // Consistency checks. 111 | assert(seq_length > 1); 112 | assert(num_epochs > 0); 113 | assert(tokens_per_epoch > 1); 114 | 115 | // Remove bound checks. 116 | auto sizes = sizes_.unchecked<1>(); 117 | auto doc_idx = doc_idx_.unchecked<1>(); 118 | 119 | // Mapping and it's length (1D). 120 | int64_t num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length; 121 | int32_t* sample_idx = new int32_t[2*(num_samples+1)]; 122 | 123 | cout << " using:" << endl << std::flush; 124 | cout << " number of documents: " << 125 | doc_idx_.shape(0) / num_epochs << endl << std::flush; 126 | cout << " number of epochs: " << num_epochs << 127 | endl << std::flush; 128 | cout << " sequence length: " << seq_length << 129 | endl << std::flush; 130 | cout << " total number of samples: " << num_samples << 131 | endl << std::flush; 132 | 133 | // Index into sample_idx. 134 | int64_t sample_index = 0; 135 | // Index into doc_idx. 136 | int64_t doc_idx_index = 0; 137 | // Begining offset for each document. 138 | int32_t doc_offset = 0; 139 | // Start with first document and no offset. 140 | sample_idx[2 * sample_index] = doc_idx_index; 141 | sample_idx[2 * sample_index + 1] = doc_offset; 142 | ++sample_index; 143 | 144 | while (sample_index <= num_samples) { 145 | // Start with a fresh sequence. 146 | int32_t remaining_seq_length = seq_length + 1; 147 | while (remaining_seq_length != 0) { 148 | // Get the document length. 149 | auto doc_id = doc_idx[doc_idx_index]; 150 | auto doc_length = sizes[doc_id] - doc_offset; 151 | // And add it to the current sequence. 152 | remaining_seq_length -= doc_length; 153 | // If we have more than a full sequence, adjust offset and set 154 | // remaining length to zero so we return from the while loop. 155 | // Note that -1 here is for the same reason we have -1 in 156 | // `_num_epochs` calculations. 157 | if (remaining_seq_length <= 0) { 158 | doc_offset += (remaining_seq_length + doc_length - 1); 159 | remaining_seq_length = 0; 160 | } else { 161 | // Otherwise, start from the begining of the next document. 162 | ++doc_idx_index; 163 | doc_offset = 0; 164 | } 165 | } 166 | // Record the sequence. 167 | sample_idx[2 * sample_index] = doc_idx_index; 168 | sample_idx[2 * sample_index + 1] = doc_offset; 169 | ++sample_index; 170 | } 171 | 172 | // Method to deallocate memory. 173 | py::capsule free_when_done(sample_idx, [](void *mem_) { 174 | int32_t *mem = reinterpret_cast(mem_); 175 | delete[] mem; 176 | }); 177 | 178 | // Return the numpy array. 179 | const auto byte_size = sizeof(int32_t); 180 | return py::array(std::vector{num_samples+1, 2}, // shape 181 | {2*byte_size, byte_size}, // C-style contiguous strides 182 | sample_idx, // the data pointer 183 | free_when_done); // numpy array references 184 | 185 | } 186 | 187 | 188 | inline int32_t get_target_sample_len(const int32_t short_seq_ratio, 189 | const int32_t max_length, 190 | std::mt19937& rand32_gen) { 191 | /* Training sample length. */ 192 | if (short_seq_ratio == 0) { 193 | return max_length; 194 | } 195 | const auto random_number = rand32_gen(); 196 | if ((random_number % short_seq_ratio) == 0) { 197 | return 2 + random_number % (max_length - 1); 198 | } 199 | return max_length; 200 | } 201 | 202 | 203 | template 204 | py::array build_mapping_impl(const py::array_t& docs_, 205 | const py::array_t& sizes_, 206 | const int32_t num_epochs, 207 | const uint64_t max_num_samples, 208 | const int32_t max_seq_length, 209 | const double short_seq_prob, 210 | const int32_t seed, 211 | const bool verbose, 212 | const int32_t min_num_sent) { 213 | /* Build a mapping of (start-index, end-index, sequence-length) where 214 | start and end index are the indices of the sentences in the sample 215 | and sequence-length is the target sequence length. 216 | */ 217 | 218 | // Consistency checks. 219 | assert(num_epochs > 0); 220 | assert(max_seq_length > 1); 221 | assert(short_seq_prob >= 0.0); 222 | assert(short_seq_prob <= 1.0); 223 | assert(seed > 0); 224 | 225 | // Remove bound checks. 226 | auto docs = docs_.unchecked<1>(); 227 | auto sizes = sizes_.unchecked<1>(); 228 | 229 | // For efficiency, convert probability to ratio. Note: rand() generates int. 230 | int32_t short_seq_ratio = 0; 231 | if (short_seq_prob > 0) { 232 | short_seq_ratio = static_cast(round(1.0 / short_seq_prob)); 233 | } 234 | 235 | if (verbose) { 236 | const auto sent_start_index = docs[0]; 237 | const auto sent_end_index = docs[docs_.shape(0) - 1]; 238 | const auto num_sentences = sent_end_index - sent_start_index; 239 | cout << " using:" << endl << std::flush; 240 | cout << " number of documents: " << docs_.shape(0) - 1 << 241 | endl << std::flush; 242 | cout << " sentences range: [" << sent_start_index << 243 | ", " << sent_end_index << ")" << endl << std::flush; 244 | cout << " total number of sentences: " << num_sentences << 245 | endl << std::flush; 246 | cout << " number of epochs: " << num_epochs << 247 | endl << std::flush; 248 | cout << " maximum number of samples: " << max_num_samples << 249 | endl << std::flush; 250 | cout << " maximum sequence length: " << max_seq_length << 251 | endl << std::flush; 252 | cout << " short sequence probability: " << short_seq_prob << 253 | endl << std::flush; 254 | cout << " short sequence ration (1/prob): " << short_seq_ratio << 255 | endl << std::flush; 256 | cout << " seed: " << seed << endl << 257 | std::flush; 258 | } 259 | 260 | // Mapping and it's length (1D). 261 | int64_t num_samples = -1; 262 | DocIdx* maps = NULL; 263 | 264 | // Perform two iterations, in the first iteration get the size 265 | // and allocate memory and in the second iteration populate the map. 266 | bool second = false; 267 | for (int32_t iteration=0; iteration<2; ++iteration) { 268 | 269 | // Set the seed so both iterations produce the same results. 270 | std::mt19937 rand32_gen(seed); 271 | 272 | // Set the flag on second iteration. 273 | second = (iteration == 1); 274 | 275 | // Counters: 276 | uint64_t empty_docs = 0; 277 | uint64_t one_sent_docs = 0; 278 | uint64_t long_sent_docs = 0; 279 | 280 | // Current map index. 281 | uint64_t map_index = 0; 282 | 283 | // For each epoch: 284 | for (int32_t epoch=0; epoch= max_num_samples) { 286 | if (verbose && (!second)) { 287 | cout << " reached " << max_num_samples << " samples after " 288 | << epoch << " epochs ..." << endl << std::flush; 289 | } 290 | break; 291 | } 292 | // For each document: 293 | for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) { 294 | 295 | // Document sentences are in [sent_index_first, sent_index_last) 296 | const auto sent_index_first = docs[doc]; 297 | const auto sent_index_last = docs[doc + 1]; 298 | 299 | // At the begining of the document previous index is the 300 | // start index. 301 | auto prev_start_index = sent_index_first; 302 | 303 | // Remaining documents. 304 | auto num_remain_sent = sent_index_last - sent_index_first; 305 | 306 | // Some bookkeeping 307 | if ((epoch == 0) && (!second)) { 308 | if (num_remain_sent == 0) { 309 | ++empty_docs; 310 | } 311 | if (num_remain_sent == 1) { 312 | ++one_sent_docs; 313 | } 314 | } 315 | 316 | // Detect documents with long sentences. 317 | bool contains_long_sentence = false; 318 | if (num_remain_sent > 1) { 319 | for (auto sent_index=sent_index_first; 320 | sent_index < sent_index_last; ++sent_index) { 321 | if (sizes[sent_index] > LONG_SENTENCE_LEN){ 322 | if ((epoch == 0) && (!second)) { 323 | ++long_sent_docs; 324 | } 325 | contains_long_sentence = true; 326 | break; 327 | } 328 | } 329 | } 330 | 331 | // If we have more than two sentences. 332 | if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) { 333 | 334 | // Set values. 335 | auto seq_len = int32_t{0}; 336 | auto num_sent = int32_t{0}; 337 | auto target_seq_len = get_target_sample_len(short_seq_ratio, 338 | max_seq_length, 339 | rand32_gen); 340 | 341 | // Loop through sentences. 342 | for (auto sent_index=sent_index_first; 343 | sent_index < sent_index_last; ++sent_index) { 344 | 345 | // Add the size and number of sentences. 346 | seq_len += sizes[sent_index]; 347 | ++num_sent; 348 | --num_remain_sent; 349 | 350 | // If we have reached the target length. 351 | // and if not only one sentence is left in the document. 352 | // and if we have at least two sentneces. 353 | // and if we have reached end of the document. 354 | if (((seq_len >= target_seq_len) && 355 | (num_remain_sent > 1) && 356 | (num_sent >= min_num_sent) ) || (num_remain_sent == 0)) { 357 | 358 | // Check for overflow. 359 | if ((3 * map_index + 2) > 360 | std::numeric_limits::max()) { 361 | cout << "number of samples exceeded maximum " 362 | << "allowed by type int64: " 363 | << std::numeric_limits::max() 364 | << endl; 365 | throw std::overflow_error("Number of samples"); 366 | } 367 | 368 | // Populate the map. 369 | if (second) { 370 | const auto map_index_0 = 3 * map_index; 371 | maps[map_index_0] = static_cast(prev_start_index); 372 | maps[map_index_0 + 1] = static_cast(sent_index + 1); 373 | maps[map_index_0 + 2] = static_cast(target_seq_len); 374 | } 375 | 376 | // Update indices / counters. 377 | ++map_index; 378 | prev_start_index = sent_index + 1; 379 | target_seq_len = get_target_sample_len(short_seq_ratio, 380 | max_seq_length, 381 | rand32_gen); 382 | seq_len = 0; 383 | num_sent = 0; 384 | } 385 | 386 | } // for (auto sent_index=sent_index_first; ... 387 | } // if (num_remain_sent > 1) { 388 | } // for (int doc=0; doc < num_docs; ++doc) { 389 | } // for (int epoch=0; epoch < num_epochs; ++epoch) { 390 | 391 | if (!second) { 392 | if (verbose) { 393 | cout << " number of empty documents: " << empty_docs << 394 | endl << std::flush; 395 | cout << " number of documents with one sentence: " << 396 | one_sent_docs << endl << std::flush; 397 | cout << " number of documents with long sentences: " << 398 | long_sent_docs << endl << std::flush; 399 | cout << " will create mapping for " << map_index << 400 | " samples" << endl << std::flush; 401 | } 402 | assert(maps == NULL); 403 | assert(num_samples < 0); 404 | maps = new DocIdx[3*map_index]; 405 | num_samples = static_cast(map_index); 406 | } 407 | 408 | } // for (int iteration=0; iteration < 2; ++iteration) { 409 | 410 | // Shuffle. 411 | // We need a 64 bit random number generator as we might have more 412 | // than 2 billion samples. 413 | std::mt19937_64 rand64_gen(seed + 1); 414 | for (auto i=(num_samples - 1); i > 0; --i) { 415 | const auto j = static_cast(rand64_gen() % (i + 1)); 416 | const auto i0 = 3 * i; 417 | const auto j0 = 3 * j; 418 | // Swap values. 419 | swap(maps[i0], maps[j0]); 420 | swap(maps[i0 + 1], maps[j0 + 1]); 421 | swap(maps[i0 + 2], maps[j0 + 2]); 422 | } 423 | 424 | // Method to deallocate memory. 425 | py::capsule free_when_done(maps, [](void *mem_) { 426 | DocIdx *mem = reinterpret_cast(mem_); 427 | delete[] mem; 428 | }); 429 | 430 | // Return the numpy array. 431 | const auto byte_size = sizeof(DocIdx); 432 | return py::array(std::vector{num_samples, 3}, // shape 433 | {3*byte_size, byte_size}, // C-style contiguous strides 434 | maps, // the data pointer 435 | free_when_done); // numpy array references 436 | 437 | } 438 | 439 | 440 | py::array build_mapping(const py::array_t& docs_, 441 | const py::array_t& sizes_, 442 | const int num_epochs, 443 | const uint64_t max_num_samples, 444 | const int max_seq_length, 445 | const double short_seq_prob, 446 | const int seed, 447 | const bool verbose, 448 | const int32_t min_num_sent) { 449 | 450 | if (sizes_.size() > std::numeric_limits::max()) { 451 | if (verbose) { 452 | cout << " using uint64 for data mapping..." << endl << std::flush; 453 | } 454 | return build_mapping_impl(docs_, sizes_, num_epochs, 455 | max_num_samples, max_seq_length, 456 | short_seq_prob, seed, verbose, 457 | min_num_sent); 458 | } else { 459 | if (verbose) { 460 | cout << " using uint32 for data mapping..." << endl << std::flush; 461 | } 462 | return build_mapping_impl(docs_, sizes_, num_epochs, 463 | max_num_samples, max_seq_length, 464 | short_seq_prob, seed, verbose, 465 | min_num_sent); 466 | } 467 | } 468 | 469 | template 470 | py::array build_blocks_mapping_impl(const py::array_t& docs_, 471 | const py::array_t& sizes_, 472 | const py::array_t& titles_sizes_, 473 | const int32_t num_epochs, 474 | const uint64_t max_num_samples, 475 | const int32_t max_seq_length, 476 | const int32_t seed, 477 | const bool verbose, 478 | const bool use_one_sent_blocks) { 479 | /* Build a mapping of (start-index, end-index, sequence-length) where 480 | start and end index are the indices of the sentences in the sample 481 | and sequence-length is the target sequence length. 482 | */ 483 | 484 | // Consistency checks. 485 | assert(num_epochs > 0); 486 | assert(max_seq_length > 1); 487 | assert(seed > 0); 488 | 489 | // Remove bound checks. 490 | auto docs = docs_.unchecked<1>(); 491 | auto sizes = sizes_.unchecked<1>(); 492 | auto titles_sizes = titles_sizes_.unchecked<1>(); 493 | 494 | if (verbose) { 495 | const auto sent_start_index = docs[0]; 496 | const auto sent_end_index = docs[docs_.shape(0) - 1]; 497 | const auto num_sentences = sent_end_index - sent_start_index; 498 | cout << " using:" << endl << std::flush; 499 | cout << " number of documents: " << docs_.shape(0) - 1 << 500 | endl << std::flush; 501 | cout << " sentences range: [" << sent_start_index << 502 | ", " << sent_end_index << ")" << endl << std::flush; 503 | cout << " total number of sentences: " << num_sentences << 504 | endl << std::flush; 505 | cout << " number of epochs: " << num_epochs << 506 | endl << std::flush; 507 | cout << " maximum number of samples: " << max_num_samples << 508 | endl << std::flush; 509 | cout << " maximum sequence length: " << max_seq_length << 510 | endl << std::flush; 511 | cout << " seed: " << seed << endl << 512 | std::flush; 513 | } 514 | 515 | // Mapping and its length (1D). 516 | int64_t num_samples = -1; 517 | DocIdx* maps = NULL; 518 | 519 | // Acceptable number of sentences per block. 520 | int min_num_sent = 2; 521 | if (use_one_sent_blocks) { 522 | min_num_sent = 1; 523 | } 524 | 525 | // Perform two iterations, in the first iteration get the size 526 | // and allocate memory and in the second iteration populate the map. 527 | bool second = false; 528 | for (int32_t iteration=0; iteration<2; ++iteration) { 529 | 530 | // Set the flag on second iteration. 531 | second = (iteration == 1); 532 | 533 | // Current map index. 534 | uint64_t map_index = 0; 535 | 536 | uint64_t empty_docs = 0; 537 | uint64_t one_sent_docs = 0; 538 | uint64_t long_sent_docs = 0; 539 | // For each epoch: 540 | for (int32_t epoch=0; epoch= max_num_samples) { 545 | if (verbose && (!second)) { 546 | cout << " reached " << max_num_samples << " samples after " 547 | << epoch << " epochs ..." << endl << std::flush; 548 | } 549 | break; 550 | } 551 | // For each document: 552 | for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) { 553 | 554 | // Document sentences are in [sent_index_first, sent_index_last) 555 | const auto sent_index_first = docs[doc]; 556 | const auto sent_index_last = docs[doc + 1]; 557 | const auto target_seq_len = max_seq_length - titles_sizes[doc]; 558 | 559 | // At the begining of the document previous index is the 560 | // start index. 561 | auto prev_start_index = sent_index_first; 562 | 563 | // Remaining documents. 564 | auto num_remain_sent = sent_index_last - sent_index_first; 565 | 566 | // Some bookkeeping 567 | if ((epoch == 0) && (!second)) { 568 | if (num_remain_sent == 0) { 569 | ++empty_docs; 570 | } 571 | if (num_remain_sent == 1) { 572 | ++one_sent_docs; 573 | } 574 | } 575 | // Detect documents with long sentences. 576 | bool contains_long_sentence = false; 577 | if (num_remain_sent >= min_num_sent) { 578 | for (auto sent_index=sent_index_first; 579 | sent_index < sent_index_last; ++sent_index) { 580 | if (sizes[sent_index] > LONG_SENTENCE_LEN){ 581 | if ((epoch == 0) && (!second)) { 582 | ++long_sent_docs; 583 | } 584 | contains_long_sentence = true; 585 | break; 586 | } 587 | } 588 | } 589 | // If we have enough sentences and no long sentences. 590 | if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) { 591 | 592 | // Set values. 593 | auto seq_len = int32_t{0}; 594 | auto num_sent = int32_t{0}; 595 | 596 | // Loop through sentences. 597 | for (auto sent_index=sent_index_first; 598 | sent_index < sent_index_last; ++sent_index) { 599 | 600 | // Add the size and number of sentences. 601 | seq_len += sizes[sent_index]; 602 | ++num_sent; 603 | --num_remain_sent; 604 | 605 | // If we have reached the target length. 606 | // and there are an acceptable number of sentences left 607 | // and if we have at least the minimum number of sentences. 608 | // or if we have reached end of the document. 609 | if (((seq_len >= target_seq_len) && 610 | (num_remain_sent >= min_num_sent) && 611 | (num_sent >= min_num_sent) ) || (num_remain_sent == 0)) { 612 | 613 | // Populate the map. 614 | if (second) { 615 | const auto map_index_0 = 4 * map_index; 616 | // Each sample has 4 items: the starting sentence index, ending sentence index, 617 | // the index of the document from which the block comes (used for fetching titles) 618 | // and the unique id of the block (used for creating block indexes) 619 | 620 | maps[map_index_0] = static_cast(prev_start_index); 621 | maps[map_index_0 + 1] = static_cast(sent_index + 1); 622 | maps[map_index_0 + 2] = static_cast(doc); 623 | maps[map_index_0 + 3] = static_cast(block_id); 624 | } 625 | 626 | // Update indices / counters. 627 | ++map_index; 628 | ++block_id; 629 | prev_start_index = sent_index + 1; 630 | seq_len = 0; 631 | num_sent = 0; 632 | } 633 | } // for (auto sent_index=sent_index_first; ... 634 | } // if (num_remain_sent > 1) { 635 | } // for (int doc=0; doc < num_docs; ++doc) { 636 | } // for (int epoch=0; epoch < num_epochs; ++epoch) { 637 | 638 | if (!second) { 639 | if (verbose) { 640 | cout << " number of empty documents: " << empty_docs << 641 | endl << std::flush; 642 | cout << " number of documents with one sentence: " << 643 | one_sent_docs << endl << std::flush; 644 | cout << " number of documents with long sentences: " << 645 | long_sent_docs << endl << std::flush; 646 | cout << " will create mapping for " << map_index << 647 | " samples" << endl << std::flush; 648 | } 649 | assert(maps == NULL); 650 | assert(num_samples < 0); 651 | maps = new DocIdx[4*map_index]; 652 | num_samples = static_cast(map_index); 653 | } 654 | 655 | } // for (int iteration=0; iteration < 2; ++iteration) { 656 | 657 | // Shuffle. 658 | // We need a 64 bit random number generator as we might have more 659 | // than 2 billion samples. 660 | std::mt19937_64 rand64_gen(seed + 1); 661 | for (auto i=(num_samples - 1); i > 0; --i) { 662 | const auto j = static_cast(rand64_gen() % (i + 1)); 663 | const auto i0 = 4 * i; 664 | const auto j0 = 4 * j; 665 | // Swap values. 666 | swap(maps[i0], maps[j0]); 667 | swap(maps[i0 + 1], maps[j0 + 1]); 668 | swap(maps[i0 + 2], maps[j0 + 2]); 669 | swap(maps[i0 + 3], maps[j0 + 3]); 670 | } 671 | 672 | // Method to deallocate memory. 673 | py::capsule free_when_done(maps, [](void *mem_) { 674 | DocIdx *mem = reinterpret_cast(mem_); 675 | delete[] mem; 676 | }); 677 | 678 | // Return the numpy array. 679 | const auto byte_size = sizeof(DocIdx); 680 | return py::array(std::vector{num_samples, 4}, // shape 681 | {4*byte_size, byte_size}, // C-style contiguous strides 682 | maps, // the data pointer 683 | free_when_done); // numpy array references 684 | 685 | } 686 | 687 | py::array build_blocks_mapping(const py::array_t& docs_, 688 | const py::array_t& sizes_, 689 | const py::array_t& titles_sizes_, 690 | const int num_epochs, 691 | const uint64_t max_num_samples, 692 | const int max_seq_length, 693 | const int seed, 694 | const bool verbose, 695 | const bool use_one_sent_blocks) { 696 | 697 | if (sizes_.size() > std::numeric_limits::max()) { 698 | if (verbose) { 699 | cout << " using uint64 for data mapping..." << endl << std::flush; 700 | } 701 | return build_blocks_mapping_impl(docs_, sizes_, titles_sizes_, 702 | num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks); 703 | } else { 704 | if (verbose) { 705 | cout << " using uint32 for data mapping..." << endl << std::flush; 706 | } 707 | return build_blocks_mapping_impl(docs_, sizes_, titles_sizes_, 708 | num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks); 709 | } 710 | } 711 | 712 | PYBIND11_MODULE(helpers, m) { 713 | m.def("build_mapping", &build_mapping); 714 | m.def("build_blocks_mapping", &build_blocks_mapping); 715 | m.def("build_sample_idx", &build_sample_idx); 716 | m.def("build_blending_indices", &build_blending_indices); 717 | } 718 | --------------------------------------------------------------------------------