├── utils ├── __init__.py ├── easy_context │ ├── ulysses_attn │ │ ├── __pycache__ │ │ │ ├── monkey_patch.cpython-310.pyc │ │ │ └── prepare_inputs.cpython-310.pyc │ │ ├── prepare_inputs.py │ │ └── monkey_patch.py │ ├── dist_flash_attn │ │ ├── __pycache__ │ │ │ ├── monkey_patch.cpython-310.pyc │ │ │ ├── prepare_input.cpython-310.pyc │ │ │ ├── async_communication.cpython-310.pyc │ │ │ └── lightseq_async_attn.cpython-310.pyc │ │ ├── prepare_input.py │ │ ├── async_communication.py │ │ ├── monkey_patch.py │ │ ├── lightseq_async_attn.py │ │ └── lightseq_async_attn_varlen.py │ ├── zigzag_ring_attn │ │ ├── __pycache__ │ │ │ ├── monkey_patch.cpython-310.pyc │ │ │ └── prepare_inputs.cpython-310.pyc │ │ ├── prepare_inputs.py │ │ └── monkey_patch.py │ ├── unsloth_offloaded_gradient_checkpoint │ │ ├── __pycache__ │ │ │ └── monkey_patch.cpython-310.pyc │ │ └── monkey_patch.py │ └── __init__.py ├── accelerate_configs │ ├── single_node.yaml │ ├── single_node_2.yaml │ ├── zero3_offload.json │ └── zero3_offload_stage2.json ├── logger.py ├── logits_compute │ ├── readme.md │ └── logits_compute.py ├── loader.py ├── preprocess_data.py └── train.py ├── preprocess_token_PI ├── __init__.py ├── dataprocessor.py └── FSProcessor.py ├── requirements.txt ├── train_LR_llama3_target80k_use24k.sh └── README.md /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /preprocess_token_PI/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/easy_context/ulysses_attn/__pycache__/monkey_patch.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyuanhubj/LongRecipe/HEAD/utils/easy_context/ulysses_attn/__pycache__/monkey_patch.cpython-310.pyc -------------------------------------------------------------------------------- /utils/easy_context/dist_flash_attn/__pycache__/monkey_patch.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyuanhubj/LongRecipe/HEAD/utils/easy_context/dist_flash_attn/__pycache__/monkey_patch.cpython-310.pyc -------------------------------------------------------------------------------- /utils/easy_context/ulysses_attn/__pycache__/prepare_inputs.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyuanhubj/LongRecipe/HEAD/utils/easy_context/ulysses_attn/__pycache__/prepare_inputs.cpython-310.pyc -------------------------------------------------------------------------------- /utils/easy_context/dist_flash_attn/__pycache__/prepare_input.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyuanhubj/LongRecipe/HEAD/utils/easy_context/dist_flash_attn/__pycache__/prepare_input.cpython-310.pyc -------------------------------------------------------------------------------- /utils/easy_context/zigzag_ring_attn/__pycache__/monkey_patch.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyuanhubj/LongRecipe/HEAD/utils/easy_context/zigzag_ring_attn/__pycache__/monkey_patch.cpython-310.pyc -------------------------------------------------------------------------------- /utils/easy_context/zigzag_ring_attn/__pycache__/prepare_inputs.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyuanhubj/LongRecipe/HEAD/utils/easy_context/zigzag_ring_attn/__pycache__/prepare_inputs.cpython-310.pyc -------------------------------------------------------------------------------- /utils/easy_context/dist_flash_attn/__pycache__/async_communication.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyuanhubj/LongRecipe/HEAD/utils/easy_context/dist_flash_attn/__pycache__/async_communication.cpython-310.pyc -------------------------------------------------------------------------------- /utils/easy_context/dist_flash_attn/__pycache__/lightseq_async_attn.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyuanhubj/LongRecipe/HEAD/utils/easy_context/dist_flash_attn/__pycache__/lightseq_async_attn.cpython-310.pyc -------------------------------------------------------------------------------- /utils/easy_context/unsloth_offloaded_gradient_checkpoint/__pycache__/monkey_patch.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiyuanhubj/LongRecipe/HEAD/utils/easy_context/unsloth_offloaded_gradient_checkpoint/__pycache__/monkey_patch.cpython-310.pyc -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.28.0 2 | datasets==2.18.0 3 | einops==0.8.0 4 | torch==2.1.2 5 | flash_attn==2.5.8 6 | numpy==2.1.0 7 | pytest==8.1.1 8 | ring_flash_attn@git+https://github.com/zhuzilin/ring-flash-attention 9 | tiktoken==0.5.2 10 | tqdm==4.65.0 11 | transformers==4.39.1 12 | triton==2.1.0 13 | xformers==0.0.23.post1 14 | deepspeed==0.14.0 15 | yunchang==0.1 16 | -------------------------------------------------------------------------------- /utils/accelerate_configs/single_node.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | deepspeed_config_file: utils/accelerate_configs/zero3_offload.json 5 | zero3_init_flag: false 6 | distributed_type: DEEPSPEED 7 | downcast_bf16: 'no' 8 | machine_rank: 0 9 | main_training_function: main 10 | num_machines: 1 11 | num_processes: 1 12 | rdzv_backend: static 13 | same_network: true 14 | tpu_env: [] 15 | tpu_use_cluster: false 16 | tpu_use_sudo: false 17 | use_cpu: false 18 | -------------------------------------------------------------------------------- /utils/accelerate_configs/single_node_2.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | deepspeed_config_file: utils/accelerate_configs/zero3_offload_stage2.json 5 | zero3_init_flag: false 6 | distributed_type: DEEPSPEED 7 | downcast_bf16: 'no' 8 | machine_rank: 0 9 | main_training_function: main 10 | num_machines: 1 11 | num_processes: 1 12 | rdzv_backend: static 13 | same_network: true 14 | tpu_env: [] 15 | tpu_use_cluster: false 16 | tpu_use_sudo: false 17 | use_cpu: false 18 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from tqdm import tqdm 3 | 4 | class TqdmToLogger(object): 5 | def __init__(self, logger, level=logging.INFO): 6 | self.logger = logger 7 | self.level = level 8 | self.terminal = tqdm(total=0, position=0, file=open('/dev/null', 'w')) # 使用/dev/null丢弃默认的输出 9 | 10 | def write(self, message): 11 | # tqdm 组件会在末尾添加 '\r' 用于回到行首,因此在写入日志时需要剔除 12 | if message.rstrip() != '': 13 | self.logger.log(self.level, message.rstrip('\r')) 14 | 15 | def flush(self): 16 | pass 17 | 18 | -------------------------------------------------------------------------------- /utils/logits_compute/readme.md: -------------------------------------------------------------------------------- 1 | ## Logits Computation based on Llama3-8B Example. 2 | 3 | We use the existing next token prediction of vllm and modify the `simpler.py` file to obtain the logits for the input prompt. The version of vllm we are using is [v0.4.0](https://github.com/vllm-project/vllm/tree/v0.4.0). After successfully installing vllm, you need to replace the original `simpler.py` file in `/model_executor/layers/sampler.py` with the modified version. 4 | 5 | 6 | Thus, you can excute the following code and adjust the file-saving paths as needed. 7 | ``` 8 | python logits_compute.py 9 | ``` 10 | 11 | Subsequently, you will have three files: `prompt_logits.json` containing the logits for the inputs,, `prompt_tokens_logits.json` mapping tokens to their corresponding logits., and `prompt_ids_logits.json` mapping token ids to their corresponding logits. 12 | -------------------------------------------------------------------------------- /utils/loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | from utils.logger import TqdmToLogger 4 | from transformers import AutoModelForCausalLM, AutoTokenizer 5 | 6 | 7 | 8 | 9 | 10 | def load_model_and_tokenizer(model_path, accelerator): 11 | 12 | model = AutoModelForCausalLM.from_pretrained( 13 | model_path, 14 | device_map=accelerator.device, 15 | torch_dtype=torch.bfloat16, 16 | _attn_implementation="flash_attention_2", 17 | ) 18 | 19 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast_tokenizer=True) 20 | 21 | return model, tokenizer 22 | 23 | 24 | def load_logger(log_path): 25 | 26 | logging.basicConfig(filename=log_path, level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') 27 | logger = logging.getLogger(__name__) 28 | tqdm_out = TqdmToLogger(logger, level=logging.INFO) 29 | 30 | return logger, tqdm_out 31 | -------------------------------------------------------------------------------- /utils/easy_context/dist_flash_attn/prepare_input.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | def extract_local(value, rank, world_size, device, dim=1): 4 | value_local = value.chunk(world_size, dim=dim)[rank] 5 | return value_local.to(device) 6 | 7 | 8 | def prepare_dist_flash_attn_inputs( 9 | input_ids, position_ids, target_ids, rank, world_size, device 10 | ): 11 | local_input_ids = extract_local( 12 | input_ids, 13 | rank, 14 | world_size, 15 | device, 16 | ) 17 | local_position_ids = extract_local( 18 | position_ids, 19 | rank, 20 | world_size, 21 | device, 22 | ) 23 | if target_ids is not None: 24 | local_target_ids = extract_local( 25 | target_ids, 26 | rank, 27 | world_size, 28 | device, 29 | ) 30 | else: 31 | local_target_ids = None 32 | return { 33 | "local_input_ids": local_input_ids, 34 | "local_position_ids": local_position_ids, 35 | "local_target_ids": local_target_ids, 36 | } -------------------------------------------------------------------------------- /utils/easy_context/zigzag_ring_attn/prepare_inputs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def extract_local(value, rank, world_size, device, dim=1): 5 | value_chunks = value.chunk(2 * world_size, dim=dim) 6 | local_value = torch.cat( 7 | [value_chunks[rank], value_chunks[2 * world_size - rank - 1]], dim=dim 8 | ) 9 | return local_value.to(device) 10 | 11 | 12 | def prepare_zigzag_ring_attn_inputs( 13 | input_ids, position_ids, target_ids, rank, world_size, device 14 | ): 15 | local_input_ids = extract_local( 16 | input_ids, 17 | rank, 18 | world_size, 19 | device, 20 | ) 21 | local_position_ids = extract_local( 22 | position_ids, 23 | rank, 24 | world_size, 25 | device, 26 | ) 27 | if target_ids is not None: 28 | local_target_ids = extract_local( 29 | target_ids, 30 | rank, 31 | world_size, 32 | device, 33 | ) 34 | else: 35 | local_target_ids = None 36 | return { 37 | "local_input_ids": local_input_ids, 38 | "local_position_ids": local_position_ids, 39 | "local_target_ids": local_target_ids, 40 | } 41 | -------------------------------------------------------------------------------- /utils/logits_compute/logits_compute.py: -------------------------------------------------------------------------------- 1 | from vllm import LLM, SamplingParams 2 | from transformers import AutoTokenizer 3 | import json 4 | 5 | prompt = ['Hello, world, greate world'] 6 | # prompt = ['Hugging Face is creating great transformers models!'] 7 | sampling_params = SamplingParams(prompt_logprobs=1) 8 | model_path = 'meta-llama/Meta-Llama-3-8B' 9 | 10 | tokenizer = AutoTokenizer.from_pretrained(model_path) 11 | llm = LLM(model=model_path,trust_remote_code=True,gpu_memory_utilization=0.9) 12 | 13 | tokenized_input = tokenizer(prompt[0]) 14 | tokens = tokenizer.convert_ids_to_tokens(tokenized_input['input_ids']) 15 | ids = tokenized_input['input_ids'] 16 | ids.pop(0) 17 | tokens.pop(0) 18 | 19 | try: 20 | outputs = llm.generate(prompt,sampling_params) 21 | except: 22 | pass 23 | 24 | with open('prompt_logits.json','r') as f: 25 | logits = json.load(f) 26 | 27 | def save_to_json(filename, data): 28 | with open(filename, 'w') as f: 29 | json.dump(data, f) 30 | 31 | save_to_json('prompt_tokens_logits.json', dict(zip(tokens, logits[0]))) ## chagne to your customized path 32 | save_to_json('prompt_ids_logits.json', dict(zip(ids, logits[0]))) ## chagne to your customized path 33 | -------------------------------------------------------------------------------- /utils/easy_context/ulysses_attn/prepare_inputs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def extract_local(value, rank, world_size, device, dim=1): 5 | dimension_size = value.shape[dim] 6 | sub_seq_length = dimension_size // world_size 7 | 8 | sub_seq_start = rank * sub_seq_length 9 | sub_seq_end = (rank + 1) * sub_seq_length 10 | local_value = value[:, sub_seq_start:sub_seq_end] 11 | 12 | return local_value.to(device) 13 | 14 | 15 | def prepare_ulysses_attn_inputs( 16 | input_ids, position_ids, target_ids, rank, world_size, device 17 | ): 18 | 19 | local_input_ids = extract_local( 20 | input_ids, 21 | rank, 22 | world_size, 23 | device, 24 | ) 25 | local_position_ids = extract_local( 26 | position_ids, 27 | rank, 28 | world_size, 29 | device, 30 | ) 31 | 32 | if target_ids is not None: 33 | local_target_ids = extract_local( 34 | target_ids, 35 | rank, 36 | world_size, 37 | device, 38 | ) 39 | else: 40 | local_target_ids = None 41 | return { 42 | "local_input_ids": local_input_ids, 43 | "local_position_ids": local_position_ids, 44 | "local_target_ids": local_target_ids, 45 | } 46 | -------------------------------------------------------------------------------- /utils/accelerate_configs/zero3_offload.json: -------------------------------------------------------------------------------- 1 | { 2 | "bf16": { 3 | "enabled": "auto" 4 | }, 5 | "fp16": { 6 | "enabled": "auto" 7 | }, 8 | "scheduler": { 9 | "type": "WarmupLR", 10 | "params": { 11 | "warmup_min_lr": 0, 12 | "warmup_max_lr": 5e-5, 13 | "warmup_num_steps": 0, 14 | "warmup_type": "linear" 15 | } 16 | }, 17 | "optimizer": { 18 | "type": "AdamW", 19 | "params": { 20 | "lr": "auto", 21 | "betas": [0.9, 0.999], 22 | "eps": 1e-8, 23 | "weight_decay": 0.1 24 | } 25 | }, 26 | "zero_optimization": { 27 | "stage": 3, 28 | "offload_optimizer": { 29 | "device": "cpu", 30 | "pin_memory": true 31 | }, 32 | "offload_param": { 33 | "device": "cpu", 34 | "pin_memory": true 35 | }, 36 | "overlap_comm": true, 37 | "contiguous_gradients": true, 38 | "sub_group_size": 1e9, 39 | "reduce_bucket_size": "auto", 40 | "stage3_prefetch_bucket_size": "auto", 41 | "stage3_param_persistence_threshold": "auto", 42 | "stage3_max_live_parameters": 1e9, 43 | "stage3_max_reuse_distance": 1e9, 44 | "stage3_gather_16bit_weights_on_model_save": true 45 | }, 46 | "gradient_accumulation_steps": "auto", 47 | "gradient_clipping": "auto", 48 | "steps_per_print": 2000, 49 | "train_batch_size": "auto", 50 | "train_micro_batch_size_per_gpu": 1, 51 | "wall_clock_breakdown": false 52 | } 53 | -------------------------------------------------------------------------------- /utils/accelerate_configs/zero3_offload_stage2.json: -------------------------------------------------------------------------------- 1 | { 2 | "bf16": { 3 | "enabled": "auto" 4 | }, 5 | "fp16": { 6 | "enabled": "auto" 7 | }, 8 | "scheduler": { 9 | "type": "WarmupLR", 10 | "params": { 11 | "warmup_min_lr": 0, 12 | "warmup_max_lr": 5e-6, 13 | "warmup_num_steps": 0, 14 | "warmup_type": "linear" 15 | } 16 | }, 17 | "optimizer": { 18 | "type": "AdamW", 19 | "params": { 20 | "lr": "auto", 21 | "betas": [0.9, 0.999], 22 | "eps": 1e-8, 23 | "weight_decay": 0.1 24 | } 25 | }, 26 | "zero_optimization": { 27 | "stage": 3, 28 | "offload_optimizer": { 29 | "device": "cpu", 30 | "pin_memory": true 31 | }, 32 | "offload_param": { 33 | "device": "cpu", 34 | "pin_memory": true 35 | }, 36 | "overlap_comm": true, 37 | "contiguous_gradients": true, 38 | "sub_group_size": 1e9, 39 | "reduce_bucket_size": "auto", 40 | "stage3_prefetch_bucket_size": "auto", 41 | "stage3_param_persistence_threshold": "auto", 42 | "stage3_max_live_parameters": 1e9, 43 | "stage3_max_reuse_distance": 1e9, 44 | "stage3_gather_16bit_weights_on_model_save": true 45 | }, 46 | "gradient_accumulation_steps": "auto", 47 | "gradient_clipping": "auto", 48 | "steps_per_print": 2000, 49 | "train_batch_size": "auto", 50 | "train_micro_batch_size_per_gpu": 1, 51 | "wall_clock_breakdown": false 52 | } -------------------------------------------------------------------------------- /utils/easy_context/__init__.py: -------------------------------------------------------------------------------- 1 | from .dist_flash_attn.prepare_input import prepare_dist_flash_attn_inputs 2 | from .dist_flash_attn.monkey_patch import apply_dist_flash_attn_monkey_patch_llama 3 | from .zigzag_ring_attn.prepare_inputs import prepare_zigzag_ring_attn_inputs 4 | from .zigzag_ring_attn.monkey_patch import apply_zigzag_ring_attn_monkey_patch_llama 5 | from .zigzag_ring_attn.monkey_patch import apply_zigzag_ring_attn_monkey_patch_mistral 6 | from .unsloth_offloaded_gradient_checkpoint.monkey_patch import apply_unsloth_offloaded_gradient_checkpoint_monkey_patch 7 | from .ulysses_attn.prepare_inputs import prepare_ulysses_attn_inputs 8 | from .ulysses_attn.monkey_patch import apply_ulysses_attn_monkey_patch_llama 9 | 10 | def prepare_seq_parallel_inputs( 11 | seq_algo, input_ids, position_ids, target_ids, rank, world_size, device 12 | ): 13 | if seq_algo == "zigzag_ring_attn": 14 | return prepare_zigzag_ring_attn_inputs( 15 | input_ids, position_ids, target_ids, rank, world_size, device 16 | ) 17 | elif seq_algo == "dist_flash_attn": 18 | return prepare_dist_flash_attn_inputs( 19 | input_ids, position_ids, target_ids, rank, world_size, device 20 | ) 21 | elif seq_algo == "ulysses_attn": 22 | return prepare_ulysses_attn_inputs( 23 | input_ids, position_ids, target_ids, rank, world_size, device 24 | ) 25 | elif seq_algo == "data_parallel": 26 | return { 27 | "local_input_ids": input_ids.to(device), 28 | "local_position_ids": position_ids.to(device), 29 | "local_target_ids": target_ids.to(device), 30 | } 31 | else: 32 | raise ValueError(f"Invalid seq_algo: {seq_algo}") 33 | 34 | def apply_seq_parallel_monkey_patch( 35 | seq_algo, model 36 | ): 37 | assert seq_algo in ["zigzag_ring_attn", "dist_flash_attn", "ulysses_attn", "data_parallel"], f"Invalid seq_algo: {seq_algo}" 38 | assert model in ["llama", "mistral"], f"Invalid model: {model}" 39 | if seq_algo == "data_parallel": 40 | return 41 | elif seq_algo == "zigzag_ring_attn" and model == "llama": 42 | apply_zigzag_ring_attn_monkey_patch_llama() 43 | elif seq_algo == "zigzag_ring_attn" and model == "mistral": 44 | apply_zigzag_ring_attn_monkey_patch_mistral() 45 | elif seq_algo == "dist_flash_attn" and model == "llama": 46 | apply_dist_flash_attn_monkey_patch_llama() 47 | elif seq_algo == "ulysses_attn" and model == "llama": 48 | apply_ulysses_attn_monkey_patch_llama() 49 | else: 50 | raise ValueError(f"Invalid seq_algo: {seq_algo} or model: {model}") 51 | 52 | def prepare_dataloader(seq_algo, dataloader, acclerator): 53 | if seq_algo == "data_parallel": 54 | return acclerator.prepare(dataloader) 55 | else: 56 | return dataloader -------------------------------------------------------------------------------- /train_LR_llama3_target80k_use24k.sh: -------------------------------------------------------------------------------- 1 | 2 | SEQ_LENGTH=24000 3 | TARGET_LENGTH=80000 4 | SETTING='LongRecipe' 5 | MODEL_NAME=llama3_8b 6 | Right_Points_PATH=./output/llama3_LPS_6_digits_l24000_t80000_min1_max1.pkl 7 | FS_PI_PATH=./output/llama3_LF_6_digits_l24000_t80000_min1_max1.pkl 8 | SUB_LABEL='LR_target80k_use24k' 9 | DATA_PATH_1='./output/feature_6_tokens.jsonl' 10 | DATA_PATH_2='replay_dataset' 11 | MODEL='model_path' 12 | # --parallel_mode: data_parallel; 13 | 14 | accelerate launch \ 15 | --config_file utils/accelerate_configs/single_node.yaml \ 16 | utils/train.py \ 17 | --batch-size 1 \ 18 | --gradient-accumulate-every 96 \ 19 | --learning-rate 5e-5 \ 20 | --epoch 1 \ 21 | --data_path $DATA_PATH_1 \ 22 | --output-dir ./output/$MODEL_NAME-$SETTING-$SEQ_LENGTH-$SUB_LABEL \ 23 | --seed 2027 \ 24 | --model $MODEL \ 25 | --seq-length $SEQ_LENGTH \ 26 | --target-length $TARGET_LENGTH \ 27 | --log-path $SETTING-$SEQ_LENGTH-$MODEL_NAME-$SUB_LABEL.log \ 28 | --setting $SETTING \ 29 | --right_points-path $Right_Points_PATH \ 30 | --fs_PI-path $FS_PI_PATH \ 31 | --parallel_mode data_parallel \ 32 | --num_proc 5 \ 33 | --stage 0 34 | 35 | cp $MODEL/special_tokens_map.json ./output/$MODEL_NAME-$SETTING-$SEQ_LENGTH-$SUB_LABEL/stage_0 36 | cp $MODEL/tokenizer_config.json ./output/$MODEL_NAME-$SETTING-$SEQ_LENGTH-$SUB_LABEL/stage_0 37 | cp $MODEL/tokenizer.json ./output/$MODEL_NAME-$SETTING-$SEQ_LENGTH-$SUB_LABEL/stage_0 38 | rm ./output/$MODEL_NAME-$SETTING-$SEQ_LENGTH-$SUB_LABEL/stage_0/model.safetensors 39 | 40 | 41 | accelerate launch \ 42 | --config_file utils/accelerate_configs/single_node_2.yaml \ 43 | utils/train.py \ 44 | --data_path $DATA_PATH_2 \ 45 | --batch-size 1 \ 46 | --gradient-accumulate-every 96 \ 47 | --learning-rate 5e-6 \ 48 | --epoch 1 \ 49 | --output-dir ./output/$MODEL_NAME-$SETTING-$SEQ_LENGTH-$SUB_LABEL \ 50 | --seed 2027 \ 51 | --model $MODEL \ 52 | --seq-length $SEQ_LENGTH \ 53 | --target-length $TARGET_LENGTH \ 54 | --log-path $SETTING-$SEQ_LENGTH-$MODEL_NAME-$SUB_LABEL.log \ 55 | --setting full \ 56 | --right_points-path $Right_Points_PATH \ 57 | --fs_PI-path $FS_PI_PATH \ 58 | --parallel_mode data_parallel \ 59 | --num_proc 5 \ 60 | --stage 1 61 | 62 | cp $MODEL/special_tokens_map.json ./output/$MODEL_NAME-$SETTING-$SEQ_LENGTH-$SUB_LABEL/stage_1 63 | cp $MODEL/tokenizer_config.json ./output/$MODEL_NAME-$SETTING-$SEQ_LENGTH-$SUB_LABEL/stage_1 64 | cp $MODEL/tokenizer.json ./output/$MODEL_NAME-$SETTING-$SEQ_LENGTH-$SUB_LABEL/stage_1 65 | rm ./output/$MODEL_NAME-$SETTING-$SEQ_LENGTH-$SUB_LABEL/stage_1/model.safetensors 66 | 67 | 68 | accelerate utils/launch \ 69 | --config_file accelerate_configs/single_node.yaml \ 70 | utils/train.py \ 71 | --output-dir ./output/$MODEL_NAME-$SETTING-$SEQ_LENGTH-$SUB_LABEL \ 72 | --seed 2027 \ 73 | --model $MODEL \ 74 | --log-path $SETTING-$SEQ_LENGTH-$MODEL_NAME-$SUB_LABEL.log \ 75 | --stage 2 76 | 77 | cp $MODEL/special_tokens_map.json ./output/$MODEL_NAME-$SETTING-$SEQ_LENGTH-$SUB_LABEL/stage_2 78 | cp $MODEL/tokenizer_config.json ./output/$MODEL_NAME-$SETTING-$SEQ_LENGTH-$SUB_LABEL/stage_2 79 | cp $MODEL/tokenizer.json ./output/$MODEL_NAME-$SETTING-$SEQ_LENGTH-$SUB_LABEL/stage_2 80 | # rm ./output/$MODEL_NAME-$SETTING-$SEQ_LENGTH-$SUB_LABEL/stage_2/model.safetensors 81 | 82 | -------------------------------------------------------------------------------- /utils/easy_context/ulysses_attn/monkey_patch.py: -------------------------------------------------------------------------------- 1 | import transformers 2 | from typing import List, Optional, Tuple, Union 3 | import warnings 4 | import torch 5 | import torch.utils.checkpoint 6 | try: 7 | from yunchang.ulysses import UlyssesAttention 8 | ulysses_attn = UlyssesAttention() 9 | except: 10 | print("If you want to use the UlyssesAttention class, please install the yunchang package.") 11 | ulysses_attn = None 12 | 13 | 14 | def new_flash_attn_forward( 15 | self, 16 | query_states, 17 | key_states, 18 | value_states, 19 | attention_mask, 20 | query_length, 21 | dropout=0.0, 22 | softmax_scale=None, 23 | use_sliding_windows=False, 24 | ): 25 | if not self._flash_attn_uses_top_left_mask: 26 | causal = self.is_causal 27 | else: 28 | causal = self.is_causal and query_length != 1 29 | 30 | # Contains at least one padding token in the sequence 31 | assert attention_mask is None 32 | assert causal is True 33 | assert use_sliding_windows is False 34 | attn_output = ulysses_attn( 35 | query_states, 36 | key_states, 37 | value_states, 38 | dropout, 39 | softmax_scale, 40 | causal=causal, 41 | ) 42 | 43 | return attn_output 44 | 45 | 46 | def new_decoder_forward( 47 | self, 48 | hidden_states: torch.Tensor, 49 | attention_mask: Optional[torch.Tensor] = None, 50 | position_ids: Optional[torch.LongTensor] = None, 51 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 52 | output_attentions: Optional[bool] = False, 53 | use_cache: Optional[bool] = False, 54 | cache_position: Optional[torch.LongTensor] = None, 55 | **kwargs, 56 | ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: 57 | assert isinstance( 58 | self.self_attn, transformers.models.llama.modeling_llama.LlamaFlashAttention2 59 | ) or isinstance( 60 | self.self_attn, 61 | transformers.models.mistral.modeling_mistral.MistralFlashAttention2, 62 | ), "Please toggle on the Flash Attention 2 implementation when using zigzag ring attention monkey patch." 63 | 64 | if "padding_mask" in kwargs: 65 | warnings.warn( 66 | "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" 67 | ) 68 | 69 | residual = hidden_states 70 | 71 | hidden_states = self.input_layernorm(hidden_states) 72 | 73 | # Self Attention 74 | hidden_states, self_attn_weights, present_key_value = self.self_attn( 75 | hidden_states=hidden_states, 76 | attention_mask=attention_mask, 77 | position_ids=position_ids, 78 | past_key_value=past_key_value, 79 | output_attentions=output_attentions, 80 | use_cache=use_cache, 81 | cache_position=cache_position, 82 | **kwargs, 83 | ) 84 | hidden_states = residual + hidden_states 85 | 86 | # Fully Connected 87 | residual = hidden_states 88 | hidden_states = self.post_attention_layernorm(hidden_states) 89 | hidden_states = self.mlp(hidden_states) 90 | hidden_states = residual + hidden_states 91 | 92 | outputs = (hidden_states,) 93 | 94 | if output_attentions: 95 | outputs += (self_attn_weights,) 96 | 97 | if use_cache: 98 | outputs += (present_key_value,) 99 | 100 | return outputs 101 | 102 | 103 | def apply_ulysses_attn_monkey_patch_llama(): 104 | transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward = ( 105 | new_flash_attn_forward 106 | ) 107 | transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = ( 108 | new_decoder_forward 109 | ) 110 | 111 | 112 | -------------------------------------------------------------------------------- /utils/easy_context/unsloth_offloaded_gradient_checkpoint/monkey_patch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import transformers 17 | import inspect 18 | 19 | 20 | class Unsloth_Offloaded_Gradient_Checkpointer(torch.autograd.Function): 21 | """ 22 | Saves VRAM by smartly offloading to RAM. 23 | Tiny hit to performance, since we mask the movement via non blocking calls. 24 | """ 25 | 26 | @staticmethod 27 | @torch.cuda.amp.custom_fwd 28 | def forward(ctx, forward_function, hidden_states, *args): 29 | saved_hidden_states = hidden_states.to("cpu", non_blocking=True) 30 | with torch.no_grad(): 31 | output = forward_function(hidden_states, *args) 32 | ctx.save_for_backward(saved_hidden_states) 33 | ctx.forward_function = forward_function 34 | ctx.args = args 35 | 36 | return output 37 | 38 | pass 39 | 40 | @staticmethod 41 | @torch.cuda.amp.custom_bwd 42 | def backward(ctx, dY): 43 | (hidden_states,) = ctx.saved_tensors 44 | hidden_states = hidden_states.to("cuda", non_blocking=True).detach() 45 | hidden_states.requires_grad = True 46 | with torch.enable_grad(): 47 | (output,) = ctx.forward_function(hidden_states, *ctx.args) 48 | torch.autograd.backward(output, dY) 49 | return ( 50 | None, 51 | hidden_states.grad, 52 | ) + ( 53 | None, 54 | ) * len(ctx.args) 55 | 56 | pass 57 | 58 | 59 | pass 60 | 61 | 62 | def new_gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): 63 | assert gradient_checkpointing_kwargs == None 64 | if not self.supports_gradient_checkpointing: 65 | raise ValueError( 66 | f"{self.__class__.__name__} does not support gradient checkpointing." 67 | ) 68 | 69 | gradient_checkpointing_func = Unsloth_Offloaded_Gradient_Checkpointer.apply 70 | # For old GC format (transformers < 4.35.0) for models that live on the Hub 71 | # we will fall back to the overwritten `_set_gradient_checkpointing` method 72 | _is_using_old_format = ( 73 | "value" in inspect.signature(self._set_gradient_checkpointing).parameters 74 | ) 75 | 76 | if not _is_using_old_format: 77 | self._set_gradient_checkpointing( 78 | enable=True, gradient_checkpointing_func=gradient_checkpointing_func 79 | ) 80 | else: 81 | raise NotImplementedError() 82 | 83 | if getattr(self, "_hf_peft_config_loaded", False): 84 | # When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True 85 | # we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334 86 | # When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate 87 | # the gradients to make sure the gradient flows. 88 | self.enable_input_require_grads() 89 | 90 | 91 | def apply_unsloth_offloaded_gradient_checkpoint_monkey_patch(): 92 | transformers.modeling_utils.PreTrainedModel.gradient_checkpointing_enable = ( 93 | new_gradient_checkpointing_enable 94 | ) 95 | -------------------------------------------------------------------------------- /utils/easy_context/zigzag_ring_attn/monkey_patch.py: -------------------------------------------------------------------------------- 1 | import transformers 2 | from typing import List, Optional, Tuple, Union 3 | import warnings 4 | import torch 5 | import torch.utils.checkpoint 6 | from ring_flash_attn.zigzag_ring_flash_attn import zigzag_ring_flash_attn_func 7 | 8 | 9 | def new_flash_attn_forward( 10 | self, 11 | query_states, 12 | key_states, 13 | value_states, 14 | attention_mask, 15 | query_length, 16 | dropout=0.0, 17 | softmax_scale=None, 18 | use_sliding_windows=False, 19 | ): 20 | if not self._flash_attn_uses_top_left_mask: 21 | causal = self.is_causal 22 | else: 23 | causal = self.is_causal and query_length != 1 24 | 25 | # Contains at least one padding token in the sequence 26 | assert attention_mask is None 27 | assert causal is True 28 | assert use_sliding_windows is False 29 | attn_output = zigzag_ring_flash_attn_func( 30 | query_states, 31 | key_states, 32 | value_states, 33 | dropout, 34 | softmax_scale, 35 | causal=causal, 36 | ) 37 | 38 | return attn_output 39 | 40 | 41 | def new_decoder_forward( 42 | self, 43 | hidden_states: torch.Tensor, 44 | attention_mask: Optional[torch.Tensor] = None, 45 | position_ids: Optional[torch.LongTensor] = None, 46 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 47 | output_attentions: Optional[bool] = False, 48 | use_cache: Optional[bool] = False, 49 | cache_position: Optional[torch.LongTensor] = None, 50 | **kwargs, 51 | ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: 52 | assert isinstance( 53 | self.self_attn, transformers.models.llama.modeling_llama.LlamaFlashAttention2 54 | ) or isinstance( 55 | self.self_attn, 56 | transformers.models.mistral.modeling_mistral.MistralFlashAttention2, 57 | ), "Please toggle on the Flash Attention 2 implementation when using zigzag ring attention monkey patch." 58 | 59 | if "padding_mask" in kwargs: 60 | warnings.warn( 61 | "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" 62 | ) 63 | 64 | residual = hidden_states 65 | 66 | hidden_states = self.input_layernorm(hidden_states) 67 | 68 | # Self Attention 69 | hidden_states, self_attn_weights, present_key_value = self.self_attn( 70 | hidden_states=hidden_states, 71 | attention_mask=attention_mask, 72 | position_ids=position_ids, 73 | past_key_value=past_key_value, 74 | output_attentions=output_attentions, 75 | use_cache=use_cache, 76 | cache_position=cache_position, 77 | **kwargs, 78 | ) 79 | hidden_states = residual + hidden_states 80 | 81 | # Fully Connected 82 | residual = hidden_states 83 | hidden_states = self.post_attention_layernorm(hidden_states) 84 | hidden_states = self.mlp(hidden_states) 85 | hidden_states = residual + hidden_states 86 | 87 | outputs = (hidden_states,) 88 | 89 | if output_attentions: 90 | outputs += (self_attn_weights,) 91 | 92 | if use_cache: 93 | outputs += (present_key_value,) 94 | 95 | return outputs 96 | 97 | 98 | def apply_zigzag_ring_attn_monkey_patch_llama(): 99 | transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward = ( 100 | new_flash_attn_forward 101 | ) 102 | transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = ( 103 | new_decoder_forward 104 | ) 105 | 106 | 107 | def apply_zigzag_ring_attn_monkey_patch_mistral(): 108 | transformers.models.mistral.modeling_mistral.MistralFlashAttention2._flash_attention_forward = ( 109 | new_flash_attn_forward 110 | ) 111 | transformers.models.mistral.modeling_mistral.MistralDecoderLayer.forward = ( 112 | new_decoder_forward 113 | ) 114 | -------------------------------------------------------------------------------- /preprocess_token_PI/dataprocessor.py: -------------------------------------------------------------------------------- 1 | 2 | import json 3 | import tqdm 4 | import numpy as np 5 | from transformers import AutoTokenizer 6 | from datasets import load_dataset 7 | import multiprocessing 8 | import pickle 9 | import random 10 | from functools import partial 11 | import os 12 | from concurrent.futures import ThreadPoolExecutor, as_completed 13 | from FSProcessor import FSProcessor 14 | 15 | def process_data_single(dat, llama_tokenizer, llama_3_tokenizer): 16 | new_token = llama_3_tokenizer.encode(llama_tokenizer.decode(dat)) 17 | return {'input_ids': new_token} 18 | 19 | 20 | class DataProcessor: 21 | def __init__(self, dataset_name, target_model_tokenizer_file, dataset_split, model_pre_id, model_post_id, use_length, target_length, model_name, processor_type, select_ratio): 22 | self.dataset_name = dataset_name 23 | self.target_model_tokenizer_file = target_model_tokenizer_file 24 | self.dataset_split = dataset_split 25 | self.model_pre_id = model_pre_id 26 | self.model_post_id = model_post_id 27 | self.use_length = use_length 28 | self.target_length = target_length 29 | self.model_name = model_name 30 | self.retain = 0 31 | self.select_ratio = select_ratio 32 | self.base_length = use_length - self.retain 33 | self.extend_length = target_length - self.retain 34 | 35 | self.llama_tokenizer = AutoTokenizer.from_pretrained(model_pre_id, use_fast_tokenizer=True) 36 | self.llama_3_tokenizer = AutoTokenizer.from_pretrained(model_post_id, use_fast_tokenizer=True) 37 | 38 | if processor_type == "FS": 39 | self.processor = FSProcessor( 40 | model_path=model_post_id, 41 | base_length=self.base_length, 42 | extend_length=self.extend_length, 43 | select_ratio=self.select_ratio 44 | ) 45 | 46 | def load_dataset_(self): 47 | return load_dataset(self.dataset_name, split=self.dataset_split) 48 | 49 | def process_data_parallel(self, dataset): 50 | print('Processing dataset in parallel...') 51 | new_dataset = [dat['input_ids'][:] for dat in tqdm.tqdm(dataset)] 52 | 53 | num_cpus = multiprocessing.cpu_count() 54 | num_processes = min(100, num_cpus) 55 | 56 | length = [] 57 | with open(self.target_model_tokenizer_file, 'w') as f: 58 | with multiprocessing.Pool(processes=num_processes) as pool: 59 | process_func = partial(process_data_single, llama_tokenizer=self.llama_tokenizer, llama_3_tokenizer=self.llama_3_tokenizer) 60 | for new_data in tqdm.tqdm(pool.map(process_func, new_dataset)): 61 | f.write(json.dumps(new_data) + '\n') 62 | length.append(len(new_data['input_ids'])) 63 | print(np.mean(length)) 64 | 65 | def save_tokenized_data(self, processe_file): 66 | processed_data = [] 67 | with open(processe_file, 'r') as f: 68 | for line in f: 69 | processed_data.append(json.loads(line)) 70 | return processed_data 71 | 72 | def load_processed_data(self, processed_file): 73 | processed_data = [] 74 | with open(processed_file, 'r') as f: 75 | for line in f: 76 | processed_data.append(json.loads(line)) 77 | return processed_data 78 | 79 | def run(self): 80 | dataset = self.load_dataset_() 81 | if not os.path.exists(self.target_model_tokenizer_file): 82 | self.process_data_parallel(dataset) 83 | processed_data = self.load_processed_data(self.target_model_tokenizer_file) 84 | 85 | self.processor.run_process(processed_data) 86 | 87 | 88 | if __name__ == "__main__": 89 | processor = DataProcessor( 90 | dataset_name='yaofu/slimpajama-per-source-length-upsample', 91 | target_model_tokenizer_file='output/processed_llama_3_s10l_full.json', 92 | dataset_split='train[0:10000]', 93 | model_pre_id="meta-llama/Llama-2-7b-hf", 94 | model_post_id="meta-llama/Meta-Llama-3-8B", 95 | use_length=24000, 96 | target_length=80000, 97 | model_name='__', 98 | processor_type='FS', 99 | select_ratio=0.6 100 | ) 101 | processor.run() 102 | -------------------------------------------------------------------------------- /preprocess_token_PI/FSProcessor.py: -------------------------------------------------------------------------------- 1 | import json 2 | import tqdm 3 | import numpy as np 4 | from transformers import AutoTokenizer 5 | import pickle 6 | import random 7 | import os 8 | from concurrent.futures import ThreadPoolExecutor, as_completed 9 | 10 | class FSProcessor: 11 | def __init__(self, model_path, base_length=24000, extend_length=80000, min_num=1, max_num=1, select_ratio=0.8): 12 | self.tokenizer = AutoTokenizer.from_pretrained(model_path) 13 | self.base_length = base_length 14 | self.extend_length = extend_length 15 | self.min_num = min_num 16 | self.max_num = max_num 17 | self.select_ratio = select_ratio 18 | self.sentences = [] 19 | self.indexes = [] 20 | 21 | def flatten_list(self, lst): 22 | flattened_list = [] 23 | for item in lst: 24 | if isinstance(item, list): 25 | flattened_list.extend(self.flatten_list(item)) 26 | else: 27 | flattened_list.append(item) 28 | return flattened_list 29 | 30 | def split_sentences(self, data, tokens_to_split): 31 | split_lists = [] 32 | split_idxs = [] 33 | 34 | current_split = [] 35 | current_idx = [] 36 | 37 | for idx, num in enumerate(data): 38 | if num in tokens_to_split: 39 | if current_split: 40 | current_split.append(num) 41 | current_idx.append(idx) 42 | split_lists.append(current_split) 43 | split_idxs.append(current_idx) 44 | current_split = [] 45 | current_idx = [] 46 | else: 47 | current_split.append(num) 48 | current_idx.append(idx) 49 | split_lists.append(current_split) 50 | split_idxs.append(current_idx) 51 | current_split = [] 52 | current_idx = [] 53 | else: 54 | current_split.append(num) 55 | current_idx.append(idx) 56 | if current_split: 57 | split_lists.append(current_split) 58 | split_idxs.append(current_idx) 59 | return split_lists, split_idxs 60 | 61 | def load_and_process_data(self, dataset): 62 | length_6k = [data['input_ids'][:self.base_length] for data in dataset] 63 | tokens_to_split = [13, 198, 30, 0, 627] 64 | 65 | for data in tqdm.tqdm(length_6k): 66 | data = self.flatten_list(data) 67 | split_lists, split_idxs = self.split_sentences(data,tokens_to_split) 68 | self.sentences.append([sent for sent in split_lists]) 69 | self.indexes.append([sent for sent in split_idxs]) 70 | 71 | def process_index(self, idx): 72 | selected = [] 73 | tokenized = self.tokenizer.batch_decode(idx) 74 | for ids, tokens in zip(idx, tokenized): 75 | if any(char.isdigit() for char in tokens): 76 | selected.append(ids) 77 | else: 78 | random_number = random.randint(1, 10) 79 | if random_number < int(self.select_ratio * 10): 80 | continue 81 | else: 82 | selected.append(ids) 83 | return selected 84 | 85 | def process_all_indexes(self, max_workers=40): 86 | selected = [] 87 | with ThreadPoolExecutor(max_workers=max_workers) as executor: 88 | futures = [executor.submit(self.process_index, idx) for idx in self.sentences] 89 | for future in tqdm.tqdm(as_completed(futures), total=len(futures), desc="Processing"): 90 | selected.append(future.result()) 91 | return selected 92 | 93 | def random_chunk_list(self, input_list): 94 | result = [] 95 | i = 0 96 | chunk_sizes = [] 97 | while i < len(input_list): 98 | chunk_size = random.randint(self.min_num, self.max_num) 99 | result.append(input_list[i:i + chunk_size]) 100 | i += chunk_size 101 | chunk_sizes.append(chunk_size) 102 | return result, chunk_sizes 103 | 104 | def process_new_indexes(self): 105 | new_total_indexes = [] 106 | for index in tqdm.tqdm(self.indexes): 107 | result, chunk_size = self.random_chunk_list(list(range(len(index)))) 108 | new_indexes = [] 109 | for idx in result: 110 | new_index = [] 111 | for id in idx: 112 | new_index.extend(index[id]) 113 | new_indexes.append(new_index) 114 | new_total_indexes.append(new_indexes) 115 | return new_total_indexes 116 | 117 | def generate_res_lists(self, new_total_indexes): 118 | res_lists = [] 119 | for index in tqdm.tqdm(new_total_indexes): 120 | random_zeros = [0] * (self.extend_length - self.base_length + len(index)) 121 | 122 | insert_indices = sorted(random.sample(range(len(random_zeros)), len(index))) 123 | 124 | random_zeros[0] = index[0] 125 | for random_number, idx in zip(insert_indices[1:], index[1:]): 126 | random_zeros[random_number] = idx 127 | 128 | res_list = self.flatten_list(random_zeros) 129 | PI_list = [id for id in range(len(res_list)) if res_list[id] != 0] 130 | new_PI_list = [PI_list[0] - 1] + PI_list 131 | if new_PI_list[0] == -1: 132 | updated_PI_list = [x + 1 for x in new_PI_list] 133 | else: 134 | updated_PI_list = new_PI_list 135 | res_lists.append(updated_PI_list) 136 | return res_lists 137 | 138 | def save_data(self, data, file_name): 139 | with open(file_name, "wb") as f: 140 | pickle.dump(data, f) 141 | 142 | def run_process(self,data): 143 | self.load_and_process_data(data) 144 | selected = self.process_all_indexes() 145 | lengths = [len(item) for item in selected] 146 | print(f'sample_length: {np.mean(lengths)}') 147 | 148 | with open(f'output/feature_{str(int(self.select_ratio*10))}_tokens.jsonl', 'w') as f: 149 | for new_data in tqdm.tqdm(selected): 150 | f.write(json.dumps({'input_ids': new_data}) + '\n') 151 | 152 | new_total_indexes = self.process_new_indexes() 153 | res_lists = self.generate_res_lists(new_total_indexes) 154 | 155 | print(len(res_lists[0])) 156 | self.save_data(res_lists, f"output/llama3_LF_{str(int(self.select_ratio*10))}_digits_l{self.base_length}_t{self.extend_length}_min{self.min_num}_max{self.max_num}.pkl") 157 | 158 | right_points = [random.randint(1, (self.base_length + 1) // 2) for _ in range(len(res_lists))] 159 | self.save_data(right_points, f"output/llama3_LPS_{str(int(self.select_ratio*10))}_digits_l{self.base_length}_t{self.extend_length}_min{self.min_num}_max{self.max_num}.pkl") 160 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LongRecipe: Recipe for Efficient Long Context Generalization in Large Language Models 2 | 3 |

