├── src
├── __init__.py
├── __pycache__
│ ├── utils.cpython-37.pyc
│ ├── __init__.cpython-37.pyc
│ ├── xl_wrapper.cpython-37.pyc
│ └── download_utils.cpython-37.pyc
├── fp16
│ ├── __pycache__
│ │ ├── fp16.cpython-37.pyc
│ │ ├── __init__.cpython-37.pyc
│ │ ├── fp16util.cpython-37.pyc
│ │ └── loss_scaler.cpython-37.pyc
│ ├── __init__.py
│ ├── fp16util.py
│ └── loss_scaler.py
├── mpu
│ ├── __pycache__
│ │ ├── data.cpython-37.pyc
│ │ ├── grads.cpython-37.pyc
│ │ ├── utils.cpython-37.pyc
│ │ ├── __init__.cpython-37.pyc
│ │ ├── layers.cpython-37.pyc
│ │ ├── mappings.cpython-37.pyc
│ │ ├── random.cpython-37.pyc
│ │ ├── initialize.cpython-37.pyc
│ │ ├── cross_entropy.cpython-37.pyc
│ │ └── transformer.cpython-37.pyc
│ ├── __init__.py
│ ├── utils.py
│ ├── grads.py
│ ├── data.py
│ ├── mappings.py
│ ├── cross_entropy.py
│ ├── initialize.py
│ ├── layers.py
│ ├── random.py
│ └── transformer.py
├── model
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ ├── distributed.cpython-37.pyc
│ │ └── gpt3_modeling.cpython-37.pyc
│ ├── __init__.py
│ ├── distributed.py
│ └── gpt3_modeling.py
├── deepspeed_config
│ ├── gpt3_large_2048.json
│ ├── gpt3_small.json
│ ├── gpt3_small_2048.json
│ ├── gpt3_medium_2048.json
│ ├── gpt3_large.json
│ ├── gpt2_small.json
│ ├── gpt3_small_sparse.json
│ ├── gpt3_small_sparse_2048.json
│ ├── gpt3_xl_sparse.json
│ └── gpt3_xl_sparse_2048.json
├── data_utils
│ ├── __init__.py
│ ├── lazy_loader.py
│ └── file_utils.py
├── download_utils.py
├── learning_rates.py
├── gpt3_data_loader.py
├── dataset_rugpt3.py
├── xl_wrapper.py
├── arguments.py
└── utils.py
└── README.md
/src/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TatianaShavrina/rugpt3simplification_rsse/main/src/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/src/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TatianaShavrina/rugpt3simplification_rsse/main/src/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/src/fp16/__pycache__/fp16.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TatianaShavrina/rugpt3simplification_rsse/main/src/fp16/__pycache__/fp16.cpython-37.pyc
--------------------------------------------------------------------------------
/src/mpu/__pycache__/data.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TatianaShavrina/rugpt3simplification_rsse/main/src/mpu/__pycache__/data.cpython-37.pyc
--------------------------------------------------------------------------------
/src/mpu/__pycache__/grads.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TatianaShavrina/rugpt3simplification_rsse/main/src/mpu/__pycache__/grads.cpython-37.pyc
--------------------------------------------------------------------------------
/src/mpu/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TatianaShavrina/rugpt3simplification_rsse/main/src/mpu/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/src/__pycache__/xl_wrapper.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TatianaShavrina/rugpt3simplification_rsse/main/src/__pycache__/xl_wrapper.cpython-37.pyc
--------------------------------------------------------------------------------
/src/mpu/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TatianaShavrina/rugpt3simplification_rsse/main/src/mpu/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/src/mpu/__pycache__/layers.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TatianaShavrina/rugpt3simplification_rsse/main/src/mpu/__pycache__/layers.cpython-37.pyc
--------------------------------------------------------------------------------
/src/mpu/__pycache__/mappings.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TatianaShavrina/rugpt3simplification_rsse/main/src/mpu/__pycache__/mappings.cpython-37.pyc
--------------------------------------------------------------------------------
/src/mpu/__pycache__/random.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TatianaShavrina/rugpt3simplification_rsse/main/src/mpu/__pycache__/random.cpython-37.pyc
--------------------------------------------------------------------------------
/src/__pycache__/download_utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TatianaShavrina/rugpt3simplification_rsse/main/src/__pycache__/download_utils.cpython-37.pyc
--------------------------------------------------------------------------------
/src/fp16/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TatianaShavrina/rugpt3simplification_rsse/main/src/fp16/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/src/fp16/__pycache__/fp16util.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TatianaShavrina/rugpt3simplification_rsse/main/src/fp16/__pycache__/fp16util.cpython-37.pyc
--------------------------------------------------------------------------------
/src/model/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TatianaShavrina/rugpt3simplification_rsse/main/src/model/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/src/mpu/__pycache__/initialize.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TatianaShavrina/rugpt3simplification_rsse/main/src/mpu/__pycache__/initialize.cpython-37.pyc
--------------------------------------------------------------------------------
/src/fp16/__pycache__/loss_scaler.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TatianaShavrina/rugpt3simplification_rsse/main/src/fp16/__pycache__/loss_scaler.cpython-37.pyc
--------------------------------------------------------------------------------
/src/model/__pycache__/distributed.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TatianaShavrina/rugpt3simplification_rsse/main/src/model/__pycache__/distributed.cpython-37.pyc
--------------------------------------------------------------------------------
/src/mpu/__pycache__/cross_entropy.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TatianaShavrina/rugpt3simplification_rsse/main/src/mpu/__pycache__/cross_entropy.cpython-37.pyc
--------------------------------------------------------------------------------
/src/mpu/__pycache__/transformer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TatianaShavrina/rugpt3simplification_rsse/main/src/mpu/__pycache__/transformer.cpython-37.pyc
--------------------------------------------------------------------------------
/src/model/__pycache__/gpt3_modeling.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TatianaShavrina/rugpt3simplification_rsse/main/src/model/__pycache__/gpt3_modeling.cpython-37.pyc
--------------------------------------------------------------------------------
/src/deepspeed_config/gpt3_large_2048.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_micro_batch_size_per_gpu": 1,
3 | "fp16": {
4 | "enabled": false,
5 | "loss_scale": 0,
6 | "loss_scale_window": 2000,
7 | "min_loss_scale": 0.0
8 | },
9 | "zero_optimization": {
10 | "stage": 0,
11 | "reduce_bucket_size": 50000000
12 | }
13 | }
14 |
--------------------------------------------------------------------------------
/src/deepspeed_config/gpt3_small.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_micro_batch_size_per_gpu": 8,
3 | "gradient_accumulation_steps": 1,
4 | "steps_per_print": 100,
5 | "gradient_clipping": 1.0,
6 | "fp16": {
7 | "enabled": true,
8 | "loss_scale": 0,
9 | "loss_scale_window": 2000,
10 | "hysteresis": 2,
11 | "min_loss_scale": 0.0
12 | },
13 | "zero_optimization": {
14 | "stage":2,
15 | "reduce_bucket_size": 50000000
16 | }
17 | }
18 |
--------------------------------------------------------------------------------
/src/deepspeed_config/gpt3_small_2048.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_micro_batch_size_per_gpu": 2,
3 | "gradient_accumulation_steps": 1,
4 | "steps_per_print": 100,
5 | "gradient_clipping": 1.0,
6 | "fp16": {
7 | "enabled": true,
8 | "loss_scale": 0,
9 | "loss_scale_window": 2000,
10 | "hysteresis": 2,
11 | "min_loss_scale": 0.0
12 | },
13 | "zero_optimization": {
14 | "stage": 2,
15 | "reduce_bucket_size": 50000000
16 | }
17 | }
18 |
--------------------------------------------------------------------------------
/src/deepspeed_config/gpt3_medium_2048.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_micro_batch_size_per_gpu": 4,
3 | "gradient_accumulation_steps": 1,
4 | "steps_per_print": 100,
5 | "gradient_clipping": 1.0,
6 | "fp16": {
7 | "enabled": true,
8 | "loss_scale": 0,
9 | "loss_scale_window": 2000,
10 | "hysteresis": 2,
11 | "min_loss_scale": 0.0
12 | },
13 | "zero_optimization": {
14 | "stage": 2,
15 | "reduce_bucket_size": 50000000,
16 | "overlap_comm": true
17 | }
18 | }
19 |
--------------------------------------------------------------------------------
/src/deepspeed_config/gpt3_large.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_batch_size": 20,
3 | "gradient_accumulation_steps": 5,
4 | "steps_per_print": 100,
5 | "zero_optimization": {
6 | "stage": 0
7 | },
8 | "zero_allow_untested_optimizer": true,
9 | "gradient_clipping": 1.0,
10 | "fp16": {
11 | "enabled": true,
12 | "loss_scale": 0,
13 | "loss_scale_window": 1000,
14 | "hysteresis": 2,
15 | "min_loss_scale": 1
16 | },
17 | "activation_checkpointing": {
18 | "partition_activations": false,
19 | "contiguous_memory_optimization": false
20 | },
21 | "wall_clock_breakdown": false
22 | }
23 |
--------------------------------------------------------------------------------
/src/deepspeed_config/gpt2_small.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_batch_size": 168,
3 | "gradient_accumulation_steps": 1,
4 | "steps_per_print": 1000,
5 | "zero_optimization": {
6 | "stage": 2
7 | },
8 | "zero_allow_untested_optimizer": true,
9 | "gradient_clipping": 1.0,
10 | "fp16": {
11 | "enabled": true,
12 | "loss_scale": 0,
13 | "loss_scale_window": 1000,
14 | "hysteresis": 2,
15 | "min_loss_scale": 1
16 | },
17 | "activation_checkpointing": {
18 | "partition_activations": false,
19 | "contiguous_memory_optimization": false
20 | },
21 | "wall_clock_breakdown": false
22 | }
23 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ### ruGPT3 solution of RSSE
2 |
3 |
4 | Solution of [RuSimpleSentEval](https://github.com/dialogue-evaluation/RuSimpleSentEval), competition in Dialogue2021.
5 |
6 | Solution based on RuGPT-3XL. Model was tuned on train data.
7 |
8 | Our approach has achieved second place on the public leaderboard and fifth place on the privateleaderboard.
9 | It reaches about 37 SARI score.
10 |
11 | Download pre-trained XL model [here](https://disk.yandex.ru/d/dd7tM93w4g-14g) and put in folder `model`.
12 |
13 | Example of usage is [here](./Simplification%20with%20ruGPT3.ipynb)
14 |
15 | Read about all experiments in the article.
16 |
--------------------------------------------------------------------------------
/src/deepspeed_config/gpt3_small_sparse.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_micro_batch_size_per_gpu": 16,
3 | "gradient_accumulation_steps": 1,
4 | "steps_per_print": 100,
5 | "gradient_clipping": 1.0,
6 | "fp16": {
7 | "enabled": true,
8 | "loss_scale": 0,
9 | "loss_scale_window": 2000,
10 | "hysteresis": 2,
11 | "min_loss_scale": 0.0
12 | },
13 | "zero_optimization": {
14 | "stage": 2,
15 | "reduce_bucket_size": 50000000
16 | },
17 | "sparse_attention": {
18 | "mode": "fixed",
19 | "block": 16,
20 | "different_layout_per_head": true,
21 | "num_local_blocks": 8,
22 | "num_global_blocks": 1,
23 | "attention": "unidirectional",
24 | "horizontal_global_attention": false,
25 | "num_different_global_patterns": 8
26 | }
27 | }
28 |
--------------------------------------------------------------------------------
/src/deepspeed_config/gpt3_small_sparse_2048.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_micro_batch_size_per_gpu": 4,
3 | "gradient_accumulation_steps": 1,
4 | "steps_per_print": 100,
5 | "gradient_clipping": 1.0,
6 | "fp16": {
7 | "enabled": true,
8 | "loss_scale": 0,
9 | "loss_scale_window": 2000,
10 | "hysteresis": 2,
11 | "min_loss_scale": 0.0
12 | },
13 | "zero_optimization": {
14 | "stage":1,
15 | "reduce_bucket_size": 500000000
16 | },
17 | "sparse_attention": {
18 | "mode": "fixed",
19 | "block": 16,
20 | "different_layout_per_head": true,
21 | "num_local_blocks": 8,
22 | "num_global_blocks": 1,
23 | "attention": "unidirectional",
24 | "horizontal_global_attention": false,
25 | "num_different_global_patterns": 8
26 | }
27 | }
28 |
--------------------------------------------------------------------------------
/src/data_utils/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) 2020, Sber. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """utils for creating datasets"""
16 | from .lazy_loader import exists_lazy, make_lazy, lazy_array_loader
17 |
--------------------------------------------------------------------------------
/src/deepspeed_config/gpt3_xl_sparse.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_micro_batch_size_per_gpu": 8,
3 | "gradient_accumulation_steps": 1,
4 | "steps_per_print": 100,
5 | "gradient_clipping": 1.0,
6 | "fp16": {
7 | "enabled": true,
8 | "loss_scale": 0,
9 | "loss_scale_window": 2000,
10 | "hysteresis": 2,
11 | "min_loss_scale": 0.0
12 | },
13 | "zero_optimization": {
14 | "stage": 2,
15 | "reduce_bucket_size": 50000000,
16 | "overlap_comm": true
17 | },
18 | "sparse_attention": {
19 | "mode": "fixed",
20 | "block": 16,
21 | "different_layout_per_head": true,
22 | "num_local_blocks": 8,
23 | "num_global_blocks": 1,
24 | "attention": "unidirectional",
25 | "horizontal_global_attention": false,
26 | "num_different_global_patterns": 8
27 | }
28 | }
29 |
--------------------------------------------------------------------------------
/src/deepspeed_config/gpt3_xl_sparse_2048.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_micro_batch_size_per_gpu": 2,
3 | "gradient_accumulation_steps": 1,
4 | "steps_per_print": 100,
5 | "gradient_clipping": 1.0,
6 | "fp16": {
7 | "enabled": true,
8 | "loss_scale": 0,
9 | "loss_scale_window": 2000,
10 | "hysteresis": 2,
11 | "min_loss_scale": 0.0
12 | },
13 | "zero_optimization": {
14 | "stage": 2,
15 | "reduce_bucket_size": 50000000,
16 | "overlap_comm": true
17 | },
18 | "sparse_attention": {
19 | "mode": "fixed",
20 | "block": 16,
21 | "different_layout_per_head": true,
22 | "num_local_blocks": 8,
23 | "num_global_blocks": 1,
24 | "attention": "unidirectional",
25 | "horizontal_global_attention": false,
26 | "num_different_global_patterns": 8
27 | }
28 | }
29 |
--------------------------------------------------------------------------------
/src/model/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) 2020, Sber. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from .distributed import *
17 | from .gpt3_modeling import gpt3_get_params_for_weight_decay_optimization
18 | from .gpt3_modeling import GPT3Model
19 |
--------------------------------------------------------------------------------
/src/fp16/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) 2020, Sber. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from .fp16util import (
16 | BN_convert_float,
17 | network_to_half,
18 | prep_param_lists,
19 | model_grads_to_master_grads,
20 | master_params_to_model_params,
21 | tofp16,
22 | to_python_float,
23 | clip_grad_norm,
24 | convert_module,
25 | convert_network,
26 | FP16Model,
27 | )
28 |
29 | from .fp16 import *
30 | from .loss_scaler import *
31 |
--------------------------------------------------------------------------------
/src/mpu/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) 2020, Sber. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Model parallel utility interface."""
17 |
18 | from .cross_entropy import vocab_parallel_cross_entropy
19 |
20 | from .data import broadcast_data
21 |
22 | from .grads import clip_grad_norm
23 |
24 | from .initialize import destroy_model_parallel
25 | from .initialize import get_data_parallel_group
26 | from .initialize import get_data_parallel_rank
27 | from .initialize import get_data_parallel_world_size
28 | from .initialize import get_model_parallel_group
29 | from .initialize import get_model_parallel_rank
30 | from .initialize import get_model_parallel_src_rank
31 | from .initialize import get_model_parallel_world_size
32 | from .initialize import initialize_model_parallel
33 | from .initialize import model_parallel_is_initialized
34 |
35 | from .layers import ColumnParallelLinear
36 | from .layers import ParallelEmbedding
37 | from .layers import RowParallelLinear
38 | from .layers import VocabParallelEmbedding
39 |
40 | from .mappings import copy_to_model_parallel_region
41 | from .mappings import gather_from_model_parallel_region
42 | from .mappings import reduce_from_model_parallel_region
43 | from .mappings import scatter_to_model_parallel_region
44 |
45 | from .random import checkpoint
46 | from .random import partition_activations_in_checkpoint
47 | from .random import get_cuda_rng_tracker
48 | from .random import model_parallel_cuda_manual_seed
49 |
50 | from .transformer import GPT3ParallelTransformer
51 | from .transformer import LayerNorm
52 |
--------------------------------------------------------------------------------
/src/mpu/utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) 2020, Sber. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | import torch
18 |
19 |
20 | def ensure_divisibility(numerator, denominator):
21 | """Ensure that numerator is divisible by the denominator."""
22 | assert numerator % denominator == 0, '{} is not divisible by {}'.format(
23 | numerator, denominator)
24 |
25 |
26 | def divide(numerator, denominator):
27 | """Ensure that numerator is divisible by the denominator and return
28 | the division value."""
29 | ensure_divisibility(numerator, denominator)
30 | return numerator // denominator
31 |
32 |
33 | def split_tensor_along_last_dim(tensor, num_partitions,
34 | contiguous_split_chunks=False):
35 | """Split a tensor along its last dimension.
36 | Arguments:
37 | tensor: input tensor.
38 | num_partitions: number of partitions to split the tensor
39 | contiguous_split_chunks: If True, make each chunk contiguous
40 | in memory.
41 | """
42 | # Get the size and dimension.
43 | last_dim = tensor.dim() - 1
44 | last_dim_size = divide(tensor.size()[last_dim], num_partitions)
45 | # Split.
46 | tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
47 | # Note: torch.split does not create contiguous tensors by default.
48 | if contiguous_split_chunks:
49 | return tuple(chunk.contiguous() for chunk in tensor_list)
50 |
51 | return tensor_list
52 |
53 |
54 | class VocabUtility:
55 | """Split the vocabulary into `world_size` chunks amd return the
56 | first and last index of the vocabulary belonging to the `rank`
57 | partition: Note that indecies in [fist, last)"""
58 |
59 | @staticmethod
60 | def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size,
61 | rank, world_size):
62 | index_f = rank * per_partition_vocab_size
63 | index_l = index_f + per_partition_vocab_size
64 | return index_f, index_l
65 |
66 | @staticmethod
67 | def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size):
68 | per_partition_vocab_size = divide(global_vocab_size, world_size)
69 | return VocabUtility.vocab_range_from_per_partition_vocab_size(
70 | per_partition_vocab_size, rank, world_size)
71 |
--------------------------------------------------------------------------------
/src/download_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from transformers.file_utils import (
4 | cached_path,
5 | hf_bucket_url,
6 | is_remote_url,
7 | )
8 | from transformers.utils import logging
9 |
10 | logger = logging.get_logger(__name__)
11 | WEIGHTS_NAME = "mp_rank_00_model_states.pt"
12 | DEEPSPEED_CONFIG_NAME = "deepspeed_config.json"
13 |
14 |
15 | def download_model_files(pretrained_model_name_or_path):
16 | weights_path = download_file_from_hf(pretrained_model_name_or_path, WEIGHTS_NAME)
17 | deepspeed_config_path = download_file_from_hf(pretrained_model_name_or_path, DEEPSPEED_CONFIG_NAME)
18 | return weights_path, deepspeed_config_path
19 |
20 |
21 | def download_file_from_hf(pretrained_model_name_or_path: str, file_name: str) -> str:
22 | # Load model
23 | if pretrained_model_name_or_path is not None:
24 | if os.path.isdir(pretrained_model_name_or_path):
25 | if os.path.isfile(os.path.join(pretrained_model_name_or_path, file_name)):
26 | # Load from a PyTorch checkpoint
27 | archive_file = os.path.join(pretrained_model_name_or_path, file_name)
28 | else:
29 | raise EnvironmentError(
30 | "Error no file named {} found in directory {}".format(
31 | file_name,
32 | pretrained_model_name_or_path,
33 | )
34 | )
35 | elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
36 | archive_file = pretrained_model_name_or_path
37 | else:
38 | archive_file = hf_bucket_url(
39 | pretrained_model_name_or_path,
40 | filename=file_name,
41 | revision=None,
42 | mirror=None,
43 | )
44 |
45 | try:
46 | # Load from URL or cache if already cached
47 | resolved_archive_file = cached_path(
48 | archive_file,
49 | cache_dir=None,
50 | force_download=False,
51 | proxies=None,
52 | resume_download=False,
53 | local_files_only=False,
54 | )
55 | except EnvironmentError as err:
56 | logger.error(err)
57 | msg = (
58 | f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
59 | f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on"
60 | f"'https://huggingface.co/models'\n\n"
61 | f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a"
62 | f"file named one of {file_name}.\n\n"
63 | )
64 | raise EnvironmentError(msg)
65 |
66 | if resolved_archive_file == archive_file:
67 | logger.info("loading weights file {}".format(archive_file))
68 | else:
69 | logger.info("loading weights file {} from cache at {}".format(archive_file, resolved_archive_file))
70 | else:
71 | resolved_archive_file = None
72 |
73 | return resolved_archive_file
74 |
--------------------------------------------------------------------------------
/src/mpu/grads.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) 2020, Sber. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | # Parts of the code here are adapted from PyTorch
18 | # repo: https://github.com/pytorch/pytorch
19 |
20 |
21 | import torch
22 | from torch._six import inf
23 |
24 | from .initialize import get_model_parallel_group
25 | from .initialize import get_model_parallel_rank
26 |
27 |
28 | def clip_grad_norm(parameters, max_norm, norm_type=2):
29 | """Clips gradient norm of an iterable of parameters.
30 |
31 | This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
32 | added functionality to handle model parallel parameters. Note that
33 | the gradients are modified in place.
34 |
35 | Arguments:
36 | parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
37 | single Tensor that will have gradients normalized
38 | max_norm (float or int): max norm of the gradients
39 | norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
40 | infinity norm.
41 |
42 | Returns:
43 | Total norm of the parameters (viewed as a single vector).
44 | """
45 | if isinstance(parameters, torch.Tensor):
46 | parameters = [parameters]
47 | parameters = list(filter(lambda p: p.grad is not None, parameters))
48 | max_norm = float(max_norm)
49 | norm_type = float(norm_type)
50 | if norm_type == inf:
51 | total_norm = max(p.grad.data.abs().max() for p in parameters)
52 | total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
53 | # Take max across all GPUs.
54 | torch.distributed.all_reduce(total_norm_cuda,
55 | op=torch.distributed.ReduceOp.MAX,
56 | group=get_model_parallel_group())
57 | total_norm = total_norm_cuda[0].item()
58 | else:
59 | total_norm = 0
60 | for p in parameters:
61 | if p.model_parallel or (get_model_parallel_rank() == 0):
62 | param_norm = p.grad.data.norm(norm_type)
63 | total_norm += param_norm.item() ** norm_type
64 | # Sum across all model parallel GPUs.
65 | total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
66 | torch.distributed.all_reduce(total_norm_cuda,
67 | op=torch.distributed.ReduceOp.SUM,
68 | group=get_model_parallel_group())
69 | total_norm = total_norm_cuda[0].item() ** (1. / norm_type)
70 | clip_coef = max_norm / (total_norm + 1e-6)
71 | if clip_coef < 1:
72 | for p in parameters:
73 | p.grad.data.mul_(clip_coef)
74 | return total_norm
75 |
--------------------------------------------------------------------------------
/src/learning_rates.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) 2020, Sber. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """PyTorch DataLoader for TFRecords"""
16 |
17 | import math
18 |
19 | import torch
20 | from torch.optim.lr_scheduler import _LRScheduler
21 |
22 |
23 | class AnnealingLR(_LRScheduler):
24 | """Anneals the learning rate from start to zero along a cosine curve."""
25 |
26 | DECAY_STYLES = ['linear', 'cosine', 'exponential', 'constant', 'None']
27 |
28 | def __init__(self, optimizer, start_lr, warmup_iter, num_iters, decay_style=None, last_iter=-1, min_lr=1e-6):
29 | self.optimizer = optimizer
30 | self.start_lr = start_lr
31 | self.warmup_iter = warmup_iter
32 | self.num_iters = last_iter + 1
33 | self.end_iter = num_iters
34 | self.decay_style = decay_style.lower() if isinstance(decay_style, str) else None
35 | self.step(self.num_iters)
36 | self.min_lr = min_lr
37 | self._last_lr = start_lr
38 | self._min_reached = False
39 | if torch.distributed.get_rank() == 0:
40 | print('learning rate decaying', decay_style)
41 |
42 | def get_lr(self):
43 | # https://openreview.net/pdf?id=BJYwwY9ll pg. 4
44 | if self.warmup_iter > 0 and self.num_iters <= self.warmup_iter:
45 | new_lr = float(self.start_lr) * self.num_iters / self.warmup_iter
46 | else:
47 | if self.decay_style == self.DECAY_STYLES[0]:
48 | lr = self.start_lr * ((self.end_iter - (self.num_iters - self.warmup_iter)) / self.end_iter)
49 | new_lr = max(self.min_lr, lr)
50 | elif self.decay_style == self.DECAY_STYLES[1]:
51 | new_lr = self.start_lr / 2.0 * (
52 | math.cos(math.pi * (self.num_iters - self.warmup_iter) / self.end_iter) + 1)
53 | if new_lr <= self.min_lr or self._min_reached or self.num_iters > self.end_iter:
54 | self._min_reached = True
55 | new_lr = self.min_lr
56 | elif self.decay_style == self.DECAY_STYLES[2]:
57 | # TODO: implement exponential decay
58 | new_lr = self.start_lr
59 | else:
60 | new_lr = self.start_lr
61 | self._last_lr = new_lr
62 | return new_lr
63 |
64 | def step(self, step_num=None):
65 | if step_num is None:
66 | step_num = self.num_iters + 1
67 | self.num_iters = step_num
68 | new_lr = self.get_lr()
69 | for group in self.optimizer.param_groups:
70 | group['lr'] = new_lr
71 |
72 | def state_dict(self):
73 | sd = {
74 | 'start_lr': self.start_lr,
75 | 'warmup_iter': self.warmup_iter,
76 | 'num_iters': self.num_iters,
77 | 'decay_style': self.decay_style,
78 | 'end_iter': self.end_iter
79 | }
80 | return sd
81 |
82 | def load_state_dict(self, sd):
83 | self.start_lr = sd['start_lr']
84 | self.warmup_iter = sd['warmup_iter']
85 | self.num_iters = sd['num_iters']
86 | self.end_iter = sd['end_iter']
87 | self.decay_style = sd['decay_style']
88 | self.step(self.num_iters)
89 |
--------------------------------------------------------------------------------
/src/mpu/data.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) 2020, Sber. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import torch
17 |
18 | from .initialize import get_model_parallel_group
19 | from .initialize import get_model_parallel_rank
20 | from .initialize import get_model_parallel_src_rank
21 |
22 |
23 | _MAX_DATA_DIM = 4
24 |
25 |
26 | def _check_data_types(keys, data, target_dtype):
27 | """Check that all the keys have the same target data type."""
28 | for key in keys:
29 | assert data[key].dtype == target_dtype, '{} has data type {} which '\
30 | 'is different than {}'.format(key, data[key].dtype, target_dtype)
31 |
32 |
33 | def _build_key_size_numel_dictionaries(keys, data):
34 | """Build the size on rank 0 and broadcast."""
35 | max_dim = _MAX_DATA_DIM
36 | sizes = [0 for _ in range(max_dim) for _ in keys]
37 |
38 | # Pack the sizes on rank zero.
39 | if get_model_parallel_rank() == 0:
40 | offset = 0
41 | for key in keys:
42 | assert data[key].dim() < max_dim, 'you should increase MAX_DATA_DIM'
43 | size = data[key].size()
44 | for i, s in enumerate(size):
45 | sizes[i + offset] = s
46 | offset += max_dim
47 |
48 | # Move to GPU and broadcast.
49 | sizes_cuda = torch.cuda.LongTensor(sizes)
50 | torch.distributed.broadcast(sizes_cuda, get_model_parallel_src_rank(),
51 | group=get_model_parallel_group())
52 |
53 | # Move back to cpu and unpack.
54 | sizes_cpu = sizes_cuda.cpu()
55 | key_size = {}
56 | key_numel = {}
57 | total_numel = 0
58 | offset = 0
59 | for key in keys:
60 | i = 0
61 | size = []
62 | numel = 1
63 | while sizes_cpu[offset + i] > 0:
64 | this_size = sizes_cpu[offset + i]
65 | size.append(this_size)
66 | numel *= this_size
67 | i += 1
68 | key_size[key] = size
69 | key_numel[key] = numel
70 | total_numel += numel
71 | offset += max_dim
72 |
73 | return key_size, key_numel, total_numel
74 |
75 |
76 | def broadcast_data(keys, data, datatype):
77 | """Broadcast data from rank zero of each model parallel group to the
78 | members of the same model parallel group.
79 |
80 | Arguments:
81 | keys: list of keys in the data disctionary to be broadcasted
82 | data: data dictionary of string keys and cpu tensor values.
83 | datatype: torch data type of all tensors in data associated
84 | with keys.
85 | """
86 | # Build (key, size) and (key, number of elements) dictionaries along
87 | # with the total number of elements on all ranks.
88 | key_size, key_numel, total_numel = _build_key_size_numel_dictionaries(keys,
89 | data)
90 |
91 | # Pack on rank zero.
92 | if get_model_parallel_rank() == 0:
93 | # Check that all keys have the same data type.
94 | _check_data_types(keys, data, datatype)
95 | # Flatten the data associated with the keys
96 | flatten_data = torch.cat(
97 | [data[key].contiguous().view(-1) for key in keys], dim=0).cuda()
98 | else:
99 | flatten_data = torch.empty(total_numel,
100 | device=torch.cuda.current_device(),
101 | dtype=datatype)
102 |
103 | # Boradcast
104 | torch.distributed.broadcast(flatten_data, get_model_parallel_src_rank(),
105 | group=get_model_parallel_group())
106 |
107 | # Unpack
108 | output = {}
109 | offset = 0
110 | for key in keys:
111 | size = key_size[key]
112 | numel = key_numel[key]
113 | output[key] = flatten_data.narrow(0, offset, numel).view(size)
114 | offset += numel
115 |
116 | return output
117 |
--------------------------------------------------------------------------------
/src/mpu/mappings.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) 2020, Sber. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import torch
17 |
18 | from .initialize import get_model_parallel_group
19 | from .utils import split_tensor_along_last_dim
20 |
21 |
22 | def _reduce(input_):
23 | """All-reduce the the input tensor across model parallel group."""
24 | group = get_model_parallel_group()
25 |
26 | # Bypass the function if we are using only 1 GPU.
27 | if torch.distributed.get_world_size(group=group) == 1:
28 | return input_
29 |
30 | # All-reduce.
31 | torch.distributed.all_reduce(input_, group=group)
32 |
33 | return input_
34 |
35 |
36 | def _split(input_):
37 | """Split the tensor along its last dimension and keep the
38 | corresponding slice."""
39 | group = get_model_parallel_group()
40 |
41 | # Bypass the function if we are using only 1 GPU.
42 | if torch.distributed.get_world_size(group=group) == 1:
43 | return input_
44 |
45 | # Split along last dimension.
46 | world_size = torch.distributed.get_world_size(group=group)
47 | input_list = split_tensor_along_last_dim(input_, world_size)
48 |
49 | # Note: torch.split does not create contiguous tensors by default.
50 | rank = torch.distributed.get_rank(group=group)
51 | output = input_list[rank].contiguous()
52 |
53 | return output
54 |
55 |
56 | def _gather(input_):
57 | """Gather tensors and concatinate along the last dimension."""
58 | group = get_model_parallel_group()
59 |
60 | # Bypass the function if we are using only 1 GPU.
61 | if torch.distributed.get_world_size(group=group) == 1:
62 | return input_
63 |
64 | # Size and dimension.
65 | last_dim = input_.dim() - 1
66 | rank = torch.distributed.get_rank(group=group)
67 | world_size = torch.distributed.get_world_size(group=group)
68 |
69 | tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
70 | tensor_list[rank] = input_
71 | torch.distributed.all_gather(tensor_list, input_, group=group)
72 |
73 | # Note: torch.cat already creates a contiguous tensor.
74 | output = torch.cat(tensor_list, dim=last_dim).contiguous()
75 |
76 | return output
77 |
78 |
79 | class _CopyToModelParallelRegion(torch.autograd.Function):
80 | """Pass the input to the model parallel region."""
81 |
82 | @staticmethod
83 | def forward(ctx, input_):
84 | return input_
85 |
86 | @staticmethod
87 | def backward(ctx, grad_output):
88 | return _reduce(grad_output)
89 |
90 |
91 | class _ReduceFromModelParallelRegion(torch.autograd.Function):
92 | """All-redcue the input from the model parallel region."""
93 |
94 | @staticmethod
95 | def forward(ctx, input_):
96 | return _reduce(input_)
97 |
98 | @staticmethod
99 | def backward(ctx, grad_output):
100 | return grad_output
101 |
102 |
103 | class _ScatterToModelParallelRegion(torch.autograd.Function):
104 | """Split the input and keep only the corresponding chuck to the rank."""
105 |
106 | @staticmethod
107 | def forward(ctx, input_):
108 | return _split(input_)
109 |
110 | @staticmethod
111 | def backward(ctx, grad_output):
112 | return _gather(grad_output)
113 |
114 |
115 | class _GatherFromModelParallelRegion(torch.autograd.Function):
116 | """Gather the input from model parallel region and concatinate."""
117 |
118 | @staticmethod
119 | def forward(ctx, input_):
120 | return _gather(input_)
121 |
122 | @staticmethod
123 | def backward(ctx, grad_output):
124 | return _split(grad_output)
125 |
126 |
127 | # -----------------
128 | # Helper functions.
129 | # -----------------
130 |
131 | def copy_to_model_parallel_region(input_):
132 | return _CopyToModelParallelRegion.apply(input_)
133 |
134 | def reduce_from_model_parallel_region(input_):
135 | return _ReduceFromModelParallelRegion.apply(input_)
136 |
137 | def scatter_to_model_parallel_region(input_):
138 | return _ScatterToModelParallelRegion.apply(input_)
139 |
140 | def gather_from_model_parallel_region(input_):
141 | return _GatherFromModelParallelRegion.apply(input_)
142 |
--------------------------------------------------------------------------------
/src/mpu/cross_entropy.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) 2020, Sber. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | import torch
18 |
19 | from .initialize import get_model_parallel_group
20 | from .initialize import get_model_parallel_rank
21 | from .initialize import get_model_parallel_world_size
22 | from .utils import VocabUtility
23 |
24 |
25 | class _VocabParallelCrossEntropy(torch.autograd.Function):
26 |
27 | @staticmethod
28 | def forward(ctx, vocab_parallel_logits, target):
29 |
30 | # Copy so the input remains unchanged.
31 | logits = vocab_parallel_logits.clone()
32 | # Maximum value along vocab dimension across all GPUs.
33 | logits_max = torch.max(logits, dim=-1)[0]
34 | torch.distributed.all_reduce(logits_max,
35 | op=torch.distributed.ReduceOp.MAX,
36 | group=get_model_parallel_group())
37 | # Subtract the maximum value.
38 | logits.sub_(logits_max.unsqueeze(dim=-1))
39 | # Sum of exponential of logits along vocab dimension across all GPUs.
40 | exp_logits = logits.exp()
41 | sum_exp_logits = exp_logits.sum(dim=-1)
42 | torch.distributed.all_reduce(sum_exp_logits,
43 | op=torch.distributed.ReduceOp.SUM,
44 | group=get_model_parallel_group())
45 |
46 | # Get the partition's vocab indecies
47 | get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size
48 | partition_vocab_size = vocab_parallel_logits.size()[-1]
49 | rank = get_model_parallel_rank()
50 | world_size = get_model_parallel_world_size()
51 | vocab_start_index, vocab_end_index = get_vocab_range(
52 | partition_vocab_size, rank, world_size)
53 |
54 | # Create a mask of valid vocab ids (1 means it needs to be masked).
55 | target_mask = (target < vocab_start_index) | (target >= vocab_end_index)
56 | masked_target = target.clone() - vocab_start_index
57 | masked_target[target_mask] = 0
58 |
59 | # Get predicted-logits = logits[target].
60 | # For Simplicity, we convert logits to a 2-D tensor with size
61 | # [*, partition-vocab-size] and target to a 1-D tensor of size [*].
62 | logits_2d = logits.view(-1, partition_vocab_size)
63 | masked_target_1d = masked_target.view(-1)
64 | arange_1d = torch.arange(start=0, end=logits_2d.size()[0],
65 | device=logits_2d.device)
66 | predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
67 | predicted_logits = predicted_logits_1d.view_as(target)
68 | predicted_logits[target_mask] = 0.0
69 | # All reduce is needed to get the chunks from other GPUs.
70 | torch.distributed.all_reduce(predicted_logits,
71 | op=torch.distributed.ReduceOp.SUM,
72 | group=get_model_parallel_group())
73 |
74 | # Loss = log(sum(exp(logits))) - predicted-logit.
75 | loss = torch.log(sum_exp_logits) - predicted_logits
76 |
77 | # Store softmax, target-mask and masked-target for backward pass.
78 | exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
79 | ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)
80 |
81 | return loss
82 |
83 | @staticmethod
84 | def backward(ctx, grad_output):
85 |
86 | # Retreive tensors from the forward path.
87 | softmax, target_mask, masked_target_1d = ctx.saved_tensors
88 |
89 | # All the inputs have softmax as thier gradient.
90 | grad_input = softmax
91 | # For simplicity, work with the 2D gradient.
92 | partition_vocab_size = softmax.size()[-1]
93 | grad_2d = grad_input.view(-1, partition_vocab_size)
94 |
95 | # Add the gradient from matching classes.
96 | arange_1d = torch.arange(start=0, end=grad_2d.size()[0],
97 | device=grad_2d.device)
98 | grad_2d[arange_1d, masked_target_1d] -= (
99 | 1.0 - target_mask.view(-1).float())
100 |
101 | # Finally elementwise multiplication with the output gradients.
102 | grad_input.mul_(grad_output.unsqueeze(dim=-1))
103 |
104 | return grad_input, None
105 |
106 |
107 | def vocab_parallel_cross_entropy(vocab_parallel_logits, target):
108 | """Helper function for the cross entropy."""
109 | return _VocabParallelCrossEntropy.apply(vocab_parallel_logits, target)
110 |
--------------------------------------------------------------------------------
/src/gpt3_data_loader.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) 2020, Sber. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import os
17 |
18 | import torch
19 | from torch.utils.data import BatchSampler, DataLoader
20 |
21 | from src import mpu
22 | from src.dataset_rugpt3 import RuGpt3TextDataset, RuGpt3DatasetArguments
23 | from src.utils import print_rank_0
24 | from transformers import GPT2Tokenizer
25 |
26 |
27 | class InfiniteDataLoader(DataLoader):
28 | def __init__(self, *args, **kwargs):
29 | super().__init__(*args, **kwargs)
30 | # Initialize an iterator over the dataset.
31 | self.dataset_iterator = super().__iter__()
32 |
33 | def __iter__(self):
34 | return self
35 |
36 | def __next__(self):
37 | try:
38 | batch = next(self.dataset_iterator)
39 | except StopIteration:
40 | # Dataset exhausted, use a new fresh iterator.
41 | self.dataset_iterator = super().__iter__()
42 | batch = next(self.dataset_iterator)
43 | return batch
44 |
45 |
46 | class ResumableBatchSampler(BatchSampler):
47 | start_iter = 0
48 |
49 | def __iter__(self):
50 | batch = []
51 | i = 0
52 | for idx in self.sampler:
53 | batch.append(idx)
54 | if len(batch) == self.batch_size:
55 | if i >= self.start_iter:
56 | yield batch
57 | batch = []
58 | i += 1
59 | if len(batch) > 0 and not self.drop_last:
60 | yield batch
61 |
62 |
63 | def make_gpt3_dataloaders(args):
64 | # Data parallel arguments
65 | world_size = mpu.get_data_parallel_world_size()
66 | rank = mpu.get_data_parallel_rank()
67 | # global_batch_size = args.batch_size * world_size
68 | num_workers = args.num_workers
69 |
70 | # data_dir = args.train_data_path if args.train_data_path else os.path.dirname(args.test_data_path)
71 | tokenizer_path = args.load_huggingface if args.load_huggingface else \
72 | (args.tokenizer_path if args.tokenizer_path else os.path.join(os.path.dirname(args.train_data_path),
73 | '_tokenizer/'))
74 | print_rank_0('Load tokenizer from ' + tokenizer_path)
75 | tokenizer = GPT2Tokenizer.from_pretrained(tokenizer_path)
76 | tokenizer.add_special_tokens({"bos_token": ""})
77 | tokenizer.add_special_tokens({"eos_token": ""})
78 |
79 | print("Add answer_sep:", args.answer_sep)
80 | tokenizer.add_tokens(args.answer_sep)
81 |
82 | print("Add start_sep", args.start_sep)
83 | tokenizer.add_tokens(args.start_sep)
84 |
85 | print("Add start_sep", args.end_sep)
86 | tokenizer.add_tokens(args.end_sep)
87 |
88 | eod_token = tokenizer.encoder['']
89 | num_tokens = len(tokenizer)
90 |
91 | train_dataset_args = RuGpt3DatasetArguments(
92 | block_size=args.seq_length, max_files_load=args.max_files_per_process, overwrite_cache=args.overwrite_cache,
93 | tqdm=False)
94 | eval_dataset_args = RuGpt3DatasetArguments(
95 | block_size=args.seq_length, max_files_load=args.max_files_per_process, overwrite_cache=args.overwrite_cache,
96 | tqdm=True)
97 |
98 | def make_data_loader_(data_path, dataset_args):
99 | print_rank_0(f'Load RuGPT3 Dataset from {data_path}, {dataset_args.max_files_load} files per process')
100 | dataset = RuGpt3TextDataset(
101 | tokenizer=tokenizer,
102 | args=dataset_args,
103 | rank=rank,
104 | world_size=world_size,
105 | file_path=data_path,
106 | # cache_prefix=args.cache_prefix
107 | all_args=args
108 | )
109 | # Use a simple sampler with distributed batch sampler.
110 | sampler = torch.utils.data.SequentialSampler(dataset)
111 | batch_sampler = ResumableBatchSampler(sampler=sampler,
112 | batch_size=args.batch_size,
113 | drop_last=True)
114 |
115 | return InfiniteDataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers, pin_memory=True)
116 |
117 | train = make_data_loader_(args.train_data_path, train_dataset_args) if args.train_data_path else None
118 | valid = make_data_loader_(args.val_data_path, eval_dataset_args) if args.val_data_path else None
119 | test = make_data_loader_(args.test_data_path, eval_dataset_args) if args.test_data_path else None
120 |
121 | args.do_train = train is not None
122 | args.do_valid = valid is not None
123 | args.do_test = test is not None
124 |
125 | return (train, valid, test), num_tokens, eod_token, tokenizer
126 |
--------------------------------------------------------------------------------
/src/model/distributed.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) 2020, Sber. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import torch
17 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
18 | import torch.distributed as dist
19 | from torch.nn.modules import Module
20 | from torch.autograd import Variable
21 |
22 | from src import mpu
23 |
24 |
25 | class DistributedDataParallel(Module):
26 |
27 | def __init__(self, module):
28 | super(DistributedDataParallel, self).__init__()
29 | self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False
30 |
31 | self.module = module
32 | self.data_parallel_group = mpu.get_data_parallel_group()
33 | src_rank = mpu.get_model_parallel_rank()
34 | for p in self.module.parameters():
35 | if torch.is_tensor(p):
36 | dist.broadcast(p, src_rank, group=self.data_parallel_group)
37 |
38 | def allreduce_params(reduce_after=True, no_scale=False, fp32_allreduce=False):
39 | if(self.needs_reduction):
40 | self.needs_reduction = False
41 | buckets = {}
42 | for name, param in self.module.named_parameters():
43 | if param.requires_grad and param.grad is not None:
44 | tp = (param.data.type())
45 | if tp not in buckets:
46 | buckets[tp] = []
47 | buckets[tp].append(param)
48 | if self.warn_on_half:
49 | if torch.cuda.HalfTensor in buckets:
50 | print("WARNING: gloo dist backend for half parameters may be extremely slow." +
51 | " It is recommended to use the NCCL backend in this case.")
52 | self.warn_on_half = False
53 | for tp in buckets:
54 | bucket = buckets[tp]
55 | grads = [param.grad.data for param in bucket]
56 | coalesced = _flatten_dense_tensors(grads)
57 | if fp32_allreduce:
58 | coalesced = coalesced.float()
59 | if not no_scale and not reduce_after:
60 | coalesced /= dist.get_world_size(group=self.data_parallel_group)
61 | dist.all_reduce(coalesced, group=self.data_parallel_group)
62 | torch.cuda.synchronize()
63 | if not no_scale and reduce_after:
64 | coalesced /= dist.get_world_size(group=self.data_parallel_group)
65 | for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
66 | buf.copy_(synced)
67 | self.hook_handles = []
68 | self.hooks = []
69 | for param in list(self.module.parameters()):
70 | def allreduce_hook(*unused):
71 | Variable._execution_engine.queue_callback(allreduce_params)
72 | # handle = param.register_hook(allreduce_hook)
73 | #self.hooks.append(allreduce_hook)
74 | #self.hook_handles.append(handle)
75 | self.allreduce_params = allreduce_params
76 |
77 | def forward(self, *inputs, **kwargs):
78 | self.needs_reduction = True
79 | return self.module(*inputs, **kwargs)
80 |
81 | def state_dict(self, destination=None, prefix='', keep_vars=False):
82 | #[h.remove() for h in self.hook_handles]
83 | sd = self.module.state_dict(destination, prefix, keep_vars)
84 | # for handle, hook in zip(self.hook_handles, self.hooks):
85 | # d = handle.hooks_dict_ref()
86 | # d[handle.id] = hook
87 |
88 | return sd
89 |
90 | def load_state_dict(self, state_dict, strict=True):
91 | self.module.load_state_dict(state_dict, strict=strict)
92 |
93 | '''
94 | def _sync_buffers(self):
95 | buffers = list(self.module._all_buffers())
96 | if len(buffers) > 0:
97 | # cross-node buffer sync
98 | flat_buffers = _flatten_dense_tensors(buffers)
99 | dist.broadcast(flat_buffers, 0)
100 | for buf, synced in zip(buffers, _unflatten_dense_tensors(flat_buffers, buffers)):
101 | buf.copy_(synced)
102 | def train(self, mode=True):
103 | # Clear NCCL communicator and CUDA event cache of the default group ID,
104 | # These cache will be recreated at the later call. This is currently a
105 | # work-around for a potential NCCL deadlock.
106 | if dist._backend == dist.dist_backend.NCCL:
107 | dist._clear_group_cache()
108 | super(DistributedDataParallel, self).train(mode)
109 | self.module.train(mode)
110 | '''
111 |
112 |
--------------------------------------------------------------------------------
/src/mpu/initialize.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) 2020, Sber. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | """Model and data parallel groups."""
18 |
19 | import torch
20 |
21 | from .utils import ensure_divisibility
22 |
23 |
24 | # Model parallel group that the current rank belongs to.
25 | _MODEL_PARALLEL_GROUP = None
26 | # Data parallel group that the current rank belongs to.
27 | _DATA_PARALLEL_GROUP = None
28 |
29 |
30 | def initialize_model_parallel(model_parallel_size_):
31 | """
32 | Initialize model data parallel groups.
33 |
34 | Arguments:
35 | model_parallel_size: number of GPUs used to parallelize model.
36 |
37 | Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
38 | use 2 GPUs to parallelize the model. The present function will
39 | create 4 model parallel groups and 2 data parallel grous as:
40 | 4 model parallel groups:
41 | [g0, g1], [g2, g3], [g4, g5], [g6, g7]
42 | 2 data parallel groups:
43 | [g0, g2, g4, g6], [g1, g3, g5, g7]
44 | Note that for efficiency, the caller should make sure adjacent ranks
45 | are on the same DGX box. For example if we are using 2 DGX-1 boxes
46 | with a total of 16 GPUs, rank 0 to 7 belong to the first box and
47 | ranks 8 to 15 belong to the second box.
48 | """
49 | if torch.distributed.get_rank() == 0:
50 | print('> initializing model parallel with size {}'.format(
51 | model_parallel_size_))
52 | # Get world size and rank. Ensure some consistencies.
53 | assert torch.distributed.is_initialized()
54 | world_size = torch.distributed.get_world_size()
55 | model_parallel_size = min(model_parallel_size_, world_size)
56 | ensure_divisibility(world_size, model_parallel_size)
57 | rank = torch.distributed.get_rank()
58 |
59 | # Build the data parallel groups.
60 | global _DATA_PARALLEL_GROUP
61 | assert _DATA_PARALLEL_GROUP is None, \
62 | 'data parallel group is already initialized'
63 | for i in range(model_parallel_size):
64 | ranks = range(i, world_size, model_parallel_size)
65 | group = torch.distributed.new_group(ranks)
66 | if i == (rank % model_parallel_size):
67 | _DATA_PARALLEL_GROUP = group
68 |
69 | # Build the model parallel groups.
70 | global _MODEL_PARALLEL_GROUP
71 | assert _MODEL_PARALLEL_GROUP is None, \
72 | 'model parallel group is already initialized'
73 | for i in range(world_size // model_parallel_size):
74 | ranks = range(i * model_parallel_size,
75 | (i + 1) * model_parallel_size)
76 | group = torch.distributed.new_group(ranks)
77 | if i == (rank // model_parallel_size):
78 | _MODEL_PARALLEL_GROUP = group
79 |
80 |
81 | def model_parallel_is_initialized():
82 | """Check if model and data parallel groups are initialized."""
83 | if _MODEL_PARALLEL_GROUP is None or _DATA_PARALLEL_GROUP is None:
84 | return False
85 | return True
86 |
87 |
88 | def get_model_parallel_group():
89 | """Get the model parallel group the caller rank belongs to."""
90 | assert _MODEL_PARALLEL_GROUP is not None, \
91 | 'model parallel group is not initialized'
92 | return _MODEL_PARALLEL_GROUP
93 |
94 |
95 | def get_data_parallel_group():
96 | """Get the data parallel group the caller rank belongs to."""
97 | assert _DATA_PARALLEL_GROUP is not None, \
98 | 'data parallel group is not initialized'
99 | return _DATA_PARALLEL_GROUP
100 |
101 |
102 | def get_model_parallel_world_size():
103 | """Return world size for the model parallel group."""
104 | return torch.distributed.get_world_size(group=get_model_parallel_group())
105 |
106 |
107 | def get_model_parallel_rank():
108 | """Return my rank for the model parallel group."""
109 | return torch.distributed.get_rank(group=get_model_parallel_group())
110 |
111 |
112 | def get_model_parallel_src_rank():
113 | """Calculate the global rank corresponding to a local rank zeor
114 | in the model parallel group."""
115 | global_rank = torch.distributed.get_rank()
116 | local_world_size = get_model_parallel_world_size()
117 | return (global_rank // local_world_size) * local_world_size
118 |
119 |
120 | def get_data_parallel_world_size():
121 | """Return world size for the data parallel group."""
122 | return torch.distributed.get_world_size(group=get_data_parallel_group())
123 |
124 |
125 | def get_data_parallel_rank():
126 | """Return my rank for the data parallel group."""
127 | return torch.distributed.get_rank(group=get_data_parallel_group())
128 |
129 |
130 | def destroy_model_parallel():
131 | """Set the groups to none."""
132 | global _MODEL_PARALLEL_GROUP
133 | _MODEL_PARALLEL_GROUP = None
134 | global _DATA_PARALLEL_GROUP
135 | _DATA_PARALLEL_GROUP = None
136 |
--------------------------------------------------------------------------------
/src/model/gpt3_modeling.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) 2020, Sber. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """GPT-3 model."""
17 |
18 | import torch
19 | import torch.nn.functional as F
20 |
21 | from src import mpu
22 |
23 |
24 | def init_method_normal(std=0.02):
25 | """Init method based on normal distribution.
26 |
27 | This is only used for embeddings. The transformer has its
28 | own initializer.
29 | """
30 | def init_(tensor):
31 | return torch.nn.init.normal_(tensor, mean=0.0, std=std)
32 | return init_
33 |
34 |
35 | class GPT3Model(torch.nn.Module):
36 | """GPT-3 Language model.
37 |
38 | The output of the forward method are the logits (parallel or
39 | serial depending on the `parallel_output` flag.
40 | """
41 |
42 | def __init__(self,
43 | num_layers,
44 | vocab_size,
45 | hidden_size,
46 | num_attention_heads,
47 | embedding_dropout_prob,
48 | attention_dropout_prob,
49 | output_dropout_prob,
50 | max_sequence_length,
51 | checkpoint_activations,
52 | checkpoint_num_layers=1,
53 | parallel_output=True,
54 | deepspeed_sparsity_config=None,
55 | sparse_mode=None):
56 |
57 | super(GPT3Model, self).__init__()
58 |
59 | self._conf_dict = {
60 | 'vocab_size': vocab_size,
61 | 'n_positions': max_sequence_length,
62 | 'n_ctx': max_sequence_length,
63 | 'n_embd': hidden_size,
64 | 'n_layer': num_layers,
65 | 'n_head': num_attention_heads
66 | }
67 |
68 | self.parallel_output = parallel_output
69 |
70 | init_method = init_method_normal(std=0.02)
71 |
72 | # Word embeddings (parallel).
73 | self.word_embeddings = mpu.VocabParallelEmbedding(
74 | vocab_size, hidden_size, init_method=init_method)
75 |
76 | # Position embedding (serial).
77 | self.position_embeddings = torch.nn.Embedding(max_sequence_length,
78 | hidden_size)
79 | # Initialize the position embeddings.
80 | init_method(self.position_embeddings.weight)
81 |
82 | # Embeddings dropout
83 | self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
84 |
85 | # Transformer
86 | self.transformer = mpu.GPT3ParallelTransformer(num_layers,
87 | hidden_size,
88 | num_attention_heads,
89 | attention_dropout_prob,
90 | output_dropout_prob,
91 | checkpoint_activations,
92 | checkpoint_num_layers,
93 | use_deepspeed_sparse=deepspeed_sparsity_config,
94 | sparse_mode=sparse_mode)
95 |
96 | def forward(self, input_ids, position_ids, attention_mask):
97 |
98 | # Embeddings.
99 | # print('input ids tensor', input_ids.size(), input_ids[0,:2])
100 | words_embeddings = self.word_embeddings(input_ids)
101 | position_embeddings = self.position_embeddings(position_ids)
102 | embeddings = words_embeddings + position_embeddings
103 |
104 | # Dropout.
105 | embeddings = self.embedding_dropout(embeddings)
106 |
107 | # Transformer.
108 | transformer_output = self.transformer(embeddings, attention_mask)
109 |
110 | # Parallel logits.
111 | transformer_output_parallel = mpu.copy_to_model_parallel_region(
112 | transformer_output)
113 | logits_parallel = F.linear(transformer_output_parallel,
114 | self.word_embeddings.weight)
115 |
116 | if self.parallel_output:
117 | return logits_parallel
118 |
119 | return mpu.gather_from_model_parallel_region(logits_parallel)
120 |
121 |
122 | def gpt3_get_params_for_weight_decay_optimization(module):
123 |
124 | weight_decay_params = {'params': []}
125 | no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
126 | for module_ in module.modules():
127 | if isinstance(module_, (mpu.LayerNorm, torch.nn.LayerNorm)):
128 | no_weight_decay_params['params'].extend(
129 | [p for p in list(module_._parameters.values())
130 | if p is not None])
131 | else:
132 | weight_decay_params['params'].extend(
133 | [p for n, p in list(module_._parameters.items())
134 | if p is not None and n != 'bias'])
135 | no_weight_decay_params['params'].extend(
136 | [p for n, p in list(module_._parameters.items())
137 | if p is not None and n == 'bias'])
138 |
139 | return weight_decay_params, no_weight_decay_params
140 |
--------------------------------------------------------------------------------
/src/dataset_rugpt3.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import pickle
4 | import random
5 | from dataclasses import dataclass, field
6 | from typing import Optional
7 |
8 | import numpy as np
9 | import torch
10 | from torch.utils.data import Dataset
11 | from tqdm import tqdm
12 |
13 | logger = logging.getLogger(__name__)
14 | logger.setLevel(logging.INFO)
15 |
16 |
17 | @dataclass
18 | class RuGpt3DatasetArguments:
19 | train_data_file: Optional[str] = field(
20 | default=None, metadata={"help": "The input training data file (a text file)."}
21 | )
22 | eval_data_file: Optional[str] = field(
23 | default=None,
24 | metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
25 | )
26 | block_size: int = field(
27 | default=-1,
28 | metadata={
29 | "help": "Optional input sequence length after tokenization."
30 | "The training dataset will be truncated in block of this size for training."
31 | "Default to the model max input length for single sentence inputs"
32 | " (take into account special tokens)."
33 | },
34 | )
35 | overwrite_cache: bool = field(
36 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
37 | )
38 | random_shift: bool = field(default=False, metadata={"help": "Make random shift from start of each file"})
39 | max_files_load: int = field(default=50000, metadata={"help": "Maximum number of files to load at one worker"})
40 | tqdm: bool = field(default=False, metadata={"help": "Show tqdm progress bar"})
41 |
42 |
43 | class RuGpt3TextDataset(Dataset):
44 | def process_file(self, file_path, filename, tokenizer, args):
45 | cached_features_file = os.path.join(self._cache_dir, filename.replace('/', '_') + '.pkl')
46 | examples = []
47 |
48 | if os.path.exists(cached_features_file) and not args.overwrite_cache:
49 | with open(cached_features_file, "rb") as handle:
50 | try:
51 | examples = pickle.load(handle)
52 | examples = np.asarray(examples, dtype=np.int32)
53 | except Exception as e:
54 | print('Failed to load cache file:', cached_features_file)
55 | raise e
56 | else:
57 | examples = []
58 | with open(os.path.join(file_path, filename), encoding="utf-8") as f:
59 | text = f.read()
60 | if self.args.line_by_line:
61 | lines = [x + "" for x in text.strip().split("")]
62 | for line in lines:
63 | line = tokenizer.encode(line)
64 | line += [tokenizer.encoder['']] * self.args.seq_length
65 | line = line[:self.args.seq_length]
66 | examples.append(line)
67 | else:
68 | tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
69 |
70 | max_shift = max(min(args.block_size, len(tokenized_text) - args.block_size), 0)
71 | rnd_shift = random.randrange(max_shift) if max_shift and args.random_shift else 0
72 |
73 | for i in range(rnd_shift, len(tokenized_text) - args.block_size + 1, args.block_size):
74 | example = tokenized_text[i:i + args.block_size]
75 | if None in example:
76 | raise Exception('None in tokens!: ' + filename)
77 | if len(example) == args.block_size:
78 | examples.append(example)
79 | # Note that we are loosing the last truncated example here for the sake of simplicity (no padding)
80 | # If your dataset is small, first you should loook for a bigger one :-) and second you
81 | # can change this behavior by adding (model specific) padding.
82 |
83 | with open(cached_features_file, "wb") as handle:
84 | pickle.dump(examples, handle, protocol=pickle.HIGHEST_PROTOCOL)
85 | examples = np.asarray(examples, dtype=np.int32)
86 |
87 | return examples
88 |
89 | def __init__(self, tokenizer, args, rank, world_size, file_path, cache_prefix='_', all_args=None):
90 | self.rank = rank
91 | self.world_size = world_size
92 | self.log(f"Loading dataset {file_path}")
93 | max_file_load = args.max_files_load
94 | self.args = all_args
95 |
96 | file_with_list = file_path
97 | file_path = os.path.dirname(file_with_list)
98 | self.log(f"Check filelist {file_with_list} with root dir {file_path}")
99 |
100 | if not os.path.exists(file_with_list) and rank < 1:
101 | raise Exception('No file list!')
102 |
103 | with open(file_with_list, 'r') as fp:
104 | files = [line.strip() for line in fp.read().split('\n') if line]
105 |
106 | if rank == -1:
107 | self.log('Shuffle')
108 | random.shuffle(files)
109 | if len(files) > max_file_load:
110 | files = files[:max_file_load]
111 | else:
112 | shard_size = len(files) // world_size
113 | if shard_size > max_file_load:
114 | logger.warning(
115 | f"Shard size {shard_size} > max_file_load {max_file_load},"
116 | f" only first {(max_file_load * world_size)}"
117 | f" files of dataset would be loaded!")
118 | shard_size = max_file_load
119 | shard_start = rank * shard_size
120 | shard_end = (rank + 1) * shard_size
121 | self.log(f"Shard [{shard_start}, {shard_end}]")
122 | files = files[shard_start:shard_end]
123 |
124 | self._cache_dir = os.path.join(file_path, f'{cache_prefix}cache_{args.block_size}_{len(tokenizer)}')
125 | os.makedirs(self._cache_dir, exist_ok=True)
126 | if args.overwrite_cache:
127 | self.log('Overwrite cache ' + self._cache_dir)
128 |
129 | examples = []
130 | iterator = tqdm(files) if args.tqdm else files
131 | for i, filename in enumerate(iterator):
132 | if i % 1000 == 0:
133 | self.log(f"Loaded {i}/{len(files)} files")
134 | example = self.process_file(file_path, filename, tokenizer, args=args)
135 | if example.size:
136 | examples.append(example)
137 | self.examples = np.vstack(examples)
138 | np.random.shuffle(self.examples)
139 | self.log(f"Loaded {len(self.examples)} examples, {self.examples.size} tokens")
140 |
141 | def log(self, msg):
142 | logger.warning(f"R{self.rank}/{self.world_size}: {msg}")
143 |
144 | def __len__(self):
145 | return len(self.examples)
146 |
147 | def __getitem__(self, item):
148 | item = item % len(self.examples) # infinite loop, modulo dataset size
149 | if len(self.examples[item]) == 0:
150 | item = random.randint(1, len(self.examples))
151 | return torch.tensor(self.examples[item], dtype=torch.long)
152 |
--------------------------------------------------------------------------------
/src/data_utils/lazy_loader.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) 2020, Sber. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """utils for loading text from disk"""
16 | import mmap
17 | import os
18 | import pickle as pkl
19 | import time
20 | from itertools import accumulate
21 |
22 | import torch
23 | from torch.multiprocessing import Lock
24 |
25 |
26 | def get_lazy_path(path):
27 | """
28 | Gets directory path where lazy files are stored.
29 | """
30 | return os.path.splitext(path)[0] + '.lazy'
31 |
32 |
33 | def exists_lazy(path, data_type='data'):
34 | """
35 | Check if we've already made a lazy version of this file for the `data_type` field.
36 | """
37 | if not os.path.exists(get_lazy_path(path)):
38 | return False
39 | contents = os.listdir(get_lazy_path(path))
40 | if data_type not in contents:
41 | return False
42 | if data_type + '.len.pkl' not in contents:
43 | return False
44 | return True
45 |
46 |
47 | def make_lazy(path, strs, data_type='data'):
48 | """
49 | Make lazy version of `data_type` field of the file. Byte offsets
50 | corresponding to data indices are stored in a `.len.pkl` data file.
51 | """
52 | lazypath = get_lazy_path(path)
53 | if not os.path.exists(lazypath):
54 | os.makedirs(lazypath)
55 | datapath = os.path.join(lazypath, data_type)
56 | lenpath = os.path.join(lazypath, data_type + '.len.pkl')
57 | if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
58 | with open(datapath, 'wb') as f:
59 | str_lens = []
60 | str_cnt = 0
61 | for s in strs:
62 | if isinstance(s, dict):
63 | s = s['text']
64 | encoded = s.encode('utf-8')
65 | f.write(encoded)
66 | str_cnt = len(encoded)
67 | str_lens.append(str_cnt)
68 | pkl.dump(str_lens, open(lenpath, 'wb'))
69 | else:
70 | while not os.path.exists(lenpath):
71 | time.sleep(1)
72 |
73 |
74 | def split_strings(strings, start, chr_lens):
75 | """
76 | Split strings based on string lengths and given start.
77 | """
78 | return [strings[i - start:j - start] for i, j in zip([start] + chr_lens[:-1], chr_lens)]
79 |
80 |
81 | class ProcessorTokenizer:
82 | """
83 | callable class that runs a preprocessing, as well as tokenization step,
84 | on input text.
85 | """
86 |
87 | def __init__(self, tokenizer, process_fn=None):
88 | self.tokenizer = tokenizer
89 | self.process_fn = process_fn
90 |
91 | def __call__(self, string):
92 | if self.tokenizer is not None:
93 | string = self.tokenizer(string, process_fn=self.process_fn)
94 | elif self.process_fn is not None:
95 | string = self.process_fn(string)
96 | return string
97 |
98 |
99 | class lazy_array_loader(object):
100 | """
101 | Arguments:
102 | path: path to directory where array entries are concatenated into one big string file
103 | and the .len file are located
104 | data_type (str): Some datsets have multiple fields that are stored in different paths.
105 | `data_type` specifies which of these fields to load in this class
106 | mem_map (boolean): Specifies whether to memory map file `path`
107 | map_fn (callable): Fetched strings are passed through map_fn before being returned.
108 |
109 | Example of lazy loader directory structure:
110 | file.json
111 | file.lazy/
112 | data_type1
113 | data_type1.len.pkl
114 | data_type2
115 | data_type2.len.pkl
116 | """
117 |
118 | def __init__(self, path, data_type='data', mem_map=False, map_fn=None):
119 | lazypath = get_lazy_path(path)
120 | datapath = os.path.join(lazypath, data_type)
121 | # get file where array entries are concatenated into one big string
122 | self._file = open(datapath, 'rb')
123 | self.file = self._file
124 | # memory map file if necessary
125 | self.mem_map = mem_map
126 | if self.mem_map:
127 | self.file = mmap.mmap(self.file.fileno(), 0, prot=mmap.PROT_READ)
128 | lenpath = os.path.join(lazypath, data_type + '.len.pkl')
129 | self.lens = pkl.load(open(lenpath, 'rb'))
130 | self.ends = list(accumulate(self.lens))
131 | self.dumb_ends = list(self.ends)
132 | self.read_lock = Lock()
133 | self.process_fn = map_fn
134 | self.map_fn = map_fn
135 | self._tokenizer = None
136 |
137 | def SetTokenizer(self, tokenizer):
138 | """
139 | logic to set and remove (set to None) tokenizer.
140 | combines preprocessing/tokenization into one callable.
141 | """
142 | if tokenizer is None:
143 | if not hasattr(self, '_tokenizer'):
144 | self._tokenizer = tokenizer
145 | else:
146 | self._tokenizer = tokenizer
147 | self.map_fn = ProcessorTokenizer(tokenizer, self.process_fn)
148 |
149 | def GetTokenizer(self):
150 | return self._tokenizer
151 |
152 | def __getitem__(self, index):
153 | """
154 | read file and splice strings based on string ending array `self.ends`
155 | """
156 | if not isinstance(index, slice):
157 | if index == 0:
158 | start = 0
159 | else:
160 | start = self.ends[index - 1]
161 | end = self.ends[index]
162 | rtn = self.file_read(start, end)
163 | if self.map_fn is not None:
164 | return self.map_fn(rtn)
165 | else:
166 | # if slice, fetch strings with 1 diskread and then splice in memory
167 | chr_lens = self.ends[index]
168 | if index.start == 0 or index.start is None:
169 | start = 0
170 | else:
171 | start = self.ends[index.start - 1]
172 | stop = chr_lens[-1]
173 | strings = self.file_read(start, stop)
174 | rtn = split_strings(strings, start, chr_lens)
175 | if self.map_fn is not None:
176 | return self.map_fn([s for s in rtn])
177 | return rtn
178 |
179 | def __len__(self):
180 | return len(self.ends)
181 |
182 | def file_read(self, start=0, end=None):
183 | """read specified portion of file"""
184 |
185 | # atomic reads to avoid race conditions with multiprocess dataloader
186 | self.read_lock.acquire()
187 | # seek to start of file read
188 | self.file.seek(start)
189 | # read to end of file if no end point provided
190 | if end is None:
191 | rtn = self.file.read()
192 | # else read amount needed to reach end point
193 | else:
194 | rtn = self.file.read(end - start)
195 | self.read_lock.release()
196 | # TODO: @raulp figure out mem map byte string bug
197 | # if mem map'd need to decode byte string to string
198 | rtn = rtn.decode('utf-8', 'ignore')
199 | # rtn = str(rtn)
200 | if self.mem_map:
201 | rtn = rtn.decode('unicode_escape')
202 | return rtn
203 |
--------------------------------------------------------------------------------
/src/fp16/fp16util.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) 2020, Sber. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import torch
17 | import torch.nn as nn
18 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
19 | from torch.autograd import Variable
20 |
21 | from src import mpu
22 |
23 |
24 | class tofp16(nn.Module):
25 | """
26 | Utility module that implements::
27 |
28 | def forward(self, input):
29 | return input.half()
30 | """
31 |
32 | def __init__(self):
33 | super(tofp16, self).__init__()
34 |
35 | def forward(self, input):
36 | return input.half()
37 |
38 |
39 | def BN_convert_float(module):
40 | """
41 | Utility function for network_to_half().
42 |
43 | Retained for legacy purposes.
44 | """
45 | if isinstance(module, torch.nn.modules.batchnorm._BatchNorm) and module.affine is True:
46 | module.float()
47 | for child in module.children():
48 | BN_convert_float(child)
49 | return module
50 |
51 |
52 | def network_to_half(network):
53 | """
54 | Convert model to half precision in a batchnorm-safe way.
55 |
56 | Retained for legacy purposes. It is recommended to use FP16Model.
57 | """
58 | return nn.Sequential(tofp16(), BN_convert_float(network.half()))
59 |
60 |
61 | def convert_module(module, dtype):
62 | """
63 | Converts a module's immediate parameters and buffers to dtype.
64 | """
65 | for param in module.parameters(recurse=False):
66 | if param is not None:
67 | if param.data.dtype.is_floating_point:
68 | param.data = param.data.to(dtype=dtype)
69 | if param._grad is not None and param._grad.data.dtype.is_floating_point:
70 | param._grad.data = param._grad.data.to(dtype=dtype)
71 |
72 | for buf in module.buffers(recurse=False):
73 | if buf is not None and buf.data.dtype.is_floating_point:
74 | buf.data = buf.data.to(dtype=dtype)
75 |
76 |
77 | def convert_network(network, dtype):
78 | """
79 | Converts a network's parameters and buffers to dtype.
80 | """
81 | for module in network.modules():
82 | if isinstance(module, torch.nn.modules.batchnorm._BatchNorm) and module.affine is True:
83 | continue
84 | convert_module(module, dtype)
85 | return network
86 |
87 |
88 | class FP16Model(nn.Module):
89 | """
90 | Convert model to half precision in a batchnorm-safe way.
91 | """
92 |
93 | def __init__(self, network):
94 | super(FP16Model, self).__init__()
95 | self.network = convert_network(network, dtype=torch.half)
96 |
97 | def forward(self, *inputs):
98 | inputs = tuple(t.half() for t in inputs)
99 | return self.network(*inputs)
100 |
101 |
102 | def backwards_debug_hook(grad):
103 | raise RuntimeError("master_params recieved a gradient in the backward pass!")
104 |
105 |
106 | def prep_param_lists(model, flat_master=False):
107 | """
108 | Creates a list of FP32 master parameters for a given model, as in
109 | `Training Neural Networks with Mixed Precision: Real Examples`_.
110 |
111 | Args:
112 | model (torch.nn.Module): Existing Pytorch model
113 | flat_master (bool, optional, default=False): Flatten the master parameters into a single tensor, as a performance optimization.
114 | Returns:
115 | A tuple (``model_params``, ``master_params``). ``model_params`` is a list of the model's parameters for later use with :func:`model_grads_to_master_grads` and :func:`master_params_to_model_params`. ``master_params`` is a list of FP32 master gradients. If ``flat_master=True``, ``master_params`` will be a list with one element.
116 |
117 | Example::
118 |
119 | model_params, master_params = prep_param_lists(model)
120 |
121 | .. warning::
122 | Currently, if ``flat_master=True``, all the model's parameters must be the same type. If the model has parameters of different types, use ``flat_master=False``, or use :class:`FP16_Optimizer`.
123 |
124 | .. _`Training Neural Networks with Mixed Precision: Real Examples`:
125 | http://on-demand.gputechconf.com/gtc/2018/video/S81012/
126 | """
127 | model_params = [param for param in model.parameters() if param.requires_grad]
128 |
129 | if flat_master:
130 | # Give the user some more useful error messages
131 | try:
132 | # flatten_dense_tensors returns a contiguous flat array.
133 | # http://pytorch.org/docs/master/_modules/torch/_utils.html
134 | master_params = _flatten_dense_tensors([param.data for param in model_params]).float()
135 | except:
136 | print("Error in prep_param_lists: model may contain a mixture of parameters "
137 | "of different types. Use flat_master=False, or use F16_Optimizer.")
138 | raise
139 | master_params = torch.nn.Parameter(master_params)
140 | master_params.requires_grad = True
141 | # master_params.register_hook(backwards_debug_hook)
142 | if master_params.grad is None:
143 | master_params.grad = master_params.new(*master_params.size())
144 | return model_params, [master_params]
145 | else:
146 | master_params = [param.clone().float().detach() for param in model_params]
147 | for param in master_params:
148 | param.requires_grad = True
149 | return model_params, master_params
150 |
151 |
152 | def model_grads_to_master_grads(model_params, master_params, flat_master=False):
153 | """
154 | Copy model gradients to master gradients.
155 |
156 | Args:
157 | model_params: List of model parameters created by :func:`prep_param_lists`.
158 | master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`model_grads_to_master_grads`.
159 | """
160 | if flat_master:
161 | # The flattening may incur one more deep copy than is necessary.
162 | master_params[0].grad.data.copy_(
163 | _flatten_dense_tensors([p.grad.data for p in model_params]))
164 | else:
165 | for model, master in zip(model_params, master_params):
166 | if model.grad is not None:
167 | if master.grad is None:
168 | master.grad = Variable(master.data.new(*master.data.size()))
169 | master.grad.data.copy_(model.grad.data)
170 | else:
171 | master.grad = None
172 |
173 |
174 | def master_params_to_model_params(model_params, master_params, flat_master=False):
175 | """
176 | Copy master parameters to model parameters.
177 |
178 | Args:
179 | model_params: List of model parameters created by :func:`prep_param_lists`.
180 | master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`master_params_to_model_params`.
181 | """
182 | if flat_master:
183 | for model, master in zip(model_params,
184 | _unflatten_dense_tensors(master_params[0].data, model_params)):
185 | model.data.copy_(master)
186 | else:
187 | for model, master in zip(model_params, master_params):
188 | model.data.copy_(master.data)
189 |
190 |
191 | # Backward compatibility fixes
192 |
193 | def to_python_float(t):
194 | if hasattr(t, 'item'):
195 | return t.item()
196 | else:
197 | return t[0]
198 |
199 |
200 | TORCH_MAJOR = int(torch.__version__.split('.')[0])
201 | TORCH_MINOR = int(torch.__version__.split('.')[1])
202 |
203 | clip_grad_norm = mpu.clip_grad_norm
204 | # elif TORCH_MAJOR == 0 and TORCH_MINOR <= 4:
205 | # clip_grad_norm = torch.nn.utils.clip_grad_norm
206 | # else:
207 | # clip_grad_norm = torch.nn.utils.clip_grad_norm_
208 |
--------------------------------------------------------------------------------
/src/data_utils/file_utils.py:
--------------------------------------------------------------------------------
1 | # This file is provided as is from:
2 | # https://github.com/huggingface/pytorch-pretrained-BERT
3 | # Please refer to their repository for copyright.
4 |
5 | """
6 | Utilities for working with the local dataset cache.
7 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
8 | Copyright by the AllenNLP authors.
9 | """
10 | from __future__ import (absolute_import, division, print_function, unicode_literals)
11 |
12 | import json
13 | import logging
14 | import os
15 | import shutil
16 | import tempfile
17 | from functools import wraps
18 | from hashlib import sha256
19 | import sys
20 | from io import open
21 |
22 | import boto3
23 | import requests
24 | from botocore.exceptions import ClientError
25 | from tqdm import tqdm
26 |
27 | try:
28 | from urllib.parse import urlparse
29 | except ImportError:
30 | from urlparse import urlparse
31 |
32 | try:
33 | from pathlib import Path
34 | PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
35 | Path.home() / '.pytorch_pretrained_bert'))
36 | except (AttributeError, ImportError):
37 | PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
38 | os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert'))
39 |
40 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name
41 |
42 |
43 | def url_to_filename(url, etag=None):
44 | """
45 | Convert `url` into a hashed filename in a repeatable way.
46 | If `etag` is specified, append its hash to the url's, delimited
47 | by a period.
48 | """
49 | url_bytes = url.encode('utf-8')
50 | url_hash = sha256(url_bytes)
51 | filename = url_hash.hexdigest()
52 |
53 | if etag:
54 | etag_bytes = etag.encode('utf-8')
55 | etag_hash = sha256(etag_bytes)
56 | filename += '.' + etag_hash.hexdigest()
57 |
58 | return filename
59 |
60 |
61 | def filename_to_url(filename, cache_dir=None):
62 | """
63 | Return the url and etag (which may be ``None``) stored for `filename`.
64 | Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
65 | """
66 | if cache_dir is None:
67 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
68 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
69 | cache_dir = str(cache_dir)
70 |
71 | cache_path = os.path.join(cache_dir, filename)
72 | if not os.path.exists(cache_path):
73 | raise EnvironmentError("file {} not found".format(cache_path))
74 |
75 | meta_path = cache_path + '.json'
76 | if not os.path.exists(meta_path):
77 | raise EnvironmentError("file {} not found".format(meta_path))
78 |
79 | with open(meta_path, encoding="utf-8") as meta_file:
80 | metadata = json.load(meta_file)
81 | url = metadata['url']
82 | etag = metadata['etag']
83 |
84 | return url, etag
85 |
86 |
87 | def cached_path(url_or_filename, cache_dir=None):
88 | """
89 | Given something that might be a URL (or might be a local path),
90 | determine which. If it's a URL, download the file and cache it, and
91 | return the path to the cached file. If it's already a local path,
92 | make sure the file exists and then return the path.
93 | """
94 | if cache_dir is None:
95 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
96 | if sys.version_info[0] == 3 and isinstance(url_or_filename, Path):
97 | url_or_filename = str(url_or_filename)
98 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
99 | cache_dir = str(cache_dir)
100 |
101 | parsed = urlparse(url_or_filename)
102 |
103 | if parsed.scheme in ('http', 'https', 's3'):
104 | # URL, so get it from the cache (downloading if necessary)
105 | return get_from_cache(url_or_filename, cache_dir)
106 | elif os.path.exists(url_or_filename):
107 | # File, and it exists.
108 | return url_or_filename
109 | elif parsed.scheme == '':
110 | # File, but it doesn't exist.
111 | raise EnvironmentError("file {} not found".format(url_or_filename))
112 | else:
113 | # Something unknown
114 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
115 |
116 |
117 | def split_s3_path(url):
118 | """Split a full s3 path into the bucket name and path."""
119 | parsed = urlparse(url)
120 | if not parsed.netloc or not parsed.path:
121 | raise ValueError("bad s3 path {}".format(url))
122 | bucket_name = parsed.netloc
123 | s3_path = parsed.path
124 | # Remove '/' at beginning of path.
125 | if s3_path.startswith("/"):
126 | s3_path = s3_path[1:]
127 | return bucket_name, s3_path
128 |
129 |
130 | def s3_request(func):
131 | """
132 | Wrapper function for s3 requests in order to create more helpful error
133 | messages.
134 | """
135 |
136 | @wraps(func)
137 | def wrapper(url, *args, **kwargs):
138 | try:
139 | return func(url, *args, **kwargs)
140 | except ClientError as exc:
141 | if int(exc.response["Error"]["Code"]) == 404:
142 | raise EnvironmentError("file {} not found".format(url))
143 | else:
144 | raise
145 |
146 | return wrapper
147 |
148 |
149 | @s3_request
150 | def s3_etag(url):
151 | """Check ETag on S3 object."""
152 | s3_resource = boto3.resource("s3")
153 | bucket_name, s3_path = split_s3_path(url)
154 | s3_object = s3_resource.Object(bucket_name, s3_path)
155 | return s3_object.e_tag
156 |
157 |
158 | @s3_request
159 | def s3_get(url, temp_file):
160 | """Pull a file directly from S3."""
161 | s3_resource = boto3.resource("s3")
162 | bucket_name, s3_path = split_s3_path(url)
163 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
164 |
165 |
166 | def http_get(url, temp_file):
167 | req = requests.get(url, stream=True)
168 | content_length = req.headers.get('Content-Length')
169 | total = int(content_length) if content_length is not None else None
170 | progress = tqdm(unit="B", total=total)
171 | for chunk in req.iter_content(chunk_size=1024):
172 | if chunk:
173 | # filter out keep-alive new chunks
174 | progress.update(len(chunk))
175 | temp_file.write(chunk)
176 | progress.close()
177 |
178 |
179 | def get_from_cache(url, cache_dir=None):
180 | """
181 | Given a URL, look for the corresponding dataset in the local cache.
182 | If it's not there, download it. Then return the path to the cached file.
183 | """
184 | if cache_dir is None:
185 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
186 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
187 | cache_dir = str(cache_dir)
188 |
189 | if not os.path.exists(cache_dir):
190 | os.makedirs(cache_dir)
191 |
192 | # Get eTag to add to filename, if it exists.
193 | if url.startswith("s3://"):
194 | etag = s3_etag(url)
195 | else:
196 | response = requests.head(url, allow_redirects=True)
197 | if response.status_code != 200:
198 | raise IOError("HEAD request failed for url {} with status code {}"
199 | .format(url, response.status_code))
200 | etag = response.headers.get("ETag")
201 |
202 | filename = url_to_filename(url, etag)
203 |
204 | # get cache path to put the file
205 | cache_path = os.path.join(cache_dir, filename)
206 |
207 | if not os.path.exists(cache_path):
208 | # Download to temporary file, then copy to cache dir once finished.
209 | # Otherwise you get corrupt cache entries if the download gets interrupted.
210 | with tempfile.NamedTemporaryFile() as temp_file:
211 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name)
212 |
213 | # GET file object
214 | if url.startswith("s3://"):
215 | s3_get(url, temp_file)
216 | else:
217 | http_get(url, temp_file)
218 |
219 | # we are copying the file before closing it, so flush to avoid truncation
220 | temp_file.flush()
221 | # shutil.copyfileobj() starts at the current position, so go to the start
222 | temp_file.seek(0)
223 |
224 | logger.info("copying %s to cache at %s", temp_file.name, cache_path)
225 | with open(cache_path, 'wb') as cache_file:
226 | shutil.copyfileobj(temp_file, cache_file)
227 |
228 | logger.info("creating metadata file for %s", cache_path)
229 | meta = {'url': url, 'etag': etag}
230 | meta_path = cache_path + '.json'
231 | with open(meta_path, 'w', encoding="utf-8") as meta_file:
232 | json.dump(meta, meta_file)
233 |
234 | logger.info("removing temp file %s", temp_file.name)
235 |
236 | return cache_path
237 |
238 |
239 | def read_set_from_file(filename):
240 | """
241 | Extract a de-duped collection (set) of text from a file.
242 | Expected file format is one item per line.
243 | """
244 | collection = set()
245 | with open(filename, 'r', encoding='utf-8') as file_:
246 | for line in file_:
247 | collection.add(line.rstrip())
248 | return collection
249 |
250 |
251 | def get_file_extension(path, dot=True, lower=True):
252 | ext = os.path.splitext(path)[1]
253 | ext = ext if dot else ext[1:]
254 | return ext.lower() if lower else ext
255 |
--------------------------------------------------------------------------------
/src/fp16/loss_scaler.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) 2020, Sber. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import torch
17 |
18 | from src import mpu
19 |
20 |
21 | # item() is a recent addition, so this helps with backward compatibility.
22 | def to_python_float(t):
23 | if hasattr(t, 'item'):
24 | return t.item()
25 | else:
26 | return t[0]
27 |
28 |
29 | class LossScaler:
30 | """
31 | Class that manages a static loss scale. This class is intended to interact with
32 | :class:`FP16_Optimizer`, and should not be directly manipulated by the user.
33 |
34 | Use of :class:`LossScaler` is enabled via the ``static_loss_scale`` argument to
35 | :class:`FP16_Optimizer`'s constructor.
36 |
37 | Args:
38 | scale (float, optional, default=1.0): The loss scale.
39 | """
40 |
41 | def __init__(self, scale=1):
42 | self.cur_scale = scale
43 |
44 | # `params` is a list / generator of torch.Variable
45 | def has_overflow(self, params):
46 | return False
47 |
48 | # `x` is a torch.Tensor
49 | def _has_inf_or_nan(x):
50 | return False
51 |
52 | def update_scale(self, overflow):
53 | pass
54 |
55 | @property
56 | def loss_scale(self):
57 | return self.cur_scale
58 |
59 | def scale_gradient(self, module, grad_in, grad_out):
60 | return tuple(self.loss_scale * g for g in grad_in)
61 |
62 | def backward(self, loss, retain_graph=False):
63 | scaled_loss = loss * self.loss_scale
64 | scaled_loss.backward(retain_graph=retain_graph)
65 |
66 |
67 | class DynamicLossScaler:
68 | """
69 | Class that manages dynamic loss scaling. It is recommended to use :class:`DynamicLossScaler`
70 | indirectly, by supplying ``dynamic_loss_scale=True`` to the constructor of
71 | :class:`FP16_Optimizer`. However, it's important to understand how :class:`DynamicLossScaler`
72 | operates, because the default options can be changed using the
73 | the ``dynamic_loss_args`` argument to :class:`FP16_Optimizer`'s constructor.
74 |
75 | Loss scaling is designed to combat the problem of underflowing gradients encountered at long
76 | times when training fp16 networks. Dynamic loss scaling begins by attempting a very high loss
77 | scale. Ironically, this may result in OVERflowing gradients. If overflowing gradients are
78 | encountered, :class:`DynamicLossScaler` informs :class:`FP16_Optimizer` that an overflow has
79 | occurred.
80 | :class:`FP16_Optimizer` then skips the update step for this particular iteration/minibatch,
81 | and :class:`DynamicLossScaler` adjusts the loss scale to a lower value.
82 | If a certain number of iterations occur without overflowing gradients detected,
83 | :class:`DynamicLossScaler` increases the loss scale once more.
84 | In this way :class:`DynamicLossScaler` attempts to "ride the edge" of
85 | always using the highest loss scale possible without incurring overflow.
86 |
87 | Args:
88 | init_scale (float, optional, default=2**32): Initial loss scale attempted by :class:`DynamicLossScaler.`
89 | scale_factor (float, optional, default=2.0): Factor used when adjusting the loss scale. If an overflow is encountered, the loss scale is readjusted to loss scale/``scale_factor``. If ``scale_window`` consecutive iterations take place without an overflow, the loss scale is readjusted to loss_scale*``scale_factor``.
90 | scale_window (int, optional, default=1000): Number of consecutive iterations without an overflow to wait before increasing the loss scale.
91 | """
92 |
93 | def __init__(self,
94 | init_scale=2 ** 32,
95 | scale_factor=2.,
96 | scale_window=1000,
97 | min_scale=1,
98 | delayed_shift=1,
99 | consecutive_hysteresis=False):
100 | self.cur_scale = init_scale
101 | self.cur_iter = 0
102 | self.last_overflow_iter = -1
103 | self.scale_factor = scale_factor
104 | self.scale_window = scale_window
105 | self.min_scale = min_scale
106 | self.delayed_shift = delayed_shift
107 | self.cur_hysteresis = delayed_shift
108 | self.consecutive_hysteresis = consecutive_hysteresis
109 |
110 | # `params` is a list / generator of torch.Variable
111 | def has_overflow_serial(self, params):
112 | for p in params:
113 | if p.grad is not None and DynamicLossScaler._has_inf_or_nan(p.grad.data):
114 | return True
115 |
116 | return False
117 |
118 | def has_overflow(self, params):
119 | overflow = self.has_overflow_serial(params)
120 | # Since each model parallel GPU carries only part of the model,
121 | # make sure overflow flag is synced across all the model parallel GPUs
122 | overflow_gpu = torch.cuda.ByteTensor([overflow])
123 | torch.distributed.all_reduce(overflow_gpu,
124 | op=torch.distributed.ReduceOp.MAX,
125 | group=mpu.get_model_parallel_group())
126 | overflow = overflow_gpu[0].item()
127 | return bool(overflow)
128 |
129 | # `x` is a torch.Tensor
130 | def _has_inf_or_nan(x):
131 | try:
132 | # if x is half, the .float() incurs an additional deep copy, but it's necessary if
133 | # Pytorch's .sum() creates a one-element tensor of the same type as x
134 | # (which is true for some recent version of pytorch).
135 | cpu_sum = float(x.float().sum())
136 | # More efficient version that can be used if .sum() returns a Python scalar
137 | # cpu_sum = float(x.sum())
138 | except RuntimeError as instance:
139 | # We want to check if inst is actually an overflow exception.
140 | # RuntimeError could come from a different error.
141 | # If so, we still want the exception to propagate.
142 | if "value cannot be converted" not in instance.args[0]:
143 | raise
144 | return True
145 | else:
146 | if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
147 | return True
148 | return False
149 |
150 | # `overflow` is boolean indicating whether the gradient overflowed
151 | def update_scale(self, overflow):
152 |
153 | if not hasattr(self, 'min_scale'):
154 | self.min_scale = 1
155 | if not hasattr(self, 'delayed_shift'):
156 | self.delayed_shift = 1
157 | if not hasattr(self, 'cur_hysteresis'):
158 | self.cur_hysteresis = 1
159 | if not hasattr(self, 'consecutive_hysteresis'):
160 | self.consecutive_hysteresis = True
161 | if overflow:
162 | # self.cur_scale /= self.scale_factor
163 | if self.delayed_shift == 1 or self.cur_hysteresis == 1:
164 | self.cur_scale = max(self.cur_scale / self.scale_factor, self.min_scale)
165 | else:
166 | self.cur_hysteresis -= 1
167 | self.last_overflow_iter = self.cur_iter
168 | else:
169 | if self.consecutive_hysteresis:
170 | self.cur_hysteresis = self.delayed_shift
171 | if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0:
172 | if not self.consecutive_hysteresis:
173 | self.cur_hysteresis = self.delayed_shift
174 | self.cur_scale *= self.scale_factor
175 | self.cur_iter += 1
176 |
177 | @property
178 | def loss_scale(self):
179 | return self.cur_scale
180 |
181 | def scale_gradient(self, module, grad_in, grad_out):
182 | return tuple(self.loss_scale * g for g in grad_in)
183 |
184 | def backward(self, loss, retain_graph=False):
185 | scaled_loss = loss * self.loss_scale
186 | scaled_loss.backward(retain_graph=retain_graph)
187 |
188 |
189 | ##############################################################
190 | # Example usage below here -- assuming it's in a separate file
191 | ##############################################################
192 | """
193 | TO-DO separate out into an example.
194 | if __name__ == "__main__":
195 | import torch
196 | from torch.autograd import Variable
197 | from dynamic_loss_scaler import DynamicLossScaler
198 |
199 | # N is batch size; D_in is input dimension;
200 | # H is hidden dimension; D_out is output dimension.
201 | N, D_in, H, D_out = 64, 1000, 100, 10
202 |
203 | # Create random Tensors to hold inputs and outputs, and wrap them in Variables.
204 | x = Variable(torch.randn(N, D_in), requires_grad=False)
205 | y = Variable(torch.randn(N, D_out), requires_grad=False)
206 |
207 | w1 = Variable(torch.randn(D_in, H), requires_grad=True)
208 | w2 = Variable(torch.randn(H, D_out), requires_grad=True)
209 | parameters = [w1, w2]
210 |
211 | learning_rate = 1e-6
212 | optimizer = torch.optim.SGD(parameters, lr=learning_rate)
213 | loss_scaler = DynamicLossScaler()
214 |
215 | for t in range(500):
216 | y_pred = x.mm(w1).clamp(min=0).mm(w2)
217 | loss = (y_pred - y).pow(2).sum() * loss_scaler.loss_scale
218 | print('Iter {} loss scale: {}'.format(t, loss_scaler.loss_scale))
219 | print('Iter {} scaled loss: {}'.format(t, loss.data[0]))
220 | print('Iter {} unscaled loss: {}'.format(t, loss.data[0] / loss_scaler.loss_scale))
221 |
222 | # Run backprop
223 | optimizer.zero_grad()
224 | loss.backward()
225 |
226 | # Check for overflow
227 | has_overflow = DynamicLossScaler.has_overflow(parameters)
228 |
229 | # If no overflow, unscale grad and update as usual
230 | if not has_overflow:
231 | for param in parameters:
232 | param.grad.data.mul_(1. / loss_scaler.loss_scale)
233 | optimizer.step()
234 | # Otherwise, don't do anything -- ie, skip iteration
235 | else:
236 | print('OVERFLOW!')
237 |
238 | # Update loss scale for next iteration
239 | loss_scaler.update_scale(has_overflow)
240 |
241 | """
242 |
--------------------------------------------------------------------------------
/src/xl_wrapper.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | import os
4 | import random
5 | from typing import Union, Iterable
6 |
7 | import numpy as np
8 | import torch
9 | from deepspeed import DeepSpeedConfig
10 | from torch.nn import CrossEntropyLoss
11 | from transformers import GPT2Tokenizer, PreTrainedModel, PretrainedConfig
12 |
13 | from src import mpu
14 | from .fp16 import FP16_Module
15 | from .model import GPT3Model
16 | from .download_utils import download_model_files
17 | from transformers.utils import logging
18 |
19 |
20 | logger = logging.get_logger(__name__)
21 | NoneType = type(None)
22 |
23 |
24 | def get_deepspeed_config(path):
25 | return DeepSpeedConfig(path)
26 |
27 |
28 | def get_sparse_attention_config(path, num_heads):
29 | ds_config = get_deepspeed_config(path)
30 | if hasattr(ds_config, 'sparse_attention') and ds_config.sparse_attention:
31 | sa_config = ds_config.sparse_attention
32 | sa_mode = sa_config.get('mode')
33 | if sa_mode == 'dense':
34 | from deepspeed.ops.sparse_attention import DenseSparsityConfig as STConfig
35 | elif sa_mode == 'fixed':
36 | from deepspeed.ops.sparse_attention import FixedSparsityConfig as STConfig
37 | elif sa_mode == 'bigbird':
38 | from deepspeed.ops.sparse_attention import BigBirdSparsityConfig as STConfig
39 | elif sa_mode == 'bslongformer':
40 | from deepspeed.ops.sparse_attention import BSLongformerSparsityConfig as STConfig
41 | elif sa_mode == 'variable':
42 | from deepspeed.ops.sparse_attention import VariableSparsityConfig as STConfig
43 | else:
44 | raise NotImplementedError(
45 | f'Given sparsity mode, {sa_mode}, has not been implemented yet!'
46 | )
47 | del sa_config['mode']
48 | return STConfig(num_heads=num_heads, **sa_config)
49 | else:
50 | return None
51 |
52 |
53 | def get_model(deepspeed_config_path):
54 | num_local_heads = 16
55 | sparse_mode = 'alternating'
56 | deepspeed_sparsity_config = get_sparse_attention_config(deepspeed_config_path, num_local_heads)
57 | if deepspeed_sparsity_config is not None:
58 | logger.info(f"Use sparse attention with mode {sparse_mode}")
59 | else:
60 | logger.info(f"Use dense attention")
61 | model = GPT3Model(num_layers=24,
62 | vocab_size=50264,
63 | hidden_size=2048,
64 | num_attention_heads=num_local_heads,
65 | embedding_dropout_prob=0.1, attention_dropout_prob=0.1, output_dropout_prob=0.1,
66 | max_sequence_length=2048,
67 | checkpoint_activations=False,
68 | checkpoint_num_layers=1,
69 | parallel_output=False,
70 | deepspeed_sparsity_config=deepspeed_sparsity_config,
71 | sparse_mode=sparse_mode)
72 | # GPU allocation.
73 | model.cuda(torch.cuda.current_device())
74 |
75 | # Fp16 conversion.
76 | model = FP16_Module(model)
77 |
78 | return model
79 |
80 |
81 | def setup_model(weights_path, deepspeed_config_path):
82 | model = get_model(deepspeed_config_path)
83 | logger.info("Load checkpoint from " + weights_path)
84 | checkpoint = torch.load(weights_path, map_location=lambda storage, loc: storage)['module']
85 | model.load_state_dict(checkpoint)
86 | model.eval()
87 | logger.info("Model Loaded")
88 | return model
89 |
90 |
91 | def get_masks_and_position_ids(data,
92 | eod_token,
93 | reset_position_ids,
94 | reset_attention_mask):
95 | # Extract batch size and sequence length.
96 | batch_size, seq_length = data.size()
97 |
98 | # Attention mask (lower triangular).
99 | if reset_attention_mask:
100 | att_mask_batch = batch_size
101 | else:
102 | att_mask_batch = 1
103 | attention_mask = torch.tril(torch.ones(
104 | (att_mask_batch, seq_length, seq_length), device=data.device)).view(
105 | att_mask_batch, 1, seq_length, seq_length)
106 |
107 | # Loss mask.
108 | loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
109 | loss_mask[data == eod_token] = 0.0
110 |
111 | # Position ids.
112 | position_ids = torch.arange(seq_length, dtype=torch.long,
113 | device=data.device)
114 | position_ids = position_ids.unsqueeze(0).expand_as(data)
115 | # We need to clone as the ids will be modifed based on batch index.
116 | if reset_position_ids:
117 | position_ids = position_ids.clone()
118 |
119 | if reset_position_ids or reset_attention_mask:
120 | # Loop through the batches:
121 | for b in range(batch_size):
122 |
123 | # Find indices where EOD token is.
124 | eod_index = position_ids[b, data[b] == eod_token]
125 | # Detach indecies from positions if going to modify positions.
126 | if reset_position_ids:
127 | eod_index = eod_index.clone()
128 |
129 | # Loop through EOD indices:
130 | prev_index = 0
131 | for j in range(eod_index.size()[0]):
132 | i = eod_index[j]
133 | # Mask attention loss.
134 | if reset_attention_mask:
135 | attention_mask[b, 0, (i + 1):, :(i + 1)] = 0
136 | # Reset positions.
137 | if reset_position_ids:
138 | position_ids[b, (i + 1):] -= (i + 1 - prev_index)
139 | prev_index = i + 1
140 |
141 | return attention_mask, loss_mask, position_ids
142 |
143 |
144 | class ModelOutput(object):
145 | def __init__(self, logits, loss=None):
146 | self.logits = logits
147 | self.loss = loss
148 |
149 | def __getitem__(self, key):
150 | if key == "logits":
151 | return self.logits
152 | raise StopIteration
153 |
154 |
155 | class RuGPT3XL(PreTrainedModel):
156 | def __init__(self, model, tokenizer, model_path, seq_len=512, min_generated_len=32):
157 | super().__init__(PretrainedConfig())
158 | self.model = model
159 | self.pad_token_id = tokenizer.encoder['']
160 | self.eos_token_id = tokenizer.encoder['']
161 | self.seq_len = seq_len
162 | self.model_path = model_path
163 | self.tokenizer = tokenizer
164 | self.min_generated_len = min_generated_len
165 |
166 | @classmethod
167 | def from_pretrained(
168 | cls,
169 | model_name_or_path=None,
170 | seq_len=512,
171 | weights_path=None,
172 | deepspeed_config_path=None,
173 | master_port="6000",
174 | min_generated_len=32,
175 | rank=0
176 | ):
177 | init_method = 'tcp://' + os.getenv('MASTER_ADDR', 'localhost') + ':' + os.getenv('MASTER_PORT', master_port)
178 | try:
179 | torch.distributed.init_process_group(backend='nccl', world_size=1, rank=rank, init_method=init_method)
180 | mpu.initialize_model_parallel(1)
181 | except RuntimeError:
182 | logger.info("The default process group has already initialized...")
183 |
184 | seed = 1234
185 | random.seed(seed)
186 | np.random.seed(seed)
187 | torch.manual_seed(seed)
188 | tokenizer = GPT2Tokenizer.from_pretrained(model_name_or_path)
189 | logger.info("Check cached model files...")
190 | if weights_path is None:
191 | weights_path, deepspeed_config_path = download_model_files(model_name_or_path)
192 | model = setup_model(weights_path, deepspeed_config_path)
193 | mpu.model_parallel_cuda_manual_seed(seed)
194 | # model.cuda()
195 | model = model.eval()
196 | return cls(model, tokenizer=tokenizer, seq_len=seq_len, model_path=model_name_or_path, min_generated_len=min_generated_len)
197 |
198 | def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs):
199 | kwargs.update({"input_ids": input_ids})
200 | return kwargs
201 |
202 | def generate(
203 | self, text: Union[str, NoneType] = None,
204 | input_ids: Union[torch.LongTensor, NoneType] = None,
205 | max_length: Union[int, None] = None,
206 | min_length: Union[int, NoneType] = None,
207 | do_sample: Union[bool, NoneType] = None,
208 | early_stopping: Union[bool, NoneType] = None,
209 | num_beams: Union[int, NoneType] = None,
210 | temperature: Union[float, NoneType] = None,
211 | top_k: Union[int, NoneType] = None,
212 | top_p: Union[float, NoneType] = None,
213 | repetition_penalty: Union[float, NoneType] = None,
214 | bad_words_ids: Union[Iterable[int], NoneType] = None,
215 | bos_token_id: Union[int, NoneType] = None,
216 | pad_token_id: Union[int, NoneType] = None,
217 | eos_token_id: Union[int, NoneType] = None,
218 | length_penalty: Union[float, NoneType] = None,
219 | no_repeat_ngram_size: Union[int, NoneType] = None,
220 | num_return_sequences: Union[int, NoneType] = None,
221 | decoder_start_token_id: Union[int, NoneType] = None,
222 | use_cache: Union[bool, NoneType] = None,
223 | **model_kwargs):
224 | if text is not None:
225 | input_ids = torch.cuda.LongTensor([self.tokenizer(text)['input_ids']])
226 | if eos_token_id is None:
227 | eos_token_id = self.eos_token_id
228 | if pad_token_id is None:
229 | pad_token_id = self.pad_token_id
230 | if input_ids.shape[-1] > 2048:
231 | input_ids = input_ids[:, -2048 + self.min_generated_len:]
232 | # print(input_ids.shape, max_length, mpu.get_data_parallel_rank())
233 | res = super().generate(
234 | input_ids=input_ids,
235 | max_length=max_length,
236 | min_length=min_length,
237 | do_sample=do_sample,
238 | early_stopping=early_stopping,
239 | num_beams=num_beams,
240 | temperature=temperature,
241 | top_k=top_k,
242 | top_p=top_p,
243 | repetition_penalty=repetition_penalty,
244 | bad_words_ids=bad_words_ids,
245 | bos_token_id=bos_token_id,
246 | pad_token_id=pad_token_id,
247 | eos_token_id=eos_token_id,
248 | length_penalty=length_penalty,
249 | no_repeat_ngram_size=no_repeat_ngram_size,
250 | num_return_sequences=num_return_sequences,
251 | decoder_start_token_id=decoder_start_token_id,
252 | use_cache=use_cache,
253 | **model_kwargs
254 | )
255 | return list(map(self.tokenizer.decode, res.tolist()))
256 |
257 | def __call__(self, text=None, input_ids=None, labels=None, **kwargs):
258 | if input_ids is None:
259 | if text is None:
260 | text = ""
261 | input_ids = torch.cuda.LongTensor([self.tokenizer(text)['input_ids']])
262 | if isinstance(input_ids, list):
263 | input_ids = torch.cuda.LongTensor(input_ids)
264 | if isinstance(labels, list):
265 | labels = torch.cuda.LongTensor(labels)
266 | res = []
267 | if labels is not None:
268 | lbls = labels
269 | else:
270 | lbls = [None] * len(input_ids)
271 | loss = None
272 | original_context_length = 0
273 | seq_len = self.seq_len
274 | for tokens, lbl in zip(input_ids, lbls):
275 | context_tokens = tokens.tolist()
276 | original_context_length = len(context_tokens)
277 | if labels is not None:
278 | lbl = lbl.tolist()
279 | assert original_context_length
280 |
281 | while len(context_tokens) % 16:
282 | context_tokens.append(self.pad_token_id)
283 | if labels is not None:
284 | lbl.append(self.pad_token_id)
285 | context_tokens = context_tokens[-2048:]
286 | context_length = len(context_tokens)
287 | if labels is not None:
288 | lbl = lbl[-2048:]
289 | lbl = torch.cuda.LongTensor(lbl)
290 | context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
291 | context_length_tensor = torch.cuda.LongTensor([context_length])
292 |
293 | torch.distributed.broadcast(context_length_tensor, mpu.get_model_parallel_src_rank(),
294 | group=mpu.get_model_parallel_group())
295 | torch.distributed.broadcast(context_tokens_tensor, mpu.get_model_parallel_src_rank(),
296 | group=mpu.get_model_parallel_group())
297 |
298 | # context_length = context_length_tensor[0].item()
299 | # print(context_tokens_tensor.shape, original_context_length, seq_len, mpu.get_data_parallel_rank())
300 |
301 | tokens = context_tokens_tensor
302 | tokens = tokens.view(1, -1).contiguous()
303 | tokens = tokens.to(torch.cuda.current_device())
304 | attention_mask, loss_mask, position_ids = get_masks_and_position_ids(tokens, self.pad_token_id, False,
305 | False)
306 | lm_logits = self.model(tokens, position_ids, attention_mask)
307 | loss = None
308 | if labels is not None:
309 | # Shift so that tokens < n predict n
310 | shift_logits = lm_logits[..., :-1, :].contiguous()
311 | shift_labels = lbl[..., 1:].contiguous()
312 | # Flatten the tokens
313 | loss_fct = CrossEntropyLoss(ignore_index=self.pad_token_id)
314 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
315 | res.append((lm_logits, loss))
316 | logits = torch.cat([x[0] for x in res], dim=0)[:, : original_context_length, :]
317 | if loss is not None:
318 | loss = [x[1] for x in res]
319 | # print(logits.shape, mpu.get_data_parallel_rank(), "------------")
320 | return ModelOutput(logits, loss)
321 |
--------------------------------------------------------------------------------
/src/mpu/layers.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) 2020, Sber. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | # Parts of the code here are adapted from PyTorch
18 | # repo: https://github.com/pytorch/pytorch
19 |
20 |
21 | import math
22 |
23 | import torch
24 | import torch.nn.functional as F
25 | import torch.nn.init as init
26 | from torch.nn.parameter import Parameter
27 |
28 | from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm
29 |
30 | from .initialize import get_model_parallel_rank
31 | from .initialize import get_model_parallel_world_size
32 | from .mappings import copy_to_model_parallel_region
33 | from .mappings import gather_from_model_parallel_region
34 | from .mappings import reduce_from_model_parallel_region
35 | from .mappings import scatter_to_model_parallel_region
36 | from .random import get_cuda_rng_tracker
37 | from .utils import divide
38 | from .utils import split_tensor_along_last_dim
39 | from .utils import VocabUtility
40 |
41 |
42 | def _initialize_affine_weight(weight, output_size, input_size,
43 | per_partition_size, partition_dim, init_method,
44 | stride=1, return_master_weight=False):
45 | """Initialize affine weight for model parallel.
46 |
47 | Build the master weight on all processes and scatter
48 | the relevant chunk."""
49 | # If we only use 1 process for model parallelism, bypass scatter.
50 | world_size = get_model_parallel_world_size()
51 | if world_size == 1:
52 | init_method(weight)
53 | if return_master_weight:
54 | return weight
55 | return None
56 |
57 | # Initialize master weight
58 | master_weight = torch.empty(output_size, input_size,
59 | dtype=weight.dtype,
60 | requires_grad=False)
61 | init_method(master_weight)
62 |
63 | # Split and copy
64 | per_partition_per_stride_size = divide(per_partition_size, stride)
65 | weight_list = torch.split(master_weight, per_partition_per_stride_size,
66 | dim=partition_dim)
67 | rank = get_model_parallel_rank()
68 | my_weight_list = weight_list[rank::world_size]
69 |
70 | with torch.no_grad():
71 | torch.cat(my_weight_list, dim=partition_dim, out=weight)
72 | if return_master_weight:
73 | return master_weight
74 | return None
75 |
76 |
77 | class VocabParallelEmbedding(torch.nn.Module):
78 | """Embedding parallelized in the vocabulary dimension.
79 |
80 | This is mainly adapted from torch.nn.Embedding and all the default
81 | values are kept.
82 | Arguments:
83 | num_embeddings: vocabulary size.
84 | embedding_dim: size of hidden state.
85 | init_method: method to initialize weights.
86 | """
87 | def __init__(self, num_embeddings, embedding_dim,
88 | init_method=init.xavier_normal_):
89 | super(VocabParallelEmbedding, self).__init__()
90 | # Keep the input dimensions.
91 | self.num_embeddings = num_embeddings
92 | self.embedding_dim = embedding_dim
93 | # Set the detauls for compatibility.
94 | self.padding_idx = None
95 | self.max_norm = None
96 | self.norm_type = 2.
97 | self.scale_grad_by_freq = False
98 | self.sparse = False
99 | self._weight = None
100 | # Divide the weight matrix along the vocaburaly dimension.
101 | self.vocab_start_index, self.vocab_end_index = \
102 | VocabUtility.vocab_range_from_global_vocab_size(
103 | self.num_embeddings, get_model_parallel_rank(),
104 | get_model_parallel_world_size())
105 | self.num_embeddings_per_partition = self.vocab_end_index - \
106 | self.vocab_start_index
107 |
108 | # Allocate weights.
109 | self.weight = Parameter(torch.Tensor(self.num_embeddings_per_partition,
110 | self.embedding_dim))
111 | self.weight.model_parallel = True
112 | # And initialize.
113 | _initialize_affine_weight(
114 | self.weight, self.num_embeddings, self.embedding_dim,
115 | self.num_embeddings_per_partition, 0, init_method)
116 |
117 | def forward(self, input_):
118 | # Build the mask.
119 | input_mask = (input_ < self.vocab_start_index) | \
120 | (input_ >= self.vocab_end_index)
121 | # Mask the input.
122 | masked_input = input_.clone() - self.vocab_start_index
123 | masked_input[input_mask] = 0
124 | # Get the embeddings.
125 | output_parallel = F.embedding(masked_input, self.weight,
126 | self.padding_idx, self.max_norm,
127 | self.norm_type, self.scale_grad_by_freq,
128 | self.sparse)
129 | # Mask the output embedding.
130 | output_parallel[input_mask, :] = 0.0
131 | # Reduce across all the model parallel GPUs.
132 | output = reduce_from_model_parallel_region(output_parallel)
133 | return output
134 |
135 |
136 | class ParallelEmbedding(torch.nn.Module):
137 | """Embedding parallelized in the embedding dimension.
138 |
139 | This is mainly adapted from torch.nn.Embedding and all the default
140 | values are kept.
141 | Arguments:
142 | num_embeddings: vocabulary size.
143 | embedding_dim: size of hidden state.
144 | init_method: method to initialize weights.
145 | """
146 | def __init__(self, num_embeddings, embedding_dim,
147 | init_method=init.xavier_normal_,
148 | keep_master_weight_for_test=False):
149 | super(ParallelEmbedding, self).__init__()
150 | # Keep the input dimensions.
151 | self.num_embeddings = num_embeddings
152 | self.embedding_dim = embedding_dim
153 | # Set some detauls for compatibility.
154 | self.padding_idx = None
155 | self.max_norm = None
156 | self.norm_type = 2.
157 | self.scale_grad_by_freq = False
158 | self.sparse = False
159 | self._weight = None
160 | # Divide the weight matrix along the embedding dimension.
161 | world_size = get_model_parallel_world_size()
162 | self.embedding_dim_per_partition = divide(self.embedding_dim,
163 | world_size)
164 |
165 | # Allocate weights.
166 | self.weight = Parameter(torch.Tensor(self.num_embeddings,
167 | self.embedding_dim_per_partition))
168 | self.weight.model_parallel = True
169 | # And initialize.
170 | _initialize_affine_weight(
171 | self.weight, self.num_embeddings, self.embedding_dim,
172 | self.embedding_dim_per_partition, 1, init_method,
173 | stride=1, return_master_weight=False)
174 |
175 | def forward(self, input_):
176 | input_parallel = copy_to_model_parallel_region(input_)
177 | output_parallel = F.embedding(input_parallel, self.weight,
178 | self.padding_idx, self.max_norm,
179 | self.norm_type, self.scale_grad_by_freq,
180 | self.sparse)
181 | output = gather_from_model_parallel_region(output_parallel)
182 | return output
183 |
184 |
185 | class ColumnParallelLinear(torch.nn.Module):
186 | """Linear layer with column parallelism.
187 |
188 | The linear layer is defined as Y = XA + b. A is parallelized along
189 | its second dimension as A = [A_1, ..., A_p].
190 |
191 | Arguments:
192 | input_size: first dimension of matrix A.
193 | output_size: second dimension of matrix A.
194 | bias: If true, add bias
195 | gather_output: If true, call all-gether on output and make Y avaiable
196 | to all GPUs, otherwise, every GPU will have its output
197 | which is Y_i = XA_i
198 | init_method: method to initialize weights. Note that bias is always set
199 | to zero.
200 | stride: For the strided linear layers.
201 | keep_master_weight_for_test: This was added for testing and should be
202 | set to False. It returns the master weights
203 | used for initialization.
204 | """
205 | def __init__(self, input_size, output_size, bias=True, gather_output=True,
206 | init_method=init.xavier_normal_, stride=1,
207 | keep_master_weight_for_test=False):
208 | super(ColumnParallelLinear, self).__init__()
209 |
210 | # Keep input parameters
211 | self.input_size = input_size
212 | self.output_size = output_size
213 | self.gather_output = gather_output
214 | # Divide the weight matrix along the last dimension.
215 | world_size = get_model_parallel_world_size()
216 | self.output_size_per_partition = divide(output_size, world_size)
217 |
218 | # Parameters.
219 | # Note: torch.nn.functional.linear performs XA^T + b and as a result
220 | # we allocate the transpose.
221 | self.weight = Parameter(torch.Tensor(self.output_size_per_partition,
222 | self.input_size))
223 | self.weight.model_parallel = True
224 | if bias:
225 | self.bias = Parameter(torch.Tensor(self.output_size_per_partition))
226 | self.bias.model_parallel = True
227 | # Always initialize bias to zero.
228 | with torch.no_grad():
229 | self.bias.zero_()
230 | else:
231 | self.register_parameter('bias', None)
232 |
233 | # Initialize weight.
234 | self.master_weight = _initialize_affine_weight(
235 | self.weight, self.output_size, self.input_size,
236 | self.output_size_per_partition, 0, init_method,
237 | stride=stride, return_master_weight=keep_master_weight_for_test)
238 |
239 | def forward(self, input_):
240 | # Set up backprop all-reduce.
241 | input_parallel = copy_to_model_parallel_region(input_)
242 | # Matrix multiply.
243 | output_parallel = F.linear(input_parallel, self.weight, self.bias)
244 | if self.gather_output:
245 | # All-gather across the partitions.
246 | output = gather_from_model_parallel_region(output_parallel)
247 | else:
248 | output = output_parallel
249 | return output
250 |
251 |
252 | class RowParallelLinear(torch.nn.Module):
253 | """Linear layer with row parallelism.
254 |
255 | The linear layer is defined as Y = XA + b. A is parallelized along
256 | its first dimension and X along its second dimension as:
257 | - -
258 | | A_1 |
259 | | . |
260 | A = | . | X = [X_1, ..., X_p]
261 | | . |
262 | | A_p |
263 | - -
264 | Arguments:
265 | input_size: first dimension of matrix A.
266 | output_size: second dimension of matrix A.
267 | bias: If true, add bias. Note that bias is not parallelized.
268 | input_is_parallel: If true, we assume that the input is already
269 | split across the GPUs and we do not split
270 | again.
271 | init_method: method to initialize weights. Note that bias is always set
272 | to zero.
273 | stride: For the strided linear layers.
274 | keep_master_weight_for_test: This was added for testing and should be
275 | set to False. It returns the master weights
276 | used for initialization.
277 | """
278 | def __init__(self, input_size, output_size, bias=True,
279 | input_is_parallel=False,
280 | init_method=init.xavier_normal_, stride=1,
281 | keep_master_weight_for_test=False):
282 | super(RowParallelLinear, self).__init__()
283 |
284 | # Keep input parameters
285 | self.input_size = input_size
286 | self.output_size = output_size
287 | self.input_is_parallel = input_is_parallel
288 | # Divide the weight matrix along the last dimension.
289 | world_size = get_model_parallel_world_size()
290 | self.input_size_per_partition = divide(input_size, world_size)
291 |
292 | # Parameters.
293 | # Note: torch.nn.functional.linear performs XA^T + b and as a result
294 | # we allocate the transpose.
295 | self.weight = Parameter(torch.Tensor(self.output_size,
296 | self.input_size_per_partition))
297 | self.weight.model_parallel = True
298 | if bias:
299 | self.bias = Parameter(torch.Tensor(self.output_size))
300 | # Always initialize bias to zero.
301 | with torch.no_grad():
302 | self.bias.zero_()
303 | else:
304 | self.register_parameter('bias', None)
305 |
306 | # Initialize weight.
307 | self.master_weight = _initialize_affine_weight(
308 | self.weight, self.output_size, self.input_size,
309 | self.input_size_per_partition, 1, init_method,
310 | stride=stride, return_master_weight=keep_master_weight_for_test)
311 |
312 | def forward(self, input_):
313 | # Set up backprop all-reduce.
314 | if self.input_is_parallel:
315 | input_parallel = input_
316 | else:
317 | input_parallel = scatter_to_model_parallel_region(input_)
318 | # Matrix multiply.
319 | output_parallel = F.linear(input_parallel, self.weight)
320 | # All-reduce across all the partitions.
321 | output_ = reduce_from_model_parallel_region(output_parallel)
322 | if self.bias is not None:
323 | output = output_ + self.bias
324 | else:
325 | output = output_
326 | return output
327 |
328 |
--------------------------------------------------------------------------------
/src/mpu/random.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | #Modified by Samyam Rajbhandari
3 | #Used to partition the activations stored for backward propagation
4 | #Therefore reduces the memory consumption
5 |
6 | # Copyright (c) 2020, Sber. All rights reserved.
7 | #
8 | # Licensed under the Apache License, Version 2.0 (the "License");
9 | # you may not use this file except in compliance with the License.
10 | # You may obtain a copy of the License at
11 | #
12 | # http://www.apache.org/licenses/LICENSE-2.0
13 | #
14 | # Unless required by applicable law or agreed to in writing, software
15 | # distributed under the License is distributed on an "AS IS" BASIS,
16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 | # See the License for the specific language governing permissions and
18 | # limitations under the License.
19 |
20 |
21 | # Parts of the code here are adapted from PyTorch
22 | # repo: https://github.com/pytorch/pytorch
23 | import contextlib
24 | import torch.distributed as dist
25 | import torch
26 | from torch import _C
27 | from torch.cuda import _lazy_call, device as device_ctx_manager
28 | #from torch.utils.checkpoint import detach_variable
29 |
30 |
31 | import torch.distributed as dist
32 | PARTITION_ACTIVATIONS = False
33 | PA_CORRECTNESS_TEST= False
34 |
35 | def see_memory_usage(message, force=False):
36 | if not force:
37 | return
38 | dist.barrier()
39 | if dist.get_rank() == 0:
40 | print(message)
41 | print("Memory Allocated ", torch.cuda.memory_allocated()/(1024*1024*1024), "GigaBytes")
42 | print("Max Memory Allocated ", torch.cuda.max_memory_allocated()/(1024*1024*1024), "GigaBytes")
43 | print("Cache Allocated ", torch.cuda.memory_cached()/(1024*1024*1024), "GigaBytes")
44 | print("Max cache Allocated ", torch.cuda.max_memory_cached()/(1024*1024*1024), "GigaBytes")
45 | print(" ")
46 | #input("Press Any Key To Continue ..")
47 |
48 |
49 | from .initialize import get_data_parallel_rank
50 | from .initialize import get_model_parallel_rank
51 | from .initialize import get_model_parallel_world_size
52 | from .initialize import get_model_parallel_group
53 |
54 | mp_rank = None #get_model_parallel_rank()
55 | mp_size = None #get_model_parallel_world_size()
56 | mp_group = None #get_model_parallel_group()
57 |
58 | # Default name for the model parallel rng tracker.
59 | _MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'
60 | transport_stream = None
61 | cuda_device=None
62 | def detach_variable(inputs, device=None):
63 | if isinstance(inputs, tuple):
64 | out = []
65 | for inp in inputs:
66 | if not isinstance(inp, torch.Tensor):
67 | out.append(inp)
68 | continue
69 |
70 | requires_grad = inp.requires_grad
71 |
72 | if device is not None:
73 | x = inp.to(device=device)
74 | else:
75 | x = inp
76 |
77 | x = x.detach()
78 | x.requires_grad = requires_grad
79 | out.append(x)
80 | return tuple(out)
81 | else:
82 | raise RuntimeError(
83 | "Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__)
84 |
85 | def _set_cuda_rng_state(new_state, device=-1):
86 | """Sets the random number generator state of the current GPU.
87 |
88 | Argumentss:
89 | new_state (torch.ByteTensor): The desired state
90 | This function is adapted from PyTorch repo (torch.cuda.set_rng_state)
91 | with a single change: the input state is not cloned. Cloning caused
92 | major performance issues for +4 GPU cases.
93 | """
94 | if hasattr(_C, '_cuda_setRNGState') and callable(_C._cuda_setRNGState):
95 | # older PyTorch
96 | def cb():
97 | with device_ctx_manager(device):
98 | _C._cuda_setRNGState(new_state)
99 | else:
100 | # newer PyTorch
101 | if device == -1:
102 | device = torch.device('cuda')
103 | elif isinstance(device, str):
104 | device = torch.device(device)
105 | elif isinstance(device, int):
106 | device = torch.device('cuda', device)
107 |
108 | def cb():
109 | idx = device.index
110 | if idx is None:
111 | idx = torch.cuda.current_device()
112 | default_generator = torch.cuda.default_generators[idx]
113 | default_generator.set_state(new_state)
114 |
115 | _lazy_call(cb)
116 |
117 |
118 |
119 | class CudaRNGStatesTracker:
120 | """Tracker for the cuda RNG states.
121 |
122 | Using the `add` method, a cuda rng state is initialized based on
123 | the input `seed` and is assigned to `name`. Later, by forking the
124 | rng state, we can perform operations and return to our starting
125 | cuda state.
126 | """
127 | def __init__(self):
128 | # Map from a string name to the cuda rng state.
129 | self.states_ = {}
130 | # Seeds are just for book keeping and ensure no seed is set twice.
131 | self.seeds_ = set()
132 |
133 | def reset(self):
134 | """Set to the initial state (no tracker)."""
135 | self.states_ = {}
136 | self.seeds_ = set()
137 |
138 | def get_states(self):
139 | """Get rng states. Copy the dictionary so we have direct
140 | pointers to the states, not just a pointer to the dictionary."""
141 | states = {}
142 | for name in self.states_:
143 | states[name] = self.states_[name]
144 | return states
145 |
146 | def set_states(self, states):
147 | """Set the rng states. For efficiency purposes, we do not check
148 | the size of seed for compatibility."""
149 | self.states_ = states
150 |
151 | def add(self, name, seed):
152 | """Track the rng state."""
153 | # Check seed is not already used.
154 | if seed in self.seeds_:
155 | raise Exception('seed {} already exists'.format(seed))
156 | self.seeds_.add(seed)
157 | # Check that state is not already defined.
158 | if name in self.states_:
159 | raise Exception('cuda rng state {} already exists'.format(name))
160 | # Get the current rng state.
161 | orig_rng_state = torch.cuda.get_rng_state()
162 | # Set the new state and store it.
163 | torch.cuda.manual_seed(seed)
164 | self.states_[name] = torch.cuda.get_rng_state()
165 | # Reset rng state to what it was.
166 | _set_cuda_rng_state(orig_rng_state)
167 |
168 | @contextlib.contextmanager
169 | def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME):
170 | """Fork the cuda rng state, perform operations, and exit with
171 | the original state."""
172 | # Check if we have added the state
173 | if name not in self.states_:
174 | raise Exception('cuda rng state {} is not added'.format(name))
175 | # Store current rng state.
176 | orig_cuda_rng_state = torch.cuda.get_rng_state()
177 | # Set rng state to the desired one
178 | _set_cuda_rng_state(self.states_[name])
179 | # Do the stuff we wanted to do.
180 | try:
181 | yield
182 | finally:
183 | # Update the current rng state for later use.
184 | self.states_[name] = torch.cuda.get_rng_state()
185 | # And set the state to the original state we started with.
186 | _set_cuda_rng_state(orig_cuda_rng_state)
187 |
188 |
189 | # RNG tracker object.
190 | _CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
191 |
192 |
193 | def get_cuda_rng_tracker():
194 | """Get cuda rng tracker."""
195 | return _CUDA_RNG_STATE_TRACKER
196 |
197 |
198 | def model_parallel_cuda_manual_seed(seed):
199 | """Initialize model parallel cuda seed.
200 |
201 | This function should be called after the model parallel is
202 | initialized. Also, no torch.cuda.manual_seed should be called
203 | after this function. Basically, this is replacement for that
204 | function.
205 | Two set of RNG states are tracked:
206 | default state: This is for data parallelism and is the same among a
207 | set of model parallel GPUs but different across
208 | different model paralle groups. This is used for
209 | example for dropout in the non-model-parallel regions.
210 | model-parallel state: This state is different among a set of model
211 | parallel GPUs, but the same across data parallel
212 | groups. This is used for example for dropout in
213 | model parallel regions.
214 | """
215 | # 2718 is just for fun and any POSITIVE value will work.
216 | offset = seed + 2718
217 | model_parallel_seed = offset + get_model_parallel_rank()
218 | # Data parallel gets the original sedd.
219 | data_parallel_seed = seed
220 |
221 | if torch.distributed.get_rank() == 0:
222 | print('> initializing model parallel cuda seeds on global rank {}, '
223 | 'model parallel rank {}, and data parallel rank {} with '
224 | 'model parallel seed: {} and data parallel seed: {}'.format(
225 | torch.distributed.get_rank(), get_model_parallel_rank(),
226 | get_data_parallel_rank(), model_parallel_seed,
227 | data_parallel_seed), flush=True)
228 | _CUDA_RNG_STATE_TRACKER.reset()
229 | # Set the default state.
230 | torch.cuda.manual_seed(data_parallel_seed)
231 | # and model parallel state.
232 | _CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME,
233 | model_parallel_seed)
234 |
235 |
236 | def get_partition_start(item):
237 | global mp_rank, mp_size, mp_group
238 | partition_size = get_partition_size(item)
239 | start = partition_size * mp_rank
240 | return int(start)
241 |
242 | def get_partition_size(item):
243 | global mp_rank, mp_size, mp_group
244 | size = item.numel()
245 | partition_size = size/mp_size
246 | return int(partition_size)
247 |
248 | def get_full_inputs(tensors):
249 | inputs=[]
250 | for i in range(int(len(tensors)/2)-1):
251 | item = tensors[2 * i]
252 | size = tensors[2* i + 1]
253 | partition_size = item.numel()
254 | tensor_size = partition_size * mp_size
255 | flat_tensor = torch.zeros([tensor_size], dtype=item.dtype, device=item.device)
256 | partitions=[]
257 | for i in range(mp_size):
258 | part_i = flat_tensor.narrow(0, partition_size * i , partition_size)
259 | if i == mp_rank:
260 | part_i.copy_(item)
261 | partitions.append(part_i)
262 | dist.all_gather(partitions,partitions[mp_rank], group=mp_group)
263 | input_tensor = flat_tensor.view(list(size.numpy()))
264 | item.data=input_tensor.data
265 |
266 | inputs.append(item)
267 | inputs.append(tensors[-2])
268 |
269 | return tuple(inputs)
270 |
271 |
272 |
273 | class CheckpointFunction(torch.autograd.Function):
274 | """This function is adapted from torch.utils.checkpoint with
275 | two main changes:
276 | 1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state`
277 | 2) the states in the model parallel tracker are also properly
278 | tracked/set/reset.
279 | """
280 | @staticmethod
281 | def forward(ctx, run_function, *args):
282 | ctx.run_function = run_function
283 | global mp_rank, mp_size, mp_group
284 | if mp_rank is None:
285 | mp_rank = get_model_parallel_rank()
286 | mp_size = get_model_parallel_world_size()
287 | mp_group = get_model_parallel_group()
288 |
289 |
290 | global cuda_device, transport_stream, PARTITION_ACTIVATIONS
291 | if cuda_device is None:
292 | if dist.get_rank() == 0:
293 | print(f"Partition Activations {PARTITION_ACTIVATIONS} and Correctness Check {PA_CORRECTNESS_TEST}")
294 |
295 | cuda_device = torch.cuda.current_device()
296 | #The transport stream is used to overlap the allgather communication for the activations
297 | #with the computation in the backward pass
298 | transport_stream = torch.cuda.Stream(device=cuda_device)
299 |
300 | if PARTITION_ACTIVATIONS:
301 | inputs = [item.detach().contiguous().view(-1).narrow(0, get_partition_start(item), get_partition_size(item)).clone() for item in args[:-1]]
302 | inputs.append(args[-1])
303 |
304 | #just in case something funky is happening such as reuse of inputs
305 | inputs_cuda = [item.to(cuda_device) for item in args]
306 |
307 | # Copy the rng states.
308 | ctx.fwd_cpu_rng_state = torch.get_rng_state()
309 | ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state()
310 | ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
311 |
312 | #ctx.save_for_backward(*args)
313 | with torch.no_grad():
314 | outputs = run_function(*inputs_cuda)
315 |
316 | del inputs_cuda
317 |
318 | if PARTITION_ACTIVATIONS:
319 | new_args = []
320 | for arg, inp in zip(args,inputs):
321 | size= torch.tensor(arg.size())
322 | arg.data = inp.data
323 | new_args.append(arg)
324 | new_args.append(size)
325 | ctx.save_for_backward(*new_args)
326 | else:
327 | ctx.save_for_backward(*args)
328 |
329 | return outputs
330 |
331 | @staticmethod
332 | def backward(ctx, *args):
333 | if not torch.autograd._is_checkpoint_valid():
334 | raise RuntimeError("Checkpointing is not compatible with .grad(), "
335 | "please use .backward() if possible")
336 |
337 | global cuda_device, transport_stream, PARTITION_ACTIVATIONS
338 |
339 | if PARTITION_ACTIVATIONS:
340 | with torch.cuda.stream(transport_stream):
341 | inputs = get_full_inputs(ctx.saved_tensors)
342 | detached_inputs = detach_variable(inputs)
343 | else:
344 | inputs = ctx.saved_tensors
345 | detached_inputs = detach_variable(inputs)
346 |
347 | # Store the current states.
348 | bwd_cpu_rng_state = torch.get_rng_state()
349 | bwd_cuda_rng_state = torch.cuda.get_rng_state()
350 | bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
351 |
352 | # Set the states to what it used to be before the forward pass.
353 | torch.set_rng_state(ctx.fwd_cpu_rng_state)
354 | _set_cuda_rng_state(ctx.fwd_cuda_rng_state)
355 | get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker)
356 |
357 | if PARTITION_ACTIVATIONS:
358 | current_stream=torch.cuda.current_stream()
359 | current_stream.wait_stream(transport_stream)
360 |
361 | with torch.enable_grad():
362 | outputs = ctx.run_function(*detached_inputs)
363 |
364 | # Set the states back to what it was at the start of this function.
365 | torch.set_rng_state(bwd_cpu_rng_state)
366 | _set_cuda_rng_state(bwd_cuda_rng_state)
367 | get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker)
368 |
369 | if isinstance(outputs, torch.Tensor):
370 | outputs = (outputs,)
371 | torch.autograd.backward(outputs, args)
372 | return (None,) + tuple(inp.grad for inp in detached_inputs)
373 |
374 |
375 | def checkpoint(function, *args):
376 | """Checkpoint a model or part of the model.
377 | This has been directly copied from torch.utils.checkpoint."""
378 | return CheckpointFunction.apply(function, *args)
379 |
380 | def partition_activations_in_checkpoint(partition_activation):
381 | global PARTITION_ACTIVATIONS
382 | PARTITION_ACTIVATIONS=partition_activation
383 | if dist.get_rank() == 0:
384 | print(f"**************Partition Activations {PARTITION_ACTIVATIONS}************")
385 |
386 |
387 |
--------------------------------------------------------------------------------
/src/mpu/transformer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) 2020, Sber. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Transformer."""
17 |
18 | import math
19 |
20 | import torch
21 | from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm
22 | from .initialize import get_model_parallel_world_size
23 | from .layers import ColumnParallelLinear
24 | from .layers import RowParallelLinear
25 | from .utils import divide
26 | from .utils import split_tensor_along_last_dim
27 | from src.utils import DEEPSPEED_WRAP
28 |
29 |
30 | class GPT3ParallelSelfAttention(torch.nn.Module):
31 | """Parallel self-attention layer for GPT3.
32 |
33 | Self-attention layer takes input with size [b, s, h] where b is
34 | the batch size, s is the sequence lenght, and h is the hidden size
35 | and creates output of the same size.
36 | Arguments:
37 | hidden_size: total hidden size of the layer (h).
38 | num_attention_heads: number of attention heads (n). Note that we
39 | require n to be divisible by number of GPUs
40 | used to parallelize the model. Also, we
41 | require hidden size to be divisible by n.
42 | dropout_prob: dropout probability for the attention scores.
43 | init_method: weight initialization.
44 | output_layer_init_method: output layer initialization. If None, use
45 | `init_method`.
46 | We use the following notation:
47 | h: hidden_size
48 | n: num_attention_heads
49 | p: number of partitions
50 | np: n/p
51 | hp: h/p
52 | hn: h/n
53 | b: batch size
54 | s: sequence length
55 | """
56 |
57 | def __init__(self, hidden_size, num_attention_heads,
58 | attention_dropout_prob, output_dropout_prob,
59 | init_method, output_layer_init_method=None,
60 | use_deepspeed_sparse=None):
61 | super(GPT3ParallelSelfAttention, self).__init__()
62 | self.use_deepspeed_sparse = use_deepspeed_sparse
63 | if DEEPSPEED_WRAP:
64 | deepspeed = DEEPSPEED_WRAP.deepspeed
65 | from deepspeed.ops.sparse_attention import SparseSelfAttention
66 | if self.use_deepspeed_sparse is not None:
67 | self.sparse_self_attention = SparseSelfAttention(self.use_deepspeed_sparse)
68 | # Set output layer initialization if not provided.
69 | if output_layer_init_method is None:
70 | output_layer_init_method = init_method
71 | # Per attention head and per partition values.
72 | world_size = get_model_parallel_world_size()
73 | self.hidden_size_per_partition = divide(hidden_size, world_size)
74 | self.hidden_size_per_attention_head = divide(hidden_size,
75 | num_attention_heads)
76 | self.num_attention_heads_per_partition = divide(num_attention_heads,
77 | world_size)
78 | # Strided linear layer.
79 | self.query_key_value = ColumnParallelLinear(hidden_size, 3 * hidden_size,
80 | stride=3,
81 | gather_output=False,
82 | init_method=init_method)
83 | # Dropout. Note that for a single iteration, this layer will generate
84 | # different outputs on different number of parallel partitions but
85 | # on average it should not be partition dependent.
86 | self.attention_dropout = torch.nn.Dropout(attention_dropout_prob)
87 |
88 | # Output.
89 | self.dense = RowParallelLinear(hidden_size,
90 | hidden_size,
91 | input_is_parallel=True,
92 | init_method=output_layer_init_method)
93 | self.output_dropout = torch.nn.Dropout(output_dropout_prob)
94 |
95 | if DEEPSPEED_WRAP:
96 | if DEEPSPEED_WRAP.deepspeed.checkpointing.is_configured():
97 | global get_cuda_rng_tracker, checkpoint
98 | get_cuda_rng_tracker = DEEPSPEED_WRAP.deepspeed.checkpointing.get_cuda_rng_tracker
99 | checkpoint = deepspeed.checkpointing.checkpoint
100 |
101 | def _transpose_for_scores(self, tensor):
102 | """Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with
103 | size [b, np, s, hn].
104 | """
105 | new_tensor_shape = tensor.size()[:-1] + \
106 | (self.num_attention_heads_per_partition,
107 | self.hidden_size_per_attention_head)
108 | tensor = tensor.view(*new_tensor_shape)
109 | return tensor.permute(0, 2, 1, 3)
110 |
111 | def forward(self, hidden_states, ltor_mask):
112 | # hidden_states: [b, s, h]
113 | # ltor_mask: [1, 1, s, s]
114 |
115 | # Attention heads. [b, s, hp]
116 | mixed_x_layer = self.query_key_value(hidden_states)
117 | (mixed_query_layer,
118 | mixed_key_layer,
119 | mixed_value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
120 |
121 | # Reshape and transpose [b, np, s, hn]
122 | query_layer = self._transpose_for_scores(mixed_query_layer)
123 | key_layer = self._transpose_for_scores(mixed_key_layer)
124 | value_layer = self._transpose_for_scores(mixed_value_layer)
125 |
126 | if self.use_deepspeed_sparse:
127 | context_layer = self.sparse_self_attention(
128 | query_layer,
129 | key_layer,
130 | value_layer,
131 | attn_mask=ltor_mask)
132 | else:
133 | # Raw attention scores. [b, np, s, s]
134 | attention_scores = torch.matmul(query_layer,
135 | key_layer.transpose(-1, -2))
136 | attention_scores = attention_scores / math.sqrt(
137 | self.hidden_size_per_attention_head)
138 | # Apply the left to right attention mask.
139 | attention_scores = torch.mul(attention_scores, ltor_mask) - 10000.0 * (1.0 - ltor_mask)
140 |
141 | # Attention probabilities. [b, np, s, s]
142 | attention_probs = torch.nn.Softmax(dim=-1)(attention_scores)
143 | # This is actually dropping out entire tokens to attend to, which might
144 | # seem a bit unusual, but is taken from the original Transformer paper.
145 | if DEEPSPEED_WRAP and DEEPSPEED_WRAP.deepspeed.checkpointing.is_configured():
146 | with get_cuda_rng_tracker().fork():
147 | attention_probs = self.attention_dropout(attention_probs)
148 | else:
149 | attention_probs = self.attention_dropout(attention_probs)
150 |
151 | # Context layer.
152 | # [b, np, s, hn]
153 | context_layer = torch.matmul(attention_probs, value_layer)
154 | # [b, s, np, hn]
155 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
156 | new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
157 | # [b, s, hp]
158 | context_layer = context_layer.view(*new_context_layer_shape)
159 |
160 | # Output. [b, s, h]
161 | output = self.dense(context_layer)
162 | output = self.output_dropout(output)
163 |
164 | return output
165 |
166 |
167 | # @torch.jit.script
168 | # Remove torch.jit for colab
169 | def gelu_impl(x):
170 | """OpenAI's gelu implementation."""
171 | return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x *
172 | (1.0 + 0.044715 * x * x)))
173 |
174 |
175 | def gelu(x):
176 | return gelu_impl(x)
177 |
178 |
179 | class GPT3ParallelMLP(torch.nn.Module):
180 | """MLP for GPT3.
181 |
182 | MLP will take the input with h hidden state, project it to 4*h
183 | hidden dimension, perform gelu transformation, and project the
184 | state back into h hidden dimension. At the end, dropout is also
185 | applied.
186 |
187 | Arguments:
188 | hidden_size: The hidden size of the self attention.
189 | output_dropout_prob: dropout probability for the outputs
190 | after self attention and final output.
191 | init_method: initialization method used for the weights. Note
192 | that all biases are initialized to zero and
193 | layernorm weight are initialized to one.
194 | output_layer_init_method: output layer initialization. If None,
195 | use `init_method`.
196 | """
197 |
198 | def __init__(self, hidden_size, output_dropout_prob, init_method,
199 | output_layer_init_method=None):
200 | super(GPT3ParallelMLP, self).__init__()
201 | # Set output layer initialization if not provided.
202 | if output_layer_init_method is None:
203 | output_layer_init_method = init_method
204 | # Project to 4h.
205 | self.dense_h_to_4h = ColumnParallelLinear(hidden_size, 4 * hidden_size,
206 | gather_output=False,
207 | init_method=init_method)
208 | # Project back to h.
209 | self.dense_4h_to_h = RowParallelLinear(
210 | 4 * hidden_size,
211 | hidden_size,
212 | input_is_parallel=True,
213 | init_method=output_layer_init_method)
214 | self.dropout = torch.nn.Dropout(output_dropout_prob)
215 |
216 | def forward(self, hidden_states):
217 | # [b, s, 4hp]
218 | intermediate_parallel = self.dense_h_to_4h(hidden_states)
219 | intermediate_parallel = gelu(intermediate_parallel)
220 |
221 | # [b, s, h]
222 | output = self.dense_4h_to_h(intermediate_parallel)
223 | output = self.dropout(output)
224 | return output
225 |
226 |
227 | class GPT3ParallelTransformerLayer(torch.nn.Module):
228 | """A single layer transformer for GPT3.
229 |
230 | We use the following notation:
231 | h: hidden size
232 | n: number of attention heads
233 | b: batch size
234 | s: sequence length
235 | Transformore layer takes input with size [b, s, h] and returns an
236 | output of the same size.
237 |
238 | Arguments:
239 | hidden_size: The hidden size of the self attention.
240 | num_attention_heads: number of attention head in the self
241 | attention.
242 | attention_dropout_prob: dropout probability of the attention
243 | score in self attention.
244 | output_dropout_prob: dropout probability for the outputs
245 | after self attention and final output.
246 | layernorm_epsilon: epsilon used in layernorm to avoid
247 | division by zero.
248 | init_method: initialization method used for the weights. Note
249 | that all biases are initialized to zero and
250 | layernorm weight are initialized to one.
251 | output_layer_init_method: output layers (attention output and
252 | mlp output) initialization. If None,
253 | use `init_method`.
254 | """
255 |
256 | def __init__(self,
257 | hidden_size,
258 | num_attention_heads,
259 | attention_dropout_prob,
260 | output_dropout_prob,
261 | layernorm_epsilon,
262 | init_method,
263 | output_layer_init_method=None,
264 | use_deepspeed_sparse=None):
265 | super(GPT3ParallelTransformerLayer, self).__init__()
266 | # Set output layer initialization if not provided.
267 | if output_layer_init_method is None:
268 | output_layer_init_method = init_method
269 |
270 | # Layernorm on the input data.
271 | self.input_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
272 |
273 | # Self attention.
274 | self.attention = GPT3ParallelSelfAttention(
275 | hidden_size,
276 | num_attention_heads,
277 | attention_dropout_prob,
278 | output_dropout_prob,
279 | init_method,
280 | output_layer_init_method=output_layer_init_method,
281 | use_deepspeed_sparse=use_deepspeed_sparse)
282 |
283 | # Layernorm on the input data.
284 | self.post_attention_layernorm = LayerNorm(hidden_size,
285 | eps=layernorm_epsilon)
286 |
287 | # MLP
288 | self.mlp = GPT3ParallelMLP(
289 | hidden_size,
290 | output_dropout_prob,
291 | init_method,
292 | output_layer_init_method=output_layer_init_method)
293 |
294 | def forward(self, hidden_states, ltor_mask):
295 | # hidden_states: [b, s, h]
296 | # ltor_mask: [1, 1, s, s]
297 |
298 | # Layer norm at the begining of the transformer layer.
299 | layernorm_output = self.input_layernorm(hidden_states)
300 | # Self attention.
301 | attention_output = self.attention(layernorm_output, ltor_mask)
302 | # Residual connection.
303 | layernorm_input = hidden_states + attention_output
304 | # Layer norm post the self attention.
305 | layernorm_output = self.post_attention_layernorm(layernorm_input)
306 | # MLP.
307 | mlp_output = self.mlp(layernorm_output)
308 | # Second residual connection.
309 | output = layernorm_input + mlp_output
310 |
311 | return output
312 |
313 |
314 | def unscaled_init_method(sigma):
315 | """Init method based on N(0, sigma)."""
316 |
317 | def init_(tensor):
318 | return torch.nn.init.normal_(tensor, mean=0.0, std=sigma)
319 |
320 | return init_
321 |
322 |
323 | def scaled_init_method(sigma, num_layers):
324 | """Init method based on N(0, sigma/sqrt(2*num_layers)."""
325 | std = sigma / math.sqrt(2.0 * num_layers)
326 |
327 | def init_(tensor):
328 | return torch.nn.init.normal_(tensor, mean=0.0, std=std)
329 |
330 | return init_
331 |
332 |
333 | class GPT3ParallelTransformer(torch.nn.Module):
334 | """GPT-3 transformer.
335 |
336 | This module takes input from embedding layer and it's output can
337 | be used directly by a logit layer. It consists of L (num-layers)
338 | blocks of:
339 | layer norm
340 | self attention
341 | residual connection
342 | layer norm
343 | mlp
344 | residual connection
345 | followed by a final layer norm.
346 |
347 | Arguments:
348 | num_layers: Number of transformer layers.
349 | hidden_size: The hidden size of the self attention.
350 | num_attention_heads: number of attention head in the self
351 | attention.
352 | attention_dropout_prob: dropout probability of the attention
353 | score in self attention.
354 | output_dropout_prob: dropout probability for the outputs
355 | after self attention and final output.
356 | checkpoint_activations: if True, checkpoint activations.
357 | checkpoint_num_layers: number of layers to checkpoint. This
358 | is basically the chunk size in checkpoitning.
359 | layernorm_epsilon: epsilon used in layernorm to avoid
360 | division by zero.
361 | init_method_std: standard deviation of the init method which has
362 | the form N(0, std).
363 | use_scaled_init_for_output_weights: If Ture use 1/sqrt(2*num_layers)
364 | scaling for the output weights (
365 | output of self attention and mlp).
366 | """
367 |
368 | def __init__(self,
369 | num_layers,
370 | hidden_size,
371 | num_attention_heads,
372 | attention_dropout_prob,
373 | output_dropout_prob,
374 | checkpoint_activations,
375 | checkpoint_num_layers=1,
376 | layernorm_epsilon=1.0e-5,
377 | init_method_std=0.02,
378 | use_scaled_init_for_output_weights=True,
379 | use_deepspeed_sparse=None,
380 | sparse_mode='all'):
381 | super(GPT3ParallelTransformer, self).__init__()
382 |
383 | if DEEPSPEED_WRAP:
384 | from deepspeed.ops.sparse_attention import SparseSelfAttention
385 |
386 | # Store activation checkpoiting flag.
387 | self.checkpoint_activations = checkpoint_activations
388 | self.checkpoint_num_layers = checkpoint_num_layers
389 |
390 | output_layer_init_method = None
391 | if use_scaled_init_for_output_weights:
392 | output_layer_init_method = scaled_init_method(init_method_std,
393 | num_layers)
394 | if use_deepspeed_sparse and sparse_mode == 'alternating':
395 | print('Use alternating sparse & dense attention layers')
396 |
397 | def get_layer(layer_num, num_layers):
398 | sparsity_config = use_deepspeed_sparse
399 | if use_deepspeed_sparse:
400 | if sparse_mode == 'alternating' and layer_num % 2: # even layers are dense
401 | sparsity_config = None
402 | elif sparse_mode == 'top_bottom' and layer_num >= num_layers // 2: # top levels are dense
403 | sparsity_config = None
404 | return GPT3ParallelTransformerLayer(
405 | hidden_size,
406 | num_attention_heads,
407 | attention_dropout_prob,
408 | output_dropout_prob,
409 | layernorm_epsilon,
410 | unscaled_init_method(init_method_std),
411 | output_layer_init_method=output_layer_init_method,
412 | use_deepspeed_sparse=sparsity_config)
413 |
414 | # Transformer layers.
415 | self.layers = torch.nn.ModuleList(
416 | [get_layer(i, num_layers) for i in range(num_layers)])
417 |
418 | # Final layer norm before output.
419 | self.final_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
420 |
421 | if DEEPSPEED_WRAP:
422 | if DEEPSPEED_WRAP.deepspeed.checkpointing.is_configured():
423 | global get_cuda_rng_tracker, checkpoint
424 | get_cuda_rng_tracker = DEEPSPEED_WRAP.deepspeed.checkpointing.get_cuda_rng_tracker
425 | checkpoint = DEEPSPEED_WRAP.deepspeed.checkpointing.checkpoint
426 |
427 | def forward(self, hidden_states, attention_mask):
428 |
429 | def custom(start, end):
430 | def custom_forward(*inputs):
431 | layers_ = self.layers[start:end]
432 | x_ = inputs[0]
433 | for layer in layers_:
434 | x_ = layer(x_, inputs[1])
435 | return x_
436 |
437 | return custom_forward
438 |
439 | if self.checkpoint_activations:
440 | l = 0
441 | num_layers = len(self.layers)
442 | chunk_length = self.checkpoint_num_layers
443 | while l < num_layers:
444 | hidden_states = checkpoint(custom(l, l + chunk_length),
445 | hidden_states, attention_mask)
446 | l += chunk_length
447 | else:
448 | for layer in self.layers:
449 | hidden_states = layer(hidden_states, attention_mask)
450 |
451 | # Final layer norm.
452 | output = self.final_layernorm(hidden_states)
453 |
454 | return output
455 |
--------------------------------------------------------------------------------
/src/arguments.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) 2020, Sber. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """argparser configuration"""
17 |
18 | import argparse
19 | import os
20 | import torch
21 | from src.utils import DEEPSPEED_WRAP
22 | from modules.data.read import DataReader, add_data_reader_arguments
23 |
24 |
25 | def add_model_config_args(parser):
26 | """Model arguments"""
27 |
28 | group = parser.add_argument_group('model', 'model configuration')
29 |
30 | group.add_argument('--attention-dropout', type=float, default=0.1,
31 | help='dropout probability for attention weights')
32 | group.add_argument('--num-attention-heads', type=int, default=16,
33 | help='num of transformer attention heads')
34 | group.add_argument('--hidden-size', type=int, default=1024,
35 | help='tansformer hidden size')
36 | group.add_argument('--intermediate-size', type=int, default=None,
37 | help='transformer embedding dimension for FFN'
38 | 'set to 4*`--hidden-size` if it is None')
39 | group.add_argument('--num-layers', type=int, default=24,
40 | help='num decoder layers')
41 | group.add_argument('--layernorm-epsilon', type=float, default=1e-5,
42 | help='layer norm epsilon')
43 | group.add_argument('--hidden-dropout', type=float, default=0.1,
44 | help='dropout probability for hidden state transformer')
45 | group.add_argument('--max-position-embeddings', type=int, default=512,
46 | help='maximum number of position embeddings to use')
47 | group.add_argument('--vocab-size', type=int, default=30522,
48 | help='vocab size to use for non-character-level '
49 | 'tokenization. This value will only be used when '
50 | 'creating a tokenizer')
51 | group.add_argument('--deep-init', action='store_true',
52 | help='initialize bert model similar to gpt2 model.'
53 | 'scales initialization of projection layers by a '
54 | 'factor of 1/sqrt(2N). Necessary to train bert '
55 | 'models larger than BERT-Large.')
56 | group.add_argument('--make-vocab-size-divisible-by', type=int, default=8,
57 | help='Pad the vocab size to be divisible by this value.'
58 | 'This is added for computational efficieny reasons.')
59 | group.add_argument('--cpu-optimizer', action='store_true',
60 | help='Run optimizer on CPU')
61 | group.add_argument('--cpu_torch_adam', action='store_true',
62 | help='Use Torch Adam as optimizer on CPU.')
63 | group.add_argument('--sparse-mode', type=str, default='all',
64 | choices=['all', 'alternating', 'top_bottom'],
65 | help='sparse layers arrangement in model')
66 |
67 | return parser
68 |
69 |
70 | def add_fp16_config_args(parser):
71 | """Mixed precision arguments."""
72 |
73 | group = parser.add_argument_group('fp16', 'fp16 configurations')
74 |
75 | group.add_argument('--fp16', action='store_true',
76 | help='Run model in fp16 mode')
77 | group.add_argument('--fp32-embedding', action='store_true',
78 | help='embedding in fp32')
79 | group.add_argument('--fp32-layernorm', action='store_true',
80 | help='layer norm in fp32')
81 | group.add_argument('--fp32-tokentypes', action='store_true',
82 | help='embedding token types in fp32')
83 | group.add_argument('--fp32-allreduce', action='store_true',
84 | help='all-reduce in fp32')
85 | group.add_argument('--hysteresis', type=int, default=2,
86 | help='hysteresis for dynamic loss scaling')
87 | group.add_argument('--loss-scale', type=float, default=None,
88 | help='Static loss scaling, positive power of 2 '
89 | 'values can improve fp16 convergence. If None, dynamic'
90 | 'loss scaling is used.')
91 | group.add_argument('--loss-scale-window', type=float, default=1000,
92 | help='Window over which to raise/lower dynamic scale')
93 | group.add_argument('--min-scale', type=float, default=1,
94 | help='Minimum loss scale for dynamic loss scale')
95 |
96 | return parser
97 |
98 |
99 | def add_training_args(parser):
100 | """Training arguments."""
101 |
102 | group = parser.add_argument_group('train', 'training configurations')
103 |
104 | group.add_argument('--batch-size', type=int, default=4,
105 | help='Data Loader batch size')
106 | group.add_argument('--weight-decay', type=float, default=0.01,
107 | help='weight decay coefficient for L2 regularization')
108 | group.add_argument('--checkpoint-activations', action='store_true',
109 | help='checkpoint activation to allow for training '
110 | 'with larger models and sequences')
111 | group.add_argument('--checkpoint-num-layers', type=int, default=1,
112 | help='chunk size (number of layers) for checkpointing')
113 | group.add_argument('--deepspeed-activation-checkpointing', action='store_true',
114 | help='uses activation checkpointing from deepspeed')
115 | group.add_argument('--clip-grad', type=float, default=1.0,
116 | help='gradient clipping')
117 | group.add_argument('--train-iters', type=int, default=1000000,
118 | help='total number of iterations to train over all training runs')
119 | group.add_argument('--log-interval', type=int, default=100,
120 | help='report interval')
121 | group.add_argument('--logging-dir', type=str, default=None,
122 | help='tensorboard log dir')
123 | group.add_argument('--exit-interval', type=int, default=None,
124 | help='Exit the program after this many new iterations.')
125 |
126 | group.add_argument('--seed', type=int, default=1234,
127 | help='random seed')
128 | # Batch prodecuer arguments
129 | group.add_argument('--reset-position-ids', action='store_true',
130 | help='Reset posistion ids after end-of-document token.')
131 | group.add_argument('--reset-attention-mask', action='store_true',
132 | help='Reset self attention maske after '
133 | 'end-of-document token.')
134 |
135 | # Learning rate.
136 | group.add_argument('--lr-decay-iters', type=int, default=None,
137 | help='number of iterations to decay LR over,'
138 | ' If None defaults to `--train-iters`*`--epochs`')
139 | group.add_argument('--lr-decay-style', type=str, default='linear',
140 | choices=['constant', 'linear', 'cosine', 'exponential'],
141 | help='learning rate decay function')
142 | group.add_argument('--lr', type=float, default=1.0e-4,
143 | help='initial learning rate')
144 | group.add_argument('--min-lr', type=float, default=1.0e-6,
145 | help='minimal learning rate')
146 | group.add_argument('--warmup', type=float, default=0.01,
147 | help='percentage of data to warmup on (.01 = 1% of all '
148 | 'training iters). Default 0.01')
149 | # model checkpointing
150 | group.add_argument('--save', type=str, default=None,
151 | help='Output directory to save checkpoints to.')
152 | group.add_argument('--save-interval', type=int, default=5000,
153 | help='number of iterations between saves')
154 | group.add_argument('--no-save-optim', action='store_true',
155 | help='Do not save current optimizer.')
156 | group.add_argument('--no-save-rng', action='store_true',
157 | help='Do not save current rng state.')
158 | group.add_argument('--load', type=str, default=None,
159 | help='Path to a directory containing a model checkpoint.')
160 | group.add_argument('--no-load-optim', action='store_true',
161 | help='Do not load optimizer when loading checkpoint.')
162 | group.add_argument('--log-memory', action='store_true',
163 | help='Write memory consumption in tensorboard log')
164 | group.add_argument('--no-load-rng', action='store_true',
165 | help='Do not load rng state when loading checkpoint.')
166 | group.add_argument('--load-huggingface', type=str, default=None,
167 | help='Path to a directory containing a huggingface transformers model checkpoint.')
168 | group.add_argument('--export-huggingface', type=str, default=None,
169 | help='Exported model to path in huggingface format.')
170 | group.add_argument('--huggingface-double-pos-embeddings', action='store_true',
171 | help='Duplicate first half of pos embedding weights to last')
172 | group.add_argument('--load-tag', type=str, default='',
173 | help='checkpoint name to test')
174 | group.add_argument('--cache-prefix', type=str, default='_',
175 | help='cache folder prefix')
176 | group.add_argument('--finetune', action='store_true',
177 | help='Load model for finetuning. Do not load optimizer '
178 | 'or rng state from checkpoint and set iteration to 0. '
179 | 'Assumed when loading a release checkpoint.')
180 | group.add_argument('--resume-dataloader', action='store_true',
181 | help='Resume the dataloader when resuming training. '
182 | 'Does not apply to tfrecords dataloader, try resuming'
183 | 'with a different seed in this case.')
184 | # distributed training args
185 | group.add_argument('--distributed-backend', default='nccl',
186 | help='which backend to use for distributed '
187 | 'training. One of [gloo, nccl]')
188 |
189 | group.add_argument('--local_rank', type=int, default=None,
190 | help='local rank passed from distributed launcher')
191 |
192 | group.add_argument('--master_port', type=int, default=6000,
193 | help='master port for test parallel prediction')
194 |
195 | return parser
196 |
197 |
198 | def add_evaluation_args(parser):
199 | """Evaluation arguments."""
200 |
201 | group = parser.add_argument_group('validation', 'validation configurations')
202 |
203 | group.add_argument('--eval-batch-size', type=int, default=None,
204 | help='Data Loader batch size for evaluation datasets.'
205 | 'Defaults to `--batch-size`')
206 | group.add_argument('--eval-iters', type=int, default=100,
207 | help='number of iterations to run for evaluation'
208 | 'validation/test for')
209 | group.add_argument('--eval-interval', type=int, default=1000,
210 | help='interval between running evaluation on validation set')
211 | group.add_argument('--eval-seq-length', type=int, default=None,
212 | help='Maximum sequence length to process for '
213 | 'evaluation. Defaults to `--seq-length`')
214 | group.add_argument('--eval-max-preds-per-seq', type=int, default=None,
215 | help='Maximum number of predictions to use for '
216 | 'evaluation. Defaults to '
217 | 'math.ceil(`--eval-seq-length`*.15/10)*10')
218 | group.add_argument('--overlapping-eval', type=int, default=32,
219 | help='sliding window for overlapping eval ')
220 | group.add_argument('--cloze-eval', action='store_true',
221 | help='Evaluation dataset from `--valid-data` is a cloze task')
222 | group.add_argument('--eval-hf', action='store_true',
223 | help='perform evaluation with huggingface openai model.'
224 | 'use `--load` to specify weights path to be loaded')
225 | group.add_argument('--load-openai', action='store_true',
226 | help='load openai weights into our model. Use `--load` '
227 | 'to specify weights path to be loaded')
228 |
229 | return parser
230 |
231 |
232 | def add_text_generate_args(parser):
233 | """Text generate arguments."""
234 |
235 | group = parser.add_argument_group('Text generation', 'configurations')
236 | group.add_argument("--temperature", type=float, default=1.0)
237 | group.add_argument("--top_p", type=float, default=0.0)
238 | group.add_argument("--top_k", type=int, default=0)
239 | group.add_argument("--out-seq-length", type=int, default=256)
240 | group.add_argument("--tg-token-name", type=str, default='token.txt')
241 | return parser
242 |
243 |
244 | def add_data_args(parser):
245 | """Train/valid/test data arguments."""
246 |
247 | group = parser.add_argument_group('data', 'data configurations')
248 |
249 | group.add_argument('--model-parallel-size', type=int, default=1,
250 | help='size of the model parallel.')
251 | group.add_argument('--shuffle', action='store_true',
252 | help='Shuffle data. Shuffling is deterministic '
253 | 'based on seed and current epoch.')
254 | group.add_argument('--train-data', nargs='+', default=None,
255 | help='Whitespace separated filenames or corpora names '
256 | 'for training.')
257 |
258 | group.add_argument('--use-npy-data-loader', action='store_true',
259 | help='Use the numpy data loader. If set, then'
260 | 'train-data-path, val-data-path, and test-data-path'
261 | 'should also be provided.')
262 | group.add_argument('--train-data-path', type=str, default='',
263 | help='path to the training data')
264 | group.add_argument('--val-data-path', type=str, default='',
265 | help='path to the validation data')
266 | group.add_argument('--test-data-path', type=str, default='',
267 | help='path to the test data')
268 | group.add_argument('--input-data-sizes-file', type=str, default='sizes.txt',
269 | help='the filename containing all the shards sizes')
270 |
271 | group.add_argument('--delim', default=',',
272 | help='delimiter used to parse csv data files')
273 | group.add_argument('--text-key', default='sentence',
274 | help='key to use to extract text from json/csv')
275 | group.add_argument('--eval-text-key', default=None,
276 | help='key to use to extract text from '
277 | 'json/csv evaluation datasets')
278 | group.add_argument('--valid-data', nargs='*', default=None,
279 | help="""Filename for validation data.""")
280 | group.add_argument('--split', default='1000,1,1',
281 | help='comma-separated list of proportions for training,'
282 | ' validation, and test split')
283 | group.add_argument('--test-data', nargs='*', default=None,
284 | help="""Filename for testing""")
285 | group.add_argument('--overwrite-cache', action='store_true',
286 | help='overwrite dataset cache')
287 | group.add_argument('--lazy-loader', action='store_true',
288 | help='whether to lazy read the data set')
289 | group.add_argument('--loose-json', action='store_true',
290 | help='Use loose json (one json-formatted string per '
291 | 'newline), instead of tight json (data file is one '
292 | 'json string)')
293 | group.add_argument('--presplit-sentences', action='store_true',
294 | help='Dataset content consists of documents where '
295 | 'each document consists of newline separated sentences')
296 | group.add_argument('--num-workers', type=int, default=2,
297 | help="""Number of workers to use for dataloading""")
298 | group.add_argument('--tokenizer-path', type=str, default=None,
299 | help='path used to save/load sentencepiece tokenization '
300 | 'models')
301 | group.add_argument("--cache-dir", default=None, type=str,
302 | help="Where to store pre-trained BERT downloads")
303 | group.add_argument('--use-tfrecords', action='store_true',
304 | help='load `--train-data`, `--valid-data`, '
305 | '`--test-data` from BERT tf records instead of '
306 | 'normal data pipeline')
307 | group.add_argument('--seq-length', type=int, default=512,
308 | help="Maximum sequence length to process")
309 | group.add_argument('--max-files-per-process', type=int, default=50000,
310 | help="Maximum files to load per process")
311 | group.add_argument('--max-preds-per-seq', type=int, default=None,
312 | help='Maximum number of predictions to use per sequence.'
313 | 'Defaults to math.ceil(`--seq-length`*.15/10)*10.'
314 | 'MUST BE SPECIFIED IF `--use-tfrecords` is True.')
315 |
316 | return parser
317 |
318 |
319 | def get_args(cmd=None):
320 | """Parse all the args."""
321 |
322 | parser = argparse.ArgumentParser(description='PyTorch BERT Model')
323 | parser = add_model_config_args(parser)
324 | parser = add_fp16_config_args(parser)
325 | parser = add_training_args(parser)
326 | parser = add_evaluation_args(parser)
327 | parser = add_text_generate_args(parser)
328 | parser = add_data_args(parser)
329 | parser = add_data_reader_arguments(parser)
330 |
331 | # Include DeepSpeed configuration arguments
332 | if DEEPSPEED_WRAP:
333 | parser = DEEPSPEED_WRAP.deepspeed.add_config_arguments(parser)
334 | if cmd is None:
335 | args = parser.parse_args()
336 | else:
337 | args = parser.parse_args(cmd)
338 |
339 | if not args.train_data and not args.train_data_path:
340 | print('WARNING: No training data specified')
341 |
342 | args.cuda = torch.cuda.is_available()
343 |
344 | args.rank = int(os.getenv('RANK', '0'))
345 | args.world_size = int(os.getenv("WORLD_SIZE", '1'))
346 |
347 | if os.getenv('OMPI_COMM_WORLD_LOCAL_RANK'):
348 | # We are using (OpenMPI) mpirun for launching distributed data parallel processes
349 | local_rank = int(os.getenv('OMPI_COMM_WORLD_LOCAL_RANK'))
350 | local_size = int(os.getenv('OMPI_COMM_WORLD_LOCAL_SIZE'))
351 |
352 | # Possibly running with Slurm
353 | num_nodes = int(os.getenv('SLURM_JOB_NUM_NODES', '1'))
354 | nodeid = int(os.getenv('SLURM_NODEID', '0'))
355 |
356 | args.local_rank = local_rank
357 | args.rank = nodeid * local_size + local_rank
358 | args.world_size = num_nodes * local_size
359 |
360 | args.model_parallel_size = min(args.model_parallel_size, args.world_size)
361 | if args.rank == 0:
362 | print('using world size: {} and model-parallel size: {} '.format(
363 | args.world_size, args.model_parallel_size))
364 |
365 | args.dynamic_loss_scale = False
366 | if args.loss_scale is None:
367 | args.dynamic_loss_scale = True
368 | if args.rank == 0:
369 | print(' > using dynamic loss scaling')
370 |
371 | # The args fp32_* or fp16_* meant to be active when the
372 | # args fp16 is set. So the default behaviour should all
373 | # be false.
374 | if not args.fp16:
375 | args.fp32_embedding = False
376 | args.fp32_tokentypes = False
377 | args.fp32_layernorm = False
378 |
379 | return args
380 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) 2020, Sber. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Utilities for logging and serialization"""
17 |
18 | import os
19 | import random
20 | import time
21 |
22 | import numpy as np
23 | import torch
24 | from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
25 |
26 | from src import mpu
27 | from src.fp16 import FP16_Optimizer, FP16_Module
28 | from src.model import DistributedDataParallel as DDP
29 | import os
30 | from src.download_utils import download_model_files
31 |
32 |
33 | class DeepSpeedImportWrap(object):
34 | def __init__(self):
35 | self.use_ds = os.environ.get("USE_DEEPSPEED", False)
36 | self.deepspeed = None
37 | if self.use_ds:
38 | import deepspeed
39 | self.deepspeed = deepspeed
40 |
41 | def __bool__(self):
42 | return bool(self.use_ds)
43 |
44 |
45 | DEEPSPEED_WRAP = DeepSpeedImportWrap()
46 |
47 |
48 | def print_rank_0(message):
49 | if torch.distributed.is_initialized():
50 | if torch.distributed.get_rank() == 0:
51 | print(message, flush=True)
52 | else:
53 | print(message, flush=True)
54 |
55 |
56 | def print_args(args):
57 | """Print arguments."""
58 |
59 | print('arguments:', flush=True)
60 | for arg in vars(args):
61 | dots = '.' * (29 - len(arg))
62 | print(' {} {} {}'.format(arg, dots, getattr(args, arg)), flush=True)
63 |
64 |
65 | def print_params_min_max_norm(optimizer, iteration):
66 | """Print min, max, and norm of all parameters."""
67 | index = 0
68 | rank = torch.distributed.get_rank()
69 | string = 'iteration, rank, index, model-parallel,min, max, norm\n'
70 | optimizer_ = optimizer
71 | if isinstance(optimizer, FP16_Optimizer):
72 | optimizer_ = optimizer.optimizer
73 | for param_group in optimizer_.param_groups:
74 | for param in param_group['params']:
75 | index += 1
76 | min_ = param.data.min()
77 | max_ = param.data.max()
78 | norm = param.data.norm()
79 | string += '{:7d}, {:4d}, {:4d}, {:2d}, '.format(
80 | iteration, rank, index, int(param.model_parallel))
81 | string += '{:.6E}, {:.6E}, {:.6E}\n'.format(min_, max_, norm)
82 | print(string, flush=True)
83 |
84 |
85 | class Timers:
86 | """Group of timers."""
87 |
88 | class Timer:
89 | """Timer."""
90 |
91 | def __init__(self, name):
92 | self.name_ = name
93 | self.elapsed_ = 0.0
94 | self.started_ = False
95 | self.start_time = time.time()
96 |
97 | def start(self):
98 | """Start the timer."""
99 | assert not self.started_, 'timer has already been started'
100 | torch.cuda.synchronize()
101 | self.start_time = time.time()
102 | self.started_ = True
103 |
104 | def stop(self):
105 | """Stop the timer."""
106 | assert self.started_, 'timer is not started'
107 | torch.cuda.synchronize()
108 | self.elapsed_ += (time.time() - self.start_time)
109 | self.started_ = False
110 |
111 | def reset(self):
112 | """Reset timer."""
113 | self.elapsed_ = 0.0
114 | self.started_ = False
115 |
116 | def elapsed(self, reset=True):
117 | """Calculate the elapsed time."""
118 | started_ = self.started_
119 | # If the timing in progress, end it first.
120 | if self.started_:
121 | self.stop()
122 | # Get the elapsed time.
123 | elapsed_ = self.elapsed_
124 | # Reset the elapsed time
125 | if reset:
126 | self.reset()
127 | # If timing was in progress, set it back.
128 | if started_:
129 | self.start()
130 | return elapsed_
131 |
132 | def __init__(self):
133 | self.timers = {}
134 |
135 | def __call__(self, name):
136 | if name not in self.timers:
137 | self.timers[name] = self.Timer(name)
138 | return self.timers[name]
139 |
140 | def log(self, names, normalizer=1.0, reset=True):
141 | """Log a group of timers."""
142 | assert normalizer > 0.0
143 | string = 'time (ms)'
144 | for name in names:
145 | elapsed_time = self.timers[name].elapsed(
146 | reset=reset) * 1000.0 / normalizer
147 | string += ' | {}: {:.2f}'.format(name, elapsed_time)
148 | print_rank_0(string)
149 |
150 |
151 | def report_memory(name):
152 | """Simple GPU memory report."""
153 |
154 | mega_bytes = 1024.0 * 1024.0
155 | string = name + ' memory (MB)'
156 | string += ' | allocated: {}'.format(
157 | torch.cuda.memory_allocated() / mega_bytes)
158 | string += ' | max allocated: {}'.format(
159 | torch.cuda.max_memory_allocated() / mega_bytes)
160 | string += ' | cached: {}'.format(torch.cuda.memory_cached() / mega_bytes)
161 | string += ' | max cached: {}'.format(
162 | torch.cuda.max_memory_cached() / mega_bytes)
163 | print_rank_0(string)
164 |
165 |
166 | def get_checkpoint_name(checkpoints_path, iteration, release=False, zero=False):
167 | if release:
168 | d = 'release'
169 | else:
170 | d = 'iter_{:07d}'.format(iteration)
171 | if zero:
172 | dp_rank = mpu.get_data_parallel_rank()
173 | d += '_zero_dp_rank_{}'.format(dp_rank)
174 | return os.path.join(checkpoints_path, d,
175 | 'mp_rank_{:02d}'.format(mpu.get_model_parallel_rank()),
176 | 'model_optim_rng.pt')
177 |
178 |
179 | def ensure_directory_exists(filename):
180 | dirname = os.path.dirname(filename)
181 | if not os.path.exists(dirname):
182 | os.makedirs(dirname)
183 |
184 |
185 | def get_checkpoint_tracker_filename(checkpoints_path):
186 | return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt')
187 |
188 |
189 | def save_zero_checkpoint(args, iteration, optimizer):
190 | zero_sd = {'iteration': iteration,
191 | 'optimizer_state_dict': optimizer.state_dict()}
192 | zero_checkpoint_name = get_checkpoint_name(args.save, iteration, zero=True)
193 | ensure_directory_exists(zero_checkpoint_name)
194 | torch.save(zero_sd, zero_checkpoint_name)
195 | print(' successfully saved {}'.format(zero_checkpoint_name))
196 |
197 |
198 | def save_checkpoint(iteration, model, optimizer,
199 | lr_scheduler, args, deepspeed=False):
200 | """Save a model checkpoint."""
201 | if deepspeed:
202 | save_ds_checkpoint(iteration, model, args)
203 | else:
204 | # Only rank zer0 of the data parallel writes to the disk.
205 | if isinstance(model, torchDDP):
206 | model = model.module
207 |
208 | if mpu.get_data_parallel_rank() == 0:
209 | checkpoint_name = get_checkpoint_name(args.save, iteration)
210 | print('global rank {} is saving checkpoint at iteration {:7d} to {}'.
211 | format(torch.distributed.get_rank(), iteration, checkpoint_name))
212 |
213 | sd = {}
214 | sd['iteration'] = iteration
215 | sd['model'] = model.state_dict()
216 |
217 | # Optimizer stuff.
218 | if not args.no_save_optim:
219 | if optimizer is not None:
220 | sd['optimizer'] = optimizer.state_dict()
221 | if lr_scheduler is not None:
222 | sd['lr_scheduler'] = lr_scheduler.state_dict()
223 |
224 | # rng states.
225 | if not args.no_save_rng:
226 | sd['random_rng_state'] = random.getstate()
227 | sd['np_rng_state'] = np.random.get_state()
228 | sd['torch_rng_state'] = torch.get_rng_state()
229 | sd['cuda_rng_state'] = torch.cuda.get_rng_state()
230 | sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states()
231 |
232 | ensure_directory_exists(checkpoint_name)
233 | torch.save(sd, checkpoint_name)
234 | print(' successfully saved {}'.format(checkpoint_name))
235 |
236 | # Wait so everyone is done (necessary)
237 | torch.distributed.barrier()
238 | # And update the latest iteration
239 | if torch.distributed.get_rank() == 0:
240 | tracker_filename = get_checkpoint_tracker_filename(args.save)
241 | with open(tracker_filename, 'w') as f:
242 | f.write(str(iteration))
243 | # Wait so everyone is done (not necessary)
244 | torch.distributed.barrier()
245 |
246 |
247 | def save_ds_checkpoint(iteration, model, args):
248 | """Save a model checkpoint."""
249 |
250 | sd = {}
251 | sd['iteration'] = iteration
252 | # rng states.
253 | if not args.no_save_rng:
254 | sd['random_rng_state'] = random.getstate()
255 | sd['np_rng_state'] = np.random.get_state()
256 | sd['torch_rng_state'] = torch.get_rng_state()
257 | sd['cuda_rng_state'] = torch.cuda.get_rng_state()
258 | sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states()
259 |
260 | model.save_checkpoint(args.save, str(iteration), client_state=sd)
261 |
262 |
263 | def get_checkpoints(load_dir):
264 | return [f for f in os.listdir(load_dir) if '.' not in f]
265 |
266 |
267 | def get_last_checkpoints(load_dir, n=1):
268 | checkpoints = get_checkpoints(load_dir)
269 | checkpoints = [(int(c), c) for c in checkpoints]
270 | by_iteration = sorted(checkpoints, key=lambda x: x[0], reverse=True)
271 | return [b[-1] for b in by_iteration[:n]]
272 |
273 |
274 | def get_outdated_checkpoints(load_dir, retain_last_n=5):
275 | last = get_last_checkpoints(load_dir, retain_last_n)
276 | return [d for d in get_checkpoints(load_dir) if d not in last]
277 |
278 |
279 | def get_checkpoint_iteration(args):
280 | # Read the tracker file and set the iteration.
281 | if args.load_tag:
282 | return args.load_tag, False, True
283 | tracker_filename = get_checkpoint_tracker_filename(args.load)
284 | if not os.path.isfile(tracker_filename):
285 | print_rank_0('WARNING: could not find the metadata file {} '.format(
286 | tracker_filename))
287 | print_rank_0(' will not load any checkpoints and will start from '
288 | 'random')
289 | return 0, False, False
290 | iteration = 0
291 | release = False
292 | with open(tracker_filename, 'r') as f:
293 | metastring = f.read().strip()
294 | try:
295 | iteration = int(metastring)
296 | except ValueError:
297 | release = metastring == 'release'
298 | if not release:
299 | print_rank_0('ERROR: Invalid metadata file {}. Exiting'.format(
300 | tracker_filename))
301 | exit()
302 |
303 | assert iteration > 0 or release, 'error parsing metadata file {}'.format(
304 | tracker_filename)
305 |
306 | return iteration, release, True
307 |
308 |
309 | def load_checkpoint(model, optimizer, lr_scheduler, args, deepspeed=False):
310 | """Load a model checkpoint."""
311 |
312 | iteration, release, success = get_checkpoint_iteration(args)
313 |
314 | if not success:
315 | return 0
316 |
317 | if deepspeed:
318 | load_optim = not args.no_load_optim
319 | checkpoint_name, sd = model.load_checkpoint(args.load, iteration, load_optimizer_states=load_optim,
320 | load_lr_scheduler_states=load_optim)
321 |
322 | if checkpoint_name is None:
323 | if mpu.get_data_parallel_rank() == 0:
324 | print("Unable to load checkpoint.")
325 | return iteration
326 |
327 | else:
328 |
329 | # Checkpoint.
330 | checkpoint_name = get_checkpoint_name(args.load, iteration, release)
331 |
332 | if mpu.get_data_parallel_rank() == 0:
333 | print('global rank {} is loading checkpoint {}'.format(
334 | torch.distributed.get_rank(), checkpoint_name))
335 |
336 | # Load the checkpoint.
337 | sd = torch.load(checkpoint_name, map_location='cpu')
338 |
339 | if isinstance(model, torchDDP):
340 | model = model.module
341 |
342 | # Model.
343 | try:
344 | model.load_state_dict(sd['model'])
345 | except KeyError:
346 | print_rank_0('A metadata file exists but unable to load model '
347 | 'from checkpoint {}, exiting'.format(checkpoint_name))
348 | exit()
349 |
350 | # Optimizer.
351 | if not release and not args.finetune and not args.no_load_optim:
352 | try:
353 | if optimizer is not None:
354 | optimizer.load_state_dict(sd['optimizer'])
355 | if lr_scheduler is not None:
356 | lr_scheduler.load_state_dict(sd['lr_scheduler'])
357 | except KeyError:
358 | print_rank_0('Unable to load optimizer from checkpoint {}, exiting. '
359 | 'Specify --no-load-optim or --finetune to prevent '
360 | 'attempting to load the optimizer '
361 | 'state.'.format(checkpoint_name))
362 | exit()
363 |
364 | # Iterations.
365 | if args.finetune or release:
366 | iteration = 0
367 | else:
368 | try:
369 | iteration = sd['iteration']
370 | except KeyError:
371 | try: # Backward compatible with older checkpoints
372 | iteration = sd['total_iters']
373 | except KeyError:
374 | print_rank_0('A metadata file exists but Unable to load iteration '
375 | ' from checkpoint {}, exiting'.format(checkpoint_name))
376 | exit()
377 |
378 | # rng states.
379 | if not release and not args.finetune and not args.no_load_rng:
380 | try:
381 | random.setstate(sd['random_rng_state'])
382 | np.random.set_state(sd['np_rng_state'])
383 | torch.set_rng_state(sd['torch_rng_state'])
384 | torch.cuda.set_rng_state(sd['cuda_rng_state'])
385 | mpu.get_cuda_rng_tracker().set_states(sd['rng_tracker_states'])
386 | except KeyError:
387 | print_rank_0('Unable to load optimizer from checkpoint {}, exiting. '
388 | 'Specify --no-load-optim or --finetune to prevent '
389 | 'attempting to load the optimizer '
390 | 'state.'.format(checkpoint_name))
391 | exit()
392 |
393 | torch.distributed.barrier()
394 | if mpu.get_data_parallel_rank() == 0:
395 | print(' successfully loaded {}'.format(checkpoint_name))
396 |
397 | return iteration
398 |
399 |
400 | def load_weights(src, dst, dst2src=False, double_pos_embeddings=False):
401 | """
402 | Loads weights from src to dst via in place copy.
403 | src is a huggingface gpt3model, while dst is one of our models.
404 | dst2src=True loads parameters from our models into huggingface's.
405 | ^dst2src is still untested
406 | """
407 | conv_layer = 'Conv1D' in str(type(src))
408 | for n, p in src.named_parameters():
409 | if dst2src:
410 | data = dst._parameters[n].data
411 | load = p.data
412 | else:
413 | if double_pos_embeddings:
414 | print('Double pos embeddings')
415 | mid = p.size(0) // 2
416 | p[mid:, :] = p[:mid, :] # copy first half of position embedings to last
417 | data = p.data
418 | load = dst._parameters[n].data
419 | if conv_layer and 'weight' in n:
420 | data = data.t().contiguous()
421 | load.copy_(data)
422 |
423 |
424 | # dst._parameters[n].data.copy_(data)
425 |
426 | def load_mlp(our, oai, dst2src=False):
427 | load_weights(oai.c_fc, our.dense_h_to_4h, dst2src)
428 | load_weights(oai.c_proj, our.dense_4h_to_h, dst2src)
429 |
430 |
431 | def load_attention(our, oai, dst2src=False):
432 | load_weights(oai.c_attn, our.query_key_value, dst2src)
433 | load_weights(oai.c_proj, our.dense, dst2src)
434 |
435 |
436 | def load_transformer_layer(our, oai, dst2src=False):
437 | load_weights(oai.ln_1, our.input_layernorm, dst2src)
438 | load_weights(oai.ln_2, our.post_attention_layernorm, dst2src)
439 | load_mlp(our.mlp, oai.mlp, dst2src)
440 | load_attention(our.attention, oai.attn, dst2src)
441 |
442 |
443 | def move_weights(our, oai, dst2src=False, double_pos_embeddings=False):
444 | """
445 | Loads weights from `oai` to `our` via in place copy.
446 | `oai` is a huggingface gpt3model, while `our` is one of our models.
447 | dst2src=True loads parameters from our models into huggingface's.
448 | ^dst2src=True is still untested
449 | """
450 | # while isinstance(our, (torchDDP, model.distributed.DistributedDataParallel, FP16_Module)):
451 | # our=our.module
452 | transformer_model = oai.transformer
453 | load_weights(transformer_model.ln_f, our.transformer.final_layernorm, dst2src)
454 | load_weights(transformer_model.wte, our.word_embeddings, dst2src)
455 | load_weights(transformer_model.wpe, our.position_embeddings, dst2src, double_pos_embeddings=double_pos_embeddings)
456 |
457 | for our_layer, oai_layer in zip(our.transformer.layers, oai.transformer.h):
458 | load_transformer_layer(our_layer, oai_layer, dst2src)
459 |
460 |
461 | def load_huggingface_model(model, path, double_pos_embeddings):
462 | from transformers import GPT2LMHeadModel
463 | print('Load huggingface model from', path, ('with pos emb doubling' if double_pos_embeddings else ''))
464 | model2fill = model
465 | while isinstance(model2fill, (torchDDP, FP16_Module)):
466 | model2fill = model2fill.module
467 |
468 | if path == "sberbank-ai/rugpt3xl":
469 | weights_path, _ = download_model_files("sberbank-ai/rugpt3xl")
470 | checkpoint = torch.load(weights_path, map_location=lambda storage, loc: storage)['module']
471 | model2fill.load_state_dict(checkpoint)
472 | else:
473 | h_model = GPT2LMHeadModel.from_pretrained(path)
474 | move_weights(model2fill, h_model, double_pos_embeddings)
475 |
476 | print('Loaded huggingface model', type(model))
477 | return model
478 |
479 |
480 | def export_to_huggingface_model(model, path):
481 | from transformers import GPT2LMHeadModel, GPT2Config
482 | model_from = model
483 | while isinstance(model_from, (DDP, torchDDP, FP16_Module)):
484 | model_from = model_from.module
485 | conf_dict = model_from._conf_dict
486 | print('Export to huggingface model ', path, 'with config', conf_dict)
487 | config = GPT2Config(**conf_dict)
488 | hf_model = GPT2LMHeadModel(config=config)
489 | model_to = hf_model
490 | while isinstance(model_to, (DDP, torchDDP, FP16_Module)):
491 | model_to = model_to.module
492 | move_weights(model_from, model_to, dst2src=True)
493 | hf_model.save_pretrained(path)
494 | print('Saved huggingface model', type(model))
495 |
496 |
497 | def get_deepspeed_config(args):
498 | if hasattr(args, 'deepspeed_config') and args.deepspeed_config:
499 | from deepspeed import DeepSpeedConfig
500 | return DeepSpeedConfig(args.deepspeed_config)
501 | else:
502 | raise RuntimeError('deepspeed_config is not found in args.')
503 |
504 |
505 | def get_sparse_attention_config(args, num_heads):
506 | ds_config = get_deepspeed_config(args)
507 | if hasattr(ds_config, 'sparse_attention') and ds_config.sparse_attention:
508 | sa_config = ds_config.sparse_attention
509 | sa_mode = sa_config.get('mode')
510 | if (sa_mode == 'dense'):
511 | from deepspeed.ops.sparse_attention import DenseSparsityConfig as STConfig
512 | elif (sa_mode == 'fixed'):
513 | from deepspeed.ops.sparse_attention import FixedSparsityConfig as STConfig
514 | elif (sa_mode == 'bigbird'):
515 | from deepspeed.ops.sparse_attention import BigBirdSparsityConfig as STConfig
516 | elif (sa_mode == 'bslongformer'):
517 | from deepspeed.ops.sparse_attention import BSLongformerSparsityConfig as STConfig
518 | elif (sa_mode == 'variable'):
519 | from deepspeed.ops.sparse_attention import VariableSparsityConfig as STConfig
520 | else:
521 | raise NotImplementedError(
522 | f'Given sparsity mode, {sa_mode}, has not been implemented yet!'
523 | )
524 | del sa_config['mode']
525 | return STConfig(num_heads=num_heads, **sa_config)
526 | else:
527 | return None
528 |
529 |
530 | def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
531 | # This function has been mostly taken from huggingface conversational ai code at
532 | # https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313
533 |
534 | if top_k > 0:
535 | # Remove all tokens with a probability less than the last token of the top-k
536 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
537 | logits[indices_to_remove] = filter_value
538 |
539 | if top_p > 0.0:
540 | # convert to 1D
541 | logits = logits.view(logits.size()[1]).contiguous()
542 | sorted_logits, sorted_indices = torch.sort(logits, descending=True)
543 | cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
544 |
545 | # Remove tokens with cumulative probability above the threshold
546 | sorted_indices_to_remove = cumulative_probs > top_p
547 | # Shift the indices to the right to keep also the first token above the threshold
548 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
549 | sorted_indices_to_remove[..., 0] = 0
550 | indices_to_remove = sorted_indices[sorted_indices_to_remove]
551 | logits[indices_to_remove] = filter_value
552 | # going back to 2D
553 | logits = logits.view(1, -1).contiguous()
554 |
555 | return logits
556 |
--------------------------------------------------------------------------------