├── requirements.txt ├── bert ├── colossalai_utils │ ├── model_zoo │ │ ├── __init__.py │ │ └── colo_bert.py │ ├── requirement.txt │ ├── bert_config_pp.json │ ├── bert_config_tp1d.json │ ├── bert_config_tp2d.json │ ├── bert_config_tp3d.json │ ├── bert_config_tp1dpp.json │ ├── bert_config_tp2p5d.json │ ├── bert_config_zero.json │ ├── bert_config_zerotppp.json │ └── utils.py ├── README.md ├── run.py └── common │ ├── helper.py │ └── train.py ├── .github └── ISSUE_TEMPLATE │ ├── config.yml │ ├── documentation.yml │ ├── bug-report.yml │ └── feature_request.yml ├── zero ├── requirement.txt ├── torch_utils │ ├── vit_config.json │ ├── gpt2_config.json │ └── utils.py ├── colossalai_utils │ ├── gpt2_config_v1.json │ ├── vit_config.json │ ├── gpt2_config.json │ └── utils.py ├── patrickstar_utils │ ├── vit_config.json │ ├── gpt2_config.json │ └── utils.py ├── fairscale_utils │ ├── vit_config.json │ ├── gpt2_config.json │ └── utils.py ├── deepspeed_utils │ ├── utils.py │ ├── vit_config.json │ ├── gpt2_config.json │ └── gpt2_nvme_config.json ├── README.md ├── run.py └── common │ ├── utils.py │ ├── train.py │ ├── vit.py │ └── gpt2.py ├── cifar ├── configs │ ├── vit_1d.py │ ├── vit_2d.py │ ├── vit_3d.py │ ├── vit_vanilla.py │ └── vit_2p5d.py └── train.py ├── imagenet1k ├── configs │ ├── vit_1d.py │ ├── vit_2d.py │ ├── vit_3d.py │ ├── vit_2p5d.py │ └── vit_vanilla.py └── train.py ├── gpt ├── configs │ ├── gpt3_vanilla.py │ ├── gpt2_1d.py │ ├── gpt2_2d.py │ ├── gpt2_3d.py │ ├── gpt2_8b_1d_256.py │ ├── gpt2_8b_1d_512.py │ ├── gpt2_vanilla.py │ ├── gpt2_xl_1d_256.py │ ├── gpt2_2p5d.py │ ├── gpt2_8b_2p5d_256.py │ ├── gpt2_8b_2p5d_512.py │ ├── gpt2_xl_2p5d_256.py │ ├── gpt2_pp1d.py │ └── gpt3_pp1d.py ├── data.py ├── readme.md └── train.py ├── .gitignore ├── README.md └── LICENSE /requirements.txt: -------------------------------------------------------------------------------- 1 | torch >= 1.8.0 2 | torchvision 3 | colossalai -------------------------------------------------------------------------------- /bert/colossalai_utils/model_zoo/__init__.py: -------------------------------------------------------------------------------- 1 | from .colo_bert import create_colo_bert_pipeline_model, ColoBertForMaskedLM, ColoBertMaskedLMLoss 2 | 3 | __all__ = ['create_colo_bert_pipeline_model', 'ColoBertForMaskedLM', 'ColoBertMaskedLMLoss'] 4 | -------------------------------------------------------------------------------- /bert/colossalai_utils/requirement.txt: -------------------------------------------------------------------------------- 1 | 2 | torch>=1.10 -f https://download.pytorch.org/whl/cu113/torch_stable.html 3 | torchvision -f https://download.pytorch.org/whl/cu113/torch_stable.html 4 | transformers 5 | datasets 6 | colossalai 7 | rich -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: true 2 | contact_links: 3 | - name: "😊 Discussions" 4 | url: https://github.com/hpcaitech/ColossalAI/discussions 5 | about: Ask questions and discuss with other Colossal-AI community members in our forum 6 | -------------------------------------------------------------------------------- /zero/requirement.txt: -------------------------------------------------------------------------------- 1 | 2 | torch>=1.10 -f https://download.pytorch.org/whl/cu113/torch_stable.html 3 | torchvision -f https://download.pytorch.org/whl/cu113/torch_stable.html 4 | transformers 5 | datasets 6 | colossalai 7 | deepspeed 8 | fairscale 9 | rich 10 | nvidia-dali-cuda110 --extra-index-url https://developer.download.nvidia.com/compute/redist -------------------------------------------------------------------------------- /zero/torch_utils/vit_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "torch", 3 | "model": { 4 | "type": "vit_h" 5 | }, 6 | "hyperparameter": { 7 | "batch_size": 4 8 | }, 9 | "fp16": { 10 | "enabled": true, 11 | "init_scale": 32768, 12 | "growth_factor": 2.0, 13 | "backoff_factor": 0.5, 14 | "growth_interval": 1000 15 | }, 16 | "gradient_clipping": 1.0, 17 | "use_mem_monitor": true 18 | } -------------------------------------------------------------------------------- /cifar/configs/vit_1d.py: -------------------------------------------------------------------------------- 1 | BATCH_SIZE = 512 2 | LEARNING_RATE = 2e-3 3 | WEIGHT_DECAY = 3e-2 4 | 5 | TENSOR_PARALLEL_SIZE = 2 6 | TENSOR_PARALLEL_MODE = '1d' 7 | 8 | NUM_EPOCHS = 200 9 | WARMUP_EPOCHS = 40 10 | 11 | parallel = dict( 12 | pipeline=1, 13 | tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE), 14 | ) 15 | 16 | seed = 42 17 | 18 | LOG_PATH = f"./vit_{TENSOR_PARALLEL_MODE}_cifar10_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}/" 19 | -------------------------------------------------------------------------------- /cifar/configs/vit_2d.py: -------------------------------------------------------------------------------- 1 | BATCH_SIZE = 512 2 | LEARNING_RATE = 2e-3 3 | WEIGHT_DECAY = 3e-2 4 | 5 | TENSOR_PARALLEL_SIZE = 4 6 | TENSOR_PARALLEL_MODE = '2d' 7 | 8 | NUM_EPOCHS = 200 9 | WARMUP_EPOCHS = 40 10 | 11 | parallel = dict( 12 | pipeline=1, 13 | tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE), 14 | ) 15 | 16 | seed = 42 17 | 18 | LOG_PATH = f"./vit_{TENSOR_PARALLEL_MODE}_cifar10_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}/" 19 | -------------------------------------------------------------------------------- /cifar/configs/vit_3d.py: -------------------------------------------------------------------------------- 1 | BATCH_SIZE = 512 2 | LEARNING_RATE = 2e-3 3 | WEIGHT_DECAY = 3e-2 4 | 5 | TENSOR_PARALLEL_SIZE = 8 6 | TENSOR_PARALLEL_MODE = '3d' 7 | 8 | NUM_EPOCHS = 200 9 | WARMUP_EPOCHS = 40 10 | 11 | parallel = dict( 12 | pipeline=1, 13 | tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE), 14 | ) 15 | 16 | seed = 42 17 | 18 | LOG_PATH = f"./vit_{TENSOR_PARALLEL_MODE}_cifar10_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}/" 19 | -------------------------------------------------------------------------------- /cifar/configs/vit_vanilla.py: -------------------------------------------------------------------------------- 1 | BATCH_SIZE = 512 2 | LEARNING_RATE = 2e-3 3 | WEIGHT_DECAY = 3e-2 4 | 5 | TENSOR_PARALLEL_SIZE = 1 6 | TENSOR_PARALLEL_MODE = None 7 | 8 | NUM_EPOCHS = 200 9 | WARMUP_EPOCHS = 40 10 | 11 | parallel = dict( 12 | pipeline=1, 13 | tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE), 14 | ) 15 | 16 | seed = 42 17 | 18 | LOG_PATH = f"./vit_{TENSOR_PARALLEL_MODE}_cifar10_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}/" 19 | -------------------------------------------------------------------------------- /cifar/configs/vit_2p5d.py: -------------------------------------------------------------------------------- 1 | BATCH_SIZE = 512 2 | LEARNING_RATE = 2e-3 3 | WEIGHT_DECAY = 3e-2 4 | 5 | TENSOR_PARALLEL_SIZE = 4 6 | DEPTH = 1 7 | TENSOR_PARALLEL_MODE = '2.5d' 8 | 9 | NUM_EPOCHS = 200 10 | WARMUP_EPOCHS = 40 11 | 12 | parallel = dict( 13 | pipeline=1, 14 | tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE, depth=DEPTH), 15 | ) 16 | 17 | seed = 42 18 | 19 | LOG_PATH = f"./vit_{TENSOR_PARALLEL_MODE}_cifar10_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}/" 20 | -------------------------------------------------------------------------------- /bert/colossalai_utils/bert_config_pp.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "colossalai", 3 | "model": { 4 | "type": "bert_base" 5 | }, 6 | "hyperparameter": { 7 | "batch_size": 8, 8 | "num_epochs": 20, 9 | "steps_per_epoch": 10 10 | }, 11 | "gradient_clipping": 1.0, 12 | "parallel": { 13 | "pipeline": 4, 14 | "tensor": { 15 | "mode": "1d", 16 | "size": 1 17 | } 18 | }, 19 | "use_mem_monitor": true 20 | } 21 | -------------------------------------------------------------------------------- /bert/colossalai_utils/bert_config_tp1d.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "colossalai", 3 | "model": { 4 | "type": "bert_base" 5 | }, 6 | "hyperparameter": { 7 | "batch_size": 8, 8 | "num_epochs": 10, 9 | "steps_per_epoch": 10 10 | }, 11 | "gradient_clipping": 1.0, 12 | "parallel": { 13 | "pipeline": 1, 14 | "tensor": { 15 | "mode": "1d", 16 | "size": 2 17 | } 18 | }, 19 | "use_mem_monitor": true 20 | } 21 | -------------------------------------------------------------------------------- /bert/colossalai_utils/bert_config_tp2d.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "colossalai", 3 | "model": { 4 | "type": "bert_base" 5 | }, 6 | "hyperparameter": { 7 | "batch_size": 8, 8 | "num_epochs": 20, 9 | "steps_per_epoch": 10 10 | }, 11 | "gradient_clipping": 1.0, 12 | "parallel": { 13 | "pipeline": 1, 14 | "tensor": { 15 | "mode": "2d", 16 | "size": 4 17 | } 18 | }, 19 | "use_mem_monitor": true 20 | } 21 | -------------------------------------------------------------------------------- /bert/colossalai_utils/bert_config_tp3d.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "colossalai", 3 | "model": { 4 | "type": "bert_base" 5 | }, 6 | "hyperparameter": { 7 | "batch_size": 8, 8 | "num_epochs": 20, 9 | "steps_per_epoch": 10 10 | }, 11 | "gradient_clipping": 1.0, 12 | "parallel": { 13 | "pipeline": 1, 14 | "tensor": { 15 | "mode": "3d", 16 | "size": 8 17 | } 18 | }, 19 | "use_mem_monitor": true 20 | } 21 | -------------------------------------------------------------------------------- /bert/colossalai_utils/bert_config_tp1dpp.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "colossalai", 3 | "model": { 4 | "type": "bert_base" 5 | }, 6 | "hyperparameter": { 7 | "batch_size": 8, 8 | "num_epochs": 20, 9 | "steps_per_epoch": 10 10 | }, 11 | "gradient_clipping": 1.0, 12 | "parallel": { 13 | "pipeline": 2, 14 | "tensor": { 15 | "mode": "1d", 16 | "size": 2 17 | } 18 | }, 19 | "use_mem_monitor": true 20 | } 21 | -------------------------------------------------------------------------------- /bert/colossalai_utils/bert_config_tp2p5d.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "colossalai", 3 | "model": { 4 | "type": "bert_base" 5 | }, 6 | "hyperparameter": { 7 | "batch_size": 8, 8 | "num_epochs": 20, 9 | "steps_per_epoch": 10 10 | }, 11 | "gradient_clipping": 1.0, 12 | "parallel": { 13 | "pipeline": 1, 14 | "tensor": { 15 | "mode": "2.5d", 16 | "size": 8, 17 | "depth": 2 18 | } 19 | }, 20 | "use_mem_monitor": true 21 | } 22 | -------------------------------------------------------------------------------- /zero/torch_utils/gpt2_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "torch", 3 | "model": { 4 | "type": "gpt2_10b" 5 | }, 6 | "hyperparameter": { 7 | "batch_size": 1, 8 | "num_epochs": 2, 9 | "steps_per_epoch": 10, 10 | "synthetic": true 11 | }, 12 | "fp16": { 13 | "enabled": true, 14 | "init_scale": 32768, 15 | "growth_factor": 2.0, 16 | "backoff_factor": 0.5, 17 | "growth_interval": 1000 18 | }, 19 | "gradient_clipping": 1.0, 20 | "use_mem_monitor": true 21 | } 22 | -------------------------------------------------------------------------------- /bert/README.md: -------------------------------------------------------------------------------- 1 | # Bert Benchmark 2 | Bert Benchmark with data parallel, tensor parallel(tp), pipeline parallel(pp) and ZeRO. 3 | 4 | ## Setup 5 | 1. Install dependencies if you do not have them 6 | ``` 7 | pip install -r requirement.txt 8 | ``` 9 | 10 | 2. Add root dir into PYTHONPATH 11 | ``` 12 | export PYTHONPATH=$(dirname "$PWD"):$PYTHONPATH 13 | ``` 14 | 15 | ## Bert Usage 16 | 17 | 1. Prepare datasets and tokenizers from HuggingFace Hub if necessary (e.g. we provide an example of training `wikitext-2`). 18 | 19 | 2. Run benchmark with one of the systems to evaluate 20 | ``` 21 | DATA=/PATH/TO/DATASET TOKENIZER=/PATH/TO/TOKENIZER LOG=/PATH/TO/LOG torchrun --nproc_per_node=NUM_GPUS run.py --config=CONFIG_FILE 22 | ``` -------------------------------------------------------------------------------- /imagenet1k/configs/vit_1d.py: -------------------------------------------------------------------------------- 1 | from colossalai.amp import AMP_TYPE 2 | 3 | TOTAL_BATCH_SIZE = 4096 4 | LEARNING_RATE = 3e-3 5 | WEIGHT_DECAY = 0.3 6 | 7 | TENSOR_PARALLEL_SIZE = 2 8 | TENSOR_PARALLEL_MODE = '1d' 9 | 10 | NUM_EPOCHS = 300 11 | WARMUP_EPOCHS = 32 12 | 13 | parallel = dict( 14 | pipeline=1, 15 | tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE), 16 | ) 17 | 18 | fp16 = dict(mode=AMP_TYPE.TORCH, ) 19 | 20 | gradient_accumulation = 2 21 | 22 | BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation 23 | 24 | clip_grad_norm = 1.0 25 | 26 | LOG_PATH = f"./vit_{TENSOR_PARALLEL_MODE}_imagenet1k_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_{fp16['mode']}_clip_grad{clip_grad_norm}/" 27 | -------------------------------------------------------------------------------- /imagenet1k/configs/vit_2d.py: -------------------------------------------------------------------------------- 1 | from colossalai.amp import AMP_TYPE 2 | 3 | TOTAL_BATCH_SIZE = 4096 4 | LEARNING_RATE = 3e-3 5 | WEIGHT_DECAY = 0.3 6 | 7 | TENSOR_PARALLEL_SIZE = 4 8 | TENSOR_PARALLEL_MODE = '2d' 9 | 10 | NUM_EPOCHS = 300 11 | WARMUP_EPOCHS = 32 12 | 13 | parallel = dict( 14 | pipeline=1, 15 | tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE), 16 | ) 17 | 18 | fp16 = dict(mode=AMP_TYPE.TORCH, ) 19 | 20 | gradient_accumulation = 2 21 | 22 | BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation 23 | 24 | clip_grad_norm = 1.0 25 | 26 | LOG_PATH = f"./vit_{TENSOR_PARALLEL_MODE}_imagenet1k_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_{fp16['mode']}_clip_grad{clip_grad_norm}/" 27 | -------------------------------------------------------------------------------- /imagenet1k/configs/vit_3d.py: -------------------------------------------------------------------------------- 1 | from colossalai.amp import AMP_TYPE 2 | 3 | TOTAL_BATCH_SIZE = 4096 4 | LEARNING_RATE = 3e-3 5 | WEIGHT_DECAY = 0.3 6 | 7 | TENSOR_PARALLEL_SIZE = 8 8 | TENSOR_PARALLEL_MODE = '3d' 9 | 10 | NUM_EPOCHS = 300 11 | WARMUP_EPOCHS = 32 12 | 13 | parallel = dict( 14 | pipeline=1, 15 | tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE), 16 | ) 17 | 18 | fp16 = dict(mode=AMP_TYPE.TORCH, ) 19 | 20 | gradient_accumulation = 2 21 | 22 | BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation 23 | 24 | clip_grad_norm = 1.0 25 | 26 | LOG_PATH = f"./vit_{TENSOR_PARALLEL_MODE}_imagenet1k_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_{fp16['mode']}_clip_grad{clip_grad_norm}/" 27 | -------------------------------------------------------------------------------- /zero/colossalai_utils/gpt2_config_v1.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "colossalai", 3 | "model": { 4 | "type": "gpt2_small" 5 | }, 6 | "hyperparameter": { 7 | "batch_size": 3, 8 | "steps_per_epoch":10 9 | }, 10 | "fp16": { 11 | "initial_scale": 32768, 12 | "min_scale": 1, 13 | "growth_factor": 2.0, 14 | "backoff_factor": 0.5, 15 | "growth_interval": 1000 16 | }, 17 | "gradient_clipping": 0.0, 18 | "zero": { 19 | "mixed_precision": true, 20 | "reshard_after_forward": false, 21 | "offload_config" : { 22 | "device": "cpu" 23 | }, 24 | "version": 1 25 | }, 26 | "use_mem_monitor": true 27 | } 28 | -------------------------------------------------------------------------------- /zero/colossalai_utils/vit_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "colossalai", 3 | "model": { 4 | "type": "vit_h" 5 | }, 6 | "hyperparameter": { 7 | "batch_size": 4, 8 | "steps_per_epoch":10 9 | }, 10 | "fp16": { 11 | "initial_scale": 32768, 12 | "min_scale": 1, 13 | "growth_factor": 2.0, 14 | "backoff_factor": 0.5, 15 | "growth_interval": 1000 16 | }, 17 | "gradient_clipping": 1.0, 18 | "zero": { 19 | "reduce_scatter_bucket_size_mb": 25, 20 | "fp32_reduce_scatter": false, 21 | "offload_config": { 22 | "device": "cpu" 23 | }, 24 | "shard_param": true 25 | }, 26 | "use_mem_monitor": true 27 | } -------------------------------------------------------------------------------- /imagenet1k/configs/vit_2p5d.py: -------------------------------------------------------------------------------- 1 | from colossalai.amp import AMP_TYPE 2 | 3 | TOTAL_BATCH_SIZE = 4096 4 | LEARNING_RATE = 3e-3 5 | WEIGHT_DECAY = 0.3 6 | 7 | TENSOR_PARALLEL_SIZE = 4 8 | DEPTH = 1 9 | TENSOR_PARALLEL_MODE = '2.5d' 10 | 11 | NUM_EPOCHS = 300 12 | WARMUP_EPOCHS = 32 13 | 14 | parallel = dict( 15 | pipeline=1, 16 | tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE, depth=DEPTH), 17 | ) 18 | 19 | fp16 = dict(mode=AMP_TYPE.TORCH, ) 20 | 21 | gradient_accumulation = 2 22 | 23 | BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation 24 | 25 | clip_grad_norm = 1.0 26 | 27 | LOG_PATH = f"./vit_{TENSOR_PARALLEL_MODE}_imagenet1k_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_{fp16['mode']}_clip_grad{clip_grad_norm}/" 28 | -------------------------------------------------------------------------------- /imagenet1k/configs/vit_vanilla.py: -------------------------------------------------------------------------------- 1 | from colossalai.amp import AMP_TYPE 2 | 3 | TOTAL_BATCH_SIZE = 4096 4 | LEARNING_RATE = 3e-3 5 | WEIGHT_DECAY = 0.3 6 | 7 | TENSOR_PARALLEL_SIZE = 1 8 | TENSOR_PARALLEL_MODE = None 9 | 10 | NUM_EPOCHS = 300 11 | WARMUP_EPOCHS = 32 12 | 13 | parallel = dict( 14 | pipeline=1, 15 | tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE), 16 | ) 17 | 18 | fp16 = dict(mode=AMP_TYPE.TORCH, ) 19 | 20 | gradient_accumulation = 2 21 | 22 | BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation 23 | 24 | clip_grad_norm = 1.0 25 | 26 | LOG_PATH = f"./vit_{TENSOR_PARALLEL_MODE}_imagenet1k_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_{fp16['mode']}_clip_grad{clip_grad_norm}/" 27 | -------------------------------------------------------------------------------- /zero/colossalai_utils/gpt2_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "colossalai", 3 | "model": { 4 | "type": "gpt2_10b" 5 | }, 6 | "hyperparameter": { 7 | "batch_size": 1, 8 | "steps_per_epoch": 3 9 | }, 10 | "fp16": { 11 | "initial_scale": 32768, 12 | "min_scale": 1, 13 | "growth_factor": 2.0, 14 | "backoff_factor": 0.5, 15 | "growth_interval": 1000 16 | }, 17 | "gradient_clipping": 0.0, 18 | "zero": { 19 | "reduce_scatter_bucket_size_mb": 25, 20 | "fp32_reduce_scatter": false, 21 | "offload_config": { 22 | "device": "cpu" 23 | }, 24 | "reuse_fp16_shard": true, 25 | "version": 2 26 | }, 27 | "use_mem_monitor": true 28 | } -------------------------------------------------------------------------------- /bert/colossalai_utils/bert_config_zero.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "colossalai", 3 | "model": { 4 | "type": "bert_base" 5 | }, 6 | "hyperparameter": { 7 | "batch_size": 8, 8 | "num_epochs": 20, 9 | "steps_per_epoch": 10 10 | }, 11 | "gradient_clipping": 1.0, 12 | "zero": { 13 | "model_config": { 14 | "offload_config": { 15 | "device": "cpu" 16 | } 17 | }, 18 | "optimizer_config": { 19 | "cpu_offload": true, 20 | "initial_scale": 256, 21 | "min_scale": 1, 22 | "growth_factor": 2.0, 23 | "backoff_factor": 0.5, 24 | "growth_interval": 1000 25 | } 26 | }, 27 | "use_mem_monitor": true 28 | } 29 | -------------------------------------------------------------------------------- /zero/patrickstar_utils/vit_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "patrickstar", 3 | "model": { 4 | "type": "vit_h" 5 | }, 6 | "hyperparameter": { 7 | "batch_size": 4 8 | }, 9 | "optimizer": { 10 | "type": "AdamW", 11 | "params": { 12 | "lr": 0.0015, 13 | "weight_decay": 0.01, 14 | "use_hybrid_adam": true 15 | } 16 | }, 17 | "fp16": { 18 | "enabled": true, 19 | "loss_scale": 0, 20 | "initial_scale_power": 15, 21 | "loss_scale_window": 1000, 22 | "hysteresis": 2, 23 | "min_loss_scale": 1 24 | }, 25 | "default_chunk_size": 67108864, 26 | "release_after_init": true, 27 | "gradient_clipping": 1.0, 28 | "use_cpu_embedding": false, 29 | "use_mem_monitor": true 30 | } -------------------------------------------------------------------------------- /zero/fairscale_utils/vit_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "fairscale", 3 | "model": { 4 | "type": "vit_h", 5 | "checkpoint": false 6 | }, 7 | "hyperparameter": { 8 | "batch_size": 4 9 | }, 10 | "fp16": { 11 | "enabled": true, 12 | "init_scale": 32768, 13 | "growth_factor": 2.0, 14 | "backoff_factor": 0.5, 15 | "growth_interval": 1000 16 | }, 17 | "gradient_clipping": 1.0, 18 | "fsdp": { 19 | "reshard_after_forward": true, 20 | "mixed_precision": true, 21 | "fp32_reduce_scatter": false, 22 | "flatten_parameters": true, 23 | "move_params_to_cpu": true, 24 | "bucket_cap_mb": 25, 25 | "clear_autocast_cache": false, 26 | "force_input_to_fp32": false, 27 | "state_dict_on_rank_0_only": false 28 | }, 29 | "use_mem_monitor": true 30 | } -------------------------------------------------------------------------------- /zero/patrickstar_utils/gpt2_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "patrickstar", 3 | "model": { 4 | "type": "gpt2_10b" 5 | }, 6 | "hyperparameter": { 7 | "batch_size": 8, 8 | "num_epochs": 2, 9 | "steps_per_epoch": 10, 10 | "synthetic": true 11 | }, 12 | "optimizer": { 13 | "type": "AdamW", 14 | "params": { 15 | "lr": 0.0015, 16 | "weight_decay": 0.01, 17 | "use_hybrid_adam": true 18 | } 19 | }, 20 | "fp16": { 21 | "enabled": true, 22 | "loss_scale": 0, 23 | "initial_scale_power": 15, 24 | "loss_scale_window": 1000, 25 | "hysteresis": 2, 26 | "min_loss_scale": 1 27 | }, 28 | "default_chunk_size": 1073741824, 29 | "release_after_init": true, 30 | "gradient_clipping": 1.0, 31 | "use_cpu_embedding": false, 32 | "use_mem_monitor": true 33 | } -------------------------------------------------------------------------------- /gpt/configs/gpt3_vanilla.py: -------------------------------------------------------------------------------- 1 | from colossalai.amp import AMP_TYPE 2 | 3 | VOCAB_SIZE = 50304 4 | SEQ_LENGTH = 1024 5 | 6 | TOTAL_BATCH_SIZE = 64 7 | LEARNING_RATE = 0.00015 8 | WEIGHT_DECAY = 1e-2 9 | 10 | TENSOR_PARALLEL_SIZE = 1 11 | TENSOR_PARALLEL_MODE = None 12 | 13 | NUM_EPOCHS = 60 14 | WARMUP_EPOCHS = int(NUM_EPOCHS * 0.36) 15 | 16 | parallel = dict( 17 | pipeline=4, 18 | tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE), 19 | ) 20 | 21 | fp16 = dict(mode=AMP_TYPE.NAIVE, ) 22 | 23 | gradient_accumulation = 1 24 | 25 | BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation 26 | 27 | NUM_MICRO_BATCHES = 16 28 | TENSOR_SHAPE = (BATCH_SIZE // NUM_MICRO_BATCHES, SEQ_LENGTH, 1600) 29 | 30 | clip_grad_norm = 1.0 31 | 32 | # LOG_PATH = f"./gpt3_{TENSOR_PARALLEL_MODE}_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_accum{gradient_accumulation}_clip_grad{clip_grad_norm}/" 33 | -------------------------------------------------------------------------------- /bert/colossalai_utils/bert_config_zerotppp.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "colossalai", 3 | "model": { 4 | "type": "bert_base" 5 | }, 6 | "hyperparameter": { 7 | "batch_size": 8, 8 | "num_epochs": 100, 9 | "steps_per_epoch": 10 10 | }, 11 | "gradient_clipping": 1.0, 12 | "zero": { 13 | "model_config": { 14 | "offload_config": { 15 | "device": "cpu" 16 | } 17 | }, 18 | "optimizer_config": { 19 | "cpu_offload": true, 20 | "initial_scale": 256, 21 | "min_scale": 1, 22 | "growth_factor": 2.0, 23 | "backoff_factor": 0.5, 24 | "growth_interval": 1000 25 | } 26 | }, 27 | "parallel": { 28 | "pipeline":1, 29 | "tensor": { 30 | "mode": "1d", 31 | "size": 2 32 | } 33 | }, 34 | "use_mem_monitor": true 35 | } 36 | -------------------------------------------------------------------------------- /bert/run.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from bert.common.helper import bert_builder 4 | from bert.colossalai_utils.utils import init_w_col 5 | from bert.common.train import train 6 | from zero.common.utils import CONFIG, load_config, print_log 7 | from zero.torch_utils.utils import init_w_torch 8 | 9 | _method = { 10 | 'torch': init_w_torch, 11 | 'colossalai': init_w_col, 12 | } 13 | 14 | _builder = { 15 | 'bert': bert_builder, 16 | } 17 | 18 | def run_bert(): 19 | method = CONFIG['method'] 20 | 21 | model = CONFIG['model']['type'] 22 | model_type = model.split('_')[0] 23 | 24 | train(*_method[method](_builder[model_type])) 25 | 26 | if __name__ == '__main__': 27 | load_config() 28 | 29 | CONFIG['log_path'] = os.environ.get('LOG', '.') 30 | os.makedirs(CONFIG['log_path'], exist_ok=True) 31 | 32 | print_log(f'Initializing {CONFIG["method"]} ...') 33 | 34 | run_bert() 35 | -------------------------------------------------------------------------------- /zero/fairscale_utils/gpt2_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "fairscale", 3 | "model": { 4 | "type": "gpt2_10b" 5 | }, 6 | "hyperparameter": { 7 | "batch_size": 1, 8 | "num_epochs": 2, 9 | "steps_per_epoch": 10, 10 | "synthetic": true 11 | }, 12 | "fp16": { 13 | "enabled": true, 14 | "init_scale": 32768, 15 | "growth_factor": 2.0, 16 | "backoff_factor": 0.5, 17 | "growth_interval": 1000 18 | }, 19 | "gradient_clipping": 1.0, 20 | "fsdp": { 21 | "reshard_after_forward": true, 22 | "mixed_precision": true, 23 | "fp32_reduce_scatter": false, 24 | "flatten_parameters": true, 25 | "move_params_to_cpu": true, 26 | "bucket_cap_mb": 25, 27 | "clear_autocast_cache": false, 28 | "force_input_to_fp32": false, 29 | "state_dict_on_rank_0_only": false 30 | }, 31 | "use_mem_monitor": true 32 | } -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/documentation.yml: -------------------------------------------------------------------------------- 1 | name: 📚 Documentation 2 | description: Report an issue related to https://www.colossalai.org/ 3 | 4 | body: 5 | - type: markdown 6 | attributes: 7 | value: > 8 | #### Not suitable for your needs? [Open a blank issue](https://github.com/hpcaitech/ColossalAI/issues/new). 9 | - type: textarea 10 | attributes: 11 | label: 📚 The doc issue 12 | description: | 13 | **Description** What content in [Documentation](https://www.colossalai.org/) is an issue? 14 | **Location** Where is the issue location? 15 | **Expectation** What is your expected content about it? 16 | **Screenshots** If applicable, add screenshots to help explain your problem. 17 | **Suggestions** Tell us how we could improve the documentation. 18 | placeholder: | 19 | A clear and concise description of the issue. 20 | validations: 21 | required: true 22 | 23 | - type: markdown 24 | attributes: 25 | value: > 26 | Thanks for contributing 🎉! 27 | -------------------------------------------------------------------------------- /gpt/configs/gpt2_1d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from colossalai.amp import AMP_TYPE 3 | from model_zoo.gpt import gpt2_8B, gpt2_xl, gpt2_medium, gpt2_large 4 | 5 | VOCAB_SIZE = 50304 6 | SEQ_LENGTH = 1024 7 | 8 | TOTAL_BATCH_SIZE = 32 9 | LEARNING_RATE = 0.00015 10 | WEIGHT_DECAY = 1e-2 11 | 12 | TENSOR_PARALLEL_SIZE = 4 13 | TENSOR_PARALLEL_MODE = '1d' 14 | 15 | NUM_EPOCHS = 20 16 | WARMUP_EPOCHS = 1 17 | 18 | parallel = dict( 19 | pipeline=1, 20 | tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE), 21 | ) 22 | 23 | model = dict(type=gpt2_xl, 24 | vocab_size=VOCAB_SIZE, 25 | max_position_embeddings=SEQ_LENGTH, 26 | dtype=torch.half, 27 | # fuse_scale_mask_softmax=True, 28 | checkpoint=True) 29 | 30 | fp16 = dict(mode=AMP_TYPE.NAIVE) 31 | 32 | gradient_accumulation = 1 33 | 34 | BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation 35 | 36 | clip_grad_norm = 1.0 37 | 38 | # LOG_PATH = f"./gpt2_{TENSOR_PARALLEL_MODE}_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_accum{gradient_accumulation}_clip_grad{clip_grad_norm}/" 39 | -------------------------------------------------------------------------------- /gpt/configs/gpt2_2d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from colossalai.amp import AMP_TYPE 3 | from model_zoo.gpt import gpt2_8B, gpt2_xl, gpt2_medium, gpt2_large 4 | 5 | VOCAB_SIZE = 50304 6 | SEQ_LENGTH = 1024 7 | 8 | TOTAL_BATCH_SIZE = 32 9 | LEARNING_RATE = 0.00015 10 | WEIGHT_DECAY = 1e-2 11 | 12 | TENSOR_PARALLEL_SIZE = 4 13 | TENSOR_PARALLEL_MODE = '2d' 14 | 15 | NUM_EPOCHS = 20 16 | WARMUP_EPOCHS = 1 17 | 18 | parallel = dict( 19 | pipeline=1, 20 | tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE), 21 | ) 22 | 23 | model = dict(type=gpt2_xl, 24 | vocab_size=VOCAB_SIZE, 25 | max_position_embeddings=SEQ_LENGTH, 26 | dtype=torch.half, 27 | # fuse_scale_mask_softmax=True, 28 | checkpoint=True) 29 | 30 | fp16 = dict(mode=AMP_TYPE.NAIVE) 31 | 32 | gradient_accumulation = 1 33 | 34 | BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation 35 | 36 | clip_grad_norm = 1.0 37 | 38 | # LOG_PATH = f"./gpt2_{TENSOR_PARALLEL_MODE}_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_accum{gradient_accumulation}_clip_grad{clip_grad_norm}/" 39 | -------------------------------------------------------------------------------- /gpt/configs/gpt2_3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from colossalai.amp import AMP_TYPE 3 | from model_zoo.gpt import gpt2_8B, gpt2_xl, gpt2_medium, gpt2_large 4 | 5 | VOCAB_SIZE = 50304 6 | SEQ_LENGTH = 1024 7 | 8 | TOTAL_BATCH_SIZE = 32 9 | LEARNING_RATE = 0.00015 10 | WEIGHT_DECAY = 1e-2 11 | 12 | TENSOR_PARALLEL_SIZE = 8 13 | TENSOR_PARALLEL_MODE = '3d' 14 | 15 | NUM_EPOCHS = 20 16 | WARMUP_EPOCHS = 1 17 | 18 | parallel = dict( 19 | pipeline=1, 20 | tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE), 21 | ) 22 | 23 | model = dict(type=gpt2_xl, 24 | vocab_size=VOCAB_SIZE, 25 | max_position_embeddings=SEQ_LENGTH, 26 | dtype=torch.half, 27 | # fuse_scale_mask_softmax=True, 28 | checkpoint=True) 29 | 30 | fp16 = dict(mode=AMP_TYPE.NAIVE) 31 | 32 | gradient_accumulation = 1 33 | 34 | BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation 35 | 36 | clip_grad_norm = 1.0 37 | 38 | # LOG_PATH = f"./gpt2_{TENSOR_PARALLEL_MODE}_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_accum{gradient_accumulation}_clip_grad{clip_grad_norm}/" 39 | -------------------------------------------------------------------------------- /gpt/configs/gpt2_8b_1d_256.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from colossalai.amp import AMP_TYPE 3 | from model_zoo.gpt import gpt2_8B, gpt2_xl 4 | 5 | VOCAB_SIZE = 50304 6 | SEQ_LENGTH = 1024 7 | 8 | TOTAL_BATCH_SIZE = 256 9 | LEARNING_RATE = 0.00015 10 | WEIGHT_DECAY = 1e-2 11 | 12 | TENSOR_PARALLEL_SIZE = 4 13 | TENSOR_PARALLEL_MODE = '1d' 14 | 15 | NUM_EPOCHS = 60 16 | WARMUP_EPOCHS = int(NUM_EPOCHS * 0.36) 17 | 18 | parallel = dict( 19 | pipeline=1, 20 | tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE), 21 | ) 22 | 23 | model = dict(type=gpt2_8B, 24 | vocab_size=VOCAB_SIZE, 25 | max_position_embeddings=SEQ_LENGTH, 26 | dtype=torch.half, 27 | fuse_scale_mask_softmax=True, 28 | checkpoint=True) 29 | 30 | fp16 = dict(mode=AMP_TYPE.NAIVE) 31 | 32 | gradient_accumulation = 1 33 | 34 | BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation 35 | 36 | clip_grad_norm = 1.0 37 | 38 | # LOG_PATH = f"./gpt2_{TENSOR_PARALLEL_MODE}_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_accum{gradient_accumulation}_clip_grad{clip_grad_norm}/" 39 | -------------------------------------------------------------------------------- /gpt/configs/gpt2_8b_1d_512.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from colossalai.amp import AMP_TYPE 3 | from model_zoo.gpt import gpt2_8B, gpt2_xl 4 | 5 | VOCAB_SIZE = 50304 6 | SEQ_LENGTH = 1024 7 | 8 | TOTAL_BATCH_SIZE = 512 9 | LEARNING_RATE = 0.00015 10 | WEIGHT_DECAY = 1e-2 11 | 12 | TENSOR_PARALLEL_SIZE = 4 13 | TENSOR_PARALLEL_MODE = '1d' 14 | 15 | NUM_EPOCHS = 60 16 | WARMUP_EPOCHS = int(NUM_EPOCHS * 0.36) 17 | 18 | parallel = dict( 19 | pipeline=1, 20 | tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE), 21 | ) 22 | 23 | model = dict(type=gpt2_8B, 24 | vocab_size=VOCAB_SIZE, 25 | max_position_embeddings=SEQ_LENGTH, 26 | dtype=torch.half, 27 | fuse_scale_mask_softmax=True, 28 | checkpoint=True) 29 | 30 | fp16 = dict(mode=AMP_TYPE.NAIVE) 31 | 32 | gradient_accumulation = 1 33 | 34 | BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation 35 | 36 | clip_grad_norm = 1.0 37 | 38 | # LOG_PATH = f"./gpt2_{TENSOR_PARALLEL_MODE}_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_accum{gradient_accumulation}_clip_grad{clip_grad_norm}/" 39 | -------------------------------------------------------------------------------- /gpt/configs/gpt2_vanilla.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from colossalai.amp import AMP_TYPE 3 | from model_zoo.gpt import gpt2_8B, gpt2_xl, gpt2_medium, gpt2_large 4 | 5 | VOCAB_SIZE = 50304 6 | SEQ_LENGTH = 1024 7 | 8 | TOTAL_BATCH_SIZE = 32 9 | LEARNING_RATE = 0.00015 10 | WEIGHT_DECAY = 1e-2 11 | 12 | TENSOR_PARALLEL_SIZE = 1 13 | TENSOR_PARALLEL_MODE = None 14 | 15 | NUM_EPOCHS = 20 16 | WARMUP_EPOCHS = 1 17 | 18 | parallel = dict( 19 | pipeline=1, 20 | tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE), 21 | ) 22 | 23 | model = dict(type=gpt2_xl, 24 | vocab_size=VOCAB_SIZE, 25 | max_position_embeddings=SEQ_LENGTH, 26 | dtype=torch.half, 27 | # fuse_scale_mask_softmax=True, 28 | checkpoint=True) 29 | 30 | fp16 = dict(mode=AMP_TYPE.NAIVE) 31 | 32 | gradient_accumulation = 1 33 | 34 | BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation 35 | 36 | clip_grad_norm = 1.0 37 | 38 | # LOG_PATH = f"./gpt2_{TENSOR_PARALLEL_MODE}_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_accum{gradient_accumulation}_clip_grad{clip_grad_norm}/" 39 | -------------------------------------------------------------------------------- /gpt/configs/gpt2_xl_1d_256.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from colossalai.amp import AMP_TYPE 3 | from model_zoo.gpt import gpt2_8B, gpt2_xl 4 | 5 | VOCAB_SIZE = 50304 6 | SEQ_LENGTH = 1024 7 | 8 | TOTAL_BATCH_SIZE = 1024 9 | LEARNING_RATE = 0.00015 10 | WEIGHT_DECAY = 1e-2 11 | 12 | TENSOR_PARALLEL_SIZE = 4 13 | TENSOR_PARALLEL_MODE = '1d' 14 | 15 | NUM_EPOCHS = 60 16 | WARMUP_EPOCHS = int(NUM_EPOCHS * 0.36) 17 | 18 | parallel = dict( 19 | pipeline=1, 20 | tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE), 21 | ) 22 | 23 | model = dict(type=gpt2_xl, 24 | vocab_size=VOCAB_SIZE, 25 | max_position_embeddings=SEQ_LENGTH, 26 | dtype=torch.half, 27 | fuse_scale_mask_softmax=True, 28 | checkpoint=True) 29 | 30 | fp16 = dict(mode=AMP_TYPE.NAIVE) 31 | 32 | gradient_accumulation = 1 33 | 34 | BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation 35 | 36 | clip_grad_norm = 1.0 37 | 38 | # LOG_PATH = f"./gpt2_{TENSOR_PARALLEL_MODE}_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_accum{gradient_accumulation}_clip_grad{clip_grad_norm}/" 39 | -------------------------------------------------------------------------------- /gpt/configs/gpt2_2p5d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from colossalai.amp import AMP_TYPE 3 | from model_zoo.gpt import gpt2_8B, gpt2_xl, gpt2_medium, gpt2_large 4 | 5 | VOCAB_SIZE = 50304 6 | SEQ_LENGTH = 1024 7 | 8 | TOTAL_BATCH_SIZE = 32 9 | LEARNING_RATE = 0.00015 10 | WEIGHT_DECAY = 1e-2 11 | 12 | TENSOR_PARALLEL_SIZE = 4 13 | TENSOR_PARALLEL_MODE = '2.5d' 14 | 15 | NUM_EPOCHS = 20 16 | WARMUP_EPOCHS = 1 17 | 18 | parallel = dict( 19 | pipeline=1, 20 | tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE, depth=1), 21 | ) 22 | 23 | model = dict(type=gpt2_xl, 24 | vocab_size=VOCAB_SIZE, 25 | max_position_embeddings=SEQ_LENGTH, 26 | dtype=torch.half, 27 | # fuse_scale_mask_softmax=True, 28 | checkpoint=True) 29 | 30 | fp16 = dict(mode=AMP_TYPE.NAIVE) 31 | 32 | gradient_accumulation = 1 33 | 34 | BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation 35 | 36 | clip_grad_norm = 1.0 37 | 38 | # LOG_PATH = f"./gpt2_{TENSOR_PARALLEL_MODE}_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_accum{gradient_accumulation}_clip_grad{clip_grad_norm}/" 39 | -------------------------------------------------------------------------------- /gpt/configs/gpt2_8b_2p5d_256.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from colossalai.amp import AMP_TYPE 3 | from model_zoo.gpt import gpt2_8B, gpt2_xl 4 | 5 | VOCAB_SIZE = 50304 6 | SEQ_LENGTH = 1024 7 | 8 | TOTAL_BATCH_SIZE = 1280 9 | LEARNING_RATE = 0.00015 10 | WEIGHT_DECAY = 1e-2 11 | 12 | TENSOR_PARALLEL_SIZE = 16 13 | TENSOR_PARALLEL_MODE = '2.5d' 14 | 15 | NUM_EPOCHS = 60 16 | WARMUP_EPOCHS = int(NUM_EPOCHS * 0.36) 17 | 18 | parallel = dict( 19 | pipeline=1, 20 | tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE, depth=1), 21 | ) 22 | 23 | model = dict(type=gpt2_8B, 24 | vocab_size=VOCAB_SIZE, 25 | max_position_embeddings=SEQ_LENGTH, 26 | dtype=torch.half, 27 | fuse_scale_mask_softmax=True, 28 | checkpoint=True) 29 | 30 | fp16 = dict(mode=AMP_TYPE.NAIVE) 31 | 32 | gradient_accumulation = 1 33 | 34 | BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation 35 | 36 | clip_grad_norm = 1.0 37 | 38 | # LOG_PATH = f"./gpt2_{TENSOR_PARALLEL_MODE}_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_accum{gradient_accumulation}_clip_grad{clip_grad_norm}/" 39 | -------------------------------------------------------------------------------- /gpt/configs/gpt2_8b_2p5d_512.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from colossalai.amp import AMP_TYPE 3 | from model_zoo.gpt import gpt2_8B, gpt2_xl 4 | 5 | VOCAB_SIZE = 50304 6 | SEQ_LENGTH = 1024 7 | 8 | TOTAL_BATCH_SIZE = 2560 9 | LEARNING_RATE = 0.00015 10 | WEIGHT_DECAY = 1e-2 11 | 12 | TENSOR_PARALLEL_SIZE = 16 13 | TENSOR_PARALLEL_MODE = '2.5d' 14 | 15 | NUM_EPOCHS = 60 16 | WARMUP_EPOCHS = int(NUM_EPOCHS * 0.36) 17 | 18 | parallel = dict( 19 | pipeline=1, 20 | tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE, depth=1), 21 | ) 22 | 23 | model = dict(type=gpt2_8B, 24 | vocab_size=VOCAB_SIZE, 25 | max_position_embeddings=SEQ_LENGTH, 26 | dtype=torch.half, 27 | fuse_scale_mask_softmax=True, 28 | checkpoint=True) 29 | 30 | fp16 = dict(mode=AMP_TYPE.NAIVE) 31 | 32 | gradient_accumulation = 1 33 | 34 | BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation 35 | 36 | clip_grad_norm = 1.0 37 | 38 | # LOG_PATH = f"./gpt2_{TENSOR_PARALLEL_MODE}_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_accum{gradient_accumulation}_clip_grad{clip_grad_norm}/" 39 | -------------------------------------------------------------------------------- /gpt/configs/gpt2_xl_2p5d_256.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from colossalai.amp import AMP_TYPE 3 | from model_zoo.gpt import gpt2_8B, gpt2_xl 4 | 5 | VOCAB_SIZE = 50304 6 | SEQ_LENGTH = 1024 7 | 8 | TOTAL_BATCH_SIZE = 2048 9 | LEARNING_RATE = 0.00015 10 | WEIGHT_DECAY = 1e-2 11 | 12 | TENSOR_PARALLEL_SIZE = 4 13 | TENSOR_PARALLEL_MODE = '2.5d' 14 | 15 | NUM_EPOCHS = 60 16 | WARMUP_EPOCHS = int(NUM_EPOCHS * 0.36) 17 | 18 | parallel = dict( 19 | pipeline=1, 20 | tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE, depth=1), 21 | ) 22 | 23 | model = dict(type=gpt2_xl, 24 | vocab_size=VOCAB_SIZE, 25 | max_position_embeddings=SEQ_LENGTH, 26 | dtype=torch.half, 27 | fuse_scale_mask_softmax=True, 28 | checkpoint=True) 29 | 30 | fp16 = dict(mode=AMP_TYPE.NAIVE) 31 | 32 | gradient_accumulation = 1 33 | 34 | BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation 35 | 36 | clip_grad_norm = 1.0 37 | 38 | # LOG_PATH = f"./gpt2_{TENSOR_PARALLEL_MODE}_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_accum{gradient_accumulation}_clip_grad{clip_grad_norm}/" 39 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug-report.yml: -------------------------------------------------------------------------------- 1 | name: 🐛 Bug Report 2 | description: Create a report to help us reproduce and fix the bug 3 | 4 | body: 5 | - type: markdown 6 | attributes: 7 | value: > 8 | #### Not suitable for your needs? [Open a blank issue](https://github.com/hpcaitech/ColossalAI/issues/new). 9 | - type: textarea 10 | attributes: 11 | label: 🐛 Describe the bug 12 | description: | 13 | **Describe the bug** 14 | A clear and concise description of what the bug is. 15 | **To Reproduce** 16 | Steps or code snippet to reproduce the behavior. 17 | **Expected behavior** 18 | A clear and concise description of what you expected to happen. 19 | **Screenshots** 20 | If applicable, add screenshots to help explain your problem. 21 | placeholder: | 22 | A clear and concise description of what the bug is. 23 | validations: 24 | required: true 25 | - type: textarea 26 | attributes: 27 | label: Environment 28 | description: | 29 | Please provide the environment information, eg. CUDA/cuDNN/NCCL/Python/PyTorch version. 30 | 31 | - type: markdown 32 | attributes: 33 | value: > 34 | Thanks for contributing 🎉! 35 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.yml: -------------------------------------------------------------------------------- 1 | name: 🚀 Feature request 2 | description: Suggest an idea for this project 3 | 4 | body: 5 | - type: markdown 6 | attributes: 7 | value: > 8 | #### Not suitable for your needs? [Open a blank issue](https://github.com/hpcaitech/ColossalAI/issues/new). 9 | - type: textarea 10 | attributes: 11 | label: Describe the feature 12 | description: | 13 | **Is your feature request related to a problem? Please describe.** 14 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 15 | **Describe the solution you'd like** 16 | A clear and concise description of what you want to happen. 17 | **Describe alternatives you've considered** 18 | A clear and concise description of any alternative solutions or features you've considered. 19 | **Screenshots** 20 | If applicable, add screenshots to help explain your problem. 21 | **Suggest a potential alternative/fix** 22 | Tell us how we could improve this project. 23 | placeholder: | 24 | A clear and concise description of your idea. 25 | validations: 26 | required: true 27 | 28 | - type: markdown 29 | attributes: 30 | value: > 31 | Thanks for contributing 🎉! 32 | -------------------------------------------------------------------------------- /zero/patrickstar_utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from zero.common.utils import CONFIG, get_gpu_memory_mb, print_log 5 | from torch.distributed import init_process_group 6 | 7 | 8 | def init_w_ps(builder): 9 | from patrickstar.runtime import initialize_engine 10 | 11 | config = CONFIG.copy() 12 | 13 | rank = int(os.environ['RANK']) 14 | world_size = int(os.environ['WORLD_SIZE']) 15 | host = os.environ['MASTER_ADDR'] 16 | port = int(os.environ['MASTER_PORT']) 17 | init_process_group(rank=rank, world_size=world_size, init_method=f'tcp://{host}:{port}', backend='nccl') 18 | 19 | torch.cuda.set_device(rank) 20 | if CONFIG.get('gpu_mem_fraction', None) is not None: 21 | torch.cuda.set_per_process_memory_fraction(CONFIG['gpu_mem_fraction']) 22 | print_log(f'Set max GPU mem: {get_gpu_memory_mb() * CONFIG["gpu_mem_fraction"]:.2f} MB') 23 | 24 | build_data, build_model, build_loss, _, build_scheduler = builder() 25 | 26 | train_data, test_data = build_data() 27 | 28 | criterion = build_loss() 29 | 30 | model, optimizer = initialize_engine(model_func=build_model, local_rank=rank, config=config) 31 | 32 | lr_scheduler = build_scheduler(len(train_data), optimizer) 33 | 34 | return model, train_data, test_data, criterion, optimizer, None, lr_scheduler 35 | -------------------------------------------------------------------------------- /zero/deepspeed_utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from zero.common.utils import CONFIG, get_gpu_memory_mb, print_log 3 | 4 | 5 | def init_w_ds(builder): 6 | import deepspeed 7 | 8 | config = CONFIG.copy() 9 | 10 | deepspeed.init_distributed() 11 | 12 | if CONFIG.get('gpu_mem_fraction', None) is not None: 13 | torch.cuda.set_per_process_memory_fraction(CONFIG['gpu_mem_fraction']) 14 | print_log(f'Set max GPU mem: {get_gpu_memory_mb() * CONFIG["gpu_mem_fraction"]:.2f} MB') 15 | 16 | build_data, build_model, build_loss, build_optimizer, build_scheduler = builder() 17 | 18 | train_data, test_data = build_data() 19 | 20 | with deepspeed.zero.Init(config_dict_or_path=config): 21 | model = build_model() 22 | 23 | criterion = build_loss() 24 | 25 | optimizer = build_optimizer(model.parameters()) 26 | 27 | lr_scheduler = build_scheduler(len(train_data), optimizer) 28 | 29 | model, optimizer, _, lr_scheduler = deepspeed.initialize(model=model, 30 | optimizer=optimizer, 31 | lr_scheduler=lr_scheduler, 32 | config=config) 33 | 34 | return model, train_data, test_data, criterion, optimizer, None, lr_scheduler 35 | -------------------------------------------------------------------------------- /zero/README.md: -------------------------------------------------------------------------------- 1 | # GPT2 ZeRO Benchmark 2 | GPT2 ZeRO benchmark with data parallelism to evaluate Colossal-AI, DeepSpeed, FairScale and PatrickStar. 3 | 4 | ## Requirements 5 | ``` 6 | CUDA>=11.3 7 | torch>=1.10.0 8 | deepspeed>=0.5.8 9 | fairscale>=0.4.5 10 | patrickstar>=0.4.6 11 | nvidia-dali>=1.8.0 12 | ``` 13 | 14 | ## Setup 15 | 1. Install dependencies if you do not have them 16 | ``` 17 | pip install -r requirement.txt 18 | ``` 19 | 2. Also, download PatrickStar from github 20 | ``` 21 | https://github.com/Tencent/PatrickStar.git 22 | ``` 23 | 3. Install PatrickStar 24 | ``` 25 | cd PatrickStar 26 | pip install . 27 | ``` 28 | 4. Add root dir into PYTHONPATH 29 | ``` 30 | export PYTHONPATH=$(dirname "$PWD"):$PYTHONPATH 31 | ``` 32 | 33 | ## GPT Usage 34 | 35 | 1. Prepare datasets and tokenizers from HuggingFace Hub if necessary (e.g. we provide an example of training `wikitext-2`). 36 | 37 | 2. Run benchmark with one of the systems to evaluate 38 | ``` 39 | DATA=/PATH/TO/DATASET TOKENIZER=/PATH/TO/TOKENIZER LOG=/PATH/TO/LOG torchrun --nproc_per_node=NUM_GPUS run.py --config=CONFIG_FILE 40 | ``` 41 | 42 | ## VIT Usage 43 | 1. Prepare ImageNet-1k datasets (TFrecord version). 44 | 45 | 2. Run benchmark with one of the systems to evaluate 46 | ``` 47 | DATA=/PATH/TO/DATASET LOG=/PATH/TO/LOG torchrun --nproc_per_node=NUM_GPUS run.py --config=CONFIG_FILE 48 | ``` 49 | -------------------------------------------------------------------------------- /zero/run.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from zero.colossalai_utils.utils import init_w_col 4 | from zero.common.gpt2 import gpt2_builder 5 | from zero.common.train import train 6 | from zero.common.utils import CONFIG, load_config, print_log 7 | from zero.common.vit import vit_builder 8 | from zero.deepspeed_utils.utils import init_w_ds 9 | from zero.fairscale_utils.utils import init_w_fs 10 | from zero.patrickstar_utils.utils import init_w_ps 11 | from zero.torch_utils.utils import init_w_torch 12 | 13 | _zero_method = { 14 | 'fairscale': init_w_fs, 15 | 'colossalai': init_w_col, 16 | 'torch': init_w_torch, 17 | 'patrickstar': init_w_ps, 18 | 'deepspeed': init_w_ds 19 | } 20 | 21 | _builder = { 22 | 'gpt2': gpt2_builder, 23 | 'vit': vit_builder, 24 | } 25 | 26 | 27 | def run_zero(): 28 | method = CONFIG['method'] 29 | assert method in ['colossalai', 'deepspeed', 'fairscale', 'patrickstar', 'torch'], f'No support for {method}.' 30 | 31 | model = CONFIG['model']['type'] 32 | model_type = model.split('_')[0] 33 | assert model_type in ['gpt2', 'vit'], f'No support for {model}.' 34 | 35 | train(*_zero_method[method](_builder[model_type])) 36 | 37 | 38 | if __name__ == '__main__': 39 | load_config() 40 | 41 | CONFIG['log_path'] = os.environ.get('LOG', '.') 42 | os.makedirs(CONFIG['log_path'], exist_ok=True) 43 | 44 | print_log(f'Initializing {CONFIG["method"]} ...') 45 | 46 | run_zero() 47 | -------------------------------------------------------------------------------- /zero/deepspeed_utils/vit_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "deepspeed", 3 | "model": { 4 | "type": "vit_h" 5 | }, 6 | "hyperparameter": { 7 | "batch_size": 4, 8 | "steps_per_epoch":10 9 | }, 10 | "train_batch_size": 32, 11 | "steps_per_print": 2147483647, 12 | "zero_optimization": { 13 | "stage": 3, 14 | "offload_optimizer": { 15 | "device": "cpu", 16 | "pin_memory": true, 17 | "buffer_count": 4, 18 | "fast_init": false 19 | }, 20 | "offload_param": { 21 | "device": "cpu", 22 | "pin_memory": true, 23 | "buffer_count": 5, 24 | "buffer_size": 1e8, 25 | "max_in_cpu": 1e9 26 | }, 27 | "allgather_partitions": true, 28 | "allgather_bucket_size": 5e8, 29 | "overlap_comm": true, 30 | "reduce_scatter": true, 31 | "reduce_bucket_size": 5e8, 32 | "contiguous_gradients": true, 33 | "stage3_max_live_parameters": 1e9, 34 | "stage3_max_reuse_distance": 1e9, 35 | "stage3_prefetch_bucket_size": 5e8, 36 | "stage3_param_persistence_threshold": 1e6 37 | }, 38 | "gradient_clipping": 1.0, 39 | "fp16": { 40 | "enabled": true, 41 | "loss_scale": 0, 42 | "initial_scale_power": 15, 43 | "loss_scale_window": 1000, 44 | "hysteresis": 2, 45 | "min_loss_scale": 1 46 | }, 47 | "use_mem_monitor": true 48 | } -------------------------------------------------------------------------------- /gpt/configs/gpt2_pp1d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from colossalai.amp import AMP_TYPE 3 | from colossalai.engine.schedule import PipelineSchedule 4 | from model_zoo.gpt import gpt2_8B_pipeline 5 | 6 | VOCAB_SIZE = 50304 7 | SEQ_LENGTH = 1024 8 | 9 | TOTAL_BATCH_SIZE = 64 10 | LEARNING_RATE = 0.00015 11 | WEIGHT_DECAY = 1e-2 12 | 13 | gradient_accumulation = 1 14 | 15 | clip_grad_norm = 1.0 16 | 17 | BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation 18 | 19 | TENSOR_PARALLEL_SIZE = 4 20 | TENSOR_PARALLEL_MODE = '1d' 21 | 22 | PIPELINE_SIZE = 2 23 | MICRO_BATCH_SIZE = 4 24 | NUM_MICRO_BATCHES = BATCH_SIZE // MICRO_BATCH_SIZE 25 | 26 | NUM_EPOCHS = 20 27 | WARMUP_EPOCHS = 1 28 | 29 | parallel = dict( 30 | pipeline=PIPELINE_SIZE, 31 | tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE), 32 | ) 33 | 34 | model = dict(type=gpt2_8B_pipeline, 35 | vocab_size=VOCAB_SIZE, 36 | max_position_embeddings=SEQ_LENGTH, 37 | dtype=torch.half, 38 | checkpoint=True) 39 | 40 | schedule = dict(type=PipelineSchedule, 41 | num_microbatches=NUM_MICRO_BATCHES, 42 | tensor_shape=(MICRO_BATCH_SIZE, SEQ_LENGTH, 3072), 43 | scatter_gather_tensors=True) 44 | 45 | fp16 = dict(mode=AMP_TYPE.NAIVE, ) 46 | 47 | # LOG_PATH = f"./gpt3_{TENSOR_PARALLEL_MODE}_pp{PIPELINE_SIZE}_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_accum{gradient_accumulation}_clip_grad{clip_grad_norm}/" 48 | -------------------------------------------------------------------------------- /zero/deepspeed_utils/gpt2_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "deepspeed", 3 | "model": { 4 | "type": "gpt2_10b" 5 | }, 6 | "hyperparameter": { 7 | "batch_size": 20, 8 | "num_epochs": 2, 9 | "steps_per_epoch": 10, 10 | "synthetic": true 11 | }, 12 | "train_batch_size": 40, 13 | "steps_per_print": 2147483647, 14 | "zero_optimization": { 15 | "stage": 3, 16 | "offload_optimizer": { 17 | "device": "cpu", 18 | "pin_memory": true, 19 | "buffer_count": 4, 20 | "fast_init": false 21 | }, 22 | "offload_param": { 23 | "device": "cpu", 24 | "pin_memory": true, 25 | "buffer_count": 5, 26 | "buffer_size": 1e8, 27 | "max_in_cpu": 1e9 28 | }, 29 | "allgather_partitions": true, 30 | "allgather_bucket_size": 5e8, 31 | "overlap_comm": true, 32 | "reduce_scatter": true, 33 | "reduce_bucket_size": 5e8, 34 | "contiguous_gradients": true, 35 | "stage3_max_live_parameters": 1e9, 36 | "stage3_max_reuse_distance": 1e9, 37 | "stage3_prefetch_bucket_size": 5e8, 38 | "stage3_param_persistence_threshold": 1e6 39 | }, 40 | "gradient_clipping": 1.0, 41 | "fp16": { 42 | "enabled": true, 43 | "loss_scale": 0, 44 | "initial_scale_power": 5, 45 | "loss_scale_window": 1000, 46 | "hysteresis": 2, 47 | "min_loss_scale": 1 48 | }, 49 | "use_mem_monitor": true 50 | } -------------------------------------------------------------------------------- /zero/torch_utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from zero.common.utils import CONFIG, get_gpu_memory_mb, get_model_size, print_log 5 | from torch.distributed import init_process_group 6 | from torch.nn.parallel import DistributedDataParallel as DDP 7 | 8 | 9 | def init_w_torch(builder): 10 | rank = int(os.environ['RANK']) 11 | world_size = int(os.environ['WORLD_SIZE']) 12 | host = os.environ['MASTER_ADDR'] 13 | port = int(os.environ['MASTER_PORT']) 14 | init_process_group(rank=rank, world_size=world_size, init_method=f'tcp://{host}:{port}', backend='nccl') 15 | 16 | torch.cuda.set_device(rank) 17 | if CONFIG.get('gpu_mem_fraction', None) is not None: 18 | torch.cuda.set_per_process_memory_fraction(CONFIG['gpu_mem_fraction']) 19 | print_log(f'Set max GPU mem: {get_gpu_memory_mb() * CONFIG["gpu_mem_fraction"]:.2f} MB') 20 | 21 | build_data, build_model, build_loss, build_optimizer, build_scheduler = builder() 22 | 23 | train_data, test_data = build_data() 24 | 25 | model = build_model().to(rank) 26 | if 'numel' not in CONFIG['model']: 27 | CONFIG['model']['numel'] = get_model_size(model) 28 | model = DDP(model) 29 | 30 | criterion = build_loss() 31 | 32 | optimizer = build_optimizer(model.parameters()) 33 | 34 | scaler = torch.cuda.amp.GradScaler(**CONFIG['fp16']) if 'fp16' in CONFIG else None 35 | 36 | lr_scheduler = build_scheduler(len(train_data), optimizer) 37 | 38 | return model, train_data, test_data, criterion, optimizer, scaler, lr_scheduler 39 | -------------------------------------------------------------------------------- /zero/deepspeed_utils/gpt2_nvme_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "deepspeed", 3 | "model": { 4 | "type": "gpt2_32b" 5 | }, 6 | "hyperparameter": { 7 | "batch_size": 4, 8 | "num_epochs": 2, 9 | "steps_per_epoch": 4, 10 | "synthetic": true 11 | }, 12 | "train_batch_size": 4, 13 | "steps_per_print": 2147483647, 14 | "zero_optimization": { 15 | "stage": 3, 16 | "offload_optimizer": { 17 | "device": "nvme", 18 | "nvme_path": "/data/user/offload", 19 | "pin_memory": false, 20 | "buffer_count": 4, 21 | "fast_init": false 22 | }, 23 | "offload_param": { 24 | "device": "cpu", 25 | "pin_memory": true, 26 | "buffer_count": 5, 27 | "buffer_size": 1e8, 28 | "max_in_cpu": 1e9 29 | }, 30 | "allgather_partitions": true, 31 | "allgather_bucket_size": 5e8, 32 | "overlap_comm": true, 33 | "reduce_scatter": true, 34 | "reduce_bucket_size": 5e8, 35 | "contiguous_gradients": true, 36 | "stage3_max_live_parameters": 1e9, 37 | "stage3_max_reuse_distance": 1e9, 38 | "stage3_prefetch_bucket_size": 5e8, 39 | "stage3_param_persistence_threshold": 1e6 40 | }, 41 | "gradient_clipping": 1.0, 42 | "fp16": { 43 | "enabled": true, 44 | "loss_scale": 0, 45 | "initial_scale_power": 5, 46 | "loss_scale_window": 1000, 47 | "hysteresis": 2, 48 | "min_loss_scale": 1 49 | }, 50 | "use_mem_monitor": true 51 | } -------------------------------------------------------------------------------- /gpt/data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import torch 5 | from colossalai.registry import DATASETS 6 | from torch.utils.data import Dataset 7 | from transformers import GPT2Tokenizer 8 | 9 | 10 | @DATASETS.register_module 11 | class WebtextDataset(Dataset): 12 | def __init__(self, path, seq_len=1024) -> None: 13 | super().__init__() 14 | root = os.path.dirname(path) 15 | encoded_data_cache_path = os.path.join(root, f'gpt_webtext_{seq_len}.pt') 16 | if os.path.isfile(encoded_data_cache_path): 17 | seq_len_, data, attention_mask = torch.load(encoded_data_cache_path) 18 | if seq_len_ == seq_len: 19 | self.data = data 20 | self.attention_mask = attention_mask 21 | return 22 | raw_data = [] 23 | with open(path) as f: 24 | for line in f.readlines(): 25 | raw_data.append(json.loads(line)['text']) 26 | tokenizer = GPT2Tokenizer.from_pretrained('gpt2') 27 | tokenizer.pad_token = tokenizer.unk_token 28 | encoded_data = tokenizer(raw_data, padding=True, truncation=True, max_length=seq_len, return_tensors='pt') 29 | self.data = encoded_data['input_ids'] 30 | self.attention_mask = encoded_data['attention_mask'] 31 | torch.save((seq_len, self.data, self.attention_mask), encoded_data_cache_path) 32 | 33 | def __len__(self): 34 | return len(self.data) 35 | 36 | def __getitem__(self, index): 37 | return {'input_ids': self.data[index], 'attention_mask': self.attention_mask[index]}, self.data[index] 38 | -------------------------------------------------------------------------------- /gpt/configs/gpt3_pp1d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from colossalai.amp import AMP_TYPE 3 | from colossalai.engine.schedule import InterleavedPipelineSchedule 4 | from model_zoo.gpt import gpt3_pipeline 5 | 6 | VOCAB_SIZE = 50304 7 | SEQ_LENGTH = 2048 8 | 9 | TOTAL_BATCH_SIZE = 192 10 | LEARNING_RATE = 0.00015 11 | WEIGHT_DECAY = 1e-2 12 | 13 | gradient_accumulation = 1 14 | 15 | clip_grad_norm = 1.0 16 | 17 | BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation 18 | 19 | TENSOR_PARALLEL_SIZE = 4 20 | TENSOR_PARALLEL_MODE = '1d' 21 | 22 | PIPELINE_SIZE = 32 23 | MICRO_BATCH_SIZE = 1 24 | NUM_MICRO_BATCHES = BATCH_SIZE // MICRO_BATCH_SIZE 25 | NUM_CHUNKS = 1 26 | 27 | NUM_EPOCHS = 20 28 | WARMUP_EPOCHS = int(NUM_EPOCHS * 0.36) 29 | 30 | parallel = dict( 31 | pipeline=PIPELINE_SIZE, 32 | tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE), 33 | ) 34 | 35 | model = dict(type=gpt3_pipeline, 36 | num_chunks=NUM_CHUNKS, 37 | vocab_size=VOCAB_SIZE, 38 | max_position_embeddings=SEQ_LENGTH, 39 | dtype=torch.half, 40 | fuse_scale_mask_softmax=True, 41 | checkpoint=True) 42 | 43 | schedule = dict(type=InterleavedPipelineSchedule, 44 | num_microbatches=NUM_MICRO_BATCHES, 45 | num_model_chunks=NUM_CHUNKS, 46 | tensor_shape=(MICRO_BATCH_SIZE, SEQ_LENGTH, 12288), 47 | scatter_gather_tensors=True) 48 | 49 | fp16 = dict(mode=AMP_TYPE.NAIVE, ) 50 | 51 | # LOG_PATH = f"./gpt3_{TENSOR_PARALLEL_MODE}_pp{PIPELINE_SIZE}_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_accum{gradient_accumulation}_clip_grad{clip_grad_norm}/" 52 | -------------------------------------------------------------------------------- /zero/fairscale_utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from zero.common.utils import CONFIG, get_gpu_memory_mb, print_log 5 | from torch.distributed import init_process_group 6 | 7 | 8 | def init_w_fs(builder): 9 | from fairscale.nn.checkpoint import checkpoint_wrapper 10 | from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP 11 | from fairscale.optim.grad_scaler import ShardedGradScaler 12 | 13 | rank = int(os.environ['RANK']) 14 | world_size = int(os.environ['WORLD_SIZE']) 15 | host = os.environ['MASTER_ADDR'] 16 | port = int(os.environ['MASTER_PORT']) 17 | init_process_group(rank=rank, world_size=world_size, init_method=f'tcp://{host}:{port}', backend='nccl') 18 | 19 | torch.cuda.set_device(rank) 20 | if CONFIG.get('gpu_mem_fraction', None) is not None: 21 | torch.cuda.set_per_process_memory_fraction(CONFIG['gpu_mem_fraction']) 22 | print_log(f'Set max GPU mem: {get_gpu_memory_mb() * CONFIG["gpu_mem_fraction"]:.2f} MB') 23 | 24 | build_data, build_model, build_loss, build_optimizer, build_scheduler = builder() 25 | 26 | train_data, test_data = build_data() 27 | 28 | assert 'fsdp' in CONFIG 29 | use_checkpoint = CONFIG['model'].get('checkpoint') 30 | CONFIG['model']['checkpoint'] = False 31 | model = build_model() 32 | if use_checkpoint: 33 | model = checkpoint_wrapper(model) 34 | model = FSDP(model, **CONFIG['fsdp']) 35 | 36 | criterion = build_loss() 37 | 38 | optimizer = build_optimizer(model.parameters()) 39 | 40 | scaler = ShardedGradScaler(**CONFIG['fp16']) if 'fp16' in CONFIG else None 41 | 42 | lr_scheduler = build_scheduler(len(train_data), optimizer) 43 | 44 | return model, train_data, test_data, criterion, optimizer, scaler, lr_scheduler 45 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | docs/.build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # pipenv 89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 92 | # install all needed dependencies. 93 | #Pipfile.lock 94 | 95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 96 | __pypackages__/ 97 | 98 | # Celery stuff 99 | celerybeat-schedule 100 | celerybeat.pid 101 | 102 | # SageMath parsed files 103 | *.sage.py 104 | 105 | # Environments 106 | .env 107 | .venv 108 | env/ 109 | venv/ 110 | ENV/ 111 | env.bak/ 112 | venv.bak/ 113 | 114 | # Spyder project settings 115 | .spyderproject 116 | .spyproject 117 | 118 | # Rope project settings 119 | .ropeproject 120 | 121 | # mkdocs documentation 122 | /site 123 | 124 | # mypy 125 | .mypy_cache/ 126 | .dmypy.json 127 | dmypy.json 128 | 129 | # Pyre type checker 130 | .pyre/ 131 | 132 | # IDE 133 | .idea/ 134 | .vscode/ 135 | 136 | # macos 137 | .DS_Store 138 | #data/ 139 | 140 | docs/.build -------------------------------------------------------------------------------- /bert/colossalai_utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from zero.common.utils import CONFIG, print_log 3 | from torch.cuda import max_memory_allocated, reset_peak_memory_stats 4 | from torch.distributed import get_rank 5 | 6 | def init_w_col(builder): 7 | import colossalai 8 | from colossalai.core import global_context as gpc 9 | from colossalai.nn.optimizer import CPUAdam 10 | from colossalai.zero.init_ctx import ZeroInitContext 11 | from colossalai.zero.shard_utils import (BucketTensorShardStrategy) 12 | 13 | from colossalai.utils.memory_utils.utils import colo_set_process_memory_fraction 14 | colo_set_process_memory_fraction(0.2) 15 | 16 | colossalai.launch_from_torch(config=CONFIG) 17 | 18 | build_data, build_model, build_loss, optimizer_class, build_scheduler = builder() 19 | 20 | print_log('Building data') 21 | train_data, test_data = build_data() 22 | 23 | use_zero = "zero" in gpc.config 24 | if use_zero: 25 | cpu_offload = gpc.config.zero.model_config.offload_config.device == 'cpu' 26 | else: 27 | cpu_offload = None 28 | 29 | rank = get_rank() 30 | reset_peak_memory_stats(rank) 31 | 32 | print_log('Building model') 33 | if use_zero: 34 | shard_strategy = BucketTensorShardStrategy() 35 | with ZeroInitContext(target_device=torch.cuda.current_device(), 36 | shard_strategy=shard_strategy, 37 | shard_param=True): 38 | model = build_model() 39 | gpc.config.zero.model_config['shard_strategy'] = shard_strategy 40 | 41 | else: 42 | model = build_model() 43 | 44 | criterion = build_loss() 45 | 46 | print_log(f'Peak Memory = {max_memory_allocated(rank) / (1024 * 1024)} M') 47 | reset_peak_memory_stats(rank) 48 | 49 | optimizer_kwargs = {} 50 | if use_zero and cpu_offload: 51 | optimizer_class = CPUAdam 52 | optimizer_kwargs = { 53 | 'lr': CONFIG['hyperparameter']['learning_rate'], 54 | 'weight_decay': CONFIG['hyperparameter']['weight_decay'] 55 | } 56 | 57 | optimizer = optimizer_class(model.parameters()) 58 | 59 | lr_scheduler = build_scheduler(len(train_data), optimizer) 60 | print_log(f'Peak Memory = {max_memory_allocated(rank) / (1024 * 1024)} M') 61 | 62 | engine, train_data, test_data, lr_scheduler = colossalai.initialize(model, 63 | optimizer, 64 | criterion, 65 | train_data, 66 | test_data, 67 | lr_scheduler) 68 | model = engine 69 | criterion = engine.criterion 70 | optimizer = engine 71 | 72 | return model, train_data, test_data, criterion, optimizer, None, lr_scheduler 73 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Benchmark for Tuning Accuracy and Efficiency 2 | 3 | ## Overview 4 | 5 | The benchmark includes our efforts in using Colossal-AI to train different tasks to achieve SOTA results. 6 | We are interested in both validataion accuracy and training speed, and prefer larger batch size to take advantage of more GPU devices. 7 | For example, we trained vision transformer with batch size 512 on CIFAR10 and 4096 on ImageNet1k, which are basically not used in existing works. 8 | Some of the results in the benchmark trained with 8x A100 are shown below. 9 | 10 | | Task | Model | Training Time | Top-1 Accuracy | 11 | | ---------- | ------------ | ------------- | -------------- | 12 | | CIFAR10 | [ViT-Lite-7/4](https://arxiv.org/pdf/2104.05704.pdf) | ~ 16 min | ~ 90.5% | 13 | | ImageNet1k | ViT-S/16 | ~ 16.5 h | ~ 74.5% | 14 | 15 | The `train.py` script in each task runs training with the specific configuration script in `configs/` for different parallelisms. 16 | Supported parallelisms include data parallel only (ends with `vanilla`), 1D (ends with `1d`), 2D (ends with `2d`), 2.5D (ends with `2p5d`), 3D (ends with `3d`). 17 | 18 | Each configuration scripts basically includes the following elements, taking ImageNet1k task as example: 19 | ``` 20 | TOTAL_BATCH_SIZE = 4096 21 | LEARNING_RATE = 3e-3 22 | WEIGHT_DECAY = 0.3 23 | 24 | NUM_EPOCHS = 300 25 | WARMUP_EPOCHS = 32 26 | 27 | # data parallel only 28 | TENSOR_PARALLEL_SIZE = 1 29 | TENSOR_PARALLEL_MODE = None 30 | 31 | # parallelism setting 32 | parallel = dict( 33 | pipeline=1, 34 | tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE), 35 | ) 36 | 37 | fp16 = dict(mode=AMP_TYPE.TORCH, ) # amp setting 38 | 39 | gradient_accumulation = 2 # accumulate 2 steps for gradient update 40 | 41 | BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation # actual batch size for dataloader 42 | 43 | clip_grad_norm = 1.0 # clip gradient with norm 1.0 44 | ``` 45 | Upper case elements are basically what `train.py` needs, and lower case elements are what Colossal-AI needs to initialize the training. 46 | 47 | ## Usage 48 | 49 | To start training, use the following command to run each worker: 50 | ``` 51 | $ DATA=/path/to/dataset python train.py --world_size=WORLD_SIZE \ 52 | --rank=RANK \ 53 | --local_rank=LOCAL_RANK \ 54 | --host=MASTER_IP_ADDRESS \ 55 | --port=MASTER_PORT \ 56 | --config=CONFIG_FILE 57 | ``` 58 | It is also recommended to start training with `torchrun` as: 59 | ``` 60 | $ DATA=/path/to/dataset torchrun --nproc_per_node=NUM_GPUS_PER_NODE \ 61 | --nnodes=NUM_NODES \ 62 | --node_rank=NODE_RANK \ 63 | --master_addr=MASTER_IP_ADDRESS \ 64 | --master_port=MASTER_PORT \ 65 | train.py --config=CONFIG_FILE 66 | ``` -------------------------------------------------------------------------------- /zero/colossalai_utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.cuda import max_memory_allocated, reset_peak_memory_stats 3 | from torch.distributed import get_rank 4 | from zero.common.utils import (CONFIG, get_gpu_memory_mb, get_model_size, 5 | print_log) 6 | 7 | 8 | def init_w_col(builder): 9 | import colossalai 10 | from colossalai.core import global_context as gpc 11 | from colossalai.logging import disable_existing_loggers 12 | from colossalai.nn.optimizer import CPUAdam 13 | from colossalai.zero.init_ctx import ZeroInitContext 14 | from colossalai.zero.shard_utils import (BucketTensorShardStrategy, 15 | TensorShardStrategy) 16 | from colossalai.zero.sharded_model import ShardedModel, ShardedModelV2 17 | from colossalai.zero.sharded_optim import ShardedOptimizerV2 18 | 19 | disable_existing_loggers() 20 | colossalai.launch_from_torch(config=CONFIG) 21 | 22 | if CONFIG.get('gpu_mem_fraction', None) is not None: 23 | torch.cuda.set_per_process_memory_fraction(CONFIG['gpu_mem_fraction']) 24 | print_log(f'Set max GPU mem: {get_gpu_memory_mb() * CONFIG["gpu_mem_fraction"]:.2f} MB') 25 | 26 | build_data, build_model, build_loss, optimizer_class, build_scheduler = builder() 27 | 28 | print_log('Building data') 29 | train_data, test_data = build_data() 30 | 31 | use_v2 = gpc.config.zero.pop('version', 2) == 2 32 | 33 | cpu_offload = gpc.config.zero.offload_config.device == 'cpu' 34 | 35 | rank = get_rank() 36 | reset_peak_memory_stats(rank) 37 | 38 | print_log('Building model') 39 | if use_v2: 40 | shard_strategy = TensorShardStrategy() 41 | model_numel = torch.zeros(1, dtype=torch.long) 42 | with ZeroInitContext(target_device=torch.cuda.current_device(), 43 | shard_strategy=shard_strategy, 44 | shard_param=True, 45 | model_numel_tensor=model_numel, 46 | rm_torch_payload_on_the_fly=True): 47 | model = build_model() 48 | model = ShardedModelV2(model, shard_strategy, **gpc.config.zero) 49 | if 'numel' not in CONFIG['model']: 50 | CONFIG['model']['numel'] = model_numel.item() 51 | print_log(f'model numel: {model_numel.item()}') 52 | else: 53 | model = build_model() 54 | if 'numel' not in CONFIG['model']: 55 | CONFIG['model']['numel'] = get_model_size(model) 56 | model = ShardedModel(model, **gpc.config.zero) 57 | 58 | criterion = build_loss() 59 | 60 | print_log(f'Peak Memory = {max_memory_allocated(rank) / (1024 * 1024)} M') 61 | reset_peak_memory_stats(rank) 62 | 63 | optimizer_kwargs = {} 64 | if cpu_offload: 65 | optimizer_class = CPUAdam 66 | optimizer_kwargs = { 67 | 'lr': CONFIG['hyperparameter']['learning_rate'], 68 | 'weight_decay': CONFIG['hyperparameter']['weight_decay'] 69 | } 70 | 71 | if use_v2: 72 | optimizer = optimizer_class(model.parameters(), **optimizer_kwargs) 73 | optimizer = ShardedOptimizerV2(model, 74 | optimizer, 75 | **gpc.config.get('fp16', dict()), 76 | cpu_offload=cpu_offload) 77 | else: 78 | optimizer = optimizer_class(model.parameters()) 79 | 80 | lr_scheduler = build_scheduler(len(train_data), optimizer) 81 | print_log(f'Peak Memory = {max_memory_allocated(rank) / (1024 * 1024)} M') 82 | 83 | return model, train_data, test_data, criterion, optimizer, None, lr_scheduler 84 | -------------------------------------------------------------------------------- /zero/common/utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import time 5 | from concurrent.futures import ThreadPoolExecutor 6 | 7 | import torch 8 | from torch.distributed import get_rank, is_initialized 9 | 10 | CONFIG = dict() 11 | 12 | 13 | def load_config(): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--config', type=str) 16 | args = parser.parse_args() 17 | 18 | config_file = args.config 19 | 20 | assert os.path.exists(config_file), 'No valid config file found.' 21 | 22 | with open(config_file, 'r') as f: 23 | cfg = json.load(f) 24 | for k, v in cfg.items(): 25 | CONFIG[k] = v 26 | 27 | 28 | class AsyncMemoryMonitor: 29 | 30 | def __init__(self, rank, power=3, save_to_disk=True): 31 | """ 32 | Adapted from https://github.com/Tencent/PatrickStar/blob/master/patrickstar/core/memtracer/memtracer.py. 33 | An Async Mem Monitor runing during computing. 34 | Sampling GPU memory usage of the current GPU dev 35 | at interval of 1/(10**power) sec. 36 | """ 37 | self.keep_measuring = False 38 | device = torch.cuda.current_device() 39 | self.executor = ThreadPoolExecutor(max_workers=1, initializer=lambda: torch.cuda.set_device(device)) 40 | self.monitor_thread = None 41 | self.interval = 1 / (10**power) 42 | self.rank = rank 43 | self.file = os.path.join(CONFIG['log_path'], f'memory_rank_{rank}.log') if save_to_disk else None 44 | 45 | def set_interval(self, power: int): 46 | self.interval = 1 / (10**power) 47 | 48 | def start(self): 49 | self.keep_measuring = True 50 | torch.cuda.reset_peak_memory_stats(self.rank) 51 | self.monitor_thread = self.executor.submit(self._measure_usage) 52 | 53 | def finish(self): 54 | if self.keep_measuring is False: 55 | return 0 56 | self.keep_measuring = False 57 | gpu_usage = self.monitor_thread.result() 58 | self.monitor_thread = None 59 | if self.file is not None: 60 | with open(self.file, 'a') as f: 61 | f.writelines(list(map(lambda x: str(x) + '\n', gpu_usage))) 62 | return gpu_usage 63 | 64 | def _measure_usage(self): 65 | gpu_usage = list() 66 | while self.keep_measuring: 67 | gpu_usage.append(torch.cuda.max_memory_allocated(self.rank) / (1024 * 1024)) # MB 68 | torch.cuda.reset_peak_memory_stats(self.rank) 69 | time.sleep(self.interval) 70 | 71 | return gpu_usage 72 | 73 | 74 | def print_log(msg): 75 | msg = f'{time.asctime()} > {msg}' 76 | rank = get_rank() if is_initialized() else 0 77 | log_file = os.path.join(CONFIG['log_path'], f'training_rank_{rank}.log') 78 | with open(log_file, 'a') as f: 79 | f.write(msg + '\n') 80 | if rank == 0: 81 | print(msg) 82 | 83 | 84 | class ModelFromHF(torch.nn.Module): 85 | 86 | def __init__(self, config, model_cls): 87 | super().__init__() 88 | self.module = model_cls(config) 89 | if CONFIG['model'].get('checkpoint'): 90 | self.module.apply(self.set_checkpointing) 91 | 92 | def set_checkpointing(self, module): 93 | if hasattr(module, 'gradient_checkpointing'): 94 | module.gradient_checkpointing = True 95 | 96 | def forward(self, *args, **kwargs): 97 | output = self.module(*args, **kwargs) 98 | return output.logits 99 | 100 | 101 | def get_tflops(iter_time: float, num_tokens: int) -> float: 102 | flops = CONFIG['model']['numel'] * num_tokens * 2.0 * 4.0 103 | return (flops / 1e12) / (iter_time + 1e-12) 104 | 105 | 106 | def get_model_size(model: torch.nn.Module): 107 | return sum(p.numel() for p in model.parameters()) 108 | 109 | 110 | def get_gpu_memory_mb(): 111 | return torch.cuda.get_device_properties(torch.cuda.current_device()).total_memory / 1024**2 112 | -------------------------------------------------------------------------------- /gpt/readme.md: -------------------------------------------------------------------------------- 1 | # Step1: Datasets 2 | We do not host any datasets for GPT or BERT training, however, we detail their collection so that our results may be reproduced. 3 | 4 | 5 | ## Option1: Use A Toy Dataset 6 | If you just want to go through the process quickly, you can download the a toy data (size 80MB). 7 | Down load it from [Google Cloud](https://drive.google.com/file/d/1eCY30B9g-I3oPdtQHR8rmIxx64Js_LZh/view?usp=sharing). 8 | 9 | 10 | ## Option2: Prepare the Webtest Data 11 | ### Collecting GPT Webtext Data 12 | We utilize the publicly available [OpenWebText](https://github.com/eukaryote31/openwebtext) library from [jcpeterson](https://github.com/jcpeterson/openwebtext) and [shenggan's](https://github.com/Shenggan/openwebtext)(modified from [eukaryote31's](https://github.com/eukaryote31/openwebtext)) work to download urls. We then filtered, cleaned, and deduplicated all downloaded content according to the procedure described in Megatron's [openwebtext](./tools/openwebtext) directory. 13 | 14 | #### Install necessary packages 15 | 16 | ``` 17 | pip install ftfy langdetect numpy torch pandas nltk sentencepiece boto3 tqdm regex bs4 newspaper3k htmlmin tldextract cached-path 18 | git clone git@github.com:Shenggan/LSH.git 19 | cd LSH 20 | python setup.py install 21 | ``` 22 | 23 | #### Download Data 24 | 25 | 1. Download the deduplicated URLs `` from [jcpeterson](https://mega.nz/#F!EZZD0YwJ!9_PlEQzdMVLaNdKv_ICNVQ!cc4RgQQZ) 26 | 27 | 2. Remove blacklisted URLs. 28 | 29 | git clone git@github.com:WANG-CR/Megatron-LM.git 30 | python Megatron-LM/tools/openwebtext/blacklist_urls.py 31 | 32 | 3. Download the content from the clean urls and Merge the contents into one loose json file with 1 json per newline of the format `{'text': text, 'url': unique_url}`. The output content will be called 33 | 34 | ``` 35 | git clone git@github.com:Shenggan/openwebtext.git 36 | python openwebtext/download.py --n_procs 50 37 | ``` 38 | 39 | #### Prepare Data for GPT Training 40 | 41 | 1. Perform ftfy, English detection and remove documents with less than 128 tokens. This step can be sharded and run on shards. 42 | 43 | ``` 44 | cd Megatron-LM/tools/openwebtext 45 | python cleanup_dataset.py 46 | ``` 47 | 48 | Additional cleanup (e.g. remove documents less than 512 characters or dataset specific cleaning like stories, realnews datasets) can be done using `cleanup_fix_dataset.py`. More details can be found by running `python cleanup_fix_dataset.py --help` 49 | 50 | 2. Using LSH, find possible duplicates and store then in a file for later processing. The code supports saving and loading fingerprints for recurrent deduplications, and is also multithreaded for faster processing. More details are can be found by `python find_duplicate.py --help`. 51 | 52 | ``` 53 | python find_duplicates.py --inputs --output 54 | ``` 55 | 56 | 3. Based on similarity measure defind inside function `is_similar` (default: 0.9), group urls that are similar. Basically, for each group, only one url we should keep and remove the rest. 57 | 58 | ``` 59 | python group_duplicate_urls.py 60 | ``` 61 | 62 | 4. Remove similar documents that were detected in the last step. 63 | 64 | ``` 65 | python remove_group_duplicates.py 66 | ``` 67 | 68 | 5. shuffle the dataset 69 | 70 | ``` 71 | shuf -o 72 | ``` 73 | 74 | 75 | ## Step2: Training 76 | 77 | Run GPT training using 4 GPUs with vanilla parallel strategy. You can try other strategies by using different files from ./configs. 78 | The training last for 10 steps. 79 | 80 | ```python 81 | NUM_GPUS_PER_NODE=4 82 | NUM_NODES=1 83 | NODE_RANK=0 84 | 85 | export EXEC="torchrun" 86 | export CONFIG="./configs/gpt2_vanilla.py" 87 | 88 | DATA=/your_own_path/small-gpt-dataset.json ${EXEC} --nproc_per_node=${NUM_GPUS_PER_NODE} \ 89 | --nnodes=${NUM_NODES} \ 90 | --node_rank=${NODE_RANK} \ 91 | train.py --from_torch --config=${CONFIG} 92 | ``` -------------------------------------------------------------------------------- /cifar/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import os 5 | 6 | import colossalai 7 | import torch 8 | import torchvision 9 | from colossalai.builder import * 10 | from colossalai.core import global_context as gpc 11 | from colossalai.logging import disable_existing_loggers, get_dist_logger 12 | from colossalai.nn import Accuracy, CrossEntropyLoss 13 | from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR 14 | from colossalai.trainer import Trainer, hooks 15 | from colossalai.utils import MultiTimer, get_dataloader 16 | from model_zoo.vit import vit_lite_depth7_patch4_32 17 | from torchvision import transforms 18 | 19 | DATASET_PATH = str(os.environ['DATA']) 20 | 21 | 22 | def build_cifar(batch_size): 23 | transform_train = transforms.Compose([ 24 | transforms.RandomCrop(32, padding=4), 25 | transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10), 26 | transforms.ToTensor(), 27 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 28 | ]) 29 | transform_test = transforms.Compose([ 30 | transforms.Resize(32), 31 | transforms.ToTensor(), 32 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 33 | ]) 34 | 35 | train_dataset = torchvision.datasets.CIFAR10(root=DATASET_PATH, 36 | train=True, 37 | download=True, 38 | transform=transform_train) 39 | test_dataset = torchvision.datasets.CIFAR10(root=DATASET_PATH, train=False, transform=transform_test) 40 | train_dataloader = get_dataloader(dataset=train_dataset, 41 | shuffle=True, 42 | batch_size=batch_size, 43 | num_workers=4, 44 | pin_memory=True) 45 | test_dataloader = get_dataloader(dataset=test_dataset, batch_size=batch_size, num_workers=4, pin_memory=True) 46 | return train_dataloader, test_dataloader 47 | 48 | 49 | def train_cifar(): 50 | disable_existing_loggers() 51 | parser = colossalai.get_default_parser() 52 | parser.add_argument('--from_torch', default=False, action='store_true') 53 | args = parser.parse_args() 54 | if args.from_torch: 55 | colossalai.launch_from_torch(config=args.config, seed=42) 56 | else: 57 | # standard launch 58 | colossalai.launch(config=args.config, 59 | rank=args.rank, 60 | world_size=args.world_size, 61 | local_rank=args.local_rank, 62 | host=args.host, 63 | port=args.port, 64 | seed=42) 65 | 66 | logger = get_dist_logger() 67 | if hasattr(gpc.config, 'LOG_PATH'): 68 | if gpc.get_global_rank() == 0: 69 | log_path = gpc.config.LOG_PATH 70 | if not os.path.exists(log_path): 71 | os.mkdir(log_path) 72 | logger.log_to_file(log_path) 73 | 74 | model = vit_lite_depth7_patch4_32() 75 | 76 | train_dataloader, test_dataloader = build_cifar(gpc.config.BATCH_SIZE // gpc.data_parallel_size) 77 | 78 | criterion = CrossEntropyLoss(label_smoothing=0.1) 79 | 80 | optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY) 81 | 82 | steps_per_epoch = len(train_dataloader) 83 | 84 | lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, 85 | total_steps=gpc.config.NUM_EPOCHS * steps_per_epoch, 86 | warmup_steps=gpc.config.WARMUP_EPOCHS * steps_per_epoch) 87 | 88 | engine, train_dataloader, test_dataloader, lr_scheduler = colossalai.initialize(model=model, 89 | optimizer=optimizer, 90 | criterion=criterion, 91 | train_dataloader=train_dataloader, 92 | test_dataloader=test_dataloader, 93 | lr_scheduler=lr_scheduler) 94 | 95 | logger.info("Engine is built", ranks=[0]) 96 | 97 | timer = MultiTimer() 98 | 99 | trainer = Trainer(engine=engine, logger=logger, timer=timer) 100 | logger.info("Trainer is built", ranks=[0]) 101 | 102 | hook_list = [ 103 | hooks.LogMetricByEpochHook(logger=logger), 104 | hooks.LogMetricByStepHook(), 105 | # hooks.LogTimingByEpochHook(timer=timer, logger=logger), 106 | # hooks.LogMemoryByEpochHook(logger=logger), 107 | hooks.AccuracyHook(accuracy_func=Accuracy()), 108 | hooks.LossHook(), 109 | hooks.ThroughputHook(), 110 | hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False) 111 | ] 112 | 113 | logger.info("Train start", ranks=[0]) 114 | trainer.fit(train_dataloader=train_dataloader, 115 | test_dataloader=test_dataloader, 116 | epochs=gpc.config.NUM_EPOCHS, 117 | hooks=hook_list, 118 | display_progress=True, 119 | test_interval=1) 120 | 121 | 122 | if __name__ == '__main__': 123 | train_cifar() 124 | -------------------------------------------------------------------------------- /gpt/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import colossalai 4 | import torch 5 | from colossalai.core import global_context as gpc 6 | from colossalai.engine.schedule import InterleavedPipelineSchedule 7 | from colossalai.logging import disable_existing_loggers, get_dist_logger 8 | from colossalai.nn import CosineAnnealingWarmupLR 9 | from colossalai.trainer import Trainer, hooks 10 | from colossalai.utils import MultiTimer, get_dataloader, is_using_pp 11 | from model_zoo.gpt import GPTLMLoss 12 | 13 | from data import WebtextDataset 14 | 15 | 16 | def train_gpt(): 17 | disable_existing_loggers() 18 | parser = colossalai.get_default_parser() 19 | parser.add_argument('--from_torch', default=False, action='store_true') 20 | args = parser.parse_args() 21 | if args.from_torch: 22 | colossalai.launch_from_torch(config=args.config, seed=42) 23 | else: 24 | # standard launch 25 | colossalai.launch(config=args.config, 26 | rank=args.rank, 27 | world_size=args.world_size, 28 | local_rank=args.local_rank, 29 | host=args.host, 30 | port=args.port, 31 | seed=42) 32 | 33 | logger = get_dist_logger() 34 | if hasattr(gpc.config, 'LOG_PATH'): 35 | log_path = gpc.config.LOG_PATH 36 | if not os.path.exists(log_path): 37 | os.mkdir(log_path) 38 | logger.log_to_file(log_path) 39 | 40 | train_dataset = WebtextDataset(os.environ['DATA'], seq_len=gpc.config.SEQ_LENGTH) 41 | train_dataloader = get_dataloader(train_dataset, 42 | seed=42, 43 | batch_size=gpc.config.BATCH_SIZE // gpc.data_parallel_size, 44 | num_workers=1, 45 | pin_memory=True, 46 | shuffle=True, 47 | drop_last=True) 48 | logger.info(f'Loaded {len(train_dataset)}/{len(train_dataloader)} samples/batches', ranks=[0]) 49 | 50 | model = gpc.config.model.pop('type')(**gpc.config.model) 51 | if is_using_pp(): 52 | schedule = gpc.config.schedule.pop('type')(**gpc.config.schedule) 53 | # tensor_shape = getattr(gpc.config, 'TENSOR_SHAPE', None) 54 | # if hasattr(gpc.config, 'NUM_CHUNKS'): 55 | # schedule = InterleavedPipelineSchedule(gpc.config.NUM_MICRO_BATCHES, 56 | # gpc.config.NUM_CHUNKS, 57 | # tensor_shape=tensor_shape, 58 | # scatter_gather_tensors=True) 59 | if isinstance(schedule, InterleavedPipelineSchedule) and not isinstance(model, torch.nn.ModuleList): 60 | model = torch.nn.ModuleList([model]) 61 | # else: 62 | # schedule = PipelineSchedule(gpc.config.NUM_MICRO_BATCHES, 63 | # tensor_shape=tensor_shape, 64 | # scatter_gather_tensors=True) 65 | else: 66 | schedule = None 67 | 68 | numel = 0 69 | for p in model.parameters(): 70 | numel += p.numel() 71 | logger.info( 72 | f'Rank {gpc.get_global_rank()}: {numel / (1024*1024):.2f} M parameters | memory usage = {torch.cuda.memory_allocated() / (1024 * 1024 * 1024):.2f} GB.' 73 | ) 74 | 75 | criterion = GPTLMLoss() 76 | 77 | optimizer = torch.optim.Adam(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY) 78 | 79 | steps_per_epoch = len(train_dataloader) // gpc.config.gradient_accumulation 80 | 81 | lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, 82 | total_steps=gpc.config.NUM_EPOCHS * steps_per_epoch, 83 | warmup_steps=gpc.config.WARMUP_EPOCHS * steps_per_epoch, 84 | eta_min=1e-5) 85 | 86 | engine, train_dataloader, _, lr_scheduler = colossalai.initialize(model=model, 87 | optimizer=optimizer, 88 | criterion=criterion, 89 | train_dataloader=train_dataloader, 90 | lr_scheduler=lr_scheduler) 91 | 92 | timer = MultiTimer() 93 | 94 | trainer = Trainer(engine=engine, logger=logger, timer=timer, schedule=schedule) 95 | 96 | hook_list = [ 97 | hooks.LogMetricByEpochHook(logger=logger), 98 | hooks.LogMetricByStepHook(), 99 | hooks.LossHook(), 100 | hooks.ThroughputHook(ignored_steps=5), 101 | hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False), 102 | # hooks.TensorboardHook(log_dir='./tb_logs', ranks=[0]), 103 | # hooks.LogMemoryByEpochHook(logger), 104 | # hooks.LogTimingByEpochHook(timer, logger, ignore_num_train_steps=5), 105 | # hooks.SaveCheckpointHook(checkpoint_dir='./ckpt') 106 | ] 107 | 108 | logger.info("Training start", ranks=[0]) 109 | torch.cuda.reset_peak_memory_stats() 110 | trainer.fit(train_dataloader=train_dataloader, 111 | epochs=gpc.config.NUM_EPOCHS, 112 | max_steps=10, 113 | hooks=hook_list, 114 | return_output_label=False, 115 | display_progress=True) 116 | 117 | logger.info( 118 | f'Rank {gpc.get_global_rank()}: peak memory usage = {torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024):.2f} GB.' 119 | ) 120 | 121 | 122 | if __name__ == '__main__': 123 | train_gpt() 124 | -------------------------------------------------------------------------------- /bert/common/helper.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.distributed import get_world_size 5 | from transformers import BertConfig, BertTokenizer 6 | 7 | from zero.common.utils import CONFIG, ModelFromHF, get_model_size 8 | from bert.colossalai_utils.model_zoo.colo_bert import ColoBertMaskedLMLoss, ColoBertForMaskedLM, create_colo_bert_pipeline_model 9 | 10 | _bert_base = dict( 11 | seq_length=512, 12 | vocab_size=50304, 13 | hidden_size=768, 14 | num_heads=12, 15 | depth=12, 16 | ff_size=3072, 17 | checkpoint=False, 18 | evaluation='ppl', 19 | ) 20 | 21 | _bert_large = dict( 22 | seq_length=512, 23 | vocab_size=50304, 24 | hidden_size=1024, 25 | num_heads=16, 26 | depth=24, 27 | ff_size=3072, 28 | checkpoint=False, 29 | evaluation='ppl', 30 | ) 31 | 32 | _bert_configurations = dict( 33 | bert=_bert_base, 34 | bert_base=_bert_base, 35 | bert_large=_bert_large 36 | ) 37 | 38 | _default_hyperparameters = dict( 39 | tokenize_mode='concat', 40 | batch_size=8, 41 | learning_rate=5e-5, 42 | weight_decay=1e-2, 43 | num_epochs=2, 44 | warmup_epochs=1, 45 | steps_per_epoch=100, 46 | ) 47 | 48 | 49 | def build_data(): 50 | import random 51 | from functools import partial 52 | from itertools import chain 53 | 54 | import numpy as np 55 | from datasets import load_from_disk, set_progress_bar_enabled 56 | from torch.utils.data import DataLoader, DistributedSampler 57 | from transformers import DataCollatorForLanguageModeling 58 | 59 | world_size = get_world_size() 60 | 61 | set_progress_bar_enabled(False) 62 | dataset = load_from_disk(CONFIG['dataset']) 63 | tokenizer = BertTokenizer(vocab_file=CONFIG['tokenizer'] + '/vocab.txt') 64 | 65 | def tokenize(examples, mode='concat'): 66 | assert mode in ['concat', 'pad'] 67 | seq_len = CONFIG['model']['seq_length'] 68 | if mode == 'concat': 69 | examples = tokenizer(examples['text']) 70 | concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} 71 | total_length = len(concatenated_examples[list(examples.keys())[0]]) 72 | if total_length >= seq_len: 73 | total_length = (total_length // seq_len) * seq_len 74 | 75 | result = { 76 | k: [t[i:i + seq_len] for i in range(0, total_length, seq_len)] 77 | for k, t in concatenated_examples.items() 78 | } 79 | else: 80 | tokenizer.pad_token = tokenizer.unk_token 81 | result = tokenizer(examples, padding=True, truncation=True, max_length=seq_len, return_tensors='pt') 82 | 83 | return result 84 | 85 | tokenized_dataset = dataset.map(partial(tokenize, mode=CONFIG['hyperparameter']['tokenize_mode']), 86 | batched=True, 87 | num_proc=16, 88 | load_from_cache_file=False, 89 | keep_in_memory=True, 90 | remove_columns='text') 91 | 92 | def seed_worker(_): 93 | worker_seed = 1024 94 | np.random.seed(worker_seed) 95 | torch.manual_seed(worker_seed) 96 | random.seed(worker_seed) 97 | 98 | data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15) 99 | train_sampler = DistributedSampler(tokenized_dataset['train'], shuffle=True) if world_size > 1 else None 100 | train_data = DataLoader(tokenized_dataset['train'], 101 | shuffle=(train_sampler is None), 102 | sampler=train_sampler, 103 | drop_last=True, 104 | collate_fn=data_collator, 105 | worker_init_fn=seed_worker, 106 | batch_size=CONFIG['hyperparameter']['batch_size'], 107 | pin_memory=True) 108 | test_sampler = DistributedSampler(tokenized_dataset['validation'], shuffle=False) if world_size > 1 else None 109 | test_data = DataLoader(tokenized_dataset['validation'], 110 | sampler=test_sampler, 111 | drop_last=True, 112 | collate_fn=data_collator, 113 | worker_init_fn=seed_worker, 114 | batch_size=CONFIG['hyperparameter']['batch_size'], 115 | pin_memory=True) 116 | 117 | return train_data, test_data 118 | 119 | 120 | def build_model(): 121 | model_cfg = CONFIG['model'] 122 | bert_cfg = BertConfig(vocab_size=model_cfg['vocab_size'], 123 | hidden_size=model_cfg['hidden_size'], 124 | num_hidden_layers=model_cfg['depth'], 125 | num_attention_heads=model_cfg['num_heads'], 126 | intermediate_size=model_cfg['ff_size'], 127 | max_position_embeddings=model_cfg['seq_length'], 128 | use_cache=not CONFIG['model'].get('checkpoint', False)) 129 | 130 | use_pipeline = 'parallel' in CONFIG and 'pipeline' in CONFIG['parallel'] and int(CONFIG['parallel']['pipeline']) > 1 131 | if use_pipeline: 132 | model = create_colo_bert_pipeline_model(bert_cfg) 133 | else: 134 | model = ModelFromHF(bert_cfg, ColoBertForMaskedLM) 135 | 136 | return model 137 | 138 | def build_loss(): 139 | return ColoBertMaskedLMLoss() 140 | 141 | def build_optimizer(params): 142 | optimizer = torch.optim.AdamW(params, 143 | lr=CONFIG['hyperparameter']['learning_rate'], 144 | weight_decay=CONFIG['hyperparameter']['weight_decay']) 145 | return optimizer 146 | 147 | 148 | def build_scheduler(epoch_steps, optimizer): 149 | from transformers.optimization import get_linear_schedule_with_warmup 150 | from colossalai.nn.lr_scheduler import LinearWarmupLR 151 | 152 | max_steps = epoch_steps * CONFIG['hyperparameter']['num_epochs'] 153 | warmup_steps = epoch_steps * CONFIG['hyperparameter']['warmup_epochs'] 154 | 155 | if CONFIG['method'] == 'colossalai': 156 | lr_scheduler = LinearWarmupLR(optimizer, 157 | total_steps=max_steps, 158 | warmup_steps=warmup_steps) 159 | else: 160 | lr_scheduler = get_linear_schedule_with_warmup(optimizer, 161 | num_warmup_steps=warmup_steps, 162 | num_training_steps=max_steps) 163 | 164 | return lr_scheduler 165 | 166 | 167 | def bert_builder(): 168 | model_type = CONFIG['model']['type'] 169 | if model_type in _bert_configurations: 170 | for k, v in _bert_configurations[model_type].items(): 171 | if k not in CONFIG['model']: 172 | CONFIG['model'][k] = v 173 | 174 | if 'hyperparameter' in CONFIG: 175 | for k, v in _default_hyperparameters.items(): 176 | if k not in CONFIG['hyperparameter']: 177 | CONFIG['hyperparameter'][k] = v 178 | else: 179 | CONFIG['hyperparameter'] = _default_hyperparameters 180 | 181 | CONFIG['dataset'] = os.environ['DATA'] 182 | CONFIG['tokenizer'] = os.environ['TOKENIZER'] 183 | if 'numel' not in CONFIG['model']: 184 | CONFIG['model']['numel'] = get_model_size(build_model()) 185 | 186 | return build_data, build_model, build_loss, build_optimizer, build_scheduler 187 | -------------------------------------------------------------------------------- /imagenet1k/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import glob 5 | import os 6 | 7 | import colossalai 8 | import nvidia.dali.fn as fn 9 | import nvidia.dali.tfrecord as tfrec 10 | import torch 11 | from colossalai.builder import * 12 | from colossalai.context import ParallelMode 13 | from colossalai.core import global_context as gpc 14 | from colossalai.logging import disable_existing_loggers, get_dist_logger 15 | from colossalai.nn import Accuracy, CrossEntropyLoss 16 | from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR 17 | from colossalai.trainer import Trainer, hooks 18 | from colossalai.utils import MultiTimer 19 | from model_zoo.vit import vit_small_patch16_224 20 | from nvidia.dali import types 21 | from nvidia.dali.pipeline import Pipeline 22 | from nvidia.dali.plugin.pytorch import DALIClassificationIterator 23 | 24 | DATASET_PATH = str(os.environ['DATA']) 25 | 26 | TRAIN_RECS = DATASET_PATH + '/train/*' 27 | VAL_RECS = DATASET_PATH + '/validation/*' 28 | TRAIN_IDX = DATASET_PATH + '/idx_files/train/*' 29 | VAL_IDX = DATASET_PATH + '/idx_files/validation/*' 30 | 31 | 32 | class DaliDataloader(DALIClassificationIterator): 33 | 34 | def __init__(self, 35 | tfrec_filenames, 36 | tfrec_idx_filenames, 37 | shard_id=0, 38 | num_shards=1, 39 | batch_size=128, 40 | num_threads=4, 41 | resize=256, 42 | crop=224, 43 | prefetch=2, 44 | training=True, 45 | gpu_aug=False, 46 | cuda=True): 47 | pipe = Pipeline(batch_size=batch_size, 48 | num_threads=num_threads, 49 | device_id=torch.cuda.current_device() if cuda else None, 50 | seed=1024) 51 | with pipe: 52 | inputs = fn.readers.tfrecord(path=tfrec_filenames, 53 | index_path=tfrec_idx_filenames, 54 | random_shuffle=training, 55 | shard_id=shard_id, 56 | num_shards=num_shards, 57 | initial_fill=10000, 58 | read_ahead=True, 59 | prefetch_queue_depth=prefetch, 60 | name='Reader', 61 | features={ 62 | 'image/encoded': tfrec.FixedLenFeature((), tfrec.string, ""), 63 | 'image/class/label': tfrec.FixedLenFeature([1], tfrec.int64, -1), 64 | }) 65 | images = inputs["image/encoded"] 66 | 67 | if training: 68 | images = fn.decoders.image(images, device='mixed' if gpu_aug else 'cpu', output_type=types.RGB) 69 | images = fn.random_resized_crop(images, size=crop, device='gpu' if gpu_aug else 'cpu') 70 | flip_lr = fn.random.coin_flip(probability=0.5) 71 | else: 72 | # decode jpeg and resize 73 | images = fn.decoders.image(images, device='mixed' if gpu_aug else 'cpu', output_type=types.RGB) 74 | images = fn.resize(images, 75 | device='gpu' if gpu_aug else 'cpu', 76 | resize_x=resize, 77 | resize_y=resize, 78 | dtype=types.FLOAT, 79 | interp_type=types.INTERP_TRIANGULAR) 80 | flip_lr = False 81 | 82 | # center crop and normalise 83 | images = fn.crop_mirror_normalize(images, 84 | dtype=types.FLOAT, 85 | crop=(crop, crop), 86 | mean=[127.5], 87 | std=[127.5], 88 | mirror=flip_lr) 89 | label = inputs["image/class/label"] - 1 # 0-999 90 | # LSG: element_extract will raise exception, let's flatten outside 91 | # label = fn.element_extract(label, element_map=0) # Flatten 92 | if cuda: # transfer data to gpu 93 | pipe.set_outputs(images.gpu(), label.gpu()) 94 | else: 95 | pipe.set_outputs(images, label) 96 | 97 | pipe.build() 98 | last_batch_policy = 'DROP' if training else 'PARTIAL' 99 | super().__init__(pipe, reader_name="Reader", auto_reset=True, last_batch_policy=last_batch_policy) 100 | 101 | def __iter__(self): 102 | # if not reset (after an epoch), reset; if just initialize, ignore 103 | if self._counter >= self._size or self._size < 0: 104 | self.reset() 105 | return self 106 | 107 | def __next__(self): 108 | data = super().__next__() 109 | img, label = data[0]['data'], data[0]['label'] 110 | label = label.squeeze() 111 | return img, label 112 | 113 | 114 | def build_dali_train(batch_size): 115 | return DaliDataloader( 116 | sorted(glob.glob(TRAIN_RECS)), 117 | sorted(glob.glob(TRAIN_IDX)), 118 | batch_size=batch_size, 119 | shard_id=gpc.get_local_rank(ParallelMode.DATA), 120 | num_shards=gpc.get_world_size(ParallelMode.DATA), 121 | training=True, 122 | gpu_aug=True, 123 | cuda=True, 124 | ) 125 | 126 | 127 | def build_dali_test(batch_size): 128 | return DaliDataloader( 129 | sorted(glob.glob(VAL_RECS)), 130 | sorted(glob.glob(VAL_IDX)), 131 | batch_size=batch_size, 132 | shard_id=gpc.get_local_rank(ParallelMode.DATA), 133 | num_shards=gpc.get_world_size(ParallelMode.DATA), 134 | training=False, 135 | gpu_aug=True, 136 | cuda=True, 137 | ) 138 | 139 | 140 | def train_imagenet(): 141 | disable_existing_loggers() 142 | parser = colossalai.get_default_parser() 143 | parser.add_argument('--from_torch', default=False, action='store_true') 144 | args = parser.parse_args() 145 | if args.from_torch: 146 | colossalai.launch_from_torch(config=args.config, seed=42) 147 | else: 148 | # standard launch 149 | colossalai.launch(config=args.config, 150 | rank=args.rank, 151 | world_size=args.world_size, 152 | local_rank=args.local_rank, 153 | host=args.host, 154 | port=args.port, 155 | seed=42) 156 | 157 | logger = get_dist_logger() 158 | if hasattr(gpc.config, 'LOG_PATH'): 159 | if gpc.get_global_rank() == 0: 160 | log_path = gpc.config.LOG_PATH 161 | if not os.path.exists(log_path): 162 | os.mkdir(log_path) 163 | logger.log_to_file(log_path) 164 | 165 | model = vit_small_patch16_224(num_classes=1000, init_method='jax') 166 | 167 | train_dataloader = build_dali_train(gpc.config.BATCH_SIZE // gpc.data_parallel_size) 168 | test_dataloader = build_dali_test(gpc.config.BATCH_SIZE // gpc.data_parallel_size) 169 | 170 | criterion = CrossEntropyLoss(label_smoothing=0.1) 171 | 172 | optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY) 173 | 174 | lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, 175 | total_steps=gpc.config.NUM_EPOCHS, 176 | warmup_steps=gpc.config.WARMUP_EPOCHS) 177 | 178 | engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model=model, 179 | optimizer=optimizer, 180 | criterion=criterion, 181 | train_dataloader=train_dataloader, 182 | test_dataloader=test_dataloader) 183 | 184 | logger.info("Engine is built", ranks=[0]) 185 | 186 | timer = MultiTimer() 187 | 188 | trainer = Trainer(engine=engine, logger=logger, timer=timer) 189 | logger.info("Trainer is built", ranks=[0]) 190 | 191 | hook_list = [ 192 | hooks.LogMetricByEpochHook(logger=logger), 193 | hooks.LogMetricByStepHook(), 194 | # hooks.LogTimingByEpochHook(timer=timer, logger=logger), 195 | # hooks.LogMemoryByEpochHook(logger=logger), 196 | hooks.AccuracyHook(accuracy_func=Accuracy()), 197 | hooks.LossHook(), 198 | hooks.ThroughputHook(), 199 | hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True) 200 | ] 201 | 202 | logger.info("Train start", ranks=[0]) 203 | trainer.fit(train_dataloader=train_dataloader, 204 | test_dataloader=test_dataloader, 205 | epochs=gpc.config.NUM_EPOCHS, 206 | hooks=hook_list, 207 | display_progress=True, 208 | test_interval=1) 209 | 210 | 211 | if __name__ == '__main__': 212 | train_imagenet() 213 | -------------------------------------------------------------------------------- /bert/common/train.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | 4 | import torch 5 | from torch.distributed import all_reduce, get_rank, get_world_size 6 | from tqdm import tqdm 7 | 8 | from zero.common.utils import CONFIG, AsyncMemoryMonitor, print_log, get_tflops 9 | 10 | 11 | def _train(epoch, rank, world_size, train_dataloader, model, criterion, optimizer, lr_scheduler, scaler, mem_monitor): 12 | use_optimizer_backward = CONFIG['method'] in ['colossalai'] 13 | use_integrated_backward = CONFIG['method'] in ['deepspeed', 'patrickstar'] 14 | use_integrated_step = CONFIG['method'] in ['deepspeed'] 15 | use_autocast = CONFIG['method'] in ['torch', 'colossalai'] and \ 16 | 'fp16' in CONFIG and CONFIG['fp16'].get('enabled', True) 17 | clip_grad_norm = CONFIG.get('gradient_clipping', 0.) 18 | use_integraded_clip_grad = CONFIG['method'] in ['fairscale'] 19 | use_colossalai_zero_v1 = CONFIG['method'] == 'colossalai' and CONFIG.get('sharded_model_version', 2) == 1 20 | 21 | model.train() 22 | 23 | num_steps = len(train_dataloader) 24 | if 'steps_per_epoch' in CONFIG['hyperparameter'] and CONFIG['hyperparameter']['steps_per_epoch'] < num_steps: 25 | num_steps = CONFIG['hyperparameter']['steps_per_epoch'] 26 | progress = range(num_steps) 27 | 28 | if rank == 0: 29 | progress = tqdm(progress, desc=f"[Epoch {epoch} / Train]") 30 | 31 | train_loss = torch.zeros(()).to(torch.float).to(rank) 32 | used_time = 0. 33 | num_steps = 0 34 | num_samples = torch.zeros(()).to(torch.int).to(rank) 35 | num_tokens = torch.zeros(()).to(torch.int).to(rank) 36 | 37 | data_iter = iter(train_dataloader) 38 | 39 | if mem_monitor is not None: 40 | mem_monitor.start() 41 | 42 | for _ in progress: 43 | fwd_start = time.time() 44 | 45 | optimizer.zero_grad() 46 | 47 | if use_colossalai_zero_v1: 48 | model.zero_grad(set_to_none=True) 49 | 50 | batch = next(data_iter) 51 | 52 | labels = batch.pop('labels') 53 | batch_size = None 54 | batch_tokens = None 55 | if isinstance(labels, torch.Tensor): 56 | labels = labels.to(rank) 57 | batch_size = labels.size(0) 58 | batch_tokens = labels.numel() 59 | else: 60 | for k, v in labels.items(): 61 | labels[k] = v.to(rank) 62 | if batch_size is None: 63 | batch_size = v.size(0) 64 | if batch_tokens is None: 65 | batch_tokens = v.numel() 66 | 67 | for k, v in batch.items(): 68 | batch[k] = v.to(rank) 69 | 70 | if use_autocast: 71 | with torch.cuda.amp.autocast(): 72 | outputs = model(**batch) 73 | else: 74 | outputs = model(**batch) 75 | 76 | loss = criterion(outputs, labels) 77 | train_loss += loss 78 | 79 | fwd_end = time.time() 80 | 81 | bwd_start = time.time() 82 | 83 | optimizer.backward(loss) 84 | optimizer.step() 85 | lr_scheduler.step() 86 | 87 | bwd_end = time.time() 88 | 89 | num_steps += 1 90 | num_samples += batch_size 91 | num_tokens += batch_tokens 92 | 93 | fwd_time = fwd_end - fwd_start 94 | bwd_time = bwd_end - bwd_start 95 | batch_time = fwd_time + bwd_time 96 | used_time += batch_time 97 | 98 | if rank == 0: 99 | progress.set_postfix(loss=loss.item(), 100 | lr=lr_scheduler.get_last_lr()[0], 101 | time_forward=fwd_time, 102 | time_backward=bwd_time, 103 | throughput=batch_size * world_size / (batch_time + 1e-12), 104 | tflops=get_tflops(batch_time, batch_tokens * world_size)) 105 | 106 | peak_mem = None 107 | if mem_monitor is not None: 108 | peak_mem = max(mem_monitor.finish()) 109 | 110 | all_reduce(train_loss) 111 | all_reduce(num_samples) 112 | all_reduce(num_tokens) 113 | 114 | msg = f'[Epoch {epoch} / Train]: Loss = {train_loss.item() / (world_size * num_steps):.3f}' 115 | msg += f' | Throughput = {num_samples.item() / (used_time + 1e-12):.3f} samples/sec' 116 | msg += f' | TFLOPS = {get_tflops(used_time, num_tokens.item()):.3f}' 117 | if peak_mem is not None: 118 | msg += f' | Peak memory = {peak_mem / 1024:.3f} GB.' 119 | print_log(msg) 120 | 121 | 122 | def _test(epoch, rank, world_size, test_dataloader, model, criterion, mem_monitor): 123 | use_autocast = CONFIG['method'] in ['torch', 'colossalai'] and \ 124 | 'fp16' in CONFIG and CONFIG['fp16'].get('enabled', True) 125 | evaluation = CONFIG['model']['evaluation'] 126 | 127 | model.eval() 128 | 129 | num_steps = len(test_dataloader) 130 | if 'steps_per_epoch' in CONFIG['hyperparameter'] and CONFIG['hyperparameter']['steps_per_epoch'] < num_steps: 131 | num_steps = CONFIG['hyperparameter']['steps_per_epoch'] 132 | progress = range(num_steps) 133 | if rank == 0: 134 | progress = tqdm(progress, desc=f"[Epoch {epoch} / Test]") 135 | 136 | test_loss = torch.zeros(()).to(torch.float).to(rank) 137 | used_time = 0. 138 | num_steps = 0 139 | num_samples = torch.zeros(()).to(torch.int).to(rank) 140 | num_tokens = torch.zeros(()).to(torch.int).to(rank) 141 | correct = torch.zeros(()).to(torch.int).to(rank) 142 | 143 | data_iter = iter(test_dataloader) 144 | 145 | if mem_monitor is not None: 146 | mem_monitor.start() 147 | 148 | with torch.no_grad(): 149 | for _ in progress: 150 | batch_start = time.time() 151 | 152 | batch = next(data_iter) 153 | 154 | labels = batch.pop('labels') 155 | batch_size = None 156 | batch_tokens = None 157 | if isinstance(labels, torch.Tensor): 158 | labels = labels.to(rank) 159 | batch_size = labels.size(0) 160 | batch_tokens = labels.numel() 161 | else: 162 | for k, v in labels.items(): 163 | labels[k] = v.to(rank) 164 | if batch_size is None: 165 | batch_size = v.size(0) 166 | if batch_tokens is None: 167 | batch_tokens = v.numel() 168 | 169 | for k, v in batch.items(): 170 | batch[k] = v.to(rank) 171 | if use_autocast: 172 | with torch.cuda.amp.autocast(): 173 | outputs = model(**batch) 174 | else: 175 | outputs = model(**batch) 176 | 177 | loss = criterion(outputs, labels) 178 | test_loss += loss 179 | 180 | batch_end = time.time() 181 | 182 | num_steps += 1 183 | num_samples += batch_size 184 | num_tokens += batch_tokens 185 | 186 | batch_time = batch_end - batch_start 187 | used_time += batch_time 188 | 189 | if rank == 0: 190 | metrics = dict(loss=loss.item(), 191 | step_time=batch_time, 192 | throughput=batch_size * world_size / (batch_time + 1e-12), 193 | tflops=get_tflops(batch_time, batch_tokens * world_size)) 194 | if evaluation == 'ppl': 195 | metrics['perplexity'] = math.exp(loss.item()) 196 | elif evaluation == 'acc': 197 | if not isinstance(labels, torch.Tensor): 198 | labels = labels['targets_a'] 199 | batch_correct = torch.sum(labels == torch.argmax(outputs, dim=-1)).item() 200 | metrics['accuracy'] = batch_correct / batch_size 201 | correct += batch_correct 202 | else: 203 | raise ValueError(f'Invalid evaluation method {evaluation}') 204 | progress.set_postfix(**metrics) 205 | 206 | peak_mem = None 207 | if mem_monitor is not None: 208 | peak_mem = max(mem_monitor.finish()) 209 | 210 | all_reduce(test_loss) 211 | reduced_loss = test_loss.item() / (world_size * num_steps) 212 | all_reduce(num_samples) 213 | all_reduce(num_tokens) 214 | if evaluation == 'acc': 215 | all_reduce(correct) 216 | 217 | msg = f'[Epoch {epoch} / Test]: Loss = {reduced_loss:.3f}' 218 | if evaluation == 'ppl': 219 | msg += f' | Perplexity = {math.exp(reduced_loss):.3f}' 220 | else: 221 | msg += f' | Accuracy = {correct.item() * 100 / num_samples.item():.3f} %' 222 | msg += f' | Throughput = {num_samples.item() / (used_time + 1e-12):.3f} samples/sec' 223 | msg += f' | TFLOPS = {get_tflops(used_time, num_tokens.item()):.3f}' 224 | if peak_mem is not None: 225 | msg += f' | Peak memory = {peak_mem / 1024:.3f} GB.' 226 | print_log(msg) 227 | 228 | 229 | def train(model, train_data, test_data, criterion, optimizer, scaler, lr_scheduler): 230 | use_pipeline = 'parallel' in CONFIG and 'pipeline' in CONFIG['parallel'] and int(CONFIG['parallel']['pipeline']) > 1 231 | 232 | rank = get_rank() 233 | world_size = get_world_size() 234 | 235 | mem_monitor = None 236 | if CONFIG.get('use_mem_monitor'): 237 | mem_monitor = AsyncMemoryMonitor(rank) 238 | 239 | numel = CONFIG['model']['numel'] 240 | if numel < 1e9: 241 | msg = f'{numel / 1e6:.3f} M' 242 | else: 243 | msg = f'{numel / 1e9:.3f} B' 244 | print_log(f'Model is built (parameter size = {msg}).') 245 | 246 | print_log('Benchmark start.') 247 | 248 | if use_pipeline: 249 | import colossalai.nn as col_nn 250 | from colossalai.engine.schedule import PipelineSchedule 251 | from colossalai.utils import MultiTimer, get_dataloader 252 | from colossalai.logging import get_dist_logger 253 | from colossalai.trainer import Trainer, hooks 254 | 255 | def batch_data_process_func(batch_data): 256 | data = {'input_ids': batch_data['input_ids'], 'token_type_ids': batch_data['token_type_ids'], 'attention_mask': batch_data['attention_mask']} 257 | labels = batch_data['labels'] 258 | return data, labels 259 | 260 | timer = MultiTimer() 261 | schedule = PipelineSchedule(num_microbatches=2, batch_data_process_func=batch_data_process_func) 262 | engine = model 263 | logger = get_dist_logger() 264 | trainer = Trainer(engine=engine, timer=timer, logger=logger, schedule=schedule) 265 | 266 | hook_list = [ 267 | hooks.LossHook(), 268 | hooks.AccuracyHook(col_nn.metric.Accuracy()), 269 | hooks.LogMetricByEpochHook(logger), 270 | hooks.LRSchedulerHook(lr_scheduler, by_epoch=True) 271 | ] 272 | 273 | trainer.fit(train_dataloader=train_data, 274 | epochs=CONFIG['hyperparameter']['num_epochs'], 275 | test_dataloader=test_data, 276 | test_interval=1, 277 | hooks=hook_list, 278 | display_progress=True) 279 | 280 | else: 281 | for epoch in range(CONFIG['hyperparameter']['num_epochs']): 282 | _train(epoch, rank, world_size, train_data, model, criterion, optimizer, lr_scheduler, scaler, mem_monitor) 283 | _test(epoch, rank, world_size, test_data, model, criterion, mem_monitor) 284 | 285 | print_log('Benchmark complete.') 286 | -------------------------------------------------------------------------------- /zero/common/train.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | 4 | import torch 5 | from torch.distributed import all_reduce, get_rank, get_world_size 6 | from tqdm import tqdm 7 | 8 | from zero.common.utils import CONFIG, AsyncMemoryMonitor, print_log, get_tflops 9 | 10 | 11 | def _train(epoch, rank, world_size, train_dataloader, model, criterion, optimizer, lr_scheduler, scaler, mem_monitor): 12 | use_optimizer_backward = CONFIG['method'] in ['colossalai'] 13 | use_integrated_backward = CONFIG['method'] in ['deepspeed', 'patrickstar'] 14 | use_integrated_step = CONFIG['method'] in ['deepspeed'] 15 | use_autocast = CONFIG['method'] in ['torch', 'colossalai'] and \ 16 | 'fp16' in CONFIG and CONFIG['fp16'].get('enabled', True) 17 | clip_grad_norm = CONFIG.get('gradient_clipping', 0.) 18 | use_integraded_clip_grad = CONFIG['method'] in ['fairscale'] 19 | use_colossalai_zero_v1 = CONFIG['method'] == 'colossalai' and CONFIG.get('sharded_model_version', 2) == 1 20 | 21 | model.train() 22 | 23 | num_steps = len(train_dataloader) 24 | if 'steps_per_epoch' in CONFIG['hyperparameter'] and CONFIG['hyperparameter']['steps_per_epoch'] < num_steps: 25 | num_steps = CONFIG['hyperparameter']['steps_per_epoch'] 26 | progress = range(num_steps) 27 | 28 | if rank == 0: 29 | progress = tqdm(progress, desc=f"[Epoch {epoch} / Train]") 30 | 31 | train_loss = torch.zeros(()).to(torch.float).to(rank) 32 | used_time = 0. 33 | num_steps = 0 34 | num_samples = torch.zeros(()).to(torch.int).to(rank) 35 | num_tokens = torch.zeros(()).to(torch.int).to(rank) 36 | 37 | data_iter = iter(train_dataloader) 38 | 39 | if mem_monitor is not None: 40 | mem_monitor.start() 41 | 42 | for _ in progress: 43 | fwd_start = time.time() 44 | 45 | optimizer.zero_grad() 46 | 47 | if use_colossalai_zero_v1: 48 | model.zero_grad(set_to_none=True) 49 | 50 | batch = next(data_iter) 51 | 52 | labels = batch.pop('labels') 53 | batch_size = None 54 | batch_tokens = None 55 | if isinstance(labels, torch.Tensor): 56 | labels = labels.to(rank) 57 | batch_size = labels.size(0) 58 | batch_tokens = labels.numel() 59 | else: 60 | for k, v in labels.items(): 61 | labels[k] = v.to(rank) 62 | if batch_size is None: 63 | batch_size = v.size(0) 64 | if batch_tokens is None: 65 | batch_tokens = v.numel() 66 | 67 | for k, v in batch.items(): 68 | batch[k] = v.to(rank) 69 | 70 | if use_autocast: 71 | with torch.cuda.amp.autocast(): 72 | outputs = model(**batch) 73 | else: 74 | outputs = model(**batch) 75 | 76 | loss = criterion(outputs, labels) 77 | train_loss += loss 78 | 79 | fwd_end = time.time() 80 | 81 | bwd_start = time.time() 82 | 83 | if use_colossalai_zero_v1: 84 | loss.backward() 85 | optimizer.step() 86 | lr_scheduler.step() 87 | elif use_integrated_backward: # deepspeed & patrickstar style 88 | model.backward(loss) 89 | if use_integrated_step: 90 | model.step() # deepspeed style 91 | else: 92 | optimizer.step() # patrickstar style 93 | lr_scheduler.step() 94 | 95 | elif use_optimizer_backward: # colossalai style 96 | optimizer.backward(loss) 97 | if clip_grad_norm > 0: 98 | optimizer.clip_grad_norm(model, clip_grad_norm) 99 | optimizer.step() 100 | lr_scheduler.step() 101 | 102 | elif scaler is not None: # torch & fairscale amp style 103 | scaler.scale(loss).backward() 104 | scaler.unscale_(optimizer) 105 | if clip_grad_norm > 0: 106 | if use_integraded_clip_grad: # fairscale style 107 | model.clip_grad_norm_(clip_grad_norm) 108 | else: # torch style 109 | torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm) 110 | scaler.step(optimizer) 111 | scaler.update() 112 | lr_scheduler.step() 113 | 114 | else: # torch & fairscale normal style 115 | loss.backward() 116 | if clip_grad_norm > 0: 117 | if use_integraded_clip_grad: # fairscale style 118 | model.clip_grad_norm_(clip_grad_norm) 119 | else: # torch style 120 | torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm) 121 | optimizer.step() 122 | lr_scheduler.step() 123 | 124 | bwd_end = time.time() 125 | 126 | num_steps += 1 127 | num_samples += batch_size 128 | num_tokens += batch_tokens 129 | 130 | fwd_time = fwd_end - fwd_start 131 | bwd_time = bwd_end - bwd_start 132 | batch_time = fwd_time + bwd_time 133 | used_time += batch_time 134 | 135 | if rank == 0: 136 | progress.set_postfix(loss=loss.item(), 137 | lr=lr_scheduler.get_last_lr()[0], 138 | time_forward=fwd_time, 139 | time_backward=bwd_time, 140 | throughput=batch_size * world_size / (batch_time + 1e-12), 141 | tflops=get_tflops(batch_time, batch_tokens * world_size)) 142 | 143 | peak_mem = None 144 | if mem_monitor is not None: 145 | peak_mem = max(mem_monitor.finish()) 146 | 147 | all_reduce(train_loss) 148 | all_reduce(num_samples) 149 | all_reduce(num_tokens) 150 | 151 | msg = f'[Epoch {epoch} / Train]: Loss = {train_loss.item() / (world_size * num_steps):.3f}' 152 | msg += f' | Throughput = {num_samples.item() / (used_time + 1e-12):.3f} samples/sec' 153 | msg += f' | TFLOPS = {get_tflops(used_time, num_tokens.item()):.3f}' 154 | if peak_mem is not None: 155 | msg += f' | Peak memory = {peak_mem / 1024:.3f} GB.' 156 | print_log(msg) 157 | 158 | 159 | def _test(epoch, rank, world_size, test_dataloader, model, criterion, mem_monitor): 160 | use_autocast = CONFIG['method'] in ['torch', 'colossalai'] and \ 161 | 'fp16' in CONFIG and CONFIG['fp16'].get('enabled', True) 162 | evaluation = CONFIG['model']['evaluation'] 163 | 164 | model.eval() 165 | 166 | num_steps = len(test_dataloader) 167 | if 'steps_per_epoch' in CONFIG['hyperparameter'] and CONFIG['hyperparameter']['steps_per_epoch'] < num_steps: 168 | num_steps = CONFIG['hyperparameter']['steps_per_epoch'] 169 | progress = range(num_steps) 170 | if rank == 0: 171 | progress = tqdm(progress, desc=f"[Epoch {epoch} / Test]") 172 | 173 | test_loss = torch.zeros(()).to(torch.float).to(rank) 174 | used_time = 0. 175 | num_steps = 0 176 | num_samples = torch.zeros(()).to(torch.int).to(rank) 177 | num_tokens = torch.zeros(()).to(torch.int).to(rank) 178 | correct = torch.zeros(()).to(torch.int).to(rank) 179 | 180 | data_iter = iter(test_dataloader) 181 | 182 | if mem_monitor is not None: 183 | mem_monitor.start() 184 | 185 | with torch.no_grad(): 186 | for _ in progress: 187 | batch_start = time.time() 188 | 189 | batch = next(data_iter) 190 | 191 | labels = batch.pop('labels') 192 | batch_size = None 193 | batch_tokens = None 194 | if isinstance(labels, torch.Tensor): 195 | labels = labels.to(rank) 196 | batch_size = labels.size(0) 197 | batch_tokens = labels.numel() 198 | else: 199 | for k, v in labels.items(): 200 | labels[k] = v.to(rank) 201 | if batch_size is None: 202 | batch_size = v.size(0) 203 | if batch_tokens is None: 204 | batch_tokens = v.numel() 205 | 206 | for k, v in batch.items(): 207 | batch[k] = v.to(rank) 208 | if use_autocast: 209 | with torch.cuda.amp.autocast(): 210 | outputs = model(**batch) 211 | else: 212 | outputs = model(**batch) 213 | 214 | loss = criterion(outputs, labels) 215 | test_loss += loss 216 | 217 | batch_end = time.time() 218 | 219 | num_steps += 1 220 | num_samples += batch_size 221 | num_tokens += batch_tokens 222 | 223 | batch_time = batch_end - batch_start 224 | used_time += batch_time 225 | 226 | if rank == 0: 227 | metrics = dict(loss=loss.item(), 228 | step_time=batch_time, 229 | throughput=batch_size * world_size / (batch_time + 1e-12), 230 | tflops=get_tflops(batch_time, batch_tokens * world_size)) 231 | if evaluation == 'ppl': 232 | metrics['perplexity'] = math.exp(loss.item()) 233 | elif evaluation == 'acc': 234 | if not isinstance(labels, torch.Tensor): 235 | labels = labels['targets_a'] 236 | batch_correct = torch.sum(labels == torch.argmax(outputs, dim=-1)).item() 237 | metrics['accuracy'] = batch_correct / batch_size 238 | correct += batch_correct 239 | else: 240 | raise ValueError(f'Invalid evaluation method {evaluation}') 241 | progress.set_postfix(**metrics) 242 | 243 | peak_mem = None 244 | if mem_monitor is not None: 245 | peak_mem = max(mem_monitor.finish()) 246 | 247 | all_reduce(test_loss) 248 | reduced_loss = test_loss.item() / (world_size * num_steps) 249 | all_reduce(num_samples) 250 | all_reduce(num_tokens) 251 | if evaluation == 'acc': 252 | all_reduce(correct) 253 | 254 | msg = f'[Epoch {epoch} / Test]: Loss = {reduced_loss:.3f}' 255 | if evaluation == 'ppl': 256 | msg += f' | Perplexity = {math.exp(reduced_loss):.3f}' 257 | else: 258 | msg += f' | Accuracy = {correct.item() * 100 / num_samples.item():.3f} %' 259 | msg += f' | Throughput = {num_samples.item() / (used_time + 1e-12):.3f} samples/sec' 260 | msg += f' | TFLOPS = {get_tflops(used_time, num_tokens.item()):.3f}' 261 | if peak_mem is not None: 262 | msg += f' | Peak memory = {peak_mem / 1024:.3f} GB.' 263 | print_log(msg) 264 | 265 | 266 | def train(model, train_data, test_data, criterion, optimizer, scaler, lr_scheduler): 267 | rank = get_rank() 268 | world_size = get_world_size() 269 | 270 | mem_monitor = None 271 | if CONFIG.get('use_mem_monitor'): 272 | mem_monitor = AsyncMemoryMonitor(rank) 273 | 274 | numel = CONFIG['model']['numel'] 275 | if numel < 1e9: 276 | msg = f'{numel / 1e6:.3f} M' 277 | else: 278 | msg = f'{numel / 1e9:.3f} B' 279 | print_log(f'Model is built (parameter size = {msg}).') 280 | 281 | print_log('Benchmark start.') 282 | 283 | for epoch in range(CONFIG['hyperparameter']['num_epochs']): 284 | _train(epoch, rank, world_size, train_data, model, criterion, optimizer, lr_scheduler, scaler, mem_monitor) 285 | _test(epoch, rank, world_size, test_data, model, criterion, mem_monitor) 286 | 287 | print_log('Benchmark complete.') 288 | -------------------------------------------------------------------------------- /zero/common/vit.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.distributed import get_rank, get_world_size 5 | from transformers import ViTConfig, ViTForImageClassification 6 | 7 | from zero.common.utils import CONFIG, ModelFromHF 8 | 9 | _vit_b = dict( 10 | img_size=224, 11 | patch_size=16, 12 | hidden_size=768, 13 | intermediate_size=3072, 14 | num_heads=12, 15 | depth=12, 16 | dropout=0.1, 17 | num_labels=1000, 18 | numel=86567656, 19 | checkpoint=False, 20 | evaluation='acc', 21 | ) 22 | 23 | _vit_h = dict( 24 | img_size=224, 25 | patch_size=16, 26 | hidden_size=1280, 27 | intermediate_size=5120, 28 | num_heads=16, 29 | depth=32, 30 | dropout=0.1, 31 | num_labels=1000, 32 | numel=632199400, 33 | checkpoint=True, 34 | evaluation='acc', 35 | ) 36 | 37 | _vit_g = dict( 38 | img_size=224, 39 | patch_size=14, 40 | hidden_size=1664, 41 | intermediate_size=8192, 42 | num_heads=16, 43 | depth=48, 44 | dropout=0.1, 45 | num_labels=1000, 46 | numel=1844440680, 47 | checkpoint=True, 48 | evaluation='acc', 49 | ) 50 | 51 | _vit_10b = dict( 52 | img_size=224, 53 | patch_size=16, 54 | hidden_size=4096, 55 | intermediate_size=16384, 56 | num_heads=16, 57 | depth=50, 58 | dropout=0.1, 59 | num_labels=1000, 60 | numel=10077058024, 61 | checkpoint=True, 62 | evaluation='acc', 63 | ) 64 | 65 | _vit_configurations = dict( 66 | vit=_vit_b, 67 | vit_b=_vit_b, 68 | vit_h=_vit_h, 69 | vit_g=_vit_g, 70 | vit_10b=_vit_10b, 71 | ) 72 | 73 | _default_hyperparameters = dict( 74 | batch_size=4, 75 | mixup_alpha=0.2, 76 | learning_rate=3e-3, 77 | weight_decay=0.3, 78 | num_epochs=2, 79 | warmup_epochs=1, 80 | steps_per_epoch=100, 81 | ) 82 | 83 | 84 | def build_data(): 85 | import glob 86 | 87 | import numpy as np 88 | import nvidia.dali.fn as fn 89 | import nvidia.dali.tfrecord as tfrec 90 | import nvidia.dali.types as types 91 | from nvidia.dali.pipeline import Pipeline 92 | from nvidia.dali.plugin.pytorch import DALIClassificationIterator, LastBatchPolicy 93 | 94 | class DaliDataloader(DALIClassificationIterator): 95 | 96 | def __init__(self, 97 | tfrec_filenames, 98 | tfrec_idx_filenames, 99 | shard_id=0, 100 | num_shards=1, 101 | batch_size=128, 102 | num_threads=4, 103 | resize=256, 104 | crop=224, 105 | prefetch=2, 106 | training=True, 107 | gpu_aug=False, 108 | cuda=True, 109 | mixup_alpha=0.0): 110 | self.mixup_alpha = mixup_alpha 111 | self.training = training 112 | pipe = Pipeline(batch_size=batch_size, 113 | num_threads=num_threads, 114 | device_id=torch.cuda.current_device() if cuda else None, 115 | seed=1024) 116 | with pipe: 117 | inputs = fn.readers.tfrecord(path=tfrec_filenames, 118 | index_path=tfrec_idx_filenames, 119 | random_shuffle=training, 120 | shard_id=shard_id, 121 | num_shards=num_shards, 122 | initial_fill=10000, 123 | read_ahead=True, 124 | prefetch_queue_depth=prefetch, 125 | name='Reader', 126 | features={ 127 | 'image/encoded': tfrec.FixedLenFeature((), tfrec.string, ""), 128 | 'image/class/label': tfrec.FixedLenFeature([1], tfrec.int64, -1), 129 | }) 130 | images = inputs["image/encoded"] 131 | 132 | if training: 133 | images = fn.decoders.image(images, device='mixed' if gpu_aug else 'cpu', output_type=types.RGB) 134 | images = fn.random_resized_crop(images, size=crop, device='gpu' if gpu_aug else 'cpu') 135 | flip_lr = fn.random.coin_flip(probability=0.5) 136 | else: 137 | # decode jpeg and resize 138 | images = fn.decoders.image(images, device='mixed' if gpu_aug else 'cpu', output_type=types.RGB) 139 | images = fn.resize(images, 140 | device='gpu' if gpu_aug else 'cpu', 141 | resize_x=resize, 142 | resize_y=resize, 143 | dtype=types.FLOAT, 144 | interp_type=types.INTERP_TRIANGULAR) 145 | flip_lr = False 146 | 147 | # center crop and normalise 148 | images = fn.crop_mirror_normalize(images, 149 | dtype=types.FLOAT, 150 | crop=(crop, crop), 151 | mean=[127.5], 152 | std=[127.5], 153 | mirror=flip_lr) 154 | label = inputs["image/class/label"] - 1 # 0-999 155 | # LSG: element_extract will raise exception, let's flatten outside 156 | # label = fn.element_extract(label, element_map=0) # Flatten 157 | if cuda: # transfer data to gpu 158 | pipe.set_outputs(images.gpu(), label.gpu()) 159 | else: 160 | pipe.set_outputs(images, label) 161 | 162 | pipe.build() 163 | last_batch_policy = LastBatchPolicy.DROP if training else LastBatchPolicy.PARTIAL 164 | super().__init__(pipe, reader_name="Reader", auto_reset=True, last_batch_policy=last_batch_policy) 165 | 166 | def __iter__(self): 167 | # if not reset (after an epoch), reset; if just initialize, ignore 168 | if self._counter >= self._size or self._size < 0: 169 | self.reset() 170 | return self 171 | 172 | def __next__(self): 173 | data = super().__next__() 174 | img, label = data[0]['data'], data[0]['label'] 175 | img = (img - 127.5) / 127.5 176 | label = label.squeeze() 177 | if self.mixup_alpha > 0.0: 178 | if self.training: 179 | lam = np.random.beta(self.mixup_alpha, self.mixup_alpha) 180 | idx = torch.randperm(img.size(0)).to(img.device) 181 | img = lam * img + (1 - lam) * img[idx, :] 182 | label_a, label_b = label, label[idx] 183 | lam = torch.tensor(lam, device=img.device, dtype=img.dtype) 184 | label = {'targets_a': label_a, 'targets_b': label_b, 'lam': lam} 185 | else: 186 | label = { 187 | 'targets_a': label, 188 | 'targets_b': label, 189 | 'lam': torch.ones((), device=img.device, dtype=img.dtype) 190 | } 191 | return {'pixel_values': img, 'labels': label} 192 | return {'pixel_values': img, 'labels': label} 193 | 194 | rank = get_rank() 195 | world_size = get_world_size() 196 | 197 | train_pat = os.path.join(CONFIG['dataset'], 'train/*') 198 | train_idx_pat = os.path.join(CONFIG['dataset'], 'idx_files/train/*') 199 | val_pat = os.path.join(CONFIG['dataset'], 'validation/*') 200 | val_idx_pat = os.path.join(CONFIG['dataset'], 'idx_files/validation/*') 201 | 202 | train_data = DaliDataloader(sorted(glob.glob(train_pat)), 203 | sorted(glob.glob(train_idx_pat)), 204 | batch_size=CONFIG['hyperparameter']['batch_size'], 205 | shard_id=rank, 206 | num_shards=world_size, 207 | gpu_aug=True, 208 | cuda=True, 209 | mixup_alpha=CONFIG['hyperparameter']['mixup_alpha']) 210 | 211 | test_data = DaliDataloader(sorted(glob.glob(val_pat)), 212 | sorted(glob.glob(val_idx_pat)), 213 | batch_size=CONFIG['hyperparameter']['batch_size'], 214 | shard_id=rank, 215 | num_shards=world_size, 216 | training=False, 217 | gpu_aug=False, 218 | cuda=True, 219 | mixup_alpha=CONFIG['hyperparameter']['mixup_alpha']) 220 | 221 | return train_data, test_data 222 | 223 | 224 | def build_model(): 225 | vit_config = ViTConfig(image_size=CONFIG['model']['img_size'], 226 | patch_size=CONFIG['model']['patch_size'], 227 | hidden_size=CONFIG['model']['hidden_size'], 228 | intermediate_size=CONFIG['model']['intermediate_size'], 229 | num_hidden_layers=CONFIG['model']['depth'], 230 | hidden_dropout_prob=CONFIG['model']['dropout'], 231 | num_attention_heads=CONFIG['model']['num_heads'], 232 | num_labels=CONFIG['model']['num_labels']) 233 | model = ModelFromHF(vit_config, ViTForImageClassification) 234 | 235 | return model 236 | 237 | 238 | class MixupLoss(torch.nn.Module): 239 | 240 | def __init__(self): 241 | super().__init__() 242 | self.loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=0.1) 243 | 244 | def forward(self, inputs, targets): 245 | targets_a, targets_b, lam = targets['targets_a'], targets['targets_b'], targets['lam'] 246 | return lam * self.loss_fn(inputs, targets_a) + (1 - lam) * self.loss_fn(inputs, targets_b) 247 | 248 | 249 | def build_loss(): 250 | return MixupLoss() 251 | 252 | 253 | def build_optimizer(params): 254 | optimizer = torch.optim.Adam(params, 255 | lr=CONFIG['hyperparameter']['learning_rate'], 256 | weight_decay=CONFIG['hyperparameter']['weight_decay']) 257 | 258 | return optimizer 259 | 260 | 261 | def build_scheduler(epoch_steps, optimizer): 262 | from transformers.optimization import get_cosine_schedule_with_warmup 263 | 264 | max_steps = epoch_steps * CONFIG['hyperparameter']['num_epochs'] 265 | warmup_steps = epoch_steps * CONFIG['hyperparameter']['warmup_epochs'] 266 | lr_scheduler = get_cosine_schedule_with_warmup(optimizer, 267 | num_warmup_steps=warmup_steps, 268 | num_training_steps=max_steps) 269 | 270 | return lr_scheduler 271 | 272 | 273 | def vit_builder(): 274 | model_type = CONFIG['model']['type'] 275 | if model_type in _vit_configurations: 276 | for k, v in _vit_configurations[model_type].items(): 277 | if k not in CONFIG['model']: 278 | CONFIG['model'][k] = v 279 | 280 | if 'hyperparameter' in CONFIG: 281 | for k, v in _default_hyperparameters.items(): 282 | if k not in CONFIG['hyperparameter']: 283 | CONFIG['hyperparameter'][k] = v 284 | else: 285 | CONFIG['hyperparameter'] = _default_hyperparameters 286 | 287 | CONFIG['dataset'] = os.environ['DATA'] 288 | 289 | return build_data, build_model, build_loss, build_optimizer, build_scheduler 290 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /zero/common/gpt2.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.distributed import get_world_size 5 | from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer 6 | 7 | from zero.common.utils import CONFIG, ModelFromHF, get_model_size 8 | 9 | _gpt2_small = dict( 10 | seq_length=1024, 11 | vocab_size=50257, 12 | hidden_size=768, 13 | num_heads=12, 14 | depth=12, 15 | numel=124439808, 16 | checkpoint=True, 17 | evaluation='ppl', 18 | ) 19 | 20 | _gpt2_xl = dict( 21 | seq_length=1024, 22 | vocab_size=50257, 23 | hidden_size=1600, 24 | num_heads=25, 25 | depth=48, 26 | numel=1557611200, 27 | checkpoint=True, 28 | evaluation='ppl', 29 | ) 30 | 31 | _gpt2_10b = dict( 32 | seq_length=1024, 33 | vocab_size=50257, 34 | hidden_size=4096, 35 | num_heads=16, 36 | depth=50, 37 | numel=10279047168, 38 | checkpoint=True, 39 | evaluation='ppl', 40 | ) 41 | 42 | _gpt2_4b = dict( 43 | seq_length=1024, 44 | vocab_size=50257, 45 | hidden_size=2304, 46 | num_heads=16, 47 | depth=64, 48 | checkpoint=True, 49 | evaluation='ppl', 50 | ) 51 | 52 | _gpt2_2b = dict( 53 | seq_length=1024, 54 | vocab_size=50257, 55 | hidden_size=2048, 56 | num_heads=16, 57 | depth=40, 58 | checkpoint=True, 59 | evaluation='ppl', 60 | ) 61 | 62 | _gpt2_3b = dict( 63 | seq_length=1024, 64 | vocab_size=50257, 65 | hidden_size=2560, 66 | num_heads=40, 67 | depth=24, 68 | checkpoint=True, 69 | evaluation='ppl', 70 | ) 71 | 72 | _gpt2_6b = dict( 73 | seq_length=1024, 74 | vocab_size=50257, 75 | hidden_size=4096, 76 | num_heads=16, 77 | depth=30, 78 | checkpoint=True, 79 | evaluation='ppl', 80 | ) 81 | 82 | _gpt2_8b = dict( 83 | seq_length=1024, 84 | vocab_size=50257, 85 | hidden_size=4096, 86 | num_heads=16, 87 | depth=40, 88 | checkpoint=True, 89 | evaluation='ppl', 90 | ) 91 | 92 | _gpt2_12b = dict( 93 | seq_length=1024, 94 | vocab_size=50257, 95 | hidden_size=4096, 96 | num_heads=16, 97 | depth=60, 98 | checkpoint=True, 99 | evaluation='ppl', 100 | ) 101 | 102 | _gpt2_15b = dict( 103 | seq_length=1024, 104 | vocab_size=50257, 105 | hidden_size=4096, 106 | num_heads=16, 107 | depth=78, 108 | checkpoint=True, 109 | evaluation='ppl', 110 | ) 111 | 112 | _gpt2_18b = dict( 113 | seq_length=1024, 114 | vocab_size=50257, 115 | hidden_size=4096, 116 | num_heads=16, 117 | depth=90, 118 | checkpoint=True, 119 | evaluation='ppl', 120 | ) 121 | 122 | _gpt2_20b = dict( 123 | seq_length=1024, 124 | vocab_size=50257, 125 | hidden_size=8192, 126 | num_heads=16, 127 | depth=25, 128 | checkpoint=True, 129 | evaluation='ppl', 130 | ) 131 | 132 | _gpt2_24b = dict( 133 | seq_length=1024, 134 | vocab_size=50257, 135 | hidden_size=8192, 136 | num_heads=16, 137 | depth=30, 138 | checkpoint=True, 139 | evaluation='ppl', 140 | ) 141 | 142 | _gpt2_28b = dict( 143 | seq_length=1024, 144 | vocab_size=50257, 145 | hidden_size=8192, 146 | num_heads=16, 147 | depth=35, 148 | checkpoint=True, 149 | evaluation='ppl', 150 | ) 151 | 152 | _gpt2_32b = dict( 153 | seq_length=1024, 154 | vocab_size=50257, 155 | hidden_size=8192, 156 | num_heads=16, 157 | depth=40, 158 | checkpoint=True, 159 | evaluation='ppl', 160 | ) 161 | 162 | _gpt2_36b = dict( 163 | seq_length=1024, 164 | vocab_size=50257, 165 | hidden_size=8192, 166 | num_heads=16, 167 | depth=45, 168 | checkpoint=True, 169 | evaluation='ppl', 170 | ) 171 | 172 | _gpt2_40b = dict( 173 | seq_length=1024, 174 | vocab_size=50257, 175 | hidden_size=8192, 176 | num_heads=16, 177 | depth=50, 178 | checkpoint=True, 179 | evaluation='ppl', 180 | ) 181 | 182 | _gpt2_configurations = dict( 183 | gpt2=_gpt2_small, 184 | gpt2_small=_gpt2_small, 185 | gpt2_xl=_gpt2_xl, 186 | gpt2_10b=_gpt2_10b, 187 | gpt2_4b=_gpt2_4b, 188 | gpt2_6b=_gpt2_6b, 189 | gpt2_8b=_gpt2_8b, 190 | gpt2_2b=_gpt2_2b, 191 | gpt2_3b=_gpt2_3b, 192 | gpt2_12b=_gpt2_12b, 193 | gpt2_15b=_gpt2_15b, 194 | gpt2_18b=_gpt2_18b, 195 | gpt2_20b=_gpt2_20b, 196 | gpt2_40b=_gpt2_40b, 197 | gpt2_24b=_gpt2_24b, 198 | gpt2_28b=_gpt2_28b, 199 | gpt2_32b=_gpt2_32b, 200 | gpt2_36b=_gpt2_36b 201 | ) 202 | 203 | _default_hyperparameters = dict( 204 | tokenize_mode='concat', 205 | batch_size=4, 206 | learning_rate=0.00015, 207 | weight_decay=1e-2, 208 | num_epochs=2, 209 | warmup_epochs=1, 210 | steps_per_epoch=100, 211 | ) 212 | 213 | 214 | def build_data(): 215 | import copy 216 | import random 217 | from functools import partial 218 | from itertools import chain 219 | 220 | import numpy as np 221 | from datasets import load_from_disk, set_progress_bar_enabled 222 | from torch.utils.data import DataLoader, Dataset, DistributedSampler 223 | from transformers import default_data_collator 224 | 225 | world_size = get_world_size() 226 | 227 | if CONFIG['hyperparameter'].get('synthetic'): 228 | 229 | class SyntheticDataset(Dataset): 230 | 231 | def __init__(self, vocab_size, seq_length, size) -> None: 232 | super().__init__() 233 | self.size = size 234 | self.vocab_size = vocab_size 235 | self.seq_length = seq_length 236 | 237 | def __len__(self): 238 | return self.size 239 | 240 | def __getitem__(self, _): 241 | input_ids = torch.randint(self.vocab_size, (self.seq_length, )) 242 | attention_mask = torch.ones((self.seq_length, ), dtype=torch.long) 243 | return {'input_ids': input_ids, 'attention_mask': attention_mask, 'labels': input_ids.clone()} 244 | 245 | vocab_size = CONFIG['model']['vocab_size'] 246 | seq_len = CONFIG['model']['seq_length'] 247 | train_size = CONFIG['hyperparameter']['batch_size'] * world_size * CONFIG['hyperparameter'].get( 248 | 'steps_per_epoch', 100) 249 | test_size = CONFIG['hyperparameter']['batch_size'] * world_size * CONFIG['hyperparameter'].get( 250 | 'steps_per_epoch', 10) 251 | tokenized_dataset = { 252 | 'train': SyntheticDataset(vocab_size, seq_len, train_size), 253 | 'validation': SyntheticDataset(vocab_size, seq_len, test_size) 254 | } 255 | 256 | else: 257 | set_progress_bar_enabled(False) 258 | dataset = load_from_disk(CONFIG['dataset']) 259 | tokenizer = GPT2Tokenizer(vocab_file=CONFIG['tokenizer'] + '/vocab.json', 260 | merges_file=CONFIG['tokenizer'] + '/merges.txt') 261 | 262 | def tokenize(examples, mode='concat'): 263 | assert mode in ['concat', 'pad'] 264 | seq_len = CONFIG['model']['seq_length'] 265 | if mode == 'concat': 266 | examples = tokenizer(examples['text']) 267 | concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} 268 | total_length = len(concatenated_examples[list(examples.keys())[0]]) 269 | if total_length >= seq_len: 270 | total_length = (total_length // seq_len) * seq_len 271 | 272 | result = { 273 | k: [t[i:i + seq_len] for i in range(0, total_length, seq_len)] 274 | for k, t in concatenated_examples.items() 275 | } 276 | else: 277 | tokenizer.pad_token = tokenizer.unk_token 278 | result = tokenizer(examples, padding=True, truncation=True, max_length=seq_len, return_tensors='pt') 279 | 280 | result["labels"] = copy.deepcopy(result["input_ids"]) 281 | 282 | return result 283 | 284 | tokenized_dataset = dataset.map(partial(tokenize, mode=CONFIG['hyperparameter']['tokenize_mode']), 285 | batched=True, 286 | num_proc=16, 287 | load_from_cache_file=False, 288 | keep_in_memory=True, 289 | remove_columns='text') 290 | 291 | CONFIG['model']['vocab_size'] = len(tokenizer) 292 | 293 | def seed_worker(_): 294 | worker_seed = 1024 295 | np.random.seed(worker_seed) 296 | torch.manual_seed(worker_seed) 297 | random.seed(worker_seed) 298 | 299 | train_sampler = DistributedSampler(tokenized_dataset['train'], shuffle=True) if world_size > 1 else None 300 | train_data = DataLoader(tokenized_dataset['train'], 301 | shuffle=(train_sampler is None), 302 | sampler=train_sampler, 303 | drop_last=True, 304 | collate_fn=default_data_collator, 305 | worker_init_fn=seed_worker, 306 | batch_size=CONFIG['hyperparameter']['batch_size'], 307 | num_workers=4, 308 | pin_memory=True) 309 | test_sampler = DistributedSampler(tokenized_dataset['validation'], shuffle=False) if world_size > 1 else None 310 | test_data = DataLoader(tokenized_dataset['validation'], 311 | sampler=test_sampler, 312 | collate_fn=default_data_collator, 313 | worker_init_fn=seed_worker, 314 | batch_size=CONFIG['hyperparameter']['batch_size'], 315 | num_workers=4, 316 | pin_memory=True) 317 | 318 | return train_data, test_data 319 | 320 | 321 | def build_model(): 322 | model_cfg = CONFIG['model'] 323 | gpt2_cfg = GPT2Config(vocab_size=model_cfg['vocab_size'], 324 | n_positions=model_cfg['seq_length'], 325 | n_embd=model_cfg['hidden_size'], 326 | n_layer=model_cfg['depth'], 327 | n_head=model_cfg['num_heads'], 328 | use_cache=not CONFIG['model'].get('checkpoint', False)) 329 | 330 | model = ModelFromHF(gpt2_cfg, GPT2LMHeadModel) 331 | 332 | return model 333 | 334 | 335 | class GPTLMLoss(torch.nn.Module): 336 | 337 | def __init__(self): 338 | super().__init__() 339 | self.loss = torch.nn.CrossEntropyLoss() 340 | 341 | def forward(self, logits, labels): 342 | shift_logits = logits[..., :-1, :].contiguous() 343 | shift_labels = labels[..., 1:].contiguous() 344 | # Flatten the tokens 345 | return self.loss(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 346 | 347 | 348 | def build_loss(): 349 | return GPTLMLoss() 350 | 351 | 352 | def build_optimizer(params): 353 | optimizer = torch.optim.AdamW(params, 354 | lr=CONFIG['hyperparameter']['learning_rate'], 355 | weight_decay=CONFIG['hyperparameter']['weight_decay']) 356 | return optimizer 357 | 358 | 359 | def build_scheduler(epoch_steps, optimizer): 360 | from transformers.optimization import get_linear_schedule_with_warmup 361 | 362 | max_steps = epoch_steps * CONFIG['hyperparameter']['num_epochs'] 363 | warmup_steps = epoch_steps * CONFIG['hyperparameter']['warmup_epochs'] 364 | lr_scheduler = get_linear_schedule_with_warmup(optimizer, 365 | num_warmup_steps=warmup_steps, 366 | num_training_steps=max_steps) 367 | 368 | return lr_scheduler 369 | 370 | 371 | def gpt2_builder(): 372 | model_type = CONFIG['model']['type'] 373 | if model_type in _gpt2_configurations: 374 | for k, v in _gpt2_configurations[model_type].items(): 375 | if k not in CONFIG['model']: 376 | CONFIG['model'][k] = v 377 | 378 | if 'hyperparameter' in CONFIG: 379 | for k, v in _default_hyperparameters.items(): 380 | if k not in CONFIG['hyperparameter']: 381 | CONFIG['hyperparameter'][k] = v 382 | else: 383 | CONFIG['hyperparameter'] = _default_hyperparameters 384 | 385 | CONFIG['dataset'] = os.environ['DATA'] 386 | CONFIG['tokenizer'] = os.environ['TOKENIZER'] 387 | 388 | return build_data, build_model, build_loss, build_optimizer, build_scheduler 389 | -------------------------------------------------------------------------------- /bert/colossalai_utils/model_zoo/colo_bert.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from huggingface modeling_bert.py. Change the necessary part to use Colossolai. 3 | """ 4 | import torch 5 | import math 6 | 7 | from transformers import (PreTrainedModel, BertConfig, load_tf_weights_in_bert, 8 | apply_chunking_to_forward 9 | ) 10 | from transformers.activations import ACT2FN 11 | 12 | from transformers.modeling_outputs import MaskedLMOutput as HFMaskedLMOutput 13 | from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions as HFBaseModelOutputWithPoolingAndCrossAttentions 14 | from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions as HFBaseModelOutputWithPastAndCrossAttentions 15 | 16 | from torch import nn 17 | from packaging import version 18 | 19 | from colossalai import nn as col_nn 20 | from colossalai.nn.layer.utils import divide 21 | from colossalai.core import global_context as gpc 22 | from colossalai.logging import get_dist_logger 23 | from colossalai.context.parallel_mode import ParallelMode 24 | from colossalai.nn.layer.wrapper import PipelineSharedModuleWrapper 25 | from colossalai.utils import get_current_device 26 | from colossalai.builder.pipeline import partition_uniform 27 | 28 | class BertSelfOutput(nn.Module): 29 | def __init__(self, config): 30 | super().__init__() 31 | self.dense = col_nn.Linear(config.hidden_size, config.hidden_size) 32 | self.dropout = col_nn.Dropout(config.hidden_dropout_prob) 33 | self.LayerNorm = col_nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 34 | 35 | def forward(self, hidden_states, input_tensor): 36 | hidden_states = self.dense(hidden_states) 37 | hidden_states = self.dropout(hidden_states) 38 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 39 | return hidden_states 40 | 41 | class BertSelfAttention(nn.Module): 42 | def __init__(self, config, position_embedding_type=None): 43 | super().__init__() 44 | if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): 45 | raise ValueError( 46 | f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " 47 | f"heads ({config.num_attention_heads})" 48 | ) 49 | 50 | self.num_attention_heads = config.num_attention_heads 51 | self.attention_head_size = divide(config.hidden_size, config.num_attention_heads) 52 | 53 | self.query_key_value = col_nn.Linear(config.hidden_size, self.num_attention_heads * self.attention_head_size * 3) 54 | 55 | self.dropout = col_nn.Dropout(config.attention_probs_dropout_prob) 56 | self.position_embedding_type = position_embedding_type or getattr( 57 | config, "position_embedding_type", "absolute" 58 | ) 59 | if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": 60 | self.max_position_embeddings = config.max_position_embeddings 61 | self.distance_embedding = col_nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) 62 | 63 | self.is_decoder = config.is_decoder 64 | 65 | def forward( 66 | self, 67 | hidden_states, 68 | attention_mask=None, 69 | head_mask=None, 70 | encoder_hidden_states=None, 71 | encoder_attention_mask=None, 72 | past_key_value=None, 73 | output_attentions=False, 74 | ): 75 | qkv = self.query_key_value(hidden_states) 76 | all_head_size = qkv.shape[-1] // 3 77 | num_attention_heads = divide(all_head_size, self.attention_head_size) 78 | new_qkv_shape = qkv.shape[:-1] + \ 79 | (num_attention_heads, 3 * self.attention_head_size) 80 | qkv = qkv.view(new_qkv_shape) 81 | qkv = qkv.permute((0, 2, 1, 3)) 82 | ###print("BertSelfAttention:qkv:", qkv.shape) 83 | q, k, v = torch.chunk(qkv, 3, dim=-1) 84 | 85 | # If this is instantiated as a cross-attention module, the keys 86 | # and values come from an encoder; the attention mask needs to be 87 | # such that the encoder's padding tokens are not attended to. 88 | is_cross_attention = encoder_hidden_states is not None 89 | 90 | query_layer = q 91 | key_layer = k 92 | value_layer = v 93 | 94 | if self.is_decoder: 95 | # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. 96 | # Further calls to cross_attention layer can then reuse all cross-attention 97 | # key/value_states (first "if" case) 98 | # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of 99 | # all previous decoder key/value_states. Further calls to uni-directional self-attention 100 | # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) 101 | # if encoder bi-directional self-attention `past_key_value` is always `None` 102 | past_key_value = (key_layer, value_layer) 103 | 104 | # Take the dot product between "query" and "key" to get the raw attention scores. 105 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 106 | 107 | if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": 108 | seq_length = hidden_states.size()[1] 109 | position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) 110 | position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) 111 | distance = position_ids_l - position_ids_r 112 | positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) 113 | positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility 114 | 115 | if self.position_embedding_type == "relative_key": 116 | relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) 117 | attention_scores = attention_scores + relative_position_scores 118 | elif self.position_embedding_type == "relative_key_query": 119 | relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) 120 | relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) 121 | attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key 122 | 123 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 124 | ###print("BertSelfAttention:attention_scores:", attention_scores.shape) 125 | ###print("BertSelfAttention:attention_mask:", attention_mask.shape) 126 | if attention_mask is not None: 127 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 128 | attention_scores = attention_scores + attention_mask 129 | 130 | # Normalize the attention scores to probabilities. 131 | attention_probs = nn.functional.softmax(attention_scores, dim=-1) 132 | 133 | # This is actually dropping out entire tokens to attend to, which might 134 | # seem a bit unusual, but is taken from the original Transformer paper. 135 | attention_probs = self.dropout(attention_probs) 136 | 137 | # Mask heads if we want to 138 | if head_mask is not None: 139 | attention_probs = attention_probs * head_mask 140 | 141 | context_layer = torch.matmul(attention_probs, value_layer) 142 | 143 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 144 | new_context_layer_shape = context_layer.size()[:-2] + (all_head_size,) 145 | context_layer = context_layer.view(new_context_layer_shape) 146 | 147 | outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) 148 | 149 | if self.is_decoder: 150 | outputs = outputs + (past_key_value,) 151 | return outputs 152 | 153 | class BertAttention(nn.Module): 154 | def __init__(self, config, position_embedding_type=None): 155 | super().__init__() 156 | self.self = BertSelfAttention(config, position_embedding_type=position_embedding_type) 157 | self.output = BertSelfOutput(config) 158 | 159 | def forward( 160 | self, 161 | hidden_states, 162 | attention_mask=None, 163 | head_mask=None, 164 | encoder_hidden_states=None, 165 | encoder_attention_mask=None, 166 | past_key_value=None, 167 | output_attentions=False, 168 | ): 169 | self_outputs = self.self( 170 | hidden_states, 171 | attention_mask, 172 | head_mask, 173 | encoder_hidden_states, 174 | encoder_attention_mask, 175 | past_key_value, 176 | output_attentions, 177 | ) 178 | attention_output = self.output(self_outputs[0], hidden_states) 179 | outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them 180 | ###print("BertAttention:attention_output:", attention_output.shape) 181 | return outputs 182 | 183 | class BertIntermediate(nn.Module): 184 | def __init__(self, config): 185 | super().__init__() 186 | self.dense = col_nn.Linear(config.hidden_size, config.intermediate_size) 187 | self.intermediate_act_fn = ACT2FN[config.hidden_act] 188 | 189 | def forward(self, hidden_states): 190 | hidden_states = self.dense(hidden_states) 191 | hidden_states = self.intermediate_act_fn(hidden_states) 192 | return hidden_states 193 | 194 | class BertOutput(nn.Module): 195 | def __init__(self, config): 196 | super().__init__() 197 | self.dense = col_nn.Linear(config.intermediate_size, config.hidden_size) 198 | self.dropout = col_nn.Dropout(config.hidden_dropout_prob) 199 | self.LayerNorm = col_nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 200 | 201 | def forward(self, hidden_states, input_tensor): 202 | hidden_states = self.dense(hidden_states) 203 | hidden_states = self.dropout(hidden_states) 204 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 205 | return hidden_states 206 | 207 | class BertLayer(nn.Module): 208 | def __init__(self, config): 209 | super().__init__() 210 | self.chunk_size_feed_forward = config.chunk_size_feed_forward 211 | self.seq_len_dim = 1 212 | self.attention = BertAttention(config) 213 | self.is_decoder = config.is_decoder 214 | self.add_cross_attention = config.add_cross_attention 215 | if self.add_cross_attention: 216 | if not self.is_decoder: 217 | raise ValueError(f"{self} should be used as a decoder model if cross attention is added") 218 | self.crossattention = BertAttention(config, position_embedding_type="absolute") 219 | self.intermediate = BertIntermediate(config) 220 | self.output = BertOutput(config) 221 | 222 | def forward( 223 | self, 224 | hidden_states, 225 | attention_mask=None, 226 | head_mask=None, 227 | encoder_hidden_states=None, 228 | encoder_attention_mask=None, 229 | past_key_value=None, 230 | output_attentions=False, 231 | ): 232 | # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 233 | self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None 234 | self_attention_outputs = self.attention( 235 | hidden_states, 236 | attention_mask, 237 | head_mask, 238 | output_attentions=output_attentions, 239 | past_key_value=self_attn_past_key_value, 240 | ) 241 | attention_output = self_attention_outputs[0] 242 | 243 | # if decoder, the last output is tuple of self-attn cache 244 | if self.is_decoder: 245 | outputs = self_attention_outputs[1:-1] 246 | present_key_value = self_attention_outputs[-1] 247 | else: 248 | outputs = self_attention_outputs[1:] # add self attentions if we output attention weights 249 | 250 | cross_attn_present_key_value = None 251 | if self.is_decoder and encoder_hidden_states is not None: 252 | if not hasattr(self, "crossattention"): 253 | raise ValueError( 254 | f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`" 255 | ) 256 | 257 | # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple 258 | cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None 259 | cross_attention_outputs = self.crossattention( 260 | attention_output, 261 | attention_mask, 262 | head_mask, 263 | encoder_hidden_states, 264 | encoder_attention_mask, 265 | cross_attn_past_key_value, 266 | output_attentions, 267 | ) 268 | attention_output = cross_attention_outputs[0] 269 | outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights 270 | 271 | # add cross-attn cache to positions 3,4 of present_key_value tuple 272 | cross_attn_present_key_value = cross_attention_outputs[-1] 273 | present_key_value = present_key_value + cross_attn_present_key_value 274 | 275 | layer_output = apply_chunking_to_forward( 276 | self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output 277 | ) 278 | outputs = (layer_output,) + outputs 279 | 280 | # if decoder, return the attn key/values as the last output 281 | if self.is_decoder: 282 | outputs = outputs + (present_key_value,) 283 | 284 | return outputs 285 | 286 | def feed_forward_chunk(self, attention_output): 287 | intermediate_output = self.intermediate(attention_output) 288 | layer_output = self.output(intermediate_output, attention_output) 289 | return layer_output 290 | 291 | class BertEmbeddings(nn.Module): 292 | """Construct the embeddings from word, position and token_type embeddings.""" 293 | 294 | def __init__(self, config): 295 | super().__init__() 296 | self.word_embeddings = col_nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) 297 | self.token_type_embeddings = col_nn.Embedding(config.type_vocab_size, config.hidden_size) 298 | self.position_embeddings = col_nn.Embedding(config.max_position_embeddings, config.hidden_size) 299 | 300 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load 301 | # any TensorFlow checkpoint file 302 | self.LayerNorm = col_nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 303 | self.dropout = col_nn.Dropout(config.hidden_dropout_prob) 304 | # position_ids (1, len position emb) is contiguous in memory and exported when serialized 305 | self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") 306 | self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) 307 | if version.parse(torch.__version__) > version.parse("1.6.0"): 308 | self.register_buffer( 309 | "token_type_ids", 310 | torch.zeros(self.position_ids.size(), dtype=torch.long), 311 | persistent=False, 312 | ) 313 | 314 | def forward( 315 | self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 316 | ): 317 | if input_ids is not None: 318 | input_shape = input_ids.size() 319 | else: 320 | input_shape = inputs_embeds.size()[:-1] 321 | 322 | seq_length = input_shape[1] 323 | 324 | if position_ids is None: 325 | position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] 326 | 327 | # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs 328 | # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves 329 | # issue #5664 330 | if token_type_ids is None: 331 | if hasattr(self, "token_type_ids"): 332 | buffered_token_type_ids = self.token_type_ids[:, :seq_length] 333 | buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) 334 | token_type_ids = buffered_token_type_ids_expanded 335 | else: 336 | token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) 337 | 338 | if inputs_embeds is None: 339 | inputs_embeds = self.word_embeddings(input_ids) 340 | ###print("BertEmbeddings:word_embeddings:", inputs_embeds.shape) 341 | token_type_embeddings = self.token_type_embeddings(token_type_ids) 342 | ###print("BertEmbeddings:token_type_embeddings:", token_type_embeddings.shape) 343 | 344 | embeddings = inputs_embeds + token_type_embeddings 345 | if self.position_embedding_type == "absolute": 346 | position_embeddings = self.position_embeddings(position_ids) 347 | ###print("BertEmbeddings:position_embeddings:", position_embeddings.shape) 348 | embeddings += position_embeddings 349 | embeddings = self.LayerNorm(embeddings) 350 | ###print("BertEmbeddings:LayerNorm:", embeddings.shape) 351 | embeddings = self.dropout(embeddings) 352 | ###print("BertEmbeddings:dropout:", embeddings.shape) 353 | return embeddings 354 | 355 | class BertPooler(nn.Module): 356 | def __init__(self, config): 357 | super().__init__() 358 | self.dense = col_nn.Linear(config.hidden_size, config.hidden_size) 359 | self.activation = nn.Tanh() 360 | 361 | def forward(self, hidden_states): 362 | # We "pool" the model by simply taking the hidden state corresponding 363 | # to the first token. 364 | first_token_tensor = hidden_states[:, 0] 365 | pooled_output = self.dense(first_token_tensor) 366 | pooled_output = self.activation(pooled_output) 367 | return pooled_output 368 | 369 | class ColoBertMaskedLMLoss(torch.nn.Module): 370 | def __init__(self): 371 | super().__init__() 372 | self.loss = col_nn.CrossEntropyLoss() 373 | 374 | def forward(self, logits, labels): 375 | shift_logits = logits[..., :-1, :].contiguous() 376 | shift_labels = labels[..., 1:].contiguous() 377 | # Flatten the tokens 378 | return self.loss(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 379 | 380 | class BertEncoder(nn.Module): 381 | def __init__(self, config): 382 | super().__init__() 383 | self.config = config 384 | self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)]) 385 | self.gradient_checkpointing = False 386 | 387 | def forward( 388 | self, 389 | hidden_states, 390 | attention_mask=None, 391 | head_mask=None, 392 | encoder_hidden_states=None, 393 | encoder_attention_mask=None, 394 | past_key_values=None, 395 | use_cache=None, 396 | output_attentions=False, 397 | output_hidden_states=False, 398 | return_dict=True, 399 | ): 400 | all_hidden_states = () if output_hidden_states else None 401 | all_self_attentions = () if output_attentions else None 402 | all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None 403 | 404 | next_decoder_cache = () if use_cache else None 405 | 406 | for i, layer_module in enumerate(self.layer): 407 | if output_hidden_states: 408 | all_hidden_states = all_hidden_states + (hidden_states,) 409 | 410 | layer_head_mask = head_mask[i] if head_mask is not None else None 411 | past_key_value = past_key_values[i] if past_key_values is not None else None 412 | 413 | if self.gradient_checkpointing and self.training: 414 | 415 | if use_cache: 416 | use_cache = False 417 | 418 | def create_custom_forward(module): 419 | def custom_forward(*inputs): 420 | return module(*inputs, past_key_value, output_attentions) 421 | 422 | return custom_forward 423 | 424 | layer_outputs = torch.utils.checkpoint.checkpoint( 425 | create_custom_forward(layer_module), 426 | hidden_states, 427 | attention_mask, 428 | layer_head_mask, 429 | encoder_hidden_states, 430 | encoder_attention_mask, 431 | ) 432 | else: 433 | layer_outputs = layer_module( 434 | hidden_states, 435 | attention_mask, 436 | layer_head_mask, 437 | encoder_hidden_states, 438 | encoder_attention_mask, 439 | past_key_value, 440 | output_attentions, 441 | ) 442 | 443 | hidden_states = layer_outputs[0] 444 | if use_cache: 445 | next_decoder_cache += (layer_outputs[-1],) 446 | if output_attentions: 447 | all_self_attentions = all_self_attentions + (layer_outputs[1],) 448 | if self.config.add_cross_attention: 449 | all_cross_attentions = all_cross_attentions + (layer_outputs[2],) 450 | 451 | if output_hidden_states: 452 | all_hidden_states = all_hidden_states + (hidden_states,) 453 | 454 | if not return_dict: 455 | return tuple( 456 | v 457 | for v in [ 458 | hidden_states, 459 | next_decoder_cache, 460 | all_hidden_states, 461 | all_self_attentions, 462 | all_cross_attentions, 463 | ] 464 | if v is not None 465 | ) 466 | return HFBaseModelOutputWithPastAndCrossAttentions( 467 | last_hidden_state=hidden_states, 468 | past_key_values=next_decoder_cache, 469 | hidden_states=all_hidden_states, 470 | attentions=all_self_attentions, 471 | cross_attentions=all_cross_attentions, 472 | ) 473 | 474 | class BertPreTrainedModel(PreTrainedModel): 475 | """ 476 | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained 477 | models. 478 | """ 479 | config_class = BertConfig 480 | load_tf_weights = load_tf_weights_in_bert 481 | base_model_prefix = "bert" 482 | supports_gradient_checkpointing = True 483 | _keys_to_ignore_on_load_missing = [r"position_ids"] 484 | 485 | def _init_weights(self, module): 486 | """Initialize the weights""" 487 | if isinstance(module, col_nn.Linear): 488 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 489 | if module.bias is not None: 490 | module.bias.data.zero_() 491 | elif isinstance(module, col_nn.Embedding): 492 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 493 | 494 | def _set_gradient_checkpointing(self, module, value=False): 495 | if isinstance(module, BertEncoder): 496 | module.gradient_checkpointing = value 497 | 498 | class BertModel(BertPreTrainedModel): 499 | """ 500 | The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of 501 | cross-attention is added between the self-attention layers, following the architecture described in [Attention is 502 | all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, 503 | Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. 504 | To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set 505 | to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and 506 | `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. 507 | """ 508 | 509 | def __init__(self, config, add_pooling_layer=True): 510 | super().__init__(config) 511 | self.config = config 512 | 513 | self.embeddings = BertEmbeddings(config) 514 | self.encoder = BertEncoder(config) 515 | 516 | self.pooler = BertPooler(config) if add_pooling_layer else None 517 | 518 | def get_input_embeddings(self): 519 | return self.embeddings.word_embeddings 520 | 521 | def set_input_embeddings(self, value): 522 | self.embeddings.word_embeddings = value 523 | 524 | def forward( 525 | self, 526 | input_ids=None, 527 | attention_mask=None, 528 | token_type_ids=None, 529 | position_ids=None, 530 | head_mask=None, 531 | inputs_embeds=None, 532 | encoder_hidden_states=None, 533 | encoder_attention_mask=None, 534 | past_key_values=None, 535 | use_cache=None, 536 | output_attentions=None, 537 | output_hidden_states=None, 538 | return_dict=None, 539 | ): 540 | r""" 541 | encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): 542 | Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if 543 | the model is configured as a decoder. 544 | encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): 545 | Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in 546 | the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: 547 | - 1 for tokens that are **not masked**, 548 | - 0 for tokens that are **masked**. 549 | past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): 550 | Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. 551 | If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that 552 | don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all 553 | `decoder_input_ids` of shape `(batch_size, sequence_length)`. 554 | use_cache (`bool`, *optional*): 555 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see 556 | `past_key_values`). 557 | """ 558 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 559 | output_hidden_states = ( 560 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 561 | ) 562 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 563 | 564 | if self.config.is_decoder: 565 | use_cache = use_cache if use_cache is not None else self.config.use_cache 566 | else: 567 | use_cache = False 568 | 569 | if input_ids is not None and inputs_embeds is not None: 570 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 571 | elif input_ids is not None: 572 | input_shape = input_ids.size() 573 | elif inputs_embeds is not None: 574 | input_shape = inputs_embeds.size()[:-1] 575 | else: 576 | raise ValueError("You have to specify either input_ids or inputs_embeds") 577 | 578 | batch_size, seq_length = input_shape 579 | device = input_ids.device if input_ids is not None else inputs_embeds.device 580 | 581 | # past_key_values_length 582 | past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 583 | 584 | if attention_mask is None: 585 | attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) 586 | 587 | if token_type_ids is None: 588 | if hasattr(self.embeddings, "token_type_ids"): 589 | buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] 590 | buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) 591 | token_type_ids = buffered_token_type_ids_expanded 592 | else: 593 | token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) 594 | 595 | # Prepare head mask if needed 596 | # 1.0 in head_mask indicate we keep the head 597 | # attention_probs has shape bsz x n_heads x N x N 598 | # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] 599 | # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] 600 | head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) 601 | 602 | embedding_output = self.embeddings( 603 | input_ids=input_ids, 604 | position_ids=position_ids, 605 | token_type_ids=token_type_ids, 606 | inputs_embeds=inputs_embeds, 607 | past_key_values_length=past_key_values_length, 608 | ) 609 | 610 | # We create a 3D attention mask from a 2D tensor mask. 611 | # Sizes are [batch_size, 1, 1, to_seq_length] 612 | # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] 613 | # Adapted from huggingface 614 | if attention_mask is not None: 615 | batch_size = input_ids.shape[0] 616 | extended_attention_mask = attention_mask.view(batch_size, -1) 617 | extended_attention_mask = col_nn.partition_batch(extended_attention_mask) 618 | extended_attention_mask = extended_attention_mask.unsqueeze(1).unsqueeze(2) 619 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 620 | extended_attention_mask = extended_attention_mask.to(dtype=embedding_output.dtype) 621 | 622 | # If a 2D or 3D attention mask is provided for the cross-attention 623 | # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] 624 | if self.config.is_decoder and encoder_hidden_states is not None: 625 | encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() 626 | encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) 627 | if encoder_attention_mask is None: 628 | encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) 629 | encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) 630 | else: 631 | encoder_extended_attention_mask = None 632 | 633 | ###print("BertModel:emeddings:", embedding_output.shape) 634 | encoder_outputs = self.encoder( 635 | embedding_output, 636 | attention_mask=extended_attention_mask, 637 | head_mask=head_mask, 638 | encoder_hidden_states=encoder_hidden_states, 639 | encoder_attention_mask=encoder_extended_attention_mask, 640 | past_key_values=past_key_values, 641 | use_cache=use_cache, 642 | output_attentions=output_attentions, 643 | output_hidden_states=output_hidden_states, 644 | return_dict=return_dict, 645 | ) 646 | sequence_output = encoder_outputs[0] 647 | pooled_output = self.pooler(sequence_output) if self.pooler is not None else None 648 | 649 | if not return_dict: 650 | return (sequence_output, pooled_output) + encoder_outputs[1:] 651 | 652 | return HFBaseModelOutputWithPoolingAndCrossAttentions( 653 | last_hidden_state=sequence_output, 654 | pooler_output=pooled_output, 655 | past_key_values=encoder_outputs.past_key_values, 656 | hidden_states=encoder_outputs.hidden_states, 657 | attentions=encoder_outputs.attentions, 658 | cross_attentions=encoder_outputs.cross_attentions, 659 | ) 660 | 661 | class BertPredictionHeadTransform(nn.Module): 662 | def __init__(self, config): 663 | super().__init__() 664 | self.dense = col_nn.Linear(config.hidden_size, config.hidden_size, gather_output=True) 665 | self.transform_act_fn = ACT2FN[config.hidden_act] 666 | self.LayerNorm = col_nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 667 | 668 | def forward(self, hidden_states): 669 | ###print("BertPredictionHeadTransform:input:", hidden_states.shape) 670 | hidden_states = self.dense(hidden_states) 671 | ###print("BertPredictionHeadTransform:after dense:", hidden_states.shape) 672 | hidden_states = self.transform_act_fn(hidden_states) 673 | ###print("BertPredictionHeadTransform:after act2fn:", hidden_states.shape) 674 | hidden_states = self.LayerNorm(hidden_states) 675 | ###print("BertPredictionHeadTransform:output:", hidden_states.shape) 676 | return hidden_states 677 | 678 | class BertLMPredictionHead(nn.Module): 679 | def __init__(self, config): 680 | super().__init__() 681 | self.transform = BertPredictionHeadTransform(config) 682 | 683 | # The output weights are the same as the input embeddings, but there is 684 | # an output-only bias for each token. 685 | self.decoder = col_nn.Classifier(config.hidden_size, config.vocab_size, bias=True) 686 | 687 | def forward(self, hidden_states): 688 | ###print("BertLMPredictionHead:input:", hidden_states.shape) 689 | hidden_states = self.transform(hidden_states) 690 | ###print("BertLMPredictionHead:after transfrom:", hidden_states.shape) 691 | hidden_states = self.decoder(hidden_states) 692 | ###print("BertLMPredictionHead:output:", hidden_states.shape) 693 | return hidden_states 694 | 695 | class BertOnlyMLMHead(nn.Module): 696 | def __init__(self, config): 697 | super().__init__() 698 | self.predictions = BertLMPredictionHead(config) 699 | 700 | def forward(self, sequence_output): 701 | prediction_scores = self.predictions(sequence_output) 702 | return prediction_scores 703 | 704 | class ColoBertForMaskedLM(BertPreTrainedModel): 705 | 706 | _keys_to_ignore_on_load_unexpected = [r"pooler"] 707 | _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] 708 | 709 | def __init__(self, config): 710 | super().__init__(config) 711 | 712 | self.bert = BertModel(config, add_pooling_layer=False) 713 | self.cls = BertOnlyMLMHead(config) 714 | 715 | def get_output_embeddings(self): 716 | return self.cls.predictions.decoder 717 | 718 | def set_output_embeddings(self, new_embeddings): 719 | self.cls.predictions.decoder = new_embeddings 720 | 721 | def forward( 722 | self, 723 | input_ids=None, 724 | attention_mask=None, 725 | token_type_ids=None, 726 | position_ids=None, 727 | head_mask=None, 728 | inputs_embeds=None, 729 | encoder_hidden_states=None, 730 | encoder_attention_mask=None, 731 | labels=None, 732 | output_attentions=None, 733 | output_hidden_states=None, 734 | return_dict=None, 735 | ): 736 | r""" 737 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 738 | Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., 739 | config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the 740 | loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` 741 | """ 742 | 743 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 744 | 745 | outputs = self.bert( 746 | input_ids, 747 | attention_mask=attention_mask, 748 | token_type_ids=token_type_ids, 749 | position_ids=position_ids, 750 | head_mask=head_mask, 751 | inputs_embeds=inputs_embeds, 752 | encoder_hidden_states=encoder_hidden_states, 753 | encoder_attention_mask=encoder_attention_mask, 754 | output_attentions=output_attentions, 755 | output_hidden_states=output_hidden_states, 756 | return_dict=return_dict, 757 | ) 758 | 759 | sequence_output = outputs[0] 760 | prediction_scores = self.cls(sequence_output) 761 | ###print("BertForMaskedLM:") 762 | ###print(sequence_output.shape) 763 | ###print(prediction_scores.shape) 764 | 765 | masked_lm_loss = None 766 | 767 | if not return_dict: 768 | output = (prediction_scores,) + outputs[2:] 769 | return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output 770 | 771 | return HFMaskedLMOutput( 772 | loss=masked_lm_loss, 773 | logits=prediction_scores, 774 | hidden_states=outputs.hidden_states, 775 | attentions=outputs.attentions, 776 | ) 777 | 778 | def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs): 779 | input_shape = input_ids.shape 780 | effective_batch_size = input_shape[0] 781 | 782 | # add a dummy token 783 | if self.config.pad_token_id is None: 784 | raise ValueError("The PAD token should be defined for generation") 785 | 786 | attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1) 787 | dummy_token = torch.full( 788 | (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device 789 | ) 790 | input_ids = torch.cat([input_ids, dummy_token], dim=1) 791 | 792 | return {"input_ids": input_ids, "attention_mask": attention_mask} 793 | 794 | ''' 795 | Colossalai PipelineBert 796 | ''' 797 | class PipelineBertForMaskedLM(nn.Module): 798 | def __init__(self, config, first: bool = False, last: bool = False): 799 | super().__init__() 800 | self.first = first 801 | self.last = last 802 | 803 | if self.first: 804 | self.embeddings = BertEmbeddings(config) 805 | 806 | self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)]) 807 | 808 | if self.last: # For pipeline, only last pipe does Head. 809 | self.cls = BertOnlyMLMHead(config) 810 | 811 | def forward( 812 | self, 813 | x=None, 814 | input_ids=None, 815 | attention_mask=None, 816 | token_type_ids=None, 817 | position_ids=None, 818 | head_mask=None, 819 | inputs_embeds=None, 820 | encoder_hidden_states=None, 821 | encoder_attention_mask=None, 822 | labels=None, 823 | output_attentions=None, 824 | output_hidden_states=None, 825 | return_dict=None, 826 | ): 827 | if self.first: 828 | x = self.embeddings( 829 | input_ids=input_ids, 830 | position_ids=position_ids, 831 | token_type_ids=token_type_ids, 832 | inputs_embeds=inputs_embeds 833 | ) 834 | 835 | # We create a 3D attention mask from a 2D tensor mask. 836 | # Sizes are [batch_size, 1, 1, to_seq_length] 837 | # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] 838 | # Adapted from huggingface 839 | if attention_mask is not None: 840 | if self.first: 841 | batch_size = input_ids.shape[0] 842 | else: 843 | batch_size = x.shape[0] 844 | attention_mask = attention_mask.view(batch_size, -1) 845 | attention_mask = col_nn.partition_batch(attention_mask) 846 | attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) 847 | attention_mask = attention_mask.to(dtype=x.dtype) # fp16 compatibility 848 | attention_mask = (1.0 - attention_mask) * -10000.0 849 | 850 | for i, layer_module in enumerate(self.layer): 851 | layer_outputs = layer_module(x, attention_mask) 852 | x = layer_outputs[0] 853 | 854 | if self.last: # For pipeline, only last pipe does Head. 855 | x = self.cls(x) 856 | 857 | return x 858 | 859 | def create_colo_bert_pipeline_model(config): 860 | num_chunks = 1 861 | layer_partitions = None 862 | depth = config.num_hidden_layers 863 | 864 | logger = get_dist_logger() 865 | pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) 866 | pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) 867 | rank = gpc.get_global_rank() 868 | wrapper = PipelineSharedModuleWrapper([0, pipeline_size - 1]) 869 | parts = partition_uniform(depth, pipeline_size, 870 | num_chunks)[pipeline_rank] if layer_partitions is None else layer_partitions 871 | models = [] 872 | for start, end in parts: 873 | first = start == 0 874 | last = end == depth 875 | config.num_hidden_layers = end - start 876 | chunk = PipelineBertForMaskedLM(config, first, last).to(get_current_device()) 877 | 878 | models.append(chunk) 879 | logger.info(f'==> Rank {rank} built layer {start}-{end} / total {depth}') 880 | if len(models) == 1: 881 | model = models[0] 882 | else: 883 | model = nn.ModuleList(models) 884 | return model 885 | --------------------------------------------------------------------------------