├── _lib └── fastertransformer_training_v1 │ └── cu111 │ └── torch190 │ ├── libfastertransformer.so │ └── th_fastertransformer.cpython-37m-x86_64-linux-gnu.so ├── .gitignore ├── .gitmodules ├── src └── veGiantModel │ ├── module │ ├── __init__.py │ └── dense.py │ ├── distributed │ └── __init__.py │ ├── __init__.py │ ├── launcher │ └── launch.py │ ├── initialize.py │ ├── patcher.py │ └── engine │ ├── p2p.py │ ├── schedule.py │ ├── module.py │ └── topology.py ├── docs ├── Dockerfile └── step-by-step-tutorial.md ├── README.md ├── examples └── gpt │ ├── pretrain_gpt2_distributed.sh │ ├── initialize.py │ ├── gpt_piped.py │ └── pretrain_gpt2.py └── LICENSE /_lib/fastertransformer_training_v1/cu111/torch190/libfastertransformer.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/volcengine/veGiantModel/HEAD/_lib/fastertransformer_training_v1/cu111/torch190/libfastertransformer.so -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | application/cache 2 | *.pyc 3 | 4 | # general things to ignore 5 | build/ 6 | dist/ 7 | *.egg-info/ 8 | *.egg 9 | *.py[cod] 10 | __pycache__/ 11 | *~ 12 | 13 | # due to using tox and pytest 14 | .tox 15 | .cache 16 | -------------------------------------------------------------------------------- /_lib/fastertransformer_training_v1/cu111/torch190/th_fastertransformer.cpython-37m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/volcengine/veGiantModel/HEAD/_lib/fastertransformer_training_v1/cu111/torch190/th_fastertransformer.cpython-37m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third_party/megatron"] 2 | path = third_party/megatron 3 | url = https://github.com/NVIDIA/Megatron-LM.git 4 | [submodule "third_party/deepspeed"] 5 | path = third_party/deepspeed 6 | url = https://github.com/microsoft/DeepSpeed.git 7 | -------------------------------------------------------------------------------- /src/veGiantModel/module/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, ByteDance Inc. All rights reserved. 2 | from .dense import ColumnSerialLinear, ColumnParallelLinear 3 | from .dense import RowSerialLinear, RowParallelLinear, MockModule 4 | from .dense import ColumnParallelLinearTranspose, ColumnSerialLinearTranspose 5 | 6 | __all__ = ['ColumnSerialLinear', 7 | 'ColumnParallelLinear', 8 | 'ColumnParallelLinearTranspose', 9 | 'ColumnSerialLinearTranspose', 10 | 'RowSerialLinear', 11 | 'RowParallelLinear', 12 | 'MockModule'] 13 | -------------------------------------------------------------------------------- /src/veGiantModel/distributed/__init__.py: -------------------------------------------------------------------------------- 1 | from .. import patcher as dist 2 | from megatron import mpu 3 | 4 | def get_model_parallel_world_size(): 5 | return dist.get_model_parallel_world_size() 6 | 7 | def get_model_parallel_rank(): 8 | return dist.get_model_parallel_rank() 9 | 10 | def get_data_parallel_world_size(): 11 | return dist.get_data_parallel_world_size() 12 | 13 | def get_model_parallel_group(): 14 | return dist.get_model_parallel_group() 15 | 16 | def get_grid(): 17 | return dist.get_grid() 18 | 19 | def copy_to_model_parallel_region(input_): 20 | return mpu.copy_to_model_parallel_region(input_) 21 | 22 | def reduce_from_model_parallel_region(input_): 23 | return mpu.reduce_from_model_parallel_region(input_) 24 | 25 | def gather_from_model_parallel_region(input_): 26 | return mpu.gather_from_model_parallel_region(input_) 27 | -------------------------------------------------------------------------------- /docs/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:21.05-py3 2 | 3 | RUN pip3 install boto3 regex tensorboardX==1.8 wheel pybind11 ninja psutil pyprof 4 | RUN apt-get -yq autoremove --purge ibverbs-providers 5 | RUN apt-get update && \ 6 | DEBIAN_FRONTEND=noninteractive apt-get install -yq --no-install-recommends --allow-downgrades \ 7 | libibverbs-dev=28.0-1ubuntu1 libibverbs1=28.0-1ubuntu1 8 | 9 | RUN apt-get update && \ 10 | DEBIAN_FRONTEND=noninteractive apt-get install -yq --no-install-recommends --allow-downgrades \ 11 | cmake \ 12 | libopenmpi-dev \ 13 | openmpi-bin \ 14 | openssh-client \ 15 | openssh-server \ 16 | ibverbs-providers \ 17 | libibverbs-dev=28.0-1ubuntu1 \ 18 | librdmacm-dev \ 19 | vim \ 20 | iputils-ping \ 21 | llvm-10-dev \ 22 | iproute2 \ 23 | unzip 24 | 25 | RUN ln -s /usr/bin/aclocal-1.16 /usr/local/bin/aclocal-1.14 26 | RUN ln -s /usr/bin/automake /usr/local/bin/automake-1.14 27 | 28 | ENV LD_LIBRARY_PATH "/usr/lib/x86_64-linux-gnu:${LD_LIBRARY_PATH}" 29 | ENV BYTEPS_WITH_UCX 0 30 | 31 | #install byteps from package stored in tos at volcengine 32 | # RUN pip3 install https://giant-model-package.tos-cn-beijing.volces.com/byteps-0.7.2-cp38-cp38-linux_x86_64.whl 33 | 34 | #install byteps from source 35 | RUN git clone --recursive -b bccl-github https://github.com/bytedance/byteps.git && \ 36 | cd byteps && python3 setup.py install 37 | 38 | WORKDIR /root -------------------------------------------------------------------------------- /docs/step-by-step-tutorial.md: -------------------------------------------------------------------------------- 1 | # A Step-by-Step Tutorial 2 | The goal of this tutorial is to help you run the example quickly. 3 | 4 | ## Pre-requisite 5 | pytorch: 6 | ``` 7 | pip3 install pytorch 8 | ``` 9 | 10 | Apex: 11 | ``` 12 | git clone https://github.com/NVIDIA/apex.git 13 | cd apex 14 | python3 setup.py -v --cpp_ext --cuda_ext bdist_wheel 15 | sudo pip3 install dist/* 16 | ``` 17 | 18 | BytePs: 19 | ``` 20 | git clone --recursive -b bccl-github https://github.com/bytedance/byteps.git 21 | cd byteps 22 | python3 setup.py install 23 | ``` 24 | ## Prepare data 25 | [GPT data preprocess](https://github.com/NVIDIA/Megatron-LM#data-preprocessing) 26 | 27 | ## Setup veGiantModel 28 | ``` 29 | git clone https://github.com/volcengine/veGiantModel.git 30 | cd veGiantModel 31 | git submodule update --init --recursive 32 | ``` 33 | 34 | ## Modify script 35 | Modify examples/gpt/pretrain_gpt2_distributed.sh before run 36 | ``` 37 | DATA_PATH -- the preprocessed gpt data local folder path 38 | CHECKPOINT_PATH -- local path to save/load check point 39 | MASTER_PORT -- port number used by torch ddp 40 | WORKER_0_PORT -- port number for veGiantModel use for communication 41 | WORKER_0_HOST -- ip of the master node (single node training can use 'localhost') 42 | NUM_WORKER -- number of workers in the training 43 | WORKER_RANK -- rank of current node 44 | GPU_PER_WORKER -- number of GPUs per node 45 | ``` 46 | 47 | ## run script 48 | ``` 49 | bash examples/gpt/pretrain_gpt2_distributed.sh 50 | ``` 51 | 52 | -------------------------------------------------------------------------------- /src/veGiantModel/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, ByteDance Inc. All rights reserved. 2 | import sys 3 | import os 4 | 5 | cwd = os.path.dirname(os.path.abspath(__file__)) 6 | _deepspeed_dir = os.path.join(cwd, '../../third_party/deepspeed') 7 | _megatron_dir = os.path.join(cwd, '../../third_party/megatron') 8 | sys.path.append(cwd) 9 | sys.path.append(_deepspeed_dir) 10 | sys.path.append(_megatron_dir) 11 | 12 | from . import patcher 13 | from .engine.engine import VeGiantModelEngine 14 | from .initialize import initialize_megatron, init_distribute 15 | from .distributed import * 16 | 17 | def initialize(args, 18 | model, 19 | optimizer=None, 20 | model_parameters=None, 21 | training_data=None, 22 | lr_scheduler=None, 23 | mpu=None, 24 | dist_init_required=None, 25 | collate_fn=None, 26 | config_params=None): 27 | engine = VeGiantModelEngine(args=args, 28 | model=model, 29 | optimizer=optimizer, 30 | model_parameters=model_parameters, 31 | training_data=training_data, 32 | lr_scheduler=lr_scheduler, 33 | mpu=model.mpu(), 34 | dist_init_required=dist_init_required, 35 | collate_fn=collate_fn, 36 | config_params=config_params) 37 | 38 | return_items = [ 39 | engine, 40 | engine.optimizer, 41 | engine.training_dataloader, 42 | engine.lr_scheduler 43 | ] 44 | return tuple(return_items) 45 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # veGiantModel 2 | VeGiantModel is a torch based high efficient training library developed by the Applied Machine Learning team at Bytedance. This repository is for ongoing research to make giant model (such as [GPT](https://arxiv.org/abs/2005.14165), [BERT](https://arxiv.org/pdf/1810.04805.pdf) and [T5](https://arxiv.org/abs/1910.10683)) training easy, efficient, and effective. VeGiantModel builds on top of [Megatron](https://github.com/NVIDIA/Megatron-LM) and [DeepSpeed](https://github.com/microsoft/DeepSpeed), improves communication efficiency by integrating high efficient communication library [BytePs](https://github.com/bytedance/byteps) and providing customized pipline partitioning. 3 | ## initialization 4 | 5 | ```python 6 | import veGiantModel 7 | pipeline_parallel_size = 1 8 | model_parallel_size = 2 9 | veGiantModel.initialize.init_distribute(pipeline_parallel_size, model_parallel_size, init_method="env://") 10 | mp_size = veGiantModel.distributed.get_model_parallel_world_size() 11 | dp_size = veGiantModel.distributed.get_data_parallel_world_size() 12 | ``` 13 | 14 | ## modules 15 | 16 | 17 | ```python 18 | from veGiantModel.module import ColumnParallelLinear, RowParallelLinear 19 | 20 | class PositionWiseFeedForward(nn.Module): 21 | """ FeedForward Neural Networks for each position """ 22 | 23 | def __init__(self, config: Config): 24 | super().__init__() 25 | 26 | if self.config.use_mp_linear_in_ffn: 27 | assert ColumnParallelLinear is not None 28 | assert RowParallelLinear is not None 29 | self.fc1 = ColumnParallelLinear(config.dim, config.dim_ff, use_ft=False) 30 | self.fc2 = RowParallelLinear(config.dim_ff, config.dim, use_ft=False) 31 | else: 32 | self.fc1 = nn.Linear(config.dim, config.dim_ff) 33 | self.fc2 = nn.Linear(config.dim_ff, config.dim) 34 | self.act = Activation(config.act) 35 | self.dropout = nn.Dropout(config.p_drop_hidden) 36 | 37 | def forward(self, x) -> torch.Tensor: 38 | # (bsz, seq_len, dim) -> (bsz, seq_len, dim_ff / model_parallel_size) -> (bsz, seq_len, dim) 39 | fc1_out = self.act(self.fc1(x)) 40 | if self.config.dropout_in_ffn: 41 | fc1_out = self.dropout(fc1_out) 42 | fc2_out = self.fc2(fc1_out) 43 | if self.config.use_ffn_output_dropout: 44 | fc2_out = self.dropout(fc2_out) 45 | return fc2_out 46 | ``` 47 | 48 | 49 | ## Examples 50 | ### GPT Pretraining 51 | The `examples/gpt/pretrain_gpt2_distributed.sh` scrips runs 345M parameter GPT pretraining on single 8 GPUs node. It follows largely the same as Megatron GPT script with a few notable differences. It shows good compatiblility with current megatron/Deepseed training job with little changes to adpot VeGiantModel. 52 | -------------------------------------------------------------------------------- /examples/gpt/pretrain_gpt2_distributed.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | # Runs the "345M" parameter model 3 | 4 | DATA_PATH= 5 | CHECKPOINT_PATH= 6 | 7 | export WORKER_0_HOST=127.0.0.1 8 | export DMLC_NODE_HOST=127.0.0.1 9 | export WORKER_0_PORT=6000 10 | export NUM_WORKER=1 11 | export WORKER_RANK=0 12 | export GPU_PER_WORKER=8 13 | 14 | export BYTEPS_WITH_UCX=0 15 | export DMLC_ENABLE_UCX=0 16 | export DMLC_ENABLE_RDMA=0 17 | 18 | MASTER_PORT=6002 19 | MASTER_ADDR=$WORKER_0_HOST 20 | 21 | GPUS_PER_NODE=$GPU_PER_WORKER 22 | 23 | NNODES=$NUM_WORKER 24 | NODE_RANK=$WORKER_RANK 25 | 26 | WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) 27 | 28 | base_dir=$(cd `dirname $0`; pwd) 29 | echo base_dir $base_dir 30 | 31 | DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT" 32 | 33 | ds_config='{ 34 | "train_micro_batch_size_per_gpu":16, 35 | "train_batch_size" : 16, 36 | "gradient_accumulation_steps": 2, 37 | "steps_per_print": 1, 38 | "gradient_clipping": 1.0, 39 | "zero_optimization": { 40 | "stage": 0, 41 | "allgather_partitions": true, 42 | "allgather_bucket_size": 500000000, 43 | "overlap_comm": true, 44 | "reduce_scatter": true, 45 | "reduce_bucket_size": 500000000, 46 | "contiguous_gradients" : true, 47 | "cpu_offload": false 48 | }, 49 | "fp16": { 50 | "enabled": true, 51 | "loss_scale": 0, 52 | "loss_scale_window": 1000, 53 | "hysteresis": 2, 54 | "min_loss_scale": 1 55 | }, 56 | "wall_clock_breakdown": true 57 | }' 58 | 59 | python3 -m torch.distributed.launch $DISTRIBUTED_ARGS \ 60 | --no_python --use_env python3 \ 61 | ${base_dir}/pretrain_gpt2.py \ 62 | --model-parallel-size 2 \ 63 | --num-stages 2 \ 64 | --num-layers 24 \ 65 | --hidden-size 1024 \ 66 | --train-batch-size 64 \ 67 | --gradient_accumulation_steps 16 \ 68 | --num-attention-heads 16 \ 69 | --batch-size 4 \ 70 | --seq-length 1024 \ 71 | --max-position-embeddings 1024 \ 72 | --train-iters 500000 \ 73 | --lr-decay-iters 450000 \ 74 | --save $CHECKPOINT_PATH \ 75 | --load $CHECKPOINT_PATH \ 76 | --data-path $DATA_PATH/openwebtext-gpt2_text_document \ 77 | --vocab-file $DATA_PATH/gpt2-vocab.json \ 78 | --merge-file $DATA_PATH/gpt2-merges.txt \ 79 | --data-impl mmap \ 80 | --split 949,50,1 \ 81 | --distributed-backend nccl \ 82 | --lr 0.00025 \ 83 | --lr-decay-style cosine \ 84 | --min-lr 1.0e-5 \ 85 | --weight-decay 1e-2 \ 86 | --clip-grad 1.0 \ 87 | --warmup .02 \ 88 | --log-interval 1 \ 89 | --save-interval 100000 \ 90 | --vocab-size 145608 \ 91 | --DDP-impl torch \ 92 | --eod-mask-loss \ 93 | --deepspeed-pipeline \ 94 | --deepspeed \ 95 | --config_param "$ds_config" \ 96 | --fp16 \ 97 | --partition_method "type:ParallelTransformerLayerPiped" \ 98 | $@ 99 | set +x 100 | -------------------------------------------------------------------------------- /src/veGiantModel/launcher/launch.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, ByteDance Inc. All rights reserved. 2 | #!/usr/bin/python 3 | 4 | from __future__ import print_function 5 | import os 6 | import subprocess 7 | import threading 8 | import sys 9 | from megatron import mpu 10 | from deepspeed.utils import log_dist 11 | import logging 12 | 13 | class PropagatingThread(threading.Thread): 14 | """ propagate exceptions to the parent's thread 15 | refer to https://stackoverflow.com/a/31614591/9601110 16 | """ 17 | 18 | def run(self): 19 | self.exc = None 20 | try: 21 | if hasattr(self, '_Thread__target'): 22 | # python 2.x 23 | self.ret = self._Thread__target( 24 | *self._Thread__args, **self._Thread__kwargs) 25 | else: 26 | # python 3.x 27 | self.ret = self._target(*self._args, **self._kwargs) 28 | except BaseException as e: 29 | self.exc = e 30 | 31 | def join(self): 32 | super(PropagatingThread, self).join() 33 | if self.exc: 34 | raise self.exc 35 | return self.exc 36 | 37 | def launch_scheduler(local_rank): 38 | if os.environ['WORKER_RANK'] != '0': 39 | return 40 | 41 | if local_rank != 0: 42 | return 43 | 44 | 45 | def scheduler_runner(): 46 | my_env = os.environ.copy() 47 | my_env['DMLC_ROLE'] = 'scheduler' 48 | my_env['PS_VERBOSE'] = os.environ.get('PS_VERBOSE', '1') 49 | nvidia_smi = f'nvidia-smi -L' 50 | devices = os.popen(nvidia_smi).read().strip() 51 | if 'A100' in devices: 52 | ip_cmd = f'ip addr show eth2' 53 | ip = os.popen(ip_cmd + ' | grep "\" | awk \'{ print $2 }\' | awk -F "/" \'{ print $1 }\'').read().strip() 54 | my_env['DMLC_NODE_HOST'] = ip 55 | my_env['UCX_RDMA_CM_SOURCE_ADDRESS'] = ip 56 | os.environ['UCX_NET_DEVICES'] = 'mlx5_2:1,eth0,eth1,eth2,eth3' 57 | 58 | command = "python3 -c 'import byteps.server'" 59 | subprocess.check_call(command, env=my_env, 60 | stdout=sys.stdout, stderr=sys.stderr, shell=True) 61 | t = PropagatingThread(target=scheduler_runner) 62 | t.setDaemon(True) 63 | t.start() 64 | 65 | def get_worker0_host(): 66 | host = os.environ['WORKER_0_HOST'] 67 | return host 68 | 69 | def get_worker0_port(): 70 | port = os.environ['WORKER_0_PORT'] 71 | return port 72 | 73 | def setup_env(local_rank): 74 | mp_size = mpu.get_model_parallel_world_size() 75 | 76 | num_nodes = int(os.environ['NUM_WORKER']) 77 | gpu_per_node = int(os.environ['GPU_PER_WORKER']) 78 | assert gpu_per_node >= mp_size 79 | assert gpu_per_node % mp_size == 0 80 | 81 | os.environ['BYTEPS_RDMA_START_DEPTH'] = str(32) 82 | os.environ['BYTEPS_RDMA_RX_DEPTH'] = str(512) 83 | 84 | os.environ['DMLC_NUM_WORKER'] = str(gpu_per_node * num_nodes) 85 | os.environ['DMLC_NUM_SERVER'] = str(gpu_per_node * num_nodes) 86 | 87 | os.environ['BYTEPS_LOCAL_SIZE'] = str(gpu_per_node) 88 | os.environ['BYTEPS_FORCE_DISTRIBUTED'] = '1' 89 | os.environ['BYTEPS_ENABLE_IPC'] = '0' 90 | os.environ['DMLC_PS_ROOT_PORT'] = get_worker0_port() 91 | os.environ['DMLC_PS_ROOT_URI'] = get_worker0_host() 92 | 93 | if 'DMLC_ENABLE_RDMA' not in os.environ: 94 | os.environ['DMLC_ENABLE_RDMA'] = '1' 95 | os.environ['DMLC_ENABLE_UCX'] = os.environ.get('DMLC_ENABLE_UCX', '1') 96 | os.environ['UCX_IB_TRAFFIC_CLASS'] = '236' 97 | os.environ['UCX_TLS'] = os.environ.get('UCX_TLS', 'rc_x,tcp,sm') 98 | nvidia_smi = f'nvidia-smi -L' 99 | devices = os.popen(nvidia_smi).read().strip() 100 | if 'A100' in devices: 101 | nic = 2 # TODO: use multiple NICs with `int(local_rank / 2)` 102 | ip_cmd = f'ip addr show eth{nic}' 103 | ip = os.popen(ip_cmd + ' | grep "\" | awk \'{ print $2 }\' | awk -F "/" \'{ print $1 }\'').read().strip() 104 | os.environ['UCX_RDMA_CM_SOURCE_ADDRESS'] = os.environ.get('UCX_RDMA_CM_SOURCE_ADDRESS', ip) 105 | devs = os.environ.get('UCX_NET_DEVICES', f'mlx5_{nic}:1,eth0,eth1,eth2,eth3') 106 | os.environ['UCX_NET_DEVICES'] = devs 107 | os.environ['DMLC_NODE_HOST'] = os.environ['UCX_RDMA_CM_SOURCE_ADDRESS'] 108 | elif 'V100' in devices or 'T4' in devices: 109 | devs = os.environ.get('UCX_NET_DEVICES', 'mlx5_2:1,eth0,eth2') 110 | os.environ['UCX_NET_DEVICES'] = devs 111 | else: 112 | raise RuntimeError(f"Unknown devices: {devices}") 113 | 114 | def launch_bps(local_rank): 115 | log_dist(f'launch_bps({local_rank})', ranks=[-1], level=logging.DEBUG) 116 | setup_env(local_rank) 117 | launch_scheduler(local_rank) -------------------------------------------------------------------------------- /src/veGiantModel/initialize.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, ByteDance Inc. All rights reserved. 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | import torch 4 | import os 5 | import random 6 | import numpy as np 7 | 8 | from megatron.global_vars import set_global_variables 9 | from megatron import get_args, mpu, print_rank_0 10 | from .engine.topology import PipeModelDataParallelTopology, PipelineParallelGrid 11 | from .launcher.launch import launch_bps 12 | from deepspeed.utils import log_dist 13 | import logging 14 | 15 | 16 | 17 | def add_byte_giant_model_customize_args(parser): 18 | import deepspeed 19 | parser = deepspeed.add_config_arguments(parser) 20 | group = parser.add_argument_group(title='bytedance') 21 | group.add_argument('--cpu-optimizer', action='store_true', 22 | help='Run optimizer on CPU') 23 | group.add_argument('--cpu_torch_adam', action='store_true', 24 | help='Use Torch Adam as optimizer on CPU.') 25 | group.add_argument('--vocab-size', type=int, default=1000, 26 | help='vocab size.') 27 | group.add_argument('--train-batch-size', type=int, default=0, 28 | help='global batch size') 29 | group.add_argument('--train_micro_batch_size_per_gpu', type=int, default=0, 30 | help='Batch size per model instance (for deepspeed). ' 31 | 'Global batch size is local batch size times data ' 32 | 'parallel size.') 33 | group.add_argument('--deepspeed-activation-checkpointing', action='store_true', 34 | help='deepspeed_activation_checkpointing.') 35 | group.add_argument('--deepspeed-pipeline', action='store_true', 36 | help='enable pipeline parallelism via deepspeed.') 37 | group.add_argument('--ci', action='store_true', help="run in CI environment") 38 | group.add_argument('--gradient_accumulation_steps', type=int, default=1, 39 | help="set gradient_accumulation_steps for deepspeed config") 40 | group.add_argument('--train_batch_size', type=int, default=0, 41 | help="train_batch_size") 42 | group.add_argument('--broadcast_activation', action='store_true', help="use broadcast to send/recv activation") 43 | group.add_argument('--broadcast_grads', action='store_true', help="use broadcast to send/recv grads") 44 | group.add_argument('--partition_method', type=str, default='uniform', 45 | help='the method to partition layers in pipeline parallelism.') 46 | group.add_argument('--config_param', type=str, default='', 47 | help='json dict for deepspeed config') 48 | 49 | group.add_argument('--num-stages', type=int, default=1, 50 | help='number of stages') 51 | return parser 52 | 53 | def initialize_megatron(extra_args_provider=None, args_defaults={}): 54 | set_global_variables(extra_args_provider=add_byte_giant_model_customize_args, args_defaults=args_defaults) 55 | args = get_args() 56 | init_distribute(args.num_stages, args.model_parallel_size) 57 | _set_random_seed(args.seed) 58 | 59 | def _init_topology(num_stages, mp_size): 60 | num_pp = num_stages 61 | num_mp = mp_size 62 | num_dp = (torch.distributed.get_world_size() // num_pp) // num_mp 63 | log_dist('rank: {args.rank}, init topology with num_pp:{num_pp}, num_mp:{num_mp}, \ 64 | num_dp: {num_dp}', ranks=[-1], level=logging.DEBUG) 65 | topology = PipeModelDataParallelTopology(num_pp=num_pp, num_mp=num_mp, num_dp=num_dp) 66 | log_dist(f'finish building topology, topology.mapping: {topology.mapping}', \ 67 | ranks=[-1], level=logging.DEBUG) 68 | return PipelineParallelGrid(topology) 69 | 70 | def _set_random_seed(seed): 71 | """Set random seed for reproducability.""" 72 | if seed is not None and seed > 0: 73 | random.seed(seed) 74 | np.random.seed(seed) 75 | torch.manual_seed(seed) 76 | if torch.cuda.device_count() > 0: 77 | mpu.model_parallel_cuda_manual_seed(seed) 78 | else: 79 | raise ValueError('Seed ({}) should be a positive integer.'.format(seed)) 80 | 81 | def init_distribute(num_stages, mp_size, 82 | distributed_backend='nccl', init_method='tcp://'): 83 | rank = int(os.getenv('RANK', '0')) 84 | world_size = int(os.getenv("WORLD_SIZE", '1')) 85 | device_count = torch.cuda.device_count() 86 | local_rank = rank % device_count 87 | 88 | if torch.distributed.is_initialized(): 89 | print_rank_0('torch distributed is already initialized, ' 90 | 'skipping initialization ...') 91 | else: 92 | print_rank_0('> initializing torch distributed ...') 93 | 94 | torch.cuda.set_device(local_rank) 95 | # Call the init process 96 | master_ip = os.getenv('MASTER_ADDR', 'localhost') 97 | master_port = os.getenv('MASTER_PORT', '6000') 98 | init_method += master_ip + ':' + master_port 99 | torch.distributed.init_process_group( 100 | backend=distributed_backend, 101 | world_size=world_size, rank=rank, 102 | init_method=init_method) 103 | 104 | # Set the model-parallel / data-parallel communicators. 105 | grid = _init_topology(num_stages, mp_size) 106 | mpu.initialize_model_parallel(grid) 107 | if num_stages > 1: 108 | import byteps.torch as bps 109 | assert bps is not None 110 | launch_bps(local_rank) 111 | -------------------------------------------------------------------------------- /src/veGiantModel/patcher.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, ByteDance Inc. All rights reserved. 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | 4 | import torch 5 | print("Loading veGiantModel submodules ...") 6 | 7 | _TOPOLOGY = None 8 | 9 | def is_unitialized(): 10 | """Useful for code segments that may be accessed with or without mpu initialization""" 11 | return _TOPOLOGY is None 12 | 13 | 14 | def initialize_model_parallel(grid): 15 | # Get world size and rank. Ensure some consistencies. 16 | assert torch.distributed.is_initialized() 17 | global _TOPOLOGY 18 | _TOPOLOGY = grid 19 | 20 | 21 | def model_parallel_is_initialized(): 22 | """Check if model and data parallel groups are initialized.""" 23 | if _TOPOLOGY is None: 24 | return False 25 | return True 26 | 27 | 28 | def get_model_parallel_group(): 29 | """Get the parallel group the caller rank belongs to.""" 30 | assert _TOPOLOGY is not None, \ 31 | ' parallel group is not initialized' 32 | return _TOPOLOGY.get_slice_parallel_group() 33 | 34 | 35 | def get_data_parallel_group(): 36 | """Get the data parallel group the caller rank belongs to.""" 37 | assert _TOPOLOGY is not None, \ 38 | 'data parallel group is not initialized' 39 | return _TOPOLOGY.get_data_parallel_group() 40 | 41 | 42 | def set_model_parallel_world_size(world_size): 43 | pass 44 | 45 | 46 | def get_model_parallel_world_size(): 47 | """Return world size for the model parallel group.""" 48 | return _TOPOLOGY.get_slice_parallel_world_size() 49 | 50 | 51 | def set_model_parallel_rank(rank): 52 | pass 53 | 54 | 55 | def get_model_parallel_rank(): 56 | """Return my rank for the model parallel group.""" 57 | return _TOPOLOGY.get_slice_parallel_rank() 58 | 59 | 60 | def get_model_parallel_src_rank(): 61 | return _TOPOLOGY.get_slice_parallel_src_rank() 62 | 63 | 64 | def get_data_parallel_world_size(): 65 | """Return world size for the data parallel group.""" 66 | return _TOPOLOGY.get_data_parallel_world_size() 67 | 68 | 69 | def get_data_parallel_rank(): 70 | """Return my rank for the data parallel group.""" 71 | return _TOPOLOGY.get_data_parallel_rank() 72 | 73 | def get_pipe_parallel_rank(): 74 | return _TOPOLOGY.get_pipe_parallel_rank() 75 | 76 | def destroy_model_parallel(): 77 | """Set the groups to none.""" 78 | global _TOPOLOGY 79 | _TOPOLOGY = None 80 | 81 | def get_grid(): 82 | return _TOPOLOGY 83 | 84 | def get_topo(): 85 | return _TOPOLOGY.topology() 86 | 87 | import megatron.mpu.initialize as initialize 88 | initialize.is_unitialized = is_unitialized 89 | initialize.initialize_model_parallel = initialize_model_parallel 90 | initialize.model_parallel_is_initialized = model_parallel_is_initialized 91 | initialize.get_model_parallel_group = get_model_parallel_group 92 | initialize.get_data_parallel_group = get_data_parallel_group 93 | initialize.set_model_parallel_world_size = set_model_parallel_world_size 94 | initialize.get_model_parallel_world_size = get_model_parallel_world_size 95 | initialize.set_model_parallel_rank = set_model_parallel_rank 96 | initialize.get_model_parallel_rank = get_model_parallel_rank 97 | initialize.get_model_parallel_src_rank = get_model_parallel_src_rank 98 | initialize.get_data_parallel_world_size = get_data_parallel_world_size 99 | initialize.get_data_parallel_rank = get_data_parallel_rank 100 | initialize.get_pipe_parallel_rank = get_pipe_parallel_rank 101 | initialize.destroy_model_parallel = destroy_model_parallel 102 | 103 | from megatron import mpu 104 | from importlib import reload 105 | reload(mpu.data) 106 | reload(mpu.mappings) 107 | reload(mpu.cross_entropy) 108 | mpu.get_pipe_parallel_rank = get_pipe_parallel_rank 109 | reload(mpu) 110 | 111 | from megatron.mpu import mappings 112 | 113 | def _gather(input_): 114 | """Gather tensors and concatinate along the last dimension.""" 115 | 116 | world_size = get_model_parallel_world_size() 117 | # Bypass the function if we are using only 1 GPU. 118 | if world_size==1: 119 | return input_ 120 | 121 | # Size and dimension. 122 | last_dim = input_.dim() - 1 123 | rank = get_model_parallel_rank() 124 | 125 | tensor_list = [torch.empty_like(input_) for _ in range(world_size)] 126 | tensor_list[rank] = input_ 127 | group = get_model_parallel_group() 128 | torch.distributed.all_gather(tensor_list, input_, group=group) 129 | 130 | # Note: torch.cat already creates a contiguous tensor. 131 | output = torch.cat(tensor_list, dim=last_dim).contiguous() 132 | 133 | return output 134 | 135 | mappings._gather = _gather 136 | 137 | from megatron.tokenizer import tokenizer as token 138 | from megatron.tokenizer.tokenizer import _BertWordPieceTokenizer, _vocab_size_with_padding, _GPT2BPETokenizer 139 | 140 | def build_tokenizer(args): 141 | if args.vocab_file is None: 142 | args.padded_vocab_size = _vocab_size_with_padding(args.vocab_size, 143 | args) 144 | return None 145 | """Initialize tokenizer.""" 146 | if args.rank == 0: 147 | print('> building {} tokenizer ...'.format(args.tokenizer_type), 148 | flush=True) 149 | 150 | # Select and instantiate the tokenizer. 151 | assert args.vocab_file is not None 152 | if args.tokenizer_type == 'BertWordPieceLowerCase': 153 | tokenizer = _BertWordPieceTokenizer(vocab_file=args.vocab_file, 154 | lower_case=True) 155 | elif args.tokenizer_type == 'BertWordPieceCase': 156 | tokenizer = _BertWordPieceTokenizer(vocab_file=args.vocab_file, 157 | lower_case=False) 158 | elif args.tokenizer_type == 'GPT2BPETokenizer': 159 | assert args.merge_file is not None 160 | tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file) 161 | else: 162 | raise NotImplementedError('{} tokenizer is not ' 163 | 'implemented.'.format(args.tokenizer_type)) 164 | 165 | # Add vocab size. 166 | args.padded_vocab_size = _vocab_size_with_padding(tokenizer.vocab_size, 167 | args) 168 | 169 | return tokenizer 170 | 171 | token.build_tokenizer = build_tokenizer 172 | import megatron 173 | reload(megatron.tokenizer) 174 | reload(megatron.global_vars) 175 | reload(megatron.global_vars) 176 | print("veGiantModel loaded.") 177 | -------------------------------------------------------------------------------- /examples/gpt/initialize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import json 3 | import veGiantModel 4 | 5 | from megatron import get_args, mpu 6 | from megatron.fp16 import FP16_Module 7 | from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP 8 | from megatron.model import DistributedDataParallel as LocalDDP 9 | from megatron.model import get_params_for_weight_decay_optimization 10 | from apex.optimizers import FusedAdam as Adam 11 | from megatron.learning_rates import AnnealingLR 12 | from megatron import print_rank_0 13 | 14 | 15 | def get_learning_rate_scheduler(optimizer, lr_scheduler_builder): 16 | """Build the learning rate scheduler.""" 17 | args = get_args() 18 | 19 | 20 | if lr_scheduler_builder is not None: 21 | lr_scheduler = lr_scheduler_builder(optimizer) 22 | else: 23 | # Add linear learning rate scheduler. 24 | if args.lr_decay_iters is not None: 25 | num_iters = args.lr_decay_iters 26 | else: 27 | num_iters = args.train_iters 28 | num_iters = max(1, num_iters) 29 | init_step = 0 30 | warmup_iter = args.warmup * num_iters 31 | lr_scheduler = AnnealingLR( 32 | optimizer, 33 | start_lr=args.lr, 34 | warmup_iter=warmup_iter, 35 | total_iters=num_iters, 36 | decay_style=args.lr_decay_style, 37 | last_iter=init_step, 38 | min_lr=args.min_lr, 39 | use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler, 40 | override_lr_scheduler=args.override_lr_scheduler) 41 | 42 | return lr_scheduler 43 | 44 | 45 | def get_model(model_provider_func): 46 | """Build the model.""" 47 | args = get_args() 48 | 49 | # Build model on cpu. 50 | model = model_provider_func() 51 | 52 | # Print number of parameters. 53 | if mpu.get_data_parallel_rank() == 0: 54 | print(' > number of parameters on model parallel rank {}: {}'.format( 55 | mpu.get_model_parallel_rank(), 56 | sum([p.nelement() for p in model.parameters()])), flush=True) 57 | 58 | # GPU allocation. 59 | model.cuda(torch.cuda.current_device()) 60 | 61 | return model 62 | 63 | def get_optimizer(model): 64 | """Set up the optimizer.""" 65 | args = get_args() 66 | 67 | # Build parameter groups (weight decay and non-decay). 68 | while isinstance(model, (torchDDP, LocalDDP, FP16_Module)): 69 | model = model.module 70 | param_groups = get_params_for_weight_decay_optimization(model) 71 | 72 | # Add model parallel attribute if it is not set. 73 | for param_group in param_groups: 74 | for param in param_group['params']: 75 | if not hasattr(param, 'model_parallel'): 76 | param.model_parallel = False 77 | 78 | if args.cpu_optimizer: 79 | if args.cpu_torch_adam: 80 | cpu_adam_optimizer = torch.optim.Adam 81 | else: 82 | from deepspeed.ops.adam import DeepSpeedCPUAdam 83 | cpu_adam_optimizer = DeepSpeedCPUAdam 84 | optimizer = cpu_adam_optimizer(param_groups, 85 | lr=args.lr, weight_decay=args.weight_decay) 86 | else: 87 | # Use Adam. 88 | optimizer = Adam(param_groups, lr=args.lr, weight_decay=args.weight_decay) 89 | 90 | if args.deepspeed: 91 | # fp16 wrapper is not required for DeepSpeed. 92 | return optimizer 93 | 94 | # Wrap into fp16 optimizer. 95 | if args.fp16: 96 | optimizer = FP16_Optimizer(optimizer, 97 | static_loss_scale=args.loss_scale, 98 | dynamic_loss_scale=args.dynamic_loss_scale, 99 | dynamic_loss_args={ 100 | 'scale_window': args.loss_scale_window, 101 | 'min_scale': args.min_scale, 102 | 'delayed_shift': args.hysteresis}, 103 | fp16_optim=args.fp16_optim) 104 | 105 | return optimizer 106 | 107 | def setup_model_and_optimizer(model, optimizer, train_dataset_provider, lr_scheduler_builder): 108 | """Setup model and optimizer.""" 109 | args = get_args() 110 | if optimizer is None: 111 | optimizer = get_optimizer(model) 112 | lr_scheduler = get_learning_rate_scheduler(optimizer, lr_scheduler_builder) 113 | 114 | print_rank_0("DeepSpeed is enabled.") 115 | 116 | # Print number of parameters. 117 | if mpu.get_data_parallel_rank() == 0: 118 | print(' > number of parameters on data parallel rank {}, model parallel rank {}, pipeline parallel rank {}: {}'.format( 119 | mpu.get_data_parallel_rank(), 120 | mpu.get_model_parallel_rank(), 121 | mpu.get_pipe_parallel_rank(), 122 | sum([p.nelement() for p in model.parameters()])), flush=True) 123 | 124 | if args.deepspeed_pipeline: 125 | print_rank_0("Pipeline Parallelism is enabled.") 126 | train_data = train_dataset_provider() if train_dataset_provider is not None else None 127 | _param_dict = json.loads(args.config_param) 128 | engine, optimizer, _, lr_scheduler = veGiantModel.initialize( 129 | model=model, 130 | optimizer=optimizer, 131 | args=args, 132 | lr_scheduler=lr_scheduler, 133 | mpu=None, 134 | dist_init_required=False, 135 | config_params = _param_dict, 136 | training_data=train_data 137 | ) 138 | engine.set_batch_fn(model.batch_fn) 139 | else: 140 | engine, optimizer, _, lr_scheduler = veGiantModel.initialize( 141 | model=model, 142 | optimizer=optimizer, 143 | args=args, 144 | lr_scheduler=lr_scheduler, 145 | mpu=mpu, 146 | dist_init_required=False 147 | ) 148 | 149 | print_rank_0("Model Preparation Done") 150 | args.iteration = 0 151 | 152 | return engine, optimizer, lr_scheduler 153 | 154 | 155 | def initialize_pipeline(model, optimizer, train_dataset_provider, lr_scheduler_builder=None): 156 | return setup_model_and_optimizer(model, optimizer, train_dataset_provider, lr_scheduler_builder) 157 | 158 | 159 | def initialize_distributed(num_stages, mp_size, distributed_backend='nccl'): 160 | veGiantModel.init_distribute(num_stages=num_stages, mp_size=mp_size, distributed_backend=distributed_backend) 161 | 162 | def initialize_megatron(extra_args_provider=None, args_defaults={}): 163 | veGiantModel.initialize_megatron(extra_args_provider=extra_args_provider, args_defaults=args_defaults) 164 | -------------------------------------------------------------------------------- /examples/gpt/gpt_piped.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from megatron import get_args, mpu 4 | 5 | from megatron.model.language_model import parallel_lm_logits, Embedding 6 | from megatron.model.transformer import ParallelTransformerLayer 7 | from megatron.model.transformer import LayerNorm 8 | from megatron.model.gpt2_model import gpt2_attention_mask_func 9 | from megatron.model.utils import init_method_normal 10 | from megatron.model.utils import scaled_init_method_normal 11 | from megatron.module import MegatronModule 12 | from megatron.utils import get_ltor_masks_and_position_ids 13 | 14 | from deepspeed.pipe import LayerSpec, TiedLayerSpec 15 | from megatron import get_tokenizer 16 | from veGiantModel.engine.module import VeGiantModule 17 | import veGiantModel 18 | 19 | class GPTModelPiped(VeGiantModule): 20 | def __init__(self): 21 | args = get_args() 22 | self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy 23 | self.tokenizer = get_tokenizer() 24 | self.parallel_output = True 25 | 26 | self.num_layers = args.num_layers 27 | self.hidden_size = args.hidden_size 28 | 29 | self.init_method = init_method_normal(args.init_method_std) 30 | self.scale_init_method = scaled_init_method_normal(args.init_method_std, 31 | args.num_layers) 32 | 33 | self.num_tokentypes = 0 34 | 35 | layers = [] 36 | layers.append(lambda x: self._get_batch(x)) 37 | layers.append(TiedLayerSpec("SharedEmbedding", 38 | EmbeddingPiped, 39 | self.hidden_size, 40 | args.padded_vocab_size, 41 | args.max_position_embeddings, 42 | args.hidden_dropout, 43 | self.init_method, 44 | self.num_tokentypes, 45 | tied_weight_attr='embedding_weight')) 46 | 47 | layers.append(lambda x: (x[0].transpose(0, 1).contiguous(), x[1])) 48 | 49 | for i in range(self.num_layers): 50 | layers.append(LayerSpec(ParallelTransformerLayerPiped, 51 | gpt2_attention_mask_func, 52 | self.init_method, 53 | self.scale_init_method, 54 | i+1)) 55 | 56 | layers.append(lambda x: (x[0].transpose(0, 1).contiguous())) 57 | 58 | layers.append(LayerSpec(LayerNorm, args.hidden_size, eps=args.layernorm_epsilon)) 59 | 60 | layers.append(TiedLayerSpec("SharedEmbedding", 61 | LMLogitsPiped, 62 | self.hidden_size, 63 | args.padded_vocab_size, 64 | self.init_method, 65 | tied_weight_attr='embedding_weight')) 66 | 67 | super().__init__(layers=layers, 68 | num_stages = args.num_stages, 69 | partition_method=args.partition_method, 70 | grid=veGiantModel.distributed.get_grid(), 71 | loss_fn=self.loss_fn) 72 | 73 | 74 | # Data Preprocessing, copied from pretrain_gpt2.py 75 | def _get_batch(self, data): 76 | """Generate a batch""" 77 | args = get_args() 78 | # Unpack. 79 | tokens = data 80 | 81 | attention_mask, _, position_ids = get_ltor_masks_and_position_ids( 82 | tokens, 83 | self.tokenizer.eod, 84 | args.reset_position_ids, 85 | args.reset_attention_mask, 86 | args.eod_mask_loss) 87 | 88 | return (tokens.to(device="cuda"), 89 | position_ids.to(device="cuda"), 90 | attention_mask.to(device="cuda")) 91 | 92 | def loss_fn(self, inputs, data): 93 | tokens = data[0] 94 | target = data[1] 95 | args = get_args() 96 | _, loss_mask, _ = get_ltor_masks_and_position_ids( 97 | tokens, 98 | self.tokenizer.eod, 99 | args.reset_position_ids, 100 | args.reset_attention_mask, 101 | args.eod_mask_loss) 102 | 103 | if self.fp16_lm_cross_entropy: 104 | assert inputs.dtype == torch.half 105 | loss = mpu.vocab_parallel_cross_entropy(inputs, target) 106 | else: 107 | loss = mpu.vocab_parallel_cross_entropy(inputs.float(), target) 108 | loss_mask = loss_mask.view(-1) 109 | loss_avg = torch.sum(loss.view(-1) * loss_mask) / loss_mask.sum() 110 | if loss.dtype == torch.half: 111 | loss_avg = loss_avg.half() 112 | 113 | return loss_avg 114 | 115 | def batch_fn(self, batch, is_train:bool): 116 | if batch is not None: 117 | data = {'text': torch.tensor(batch['text'].numpy())} 118 | else: 119 | data = None 120 | 121 | keys = ['text'] 122 | datatype = torch.int64 123 | 124 | data_b = mpu.broadcast_data(keys, data, datatype) 125 | 126 | tokens_ = data_b['text'].long() 127 | tokens_write = tokens_ 128 | labels = tokens_[:, 1:].contiguous() 129 | tokens_ = tokens_[:, :-1].contiguous() 130 | tokens_2 = torch.unsqueeze(tokens_, 0) 131 | data2 = torch.cat((tokens_2, labels[None, :, :]), dim=0) 132 | data = [] 133 | data.append(tokens_) 134 | data.append(data2) 135 | return data 136 | 137 | class LMLogitsPiped(MegatronModule): 138 | def __init__(self, hidden_size, vocab_size, init_method): 139 | super().__init__() 140 | self.word_embeddings = mpu.VocabParallelEmbedding( 141 | vocab_size, hidden_size, init_method=init_method) 142 | self.embedding_weight = self.word_embeddings.weight 143 | 144 | def forward(self, lm_output): 145 | return parallel_lm_logits(lm_output, self.embedding_weight, True) 146 | 147 | 148 | class EmbeddingPiped(Embedding): 149 | def __init__(self, 150 | hidden_size, 151 | vocab_size, 152 | max_sequence_length, 153 | embedding_dropout_prob, 154 | init_method, 155 | num_tokentypes=0): 156 | super().__init__(hidden_size, 157 | vocab_size, 158 | max_sequence_length, 159 | embedding_dropout_prob, 160 | init_method, 161 | num_tokentypes) 162 | self.embedding_weight = self.word_embeddings.weight 163 | 164 | def forward(self, inputs): 165 | input_ids, position_ids, attention_mask = inputs 166 | return super().forward(input_ids, position_ids, None), attention_mask 167 | 168 | class ParallelTransformerLayerPiped(ParallelTransformerLayer): 169 | def __init__(self, 170 | attention_mask_func, 171 | init_method, 172 | output_layer_init_method, 173 | layer_number): 174 | super().__init__(attention_mask_func, 175 | init_method, 176 | output_layer_init_method, 177 | layer_number) 178 | 179 | def forward(self, inputs): 180 | hidden_states, attention_mask = inputs 181 | return (super().forward(hidden_states, attention_mask), 182 | attention_mask) -------------------------------------------------------------------------------- /src/veGiantModel/engine/p2p.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, ByteDance Inc. All rights reserved. 2 | # Copyright 2019 The Microsoft DeepSpeed Team 3 | ''' 4 | Copyright 2019 The Microsoft DeepSpeed Team 5 | ''' 6 | 7 | import os 8 | import torch 9 | import torch.distributed as dist 10 | from deepspeed.utils import logger, log_dist 11 | 12 | ENABLE_PYTORCH_BROADCAST = os.environ.get("ENABLE_PYTORCH_BROADCAST", "0") != "0" 13 | 14 | try: 15 | if not ENABLE_PYTORCH_BROADCAST: 16 | import byteps.torch as bps 17 | else: 18 | print("BytePS import is disabled", flush=True) 19 | bps = None 20 | except ImportError: 21 | print("BytePS is not installed") 22 | bps = None 23 | 24 | _groups = None 25 | _grid = None 26 | 27 | DS_PIPE_VERBOSE = os.environ.get('DS_PIPE_VERBOSE', "0") != "0" 28 | 29 | did_recv = False 30 | send_stream = None 31 | recv_stream = None 32 | 33 | 34 | bps_send_handles = {} 35 | bps_recv_handles = {} 36 | 37 | 38 | #initializes adjacent process groups 39 | #run this only after torch.distributed.init_process_group() has been called 40 | def init_process_groups(grid): 41 | global _groups, _grid 42 | _grid = grid 43 | 44 | assert _grid.pipe_parallel_size > 1, "There is no model parallelism" 45 | 46 | _groups = [dist.new_group(ranks=group) for group in _grid.p2p_groups] 47 | 48 | 49 | def _is_valid_send_recv(src_stage, dest_stage): 50 | first_stage = 0 51 | last_stage = _grid.pipe_parallel_size - 1 52 | assert abs(src_stage-dest_stage) == 1 or \ 53 | (src_stage == first_stage and dest_stage == last_stage) or \ 54 | (src_stage == last_stage and dest_stage == first_stage), \ 55 | "Functionality currently limited to send and receive between adjacent ranks only" 56 | 57 | 58 | def send(tensor, dest_stage, async_op=False): 59 | global _groups 60 | 61 | async_op = False 62 | src_stage = _grid.get_stage_id() 63 | _is_valid_send_recv(src_stage, dest_stage) 64 | 65 | group = _get_send_recv_group(src_stage, dest_stage) 66 | src_rank = _grid.stage_to_global(stage_id=src_stage) 67 | 68 | import torch 69 | if tensor.dtype != torch.float32 and DS_PIPE_VERBOSE: 70 | print('warning: p2p send', tensor.dtype, tensor.shape, flush=True) 71 | return _send(tensor, src_rank, group, async_op) 72 | 73 | def _bps_get_name(src, dest, name, suffix): 74 | return "_".join([str(src), str(dest), str(name), str(suffix)]) 75 | 76 | def bps_send(tensor, dest_stage, name, index, async_op=True): 77 | global bps_send_handles 78 | 79 | src_stage = _grid.get_stage_id() 80 | _is_valid_send_recv(src_stage, dest_stage) 81 | src_rank = _grid.stage_to_global(stage_id=src_stage) 82 | dest_rank = _grid.stage_to_global(stage_id=dest_stage) 83 | name = _bps_get_name(src_rank, dest_rank, name, index) 84 | if name not in bps_send_handles: 85 | # XXX hard-code max number of tensors for this name 86 | bps_send_handles[name] = [None] * 10 87 | else: 88 | handle = bps_send_handles[name][index] 89 | if handle is not None: 90 | bps.synchronize(handle) 91 | handle = bps.send_async(tensor, dest_rank, name=name) 92 | # XXX 93 | if not async_op: 94 | bps.synchronize(handle) 95 | else: 96 | bps_send_handles[name][index] = handle 97 | return tensor 98 | 99 | def bps_sync(src_stage, name, index=0): 100 | dest_stage = _grid.get_stage_id() 101 | _is_valid_send_recv(src_stage, dest_stage) 102 | src_rank = _grid.stage_to_global(stage_id=src_stage) 103 | dest_rank = _grid.stage_to_global(stage_id=dest_stage) 104 | name = _bps_get_name(src_rank, dest_rank, name, index) 105 | if name in bps_recv_handles: 106 | handle = bps_recv_handles[name][index] 107 | if handle is not None: 108 | bps.synchronize(handle) 109 | 110 | def bps_sync_all(): 111 | for name, handles in bps_send_handles.items(): 112 | for handle in handles: 113 | if handle is not None: 114 | bps.synchronize(handle) 115 | 116 | for name, handles in bps_recv_handles.items(): 117 | for handle in handles: 118 | if handle is not None: 119 | bps.synchronize(handle) 120 | 121 | def bps_recv(tensor, src_stage, name, index=0, async_op=True): 122 | global bps_recv_handles 123 | 124 | dest_stage = _grid.get_stage_id() 125 | _is_valid_send_recv(src_stage, dest_stage) 126 | src_rank = _grid.stage_to_global(stage_id=src_stage) 127 | dest_rank = _grid.stage_to_global(stage_id=dest_stage) 128 | name = _bps_get_name(src_rank, dest_rank, name, index) 129 | if name not in bps_recv_handles: 130 | # XXX hard-code max number of tensors for this name 131 | bps_recv_handles[name] = [None] * 10 132 | else: 133 | handle = bps_recv_handles[name][index] 134 | if handle is not None: 135 | bps.synchronize(handle) 136 | handle = bps.recv_async(tensor, src_rank, name=name) 137 | if not async_op: 138 | bps.synchronize(handle) 139 | else: 140 | bps_recv_handles[name][index] = handle 141 | return tensor 142 | 143 | 144 | def _send(tensor, src_rank, group, async_op): 145 | global did_recv 146 | return dist.broadcast(tensor, src_rank, group=group, async_op=async_op) 147 | 148 | def send_grads(tensor, grid, async_op=False): 149 | async_op = False 150 | if grid.send_grads_src_rank == grid.global_rank: 151 | # print(f'start rank: {grid.global_rank}, stage_id: {grid.stage_id}, mp_id: {grid.model_parallel_id}, _send_grad_src_rank: {grid.send_grads_src_rank}, send group: {grid.send_grads_group}, send_grad_groups: {grid.send_grads_proc_group}', flush=True) 152 | _send(tensor, grid.send_grads_src_rank, grid.send_grads_proc_group, async_op) 153 | # print(f'finis rank: {grid.global_rank}, stage_id: {grid.stage_id}, mp_id: {grid.model_parallel_id}, _send_grad_src_rank: {grid.send_grads_src_rank}, send group: {grid.send_grads_group}', flush=True) 154 | else: 155 | # print(f'finish fast rank: {grid.global_rank}, stage_id: {grid.stage_id}, mp_id: {grid.model_parallel_id}, _send_grad_src_rank: {grid.send_grads_src_rank}, send group: {grid.send_grads_group}', flush=True) 156 | pass 157 | 158 | def _recv(tensor, src_rank, group, async_op): 159 | global did_recv 160 | tensor = dist.broadcast(tensor, src_rank, group=group, async_op=async_op) 161 | did_recv = True 162 | return tensor 163 | 164 | def recv_grads(tensor, grid, async_op=False): 165 | async_op = False 166 | # print(f'start rank: {grid.global_rank}, stage_id: {grid.stage_id}, mp_id: {grid.model_parallel_id}, _recv_grad_src_rank: {grid.recv_grads_src_rank}, recv group: {grid.recv_grads_group}, recv_grad_groups: {grid.recv_grads_proc_group}', flush=True) 167 | _recv(tensor, grid.recv_grads_src_rank, grid.recv_grads_proc_group, async_op) 168 | # print(f'finish rank: {grid.global_rank}, stage_id: {grid.stage_id}, mp_id: {grid.model_parallel_id}, _recv_grad_src_rank: {grid.recv_grads_src_rank}, recv group: {grid.recv_grads_group}', flush=True) 169 | 170 | 171 | def send_activations(tensor, grid, async_op=False): 172 | async_op = False 173 | if grid.send_activation_src_rank == grid.global_rank: 174 | # print(f'start rank: {grid.global_rank}, stage_id: {grid.stage_id}, mp_id: {grid.model_parallel_id}, _send_grad_src_rank: {grid.send_grads_src_rank}, send group: {grid.send_grads_group}, send_grad_groups: {grid.send_grads_proc_group}', flush=True) 175 | _send(tensor, grid.send_activation_src_rank, grid.send_activation_proc_group, async_op) 176 | # print(f'finis rank: {grid.global_rank}, stage_id: {grid.stage_id}, mp_id: {grid.model_parallel_id}, _send_grad_src_rank: {grid.send_grads_src_rank}, send group: {grid.send_grads_group}', flush=True) 177 | else: 178 | # print(f'finish fast rank: {grid.global_rank}, stage_id: {grid.stage_id}, mp_id: {grid.model_parallel_id}, _send_grad_src_rank: {grid.send_grads_src_rank}, send group: {grid.send_grads_group}', flush=True) 179 | pass 180 | 181 | def recv_activations(tensor, grid, async_op=False): 182 | async_op = False 183 | # print(f'start rank: {grid.global_rank}, stage_id: {grid.stage_id}, mp_id: {grid.model_parallel_id}, _recv_grad_src_rank: {grid.recv_grads_src_rank}, recv group: {grid.recv_grads_group}, recv_grad_groups: {grid.recv_grads_proc_group}', flush=True) 184 | _recv(tensor, grid.recv_activation_src_rank, grid.recv_activation_proc_group, async_op) 185 | # print(f'finish rank: {grid.global_rank}, stage_id: {grid.stage_id}, mp_id: {grid.model_parallel_id}, _recv_grad_src_rank: {grid.recv_grads_src_rank}, recv group: {grid.recv_grads_group}', flush=True) 186 | 187 | def recv(tensor, src_stage, async_op=False): 188 | global _groups 189 | global did_recv 190 | 191 | async_op = False 192 | dest_stage = _grid.get_stage_id() 193 | _is_valid_send_recv(src_stage, dest_stage) 194 | 195 | group = _get_send_recv_group(src_stage, dest_stage) 196 | src_rank = _grid.stage_to_global(stage_id=src_stage) 197 | return _recv(tensor, src_rank, group, async_op) 198 | 199 | 200 | def barrier(stage_id): 201 | global _groups, _grid 202 | group_id = _grid.stage_to_global(stage_id=stage_id) 203 | if (dist.get_rank() >= 0): 204 | print("Barrier Group ID", group_id) 205 | print("Barrier Group", _grid.p2p_groups[group_id]) 206 | dist.barrier(group=_groups[group_id]) 207 | if (dist.get_rank() >= 0): 208 | print("Exiting Barrier ", group_id) 209 | 210 | 211 | def _get_send_recv_group(src_stage, dest_stage): 212 | '''the group id is always the smaller rank unless its a wrap around''' 213 | 214 | stage_id = None 215 | 216 | first_stage = 0 217 | last_stage = _grid.pipe_parallel_size - 1 218 | 219 | if (src_stage == first_stage and dest_stage == last_stage 220 | or dest_stage == first_stage and src_stage == last_stage): 221 | stage_id = last_stage 222 | elif src_stage > dest_stage: 223 | stage_id = dest_stage 224 | else: 225 | stage_id = src_stage 226 | '''group_id corresponds to group of [group_id, group_id+1] 227 | unless group_id is the rank of the last stage 228 | in which case group_id correspods to group[group_id-num_stages+1, group_id] 229 | ''' 230 | group_id = _grid.stage_to_global(stage_id=stage_id) 231 | 232 | return _groups[group_id] 233 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /src/veGiantModel/engine/schedule.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, ByteDance Inc. All rights reserved. 2 | from deepspeed.runtime.pipe.schedule import ( 3 | BufferOpInstruction,PipeInstruction, 4 | ReduceTiedGrads,ReduceGrads,OptimizerStep, 5 | LoadMicroBatch,PipeSchedule,TrainSchedule, 6 | ) 7 | 8 | import os 9 | 10 | BYTEPS_REDUCED_MEM = os.environ.get('BYTEPS_REDUCED_MEM', '1') != '0' 11 | 12 | class BytePSInferenceSchedule(PipeSchedule): 13 | """A schedule for inferencing batches using pipeline parallelism. 14 | """ 15 | def __init__(self, micro_batches, stages, stage_id, prefetch=True): 16 | super().__init__(micro_batches, stages, stage_id) 17 | self.prefetch = prefetch 18 | 19 | def steps(self): 20 | """""" 21 | total_steps = self.micro_batches + self.stages - 1 22 | for step_id in range(total_steps): 23 | cmds = [] 24 | micro_batch_id = step_id - self.stage_id 25 | 26 | buffer_id = micro_batch_id % self.num_pipe_buffers() 27 | batch_is_valid = self._valid_micro_batch(micro_batch_id) 28 | 29 | if not self.prefetch: 30 | if batch_is_valid: 31 | if self.is_first_stage or self.is_last_stage: 32 | cmds.append(LoadMicroBatch(buffer_id)) 33 | if self._valid_stage(self.prev_stage): 34 | cmds.append(BytePSRecvActivation(buffer_id)) 35 | cmds.append(BytePSSyncActivation(buffer_id)) 36 | cmds.append(BytePSForwardPass(buffer_id)) 37 | if self._valid_stage(self.next_stage): 38 | cmds.append(BytePSSendActivation(buffer_id)) 39 | else: 40 | next_buffer_id = (micro_batch_id + 1) % self.num_pipe_buffers() 41 | next_batch_is_valid = self._valid_micro_batch(micro_batch_id + 1) 42 | # micro_batch starts at 0. Get the current batch, and start prefetching 43 | if micro_batch_id == 0: 44 | if self.is_first_stage or self.is_last_stage: 45 | cmds.append(LoadMicroBatch(buffer_id)) 46 | if self._valid_stage(self.prev_stage): 47 | cmds.append(BytePSRecvActivation(buffer_id)) 48 | if next_batch_is_valid: 49 | cmds.append(BytePSRecvActivation(next_buffer_id)) 50 | cmds.append(BytePSSyncActivation(buffer_id)) 51 | cmds.append(BytePSForwardPass(buffer_id)) 52 | if self._valid_stage(self.next_stage): 53 | cmds.append(BytePSSendActivation(buffer_id)) 54 | elif batch_is_valid: 55 | # After micro_batch 0, we prefetch the next one, 56 | # and wait for the current one 57 | if self._valid_stage(self.prev_stage) and next_batch_is_valid: 58 | cmds.append(BytePSRecvActivation(next_buffer_id)) 59 | if self.is_first_stage or self.is_last_stage: 60 | cmds.append(LoadMicroBatch(buffer_id)) 61 | if self._valid_stage(self.prev_stage): 62 | cmds.append(BytePSSyncActivation(buffer_id)) 63 | cmds.append(BytePSForwardPass(buffer_id)) 64 | if self._valid_stage(self.next_stage): 65 | cmds.append(BytePSSendActivation(buffer_id)) 66 | 67 | yield cmds 68 | 69 | def num_pipe_buffers(self): 70 | """Only `self.micro_batches` pipeline buffers are required for inferencing. 71 | 72 | Returns: 73 | ``self.micro_batches`` 74 | """ 75 | buffers = min(self.micro_batches, self.stages * 2) 76 | if BYTEPS_REDUCED_MEM: 77 | buffers = min(self.stages + 1, self.micro_batches) 78 | return max(2, buffers) 79 | 80 | 81 | class BytePSTrainSchedule(TrainSchedule): 82 | """A schedule for training a batch using hybrid parallelism. 83 | 84 | Pipeline parallelism is extracted through gradient accumulation and thus 85 | convergence follows that of a data parallel approach with the same batch 86 | size. 87 | """ 88 | def __init__(self, micro_batches, stages, stage_id, prefetch=True): 89 | super().__init__(micro_batches, stages, stage_id) 90 | self.prefetch = prefetch and micro_batches > 1 91 | if not self.prefetch: 92 | print('BYTEPS NO PREFETCH STEPS', flush=True) 93 | 94 | def steps(self): 95 | if self.prefetch: 96 | return self._steps() 97 | else: 98 | return self._steps_no_prefetch() 99 | 100 | def _steps(self): 101 | """""" 102 | total_steps = 2 * (self.micro_batches + self.stages - 1) 103 | for step_id in range(total_steps): 104 | # Map the step of the pipeline to the micro-batch id and also whether it is a 105 | # forward or backward pass step. 106 | cmds = [] 107 | micro_batch_id, is_forward = self._step_to_micro_batch(step_id) 108 | batch_is_valid = self._valid_micro_batch(micro_batch_id) 109 | if not batch_is_valid: 110 | if step_id == total_steps - 1: 111 | cmds.append(BytePSSyncAll()) 112 | cmds.append(ReduceTiedGrads()) 113 | cmds.append(ReduceGrads()) 114 | cmds.append(OptimizerStep()) 115 | yield cmds 116 | continue 117 | curr_buffer = self._buffer_idx(micro_batch_id) 118 | 119 | # try to find the next valid batch 120 | next_step_id = step_id + 1 121 | next_micro_batch_id, next_is_forward, next_batch_is_valid = None, None, None 122 | while next_step_id < total_steps: 123 | next_micro_batch_id, next_is_forward = self._step_to_micro_batch(next_step_id) 124 | next_batch_is_valid = self._valid_micro_batch(next_micro_batch_id) 125 | if next_batch_is_valid: 126 | break 127 | next_step_id += 1 128 | 129 | next_buffer = None 130 | if next_batch_is_valid: 131 | next_buffer = self._buffer_idx(next_micro_batch_id) 132 | 133 | if micro_batch_id == 0 and is_forward: 134 | # first/last stage loads 135 | if self.stage_id == 0 or self.stage_id == self.stages - 1: 136 | cmds.append(LoadMicroBatch(curr_buffer)) 137 | # fetch 138 | if self._valid_stage(self.prev_stage): 139 | cmds.append(BytePSRecvActivation(curr_buffer)) 140 | # pre-fetch 141 | if next_batch_is_valid: 142 | if self._valid_stage(self.prev_stage) and next_is_forward: 143 | cmds.append(BytePSRecvActivation(next_buffer)) 144 | if self._valid_stage(self.next_stage) and not next_is_forward: 145 | cmds.append(BytePSRecvGrad(next_buffer)) 146 | # sync and compute 147 | if self._valid_stage(self.prev_stage): 148 | cmds.append(BytePSSyncActivation(curr_buffer)) 149 | cmds.append(BytePSForwardPass(curr_buffer)) 150 | if self._valid_stage(self.next_stage): 151 | cmds.append(BytePSSendActivation(curr_buffer)) 152 | else: 153 | # prefetch 154 | if next_batch_is_valid: 155 | if self._valid_stage(self.prev_stage) and next_is_forward: 156 | cmds.append(BytePSRecvActivation(next_buffer)) 157 | if self._valid_stage(self.next_stage) and not next_is_forward: 158 | cmds.append(BytePSRecvGrad(next_buffer)) 159 | if is_forward: 160 | if self.stage_id == 0 or self.stage_id == self.stages - 1: 161 | # First/last stage loads 162 | cmds.append(LoadMicroBatch(curr_buffer)) 163 | if self._valid_stage(self.prev_stage): 164 | cmds.append(BytePSSyncActivation(curr_buffer)) 165 | cmds.append(BytePSForwardPass(curr_buffer)) 166 | if self._valid_stage(self.next_stage): 167 | cmds.append(BytePSSendActivation(curr_buffer)) 168 | else: 169 | if self._valid_stage(self.next_stage): 170 | cmds.append(BytePSSyncGrad(curr_buffer)) 171 | cmds.append(BytePSBackwardPass(curr_buffer)) 172 | if self._valid_stage(self.prev_stage): 173 | cmds.append(BytePSSendGrad(curr_buffer)) 174 | 175 | # Model step at the end of the batch 176 | if step_id == total_steps - 1: 177 | cmds.append(BytePSSyncAll()) 178 | cmds.append(ReduceTiedGrads()) 179 | cmds.append(ReduceGrads()) 180 | cmds.append(OptimizerStep()) 181 | 182 | yield cmds 183 | 184 | def _steps_no_prefetch(self): 185 | """""" 186 | total_steps = 2 * (self.micro_batches + self.stages - 1) 187 | for step_id in range(total_steps): 188 | # Map the step of the pipeline to the micro-batch id and also whether it is a 189 | # forward or backward pass step. 190 | cmds = [] 191 | micro_batch_id, is_forward = self._step_to_micro_batch(step_id) 192 | batch_is_valid = self._valid_micro_batch(micro_batch_id) 193 | if not batch_is_valid: 194 | if step_id == total_steps - 1: 195 | cmds.append(BytePSSyncAll()) 196 | cmds.append(ReduceTiedGrads()) 197 | cmds.append(ReduceGrads()) 198 | cmds.append(OptimizerStep()) 199 | yield cmds 200 | continue 201 | 202 | curr_buffer = self._buffer_idx(micro_batch_id) 203 | 204 | if is_forward: 205 | if self._valid_stage(self.prev_stage): 206 | cmds.append(BytePSRecvActivation(curr_buffer)) 207 | cmds.append(BytePSSyncActivation(curr_buffer)) 208 | if self.stage_id == 0 or self.stage_id == self.stages - 1: 209 | # First/last stage loads 210 | cmds.append(LoadMicroBatch(curr_buffer)) 211 | cmds.append(BytePSForwardPass(curr_buffer)) 212 | if self._valid_stage(self.next_stage): 213 | cmds.append(BytePSSendActivation(curr_buffer)) 214 | else: 215 | if self._valid_stage(self.next_stage): 216 | cmds.append(BytePSRecvGrad(curr_buffer)) 217 | cmds.append(BytePSSyncGrad(curr_buffer)) 218 | cmds.append(BytePSBackwardPass(curr_buffer)) 219 | if self._valid_stage(self.prev_stage): 220 | cmds.append(BytePSSendGrad(curr_buffer)) 221 | 222 | # Model step at the end of the batch 223 | if step_id == total_steps - 1: 224 | cmds.append(BytePSSyncAll()) 225 | cmds.append(ReduceTiedGrads()) 226 | cmds.append(ReduceGrads()) 227 | cmds.append(OptimizerStep()) 228 | 229 | yield cmds 230 | 231 | def num_pipe_buffers(self): 232 | """As many buffers as the distance from this stage to the last stage. 233 | """ 234 | buffers = min(self.micro_batches, self.stages * 2) 235 | if BYTEPS_REDUCED_MEM: 236 | buffers = min(self.stages + 1, self.micro_batches) 237 | return max(2, buffers) 238 | 239 | 240 | class BytePSSendActivation(BufferOpInstruction): 241 | pass 242 | 243 | class BytePSRecvActivation(BufferOpInstruction): 244 | pass 245 | 246 | class BytePSSyncActivation(BufferOpInstruction): 247 | pass 248 | 249 | class BytePSSyncGrad(BufferOpInstruction): 250 | pass 251 | 252 | class BytePSSendGrad(BufferOpInstruction): 253 | pass 254 | 255 | class BytePSRecvGrad(BufferOpInstruction): 256 | pass 257 | 258 | class BytePSForwardPass(BufferOpInstruction): 259 | pass 260 | 261 | class BytePSBackwardPass(BufferOpInstruction): 262 | pass 263 | 264 | class BytePSSyncAll(PipeInstruction): 265 | pass 266 | 267 | 268 | 269 | 270 | 271 | 272 | 273 | -------------------------------------------------------------------------------- /examples/gpt/pretrain_gpt2.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, ByteDance Inc. All rights reserved. 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | """Pretrain GPT2""" 4 | import torch 5 | import os 6 | import numpy as np 7 | import time 8 | import sys 9 | 10 | _cwd = os.path.dirname(os.path.abspath(__file__)) 11 | _giantModel_dir = os.path.join(_cwd, '../../src') 12 | sys.path.append(_giantModel_dir) 13 | 14 | from initialize import initialize_megatron, initialize_pipeline 15 | from gpt_piped import GPTModelPiped 16 | 17 | from megatron import get_args, mpu 18 | from megatron import get_timers 19 | from megatron import get_tensorboard_writer 20 | from megatron import print_rank_0 21 | from megatron.learning_rates import AnnealingLR 22 | from megatron.training import build_train_valid_test_data_iterators 23 | from megatron.data.gpt2_dataset import get_indexed_dataset_, get_train_valid_test_split_, _num_tokens, _num_epochs, _build_doc_idx, _build_shuffle_idx 24 | from deepspeed.utils import log_dist 25 | 26 | def _build_index_mappings(name, data_prefix, documents, sizes, 27 | num_samples, seq_length, seed): 28 | """Build doc-idx, sample-idx, and shuffle-idx. 29 | doc-idx: is an array (ordered) of documents to be used in training. 30 | sample-idx: is the start document index and document offset for each 31 | training sample. 32 | shuffle-idx: maps the sample index into a random index into sample-idx. 33 | """ 34 | log_dist(f' >>>> Entering _build_index_mappings', ranks=[-1]) 35 | # Number of tokens in each epoch and number of required epochs. 36 | args = get_args() 37 | tokens_per_epoch = _num_tokens(documents, sizes) 38 | num_epochs = _num_epochs(tokens_per_epoch, seq_length, num_samples) 39 | # rng state 40 | np_rng = np.random.RandomState(seed=seed) 41 | 42 | # Filename of the index mappings. 43 | _filename = data_prefix 44 | _filename += '_{}_{}_indexmap'.format(args.rank, name) 45 | _filename += '_{}ns'.format(num_samples) 46 | _filename += '_{}sl'.format(seq_length) 47 | _filename += '_{}s'.format(seed) 48 | doc_idx_filename = _filename + '_doc_idx.npy' 49 | sample_idx_filename = _filename + '_sample_idx.npy' 50 | shuffle_idx_filename = _filename + '_shuffle_idx.npy' 51 | 52 | # Build the indexed mapping if not exist. 53 | device_count = torch.cuda.device_count() 54 | if (not os.path.isfile(doc_idx_filename)) or \ 55 | (not os.path.isfile(sample_idx_filename)) or \ 56 | (not os.path.isfile(shuffle_idx_filename)): 57 | 58 | log_dist(f' > WARNING: could not find index map files, building ' 59 | 'the indices ...', ranks=[-1]) 60 | # doc-idx. 61 | start_time = time.time() 62 | doc_idx = _build_doc_idx(documents, num_epochs, np_rng) 63 | np.save(doc_idx_filename, doc_idx, allow_pickle=True) 64 | log_dist(' > elasped time to build and save doc-idx mapping ' 65 | '(seconds): {:4f}'.format(time.time() - start_time), ranks=[-1]) 66 | # sample-idx. 67 | start_time = time.time() 68 | # Use C++ implementation for speed. 69 | # First compile and then import. 70 | from megatron.data.dataset_utils import compile_helper 71 | compile_helper() 72 | from megatron.data import helpers 73 | assert doc_idx.dtype == np.int32 74 | assert sizes.dtype == np.int32 75 | sample_idx = helpers.build_sample_idx(sizes, doc_idx, seq_length, 76 | num_epochs, tokens_per_epoch) 77 | # sample_idx = _build_sample_idx(sizes, doc_idx, seq_length, 78 | # num_epochs, tokens_per_epoch) 79 | np.save(sample_idx_filename, sample_idx, allow_pickle=True) 80 | log_dist(' > elasped time to build and save sample-idx mapping ' 81 | '(seconds): {:4f}'.format(time.time() - start_time), ranks=[-1]) 82 | # shuffle-idx. 83 | start_time = time.time() 84 | # -1 is due to data structure used to retieve the index: 85 | # sample i --> [sample_idx[i], sample_idx[i+1]) 86 | shuffle_idx = _build_shuffle_idx(sample_idx.shape[0] - 1, np_rng) 87 | np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True) 88 | log_dist(' > elasped time to build and save shuffle-idx mapping' 89 | ' (seconds): {:4f}'.format(time.time() - start_time), ranks=[-1]) 90 | 91 | # This should be a barrier but nccl barrier assumes 92 | # device_index=rank which is not the case for model 93 | # parallel case 94 | counts = torch.cuda.LongTensor([1]) 95 | torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group()) 96 | assert counts[0].item() == torch.distributed.get_world_size( 97 | group=mpu.get_data_parallel_group()) 98 | 99 | # Load mappings. 100 | start_time = time.time() 101 | log_dist(' > loading doc-idx mapping from {}'.format( 102 | doc_idx_filename)) 103 | 104 | if not os.path.isfile(doc_idx_filename): 105 | log_dist(' > loading doc-idx mapping from {} failed, file not exist'.format( 106 | doc_idx_filename), ranks=[-1]) 107 | 108 | doc_idx = np.load(doc_idx_filename, allow_pickle=True, mmap_mode='r') 109 | log_dist(' > loading sample-idx mapping from {}'.format( 110 | sample_idx_filename), ranks=[-1]) 111 | if not os.path.isfile(sample_idx_filename): 112 | log_dist(' > loading doc-idx mapping from {} failed, file not exist'.format( 113 | sample_idx_filename), ranks=[-1]) 114 | sample_idx = np.load(sample_idx_filename, allow_pickle=True, mmap_mode='r') 115 | log_dist(' > loading shuffle-idx mapping from {}'.format( 116 | shuffle_idx_filename), ranks=[-1]) 117 | if not os.path.isfile(shuffle_idx_filename): 118 | log_dist(' > loading doc-idx mapping from {} failed, file not exist'.format( 119 | shuffle_idx_filename), ranks=[-1]) 120 | shuffle_idx = np.load(shuffle_idx_filename, allow_pickle=True, mmap_mode='r') 121 | log_dist(' loaded indexed file in {:3.3f} seconds'.format( 122 | time.time() - start_time), ranks=[-1]) 123 | log_dist(' total number of samples: {}'.format( 124 | sample_idx.shape[0]), ranks=[-1]) 125 | log_dist(' total number of epochs: {}'.format(num_epochs), ranks=[-1]) 126 | 127 | log_dist(f' >>>> exiting _build_index_mappings', ranks=[-1]) 128 | return doc_idx, sample_idx, shuffle_idx 129 | 130 | class GPT2DatasetFixed(torch.utils.data.Dataset): 131 | def __init__(self, name, data_prefix, documents, indexed_dataset, 132 | num_samples, seq_length, seed): 133 | 134 | self.name = name 135 | self.indexed_dataset = indexed_dataset 136 | 137 | # Checks 138 | assert np.min(documents) >= 0 139 | assert np.max(documents) < indexed_dataset.sizes.shape[0] 140 | 141 | # Build index mappings. 142 | self.doc_idx, self.sample_idx, self.shuffle_idx = _build_index_mappings( 143 | self.name, data_prefix, documents, self.indexed_dataset.sizes, 144 | num_samples, seq_length, seed) 145 | 146 | def __len__(self): 147 | # -1 is due to data structure used to retieve the index: 148 | # sample i --> [sample_idx[i], sample_idx[i+1]) 149 | return self.sample_idx.shape[0] - 1 150 | 151 | def __getitem__(self, idx): 152 | # Get the shuffled index. 153 | idx = self.shuffle_idx[idx] 154 | # Start and end documents and offsets. 155 | doc_index_f = self.sample_idx[idx][0] 156 | doc_index_l = self.sample_idx[idx + 1][0] 157 | offset_f = self.sample_idx[idx][1] 158 | offset_l = self.sample_idx[idx + 1][1] 159 | # If we are within the same document, just extract the chunk. 160 | if doc_index_f == doc_index_l: 161 | sample = self.indexed_dataset.get(self.doc_idx[doc_index_f], 162 | offset=offset_f, 163 | length=offset_l - offset_f + 1) 164 | else: 165 | # Otherwise, get the rest of the initial document. 166 | sample_list = [self.indexed_dataset.get(self.doc_idx[doc_index_f], 167 | offset=offset_f)] 168 | # Loop over all in between documents and add the entire document. 169 | for i in range(doc_index_f + 1, doc_index_l): 170 | sample_list.append(self.indexed_dataset.get(self.doc_idx[i])) 171 | # And finally add the relevant portion of last document. 172 | sample_list.append(self.indexed_dataset.get( 173 | self.doc_idx[doc_index_l], 174 | length=offset_l + 1)) 175 | sample = np.concatenate(sample_list) 176 | 177 | return {'text': np.array(sample, dtype=np.int64)} 178 | 179 | 180 | 181 | def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, 182 | train_valid_test_num_samples, 183 | seq_length, seed, skip_warmup): 184 | """Build train, valid, and test datasets.""" 185 | 186 | # Indexed dataset. 187 | indexed_dataset = get_indexed_dataset_(data_prefix, 188 | data_impl, 189 | skip_warmup) 190 | 191 | total_num_of_documents = indexed_dataset.sizes.shape[0] 192 | splits = get_train_valid_test_split_(splits_string, total_num_of_documents) 193 | 194 | # Print stats about the splits. 195 | print_rank_0(' > dataset split:') 196 | 197 | def print_split_stats(name, index): 198 | print_rank_0(' {}:'.format(name)) 199 | print_rank_0(' document indices in [{}, {}) total of {} ' 200 | 'documents'.format(splits[index], splits[index + 1], 201 | splits[index + 1] - splits[index])) 202 | print_split_stats('train', 0) 203 | print_split_stats('validation', 1) 204 | print_split_stats('test', 2) 205 | 206 | def build_dataset(index, name): 207 | dataset = None 208 | if splits[index + 1] > splits[index]: 209 | documents = np.arange(start=splits[index], stop=splits[index + 1], 210 | step=1, dtype=np.int32) 211 | dataset = GPT2DatasetFixed(name, data_prefix, 212 | documents, indexed_dataset, 213 | train_valid_test_num_samples[index], 214 | seq_length, seed) 215 | return dataset 216 | 217 | train_dataset = build_dataset(0, 'train') 218 | valid_dataset = build_dataset(1, 'valid') 219 | test_dataset = build_dataset(2, 'test') 220 | 221 | return (train_dataset, valid_dataset, test_dataset) 222 | 223 | def model_provider(): 224 | """Build the model.""" 225 | 226 | print_rank_0('building GPT2 model ...') 227 | model = GPTModelPiped() 228 | return model 229 | 230 | def lr_scheduler_builder(optimizer): 231 | """Build the learning rate scheduler.""" 232 | args = get_args() 233 | 234 | # Add linear learning rate scheduler. 235 | if args.lr_decay_iters is not None: 236 | num_iters = args.lr_decay_iters 237 | else: 238 | num_iters = args.train_iters 239 | num_iters = max(1, num_iters) 240 | init_step = 0 241 | warmup_iter = args.warmup * num_iters 242 | 243 | lr_scheduler = AnnealingLR( 244 | optimizer, 245 | start_lr=args.lr, 246 | warmup_iter=warmup_iter, 247 | total_iters=num_iters, 248 | decay_style=args.lr_decay_style, 249 | last_iter=init_step, 250 | min_lr=args.min_lr, 251 | use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler, 252 | override_lr_scheduler=args.override_lr_scheduler) 253 | 254 | return lr_scheduler 255 | 256 | 257 | def pretrain(model_provider, args_defaults={}): 258 | initialize_megatron(args_defaults=args_defaults) 259 | timers = get_timers() 260 | 261 | # Model, optimizer, and learning rate. 262 | timers('model and optimizer').start() 263 | model = model_provider() 264 | engine, optimizer, lr_scheduler = initialize_pipeline(model, None, None, lr_scheduler_builder) 265 | timers('model and optimizer').stop() 266 | 267 | # Print setup timing. 268 | print_rank_0('done with setups ...') 269 | print_rank_0('training ...') 270 | 271 | train(engine, optimizer, lr_scheduler) 272 | 273 | def traing_log(loss_dict, iteration): 274 | args = get_args() 275 | timers = get_timers() 276 | writer = get_tensorboard_writer() 277 | 278 | # Logging. 279 | timers_to_log = [] 280 | 281 | def add_to_logging(name): 282 | if name in timers.timers: 283 | timers_to_log.append(name) 284 | add_to_logging('forward') 285 | add_to_logging('backward') 286 | add_to_logging('backward-backward') 287 | add_to_logging('backward-allreduce') 288 | add_to_logging('backward-master-grad') 289 | add_to_logging('backward-clip-grad') 290 | add_to_logging('optimizer') 291 | add_to_logging('batch generator') 292 | 293 | if writer and torch.distributed.get_rank() == 0: 294 | writer.add_scalar('loss', loss_dict, iteration) 295 | normalizer = iteration % args.log_interval 296 | if normalizer == 0: 297 | normalizer = args.log_interval 298 | timers.write(timers_to_log, writer, iteration, 299 | normalizer=normalizer) 300 | 301 | def train_valid_test_dataset_provider(train_val_test_num_samples): 302 | """Build train, valid, and test datasets.""" 303 | args = get_args() 304 | 305 | print_rank_0('> building train, validation, and test datasets ' 306 | 'for GPT ...') 307 | train_ds, valid_ds, test_ds = build_train_valid_test_datasets( 308 | data_prefix=args.data_path, 309 | data_impl=args.data_impl, 310 | splits_string=args.split, 311 | train_valid_test_num_samples=train_val_test_num_samples, 312 | seq_length=args.seq_length, 313 | seed=args.seed, 314 | skip_warmup=(not args.mmap_warmup)) 315 | print_rank_0("> finished creating GPT datasets ...") 316 | 317 | return train_ds, valid_ds, test_ds 318 | 319 | def train(engine, optimizer, lr_scheduler): 320 | """Train the model function.""" 321 | args = get_args() 322 | timers = get_timers() 323 | 324 | # Turn on training mode which enables dropout. 325 | engine.train() 326 | 327 | # Iterations. 328 | iteration = args.iteration 329 | 330 | timers('interval time').start() 331 | 332 | train_data_iterator, valid_data_iterator, test_data_iterator \ 333 | = build_train_valid_test_data_iterators(train_valid_test_dataset_provider) 334 | 335 | log_dist(f' >>>> start training', ranks=[-1]) 336 | while iteration < args.train_iters: 337 | engine.train_batch(train_data_iterator) 338 | 339 | if __name__ == "__main__": 340 | pretrain(model_provider, 341 | args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}) 342 | -------------------------------------------------------------------------------- /src/veGiantModel/engine/module.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, ByteDance Inc. All rights reserved. 2 | # Copyright 2019 The Microsoft DeepSpeed Team 3 | import os 4 | 5 | import re as regex 6 | 7 | from functools import partial 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.distributed as dist 12 | 13 | from math import floor 14 | 15 | from deepspeed.utils import logger 16 | from deepspeed.runtime import utils as ds_utils 17 | from deepspeed.runtime.activation_checkpointing import checkpointing 18 | from deepspeed.pipe import PipelineModule,LayerSpec, TiedLayerSpec 19 | from .topology import PipeDataParallelTopology, PipelineParallelGrid 20 | 21 | class VeGiantModule(PipelineModule): 22 | def __init__(self, 23 | layers, 24 | num_stages=None, 25 | loss_fn=None, 26 | seed_layers=False, 27 | seed_fn=None, 28 | base_seed=1234, 29 | grid=None, 30 | partition_method='parameters', 31 | activation_checkpoint_interval=0, 32 | activation_checkpoint_func=checkpointing.checkpoint): 33 | """Modules to be parallelized with pipeline parallelism. 34 | 35 | The key constraint that enables pipeline parallelism is the 36 | representation of the forward pass as a sequence of layers 37 | and the enforcement of a simple interface between them. The 38 | forward pass is implicitly defined by the module ``layers``. The key 39 | assumption is that the output of each layer can be directly fed as 40 | input to the next, like a ``torch.nn.Sequence``. The forward pass is 41 | implicitly: 42 | 43 | .. code-block:: python 44 | 45 | def forward(self, inputs): 46 | x = inputs 47 | for layer in self.layers: 48 | x = layer(x) 49 | return x 50 | 51 | Args: 52 | layers (Iterable): A sequence of layers defining pipeline structure. Can be a ``torch.nn.Sequential`` module. 53 | num_stages (int, optional): The degree of pipeline parallelism. If not specified, ``topology`` must be provided. 54 | topology (``deepseed.pipe.ProcessTopology``, optional): Defines the axes of parallelism axes for training. Must be provided if ``num_stages`` is ``None``. 55 | loss_fn (callable, optional): Loss is computed ``loss = loss_fn(outputs, label)`` 56 | base_seed (int, optional): [description]. Defaults to 1234. 57 | partition_method (str, optional): [description]. Defaults to 'parameters'. 58 | activation_checkpoint_interval (int, optional): The granularity activation checkpointing in terms of number of layers. 0 disables activation checkpointing. 59 | activation_checkpoint_func (callable, optional): The function to use for activation checkpointing. Defaults to ``deepspeed.checkpointing.checkpoint``. 60 | """ 61 | 62 | super(PipelineModule, self).__init__() 63 | 64 | topology = grid.topology() if grid is not None else None 65 | 66 | if num_stages is None and topology is None: 67 | raise RuntimeError('must provide num_stages or topology') 68 | 69 | self.micro_offset = 0 70 | 71 | self.loss_fn = loss_fn 72 | 73 | self.seed_layers = seed_layers 74 | self.seed_fn = seed_fn 75 | self.base_seed = base_seed 76 | if dist.get_rank() == 0: 77 | try: 78 | seed_str = self.seed_fn.__name__ 79 | except AttributeError: 80 | seed_str = None 81 | print( 82 | f'SEED_LAYERS={self.seed_layers} BASE_SEED={self.base_seed} SEED_FN={seed_str}' 83 | ) 84 | 85 | # Setup world info 86 | self.world_group = dist.new_group(ranks=range(dist.get_world_size())) 87 | self.global_rank = dist.get_rank(group=self.world_group) 88 | self.world_size = dist.get_world_size(group=self.world_group) 89 | 90 | if topology: 91 | self._topo = topology 92 | self.num_stages = self._topo.get_dim('pipe') 93 | else: 94 | self.num_stages = num_stages 95 | if topology is None: 96 | if self.world_size % self.num_stages != 0: 97 | raise RuntimeError( 98 | f'num_stages ({self.num_stages}) must divide distributed world size ({self.world_size})' 99 | ) 100 | dp = self.world_size // num_stages 101 | topology = PipeDataParallelTopology(num_pp=num_stages, num_dp=dp) 102 | self._topo = topology 103 | 104 | # Contruct communicators for pipeline topology 105 | self._grid = grid if grid is not None else PipelineParallelGrid(process_group=self.world_group, topology=self._topo) 106 | 107 | self.stage_id = self._topo.get_coord(self.global_rank).pipe 108 | 109 | # Initialize partition information 110 | self._layer_specs = list(layers) 111 | self._num_layers = len(self._layer_specs) 112 | self._local_start = 0 113 | self._local_stop = None 114 | self._partition_layers(method=partition_method) 115 | 116 | self.forward_funcs = [] 117 | self.tied_modules = nn.ModuleDict() 118 | self.tied_weight_attrs = {} 119 | 120 | # Offset the random seed by the stage ID. 121 | #newseed = torch.cuda.initial_seed() + self._grid.get_stage_id() 122 | #ds_utils.set_random_seed(newseed) 123 | 124 | #with torch.random.fork_rng(devices=[torch.cuda.current_device()]): 125 | self._build() 126 | self.to('cuda') 127 | 128 | self.tied_comms = self._index_tied_modules() 129 | self._synchronize_tied_weights() 130 | 131 | self.activation_checkpoint_interval = activation_checkpoint_interval 132 | self.activation_checkpoint_func = activation_checkpoint_func 133 | 134 | def _build(self): 135 | specs = self._layer_specs 136 | 137 | for local_idx, layer in enumerate(specs[self._local_start:self._local_stop]): 138 | layer_idx = local_idx + self._local_start 139 | if self.seed_layers: 140 | if self.seed_fn: 141 | self.seed_fn(self.base_seed + layer_idx) 142 | else: 143 | ds_utils.set_random_seed(self.base_seed + layer_idx) 144 | 145 | # Recursively build PipelineModule objects 146 | if isinstance(layer, PipelineModule): 147 | raise NotImplementedError('RECURSIVE BUILD NOT YET IMPLEMENTED') 148 | 149 | # LayerSpec objects contain an nn.Module that should be allocated now. 150 | elif isinstance(layer, nn.Module): 151 | name = str(layer_idx) 152 | self.forward_funcs.append(layer) 153 | self.add_module(name, layer) 154 | 155 | # TiedLayerSpec objects contain an nn.Module that should be allocated now. 156 | elif isinstance(layer, TiedLayerSpec): 157 | # Build and register the module if we haven't seen it before. 158 | if layer.key not in self.tied_modules: 159 | self.tied_modules[layer.key] = layer.build() 160 | self.tied_weight_attrs[layer.key] = layer.tied_weight_attr 161 | 162 | if layer.forward_fn is None: 163 | # Just use forward() 164 | self.forward_funcs.append(self.tied_modules[layer.key]) 165 | else: 166 | # User specified fn with args (module, input) 167 | self.forward_funcs.append( 168 | partial(layer.forward_fn, 169 | self.tied_modules[layer.key])) 170 | 171 | # LayerSpec objects contain an nn.Module that should be allocated now. 172 | elif isinstance(layer, LayerSpec): 173 | module = layer.build() 174 | name = str(layer_idx) 175 | self.forward_funcs.append(module) 176 | self.add_module(name, module) 177 | 178 | # Last option: layer may be a functional (e.g., lambda). We do nothing in 179 | # that case and just use it in forward() 180 | else: 181 | self.forward_funcs.append(layer) 182 | 183 | # All pipeline parameters should be considered as model parallel in the context 184 | # of our FP16 optimizer 185 | for p in self.parameters(): 186 | p.model_parallel = True 187 | 188 | def _count_layer_params(self): 189 | """Count the trainable parameters in individual layers. 190 | 191 | This routine will only build one layer at a time. 192 | 193 | Returns: 194 | A list of the number of parameters in each layer. 195 | """ 196 | param_counts = [0] * len(self._layer_specs) 197 | for idx, layer in enumerate(self._layer_specs): 198 | if isinstance(layer, LayerSpec): 199 | l = layer.build() 200 | params = filter(lambda p: p.requires_grad, l.parameters()) 201 | param_counts[idx] = sum(p.numel() for p in params) 202 | elif isinstance(layer, nn.Module): 203 | params = filter(lambda p: p.requires_grad, layer.parameters()) 204 | param_counts[idx] = sum(p.numel() for p in params) 205 | return param_counts 206 | 207 | def _find_layer_type(self, layername): 208 | idxs = [] 209 | typeregex = regex.compile(layername, regex.IGNORECASE) 210 | for idx, layer in enumerate(self._layer_specs): 211 | name = None 212 | if isinstance(layer, LayerSpec): 213 | name = layer.typename.__name__ 214 | elif isinstance(layer, nn.Module): 215 | name = layer.__class__.__name__ 216 | else: 217 | try: 218 | name = layer.__name__ 219 | except AttributeError: 220 | continue 221 | if typeregex.search(name): 222 | idxs.append(idx) 223 | 224 | if len(idxs) == 0: 225 | raise RuntimeError( 226 | f"Partitioning '{layername}' found no valid layers to partition.") 227 | return idxs 228 | 229 | def forward(self, forward_input): 230 | # We need to offset the seed by the microbatch ID. Save it in a local var to 231 | # ensure it is preserved in the closure. Otherwise checkpointed forward funcs 232 | # will see a different offset. 233 | self.micro_offset += 1 234 | 235 | def exec_range_func(start, end): 236 | ''' Helper function to be used with checkpoint() 237 | Adapted from torch.utils.checkpoint:checkpoint_sequential() 238 | ''' 239 | local_micro_offset = self.micro_offset + 1 240 | 241 | def exec_func(*inputs): 242 | # Single tensor inputs need to be unwrapped 243 | if len(inputs) == 1: 244 | inputs = inputs[0] 245 | for idx, layer in enumerate(self.forward_funcs[start:end]): 246 | self.curr_layer = idx + self._local_start 247 | if self.seed_layers: 248 | new_seed = (self.base_seed * 249 | local_micro_offset) + self.curr_layer 250 | if self.seed_fn: 251 | self.seed_fn(new_seed) 252 | else: 253 | ds_utils.set_random_seed(new_seed) 254 | 255 | inputs = layer(inputs) 256 | return inputs 257 | 258 | return exec_func 259 | 260 | if self.activation_checkpoint_interval == 0: 261 | func = exec_range_func(0, len(self.forward_funcs)) 262 | x = func(forward_input) 263 | else: 264 | num_layers = len(self.forward_funcs) 265 | x = forward_input 266 | for start_idx in range(0, num_layers, self.activation_checkpoint_interval): 267 | end_idx = min(start_idx + self.activation_checkpoint_interval, 268 | num_layers) 269 | 270 | funcs = self.forward_funcs[start_idx:end_idx] 271 | # Since we either pass tensors or tuples of tensors without unpacking, we 272 | # need to be careful not to double-wrap tensors with tuple. 273 | if not isinstance(x, tuple): 274 | x = (x, ) 275 | 276 | if self._is_checkpointable(funcs): 277 | x = self.activation_checkpoint_func( 278 | exec_range_func(start_idx, 279 | end_idx), 280 | *x) 281 | else: 282 | x = exec_range_func(start_idx, end_idx)(*x) 283 | return x 284 | 285 | def _partition_uniform(self, num_items, num_parts): 286 | # print(f'enter _partition_uniform', flush=True) 287 | parts = [0] * (num_parts + 1) 288 | if num_items <= num_parts: 289 | for p in range(num_parts + 1): 290 | parts[p] = min(p, num_items) 291 | return parts 292 | expected_chunksize = num_items / num_parts 293 | for p in range(num_parts): 294 | parts[p] = min(floor(expected_chunksize * p), num_items) 295 | parts[num_parts] = num_items 296 | return parts 297 | 298 | def _partition_balanced(self, weights, num_parts, eps=1e-3): 299 | num_items = len(weights) 300 | # First check for the trivial edge case 301 | if num_items <= num_parts: 302 | return self._partition_uniform(num_items, num_parts) 303 | 304 | weights_ = ds_utils.prefix_sum_inc(weights) 305 | 306 | # Find the smallest bottleneck (weight of heaviest partition) 307 | bottleneck = ds_utils._rb_partition_balanced(weights_, num_parts, eps=eps) 308 | 309 | # Now compute that partitioning 310 | parts, success = ds_utils._lprobe(weights_, num_parts, bottleneck) 311 | assert success 312 | 313 | return parts 314 | 315 | def _partition_layers(self, method='uniform'): 316 | num_stages = self._topo.get_dim('pipe') 317 | stage_id = self._topo.get_coord(self.global_rank).pipe 318 | 319 | if self.global_rank == 0: 320 | logger.info(f'Partitioning pipeline stages with method {method}') 321 | 322 | method = method.lower() 323 | 324 | # Each stage gets a simple uniform number of layers. 325 | if method == 'uniform': 326 | num_layers = len(self._layer_specs) 327 | self.parts = self._partition_uniform(num_items=num_layers, 328 | num_parts=num_stages) 329 | elif method == 'parameters': 330 | param_counts = self._count_layer_params() 331 | self.parts = self._partition_balanced(weights=param_counts, 332 | num_parts=num_stages) 333 | elif method.startswith('type:'): 334 | layertype = method.split(':')[1] 335 | binary_weights = [0] * len(self._layer_specs) 336 | for idx in self._find_layer_type(layertype): 337 | binary_weights[idx] = 1 338 | else: 339 | self.parts = self._partition_balanced(weights=binary_weights, 340 | num_parts=num_stages) 341 | elif method.startswith('manual:'): 342 | msplit = method.split(':') 343 | layernum = int(msplit[1]) 344 | layerparts = msplit[2].split(',') 345 | assert len(self._layer_specs) == layernum # failsafe check for layer num 346 | assert num_stages == len(layerparts)-1 # failsafe check for num stages 347 | self.parts = list(map(int, layerparts)) 348 | elif method == 'profile': 349 | raise NotImplementedError(f'Partitioning method {method} not implemented.') 350 | else: 351 | raise NotImplementedError(f'Partitioning method {method} not implemented.') 352 | 353 | # Print some information on the partitioning. 354 | if self.global_rank == 0: 355 | for stage in range(num_stages): 356 | start = self.parts[stage] 357 | stop = self.parts[stage + 1] 358 | print(f'stage={stage} layers={stop - start}') 359 | for idx, layer in enumerate(self._layer_specs[start:stop]): 360 | name = str(layer) 361 | if isinstance(layer, LayerSpec): 362 | name = layer.typename.__name__ 363 | if isinstance(layer, nn.Module): 364 | name = layer.__class__.__name__ 365 | else: 366 | try: 367 | name = layer.__name__ 368 | except AttributeError: 369 | pass 370 | print(f' {idx+start:2d}: {name}') 371 | if self.loss_fn: 372 | try: 373 | print(f' loss: {self.loss_fn.__name__}') 374 | except AttributeError: 375 | print(f' loss: {self.loss_fn.__class__.__name__}') 376 | 377 | self._set_bounds(start=self.parts[stage_id], stop=self.parts[stage_id + 1]) 378 | 379 | def allreduce_tied_weight_gradients(self): 380 | '''All reduce the gradients of the tied weights between tied stages''' 381 | for key, comm in self.tied_comms.items(): 382 | weight = getattr(self.tied_modules[key], comm['weight_attr']) 383 | dist.all_reduce(weight.grad, group=comm['group']) 384 | 385 | def _synchronize_tied_weights(self): 386 | for key, comm in self.tied_comms.items(): 387 | dist.broadcast( 388 | getattr(comm['module'], 389 | comm['weight_attr']), 390 | src=min(comm['ranks']), 391 | group=comm['group'], 392 | ) 393 | 394 | def _index_tied_modules(self): 395 | ''' Build communication structures for tied modules. ''' 396 | tied_comms = {} 397 | if self._topo.get_dim('pipe') == 1: 398 | return tied_comms 399 | 400 | specs = self._layer_specs 401 | tie_keys = set(s.key for s in specs if isinstance(s, TiedLayerSpec)) 402 | for key in tie_keys: 403 | # Find the layers that the tied module appears in 404 | tied_layers = [] 405 | for idx, layer in enumerate(specs): 406 | if isinstance(layer, TiedLayerSpec) and layer.key == key: 407 | tied_layers.append(idx) 408 | # Find all stages with this tied module 409 | # TODO: Would be nice to remove the nested data/model parallelism loops and 410 | # TODO: instead generalize in some way, since we really just care about the 411 | # TODO: stage that owns the tied layer. Then loop over each (dp, mp, ...) 412 | # TODO: fiber to generate process groups. 413 | tied_stages = set(self.stage_owner(idx) for idx in tied_layers) 414 | for dp in range(self._grid.data_parallel_size): 415 | for mp in range(self._grid.model_parallel_size): 416 | tied_ranks = [] 417 | for s in sorted(tied_stages): 418 | if self._grid.model_parallel_size > 1: 419 | tied_ranks.append( 420 | self._grid.stage_to_global(stage_id=s, 421 | data=dp, 422 | model=mp)) 423 | else: 424 | tied_ranks.append( 425 | self._grid.stage_to_global(stage_id=s, 426 | data=dp)) 427 | group = dist.new_group(ranks=tied_ranks) 428 | 429 | # Record this tied module if we own a local copy of it. 430 | if self.global_rank in tied_ranks: 431 | assert key in self.tied_modules 432 | if key in self.tied_modules: 433 | tied_comms[key] = { 434 | 'ranks': tied_ranks, 435 | 'group': group, 436 | 'weight_attr': self.tied_weight_attrs[key], 437 | 'module': self.tied_modules[key], 438 | } 439 | # Only count the tied module once in the eyes of the FP16 optimizer 440 | if self.global_rank != tied_ranks[0]: 441 | for p in self.tied_modules[key].parameters(): 442 | p.model_parallel = False 443 | ''' 444 | if len(tied_comms) > 0: 445 | print(f'RANK={self.global_rank} tied_comms={tied_comms}') 446 | ''' 447 | 448 | return tied_comms 449 | 450 | def partitions(self): 451 | return self.parts 452 | 453 | def stage_owner(self, layer_idx): 454 | assert 0 <= layer_idx < self._num_layers 455 | for stage in range(self._topo.get_dim('pipe')): 456 | if self.parts[stage] <= layer_idx < self.parts[stage + 1]: 457 | return stage 458 | raise RuntimeError(f'Layer {layer_idx} not owned? parts={self.parts}') 459 | 460 | def _set_bounds(self, start=None, stop=None): 461 | """Manually define the range of layers that will be built on this process. 462 | 463 | These boundaries are treated as list slices and so start is inclusive and stop is 464 | exclusive. The default of None for both results in all layers being built 465 | locally. 466 | """ 467 | self._local_start = start 468 | self._local_stop = stop 469 | 470 | def set_checkpoint_interval(self, interval): 471 | assert interval >= 0 472 | self.checkpoint_interval = interval 473 | 474 | def topology(self): 475 | """ ProcessTopology object to query process mappings. """ 476 | return self._topo 477 | 478 | def mpu(self): 479 | return self._grid 480 | 481 | def num_pipeline_stages(self): 482 | return self._topo.get_dim('pipe') 483 | 484 | def ckpt_prefix(self, checkpoints_path, tag): 485 | """Build a prefix for all checkpoint files written by this module. """ 486 | # All checkpoint files start with this 487 | rank_name = 'module' 488 | 489 | # Data parallelism is omitted from the naming convention because we are agnostic 490 | # to this in the checkpoint. 491 | omit_dims = frozenset(['data']) 492 | axes = [a for a in self._grid._topo.get_axis_names() if a not in omit_dims] 493 | for dim in axes: 494 | rank = getattr(self._grid._topo.get_coord(rank=self.global_rank), dim) 495 | rank_name += f'-{dim}_{rank:02d}' 496 | 497 | ckpt_name = os.path.join(checkpoints_path, str(tag), rank_name) 498 | return ckpt_name 499 | 500 | def ckpt_layer_path(self, ckpt_dir, local_layer_idx): 501 | """Customize a prefix for a specific pipeline module layer. """ 502 | idx = local_layer_idx + self._local_start 503 | layer_ckpt_path = os.path.join(ckpt_dir, f'layer_{idx:02d}') 504 | rank_repr = self._grid._topo.get_rank_repr(rank=self.global_rank) 505 | if rank_repr is not '': 506 | layer_ckpt_path += f'-{rank_repr}' 507 | layer_ckpt_path += '-model_states.pt' 508 | return layer_ckpt_path 509 | 510 | def save_state_dict(self, save_dir): 511 | if self._grid.data_parallel_id != 0: 512 | return 513 | 514 | os.makedirs(save_dir, exist_ok=True) 515 | layer_offset = self._local_start 516 | for idx, layer in enumerate(self.forward_funcs): 517 | model_ckpt_path = self.ckpt_layer_path(save_dir, idx) 518 | if not hasattr(layer, 'state_dict'): 519 | continue 520 | torch.save(layer.state_dict(), model_ckpt_path) 521 | 522 | def load_state_dir(self, load_dir, strict=True): 523 | rank = dist.get_rank() 524 | 525 | layer_offset = self._local_start 526 | for idx, layer in enumerate(self.forward_funcs): 527 | # Functions, etc. will not have state_dicts 528 | if not hasattr(layer, 'load_state_dict'): 529 | continue 530 | 531 | model_ckpt_path = self.ckpt_layer_path(load_dir, idx) 532 | layer.load_state_dict(torch.load(model_ckpt_path, 533 | map_location=lambda storage, 534 | loc: storage), 535 | strict=strict) 536 | if self._grid.data_parallel_id == 0: 537 | logger.info( 538 | f'RANK={self.global_rank} Loaded layer={idx+layer_offset} file={model_ckpt_path}' 539 | ) 540 | 541 | self._synchronize_tied_weights() 542 | 543 | def _is_checkpointable(self, funcs): 544 | if self.__class__.__name__ == 'GPT2ModelPipe': 545 | return all('ParallelTransformerLayerPipe' in f.__class__.__name__ 546 | for f in funcs) 547 | 548 | params = [f.parameters() for f in funcs if isinstance(f, torch.nn.Module)] 549 | return any(len(list(p)) > 0 for p in params) 550 | -------------------------------------------------------------------------------- /src/veGiantModel/engine/topology.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, ByteDance Inc. All rights reserved. 2 | # Copyright 2019 The Microsoft DeepSpeed Team 3 | 4 | from deepspeed.utils import log_dist 5 | 6 | import torch.distributed as dist 7 | 8 | from collections import namedtuple 9 | from itertools import product as cartesian_product 10 | import logging, os 11 | 12 | import torch 13 | 14 | class ProcessTopology: 15 | """ Manages the mapping of n-dimensional Cartesian coordinates to linear 16 | indices. This mapping is used to map the rank of processes to the grid 17 | for various forms of parallelism. 18 | 19 | Each axis of the tensor is accessed by its name. The provided ordering 20 | of the axes defines the layout of the topology. ProcessTopology uses a "row-major" 21 | layout of the tensor axes, and so axes=['x', 'y'] would map coordinates (x,y) and 22 | (x,y+1) to adjacent linear indices. If instead axes=['y', 'x'] was used, coordinates 23 | (x,y) and (x+1,y) would be adjacent. 24 | 25 | Some methods return ProcessCoord namedtuples. 26 | """ 27 | def __init__(self, axes, dims): 28 | """Create a mapping of n-dimensional tensor coordinates to linear indices. 29 | 30 | Arguments: 31 | axes (list): the names of the tensor axes 32 | dims (list): the dimension (length) of each axis of the topology tensor 33 | """ 34 | 35 | self.axes = axes # names of each topology axis 36 | self.dims = dims # length of each topology axis 37 | 38 | # This is actually a class that lets us hash {'row':3, 'col':2} mappings 39 | self.ProcessCoord = namedtuple('ProcessCoord', axes) 40 | 41 | self.mapping = {} 42 | ranges = [range(d) for d in dims] 43 | # example: 1, (0,0,1) 44 | for global_rank, coord in enumerate(cartesian_product(*ranges)): 45 | key = {axis: coord[self.axes.index(axis)] for axis in self.axes} 46 | key = self.ProcessCoord(**key) 47 | # for example, {ProcessCoord(row=0, col=1) : 1} 48 | self.mapping[key] = global_rank 49 | 50 | def get_rank(self, **coord_kwargs): 51 | """Return the global rank of a process via its coordinates. 52 | 53 | Coordinates are specified as kwargs. For example: 54 | 55 | >>> X = ProcessTopology(axes=['x', 'y'], dims=[2,3]) 56 | >>> X.get_rank(x=0, y=1) 57 | 1 58 | """ 59 | if len(coord_kwargs) != len(self.axes): 60 | raise ValueError('get_rank() does not support slices. Use filter_match())') 61 | 62 | key = self.ProcessCoord(**coord_kwargs) 63 | assert key in self.mapping, f'key {kwargs} invalid' 64 | return self.mapping[key] 65 | 66 | def get_axis_names(self): 67 | """Return a list of the axis names in the ordering of the topology. """ 68 | return self.axes 69 | 70 | def get_rank_repr(self, 71 | rank, 72 | omit_axes=['data', 73 | 'pipe'], 74 | inner_sep='_', 75 | outer_sep='-'): 76 | """Return a string representation of a rank. 77 | 78 | This method is primarily used for checkpointing model data. 79 | 80 | For example: 81 | >>> topo = Topo(axes=['a', 'b'], dims=[2, 2]) 82 | >>> topo.get_rank_repr(rank=3) 83 | 'a_01-b_01' 84 | >>> topo.get_rank_repr(rank=3, omit_axes=['a']) 85 | 'b_01' 86 | 87 | Args: 88 | rank (int): A rank in the topology. 89 | omit_axes (list, optional): Axes that should not be in the representation. Defaults to ['data', 'pipe']. 90 | inner_sep (str, optional): [description]. Defaults to '_'. 91 | outer_sep (str, optional): [description]. Defaults to '-'. 92 | 93 | Returns: 94 | str: A string representation of the coordinate owned by ``rank``. 95 | """ 96 | omit_axes = frozenset(omit_axes) 97 | axes = [a for a in self.get_axis_names() if a not in omit_axes] 98 | names = [] 99 | for ax in axes: 100 | ax_rank = getattr(self.get_coord(rank=rank), ax) 101 | names.append(f'{ax}{inner_sep}{ax_rank:02d}') 102 | return outer_sep.join(names) 103 | 104 | def get_dim(self, axis): 105 | """Return the number of processes along the given axis. 106 | 107 | For example: 108 | >>> X = ProcessTopology(axes=['x', 'y'], dims=[2,3]) 109 | >>> X.get_dim('y') 110 | 3 111 | """ 112 | if axis not in self.axes: 113 | return 0 114 | return self.dims[self.axes.index(axis)] 115 | 116 | def get_coord(self, rank): 117 | """Return the coordinate owned by a process rank. 118 | 119 | The axes of the returned namedtuple can be directly accessed as members. For 120 | example: 121 | >>> X = ProcessTopology(axes=['x', 'y'], dims=[2,3]) 122 | >>> coord = X.get_coord(rank=1) 123 | >>> coord.x 124 | 0 125 | >>> coord.y 126 | 1 127 | """ 128 | for coord, idx in self.mapping.items(): 129 | if idx == rank: 130 | return coord 131 | raise ValueError(f'rank {rank} not found in topology.') 132 | 133 | def get_axis_comm_lists(self, axis): 134 | """ Construct lists suitable for a communicator group along axis ``axis``. 135 | 136 | Example: 137 | >>> topo = Topo(axes=['pipe', 'data', 'model'], dims=[2, 2, 2]) 138 | >>> topo.get_axis_comm_lists('pipe') 139 | [ 140 | [0, 4], # data=0, model=0 141 | [1, 5], # data=0, model=1 142 | [2, 6], # data=1, model=0 143 | [3, 7], # data=1, model=1 144 | ] 145 | 146 | Returns: 147 | A list of lists whose coordinates match in all axes *except* ``axis``. 148 | """ 149 | 150 | # We don't want to RuntimeError because it allows us to write more generalized 151 | # code for hybrid parallelisms. 152 | if axis not in self.axes: 153 | return [] 154 | 155 | # Grab all axes but `axis` 156 | other_axes = [a for a in self.axes if a != axis] 157 | 158 | lists = [] 159 | 160 | # Construct all combinations of coords with other_axes 161 | ranges = [range(self.get_dim(a)) for a in other_axes] 162 | for coord in cartesian_product(*ranges): 163 | other_keys = {a: coord[other_axes.index(a)] for a in other_axes} 164 | # now go over all ranks in `axis`. 165 | sub_list = [] 166 | for axis_key in range(self.get_dim(axis)): 167 | key = self.ProcessCoord(**other_keys, **{axis: axis_key}) 168 | sub_list.append(self.mapping[key]) 169 | lists.append(sub_list) 170 | 171 | return lists 172 | 173 | def filter_match(self, **filter_kwargs): 174 | """Return the list of ranks whose coordinates match the provided criteria. 175 | 176 | Example: 177 | >>> X = ProcessTopology(axes=['pipe', 'data', 'model'], dims=[2, 2, 2]) 178 | >>> X.filter_match(pipe=0, data=1) 179 | [2, 3] 180 | >>> [X.get_coord(rank) for rank in X.filter_match(pipe=0, data=1)] 181 | [ProcessCoord(pipe=0, data=1, model=0), ProcessCoord(pipe=0, data=1, model=1)] 182 | 183 | Arguments: 184 | **filter_kwargs (dict): criteria used to select coordinates. 185 | 186 | Returns: 187 | The list of ranks whose coordinates match filter_kwargs. 188 | """ 189 | def _filter_helper(x): 190 | for key, val in filter_kwargs.items(): 191 | if getattr(x, key) != val: 192 | return False 193 | return True 194 | 195 | coords = filter(_filter_helper, self.mapping.keys()) 196 | return [self.mapping[coo] for coo in coords] 197 | 198 | def get_axis_list(self, axis, idx): 199 | """Returns the list of global ranks whose coordinate in an axis is idx. 200 | 201 | For example: 202 | >>> X = ProcessTopology(axes=['x', 'y'], dims=[2,3]) 203 | >>> X.get_axis_list(axis='x', idx=0) 204 | [0, 1, 2] 205 | >>> X.get_axis_list(axis='y', idx=0) 206 | [0, 3] 207 | """ 208 | 209 | # This could be faster by generating the desired keys directly instead of 210 | # filtering. 211 | axis_num = self.axes.index(axis) 212 | ranks = [self.mapping[k] for k in self.mapping.keys() if k[axis_num] == idx] 213 | return ranks 214 | 215 | def world_size(self): 216 | return len(self.mapping) 217 | 218 | def __str__(self): 219 | return str(self.mapping) 220 | 221 | 222 | def _prime_factors(N): 223 | """ Returns the prime factorization of positive integer N. """ 224 | if N <= 0: 225 | raise ValueError("Values must be strictly positive.") 226 | 227 | primes = [] 228 | while N != 1: 229 | for candidate in range(2, N + 1): 230 | if N % candidate == 0: 231 | primes.append(candidate) 232 | N //= candidate 233 | break 234 | return primes 235 | 236 | 237 | class PipeDataParallelTopology(ProcessTopology): 238 | """ A topology specialiation for hybrid data and pipeline parallelism. 239 | 240 | Uses data parallelism on the last dimension to encourage gradient 241 | reductions to use high-bandwidth intra-node links and lower-volume 242 | pipeline communications to use low-bandwidth inter-node links. 243 | """ 244 | def __init__(self, num_pp, num_dp): 245 | super().__init__(axes=['pipe', 'data'], dims=[num_pp, num_dp]) 246 | 247 | 248 | class PipeModelDataParallelTopology(ProcessTopology): 249 | """ A topology for hybrid pipeline, model, and data parallelism. """ 250 | def __init__(self, num_dp, num_pp, num_mp): 251 | # super().__init__(axes=['model', 'data', 'pipe'], dims=[num_mp, num_dp, num_pp]) 252 | super().__init__(axes=['pipe', 'data', 'model'], dims=[num_pp, num_dp, num_mp]) 253 | 254 | 255 | class PipelineParallelGrid: 256 | """Implements a grid object that stores the data parallel ranks 257 | corresponding to each o the model parallel stages 258 | 259 | The grid object organizes the processes in a distributed pytorch job 260 | into a 2D grid, of stage_id and data_parallel_id. 261 | 262 | self.stage_id and self.data_parallel_id stores the stage id 263 | and the data parallel id of current process. 264 | 265 | self.dp_group groups the processes by stage_id. 266 | self.dp_group[i], is a list containing all process ranks whose 267 | stage_id is i. 268 | 269 | self.p2p_groups stores a list of tuple, where each tuple 270 | stores process ranks of adjacent stages for a given data_parallel_id. 271 | For example if num_stage is 5 then a tuple [7,8] represents stages [3, 4], 272 | with data_parallel id = 1. A stage wrap around will appear as non-adjacent ranks, 273 | for example tuple [4,0] with representing wrap-around stage 4 and 0, for 274 | data_parallel_id = 0, or similarly [9,5] represents wrapped around stages [4,0] 275 | for data_parallel_id = 1. 276 | """ 277 | def __init__(self, topology=None, process_group=None): 278 | # TODO use process_group if provided 279 | self.global_rank = dist.get_rank() 280 | self.world_size = dist.get_world_size() 281 | if topology is not None: 282 | log_dist(f'building PipelineParallelGrid with topology: {topology}', ranks=[-1], level=logging.DEBUG) 283 | self._topo = topology 284 | else: 285 | num_pp = 1 286 | num_dp = 1 287 | for idx, prime in enumerate(_prime_factors(self.world_size)): 288 | if idx % 2 == 0: 289 | num_pp *= prime 290 | else: 291 | num_dp *= prime 292 | self._topo = PipeDataParallelTopology(num_dp=num_dp, num_pp=num_pp) 293 | self.data_parallel_size = max(self._topo.get_dim('data'), 1) 294 | self.pipe_parallel_size = max(self._topo.get_dim('pipe'), 1) 295 | self.model_parallel_size = max(self._topo.get_dim('model'), 1) 296 | assert self._is_grid_valid(), "Invalid Grid" 297 | 298 | self.stage_id = self.get_stage_id() 299 | self.data_parallel_id = self.get_data_parallel_id() 300 | self.model_parallel_id = self.get_model_parallel_id() 301 | self.slice_parallel_src_id = self.get_src_parallel_src_id() 302 | log_dist(f'stage_id: {self.stage_id}, slice_parallel_src_id: {self.slice_parallel_src_id}', ranks=[-1], level=logging.DEBUG) 303 | # Create new ProcessGroups for all model parallelism. DeepSpeedLight uses these 304 | # to detect overflow, etc. 305 | 306 | 307 | self.ds_model_proc_group = None 308 | self.ds_model_rank = -1 309 | for dp in range(self.data_parallel_size): 310 | ranks = sorted(self._topo.get_axis_list(axis='data', idx=dp)) 311 | if self.global_rank == 0: 312 | #print(f'RANK={self.global_rank} building DeepSpeed model group: {ranks}') 313 | pass 314 | proc_group = dist.new_group(ranks=ranks) 315 | 316 | if self.global_rank in ranks: 317 | log_dist(f'data_parallel_id: {self.data_parallel_id}, model_parallel_id: {self.model_parallel_id}, \ 318 | stage_id: {self.stage_id}, building ds model group: {ranks}', ranks=[-1], level=logging.DEBUG) 319 | self.ds_model_proc_group = proc_group 320 | self.ds_model_world_size = len(ranks) 321 | self.ds_model_rank = ranks.index(self.global_rank) 322 | assert self.ds_model_rank > -1 323 | assert self.ds_model_proc_group is not None 324 | 325 | # Create new ProcessGroup for gradient all-reduces - these are the data parallel groups 326 | self.dp_group = [] 327 | self.dp_groups = self._topo.get_axis_comm_lists('data') 328 | for g in self.dp_groups: 329 | proc_group = dist.new_group(ranks=g) 330 | if self.global_rank in g: 331 | log_dist(f'data_parallel_id: {self.data_parallel_id}, model_parallel_id: {self.model_parallel_id}, \ 332 | stage_id: {self.stage_id}, building dp group: {g}', ranks=[-1], level=logging.DEBUG) 333 | self.dp_group = g 334 | self.dp_proc_group = proc_group 335 | 336 | self.is_first_stage = (self.stage_id == 0) 337 | self.is_last_stage = (self.stage_id == (self.pipe_parallel_size - 1)) 338 | 339 | self.p2p_groups = self._build_p2p_groups() 340 | self._build_grads_groups() 341 | self._build_activation_groups() 342 | 343 | self._build_grads_groups() 344 | 345 | self._build_activation_groups() 346 | 347 | # Create new ProcessGroup for pipeline collectives - these are pipe parallel groups 348 | self.pp_group = [] 349 | self.pp_proc_group = None 350 | self.pipe_groups = self._topo.get_axis_comm_lists('pipe') 351 | for ranks in self.pipe_groups: 352 | # if self.global_rank == 0: 353 | # #print(f'RANK={self.global_rank} building pipeline group: {ranks}') 354 | # pass 355 | proc_group = dist.new_group(ranks=ranks) 356 | if self.global_rank in ranks: 357 | log_dist(f'data_parallel_id: {self.data_parallel_id}, model_parallel_id: {self.model_parallel_id},\ 358 | stage_id: {self.stage_id}, building pipeline group: {ranks}', \ 359 | ranks=[-1], level=logging.DEBUG) 360 | self.pp_group = ranks 361 | self.pp_proc_group = proc_group 362 | assert self.pp_proc_group is not None 363 | 364 | # Create new ProcessGroup for model (tensor-slicing) collectives 365 | 366 | # Short circuit case without model parallelism. 367 | # TODO: it would be nice if topology had bcast semantics to avoid this branching 368 | # case? 369 | if self.model_parallel_size == 1: 370 | for group_rank in range(self.world_size): 371 | group_rank = [group_rank] 372 | group = dist.new_group(ranks=group_rank) 373 | if group_rank[0] == self.global_rank: 374 | self.slice_group = group_rank 375 | self.slice_proc_group = group 376 | return 377 | else: 378 | self.mp_group = [] 379 | self.model_groups = self._topo.get_axis_comm_lists('model') 380 | for g in self.model_groups: 381 | proc_group = dist.new_group(ranks=g) 382 | if self.global_rank in g: 383 | log_dist(f'data_parallel_id: {self.data_parallel_id}, model_parallel_id: {self.model_parallel_id}, \ 384 | stage_id: {self.stage_id}, building slice group: {g}', ranks=[-1], level=logging.DEBUG) 385 | self.slice_group = g 386 | self.slice_proc_group = proc_group 387 | 388 | def get_stage_id(self): 389 | return self._topo.get_coord(rank=self.global_rank).pipe 390 | 391 | def get_data_parallel_id(self): 392 | return self._topo.get_coord(rank=self.global_rank).data 393 | 394 | def get_model_parallel_id(self): 395 | if 'model' in self._topo.get_axis_names(): 396 | return self._topo.get_coord(rank=self.global_rank).model 397 | return 0 398 | 399 | def get_src_parallel_src_id(self): 400 | if 'model' not in self._topo.get_axis_names(): 401 | return 0 402 | return self.stage_to_global(stage_id=self.stage_id, 403 | data=self.data_parallel_id, 404 | model=0) 405 | 406 | def _build_p2p_groups(self): 407 | """Groups for sending and receiving activations and gradients across model 408 | parallel stages. 409 | """ 410 | comm_lists = self._topo.get_axis_comm_lists('pipe') 411 | log_dist(f'_build_p2p_groups data_parallel_id: {self.data_parallel_id}, \ 412 | model_parallel_id: {self.model_parallel_id}, stage_id: {self.stage_id}, \ 413 | comm_lists: {comm_lists}', ranks=[-1], level=logging.DEBUG) 414 | 415 | p2p_lists = [] 416 | for rank in range(self.world_size): 417 | for l in comm_lists: 418 | assert len(l) == self.pipe_parallel_size 419 | if rank in l: 420 | idx = l.index(rank) 421 | buddy_rank = l[(idx + 1) % self.pipe_parallel_size] 422 | p2p_lists.append([rank, buddy_rank]) 423 | break # next global rank 424 | assert len(p2p_lists) == self.world_size 425 | log_dist(f'data_parallel_id: {self.data_parallel_id}, model_parallel_id: \ 426 | {self.model_parallel_id}, stage_id: {self.stage_id}, \ 427 | p2p_lists: {p2p_lists}', ranks=[-1], level=logging.DEBUG) 428 | return p2p_lists 429 | 430 | def _build_grads_groups(self): 431 | self.send_grads_src_rank = -1 432 | self.recv_grads_src_rank = -1 433 | 434 | self.send_grads_group = [] 435 | self.recv_grads_group = [] 436 | 437 | self.send_grads_proc_group = None 438 | self.recv_grads_proc_group = None 439 | self.grads_proc_groups = [] 440 | 441 | for dp_id in range(self.data_parallel_size): 442 | for stage in range(self.pipe_parallel_size): 443 | next_stage = stage + 1 444 | prev_stage = stage - 1 445 | 446 | grads_group = [] 447 | grads_proc_group = None 448 | 449 | if prev_stage > -1: 450 | grads_src_rank = self._topo.filter_match(data=dp_id, pipe=stage, model=0)[0] 451 | prev_mp_group = self._topo.filter_match(data=dp_id, pipe=prev_stage) 452 | grads_group.append(grads_src_rank) 453 | grads_group.extend(prev_mp_group) 454 | grads_group.sort() 455 | # log_dist(f'_build_grads_groups stage: {stage}, grads_group: {grads_group}', ranks=[-1]) 456 | grads_proc_group = dist.new_group(ranks=grads_group) 457 | self.grads_proc_groups.append(grads_proc_group) 458 | if stage == self.stage_id and self.data_parallel_id == dp_id: 459 | self.send_grads_src_rank = grads_src_rank 460 | self.send_grads_group = grads_group 461 | self.send_grads_proc_group = grads_proc_group 462 | 463 | elif stage == self.stage_id + 1 and self.data_parallel_id == dp_id: 464 | self.recv_grads_src_rank = grads_src_rank 465 | self.recv_grads_group = grads_group 466 | self.recv_grads_proc_group = grads_proc_group 467 | log_dist(f'_build_grads_groups stage: {self.stage_id}, send_grads_src_rank : {self.send_grads_src_rank}, ' 468 | f'send_grads_group: {self.send_grads_group}, recv_grads_group: {self.recv_grads_group}', \ 469 | ranks=[-1], level=logging.DEBUG) 470 | 471 | def _build_activation_groups(self): 472 | self.send_activation_src_rank = -1 473 | self.recv_activation_src_rank = -1 474 | 475 | self.send_activation_group = [] 476 | self.recv_activation_group = [] 477 | 478 | self.send_activation_proc_group = None 479 | self.recv_activation_proc_group = None 480 | self.activation_proc_groups = [] 481 | 482 | for dp_id in range(self.data_parallel_size): 483 | for stage in range(self.pipe_parallel_size): 484 | next_stage = stage + 1 485 | prev_stage = stage - 1 486 | 487 | activation_group = [] 488 | activation_proc_group = None 489 | 490 | if next_stage < self.pipe_parallel_size: 491 | activation_src_rank = self._topo.filter_match(data=dp_id, pipe=stage, model=0)[0] 492 | next_mp_group = self._topo.filter_match(data=dp_id, pipe=next_stage) 493 | activation_group.append(activation_src_rank) 494 | activation_group.extend(next_mp_group) 495 | activation_group.sort() 496 | activation_proc_group = dist.new_group(ranks=activation_group) 497 | self.activation_proc_groups.append(activation_proc_group) 498 | if stage == self.stage_id and self.data_parallel_id == dp_id: 499 | self.send_activation_src_rank = activation_src_rank 500 | self.send_activation_group = activation_group 501 | self.send_activation_proc_group = activation_proc_group 502 | elif stage == self.stage_id - 1 and self.data_parallel_id == dp_id: 503 | self.recv_activation_src_rank = activation_src_rank 504 | self.recv_activation_group = activation_group 505 | self.recv_activation_proc_group = activation_proc_group 506 | log_dist(f'_build_activation_groups stage: {self.stage_id}, send_activation_src_rank : '\ 507 | f'{self.send_activation_src_rank}, send_activation_group: {self.send_activation_group}, '\ 508 | f'recv_grads_group: {self.recv_grads_group}', ranks=[-1], level=logging.DEBUG) 509 | 510 | def _is_grid_valid(self): 511 | ranks = 1 512 | for ax in self._topo.get_axis_names(): 513 | ranks *= self._topo.get_dim(ax) 514 | return ranks == dist.get_world_size() 515 | 516 | #returns the global rank of the process with the provided stage id 517 | #which has the same data_parallel_id as caller process 518 | def stage_to_global(self, stage_id, **kwargs): 519 | me = self._topo.get_coord(self.global_rank) 520 | transform = me._replace(pipe=stage_id, **kwargs)._asdict() 521 | return self._topo.get_rank(**transform) 522 | 523 | #returns the byteps rank of the process with the provided stage id 524 | def stage_to_byteps(self, stage_id): 525 | return self.pipe_parallel_size * self.data_parallel_id + stage_id 526 | 527 | def topology(self): 528 | return self._topo 529 | 530 | # MPU functions for DeepSpeed integration 531 | def get_global_rank(self): 532 | return self.global_rank 533 | 534 | def get_pipe_parallel_rank(self): 535 | """ The stage of the pipeline this rank resides in. """ 536 | return self.stage_id 537 | 538 | def get_pipe_parallel_world_size(self): 539 | """ The number of stages in the pipeline. """ 540 | return self.pipe_parallel_size 541 | 542 | def get_pipe_parallel_group(self): 543 | """ The group of ranks within the same pipeline. """ 544 | return self.pp_proc_group 545 | 546 | def get_data_parallel_rank(self): 547 | """ Which pipeline this rank resides in. """ 548 | return self.data_parallel_id 549 | 550 | def get_data_parallel_world_size(self): 551 | """ The number of pipelines. """ 552 | return self.data_parallel_size 553 | 554 | def get_data_parallel_group(self): 555 | """ The group of ranks within the same stage of all pipelines. """ 556 | return self.dp_proc_group 557 | 558 | # These are model parallel groups across all types of model parallelism. 559 | # Deepspeed uses them to detect overflow, etc. 560 | def get_model_parallel_rank(self): 561 | return self.model_parallel_id 562 | 563 | def get_model_parallel_world_size(self): 564 | return self.model_parallel_size 565 | 566 | def get_model_parallel_group(self): 567 | return self.slice_proc_group 568 | 569 | # For Megatron-style tensor slicing 570 | def get_slice_parallel_rank(self): 571 | return self.model_parallel_id 572 | 573 | def get_slice_parallel_world_size(self): 574 | return self.model_parallel_size 575 | 576 | def get_slice_parallel_group(self): 577 | return self.slice_proc_group 578 | 579 | def get_slice_parallel_src_rank(self): 580 | return self.slice_parallel_src_id 581 | -------------------------------------------------------------------------------- /src/veGiantModel/module/dense.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, ByteDance Inc. All rights reserved. 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | import torch.nn as nn 18 | import torch.autograd as autograd 19 | 20 | # try: 21 | # import veGiantModel 22 | # except ImportError: 23 | # byteGiantModel = None 24 | 25 | class MockModule(nn.Module): 26 | """Module for testing model parallelism""" 27 | pass 28 | 29 | try: 30 | from th_fastertransformer import Linear 31 | 32 | class LinearFunction(autograd.Function): 33 | 34 | @staticmethod 35 | def forward(ctx, input_tensor, weight, bias, act_gelu=False, dropout_rate=0.0): 36 | bias_out = torch.Tensor(0) 37 | dropout_mask = torch.Tensor(0) 38 | if act_gelu == True or dropout_rate > 0.0: 39 | output, bias_out, dropout_mask = Linear.forward_gelu_dropout(input_tensor, weight, bias, act_gelu, dropout_rate) 40 | else: 41 | output = Linear.forward(input_tensor, weight, bias) 42 | ctx.save_for_backward(input_tensor, weight, bias_out, dropout_mask) 43 | ctx.act_gelu = act_gelu 44 | ctx.dropout_rate = dropout_rate 45 | return output 46 | 47 | @staticmethod 48 | def backward(ctx, grad_out): 49 | act_gelu = ctx.act_gelu 50 | dropout_rate = ctx.dropout_rate 51 | input_tensor, weight, bias_out, dropout_mask = ctx.saved_tensors 52 | if act_gelu == True or dropout_rate > 0.0: 53 | grad_in, grad_weight, grad_bias = Linear.backward_gelu_dropout( 54 | grad_out, input_tensor, weight, act_gelu, dropout_rate, bias_out, dropout_mask) 55 | else: 56 | grad_in, grad_weight, grad_bias = Linear.backward( 57 | grad_out, input_tensor, weight) 58 | return grad_in, grad_weight, grad_bias, None, None 59 | 60 | class FTLinear(nn.Module): 61 | def __init__(self, in_features, out_features, initializer_range=0.02, act_gelu=False, dropout_rate=0.0): 62 | super().__init__() 63 | 64 | self.in_features = in_features 65 | self.out_features = out_features 66 | self.weight = nn.Parameter(torch.Tensor(out_features, in_features)) 67 | self.bias = nn.Parameter(torch.Tensor(out_features)) 68 | self.act_gelu = act_gelu 69 | self.dropout_rate = dropout_rate 70 | 71 | self.weight.data.normal_(mean=0.0, std=initializer_range) 72 | self.bias.data.zero_() 73 | 74 | def forward(self, input_tensor): 75 | return LinearFunction.apply(input_tensor, self.weight, self.bias, self.act_gelu, self.dropout_rate if self.training else 0.) 76 | 77 | def extra_repr(self): 78 | return 'in_features={}, out_features={}'.format(self.in_features, self.out_features) 79 | 80 | except Exception as e: 81 | FTLinear = None 82 | 83 | try: 84 | from th_fastertransformer import LinearTranspose 85 | 86 | class LinearTransposeFunction(autograd.Function): 87 | @staticmethod 88 | def forward(ctx, input_tensor, weight, bias, head_num, transpose_type): 89 | output = LinearTranspose.forward(input_tensor, weight, bias, head_num, transpose_type) 90 | ctx.head_num = head_num 91 | ctx.transpose_type = transpose_type 92 | ctx.save_for_backward(input_tensor, weight) 93 | return output 94 | 95 | @staticmethod 96 | def backward(ctx, grad_out): 97 | input_tensor, weight = ctx.saved_tensors 98 | grad_in, grad_weight, grad_bias = LinearTranspose.backward(grad_out, input_tensor, weight, ctx.head_num, ctx.transpose_type) 99 | return grad_in, grad_weight, grad_bias, None, None 100 | 101 | class FTLinearTranspose(nn.Module): 102 | def __init__(self, in_features, out_features, head_num, transpose_type="0213", initializer_range=0.02): 103 | super().__init__() 104 | 105 | self.in_features = in_features 106 | self.out_features = out_features 107 | self.head_num = head_num 108 | self.transpose_type = transpose_type 109 | self.weight = nn.Parameter(torch.Tensor(out_features, in_features)) 110 | self.bias = nn.Parameter(torch.Tensor(out_features)) 111 | 112 | self.weight.data.normal_(mean=0.0, std=initializer_range) 113 | self.bias.data.zero_() 114 | 115 | def forward(self, input_tensor): 116 | return LinearTransposeFunction.apply(input_tensor, self.weight, self.bias, self.head_num, self.transpose_type) 117 | 118 | def extra_repr(self): 119 | return 'in_features={}, out_features={}, head_num={}'.format(self.in_features, self.out_features, self.head_num) 120 | 121 | except Exception as e: 122 | FTLinearTranspose = None 123 | FTDAGather = None 124 | 125 | def column_parallel_load_hook(module, log_fn): 126 | """hook for column parallel linear's load_state_dict function. 127 | It is a helper function to load a the checkpoint from a 128 | non-model-parallel module. It returns a hook function that 129 | pre-processes the checkpoint to parallel slices such that 130 | each model parallel rank could load the corresponding slice. 131 | 132 | Arguments: 133 | module: ColumnParallelLinear or ColumnParallelLinearTranspose 134 | 135 | log_fn: function for logging 136 | 137 | Returns: 138 | A hook function to help load model parallel modules from non- 139 | model-parallel checkpoints. 140 | """ 141 | assert module.mp_rank is not None 142 | assert module.out_features is not None 143 | def hook(state_dict, prefix, local_metadata, strict, missing_keys, 144 | unexpected_keys, error_msgs): 145 | weight_name = prefix + 'weight' 146 | bias_name = prefix + 'bias' 147 | if weight_name in state_dict: 148 | v = state_dict[weight_name] 149 | assert len(v.shape) == 2, v.shape 150 | idx_begin = module.mp_rank * module.out_features 151 | idx_end = (module.mp_rank + 1) * module.out_features 152 | shard = v[idx_begin:idx_end, :] 153 | state_dict[weight_name] = shard 154 | log_fn(f"slice param {weight_name}\tfor model parallelism: {v.shape} -> {shard.shape}") 155 | if bias_name in state_dict: 156 | v = state_dict[bias_name] 157 | assert len(v.shape) == 1, v.shape 158 | idx_begin = module.mp_rank * module.out_features 159 | idx_end = (module.mp_rank + 1) * module.out_features 160 | shard = v[idx_begin:idx_end] 161 | state_dict[bias_name] = shard 162 | log_fn(f"slice param {bias_name}\tfor model parallelism: {v.shape} -> {shard.shape}") 163 | return hook 164 | 165 | def column_serial_load_hook(module, log_fn): 166 | """hook for column serial linear's load_state_dict function. 167 | It is a helper function to load a the checkpoint from a 168 | non-model-parallel module. It returns a hook function that 169 | pre-processes the checkpoint to parallel slices such that 170 | each model parallel rank could load the corresponding slice. 171 | 172 | Arguments: 173 | module: ColumnSerialLinear or ColumnSerialLinearTranspose 174 | 175 | log_fn: function for logging 176 | 177 | Returns: 178 | A hook function to help load model serial modules from non- 179 | model-parallel checkpoints. 180 | """ 181 | assert module.model_parallel_size is not None 182 | assert module.out_features is not None 183 | def hook(state_dict, prefix, local_metadata, strict, missing_keys, 184 | unexpected_keys, error_msgs): 185 | weight_name = prefix + 'weight' 186 | bias_name = prefix + 'bias' 187 | if weight_name in state_dict: 188 | v = state_dict[weight_name] 189 | assert len(v.shape) == 2, v.shape 190 | for i in range(module.model_parallel_size): 191 | weight_name_i = weight_name + "." + str(i) 192 | idx_begin = i * module.out_features 193 | idx_end = (i + 1) * module.out_features 194 | shard = v[idx_begin:idx_end, :] 195 | state_dict[weight_name_i] = shard 196 | log_fn(f"slice param {weight_name_i}\tfor model parallelism: {v.shape} -> {shard.shape}") 197 | del state_dict[weight_name] 198 | if bias_name in state_dict: 199 | v = state_dict[bias_name] 200 | assert len(v.shape) == 1, v.shape 201 | for i in range(module.model_parallel_size): 202 | bias_name_i = bias_name + "." + str(i) 203 | idx_begin = i * module.out_features 204 | idx_end = (i + 1) * module.out_features 205 | shard = v[idx_begin:idx_end] 206 | state_dict[bias_name_i] = shard 207 | log_fn(f"slice param {bias_name_i}\tfor model parallelism: {v.shape} -> {shard.shape}") 208 | del state_dict[bias_name] 209 | return hook 210 | 211 | class ColumnSerialLinear(MockModule): 212 | def __init__(self, in_features, out_features, initializer_range=0.02, 213 | act_gelu=False, dropout_rate=0.0, load_from_shards=False, use_ft=False): 214 | """ 215 | A serial module that mocks the ColumnParallelLinear module. It mocks the parallel 216 | logic by applying the series of work on the same rank, and reduce the result if needed. 217 | """ 218 | super().__init__() 219 | import veGiantModel 220 | model_parallel_size = veGiantModel.distributed.get_model_parallel_world_size() 221 | self.model_parallel_size = model_parallel_size 222 | self.in_features = in_features 223 | self.out_features = out_features // model_parallel_size 224 | assert out_features % model_parallel_size == 0, (out_features, model_parallel_size) 225 | weight_params = [nn.Parameter(torch.Tensor(self.out_features, self.in_features)) for _ in range(model_parallel_size)] 226 | self.weight = nn.ParameterList(weight_params) 227 | bias_params = [nn.Parameter(torch.Tensor(self.out_features)) for _ in range(model_parallel_size)] 228 | self.bias = nn.ParameterList(bias_params) 229 | self.act_gelu = act_gelu 230 | self.dropout_rate = dropout_rate 231 | for weight in self.weight: 232 | weight.data.normal_(mean=0.0, std=initializer_range) 233 | for bias in self.bias: 234 | bias.data.zero_() 235 | self.use_ft = use_ft 236 | if not use_ft: 237 | assert not act_gelu 238 | assert not dropout_rate, dropout_rate 239 | if not load_from_shards: 240 | load_hook = column_serial_load_hook(self, print) 241 | self._register_load_state_dict_pre_hook(load_hook) 242 | 243 | def forward(self, input_tensor): 244 | outputs = [] 245 | for i in range(self.model_parallel_size): 246 | if self.use_ft: 247 | output_i = LinearFunction.apply(input_tensor, self.weight[i], self.bias[i], self.act_gelu, 248 | self.dropout_rate if self.training else 0.) 249 | else: 250 | output_i = nn.functional.linear(input_tensor, self.weight[i], self.bias[i]) 251 | outputs.append(output_i) 252 | output = torch.cat(outputs, dim=-1) 253 | return output 254 | 255 | def extra_repr(self): 256 | return 'in_features={}, out_features={}'.format(self.in_features, self.out_features) 257 | 258 | class ColumnParallelLinear(nn.Module): 259 | def __init__(self, in_features, out_features, initializer_range=0.02, 260 | act_gelu=False, dropout_rate=0.0, load_from_shards=False, use_ft=False, 261 | bias=True, gather_output=False): 262 | """Linear layer with column parallelism. 263 | 264 | The linear layer is defined as Y = dropout(gelu(XA + b)). A is parallelized along 265 | its second dimension as A = [A_1, ..., A_p]. 266 | 267 | Arguments: 268 | in_features: first dimension of matrix A. 269 | out_features: second dimension of matrix A. 270 | initializer_range: range for weight initialization. Note that bias is always set 271 | to zero. 272 | act_gelu: If true, apply gelu activation to (XA+b) 273 | dropout_rate: If greater than zero, apply dropout to gelu(XA+b) 274 | load_from_shards: If true, load the states from sharded checkpoints. Otherwise, 275 | the module automatically slice the checkpoint tensor based on its 276 | model parallel rank. 277 | use_ft: use faster transformer for acceleration. 278 | bias: If true, add bias 279 | gather_output: If true, call all-gether on output and make Y avaiable 280 | to all GPUs, otherwise, every GPU will have its output 281 | which is Y_i = XA_i 282 | """ 283 | super().__init__() 284 | import veGiantModel 285 | model_parallel_size = veGiantModel.distributed.get_model_parallel_world_size() 286 | self.in_features = in_features 287 | self.out_features = out_features // model_parallel_size 288 | assert out_features % model_parallel_size == 0, (out_features, model_parallel_size) 289 | self.weight = nn.Parameter(torch.Tensor(self.out_features, self.in_features)) 290 | self.weight.data.normal_(mean=0.0, std=initializer_range) 291 | if bias: 292 | self.bias = nn.Parameter(torch.Tensor(self.out_features)) 293 | self.bias.data.zero_() 294 | else: 295 | self.bias = None 296 | assert not use_ft 297 | self.gather_output = gather_output 298 | self.act_gelu = act_gelu 299 | self.dropout_rate = dropout_rate 300 | self.use_ft = use_ft 301 | self.mp_rank = veGiantModel.distributed.get_model_parallel_rank() 302 | if not use_ft: 303 | assert not act_gelu 304 | assert not dropout_rate, dropout_rate 305 | if not load_from_shards: 306 | load_hook = column_parallel_load_hook(self, print) 307 | self._register_load_state_dict_pre_hook(load_hook) 308 | 309 | def forward(self, input_tensor): 310 | import veGiantModel 311 | input_tensor = veGiantModel.distributed.copy_to_model_parallel_region(input_tensor) 312 | if self.use_ft: 313 | output = LinearFunction.apply(input_tensor, self.weight, self.bias, self.act_gelu, 314 | self.dropout_rate if self.training else 0.) 315 | else: 316 | output = nn.functional.linear(input_tensor, self.weight, self.bias) 317 | if self.gather_output: 318 | output = veGiantModel.distributed.gather_from_model_parallel_region(output) 319 | return output 320 | 321 | def extra_repr(self): 322 | return 'in_features={}, out_features={}'.format(self.in_features, self.out_features) 323 | 324 | class RowSerialLinear(MockModule): 325 | def __init__(self, in_features, out_features, initializer_range=0.02, dropout_rate=0.0, 326 | load_from_shards=False, use_ft=False): 327 | """ 328 | A serial module that mocks the RowParallelLinear module. It mocks the parallel 329 | logic by applying the series of work on the same rank. 330 | """ 331 | super().__init__() 332 | import veGiantModel 333 | model_parallel_size = veGiantModel.distributed.get_model_parallel_world_size() 334 | self.model_parallel_size = model_parallel_size 335 | self.in_features = in_features // model_parallel_size 336 | self.out_features = out_features 337 | assert in_features % model_parallel_size == 0, (in_features, model_parallel_size) 338 | weight_params = [nn.Parameter(torch.Tensor(self.out_features, self.in_features)) for _ in range(model_parallel_size)] 339 | self.weight = nn.ParameterList(weight_params) 340 | self.bias = nn.Parameter(torch.Tensor(self.out_features)) 341 | self.dropout_rate = dropout_rate 342 | for weight in self.weight: 343 | weight.data.normal_(mean=0.0, std=initializer_range) 344 | self.bias.data.zero_() 345 | self.dropout = nn.Dropout(dropout_rate) 346 | self.use_ft = use_ft 347 | self.mp_rank = veGiantModel.distributed.get_model_parallel_rank() 348 | if not load_from_shards: 349 | def load_hook(state_dict, prefix, local_metadata, strict, missing_keys, 350 | unexpected_keys, error_msgs): 351 | weight_name = prefix + 'weight' 352 | if weight_name in state_dict: 353 | v = state_dict[weight_name] 354 | assert len(v.shape) == 2, v.shape 355 | for i in range(model_parallel_size): 356 | weight_name_i = weight_name + '.' + str(i) 357 | idx_begin = i * self.in_features 358 | idx_end = (i + 1) * self.in_features 359 | shard = v[:, idx_begin:idx_end] 360 | state_dict[weight_name_i] = shard 361 | print(f"slice param {weight_name_i}\tfor model parallelism: {v.shape} -> {shard.shape}") 362 | del state_dict[weight_name] 363 | self._register_load_state_dict_pre_hook(load_hook) 364 | 365 | def forward(self, input_tensor): 366 | input_tensors = torch.split(input_tensor, self.in_features, dim=-1) 367 | outputs = [] 368 | for i in range(self.model_parallel_size): 369 | if self.use_ft: 370 | output_i = LinearFunction.apply(input_tensors[i].contiguous(), self.weight[i], self.bias, False, 0.) 371 | else: 372 | output_i = nn.functional.linear(input_tensors[i].contiguous(), self.weight[i], self.bias) 373 | outputs.append(output_i) 374 | output = outputs[0] 375 | for i in range(self.model_parallel_size - 1): 376 | output = output + outputs[i + 1] 377 | if self.dropout_rate: 378 | output = self.dropout(output) 379 | return output 380 | 381 | def extra_repr(self): 382 | return 'in_features={}, out_features={}'.format(self.in_features, self.out_features) 383 | 384 | class RowParallelLinear(nn.Module): 385 | def __init__(self, in_features, out_features, initializer_range=0.02, dropout_rate=0.0, 386 | load_from_shards=False, use_ft=False): 387 | """Linear layer with row parallelism. 388 | 389 | The linear layer is defined as Y = XA + b. A is parallelized along 390 | its first dimension and X along its second dimension as: 391 | - - 392 | | A_1 | 393 | | . | 394 | A = | . | X = [X_1, ..., X_p] 395 | | . | 396 | | A_p | 397 | - - 398 | 399 | Arguments: 400 | in_features: first dimension of matrix A. 401 | out_features: second dimension of matrix A. 402 | initializer_range: range for weight initialization. Note that bias is always set 403 | to zero. 404 | dropout_rate: If greater than zero, apply dropout XA+b 405 | load_from_shards: If true, load the states from sharded checkpoints. Otherwise, 406 | the module automatically slice the checkpoint tensor based on its 407 | model parallel rank. 408 | use_ft: use faster transformer for acceleration. 409 | """ 410 | super().__init__() 411 | import veGiantModel 412 | model_parallel_size = veGiantModel.distributed.get_model_parallel_world_size() 413 | self.in_features = in_features // model_parallel_size 414 | self.out_features = out_features 415 | assert in_features % model_parallel_size == 0, (in_features, model_parallel_size) 416 | self.weight = nn.Parameter(torch.Tensor(self.out_features, self.in_features)) 417 | self.bias = nn.Parameter(torch.Tensor(self.out_features)) 418 | self.dropout_rate = dropout_rate 419 | self.weight.data.normal_(mean=0.0, std=initializer_range) 420 | self.bias.data.zero_() 421 | self.dropout = nn.Dropout(dropout_rate) 422 | self.use_ft = use_ft 423 | self.mp_rank = veGiantModel.distributed.get_model_parallel_rank() 424 | if not load_from_shards: 425 | def load_hook(state_dict, prefix, local_metadata, strict, missing_keys, 426 | unexpected_keys, error_msgs): 427 | weight_name = prefix + 'weight' 428 | if weight_name in state_dict: 429 | v = state_dict[weight_name] 430 | assert len(v.shape) == 2, v.shape 431 | idx_begin = self.mp_rank * self.in_features 432 | idx_end = (self.mp_rank + 1) * self.in_features 433 | shard = v[:, idx_begin:idx_end] 434 | state_dict[weight_name] = shard 435 | print(f"slice param {weight_name}\tfor model parallelism: {v.shape} -> {shard.shape}") 436 | self._register_load_state_dict_pre_hook(load_hook) 437 | 438 | def forward(self, input_tensor): 439 | if self.use_ft: 440 | output = LinearFunction.apply(input_tensor, self.weight, self.bias, False, 0.) 441 | else: 442 | output = nn.functional.linear(input_tensor, self.weight, self.bias) 443 | import veGiantModel 444 | output = veGiantModel.distributed.reduce_from_model_parallel_region(output) 445 | 446 | if self.dropout_rate: 447 | output = self.dropout(output) 448 | return output 449 | 450 | def extra_repr(self): 451 | return 'in_features={}, out_features={}'.format(self.in_features, self.out_features) 452 | 453 | 454 | class ColumnParallelLinearTranspose(nn.Module): 455 | def __init__(self, in_features, out_features, head_num, transpose_type="0213", initializer_range=0.02, 456 | use_ft=False, load_from_shards=False): 457 | """Linear layer with column parallelism. The output is then reshaped to 4D with 458 | (dim0, dim1, head_num, out_features / head_num), then permuted with axies provided by transpose_type. 459 | For equivalent computation, check the implementation of `ColumnSerialLinearTranspose`. 460 | 461 | The linear layer is defined as Y = XA + b. A is parallelized along 462 | its second dimension as A = [A_1, ..., A_p]. 463 | 464 | Arguments: 465 | in_features: first dimension of matrix A. 466 | out_features: second dimension of matrix A. 467 | head_num: number of "heads" for the out_feature dimension. 468 | transpose_type: the axies for permutation on the output. 469 | initializer_range: range for weight initialization. Note that bias is always set 470 | to zero. 471 | use_ft: use faster transformer for acceleration. 472 | load_from_shards: If true, load the states from sharded checkpoints. Otherwise, 473 | the module automatically slice the checkpoint tensor based on its 474 | model parallel rank. 475 | """ 476 | super().__init__() 477 | self.use_ft = use_ft 478 | self.in_features = in_features 479 | import veGiantModel 480 | model_parallel_size = veGiantModel.distributed.get_model_parallel_world_size() 481 | self.mp_rank = veGiantModel.distributed.get_model_parallel_rank() 482 | 483 | assert out_features % model_parallel_size == 0, (out_features, model_parallel_size) 484 | self.out_features = out_features // model_parallel_size 485 | assert head_num % model_parallel_size == 0, (head_num, model_parallel_size) 486 | self.head_num = head_num // model_parallel_size 487 | self.head_dim = self.out_features // self.head_num 488 | self.transpose_type = transpose_type 489 | self.weight = nn.Parameter(torch.Tensor(self.out_features, in_features)) 490 | self.bias = nn.Parameter(torch.Tensor(self.out_features)) 491 | self.weight.data.normal_(mean=0.0, std=initializer_range) 492 | self.bias.data.zero_() 493 | if not load_from_shards: 494 | load_hook = column_parallel_load_hook(self, print) 495 | self._register_load_state_dict_pre_hook(load_hook) 496 | 497 | def forward(self, input_tensor): 498 | import veGiantModel 499 | input_tensor = veGiantModel.distributed.copy_to_model_parallel_region(input_tensor) 500 | if self.use_ft: 501 | output = LinearTransposeFunction.apply(input_tensor, self.weight, self.bias, 502 | self.head_num, self.transpose_type) 503 | else: 504 | assert self.transpose_type == "0213", self.transpose_type 505 | linear_out = nn.functional.linear(input_tensor, self.weight, self.bias) 506 | new_shape = linear_out.size()[:-1] + (self.head_num, self.head_dim) 507 | linear_out = linear_out.view(*new_shape) 508 | output = linear_out.permute(0, 2, 1, 3).contiguous() 509 | return output 510 | 511 | def extra_repr(self): 512 | return 'in_features={}, out_features={}, head_num={}'.format(self.in_features, self.out_features, self.head_num) 513 | 514 | class ColumnSerialLinearTranspose(MockModule): 515 | def __init__(self, in_features, out_features, head_num, transpose_type="0213", initializer_range=0.02, 516 | use_ft=False, load_from_shards=False): 517 | """ 518 | A serial module that mocks the ColumnParallelLinearTranspose module. It mocks the parallel 519 | logic by applying the series of work on the same rank. 520 | """ 521 | super().__init__() 522 | self.use_ft = use_ft 523 | self.in_features = in_features 524 | import veGiantModel 525 | model_parallel_size = veGiantModel.distributed.get_model_parallel_world_size() 526 | self.model_parallel_size = model_parallel_size 527 | self.mp_rank = veGiantModel.distributed.get_model_parallel_rank() 528 | assert out_features % model_parallel_size == 0, (out_features, model_parallel_size) 529 | self.out_features = out_features // model_parallel_size 530 | assert head_num % model_parallel_size == 0, (head_num, model_parallel_size) 531 | self.head_num = head_num // model_parallel_size 532 | self.head_dim = self.out_features // self.head_num 533 | self.transpose_type = transpose_type 534 | weight_params = [nn.Parameter(torch.Tensor(self.out_features, self.in_features)) for _ in range(model_parallel_size)] 535 | self.weight = nn.ParameterList(weight_params) 536 | bias_params = [nn.Parameter(torch.Tensor(self.out_features)) for _ in range(model_parallel_size)] 537 | self.bias = nn.ParameterList(bias_params) 538 | for weight in self.weight: 539 | weight.data.normal_(mean=0.0, std=initializer_range) 540 | for bias in self.bias: 541 | bias.data.zero_() 542 | 543 | if not load_from_shards: 544 | load_hook = column_serial_load_hook(self, print) 545 | self._register_load_state_dict_pre_hook(load_hook) 546 | 547 | def forward(self, input_tensor): 548 | outputs = [] 549 | for i in range(self.model_parallel_size): 550 | if self.use_ft: 551 | output_i = LinearTransposeFunction.apply(input_tensor, self.weight[i], self.bias[i], self.head_num, self.transpose_type) 552 | else: 553 | assert self.transpose_type == "0213", self.transpose_type 554 | linear_out = nn.functional.linear(input_tensor, self.weight[i], self.bias[i]) 555 | new_shape = linear_out.size()[:-1] + (self.head_num, self.head_dim) 556 | linear_out = linear_out.view(*new_shape) 557 | output_i = linear_out.permute(0, 2, 1, 3).contiguous() 558 | outputs.append(output_i) 559 | output = torch.cat(outputs, dim=1) 560 | return output 561 | 562 | def extra_repr(self): 563 | return 'in_features={}, out_features={}, head_num={}'.format(self.in_features, self.out_features, self.head_num) --------------------------------------------------------------------------------