├── test ├── __init__.py ├── chunk │ ├── __init__.py │ ├── test_block.py │ ├── test_fetcher.py │ ├── fetcher_utils.py │ ├── test_group.py │ ├── test_scheduler.py │ └── test_chunk.py ├── parameter │ ├── __init__.py │ ├── hf_models │ │ ├── __init__.py │ │ ├── diffuser.py │ │ ├── gpt.py │ │ └── albert.py │ ├── test_timm.py │ └── test_torchvision.py ├── test_core_functions.sh ├── src │ └── test_move.py ├── utils │ ├── iterator.py │ ├── resnet.py │ ├── registry.py │ ├── mlp.py │ ├── small.py │ ├── __init__.py │ ├── opt.py │ └── gpt.py ├── tools │ └── test_registry.py ├── ctx │ └── test_meta_ctx.py ├── search │ ├── test_optimal.py │ ├── test_mini_waste.py │ └── test_simple.py ├── exp_tracer │ ├── test_fx_order.py │ └── test_td_order.py ├── kernels │ ├── test_attn.py │ └── test_ln.py ├── tracer │ ├── test_tf_order.py │ ├── test_cuda_profiler.py │ └── test_op_cache.py ├── hook │ └── test_hook.py ├── wrapper │ ├── test_optimizer.py │ ├── test_prefetch.py │ ├── test_module.py │ └── test_amp.py └── test_models.py ├── example ├── __init__.py ├── common │ ├── __init__.py │ ├── zero2_config.json │ ├── opt.py │ ├── zero3_config.json │ ├── fsdp.py │ ├── utils.py │ ├── elx.py │ ├── ds.py │ └── models.py ├── fine-tune │ ├── run_ddp.sh │ ├── run_elixir.sh │ ├── readme.md │ ├── func_module.py │ ├── data_module.py │ ├── torch_ddp.py │ └── elixir_mini.py ├── activation │ ├── profile_activation.sh │ └── activation.py ├── benchmark │ ├── scripts │ │ ├── rm_lock.sh │ │ ├── fsdp.sh │ │ ├── elixir.sh │ │ ├── deepspeed.sh │ │ ├── fsdp_benchmark.sh │ │ ├── run_script.sh │ │ └── benchmark.sh │ ├── run_elixir.sh │ ├── fetch_hf_settings.py │ ├── readme.md │ └── run_gemini.sh └── search_example.py ├── elixir ├── tracer │ ├── __init__.py │ ├── memory_tracer │ │ ├── __init__.py │ │ ├── output_shape.py │ │ ├── cuda_profiler.py │ │ ├── memory_tensor.py │ │ └── op_cache.py │ ├── param_tracer │ │ ├── __init__.py │ │ ├── fx_order.py │ │ └── td_order.py │ ├── ops.py │ └── utils.py ├── __init__.py ├── hook │ ├── __init__.py │ ├── storage.py │ ├── functions.py │ └── parameter.py ├── wrapper │ └── __init__.py ├── chunk │ ├── scheduler │ │ ├── __init__.py │ │ ├── fifo.py │ │ ├── base.py │ │ └── prefetch.py │ ├── __init__.py │ └── core │ │ ├── __init__.py │ │ ├── states.py │ │ └── memory_pool.py ├── search │ ├── __init__.py │ ├── result.py │ ├── utils.py │ ├── simple.py │ └── simulator.py ├── kernels │ ├── layernorm.py │ ├── __init__.py │ ├── attn_wrapper.py │ ├── attention.py │ ├── gpt_attention.py │ └── opt_attention.py ├── ctx │ ├── meta_ctx.py │ └── __init__.py ├── cuda.py ├── meta_registrations.py ├── utils.py └── parameter │ └── __init__.py ├── requirements.txt ├── .isort.cfg ├── .style.yapf ├── profile ├── get_bandwidth.sh ├── get_optim_v.sh ├── profile_bandwidth.py └── profile_optimizer.py ├── .pre-commit-config.yaml ├── setup.py ├── src └── simulator.cpp ├── .gitignore └── README.md /test/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /example/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/chunk/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /elixir/tracer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /example/common/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/parameter/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /elixir/__init__.py: -------------------------------------------------------------------------------- 1 | import elixir.cuda 2 | import elixir.utils 3 | -------------------------------------------------------------------------------- /test/test_core_functions.sh: -------------------------------------------------------------------------------- 1 | pytest ctx tracer hook kernels chunk search wrapper 2 | -------------------------------------------------------------------------------- /example/fine-tune/run_ddp.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node=4 --master_port=25912 torch_ddp.py 2 | -------------------------------------------------------------------------------- /elixir/hook/__init__.py: -------------------------------------------------------------------------------- 1 | from .parameter import HookParam 2 | from .storage import BufferStore 3 | -------------------------------------------------------------------------------- /example/fine-tune/run_elixir.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node=4 --master_port=25912 elixir_mini.py 2 | -------------------------------------------------------------------------------- /elixir/wrapper/__init__.py: -------------------------------------------------------------------------------- 1 | from .module import ElixirModule 2 | from .optimizer import ElixirOptimizer 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | diffusers 2 | pytest 3 | sortedcontainers 4 | timm 5 | torch>=1.13.1 6 | transformers 7 | einops 8 | -------------------------------------------------------------------------------- /.isort.cfg: -------------------------------------------------------------------------------- 1 | [settings] 2 | line_length = 120 3 | multi_line_output=3 4 | include_trailing_comma = true 5 | ignore_comments = true 6 | -------------------------------------------------------------------------------- /elixir/tracer/memory_tracer/__init__.py: -------------------------------------------------------------------------------- 1 | from .cuda_profiler import cuda_memory_profiling 2 | from .memory_tensor import MTensor 3 | -------------------------------------------------------------------------------- /example/activation/profile_activation.sh: -------------------------------------------------------------------------------- 1 | for name_model in "gpt2-15b"; do 2 | python activation.py --model_name ${name_model} 3 | done 4 | -------------------------------------------------------------------------------- /.style.yapf: -------------------------------------------------------------------------------- 1 | [style] 2 | based_on_style = google 3 | spaces_before_comment = 4 4 | split_before_logical_operator = true 5 | column_limit = 120 6 | -------------------------------------------------------------------------------- /elixir/chunk/scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import ChunkScheduler 2 | from .fifo import FIFOScheduler 3 | from .prefetch import PrefetchScheduler 4 | -------------------------------------------------------------------------------- /example/benchmark/scripts/rm_lock.sh: -------------------------------------------------------------------------------- 1 | rm -rf ~/.cache/torch_extensions/py39_cu116/cpu_adam/lock 2 | rm -rf ~/.cache/torch_extensions/py39_cu116/fused_adam/lock 3 | -------------------------------------------------------------------------------- /elixir/tracer/param_tracer/__init__.py: -------------------------------------------------------------------------------- 1 | from .fx_order import generate_fx_order 2 | from .td_order import generate_td_order 3 | from .tf_order import generate_tf_order 4 | -------------------------------------------------------------------------------- /elixir/chunk/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import BlockRequire, Chunk, ChunkGroup, MemoryPool, PrivateBlock, PublicBlock, TensorBlock, TensorState 2 | from .fetcher import ChunkFetcher 3 | -------------------------------------------------------------------------------- /profile/get_bandwidth.sh: -------------------------------------------------------------------------------- 1 | for num_gpu in 1 2 4; do 2 | echo "${num_gpu} is used" 3 | torchrun --nproc_per_node=${num_gpu} --master_port=29515 profile_bandwidth.py 4 | done 5 | -------------------------------------------------------------------------------- /elixir/search/__init__.py: -------------------------------------------------------------------------------- 1 | from .mini_waste import minimum_waste_search 2 | from .optimal import optimal_search 3 | from .result import ChunkPlan, SearchResult 4 | from .simple import simple_search 5 | -------------------------------------------------------------------------------- /elixir/chunk/core/__init__.py: -------------------------------------------------------------------------------- 1 | from .chunk import Chunk 2 | from .group import ChunkGroup 3 | from .memory_pool import BlockRequire, MemoryPool, PrivateBlock, PublicBlock, TensorBlock 4 | from .states import TensorState 5 | -------------------------------------------------------------------------------- /example/benchmark/scripts/fsdp.sh: -------------------------------------------------------------------------------- 1 | conda activate ds-torch-1.13 2 | 3 | export T_MODEL=${T_MODEL:-"opt-1b"} 4 | 5 | export N_GPU=${N_GPU:-1} 6 | export N_BS=${N_BS:-16} 7 | export N_STEP=${N_STEP:-6} 8 | 9 | export T_DP="fsdp" 10 | bash ./run_script.sh 11 | -------------------------------------------------------------------------------- /profile/get_optim_v.sh: -------------------------------------------------------------------------------- 1 | for num_gpu in 1 2 4; do 2 | echo "${num_gpu} is used" 3 | wc=`cat /proc/cpuinfo | grep "processor"| wc -l` 4 | let TNUM=wc/${num_gpu} 5 | env OMP_NUM_THREADS=${TNUM} torchrun --nproc_per_node=${num_gpu} --master_port=29515 profile_optimizer.py 6 | done 7 | -------------------------------------------------------------------------------- /test/src/test_move.py: -------------------------------------------------------------------------------- 1 | from elixir.c_utils import move_count 2 | 3 | 4 | def test_move_count(): 5 | steps = [[0], [1, 2], [3], [3], [1, 2], [0]] 6 | size = 2 7 | assert move_count(steps, size) == 12 8 | 9 | 10 | if __name__ == '__main__': 11 | test_move_count() 12 | -------------------------------------------------------------------------------- /example/benchmark/scripts/elixir.sh: -------------------------------------------------------------------------------- 1 | # source /opt/conda/etc/profile.d/conda.sh 2 | conda activate adv-torch-1.13 3 | 4 | export T_MODEL=${T_MODEL:-"gpt2-20b"} 5 | 6 | export N_GPU=${N_GPU:-4} 7 | export N_BS=${N_BS:-16} 8 | export N_STEP=${N_STEP:-6} 9 | 10 | export T_DP="elixir" 11 | bash ./run_script.sh 12 | -------------------------------------------------------------------------------- /elixir/kernels/layernorm.py: -------------------------------------------------------------------------------- 1 | from apex.normalization.fused_layer_norm import fused_layer_norm, fused_layer_norm_affine 2 | 3 | 4 | def ln_func(input, normalized_shape, weight=None, bias=None, eps=1e-05): 5 | if weight is None: 6 | assert bias is None 7 | return fused_layer_norm(input, normalized_shape, eps) 8 | else: 9 | assert weight is not None and bias is not None 10 | return fused_layer_norm_affine(input, weight, bias, normalized_shape, eps) 11 | -------------------------------------------------------------------------------- /example/fine-tune/readme.md: -------------------------------------------------------------------------------- 1 | # Fine-tune Example 2 | 3 | ## Command 4 | 5 | ```bash 6 | bash run_elixir.sh (or run_ddp.sh) 7 | ``` 8 | 9 | ## Results 10 | 11 | 12 | | config | accuracy | f1 | 13 | | :----: | :----: | :----: | 14 | | ddp-fp32-1GPUs | 0.8382 | 0.8889 | 15 | | elixir-fp32-1GPUs | 0.8407 | 0.8904 | 16 | | ddp-fp32-2GPUs | 0.8333 | 0.8859 | 17 | | elixir-fp32-2GPUs | 0.8333 | 0.8855 | 18 | | ddp-fp32-4GPUs | 0.8358 | 0.8874 | 19 | | elixir-fp32-4GPUs | 0.8382 | 0.8889 | 20 | -------------------------------------------------------------------------------- /example/benchmark/run_elixir.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | export GPUNUM=${GPUNUM:-4} 4 | export BATCH_SIZE=${BATCH_SIZE:-32} 5 | export MODEL_NAME=${MODEL_TYPE:-"gpt2-400m"} 6 | export TRAIN_STEP=${TRAIN_STEP:-6} 7 | # export PYTHONPATH=$PWD:$PYTHONPATH 8 | 9 | mkdir -p elixir_logs 10 | 11 | torchrun --standalone --nproc_per_node=${GPUNUM} ./elixir_demo.py \ 12 | --model_name=${MODEL_NAME} \ 13 | --batch_size=${BATCH_SIZE} \ 14 | --train_step=${TRAIN_STEP} \ 15 | 2>&1 | tee ./elixir_logs/${MODEL_TYPE}_gpu_${GPUNUM}_bs_${BATCH_SIZE}.log 16 | -------------------------------------------------------------------------------- /elixir/kernels/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | fused_torch_functions = {F.layer_norm: F.layer_norm} 5 | 6 | 7 | def register_fused_layer_norm(): 8 | try: 9 | from .layernorm import ln_func 10 | fused_torch_functions[F.layer_norm] = ln_func 11 | print('Register fused layer norm successfully from apex.') 12 | except: 13 | print('Cannot import fused layer norm, please install apex from source.') 14 | pass 15 | 16 | 17 | register_fused_layer_norm() 18 | -------------------------------------------------------------------------------- /test/utils/iterator.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | 4 | class TestIterator(ABC): 5 | 6 | def __init__(self, length=10) -> None: 7 | self.length = length 8 | self.step = 0 9 | 10 | @abstractmethod 11 | def generate(self): 12 | pass 13 | 14 | def __next__(self): 15 | if self.step < self.length: 16 | self.step += 1 17 | return self.generate() 18 | else: 19 | raise StopIteration 20 | 21 | def __len__(self): 22 | return self.length 23 | -------------------------------------------------------------------------------- /test/tools/test_registry.py: -------------------------------------------------------------------------------- 1 | from test.utils import to_cuda 2 | 3 | import pytest 4 | import torch 5 | 6 | 7 | def test_registry(): 8 | from test.utils.registry import TEST_MODELS 9 | for name, model_tuple in TEST_MODELS: 10 | torch.cuda.synchronize() 11 | print(f'model `{name}` is in testing') 12 | 13 | model_fn, data_fn = model_tuple 14 | model = model_fn().cuda() 15 | data = to_cuda(data_fn()) 16 | loss = model(**data) 17 | loss.backward() 18 | 19 | 20 | if __name__ == '__main__': 21 | test_registry() 22 | -------------------------------------------------------------------------------- /example/benchmark/scripts/deepspeed.sh: -------------------------------------------------------------------------------- 1 | # source /opt/conda/etc/profile.d/conda.sh 2 | conda activate ds-torch-1.13 3 | 4 | export T_MODEL=${T_MODEL:-"opt-1b"} 5 | 6 | export N_GPU=${N_GPU:-1} 7 | export N_BS=${N_BS:-16} 8 | export N_STEP=${N_STEP:-6} 9 | 10 | export T_DP="zero2" 11 | bash ./rm_lock.sh 12 | bash ./run_script.sh 13 | 14 | export T_DP="zero2-offload" 15 | bash ./rm_lock.sh 16 | bash ./run_script.sh 17 | 18 | export T_DP="zero3" 19 | bash ./rm_lock.sh 20 | bash ./run_script.sh 21 | 22 | export T_DP="zero3-offload" 23 | bash ./rm_lock.sh 24 | bash ./run_script.sh 25 | -------------------------------------------------------------------------------- /example/benchmark/scripts/fsdp_benchmark.sh: -------------------------------------------------------------------------------- 1 | export HF_DATASETS_OFFLINE=1 2 | export TRANSFORMERS_OFFLINE=1 3 | 4 | for name_model in "gpt2-4b" "gpt2-10b" "gpt2-15b" "gpt2-20b"; do 5 | for num_gpu in 1 2 4; do 6 | for batch_size in 4 8 12 16; do 7 | echo "****************** Begin ***************************" 8 | T_MODEL=${name_model} N_GPU=${num_gpu} N_BS=${batch_size} bash ./fsdp.sh 9 | echo "****************** Finished ***************************" 10 | echo "" 11 | echo "" 12 | done 13 | done 14 | done 15 | -------------------------------------------------------------------------------- /example/benchmark/scripts/run_script.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | T_DP=${T_DP:-"fsdp"} 4 | T_MODEL=${T_MODEL:-"opt-1b"} 5 | 6 | N_GPU=${N_GPU:-1} 7 | N_BS=${N_BS:-16} 8 | N_STEP=${N_STEP:-6} 9 | 10 | mkdir -p benchmark_logs 11 | 12 | wc=`cat /proc/cpuinfo | grep "processor"| wc -l` 13 | let TNUM=wc/${N_GPU} 14 | 15 | env OMP_NUM_THREADS=${TNUM} torchrun --nproc_per_node=${N_GPU} --master_port=29911 ./script.py \ 16 | --dp_type=${T_DP} \ 17 | --model_name=${T_MODEL} \ 18 | --batch_size=${N_BS} \ 19 | --train_step=${N_STEP} \ 20 | 2>&1 | tee ./benchmark_logs/${T_MODEL}_bs_${N_BS}_gpu_${N_GPU}_${T_DP}.log 21 | -------------------------------------------------------------------------------- /test/ctx/test_meta_ctx.py: -------------------------------------------------------------------------------- 1 | from test.utils import TEST_MODELS 2 | 3 | from elixir.ctx import MetaContext 4 | 5 | 6 | def test_meta_context(): 7 | builder, *_ = TEST_MODELS.get('resnet') 8 | with MetaContext(): 9 | model = builder() 10 | 11 | for name, param in model.named_parameters(): 12 | assert param.device.type == 'meta' 13 | print(name, param) 14 | 15 | for name, buffer in model.named_buffers(): 16 | assert buffer.device.type == 'meta' 17 | print(name, buffer) 18 | 19 | 20 | if __name__ == '__main__': 21 | test_meta_context() 22 | -------------------------------------------------------------------------------- /example/benchmark/fetch_hf_settings.py: -------------------------------------------------------------------------------- 1 | from elixir.ctx import MetaContext 2 | from example.common.models import get_model 3 | 4 | gpt_list = ['gpt2-400m', 'gpt2-1b'] 5 | opt_list = ['opt-350m', 'opt-1b', 'opt-3b', 'opt-7b', 'opt-13b', 'opt-30b', 'opt-66b', 'opt-175b'] 6 | 7 | 8 | def init_model(name: str): 9 | with MetaContext(): 10 | model = get_model(name) 11 | del model 12 | 13 | 14 | def main(): 15 | for name in gpt_list: 16 | init_model(name) 17 | 18 | for name in opt_list: 19 | init_model(name) 20 | 21 | 22 | if __name__ == '__main__': 23 | main() 24 | -------------------------------------------------------------------------------- /test/utils/resnet.py: -------------------------------------------------------------------------------- 1 | from test.utils.iterator import TestIterator 2 | from test.utils.registry import TEST_MODELS 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torchvision.models import resnet18 7 | 8 | 9 | def resnet_data_fn(): 10 | return dict(x=torch.randn(4, 3, 32, 32)) 11 | 12 | 13 | class ResNetModel(nn.Module): 14 | 15 | def __init__(self) -> None: 16 | super().__init__() 17 | self.r = resnet18() 18 | 19 | def forward(self, x): 20 | output = self.r(x) 21 | return output.sum() 22 | 23 | 24 | TEST_MODELS.register('resnet', ResNetModel, resnet_data_fn) 25 | -------------------------------------------------------------------------------- /example/benchmark/readme.md: -------------------------------------------------------------------------------- 1 | # Benchmark 2 | 3 | See the folder `scripts` checking how to run benchmarks. 4 | 5 | ## Environments 6 | 7 | Here are some notice for users. 8 | 9 | * ColossalAI and DeepSpeed can not be installed in one conda environment, just create two for them respectively. 10 | 11 | * The version of `PyTorch` should be `1.13.1`. 12 | 13 | * The version of `transformers` should be `4.26.1`. 14 | 15 | * The version of `deepspeed` should be `0.8.3`. 16 | 17 | * We found that FSDP in PyTorch is not compatible with the gradient checkpointing even in PyTorch 2.0. 18 | Thus, we use FSDP from `fairscale(0.4.13)`. 19 | -------------------------------------------------------------------------------- /example/benchmark/scripts/benchmark.sh: -------------------------------------------------------------------------------- 1 | export HF_DATASETS_OFFLINE=1 2 | export TRANSFORMERS_OFFLINE=1 3 | 4 | for name_model in "gpt2-4b" "gpt2-10b" "gpt2-15b" "gpt2-20b"; do 5 | for num_gpu in 1 2 4; do 6 | for batch_size in 4 8 12 16; do 7 | echo "****************** Begin ***************************" 8 | T_MODEL=${name_model} N_GPU=${num_gpu} N_BS=${batch_size} bash ./deepspeed.sh 9 | T_MODEL=${name_model} N_GPU=${num_gpu} N_BS=${batch_size} bash ./elixir.sh 10 | echo "****************** Finished ***************************" 11 | echo "" 12 | echo "" 13 | done 14 | done 15 | done 16 | -------------------------------------------------------------------------------- /example/benchmark/run_gemini.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | # The following options only valid when DISTPLAN="colossalai" 4 | export GPUNUM=${GPUNUM:-4} 5 | export PLACE_POLICY=${PLACEMENT:-"cuda"} 6 | export BATCH_SIZE=${BATCH_SIZE:-32} 7 | export MODEL_NAME=${MODEL_TYPE:-"gpt2-400m"} 8 | export TRAIN_STEP=${TRAIN_STEP:-6} 9 | # export PYTHONPATH=$PWD:$PYTHONPATH 10 | 11 | mkdir -p gemini_logs 12 | 13 | torchrun --standalone --nproc_per_node=${GPUNUM} ./gemini_demo.py \ 14 | --model_name=${MODEL_NAME} \ 15 | --batch_size=${BATCH_SIZE} \ 16 | --place_policy=${PLACE_POLICY} \ 17 | --train_step=${TRAIN_STEP} \ 18 | 2>&1 | tee ./gemini_logs/${MODEL_TYPE}_gpu_${GPUNUM}_bs_${BATCH_SIZE}_${PLACEMENT}.log 19 | -------------------------------------------------------------------------------- /test/utils/registry.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Callable 3 | 4 | 5 | class Registry(object): 6 | 7 | def __init__(self) -> None: 8 | super().__init__() 9 | self._registry_dict = OrderedDict() 10 | 11 | def register(self, name: str, model_fn: Callable, data_fn: Callable): 12 | assert name not in self._registry_dict 13 | 14 | model_tuple = (model_fn, data_fn) 15 | self._registry_dict[name] = model_tuple 16 | 17 | def get(self, name: str): 18 | return self._registry_dict[name] 19 | 20 | def __iter__(self): 21 | return iter(self._registry_dict.items()) 22 | 23 | 24 | TEST_MODELS = Registry() 25 | 26 | __all__ = [TEST_MODELS] 27 | -------------------------------------------------------------------------------- /example/common/zero2_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "steps_per_print": 100000000, 3 | "train_batch_size": 256, 4 | "train_micro_batch_size_per_gpu": 32, 5 | "gradient_accumulation_steps": 1, 6 | "zero_optimization": { 7 | "stage": 2, 8 | "offload_optimizer": { 9 | "device": "cpu" 10 | }, 11 | "allgather_partitions": true, 12 | "reduce_scatter": true, 13 | "allgather_bucket_size": 5e8, 14 | "reduce_bucket_size": 5e8, 15 | "overlap_comm": true, 16 | "contiguous_gradients": true 17 | }, 18 | "fp16": { 19 | "enabled": true, 20 | "loss_scale": 128.0, 21 | "loss_scale_window": 1000, 22 | "hysteresis": 2, 23 | "min_loss_scale": 1 24 | }, 25 | "wall_clock_breakdown": false, 26 | "zero_allow_untested_optimizer": true 27 | } 28 | -------------------------------------------------------------------------------- /example/common/opt.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from transformers import OPTForCausalLM 3 | 4 | 5 | class OPTLMModel(nn.Module): 6 | 7 | def __init__(self, config) -> None: 8 | super().__init__() 9 | self.config = config 10 | self.module = OPTForCausalLM(config=config) 11 | self.enable_gc = False 12 | 13 | def gradient_checkpointing_enable(self): 14 | self.module.gradient_checkpointing_enable() 15 | self.enable_gc = True 16 | 17 | def forward(self, input_ids, attention_mask): 18 | loss = self.module( 19 | # pre-commit: do not rearrange 20 | input_ids=input_ids, 21 | attention_mask=attention_mask, 22 | labels=input_ids, 23 | use_cache=(not self.enable_gc))['loss'] 24 | return loss 25 | -------------------------------------------------------------------------------- /test/search/test_optimal.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from test.utils import TEST_MODELS 3 | 4 | import torch 5 | 6 | from elixir.cuda import gpu_device 7 | from elixir.search import optimal_search 8 | 9 | 10 | def step_fn(model, inp): 11 | model(**inp).backward() 12 | 13 | 14 | def test_optimal_search(): 15 | model_fn, data_fn = TEST_MODELS.get('gpt2_small') 16 | model = model_fn() 17 | data = data_fn() 18 | 19 | sr = optimal_search(model, 1, unified_dtype=torch.float16, overlap=True, verbose=True, inp=data, step_fn=step_fn) 20 | 21 | chunk_plans = deepcopy(sr.param_chunk_plans) 22 | for plan in chunk_plans: 23 | assert plan.chunk_dtype == torch.float16 24 | assert plan.kwargs.get('shard_device') == gpu_device() 25 | 26 | 27 | if __name__ == '__main__': 28 | test_optimal_search() 29 | -------------------------------------------------------------------------------- /test/parameter/hf_models/__init__.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import torch 4 | 5 | 6 | def test_hf_model(builder, kwargs, data_fn): 7 | torch_model = builder(**kwargs).cuda() 8 | test_model = deepcopy(torch_model) 9 | 10 | from elixir.parameter.temp import transform 11 | test_model = transform(test_model) 12 | 13 | torch_model.eval() 14 | test_model.eval() 15 | 16 | data = data_fn() 17 | for k, v in data.items(): 18 | if isinstance(v, torch.Tensor): 19 | data[k] = v.cuda() 20 | 21 | torch_out = torch_model(**data) 22 | test_out = test_model(**data) 23 | 24 | for k, u in torch_out.items(): 25 | v = test_out[k] 26 | if isinstance(u, torch.Tensor): 27 | assert torch.equal(u, v), f'output {k} is wrong' 28 | 29 | torch.cuda.synchronize() 30 | -------------------------------------------------------------------------------- /elixir/kernels/attn_wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Model 4 | from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder 5 | 6 | from .gpt_attention import XGPT2Attention, XGPT2Model 7 | from .opt_attention import XOPTAttention, XOPTDecoder 8 | 9 | 10 | def wrap_attention(model: nn.Module): 11 | for name, module in model.named_modules(): 12 | if isinstance(module, GPT2Model): 13 | module.__class__ = XGPT2Model 14 | elif isinstance(module, GPT2Attention): 15 | module.__class__ = XGPT2Attention 16 | elif isinstance(module, OPTAttention): 17 | module.__class__ = XOPTAttention 18 | elif isinstance(module, OPTDecoder): 19 | module.__class__ = XOPTDecoder 20 | return model 21 | -------------------------------------------------------------------------------- /elixir/ctx/meta_ctx.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from elixir.ctx import tensor_creation_methods 4 | 5 | 6 | class MetaContext(object): 7 | 8 | def __init__(self, device_type: str = 'meta') -> None: 9 | super().__init__() 10 | self.device_type = device_type 11 | return None 12 | 13 | def __enter__(self): 14 | 15 | def meta_wrap(func): 16 | 17 | def wrapped_func(*args, **kwargs): 18 | kwargs['device'] = self.device_type 19 | return func(*args, **kwargs) 20 | 21 | return wrapped_func 22 | 23 | for name, method in tensor_creation_methods.items(): 24 | setattr(torch, name, meta_wrap(method)) 25 | 26 | def __exit__(self, exc_type, exc_val, exc_tb): 27 | for name, method in tensor_creation_methods.items(): 28 | setattr(torch, name, method) 29 | -------------------------------------------------------------------------------- /example/common/zero3_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "steps_per_print": 100000000, 3 | "train_batch_size": 256, 4 | "train_micro_batch_size_per_gpu": 32, 5 | "gradient_accumulation_steps": 1, 6 | "zero_optimization": { 7 | "stage": 3, 8 | "offload_optimizer": { 9 | "device": "cpu" 10 | }, 11 | "offload_param": { 12 | "device": "cpu" 13 | }, 14 | "stage3_max_live_parameters": 1e9, 15 | "stage3_max_reuse_distance": 1e9, 16 | "stage3_prefetch_bucket_size": 1e7, 17 | "stage3_param_persistence_threshold": 1e5, 18 | "reduce_bucket_size": 1e7, 19 | "contiguous_gradients": true 20 | }, 21 | "fp16": { 22 | "enabled": true, 23 | "loss_scale": 128.0, 24 | "loss_scale_window": 1000, 25 | "hysteresis": 2, 26 | "min_loss_scale": 1 27 | }, 28 | "wall_clock_breakdown": false, 29 | "zero_allow_untested_optimizer": true 30 | } 31 | -------------------------------------------------------------------------------- /elixir/cuda.py: -------------------------------------------------------------------------------- 1 | from functools import cache 2 | 3 | import torch 4 | from torch.cuda._utils import _get_device_index 5 | 6 | elixir_cuda_fraction = dict() 7 | 8 | 9 | @cache 10 | def gpu_device(): 11 | return torch.device(torch.cuda.current_device()) 12 | 13 | 14 | def set_memory_fraction(fraction, device=None): 15 | torch.cuda.set_per_process_memory_fraction(fraction, device) 16 | if device is None: 17 | device = torch.cuda.current_device() 18 | device = _get_device_index(device) 19 | elixir_cuda_fraction[device] = fraction 20 | 21 | 22 | def get_allowed_memory(device=None): 23 | total_memory = torch.cuda.get_device_properties(device).total_memory 24 | if device is None: 25 | device = torch.cuda.current_device() 26 | device = _get_device_index(device) 27 | fraction = elixir_cuda_fraction.get(device, 1.0) 28 | return int(fraction * total_memory) 29 | -------------------------------------------------------------------------------- /test/utils/mlp.py: -------------------------------------------------------------------------------- 1 | from test.utils.registry import TEST_MODELS 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | def mlp_data_fn(): 8 | return dict(x=torch.randn(4, 16)) 9 | 10 | 11 | class MlpModule(nn.Module): 12 | 13 | def __init__(self, hidden_dim: int = 16) -> None: 14 | super().__init__() 15 | self.proj1 = nn.Linear(hidden_dim, 4 * hidden_dim) 16 | self.act = nn.GELU() 17 | self.proj2 = nn.Linear(4 * hidden_dim, hidden_dim) 18 | 19 | def forward(self, x): 20 | return x + (self.proj2(self.act(self.proj1(x)))) 21 | 22 | 23 | class MlpModel(nn.Module): 24 | 25 | def __init__(self, hidden_dim: int = 16) -> None: 26 | super().__init__() 27 | self.mlp = MlpModule(hidden_dim) 28 | 29 | def forward(self, x): 30 | output = self.mlp(x) 31 | return output.sum() 32 | 33 | 34 | TEST_MODELS.register('mlp', MlpModel, mlp_data_fn) 35 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | 3 | - repo: https://github.com/pycqa/isort 4 | rev: 5.12.0 5 | hooks: 6 | - id: isort 7 | name: sort all imports (python) 8 | 9 | - repo: https://github.com/pre-commit/mirrors-yapf 10 | rev: v0.32.0 11 | hooks: 12 | - id: yapf 13 | name: yapf formatter 14 | args: ['--style=.style.yapf', '--parallel', '--in-place'] 15 | 16 | - repo: https://github.com/pre-commit/mirrors-clang-format 17 | rev: v13.0.1 18 | hooks: 19 | - id: clang-format 20 | name: clang formatter 21 | 22 | - repo: https://github.com/pre-commit/pre-commit-hooks 23 | rev: v4.3.0 24 | hooks: 25 | - id: check-yaml 26 | - id: check-merge-conflict 27 | - id: check-case-conflict 28 | - id: trailing-whitespace 29 | - id: end-of-file-fixer 30 | - id: mixed-line-ending 31 | args: ['--fix=lf'] 32 | - id: double-quote-string-fixer 33 | -------------------------------------------------------------------------------- /test/exp_tracer/test_fx_order.py: -------------------------------------------------------------------------------- 1 | from test.utils import TEST_MODELS, assert_dict_keys 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from elixir.tracer.param_tracer import generate_fx_order 7 | 8 | 9 | def test_fx_forward(): 10 | builder, *_ = TEST_MODELS.get_func('small')() 11 | model = builder() 12 | forward_order = generate_fx_order(model) 13 | 14 | # for step in forward_order: 15 | # print(step) 16 | 17 | assert_dict_keys(forward_order[0], ['embed.weight']) 18 | assert_dict_keys(forward_order[1], ['mlp.proj1.weight', 'mlp.proj1.bias']) 19 | assert_dict_keys(forward_order[2], ['mlp.proj2.weight', 'mlp.proj2.bias']) 20 | assert_dict_keys(forward_order[3], ['norm1.weight', 'norm1.bias']) 21 | assert_dict_keys(forward_order[4], ['norm2.weight', 'norm2.bias']) 22 | assert_dict_keys(forward_order[5], ['embed.weight']) 23 | 24 | 25 | if __name__ == '__main__': 26 | test_fx_forward() 27 | -------------------------------------------------------------------------------- /test/kernels/test_attn.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from test.utils import TEST_MODELS, to_cuda 3 | 4 | from torch.testing import assert_close 5 | 6 | from elixir.kernels.attn_wrapper import wrap_attention 7 | 8 | 9 | def exam_one_model(model_fn, data_fn): 10 | torch_model = model_fn().cuda() 11 | test_model = deepcopy(torch_model) 12 | test_model = wrap_attention(test_model) 13 | 14 | data = to_cuda(data_fn()) 15 | torch_out = torch_model(**data) 16 | torch_out.backward() 17 | 18 | test_out = test_model(**data) 19 | test_out.backward() 20 | 21 | assert_close(torch_out, test_out) 22 | for (name, p_torch), p_test in zip(torch_model.named_parameters(), test_model.parameters()): 23 | assert_close(p_torch.grad, p_test.grad) 24 | 25 | 26 | def test_gpt_atten_kernel(): 27 | exam_one_model(*TEST_MODELS.get('gpt2_micro')) 28 | exam_one_model(*TEST_MODELS.get('opt_micro')) 29 | 30 | 31 | if __name__ == '__main__': 32 | test_gpt_atten_kernel() 33 | -------------------------------------------------------------------------------- /test/utils/small.py: -------------------------------------------------------------------------------- 1 | from test.utils.mlp import MlpModule 2 | from test.utils.registry import TEST_MODELS 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | def small_data_fn(): 9 | return dict(x=torch.randint(low=0, high=20, size=(4, 8))) 10 | 11 | 12 | class SmallModel(nn.Module): 13 | 14 | def __init__(self, num_embeddings: int = 20, hidden_dim: int = 16) -> None: 15 | super().__init__() 16 | self.embed = nn.Embedding(num_embeddings, hidden_dim) 17 | self.norm1 = nn.LayerNorm(hidden_dim) 18 | self.mlp = MlpModule(hidden_dim=hidden_dim) 19 | self.norm2 = nn.LayerNorm(hidden_dim) 20 | self.proj = nn.Linear(hidden_dim, num_embeddings, bias=False) 21 | self.proj.weight = self.embed.weight 22 | 23 | def forward(self, x): 24 | x = self.embed(x) 25 | x = x + self.norm1(self.mlp(x)) 26 | x = self.proj(self.norm2(x)) 27 | x = x.mean(dim=-2) 28 | return x.sum() 29 | 30 | 31 | TEST_MODELS.register('small', SmallModel, small_data_fn) 32 | -------------------------------------------------------------------------------- /elixir/search/result.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, NamedTuple 2 | 3 | import torch 4 | 5 | from elixir.chunk import ChunkGroup 6 | 7 | 8 | class ChunkPlan(NamedTuple): 9 | """ChunkPlan is a type of configuration used to instruct the initialization of a chunk. 10 | 11 | args: 12 | name_list: contains the names of parameters that should be pushed into this chunk 13 | chunk_size: the size of this chunk 14 | chunk_dtype: the dtype of this chunk 15 | kwargs: a dictionary used in __init__ function of Chunk 16 | """ 17 | name_list: List[str] 18 | chunk_size: int 19 | chunk_dtype: torch.dtype 20 | kwargs: Dict 21 | 22 | 23 | class SearchResult(object): 24 | 25 | def __init__(self, 26 | chunk_group: ChunkGroup, 27 | chunk_plans: List[ChunkPlan], 28 | param_called_per_step: List[List[str]] = None) -> None: 29 | super().__init__() 30 | self.chunk_group = chunk_group 31 | self.param_chunk_plans = chunk_plans 32 | self.param_called_per_step = param_called_per_step 33 | -------------------------------------------------------------------------------- /elixir/chunk/scheduler/fifo.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from elixir.chunk.core import Chunk 4 | 5 | from .base import Chunk, ChunkScheduler 6 | 7 | 8 | class FIFOScheduler(ChunkScheduler): 9 | 10 | def __init__(self) -> None: 11 | super().__init__() 12 | self.fifo_dict: Optional[dict] = None 13 | 14 | def reset(self) -> None: 15 | super().reset() 16 | self.fifo_dict = dict() 17 | 18 | def clear(self) -> None: 19 | super().clear() 20 | self.fifo_dict = None 21 | 22 | def top(self) -> Optional[Chunk]: 23 | if not super().top(): 24 | return None 25 | dict_iter = iter(self.fifo_dict) 26 | ret = next(dict_iter) 27 | return ret 28 | 29 | def add(self, chunk: Chunk) -> bool: 30 | if not super().add(chunk): 31 | return False 32 | self.fifo_dict[chunk] = True 33 | return True 34 | 35 | def remove(self, chunk: Chunk) -> bool: 36 | if not super().remove(chunk): 37 | return False 38 | self.fifo_dict.pop(chunk) 39 | return True 40 | -------------------------------------------------------------------------------- /elixir/chunk/core/states.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class TensorState(Enum): 5 | FREE = 0 6 | COMPUTE = 1 7 | HOLD = 2 8 | HOLD_AFTER_BWD = 3 9 | READY_FOR_REDUCE = 4 10 | 11 | 12 | # expected: free -> hold -> compute -> hold -> 13 | # -> compute -> hold_after_bwd -> ready_for_reduce 14 | legal_ts_update_list = [(TensorState.FREE, TensorState.HOLD), (TensorState.FREE, TensorState.COMPUTE), 15 | (TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE), 16 | (TensorState.COMPUTE, TensorState.HOLD), (TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD), 17 | (TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE), 18 | (TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE), 19 | (TensorState.READY_FOR_REDUCE, TensorState.HOLD)] 20 | 21 | 22 | def ts_update_sanity_check(old_state, new_state) -> bool: 23 | if (old_state, new_state) not in legal_ts_update_list: 24 | raise RuntimeError(f'illegal tensor state updating: {old_state} -> {new_state}') 25 | return True 26 | -------------------------------------------------------------------------------- /test/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.testing import assert_close 3 | from torch.utils._pytree import tree_map 4 | 5 | from . import gpt, mlp, opt, resnet, small 6 | from .registry import TEST_MODELS 7 | 8 | 9 | def to_cuda(input_dict): 10 | 11 | def local_fn(t): 12 | if isinstance(t, torch.Tensor): 13 | t = t.cuda() 14 | return t 15 | 16 | ret = tree_map(local_fn, input_dict) 17 | return ret 18 | 19 | 20 | def allclose(ta, tb, **kwargs): 21 | assert_close(ta, tb, **kwargs) 22 | return True 23 | 24 | 25 | def assert_dict_keys(test_dict, keys): 26 | assert len(test_dict) == len(keys) 27 | for k in keys: 28 | assert k in test_dict 29 | 30 | 31 | def assert_dict_values(da, db, fn): 32 | assert len(da) == len(db) 33 | for k, v in da.items(): 34 | assert k in db 35 | if not torch.is_tensor(v): 36 | continue 37 | u = db.get(k) 38 | if u.device != v.device: 39 | v = v.to(u.device) 40 | # print(f"checking key {k}: {u.shape} vs {v.shape}") 41 | assert fn(u.data, v.data), f'max diff {torch.max(torch.abs(u.data - v.data))}' 42 | -------------------------------------------------------------------------------- /test/search/test_mini_waste.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from test.utils import TEST_MODELS 3 | 4 | import torch 5 | 6 | from elixir.cuda import gpu_device 7 | from elixir.search import minimum_waste_search 8 | 9 | 10 | def step_fn(model, inp): 11 | model(**inp).backward() 12 | 13 | 14 | def test_mini_waste_search(): 15 | model_fn, data_fn = TEST_MODELS.get('gpt2_small') 16 | model = model_fn() 17 | data = data_fn() 18 | 19 | sr = minimum_waste_search(model, 20 | 1, 21 | unified_dtype=torch.float16, 22 | cpu_offload=True, 23 | prefetch=True, 24 | verbose=True, 25 | inp=data, 26 | step_fn=step_fn) 27 | 28 | chunk_plans = deepcopy(sr.param_chunk_plans) 29 | for plan in chunk_plans: 30 | assert plan.chunk_dtype == torch.float16 31 | assert plan.kwargs.get('shard_device') == torch.device('cpu') 32 | assert plan.kwargs.get('cpu_pin_memory') == True 33 | 34 | 35 | if __name__ == '__main__': 36 | test_mini_waste_search() 37 | -------------------------------------------------------------------------------- /example/common/fsdp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP 4 | 5 | from elixir.ctx import MetaContext 6 | from elixir.kernels.attn_wrapper import wrap_attention 7 | from elixir.utils import get_model_size 8 | from example.common.models import get_model 9 | 10 | 11 | def train_init(model_name: str): 12 | with MetaContext('cuda'): 13 | model = get_model(model_name) 14 | model_size = get_model_size(model) 15 | model = FSDP(module=model, mixed_precision=True, flatten_parameters=False) 16 | 17 | optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0) 18 | model.gradient_checkpointing_enable() 19 | 20 | model = wrap_attention(model) 21 | model.train() 22 | 23 | def forward(data): 24 | return model(**data) 25 | 26 | def backward(loss): 27 | loss.backward() 28 | 29 | def optim(): 30 | optimizer.step() 31 | optimizer.zero_grad() 32 | 33 | return forward, backward, optim, model_size 34 | 35 | 36 | if __name__ == '__main__': 37 | import colossalai 38 | colossalai.launch_from_torch(config={}) 39 | print(train_init('opt-1b')) 40 | -------------------------------------------------------------------------------- /test/exp_tracer/test_td_order.py: -------------------------------------------------------------------------------- 1 | from test.utils import TEST_MODELS, assert_dict_keys 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from elixir.tracer.param_tracer import generate_fx_order, generate_td_order 7 | 8 | 9 | def test_td_forward_backward(): 10 | builder, train_iter, test_iter, criterion = TEST_MODELS.get_func('mlp')() 11 | model = builder() 12 | data, label = next(train_iter) 13 | data.requires_grad = True 14 | 15 | def forward_backward_fn(model, inp): 16 | model(*inp).sum().backward() 17 | 18 | td_order = generate_td_order(model, data, forward_backward_fn) 19 | for step_dict in td_order: 20 | print(step_dict) 21 | 22 | assert_dict_keys(td_order[0], ['proj1.weight']) 23 | assert_dict_keys(td_order[1], ['proj1.weight', 'proj1.bias']) 24 | assert_dict_keys(td_order[2], ['proj2.weight']) 25 | assert_dict_keys(td_order[3], ['proj2.weight', 'proj2.bias']) 26 | assert_dict_keys(td_order[4], ['proj2.weight']) 27 | assert_dict_keys(td_order[5], ['proj2.weight']) 28 | assert_dict_keys(td_order[6], ['proj1.weight']) 29 | assert_dict_keys(td_order[7], ['proj1.weight']) 30 | 31 | 32 | if __name__ == '__main__': 33 | test_td_forward_backward() 34 | -------------------------------------------------------------------------------- /elixir/meta_registrations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch._meta_registrations import register_meta 3 | 4 | aten = torch.ops.aten 5 | 6 | 7 | # since we fix the torch version to 1.13.1, we have to add unimplemented meta ops 8 | # all these functions are from here https://github.com/pytorch/pytorch/blob/master/torch/_meta_registrations.py 9 | @register_meta([aten.convolution_backward.default]) 10 | def meta_convolution_backward( 11 | grad_output_, 12 | input_, 13 | weight_, 14 | bias_sizes_opt, 15 | stride, 16 | padding, 17 | dilation, 18 | transposed, 19 | output_padding, 20 | groups, 21 | output_mask, 22 | ): 23 | # High level logic taken from slow_conv3d_backward_cpu which should 24 | # be representative of all convolution_backward impls 25 | backend_grad_input = None 26 | backend_grad_weight = None 27 | backend_grad_bias = None 28 | 29 | if output_mask[0]: 30 | backend_grad_input = grad_output_.new_empty(input_.size()) 31 | if output_mask[1]: 32 | backend_grad_weight = grad_output_.new_empty(weight_.size()) 33 | if output_mask[2]: 34 | backend_grad_bias = grad_output_.new_empty(bias_sizes_opt) 35 | 36 | return (backend_grad_input, backend_grad_weight, backend_grad_bias) 37 | -------------------------------------------------------------------------------- /example/search_example.py: -------------------------------------------------------------------------------- 1 | from test.utils.gpt import GPTLMModel, small_data_fn 2 | 3 | import torch 4 | from transformers import AutoConfig, OPTConfig, OPTForCausalLM 5 | 6 | import elixir 7 | from elixir.ctx import MetaContext 8 | from elixir.search import optimal_search 9 | from elixir.search.utils import find_search_range 10 | from elixir.tracer.memory_tracer import cuda_memory_profiling 11 | from elixir.utils import get_model_size, model_size_formatter 12 | from example.common.models import get_model 13 | from example.common.utils import fake_gpt_data 14 | 15 | 16 | def profile_optimal_search(): 17 | elixir.cuda.set_memory_fraction(0.2) 18 | 19 | with MetaContext(): 20 | model = get_model('opt-1b') 21 | model_size = get_model_size(model) 22 | print(f'model size: {model_size_formatter(model_size)}') 23 | 24 | ids, mask = fake_gpt_data(16, 1024, 50257) 25 | data = dict(input_ids=ids, attention_mask=mask) 26 | 27 | def train_step(model_in, inp_in): 28 | loss = model_in(**inp_in) 29 | loss.backward() 30 | 31 | model.gradient_checkpointing_enable() 32 | sr = optimal_search(model, 4, unified_dtype=torch.float16, overlap=True, verbose=True, inp=data, step_fn=train_step) 33 | 34 | 35 | if __name__ == '__main__': 36 | profile_optimal_search() 37 | -------------------------------------------------------------------------------- /test/tracer/test_tf_order.py: -------------------------------------------------------------------------------- 1 | from test.utils import TEST_MODELS 2 | 3 | from elixir.tracer.param_tracer import generate_tf_order 4 | 5 | 6 | def test_tf_forward_backward(): 7 | model_fn, data_fn = TEST_MODELS.get('gpt2_micro') 8 | model = model_fn() 9 | data = data_fn() 10 | 11 | def forward_backward_fn(local_model, local_input): 12 | local_model(**local_input).backward() 13 | 14 | # model.gradient_checkpointing_enable() 15 | tf_order = generate_tf_order(model, data, forward_backward_fn) 16 | params_per_step = tf_order['params_per_step'] 17 | assert len(params_per_step) == 32 18 | 19 | model.gradient_checkpointing_enable() 20 | tf_order = generate_tf_order(model, data, forward_backward_fn) 21 | params_per_step = tf_order['params_per_step'] 22 | checkpoint_info = tf_order['checkpoint_info'] 23 | for i, step in enumerate(params_per_step): 24 | print(f'step {i}: {step}') 25 | for c in checkpoint_info: 26 | print(f'checkpoint info: {c}') 27 | assert len(params_per_step) == 44 28 | 29 | assert data['input_ids'].device.type == 'cpu' 30 | assert data['attention_mask'].device.type == 'cpu' 31 | for param in model.parameters(): 32 | assert param.device.type == 'cpu' 33 | 34 | 35 | if __name__ == '__main__': 36 | test_tf_forward_backward() 37 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | from torch.utils.cpp_extension import BuildExtension, CppExtension 3 | 4 | __version__ = '0.0.1' 5 | 6 | 7 | def fetch_requirements(path): 8 | with open(path, 'r') as fd: 9 | return [r.strip() for r in fd.readlines()] 10 | 11 | 12 | def fetch_readme(): 13 | with open('README.md', encoding='utf-8') as f: 14 | return f.read() 15 | 16 | 17 | ext_modules = [ 18 | CppExtension(name='elixir.c_utils', 19 | sources=['src/simulator.cpp'], 20 | extra_compile_args=['-O3', '-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5']) 21 | ] 22 | 23 | setup( 24 | name='elixir', 25 | version=__version__, 26 | author='Haichen Huang', 27 | author_email='c2h214748@gmail.com', 28 | url='https://github.com/hpcaitech/Elixir', 29 | packages=find_packages(exclude=( 30 | 'example', 31 | 'profile', 32 | 'src', 33 | 'test', 34 | '*.egg-info', 35 | )), 36 | description='An Optimized Implementation of Elixir (Gemini2.0)', 37 | long_description=fetch_readme(), 38 | long_description_content_type='text/markdown', 39 | ext_modules=ext_modules, 40 | cmdclass={'build_ext': BuildExtension}, 41 | install_requires=fetch_requirements('requirements.txt'), 42 | python_requires='>=3.8', 43 | ) 44 | -------------------------------------------------------------------------------- /elixir/kernels/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import xformers.ops as xops 3 | from torch.utils._pytree import tree_map 4 | 5 | from elixir.tracer.memory_tracer.memory_tensor import MTensor 6 | 7 | 8 | def lower_triangular_attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, p: float = 0.0): 9 | 10 | args = (query, key, value) 11 | meta_flag = False 12 | 13 | for x in args: 14 | if x.device.type == 'meta': 15 | meta_flag = True 16 | break 17 | 18 | if meta_flag: 19 | atten = query @ key.transpose(-2, -1) 20 | output = atten @ value 21 | return output 22 | 23 | profile_flag = False 24 | 25 | def to_torch_tensor(x): 26 | if isinstance(x, MTensor): 27 | nonlocal profile_flag 28 | profile_flag = True 29 | return x.elem 30 | return x 31 | 32 | args = tree_map(to_torch_tensor, args) 33 | query, key, value = args 34 | output = xops.memory_efficient_attention(query=query, 35 | key=key, 36 | value=value, 37 | p=p, 38 | attn_bias=xops.LowerTriangularMask()) 39 | 40 | if profile_flag: 41 | output = MTensor(output) 42 | 43 | return output 44 | -------------------------------------------------------------------------------- /elixir/chunk/scheduler/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Optional 3 | 4 | from elixir.chunk.core import Chunk 5 | 6 | 7 | class ChunkScheduler(ABC): 8 | 9 | def __init__(self) -> None: 10 | super().__init__() 11 | self.releasable_set: Optional[set] = None 12 | self.current_step = -1 13 | 14 | @abstractmethod 15 | def reset(self) -> None: 16 | self.releasable_set = set() 17 | self.current_step = -1 18 | 19 | @abstractmethod 20 | def clear(self) -> None: 21 | # asure the set is empty now 22 | assert not bool(self.releasable_set) 23 | 24 | @abstractmethod 25 | def top(self) -> Optional[Chunk]: 26 | # return None if the releasable set is empty 27 | if not self.releasable_set: 28 | return False 29 | return True 30 | 31 | @abstractmethod 32 | def add(self, chunk: Chunk) -> bool: 33 | if chunk in self.releasable_set: 34 | return False 35 | self.releasable_set.add(chunk) 36 | return True 37 | 38 | @abstractmethod 39 | def remove(self, chunk: Chunk) -> bool: 40 | if chunk not in self.releasable_set: 41 | return False 42 | self.releasable_set.remove(chunk) 43 | return True 44 | 45 | def step(self, *args, **kwags): 46 | self.current_step += 1 47 | -------------------------------------------------------------------------------- /test/utils/opt.py: -------------------------------------------------------------------------------- 1 | from test.utils.registry import TEST_MODELS 2 | 3 | import torch.nn as nn 4 | from transformers import OPTConfig, OPTForCausalLM 5 | 6 | from .gpt import micro_data_fn 7 | 8 | 9 | class OPTLMModel(nn.Module): 10 | 11 | def __init__(self, config) -> None: 12 | super().__init__() 13 | self.config = config 14 | self.module = OPTForCausalLM(config=config) 15 | self.enable_gc = False 16 | 17 | def gradient_checkpointing_enable(self): 18 | self.module.gradient_checkpointing_enable() 19 | self.enable_gc = True 20 | 21 | def forward(self, input_ids, attention_mask): 22 | loss = self.module( 23 | # pre-commit: do not rearrange 24 | input_ids=input_ids, 25 | attention_mask=attention_mask, 26 | labels=input_ids, 27 | use_cache=(not self.enable_gc))['loss'] 28 | return loss 29 | 30 | 31 | def opt_micro(): 32 | opt_config = OPTConfig( 33 | # pre-commit: do not rearrange 34 | vocab_size=128, 35 | activation_dropout=0.0, 36 | dropout=0, 37 | hidden_size=32, 38 | num_hidden_layers=4, 39 | ffn_dim=128, 40 | num_attention_heads=4, 41 | word_embed_proj_dim=32, 42 | output_projection=True) 43 | return OPTLMModel(opt_config) 44 | 45 | 46 | TEST_MODELS.register('opt_micro', opt_micro, micro_data_fn) 47 | -------------------------------------------------------------------------------- /test/tracer/test_cuda_profiler.py: -------------------------------------------------------------------------------- 1 | from test.utils import TEST_MODELS, to_cuda 2 | 3 | import pytest 4 | import torch 5 | 6 | from elixir.tracer.memory_tracer import cuda_memory_profiling 7 | 8 | 9 | def one_step(model, inp): 10 | loss = model(**inp) 11 | loss.backward() 12 | return loss 13 | 14 | 15 | def try_one_model(model_fn, data_fn): 16 | model = model_fn().cuda() 17 | data = to_cuda(data_fn()) 18 | one_step(model, data) # generate gradients 19 | 20 | pre_cuda_alc = torch.cuda.memory_allocated() 21 | torch.cuda.reset_peak_memory_stats() 22 | one_step(model, data) 23 | aft_cuda_alc = torch.cuda.max_memory_allocated() 24 | torch_activation_occ = aft_cuda_alc - pre_cuda_alc 25 | model.zero_grad(set_to_none=True) 26 | print('normal', torch_activation_occ) 27 | 28 | before = torch.cuda.memory_allocated() 29 | profiling_dict = cuda_memory_profiling(model, data, one_step) 30 | after = torch.cuda.memory_allocated() 31 | print('profiling', profiling_dict) 32 | assert before == after 33 | assert torch_activation_occ == profiling_dict['activation_occ'] 34 | print('Check is ok.') 35 | 36 | 37 | def test_cuda_profiler(): 38 | model_list = ['resnet', 'gpt2_micro'] 39 | for name in model_list: 40 | model_fn, data_fn = TEST_MODELS.get(name) 41 | try_one_model(model_fn, data_fn) 42 | 43 | 44 | if __name__ == '__main__': 45 | test_cuda_profiler() 46 | -------------------------------------------------------------------------------- /test/hook/test_hook.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from elixir.hook import BufferStore, HookParam 7 | from elixir.parameter import FakeTensor 8 | 9 | 10 | def test_hook(): 11 | x = nn.Parameter(torch.randn(4, 4)) 12 | 13 | ori_numel = x.numel() 14 | ori_size = x.size() 15 | ori_stride = x.stride() 16 | ori_offset = x.storage_offset() 17 | 18 | fake_data = FakeTensor(x.data) 19 | x.data = fake_data 20 | x.__class__ = HookParam 21 | 22 | assert x.numel() == ori_numel 23 | assert x.size() == ori_size 24 | assert x.stride() == ori_stride 25 | assert x.storage_offset() == ori_offset 26 | 27 | 28 | def test_store(): 29 | buffer = BufferStore(1024, torch.float16) 30 | print(buffer) 31 | 32 | x = torch.randn(4, 128, dtype=torch.float16, device='cuda') 33 | original_ptr_x = x.data_ptr() 34 | copy_x = deepcopy(x) 35 | 36 | y = torch.randn(512, dtype=torch.float16, device='cuda') 37 | original_ptr_y = y.data_ptr() 38 | copy_y = deepcopy(y) 39 | 40 | offset = 0 41 | offset = buffer.insert(x, offset) 42 | assert offset == x.numel() 43 | assert torch.equal(x, copy_x) 44 | 45 | offset = buffer.insert(y, offset) 46 | assert offset == 1024 47 | assert torch.equal(y, copy_y) 48 | 49 | buffer.erase(x) 50 | buffer.erase(y) 51 | assert x.data_ptr() == original_ptr_x 52 | assert y.data_ptr() == original_ptr_y 53 | 54 | 55 | if __name__ == '__main__': 56 | test_store() 57 | -------------------------------------------------------------------------------- /profile/profile_bandwidth.py: -------------------------------------------------------------------------------- 1 | from time import time 2 | 3 | import colossalai 4 | import torch 5 | import torch.distributed as dist 6 | 7 | from elixir.utils import print_rank_0 8 | 9 | 10 | def profile_function(n_times=20): 11 | l_gb = 10 12 | length = l_gb * 10**9 13 | cpu_x = torch.empty(length, dtype=torch.int8, pin_memory=True) 14 | cuda_x = torch.empty(length, dtype=torch.int8, device='cuda') 15 | 16 | torch.cuda.synchronize() 17 | cpu_to_cuda_start = time() 18 | for _ in range(n_times): 19 | cuda_x.copy_(cpu_x) 20 | torch.cuda.synchronize() 21 | cpu_to_cuda_span = time() - cpu_to_cuda_start 22 | 23 | cuda_to_cpu_start = time() 24 | for _ in range(n_times): 25 | cpu_x.copy_(cuda_x) 26 | torch.cuda.synchronize() 27 | cuda_to_cpu_span = time() - cuda_to_cpu_start 28 | 29 | n_proc = dist.get_world_size() 30 | sum_time = torch.tensor(cpu_to_cuda_span, dtype=torch.double, device='cuda') 31 | dist.all_reduce(sum_time, op=dist.ReduceOp.MAX) 32 | cpu_to_cuda_bandwidth = n_times * n_proc * l_gb / sum_time.item() 33 | 34 | sum_time = torch.tensor(cuda_to_cpu_span, dtype=torch.double, device='cuda') 35 | dist.all_reduce(sum_time, op=dist.ReduceOp.MAX) 36 | cuda_to_cpu_bandwidth = n_times * n_proc * l_gb / sum_time.item() 37 | 38 | print_rank_0( 39 | f'Bandwidth profiling result: cpu -> cuda: {cpu_to_cuda_bandwidth: .3f}, cuda -> cpu: {cuda_to_cpu_bandwidth: .3f}' 40 | ) 41 | 42 | 43 | if __name__ == '__main__': 44 | colossalai.launch_from_torch(config={}) 45 | profile_function() 46 | -------------------------------------------------------------------------------- /test/search/test_simple.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from test.utils import TEST_MODELS 3 | 4 | import torch 5 | 6 | from elixir.cuda import gpu_device 7 | from elixir.search import simple_search 8 | 9 | 10 | def step_fn(model, inp): 11 | model(**inp).backward() 12 | 13 | 14 | def test_simple_search(): 15 | model_fn, data_fn = TEST_MODELS.get('small') 16 | model = model_fn() 17 | data = data_fn() 18 | 19 | sr = simple_search(model, 20 | 1, 21 | split_number=5, 22 | shard_device=gpu_device(), 23 | prefetch=True, 24 | verbose=True, 25 | inp=data, 26 | step_fn=step_fn) 27 | 28 | chunk_plans = deepcopy(sr.param_chunk_plans) 29 | private_plan = chunk_plans.pop(0) 30 | assert private_plan.name_list == ['embed.weight'] 31 | assert private_plan.chunk_size == 320 32 | assert private_plan.kwargs.get('shard_device') == gpu_device() 33 | 34 | assert chunk_plans[0].name_list == ['norm1.weight', 'norm1.bias'] 35 | assert chunk_plans[1].name_list == ['mlp.proj1.weight', 'mlp.proj1.bias'] 36 | assert chunk_plans[2].name_list == ['mlp.proj2.weight', 'mlp.proj2.bias'] 37 | assert chunk_plans[3].name_list == ['norm2.weight'] 38 | assert chunk_plans[4].name_list == ['norm2.bias'] 39 | 40 | for plan in chunk_plans: 41 | assert plan.chunk_size == 1088 42 | assert plan.kwargs.get('shard_device') == gpu_device() 43 | 44 | 45 | if __name__ == '__main__': 46 | test_simple_search() 47 | -------------------------------------------------------------------------------- /elixir/tracer/memory_tracer/output_shape.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | # Output functions come from https://github.com/pytorch/pytorch/blob/master/torch/_meta_registrations.py 5 | def check_cuda_mm(*args): 6 | for x in args: 7 | assert isinstance(x, torch.Tensor) 8 | assert x.device.type == 'cuda' 9 | 10 | 11 | def mm_output(a, b): 12 | assert a.dim() == 2, 'a must be 2D' 13 | assert b.dim() == 2, 'b must be 2D' 14 | N, M1 = a.shape 15 | M2, P = b.shape 16 | assert M1 == M2, 'a and b must have same reduction dim' 17 | return (N, P) 18 | 19 | 20 | def addmm_output(bias, x, y): 21 | return mm_output(x, y) 22 | 23 | 24 | def common_baddbmm_bmm(batch1, batch2, is_bmm, self_baddbmm=None): 25 | assert batch1.dim() == 3, 'batch1 must be a 3D tensor' 26 | assert batch2.dim() == 3, 'batch2 must be a 3D tensor' 27 | 28 | batch1_sizes = batch1.size() 29 | batch2_sizes = batch2.size() 30 | 31 | bs = batch1_sizes[0] 32 | contraction_size = batch1_sizes[2] 33 | res_rows = batch1_sizes[1] 34 | res_cols = batch2_sizes[2] 35 | output_size = (bs, res_rows, res_cols) 36 | 37 | assert batch2_sizes[0] == bs and batch2_sizes[1] == contraction_size 38 | 39 | if not is_bmm and self_baddbmm is not None: 40 | assert self_baddbmm.dim() == 3, 'self must be a 3D tensor' 41 | assert self_baddbmm.size() == output_size, \ 42 | f'Expected an input tensor shape with shape {output_size} but got shape: {self_baddbmm.size()}' 43 | 44 | return output_size 45 | 46 | 47 | def bmm_output(mat1, mat2): 48 | return common_baddbmm_bmm(mat1, mat2, True) 49 | -------------------------------------------------------------------------------- /elixir/kernels/gpt_attention.py: -------------------------------------------------------------------------------- 1 | import einops 2 | import torch 3 | from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Model 4 | 5 | from .attention import lower_triangular_attention 6 | 7 | 8 | class XGPT2Attention(GPT2Attention): 9 | 10 | def _attn(self, query, key, value, attention_mask=None, head_mask=None): 11 | assert self.scale_attn_weights 12 | assert not self.is_cross_attention 13 | assert not self.scale_attn_by_inverse_layer_idx 14 | assert not self.reorder_and_upcast_attn 15 | 16 | b_size, h_size, m_size, k_size = query.size() 17 | 18 | assert self.bias.size(-1) == m_size 19 | query = einops.rearrange(query, 'b h m k -> b m h k') 20 | key = einops.rearrange(key, 'b h m k -> b m h k') 21 | value = einops.rearrange(value, 'b h m k -> b m h k') 22 | 23 | drop_rate = self.attn_dropout.p 24 | output = lower_triangular_attention(query, key, value, p=drop_rate) 25 | 26 | ret = einops.rearrange(output, 'b m h k -> b h m k') 27 | 28 | return ret, None 29 | 30 | 31 | class XGPT2Model(GPT2Model): 32 | 33 | def forward(self, *args, **kwargs): 34 | assert 'attention_mask' in kwargs, 'please pass attention_mask as a kwarg' 35 | attn_mask = kwargs.get('attention_mask') 36 | # assert torch.all(attn_mask == 1), 'only accept no padding mask' 37 | 38 | head_mask = kwargs.get('head_mask', None) 39 | assert head_mask is None, 'head mask should be None' 40 | 41 | output_attn = kwargs.get('output_attentions', False) 42 | if output_attn: 43 | Warning('output_attentions is not supported for XGPT2Model') 44 | 45 | return super().forward(*args, **kwargs) 46 | -------------------------------------------------------------------------------- /example/activation/activation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from test.utils.gpt import GPTLMModel, small_data_fn 3 | from time import time 4 | 5 | import torch 6 | from torch.autograd.profiler_util import _format_memory 7 | from transformers import AutoConfig, OPTConfig, OPTForCausalLM 8 | 9 | from elixir.ctx import MetaContext 10 | from elixir.kernels.attn_wrapper import wrap_attention 11 | from elixir.tracer.memory_tracer import cuda_memory_profiling 12 | from elixir.utils import get_model_size, model_size_formatter 13 | from example.common.models import get_model 14 | 15 | 16 | def parse_args(): 17 | parser = argparse.ArgumentParser(description='test activation settings') 18 | parser.add_argument('--model_name', type=str, default='opt-1b', help='test model name') 19 | args = parser.parse_args() 20 | return args 21 | 22 | 23 | def profile_max_activation(): 24 | args = parse_args() 25 | with MetaContext(): 26 | model = get_model(args.model_name) 27 | model_size = get_model_size(model) 28 | print(f'model size: {model_size_formatter(model_size)}') 29 | 30 | data = small_data_fn() 31 | 32 | def train_step(model_in, inp_in): 33 | loss = model_in(**inp_in) 34 | loss.backward() 35 | 36 | model = wrap_attention(model) 37 | model.gradient_checkpointing_enable() 38 | 39 | start = time() 40 | 41 | profiling_dict = cuda_memory_profiling(model, data, train_step, dtype=torch.float16) 42 | 43 | torch.cuda.synchronize() 44 | end = time() 45 | 46 | print(f'profile time: {end - start: .2f} sec') 47 | print('memory usage', profiling_dict) 48 | print('activation', _format_memory(profiling_dict['activation_occ'])) 49 | 50 | 51 | if __name__ == '__main__': 52 | profile_max_activation() 53 | -------------------------------------------------------------------------------- /elixir/hook/storage.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd.profiler_util import _format_memory 3 | 4 | from elixir.cuda import gpu_device 5 | 6 | 7 | class BufferStore(object): 8 | """A place to store parameters temporarily when computing. 9 | """ 10 | 11 | def __init__(self, buffer_size: torch.Tensor, buffer_dtype: torch.dtype, device_str: str = 'cuda') -> None: 12 | super().__init__() 13 | self.buffer_size = buffer_size 14 | self.buffer_dtype = buffer_dtype 15 | self.buffer: torch.Tensor = torch.empty(buffer_size, dtype=buffer_dtype, device=device_str) 16 | self.buffer_occ = buffer_size * self.buffer.element_size() 17 | self.record_dict = dict() 18 | 19 | def zeros(self): 20 | torch.zero_(self.buffer) 21 | 22 | def insert(self, t: torch.Tensor, offset: int) -> int: 23 | assert t not in self.record_dict 24 | end = offset + t.numel() 25 | assert end <= self.buffer_size, f'buffer size is {self.buffer_size} but needs {end}' 26 | 27 | new_data = self.buffer[offset:end].view(t.shape) 28 | new_data.copy_(t.data) 29 | 30 | self.record_dict[t] = t.data 31 | t.data = new_data 32 | 33 | return end 34 | 35 | def erase(self, t: torch.Tensor): 36 | assert t in self.record_dict 37 | 38 | new_data = self.record_dict.pop(t) 39 | t.data = new_data 40 | 41 | return 42 | 43 | def empty_like(self, t: torch.Tensor): 44 | return self.buffer[:t.numel()].view(t.shape) 45 | 46 | def empty_1d(self, size: int): 47 | return self.buffer[:size] 48 | 49 | def __repr__(self) -> str: 50 | return f'Buffer(size={self.buffer_size}, dtype={self.buffer_dtype}, memo_occ={_format_memory(self.buffer_occ)})' 51 | -------------------------------------------------------------------------------- /elixir/ctx/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | tensor_creation_methods = dict(tensor=torch.tensor, 4 | sparse_coo_tensor=torch.sparse_coo_tensor, 5 | asarray=torch.asarray, 6 | as_tensor=torch.as_tensor, 7 | as_strided=torch.as_strided, 8 | from_numpy=torch.from_numpy, 9 | from_dlpack=torch.from_dlpack, 10 | frombuffer=torch.frombuffer, 11 | zeros=torch.zeros, 12 | zeros_like=torch.zeros_like, 13 | ones=torch.ones, 14 | ones_like=torch.ones_like, 15 | arange=torch.arange, 16 | range=torch.range, 17 | linspace=torch.linspace, 18 | logspace=torch.logspace, 19 | eye=torch.eye, 20 | empty=torch.empty, 21 | empty_like=torch.empty_like, 22 | empty_strided=torch.empty_strided, 23 | full=torch.full, 24 | full_like=torch.full_like, 25 | quantize_per_tensor=torch.quantize_per_tensor, 26 | quantize_per_channel=torch.quantize_per_channel, 27 | dequantize=torch.dequantize, 28 | complex=torch.complex, 29 | polar=torch.polar, 30 | heaviside=torch.heaviside) 31 | 32 | from .meta_ctx import MetaContext 33 | -------------------------------------------------------------------------------- /test/parameter/hf_models/diffuser.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from test.parameter.hf_models import test_hf_model 3 | 4 | import diffusers 5 | import torch 6 | from torch.testing import assert_close 7 | 8 | from elixir.parameter.temp import transform 9 | from elixir.utils import seed_all 10 | 11 | BATCH_SIZE = 2 12 | SEQ_LENGTH = 5 13 | HEIGHT = 224 14 | WIDTH = 224 15 | IN_CHANNELS = 3 16 | LATENTS_SHAPE = (BATCH_SIZE, IN_CHANNELS, HEIGHT // 7, WIDTH // 7) 17 | TIME_STEP = 3 18 | 19 | 20 | def test_model(builder, kwargs): 21 | data = torch.randn(LATENTS_SHAPE, device='cuda') 22 | torch_model = builder().cuda() 23 | test_model = deepcopy(torch_model) 24 | test_model = transform(test_model) 25 | 26 | torch_model.eval() 27 | test_model.eval() 28 | 29 | torch_out = torch_model(data, **kwargs) 30 | test_out = test_model(data, **kwargs) 31 | 32 | assert_close(torch_out['sample'], test_out['sample']) 33 | torch.cuda.synchronize() 34 | 35 | 36 | def test_vae(): 37 | seed_all(319) 38 | model_list = [ 39 | diffusers.AutoencoderKL, 40 | diffusers.VQModel, 41 | ] 42 | kwargs = {} 43 | for builder in model_list: 44 | flag = '√' 45 | try: 46 | test_model(builder, kwargs) 47 | except: 48 | flag = 'x' 49 | print(f'{builder.__name__:40s} {flag}') 50 | 51 | 52 | def test_unet(): 53 | seed_all(221) 54 | model_list = [ 55 | diffusers.UNet2DModel, 56 | ] 57 | kwargs = {'timestep': TIME_STEP} 58 | for builder in model_list: 59 | flag = '√' 60 | try: 61 | test_model(builder, kwargs) 62 | except: 63 | flag = 'x' 64 | print(f'{builder.__name__:40s} {flag}') 65 | 66 | 67 | if __name__ == '__main__': 68 | test_vae() 69 | test_unet() 70 | -------------------------------------------------------------------------------- /elixir/tracer/ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | 4 | aten = torch.ops.aten 5 | 6 | __all__ = [ 7 | 'TorchFactoryMethod', 'TorchOverrideableFactoryMethod', 'TorchNonOverrideableFactoryMethod', 'TensorPropertyMethod', 8 | 'DistCommMethod', 'AliasATen', 'InplaceATen', 'MaybeInplaceAten', 'SameStorageAten' 9 | ] 10 | 11 | TorchOverrideableFactoryMethod = [ 12 | 'empty', 13 | 'eye', 14 | 'full', 15 | 'ones', 16 | 'rand', 17 | 'randn', 18 | 'zeros', 19 | ] 20 | 21 | TorchNonOverrideableFactoryMethod = [ 22 | 'arange', 23 | 'finfo', 24 | 'linspace', 25 | 'logspace', 26 | 'randint', 27 | 'randperm', 28 | 'tensor', 29 | ] 30 | 31 | TorchFactoryMethod = TorchOverrideableFactoryMethod + TorchNonOverrideableFactoryMethod 32 | 33 | TensorPropertyMethod = ['dtype', 'shape', 'device', 'requires_grad', 'grad', 'grad_fn', 'data'] 34 | 35 | DistCommMethod = [ 36 | 'all_gather', 37 | 'all_reduce', 38 | 'all_to_all', 39 | 'broadcast', 40 | 'gather', 41 | 'reduce', 42 | 'reduce_scatter', 43 | 'scatter', 44 | ] 45 | 46 | AliasATen = [ 47 | aten.detach.default, 48 | aten.detach_.default, 49 | aten.t.default, 50 | aten.transpose.int, 51 | aten.view.default, 52 | aten._unsafe_view.default, 53 | aten._reshape_alias.default, 54 | ] 55 | 56 | InplaceATen = [ 57 | aten.add_.Tensor, 58 | aten.add_.Scalar, 59 | aten.sub_.Tensor, 60 | aten.sub_.Scalar, 61 | aten.mul_.Tensor, 62 | aten.mul_.Scalar, 63 | aten.div_.Tensor, 64 | aten.div_.Scalar, 65 | aten.pow_.Tensor, 66 | aten.pow_.Scalar, 67 | ] 68 | 69 | MaybeInplaceAten = [ 70 | aten.diagonal.default, 71 | aten.select.int, 72 | aten.slice.Tensor, 73 | aten.as_strided.default, 74 | ] 75 | 76 | SameStorageAten = AliasATen + InplaceATen + MaybeInplaceAten 77 | -------------------------------------------------------------------------------- /test/kernels/test_ln.py: -------------------------------------------------------------------------------- 1 | import os 2 | from copy import deepcopy 3 | from functools import partial 4 | 5 | import pytest 6 | import torch 7 | import torch.distributed as dist 8 | import torch.nn as nn 9 | from torch.testing import assert_close 10 | 11 | from elixir.search import simple_search 12 | from elixir.utils import init_distributed 13 | from elixir.wrapper import ElixirModule 14 | 15 | 16 | def exam_fused_layernorm(nproc, group): 17 | torch_model = nn.LayerNorm(2048) 18 | fused_model = deepcopy(torch_model) 19 | 20 | torch_model = torch_model.cuda() 21 | sr = simple_search(fused_model, nproc, 1, 1.0, verbose=True) 22 | fused_model = ElixirModule(fused_model, sr, group, use_fused_kernels=True) 23 | 24 | data = torch.randn(2, 2048, device='cuda') 25 | 26 | torch_loss = torch_model(data).sum() 27 | torch_loss.backward() 28 | 29 | fused_loss = fused_model(data).sum() 30 | fused_model.backward(fused_loss) 31 | 32 | assert_close(torch_loss, fused_loss) 33 | 34 | grad_state = fused_model.state_dict(from_param=True) 35 | for name, param in torch_model.named_parameters(): 36 | assert_close(param.grad.cpu(), grad_state[name]) 37 | 38 | 39 | def run_dist(rank, world_size): 40 | os.environ['RANK'] = str(rank) 41 | os.environ['LOCAL_RANK'] = str(rank) 42 | os.environ['WORLD_SIZE'] = str(world_size) 43 | os.environ['MASTER_ADDR'] = '127.0.0.1' 44 | os.environ['MASTER_PORT'] = str(29512) 45 | init_distributed() 46 | exam_fused_layernorm(nproc=world_size, group=dist.GroupMember.WORLD) 47 | 48 | 49 | @pytest.mark.dist 50 | @pytest.mark.parametrize('world_size', [1]) 51 | def test_fused_layernorm(world_size): 52 | run_func = partial(run_dist, world_size=world_size) 53 | torch.multiprocessing.spawn(run_func, nprocs=world_size) 54 | 55 | 56 | if __name__ == '__main__': 57 | test_fused_layernorm(world_size=1) 58 | -------------------------------------------------------------------------------- /test/chunk/test_block.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from elixir.chunk import BlockRequire, MemoryPool, PrivateBlock, PublicBlock 4 | 5 | 6 | def test_block(): 7 | b = PublicBlock(123, torch.float16, 'cuda') 8 | payload_b = b.payload 9 | 10 | assert payload_b.numel() == 123 11 | assert payload_b.dtype == torch.float16 12 | assert payload_b.device.type == 'cuda' 13 | assert payload_b.numel() * payload_b.element_size() == b.memo_occ 14 | 15 | c = PrivateBlock(77, torch.float, 'cpu') 16 | payload_c = c.payload 17 | 18 | assert payload_c.numel() == 77 19 | assert payload_c.dtype == torch.float 20 | assert payload_c.device.type == 'cpu' 21 | assert payload_c.numel() * payload_c.element_size() == c.memo_occ 22 | 23 | print('test_block: ok') 24 | 25 | 26 | def test_memory_pool(): 27 | mp = MemoryPool(device_type='cuda') 28 | private_list = [BlockRequire(5, torch.float), BlockRequire(81, torch.float16)] 29 | mp.allocate(public_block_number=4, private_block_list=private_list) 30 | 31 | block0 = mp.get_public_block() 32 | 33 | assert block0 in mp.public_used_blocks 34 | assert mp.public_used_cnt == 1 35 | assert mp.public_free_cnt == 3 36 | 37 | block1 = mp.get_public_block() 38 | 39 | assert block1 in mp.public_used_blocks 40 | assert mp.public_used_cnt == 2 41 | assert mp.public_free_cnt == 2 42 | 43 | mp.free_public_block(block0) 44 | mp.free_public_block(block1) 45 | 46 | assert block0 in mp.public_free_blocks 47 | assert block1 in mp.public_free_blocks 48 | assert mp.public_used_cnt == 0 49 | assert mp.public_free_cnt == 4 50 | 51 | block0 = mp.get_private_block(5, torch.float) 52 | assert block0.numel == 5 53 | assert block0.dtype == torch.float 54 | 55 | print('test_memory_pool: ok') 56 | 57 | 58 | if __name__ == '__main__': 59 | test_block() 60 | test_memory_pool() 61 | -------------------------------------------------------------------------------- /example/common/utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | from contextlib import nullcontext 3 | 4 | import psutil 5 | import torch 6 | from torch.autograd.profiler_util import _format_memory 7 | from torch.profiler import ProfilerActivity, profile, schedule, tensorboard_trace_handler 8 | 9 | 10 | class DummyProfiler: 11 | 12 | def __init__(self): 13 | self.step_number = 0 14 | 15 | def step(self): 16 | self.step_number += 1 17 | 18 | 19 | # Randomly Generated Data 20 | def fake_gpt_data(batch_size, seq_len, vocab_size): 21 | input_ids = torch.randint(0, vocab_size, (batch_size, seq_len)) 22 | attention_mask = torch.ones_like(input_ids) 23 | return input_ids, attention_mask 24 | 25 | 26 | def fake_img_data(batch_size, channel, width, height): 27 | return torch.randn((batch_size, channel, width, height)) 28 | 29 | 30 | def get_tflops(model_numel, batch_size, seq_len, step_time): 31 | return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) 32 | 33 | 34 | def get_mem_info(prefix=''): 35 | cpu_memory = psutil.Process().memory_info().rss 36 | gpu_memory = torch.cuda.memory_allocated() 37 | 38 | return f'{prefix}GPU memory usage: {_format_memory(gpu_memory)}, CPU memory usage: {_format_memory(cpu_memory)}' 39 | 40 | 41 | def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir): 42 | if enable_flag: 43 | return profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], 44 | schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps), 45 | on_trace_ready=tensorboard_trace_handler(save_dir), 46 | record_shapes=True, 47 | profile_memory=True) 48 | else: 49 | return nullcontext(DummyProfiler()) 50 | 51 | 52 | def get_time_stamp(): 53 | cur_time = time.strftime('%d-%H:%M', time.localtime()) 54 | return cur_time 55 | -------------------------------------------------------------------------------- /src/simulator.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | int move_count_impl(std::vector> &steps, int n_blocks) { 6 | int n_steps = steps.size(); 7 | std::unordered_map my_map; 8 | std::map, int> next_map; 9 | 10 | for (auto i = n_steps - 1; ~i; --i) { 11 | auto ids = steps.at(i); 12 | for (auto c_id : ids) { 13 | auto iter = my_map.find(c_id); 14 | auto nxt = n_steps; 15 | if (iter != my_map.end()) 16 | nxt = iter->second; 17 | next_map.emplace(std::make_pair(i, c_id), nxt); 18 | my_map[c_id] = i; 19 | } 20 | } 21 | // reuse this map 22 | for (auto iter : my_map) 23 | my_map[iter.first] = 0; 24 | 25 | int cache_size = 0, count = 0; 26 | std::priority_queue> cache; 27 | for (auto i = 0; i < n_steps; ++i) { 28 | auto ids = steps.at(i); 29 | assert(n_blocks >= ids.size()); 30 | 31 | int not_in = 0; 32 | for (auto c_id : ids) 33 | if (my_map[c_id] == 0) 34 | ++not_in; 35 | 36 | while (cache_size + not_in > n_blocks) { 37 | std::pair q_top = cache.top(); 38 | cache.pop(); 39 | assert(q_top.first > i); 40 | assert(my_map[q_top.second] == 1); 41 | my_map[q_top.second] = 0; 42 | --cache_size; 43 | ++count; 44 | } 45 | 46 | for (auto c_id : ids) { 47 | auto iter = next_map.find(std::make_pair(i, c_id)); 48 | cache.push(std::make_pair(iter->second, c_id)); 49 | if (my_map[c_id] == 0) { 50 | my_map[c_id] = 1; 51 | ++cache_size; 52 | } 53 | } 54 | } 55 | return (count + cache_size) << 1; 56 | } 57 | 58 | int move_count(std::vector> &steps, int n_blocks) { 59 | return move_count_impl(steps, n_blocks); 60 | } 61 | 62 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 63 | m.def("move_count", &move_count, "Count the number of moves."); 64 | } 65 | -------------------------------------------------------------------------------- /example/common/elx.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from colossalai.nn.optimizer import HybridAdam 4 | from transformers.modeling_utils import no_init_weights 5 | 6 | from elixir.ctx import MetaContext 7 | from elixir.kernels.attn_wrapper import wrap_attention 8 | from elixir.search import optimal_search 9 | from elixir.utils import get_model_size 10 | from elixir.wrapper import ElixirModule, ElixirOptimizer 11 | from example.common.models import get_model 12 | 13 | 14 | def train_step(model, data): 15 | loss = model(**data) 16 | loss.backward() 17 | return loss 18 | 19 | 20 | def train_init(model_name: str, data: dict): 21 | global_group = dist.GroupMember.WORLD 22 | global_size = dist.get_world_size() 23 | 24 | with no_init_weights(): 25 | model = get_model(model_name) 26 | model_size = get_model_size(model) 27 | optimizer = HybridAdam(model.parameters(), lr=1e-3) 28 | 29 | model.gradient_checkpointing_enable() 30 | model = wrap_attention(model) 31 | 32 | sr = optimal_search(model, 33 | global_size, 34 | unified_dtype=torch.float16, 35 | overlap=True, 36 | verbose=True, 37 | inp=data, 38 | step_fn=train_step) 39 | model = ElixirModule(model, sr, global_group, prefetch=True, dtype=torch.float16, use_fused_kernels=True) 40 | optimizer = ElixirOptimizer(model, optimizer, initial_scale=64, hysteresis=1, init_step=True) 41 | 42 | model.train() 43 | 44 | def forward(data): 45 | return model(**data) 46 | 47 | def backward(loss): 48 | optimizer.backward(loss) 49 | 50 | def optim(): 51 | optimizer.step() 52 | optimizer.zero_grad() 53 | 54 | return forward, backward, optim, model_size 55 | 56 | 57 | if __name__ == '__main__': 58 | import colossalai 59 | colossalai.launch_from_torch(config={}) 60 | print(train_init('opt-13b')) 61 | -------------------------------------------------------------------------------- /example/common/ds.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import deepspeed 5 | import torch 6 | from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam 7 | from transformers.modeling_utils import no_init_weights 8 | 9 | from elixir.kernels.attn_wrapper import wrap_attention 10 | from elixir.utils import get_model_size 11 | from example.common.models import get_model 12 | 13 | 14 | def train_init(batch_size: int, model_name: str, zero_stage: int, cpu_offload: bool): 15 | cur_path = os.path.abspath(os.path.dirname(__file__)) 16 | if zero_stage == 2: 17 | ds_path = os.path.join(cur_path, 'zero2_config.json') 18 | else: 19 | ds_path = os.path.join(cur_path, 'zero3_config.json') 20 | ds_config = json.load(open(ds_path)) 21 | 22 | if not cpu_offload: 23 | zero_optim = ds_config.get('zero_optimization') 24 | zero_optim.pop('offload_optimizer') 25 | if zero_stage == 3: 26 | zero_optim.pop('offload_param') 27 | 28 | total_bs = batch_size * int(os.environ['WORLD_SIZE']) 29 | ds_config['train_batch_size'] = total_bs 30 | ds_config['train_micro_batch_size_per_gpu'] = batch_size 31 | 32 | deepspeed.init_distributed() 33 | if zero_stage == 2: 34 | with no_init_weights(): 35 | model = get_model(model_name) 36 | numel = get_model_size(model) 37 | else: 38 | with deepspeed.zero.Init(config_dict_or_path=ds_config): 39 | model = get_model(model_name) 40 | numel = deepspeed.runtime.zero.partition_parameters.param_count 41 | 42 | if cpu_offload: 43 | optimizer = DeepSpeedCPUAdam(model.parameters(), lr=1e-3) 44 | else: 45 | optimizer = FusedAdam(model.parameters(), lr=1e-3) 46 | model, optimizer, _, _ = deepspeed.initialize(model=model, optimizer=optimizer, config=ds_config) 47 | model.gradient_checkpointing_enable() 48 | model = wrap_attention(model) 49 | model.train() 50 | 51 | def forward(data): 52 | return model(**data) 53 | 54 | def backward(loss): 55 | model.backward(loss) 56 | 57 | def optim(): 58 | model.step() 59 | 60 | return forward, backward, optim, numel 61 | 62 | 63 | if __name__ == '__main__': 64 | train_init(1, 'opt-1b', 3, False) 65 | exit(0) 66 | -------------------------------------------------------------------------------- /profile/profile_optimizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from time import time 3 | 4 | import colossalai 5 | import torch 6 | import torch.distributed as dist 7 | import torch.nn as nn 8 | from colossalai.nn.optimizer import HybridAdam 9 | 10 | from elixir.utils import print_rank_0 11 | 12 | 13 | def set_cpu_maximum_parallelism(): 14 | conf_str = torch.__config__.parallel_info() 15 | inter_str = conf_str.split('hardware_concurrency() : ')[1] 16 | max_concurrency = int(inter_str.split('\n')[0]) 17 | concurrency_per_process = max_concurrency // dist.get_world_size() 18 | os.environ['OMP_NUM_THREADS'] = str(concurrency_per_process) 19 | print(f'environmental variable OMP_NUM_THREADS is set to {max_concurrency}.') 20 | 21 | 22 | class FlattenModel(nn.Module): 23 | 24 | def __init__(self, length: int = 2 * 10**9, device_type: str = 'cuda') -> None: 25 | super().__init__() 26 | self.length = length 27 | self.device_type = device_type 28 | self.weight = nn.Parameter(torch.zeros(length, device=device_type)) 29 | 30 | def set_grad(self): 31 | self.weight.grad = torch.ones(self.length, dtype=torch.float, device=self.device_type) 32 | 33 | 34 | def test_optimizer_update(device_type: str = 'cuda', n_times: int = 50): 35 | l_gb = 1 36 | length = int(l_gb * 10**9) 37 | model = FlattenModel(length, device_type=device_type) 38 | optimizer = HybridAdam(model.parameters(), lr=1e-5) 39 | 40 | sum_time = 0 41 | for _ in range(n_times): 42 | optimizer.zero_grad() 43 | model.set_grad() 44 | 45 | dist.barrier() 46 | torch.cuda.synchronize() 47 | 48 | start = time() 49 | optimizer.step() 50 | 51 | dist.barrier() 52 | torch.cuda.synchronize() 53 | sum_time += time() - start 54 | 55 | n_proc = dist.get_world_size() 56 | sum_time = torch.tensor(sum_time, dtype=torch.double, device='cuda') 57 | dist.all_reduce(sum_time, op=dist.ReduceOp.MAX) 58 | velocity = n_times * n_proc * l_gb / sum_time.item() 59 | 60 | print_rank_0(f'GPU velocity result: {velocity: .2f}') 61 | 62 | 63 | if __name__ == '__main__': 64 | colossalai.launch_from_torch(config={}) 65 | # set_cpu_maximum_parallelism() 66 | test_optimizer_update('cuda') 67 | -------------------------------------------------------------------------------- /test/chunk/test_fetcher.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | from functools import partial 4 | from test.chunk.fetcher_utils import hook_transform 5 | from test.utils import TEST_MODELS, to_cuda 6 | 7 | import pytest 8 | import torch 9 | import torch.distributed as dist 10 | from torch.nn.parallel import DistributedDataParallel as DDP 11 | from torch.testing import assert_close 12 | 13 | from elixir.chunk import ChunkGroup 14 | from elixir.utils import init_distributed, seed_all 15 | 16 | 17 | def check_gradient(ddp_model, my_model, cg: ChunkGroup): 18 | for chunk in cg.fused_chunks: 19 | cg.access_chunk(chunk) 20 | 21 | for (name, p0), p1 in zip(ddp_model.named_parameters(), my_model.parameters()): 22 | torch.cuda.synchronize() 23 | print(f'checking parameter {name}') 24 | assert_close(p0.grad.data, p1.data) 25 | 26 | 27 | def exam_chunk_fetcher(nproc, group): 28 | model_fn, data_fn = TEST_MODELS.get('resnet') 29 | torch_model = model_fn().cuda() 30 | test_model = copy.deepcopy(torch_model) 31 | 32 | rank = dist.get_rank(group) 33 | # get different data 34 | seed_all(1001 + rank) 35 | data = to_cuda(data_fn()) 36 | 37 | seed_all(1001, cuda_deterministic=True) 38 | ddp_model = DDP(torch_model) 39 | ddp_loss = ddp_model(**data) 40 | ddp_loss.backward() 41 | 42 | hook_model, cg = hook_transform(test_model, group) 43 | my_loss = hook_model(**data) 44 | my_loss.backward() 45 | 46 | assert_close(ddp_loss, my_loss) 47 | check_gradient(ddp_model, hook_model, cg) 48 | print('private chunk fetcher is ok') 49 | 50 | 51 | def run_dist(rank, world_size): 52 | os.environ['RANK'] = str(rank) 53 | os.environ['LOCAL_RANK'] = str(rank) 54 | os.environ['WORLD_SIZE'] = str(world_size) 55 | os.environ['MASTER_ADDR'] = '127.0.0.1' 56 | os.environ['MASTER_PORT'] = str(29512) 57 | init_distributed() 58 | exam_chunk_fetcher(nproc=world_size, group=dist.GroupMember.WORLD) 59 | 60 | 61 | @pytest.mark.dist 62 | @pytest.mark.parametrize('world_size', [1, 2, 4]) 63 | def test_chunk_fetcher(world_size): 64 | run_func = partial(run_dist, world_size=world_size) 65 | torch.multiprocessing.spawn(run_func, nprocs=world_size) 66 | 67 | 68 | if __name__ == '__main__': 69 | test_chunk_fetcher(world_size=2) 70 | -------------------------------------------------------------------------------- /elixir/tracer/param_tracer/fx_order.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.fx import GraphModule, Node, symbolic_trace 6 | 7 | from elixir.tracer.utils import meta_copy 8 | 9 | 10 | def generate_fx_order(model: nn.Module) -> List[Dict[str, nn.Parameter]]: 11 | fxf_name_mark = '_fxf_name' 12 | fxf_param_mark = '_fxf_param' 13 | 14 | def tensor_trans(t): 15 | meta_t = t.data.to('meta') 16 | if isinstance(t, nn.Parameter): 17 | meta_t = nn.Parameter(meta_t) 18 | return meta_t 19 | 20 | meta_model = meta_copy(model, tensor_trans) 21 | 22 | # attach names for parameters 23 | for name, param in meta_model.named_parameters(): 24 | setattr(param, fxf_name_mark, name) 25 | 26 | fx_forward_order: List[Dict[str, nn.Parameter]] = list() 27 | 28 | gm: GraphModule = symbolic_trace(meta_model) 29 | 30 | for node in gm.graph.nodes: 31 | if node.op in ('output', 'placeholder'): 32 | continue 33 | 34 | step_dict = None 35 | if node.op == 'get_attr': 36 | maybe_param = getattr(gm, node.target) 37 | # mark this node as a parameter 38 | if maybe_param is not None: 39 | setattr(node, fxf_param_mark, maybe_param) 40 | continue 41 | elif node.op == 'call_module': 42 | target_module = gm.get_submodule(node.target) 43 | step_dict = dict() 44 | # collect all parameters in the module 45 | for maybe_param in target_module.parameters(): 46 | if maybe_param is not None: 47 | param_name = getattr(maybe_param, fxf_name_mark) 48 | step_dict[param_name] = maybe_param 49 | elif node.op in ('call_function', 'call_method'): 50 | step_dict = dict() 51 | for pre in node.args: 52 | if hasattr(pre, fxf_param_mark): 53 | param = getattr(pre, fxf_param_mark) 54 | param_name = getattr(param, fxf_name_mark) 55 | step_dict[param_name] = param 56 | else: 57 | raise RuntimeError(f'Unsupported node op {node.op}!') 58 | 59 | if step_dict is not None and len(step_dict) > 0: 60 | fx_forward_order.append(step_dict) 61 | 62 | return fx_forward_order 63 | -------------------------------------------------------------------------------- /test/parameter/test_timm.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import timm.models as tmm 4 | import torch 5 | import torch.nn.functional as F 6 | from torch.testing import assert_close 7 | 8 | from elixir.parameter.temp import transform 9 | from elixir.utils import seed_all 10 | 11 | 12 | def assert_tuple_close(ta, tb): 13 | if not isinstance(ta, tuple): 14 | ta = (ta,) 15 | if not isinstance(tb, tuple): 16 | tb = (tb,) 17 | 18 | for (a, b) in zip(ta, tb): 19 | assert_close(a, b) 20 | 21 | 22 | def test_model(builder, kwargs): 23 | data = torch.randn(2, 3, 224, 224, device='cuda') 24 | torch_model = builder(**kwargs).cuda() 25 | test_model = deepcopy(torch_model) 26 | test_model = transform(test_model) 27 | 28 | torch_model.eval() 29 | test_model.eval() 30 | 31 | torch_out = torch_model(data) 32 | test_out = test_model(data) 33 | 34 | assert_tuple_close(torch_out, test_out) 35 | torch.cuda.synchronize() 36 | 37 | 38 | def test_timm_models(): 39 | seed_all(1001, cuda_deterministic=True) 40 | model_list = [ 41 | tmm.beit_base_patch16_224, 42 | tmm.beitv2_base_patch16_224, 43 | tmm.cait_s24_224, 44 | tmm.coat_lite_mini, 45 | tmm.convit_base, 46 | tmm.deit3_base_patch16_224, 47 | tmm.dm_nfnet_f0, 48 | tmm.eca_nfnet_l0, 49 | tmm.efficientformer_l1, 50 | tmm.ese_vovnet19b_dw, 51 | tmm.gmixer_12_224, 52 | tmm.gmlp_b16_224, 53 | tmm.hardcorenas_a, 54 | tmm.hrnet_w18_small, 55 | tmm.inception_v3, 56 | tmm.mixer_b16_224, 57 | tmm.nf_ecaresnet101, 58 | tmm.nf_regnet_b0, 59 | # tmm.pit_b_224, # pretrained only 60 | tmm.regnetv_040, 61 | tmm.skresnet18, 62 | tmm.swin_base_patch4_window7_224, 63 | tmm.tnt_b_patch16_224, 64 | tmm.vgg11, 65 | tmm.vit_base_patch16_18x2_224, 66 | tmm.wide_resnet50_2, 67 | ] 68 | 69 | for builder in model_list: 70 | kwargs = {} 71 | 72 | flag = '√' 73 | try: 74 | test_model(builder, kwargs) 75 | except: 76 | flag = 'x' 77 | print(f'{builder.__name__:40s} {flag}') 78 | 79 | 80 | if __name__ == '__main__': 81 | test_timm_models() 82 | # torch.Tensor.add_ = torch.Tensor.add 83 | # test_model(tm.resnest.resnest50d, {}) 84 | -------------------------------------------------------------------------------- /elixir/hook/functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from elixir.chunk import ChunkFetcher 4 | 5 | from .storage import BufferStore 6 | 7 | 8 | def prefwd_postbwd_function(fetcher: ChunkFetcher, store: BufferStore): 9 | 10 | class PreFwdPostBwd(torch.autograd.Function): 11 | 12 | @staticmethod 13 | def forward(ctx, params, *args): 14 | with torch._C.DisableTorchFunction(): 15 | ctx.params = params 16 | chunks = fetcher.trans_to_compute(params) 17 | fetcher.fetch_chunks(chunks) 18 | 19 | offset = 0 20 | for p in ctx.params: 21 | if not fetcher.is_in_fused(p): 22 | # we should add parameters to buffer 23 | # because their blocks may be changed 24 | offset = store.insert(p, offset) 25 | 26 | return args 27 | 28 | @staticmethod 29 | def backward(ctx, *grads): 30 | with torch._C.DisableTorchFunction(): 31 | fetcher.trans_to_hold(ctx.params, phase='b') 32 | 33 | for p in ctx.params: 34 | if not fetcher.is_in_fused(p): 35 | store.erase(p) 36 | 37 | return (None, *grads) 38 | 39 | return PreFwdPostBwd.apply 40 | 41 | 42 | def postfwd_prebwd_function(fetcher: ChunkFetcher, store: BufferStore): 43 | 44 | class PostFwdPreBwd(torch.autograd.Function): 45 | 46 | @staticmethod 47 | def forward(ctx, params, *args): 48 | with torch._C.DisableTorchFunction(): 49 | ctx.params = params 50 | 51 | fetcher.trans_to_hold(ctx.params, phase='f') 52 | for p in ctx.params: 53 | if not fetcher.is_in_fused(p): 54 | store.erase(p) 55 | 56 | return args 57 | 58 | @staticmethod 59 | def backward(ctx, *grads): 60 | with torch._C.DisableTorchFunction(): 61 | chunks = fetcher.trans_to_compute(ctx.params) 62 | fetcher.fetch_chunks(chunks) 63 | 64 | offset = 0 65 | for p in ctx.params: 66 | if not fetcher.is_in_fused(p): 67 | # we should add parameters to buffer 68 | # because their blocks may be changed 69 | offset = store.insert(p, offset) 70 | 71 | return (None, *grads) 72 | 73 | return PostFwdPreBwd.apply 74 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # vscode configurations 2 | .vscode/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | -------------------------------------------------------------------------------- /test/parameter/hf_models/gpt.py: -------------------------------------------------------------------------------- 1 | from test.parameter.hf_models import test_hf_model 2 | 3 | import torch 4 | import transformers 5 | 6 | BS = 2 7 | SL = 16 8 | 9 | gpt_config = transformers.GPT2Config(n_positions=64, n_embd=128, n_layer=2, n_head=4, pad_token_id=0) 10 | opt_config = transformers.OPTConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4, pad_token_id=0) 11 | t5_config = transformers.T5Config(d_model=128, d_kv=32, d_ff=256, num_layers=2, num_heads=4, pad_token_id=0) 12 | 13 | 14 | def data_gpt(): 15 | input_ids = torch.zeros((BS, SL), dtype=torch.int64) 16 | token_type_ids = torch.zeros((BS, SL), dtype=torch.int64) 17 | attention_mask = torch.zeros((BS, SL), dtype=torch.int64) 18 | return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) 19 | 20 | 21 | def data_opt(): 22 | input_ids = torch.zeros((BS, SL), dtype=torch.int64) 23 | attention_mask = torch.zeros((BS, SL), dtype=torch.int64) 24 | return dict(input_ids=input_ids, attention_mask=attention_mask) 25 | 26 | 27 | def data_t5(): 28 | input_ids = torch.zeros((BS, SL), dtype=torch.int64) 29 | decoder_input_ids = torch.zeros((BS, SL), dtype=torch.int64) 30 | return dict(input_ids=input_ids, decoder_input_ids=decoder_input_ids) 31 | 32 | 33 | def data_t5_encoder(): 34 | input_ids = torch.zeros((BS, SL), dtype=torch.int64) 35 | return dict(input_ids=input_ids) 36 | 37 | 38 | model_dict = { 39 | transformers.GPT2Model: dict(config=gpt_config, data=data_gpt), 40 | transformers.GPT2LMHeadModel: dict(config=gpt_config, data=data_gpt), 41 | transformers.GPT2DoubleHeadsModel: dict(config=gpt_config, data=data_gpt), 42 | transformers.GPT2ForTokenClassification: dict(config=gpt_config, data=data_gpt), 43 | transformers.GPT2ForSequenceClassification: dict(config=gpt_config, data=data_gpt), 44 | transformers.OPTModel: dict(config=opt_config, data=data_opt), 45 | transformers.OPTForCausalLM: dict(config=opt_config, data=data_opt), 46 | transformers.T5EncoderModel: dict(config=t5_config, data=data_t5_encoder), 47 | transformers.T5Model: dict(config=t5_config, data=data_t5), 48 | transformers.T5ForConditionalGeneration: dict(config=t5_config, data=data_t5), 49 | } 50 | 51 | 52 | def test_gpt(): 53 | for builder, config_dict in model_dict.items(): 54 | kwargs = dict(config=config_dict['config']) 55 | data_fn = config_dict['data'] 56 | 57 | flag = '√' 58 | try: 59 | test_hf_model(builder, kwargs, data_fn) 60 | except: 61 | flag = 'x' 62 | print(f'{builder.__name__:40s} {flag}') 63 | 64 | 65 | if __name__ == '__main__': 66 | test_gpt() 67 | -------------------------------------------------------------------------------- /test/wrapper/test_optimizer.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | from functools import partial 4 | from test.utils import TEST_MODELS, allclose, assert_dict_values, to_cuda 5 | 6 | import pytest 7 | import torch 8 | import torch.distributed as dist 9 | from colossalai.nn.optimizer import HybridAdam 10 | from torch.nn.parallel import DistributedDataParallel as DDP 11 | from torch.testing import assert_close 12 | 13 | from elixir.cuda import gpu_device 14 | from elixir.search import simple_search 15 | from elixir.utils import init_distributed, seed_all 16 | from elixir.wrapper import ElixirModule, ElixirOptimizer 17 | 18 | 19 | def exam_optimizer_one_model(model_fn, data_fn, nproc, group, exam_seed=2261): 20 | ddp_model = model_fn().cuda() 21 | test_model = copy.deepcopy(ddp_model) 22 | 23 | ddp_model = DDP(ddp_model) 24 | ddp_optim = HybridAdam(ddp_model.parameters(), lr=1e-1, weight_decay=0) 25 | 26 | test_optim = HybridAdam(test_model.parameters(), lr=1e-1, weight_decay=0) 27 | sr = simple_search(test_model, nproc, shard_device=gpu_device()) 28 | test_model = ElixirModule(test_model, sr, group) 29 | test_optim = ElixirOptimizer(test_model, test_optim) 30 | 31 | # get different data 32 | seed_all(exam_seed + dist.get_rank(group)) 33 | data = to_cuda(data_fn()) 34 | 35 | seed_all(exam_seed, cuda_deterministic=True) 36 | ddp_optim.zero_grad() 37 | ddp_loss = ddp_model(**data) 38 | ddp_loss.backward() 39 | ddp_optim.step() 40 | 41 | test_optim.zero_grad() 42 | test_loss = test_model(**data) 43 | test_optim.backward(test_loss) 44 | test_optim.step() 45 | 46 | assert_close(ddp_loss, test_loss) 47 | torch_st = ddp_model.module.state_dict() 48 | test_st = test_model.state_dict() 49 | assert_dict_values(torch_st, test_st, fn=partial(allclose, rtol=2e-6, atol=2e-5)) 50 | 51 | 52 | def exam_optimizer_in_models(nproc, group): 53 | model_fn, data_fn = TEST_MODELS.get('resnet') 54 | exam_optimizer_one_model(model_fn, data_fn, nproc, group) 55 | 56 | 57 | def run_dist(rank, world_size): 58 | os.environ['RANK'] = str(rank) 59 | os.environ['LOCAL_RANK'] = str(rank) 60 | os.environ['WORLD_SIZE'] = str(world_size) 61 | os.environ['MASTER_ADDR'] = '127.0.0.1' 62 | os.environ['MASTER_PORT'] = str(29512) 63 | init_distributed() 64 | exam_optimizer_in_models(nproc=world_size, group=dist.GroupMember.WORLD) 65 | 66 | 67 | @pytest.mark.dist 68 | @pytest.mark.parametrize('world_size', [1, 2, 4]) 69 | def test_elixir_optimizer(world_size): 70 | run_func = partial(run_dist, world_size=world_size) 71 | torch.multiprocessing.spawn(run_func, nprocs=world_size) 72 | 73 | 74 | if __name__ == '__main__': 75 | test_elixir_optimizer(world_size=4) 76 | -------------------------------------------------------------------------------- /test/chunk/fetcher_utils.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | import torch.distributed as dist 5 | import torch.nn as nn 6 | 7 | from elixir.chunk import BlockRequire, ChunkFetcher, ChunkGroup, MemoryPool, TensorState 8 | from elixir.chunk.scheduler import FIFOScheduler 9 | from elixir.hook import BufferStore, HookParam 10 | from elixir.parameter import OutplaceTensor 11 | 12 | 13 | def to_divide(a: int, b: int): 14 | return a + (-a % b) 15 | 16 | 17 | def grad_handler(grad: torch.Tensor, param: nn.Parameter, fetcher: ChunkFetcher): 18 | empty_grad = torch.empty_like(grad) 19 | empty_grad.storage().resize_(0) 20 | 21 | with torch._C.DisableTorchFunction(): 22 | chunk = fetcher.get_one_chunk(param) 23 | if chunk.tensors_info[param].state != TensorState.HOLD_AFTER_BWD: 24 | raise RuntimeError() 25 | fetcher.group.tensor_trans_state(param, TensorState.READY_FOR_REDUCE) 26 | chunk.copy_tensor_to_chunk_slice(param, grad) 27 | fetcher.reduce_chunk(chunk) 28 | 29 | return empty_grad 30 | 31 | 32 | def hook_transform(model: nn.Module, process_group: dist.ProcessGroupGloo): 33 | pg_size = dist.get_world_size(process_group) 34 | 35 | private_list = list() 36 | for param in model.parameters(): 37 | block_size = to_divide(param.numel(), pg_size) 38 | private_list.append(BlockRequire(block_size, param.dtype)) 39 | 40 | mp = MemoryPool('cuda') 41 | mp.allocate(private_block_list=private_list) 42 | cg = ChunkGroup(rcache=mp) 43 | # allocate chunk group 44 | fused_config = dict(rcache_fused=True) 45 | for param in model.parameters(): 46 | cg.allocate_chunk([param], to_divide(param.numel(), pg_size), param.dtype, process_group, fused_config) 47 | # initialize chunk fetcher 48 | scheduler = FIFOScheduler() 49 | fetcher = ChunkFetcher(scheduler, cg) 50 | buffer = BufferStore(0, torch.float32) 51 | # register fetcher and gradient handler 52 | HookParam.attach_fetcher(fetcher, buffer) 53 | for param in model.parameters(): 54 | param.register_hook(partial(grad_handler, param=param, fetcher=fetcher)) 55 | param.__class__ = HookParam 56 | # set inplace to False for all modules 57 | for module in model.modules(): 58 | if hasattr(module, 'inplace'): 59 | module.inplace = False 60 | 61 | def transform_input(self_module, inputs): 62 | fetcher.reset() 63 | input_list = list() 64 | for t in inputs: 65 | if isinstance(t, torch.Tensor): 66 | t = OutplaceTensor(t) 67 | input_list.append(t) 68 | return tuple(input_list) 69 | 70 | model.register_forward_pre_hook(transform_input) 71 | 72 | return model, cg 73 | -------------------------------------------------------------------------------- /test/utils/gpt.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from test.utils.registry import TEST_MODELS 3 | 4 | import torch 5 | import torch.nn as nn 6 | from transformers import GPT2Config, GPT2LMHeadModel 7 | 8 | MICRO_VS = 128 9 | MICRO_BS = 4 10 | MICRO_SL = 64 11 | 12 | MACRO_VS = 50257 13 | MACRO_BS = 2 14 | MACRO_SL = 1024 15 | 16 | 17 | def micro_data_fn(): 18 | input_ids = torch.randint(low=0, high=MICRO_VS, size=(MICRO_BS, MICRO_SL)) 19 | attn_mask = torch.ones_like(input_ids) 20 | return dict(input_ids=input_ids, attention_mask=attn_mask) 21 | 22 | 23 | def small_data_fn(): 24 | input_ids = torch.randint(low=0, high=MACRO_VS, size=(MACRO_BS, MACRO_SL)) 25 | attn_mask = torch.ones_like(input_ids) 26 | return dict(input_ids=input_ids, attention_mask=attn_mask) 27 | 28 | 29 | class GPTLMLoss(nn.Module): 30 | 31 | def __init__(self): 32 | super().__init__() 33 | self.loss_fn = nn.CrossEntropyLoss() 34 | 35 | def forward(self, logits, labels): 36 | shift_logits = logits[..., :-1, :].contiguous() 37 | shift_labels = labels[..., 1:].contiguous() 38 | # Flatten the tokens 39 | return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 40 | 41 | 42 | class GPTLMModel(nn.Module): 43 | 44 | def __init__(self, hidden_size=768, num_layers=12, num_attention_heads=12, max_seq_len=1024, vocab_size=50257): 45 | super().__init__() 46 | self.enable_gc = False 47 | self.config = GPT2Config( 48 | # pre-commit: do not rearrange 49 | n_embd=hidden_size, 50 | n_layer=num_layers, 51 | n_head=num_attention_heads, 52 | n_positions=max_seq_len, 53 | n_ctx=max_seq_len, 54 | vocab_size=vocab_size, 55 | resid_pdrop=0.0, 56 | embd_pdrop=0.0, 57 | attn_pdrop=0.0) 58 | self.module = GPT2LMHeadModel(config=self.config) 59 | self.criterion = GPTLMLoss() 60 | 61 | def gradient_checkpointing_enable(self): 62 | self.module.gradient_checkpointing_enable() 63 | self.enable_gc = True 64 | 65 | def forward(self, input_ids, attention_mask): 66 | # Only return lm_logits 67 | output = self.module(input_ids=input_ids, attention_mask=attention_mask, use_cache=(not self.enable_gc))[0] 68 | loss = self.criterion(output, input_ids) 69 | return loss 70 | 71 | 72 | gpt2_micro = partial(GPTLMModel, hidden_size=32, num_layers=2, num_attention_heads=4, max_seq_len=64, vocab_size=128) 73 | gpt2_small = GPTLMModel 74 | gpt2_base = partial(GPTLMModel, hidden_size=1024, num_layers=24, num_attention_heads=16) 75 | 76 | TEST_MODELS.register('gpt2_micro', gpt2_micro, micro_data_fn) 77 | TEST_MODELS.register('gpt2_small', gpt2_small, small_data_fn) 78 | TEST_MODELS.register('gpt2_base', gpt2_base, small_data_fn) 79 | -------------------------------------------------------------------------------- /elixir/tracer/memory_tracer/cuda_profiler.py: -------------------------------------------------------------------------------- 1 | import gc 2 | from typing import Callable, Dict, Tuple, Union 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.utils._pytree import tree_map 7 | 8 | from elixir.tracer.utils import get_cuda_allocated, meta_copy, model_memory_figure 9 | from elixir.utils import print_rank_0 10 | 11 | from .memory_tensor import MTensor 12 | 13 | 14 | def grad_cleaner(grad): 15 | empty_grad = torch.empty_like(grad.elem) 16 | grad.elem = None 17 | empty_grad.storage().resize_(0) 18 | return empty_grad 19 | 20 | 21 | def cuda_memory_profiling(model: nn.Module, inp: Dict, step_fn: Callable, dtype=torch.float): 22 | assert isinstance(inp, dict), 'the example input should be a dictionary' 23 | print_rank_0(f'You are profiling cuda memory with dtype `{dtype}`') 24 | 25 | def tensor_trans(t: torch.Tensor): 26 | # set dtype for tensors 27 | meta_dtype = dtype if t.is_floating_point() else t.dtype 28 | meta_t = torch.empty_like(t.data, device='meta', dtype=meta_dtype) 29 | # pack parameters 30 | if isinstance(t, nn.Parameter): 31 | meta_t = nn.Parameter(meta_t) 32 | return meta_t 33 | 34 | # first, transform the model into one dtype 35 | model = meta_copy(model, tensor_trans) 36 | # get the memory firgure of the model 37 | memo_dict = model_memory_figure(model) 38 | # initialize a empty pool for parameters 39 | pool = torch.zeros(memo_dict['param_max_numel'], device='cuda', dtype=dtype) 40 | 41 | def tensor_to_cuda(t): 42 | if isinstance(t, nn.Parameter): 43 | fake_data = pool[:t.numel()].view(t.shape) 44 | return nn.Parameter(fake_data) 45 | else: 46 | fake_data = torch.zeros(t.shape, device='cuda', dtype=t.dtype) 47 | return fake_data 48 | 49 | # make all parameters in CUDA and point to a same address 50 | model = meta_copy(model, tensor_to_cuda) 51 | # add hooks to clean gradients 52 | for param in model.parameters(): 53 | param.register_hook(grad_cleaner) 54 | 55 | def input_trans(t): 56 | if isinstance(t, torch.Tensor): 57 | cuda_dtype = dtype if t.is_floating_point() else t.dtype 58 | cuda_t = t.data.clone() 59 | cuda_t = cuda_t.to(dtype=cuda_dtype, device='cuda') 60 | cuda_t.requires_grad = t.requires_grad 61 | return MTensor(cuda_t) 62 | return t 63 | 64 | inp = tree_map(input_trans, inp) 65 | # reset all collected peak memory states 66 | MTensor.reset_peak_memory() 67 | before_cuda_alc = get_cuda_allocated() 68 | 69 | step_fn(model, inp) 70 | 71 | after_cuda_alc = MTensor.current_peak_memory() 72 | activation_occ = after_cuda_alc - before_cuda_alc 73 | 74 | return dict(param_occ=memo_dict['param_occ'], 75 | buffer_occ=memo_dict['buffer_occ'], 76 | grad_occ=memo_dict['param_occ'], 77 | activation_occ=activation_occ) 78 | -------------------------------------------------------------------------------- /test/wrapper/test_prefetch.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | from functools import partial 4 | from test.utils import TEST_MODELS, to_cuda 5 | 6 | import pytest 7 | import torch 8 | import torch.distributed as dist 9 | import torch.nn as nn 10 | from torch.nn.parallel import DistributedDataParallel as DDP 11 | from torch.testing import assert_close 12 | 13 | from elixir.cuda import gpu_device 14 | from elixir.search import simple_search 15 | from elixir.utils import init_distributed, seed_all 16 | from elixir.wrapper import ElixirModule 17 | 18 | 19 | def check_gradient(ddp_model: nn.Module, test_model: ElixirModule): 20 | grad_state = test_model.state_dict(from_param=True) 21 | for name, param in ddp_model.named_parameters(): 22 | assert_close(param.grad.cpu(), grad_state[name]) 23 | 24 | 25 | def exam_one_module_fwd_bwd(model_fn, data_fn, nproc, group, exam_seed=2263): 26 | 27 | def one_step(local_model, local_input): 28 | loss = local_model(**local_input) 29 | loss.backward() 30 | return loss 31 | 32 | ddp_model = model_fn().cuda() 33 | test_model = copy.deepcopy(ddp_model) 34 | 35 | # get different data 36 | seed_all(exam_seed + dist.get_rank(group)) 37 | data = to_cuda(data_fn()) 38 | 39 | # wrap as DDP model 40 | ddp_model = DDP(ddp_model) 41 | # search how to initialize chunks 42 | sr = simple_search(test_model, 43 | nproc, 44 | shard_device=gpu_device(), 45 | prefetch=True, 46 | verbose=True, 47 | inp=data, 48 | step_fn=one_step) 49 | test_model = ElixirModule(test_model, sr, group, prefetch=True) 50 | 51 | seed_all(exam_seed, cuda_deterministic=True) 52 | ddp_loss = one_step(ddp_model, data) 53 | 54 | with torch.no_grad(): 55 | test_loss = test_model(**data) 56 | assert_close(ddp_loss, test_loss) 57 | 58 | test_loss = test_model(**data) 59 | test_model.backward(test_loss) 60 | assert_close(ddp_loss, test_loss) 61 | check_gradient(ddp_model.module, test_model) 62 | 63 | 64 | def exam_modules_fwd_bwd(nproc, group): 65 | model_fn, data_fn = TEST_MODELS.get('resnet') 66 | exam_one_module_fwd_bwd(model_fn, data_fn, nproc, group) 67 | 68 | 69 | def run_dist(rank, world_size): 70 | os.environ['RANK'] = str(rank) 71 | os.environ['LOCAL_RANK'] = str(rank) 72 | os.environ['WORLD_SIZE'] = str(world_size) 73 | os.environ['MASTER_ADDR'] = '127.0.0.1' 74 | os.environ['MASTER_PORT'] = str(29512) 75 | init_distributed() 76 | exam_modules_fwd_bwd(nproc=world_size, group=dist.GroupMember.WORLD) 77 | 78 | 79 | @pytest.mark.dist 80 | @pytest.mark.parametrize('world_size', [1, 2, 4]) 81 | def test_module_prefetch(world_size): 82 | run_func = partial(run_dist, world_size=world_size) 83 | torch.multiprocessing.spawn(run_func, nprocs=world_size) 84 | 85 | 86 | if __name__ == '__main__': 87 | test_module_prefetch(world_size=2) 88 | -------------------------------------------------------------------------------- /elixir/chunk/scheduler/prefetch.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from typing import Iterable, List, Optional 3 | 4 | import torch 5 | from sortedcontainers import SortedSet 6 | 7 | from elixir.chunk.core import Chunk 8 | 9 | from .base import Chunk, ChunkScheduler 10 | 11 | 12 | class PrefetchScheduler(ChunkScheduler): 13 | 14 | def __init__(self, chunk_called_per_step: List[Iterable[Chunk]]) -> None: 15 | super().__init__() 16 | self.chunk_mapping = None 17 | self.evict_set = None 18 | self.search_step = -1 19 | 20 | self.chunks_per_step = chunk_called_per_step 21 | self.total_steps = len(chunk_called_per_step) 22 | self.next_step_dict = defaultdict(list) 23 | # initialize the next_step dictionary 24 | for i, c_list in enumerate(chunk_called_per_step): 25 | for c in c_list: 26 | self.next_step_dict[c].append(i) 27 | 28 | def _get_next_step(self, chunk: Chunk): 29 | step_list = self.next_step_dict[chunk] 30 | for i in step_list: 31 | if i > self.current_step: 32 | return i 33 | return self.total_steps 34 | 35 | def reset(self) -> None: 36 | super().reset() 37 | self.chunk_mapping = dict() 38 | self.evict_set = SortedSet() 39 | self.search_step = -1 40 | 41 | def clear(self) -> None: 42 | super().clear() 43 | if torch.is_grad_enabled(): 44 | assert self.current_step == self.total_steps - 1 45 | self.chunk_mapping = None 46 | self.evict_set = None 47 | self.search_step = -1 48 | 49 | def top(self) -> Optional[Chunk]: 50 | if not super().top(): 51 | return None 52 | next_step, chunk = self.evict_set[-1] 53 | return chunk 54 | 55 | def add(self, chunk: Chunk) -> bool: 56 | if not super().add(chunk): 57 | return False 58 | value = (self._get_next_step(chunk), chunk) 59 | self.chunk_mapping[chunk] = value 60 | self.evict_set.add(value) 61 | return True 62 | 63 | def remove(self, chunk: Chunk) -> bool: 64 | if not super().remove(chunk): 65 | return False 66 | value = self.chunk_mapping[chunk] 67 | self.evict_set.remove(value) 68 | self.chunk_mapping.pop(chunk) 69 | return True 70 | 71 | def step(self, *args, **kwags): 72 | super().step(*args, **kwags) 73 | if self.current_step >= self.total_steps: 74 | raise RuntimeError('exceed simulated steps, please modify your profiling `step_fn`') 75 | 76 | def get_next_chunk(self, chunks: List[Chunk]): 77 | self.search_step = max(self.search_step, self.current_step + 1) 78 | while self.search_step < self.total_steps: 79 | c_list = self.chunks_per_step[self.search_step] 80 | for c in c_list: 81 | if c not in chunks: 82 | return c 83 | self.search_step += 1 84 | return None 85 | -------------------------------------------------------------------------------- /elixir/tracer/memory_tracer/memory_tensor.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | from typing import Iterator 3 | 4 | import torch 5 | from torch.utils._pytree import tree_map 6 | 7 | from elixir.tracer.utils import get_cuda_max_allocated 8 | 9 | from .op_cache import wrapped_mm_ops 10 | 11 | aten = torch.ops.aten 12 | 13 | mm_ops_list = [aten.mm.default, aten.addmm.default, aten.bmm.default, aten.addbmm.default, aten.baddbmm.default] 14 | 15 | 16 | @contextlib.contextmanager 17 | def no_dispatch() -> Iterator[None]: 18 | guard = torch._C._DisableTorchDispatch() 19 | try: 20 | yield 21 | finally: 22 | del guard 23 | 24 | 25 | def normalize_tuple(x): 26 | if not isinstance(x, tuple): 27 | return (x,) 28 | return x 29 | 30 | 31 | class MTensor(torch.Tensor): 32 | elem: torch.Tensor 33 | 34 | __slots__ = ['elem'] 35 | 36 | peak_memory_allocated: int = 0 37 | 38 | @staticmethod 39 | def reset_peak_memory(): 40 | torch.cuda.reset_peak_memory_stats() 41 | MTensor.peak_memory_allocated = 0 42 | 43 | @staticmethod 44 | def update_peak_memory(new_peak): 45 | MTensor.peak_memory_allocated = max(MTensor.peak_memory_allocated, new_peak) 46 | 47 | @staticmethod 48 | def current_peak_memory(): 49 | cur_peak = get_cuda_max_allocated() 50 | return max(MTensor.peak_memory_allocated, cur_peak) 51 | 52 | @staticmethod 53 | def __new__(cls, elem, *args, **kwargs): 54 | r = torch.Tensor._make_wrapper_subclass( 55 | cls, 56 | elem.size(), 57 | strides=elem.stride(), 58 | storage_offset=elem.storage_offset(), 59 | # TODO: clone strides and storage aliasing 60 | dtype=elem.dtype, 61 | layout=elem.layout, 62 | device=elem.device, 63 | requires_grad=elem.requires_grad) 64 | r.elem = elem 65 | return r 66 | 67 | def __repr__(self): 68 | return f'MTensor({self.elem})' 69 | 70 | @classmethod 71 | def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 72 | 73 | def print_tensor(x): 74 | if isinstance(x, torch.Tensor): 75 | print(x.shape) 76 | 77 | # tree_map(print_tensor, args) 78 | # tree_map(print_tensor, kwargs) 79 | 80 | def unwrap(x): 81 | return x.elem if isinstance(x, MTensor) else x 82 | 83 | def wrap(x): 84 | return MTensor(x) if isinstance(x, torch.Tensor) else x 85 | 86 | if func in mm_ops_list: 87 | res, pre_max = wrapped_mm_ops(func, *tree_map(unwrap, args), **tree_map(unwrap, kwargs)) 88 | MTensor.update_peak_memory(pre_max) 89 | else: 90 | with no_dispatch(): 91 | res = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)) 92 | 93 | outs = normalize_tuple(res) 94 | res = tree_map(wrap, outs) 95 | 96 | if len(res) == 1: 97 | return res[0] 98 | else: 99 | return res 100 | -------------------------------------------------------------------------------- /test/parameter/hf_models/albert.py: -------------------------------------------------------------------------------- 1 | from test.parameter.hf_models import test_hf_model 2 | 3 | import torch 4 | import transformers 5 | 6 | BS = 2 7 | SL = 16 8 | 9 | one_sentence_config = transformers.AlbertConfig(embedding_size=128, 10 | hidden_size=128, 11 | num_hidden_layers=2, 12 | num_attention_heads=4, 13 | intermediate_size=256) 14 | 15 | multi_sentence_config = transformers.AlbertConfig(hidden_size=128, 16 | num_hidden_layers=2, 17 | num_attention_heads=4, 18 | intermediate_size=256) 19 | 20 | 21 | def data_one_sentence(): 22 | input_ids = torch.zeros((BS, SL), dtype=torch.int64) 23 | token_type_ids = torch.zeros((BS, SL), dtype=torch.int64) 24 | attention_mask = torch.zeros((BS, SL), dtype=torch.int64) 25 | return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) 26 | 27 | 28 | def data_qa(): 29 | tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-uncased') 30 | question, text = 'Who was Jim Henson?', 'Jim Henson was a nice puppet' 31 | inputs = tokenizer(question, text, return_tensors='pt') 32 | return inputs 33 | 34 | 35 | def data_mcq(): 36 | tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-uncased') 37 | prompt = 'In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced.' 38 | choice0 = 'It is eaten with a fork and a knife.' 39 | choice1 = 'It is eaten while held in the hand.' 40 | encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors='pt', padding=True) 41 | encoding = {k: v.unsqueeze(0) for k, v in encoding.items()} 42 | return encoding 43 | 44 | 45 | model_dict = { 46 | transformers.AlbertModel: dict(config=one_sentence_config, data=data_one_sentence), 47 | transformers.AlbertForPreTraining: dict(config=one_sentence_config, data=data_one_sentence), 48 | transformers.AlbertForMaskedLM: dict(config=one_sentence_config, data=data_one_sentence), 49 | transformers.AlbertForSequenceClassification: dict(config=one_sentence_config, data=data_one_sentence), 50 | transformers.AlbertForTokenClassification: dict(config=one_sentence_config, data=data_one_sentence), 51 | transformers.AlbertForQuestionAnswering: dict(config=multi_sentence_config, data=data_qa), 52 | transformers.AlbertForMultipleChoice: dict(config=multi_sentence_config, data=data_mcq), 53 | } 54 | 55 | 56 | def test_albert(): 57 | for builder, config_dict in model_dict.items(): 58 | kwargs = dict(config=config_dict['config']) 59 | data_fn = config_dict['data'] 60 | 61 | flag = '√' 62 | try: 63 | test_hf_model(builder, kwargs, data_fn) 64 | except: 65 | flag = 'x' 66 | print(f'{builder.__name__:40s} {flag}') 67 | 68 | 69 | if __name__ == '__main__': 70 | test_albert() 71 | -------------------------------------------------------------------------------- /test/chunk/test_group.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import partial 3 | 4 | import pytest 5 | import torch 6 | import torch.distributed as dist 7 | 8 | from elixir.chunk import BlockRequire, ChunkGroup, MemoryPool, TensorState 9 | from elixir.utils import init_distributed 10 | 11 | 12 | def exam_chunk_group_functions(nproc, group): 13 | a = torch.randn(3, 64, device='cuda') 14 | copy_a = a.clone() 15 | b = torch.randn(2, 32, device='cuda') 16 | copy_b = b.clone() 17 | c = torch.randn(256, device='cuda') 18 | copy_c = c.clone() 19 | d = torch.randn(2, 2, 64, device='cuda') 20 | copy_d = d.clone() 21 | e = torch.randn(2, 33, device='cuda') 22 | copy_e = e.clone() 23 | 24 | mp = MemoryPool('cuda') 25 | mp.allocate(public_block_size=256, public_block_number=2, private_block_list=[BlockRequire(68, torch.float)]) 26 | cg = ChunkGroup(rcache=mp) 27 | c0 = cg.allocate_chunk([a, b], 256, torch.float, group) 28 | c1 = cg.allocate_chunk([c], 256, torch.float, group) 29 | c2 = cg.allocate_chunk([d], 256, torch.float, group) 30 | 31 | fused_config = dict(rcache_fused=True) 32 | c3 = cg.allocate_chunk([e], 68, torch.float, group, fused_config) 33 | 34 | def check_chunk_0(): 35 | assert torch.equal(a, copy_a) 36 | assert torch.equal(b, copy_b) 37 | 38 | def check_chunk_1(): 39 | assert torch.equal(c, copy_c) 40 | 41 | def check_chunk_2(): 42 | assert torch.equal(d, copy_d) 43 | 44 | def check_chunk_3(): 45 | assert torch.equal(e, copy_e) 46 | 47 | # check tensors_to_chunks 48 | chunks = cg.tensors_to_chunks([e, a]) 49 | assert chunks[0] == c0 50 | assert chunks[1] == c3 51 | # check access_chunk for unfused chunks 52 | cg.access_chunk(c0) 53 | cg.access_chunk(c1) 54 | check_chunk_0() 55 | check_chunk_1() 56 | assert not cg.rcache_enough_check(c2) 57 | assert cg.rcache_enough_check(c3) 58 | # check access_chunk for fused chunks 59 | cg.access_chunk(c3) 60 | check_chunk_3() 61 | # check release_chunk for unfused chunks 62 | cg.release_chunk(c1) 63 | assert cg.rcache_enough_check(c2) 64 | # check access_chunk 65 | cg.access_chunk(c2) 66 | check_chunk_2() 67 | 68 | cg.tensor_trans_state(e, TensorState.COMPUTE) 69 | cg.tensor_trans_state(e, TensorState.HOLD_AFTER_BWD) 70 | cg.tensor_trans_state(e, TensorState.READY_FOR_REDUCE) 71 | cg.reduce_chunk(c3) 72 | assert not c3.is_replica 73 | 74 | torch.cuda.synchronize() 75 | print('chunk group functions are ok') 76 | 77 | 78 | def run_dist(rank, world_size): 79 | os.environ['RANK'] = str(rank) 80 | os.environ['LOCAL_RANK'] = str(rank) 81 | os.environ['WORLD_SIZE'] = str(world_size) 82 | os.environ['MASTER_ADDR'] = '127.0.0.1' 83 | os.environ['MASTER_PORT'] = str(29512) 84 | init_distributed() 85 | exam_chunk_group_functions(nproc=world_size, group=dist.GroupMember.WORLD) 86 | 87 | 88 | @pytest.mark.dist 89 | @pytest.mark.parametrize('world_size', [1, 2, 4]) 90 | def test_chunk_group(world_size): 91 | run_func = partial(run_dist, world_size=world_size) 92 | torch.multiprocessing.spawn(run_func, nprocs=world_size) 93 | 94 | 95 | if __name__ == '__main__': 96 | test_chunk_group(world_size=2) 97 | -------------------------------------------------------------------------------- /example/fine-tune/func_module.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import psutil 5 | import torch 6 | import torch.distributed as dist 7 | from colossalai.utils import get_current_device 8 | from tqdm import tqdm 9 | 10 | 11 | def seed_all(seed: int = 42): 12 | random.seed(seed) 13 | np.random.seed(seed) 14 | torch.manual_seed(seed) 15 | torch.cuda.manual_seed_all(seed) 16 | 17 | 18 | def get_cpu_mem(): 19 | return psutil.Process().memory_info().rss / 1024**2 20 | 21 | 22 | def get_gpu_mem(): 23 | return torch.cuda.max_memory_allocated() / 1024**2 24 | 25 | 26 | def get_cur_gpu_mem(): 27 | return torch.cuda.memory_allocated() / 1024**2 28 | 29 | 30 | def get_mem_info(prefix=''): 31 | return '{}current CUDA memory: {:.2f} MB, past max CUDA memory: {:.2f}, CPU memory {:.2f} MB'.format( 32 | prefix, get_cur_gpu_mem(), get_gpu_mem(), get_cpu_mem()) 33 | 34 | 35 | def get_tflops(model_numel, batch_size, seq_len, step_time): 36 | return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) 37 | 38 | 39 | def train(epoch, sampler, model, loader, optimizer, show_progress=True, lr_scheduler=None, optimizer_backward=False): 40 | if sampler: 41 | sampler.set_epoch(epoch) 42 | model.train() 43 | train_iter = iter(loader) 44 | num_steps_per_epoch = len(loader) 45 | 46 | def run_step(): 47 | batch = next(train_iter) 48 | for key, val in batch.items(): 49 | batch[key] = val.cuda() 50 | 51 | optimizer.zero_grad() 52 | outputs = model(**batch) 53 | output_loss = outputs[0] 54 | step_loss = output_loss.item() 55 | if optimizer_backward: 56 | optimizer.backward(output_loss) 57 | else: 58 | output_loss.backward() 59 | optimizer.step() 60 | return step_loss 61 | 62 | with tqdm(range(num_steps_per_epoch), desc='train', ncols=0, disable=not show_progress) as t: 63 | for step in t: 64 | loss = run_step() 65 | lr_scheduler.step() 66 | t.set_postfix(loss=f'{loss:.4f}') 67 | 68 | try: 69 | while True: 70 | next(train_iter) 71 | except StopIteration: 72 | pass 73 | 74 | 75 | def evaluate(model, loader, metric, show_progress=True): 76 | model.eval() 77 | valid_iter = iter(loader) 78 | num_steps_per_epoch = len(loader) 79 | 80 | with torch.no_grad(): 81 | with tqdm(range(num_steps_per_epoch), desc='valid', ncols=0, disable=not show_progress) as t: 82 | for step in t: 83 | batch = next(valid_iter) 84 | for key, val in batch.items(): 85 | batch[key] = val.cuda() 86 | 87 | outputs = model(**batch) 88 | val_loss, logits = outputs[:2] 89 | preds = torch.argmax(logits, dim=-1) 90 | labels = batch['labels'] 91 | metric.add_batch(predictions=preds, references=labels) 92 | 93 | try: 94 | while True: 95 | next(valid_iter) 96 | except StopIteration: 97 | pass 98 | 99 | score = metric.compute() 100 | return score['accuracy'], score['f1'] 101 | -------------------------------------------------------------------------------- /test/parameter/test_torchvision.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torchvision.models as tm 6 | from torch.testing import assert_close 7 | 8 | from elixir.parameter.temp import transform 9 | from elixir.utils import seed_all 10 | 11 | 12 | def test_model(builder, kwargs): 13 | torch_model = builder(**kwargs).cuda() 14 | test_model = deepcopy(torch_model) 15 | test_model = transform(test_model) 16 | 17 | torch_model.eval() 18 | test_model.eval() 19 | 20 | data = torch.randn(2, 3, 224, 224, device='cuda') 21 | torch_out = torch_model(data) 22 | test_out = test_model(data) 23 | assert_close(torch_out, test_out) 24 | 25 | 26 | def test_torchvision_models(): 27 | seed_all(1001, cuda_deterministic=True) 28 | 29 | model_list = [ 30 | tm.alexnet, 31 | tm.convnext_base, 32 | tm.densenet121, 33 | tm.efficientnet_v2_s, 34 | tm.googlenet, # output bad case 35 | tm.inception_v3, # bad case 36 | tm.mobilenet_v2, 37 | tm.mobilenet_v3_small, 38 | tm.mnasnet0_5, 39 | tm.resnet18, 40 | tm.regnet_x_16gf, 41 | tm.resnext50_32x4d, 42 | tm.shufflenet_v2_x0_5, 43 | tm.squeezenet1_0, 44 | tm.swin_s, # fx bad case 45 | tm.vgg11, 46 | tm.vit_b_16, 47 | tm.wide_resnet50_2, 48 | ] 49 | for builder in model_list: 50 | kwargs = {} 51 | flag = '√' 52 | try: 53 | test_model(builder, kwargs) 54 | except: 55 | flag = 'x' 56 | 57 | print(f'{builder.__name__:20s} {flag}') 58 | 59 | 60 | def test_fwd_bwd(builder, kwargs): 61 | torch_model = builder(**kwargs).cuda() 62 | test_model = deepcopy(torch_model) 63 | test_model = transform(test_model) 64 | 65 | torch_model.eval() 66 | test_model.eval() 67 | 68 | data = torch.randn(2, 3, 224, 224, device='cuda') 69 | torch_loss = torch_model(data).sum() 70 | torch_loss.backward() 71 | 72 | test_loss = test_model(data).sum() 73 | assert_close(torch_loss, test_loss) 74 | test_loss.backward() 75 | 76 | for (torch_p, test_p) in zip(torch_model.parameters(), test_model.parameters()): 77 | assert_close(torch_p.grad, test_p.grad) 78 | 79 | 80 | def test_fwd_bwd_models(): 81 | seed_all(1001, cuda_deterministic=True) 82 | 83 | model_list = [ 84 | tm.alexnet, 85 | tm.convnext_base, 86 | tm.densenet121, 87 | tm.efficientnet_v2_s, 88 | tm.googlenet, # output bad case 89 | tm.inception_v3, # bad case 90 | tm.mobilenet_v2, 91 | tm.mobilenet_v3_small, 92 | tm.mnasnet0_5, 93 | tm.resnet18, 94 | tm.regnet_x_16gf, 95 | tm.resnext50_32x4d, 96 | tm.shufflenet_v2_x0_5, 97 | tm.squeezenet1_0, 98 | tm.swin_s, # fx bad case 99 | tm.vgg11, 100 | tm.vit_b_16, 101 | tm.wide_resnet50_2, 102 | ] 103 | for builder in model_list: 104 | kwargs = {} 105 | flag = '√' 106 | try: 107 | test_fwd_bwd(builder, kwargs) 108 | except: 109 | flag = 'x' 110 | 111 | print(f'{builder.__name__:20s} {flag}') 112 | 113 | 114 | if __name__ == '__main__': 115 | test_fwd_bwd_models() 116 | -------------------------------------------------------------------------------- /elixir/hook/parameter.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.utils._pytree import tree_map 6 | 7 | from elixir.chunk import ChunkFetcher 8 | from elixir.kernels import fused_torch_functions 9 | from elixir.parameter import OutplaceTensor, is_no_hook_op, to_outplace_tensor 10 | 11 | from .functions import postfwd_prebwd_function, prefwd_postbwd_function 12 | from .storage import BufferStore 13 | 14 | 15 | class HookParam(OutplaceTensor, nn.Parameter): 16 | pre_fwd_func = None 17 | post_fwd_func = None 18 | use_fused_kernel = False 19 | 20 | @staticmethod 21 | def attach_fetcher(fetcher: ChunkFetcher, store: BufferStore): 22 | HookParam.pre_fwd_func = prefwd_postbwd_function(fetcher, store) 23 | HookParam.post_fwd_func = postfwd_prebwd_function(fetcher, store) 24 | 25 | @staticmethod 26 | def release_fetcher(): 27 | HookParam.pre_fwd_func = None 28 | HookParam.post_fwd_func = None 29 | 30 | @staticmethod 31 | def enable_fused_kernel(): 32 | HookParam.use_fused_kernel = True 33 | 34 | @staticmethod 35 | def disable_fused_kernel(): 36 | HookParam.use_fused_kernel = False 37 | 38 | @classmethod 39 | def __torch_function__(cls, func, types, args=(), kwargs=None): 40 | if kwargs is None: 41 | kwargs = {} 42 | 43 | if is_no_hook_op(func): 44 | with torch._C.DisableTorchFunction(): 45 | ret = func(*args, **kwargs) 46 | return ret 47 | 48 | params_to_index = OrderedDict() 49 | params_index = 0 50 | 51 | def append_param(x): 52 | nonlocal params_index 53 | if isinstance(x, HookParam): 54 | params_to_index[x] = params_index 55 | params_index += 1 56 | 57 | tree_map(append_param, args) 58 | tree_map(append_param, kwargs) 59 | 60 | params = tuple(params_to_index.keys()) 61 | new_params = HookParam.pre_fwd_func(params, *params) 62 | 63 | def replace_param(x): 64 | if isinstance(x, HookParam): 65 | return new_params[params_to_index[x]] 66 | return x 67 | 68 | with torch._C.DisableTorchFunction(): 69 | if HookParam.use_fused_kernel and func in fused_torch_functions: 70 | func = fused_torch_functions.get(func) 71 | ret = func(*tree_map(replace_param, args), **tree_map(replace_param, kwargs)) 72 | if not isinstance(ret, tuple): 73 | ret = (ret,) 74 | 75 | ptr_set = set() 76 | for p in new_params: 77 | ptr_set.add(p.data_ptr()) 78 | 79 | def clone_inplace_tensor(x): 80 | if isinstance(x, torch.Tensor): 81 | start_point = x.data_ptr() - x.element_size() * x.storage_offset() 82 | if start_point in ptr_set: 83 | return x.clone() 84 | return x 85 | 86 | ret = tree_map(clone_inplace_tensor, ret) 87 | ret = HookParam.post_fwd_func(params, *ret) 88 | 89 | def convert(t): 90 | if isinstance(t, torch.Tensor): 91 | t = to_outplace_tensor(t) 92 | return t 93 | 94 | ret = tree_map(convert, ret) 95 | 96 | if len(ret) == 1: 97 | return ret[0] 98 | else: 99 | return ret 100 | -------------------------------------------------------------------------------- /test/wrapper/test_module.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | from functools import partial 4 | from test.utils import TEST_MODELS, assert_dict_values, to_cuda 5 | 6 | import pytest 7 | import torch 8 | import torch.distributed as dist 9 | import torch.nn as nn 10 | from torch.nn.parallel import DistributedDataParallel as DDP 11 | from torch.testing import assert_close 12 | 13 | from elixir.search import simple_search 14 | from elixir.utils import init_distributed, seed_all 15 | from elixir.wrapper import ElixirModule 16 | 17 | 18 | def check_gradient(ddp_model: nn.Module, test_model: ElixirModule): 19 | grad_state = test_model.state_dict(from_param=True) 20 | for name, param in ddp_model.named_parameters(): 21 | assert_close(param.grad.cpu(), grad_state[name]) 22 | 23 | 24 | def exam_module_init(nproc, group, grad_flag): 25 | model_fn, data_fn = TEST_MODELS.get('resnet') 26 | torch_model = model_fn().cuda() 27 | test_model = model_fn().cuda() 28 | 29 | for p1, p2 in zip(torch_model.parameters(), test_model.parameters()): 30 | p1.requires_grad = p2.requires_grad = grad_flag 31 | 32 | sr = simple_search(test_model, nproc) 33 | model = ElixirModule(test_model, sr, group) 34 | # check function: ElixirModule.load_state_dict after ElixirModule.__init__ 35 | torch_st = torch_model.state_dict() 36 | if dist.get_rank() != 0: 37 | torch_st = None 38 | test_st = model.load_state_dict(torch_st, only_rank_0=True) 39 | # check function: ElixirModule.state_dict after ElixirModule.__init__ 40 | torch_st = torch_model.state_dict() 41 | test_st = model.state_dict() 42 | assert_dict_values(torch_st, test_st, fn=torch.equal) 43 | 44 | 45 | def exam_one_module_fwd_bwd(model_fn, data_fn, nproc, group, exam_seed=2261): 46 | ddp_model = model_fn().cuda() 47 | test_model = copy.deepcopy(ddp_model) 48 | sr = simple_search(test_model, nproc, allocate_factor=0.6) 49 | test_model = ElixirModule(test_model, sr, group) 50 | 51 | # get different data 52 | seed_all(exam_seed + dist.get_rank(group)) 53 | data = data_fn() 54 | data = to_cuda(data) 55 | 56 | seed_all(exam_seed, cuda_deterministic=True) 57 | ddp_model = DDP(ddp_model) 58 | ddp_loss = ddp_model(**data) 59 | ddp_loss.backward() 60 | 61 | test_loss = test_model(**data) 62 | test_model.backward(test_loss) 63 | 64 | assert_close(ddp_loss, test_loss) 65 | check_gradient(ddp_model.module, test_model) 66 | 67 | 68 | def exam_modules_fwd_bwd(nproc, group): 69 | model_fn, data_fn = TEST_MODELS.get('resnet') 70 | exam_one_module_fwd_bwd(model_fn, data_fn, nproc, group) 71 | 72 | 73 | def run_dist(rank, world_size): 74 | os.environ['RANK'] = str(rank) 75 | os.environ['LOCAL_RANK'] = str(rank) 76 | os.environ['WORLD_SIZE'] = str(world_size) 77 | os.environ['MASTER_ADDR'] = '127.0.0.1' 78 | os.environ['MASTER_PORT'] = str(29512) 79 | init_distributed() 80 | exam_module_init(nproc=world_size, group=dist.GroupMember.WORLD, grad_flag=False) 81 | exam_module_init(nproc=world_size, group=dist.GroupMember.WORLD, grad_flag=True) 82 | exam_modules_fwd_bwd(nproc=world_size, group=dist.GroupMember.WORLD) 83 | 84 | 85 | @pytest.mark.dist 86 | @pytest.mark.parametrize('world_size', [1, 2, 4]) 87 | def test_elixir_module(world_size): 88 | run_func = partial(run_dist, world_size=world_size) 89 | torch.multiprocessing.spawn(run_func, nprocs=world_size) 90 | 91 | 92 | if __name__ == '__main__': 93 | test_elixir_module(world_size=2) 94 | -------------------------------------------------------------------------------- /test/chunk/test_scheduler.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import partial 3 | 4 | import pytest 5 | import torch 6 | import torch.distributed as dist 7 | 8 | from elixir.chunk import Chunk, MemoryPool 9 | from elixir.chunk.scheduler import FIFOScheduler, PrefetchScheduler 10 | from elixir.utils import init_distributed 11 | 12 | 13 | def exam_fifo(nproc, group): 14 | mp = MemoryPool('cuda') 15 | mp.allocate(public_block_number=1) 16 | c0 = Chunk(mp, 1024, torch.float, group) 17 | c1 = Chunk(mp, 1024, torch.float, group) 18 | c2 = Chunk(mp, 1024, torch.float, group) 19 | 20 | sdl = FIFOScheduler() 21 | sdl.reset() 22 | 23 | sdl.add(c0) 24 | sdl.add(c1) 25 | sdl.add(c2) 26 | sdl.add(c0) # nothing happens here 27 | assert sdl.top() == c0 28 | 29 | sdl.remove(c0) 30 | assert sdl.top() == c1, f'{sdl.top()}' 31 | sdl.remove(c0) 32 | assert sdl.top() == c1, f'{sdl.top()}' 33 | 34 | sdl.add(c0) 35 | assert sdl.top() == c1 36 | sdl.remove(c1) 37 | assert sdl.top() == c2 38 | sdl.remove(c2) 39 | assert sdl.top() == c0 40 | 41 | 42 | def exam_prefetch(nproc, group): 43 | mp = MemoryPool('cuda') 44 | mp.allocate() 45 | c0 = Chunk(mp, 1024, torch.float, group) 46 | c1 = Chunk(mp, 1024, torch.float, group) 47 | c2 = Chunk(mp, 1024, torch.float, group) 48 | 49 | chunk_called_per_step = [[c0], [c1], [c2], [c0], [c0], [c1], [c2], [c2], [c1], [c0]] 50 | 51 | sdl = PrefetchScheduler(chunk_called_per_step=chunk_called_per_step) 52 | print(sdl.next_step_dict) 53 | sdl.reset() 54 | 55 | sdl.step() 56 | sdl.add(c0) 57 | assert sdl.top() == c0 58 | 59 | sdl.step() 60 | sdl.add(c1) 61 | assert sdl.top() == c1 62 | 63 | sdl.step() 64 | sdl.add(c2) 65 | assert sdl.top() == c2 66 | 67 | sdl.remove(c0) 68 | sdl.step() 69 | sdl.add(c0) 70 | assert sdl.top() == c2 71 | 72 | sdl.remove(c0) 73 | sdl.step() 74 | sdl.add(c0) 75 | assert sdl.top() == c0 76 | sdl.remove(c0) # notice here 77 | 78 | sdl.remove(c1) 79 | sdl.step() 80 | sdl.add(c1) 81 | assert sdl.top() == c1 82 | 83 | sdl.remove(c2) 84 | sdl.step() 85 | sdl.add(c2) 86 | assert sdl.top() == c1 87 | 88 | sdl.remove(c2) 89 | sdl.step() 90 | sdl.add(c2) 91 | assert sdl.top() == c2 92 | sdl.remove(c2) # notice here 93 | sdl.add(c0) # notice here 94 | 95 | sdl.remove(c1) 96 | sdl.step() 97 | sdl.add(c1) 98 | assert sdl.top() == c1 99 | sdl.remove(c1) # notice here 100 | 101 | sdl.remove(c0) 102 | sdl.step() 103 | sdl.add(c0) 104 | assert sdl.top() == c0 105 | 106 | sdl.remove(c0) 107 | sdl.clear() 108 | 109 | 110 | def run_dist(rank, world_size): 111 | os.environ['RANK'] = str(rank) 112 | os.environ['LOCAL_RANK'] = str(rank) 113 | os.environ['WORLD_SIZE'] = str(world_size) 114 | os.environ['MASTER_ADDR'] = '127.0.0.1' 115 | os.environ['MASTER_PORT'] = str(29512) 116 | init_distributed() 117 | exam_fifo(nproc=world_size, group=dist.GroupMember.WORLD) 118 | exam_prefetch(nproc=world_size, group=dist.GroupMember.WORLD) 119 | 120 | 121 | @pytest.mark.dist 122 | def test_chunk_scheduler(world_size=1): 123 | run_func = partial(run_dist, world_size=world_size) 124 | torch.multiprocessing.spawn(run_func, nprocs=world_size) 125 | 126 | 127 | if __name__ == '__main__': 128 | test_chunk_scheduler() 129 | -------------------------------------------------------------------------------- /test/test_models.py: -------------------------------------------------------------------------------- 1 | import colossalai 2 | import torch 3 | import torch.distributed as dist 4 | from colossalai.nn.optimizer import HybridAdam 5 | from colossalai.testing import rerun_if_address_is_in_use, spawn 6 | from tests.kit.model_zoo import model_zoo 7 | 8 | from elixir.search import minimum_waste_search 9 | from elixir.wrapper import ElixirModule, ElixirOptimizer 10 | 11 | 12 | def check_gemini_plugin(early_stop: bool = True): 13 | """check gemini plugin over model zoo 14 | 15 | Args: 16 | early_stop (bool, optional): Whether to stop when getting the first error. Defaults to True. 17 | """ 18 | passed_models = [] 19 | failed_info = {} # (model_name, error) pair 20 | 21 | for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items(): 22 | # These models lead to CUDA error 23 | if name in ('diffusers_auto_encoder_kl', 'diffusers_vq_model', 'diffusers_unet2d_model', 'timm_resmlp', 24 | 'timm_gmixer_12_224', 'timm_gmlp_b16_224', 'timm_mixer_b16_224', 'timm_convnext', 25 | 'torchaudio_wav2vec2_base', 'torchaudio_hubert_base', 'torchvision_convnext_base'): 26 | continue 27 | 28 | try: 29 | print(name) 30 | global_size = dist.get_world_size() 31 | global_group = dist.GroupMember.WORLD 32 | 33 | model = model_fn() 34 | optimizer = HybridAdam(model.parameters(), lr=1e-3) 35 | criterion = lambda x: x.mean() 36 | data = data_gen_fn() 37 | 38 | data = { 39 | k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v 40 | for k, v in data.items() 41 | } 42 | 43 | sr = minimum_waste_search( 44 | # pre-commit: do not rearrange 45 | m=model, 46 | group_size=global_size, 47 | unified_dtype=torch.float16, 48 | prefetch=False, 49 | verbose=True) 50 | 51 | model = ElixirModule(model, sr, global_group, prefetch=False, dtype=torch.float16) 52 | optimizer = ElixirOptimizer(model, optimizer, initial_scale=32) 53 | 54 | output = model(**data) 55 | output = output_transform_fn(output) 56 | output_key = list(output.keys())[0] 57 | loss = criterion(output[output_key]) 58 | 59 | optimizer.backward(loss) 60 | optimizer.step() 61 | passed_models.append(name) 62 | 63 | del model, optimizer, criterion, data, output, loss 64 | except Exception as e: 65 | failed_info[name] = e 66 | if early_stop: 67 | raise e 68 | 69 | torch.cuda.empty_cache() 70 | 71 | if dist.get_rank() == 0: 72 | print(f'Passed models({len(passed_models)}): {passed_models}\n\n') 73 | print(f'Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n') 74 | assert len(failed_info) == 0, '\n'.join([f'{k}: {v}' for k, v in failed_info.items()]) 75 | 76 | 77 | def run_dist(rank, world_size, port, early_stop: bool = True): 78 | # init dist env 79 | colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') 80 | check_gemini_plugin(early_stop=early_stop) 81 | 82 | 83 | @rerun_if_address_is_in_use() 84 | def test_gemini_plugin(early_stop: bool = True): 85 | spawn(run_dist, 2, early_stop=early_stop) 86 | 87 | 88 | if __name__ == '__main__': 89 | test_gemini_plugin(early_stop=False) 90 | -------------------------------------------------------------------------------- /elixir/tracer/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from copy import copy 3 | from typing import Optional, Set 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | def _get_dfs_module_list(module: nn.Module, memo: Optional[Set[nn.Module]] = None, prefix: str = ''): 10 | """Get a dfs module list of the given module. Its order is same as the order of creations of modules. 11 | """ 12 | if memo is None: 13 | memo = set() 14 | if module not in memo: 15 | for name, submodule in module._modules.items(): 16 | if submodule is None: 17 | continue 18 | submodule_prefix = prefix + ('.' if prefix else '') + name 19 | for m in _get_dfs_module_list(submodule, memo, submodule_prefix): 20 | yield m 21 | 22 | memo.add(module) 23 | yield prefix, module 24 | 25 | 26 | def _get_shallow_copy_model(model: nn.Module): 27 | """Get a shallow copy of the given model. Each submodule is different from the original submodule. 28 | But the new submodule and the old submodule share all attributes. 29 | """ 30 | old_to_new = dict() 31 | for name, module in _get_dfs_module_list(model): 32 | new_module = copy(module) 33 | new_module._modules = OrderedDict() 34 | for subname, submodule in module._modules.items(): 35 | if submodule is None: 36 | continue 37 | setattr(new_module, subname, old_to_new[submodule]) 38 | old_to_new[module] = new_module 39 | return old_to_new[model] 40 | 41 | 42 | def meta_copy(model: nn.Module, meta_fn: callable): 43 | new_model = _get_shallow_copy_model(model) 44 | old_parameters = dict() 45 | old_buffers = dict() 46 | 47 | for (_, old_module), (_, new_module) in \ 48 | zip(_get_dfs_module_list(model), _get_dfs_module_list(new_model)): 49 | 50 | new_module._parameters = OrderedDict() 51 | for name, param in old_module._parameters.items(): 52 | new_param = None 53 | if param is not None: 54 | param_id = id(param) 55 | if param_id in old_parameters: 56 | new_param = old_parameters.get(param_id) 57 | else: 58 | new_param = meta_fn(param) 59 | old_parameters[param_id] = new_param 60 | setattr(new_module, name, new_param) 61 | 62 | new_module._buffers = OrderedDict() 63 | for name, buffer in old_module._buffers.items(): 64 | new_buffer = None 65 | if buffer is not None: 66 | buffer_id = id(buffer) 67 | if buffer_id in old_buffers: 68 | new_buffer = old_buffers.get(buffer_id) 69 | else: 70 | new_buffer = meta_fn(buffer) 71 | old_buffers[buffer_id] = new_buffer 72 | new_module.register_buffer(name, new_buffer) 73 | 74 | return new_model 75 | 76 | 77 | def get_cuda_allocated(): 78 | return torch.cuda.memory_allocated() 79 | 80 | 81 | def get_cuda_max_allocated(): 82 | return torch.cuda.max_memory_allocated() 83 | 84 | 85 | def model_memory_figure(model: nn.Module): 86 | param_occ = 0 87 | max_numel = 0 88 | for name, param in model.named_parameters(): 89 | param_occ += param.numel() * param.element_size() 90 | max_numel = max(max_numel, param.numel()) 91 | 92 | buffer_occ = 0 93 | for name, buffer in model.named_buffers(): 94 | buffer_occ += buffer.numel() * buffer.element_size() 95 | 96 | return dict(param_occ=param_occ, param_max_numel=max_numel, buffer_occ=buffer_occ) 97 | -------------------------------------------------------------------------------- /elixir/utils.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import os 3 | import random 4 | 5 | import numpy as np 6 | import torch 7 | import torch.distributed as dist 8 | import torch.nn as nn 9 | 10 | 11 | @contextlib.contextmanager 12 | def no_dispatch(): 13 | guard = torch._C._DisableTorchDispatch() 14 | try: 15 | yield 16 | finally: 17 | del guard 18 | 19 | 20 | def normalize_tuple(x): 21 | if not isinstance(x, tuple): 22 | return (x,) 23 | return x 24 | 25 | 26 | def seed_all(seed, cuda_deterministic=False): 27 | random.seed(seed) 28 | np.random.seed(seed) 29 | torch.manual_seed(seed) 30 | if torch.cuda.is_available(): 31 | torch.cuda.manual_seed(seed) 32 | torch.cuda.manual_seed_all(seed) 33 | if cuda_deterministic: # slower, more reproducible 34 | torch.backends.cudnn.deterministic = True 35 | torch.backends.cudnn.benchmark = False 36 | else: 37 | torch.backends.cudnn.deterministic = False 38 | torch.backends.cudnn.benchmark = True 39 | 40 | 41 | def init_distributed(): 42 | rank = int(os.environ['RANK']) 43 | local_rank = int(os.environ['LOCAL_RANK']) 44 | world_size = int(os.environ['WORLD_SIZE']) 45 | host = os.environ['MASTER_ADDR'] 46 | port = int(os.environ['MASTER_PORT']) 47 | 48 | init_method = f'tcp://[{host}]:{port}' 49 | dist.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=world_size) 50 | 51 | # set cuda device 52 | if torch.cuda.is_available(): 53 | # if local rank is not given, calculate automatically 54 | torch.cuda.set_device(local_rank) 55 | 56 | seed_all(1024) 57 | 58 | 59 | def print_rank_0(*args, **kwargs): 60 | if dist.is_initialized(): 61 | if dist.get_rank() == 0: 62 | print(*args, **kwargs) 63 | dist.barrier() 64 | else: 65 | print(*args, **kwargs) 66 | 67 | 68 | def get_model_size(model: nn.Module): 69 | total_numel = 0 70 | for module in model.modules(): 71 | for p in module.parameters(recurse=False): 72 | total_numel += p.numel() 73 | return total_numel 74 | 75 | 76 | def model_size_formatter(numel: int) -> str: 77 | GB_SIZE = 10**9 78 | MB_SIZE = 10**6 79 | KB_SIZE = 10**3 80 | if numel >= GB_SIZE: 81 | return f'{numel / GB_SIZE:.1f}B' 82 | elif numel >= MB_SIZE: 83 | return f'{numel / MB_SIZE:.1f}M' 84 | elif numel >= KB_SIZE: 85 | return f'{numel / KB_SIZE:.1f}K' 86 | else: 87 | return str(numel) 88 | 89 | 90 | def calc_buffer_size(m: nn.Module, test_dtype: torch.dtype = torch.float): 91 | max_sum_size = 0 92 | for module in m.modules(): 93 | sum_p_size = 0 94 | for param in module.parameters(recurse=False): 95 | assert param.dtype == test_dtype 96 | sum_p_size += param.numel() 97 | max_sum_size = max(max_sum_size, sum_p_size) 98 | return max_sum_size 99 | 100 | 101 | def calc_block_usage(): 102 | snap_shot = torch.cuda.memory_snapshot() 103 | 104 | total_sum = 0 105 | active_sum = 0 106 | for info_dict in snap_shot: 107 | blocks = info_dict.get('blocks') 108 | for b in blocks: 109 | size = b.get('size') 110 | state = b.get('state') 111 | total_sum += size 112 | if state == 'active_allocated': 113 | active_sum += size 114 | 115 | active_ratio = 1 116 | if total_sum > 0: 117 | active_ratio = active_sum / total_sum 118 | 119 | print(f'memory snap shot: active ratio {active_ratio:.2f}') 120 | -------------------------------------------------------------------------------- /test/wrapper/test_amp.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | from functools import partial 4 | from test.utils import TEST_MODELS, to_cuda 5 | 6 | import pytest 7 | import torch 8 | import torch.distributed as dist 9 | from apex import amp 10 | from apex.parallel import DistributedDataParallel as DDP 11 | from colossalai.nn.optimizer import HybridAdam 12 | from torch.testing import assert_close 13 | 14 | from elixir.cuda import gpu_device 15 | from elixir.search import simple_search 16 | from elixir.utils import init_distributed, seed_all 17 | from elixir.wrapper import ElixirModule, ElixirOptimizer 18 | 19 | 20 | def amp_check_model_states(ddp_optim, test_model): 21 | test_states = test_model.state_dict() 22 | for (name, _), p in zip(test_model.module.named_parameters(), amp.master_params(ddp_optim)): 23 | test_p = test_states[name] 24 | copy_p = p.to(test_p.device) 25 | print(f'checking parameter `{name}`: {test_p.dtype} {copy_p.dtype}') 26 | assert_close(test_p.data, copy_p.data) 27 | 28 | 29 | def exam_amp_one_model(model_fn, data_fn, nproc, group, exam_seed=2261): 30 | ddp_model = model_fn().cuda() 31 | test_model = copy.deepcopy(ddp_model) 32 | # important here, since apex has a lazy fp32 init after the first optimizer step 33 | test_model = test_model.half() 34 | 35 | ddp_optim = HybridAdam(ddp_model.parameters(), lr=1e-1, weight_decay=0) 36 | ddp_model, ddp_optim = amp.initialize(ddp_model, 37 | ddp_optim, 38 | opt_level='O2', 39 | loss_scale=1.0, 40 | keep_batchnorm_fp32=False) 41 | ddp_model = DDP(ddp_model, message_size=0, allreduce_always_fp32=True) 42 | 43 | test_optim = HybridAdam(test_model.parameters(), lr=1e-1, weight_decay=0) 44 | sr = simple_search(test_model, nproc, shard_device=gpu_device(), unified_dtype=torch.float16, verbose=True) 45 | test_model = ElixirModule(test_model, sr, group, dtype=torch.float16, reduce_always_fp32=True, output_fp32=True) 46 | test_optim = ElixirOptimizer(test_model, test_optim, initial_scale=1.0) 47 | 48 | # get different data 49 | seed_all(exam_seed + dist.get_rank(group), cuda_deterministic=True) 50 | for _ in range(2): 51 | data = to_cuda(data_fn()) 52 | 53 | ddp_optim.zero_grad() 54 | ddp_loss = ddp_model(**data) 55 | with amp.scale_loss(ddp_loss, ddp_optim) as scaled_loss: 56 | scaled_loss.backward() 57 | ddp_optim.step() 58 | 59 | test_optim.zero_grad() 60 | test_loss = test_model(**data) 61 | test_optim.backward(test_loss) 62 | test_optim.step() 63 | 64 | assert_close(ddp_loss, test_loss) 65 | amp_check_model_states(ddp_optim, test_model) 66 | 67 | 68 | def exam_amp_in_models(nproc, group): 69 | model_fn, data_fn = TEST_MODELS.get('gpt2_micro') 70 | exam_amp_one_model(model_fn, data_fn, nproc, group) 71 | 72 | 73 | def run_dist(rank, world_size): 74 | os.environ['RANK'] = str(rank) 75 | os.environ['LOCAL_RANK'] = str(rank) 76 | os.environ['WORLD_SIZE'] = str(world_size) 77 | os.environ['MASTER_ADDR'] = '127.0.0.1' 78 | os.environ['MASTER_PORT'] = str(29512) 79 | init_distributed() 80 | exam_amp_in_models(nproc=world_size, group=dist.GroupMember.WORLD) 81 | 82 | 83 | @pytest.mark.dist 84 | @pytest.mark.parametrize('world_size', [1, 2, 4]) 85 | def test_elixir_amp(world_size): 86 | run_func = partial(run_dist, world_size=world_size) 87 | torch.multiprocessing.spawn(run_func, nprocs=world_size) 88 | 89 | 90 | if __name__ == '__main__': 91 | test_elixir_amp(world_size=4) 92 | -------------------------------------------------------------------------------- /elixir/search/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def to_divide(a: int, b: int): 6 | return a + (-a % b) 7 | 8 | 9 | def to_meta_tensor(t: torch.Tensor, dtype: torch.dtype = None) -> torch.Tensor: 10 | # only float tensors need dtype change 11 | if t.is_floating_point() and dtype is not None: 12 | meta_dtype = dtype 13 | else: 14 | meta_dtype = t.dtype 15 | # we shall not use t.data.to here, since t might be a fake tensor 16 | meta_t = torch.empty(t.size(), dtype=meta_dtype, device='meta') 17 | # pack it if t is a parameter 18 | # we should filter parameters with no grad 19 | if isinstance(t, nn.Parameter) and t.requires_grad: 20 | meta_t = nn.Parameter(meta_t) 21 | return meta_t 22 | 23 | 24 | def get_multi_used_params(m: nn.Module) -> set[torch.Tensor]: 25 | multi_used_set = set() 26 | visit = dict() 27 | for module in m.modules(): 28 | for param in module.parameters(recurse=False): 29 | if param not in visit: 30 | visit[param] = True 31 | else: 32 | multi_used_set.add(param) 33 | return multi_used_set 34 | 35 | 36 | def find_minimum_waste_size(numel_group_list: list[list[int]], min_range: int, max_range: int, interval: int): 37 | 38 | max_per_group = list() 39 | for n_list in numel_group_list: 40 | max_per_group.append(max(n_list)) 41 | max_numel = max(max_per_group) 42 | 43 | test_size = to_divide(max(max_numel, min_range), interval) 44 | best_size = test_size 45 | min_waste = float('+inf') 46 | 47 | def calc_waste(numel_list: list[int], block_size: int): 48 | acc = 0 49 | left = 0 50 | for s in numel_list: 51 | if s > left: 52 | acc += left 53 | left = block_size 54 | left -= s 55 | return left + acc 56 | 57 | assert test_size <= max_range, 'max_numel or min_range is larger than max_range' 58 | while test_size <= max_range: 59 | current_waste = 0 60 | for n_list in numel_group_list: 61 | current_waste += calc_waste(n_list, test_size) 62 | if current_waste < min_waste: 63 | best_size = test_size 64 | min_waste = current_waste 65 | test_size += interval 66 | 67 | return best_size, min_waste 68 | 69 | 70 | def find_search_range(m: nn.Module): 71 | 72 | ele_size = 0 73 | for param in m.parameters(): 74 | if ele_size == 0: 75 | ele_size = param.element_size() 76 | else: 77 | assert param.element_size() == ele_size 78 | 79 | def next_2_pow(x: int): 80 | y = 1 81 | while y < x: 82 | y <<= 1 83 | return y 84 | 85 | private_params = get_multi_used_params(m) 86 | params = [p for p in m.parameters() if p not in private_params] 87 | memo_list = [p.numel() * p.element_size() for p in params] 88 | max_memo = max(memo_list) 89 | # minimum chunk memory is 32 MiB 90 | default_min = 32 * 1024**2 91 | while default_min < max_memo: 92 | default_min <<= 1 93 | default_max = int(3 * default_min) 94 | # * 2 for forward and backward 95 | length = 2 * next_2_pow(len(params)) 96 | default_iter_times = 16 * 1024**2 97 | default_search_times = default_iter_times // length 98 | 99 | gap = default_max - default_min 100 | # minimum search interval is 1024 101 | if default_search_times > (gap // 1024): 102 | interval = 1024 103 | else: 104 | interval = gap // default_search_times 105 | 106 | return (default_min // ele_size, default_max // ele_size, interval // ele_size) 107 | -------------------------------------------------------------------------------- /example/common/models.py: -------------------------------------------------------------------------------- 1 | from test.utils.gpt import GPTLMModel 2 | 3 | from transformers import AutoConfig, OPTConfig 4 | 5 | from example.common.opt import OPTLMModel 6 | 7 | 8 | def gpt2_400m(): 9 | return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16) 10 | 11 | 12 | def gpt2_1b(): 13 | return GPTLMModel(hidden_size=1536, num_layers=32, num_attention_heads=24) 14 | 15 | 16 | def gpt2_4b(): 17 | return GPTLMModel(hidden_size=3072, num_layers=32, num_attention_heads=24) 18 | 19 | 20 | def gpt2_10b(): 21 | return GPTLMModel(hidden_size=4096, num_layers=48, num_attention_heads=32) 22 | 23 | 24 | def gpt2_15b(): 25 | return GPTLMModel(hidden_size=8192, num_layers=18, num_attention_heads=64) 26 | 27 | 28 | def gpt2_20b(): 29 | return GPTLMModel(hidden_size=8192, num_layers=24, num_attention_heads=64) 30 | 31 | 32 | def gpt2_25b(): 33 | return GPTLMModel(hidden_size=8192, num_layers=30, num_attention_heads=64) 34 | 35 | 36 | def gpt2_30b(): 37 | return GPTLMModel(hidden_size=8192, num_layers=36, num_attention_heads=64) 38 | 39 | 40 | def gpt2_40b(): 41 | return GPTLMModel(hidden_size=8192, num_layers=50, num_attention_heads=16) 42 | 43 | 44 | def opt_350m(): 45 | opt_config = AutoConfig.from_pretrained('facebook/opt-350m') 46 | return OPTLMModel(opt_config) 47 | 48 | 49 | def opt_1b(): 50 | opt_config = AutoConfig.from_pretrained('facebook/opt-1.3b') 51 | return OPTLMModel(opt_config) 52 | 53 | 54 | def opt_3b(): 55 | opt_config = AutoConfig.from_pretrained('facebook/opt-2.7b') 56 | return OPTLMModel(opt_config) 57 | 58 | 59 | def opt_7b(): 60 | opt_config = AutoConfig.from_pretrained('facebook/opt-6.7b') 61 | return OPTLMModel(opt_config) 62 | 63 | 64 | def opt_13b(): 65 | opt_config = AutoConfig.from_pretrained('facebook/opt-13b') 66 | return OPTLMModel(opt_config) 67 | 68 | 69 | def opt_30b(): 70 | opt_config = AutoConfig.from_pretrained('facebook/opt-30b') 71 | return OPTLMModel(opt_config) 72 | 73 | 74 | def opt_66b(): 75 | opt_config = AutoConfig.from_pretrained('facebook/opt-66b') 76 | return OPTLMModel(opt_config) 77 | 78 | 79 | def opt_175b(): 80 | opt_config = OPTConfig(activation_dropout=0.0, 81 | hidden_size=12288, 82 | num_hidden_layers=96, 83 | ffn_dim=49152, 84 | num_attention_heads=96, 85 | word_embed_proj_dim=12288, 86 | output_projection=True) 87 | return OPTLMModel(opt_config) 88 | 89 | 90 | def get_model(name: str): 91 | if name == 'gpt2-400m': 92 | return gpt2_400m() 93 | elif name == 'gpt2-1b': 94 | return gpt2_1b() 95 | elif name == 'gpt2-4b': 96 | return gpt2_4b() 97 | elif name == 'gpt2-10b': 98 | return gpt2_10b() 99 | elif name == 'gpt2-15b': 100 | return gpt2_15b() 101 | elif name == 'gpt2-20b': 102 | return gpt2_20b() 103 | elif name == 'gpt2-25b': 104 | return gpt2_25b() 105 | elif name == 'gpt2-30b': 106 | return gpt2_30b() 107 | elif name == 'gpt2-40b': 108 | return gpt2_40b() 109 | elif name == 'opt-350m': 110 | return opt_350m() 111 | elif name == 'opt-1b': 112 | return opt_1b() 113 | elif name == 'opt-3b': 114 | return opt_3b() 115 | elif name == 'opt-7b': 116 | return opt_7b() 117 | elif name == 'opt-13b': 118 | return opt_13b() 119 | elif name == 'opt-30b': 120 | return opt_30b() 121 | elif name == 'opt-66b': 122 | return opt_66b() 123 | elif name == 'opt-175b': 124 | return opt_175b() 125 | else: 126 | raise ValueError(f'Unknown model name: {name}') 127 | -------------------------------------------------------------------------------- /elixir/kernels/opt_attention.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import einops 4 | import torch 5 | import torch.nn as nn 6 | from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder 7 | 8 | from .attention import lower_triangular_attention 9 | 10 | 11 | class XOPTAttention(OPTAttention): 12 | 13 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): 14 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).contiguous() 15 | 16 | def forward( 17 | self, 18 | hidden_states: torch.Tensor, 19 | key_value_states: Optional[torch.Tensor] = None, 20 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 21 | attention_mask: Optional[torch.Tensor] = None, 22 | layer_head_mask: Optional[torch.Tensor] = None, 23 | output_attentions: bool = False, 24 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 25 | """Input shape: Batch x Time x Channel""" 26 | 27 | # if key_value_states are provided this layer is used as a cross-attention layer 28 | # for the decoder 29 | is_cross_attention = key_value_states is not None 30 | assert is_cross_attention is False 31 | assert past_key_value is None 32 | assert layer_head_mask is None 33 | # assert output_attentions is False 34 | 35 | bsz, tgt_len, _ = hidden_states.size() 36 | 37 | # get query proj 38 | query_states = self.q_proj(hidden_states).view(bsz, tgt_len, self.num_heads, self.head_dim) 39 | # get key, value proj 40 | # self_attention 41 | key_states = self._shape(self.k_proj(hidden_states), -1, bsz) 42 | value_states = self._shape(self.v_proj(hidden_states), -1, bsz) 43 | 44 | if self.is_decoder: 45 | # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. 46 | # Further calls to cross_attention layer can then reuse all cross-attention 47 | # key/value_states (first "if" case) 48 | # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of 49 | # all previous decoder key/value_states. Further calls to uni-directional self-attention 50 | # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) 51 | # if encoder bi-directional self-attention `past_key_value` is always `None` 52 | past_key_value = (key_states, value_states) 53 | 54 | src_len = key_states.size(1) 55 | assert tgt_len == src_len 56 | 57 | attn_output = lower_triangular_attention(query=query_states, key=key_states, value=value_states, p=self.dropout) 58 | 59 | if attn_output.size() != (bsz, tgt_len, self.num_heads, self.head_dim): 60 | raise ValueError(f'`attn_output` should be of size {(bsz, tgt_len, self.num_heads, self.head_dim)}, but is' 61 | f' {attn_output.size()}') 62 | 63 | # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be 64 | # partitioned aross GPUs when using tensor-parallelism. 65 | attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) 66 | 67 | attn_output = self.out_proj(attn_output) 68 | 69 | return attn_output, None, past_key_value 70 | 71 | 72 | class XOPTDecoder(OPTDecoder): 73 | 74 | def forward(self, *args, **kwargs): 75 | assert 'attention_mask' in kwargs, 'please pass attention_mask as a kwarg' 76 | attn_mask = kwargs.get('attention_mask') 77 | # assert torch.all(attn_mask == 1), 'only accept no padding mask' 78 | 79 | head_mask = kwargs.get('head_mask', None) 80 | assert head_mask is None, 'head mask should be None' 81 | 82 | output_attn = kwargs.get('output_attentions', False) 83 | if output_attn: 84 | Warning('output_attentions is not supported for XOPTDecoder') 85 | 86 | return super().forward(*args, **kwargs) 87 | -------------------------------------------------------------------------------- /elixir/parameter/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils._pytree import tree_map 3 | 4 | debug_flag = False 5 | 6 | white_list = {torch.Tensor.__getitem__} 7 | 8 | fake_allowed = { 9 | # pre-commit: don't move 10 | torch.Tensor.numel, 11 | torch.Tensor.size, 12 | torch.Tensor.stride, 13 | torch.Tensor.storage_offset, 14 | torch.Tensor.is_floating_point 15 | } 16 | 17 | inpalce_mapping = { 18 | torch.Tensor.add_: torch.Tensor.add, 19 | torch.Tensor.sub_: torch.Tensor.sub, 20 | torch.Tensor.mul_: torch.Tensor.mul, 21 | torch.Tensor.div_: torch.Tensor.div 22 | } 23 | 24 | 25 | def is_no_hook_op(func) -> bool: 26 | if func.__name__.startswith('__') and func not in white_list: 27 | return True 28 | if func in fake_allowed: 29 | return True 30 | return False 31 | 32 | 33 | class FakeTensor(torch.Tensor): 34 | 35 | @staticmethod 36 | def __new__(cls, elem, *args, **kwargs): 37 | r = torch.Tensor._make_wrapper_subclass(cls, 38 | elem.size(), 39 | strides=elem.stride(), 40 | storage_offset=elem.storage_offset(), 41 | dtype=elem.dtype, 42 | layout=elem.layout, 43 | device=elem.device, 44 | requires_grad=elem.requires_grad) 45 | return r 46 | 47 | @classmethod 48 | def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 49 | raise NotImplementedError 50 | 51 | 52 | def to_outplace_tensor(t): 53 | if isinstance(t, OutplaceTensor): 54 | return t 55 | assert type(t) is torch.Tensor, f'type: {type(t)}' 56 | t.__class__ = OutplaceTensor 57 | return t 58 | 59 | 60 | class OutplaceTensor(torch.Tensor): 61 | # TODO: rename this class 62 | def __new__(cls, tensor): 63 | rt = tensor.as_subclass(cls) 64 | return rt 65 | 66 | @classmethod 67 | def __torch_function__(cls, func, types, args=(), kwargs=None): 68 | 69 | if kwargs is None: 70 | kwargs = {} 71 | # in order to trigger pre-op hook in the forward of checkpoint module 72 | # we have to capture the `backward` function 73 | # and make sure that it does not in `torch._C.DisableTorchFunction()` context 74 | if func is torch.Tensor.backward: 75 | assert len(args) == 1 # only has 1 paramter 76 | backward_tensor = torch.Tensor(args[0]) 77 | tensor_kwargs = {k: torch.Tensor(v) if torch.is_tensor(v) else v for k, v in kwargs.items()} 78 | return backward_tensor.backward(**tensor_kwargs) 79 | # return a tensor if the output needs to be a torch.Tensor (such as Tensor.data.__get__) 80 | if is_no_hook_op(func): 81 | with torch._C.DisableTorchFunction(): 82 | ret = func(*args, **kwargs) 83 | return ret 84 | 85 | # debug inplace operations 86 | if debug_flag: 87 | if func.__name__.endswith('_'): 88 | print(f'found inplace operation {func.__name__}') 89 | 90 | # replace the in-place function 91 | if func in inpalce_mapping: 92 | func = inpalce_mapping[func] 93 | # set the 'inplace' kwargs to False 94 | if 'inplace' in kwargs: 95 | kwargs['inplace'] = False 96 | 97 | with torch._C.DisableTorchFunction(): 98 | ret = func(*args, **kwargs) 99 | if not isinstance(ret, tuple): 100 | ret = (ret,) 101 | 102 | def convert(t): 103 | if isinstance(t, torch.Tensor): 104 | t = to_outplace_tensor(t) 105 | return t 106 | 107 | ret = tree_map(convert, ret) 108 | 109 | if len(ret) == 1: 110 | ret = ret[0] 111 | 112 | return ret 113 | -------------------------------------------------------------------------------- /elixir/tracer/memory_tracer/op_cache.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | from typing import Dict, Iterator, Tuple 3 | 4 | import torch 5 | 6 | from elixir.tracer.utils import get_cuda_allocated, get_cuda_max_allocated 7 | 8 | from .output_shape import addmm_output, bmm_output, check_cuda_mm, mm_output 9 | 10 | 11 | @contextlib.contextmanager 12 | def no_dispatch() -> Iterator[None]: 13 | guard = torch._C._DisableTorchDispatch() 14 | try: 15 | yield 16 | finally: 17 | del guard 18 | 19 | 20 | def tensor_info(x: torch.Tensor): 21 | # returns the meta information used for CUDA kernels 22 | return (x.shape, x.stride(), x.layout, x.dtype) 23 | 24 | 25 | def get_args_info(*args): 26 | # returns a tuple contains the meta information of all inputs 27 | # every argument is expected to be a tensor 28 | info_list = [] 29 | for x in args: 30 | if isinstance(x, torch.Tensor): 31 | info_list.append(tensor_info(x)) 32 | return tuple(info_list) 33 | 34 | 35 | class OpCache(object): 36 | 37 | def __init__(self, name: str) -> None: 38 | super().__init__() 39 | self.name = name 40 | self.temp_memory: Dict[Tuple, int] = dict() 41 | 42 | def reset(self): 43 | self.temp_memory.clear() 44 | 45 | def get(self, info): 46 | if info in self.temp_memory: 47 | return True, self.temp_memory[info] 48 | else: 49 | return False, None 50 | 51 | def add(self, info, memo): 52 | self.temp_memory[info] = memo 53 | 54 | def print(self): 55 | print(f'OpCache {self.name} information:') 56 | for k, v in self.temp_memory.items(): 57 | print(f'key: {k}\ntemp_memo:{v}') 58 | 59 | 60 | aten = torch.ops.aten 61 | addmm_cache = OpCache('aten.addmm.default') 62 | bmm_cache = OpCache('aten.bmm.default') 63 | mm_cache = OpCache('aten.mm.default') 64 | 65 | op_mapping = { 66 | aten.mm.default: { 67 | 'cache': mm_cache, 68 | 'output': mm_output 69 | }, 70 | aten.addmm.default: { 71 | 'cache': addmm_cache, 72 | 'output': addmm_output 73 | }, 74 | aten.bmm.default: { 75 | 'cache': bmm_cache, 76 | 'output': bmm_output 77 | } 78 | } 79 | 80 | 81 | def reset_caches(): 82 | addmm_cache.reset() 83 | bmm_cache.reset() 84 | mm_cache.reset() 85 | 86 | 87 | def fake_cuda_output(temp_memo, output_shape, dtype): 88 | ret = torch.empty(output_shape, dtype=dtype, device='cuda') 89 | sub = temp_memo - ret.numel() * ret.element_size() 90 | 91 | if sub > 0: 92 | # allocate a temp empty tensor block to simulate the computation in kernels 93 | temp = torch.empty(sub, dtype=torch.int8, device='cuda') 94 | # release this tensor block 95 | del temp 96 | 97 | return ret 98 | 99 | 100 | def real_cuda_output(func, *args, **kwargs): 101 | cur_alc = get_cuda_allocated() 102 | # save the peak memory usage 103 | pre_max_alc = get_cuda_max_allocated() 104 | # the peak memory history is cleared here 105 | torch.cuda.reset_peak_memory_stats() 106 | 107 | with no_dispatch(): 108 | ret = func(*args, **kwargs) 109 | 110 | max_alc = get_cuda_max_allocated() 111 | # calculate the temporary memory allocation 112 | temp_memo = max_alc - cur_alc 113 | 114 | return ret, temp_memo, pre_max_alc 115 | 116 | 117 | def wrapped_mm_ops(func, *args, **kwargs): 118 | check_cuda_mm(*args) 119 | 120 | if func not in op_mapping: 121 | raise RuntimeError(f'Unsupported mm operation {func}') 122 | 123 | args_info = get_args_info(*args) 124 | cache = op_mapping[func]['cache'] 125 | cached_flag, temp_memo = cache.get(args_info) 126 | 127 | if cached_flag: 128 | output_fn = op_mapping[func]['output'] 129 | out_shape = output_fn(*args) 130 | ret = fake_cuda_output(temp_memo=temp_memo, output_shape=out_shape, dtype=args[0].dtype) 131 | return ret, 0 132 | else: 133 | ret, temp_memo, pre_max_alc = real_cuda_output(func, *args, **kwargs) 134 | cache.add(args_info, temp_memo) 135 | return ret, pre_max_alc 136 | -------------------------------------------------------------------------------- /example/fine-tune/data_module.py: -------------------------------------------------------------------------------- 1 | import datasets 2 | from torch.utils.data import DataLoader 3 | from torch.utils.data.distributed import DistributedSampler 4 | from transformers import AutoTokenizer 5 | 6 | 7 | class GLUEDataModule: 8 | 9 | task_text_field_map = { 10 | 'cola': ['sentence'], 11 | 'sst2': ['sentence'], 12 | 'mrpc': ['sentence1', 'sentence2'], 13 | 'qqp': ['question1', 'question2'], 14 | 'stsb': ['sentence1', 'sentence2'], 15 | 'mnli': ['premise', 'hypothesis'], 16 | 'qnli': ['question', 'sentence'], 17 | 'rte': ['sentence1', 'sentence2'], 18 | 'wnli': ['sentence1', 'sentence2'], 19 | 'ax': ['premise', 'hypothesis'], 20 | } 21 | 22 | glue_task_num_labels = { 23 | 'cola': 2, 24 | 'sst2': 2, 25 | 'mrpc': 2, 26 | 'qqp': 2, 27 | 'stsb': 1, 28 | 'mnli': 3, 29 | 'qnli': 2, 30 | 'rte': 2, 31 | 'wnli': 2, 32 | 'ax': 3, 33 | } 34 | 35 | loader_columns = [ 36 | 'datasets_idx', 37 | 'input_ids', 38 | 'token_type_ids', 39 | 'attention_mask', 40 | 'start_positions', 41 | 'end_positions', 42 | 'labels', 43 | ] 44 | 45 | def __init__( 46 | self, 47 | model_name_or_path: str, 48 | task_name: str = 'mrpc', 49 | max_seq_length: int = 128, 50 | train_batch_size: int = 32, 51 | eval_batch_size: int = 32, 52 | **kwargs, 53 | ): 54 | self.model_name_or_path = model_name_or_path 55 | self.task_name = task_name 56 | self.max_seq_length = max_seq_length 57 | self.train_batch_size = train_batch_size 58 | self.eval_batch_size = eval_batch_size 59 | 60 | self.text_fields = self.task_text_field_map[task_name] 61 | self.num_labels = self.glue_task_num_labels[task_name] 62 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True) 63 | 64 | self.dataset = None 65 | self.columns = None 66 | self.eval_splits = None 67 | 68 | def setup(self, stage: str): 69 | self.dataset = datasets.load_dataset('glue', self.task_name) 70 | 71 | for split in self.dataset.keys(): 72 | self.dataset[split] = self.dataset[split].map( 73 | self.convert_to_features, 74 | batched=True, 75 | remove_columns=['label'], 76 | ) 77 | self.columns = [c for c in self.dataset[split].column_names if c in self.loader_columns] 78 | self.dataset[split].set_format(type='torch', columns=self.columns) 79 | 80 | self.eval_splits = [x for x in self.dataset.keys() if 'validation' in x] 81 | 82 | def train_loader_and_sampler(self): 83 | train_set = self.dataset['train'] 84 | train_sampler = DistributedSampler(train_set, shuffle=True) 85 | train_loader = DataLoader(train_set, self.train_batch_size, sampler=train_sampler) 86 | return train_loader, train_sampler 87 | 88 | def val_loader_and_sampler(self): 89 | valid_set = self.dataset['validation'] 90 | valid_loader = DataLoader(valid_set, self.eval_batch_size) 91 | return valid_loader 92 | 93 | def convert_to_features(self, example_batch, indices=None): 94 | 95 | # Either encode single sentence or sentence pairs 96 | if len(self.text_fields) > 1: 97 | texts_or_text_pairs = list(zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]])) 98 | else: 99 | texts_or_text_pairs = example_batch[self.text_fields[0]] 100 | 101 | # Tokenize the text/text pairs 102 | features = self.tokenizer.batch_encode_plus(texts_or_text_pairs, 103 | max_length=self.max_seq_length, 104 | pad_to_max_length=True, 105 | truncation=True) 106 | 107 | # Rename label to labels to make it easier to pass to model forward 108 | features['labels'] = example_batch['label'] 109 | 110 | return features 111 | -------------------------------------------------------------------------------- /elixir/search/simple.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Tuple 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from .base import SearchBase 8 | from .result import SearchResult 9 | from .utils import get_multi_used_params, to_divide 10 | 11 | 12 | class SearchSimple(SearchBase): 13 | 14 | def __init__(self, 15 | module: nn.Module, 16 | default_group_size: int, 17 | dtype: torch.dtype = torch.float, 18 | prefetch: bool = False, 19 | verbose: bool = False, 20 | inp=None, 21 | step_fn=None) -> None: 22 | 23 | super().__init__(module, dtype, prefetch, verbose, inp, step_fn) 24 | self.default_group_size = default_group_size 25 | 26 | def private_truncate(self, param: nn.Parameter) -> int: 27 | return to_divide(param.numel(), self.default_group_size) 28 | 29 | def public_trucate(self, length: int) -> int: 30 | return to_divide(length, self.default_group_size) 31 | 32 | def search(self, split_number: int, allocate_factor: float) -> Tuple: 33 | # get multi-used parameters 34 | private_params = get_multi_used_params(self.meta_module) 35 | # get parameters used only one time 36 | public_params = [p for p in self.meta_module.parameters() if p not in private_params] 37 | 38 | # calculate the size of each group 39 | len_public = len(public_params) 40 | split_number = min(len_public, split_number) 41 | # allocate a list for groups 42 | public_groups = list() 43 | if split_number > 0: 44 | average_size = len_public // split_number 45 | left_size = len_public % split_number 46 | 47 | # set the size of each segment 48 | pack_size_list = [average_size] * split_number 49 | for i in range(split_number): 50 | if left_size > 0: 51 | pack_size_list[i] += 1 52 | left_size -= 1 53 | 54 | # split public parameters 55 | for i in range(split_number): 56 | p_list = list() 57 | for _ in range(pack_size_list[i]): 58 | p = public_params.pop(0) 59 | p_list.append(p) 60 | public_groups.append(p_list) 61 | assert len(public_params) == 0 62 | 63 | # calculate the maximum summarized size 64 | max_sum_size = 0 65 | for p_list in public_groups: 66 | sum_size = sum([p.numel() for p in p_list]) 67 | max_sum_size = max(max_sum_size, sum_size) 68 | else: 69 | max_sum_size = 0 70 | 71 | self.public_block_size = max_sum_size 72 | self.public_block_number = math.ceil(split_number * allocate_factor) 73 | 74 | return (private_params, public_groups) 75 | 76 | 77 | def simple_search(m: nn.Module, 78 | group_size: int, 79 | split_number: int = 10, 80 | allocate_factor: float = 0.6, 81 | unified_dtype: torch.dtype = torch.float, 82 | shard_device: torch.device = torch.device('cpu'), 83 | prefetch: bool = False, 84 | verbose: bool = False, 85 | inp=None, 86 | step_fn=None) -> SearchResult: 87 | 88 | search_class = SearchSimple( 89 | # pre-commit: do not rearrange 90 | module=m, 91 | default_group_size=group_size, 92 | dtype=unified_dtype, 93 | prefetch=prefetch, 94 | verbose=verbose, 95 | inp=inp, 96 | step_fn=step_fn) 97 | 98 | private_group, public_groups = search_class.search(split_number, allocate_factor) 99 | chunk_plans = search_class.generate_chunk_plans(private_group, public_groups) 100 | # assign shard device 101 | for plan in chunk_plans: 102 | plan.kwargs['shard_device'] = shard_device 103 | 104 | chunk_group = search_class.allocate_chunk_group(chunk_plans) 105 | 106 | return SearchResult(chunk_group=chunk_group, 107 | chunk_plans=chunk_plans, 108 | param_called_per_step=search_class.param_per_step) 109 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Elixir (Gemini2.0) 2 | Elixir, also known as Gemini, is a technology designed to facilitate the training of large models on a small GPU cluster. 3 | Its goal is to eliminate data redundancy and leverage CPU memory to accommodate really large models. 4 | In addition, Elixir automatically profiles each training step prior to execution and selects the optimal configuration for the ratio of redundancy and the device for each parameter. 5 | This repository is used to benchmark the performance of Elixir. 6 | Elixir will be integrated into ColossalAI for usability. 7 | 8 | ## Environment 9 | 10 | This version is a beta release, so the running environment is somewhat restrictive. 11 | We are only demonstrating our running environment here, as we have not yet tested its compatibility. 12 | We have set the CUDA version to `11.6` and the PyTorch version to `1.13.1+cu11.6`. 13 | 14 | Three dependent package should be installed from source. 15 | - [ColossalAI](https://github.com/hpcaitech/ColossalAI) (necessary): just clone it and use `pip install .` from the newest master branch. 16 | - [Apex](https://github.com/NVIDIA/apex) (optional): clone it, checkout to tag `22.03`, and install it. 17 | - [Xformers](https://github.com/facebookresearch/xformers) (optional): clone it, checkout to tag `v0.0.17`, and install it. 18 | 19 | Finally, install all packages in the `requirements.txt`. 20 | 21 | ## Tools 22 | 23 | ### CUDA Memory Profiling 24 | 25 | Function `cuda_memory_profiling` in `elixir.tracer.memory_tracer` can help you profile each kind of memory occupation during training. 26 | It tells you the CUDA memory occupation of parameters, gradient and maximum size of activations generated during training. 27 | Moreover, it is an efficient and fast tool which enables quickly profiling OPT-175B model on a single GPU. 28 | You can try it by yourself with the folder `activation` in the directory `example`. 29 | 30 | (I think you should have at least 16GB CUDA memory to run the OPT-175B example but that doesn't matter. Just try it first.) 31 | 32 | ### Hardware Performance Profiling 33 | 34 | See the folder `profile`. 35 | You can profile the aggregate bandwidth of GPU-CPU communications and the aggreagte velocity of Adam optimizers. 36 | 37 | ## Examples 38 | 39 | Here is a simple example to wrap your model and optimizer for [fine-tuning](https://github.com/hpcaitech/Elixir/tree/main/example/fine-tune). 40 | 41 | ```python 42 | from elixir.search import minimum_waste_search 43 | from elixir.wrapper import ElixirModule, ElixirOptimizer 44 | 45 | model = BertForSequenceClassification.from_pretrained('bert-base-uncased') 46 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, eps=1e-8) 47 | 48 | sr = minimum_waste_search(model, world_size) 49 | model = ElixirModule(model, sr, world_group) 50 | optimizer = ElixirOptimizer(model, optimizer) 51 | ``` 52 | 53 | Here is an advanced example for performance, which is used in our [benchmarkhere](https://github.com/hpcaitech/Elixir/blob/main/example/common/elx.py). 54 | 55 | ```python 56 | import torch 57 | import torch.distributed as dist 58 | from colossalai.nn.optimizer import HybridAdam 59 | from elixir.wrapper import ElixirModule, ElixirOptimizer 60 | 61 | # get the world communication group 62 | global_group = dist.GroupMember.WORLD 63 | # get the communication world size 64 | global_size = dist.get_world_size() 65 | 66 | # initialize the model in CPU 67 | model = get_model(model_name) 68 | # HybridAdam allows a part of parameters updated on CPU and a part updated on GPU 69 | optimizer = HybridAdam(model.parameters(), lr=1e-3) 70 | 71 | sr = optimal_search( 72 | model, 73 | global_size, 74 | unified_dtype=torch.float16, # enable for FP16 training 75 | overlap=True, # enable for overlapping communications 76 | verbose=True, # print detailed processing information 77 | inp=data, # proivde an example input data in dictionary format 78 | step_fn=train_step # provide an example step function 79 | ) 80 | model = ElixirModule( 81 | model, 82 | sr, 83 | global_group, 84 | prefetch=True, # prefetch chunks to overlap communications 85 | dtype=torch.float16, # use AMP 86 | use_fused_kernels=True # enable fused kernels in Apex 87 | ) 88 | optimizer = ElixirOptimizer( 89 | model, 90 | optimizer, 91 | initial_scale=64, # loss scale used in AMP 92 | init_step=True # enable for the stability of training 93 | ) 94 | ``` 95 | -------------------------------------------------------------------------------- /example/fine-tune/torch_ddp.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from datetime import datetime 3 | from time import time 4 | from typing import Optional 5 | 6 | import colossalai 7 | import datasets 8 | import torch 9 | import torch.distributed as dist 10 | from colossalai.logging import disable_existing_loggers, get_dist_logger 11 | from colossalai.nn import LinearWarmupLR 12 | from data_module import GLUEDataModule 13 | from func_module import evaluate, get_mem_info, get_tflops, seed_all, train 14 | from torch.nn.parallel import DistributedDataParallel as DDP 15 | from transformers import AutoConfig, BertForSequenceClassification, get_linear_schedule_with_warmup 16 | 17 | if __name__ == '__main__': 18 | disable_existing_loggers() 19 | colossalai.launch_from_torch(config={}) 20 | logger = get_dist_logger() 21 | world_size = dist.get_world_size() 22 | local_rank = dist.get_rank() 23 | 24 | parser = ArgumentParser() 25 | parser.add_argument('--task', default='mrpc') 26 | parser.add_argument('--epochs', type=int, default=3) 27 | parser.add_argument('--batch_size', type=int, default=32) 28 | parser.add_argument('--lr', type=float, default=2.4e-5) 29 | parser.add_argument('--weight_decay', type=float, default=0.01) 30 | parser.add_argument('--warmup_fraction', type=float, default=0.1) 31 | args = parser.parse_args() 32 | 33 | assert args.batch_size % world_size == 0 34 | global_batch_size = args.batch_size 35 | local_batch_size = args.batch_size // world_size 36 | 37 | global_seed = 3407 38 | seed_all(global_seed) 39 | logger.info('Random is set to {} in all processes.'.format(global_seed), ranks=[0]) 40 | 41 | model_name = 'bert-base-uncased' 42 | logger.info('Data is preparing now.', ranks=[0]) 43 | dm = GLUEDataModule(model_name_or_path=model_name, 44 | task_name=args.task, 45 | train_batch_size=local_batch_size, 46 | eval_batch_size=global_batch_size) 47 | dm.setup('fit') 48 | 49 | config = AutoConfig.from_pretrained(model_name, num_labels=dm.num_labels) 50 | metric = datasets.load_metric('glue', dm.task_name, experiment_id=datetime.now().strftime('%d-%m-%Y_%H-%M-%S')) 51 | 52 | logger.info('Model is creating now.', ranks=[0]) 53 | model = BertForSequenceClassification.from_pretrained(model_name, config=config) 54 | numel = sum([p.numel() for p in model.parameters()]) 55 | logger.info(f'Model numel: {numel}', ranks=[0]) 56 | logger.info(get_mem_info(), ranks=[0]) 57 | 58 | model = model.cuda() 59 | model = DDP(model, device_ids=[local_rank]) 60 | 61 | logger.info('Optimizer is creating now.', ranks=[0]) 62 | no_decay = ['bias', 'LayerNorm.weight'] 63 | optimizer_grouped_parameters = [ 64 | { 65 | 'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 66 | 'weight_decay': args.weight_decay, 67 | }, 68 | { 69 | 'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 70 | 'weight_decay': 0.0, 71 | }, 72 | ] 73 | 74 | optimizer = torch.optim.Adam(optimizer_grouped_parameters, lr=args.lr, eps=1e-8) 75 | 76 | logger.info('Dataloder is creating now.', ranks=[0]) 77 | train_loader, train_sampler = dm.train_loader_and_sampler() 78 | valid_loader = dm.val_loader_and_sampler() 79 | 80 | logger.info('Learning rate scheduler is creating now.', ranks=[0]) 81 | num_epoch = args.epochs 82 | steps_per_epoch = len(train_loader) 83 | num_all_steps = num_epoch * steps_per_epoch 84 | num_warm_steps = int(num_all_steps * args.warmup_fraction) 85 | lr_scheduler = get_linear_schedule_with_warmup(optimizer=optimizer, 86 | num_warmup_steps=num_warm_steps, 87 | num_training_steps=num_all_steps) 88 | 89 | for epoch in range(num_epoch): 90 | logger.info('Epoch {} starts'.format(epoch), ranks=[0]) 91 | dist.barrier() 92 | train(epoch=epoch, 93 | sampler=train_sampler, 94 | model=model, 95 | loader=train_loader, 96 | optimizer=optimizer, 97 | lr_scheduler=lr_scheduler, 98 | show_progress=local_rank == 0) 99 | percentage, f1 = evaluate(model=model, metric=metric, loader=valid_loader, show_progress=local_rank == 0) 100 | logger.info('valid correct percentage: {:.4f}\nf1: {:.4f}'.format(percentage, f1), ranks=[0]) 101 | -------------------------------------------------------------------------------- /elixir/search/simulator.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from .utils import to_divide 4 | 5 | 6 | class LinkedSet(object): 7 | 8 | def __init__(self) -> None: 9 | super().__init__() 10 | self.fifo_dict = dict() 11 | 12 | def __len__(self) -> int: 13 | return len(self.fifo_dict) 14 | 15 | def __contains__(self, x: int) -> bool: 16 | return x in self.fifo_dict 17 | 18 | def full(self, n: int): 19 | return len(self.fifo_dict) >= n 20 | 21 | def push(self, x: int): 22 | assert x not in self.fifo_dict 23 | self.fifo_dict[x] = True 24 | 25 | def pop_value(self, x: int): 26 | assert x in self.fifo_dict 27 | self.fifo_dict.pop(x) 28 | 29 | def pop_left(self): 30 | x = next(iter(self.fifo_dict)) 31 | self.fifo_dict.pop(x) 32 | 33 | 34 | def calc_move_times(param_per_step: list, param_to_chunk: dict, n_blocks: int): 35 | from elixir.c_utils import move_count 36 | chunk_per_step = list() 37 | 38 | for param_set in param_per_step: 39 | id_set = set() 40 | for name in param_set: 41 | # continue if the parameter is ignored 42 | if name not in param_to_chunk: 43 | continue 44 | id_set.add(param_to_chunk[name]) 45 | if len(id_set) > 0: 46 | chunk_per_step.append(list(id_set)) 47 | 48 | return move_count(chunk_per_step, n_blocks) 49 | 50 | 51 | def find_optimal_chunk_size( 52 | # pre-commit: do not rearrange 53 | param_per_step: list, 54 | param_names: list, 55 | param_numels: list, 56 | cuda_elements: int, 57 | overlap: bool, 58 | min_range: int, 59 | max_range: int, 60 | interval: int): 61 | 62 | max_numel = 0 63 | for numel in param_numels: 64 | max_numel = max(max_numel, numel) 65 | test_size = to_divide(max(max_numel, min_range), interval) 66 | # floor rounding 67 | cuda_elements = to_divide(cuda_elements - interval + 1, interval) 68 | max_range = min(max_range, cuda_elements) 69 | 70 | min_move_elements = float('+inf') 71 | best_size = test_size 72 | best_number_blocks = 0 73 | best_waste = 0 74 | 75 | def dispatch_chunks(param_to_chunk: dict, block_size: int) -> int: 76 | chunk_id = 0 77 | acc = 0 78 | left = 0 79 | for (name, numel) in zip(param_names, param_numels): 80 | if numel > left: 81 | acc += left 82 | chunk_id += 1 83 | left = block_size 84 | left -= numel 85 | param_to_chunk[name] = chunk_id 86 | return (chunk_id, left + acc) 87 | 88 | assert test_size <= max_range, 'max_numel or min_range is larger than max_range or cuda capacity' 89 | while test_size <= max_range: 90 | # calculate the number of blocks 91 | number_blocks = int(cuda_elements // test_size) 92 | # if prefetch is enabled, we pretend that two chunks are reserved 93 | if overlap: 94 | number_blocks -= 2 95 | if number_blocks <= 0: 96 | continue 97 | # initialize the chunk id for each parameter 98 | param_to_chunk = dict() 99 | number_chunks, current_waste = dispatch_chunks(param_to_chunk, test_size) 100 | number_blocks = min(number_blocks, number_chunks) 101 | # calculate the minimum number of movements 102 | move_times = calc_move_times(param_per_step, param_to_chunk, number_blocks) 103 | 104 | current_move_elements = move_times * test_size 105 | # print("test", test_size, current_move_elements) 106 | if current_move_elements < min_move_elements: 107 | min_move_elements = current_move_elements 108 | best_size = test_size 109 | best_number_blocks = number_blocks 110 | best_waste = current_waste 111 | 112 | test_size += interval 113 | 114 | if min_move_elements == float('inf'): 115 | raise RuntimeError('optimal search: can not find a valid solution') 116 | 117 | return best_size, best_number_blocks, best_waste 118 | 119 | 120 | def bandwidth_c2g(n: int): 121 | return 16.3 * n + 8.7 122 | 123 | 124 | def bandwidth_g2c(n: int): 125 | return 15.8 * n + 2.3 126 | 127 | 128 | def velocity_gpu(n: int): 129 | return 50 * n 130 | 131 | 132 | def velocity_cpu(n: int): 133 | return 1.66 * math.log(n) + 5.15 134 | 135 | 136 | def rcache_prioirity_check(n: int, r_os: int, e_p: int, e_o: int): 137 | In = e_p / bandwidth_c2g(n) + e_p / bandwidth_g2c(n) 138 | Jn = (n / r_os) * (e_o / bandwidth_c2g(n) + In + e_p / bandwidth_g2c(n) + 1.0 / velocity_cpu(n) - 139 | 1.0 / velocity_gpu(n)) 140 | return In > Jn 141 | -------------------------------------------------------------------------------- /test/tracer/test_op_cache.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from elixir.tracer.memory_tracer import MTensor 5 | from elixir.tracer.memory_tracer.op_cache import addmm_cache, bmm_cache, mm_cache 6 | from elixir.tracer.utils import get_cuda_allocated, get_cuda_max_allocated 7 | 8 | 9 | def op_mm(x, y): 10 | u = torch.matmul(x, y) 11 | return u.shape 12 | 13 | 14 | def op_addmm(x, y, z): 15 | u = torch.addmm(x, y, z) 16 | return u.shape 17 | 18 | 19 | def op_bmm(x, y): 20 | u = torch.bmm(x, y) 21 | return u.shape 22 | 23 | 24 | @pytest.mark.parametrize('dtype', [torch.float, torch.float16, torch.bfloat16]) 25 | def test_mm(dtype, size0=(4, 256), size1=(256, 1024)): 26 | torch.cuda.reset_peak_memory_stats() 27 | assert get_cuda_allocated() == 0 28 | 29 | x = torch.randn(size0, dtype=dtype, device='cuda') 30 | y = torch.randn(size1, dtype=dtype, device='cuda') 31 | torch_pre_alc = get_cuda_allocated() 32 | 33 | torch_z_size = op_mm(x, y) 34 | torch_temp_alc = get_cuda_max_allocated() - torch_pre_alc 35 | 36 | del x 37 | del y 38 | 39 | assert get_cuda_allocated() == 0 40 | x = MTensor(torch.randn(size0, dtype=dtype, device='cuda')) 41 | y = MTensor(torch.randn(size1, dtype=dtype, device='cuda')) 42 | op1_pre_alc = get_cuda_allocated() 43 | 44 | MTensor.reset_peak_memory() 45 | op1_z_size = op_mm(x, y) 46 | op1_temp_alc = MTensor.current_peak_memory() - op1_pre_alc 47 | 48 | assert torch_z_size == op1_z_size 49 | assert torch_pre_alc == op1_pre_alc 50 | assert torch_temp_alc == op1_temp_alc 51 | assert len(mm_cache.temp_memory) > 0 52 | 53 | MTensor.reset_peak_memory() 54 | op2_z_size = op_mm(x, y) 55 | op2_temp_alc = MTensor.current_peak_memory() - op1_pre_alc 56 | 57 | assert torch_z_size == op2_z_size 58 | assert torch_temp_alc == op2_temp_alc 59 | 60 | 61 | @pytest.mark.parametrize('dtype', [torch.float, torch.float16, torch.bfloat16]) 62 | def test_addmm(dtype, size0=(4, 16), size1=(16, 64)): 63 | torch.cuda.reset_peak_memory_stats() 64 | assert get_cuda_allocated() == 0 65 | 66 | x = torch.randn(size0, dtype=dtype, device='cuda') 67 | y = torch.randn(size1, dtype=dtype, device='cuda') 68 | u = torch.randn(size1[-1], dtype=dtype, device='cuda') 69 | torch_pre_alc = get_cuda_allocated() 70 | 71 | torch_z_size = op_addmm(u, x, y) 72 | torch_temp_alc = get_cuda_max_allocated() - torch_pre_alc 73 | 74 | del x 75 | del y 76 | del u 77 | 78 | assert get_cuda_allocated() == 0 79 | x = MTensor(torch.randn(size0, dtype=dtype, device='cuda')) 80 | y = MTensor(torch.randn(size1, dtype=dtype, device='cuda')) 81 | u = MTensor(torch.randn(size1[-1], dtype=dtype, device='cuda')) 82 | op1_pre_alc = get_cuda_allocated() 83 | 84 | MTensor.reset_peak_memory() 85 | op1_z_size = op_addmm(u, x, y) 86 | op1_temp_alc = MTensor.current_peak_memory() - op1_pre_alc 87 | 88 | assert torch_z_size == op1_z_size 89 | assert torch_pre_alc == op1_pre_alc 90 | assert torch_temp_alc == op1_temp_alc 91 | assert len(addmm_cache.temp_memory) > 0 92 | 93 | MTensor.reset_peak_memory() 94 | op2_z_size = op_addmm(u, x, y) 95 | op2_temp_alc = MTensor.current_peak_memory() - op1_pre_alc 96 | 97 | assert torch_z_size == op2_z_size 98 | assert torch_temp_alc == op2_temp_alc 99 | 100 | 101 | @pytest.mark.parametrize('dtype', [torch.float, torch.float16, torch.bfloat16]) 102 | def test_bmm(dtype, size0=(10, 4, 15), size1=(10, 15, 64)): 103 | torch.cuda.reset_peak_memory_stats() 104 | assert get_cuda_allocated() == 0 105 | 106 | x = torch.randn(size0, dtype=dtype, device='cuda') 107 | y = torch.randn(size1, dtype=dtype, device='cuda') 108 | torch_pre_alc = get_cuda_allocated() 109 | 110 | torch_z_size = op_bmm(x, y) 111 | torch_temp_alc = get_cuda_max_allocated() - torch_pre_alc 112 | 113 | del x 114 | del y 115 | 116 | assert get_cuda_allocated() == 0 117 | x = MTensor(torch.randn(size0, dtype=dtype, device='cuda')) 118 | y = MTensor(torch.randn(size1, dtype=dtype, device='cuda')) 119 | op1_pre_alc = get_cuda_allocated() 120 | 121 | MTensor.reset_peak_memory() 122 | op1_z_size = op_bmm(x, y) 123 | op1_temp_alc = MTensor.current_peak_memory() - op1_pre_alc 124 | 125 | assert torch_z_size == op1_z_size 126 | assert torch_pre_alc == op1_pre_alc 127 | assert torch_temp_alc == op1_temp_alc 128 | assert len(bmm_cache.temp_memory) > 0 129 | 130 | bmm_cache.print() 131 | 132 | MTensor.reset_peak_memory() 133 | op2_z_size = op_bmm(x, y) 134 | op2_temp_alc = MTensor.current_peak_memory() - op1_pre_alc 135 | 136 | assert torch_z_size == op2_z_size 137 | assert torch_temp_alc == op2_temp_alc 138 | 139 | 140 | if __name__ == '__main__': 141 | test_addmm(dtype=torch.float) 142 | -------------------------------------------------------------------------------- /example/fine-tune/elixir_mini.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from datetime import datetime 3 | from time import time 4 | from typing import Optional 5 | 6 | import colossalai 7 | import datasets 8 | import torch 9 | import torch.distributed as dist 10 | from colossalai.logging import disable_existing_loggers, get_dist_logger 11 | from colossalai.nn import LinearWarmupLR 12 | from data_module import GLUEDataModule 13 | from func_module import evaluate, get_mem_info, get_tflops, seed_all, train 14 | from torch.nn.parallel import DistributedDataParallel as DDP 15 | from transformers import AutoConfig, BertForSequenceClassification, get_linear_schedule_with_warmup 16 | 17 | from elixir.search import minimum_waste_search 18 | from elixir.wrapper import ElixirModule, ElixirOptimizer 19 | 20 | if __name__ == '__main__': 21 | disable_existing_loggers() 22 | colossalai.launch_from_torch(config={}) 23 | logger = get_dist_logger() 24 | world_size = dist.get_world_size() 25 | world_group = dist.GroupMember.WORLD 26 | local_rank = dist.get_rank() 27 | 28 | parser = ArgumentParser() 29 | parser.add_argument('--task', default='mrpc') 30 | parser.add_argument('--epochs', type=int, default=3) 31 | parser.add_argument('--batch_size', type=int, default=32) 32 | parser.add_argument('--lr', type=float, default=2.4e-5) 33 | parser.add_argument('--weight_decay', type=float, default=0.01) 34 | parser.add_argument('--warmup_fraction', type=float, default=0.1) 35 | args = parser.parse_args() 36 | 37 | assert args.batch_size % world_size == 0 38 | global_batch_size = args.batch_size 39 | local_batch_size = args.batch_size // world_size 40 | 41 | global_seed = 3407 42 | seed_all(global_seed) 43 | logger.info('Random is set to {} in all processes.'.format(global_seed), ranks=[0]) 44 | 45 | model_name = 'bert-base-uncased' 46 | logger.info('Data is preparing now.', ranks=[0]) 47 | dm = GLUEDataModule(model_name_or_path=model_name, 48 | task_name=args.task, 49 | train_batch_size=local_batch_size, 50 | eval_batch_size=global_batch_size) 51 | dm.setup('fit') 52 | 53 | config = AutoConfig.from_pretrained(model_name, num_labels=dm.num_labels) 54 | metric = datasets.load_metric('glue', dm.task_name, experiment_id=datetime.now().strftime('%d-%m-%Y_%H-%M-%S')) 55 | 56 | logger.info('Model is creating now.', ranks=[0]) 57 | model = BertForSequenceClassification.from_pretrained(model_name, config=config) 58 | numel = sum([p.numel() for p in model.parameters()]) 59 | logger.info(f'Model numel: {numel}', ranks=[0]) 60 | logger.info(get_mem_info(), ranks=[0]) 61 | 62 | logger.info('Optimizer is creating now.', ranks=[0]) 63 | no_decay = ['bias', 'LayerNorm.weight'] 64 | optimizer_grouped_parameters = [ 65 | { 66 | 'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 67 | 'weight_decay': args.weight_decay, 68 | }, 69 | { 70 | 'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 71 | 'weight_decay': 0.0, 72 | }, 73 | ] 74 | 75 | optimizer = torch.optim.Adam(optimizer_grouped_parameters, lr=args.lr, eps=1e-8) 76 | 77 | sr = minimum_waste_search(model, world_size, torch.float32, verbose=True) 78 | model = ElixirModule(model, sr, world_group, dtype=torch.float32, use_fused_kernels=True) 79 | optimizer = ElixirOptimizer(model, optimizer, init_step=True) 80 | 81 | logger.info('Dataloder is creating now.', ranks=[0]) 82 | train_loader, train_sampler = dm.train_loader_and_sampler() 83 | valid_loader = dm.val_loader_and_sampler() 84 | 85 | logger.info('Learning rate scheduler is creating now.', ranks=[0]) 86 | num_epoch = args.epochs 87 | steps_per_epoch = len(train_loader) 88 | num_all_steps = num_epoch * steps_per_epoch 89 | num_warm_steps = int(num_all_steps * args.warmup_fraction) 90 | lr_scheduler = get_linear_schedule_with_warmup(optimizer=optimizer, 91 | num_warmup_steps=num_warm_steps, 92 | num_training_steps=num_all_steps) 93 | 94 | for epoch in range(num_epoch): 95 | logger.info('Epoch {} starts'.format(epoch), ranks=[0]) 96 | dist.barrier() 97 | train(epoch=epoch, 98 | sampler=train_sampler, 99 | model=model, 100 | loader=train_loader, 101 | optimizer=optimizer, 102 | lr_scheduler=lr_scheduler, 103 | show_progress=local_rank == 0, 104 | optimizer_backward=True) 105 | percentage, f1 = evaluate(model=model, metric=metric, loader=valid_loader, show_progress=local_rank == 0) 106 | logger.info('valid correct percentage: {:.4f}\nf1: {:.4f}'.format(percentage, f1), ranks=[0]) 107 | -------------------------------------------------------------------------------- /elixir/tracer/param_tracer/td_order.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import uuid 3 | from typing import Callable, Dict, Iterator, List, Tuple, Union 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.utils._pytree import tree_map 8 | 9 | from elixir.tracer.ops import SameStorageAten 10 | from elixir.tracer.utils import meta_copy 11 | 12 | 13 | @contextlib.contextmanager 14 | def no_dispatch() -> Iterator[None]: 15 | guard = torch._C._DisableTorchDispatch() 16 | try: 17 | yield 18 | finally: 19 | del guard 20 | 21 | 22 | def normalize_tuple(x): 23 | if not isinstance(x, tuple): 24 | return (x,) 25 | return x 26 | 27 | 28 | def register_storage(x): 29 | assert isinstance(x, nn.Parameter) 30 | assert x.data_ptr() == 0 31 | 32 | data_ptr = uuid.uuid1() 33 | x.data_ptr = lambda: data_ptr 34 | 35 | 36 | class ATensor(torch.Tensor): 37 | elem: torch.Tensor 38 | 39 | __slots__ = ['elem'] 40 | 41 | data_ptr_dict: Dict[int, Tuple[str, nn.Parameter]] = None 42 | order_list: List[Dict] = None 43 | 44 | @staticmethod 45 | def reset(): 46 | ATensor.data_ptr_dict = dict() 47 | ATensor.order_list = list() 48 | 49 | @staticmethod 50 | def clear(): 51 | ATensor.data_ptr_dict = None 52 | ATensor.order_list = None 53 | 54 | @staticmethod 55 | def add_data_ptr(name: str, param: nn.Parameter): 56 | data_ptr = param.data_ptr() 57 | if data_ptr not in ATensor.data_ptr_dict: 58 | ATensor.data_ptr_dict[data_ptr] = (name, param) 59 | else: 60 | name_in, param_in = ATensor.data_ptr_dict[data_ptr] 61 | if name != name_in or id(param) != id(param_in): 62 | raise RuntimeError('Got two different parameters with the same data ptr') 63 | 64 | @staticmethod 65 | def get_param(data_ptr: int): 66 | if data_ptr in ATensor.data_ptr_dict: 67 | return ATensor.data_ptr_dict.get(data_ptr) 68 | else: 69 | return None, None 70 | 71 | @staticmethod 72 | def __new__(cls, elem, *args, **kwargs): 73 | r = torch.Tensor._make_wrapper_subclass( 74 | cls, 75 | elem.size(), 76 | strides=elem.stride(), 77 | storage_offset=elem.storage_offset(), 78 | # TODO: clone strides and storage aliasing 79 | dtype=elem.dtype, 80 | layout=elem.layout, 81 | device=elem.device, 82 | requires_grad=elem.requires_grad) 83 | r.elem = elem 84 | return r 85 | 86 | def __repr__(self): 87 | return f'ATensor({self.elem})' 88 | 89 | @classmethod 90 | def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 91 | step_dict = dict() 92 | 93 | def record_param(x): 94 | if isinstance(x, torch.Tensor): 95 | name, param = ATensor.get_param(x.data_ptr()) 96 | if name is not None: 97 | step_dict[name] = param 98 | 99 | def debug_tensor(x): 100 | if isinstance(x, torch.Tensor): 101 | print(type(x), x.shape, x.data_ptr(), id(x)) 102 | if x.grad_fn: 103 | print(x.grad_fn) 104 | 105 | tree_map(record_param, args) 106 | if len(step_dict) > 0: 107 | ATensor.order_list.append(step_dict) 108 | del step_dict 109 | 110 | def unwrap(x): 111 | return x.elem if isinstance(x, ATensor) else x 112 | 113 | def wrap(x): 114 | return ATensor(x) if isinstance(x, torch.Tensor) else x 115 | 116 | with no_dispatch(): 117 | res = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)) 118 | outs = normalize_tuple(res) 119 | res = tree_map(wrap, outs) 120 | 121 | if func in SameStorageAten: 122 | for x in res: 123 | if isinstance(x, torch.Tensor): 124 | x.data_ptr = args[0].data_ptr 125 | 126 | if len(res) == 1: 127 | return res[0] 128 | else: 129 | return res 130 | 131 | 132 | def generate_td_order(model: nn.Module, inp: Union[torch.Tensor, Tuple], step_fn: Callable): 133 | ATensor.reset() 134 | 135 | def tensor_trans(t): 136 | meta_t = ATensor(t.data.to('meta')) 137 | if isinstance(t, nn.Parameter): 138 | meta_t = nn.Parameter(meta_t) 139 | return meta_t 140 | 141 | model = meta_copy(model, tensor_trans) 142 | for name, param in model.named_parameters(): 143 | register_storage(param) 144 | ATensor.add_data_ptr(name, param) 145 | 146 | # convert all input data to meta_tensor 147 | if not isinstance(inp, tuple): 148 | inp = (inp,) 149 | inp = tree_map(lambda t: ATensor(torch.empty_like(t, device='meta', requires_grad=t.requires_grad)), inp) 150 | 151 | step_fn(model, inp) 152 | 153 | ret = ATensor.order_list 154 | ATensor.clear() 155 | 156 | return ret 157 | -------------------------------------------------------------------------------- /elixir/chunk/core/memory_pool.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from collections import defaultdict 3 | from typing import Iterable, NamedTuple 4 | 5 | import torch 6 | from torch.autograd.profiler_util import _format_memory 7 | 8 | 9 | class BlockRequire(NamedTuple): 10 | numel: int 11 | dtype: torch.dtype 12 | 13 | 14 | class TensorBlock(ABC): 15 | total_count: int = 0 16 | 17 | def __init__(self, numel: int, dtype: torch.dtype, device_type: str) -> None: 18 | self.block_id = TensorBlock.total_count 19 | TensorBlock.total_count += 1 20 | 21 | self.device_type = device_type 22 | self.payload: torch.Tensor = torch.empty((numel,), dtype=dtype, device=device_type) 23 | self.memo_occ: int = self.payload.numel() * self.payload.element_size() 24 | 25 | @property 26 | def numel(self): 27 | return self.payload.numel() 28 | 29 | @property 30 | def dtype(self): 31 | return self.payload.dtype 32 | 33 | @property 34 | def device(self): 35 | return self.payload.device 36 | 37 | def __hash__(self) -> int: 38 | return self.block_id 39 | 40 | def __eq__(self, other: object) -> bool: 41 | return self.block_id == other.block_id 42 | 43 | def __repr__(self) -> str: 44 | return f'(id={self.block_id}, numel={self.numel}, device={self.device_type}, dtype={self.dtype}, memo={self.memo_occ})' 45 | 46 | 47 | class PublicBlock(TensorBlock): 48 | 49 | def __init__(self, numel: int, dtype: torch.dtype, device_type: str) -> None: 50 | super().__init__(numel, dtype, device_type) 51 | self.block_type = 'public' 52 | 53 | def __repr__(self) -> str: 54 | return f'PublicBlock{super().__repr__()}' 55 | 56 | 57 | class PrivateBlock(TensorBlock): 58 | 59 | def __init__(self, numel: int, dtype: torch.dtype, device_type: str) -> None: 60 | super().__init__(numel, dtype, device_type) 61 | self.block_type = 'private' 62 | 63 | def __repr__(self) -> str: 64 | return f'PrivateBlock{super().__repr__()}' 65 | 66 | 67 | class MemoryPool(object): 68 | 69 | def __init__(self, device_type: str) -> None: 70 | self.device_type: str = device_type 71 | 72 | self.public_space: int = 0 73 | self.public_block_size: int = 0 74 | self.public_dtype: torch.dtype = None 75 | 76 | self.public_free_blocks: list = None 77 | self.public_used_blocks: set = None 78 | 79 | self.public_free_cnt: int = 0 80 | self.public_used_cnt: int = 0 81 | 82 | self.private_space: int = 0 83 | self.private_blocks: list = None 84 | self.private_lookup_dict: dict[BlockRequire, list] = None 85 | 86 | self.__allocate_flag = False 87 | 88 | def allocate(self, 89 | public_dtype: torch.dtype = torch.float, 90 | public_block_size: int = 1024, 91 | public_block_number: int = 0, 92 | private_block_list: Iterable[BlockRequire] = ()): 93 | assert self.__allocate_flag is False 94 | assert public_block_number >= 0 95 | 96 | self.public_free_blocks = list() 97 | self.public_used_blocks = set() 98 | for _ in range(public_block_number): 99 | block = PublicBlock(public_block_size, public_dtype, self.device_type) 100 | self.public_free_blocks.append(block) 101 | 102 | if public_block_number <= 0: 103 | self.public_space = 0 104 | else: 105 | self.public_space = self.public_free_blocks[0].memo_occ * public_block_number 106 | self.public_block_size = public_block_size 107 | self.public_dtype = public_dtype 108 | 109 | self.public_free_cnt = public_block_number 110 | self.public_used_cnt = 0 111 | 112 | self.private_space = 0 113 | self.private_blocks = list() 114 | self.private_lookup_dict = defaultdict(list) 115 | 116 | for require in private_block_list: 117 | block = PrivateBlock(require.numel, require.dtype, self.device_type) 118 | self.private_space += block.memo_occ 119 | self.private_blocks.append(block) 120 | self.private_lookup_dict[require].append(block) 121 | 122 | self.__allocate_flag = True 123 | 124 | def __repr__(self) -> str: 125 | return f'MP(public_space={_format_memory(self.public_space)}, private_space={_format_memory(self.private_space)})' 126 | 127 | def get_private_block(self, numel: int, dtype: torch.dtype): 128 | block_list = self.private_lookup_dict.get(BlockRequire(numel=numel, dtype=dtype)) 129 | return block_list.pop() 130 | 131 | def get_public_block(self): 132 | self.public_free_cnt -= 1 133 | self.public_used_cnt += 1 134 | 135 | block = self.public_free_blocks.pop() 136 | self.public_used_blocks.add(block) 137 | 138 | return block 139 | 140 | def free_public_block(self, block: TensorBlock): 141 | assert isinstance(block, PublicBlock) 142 | assert block in self.public_used_blocks 143 | 144 | self.public_free_cnt += 1 145 | self.public_used_cnt -= 1 146 | 147 | self.public_used_blocks.remove(block) 148 | self.public_free_blocks.append(block) 149 | 150 | return block 151 | -------------------------------------------------------------------------------- /test/chunk/test_chunk.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import partial 3 | 4 | import pytest 5 | import torch 6 | import torch.distributed as dist 7 | 8 | from elixir.chunk import BlockRequire, Chunk, MemoryPool, TensorState 9 | from elixir.utils import init_distributed 10 | 11 | 12 | def exam_chunk_functions(nproc, group): 13 | a = torch.randn(2, 64, device='cuda') 14 | copy_a = a.clone() 15 | b = torch.randn(2, 2, 128, device='cuda') 16 | copy_b = b.clone() 17 | c = torch.randn(128, device='cuda') 18 | copy_c = c.clone() 19 | d = torch.randn(4, 32, device='cuda') 20 | copy_d = d.clone() 21 | 22 | mp = MemoryPool('cuda') 23 | mp.allocate(public_block_number=1) 24 | 25 | chunk = Chunk(mp, 1024, torch.float, group) 26 | chunk.l2_norm_flag = True 27 | assert chunk.chunk_size == 1024 28 | assert chunk.chunk_dtype == torch.float 29 | assert chunk.shard_size == 1024 // nproc 30 | 31 | def check_tensors(): 32 | assert torch.equal(a, copy_a) 33 | assert torch.equal(b, copy_b) 34 | assert torch.equal(c, copy_c) 35 | assert torch.equal(d, copy_d) 36 | 37 | chunk.append_tensor(a) 38 | chunk.append_tensor(b) 39 | chunk.append_tensor(c) 40 | chunk.append_tensor(d) 41 | check_tensors() 42 | 43 | chunk.close_chunk() 44 | assert chunk.is_replica is False 45 | # check function: get_cpu_copy 46 | cpu_copys = chunk.get_cpu_copy() 47 | for t_gpu, t_cpu in zip([copy_a, copy_b, copy_c, copy_d], cpu_copys): 48 | assert t_cpu.device.type == 'cpu' 49 | assert torch.equal(t_gpu.cpu(), t_cpu) 50 | # check function: access_chunk 51 | block = mp.get_public_block() 52 | chunk.access_chunk(block) 53 | assert chunk.is_replica 54 | assert chunk.scatter_check 55 | check_tensors() 56 | # check function: release_chunk 57 | chunk.optim_sync_flag = False 58 | block = chunk.release_chunk() 59 | assert block in mp.public_used_blocks 60 | assert chunk.is_replica is False 61 | assert chunk.optim_sync_flag is True 62 | # check function: access_chunk after release_chunk 63 | chunk.access_chunk(block) 64 | check_tensors() 65 | # check function: reduce_chunk 66 | norm = block.payload.float().norm(2)**2 67 | chunk.reduce_chunk() 68 | assert chunk.is_replica is False 69 | assert chunk.tensor_state_cnter[TensorState.HOLD] == 4 70 | 71 | test_norm = torch.Tensor([chunk.l2_norm]).cuda() 72 | dist.all_reduce(test_norm) 73 | assert torch.allclose(norm, test_norm) 74 | 75 | torch.cuda.synchronize() 76 | print('chunk functions are ok') 77 | 78 | 79 | def exam_chunk_states(nproc, group): 80 | a = torch.randn(2, 64, device='cuda') 81 | copy_a = a.clone() 82 | b = torch.randn(2, 2, 128, device='cuda') 83 | copy_b = b.clone() 84 | c = torch.randn(128, device='cuda') 85 | copy_c = c.clone() 86 | d = torch.randn(4, 32, device='cuda') 87 | copy_d = d.clone() 88 | 89 | private = [BlockRequire(1024, torch.float)] 90 | mp = MemoryPool('cuda') 91 | mp.allocate(private_block_list=private) 92 | 93 | chunk = Chunk(mp, 1024, torch.float, group, rcache_fused=True) 94 | assert chunk.chunk_size == 1024 95 | assert chunk.chunk_dtype == torch.float 96 | assert chunk.shard_size == 1024 // nproc 97 | 98 | def check_tensors(): 99 | assert torch.equal(a, copy_a) 100 | assert torch.equal(b, copy_b) 101 | assert torch.equal(c, copy_c) 102 | assert torch.equal(d, copy_d) 103 | 104 | chunk.append_tensor(a) 105 | chunk.append_tensor(b) 106 | chunk.append_tensor(c) 107 | chunk.append_tensor(d) 108 | check_tensors() 109 | 110 | chunk.close_chunk() 111 | assert chunk.is_replica is False 112 | 113 | chunk.access_chunk() 114 | assert chunk.is_replica 115 | check_tensors() 116 | 117 | assert chunk.tensor_state_cnter[TensorState.HOLD] == 4 118 | chunk.tensor_trans_state(a, TensorState.COMPUTE) 119 | assert chunk.tensor_state_cnter[TensorState.HOLD] == 3 120 | assert chunk.tensor_state_cnter[TensorState.COMPUTE] == 1 121 | 122 | tensor_list = [a, b, c, d] 123 | for t in tensor_list: 124 | chunk.tensor_trans_state(t, TensorState.COMPUTE) 125 | chunk.tensor_trans_state(t, TensorState.HOLD_AFTER_BWD) 126 | chunk.tensor_trans_state(t, TensorState.READY_FOR_REDUCE) 127 | assert chunk.tensor_state_cnter[TensorState.READY_FOR_REDUCE] == 4 128 | assert chunk.reduce_check 129 | 130 | torch.cuda.synchronize() 131 | print('chunk states are ok') 132 | 133 | 134 | def run_dist(rank, world_size): 135 | os.environ['RANK'] = str(rank) 136 | os.environ['LOCAL_RANK'] = str(rank) 137 | os.environ['WORLD_SIZE'] = str(world_size) 138 | os.environ['MASTER_ADDR'] = '127.0.0.1' 139 | os.environ['MASTER_PORT'] = str(29512) 140 | init_distributed() 141 | exam_chunk_functions(nproc=world_size, group=dist.GroupMember.WORLD) 142 | exam_chunk_states(nproc=world_size, group=dist.GroupMember.WORLD) 143 | 144 | 145 | @pytest.mark.dist 146 | @pytest.mark.parametrize('world_size', [1, 2, 4]) 147 | def test_chunk_functions(world_size): 148 | run_func = partial(run_dist, world_size=world_size) 149 | torch.multiprocessing.spawn(run_func, nprocs=world_size) 150 | 151 | 152 | if __name__ == '__main__': 153 | test_chunk_functions(world_size=4) 154 | --------------------------------------------------------------------------------