├── src ├── __init__.py ├── __pycache__ │ ├── utils.cpython-37.pyc │ ├── __init__.cpython-37.pyc │ ├── xl_wrapper.cpython-37.pyc │ └── download_utils.cpython-37.pyc ├── fp16 │ ├── __pycache__ │ │ ├── fp16.cpython-37.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── fp16util.cpython-37.pyc │ │ └── loss_scaler.cpython-37.pyc │ ├── __init__.py │ ├── fp16util.py │ └── loss_scaler.py ├── mpu │ ├── __pycache__ │ │ ├── data.cpython-37.pyc │ │ ├── grads.cpython-37.pyc │ │ ├── utils.cpython-37.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── layers.cpython-37.pyc │ │ ├── mappings.cpython-37.pyc │ │ ├── random.cpython-37.pyc │ │ ├── initialize.cpython-37.pyc │ │ ├── cross_entropy.cpython-37.pyc │ │ └── transformer.cpython-37.pyc │ ├── __init__.py │ ├── utils.py │ ├── grads.py │ ├── data.py │ ├── mappings.py │ ├── cross_entropy.py │ ├── initialize.py │ ├── layers.py │ ├── random.py │ └── transformer.py ├── model │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── distributed.cpython-37.pyc │ │ └── gpt3_modeling.cpython-37.pyc │ ├── __init__.py │ ├── distributed.py │ └── gpt3_modeling.py ├── deepspeed_config │ ├── gpt3_large_2048.json │ ├── gpt3_small.json │ ├── gpt3_small_2048.json │ ├── gpt3_medium_2048.json │ ├── gpt3_large.json │ ├── gpt2_small.json │ ├── gpt3_small_sparse.json │ ├── gpt3_small_sparse_2048.json │ ├── gpt3_xl_sparse.json │ └── gpt3_xl_sparse_2048.json ├── data_utils │ ├── __init__.py │ ├── lazy_loader.py │ └── file_utils.py ├── download_utils.py ├── learning_rates.py ├── gpt3_data_loader.py ├── dataset_rugpt3.py ├── xl_wrapper.py ├── arguments.py └── utils.py └── README.md /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TatianaShavrina/rugpt3simplification_rsse/main/src/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TatianaShavrina/rugpt3simplification_rsse/main/src/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /src/fp16/__pycache__/fp16.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TatianaShavrina/rugpt3simplification_rsse/main/src/fp16/__pycache__/fp16.cpython-37.pyc -------------------------------------------------------------------------------- /src/mpu/__pycache__/data.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TatianaShavrina/rugpt3simplification_rsse/main/src/mpu/__pycache__/data.cpython-37.pyc -------------------------------------------------------------------------------- /src/mpu/__pycache__/grads.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TatianaShavrina/rugpt3simplification_rsse/main/src/mpu/__pycache__/grads.cpython-37.pyc -------------------------------------------------------------------------------- /src/mpu/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TatianaShavrina/rugpt3simplification_rsse/main/src/mpu/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/xl_wrapper.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TatianaShavrina/rugpt3simplification_rsse/main/src/__pycache__/xl_wrapper.cpython-37.pyc -------------------------------------------------------------------------------- /src/mpu/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TatianaShavrina/rugpt3simplification_rsse/main/src/mpu/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /src/mpu/__pycache__/layers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TatianaShavrina/rugpt3simplification_rsse/main/src/mpu/__pycache__/layers.cpython-37.pyc -------------------------------------------------------------------------------- /src/mpu/__pycache__/mappings.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TatianaShavrina/rugpt3simplification_rsse/main/src/mpu/__pycache__/mappings.cpython-37.pyc -------------------------------------------------------------------------------- /src/mpu/__pycache__/random.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TatianaShavrina/rugpt3simplification_rsse/main/src/mpu/__pycache__/random.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/download_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TatianaShavrina/rugpt3simplification_rsse/main/src/__pycache__/download_utils.cpython-37.pyc -------------------------------------------------------------------------------- /src/fp16/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TatianaShavrina/rugpt3simplification_rsse/main/src/fp16/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /src/fp16/__pycache__/fp16util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TatianaShavrina/rugpt3simplification_rsse/main/src/fp16/__pycache__/fp16util.cpython-37.pyc -------------------------------------------------------------------------------- /src/model/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TatianaShavrina/rugpt3simplification_rsse/main/src/model/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /src/mpu/__pycache__/initialize.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TatianaShavrina/rugpt3simplification_rsse/main/src/mpu/__pycache__/initialize.cpython-37.pyc -------------------------------------------------------------------------------- /src/fp16/__pycache__/loss_scaler.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TatianaShavrina/rugpt3simplification_rsse/main/src/fp16/__pycache__/loss_scaler.cpython-37.pyc -------------------------------------------------------------------------------- /src/model/__pycache__/distributed.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TatianaShavrina/rugpt3simplification_rsse/main/src/model/__pycache__/distributed.cpython-37.pyc -------------------------------------------------------------------------------- /src/mpu/__pycache__/cross_entropy.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TatianaShavrina/rugpt3simplification_rsse/main/src/mpu/__pycache__/cross_entropy.cpython-37.pyc -------------------------------------------------------------------------------- /src/mpu/__pycache__/transformer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TatianaShavrina/rugpt3simplification_rsse/main/src/mpu/__pycache__/transformer.cpython-37.pyc -------------------------------------------------------------------------------- /src/model/__pycache__/gpt3_modeling.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TatianaShavrina/rugpt3simplification_rsse/main/src/model/__pycache__/gpt3_modeling.cpython-37.pyc -------------------------------------------------------------------------------- /src/deepspeed_config/gpt3_large_2048.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_micro_batch_size_per_gpu": 1, 3 | "fp16": { 4 | "enabled": false, 5 | "loss_scale": 0, 6 | "loss_scale_window": 2000, 7 | "min_loss_scale": 0.0 8 | }, 9 | "zero_optimization": { 10 | "stage": 0, 11 | "reduce_bucket_size": 50000000 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /src/deepspeed_config/gpt3_small.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_micro_batch_size_per_gpu": 8, 3 | "gradient_accumulation_steps": 1, 4 | "steps_per_print": 100, 5 | "gradient_clipping": 1.0, 6 | "fp16": { 7 | "enabled": true, 8 | "loss_scale": 0, 9 | "loss_scale_window": 2000, 10 | "hysteresis": 2, 11 | "min_loss_scale": 0.0 12 | }, 13 | "zero_optimization": { 14 | "stage":2, 15 | "reduce_bucket_size": 50000000 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /src/deepspeed_config/gpt3_small_2048.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_micro_batch_size_per_gpu": 2, 3 | "gradient_accumulation_steps": 1, 4 | "steps_per_print": 100, 5 | "gradient_clipping": 1.0, 6 | "fp16": { 7 | "enabled": true, 8 | "loss_scale": 0, 9 | "loss_scale_window": 2000, 10 | "hysteresis": 2, 11 | "min_loss_scale": 0.0 12 | }, 13 | "zero_optimization": { 14 | "stage": 2, 15 | "reduce_bucket_size": 50000000 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /src/deepspeed_config/gpt3_medium_2048.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_micro_batch_size_per_gpu": 4, 3 | "gradient_accumulation_steps": 1, 4 | "steps_per_print": 100, 5 | "gradient_clipping": 1.0, 6 | "fp16": { 7 | "enabled": true, 8 | "loss_scale": 0, 9 | "loss_scale_window": 2000, 10 | "hysteresis": 2, 11 | "min_loss_scale": 0.0 12 | }, 13 | "zero_optimization": { 14 | "stage": 2, 15 | "reduce_bucket_size": 50000000, 16 | "overlap_comm": true 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /src/deepspeed_config/gpt3_large.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_batch_size": 20, 3 | "gradient_accumulation_steps": 5, 4 | "steps_per_print": 100, 5 | "zero_optimization": { 6 | "stage": 0 7 | }, 8 | "zero_allow_untested_optimizer": true, 9 | "gradient_clipping": 1.0, 10 | "fp16": { 11 | "enabled": true, 12 | "loss_scale": 0, 13 | "loss_scale_window": 1000, 14 | "hysteresis": 2, 15 | "min_loss_scale": 1 16 | }, 17 | "activation_checkpointing": { 18 | "partition_activations": false, 19 | "contiguous_memory_optimization": false 20 | }, 21 | "wall_clock_breakdown": false 22 | } 23 | -------------------------------------------------------------------------------- /src/deepspeed_config/gpt2_small.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_batch_size": 168, 3 | "gradient_accumulation_steps": 1, 4 | "steps_per_print": 1000, 5 | "zero_optimization": { 6 | "stage": 2 7 | }, 8 | "zero_allow_untested_optimizer": true, 9 | "gradient_clipping": 1.0, 10 | "fp16": { 11 | "enabled": true, 12 | "loss_scale": 0, 13 | "loss_scale_window": 1000, 14 | "hysteresis": 2, 15 | "min_loss_scale": 1 16 | }, 17 | "activation_checkpointing": { 18 | "partition_activations": false, 19 | "contiguous_memory_optimization": false 20 | }, 21 | "wall_clock_breakdown": false 22 | } 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### ruGPT3 solution of RSSE 2 | 3 | 4 | Solution of [RuSimpleSentEval](https://github.com/dialogue-evaluation/RuSimpleSentEval), competition in Dialogue2021. 5 | 6 | Solution based on RuGPT-3XL. Model was tuned on train data. 7 | 8 | Our approach has achieved second place on the public leaderboard and fifth place on the privateleaderboard. 9 | It reaches about 37 SARI score. 10 | 11 | Download pre-trained XL model [here](https://disk.yandex.ru/d/dd7tM93w4g-14g) and put in folder `model`. 12 | 13 | Example of usage is [here](./Simplification%20with%20ruGPT3.ipynb) 14 | 15 | Read about all experiments in the article. 16 | -------------------------------------------------------------------------------- /src/deepspeed_config/gpt3_small_sparse.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_micro_batch_size_per_gpu": 16, 3 | "gradient_accumulation_steps": 1, 4 | "steps_per_print": 100, 5 | "gradient_clipping": 1.0, 6 | "fp16": { 7 | "enabled": true, 8 | "loss_scale": 0, 9 | "loss_scale_window": 2000, 10 | "hysteresis": 2, 11 | "min_loss_scale": 0.0 12 | }, 13 | "zero_optimization": { 14 | "stage": 2, 15 | "reduce_bucket_size": 50000000 16 | }, 17 | "sparse_attention": { 18 | "mode": "fixed", 19 | "block": 16, 20 | "different_layout_per_head": true, 21 | "num_local_blocks": 8, 22 | "num_global_blocks": 1, 23 | "attention": "unidirectional", 24 | "horizontal_global_attention": false, 25 | "num_different_global_patterns": 8 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /src/deepspeed_config/gpt3_small_sparse_2048.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_micro_batch_size_per_gpu": 4, 3 | "gradient_accumulation_steps": 1, 4 | "steps_per_print": 100, 5 | "gradient_clipping": 1.0, 6 | "fp16": { 7 | "enabled": true, 8 | "loss_scale": 0, 9 | "loss_scale_window": 2000, 10 | "hysteresis": 2, 11 | "min_loss_scale": 0.0 12 | }, 13 | "zero_optimization": { 14 | "stage":1, 15 | "reduce_bucket_size": 500000000 16 | }, 17 | "sparse_attention": { 18 | "mode": "fixed", 19 | "block": 16, 20 | "different_layout_per_head": true, 21 | "num_local_blocks": 8, 22 | "num_global_blocks": 1, 23 | "attention": "unidirectional", 24 | "horizontal_global_attention": false, 25 | "num_different_global_patterns": 8 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /src/data_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2020, Sber. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """utils for creating datasets""" 16 | from .lazy_loader import exists_lazy, make_lazy, lazy_array_loader 17 | -------------------------------------------------------------------------------- /src/deepspeed_config/gpt3_xl_sparse.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_micro_batch_size_per_gpu": 8, 3 | "gradient_accumulation_steps": 1, 4 | "steps_per_print": 100, 5 | "gradient_clipping": 1.0, 6 | "fp16": { 7 | "enabled": true, 8 | "loss_scale": 0, 9 | "loss_scale_window": 2000, 10 | "hysteresis": 2, 11 | "min_loss_scale": 0.0 12 | }, 13 | "zero_optimization": { 14 | "stage": 2, 15 | "reduce_bucket_size": 50000000, 16 | "overlap_comm": true 17 | }, 18 | "sparse_attention": { 19 | "mode": "fixed", 20 | "block": 16, 21 | "different_layout_per_head": true, 22 | "num_local_blocks": 8, 23 | "num_global_blocks": 1, 24 | "attention": "unidirectional", 25 | "horizontal_global_attention": false, 26 | "num_different_global_patterns": 8 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /src/deepspeed_config/gpt3_xl_sparse_2048.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_micro_batch_size_per_gpu": 2, 3 | "gradient_accumulation_steps": 1, 4 | "steps_per_print": 100, 5 | "gradient_clipping": 1.0, 6 | "fp16": { 7 | "enabled": true, 8 | "loss_scale": 0, 9 | "loss_scale_window": 2000, 10 | "hysteresis": 2, 11 | "min_loss_scale": 0.0 12 | }, 13 | "zero_optimization": { 14 | "stage": 2, 15 | "reduce_bucket_size": 50000000, 16 | "overlap_comm": true 17 | }, 18 | "sparse_attention": { 19 | "mode": "fixed", 20 | "block": 16, 21 | "different_layout_per_head": true, 22 | "num_local_blocks": 8, 23 | "num_global_blocks": 1, 24 | "attention": "unidirectional", 25 | "horizontal_global_attention": false, 26 | "num_different_global_patterns": 8 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /src/model/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2020, Sber. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from .distributed import * 17 | from .gpt3_modeling import gpt3_get_params_for_weight_decay_optimization 18 | from .gpt3_modeling import GPT3Model 19 | -------------------------------------------------------------------------------- /src/fp16/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2020, Sber. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from .fp16util import ( 16 | BN_convert_float, 17 | network_to_half, 18 | prep_param_lists, 19 | model_grads_to_master_grads, 20 | master_params_to_model_params, 21 | tofp16, 22 | to_python_float, 23 | clip_grad_norm, 24 | convert_module, 25 | convert_network, 26 | FP16Model, 27 | ) 28 | 29 | from .fp16 import * 30 | from .loss_scaler import * 31 | -------------------------------------------------------------------------------- /src/mpu/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2020, Sber. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Model parallel utility interface.""" 17 | 18 | from .cross_entropy import vocab_parallel_cross_entropy 19 | 20 | from .data import broadcast_data 21 | 22 | from .grads import clip_grad_norm 23 | 24 | from .initialize import destroy_model_parallel 25 | from .initialize import get_data_parallel_group 26 | from .initialize import get_data_parallel_rank 27 | from .initialize import get_data_parallel_world_size 28 | from .initialize import get_model_parallel_group 29 | from .initialize import get_model_parallel_rank 30 | from .initialize import get_model_parallel_src_rank 31 | from .initialize import get_model_parallel_world_size 32 | from .initialize import initialize_model_parallel 33 | from .initialize import model_parallel_is_initialized 34 | 35 | from .layers import ColumnParallelLinear 36 | from .layers import ParallelEmbedding 37 | from .layers import RowParallelLinear 38 | from .layers import VocabParallelEmbedding 39 | 40 | from .mappings import copy_to_model_parallel_region 41 | from .mappings import gather_from_model_parallel_region 42 | from .mappings import reduce_from_model_parallel_region 43 | from .mappings import scatter_to_model_parallel_region 44 | 45 | from .random import checkpoint 46 | from .random import partition_activations_in_checkpoint 47 | from .random import get_cuda_rng_tracker 48 | from .random import model_parallel_cuda_manual_seed 49 | 50 | from .transformer import GPT3ParallelTransformer 51 | from .transformer import LayerNorm 52 | -------------------------------------------------------------------------------- /src/mpu/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2020, Sber. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import torch 18 | 19 | 20 | def ensure_divisibility(numerator, denominator): 21 | """Ensure that numerator is divisible by the denominator.""" 22 | assert numerator % denominator == 0, '{} is not divisible by {}'.format( 23 | numerator, denominator) 24 | 25 | 26 | def divide(numerator, denominator): 27 | """Ensure that numerator is divisible by the denominator and return 28 | the division value.""" 29 | ensure_divisibility(numerator, denominator) 30 | return numerator // denominator 31 | 32 | 33 | def split_tensor_along_last_dim(tensor, num_partitions, 34 | contiguous_split_chunks=False): 35 | """Split a tensor along its last dimension. 36 | Arguments: 37 | tensor: input tensor. 38 | num_partitions: number of partitions to split the tensor 39 | contiguous_split_chunks: If True, make each chunk contiguous 40 | in memory. 41 | """ 42 | # Get the size and dimension. 43 | last_dim = tensor.dim() - 1 44 | last_dim_size = divide(tensor.size()[last_dim], num_partitions) 45 | # Split. 46 | tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) 47 | # Note: torch.split does not create contiguous tensors by default. 48 | if contiguous_split_chunks: 49 | return tuple(chunk.contiguous() for chunk in tensor_list) 50 | 51 | return tensor_list 52 | 53 | 54 | class VocabUtility: 55 | """Split the vocabulary into `world_size` chunks amd return the 56 | first and last index of the vocabulary belonging to the `rank` 57 | partition: Note that indecies in [fist, last)""" 58 | 59 | @staticmethod 60 | def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, 61 | rank, world_size): 62 | index_f = rank * per_partition_vocab_size 63 | index_l = index_f + per_partition_vocab_size 64 | return index_f, index_l 65 | 66 | @staticmethod 67 | def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size): 68 | per_partition_vocab_size = divide(global_vocab_size, world_size) 69 | return VocabUtility.vocab_range_from_per_partition_vocab_size( 70 | per_partition_vocab_size, rank, world_size) 71 | -------------------------------------------------------------------------------- /src/download_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from transformers.file_utils import ( 4 | cached_path, 5 | hf_bucket_url, 6 | is_remote_url, 7 | ) 8 | from transformers.utils import logging 9 | 10 | logger = logging.get_logger(__name__) 11 | WEIGHTS_NAME = "mp_rank_00_model_states.pt" 12 | DEEPSPEED_CONFIG_NAME = "deepspeed_config.json" 13 | 14 | 15 | def download_model_files(pretrained_model_name_or_path): 16 | weights_path = download_file_from_hf(pretrained_model_name_or_path, WEIGHTS_NAME) 17 | deepspeed_config_path = download_file_from_hf(pretrained_model_name_or_path, DEEPSPEED_CONFIG_NAME) 18 | return weights_path, deepspeed_config_path 19 | 20 | 21 | def download_file_from_hf(pretrained_model_name_or_path: str, file_name: str) -> str: 22 | # Load model 23 | if pretrained_model_name_or_path is not None: 24 | if os.path.isdir(pretrained_model_name_or_path): 25 | if os.path.isfile(os.path.join(pretrained_model_name_or_path, file_name)): 26 | # Load from a PyTorch checkpoint 27 | archive_file = os.path.join(pretrained_model_name_or_path, file_name) 28 | else: 29 | raise EnvironmentError( 30 | "Error no file named {} found in directory {}".format( 31 | file_name, 32 | pretrained_model_name_or_path, 33 | ) 34 | ) 35 | elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): 36 | archive_file = pretrained_model_name_or_path 37 | else: 38 | archive_file = hf_bucket_url( 39 | pretrained_model_name_or_path, 40 | filename=file_name, 41 | revision=None, 42 | mirror=None, 43 | ) 44 | 45 | try: 46 | # Load from URL or cache if already cached 47 | resolved_archive_file = cached_path( 48 | archive_file, 49 | cache_dir=None, 50 | force_download=False, 51 | proxies=None, 52 | resume_download=False, 53 | local_files_only=False, 54 | ) 55 | except EnvironmentError as err: 56 | logger.error(err) 57 | msg = ( 58 | f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n" 59 | f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on" 60 | f"'https://huggingface.co/models'\n\n" 61 | f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a" 62 | f"file named one of {file_name}.\n\n" 63 | ) 64 | raise EnvironmentError(msg) 65 | 66 | if resolved_archive_file == archive_file: 67 | logger.info("loading weights file {}".format(archive_file)) 68 | else: 69 | logger.info("loading weights file {} from cache at {}".format(archive_file, resolved_archive_file)) 70 | else: 71 | resolved_archive_file = None 72 | 73 | return resolved_archive_file 74 | -------------------------------------------------------------------------------- /src/mpu/grads.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2020, Sber. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | # Parts of the code here are adapted from PyTorch 18 | # repo: https://github.com/pytorch/pytorch 19 | 20 | 21 | import torch 22 | from torch._six import inf 23 | 24 | from .initialize import get_model_parallel_group 25 | from .initialize import get_model_parallel_rank 26 | 27 | 28 | def clip_grad_norm(parameters, max_norm, norm_type=2): 29 | """Clips gradient norm of an iterable of parameters. 30 | 31 | This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and 32 | added functionality to handle model parallel parameters. Note that 33 | the gradients are modified in place. 34 | 35 | Arguments: 36 | parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a 37 | single Tensor that will have gradients normalized 38 | max_norm (float or int): max norm of the gradients 39 | norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for 40 | infinity norm. 41 | 42 | Returns: 43 | Total norm of the parameters (viewed as a single vector). 44 | """ 45 | if isinstance(parameters, torch.Tensor): 46 | parameters = [parameters] 47 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 48 | max_norm = float(max_norm) 49 | norm_type = float(norm_type) 50 | if norm_type == inf: 51 | total_norm = max(p.grad.data.abs().max() for p in parameters) 52 | total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) 53 | # Take max across all GPUs. 54 | torch.distributed.all_reduce(total_norm_cuda, 55 | op=torch.distributed.ReduceOp.MAX, 56 | group=get_model_parallel_group()) 57 | total_norm = total_norm_cuda[0].item() 58 | else: 59 | total_norm = 0 60 | for p in parameters: 61 | if p.model_parallel or (get_model_parallel_rank() == 0): 62 | param_norm = p.grad.data.norm(norm_type) 63 | total_norm += param_norm.item() ** norm_type 64 | # Sum across all model parallel GPUs. 65 | total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) 66 | torch.distributed.all_reduce(total_norm_cuda, 67 | op=torch.distributed.ReduceOp.SUM, 68 | group=get_model_parallel_group()) 69 | total_norm = total_norm_cuda[0].item() ** (1. / norm_type) 70 | clip_coef = max_norm / (total_norm + 1e-6) 71 | if clip_coef < 1: 72 | for p in parameters: 73 | p.grad.data.mul_(clip_coef) 74 | return total_norm 75 | -------------------------------------------------------------------------------- /src/learning_rates.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2020, Sber. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch DataLoader for TFRecords""" 16 | 17 | import math 18 | 19 | import torch 20 | from torch.optim.lr_scheduler import _LRScheduler 21 | 22 | 23 | class AnnealingLR(_LRScheduler): 24 | """Anneals the learning rate from start to zero along a cosine curve.""" 25 | 26 | DECAY_STYLES = ['linear', 'cosine', 'exponential', 'constant', 'None'] 27 | 28 | def __init__(self, optimizer, start_lr, warmup_iter, num_iters, decay_style=None, last_iter=-1, min_lr=1e-6): 29 | self.optimizer = optimizer 30 | self.start_lr = start_lr 31 | self.warmup_iter = warmup_iter 32 | self.num_iters = last_iter + 1 33 | self.end_iter = num_iters 34 | self.decay_style = decay_style.lower() if isinstance(decay_style, str) else None 35 | self.step(self.num_iters) 36 | self.min_lr = min_lr 37 | self._last_lr = start_lr 38 | self._min_reached = False 39 | if torch.distributed.get_rank() == 0: 40 | print('learning rate decaying', decay_style) 41 | 42 | def get_lr(self): 43 | # https://openreview.net/pdf?id=BJYwwY9ll pg. 4 44 | if self.warmup_iter > 0 and self.num_iters <= self.warmup_iter: 45 | new_lr = float(self.start_lr) * self.num_iters / self.warmup_iter 46 | else: 47 | if self.decay_style == self.DECAY_STYLES[0]: 48 | lr = self.start_lr * ((self.end_iter - (self.num_iters - self.warmup_iter)) / self.end_iter) 49 | new_lr = max(self.min_lr, lr) 50 | elif self.decay_style == self.DECAY_STYLES[1]: 51 | new_lr = self.start_lr / 2.0 * ( 52 | math.cos(math.pi * (self.num_iters - self.warmup_iter) / self.end_iter) + 1) 53 | if new_lr <= self.min_lr or self._min_reached or self.num_iters > self.end_iter: 54 | self._min_reached = True 55 | new_lr = self.min_lr 56 | elif self.decay_style == self.DECAY_STYLES[2]: 57 | # TODO: implement exponential decay 58 | new_lr = self.start_lr 59 | else: 60 | new_lr = self.start_lr 61 | self._last_lr = new_lr 62 | return new_lr 63 | 64 | def step(self, step_num=None): 65 | if step_num is None: 66 | step_num = self.num_iters + 1 67 | self.num_iters = step_num 68 | new_lr = self.get_lr() 69 | for group in self.optimizer.param_groups: 70 | group['lr'] = new_lr 71 | 72 | def state_dict(self): 73 | sd = { 74 | 'start_lr': self.start_lr, 75 | 'warmup_iter': self.warmup_iter, 76 | 'num_iters': self.num_iters, 77 | 'decay_style': self.decay_style, 78 | 'end_iter': self.end_iter 79 | } 80 | return sd 81 | 82 | def load_state_dict(self, sd): 83 | self.start_lr = sd['start_lr'] 84 | self.warmup_iter = sd['warmup_iter'] 85 | self.num_iters = sd['num_iters'] 86 | self.end_iter = sd['end_iter'] 87 | self.decay_style = sd['decay_style'] 88 | self.step(self.num_iters) 89 | -------------------------------------------------------------------------------- /src/mpu/data.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2020, Sber. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | 18 | from .initialize import get_model_parallel_group 19 | from .initialize import get_model_parallel_rank 20 | from .initialize import get_model_parallel_src_rank 21 | 22 | 23 | _MAX_DATA_DIM = 4 24 | 25 | 26 | def _check_data_types(keys, data, target_dtype): 27 | """Check that all the keys have the same target data type.""" 28 | for key in keys: 29 | assert data[key].dtype == target_dtype, '{} has data type {} which '\ 30 | 'is different than {}'.format(key, data[key].dtype, target_dtype) 31 | 32 | 33 | def _build_key_size_numel_dictionaries(keys, data): 34 | """Build the size on rank 0 and broadcast.""" 35 | max_dim = _MAX_DATA_DIM 36 | sizes = [0 for _ in range(max_dim) for _ in keys] 37 | 38 | # Pack the sizes on rank zero. 39 | if get_model_parallel_rank() == 0: 40 | offset = 0 41 | for key in keys: 42 | assert data[key].dim() < max_dim, 'you should increase MAX_DATA_DIM' 43 | size = data[key].size() 44 | for i, s in enumerate(size): 45 | sizes[i + offset] = s 46 | offset += max_dim 47 | 48 | # Move to GPU and broadcast. 49 | sizes_cuda = torch.cuda.LongTensor(sizes) 50 | torch.distributed.broadcast(sizes_cuda, get_model_parallel_src_rank(), 51 | group=get_model_parallel_group()) 52 | 53 | # Move back to cpu and unpack. 54 | sizes_cpu = sizes_cuda.cpu() 55 | key_size = {} 56 | key_numel = {} 57 | total_numel = 0 58 | offset = 0 59 | for key in keys: 60 | i = 0 61 | size = [] 62 | numel = 1 63 | while sizes_cpu[offset + i] > 0: 64 | this_size = sizes_cpu[offset + i] 65 | size.append(this_size) 66 | numel *= this_size 67 | i += 1 68 | key_size[key] = size 69 | key_numel[key] = numel 70 | total_numel += numel 71 | offset += max_dim 72 | 73 | return key_size, key_numel, total_numel 74 | 75 | 76 | def broadcast_data(keys, data, datatype): 77 | """Broadcast data from rank zero of each model parallel group to the 78 | members of the same model parallel group. 79 | 80 | Arguments: 81 | keys: list of keys in the data disctionary to be broadcasted 82 | data: data dictionary of string keys and cpu tensor values. 83 | datatype: torch data type of all tensors in data associated 84 | with keys. 85 | """ 86 | # Build (key, size) and (key, number of elements) dictionaries along 87 | # with the total number of elements on all ranks. 88 | key_size, key_numel, total_numel = _build_key_size_numel_dictionaries(keys, 89 | data) 90 | 91 | # Pack on rank zero. 92 | if get_model_parallel_rank() == 0: 93 | # Check that all keys have the same data type. 94 | _check_data_types(keys, data, datatype) 95 | # Flatten the data associated with the keys 96 | flatten_data = torch.cat( 97 | [data[key].contiguous().view(-1) for key in keys], dim=0).cuda() 98 | else: 99 | flatten_data = torch.empty(total_numel, 100 | device=torch.cuda.current_device(), 101 | dtype=datatype) 102 | 103 | # Boradcast 104 | torch.distributed.broadcast(flatten_data, get_model_parallel_src_rank(), 105 | group=get_model_parallel_group()) 106 | 107 | # Unpack 108 | output = {} 109 | offset = 0 110 | for key in keys: 111 | size = key_size[key] 112 | numel = key_numel[key] 113 | output[key] = flatten_data.narrow(0, offset, numel).view(size) 114 | offset += numel 115 | 116 | return output 117 | -------------------------------------------------------------------------------- /src/mpu/mappings.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2020, Sber. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | 18 | from .initialize import get_model_parallel_group 19 | from .utils import split_tensor_along_last_dim 20 | 21 | 22 | def _reduce(input_): 23 | """All-reduce the the input tensor across model parallel group.""" 24 | group = get_model_parallel_group() 25 | 26 | # Bypass the function if we are using only 1 GPU. 27 | if torch.distributed.get_world_size(group=group) == 1: 28 | return input_ 29 | 30 | # All-reduce. 31 | torch.distributed.all_reduce(input_, group=group) 32 | 33 | return input_ 34 | 35 | 36 | def _split(input_): 37 | """Split the tensor along its last dimension and keep the 38 | corresponding slice.""" 39 | group = get_model_parallel_group() 40 | 41 | # Bypass the function if we are using only 1 GPU. 42 | if torch.distributed.get_world_size(group=group) == 1: 43 | return input_ 44 | 45 | # Split along last dimension. 46 | world_size = torch.distributed.get_world_size(group=group) 47 | input_list = split_tensor_along_last_dim(input_, world_size) 48 | 49 | # Note: torch.split does not create contiguous tensors by default. 50 | rank = torch.distributed.get_rank(group=group) 51 | output = input_list[rank].contiguous() 52 | 53 | return output 54 | 55 | 56 | def _gather(input_): 57 | """Gather tensors and concatinate along the last dimension.""" 58 | group = get_model_parallel_group() 59 | 60 | # Bypass the function if we are using only 1 GPU. 61 | if torch.distributed.get_world_size(group=group) == 1: 62 | return input_ 63 | 64 | # Size and dimension. 65 | last_dim = input_.dim() - 1 66 | rank = torch.distributed.get_rank(group=group) 67 | world_size = torch.distributed.get_world_size(group=group) 68 | 69 | tensor_list = [torch.empty_like(input_) for _ in range(world_size)] 70 | tensor_list[rank] = input_ 71 | torch.distributed.all_gather(tensor_list, input_, group=group) 72 | 73 | # Note: torch.cat already creates a contiguous tensor. 74 | output = torch.cat(tensor_list, dim=last_dim).contiguous() 75 | 76 | return output 77 | 78 | 79 | class _CopyToModelParallelRegion(torch.autograd.Function): 80 | """Pass the input to the model parallel region.""" 81 | 82 | @staticmethod 83 | def forward(ctx, input_): 84 | return input_ 85 | 86 | @staticmethod 87 | def backward(ctx, grad_output): 88 | return _reduce(grad_output) 89 | 90 | 91 | class _ReduceFromModelParallelRegion(torch.autograd.Function): 92 | """All-redcue the input from the model parallel region.""" 93 | 94 | @staticmethod 95 | def forward(ctx, input_): 96 | return _reduce(input_) 97 | 98 | @staticmethod 99 | def backward(ctx, grad_output): 100 | return grad_output 101 | 102 | 103 | class _ScatterToModelParallelRegion(torch.autograd.Function): 104 | """Split the input and keep only the corresponding chuck to the rank.""" 105 | 106 | @staticmethod 107 | def forward(ctx, input_): 108 | return _split(input_) 109 | 110 | @staticmethod 111 | def backward(ctx, grad_output): 112 | return _gather(grad_output) 113 | 114 | 115 | class _GatherFromModelParallelRegion(torch.autograd.Function): 116 | """Gather the input from model parallel region and concatinate.""" 117 | 118 | @staticmethod 119 | def forward(ctx, input_): 120 | return _gather(input_) 121 | 122 | @staticmethod 123 | def backward(ctx, grad_output): 124 | return _split(grad_output) 125 | 126 | 127 | # ----------------- 128 | # Helper functions. 129 | # ----------------- 130 | 131 | def copy_to_model_parallel_region(input_): 132 | return _CopyToModelParallelRegion.apply(input_) 133 | 134 | def reduce_from_model_parallel_region(input_): 135 | return _ReduceFromModelParallelRegion.apply(input_) 136 | 137 | def scatter_to_model_parallel_region(input_): 138 | return _ScatterToModelParallelRegion.apply(input_) 139 | 140 | def gather_from_model_parallel_region(input_): 141 | return _GatherFromModelParallelRegion.apply(input_) 142 | -------------------------------------------------------------------------------- /src/mpu/cross_entropy.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2020, Sber. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import torch 18 | 19 | from .initialize import get_model_parallel_group 20 | from .initialize import get_model_parallel_rank 21 | from .initialize import get_model_parallel_world_size 22 | from .utils import VocabUtility 23 | 24 | 25 | class _VocabParallelCrossEntropy(torch.autograd.Function): 26 | 27 | @staticmethod 28 | def forward(ctx, vocab_parallel_logits, target): 29 | 30 | # Copy so the input remains unchanged. 31 | logits = vocab_parallel_logits.clone() 32 | # Maximum value along vocab dimension across all GPUs. 33 | logits_max = torch.max(logits, dim=-1)[0] 34 | torch.distributed.all_reduce(logits_max, 35 | op=torch.distributed.ReduceOp.MAX, 36 | group=get_model_parallel_group()) 37 | # Subtract the maximum value. 38 | logits.sub_(logits_max.unsqueeze(dim=-1)) 39 | # Sum of exponential of logits along vocab dimension across all GPUs. 40 | exp_logits = logits.exp() 41 | sum_exp_logits = exp_logits.sum(dim=-1) 42 | torch.distributed.all_reduce(sum_exp_logits, 43 | op=torch.distributed.ReduceOp.SUM, 44 | group=get_model_parallel_group()) 45 | 46 | # Get the partition's vocab indecies 47 | get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size 48 | partition_vocab_size = vocab_parallel_logits.size()[-1] 49 | rank = get_model_parallel_rank() 50 | world_size = get_model_parallel_world_size() 51 | vocab_start_index, vocab_end_index = get_vocab_range( 52 | partition_vocab_size, rank, world_size) 53 | 54 | # Create a mask of valid vocab ids (1 means it needs to be masked). 55 | target_mask = (target < vocab_start_index) | (target >= vocab_end_index) 56 | masked_target = target.clone() - vocab_start_index 57 | masked_target[target_mask] = 0 58 | 59 | # Get predicted-logits = logits[target]. 60 | # For Simplicity, we convert logits to a 2-D tensor with size 61 | # [*, partition-vocab-size] and target to a 1-D tensor of size [*]. 62 | logits_2d = logits.view(-1, partition_vocab_size) 63 | masked_target_1d = masked_target.view(-1) 64 | arange_1d = torch.arange(start=0, end=logits_2d.size()[0], 65 | device=logits_2d.device) 66 | predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] 67 | predicted_logits = predicted_logits_1d.view_as(target) 68 | predicted_logits[target_mask] = 0.0 69 | # All reduce is needed to get the chunks from other GPUs. 70 | torch.distributed.all_reduce(predicted_logits, 71 | op=torch.distributed.ReduceOp.SUM, 72 | group=get_model_parallel_group()) 73 | 74 | # Loss = log(sum(exp(logits))) - predicted-logit. 75 | loss = torch.log(sum_exp_logits) - predicted_logits 76 | 77 | # Store softmax, target-mask and masked-target for backward pass. 78 | exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) 79 | ctx.save_for_backward(exp_logits, target_mask, masked_target_1d) 80 | 81 | return loss 82 | 83 | @staticmethod 84 | def backward(ctx, grad_output): 85 | 86 | # Retreive tensors from the forward path. 87 | softmax, target_mask, masked_target_1d = ctx.saved_tensors 88 | 89 | # All the inputs have softmax as thier gradient. 90 | grad_input = softmax 91 | # For simplicity, work with the 2D gradient. 92 | partition_vocab_size = softmax.size()[-1] 93 | grad_2d = grad_input.view(-1, partition_vocab_size) 94 | 95 | # Add the gradient from matching classes. 96 | arange_1d = torch.arange(start=0, end=grad_2d.size()[0], 97 | device=grad_2d.device) 98 | grad_2d[arange_1d, masked_target_1d] -= ( 99 | 1.0 - target_mask.view(-1).float()) 100 | 101 | # Finally elementwise multiplication with the output gradients. 102 | grad_input.mul_(grad_output.unsqueeze(dim=-1)) 103 | 104 | return grad_input, None 105 | 106 | 107 | def vocab_parallel_cross_entropy(vocab_parallel_logits, target): 108 | """Helper function for the cross entropy.""" 109 | return _VocabParallelCrossEntropy.apply(vocab_parallel_logits, target) 110 | -------------------------------------------------------------------------------- /src/gpt3_data_loader.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2020, Sber. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import os 17 | 18 | import torch 19 | from torch.utils.data import BatchSampler, DataLoader 20 | 21 | from src import mpu 22 | from src.dataset_rugpt3 import RuGpt3TextDataset, RuGpt3DatasetArguments 23 | from src.utils import print_rank_0 24 | from transformers import GPT2Tokenizer 25 | 26 | 27 | class InfiniteDataLoader(DataLoader): 28 | def __init__(self, *args, **kwargs): 29 | super().__init__(*args, **kwargs) 30 | # Initialize an iterator over the dataset. 31 | self.dataset_iterator = super().__iter__() 32 | 33 | def __iter__(self): 34 | return self 35 | 36 | def __next__(self): 37 | try: 38 | batch = next(self.dataset_iterator) 39 | except StopIteration: 40 | # Dataset exhausted, use a new fresh iterator. 41 | self.dataset_iterator = super().__iter__() 42 | batch = next(self.dataset_iterator) 43 | return batch 44 | 45 | 46 | class ResumableBatchSampler(BatchSampler): 47 | start_iter = 0 48 | 49 | def __iter__(self): 50 | batch = [] 51 | i = 0 52 | for idx in self.sampler: 53 | batch.append(idx) 54 | if len(batch) == self.batch_size: 55 | if i >= self.start_iter: 56 | yield batch 57 | batch = [] 58 | i += 1 59 | if len(batch) > 0 and not self.drop_last: 60 | yield batch 61 | 62 | 63 | def make_gpt3_dataloaders(args): 64 | # Data parallel arguments 65 | world_size = mpu.get_data_parallel_world_size() 66 | rank = mpu.get_data_parallel_rank() 67 | # global_batch_size = args.batch_size * world_size 68 | num_workers = args.num_workers 69 | 70 | # data_dir = args.train_data_path if args.train_data_path else os.path.dirname(args.test_data_path) 71 | tokenizer_path = args.load_huggingface if args.load_huggingface else \ 72 | (args.tokenizer_path if args.tokenizer_path else os.path.join(os.path.dirname(args.train_data_path), 73 | '_tokenizer/')) 74 | print_rank_0('Load tokenizer from ' + tokenizer_path) 75 | tokenizer = GPT2Tokenizer.from_pretrained(tokenizer_path) 76 | tokenizer.add_special_tokens({"bos_token": ""}) 77 | tokenizer.add_special_tokens({"eos_token": ""}) 78 | 79 | print("Add answer_sep:", args.answer_sep) 80 | tokenizer.add_tokens(args.answer_sep) 81 | 82 | print("Add start_sep", args.start_sep) 83 | tokenizer.add_tokens(args.start_sep) 84 | 85 | print("Add start_sep", args.end_sep) 86 | tokenizer.add_tokens(args.end_sep) 87 | 88 | eod_token = tokenizer.encoder[''] 89 | num_tokens = len(tokenizer) 90 | 91 | train_dataset_args = RuGpt3DatasetArguments( 92 | block_size=args.seq_length, max_files_load=args.max_files_per_process, overwrite_cache=args.overwrite_cache, 93 | tqdm=False) 94 | eval_dataset_args = RuGpt3DatasetArguments( 95 | block_size=args.seq_length, max_files_load=args.max_files_per_process, overwrite_cache=args.overwrite_cache, 96 | tqdm=True) 97 | 98 | def make_data_loader_(data_path, dataset_args): 99 | print_rank_0(f'Load RuGPT3 Dataset from {data_path}, {dataset_args.max_files_load} files per process') 100 | dataset = RuGpt3TextDataset( 101 | tokenizer=tokenizer, 102 | args=dataset_args, 103 | rank=rank, 104 | world_size=world_size, 105 | file_path=data_path, 106 | # cache_prefix=args.cache_prefix 107 | all_args=args 108 | ) 109 | # Use a simple sampler with distributed batch sampler. 110 | sampler = torch.utils.data.SequentialSampler(dataset) 111 | batch_sampler = ResumableBatchSampler(sampler=sampler, 112 | batch_size=args.batch_size, 113 | drop_last=True) 114 | 115 | return InfiniteDataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers, pin_memory=True) 116 | 117 | train = make_data_loader_(args.train_data_path, train_dataset_args) if args.train_data_path else None 118 | valid = make_data_loader_(args.val_data_path, eval_dataset_args) if args.val_data_path else None 119 | test = make_data_loader_(args.test_data_path, eval_dataset_args) if args.test_data_path else None 120 | 121 | args.do_train = train is not None 122 | args.do_valid = valid is not None 123 | args.do_test = test is not None 124 | 125 | return (train, valid, test), num_tokens, eod_token, tokenizer 126 | -------------------------------------------------------------------------------- /src/model/distributed.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2020, Sber. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 18 | import torch.distributed as dist 19 | from torch.nn.modules import Module 20 | from torch.autograd import Variable 21 | 22 | from src import mpu 23 | 24 | 25 | class DistributedDataParallel(Module): 26 | 27 | def __init__(self, module): 28 | super(DistributedDataParallel, self).__init__() 29 | self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False 30 | 31 | self.module = module 32 | self.data_parallel_group = mpu.get_data_parallel_group() 33 | src_rank = mpu.get_model_parallel_rank() 34 | for p in self.module.parameters(): 35 | if torch.is_tensor(p): 36 | dist.broadcast(p, src_rank, group=self.data_parallel_group) 37 | 38 | def allreduce_params(reduce_after=True, no_scale=False, fp32_allreduce=False): 39 | if(self.needs_reduction): 40 | self.needs_reduction = False 41 | buckets = {} 42 | for name, param in self.module.named_parameters(): 43 | if param.requires_grad and param.grad is not None: 44 | tp = (param.data.type()) 45 | if tp not in buckets: 46 | buckets[tp] = [] 47 | buckets[tp].append(param) 48 | if self.warn_on_half: 49 | if torch.cuda.HalfTensor in buckets: 50 | print("WARNING: gloo dist backend for half parameters may be extremely slow." + 51 | " It is recommended to use the NCCL backend in this case.") 52 | self.warn_on_half = False 53 | for tp in buckets: 54 | bucket = buckets[tp] 55 | grads = [param.grad.data for param in bucket] 56 | coalesced = _flatten_dense_tensors(grads) 57 | if fp32_allreduce: 58 | coalesced = coalesced.float() 59 | if not no_scale and not reduce_after: 60 | coalesced /= dist.get_world_size(group=self.data_parallel_group) 61 | dist.all_reduce(coalesced, group=self.data_parallel_group) 62 | torch.cuda.synchronize() 63 | if not no_scale and reduce_after: 64 | coalesced /= dist.get_world_size(group=self.data_parallel_group) 65 | for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): 66 | buf.copy_(synced) 67 | self.hook_handles = [] 68 | self.hooks = [] 69 | for param in list(self.module.parameters()): 70 | def allreduce_hook(*unused): 71 | Variable._execution_engine.queue_callback(allreduce_params) 72 | # handle = param.register_hook(allreduce_hook) 73 | #self.hooks.append(allreduce_hook) 74 | #self.hook_handles.append(handle) 75 | self.allreduce_params = allreduce_params 76 | 77 | def forward(self, *inputs, **kwargs): 78 | self.needs_reduction = True 79 | return self.module(*inputs, **kwargs) 80 | 81 | def state_dict(self, destination=None, prefix='', keep_vars=False): 82 | #[h.remove() for h in self.hook_handles] 83 | sd = self.module.state_dict(destination, prefix, keep_vars) 84 | # for handle, hook in zip(self.hook_handles, self.hooks): 85 | # d = handle.hooks_dict_ref() 86 | # d[handle.id] = hook 87 | 88 | return sd 89 | 90 | def load_state_dict(self, state_dict, strict=True): 91 | self.module.load_state_dict(state_dict, strict=strict) 92 | 93 | ''' 94 | def _sync_buffers(self): 95 | buffers = list(self.module._all_buffers()) 96 | if len(buffers) > 0: 97 | # cross-node buffer sync 98 | flat_buffers = _flatten_dense_tensors(buffers) 99 | dist.broadcast(flat_buffers, 0) 100 | for buf, synced in zip(buffers, _unflatten_dense_tensors(flat_buffers, buffers)): 101 | buf.copy_(synced) 102 | def train(self, mode=True): 103 | # Clear NCCL communicator and CUDA event cache of the default group ID, 104 | # These cache will be recreated at the later call. This is currently a 105 | # work-around for a potential NCCL deadlock. 106 | if dist._backend == dist.dist_backend.NCCL: 107 | dist._clear_group_cache() 108 | super(DistributedDataParallel, self).train(mode) 109 | self.module.train(mode) 110 | ''' 111 | 112 | -------------------------------------------------------------------------------- /src/mpu/initialize.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2020, Sber. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | """Model and data parallel groups.""" 18 | 19 | import torch 20 | 21 | from .utils import ensure_divisibility 22 | 23 | 24 | # Model parallel group that the current rank belongs to. 25 | _MODEL_PARALLEL_GROUP = None 26 | # Data parallel group that the current rank belongs to. 27 | _DATA_PARALLEL_GROUP = None 28 | 29 | 30 | def initialize_model_parallel(model_parallel_size_): 31 | """ 32 | Initialize model data parallel groups. 33 | 34 | Arguments: 35 | model_parallel_size: number of GPUs used to parallelize model. 36 | 37 | Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we 38 | use 2 GPUs to parallelize the model. The present function will 39 | create 4 model parallel groups and 2 data parallel grous as: 40 | 4 model parallel groups: 41 | [g0, g1], [g2, g3], [g4, g5], [g6, g7] 42 | 2 data parallel groups: 43 | [g0, g2, g4, g6], [g1, g3, g5, g7] 44 | Note that for efficiency, the caller should make sure adjacent ranks 45 | are on the same DGX box. For example if we are using 2 DGX-1 boxes 46 | with a total of 16 GPUs, rank 0 to 7 belong to the first box and 47 | ranks 8 to 15 belong to the second box. 48 | """ 49 | if torch.distributed.get_rank() == 0: 50 | print('> initializing model parallel with size {}'.format( 51 | model_parallel_size_)) 52 | # Get world size and rank. Ensure some consistencies. 53 | assert torch.distributed.is_initialized() 54 | world_size = torch.distributed.get_world_size() 55 | model_parallel_size = min(model_parallel_size_, world_size) 56 | ensure_divisibility(world_size, model_parallel_size) 57 | rank = torch.distributed.get_rank() 58 | 59 | # Build the data parallel groups. 60 | global _DATA_PARALLEL_GROUP 61 | assert _DATA_PARALLEL_GROUP is None, \ 62 | 'data parallel group is already initialized' 63 | for i in range(model_parallel_size): 64 | ranks = range(i, world_size, model_parallel_size) 65 | group = torch.distributed.new_group(ranks) 66 | if i == (rank % model_parallel_size): 67 | _DATA_PARALLEL_GROUP = group 68 | 69 | # Build the model parallel groups. 70 | global _MODEL_PARALLEL_GROUP 71 | assert _MODEL_PARALLEL_GROUP is None, \ 72 | 'model parallel group is already initialized' 73 | for i in range(world_size // model_parallel_size): 74 | ranks = range(i * model_parallel_size, 75 | (i + 1) * model_parallel_size) 76 | group = torch.distributed.new_group(ranks) 77 | if i == (rank // model_parallel_size): 78 | _MODEL_PARALLEL_GROUP = group 79 | 80 | 81 | def model_parallel_is_initialized(): 82 | """Check if model and data parallel groups are initialized.""" 83 | if _MODEL_PARALLEL_GROUP is None or _DATA_PARALLEL_GROUP is None: 84 | return False 85 | return True 86 | 87 | 88 | def get_model_parallel_group(): 89 | """Get the model parallel group the caller rank belongs to.""" 90 | assert _MODEL_PARALLEL_GROUP is not None, \ 91 | 'model parallel group is not initialized' 92 | return _MODEL_PARALLEL_GROUP 93 | 94 | 95 | def get_data_parallel_group(): 96 | """Get the data parallel group the caller rank belongs to.""" 97 | assert _DATA_PARALLEL_GROUP is not None, \ 98 | 'data parallel group is not initialized' 99 | return _DATA_PARALLEL_GROUP 100 | 101 | 102 | def get_model_parallel_world_size(): 103 | """Return world size for the model parallel group.""" 104 | return torch.distributed.get_world_size(group=get_model_parallel_group()) 105 | 106 | 107 | def get_model_parallel_rank(): 108 | """Return my rank for the model parallel group.""" 109 | return torch.distributed.get_rank(group=get_model_parallel_group()) 110 | 111 | 112 | def get_model_parallel_src_rank(): 113 | """Calculate the global rank corresponding to a local rank zeor 114 | in the model parallel group.""" 115 | global_rank = torch.distributed.get_rank() 116 | local_world_size = get_model_parallel_world_size() 117 | return (global_rank // local_world_size) * local_world_size 118 | 119 | 120 | def get_data_parallel_world_size(): 121 | """Return world size for the data parallel group.""" 122 | return torch.distributed.get_world_size(group=get_data_parallel_group()) 123 | 124 | 125 | def get_data_parallel_rank(): 126 | """Return my rank for the data parallel group.""" 127 | return torch.distributed.get_rank(group=get_data_parallel_group()) 128 | 129 | 130 | def destroy_model_parallel(): 131 | """Set the groups to none.""" 132 | global _MODEL_PARALLEL_GROUP 133 | _MODEL_PARALLEL_GROUP = None 134 | global _DATA_PARALLEL_GROUP 135 | _DATA_PARALLEL_GROUP = None 136 | -------------------------------------------------------------------------------- /src/model/gpt3_modeling.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2020, Sber. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """GPT-3 model.""" 17 | 18 | import torch 19 | import torch.nn.functional as F 20 | 21 | from src import mpu 22 | 23 | 24 | def init_method_normal(std=0.02): 25 | """Init method based on normal distribution. 26 | 27 | This is only used for embeddings. The transformer has its 28 | own initializer. 29 | """ 30 | def init_(tensor): 31 | return torch.nn.init.normal_(tensor, mean=0.0, std=std) 32 | return init_ 33 | 34 | 35 | class GPT3Model(torch.nn.Module): 36 | """GPT-3 Language model. 37 | 38 | The output of the forward method are the logits (parallel or 39 | serial depending on the `parallel_output` flag. 40 | """ 41 | 42 | def __init__(self, 43 | num_layers, 44 | vocab_size, 45 | hidden_size, 46 | num_attention_heads, 47 | embedding_dropout_prob, 48 | attention_dropout_prob, 49 | output_dropout_prob, 50 | max_sequence_length, 51 | checkpoint_activations, 52 | checkpoint_num_layers=1, 53 | parallel_output=True, 54 | deepspeed_sparsity_config=None, 55 | sparse_mode=None): 56 | 57 | super(GPT3Model, self).__init__() 58 | 59 | self._conf_dict = { 60 | 'vocab_size': vocab_size, 61 | 'n_positions': max_sequence_length, 62 | 'n_ctx': max_sequence_length, 63 | 'n_embd': hidden_size, 64 | 'n_layer': num_layers, 65 | 'n_head': num_attention_heads 66 | } 67 | 68 | self.parallel_output = parallel_output 69 | 70 | init_method = init_method_normal(std=0.02) 71 | 72 | # Word embeddings (parallel). 73 | self.word_embeddings = mpu.VocabParallelEmbedding( 74 | vocab_size, hidden_size, init_method=init_method) 75 | 76 | # Position embedding (serial). 77 | self.position_embeddings = torch.nn.Embedding(max_sequence_length, 78 | hidden_size) 79 | # Initialize the position embeddings. 80 | init_method(self.position_embeddings.weight) 81 | 82 | # Embeddings dropout 83 | self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob) 84 | 85 | # Transformer 86 | self.transformer = mpu.GPT3ParallelTransformer(num_layers, 87 | hidden_size, 88 | num_attention_heads, 89 | attention_dropout_prob, 90 | output_dropout_prob, 91 | checkpoint_activations, 92 | checkpoint_num_layers, 93 | use_deepspeed_sparse=deepspeed_sparsity_config, 94 | sparse_mode=sparse_mode) 95 | 96 | def forward(self, input_ids, position_ids, attention_mask): 97 | 98 | # Embeddings. 99 | # print('input ids tensor', input_ids.size(), input_ids[0,:2]) 100 | words_embeddings = self.word_embeddings(input_ids) 101 | position_embeddings = self.position_embeddings(position_ids) 102 | embeddings = words_embeddings + position_embeddings 103 | 104 | # Dropout. 105 | embeddings = self.embedding_dropout(embeddings) 106 | 107 | # Transformer. 108 | transformer_output = self.transformer(embeddings, attention_mask) 109 | 110 | # Parallel logits. 111 | transformer_output_parallel = mpu.copy_to_model_parallel_region( 112 | transformer_output) 113 | logits_parallel = F.linear(transformer_output_parallel, 114 | self.word_embeddings.weight) 115 | 116 | if self.parallel_output: 117 | return logits_parallel 118 | 119 | return mpu.gather_from_model_parallel_region(logits_parallel) 120 | 121 | 122 | def gpt3_get_params_for_weight_decay_optimization(module): 123 | 124 | weight_decay_params = {'params': []} 125 | no_weight_decay_params = {'params': [], 'weight_decay': 0.0} 126 | for module_ in module.modules(): 127 | if isinstance(module_, (mpu.LayerNorm, torch.nn.LayerNorm)): 128 | no_weight_decay_params['params'].extend( 129 | [p for p in list(module_._parameters.values()) 130 | if p is not None]) 131 | else: 132 | weight_decay_params['params'].extend( 133 | [p for n, p in list(module_._parameters.items()) 134 | if p is not None and n != 'bias']) 135 | no_weight_decay_params['params'].extend( 136 | [p for n, p in list(module_._parameters.items()) 137 | if p is not None and n == 'bias']) 138 | 139 | return weight_decay_params, no_weight_decay_params 140 | -------------------------------------------------------------------------------- /src/dataset_rugpt3.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pickle 4 | import random 5 | from dataclasses import dataclass, field 6 | from typing import Optional 7 | 8 | import numpy as np 9 | import torch 10 | from torch.utils.data import Dataset 11 | from tqdm import tqdm 12 | 13 | logger = logging.getLogger(__name__) 14 | logger.setLevel(logging.INFO) 15 | 16 | 17 | @dataclass 18 | class RuGpt3DatasetArguments: 19 | train_data_file: Optional[str] = field( 20 | default=None, metadata={"help": "The input training data file (a text file)."} 21 | ) 22 | eval_data_file: Optional[str] = field( 23 | default=None, 24 | metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, 25 | ) 26 | block_size: int = field( 27 | default=-1, 28 | metadata={ 29 | "help": "Optional input sequence length after tokenization." 30 | "The training dataset will be truncated in block of this size for training." 31 | "Default to the model max input length for single sentence inputs" 32 | " (take into account special tokens)." 33 | }, 34 | ) 35 | overwrite_cache: bool = field( 36 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 37 | ) 38 | random_shift: bool = field(default=False, metadata={"help": "Make random shift from start of each file"}) 39 | max_files_load: int = field(default=50000, metadata={"help": "Maximum number of files to load at one worker"}) 40 | tqdm: bool = field(default=False, metadata={"help": "Show tqdm progress bar"}) 41 | 42 | 43 | class RuGpt3TextDataset(Dataset): 44 | def process_file(self, file_path, filename, tokenizer, args): 45 | cached_features_file = os.path.join(self._cache_dir, filename.replace('/', '_') + '.pkl') 46 | examples = [] 47 | 48 | if os.path.exists(cached_features_file) and not args.overwrite_cache: 49 | with open(cached_features_file, "rb") as handle: 50 | try: 51 | examples = pickle.load(handle) 52 | examples = np.asarray(examples, dtype=np.int32) 53 | except Exception as e: 54 | print('Failed to load cache file:', cached_features_file) 55 | raise e 56 | else: 57 | examples = [] 58 | with open(os.path.join(file_path, filename), encoding="utf-8") as f: 59 | text = f.read() 60 | if self.args.line_by_line: 61 | lines = [x + "" for x in text.strip().split("")] 62 | for line in lines: 63 | line = tokenizer.encode(line) 64 | line += [tokenizer.encoder['']] * self.args.seq_length 65 | line = line[:self.args.seq_length] 66 | examples.append(line) 67 | else: 68 | tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text)) 69 | 70 | max_shift = max(min(args.block_size, len(tokenized_text) - args.block_size), 0) 71 | rnd_shift = random.randrange(max_shift) if max_shift and args.random_shift else 0 72 | 73 | for i in range(rnd_shift, len(tokenized_text) - args.block_size + 1, args.block_size): 74 | example = tokenized_text[i:i + args.block_size] 75 | if None in example: 76 | raise Exception('None in tokens!: ' + filename) 77 | if len(example) == args.block_size: 78 | examples.append(example) 79 | # Note that we are loosing the last truncated example here for the sake of simplicity (no padding) 80 | # If your dataset is small, first you should loook for a bigger one :-) and second you 81 | # can change this behavior by adding (model specific) padding. 82 | 83 | with open(cached_features_file, "wb") as handle: 84 | pickle.dump(examples, handle, protocol=pickle.HIGHEST_PROTOCOL) 85 | examples = np.asarray(examples, dtype=np.int32) 86 | 87 | return examples 88 | 89 | def __init__(self, tokenizer, args, rank, world_size, file_path, cache_prefix='_', all_args=None): 90 | self.rank = rank 91 | self.world_size = world_size 92 | self.log(f"Loading dataset {file_path}") 93 | max_file_load = args.max_files_load 94 | self.args = all_args 95 | 96 | file_with_list = file_path 97 | file_path = os.path.dirname(file_with_list) 98 | self.log(f"Check filelist {file_with_list} with root dir {file_path}") 99 | 100 | if not os.path.exists(file_with_list) and rank < 1: 101 | raise Exception('No file list!') 102 | 103 | with open(file_with_list, 'r') as fp: 104 | files = [line.strip() for line in fp.read().split('\n') if line] 105 | 106 | if rank == -1: 107 | self.log('Shuffle') 108 | random.shuffle(files) 109 | if len(files) > max_file_load: 110 | files = files[:max_file_load] 111 | else: 112 | shard_size = len(files) // world_size 113 | if shard_size > max_file_load: 114 | logger.warning( 115 | f"Shard size {shard_size} > max_file_load {max_file_load}," 116 | f" only first {(max_file_load * world_size)}" 117 | f" files of dataset would be loaded!") 118 | shard_size = max_file_load 119 | shard_start = rank * shard_size 120 | shard_end = (rank + 1) * shard_size 121 | self.log(f"Shard [{shard_start}, {shard_end}]") 122 | files = files[shard_start:shard_end] 123 | 124 | self._cache_dir = os.path.join(file_path, f'{cache_prefix}cache_{args.block_size}_{len(tokenizer)}') 125 | os.makedirs(self._cache_dir, exist_ok=True) 126 | if args.overwrite_cache: 127 | self.log('Overwrite cache ' + self._cache_dir) 128 | 129 | examples = [] 130 | iterator = tqdm(files) if args.tqdm else files 131 | for i, filename in enumerate(iterator): 132 | if i % 1000 == 0: 133 | self.log(f"Loaded {i}/{len(files)} files") 134 | example = self.process_file(file_path, filename, tokenizer, args=args) 135 | if example.size: 136 | examples.append(example) 137 | self.examples = np.vstack(examples) 138 | np.random.shuffle(self.examples) 139 | self.log(f"Loaded {len(self.examples)} examples, {self.examples.size} tokens") 140 | 141 | def log(self, msg): 142 | logger.warning(f"R{self.rank}/{self.world_size}: {msg}") 143 | 144 | def __len__(self): 145 | return len(self.examples) 146 | 147 | def __getitem__(self, item): 148 | item = item % len(self.examples) # infinite loop, modulo dataset size 149 | if len(self.examples[item]) == 0: 150 | item = random.randint(1, len(self.examples)) 151 | return torch.tensor(self.examples[item], dtype=torch.long) 152 | -------------------------------------------------------------------------------- /src/data_utils/lazy_loader.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2020, Sber. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """utils for loading text from disk""" 16 | import mmap 17 | import os 18 | import pickle as pkl 19 | import time 20 | from itertools import accumulate 21 | 22 | import torch 23 | from torch.multiprocessing import Lock 24 | 25 | 26 | def get_lazy_path(path): 27 | """ 28 | Gets directory path where lazy files are stored. 29 | """ 30 | return os.path.splitext(path)[0] + '.lazy' 31 | 32 | 33 | def exists_lazy(path, data_type='data'): 34 | """ 35 | Check if we've already made a lazy version of this file for the `data_type` field. 36 | """ 37 | if not os.path.exists(get_lazy_path(path)): 38 | return False 39 | contents = os.listdir(get_lazy_path(path)) 40 | if data_type not in contents: 41 | return False 42 | if data_type + '.len.pkl' not in contents: 43 | return False 44 | return True 45 | 46 | 47 | def make_lazy(path, strs, data_type='data'): 48 | """ 49 | Make lazy version of `data_type` field of the file. Byte offsets 50 | corresponding to data indices are stored in a `.len.pkl` data file. 51 | """ 52 | lazypath = get_lazy_path(path) 53 | if not os.path.exists(lazypath): 54 | os.makedirs(lazypath) 55 | datapath = os.path.join(lazypath, data_type) 56 | lenpath = os.path.join(lazypath, data_type + '.len.pkl') 57 | if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: 58 | with open(datapath, 'wb') as f: 59 | str_lens = [] 60 | str_cnt = 0 61 | for s in strs: 62 | if isinstance(s, dict): 63 | s = s['text'] 64 | encoded = s.encode('utf-8') 65 | f.write(encoded) 66 | str_cnt = len(encoded) 67 | str_lens.append(str_cnt) 68 | pkl.dump(str_lens, open(lenpath, 'wb')) 69 | else: 70 | while not os.path.exists(lenpath): 71 | time.sleep(1) 72 | 73 | 74 | def split_strings(strings, start, chr_lens): 75 | """ 76 | Split strings based on string lengths and given start. 77 | """ 78 | return [strings[i - start:j - start] for i, j in zip([start] + chr_lens[:-1], chr_lens)] 79 | 80 | 81 | class ProcessorTokenizer: 82 | """ 83 | callable class that runs a preprocessing, as well as tokenization step, 84 | on input text. 85 | """ 86 | 87 | def __init__(self, tokenizer, process_fn=None): 88 | self.tokenizer = tokenizer 89 | self.process_fn = process_fn 90 | 91 | def __call__(self, string): 92 | if self.tokenizer is not None: 93 | string = self.tokenizer(string, process_fn=self.process_fn) 94 | elif self.process_fn is not None: 95 | string = self.process_fn(string) 96 | return string 97 | 98 | 99 | class lazy_array_loader(object): 100 | """ 101 | Arguments: 102 | path: path to directory where array entries are concatenated into one big string file 103 | and the .len file are located 104 | data_type (str): Some datsets have multiple fields that are stored in different paths. 105 | `data_type` specifies which of these fields to load in this class 106 | mem_map (boolean): Specifies whether to memory map file `path` 107 | map_fn (callable): Fetched strings are passed through map_fn before being returned. 108 | 109 | Example of lazy loader directory structure: 110 | file.json 111 | file.lazy/ 112 | data_type1 113 | data_type1.len.pkl 114 | data_type2 115 | data_type2.len.pkl 116 | """ 117 | 118 | def __init__(self, path, data_type='data', mem_map=False, map_fn=None): 119 | lazypath = get_lazy_path(path) 120 | datapath = os.path.join(lazypath, data_type) 121 | # get file where array entries are concatenated into one big string 122 | self._file = open(datapath, 'rb') 123 | self.file = self._file 124 | # memory map file if necessary 125 | self.mem_map = mem_map 126 | if self.mem_map: 127 | self.file = mmap.mmap(self.file.fileno(), 0, prot=mmap.PROT_READ) 128 | lenpath = os.path.join(lazypath, data_type + '.len.pkl') 129 | self.lens = pkl.load(open(lenpath, 'rb')) 130 | self.ends = list(accumulate(self.lens)) 131 | self.dumb_ends = list(self.ends) 132 | self.read_lock = Lock() 133 | self.process_fn = map_fn 134 | self.map_fn = map_fn 135 | self._tokenizer = None 136 | 137 | def SetTokenizer(self, tokenizer): 138 | """ 139 | logic to set and remove (set to None) tokenizer. 140 | combines preprocessing/tokenization into one callable. 141 | """ 142 | if tokenizer is None: 143 | if not hasattr(self, '_tokenizer'): 144 | self._tokenizer = tokenizer 145 | else: 146 | self._tokenizer = tokenizer 147 | self.map_fn = ProcessorTokenizer(tokenizer, self.process_fn) 148 | 149 | def GetTokenizer(self): 150 | return self._tokenizer 151 | 152 | def __getitem__(self, index): 153 | """ 154 | read file and splice strings based on string ending array `self.ends` 155 | """ 156 | if not isinstance(index, slice): 157 | if index == 0: 158 | start = 0 159 | else: 160 | start = self.ends[index - 1] 161 | end = self.ends[index] 162 | rtn = self.file_read(start, end) 163 | if self.map_fn is not None: 164 | return self.map_fn(rtn) 165 | else: 166 | # if slice, fetch strings with 1 diskread and then splice in memory 167 | chr_lens = self.ends[index] 168 | if index.start == 0 or index.start is None: 169 | start = 0 170 | else: 171 | start = self.ends[index.start - 1] 172 | stop = chr_lens[-1] 173 | strings = self.file_read(start, stop) 174 | rtn = split_strings(strings, start, chr_lens) 175 | if self.map_fn is not None: 176 | return self.map_fn([s for s in rtn]) 177 | return rtn 178 | 179 | def __len__(self): 180 | return len(self.ends) 181 | 182 | def file_read(self, start=0, end=None): 183 | """read specified portion of file""" 184 | 185 | # atomic reads to avoid race conditions with multiprocess dataloader 186 | self.read_lock.acquire() 187 | # seek to start of file read 188 | self.file.seek(start) 189 | # read to end of file if no end point provided 190 | if end is None: 191 | rtn = self.file.read() 192 | # else read amount needed to reach end point 193 | else: 194 | rtn = self.file.read(end - start) 195 | self.read_lock.release() 196 | # TODO: @raulp figure out mem map byte string bug 197 | # if mem map'd need to decode byte string to string 198 | rtn = rtn.decode('utf-8', 'ignore') 199 | # rtn = str(rtn) 200 | if self.mem_map: 201 | rtn = rtn.decode('unicode_escape') 202 | return rtn 203 | -------------------------------------------------------------------------------- /src/fp16/fp16util.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2020, Sber. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | import torch.nn as nn 18 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 19 | from torch.autograd import Variable 20 | 21 | from src import mpu 22 | 23 | 24 | class tofp16(nn.Module): 25 | """ 26 | Utility module that implements:: 27 | 28 | def forward(self, input): 29 | return input.half() 30 | """ 31 | 32 | def __init__(self): 33 | super(tofp16, self).__init__() 34 | 35 | def forward(self, input): 36 | return input.half() 37 | 38 | 39 | def BN_convert_float(module): 40 | """ 41 | Utility function for network_to_half(). 42 | 43 | Retained for legacy purposes. 44 | """ 45 | if isinstance(module, torch.nn.modules.batchnorm._BatchNorm) and module.affine is True: 46 | module.float() 47 | for child in module.children(): 48 | BN_convert_float(child) 49 | return module 50 | 51 | 52 | def network_to_half(network): 53 | """ 54 | Convert model to half precision in a batchnorm-safe way. 55 | 56 | Retained for legacy purposes. It is recommended to use FP16Model. 57 | """ 58 | return nn.Sequential(tofp16(), BN_convert_float(network.half())) 59 | 60 | 61 | def convert_module(module, dtype): 62 | """ 63 | Converts a module's immediate parameters and buffers to dtype. 64 | """ 65 | for param in module.parameters(recurse=False): 66 | if param is not None: 67 | if param.data.dtype.is_floating_point: 68 | param.data = param.data.to(dtype=dtype) 69 | if param._grad is not None and param._grad.data.dtype.is_floating_point: 70 | param._grad.data = param._grad.data.to(dtype=dtype) 71 | 72 | for buf in module.buffers(recurse=False): 73 | if buf is not None and buf.data.dtype.is_floating_point: 74 | buf.data = buf.data.to(dtype=dtype) 75 | 76 | 77 | def convert_network(network, dtype): 78 | """ 79 | Converts a network's parameters and buffers to dtype. 80 | """ 81 | for module in network.modules(): 82 | if isinstance(module, torch.nn.modules.batchnorm._BatchNorm) and module.affine is True: 83 | continue 84 | convert_module(module, dtype) 85 | return network 86 | 87 | 88 | class FP16Model(nn.Module): 89 | """ 90 | Convert model to half precision in a batchnorm-safe way. 91 | """ 92 | 93 | def __init__(self, network): 94 | super(FP16Model, self).__init__() 95 | self.network = convert_network(network, dtype=torch.half) 96 | 97 | def forward(self, *inputs): 98 | inputs = tuple(t.half() for t in inputs) 99 | return self.network(*inputs) 100 | 101 | 102 | def backwards_debug_hook(grad): 103 | raise RuntimeError("master_params recieved a gradient in the backward pass!") 104 | 105 | 106 | def prep_param_lists(model, flat_master=False): 107 | """ 108 | Creates a list of FP32 master parameters for a given model, as in 109 | `Training Neural Networks with Mixed Precision: Real Examples`_. 110 | 111 | Args: 112 | model (torch.nn.Module): Existing Pytorch model 113 | flat_master (bool, optional, default=False): Flatten the master parameters into a single tensor, as a performance optimization. 114 | Returns: 115 | A tuple (``model_params``, ``master_params``). ``model_params`` is a list of the model's parameters for later use with :func:`model_grads_to_master_grads` and :func:`master_params_to_model_params`. ``master_params`` is a list of FP32 master gradients. If ``flat_master=True``, ``master_params`` will be a list with one element. 116 | 117 | Example:: 118 | 119 | model_params, master_params = prep_param_lists(model) 120 | 121 | .. warning:: 122 | Currently, if ``flat_master=True``, all the model's parameters must be the same type. If the model has parameters of different types, use ``flat_master=False``, or use :class:`FP16_Optimizer`. 123 | 124 | .. _`Training Neural Networks with Mixed Precision: Real Examples`: 125 | http://on-demand.gputechconf.com/gtc/2018/video/S81012/ 126 | """ 127 | model_params = [param for param in model.parameters() if param.requires_grad] 128 | 129 | if flat_master: 130 | # Give the user some more useful error messages 131 | try: 132 | # flatten_dense_tensors returns a contiguous flat array. 133 | # http://pytorch.org/docs/master/_modules/torch/_utils.html 134 | master_params = _flatten_dense_tensors([param.data for param in model_params]).float() 135 | except: 136 | print("Error in prep_param_lists: model may contain a mixture of parameters " 137 | "of different types. Use flat_master=False, or use F16_Optimizer.") 138 | raise 139 | master_params = torch.nn.Parameter(master_params) 140 | master_params.requires_grad = True 141 | # master_params.register_hook(backwards_debug_hook) 142 | if master_params.grad is None: 143 | master_params.grad = master_params.new(*master_params.size()) 144 | return model_params, [master_params] 145 | else: 146 | master_params = [param.clone().float().detach() for param in model_params] 147 | for param in master_params: 148 | param.requires_grad = True 149 | return model_params, master_params 150 | 151 | 152 | def model_grads_to_master_grads(model_params, master_params, flat_master=False): 153 | """ 154 | Copy model gradients to master gradients. 155 | 156 | Args: 157 | model_params: List of model parameters created by :func:`prep_param_lists`. 158 | master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`model_grads_to_master_grads`. 159 | """ 160 | if flat_master: 161 | # The flattening may incur one more deep copy than is necessary. 162 | master_params[0].grad.data.copy_( 163 | _flatten_dense_tensors([p.grad.data for p in model_params])) 164 | else: 165 | for model, master in zip(model_params, master_params): 166 | if model.grad is not None: 167 | if master.grad is None: 168 | master.grad = Variable(master.data.new(*master.data.size())) 169 | master.grad.data.copy_(model.grad.data) 170 | else: 171 | master.grad = None 172 | 173 | 174 | def master_params_to_model_params(model_params, master_params, flat_master=False): 175 | """ 176 | Copy master parameters to model parameters. 177 | 178 | Args: 179 | model_params: List of model parameters created by :func:`prep_param_lists`. 180 | master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`master_params_to_model_params`. 181 | """ 182 | if flat_master: 183 | for model, master in zip(model_params, 184 | _unflatten_dense_tensors(master_params[0].data, model_params)): 185 | model.data.copy_(master) 186 | else: 187 | for model, master in zip(model_params, master_params): 188 | model.data.copy_(master.data) 189 | 190 | 191 | # Backward compatibility fixes 192 | 193 | def to_python_float(t): 194 | if hasattr(t, 'item'): 195 | return t.item() 196 | else: 197 | return t[0] 198 | 199 | 200 | TORCH_MAJOR = int(torch.__version__.split('.')[0]) 201 | TORCH_MINOR = int(torch.__version__.split('.')[1]) 202 | 203 | clip_grad_norm = mpu.clip_grad_norm 204 | # elif TORCH_MAJOR == 0 and TORCH_MINOR <= 4: 205 | # clip_grad_norm = torch.nn.utils.clip_grad_norm 206 | # else: 207 | # clip_grad_norm = torch.nn.utils.clip_grad_norm_ 208 | -------------------------------------------------------------------------------- /src/data_utils/file_utils.py: -------------------------------------------------------------------------------- 1 | # This file is provided as is from: 2 | # https://github.com/huggingface/pytorch-pretrained-BERT 3 | # Please refer to their repository for copyright. 4 | 5 | """ 6 | Utilities for working with the local dataset cache. 7 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp 8 | Copyright by the AllenNLP authors. 9 | """ 10 | from __future__ import (absolute_import, division, print_function, unicode_literals) 11 | 12 | import json 13 | import logging 14 | import os 15 | import shutil 16 | import tempfile 17 | from functools import wraps 18 | from hashlib import sha256 19 | import sys 20 | from io import open 21 | 22 | import boto3 23 | import requests 24 | from botocore.exceptions import ClientError 25 | from tqdm import tqdm 26 | 27 | try: 28 | from urllib.parse import urlparse 29 | except ImportError: 30 | from urlparse import urlparse 31 | 32 | try: 33 | from pathlib import Path 34 | PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 35 | Path.home() / '.pytorch_pretrained_bert')) 36 | except (AttributeError, ImportError): 37 | PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 38 | os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert')) 39 | 40 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 41 | 42 | 43 | def url_to_filename(url, etag=None): 44 | """ 45 | Convert `url` into a hashed filename in a repeatable way. 46 | If `etag` is specified, append its hash to the url's, delimited 47 | by a period. 48 | """ 49 | url_bytes = url.encode('utf-8') 50 | url_hash = sha256(url_bytes) 51 | filename = url_hash.hexdigest() 52 | 53 | if etag: 54 | etag_bytes = etag.encode('utf-8') 55 | etag_hash = sha256(etag_bytes) 56 | filename += '.' + etag_hash.hexdigest() 57 | 58 | return filename 59 | 60 | 61 | def filename_to_url(filename, cache_dir=None): 62 | """ 63 | Return the url and etag (which may be ``None``) stored for `filename`. 64 | Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. 65 | """ 66 | if cache_dir is None: 67 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 68 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 69 | cache_dir = str(cache_dir) 70 | 71 | cache_path = os.path.join(cache_dir, filename) 72 | if not os.path.exists(cache_path): 73 | raise EnvironmentError("file {} not found".format(cache_path)) 74 | 75 | meta_path = cache_path + '.json' 76 | if not os.path.exists(meta_path): 77 | raise EnvironmentError("file {} not found".format(meta_path)) 78 | 79 | with open(meta_path, encoding="utf-8") as meta_file: 80 | metadata = json.load(meta_file) 81 | url = metadata['url'] 82 | etag = metadata['etag'] 83 | 84 | return url, etag 85 | 86 | 87 | def cached_path(url_or_filename, cache_dir=None): 88 | """ 89 | Given something that might be a URL (or might be a local path), 90 | determine which. If it's a URL, download the file and cache it, and 91 | return the path to the cached file. If it's already a local path, 92 | make sure the file exists and then return the path. 93 | """ 94 | if cache_dir is None: 95 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 96 | if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): 97 | url_or_filename = str(url_or_filename) 98 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 99 | cache_dir = str(cache_dir) 100 | 101 | parsed = urlparse(url_or_filename) 102 | 103 | if parsed.scheme in ('http', 'https', 's3'): 104 | # URL, so get it from the cache (downloading if necessary) 105 | return get_from_cache(url_or_filename, cache_dir) 106 | elif os.path.exists(url_or_filename): 107 | # File, and it exists. 108 | return url_or_filename 109 | elif parsed.scheme == '': 110 | # File, but it doesn't exist. 111 | raise EnvironmentError("file {} not found".format(url_or_filename)) 112 | else: 113 | # Something unknown 114 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) 115 | 116 | 117 | def split_s3_path(url): 118 | """Split a full s3 path into the bucket name and path.""" 119 | parsed = urlparse(url) 120 | if not parsed.netloc or not parsed.path: 121 | raise ValueError("bad s3 path {}".format(url)) 122 | bucket_name = parsed.netloc 123 | s3_path = parsed.path 124 | # Remove '/' at beginning of path. 125 | if s3_path.startswith("/"): 126 | s3_path = s3_path[1:] 127 | return bucket_name, s3_path 128 | 129 | 130 | def s3_request(func): 131 | """ 132 | Wrapper function for s3 requests in order to create more helpful error 133 | messages. 134 | """ 135 | 136 | @wraps(func) 137 | def wrapper(url, *args, **kwargs): 138 | try: 139 | return func(url, *args, **kwargs) 140 | except ClientError as exc: 141 | if int(exc.response["Error"]["Code"]) == 404: 142 | raise EnvironmentError("file {} not found".format(url)) 143 | else: 144 | raise 145 | 146 | return wrapper 147 | 148 | 149 | @s3_request 150 | def s3_etag(url): 151 | """Check ETag on S3 object.""" 152 | s3_resource = boto3.resource("s3") 153 | bucket_name, s3_path = split_s3_path(url) 154 | s3_object = s3_resource.Object(bucket_name, s3_path) 155 | return s3_object.e_tag 156 | 157 | 158 | @s3_request 159 | def s3_get(url, temp_file): 160 | """Pull a file directly from S3.""" 161 | s3_resource = boto3.resource("s3") 162 | bucket_name, s3_path = split_s3_path(url) 163 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) 164 | 165 | 166 | def http_get(url, temp_file): 167 | req = requests.get(url, stream=True) 168 | content_length = req.headers.get('Content-Length') 169 | total = int(content_length) if content_length is not None else None 170 | progress = tqdm(unit="B", total=total) 171 | for chunk in req.iter_content(chunk_size=1024): 172 | if chunk: 173 | # filter out keep-alive new chunks 174 | progress.update(len(chunk)) 175 | temp_file.write(chunk) 176 | progress.close() 177 | 178 | 179 | def get_from_cache(url, cache_dir=None): 180 | """ 181 | Given a URL, look for the corresponding dataset in the local cache. 182 | If it's not there, download it. Then return the path to the cached file. 183 | """ 184 | if cache_dir is None: 185 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 186 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 187 | cache_dir = str(cache_dir) 188 | 189 | if not os.path.exists(cache_dir): 190 | os.makedirs(cache_dir) 191 | 192 | # Get eTag to add to filename, if it exists. 193 | if url.startswith("s3://"): 194 | etag = s3_etag(url) 195 | else: 196 | response = requests.head(url, allow_redirects=True) 197 | if response.status_code != 200: 198 | raise IOError("HEAD request failed for url {} with status code {}" 199 | .format(url, response.status_code)) 200 | etag = response.headers.get("ETag") 201 | 202 | filename = url_to_filename(url, etag) 203 | 204 | # get cache path to put the file 205 | cache_path = os.path.join(cache_dir, filename) 206 | 207 | if not os.path.exists(cache_path): 208 | # Download to temporary file, then copy to cache dir once finished. 209 | # Otherwise you get corrupt cache entries if the download gets interrupted. 210 | with tempfile.NamedTemporaryFile() as temp_file: 211 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name) 212 | 213 | # GET file object 214 | if url.startswith("s3://"): 215 | s3_get(url, temp_file) 216 | else: 217 | http_get(url, temp_file) 218 | 219 | # we are copying the file before closing it, so flush to avoid truncation 220 | temp_file.flush() 221 | # shutil.copyfileobj() starts at the current position, so go to the start 222 | temp_file.seek(0) 223 | 224 | logger.info("copying %s to cache at %s", temp_file.name, cache_path) 225 | with open(cache_path, 'wb') as cache_file: 226 | shutil.copyfileobj(temp_file, cache_file) 227 | 228 | logger.info("creating metadata file for %s", cache_path) 229 | meta = {'url': url, 'etag': etag} 230 | meta_path = cache_path + '.json' 231 | with open(meta_path, 'w', encoding="utf-8") as meta_file: 232 | json.dump(meta, meta_file) 233 | 234 | logger.info("removing temp file %s", temp_file.name) 235 | 236 | return cache_path 237 | 238 | 239 | def read_set_from_file(filename): 240 | """ 241 | Extract a de-duped collection (set) of text from a file. 242 | Expected file format is one item per line. 243 | """ 244 | collection = set() 245 | with open(filename, 'r', encoding='utf-8') as file_: 246 | for line in file_: 247 | collection.add(line.rstrip()) 248 | return collection 249 | 250 | 251 | def get_file_extension(path, dot=True, lower=True): 252 | ext = os.path.splitext(path)[1] 253 | ext = ext if dot else ext[1:] 254 | return ext.lower() if lower else ext 255 | -------------------------------------------------------------------------------- /src/fp16/loss_scaler.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2020, Sber. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | 18 | from src import mpu 19 | 20 | 21 | # item() is a recent addition, so this helps with backward compatibility. 22 | def to_python_float(t): 23 | if hasattr(t, 'item'): 24 | return t.item() 25 | else: 26 | return t[0] 27 | 28 | 29 | class LossScaler: 30 | """ 31 | Class that manages a static loss scale. This class is intended to interact with 32 | :class:`FP16_Optimizer`, and should not be directly manipulated by the user. 33 | 34 | Use of :class:`LossScaler` is enabled via the ``static_loss_scale`` argument to 35 | :class:`FP16_Optimizer`'s constructor. 36 | 37 | Args: 38 | scale (float, optional, default=1.0): The loss scale. 39 | """ 40 | 41 | def __init__(self, scale=1): 42 | self.cur_scale = scale 43 | 44 | # `params` is a list / generator of torch.Variable 45 | def has_overflow(self, params): 46 | return False 47 | 48 | # `x` is a torch.Tensor 49 | def _has_inf_or_nan(x): 50 | return False 51 | 52 | def update_scale(self, overflow): 53 | pass 54 | 55 | @property 56 | def loss_scale(self): 57 | return self.cur_scale 58 | 59 | def scale_gradient(self, module, grad_in, grad_out): 60 | return tuple(self.loss_scale * g for g in grad_in) 61 | 62 | def backward(self, loss, retain_graph=False): 63 | scaled_loss = loss * self.loss_scale 64 | scaled_loss.backward(retain_graph=retain_graph) 65 | 66 | 67 | class DynamicLossScaler: 68 | """ 69 | Class that manages dynamic loss scaling. It is recommended to use :class:`DynamicLossScaler` 70 | indirectly, by supplying ``dynamic_loss_scale=True`` to the constructor of 71 | :class:`FP16_Optimizer`. However, it's important to understand how :class:`DynamicLossScaler` 72 | operates, because the default options can be changed using the 73 | the ``dynamic_loss_args`` argument to :class:`FP16_Optimizer`'s constructor. 74 | 75 | Loss scaling is designed to combat the problem of underflowing gradients encountered at long 76 | times when training fp16 networks. Dynamic loss scaling begins by attempting a very high loss 77 | scale. Ironically, this may result in OVERflowing gradients. If overflowing gradients are 78 | encountered, :class:`DynamicLossScaler` informs :class:`FP16_Optimizer` that an overflow has 79 | occurred. 80 | :class:`FP16_Optimizer` then skips the update step for this particular iteration/minibatch, 81 | and :class:`DynamicLossScaler` adjusts the loss scale to a lower value. 82 | If a certain number of iterations occur without overflowing gradients detected, 83 | :class:`DynamicLossScaler` increases the loss scale once more. 84 | In this way :class:`DynamicLossScaler` attempts to "ride the edge" of 85 | always using the highest loss scale possible without incurring overflow. 86 | 87 | Args: 88 | init_scale (float, optional, default=2**32): Initial loss scale attempted by :class:`DynamicLossScaler.` 89 | scale_factor (float, optional, default=2.0): Factor used when adjusting the loss scale. If an overflow is encountered, the loss scale is readjusted to loss scale/``scale_factor``. If ``scale_window`` consecutive iterations take place without an overflow, the loss scale is readjusted to loss_scale*``scale_factor``. 90 | scale_window (int, optional, default=1000): Number of consecutive iterations without an overflow to wait before increasing the loss scale. 91 | """ 92 | 93 | def __init__(self, 94 | init_scale=2 ** 32, 95 | scale_factor=2., 96 | scale_window=1000, 97 | min_scale=1, 98 | delayed_shift=1, 99 | consecutive_hysteresis=False): 100 | self.cur_scale = init_scale 101 | self.cur_iter = 0 102 | self.last_overflow_iter = -1 103 | self.scale_factor = scale_factor 104 | self.scale_window = scale_window 105 | self.min_scale = min_scale 106 | self.delayed_shift = delayed_shift 107 | self.cur_hysteresis = delayed_shift 108 | self.consecutive_hysteresis = consecutive_hysteresis 109 | 110 | # `params` is a list / generator of torch.Variable 111 | def has_overflow_serial(self, params): 112 | for p in params: 113 | if p.grad is not None and DynamicLossScaler._has_inf_or_nan(p.grad.data): 114 | return True 115 | 116 | return False 117 | 118 | def has_overflow(self, params): 119 | overflow = self.has_overflow_serial(params) 120 | # Since each model parallel GPU carries only part of the model, 121 | # make sure overflow flag is synced across all the model parallel GPUs 122 | overflow_gpu = torch.cuda.ByteTensor([overflow]) 123 | torch.distributed.all_reduce(overflow_gpu, 124 | op=torch.distributed.ReduceOp.MAX, 125 | group=mpu.get_model_parallel_group()) 126 | overflow = overflow_gpu[0].item() 127 | return bool(overflow) 128 | 129 | # `x` is a torch.Tensor 130 | def _has_inf_or_nan(x): 131 | try: 132 | # if x is half, the .float() incurs an additional deep copy, but it's necessary if 133 | # Pytorch's .sum() creates a one-element tensor of the same type as x 134 | # (which is true for some recent version of pytorch). 135 | cpu_sum = float(x.float().sum()) 136 | # More efficient version that can be used if .sum() returns a Python scalar 137 | # cpu_sum = float(x.sum()) 138 | except RuntimeError as instance: 139 | # We want to check if inst is actually an overflow exception. 140 | # RuntimeError could come from a different error. 141 | # If so, we still want the exception to propagate. 142 | if "value cannot be converted" not in instance.args[0]: 143 | raise 144 | return True 145 | else: 146 | if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum: 147 | return True 148 | return False 149 | 150 | # `overflow` is boolean indicating whether the gradient overflowed 151 | def update_scale(self, overflow): 152 | 153 | if not hasattr(self, 'min_scale'): 154 | self.min_scale = 1 155 | if not hasattr(self, 'delayed_shift'): 156 | self.delayed_shift = 1 157 | if not hasattr(self, 'cur_hysteresis'): 158 | self.cur_hysteresis = 1 159 | if not hasattr(self, 'consecutive_hysteresis'): 160 | self.consecutive_hysteresis = True 161 | if overflow: 162 | # self.cur_scale /= self.scale_factor 163 | if self.delayed_shift == 1 or self.cur_hysteresis == 1: 164 | self.cur_scale = max(self.cur_scale / self.scale_factor, self.min_scale) 165 | else: 166 | self.cur_hysteresis -= 1 167 | self.last_overflow_iter = self.cur_iter 168 | else: 169 | if self.consecutive_hysteresis: 170 | self.cur_hysteresis = self.delayed_shift 171 | if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0: 172 | if not self.consecutive_hysteresis: 173 | self.cur_hysteresis = self.delayed_shift 174 | self.cur_scale *= self.scale_factor 175 | self.cur_iter += 1 176 | 177 | @property 178 | def loss_scale(self): 179 | return self.cur_scale 180 | 181 | def scale_gradient(self, module, grad_in, grad_out): 182 | return tuple(self.loss_scale * g for g in grad_in) 183 | 184 | def backward(self, loss, retain_graph=False): 185 | scaled_loss = loss * self.loss_scale 186 | scaled_loss.backward(retain_graph=retain_graph) 187 | 188 | 189 | ############################################################## 190 | # Example usage below here -- assuming it's in a separate file 191 | ############################################################## 192 | """ 193 | TO-DO separate out into an example. 194 | if __name__ == "__main__": 195 | import torch 196 | from torch.autograd import Variable 197 | from dynamic_loss_scaler import DynamicLossScaler 198 | 199 | # N is batch size; D_in is input dimension; 200 | # H is hidden dimension; D_out is output dimension. 201 | N, D_in, H, D_out = 64, 1000, 100, 10 202 | 203 | # Create random Tensors to hold inputs and outputs, and wrap them in Variables. 204 | x = Variable(torch.randn(N, D_in), requires_grad=False) 205 | y = Variable(torch.randn(N, D_out), requires_grad=False) 206 | 207 | w1 = Variable(torch.randn(D_in, H), requires_grad=True) 208 | w2 = Variable(torch.randn(H, D_out), requires_grad=True) 209 | parameters = [w1, w2] 210 | 211 | learning_rate = 1e-6 212 | optimizer = torch.optim.SGD(parameters, lr=learning_rate) 213 | loss_scaler = DynamicLossScaler() 214 | 215 | for t in range(500): 216 | y_pred = x.mm(w1).clamp(min=0).mm(w2) 217 | loss = (y_pred - y).pow(2).sum() * loss_scaler.loss_scale 218 | print('Iter {} loss scale: {}'.format(t, loss_scaler.loss_scale)) 219 | print('Iter {} scaled loss: {}'.format(t, loss.data[0])) 220 | print('Iter {} unscaled loss: {}'.format(t, loss.data[0] / loss_scaler.loss_scale)) 221 | 222 | # Run backprop 223 | optimizer.zero_grad() 224 | loss.backward() 225 | 226 | # Check for overflow 227 | has_overflow = DynamicLossScaler.has_overflow(parameters) 228 | 229 | # If no overflow, unscale grad and update as usual 230 | if not has_overflow: 231 | for param in parameters: 232 | param.grad.data.mul_(1. / loss_scaler.loss_scale) 233 | optimizer.step() 234 | # Otherwise, don't do anything -- ie, skip iteration 235 | else: 236 | print('OVERFLOW!') 237 | 238 | # Update loss scale for next iteration 239 | loss_scaler.update_scale(has_overflow) 240 | 241 | """ 242 | -------------------------------------------------------------------------------- /src/xl_wrapper.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import os 4 | import random 5 | from typing import Union, Iterable 6 | 7 | import numpy as np 8 | import torch 9 | from deepspeed import DeepSpeedConfig 10 | from torch.nn import CrossEntropyLoss 11 | from transformers import GPT2Tokenizer, PreTrainedModel, PretrainedConfig 12 | 13 | from src import mpu 14 | from .fp16 import FP16_Module 15 | from .model import GPT3Model 16 | from .download_utils import download_model_files 17 | from transformers.utils import logging 18 | 19 | 20 | logger = logging.get_logger(__name__) 21 | NoneType = type(None) 22 | 23 | 24 | def get_deepspeed_config(path): 25 | return DeepSpeedConfig(path) 26 | 27 | 28 | def get_sparse_attention_config(path, num_heads): 29 | ds_config = get_deepspeed_config(path) 30 | if hasattr(ds_config, 'sparse_attention') and ds_config.sparse_attention: 31 | sa_config = ds_config.sparse_attention 32 | sa_mode = sa_config.get('mode') 33 | if sa_mode == 'dense': 34 | from deepspeed.ops.sparse_attention import DenseSparsityConfig as STConfig 35 | elif sa_mode == 'fixed': 36 | from deepspeed.ops.sparse_attention import FixedSparsityConfig as STConfig 37 | elif sa_mode == 'bigbird': 38 | from deepspeed.ops.sparse_attention import BigBirdSparsityConfig as STConfig 39 | elif sa_mode == 'bslongformer': 40 | from deepspeed.ops.sparse_attention import BSLongformerSparsityConfig as STConfig 41 | elif sa_mode == 'variable': 42 | from deepspeed.ops.sparse_attention import VariableSparsityConfig as STConfig 43 | else: 44 | raise NotImplementedError( 45 | f'Given sparsity mode, {sa_mode}, has not been implemented yet!' 46 | ) 47 | del sa_config['mode'] 48 | return STConfig(num_heads=num_heads, **sa_config) 49 | else: 50 | return None 51 | 52 | 53 | def get_model(deepspeed_config_path): 54 | num_local_heads = 16 55 | sparse_mode = 'alternating' 56 | deepspeed_sparsity_config = get_sparse_attention_config(deepspeed_config_path, num_local_heads) 57 | if deepspeed_sparsity_config is not None: 58 | logger.info(f"Use sparse attention with mode {sparse_mode}") 59 | else: 60 | logger.info(f"Use dense attention") 61 | model = GPT3Model(num_layers=24, 62 | vocab_size=50264, 63 | hidden_size=2048, 64 | num_attention_heads=num_local_heads, 65 | embedding_dropout_prob=0.1, attention_dropout_prob=0.1, output_dropout_prob=0.1, 66 | max_sequence_length=2048, 67 | checkpoint_activations=False, 68 | checkpoint_num_layers=1, 69 | parallel_output=False, 70 | deepspeed_sparsity_config=deepspeed_sparsity_config, 71 | sparse_mode=sparse_mode) 72 | # GPU allocation. 73 | model.cuda(torch.cuda.current_device()) 74 | 75 | # Fp16 conversion. 76 | model = FP16_Module(model) 77 | 78 | return model 79 | 80 | 81 | def setup_model(weights_path, deepspeed_config_path): 82 | model = get_model(deepspeed_config_path) 83 | logger.info("Load checkpoint from " + weights_path) 84 | checkpoint = torch.load(weights_path, map_location=lambda storage, loc: storage)['module'] 85 | model.load_state_dict(checkpoint) 86 | model.eval() 87 | logger.info("Model Loaded") 88 | return model 89 | 90 | 91 | def get_masks_and_position_ids(data, 92 | eod_token, 93 | reset_position_ids, 94 | reset_attention_mask): 95 | # Extract batch size and sequence length. 96 | batch_size, seq_length = data.size() 97 | 98 | # Attention mask (lower triangular). 99 | if reset_attention_mask: 100 | att_mask_batch = batch_size 101 | else: 102 | att_mask_batch = 1 103 | attention_mask = torch.tril(torch.ones( 104 | (att_mask_batch, seq_length, seq_length), device=data.device)).view( 105 | att_mask_batch, 1, seq_length, seq_length) 106 | 107 | # Loss mask. 108 | loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device) 109 | loss_mask[data == eod_token] = 0.0 110 | 111 | # Position ids. 112 | position_ids = torch.arange(seq_length, dtype=torch.long, 113 | device=data.device) 114 | position_ids = position_ids.unsqueeze(0).expand_as(data) 115 | # We need to clone as the ids will be modifed based on batch index. 116 | if reset_position_ids: 117 | position_ids = position_ids.clone() 118 | 119 | if reset_position_ids or reset_attention_mask: 120 | # Loop through the batches: 121 | for b in range(batch_size): 122 | 123 | # Find indices where EOD token is. 124 | eod_index = position_ids[b, data[b] == eod_token] 125 | # Detach indecies from positions if going to modify positions. 126 | if reset_position_ids: 127 | eod_index = eod_index.clone() 128 | 129 | # Loop through EOD indices: 130 | prev_index = 0 131 | for j in range(eod_index.size()[0]): 132 | i = eod_index[j] 133 | # Mask attention loss. 134 | if reset_attention_mask: 135 | attention_mask[b, 0, (i + 1):, :(i + 1)] = 0 136 | # Reset positions. 137 | if reset_position_ids: 138 | position_ids[b, (i + 1):] -= (i + 1 - prev_index) 139 | prev_index = i + 1 140 | 141 | return attention_mask, loss_mask, position_ids 142 | 143 | 144 | class ModelOutput(object): 145 | def __init__(self, logits, loss=None): 146 | self.logits = logits 147 | self.loss = loss 148 | 149 | def __getitem__(self, key): 150 | if key == "logits": 151 | return self.logits 152 | raise StopIteration 153 | 154 | 155 | class RuGPT3XL(PreTrainedModel): 156 | def __init__(self, model, tokenizer, model_path, seq_len=512, min_generated_len=32): 157 | super().__init__(PretrainedConfig()) 158 | self.model = model 159 | self.pad_token_id = tokenizer.encoder[''] 160 | self.eos_token_id = tokenizer.encoder[''] 161 | self.seq_len = seq_len 162 | self.model_path = model_path 163 | self.tokenizer = tokenizer 164 | self.min_generated_len = min_generated_len 165 | 166 | @classmethod 167 | def from_pretrained( 168 | cls, 169 | model_name_or_path=None, 170 | seq_len=512, 171 | weights_path=None, 172 | deepspeed_config_path=None, 173 | master_port="6000", 174 | min_generated_len=32, 175 | rank=0 176 | ): 177 | init_method = 'tcp://' + os.getenv('MASTER_ADDR', 'localhost') + ':' + os.getenv('MASTER_PORT', master_port) 178 | try: 179 | torch.distributed.init_process_group(backend='nccl', world_size=1, rank=rank, init_method=init_method) 180 | mpu.initialize_model_parallel(1) 181 | except RuntimeError: 182 | logger.info("The default process group has already initialized...") 183 | 184 | seed = 1234 185 | random.seed(seed) 186 | np.random.seed(seed) 187 | torch.manual_seed(seed) 188 | tokenizer = GPT2Tokenizer.from_pretrained(model_name_or_path) 189 | logger.info("Check cached model files...") 190 | if weights_path is None: 191 | weights_path, deepspeed_config_path = download_model_files(model_name_or_path) 192 | model = setup_model(weights_path, deepspeed_config_path) 193 | mpu.model_parallel_cuda_manual_seed(seed) 194 | # model.cuda() 195 | model = model.eval() 196 | return cls(model, tokenizer=tokenizer, seq_len=seq_len, model_path=model_name_or_path, min_generated_len=min_generated_len) 197 | 198 | def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs): 199 | kwargs.update({"input_ids": input_ids}) 200 | return kwargs 201 | 202 | def generate( 203 | self, text: Union[str, NoneType] = None, 204 | input_ids: Union[torch.LongTensor, NoneType] = None, 205 | max_length: Union[int, None] = None, 206 | min_length: Union[int, NoneType] = None, 207 | do_sample: Union[bool, NoneType] = None, 208 | early_stopping: Union[bool, NoneType] = None, 209 | num_beams: Union[int, NoneType] = None, 210 | temperature: Union[float, NoneType] = None, 211 | top_k: Union[int, NoneType] = None, 212 | top_p: Union[float, NoneType] = None, 213 | repetition_penalty: Union[float, NoneType] = None, 214 | bad_words_ids: Union[Iterable[int], NoneType] = None, 215 | bos_token_id: Union[int, NoneType] = None, 216 | pad_token_id: Union[int, NoneType] = None, 217 | eos_token_id: Union[int, NoneType] = None, 218 | length_penalty: Union[float, NoneType] = None, 219 | no_repeat_ngram_size: Union[int, NoneType] = None, 220 | num_return_sequences: Union[int, NoneType] = None, 221 | decoder_start_token_id: Union[int, NoneType] = None, 222 | use_cache: Union[bool, NoneType] = None, 223 | **model_kwargs): 224 | if text is not None: 225 | input_ids = torch.cuda.LongTensor([self.tokenizer(text)['input_ids']]) 226 | if eos_token_id is None: 227 | eos_token_id = self.eos_token_id 228 | if pad_token_id is None: 229 | pad_token_id = self.pad_token_id 230 | if input_ids.shape[-1] > 2048: 231 | input_ids = input_ids[:, -2048 + self.min_generated_len:] 232 | # print(input_ids.shape, max_length, mpu.get_data_parallel_rank()) 233 | res = super().generate( 234 | input_ids=input_ids, 235 | max_length=max_length, 236 | min_length=min_length, 237 | do_sample=do_sample, 238 | early_stopping=early_stopping, 239 | num_beams=num_beams, 240 | temperature=temperature, 241 | top_k=top_k, 242 | top_p=top_p, 243 | repetition_penalty=repetition_penalty, 244 | bad_words_ids=bad_words_ids, 245 | bos_token_id=bos_token_id, 246 | pad_token_id=pad_token_id, 247 | eos_token_id=eos_token_id, 248 | length_penalty=length_penalty, 249 | no_repeat_ngram_size=no_repeat_ngram_size, 250 | num_return_sequences=num_return_sequences, 251 | decoder_start_token_id=decoder_start_token_id, 252 | use_cache=use_cache, 253 | **model_kwargs 254 | ) 255 | return list(map(self.tokenizer.decode, res.tolist())) 256 | 257 | def __call__(self, text=None, input_ids=None, labels=None, **kwargs): 258 | if input_ids is None: 259 | if text is None: 260 | text = "" 261 | input_ids = torch.cuda.LongTensor([self.tokenizer(text)['input_ids']]) 262 | if isinstance(input_ids, list): 263 | input_ids = torch.cuda.LongTensor(input_ids) 264 | if isinstance(labels, list): 265 | labels = torch.cuda.LongTensor(labels) 266 | res = [] 267 | if labels is not None: 268 | lbls = labels 269 | else: 270 | lbls = [None] * len(input_ids) 271 | loss = None 272 | original_context_length = 0 273 | seq_len = self.seq_len 274 | for tokens, lbl in zip(input_ids, lbls): 275 | context_tokens = tokens.tolist() 276 | original_context_length = len(context_tokens) 277 | if labels is not None: 278 | lbl = lbl.tolist() 279 | assert original_context_length 280 | 281 | while len(context_tokens) % 16: 282 | context_tokens.append(self.pad_token_id) 283 | if labels is not None: 284 | lbl.append(self.pad_token_id) 285 | context_tokens = context_tokens[-2048:] 286 | context_length = len(context_tokens) 287 | if labels is not None: 288 | lbl = lbl[-2048:] 289 | lbl = torch.cuda.LongTensor(lbl) 290 | context_tokens_tensor = torch.cuda.LongTensor(context_tokens) 291 | context_length_tensor = torch.cuda.LongTensor([context_length]) 292 | 293 | torch.distributed.broadcast(context_length_tensor, mpu.get_model_parallel_src_rank(), 294 | group=mpu.get_model_parallel_group()) 295 | torch.distributed.broadcast(context_tokens_tensor, mpu.get_model_parallel_src_rank(), 296 | group=mpu.get_model_parallel_group()) 297 | 298 | # context_length = context_length_tensor[0].item() 299 | # print(context_tokens_tensor.shape, original_context_length, seq_len, mpu.get_data_parallel_rank()) 300 | 301 | tokens = context_tokens_tensor 302 | tokens = tokens.view(1, -1).contiguous() 303 | tokens = tokens.to(torch.cuda.current_device()) 304 | attention_mask, loss_mask, position_ids = get_masks_and_position_ids(tokens, self.pad_token_id, False, 305 | False) 306 | lm_logits = self.model(tokens, position_ids, attention_mask) 307 | loss = None 308 | if labels is not None: 309 | # Shift so that tokens < n predict n 310 | shift_logits = lm_logits[..., :-1, :].contiguous() 311 | shift_labels = lbl[..., 1:].contiguous() 312 | # Flatten the tokens 313 | loss_fct = CrossEntropyLoss(ignore_index=self.pad_token_id) 314 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 315 | res.append((lm_logits, loss)) 316 | logits = torch.cat([x[0] for x in res], dim=0)[:, : original_context_length, :] 317 | if loss is not None: 318 | loss = [x[1] for x in res] 319 | # print(logits.shape, mpu.get_data_parallel_rank(), "------------") 320 | return ModelOutput(logits, loss) 321 | -------------------------------------------------------------------------------- /src/mpu/layers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2020, Sber. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | # Parts of the code here are adapted from PyTorch 18 | # repo: https://github.com/pytorch/pytorch 19 | 20 | 21 | import math 22 | 23 | import torch 24 | import torch.nn.functional as F 25 | import torch.nn.init as init 26 | from torch.nn.parameter import Parameter 27 | 28 | from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm 29 | 30 | from .initialize import get_model_parallel_rank 31 | from .initialize import get_model_parallel_world_size 32 | from .mappings import copy_to_model_parallel_region 33 | from .mappings import gather_from_model_parallel_region 34 | from .mappings import reduce_from_model_parallel_region 35 | from .mappings import scatter_to_model_parallel_region 36 | from .random import get_cuda_rng_tracker 37 | from .utils import divide 38 | from .utils import split_tensor_along_last_dim 39 | from .utils import VocabUtility 40 | 41 | 42 | def _initialize_affine_weight(weight, output_size, input_size, 43 | per_partition_size, partition_dim, init_method, 44 | stride=1, return_master_weight=False): 45 | """Initialize affine weight for model parallel. 46 | 47 | Build the master weight on all processes and scatter 48 | the relevant chunk.""" 49 | # If we only use 1 process for model parallelism, bypass scatter. 50 | world_size = get_model_parallel_world_size() 51 | if world_size == 1: 52 | init_method(weight) 53 | if return_master_weight: 54 | return weight 55 | return None 56 | 57 | # Initialize master weight 58 | master_weight = torch.empty(output_size, input_size, 59 | dtype=weight.dtype, 60 | requires_grad=False) 61 | init_method(master_weight) 62 | 63 | # Split and copy 64 | per_partition_per_stride_size = divide(per_partition_size, stride) 65 | weight_list = torch.split(master_weight, per_partition_per_stride_size, 66 | dim=partition_dim) 67 | rank = get_model_parallel_rank() 68 | my_weight_list = weight_list[rank::world_size] 69 | 70 | with torch.no_grad(): 71 | torch.cat(my_weight_list, dim=partition_dim, out=weight) 72 | if return_master_weight: 73 | return master_weight 74 | return None 75 | 76 | 77 | class VocabParallelEmbedding(torch.nn.Module): 78 | """Embedding parallelized in the vocabulary dimension. 79 | 80 | This is mainly adapted from torch.nn.Embedding and all the default 81 | values are kept. 82 | Arguments: 83 | num_embeddings: vocabulary size. 84 | embedding_dim: size of hidden state. 85 | init_method: method to initialize weights. 86 | """ 87 | def __init__(self, num_embeddings, embedding_dim, 88 | init_method=init.xavier_normal_): 89 | super(VocabParallelEmbedding, self).__init__() 90 | # Keep the input dimensions. 91 | self.num_embeddings = num_embeddings 92 | self.embedding_dim = embedding_dim 93 | # Set the detauls for compatibility. 94 | self.padding_idx = None 95 | self.max_norm = None 96 | self.norm_type = 2. 97 | self.scale_grad_by_freq = False 98 | self.sparse = False 99 | self._weight = None 100 | # Divide the weight matrix along the vocaburaly dimension. 101 | self.vocab_start_index, self.vocab_end_index = \ 102 | VocabUtility.vocab_range_from_global_vocab_size( 103 | self.num_embeddings, get_model_parallel_rank(), 104 | get_model_parallel_world_size()) 105 | self.num_embeddings_per_partition = self.vocab_end_index - \ 106 | self.vocab_start_index 107 | 108 | # Allocate weights. 109 | self.weight = Parameter(torch.Tensor(self.num_embeddings_per_partition, 110 | self.embedding_dim)) 111 | self.weight.model_parallel = True 112 | # And initialize. 113 | _initialize_affine_weight( 114 | self.weight, self.num_embeddings, self.embedding_dim, 115 | self.num_embeddings_per_partition, 0, init_method) 116 | 117 | def forward(self, input_): 118 | # Build the mask. 119 | input_mask = (input_ < self.vocab_start_index) | \ 120 | (input_ >= self.vocab_end_index) 121 | # Mask the input. 122 | masked_input = input_.clone() - self.vocab_start_index 123 | masked_input[input_mask] = 0 124 | # Get the embeddings. 125 | output_parallel = F.embedding(masked_input, self.weight, 126 | self.padding_idx, self.max_norm, 127 | self.norm_type, self.scale_grad_by_freq, 128 | self.sparse) 129 | # Mask the output embedding. 130 | output_parallel[input_mask, :] = 0.0 131 | # Reduce across all the model parallel GPUs. 132 | output = reduce_from_model_parallel_region(output_parallel) 133 | return output 134 | 135 | 136 | class ParallelEmbedding(torch.nn.Module): 137 | """Embedding parallelized in the embedding dimension. 138 | 139 | This is mainly adapted from torch.nn.Embedding and all the default 140 | values are kept. 141 | Arguments: 142 | num_embeddings: vocabulary size. 143 | embedding_dim: size of hidden state. 144 | init_method: method to initialize weights. 145 | """ 146 | def __init__(self, num_embeddings, embedding_dim, 147 | init_method=init.xavier_normal_, 148 | keep_master_weight_for_test=False): 149 | super(ParallelEmbedding, self).__init__() 150 | # Keep the input dimensions. 151 | self.num_embeddings = num_embeddings 152 | self.embedding_dim = embedding_dim 153 | # Set some detauls for compatibility. 154 | self.padding_idx = None 155 | self.max_norm = None 156 | self.norm_type = 2. 157 | self.scale_grad_by_freq = False 158 | self.sparse = False 159 | self._weight = None 160 | # Divide the weight matrix along the embedding dimension. 161 | world_size = get_model_parallel_world_size() 162 | self.embedding_dim_per_partition = divide(self.embedding_dim, 163 | world_size) 164 | 165 | # Allocate weights. 166 | self.weight = Parameter(torch.Tensor(self.num_embeddings, 167 | self.embedding_dim_per_partition)) 168 | self.weight.model_parallel = True 169 | # And initialize. 170 | _initialize_affine_weight( 171 | self.weight, self.num_embeddings, self.embedding_dim, 172 | self.embedding_dim_per_partition, 1, init_method, 173 | stride=1, return_master_weight=False) 174 | 175 | def forward(self, input_): 176 | input_parallel = copy_to_model_parallel_region(input_) 177 | output_parallel = F.embedding(input_parallel, self.weight, 178 | self.padding_idx, self.max_norm, 179 | self.norm_type, self.scale_grad_by_freq, 180 | self.sparse) 181 | output = gather_from_model_parallel_region(output_parallel) 182 | return output 183 | 184 | 185 | class ColumnParallelLinear(torch.nn.Module): 186 | """Linear layer with column parallelism. 187 | 188 | The linear layer is defined as Y = XA + b. A is parallelized along 189 | its second dimension as A = [A_1, ..., A_p]. 190 | 191 | Arguments: 192 | input_size: first dimension of matrix A. 193 | output_size: second dimension of matrix A. 194 | bias: If true, add bias 195 | gather_output: If true, call all-gether on output and make Y avaiable 196 | to all GPUs, otherwise, every GPU will have its output 197 | which is Y_i = XA_i 198 | init_method: method to initialize weights. Note that bias is always set 199 | to zero. 200 | stride: For the strided linear layers. 201 | keep_master_weight_for_test: This was added for testing and should be 202 | set to False. It returns the master weights 203 | used for initialization. 204 | """ 205 | def __init__(self, input_size, output_size, bias=True, gather_output=True, 206 | init_method=init.xavier_normal_, stride=1, 207 | keep_master_weight_for_test=False): 208 | super(ColumnParallelLinear, self).__init__() 209 | 210 | # Keep input parameters 211 | self.input_size = input_size 212 | self.output_size = output_size 213 | self.gather_output = gather_output 214 | # Divide the weight matrix along the last dimension. 215 | world_size = get_model_parallel_world_size() 216 | self.output_size_per_partition = divide(output_size, world_size) 217 | 218 | # Parameters. 219 | # Note: torch.nn.functional.linear performs XA^T + b and as a result 220 | # we allocate the transpose. 221 | self.weight = Parameter(torch.Tensor(self.output_size_per_partition, 222 | self.input_size)) 223 | self.weight.model_parallel = True 224 | if bias: 225 | self.bias = Parameter(torch.Tensor(self.output_size_per_partition)) 226 | self.bias.model_parallel = True 227 | # Always initialize bias to zero. 228 | with torch.no_grad(): 229 | self.bias.zero_() 230 | else: 231 | self.register_parameter('bias', None) 232 | 233 | # Initialize weight. 234 | self.master_weight = _initialize_affine_weight( 235 | self.weight, self.output_size, self.input_size, 236 | self.output_size_per_partition, 0, init_method, 237 | stride=stride, return_master_weight=keep_master_weight_for_test) 238 | 239 | def forward(self, input_): 240 | # Set up backprop all-reduce. 241 | input_parallel = copy_to_model_parallel_region(input_) 242 | # Matrix multiply. 243 | output_parallel = F.linear(input_parallel, self.weight, self.bias) 244 | if self.gather_output: 245 | # All-gather across the partitions. 246 | output = gather_from_model_parallel_region(output_parallel) 247 | else: 248 | output = output_parallel 249 | return output 250 | 251 | 252 | class RowParallelLinear(torch.nn.Module): 253 | """Linear layer with row parallelism. 254 | 255 | The linear layer is defined as Y = XA + b. A is parallelized along 256 | its first dimension and X along its second dimension as: 257 | - - 258 | | A_1 | 259 | | . | 260 | A = | . | X = [X_1, ..., X_p] 261 | | . | 262 | | A_p | 263 | - - 264 | Arguments: 265 | input_size: first dimension of matrix A. 266 | output_size: second dimension of matrix A. 267 | bias: If true, add bias. Note that bias is not parallelized. 268 | input_is_parallel: If true, we assume that the input is already 269 | split across the GPUs and we do not split 270 | again. 271 | init_method: method to initialize weights. Note that bias is always set 272 | to zero. 273 | stride: For the strided linear layers. 274 | keep_master_weight_for_test: This was added for testing and should be 275 | set to False. It returns the master weights 276 | used for initialization. 277 | """ 278 | def __init__(self, input_size, output_size, bias=True, 279 | input_is_parallel=False, 280 | init_method=init.xavier_normal_, stride=1, 281 | keep_master_weight_for_test=False): 282 | super(RowParallelLinear, self).__init__() 283 | 284 | # Keep input parameters 285 | self.input_size = input_size 286 | self.output_size = output_size 287 | self.input_is_parallel = input_is_parallel 288 | # Divide the weight matrix along the last dimension. 289 | world_size = get_model_parallel_world_size() 290 | self.input_size_per_partition = divide(input_size, world_size) 291 | 292 | # Parameters. 293 | # Note: torch.nn.functional.linear performs XA^T + b and as a result 294 | # we allocate the transpose. 295 | self.weight = Parameter(torch.Tensor(self.output_size, 296 | self.input_size_per_partition)) 297 | self.weight.model_parallel = True 298 | if bias: 299 | self.bias = Parameter(torch.Tensor(self.output_size)) 300 | # Always initialize bias to zero. 301 | with torch.no_grad(): 302 | self.bias.zero_() 303 | else: 304 | self.register_parameter('bias', None) 305 | 306 | # Initialize weight. 307 | self.master_weight = _initialize_affine_weight( 308 | self.weight, self.output_size, self.input_size, 309 | self.input_size_per_partition, 1, init_method, 310 | stride=stride, return_master_weight=keep_master_weight_for_test) 311 | 312 | def forward(self, input_): 313 | # Set up backprop all-reduce. 314 | if self.input_is_parallel: 315 | input_parallel = input_ 316 | else: 317 | input_parallel = scatter_to_model_parallel_region(input_) 318 | # Matrix multiply. 319 | output_parallel = F.linear(input_parallel, self.weight) 320 | # All-reduce across all the partitions. 321 | output_ = reduce_from_model_parallel_region(output_parallel) 322 | if self.bias is not None: 323 | output = output_ + self.bias 324 | else: 325 | output = output_ 326 | return output 327 | 328 | -------------------------------------------------------------------------------- /src/mpu/random.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | #Modified by Samyam Rajbhandari 3 | #Used to partition the activations stored for backward propagation 4 | #Therefore reduces the memory consumption 5 | 6 | # Copyright (c) 2020, Sber. All rights reserved. 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | 21 | # Parts of the code here are adapted from PyTorch 22 | # repo: https://github.com/pytorch/pytorch 23 | import contextlib 24 | import torch.distributed as dist 25 | import torch 26 | from torch import _C 27 | from torch.cuda import _lazy_call, device as device_ctx_manager 28 | #from torch.utils.checkpoint import detach_variable 29 | 30 | 31 | import torch.distributed as dist 32 | PARTITION_ACTIVATIONS = False 33 | PA_CORRECTNESS_TEST= False 34 | 35 | def see_memory_usage(message, force=False): 36 | if not force: 37 | return 38 | dist.barrier() 39 | if dist.get_rank() == 0: 40 | print(message) 41 | print("Memory Allocated ", torch.cuda.memory_allocated()/(1024*1024*1024), "GigaBytes") 42 | print("Max Memory Allocated ", torch.cuda.max_memory_allocated()/(1024*1024*1024), "GigaBytes") 43 | print("Cache Allocated ", torch.cuda.memory_cached()/(1024*1024*1024), "GigaBytes") 44 | print("Max cache Allocated ", torch.cuda.max_memory_cached()/(1024*1024*1024), "GigaBytes") 45 | print(" ") 46 | #input("Press Any Key To Continue ..") 47 | 48 | 49 | from .initialize import get_data_parallel_rank 50 | from .initialize import get_model_parallel_rank 51 | from .initialize import get_model_parallel_world_size 52 | from .initialize import get_model_parallel_group 53 | 54 | mp_rank = None #get_model_parallel_rank() 55 | mp_size = None #get_model_parallel_world_size() 56 | mp_group = None #get_model_parallel_group() 57 | 58 | # Default name for the model parallel rng tracker. 59 | _MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng' 60 | transport_stream = None 61 | cuda_device=None 62 | def detach_variable(inputs, device=None): 63 | if isinstance(inputs, tuple): 64 | out = [] 65 | for inp in inputs: 66 | if not isinstance(inp, torch.Tensor): 67 | out.append(inp) 68 | continue 69 | 70 | requires_grad = inp.requires_grad 71 | 72 | if device is not None: 73 | x = inp.to(device=device) 74 | else: 75 | x = inp 76 | 77 | x = x.detach() 78 | x.requires_grad = requires_grad 79 | out.append(x) 80 | return tuple(out) 81 | else: 82 | raise RuntimeError( 83 | "Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__) 84 | 85 | def _set_cuda_rng_state(new_state, device=-1): 86 | """Sets the random number generator state of the current GPU. 87 | 88 | Argumentss: 89 | new_state (torch.ByteTensor): The desired state 90 | This function is adapted from PyTorch repo (torch.cuda.set_rng_state) 91 | with a single change: the input state is not cloned. Cloning caused 92 | major performance issues for +4 GPU cases. 93 | """ 94 | if hasattr(_C, '_cuda_setRNGState') and callable(_C._cuda_setRNGState): 95 | # older PyTorch 96 | def cb(): 97 | with device_ctx_manager(device): 98 | _C._cuda_setRNGState(new_state) 99 | else: 100 | # newer PyTorch 101 | if device == -1: 102 | device = torch.device('cuda') 103 | elif isinstance(device, str): 104 | device = torch.device(device) 105 | elif isinstance(device, int): 106 | device = torch.device('cuda', device) 107 | 108 | def cb(): 109 | idx = device.index 110 | if idx is None: 111 | idx = torch.cuda.current_device() 112 | default_generator = torch.cuda.default_generators[idx] 113 | default_generator.set_state(new_state) 114 | 115 | _lazy_call(cb) 116 | 117 | 118 | 119 | class CudaRNGStatesTracker: 120 | """Tracker for the cuda RNG states. 121 | 122 | Using the `add` method, a cuda rng state is initialized based on 123 | the input `seed` and is assigned to `name`. Later, by forking the 124 | rng state, we can perform operations and return to our starting 125 | cuda state. 126 | """ 127 | def __init__(self): 128 | # Map from a string name to the cuda rng state. 129 | self.states_ = {} 130 | # Seeds are just for book keeping and ensure no seed is set twice. 131 | self.seeds_ = set() 132 | 133 | def reset(self): 134 | """Set to the initial state (no tracker).""" 135 | self.states_ = {} 136 | self.seeds_ = set() 137 | 138 | def get_states(self): 139 | """Get rng states. Copy the dictionary so we have direct 140 | pointers to the states, not just a pointer to the dictionary.""" 141 | states = {} 142 | for name in self.states_: 143 | states[name] = self.states_[name] 144 | return states 145 | 146 | def set_states(self, states): 147 | """Set the rng states. For efficiency purposes, we do not check 148 | the size of seed for compatibility.""" 149 | self.states_ = states 150 | 151 | def add(self, name, seed): 152 | """Track the rng state.""" 153 | # Check seed is not already used. 154 | if seed in self.seeds_: 155 | raise Exception('seed {} already exists'.format(seed)) 156 | self.seeds_.add(seed) 157 | # Check that state is not already defined. 158 | if name in self.states_: 159 | raise Exception('cuda rng state {} already exists'.format(name)) 160 | # Get the current rng state. 161 | orig_rng_state = torch.cuda.get_rng_state() 162 | # Set the new state and store it. 163 | torch.cuda.manual_seed(seed) 164 | self.states_[name] = torch.cuda.get_rng_state() 165 | # Reset rng state to what it was. 166 | _set_cuda_rng_state(orig_rng_state) 167 | 168 | @contextlib.contextmanager 169 | def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME): 170 | """Fork the cuda rng state, perform operations, and exit with 171 | the original state.""" 172 | # Check if we have added the state 173 | if name not in self.states_: 174 | raise Exception('cuda rng state {} is not added'.format(name)) 175 | # Store current rng state. 176 | orig_cuda_rng_state = torch.cuda.get_rng_state() 177 | # Set rng state to the desired one 178 | _set_cuda_rng_state(self.states_[name]) 179 | # Do the stuff we wanted to do. 180 | try: 181 | yield 182 | finally: 183 | # Update the current rng state for later use. 184 | self.states_[name] = torch.cuda.get_rng_state() 185 | # And set the state to the original state we started with. 186 | _set_cuda_rng_state(orig_cuda_rng_state) 187 | 188 | 189 | # RNG tracker object. 190 | _CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() 191 | 192 | 193 | def get_cuda_rng_tracker(): 194 | """Get cuda rng tracker.""" 195 | return _CUDA_RNG_STATE_TRACKER 196 | 197 | 198 | def model_parallel_cuda_manual_seed(seed): 199 | """Initialize model parallel cuda seed. 200 | 201 | This function should be called after the model parallel is 202 | initialized. Also, no torch.cuda.manual_seed should be called 203 | after this function. Basically, this is replacement for that 204 | function. 205 | Two set of RNG states are tracked: 206 | default state: This is for data parallelism and is the same among a 207 | set of model parallel GPUs but different across 208 | different model paralle groups. This is used for 209 | example for dropout in the non-model-parallel regions. 210 | model-parallel state: This state is different among a set of model 211 | parallel GPUs, but the same across data parallel 212 | groups. This is used for example for dropout in 213 | model parallel regions. 214 | """ 215 | # 2718 is just for fun and any POSITIVE value will work. 216 | offset = seed + 2718 217 | model_parallel_seed = offset + get_model_parallel_rank() 218 | # Data parallel gets the original sedd. 219 | data_parallel_seed = seed 220 | 221 | if torch.distributed.get_rank() == 0: 222 | print('> initializing model parallel cuda seeds on global rank {}, ' 223 | 'model parallel rank {}, and data parallel rank {} with ' 224 | 'model parallel seed: {} and data parallel seed: {}'.format( 225 | torch.distributed.get_rank(), get_model_parallel_rank(), 226 | get_data_parallel_rank(), model_parallel_seed, 227 | data_parallel_seed), flush=True) 228 | _CUDA_RNG_STATE_TRACKER.reset() 229 | # Set the default state. 230 | torch.cuda.manual_seed(data_parallel_seed) 231 | # and model parallel state. 232 | _CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, 233 | model_parallel_seed) 234 | 235 | 236 | def get_partition_start(item): 237 | global mp_rank, mp_size, mp_group 238 | partition_size = get_partition_size(item) 239 | start = partition_size * mp_rank 240 | return int(start) 241 | 242 | def get_partition_size(item): 243 | global mp_rank, mp_size, mp_group 244 | size = item.numel() 245 | partition_size = size/mp_size 246 | return int(partition_size) 247 | 248 | def get_full_inputs(tensors): 249 | inputs=[] 250 | for i in range(int(len(tensors)/2)-1): 251 | item = tensors[2 * i] 252 | size = tensors[2* i + 1] 253 | partition_size = item.numel() 254 | tensor_size = partition_size * mp_size 255 | flat_tensor = torch.zeros([tensor_size], dtype=item.dtype, device=item.device) 256 | partitions=[] 257 | for i in range(mp_size): 258 | part_i = flat_tensor.narrow(0, partition_size * i , partition_size) 259 | if i == mp_rank: 260 | part_i.copy_(item) 261 | partitions.append(part_i) 262 | dist.all_gather(partitions,partitions[mp_rank], group=mp_group) 263 | input_tensor = flat_tensor.view(list(size.numpy())) 264 | item.data=input_tensor.data 265 | 266 | inputs.append(item) 267 | inputs.append(tensors[-2]) 268 | 269 | return tuple(inputs) 270 | 271 | 272 | 273 | class CheckpointFunction(torch.autograd.Function): 274 | """This function is adapted from torch.utils.checkpoint with 275 | two main changes: 276 | 1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state` 277 | 2) the states in the model parallel tracker are also properly 278 | tracked/set/reset. 279 | """ 280 | @staticmethod 281 | def forward(ctx, run_function, *args): 282 | ctx.run_function = run_function 283 | global mp_rank, mp_size, mp_group 284 | if mp_rank is None: 285 | mp_rank = get_model_parallel_rank() 286 | mp_size = get_model_parallel_world_size() 287 | mp_group = get_model_parallel_group() 288 | 289 | 290 | global cuda_device, transport_stream, PARTITION_ACTIVATIONS 291 | if cuda_device is None: 292 | if dist.get_rank() == 0: 293 | print(f"Partition Activations {PARTITION_ACTIVATIONS} and Correctness Check {PA_CORRECTNESS_TEST}") 294 | 295 | cuda_device = torch.cuda.current_device() 296 | #The transport stream is used to overlap the allgather communication for the activations 297 | #with the computation in the backward pass 298 | transport_stream = torch.cuda.Stream(device=cuda_device) 299 | 300 | if PARTITION_ACTIVATIONS: 301 | inputs = [item.detach().contiguous().view(-1).narrow(0, get_partition_start(item), get_partition_size(item)).clone() for item in args[:-1]] 302 | inputs.append(args[-1]) 303 | 304 | #just in case something funky is happening such as reuse of inputs 305 | inputs_cuda = [item.to(cuda_device) for item in args] 306 | 307 | # Copy the rng states. 308 | ctx.fwd_cpu_rng_state = torch.get_rng_state() 309 | ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state() 310 | ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() 311 | 312 | #ctx.save_for_backward(*args) 313 | with torch.no_grad(): 314 | outputs = run_function(*inputs_cuda) 315 | 316 | del inputs_cuda 317 | 318 | if PARTITION_ACTIVATIONS: 319 | new_args = [] 320 | for arg, inp in zip(args,inputs): 321 | size= torch.tensor(arg.size()) 322 | arg.data = inp.data 323 | new_args.append(arg) 324 | new_args.append(size) 325 | ctx.save_for_backward(*new_args) 326 | else: 327 | ctx.save_for_backward(*args) 328 | 329 | return outputs 330 | 331 | @staticmethod 332 | def backward(ctx, *args): 333 | if not torch.autograd._is_checkpoint_valid(): 334 | raise RuntimeError("Checkpointing is not compatible with .grad(), " 335 | "please use .backward() if possible") 336 | 337 | global cuda_device, transport_stream, PARTITION_ACTIVATIONS 338 | 339 | if PARTITION_ACTIVATIONS: 340 | with torch.cuda.stream(transport_stream): 341 | inputs = get_full_inputs(ctx.saved_tensors) 342 | detached_inputs = detach_variable(inputs) 343 | else: 344 | inputs = ctx.saved_tensors 345 | detached_inputs = detach_variable(inputs) 346 | 347 | # Store the current states. 348 | bwd_cpu_rng_state = torch.get_rng_state() 349 | bwd_cuda_rng_state = torch.cuda.get_rng_state() 350 | bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() 351 | 352 | # Set the states to what it used to be before the forward pass. 353 | torch.set_rng_state(ctx.fwd_cpu_rng_state) 354 | _set_cuda_rng_state(ctx.fwd_cuda_rng_state) 355 | get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker) 356 | 357 | if PARTITION_ACTIVATIONS: 358 | current_stream=torch.cuda.current_stream() 359 | current_stream.wait_stream(transport_stream) 360 | 361 | with torch.enable_grad(): 362 | outputs = ctx.run_function(*detached_inputs) 363 | 364 | # Set the states back to what it was at the start of this function. 365 | torch.set_rng_state(bwd_cpu_rng_state) 366 | _set_cuda_rng_state(bwd_cuda_rng_state) 367 | get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker) 368 | 369 | if isinstance(outputs, torch.Tensor): 370 | outputs = (outputs,) 371 | torch.autograd.backward(outputs, args) 372 | return (None,) + tuple(inp.grad for inp in detached_inputs) 373 | 374 | 375 | def checkpoint(function, *args): 376 | """Checkpoint a model or part of the model. 377 | This has been directly copied from torch.utils.checkpoint.""" 378 | return CheckpointFunction.apply(function, *args) 379 | 380 | def partition_activations_in_checkpoint(partition_activation): 381 | global PARTITION_ACTIVATIONS 382 | PARTITION_ACTIVATIONS=partition_activation 383 | if dist.get_rank() == 0: 384 | print(f"**************Partition Activations {PARTITION_ACTIVATIONS}************") 385 | 386 | 387 | -------------------------------------------------------------------------------- /src/mpu/transformer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2020, Sber. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Transformer.""" 17 | 18 | import math 19 | 20 | import torch 21 | from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm 22 | from .initialize import get_model_parallel_world_size 23 | from .layers import ColumnParallelLinear 24 | from .layers import RowParallelLinear 25 | from .utils import divide 26 | from .utils import split_tensor_along_last_dim 27 | from src.utils import DEEPSPEED_WRAP 28 | 29 | 30 | class GPT3ParallelSelfAttention(torch.nn.Module): 31 | """Parallel self-attention layer for GPT3. 32 | 33 | Self-attention layer takes input with size [b, s, h] where b is 34 | the batch size, s is the sequence lenght, and h is the hidden size 35 | and creates output of the same size. 36 | Arguments: 37 | hidden_size: total hidden size of the layer (h). 38 | num_attention_heads: number of attention heads (n). Note that we 39 | require n to be divisible by number of GPUs 40 | used to parallelize the model. Also, we 41 | require hidden size to be divisible by n. 42 | dropout_prob: dropout probability for the attention scores. 43 | init_method: weight initialization. 44 | output_layer_init_method: output layer initialization. If None, use 45 | `init_method`. 46 | We use the following notation: 47 | h: hidden_size 48 | n: num_attention_heads 49 | p: number of partitions 50 | np: n/p 51 | hp: h/p 52 | hn: h/n 53 | b: batch size 54 | s: sequence length 55 | """ 56 | 57 | def __init__(self, hidden_size, num_attention_heads, 58 | attention_dropout_prob, output_dropout_prob, 59 | init_method, output_layer_init_method=None, 60 | use_deepspeed_sparse=None): 61 | super(GPT3ParallelSelfAttention, self).__init__() 62 | self.use_deepspeed_sparse = use_deepspeed_sparse 63 | if DEEPSPEED_WRAP: 64 | deepspeed = DEEPSPEED_WRAP.deepspeed 65 | from deepspeed.ops.sparse_attention import SparseSelfAttention 66 | if self.use_deepspeed_sparse is not None: 67 | self.sparse_self_attention = SparseSelfAttention(self.use_deepspeed_sparse) 68 | # Set output layer initialization if not provided. 69 | if output_layer_init_method is None: 70 | output_layer_init_method = init_method 71 | # Per attention head and per partition values. 72 | world_size = get_model_parallel_world_size() 73 | self.hidden_size_per_partition = divide(hidden_size, world_size) 74 | self.hidden_size_per_attention_head = divide(hidden_size, 75 | num_attention_heads) 76 | self.num_attention_heads_per_partition = divide(num_attention_heads, 77 | world_size) 78 | # Strided linear layer. 79 | self.query_key_value = ColumnParallelLinear(hidden_size, 3 * hidden_size, 80 | stride=3, 81 | gather_output=False, 82 | init_method=init_method) 83 | # Dropout. Note that for a single iteration, this layer will generate 84 | # different outputs on different number of parallel partitions but 85 | # on average it should not be partition dependent. 86 | self.attention_dropout = torch.nn.Dropout(attention_dropout_prob) 87 | 88 | # Output. 89 | self.dense = RowParallelLinear(hidden_size, 90 | hidden_size, 91 | input_is_parallel=True, 92 | init_method=output_layer_init_method) 93 | self.output_dropout = torch.nn.Dropout(output_dropout_prob) 94 | 95 | if DEEPSPEED_WRAP: 96 | if DEEPSPEED_WRAP.deepspeed.checkpointing.is_configured(): 97 | global get_cuda_rng_tracker, checkpoint 98 | get_cuda_rng_tracker = DEEPSPEED_WRAP.deepspeed.checkpointing.get_cuda_rng_tracker 99 | checkpoint = deepspeed.checkpointing.checkpoint 100 | 101 | def _transpose_for_scores(self, tensor): 102 | """Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with 103 | size [b, np, s, hn]. 104 | """ 105 | new_tensor_shape = tensor.size()[:-1] + \ 106 | (self.num_attention_heads_per_partition, 107 | self.hidden_size_per_attention_head) 108 | tensor = tensor.view(*new_tensor_shape) 109 | return tensor.permute(0, 2, 1, 3) 110 | 111 | def forward(self, hidden_states, ltor_mask): 112 | # hidden_states: [b, s, h] 113 | # ltor_mask: [1, 1, s, s] 114 | 115 | # Attention heads. [b, s, hp] 116 | mixed_x_layer = self.query_key_value(hidden_states) 117 | (mixed_query_layer, 118 | mixed_key_layer, 119 | mixed_value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) 120 | 121 | # Reshape and transpose [b, np, s, hn] 122 | query_layer = self._transpose_for_scores(mixed_query_layer) 123 | key_layer = self._transpose_for_scores(mixed_key_layer) 124 | value_layer = self._transpose_for_scores(mixed_value_layer) 125 | 126 | if self.use_deepspeed_sparse: 127 | context_layer = self.sparse_self_attention( 128 | query_layer, 129 | key_layer, 130 | value_layer, 131 | attn_mask=ltor_mask) 132 | else: 133 | # Raw attention scores. [b, np, s, s] 134 | attention_scores = torch.matmul(query_layer, 135 | key_layer.transpose(-1, -2)) 136 | attention_scores = attention_scores / math.sqrt( 137 | self.hidden_size_per_attention_head) 138 | # Apply the left to right attention mask. 139 | attention_scores = torch.mul(attention_scores, ltor_mask) - 10000.0 * (1.0 - ltor_mask) 140 | 141 | # Attention probabilities. [b, np, s, s] 142 | attention_probs = torch.nn.Softmax(dim=-1)(attention_scores) 143 | # This is actually dropping out entire tokens to attend to, which might 144 | # seem a bit unusual, but is taken from the original Transformer paper. 145 | if DEEPSPEED_WRAP and DEEPSPEED_WRAP.deepspeed.checkpointing.is_configured(): 146 | with get_cuda_rng_tracker().fork(): 147 | attention_probs = self.attention_dropout(attention_probs) 148 | else: 149 | attention_probs = self.attention_dropout(attention_probs) 150 | 151 | # Context layer. 152 | # [b, np, s, hn] 153 | context_layer = torch.matmul(attention_probs, value_layer) 154 | # [b, s, np, hn] 155 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 156 | new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) 157 | # [b, s, hp] 158 | context_layer = context_layer.view(*new_context_layer_shape) 159 | 160 | # Output. [b, s, h] 161 | output = self.dense(context_layer) 162 | output = self.output_dropout(output) 163 | 164 | return output 165 | 166 | 167 | # @torch.jit.script 168 | # Remove torch.jit for colab 169 | def gelu_impl(x): 170 | """OpenAI's gelu implementation.""" 171 | return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * 172 | (1.0 + 0.044715 * x * x))) 173 | 174 | 175 | def gelu(x): 176 | return gelu_impl(x) 177 | 178 | 179 | class GPT3ParallelMLP(torch.nn.Module): 180 | """MLP for GPT3. 181 | 182 | MLP will take the input with h hidden state, project it to 4*h 183 | hidden dimension, perform gelu transformation, and project the 184 | state back into h hidden dimension. At the end, dropout is also 185 | applied. 186 | 187 | Arguments: 188 | hidden_size: The hidden size of the self attention. 189 | output_dropout_prob: dropout probability for the outputs 190 | after self attention and final output. 191 | init_method: initialization method used for the weights. Note 192 | that all biases are initialized to zero and 193 | layernorm weight are initialized to one. 194 | output_layer_init_method: output layer initialization. If None, 195 | use `init_method`. 196 | """ 197 | 198 | def __init__(self, hidden_size, output_dropout_prob, init_method, 199 | output_layer_init_method=None): 200 | super(GPT3ParallelMLP, self).__init__() 201 | # Set output layer initialization if not provided. 202 | if output_layer_init_method is None: 203 | output_layer_init_method = init_method 204 | # Project to 4h. 205 | self.dense_h_to_4h = ColumnParallelLinear(hidden_size, 4 * hidden_size, 206 | gather_output=False, 207 | init_method=init_method) 208 | # Project back to h. 209 | self.dense_4h_to_h = RowParallelLinear( 210 | 4 * hidden_size, 211 | hidden_size, 212 | input_is_parallel=True, 213 | init_method=output_layer_init_method) 214 | self.dropout = torch.nn.Dropout(output_dropout_prob) 215 | 216 | def forward(self, hidden_states): 217 | # [b, s, 4hp] 218 | intermediate_parallel = self.dense_h_to_4h(hidden_states) 219 | intermediate_parallel = gelu(intermediate_parallel) 220 | 221 | # [b, s, h] 222 | output = self.dense_4h_to_h(intermediate_parallel) 223 | output = self.dropout(output) 224 | return output 225 | 226 | 227 | class GPT3ParallelTransformerLayer(torch.nn.Module): 228 | """A single layer transformer for GPT3. 229 | 230 | We use the following notation: 231 | h: hidden size 232 | n: number of attention heads 233 | b: batch size 234 | s: sequence length 235 | Transformore layer takes input with size [b, s, h] and returns an 236 | output of the same size. 237 | 238 | Arguments: 239 | hidden_size: The hidden size of the self attention. 240 | num_attention_heads: number of attention head in the self 241 | attention. 242 | attention_dropout_prob: dropout probability of the attention 243 | score in self attention. 244 | output_dropout_prob: dropout probability for the outputs 245 | after self attention and final output. 246 | layernorm_epsilon: epsilon used in layernorm to avoid 247 | division by zero. 248 | init_method: initialization method used for the weights. Note 249 | that all biases are initialized to zero and 250 | layernorm weight are initialized to one. 251 | output_layer_init_method: output layers (attention output and 252 | mlp output) initialization. If None, 253 | use `init_method`. 254 | """ 255 | 256 | def __init__(self, 257 | hidden_size, 258 | num_attention_heads, 259 | attention_dropout_prob, 260 | output_dropout_prob, 261 | layernorm_epsilon, 262 | init_method, 263 | output_layer_init_method=None, 264 | use_deepspeed_sparse=None): 265 | super(GPT3ParallelTransformerLayer, self).__init__() 266 | # Set output layer initialization if not provided. 267 | if output_layer_init_method is None: 268 | output_layer_init_method = init_method 269 | 270 | # Layernorm on the input data. 271 | self.input_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) 272 | 273 | # Self attention. 274 | self.attention = GPT3ParallelSelfAttention( 275 | hidden_size, 276 | num_attention_heads, 277 | attention_dropout_prob, 278 | output_dropout_prob, 279 | init_method, 280 | output_layer_init_method=output_layer_init_method, 281 | use_deepspeed_sparse=use_deepspeed_sparse) 282 | 283 | # Layernorm on the input data. 284 | self.post_attention_layernorm = LayerNorm(hidden_size, 285 | eps=layernorm_epsilon) 286 | 287 | # MLP 288 | self.mlp = GPT3ParallelMLP( 289 | hidden_size, 290 | output_dropout_prob, 291 | init_method, 292 | output_layer_init_method=output_layer_init_method) 293 | 294 | def forward(self, hidden_states, ltor_mask): 295 | # hidden_states: [b, s, h] 296 | # ltor_mask: [1, 1, s, s] 297 | 298 | # Layer norm at the begining of the transformer layer. 299 | layernorm_output = self.input_layernorm(hidden_states) 300 | # Self attention. 301 | attention_output = self.attention(layernorm_output, ltor_mask) 302 | # Residual connection. 303 | layernorm_input = hidden_states + attention_output 304 | # Layer norm post the self attention. 305 | layernorm_output = self.post_attention_layernorm(layernorm_input) 306 | # MLP. 307 | mlp_output = self.mlp(layernorm_output) 308 | # Second residual connection. 309 | output = layernorm_input + mlp_output 310 | 311 | return output 312 | 313 | 314 | def unscaled_init_method(sigma): 315 | """Init method based on N(0, sigma).""" 316 | 317 | def init_(tensor): 318 | return torch.nn.init.normal_(tensor, mean=0.0, std=sigma) 319 | 320 | return init_ 321 | 322 | 323 | def scaled_init_method(sigma, num_layers): 324 | """Init method based on N(0, sigma/sqrt(2*num_layers).""" 325 | std = sigma / math.sqrt(2.0 * num_layers) 326 | 327 | def init_(tensor): 328 | return torch.nn.init.normal_(tensor, mean=0.0, std=std) 329 | 330 | return init_ 331 | 332 | 333 | class GPT3ParallelTransformer(torch.nn.Module): 334 | """GPT-3 transformer. 335 | 336 | This module takes input from embedding layer and it's output can 337 | be used directly by a logit layer. It consists of L (num-layers) 338 | blocks of: 339 | layer norm 340 | self attention 341 | residual connection 342 | layer norm 343 | mlp 344 | residual connection 345 | followed by a final layer norm. 346 | 347 | Arguments: 348 | num_layers: Number of transformer layers. 349 | hidden_size: The hidden size of the self attention. 350 | num_attention_heads: number of attention head in the self 351 | attention. 352 | attention_dropout_prob: dropout probability of the attention 353 | score in self attention. 354 | output_dropout_prob: dropout probability for the outputs 355 | after self attention and final output. 356 | checkpoint_activations: if True, checkpoint activations. 357 | checkpoint_num_layers: number of layers to checkpoint. This 358 | is basically the chunk size in checkpoitning. 359 | layernorm_epsilon: epsilon used in layernorm to avoid 360 | division by zero. 361 | init_method_std: standard deviation of the init method which has 362 | the form N(0, std). 363 | use_scaled_init_for_output_weights: If Ture use 1/sqrt(2*num_layers) 364 | scaling for the output weights ( 365 | output of self attention and mlp). 366 | """ 367 | 368 | def __init__(self, 369 | num_layers, 370 | hidden_size, 371 | num_attention_heads, 372 | attention_dropout_prob, 373 | output_dropout_prob, 374 | checkpoint_activations, 375 | checkpoint_num_layers=1, 376 | layernorm_epsilon=1.0e-5, 377 | init_method_std=0.02, 378 | use_scaled_init_for_output_weights=True, 379 | use_deepspeed_sparse=None, 380 | sparse_mode='all'): 381 | super(GPT3ParallelTransformer, self).__init__() 382 | 383 | if DEEPSPEED_WRAP: 384 | from deepspeed.ops.sparse_attention import SparseSelfAttention 385 | 386 | # Store activation checkpoiting flag. 387 | self.checkpoint_activations = checkpoint_activations 388 | self.checkpoint_num_layers = checkpoint_num_layers 389 | 390 | output_layer_init_method = None 391 | if use_scaled_init_for_output_weights: 392 | output_layer_init_method = scaled_init_method(init_method_std, 393 | num_layers) 394 | if use_deepspeed_sparse and sparse_mode == 'alternating': 395 | print('Use alternating sparse & dense attention layers') 396 | 397 | def get_layer(layer_num, num_layers): 398 | sparsity_config = use_deepspeed_sparse 399 | if use_deepspeed_sparse: 400 | if sparse_mode == 'alternating' and layer_num % 2: # even layers are dense 401 | sparsity_config = None 402 | elif sparse_mode == 'top_bottom' and layer_num >= num_layers // 2: # top levels are dense 403 | sparsity_config = None 404 | return GPT3ParallelTransformerLayer( 405 | hidden_size, 406 | num_attention_heads, 407 | attention_dropout_prob, 408 | output_dropout_prob, 409 | layernorm_epsilon, 410 | unscaled_init_method(init_method_std), 411 | output_layer_init_method=output_layer_init_method, 412 | use_deepspeed_sparse=sparsity_config) 413 | 414 | # Transformer layers. 415 | self.layers = torch.nn.ModuleList( 416 | [get_layer(i, num_layers) for i in range(num_layers)]) 417 | 418 | # Final layer norm before output. 419 | self.final_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) 420 | 421 | if DEEPSPEED_WRAP: 422 | if DEEPSPEED_WRAP.deepspeed.checkpointing.is_configured(): 423 | global get_cuda_rng_tracker, checkpoint 424 | get_cuda_rng_tracker = DEEPSPEED_WRAP.deepspeed.checkpointing.get_cuda_rng_tracker 425 | checkpoint = DEEPSPEED_WRAP.deepspeed.checkpointing.checkpoint 426 | 427 | def forward(self, hidden_states, attention_mask): 428 | 429 | def custom(start, end): 430 | def custom_forward(*inputs): 431 | layers_ = self.layers[start:end] 432 | x_ = inputs[0] 433 | for layer in layers_: 434 | x_ = layer(x_, inputs[1]) 435 | return x_ 436 | 437 | return custom_forward 438 | 439 | if self.checkpoint_activations: 440 | l = 0 441 | num_layers = len(self.layers) 442 | chunk_length = self.checkpoint_num_layers 443 | while l < num_layers: 444 | hidden_states = checkpoint(custom(l, l + chunk_length), 445 | hidden_states, attention_mask) 446 | l += chunk_length 447 | else: 448 | for layer in self.layers: 449 | hidden_states = layer(hidden_states, attention_mask) 450 | 451 | # Final layer norm. 452 | output = self.final_layernorm(hidden_states) 453 | 454 | return output 455 | -------------------------------------------------------------------------------- /src/arguments.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2020, Sber. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """argparser configuration""" 17 | 18 | import argparse 19 | import os 20 | import torch 21 | from src.utils import DEEPSPEED_WRAP 22 | from modules.data.read import DataReader, add_data_reader_arguments 23 | 24 | 25 | def add_model_config_args(parser): 26 | """Model arguments""" 27 | 28 | group = parser.add_argument_group('model', 'model configuration') 29 | 30 | group.add_argument('--attention-dropout', type=float, default=0.1, 31 | help='dropout probability for attention weights') 32 | group.add_argument('--num-attention-heads', type=int, default=16, 33 | help='num of transformer attention heads') 34 | group.add_argument('--hidden-size', type=int, default=1024, 35 | help='tansformer hidden size') 36 | group.add_argument('--intermediate-size', type=int, default=None, 37 | help='transformer embedding dimension for FFN' 38 | 'set to 4*`--hidden-size` if it is None') 39 | group.add_argument('--num-layers', type=int, default=24, 40 | help='num decoder layers') 41 | group.add_argument('--layernorm-epsilon', type=float, default=1e-5, 42 | help='layer norm epsilon') 43 | group.add_argument('--hidden-dropout', type=float, default=0.1, 44 | help='dropout probability for hidden state transformer') 45 | group.add_argument('--max-position-embeddings', type=int, default=512, 46 | help='maximum number of position embeddings to use') 47 | group.add_argument('--vocab-size', type=int, default=30522, 48 | help='vocab size to use for non-character-level ' 49 | 'tokenization. This value will only be used when ' 50 | 'creating a tokenizer') 51 | group.add_argument('--deep-init', action='store_true', 52 | help='initialize bert model similar to gpt2 model.' 53 | 'scales initialization of projection layers by a ' 54 | 'factor of 1/sqrt(2N). Necessary to train bert ' 55 | 'models larger than BERT-Large.') 56 | group.add_argument('--make-vocab-size-divisible-by', type=int, default=8, 57 | help='Pad the vocab size to be divisible by this value.' 58 | 'This is added for computational efficieny reasons.') 59 | group.add_argument('--cpu-optimizer', action='store_true', 60 | help='Run optimizer on CPU') 61 | group.add_argument('--cpu_torch_adam', action='store_true', 62 | help='Use Torch Adam as optimizer on CPU.') 63 | group.add_argument('--sparse-mode', type=str, default='all', 64 | choices=['all', 'alternating', 'top_bottom'], 65 | help='sparse layers arrangement in model') 66 | 67 | return parser 68 | 69 | 70 | def add_fp16_config_args(parser): 71 | """Mixed precision arguments.""" 72 | 73 | group = parser.add_argument_group('fp16', 'fp16 configurations') 74 | 75 | group.add_argument('--fp16', action='store_true', 76 | help='Run model in fp16 mode') 77 | group.add_argument('--fp32-embedding', action='store_true', 78 | help='embedding in fp32') 79 | group.add_argument('--fp32-layernorm', action='store_true', 80 | help='layer norm in fp32') 81 | group.add_argument('--fp32-tokentypes', action='store_true', 82 | help='embedding token types in fp32') 83 | group.add_argument('--fp32-allreduce', action='store_true', 84 | help='all-reduce in fp32') 85 | group.add_argument('--hysteresis', type=int, default=2, 86 | help='hysteresis for dynamic loss scaling') 87 | group.add_argument('--loss-scale', type=float, default=None, 88 | help='Static loss scaling, positive power of 2 ' 89 | 'values can improve fp16 convergence. If None, dynamic' 90 | 'loss scaling is used.') 91 | group.add_argument('--loss-scale-window', type=float, default=1000, 92 | help='Window over which to raise/lower dynamic scale') 93 | group.add_argument('--min-scale', type=float, default=1, 94 | help='Minimum loss scale for dynamic loss scale') 95 | 96 | return parser 97 | 98 | 99 | def add_training_args(parser): 100 | """Training arguments.""" 101 | 102 | group = parser.add_argument_group('train', 'training configurations') 103 | 104 | group.add_argument('--batch-size', type=int, default=4, 105 | help='Data Loader batch size') 106 | group.add_argument('--weight-decay', type=float, default=0.01, 107 | help='weight decay coefficient for L2 regularization') 108 | group.add_argument('--checkpoint-activations', action='store_true', 109 | help='checkpoint activation to allow for training ' 110 | 'with larger models and sequences') 111 | group.add_argument('--checkpoint-num-layers', type=int, default=1, 112 | help='chunk size (number of layers) for checkpointing') 113 | group.add_argument('--deepspeed-activation-checkpointing', action='store_true', 114 | help='uses activation checkpointing from deepspeed') 115 | group.add_argument('--clip-grad', type=float, default=1.0, 116 | help='gradient clipping') 117 | group.add_argument('--train-iters', type=int, default=1000000, 118 | help='total number of iterations to train over all training runs') 119 | group.add_argument('--log-interval', type=int, default=100, 120 | help='report interval') 121 | group.add_argument('--logging-dir', type=str, default=None, 122 | help='tensorboard log dir') 123 | group.add_argument('--exit-interval', type=int, default=None, 124 | help='Exit the program after this many new iterations.') 125 | 126 | group.add_argument('--seed', type=int, default=1234, 127 | help='random seed') 128 | # Batch prodecuer arguments 129 | group.add_argument('--reset-position-ids', action='store_true', 130 | help='Reset posistion ids after end-of-document token.') 131 | group.add_argument('--reset-attention-mask', action='store_true', 132 | help='Reset self attention maske after ' 133 | 'end-of-document token.') 134 | 135 | # Learning rate. 136 | group.add_argument('--lr-decay-iters', type=int, default=None, 137 | help='number of iterations to decay LR over,' 138 | ' If None defaults to `--train-iters`*`--epochs`') 139 | group.add_argument('--lr-decay-style', type=str, default='linear', 140 | choices=['constant', 'linear', 'cosine', 'exponential'], 141 | help='learning rate decay function') 142 | group.add_argument('--lr', type=float, default=1.0e-4, 143 | help='initial learning rate') 144 | group.add_argument('--min-lr', type=float, default=1.0e-6, 145 | help='minimal learning rate') 146 | group.add_argument('--warmup', type=float, default=0.01, 147 | help='percentage of data to warmup on (.01 = 1% of all ' 148 | 'training iters). Default 0.01') 149 | # model checkpointing 150 | group.add_argument('--save', type=str, default=None, 151 | help='Output directory to save checkpoints to.') 152 | group.add_argument('--save-interval', type=int, default=5000, 153 | help='number of iterations between saves') 154 | group.add_argument('--no-save-optim', action='store_true', 155 | help='Do not save current optimizer.') 156 | group.add_argument('--no-save-rng', action='store_true', 157 | help='Do not save current rng state.') 158 | group.add_argument('--load', type=str, default=None, 159 | help='Path to a directory containing a model checkpoint.') 160 | group.add_argument('--no-load-optim', action='store_true', 161 | help='Do not load optimizer when loading checkpoint.') 162 | group.add_argument('--log-memory', action='store_true', 163 | help='Write memory consumption in tensorboard log') 164 | group.add_argument('--no-load-rng', action='store_true', 165 | help='Do not load rng state when loading checkpoint.') 166 | group.add_argument('--load-huggingface', type=str, default=None, 167 | help='Path to a directory containing a huggingface transformers model checkpoint.') 168 | group.add_argument('--export-huggingface', type=str, default=None, 169 | help='Exported model to path in huggingface format.') 170 | group.add_argument('--huggingface-double-pos-embeddings', action='store_true', 171 | help='Duplicate first half of pos embedding weights to last') 172 | group.add_argument('--load-tag', type=str, default='', 173 | help='checkpoint name to test') 174 | group.add_argument('--cache-prefix', type=str, default='_', 175 | help='cache folder prefix') 176 | group.add_argument('--finetune', action='store_true', 177 | help='Load model for finetuning. Do not load optimizer ' 178 | 'or rng state from checkpoint and set iteration to 0. ' 179 | 'Assumed when loading a release checkpoint.') 180 | group.add_argument('--resume-dataloader', action='store_true', 181 | help='Resume the dataloader when resuming training. ' 182 | 'Does not apply to tfrecords dataloader, try resuming' 183 | 'with a different seed in this case.') 184 | # distributed training args 185 | group.add_argument('--distributed-backend', default='nccl', 186 | help='which backend to use for distributed ' 187 | 'training. One of [gloo, nccl]') 188 | 189 | group.add_argument('--local_rank', type=int, default=None, 190 | help='local rank passed from distributed launcher') 191 | 192 | group.add_argument('--master_port', type=int, default=6000, 193 | help='master port for test parallel prediction') 194 | 195 | return parser 196 | 197 | 198 | def add_evaluation_args(parser): 199 | """Evaluation arguments.""" 200 | 201 | group = parser.add_argument_group('validation', 'validation configurations') 202 | 203 | group.add_argument('--eval-batch-size', type=int, default=None, 204 | help='Data Loader batch size for evaluation datasets.' 205 | 'Defaults to `--batch-size`') 206 | group.add_argument('--eval-iters', type=int, default=100, 207 | help='number of iterations to run for evaluation' 208 | 'validation/test for') 209 | group.add_argument('--eval-interval', type=int, default=1000, 210 | help='interval between running evaluation on validation set') 211 | group.add_argument('--eval-seq-length', type=int, default=None, 212 | help='Maximum sequence length to process for ' 213 | 'evaluation. Defaults to `--seq-length`') 214 | group.add_argument('--eval-max-preds-per-seq', type=int, default=None, 215 | help='Maximum number of predictions to use for ' 216 | 'evaluation. Defaults to ' 217 | 'math.ceil(`--eval-seq-length`*.15/10)*10') 218 | group.add_argument('--overlapping-eval', type=int, default=32, 219 | help='sliding window for overlapping eval ') 220 | group.add_argument('--cloze-eval', action='store_true', 221 | help='Evaluation dataset from `--valid-data` is a cloze task') 222 | group.add_argument('--eval-hf', action='store_true', 223 | help='perform evaluation with huggingface openai model.' 224 | 'use `--load` to specify weights path to be loaded') 225 | group.add_argument('--load-openai', action='store_true', 226 | help='load openai weights into our model. Use `--load` ' 227 | 'to specify weights path to be loaded') 228 | 229 | return parser 230 | 231 | 232 | def add_text_generate_args(parser): 233 | """Text generate arguments.""" 234 | 235 | group = parser.add_argument_group('Text generation', 'configurations') 236 | group.add_argument("--temperature", type=float, default=1.0) 237 | group.add_argument("--top_p", type=float, default=0.0) 238 | group.add_argument("--top_k", type=int, default=0) 239 | group.add_argument("--out-seq-length", type=int, default=256) 240 | group.add_argument("--tg-token-name", type=str, default='token.txt') 241 | return parser 242 | 243 | 244 | def add_data_args(parser): 245 | """Train/valid/test data arguments.""" 246 | 247 | group = parser.add_argument_group('data', 'data configurations') 248 | 249 | group.add_argument('--model-parallel-size', type=int, default=1, 250 | help='size of the model parallel.') 251 | group.add_argument('--shuffle', action='store_true', 252 | help='Shuffle data. Shuffling is deterministic ' 253 | 'based on seed and current epoch.') 254 | group.add_argument('--train-data', nargs='+', default=None, 255 | help='Whitespace separated filenames or corpora names ' 256 | 'for training.') 257 | 258 | group.add_argument('--use-npy-data-loader', action='store_true', 259 | help='Use the numpy data loader. If set, then' 260 | 'train-data-path, val-data-path, and test-data-path' 261 | 'should also be provided.') 262 | group.add_argument('--train-data-path', type=str, default='', 263 | help='path to the training data') 264 | group.add_argument('--val-data-path', type=str, default='', 265 | help='path to the validation data') 266 | group.add_argument('--test-data-path', type=str, default='', 267 | help='path to the test data') 268 | group.add_argument('--input-data-sizes-file', type=str, default='sizes.txt', 269 | help='the filename containing all the shards sizes') 270 | 271 | group.add_argument('--delim', default=',', 272 | help='delimiter used to parse csv data files') 273 | group.add_argument('--text-key', default='sentence', 274 | help='key to use to extract text from json/csv') 275 | group.add_argument('--eval-text-key', default=None, 276 | help='key to use to extract text from ' 277 | 'json/csv evaluation datasets') 278 | group.add_argument('--valid-data', nargs='*', default=None, 279 | help="""Filename for validation data.""") 280 | group.add_argument('--split', default='1000,1,1', 281 | help='comma-separated list of proportions for training,' 282 | ' validation, and test split') 283 | group.add_argument('--test-data', nargs='*', default=None, 284 | help="""Filename for testing""") 285 | group.add_argument('--overwrite-cache', action='store_true', 286 | help='overwrite dataset cache') 287 | group.add_argument('--lazy-loader', action='store_true', 288 | help='whether to lazy read the data set') 289 | group.add_argument('--loose-json', action='store_true', 290 | help='Use loose json (one json-formatted string per ' 291 | 'newline), instead of tight json (data file is one ' 292 | 'json string)') 293 | group.add_argument('--presplit-sentences', action='store_true', 294 | help='Dataset content consists of documents where ' 295 | 'each document consists of newline separated sentences') 296 | group.add_argument('--num-workers', type=int, default=2, 297 | help="""Number of workers to use for dataloading""") 298 | group.add_argument('--tokenizer-path', type=str, default=None, 299 | help='path used to save/load sentencepiece tokenization ' 300 | 'models') 301 | group.add_argument("--cache-dir", default=None, type=str, 302 | help="Where to store pre-trained BERT downloads") 303 | group.add_argument('--use-tfrecords', action='store_true', 304 | help='load `--train-data`, `--valid-data`, ' 305 | '`--test-data` from BERT tf records instead of ' 306 | 'normal data pipeline') 307 | group.add_argument('--seq-length', type=int, default=512, 308 | help="Maximum sequence length to process") 309 | group.add_argument('--max-files-per-process', type=int, default=50000, 310 | help="Maximum files to load per process") 311 | group.add_argument('--max-preds-per-seq', type=int, default=None, 312 | help='Maximum number of predictions to use per sequence.' 313 | 'Defaults to math.ceil(`--seq-length`*.15/10)*10.' 314 | 'MUST BE SPECIFIED IF `--use-tfrecords` is True.') 315 | 316 | return parser 317 | 318 | 319 | def get_args(cmd=None): 320 | """Parse all the args.""" 321 | 322 | parser = argparse.ArgumentParser(description='PyTorch BERT Model') 323 | parser = add_model_config_args(parser) 324 | parser = add_fp16_config_args(parser) 325 | parser = add_training_args(parser) 326 | parser = add_evaluation_args(parser) 327 | parser = add_text_generate_args(parser) 328 | parser = add_data_args(parser) 329 | parser = add_data_reader_arguments(parser) 330 | 331 | # Include DeepSpeed configuration arguments 332 | if DEEPSPEED_WRAP: 333 | parser = DEEPSPEED_WRAP.deepspeed.add_config_arguments(parser) 334 | if cmd is None: 335 | args = parser.parse_args() 336 | else: 337 | args = parser.parse_args(cmd) 338 | 339 | if not args.train_data and not args.train_data_path: 340 | print('WARNING: No training data specified') 341 | 342 | args.cuda = torch.cuda.is_available() 343 | 344 | args.rank = int(os.getenv('RANK', '0')) 345 | args.world_size = int(os.getenv("WORLD_SIZE", '1')) 346 | 347 | if os.getenv('OMPI_COMM_WORLD_LOCAL_RANK'): 348 | # We are using (OpenMPI) mpirun for launching distributed data parallel processes 349 | local_rank = int(os.getenv('OMPI_COMM_WORLD_LOCAL_RANK')) 350 | local_size = int(os.getenv('OMPI_COMM_WORLD_LOCAL_SIZE')) 351 | 352 | # Possibly running with Slurm 353 | num_nodes = int(os.getenv('SLURM_JOB_NUM_NODES', '1')) 354 | nodeid = int(os.getenv('SLURM_NODEID', '0')) 355 | 356 | args.local_rank = local_rank 357 | args.rank = nodeid * local_size + local_rank 358 | args.world_size = num_nodes * local_size 359 | 360 | args.model_parallel_size = min(args.model_parallel_size, args.world_size) 361 | if args.rank == 0: 362 | print('using world size: {} and model-parallel size: {} '.format( 363 | args.world_size, args.model_parallel_size)) 364 | 365 | args.dynamic_loss_scale = False 366 | if args.loss_scale is None: 367 | args.dynamic_loss_scale = True 368 | if args.rank == 0: 369 | print(' > using dynamic loss scaling') 370 | 371 | # The args fp32_* or fp16_* meant to be active when the 372 | # args fp16 is set. So the default behaviour should all 373 | # be false. 374 | if not args.fp16: 375 | args.fp32_embedding = False 376 | args.fp32_tokentypes = False 377 | args.fp32_layernorm = False 378 | 379 | return args 380 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2020, Sber. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Utilities for logging and serialization""" 17 | 18 | import os 19 | import random 20 | import time 21 | 22 | import numpy as np 23 | import torch 24 | from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP 25 | 26 | from src import mpu 27 | from src.fp16 import FP16_Optimizer, FP16_Module 28 | from src.model import DistributedDataParallel as DDP 29 | import os 30 | from src.download_utils import download_model_files 31 | 32 | 33 | class DeepSpeedImportWrap(object): 34 | def __init__(self): 35 | self.use_ds = os.environ.get("USE_DEEPSPEED", False) 36 | self.deepspeed = None 37 | if self.use_ds: 38 | import deepspeed 39 | self.deepspeed = deepspeed 40 | 41 | def __bool__(self): 42 | return bool(self.use_ds) 43 | 44 | 45 | DEEPSPEED_WRAP = DeepSpeedImportWrap() 46 | 47 | 48 | def print_rank_0(message): 49 | if torch.distributed.is_initialized(): 50 | if torch.distributed.get_rank() == 0: 51 | print(message, flush=True) 52 | else: 53 | print(message, flush=True) 54 | 55 | 56 | def print_args(args): 57 | """Print arguments.""" 58 | 59 | print('arguments:', flush=True) 60 | for arg in vars(args): 61 | dots = '.' * (29 - len(arg)) 62 | print(' {} {} {}'.format(arg, dots, getattr(args, arg)), flush=True) 63 | 64 | 65 | def print_params_min_max_norm(optimizer, iteration): 66 | """Print min, max, and norm of all parameters.""" 67 | index = 0 68 | rank = torch.distributed.get_rank() 69 | string = 'iteration, rank, index, model-parallel,min, max, norm\n' 70 | optimizer_ = optimizer 71 | if isinstance(optimizer, FP16_Optimizer): 72 | optimizer_ = optimizer.optimizer 73 | for param_group in optimizer_.param_groups: 74 | for param in param_group['params']: 75 | index += 1 76 | min_ = param.data.min() 77 | max_ = param.data.max() 78 | norm = param.data.norm() 79 | string += '{:7d}, {:4d}, {:4d}, {:2d}, '.format( 80 | iteration, rank, index, int(param.model_parallel)) 81 | string += '{:.6E}, {:.6E}, {:.6E}\n'.format(min_, max_, norm) 82 | print(string, flush=True) 83 | 84 | 85 | class Timers: 86 | """Group of timers.""" 87 | 88 | class Timer: 89 | """Timer.""" 90 | 91 | def __init__(self, name): 92 | self.name_ = name 93 | self.elapsed_ = 0.0 94 | self.started_ = False 95 | self.start_time = time.time() 96 | 97 | def start(self): 98 | """Start the timer.""" 99 | assert not self.started_, 'timer has already been started' 100 | torch.cuda.synchronize() 101 | self.start_time = time.time() 102 | self.started_ = True 103 | 104 | def stop(self): 105 | """Stop the timer.""" 106 | assert self.started_, 'timer is not started' 107 | torch.cuda.synchronize() 108 | self.elapsed_ += (time.time() - self.start_time) 109 | self.started_ = False 110 | 111 | def reset(self): 112 | """Reset timer.""" 113 | self.elapsed_ = 0.0 114 | self.started_ = False 115 | 116 | def elapsed(self, reset=True): 117 | """Calculate the elapsed time.""" 118 | started_ = self.started_ 119 | # If the timing in progress, end it first. 120 | if self.started_: 121 | self.stop() 122 | # Get the elapsed time. 123 | elapsed_ = self.elapsed_ 124 | # Reset the elapsed time 125 | if reset: 126 | self.reset() 127 | # If timing was in progress, set it back. 128 | if started_: 129 | self.start() 130 | return elapsed_ 131 | 132 | def __init__(self): 133 | self.timers = {} 134 | 135 | def __call__(self, name): 136 | if name not in self.timers: 137 | self.timers[name] = self.Timer(name) 138 | return self.timers[name] 139 | 140 | def log(self, names, normalizer=1.0, reset=True): 141 | """Log a group of timers.""" 142 | assert normalizer > 0.0 143 | string = 'time (ms)' 144 | for name in names: 145 | elapsed_time = self.timers[name].elapsed( 146 | reset=reset) * 1000.0 / normalizer 147 | string += ' | {}: {:.2f}'.format(name, elapsed_time) 148 | print_rank_0(string) 149 | 150 | 151 | def report_memory(name): 152 | """Simple GPU memory report.""" 153 | 154 | mega_bytes = 1024.0 * 1024.0 155 | string = name + ' memory (MB)' 156 | string += ' | allocated: {}'.format( 157 | torch.cuda.memory_allocated() / mega_bytes) 158 | string += ' | max allocated: {}'.format( 159 | torch.cuda.max_memory_allocated() / mega_bytes) 160 | string += ' | cached: {}'.format(torch.cuda.memory_cached() / mega_bytes) 161 | string += ' | max cached: {}'.format( 162 | torch.cuda.max_memory_cached() / mega_bytes) 163 | print_rank_0(string) 164 | 165 | 166 | def get_checkpoint_name(checkpoints_path, iteration, release=False, zero=False): 167 | if release: 168 | d = 'release' 169 | else: 170 | d = 'iter_{:07d}'.format(iteration) 171 | if zero: 172 | dp_rank = mpu.get_data_parallel_rank() 173 | d += '_zero_dp_rank_{}'.format(dp_rank) 174 | return os.path.join(checkpoints_path, d, 175 | 'mp_rank_{:02d}'.format(mpu.get_model_parallel_rank()), 176 | 'model_optim_rng.pt') 177 | 178 | 179 | def ensure_directory_exists(filename): 180 | dirname = os.path.dirname(filename) 181 | if not os.path.exists(dirname): 182 | os.makedirs(dirname) 183 | 184 | 185 | def get_checkpoint_tracker_filename(checkpoints_path): 186 | return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt') 187 | 188 | 189 | def save_zero_checkpoint(args, iteration, optimizer): 190 | zero_sd = {'iteration': iteration, 191 | 'optimizer_state_dict': optimizer.state_dict()} 192 | zero_checkpoint_name = get_checkpoint_name(args.save, iteration, zero=True) 193 | ensure_directory_exists(zero_checkpoint_name) 194 | torch.save(zero_sd, zero_checkpoint_name) 195 | print(' successfully saved {}'.format(zero_checkpoint_name)) 196 | 197 | 198 | def save_checkpoint(iteration, model, optimizer, 199 | lr_scheduler, args, deepspeed=False): 200 | """Save a model checkpoint.""" 201 | if deepspeed: 202 | save_ds_checkpoint(iteration, model, args) 203 | else: 204 | # Only rank zer0 of the data parallel writes to the disk. 205 | if isinstance(model, torchDDP): 206 | model = model.module 207 | 208 | if mpu.get_data_parallel_rank() == 0: 209 | checkpoint_name = get_checkpoint_name(args.save, iteration) 210 | print('global rank {} is saving checkpoint at iteration {:7d} to {}'. 211 | format(torch.distributed.get_rank(), iteration, checkpoint_name)) 212 | 213 | sd = {} 214 | sd['iteration'] = iteration 215 | sd['model'] = model.state_dict() 216 | 217 | # Optimizer stuff. 218 | if not args.no_save_optim: 219 | if optimizer is not None: 220 | sd['optimizer'] = optimizer.state_dict() 221 | if lr_scheduler is not None: 222 | sd['lr_scheduler'] = lr_scheduler.state_dict() 223 | 224 | # rng states. 225 | if not args.no_save_rng: 226 | sd['random_rng_state'] = random.getstate() 227 | sd['np_rng_state'] = np.random.get_state() 228 | sd['torch_rng_state'] = torch.get_rng_state() 229 | sd['cuda_rng_state'] = torch.cuda.get_rng_state() 230 | sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states() 231 | 232 | ensure_directory_exists(checkpoint_name) 233 | torch.save(sd, checkpoint_name) 234 | print(' successfully saved {}'.format(checkpoint_name)) 235 | 236 | # Wait so everyone is done (necessary) 237 | torch.distributed.barrier() 238 | # And update the latest iteration 239 | if torch.distributed.get_rank() == 0: 240 | tracker_filename = get_checkpoint_tracker_filename(args.save) 241 | with open(tracker_filename, 'w') as f: 242 | f.write(str(iteration)) 243 | # Wait so everyone is done (not necessary) 244 | torch.distributed.barrier() 245 | 246 | 247 | def save_ds_checkpoint(iteration, model, args): 248 | """Save a model checkpoint.""" 249 | 250 | sd = {} 251 | sd['iteration'] = iteration 252 | # rng states. 253 | if not args.no_save_rng: 254 | sd['random_rng_state'] = random.getstate() 255 | sd['np_rng_state'] = np.random.get_state() 256 | sd['torch_rng_state'] = torch.get_rng_state() 257 | sd['cuda_rng_state'] = torch.cuda.get_rng_state() 258 | sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states() 259 | 260 | model.save_checkpoint(args.save, str(iteration), client_state=sd) 261 | 262 | 263 | def get_checkpoints(load_dir): 264 | return [f for f in os.listdir(load_dir) if '.' not in f] 265 | 266 | 267 | def get_last_checkpoints(load_dir, n=1): 268 | checkpoints = get_checkpoints(load_dir) 269 | checkpoints = [(int(c), c) for c in checkpoints] 270 | by_iteration = sorted(checkpoints, key=lambda x: x[0], reverse=True) 271 | return [b[-1] for b in by_iteration[:n]] 272 | 273 | 274 | def get_outdated_checkpoints(load_dir, retain_last_n=5): 275 | last = get_last_checkpoints(load_dir, retain_last_n) 276 | return [d for d in get_checkpoints(load_dir) if d not in last] 277 | 278 | 279 | def get_checkpoint_iteration(args): 280 | # Read the tracker file and set the iteration. 281 | if args.load_tag: 282 | return args.load_tag, False, True 283 | tracker_filename = get_checkpoint_tracker_filename(args.load) 284 | if not os.path.isfile(tracker_filename): 285 | print_rank_0('WARNING: could not find the metadata file {} '.format( 286 | tracker_filename)) 287 | print_rank_0(' will not load any checkpoints and will start from ' 288 | 'random') 289 | return 0, False, False 290 | iteration = 0 291 | release = False 292 | with open(tracker_filename, 'r') as f: 293 | metastring = f.read().strip() 294 | try: 295 | iteration = int(metastring) 296 | except ValueError: 297 | release = metastring == 'release' 298 | if not release: 299 | print_rank_0('ERROR: Invalid metadata file {}. Exiting'.format( 300 | tracker_filename)) 301 | exit() 302 | 303 | assert iteration > 0 or release, 'error parsing metadata file {}'.format( 304 | tracker_filename) 305 | 306 | return iteration, release, True 307 | 308 | 309 | def load_checkpoint(model, optimizer, lr_scheduler, args, deepspeed=False): 310 | """Load a model checkpoint.""" 311 | 312 | iteration, release, success = get_checkpoint_iteration(args) 313 | 314 | if not success: 315 | return 0 316 | 317 | if deepspeed: 318 | load_optim = not args.no_load_optim 319 | checkpoint_name, sd = model.load_checkpoint(args.load, iteration, load_optimizer_states=load_optim, 320 | load_lr_scheduler_states=load_optim) 321 | 322 | if checkpoint_name is None: 323 | if mpu.get_data_parallel_rank() == 0: 324 | print("Unable to load checkpoint.") 325 | return iteration 326 | 327 | else: 328 | 329 | # Checkpoint. 330 | checkpoint_name = get_checkpoint_name(args.load, iteration, release) 331 | 332 | if mpu.get_data_parallel_rank() == 0: 333 | print('global rank {} is loading checkpoint {}'.format( 334 | torch.distributed.get_rank(), checkpoint_name)) 335 | 336 | # Load the checkpoint. 337 | sd = torch.load(checkpoint_name, map_location='cpu') 338 | 339 | if isinstance(model, torchDDP): 340 | model = model.module 341 | 342 | # Model. 343 | try: 344 | model.load_state_dict(sd['model']) 345 | except KeyError: 346 | print_rank_0('A metadata file exists but unable to load model ' 347 | 'from checkpoint {}, exiting'.format(checkpoint_name)) 348 | exit() 349 | 350 | # Optimizer. 351 | if not release and not args.finetune and not args.no_load_optim: 352 | try: 353 | if optimizer is not None: 354 | optimizer.load_state_dict(sd['optimizer']) 355 | if lr_scheduler is not None: 356 | lr_scheduler.load_state_dict(sd['lr_scheduler']) 357 | except KeyError: 358 | print_rank_0('Unable to load optimizer from checkpoint {}, exiting. ' 359 | 'Specify --no-load-optim or --finetune to prevent ' 360 | 'attempting to load the optimizer ' 361 | 'state.'.format(checkpoint_name)) 362 | exit() 363 | 364 | # Iterations. 365 | if args.finetune or release: 366 | iteration = 0 367 | else: 368 | try: 369 | iteration = sd['iteration'] 370 | except KeyError: 371 | try: # Backward compatible with older checkpoints 372 | iteration = sd['total_iters'] 373 | except KeyError: 374 | print_rank_0('A metadata file exists but Unable to load iteration ' 375 | ' from checkpoint {}, exiting'.format(checkpoint_name)) 376 | exit() 377 | 378 | # rng states. 379 | if not release and not args.finetune and not args.no_load_rng: 380 | try: 381 | random.setstate(sd['random_rng_state']) 382 | np.random.set_state(sd['np_rng_state']) 383 | torch.set_rng_state(sd['torch_rng_state']) 384 | torch.cuda.set_rng_state(sd['cuda_rng_state']) 385 | mpu.get_cuda_rng_tracker().set_states(sd['rng_tracker_states']) 386 | except KeyError: 387 | print_rank_0('Unable to load optimizer from checkpoint {}, exiting. ' 388 | 'Specify --no-load-optim or --finetune to prevent ' 389 | 'attempting to load the optimizer ' 390 | 'state.'.format(checkpoint_name)) 391 | exit() 392 | 393 | torch.distributed.barrier() 394 | if mpu.get_data_parallel_rank() == 0: 395 | print(' successfully loaded {}'.format(checkpoint_name)) 396 | 397 | return iteration 398 | 399 | 400 | def load_weights(src, dst, dst2src=False, double_pos_embeddings=False): 401 | """ 402 | Loads weights from src to dst via in place copy. 403 | src is a huggingface gpt3model, while dst is one of our models. 404 | dst2src=True loads parameters from our models into huggingface's. 405 | ^dst2src is still untested 406 | """ 407 | conv_layer = 'Conv1D' in str(type(src)) 408 | for n, p in src.named_parameters(): 409 | if dst2src: 410 | data = dst._parameters[n].data 411 | load = p.data 412 | else: 413 | if double_pos_embeddings: 414 | print('Double pos embeddings') 415 | mid = p.size(0) // 2 416 | p[mid:, :] = p[:mid, :] # copy first half of position embedings to last 417 | data = p.data 418 | load = dst._parameters[n].data 419 | if conv_layer and 'weight' in n: 420 | data = data.t().contiguous() 421 | load.copy_(data) 422 | 423 | 424 | # dst._parameters[n].data.copy_(data) 425 | 426 | def load_mlp(our, oai, dst2src=False): 427 | load_weights(oai.c_fc, our.dense_h_to_4h, dst2src) 428 | load_weights(oai.c_proj, our.dense_4h_to_h, dst2src) 429 | 430 | 431 | def load_attention(our, oai, dst2src=False): 432 | load_weights(oai.c_attn, our.query_key_value, dst2src) 433 | load_weights(oai.c_proj, our.dense, dst2src) 434 | 435 | 436 | def load_transformer_layer(our, oai, dst2src=False): 437 | load_weights(oai.ln_1, our.input_layernorm, dst2src) 438 | load_weights(oai.ln_2, our.post_attention_layernorm, dst2src) 439 | load_mlp(our.mlp, oai.mlp, dst2src) 440 | load_attention(our.attention, oai.attn, dst2src) 441 | 442 | 443 | def move_weights(our, oai, dst2src=False, double_pos_embeddings=False): 444 | """ 445 | Loads weights from `oai` to `our` via in place copy. 446 | `oai` is a huggingface gpt3model, while `our` is one of our models. 447 | dst2src=True loads parameters from our models into huggingface's. 448 | ^dst2src=True is still untested 449 | """ 450 | # while isinstance(our, (torchDDP, model.distributed.DistributedDataParallel, FP16_Module)): 451 | # our=our.module 452 | transformer_model = oai.transformer 453 | load_weights(transformer_model.ln_f, our.transformer.final_layernorm, dst2src) 454 | load_weights(transformer_model.wte, our.word_embeddings, dst2src) 455 | load_weights(transformer_model.wpe, our.position_embeddings, dst2src, double_pos_embeddings=double_pos_embeddings) 456 | 457 | for our_layer, oai_layer in zip(our.transformer.layers, oai.transformer.h): 458 | load_transformer_layer(our_layer, oai_layer, dst2src) 459 | 460 | 461 | def load_huggingface_model(model, path, double_pos_embeddings): 462 | from transformers import GPT2LMHeadModel 463 | print('Load huggingface model from', path, ('with pos emb doubling' if double_pos_embeddings else '')) 464 | model2fill = model 465 | while isinstance(model2fill, (torchDDP, FP16_Module)): 466 | model2fill = model2fill.module 467 | 468 | if path == "sberbank-ai/rugpt3xl": 469 | weights_path, _ = download_model_files("sberbank-ai/rugpt3xl") 470 | checkpoint = torch.load(weights_path, map_location=lambda storage, loc: storage)['module'] 471 | model2fill.load_state_dict(checkpoint) 472 | else: 473 | h_model = GPT2LMHeadModel.from_pretrained(path) 474 | move_weights(model2fill, h_model, double_pos_embeddings) 475 | 476 | print('Loaded huggingface model', type(model)) 477 | return model 478 | 479 | 480 | def export_to_huggingface_model(model, path): 481 | from transformers import GPT2LMHeadModel, GPT2Config 482 | model_from = model 483 | while isinstance(model_from, (DDP, torchDDP, FP16_Module)): 484 | model_from = model_from.module 485 | conf_dict = model_from._conf_dict 486 | print('Export to huggingface model ', path, 'with config', conf_dict) 487 | config = GPT2Config(**conf_dict) 488 | hf_model = GPT2LMHeadModel(config=config) 489 | model_to = hf_model 490 | while isinstance(model_to, (DDP, torchDDP, FP16_Module)): 491 | model_to = model_to.module 492 | move_weights(model_from, model_to, dst2src=True) 493 | hf_model.save_pretrained(path) 494 | print('Saved huggingface model', type(model)) 495 | 496 | 497 | def get_deepspeed_config(args): 498 | if hasattr(args, 'deepspeed_config') and args.deepspeed_config: 499 | from deepspeed import DeepSpeedConfig 500 | return DeepSpeedConfig(args.deepspeed_config) 501 | else: 502 | raise RuntimeError('deepspeed_config is not found in args.') 503 | 504 | 505 | def get_sparse_attention_config(args, num_heads): 506 | ds_config = get_deepspeed_config(args) 507 | if hasattr(ds_config, 'sparse_attention') and ds_config.sparse_attention: 508 | sa_config = ds_config.sparse_attention 509 | sa_mode = sa_config.get('mode') 510 | if (sa_mode == 'dense'): 511 | from deepspeed.ops.sparse_attention import DenseSparsityConfig as STConfig 512 | elif (sa_mode == 'fixed'): 513 | from deepspeed.ops.sparse_attention import FixedSparsityConfig as STConfig 514 | elif (sa_mode == 'bigbird'): 515 | from deepspeed.ops.sparse_attention import BigBirdSparsityConfig as STConfig 516 | elif (sa_mode == 'bslongformer'): 517 | from deepspeed.ops.sparse_attention import BSLongformerSparsityConfig as STConfig 518 | elif (sa_mode == 'variable'): 519 | from deepspeed.ops.sparse_attention import VariableSparsityConfig as STConfig 520 | else: 521 | raise NotImplementedError( 522 | f'Given sparsity mode, {sa_mode}, has not been implemented yet!' 523 | ) 524 | del sa_config['mode'] 525 | return STConfig(num_heads=num_heads, **sa_config) 526 | else: 527 | return None 528 | 529 | 530 | def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): 531 | # This function has been mostly taken from huggingface conversational ai code at 532 | # https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313 533 | 534 | if top_k > 0: 535 | # Remove all tokens with a probability less than the last token of the top-k 536 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] 537 | logits[indices_to_remove] = filter_value 538 | 539 | if top_p > 0.0: 540 | # convert to 1D 541 | logits = logits.view(logits.size()[1]).contiguous() 542 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) 543 | cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1) 544 | 545 | # Remove tokens with cumulative probability above the threshold 546 | sorted_indices_to_remove = cumulative_probs > top_p 547 | # Shift the indices to the right to keep also the first token above the threshold 548 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 549 | sorted_indices_to_remove[..., 0] = 0 550 | indices_to_remove = sorted_indices[sorted_indices_to_remove] 551 | logits[indices_to_remove] = filter_value 552 | # going back to 2D 553 | logits = logits.view(1, -1).contiguous() 554 | 555 | return logits 556 | --------------------------------------------------------------------------------