├── README.md ├── data ├── data_process.sh └── script │ ├── preprocess.example.py │ ├── pseudo.multiturn.py │ ├── step1.sample.py │ ├── step2.tokenize.py │ └── step3.pack.py ├── docs ├── image-1.png ├── image-2.png ├── image-3.png └── image-4.png └── train ├── config.py ├── config ├── multi_node.yaml ├── single_node.yaml └── zero3_offload.json ├── data.py ├── requirements.txt ├── ring-flash-attention ├── LICENSE ├── README.md ├── benchmark │ ├── benchmark_kvpacked_func.py │ └── benchmark_varlen_kvpacked_func.py ├── build │ └── lib │ │ └── ring_flash_attn │ │ ├── __init__.py │ │ ├── adapters │ │ ├── __init__.py │ │ └── hf_adapter.py │ │ ├── llama3_flash_attn_varlen.py │ │ ├── ring_flash_attn.py │ │ ├── ring_flash_attn_varlen.py │ │ ├── stripe_flash_attn.py │ │ ├── triton_utils.py │ │ ├── utils.py │ │ ├── zigzag_ring_flash_attn.py │ │ └── zigzag_ring_flash_attn_varlen.py ├── pyproject.toml ├── ring_flash_attn.egg-info │ ├── PKG-INFO │ ├── SOURCES.txt │ ├── dependency_links.txt │ └── top_level.txt ├── ring_flash_attn │ ├── .___pycache__ │ ├── .ipynb_checkpoints │ │ ├── llama3_flash_attn_varlen-checkpoint.py │ │ └── utils-checkpoint.py │ ├── __init__.py │ ├── adapters │ │ ├── .ipynb_checkpoints │ │ │ └── hf_adapter-checkpoint.py │ │ ├── __init__.py │ │ └── hf_adapter.py │ ├── llama3_flash_attn_varlen.py │ ├── ring_flash_attn.py │ ├── ring_flash_attn_varlen.py │ ├── stripe_flash_attn.py │ ├── triton_utils.py │ ├── utils.py │ ├── zigzag_ring_flash_attn.py │ └── zigzag_ring_flash_attn_varlen.py ├── setup.py └── test │ ├── test_llama3_flash_attn_varlen_func.py │ ├── test_llama3_prepare_cu_seqlens.py │ ├── test_ring_flash_attn_func.py │ ├── test_ring_flash_attn_varlen_func.py │ ├── test_stripe_flash_attn_func.py │ ├── test_triton_kernels.py │ ├── test_zigzag_ring_flash_attn_func.py │ ├── test_zigzag_ring_flash_attn_varlen_func.py │ └── utils.py ├── ring_attn_utils.py ├── train.py └── train.sh /README.md: -------------------------------------------------------------------------------- 1 |
2 |

OpenSFT

3 |
4 |
5 |

6 | Open-source / Lightweight / Easy-to-use / Large-Scale / Extra-Long-Text 7 |