4 | 🤗 LongRecipe-Llama3-8B-128k • 🤗 LongRecipe-Qwen2-7B-128k • 📃 Paper 5 | 6 | 7 | ## Project Directory Structure 8 | 9 | 10 | ``` 11 | LongRecipe/ 12 | ├── accelerate_configs/ 13 | │ ├── config_files 14 | ├── utils/ 15 | │ └── preprocess_token_PI/ 16 | │ ├── dataprocessor.py 17 | │ └── FSProcessor.py 18 | │ └── easy_context/ 19 | │ ├── dist_flash_attn/ 20 | │ ├── ulysses_attn/ 21 | │ └── zigzag_ring_attn/ 22 | │ ├── loader.py 23 | │ ├── logger.py 24 | │ └── preprocess_data.py 25 | ├── README.md 26 | ├── train_LR_llama3_target80k_use24k.sh 27 | ├── requirements.txt 28 | └── train.py 29 | ``` 30 | 31 | ## Reproduction: 32 | 33 | Before starting with the data preprocessing and model training, ensure that all necessary dependencies are installed. Use the following command to install the required packages: 34 | 35 | `pip install -r requirements.txt` 36 | 37 | ### Data Preprocessing (Example: Llama3) 38 | 39 | To begin, download the dataset represented by the Llama3 tokenizer from this link. After downloading, execute the following command to generate the position index files for different training approaches: 40 | 41 | 42 | ``` 43 | # Command to load dataset and generate position index files 44 | python preprocess_token_PI/dataprocessor.py 45 | ``` 46 | 47 | 48 | ### Model Training: 49 | 50 | The model training process is divided into three distinct stages to effectively extend the context window of the LLM while maintaining its original capabilities. 51 | 52 | #### Context Window Extension 53 | 54 | In the first stage, we extend the context window using a dataset containing 1.7B tokens. The following command initiates this training stage: 55 | 56 | 57 | ``` 58 | accelerate launch \ 59 | --config_file accelerate_configs/single_node.yaml \ 60 | train.py \ 61 | --batch-size 1 \ 62 | --gradient-accumulate-every 96 \ 63 | --learning-rate 5e-5 \ 64 | --epoch 1 \ 65 | --data_path $DATA_PATH_CONTEXT_EXTENSION \ 66 | --output-dir ./output/$MODEL_NAME-$SETTING-$SEQ_LENGTH-$SUB_LABEL \ 67 | --seed 2027 \ 68 | --model $MODEL \ 69 | --seq-length $SEQ_LENGTH \ 70 | --target-length $TARGET_LENGTH \ 71 | --log-path $SETTING-$SEQ_LENGTH-$MODEL_NAME-$SUB_LABEL.log \ 72 | --setting $SETTING \ 73 | --right_points-path $Right_Points_PATH \ 74 | --fs_PI-path $FS_PI_PATH \ 75 | --parallel_mode ulysses_attn \ 76 | --num_proc 5 \ 77 | --stage 0 78 | ``` 79 | 80 | 81 | Arguments Explanation: 82 | * **--data_path**: Path to the dataset with Llama3-tokenized samples. 83 | * **--model**: The base model used for training. 84 | * **--seq-length**: The sequence length for training. 85 | * **--target-length**: The target context window length. 86 | * **--setting**: The training method, which could include FLT, RPES, PoSE, LongRecipe. 87 | * **--right_points-path**: Path to the PoSE right point set file. 88 | * **--fs_PI-path**: Path to the LongRecipe’s position index file. 89 | 90 | Post-training, copy the tokenizer files to the output directory and remove any unnecessary files: 91 | 92 | ``` 93 | cp $MODEL/special_tokens_map.json ./output/$MODEL_NAME-$SETTING-$SEQ_LENGTH-$SUB_LABEL/stage_0 94 | cp $MODEL/tokenizer_config.json ./output/$MODEL_NAME-$SETTING-$SEQ_LENGTH-$SUB_LABEL/stage_0 95 | cp $MODEL/tokenizer.json ./output/$MODEL_NAME-$SETTING-$SEQ_LENGTH-$SUB_LABEL/stage_0 96 | rm ./output/$MODEL_NAME-$SETTING-$SEQ_LENGTH-$SUB_LABEL/stage_0/model.safetensors 97 | ``` 98 | 99 | #### Stage 2: Training Annealing 100 | 101 | 102 | In the second stage, we perform training annealing using both general and domain-specific data, gradually reducing the learning rate to zero. Approximately 100M tokens of data are used in this phase. 103 | ``` 104 | accelerate launch \ 105 | --config_file accelerate_configs/single_node_2.yaml \ 106 | train.py \ 107 | --data_path $DATA_PATH_ANNEALING \ 108 | --batch-size 1 \ 109 | --gradient-accumulate-every 96 \ 110 | --learning-rate 5e-6 \ 111 | --epoch 1 \ 112 | --output-dir ./output/$MODEL_NAME-$SETTING-$SEQ_LENGTH-$SUB_LABEL \ 113 | --seed 2027 \ 114 | --model $STAGE_1_MODEL \ 115 | --seq-length $SEQ_LENGTH \ 116 | --target-length $TARGET_LENGTH \ 117 | --log-path $SETTING-$SEQ_LENGTH-$MODEL_NAME-$SUB_LABEL.log \ 118 | --setting $SETTING \ 119 | --right_points-path $Right_Points_PATH \ 120 | --fs_PI-path $FS_PI_PATH \ 121 | --parallel_mode ulysses_attn \ 122 | --num_proc 10 \ 123 | --stage 1 124 | ``` 125 | 126 | Copy the updated tokenizer files to the output directory: 127 | 128 | ``` 129 | cp $MODEL/special_tokens_map.json ./output/$MODEL_NAME-$SETTING-$SEQ_LENGTH-$SUB_LABEL/stage_1 130 | cp $MODEL/tokenizer_config.json ./output/$MODEL_NAME-$SETTING-$SEQ_LENGTH-$SUB_LABEL/stage_1 131 | cp $MODEL/tokenizer.json ./output/$MODEL_NAME-$SETTING-$SEQ_LENGTH-$SUB_LABEL/stage_1 132 | rm ./output/$MODEL_NAME-$SETTING-$SEQ_LENGTH-$SUB_LABEL/stage_1/model.safetensors 133 | ``` 134 | 135 | In our experiment, we merge the two datasets mentioned together in out paper, and format each sample as follows: 136 | 137 | ``` 138 | { 139 | "prompt": , 140 | "response": 141 | } 142 | ``` 143 | 144 | #### Stage 3: Model Merge 145 | 146 | The final stage involves merging the original model with the fine-tuned model using an average weight strategy to enhance the model's foundational capabilities. 147 | 148 | ``` 149 | accelerate launch \ 150 | --config_file accelerate_configs/single_node.yaml \ 151 | train.py \ 152 | accelerate launch \ 153 | --config_file accelerate_configs/single_node.yaml \ 154 | train.py \ 155 | --output-dir ./output/$MODEL_NAME-$SETTING-$SEQ_LENGTH-$SUB_LABEL \ 156 | --seed 2027 \ 157 | --model $MODEL \ 158 | --log-path $SETTING-$SEQ_LENGTH-$MODEL_NAME-$SUB_LABEL.log \ 159 | --stage 2 160 | ``` 161 | 162 | You can also run 163 | 164 | ``` 165 | bash ./train_scirpts/train_LR_llama3_target80k_use24k.sh 166 | ``` 167 | 168 | after preprocess your data to do the three stage in one command. 169 | 170 | 200 | 201 | 202 | ## Citation 203 | 204 | If you find this repo helpful, please cite our paper as follows: 205 | 206 | ``` 207 | @article{hu2024longrecipe, 208 | title={LongRecipe: Recipe for Efficient Long Context Generalization in Large Languge Models}, 209 | author={Zhiyuan Hu, Yuliang Liu, Jinman Zhao, Suyuchen Wang, Yan Wang, Wei Shen, Qing Gu, Anh Tuan Luu, See-Kiong Ng, Zhiwei Jiang, Bryan Hooi}, 210 | journal={arXiv preprint arXiv:2409.00509}, 211 | year={2024} 212 | } 213 | ``` 214 | -------------------------------------------------------------------------------- /utils/preprocess_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import tqdm 4 | import tiktoken 5 | import random 6 | import pickle 7 | from itertools import chain 8 | from typing import TYPE_CHECKING, Any, Dict, List, Literal, Union 9 | import pdb 10 | from transformers import TrainingArguments, HfArgumentParser 11 | import torch 12 | import random 13 | 14 | 15 | def flatten_list(lst): 16 | flattened_list = [] 17 | for item in lst: 18 | if isinstance(item, list): 19 | flattened_list.extend(flatten_list(item)) 20 | else: 21 | flattened_list.append(item) 22 | return flattened_list 23 | 24 | def preprocess_dataset( 25 | dataset: Union["Dataset", "IterableDataset"], 26 | tokenizer: "PreTrainedTokenizer", 27 | model_type, 28 | setting_choice, 29 | seq_length_, 30 | target_len, 31 | right_points_file, 32 | fs_PI_s_file, 33 | num_proc, 34 | ) -> Union["Dataset", "IterableDataset"]: 35 | 36 | global seq_length 37 | seq_length = seq_length_ 38 | def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]: 39 | 40 | global seq_length 41 | 42 | if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen) 43 | kwargs = dict(allowed_special="all") 44 | else: 45 | kwargs = dict(add_special_tokens=True) 46 | 47 | if hasattr(tokenizer, "add_eos_token"): # for LLaMA tokenizer 48 | add_eos_token_flag = getattr(tokenizer, "add_eos_token") 49 | setattr(tokenizer, "add_eos_token", True) 50 | 51 | if setting_choice == 'full_str': 52 | sentence = [] 53 | for data_1, data_2 in zip(examples["prompt"], examples["response"]): 54 | sentence.append(data_1 + data_2) 55 | 56 | tokenized_examples = tokenizer(sentence, **kwargs) 57 | 58 | else: 59 | tokenized_examples = {"input_ids": examples["input_ids"]} 60 | 61 | try: 62 | new_tokenized_examples = {'input_ids': [data[:] for data in tokenized_examples['input_ids']]} 63 | except: 64 | new_tokenized_examples = {'input_ids': [data[:] for data in tokenized_examples]} 65 | 66 | concatenated_examples = {k: list(chain(*new_tokenized_examples[k])) for k in new_tokenized_examples.keys()} 67 | 68 | 69 | total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]]) 70 | block_size = seq_length 71 | total_length = (total_length // block_size) * block_size 72 | # split by chunks of cutoff_len 73 | result = { 74 | k: [t[i: i + block_size] for i in range(0, total_length, block_size)] 75 | for k, t in concatenated_examples.items() 76 | } 77 | if hasattr(tokenizer, "add_eos_token"): 78 | setattr(tokenizer, "add_eos_token", add_eos_token_flag) 79 | 80 | if setting_choice == 'full' or setting_choice == 'full_str': 81 | return result 82 | 83 | if setting_choice == 'LongRecipe': 84 | # llama 3 RSS-S 85 | 86 | with open(right_points_file, 'rb') as fr: 87 | rt1s = pickle.load(fr) 88 | with open(fs_PI_s_file, 'rb') as fr: 89 | topk = pickle.load(fr) 90 | 91 | input_ids = [] 92 | position_ids = [] 93 | if 'llama2' in model_type: 94 | retain = 1 95 | else: 96 | retain = 0 97 | 98 | for idx, ids in tqdm.tqdm(enumerate(new_tokenized_examples["input_ids"])): 99 | ids = flatten_list(ids) 100 | base_length = min(seq_length, len(ids)) 101 | 102 | if len(topk[idx]) == 0: 103 | continue 104 | rt1 = rt1s[idx] 105 | new_input_ids = ids[:base_length] 106 | pos_ids = torch.arange(retain + rt1).tolist() 107 | new_position_ids = topk[idx][rt1:base_length] 108 | pos_ids.extend(new_position_ids) 109 | 110 | try: 111 | assert len(pos_ids) == len(new_input_ids) 112 | position_ids.append(pos_ids) 113 | input_ids.append(new_input_ids) 114 | except: 115 | print(len(pos_ids)) 116 | print(len(new_input_ids)) 117 | 118 | model_inputs = {"input_ids": input_ids, "position_ids": position_ids} 119 | return model_inputs 120 | 121 | 122 | elif setting_choice == 'R': 123 | 124 | input_ids = [] 125 | position_ids = [] 126 | 127 | retain = 0 128 | 129 | 130 | for idx, ids in tqdm.tqdm(enumerate(new_tokenized_examples["input_ids"])): 131 | base_length = min(seq_length, len(ids)) 132 | new_input_ids = ids[:base_length] 133 | 134 | pos_ids = torch.arange(retain, dtype=torch.long).tolist() 135 | pos_ids.extend(sorted(random.sample(list(range(retain, target_len)), base_length-retain))) 136 | 137 | input_ids.append(new_input_ids) 138 | position_ids.append(pos_ids) 139 | try: 140 | assert len(pos_ids) == len(new_input_ids) 141 | except: 142 | print(len(pos_ids)) 143 | print(len(new_input_ids)) 144 | # pdb.set_trace() 145 | 146 | model_inputs = {"input_ids": input_ids, "position_ids": position_ids} 147 | return model_inputs 148 | 149 | elif setting_choice == 'RSS_S_base': 150 | 151 | with open(right_points_file, 'rb') as fr: 152 | rt1s = pickle.load(fr) 153 | with open(fs_PI_s_file, 'rb') as fr: 154 | topk = pickle.load(fr) 155 | 156 | input_ids = [] 157 | position_ids = [] 158 | 159 | retain = 0 160 | 161 | for idx, ids in tqdm.tqdm(enumerate(new_tokenized_examples["input_ids"])): 162 | base_length = min(seq_length, len(ids)) 163 | if len(topk[idx]) == 0: 164 | continue 165 | rt1 = 8000 166 | # rt1 = random.randint(1, (base_length+1)//2) 167 | new_input_ids = ids[:base_length] 168 | 169 | pos_ids = torch.arange(rt1).tolist() 170 | new_position_ids = topk[idx][rt1:base_length] 171 | pos_ids.extend(new_position_ids) 172 | 173 | position_ids.append(pos_ids) 174 | input_ids.append(new_input_ids) 175 | assert len(pos_ids) == len(new_input_ids) 176 | 177 | model_inputs = {"input_ids": input_ids, "position_ids": position_ids} 178 | return model_inputs 179 | 180 | 181 | elif setting_choice == 'pose': 182 | with open(right_points_file, 'rb') as fr: 183 | rt1s = pickle.load(fr) 184 | 185 | 186 | with open(fs_PI_s_file, 'rb') as fr: 187 | topk = pickle.load(fr) 188 | 189 | rts = [] 190 | lt1s = [] 191 | input_ids = [] 192 | position_ids = [] 193 | if 'llama2' in model_type: 194 | retain = 1 195 | else: 196 | retain = 0 197 | 198 | scaled_max_position_embeddings = target_len 199 | 200 | 201 | for idx, ids in tqdm.tqdm(enumerate(new_tokenized_examples["input_ids"])): 202 | 203 | # if 'qwen2_7b' in model_type or 'mistral_3_7b' in model_type: 204 | # if idx >= 5000:continue 205 | 206 | base_length = min(seq_length, len(ids)) 207 | len_chunk = min(len(ids), base_length) 208 | len_input = len(ids) 209 | lt1 = 0 210 | rt1 = rt1s[idx] 211 | lt1 += retain; rt1 += retain 212 | chunked_ids = ids[:len_chunk] 213 | new_input_ids = ids[:retain] + chunked_ids 214 | # new_input_ids.extend(chunked_ids) 215 | input_ids.append(new_input_ids) 216 | 217 | pos_ids = torch.arange(len(chunked_ids), dtype=torch.long) 218 | len_pos_ids = len(pos_ids) 219 | lt = random.randint(0, scaled_max_position_embeddings-len_pos_ids) 220 | # lt = 0 221 | rt = random.randint(lt, scaled_max_position_embeddings-len_pos_ids) 222 | 223 | new_pos_ids = torch.arange(lt, dtype=torch.long) 224 | new_pos_ids = torch.cat((new_pos_ids, pos_ids[lt:] + (rt + retain))) 225 | 226 | position_ids.append(new_pos_ids) 227 | 228 | model_inputs = {"input_ids": input_ids, "position_ids": position_ids} 229 | return model_inputs 230 | 231 | elif setting_choice == 'RSS_L': 232 | 233 | with open(right_points_file, 'rb') as fr: 234 | rt1s = pickle.load(fr) 235 | 236 | with open(fs_PI_s_file, 'rb') as fr: 237 | topk = pickle.load(fr) 238 | 239 | input_ids = [] 240 | position_ids = [] 241 | 242 | retain = 0 243 | for idx, ids in tqdm.tqdm(enumerate(new_tokenized_examples["input_ids"])): 244 | if len(topk[idx]) == 0: 245 | continue 246 | new_input_ids = ids[:retain] 247 | try: 248 | new_delta_ids = [ids[i] for i in topk[idx][:seq_length]] 249 | except: 250 | # pdb.set_trace() 251 | pass 252 | new_input_ids.extend(new_delta_ids) 253 | 254 | 255 | pos_ids = torch.arange(rt1s[idx]).tolist() 256 | new_position_ids = topk[idx][rt1s[idx]:seq_length] 257 | pos_ids.extend(new_position_ids) 258 | 259 | try: 260 | assert len(pos_ids) == len(new_input_ids) 261 | position_ids.append(pos_ids) 262 | input_ids.append(new_input_ids) 263 | except: 264 | pass 265 | # pdb.set_trace() 266 | model_inputs = {"input_ids": input_ids, "position_ids": position_ids} 267 | # print(len(model_inputs)) 268 | return model_inputs 269 | 270 | 271 | preprocess_func = preprocess_pretrain_dataset 272 | 273 | 274 | if len(dataset) == 1: 275 | column_names = [] 276 | else: 277 | column_names = list(next(iter(dataset)).keys()) 278 | kwargs = {} 279 | kwargs = dict( 280 | num_proc=num_proc, 281 | load_from_cache_file=True, 282 | desc="Running tokenizer on dataset" 283 | ) 284 | dataset = dataset.map( 285 | preprocess_func, 286 | batched=True, 287 | remove_columns=column_names, 288 | **kwargs 289 | ) 290 | 291 | return dataset 292 | -------------------------------------------------------------------------------- /utils/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | import pickle 5 | import random 6 | from tqdm import tqdm 7 | from datetime import timedelta 8 | from accelerate import Accelerator 9 | from utils.loader import load_logger, load_model_and_tokenizer 10 | from transformers import AutoModelForCausalLM, AutoTokenizer 11 | from utils.preprocess_data import preprocess_dataset 12 | from accelerate.utils import InitProcessGroupKwargs, set_seed 13 | from torch.utils.data import DataLoader 14 | from datasets import load_dataset, load_from_disk, DatasetDict 15 | from transformers import default_data_collator 16 | import transformers 17 | from flash_attn.losses.cross_entropy import CrossEntropyLoss 18 | import math 19 | from accelerate.utils import ( 20 | InitProcessGroupKwargs, 21 | set_seed, 22 | DummyOptim, 23 | DummyScheduler, 24 | ) 25 | from utils.easy_context import ( 26 | prepare_dataloader, 27 | prepare_seq_parallel_inputs, 28 | apply_seq_parallel_monkey_patch, 29 | apply_unsloth_offloaded_gradient_checkpoint_monkey_patch 30 | ) 31 | 32 | 33 | class LongRecipe_TRAIN: 34 | """ 35 | Train LLMs with LongRecipe. 36 | """ 37 | def __init__( 38 | self, 39 | output_path, 40 | log_path, 41 | model_path, 42 | data_path, 43 | right_points_path, 44 | fs_PI_path, 45 | learning_rate, 46 | gradient_accumulate_every, 47 | epoch, 48 | batch_size, 49 | parallel_mode, 50 | setting, 51 | stage, 52 | seq_length, 53 | target_length, 54 | num_proc, 55 | seed=None, 56 | 57 | ): 58 | # logger 59 | self.logger, self.tqdm_out = load_logger(log_path) 60 | # set_seed 61 | self.seed = seed 62 | set_seed(self.seed) 63 | 64 | # training setting 65 | self.setting = setting 66 | 67 | # paths 68 | self.model_path = model_path 69 | self.output_path = output_path 70 | self.data_path = data_path 71 | self.fs_PI_path = fs_PI_path 72 | self.right_points_path = right_points_path 73 | 74 | if self.output_path: 75 | os.makedirs(self.output_path, exist_ok=True) 76 | 77 | # parallel_mode 78 | self.parallel_mode = parallel_mode 79 | 80 | # training params 81 | self.learning_rate = learning_rate 82 | self.batch_size = batch_size 83 | self.gradient_accumulate_every = gradient_accumulate_every 84 | self.epoch = epoch 85 | self.stage = stage 86 | 87 | # config 88 | self.seq_length = seq_length 89 | self.target_length = target_length 90 | self.num_proc = num_proc 91 | 92 | 93 | def load_dataset_util(self, stage, data_path, tokenizer, batch_size, accelerator, model_type): 94 | try: 95 | train_dataset = load_dataset(data_path) 96 | except: 97 | train_dataset = load_dataset('json', data_files=data_path)['train'] 98 | 99 | 100 | if stage == 0: 101 | train_dataset = preprocess_dataset(train_dataset, 102 | tokenizer=tokenizer, 103 | model_type=model_type, 104 | setting_choice=self.setting, 105 | seq_length_=self.seq_length, 106 | target_len=self.target_length, 107 | right_points_file=self.right_points_path, 108 | fs_PI_s_file=self.fs_PI_path, 109 | num_proc=self.num_proc 110 | ) 111 | print("Dataset Size:", len(train_dataset)) 112 | 113 | train_loader = DataLoader( 114 | train_dataset, 115 | collate_fn=default_data_collator, 116 | shuffle=False, 117 | batch_size=batch_size, 118 | ) 119 | 120 | train_dataset_loader = prepare_dataloader(self.parallel_mode, train_loader, accelerator) 121 | 122 | 123 | if stage == 1: 124 | train_dataset = preprocess_dataset(train_dataset, 125 | tokenizer=tokenizer, 126 | setting_choice='full_str', 127 | model_type=self.model_path, 128 | seq_length_=self.seq_length, 129 | target_len=self.target_length, 130 | right_points_file=self.right_points_path, 131 | fs_PI_s_file=self.fs_PI_path, 132 | num_proc=self.num_proc 133 | ) 134 | print("Dataset Size:", len(train_dataset)) 135 | 136 | train_loader = DataLoader( 137 | train_dataset, 138 | collate_fn=default_data_collator, 139 | shuffle=False, 140 | batch_size=batch_size, 141 | ) 142 | train_dataset_loader = prepare_dataloader(self.parallel_mode, train_loader, accelerator) 143 | 144 | return train_dataset_loader 145 | 146 | 147 | def prepare_accelerator(self, stage): 148 | 149 | # accelerator 150 | timeout = InitProcessGroupKwargs(timeout=timedelta(seconds=1_000_000)) 151 | 152 | accelerator = Accelerator( 153 | gradient_accumulation_steps=self.gradient_accumulate_every if stage == 0 else self.gradient_accumulate_every, 154 | mixed_precision="bf16", 155 | kwargs_handlers=[timeout], 156 | # fsdp_plugin=fsdp_plugin, 157 | ) 158 | return accelerator 159 | 160 | def prepare_model_and_tokenizer(self, stage, accelerator): 161 | if stage == 0: 162 | model, tokenizer = load_model_and_tokenizer(self.model_path, accelerator) 163 | 164 | elif stage == 1: 165 | # model, tokenizer = load_model_and_tokenizer(self.model_path, accelerator) 166 | model, tokenizer = load_model_and_tokenizer(self.output_path + '/stage_' + str(stage-1) , accelerator) 167 | 168 | 169 | assert isinstance( 170 | model, (transformers.LlamaForCausalLM, transformers.MistralForCausalLM) 171 | ), "Only support llama and mistral model" 172 | model_type = ( 173 | "llama" if isinstance(model, transformers.LlamaForCausalLM) else "mistral" 174 | ) 175 | apply_seq_parallel_monkey_patch(self.parallel_mode, model_type) 176 | 177 | return model, tokenizer 178 | 179 | def prepare_scheduler(self, stage, optim, num_training_steps): 180 | # scheduler 181 | scheduler = DummyScheduler( 182 | optim, 183 | num_training_steps=num_training_steps, 184 | total_num_steps=num_training_steps, 185 | ) 186 | return scheduler 187 | 188 | 189 | def prepare_optimizer(self, stage, model, learning_rate): 190 | 191 | optim = DummyOptim(model.parameters(), lr=learning_rate) 192 | return optim 193 | 194 | def prepare_and_check_params(self, stage, epoch, train_dataset_loader, gradient_accumulate_every): 195 | 196 | # calculate training steps 197 | num_training_steps = math.ceil(epoch * train_dataset_loader.dataset.shape[0] / gradient_accumulate_every) 198 | print(num_training_steps) 199 | return num_training_steps 200 | 201 | def prepare_loss_fn(self): 202 | loss_func = CrossEntropyLoss(inplace_backward=True) 203 | return loss_func 204 | 205 | def prepare_modules(self, stage, lr, data_path, epoch, batch_size, gradient_accumulate_every): 206 | """ 207 | prepare accelerator, model, tokenizer, dataloader, optimizer, scheduler 208 | """ 209 | 210 | # if stage == 0: 211 | accelerator = self.prepare_accelerator(stage=stage) 212 | model, tokenizer = self.prepare_model_and_tokenizer(stage, accelerator) 213 | train_data_loader = self.load_dataset_util(stage=stage, data_path=data_path, tokenizer=tokenizer, batch_size=batch_size, accelerator=accelerator, model_type=self.model_path) 214 | optim = self.prepare_optimizer(stage=stage, model=model, learning_rate=lr) 215 | num_training_steps = self.prepare_and_check_params(stage=stage, epoch=epoch, train_dataset_loader=train_data_loader, gradient_accumulate_every=gradient_accumulate_every) 216 | scheduler = self.prepare_scheduler(stage, optim, num_training_steps) 217 | model, optim, scheduler = accelerator.prepare(model, optim, scheduler) 218 | 219 | model.gradient_checkpointing_enable() 220 | accelerator.register_for_checkpointing(scheduler) 221 | accelerator.print(f"Max train epoches: {epoch}, Max train steps: {num_training_steps}") 222 | 223 | progress_bar = tqdm( 224 | range(num_training_steps), file=self.tqdm_out, mininterval=1, disable=not accelerator.is_local_main_process 225 | ) 226 | 227 | loss_func = self.prepare_loss_fn() 228 | 229 | return model, accelerator, train_data_loader, loss_func, optim, scheduler, progress_bar 230 | 231 | 232 | def train(self, stage, model, accelerator, train_data_loader, loss_func, optim, scheduler, progress_bar): 233 | completed_steps = 0 234 | for idx, batch in enumerate(train_data_loader): 235 | 236 | input_ids = batch["input_ids"][0][..., :-1].unsqueeze(dim=0) 237 | target_ids = batch["input_ids"][0][..., 1:].unsqueeze(dim=0) 238 | if stage != 1: 239 | try: 240 | position_ids = batch["position_ids"][0][..., : input_ids.shape[-1]].unsqueeze(dim=0) 241 | except: 242 | print('Position idx error, check your pkl in train.py') 243 | position_ids = torch.arange(input_ids.shape[-1]).unsqueeze(0) 244 | 245 | prepared = prepare_seq_parallel_inputs( 246 | self.parallel_mode, 247 | input_ids, 248 | position_ids, 249 | target_ids, 250 | accelerator.process_index, 251 | accelerator.num_processes, 252 | accelerator.device, 253 | ) 254 | 255 | else: 256 | position_ids = torch.arange(input_ids.shape[-1]).unsqueeze(0) 257 | prepared = prepare_seq_parallel_inputs( 258 | self.parallel_mode, 259 | input_ids, 260 | position_ids, 261 | target_ids, 262 | accelerator.process_index, 263 | accelerator.num_processes, 264 | accelerator.device, 265 | ) 266 | 267 | local_input_ids = prepared["local_input_ids"] 268 | local_position_ids = prepared["local_position_ids"] 269 | local_target_ids = prepared["local_target_ids"] 270 | 271 | loss_log = None 272 | 273 | with accelerator.accumulate(model): 274 | logits = model( 275 | local_input_ids, 276 | position_ids=local_position_ids, 277 | ).logits 278 | loss = loss_func( 279 | logits.reshape(-1, logits.shape[-1]), local_target_ids.reshape(-1) 280 | ) 281 | try: 282 | accelerator.backward(loss) 283 | except: 284 | self.logger.log(msg='wrong step', level=0) 285 | 286 | if accelerator.sync_gradients: 287 | gathered_loss = accelerator.reduce(loss.clone().detach(), "mean") 288 | loss_log = { 289 | "loss": gathered_loss.item(), 290 | "ppl": math.exp(gathered_loss.item()), 291 | } 292 | accelerator.log(loss_log, step=completed_steps) 293 | 294 | optim.step() 295 | scheduler.step() 296 | optim.zero_grad() 297 | 298 | if accelerator.sync_gradients: 299 | progress_bar.update(1) 300 | if loss_log is not None: 301 | progress_bar.set_postfix(loss_log) 302 | completed_steps += 1 303 | return model, accelerator 304 | 305 | def finish_training(self, stage, model, accelerator): 306 | accelerator.print(f"Training Finished") 307 | accelerator.end_training() 308 | 309 | if self.output_path is not None: 310 | accelerator.print(f"Saving model to {self.output_path}, stage: {stage}") 311 | 312 | accelerator.wait_for_everyone() 313 | 314 | state_dict = accelerator.get_state_dict(model) 315 | 316 | accelerator.unwrap_model(model).save_pretrained( 317 | f"{self.output_path + '/stage_' + str(stage) + '/'}", 318 | is_main_process=accelerator.is_main_process, 319 | save_function=accelerator.save, 320 | state_dict=state_dict, 321 | ) 322 | 323 | accelerator.print(f"Saving Finished") 324 | 325 | def merge_model(self, stage, model_1_path, model_2_path): 326 | 327 | model_1 = AutoModelForCausalLM.from_pretrained(model_1_path) 328 | model_2 = AutoModelForCausalLM.from_pretrained(model_2_path) 329 | 330 | model_1_params = list(model_1.named_parameters()) 331 | model_2_params = list(model_2.named_parameters()) 332 | 333 | assert len(model_1_params) == len(model_2_params), "The two models do not have the same number of parameters" 334 | 335 | delta_params = {} 336 | 337 | for (name_1, param_1), (name_2, param_2) in zip(model_1.named_parameters(), model_2.named_parameters()): 338 | delta_params[name_1] = (param_2.data + param_1.data) / 2 339 | 340 | 341 | new_model = AutoModelForCausalLM.from_pretrained(model_1_path) 342 | 343 | 344 | for name, param in new_model.named_parameters(): 345 | if name in delta_params: 346 | param.data = delta_params[name] 347 | 348 | 349 | new_model.save_pretrained(self.output_path + '/stage_' + str(stage) + '/',) 350 | 351 | 352 | def train_with_stage(self): 353 | stage = self.stage 354 | 355 | if stage == 0: 356 | 357 | model, accelerator, train_data_loader, loss_func, optim, scheduler, progress_bar = self.prepare_modules(stage=stage, lr=self.learning_rate, data_path=self.data_path, epoch=self.epoch, batch_size=self.batch_size, gradient_accumulate_every=self.gradient_accumulate_every) 358 | model.train() 359 | model, accelerator = self.train(stage, model, accelerator, train_data_loader, loss_func, optim, scheduler, progress_bar) 360 | self.finish_training(stage, model, accelerator) 361 | 362 | elif stage == 1: 363 | model, accelerator, train_data_loader, loss_func, optim, scheduler, progress_bar = self.prepare_modules(stage=stage, lr=self.learning_rate, data_path=self.data_path, epoch=self.epoch, batch_size=self.batch_size, gradient_accumulate_every=self.gradient_accumulate_every) 364 | model.train() 365 | model, accelerator = self.train(stage, model, accelerator, train_data_loader, loss_func, optim, scheduler, progress_bar) 366 | self.finish_training(stage, model, accelerator) 367 | 368 | 369 | elif stage == 2: 370 | self.merge_model(stage, self.output_path + '/stage_' + str(stage-1) + '/', self.model_path) 371 | 372 | 373 | 374 | 375 | if __name__ == "__main__": 376 | args = argparse.ArgumentParser() 377 | args.add_argument("--batch-size", type=int, default=1) 378 | args.add_argument("--gradient-accumulate-every", type=int, default=8) 379 | args.add_argument("--epoch", type=int, default=1) 380 | args.add_argument("--learning-rate", type=float, default=2e-5) 381 | args.add_argument("--data_path",type=str, default=None) 382 | args.add_argument("--output-dir", type=str, required=True) 383 | args.add_argument("--wandb", type=str) 384 | args.add_argument("--seed", type=int, default=42) 385 | args.add_argument("--model", type=str, default="meta-llama/Llama-2-7b-hf") 386 | args.add_argument("--seq-length", type=int, default=16384) 387 | args.add_argument("--model-base-length", type=int, default=8192) 388 | args.add_argument("--target-length", type=int, default=128000) 389 | args.add_argument("--setting", type=str, default=None) 390 | args.add_argument("--right_points-path", type=str, default=None) 391 | args.add_argument("--fs_PI-path", type=str, default=None) 392 | args.add_argument("--log-path", type=str, default=None) 393 | args.add_argument("--stage", type=int, default=3) 394 | args.add_argument("--num_proc", type=int, default=1) 395 | args.add_argument( 396 | "--parallel_mode", 397 | type=str, 398 | choices=["zigzag_ring_attn", "dist_flash_attn", "ulysses_attn", "data_parallel"], 399 | ) 400 | 401 | args = args.parse_args() 402 | 403 | LongRecipe_train = LongRecipe_TRAIN( 404 | output_path=args.output_dir, 405 | log_path=args.log_path, 406 | model_path=args.model, 407 | data_path=args.data_path, 408 | right_points_path=args.right_points_path, 409 | fs_PI_path=args.fs_PI_path, 410 | seed=args.seed, 411 | learning_rate=args.learning_rate, 412 | gradient_accumulate_every=args.gradient_accumulate_every, 413 | epoch=args.epoch, 414 | batch_size=args.batch_size, 415 | parallel_mode=args.parallel_mode, 416 | stage=args.stage, 417 | setting=args.setting, 418 | seq_length=args.seq_length, 419 | target_length=args.target_length, 420 | num_proc=args.num_proc, 421 | ) 422 | 423 | LongRecipe_train.train_with_stage() 424 | -------------------------------------------------------------------------------- /utils/easy_context/dist_flash_attn/async_communication.py: -------------------------------------------------------------------------------- 1 | import threading 2 | import math 3 | import os 4 | 5 | import torch 6 | import torch.distributed as dist 7 | from torch.distributed import batch_isend_irecv, P2POp, isend, irecv 8 | 9 | # Sequence parallel group that the current rank belongs to. 10 | _SEQUENCE_PARALLEL_GROUP = None 11 | 12 | # These values enable us to change the sequence parallel sizes on the fly. 13 | _SEQUENCE_PARALLEL_SIZE = None 14 | _SEQUENCE_PARALLEL_RANK = None 15 | 16 | # Global buffer for P2P 17 | _PEER_Q = None 18 | _PEER_K = None 19 | _PEER_V = None 20 | _PEER_M = None 21 | _PEER_L = None 22 | _PEER_O = None 23 | _PEER_Q_BWD = None 24 | _PEER_K_BWD = None 25 | _PEER_V_BWD = None 26 | _PEER_O_BWD = None 27 | 28 | _DELTA_DQ = None 29 | _PEER_L = None 30 | _DELTA_DK = None 31 | _DELTA_DV = None 32 | _DK_DELTA_FROM_PEER = None 33 | _DV_DELTA_FROM_PEER = None 34 | _PEER_DO = None 35 | 36 | 37 | _fwd_send_volume = 0 38 | _fwd_recv_volume = 0 39 | _bwd_send_volume = 0 40 | _bwd_recv_volume = 0 41 | 42 | def initialize_distributed(): 43 | if dist.is_initialized(): 44 | if dist.get_rank() == 0: 45 | print( 46 | "torch distributed is already initialized, " 47 | "skipping initialization ...", 48 | flush=True, 49 | ) 50 | else: 51 | if int(os.environ["RANK"]) == 0: 52 | print("Initializing Torch distributed.") 53 | dist.init_process_group(backend="nccl") 54 | local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) 55 | global_world_size = dist.get_world_size() 56 | torch.cuda.set_device(dist.get_rank() % local_world_size) 57 | 58 | _initialize_sequence_parallel() 59 | # create_nccl_communicators() 60 | 61 | def _initialize_sequence_parallel(sequence_parallel_size=None): 62 | # Get world size and rank. Ensure some consistencies. 63 | assert sequence_parallel_size is None, "Multiple sequence parallel group not implemented." 64 | assert torch.distributed.is_initialized() 65 | world_size: int = torch.distributed.get_world_size() 66 | 67 | if sequence_parallel_size is None: 68 | sequence_parallel_size = world_size 69 | else: 70 | assert world_size % sequence_parallel_size == 0 71 | num_sequence_parallel_groups: int = world_size // sequence_parallel_size 72 | 73 | rank = torch.distributed.get_rank() 74 | 75 | # Build the sequence parallel groups. 76 | global _SEQUENCE_PARALLEL_GROUP 77 | global _SEQUENCE_PARALLEL_RANK 78 | global _SEQUENCE_PARALLEL_SIZE 79 | 80 | assert ( 81 | _SEQUENCE_PARALLEL_GROUP is None 82 | ), 'sequence parallel group is already initialized' 83 | for i in range(num_sequence_parallel_groups): 84 | ranks = range(i * sequence_parallel_size, (i + 1) * sequence_parallel_size) 85 | group = torch.distributed.new_group(ranks) 86 | if rank in ranks: 87 | _SEQUENCE_PARALLEL_GROUP = group 88 | _SEQUENCE_PARALLEL_RANK = ranks.index(rank) 89 | _SEQUENCE_PARALLEL_SIZE = len(ranks) 90 | 91 | if dist.get_rank() == 0: 92 | print("************ Finish sequence pralell group Initialization. ***********") 93 | # _set_global_memory_buffer() 94 | 95 | def maybe_get_set_global_memory_buffer(q, k, v, m, l, o): 96 | global _PEER_Q, _PEER_K, _PEER_V, _PEER_M, _PEER_L, _PEER_O 97 | if _PEER_Q is None: 98 | try: 99 | if get_sequence_parallel_rank() == 0: 100 | print("Initializing global memoery buffer.") 101 | except: 102 | print("Initializing global memoery buffer.") 103 | _PEER_Q = [torch.empty_like(q) for _ in range(2)] 104 | _PEER_K = [torch.empty_like(k) for _ in range(2)] 105 | _PEER_V = [torch.empty_like(v) for _ in range(2)] 106 | _PEER_M = [torch.empty_like(m) for _ in range(2)] 107 | _PEER_L = [torch.empty_like(l) for _ in range(2)] 108 | _PEER_O = [torch.empty_like(o) for _ in range(2)] 109 | 110 | return _PEER_Q, _PEER_K, _PEER_V, _PEER_M, _PEER_L, _PEER_O 111 | 112 | def maybe_get_set_global_memory_buffer_bwd(dq, dk, dv, q, L, k, v, o, do): 113 | global _DELTA_DQ, _DELTA_DK, _DELTA_DV, _DK_DELTA_FROM_PEER, _DV_DELTA_FROM_PEER,_PEER_Q_BWD, _PEER_L, _PEER_K_BWD, _PEER_V_BWD, _PEER_O_BWD, _PEER_DO 114 | if _DELTA_DQ is None: 115 | try: 116 | if get_sequence_parallel_rank() == 0: 117 | print("Initializing global memoery buffer for backward.") 118 | except: 119 | print("Initializing global memoery buffer for backward.") 120 | _DELTA_DQ = [torch.empty_like(dq) for _ in range(2)] 121 | _DELTA_DK = [torch.empty_like(dk) for _ in range(2)] 122 | _DELTA_DV = [torch.empty_like(dv) for _ in range(2)] 123 | _PEER_L = [torch.empty_like(L) for _ in range(2)] 124 | 125 | _DK_DELTA_FROM_PEER = torch.empty_like(dk) 126 | _DV_DELTA_FROM_PEER = torch.empty_like(dv) 127 | 128 | # may already be initailized in the forward call. 129 | # current forward and backward needs a transpose in q's format 130 | _PEER_Q_BWD = [torch.empty_like(q) for _ in range(2)] 131 | _PEER_K_BWD = [torch.empty_like(k) for _ in range(2)] 132 | _PEER_V_BWD = [torch.empty_like(v) for _ in range(2)] 133 | _PEER_O_BWD = [torch.empty_like(o) for _ in range(2)] 134 | 135 | _PEER_DO = [torch.empty_like(do) for _ in range(2)] 136 | 137 | return _DELTA_DQ, _DELTA_DK, _DELTA_DV, _DK_DELTA_FROM_PEER, _DV_DELTA_FROM_PEER, _PEER_Q_BWD, _PEER_L, _PEER_K_BWD, _PEER_V_BWD, _PEER_O_BWD, _PEER_DO 138 | 139 | def reset_global_memory_buffer(): 140 | global _PEER_Q, _PEER_K, _PEER_V, _PEER_M, _PEER_L, _PEER_O, _DELTA_DQ, _PEER_L, _DELTA_DK, _DELTA_DV, _DK_DELTA_FROM_PEER, _DV_DELTA_FROM_PEER, _PEER_DO 141 | _PEER_Q = None 142 | _PEER_K = None 143 | _PEER_V = None 144 | _PEER_M = None 145 | _PEER_L = None 146 | _PEER_O = None 147 | 148 | _DELTA_DQ = None 149 | _PEER_L = None 150 | _DELTA_DK = None 151 | _DELTA_DV = None 152 | _DK_DELTA_FROM_PEER = None 153 | _DV_DELTA_FROM_PEER = None 154 | _PEER_DO = None 155 | 156 | # Pytorch defers the creation of nccl communicators to the first P2P call, 157 | # We manually create them so the first isend does not hang without an irecv. 158 | # reference: https://github.com/pytorch/pytorch/blob/main/torch/csrc/cuda/nccl.cpp#L138 159 | # Only support even number of GPUs. 160 | def create_nccl_communicators(): 161 | seq_rank = get_sequence_parallel_rank() 162 | seq_group = get_sequence_parallel_group() 163 | 164 | empty_tensor = torch.empty(1,).cuda() 165 | empty_tensor_2 = torch.empty(1,).cuda() 166 | if torch.distributed.get_rank() % 2 == 0: 167 | # sender 168 | op1 = P2POp(op=isend, tensor=torch.empty(1,).cuda(), peer=seq_rank+1, group=seq_group) 169 | op2 = P2POp(op=irecv, tensor=torch.empty(1,).cuda(), peer=seq_rank+1, group=seq_group) 170 | #req = torch.distributed.isend(tensor=empty_tensor, dst=seq_rank + 1, group=seq_group) 171 | dist.batch_isend_irecv([op1, op2]) 172 | else: 173 | # receiver 174 | op1 = P2POp(op=irecv, tensor=torch.empty(1,).cuda(), peer=seq_rank-1, group=seq_group) 175 | op2 = P2POp(op=isend, tensor=torch.empty(1,).cuda(), peer=seq_rank-1, group=seq_group) 176 | #req = torch.distributed.isend(tensor=empty_tensor, dst=seq_rank + 1, group=seq_group) 177 | handles = dist.batch_isend_irecv([op1, op2]) 178 | #req = torch.distributed.irecv(tensor=empty_tensor, src=seq_rank - 1, group=seq_group) 179 | dist.all_reduce(empty_tensor, group=seq_group) 180 | 181 | def get_sequence_parallel_group(): 182 | """Get the sequence parallel group the caller rank belongs to.""" 183 | #global _SEQUENCE_PARALLEL_GROUP 184 | assert ( 185 | _SEQUENCE_PARALLEL_GROUP is not None 186 | ), 'sequence parallel group is not initialized' 187 | return _SEQUENCE_PARALLEL_GROUP 188 | 189 | def get_sequence_parallel_rank(): 190 | """Return my rank for the sequence parallel group.""" 191 | global _SEQUENCE_PARALLEL_RANK 192 | if _SEQUENCE_PARALLEL_RANK is not None: 193 | return _SEQUENCE_PARALLEL_RANK 194 | return torch.distributed.get_rank(group=get_sequence_parallel_group()) 195 | 196 | def get_sequence_parallel_size(): 197 | """Return my rank for the sequence parallel group.""" 198 | global _SEQUENCE_PARALLEL_SIZE 199 | if _SEQUENCE_PARALLEL_SIZE is not None: 200 | return _SEQUENCE_PARALLEL_SIZE 201 | return torch.distributed.get_world_size(group=get_sequence_parallel_group()) 202 | 203 | def destroy_sequence_parallel(): 204 | """Set the groups to none.""" 205 | global _SEQUENCE_PARALLEL_GROUP 206 | _SEQUENCE_PARALLEL_GROUP = None 207 | 208 | # whether this is the last time the kernel being called 209 | def is_last_time(time_step): 210 | # e.g. on a 8-GPU setup: 211 | # R=0: 0 212 | # R=1: 1 213 | # R=2: 2 214 | # R=3: 3 215 | # R=4: 4, 5, 6, 7 216 | seq_rank = get_sequence_parallel_rank() 217 | seq_world_size = get_sequence_parallel_size() 218 | if seq_rank <= seq_world_size // 2: # no one helps these ranks 219 | rank_finish_time = seq_rank 220 | else: 221 | rank_finish_time = seq_world_size // 2 222 | return rank_finish_time == time_step 223 | 224 | # Whether the current time step is computing for local q 225 | def is_compute_for_local_query(time_step): 226 | # R=3,4,5,6,7: Yes 227 | # R=0: 0 228 | # R=1: 0, 1 229 | # R=2: 0, 1, 2 230 | seq_rank = get_sequence_parallel_rank() 231 | seq_world_size = get_sequence_parallel_size() 232 | if seq_rank >= min(seq_world_size // 2, time_step): 233 | return True 234 | return False 235 | 236 | # Whether the current time step is idle 237 | def is_idle(time_step): 238 | # 0, 1, 2, 3: 4 239 | # 4, 5, 6, 7: No 240 | seq_rank = get_sequence_parallel_rank() 241 | seq_world_size = get_sequence_parallel_size() 242 | 243 | if seq_rank < (seq_world_size // 2) and time_step == seq_world_size // 2: 244 | return True 245 | return False 246 | 247 | # Whether the current time step needs to synchronize with a remote computed result 248 | def is_sync_from_remote(time_step): 249 | # R=0, 1, 2, 3, 4: No 250 | # R=5: 4 251 | # R=6: 3, 4 252 | # R=7: 2, 3, 4 253 | seq_rank = get_sequence_parallel_rank() 254 | seq_world_size = get_sequence_parallel_size() 255 | if seq_rank > max(seq_world_size // 2, seq_world_size - time_step): 256 | return True 257 | return False 258 | 259 | def maybe_send_recv_fwd_qkvo(q: torch.Tensor, peer_q: torch.Tensor, 260 | k: torch.Tensor, peer_k: torch.Tensor, 261 | v: torch.Tensor, peer_v: torch.Tensor, 262 | o_stats: list,# peer_o_stats: list, 263 | time_step: int, comm_mode, debug=False) -> torch.Tensor: 264 | 265 | seq_group = get_sequence_parallel_group() 266 | seq_rank = get_sequence_parallel_rank() 267 | seq_world_size = get_sequence_parallel_size() 268 | 269 | # Handles for operations that actually need to be wait before going to the next iteration. 270 | # For instance, QKV sender never needs to wait -> it seems fusing these calls help scheduler; 271 | all_handles = [] 272 | # KV logic: different than older version, every rank to send/recv its own kv, 273 | # to balance communication. In a balanced communication, every step each rank 274 | # should send/recv 4 tensors in total (kv, or qo). For instance, rank 0 when 275 | # time step > 0, should send its own kv and send/recv qo. In the older version, 276 | # rank 0 does not send its kv, and rely on a later rank to pass it, where the 277 | # later rank has to (1) receive kv, send rank 0's kv and send/recv qo. 278 | # Q (load balancing) logic: semantically, this will be "%" world size, so 279 | # the same send/recv rank as KV. Note: Only support even number of machines. 280 | # O (load balancing) logic: rank 0 sends result to rank 7 at time 1. 281 | # It get delayed for one time step, and thus has different maybe_send/recv_rank. 282 | # Use (time_step + 1) to easily convert to synchornize version. 283 | maybe_send_rank = seq_rank + (time_step + 1) 284 | maybe_recv_rank = seq_rank - (time_step + 1) 285 | 286 | if debug: 287 | global _fwd_send_volume, _fwd_recv_volume, _bwd_send_volume, _bwd_recv_volume 288 | _debug_send = _fwd_send_volume 289 | _debug_recv = _fwd_recv_volume 290 | 291 | if maybe_send_rank >= seq_world_size: 292 | #send q, no one needs to do remote computation in the last time step 293 | if time_step < (seq_world_size // 2 - 1): 294 | #print(f"t={time_step}: R={seq_rank} sends q to {maybe_send_rank % seq_world_size} (not wait)") 295 | #q_send_handles.append(P2POp(op=isend, tensor=q, peer=maybe_send_rank % seq_world_size, group=seq_group)) 296 | all_handles.append(P2POp(op=isend, tensor=q, peer=maybe_send_rank % seq_world_size, group=seq_group)) 297 | if debug: 298 | _fwd_send_volume += torch.numel(q) * q.element_size() 299 | else: 300 | # send kv 301 | #print(f"t={time_step}: R={seq_rank} sends kv to {maybe_send_rank} (not wait)") 302 | #kv_send_handles.append(P2POp(op=isend, tensor=k, peer=maybe_send_rank, group=seq_group)) 303 | #kv_send_handles.append(P2POp(op=isend, tensor=v, peer=maybe_send_rank, group=seq_group)) 304 | all_handles.append(P2POp(op=isend, tensor=k, peer=maybe_send_rank, group=seq_group)) 305 | all_handles.append(P2POp(op=isend, tensor=v, peer=maybe_send_rank, group=seq_group)) 306 | if debug: 307 | _fwd_send_volume += torch.numel(k) * k.element_size() 308 | _fwd_send_volume += torch.numel(v) * v.element_size() 309 | 310 | if maybe_recv_rank < 0: 311 | # recv q, no one needs to do remote computation in the last time step 312 | if time_step < (seq_world_size // 2 - 1): 313 | # print(f"t={time_step}: R={seq_rank} receives q from {maybe_recv_rank % seq_world_size} (wait)") 314 | #q_recv_handles.append(P2POp(op=irecv, tensor=peer_q, peer=maybe_recv_rank % seq_world_size, group=seq_group)) 315 | all_handles.append(P2POp(op=irecv, tensor=peer_q, peer=maybe_recv_rank % seq_world_size, group=seq_group)) 316 | if debug: 317 | _fwd_recv_volume += torch.numel(peer_q) * peer_q.element_size() 318 | else: 319 | # recv kv 320 | #print(f"t={time_step}: R={seq_rank} receivs kv from {maybe_recv_rank} (wait)") 321 | #kv_recv_handles.append(P2POp(op=irecv, tensor=peer_k, peer=maybe_recv_rank, group=seq_group)) 322 | #kv_recv_handles.append(P2POp(op=irecv, tensor=peer_v, peer=maybe_recv_rank, group=seq_group)) 323 | all_handles.append(P2POp(op=irecv, tensor=peer_k, peer=maybe_recv_rank, group=seq_group)) 324 | all_handles.append(P2POp(op=irecv, tensor=peer_v, peer=maybe_recv_rank, group=seq_group)) 325 | if debug: 326 | _fwd_recv_volume += torch.numel(peer_k) * peer_k.element_size() 327 | _fwd_recv_volume += torch.numel(peer_v) * peer_v.element_size() 328 | 329 | maybe_send_rank_o = seq_rank - (time_step - 1) 330 | maybe_recv_rank_o = seq_rank + (time_step - 1) 331 | if maybe_send_rank_o < 0 and time_step > 1: 332 | for t in o_stats: 333 | # print(f"t={time_step}: R={seq_rank} sends o to {maybe_send_rank_o % seq_world_size} (wait)") 334 | #o_send_handles.append(P2POp(op=isend, tensor=t, peer=maybe_send_rank_o % seq_world_size, group=seq_group)) 335 | all_handles.append(P2POp(op=isend, tensor=t, peer=maybe_send_rank_o % seq_world_size, group=seq_group)) 336 | if debug: 337 | _fwd_send_volume += torch.numel(t) * t.element_size() 338 | if maybe_recv_rank_o >= seq_world_size and time_step > 1 : 339 | for t in o_stats: 340 | # print(f"t={time_step}: R={seq_rank} receives o from {maybe_recv_rank_o % seq_world_size} (wait)") 341 | #o_recv_handles.append(P2POp(op=irecv, tensor=t, peer=maybe_recv_rank_o % seq_world_size, group=seq_group)) 342 | all_handles.append(P2POp(op=irecv, tensor=t, peer=maybe_recv_rank_o % seq_world_size, group=seq_group)) 343 | if debug: 344 | _fwd_recv_volume += torch.numel(t) * t.element_size() 345 | 346 | #reqs = [] 347 | 348 | if debug: 349 | if seq_rank in [0, 8]: 350 | print(f"R={seq_rank} time_step={time_step} increases: send {(_fwd_send_volume - _debug_send) * 1e-9} GB recv {(_fwd_recv_volume - _debug_recv) * 1e-9} GB") 351 | #return reqs 352 | all_reqs = launch_async_handles(all_handles, comm_mode) 353 | return [all_reqs] 354 | 355 | # delta: may be you are using it for your local compute or as a distributed buffer to send to others 356 | # .. Sorry for the bad naming.. 357 | def maybe_send_recv_bwd_qkvo(dq_delta: torch.Tensor, dk_delta: torch.Tensor, 358 | dv_delta: torch.Tensor, dk_delta_from_peer: torch.Tensor, 359 | dv_delta_from_peer: torch.Tensor, q: torch.Tensor, 360 | peer_q: torch.Tensor, L: torch.Tensor, 361 | peer_L: torch.Tensor, k: torch.Tensor, 362 | peer_k: torch.Tensor, v: torch.Tensor, 363 | peer_v: torch.Tensor, o: torch.Tensor, 364 | peer_o: torch.Tensor, do: torch.Tensor, 365 | peer_do: torch.Tensor, time_step: int, comm_mode, debug=False): 366 | 367 | seq_group = get_sequence_parallel_group() 368 | seq_rank = get_sequence_parallel_rank() 369 | seq_world_size = get_sequence_parallel_size() 370 | 371 | all_handles = [] 372 | maybe_send_rank = seq_rank + (time_step + 1) 373 | maybe_recv_rank = seq_rank - (time_step + 1) 374 | 375 | if debug: 376 | global _fwd_send_volume, _fwd_recv_volume, _bwd_send_volume, _bwd_recv_volume 377 | 378 | if maybe_send_rank >= seq_world_size: 379 | #send q, no one needs to do remote computation in the last time step 380 | if time_step < (seq_world_size // 2 - 1): 381 | all_handles.append(P2POp(op=isend, tensor=q, peer=maybe_send_rank % seq_world_size, group=seq_group)) 382 | all_handles.append(P2POp(op=isend, tensor=L, peer=maybe_send_rank % seq_world_size, group=seq_group)) 383 | all_handles.append(P2POp(op=isend, tensor=o, peer=maybe_send_rank % seq_world_size, group=seq_group)) 384 | all_handles.append(P2POp(op=isend, tensor=do, peer=maybe_send_rank % seq_world_size, group=seq_group)) 385 | if debug: 386 | _bwd_send_volume += torch.numel(q) * q.element_size() 387 | _bwd_send_volume += torch.numel(L) * L.element_size() 388 | _bwd_send_volume += torch.numel(o) * o.element_size() 389 | _bwd_send_volume += torch.numel(do) * do.element_size() 390 | else: 391 | # send kv 392 | all_handles.append(P2POp(op=isend, tensor=k, peer=maybe_send_rank, group=seq_group)) 393 | all_handles.append(P2POp(op=isend, tensor=v, peer=maybe_send_rank, group=seq_group)) 394 | if debug: 395 | _bwd_send_volume += torch.numel(k) * k.element_size() 396 | _bwd_send_volume += torch.numel(v) * v.element_size() 397 | 398 | if maybe_recv_rank < 0: 399 | # recv q, no one needs to do remote computation in the last time step 400 | if time_step < (seq_world_size // 2 - 1): 401 | all_handles.append(P2POp(op=irecv, tensor=peer_q, peer=maybe_recv_rank % seq_world_size, group=seq_group)) 402 | all_handles.append(P2POp(op=irecv, tensor=peer_L, peer=maybe_recv_rank % seq_world_size, group=seq_group)) 403 | all_handles.append(P2POp(op=irecv, tensor=peer_o, peer=maybe_recv_rank % seq_world_size, group=seq_group)) 404 | all_handles.append(P2POp(op=irecv, tensor=peer_do, peer=maybe_recv_rank % seq_world_size, group=seq_group)) 405 | if debug: 406 | _bwd_recv_volume += torch.numel(peer_q) * peer_q.element_size() 407 | _bwd_recv_volume += torch.numel(peer_L) * peer_L.element_size() 408 | _bwd_recv_volume += torch.numel(peer_o) * peer_o.element_size() 409 | _bwd_recv_volume += torch.numel(peer_do) * peer_do.element_size() 410 | else: 411 | # recv kv 412 | all_handles.append(P2POp(op=irecv, tensor=peer_k, peer=maybe_recv_rank, group=seq_group)) 413 | all_handles.append(P2POp(op=irecv, tensor=peer_v, peer=maybe_recv_rank, group=seq_group)) 414 | if debug: 415 | _bwd_recv_volume += torch.numel(peer_k) * peer_k.element_size() 416 | _bwd_recv_volume += torch.numel(peer_v) * peer_v.element_size() 417 | 418 | # Whether I should update dq, dk and dv after waiting these requests 419 | is_update_dq = False 420 | is_update_dkv = False 421 | 422 | maybe_send_rank_dqkv = seq_rank - (time_step - 1) 423 | maybe_recv_rank_dqkv = seq_rank + (time_step - 1) 424 | 425 | if time_step > 1: 426 | if maybe_send_rank_dqkv < 0: 427 | #print(f"BWD t={time_step}: R={seq_rank} sends dq delta to {maybe_send_rank_dqkv % seq_world_size}") 428 | all_handles.append(P2POp(op=isend, tensor=dq_delta, peer=maybe_send_rank_dqkv % seq_world_size, group=seq_group)) 429 | if debug: 430 | _bwd_send_volume += torch.numel(dq_delta) * dq_delta.element_size() 431 | else: 432 | #print(f"BWD t={time_step}: R={seq_rank} sends dkv delta to {maybe_send_rank_dqkv}") 433 | all_handles.append(P2POp(op=isend, tensor=dk_delta, peer=maybe_send_rank_dqkv, group=seq_group)) 434 | all_handles.append(P2POp(op=isend, tensor=dv_delta, peer=maybe_send_rank_dqkv, group=seq_group)) 435 | if debug: 436 | _bwd_send_volume += torch.numel(dk_delta) * dk_delta.element_size() 437 | _bwd_send_volume += torch.numel(dv_delta) * dv_delta.element_size() 438 | 439 | if maybe_recv_rank_dqkv >= seq_world_size: 440 | #print(f"BWD t={time_step}: R={seq_rank} receives dq delta to {maybe_recv_rank_dqkv % seq_world_size}") 441 | all_handles.append(P2POp(op=irecv, tensor=dq_delta, peer=maybe_recv_rank_dqkv % seq_world_size, group=seq_group)) 442 | is_update_dq = True 443 | if debug: 444 | _bwd_recv_volume += torch.numel(dq_delta) * dq_delta.element_size() 445 | else: 446 | #print(f"BWD t={time_step}: R={seq_rank} receives dk dv delta from {maybe_recv_rank_dqkv}") 447 | all_handles.append(P2POp(op=irecv, tensor=dk_delta_from_peer, peer=maybe_recv_rank_dqkv, group=seq_group)) 448 | all_handles.append(P2POp(op=irecv, tensor=dv_delta_from_peer, peer=maybe_recv_rank_dqkv, group=seq_group)) 449 | is_update_dkv = True 450 | if debug: 451 | _bwd_recv_volume += torch.numel(dk_delta_from_peer) * dk_delta_from_peer.element_size() 452 | _bwd_recv_volume += torch.numel(dv_delta_from_peer) * dv_delta_from_peer.element_size() 453 | 454 | # return [], is_update_dq, is_update_dkv 455 | all_reqs = launch_async_handles(all_handles, comm_mode) 456 | return [all_reqs], is_update_dq, is_update_dkv 457 | 458 | def maybe_send_recv_bwd_last_dkv(dk_delta: torch.Tensor, dv_delta: torch.Tensor, time_step, comm_mode, debug=False): 459 | is_update_last_dkv = False 460 | 461 | seq_group = get_sequence_parallel_group() 462 | seq_rank = get_sequence_parallel_rank() 463 | seq_world_size = get_sequence_parallel_size() 464 | 465 | if seq_world_size == 1: return [], is_update_last_dkv 466 | 467 | all_handles = [] 468 | 469 | if debug: 470 | global _fwd_send_volume, _fwd_recv_volume, _bwd_send_volume, _bwd_recv_volume 471 | 472 | if time_step == seq_world_size // 2: 473 | maybe_send_rank = seq_rank - time_step 474 | maybe_recv_rank = seq_rank + time_step 475 | 476 | assert (maybe_send_rank >= 0) ^ (maybe_recv_rank < seq_world_size), "R={seq_rank} should be either sending or receiving dkv in the last time step." 477 | 478 | if maybe_send_rank >= 0: 479 | # print(f"BWD t={time_step}: R={seq_rank} last send dkv to {maybe_send_rank}") 480 | all_handles.append(P2POp(op=isend, tensor=dk_delta, peer=maybe_send_rank, group=seq_group)) 481 | all_handles.append(P2POp(op=isend, tensor=dv_delta, peer=maybe_send_rank, group=seq_group)) 482 | if debug: 483 | _bwd_send_volume += torch.numel(dk_delta) * dk_delta.element_size() 484 | _bwd_send_volume += torch.numel(dv_delta) * dv_delta.element_size() 485 | if maybe_recv_rank < seq_world_size: 486 | # print(f"BWD t={time_step}: R={seq_rank} last receive dkv from {maybe_recv_rank}") 487 | all_handles.append(P2POp(op=irecv, tensor=dk_delta, peer=maybe_recv_rank, group=seq_group)) 488 | all_handles.append(P2POp(op=irecv, tensor=dv_delta, peer=maybe_recv_rank, group=seq_group)) 489 | if debug: 490 | _bwd_recv_volume += torch.numel(dk_delta) * dk_delta.element_size() 491 | _bwd_recv_volume += torch.numel(dv_delta) * dv_delta.element_size() 492 | is_update_last_dkv = True 493 | 494 | # return [], is_update_last_dkv 495 | all_reqs = launch_async_handles(all_handles, comm_mode) 496 | 497 | return [all_reqs], is_update_last_dkv 498 | 499 | def print_and_reset_comm_stats(): 500 | seq_rank = get_sequence_parallel_rank() 501 | 502 | global _fwd_send_volume, _fwd_recv_volume, _bwd_send_volume, _bwd_recv_volume 503 | _fwd_send_volume *= 1e-9 504 | _fwd_recv_volume *= 1e-9 505 | _bwd_send_volume *= 1e-9 506 | _bwd_recv_volume *= 1e-9 507 | 508 | print(f"R={seq_rank} fwd send: {_fwd_send_volume} fwd recv: {_fwd_recv_volume}; bwd send: {_bwd_send_volume}, bwd recv: {_bwd_recv_volume} GB.") 509 | _fwd_send_volume = 0 510 | _fwd_recv_volume = 0 511 | _bwd_send_volume = 0 512 | _bwd_recv_volume = 0 513 | 514 | def launch_async_handles(handles, comm_mode): 515 | global _args 516 | if comm_mode == "nocomm": 517 | #print("skipping communication for ablation") 518 | return [] 519 | if len(handles) > 0: 520 | return dist.batch_isend_irecv(handles) 521 | return [] 522 | 523 | def wait_async_handles(reqs): 524 | if len(reqs) > 0: 525 | for req in reqs: 526 | for r in req: 527 | r.wait() -------------------------------------------------------------------------------- /utils/easy_context/dist_flash_attn/monkey_patch.py: -------------------------------------------------------------------------------- 1 | """ 2 | Materialization-aware gradient checkpointing monkey patch. 3 | """ 4 | from typing import List, Optional, Tuple 5 | 6 | import torch 7 | from torch import nn 8 | from torch.utils.checkpoint import _get_autocast_kwargs, check_backward_validity, get_device_states, set_device_states, detach_variable 9 | 10 | import transformers 11 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, BaseModelOutputWithPast 12 | 13 | from einops import rearrange 14 | 15 | from .lightseq_async_attn import _lightseq_forward, _lightseq_backward 16 | from .async_communication import initialize_distributed, reset_global_memory_buffer 17 | 18 | 19 | # define a global buffer to save flash attention outputs 20 | # it's called global because it saves the outputs for all layers 21 | global_flash_attn_out_buffer = None 22 | 23 | # define a local buffer to save recomputed qkv 24 | # it's called local because it's a temporary buffer which will be updated across layers 25 | local_res_grad_buffer = None 26 | 27 | # hooks for the gradients of residual 28 | global_hooks = [] 29 | 30 | def init_flash_attn_buffers(num_layers): 31 | # update the global buffer according to number of layers 32 | global global_flash_attn_out_buffer 33 | global_flash_attn_out_buffer = [None] * num_layers 34 | 35 | def clean_hook(): 36 | # Remove all hooks in the global buffer 37 | for hook in global_hooks: 38 | hook.remove() 39 | # Clear the global buffer 40 | global_hooks.clear() 41 | 42 | def clear_all_buffers_at_the_end_of_training(): 43 | # call it at the end of training 44 | global lobal_flash_attn_out_buffer 45 | global_flash_attn_out_buffer = None 46 | global local_res_grad_buffer 47 | local_res_grad_buffer = None 48 | clean_hook() 49 | 50 | def save_flash_attn_out_to_global_buffer(idx, out): 51 | global global_flash_attn_out_buffer 52 | global_flash_attn_out_buffer[idx] = out 53 | 54 | def get_flash_attn_out_from_global_buffer(idx): 55 | global global_flash_attn_out_buffer 56 | return global_flash_attn_out_buffer[idx] 57 | 58 | def free_flash_attn_out_buffer(idx): 59 | global global_flash_attn_out_buffer 60 | global_flash_attn_out_buffer[idx] = None 61 | 62 | def write_gradient_to_flash_attn_out(idx, grad): 63 | global global_flash_attn_out_buffer 64 | global_flash_attn_out_buffer[idx].grad = grad 65 | 66 | def save_res_grad_hook(grad): 67 | global local_res_grad_buffer 68 | local_res_grad_buffer = grad 69 | 70 | def load_and_add_res_grad_hook(grad): 71 | grad += get_res_grad_from_local_buffer() 72 | 73 | def get_res_grad_from_local_buffer(): 74 | global local_res_grad_buffer 75 | assert local_res_grad_buffer is not None 76 | return local_res_grad_buffer 77 | 78 | class CheckpointFunctionEndWithFlashAttention(torch.autograd.Function): 79 | """ Avoid doing twice flash attention forward during checkpointed backward. 80 | args: 81 | hidden_states, # i.e., flash attention output which is saved in global buffer. 82 | attention_mask, 83 | position_ids, 84 | residual, # the gradient of residual is saved in local buffer to pass across ckpt layers. 85 | """ 86 | 87 | @staticmethod 88 | def forward(ctx, run_function, layer_idx, preserve_rng_state, *args): 89 | check_backward_validity(args) 90 | ctx.run_function = run_function 91 | ctx.layer_idx = layer_idx 92 | ctx.preserve_rng_state = preserve_rng_state 93 | # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu. 94 | ctx.gpu_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs() 95 | if preserve_rng_state: 96 | ctx.fwd_cpu_state = torch.get_rng_state() 97 | # Don't eagerly initialize the cuda context by accident. 98 | # (If the user intends that the context is initialized later, within their 99 | # run_function, we SHOULD actually stash the cuda state here. Unfortunately, 100 | # we have no way to anticipate this will happen before we run the function.) 101 | ctx.had_cuda_in_fwd = False 102 | if torch.cuda._initialized: 103 | ctx.had_cuda_in_fwd = True 104 | ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args) 105 | 106 | # Save non-tensor inputs in ctx, keep a placeholder None for tensors 107 | # to be filled out during the backward. 108 | ctx.inputs = [] 109 | ctx.tensor_indices = [] 110 | tensor_inputs = [] 111 | for i, arg in enumerate(args): 112 | if i == 0 and ctx.layer_idx != 0: 113 | # flash attention output is saved to the global buffer during forward 114 | ctx.inputs.append(None) 115 | else: 116 | if torch.is_tensor(arg): 117 | tensor_inputs.append(arg) 118 | ctx.tensor_indices.append(i) 119 | ctx.inputs.append(None) 120 | else: 121 | ctx.inputs.append(arg) 122 | 123 | with torch.no_grad(): 124 | q, k, v, residual = run_function(*args) 125 | softmax_scale = q.shape[-1] ** (-0.5) 126 | 127 | # lightseq version 128 | _, _, _, out, softmax_lse = _lightseq_forward(q, k, v, True, softmax_scale, comm_mode='lightseq') 129 | rng_state = None 130 | 131 | # save flash attention output to global buffer 132 | save_flash_attn_out_to_global_buffer(ctx.layer_idx, out) 133 | tensor_inputs += [softmax_lse] 134 | ctx.softmax_scale = softmax_scale 135 | 136 | ctx.save_for_backward(*tensor_inputs) 137 | 138 | return out, residual 139 | 140 | @staticmethod 141 | def backward(ctx, *args): 142 | if not torch.autograd._is_checkpoint_valid(): 143 | raise RuntimeError( 144 | "Checkpointing is not compatible with .grad() or when an `inputs` parameter" 145 | " is passed to .backward(). Please use .backward() and do not pass its `inputs`" 146 | " argument.") 147 | # Copy the list to avoid modifying original list. 148 | inputs = list(ctx.inputs) 149 | tensor_indices = ctx.tensor_indices 150 | tensors = ctx.saved_tensors 151 | tensors, softmax_lse = tensors[:-1], tensors[-1] 152 | 153 | # Fill in inputs with appropriate saved tensors. 154 | # Fill the flash attention output first 155 | if ctx.layer_idx > 0: 156 | # inputs[0] should be flash attention output 157 | inputs[0] = get_flash_attn_out_from_global_buffer(ctx.layer_idx-1) 158 | for i, idx in enumerate(tensor_indices): 159 | inputs[idx] = tensors[i] 160 | 161 | # Stash the surrounding rng state, and mimic the state that was 162 | # present at this time during forward. Restore the surrounding state 163 | # when we're done. 164 | rng_devices = [] 165 | if ctx.preserve_rng_state and ctx.had_cuda_in_fwd: 166 | rng_devices = ctx.fwd_gpu_devices 167 | with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state): 168 | if ctx.preserve_rng_state: 169 | torch.set_rng_state(ctx.fwd_cpu_state) 170 | if ctx.had_cuda_in_fwd: 171 | set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states) 172 | detached_inputs = detach_variable(tuple(inputs)) 173 | with torch.enable_grad(), \ 174 | torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \ 175 | torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): 176 | # Stop recomputation before flash attention 177 | # It is unecessary to run recomputation for flash attn 178 | q, k, v, residual = ctx.run_function(*detached_inputs) 179 | 180 | # run backward() with only tensor that requires grad 181 | # run flash attention backward first: 182 | # get 'dout' from auto_grad inputs 183 | # get 'out' from global buffer 184 | # get 'qkv' from the recomputed tensors 185 | #dq = torch.empty(q.shape, dtype=q.dtype, device=q.device) 186 | #dk = torch.empty(k.shape, dtype=q.dtype, device=q.device) 187 | #dv = torch.empty(v.shape, dtype=q.dtype, device=q.device) 188 | out = get_flash_attn_out_from_global_buffer(ctx.layer_idx) 189 | # todo get dout 190 | dout = args[0] 191 | 192 | # lightseq version 193 | dq, dk, dv = _lightseq_backward(dout, q, k, v, out, softmax_lse, ctx.softmax_scale, comm_mode='lightseq', backward_engine='flash') 194 | #dqkv = torch.stack([dq, dk, dv]) 195 | 196 | # run backward for the part before flash attention 197 | #qkv.backward(dqkv) 198 | torch.autograd.backward([q, k, v], [dq, dk, dv]) 199 | 200 | grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None 201 | for inp in detached_inputs) 202 | 203 | # write flash attention output gradients to buffer 204 | if ctx.layer_idx > 0: 205 | write_gradient_to_flash_attn_out(ctx.layer_idx-1, detached_inputs[0].grad) 206 | 207 | return (None, None, None) + grads 208 | 209 | 210 | def checkpoint_end_with_flash_attention(function, layer_idx, *args, use_reentrant: bool = True, **kwargs): 211 | # Hack to mix *args with **kwargs in a python 2.7-compliant way 212 | preserve = kwargs.pop('preserve_rng_state', True) 213 | if kwargs and use_reentrant: 214 | raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)) 215 | 216 | return CheckpointFunctionEndWithFlashAttention.apply(function, layer_idx, preserve, *args) 217 | 218 | 219 | class CheckpointFunctionLastModule(torch.autograd.Function): 220 | """ 221 | for the last ffn layer after flash attention, modifications include: 222 | write the gradients wrt flash attention output and residual to the global buffer. 223 | """ 224 | 225 | @staticmethod 226 | def forward(ctx, run_function, preserve_rng_state, *args): 227 | check_backward_validity(args) 228 | ctx.run_function = run_function 229 | ctx.preserve_rng_state = preserve_rng_state 230 | # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu. 231 | ctx.gpu_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs() 232 | if preserve_rng_state: 233 | ctx.fwd_cpu_state = torch.get_rng_state() 234 | # Don't eagerly initialize the cuda context by accident. 235 | # (If the user intends that the context is initialized later, within their 236 | # run_function, we SHOULD actually stash the cuda state here. Unfortunately, 237 | # we have no way to anticipate this will happen before we run the function.) 238 | ctx.had_cuda_in_fwd = False 239 | if torch.cuda._initialized: 240 | ctx.had_cuda_in_fwd = True 241 | ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args) 242 | 243 | # Save non-tensor inputs in ctx, keep a placeholder None for tensors 244 | # to be filled out during the backward. 245 | ctx.inputs = [] 246 | ctx.tensor_indices = [] 247 | tensor_inputs = [] 248 | 249 | assert torch.is_tensor(args[0]), "assuming the first tensor is the flash attention output" 250 | for i, arg in enumerate(args): 251 | if torch.is_tensor(arg) and i == 0: 252 | # flash attn output has been saved to global buffer 253 | ctx.inputs.append(None) 254 | elif torch.is_tensor(arg): 255 | tensor_inputs.append(arg) 256 | ctx.tensor_indices.append(i) 257 | ctx.inputs.append(None) 258 | else: 259 | ctx.inputs.append(arg) 260 | 261 | ctx.save_for_backward(*tensor_inputs) 262 | 263 | with torch.no_grad(): 264 | outputs = run_function(*args) 265 | return outputs 266 | 267 | @staticmethod 268 | def backward(ctx, *args): 269 | if not torch.autograd._is_checkpoint_valid(): 270 | raise RuntimeError( 271 | "Checkpointing is not compatible with .grad() or when an `inputs` parameter" 272 | " is passed to .backward(). Please use .backward() and do not pass its `inputs`" 273 | " argument.") 274 | # Copy the list to avoid modifying original list. 275 | inputs = list(ctx.inputs) 276 | tensor_indices = ctx.tensor_indices 277 | tensors = ctx.saved_tensors 278 | 279 | # Fill in inputs with appropriate saved tensors. 280 | # Fill the flash attention output first 281 | # inputs[0] should be flash attention output 282 | inputs[0] = get_flash_attn_out_from_global_buffer(-1) 283 | for i, idx in enumerate(tensor_indices): 284 | inputs[idx] = tensors[i] 285 | 286 | # Stash the surrounding rng state, and mimic the state that was 287 | # present at this time during forward. Restore the surrounding state 288 | # when we're done. 289 | rng_devices = [] 290 | if ctx.preserve_rng_state and ctx.had_cuda_in_fwd: 291 | rng_devices = ctx.fwd_gpu_devices 292 | with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state): 293 | if ctx.preserve_rng_state: 294 | torch.set_rng_state(ctx.fwd_cpu_state) 295 | if ctx.had_cuda_in_fwd: 296 | set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states) 297 | detached_inputs = detach_variable(tuple(inputs)) 298 | with torch.enable_grad(), \ 299 | torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \ 300 | torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): 301 | outputs = ctx.run_function(*detached_inputs) 302 | 303 | if isinstance(outputs, torch.Tensor): 304 | outputs = (outputs,) 305 | 306 | # run backward() with only tensor that requires grad 307 | outputs_with_grad = [] 308 | args_with_grad = [] 309 | for i in range(len(outputs)): 310 | if torch.is_tensor(outputs[i]) and outputs[i].requires_grad: 311 | outputs_with_grad.append(outputs[i]) 312 | args_with_grad.append(args[i]) 313 | if len(outputs_with_grad) == 0: 314 | raise RuntimeError( 315 | "none of output has requires_grad=True," 316 | " this checkpoint() is not necessary") 317 | torch.autograd.backward(outputs_with_grad, args_with_grad) 318 | grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None 319 | for inp in detached_inputs) 320 | 321 | # write flash attention output gradients to buffer 322 | write_gradient_to_flash_attn_out(-1, detached_inputs[0].grad) 323 | 324 | return (None, None) + grads 325 | 326 | def checkpoint_last_module(function, *args, use_reentrant: bool = True, **kwargs): 327 | preserve = kwargs.pop('preserve_rng_state', True) 328 | if kwargs and use_reentrant: 329 | raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)) 330 | 331 | return CheckpointFunctionLastModule.apply(function, preserve, *args) 332 | 333 | 334 | def llama_layer_forward( 335 | self, 336 | hidden_states: torch.Tensor, 337 | attention_mask: Optional[torch.Tensor] = None, 338 | position_ids: Optional[torch.LongTensor] = None, 339 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 340 | output_attentions: Optional[bool] = False, 341 | use_cache: Optional[bool] = False, 342 | compute_attn_only: Optional[bool] = False, 343 | compute_ffn_only: Optional[bool] = False, 344 | residual: Optional[bool] = None, 345 | ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: 346 | """ 347 | Args: 348 | hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` 349 | attention_mask (`torch.FloatTensor`, *optional*): attention mask of size 350 | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. 351 | output_attentions (`bool`, *optional*): 352 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 353 | returned tensors for more detail. 354 | use_cache (`bool`, *optional*): 355 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding 356 | (see `past_key_values`). 357 | past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states 358 | """ 359 | assert compute_ffn_only or compute_attn_only 360 | 361 | if compute_attn_only: 362 | residual = hidden_states 363 | 364 | if residual.requires_grad: 365 | # register a hook to add the gradient of residual 366 | # from next checkpoint layer when doing recomputation 367 | hook = residual.register_hook(load_and_add_res_grad_hook) 368 | global_hooks.append(hook) 369 | 370 | hidden_states = self.input_layernorm(hidden_states) 371 | 372 | # Flash Attention 373 | bsz, q_len, _ = hidden_states.size() 374 | try: 375 | query_states = self.self_attn.q_proj(hidden_states).view(bsz, q_len, self.self_attn.num_heads, self.self_attn.head_dim).transpose(1, 2) 376 | key_states = self.self_attn.k_proj(hidden_states).view(bsz, q_len, self.self_attn.num_key_value_heads, self.self_attn.head_dim).transpose(1, 2) 377 | value_states = self.self_attn.v_proj(hidden_states).view(bsz, q_len, self.self_attn.num_key_value_heads, self.self_attn.head_dim).transpose(1, 2) 378 | except: 379 | # old transformers versions don't support num_key_value_heads 380 | query_states = self.self_attn.q_proj(hidden_states).view(bsz, q_len, self.self_attn.num_heads, self.self_attn.head_dim).transpose(1, 2) 381 | key_states = self.self_attn.k_proj(hidden_states).view(bsz, q_len, self.self_attn.num_heads, self.self_attn.head_dim).transpose(1, 2) 382 | value_states = self.self_attn.v_proj(hidden_states).view(bsz, q_len, self.self_attn.num_heads, self.self_attn.head_dim).transpose(1, 2) 383 | 384 | kv_seq_len = key_states.shape[-2] 385 | assert past_key_value is None, "past_key_value is not supported" 386 | 387 | cos, sin = self.self_attn.rotary_emb(value_states, position_ids) 388 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) 389 | # [bsz, nh, t, hd] 390 | assert not output_attentions, "output_attentions is not supported" 391 | assert not use_cache, "use_cache is not supported" 392 | return query_states.contiguous(), key_states.contiguous(), value_states.contiguous(), residual 393 | 394 | elif compute_ffn_only: 395 | hidden_states = self.self_attn.o_proj(rearrange(hidden_states, 'b h s d -> b s (h d)')) 396 | # Need to add residual here to make sure checkpoint is right after attention 397 | if residual.requires_grad: 398 | # save the gradient of residual to the local buffer 399 | # collect the hooks which should be removed after backward to avoid memory leak 400 | hook = residual.register_hook(save_res_grad_hook) 401 | global_hooks.append(hook) 402 | 403 | hidden_states = residual + hidden_states 404 | 405 | # Fully Connected 406 | 407 | residual = hidden_states 408 | hidden_states = self.post_attention_layernorm(hidden_states) 409 | hidden_states = self.mlp(hidden_states) 410 | hidden_states = residual + hidden_states 411 | 412 | outputs = (hidden_states,) 413 | 414 | else: 415 | raise AttributeError 416 | 417 | return outputs 418 | 419 | 420 | def forward( 421 | self, 422 | input_ids: torch.LongTensor = None, 423 | attention_mask: Optional[torch.Tensor] = None, 424 | position_ids: Optional[torch.LongTensor] = None, 425 | past_key_values: Optional[List[torch.FloatTensor]] = None, 426 | inputs_embeds: Optional[torch.FloatTensor] = None, 427 | use_cache: Optional[bool] = None, 428 | output_attentions: Optional[bool] = None, 429 | output_hidden_states: Optional[bool] = None, 430 | cache_position: Optional[torch.LongTensor] = None, 431 | return_dict: Optional[bool] = None, 432 | ): 433 | assert cache_position is None, "cache_position is not supported" 434 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 435 | output_hidden_states = ( 436 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 437 | ) 438 | use_cache = use_cache if use_cache is not None else self.config.use_cache 439 | 440 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 441 | 442 | # retrieve input_ids and inputs_embeds 443 | if input_ids is not None and inputs_embeds is not None: 444 | raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") 445 | elif input_ids is not None: 446 | batch_size, seq_length = input_ids.shape 447 | elif inputs_embeds is not None: 448 | batch_size, seq_length, _ = inputs_embeds.shape 449 | else: 450 | raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") 451 | 452 | seq_length_with_past = seq_length 453 | past_key_values_length = 0 454 | 455 | if past_key_values is not None: 456 | past_key_values_length = past_key_values[0][0].shape[2] 457 | seq_length_with_past = seq_length_with_past + past_key_values_length 458 | 459 | if position_ids is None: 460 | device = input_ids.device if input_ids is not None else inputs_embeds.device 461 | position_ids = torch.arange( 462 | past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device 463 | ) 464 | position_ids = position_ids.unsqueeze(0).view(-1, seq_length) 465 | else: 466 | position_ids = position_ids.view(-1, seq_length).long() 467 | 468 | if inputs_embeds is None: 469 | inputs_embeds = self.embed_tokens(input_ids) 470 | # embed positions 471 | attention_mask = None 472 | 473 | hidden_states = inputs_embeds 474 | 475 | if self.gradient_checkpointing and self.training: 476 | try: 477 | logger.warning_once( 478 | "***** Using fast gradient checkpointing... *****" 479 | ) 480 | except: 481 | pass 482 | # initialize the global buffer 483 | init_flash_attn_buffers(len(self.layers)) 484 | 485 | if use_cache: 486 | try: 487 | logger.warning_once( 488 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 489 | ) 490 | except: 491 | pass 492 | use_cache = False 493 | 494 | # decoder layers 495 | all_hidden_states = () if output_hidden_states else None 496 | all_self_attns = () if output_attentions else None 497 | next_decoder_cache = () if use_cache else None 498 | 499 | # apply flash-attention friendly gradient checkpointing 500 | if self.gradient_checkpointing and self.training: 501 | for idx in range(len(self.layers) + 1): 502 | if output_hidden_states: 503 | all_hidden_states += (hidden_states,) 504 | 505 | past_key_value = past_key_values[idx] if past_key_values is not None else None 506 | 507 | def forward_first_attn_module(module): 508 | def custom_forward(*inputs): 509 | hidden_states, attention_mask, position_ids, _ = inputs 510 | # None for past_key_value 511 | return module(hidden_states, attention_mask, position_ids, past_key_value, output_attentions, compute_attn_only=True) 512 | return custom_forward 513 | 514 | def forward_ffn_attn_layer(module1, module2): 515 | def custom_forward(*inputs): 516 | hidden_states, attention_mask, position_ids, residual = inputs 517 | # None for past_key_value 518 | layer_outputs = module1(hidden_states, attention_mask, position_ids, past_key_value, output_attentions, compute_ffn_only=True, residual=residual) 519 | hidden_states = layer_outputs[0] 520 | return module2(hidden_states, attention_mask, position_ids, past_key_value, output_attentions, compute_attn_only=True) 521 | return custom_forward 522 | 523 | def forward_last_ffn_module(module): 524 | def custom_forward(*inputs): 525 | hidden_states, attention_mask, position_ids, residual = inputs 526 | # None for past_key_value 527 | return module(hidden_states, attention_mask, position_ids, past_key_value, output_attentions, compute_ffn_only=True, residual=residual) 528 | return custom_forward 529 | 530 | if idx == 0: 531 | layer_outputs = checkpoint_end_with_flash_attention( 532 | forward_first_attn_module(self.layers[0]), 533 | idx, 534 | hidden_states, 535 | attention_mask, 536 | position_ids, 537 | None, 538 | ) 539 | hidden_states, residual = layer_outputs[0], layer_outputs[-1] 540 | elif idx == len(self.layers): 541 | layer_outputs = checkpoint_last_module( 542 | forward_last_ffn_module(self.layers[-1]), 543 | hidden_states, 544 | attention_mask, 545 | position_ids, 546 | residual, 547 | ) 548 | hidden_states = layer_outputs[0] 549 | else: 550 | layer_outputs = checkpoint_end_with_flash_attention( 551 | forward_ffn_attn_layer(self.layers[idx-1], self.layers[idx]), 552 | idx, 553 | hidden_states, 554 | attention_mask, 555 | position_ids, 556 | residual, 557 | ) 558 | hidden_states, residual = layer_outputs[0], layer_outputs[-1] 559 | 560 | if use_cache: 561 | next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) 562 | 563 | if output_attentions: 564 | all_self_attns += (layer_outputs[1],) 565 | else: 566 | for idx, decoder_layer in enumerate(self.layers): 567 | if output_hidden_states: 568 | all_hidden_states += (hidden_states,) 569 | 570 | past_key_value = past_key_values[idx] if past_key_values is not None else None 571 | 572 | layer_outputs = decoder_layer( 573 | hidden_states, 574 | attention_mask=attention_mask, 575 | position_ids=position_ids, 576 | past_key_value=past_key_value, 577 | output_attentions=output_attentions, 578 | use_cache=use_cache, 579 | ) 580 | 581 | hidden_states = layer_outputs[0] 582 | 583 | if use_cache: 584 | next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) 585 | 586 | if output_attentions: 587 | all_self_attns += (layer_outputs[1],) 588 | 589 | hidden_states = self.norm(hidden_states) 590 | 591 | # add hidden states from the last decoder layer 592 | if output_hidden_states: 593 | all_hidden_states += (hidden_states,) 594 | 595 | next_cache = next_decoder_cache if use_cache else None 596 | if not return_dict: 597 | return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) 598 | return BaseModelOutputWithPast( 599 | last_hidden_state=hidden_states, 600 | past_key_values=next_cache, 601 | hidden_states=all_hidden_states, 602 | attentions=all_self_attns, 603 | ) 604 | 605 | 606 | def apply_dist_flash_attn_monkey_patch_llama(): 607 | initialize_distributed() 608 | transformers.models.llama.modeling_llama.LlamaModel.forward = forward 609 | transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = llama_layer_forward 610 | -------------------------------------------------------------------------------- /utils/easy_context/dist_flash_attn/lightseq_async_attn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | 4 | from einops import rearrange 5 | import argparse 6 | 7 | import pytest 8 | import torch 9 | import torch.distributed as dist 10 | from torch.distributed import ReduceOp 11 | #from torch.profiler import profile, record_function, ProfilerActivity 12 | import functools 13 | import triton 14 | import triton.language as tl 15 | import time 16 | import numpy as np 17 | from tqdm import tqdm 18 | 19 | try: 20 | from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward 21 | except: 22 | pass 23 | 24 | from .async_communication import (is_last_time, is_compute_for_local_query, is_sync_from_remote, is_idle, print_and_reset_comm_stats, 25 | launch_async_handles, wait_async_handles, maybe_send_recv_fwd_qkvo, maybe_send_recv_bwd_qkvo, maybe_send_recv_bwd_last_dkv, reset_global_memory_buffer, 26 | maybe_get_set_global_memory_buffer, maybe_get_set_global_memory_buffer_bwd, initialize_distributed, get_sequence_parallel_size, get_sequence_parallel_rank) 27 | 28 | @triton.jit 29 | def max_fn(x, y): 30 | return tl.math.max(x, y) 31 | 32 | @triton.jit 33 | def _rescale_kernel( 34 | peer_m, 35 | m, 36 | peer_l, 37 | l, 38 | peer_o, 39 | o, 40 | L, 41 | stride_oz, stride_oh, stride_om, stride_on, 42 | Z, H, N_CTX, 43 | BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, 44 | BLOCK_N: tl.constexpr, 45 | LAST_STEP: tl.constexpr, 46 | ): 47 | start_m = tl.program_id(0) 48 | off_hz = tl.program_id(1) 49 | o_offset = off_hz * stride_oh 50 | peer_o_block_ptr = tl.make_block_ptr( 51 | base=peer_o + o_offset, 52 | shape=(N_CTX, BLOCK_DMODEL), 53 | strides=(stride_om, stride_on), 54 | offsets=(start_m * BLOCK_M, 0), 55 | block_shape=(BLOCK_M, BLOCK_DMODEL), 56 | order=(1, 0) 57 | ) 58 | o_block_ptr = tl.make_block_ptr( 59 | base=o + o_offset, 60 | shape=(N_CTX, BLOCK_DMODEL), 61 | strides=(stride_om, stride_on), 62 | offsets=(start_m * BLOCK_M, 0), 63 | block_shape=(BLOCK_M, BLOCK_DMODEL), 64 | order=(1, 0) 65 | ) 66 | # initialize offsets 67 | offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) 68 | offs_n = tl.arange(0, BLOCK_N) 69 | 70 | peer_m_ptrs = peer_m + off_hz * N_CTX + offs_m 71 | m_ptrs = m + off_hz * N_CTX + offs_m 72 | peer_l_ptrs = peer_l + off_hz * N_CTX + offs_m 73 | l_ptrs = l + off_hz * N_CTX + offs_m 74 | 75 | peer_m_i = tl.load(peer_m_ptrs) 76 | peer_m_i = peer_m_i.to(tl.float32) 77 | m_i = tl.load(m_ptrs) 78 | m_i = m_i.to(tl.float32) 79 | peer_l_i = tl.load(peer_l_ptrs) 80 | peer_l_i = peer_l_i.to(tl.float32) 81 | l_i = tl.load(l_ptrs) 82 | l_i = l_i.to(tl.float32) 83 | 84 | peer_acc = tl.load(peer_o_block_ptr) 85 | peer_acc = peer_acc.to(tl.float32) 86 | acc = tl.load(o_block_ptr) 87 | acc = acc.to(tl.float32) 88 | lo = 0 89 | hi = N_CTX 90 | m_i_sync = tl.maximum(m_i, peer_m_i) 91 | alpha = tl.math.exp2(m_i - m_i_sync) 92 | peer_alpha = tl.math.exp2(peer_m_i - m_i_sync) 93 | # -- scale and update acc -- 94 | acc_scale = l_i * 0 + alpha # workaround some compiler bug 95 | peer_acc_scale = peer_l_i * 0 + peer_alpha # workaround some compiler bug 96 | 97 | acc *= acc_scale[:, None] 98 | peer_acc *= peer_acc_scale[:, None] 99 | acc += peer_acc 100 | l_i = l_i * acc_scale + peer_l_i * peer_acc_scale 101 | # write back O, l, m 102 | tl.store(m_ptrs, m_i_sync) 103 | tl.store(l_ptrs, l_i) 104 | if LAST_STEP: 105 | acc = acc / l_i[:, None] 106 | L_ptrs = L + off_hz * N_CTX + offs_m 107 | tl.store(L_ptrs, m_i_sync / 1.44269504 + tl.math.log(l_i)) 108 | tl.store(o_block_ptr, acc.to(tl.bfloat16)) 109 | 110 | @triton.jit 111 | def _fwd_kernel( 112 | Q, K, V, sm_scale, 113 | m, 114 | l, 115 | O, 116 | L, 117 | stride_qz, stride_qh, stride_qm, stride_qk, 118 | stride_kz, stride_kh, stride_kn, stride_kk, 119 | stride_vz, stride_vh, stride_vk, stride_vn, 120 | stride_oz, stride_oh, stride_om, stride_on, 121 | Z, H, N_CTX, 122 | BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, 123 | BLOCK_N: tl.constexpr, 124 | IS_CAUSAL: tl.constexpr, 125 | LAST_STEP: tl.constexpr 126 | ): 127 | start_m = tl.program_id(0) 128 | off_hz = tl.program_id(1) 129 | qvk_offset = off_hz * stride_qh 130 | Q_block_ptr = tl.make_block_ptr( 131 | base=Q + qvk_offset, 132 | shape=(N_CTX, BLOCK_DMODEL), 133 | strides=(stride_qm, stride_qk), 134 | offsets=(start_m * BLOCK_M, 0), 135 | block_shape=(BLOCK_M, BLOCK_DMODEL), 136 | order=(1, 0) 137 | ) 138 | K_block_ptr = tl.make_block_ptr( 139 | base=K + qvk_offset, 140 | shape=(BLOCK_DMODEL, N_CTX), 141 | strides=(stride_kk, stride_kn), 142 | offsets=(0, 0), 143 | block_shape=(BLOCK_DMODEL, BLOCK_N), 144 | order=(0, 1) 145 | ) 146 | V_block_ptr = tl.make_block_ptr( 147 | base=V + qvk_offset, 148 | shape=(N_CTX, BLOCK_DMODEL), 149 | strides=(stride_vk, stride_vn), 150 | offsets=(0, 0), 151 | block_shape=(BLOCK_N, BLOCK_DMODEL), 152 | order=(1, 0) 153 | ) 154 | O_block_ptr = tl.make_block_ptr( 155 | base=O + qvk_offset, 156 | shape=(N_CTX, BLOCK_DMODEL), 157 | strides=(stride_om, stride_on), 158 | offsets=(start_m * BLOCK_M, 0), 159 | block_shape=(BLOCK_M, BLOCK_DMODEL), 160 | order=(1, 0) 161 | ) 162 | # initialize offsets 163 | offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) 164 | offs_n = tl.arange(0, BLOCK_N) 165 | # initialize pointer to m and l -> load from provided pointer 166 | m_ptrs = m + off_hz * N_CTX + offs_m 167 | l_ptrs = l + off_hz * N_CTX + offs_m 168 | m_i = tl.load(m_ptrs) 169 | m_i = m_i.to(tl.float32) 170 | l_i = tl.load(l_ptrs) 171 | l_i = l_i.to(tl.float32) 172 | acc = tl.load(O_block_ptr) 173 | acc = acc.to(tl.float32) 174 | # scale sm_scale by log_2(e) and use 175 | # 2^x instead of exp in the loop because CSE and LICM 176 | # don't work as expected with `exp` in the loop 177 | qk_scale = sm_scale * 1.44269504 178 | # load q: it will stay in SRAM throughout 179 | q = tl.load(Q_block_ptr) 180 | q = (q * qk_scale).to(tl.bfloat16) 181 | # loop over k, v and update accumulator 182 | lo = 0 183 | hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX 184 | for start_n in range(lo, hi, BLOCK_N): 185 | # -- load k, v -- 186 | k = tl.load(K_block_ptr) 187 | v = tl.load(V_block_ptr) 188 | # -- compute qk --- 189 | qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 190 | if IS_CAUSAL: 191 | qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) 192 | qk += tl.dot(q, k) 193 | # -- compute scaling constant --- 194 | m_i_new = tl.maximum(m_i, tl.max(qk, 1)) 195 | alpha = tl.math.exp2(m_i - m_i_new) 196 | p = tl.math.exp2(qk - m_i_new[:, None]) 197 | # -- scale and update acc -- 198 | acc_scale = l_i * 0 + alpha # workaround some compiler bug 199 | acc *= acc_scale[:, None] 200 | acc += tl.dot(p.to(tl.bfloat16), v) 201 | # -- update m_i and l_i -- 202 | l_i = l_i * alpha + tl.sum(p, 1) 203 | m_i = m_i_new 204 | # update pointers 205 | K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) 206 | V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) 207 | # write back original l and m 208 | tl.store(m_ptrs, m_i) 209 | tl.store(l_ptrs, l_i) 210 | # write back O, L 211 | if LAST_STEP: 212 | acc = acc / l_i[:, None] 213 | L_ptrs = L + off_hz * N_CTX + offs_m 214 | tl.store(L_ptrs, m_i / 1.44269504 + tl.math.log(l_i)) 215 | tl.store(O_block_ptr, acc.to(tl.bfloat16)) 216 | 217 | # for gqa/mqa to expand kv heads 218 | def maybe_repeat_kv_fwd(nqh, kv): 219 | bs, nkvh, slen, hdim = kv.shape 220 | n_rep = nqh // nkvh 221 | if n_rep == 1: 222 | return kv 223 | kv_expand = kv[:, :, None, :, :].expand(bs, nkvh, n_rep, slen, hdim) 224 | return kv_expand.reshape(bs, nkvh * n_rep, slen, hdim) 225 | 226 | def maybe_repeat_kv_bwd(nqh, kv): 227 | bs, slen, nkvh, hdim = kv.shape 228 | n_rep = nqh // nkvh 229 | if n_rep == 1: 230 | return kv 231 | kv_expand = kv[:, :, :, None, :].expand(bs, slen, nkvh, n_rep, hdim) 232 | return kv_expand.reshape(bs, slen, nkvh * n_rep, hdim) 233 | 234 | # kv grad has shape bs, slen, nqh, hdim 235 | def maybe_reduce_dkv(nkvh, dkv): 236 | bs, slen, nqh, hdim = dkv.shape 237 | n_rep = nqh // nkvh 238 | if n_rep == 1: 239 | return dkv 240 | dkv_reshape = dkv.view(bs, slen, nkvh, n_rep, hdim) 241 | return torch.sum(dkv_reshape, dim=3) 242 | 243 | 244 | def _lightseq_forward(q, k, v, causal, sm_scale, comm_mode): 245 | # maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x 246 | # q, k, v = [maybe_contiguous(x) for x in (q, k, v)] 247 | 248 | # shape constraints 249 | Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] 250 | assert Lq == Lk and Lk == Lv 251 | assert Lk in {16, 32, 64, 128} 252 | # Why do I have to change it from 128 64 to 32 32? 253 | BLOCK_M = 32 254 | BLOCK_N = 32 255 | 256 | bsz, nh, seq_len, hdim = q.shape 257 | 258 | m = torch.full((bsz * nh, seq_len), fill_value=-float("inf"), device=q.device, dtype=torch.float32) 259 | l = torch.zeros_like(m) 260 | L = torch.zeros_like(m) 261 | o = torch.zeros_like(q) 262 | 263 | grid = (triton.cdiv(seq_len, BLOCK_M), bsz * nh, 1) 264 | num_warps = 4 if Lk <= 64 else 8 265 | 266 | seq_rank = get_sequence_parallel_rank() 267 | seq_world_size = get_sequence_parallel_size() 268 | 269 | # Initialize all buffers 270 | peer_q, peer_k, peer_v, peer_m, peer_l, peer_o = maybe_get_set_global_memory_buffer(q, k, v, m, l, o) 271 | 272 | fwd_launch_helper = lambda q, k, v, m, l, o, L, IS_CAUSAL, LAST_STEP: _fwd_kernel[grid]( 273 | q, k, v, sm_scale, 274 | m, 275 | l, 276 | o, 277 | L, 278 | q.stride(0), q.stride(1), q.stride(2), q.stride(3), 279 | k.stride(0), k.stride(1), k.stride(2), k.stride(3), 280 | v.stride(0), v.stride(1), v.stride(2), v.stride(3), 281 | o.stride(0), o.stride(1), o.stride(2), o.stride(3), 282 | q.shape[0], q.shape[1], q.shape[2], 283 | BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk, 284 | IS_CAUSAL=IS_CAUSAL, 285 | LAST_STEP=LAST_STEP, 286 | num_warps=num_warps, 287 | num_stages=4) 288 | 289 | for time_step in range(seq_world_size // 2 + 1): 290 | # This is important for cuda scheduler to execute nccl calls first. 291 | torch.cuda.synchronize() 292 | # Communication uses buffer_idx_1, and compute uses buffer_idx_2, which effectively are contents from the last time step. 293 | buffer_idx_1 = time_step % 2 294 | buffer_idx_2 = (time_step - 1) % 2 295 | 296 | reqs = maybe_send_recv_fwd_qkvo(q, peer_q[buffer_idx_1], k, peer_k[buffer_idx_1], v, peer_v[buffer_idx_1], 297 | [peer_o[buffer_idx_1], peer_m[buffer_idx_1], peer_l[buffer_idx_1]], time_step, comm_mode) 298 | if comm_mode == "sync": 299 | # if seq_rank == 0: 300 | # print("Immediate wait for abalation") 301 | wait_async_handles(reqs) 302 | if is_compute_for_local_query(time_step): 303 | # print(f"t={time_step}: (Comp) R={seq_rank} local compute") 304 | if time_step == 0: 305 | fwd_launch_helper(q, maybe_repeat_kv_fwd(q.shape[1], k), maybe_repeat_kv_fwd(q.shape[1], v), m, l, o, L, True, is_last_time(time_step)) 306 | else: 307 | # if needs to sync from others, do not normalize here 308 | fwd_launch_helper(q, maybe_repeat_kv_fwd(q.shape[1], peer_k[buffer_idx_2]), maybe_repeat_kv_fwd(q.shape[1], peer_v[buffer_idx_2]), m, l, o, L, False, not is_sync_from_remote(time_step) and is_last_time(time_step)) 309 | elif is_idle(time_step): 310 | # print(f"t={time_step}: (Comp) R={seq_rank} idle") 311 | pass 312 | else: 313 | # print(f"t={time_step}: (Comp) R={seq_rank} helps other") 314 | peer_m[buffer_idx_2] = torch.full_like(m, fill_value=-float("inf")) 315 | peer_l[buffer_idx_2] = torch.zeros_like(l) 316 | peer_o[buffer_idx_2] = torch.zeros_like(o) 317 | 318 | #print(f"rank 3 q is: {peer_q[buffer_idx_2]}") 319 | fwd_launch_helper(peer_q[buffer_idx_2], maybe_repeat_kv_fwd(q.shape[1], k), maybe_repeat_kv_fwd(q.shape[1], v), peer_m[buffer_idx_2], peer_l[buffer_idx_2], peer_o[buffer_idx_2], None, False, False) 320 | 321 | if comm_mode == "lightseq": 322 | # Make sure tensors for next steps are ready 323 | wait_async_handles(reqs) 324 | # sync between statistics get from other ranks and the local ones 325 | if is_sync_from_remote(time_step): 326 | _rescale_kernel[grid]( 327 | peer_m[buffer_idx_1], 328 | m, 329 | peer_l[buffer_idx_1], 330 | l, 331 | peer_o[buffer_idx_1], 332 | o, 333 | L, 334 | o.stride(0), o.stride(1), o.stride(2), o.stride(3), 335 | o.shape[0], o.shape[1], o.shape[2], 336 | BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk, 337 | LAST_STEP=is_last_time(time_step), 338 | num_warps=num_warps, 339 | num_stages=4) 340 | return q, k, v, o, L 341 | 342 | def _lightseq_backward(do, q, k, v, o, L, sm_scale, comm_mode, backward_engine): 343 | BLOCK = 128 344 | q, k, v, o, do = [rearrange(_x, 'b h s d -> b s h d').contiguous() for _x in [q, k, v, o, do]] 345 | L = rearrange(L, '(b h) s -> b h s', b=q.shape[0]) 346 | 347 | dq = torch.empty_like(q) 348 | dk = torch.empty_like(k) 349 | dv = torch.empty_like(v) 350 | 351 | # maybe gqa 352 | nqh = q.shape[2] 353 | nkvh = k.shape[2] 354 | is_gqa = (nqh > nkvh) 355 | 356 | seq_rank = get_sequence_parallel_rank() 357 | seq_world_size = get_sequence_parallel_size() 358 | 359 | # Initialize all backward buffers 360 | dq_delta, dk_delta, dv_delta, dk_delta_from_peer, dv_delta_from_peer, \ 361 | peer_q, peer_L, peer_k, peer_v, peer_o, peer_do = maybe_get_set_global_memory_buffer_bwd(dq, dk, dv, q, L, k, v, o, do) 362 | 363 | for time_step in range(0, get_sequence_parallel_size() // 2 + 1): 364 | torch.cuda.synchronize() 365 | buffer_idx_1 = time_step % 2 366 | buffer_idx_2 = (time_step - 1) % 2 367 | 368 | reqs, is_update_dq, is_update_dkv = maybe_send_recv_bwd_qkvo(dq_delta[buffer_idx_1], dk_delta[buffer_idx_1], dv_delta[buffer_idx_1], dk_delta_from_peer, dv_delta_from_peer, q, peer_q[buffer_idx_1], L, peer_L[buffer_idx_1], k, peer_k[buffer_idx_1], v, peer_v[buffer_idx_1], o, peer_o[buffer_idx_1], do, peer_do[buffer_idx_1], time_step, comm_mode) 369 | if comm_mode == "sync": 370 | # if seq_rank == 0: 371 | # print("(bwd) Immediate wait for abalation") 372 | wait_async_handles(reqs) 373 | 374 | if is_compute_for_local_query(time_step): 375 | if time_step == 0: 376 | if backward_engine == "flash": 377 | _flash_attn_backward(do, q, k, v, o, L, dq, dk, dv, 0.0, sm_scale, True, (-1,-1), None, False) 378 | else: 379 | inp = Inputs(query=q, key=maybe_repeat_kv_bwd(q.shape[2], k), value=maybe_repeat_kv_bwd(q.shape[2], v), attn_bias=xformers.ops.LowerTriangularMask(), p=0, scale=sm_scale) 380 | op_ctx = Context(lse=L, out=o, rng_state=None) 381 | # Let xformers dispatch the correct backend 382 | grads = _memory_efficient_attention_backward(ctx=op_ctx, inp=inp, grad=do, op=None) 383 | dq = grads.dq 384 | dk, dv = maybe_reduce_dkv(nkvh, grads.dk), maybe_reduce_dkv(nkvh, grads.dv) 385 | else: 386 | if backward_engine == "flash": 387 | _flash_attn_backward(do, q, peer_k[buffer_idx_2], peer_v[buffer_idx_2], o, L, dq_delta[buffer_idx_2], dk_delta[buffer_idx_2], dv_delta[buffer_idx_2], 0.0, sm_scale, False, (-1,-1), None, False) 388 | else: 389 | inp = Inputs(query=q, key=maybe_repeat_kv_bwd(q.shape[2], peer_k[buffer_idx_2]), value=maybe_repeat_kv_bwd(q.shape[2], peer_v[buffer_idx_2]), attn_bias=None, p=0, scale=sm_scale) 390 | op_ctx = Context(lse=L, out=o, rng_state=None) 391 | grads = _memory_efficient_attention_backward(ctx=op_ctx, inp=inp, grad=do, op=None) 392 | dq_delta[buffer_idx_2] = grads.dq 393 | dk_delta[buffer_idx_2], dv_delta[buffer_idx_2] = maybe_reduce_dkv(nkvh, grads.dk), maybe_reduce_dkv(nkvh, grads.dv) 394 | dq += dq_delta[buffer_idx_2] 395 | elif is_idle(time_step): 396 | pass 397 | else: 398 | if backward_engine == "flash": 399 | _flash_attn_backward(peer_do[buffer_idx_2], peer_q[buffer_idx_2], k, v, peer_o[buffer_idx_2], peer_L[buffer_idx_2], dq_delta[buffer_idx_2], dk_delta[buffer_idx_2], dv_delta[buffer_idx_2], 0.0, sm_scale, False, (-1,-1), None, False) 400 | else: 401 | inp = Inputs(query=peer_q[buffer_idx_2], key=maybe_repeat_kv_bwd(q.shape[2], k), value=maybe_repeat_kv_bwd(q.shape[2], v), attn_bias=None, p=0, scale=sm_scale) 402 | op_ctx = Context(lse=peer_L[buffer_idx_2], out=peer_o[buffer_idx_2], rng_state=None) 403 | grads = _memory_efficient_attention_backward(ctx=op_ctx, inp=inp, grad=peer_do[buffer_idx_2], op=None) 404 | dq_delta[buffer_idx_2] = grads.dq 405 | dk_delta[buffer_idx_2], dv_delta[buffer_idx_2] = maybe_reduce_dkv(nkvh, grads.dk), maybe_reduce_dkv(nkvh, grads.dv) 406 | dk += dk_delta[buffer_idx_2] 407 | dv += dv_delta[buffer_idx_2] 408 | 409 | if comm_mode == "lightseq": 410 | # Make sure tensors for next steps are ready 411 | wait_async_handles(reqs) 412 | 413 | # The last time step needs to send dk and dv immediately, move it up here to maximize overlap with the following three addition. 414 | reqs, is_update_last_dkv = maybe_send_recv_bwd_last_dkv(dk_delta[buffer_idx_2], dv_delta[buffer_idx_2], time_step, comm_mode) 415 | 416 | if comm_mode == "sync": 417 | # if seq_rank == 0: 418 | # print("(bwd) dkv Immediate wait for abalation") 419 | wait_async_handles(reqs) 420 | # apply dq_delta, dk_delta and dv_delta from remote 421 | if is_update_dq: 422 | dq += dq_delta[buffer_idx_1] 423 | if is_update_dkv: 424 | dk += dk_delta_from_peer 425 | dv += dv_delta_from_peer 426 | 427 | if comm_mode == "lightseq": 428 | wait_async_handles(reqs) 429 | # apply dk_delta and dv_delta to sender 430 | if is_update_last_dkv: 431 | dk += dk_delta[buffer_idx_2] 432 | dv += dv_delta[buffer_idx_2] 433 | 434 | dq, dk, dv = [rearrange(_x, 'b h s d -> b s h d') for _x in [dq, dk, dv]] 435 | return dq, dk, dv 436 | 437 | class _attention(torch.autograd.Function): 438 | @staticmethod 439 | def forward(ctx, q, k, v, causal, sm_scale): 440 | try: 441 | global args 442 | comm_mode = args.comm_mode 443 | backward_engine = args.backward_engine 444 | except: 445 | comm_mode = 'lightseq' 446 | backward_engine = 'flash' 447 | 448 | q, k, v, o, L = _lightseq_forward(q, k, v, causal, sm_scale, comm_mode) 449 | 450 | ctx.save_for_backward(q, k, v, o, L) 451 | ctx.sm_scale = sm_scale 452 | ctx.comm_mode = comm_mode 453 | ctx.backward_engine = backward_engine 454 | return o 455 | 456 | @staticmethod 457 | def backward(ctx, do): 458 | q, k, v, o, L = ctx.saved_tensors 459 | sm_scale = ctx.sm_scale 460 | 461 | dq, dk, dv = _lightseq_backward(do, q, k, v, o, L, sm_scale, ctx.comm_mode, ctx.backward_engine) 462 | return dq, dk, dv, None, None 463 | 464 | attention = _attention.apply 465 | 466 | 467 | #@pytest.mark.parametrize('causal', [False, True]) 468 | #@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(6, 9, 1024, 64)]) 469 | def test_op(Z, H, N_CTX, D_HEAD, causal, dtype=torch.bfloat16): 470 | torch.manual_seed(20) 471 | q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() 472 | k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() 473 | v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() 474 | 475 | rank = dist.get_rank() 476 | world_size = dist.get_world_size() 477 | seq_per_rank = N_CTX // world_size 478 | 479 | sm_scale = 0.5 480 | dout = torch.randn_like(q) 481 | # reference implementation 482 | M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) 483 | p = torch.matmul(q, k.transpose(2, 3)) * sm_scale 484 | assert causal 485 | if causal: 486 | p[:, :, M == 0] = float("-inf") 487 | p = torch.softmax(p.float(), dim=-1).half() 488 | ref_out = torch.matmul(p, v) 489 | ref_out.backward(dout) 490 | ref_dv, v.grad = v.grad.clone(), None 491 | ref_dk, k.grad = k.grad.clone(), None 492 | ref_dq, q.grad = q.grad.clone(), None 493 | 494 | # triton implementation 495 | 496 | a, b, c, d = q.size() 497 | real_q = q[:,:, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].view(a, b, -1, d).contiguous().clone().detach().requires_grad_(True) 498 | real_k = k[:,:, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].view(a, b, -1, d).contiguous().clone().detach().requires_grad_(True) 499 | real_v = v[:,:, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].view(a, b, -1, d).contiguous().clone().detach().requires_grad_(True) 500 | real_do = dout[:,:, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].view(a, b, -1, d).contiguous().clone().detach().requires_grad_(True) 501 | 502 | tri_out = attention(real_q, real_k, real_v, causal, sm_scale).half() 503 | 504 | # compare 505 | assert torch.allclose(ref_out[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_out, atol=1e-2, rtol=0), f" rank {rank} fails forward" 506 | print(f" *** rank {rank} passes forward") 507 | tri_out.backward(real_do) 508 | tri_dv, real_v.grad = real_v.grad.clone(), None 509 | tri_dk, real_k.grad = real_k.grad.clone(), None 510 | tri_dq, real_q.grad = real_q.grad.clone(), None 511 | assert torch.allclose(ref_dq[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_dq, atol=1e-2, rtol=0), f" rank {rank} fails backward dq" 512 | assert torch.allclose(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_dk, atol=1e-2, rtol=0), f"rank {rank} fails backward dk" #f" {ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dk} {torch.max(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dk)} rank {rank} fails backward dk" 513 | assert torch.allclose(ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_dv, atol=1e-2, rtol=0), f"rank {rank} fails backward dv {ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dv} {torch.max(ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dv)} rank {rank} fails backward dv" 514 | print(f"rank {rank} passes backward") 515 | 516 | 517 | def test_gqa(Z, H, KVH, N_CTX, D_HEAD, causal, dtype=torch.bfloat16): 518 | torch.manual_seed(177) 519 | q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() 520 | k = torch.empty((Z, KVH, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() 521 | v = torch.empty((Z, KVH, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() 522 | 523 | rank = dist.get_rank() 524 | world_size = dist.get_world_size() 525 | seq_per_rank = N_CTX // world_size 526 | 527 | sm_scale = 0.5 528 | dout = torch.randn_like(q) 529 | # torch reference implementation 530 | M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) 531 | ref_k = maybe_repeat_kv_fwd(q.shape[1], k).clone().detach().requires_grad_(True) 532 | ref_v = maybe_repeat_kv_fwd(q.shape[1], v).clone().detach().requires_grad_(True) 533 | p = torch.matmul(q, ref_k.transpose(2,3)) * sm_scale 534 | assert causal 535 | if causal: 536 | p[:, :, M == 0] = float("-inf") 537 | p = torch.softmax(p.float(), dim=-1).half() 538 | ref_out = torch.matmul(p, ref_v) 539 | ref_out.backward(dout) 540 | ref_dv, v.grad = ref_v.grad.clone(), None 541 | ref_dv = (maybe_reduce_dkv(KVH, ref_dv.transpose(1,2))).transpose(1,2) 542 | ref_dk, k.grad = ref_k.grad.clone(), None 543 | ref_dk = (maybe_reduce_dkv(KVH, ref_dk.transpose(1,2))).transpose(1,2) 544 | ref_dq, q.grad = q.grad.clone(), None 545 | 546 | # flash reference 547 | from flash_attn import flash_attn_qkvpacked_func, flash_attn_func 548 | flash_q = q.transpose(1,2).clone().detach().requires_grad_(True) 549 | flash_k = k.transpose(1,2).clone().detach().requires_grad_(True) 550 | flash_v = v.transpose(1,2).clone().detach().requires_grad_(True) 551 | flash_ref_out = flash_attn_func(flash_q, flash_k, flash_v, 0, sm_scale, True) 552 | flash_ref_out.backward(dout.transpose(1,2)) 553 | flash_ref_out = flash_ref_out.transpose(1,2) 554 | flash_ref_dv, v.grad = flash_v.grad.clone(), None 555 | flash_ref_dv = flash_ref_dv.transpose(1,2) 556 | flash_ref_dk, k.grad = flash_k.grad.clone(), None 557 | flash_ref_dk = flash_ref_dk.transpose(1,2) 558 | flash_ref_dq, q.grad = flash_q.grad.clone(), None 559 | flash_ref_dq = flash_ref_dq.transpose(1,2) 560 | 561 | # triton implementation 562 | 563 | a, b, c, d = q.size() 564 | real_q = q[:,:, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].view(a, b, -1, d).contiguous().clone().detach().requires_grad_(True) 565 | real_k = k[:,:, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].view(a, KVH, -1, d).contiguous().clone().detach().requires_grad_(True) 566 | real_v = v[:,:, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].view(a, KVH, -1, d).contiguous().clone().detach().requires_grad_(True) 567 | real_do = dout[:,:, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].view(a, b, -1, d).contiguous().clone().detach().requires_grad_(True) 568 | 569 | tri_out = attention(real_q, real_k, real_v, causal, sm_scale).half() 570 | 571 | # compare 572 | assert torch.allclose(flash_ref_out[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_out, atol=1e-2, rtol=0), f" rank {rank} fails forward against flash" 573 | print(f" *** rank {rank} passes forward") 574 | tri_out.backward(real_do) 575 | tri_dv, real_v.grad = real_v.grad.clone(), None 576 | tri_dk, real_k.grad = real_k.grad.clone(), None 577 | tri_dq, real_q.grad = real_q.grad.clone(), None 578 | assert torch.allclose(flash_ref_dq[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_dq, atol=1e-2, rtol=0), f" rank {rank} fails backward dq against flash" 579 | #print(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].shape, ref_dk.shape, tri_dk.shape) 580 | assert torch.allclose(flash_ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_dk, atol=1e-2, rtol=0), f"rank {rank} fails backward dk against flash {flash_ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dk} {torch.max(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dk)} rank {rank} fails backward dk" 581 | assert torch.allclose(flash_ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_dv, atol=1e-2, rtol=0), f"rank {rank} fails backward dv against flash {flash_ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dv} {torch.max(flash_ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dv)} rank {rank} fails backward dv" 582 | print(f"rank {rank} passes backward against flash") 583 | 584 | assert torch.allclose(ref_out[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_out, atol=1e-2, rtol=0), f" rank {rank} fails forward" 585 | print(f" *** rank {rank} passes forward") 586 | assert torch.allclose(ref_dq[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_dq, atol=1e-2, rtol=0), f" rank {rank} fails backward dq" 587 | #print(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].shape, ref_dk.shape, tri_dk.shape) 588 | assert torch.allclose(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_dk, atol=1e-2, rtol=0), f"rank {rank} fails backward dk {ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dk} {torch.max(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dk)} rank {rank} fails backward dk" 589 | assert torch.allclose(ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_dv, atol=1e-2, rtol=0), f"rank {rank} fails backward dv {ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dv} {torch.max(ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dv)} rank {rank} fails backward dv" 590 | print(f"rank {rank} passes backward") 591 | 592 | #BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 593 | try: 594 | from flash_attn.flash_attn_interface import \ 595 | flash_attn_qkvpacked_func as flash_attn_func 596 | FLASH_VER = 2 597 | except BaseException: 598 | try: 599 | from flash_attn.flash_attn_interface import flash_attn_func 600 | FLASH_VER = 1 601 | except BaseException: 602 | FLASH_VER = None 603 | HAS_FLASH = FLASH_VER is not None 604 | HAS_FLASH = None 605 | ONLY_FLASH = False 606 | 607 | #BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 608 | BATCH, N_HEADS, N_CTX, D_HEAD = 1, 32, None, 128 609 | # vary seq length for fixed head and batch=4 610 | configs = [triton.testing.Benchmark( 611 | x_names=['N_CTX'], 612 | x_vals=[2**i for i in range(18, 19)],#[ 20, 21]],#[10, 11, 12, 13, 14, 15, 16, 17, 18]], 613 | line_arg='provider', 614 | line_vals=['triton'] if not ONLY_FLASH else [] + (['flash'] if HAS_FLASH else []), 615 | line_names=['Triton'] if not ONLY_FLASH else [] + ([f'Flash-{FLASH_VER}'] if HAS_FLASH else []), 616 | styles=[('red', '-'), ('blue', '-')], 617 | ylabel='ms', 618 | plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}-{causal}', 619 | args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.bfloat16, 'mode': mode, 'causal': causal} 620 | ) for mode in ["all"] for causal in [True]] 621 | 622 | # @triton.testing.perf_report(configs) 623 | def bench_flash_attention(BATCH, H, KVH, N_CTX, D_HEAD, causal, mode, provider, args, dtype=torch.bfloat16, device="cuda"): 624 | assert mode == "all" #mode in ['fwd', 'bwd'] 625 | n_warmup = 10 626 | n_repeat = 10 627 | cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda') 628 | seq_rank = get_sequence_parallel_rank() 629 | seq_world_size = get_sequence_parallel_size() 630 | if provider == "triton": 631 | q = torch.randn((BATCH, H, N_CTX // seq_world_size, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) 632 | k = torch.randn((BATCH, KVH, N_CTX // seq_world_size, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) 633 | v = torch.randn((BATCH, KVH, N_CTX // seq_world_size, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) 634 | if seq_rank == 0: 635 | print(f"Benchmarking per GPU qkv shape: {q.shape}") 636 | sm_scale = 1.3 637 | fwd_fn = lambda: attention(q, k, v, causal, sm_scale) 638 | if provider == "flash": 639 | qkv = torch.randn((BATCH, N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True) 640 | if FLASH_VER == 1: 641 | lengths = torch.full((BATCH,), fill_value=N_CTX, device=device) 642 | cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32) 643 | cu_seqlens[1:] = lengths.cumsum(0) 644 | qkv = qkv.reshape(BATCH * N_CTX, 3, H, D_HEAD) 645 | fwd_fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=causal) 646 | elif FLASH_VER == 2: 647 | fwd_fn = lambda: flash_attn_func(qkv, causal=causal) 648 | else: 649 | raise ValueError(f'unknown {FLASH_VER = }') 650 | 651 | flops_per_matmul = 2. * BATCH * H * N_CTX * N_CTX * D_HEAD / seq_world_size 652 | attn_flops = 2 * flops_per_matmul 653 | 654 | assert causal 655 | if causal: 656 | attn_flops *= 0.5 657 | fwd_flops = attn_flops 658 | bwd_flops = attn_flops * 2.5 # 2.0(bwd) + 0.5(recompute) 659 | 660 | o = fwd_fn() 661 | do = torch.randn_like(o) 662 | bwd_fn = lambda: o.backward(do, retain_graph=True) 663 | 664 | def run_benchmark(fn): 665 | time_list = [] 666 | for _ in tqdm(range(n_warmup)): 667 | cache.zero_() 668 | fn() 669 | torch.cuda.synchronize() 670 | if args.debug: 671 | print_and_reset_comm_stats() 672 | for i in tqdm(range(n_repeat)): 673 | cache.zero_() 674 | torch.cuda.synchronize() 675 | time_s = time.time() 676 | fn() 677 | torch.cuda.synchronize() 678 | time_e = time.time() 679 | time_list.append((time_e - time_s) * 1000.0) 680 | if args.debug: 681 | print_and_reset_comm_stats() 682 | return np.asarray(time_list) 683 | 684 | fwd_time_arr = run_benchmark(fwd_fn) 685 | bwd_time_arr = run_benchmark(bwd_fn) 686 | 687 | fwd_flops_ps = fwd_flops / np.mean(fwd_time_arr) * 1e-9 688 | print(f"(FWD) R={seq_rank} avg: {np.mean(fwd_time_arr)}, std: {np.std(fwd_time_arr)} flops: {fwd_flops_ps} \n") 689 | 690 | bwd_flops_ps = bwd_flops / np.mean(bwd_time_arr) * 1e-9 691 | print(f"(BWD) R={seq_rank} avg: {np.mean(bwd_time_arr)}, std: {np.std(bwd_time_arr)} flops: {bwd_flops_ps} \n") 692 | 693 | # total 694 | total_time_arr = fwd_time_arr + bwd_time_arr 695 | total_flops = fwd_flops + bwd_flops 696 | total_flops_ps = total_flops / np.mean(total_time_arr) * 1e-9 697 | print(f"(Total) R={seq_rank} avg: {np.mean(total_time_arr)}, std: {np.std(total_time_arr)} flops: {total_flops_ps} \n") 698 | 699 | #return total_flops_ps 700 | 701 | 702 | if __name__ == "__main__": 703 | parser = argparse.ArgumentParser() 704 | parser.add_argument("--comm-mode", type=str, default="lightseq") 705 | parser.add_argument("--debug", action="store_true") 706 | parser.add_argument("--run-mode", type=str, default="benchmark") 707 | parser.add_argument("--bs", type=int, default=1) 708 | parser.add_argument("--n_heads", type=int, default=32) 709 | parser.add_argument("--n_kvheads", type=int, default=32) 710 | parser.add_argument("--d_head", type=int, default=128) 711 | parser.add_argument("--start_ctx", type=int, default=12) 712 | parser.add_argument("--end_ctx", type=int, default=18) 713 | parser.add_argument("--forward_engine", type=str, default="triton") 714 | parser.add_argument("--backward_engine", type=str, default="flash") 715 | 716 | global args 717 | args = parser.parse_args() 718 | initialize_distributed() 719 | 720 | assert args.forward_engine == "triton", "Only triton forward is implmented." 721 | assert args.backward_engine in ["flash", "xformers"], "Only flash or xformers backward is implemented." 722 | 723 | if args.backward_engine == "flash": 724 | from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward 725 | else: 726 | try: 727 | import xformers.ops 728 | from xformers.ops.fmha.common import Inputs, Context 729 | from xformers.ops.fmha import _memory_efficient_attention_backward 730 | from xformers.ops.fmha import cutlass, flash 731 | except ImportError: 732 | print("xformers not found! Please install it before trying to use it.") 733 | 734 | if args.run_mode == "benchmark": 735 | for N_CTX in [2**i for i in range(args.start_ctx, args.end_ctx)]: 736 | bench_flash_attention(args.bs, args.n_heads, args.n_kvheads, N_CTX, args.d_head, True, "all", "triton", args)#.run(save_path='.', print_data=True) 737 | reset_global_memory_buffer() 738 | else: 739 | assert args.run_mode == "test" 740 | for N_CTX in [2048, 4096]: 741 | test_op(1, 16, N_CTX, 128, True) 742 | #test_gqa(1, 16, 8, N_CTX, 128, True) 743 | reset_global_memory_buffer() 744 | -------------------------------------------------------------------------------- /utils/easy_context/dist_flash_attn/lightseq_async_attn_varlen.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | 4 | from einops import rearrange 5 | import argparse 6 | 7 | import pytest 8 | import torch 9 | import torch.distributed as dist 10 | from torch.distributed import ReduceOp 11 | #from torch.profiler import profile, record_function, ProfilerActivity 12 | 13 | import triton 14 | import triton.language as tl 15 | import time 16 | import numpy as np 17 | from tqdm import tqdm 18 | 19 | try: 20 | from flash_attn.flash_attn_interface import _flash_attn_varlen_backward 21 | except: 22 | pass 23 | 24 | from .async_communication import (is_last_time, is_compute_for_local_query, is_sync_from_remote, is_idle, print_and_reset_comm_stats, 25 | launch_async_handles, wait_async_handles, maybe_send_recv_fwd_qkvo, maybe_send_recv_bwd_qkvo, maybe_send_recv_bwd_last_dkv, reset_global_memory_buffer, 26 | maybe_get_set_global_memory_buffer, maybe_get_set_global_memory_buffer_bwd, initialize_distributed, get_sequence_parallel_size, get_sequence_parallel_rank) 27 | 28 | @triton.jit 29 | def max_fn(x, y): 30 | return tl.math.max(x, y) 31 | 32 | @triton.jit 33 | def _rescale_kernel( 34 | peer_m, 35 | m, 36 | peer_l, 37 | l, 38 | peer_o, 39 | o, 40 | L, 41 | stride_oz, stride_oh, stride_om, stride_on, 42 | Z, H, N_CTX, 43 | seqlen_q_rounded, seqlen_peer_q_rounded, 44 | BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, 45 | BLOCK_N: tl.constexpr, 46 | LAST_STEP: tl.constexpr, 47 | ): 48 | start_m = tl.program_id(0) 49 | off_hz = tl.program_id(1) 50 | o_offset = off_hz * stride_oh 51 | peer_o_block_ptr = tl.make_block_ptr( 52 | base=peer_o + o_offset, 53 | shape=(N_CTX, BLOCK_DMODEL), 54 | strides=(stride_om, stride_on), 55 | offsets=(start_m * BLOCK_M, 0), 56 | block_shape=(BLOCK_M, BLOCK_DMODEL), 57 | order=(1, 0) 58 | ) 59 | o_block_ptr = tl.make_block_ptr( 60 | base=o + o_offset, 61 | shape=(N_CTX, BLOCK_DMODEL), 62 | strides=(stride_om, stride_on), 63 | offsets=(start_m * BLOCK_M, 0), 64 | block_shape=(BLOCK_M, BLOCK_DMODEL), 65 | order=(1, 0) 66 | ) 67 | # initialize offsets 68 | offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) 69 | offs_n = tl.arange(0, BLOCK_N) 70 | 71 | peer_m_ptrs = peer_m + off_hz * seqlen_peer_q_rounded + offs_m 72 | m_ptrs = m + off_hz * seqlen_q_rounded + offs_m 73 | peer_l_ptrs = peer_l + off_hz * seqlen_peer_q_rounded + offs_m 74 | l_ptrs = l + off_hz * seqlen_q_rounded + offs_m 75 | 76 | peer_m_i = tl.load(peer_m_ptrs) 77 | peer_m_i = peer_m_i.to(tl.float32) 78 | m_i = tl.load(m_ptrs) 79 | m_i = m_i.to(tl.float32) 80 | peer_l_i = tl.load(peer_l_ptrs) 81 | peer_l_i = peer_l_i.to(tl.float32) 82 | l_i = tl.load(l_ptrs) 83 | l_i = l_i.to(tl.float32) 84 | 85 | peer_acc = tl.load(peer_o_block_ptr)#, boundary_check=(0, 1), padding_option='zero') 86 | peer_acc = peer_acc.to(tl.float32) 87 | acc = tl.load(o_block_ptr) #, boundary_check=(0, 1), padding_option='zero') 88 | acc = acc.to(tl.float32) 89 | lo = 0 90 | hi = N_CTX 91 | m_i_sync = tl.maximum(m_i, peer_m_i) 92 | alpha = tl.math.exp2(m_i - m_i_sync) 93 | peer_alpha = tl.math.exp2(peer_m_i - m_i_sync) 94 | # -- scale and update acc -- 95 | acc_scale = l_i * 0 + alpha # workaround some compiler bug 96 | peer_acc_scale = peer_l_i * 0 + peer_alpha # workaround some compiler bug 97 | 98 | acc *= acc_scale[:, None] 99 | peer_acc *= peer_acc_scale[:, None] 100 | acc += peer_acc 101 | l_i = l_i * acc_scale + peer_l_i * peer_acc_scale 102 | # write back O, l, m 103 | tl.store(m_ptrs, m_i_sync) 104 | tl.store(l_ptrs, l_i) 105 | if LAST_STEP: 106 | acc = acc / l_i[:, None] 107 | L_ptrs = L + off_hz * N_CTX + offs_m 108 | tl.store(L_ptrs, m_i_sync / 1.44269504 + tl.math.log(l_i)) 109 | tl.store(o_block_ptr, acc.to(tl.bfloat16), boundary_check=(0, 1)) 110 | 111 | @triton.jit 112 | def _fwd_kernel( 113 | Q, K, V, sm_scale, 114 | m, 115 | l, 116 | O, 117 | L, 118 | stride_qz, stride_qh, stride_qm, stride_qk, 119 | stride_kz, stride_kh, stride_kn, stride_kk, 120 | stride_vz, stride_vh, stride_vk, stride_vn, 121 | stride_oz, stride_oh, stride_om, stride_on, 122 | Z, H, N_CTX, 123 | seqlen_q_rounded, 124 | BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, 125 | BLOCK_N: tl.constexpr, 126 | IS_CAUSAL: tl.constexpr, 127 | LAST_STEP: tl.constexpr 128 | ): 129 | start_m = tl.program_id(0) 130 | off_hz = tl.program_id(1) 131 | qvk_offset = off_hz * stride_qh 132 | Q_block_ptr = tl.make_block_ptr( 133 | base=Q + qvk_offset, 134 | shape=(N_CTX, BLOCK_DMODEL), 135 | strides=(stride_qm, stride_qk), 136 | offsets=(start_m * BLOCK_M, 0), 137 | block_shape=(BLOCK_M, BLOCK_DMODEL), 138 | order=(1, 0) 139 | ) 140 | K_block_ptr = tl.make_block_ptr( 141 | base=K + qvk_offset, 142 | shape=(BLOCK_DMODEL, N_CTX), 143 | strides=(stride_kk, stride_kn), 144 | offsets=(0, 0), 145 | block_shape=(BLOCK_DMODEL, BLOCK_N), 146 | order=(0, 1) 147 | ) 148 | V_block_ptr = tl.make_block_ptr( 149 | base=V + qvk_offset, 150 | shape=(N_CTX, BLOCK_DMODEL), 151 | strides=(stride_vk, stride_vn), 152 | offsets=(0, 0), 153 | block_shape=(BLOCK_N, BLOCK_DMODEL), 154 | order=(1, 0) 155 | ) 156 | O_block_ptr = tl.make_block_ptr( 157 | base=O + qvk_offset, 158 | shape=(N_CTX, BLOCK_DMODEL), 159 | strides=(stride_om, stride_on), 160 | offsets=(start_m * BLOCK_M, 0), 161 | block_shape=(BLOCK_M, BLOCK_DMODEL), 162 | order=(1, 0) 163 | ) 164 | # initialize offsets 165 | offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) 166 | offs_n = tl.arange(0, BLOCK_N) 167 | # initialize pointer to m and l -> load from provided pointer 168 | # (TODO): Why float32? 169 | m_ptrs = m + off_hz * seqlen_q_rounded + offs_m 170 | l_ptrs = l + off_hz * seqlen_q_rounded + offs_m 171 | m_i = tl.load(m_ptrs) 172 | m_i = m_i.to(tl.float32) 173 | l_i = tl.load(l_ptrs) 174 | l_i = l_i.to(tl.float32) 175 | acc = tl.load(O_block_ptr) 176 | acc = acc.to(tl.float32) 177 | # scale sm_scale by log_2(e) and use 178 | # 2^x instead of exp in the loop because CSE and LICM 179 | # don't work as expected with `exp` in the loop 180 | qk_scale = sm_scale * 1.44269504 181 | # load q: it will stay in SRAM throughout 182 | q = tl.load(Q_block_ptr, boundary_check=(0,), padding_option='zero') 183 | q = (q * qk_scale).to(tl.bfloat16) 184 | # loop over k, v and update accumulator 185 | lo = 0 186 | hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX 187 | for start_n in range(lo, hi, BLOCK_N): 188 | # -- load k, v -- 189 | k = tl.load(K_block_ptr, boundary_check=(1,), padding_option='zero') 190 | v = tl.load(V_block_ptr, boundary_check=(0,), padding_option='zero') 191 | # -- compute qk --- 192 | qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 193 | if IS_CAUSAL: 194 | qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) 195 | qk += tl.dot(q, k) 196 | # -- compute scaling constant --- 197 | m_i_new = tl.maximum(m_i, tl.max(qk, 1)) 198 | alpha = tl.math.exp2(m_i - m_i_new) 199 | p = tl.math.exp2(qk - m_i_new[:, None]) 200 | # -- scale and update acc -- 201 | acc_scale = l_i * 0 + alpha # workaround some compiler bug 202 | acc *= acc_scale[:, None] 203 | acc += tl.dot(p.to(tl.bfloat16), v) 204 | # -- update m_i and l_i -- 205 | l_i = l_i * alpha + tl.sum(p, 1) 206 | m_i = m_i_new 207 | # update pointers 208 | K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) 209 | V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) 210 | # write back original l and m 211 | tl.store(m_ptrs, m_i) 212 | tl.store(l_ptrs, l_i) 213 | # write back O, L 214 | if LAST_STEP: 215 | acc = acc / l_i[:, None] 216 | L_ptrs = L + off_hz * seqlen_q_rounded + offs_m 217 | tl.store(L_ptrs, m_i / 1.44269504 + tl.math.log(l_i)) 218 | tl.store(O_block_ptr, acc.to(tl.bfloat16), boundary_check=(0, 1)) 219 | 220 | # for gqa/mqa to expand kv heads 221 | def maybe_repeat_kv_fwd(nqh, kv): 222 | bs, nkvh, slen, hdim = kv.shape 223 | n_rep = nqh // nkvh 224 | if n_rep == 1: 225 | return kv 226 | kv_expand = kv[:, :, None, :, :].expand(bs, nkvh, n_rep, slen, hdim) 227 | return kv_expand.reshape(bs, nkvh * n_rep, slen, hdim) 228 | 229 | def maybe_repeat_kv_bwd(nqh, kv): 230 | bs, slen, nkvh, hdim = kv.shape 231 | n_rep = nqh // nkvh 232 | if n_rep == 1: 233 | return kv 234 | kv_expand = kv[:, :, :, None, :].expand(bs, slen, nkvh, n_rep, hdim) 235 | return kv_expand.reshape(bs, slen, nkvh * n_rep, hdim) 236 | 237 | # kv grad has shape bs, slen, nqh, hdim 238 | def maybe_reduce_dkv(nkvh, dkv): 239 | bs, slen, nqh, hdim = dkv.shape 240 | n_rep = nqh // nkvh 241 | if n_rep == 1: 242 | return dkv 243 | #print("*"*100, dkv.shape, bs, slen, nkvh, n_rep, hdim) 244 | dkv_reshape = dkv.view(bs, slen, nkvh, n_rep, hdim) 245 | #print("-"*100, dkv_reshape.shape, bs, slen, nkvh, n_rep, hdim) 246 | return torch.sum(dkv_reshape, dim=3) 247 | 248 | 249 | def _lightseq_forward_varlen(q, k, v, causal, sm_scale, comm_mode): 250 | # maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x 251 | # q, k, v = [maybe_contiguous(x) for x in (q, k, v)] 252 | 253 | # shape constraints 254 | Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] 255 | # assert Lq == Lk and Lk == Lv 256 | # assert Lk in {16, 32, 64, 128} 257 | BLOCK_M = 128 258 | BLOCK_N = 64 259 | 260 | bsz, nh, unpadded_seq_len, hdim = q.shape 261 | cu_seq_lens = torch.arange(0, (bsz+1) * unpadded_seq_len, unpadded_seq_len, dtype=torch.int32, device=q.device) 262 | max_seqlen = unpadded_seq_len 263 | seqlen_q_rounded = math.ceil(q.shape[2] / BLOCK_M) * BLOCK_M 264 | 265 | m = torch.full((bsz * nh, seqlen_q_rounded), fill_value=-float("inf"), device=q.device, dtype=torch.float32) 266 | l = torch.zeros((bsz * nh, seqlen_q_rounded), device=q.device, dtype=torch.float32) 267 | L = torch.zeros((bsz * nh, seqlen_q_rounded), device=q.device, dtype=torch.float32) 268 | o = torch.zeros_like(q) 269 | 270 | grid = (triton.cdiv(q.shape[2], BLOCK_M), bsz * nh, 1) 271 | num_warps = 4 if Lk <= 64 else 8 272 | 273 | seq_rank = get_sequence_parallel_rank() 274 | seq_world_size = get_sequence_parallel_size() 275 | 276 | # Initialize all buffers 277 | peer_q, peer_k, peer_v, peer_m, peer_l, peer_o = maybe_get_set_global_memory_buffer(q, k, v, m, l, o) 278 | 279 | fwd_launch_helper = lambda q, k, v, m, l, o, L, IS_CAUSAL, LAST_STEP: _fwd_kernel[grid]( 280 | q, k, v, sm_scale, 281 | m, 282 | l, 283 | o, 284 | L, 285 | q.stride(0), q.stride(1), q.stride(2), q.stride(3), 286 | k.stride(0), k.stride(1), k.stride(2), k.stride(3), 287 | v.stride(0), v.stride(1), v.stride(2), v.stride(3), 288 | o.stride(0), o.stride(1), o.stride(2), o.stride(3), 289 | q.shape[0], q.shape[1], q.shape[2], 290 | seqlen_q_rounded, 291 | BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk, 292 | IS_CAUSAL=IS_CAUSAL, 293 | LAST_STEP=LAST_STEP, 294 | num_warps=num_warps, 295 | num_stages=4) 296 | 297 | for time_step in range(seq_world_size // 2 + 1): 298 | # This is important for cuda scheduler to execute nccl calls first. 299 | torch.cuda.synchronize() 300 | # Communication uses buffer_idx_1, and compute uses buffer_idx_2, which effectively are contents from the last time step. 301 | buffer_idx_1 = time_step % 2 302 | buffer_idx_2 = (time_step - 1) % 2 303 | 304 | reqs = maybe_send_recv_fwd_qkvo(q, peer_q[buffer_idx_1], k, peer_k[buffer_idx_1], v, peer_v[buffer_idx_1], 305 | [peer_o[buffer_idx_1], peer_m[buffer_idx_1], peer_l[buffer_idx_1]], time_step, comm_mode) 306 | if comm_mode == "sync": 307 | # if seq_rank == 0: 308 | # print("Immediate wait for abalation") 309 | wait_async_handles(reqs) 310 | if is_compute_for_local_query(time_step): 311 | # print(f"t={time_step}: (Comp) R={seq_rank} local compute") 312 | if time_step == 0: 313 | fwd_launch_helper(q, maybe_repeat_kv_fwd(q.shape[1], k), maybe_repeat_kv_fwd(q.shape[1], v), m, l, o, L, True, is_last_time(time_step)) 314 | else: 315 | # if needs to sync from others, do not normalize here 316 | fwd_launch_helper(q, maybe_repeat_kv_fwd(q.shape[1], peer_k[buffer_idx_2]), maybe_repeat_kv_fwd(q.shape[1], peer_v[buffer_idx_2]), m, l, o, L, False, not is_sync_from_remote(time_step) and is_last_time(time_step)) 317 | elif is_idle(time_step): 318 | # print(f"t={time_step}: (Comp) R={seq_rank} idle") 319 | pass 320 | else: 321 | # print(f"t={time_step}: (Comp) R={seq_rank} helps other") 322 | peer_m[buffer_idx_2] = torch.full_like(m, fill_value=-float("inf")) 323 | peer_l[buffer_idx_2] = torch.zeros_like(l) 324 | peer_o[buffer_idx_2] = torch.zeros_like(o) 325 | 326 | #print(f"rank 3 q is: {peer_q[buffer_idx_2]}") 327 | fwd_launch_helper(peer_q[buffer_idx_2], maybe_repeat_kv_fwd(q.shape[1], k), maybe_repeat_kv_fwd(q.shape[1], v), peer_m[buffer_idx_2], peer_l[buffer_idx_2], peer_o[buffer_idx_2], None, False, False) 328 | 329 | if comm_mode == "lightseq": 330 | # Make sure tensors for next steps are ready 331 | wait_async_handles(reqs) 332 | # sync between statistics get from other ranks and the local ones 333 | if is_sync_from_remote(time_step): 334 | # print(f"t={time_step}: (Comp) R={seq_rank} sync with other - last time: {is_last_time(time_step)}") 335 | seqlen_peer_q_rounded = peer_l[buffer_idx_1].shape[-1] 336 | _rescale_kernel[grid]( 337 | peer_m[buffer_idx_1], 338 | m, 339 | peer_l[buffer_idx_1], 340 | l, 341 | peer_o[buffer_idx_1], 342 | o, 343 | L, 344 | o.stride(0), o.stride(1), o.stride(2), o.stride(3), 345 | o.shape[0], o.shape[1], o.shape[2], 346 | seqlen_q_rounded, seqlen_peer_q_rounded, 347 | BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk, 348 | LAST_STEP=is_last_time(time_step), 349 | num_warps=num_warps, 350 | num_stages=4) 351 | return q, k, v, o, L, cu_seq_lens, max_seqlen 352 | 353 | def _lightseq_backward_varlen(do, q, k, v, o, L, sm_scale, comm_mode, backward_engine, cu_seq_lens, max_seqlen): 354 | BLOCK = 128 355 | L = rearrange(L[:, :max_seqlen].contiguous(), '(b h) s -> b h s', b=q.shape[0]) 356 | q, k, v, o, do = [rearrange(_x, 'b h s d -> (b s) h d').contiguous() for _x in [q, k, v, o, do]] 357 | 358 | dq = torch.empty_like(q) 359 | dk = torch.empty_like(k) 360 | dv = torch.empty_like(v) 361 | 362 | # maybe gqa 363 | nqh = q.shape[1] 364 | nkvh = k.shape[1] 365 | is_gqa = (nqh > nkvh) 366 | 367 | seq_rank = get_sequence_parallel_rank() 368 | seq_world_size = get_sequence_parallel_size() 369 | 370 | # Initialize all backward buffers 371 | dq_delta, dk_delta, dv_delta, dk_delta_from_peer, dv_delta_from_peer, \ 372 | peer_q, peer_L, peer_k, peer_v, peer_o, peer_do = maybe_get_set_global_memory_buffer_bwd(dq, dk, dv, q, L, k, v, o, do) 373 | 374 | for time_step in range(0, get_sequence_parallel_size() // 2 + 1): 375 | torch.cuda.synchronize() 376 | buffer_idx_1 = time_step % 2 377 | buffer_idx_2 = (time_step - 1) % 2 378 | 379 | reqs, is_update_dq, is_update_dkv = maybe_send_recv_bwd_qkvo(dq_delta[buffer_idx_1], dk_delta[buffer_idx_1], dv_delta[buffer_idx_1], dk_delta_from_peer, dv_delta_from_peer, q, peer_q[buffer_idx_1], L, peer_L[buffer_idx_1], k, peer_k[buffer_idx_1], v, peer_v[buffer_idx_1], o, peer_o[buffer_idx_1], do, peer_do[buffer_idx_1], time_step, comm_mode) 380 | if comm_mode == "sync": 381 | wait_async_handles(reqs) 382 | 383 | if is_compute_for_local_query(time_step): 384 | if time_step == 0: 385 | assert backward_engine == "flash", "We haven't supportted varlen feature in xformer" 386 | if backward_engine == "flash": 387 | _flash_attn_varlen_backward(do, q, k, v, o, L, dq, dk, dv, cu_seq_lens, cu_seq_lens, max_seqlen, max_seqlen, 0.0, sm_scale, True, None) 388 | else: 389 | inp = Inputs(query=q, key=maybe_repeat_kv_bwd(q.shape[2], k), value=maybe_repeat_kv_bwd(q.shape[2], v), attn_bias=xformers.ops.LowerTriangularMask(), p=0, scale=sm_scale) 390 | op_ctx = Context(lse=L, out=o, rng_state=None) 391 | # Let xformers dispatch the correct backend 392 | grads = _memory_efficient_attention_backward(ctx=op_ctx, inp=inp, grad=do, op=None) 393 | dq = grads.dq 394 | dk, dv = maybe_reduce_dkv(nkvh, grads.dk), maybe_reduce_dkv(nkvh, grads.dv) 395 | else: 396 | assert backward_engine == "flash", "We haven't supportted varlen feature in xformer" 397 | if backward_engine == "flash": 398 | _flash_attn_varlen_backward(do, q, peer_k[buffer_idx_2], peer_v[buffer_idx_2], o, L, dq_delta[buffer_idx_2], dk_delta[buffer_idx_2], dv_delta[buffer_idx_2], cu_seq_lens, cu_seq_lens, max_seqlen, max_seqlen, 0.0, sm_scale, False, None) 399 | else: 400 | inp = Inputs(query=q, key=maybe_repeat_kv_bwd(q.shape[2], peer_k[buffer_idx_2]), value=maybe_repeat_kv_bwd(q.shape[2], peer_v[buffer_idx_2]), attn_bias=None, p=0, scale=sm_scale) 401 | op_ctx = Context(lse=L, out=o, rng_state=None) 402 | grads = _memory_efficient_attention_backward(ctx=op_ctx, inp=inp, grad=do, op=None) 403 | dq_delta[buffer_idx_2] = grads.dq 404 | dk_delta[buffer_idx_2], dv_delta[buffer_idx_2] = maybe_reduce_dkv(nkvh, grads.dk), maybe_reduce_dkv(nkvh, grads.dv) 405 | dq += dq_delta[buffer_idx_2] 406 | elif is_idle(time_step): 407 | # print(f"BWD t={time_step}: (Comp) R={seq_rank} idle") 408 | pass 409 | else: 410 | # print(f"BWD t={time_step}: (Comp) R={seq_rank} helps other") 411 | assert backward_engine == "flash", "We haven't supportted varlen feature in xformer" 412 | if backward_engine == "flash": 413 | _flash_attn_varlen_backward(peer_do[buffer_idx_2], peer_q[buffer_idx_2], k, v, peer_o[buffer_idx_2], peer_L[buffer_idx_2], dq_delta[buffer_idx_2], dk_delta[buffer_idx_2], dv_delta[buffer_idx_2], cu_seq_lens, cu_seq_lens, max_seqlen, max_seqlen, 0.0, sm_scale, False, None) 414 | else: 415 | inp = Inputs(query=peer_q[buffer_idx_2], key=maybe_repeat_kv_bwd(q.shape[2], k), value=maybe_repeat_kv_bwd(q.shape[2], v), attn_bias=None, p=0, scale=sm_scale) 416 | op_ctx = Context(lse=peer_L[buffer_idx_2], out=peer_o[buffer_idx_2], rng_state=None) 417 | grads = _memory_efficient_attention_backward(ctx=op_ctx, inp=inp, grad=peer_do[buffer_idx_2], op=None) 418 | dq_delta[buffer_idx_2] = grads.dq 419 | dk_delta[buffer_idx_2], dv_delta[buffer_idx_2] = maybe_reduce_dkv(nkvh, grads.dk), maybe_reduce_dkv(nkvh, grads.dv) 420 | dk += dk_delta[buffer_idx_2] 421 | dv += dv_delta[buffer_idx_2] 422 | 423 | if comm_mode == "lightseq": 424 | # Make sure tensors for next steps are ready 425 | wait_async_handles(reqs) 426 | 427 | # The last time step needs to send dk and dv immediately, move it up here to maximize overlap with the following three addition. 428 | reqs, is_update_last_dkv = maybe_send_recv_bwd_last_dkv(dk_delta[buffer_idx_2], dv_delta[buffer_idx_2], time_step, comm_mode) 429 | 430 | if comm_mode == "sync": 431 | # if seq_rank == 0: 432 | # print("(bwd) dkv Immediate wait for abalation") 433 | wait_async_handles(reqs) 434 | # apply dq_delta, dk_delta and dv_delta from remote 435 | if is_update_dq: 436 | dq += dq_delta[buffer_idx_1] 437 | if is_update_dkv: 438 | dk += dk_delta_from_peer 439 | dv += dv_delta_from_peer 440 | 441 | if comm_mode == "lightseq": 442 | wait_async_handles(reqs) 443 | # apply dk_delta and dv_delta to sender 444 | if is_update_last_dkv: 445 | dk += dk_delta[buffer_idx_2] 446 | dv += dv_delta[buffer_idx_2] 447 | 448 | dq, dk, dv = [rearrange(_x, '(b s) h d -> b h s d', s=max_seqlen) for _x in [dq, dk, dv]] 449 | return dq, dk, dv 450 | 451 | class _attention_varlen(torch.autograd.Function): 452 | @staticmethod 453 | def forward(ctx, q, k, v, causal, sm_scale): 454 | try: 455 | global args 456 | comm_mode = args.comm_mode 457 | backward_engine = args.backward_engine 458 | except: 459 | comm_mode = 'lightseq' 460 | backward_engine = 'flash' 461 | 462 | q, k, v, o, L, cu_seq_lens, max_seqlen = _lightseq_forward_varlen(q, k, v, causal, sm_scale, comm_mode) 463 | 464 | ctx.save_for_backward(q, k, v, o, L, cu_seq_lens) 465 | ctx.max_seqlen = max_seqlen 466 | ctx.sm_scale = sm_scale 467 | ctx.comm_mode = comm_mode 468 | ctx.backward_engine = backward_engine 469 | return o 470 | 471 | @staticmethod 472 | def backward(ctx, do): 473 | q, k, v, o, L, cu_seq_lens = ctx.saved_tensors 474 | sm_scale = ctx.sm_scale 475 | max_seqlen = ctx.max_seqlen 476 | 477 | dq, dk, dv = _lightseq_backward_varlen(do, q, k, v, o, L, sm_scale, ctx.comm_mode, ctx.backward_engine, cu_seq_lens, max_seqlen) 478 | return dq, dk, dv, None, None 479 | 480 | dist_attn_varlen = _attention_varlen.apply 481 | 482 | 483 | #@pytest.mark.parametrize('causal', [False, True]) 484 | #@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(6, 9, 1024, 64)]) 485 | def test_op(Z, H, N_CTX, D_HEAD, causal, dtype=torch.bfloat16): 486 | torch.manual_seed(20) 487 | rank = dist.get_rank() 488 | world_size = dist.get_world_size() 489 | 490 | 491 | PAD = world_size * 256 492 | seq_per_rank = (N_CTX-PAD) // world_size 493 | q = torch.empty((Z, H, N_CTX-PAD, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() 494 | k = torch.empty((Z, H, N_CTX-PAD, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() 495 | v = torch.empty((Z, H, N_CTX-PAD, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() 496 | 497 | # DEBUG: mask out 498 | #mask = torch.zeros(Z, H, seq_per_rank * (world_size - 1), D_HEAD).cuda() 499 | #mask_2 = torch.ones(Z, H, seq_per_rank, D_HEAD).cuda() 500 | #mask = torch.cat((mask, mask_2), dim=-2).to(dtype) 501 | #q = mask * q 502 | #k = mask * k 503 | #v = mask * v 504 | 505 | sm_scale = 0.5 506 | dout = torch.randn_like(q) 507 | # reference implementation 508 | M = torch.tril(torch.ones((N_CTX-PAD, N_CTX-PAD), device="cuda")) 509 | p = torch.matmul(q, k.transpose(2, 3)) * sm_scale 510 | assert causal 511 | if causal: 512 | p[:, :, M == 0] = float("-inf") 513 | p = torch.softmax(p.float(), dim=-1).half() 514 | ref_out = torch.matmul(p, v) 515 | ref_out.backward(dout) 516 | ref_dv, v.grad = v.grad.clone(), None 517 | ref_dk, k.grad = k.grad.clone(), None 518 | ref_dq, q.grad = q.grad.clone(), None 519 | 520 | # triton implementation 521 | 522 | a, b, c, d = q.size() 523 | real_q = q[:,:, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].view(a, b, -1, d).contiguous().clone().detach().requires_grad_(True) 524 | real_k = k[:,:, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].view(a, b, -1, d).contiguous().clone().detach().requires_grad_(True) 525 | real_v = v[:,:, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].view(a, b, -1, d).contiguous().clone().detach().requires_grad_(True) 526 | real_do = dout[:,:, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].view(a, b, -1, d).contiguous().clone().detach().requires_grad_(True) 527 | 528 | tri_out = dist_attn_varlen(real_q, real_k, real_v, causal, sm_scale).half() 529 | 530 | # compare 531 | assert torch.allclose(ref_out[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_out, atol=1e-2, rtol=0), f" rank {rank} fails forward" 532 | print(f" *** rank {rank} passes forward") 533 | tri_out.backward(real_do) 534 | tri_dv, real_v.grad = real_v.grad.clone(), None 535 | tri_dk, real_k.grad = real_k.grad.clone(), None 536 | tri_dq, real_q.grad = real_q.grad.clone(), None 537 | assert torch.allclose(ref_dq[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_dq, atol=1e-2, rtol=0), f"rank {rank} fails backward dq" #{ref_dq[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dq} {torch.max(ref_dq[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dq)} rank {rank} fails backward dk" 538 | assert torch.allclose(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_dk, atol=1e-2, rtol=0), f"rank {rank} fails backward dk" #{ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dk} {torch.max(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dk)} rank {rank} fails backward dk" 539 | assert torch.allclose(ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_dv, atol=1e-2, rtol=0), f"rank {rank} fails backward dv" #{ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dv} {torch.max(ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dv)} rank {rank} fails backward dv" 540 | print(f"rank {rank} passes backward") 541 | 542 | #TODO(High Priority): Investigate why rank 0 tends to have larger numerical difference. 543 | def test_gqa(Z, H, KVH, N_CTX, D_HEAD, causal, dtype=torch.bfloat16): 544 | torch.manual_seed(177) 545 | q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() 546 | k = torch.empty((Z, KVH, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() 547 | v = torch.empty((Z, KVH, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() 548 | 549 | rank = dist.get_rank() 550 | world_size = dist.get_world_size() 551 | seq_per_rank = N_CTX // world_size 552 | 553 | sm_scale = 0.5 554 | dout = torch.randn_like(q) 555 | # torch reference implementation 556 | M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) 557 | ref_k = maybe_repeat_kv_fwd(q.shape[1], k).clone().detach().requires_grad_(True) 558 | ref_v = maybe_repeat_kv_fwd(q.shape[1], v).clone().detach().requires_grad_(True) 559 | #print(q.shape, ref_k.shape, k.shape) 560 | p = torch.matmul(q, ref_k.transpose(2,3)) * sm_scale 561 | assert causal 562 | if causal: 563 | p[:, :, M == 0] = float("-inf") 564 | p = torch.softmax(p.float(), dim=-1).half() 565 | ref_out = torch.matmul(p, ref_v) 566 | ref_out.backward(dout) 567 | ref_dv, v.grad = ref_v.grad.clone(), None 568 | #print("Before reduce", ref_dv.shape) 569 | ref_dv = (maybe_reduce_dkv(KVH, ref_dv.transpose(1,2))).transpose(1,2) 570 | #print("After reduce", ref_dv.shape) 571 | ref_dk, k.grad = ref_k.grad.clone(), None 572 | ref_dk = (maybe_reduce_dkv(KVH, ref_dk.transpose(1,2))).transpose(1,2) 573 | ref_dq, q.grad = q.grad.clone(), None 574 | 575 | # flash reference 576 | from flash_attn import flash_attn_qkvpacked_func, flash_attn_func 577 | flash_q = q.transpose(1,2).clone().detach().requires_grad_(True) 578 | flash_k = k.transpose(1,2).clone().detach().requires_grad_(True) 579 | flash_v = v.transpose(1,2).clone().detach().requires_grad_(True) 580 | flash_ref_out = flash_attn_func(flash_q, flash_k, flash_v, 0, sm_scale, True) 581 | flash_ref_out.backward(dout.transpose(1,2)) 582 | flash_ref_out = flash_ref_out.transpose(1,2) 583 | flash_ref_dv, v.grad = flash_v.grad.clone(), None 584 | flash_ref_dv = flash_ref_dv.transpose(1,2) 585 | flash_ref_dk, k.grad = flash_k.grad.clone(), None 586 | flash_ref_dk = flash_ref_dk.transpose(1,2) 587 | flash_ref_dq, q.grad = flash_q.grad.clone(), None 588 | flash_ref_dq = flash_ref_dq.transpose(1,2) 589 | 590 | # triton implementation 591 | 592 | a, b, c, d = q.size() 593 | real_q = q[:,:, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].view(a, b, -1, d).contiguous().clone().detach().requires_grad_(True) 594 | real_k = k[:,:, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].view(a, KVH, -1, d).contiguous().clone().detach().requires_grad_(True) 595 | real_v = v[:,:, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].view(a, KVH, -1, d).contiguous().clone().detach().requires_grad_(True) 596 | real_do = dout[:,:, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].view(a, b, -1, d).contiguous().clone().detach().requires_grad_(True) 597 | 598 | tri_out = dist_attn_varlen(real_q, real_k, real_v, causal, sm_scale).half() 599 | 600 | # compare 601 | assert torch.allclose(flash_ref_out[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_out, atol=1e-2, rtol=0), f" rank {rank} fails forward against flash" 602 | print(f" *** rank {rank} passes forward") 603 | tri_out.backward(real_do) 604 | tri_dv, real_v.grad = real_v.grad.clone(), None 605 | tri_dk, real_k.grad = real_k.grad.clone(), None 606 | tri_dq, real_q.grad = real_q.grad.clone(), None 607 | assert torch.allclose(flash_ref_dq[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_dq, atol=1e-2, rtol=0), f" rank {rank} fails backward dq against flash" 608 | #print(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].shape, ref_dk.shape, tri_dk.shape) 609 | assert torch.allclose(flash_ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_dk, atol=1e-2, rtol=0), f"rank {rank} fails backward dk against flash {flash_ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dk} {torch.max(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dk)} rank {rank} fails backward dk" 610 | assert torch.allclose(flash_ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_dv, atol=1e-2, rtol=0), f"rank {rank} fails backward dv against flash {flash_ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dv} {torch.max(flash_ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dv)} rank {rank} fails backward dv" 611 | print(f"rank {rank} passes backward against flash") 612 | 613 | assert torch.allclose(ref_out[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_out, atol=1e-2, rtol=0), f" rank {rank} fails forward" 614 | print(f" *** rank {rank} passes forward") 615 | assert torch.allclose(ref_dq[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_dq, atol=1e-2, rtol=0), f" rank {rank} fails backward dq" 616 | #print(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :].shape, ref_dk.shape, tri_dk.shape) 617 | assert torch.allclose(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_dk, atol=1e-2, rtol=0), f"rank {rank} fails backward dk {ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dk} {torch.max(ref_dk[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dk)} rank {rank} fails backward dk" 618 | assert torch.allclose(ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :], tri_dv, atol=1e-2, rtol=0), f"rank {rank} fails backward dv {ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :]} {tri_dv} {torch.max(ref_dv[:, :, rank * seq_per_rank: (rank + 1) * seq_per_rank, :] - tri_dv)} rank {rank} fails backward dv" 619 | print(f"rank {rank} passes backward") 620 | 621 | #BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 622 | try: 623 | from flash_attn.flash_attn_interface import \ 624 | flash_attn_qkvpacked_func as flash_attn_func 625 | FLASH_VER = 2 626 | except BaseException: 627 | try: 628 | from flash_attn.flash_attn_interface import flash_attn_func 629 | FLASH_VER = 1 630 | except BaseException: 631 | FLASH_VER = None 632 | HAS_FLASH = FLASH_VER is not None 633 | HAS_FLASH = None 634 | ONLY_FLASH = False 635 | 636 | #BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 637 | BATCH, N_HEADS, N_CTX, D_HEAD = 1, 32, None, 128 638 | # vary seq length for fixed head and batch=4 639 | configs = [triton.testing.Benchmark( 640 | x_names=['N_CTX'], 641 | x_vals=[2**i for i in range(18, 19)],#[ 20, 21]],#[10, 11, 12, 13, 14, 15, 16, 17, 18]], 642 | line_arg='provider', 643 | line_vals=['triton'] if not ONLY_FLASH else [] + (['flash'] if HAS_FLASH else []), 644 | line_names=['Triton'] if not ONLY_FLASH else [] + ([f'Flash-{FLASH_VER}'] if HAS_FLASH else []), 645 | styles=[('red', '-'), ('blue', '-')], 646 | ylabel='ms', 647 | plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}-{causal}', 648 | args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.bfloat16, 'mode': mode, 'causal': causal} 649 | ) for mode in ["all"] for causal in [True]] 650 | 651 | # @triton.testing.perf_report(configs) 652 | def bench_flash_attention(BATCH, H, KVH, N_CTX, D_HEAD, causal, mode, provider, args, dtype=torch.bfloat16, device="cuda"): 653 | assert mode == "all" #mode in ['fwd', 'bwd'] 654 | n_warmup = 10 655 | n_repeat = 10 656 | cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda') 657 | seq_rank = get_sequence_parallel_rank() 658 | seq_world_size = get_sequence_parallel_size() 659 | if provider == "triton": 660 | q = torch.randn((BATCH, H, N_CTX // seq_world_size, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) 661 | k = torch.randn((BATCH, KVH, N_CTX // seq_world_size, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) 662 | v = torch.randn((BATCH, KVH, N_CTX // seq_world_size, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) 663 | if seq_rank == 0: 664 | print(f"Benchmarking per GPU qkv shape: {q.shape}") 665 | sm_scale = 1.3 666 | fwd_fn = lambda: dist_attn_varlen(q, k, v, causal, sm_scale) 667 | if provider == "flash": 668 | qkv = torch.randn((BATCH, N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True) 669 | if FLASH_VER == 1: 670 | lengths = torch.full((BATCH,), fill_value=N_CTX, device=device) 671 | cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32) 672 | cu_seqlens[1:] = lengths.cumsum(0) 673 | qkv = qkv.reshape(BATCH * N_CTX, 3, H, D_HEAD) 674 | fwd_fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=causal) 675 | elif FLASH_VER == 2: 676 | fwd_fn = lambda: flash_attn_func(qkv, causal=causal) 677 | else: 678 | raise ValueError(f'unknown {FLASH_VER = }') 679 | 680 | flops_per_matmul = 2. * BATCH * H * N_CTX * N_CTX * D_HEAD / seq_world_size 681 | attn_flops = 2 * flops_per_matmul 682 | 683 | assert causal 684 | if causal: 685 | attn_flops *= 0.5 686 | fwd_flops = attn_flops 687 | bwd_flops = attn_flops * 2.5 # 2.0(bwd) + 0.5(recompute) 688 | 689 | o = fwd_fn() 690 | do = torch.randn_like(o) 691 | bwd_fn = lambda: o.backward(do, retain_graph=True) 692 | 693 | def run_benchmark(fn): 694 | time_list = [] 695 | for _ in tqdm(range(n_warmup)): 696 | cache.zero_() 697 | fn() 698 | torch.cuda.synchronize() 699 | if args.debug: 700 | print_and_reset_comm_stats() 701 | for i in tqdm(range(n_repeat)): 702 | cache.zero_() 703 | torch.cuda.synchronize() 704 | time_s = time.time() 705 | fn() 706 | torch.cuda.synchronize() 707 | time_e = time.time() 708 | time_list.append((time_e - time_s) * 1000.0) 709 | if args.debug: 710 | print_and_reset_comm_stats() 711 | return np.asarray(time_list) 712 | 713 | fwd_time_arr = run_benchmark(fwd_fn) 714 | bwd_time_arr = run_benchmark(bwd_fn) 715 | 716 | fwd_flops_ps = fwd_flops / np.mean(fwd_time_arr) * 1e-9 717 | print(f"(FWD) R={seq_rank} avg: {np.mean(fwd_time_arr)}, std: {np.std(fwd_time_arr)} flops: {fwd_flops_ps} \n") 718 | 719 | bwd_flops_ps = bwd_flops / np.mean(bwd_time_arr) * 1e-9 720 | print(f"(BWD) R={seq_rank} avg: {np.mean(bwd_time_arr)}, std: {np.std(bwd_time_arr)} flops: {bwd_flops_ps} \n") 721 | 722 | # total 723 | total_time_arr = fwd_time_arr + bwd_time_arr 724 | total_flops = fwd_flops + bwd_flops 725 | total_flops_ps = total_flops / np.mean(total_time_arr) * 1e-9 726 | print(f"(Total) R={seq_rank} avg: {np.mean(total_time_arr)}, std: {np.std(total_time_arr)} flops: {total_flops_ps} \n") 727 | 728 | #return total_flops_ps 729 | 730 | 731 | if __name__ == "__main__": 732 | parser = argparse.ArgumentParser() 733 | parser.add_argument("--comm-mode", type=str, default="lightseq") 734 | parser.add_argument("--debug", action="store_true") 735 | parser.add_argument("--run-mode", type=str, default="test") 736 | parser.add_argument("--bs", type=int, default=1) 737 | parser.add_argument("--n_heads", type=int, default=32) 738 | parser.add_argument("--n_kvheads", type=int, default=32) 739 | parser.add_argument("--d_head", type=int, default=128) 740 | parser.add_argument("--start_ctx", type=int, default=12) 741 | parser.add_argument("--end_ctx", type=int, default=18) 742 | parser.add_argument("--forward_engine", type=str, default="triton") 743 | parser.add_argument("--backward_engine", type=str, default="flash") 744 | 745 | global args 746 | args = parser.parse_args() 747 | initialize_distributed() 748 | 749 | assert args.forward_engine == "triton", "Only triton forward is implmented." 750 | assert args.backward_engine in ["flash", "xformers"], "Only flash or xformers backward is implemented." 751 | 752 | if args.backward_engine == "flash": 753 | from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward 754 | else: 755 | try: 756 | import xformers.ops 757 | from xformers.ops.fmha.common import Inputs, Context 758 | from xformers.ops.fmha import _memory_efficient_attention_backward 759 | from xformers.ops.fmha import cutlass, flash 760 | except ImportError: 761 | print("xformers not found! Please install it before trying to use it.") 762 | 763 | if args.run_mode == "benchmark": 764 | for N_CTX in [2**i for i in range(args.start_ctx, args.end_ctx)]: 765 | bench_flash_attention(args.bs, args.n_heads, args.n_kvheads, N_CTX, args.d_head, True, "all", "triton", args)#.run(save_path='.', print_data=True) 766 | reset_global_memory_buffer() 767 | else: 768 | assert args.run_mode == "test" 769 | for N_CTX in [4096]: 770 | test_op(2, 16, N_CTX, 128, True) 771 | #test_gqa(1, 16, 8, N_CTX, 128, True) 772 | reset_global_memory_buffer() 773 | --------------------------------------------------------------------------------