├── README.md
├── data
├── data_process.sh
└── script
│ ├── preprocess.example.py
│ ├── pseudo.multiturn.py
│ ├── step1.sample.py
│ ├── step2.tokenize.py
│ └── step3.pack.py
├── docs
├── image-1.png
├── image-2.png
├── image-3.png
└── image-4.png
└── train
├── config.py
├── config
├── multi_node.yaml
├── single_node.yaml
└── zero3_offload.json
├── data.py
├── requirements.txt
├── ring-flash-attention
├── LICENSE
├── README.md
├── benchmark
│ ├── benchmark_kvpacked_func.py
│ └── benchmark_varlen_kvpacked_func.py
├── build
│ └── lib
│ │ └── ring_flash_attn
│ │ ├── __init__.py
│ │ ├── adapters
│ │ ├── __init__.py
│ │ └── hf_adapter.py
│ │ ├── llama3_flash_attn_varlen.py
│ │ ├── ring_flash_attn.py
│ │ ├── ring_flash_attn_varlen.py
│ │ ├── stripe_flash_attn.py
│ │ ├── triton_utils.py
│ │ ├── utils.py
│ │ ├── zigzag_ring_flash_attn.py
│ │ └── zigzag_ring_flash_attn_varlen.py
├── pyproject.toml
├── ring_flash_attn.egg-info
│ ├── PKG-INFO
│ ├── SOURCES.txt
│ ├── dependency_links.txt
│ └── top_level.txt
├── ring_flash_attn
│ ├── .___pycache__
│ ├── .ipynb_checkpoints
│ │ ├── llama3_flash_attn_varlen-checkpoint.py
│ │ └── utils-checkpoint.py
│ ├── __init__.py
│ ├── adapters
│ │ ├── .ipynb_checkpoints
│ │ │ └── hf_adapter-checkpoint.py
│ │ ├── __init__.py
│ │ └── hf_adapter.py
│ ├── llama3_flash_attn_varlen.py
│ ├── ring_flash_attn.py
│ ├── ring_flash_attn_varlen.py
│ ├── stripe_flash_attn.py
│ ├── triton_utils.py
│ ├── utils.py
│ ├── zigzag_ring_flash_attn.py
│ └── zigzag_ring_flash_attn_varlen.py
├── setup.py
└── test
│ ├── test_llama3_flash_attn_varlen_func.py
│ ├── test_llama3_prepare_cu_seqlens.py
│ ├── test_ring_flash_attn_func.py
│ ├── test_ring_flash_attn_varlen_func.py
│ ├── test_stripe_flash_attn_func.py
│ ├── test_triton_kernels.py
│ ├── test_zigzag_ring_flash_attn_func.py
│ ├── test_zigzag_ring_flash_attn_varlen_func.py
│ └── utils.py
├── ring_attn_utils.py
├── train.py
└── train.sh
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
OpenSFT
3 |
4 |
5 |
6 | Open-source / Lightweight / Easy-to-use / Large-Scale / Extra-Long-Text
7 |
8 |
9 |
10 |
11 |
12 | OpenSFT,一个开源的SFT训练框架,基于 accelerator + deepspeed + ring flash attention 实现。
13 |
14 | 本项目实现了length-pack数据组织逻辑,进一步增加了并行量;实现了序列并行下的分类loss统计,更好地监控各类别效果;新增了序列并行下的turn-loss,兼顾长短对话。
15 |
16 | 本训练框架非常轻量,易于学习和二次开发,欢迎star。
17 |
18 | ### 更新
19 |
20 | 2025.03.27 新增按概率拼接伪多轮,详见data/script/pseudo.multiturn.py
21 |
22 |
23 |
24 | ### 环境安装
25 |
26 | ```bash
27 | git clone https://github.com/mlpod/OpenSFT.git
28 | cd OpenSFT/train
29 | pip install -r requirements.txt
30 | ```
31 |
32 | ### 数据组织方式
33 |
34 | SFT中几种不同的数据组织方式:
35 |
36 | 1. 传统数据组织方式,红色部分是计算损失部分,存在大量padding。
37 |
38 |
39 |
40 | 2. 变长、多轮loss数据组织方式,红色部分是计算损失部分,存在padding。
41 |
42 |
43 |
44 | 3. 将batch内的序列打包成一个样本,序列并行计算。有padding(需要补齐到能被序列并行度整除),不同样本直接长度差异大。
45 |
46 |
47 | **本项目实现的数据组织方式**: 在3基础上,将batch内改为按指定最大长度组织数据,打包直到快超出最大长度,因此有更少的padding。不同样本直接长度差异微乎其微。例如指定128K,在distill_r1_110k数据集上可以压缩至1k样本量。
48 |
49 | 在线打包会带来一定耗时,在此采用了离线计算的方式,将训练和数据处理进行解耦。
50 | 另外离线计算中应用了贪心、并发等实现,进一步减少了pack后样本长度差异和时长,详见代码。
51 |
52 | 数据处理代码放在了data目录中。data目录中的raw目录是存放未预处理的文件,数据格式如下:
53 | ```json
54 | {
55 | "messages": [
56 | {
57 | "role": "system",
58 | "content": ""
59 | },
60 | {
61 | "role": "user",
62 | "content": ""
63 | },
64 | {
65 | "role": "assistant",
66 | "content": ""
67 | }
68 | ],
69 | "labels": [0, 0, 1],
70 | "meta": {
71 | "category_id": 1,
72 | "category_name": ""
73 | }
74 | }
75 | ```
76 | 其中 labels 是 messages 中计算损失的轮次,仅 assistant 有效。category_id不为0时,在训练过程中会打印该category_id的loss。
77 |
78 | 数据配比配置文件格式如下:
79 | ```json
80 | {
81 | "data_name": "",
82 | "data_path": "",
83 | "ratio": [
84 | {
85 | "category_id": 1,
86 | "category_name": "",
87 | "size": 100,
88 | "sample_rate": 1.0
89 | },
90 | {
91 | "category_id": 2,
92 | "category_name": "",
93 | "size": 200,
94 | "sample_rate": 0.8
95 | }
96 | ]
97 | }
98 | ```
99 | 用户需修改数据预处理逻辑代码 data/step1.preprocess.py,将数据处理成目标格式,并生成数据配比配置文件。
100 | 用户可按需修改配比。
101 |
102 | 脚本中的demo数据:https://huggingface.co/datasets/Congliu/Chinese-DeepSeek-R1-Distill-data-110k/resolve/main/distill_r1_110k.jsonl
103 |
104 | 详见数据处理脚本
105 | ```bash
106 | sh data_process.sh
107 | ```
108 |
109 | ### 训练
110 | ```bash
111 | sh train.sh
112 | ```
113 | #### category-loss
114 |
115 | ```json
116 | {
117 | "epoch": 2,
118 | "steps": 3,
119 | "lr": 1.22375e-05,
120 | "loss": 0.515095591545105,
121 | "coig/neo_loss": 0.5617072582244873,
122 | "stem_zh/phy_loss": 0.4811963737010956,
123 | "EduChat-Math_loss": 0.4951120913028717,
124 | "meta-math/GSM8K_zh_loss": 0.5640832781791687,
125 | "exam/coig_exam_loss": 0.6263442635536194,
126 | "gavinluo/applied_math_loss": 0.4919000566005707,
127 | "stem_zh/chem_loss": 0.4528641700744629,
128 | "stem_zh/bio_loss": 0.46091940999031067,
129 | "zhihu/zhihu_score9.0-10_clean_v10_loss": 0.5875096917152405,
130 | "xhs/xhs_loss": 0.7661288380622864,
131 | "stem_zh/med_loss": 0.42540857195854187,
132 | "human_value/100poison_loss": 0.5484293699264526,
133 | "ruozhiba/ruozhiba_ruozhiba_loss": 0.825197160243988,
134 | "logi_qa/logi-qa_loss": 0.6175104975700378,
135 | "Haijian/Advanced-Math_loss": 0.4288356602191925,
136 | "exam/kaoyan_loss": 0.6865882873535156
137 | }
138 | ```
139 |
140 | #### turn-loss
141 | 受[SFT loss 计算的那些坑(多轮合并/packing)](https://zhuanlan.zhihu.com/p/721652210)的启发,新增了turn-loss,以兼顾长短对话。细节详见代码,大家可以按需使用。
142 |
143 |
--------------------------------------------------------------------------------
/data/data_process.sh:
--------------------------------------------------------------------------------
1 | set -x
2 |
3 |
4 | step0=0
5 | step1=1
6 | step2=0
7 | step3=0
8 |
9 | if [ $step0 -ne 0 ]; then
10 | python script/preprocess.example.py
11 | fi
12 |
13 | config_path=config/distill_r1_110k.json
14 | preprocessed_data_path=dataset/distill_r1_110k.preprocessed.jsonl
15 | sampled_data_path=dataset/distill_r1_110k.sampled.jsonl
16 | tokenized_data_path=dataset/distill_r1_110k.tokenized.jsonl
17 | packed_data_path=dataset/distill_r1_110k.packed.jsonl
18 |
19 |
20 | if [ $step1 -ne 0 ]; then
21 | python script/step1.sample.py \
22 | --config-path=$config_path \
23 | --preprocessed-data-path=$preprocessed_data_path \
24 | --output-path=$sampled_data_path \
25 | --seed=2025 --num-workers=10
26 | fi
27 |
28 | # python script/pseudo.multiturn.py
29 |
30 | tokenizer_path=
31 | padding_value=
32 |
33 | if [ $step2 -ne 0 ]; then
34 | python script/step2.tokenize.py \
35 | --input-path=$sampled_data_path \
36 | --output-path=$tokenized_data_path \
37 | --tokenizer-path=$tokenizer_path \
38 | --num-workers=10
39 | fi
40 |
41 |
42 | if [ $step3 -ne 0 ]; then
43 | python script/step3.pack.py \
44 | --input-path=$tokenized_data_path \
45 | --output-path=$packed_data_path \
46 | --max-length=131072 \
47 | --padding-value=$padding_value \
48 | --num-workers=10
49 | fi
--------------------------------------------------------------------------------
/data/script/preprocess.example.py:
--------------------------------------------------------------------------------
1 | import json
2 | from functools import partial
3 | from collections import Counter
4 | from typing import Dict, Any, List
5 | from pathlib import Path
6 | from datasets import load_dataset, Dataset
7 |
8 |
9 | def process_record(record: Dict[str, Any], n2i: Dict[str, int], i2n: Dict[int, str]) -> Dict[str, Any]:
10 | """处理单条数据记录
11 |
12 | Args:
13 | record: 原始数据记录
14 | n2i: name到id的映射字典
15 | i2n: id到name的映射字典
16 |
17 | Returns:
18 | 处理后的数据记录
19 | """
20 | think_template = '''
21 | {think}
22 |
23 | {answer}'''
24 |
25 | try:
26 | category_id = n2i[record['repo_name']]
27 | category_name = i2n[category_id]
28 |
29 | return {
30 | "meta": {
31 | "category_id": category_id,
32 | "category_name": category_name
33 | },
34 | "messages": [
35 | {
36 | "role": "system",
37 | "content": "你是一位擅长深度思考的助手。在回答问题之前,你会像人类一样展现出一步一步的思考过程,包括问题理解、分析推理、自我质疑、反思验证等环节。之后你会基于思考过程,作出准确的回答。"
38 | },
39 | {
40 | "role": "user",
41 | "content": record['input']
42 | },
43 | {
44 | "role": "assistant",
45 | "content": think_template.format(
46 | think=record['reasoning_content'].strip(),
47 | answer=record['content'].strip()
48 | )
49 | }
50 | ],
51 | "labels": [0, 0, 1]
52 | }
53 | except KeyError as e:
54 | raise KeyError(f"处理记录时发生错误,缺少必要字段: {e}")
55 |
56 | def generate_category_mapping(dataset: Dataset) -> tuple[Dict[int, str], Dict[str, int], List[Dict]]:
57 | """生成类别映射和配置信息
58 |
59 | Args:
60 | dataset: 原始数据集
61 |
62 | Returns:
63 | tuple包含:
64 | - id2name: id到name的映射
65 | - name2id: name到id的映射
66 | - ratio_config: 采样率配置
67 | """
68 | id2name = {}
69 | name2id = {}
70 | ratio_config = []
71 |
72 | name_cnt = Counter(dataset['repo_name'])
73 | category_id = 1 # category_id 为 0 时不统计损失
74 |
75 | for name in name_cnt:
76 | id2name[category_id] = name
77 | name2id[name] = category_id
78 | ratio_config.append({
79 | "category_id": category_id,
80 | "category_name": name,
81 | "size": name_cnt[name],
82 | "sample_rate": 1.0
83 | })
84 | category_id += 1
85 |
86 | return id2name, name2id, ratio_config
87 |
88 | def main():
89 | """主函数"""
90 | # 定义路径
91 | raw_data_path = 'raw/distill_r1_110k.jsonl'
92 | config_path = 'config/distill_r1_110k.json'
93 | preprocessed_data_path = 'dataset/distill_r1_110k.preprocessed.jsonl'
94 |
95 | # 确保必要的目录存在
96 | Path(config_path).parent.mkdir(parents=True, exist_ok=True)
97 | Path(preprocessed_data_path).parent.mkdir(parents=True, exist_ok=True)
98 |
99 | try:
100 | # 加载数据集
101 | dataset = load_dataset('json', data_files=raw_data_path)['train']
102 |
103 | # 生成配置
104 | id2name, name2id, ratio_config = generate_category_mapping(dataset)
105 |
106 | config = {
107 | "data_name": "distill_r1_110k",
108 | "ratio": ratio_config
109 | }
110 |
111 | # 保存配置
112 | with open(config_path, 'w', encoding='utf-8') as f:
113 | f.write(json.dumps(config, ensure_ascii=False, indent=4))
114 |
115 | # 处理数据集
116 | process_record_partial = partial(process_record, n2i=name2id, i2n=id2name)
117 | preprocessed_dataset = dataset.map(
118 | process_record_partial,
119 | num_proc=10,
120 | remove_columns=dataset.column_names
121 | )
122 |
123 | # 保存处理后的数据集
124 | preprocessed_dataset.save_to_disk(preprocessed_data_path)
125 |
126 | except Exception as e:
127 | print(f"处理过程中发生错误: {e}")
128 | raise
129 |
130 | if __name__ == "__main__":
131 | main()
132 |
133 |
134 |
135 |
--------------------------------------------------------------------------------
/data/script/pseudo.multiturn.py:
--------------------------------------------------------------------------------
1 | import random
2 | from datasets import load_from_disk, Dataset
3 |
4 |
5 | def merge_records(record1, record2):
6 | """
7 | 合并两个记录。
8 | :param record1: 第一个记录 (dict)
9 | :param record2: 第二个记录 (dict)
10 | :return: 合并后的记录 (dict)
11 | """
12 | c1 = record1["messages"][0]["role"] == "system" and record2["messages"][0]["role"] == "system"
13 | c2 = record1["messages"][0]["content"] == record2["messages"][0]["content"]
14 | if c1 and c2:
15 | record2["messages"] = record2["messages"][1:]
16 | record2['labels'] = record2['labels'][1:]
17 |
18 | if 'meta' in record1:
19 | record1['category_id_list'] = [record1["meta"]["category_id"]] * len(record1["messages"])
20 |
21 | merged_record = {
22 | "messages": record1["messages"] + record2["messages"], # 合并 messages
23 | "labels": record1["labels"] + record2["labels"], # 合并 labels
24 | "category_id_list": record1['category_id_list'] +
25 | [record2["meta"]["category_id"]] * len(record2["messages"]), # 第二条记录的 category_id
26 | }
27 | assert len(merged_record['messages']) == len(merged_record['labels']) == len(merged_record['category_id_list'])
28 | return merged_record
29 |
30 |
31 |
32 |
33 | def merge(input_file, output_file, merge_probability):
34 | """
35 | :param input_file: 输入 JSONL 文件路径
36 | :param output_file: 输出 JSONL 文件路径
37 | :param merge_probability: 合并的概率 (0~1)
38 | """
39 | # 读取 JSONL 文件内容
40 | dataset = load_from_disk(input_file)
41 | records = []
42 | for record in dataset:
43 | records.append(record)
44 |
45 | merged_records = []
46 | i = 0
47 | while i < len(records):
48 | record1 = records[i]
49 | i += 1
50 | # 按概率尝试合并连续记录
51 | while i < len(records) and random.random() < merge_probability:
52 | record2 = records[i]
53 | record1 = merge_records(record1, record2)
54 | i += 1
55 | # 将最终合并的记录加入结果
56 | merged_records.append(record1)
57 |
58 | turn_len_list = []
59 | for rec in merged_records:
60 | turn_len = len(rec['messages'])
61 | turn_len_list.append(turn_len)
62 | print("平均轮次:", sum(turn_len_list)/len(turn_len_list))
63 |
64 | # 保存结果到文件
65 | data = Dataset.from_list(merged_records)
66 | print("合并前:", len(dataset), "合并后:", len(data))
67 | data.save_to_disk(output_file)
68 |
69 |
70 | if __name__ == '__main__':
71 | # 示例:运行脚本
72 | input_file = "dataset/distill_r1_110k.sampled.jsonl" # 输入文件路径
73 | output_file = "dataset/distill_r1_110k.sampled.pseudo.multiturn.jsonl" # 输出文件路径
74 | merge_probability = 0.5 # 合并概率 (50%)
75 | num_workers = 1
76 | merge(input_file, output_file, merge_probability)
--------------------------------------------------------------------------------
/data/script/step1.sample.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 | import argparse
4 | from tqdm import tqdm
5 | from multiprocessing import Pool
6 | from datasets import Dataset, load_from_disk
7 |
8 | def load_config(config_path):
9 | with open(config_path, 'r', encoding='utf-8') as f:
10 | return json.load(f)
11 |
12 | def process_category(items, ratio, seed):
13 | random.seed(seed)
14 | int_part = int(ratio)
15 | float_part = ratio - int_part
16 | result = items * int_part
17 | extra_sample_size = int(len(items) * float_part)
18 | if extra_sample_size > 0:
19 | result += random.sample(items, extra_sample_size)
20 | return result
21 |
22 | class Sampler(object):
23 | def __init__(self, config_path, preprocessed_data_path, num_workers, seed, output_path):
24 | self.config =load_config(config_path=config_path)
25 | self.num_workers = num_workers
26 | self.seed = seed
27 | self.output_path = output_path
28 | self.dataset = load_from_disk(preprocessed_data_path)
29 |
30 | def sample(self):
31 | category_ratios = {
32 | item['category_id']: item['sample_rate']
33 | for item in self.config['ratio']
34 | }
35 | categorized_data = {}
36 | for item in tqdm(self.dataset):
37 | category_id = item['meta']['category_id']
38 | if category_id not in categorized_data:
39 | categorized_data[category_id] = []
40 | categorized_data[category_id].append(item)
41 | with Pool(self.num_workers) as pool:
42 | tasks = [
43 | (items, category_ratios[category_id], self.seed + category_id)
44 | for category_id, items in categorized_data.items()
45 | ]
46 | results = pool.starmap(process_category, tasks)
47 | sampled_data = []
48 | for result in results:
49 | sampled_data.extend(result)
50 | random.seed(self.seed)
51 | random.shuffle(sampled_data)
52 | self.dataset = Dataset.from_list(sampled_data)
53 |
54 | def save(self):
55 | self.dataset.save_to_disk(self.output_path)
56 |
57 |
58 |
59 | def parse_args():
60 | parser = argparse.ArgumentParser()
61 | parser.add_argument('--config-path', type=str, help='数据配置文件路径')
62 | parser.add_argument('--preprocessed-data-path', type=str, help='数据保存路径')
63 | parser.add_argument('--output-path', type=str, help='数据保存路径')
64 | parser.add_argument('--seed', type=int, default=42, help='随机数种子')
65 | parser.add_argument('--num-workers', type=int, default=5, help='并行处理的工作进程数')
66 | return parser.parse_args()
67 |
68 | def main():
69 | args = parse_args()
70 | sampler = Sampler(config_path=args.config_path,
71 | preprocessed_data_path=args.preprocessed_data_path,
72 | output_path=args.output_path,
73 | seed=args.seed,
74 | num_workers=args.num_workers)
75 | sampler.sample()
76 | sampler.save()
77 |
78 | if __name__ == "__main__":
79 | main()
--------------------------------------------------------------------------------
/data/script/step2.tokenize.py:
--------------------------------------------------------------------------------
1 | import json
2 | import torch
3 | import argparse
4 | from transformers import AutoTokenizer
5 | from datasets import load_from_disk
6 |
7 | def load_config(config_path):
8 | with open(config_path, 'r', encoding='utf-8') as f:
9 | return json.load(f)
10 |
11 | class Tokenizer(object):
12 | def __init__(self, input_path, tokenizer_path, num_workers, output_path, ignore_index=-100):
13 | self.dataset = load_from_disk(input_path)
14 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
15 | self.num_workers = num_workers
16 | self.output_path = output_path
17 | self.ignore_index = ignore_index
18 |
19 |
20 | def process(self, record):
21 | input_ids = self.tokenizer.apply_chat_template(
22 | record["messages"],
23 | tokenize=True,
24 | add_generation_prompt=False,
25 | return_tensors="pt")[0]
26 |
27 | token_labels = torch.full_like(input_ids, fill_value=self.ignore_index)
28 | category_ids = [0] * len(input_ids)
29 | for idx, message in enumerate(record["messages"]):
30 | if message["role"] == "assistant":
31 | prompt = self.tokenizer.apply_chat_template(record["messages"][:idx], tokenize=False,
32 | add_generation_prompt=True)
33 | response = self.tokenizer.apply_chat_template(record["messages"][: idx + 1], tokenize=False)[
34 | len(prompt):]
35 | start_idx = self.tokenizer(
36 | prompt,
37 | padding=False,
38 | return_tensors="pt",
39 | add_special_tokens=False,
40 | )["attention_mask"].int().sum().item()
41 | end_idx = start_idx + self.tokenizer(
42 | response,
43 | padding=False,
44 | return_tensors="pt",
45 | add_special_tokens=False,
46 | )["attention_mask"].int().sum().item()
47 | if record["labels"][idx] == 1:
48 | token_labels[start_idx:end_idx] = input_ids[start_idx:end_idx]
49 | if 'meta' in record:
50 | category_ids[start_idx:end_idx] = [record['meta']['category_id']] * (end_idx - start_idx)
51 | else:
52 | category_ids[start_idx:end_idx] = [record['category_id_list'][idx]] * (end_idx - start_idx)
53 | return {
54 | "input_ids": input_ids,
55 | "token_labels": token_labels,
56 | "category_ids": category_ids,
57 | }
58 |
59 | def tokenize(self):
60 | self.dataset = self.dataset.map(
61 | self.process,
62 | num_proc=self.num_workers,
63 | remove_columns=self.dataset.column_names,
64 | desc="分词"
65 | )
66 |
67 | def save(self):
68 | self.dataset.save_to_disk(self.output_path)
69 |
70 |
71 | def parse_args():
72 | parser = argparse.ArgumentParser()
73 | parser.add_argument('--input-path', type=str, help='数据路径')
74 | parser.add_argument('--output-path', type=str, help='数据保存路径')
75 | parser.add_argument('--tokenizer-path', type=str, help='tokenizer路径')
76 | parser.add_argument('--num-workers', type=int, help='并行处理的工作进程数')
77 | return parser.parse_args()
78 |
79 | def main():
80 | args = parse_args()
81 | tokenizer = Tokenizer(input_path=args.input_path,
82 | tokenizer_path=args.tokenizer_path,
83 | num_workers=args.num_workers,
84 | output_path=args.output_path)
85 | tokenizer.tokenize()
86 | tokenizer.save()
87 |
88 | if __name__ == "__main__":
89 | main()
--------------------------------------------------------------------------------
/data/script/step3.pack.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 | import argparse
3 | import random
4 | import torch
5 | import numpy as np
6 | from functools import partial
7 | from multiprocessing import Pool
8 | from datasets import load_from_disk, Dataset
9 |
10 |
11 | def process_single_group(group_indices, sequences, max_length, padding_value):
12 | """处理单个packed sequence组"""
13 | # 预分配numpy数组,使用padding_value初始化
14 | packed_input_ids = np.full(max_length, padding_value, dtype=np.int64)
15 | packed_labels = np.full(max_length, -100, dtype=np.int64)
16 | packed_category_ids = np.zeros(max_length, dtype=np.int64)
17 | attention_mask = np.zeros(max_length, dtype=np.int64)
18 |
19 | current_pos = 0
20 | packed_seq_lens = []
21 |
22 | # 随机打乱索引
23 | random.shuffle(group_indices)
24 |
25 | # 填充数据
26 | for idx in group_indices:
27 | seq_len = len(sequences[idx]['input_ids'])
28 | packed_seq_lens.append(seq_len)
29 |
30 | end_pos = current_pos + seq_len
31 | packed_input_ids[current_pos:end_pos] = sequences[idx]['input_ids']
32 | packed_labels[current_pos:end_pos] = sequences[idx]['token_labels']
33 | packed_category_ids[current_pos:end_pos] = sequences[idx]['category_ids']
34 | attention_mask[current_pos:end_pos] = 1
35 |
36 | current_pos = end_pos
37 |
38 | return {
39 | 'input_ids': torch.from_numpy(packed_input_ids),
40 | 'labels': torch.from_numpy(packed_labels),
41 | 'attention_mask': torch.from_numpy(attention_mask),
42 | 'packed_seq_lens': torch.tensor(packed_seq_lens),
43 | 'category_ids': torch.from_numpy(packed_category_ids),
44 | }
45 |
46 |
47 | def pack_sequences(sequences, max_length, num_workers=20, padding_value=0, max_attempt=2000):
48 | """
49 | 使用贪心算法将多个序列打包成固定长度的序列,尽量减少packed_sequences的数量
50 |
51 | Args:
52 | sequences: 包含多个序列的列表,每个序列是一个字典,包含'input_ids'键
53 | max_length: 打包后序列的最大长度
54 | padding_value: 填充值
55 |
56 | Returns:
57 | packed_results: 包含打包后的序列和长度信息的字典列表
58 | - input_ids: 拼接后的序列
59 | - lengths: 原始序列长度列表
60 | - indices: 原始序列索引列表
61 | """
62 |
63 | def compute_length(example):
64 | return {'length': len(example['input_ids'])}
65 |
66 | data_with_length = sequences.map(
67 | compute_length,
68 | num_proc=num_workers,
69 | desc="计算序列长度"
70 | )
71 | idx2len = dict(enumerate(data_with_length['length']))
72 |
73 | unused_indices = set(idx2len.keys())
74 | packed_sequences = []
75 |
76 | pbar = tqdm(total=len(unused_indices), desc="打包进度")
77 |
78 | while unused_indices:
79 | current_seq = []
80 | current_len = 0
81 | attempt = max_attempt
82 | for idx in idx2len:
83 | if idx in unused_indices:
84 | if current_len + idx2len[idx] <= max_length:
85 | current_seq.append(idx)
86 | current_len += idx2len[idx]
87 | unused_indices.remove(idx)
88 | pbar.update(1)
89 | else:
90 | if attempt > 0:
91 | attempt -= 1
92 | continue
93 | else:
94 | break
95 | packed_sequences.append(current_seq)
96 | pbar.close()
97 |
98 | with Pool(num_workers) as pool:
99 | process_fn = partial(
100 | process_single_group,
101 | sequences=sequences,
102 | max_length=max_length,
103 | padding_value=padding_value
104 | )
105 |
106 | # 使用imap显示进度条
107 | packed_results = list(tqdm(
108 | pool.imap(process_fn, packed_sequences),
109 | total=len(packed_sequences),
110 | desc="并行打包处理"
111 | ))
112 |
113 | return packed_results
114 |
115 | class Packer(object):
116 | def __init__(self, input_path, output_path, max_length, padding_value, num_workers, max_attempt):
117 | self.dataset = load_from_disk(input_path)
118 | self.output_path = output_path
119 | self.max_length = max_length
120 | self.padding_value = padding_value
121 | self.num_workers = num_workers
122 | self.max_attempt = max_attempt
123 | self.packed_dataset = None
124 |
125 | def filter_fn(self, example):
126 | return len(example['input_ids']) <= self.max_length
127 |
128 | def pack(self):
129 | valid_items = self.dataset.filter(
130 | self.filter_fn,
131 | num_proc=self.num_workers
132 | )
133 | packed_data = pack_sequences(
134 | valid_items,
135 | max_length=self.max_length,
136 | padding_value=self.padding_value,
137 | max_attempt=self.max_attempt
138 | )
139 | self.packed_dataset = Dataset.from_list(packed_data)
140 | self.packed_dataset.set_format('torch')
141 |
142 |
143 | def save(self):
144 | self.packed_dataset.save_to_disk(self.output_path)
145 |
146 |
147 |
148 |
149 | def parse_args():
150 | parser = argparse.ArgumentParser()
151 | parser.add_argument('--input-path', type=str, help='数据路径')
152 | parser.add_argument('--output-path', type=str, help='数据保存路径')
153 | parser.add_argument('--max-length', type=int, help=' 最大长度')
154 | parser.add_argument('--padding-value', type=int, help='padding值')
155 | parser.add_argument('--num-workers', type=int, help='并行处理的工作进程数')
156 | return parser.parse_args()
157 |
158 | def main():
159 | args = parse_args()
160 | packer = Packer(
161 | input_path=args.input_path,
162 | output_path=args.output_path,
163 | max_length=args.max_length,
164 | padding_value=args.padding_value,
165 | num_workers=args.num_workers,
166 | max_attempt=2000) # 打包超出最大长度时,继续向后查找次数
167 | packer.pack()
168 | packer.save()
169 |
170 | if __name__ == "__main__":
171 | main()
172 |
173 |
174 |
--------------------------------------------------------------------------------
/docs/image-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mlpod/OpenSFT/57ecd3c6037c4dca600009766dbeb0c1249a6048/docs/image-1.png
--------------------------------------------------------------------------------
/docs/image-2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mlpod/OpenSFT/57ecd3c6037c4dca600009766dbeb0c1249a6048/docs/image-2.png
--------------------------------------------------------------------------------
/docs/image-3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mlpod/OpenSFT/57ecd3c6037c4dca600009766dbeb0c1249a6048/docs/image-3.png
--------------------------------------------------------------------------------
/docs/image-4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mlpod/OpenSFT/57ecd3c6037c4dca600009766dbeb0c1249a6048/docs/image-4.png
--------------------------------------------------------------------------------
/train/config.py:
--------------------------------------------------------------------------------
1 | import json
2 | import argparse
3 |
4 | parser = argparse.ArgumentParser()
5 | parser.add_argument('--max_steps', type=int, default=2000)
6 | parser.add_argument('--warmup_num_steps', type=int, default=200)
7 | parser.add_argument("--learning_rate", type=float, default=2e-5)
8 | args = parser.parse_args()
9 |
10 | with open("config/zero3_offload.json") as f:
11 | config = json.loads(f.read())
12 |
13 | config['scheduler']['params']['total_num_steps'] = args.max_steps
14 | config['scheduler']['params']['warmup_num_steps'] = args.warmup_num_steps
15 | config['scheduler']['params']['warmup_max_lr'] = args.learning_rate
16 | config['scheduler']['params']['warmup_min_lr'] = args.learning_rate * 0.1
17 |
18 | print(config)
19 | with open("config/zero3_offload.json", "w", encoding="utf-8") as f:
20 | f.write(json.dumps(config, indent=4))
21 |
--------------------------------------------------------------------------------
/train/config/multi_node.yaml:
--------------------------------------------------------------------------------
1 | debug: false
2 | deepspeed_config:
3 | deepspeed_config_file: zero3_offload.json
4 | deepspeed_multinode_launcher: standard
5 | zero3_init_flag: true
6 | distributed_type: DEEPSPEED
7 | downcast_bf16: 'no'
8 | num_machines: 24
9 | num_processes: 196
10 | main_process_port: 38222
11 | main_training_function: main
12 | rdzv_backend: c10d
13 | same_network: false
14 | tpu_env: []
15 | tpu_use_cluster: false
16 | tpu_use_sudo: false
17 | use_cpu: false
18 |
--------------------------------------------------------------------------------
/train/config/single_node.yaml:
--------------------------------------------------------------------------------
1 | debug: false
2 | deepspeed_config:
3 | deepspeed_config_file: zero3_offload.json
4 | deepspeed_multinode_launcher: standard
5 | zero3_init_flag: true
6 | distributed_type: DEEPSPEED
7 | downcast_bf16: 'no'
8 | num_machines: 1
9 | num_processes: 8
10 | main_process_port: 38222
11 | main_training_function: main
12 | rdzv_backend: c10d
13 | same_network: false
14 | tpu_env: []
15 | tpu_use_cluster: false
16 | tpu_use_sudo: false
17 | use_cpu: false
18 |
--------------------------------------------------------------------------------
/train/config/zero3_offload.json:
--------------------------------------------------------------------------------
1 | {
2 | "bf16": {
3 | "enabled": "true"
4 | },
5 | "scheduler": {
6 | "type": "WarmupDecayLR",
7 | "params": {
8 | "warmup_min_lr": 2e-06,
9 | "warmup_max_lr": 2e-05,
10 | "warmup_num_steps": 200,
11 | "warmup_type": "linear",
12 | "total_num_steps": 2000
13 | }
14 | },
15 | "optimizer": {
16 | "type": "AdamW",
17 | "params": {
18 | "lr": "auto",
19 | "betas": [
20 | 0.9,
21 | 0.95
22 | ],
23 | "eps": 1e-08,
24 | "weight_decay": 0.1
25 | }
26 | },
27 | "zero_optimization": {
28 | "stage": 3,
29 | "overlap_comm": true,
30 | "contiguous_gradients": true,
31 | "sub_group_size": 1000000000.0,
32 | "reduce_bucket_size": "auto",
33 | "stage3_prefetch_bucket_size": "auto",
34 | "stage3_param_persistence_threshold": "auto",
35 | "stage3_max_live_parameters": 1000000000.0,
36 | "stage3_max_reuse_distance": 1000000000.0,
37 | "stage3_gather_16bit_weights_on_model_save": true
38 | },
39 | "gradient_accumulation_steps": "auto",
40 | "gradient_clipping": "auto",
41 | "train_batch_size": "auto",
42 | "train_micro_batch_size_per_gpu": 1
43 | }
--------------------------------------------------------------------------------
/train/data.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import datasets
3 | from datasets import disable_caching
4 | from torch.utils.data import DataLoader, DistributedSampler
5 |
6 | disable_caching()
7 |
8 | class SFTData(object):
9 | def __init__(self, data_path):
10 | self.data = datasets.load_from_disk(data_path)
11 |
12 | def collate_fn(self, records):
13 | input_ids = records[0]["input_ids"].unsqueeze(0)
14 | labels = records[0]["labels"].unsqueeze(0)
15 | attention_mask = records[0]["attention_mask"].unsqueeze(0)
16 | category_ids = records[0]["category_ids"].unsqueeze(0)
17 | packed_seq_lens = records[0]['packed_seq_lens'].to(torch.int32)
18 | return dict(
19 | input_ids=input_ids,
20 | labels=labels,
21 | attention_mask=attention_mask,
22 | packed_seq_lens=packed_seq_lens,
23 | category_ids=category_ids
24 | )
25 |
26 | def get_dataloader(self, dp_world_size, dp_rank, seed, epoch, shuffle=True):
27 | sampler = DistributedSampler(
28 | self.data, num_replicas=dp_world_size, rank=dp_rank, seed=seed+epoch, shuffle=shuffle
29 | )
30 | train_dataloader = DataLoader(
31 | self.data,
32 | batch_size=1,
33 | sampler=sampler,
34 | collate_fn=self.collate_fn,
35 | pin_memory=True,
36 | )
37 | return train_dataloader
--------------------------------------------------------------------------------
/train/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.33.0
2 | flash-attn==2.6.3
3 | deepspeed==0.16.0
4 | transformers==4.45.2
5 |
6 | -e ./ring-flash-attention/
7 |
8 | wandb
9 | pandas
--------------------------------------------------------------------------------
/train/ring-flash-attention/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright 2024 Zilin Zhu
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
4 |
5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
6 |
7 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
--------------------------------------------------------------------------------
/train/ring-flash-attention/README.md:
--------------------------------------------------------------------------------
1 | ## Ring Flash Attention
2 |
3 | This repo implements [RingAttention](https://github.com/lhao499/RingAttention) using [FlashAttention](https://github.com/Dao-AILab/flash-attention). The current implementation supports:
4 |
5 | - varlen (packing samples) api, corresponding to `flash_attn_varlen_func`:
6 | - `ring_flash_attn_varlen_func`: A basic implementation of ring attention.
7 | - `zigzag_ring_flash_attn_varlen_func`: an more compute-balanced version of ring attention. More details in [issue#2](https://github.com/zhuzilin/ring-flash-attention/issues/2).
8 | - `llama3_flash_attn_varlen_func`: The context parallelism used in [llama3 tech report](https://arxiv.org/abs/2407.21783) with extra design for varlen and low memory overhead. Although technically not ring attention, this is **recommended** for most varlen use cases, as it offers a less intrusive alternative for training frameworks with fewer data manipulations and better arithmetic precision.
9 | - batch api, corresponding to `flash_attn_func`:
10 | - `ring_flash_attn_func`: basic ring attention.
11 | - `zigzag_ring_flash_attn_func`: An more compute balanced version of ring attention, see [issue#2](https://github.com/zhuzilin/ring-flash-attention/issues/2).
12 | - `stripe_flash_attn_func`: Stripe attention version of `ring_flash_attn_func`, the block size is set to 1 to use flash_attn api, see: https://arxiv.org/abs/2311.09431
13 | - [huggingface model adapter](ring_flash_attn/adapters/hf_adapter.py). Here is an example to use the adapter: [OpenRLHF/OpenRLHF/pull#439](https://github.com/OpenRLHF/OpenRLHF/pull/439/files).
14 |
15 | Note that
16 |
17 | - Each function includes `*_func`, `*_kvpacked_func`, `*_qkvpacked_func` variants.
18 | - The varlen versions (except the llama3 version) only support passing one `cu_seqlens`.
19 |
20 | ## Performance Summary
21 |
22 | The following table summarizes the performance of the implemented APIs:
23 |
24 | | batch api | GPU | theoretic
flash_attn | ring_attn | zigzag_ring | stripe_attn |
25 | | -------------------- | ------- | ----------------------------- | ------------- | --------------- | --------------- |
26 | | fwd only (iter/sec) | 8xH800 | 591.5 / 8 = 73.9 | 38.5 | 63.0 | 55.0 |
27 | | | | | 52.1% | **85.2%** | 74.4% |
28 | | fwd + bwd (iter/sec) | 8xH800 | 154.7 / 8 = 19.3 | 10.4 | 17.4 | 16.0 |
29 | | | | | 53.9% | **90.2%** | 82.9% |
30 | | fwd only (iter/sec) | 8xA100 | 373.4 / 8 = 46.7 | 24.0 | 38.2 | 32.5 |
31 | | | | | 51.4% | **81.7%** | 69.6% |
32 | | fwd + bwd (iter/sec) | 8xA100 | 94.7 / 8 = 11.8 | 6.2 | 10.6 | 9.75 |
33 | | | | | 52.5% | **89.8%** | 82.6% |
34 | | **varlen api** | **GPU** | **theoretic
flash_attn** | **ring_attn** | **zigzag_ring** | **llama3_attn** |
35 | | fwd only (iter/sec) | 8xH800 | 852.4 / 8 = 106.6 | 52.4 | 74.8 | 60.8 |
36 | | | | | 49.1% | **70.2%** | 57.0% |
37 | | fwd + bwd (iter/sec) | 8xH800 | 225.4 / 8 = 28.2 | 14.4 | 21.4 | 16.4 |
38 | | | | | 51.1% | **75.9%** | 58.1% |
39 | | fwd only (iter/sec) | 8xA100 | 532.3 / 8 = 66.5 | 33.1 | 47.9 | 34.3 |
40 | | | | | 49.8% | **72.0%** | 51.6% |
41 | | fwd + bwd (iter/sec) | 8xA100 | 133.8 / 8 = 16.7 | 8.7 | 13.4 | 9.7 |
42 | | | | | 52.1% | **80.2%** | 58.0% |
43 |
44 | Note that
45 |
46 | - The code of the benchmark is in [benchmark](benchmark/), its configuration matches the [Meta-Llama-3.1-8B](https://huggingface.co/NousResearch/Meta-Llama-3.1-8B/blob/main/config.json) setting, with a total sequence of length 8k per GPU.
47 | - When running the benchmark with with 8 gpu, the flash attn code is running with 1/8 computation of ring attention, as flash attn code is running `8*1^2`, while the ring attn code is running `1*8^2`.
48 | - NVLink between GPUs are required for high performance.
49 | - Please remember to adapt the RoPE offset for different api.
50 |
51 | ### Installation
52 |
53 | ```bash
54 | pip install ring-flash-attn
55 | ```
56 |
57 | or use the following command to build from source:
58 |
59 | ```bash
60 | git clone https://github.com/zhuzilin/ring-flash-attention.git
61 | cd ring-flash-attention
62 | pip install .
63 | ```
64 |
65 | ### TODOs
66 |
67 | - [x] Implement `ring_flash_attn_varlen_qkvpacked_func`
68 | - [x] Implement `zigzag_ring_flash_attn_qkvpacked_func` [issue#2](https://github.com/zhuzilin/ring-flash-attention/issues/2)
69 | - [x] Implement `stripe_flash_attn_qkvpacked_func`
70 | - [x] Implement `zigzag_ring_flash_attn_varlen_qkvpacked_func`
71 | - [x] Implement `*_kvpacked_func` and `*_func` variant for all APIs
72 | - [x] ~~Optimize `*_varlen_func`~~ Implement `llama3_flash_attn_varlen_func`
73 | - [x] ~~Add an example to train llama~~ Implement adapter for huggingface model
74 | - [ ] Implement `zigzag_llama3_flash_attn_varlen_func`
75 |
76 | ### Test
77 |
78 | ```bash
79 | torchrun --nproc_per_node 8 test/test_llama3_flash_attn_varlen_func.py
80 | torchrun --nproc_per_node 8 test/test_ring_flash_attn_func.py
81 | torchrun --nproc_per_node 8 test/test_ring_flash_attn_varlen_func.py
82 | torchrun --nproc_per_node 8 test/test_zigzag_ring_flash_attn_func.py
83 | torchrun --nproc_per_node 8 test/test_zigzag_ring_flash_attn_varlen_func.py
84 | torchrun --nproc_per_node 8 test/test_stripe_flash_attn_func.py
85 | ```
86 |
87 | ### Benchmark
88 |
89 | ```bash
90 | torchrun --nproc_per_node 8 benchmark/benchmark_kvpacked_func.py
91 | torchrun --nproc_per_node 8 benchmark/benchmark_varlen_kvpacked_func.py
92 | ```
93 |
94 | ### Known Limitations
95 |
96 | There are some arithmetic errors with the current implementation. The reason for them is probably that flash attention will return bf16 value for each block, so we cannot accumluate the values with the original fp32 ones.
97 |
98 | And also because we need to save extra fp32 buffer during computation, the memory usage would be higher than theoretic limit.
99 |
100 | Also,
101 |
102 | - dropout is not supported at the moment, because it's hard to save all the rng_states.
103 | - window_size is not supported, because it will be really tricky to implement a varlen version with window_size.
104 |
--------------------------------------------------------------------------------
/train/ring-flash-attention/benchmark/benchmark_kvpacked_func.py:
--------------------------------------------------------------------------------
1 | from flash_attn import flash_attn_kvpacked_func
2 | import os
3 | import torch
4 | import torch.distributed as dist
5 | from ring_flash_attn import (
6 | ring_flash_attn_kvpacked_func,
7 | zigzag_ring_flash_attn_kvpacked_func,
8 | stripe_flash_attn_kvpacked_func,
9 | )
10 |
11 |
12 | def benchmark(f, num_iter=100, forward_only=True, log=True, profile=False):
13 | dtype = torch.bfloat16
14 | rank = dist.get_rank()
15 | world_size = dist.get_world_size()
16 | device = torch.device(f"cuda:{rank}")
17 | torch.cuda.set_device(device)
18 |
19 | batch_size = 1
20 | deterministic = False
21 | # config of llama3 8B
22 | seqlen = 1024 * 8
23 | num_heads = 32
24 | num_kv_heads = 8
25 | head_dim = 128
26 | causal = True
27 |
28 | assert seqlen % (2 * world_size) == 0
29 | assert head_dim % 8 == 0
30 |
31 | q = torch.randn(
32 | batch_size,
33 | seqlen,
34 | num_heads,
35 | head_dim,
36 | device=device,
37 | dtype=dtype,
38 | requires_grad=True,
39 | )
40 | kv = torch.randn(
41 | batch_size,
42 | seqlen,
43 | 2,
44 | num_kv_heads,
45 | head_dim,
46 | device=device,
47 | dtype=dtype,
48 | requires_grad=True,
49 | )
50 | dout = torch.randn(
51 | batch_size, seqlen, num_heads, head_dim, device=device, dtype=dtype
52 | )
53 |
54 | if profile:
55 | torch.backends.cudnn.benchmark = True
56 | profiler = torch.profiler.profile(
57 | activities=[
58 | torch.profiler.ProfilerActivity.CPU,
59 | torch.profiler.ProfilerActivity.CUDA,
60 | ],
61 | schedule=torch.profiler.schedule(
62 | wait=5,
63 | warmup=5,
64 | active=5,
65 | ),
66 | record_shapes=True,
67 | profile_memory=True,
68 | with_flops=True,
69 | with_modules=True,
70 | with_stack=True,
71 | on_trace_ready=torch.profiler.tensorboard_trace_handler(
72 | os.path.join(
73 | f"./benchmark/logs/{f.__name__}", f"rank_{dist.get_rank()}"
74 | )
75 | ),
76 | )
77 |
78 | if profile:
79 | profiler.start()
80 |
81 | begin = torch.cuda.Event(enable_timing=True)
82 | begin.record()
83 |
84 | if forward_only:
85 | with torch.no_grad():
86 | for _ in range(num_iter):
87 | _ = f(
88 | q,
89 | kv,
90 | causal=causal,
91 | window_size=(-1, -1),
92 | alibi_slopes=None,
93 | deterministic=deterministic,
94 | return_attn_probs=False,
95 | )
96 | if profile:
97 | profiler.step()
98 |
99 | else:
100 | for _ in range(num_iter):
101 | q.grad = None
102 | kv.grad = None
103 | out = f(
104 | q,
105 | kv,
106 | causal=causal,
107 | window_size=(-1, -1),
108 | alibi_slopes=None,
109 | deterministic=deterministic,
110 | return_attn_probs=False,
111 | )
112 | out.backward(dout)
113 | if profile:
114 | profiler.step()
115 |
116 | end = torch.cuda.Event(enable_timing=True)
117 | end.record()
118 | torch.cuda.synchronize(device=device)
119 | time = begin.elapsed_time(end) / 1000.0
120 |
121 | if profile:
122 | profiler.stop()
123 |
124 | if rank == 0 and log:
125 | print(f"{num_iter / time:.3f} iter/s, {time:.3f} sec")
126 |
127 |
128 | if __name__ == "__main__":
129 | dist.init_process_group("nccl")
130 | rank = dist.get_rank()
131 |
132 | forward_only = False
133 | profile = False
134 | num_iter = 500 if forward_only else 100
135 |
136 | for f in [
137 | flash_attn_kvpacked_func,
138 | ring_flash_attn_kvpacked_func,
139 | zigzag_ring_flash_attn_kvpacked_func,
140 | stripe_flash_attn_kvpacked_func,
141 | ]:
142 | torch.cuda.empty_cache()
143 | if rank == 0:
144 | print(f"# {f.__name__}")
145 | benchmark(f, forward_only=forward_only, num_iter=num_iter, log=False)
146 | benchmark(
147 | f, forward_only=forward_only, num_iter=num_iter, log=True, profile=profile
148 | )
149 |
--------------------------------------------------------------------------------
/train/ring-flash-attention/benchmark/benchmark_varlen_kvpacked_func.py:
--------------------------------------------------------------------------------
1 | from flash_attn import flash_attn_varlen_kvpacked_func
2 | import os
3 | import torch
4 | import torch.distributed as dist
5 | from ring_flash_attn import (
6 | ring_flash_attn_varlen_kvpacked_func,
7 | zigzag_ring_flash_attn_varlen_kvpacked_func,
8 | llama3_flash_attn_varlen_kvpacked_func,
9 | llama3_flash_attn_prepare_cu_seqlens,
10 | )
11 |
12 |
13 | def benchmark(
14 | f,
15 | use_double_cu_seqlens,
16 | use_llama3=False,
17 | num_iter=100,
18 | forward_only=True,
19 | log=True,
20 | profile=False,
21 | ):
22 | dtype = torch.bfloat16
23 | rank = dist.get_rank()
24 | world_size = dist.get_world_size()
25 | device = torch.device(f"cuda:{rank}")
26 | torch.cuda.set_device(device)
27 |
28 | deterministic = False
29 | # config of llama3 8B
30 | seqlen = 1024 * 8
31 | num_heads = 32
32 | num_kv_heads = 8
33 | head_dim = 128
34 | causal = True
35 |
36 | assert seqlen % (2 * world_size) == 0
37 | assert head_dim % 8 == 0
38 |
39 | q = torch.randn(
40 | seqlen, num_heads, head_dim, device=device, dtype=dtype, requires_grad=True
41 | )
42 | kv = torch.randn(
43 | seqlen,
44 | 2,
45 | num_kv_heads,
46 | head_dim,
47 | device=device,
48 | dtype=dtype,
49 | requires_grad=True,
50 | )
51 | dout = torch.randn(seqlen, num_heads, head_dim, device=device, dtype=dtype)
52 |
53 | cu_seqlens_list = [
54 | torch.tensor([0, 8192], device=device, dtype=torch.int32),
55 | torch.tensor([0, 256, 7648, 8192], device=device, dtype=torch.int32),
56 | torch.tensor([0, 4096, 8192], device=device, dtype=torch.int32),
57 | torch.tensor(
58 | [0, 3104, 6304, 7904, 8064, 8192], device=device, dtype=torch.int32
59 | ),
60 | ]
61 |
62 | if use_llama3:
63 | cu_seqlens_q_list = []
64 | cu_seqlens_k_list = []
65 | max_seqlen_q_list = []
66 | max_seqlen_k_list = []
67 | local_k_slice_list = []
68 | for cu_seqlens in cu_seqlens_list:
69 | (
70 | cu_seqlens_q,
71 | cu_seqlens_k,
72 | max_seqlen_q,
73 | max_seqlen_k,
74 | local_k_slice,
75 | ) = llama3_flash_attn_prepare_cu_seqlens(
76 | cu_seqlens * world_size,
77 | causal=causal,
78 | rank=rank,
79 | world_size=world_size,
80 | )
81 | cu_seqlens_q_list.append(cu_seqlens_q)
82 | cu_seqlens_k_list.append(cu_seqlens_k)
83 | max_seqlen_q_list.append(max_seqlen_q)
84 | max_seqlen_k_list.append(max_seqlen_k)
85 | local_k_slice_list.append(local_k_slice)
86 | else:
87 | max_seqlen_list = [
88 | (cu_seqlens[1:] - cu_seqlens[:1]).max().item()
89 | for cu_seqlens in cu_seqlens_list
90 | ]
91 |
92 | if profile:
93 | torch.backends.cudnn.benchmark = True
94 | profiler = torch.profiler.profile(
95 | activities=[
96 | torch.profiler.ProfilerActivity.CPU,
97 | torch.profiler.ProfilerActivity.CUDA,
98 | ],
99 | schedule=torch.profiler.schedule(
100 | wait=5,
101 | warmup=5,
102 | active=5,
103 | ),
104 | record_shapes=True,
105 | profile_memory=True,
106 | with_flops=True,
107 | with_modules=True,
108 | with_stack=True,
109 | on_trace_ready=torch.profiler.tensorboard_trace_handler(
110 | os.path.join(
111 | f"./benchmark/logs/{f.__name__}", f"rank_{dist.get_rank()}"
112 | )
113 | ),
114 | )
115 |
116 | if profile:
117 | profiler.start()
118 |
119 | begin = torch.cuda.Event(enable_timing=True)
120 | begin.record()
121 |
122 | def wrapper(i: int):
123 | if use_llama3:
124 | return f(
125 | q,
126 | kv,
127 | cu_seqlens_q_list[i % len(cu_seqlens_list)],
128 | cu_seqlens_k_list[i % len(cu_seqlens_list)],
129 | max_seqlen_q_list[i % len(cu_seqlens_list)],
130 | max_seqlen_k_list[i % len(cu_seqlens_list)],
131 | heads_k_stride=4,
132 | local_k_slice=local_k_slice_list[i % len(cu_seqlens_list)],
133 | causal=causal,
134 | window_size=(-1, -1),
135 | alibi_slopes=None,
136 | deterministic=deterministic,
137 | return_attn_probs=False,
138 | )
139 | elif use_double_cu_seqlens:
140 | return f(
141 | q,
142 | kv,
143 | cu_seqlens_list[i % len(cu_seqlens_list)],
144 | cu_seqlens_list[i % len(cu_seqlens_list)],
145 | max_seqlen_list[i % len(cu_seqlens_list)],
146 | max_seqlen_list[i % len(cu_seqlens_list)],
147 | causal=causal,
148 | window_size=(-1, -1),
149 | alibi_slopes=None,
150 | deterministic=deterministic,
151 | return_attn_probs=False,
152 | )
153 | else:
154 | return f(
155 | q,
156 | kv,
157 | cu_seqlens_list[i % len(cu_seqlens_list)],
158 | max_seqlen_list[i % len(cu_seqlens_list)],
159 | causal=causal,
160 | window_size=(-1, -1),
161 | alibi_slopes=None,
162 | deterministic=deterministic,
163 | return_attn_probs=False,
164 | )
165 |
166 | if forward_only:
167 | with torch.no_grad():
168 | for i in range(num_iter):
169 | _ = wrapper(i)
170 | else:
171 | for i in range(num_iter):
172 | q.grad = None
173 | kv.grad = None
174 | out = wrapper(i)
175 | out.backward(dout)
176 | if profile:
177 | profiler.step()
178 | end = torch.cuda.Event(enable_timing=True)
179 | end.record()
180 | torch.cuda.synchronize(device=device)
181 | time = begin.elapsed_time(end) / 1000.0
182 |
183 | if profile:
184 | profiler.stop()
185 | if rank == 0 and log:
186 | print(f"{num_iter / time} iter/s, {time} sec")
187 |
188 |
189 | if __name__ == "__main__":
190 | dist.init_process_group("nccl")
191 | rank = dist.get_rank()
192 |
193 | forward_only = False
194 | profile = False
195 | num_iter = 500 if forward_only else 100
196 |
197 | for f, use_double_cu_seqlens in [
198 | (flash_attn_varlen_kvpacked_func, True),
199 | (ring_flash_attn_varlen_kvpacked_func, False),
200 | (zigzag_ring_flash_attn_varlen_kvpacked_func, False),
201 | ]:
202 | torch.cuda.empty_cache()
203 | if rank == 0:
204 | print(f"# {f.__name__}")
205 | benchmark(
206 | f,
207 | use_double_cu_seqlens,
208 | forward_only=forward_only,
209 | num_iter=num_iter,
210 | log=False,
211 | )
212 | benchmark(
213 | f,
214 | use_double_cu_seqlens,
215 | forward_only=forward_only,
216 | num_iter=num_iter,
217 | log=True,
218 | profile=profile,
219 | )
220 |
221 | for f, use_double_cu_seqlens in [
222 | (llama3_flash_attn_varlen_kvpacked_func, True),
223 | ]:
224 | torch.cuda.empty_cache()
225 | if rank == 0:
226 | print(f"# {f.__name__}")
227 | benchmark(
228 | f,
229 | use_double_cu_seqlens,
230 | use_llama3=True,
231 | forward_only=forward_only,
232 | num_iter=num_iter,
233 | log=False,
234 | )
235 | benchmark(
236 | f,
237 | use_double_cu_seqlens,
238 | use_llama3=True,
239 | forward_only=forward_only,
240 | num_iter=num_iter,
241 | log=True,
242 | profile=profile,
243 | )
244 |
--------------------------------------------------------------------------------
/train/ring-flash-attention/build/lib/ring_flash_attn/__init__.py:
--------------------------------------------------------------------------------
1 | from .llama3_flash_attn_varlen import (
2 | llama3_flash_attn_prepare_cu_seqlens,
3 | llama3_flash_attn_varlen_func,
4 | llama3_flash_attn_varlen_kvpacked_func,
5 | llama3_flash_attn_varlen_qkvpacked_func,
6 | )
7 | from .ring_flash_attn import (
8 | ring_flash_attn_func,
9 | ring_flash_attn_kvpacked_func,
10 | ring_flash_attn_qkvpacked_func,
11 | )
12 | from .ring_flash_attn_varlen import (
13 | ring_flash_attn_varlen_func,
14 | ring_flash_attn_varlen_kvpacked_func,
15 | ring_flash_attn_varlen_qkvpacked_func,
16 | )
17 | from .zigzag_ring_flash_attn import (
18 | zigzag_ring_flash_attn_func,
19 | zigzag_ring_flash_attn_kvpacked_func,
20 | zigzag_ring_flash_attn_qkvpacked_func,
21 | )
22 | from .zigzag_ring_flash_attn_varlen import (
23 | zigzag_ring_flash_attn_varlen_func,
24 | zigzag_ring_flash_attn_varlen_kvpacked_func,
25 | zigzag_ring_flash_attn_varlen_qkvpacked_func,
26 | )
27 | from .stripe_flash_attn import (
28 | stripe_flash_attn_func,
29 | stripe_flash_attn_kvpacked_func,
30 | stripe_flash_attn_qkvpacked_func,
31 | )
32 | from .adapters import (
33 | substitute_hf_flash_attn,
34 | update_ring_flash_attn_params,
35 | )
36 |
--------------------------------------------------------------------------------
/train/ring-flash-attention/build/lib/ring_flash_attn/adapters/__init__.py:
--------------------------------------------------------------------------------
1 | from .hf_adapter import (
2 | substitute_hf_flash_attn,
3 | update_ring_flash_attn_params,
4 | )
5 |
--------------------------------------------------------------------------------
/train/ring-flash-attention/build/lib/ring_flash_attn/adapters/hf_adapter.py:
--------------------------------------------------------------------------------
1 | import os
2 | import inspect
3 | from typing import Optional
4 |
5 | import torch
6 | import torch.distributed as dist
7 | import transformers
8 | import transformers.modeling_flash_attention_utils
9 | from transformers.modeling_flash_attention_utils import (
10 | _flash_supports_window_size,
11 | is_flash_attn_greater_or_equal,
12 | )
13 | from ..llama3_flash_attn_varlen import (
14 | llama3_flash_attn_varlen_func,
15 | llama3_flash_attn_prepare_cu_seqlens,
16 | )
17 |
18 |
19 | DATA_PARAMS = {}
20 | RING_ATTN_SWITCH = True
21 |
22 |
23 | def check_params(f1, f2):
24 | return len(inspect.signature(f1).parameters) == len(
25 | inspect.signature(f2).parameters
26 | )
27 |
28 |
29 | def update_ring_flash_attn_params(
30 | cu_seqlens: torch.Tensor, process_group: dist.ProcessGroup
31 | ):
32 | world_size = dist.get_world_size(group=process_group)
33 | rank = dist.get_rank(group=process_group)
34 | (
35 | cu_seqlens_q,
36 | cu_seqlens_k,
37 | max_seqlen_q,
38 | max_seqlen_k,
39 | local_k_slice,
40 | ) = llama3_flash_attn_prepare_cu_seqlens(cu_seqlens, True, rank, world_size)
41 | DATA_PARAMS.update(
42 | {
43 | "cu_seqlens_q": cu_seqlens_q,
44 | "cu_seqlens_k": cu_seqlens_k,
45 | "max_seqlen_q": max_seqlen_q,
46 | "max_seqlen_k": max_seqlen_k,
47 | "local_k_slice": local_k_slice,
48 | }
49 | )
50 |
51 |
52 | def use_ring_attn(flag):
53 | global RING_ATTN_SWITCH
54 | RING_ATTN_SWITCH = flag
55 |
56 |
57 | def create_ring_flash_attention_forward(
58 | process_group: dist.ProcessGroup, heads_k_stride: int
59 | ):
60 | def _flash_attention_forward(
61 | query_states: torch.Tensor,
62 | key_states: torch.Tensor,
63 | value_states: torch.Tensor,
64 | attention_mask: torch.Tensor,
65 | query_length: int,
66 | is_causal: bool,
67 | dropout: float = 0.0,
68 | position_ids: Optional[torch.Tensor] = None,
69 | softmax_scale: Optional[float] = None,
70 | sliding_window: Optional[int] = None,
71 | use_top_left_mask: bool = False,
72 | softcap: Optional[float] = None,
73 | deterministic: bool = None,
74 | ):
75 | """
76 | Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
77 | first unpad the input, then computes the attention scores and pad the final attention scores.
78 |
79 | Args:
80 | query_states (`torch.Tensor`):
81 | Input query states to be passed to Flash Attention API
82 | key_states (`torch.Tensor`):
83 | Input key states to be passed to Flash Attention API
84 | value_states (`torch.Tensor`):
85 | Input value states to be passed to Flash Attention API
86 | attention_mask (`torch.Tensor`):
87 | The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
88 | position of padding tokens and 1 for the position of non-padding tokens.
89 | dropout (`float`):
90 | Attention dropout
91 | softmax_scale (`float`, *optional*):
92 | The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
93 | use_top_left_mask (`bool`, defaults to `False`):
94 | flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference.
95 | softcap (`float`, *optional*):
96 | Softcap for the attention logits, used e.g. in gemma2.
97 | deterministic (`bool`, *optional*):
98 | Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled.
99 | """
100 | if not use_top_left_mask:
101 | causal = is_causal
102 | else:
103 | # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__.
104 | causal = is_causal and query_length != 1
105 |
106 | # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
107 | use_sliding_windows = (
108 | _flash_supports_window_size
109 | and sliding_window is not None
110 | and key_states.shape[1] > sliding_window
111 | )
112 | flash_kwargs = (
113 | {"window_size": (sliding_window, sliding_window)}
114 | if use_sliding_windows
115 | else {}
116 | )
117 |
118 | if is_flash_attn_greater_or_equal("2.4.1"):
119 | if deterministic is None:
120 | deterministic = (
121 | os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
122 | )
123 | flash_kwargs["deterministic"] = deterministic
124 | assert (
125 | softcap is None
126 | ), "llama3_flash_attn_varlen_func does not support softcap yet."
127 | # flash_kwargs["softcap"] = softcap
128 | flash_kwargs["group"] = process_group
129 |
130 | # not sure why attention_mask can be not None...
131 | assert causal, "only causal attention is supported yet."
132 | batch_size = query_states.size(0)
133 | assert batch_size == 1, "varlen data should be processed in advance."
134 |
135 | attn_output = llama3_flash_attn_varlen_func(
136 | query_states.squeeze(dim=0),
137 | key_states.squeeze(dim=0),
138 | value_states.squeeze(dim=0),
139 | cu_seqlens_q=DATA_PARAMS["cu_seqlens_q"],
140 | cu_seqlens_k=DATA_PARAMS["cu_seqlens_k"],
141 | max_seqlen_q=DATA_PARAMS["max_seqlen_q"],
142 | max_seqlen_k=DATA_PARAMS["max_seqlen_k"],
143 | heads_k_stride=heads_k_stride,
144 | local_k_slice=DATA_PARAMS["local_k_slice"],
145 | dropout_p=dropout,
146 | softmax_scale=softmax_scale,
147 | causal=causal,
148 | **flash_kwargs,
149 | )
150 |
151 | attn_output = attn_output.unsqueeze(dim=0)
152 |
153 | return attn_output
154 |
155 | return _flash_attention_forward
156 |
157 |
158 | def substitute_hf_flash_attn(process_group: dist.ProcessGroup, heads_k_stride: int):
159 | try:
160 | # substitute flash attn
161 | old_flash_attention_forward = (
162 | transformers.modeling_flash_attention_utils._flash_attention_forward
163 | )
164 | new_flash_attention_forward = create_ring_flash_attention_forward(
165 | process_group, heads_k_stride
166 | )
167 | assert check_params(old_flash_attention_forward, new_flash_attention_forward)
168 | transformers.modeling_flash_attention_utils._flash_attention_forward = (
169 | lambda *args, **kwargs: (
170 | new_flash_attention_forward(*args, **kwargs)
171 | if RING_ATTN_SWITCH
172 | else old_flash_attention_forward(*args, **kwargs)
173 | )
174 | )
175 | except:
176 | raise ValueError(
177 | f"The current transformer version {transformers.__version__} is not supported. "
178 | "please use pip install -U transformers to upgrade to the latest version. "
179 | "If the code failed with the latest version, "
180 | "please file an issue to https://github.com/zhuzilin/ring-flash-attention/issues"
181 | )
182 |
--------------------------------------------------------------------------------
/train/ring-flash-attention/build/lib/ring_flash_attn/ring_flash_attn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.distributed as dist
3 | from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward
4 | from .utils import RingComm, update_out_and_lse, get_default_args
5 |
6 |
7 | def ring_flash_attn_forward(
8 | process_group,
9 | q: torch.Tensor,
10 | k: torch.Tensor,
11 | v: torch.Tensor,
12 | softmax_scale,
13 | dropout_p=0,
14 | causal=True,
15 | window_size=(-1, -1),
16 | alibi_slopes=None,
17 | deterministic=False,
18 | ):
19 | comm = RingComm(process_group)
20 |
21 | out = None
22 | lse = None
23 |
24 | next_k, next_v = None, None
25 |
26 | for step in range(comm.world_size):
27 | if step + 1 != comm.world_size:
28 | next_k: torch.Tensor = comm.send_recv(k)
29 | next_v: torch.Tensor = comm.send_recv(v)
30 | comm.commit()
31 |
32 | if not causal or step <= comm.rank:
33 | params = get_default_args(_flash_attn_forward).copy()
34 | params.update(
35 | {
36 | "q": q,
37 | "k": k,
38 | "v": v,
39 | "dropout_p": dropout_p,
40 | "softmax_scale": softmax_scale,
41 | "causal": causal and step == 0,
42 | "window_size": window_size,
43 | "alibi_slopes": alibi_slopes,
44 | "return_softmax": True and dropout_p > 0,
45 | }
46 | )
47 | block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward(**params)
48 | out, lse = update_out_and_lse(out, lse, block_out, block_lse)
49 |
50 | if step + 1 != comm.world_size:
51 | comm.wait()
52 | k = next_k
53 | v = next_v
54 |
55 | out = out.to(q.dtype)
56 | lse = lse.squeeze(dim=-1).transpose(1, 2)
57 | return out, lse
58 |
59 |
60 | def ring_flash_attn_backward(
61 | process_group,
62 | dout,
63 | q,
64 | k,
65 | v,
66 | out,
67 | softmax_lse,
68 | softmax_scale,
69 | dropout_p=0,
70 | causal=True,
71 | window_size=(-1, -1),
72 | alibi_slopes=None,
73 | deterministic=False,
74 | ):
75 | kv_comm = RingComm(process_group)
76 | d_kv_comm = RingComm(process_group)
77 | dq, dk, dv = None, None, None
78 | next_dk, next_dv = None, None
79 |
80 | block_dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device)
81 | block_dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device)
82 | block_dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device)
83 |
84 | next_dk, next_dv = None, None
85 | next_k, next_v = None, None
86 |
87 | for step in range(kv_comm.world_size):
88 | if step + 1 != kv_comm.world_size:
89 | next_k = kv_comm.send_recv(k)
90 | next_v = kv_comm.send_recv(v)
91 | kv_comm.commit()
92 | if step <= kv_comm.rank or not causal:
93 | bwd_causal = causal and step == 0
94 | params = get_default_args(_flash_attn_backward).copy()
95 | params.update(
96 | {
97 | "dout": dout,
98 | "q": q,
99 | "k": k,
100 | "v": v,
101 | "out": out,
102 | "softmax_lse": softmax_lse,
103 | "dq": block_dq_buffer,
104 | "dk": block_dk_buffer,
105 | "dv": block_dv_buffer,
106 | "dropout_p": dropout_p,
107 | "softmax_scale": softmax_scale,
108 | "causal": bwd_causal,
109 | "window_size": window_size,
110 | "alibi_slopes": alibi_slopes,
111 | "deterministic": deterministic,
112 | }
113 | )
114 | _flash_attn_backward(**params)
115 |
116 | if dq is None:
117 | dq = block_dq_buffer.to(torch.float32)
118 | dk = block_dk_buffer.to(torch.float32)
119 | dv = block_dv_buffer.to(torch.float32)
120 | else:
121 | dq += block_dq_buffer
122 | d_kv_comm.wait()
123 | dk = block_dk_buffer + next_dk
124 | dv = block_dv_buffer + next_dv
125 | elif step != 0:
126 | d_kv_comm.wait()
127 | dk = next_dk
128 | dv = next_dv
129 |
130 | if step + 1 != kv_comm.world_size:
131 | kv_comm.wait()
132 | k = next_k
133 | v = next_v
134 |
135 | next_dk = d_kv_comm.send_recv(dk)
136 | next_dv = d_kv_comm.send_recv(dv)
137 | d_kv_comm.commit()
138 |
139 | d_kv_comm.wait()
140 |
141 | return dq.to(torch.bfloat16), next_dk.to(q.dtype), next_dv.to(q.dtype)
142 |
143 |
144 | class RingFlashAttnFunc(torch.autograd.Function):
145 | @staticmethod
146 | def forward(
147 | ctx,
148 | q,
149 | k,
150 | v,
151 | dropout_p,
152 | softmax_scale,
153 | causal,
154 | window_size,
155 | alibi_slopes,
156 | deterministic,
157 | return_softmax,
158 | group,
159 | ):
160 | if softmax_scale is None:
161 | softmax_scale = q.shape[-1] ** (-0.5)
162 |
163 | assert alibi_slopes is None
164 | k = k.contiguous()
165 | v = v.contiguous()
166 | out, softmax_lse = ring_flash_attn_forward(
167 | group,
168 | q,
169 | k,
170 | v,
171 | softmax_scale=softmax_scale,
172 | dropout_p=dropout_p,
173 | causal=causal,
174 | window_size=window_size,
175 | alibi_slopes=alibi_slopes,
176 | deterministic=False,
177 | )
178 | # this should be out_padded
179 | ctx.save_for_backward(q, k, v, out, softmax_lse)
180 | ctx.dropout_p = dropout_p
181 | ctx.softmax_scale = softmax_scale
182 | ctx.causal = causal
183 | ctx.window_size = window_size
184 | ctx.alibi_slopes = alibi_slopes
185 | ctx.deterministic = deterministic
186 | ctx.group = group
187 | return out if not return_softmax else (out, softmax_lse, None)
188 |
189 | @staticmethod
190 | def backward(ctx, dout, *args):
191 | q, k, v, out, softmax_lse = ctx.saved_tensors
192 | dq, dk, dv = ring_flash_attn_backward(
193 | ctx.group,
194 | dout,
195 | q,
196 | k,
197 | v,
198 | out,
199 | softmax_lse,
200 | softmax_scale=ctx.softmax_scale,
201 | dropout_p=ctx.dropout_p,
202 | causal=ctx.causal,
203 | window_size=ctx.window_size,
204 | alibi_slopes=ctx.alibi_slopes,
205 | deterministic=ctx.deterministic,
206 | )
207 | return dq, dk, dv, None, None, None, None, None, None, None, None
208 |
209 |
210 | def ring_flash_attn_qkvpacked_func(
211 | qkv,
212 | dropout_p=0.0,
213 | softmax_scale=None,
214 | causal=False,
215 | window_size=(-1, -1),
216 | alibi_slopes=None,
217 | deterministic=False,
218 | return_attn_probs=False,
219 | group=None,
220 | ):
221 | return RingFlashAttnFunc.apply(
222 | qkv[:, :, 0],
223 | qkv[:, :, 1],
224 | qkv[:, :, 2],
225 | dropout_p,
226 | softmax_scale,
227 | causal,
228 | window_size,
229 | alibi_slopes,
230 | deterministic,
231 | return_attn_probs,
232 | group,
233 | )
234 |
235 |
236 | def ring_flash_attn_kvpacked_func(
237 | q,
238 | kv,
239 | dropout_p=0.0,
240 | softmax_scale=None,
241 | causal=False,
242 | window_size=(-1, -1),
243 | alibi_slopes=None,
244 | deterministic=False,
245 | return_attn_probs=False,
246 | group=None,
247 | ):
248 | return RingFlashAttnFunc.apply(
249 | q,
250 | kv[:, :, 0],
251 | kv[:, :, 1],
252 | dropout_p,
253 | softmax_scale,
254 | causal,
255 | window_size,
256 | alibi_slopes,
257 | deterministic,
258 | return_attn_probs,
259 | group,
260 | )
261 |
262 |
263 | def ring_flash_attn_func(
264 | q,
265 | k,
266 | v,
267 | dropout_p=0.0,
268 | softmax_scale=None,
269 | causal=False,
270 | window_size=(-1, -1),
271 | alibi_slopes=None,
272 | deterministic=False,
273 | return_attn_probs=False,
274 | group=None,
275 | ):
276 | return RingFlashAttnFunc.apply(
277 | q,
278 | k,
279 | v,
280 | dropout_p,
281 | softmax_scale,
282 | causal,
283 | window_size,
284 | alibi_slopes,
285 | deterministic,
286 | return_attn_probs,
287 | group,
288 | )
289 |
--------------------------------------------------------------------------------
/train/ring-flash-attention/build/lib/ring_flash_attn/ring_flash_attn_varlen.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.distributed as dist
3 | from flash_attn.flash_attn_interface import (
4 | _flash_attn_varlen_forward,
5 | _flash_attn_varlen_backward,
6 | )
7 | from .utils import (
8 | RingComm,
9 | update_out_and_lse,
10 | get_default_args,
11 | )
12 |
13 | try:
14 | from .triton_utils import (
15 | flatten_varlen_lse,
16 | unflatten_varlen_lse,
17 | )
18 | except:
19 | from .utils import (
20 | flatten_varlen_lse,
21 | unflatten_varlen_lse,
22 | )
23 |
24 |
25 | def ring_flash_attn_varlen_forward(
26 | process_group,
27 | q: torch.Tensor,
28 | k: torch.Tensor,
29 | v: torch.Tensor,
30 | cu_seqlens,
31 | max_seqlen,
32 | softmax_scale,
33 | dropout_p=0,
34 | causal=True,
35 | window_size=(-1, -1),
36 | alibi_slopes=None,
37 | deterministic=False,
38 | ):
39 | comm = RingComm(process_group)
40 |
41 | out = None
42 | lse = None
43 | next_k, next_v = None, None
44 |
45 | old_lse = False
46 | for step in range(comm.world_size):
47 | if step + 1 != comm.world_size:
48 | next_k: torch.Tensor = comm.send_recv(k)
49 | next_v: torch.Tensor = comm.send_recv(v)
50 | comm.commit()
51 | if not causal or step <= comm.rank:
52 | params = get_default_args(_flash_attn_varlen_forward).copy()
53 | params.update(
54 | {
55 | "q": q,
56 | "k": k,
57 | "v": v,
58 | "cu_seqlens_q": cu_seqlens,
59 | "cu_seqlens_k": cu_seqlens,
60 | "max_seqlen_q": max_seqlen,
61 | "max_seqlen_k": max_seqlen,
62 | "dropout_p": dropout_p,
63 | "softmax_scale": softmax_scale,
64 | "causal": causal and step == 0,
65 | "window_size": window_size,
66 | "alibi_slopes": alibi_slopes,
67 | "return_softmax": True and dropout_p > 0,
68 | }
69 | )
70 |
71 | block_out, _, _, _, _, block_lse, _, _ = _flash_attn_varlen_forward(
72 | **params
73 | )
74 | if block_lse.dim() == 3:
75 | old_lse = True
76 | block_lse = flatten_varlen_lse(
77 | block_lse,
78 | cu_seqlens=cu_seqlens,
79 | )
80 | out, lse = update_out_and_lse(out, lse, block_out, block_lse)
81 |
82 | if step + 1 != comm.world_size:
83 | comm.wait()
84 | k = next_k
85 | v = next_v
86 |
87 | out = out.to(q.dtype)
88 | if old_lse:
89 | lse = unflatten_varlen_lse(lse, cu_seqlens, max_seqlen)
90 | else:
91 | lse = lse.squeeze(dim=-1).transpose(0, 1)
92 | return out, lse
93 |
94 |
95 | def ring_flash_attn_varlen_backward(
96 | process_group,
97 | dout,
98 | q,
99 | k,
100 | v,
101 | out,
102 | softmax_lse,
103 | cu_seqlens,
104 | max_seqlen,
105 | softmax_scale,
106 | dropout_p=0,
107 | causal=True,
108 | window_size=(-1, -1),
109 | alibi_slopes=None,
110 | deterministic=False,
111 | ):
112 | kv_comm = RingComm(process_group)
113 | d_kv_comm = RingComm(process_group)
114 | dq, dk, dv = None, None, None
115 | next_dk, next_dv = None, None
116 |
117 | block_dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device)
118 | block_dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device)
119 | block_dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device)
120 |
121 | next_dk, next_dv = None, None
122 | next_k, next_v = None, None
123 | for step in range(kv_comm.world_size):
124 | if step + 1 != kv_comm.world_size:
125 | next_k = kv_comm.send_recv(k)
126 | next_v = kv_comm.send_recv(v)
127 | kv_comm.commit()
128 | if step <= kv_comm.rank or not causal:
129 | bwd_causal = causal and step == 0
130 | params = get_default_args(_flash_attn_varlen_backward).copy()
131 | params.update(
132 | {
133 | "dout": dout,
134 | "q": q,
135 | "k": k,
136 | "v": v,
137 | "out": out,
138 | "softmax_lse": softmax_lse,
139 | "dq": block_dq_buffer,
140 | "dk": block_dk_buffer,
141 | "dv": block_dv_buffer,
142 | "cu_seqlens_q": cu_seqlens,
143 | "cu_seqlens_k": cu_seqlens,
144 | "max_seqlen_q": max_seqlen,
145 | "max_seqlen_k": max_seqlen,
146 | "dropout_p": dropout_p,
147 | "softmax_scale": softmax_scale,
148 | "causal": bwd_causal,
149 | "window_size": window_size,
150 | "alibi_slopes": alibi_slopes,
151 | "deterministic": deterministic,
152 | }
153 | )
154 | _flash_attn_varlen_backward(**params)
155 |
156 | if dq is None:
157 | dq = block_dq_buffer.to(torch.float32)
158 | dk = block_dk_buffer.to(torch.float32)
159 | dv = block_dv_buffer.to(torch.float32)
160 | else:
161 | dq += block_dq_buffer
162 | d_kv_comm.wait()
163 | dk = block_dk_buffer + next_dk
164 | dv = block_dv_buffer + next_dv
165 | elif step != 0:
166 | d_kv_comm.wait()
167 | dk = next_dk
168 | dv = next_dv
169 |
170 | if step + 1 != kv_comm.world_size:
171 | kv_comm.wait()
172 | k = next_k
173 | v = next_v
174 |
175 | next_dk = d_kv_comm.send_recv(dk)
176 | next_dv = d_kv_comm.send_recv(dv)
177 | d_kv_comm.commit()
178 |
179 | d_kv_comm.wait()
180 |
181 | return dq.to(torch.bfloat16), next_dk.to(q.dtype), next_dv.to(q.dtype)
182 |
183 |
184 | class RingFlashAttnVarlenFunc(torch.autograd.Function):
185 | @staticmethod
186 | def forward(
187 | ctx,
188 | q,
189 | k,
190 | v,
191 | cu_seqlens,
192 | max_seqlen,
193 | dropout_p,
194 | softmax_scale,
195 | causal,
196 | window_size,
197 | alibi_slopes,
198 | deterministic,
199 | return_softmax,
200 | group,
201 | ):
202 | if softmax_scale is None:
203 | softmax_scale = q.shape[-1] ** (-0.5)
204 |
205 | assert alibi_slopes is None
206 | k = k.contiguous()
207 | v = v.contiguous()
208 | out, softmax_lse = ring_flash_attn_varlen_forward(
209 | group,
210 | q,
211 | k,
212 | v,
213 | cu_seqlens,
214 | max_seqlen,
215 | softmax_scale=softmax_scale,
216 | dropout_p=dropout_p,
217 | causal=causal,
218 | window_size=window_size,
219 | alibi_slopes=alibi_slopes,
220 | deterministic=False,
221 | )
222 | # this should be out_padded
223 | ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens)
224 | ctx.max_seqlen = max_seqlen
225 | ctx.dropout_p = dropout_p
226 | ctx.softmax_scale = softmax_scale
227 | ctx.causal = causal
228 | ctx.window_size = window_size
229 | ctx.alibi_slopes = alibi_slopes
230 | ctx.deterministic = deterministic
231 | ctx.group = group
232 | return out if not return_softmax else (out, softmax_lse, None)
233 |
234 | @staticmethod
235 | def backward(ctx, dout, *args):
236 | q, k, v, out, softmax_lse, cu_seqlens = ctx.saved_tensors
237 | dq, dk, dv = ring_flash_attn_varlen_backward(
238 | ctx.group,
239 | dout,
240 | q,
241 | k,
242 | v,
243 | out,
244 | softmax_lse,
245 | cu_seqlens,
246 | ctx.max_seqlen,
247 | softmax_scale=ctx.softmax_scale,
248 | dropout_p=ctx.dropout_p,
249 | causal=ctx.causal,
250 | window_size=ctx.window_size,
251 | alibi_slopes=ctx.alibi_slopes,
252 | deterministic=ctx.deterministic,
253 | )
254 | return dq, dk, dv, None, None, None, None, None, None, None, None, None, None
255 |
256 |
257 | def ring_flash_attn_varlen_qkvpacked_func(
258 | qkv,
259 | cu_seqlens,
260 | max_seqlen,
261 | dropout_p=0.0,
262 | softmax_scale=None,
263 | causal=False,
264 | window_size=(-1, -1), # -1 means infinite context window
265 | alibi_slopes=None,
266 | deterministic=False,
267 | return_attn_probs=False,
268 | group=None,
269 | ):
270 | return RingFlashAttnVarlenFunc.apply(
271 | qkv[:, 0],
272 | qkv[:, 1],
273 | qkv[:, 2],
274 | cu_seqlens,
275 | max_seqlen,
276 | dropout_p,
277 | softmax_scale,
278 | causal,
279 | window_size,
280 | alibi_slopes,
281 | deterministic,
282 | return_attn_probs,
283 | group,
284 | )
285 |
286 |
287 | def ring_flash_attn_varlen_kvpacked_func(
288 | q,
289 | kv,
290 | cu_seqlens,
291 | max_seqlen,
292 | dropout_p=0.0,
293 | softmax_scale=None,
294 | causal=False,
295 | window_size=(-1, -1), # -1 means infinite context window
296 | alibi_slopes=None,
297 | deterministic=False,
298 | return_attn_probs=False,
299 | group=None,
300 | ):
301 | return RingFlashAttnVarlenFunc.apply(
302 | q,
303 | kv[:, 0],
304 | kv[:, 1],
305 | cu_seqlens,
306 | max_seqlen,
307 | dropout_p,
308 | softmax_scale,
309 | causal,
310 | window_size,
311 | alibi_slopes,
312 | deterministic,
313 | return_attn_probs,
314 | group,
315 | )
316 |
317 |
318 | def ring_flash_attn_varlen_func(
319 | q,
320 | k,
321 | v,
322 | cu_seqlens,
323 | max_seqlen,
324 | dropout_p=0.0,
325 | softmax_scale=None,
326 | causal=False,
327 | window_size=(-1, -1), # -1 means infinite context window
328 | alibi_slopes=None,
329 | deterministic=False,
330 | return_attn_probs=False,
331 | group=None,
332 | ):
333 | return RingFlashAttnVarlenFunc.apply(
334 | q,
335 | k,
336 | v,
337 | cu_seqlens,
338 | max_seqlen,
339 | dropout_p,
340 | softmax_scale,
341 | causal,
342 | window_size,
343 | alibi_slopes,
344 | deterministic,
345 | return_attn_probs,
346 | group,
347 | )
348 |
--------------------------------------------------------------------------------
/train/ring-flash-attention/build/lib/ring_flash_attn/triton_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import triton
3 | import triton.language as tl
4 |
5 |
6 | @triton.jit
7 | def flatten_kernel(
8 | # pointers to matrices
9 | OUT,
10 | LSE,
11 | CU_SEQLENS,
12 | # strides
13 | stride_out_nheads,
14 | stride_out_seqlen,
15 | stride_lse_batch,
16 | stride_lse_nheads,
17 | stride_lse_seqlen,
18 | # meta-parameters
19 | BLOCK_M: tl.constexpr,
20 | ):
21 | pid_m = tl.program_id(axis=0)
22 | pid_batch = tl.program_id(axis=1)
23 | pid_head = tl.program_id(axis=2)
24 |
25 | start_idx = tl.load(CU_SEQLENS + pid_batch)
26 | seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx
27 | LSE = LSE + pid_batch * stride_lse_batch + pid_head * stride_lse_nheads
28 | OUT = OUT + pid_head * stride_out_nheads + start_idx * stride_out_seqlen
29 |
30 | rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
31 |
32 | LSE = LSE + rm[:, None] * stride_lse_seqlen
33 | x = tl.load(LSE, mask=rm[:, None] < seqlen, other=0.0)
34 |
35 | OUT = OUT + rm[:, None] * stride_out_seqlen
36 | tl.store(OUT, x, mask=rm[:, None] < seqlen)
37 |
38 |
39 | def flatten_varlen_lse(lse, cu_seqlens):
40 | """
41 | Arguments:
42 | lse: (batch_size, nheads, max_seqlen)
43 | cu_seqlens: (batch_size + 1,)
44 | Return:
45 | flatten_lse: (nheads, total_seqlen)
46 | """
47 | total_seqlen = cu_seqlens[-1]
48 | batch_size, nheads, max_seqlen = lse.shape
49 | output = torch.empty((nheads, total_seqlen), dtype=lse.dtype, device=lse.device)
50 |
51 | grid = lambda META: (triton.cdiv(max_seqlen, META["BLOCK_M"]), batch_size, nheads)
52 | BLOCK_M = 4
53 |
54 | with torch.cuda.device(lse.device.index):
55 | flatten_kernel[grid](
56 | output,
57 | lse,
58 | cu_seqlens,
59 | # strides
60 | output.stride(0),
61 | output.stride(1),
62 | lse.stride(0),
63 | lse.stride(1),
64 | lse.stride(2),
65 | BLOCK_M,
66 | )
67 | return output
68 |
69 |
70 | @triton.jit
71 | def unflatten_kernel(
72 | # pointers to matrices
73 | OUT,
74 | LSE,
75 | CU_SEQLENS,
76 | # strides
77 | stride_out_batch,
78 | stride_out_nheads,
79 | stride_out_seqlen,
80 | stride_lse_seqlen,
81 | stride_lse_nheads,
82 | # meta-parameters
83 | BLOCK_M: tl.constexpr,
84 | ):
85 | pid_m = tl.program_id(axis=0)
86 | pid_batch = tl.program_id(axis=1)
87 | pid_head = tl.program_id(axis=2)
88 |
89 | start_idx = tl.load(CU_SEQLENS + pid_batch)
90 | seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx
91 | LSE = LSE + pid_head * stride_lse_nheads + start_idx * stride_lse_seqlen
92 | OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads
93 |
94 | rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
95 |
96 | LSE = LSE + rm[:, None] * stride_lse_seqlen
97 | x = tl.load(LSE, mask=rm[:, None] < seqlen, other=0.0)
98 |
99 | OUT = OUT + rm[:, None] * stride_out_seqlen
100 | tl.store(OUT, x, mask=rm[:, None] < seqlen)
101 |
102 |
103 | def unflatten_varlen_lse(lse, cu_seqlens, max_seqlen: int):
104 | """
105 | Arguments:
106 | lse: (total_seqlen, nheads, 1)
107 | cu_seqlens: (batch_size + 1,)
108 | max_seqlen: int
109 | Return:
110 | unflatten_lse: (batch_size, nheads, max_seqlen)
111 | """
112 | lse = lse.unsqueeze(dim=-1)
113 | batch_size = len(cu_seqlens) - 1
114 | nheads = lse.shape[1]
115 | output = torch.empty(
116 | (batch_size, nheads, max_seqlen),
117 | dtype=lse.dtype,
118 | device=lse.device,
119 | )
120 |
121 | grid = lambda META: (triton.cdiv(max_seqlen, META["BLOCK_M"]), batch_size, nheads)
122 | BLOCK_M = 4
123 |
124 | with torch.cuda.device(lse.device.index):
125 | unflatten_kernel[grid](
126 | output,
127 | lse,
128 | cu_seqlens,
129 | # strides
130 | output.stride(0),
131 | output.stride(1),
132 | output.stride(2),
133 | lse.stride(0),
134 | lse.stride(1),
135 | BLOCK_M,
136 | )
137 | return output
138 |
--------------------------------------------------------------------------------
/train/ring-flash-attention/build/lib/ring_flash_attn/utils.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple
2 |
3 | import torch
4 | import torch.distributed as dist
5 | import torch.nn.functional as F
6 | import inspect
7 | from functools import cache
8 |
9 |
10 | __all__ = ["update_out_and_lse", "RingComm", "get_default_args"]
11 |
12 |
13 | @cache
14 | def get_default_args(func):
15 | spec = inspect.getfullargspec(func)
16 | defaults = spec.defaults if spec.defaults is not None else ()
17 | padded_defaults = (None,) * (len(spec.args) - len(defaults)) + defaults
18 | args = dict(zip(spec.args, padded_defaults))
19 | if "softcap" in args:
20 | args["softcap"] = 0.0
21 | return args
22 |
23 |
24 | @torch.jit.script
25 | def _update_out_and_lse(
26 | out: torch.Tensor,
27 | lse: torch.Tensor,
28 | block_out: torch.Tensor,
29 | block_lse: torch.Tensor,
30 | ) -> Tuple[torch.Tensor, torch.Tensor]:
31 |
32 | block_out = block_out.to(torch.float32)
33 | block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1)
34 |
35 | # new_lse = lse + torch.log(1 + torch.exp(block_lse - lse))
36 | # torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out
37 | # For additional context and discussion, please refer to:
38 | # https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795
39 | out = out - F.sigmoid(block_lse - lse) * (out - block_out)
40 | lse = lse - F.logsigmoid(lse - block_lse)
41 |
42 | return out, lse
43 |
44 |
45 | def update_out_and_lse(
46 | out: Optional[torch.Tensor],
47 | lse: Optional[torch.Tensor],
48 | block_out: torch.Tensor,
49 | block_lse: torch.Tensor,
50 | slice_=None,
51 | ) -> Tuple[torch.Tensor, torch.Tensor]:
52 | if out is None:
53 | if slice_ is not None:
54 | raise RuntimeError("first update_out_and_lse should not pass slice_ args")
55 | out = block_out.to(torch.float32)
56 | lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1)
57 | elif slice_ is not None:
58 | slice_out, slice_lse = out[slice_], lse[slice_]
59 | slice_out, slice_lse = _update_out_and_lse(
60 | slice_out, slice_lse, block_out, block_lse
61 | )
62 | out[slice_], lse[slice_] = slice_out, slice_lse
63 | else:
64 | out, lse = _update_out_and_lse(out, lse, block_out, block_lse)
65 | return out, lse
66 |
67 |
68 | @torch.jit.script
69 | def flatten_varlen_lse(lse, cu_seqlens):
70 | new_lse = []
71 | for i in range(len(cu_seqlens) - 1):
72 | start, end = cu_seqlens[i], cu_seqlens[i + 1]
73 | new_lse.append(lse[i, :, : end - start])
74 | return torch.cat(new_lse, dim=1)
75 |
76 |
77 | @torch.jit.script
78 | def unflatten_varlen_lse(lse, cu_seqlens, max_seqlen: int):
79 | num_seq = len(cu_seqlens) - 1
80 | num_head = lse.shape[-2]
81 | new_lse = torch.empty(
82 | (num_seq, max_seqlen, num_head, 1), dtype=torch.float32, device=lse.device
83 | )
84 | for i in range(num_seq):
85 | start, end = cu_seqlens[i], cu_seqlens[i + 1]
86 | new_lse[i, : end - start] = lse[start:end]
87 | return new_lse.squeeze(dim=-1).transpose(1, 2).contiguous()
88 |
89 |
90 | class RingComm:
91 | def __init__(self, process_group: dist.ProcessGroup):
92 | self._process_group = process_group
93 | self._ops = []
94 | self.rank = dist.get_rank(self._process_group)
95 | self.world_size = dist.get_world_size(self._process_group)
96 | self._reqs = None
97 |
98 | self.send_rank = (self.rank + 1) % self.world_size
99 | self.recv_rank = (self.rank - 1) % self.world_size
100 |
101 | if process_group is not None:
102 | self.send_rank = dist.get_global_rank(self._process_group, self.send_rank)
103 | self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank)
104 |
105 | def send_recv(
106 | self, to_send: torch.Tensor, recv_tensor: Optional[torch.Tensor] = None
107 | ) -> torch.Tensor:
108 | if recv_tensor is None:
109 | res = torch.empty_like(to_send)
110 | else:
111 | res = recv_tensor
112 |
113 | send_op = dist.P2POp(
114 | dist.isend, to_send, self.send_rank, group=self._process_group
115 | )
116 | recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group)
117 | self._ops.append(send_op)
118 | self._ops.append(recv_op)
119 | return res
120 |
121 | def commit(self):
122 | if self._reqs is not None:
123 | raise RuntimeError("commit called twice")
124 | self._reqs = dist.batch_isend_irecv(self._ops)
125 |
126 | def wait(self):
127 | if self._reqs is None:
128 | raise RuntimeError("wait called before commit")
129 | for req in self._reqs:
130 | req.wait()
131 | self._reqs = None
132 | self._ops = []
133 |
134 |
135 | class AllGatherComm:
136 | def __init__(self, group=None) -> None:
137 | self.group = group
138 | self.handles = []
139 |
140 | def all_gather(self, output_tensor: torch.Tensor, input_tensor: torch.Tensor):
141 | handle = dist.all_gather_into_tensor(
142 | output_tensor, input_tensor, group=self.group, async_op=True
143 | )
144 | self.handles.append(handle)
145 |
146 | def wait(self):
147 | for handle in self.handles:
148 | handle.wait()
149 | self.handles = []
150 |
--------------------------------------------------------------------------------
/train/ring-flash-attention/build/lib/ring_flash_attn/zigzag_ring_flash_attn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.distributed as dist
3 | from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward
4 | from .utils import RingComm, update_out_and_lse, get_default_args
5 |
6 |
7 | def zigzag_ring_flash_attn_forward(
8 | process_group,
9 | q: torch.Tensor,
10 | k: torch.Tensor,
11 | v: torch.Tensor,
12 | softmax_scale,
13 | dropout_p=0,
14 | causal=True,
15 | window_size=(-1, -1),
16 | alibi_slopes=None,
17 | deterministic=False,
18 | ):
19 | assert causal == True, "zigzag ring is meaningless for causal=False"
20 | comm = RingComm(process_group)
21 |
22 | block_seq_len = q.shape[1] // 2
23 | q1 = q[:, block_seq_len:]
24 |
25 | out = None
26 | lse = None
27 | next_k, next_v = None, None
28 |
29 | def forward(q, k, v, causal):
30 | params = get_default_args(_flash_attn_forward).copy()
31 | params.update(
32 | {
33 | "q": q,
34 | "k": k,
35 | "v": v,
36 | "dropout_p": dropout_p,
37 | "softmax_scale": softmax_scale,
38 | "causal": causal,
39 | "window_size": window_size,
40 | "alibi_slopes": alibi_slopes,
41 | "return_softmax": True and dropout_p > 0,
42 | }
43 | )
44 | block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward(**params)
45 | return block_out, block_lse
46 |
47 | for step in range(comm.world_size):
48 | if step + 1 != comm.world_size:
49 | next_k: torch.Tensor = comm.send_recv(k)
50 | next_v: torch.Tensor = comm.send_recv(v)
51 | comm.commit()
52 |
53 | if step == 0:
54 | block_out, block_lse = forward(q, k, v, causal=True)
55 | out, lse = update_out_and_lse(out, lse, block_out, block_lse)
56 | elif step <= comm.rank:
57 | k0 = k[:, :block_seq_len]
58 | v0 = v[:, :block_seq_len]
59 | block_out, block_lse = forward(q, k0, v0, causal=False)
60 | out, lse = update_out_and_lse(out, lse, block_out, block_lse)
61 | else:
62 | block_out, block_lse = forward(q1, k, v, causal=False)
63 | out, lse = update_out_and_lse(
64 | out,
65 | lse,
66 | block_out,
67 | block_lse,
68 | slice_=(slice(None), slice(block_seq_len, None)),
69 | )
70 |
71 | if step + 1 != comm.world_size:
72 | comm.wait()
73 | k = next_k
74 | v = next_v
75 |
76 | out = out.to(q.dtype)
77 | lse = lse.squeeze(dim=-1).transpose(1, 2)
78 | return out, lse
79 |
80 |
81 | def zigzag_ring_flash_attn_backward(
82 | process_group,
83 | dout,
84 | q,
85 | k,
86 | v,
87 | out,
88 | softmax_lse,
89 | softmax_scale,
90 | dropout_p=0,
91 | causal=True,
92 | window_size=(-1, -1),
93 | alibi_slopes=None,
94 | deterministic=False,
95 | ):
96 | assert causal == True, "zigzag ring is meaningless for causal=False"
97 | kv_comm = RingComm(process_group)
98 | d_kv_comm = RingComm(process_group)
99 | dq, dk, dv = None, None, None
100 | next_dk, next_dv = None, None
101 | next_k, next_v = None, None
102 | dk_comm_buffer, dv_comm_buffer = None, None
103 |
104 | dout1 = dout.chunk(2, dim=1)[1]
105 | q1 = q.chunk(2, dim=1)[1]
106 | out1 = out.chunk(2, dim=1)[1]
107 | softmax_lse1 = softmax_lse.chunk(2, dim=2)[1].contiguous()
108 | block_seq_len = q.shape[1] // 2
109 |
110 | # repeatly allocating buffer may be slow...
111 | dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device)
112 | dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device)
113 | dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device)
114 |
115 | def backward(dout, q, k, v, out, softmax_lse, causal):
116 | seqlen_q = q.shape[1]
117 | seqlen_kv = k.shape[1]
118 | params = get_default_args(_flash_attn_backward).copy()
119 | params.update(
120 | {
121 | "dout": dout,
122 | "q": q,
123 | "k": k,
124 | "v": v,
125 | "out": out,
126 | "softmax_lse": softmax_lse,
127 | "dq": dq_buffer[:, :seqlen_q],
128 | "dk": dk_buffer[:, :seqlen_kv],
129 | "dv": dv_buffer[:, :seqlen_kv],
130 | "dropout_p": dropout_p,
131 | "softmax_scale": softmax_scale,
132 | "causal": causal,
133 | "window_size": window_size,
134 | "alibi_slopes": alibi_slopes,
135 | "deterministic": deterministic,
136 | }
137 | )
138 | _flash_attn_backward(**params)
139 |
140 | for step in range(kv_comm.world_size):
141 | if step + 1 != kv_comm.world_size:
142 | next_k = kv_comm.send_recv(k)
143 | next_v = kv_comm.send_recv(v)
144 | kv_comm.commit()
145 |
146 | if step == 0:
147 | backward(dout, q, k, v, out, softmax_lse, causal=True)
148 | dq = dq_buffer.to(torch.float32)
149 | dk = dk_buffer.to(torch.float32)
150 | dv = dv_buffer.to(torch.float32)
151 | else:
152 | if step <= kv_comm.rank:
153 | k0 = k[:, :block_seq_len]
154 | v0 = v[:, :block_seq_len]
155 | backward(dout, q, k0, v0, out, softmax_lse, causal=False)
156 | dq += dq_buffer
157 | else:
158 | backward(dout1, q1, k, v, out1, softmax_lse1, causal=False)
159 | # always use the first half in dq_buffer.
160 | dq[:, block_seq_len:] += dq_buffer[:, :block_seq_len]
161 |
162 | d_kv_comm.wait()
163 | dk_comm_buffer, dv_comm_buffer = dk, dv
164 | dk, dv = next_dk, next_dv
165 |
166 | if step <= kv_comm.rank:
167 | dk[:, :block_seq_len] += dk_buffer[:, :block_seq_len]
168 | dv[:, :block_seq_len] += dv_buffer[:, :block_seq_len]
169 | else:
170 | dk += dk_buffer
171 | dv += dv_buffer
172 |
173 | if step + 1 != kv_comm.world_size:
174 | kv_comm.wait()
175 | k = next_k
176 | v = next_v
177 |
178 | next_dk = d_kv_comm.send_recv(dk, dk_comm_buffer)
179 | next_dv = d_kv_comm.send_recv(dv, dv_comm_buffer)
180 | d_kv_comm.commit()
181 |
182 | d_kv_comm.wait()
183 |
184 | return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype)
185 |
186 |
187 | class ZigZagRingFlashAttnFunc(torch.autograd.Function):
188 | @staticmethod
189 | def forward(
190 | ctx,
191 | q,
192 | k,
193 | v,
194 | dropout_p,
195 | softmax_scale,
196 | causal,
197 | window_size,
198 | alibi_slopes,
199 | deterministic,
200 | return_softmax,
201 | group,
202 | ):
203 | if softmax_scale is None:
204 | softmax_scale = q.shape[-1] ** (-0.5)
205 |
206 | assert alibi_slopes is None
207 | k = k.contiguous()
208 | v = v.contiguous()
209 | out, softmax_lse = zigzag_ring_flash_attn_forward(
210 | group,
211 | q,
212 | k,
213 | v,
214 | softmax_scale=softmax_scale,
215 | dropout_p=dropout_p,
216 | causal=causal,
217 | window_size=window_size,
218 | alibi_slopes=alibi_slopes,
219 | deterministic=False,
220 | )
221 | # this should be out_padded
222 | ctx.save_for_backward(q, k, v, out, softmax_lse)
223 | ctx.dropout_p = dropout_p
224 | ctx.softmax_scale = softmax_scale
225 | ctx.causal = causal
226 | ctx.window_size = window_size
227 | ctx.alibi_slopes = alibi_slopes
228 | ctx.deterministic = deterministic
229 | ctx.group = group
230 | return out if not return_softmax else (out, softmax_lse, None)
231 |
232 | @staticmethod
233 | def backward(ctx, dout, *args):
234 | q, k, v, out, softmax_lse = ctx.saved_tensors
235 | dq, dk, dv = zigzag_ring_flash_attn_backward(
236 | ctx.group,
237 | dout,
238 | q,
239 | k,
240 | v,
241 | out,
242 | softmax_lse,
243 | softmax_scale=ctx.softmax_scale,
244 | dropout_p=ctx.dropout_p,
245 | causal=ctx.causal,
246 | window_size=ctx.window_size,
247 | alibi_slopes=ctx.alibi_slopes,
248 | deterministic=ctx.deterministic,
249 | )
250 | return dq, dk, dv, None, None, None, None, None, None, None, None
251 |
252 |
253 | def zigzag_ring_flash_attn_qkvpacked_func(
254 | qkv,
255 | dropout_p=0.0,
256 | softmax_scale=None,
257 | causal=False,
258 | window_size=(-1, -1),
259 | alibi_slopes=None,
260 | deterministic=False,
261 | return_attn_probs=False,
262 | group=None,
263 | ):
264 | return ZigZagRingFlashAttnFunc.apply(
265 | qkv[:, :, 0],
266 | qkv[:, :, 1],
267 | qkv[:, :, 2],
268 | dropout_p,
269 | softmax_scale,
270 | causal,
271 | window_size,
272 | alibi_slopes,
273 | deterministic,
274 | return_attn_probs,
275 | group,
276 | )
277 |
278 |
279 | def zigzag_ring_flash_attn_kvpacked_func(
280 | q,
281 | kv,
282 | dropout_p=0.0,
283 | softmax_scale=None,
284 | causal=False,
285 | window_size=(-1, -1),
286 | alibi_slopes=None,
287 | deterministic=False,
288 | return_attn_probs=False,
289 | group=None,
290 | ):
291 | return ZigZagRingFlashAttnFunc.apply(
292 | q,
293 | kv[:, :, 0],
294 | kv[:, :, 1],
295 | dropout_p,
296 | softmax_scale,
297 | causal,
298 | window_size,
299 | alibi_slopes,
300 | deterministic,
301 | return_attn_probs,
302 | group,
303 | )
304 |
305 |
306 | def zigzag_ring_flash_attn_func(
307 | q,
308 | k,
309 | v,
310 | dropout_p=0.0,
311 | softmax_scale=None,
312 | causal=False,
313 | window_size=(-1, -1),
314 | alibi_slopes=None,
315 | deterministic=False,
316 | return_attn_probs=False,
317 | group=None,
318 | ):
319 | return ZigZagRingFlashAttnFunc.apply(
320 | q,
321 | k,
322 | v,
323 | dropout_p,
324 | softmax_scale,
325 | causal,
326 | window_size,
327 | alibi_slopes,
328 | deterministic,
329 | return_attn_probs,
330 | group,
331 | )
332 |
--------------------------------------------------------------------------------
/train/ring-flash-attention/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=61.0"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "ring-flash-attn"
7 | version = "0.1.1"
8 | authors = [
9 | { name="zhuzilin", email="zhuzilinallen@gmail.com" },
10 | ]
11 | description = "Ring attention implementation with flash attention."
12 | readme = "README.md"
13 | requires-python = ">=3.8"
14 | classifiers = [
15 | "Programming Language :: Python :: 3",
16 | "License :: OSI Approved :: MIT License",
17 | "Operating System :: OS Independent",
18 | ]
19 |
20 | [project.urls]
21 | Homepage = "https://github.com/zhuzilin/ring-flash-attention"
22 | Issues = "https://github.com/zhuzilin/ring-flash-attention/issues"
--------------------------------------------------------------------------------
/train/ring-flash-attention/ring_flash_attn.egg-info/PKG-INFO:
--------------------------------------------------------------------------------
1 | Metadata-Version: 2.2
2 | Name: ring-flash-attn
3 | Version: 0.1.1
4 | Summary: Ring attention implementation with flash attention.
5 | Home-page: https://github.com/zhuzilin/ring-flash-attention
6 | Author: zhuzilin
7 | Author-email: zhuzilin
8 | Project-URL: Homepage, https://github.com/zhuzilin/ring-flash-attention
9 | Project-URL: Issues, https://github.com/zhuzilin/ring-flash-attention/issues
10 | Classifier: Programming Language :: Python :: 3
11 | Classifier: License :: OSI Approved :: MIT License
12 | Classifier: Operating System :: OS Independent
13 | Requires-Python: >=3.8
14 | Description-Content-Type: text/markdown
15 | License-File: LICENSE
16 | Dynamic: author
17 | Dynamic: home-page
18 |
19 | ## Ring Flash Attention
20 |
21 | This repo implements [RingAttention](https://github.com/lhao499/RingAttention) using [FlashAttention](https://github.com/Dao-AILab/flash-attention). The current implementation supports:
22 |
23 | - varlen (packing samples) api, corresponding to `flash_attn_varlen_func`:
24 | - `ring_flash_attn_varlen_func`: A basic implementation of ring attention.
25 | - `zigzag_ring_flash_attn_varlen_func`: an more compute-balanced version of ring attention. More details in [issue#2](https://github.com/zhuzilin/ring-flash-attention/issues/2).
26 | - `llama3_flash_attn_varlen_func`: The context parallelism used in [llama3 tech report](https://arxiv.org/abs/2407.21783) with extra design for varlen and low memory overhead. Although technically not ring attention, this is **recommended** for most varlen use cases, as it offers a less intrusive alternative for training frameworks with fewer data manipulations and better arithmetic precision.
27 | - batch api, corresponding to `flash_attn_func`:
28 | - `ring_flash_attn_func`: basic ring attention.
29 | - `zigzag_ring_flash_attn_func`: An more compute balanced version of ring attention, see [issue#2](https://github.com/zhuzilin/ring-flash-attention/issues/2).
30 | - `stripe_flash_attn_func`: Stripe attention version of `ring_flash_attn_func`, the block size is set to 1 to use flash_attn api, see: https://arxiv.org/abs/2311.09431
31 | - [huggingface model adapter](ring_flash_attn/adapters/hf_adapter.py). Here is an example to use the adapter: [OpenRLHF/OpenRLHF/pull#439](https://github.com/OpenRLHF/OpenRLHF/pull/439/files).
32 |
33 | Note that
34 |
35 | - Each function includes `*_func`, `*_kvpacked_func`, `*_qkvpacked_func` variants.
36 | - The varlen versions (except the llama3 version) only support passing one `cu_seqlens`.
37 |
38 | ## Performance Summary
39 |
40 | The following table summarizes the performance of the implemented APIs:
41 |
42 | | batch api | GPU | theoretic
flash_attn | ring_attn | zigzag_ring | stripe_attn |
43 | | -------------------- | ------- | ----------------------------- | ------------- | --------------- | --------------- |
44 | | fwd only (iter/sec) | 8xH800 | 591.5 / 8 = 73.9 | 38.5 | 63.0 | 55.0 |
45 | | | | | 52.1% | **85.2%** | 74.4% |
46 | | fwd + bwd (iter/sec) | 8xH800 | 154.7 / 8 = 19.3 | 10.4 | 17.4 | 16.0 |
47 | | | | | 53.9% | **90.2%** | 82.9% |
48 | | fwd only (iter/sec) | 8xA100 | 373.4 / 8 = 46.7 | 24.0 | 38.2 | 32.5 |
49 | | | | | 51.4% | **81.7%** | 69.6% |
50 | | fwd + bwd (iter/sec) | 8xA100 | 94.7 / 8 = 11.8 | 6.2 | 10.6 | 9.75 |
51 | | | | | 52.5% | **89.8%** | 82.6% |
52 | | **varlen api** | **GPU** | **theoretic
flash_attn** | **ring_attn** | **zigzag_ring** | **llama3_attn** |
53 | | fwd only (iter/sec) | 8xH800 | 852.4 / 8 = 106.6 | 52.4 | 74.8 | 60.8 |
54 | | | | | 49.1% | **70.2%** | 57.0% |
55 | | fwd + bwd (iter/sec) | 8xH800 | 225.4 / 8 = 28.2 | 14.4 | 21.4 | 16.4 |
56 | | | | | 51.1% | **75.9%** | 58.1% |
57 | | fwd only (iter/sec) | 8xA100 | 532.3 / 8 = 66.5 | 33.1 | 47.9 | 34.3 |
58 | | | | | 49.8% | **72.0%** | 51.6% |
59 | | fwd + bwd (iter/sec) | 8xA100 | 133.8 / 8 = 16.7 | 8.7 | 13.4 | 9.7 |
60 | | | | | 52.1% | **80.2%** | 58.0% |
61 |
62 | Note that
63 |
64 | - The code of the benchmark is in [benchmark](benchmark/), its configuration matches the [Meta-Llama-3.1-8B](https://huggingface.co/NousResearch/Meta-Llama-3.1-8B/blob/main/config.json) setting, with a total sequence of length 8k per GPU.
65 | - When running the benchmark with with 8 gpu, the flash attn code is running with 1/8 computation of ring attention, as flash attn code is running `8*1^2`, while the ring attn code is running `1*8^2`.
66 | - NVLink between GPUs are required for high performance.
67 | - Please remember to adapt the RoPE offset for different api.
68 |
69 | ### Installation
70 |
71 | ```bash
72 | pip install ring-flash-attn
73 | ```
74 |
75 | or use the following command to build from source:
76 |
77 | ```bash
78 | git clone https://github.com/zhuzilin/ring-flash-attention.git
79 | cd ring-flash-attention
80 | pip install .
81 | ```
82 |
83 | ### TODOs
84 |
85 | - [x] Implement `ring_flash_attn_varlen_qkvpacked_func`
86 | - [x] Implement `zigzag_ring_flash_attn_qkvpacked_func` [issue#2](https://github.com/zhuzilin/ring-flash-attention/issues/2)
87 | - [x] Implement `stripe_flash_attn_qkvpacked_func`
88 | - [x] Implement `zigzag_ring_flash_attn_varlen_qkvpacked_func`
89 | - [x] Implement `*_kvpacked_func` and `*_func` variant for all APIs
90 | - [x] ~~Optimize `*_varlen_func`~~ Implement `llama3_flash_attn_varlen_func`
91 | - [x] ~~Add an example to train llama~~ Implement adapter for huggingface model
92 | - [ ] Implement `zigzag_llama3_flash_attn_varlen_func`
93 |
94 | ### Test
95 |
96 | ```bash
97 | torchrun --nproc_per_node 8 test/test_llama3_flash_attn_varlen_func.py
98 | torchrun --nproc_per_node 8 test/test_ring_flash_attn_func.py
99 | torchrun --nproc_per_node 8 test/test_ring_flash_attn_varlen_func.py
100 | torchrun --nproc_per_node 8 test/test_zigzag_ring_flash_attn_func.py
101 | torchrun --nproc_per_node 8 test/test_zigzag_ring_flash_attn_varlen_func.py
102 | torchrun --nproc_per_node 8 test/test_stripe_flash_attn_func.py
103 | ```
104 |
105 | ### Benchmark
106 |
107 | ```bash
108 | torchrun --nproc_per_node 8 benchmark/benchmark_kvpacked_func.py
109 | torchrun --nproc_per_node 8 benchmark/benchmark_varlen_kvpacked_func.py
110 | ```
111 |
112 | ### Known Limitations
113 |
114 | There are some arithmetic errors with the current implementation. The reason for them is probably that flash attention will return bf16 value for each block, so we cannot accumluate the values with the original fp32 ones.
115 |
116 | And also because we need to save extra fp32 buffer during computation, the memory usage would be higher than theoretic limit.
117 |
118 | Also,
119 |
120 | - dropout is not supported at the moment, because it's hard to save all the rng_states.
121 | - window_size is not supported, because it will be really tricky to implement a varlen version with window_size.
122 |
--------------------------------------------------------------------------------
/train/ring-flash-attention/ring_flash_attn.egg-info/SOURCES.txt:
--------------------------------------------------------------------------------
1 | LICENSE
2 | README.md
3 | pyproject.toml
4 | setup.py
5 | ring_flash_attn/__init__.py
6 | ring_flash_attn/llama3_flash_attn_varlen.py
7 | ring_flash_attn/ring_flash_attn.py
8 | ring_flash_attn/ring_flash_attn_varlen.py
9 | ring_flash_attn/stripe_flash_attn.py
10 | ring_flash_attn/triton_utils.py
11 | ring_flash_attn/utils.py
12 | ring_flash_attn/zigzag_ring_flash_attn.py
13 | ring_flash_attn/zigzag_ring_flash_attn_varlen.py
14 | ring_flash_attn.egg-info/PKG-INFO
15 | ring_flash_attn.egg-info/SOURCES.txt
16 | ring_flash_attn.egg-info/dependency_links.txt
17 | ring_flash_attn.egg-info/top_level.txt
18 | ring_flash_attn/adapters/__init__.py
19 | ring_flash_attn/adapters/hf_adapter.py
20 | test/test_llama3_flash_attn_varlen_func.py
21 | test/test_llama3_prepare_cu_seqlens.py
22 | test/test_ring_flash_attn_func.py
23 | test/test_ring_flash_attn_varlen_func.py
24 | test/test_stripe_flash_attn_func.py
25 | test/test_triton_kernels.py
26 | test/test_zigzag_ring_flash_attn_func.py
27 | test/test_zigzag_ring_flash_attn_varlen_func.py
--------------------------------------------------------------------------------
/train/ring-flash-attention/ring_flash_attn.egg-info/dependency_links.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/train/ring-flash-attention/ring_flash_attn.egg-info/top_level.txt:
--------------------------------------------------------------------------------
1 | ring_flash_attn
2 |
--------------------------------------------------------------------------------
/train/ring-flash-attention/ring_flash_attn/.___pycache__:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mlpod/OpenSFT/57ecd3c6037c4dca600009766dbeb0c1249a6048/train/ring-flash-attention/ring_flash_attn/.___pycache__
--------------------------------------------------------------------------------
/train/ring-flash-attention/ring_flash_attn/.ipynb_checkpoints/utils-checkpoint.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple
2 |
3 | import torch
4 | import torch.distributed as dist
5 | import torch.nn.functional as F
6 | import inspect
7 | from functools import cache
8 |
9 |
10 | __all__ = ["update_out_and_lse", "RingComm", "get_default_args"]
11 |
12 |
13 | @cache
14 | def get_default_args(func):
15 | spec = inspect.getfullargspec(func)
16 | defaults = spec.defaults if spec.defaults is not None else ()
17 | padded_defaults = (None,) * (len(spec.args) - len(defaults)) + defaults
18 | args = dict(zip(spec.args, padded_defaults))
19 | if "softcap" in args:
20 | args["softcap"] = 0.0
21 | return args
22 |
23 |
24 | @torch.jit.script
25 | def _update_out_and_lse(
26 | out: torch.Tensor,
27 | lse: torch.Tensor,
28 | block_out: torch.Tensor,
29 | block_lse: torch.Tensor,
30 | ) -> Tuple[torch.Tensor, torch.Tensor]:
31 |
32 | block_out = block_out.to(torch.float32)
33 | block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1)
34 |
35 | # new_lse = lse + torch.log(1 + torch.exp(block_lse - lse))
36 | # torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out
37 | # For additional context and discussion, please refer to:
38 | # https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795
39 | out = out - F.sigmoid(block_lse - lse) * (out - block_out)
40 | lse = lse - F.logsigmoid(lse - block_lse)
41 |
42 | return out, lse
43 |
44 |
45 | def update_out_and_lse(
46 | out: Optional[torch.Tensor],
47 | lse: Optional[torch.Tensor],
48 | block_out: torch.Tensor,
49 | block_lse: torch.Tensor,
50 | slice_=None,
51 | ) -> Tuple[torch.Tensor, torch.Tensor]:
52 | if out is None:
53 | if slice_ is not None:
54 | raise RuntimeError("first update_out_and_lse should not pass slice_ args")
55 | out = block_out.to(torch.float32)
56 | lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1)
57 | elif slice_ is not None:
58 | slice_out, slice_lse = out[slice_], lse[slice_]
59 | slice_out, slice_lse = _update_out_and_lse(
60 | slice_out, slice_lse, block_out, block_lse
61 | )
62 | out[slice_], lse[slice_] = slice_out, slice_lse
63 | else:
64 | out, lse = _update_out_and_lse(out, lse, block_out, block_lse)
65 | return out, lse
66 |
67 |
68 | @torch.jit.script
69 | def flatten_varlen_lse(lse, cu_seqlens):
70 | new_lse = []
71 | for i in range(len(cu_seqlens) - 1):
72 | start, end = cu_seqlens[i], cu_seqlens[i + 1]
73 | new_lse.append(lse[i, :, : end - start])
74 | return torch.cat(new_lse, dim=1)
75 |
76 |
77 | @torch.jit.script
78 | def unflatten_varlen_lse(lse, cu_seqlens, max_seqlen: int):
79 | num_seq = len(cu_seqlens) - 1
80 | num_head = lse.shape[-2]
81 | new_lse = torch.empty(
82 | (num_seq, max_seqlen, num_head, 1), dtype=torch.float32, device=lse.device
83 | )
84 | for i in range(num_seq):
85 | start, end = cu_seqlens[i], cu_seqlens[i + 1]
86 | new_lse[i, : end - start] = lse[start:end]
87 | return new_lse.squeeze(dim=-1).transpose(1, 2).contiguous()
88 |
89 |
90 | class RingComm:
91 | def __init__(self, process_group: dist.ProcessGroup):
92 | self._process_group = process_group
93 | self._ops = []
94 | self.rank = dist.get_rank(self._process_group)
95 | self.world_size = dist.get_world_size(self._process_group)
96 | self._reqs = None
97 |
98 | self.send_rank = (self.rank + 1) % self.world_size
99 | self.recv_rank = (self.rank - 1) % self.world_size
100 |
101 | if process_group is not None:
102 | self.send_rank = dist.get_global_rank(self._process_group, self.send_rank)
103 | self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank)
104 |
105 | def send_recv(
106 | self, to_send: torch.Tensor, recv_tensor: Optional[torch.Tensor] = None
107 | ) -> torch.Tensor:
108 | if recv_tensor is None:
109 | res = torch.empty_like(to_send)
110 | else:
111 | res = recv_tensor
112 |
113 | send_op = dist.P2POp(
114 | dist.isend, to_send, self.send_rank, group=self._process_group
115 | )
116 | recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group)
117 | self._ops.append(send_op)
118 | self._ops.append(recv_op)
119 | return res
120 |
121 | def commit(self):
122 | if self._reqs is not None:
123 | raise RuntimeError("commit called twice")
124 | self._reqs = dist.batch_isend_irecv(self._ops)
125 |
126 | def wait(self):
127 | if self._reqs is None:
128 | raise RuntimeError("wait called before commit")
129 | for req in self._reqs:
130 | req.wait()
131 | self._reqs = None
132 | self._ops = []
133 |
134 |
135 | class AllGatherComm:
136 | def __init__(self, group=None) -> None:
137 | self.group = group
138 | self.handles = []
139 |
140 | def all_gather(self, output_tensor: torch.Tensor, input_tensor: torch.Tensor):
141 | handle = dist.all_gather_into_tensor(
142 | output_tensor, input_tensor, group=self.group, async_op=True
143 | )
144 | self.handles.append(handle)
145 |
146 | def wait(self):
147 | for handle in self.handles:
148 | handle.wait()
149 | self.handles = []
150 |
--------------------------------------------------------------------------------
/train/ring-flash-attention/ring_flash_attn/__init__.py:
--------------------------------------------------------------------------------
1 | from .llama3_flash_attn_varlen import (
2 | llama3_flash_attn_prepare_cu_seqlens,
3 | llama3_flash_attn_varlen_func,
4 | llama3_flash_attn_varlen_kvpacked_func,
5 | llama3_flash_attn_varlen_qkvpacked_func,
6 | )
7 | from .ring_flash_attn import (
8 | ring_flash_attn_func,
9 | ring_flash_attn_kvpacked_func,
10 | ring_flash_attn_qkvpacked_func,
11 | )
12 | from .ring_flash_attn_varlen import (
13 | ring_flash_attn_varlen_func,
14 | ring_flash_attn_varlen_kvpacked_func,
15 | ring_flash_attn_varlen_qkvpacked_func,
16 | )
17 | from .zigzag_ring_flash_attn import (
18 | zigzag_ring_flash_attn_func,
19 | zigzag_ring_flash_attn_kvpacked_func,
20 | zigzag_ring_flash_attn_qkvpacked_func,
21 | )
22 | from .zigzag_ring_flash_attn_varlen import (
23 | zigzag_ring_flash_attn_varlen_func,
24 | zigzag_ring_flash_attn_varlen_kvpacked_func,
25 | zigzag_ring_flash_attn_varlen_qkvpacked_func,
26 | )
27 | from .stripe_flash_attn import (
28 | stripe_flash_attn_func,
29 | stripe_flash_attn_kvpacked_func,
30 | stripe_flash_attn_qkvpacked_func,
31 | )
32 | from .adapters import (
33 | substitute_hf_flash_attn,
34 | update_ring_flash_attn_params,
35 | )
36 |
--------------------------------------------------------------------------------
/train/ring-flash-attention/ring_flash_attn/adapters/.ipynb_checkpoints/hf_adapter-checkpoint.py:
--------------------------------------------------------------------------------
1 | import os
2 | import inspect
3 | from typing import Optional
4 |
5 | import torch
6 | import torch.distributed as dist
7 | import transformers
8 | import transformers.modeling_flash_attention_utils
9 | from transformers.modeling_flash_attention_utils import (
10 | _flash_supports_window_size,
11 | is_flash_attn_greater_or_equal,
12 | )
13 | from ..llama3_flash_attn_varlen import (
14 | llama3_flash_attn_varlen_func,
15 | llama3_flash_attn_prepare_cu_seqlens,
16 | )
17 |
18 |
19 | DATA_PARAMS = {}
20 | RING_ATTN_SWITCH = True
21 |
22 |
23 | def check_params(f1, f2):
24 | return len(inspect.signature(f1).parameters) == len(
25 | inspect.signature(f2).parameters
26 | )
27 |
28 |
29 | def update_ring_flash_attn_params(
30 | cu_seqlens: torch.Tensor, process_group: dist.ProcessGroup
31 | ):
32 | world_size = dist.get_world_size(group=process_group)
33 | rank = dist.get_rank(group=process_group)
34 | (
35 | cu_seqlens_q,
36 | cu_seqlens_k,
37 | max_seqlen_q,
38 | max_seqlen_k,
39 | local_k_slice,
40 | ) = llama3_flash_attn_prepare_cu_seqlens(cu_seqlens, True, rank, world_size)
41 | DATA_PARAMS.update(
42 | {
43 | "cu_seqlens_q": cu_seqlens_q,
44 | "cu_seqlens_k": cu_seqlens_k,
45 | "max_seqlen_q": max_seqlen_q,
46 | "max_seqlen_k": max_seqlen_k,
47 | "local_k_slice": local_k_slice,
48 | }
49 | )
50 |
51 |
52 | def use_ring_attn(flag):
53 | global RING_ATTN_SWITCH
54 | RING_ATTN_SWITCH = flag
55 |
56 |
57 | def create_ring_flash_attention_forward(
58 | process_group: dist.ProcessGroup, heads_k_stride: int
59 | ):
60 | def _flash_attention_forward(
61 | query_states: torch.Tensor,
62 | key_states: torch.Tensor,
63 | value_states: torch.Tensor,
64 | attention_mask: torch.Tensor,
65 | query_length: int,
66 | is_causal: bool,
67 | dropout: float = 0.0,
68 | position_ids: Optional[torch.Tensor] = None,
69 | softmax_scale: Optional[float] = None,
70 | sliding_window: Optional[int] = None,
71 | use_top_left_mask: bool = False,
72 | softcap: Optional[float] = None,
73 | deterministic: bool = None,
74 | ):
75 | """
76 | Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
77 | first unpad the input, then computes the attention scores and pad the final attention scores.
78 |
79 | Args:
80 | query_states (`torch.Tensor`):
81 | Input query states to be passed to Flash Attention API
82 | key_states (`torch.Tensor`):
83 | Input key states to be passed to Flash Attention API
84 | value_states (`torch.Tensor`):
85 | Input value states to be passed to Flash Attention API
86 | attention_mask (`torch.Tensor`):
87 | The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
88 | position of padding tokens and 1 for the position of non-padding tokens.
89 | dropout (`float`):
90 | Attention dropout
91 | softmax_scale (`float`, *optional*):
92 | The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
93 | use_top_left_mask (`bool`, defaults to `False`):
94 | flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference.
95 | softcap (`float`, *optional*):
96 | Softcap for the attention logits, used e.g. in gemma2.
97 | deterministic (`bool`, *optional*):
98 | Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled.
99 | """
100 | if not use_top_left_mask:
101 | causal = is_causal
102 | else:
103 | # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__.
104 | causal = is_causal and query_length != 1
105 |
106 | # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
107 | use_sliding_windows = (
108 | _flash_supports_window_size
109 | and sliding_window is not None
110 | and key_states.shape[1] > sliding_window
111 | )
112 | flash_kwargs = (
113 | {"window_size": (sliding_window, sliding_window)}
114 | if use_sliding_windows
115 | else {}
116 | )
117 |
118 | if is_flash_attn_greater_or_equal("2.4.1"):
119 | if deterministic is None:
120 | deterministic = (
121 | os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
122 | )
123 | flash_kwargs["deterministic"] = deterministic
124 | assert (
125 | softcap is None
126 | ), "llama3_flash_attn_varlen_func does not support softcap yet."
127 | # flash_kwargs["softcap"] = softcap
128 | flash_kwargs["group"] = process_group
129 |
130 | # not sure why attention_mask can be not None...
131 | assert causal, "only causal attention is supported yet."
132 | batch_size = query_states.size(0)
133 | assert batch_size == 1, "varlen data should be processed in advance."
134 |
135 | attn_output = llama3_flash_attn_varlen_func(
136 | query_states.squeeze(dim=0),
137 | key_states.squeeze(dim=0),
138 | value_states.squeeze(dim=0),
139 | cu_seqlens_q=DATA_PARAMS["cu_seqlens_q"],
140 | cu_seqlens_k=DATA_PARAMS["cu_seqlens_k"],
141 | max_seqlen_q=DATA_PARAMS["max_seqlen_q"],
142 | max_seqlen_k=DATA_PARAMS["max_seqlen_k"],
143 | heads_k_stride=heads_k_stride,
144 | local_k_slice=DATA_PARAMS["local_k_slice"],
145 | dropout_p=dropout,
146 | softmax_scale=softmax_scale,
147 | causal=causal,
148 | **flash_kwargs,
149 | )
150 |
151 | attn_output = attn_output.unsqueeze(dim=0)
152 |
153 | return attn_output
154 |
155 | return _flash_attention_forward
156 |
157 |
158 | def substitute_hf_flash_attn(process_group: dist.ProcessGroup, heads_k_stride: int):
159 | try:
160 | # substitute flash attn
161 | old_flash_attention_forward = (
162 | transformers.modeling_flash_attention_utils._flash_attention_forward
163 | )
164 | new_flash_attention_forward = create_ring_flash_attention_forward(
165 | process_group, heads_k_stride
166 | )
167 | assert check_params(old_flash_attention_forward, new_flash_attention_forward)
168 | transformers.modeling_flash_attention_utils._flash_attention_forward = (
169 | lambda *args, **kwargs: (
170 | new_flash_attention_forward(*args, **kwargs)
171 | if RING_ATTN_SWITCH
172 | else old_flash_attention_forward(*args, **kwargs)
173 | )
174 | )
175 | except:
176 | raise ValueError(
177 | f"The current transformer version {transformers.__version__} is not supported. "
178 | "please use pip install -U transformers to upgrade to the latest version. "
179 | "If the code failed with the latest version, "
180 | "please file an issue to https://github.com/zhuzilin/ring-flash-attention/issues"
181 | )
182 |
--------------------------------------------------------------------------------
/train/ring-flash-attention/ring_flash_attn/adapters/__init__.py:
--------------------------------------------------------------------------------
1 | from .hf_adapter import (
2 | substitute_hf_flash_attn,
3 | update_ring_flash_attn_params,
4 | )
5 |
--------------------------------------------------------------------------------
/train/ring-flash-attention/ring_flash_attn/adapters/hf_adapter.py:
--------------------------------------------------------------------------------
1 | import os
2 | import inspect
3 | from typing import Optional
4 |
5 | import torch
6 | import torch.distributed as dist
7 | import transformers
8 | import transformers.modeling_flash_attention_utils
9 | from transformers.modeling_flash_attention_utils import (
10 | _flash_supports_window_size,
11 | is_flash_attn_greater_or_equal,
12 | )
13 | from ..llama3_flash_attn_varlen import (
14 | llama3_flash_attn_varlen_func,
15 | llama3_flash_attn_prepare_cu_seqlens,
16 | )
17 |
18 |
19 | DATA_PARAMS = {}
20 | RING_ATTN_SWITCH = True
21 |
22 |
23 | def check_params(f1, f2):
24 | return len(inspect.signature(f1).parameters) == len(
25 | inspect.signature(f2).parameters
26 | )
27 |
28 |
29 | def update_ring_flash_attn_params(
30 | cu_seqlens: torch.Tensor, process_group: dist.ProcessGroup
31 | ):
32 | world_size = dist.get_world_size(group=process_group)
33 | rank = dist.get_rank(group=process_group)
34 | (
35 | cu_seqlens_q,
36 | cu_seqlens_k,
37 | max_seqlen_q,
38 | max_seqlen_k,
39 | local_k_slice,
40 | ) = llama3_flash_attn_prepare_cu_seqlens(cu_seqlens, True, rank, world_size)
41 | DATA_PARAMS.update(
42 | {
43 | "cu_seqlens_q": cu_seqlens_q,
44 | "cu_seqlens_k": cu_seqlens_k,
45 | "max_seqlen_q": max_seqlen_q,
46 | "max_seqlen_k": max_seqlen_k,
47 | "local_k_slice": local_k_slice,
48 | }
49 | )
50 |
51 |
52 | def use_ring_attn(flag):
53 | global RING_ATTN_SWITCH
54 | RING_ATTN_SWITCH = flag
55 |
56 |
57 | def create_ring_flash_attention_forward(
58 | process_group: dist.ProcessGroup, heads_k_stride: int
59 | ):
60 | def _flash_attention_forward(
61 | query_states: torch.Tensor,
62 | key_states: torch.Tensor,
63 | value_states: torch.Tensor,
64 | attention_mask: torch.Tensor,
65 | query_length: int,
66 | is_causal: bool,
67 | dropout: float = 0.0,
68 | position_ids: Optional[torch.Tensor] = None,
69 | softmax_scale: Optional[float] = None,
70 | sliding_window: Optional[int] = None,
71 | use_top_left_mask: bool = False,
72 | softcap: Optional[float] = None,
73 | deterministic: bool = None,
74 | ):
75 | """
76 | Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
77 | first unpad the input, then computes the attention scores and pad the final attention scores.
78 |
79 | Args:
80 | query_states (`torch.Tensor`):
81 | Input query states to be passed to Flash Attention API
82 | key_states (`torch.Tensor`):
83 | Input key states to be passed to Flash Attention API
84 | value_states (`torch.Tensor`):
85 | Input value states to be passed to Flash Attention API
86 | attention_mask (`torch.Tensor`):
87 | The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
88 | position of padding tokens and 1 for the position of non-padding tokens.
89 | dropout (`float`):
90 | Attention dropout
91 | softmax_scale (`float`, *optional*):
92 | The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
93 | use_top_left_mask (`bool`, defaults to `False`):
94 | flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference.
95 | softcap (`float`, *optional*):
96 | Softcap for the attention logits, used e.g. in gemma2.
97 | deterministic (`bool`, *optional*):
98 | Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled.
99 | """
100 | if not use_top_left_mask:
101 | causal = is_causal
102 | else:
103 | # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__.
104 | causal = is_causal and query_length != 1
105 |
106 | # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
107 | use_sliding_windows = (
108 | _flash_supports_window_size
109 | and sliding_window is not None
110 | and key_states.shape[1] > sliding_window
111 | )
112 | flash_kwargs = (
113 | {"window_size": (sliding_window, sliding_window)}
114 | if use_sliding_windows
115 | else {}
116 | )
117 |
118 | if is_flash_attn_greater_or_equal("2.4.1"):
119 | if deterministic is None:
120 | deterministic = (
121 | os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
122 | )
123 | flash_kwargs["deterministic"] = deterministic
124 | assert (
125 | softcap is None
126 | ), "llama3_flash_attn_varlen_func does not support softcap yet."
127 | # flash_kwargs["softcap"] = softcap
128 | flash_kwargs["group"] = process_group
129 |
130 | # not sure why attention_mask can be not None...
131 | assert causal, "only causal attention is supported yet."
132 | batch_size = query_states.size(0)
133 | assert batch_size == 1, "varlen data should be processed in advance."
134 |
135 | attn_output = llama3_flash_attn_varlen_func(
136 | query_states.squeeze(dim=0),
137 | key_states.squeeze(dim=0),
138 | value_states.squeeze(dim=0),
139 | cu_seqlens_q=DATA_PARAMS["cu_seqlens_q"],
140 | cu_seqlens_k=DATA_PARAMS["cu_seqlens_k"],
141 | max_seqlen_q=DATA_PARAMS["max_seqlen_q"],
142 | max_seqlen_k=DATA_PARAMS["max_seqlen_k"],
143 | heads_k_stride=heads_k_stride,
144 | local_k_slice=DATA_PARAMS["local_k_slice"],
145 | dropout_p=dropout,
146 | softmax_scale=softmax_scale,
147 | causal=causal,
148 | **flash_kwargs,
149 | )
150 |
151 | attn_output = attn_output.unsqueeze(dim=0)
152 |
153 | return attn_output
154 |
155 | return _flash_attention_forward
156 |
157 |
158 | def substitute_hf_flash_attn(process_group: dist.ProcessGroup, heads_k_stride: int):
159 | try:
160 | # substitute flash attn
161 | old_flash_attention_forward = (
162 | transformers.modeling_flash_attention_utils._flash_attention_forward
163 | )
164 | new_flash_attention_forward = create_ring_flash_attention_forward(
165 | process_group, heads_k_stride
166 | )
167 | assert check_params(old_flash_attention_forward, new_flash_attention_forward)
168 | transformers.modeling_flash_attention_utils._flash_attention_forward = (
169 | lambda *args, **kwargs: (
170 | new_flash_attention_forward(*args, **kwargs)
171 | if RING_ATTN_SWITCH
172 | else old_flash_attention_forward(*args, **kwargs)
173 | )
174 | )
175 | except:
176 | raise ValueError(
177 | f"The current transformer version {transformers.__version__} is not supported. "
178 | "please use pip install -U transformers to upgrade to the latest version. "
179 | "If the code failed with the latest version, "
180 | "please file an issue to https://github.com/zhuzilin/ring-flash-attention/issues"
181 | )
182 |
--------------------------------------------------------------------------------
/train/ring-flash-attention/ring_flash_attn/ring_flash_attn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.distributed as dist
3 | from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward
4 | from .utils import RingComm, update_out_and_lse, get_default_args
5 |
6 |
7 | def ring_flash_attn_forward(
8 | process_group,
9 | q: torch.Tensor,
10 | k: torch.Tensor,
11 | v: torch.Tensor,
12 | softmax_scale,
13 | dropout_p=0,
14 | causal=True,
15 | window_size=(-1, -1),
16 | alibi_slopes=None,
17 | deterministic=False,
18 | ):
19 | comm = RingComm(process_group)
20 |
21 | out = None
22 | lse = None
23 |
24 | next_k, next_v = None, None
25 |
26 | for step in range(comm.world_size):
27 | if step + 1 != comm.world_size:
28 | next_k: torch.Tensor = comm.send_recv(k)
29 | next_v: torch.Tensor = comm.send_recv(v)
30 | comm.commit()
31 |
32 | if not causal or step <= comm.rank:
33 | params = get_default_args(_flash_attn_forward).copy()
34 | params.update(
35 | {
36 | "q": q,
37 | "k": k,
38 | "v": v,
39 | "dropout_p": dropout_p,
40 | "softmax_scale": softmax_scale,
41 | "causal": causal and step == 0,
42 | "window_size": window_size,
43 | "alibi_slopes": alibi_slopes,
44 | "return_softmax": True and dropout_p > 0,
45 | }
46 | )
47 | block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward(**params)
48 | out, lse = update_out_and_lse(out, lse, block_out, block_lse)
49 |
50 | if step + 1 != comm.world_size:
51 | comm.wait()
52 | k = next_k
53 | v = next_v
54 |
55 | out = out.to(q.dtype)
56 | lse = lse.squeeze(dim=-1).transpose(1, 2)
57 | return out, lse
58 |
59 |
60 | def ring_flash_attn_backward(
61 | process_group,
62 | dout,
63 | q,
64 | k,
65 | v,
66 | out,
67 | softmax_lse,
68 | softmax_scale,
69 | dropout_p=0,
70 | causal=True,
71 | window_size=(-1, -1),
72 | alibi_slopes=None,
73 | deterministic=False,
74 | ):
75 | kv_comm = RingComm(process_group)
76 | d_kv_comm = RingComm(process_group)
77 | dq, dk, dv = None, None, None
78 | next_dk, next_dv = None, None
79 |
80 | block_dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device)
81 | block_dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device)
82 | block_dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device)
83 |
84 | next_dk, next_dv = None, None
85 | next_k, next_v = None, None
86 |
87 | for step in range(kv_comm.world_size):
88 | if step + 1 != kv_comm.world_size:
89 | next_k = kv_comm.send_recv(k)
90 | next_v = kv_comm.send_recv(v)
91 | kv_comm.commit()
92 | if step <= kv_comm.rank or not causal:
93 | bwd_causal = causal and step == 0
94 | params = get_default_args(_flash_attn_backward).copy()
95 | params.update(
96 | {
97 | "dout": dout,
98 | "q": q,
99 | "k": k,
100 | "v": v,
101 | "out": out,
102 | "softmax_lse": softmax_lse,
103 | "dq": block_dq_buffer,
104 | "dk": block_dk_buffer,
105 | "dv": block_dv_buffer,
106 | "dropout_p": dropout_p,
107 | "softmax_scale": softmax_scale,
108 | "causal": bwd_causal,
109 | "window_size": window_size,
110 | "alibi_slopes": alibi_slopes,
111 | "deterministic": deterministic,
112 | }
113 | )
114 | _flash_attn_backward(**params)
115 |
116 | if dq is None:
117 | dq = block_dq_buffer.to(torch.float32)
118 | dk = block_dk_buffer.to(torch.float32)
119 | dv = block_dv_buffer.to(torch.float32)
120 | else:
121 | dq += block_dq_buffer
122 | d_kv_comm.wait()
123 | dk = block_dk_buffer + next_dk
124 | dv = block_dv_buffer + next_dv
125 | elif step != 0:
126 | d_kv_comm.wait()
127 | dk = next_dk
128 | dv = next_dv
129 |
130 | if step + 1 != kv_comm.world_size:
131 | kv_comm.wait()
132 | k = next_k
133 | v = next_v
134 |
135 | next_dk = d_kv_comm.send_recv(dk)
136 | next_dv = d_kv_comm.send_recv(dv)
137 | d_kv_comm.commit()
138 |
139 | d_kv_comm.wait()
140 |
141 | return dq.to(torch.bfloat16), next_dk.to(q.dtype), next_dv.to(q.dtype)
142 |
143 |
144 | class RingFlashAttnFunc(torch.autograd.Function):
145 | @staticmethod
146 | def forward(
147 | ctx,
148 | q,
149 | k,
150 | v,
151 | dropout_p,
152 | softmax_scale,
153 | causal,
154 | window_size,
155 | alibi_slopes,
156 | deterministic,
157 | return_softmax,
158 | group,
159 | ):
160 | if softmax_scale is None:
161 | softmax_scale = q.shape[-1] ** (-0.5)
162 |
163 | assert alibi_slopes is None
164 | k = k.contiguous()
165 | v = v.contiguous()
166 | out, softmax_lse = ring_flash_attn_forward(
167 | group,
168 | q,
169 | k,
170 | v,
171 | softmax_scale=softmax_scale,
172 | dropout_p=dropout_p,
173 | causal=causal,
174 | window_size=window_size,
175 | alibi_slopes=alibi_slopes,
176 | deterministic=False,
177 | )
178 | # this should be out_padded
179 | ctx.save_for_backward(q, k, v, out, softmax_lse)
180 | ctx.dropout_p = dropout_p
181 | ctx.softmax_scale = softmax_scale
182 | ctx.causal = causal
183 | ctx.window_size = window_size
184 | ctx.alibi_slopes = alibi_slopes
185 | ctx.deterministic = deterministic
186 | ctx.group = group
187 | return out if not return_softmax else (out, softmax_lse, None)
188 |
189 | @staticmethod
190 | def backward(ctx, dout, *args):
191 | q, k, v, out, softmax_lse = ctx.saved_tensors
192 | dq, dk, dv = ring_flash_attn_backward(
193 | ctx.group,
194 | dout,
195 | q,
196 | k,
197 | v,
198 | out,
199 | softmax_lse,
200 | softmax_scale=ctx.softmax_scale,
201 | dropout_p=ctx.dropout_p,
202 | causal=ctx.causal,
203 | window_size=ctx.window_size,
204 | alibi_slopes=ctx.alibi_slopes,
205 | deterministic=ctx.deterministic,
206 | )
207 | return dq, dk, dv, None, None, None, None, None, None, None, None
208 |
209 |
210 | def ring_flash_attn_qkvpacked_func(
211 | qkv,
212 | dropout_p=0.0,
213 | softmax_scale=None,
214 | causal=False,
215 | window_size=(-1, -1),
216 | alibi_slopes=None,
217 | deterministic=False,
218 | return_attn_probs=False,
219 | group=None,
220 | ):
221 | return RingFlashAttnFunc.apply(
222 | qkv[:, :, 0],
223 | qkv[:, :, 1],
224 | qkv[:, :, 2],
225 | dropout_p,
226 | softmax_scale,
227 | causal,
228 | window_size,
229 | alibi_slopes,
230 | deterministic,
231 | return_attn_probs,
232 | group,
233 | )
234 |
235 |
236 | def ring_flash_attn_kvpacked_func(
237 | q,
238 | kv,
239 | dropout_p=0.0,
240 | softmax_scale=None,
241 | causal=False,
242 | window_size=(-1, -1),
243 | alibi_slopes=None,
244 | deterministic=False,
245 | return_attn_probs=False,
246 | group=None,
247 | ):
248 | return RingFlashAttnFunc.apply(
249 | q,
250 | kv[:, :, 0],
251 | kv[:, :, 1],
252 | dropout_p,
253 | softmax_scale,
254 | causal,
255 | window_size,
256 | alibi_slopes,
257 | deterministic,
258 | return_attn_probs,
259 | group,
260 | )
261 |
262 |
263 | def ring_flash_attn_func(
264 | q,
265 | k,
266 | v,
267 | dropout_p=0.0,
268 | softmax_scale=None,
269 | causal=False,
270 | window_size=(-1, -1),
271 | alibi_slopes=None,
272 | deterministic=False,
273 | return_attn_probs=False,
274 | group=None,
275 | ):
276 | return RingFlashAttnFunc.apply(
277 | q,
278 | k,
279 | v,
280 | dropout_p,
281 | softmax_scale,
282 | causal,
283 | window_size,
284 | alibi_slopes,
285 | deterministic,
286 | return_attn_probs,
287 | group,
288 | )
289 |
--------------------------------------------------------------------------------
/train/ring-flash-attention/ring_flash_attn/ring_flash_attn_varlen.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.distributed as dist
3 | from flash_attn.flash_attn_interface import (
4 | _flash_attn_varlen_forward,
5 | _flash_attn_varlen_backward,
6 | )
7 | from .utils import (
8 | RingComm,
9 | update_out_and_lse,
10 | get_default_args,
11 | )
12 |
13 | try:
14 | from .triton_utils import (
15 | flatten_varlen_lse,
16 | unflatten_varlen_lse,
17 | )
18 | except:
19 | from .utils import (
20 | flatten_varlen_lse,
21 | unflatten_varlen_lse,
22 | )
23 |
24 |
25 | def ring_flash_attn_varlen_forward(
26 | process_group,
27 | q: torch.Tensor,
28 | k: torch.Tensor,
29 | v: torch.Tensor,
30 | cu_seqlens,
31 | max_seqlen,
32 | softmax_scale,
33 | dropout_p=0,
34 | causal=True,
35 | window_size=(-1, -1),
36 | alibi_slopes=None,
37 | deterministic=False,
38 | ):
39 | comm = RingComm(process_group)
40 |
41 | out = None
42 | lse = None
43 | next_k, next_v = None, None
44 |
45 | old_lse = False
46 | for step in range(comm.world_size):
47 | if step + 1 != comm.world_size:
48 | next_k: torch.Tensor = comm.send_recv(k)
49 | next_v: torch.Tensor = comm.send_recv(v)
50 | comm.commit()
51 | if not causal or step <= comm.rank:
52 | params = get_default_args(_flash_attn_varlen_forward).copy()
53 | params.update(
54 | {
55 | "q": q,
56 | "k": k,
57 | "v": v,
58 | "cu_seqlens_q": cu_seqlens,
59 | "cu_seqlens_k": cu_seqlens,
60 | "max_seqlen_q": max_seqlen,
61 | "max_seqlen_k": max_seqlen,
62 | "dropout_p": dropout_p,
63 | "softmax_scale": softmax_scale,
64 | "causal": causal and step == 0,
65 | "window_size": window_size,
66 | "alibi_slopes": alibi_slopes,
67 | "return_softmax": True and dropout_p > 0,
68 | }
69 | )
70 |
71 | block_out, _, _, _, _, block_lse, _, _ = _flash_attn_varlen_forward(
72 | **params
73 | )
74 | if block_lse.dim() == 3:
75 | old_lse = True
76 | block_lse = flatten_varlen_lse(
77 | block_lse,
78 | cu_seqlens=cu_seqlens,
79 | )
80 | out, lse = update_out_and_lse(out, lse, block_out, block_lse)
81 |
82 | if step + 1 != comm.world_size:
83 | comm.wait()
84 | k = next_k
85 | v = next_v
86 |
87 | out = out.to(q.dtype)
88 | if old_lse:
89 | lse = unflatten_varlen_lse(lse, cu_seqlens, max_seqlen)
90 | else:
91 | lse = lse.squeeze(dim=-1).transpose(0, 1)
92 | return out, lse
93 |
94 |
95 | def ring_flash_attn_varlen_backward(
96 | process_group,
97 | dout,
98 | q,
99 | k,
100 | v,
101 | out,
102 | softmax_lse,
103 | cu_seqlens,
104 | max_seqlen,
105 | softmax_scale,
106 | dropout_p=0,
107 | causal=True,
108 | window_size=(-1, -1),
109 | alibi_slopes=None,
110 | deterministic=False,
111 | ):
112 | kv_comm = RingComm(process_group)
113 | d_kv_comm = RingComm(process_group)
114 | dq, dk, dv = None, None, None
115 | next_dk, next_dv = None, None
116 |
117 | block_dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device)
118 | block_dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device)
119 | block_dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device)
120 |
121 | next_dk, next_dv = None, None
122 | next_k, next_v = None, None
123 | for step in range(kv_comm.world_size):
124 | if step + 1 != kv_comm.world_size:
125 | next_k = kv_comm.send_recv(k)
126 | next_v = kv_comm.send_recv(v)
127 | kv_comm.commit()
128 | if step <= kv_comm.rank or not causal:
129 | bwd_causal = causal and step == 0
130 | params = get_default_args(_flash_attn_varlen_backward).copy()
131 | params.update(
132 | {
133 | "dout": dout,
134 | "q": q,
135 | "k": k,
136 | "v": v,
137 | "out": out,
138 | "softmax_lse": softmax_lse,
139 | "dq": block_dq_buffer,
140 | "dk": block_dk_buffer,
141 | "dv": block_dv_buffer,
142 | "cu_seqlens_q": cu_seqlens,
143 | "cu_seqlens_k": cu_seqlens,
144 | "max_seqlen_q": max_seqlen,
145 | "max_seqlen_k": max_seqlen,
146 | "dropout_p": dropout_p,
147 | "softmax_scale": softmax_scale,
148 | "causal": bwd_causal,
149 | "window_size": window_size,
150 | "alibi_slopes": alibi_slopes,
151 | "deterministic": deterministic,
152 | }
153 | )
154 | _flash_attn_varlen_backward(**params)
155 |
156 | if dq is None:
157 | dq = block_dq_buffer.to(torch.float32)
158 | dk = block_dk_buffer.to(torch.float32)
159 | dv = block_dv_buffer.to(torch.float32)
160 | else:
161 | dq += block_dq_buffer
162 | d_kv_comm.wait()
163 | dk = block_dk_buffer + next_dk
164 | dv = block_dv_buffer + next_dv
165 | elif step != 0:
166 | d_kv_comm.wait()
167 | dk = next_dk
168 | dv = next_dv
169 |
170 | if step + 1 != kv_comm.world_size:
171 | kv_comm.wait()
172 | k = next_k
173 | v = next_v
174 |
175 | next_dk = d_kv_comm.send_recv(dk)
176 | next_dv = d_kv_comm.send_recv(dv)
177 | d_kv_comm.commit()
178 |
179 | d_kv_comm.wait()
180 |
181 | return dq.to(torch.bfloat16), next_dk.to(q.dtype), next_dv.to(q.dtype)
182 |
183 |
184 | class RingFlashAttnVarlenFunc(torch.autograd.Function):
185 | @staticmethod
186 | def forward(
187 | ctx,
188 | q,
189 | k,
190 | v,
191 | cu_seqlens,
192 | max_seqlen,
193 | dropout_p,
194 | softmax_scale,
195 | causal,
196 | window_size,
197 | alibi_slopes,
198 | deterministic,
199 | return_softmax,
200 | group,
201 | ):
202 | if softmax_scale is None:
203 | softmax_scale = q.shape[-1] ** (-0.5)
204 |
205 | assert alibi_slopes is None
206 | k = k.contiguous()
207 | v = v.contiguous()
208 | out, softmax_lse = ring_flash_attn_varlen_forward(
209 | group,
210 | q,
211 | k,
212 | v,
213 | cu_seqlens,
214 | max_seqlen,
215 | softmax_scale=softmax_scale,
216 | dropout_p=dropout_p,
217 | causal=causal,
218 | window_size=window_size,
219 | alibi_slopes=alibi_slopes,
220 | deterministic=False,
221 | )
222 | # this should be out_padded
223 | ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens)
224 | ctx.max_seqlen = max_seqlen
225 | ctx.dropout_p = dropout_p
226 | ctx.softmax_scale = softmax_scale
227 | ctx.causal = causal
228 | ctx.window_size = window_size
229 | ctx.alibi_slopes = alibi_slopes
230 | ctx.deterministic = deterministic
231 | ctx.group = group
232 | return out if not return_softmax else (out, softmax_lse, None)
233 |
234 | @staticmethod
235 | def backward(ctx, dout, *args):
236 | q, k, v, out, softmax_lse, cu_seqlens = ctx.saved_tensors
237 | dq, dk, dv = ring_flash_attn_varlen_backward(
238 | ctx.group,
239 | dout,
240 | q,
241 | k,
242 | v,
243 | out,
244 | softmax_lse,
245 | cu_seqlens,
246 | ctx.max_seqlen,
247 | softmax_scale=ctx.softmax_scale,
248 | dropout_p=ctx.dropout_p,
249 | causal=ctx.causal,
250 | window_size=ctx.window_size,
251 | alibi_slopes=ctx.alibi_slopes,
252 | deterministic=ctx.deterministic,
253 | )
254 | return dq, dk, dv, None, None, None, None, None, None, None, None, None, None
255 |
256 |
257 | def ring_flash_attn_varlen_qkvpacked_func(
258 | qkv,
259 | cu_seqlens,
260 | max_seqlen,
261 | dropout_p=0.0,
262 | softmax_scale=None,
263 | causal=False,
264 | window_size=(-1, -1), # -1 means infinite context window
265 | alibi_slopes=None,
266 | deterministic=False,
267 | return_attn_probs=False,
268 | group=None,
269 | ):
270 | return RingFlashAttnVarlenFunc.apply(
271 | qkv[:, 0],
272 | qkv[:, 1],
273 | qkv[:, 2],
274 | cu_seqlens,
275 | max_seqlen,
276 | dropout_p,
277 | softmax_scale,
278 | causal,
279 | window_size,
280 | alibi_slopes,
281 | deterministic,
282 | return_attn_probs,
283 | group,
284 | )
285 |
286 |
287 | def ring_flash_attn_varlen_kvpacked_func(
288 | q,
289 | kv,
290 | cu_seqlens,
291 | max_seqlen,
292 | dropout_p=0.0,
293 | softmax_scale=None,
294 | causal=False,
295 | window_size=(-1, -1), # -1 means infinite context window
296 | alibi_slopes=None,
297 | deterministic=False,
298 | return_attn_probs=False,
299 | group=None,
300 | ):
301 | return RingFlashAttnVarlenFunc.apply(
302 | q,
303 | kv[:, 0],
304 | kv[:, 1],
305 | cu_seqlens,
306 | max_seqlen,
307 | dropout_p,
308 | softmax_scale,
309 | causal,
310 | window_size,
311 | alibi_slopes,
312 | deterministic,
313 | return_attn_probs,
314 | group,
315 | )
316 |
317 |
318 | def ring_flash_attn_varlen_func(
319 | q,
320 | k,
321 | v,
322 | cu_seqlens,
323 | max_seqlen,
324 | dropout_p=0.0,
325 | softmax_scale=None,
326 | causal=False,
327 | window_size=(-1, -1), # -1 means infinite context window
328 | alibi_slopes=None,
329 | deterministic=False,
330 | return_attn_probs=False,
331 | group=None,
332 | ):
333 | return RingFlashAttnVarlenFunc.apply(
334 | q,
335 | k,
336 | v,
337 | cu_seqlens,
338 | max_seqlen,
339 | dropout_p,
340 | softmax_scale,
341 | causal,
342 | window_size,
343 | alibi_slopes,
344 | deterministic,
345 | return_attn_probs,
346 | group,
347 | )
348 |
--------------------------------------------------------------------------------
/train/ring-flash-attention/ring_flash_attn/triton_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import triton
3 | import triton.language as tl
4 |
5 |
6 | @triton.jit
7 | def flatten_kernel(
8 | # pointers to matrices
9 | OUT,
10 | LSE,
11 | CU_SEQLENS,
12 | # strides
13 | stride_out_nheads,
14 | stride_out_seqlen,
15 | stride_lse_batch,
16 | stride_lse_nheads,
17 | stride_lse_seqlen,
18 | # meta-parameters
19 | BLOCK_M: tl.constexpr,
20 | ):
21 | pid_m = tl.program_id(axis=0)
22 | pid_batch = tl.program_id(axis=1)
23 | pid_head = tl.program_id(axis=2)
24 |
25 | start_idx = tl.load(CU_SEQLENS + pid_batch)
26 | seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx
27 | LSE = LSE + pid_batch * stride_lse_batch + pid_head * stride_lse_nheads
28 | OUT = OUT + pid_head * stride_out_nheads + start_idx * stride_out_seqlen
29 |
30 | rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
31 |
32 | LSE = LSE + rm[:, None] * stride_lse_seqlen
33 | x = tl.load(LSE, mask=rm[:, None] < seqlen, other=0.0)
34 |
35 | OUT = OUT + rm[:, None] * stride_out_seqlen
36 | tl.store(OUT, x, mask=rm[:, None] < seqlen)
37 |
38 |
39 | def flatten_varlen_lse(lse, cu_seqlens):
40 | """
41 | Arguments:
42 | lse: (batch_size, nheads, max_seqlen)
43 | cu_seqlens: (batch_size + 1,)
44 | Return:
45 | flatten_lse: (nheads, total_seqlen)
46 | """
47 | total_seqlen = cu_seqlens[-1]
48 | batch_size, nheads, max_seqlen = lse.shape
49 | output = torch.empty((nheads, total_seqlen), dtype=lse.dtype, device=lse.device)
50 |
51 | grid = lambda META: (triton.cdiv(max_seqlen, META["BLOCK_M"]), batch_size, nheads)
52 | BLOCK_M = 4
53 |
54 | with torch.cuda.device(lse.device.index):
55 | flatten_kernel[grid](
56 | output,
57 | lse,
58 | cu_seqlens,
59 | # strides
60 | output.stride(0),
61 | output.stride(1),
62 | lse.stride(0),
63 | lse.stride(1),
64 | lse.stride(2),
65 | BLOCK_M,
66 | )
67 | return output
68 |
69 |
70 | @triton.jit
71 | def unflatten_kernel(
72 | # pointers to matrices
73 | OUT,
74 | LSE,
75 | CU_SEQLENS,
76 | # strides
77 | stride_out_batch,
78 | stride_out_nheads,
79 | stride_out_seqlen,
80 | stride_lse_seqlen,
81 | stride_lse_nheads,
82 | # meta-parameters
83 | BLOCK_M: tl.constexpr,
84 | ):
85 | pid_m = tl.program_id(axis=0)
86 | pid_batch = tl.program_id(axis=1)
87 | pid_head = tl.program_id(axis=2)
88 |
89 | start_idx = tl.load(CU_SEQLENS + pid_batch)
90 | seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx
91 | LSE = LSE + pid_head * stride_lse_nheads + start_idx * stride_lse_seqlen
92 | OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads
93 |
94 | rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
95 |
96 | LSE = LSE + rm[:, None] * stride_lse_seqlen
97 | x = tl.load(LSE, mask=rm[:, None] < seqlen, other=0.0)
98 |
99 | OUT = OUT + rm[:, None] * stride_out_seqlen
100 | tl.store(OUT, x, mask=rm[:, None] < seqlen)
101 |
102 |
103 | def unflatten_varlen_lse(lse, cu_seqlens, max_seqlen: int):
104 | """
105 | Arguments:
106 | lse: (total_seqlen, nheads, 1)
107 | cu_seqlens: (batch_size + 1,)
108 | max_seqlen: int
109 | Return:
110 | unflatten_lse: (batch_size, nheads, max_seqlen)
111 | """
112 | lse = lse.unsqueeze(dim=-1)
113 | batch_size = len(cu_seqlens) - 1
114 | nheads = lse.shape[1]
115 | output = torch.empty(
116 | (batch_size, nheads, max_seqlen),
117 | dtype=lse.dtype,
118 | device=lse.device,
119 | )
120 |
121 | grid = lambda META: (triton.cdiv(max_seqlen, META["BLOCK_M"]), batch_size, nheads)
122 | BLOCK_M = 4
123 |
124 | with torch.cuda.device(lse.device.index):
125 | unflatten_kernel[grid](
126 | output,
127 | lse,
128 | cu_seqlens,
129 | # strides
130 | output.stride(0),
131 | output.stride(1),
132 | output.stride(2),
133 | lse.stride(0),
134 | lse.stride(1),
135 | BLOCK_M,
136 | )
137 | return output
138 |
--------------------------------------------------------------------------------
/train/ring-flash-attention/ring_flash_attn/utils.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple
2 |
3 | import torch
4 | import torch.distributed as dist
5 | import torch.nn.functional as F
6 | import inspect
7 | from functools import cache
8 |
9 |
10 | __all__ = ["update_out_and_lse", "RingComm", "get_default_args"]
11 |
12 |
13 | @cache
14 | def get_default_args(func):
15 | spec = inspect.getfullargspec(func)
16 | defaults = spec.defaults if spec.defaults is not None else ()
17 | padded_defaults = (None,) * (len(spec.args) - len(defaults)) + defaults
18 | args = dict(zip(spec.args, padded_defaults))
19 | if "softcap" in args:
20 | args["softcap"] = 0.0
21 | return args
22 |
23 |
24 | @torch.jit.script
25 | def _update_out_and_lse(
26 | out: torch.Tensor,
27 | lse: torch.Tensor,
28 | block_out: torch.Tensor,
29 | block_lse: torch.Tensor,
30 | ) -> Tuple[torch.Tensor, torch.Tensor]:
31 |
32 | block_out = block_out.to(torch.float32)
33 | block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1)
34 |
35 | # new_lse = lse + torch.log(1 + torch.exp(block_lse - lse))
36 | # torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out
37 | # For additional context and discussion, please refer to:
38 | # https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795
39 | out = out - F.sigmoid(block_lse - lse) * (out - block_out)
40 | lse = lse - F.logsigmoid(lse - block_lse)
41 |
42 | return out, lse
43 |
44 |
45 | def update_out_and_lse(
46 | out: Optional[torch.Tensor],
47 | lse: Optional[torch.Tensor],
48 | block_out: torch.Tensor,
49 | block_lse: torch.Tensor,
50 | slice_=None,
51 | ) -> Tuple[torch.Tensor, torch.Tensor]:
52 | if out is None:
53 | if slice_ is not None:
54 | raise RuntimeError("first update_out_and_lse should not pass slice_ args")
55 | out = block_out.to(torch.float32)
56 | lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1)
57 | elif slice_ is not None:
58 | slice_out, slice_lse = out[slice_], lse[slice_]
59 | slice_out, slice_lse = _update_out_and_lse(
60 | slice_out, slice_lse, block_out, block_lse
61 | )
62 | out[slice_], lse[slice_] = slice_out, slice_lse
63 | else:
64 | out, lse = _update_out_and_lse(out, lse, block_out, block_lse)
65 | return out, lse
66 |
67 |
68 | @torch.jit.script
69 | def flatten_varlen_lse(lse, cu_seqlens):
70 | new_lse = []
71 | for i in range(len(cu_seqlens) - 1):
72 | start, end = cu_seqlens[i], cu_seqlens[i + 1]
73 | new_lse.append(lse[i, :, : end - start])
74 | return torch.cat(new_lse, dim=1)
75 |
76 |
77 | @torch.jit.script
78 | def unflatten_varlen_lse(lse, cu_seqlens, max_seqlen: int):
79 | num_seq = len(cu_seqlens) - 1
80 | num_head = lse.shape[-2]
81 | new_lse = torch.empty(
82 | (num_seq, max_seqlen, num_head, 1), dtype=torch.float32, device=lse.device
83 | )
84 | for i in range(num_seq):
85 | start, end = cu_seqlens[i], cu_seqlens[i + 1]
86 | new_lse[i, : end - start] = lse[start:end]
87 | return new_lse.squeeze(dim=-1).transpose(1, 2).contiguous()
88 |
89 |
90 | class RingComm:
91 | def __init__(self, process_group: dist.ProcessGroup):
92 | self._process_group = process_group
93 | self._ops = []
94 | self.rank = dist.get_rank(self._process_group)
95 | self.world_size = dist.get_world_size(self._process_group)
96 | self._reqs = None
97 |
98 | self.send_rank = (self.rank + 1) % self.world_size
99 | self.recv_rank = (self.rank - 1) % self.world_size
100 |
101 | if process_group is not None:
102 | self.send_rank = dist.get_global_rank(self._process_group, self.send_rank)
103 | self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank)
104 |
105 | def send_recv(
106 | self, to_send: torch.Tensor, recv_tensor: Optional[torch.Tensor] = None
107 | ) -> torch.Tensor:
108 | if recv_tensor is None:
109 | res = torch.empty_like(to_send)
110 | else:
111 | res = recv_tensor
112 |
113 | send_op = dist.P2POp(
114 | dist.isend, to_send, self.send_rank, group=self._process_group
115 | )
116 | recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group)
117 | self._ops.append(send_op)
118 | self._ops.append(recv_op)
119 | return res
120 |
121 | def commit(self):
122 | if self._reqs is not None:
123 | raise RuntimeError("commit called twice")
124 | self._reqs = dist.batch_isend_irecv(self._ops)
125 |
126 | def wait(self):
127 | if self._reqs is None:
128 | raise RuntimeError("wait called before commit")
129 | for req in self._reqs:
130 | req.wait()
131 | self._reqs = None
132 | self._ops = []
133 |
134 |
135 | class AllGatherComm:
136 | def __init__(self, group=None) -> None:
137 | self.group = group
138 | self.handles = []
139 |
140 | def all_gather(self, output_tensor: torch.Tensor, input_tensor: torch.Tensor):
141 | handle = dist.all_gather_into_tensor(
142 | output_tensor, input_tensor, group=self.group, async_op=True
143 | )
144 | self.handles.append(handle)
145 |
146 | def wait(self):
147 | for handle in self.handles:
148 | handle.wait()
149 | self.handles = []
150 |
--------------------------------------------------------------------------------
/train/ring-flash-attention/ring_flash_attn/zigzag_ring_flash_attn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.distributed as dist
3 | from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward
4 | from .utils import RingComm, update_out_and_lse, get_default_args
5 |
6 |
7 | def zigzag_ring_flash_attn_forward(
8 | process_group,
9 | q: torch.Tensor,
10 | k: torch.Tensor,
11 | v: torch.Tensor,
12 | softmax_scale,
13 | dropout_p=0,
14 | causal=True,
15 | window_size=(-1, -1),
16 | alibi_slopes=None,
17 | deterministic=False,
18 | ):
19 | assert causal == True, "zigzag ring is meaningless for causal=False"
20 | comm = RingComm(process_group)
21 |
22 | block_seq_len = q.shape[1] // 2
23 | q1 = q[:, block_seq_len:]
24 |
25 | out = None
26 | lse = None
27 | next_k, next_v = None, None
28 |
29 | def forward(q, k, v, causal):
30 | params = get_default_args(_flash_attn_forward).copy()
31 | params.update(
32 | {
33 | "q": q,
34 | "k": k,
35 | "v": v,
36 | "dropout_p": dropout_p,
37 | "softmax_scale": softmax_scale,
38 | "causal": causal,
39 | "window_size": window_size,
40 | "alibi_slopes": alibi_slopes,
41 | "return_softmax": True and dropout_p > 0,
42 | }
43 | )
44 | block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward(**params)
45 | return block_out, block_lse
46 |
47 | for step in range(comm.world_size):
48 | if step + 1 != comm.world_size:
49 | next_k: torch.Tensor = comm.send_recv(k)
50 | next_v: torch.Tensor = comm.send_recv(v)
51 | comm.commit()
52 |
53 | if step == 0:
54 | block_out, block_lse = forward(q, k, v, causal=True)
55 | out, lse = update_out_and_lse(out, lse, block_out, block_lse)
56 | elif step <= comm.rank:
57 | k0 = k[:, :block_seq_len]
58 | v0 = v[:, :block_seq_len]
59 | block_out, block_lse = forward(q, k0, v0, causal=False)
60 | out, lse = update_out_and_lse(out, lse, block_out, block_lse)
61 | else:
62 | block_out, block_lse = forward(q1, k, v, causal=False)
63 | out, lse = update_out_and_lse(
64 | out,
65 | lse,
66 | block_out,
67 | block_lse,
68 | slice_=(slice(None), slice(block_seq_len, None)),
69 | )
70 |
71 | if step + 1 != comm.world_size:
72 | comm.wait()
73 | k = next_k
74 | v = next_v
75 |
76 | out = out.to(q.dtype)
77 | lse = lse.squeeze(dim=-1).transpose(1, 2)
78 | return out, lse
79 |
80 |
81 | def zigzag_ring_flash_attn_backward(
82 | process_group,
83 | dout,
84 | q,
85 | k,
86 | v,
87 | out,
88 | softmax_lse,
89 | softmax_scale,
90 | dropout_p=0,
91 | causal=True,
92 | window_size=(-1, -1),
93 | alibi_slopes=None,
94 | deterministic=False,
95 | ):
96 | assert causal == True, "zigzag ring is meaningless for causal=False"
97 | kv_comm = RingComm(process_group)
98 | d_kv_comm = RingComm(process_group)
99 | dq, dk, dv = None, None, None
100 | next_dk, next_dv = None, None
101 | next_k, next_v = None, None
102 | dk_comm_buffer, dv_comm_buffer = None, None
103 |
104 | dout1 = dout.chunk(2, dim=1)[1]
105 | q1 = q.chunk(2, dim=1)[1]
106 | out1 = out.chunk(2, dim=1)[1]
107 | softmax_lse1 = softmax_lse.chunk(2, dim=2)[1].contiguous()
108 | block_seq_len = q.shape[1] // 2
109 |
110 | # repeatly allocating buffer may be slow...
111 | dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device)
112 | dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device)
113 | dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device)
114 |
115 | def backward(dout, q, k, v, out, softmax_lse, causal):
116 | seqlen_q = q.shape[1]
117 | seqlen_kv = k.shape[1]
118 | params = get_default_args(_flash_attn_backward).copy()
119 | params.update(
120 | {
121 | "dout": dout,
122 | "q": q,
123 | "k": k,
124 | "v": v,
125 | "out": out,
126 | "softmax_lse": softmax_lse,
127 | "dq": dq_buffer[:, :seqlen_q],
128 | "dk": dk_buffer[:, :seqlen_kv],
129 | "dv": dv_buffer[:, :seqlen_kv],
130 | "dropout_p": dropout_p,
131 | "softmax_scale": softmax_scale,
132 | "causal": causal,
133 | "window_size": window_size,
134 | "alibi_slopes": alibi_slopes,
135 | "deterministic": deterministic,
136 | }
137 | )
138 | _flash_attn_backward(**params)
139 |
140 | for step in range(kv_comm.world_size):
141 | if step + 1 != kv_comm.world_size:
142 | next_k = kv_comm.send_recv(k)
143 | next_v = kv_comm.send_recv(v)
144 | kv_comm.commit()
145 |
146 | if step == 0:
147 | backward(dout, q, k, v, out, softmax_lse, causal=True)
148 | dq = dq_buffer.to(torch.float32)
149 | dk = dk_buffer.to(torch.float32)
150 | dv = dv_buffer.to(torch.float32)
151 | else:
152 | if step <= kv_comm.rank:
153 | k0 = k[:, :block_seq_len]
154 | v0 = v[:, :block_seq_len]
155 | backward(dout, q, k0, v0, out, softmax_lse, causal=False)
156 | dq += dq_buffer
157 | else:
158 | backward(dout1, q1, k, v, out1, softmax_lse1, causal=False)
159 | # always use the first half in dq_buffer.
160 | dq[:, block_seq_len:] += dq_buffer[:, :block_seq_len]
161 |
162 | d_kv_comm.wait()
163 | dk_comm_buffer, dv_comm_buffer = dk, dv
164 | dk, dv = next_dk, next_dv
165 |
166 | if step <= kv_comm.rank:
167 | dk[:, :block_seq_len] += dk_buffer[:, :block_seq_len]
168 | dv[:, :block_seq_len] += dv_buffer[:, :block_seq_len]
169 | else:
170 | dk += dk_buffer
171 | dv += dv_buffer
172 |
173 | if step + 1 != kv_comm.world_size:
174 | kv_comm.wait()
175 | k = next_k
176 | v = next_v
177 |
178 | next_dk = d_kv_comm.send_recv(dk, dk_comm_buffer)
179 | next_dv = d_kv_comm.send_recv(dv, dv_comm_buffer)
180 | d_kv_comm.commit()
181 |
182 | d_kv_comm.wait()
183 |
184 | return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype)
185 |
186 |
187 | class ZigZagRingFlashAttnFunc(torch.autograd.Function):
188 | @staticmethod
189 | def forward(
190 | ctx,
191 | q,
192 | k,
193 | v,
194 | dropout_p,
195 | softmax_scale,
196 | causal,
197 | window_size,
198 | alibi_slopes,
199 | deterministic,
200 | return_softmax,
201 | group,
202 | ):
203 | if softmax_scale is None:
204 | softmax_scale = q.shape[-1] ** (-0.5)
205 |
206 | assert alibi_slopes is None
207 | k = k.contiguous()
208 | v = v.contiguous()
209 | out, softmax_lse = zigzag_ring_flash_attn_forward(
210 | group,
211 | q,
212 | k,
213 | v,
214 | softmax_scale=softmax_scale,
215 | dropout_p=dropout_p,
216 | causal=causal,
217 | window_size=window_size,
218 | alibi_slopes=alibi_slopes,
219 | deterministic=False,
220 | )
221 | # this should be out_padded
222 | ctx.save_for_backward(q, k, v, out, softmax_lse)
223 | ctx.dropout_p = dropout_p
224 | ctx.softmax_scale = softmax_scale
225 | ctx.causal = causal
226 | ctx.window_size = window_size
227 | ctx.alibi_slopes = alibi_slopes
228 | ctx.deterministic = deterministic
229 | ctx.group = group
230 | return out if not return_softmax else (out, softmax_lse, None)
231 |
232 | @staticmethod
233 | def backward(ctx, dout, *args):
234 | q, k, v, out, softmax_lse = ctx.saved_tensors
235 | dq, dk, dv = zigzag_ring_flash_attn_backward(
236 | ctx.group,
237 | dout,
238 | q,
239 | k,
240 | v,
241 | out,
242 | softmax_lse,
243 | softmax_scale=ctx.softmax_scale,
244 | dropout_p=ctx.dropout_p,
245 | causal=ctx.causal,
246 | window_size=ctx.window_size,
247 | alibi_slopes=ctx.alibi_slopes,
248 | deterministic=ctx.deterministic,
249 | )
250 | return dq, dk, dv, None, None, None, None, None, None, None, None
251 |
252 |
253 | def zigzag_ring_flash_attn_qkvpacked_func(
254 | qkv,
255 | dropout_p=0.0,
256 | softmax_scale=None,
257 | causal=False,
258 | window_size=(-1, -1),
259 | alibi_slopes=None,
260 | deterministic=False,
261 | return_attn_probs=False,
262 | group=None,
263 | ):
264 | return ZigZagRingFlashAttnFunc.apply(
265 | qkv[:, :, 0],
266 | qkv[:, :, 1],
267 | qkv[:, :, 2],
268 | dropout_p,
269 | softmax_scale,
270 | causal,
271 | window_size,
272 | alibi_slopes,
273 | deterministic,
274 | return_attn_probs,
275 | group,
276 | )
277 |
278 |
279 | def zigzag_ring_flash_attn_kvpacked_func(
280 | q,
281 | kv,
282 | dropout_p=0.0,
283 | softmax_scale=None,
284 | causal=False,
285 | window_size=(-1, -1),
286 | alibi_slopes=None,
287 | deterministic=False,
288 | return_attn_probs=False,
289 | group=None,
290 | ):
291 | return ZigZagRingFlashAttnFunc.apply(
292 | q,
293 | kv[:, :, 0],
294 | kv[:, :, 1],
295 | dropout_p,
296 | softmax_scale,
297 | causal,
298 | window_size,
299 | alibi_slopes,
300 | deterministic,
301 | return_attn_probs,
302 | group,
303 | )
304 |
305 |
306 | def zigzag_ring_flash_attn_func(
307 | q,
308 | k,
309 | v,
310 | dropout_p=0.0,
311 | softmax_scale=None,
312 | causal=False,
313 | window_size=(-1, -1),
314 | alibi_slopes=None,
315 | deterministic=False,
316 | return_attn_probs=False,
317 | group=None,
318 | ):
319 | return ZigZagRingFlashAttnFunc.apply(
320 | q,
321 | k,
322 | v,
323 | dropout_p,
324 | softmax_scale,
325 | causal,
326 | window_size,
327 | alibi_slopes,
328 | deterministic,
329 | return_attn_probs,
330 | group,
331 | )
332 |
--------------------------------------------------------------------------------
/train/ring-flash-attention/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name="ring_flash_attn",
5 | version="0.1",
6 | author="zhuzilin",
7 | url="https://github.com/zhuzilin/ring-flash-attention",
8 | packages=find_packages(),
9 | )
10 |
--------------------------------------------------------------------------------
/train/ring-flash-attention/test/test_llama3_flash_attn_varlen_func.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.distributed as dist
3 | from flash_attn import flash_attn_varlen_qkvpacked_func
4 | from ring_flash_attn import (
5 | llama3_flash_attn_prepare_cu_seqlens,
6 | llama3_flash_attn_varlen_qkvpacked_func,
7 | )
8 | from utils import log, set_seed
9 |
10 |
11 | if __name__ == "__main__":
12 | dist.init_process_group("nccl")
13 | rank = dist.get_rank()
14 | set_seed(rank)
15 | world_size = dist.get_world_size()
16 | dtype = torch.bfloat16
17 | device = torch.device(f"cuda:{rank}")
18 |
19 | batch_size = 1
20 | nheads = 5
21 | d = 8
22 | dropout_p = 0
23 | causal = True
24 | deterministic = False
25 |
26 | cu_seqlens = [0, 120, 1248, 4232]
27 | cu_seqlens_tensor = torch.tensor(cu_seqlens, dtype=torch.int32, device=device)
28 | max_seqlen = (cu_seqlens_tensor[1:] - cu_seqlens_tensor[:-1]).max().item()
29 | total_length = cu_seqlens[-1]
30 | local_length = total_length // world_size
31 | num_seq = len(cu_seqlens) - 1
32 |
33 | assert cu_seqlens_tensor[-1] % world_size == 0
34 | assert d % 8 == 0
35 |
36 | qkv = torch.randn(
37 | total_length, 3, nheads, d, device=device, dtype=dtype, requires_grad=True
38 | )
39 | dist.broadcast(qkv, src=0)
40 |
41 | dout = torch.randn(total_length, nheads, d, device=device, dtype=dtype)
42 | dist.broadcast(dout, src=0)
43 |
44 | local_qkv = qkv[rank * local_length : (rank + 1) * local_length].detach().clone()
45 | local_dout = dout[rank * local_length : (rank + 1) * local_length].detach().clone()
46 | local_qkv.requires_grad = True
47 |
48 | dist.barrier()
49 | if rank == 0:
50 | print("#" * 30)
51 | print("# forward:")
52 | print("#" * 30)
53 |
54 | out, lse, _ = flash_attn_varlen_qkvpacked_func(
55 | qkv,
56 | cu_seqlens_tensor,
57 | max_seqlen,
58 | dropout_p=dropout_p,
59 | causal=causal,
60 | window_size=(-1, -1),
61 | alibi_slopes=None,
62 | deterministic=deterministic,
63 | return_attn_probs=True,
64 | )
65 |
66 | local_out = out[rank * local_length : (rank + 1) * local_length]
67 | if lse.dim() == 2:
68 | local_lse = lse[:, rank * local_length : (rank + 1) * local_length]
69 |
70 | (
71 | local_cu_seqlens_q,
72 | local_cu_seqlens_k,
73 | max_seqlen_q,
74 | max_seqlen_k,
75 | local_k_slice,
76 | ) = llama3_flash_attn_prepare_cu_seqlens(
77 | cu_seqlens_tensor,
78 | causal=causal,
79 | rank=rank,
80 | world_size=world_size,
81 | )
82 |
83 | llama3_out, llama3_lse, _ = llama3_flash_attn_varlen_qkvpacked_func(
84 | local_qkv,
85 | local_cu_seqlens_q,
86 | local_cu_seqlens_k,
87 | max_seqlen_q,
88 | max_seqlen_k,
89 | heads_k_stride=1,
90 | local_k_slice=local_k_slice,
91 | dropout_p=dropout_p,
92 | causal=causal,
93 | window_size=(-1, -1),
94 | alibi_slopes=None,
95 | deterministic=deterministic,
96 | return_attn_probs=True,
97 | )
98 |
99 | log("out", out, rank0_only=True)
100 | log("out diff", local_out - llama3_out)
101 | if lse.dim() == 2:
102 | log("lse", lse, rank0_only=True)
103 | log("lse diff", local_lse - llama3_lse)
104 |
105 | dist.barrier()
106 | if rank == 0:
107 | print("#" * 30)
108 | print("# backward:")
109 | print("#" * 30)
110 |
111 | out.backward(dout)
112 | dqkv = qkv.grad
113 | local_dqkv = dqkv[rank * local_length : (rank + 1) * local_length]
114 |
115 | llama3_out.backward(local_dout)
116 | llama3_dqkv = local_qkv.grad
117 |
118 | log("local_dq", local_dqkv[:, 0])
119 | log("dq diff", local_dqkv[:, 0] - llama3_dqkv[:, 0])
120 | log("dk diff", local_dqkv[:, 1] - llama3_dqkv[:, 1])
121 | log("dv diff", local_dqkv[:, 2] - llama3_dqkv[:, 2])
122 |
--------------------------------------------------------------------------------
/train/ring-flash-attention/test/test_llama3_prepare_cu_seqlens.py:
--------------------------------------------------------------------------------
1 | from ring_flash_attn import llama3_flash_attn_prepare_cu_seqlens
2 | import torch
3 |
4 | if __name__ == "__main__":
5 | device = torch.device("cuda")
6 | cu_seqlens = [0, 7, 14, 16]
7 | cu_seqlens_tensor = torch.tensor(cu_seqlens, dtype=torch.int32, device=device)
8 |
9 | world_size = 8
10 | for rank in range(world_size):
11 | (
12 | cu_seqlens_q,
13 | cu_seqlens_k,
14 | max_seqlen_q,
15 | max_seqlen_k,
16 | local_k_slice,
17 | ) = llama3_flash_attn_prepare_cu_seqlens(
18 | cu_seqlens_tensor,
19 | causal=True,
20 | rank=rank,
21 | world_size=world_size,
22 | )
23 |
24 | assert max_seqlen_q == (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).max()
25 | assert max_seqlen_k == (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max()
26 | print(f"RANK: {rank}")
27 | print(f" cu_seqlens_q: {cu_seqlens_q}")
28 | print(f" cu_seqlens_k: {cu_seqlens_k}")
29 | print(f" local_k_slice: {local_k_slice}")
30 |
--------------------------------------------------------------------------------
/train/ring-flash-attention/test/test_ring_flash_attn_func.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.distributed as dist
3 | from flash_attn import flash_attn_qkvpacked_func
4 | from ring_flash_attn import ring_flash_attn_qkvpacked_func
5 | from utils import log, set_seed
6 |
7 |
8 | if __name__ == "__main__":
9 | dist.init_process_group("nccl")
10 | rank = dist.get_rank()
11 | set_seed(rank)
12 | world_size = dist.get_world_size()
13 | dtype = torch.bfloat16
14 | device = torch.device(f"cuda:{rank}")
15 |
16 | batch_size = 1
17 | seqlen = 3816
18 | nheads = 5
19 | d = 128
20 | dropout_p = 0
21 | causal = True
22 | deterministic = False
23 |
24 | assert seqlen % world_size == 0
25 | assert d % 8 == 0
26 |
27 | qkv = torch.randn(
28 | batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True
29 | )
30 | dist.broadcast(qkv, src=0)
31 |
32 | dout = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype)
33 | dist.broadcast(dout, src=0)
34 |
35 | local_qkv = qkv.chunk(world_size, dim=1)[rank].detach().clone()
36 | local_qkv.requires_grad = True
37 | local_dout = dout.chunk(world_size, dim=1)[rank].detach().clone()
38 |
39 | dist.barrier()
40 | if rank == 0:
41 | print("#" * 30)
42 | print("# forward:")
43 | print("#" * 30)
44 |
45 | out, lse, _ = flash_attn_qkvpacked_func(
46 | qkv,
47 | dropout_p=dropout_p,
48 | causal=causal,
49 | window_size=(-1, -1),
50 | alibi_slopes=None,
51 | deterministic=deterministic,
52 | return_attn_probs=True,
53 | )
54 |
55 | local_out = out.chunk(world_size, dim=1)[rank]
56 | local_lse = lse.chunk(world_size, dim=-1)[rank]
57 |
58 | fn = ring_flash_attn_qkvpacked_func
59 |
60 | ring_out, ring_lse, _ = fn(
61 | local_qkv,
62 | dropout_p=dropout_p,
63 | causal=causal,
64 | window_size=(-1, -1),
65 | alibi_slopes=None,
66 | deterministic=deterministic,
67 | return_attn_probs=True,
68 | )
69 |
70 | log("out", out, rank0_only=True)
71 | log("lse", lse, rank0_only=True)
72 | log("out diff", local_out - ring_out)
73 | log("lse diff", local_lse - ring_lse)
74 |
75 | dist.barrier()
76 | if rank == 0:
77 | print("#" * 30)
78 | print("# backward:")
79 | print("#" * 30)
80 |
81 | out.backward(dout)
82 | dqkv = qkv.grad
83 | local_dqkv = dqkv.chunk(world_size, dim=1)[rank]
84 |
85 | ring_out.backward(local_dout)
86 | ring_dqkv = local_qkv.grad
87 |
88 | log("local_dqkv", local_dqkv)
89 | log("dq diff", local_dqkv[:, 0] - ring_dqkv[:, 0])
90 | log("dk diff", local_dqkv[:, 1] - ring_dqkv[:, 1])
91 | log("dv diff", local_dqkv[:, 2] - ring_dqkv[:, 2])
92 |
--------------------------------------------------------------------------------
/train/ring-flash-attention/test/test_ring_flash_attn_varlen_func.py:
--------------------------------------------------------------------------------
1 | from flash_attn import flash_attn_varlen_qkvpacked_func
2 | import torch
3 | import torch.distributed as dist
4 | from ring_flash_attn import ring_flash_attn_varlen_qkvpacked_func
5 | from utils import log, set_seed
6 |
7 |
8 | def extract_local(value, cu_seqlens, rank, world_size):
9 | local_values = []
10 | for i in range(len(cu_seqlens) - 1):
11 | start, end = cu_seqlens[i], cu_seqlens[i + 1]
12 | local_value = value[start:end].chunk(world_size, dim=0)[rank].detach().clone()
13 | local_values.append(local_value)
14 | return torch.cat(local_values, dim=0).contiguous()
15 |
16 |
17 | def extract_lse(lse, cu_seqlens):
18 | values = []
19 | if lse.dim() == 2:
20 | for i in range(len(cu_seqlens) - 1):
21 | start, end = cu_seqlens[i], cu_seqlens[i + 1]
22 | value = lse[:, start:end]
23 | values.append(value)
24 | else:
25 | assert lse.dim() == 3
26 | for i in range(len(cu_seqlens) - 1):
27 | start, end = cu_seqlens[i], cu_seqlens[i + 1]
28 | value = lse[i, :, : end - start]
29 | values.append(value)
30 | return values
31 |
32 |
33 | if __name__ == "__main__":
34 | dist.init_process_group("nccl")
35 | rank = dist.get_rank()
36 | set_seed(rank)
37 | world_size = dist.get_world_size()
38 | dtype = torch.bfloat16
39 | device = torch.device(f"cuda:{rank}")
40 |
41 | batch_size = 1
42 | nheads = 5
43 | d = 128
44 | dropout_p = 0
45 | causal = True
46 | deterministic = False
47 |
48 | cu_seqlens = [0, 120, 1248, 4232]
49 | cu_seqlens_tensor = torch.tensor(cu_seqlens, dtype=torch.int32, device=device)
50 | max_seqlen = (cu_seqlens_tensor[1:] - cu_seqlens_tensor[:-1]).max().item()
51 | total_length = cu_seqlens[-1]
52 | num_seq = len(cu_seqlens) - 1
53 |
54 | assert torch.all(cu_seqlens_tensor % world_size == 0)
55 | assert d % 8 == 0
56 |
57 | qkv = torch.randn(
58 | total_length, 3, nheads, d, device=device, dtype=dtype, requires_grad=True
59 | )
60 | dist.broadcast(qkv, src=0)
61 |
62 | dout = torch.randn(total_length, nheads, d, device=device, dtype=dtype)
63 | dist.broadcast(dout, src=0)
64 |
65 | local_cu_seqlens_tensor = cu_seqlens_tensor // world_size
66 | local_max_seqlen = max_seqlen // world_size
67 |
68 | local_qkv = extract_local(qkv, cu_seqlens, rank, world_size)
69 | local_qkv.requires_grad = True
70 | local_dout = extract_local(dout, cu_seqlens, rank, world_size)
71 |
72 | dist.barrier()
73 | if rank == 0:
74 | print("#" * 30)
75 | print("# forward:")
76 | print("#" * 30)
77 |
78 | out, lse, _ = flash_attn_varlen_qkvpacked_func(
79 | qkv,
80 | cu_seqlens_tensor,
81 | max_seqlen,
82 | dropout_p=dropout_p,
83 | causal=causal,
84 | window_size=(-1, -1),
85 | alibi_slopes=None,
86 | deterministic=deterministic,
87 | return_attn_probs=True,
88 | )
89 |
90 | local_out = extract_local(out, cu_seqlens, rank, world_size)
91 | lse_list = extract_lse(lse, cu_seqlens)
92 |
93 | ring_out, ring_lse, _ = ring_flash_attn_varlen_qkvpacked_func(
94 | local_qkv,
95 | local_cu_seqlens_tensor,
96 | local_max_seqlen,
97 | dropout_p=dropout_p,
98 | causal=causal,
99 | window_size=(-1, -1),
100 | alibi_slopes=None,
101 | deterministic=deterministic,
102 | return_attn_probs=True,
103 | )
104 |
105 | ring_lse_list = extract_lse(ring_lse, local_cu_seqlens_tensor.tolist())
106 |
107 | log("out", out, rank0_only=True)
108 | log("out diff", local_out - ring_out)
109 |
110 | for lse, ring_lse in zip(lse_list, ring_lse_list):
111 | local_lse = lse.chunk(world_size, dim=-1)[rank]
112 | log("lse", lse, rank0_only=True)
113 | log("lse diff", local_lse - ring_lse)
114 |
115 | dist.barrier()
116 | if rank == 0:
117 | print("#" * 30)
118 | print("# backward:")
119 | print("#" * 30)
120 |
121 | out.backward(dout)
122 | dqkv = qkv.grad
123 | local_dqkv = extract_local(dqkv, cu_seqlens, rank, world_size)
124 |
125 | ring_out.backward(local_dout)
126 | ring_dqkv = local_qkv.grad
127 |
128 | log("local_dqkv", local_dqkv)
129 | log("dq diff", local_dqkv[:, 0] - ring_dqkv[:, 0])
130 | log("dk diff", local_dqkv[:, 1] - ring_dqkv[:, 1])
131 | log("dv diff", local_dqkv[:, 2] - ring_dqkv[:, 2])
132 |
--------------------------------------------------------------------------------
/train/ring-flash-attention/test/test_stripe_flash_attn_func.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.distributed as dist
3 | from flash_attn import flash_attn_qkvpacked_func
4 | from ring_flash_attn import stripe_flash_attn_qkvpacked_func
5 | from utils import log, set_seed
6 |
7 |
8 | def extract_local(value, rank, world_size, dim=1):
9 | value = torch.stack(value.split(world_size, dim=dim), dim=dim).transpose(
10 | dim, dim + 1
11 | )
12 | slicer = [rank if i == dim else slice(None) for i in range(len(value.shape))]
13 | return value[slicer].contiguous()
14 |
15 |
16 | if __name__ == "__main__":
17 | dist.init_process_group("nccl")
18 | rank = dist.get_rank()
19 | set_seed(rank)
20 | world_size = dist.get_world_size()
21 | dtype = torch.bfloat16
22 | device = torch.device(f"cuda:{rank}")
23 |
24 | batch_size = 1
25 | seqlen = 3824
26 | nheads = 5
27 | d = 128
28 | dropout_p = 0
29 | causal = True
30 | deterministic = False
31 |
32 | assert causal
33 | assert seqlen % (2 * world_size) == 0
34 | assert d % 8 == 0
35 |
36 | qkv = torch.randn(
37 | batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True
38 | )
39 | dist.broadcast(qkv, src=0)
40 |
41 | dout = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype)
42 | dist.broadcast(dout, src=0)
43 |
44 | local_qkv = extract_local(qkv, rank, world_size).detach().clone()
45 | local_qkv.requires_grad = True
46 | local_dout = extract_local(dout, rank, world_size).detach().clone()
47 |
48 | dist.barrier()
49 | if rank == 0:
50 | print("#" * 30)
51 | print("# forward:")
52 | print("#" * 30)
53 |
54 | out, lse, _ = flash_attn_qkvpacked_func(
55 | qkv,
56 | dropout_p=dropout_p,
57 | causal=causal,
58 | window_size=(-1, -1),
59 | alibi_slopes=None,
60 | deterministic=deterministic,
61 | return_attn_probs=True,
62 | )
63 |
64 | local_out = extract_local(out, rank, world_size)
65 | local_lse = extract_local(lse, rank, world_size, dim=2)
66 |
67 | ring_out, ring_lse, _ = stripe_flash_attn_qkvpacked_func(
68 | local_qkv,
69 | dropout_p=dropout_p,
70 | causal=causal,
71 | window_size=(-1, -1),
72 | alibi_slopes=None,
73 | deterministic=deterministic,
74 | return_attn_probs=True,
75 | )
76 |
77 | log("out", out, rank0_only=True)
78 | log("lse", lse, rank0_only=True)
79 | log("out diff", local_out - ring_out)
80 | log("lse diff", local_lse - ring_lse)
81 |
82 | dist.barrier()
83 | if rank == 0:
84 | print("#" * 30)
85 | print("# backward:")
86 | print("#" * 30)
87 |
88 | out.backward(dout)
89 | dqkv = qkv.grad
90 |
91 | local_dqkv = extract_local(dqkv, rank, world_size)
92 |
93 | ring_out.backward(local_dout)
94 | ring_dqkv = local_qkv.grad
95 |
96 | log("local_dqkv", local_dqkv)
97 | log("dq diff", local_dqkv[:, 0] - ring_dqkv[:, 0])
98 | log("dk diff", local_dqkv[:, 1] - ring_dqkv[:, 1])
99 | log("dv diff", local_dqkv[:, 2] - ring_dqkv[:, 2])
100 |
--------------------------------------------------------------------------------
/train/ring-flash-attention/test/test_triton_kernels.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from ring_flash_attn.utils import (
3 | flatten_varlen_lse,
4 | unflatten_varlen_lse,
5 | )
6 | from ring_flash_attn.triton_utils import (
7 | flatten_varlen_lse as triton_flatten_varlen_lse,
8 | unflatten_varlen_lse as triton_unflatten_varlen_lse,
9 | )
10 |
11 |
12 | if __name__ == "__main__":
13 | device = torch.device("cuda:0")
14 |
15 | cu_seqlens = [0, 15, 156, 529]
16 | cu_seqlens_tensor = torch.tensor(cu_seqlens, dtype=torch.int32, device=device)
17 | batch_size = len(cu_seqlens) - 1
18 | max_seqlen = (cu_seqlens_tensor[1:] - cu_seqlens_tensor[:-1]).max().item()
19 | n_head = 5
20 |
21 | lse = torch.randn(
22 | (batch_size, n_head, max_seqlen), dtype=torch.float32, device=device
23 | )
24 | flatten_lse = flatten_varlen_lse(lse, cu_seqlens_tensor)
25 | triton_flatten_lse = triton_flatten_varlen_lse(lse, cu_seqlens_tensor)
26 | assert torch.all(flatten_lse == triton_flatten_lse)
27 |
28 | flatten_lse = flatten_lse.transpose(-2, -1).unsqueeze(dim=-1)
29 | triton_flatten_lse = triton_flatten_lse.transpose(-2, -1).unsqueeze(dim=-1)
30 |
31 | unflatten_lse = unflatten_varlen_lse(flatten_lse, cu_seqlens_tensor, max_seqlen)
32 | triton_unflatten_lse = triton_unflatten_varlen_lse(
33 | triton_flatten_lse, cu_seqlens_tensor, max_seqlen
34 | )
35 |
36 | for i in range(batch_size):
37 | seqlen = cu_seqlens[i + 1] - cu_seqlens[i]
38 | assert torch.all(
39 | unflatten_lse[i, :, :seqlen] == triton_unflatten_lse[i, :, :seqlen]
40 | ), f"{unflatten_lse[i, :seqlen]} vs {triton_unflatten_lse[i, :seqlen]}"
41 |
--------------------------------------------------------------------------------
/train/ring-flash-attention/test/test_zigzag_ring_flash_attn_func.py:
--------------------------------------------------------------------------------
1 | from flash_attn import flash_attn_qkvpacked_func
2 | import torch
3 | import torch.distributed as dist
4 | from ring_flash_attn import zigzag_ring_flash_attn_qkvpacked_func
5 | from utils import log, set_seed
6 |
7 |
8 | def extract_local(value, rank, world_size, dim=1):
9 | value_chunks = value.chunk(2 * world_size, dim=dim)
10 | local_value = torch.cat(
11 | [value_chunks[rank], value_chunks[2 * world_size - rank - 1]], dim=dim
12 | )
13 | return local_value.contiguous()
14 |
15 |
16 | if __name__ == "__main__":
17 | dist.init_process_group("nccl")
18 | rank = dist.get_rank()
19 | set_seed(rank)
20 | world_size = dist.get_world_size()
21 | dtype = torch.bfloat16
22 | device = torch.device(f"cuda:{rank}")
23 |
24 | batch_size = 1
25 | seqlen = 3824
26 | nheads = 5
27 | d = 128
28 | dropout_p = 0
29 | causal = True
30 | deterministic = False
31 |
32 | assert causal
33 | assert seqlen % (2 * world_size) == 0
34 | assert d % 8 == 0
35 |
36 | qkv = torch.randn(
37 | batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True
38 | )
39 | dist.broadcast(qkv, src=0)
40 |
41 | dout = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype)
42 | dist.broadcast(dout, src=0)
43 |
44 | local_qkv = extract_local(qkv, rank, world_size).detach().clone()
45 | local_qkv.requires_grad = True
46 | local_dout = extract_local(dout, rank, world_size).detach().clone()
47 |
48 | dist.barrier()
49 | if rank == 0:
50 | print("#" * 30)
51 | print("# forward:")
52 | print("#" * 30)
53 |
54 | out, lse, _ = flash_attn_qkvpacked_func(
55 | qkv,
56 | dropout_p=dropout_p,
57 | causal=causal,
58 | window_size=(-1, -1),
59 | alibi_slopes=None,
60 | deterministic=deterministic,
61 | return_attn_probs=True,
62 | )
63 |
64 | local_out = extract_local(out, rank, world_size)
65 | local_lse = extract_local(lse, rank, world_size, dim=2)
66 |
67 | ring_out, ring_lse, _ = zigzag_ring_flash_attn_qkvpacked_func(
68 | local_qkv,
69 | dropout_p=dropout_p,
70 | causal=causal,
71 | window_size=(-1, -1),
72 | alibi_slopes=None,
73 | deterministic=deterministic,
74 | return_attn_probs=True,
75 | )
76 |
77 | log("out", out, rank0_only=True)
78 | log("lse", lse, rank0_only=True)
79 | log("out diff", local_out - ring_out)
80 | log("lse diff", local_lse - ring_lse)
81 |
82 | dist.barrier()
83 | if rank == 0:
84 | print("#" * 30)
85 | print("# backward:")
86 | print("#" * 30)
87 |
88 | out.backward(dout)
89 | dqkv = qkv.grad
90 |
91 | local_dqkv = extract_local(dqkv, rank, world_size)
92 |
93 | ring_out.backward(local_dout)
94 | ring_dqkv = local_qkv.grad
95 |
96 | log("local_dqkv", local_dqkv)
97 | log("dq diff", local_dqkv[:, 0] - ring_dqkv[:, 0])
98 | log("dk diff", local_dqkv[:, 1] - ring_dqkv[:, 1])
99 | log("dv diff", local_dqkv[:, 2] - ring_dqkv[:, 2])
100 |
--------------------------------------------------------------------------------
/train/ring-flash-attention/test/test_zigzag_ring_flash_attn_varlen_func.py:
--------------------------------------------------------------------------------
1 | from flash_attn import flash_attn_varlen_qkvpacked_func
2 | import torch
3 | import torch.distributed as dist
4 | from ring_flash_attn import zigzag_ring_flash_attn_varlen_qkvpacked_func
5 | from utils import log, set_seed
6 |
7 |
8 | def extract_local(value, cu_seqlens, rank, world_size):
9 | local_values = []
10 | for i in range(len(cu_seqlens) - 1):
11 | start, end = cu_seqlens[i], cu_seqlens[i + 1]
12 | local_value = value[start:end].chunk(2 * world_size, dim=0)
13 | local_values.extend(
14 | [
15 | local_value[rank].detach().clone(),
16 | local_value[2 * world_size - 1 - rank].detach().clone(),
17 | ]
18 | )
19 | return torch.cat(local_values, dim=0).contiguous()
20 |
21 |
22 | def extract_lse(lse, cu_seqlens):
23 | values = []
24 | if lse.dim() == 2:
25 | for i in range(len(cu_seqlens) - 1):
26 | start, end = cu_seqlens[i], cu_seqlens[i + 1]
27 | value = lse[:, start:end]
28 | values.append(value)
29 | else:
30 | assert lse.dim() == 3
31 | for i in range(len(cu_seqlens) - 1):
32 | start, end = cu_seqlens[i], cu_seqlens[i + 1]
33 | value = lse[i, :, : end - start]
34 | values.append(value)
35 | return values
36 |
37 |
38 | if __name__ == "__main__":
39 | dist.init_process_group("nccl")
40 | rank = dist.get_rank()
41 | set_seed(rank)
42 | world_size = dist.get_world_size()
43 | dtype = torch.bfloat16
44 | device = torch.device(f"cuda:{rank}")
45 |
46 | batch_size = 1
47 | nheads = 5
48 | d = 128
49 | dropout_p = 0
50 | causal = True
51 | deterministic = False
52 |
53 | cu_seqlens = [0, 128, 1248, 4240]
54 | cu_seqlens_tensor = torch.tensor(cu_seqlens, dtype=torch.int32, device=device)
55 | max_seqlen = (cu_seqlens_tensor[1:] - cu_seqlens_tensor[:-1]).max().item()
56 | total_length = cu_seqlens[-1]
57 | num_seq = len(cu_seqlens) - 1
58 |
59 | assert torch.all(cu_seqlens_tensor % (2 * world_size) == 0)
60 | assert d % 8 == 0
61 |
62 | qkv = torch.randn(
63 | total_length, 3, nheads, d, device=device, dtype=dtype, requires_grad=True
64 | )
65 | dist.broadcast(qkv, src=0)
66 |
67 | dout = torch.randn(total_length, nheads, d, device=device, dtype=dtype)
68 | dist.broadcast(dout, src=0)
69 |
70 | local_cu_seqlens_tensor = cu_seqlens_tensor // world_size
71 | local_max_seqlen = max_seqlen // world_size
72 |
73 | local_qkv = extract_local(qkv, cu_seqlens, rank, world_size)
74 | local_qkv.requires_grad = True
75 | local_dout = extract_local(dout, cu_seqlens, rank, world_size)
76 |
77 | dist.barrier()
78 | if rank == 0:
79 | print("#" * 30)
80 | print("# forward:")
81 | print("#" * 30)
82 |
83 | out, lse, _ = flash_attn_varlen_qkvpacked_func(
84 | qkv,
85 | cu_seqlens_tensor,
86 | max_seqlen,
87 | dropout_p=dropout_p,
88 | causal=causal,
89 | window_size=(-1, -1),
90 | alibi_slopes=None,
91 | deterministic=deterministic,
92 | return_attn_probs=True,
93 | )
94 |
95 | local_out = extract_local(out, cu_seqlens, rank, world_size)
96 | lse_list = extract_lse(lse, cu_seqlens)
97 |
98 | ring_out, ring_lse, _ = zigzag_ring_flash_attn_varlen_qkvpacked_func(
99 | local_qkv,
100 | local_cu_seqlens_tensor,
101 | local_max_seqlen,
102 | dropout_p=dropout_p,
103 | causal=causal,
104 | window_size=(-1, -1),
105 | alibi_slopes=None,
106 | deterministic=deterministic,
107 | return_attn_probs=True,
108 | )
109 |
110 | ring_lse_list = extract_lse(ring_lse, local_cu_seqlens_tensor.tolist())
111 |
112 | log("out", out, rank0_only=True)
113 | log("out diff", local_out - ring_out)
114 |
115 | for i, (lse, ring_lse) in enumerate(zip(lse_list, ring_lse_list)):
116 | local_lse = lse.chunk(2 * world_size, dim=-1)
117 | local_lse = torch.cat(
118 | [local_lse[rank], local_lse[2 * world_size - 1 - rank]], dim=-1
119 | )
120 | log(f"lse {i}", lse, rank0_only=True)
121 | log(f"lse diff {i}", local_lse - ring_lse)
122 |
123 | dist.barrier()
124 | if rank == 0:
125 | print("#" * 30)
126 | print("# backward:")
127 | print("#" * 30)
128 |
129 | out.backward(dout)
130 | dqkv = qkv.grad
131 | local_dqkv = extract_local(dqkv, cu_seqlens, rank, world_size)
132 |
133 | ring_out.backward(local_dout)
134 | ring_dqkv = local_qkv.grad
135 |
136 | log("local_dqkv", local_dqkv)
137 | log("dq diff", local_dqkv[:, 0] - ring_dqkv[:, 0])
138 | log("dk diff", local_dqkv[:, 1] - ring_dqkv[:, 1])
139 | log("dv diff", local_dqkv[:, 2] - ring_dqkv[:, 2])
140 |
--------------------------------------------------------------------------------
/train/ring-flash-attention/test/utils.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import torch
4 | import torch.distributed as dist
5 |
6 |
7 | def set_seed(rank, seed=42):
8 | seed = rank + seed
9 | random.seed(seed)
10 | torch.manual_seed(seed)
11 | torch.cuda.manual_seed(seed)
12 | torch.cuda.manual_seed_all(seed)
13 |
14 |
15 | def log(msg, a, rank0_only=False):
16 | world_size = dist.get_world_size()
17 | rank = dist.get_rank()
18 | if rank0_only:
19 | if rank == 0:
20 | print(
21 | f"{msg}: "
22 | f"max {a.abs().max().item():.3g}, "
23 | f"mean {a.abs().mean().item():.3g}",
24 | flush=True,
25 | )
26 | return
27 |
28 | for i in range(world_size):
29 | if i == rank:
30 | if rank == 0:
31 | print(f"{msg}:")
32 | print(
33 | f"[{rank}] "
34 | f"max {a.abs().max().item():.3g}, "
35 | f"mean {a.abs().mean().item():.3g}",
36 | flush=True,
37 | )
38 | dist.barrier()
39 |
--------------------------------------------------------------------------------
/train/ring_attn_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import transformers
4 | from typing import Optional
5 | import torch.distributed as dist
6 | import torch.nn.functional as F
7 | from transformers.modeling_flash_attention_utils import (
8 | _flash_supports_window_size,
9 | is_flash_attn_greater_or_equal,
10 | )
11 | from ring_flash_attn.llama3_flash_attn_varlen import (
12 | llama3_flash_attn_varlen_func,
13 | llama3_flash_attn_prepare_cu_seqlens
14 | )
15 | from flash_attn.losses.cross_entropy import CrossEntropyLoss
16 | from flash_attn.ops.rms_norm import rms_norm
17 |
18 | def flash_rms_norm(self, x):
19 | return rms_norm(x, self.weight, self.variance_epsilon)
20 |
21 | RING_ATTN_GROUP = None
22 |
23 |
24 | def set_ring_attn_group(group):
25 | global RING_ATTN_GROUP
26 | RING_ATTN_GROUP = group
27 |
28 |
29 | def get_ring_attn_group():
30 | return RING_ATTN_GROUP
31 |
32 |
33 | def reset_ring_attn_position_ids(start, end, packed_seq_lens):
34 | """
35 | Calculate position ids for packed_seq_ids[start:end].
36 | For example, if the packed_seq_lens is [3, 2, 4, 1], start=2, end=8,
37 | the position ids will be [2, 0, 1, 0, 1, 2].
38 |
39 | Args:
40 | start: the start position
41 | end: the end position
42 | packed_seq_lens: the sequence lengths of packed sequences
43 | """
44 | position_ids = torch.zeros((1, end - start), dtype=torch.long, device=torch.cuda.current_device())
45 | offset = 0
46 | for seqlen in packed_seq_lens:
47 | seq_start = max(offset, start)
48 | seq_end = min(offset + seqlen, end)
49 | if seq_start < seq_end:
50 | position_ids[0, seq_start - start: seq_end - start] = torch.arange(seq_start - offset, seq_end - offset)
51 |
52 | offset += seqlen
53 | if offset >= end:
54 | break
55 | return position_ids
56 |
57 |
58 | def update_ring_attn_params(packed_seq_lens, total_seq_len):
59 | """
60 | Calculate the cu_seqlens for the current forward pass and pass the value to
61 | the substituted ring_flash_attn.
62 |
63 | Note that total_seq_len may be larger than the sum of packed_seq_lens because of padding.
64 | """
65 | assert RING_ATTN_GROUP is not None
66 | cu_seqlens = torch.cumsum(
67 | packed_seq_lens.clone().detach(), #torch.tensor(packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32),
68 | dim=-1,
69 | dtype=torch.int32,
70 | )
71 | cu_seqlens = F.pad(F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len)
72 | update_ring_flash_attn_params(cu_seqlens, RING_ATTN_GROUP)
73 |
74 |
75 |
76 | def convert_to_local_input(input_ids, labels, attention_mask, packed_seq_lens, category_ids):
77 | ring_attn_rank = dist.get_rank(group=RING_ATTN_GROUP)
78 | ring_attn_size = dist.get_world_size(group=RING_ATTN_GROUP)
79 | total_seq_len = input_ids.numel()
80 | local_seq_len = total_seq_len // ring_attn_size
81 | start, end = ring_attn_rank * local_seq_len, (ring_attn_rank + 1) * local_seq_len
82 | local_input_ids = input_ids[:, start:end]
83 | local_labels_ids = labels[:, start:end]
84 | local_attention_mask = attention_mask[:, start:end]
85 | local_category_ids = category_ids[:, start:end]
86 | local_position_ids = reset_ring_attn_position_ids(start, end, packed_seq_lens)
87 | update_ring_attn_params(packed_seq_lens, total_seq_len)
88 | return local_input_ids, local_labels_ids, local_attention_mask, local_position_ids, local_category_ids
89 |
90 |
91 | DATA_PARAMS = {}
92 |
93 |
94 | def update_ring_flash_attn_params(
95 | cu_seqlens: torch.Tensor, process_group: dist.ProcessGroup
96 | ):
97 | world_size = dist.get_world_size(group=process_group)
98 | rank = dist.get_rank(group=process_group)
99 | (
100 | cu_seqlens_q,
101 | cu_seqlens_k,
102 | max_seqlen_q,
103 | max_seqlen_k,
104 | local_k_slice,
105 | ) = llama3_flash_attn_prepare_cu_seqlens(cu_seqlens, True, rank, world_size)
106 | DATA_PARAMS.update(
107 | {
108 | "cu_seqlens_q": cu_seqlens_q,
109 | "cu_seqlens_k": cu_seqlens_k,
110 | "max_seqlen_q": max_seqlen_q,
111 | "max_seqlen_k": max_seqlen_k,
112 | "local_k_slice": local_k_slice,
113 | }
114 | )
115 |
116 |
117 | def create_ring_flash_attention_forward(
118 | process_group: dist.ProcessGroup, heads_k_stride: int
119 | ):
120 | def _flash_attention_forward(
121 | query_states: torch.Tensor,
122 | key_states: torch.Tensor,
123 | value_states: torch.Tensor,
124 | attention_mask: torch.Tensor,
125 | query_length: int,
126 | is_causal: bool,
127 | dropout: float = 0.0,
128 | position_ids: Optional[torch.Tensor] = None,
129 | softmax_scale: Optional[float] = None,
130 | sliding_window: Optional[int] = None,
131 | use_top_left_mask: bool = False,
132 | softcap: Optional[float] = None,
133 | deterministic: bool = None,
134 | ):
135 | """
136 | Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
137 | first unpad the input, then computes the attention scores and pad the final attention scores.
138 |
139 | Args:
140 | query_states (`torch.Tensor`):
141 | Input query states to be passed to Flash Attention API
142 | key_states (`torch.Tensor`):
143 | Input key states to be passed to Flash Attention API
144 | value_states (`torch.Tensor`):
145 | Input value states to be passed to Flash Attention API
146 | attention_mask (`torch.Tensor`):
147 | The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
148 | position of padding tokens and 1 for the position of non-padding tokens.
149 | dropout (`float`):
150 | Attention dropout
151 | softmax_scale (`float`, *optional*):
152 | The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
153 | use_top_left_mask (`bool`, defaults to `False`):
154 | flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference.
155 | softcap (`float`, *optional*):
156 | Softcap for the attention logits, used e.g. in gemma2.
157 | deterministic (`bool`, *optional*):
158 | Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled.
159 | """
160 | if not use_top_left_mask:
161 | causal = is_causal
162 | else:
163 | # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__.
164 | causal = is_causal and query_length != 1
165 |
166 | # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
167 | use_sliding_windows = (
168 | _flash_supports_window_size
169 | and sliding_window is not None
170 | and key_states.shape[1] > sliding_window
171 | )
172 | flash_kwargs = (
173 | {"window_size": (sliding_window, sliding_window)}
174 | if use_sliding_windows
175 | else {}
176 | )
177 |
178 | if is_flash_attn_greater_or_equal("2.4.1"):
179 | if deterministic is None:
180 | deterministic = (
181 | os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
182 | )
183 | flash_kwargs["deterministic"] = deterministic
184 | assert (
185 | softcap is None
186 | ), "llama3_flash_attn_varlen_func does not support softcap yet."
187 | # flash_kwargs["softcap"] = softcap
188 | flash_kwargs["group"] = process_group
189 |
190 | # not sure why attention_mask can be not None...
191 | assert causal, "only causal attention is supported yet."
192 | batch_size = query_states.size(0)
193 | assert batch_size == 1, "varlen data should be processed in advance."
194 |
195 | attn_output = llama3_flash_attn_varlen_func(
196 | query_states.squeeze(dim=0),
197 | key_states.squeeze(dim=0),
198 | value_states.squeeze(dim=0),
199 | cu_seqlens_q=DATA_PARAMS["cu_seqlens_q"],
200 | cu_seqlens_k=DATA_PARAMS["cu_seqlens_k"],
201 | max_seqlen_q=DATA_PARAMS["max_seqlen_q"],
202 | max_seqlen_k=DATA_PARAMS["max_seqlen_k"],
203 | heads_k_stride=heads_k_stride,
204 | local_k_slice=DATA_PARAMS["local_k_slice"],
205 | dropout_p=dropout,
206 | softmax_scale=softmax_scale,
207 | causal=causal,
208 | **flash_kwargs,
209 | )
210 |
211 | attn_output = attn_output.unsqueeze(dim=0)
212 |
213 | return attn_output
214 |
215 | return _flash_attention_forward
216 |
217 |
218 | class Config(object):
219 | def __init__(self, ring_attn_size, ring_head_stride):
220 | self.ring_attn_size = ring_attn_size
221 | self.ring_head_stride = ring_head_stride
222 | self.ring_attn_rank = None
223 |
224 |
225 | def setup_ring_attn(self):
226 | for i in range(dist.get_world_size() // self.ring_attn_size):
227 | ring_attn_ranks = list(
228 | range(
229 | i * self.ring_attn_size,
230 | (i + 1) * self.ring_attn_size,
231 | )
232 | )
233 | group = dist.new_group(ranks=ring_attn_ranks, backend="nccl")
234 | if dist.get_rank() in ring_attn_ranks:
235 | set_ring_attn_group(group)
236 | self.ring_attn_rank = dist.get_rank(group=group)
237 |
238 | transformers.models.qwen2.modeling_qwen2._flash_attention_forward = create_ring_flash_attention_forward(
239 | RING_ATTN_GROUP,
240 | heads_k_stride=self.ring_head_stride
241 | )
242 | transformers.models.qwen2.modeling_qwen2.Qwen2RMSNorm.forward = flash_rms_norm
243 | transformers.models.qwen2.modeling_qwen2.CrossEntropyLoss = CrossEntropyLoss
244 |
245 | transformers.models.llama.modeling_llama._flash_attention_forward = create_ring_flash_attention_forward(
246 | RING_ATTN_GROUP,
247 | heads_k_stride=self.ring_head_stride
248 | )
249 | transformers.models.llama.modeling_llama.LlamaRMSNorm.forward = flash_rms_norm
250 | transformers.models.llama.modeling_llama.CrossEntropyLoss = CrossEntropyLoss
251 | return RING_ATTN_GROUP
252 |
253 |
--------------------------------------------------------------------------------
/train/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import time
4 | import json
5 | import torch
6 | import argparse
7 | import pandas as pd
8 | from tqdm import tqdm
9 | from datetime import timedelta
10 | from accelerate import Accelerator
11 | from torch import distributed as dist
12 | from flash_attn.losses.cross_entropy import CrossEntropyLoss
13 | from flash_attn.utils.distributed import all_gather, all_reduce
14 | from transformers import AutoModelForCausalLM, AutoTokenizer
15 | from accelerate.utils import (
16 | InitProcessGroupKwargs,
17 | set_seed,
18 | DummyOptim,
19 | DummyScheduler,
20 | )
21 |
22 | os.environ["WANDB_MODE"] = "offline"
23 |
24 | from data import SFTData
25 | from ring_attn_utils import Config, convert_to_local_input
26 |
27 |
28 |
29 | def main(args):
30 | set_seed(args.seed)
31 | timeout = InitProcessGroupKwargs(timeout=timedelta(seconds=1_000_000))
32 | accelerator = Accelerator(
33 | gradient_accumulation_steps=args.gradient_accumulation_steps,
34 | mixed_precision="bf16",
35 | kwargs_handlers=[timeout],
36 | log_with="wandb"
37 | )
38 | accelerator.init_trackers(project_name=args.project_name)
39 | world_size = accelerator.num_processes
40 | sp_world_size = args.sequence_parallel_degree
41 | dp_world_size = world_size // sp_world_size
42 |
43 | accelerator.print(f"world_size: {world_size}")
44 | accelerator.print(f"sp_world_size: {sp_world_size}")
45 | accelerator.print(f"dp_world_size: {dp_world_size}")
46 |
47 | config = Config(ring_attn_size=sp_world_size, ring_head_stride=4)
48 | ring_group = config.setup_ring_attn()
49 |
50 | with open(args.data_config_path, 'r') as f:
51 | data_config = json.loads(f.read())
52 | idx2name = {}
53 | for file in data_config['ratio']:
54 | idx2name[file['category_id']] = file['category_name']
55 |
56 | sft_data = SFTData(args.data_path)
57 |
58 | tokenizer = AutoTokenizer.from_pretrained(args.model_path)
59 | model = AutoModelForCausalLM.from_pretrained(
60 | args.model_path,
61 | torch_dtype=torch.bfloat16,
62 | _attn_implementation="flash_attention_2",
63 | )
64 |
65 | num_steps_per_epoch = math.ceil(len(sft_data.data) / dp_world_size / args.gradient_accumulation_steps)
66 | max_train_steps = num_steps_per_epoch * args.num_epochs
67 |
68 | accelerator.print(f"数据总量: {len(sft_data.data)}")
69 | accelerator.print(f"训练步数: {max_train_steps}")
70 |
71 | optim = DummyOptim(model.parameters())
72 | scheduler = DummyScheduler(optim)
73 |
74 | local_ce = CrossEntropyLoss(reduction="none")
75 |
76 | model, optim, scheduler = accelerator.prepare(model, optim, scheduler)
77 | model.gradient_checkpointing_enable()
78 | accelerator.register_for_checkpointing(scheduler)
79 | model.train()
80 |
81 | global_step = 0
82 | for epoch in range(args.num_epochs):
83 | train_loader = sft_data.get_dataloader(dp_world_size, dist.get_rank() // sp_world_size, seed=args.seed,
84 | epoch=epoch, shuffle=True)
85 | train_steps = math.ceil(len(train_loader) / args.gradient_accumulation_steps)
86 | accelerator.print(f"每个epoch数据总量: {len(train_loader)}")
87 | accelerator.print(f"每个epoch训练步数: {train_steps}")
88 | progress_bar = tqdm(
89 | range(train_steps), disable=not accelerator.is_local_main_process
90 | )
91 | for step, batch in enumerate(train_loader):
92 | input_ids = batch['input_ids'].to(accelerator.device)
93 | labels = batch['labels'].to(accelerator.device)
94 | attention_mask = batch['attention_mask'].to(accelerator.device)
95 | packed_seq_lens = batch['packed_seq_lens'].to(accelerator.device)
96 | category_ids = batch['category_ids'].to(accelerator.device)
97 |
98 | local_input_ids, local_labels, local_attention_mask, local_position_ids, local_category_ids = convert_to_local_input(
99 | input_ids, labels, attention_mask, packed_seq_lens, category_ids
100 | )
101 |
102 | with accelerator.accumulate(model):
103 | out = model(
104 | input_ids=local_input_ids,
105 | attention_mask=local_attention_mask,
106 | position_ids=local_position_ids,
107 | )
108 | shift_logits = out.logits[..., :-1, :].contiguous()
109 | shift_labels = local_labels[..., 1:].contiguous()
110 | token_losses = local_ce(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
111 |
112 | if args.use_turn_loss:
113 | all_token_losses = all_gather(token_losses, ring_group)
114 | all_shift_labels = all_gather(shift_labels, ring_group)
115 | mask = all_shift_labels.view(-1) != -100
116 | idx = torch.where(mask)[0]
117 | diffs = torch.diff(idx)
118 | split_points = torch.where(diffs > 1)[0] + 1
119 | groups = torch.tensor_split(idx, split_points.tolist())
120 | loss = torch.stack([all_token_losses[g].mean() for g in groups]).mean()
121 | loss = torch.nan_to_num(loss, nan=0.0)
122 | else:
123 | loss = token_losses[shift_labels.view(-1) != -100].mean()
124 | dist.all_reduce(loss, op=dist.ReduceOp.SUM, group=ring_group)
125 | loss = loss / sp_world_size
126 | loss = torch.nan_to_num(loss, nan=0.0)
127 | accelerator.backward(loss)
128 |
129 | if accelerator.sync_gradients:
130 | gathered_loss = accelerator.gather(loss.clone().detach())
131 | mask = (gathered_loss != 0)
132 | if mask.sum() == 0:
133 | loss_ = torch.tensor(0.0, device=accelerator.device)
134 | else:
135 | loss_ = gathered_loss[mask].mean()
136 | loss_log = {
137 | "epoch": epoch,
138 | "steps": step,
139 | "lr": scheduler.get_last_lr()[0],
140 | "loss": loss_.item()
141 | }
142 |
143 | progress_bar.set_postfix(loss_log)
144 | progress_bar.update(1)
145 | time.sleep(0.1)
146 |
147 | token_losses_ = token_losses.clone().detach()
148 | all_token_losses = accelerator.gather(token_losses_)
149 | all_category_ids = accelerator.gather(local_category_ids)[..., 1:].contiguous()
150 |
151 | idx2mean_loss = pd.DataFrame({
152 | 'category_id': all_category_ids.cpu().ravel(),
153 | 'token_loss': all_token_losses.cpu().numpy().ravel()
154 | }).groupby('category_id')['token_loss'].mean().to_dict()
155 |
156 | for idx, mean_loss in idx2mean_loss.items():
157 | if idx == 0 or idx not in idx2name:
158 | continue
159 | loss_log[f"{idx2name[idx]}_loss"] = mean_loss
160 |
161 | loss_log_str = '\n' + json.dumps(loss_log, ensure_ascii=False, indent=4)
162 | accelerator.print(loss_log_str)
163 | accelerator.log(loss_log, step=global_step)
164 | global_step += 1
165 |
166 | optim.step()
167 | scheduler.step()
168 | optim.zero_grad()
169 |
170 | if args.save_checkpoint is not None:
171 | if not os.path.exists(args.save_checkpoint):
172 | os.makedirs(args.save_checkpoint, exist_ok=True)
173 | save_path = f"{args.save_checkpoint}/epoch{epoch}_end"
174 | accelerator.print(f"Saving model to {save_path}")
175 | accelerator.wait_for_everyone()
176 | state_dict = accelerator.get_state_dict(model)
177 | accelerator.unwrap_model(model).save_pretrained(
178 | save_path,
179 | is_main_process=accelerator.is_main_process,
180 | save_function=accelerator.save,
181 | state_dict=state_dict,
182 | )
183 | if accelerator.is_main_process:
184 | tokenizer.save_pretrained(save_path)
185 | os.remove(f"{save_path}/model.safetensors")
186 |
187 | accelerator.print(f"Saving Finished")
188 |
189 | accelerator.print(f"Training Finished")
190 | accelerator.end_training()
191 |
192 |
193 | if __name__ == "__main__":
194 | args = argparse.ArgumentParser()
195 | args.add_argument("--project_name", type=str, default='')
196 | args.add_argument("--gradient_accumulation_steps", type=int, default=1)
197 | args.add_argument("--save_checkpoint", type=str, default="")
198 | args.add_argument("--seed", type=int, default=2025)
199 | args.add_argument("--model_path", type=str, default="")
200 | args.add_argument("--data_path", type=str, default="")
201 | args.add_argument("--data_config_path", type=str, default="")
202 | args.add_argument("--sequence_parallel_degree", type=int, default=8)
203 | args.add_argument("--num_epochs", type=int, default=4)
204 | args.add_argument("--use_turn_loss", action='store_true')
205 | main(args.parse_args())
--------------------------------------------------------------------------------
/train/train.sh:
--------------------------------------------------------------------------------
1 | set -x
2 | export HOME=/tmp
3 | export CUDA_DEVICE_MAX_CONNECTIONS=1
4 | export PYTORCH_CUDA_ALLOC_CONF='max_split_size_mb:2048'
5 |
6 | ip=`hostname -i`
7 |
8 | project_name=Qwen2.5-72B-Instruct
9 | model_path=path_to_model
10 | data_path=../data/dataset/distill_r1_110k.packed.jsonl
11 | data_config_path=../data/config/distill_r1_110k.json
12 | save_checkpoint=path_to_save_dir
13 |
14 | python config.py --max_steps=2000 --learning_rate=2e-5 --warmup_num_steps=200
15 |
16 | master_ip=192.168.1.1 # 主节点ip
17 | num_machines=24
18 | num_processes=192
19 | machine_rank=0 # 0到num_machines-1,每台机器一个rank
20 |
21 | nohup accelerate launch \
22 | --config_file config/multi_node.yaml \
23 | --num_processes=$num_processes \
24 | --num_machines=$num_machines \
25 | --machine_rank=$machine_rank \
26 | --main_process_ip=$master_ip \
27 | train.py \
28 | --project_name $project_name \
29 | --gradient_accumulation_steps 1 \
30 | --save_checkpoint $save_checkpoint \
31 | --seed 2025 \
32 | --model_path $model_path \
33 | --data_path $data_path \
34 | --data_config_path $data_config_path \
35 | --sequence_parallel_degree 8 \
36 | --num_epochs 4 > logs/$ip.log 2>&1 &
--------------------------------------------------------------------------------