8 |
9 | 10 |
11 | 12 | OpenSFT,一个开源的SFT训练框架,基于 accelerator + deepspeed + ring flash attention 实现。 13 | 14 | 本项目实现了length-pack数据组织逻辑,进一步增加了并行量;实现了序列并行下的分类loss统计,更好地监控各类别效果;新增了序列并行下的turn-loss,兼顾长短对话。 15 | 16 | 本训练框架非常轻量,易于学习和二次开发,欢迎star。 17 | 18 | ### 更新 19 | 20 | 2025.03.27 新增按概率拼接伪多轮,详见data/script/pseudo.multiturn.py 21 | 22 | 23 | 24 | ### 环境安装 25 | 26 | ```bash 27 | git clone https://github.com/mlpod/OpenSFT.git 28 | cd OpenSFT/train 29 | pip install -r requirements.txt 30 | ``` 31 | 32 | ### 数据组织方式 33 | 34 | SFT中几种不同的数据组织方式: 35 | 36 | 1. 传统数据组织方式,红色部分是计算损失部分,存在大量padding。 37 | 38 | 39 | 40 | 2. 变长、多轮loss数据组织方式,红色部分是计算损失部分,存在padding。 41 | 42 | 43 | 44 | 3. 将batch内的序列打包成一个样本,序列并行计算。有padding(需要补齐到能被序列并行度整除),不同样本直接长度差异大。 45 | 46 | 47 | **本项目实现的数据组织方式**: 在3基础上,将batch内改为按指定最大长度组织数据,打包直到快超出最大长度,因此有更少的padding。不同样本直接长度差异微乎其微。例如指定128K,在distill_r1_110k数据集上可以压缩至1k样本量。 48 | 49 | 在线打包会带来一定耗时,在此采用了离线计算的方式,将训练和数据处理进行解耦。 50 | 另外离线计算中应用了贪心、并发等实现,进一步减少了pack后样本长度差异和时长,详见代码。 51 | 52 | 数据处理代码放在了data目录中。data目录中的raw目录是存放未预处理的文件,数据格式如下: 53 | ```json 54 | { 55 | "messages": [ 56 | { 57 | "role": "system", 58 | "content": "" 59 | }, 60 | { 61 | "role": "user", 62 | "content": "" 63 | }, 64 | { 65 | "role": "assistant", 66 | "content": "" 67 | } 68 | ], 69 | "labels": [0, 0, 1], 70 | "meta": { 71 | "category_id": 1, 72 | "category_name": "" 73 | } 74 | } 75 | ``` 76 | 其中 labels 是 messages 中计算损失的轮次,仅 assistant 有效。category_id不为0时,在训练过程中会打印该category_id的loss。 77 | 78 | 数据配比配置文件格式如下: 79 | ```json 80 | { 81 | "data_name": "", 82 | "data_path": "", 83 | "ratio": [ 84 | { 85 | "category_id": 1, 86 | "category_name": "", 87 | "size": 100, 88 | "sample_rate": 1.0 89 | }, 90 | { 91 | "category_id": 2, 92 | "category_name": "", 93 | "size": 200, 94 | "sample_rate": 0.8 95 | } 96 | ] 97 | } 98 | ``` 99 | 用户需修改数据预处理逻辑代码 data/step1.preprocess.py,将数据处理成目标格式,并生成数据配比配置文件。 100 | 用户可按需修改配比。 101 | 102 | 脚本中的demo数据:https://huggingface.co/datasets/Congliu/Chinese-DeepSeek-R1-Distill-data-110k/resolve/main/distill_r1_110k.jsonl 103 | 104 | 详见数据处理脚本 105 | ```bash 106 | sh data_process.sh 107 | ``` 108 | 109 | ### 训练 110 | ```bash 111 | sh train.sh 112 | ``` 113 | #### category-loss 114 | 115 | ```json 116 | { 117 | "epoch": 2, 118 | "steps": 3, 119 | "lr": 1.22375e-05, 120 | "loss": 0.515095591545105, 121 | "coig/neo_loss": 0.5617072582244873, 122 | "stem_zh/phy_loss": 0.4811963737010956, 123 | "EduChat-Math_loss": 0.4951120913028717, 124 | "meta-math/GSM8K_zh_loss": 0.5640832781791687, 125 | "exam/coig_exam_loss": 0.6263442635536194, 126 | "gavinluo/applied_math_loss": 0.4919000566005707, 127 | "stem_zh/chem_loss": 0.4528641700744629, 128 | "stem_zh/bio_loss": 0.46091940999031067, 129 | "zhihu/zhihu_score9.0-10_clean_v10_loss": 0.5875096917152405, 130 | "xhs/xhs_loss": 0.7661288380622864, 131 | "stem_zh/med_loss": 0.42540857195854187, 132 | "human_value/100poison_loss": 0.5484293699264526, 133 | "ruozhiba/ruozhiba_ruozhiba_loss": 0.825197160243988, 134 | "logi_qa/logi-qa_loss": 0.6175104975700378, 135 | "Haijian/Advanced-Math_loss": 0.4288356602191925, 136 | "exam/kaoyan_loss": 0.6865882873535156 137 | } 138 | ``` 139 | 140 | #### turn-loss 141 | 受[SFT loss 计算的那些坑(多轮合并/packing)](https://zhuanlan.zhihu.com/p/721652210)的启发,新增了turn-loss,以兼顾长短对话。细节详见代码,大家可以按需使用。 142 | 143 | -------------------------------------------------------------------------------- /data/data_process.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | 4 | step0=0 5 | step1=1 6 | step2=0 7 | step3=0 8 | 9 | if [ $step0 -ne 0 ]; then 10 | python script/preprocess.example.py 11 | fi 12 | 13 | config_path=config/distill_r1_110k.json 14 | preprocessed_data_path=dataset/distill_r1_110k.preprocessed.jsonl 15 | sampled_data_path=dataset/distill_r1_110k.sampled.jsonl 16 | tokenized_data_path=dataset/distill_r1_110k.tokenized.jsonl 17 | packed_data_path=dataset/distill_r1_110k.packed.jsonl 18 | 19 | 20 | if [ $step1 -ne 0 ]; then 21 | python script/step1.sample.py \ 22 | --config-path=$config_path \ 23 | --preprocessed-data-path=$preprocessed_data_path \ 24 | --output-path=$sampled_data_path \ 25 | --seed=2025 --num-workers=10 26 | fi 27 | 28 | # python script/pseudo.multiturn.py 29 | 30 | tokenizer_path= 31 | padding_value= 32 | 33 | if [ $step2 -ne 0 ]; then 34 | python script/step2.tokenize.py \ 35 | --input-path=$sampled_data_path \ 36 | --output-path=$tokenized_data_path \ 37 | --tokenizer-path=$tokenizer_path \ 38 | --num-workers=10 39 | fi 40 | 41 | 42 | if [ $step3 -ne 0 ]; then 43 | python script/step3.pack.py \ 44 | --input-path=$tokenized_data_path \ 45 | --output-path=$packed_data_path \ 46 | --max-length=131072 \ 47 | --padding-value=$padding_value \ 48 | --num-workers=10 49 | fi -------------------------------------------------------------------------------- /data/script/preprocess.example.py: -------------------------------------------------------------------------------- 1 | import json 2 | from functools import partial 3 | from collections import Counter 4 | from typing import Dict, Any, List 5 | from pathlib import Path 6 | from datasets import load_dataset, Dataset 7 | 8 | 9 | def process_record(record: Dict[str, Any], n2i: Dict[str, int], i2n: Dict[int, str]) -> Dict[str, Any]: 10 | """处理单条数据记录 11 | 12 | Args: 13 | record: 原始数据记录 14 | n2i: name到id的映射字典 15 | i2n: id到name的映射字典 16 | 17 | Returns: 18 | 处理后的数据记录 19 | """ 20 | think_template = ''' 21 | {think} 22 | 23 | {answer}''' 24 | 25 | try: 26 | category_id = n2i[record['repo_name']] 27 | category_name = i2n[category_id] 28 | 29 | return { 30 | "meta": { 31 | "category_id": category_id, 32 | "category_name": category_name 33 | }, 34 | "messages": [ 35 | { 36 | "role": "system", 37 | "content": "你是一位擅长深度思考的助手。在回答问题之前,你会像人类一样展现出一步一步的思考过程,包括问题理解、分析推理、自我质疑、反思验证等环节。之后你会基于思考过程,作出准确的回答。" 38 | }, 39 | { 40 | "role": "user", 41 | "content": record['input'] 42 | }, 43 | { 44 | "role": "assistant", 45 | "content": think_template.format( 46 | think=record['reasoning_content'].strip(), 47 | answer=record['content'].strip() 48 | ) 49 | } 50 | ], 51 | "labels": [0, 0, 1] 52 | } 53 | except KeyError as e: 54 | raise KeyError(f"处理记录时发生错误,缺少必要字段: {e}") 55 | 56 | def generate_category_mapping(dataset: Dataset) -> tuple[Dict[int, str], Dict[str, int], List[Dict]]: 57 | """生成类别映射和配置信息 58 | 59 | Args: 60 | dataset: 原始数据集 61 | 62 | Returns: 63 | tuple包含: 64 | - id2name: id到name的映射 65 | - name2id: name到id的映射 66 | - ratio_config: 采样率配置 67 | """ 68 | id2name = {} 69 | name2id = {} 70 | ratio_config = [] 71 | 72 | name_cnt = Counter(dataset['repo_name']) 73 | category_id = 1 # category_id 为 0 时不统计损失 74 | 75 | for name in name_cnt: 76 | id2name[category_id] = name 77 | name2id[name] = category_id 78 | ratio_config.append({ 79 | "category_id": category_id, 80 | "category_name": name, 81 | "size": name_cnt[name], 82 | "sample_rate": 1.0 83 | }) 84 | category_id += 1 85 | 86 | return id2name, name2id, ratio_config 87 | 88 | def main(): 89 | """主函数""" 90 | # 定义路径 91 | raw_data_path = 'raw/distill_r1_110k.jsonl' 92 | config_path = 'config/distill_r1_110k.json' 93 | preprocessed_data_path = 'dataset/distill_r1_110k.preprocessed.jsonl' 94 | 95 | # 确保必要的目录存在 96 | Path(config_path).parent.mkdir(parents=True, exist_ok=True) 97 | Path(preprocessed_data_path).parent.mkdir(parents=True, exist_ok=True) 98 | 99 | try: 100 | # 加载数据集 101 | dataset = load_dataset('json', data_files=raw_data_path)['train'] 102 | 103 | # 生成配置 104 | id2name, name2id, ratio_config = generate_category_mapping(dataset) 105 | 106 | config = { 107 | "data_name": "distill_r1_110k", 108 | "ratio": ratio_config 109 | } 110 | 111 | # 保存配置 112 | with open(config_path, 'w', encoding='utf-8') as f: 113 | f.write(json.dumps(config, ensure_ascii=False, indent=4)) 114 | 115 | # 处理数据集 116 | process_record_partial = partial(process_record, n2i=name2id, i2n=id2name) 117 | preprocessed_dataset = dataset.map( 118 | process_record_partial, 119 | num_proc=10, 120 | remove_columns=dataset.column_names 121 | ) 122 | 123 | # 保存处理后的数据集 124 | preprocessed_dataset.save_to_disk(preprocessed_data_path) 125 | 126 | except Exception as e: 127 | print(f"处理过程中发生错误: {e}") 128 | raise 129 | 130 | if __name__ == "__main__": 131 | main() 132 | 133 | 134 | 135 | -------------------------------------------------------------------------------- /data/script/pseudo.multiturn.py: -------------------------------------------------------------------------------- 1 | import random 2 | from datasets import load_from_disk, Dataset 3 | 4 | 5 | def merge_records(record1, record2): 6 | """ 7 | 合并两个记录。 8 | :param record1: 第一个记录 (dict) 9 | :param record2: 第二个记录 (dict) 10 | :return: 合并后的记录 (dict) 11 | """ 12 | c1 = record1["messages"][0]["role"] == "system" and record2["messages"][0]["role"] == "system" 13 | c2 = record1["messages"][0]["content"] == record2["messages"][0]["content"] 14 | if c1 and c2: 15 | record2["messages"] = record2["messages"][1:] 16 | record2['labels'] = record2['labels'][1:] 17 | 18 | if 'meta' in record1: 19 | record1['category_id_list'] = [record1["meta"]["category_id"]] * len(record1["messages"]) 20 | 21 | merged_record = { 22 | "messages": record1["messages"] + record2["messages"], # 合并 messages 23 | "labels": record1["labels"] + record2["labels"], # 合并 labels 24 | "category_id_list": record1['category_id_list'] + 25 | [record2["meta"]["category_id"]] * len(record2["messages"]), # 第二条记录的 category_id 26 | } 27 | assert len(merged_record['messages']) == len(merged_record['labels']) == len(merged_record['category_id_list']) 28 | return merged_record 29 | 30 | 31 | 32 | 33 | def merge(input_file, output_file, merge_probability): 34 | """ 35 | :param input_file: 输入 JSONL 文件路径 36 | :param output_file: 输出 JSONL 文件路径 37 | :param merge_probability: 合并的概率 (0~1) 38 | """ 39 | # 读取 JSONL 文件内容 40 | dataset = load_from_disk(input_file) 41 | records = [] 42 | for record in dataset: 43 | records.append(record) 44 | 45 | merged_records = [] 46 | i = 0 47 | while i < len(records): 48 | record1 = records[i] 49 | i += 1 50 | # 按概率尝试合并连续记录 51 | while i < len(records) and random.random() < merge_probability: 52 | record2 = records[i] 53 | record1 = merge_records(record1, record2) 54 | i += 1 55 | # 将最终合并的记录加入结果 56 | merged_records.append(record1) 57 | 58 | turn_len_list = [] 59 | for rec in merged_records: 60 | turn_len = len(rec['messages']) 61 | turn_len_list.append(turn_len) 62 | print("平均轮次:", sum(turn_len_list)/len(turn_len_list)) 63 | 64 | # 保存结果到文件 65 | data = Dataset.from_list(merged_records) 66 | print("合并前:", len(dataset), "合并后:", len(data)) 67 | data.save_to_disk(output_file) 68 | 69 | 70 | if __name__ == '__main__': 71 | # 示例:运行脚本 72 | input_file = "dataset/distill_r1_110k.sampled.jsonl" # 输入文件路径 73 | output_file = "dataset/distill_r1_110k.sampled.pseudo.multiturn.jsonl" # 输出文件路径 74 | merge_probability = 0.5 # 合并概率 (50%) 75 | num_workers = 1 76 | merge(input_file, output_file, merge_probability) -------------------------------------------------------------------------------- /data/script/step1.sample.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | import argparse 4 | from tqdm import tqdm 5 | from multiprocessing import Pool 6 | from datasets import Dataset, load_from_disk 7 | 8 | def load_config(config_path): 9 | with open(config_path, 'r', encoding='utf-8') as f: 10 | return json.load(f) 11 | 12 | def process_category(items, ratio, seed): 13 | random.seed(seed) 14 | int_part = int(ratio) 15 | float_part = ratio - int_part 16 | result = items * int_part 17 | extra_sample_size = int(len(items) * float_part) 18 | if extra_sample_size > 0: 19 | result += random.sample(items, extra_sample_size) 20 | return result 21 | 22 | class Sampler(object): 23 | def __init__(self, config_path, preprocessed_data_path, num_workers, seed, output_path): 24 | self.config =load_config(config_path=config_path) 25 | self.num_workers = num_workers 26 | self.seed = seed 27 | self.output_path = output_path 28 | self.dataset = load_from_disk(preprocessed_data_path) 29 | 30 | def sample(self): 31 | category_ratios = { 32 | item['category_id']: item['sample_rate'] 33 | for item in self.config['ratio'] 34 | } 35 | categorized_data = {} 36 | for item in tqdm(self.dataset): 37 | category_id = item['meta']['category_id'] 38 | if category_id not in categorized_data: 39 | categorized_data[category_id] = [] 40 | categorized_data[category_id].append(item) 41 | with Pool(self.num_workers) as pool: 42 | tasks = [ 43 | (items, category_ratios[category_id], self.seed + category_id) 44 | for category_id, items in categorized_data.items() 45 | ] 46 | results = pool.starmap(process_category, tasks) 47 | sampled_data = [] 48 | for result in results: 49 | sampled_data.extend(result) 50 | random.seed(self.seed) 51 | random.shuffle(sampled_data) 52 | self.dataset = Dataset.from_list(sampled_data) 53 | 54 | def save(self): 55 | self.dataset.save_to_disk(self.output_path) 56 | 57 | 58 | 59 | def parse_args(): 60 | parser = argparse.ArgumentParser() 61 | parser.add_argument('--config-path', type=str, help='数据配置文件路径') 62 | parser.add_argument('--preprocessed-data-path', type=str, help='数据保存路径') 63 | parser.add_argument('--output-path', type=str, help='数据保存路径') 64 | parser.add_argument('--seed', type=int, default=42, help='随机数种子') 65 | parser.add_argument('--num-workers', type=int, default=5, help='并行处理的工作进程数') 66 | return parser.parse_args() 67 | 68 | def main(): 69 | args = parse_args() 70 | sampler = Sampler(config_path=args.config_path, 71 | preprocessed_data_path=args.preprocessed_data_path, 72 | output_path=args.output_path, 73 | seed=args.seed, 74 | num_workers=args.num_workers) 75 | sampler.sample() 76 | sampler.save() 77 | 78 | if __name__ == "__main__": 79 | main() -------------------------------------------------------------------------------- /data/script/step2.tokenize.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import argparse 4 | from transformers import AutoTokenizer 5 | from datasets import load_from_disk 6 | 7 | def load_config(config_path): 8 | with open(config_path, 'r', encoding='utf-8') as f: 9 | return json.load(f) 10 | 11 | class Tokenizer(object): 12 | def __init__(self, input_path, tokenizer_path, num_workers, output_path, ignore_index=-100): 13 | self.dataset = load_from_disk(input_path) 14 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) 15 | self.num_workers = num_workers 16 | self.output_path = output_path 17 | self.ignore_index = ignore_index 18 | 19 | 20 | def process(self, record): 21 | input_ids = self.tokenizer.apply_chat_template( 22 | record["messages"], 23 | tokenize=True, 24 | add_generation_prompt=False, 25 | return_tensors="pt")[0] 26 | 27 | token_labels = torch.full_like(input_ids, fill_value=self.ignore_index) 28 | category_ids = [0] * len(input_ids) 29 | for idx, message in enumerate(record["messages"]): 30 | if message["role"] == "assistant": 31 | prompt = self.tokenizer.apply_chat_template(record["messages"][:idx], tokenize=False, 32 | add_generation_prompt=True) 33 | response = self.tokenizer.apply_chat_template(record["messages"][: idx + 1], tokenize=False)[ 34 | len(prompt):] 35 | start_idx = self.tokenizer( 36 | prompt, 37 | padding=False, 38 | return_tensors="pt", 39 | add_special_tokens=False, 40 | )["attention_mask"].int().sum().item() 41 | end_idx = start_idx + self.tokenizer( 42 | response, 43 | padding=False, 44 | return_tensors="pt", 45 | add_special_tokens=False, 46 | )["attention_mask"].int().sum().item() 47 | if record["labels"][idx] == 1: 48 | token_labels[start_idx:end_idx] = input_ids[start_idx:end_idx] 49 | if 'meta' in record: 50 | category_ids[start_idx:end_idx] = [record['meta']['category_id']] * (end_idx - start_idx) 51 | else: 52 | category_ids[start_idx:end_idx] = [record['category_id_list'][idx]] * (end_idx - start_idx) 53 | return { 54 | "input_ids": input_ids, 55 | "token_labels": token_labels, 56 | "category_ids": category_ids, 57 | } 58 | 59 | def tokenize(self): 60 | self.dataset = self.dataset.map( 61 | self.process, 62 | num_proc=self.num_workers, 63 | remove_columns=self.dataset.column_names, 64 | desc="分词" 65 | ) 66 | 67 | def save(self): 68 | self.dataset.save_to_disk(self.output_path) 69 | 70 | 71 | def parse_args(): 72 | parser = argparse.ArgumentParser() 73 | parser.add_argument('--input-path', type=str, help='数据路径') 74 | parser.add_argument('--output-path', type=str, help='数据保存路径') 75 | parser.add_argument('--tokenizer-path', type=str, help='tokenizer路径') 76 | parser.add_argument('--num-workers', type=int, help='并行处理的工作进程数') 77 | return parser.parse_args() 78 | 79 | def main(): 80 | args = parse_args() 81 | tokenizer = Tokenizer(input_path=args.input_path, 82 | tokenizer_path=args.tokenizer_path, 83 | num_workers=args.num_workers, 84 | output_path=args.output_path) 85 | tokenizer.tokenize() 86 | tokenizer.save() 87 | 88 | if __name__ == "__main__": 89 | main() -------------------------------------------------------------------------------- /data/script/step3.pack.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import argparse 3 | import random 4 | import torch 5 | import numpy as np 6 | from functools import partial 7 | from multiprocessing import Pool 8 | from datasets import load_from_disk, Dataset 9 | 10 | 11 | def process_single_group(group_indices, sequences, max_length, padding_value): 12 | """处理单个packed sequence组""" 13 | # 预分配numpy数组,使用padding_value初始化 14 | packed_input_ids = np.full(max_length, padding_value, dtype=np.int64) 15 | packed_labels = np.full(max_length, -100, dtype=np.int64) 16 | packed_category_ids = np.zeros(max_length, dtype=np.int64) 17 | attention_mask = np.zeros(max_length, dtype=np.int64) 18 | 19 | current_pos = 0 20 | packed_seq_lens = [] 21 | 22 | # 随机打乱索引 23 | random.shuffle(group_indices) 24 | 25 | # 填充数据 26 | for idx in group_indices: 27 | seq_len = len(sequences[idx]['input_ids']) 28 | packed_seq_lens.append(seq_len) 29 | 30 | end_pos = current_pos + seq_len 31 | packed_input_ids[current_pos:end_pos] = sequences[idx]['input_ids'] 32 | packed_labels[current_pos:end_pos] = sequences[idx]['token_labels'] 33 | packed_category_ids[current_pos:end_pos] = sequences[idx]['category_ids'] 34 | attention_mask[current_pos:end_pos] = 1 35 | 36 | current_pos = end_pos 37 | 38 | return { 39 | 'input_ids': torch.from_numpy(packed_input_ids), 40 | 'labels': torch.from_numpy(packed_labels), 41 | 'attention_mask': torch.from_numpy(attention_mask), 42 | 'packed_seq_lens': torch.tensor(packed_seq_lens), 43 | 'category_ids': torch.from_numpy(packed_category_ids), 44 | } 45 | 46 | 47 | def pack_sequences(sequences, max_length, num_workers=20, padding_value=0, max_attempt=2000): 48 | """ 49 | 使用贪心算法将多个序列打包成固定长度的序列,尽量减少packed_sequences的数量 50 | 51 | Args: 52 | sequences: 包含多个序列的列表,每个序列是一个字典,包含'input_ids'键 53 | max_length: 打包后序列的最大长度 54 | padding_value: 填充值 55 | 56 | Returns: 57 | packed_results: 包含打包后的序列和长度信息的字典列表 58 | - input_ids: 拼接后的序列 59 | - lengths: 原始序列长度列表 60 | - indices: 原始序列索引列表 61 | """ 62 | 63 | def compute_length(example): 64 | return {'length': len(example['input_ids'])} 65 | 66 | data_with_length = sequences.map( 67 | compute_length, 68 | num_proc=num_workers, 69 | desc="计算序列长度" 70 | ) 71 | idx2len = dict(enumerate(data_with_length['length'])) 72 | 73 | unused_indices = set(idx2len.keys()) 74 | packed_sequences = [] 75 | 76 | pbar = tqdm(total=len(unused_indices), desc="打包进度") 77 | 78 | while unused_indices: 79 | current_seq = [] 80 | current_len = 0 81 | attempt = max_attempt 82 | for idx in idx2len: 83 | if idx in unused_indices: 84 | if current_len + idx2len[idx] <= max_length: 85 | current_seq.append(idx) 86 | current_len += idx2len[idx] 87 | unused_indices.remove(idx) 88 | pbar.update(1) 89 | else: 90 | if attempt > 0: 91 | attempt -= 1 92 | continue 93 | else: 94 | break 95 | packed_sequences.append(current_seq) 96 | pbar.close() 97 | 98 | with Pool(num_workers) as pool: 99 | process_fn = partial( 100 | process_single_group, 101 | sequences=sequences, 102 | max_length=max_length, 103 | padding_value=padding_value 104 | ) 105 | 106 | # 使用imap显示进度条 107 | packed_results = list(tqdm( 108 | pool.imap(process_fn, packed_sequences), 109 | total=len(packed_sequences), 110 | desc="并行打包处理" 111 | )) 112 | 113 | return packed_results 114 | 115 | class Packer(object): 116 | def __init__(self, input_path, output_path, max_length, padding_value, num_workers, max_attempt): 117 | self.dataset = load_from_disk(input_path) 118 | self.output_path = output_path 119 | self.max_length = max_length 120 | self.padding_value = padding_value 121 | self.num_workers = num_workers 122 | self.max_attempt = max_attempt 123 | self.packed_dataset = None 124 | 125 | def filter_fn(self, example): 126 | return len(example['input_ids']) <= self.max_length 127 | 128 | def pack(self): 129 | valid_items = self.dataset.filter( 130 | self.filter_fn, 131 | num_proc=self.num_workers 132 | ) 133 | packed_data = pack_sequences( 134 | valid_items, 135 | max_length=self.max_length, 136 | padding_value=self.padding_value, 137 | max_attempt=self.max_attempt 138 | ) 139 | self.packed_dataset = Dataset.from_list(packed_data) 140 | self.packed_dataset.set_format('torch') 141 | 142 | 143 | def save(self): 144 | self.packed_dataset.save_to_disk(self.output_path) 145 | 146 | 147 | 148 | 149 | def parse_args(): 150 | parser = argparse.ArgumentParser() 151 | parser.add_argument('--input-path', type=str, help='数据路径') 152 | parser.add_argument('--output-path', type=str, help='数据保存路径') 153 | parser.add_argument('--max-length', type=int, help=' 最大长度') 154 | parser.add_argument('--padding-value', type=int, help='padding值') 155 | parser.add_argument('--num-workers', type=int, help='并行处理的工作进程数') 156 | return parser.parse_args() 157 | 158 | def main(): 159 | args = parse_args() 160 | packer = Packer( 161 | input_path=args.input_path, 162 | output_path=args.output_path, 163 | max_length=args.max_length, 164 | padding_value=args.padding_value, 165 | num_workers=args.num_workers, 166 | max_attempt=2000) # 打包超出最大长度时,继续向后查找次数 167 | packer.pack() 168 | packer.save() 169 | 170 | if __name__ == "__main__": 171 | main() 172 | 173 | 174 | -------------------------------------------------------------------------------- /docs/image-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlpod/OpenSFT/57ecd3c6037c4dca600009766dbeb0c1249a6048/docs/image-1.png -------------------------------------------------------------------------------- /docs/image-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlpod/OpenSFT/57ecd3c6037c4dca600009766dbeb0c1249a6048/docs/image-2.png -------------------------------------------------------------------------------- /docs/image-3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlpod/OpenSFT/57ecd3c6037c4dca600009766dbeb0c1249a6048/docs/image-3.png -------------------------------------------------------------------------------- /docs/image-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlpod/OpenSFT/57ecd3c6037c4dca600009766dbeb0c1249a6048/docs/image-4.png -------------------------------------------------------------------------------- /train/config.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument('--max_steps', type=int, default=2000) 6 | parser.add_argument('--warmup_num_steps', type=int, default=200) 7 | parser.add_argument("--learning_rate", type=float, default=2e-5) 8 | args = parser.parse_args() 9 | 10 | with open("config/zero3_offload.json") as f: 11 | config = json.loads(f.read()) 12 | 13 | config['scheduler']['params']['total_num_steps'] = args.max_steps 14 | config['scheduler']['params']['warmup_num_steps'] = args.warmup_num_steps 15 | config['scheduler']['params']['warmup_max_lr'] = args.learning_rate 16 | config['scheduler']['params']['warmup_min_lr'] = args.learning_rate * 0.1 17 | 18 | print(config) 19 | with open("config/zero3_offload.json", "w", encoding="utf-8") as f: 20 | f.write(json.dumps(config, indent=4)) 21 | -------------------------------------------------------------------------------- /train/config/multi_node.yaml: -------------------------------------------------------------------------------- 1 | debug: false 2 | deepspeed_config: 3 | deepspeed_config_file: zero3_offload.json 4 | deepspeed_multinode_launcher: standard 5 | zero3_init_flag: true 6 | distributed_type: DEEPSPEED 7 | downcast_bf16: 'no' 8 | num_machines: 24 9 | num_processes: 196 10 | main_process_port: 38222 11 | main_training_function: main 12 | rdzv_backend: c10d 13 | same_network: false 14 | tpu_env: [] 15 | tpu_use_cluster: false 16 | tpu_use_sudo: false 17 | use_cpu: false 18 | -------------------------------------------------------------------------------- /train/config/single_node.yaml: -------------------------------------------------------------------------------- 1 | debug: false 2 | deepspeed_config: 3 | deepspeed_config_file: zero3_offload.json 4 | deepspeed_multinode_launcher: standard 5 | zero3_init_flag: true 6 | distributed_type: DEEPSPEED 7 | downcast_bf16: 'no' 8 | num_machines: 1 9 | num_processes: 8 10 | main_process_port: 38222 11 | main_training_function: main 12 | rdzv_backend: c10d 13 | same_network: false 14 | tpu_env: [] 15 | tpu_use_cluster: false 16 | tpu_use_sudo: false 17 | use_cpu: false 18 | -------------------------------------------------------------------------------- /train/config/zero3_offload.json: -------------------------------------------------------------------------------- 1 | { 2 | "bf16": { 3 | "enabled": "true" 4 | }, 5 | "scheduler": { 6 | "type": "WarmupDecayLR", 7 | "params": { 8 | "warmup_min_lr": 2e-06, 9 | "warmup_max_lr": 2e-05, 10 | "warmup_num_steps": 200, 11 | "warmup_type": "linear", 12 | "total_num_steps": 2000 13 | } 14 | }, 15 | "optimizer": { 16 | "type": "AdamW", 17 | "params": { 18 | "lr": "auto", 19 | "betas": [ 20 | 0.9, 21 | 0.95 22 | ], 23 | "eps": 1e-08, 24 | "weight_decay": 0.1 25 | } 26 | }, 27 | "zero_optimization": { 28 | "stage": 3, 29 | "overlap_comm": true, 30 | "contiguous_gradients": true, 31 | "sub_group_size": 1000000000.0, 32 | "reduce_bucket_size": "auto", 33 | "stage3_prefetch_bucket_size": "auto", 34 | "stage3_param_persistence_threshold": "auto", 35 | "stage3_max_live_parameters": 1000000000.0, 36 | "stage3_max_reuse_distance": 1000000000.0, 37 | "stage3_gather_16bit_weights_on_model_save": true 38 | }, 39 | "gradient_accumulation_steps": "auto", 40 | "gradient_clipping": "auto", 41 | "train_batch_size": "auto", 42 | "train_micro_batch_size_per_gpu": 1 43 | } -------------------------------------------------------------------------------- /train/data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import datasets 3 | from datasets import disable_caching 4 | from torch.utils.data import DataLoader, DistributedSampler 5 | 6 | disable_caching() 7 | 8 | class SFTData(object): 9 | def __init__(self, data_path): 10 | self.data = datasets.load_from_disk(data_path) 11 | 12 | def collate_fn(self, records): 13 | input_ids = records[0]["input_ids"].unsqueeze(0) 14 | labels = records[0]["labels"].unsqueeze(0) 15 | attention_mask = records[0]["attention_mask"].unsqueeze(0) 16 | category_ids = records[0]["category_ids"].unsqueeze(0) 17 | packed_seq_lens = records[0]['packed_seq_lens'].to(torch.int32) 18 | return dict( 19 | input_ids=input_ids, 20 | labels=labels, 21 | attention_mask=attention_mask, 22 | packed_seq_lens=packed_seq_lens, 23 | category_ids=category_ids 24 | ) 25 | 26 | def get_dataloader(self, dp_world_size, dp_rank, seed, epoch, shuffle=True): 27 | sampler = DistributedSampler( 28 | self.data, num_replicas=dp_world_size, rank=dp_rank, seed=seed+epoch, shuffle=shuffle 29 | ) 30 | train_dataloader = DataLoader( 31 | self.data, 32 | batch_size=1, 33 | sampler=sampler, 34 | collate_fn=self.collate_fn, 35 | pin_memory=True, 36 | ) 37 | return train_dataloader -------------------------------------------------------------------------------- /train/requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.33.0 2 | flash-attn==2.6.3 3 | deepspeed==0.16.0 4 | transformers==4.45.2 5 | 6 | -e ./ring-flash-attention/ 7 | 8 | wandb 9 | pandas -------------------------------------------------------------------------------- /train/ring-flash-attention/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2024 Zilin Zhu 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /train/ring-flash-attention/README.md: -------------------------------------------------------------------------------- 1 | ## Ring Flash Attention 2 | 3 | This repo implements [RingAttention](https://github.com/lhao499/RingAttention) using [FlashAttention](https://github.com/Dao-AILab/flash-attention). The current implementation supports: 4 | 5 | - varlen (packing samples) api, corresponding to `flash_attn_varlen_func`: 6 | - `ring_flash_attn_varlen_func`: A basic implementation of ring attention. 7 | - `zigzag_ring_flash_attn_varlen_func`: an more compute-balanced version of ring attention. More details in [issue#2](https://github.com/zhuzilin/ring-flash-attention/issues/2). 8 | - `llama3_flash_attn_varlen_func`: The context parallelism used in [llama3 tech report](https://arxiv.org/abs/2407.21783) with extra design for varlen and low memory overhead. Although technically not ring attention, this is **recommended** for most varlen use cases, as it offers a less intrusive alternative for training frameworks with fewer data manipulations and better arithmetic precision. 9 | - batch api, corresponding to `flash_attn_func`: 10 | - `ring_flash_attn_func`: basic ring attention. 11 | - `zigzag_ring_flash_attn_func`: An more compute balanced version of ring attention, see [issue#2](https://github.com/zhuzilin/ring-flash-attention/issues/2). 12 | - `stripe_flash_attn_func`: Stripe attention version of `ring_flash_attn_func`, the block size is set to 1 to use flash_attn api, see: https://arxiv.org/abs/2311.09431 13 | - [huggingface model adapter](ring_flash_attn/adapters/hf_adapter.py). Here is an example to use the adapter: [OpenRLHF/OpenRLHF/pull#439](https://github.com/OpenRLHF/OpenRLHF/pull/439/files). 14 | 15 | Note that 16 | 17 | - Each function includes `*_func`, `*_kvpacked_func`, `*_qkvpacked_func` variants. 18 | - The varlen versions (except the llama3 version) only support passing one `cu_seqlens`. 19 | 20 | ## Performance Summary 21 | 22 | The following table summarizes the performance of the implemented APIs: 23 | 24 | | batch api | GPU | theoretic
flash_attn | ring_attn | zigzag_ring | stripe_attn | 25 | | -------------------- | ------- | ----------------------------- | ------------- | --------------- | --------------- | 26 | | fwd only (iter/sec) | 8xH800 | 591.5 / 8 = 73.9 | 38.5 | 63.0 | 55.0 | 27 | | | | | 52.1% | **85.2%** | 74.4% | 28 | | fwd + bwd (iter/sec) | 8xH800 | 154.7 / 8 = 19.3 | 10.4 | 17.4 | 16.0 | 29 | | | | | 53.9% | **90.2%** | 82.9% | 30 | | fwd only (iter/sec) | 8xA100 | 373.4 / 8 = 46.7 | 24.0 | 38.2 | 32.5 | 31 | | | | | 51.4% | **81.7%** | 69.6% | 32 | | fwd + bwd (iter/sec) | 8xA100 | 94.7 / 8 = 11.8 | 6.2 | 10.6 | 9.75 | 33 | | | | | 52.5% | **89.8%** | 82.6% | 34 | | **varlen api** | **GPU** | **theoretic
flash_attn** | **ring_attn** | **zigzag_ring** | **llama3_attn** | 35 | | fwd only (iter/sec) | 8xH800 | 852.4 / 8 = 106.6 | 52.4 | 74.8 | 60.8 | 36 | | | | | 49.1% | **70.2%** | 57.0% | 37 | | fwd + bwd (iter/sec) | 8xH800 | 225.4 / 8 = 28.2 | 14.4 | 21.4 | 16.4 | 38 | | | | | 51.1% | **75.9%** | 58.1% | 39 | | fwd only (iter/sec) | 8xA100 | 532.3 / 8 = 66.5 | 33.1 | 47.9 | 34.3 | 40 | | | | | 49.8% | **72.0%** | 51.6% | 41 | | fwd + bwd (iter/sec) | 8xA100 | 133.8 / 8 = 16.7 | 8.7 | 13.4 | 9.7 | 42 | | | | | 52.1% | **80.2%** | 58.0% | 43 | 44 | Note that 45 | 46 | - The code of the benchmark is in [benchmark](benchmark/), its configuration matches the [Meta-Llama-3.1-8B](https://huggingface.co/NousResearch/Meta-Llama-3.1-8B/blob/main/config.json) setting, with a total sequence of length 8k per GPU. 47 | - When running the benchmark with with 8 gpu, the flash attn code is running with 1/8 computation of ring attention, as flash attn code is running `8*1^2`, while the ring attn code is running `1*8^2`. 48 | - NVLink between GPUs are required for high performance. 49 | - Please remember to adapt the RoPE offset for different api. 50 | 51 | ### Installation 52 | 53 | ```bash 54 | pip install ring-flash-attn 55 | ``` 56 | 57 | or use the following command to build from source: 58 | 59 | ```bash 60 | git clone https://github.com/zhuzilin/ring-flash-attention.git 61 | cd ring-flash-attention 62 | pip install . 63 | ``` 64 | 65 | ### TODOs 66 | 67 | - [x] Implement `ring_flash_attn_varlen_qkvpacked_func` 68 | - [x] Implement `zigzag_ring_flash_attn_qkvpacked_func` [issue#2](https://github.com/zhuzilin/ring-flash-attention/issues/2) 69 | - [x] Implement `stripe_flash_attn_qkvpacked_func` 70 | - [x] Implement `zigzag_ring_flash_attn_varlen_qkvpacked_func` 71 | - [x] Implement `*_kvpacked_func` and `*_func` variant for all APIs 72 | - [x] ~~Optimize `*_varlen_func`~~ Implement `llama3_flash_attn_varlen_func` 73 | - [x] ~~Add an example to train llama~~ Implement adapter for huggingface model 74 | - [ ] Implement `zigzag_llama3_flash_attn_varlen_func` 75 | 76 | ### Test 77 | 78 | ```bash 79 | torchrun --nproc_per_node 8 test/test_llama3_flash_attn_varlen_func.py 80 | torchrun --nproc_per_node 8 test/test_ring_flash_attn_func.py 81 | torchrun --nproc_per_node 8 test/test_ring_flash_attn_varlen_func.py 82 | torchrun --nproc_per_node 8 test/test_zigzag_ring_flash_attn_func.py 83 | torchrun --nproc_per_node 8 test/test_zigzag_ring_flash_attn_varlen_func.py 84 | torchrun --nproc_per_node 8 test/test_stripe_flash_attn_func.py 85 | ``` 86 | 87 | ### Benchmark 88 | 89 | ```bash 90 | torchrun --nproc_per_node 8 benchmark/benchmark_kvpacked_func.py 91 | torchrun --nproc_per_node 8 benchmark/benchmark_varlen_kvpacked_func.py 92 | ``` 93 | 94 | ### Known Limitations 95 | 96 | There are some arithmetic errors with the current implementation. The reason for them is probably that flash attention will return bf16 value for each block, so we cannot accumluate the values with the original fp32 ones. 97 | 98 | And also because we need to save extra fp32 buffer during computation, the memory usage would be higher than theoretic limit. 99 | 100 | Also, 101 | 102 | - dropout is not supported at the moment, because it's hard to save all the rng_states. 103 | - window_size is not supported, because it will be really tricky to implement a varlen version with window_size. 104 | -------------------------------------------------------------------------------- /train/ring-flash-attention/benchmark/benchmark_kvpacked_func.py: -------------------------------------------------------------------------------- 1 | from flash_attn import flash_attn_kvpacked_func 2 | import os 3 | import torch 4 | import torch.distributed as dist 5 | from ring_flash_attn import ( 6 | ring_flash_attn_kvpacked_func, 7 | zigzag_ring_flash_attn_kvpacked_func, 8 | stripe_flash_attn_kvpacked_func, 9 | ) 10 | 11 | 12 | def benchmark(f, num_iter=100, forward_only=True, log=True, profile=False): 13 | dtype = torch.bfloat16 14 | rank = dist.get_rank() 15 | world_size = dist.get_world_size() 16 | device = torch.device(f"cuda:{rank}") 17 | torch.cuda.set_device(device) 18 | 19 | batch_size = 1 20 | deterministic = False 21 | # config of llama3 8B 22 | seqlen = 1024 * 8 23 | num_heads = 32 24 | num_kv_heads = 8 25 | head_dim = 128 26 | causal = True 27 | 28 | assert seqlen % (2 * world_size) == 0 29 | assert head_dim % 8 == 0 30 | 31 | q = torch.randn( 32 | batch_size, 33 | seqlen, 34 | num_heads, 35 | head_dim, 36 | device=device, 37 | dtype=dtype, 38 | requires_grad=True, 39 | ) 40 | kv = torch.randn( 41 | batch_size, 42 | seqlen, 43 | 2, 44 | num_kv_heads, 45 | head_dim, 46 | device=device, 47 | dtype=dtype, 48 | requires_grad=True, 49 | ) 50 | dout = torch.randn( 51 | batch_size, seqlen, num_heads, head_dim, device=device, dtype=dtype 52 | ) 53 | 54 | if profile: 55 | torch.backends.cudnn.benchmark = True 56 | profiler = torch.profiler.profile( 57 | activities=[ 58 | torch.profiler.ProfilerActivity.CPU, 59 | torch.profiler.ProfilerActivity.CUDA, 60 | ], 61 | schedule=torch.profiler.schedule( 62 | wait=5, 63 | warmup=5, 64 | active=5, 65 | ), 66 | record_shapes=True, 67 | profile_memory=True, 68 | with_flops=True, 69 | with_modules=True, 70 | with_stack=True, 71 | on_trace_ready=torch.profiler.tensorboard_trace_handler( 72 | os.path.join( 73 | f"./benchmark/logs/{f.__name__}", f"rank_{dist.get_rank()}" 74 | ) 75 | ), 76 | ) 77 | 78 | if profile: 79 | profiler.start() 80 | 81 | begin = torch.cuda.Event(enable_timing=True) 82 | begin.record() 83 | 84 | if forward_only: 85 | with torch.no_grad(): 86 | for _ in range(num_iter): 87 | _ = f( 88 | q, 89 | kv, 90 | causal=causal, 91 | window_size=(-1, -1), 92 | alibi_slopes=None, 93 | deterministic=deterministic, 94 | return_attn_probs=False, 95 | ) 96 | if profile: 97 | profiler.step() 98 | 99 | else: 100 | for _ in range(num_iter): 101 | q.grad = None 102 | kv.grad = None 103 | out = f( 104 | q, 105 | kv, 106 | causal=causal, 107 | window_size=(-1, -1), 108 | alibi_slopes=None, 109 | deterministic=deterministic, 110 | return_attn_probs=False, 111 | ) 112 | out.backward(dout) 113 | if profile: 114 | profiler.step() 115 | 116 | end = torch.cuda.Event(enable_timing=True) 117 | end.record() 118 | torch.cuda.synchronize(device=device) 119 | time = begin.elapsed_time(end) / 1000.0 120 | 121 | if profile: 122 | profiler.stop() 123 | 124 | if rank == 0 and log: 125 | print(f"{num_iter / time:.3f} iter/s, {time:.3f} sec") 126 | 127 | 128 | if __name__ == "__main__": 129 | dist.init_process_group("nccl") 130 | rank = dist.get_rank() 131 | 132 | forward_only = False 133 | profile = False 134 | num_iter = 500 if forward_only else 100 135 | 136 | for f in [ 137 | flash_attn_kvpacked_func, 138 | ring_flash_attn_kvpacked_func, 139 | zigzag_ring_flash_attn_kvpacked_func, 140 | stripe_flash_attn_kvpacked_func, 141 | ]: 142 | torch.cuda.empty_cache() 143 | if rank == 0: 144 | print(f"# {f.__name__}") 145 | benchmark(f, forward_only=forward_only, num_iter=num_iter, log=False) 146 | benchmark( 147 | f, forward_only=forward_only, num_iter=num_iter, log=True, profile=profile 148 | ) 149 | -------------------------------------------------------------------------------- /train/ring-flash-attention/benchmark/benchmark_varlen_kvpacked_func.py: -------------------------------------------------------------------------------- 1 | from flash_attn import flash_attn_varlen_kvpacked_func 2 | import os 3 | import torch 4 | import torch.distributed as dist 5 | from ring_flash_attn import ( 6 | ring_flash_attn_varlen_kvpacked_func, 7 | zigzag_ring_flash_attn_varlen_kvpacked_func, 8 | llama3_flash_attn_varlen_kvpacked_func, 9 | llama3_flash_attn_prepare_cu_seqlens, 10 | ) 11 | 12 | 13 | def benchmark( 14 | f, 15 | use_double_cu_seqlens, 16 | use_llama3=False, 17 | num_iter=100, 18 | forward_only=True, 19 | log=True, 20 | profile=False, 21 | ): 22 | dtype = torch.bfloat16 23 | rank = dist.get_rank() 24 | world_size = dist.get_world_size() 25 | device = torch.device(f"cuda:{rank}") 26 | torch.cuda.set_device(device) 27 | 28 | deterministic = False 29 | # config of llama3 8B 30 | seqlen = 1024 * 8 31 | num_heads = 32 32 | num_kv_heads = 8 33 | head_dim = 128 34 | causal = True 35 | 36 | assert seqlen % (2 * world_size) == 0 37 | assert head_dim % 8 == 0 38 | 39 | q = torch.randn( 40 | seqlen, num_heads, head_dim, device=device, dtype=dtype, requires_grad=True 41 | ) 42 | kv = torch.randn( 43 | seqlen, 44 | 2, 45 | num_kv_heads, 46 | head_dim, 47 | device=device, 48 | dtype=dtype, 49 | requires_grad=True, 50 | ) 51 | dout = torch.randn(seqlen, num_heads, head_dim, device=device, dtype=dtype) 52 | 53 | cu_seqlens_list = [ 54 | torch.tensor([0, 8192], device=device, dtype=torch.int32), 55 | torch.tensor([0, 256, 7648, 8192], device=device, dtype=torch.int32), 56 | torch.tensor([0, 4096, 8192], device=device, dtype=torch.int32), 57 | torch.tensor( 58 | [0, 3104, 6304, 7904, 8064, 8192], device=device, dtype=torch.int32 59 | ), 60 | ] 61 | 62 | if use_llama3: 63 | cu_seqlens_q_list = [] 64 | cu_seqlens_k_list = [] 65 | max_seqlen_q_list = [] 66 | max_seqlen_k_list = [] 67 | local_k_slice_list = [] 68 | for cu_seqlens in cu_seqlens_list: 69 | ( 70 | cu_seqlens_q, 71 | cu_seqlens_k, 72 | max_seqlen_q, 73 | max_seqlen_k, 74 | local_k_slice, 75 | ) = llama3_flash_attn_prepare_cu_seqlens( 76 | cu_seqlens * world_size, 77 | causal=causal, 78 | rank=rank, 79 | world_size=world_size, 80 | ) 81 | cu_seqlens_q_list.append(cu_seqlens_q) 82 | cu_seqlens_k_list.append(cu_seqlens_k) 83 | max_seqlen_q_list.append(max_seqlen_q) 84 | max_seqlen_k_list.append(max_seqlen_k) 85 | local_k_slice_list.append(local_k_slice) 86 | else: 87 | max_seqlen_list = [ 88 | (cu_seqlens[1:] - cu_seqlens[:1]).max().item() 89 | for cu_seqlens in cu_seqlens_list 90 | ] 91 | 92 | if profile: 93 | torch.backends.cudnn.benchmark = True 94 | profiler = torch.profiler.profile( 95 | activities=[ 96 | torch.profiler.ProfilerActivity.CPU, 97 | torch.profiler.ProfilerActivity.CUDA, 98 | ], 99 | schedule=torch.profiler.schedule( 100 | wait=5, 101 | warmup=5, 102 | active=5, 103 | ), 104 | record_shapes=True, 105 | profile_memory=True, 106 | with_flops=True, 107 | with_modules=True, 108 | with_stack=True, 109 | on_trace_ready=torch.profiler.tensorboard_trace_handler( 110 | os.path.join( 111 | f"./benchmark/logs/{f.__name__}", f"rank_{dist.get_rank()}" 112 | ) 113 | ), 114 | ) 115 | 116 | if profile: 117 | profiler.start() 118 | 119 | begin = torch.cuda.Event(enable_timing=True) 120 | begin.record() 121 | 122 | def wrapper(i: int): 123 | if use_llama3: 124 | return f( 125 | q, 126 | kv, 127 | cu_seqlens_q_list[i % len(cu_seqlens_list)], 128 | cu_seqlens_k_list[i % len(cu_seqlens_list)], 129 | max_seqlen_q_list[i % len(cu_seqlens_list)], 130 | max_seqlen_k_list[i % len(cu_seqlens_list)], 131 | heads_k_stride=4, 132 | local_k_slice=local_k_slice_list[i % len(cu_seqlens_list)], 133 | causal=causal, 134 | window_size=(-1, -1), 135 | alibi_slopes=None, 136 | deterministic=deterministic, 137 | return_attn_probs=False, 138 | ) 139 | elif use_double_cu_seqlens: 140 | return f( 141 | q, 142 | kv, 143 | cu_seqlens_list[i % len(cu_seqlens_list)], 144 | cu_seqlens_list[i % len(cu_seqlens_list)], 145 | max_seqlen_list[i % len(cu_seqlens_list)], 146 | max_seqlen_list[i % len(cu_seqlens_list)], 147 | causal=causal, 148 | window_size=(-1, -1), 149 | alibi_slopes=None, 150 | deterministic=deterministic, 151 | return_attn_probs=False, 152 | ) 153 | else: 154 | return f( 155 | q, 156 | kv, 157 | cu_seqlens_list[i % len(cu_seqlens_list)], 158 | max_seqlen_list[i % len(cu_seqlens_list)], 159 | causal=causal, 160 | window_size=(-1, -1), 161 | alibi_slopes=None, 162 | deterministic=deterministic, 163 | return_attn_probs=False, 164 | ) 165 | 166 | if forward_only: 167 | with torch.no_grad(): 168 | for i in range(num_iter): 169 | _ = wrapper(i) 170 | else: 171 | for i in range(num_iter): 172 | q.grad = None 173 | kv.grad = None 174 | out = wrapper(i) 175 | out.backward(dout) 176 | if profile: 177 | profiler.step() 178 | end = torch.cuda.Event(enable_timing=True) 179 | end.record() 180 | torch.cuda.synchronize(device=device) 181 | time = begin.elapsed_time(end) / 1000.0 182 | 183 | if profile: 184 | profiler.stop() 185 | if rank == 0 and log: 186 | print(f"{num_iter / time} iter/s, {time} sec") 187 | 188 | 189 | if __name__ == "__main__": 190 | dist.init_process_group("nccl") 191 | rank = dist.get_rank() 192 | 193 | forward_only = False 194 | profile = False 195 | num_iter = 500 if forward_only else 100 196 | 197 | for f, use_double_cu_seqlens in [ 198 | (flash_attn_varlen_kvpacked_func, True), 199 | (ring_flash_attn_varlen_kvpacked_func, False), 200 | (zigzag_ring_flash_attn_varlen_kvpacked_func, False), 201 | ]: 202 | torch.cuda.empty_cache() 203 | if rank == 0: 204 | print(f"# {f.__name__}") 205 | benchmark( 206 | f, 207 | use_double_cu_seqlens, 208 | forward_only=forward_only, 209 | num_iter=num_iter, 210 | log=False, 211 | ) 212 | benchmark( 213 | f, 214 | use_double_cu_seqlens, 215 | forward_only=forward_only, 216 | num_iter=num_iter, 217 | log=True, 218 | profile=profile, 219 | ) 220 | 221 | for f, use_double_cu_seqlens in [ 222 | (llama3_flash_attn_varlen_kvpacked_func, True), 223 | ]: 224 | torch.cuda.empty_cache() 225 | if rank == 0: 226 | print(f"# {f.__name__}") 227 | benchmark( 228 | f, 229 | use_double_cu_seqlens, 230 | use_llama3=True, 231 | forward_only=forward_only, 232 | num_iter=num_iter, 233 | log=False, 234 | ) 235 | benchmark( 236 | f, 237 | use_double_cu_seqlens, 238 | use_llama3=True, 239 | forward_only=forward_only, 240 | num_iter=num_iter, 241 | log=True, 242 | profile=profile, 243 | ) 244 | -------------------------------------------------------------------------------- /train/ring-flash-attention/build/lib/ring_flash_attn/__init__.py: -------------------------------------------------------------------------------- 1 | from .llama3_flash_attn_varlen import ( 2 | llama3_flash_attn_prepare_cu_seqlens, 3 | llama3_flash_attn_varlen_func, 4 | llama3_flash_attn_varlen_kvpacked_func, 5 | llama3_flash_attn_varlen_qkvpacked_func, 6 | ) 7 | from .ring_flash_attn import ( 8 | ring_flash_attn_func, 9 | ring_flash_attn_kvpacked_func, 10 | ring_flash_attn_qkvpacked_func, 11 | ) 12 | from .ring_flash_attn_varlen import ( 13 | ring_flash_attn_varlen_func, 14 | ring_flash_attn_varlen_kvpacked_func, 15 | ring_flash_attn_varlen_qkvpacked_func, 16 | ) 17 | from .zigzag_ring_flash_attn import ( 18 | zigzag_ring_flash_attn_func, 19 | zigzag_ring_flash_attn_kvpacked_func, 20 | zigzag_ring_flash_attn_qkvpacked_func, 21 | ) 22 | from .zigzag_ring_flash_attn_varlen import ( 23 | zigzag_ring_flash_attn_varlen_func, 24 | zigzag_ring_flash_attn_varlen_kvpacked_func, 25 | zigzag_ring_flash_attn_varlen_qkvpacked_func, 26 | ) 27 | from .stripe_flash_attn import ( 28 | stripe_flash_attn_func, 29 | stripe_flash_attn_kvpacked_func, 30 | stripe_flash_attn_qkvpacked_func, 31 | ) 32 | from .adapters import ( 33 | substitute_hf_flash_attn, 34 | update_ring_flash_attn_params, 35 | ) 36 | -------------------------------------------------------------------------------- /train/ring-flash-attention/build/lib/ring_flash_attn/adapters/__init__.py: -------------------------------------------------------------------------------- 1 | from .hf_adapter import ( 2 | substitute_hf_flash_attn, 3 | update_ring_flash_attn_params, 4 | ) 5 | -------------------------------------------------------------------------------- /train/ring-flash-attention/build/lib/ring_flash_attn/adapters/hf_adapter.py: -------------------------------------------------------------------------------- 1 | import os 2 | import inspect 3 | from typing import Optional 4 | 5 | import torch 6 | import torch.distributed as dist 7 | import transformers 8 | import transformers.modeling_flash_attention_utils 9 | from transformers.modeling_flash_attention_utils import ( 10 | _flash_supports_window_size, 11 | is_flash_attn_greater_or_equal, 12 | ) 13 | from ..llama3_flash_attn_varlen import ( 14 | llama3_flash_attn_varlen_func, 15 | llama3_flash_attn_prepare_cu_seqlens, 16 | ) 17 | 18 | 19 | DATA_PARAMS = {} 20 | RING_ATTN_SWITCH = True 21 | 22 | 23 | def check_params(f1, f2): 24 | return len(inspect.signature(f1).parameters) == len( 25 | inspect.signature(f2).parameters 26 | ) 27 | 28 | 29 | def update_ring_flash_attn_params( 30 | cu_seqlens: torch.Tensor, process_group: dist.ProcessGroup 31 | ): 32 | world_size = dist.get_world_size(group=process_group) 33 | rank = dist.get_rank(group=process_group) 34 | ( 35 | cu_seqlens_q, 36 | cu_seqlens_k, 37 | max_seqlen_q, 38 | max_seqlen_k, 39 | local_k_slice, 40 | ) = llama3_flash_attn_prepare_cu_seqlens(cu_seqlens, True, rank, world_size) 41 | DATA_PARAMS.update( 42 | { 43 | "cu_seqlens_q": cu_seqlens_q, 44 | "cu_seqlens_k": cu_seqlens_k, 45 | "max_seqlen_q": max_seqlen_q, 46 | "max_seqlen_k": max_seqlen_k, 47 | "local_k_slice": local_k_slice, 48 | } 49 | ) 50 | 51 | 52 | def use_ring_attn(flag): 53 | global RING_ATTN_SWITCH 54 | RING_ATTN_SWITCH = flag 55 | 56 | 57 | def create_ring_flash_attention_forward( 58 | process_group: dist.ProcessGroup, heads_k_stride: int 59 | ): 60 | def _flash_attention_forward( 61 | query_states: torch.Tensor, 62 | key_states: torch.Tensor, 63 | value_states: torch.Tensor, 64 | attention_mask: torch.Tensor, 65 | query_length: int, 66 | is_causal: bool, 67 | dropout: float = 0.0, 68 | position_ids: Optional[torch.Tensor] = None, 69 | softmax_scale: Optional[float] = None, 70 | sliding_window: Optional[int] = None, 71 | use_top_left_mask: bool = False, 72 | softcap: Optional[float] = None, 73 | deterministic: bool = None, 74 | ): 75 | """ 76 | Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token 77 | first unpad the input, then computes the attention scores and pad the final attention scores. 78 | 79 | Args: 80 | query_states (`torch.Tensor`): 81 | Input query states to be passed to Flash Attention API 82 | key_states (`torch.Tensor`): 83 | Input key states to be passed to Flash Attention API 84 | value_states (`torch.Tensor`): 85 | Input value states to be passed to Flash Attention API 86 | attention_mask (`torch.Tensor`): 87 | The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the 88 | position of padding tokens and 1 for the position of non-padding tokens. 89 | dropout (`float`): 90 | Attention dropout 91 | softmax_scale (`float`, *optional*): 92 | The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) 93 | use_top_left_mask (`bool`, defaults to `False`): 94 | flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. 95 | softcap (`float`, *optional*): 96 | Softcap for the attention logits, used e.g. in gemma2. 97 | deterministic (`bool`, *optional*): 98 | Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled. 99 | """ 100 | if not use_top_left_mask: 101 | causal = is_causal 102 | else: 103 | # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__. 104 | causal = is_causal and query_length != 1 105 | 106 | # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length). 107 | use_sliding_windows = ( 108 | _flash_supports_window_size 109 | and sliding_window is not None 110 | and key_states.shape[1] > sliding_window 111 | ) 112 | flash_kwargs = ( 113 | {"window_size": (sliding_window, sliding_window)} 114 | if use_sliding_windows 115 | else {} 116 | ) 117 | 118 | if is_flash_attn_greater_or_equal("2.4.1"): 119 | if deterministic is None: 120 | deterministic = ( 121 | os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" 122 | ) 123 | flash_kwargs["deterministic"] = deterministic 124 | assert ( 125 | softcap is None 126 | ), "llama3_flash_attn_varlen_func does not support softcap yet." 127 | # flash_kwargs["softcap"] = softcap 128 | flash_kwargs["group"] = process_group 129 | 130 | # not sure why attention_mask can be not None... 131 | assert causal, "only causal attention is supported yet." 132 | batch_size = query_states.size(0) 133 | assert batch_size == 1, "varlen data should be processed in advance." 134 | 135 | attn_output = llama3_flash_attn_varlen_func( 136 | query_states.squeeze(dim=0), 137 | key_states.squeeze(dim=0), 138 | value_states.squeeze(dim=0), 139 | cu_seqlens_q=DATA_PARAMS["cu_seqlens_q"], 140 | cu_seqlens_k=DATA_PARAMS["cu_seqlens_k"], 141 | max_seqlen_q=DATA_PARAMS["max_seqlen_q"], 142 | max_seqlen_k=DATA_PARAMS["max_seqlen_k"], 143 | heads_k_stride=heads_k_stride, 144 | local_k_slice=DATA_PARAMS["local_k_slice"], 145 | dropout_p=dropout, 146 | softmax_scale=softmax_scale, 147 | causal=causal, 148 | **flash_kwargs, 149 | ) 150 | 151 | attn_output = attn_output.unsqueeze(dim=0) 152 | 153 | return attn_output 154 | 155 | return _flash_attention_forward 156 | 157 | 158 | def substitute_hf_flash_attn(process_group: dist.ProcessGroup, heads_k_stride: int): 159 | try: 160 | # substitute flash attn 161 | old_flash_attention_forward = ( 162 | transformers.modeling_flash_attention_utils._flash_attention_forward 163 | ) 164 | new_flash_attention_forward = create_ring_flash_attention_forward( 165 | process_group, heads_k_stride 166 | ) 167 | assert check_params(old_flash_attention_forward, new_flash_attention_forward) 168 | transformers.modeling_flash_attention_utils._flash_attention_forward = ( 169 | lambda *args, **kwargs: ( 170 | new_flash_attention_forward(*args, **kwargs) 171 | if RING_ATTN_SWITCH 172 | else old_flash_attention_forward(*args, **kwargs) 173 | ) 174 | ) 175 | except: 176 | raise ValueError( 177 | f"The current transformer version {transformers.__version__} is not supported. " 178 | "please use pip install -U transformers to upgrade to the latest version. " 179 | "If the code failed with the latest version, " 180 | "please file an issue to https://github.com/zhuzilin/ring-flash-attention/issues" 181 | ) 182 | -------------------------------------------------------------------------------- /train/ring-flash-attention/build/lib/ring_flash_attn/ring_flash_attn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward 4 | from .utils import RingComm, update_out_and_lse, get_default_args 5 | 6 | 7 | def ring_flash_attn_forward( 8 | process_group, 9 | q: torch.Tensor, 10 | k: torch.Tensor, 11 | v: torch.Tensor, 12 | softmax_scale, 13 | dropout_p=0, 14 | causal=True, 15 | window_size=(-1, -1), 16 | alibi_slopes=None, 17 | deterministic=False, 18 | ): 19 | comm = RingComm(process_group) 20 | 21 | out = None 22 | lse = None 23 | 24 | next_k, next_v = None, None 25 | 26 | for step in range(comm.world_size): 27 | if step + 1 != comm.world_size: 28 | next_k: torch.Tensor = comm.send_recv(k) 29 | next_v: torch.Tensor = comm.send_recv(v) 30 | comm.commit() 31 | 32 | if not causal or step <= comm.rank: 33 | params = get_default_args(_flash_attn_forward).copy() 34 | params.update( 35 | { 36 | "q": q, 37 | "k": k, 38 | "v": v, 39 | "dropout_p": dropout_p, 40 | "softmax_scale": softmax_scale, 41 | "causal": causal and step == 0, 42 | "window_size": window_size, 43 | "alibi_slopes": alibi_slopes, 44 | "return_softmax": True and dropout_p > 0, 45 | } 46 | ) 47 | block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward(**params) 48 | out, lse = update_out_and_lse(out, lse, block_out, block_lse) 49 | 50 | if step + 1 != comm.world_size: 51 | comm.wait() 52 | k = next_k 53 | v = next_v 54 | 55 | out = out.to(q.dtype) 56 | lse = lse.squeeze(dim=-1).transpose(1, 2) 57 | return out, lse 58 | 59 | 60 | def ring_flash_attn_backward( 61 | process_group, 62 | dout, 63 | q, 64 | k, 65 | v, 66 | out, 67 | softmax_lse, 68 | softmax_scale, 69 | dropout_p=0, 70 | causal=True, 71 | window_size=(-1, -1), 72 | alibi_slopes=None, 73 | deterministic=False, 74 | ): 75 | kv_comm = RingComm(process_group) 76 | d_kv_comm = RingComm(process_group) 77 | dq, dk, dv = None, None, None 78 | next_dk, next_dv = None, None 79 | 80 | block_dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) 81 | block_dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) 82 | block_dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) 83 | 84 | next_dk, next_dv = None, None 85 | next_k, next_v = None, None 86 | 87 | for step in range(kv_comm.world_size): 88 | if step + 1 != kv_comm.world_size: 89 | next_k = kv_comm.send_recv(k) 90 | next_v = kv_comm.send_recv(v) 91 | kv_comm.commit() 92 | if step <= kv_comm.rank or not causal: 93 | bwd_causal = causal and step == 0 94 | params = get_default_args(_flash_attn_backward).copy() 95 | params.update( 96 | { 97 | "dout": dout, 98 | "q": q, 99 | "k": k, 100 | "v": v, 101 | "out": out, 102 | "softmax_lse": softmax_lse, 103 | "dq": block_dq_buffer, 104 | "dk": block_dk_buffer, 105 | "dv": block_dv_buffer, 106 | "dropout_p": dropout_p, 107 | "softmax_scale": softmax_scale, 108 | "causal": bwd_causal, 109 | "window_size": window_size, 110 | "alibi_slopes": alibi_slopes, 111 | "deterministic": deterministic, 112 | } 113 | ) 114 | _flash_attn_backward(**params) 115 | 116 | if dq is None: 117 | dq = block_dq_buffer.to(torch.float32) 118 | dk = block_dk_buffer.to(torch.float32) 119 | dv = block_dv_buffer.to(torch.float32) 120 | else: 121 | dq += block_dq_buffer 122 | d_kv_comm.wait() 123 | dk = block_dk_buffer + next_dk 124 | dv = block_dv_buffer + next_dv 125 | elif step != 0: 126 | d_kv_comm.wait() 127 | dk = next_dk 128 | dv = next_dv 129 | 130 | if step + 1 != kv_comm.world_size: 131 | kv_comm.wait() 132 | k = next_k 133 | v = next_v 134 | 135 | next_dk = d_kv_comm.send_recv(dk) 136 | next_dv = d_kv_comm.send_recv(dv) 137 | d_kv_comm.commit() 138 | 139 | d_kv_comm.wait() 140 | 141 | return dq.to(torch.bfloat16), next_dk.to(q.dtype), next_dv.to(q.dtype) 142 | 143 | 144 | class RingFlashAttnFunc(torch.autograd.Function): 145 | @staticmethod 146 | def forward( 147 | ctx, 148 | q, 149 | k, 150 | v, 151 | dropout_p, 152 | softmax_scale, 153 | causal, 154 | window_size, 155 | alibi_slopes, 156 | deterministic, 157 | return_softmax, 158 | group, 159 | ): 160 | if softmax_scale is None: 161 | softmax_scale = q.shape[-1] ** (-0.5) 162 | 163 | assert alibi_slopes is None 164 | k = k.contiguous() 165 | v = v.contiguous() 166 | out, softmax_lse = ring_flash_attn_forward( 167 | group, 168 | q, 169 | k, 170 | v, 171 | softmax_scale=softmax_scale, 172 | dropout_p=dropout_p, 173 | causal=causal, 174 | window_size=window_size, 175 | alibi_slopes=alibi_slopes, 176 | deterministic=False, 177 | ) 178 | # this should be out_padded 179 | ctx.save_for_backward(q, k, v, out, softmax_lse) 180 | ctx.dropout_p = dropout_p 181 | ctx.softmax_scale = softmax_scale 182 | ctx.causal = causal 183 | ctx.window_size = window_size 184 | ctx.alibi_slopes = alibi_slopes 185 | ctx.deterministic = deterministic 186 | ctx.group = group 187 | return out if not return_softmax else (out, softmax_lse, None) 188 | 189 | @staticmethod 190 | def backward(ctx, dout, *args): 191 | q, k, v, out, softmax_lse = ctx.saved_tensors 192 | dq, dk, dv = ring_flash_attn_backward( 193 | ctx.group, 194 | dout, 195 | q, 196 | k, 197 | v, 198 | out, 199 | softmax_lse, 200 | softmax_scale=ctx.softmax_scale, 201 | dropout_p=ctx.dropout_p, 202 | causal=ctx.causal, 203 | window_size=ctx.window_size, 204 | alibi_slopes=ctx.alibi_slopes, 205 | deterministic=ctx.deterministic, 206 | ) 207 | return dq, dk, dv, None, None, None, None, None, None, None, None 208 | 209 | 210 | def ring_flash_attn_qkvpacked_func( 211 | qkv, 212 | dropout_p=0.0, 213 | softmax_scale=None, 214 | causal=False, 215 | window_size=(-1, -1), 216 | alibi_slopes=None, 217 | deterministic=False, 218 | return_attn_probs=False, 219 | group=None, 220 | ): 221 | return RingFlashAttnFunc.apply( 222 | qkv[:, :, 0], 223 | qkv[:, :, 1], 224 | qkv[:, :, 2], 225 | dropout_p, 226 | softmax_scale, 227 | causal, 228 | window_size, 229 | alibi_slopes, 230 | deterministic, 231 | return_attn_probs, 232 | group, 233 | ) 234 | 235 | 236 | def ring_flash_attn_kvpacked_func( 237 | q, 238 | kv, 239 | dropout_p=0.0, 240 | softmax_scale=None, 241 | causal=False, 242 | window_size=(-1, -1), 243 | alibi_slopes=None, 244 | deterministic=False, 245 | return_attn_probs=False, 246 | group=None, 247 | ): 248 | return RingFlashAttnFunc.apply( 249 | q, 250 | kv[:, :, 0], 251 | kv[:, :, 1], 252 | dropout_p, 253 | softmax_scale, 254 | causal, 255 | window_size, 256 | alibi_slopes, 257 | deterministic, 258 | return_attn_probs, 259 | group, 260 | ) 261 | 262 | 263 | def ring_flash_attn_func( 264 | q, 265 | k, 266 | v, 267 | dropout_p=0.0, 268 | softmax_scale=None, 269 | causal=False, 270 | window_size=(-1, -1), 271 | alibi_slopes=None, 272 | deterministic=False, 273 | return_attn_probs=False, 274 | group=None, 275 | ): 276 | return RingFlashAttnFunc.apply( 277 | q, 278 | k, 279 | v, 280 | dropout_p, 281 | softmax_scale, 282 | causal, 283 | window_size, 284 | alibi_slopes, 285 | deterministic, 286 | return_attn_probs, 287 | group, 288 | ) 289 | -------------------------------------------------------------------------------- /train/ring-flash-attention/build/lib/ring_flash_attn/ring_flash_attn_varlen.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from flash_attn.flash_attn_interface import ( 4 | _flash_attn_varlen_forward, 5 | _flash_attn_varlen_backward, 6 | ) 7 | from .utils import ( 8 | RingComm, 9 | update_out_and_lse, 10 | get_default_args, 11 | ) 12 | 13 | try: 14 | from .triton_utils import ( 15 | flatten_varlen_lse, 16 | unflatten_varlen_lse, 17 | ) 18 | except: 19 | from .utils import ( 20 | flatten_varlen_lse, 21 | unflatten_varlen_lse, 22 | ) 23 | 24 | 25 | def ring_flash_attn_varlen_forward( 26 | process_group, 27 | q: torch.Tensor, 28 | k: torch.Tensor, 29 | v: torch.Tensor, 30 | cu_seqlens, 31 | max_seqlen, 32 | softmax_scale, 33 | dropout_p=0, 34 | causal=True, 35 | window_size=(-1, -1), 36 | alibi_slopes=None, 37 | deterministic=False, 38 | ): 39 | comm = RingComm(process_group) 40 | 41 | out = None 42 | lse = None 43 | next_k, next_v = None, None 44 | 45 | old_lse = False 46 | for step in range(comm.world_size): 47 | if step + 1 != comm.world_size: 48 | next_k: torch.Tensor = comm.send_recv(k) 49 | next_v: torch.Tensor = comm.send_recv(v) 50 | comm.commit() 51 | if not causal or step <= comm.rank: 52 | params = get_default_args(_flash_attn_varlen_forward).copy() 53 | params.update( 54 | { 55 | "q": q, 56 | "k": k, 57 | "v": v, 58 | "cu_seqlens_q": cu_seqlens, 59 | "cu_seqlens_k": cu_seqlens, 60 | "max_seqlen_q": max_seqlen, 61 | "max_seqlen_k": max_seqlen, 62 | "dropout_p": dropout_p, 63 | "softmax_scale": softmax_scale, 64 | "causal": causal and step == 0, 65 | "window_size": window_size, 66 | "alibi_slopes": alibi_slopes, 67 | "return_softmax": True and dropout_p > 0, 68 | } 69 | ) 70 | 71 | block_out, _, _, _, _, block_lse, _, _ = _flash_attn_varlen_forward( 72 | **params 73 | ) 74 | if block_lse.dim() == 3: 75 | old_lse = True 76 | block_lse = flatten_varlen_lse( 77 | block_lse, 78 | cu_seqlens=cu_seqlens, 79 | ) 80 | out, lse = update_out_and_lse(out, lse, block_out, block_lse) 81 | 82 | if step + 1 != comm.world_size: 83 | comm.wait() 84 | k = next_k 85 | v = next_v 86 | 87 | out = out.to(q.dtype) 88 | if old_lse: 89 | lse = unflatten_varlen_lse(lse, cu_seqlens, max_seqlen) 90 | else: 91 | lse = lse.squeeze(dim=-1).transpose(0, 1) 92 | return out, lse 93 | 94 | 95 | def ring_flash_attn_varlen_backward( 96 | process_group, 97 | dout, 98 | q, 99 | k, 100 | v, 101 | out, 102 | softmax_lse, 103 | cu_seqlens, 104 | max_seqlen, 105 | softmax_scale, 106 | dropout_p=0, 107 | causal=True, 108 | window_size=(-1, -1), 109 | alibi_slopes=None, 110 | deterministic=False, 111 | ): 112 | kv_comm = RingComm(process_group) 113 | d_kv_comm = RingComm(process_group) 114 | dq, dk, dv = None, None, None 115 | next_dk, next_dv = None, None 116 | 117 | block_dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) 118 | block_dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) 119 | block_dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) 120 | 121 | next_dk, next_dv = None, None 122 | next_k, next_v = None, None 123 | for step in range(kv_comm.world_size): 124 | if step + 1 != kv_comm.world_size: 125 | next_k = kv_comm.send_recv(k) 126 | next_v = kv_comm.send_recv(v) 127 | kv_comm.commit() 128 | if step <= kv_comm.rank or not causal: 129 | bwd_causal = causal and step == 0 130 | params = get_default_args(_flash_attn_varlen_backward).copy() 131 | params.update( 132 | { 133 | "dout": dout, 134 | "q": q, 135 | "k": k, 136 | "v": v, 137 | "out": out, 138 | "softmax_lse": softmax_lse, 139 | "dq": block_dq_buffer, 140 | "dk": block_dk_buffer, 141 | "dv": block_dv_buffer, 142 | "cu_seqlens_q": cu_seqlens, 143 | "cu_seqlens_k": cu_seqlens, 144 | "max_seqlen_q": max_seqlen, 145 | "max_seqlen_k": max_seqlen, 146 | "dropout_p": dropout_p, 147 | "softmax_scale": softmax_scale, 148 | "causal": bwd_causal, 149 | "window_size": window_size, 150 | "alibi_slopes": alibi_slopes, 151 | "deterministic": deterministic, 152 | } 153 | ) 154 | _flash_attn_varlen_backward(**params) 155 | 156 | if dq is None: 157 | dq = block_dq_buffer.to(torch.float32) 158 | dk = block_dk_buffer.to(torch.float32) 159 | dv = block_dv_buffer.to(torch.float32) 160 | else: 161 | dq += block_dq_buffer 162 | d_kv_comm.wait() 163 | dk = block_dk_buffer + next_dk 164 | dv = block_dv_buffer + next_dv 165 | elif step != 0: 166 | d_kv_comm.wait() 167 | dk = next_dk 168 | dv = next_dv 169 | 170 | if step + 1 != kv_comm.world_size: 171 | kv_comm.wait() 172 | k = next_k 173 | v = next_v 174 | 175 | next_dk = d_kv_comm.send_recv(dk) 176 | next_dv = d_kv_comm.send_recv(dv) 177 | d_kv_comm.commit() 178 | 179 | d_kv_comm.wait() 180 | 181 | return dq.to(torch.bfloat16), next_dk.to(q.dtype), next_dv.to(q.dtype) 182 | 183 | 184 | class RingFlashAttnVarlenFunc(torch.autograd.Function): 185 | @staticmethod 186 | def forward( 187 | ctx, 188 | q, 189 | k, 190 | v, 191 | cu_seqlens, 192 | max_seqlen, 193 | dropout_p, 194 | softmax_scale, 195 | causal, 196 | window_size, 197 | alibi_slopes, 198 | deterministic, 199 | return_softmax, 200 | group, 201 | ): 202 | if softmax_scale is None: 203 | softmax_scale = q.shape[-1] ** (-0.5) 204 | 205 | assert alibi_slopes is None 206 | k = k.contiguous() 207 | v = v.contiguous() 208 | out, softmax_lse = ring_flash_attn_varlen_forward( 209 | group, 210 | q, 211 | k, 212 | v, 213 | cu_seqlens, 214 | max_seqlen, 215 | softmax_scale=softmax_scale, 216 | dropout_p=dropout_p, 217 | causal=causal, 218 | window_size=window_size, 219 | alibi_slopes=alibi_slopes, 220 | deterministic=False, 221 | ) 222 | # this should be out_padded 223 | ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens) 224 | ctx.max_seqlen = max_seqlen 225 | ctx.dropout_p = dropout_p 226 | ctx.softmax_scale = softmax_scale 227 | ctx.causal = causal 228 | ctx.window_size = window_size 229 | ctx.alibi_slopes = alibi_slopes 230 | ctx.deterministic = deterministic 231 | ctx.group = group 232 | return out if not return_softmax else (out, softmax_lse, None) 233 | 234 | @staticmethod 235 | def backward(ctx, dout, *args): 236 | q, k, v, out, softmax_lse, cu_seqlens = ctx.saved_tensors 237 | dq, dk, dv = ring_flash_attn_varlen_backward( 238 | ctx.group, 239 | dout, 240 | q, 241 | k, 242 | v, 243 | out, 244 | softmax_lse, 245 | cu_seqlens, 246 | ctx.max_seqlen, 247 | softmax_scale=ctx.softmax_scale, 248 | dropout_p=ctx.dropout_p, 249 | causal=ctx.causal, 250 | window_size=ctx.window_size, 251 | alibi_slopes=ctx.alibi_slopes, 252 | deterministic=ctx.deterministic, 253 | ) 254 | return dq, dk, dv, None, None, None, None, None, None, None, None, None, None 255 | 256 | 257 | def ring_flash_attn_varlen_qkvpacked_func( 258 | qkv, 259 | cu_seqlens, 260 | max_seqlen, 261 | dropout_p=0.0, 262 | softmax_scale=None, 263 | causal=False, 264 | window_size=(-1, -1), # -1 means infinite context window 265 | alibi_slopes=None, 266 | deterministic=False, 267 | return_attn_probs=False, 268 | group=None, 269 | ): 270 | return RingFlashAttnVarlenFunc.apply( 271 | qkv[:, 0], 272 | qkv[:, 1], 273 | qkv[:, 2], 274 | cu_seqlens, 275 | max_seqlen, 276 | dropout_p, 277 | softmax_scale, 278 | causal, 279 | window_size, 280 | alibi_slopes, 281 | deterministic, 282 | return_attn_probs, 283 | group, 284 | ) 285 | 286 | 287 | def ring_flash_attn_varlen_kvpacked_func( 288 | q, 289 | kv, 290 | cu_seqlens, 291 | max_seqlen, 292 | dropout_p=0.0, 293 | softmax_scale=None, 294 | causal=False, 295 | window_size=(-1, -1), # -1 means infinite context window 296 | alibi_slopes=None, 297 | deterministic=False, 298 | return_attn_probs=False, 299 | group=None, 300 | ): 301 | return RingFlashAttnVarlenFunc.apply( 302 | q, 303 | kv[:, 0], 304 | kv[:, 1], 305 | cu_seqlens, 306 | max_seqlen, 307 | dropout_p, 308 | softmax_scale, 309 | causal, 310 | window_size, 311 | alibi_slopes, 312 | deterministic, 313 | return_attn_probs, 314 | group, 315 | ) 316 | 317 | 318 | def ring_flash_attn_varlen_func( 319 | q, 320 | k, 321 | v, 322 | cu_seqlens, 323 | max_seqlen, 324 | dropout_p=0.0, 325 | softmax_scale=None, 326 | causal=False, 327 | window_size=(-1, -1), # -1 means infinite context window 328 | alibi_slopes=None, 329 | deterministic=False, 330 | return_attn_probs=False, 331 | group=None, 332 | ): 333 | return RingFlashAttnVarlenFunc.apply( 334 | q, 335 | k, 336 | v, 337 | cu_seqlens, 338 | max_seqlen, 339 | dropout_p, 340 | softmax_scale, 341 | causal, 342 | window_size, 343 | alibi_slopes, 344 | deterministic, 345 | return_attn_probs, 346 | group, 347 | ) 348 | -------------------------------------------------------------------------------- /train/ring-flash-attention/build/lib/ring_flash_attn/triton_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import triton 3 | import triton.language as tl 4 | 5 | 6 | @triton.jit 7 | def flatten_kernel( 8 | # pointers to matrices 9 | OUT, 10 | LSE, 11 | CU_SEQLENS, 12 | # strides 13 | stride_out_nheads, 14 | stride_out_seqlen, 15 | stride_lse_batch, 16 | stride_lse_nheads, 17 | stride_lse_seqlen, 18 | # meta-parameters 19 | BLOCK_M: tl.constexpr, 20 | ): 21 | pid_m = tl.program_id(axis=0) 22 | pid_batch = tl.program_id(axis=1) 23 | pid_head = tl.program_id(axis=2) 24 | 25 | start_idx = tl.load(CU_SEQLENS + pid_batch) 26 | seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx 27 | LSE = LSE + pid_batch * stride_lse_batch + pid_head * stride_lse_nheads 28 | OUT = OUT + pid_head * stride_out_nheads + start_idx * stride_out_seqlen 29 | 30 | rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 31 | 32 | LSE = LSE + rm[:, None] * stride_lse_seqlen 33 | x = tl.load(LSE, mask=rm[:, None] < seqlen, other=0.0) 34 | 35 | OUT = OUT + rm[:, None] * stride_out_seqlen 36 | tl.store(OUT, x, mask=rm[:, None] < seqlen) 37 | 38 | 39 | def flatten_varlen_lse(lse, cu_seqlens): 40 | """ 41 | Arguments: 42 | lse: (batch_size, nheads, max_seqlen) 43 | cu_seqlens: (batch_size + 1,) 44 | Return: 45 | flatten_lse: (nheads, total_seqlen) 46 | """ 47 | total_seqlen = cu_seqlens[-1] 48 | batch_size, nheads, max_seqlen = lse.shape 49 | output = torch.empty((nheads, total_seqlen), dtype=lse.dtype, device=lse.device) 50 | 51 | grid = lambda META: (triton.cdiv(max_seqlen, META["BLOCK_M"]), batch_size, nheads) 52 | BLOCK_M = 4 53 | 54 | with torch.cuda.device(lse.device.index): 55 | flatten_kernel[grid]( 56 | output, 57 | lse, 58 | cu_seqlens, 59 | # strides 60 | output.stride(0), 61 | output.stride(1), 62 | lse.stride(0), 63 | lse.stride(1), 64 | lse.stride(2), 65 | BLOCK_M, 66 | ) 67 | return output 68 | 69 | 70 | @triton.jit 71 | def unflatten_kernel( 72 | # pointers to matrices 73 | OUT, 74 | LSE, 75 | CU_SEQLENS, 76 | # strides 77 | stride_out_batch, 78 | stride_out_nheads, 79 | stride_out_seqlen, 80 | stride_lse_seqlen, 81 | stride_lse_nheads, 82 | # meta-parameters 83 | BLOCK_M: tl.constexpr, 84 | ): 85 | pid_m = tl.program_id(axis=0) 86 | pid_batch = tl.program_id(axis=1) 87 | pid_head = tl.program_id(axis=2) 88 | 89 | start_idx = tl.load(CU_SEQLENS + pid_batch) 90 | seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx 91 | LSE = LSE + pid_head * stride_lse_nheads + start_idx * stride_lse_seqlen 92 | OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads 93 | 94 | rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 95 | 96 | LSE = LSE + rm[:, None] * stride_lse_seqlen 97 | x = tl.load(LSE, mask=rm[:, None] < seqlen, other=0.0) 98 | 99 | OUT = OUT + rm[:, None] * stride_out_seqlen 100 | tl.store(OUT, x, mask=rm[:, None] < seqlen) 101 | 102 | 103 | def unflatten_varlen_lse(lse, cu_seqlens, max_seqlen: int): 104 | """ 105 | Arguments: 106 | lse: (total_seqlen, nheads, 1) 107 | cu_seqlens: (batch_size + 1,) 108 | max_seqlen: int 109 | Return: 110 | unflatten_lse: (batch_size, nheads, max_seqlen) 111 | """ 112 | lse = lse.unsqueeze(dim=-1) 113 | batch_size = len(cu_seqlens) - 1 114 | nheads = lse.shape[1] 115 | output = torch.empty( 116 | (batch_size, nheads, max_seqlen), 117 | dtype=lse.dtype, 118 | device=lse.device, 119 | ) 120 | 121 | grid = lambda META: (triton.cdiv(max_seqlen, META["BLOCK_M"]), batch_size, nheads) 122 | BLOCK_M = 4 123 | 124 | with torch.cuda.device(lse.device.index): 125 | unflatten_kernel[grid]( 126 | output, 127 | lse, 128 | cu_seqlens, 129 | # strides 130 | output.stride(0), 131 | output.stride(1), 132 | output.stride(2), 133 | lse.stride(0), 134 | lse.stride(1), 135 | BLOCK_M, 136 | ) 137 | return output 138 | -------------------------------------------------------------------------------- /train/ring-flash-attention/build/lib/ring_flash_attn/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import torch 4 | import torch.distributed as dist 5 | import torch.nn.functional as F 6 | import inspect 7 | from functools import cache 8 | 9 | 10 | __all__ = ["update_out_and_lse", "RingComm", "get_default_args"] 11 | 12 | 13 | @cache 14 | def get_default_args(func): 15 | spec = inspect.getfullargspec(func) 16 | defaults = spec.defaults if spec.defaults is not None else () 17 | padded_defaults = (None,) * (len(spec.args) - len(defaults)) + defaults 18 | args = dict(zip(spec.args, padded_defaults)) 19 | if "softcap" in args: 20 | args["softcap"] = 0.0 21 | return args 22 | 23 | 24 | @torch.jit.script 25 | def _update_out_and_lse( 26 | out: torch.Tensor, 27 | lse: torch.Tensor, 28 | block_out: torch.Tensor, 29 | block_lse: torch.Tensor, 30 | ) -> Tuple[torch.Tensor, torch.Tensor]: 31 | 32 | block_out = block_out.to(torch.float32) 33 | block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) 34 | 35 | # new_lse = lse + torch.log(1 + torch.exp(block_lse - lse)) 36 | # torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out 37 | # For additional context and discussion, please refer to: 38 | # https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795 39 | out = out - F.sigmoid(block_lse - lse) * (out - block_out) 40 | lse = lse - F.logsigmoid(lse - block_lse) 41 | 42 | return out, lse 43 | 44 | 45 | def update_out_and_lse( 46 | out: Optional[torch.Tensor], 47 | lse: Optional[torch.Tensor], 48 | block_out: torch.Tensor, 49 | block_lse: torch.Tensor, 50 | slice_=None, 51 | ) -> Tuple[torch.Tensor, torch.Tensor]: 52 | if out is None: 53 | if slice_ is not None: 54 | raise RuntimeError("first update_out_and_lse should not pass slice_ args") 55 | out = block_out.to(torch.float32) 56 | lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) 57 | elif slice_ is not None: 58 | slice_out, slice_lse = out[slice_], lse[slice_] 59 | slice_out, slice_lse = _update_out_and_lse( 60 | slice_out, slice_lse, block_out, block_lse 61 | ) 62 | out[slice_], lse[slice_] = slice_out, slice_lse 63 | else: 64 | out, lse = _update_out_and_lse(out, lse, block_out, block_lse) 65 | return out, lse 66 | 67 | 68 | @torch.jit.script 69 | def flatten_varlen_lse(lse, cu_seqlens): 70 | new_lse = [] 71 | for i in range(len(cu_seqlens) - 1): 72 | start, end = cu_seqlens[i], cu_seqlens[i + 1] 73 | new_lse.append(lse[i, :, : end - start]) 74 | return torch.cat(new_lse, dim=1) 75 | 76 | 77 | @torch.jit.script 78 | def unflatten_varlen_lse(lse, cu_seqlens, max_seqlen: int): 79 | num_seq = len(cu_seqlens) - 1 80 | num_head = lse.shape[-2] 81 | new_lse = torch.empty( 82 | (num_seq, max_seqlen, num_head, 1), dtype=torch.float32, device=lse.device 83 | ) 84 | for i in range(num_seq): 85 | start, end = cu_seqlens[i], cu_seqlens[i + 1] 86 | new_lse[i, : end - start] = lse[start:end] 87 | return new_lse.squeeze(dim=-1).transpose(1, 2).contiguous() 88 | 89 | 90 | class RingComm: 91 | def __init__(self, process_group: dist.ProcessGroup): 92 | self._process_group = process_group 93 | self._ops = [] 94 | self.rank = dist.get_rank(self._process_group) 95 | self.world_size = dist.get_world_size(self._process_group) 96 | self._reqs = None 97 | 98 | self.send_rank = (self.rank + 1) % self.world_size 99 | self.recv_rank = (self.rank - 1) % self.world_size 100 | 101 | if process_group is not None: 102 | self.send_rank = dist.get_global_rank(self._process_group, self.send_rank) 103 | self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank) 104 | 105 | def send_recv( 106 | self, to_send: torch.Tensor, recv_tensor: Optional[torch.Tensor] = None 107 | ) -> torch.Tensor: 108 | if recv_tensor is None: 109 | res = torch.empty_like(to_send) 110 | else: 111 | res = recv_tensor 112 | 113 | send_op = dist.P2POp( 114 | dist.isend, to_send, self.send_rank, group=self._process_group 115 | ) 116 | recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group) 117 | self._ops.append(send_op) 118 | self._ops.append(recv_op) 119 | return res 120 | 121 | def commit(self): 122 | if self._reqs is not None: 123 | raise RuntimeError("commit called twice") 124 | self._reqs = dist.batch_isend_irecv(self._ops) 125 | 126 | def wait(self): 127 | if self._reqs is None: 128 | raise RuntimeError("wait called before commit") 129 | for req in self._reqs: 130 | req.wait() 131 | self._reqs = None 132 | self._ops = [] 133 | 134 | 135 | class AllGatherComm: 136 | def __init__(self, group=None) -> None: 137 | self.group = group 138 | self.handles = [] 139 | 140 | def all_gather(self, output_tensor: torch.Tensor, input_tensor: torch.Tensor): 141 | handle = dist.all_gather_into_tensor( 142 | output_tensor, input_tensor, group=self.group, async_op=True 143 | ) 144 | self.handles.append(handle) 145 | 146 | def wait(self): 147 | for handle in self.handles: 148 | handle.wait() 149 | self.handles = [] 150 | -------------------------------------------------------------------------------- /train/ring-flash-attention/build/lib/ring_flash_attn/zigzag_ring_flash_attn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward 4 | from .utils import RingComm, update_out_and_lse, get_default_args 5 | 6 | 7 | def zigzag_ring_flash_attn_forward( 8 | process_group, 9 | q: torch.Tensor, 10 | k: torch.Tensor, 11 | v: torch.Tensor, 12 | softmax_scale, 13 | dropout_p=0, 14 | causal=True, 15 | window_size=(-1, -1), 16 | alibi_slopes=None, 17 | deterministic=False, 18 | ): 19 | assert causal == True, "zigzag ring is meaningless for causal=False" 20 | comm = RingComm(process_group) 21 | 22 | block_seq_len = q.shape[1] // 2 23 | q1 = q[:, block_seq_len:] 24 | 25 | out = None 26 | lse = None 27 | next_k, next_v = None, None 28 | 29 | def forward(q, k, v, causal): 30 | params = get_default_args(_flash_attn_forward).copy() 31 | params.update( 32 | { 33 | "q": q, 34 | "k": k, 35 | "v": v, 36 | "dropout_p": dropout_p, 37 | "softmax_scale": softmax_scale, 38 | "causal": causal, 39 | "window_size": window_size, 40 | "alibi_slopes": alibi_slopes, 41 | "return_softmax": True and dropout_p > 0, 42 | } 43 | ) 44 | block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward(**params) 45 | return block_out, block_lse 46 | 47 | for step in range(comm.world_size): 48 | if step + 1 != comm.world_size: 49 | next_k: torch.Tensor = comm.send_recv(k) 50 | next_v: torch.Tensor = comm.send_recv(v) 51 | comm.commit() 52 | 53 | if step == 0: 54 | block_out, block_lse = forward(q, k, v, causal=True) 55 | out, lse = update_out_and_lse(out, lse, block_out, block_lse) 56 | elif step <= comm.rank: 57 | k0 = k[:, :block_seq_len] 58 | v0 = v[:, :block_seq_len] 59 | block_out, block_lse = forward(q, k0, v0, causal=False) 60 | out, lse = update_out_and_lse(out, lse, block_out, block_lse) 61 | else: 62 | block_out, block_lse = forward(q1, k, v, causal=False) 63 | out, lse = update_out_and_lse( 64 | out, 65 | lse, 66 | block_out, 67 | block_lse, 68 | slice_=(slice(None), slice(block_seq_len, None)), 69 | ) 70 | 71 | if step + 1 != comm.world_size: 72 | comm.wait() 73 | k = next_k 74 | v = next_v 75 | 76 | out = out.to(q.dtype) 77 | lse = lse.squeeze(dim=-1).transpose(1, 2) 78 | return out, lse 79 | 80 | 81 | def zigzag_ring_flash_attn_backward( 82 | process_group, 83 | dout, 84 | q, 85 | k, 86 | v, 87 | out, 88 | softmax_lse, 89 | softmax_scale, 90 | dropout_p=0, 91 | causal=True, 92 | window_size=(-1, -1), 93 | alibi_slopes=None, 94 | deterministic=False, 95 | ): 96 | assert causal == True, "zigzag ring is meaningless for causal=False" 97 | kv_comm = RingComm(process_group) 98 | d_kv_comm = RingComm(process_group) 99 | dq, dk, dv = None, None, None 100 | next_dk, next_dv = None, None 101 | next_k, next_v = None, None 102 | dk_comm_buffer, dv_comm_buffer = None, None 103 | 104 | dout1 = dout.chunk(2, dim=1)[1] 105 | q1 = q.chunk(2, dim=1)[1] 106 | out1 = out.chunk(2, dim=1)[1] 107 | softmax_lse1 = softmax_lse.chunk(2, dim=2)[1].contiguous() 108 | block_seq_len = q.shape[1] // 2 109 | 110 | # repeatly allocating buffer may be slow... 111 | dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) 112 | dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) 113 | dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) 114 | 115 | def backward(dout, q, k, v, out, softmax_lse, causal): 116 | seqlen_q = q.shape[1] 117 | seqlen_kv = k.shape[1] 118 | params = get_default_args(_flash_attn_backward).copy() 119 | params.update( 120 | { 121 | "dout": dout, 122 | "q": q, 123 | "k": k, 124 | "v": v, 125 | "out": out, 126 | "softmax_lse": softmax_lse, 127 | "dq": dq_buffer[:, :seqlen_q], 128 | "dk": dk_buffer[:, :seqlen_kv], 129 | "dv": dv_buffer[:, :seqlen_kv], 130 | "dropout_p": dropout_p, 131 | "softmax_scale": softmax_scale, 132 | "causal": causal, 133 | "window_size": window_size, 134 | "alibi_slopes": alibi_slopes, 135 | "deterministic": deterministic, 136 | } 137 | ) 138 | _flash_attn_backward(**params) 139 | 140 | for step in range(kv_comm.world_size): 141 | if step + 1 != kv_comm.world_size: 142 | next_k = kv_comm.send_recv(k) 143 | next_v = kv_comm.send_recv(v) 144 | kv_comm.commit() 145 | 146 | if step == 0: 147 | backward(dout, q, k, v, out, softmax_lse, causal=True) 148 | dq = dq_buffer.to(torch.float32) 149 | dk = dk_buffer.to(torch.float32) 150 | dv = dv_buffer.to(torch.float32) 151 | else: 152 | if step <= kv_comm.rank: 153 | k0 = k[:, :block_seq_len] 154 | v0 = v[:, :block_seq_len] 155 | backward(dout, q, k0, v0, out, softmax_lse, causal=False) 156 | dq += dq_buffer 157 | else: 158 | backward(dout1, q1, k, v, out1, softmax_lse1, causal=False) 159 | # always use the first half in dq_buffer. 160 | dq[:, block_seq_len:] += dq_buffer[:, :block_seq_len] 161 | 162 | d_kv_comm.wait() 163 | dk_comm_buffer, dv_comm_buffer = dk, dv 164 | dk, dv = next_dk, next_dv 165 | 166 | if step <= kv_comm.rank: 167 | dk[:, :block_seq_len] += dk_buffer[:, :block_seq_len] 168 | dv[:, :block_seq_len] += dv_buffer[:, :block_seq_len] 169 | else: 170 | dk += dk_buffer 171 | dv += dv_buffer 172 | 173 | if step + 1 != kv_comm.world_size: 174 | kv_comm.wait() 175 | k = next_k 176 | v = next_v 177 | 178 | next_dk = d_kv_comm.send_recv(dk, dk_comm_buffer) 179 | next_dv = d_kv_comm.send_recv(dv, dv_comm_buffer) 180 | d_kv_comm.commit() 181 | 182 | d_kv_comm.wait() 183 | 184 | return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) 185 | 186 | 187 | class ZigZagRingFlashAttnFunc(torch.autograd.Function): 188 | @staticmethod 189 | def forward( 190 | ctx, 191 | q, 192 | k, 193 | v, 194 | dropout_p, 195 | softmax_scale, 196 | causal, 197 | window_size, 198 | alibi_slopes, 199 | deterministic, 200 | return_softmax, 201 | group, 202 | ): 203 | if softmax_scale is None: 204 | softmax_scale = q.shape[-1] ** (-0.5) 205 | 206 | assert alibi_slopes is None 207 | k = k.contiguous() 208 | v = v.contiguous() 209 | out, softmax_lse = zigzag_ring_flash_attn_forward( 210 | group, 211 | q, 212 | k, 213 | v, 214 | softmax_scale=softmax_scale, 215 | dropout_p=dropout_p, 216 | causal=causal, 217 | window_size=window_size, 218 | alibi_slopes=alibi_slopes, 219 | deterministic=False, 220 | ) 221 | # this should be out_padded 222 | ctx.save_for_backward(q, k, v, out, softmax_lse) 223 | ctx.dropout_p = dropout_p 224 | ctx.softmax_scale = softmax_scale 225 | ctx.causal = causal 226 | ctx.window_size = window_size 227 | ctx.alibi_slopes = alibi_slopes 228 | ctx.deterministic = deterministic 229 | ctx.group = group 230 | return out if not return_softmax else (out, softmax_lse, None) 231 | 232 | @staticmethod 233 | def backward(ctx, dout, *args): 234 | q, k, v, out, softmax_lse = ctx.saved_tensors 235 | dq, dk, dv = zigzag_ring_flash_attn_backward( 236 | ctx.group, 237 | dout, 238 | q, 239 | k, 240 | v, 241 | out, 242 | softmax_lse, 243 | softmax_scale=ctx.softmax_scale, 244 | dropout_p=ctx.dropout_p, 245 | causal=ctx.causal, 246 | window_size=ctx.window_size, 247 | alibi_slopes=ctx.alibi_slopes, 248 | deterministic=ctx.deterministic, 249 | ) 250 | return dq, dk, dv, None, None, None, None, None, None, None, None 251 | 252 | 253 | def zigzag_ring_flash_attn_qkvpacked_func( 254 | qkv, 255 | dropout_p=0.0, 256 | softmax_scale=None, 257 | causal=False, 258 | window_size=(-1, -1), 259 | alibi_slopes=None, 260 | deterministic=False, 261 | return_attn_probs=False, 262 | group=None, 263 | ): 264 | return ZigZagRingFlashAttnFunc.apply( 265 | qkv[:, :, 0], 266 | qkv[:, :, 1], 267 | qkv[:, :, 2], 268 | dropout_p, 269 | softmax_scale, 270 | causal, 271 | window_size, 272 | alibi_slopes, 273 | deterministic, 274 | return_attn_probs, 275 | group, 276 | ) 277 | 278 | 279 | def zigzag_ring_flash_attn_kvpacked_func( 280 | q, 281 | kv, 282 | dropout_p=0.0, 283 | softmax_scale=None, 284 | causal=False, 285 | window_size=(-1, -1), 286 | alibi_slopes=None, 287 | deterministic=False, 288 | return_attn_probs=False, 289 | group=None, 290 | ): 291 | return ZigZagRingFlashAttnFunc.apply( 292 | q, 293 | kv[:, :, 0], 294 | kv[:, :, 1], 295 | dropout_p, 296 | softmax_scale, 297 | causal, 298 | window_size, 299 | alibi_slopes, 300 | deterministic, 301 | return_attn_probs, 302 | group, 303 | ) 304 | 305 | 306 | def zigzag_ring_flash_attn_func( 307 | q, 308 | k, 309 | v, 310 | dropout_p=0.0, 311 | softmax_scale=None, 312 | causal=False, 313 | window_size=(-1, -1), 314 | alibi_slopes=None, 315 | deterministic=False, 316 | return_attn_probs=False, 317 | group=None, 318 | ): 319 | return ZigZagRingFlashAttnFunc.apply( 320 | q, 321 | k, 322 | v, 323 | dropout_p, 324 | softmax_scale, 325 | causal, 326 | window_size, 327 | alibi_slopes, 328 | deterministic, 329 | return_attn_probs, 330 | group, 331 | ) 332 | -------------------------------------------------------------------------------- /train/ring-flash-attention/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "ring-flash-attn" 7 | version = "0.1.1" 8 | authors = [ 9 | { name="zhuzilin", email="zhuzilinallen@gmail.com" }, 10 | ] 11 | description = "Ring attention implementation with flash attention." 12 | readme = "README.md" 13 | requires-python = ">=3.8" 14 | classifiers = [ 15 | "Programming Language :: Python :: 3", 16 | "License :: OSI Approved :: MIT License", 17 | "Operating System :: OS Independent", 18 | ] 19 | 20 | [project.urls] 21 | Homepage = "https://github.com/zhuzilin/ring-flash-attention" 22 | Issues = "https://github.com/zhuzilin/ring-flash-attention/issues" -------------------------------------------------------------------------------- /train/ring-flash-attention/ring_flash_attn.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.2 2 | Name: ring-flash-attn 3 | Version: 0.1.1 4 | Summary: Ring attention implementation with flash attention. 5 | Home-page: https://github.com/zhuzilin/ring-flash-attention 6 | Author: zhuzilin 7 | Author-email: zhuzilin 8 | Project-URL: Homepage, https://github.com/zhuzilin/ring-flash-attention 9 | Project-URL: Issues, https://github.com/zhuzilin/ring-flash-attention/issues 10 | Classifier: Programming Language :: Python :: 3 11 | Classifier: License :: OSI Approved :: MIT License 12 | Classifier: Operating System :: OS Independent 13 | Requires-Python: >=3.8 14 | Description-Content-Type: text/markdown 15 | License-File: LICENSE 16 | Dynamic: author 17 | Dynamic: home-page 18 | 19 | ## Ring Flash Attention 20 | 21 | This repo implements [RingAttention](https://github.com/lhao499/RingAttention) using [FlashAttention](https://github.com/Dao-AILab/flash-attention). The current implementation supports: 22 | 23 | - varlen (packing samples) api, corresponding to `flash_attn_varlen_func`: 24 | - `ring_flash_attn_varlen_func`: A basic implementation of ring attention. 25 | - `zigzag_ring_flash_attn_varlen_func`: an more compute-balanced version of ring attention. More details in [issue#2](https://github.com/zhuzilin/ring-flash-attention/issues/2). 26 | - `llama3_flash_attn_varlen_func`: The context parallelism used in [llama3 tech report](https://arxiv.org/abs/2407.21783) with extra design for varlen and low memory overhead. Although technically not ring attention, this is **recommended** for most varlen use cases, as it offers a less intrusive alternative for training frameworks with fewer data manipulations and better arithmetic precision. 27 | - batch api, corresponding to `flash_attn_func`: 28 | - `ring_flash_attn_func`: basic ring attention. 29 | - `zigzag_ring_flash_attn_func`: An more compute balanced version of ring attention, see [issue#2](https://github.com/zhuzilin/ring-flash-attention/issues/2). 30 | - `stripe_flash_attn_func`: Stripe attention version of `ring_flash_attn_func`, the block size is set to 1 to use flash_attn api, see: https://arxiv.org/abs/2311.09431 31 | - [huggingface model adapter](ring_flash_attn/adapters/hf_adapter.py). Here is an example to use the adapter: [OpenRLHF/OpenRLHF/pull#439](https://github.com/OpenRLHF/OpenRLHF/pull/439/files). 32 | 33 | Note that 34 | 35 | - Each function includes `*_func`, `*_kvpacked_func`, `*_qkvpacked_func` variants. 36 | - The varlen versions (except the llama3 version) only support passing one `cu_seqlens`. 37 | 38 | ## Performance Summary 39 | 40 | The following table summarizes the performance of the implemented APIs: 41 | 42 | | batch api | GPU | theoretic
flash_attn | ring_attn | zigzag_ring | stripe_attn | 43 | | -------------------- | ------- | ----------------------------- | ------------- | --------------- | --------------- | 44 | | fwd only (iter/sec) | 8xH800 | 591.5 / 8 = 73.9 | 38.5 | 63.0 | 55.0 | 45 | | | | | 52.1% | **85.2%** | 74.4% | 46 | | fwd + bwd (iter/sec) | 8xH800 | 154.7 / 8 = 19.3 | 10.4 | 17.4 | 16.0 | 47 | | | | | 53.9% | **90.2%** | 82.9% | 48 | | fwd only (iter/sec) | 8xA100 | 373.4 / 8 = 46.7 | 24.0 | 38.2 | 32.5 | 49 | | | | | 51.4% | **81.7%** | 69.6% | 50 | | fwd + bwd (iter/sec) | 8xA100 | 94.7 / 8 = 11.8 | 6.2 | 10.6 | 9.75 | 51 | | | | | 52.5% | **89.8%** | 82.6% | 52 | | **varlen api** | **GPU** | **theoretic
flash_attn** | **ring_attn** | **zigzag_ring** | **llama3_attn** | 53 | | fwd only (iter/sec) | 8xH800 | 852.4 / 8 = 106.6 | 52.4 | 74.8 | 60.8 | 54 | | | | | 49.1% | **70.2%** | 57.0% | 55 | | fwd + bwd (iter/sec) | 8xH800 | 225.4 / 8 = 28.2 | 14.4 | 21.4 | 16.4 | 56 | | | | | 51.1% | **75.9%** | 58.1% | 57 | | fwd only (iter/sec) | 8xA100 | 532.3 / 8 = 66.5 | 33.1 | 47.9 | 34.3 | 58 | | | | | 49.8% | **72.0%** | 51.6% | 59 | | fwd + bwd (iter/sec) | 8xA100 | 133.8 / 8 = 16.7 | 8.7 | 13.4 | 9.7 | 60 | | | | | 52.1% | **80.2%** | 58.0% | 61 | 62 | Note that 63 | 64 | - The code of the benchmark is in [benchmark](benchmark/), its configuration matches the [Meta-Llama-3.1-8B](https://huggingface.co/NousResearch/Meta-Llama-3.1-8B/blob/main/config.json) setting, with a total sequence of length 8k per GPU. 65 | - When running the benchmark with with 8 gpu, the flash attn code is running with 1/8 computation of ring attention, as flash attn code is running `8*1^2`, while the ring attn code is running `1*8^2`. 66 | - NVLink between GPUs are required for high performance. 67 | - Please remember to adapt the RoPE offset for different api. 68 | 69 | ### Installation 70 | 71 | ```bash 72 | pip install ring-flash-attn 73 | ``` 74 | 75 | or use the following command to build from source: 76 | 77 | ```bash 78 | git clone https://github.com/zhuzilin/ring-flash-attention.git 79 | cd ring-flash-attention 80 | pip install . 81 | ``` 82 | 83 | ### TODOs 84 | 85 | - [x] Implement `ring_flash_attn_varlen_qkvpacked_func` 86 | - [x] Implement `zigzag_ring_flash_attn_qkvpacked_func` [issue#2](https://github.com/zhuzilin/ring-flash-attention/issues/2) 87 | - [x] Implement `stripe_flash_attn_qkvpacked_func` 88 | - [x] Implement `zigzag_ring_flash_attn_varlen_qkvpacked_func` 89 | - [x] Implement `*_kvpacked_func` and `*_func` variant for all APIs 90 | - [x] ~~Optimize `*_varlen_func`~~ Implement `llama3_flash_attn_varlen_func` 91 | - [x] ~~Add an example to train llama~~ Implement adapter for huggingface model 92 | - [ ] Implement `zigzag_llama3_flash_attn_varlen_func` 93 | 94 | ### Test 95 | 96 | ```bash 97 | torchrun --nproc_per_node 8 test/test_llama3_flash_attn_varlen_func.py 98 | torchrun --nproc_per_node 8 test/test_ring_flash_attn_func.py 99 | torchrun --nproc_per_node 8 test/test_ring_flash_attn_varlen_func.py 100 | torchrun --nproc_per_node 8 test/test_zigzag_ring_flash_attn_func.py 101 | torchrun --nproc_per_node 8 test/test_zigzag_ring_flash_attn_varlen_func.py 102 | torchrun --nproc_per_node 8 test/test_stripe_flash_attn_func.py 103 | ``` 104 | 105 | ### Benchmark 106 | 107 | ```bash 108 | torchrun --nproc_per_node 8 benchmark/benchmark_kvpacked_func.py 109 | torchrun --nproc_per_node 8 benchmark/benchmark_varlen_kvpacked_func.py 110 | ``` 111 | 112 | ### Known Limitations 113 | 114 | There are some arithmetic errors with the current implementation. The reason for them is probably that flash attention will return bf16 value for each block, so we cannot accumluate the values with the original fp32 ones. 115 | 116 | And also because we need to save extra fp32 buffer during computation, the memory usage would be higher than theoretic limit. 117 | 118 | Also, 119 | 120 | - dropout is not supported at the moment, because it's hard to save all the rng_states. 121 | - window_size is not supported, because it will be really tricky to implement a varlen version with window_size. 122 | -------------------------------------------------------------------------------- /train/ring-flash-attention/ring_flash_attn.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | LICENSE 2 | README.md 3 | pyproject.toml 4 | setup.py 5 | ring_flash_attn/__init__.py 6 | ring_flash_attn/llama3_flash_attn_varlen.py 7 | ring_flash_attn/ring_flash_attn.py 8 | ring_flash_attn/ring_flash_attn_varlen.py 9 | ring_flash_attn/stripe_flash_attn.py 10 | ring_flash_attn/triton_utils.py 11 | ring_flash_attn/utils.py 12 | ring_flash_attn/zigzag_ring_flash_attn.py 13 | ring_flash_attn/zigzag_ring_flash_attn_varlen.py 14 | ring_flash_attn.egg-info/PKG-INFO 15 | ring_flash_attn.egg-info/SOURCES.txt 16 | ring_flash_attn.egg-info/dependency_links.txt 17 | ring_flash_attn.egg-info/top_level.txt 18 | ring_flash_attn/adapters/__init__.py 19 | ring_flash_attn/adapters/hf_adapter.py 20 | test/test_llama3_flash_attn_varlen_func.py 21 | test/test_llama3_prepare_cu_seqlens.py 22 | test/test_ring_flash_attn_func.py 23 | test/test_ring_flash_attn_varlen_func.py 24 | test/test_stripe_flash_attn_func.py 25 | test/test_triton_kernels.py 26 | test/test_zigzag_ring_flash_attn_func.py 27 | test/test_zigzag_ring_flash_attn_varlen_func.py -------------------------------------------------------------------------------- /train/ring-flash-attention/ring_flash_attn.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /train/ring-flash-attention/ring_flash_attn.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | ring_flash_attn 2 | -------------------------------------------------------------------------------- /train/ring-flash-attention/ring_flash_attn/.___pycache__: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlpod/OpenSFT/57ecd3c6037c4dca600009766dbeb0c1249a6048/train/ring-flash-attention/ring_flash_attn/.___pycache__ -------------------------------------------------------------------------------- /train/ring-flash-attention/ring_flash_attn/.ipynb_checkpoints/utils-checkpoint.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import torch 4 | import torch.distributed as dist 5 | import torch.nn.functional as F 6 | import inspect 7 | from functools import cache 8 | 9 | 10 | __all__ = ["update_out_and_lse", "RingComm", "get_default_args"] 11 | 12 | 13 | @cache 14 | def get_default_args(func): 15 | spec = inspect.getfullargspec(func) 16 | defaults = spec.defaults if spec.defaults is not None else () 17 | padded_defaults = (None,) * (len(spec.args) - len(defaults)) + defaults 18 | args = dict(zip(spec.args, padded_defaults)) 19 | if "softcap" in args: 20 | args["softcap"] = 0.0 21 | return args 22 | 23 | 24 | @torch.jit.script 25 | def _update_out_and_lse( 26 | out: torch.Tensor, 27 | lse: torch.Tensor, 28 | block_out: torch.Tensor, 29 | block_lse: torch.Tensor, 30 | ) -> Tuple[torch.Tensor, torch.Tensor]: 31 | 32 | block_out = block_out.to(torch.float32) 33 | block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) 34 | 35 | # new_lse = lse + torch.log(1 + torch.exp(block_lse - lse)) 36 | # torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out 37 | # For additional context and discussion, please refer to: 38 | # https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795 39 | out = out - F.sigmoid(block_lse - lse) * (out - block_out) 40 | lse = lse - F.logsigmoid(lse - block_lse) 41 | 42 | return out, lse 43 | 44 | 45 | def update_out_and_lse( 46 | out: Optional[torch.Tensor], 47 | lse: Optional[torch.Tensor], 48 | block_out: torch.Tensor, 49 | block_lse: torch.Tensor, 50 | slice_=None, 51 | ) -> Tuple[torch.Tensor, torch.Tensor]: 52 | if out is None: 53 | if slice_ is not None: 54 | raise RuntimeError("first update_out_and_lse should not pass slice_ args") 55 | out = block_out.to(torch.float32) 56 | lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) 57 | elif slice_ is not None: 58 | slice_out, slice_lse = out[slice_], lse[slice_] 59 | slice_out, slice_lse = _update_out_and_lse( 60 | slice_out, slice_lse, block_out, block_lse 61 | ) 62 | out[slice_], lse[slice_] = slice_out, slice_lse 63 | else: 64 | out, lse = _update_out_and_lse(out, lse, block_out, block_lse) 65 | return out, lse 66 | 67 | 68 | @torch.jit.script 69 | def flatten_varlen_lse(lse, cu_seqlens): 70 | new_lse = [] 71 | for i in range(len(cu_seqlens) - 1): 72 | start, end = cu_seqlens[i], cu_seqlens[i + 1] 73 | new_lse.append(lse[i, :, : end - start]) 74 | return torch.cat(new_lse, dim=1) 75 | 76 | 77 | @torch.jit.script 78 | def unflatten_varlen_lse(lse, cu_seqlens, max_seqlen: int): 79 | num_seq = len(cu_seqlens) - 1 80 | num_head = lse.shape[-2] 81 | new_lse = torch.empty( 82 | (num_seq, max_seqlen, num_head, 1), dtype=torch.float32, device=lse.device 83 | ) 84 | for i in range(num_seq): 85 | start, end = cu_seqlens[i], cu_seqlens[i + 1] 86 | new_lse[i, : end - start] = lse[start:end] 87 | return new_lse.squeeze(dim=-1).transpose(1, 2).contiguous() 88 | 89 | 90 | class RingComm: 91 | def __init__(self, process_group: dist.ProcessGroup): 92 | self._process_group = process_group 93 | self._ops = [] 94 | self.rank = dist.get_rank(self._process_group) 95 | self.world_size = dist.get_world_size(self._process_group) 96 | self._reqs = None 97 | 98 | self.send_rank = (self.rank + 1) % self.world_size 99 | self.recv_rank = (self.rank - 1) % self.world_size 100 | 101 | if process_group is not None: 102 | self.send_rank = dist.get_global_rank(self._process_group, self.send_rank) 103 | self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank) 104 | 105 | def send_recv( 106 | self, to_send: torch.Tensor, recv_tensor: Optional[torch.Tensor] = None 107 | ) -> torch.Tensor: 108 | if recv_tensor is None: 109 | res = torch.empty_like(to_send) 110 | else: 111 | res = recv_tensor 112 | 113 | send_op = dist.P2POp( 114 | dist.isend, to_send, self.send_rank, group=self._process_group 115 | ) 116 | recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group) 117 | self._ops.append(send_op) 118 | self._ops.append(recv_op) 119 | return res 120 | 121 | def commit(self): 122 | if self._reqs is not None: 123 | raise RuntimeError("commit called twice") 124 | self._reqs = dist.batch_isend_irecv(self._ops) 125 | 126 | def wait(self): 127 | if self._reqs is None: 128 | raise RuntimeError("wait called before commit") 129 | for req in self._reqs: 130 | req.wait() 131 | self._reqs = None 132 | self._ops = [] 133 | 134 | 135 | class AllGatherComm: 136 | def __init__(self, group=None) -> None: 137 | self.group = group 138 | self.handles = [] 139 | 140 | def all_gather(self, output_tensor: torch.Tensor, input_tensor: torch.Tensor): 141 | handle = dist.all_gather_into_tensor( 142 | output_tensor, input_tensor, group=self.group, async_op=True 143 | ) 144 | self.handles.append(handle) 145 | 146 | def wait(self): 147 | for handle in self.handles: 148 | handle.wait() 149 | self.handles = [] 150 | -------------------------------------------------------------------------------- /train/ring-flash-attention/ring_flash_attn/__init__.py: -------------------------------------------------------------------------------- 1 | from .llama3_flash_attn_varlen import ( 2 | llama3_flash_attn_prepare_cu_seqlens, 3 | llama3_flash_attn_varlen_func, 4 | llama3_flash_attn_varlen_kvpacked_func, 5 | llama3_flash_attn_varlen_qkvpacked_func, 6 | ) 7 | from .ring_flash_attn import ( 8 | ring_flash_attn_func, 9 | ring_flash_attn_kvpacked_func, 10 | ring_flash_attn_qkvpacked_func, 11 | ) 12 | from .ring_flash_attn_varlen import ( 13 | ring_flash_attn_varlen_func, 14 | ring_flash_attn_varlen_kvpacked_func, 15 | ring_flash_attn_varlen_qkvpacked_func, 16 | ) 17 | from .zigzag_ring_flash_attn import ( 18 | zigzag_ring_flash_attn_func, 19 | zigzag_ring_flash_attn_kvpacked_func, 20 | zigzag_ring_flash_attn_qkvpacked_func, 21 | ) 22 | from .zigzag_ring_flash_attn_varlen import ( 23 | zigzag_ring_flash_attn_varlen_func, 24 | zigzag_ring_flash_attn_varlen_kvpacked_func, 25 | zigzag_ring_flash_attn_varlen_qkvpacked_func, 26 | ) 27 | from .stripe_flash_attn import ( 28 | stripe_flash_attn_func, 29 | stripe_flash_attn_kvpacked_func, 30 | stripe_flash_attn_qkvpacked_func, 31 | ) 32 | from .adapters import ( 33 | substitute_hf_flash_attn, 34 | update_ring_flash_attn_params, 35 | ) 36 | -------------------------------------------------------------------------------- /train/ring-flash-attention/ring_flash_attn/adapters/.ipynb_checkpoints/hf_adapter-checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | import inspect 3 | from typing import Optional 4 | 5 | import torch 6 | import torch.distributed as dist 7 | import transformers 8 | import transformers.modeling_flash_attention_utils 9 | from transformers.modeling_flash_attention_utils import ( 10 | _flash_supports_window_size, 11 | is_flash_attn_greater_or_equal, 12 | ) 13 | from ..llama3_flash_attn_varlen import ( 14 | llama3_flash_attn_varlen_func, 15 | llama3_flash_attn_prepare_cu_seqlens, 16 | ) 17 | 18 | 19 | DATA_PARAMS = {} 20 | RING_ATTN_SWITCH = True 21 | 22 | 23 | def check_params(f1, f2): 24 | return len(inspect.signature(f1).parameters) == len( 25 | inspect.signature(f2).parameters 26 | ) 27 | 28 | 29 | def update_ring_flash_attn_params( 30 | cu_seqlens: torch.Tensor, process_group: dist.ProcessGroup 31 | ): 32 | world_size = dist.get_world_size(group=process_group) 33 | rank = dist.get_rank(group=process_group) 34 | ( 35 | cu_seqlens_q, 36 | cu_seqlens_k, 37 | max_seqlen_q, 38 | max_seqlen_k, 39 | local_k_slice, 40 | ) = llama3_flash_attn_prepare_cu_seqlens(cu_seqlens, True, rank, world_size) 41 | DATA_PARAMS.update( 42 | { 43 | "cu_seqlens_q": cu_seqlens_q, 44 | "cu_seqlens_k": cu_seqlens_k, 45 | "max_seqlen_q": max_seqlen_q, 46 | "max_seqlen_k": max_seqlen_k, 47 | "local_k_slice": local_k_slice, 48 | } 49 | ) 50 | 51 | 52 | def use_ring_attn(flag): 53 | global RING_ATTN_SWITCH 54 | RING_ATTN_SWITCH = flag 55 | 56 | 57 | def create_ring_flash_attention_forward( 58 | process_group: dist.ProcessGroup, heads_k_stride: int 59 | ): 60 | def _flash_attention_forward( 61 | query_states: torch.Tensor, 62 | key_states: torch.Tensor, 63 | value_states: torch.Tensor, 64 | attention_mask: torch.Tensor, 65 | query_length: int, 66 | is_causal: bool, 67 | dropout: float = 0.0, 68 | position_ids: Optional[torch.Tensor] = None, 69 | softmax_scale: Optional[float] = None, 70 | sliding_window: Optional[int] = None, 71 | use_top_left_mask: bool = False, 72 | softcap: Optional[float] = None, 73 | deterministic: bool = None, 74 | ): 75 | """ 76 | Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token 77 | first unpad the input, then computes the attention scores and pad the final attention scores. 78 | 79 | Args: 80 | query_states (`torch.Tensor`): 81 | Input query states to be passed to Flash Attention API 82 | key_states (`torch.Tensor`): 83 | Input key states to be passed to Flash Attention API 84 | value_states (`torch.Tensor`): 85 | Input value states to be passed to Flash Attention API 86 | attention_mask (`torch.Tensor`): 87 | The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the 88 | position of padding tokens and 1 for the position of non-padding tokens. 89 | dropout (`float`): 90 | Attention dropout 91 | softmax_scale (`float`, *optional*): 92 | The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) 93 | use_top_left_mask (`bool`, defaults to `False`): 94 | flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. 95 | softcap (`float`, *optional*): 96 | Softcap for the attention logits, used e.g. in gemma2. 97 | deterministic (`bool`, *optional*): 98 | Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled. 99 | """ 100 | if not use_top_left_mask: 101 | causal = is_causal 102 | else: 103 | # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__. 104 | causal = is_causal and query_length != 1 105 | 106 | # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length). 107 | use_sliding_windows = ( 108 | _flash_supports_window_size 109 | and sliding_window is not None 110 | and key_states.shape[1] > sliding_window 111 | ) 112 | flash_kwargs = ( 113 | {"window_size": (sliding_window, sliding_window)} 114 | if use_sliding_windows 115 | else {} 116 | ) 117 | 118 | if is_flash_attn_greater_or_equal("2.4.1"): 119 | if deterministic is None: 120 | deterministic = ( 121 | os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" 122 | ) 123 | flash_kwargs["deterministic"] = deterministic 124 | assert ( 125 | softcap is None 126 | ), "llama3_flash_attn_varlen_func does not support softcap yet." 127 | # flash_kwargs["softcap"] = softcap 128 | flash_kwargs["group"] = process_group 129 | 130 | # not sure why attention_mask can be not None... 131 | assert causal, "only causal attention is supported yet." 132 | batch_size = query_states.size(0) 133 | assert batch_size == 1, "varlen data should be processed in advance." 134 | 135 | attn_output = llama3_flash_attn_varlen_func( 136 | query_states.squeeze(dim=0), 137 | key_states.squeeze(dim=0), 138 | value_states.squeeze(dim=0), 139 | cu_seqlens_q=DATA_PARAMS["cu_seqlens_q"], 140 | cu_seqlens_k=DATA_PARAMS["cu_seqlens_k"], 141 | max_seqlen_q=DATA_PARAMS["max_seqlen_q"], 142 | max_seqlen_k=DATA_PARAMS["max_seqlen_k"], 143 | heads_k_stride=heads_k_stride, 144 | local_k_slice=DATA_PARAMS["local_k_slice"], 145 | dropout_p=dropout, 146 | softmax_scale=softmax_scale, 147 | causal=causal, 148 | **flash_kwargs, 149 | ) 150 | 151 | attn_output = attn_output.unsqueeze(dim=0) 152 | 153 | return attn_output 154 | 155 | return _flash_attention_forward 156 | 157 | 158 | def substitute_hf_flash_attn(process_group: dist.ProcessGroup, heads_k_stride: int): 159 | try: 160 | # substitute flash attn 161 | old_flash_attention_forward = ( 162 | transformers.modeling_flash_attention_utils._flash_attention_forward 163 | ) 164 | new_flash_attention_forward = create_ring_flash_attention_forward( 165 | process_group, heads_k_stride 166 | ) 167 | assert check_params(old_flash_attention_forward, new_flash_attention_forward) 168 | transformers.modeling_flash_attention_utils._flash_attention_forward = ( 169 | lambda *args, **kwargs: ( 170 | new_flash_attention_forward(*args, **kwargs) 171 | if RING_ATTN_SWITCH 172 | else old_flash_attention_forward(*args, **kwargs) 173 | ) 174 | ) 175 | except: 176 | raise ValueError( 177 | f"The current transformer version {transformers.__version__} is not supported. " 178 | "please use pip install -U transformers to upgrade to the latest version. " 179 | "If the code failed with the latest version, " 180 | "please file an issue to https://github.com/zhuzilin/ring-flash-attention/issues" 181 | ) 182 | -------------------------------------------------------------------------------- /train/ring-flash-attention/ring_flash_attn/adapters/__init__.py: -------------------------------------------------------------------------------- 1 | from .hf_adapter import ( 2 | substitute_hf_flash_attn, 3 | update_ring_flash_attn_params, 4 | ) 5 | -------------------------------------------------------------------------------- /train/ring-flash-attention/ring_flash_attn/adapters/hf_adapter.py: -------------------------------------------------------------------------------- 1 | import os 2 | import inspect 3 | from typing import Optional 4 | 5 | import torch 6 | import torch.distributed as dist 7 | import transformers 8 | import transformers.modeling_flash_attention_utils 9 | from transformers.modeling_flash_attention_utils import ( 10 | _flash_supports_window_size, 11 | is_flash_attn_greater_or_equal, 12 | ) 13 | from ..llama3_flash_attn_varlen import ( 14 | llama3_flash_attn_varlen_func, 15 | llama3_flash_attn_prepare_cu_seqlens, 16 | ) 17 | 18 | 19 | DATA_PARAMS = {} 20 | RING_ATTN_SWITCH = True 21 | 22 | 23 | def check_params(f1, f2): 24 | return len(inspect.signature(f1).parameters) == len( 25 | inspect.signature(f2).parameters 26 | ) 27 | 28 | 29 | def update_ring_flash_attn_params( 30 | cu_seqlens: torch.Tensor, process_group: dist.ProcessGroup 31 | ): 32 | world_size = dist.get_world_size(group=process_group) 33 | rank = dist.get_rank(group=process_group) 34 | ( 35 | cu_seqlens_q, 36 | cu_seqlens_k, 37 | max_seqlen_q, 38 | max_seqlen_k, 39 | local_k_slice, 40 | ) = llama3_flash_attn_prepare_cu_seqlens(cu_seqlens, True, rank, world_size) 41 | DATA_PARAMS.update( 42 | { 43 | "cu_seqlens_q": cu_seqlens_q, 44 | "cu_seqlens_k": cu_seqlens_k, 45 | "max_seqlen_q": max_seqlen_q, 46 | "max_seqlen_k": max_seqlen_k, 47 | "local_k_slice": local_k_slice, 48 | } 49 | ) 50 | 51 | 52 | def use_ring_attn(flag): 53 | global RING_ATTN_SWITCH 54 | RING_ATTN_SWITCH = flag 55 | 56 | 57 | def create_ring_flash_attention_forward( 58 | process_group: dist.ProcessGroup, heads_k_stride: int 59 | ): 60 | def _flash_attention_forward( 61 | query_states: torch.Tensor, 62 | key_states: torch.Tensor, 63 | value_states: torch.Tensor, 64 | attention_mask: torch.Tensor, 65 | query_length: int, 66 | is_causal: bool, 67 | dropout: float = 0.0, 68 | position_ids: Optional[torch.Tensor] = None, 69 | softmax_scale: Optional[float] = None, 70 | sliding_window: Optional[int] = None, 71 | use_top_left_mask: bool = False, 72 | softcap: Optional[float] = None, 73 | deterministic: bool = None, 74 | ): 75 | """ 76 | Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token 77 | first unpad the input, then computes the attention scores and pad the final attention scores. 78 | 79 | Args: 80 | query_states (`torch.Tensor`): 81 | Input query states to be passed to Flash Attention API 82 | key_states (`torch.Tensor`): 83 | Input key states to be passed to Flash Attention API 84 | value_states (`torch.Tensor`): 85 | Input value states to be passed to Flash Attention API 86 | attention_mask (`torch.Tensor`): 87 | The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the 88 | position of padding tokens and 1 for the position of non-padding tokens. 89 | dropout (`float`): 90 | Attention dropout 91 | softmax_scale (`float`, *optional*): 92 | The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) 93 | use_top_left_mask (`bool`, defaults to `False`): 94 | flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. 95 | softcap (`float`, *optional*): 96 | Softcap for the attention logits, used e.g. in gemma2. 97 | deterministic (`bool`, *optional*): 98 | Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled. 99 | """ 100 | if not use_top_left_mask: 101 | causal = is_causal 102 | else: 103 | # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__. 104 | causal = is_causal and query_length != 1 105 | 106 | # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length). 107 | use_sliding_windows = ( 108 | _flash_supports_window_size 109 | and sliding_window is not None 110 | and key_states.shape[1] > sliding_window 111 | ) 112 | flash_kwargs = ( 113 | {"window_size": (sliding_window, sliding_window)} 114 | if use_sliding_windows 115 | else {} 116 | ) 117 | 118 | if is_flash_attn_greater_or_equal("2.4.1"): 119 | if deterministic is None: 120 | deterministic = ( 121 | os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" 122 | ) 123 | flash_kwargs["deterministic"] = deterministic 124 | assert ( 125 | softcap is None 126 | ), "llama3_flash_attn_varlen_func does not support softcap yet." 127 | # flash_kwargs["softcap"] = softcap 128 | flash_kwargs["group"] = process_group 129 | 130 | # not sure why attention_mask can be not None... 131 | assert causal, "only causal attention is supported yet." 132 | batch_size = query_states.size(0) 133 | assert batch_size == 1, "varlen data should be processed in advance." 134 | 135 | attn_output = llama3_flash_attn_varlen_func( 136 | query_states.squeeze(dim=0), 137 | key_states.squeeze(dim=0), 138 | value_states.squeeze(dim=0), 139 | cu_seqlens_q=DATA_PARAMS["cu_seqlens_q"], 140 | cu_seqlens_k=DATA_PARAMS["cu_seqlens_k"], 141 | max_seqlen_q=DATA_PARAMS["max_seqlen_q"], 142 | max_seqlen_k=DATA_PARAMS["max_seqlen_k"], 143 | heads_k_stride=heads_k_stride, 144 | local_k_slice=DATA_PARAMS["local_k_slice"], 145 | dropout_p=dropout, 146 | softmax_scale=softmax_scale, 147 | causal=causal, 148 | **flash_kwargs, 149 | ) 150 | 151 | attn_output = attn_output.unsqueeze(dim=0) 152 | 153 | return attn_output 154 | 155 | return _flash_attention_forward 156 | 157 | 158 | def substitute_hf_flash_attn(process_group: dist.ProcessGroup, heads_k_stride: int): 159 | try: 160 | # substitute flash attn 161 | old_flash_attention_forward = ( 162 | transformers.modeling_flash_attention_utils._flash_attention_forward 163 | ) 164 | new_flash_attention_forward = create_ring_flash_attention_forward( 165 | process_group, heads_k_stride 166 | ) 167 | assert check_params(old_flash_attention_forward, new_flash_attention_forward) 168 | transformers.modeling_flash_attention_utils._flash_attention_forward = ( 169 | lambda *args, **kwargs: ( 170 | new_flash_attention_forward(*args, **kwargs) 171 | if RING_ATTN_SWITCH 172 | else old_flash_attention_forward(*args, **kwargs) 173 | ) 174 | ) 175 | except: 176 | raise ValueError( 177 | f"The current transformer version {transformers.__version__} is not supported. " 178 | "please use pip install -U transformers to upgrade to the latest version. " 179 | "If the code failed with the latest version, " 180 | "please file an issue to https://github.com/zhuzilin/ring-flash-attention/issues" 181 | ) 182 | -------------------------------------------------------------------------------- /train/ring-flash-attention/ring_flash_attn/ring_flash_attn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward 4 | from .utils import RingComm, update_out_and_lse, get_default_args 5 | 6 | 7 | def ring_flash_attn_forward( 8 | process_group, 9 | q: torch.Tensor, 10 | k: torch.Tensor, 11 | v: torch.Tensor, 12 | softmax_scale, 13 | dropout_p=0, 14 | causal=True, 15 | window_size=(-1, -1), 16 | alibi_slopes=None, 17 | deterministic=False, 18 | ): 19 | comm = RingComm(process_group) 20 | 21 | out = None 22 | lse = None 23 | 24 | next_k, next_v = None, None 25 | 26 | for step in range(comm.world_size): 27 | if step + 1 != comm.world_size: 28 | next_k: torch.Tensor = comm.send_recv(k) 29 | next_v: torch.Tensor = comm.send_recv(v) 30 | comm.commit() 31 | 32 | if not causal or step <= comm.rank: 33 | params = get_default_args(_flash_attn_forward).copy() 34 | params.update( 35 | { 36 | "q": q, 37 | "k": k, 38 | "v": v, 39 | "dropout_p": dropout_p, 40 | "softmax_scale": softmax_scale, 41 | "causal": causal and step == 0, 42 | "window_size": window_size, 43 | "alibi_slopes": alibi_slopes, 44 | "return_softmax": True and dropout_p > 0, 45 | } 46 | ) 47 | block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward(**params) 48 | out, lse = update_out_and_lse(out, lse, block_out, block_lse) 49 | 50 | if step + 1 != comm.world_size: 51 | comm.wait() 52 | k = next_k 53 | v = next_v 54 | 55 | out = out.to(q.dtype) 56 | lse = lse.squeeze(dim=-1).transpose(1, 2) 57 | return out, lse 58 | 59 | 60 | def ring_flash_attn_backward( 61 | process_group, 62 | dout, 63 | q, 64 | k, 65 | v, 66 | out, 67 | softmax_lse, 68 | softmax_scale, 69 | dropout_p=0, 70 | causal=True, 71 | window_size=(-1, -1), 72 | alibi_slopes=None, 73 | deterministic=False, 74 | ): 75 | kv_comm = RingComm(process_group) 76 | d_kv_comm = RingComm(process_group) 77 | dq, dk, dv = None, None, None 78 | next_dk, next_dv = None, None 79 | 80 | block_dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) 81 | block_dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) 82 | block_dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) 83 | 84 | next_dk, next_dv = None, None 85 | next_k, next_v = None, None 86 | 87 | for step in range(kv_comm.world_size): 88 | if step + 1 != kv_comm.world_size: 89 | next_k = kv_comm.send_recv(k) 90 | next_v = kv_comm.send_recv(v) 91 | kv_comm.commit() 92 | if step <= kv_comm.rank or not causal: 93 | bwd_causal = causal and step == 0 94 | params = get_default_args(_flash_attn_backward).copy() 95 | params.update( 96 | { 97 | "dout": dout, 98 | "q": q, 99 | "k": k, 100 | "v": v, 101 | "out": out, 102 | "softmax_lse": softmax_lse, 103 | "dq": block_dq_buffer, 104 | "dk": block_dk_buffer, 105 | "dv": block_dv_buffer, 106 | "dropout_p": dropout_p, 107 | "softmax_scale": softmax_scale, 108 | "causal": bwd_causal, 109 | "window_size": window_size, 110 | "alibi_slopes": alibi_slopes, 111 | "deterministic": deterministic, 112 | } 113 | ) 114 | _flash_attn_backward(**params) 115 | 116 | if dq is None: 117 | dq = block_dq_buffer.to(torch.float32) 118 | dk = block_dk_buffer.to(torch.float32) 119 | dv = block_dv_buffer.to(torch.float32) 120 | else: 121 | dq += block_dq_buffer 122 | d_kv_comm.wait() 123 | dk = block_dk_buffer + next_dk 124 | dv = block_dv_buffer + next_dv 125 | elif step != 0: 126 | d_kv_comm.wait() 127 | dk = next_dk 128 | dv = next_dv 129 | 130 | if step + 1 != kv_comm.world_size: 131 | kv_comm.wait() 132 | k = next_k 133 | v = next_v 134 | 135 | next_dk = d_kv_comm.send_recv(dk) 136 | next_dv = d_kv_comm.send_recv(dv) 137 | d_kv_comm.commit() 138 | 139 | d_kv_comm.wait() 140 | 141 | return dq.to(torch.bfloat16), next_dk.to(q.dtype), next_dv.to(q.dtype) 142 | 143 | 144 | class RingFlashAttnFunc(torch.autograd.Function): 145 | @staticmethod 146 | def forward( 147 | ctx, 148 | q, 149 | k, 150 | v, 151 | dropout_p, 152 | softmax_scale, 153 | causal, 154 | window_size, 155 | alibi_slopes, 156 | deterministic, 157 | return_softmax, 158 | group, 159 | ): 160 | if softmax_scale is None: 161 | softmax_scale = q.shape[-1] ** (-0.5) 162 | 163 | assert alibi_slopes is None 164 | k = k.contiguous() 165 | v = v.contiguous() 166 | out, softmax_lse = ring_flash_attn_forward( 167 | group, 168 | q, 169 | k, 170 | v, 171 | softmax_scale=softmax_scale, 172 | dropout_p=dropout_p, 173 | causal=causal, 174 | window_size=window_size, 175 | alibi_slopes=alibi_slopes, 176 | deterministic=False, 177 | ) 178 | # this should be out_padded 179 | ctx.save_for_backward(q, k, v, out, softmax_lse) 180 | ctx.dropout_p = dropout_p 181 | ctx.softmax_scale = softmax_scale 182 | ctx.causal = causal 183 | ctx.window_size = window_size 184 | ctx.alibi_slopes = alibi_slopes 185 | ctx.deterministic = deterministic 186 | ctx.group = group 187 | return out if not return_softmax else (out, softmax_lse, None) 188 | 189 | @staticmethod 190 | def backward(ctx, dout, *args): 191 | q, k, v, out, softmax_lse = ctx.saved_tensors 192 | dq, dk, dv = ring_flash_attn_backward( 193 | ctx.group, 194 | dout, 195 | q, 196 | k, 197 | v, 198 | out, 199 | softmax_lse, 200 | softmax_scale=ctx.softmax_scale, 201 | dropout_p=ctx.dropout_p, 202 | causal=ctx.causal, 203 | window_size=ctx.window_size, 204 | alibi_slopes=ctx.alibi_slopes, 205 | deterministic=ctx.deterministic, 206 | ) 207 | return dq, dk, dv, None, None, None, None, None, None, None, None 208 | 209 | 210 | def ring_flash_attn_qkvpacked_func( 211 | qkv, 212 | dropout_p=0.0, 213 | softmax_scale=None, 214 | causal=False, 215 | window_size=(-1, -1), 216 | alibi_slopes=None, 217 | deterministic=False, 218 | return_attn_probs=False, 219 | group=None, 220 | ): 221 | return RingFlashAttnFunc.apply( 222 | qkv[:, :, 0], 223 | qkv[:, :, 1], 224 | qkv[:, :, 2], 225 | dropout_p, 226 | softmax_scale, 227 | causal, 228 | window_size, 229 | alibi_slopes, 230 | deterministic, 231 | return_attn_probs, 232 | group, 233 | ) 234 | 235 | 236 | def ring_flash_attn_kvpacked_func( 237 | q, 238 | kv, 239 | dropout_p=0.0, 240 | softmax_scale=None, 241 | causal=False, 242 | window_size=(-1, -1), 243 | alibi_slopes=None, 244 | deterministic=False, 245 | return_attn_probs=False, 246 | group=None, 247 | ): 248 | return RingFlashAttnFunc.apply( 249 | q, 250 | kv[:, :, 0], 251 | kv[:, :, 1], 252 | dropout_p, 253 | softmax_scale, 254 | causal, 255 | window_size, 256 | alibi_slopes, 257 | deterministic, 258 | return_attn_probs, 259 | group, 260 | ) 261 | 262 | 263 | def ring_flash_attn_func( 264 | q, 265 | k, 266 | v, 267 | dropout_p=0.0, 268 | softmax_scale=None, 269 | causal=False, 270 | window_size=(-1, -1), 271 | alibi_slopes=None, 272 | deterministic=False, 273 | return_attn_probs=False, 274 | group=None, 275 | ): 276 | return RingFlashAttnFunc.apply( 277 | q, 278 | k, 279 | v, 280 | dropout_p, 281 | softmax_scale, 282 | causal, 283 | window_size, 284 | alibi_slopes, 285 | deterministic, 286 | return_attn_probs, 287 | group, 288 | ) 289 | -------------------------------------------------------------------------------- /train/ring-flash-attention/ring_flash_attn/ring_flash_attn_varlen.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from flash_attn.flash_attn_interface import ( 4 | _flash_attn_varlen_forward, 5 | _flash_attn_varlen_backward, 6 | ) 7 | from .utils import ( 8 | RingComm, 9 | update_out_and_lse, 10 | get_default_args, 11 | ) 12 | 13 | try: 14 | from .triton_utils import ( 15 | flatten_varlen_lse, 16 | unflatten_varlen_lse, 17 | ) 18 | except: 19 | from .utils import ( 20 | flatten_varlen_lse, 21 | unflatten_varlen_lse, 22 | ) 23 | 24 | 25 | def ring_flash_attn_varlen_forward( 26 | process_group, 27 | q: torch.Tensor, 28 | k: torch.Tensor, 29 | v: torch.Tensor, 30 | cu_seqlens, 31 | max_seqlen, 32 | softmax_scale, 33 | dropout_p=0, 34 | causal=True, 35 | window_size=(-1, -1), 36 | alibi_slopes=None, 37 | deterministic=False, 38 | ): 39 | comm = RingComm(process_group) 40 | 41 | out = None 42 | lse = None 43 | next_k, next_v = None, None 44 | 45 | old_lse = False 46 | for step in range(comm.world_size): 47 | if step + 1 != comm.world_size: 48 | next_k: torch.Tensor = comm.send_recv(k) 49 | next_v: torch.Tensor = comm.send_recv(v) 50 | comm.commit() 51 | if not causal or step <= comm.rank: 52 | params = get_default_args(_flash_attn_varlen_forward).copy() 53 | params.update( 54 | { 55 | "q": q, 56 | "k": k, 57 | "v": v, 58 | "cu_seqlens_q": cu_seqlens, 59 | "cu_seqlens_k": cu_seqlens, 60 | "max_seqlen_q": max_seqlen, 61 | "max_seqlen_k": max_seqlen, 62 | "dropout_p": dropout_p, 63 | "softmax_scale": softmax_scale, 64 | "causal": causal and step == 0, 65 | "window_size": window_size, 66 | "alibi_slopes": alibi_slopes, 67 | "return_softmax": True and dropout_p > 0, 68 | } 69 | ) 70 | 71 | block_out, _, _, _, _, block_lse, _, _ = _flash_attn_varlen_forward( 72 | **params 73 | ) 74 | if block_lse.dim() == 3: 75 | old_lse = True 76 | block_lse = flatten_varlen_lse( 77 | block_lse, 78 | cu_seqlens=cu_seqlens, 79 | ) 80 | out, lse = update_out_and_lse(out, lse, block_out, block_lse) 81 | 82 | if step + 1 != comm.world_size: 83 | comm.wait() 84 | k = next_k 85 | v = next_v 86 | 87 | out = out.to(q.dtype) 88 | if old_lse: 89 | lse = unflatten_varlen_lse(lse, cu_seqlens, max_seqlen) 90 | else: 91 | lse = lse.squeeze(dim=-1).transpose(0, 1) 92 | return out, lse 93 | 94 | 95 | def ring_flash_attn_varlen_backward( 96 | process_group, 97 | dout, 98 | q, 99 | k, 100 | v, 101 | out, 102 | softmax_lse, 103 | cu_seqlens, 104 | max_seqlen, 105 | softmax_scale, 106 | dropout_p=0, 107 | causal=True, 108 | window_size=(-1, -1), 109 | alibi_slopes=None, 110 | deterministic=False, 111 | ): 112 | kv_comm = RingComm(process_group) 113 | d_kv_comm = RingComm(process_group) 114 | dq, dk, dv = None, None, None 115 | next_dk, next_dv = None, None 116 | 117 | block_dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) 118 | block_dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) 119 | block_dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) 120 | 121 | next_dk, next_dv = None, None 122 | next_k, next_v = None, None 123 | for step in range(kv_comm.world_size): 124 | if step + 1 != kv_comm.world_size: 125 | next_k = kv_comm.send_recv(k) 126 | next_v = kv_comm.send_recv(v) 127 | kv_comm.commit() 128 | if step <= kv_comm.rank or not causal: 129 | bwd_causal = causal and step == 0 130 | params = get_default_args(_flash_attn_varlen_backward).copy() 131 | params.update( 132 | { 133 | "dout": dout, 134 | "q": q, 135 | "k": k, 136 | "v": v, 137 | "out": out, 138 | "softmax_lse": softmax_lse, 139 | "dq": block_dq_buffer, 140 | "dk": block_dk_buffer, 141 | "dv": block_dv_buffer, 142 | "cu_seqlens_q": cu_seqlens, 143 | "cu_seqlens_k": cu_seqlens, 144 | "max_seqlen_q": max_seqlen, 145 | "max_seqlen_k": max_seqlen, 146 | "dropout_p": dropout_p, 147 | "softmax_scale": softmax_scale, 148 | "causal": bwd_causal, 149 | "window_size": window_size, 150 | "alibi_slopes": alibi_slopes, 151 | "deterministic": deterministic, 152 | } 153 | ) 154 | _flash_attn_varlen_backward(**params) 155 | 156 | if dq is None: 157 | dq = block_dq_buffer.to(torch.float32) 158 | dk = block_dk_buffer.to(torch.float32) 159 | dv = block_dv_buffer.to(torch.float32) 160 | else: 161 | dq += block_dq_buffer 162 | d_kv_comm.wait() 163 | dk = block_dk_buffer + next_dk 164 | dv = block_dv_buffer + next_dv 165 | elif step != 0: 166 | d_kv_comm.wait() 167 | dk = next_dk 168 | dv = next_dv 169 | 170 | if step + 1 != kv_comm.world_size: 171 | kv_comm.wait() 172 | k = next_k 173 | v = next_v 174 | 175 | next_dk = d_kv_comm.send_recv(dk) 176 | next_dv = d_kv_comm.send_recv(dv) 177 | d_kv_comm.commit() 178 | 179 | d_kv_comm.wait() 180 | 181 | return dq.to(torch.bfloat16), next_dk.to(q.dtype), next_dv.to(q.dtype) 182 | 183 | 184 | class RingFlashAttnVarlenFunc(torch.autograd.Function): 185 | @staticmethod 186 | def forward( 187 | ctx, 188 | q, 189 | k, 190 | v, 191 | cu_seqlens, 192 | max_seqlen, 193 | dropout_p, 194 | softmax_scale, 195 | causal, 196 | window_size, 197 | alibi_slopes, 198 | deterministic, 199 | return_softmax, 200 | group, 201 | ): 202 | if softmax_scale is None: 203 | softmax_scale = q.shape[-1] ** (-0.5) 204 | 205 | assert alibi_slopes is None 206 | k = k.contiguous() 207 | v = v.contiguous() 208 | out, softmax_lse = ring_flash_attn_varlen_forward( 209 | group, 210 | q, 211 | k, 212 | v, 213 | cu_seqlens, 214 | max_seqlen, 215 | softmax_scale=softmax_scale, 216 | dropout_p=dropout_p, 217 | causal=causal, 218 | window_size=window_size, 219 | alibi_slopes=alibi_slopes, 220 | deterministic=False, 221 | ) 222 | # this should be out_padded 223 | ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens) 224 | ctx.max_seqlen = max_seqlen 225 | ctx.dropout_p = dropout_p 226 | ctx.softmax_scale = softmax_scale 227 | ctx.causal = causal 228 | ctx.window_size = window_size 229 | ctx.alibi_slopes = alibi_slopes 230 | ctx.deterministic = deterministic 231 | ctx.group = group 232 | return out if not return_softmax else (out, softmax_lse, None) 233 | 234 | @staticmethod 235 | def backward(ctx, dout, *args): 236 | q, k, v, out, softmax_lse, cu_seqlens = ctx.saved_tensors 237 | dq, dk, dv = ring_flash_attn_varlen_backward( 238 | ctx.group, 239 | dout, 240 | q, 241 | k, 242 | v, 243 | out, 244 | softmax_lse, 245 | cu_seqlens, 246 | ctx.max_seqlen, 247 | softmax_scale=ctx.softmax_scale, 248 | dropout_p=ctx.dropout_p, 249 | causal=ctx.causal, 250 | window_size=ctx.window_size, 251 | alibi_slopes=ctx.alibi_slopes, 252 | deterministic=ctx.deterministic, 253 | ) 254 | return dq, dk, dv, None, None, None, None, None, None, None, None, None, None 255 | 256 | 257 | def ring_flash_attn_varlen_qkvpacked_func( 258 | qkv, 259 | cu_seqlens, 260 | max_seqlen, 261 | dropout_p=0.0, 262 | softmax_scale=None, 263 | causal=False, 264 | window_size=(-1, -1), # -1 means infinite context window 265 | alibi_slopes=None, 266 | deterministic=False, 267 | return_attn_probs=False, 268 | group=None, 269 | ): 270 | return RingFlashAttnVarlenFunc.apply( 271 | qkv[:, 0], 272 | qkv[:, 1], 273 | qkv[:, 2], 274 | cu_seqlens, 275 | max_seqlen, 276 | dropout_p, 277 | softmax_scale, 278 | causal, 279 | window_size, 280 | alibi_slopes, 281 | deterministic, 282 | return_attn_probs, 283 | group, 284 | ) 285 | 286 | 287 | def ring_flash_attn_varlen_kvpacked_func( 288 | q, 289 | kv, 290 | cu_seqlens, 291 | max_seqlen, 292 | dropout_p=0.0, 293 | softmax_scale=None, 294 | causal=False, 295 | window_size=(-1, -1), # -1 means infinite context window 296 | alibi_slopes=None, 297 | deterministic=False, 298 | return_attn_probs=False, 299 | group=None, 300 | ): 301 | return RingFlashAttnVarlenFunc.apply( 302 | q, 303 | kv[:, 0], 304 | kv[:, 1], 305 | cu_seqlens, 306 | max_seqlen, 307 | dropout_p, 308 | softmax_scale, 309 | causal, 310 | window_size, 311 | alibi_slopes, 312 | deterministic, 313 | return_attn_probs, 314 | group, 315 | ) 316 | 317 | 318 | def ring_flash_attn_varlen_func( 319 | q, 320 | k, 321 | v, 322 | cu_seqlens, 323 | max_seqlen, 324 | dropout_p=0.0, 325 | softmax_scale=None, 326 | causal=False, 327 | window_size=(-1, -1), # -1 means infinite context window 328 | alibi_slopes=None, 329 | deterministic=False, 330 | return_attn_probs=False, 331 | group=None, 332 | ): 333 | return RingFlashAttnVarlenFunc.apply( 334 | q, 335 | k, 336 | v, 337 | cu_seqlens, 338 | max_seqlen, 339 | dropout_p, 340 | softmax_scale, 341 | causal, 342 | window_size, 343 | alibi_slopes, 344 | deterministic, 345 | return_attn_probs, 346 | group, 347 | ) 348 | -------------------------------------------------------------------------------- /train/ring-flash-attention/ring_flash_attn/triton_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import triton 3 | import triton.language as tl 4 | 5 | 6 | @triton.jit 7 | def flatten_kernel( 8 | # pointers to matrices 9 | OUT, 10 | LSE, 11 | CU_SEQLENS, 12 | # strides 13 | stride_out_nheads, 14 | stride_out_seqlen, 15 | stride_lse_batch, 16 | stride_lse_nheads, 17 | stride_lse_seqlen, 18 | # meta-parameters 19 | BLOCK_M: tl.constexpr, 20 | ): 21 | pid_m = tl.program_id(axis=0) 22 | pid_batch = tl.program_id(axis=1) 23 | pid_head = tl.program_id(axis=2) 24 | 25 | start_idx = tl.load(CU_SEQLENS + pid_batch) 26 | seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx 27 | LSE = LSE + pid_batch * stride_lse_batch + pid_head * stride_lse_nheads 28 | OUT = OUT + pid_head * stride_out_nheads + start_idx * stride_out_seqlen 29 | 30 | rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 31 | 32 | LSE = LSE + rm[:, None] * stride_lse_seqlen 33 | x = tl.load(LSE, mask=rm[:, None] < seqlen, other=0.0) 34 | 35 | OUT = OUT + rm[:, None] * stride_out_seqlen 36 | tl.store(OUT, x, mask=rm[:, None] < seqlen) 37 | 38 | 39 | def flatten_varlen_lse(lse, cu_seqlens): 40 | """ 41 | Arguments: 42 | lse: (batch_size, nheads, max_seqlen) 43 | cu_seqlens: (batch_size + 1,) 44 | Return: 45 | flatten_lse: (nheads, total_seqlen) 46 | """ 47 | total_seqlen = cu_seqlens[-1] 48 | batch_size, nheads, max_seqlen = lse.shape 49 | output = torch.empty((nheads, total_seqlen), dtype=lse.dtype, device=lse.device) 50 | 51 | grid = lambda META: (triton.cdiv(max_seqlen, META["BLOCK_M"]), batch_size, nheads) 52 | BLOCK_M = 4 53 | 54 | with torch.cuda.device(lse.device.index): 55 | flatten_kernel[grid]( 56 | output, 57 | lse, 58 | cu_seqlens, 59 | # strides 60 | output.stride(0), 61 | output.stride(1), 62 | lse.stride(0), 63 | lse.stride(1), 64 | lse.stride(2), 65 | BLOCK_M, 66 | ) 67 | return output 68 | 69 | 70 | @triton.jit 71 | def unflatten_kernel( 72 | # pointers to matrices 73 | OUT, 74 | LSE, 75 | CU_SEQLENS, 76 | # strides 77 | stride_out_batch, 78 | stride_out_nheads, 79 | stride_out_seqlen, 80 | stride_lse_seqlen, 81 | stride_lse_nheads, 82 | # meta-parameters 83 | BLOCK_M: tl.constexpr, 84 | ): 85 | pid_m = tl.program_id(axis=0) 86 | pid_batch = tl.program_id(axis=1) 87 | pid_head = tl.program_id(axis=2) 88 | 89 | start_idx = tl.load(CU_SEQLENS + pid_batch) 90 | seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx 91 | LSE = LSE + pid_head * stride_lse_nheads + start_idx * stride_lse_seqlen 92 | OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads 93 | 94 | rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 95 | 96 | LSE = LSE + rm[:, None] * stride_lse_seqlen 97 | x = tl.load(LSE, mask=rm[:, None] < seqlen, other=0.0) 98 | 99 | OUT = OUT + rm[:, None] * stride_out_seqlen 100 | tl.store(OUT, x, mask=rm[:, None] < seqlen) 101 | 102 | 103 | def unflatten_varlen_lse(lse, cu_seqlens, max_seqlen: int): 104 | """ 105 | Arguments: 106 | lse: (total_seqlen, nheads, 1) 107 | cu_seqlens: (batch_size + 1,) 108 | max_seqlen: int 109 | Return: 110 | unflatten_lse: (batch_size, nheads, max_seqlen) 111 | """ 112 | lse = lse.unsqueeze(dim=-1) 113 | batch_size = len(cu_seqlens) - 1 114 | nheads = lse.shape[1] 115 | output = torch.empty( 116 | (batch_size, nheads, max_seqlen), 117 | dtype=lse.dtype, 118 | device=lse.device, 119 | ) 120 | 121 | grid = lambda META: (triton.cdiv(max_seqlen, META["BLOCK_M"]), batch_size, nheads) 122 | BLOCK_M = 4 123 | 124 | with torch.cuda.device(lse.device.index): 125 | unflatten_kernel[grid]( 126 | output, 127 | lse, 128 | cu_seqlens, 129 | # strides 130 | output.stride(0), 131 | output.stride(1), 132 | output.stride(2), 133 | lse.stride(0), 134 | lse.stride(1), 135 | BLOCK_M, 136 | ) 137 | return output 138 | -------------------------------------------------------------------------------- /train/ring-flash-attention/ring_flash_attn/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import torch 4 | import torch.distributed as dist 5 | import torch.nn.functional as F 6 | import inspect 7 | from functools import cache 8 | 9 | 10 | __all__ = ["update_out_and_lse", "RingComm", "get_default_args"] 11 | 12 | 13 | @cache 14 | def get_default_args(func): 15 | spec = inspect.getfullargspec(func) 16 | defaults = spec.defaults if spec.defaults is not None else () 17 | padded_defaults = (None,) * (len(spec.args) - len(defaults)) + defaults 18 | args = dict(zip(spec.args, padded_defaults)) 19 | if "softcap" in args: 20 | args["softcap"] = 0.0 21 | return args 22 | 23 | 24 | @torch.jit.script 25 | def _update_out_and_lse( 26 | out: torch.Tensor, 27 | lse: torch.Tensor, 28 | block_out: torch.Tensor, 29 | block_lse: torch.Tensor, 30 | ) -> Tuple[torch.Tensor, torch.Tensor]: 31 | 32 | block_out = block_out.to(torch.float32) 33 | block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) 34 | 35 | # new_lse = lse + torch.log(1 + torch.exp(block_lse - lse)) 36 | # torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out 37 | # For additional context and discussion, please refer to: 38 | # https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795 39 | out = out - F.sigmoid(block_lse - lse) * (out - block_out) 40 | lse = lse - F.logsigmoid(lse - block_lse) 41 | 42 | return out, lse 43 | 44 | 45 | def update_out_and_lse( 46 | out: Optional[torch.Tensor], 47 | lse: Optional[torch.Tensor], 48 | block_out: torch.Tensor, 49 | block_lse: torch.Tensor, 50 | slice_=None, 51 | ) -> Tuple[torch.Tensor, torch.Tensor]: 52 | if out is None: 53 | if slice_ is not None: 54 | raise RuntimeError("first update_out_and_lse should not pass slice_ args") 55 | out = block_out.to(torch.float32) 56 | lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) 57 | elif slice_ is not None: 58 | slice_out, slice_lse = out[slice_], lse[slice_] 59 | slice_out, slice_lse = _update_out_and_lse( 60 | slice_out, slice_lse, block_out, block_lse 61 | ) 62 | out[slice_], lse[slice_] = slice_out, slice_lse 63 | else: 64 | out, lse = _update_out_and_lse(out, lse, block_out, block_lse) 65 | return out, lse 66 | 67 | 68 | @torch.jit.script 69 | def flatten_varlen_lse(lse, cu_seqlens): 70 | new_lse = [] 71 | for i in range(len(cu_seqlens) - 1): 72 | start, end = cu_seqlens[i], cu_seqlens[i + 1] 73 | new_lse.append(lse[i, :, : end - start]) 74 | return torch.cat(new_lse, dim=1) 75 | 76 | 77 | @torch.jit.script 78 | def unflatten_varlen_lse(lse, cu_seqlens, max_seqlen: int): 79 | num_seq = len(cu_seqlens) - 1 80 | num_head = lse.shape[-2] 81 | new_lse = torch.empty( 82 | (num_seq, max_seqlen, num_head, 1), dtype=torch.float32, device=lse.device 83 | ) 84 | for i in range(num_seq): 85 | start, end = cu_seqlens[i], cu_seqlens[i + 1] 86 | new_lse[i, : end - start] = lse[start:end] 87 | return new_lse.squeeze(dim=-1).transpose(1, 2).contiguous() 88 | 89 | 90 | class RingComm: 91 | def __init__(self, process_group: dist.ProcessGroup): 92 | self._process_group = process_group 93 | self._ops = [] 94 | self.rank = dist.get_rank(self._process_group) 95 | self.world_size = dist.get_world_size(self._process_group) 96 | self._reqs = None 97 | 98 | self.send_rank = (self.rank + 1) % self.world_size 99 | self.recv_rank = (self.rank - 1) % self.world_size 100 | 101 | if process_group is not None: 102 | self.send_rank = dist.get_global_rank(self._process_group, self.send_rank) 103 | self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank) 104 | 105 | def send_recv( 106 | self, to_send: torch.Tensor, recv_tensor: Optional[torch.Tensor] = None 107 | ) -> torch.Tensor: 108 | if recv_tensor is None: 109 | res = torch.empty_like(to_send) 110 | else: 111 | res = recv_tensor 112 | 113 | send_op = dist.P2POp( 114 | dist.isend, to_send, self.send_rank, group=self._process_group 115 | ) 116 | recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group) 117 | self._ops.append(send_op) 118 | self._ops.append(recv_op) 119 | return res 120 | 121 | def commit(self): 122 | if self._reqs is not None: 123 | raise RuntimeError("commit called twice") 124 | self._reqs = dist.batch_isend_irecv(self._ops) 125 | 126 | def wait(self): 127 | if self._reqs is None: 128 | raise RuntimeError("wait called before commit") 129 | for req in self._reqs: 130 | req.wait() 131 | self._reqs = None 132 | self._ops = [] 133 | 134 | 135 | class AllGatherComm: 136 | def __init__(self, group=None) -> None: 137 | self.group = group 138 | self.handles = [] 139 | 140 | def all_gather(self, output_tensor: torch.Tensor, input_tensor: torch.Tensor): 141 | handle = dist.all_gather_into_tensor( 142 | output_tensor, input_tensor, group=self.group, async_op=True 143 | ) 144 | self.handles.append(handle) 145 | 146 | def wait(self): 147 | for handle in self.handles: 148 | handle.wait() 149 | self.handles = [] 150 | -------------------------------------------------------------------------------- /train/ring-flash-attention/ring_flash_attn/zigzag_ring_flash_attn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward 4 | from .utils import RingComm, update_out_and_lse, get_default_args 5 | 6 | 7 | def zigzag_ring_flash_attn_forward( 8 | process_group, 9 | q: torch.Tensor, 10 | k: torch.Tensor, 11 | v: torch.Tensor, 12 | softmax_scale, 13 | dropout_p=0, 14 | causal=True, 15 | window_size=(-1, -1), 16 | alibi_slopes=None, 17 | deterministic=False, 18 | ): 19 | assert causal == True, "zigzag ring is meaningless for causal=False" 20 | comm = RingComm(process_group) 21 | 22 | block_seq_len = q.shape[1] // 2 23 | q1 = q[:, block_seq_len:] 24 | 25 | out = None 26 | lse = None 27 | next_k, next_v = None, None 28 | 29 | def forward(q, k, v, causal): 30 | params = get_default_args(_flash_attn_forward).copy() 31 | params.update( 32 | { 33 | "q": q, 34 | "k": k, 35 | "v": v, 36 | "dropout_p": dropout_p, 37 | "softmax_scale": softmax_scale, 38 | "causal": causal, 39 | "window_size": window_size, 40 | "alibi_slopes": alibi_slopes, 41 | "return_softmax": True and dropout_p > 0, 42 | } 43 | ) 44 | block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward(**params) 45 | return block_out, block_lse 46 | 47 | for step in range(comm.world_size): 48 | if step + 1 != comm.world_size: 49 | next_k: torch.Tensor = comm.send_recv(k) 50 | next_v: torch.Tensor = comm.send_recv(v) 51 | comm.commit() 52 | 53 | if step == 0: 54 | block_out, block_lse = forward(q, k, v, causal=True) 55 | out, lse = update_out_and_lse(out, lse, block_out, block_lse) 56 | elif step <= comm.rank: 57 | k0 = k[:, :block_seq_len] 58 | v0 = v[:, :block_seq_len] 59 | block_out, block_lse = forward(q, k0, v0, causal=False) 60 | out, lse = update_out_and_lse(out, lse, block_out, block_lse) 61 | else: 62 | block_out, block_lse = forward(q1, k, v, causal=False) 63 | out, lse = update_out_and_lse( 64 | out, 65 | lse, 66 | block_out, 67 | block_lse, 68 | slice_=(slice(None), slice(block_seq_len, None)), 69 | ) 70 | 71 | if step + 1 != comm.world_size: 72 | comm.wait() 73 | k = next_k 74 | v = next_v 75 | 76 | out = out.to(q.dtype) 77 | lse = lse.squeeze(dim=-1).transpose(1, 2) 78 | return out, lse 79 | 80 | 81 | def zigzag_ring_flash_attn_backward( 82 | process_group, 83 | dout, 84 | q, 85 | k, 86 | v, 87 | out, 88 | softmax_lse, 89 | softmax_scale, 90 | dropout_p=0, 91 | causal=True, 92 | window_size=(-1, -1), 93 | alibi_slopes=None, 94 | deterministic=False, 95 | ): 96 | assert causal == True, "zigzag ring is meaningless for causal=False" 97 | kv_comm = RingComm(process_group) 98 | d_kv_comm = RingComm(process_group) 99 | dq, dk, dv = None, None, None 100 | next_dk, next_dv = None, None 101 | next_k, next_v = None, None 102 | dk_comm_buffer, dv_comm_buffer = None, None 103 | 104 | dout1 = dout.chunk(2, dim=1)[1] 105 | q1 = q.chunk(2, dim=1)[1] 106 | out1 = out.chunk(2, dim=1)[1] 107 | softmax_lse1 = softmax_lse.chunk(2, dim=2)[1].contiguous() 108 | block_seq_len = q.shape[1] // 2 109 | 110 | # repeatly allocating buffer may be slow... 111 | dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) 112 | dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) 113 | dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) 114 | 115 | def backward(dout, q, k, v, out, softmax_lse, causal): 116 | seqlen_q = q.shape[1] 117 | seqlen_kv = k.shape[1] 118 | params = get_default_args(_flash_attn_backward).copy() 119 | params.update( 120 | { 121 | "dout": dout, 122 | "q": q, 123 | "k": k, 124 | "v": v, 125 | "out": out, 126 | "softmax_lse": softmax_lse, 127 | "dq": dq_buffer[:, :seqlen_q], 128 | "dk": dk_buffer[:, :seqlen_kv], 129 | "dv": dv_buffer[:, :seqlen_kv], 130 | "dropout_p": dropout_p, 131 | "softmax_scale": softmax_scale, 132 | "causal": causal, 133 | "window_size": window_size, 134 | "alibi_slopes": alibi_slopes, 135 | "deterministic": deterministic, 136 | } 137 | ) 138 | _flash_attn_backward(**params) 139 | 140 | for step in range(kv_comm.world_size): 141 | if step + 1 != kv_comm.world_size: 142 | next_k = kv_comm.send_recv(k) 143 | next_v = kv_comm.send_recv(v) 144 | kv_comm.commit() 145 | 146 | if step == 0: 147 | backward(dout, q, k, v, out, softmax_lse, causal=True) 148 | dq = dq_buffer.to(torch.float32) 149 | dk = dk_buffer.to(torch.float32) 150 | dv = dv_buffer.to(torch.float32) 151 | else: 152 | if step <= kv_comm.rank: 153 | k0 = k[:, :block_seq_len] 154 | v0 = v[:, :block_seq_len] 155 | backward(dout, q, k0, v0, out, softmax_lse, causal=False) 156 | dq += dq_buffer 157 | else: 158 | backward(dout1, q1, k, v, out1, softmax_lse1, causal=False) 159 | # always use the first half in dq_buffer. 160 | dq[:, block_seq_len:] += dq_buffer[:, :block_seq_len] 161 | 162 | d_kv_comm.wait() 163 | dk_comm_buffer, dv_comm_buffer = dk, dv 164 | dk, dv = next_dk, next_dv 165 | 166 | if step <= kv_comm.rank: 167 | dk[:, :block_seq_len] += dk_buffer[:, :block_seq_len] 168 | dv[:, :block_seq_len] += dv_buffer[:, :block_seq_len] 169 | else: 170 | dk += dk_buffer 171 | dv += dv_buffer 172 | 173 | if step + 1 != kv_comm.world_size: 174 | kv_comm.wait() 175 | k = next_k 176 | v = next_v 177 | 178 | next_dk = d_kv_comm.send_recv(dk, dk_comm_buffer) 179 | next_dv = d_kv_comm.send_recv(dv, dv_comm_buffer) 180 | d_kv_comm.commit() 181 | 182 | d_kv_comm.wait() 183 | 184 | return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) 185 | 186 | 187 | class ZigZagRingFlashAttnFunc(torch.autograd.Function): 188 | @staticmethod 189 | def forward( 190 | ctx, 191 | q, 192 | k, 193 | v, 194 | dropout_p, 195 | softmax_scale, 196 | causal, 197 | window_size, 198 | alibi_slopes, 199 | deterministic, 200 | return_softmax, 201 | group, 202 | ): 203 | if softmax_scale is None: 204 | softmax_scale = q.shape[-1] ** (-0.5) 205 | 206 | assert alibi_slopes is None 207 | k = k.contiguous() 208 | v = v.contiguous() 209 | out, softmax_lse = zigzag_ring_flash_attn_forward( 210 | group, 211 | q, 212 | k, 213 | v, 214 | softmax_scale=softmax_scale, 215 | dropout_p=dropout_p, 216 | causal=causal, 217 | window_size=window_size, 218 | alibi_slopes=alibi_slopes, 219 | deterministic=False, 220 | ) 221 | # this should be out_padded 222 | ctx.save_for_backward(q, k, v, out, softmax_lse) 223 | ctx.dropout_p = dropout_p 224 | ctx.softmax_scale = softmax_scale 225 | ctx.causal = causal 226 | ctx.window_size = window_size 227 | ctx.alibi_slopes = alibi_slopes 228 | ctx.deterministic = deterministic 229 | ctx.group = group 230 | return out if not return_softmax else (out, softmax_lse, None) 231 | 232 | @staticmethod 233 | def backward(ctx, dout, *args): 234 | q, k, v, out, softmax_lse = ctx.saved_tensors 235 | dq, dk, dv = zigzag_ring_flash_attn_backward( 236 | ctx.group, 237 | dout, 238 | q, 239 | k, 240 | v, 241 | out, 242 | softmax_lse, 243 | softmax_scale=ctx.softmax_scale, 244 | dropout_p=ctx.dropout_p, 245 | causal=ctx.causal, 246 | window_size=ctx.window_size, 247 | alibi_slopes=ctx.alibi_slopes, 248 | deterministic=ctx.deterministic, 249 | ) 250 | return dq, dk, dv, None, None, None, None, None, None, None, None 251 | 252 | 253 | def zigzag_ring_flash_attn_qkvpacked_func( 254 | qkv, 255 | dropout_p=0.0, 256 | softmax_scale=None, 257 | causal=False, 258 | window_size=(-1, -1), 259 | alibi_slopes=None, 260 | deterministic=False, 261 | return_attn_probs=False, 262 | group=None, 263 | ): 264 | return ZigZagRingFlashAttnFunc.apply( 265 | qkv[:, :, 0], 266 | qkv[:, :, 1], 267 | qkv[:, :, 2], 268 | dropout_p, 269 | softmax_scale, 270 | causal, 271 | window_size, 272 | alibi_slopes, 273 | deterministic, 274 | return_attn_probs, 275 | group, 276 | ) 277 | 278 | 279 | def zigzag_ring_flash_attn_kvpacked_func( 280 | q, 281 | kv, 282 | dropout_p=0.0, 283 | softmax_scale=None, 284 | causal=False, 285 | window_size=(-1, -1), 286 | alibi_slopes=None, 287 | deterministic=False, 288 | return_attn_probs=False, 289 | group=None, 290 | ): 291 | return ZigZagRingFlashAttnFunc.apply( 292 | q, 293 | kv[:, :, 0], 294 | kv[:, :, 1], 295 | dropout_p, 296 | softmax_scale, 297 | causal, 298 | window_size, 299 | alibi_slopes, 300 | deterministic, 301 | return_attn_probs, 302 | group, 303 | ) 304 | 305 | 306 | def zigzag_ring_flash_attn_func( 307 | q, 308 | k, 309 | v, 310 | dropout_p=0.0, 311 | softmax_scale=None, 312 | causal=False, 313 | window_size=(-1, -1), 314 | alibi_slopes=None, 315 | deterministic=False, 316 | return_attn_probs=False, 317 | group=None, 318 | ): 319 | return ZigZagRingFlashAttnFunc.apply( 320 | q, 321 | k, 322 | v, 323 | dropout_p, 324 | softmax_scale, 325 | causal, 326 | window_size, 327 | alibi_slopes, 328 | deterministic, 329 | return_attn_probs, 330 | group, 331 | ) 332 | -------------------------------------------------------------------------------- /train/ring-flash-attention/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="ring_flash_attn", 5 | version="0.1", 6 | author="zhuzilin", 7 | url="https://github.com/zhuzilin/ring-flash-attention", 8 | packages=find_packages(), 9 | ) 10 | -------------------------------------------------------------------------------- /train/ring-flash-attention/test/test_llama3_flash_attn_varlen_func.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from flash_attn import flash_attn_varlen_qkvpacked_func 4 | from ring_flash_attn import ( 5 | llama3_flash_attn_prepare_cu_seqlens, 6 | llama3_flash_attn_varlen_qkvpacked_func, 7 | ) 8 | from utils import log, set_seed 9 | 10 | 11 | if __name__ == "__main__": 12 | dist.init_process_group("nccl") 13 | rank = dist.get_rank() 14 | set_seed(rank) 15 | world_size = dist.get_world_size() 16 | dtype = torch.bfloat16 17 | device = torch.device(f"cuda:{rank}") 18 | 19 | batch_size = 1 20 | nheads = 5 21 | d = 8 22 | dropout_p = 0 23 | causal = True 24 | deterministic = False 25 | 26 | cu_seqlens = [0, 120, 1248, 4232] 27 | cu_seqlens_tensor = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) 28 | max_seqlen = (cu_seqlens_tensor[1:] - cu_seqlens_tensor[:-1]).max().item() 29 | total_length = cu_seqlens[-1] 30 | local_length = total_length // world_size 31 | num_seq = len(cu_seqlens) - 1 32 | 33 | assert cu_seqlens_tensor[-1] % world_size == 0 34 | assert d % 8 == 0 35 | 36 | qkv = torch.randn( 37 | total_length, 3, nheads, d, device=device, dtype=dtype, requires_grad=True 38 | ) 39 | dist.broadcast(qkv, src=0) 40 | 41 | dout = torch.randn(total_length, nheads, d, device=device, dtype=dtype) 42 | dist.broadcast(dout, src=0) 43 | 44 | local_qkv = qkv[rank * local_length : (rank + 1) * local_length].detach().clone() 45 | local_dout = dout[rank * local_length : (rank + 1) * local_length].detach().clone() 46 | local_qkv.requires_grad = True 47 | 48 | dist.barrier() 49 | if rank == 0: 50 | print("#" * 30) 51 | print("# forward:") 52 | print("#" * 30) 53 | 54 | out, lse, _ = flash_attn_varlen_qkvpacked_func( 55 | qkv, 56 | cu_seqlens_tensor, 57 | max_seqlen, 58 | dropout_p=dropout_p, 59 | causal=causal, 60 | window_size=(-1, -1), 61 | alibi_slopes=None, 62 | deterministic=deterministic, 63 | return_attn_probs=True, 64 | ) 65 | 66 | local_out = out[rank * local_length : (rank + 1) * local_length] 67 | if lse.dim() == 2: 68 | local_lse = lse[:, rank * local_length : (rank + 1) * local_length] 69 | 70 | ( 71 | local_cu_seqlens_q, 72 | local_cu_seqlens_k, 73 | max_seqlen_q, 74 | max_seqlen_k, 75 | local_k_slice, 76 | ) = llama3_flash_attn_prepare_cu_seqlens( 77 | cu_seqlens_tensor, 78 | causal=causal, 79 | rank=rank, 80 | world_size=world_size, 81 | ) 82 | 83 | llama3_out, llama3_lse, _ = llama3_flash_attn_varlen_qkvpacked_func( 84 | local_qkv, 85 | local_cu_seqlens_q, 86 | local_cu_seqlens_k, 87 | max_seqlen_q, 88 | max_seqlen_k, 89 | heads_k_stride=1, 90 | local_k_slice=local_k_slice, 91 | dropout_p=dropout_p, 92 | causal=causal, 93 | window_size=(-1, -1), 94 | alibi_slopes=None, 95 | deterministic=deterministic, 96 | return_attn_probs=True, 97 | ) 98 | 99 | log("out", out, rank0_only=True) 100 | log("out diff", local_out - llama3_out) 101 | if lse.dim() == 2: 102 | log("lse", lse, rank0_only=True) 103 | log("lse diff", local_lse - llama3_lse) 104 | 105 | dist.barrier() 106 | if rank == 0: 107 | print("#" * 30) 108 | print("# backward:") 109 | print("#" * 30) 110 | 111 | out.backward(dout) 112 | dqkv = qkv.grad 113 | local_dqkv = dqkv[rank * local_length : (rank + 1) * local_length] 114 | 115 | llama3_out.backward(local_dout) 116 | llama3_dqkv = local_qkv.grad 117 | 118 | log("local_dq", local_dqkv[:, 0]) 119 | log("dq diff", local_dqkv[:, 0] - llama3_dqkv[:, 0]) 120 | log("dk diff", local_dqkv[:, 1] - llama3_dqkv[:, 1]) 121 | log("dv diff", local_dqkv[:, 2] - llama3_dqkv[:, 2]) 122 | -------------------------------------------------------------------------------- /train/ring-flash-attention/test/test_llama3_prepare_cu_seqlens.py: -------------------------------------------------------------------------------- 1 | from ring_flash_attn import llama3_flash_attn_prepare_cu_seqlens 2 | import torch 3 | 4 | if __name__ == "__main__": 5 | device = torch.device("cuda") 6 | cu_seqlens = [0, 7, 14, 16] 7 | cu_seqlens_tensor = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) 8 | 9 | world_size = 8 10 | for rank in range(world_size): 11 | ( 12 | cu_seqlens_q, 13 | cu_seqlens_k, 14 | max_seqlen_q, 15 | max_seqlen_k, 16 | local_k_slice, 17 | ) = llama3_flash_attn_prepare_cu_seqlens( 18 | cu_seqlens_tensor, 19 | causal=True, 20 | rank=rank, 21 | world_size=world_size, 22 | ) 23 | 24 | assert max_seqlen_q == (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).max() 25 | assert max_seqlen_k == (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max() 26 | print(f"RANK: {rank}") 27 | print(f" cu_seqlens_q: {cu_seqlens_q}") 28 | print(f" cu_seqlens_k: {cu_seqlens_k}") 29 | print(f" local_k_slice: {local_k_slice}") 30 | -------------------------------------------------------------------------------- /train/ring-flash-attention/test/test_ring_flash_attn_func.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from flash_attn import flash_attn_qkvpacked_func 4 | from ring_flash_attn import ring_flash_attn_qkvpacked_func 5 | from utils import log, set_seed 6 | 7 | 8 | if __name__ == "__main__": 9 | dist.init_process_group("nccl") 10 | rank = dist.get_rank() 11 | set_seed(rank) 12 | world_size = dist.get_world_size() 13 | dtype = torch.bfloat16 14 | device = torch.device(f"cuda:{rank}") 15 | 16 | batch_size = 1 17 | seqlen = 3816 18 | nheads = 5 19 | d = 128 20 | dropout_p = 0 21 | causal = True 22 | deterministic = False 23 | 24 | assert seqlen % world_size == 0 25 | assert d % 8 == 0 26 | 27 | qkv = torch.randn( 28 | batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True 29 | ) 30 | dist.broadcast(qkv, src=0) 31 | 32 | dout = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype) 33 | dist.broadcast(dout, src=0) 34 | 35 | local_qkv = qkv.chunk(world_size, dim=1)[rank].detach().clone() 36 | local_qkv.requires_grad = True 37 | local_dout = dout.chunk(world_size, dim=1)[rank].detach().clone() 38 | 39 | dist.barrier() 40 | if rank == 0: 41 | print("#" * 30) 42 | print("# forward:") 43 | print("#" * 30) 44 | 45 | out, lse, _ = flash_attn_qkvpacked_func( 46 | qkv, 47 | dropout_p=dropout_p, 48 | causal=causal, 49 | window_size=(-1, -1), 50 | alibi_slopes=None, 51 | deterministic=deterministic, 52 | return_attn_probs=True, 53 | ) 54 | 55 | local_out = out.chunk(world_size, dim=1)[rank] 56 | local_lse = lse.chunk(world_size, dim=-1)[rank] 57 | 58 | fn = ring_flash_attn_qkvpacked_func 59 | 60 | ring_out, ring_lse, _ = fn( 61 | local_qkv, 62 | dropout_p=dropout_p, 63 | causal=causal, 64 | window_size=(-1, -1), 65 | alibi_slopes=None, 66 | deterministic=deterministic, 67 | return_attn_probs=True, 68 | ) 69 | 70 | log("out", out, rank0_only=True) 71 | log("lse", lse, rank0_only=True) 72 | log("out diff", local_out - ring_out) 73 | log("lse diff", local_lse - ring_lse) 74 | 75 | dist.barrier() 76 | if rank == 0: 77 | print("#" * 30) 78 | print("# backward:") 79 | print("#" * 30) 80 | 81 | out.backward(dout) 82 | dqkv = qkv.grad 83 | local_dqkv = dqkv.chunk(world_size, dim=1)[rank] 84 | 85 | ring_out.backward(local_dout) 86 | ring_dqkv = local_qkv.grad 87 | 88 | log("local_dqkv", local_dqkv) 89 | log("dq diff", local_dqkv[:, 0] - ring_dqkv[:, 0]) 90 | log("dk diff", local_dqkv[:, 1] - ring_dqkv[:, 1]) 91 | log("dv diff", local_dqkv[:, 2] - ring_dqkv[:, 2]) 92 | -------------------------------------------------------------------------------- /train/ring-flash-attention/test/test_ring_flash_attn_varlen_func.py: -------------------------------------------------------------------------------- 1 | from flash_attn import flash_attn_varlen_qkvpacked_func 2 | import torch 3 | import torch.distributed as dist 4 | from ring_flash_attn import ring_flash_attn_varlen_qkvpacked_func 5 | from utils import log, set_seed 6 | 7 | 8 | def extract_local(value, cu_seqlens, rank, world_size): 9 | local_values = [] 10 | for i in range(len(cu_seqlens) - 1): 11 | start, end = cu_seqlens[i], cu_seqlens[i + 1] 12 | local_value = value[start:end].chunk(world_size, dim=0)[rank].detach().clone() 13 | local_values.append(local_value) 14 | return torch.cat(local_values, dim=0).contiguous() 15 | 16 | 17 | def extract_lse(lse, cu_seqlens): 18 | values = [] 19 | if lse.dim() == 2: 20 | for i in range(len(cu_seqlens) - 1): 21 | start, end = cu_seqlens[i], cu_seqlens[i + 1] 22 | value = lse[:, start:end] 23 | values.append(value) 24 | else: 25 | assert lse.dim() == 3 26 | for i in range(len(cu_seqlens) - 1): 27 | start, end = cu_seqlens[i], cu_seqlens[i + 1] 28 | value = lse[i, :, : end - start] 29 | values.append(value) 30 | return values 31 | 32 | 33 | if __name__ == "__main__": 34 | dist.init_process_group("nccl") 35 | rank = dist.get_rank() 36 | set_seed(rank) 37 | world_size = dist.get_world_size() 38 | dtype = torch.bfloat16 39 | device = torch.device(f"cuda:{rank}") 40 | 41 | batch_size = 1 42 | nheads = 5 43 | d = 128 44 | dropout_p = 0 45 | causal = True 46 | deterministic = False 47 | 48 | cu_seqlens = [0, 120, 1248, 4232] 49 | cu_seqlens_tensor = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) 50 | max_seqlen = (cu_seqlens_tensor[1:] - cu_seqlens_tensor[:-1]).max().item() 51 | total_length = cu_seqlens[-1] 52 | num_seq = len(cu_seqlens) - 1 53 | 54 | assert torch.all(cu_seqlens_tensor % world_size == 0) 55 | assert d % 8 == 0 56 | 57 | qkv = torch.randn( 58 | total_length, 3, nheads, d, device=device, dtype=dtype, requires_grad=True 59 | ) 60 | dist.broadcast(qkv, src=0) 61 | 62 | dout = torch.randn(total_length, nheads, d, device=device, dtype=dtype) 63 | dist.broadcast(dout, src=0) 64 | 65 | local_cu_seqlens_tensor = cu_seqlens_tensor // world_size 66 | local_max_seqlen = max_seqlen // world_size 67 | 68 | local_qkv = extract_local(qkv, cu_seqlens, rank, world_size) 69 | local_qkv.requires_grad = True 70 | local_dout = extract_local(dout, cu_seqlens, rank, world_size) 71 | 72 | dist.barrier() 73 | if rank == 0: 74 | print("#" * 30) 75 | print("# forward:") 76 | print("#" * 30) 77 | 78 | out, lse, _ = flash_attn_varlen_qkvpacked_func( 79 | qkv, 80 | cu_seqlens_tensor, 81 | max_seqlen, 82 | dropout_p=dropout_p, 83 | causal=causal, 84 | window_size=(-1, -1), 85 | alibi_slopes=None, 86 | deterministic=deterministic, 87 | return_attn_probs=True, 88 | ) 89 | 90 | local_out = extract_local(out, cu_seqlens, rank, world_size) 91 | lse_list = extract_lse(lse, cu_seqlens) 92 | 93 | ring_out, ring_lse, _ = ring_flash_attn_varlen_qkvpacked_func( 94 | local_qkv, 95 | local_cu_seqlens_tensor, 96 | local_max_seqlen, 97 | dropout_p=dropout_p, 98 | causal=causal, 99 | window_size=(-1, -1), 100 | alibi_slopes=None, 101 | deterministic=deterministic, 102 | return_attn_probs=True, 103 | ) 104 | 105 | ring_lse_list = extract_lse(ring_lse, local_cu_seqlens_tensor.tolist()) 106 | 107 | log("out", out, rank0_only=True) 108 | log("out diff", local_out - ring_out) 109 | 110 | for lse, ring_lse in zip(lse_list, ring_lse_list): 111 | local_lse = lse.chunk(world_size, dim=-1)[rank] 112 | log("lse", lse, rank0_only=True) 113 | log("lse diff", local_lse - ring_lse) 114 | 115 | dist.barrier() 116 | if rank == 0: 117 | print("#" * 30) 118 | print("# backward:") 119 | print("#" * 30) 120 | 121 | out.backward(dout) 122 | dqkv = qkv.grad 123 | local_dqkv = extract_local(dqkv, cu_seqlens, rank, world_size) 124 | 125 | ring_out.backward(local_dout) 126 | ring_dqkv = local_qkv.grad 127 | 128 | log("local_dqkv", local_dqkv) 129 | log("dq diff", local_dqkv[:, 0] - ring_dqkv[:, 0]) 130 | log("dk diff", local_dqkv[:, 1] - ring_dqkv[:, 1]) 131 | log("dv diff", local_dqkv[:, 2] - ring_dqkv[:, 2]) 132 | -------------------------------------------------------------------------------- /train/ring-flash-attention/test/test_stripe_flash_attn_func.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from flash_attn import flash_attn_qkvpacked_func 4 | from ring_flash_attn import stripe_flash_attn_qkvpacked_func 5 | from utils import log, set_seed 6 | 7 | 8 | def extract_local(value, rank, world_size, dim=1): 9 | value = torch.stack(value.split(world_size, dim=dim), dim=dim).transpose( 10 | dim, dim + 1 11 | ) 12 | slicer = [rank if i == dim else slice(None) for i in range(len(value.shape))] 13 | return value[slicer].contiguous() 14 | 15 | 16 | if __name__ == "__main__": 17 | dist.init_process_group("nccl") 18 | rank = dist.get_rank() 19 | set_seed(rank) 20 | world_size = dist.get_world_size() 21 | dtype = torch.bfloat16 22 | device = torch.device(f"cuda:{rank}") 23 | 24 | batch_size = 1 25 | seqlen = 3824 26 | nheads = 5 27 | d = 128 28 | dropout_p = 0 29 | causal = True 30 | deterministic = False 31 | 32 | assert causal 33 | assert seqlen % (2 * world_size) == 0 34 | assert d % 8 == 0 35 | 36 | qkv = torch.randn( 37 | batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True 38 | ) 39 | dist.broadcast(qkv, src=0) 40 | 41 | dout = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype) 42 | dist.broadcast(dout, src=0) 43 | 44 | local_qkv = extract_local(qkv, rank, world_size).detach().clone() 45 | local_qkv.requires_grad = True 46 | local_dout = extract_local(dout, rank, world_size).detach().clone() 47 | 48 | dist.barrier() 49 | if rank == 0: 50 | print("#" * 30) 51 | print("# forward:") 52 | print("#" * 30) 53 | 54 | out, lse, _ = flash_attn_qkvpacked_func( 55 | qkv, 56 | dropout_p=dropout_p, 57 | causal=causal, 58 | window_size=(-1, -1), 59 | alibi_slopes=None, 60 | deterministic=deterministic, 61 | return_attn_probs=True, 62 | ) 63 | 64 | local_out = extract_local(out, rank, world_size) 65 | local_lse = extract_local(lse, rank, world_size, dim=2) 66 | 67 | ring_out, ring_lse, _ = stripe_flash_attn_qkvpacked_func( 68 | local_qkv, 69 | dropout_p=dropout_p, 70 | causal=causal, 71 | window_size=(-1, -1), 72 | alibi_slopes=None, 73 | deterministic=deterministic, 74 | return_attn_probs=True, 75 | ) 76 | 77 | log("out", out, rank0_only=True) 78 | log("lse", lse, rank0_only=True) 79 | log("out diff", local_out - ring_out) 80 | log("lse diff", local_lse - ring_lse) 81 | 82 | dist.barrier() 83 | if rank == 0: 84 | print("#" * 30) 85 | print("# backward:") 86 | print("#" * 30) 87 | 88 | out.backward(dout) 89 | dqkv = qkv.grad 90 | 91 | local_dqkv = extract_local(dqkv, rank, world_size) 92 | 93 | ring_out.backward(local_dout) 94 | ring_dqkv = local_qkv.grad 95 | 96 | log("local_dqkv", local_dqkv) 97 | log("dq diff", local_dqkv[:, 0] - ring_dqkv[:, 0]) 98 | log("dk diff", local_dqkv[:, 1] - ring_dqkv[:, 1]) 99 | log("dv diff", local_dqkv[:, 2] - ring_dqkv[:, 2]) 100 | -------------------------------------------------------------------------------- /train/ring-flash-attention/test/test_triton_kernels.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ring_flash_attn.utils import ( 3 | flatten_varlen_lse, 4 | unflatten_varlen_lse, 5 | ) 6 | from ring_flash_attn.triton_utils import ( 7 | flatten_varlen_lse as triton_flatten_varlen_lse, 8 | unflatten_varlen_lse as triton_unflatten_varlen_lse, 9 | ) 10 | 11 | 12 | if __name__ == "__main__": 13 | device = torch.device("cuda:0") 14 | 15 | cu_seqlens = [0, 15, 156, 529] 16 | cu_seqlens_tensor = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) 17 | batch_size = len(cu_seqlens) - 1 18 | max_seqlen = (cu_seqlens_tensor[1:] - cu_seqlens_tensor[:-1]).max().item() 19 | n_head = 5 20 | 21 | lse = torch.randn( 22 | (batch_size, n_head, max_seqlen), dtype=torch.float32, device=device 23 | ) 24 | flatten_lse = flatten_varlen_lse(lse, cu_seqlens_tensor) 25 | triton_flatten_lse = triton_flatten_varlen_lse(lse, cu_seqlens_tensor) 26 | assert torch.all(flatten_lse == triton_flatten_lse) 27 | 28 | flatten_lse = flatten_lse.transpose(-2, -1).unsqueeze(dim=-1) 29 | triton_flatten_lse = triton_flatten_lse.transpose(-2, -1).unsqueeze(dim=-1) 30 | 31 | unflatten_lse = unflatten_varlen_lse(flatten_lse, cu_seqlens_tensor, max_seqlen) 32 | triton_unflatten_lse = triton_unflatten_varlen_lse( 33 | triton_flatten_lse, cu_seqlens_tensor, max_seqlen 34 | ) 35 | 36 | for i in range(batch_size): 37 | seqlen = cu_seqlens[i + 1] - cu_seqlens[i] 38 | assert torch.all( 39 | unflatten_lse[i, :, :seqlen] == triton_unflatten_lse[i, :, :seqlen] 40 | ), f"{unflatten_lse[i, :seqlen]} vs {triton_unflatten_lse[i, :seqlen]}" 41 | -------------------------------------------------------------------------------- /train/ring-flash-attention/test/test_zigzag_ring_flash_attn_func.py: -------------------------------------------------------------------------------- 1 | from flash_attn import flash_attn_qkvpacked_func 2 | import torch 3 | import torch.distributed as dist 4 | from ring_flash_attn import zigzag_ring_flash_attn_qkvpacked_func 5 | from utils import log, set_seed 6 | 7 | 8 | def extract_local(value, rank, world_size, dim=1): 9 | value_chunks = value.chunk(2 * world_size, dim=dim) 10 | local_value = torch.cat( 11 | [value_chunks[rank], value_chunks[2 * world_size - rank - 1]], dim=dim 12 | ) 13 | return local_value.contiguous() 14 | 15 | 16 | if __name__ == "__main__": 17 | dist.init_process_group("nccl") 18 | rank = dist.get_rank() 19 | set_seed(rank) 20 | world_size = dist.get_world_size() 21 | dtype = torch.bfloat16 22 | device = torch.device(f"cuda:{rank}") 23 | 24 | batch_size = 1 25 | seqlen = 3824 26 | nheads = 5 27 | d = 128 28 | dropout_p = 0 29 | causal = True 30 | deterministic = False 31 | 32 | assert causal 33 | assert seqlen % (2 * world_size) == 0 34 | assert d % 8 == 0 35 | 36 | qkv = torch.randn( 37 | batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True 38 | ) 39 | dist.broadcast(qkv, src=0) 40 | 41 | dout = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype) 42 | dist.broadcast(dout, src=0) 43 | 44 | local_qkv = extract_local(qkv, rank, world_size).detach().clone() 45 | local_qkv.requires_grad = True 46 | local_dout = extract_local(dout, rank, world_size).detach().clone() 47 | 48 | dist.barrier() 49 | if rank == 0: 50 | print("#" * 30) 51 | print("# forward:") 52 | print("#" * 30) 53 | 54 | out, lse, _ = flash_attn_qkvpacked_func( 55 | qkv, 56 | dropout_p=dropout_p, 57 | causal=causal, 58 | window_size=(-1, -1), 59 | alibi_slopes=None, 60 | deterministic=deterministic, 61 | return_attn_probs=True, 62 | ) 63 | 64 | local_out = extract_local(out, rank, world_size) 65 | local_lse = extract_local(lse, rank, world_size, dim=2) 66 | 67 | ring_out, ring_lse, _ = zigzag_ring_flash_attn_qkvpacked_func( 68 | local_qkv, 69 | dropout_p=dropout_p, 70 | causal=causal, 71 | window_size=(-1, -1), 72 | alibi_slopes=None, 73 | deterministic=deterministic, 74 | return_attn_probs=True, 75 | ) 76 | 77 | log("out", out, rank0_only=True) 78 | log("lse", lse, rank0_only=True) 79 | log("out diff", local_out - ring_out) 80 | log("lse diff", local_lse - ring_lse) 81 | 82 | dist.barrier() 83 | if rank == 0: 84 | print("#" * 30) 85 | print("# backward:") 86 | print("#" * 30) 87 | 88 | out.backward(dout) 89 | dqkv = qkv.grad 90 | 91 | local_dqkv = extract_local(dqkv, rank, world_size) 92 | 93 | ring_out.backward(local_dout) 94 | ring_dqkv = local_qkv.grad 95 | 96 | log("local_dqkv", local_dqkv) 97 | log("dq diff", local_dqkv[:, 0] - ring_dqkv[:, 0]) 98 | log("dk diff", local_dqkv[:, 1] - ring_dqkv[:, 1]) 99 | log("dv diff", local_dqkv[:, 2] - ring_dqkv[:, 2]) 100 | -------------------------------------------------------------------------------- /train/ring-flash-attention/test/test_zigzag_ring_flash_attn_varlen_func.py: -------------------------------------------------------------------------------- 1 | from flash_attn import flash_attn_varlen_qkvpacked_func 2 | import torch 3 | import torch.distributed as dist 4 | from ring_flash_attn import zigzag_ring_flash_attn_varlen_qkvpacked_func 5 | from utils import log, set_seed 6 | 7 | 8 | def extract_local(value, cu_seqlens, rank, world_size): 9 | local_values = [] 10 | for i in range(len(cu_seqlens) - 1): 11 | start, end = cu_seqlens[i], cu_seqlens[i + 1] 12 | local_value = value[start:end].chunk(2 * world_size, dim=0) 13 | local_values.extend( 14 | [ 15 | local_value[rank].detach().clone(), 16 | local_value[2 * world_size - 1 - rank].detach().clone(), 17 | ] 18 | ) 19 | return torch.cat(local_values, dim=0).contiguous() 20 | 21 | 22 | def extract_lse(lse, cu_seqlens): 23 | values = [] 24 | if lse.dim() == 2: 25 | for i in range(len(cu_seqlens) - 1): 26 | start, end = cu_seqlens[i], cu_seqlens[i + 1] 27 | value = lse[:, start:end] 28 | values.append(value) 29 | else: 30 | assert lse.dim() == 3 31 | for i in range(len(cu_seqlens) - 1): 32 | start, end = cu_seqlens[i], cu_seqlens[i + 1] 33 | value = lse[i, :, : end - start] 34 | values.append(value) 35 | return values 36 | 37 | 38 | if __name__ == "__main__": 39 | dist.init_process_group("nccl") 40 | rank = dist.get_rank() 41 | set_seed(rank) 42 | world_size = dist.get_world_size() 43 | dtype = torch.bfloat16 44 | device = torch.device(f"cuda:{rank}") 45 | 46 | batch_size = 1 47 | nheads = 5 48 | d = 128 49 | dropout_p = 0 50 | causal = True 51 | deterministic = False 52 | 53 | cu_seqlens = [0, 128, 1248, 4240] 54 | cu_seqlens_tensor = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) 55 | max_seqlen = (cu_seqlens_tensor[1:] - cu_seqlens_tensor[:-1]).max().item() 56 | total_length = cu_seqlens[-1] 57 | num_seq = len(cu_seqlens) - 1 58 | 59 | assert torch.all(cu_seqlens_tensor % (2 * world_size) == 0) 60 | assert d % 8 == 0 61 | 62 | qkv = torch.randn( 63 | total_length, 3, nheads, d, device=device, dtype=dtype, requires_grad=True 64 | ) 65 | dist.broadcast(qkv, src=0) 66 | 67 | dout = torch.randn(total_length, nheads, d, device=device, dtype=dtype) 68 | dist.broadcast(dout, src=0) 69 | 70 | local_cu_seqlens_tensor = cu_seqlens_tensor // world_size 71 | local_max_seqlen = max_seqlen // world_size 72 | 73 | local_qkv = extract_local(qkv, cu_seqlens, rank, world_size) 74 | local_qkv.requires_grad = True 75 | local_dout = extract_local(dout, cu_seqlens, rank, world_size) 76 | 77 | dist.barrier() 78 | if rank == 0: 79 | print("#" * 30) 80 | print("# forward:") 81 | print("#" * 30) 82 | 83 | out, lse, _ = flash_attn_varlen_qkvpacked_func( 84 | qkv, 85 | cu_seqlens_tensor, 86 | max_seqlen, 87 | dropout_p=dropout_p, 88 | causal=causal, 89 | window_size=(-1, -1), 90 | alibi_slopes=None, 91 | deterministic=deterministic, 92 | return_attn_probs=True, 93 | ) 94 | 95 | local_out = extract_local(out, cu_seqlens, rank, world_size) 96 | lse_list = extract_lse(lse, cu_seqlens) 97 | 98 | ring_out, ring_lse, _ = zigzag_ring_flash_attn_varlen_qkvpacked_func( 99 | local_qkv, 100 | local_cu_seqlens_tensor, 101 | local_max_seqlen, 102 | dropout_p=dropout_p, 103 | causal=causal, 104 | window_size=(-1, -1), 105 | alibi_slopes=None, 106 | deterministic=deterministic, 107 | return_attn_probs=True, 108 | ) 109 | 110 | ring_lse_list = extract_lse(ring_lse, local_cu_seqlens_tensor.tolist()) 111 | 112 | log("out", out, rank0_only=True) 113 | log("out diff", local_out - ring_out) 114 | 115 | for i, (lse, ring_lse) in enumerate(zip(lse_list, ring_lse_list)): 116 | local_lse = lse.chunk(2 * world_size, dim=-1) 117 | local_lse = torch.cat( 118 | [local_lse[rank], local_lse[2 * world_size - 1 - rank]], dim=-1 119 | ) 120 | log(f"lse {i}", lse, rank0_only=True) 121 | log(f"lse diff {i}", local_lse - ring_lse) 122 | 123 | dist.barrier() 124 | if rank == 0: 125 | print("#" * 30) 126 | print("# backward:") 127 | print("#" * 30) 128 | 129 | out.backward(dout) 130 | dqkv = qkv.grad 131 | local_dqkv = extract_local(dqkv, cu_seqlens, rank, world_size) 132 | 133 | ring_out.backward(local_dout) 134 | ring_dqkv = local_qkv.grad 135 | 136 | log("local_dqkv", local_dqkv) 137 | log("dq diff", local_dqkv[:, 0] - ring_dqkv[:, 0]) 138 | log("dk diff", local_dqkv[:, 1] - ring_dqkv[:, 1]) 139 | log("dv diff", local_dqkv[:, 2] - ring_dqkv[:, 2]) 140 | -------------------------------------------------------------------------------- /train/ring-flash-attention/test/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import torch 4 | import torch.distributed as dist 5 | 6 | 7 | def set_seed(rank, seed=42): 8 | seed = rank + seed 9 | random.seed(seed) 10 | torch.manual_seed(seed) 11 | torch.cuda.manual_seed(seed) 12 | torch.cuda.manual_seed_all(seed) 13 | 14 | 15 | def log(msg, a, rank0_only=False): 16 | world_size = dist.get_world_size() 17 | rank = dist.get_rank() 18 | if rank0_only: 19 | if rank == 0: 20 | print( 21 | f"{msg}: " 22 | f"max {a.abs().max().item():.3g}, " 23 | f"mean {a.abs().mean().item():.3g}", 24 | flush=True, 25 | ) 26 | return 27 | 28 | for i in range(world_size): 29 | if i == rank: 30 | if rank == 0: 31 | print(f"{msg}:") 32 | print( 33 | f"[{rank}] " 34 | f"max {a.abs().max().item():.3g}, " 35 | f"mean {a.abs().mean().item():.3g}", 36 | flush=True, 37 | ) 38 | dist.barrier() 39 | -------------------------------------------------------------------------------- /train/ring_attn_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import transformers 4 | from typing import Optional 5 | import torch.distributed as dist 6 | import torch.nn.functional as F 7 | from transformers.modeling_flash_attention_utils import ( 8 | _flash_supports_window_size, 9 | is_flash_attn_greater_or_equal, 10 | ) 11 | from ring_flash_attn.llama3_flash_attn_varlen import ( 12 | llama3_flash_attn_varlen_func, 13 | llama3_flash_attn_prepare_cu_seqlens 14 | ) 15 | from flash_attn.losses.cross_entropy import CrossEntropyLoss 16 | from flash_attn.ops.rms_norm import rms_norm 17 | 18 | def flash_rms_norm(self, x): 19 | return rms_norm(x, self.weight, self.variance_epsilon) 20 | 21 | RING_ATTN_GROUP = None 22 | 23 | 24 | def set_ring_attn_group(group): 25 | global RING_ATTN_GROUP 26 | RING_ATTN_GROUP = group 27 | 28 | 29 | def get_ring_attn_group(): 30 | return RING_ATTN_GROUP 31 | 32 | 33 | def reset_ring_attn_position_ids(start, end, packed_seq_lens): 34 | """ 35 | Calculate position ids for packed_seq_ids[start:end]. 36 | For example, if the packed_seq_lens is [3, 2, 4, 1], start=2, end=8, 37 | the position ids will be [2, 0, 1, 0, 1, 2]. 38 | 39 | Args: 40 | start: the start position 41 | end: the end position 42 | packed_seq_lens: the sequence lengths of packed sequences 43 | """ 44 | position_ids = torch.zeros((1, end - start), dtype=torch.long, device=torch.cuda.current_device()) 45 | offset = 0 46 | for seqlen in packed_seq_lens: 47 | seq_start = max(offset, start) 48 | seq_end = min(offset + seqlen, end) 49 | if seq_start < seq_end: 50 | position_ids[0, seq_start - start: seq_end - start] = torch.arange(seq_start - offset, seq_end - offset) 51 | 52 | offset += seqlen 53 | if offset >= end: 54 | break 55 | return position_ids 56 | 57 | 58 | def update_ring_attn_params(packed_seq_lens, total_seq_len): 59 | """ 60 | Calculate the cu_seqlens for the current forward pass and pass the value to 61 | the substituted ring_flash_attn. 62 | 63 | Note that total_seq_len may be larger than the sum of packed_seq_lens because of padding. 64 | """ 65 | assert RING_ATTN_GROUP is not None 66 | cu_seqlens = torch.cumsum( 67 | packed_seq_lens.clone().detach(), #torch.tensor(packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32), 68 | dim=-1, 69 | dtype=torch.int32, 70 | ) 71 | cu_seqlens = F.pad(F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len) 72 | update_ring_flash_attn_params(cu_seqlens, RING_ATTN_GROUP) 73 | 74 | 75 | 76 | def convert_to_local_input(input_ids, labels, attention_mask, packed_seq_lens, category_ids): 77 | ring_attn_rank = dist.get_rank(group=RING_ATTN_GROUP) 78 | ring_attn_size = dist.get_world_size(group=RING_ATTN_GROUP) 79 | total_seq_len = input_ids.numel() 80 | local_seq_len = total_seq_len // ring_attn_size 81 | start, end = ring_attn_rank * local_seq_len, (ring_attn_rank + 1) * local_seq_len 82 | local_input_ids = input_ids[:, start:end] 83 | local_labels_ids = labels[:, start:end] 84 | local_attention_mask = attention_mask[:, start:end] 85 | local_category_ids = category_ids[:, start:end] 86 | local_position_ids = reset_ring_attn_position_ids(start, end, packed_seq_lens) 87 | update_ring_attn_params(packed_seq_lens, total_seq_len) 88 | return local_input_ids, local_labels_ids, local_attention_mask, local_position_ids, local_category_ids 89 | 90 | 91 | DATA_PARAMS = {} 92 | 93 | 94 | def update_ring_flash_attn_params( 95 | cu_seqlens: torch.Tensor, process_group: dist.ProcessGroup 96 | ): 97 | world_size = dist.get_world_size(group=process_group) 98 | rank = dist.get_rank(group=process_group) 99 | ( 100 | cu_seqlens_q, 101 | cu_seqlens_k, 102 | max_seqlen_q, 103 | max_seqlen_k, 104 | local_k_slice, 105 | ) = llama3_flash_attn_prepare_cu_seqlens(cu_seqlens, True, rank, world_size) 106 | DATA_PARAMS.update( 107 | { 108 | "cu_seqlens_q": cu_seqlens_q, 109 | "cu_seqlens_k": cu_seqlens_k, 110 | "max_seqlen_q": max_seqlen_q, 111 | "max_seqlen_k": max_seqlen_k, 112 | "local_k_slice": local_k_slice, 113 | } 114 | ) 115 | 116 | 117 | def create_ring_flash_attention_forward( 118 | process_group: dist.ProcessGroup, heads_k_stride: int 119 | ): 120 | def _flash_attention_forward( 121 | query_states: torch.Tensor, 122 | key_states: torch.Tensor, 123 | value_states: torch.Tensor, 124 | attention_mask: torch.Tensor, 125 | query_length: int, 126 | is_causal: bool, 127 | dropout: float = 0.0, 128 | position_ids: Optional[torch.Tensor] = None, 129 | softmax_scale: Optional[float] = None, 130 | sliding_window: Optional[int] = None, 131 | use_top_left_mask: bool = False, 132 | softcap: Optional[float] = None, 133 | deterministic: bool = None, 134 | ): 135 | """ 136 | Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token 137 | first unpad the input, then computes the attention scores and pad the final attention scores. 138 | 139 | Args: 140 | query_states (`torch.Tensor`): 141 | Input query states to be passed to Flash Attention API 142 | key_states (`torch.Tensor`): 143 | Input key states to be passed to Flash Attention API 144 | value_states (`torch.Tensor`): 145 | Input value states to be passed to Flash Attention API 146 | attention_mask (`torch.Tensor`): 147 | The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the 148 | position of padding tokens and 1 for the position of non-padding tokens. 149 | dropout (`float`): 150 | Attention dropout 151 | softmax_scale (`float`, *optional*): 152 | The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) 153 | use_top_left_mask (`bool`, defaults to `False`): 154 | flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. 155 | softcap (`float`, *optional*): 156 | Softcap for the attention logits, used e.g. in gemma2. 157 | deterministic (`bool`, *optional*): 158 | Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled. 159 | """ 160 | if not use_top_left_mask: 161 | causal = is_causal 162 | else: 163 | # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__. 164 | causal = is_causal and query_length != 1 165 | 166 | # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length). 167 | use_sliding_windows = ( 168 | _flash_supports_window_size 169 | and sliding_window is not None 170 | and key_states.shape[1] > sliding_window 171 | ) 172 | flash_kwargs = ( 173 | {"window_size": (sliding_window, sliding_window)} 174 | if use_sliding_windows 175 | else {} 176 | ) 177 | 178 | if is_flash_attn_greater_or_equal("2.4.1"): 179 | if deterministic is None: 180 | deterministic = ( 181 | os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" 182 | ) 183 | flash_kwargs["deterministic"] = deterministic 184 | assert ( 185 | softcap is None 186 | ), "llama3_flash_attn_varlen_func does not support softcap yet." 187 | # flash_kwargs["softcap"] = softcap 188 | flash_kwargs["group"] = process_group 189 | 190 | # not sure why attention_mask can be not None... 191 | assert causal, "only causal attention is supported yet." 192 | batch_size = query_states.size(0) 193 | assert batch_size == 1, "varlen data should be processed in advance." 194 | 195 | attn_output = llama3_flash_attn_varlen_func( 196 | query_states.squeeze(dim=0), 197 | key_states.squeeze(dim=0), 198 | value_states.squeeze(dim=0), 199 | cu_seqlens_q=DATA_PARAMS["cu_seqlens_q"], 200 | cu_seqlens_k=DATA_PARAMS["cu_seqlens_k"], 201 | max_seqlen_q=DATA_PARAMS["max_seqlen_q"], 202 | max_seqlen_k=DATA_PARAMS["max_seqlen_k"], 203 | heads_k_stride=heads_k_stride, 204 | local_k_slice=DATA_PARAMS["local_k_slice"], 205 | dropout_p=dropout, 206 | softmax_scale=softmax_scale, 207 | causal=causal, 208 | **flash_kwargs, 209 | ) 210 | 211 | attn_output = attn_output.unsqueeze(dim=0) 212 | 213 | return attn_output 214 | 215 | return _flash_attention_forward 216 | 217 | 218 | class Config(object): 219 | def __init__(self, ring_attn_size, ring_head_stride): 220 | self.ring_attn_size = ring_attn_size 221 | self.ring_head_stride = ring_head_stride 222 | self.ring_attn_rank = None 223 | 224 | 225 | def setup_ring_attn(self): 226 | for i in range(dist.get_world_size() // self.ring_attn_size): 227 | ring_attn_ranks = list( 228 | range( 229 | i * self.ring_attn_size, 230 | (i + 1) * self.ring_attn_size, 231 | ) 232 | ) 233 | group = dist.new_group(ranks=ring_attn_ranks, backend="nccl") 234 | if dist.get_rank() in ring_attn_ranks: 235 | set_ring_attn_group(group) 236 | self.ring_attn_rank = dist.get_rank(group=group) 237 | 238 | transformers.models.qwen2.modeling_qwen2._flash_attention_forward = create_ring_flash_attention_forward( 239 | RING_ATTN_GROUP, 240 | heads_k_stride=self.ring_head_stride 241 | ) 242 | transformers.models.qwen2.modeling_qwen2.Qwen2RMSNorm.forward = flash_rms_norm 243 | transformers.models.qwen2.modeling_qwen2.CrossEntropyLoss = CrossEntropyLoss 244 | 245 | transformers.models.llama.modeling_llama._flash_attention_forward = create_ring_flash_attention_forward( 246 | RING_ATTN_GROUP, 247 | heads_k_stride=self.ring_head_stride 248 | ) 249 | transformers.models.llama.modeling_llama.LlamaRMSNorm.forward = flash_rms_norm 250 | transformers.models.llama.modeling_llama.CrossEntropyLoss = CrossEntropyLoss 251 | return RING_ATTN_GROUP 252 | 253 | -------------------------------------------------------------------------------- /train/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import time 4 | import json 5 | import torch 6 | import argparse 7 | import pandas as pd 8 | from tqdm import tqdm 9 | from datetime import timedelta 10 | from accelerate import Accelerator 11 | from torch import distributed as dist 12 | from flash_attn.losses.cross_entropy import CrossEntropyLoss 13 | from flash_attn.utils.distributed import all_gather, all_reduce 14 | from transformers import AutoModelForCausalLM, AutoTokenizer 15 | from accelerate.utils import ( 16 | InitProcessGroupKwargs, 17 | set_seed, 18 | DummyOptim, 19 | DummyScheduler, 20 | ) 21 | 22 | os.environ["WANDB_MODE"] = "offline" 23 | 24 | from data import SFTData 25 | from ring_attn_utils import Config, convert_to_local_input 26 | 27 | 28 | 29 | def main(args): 30 | set_seed(args.seed) 31 | timeout = InitProcessGroupKwargs(timeout=timedelta(seconds=1_000_000)) 32 | accelerator = Accelerator( 33 | gradient_accumulation_steps=args.gradient_accumulation_steps, 34 | mixed_precision="bf16", 35 | kwargs_handlers=[timeout], 36 | log_with="wandb" 37 | ) 38 | accelerator.init_trackers(project_name=args.project_name) 39 | world_size = accelerator.num_processes 40 | sp_world_size = args.sequence_parallel_degree 41 | dp_world_size = world_size // sp_world_size 42 | 43 | accelerator.print(f"world_size: {world_size}") 44 | accelerator.print(f"sp_world_size: {sp_world_size}") 45 | accelerator.print(f"dp_world_size: {dp_world_size}") 46 | 47 | config = Config(ring_attn_size=sp_world_size, ring_head_stride=4) 48 | ring_group = config.setup_ring_attn() 49 | 50 | with open(args.data_config_path, 'r') as f: 51 | data_config = json.loads(f.read()) 52 | idx2name = {} 53 | for file in data_config['ratio']: 54 | idx2name[file['category_id']] = file['category_name'] 55 | 56 | sft_data = SFTData(args.data_path) 57 | 58 | tokenizer = AutoTokenizer.from_pretrained(args.model_path) 59 | model = AutoModelForCausalLM.from_pretrained( 60 | args.model_path, 61 | torch_dtype=torch.bfloat16, 62 | _attn_implementation="flash_attention_2", 63 | ) 64 | 65 | num_steps_per_epoch = math.ceil(len(sft_data.data) / dp_world_size / args.gradient_accumulation_steps) 66 | max_train_steps = num_steps_per_epoch * args.num_epochs 67 | 68 | accelerator.print(f"数据总量: {len(sft_data.data)}") 69 | accelerator.print(f"训练步数: {max_train_steps}") 70 | 71 | optim = DummyOptim(model.parameters()) 72 | scheduler = DummyScheduler(optim) 73 | 74 | local_ce = CrossEntropyLoss(reduction="none") 75 | 76 | model, optim, scheduler = accelerator.prepare(model, optim, scheduler) 77 | model.gradient_checkpointing_enable() 78 | accelerator.register_for_checkpointing(scheduler) 79 | model.train() 80 | 81 | global_step = 0 82 | for epoch in range(args.num_epochs): 83 | train_loader = sft_data.get_dataloader(dp_world_size, dist.get_rank() // sp_world_size, seed=args.seed, 84 | epoch=epoch, shuffle=True) 85 | train_steps = math.ceil(len(train_loader) / args.gradient_accumulation_steps) 86 | accelerator.print(f"每个epoch数据总量: {len(train_loader)}") 87 | accelerator.print(f"每个epoch训练步数: {train_steps}") 88 | progress_bar = tqdm( 89 | range(train_steps), disable=not accelerator.is_local_main_process 90 | ) 91 | for step, batch in enumerate(train_loader): 92 | input_ids = batch['input_ids'].to(accelerator.device) 93 | labels = batch['labels'].to(accelerator.device) 94 | attention_mask = batch['attention_mask'].to(accelerator.device) 95 | packed_seq_lens = batch['packed_seq_lens'].to(accelerator.device) 96 | category_ids = batch['category_ids'].to(accelerator.device) 97 | 98 | local_input_ids, local_labels, local_attention_mask, local_position_ids, local_category_ids = convert_to_local_input( 99 | input_ids, labels, attention_mask, packed_seq_lens, category_ids 100 | ) 101 | 102 | with accelerator.accumulate(model): 103 | out = model( 104 | input_ids=local_input_ids, 105 | attention_mask=local_attention_mask, 106 | position_ids=local_position_ids, 107 | ) 108 | shift_logits = out.logits[..., :-1, :].contiguous() 109 | shift_labels = local_labels[..., 1:].contiguous() 110 | token_losses = local_ce(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 111 | 112 | if args.use_turn_loss: 113 | all_token_losses = all_gather(token_losses, ring_group) 114 | all_shift_labels = all_gather(shift_labels, ring_group) 115 | mask = all_shift_labels.view(-1) != -100 116 | idx = torch.where(mask)[0] 117 | diffs = torch.diff(idx) 118 | split_points = torch.where(diffs > 1)[0] + 1 119 | groups = torch.tensor_split(idx, split_points.tolist()) 120 | loss = torch.stack([all_token_losses[g].mean() for g in groups]).mean() 121 | loss = torch.nan_to_num(loss, nan=0.0) 122 | else: 123 | loss = token_losses[shift_labels.view(-1) != -100].mean() 124 | dist.all_reduce(loss, op=dist.ReduceOp.SUM, group=ring_group) 125 | loss = loss / sp_world_size 126 | loss = torch.nan_to_num(loss, nan=0.0) 127 | accelerator.backward(loss) 128 | 129 | if accelerator.sync_gradients: 130 | gathered_loss = accelerator.gather(loss.clone().detach()) 131 | mask = (gathered_loss != 0) 132 | if mask.sum() == 0: 133 | loss_ = torch.tensor(0.0, device=accelerator.device) 134 | else: 135 | loss_ = gathered_loss[mask].mean() 136 | loss_log = { 137 | "epoch": epoch, 138 | "steps": step, 139 | "lr": scheduler.get_last_lr()[0], 140 | "loss": loss_.item() 141 | } 142 | 143 | progress_bar.set_postfix(loss_log) 144 | progress_bar.update(1) 145 | time.sleep(0.1) 146 | 147 | token_losses_ = token_losses.clone().detach() 148 | all_token_losses = accelerator.gather(token_losses_) 149 | all_category_ids = accelerator.gather(local_category_ids)[..., 1:].contiguous() 150 | 151 | idx2mean_loss = pd.DataFrame({ 152 | 'category_id': all_category_ids.cpu().ravel(), 153 | 'token_loss': all_token_losses.cpu().numpy().ravel() 154 | }).groupby('category_id')['token_loss'].mean().to_dict() 155 | 156 | for idx, mean_loss in idx2mean_loss.items(): 157 | if idx == 0 or idx not in idx2name: 158 | continue 159 | loss_log[f"{idx2name[idx]}_loss"] = mean_loss 160 | 161 | loss_log_str = '\n' + json.dumps(loss_log, ensure_ascii=False, indent=4) 162 | accelerator.print(loss_log_str) 163 | accelerator.log(loss_log, step=global_step) 164 | global_step += 1 165 | 166 | optim.step() 167 | scheduler.step() 168 | optim.zero_grad() 169 | 170 | if args.save_checkpoint is not None: 171 | if not os.path.exists(args.save_checkpoint): 172 | os.makedirs(args.save_checkpoint, exist_ok=True) 173 | save_path = f"{args.save_checkpoint}/epoch{epoch}_end" 174 | accelerator.print(f"Saving model to {save_path}") 175 | accelerator.wait_for_everyone() 176 | state_dict = accelerator.get_state_dict(model) 177 | accelerator.unwrap_model(model).save_pretrained( 178 | save_path, 179 | is_main_process=accelerator.is_main_process, 180 | save_function=accelerator.save, 181 | state_dict=state_dict, 182 | ) 183 | if accelerator.is_main_process: 184 | tokenizer.save_pretrained(save_path) 185 | os.remove(f"{save_path}/model.safetensors") 186 | 187 | accelerator.print(f"Saving Finished") 188 | 189 | accelerator.print(f"Training Finished") 190 | accelerator.end_training() 191 | 192 | 193 | if __name__ == "__main__": 194 | args = argparse.ArgumentParser() 195 | args.add_argument("--project_name", type=str, default='') 196 | args.add_argument("--gradient_accumulation_steps", type=int, default=1) 197 | args.add_argument("--save_checkpoint", type=str, default="") 198 | args.add_argument("--seed", type=int, default=2025) 199 | args.add_argument("--model_path", type=str, default="") 200 | args.add_argument("--data_path", type=str, default="") 201 | args.add_argument("--data_config_path", type=str, default="") 202 | args.add_argument("--sequence_parallel_degree", type=int, default=8) 203 | args.add_argument("--num_epochs", type=int, default=4) 204 | args.add_argument("--use_turn_loss", action='store_true') 205 | main(args.parse_args()) -------------------------------------------------------------------------------- /train/train.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | export HOME=/tmp 3 | export CUDA_DEVICE_MAX_CONNECTIONS=1 4 | export PYTORCH_CUDA_ALLOC_CONF='max_split_size_mb:2048' 5 | 6 | ip=`hostname -i` 7 | 8 | project_name=Qwen2.5-72B-Instruct 9 | model_path=path_to_model 10 | data_path=../data/dataset/distill_r1_110k.packed.jsonl 11 | data_config_path=../data/config/distill_r1_110k.json 12 | save_checkpoint=path_to_save_dir 13 | 14 | python config.py --max_steps=2000 --learning_rate=2e-5 --warmup_num_steps=200 15 | 16 | master_ip=192.168.1.1 # 主节点ip 17 | num_machines=24 18 | num_processes=192 19 | machine_rank=0 # 0到num_machines-1,每台机器一个rank 20 | 21 | nohup accelerate launch \ 22 | --config_file config/multi_node.yaml \ 23 | --num_processes=$num_processes \ 24 | --num_machines=$num_machines \ 25 | --machine_rank=$machine_rank \ 26 | --main_process_ip=$master_ip \ 27 | train.py \ 28 | --project_name $project_name \ 29 | --gradient_accumulation_steps 1 \ 30 | --save_checkpoint $save_checkpoint \ 31 | --seed 2025 \ 32 | --model_path $model_path \ 33 | --data_path $data_path \ 34 | --data_config_path $data_config_path \ 35 | --sequence_parallel_degree 8 \ 36 | --num_epochs 4 > logs/$ip.log 2>&1 & --------------------------------------------------------------------------------