├── eagle
├── __init__.py
├── prepare_sharegpt_dataset.py
├── generate_trajectories.py
├── train_tp.py
├── train.py
└── llama2.py
├── benchmark
├── __init__.py
├── README.md
├── create_alpaca_prompts.py
└── benchmark.py
├── .gitignore
├── pages
├── README.md
├── train_qwen_3.md
└── reproduce_eagle_1_paper.md
├── resources
├── eagle_config.json
├── eagle_config_qwen3_8b.json
├── example_chat_template_with_generation_keyword.json
└── raw_example_dataset.jsonl
├── README.md
├── docker
└── Dockerfile
├── .github
└── workflows
│ └── docker_build_push.yaml
└── LICENSE
/eagle/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/benchmark/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | tokenized_dataset
2 | eagle.egg-info
3 | checkpoints
4 | wandb
5 | vllm
6 | venv
7 | eagle/__pycache__
8 | models
9 | *.crt
10 |
--------------------------------------------------------------------------------
/pages/README.md:
--------------------------------------------------------------------------------
1 | ## Pages
2 |
3 | Explore pages to get started
4 |
5 | * [Reproduce eagle 1 paper on sharegpt dataset](./reproduce_eagle_1_paper.md)
6 |
--------------------------------------------------------------------------------
/resources/eagle_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "LlamaForCausalLM"
4 | ],
5 | "model_type": "llama",
6 | "hidden_act": "silu",
7 | "hidden_size": 4096,
8 | "initializer_range": 0.02,
9 | "torch_dtype": "bfloat16",
10 | "intermediate_size": 11008,
11 | "num_attention_heads": 32,
12 | "num_hidden_layers": 1,
13 | "num_key_value_heads": 32,
14 | "rms_norm_eps": 1e-06
15 | }
16 |
--------------------------------------------------------------------------------
/resources/eagle_config_qwen3_8b.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "LlamaForCausalLM"
4 | ],
5 | "model_type": "llama",
6 | "hidden_act": "silu",
7 | "hidden_size": 4096,
8 | "initializer_range": 0.02,
9 | "torch_dtype": "bfloat16",
10 | "intermediate_size": 12288,
11 | "num_attention_heads": 32,
12 | "num_hidden_layers": 1,
13 | "num_key_value_heads": 32,
14 | "rms_norm_eps": 1e-06,
15 | "vocab_size": 151936
16 | }
17 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## 🦅 eagle
2 |
3 | Repository allows one to train eagle draft model fully compatible with [SGLang](https://github.com/sgl-project/sglang) that achives [paper](https://arxiv.org/abs/2401.15077v1) score in terms of end to end latency speed up and generation throughput. I will work on this project to make it minimalistic as possible while making it scalable to allow you to train SOTA eagle draft model under 1 hour on a single node of enterprise GPUs but not limited to. Checkout [pages](./pages/README.md) to get started
4 |
--------------------------------------------------------------------------------
/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvidia/cuda:12.9.1-devel-ubuntu24.04
2 |
3 | WORKDIR /opt/eagle
4 |
5 | ENV PYTHONUNBUFFERED=1
6 | ENV PYTHONPATH=/opt/eagle
7 |
8 | RUN apt update
9 | RUN apt install --yes python3-pip python3.12-venv
10 | RUN python3 -m venv /venv
11 |
12 | ENV PATH="/venv/bin:$PATH"
13 | RUN pip install --upgrade pip
14 |
15 | RUN pip install torch==2.7.1 transformers==4.52.4 datasets==3.6.0 accelerate==1.8.1 clearml==2.0.0
16 | RUN pip install psutil
17 | RUN pip uninstall -y ninja && pip install ninja
18 | RUN MAX_JOBS=4 pip install flash-attn==2.8.0.post2 --no-build-isolation
19 |
20 | COPY ./eagle ./eagle
21 |
--------------------------------------------------------------------------------
/.github/workflows/docker_build_push.yaml:
--------------------------------------------------------------------------------
1 | name: Build and push release docker image
2 |
3 | on:
4 | push:
5 | tags:
6 | - "*"
7 |
8 | jobs:
9 | docker:
10 | runs-on: ubuntu-latest
11 | steps:
12 | - name: Authenticate to docker hub
13 | uses: docker/login-action@v3
14 | with:
15 | username: ${{ secrets.DOCKER_USERNAME }}
16 | password: ${{ secrets.DOCKER_PASSWORD }}
17 | registry: ${{ secrets.DOCKER_REGISTRY_URL }}
18 |
19 | - name: Configure qemu
20 | uses: docker/setup-qemu-action@v3
21 |
22 | - name: Configure buildx
23 | uses: docker/setup-buildx-action@v3
24 |
25 | - name: Build and push release docker image
26 | uses: docker/build-push-action@v6
27 | with:
28 | push: true
29 | file: docker/Dockerfile
30 | tags: |
31 | ${{ secrets.DOCKER_USERNAME }}/eagle:${{ github.ref_name }}
32 | ${{ secrets.DOCKER_USERNAME }}/eagle:latest
33 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Vladislav Kruglikov
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
23 |
--------------------------------------------------------------------------------
/resources/example_chat_template_with_generation_keyword.json:
--------------------------------------------------------------------------------
1 | {
2 | "add_bos_token": true,
3 | "add_eos_token": false,
4 | "bos_token": {
5 | "__type": "AddedToken",
6 | "content": "",
7 | "lstrip": false,
8 | "normalized": false,
9 | "rstrip": false,
10 | "single_word": false
11 | },
12 | "chat_template": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{% generation %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endgeneration %}{% endif %}{% endfor %}",
13 | "clean_up_tokenization_spaces": false,
14 | "eos_token": {
15 | "__type": "AddedToken",
16 | "content": "",
17 | "lstrip": false,
18 | "normalized": false,
19 | "rstrip": false,
20 | "single_word": false
21 | },
22 | "legacy": false,
23 | "model_max_length": 1000000000000000019884624838656,
24 | "pad_token": null,
25 | "padding_side": "right",
26 | "sp_model_kwargs": {},
27 | "tokenizer_class": "LlamaTokenizer",
28 | "unk_token": {
29 | "__type": "AddedToken",
30 | "content": "",
31 | "lstrip": false,
32 | "normalized": false,
33 | "rstrip": false,
34 | "single_word": false
35 | }
36 | }
37 |
--------------------------------------------------------------------------------
/benchmark/README.md:
--------------------------------------------------------------------------------
1 | ## 📊 Benchmark
2 |
3 | Create alpaca prompts
4 |
5 | ```bash
6 | python3 ./benchmark/create_alpaca_prompts.py \
7 | --output ./benchmark/alpaca.jsonl \
8 | --n 128
9 | ```
10 |
11 | Benchmark base model with batch size 1 meaning only 1 request runs at most at the same time
12 |
13 | ```bash
14 | docker run \
15 | --gpus all \
16 | -e CUDA_VISIBLE_DEVICES=0 \
17 | -v ./models/meta-llama2-7b-chat-hf:/mnt/llama2-7b \
18 | -v ./eagle_model:/mnt/eagle \
19 | -v ./benchmark:/opt/benchmark \
20 | --ipc=host \
21 | --shm-size 32g \
22 | lmsysorg/sglang:v0.4.6.post5-cu124 \
23 | bash -c "cd /opt && export PYTHONPATH=$PYTHONPATH:. && python3 benchmark/benchmark.py \
24 | --model /mnt/llama2-7b \
25 | --prompts benchmark/alpaca.jsonl \
26 | --n 64 \
27 | --bs 1 \
28 | --output benchmark/report_alpaca_bs1_wo_eagle.json \
29 | --temperature 0"
30 | ```
31 |
32 | Benchmark base model with draft model with batch size 1 meaning only 1 request runs at most at the same time
33 |
34 | ```bash
35 | docker run \
36 | --gpus all \
37 | -e CUDA_VISIBLE_DEVICES=1 \
38 | -v ./models/meta-llama2-7b-chat-hf:/mnt/llama2-7b \
39 | -v ./step5000/vllm:/mnt/eagle \
40 | -v ./benchmark:/opt/benchmark \
41 | -v ./checkpoints:/opt/checkpoints \
42 | --ipc=host \
43 | --shm-size 32g \
44 | lmsysorg/sglang:v0.4.6.post5-cu124 \
45 | bash -c "cd /opt && export PYTHONPATH=$PYTHONPATH:. && python3 benchmark/benchmark.py \
46 | --model /mnt/llama2-7b \
47 | --prompts benchmark/alpaca.jsonl \
48 | --n 64 \
49 | --bs 1 \
50 | --output benchmark/report_alpaca_bs1_with_eagle_new_new.json \
51 | --eagle /opt/checkpoints/step_10/sglang \
52 | --steps 4 \
53 | --k 1 \
54 | --draft 4 \
55 | --speculative-algorithm EAGLE \
56 | --temperature 0"
57 | ```
58 |
--------------------------------------------------------------------------------
/benchmark/create_alpaca_prompts.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 | import argparse
3 | import datasets
4 |
5 |
6 | def create_alpaca_prompts() -> None:
7 | arguments = _parse_arguments()
8 | output_path = arguments.output
9 | n = arguments.n
10 | frac = arguments.frac
11 | if n is not None and frac is not None:
12 | raise ValueError("One of --n or --frac must be set")
13 |
14 | print("Loading dataset")
15 | dataset = datasets.load_dataset("tatsu-lab/alpaca")["train"]
16 | print(f"Dataset has {len(dataset)} rows")
17 | dataset = dataset.shuffle(seed=0)
18 | indices = range(n) if n is not None else range(int(frac * len(dataset)))
19 | dataset = dataset.select(indices)
20 | print(f"Dataset after select has {len(dataset)} rows")
21 | dataset = dataset.map(
22 | _apply_template,
23 | batched=False,
24 | num_proc=1,
25 | remove_columns=dataset.column_names,
26 | desc="Applying template"
27 | )
28 |
29 | print("Saving to disk")
30 | dataset.to_json(output_path)
31 |
32 |
33 | def _parse_arguments() -> argparse.Namespace:
34 | parser = argparse.ArgumentParser(description="Create alpaca prompts")
35 | parser.add_argument(
36 | "--output",
37 | type=pathlib.Path,
38 | required=True,
39 | help="Path to save prompts"
40 | )
41 | parser.add_argument(
42 | "--n",
43 | type=int,
44 | help="Number of sampels to take"
45 | )
46 | parser.add_argument(
47 | "--frac",
48 | type=float,
49 | help="Number of sampels to take from 0.0 to 1.0 percent"
50 | )
51 | return parser.parse_args()
52 |
53 |
54 | def _apply_template(example: dict) -> dict:
55 | if example["input"] == "":
56 | return {"prompt": _ALPACA_TEMPLATE_WITHOUT_INPUT.format(instruction=example["instruction"])}
57 | else:
58 | return {"prompt": _ALPACA_TEMPLATE_WITH_INPUT.format(instruction=example["instruction"], input=example["input"])}
59 |
60 |
61 | _ALPACA_TEMPLATE_WITH_INPUT = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
62 |
63 | ### Instruction:
64 | {instruction}
65 |
66 | ### Input:
67 | {input}
68 |
69 | ### Response:
70 | """
71 |
72 | _ALPACA_TEMPLATE_WITHOUT_INPUT = """Below is an instruction that describes a task. Write a response that appropriately completes the request.
73 |
74 | ### Instruction:
75 | {instruction}
76 |
77 | ### Response:
78 | """
79 |
80 |
81 | if __name__ == "__main__":
82 | create_alpaca_prompts()
83 |
--------------------------------------------------------------------------------
/pages/train_qwen_3.md:
--------------------------------------------------------------------------------
1 | ## Train qwen3
2 |
3 | ```bash
4 | huggingface-cli download Qwen/Qwen3-8B --local-dir ./resources/Qwen/Qwen3-8B
5 | ```
6 |
7 | ```bash
8 | docker build --tag eagle -f docker/Dockerfile . && docker run \
9 | --gpus all \
10 | -e OMP_NUM_THREADS=4 \
11 | -e CUDA_VISIBLE_DEVICES=1,2,5,7 \
12 | -e CLEARML_OFFLINE_MODE=1 \
13 | -v ./resources:/mnt/resources \
14 | eagle \
15 | torchrun \
16 | --nnodes=1 \
17 | --nproc_per_node=4 \
18 | eagle/train.py \
19 | --micro-batch-size 2 \
20 | --gradient-accumulation-steps 2 \
21 | --num-warmup-steps 4096 \
22 | --num-training-steps 524288 \
23 | --epochs 4 \
24 | --clearml-project eagle \
25 | --clearml-task 4gpus-2microbs-2accum-16globalbs \
26 | --verifier-model-path /mnt/resources/qwen3-8b \
27 | --dataset-path /mnt/resources/sharegpt.jsonl \
28 | --eagle-config-path /mnt/resources/eagle_config_qwen3_8b.json \
29 | --learning-rate 2e-4 \
30 | --maximum-model-length 2048 \
31 | --noise-low -0.1 \
32 | --noise-high 0.1 \
33 | --v-w 1.0 \
34 | --p-w 0.1 \
35 | --grad-clip 0.5 \
36 | --b1 0.9 \
37 | --b2 0.95 \
38 | --cpdir /mnt/resources/checkpoints \
39 | --save 4096 \
40 | --mixed-precision bf16 \
41 | --verifier-model-lm-head-dtype bfloat16 \
42 | --verifier-model-dtype bfloat16 \
43 | --eagle-dtype bfloat16 \
44 | --attn flash_attention_2
45 | ```
46 |
47 | ```bash
48 | docker build --tag eagle -f docker/Dockerfile . && docker run \
49 | --gpus all \
50 | -e CUDA_VISIBLE_DEVICES=1,2,5,7 \
51 | -e CLEARML_OFFLINE_MODE=1 \
52 | -v ./resources:/mnt/resources \
53 | eagle python3 ./eagle/train_tp.py \
54 | --micro-batch-size 2 \
55 | --gradient-accumulation-steps 2 \
56 | --num-warmup-steps 4096 \
57 | --num-training-steps 524288 \
58 | --epochs 4 \
59 | --clearml-project eagle \
60 | --clearml-task 4gpus-2microbs-2accum-16globalbs \
61 | --verifier-model-path /mnt/resources/qwen3-8b \
62 | --dataset-path /mnt/resources/sharegpt.jsonl \
63 | --eagle-config-path /mnt/resources/eagle_config_qwen3_8b.json \
64 | --learning-rate 2e-4 \
65 | --maximum-model-length 2048 \
66 | --noise-low -0.1 \
67 | --noise-high 0.1 \
68 | --v-w 1.0 \
69 | --p-w 0.1 \
70 | --grad-clip 0.5 \
71 | --b1 0.9 \
72 | --b2 0.95 \
73 | --cpdir /mnt/resources/checkpoints \
74 | --save 4096 \
75 | --mixed-precision bf16 \
76 | --verifier-model-lm-head-dtype bfloat16 \
77 | --verifier-model-dtype bfloat16 \
78 | --eagle-dtype bfloat16 \
79 | --attn flash_attention_2
80 | ```
81 |
--------------------------------------------------------------------------------
/eagle/prepare_sharegpt_dataset.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 | import datasets
3 | import argparse
4 |
5 |
6 | def convert_sharegpt_dataset() -> None:
7 | arguments = _parse_arguments()
8 | output_path: pathlib.Path = arguments.output
9 | n = arguments.n
10 | frac = arguments.frac
11 | if n is not None and frac is not None:
12 | raise ValueError("One of --n or --frac must be set")
13 |
14 | print("Loading raw dataset")
15 | dataset = (
16 | datasets
17 | .load_dataset("json", data_files={
18 | "train": "https://huggingface.co/datasets/Aeala/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V4.3_unfiltered_cleaned_split.json"
19 | })
20 | ["train"]
21 | .shuffle(seed=0)
22 | )
23 |
24 | print(f"Dataset has {len(dataset)} rows")
25 | indices = range(n) if n is not None else range(int(frac * len(dataset)))
26 | dataset = dataset.select(indices)
27 | print(f"Dataset after select has {len(dataset)} rows")
28 |
29 | dataset = dataset.map(
30 | _convert_sharegpt_dataset,
31 | num_proc=1,
32 | remove_columns=dataset.column_names,
33 | desc="Converting dataset"
34 | )
35 |
36 | print(f"Dataset after filtering has {len(dataset)} rows")
37 |
38 | print("Saving to disk")
39 | dataset.to_json(output_path)
40 |
41 |
42 | def _parse_arguments() -> argparse.Namespace:
43 | parser = argparse.ArgumentParser(
44 | description="Generate trajectories"
45 | )
46 | parser.add_argument(
47 | "--output",
48 | type=pathlib.Path,
49 | required=True,
50 | help="Path to jsonlines file where the processed dataset will be stored"
51 | )
52 | parser.add_argument(
53 | "--n",
54 | type=int,
55 | help="Number of sampels to take"
56 | )
57 | parser.add_argument(
58 | "--frac",
59 | type=float,
60 | help="Number of sampels to take from 0.0 to 1.0 percent"
61 | )
62 | return parser.parse_args()
63 |
64 |
65 | def _convert_sharegpt_dataset(example: dict) -> dict:
66 | new_turns = []
67 | # https://github.com/SafeAILab/EAGLE/blob/4a9cf3a1f6cd4a294e6d30a4e7c77cba246d7ca5/eagle/ge_data/ge_data_all_llama2chat.py#L65
68 | system_prompt = "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."
69 | new_turns.append({"role": "system", "content": system_prompt})
70 | for turn in example["conversations"]:
71 | if turn["from"] == "gpt":
72 | role = "assistant"
73 | elif turn["from"] == "human":
74 | role = "user"
75 | else:
76 | raise ValueError("Unknown role")
77 | new_turn = {"role": role, "content": turn["value"]}
78 | new_turns.append(new_turn)
79 | # https://github.com/SafeAILab/EAGLE/blob/4a9cf3a1f6cd4a294e6d30a4e7c77cba246d7ca5/eagle/ge_data/ge_data_all_llama2chat.py#L69
80 | if new_turns[1]["role"] == "assistant":
81 | new_turns = [new_turns[0]] + new_turns[2:]
82 | return {"messages": new_turns}
83 |
84 |
85 | if __name__ == "__main__":
86 | convert_sharegpt_dataset()
87 |
--------------------------------------------------------------------------------
/eagle/generate_trajectories.py:
--------------------------------------------------------------------------------
1 | import time
2 | import torch
3 | import sglang
4 | import pathlib
5 | import argparse
6 | import datasets
7 | import transformers
8 |
9 |
10 | def _prepare_dataset() -> None:
11 | arguments = _parse_arguments()
12 |
13 | raw_path: pathlib.Path = arguments.input
14 | model_path: pathlib.Path = arguments.model
15 | tokenizer_path: pathlib.Path = arguments.tokenizer
16 | output_path: pathlib.Path = arguments.output
17 | n = arguments.n
18 | frac = arguments.frac
19 | if n is not None and frac is not None:
20 | raise ValueError("One of --n or --frac must be set")
21 |
22 | print("Loading tokenizer")
23 | tokenizer = transformers.AutoTokenizer.from_pretrained(str(tokenizer_path), use_fast=True)
24 |
25 | print("Loading raw dataset")
26 | dataset = (
27 | datasets
28 | .load_dataset("json", data_files={"train": [str(raw_path)]})
29 | ["train"]
30 | .shuffle(seed=0)
31 | )
32 | print(f"Dataset has {len(dataset)} rows")
33 | indices = range(n) if n is not None else range(int(frac * len(dataset)))
34 | dataset = dataset.select(indices)
35 | print(f"Dataset after select has {len(dataset)} rows")
36 |
37 | dataset = dataset.map(
38 | lambda example: _tokenize_dataset(example=example, tokenizer=tokenizer),
39 | batched=False,
40 | num_proc=1,
41 | desc="Tokenizing dataset"
42 | )
43 |
44 | llm = sglang.Engine(
45 | model_path=str(model_path),
46 | tp_size=arguments.tp,
47 | pp_size=arguments.pp,
48 | dp_size=arguments.dp,
49 | log_level="info"
50 | )
51 |
52 | sampling_params = {
53 | "temperature": arguments.temperature,
54 | "max_new_tokens": arguments.max_new_tokens
55 | }
56 |
57 | batched_input_ids = [el["input_ids"] for el in dataset]
58 |
59 | start = time.perf_counter()
60 | outputs: list[dict] = llm.generate(input_ids=batched_input_ids, sampling_params=sampling_params)
61 | end = time.perf_counter()
62 | print(end - start, "total inference")
63 |
64 |
65 | def add_reply(example, idx):
66 | msgs = example["messages"].copy()
67 | msgs.append({"role": "assistant", "content": outputs[idx]["text"]})
68 | return {"messages": msgs}
69 |
70 |
71 | dataset = dataset.map(
72 | add_reply,
73 | with_indices=True,
74 | batched=False,
75 | desc="Appending assistant replies"
76 | )
77 |
78 |
79 | print("Saving to disk")
80 | dataset.select_columns(["id", "messages"]).to_json(output_path, force_ascii=False)
81 |
82 |
83 | def _parse_arguments() -> argparse.Namespace:
84 | parser = argparse.ArgumentParser(
85 | description="Generate trajectories"
86 | )
87 | parser.add_argument(
88 | "--input",
89 | type=pathlib.Path,
90 | required=True,
91 | help="Path to JSON lines chat dataset as described in documentation wheres lines end with user response"
92 | )
93 | parser.add_argument(
94 | "--model",
95 | type=pathlib.Path,
96 | required=True,
97 | help="Path to verifier model"
98 | )
99 | parser.add_argument(
100 | "--tokenizer",
101 | type=pathlib.Path,
102 | required=True,
103 | help="Path to tokenizer, usually the same as the model"
104 | )
105 | parser.add_argument(
106 | "--output",
107 | type=pathlib.Path,
108 | required=True,
109 | help="Path to jsonlines file where the processed dataset will be stored"
110 | )
111 | parser.add_argument(
112 | "--temperature",
113 | type=float,
114 | default=0.0,
115 | help="Temperature"
116 | )
117 | parser.add_argument(
118 | "--n",
119 | type=int,
120 | help="Number of sampels to take"
121 | )
122 | parser.add_argument(
123 | "--max-new-tokens",
124 | type=int,
125 | help="Max new tokens"
126 | )
127 | parser.add_argument(
128 | "--frac",
129 | type=float,
130 | help="Number of sampels to take from 0.0 to 1.0 percent"
131 | )
132 | parser.add_argument(
133 | "--tp",
134 | type=int,
135 | help="tp size"
136 | )
137 | parser.add_argument(
138 | "--pp",
139 | type=int,
140 | help="tp size"
141 | )
142 | parser.add_argument(
143 | "--dp",
144 | type=int,
145 | help="dp size"
146 | )
147 | return parser.parse_args()
148 |
149 |
150 | def _tokenize_dataset(example: dict, tokenizer: transformers.AutoTokenizer) -> dict[str, torch.LongTensor]:
151 | if example["messages"][-1]["role"] != "assistant":
152 | raise ValueError("Last message must be from an assistant")
153 |
154 | messages = example["messages"][:-1] # keep all except last assitant reply
155 |
156 | result = tokenizer.apply_chat_template(
157 | messages,
158 | tokenize=True,
159 | add_generation_prompt=True,
160 | return_dict=True
161 | )
162 |
163 | return {
164 | "input_ids": result["input_ids"],
165 | "messages": messages
166 | }
167 |
168 |
169 | if __name__ == "__main__":
170 | _prepare_dataset()
171 |
--------------------------------------------------------------------------------
/benchmark/benchmark.py:
--------------------------------------------------------------------------------
1 | import json
2 | import sglang
3 | import pathlib
4 | import argparse
5 | import datasets
6 |
7 |
8 | def benchmark() -> None:
9 | # Parse arguments
10 | arguments = _parse_arguments()
11 | output_report_path = arguments.output
12 | prompts_path = arguments.prompts
13 | draft_tokens = arguments.draft
14 | top_k = arguments.k
15 | draft_steps = arguments.steps
16 | eagle_path = arguments.eagle
17 | model_path = arguments.model
18 | speculative_algorithm = arguments.speculative_algorithm
19 | n = arguments.n
20 | batch_size = arguments.bs
21 | temperature = arguments.temperature
22 | frac = arguments.frac
23 | if n is not None and frac is not None:
24 | raise ValueError("One of --n or --frac must be set")
25 |
26 | # Load dataset
27 | print("Loading raw dataset")
28 | dataset = (
29 | datasets
30 | .load_dataset("json", data_files={"train": [str(prompts_path)]})
31 | ["train"]
32 | .shuffle(seed=0)
33 | )
34 | print(f"Dataset has {len(dataset)} rows")
35 | indices = range(n) if n is not None else range(int(frac * len(dataset)))
36 | dataset = dataset.select(indices)
37 | print(f"Dataset after select has {len(dataset)} rows")
38 |
39 | print("Creating sglang engine")
40 |
41 | # Prepare llm engine
42 | if speculative_algorithm is None:
43 | llm = sglang.Engine(
44 | model_path=str(model_path),
45 | max_running_requests=batch_size
46 | )
47 | else:
48 | llm = sglang.Engine(
49 | model_path=str(model_path),
50 | speculative_algorithm=speculative_algorithm,
51 | speculative_draft_model_path=str(eagle_path),
52 | speculative_num_steps=draft_steps,
53 | speculative_eagle_topk=top_k,
54 | speculative_num_draft_tokens=draft_tokens,
55 | max_running_requests=batch_size,
56 | )
57 |
58 | # Send requests to llm engine
59 | sampling_params = {
60 | "temperature": temperature,
61 | }
62 |
63 | total_verify_ct = 0
64 | total_latency = 0.0
65 | total_output_tokens = 0
66 |
67 | prompts = [example["prompt"] for example in dataset]
68 | outputs: list[dict] = llm.generate(prompts, sampling_params)
69 | llm.shutdown()
70 |
71 | # Collect metrics from llm engine
72 | for output in outputs:
73 | total_latency = max(total_latency, output["meta_info"]["e2e_latency"])
74 | total_output_tokens += output["meta_info"]["completion_tokens"]
75 | if speculative_algorithm is not None:
76 | total_verify_ct += output["meta_info"]["spec_verify_ct"]
77 |
78 | total_output_throughput = total_output_tokens / total_latency
79 | if speculative_algorithm is not None and total_verify_ct != 0:
80 | accept_length = total_output_tokens / total_verify_ct
81 | else:
82 | accept_length = None
83 |
84 | # Create report and save to disk
85 | report_dict = {
86 | "output_throughput": total_output_throughput,
87 | "total_output_tokens": total_output_tokens,
88 | "total_latency": total_latency
89 | }
90 |
91 | if speculative_algorithm is not None:
92 | report_dict["acceptance_length"] = accept_length
93 |
94 | print(report_dict)
95 |
96 | with output_report_path.open("w") as report_file:
97 | json.dump(report_dict, report_file, indent=4)
98 |
99 |
100 | def _parse_arguments() -> argparse.Namespace:
101 | parser = argparse.ArgumentParser(description="Benchmark")
102 | parser.add_argument(
103 | "--model",
104 | type=pathlib.Path,
105 | required=True,
106 | help="Path to jsonlines file where the processed dataset will be stored"
107 | )
108 | parser.add_argument(
109 | "--prompts",
110 | type=pathlib.Path,
111 | required=True,
112 | help="Path to jsonlines file with prompts"
113 | )
114 | parser.add_argument(
115 | "--eagle",
116 | type=pathlib.Path,
117 | required=False,
118 | help="Path to jsonlines file where the processed dataset will be stored"
119 | )
120 | parser.add_argument(
121 | "--speculative-algorithm",
122 | type=str,
123 | required=False,
124 | help="For example EAGLE or EAGLE 3"
125 | )
126 | parser.add_argument(
127 | "--bs",
128 | type=int,
129 | required=True,
130 | help="For example EAGLE or EAGLE 3"
131 | )
132 | parser.add_argument(
133 | "--steps",
134 | type=int,
135 | help="Number of sampels to take"
136 | )
137 | parser.add_argument(
138 | "--k",
139 | type=int,
140 | help="Number of top k"
141 | )
142 | parser.add_argument(
143 | "--draft",
144 | type=int,
145 | help="Number of draft tokens to generate"
146 | )
147 | parser.add_argument(
148 | "--n",
149 | type=int,
150 | help="Number of sampels to take"
151 | )
152 | parser.add_argument(
153 | "--frac",
154 | type=float,
155 | help="Number of sampels to take from 0.0 to 1.0 percent"
156 | )
157 | parser.add_argument(
158 | "--output",
159 | type=pathlib.Path,
160 | required=True,
161 | help="Path to report"
162 | )
163 | parser.add_argument(
164 | "--temperature",
165 | type=float,
166 | required=True,
167 | help="Temperature"
168 | )
169 | return parser.parse_args()
170 |
171 |
172 | if __name__ == "__main__":
173 | benchmark()
174 |
--------------------------------------------------------------------------------
/resources/raw_example_dataset.jsonl:
--------------------------------------------------------------------------------
1 | {"id": 0, "messages": [{"role": "user", "content": "Give me 100 prompt parameters that I can specify that will influence your output, e.g. voice, tone, register, style, audience etc."}, {"role": "assistant", "content": "Sure, here are 100 prompt parameters that you can specify to influence my output:\n\n1. Voice (e.g., male or female)\n2. Tone (e.g., serious, sarcastic, humorous, etc.)\n3. Register (e.g., formal, informal, academic, conversational, etc.)\n4. Style (e.g., narrative, descriptive, expository, persuasive, etc.)\n5. Audience (e.g., children, adults, experts, non-experts, etc.)\n6. Genre (e.g., fiction, non-fiction, poetry, drama, etc.)\n7. Language (e.g., English, French, Spanish, etc.)\n8. Dialect (e.g., American English, British English, etc.)\n9. Purpose (e.g., informative, instructive, entertaining, etc.)\n10. Mood (e.g., happy, sad, angry, etc.)\n11. Point of view (e.g., first person, third person, etc.)\n12. Setting (e.g., time period, location, etc.)\n13. Characterization (e.g., flat, round, static, dynamic, etc.)\n14. Conflict (e.g., internal, external, etc.)\n15. Plot (e.g., exposition, rising action, climax, etc.)\n16. Theme (e.g., love, friendship, betrayal, etc.)\n17. Symbolism (e.g., colors, objects, etc.)\n18. Imagery (e.g., visual, auditory, olfactory, etc.)\n19. Figurative language (e.g., simile, metaphor, hyperbole, etc.)\n20. Allusion (e.g., historical, cultural, literary, etc.)\n21. Irony (e.g., verbal, situational, dramatic, etc.)\n22. Foreshadowing (e.g., subtle, explicit, etc.)\n23. Flashback (e.g., chronological, non-chronological, etc.)\n24. Dialogue (e.g., realistic, stylized, etc.)\n25. Monologue (e.g., internal, external, etc.)\n26. Soliloquy (e.g., dramatic, introspective, etc.)\n27. Sound effects (e.g., ambient, diegetic, non-diegetic, etc.)\n28. Music (e.g., genre, tempo, etc.)\n29. Visual effects (e.g., lighting, color, etc.)\n30. Camera angles (e.g., low, high, close-up, etc.)\n31. Shot composition (e.g., rule of thirds, symmetry, etc.)\n32. Editing (e.g., cuts, fades, transitions, etc.)\n33. Pacing (e.g., fast, slow, steady, etc.)\n34. Structure (e.g., chronological, non-chronological, etc.)\n35. Tone of voice (e.g., serious, angry, sad, etc.)\n36. Accent (e.g., regional, international, etc.)\n37. Syntax (e.g., simple, complex, compound, etc.)\n38. Vocabulary (e.g., formal, informal, technical, etc.)\n39. Grammar (e.g., correct, incorrect, etc.)\n40. Spelling (e.g., American, British, etc.)\n41. Pronunciation (e.g., clear, mumbled, etc.)\n42. Enunciation (e.g., precise, slurred, etc.)\n43. Volume (e.g., loud, soft, etc.)\n44. Inflection (e.g., rising, falling, etc.)\n45. Pitch (e.g., high, low, etc.)\n46. Speed (e.g., fast, slow, etc.)\n47. Diction (e.g., clear, mumbled, etc.)\n48. Rhythm (e.g., staccato, legato, etc.)\n49. Timing ("}, {"role": "user", "content": "Continue"}, {"role": "assistant", "content": "49. Timing (e.g., pacing, pauses, etc.)\n50. Emphasis (e.g., stress, intonation, etc.)\n51. Emotion (e.g., happy, sad, angry, etc.)\n52. Attitude (e.g., positive, negative, neutral, etc.)\n53. Knowledge level (e.g., basic, intermediate, advanced, etc.)\n54. Educational background (e.g., high school, college, graduate school, etc.)\n55. Professional background (e.g., medicine, law, engineering, etc.)\n56. Cultural background (e.g., Western, Eastern, etc.)\n57. Geographic location (e.g., North America, Asia, Europe, etc.)\n58. Age (e.g., young adult, middle-aged, senior, etc.)\n59. Gender (e.g., male, female, non-binary, etc.)\n60. Religion (e.g., Christianity, Islam, Judaism, etc.)\n61. Political affiliation (e.g., conservative, liberal, moderate, etc.)\n62. Social status (e.g., wealthy, middle-class, low-income, etc.)\n63. Interpersonal relationships (e.g., parent-child, romantic, etc.)\n64. Career goals (e.g., entrepreneurship, academia, etc.)\n65. Life goals (e.g., travel, family, etc.)\n66. Hobbies\/interests (e.g., sports, music, art, etc.)\n67. Personality traits (e.g., extroverted, introverted, etc.)\n68. Learning style (e.g., visual, auditory, kinesthetic, etc.)\n69. Communication style (e.g., direct, indirect, assertive, etc.)\n70. Sensory preferences (e.g., visual, auditory, tactile, etc.)\n71. Learning environment (e.g., classroom, online, etc.)\n72. Communication medium (e.g., email, phone, in-person, etc.)\n73. Learning objectives (e.g., mastery, exploration, etc.)\n74. Learning outcome (e.g., knowledge acquisition, skill development, etc.)\n75. Learning pace (e.g., accelerated, self-paced, etc.)\n76. Learning format (e.g., lecture, discussion, hands-on, etc.)\n77. Learning method (e.g., traditional, experiential, etc.)\n78. Learning tools (e.g., textbooks, videos, software, etc.)\n79. Learning resources (e.g., library, online databases, etc.)\n80. Learning environment (e.g., quiet, collaborative, etc.)\n81. Learning motivation (e.g., intrinsic, extrinsic, etc.)\n82. Learning obstacles (e.g., lack of resources, time constraints, etc.)\n83. Learning preferences (e.g., self-directed, group-oriented, etc.)\n84. Learning feedback (e.g., formative, summative, etc.)\n85. Learning assessment (e.g., quizzes, exams, etc.)\n86. Learning evaluation (e.g., self-evaluation, peer evaluation, etc.)\n87. Learning reflection (e.g., journaling, discussion, etc.)\n88. Learning feedback (e.g., written, oral, etc.)\n89. Learning environment (e.g., physical, virtual, etc.)\n90. Learning support (e.g., tutoring, counseling, etc.)\n91. Learning goals (e.g., cognitive, affective, etc.)\n92. Learning objectives (e.g., knowledge, skills, attitudes, etc.)\n93. Learning outcomes (e.g., competence, proficiency, etc.)\n94. Learning activities (e.g., problem-solving, critical thinking, etc.)\n95. Learning strategies (e.g., memorization, repetition, etc.)\n96. Learning styles (e.g., visual, auditory, kinesthetic, etc.)\n97. Learning theories (e.g., behaviorism, constructivism,"},{"role": "user", "content": "Continue"}, {"role":"assistant", "content": "97. Learning theories (e.g., behaviorism, constructivism, etc.)\n98. Learning models (e.g., ADDIE, SAM, etc.)\n99. Learning technology (e.g., LMS, authoring tools, etc.)\n100. Learning trends (e.g., gamification, microlearning, etc.)\n\nI hope these prompt parameters help you specify the output you're looking for. Let me know if you have any other questions or if there's anything else I can assist you with."}]}
2 |
--------------------------------------------------------------------------------
/pages/reproduce_eagle_1_paper.md:
--------------------------------------------------------------------------------
1 | ## Reproduce eagle 1 paper
2 |
3 | Page will cover all steps neede to reproduce eagle 1 paper for model [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)
4 |
5 | ## Docker
6 |
7 | In order to use tool you need to build docker
8 |
9 | ```bash
10 | docker build --tag vladislavkruglikov/eagle:0.3.1 -f docker/Dockerfile .
11 | ```
12 |
13 | Or use already build one
14 |
15 | ```bash
16 | docker pull vladislavkruglikov/eagle:latest
17 | ```
18 |
19 | ## Model
20 |
21 | Downloading model might require restricted access so you might need to export your token before
22 |
23 | ```bash
24 | export HF_TOKEN=
25 | huggingface-cli download meta-llama/Llama-2-7b-chat-hf \
26 | --local-dir ./resources/meta-llama2-7b-chat-hf
27 | ```
28 |
29 | Make sure chate template guards assistant reply with {% generation %} and {% endgeneration %} as shown in [this example](../resources/example_chat_template_with_generation_keyword.json). This is needed for huggingface to generate loss mask correctly such that loss attends only on assistant replies
30 |
31 | ## Prepare dataset
32 |
33 | In order to train models with our frameworks you need to supply dataset in particular format. It is jsonlines where each lines looks like this. Basically you have a list of messages and each message has 2 keys which are role and content
34 |
35 | ```json
36 | {"id": 0, "messages": [{"role": "user", "content": "Give me 100 prompt parameters that I can specify that will influence your output, e.g. voice, tone, register, style, audience etc."}, {"role": "assistant", "content": "Sure, here are 100 prompt parameters that you can specify to influence my output:\n\n1. Voice (e.g., male or female)\n2. Tone (e.g., serious, sarcastic, humorous, etc.)"}, {"role": "user", "content": "Continue"}, {"role": "assistant", "content": "3. Timing (e.g., pacing, pauses, etc.)\n4. Emphasis (e.g., stress, intonation, etc.)"}]}
37 | ```
38 |
39 | In order to reproduce eagle 1 paper we will be using pre built script that downloads sharegpt dataset and formats as reference proposes
40 |
41 | ```bash
42 | docker run \
43 | -v ./resources:/mnt/resources \
44 | eagle \
45 | python3 eagle/prepare_sharegpt_dataset.py \
46 | --frac 1.0 \
47 | --output /mnt/resources/sharegpt.jsonl
48 | ```
49 |
50 | ## Train draft model
51 |
52 | I will demonstrate how to train draft model for particular configuration but you can twick it under your needs. Training script will log data into clearml dashboards. If you want to disable it (which is reccomended for first try run script) run
53 |
54 | ```bash
55 | export CLEARML_OFFLINE_MODE=1
56 | ```
57 |
58 |
59 | Otherwise specify this environments
60 |
61 | ```bash
62 | export CLEARML_WEB_HOST=
63 | export CLEARML_API_HOST=
64 | export CLEARML_FILES_HOST=
65 | export CLEARML_API_ACCESS_KEY=
66 | export CLEARML_API_SECRET_KEY=
67 | export CLEARML_API_HOST_VERIFY_CERT=
68 | ```
69 |
70 | Also it is up to you to configure
71 |
72 | ```bash
73 | export OMP_NUM_THREADS=4
74 | export CUDA_VISIBLE_DEVICES=7
75 | ```
76 |
77 | I will demonstrate once again a very simple setup which will probably be the most popular. For that you need single gpu with roughly 80 HBM
78 |
79 | ```bash
80 | docker run \
81 | --gpus all \
82 | -e OMP_NUM_THREADS=4 \
83 | -e CUDA_VISIBLE_DEVICES=7 \
84 | -e CLEARML_OFFLINE_MODE=1 \
85 | -v ./resources:/mnt/resources \
86 | eagle \
87 | torchrun \
88 | --nnodes=1 \
89 | --nproc_per_node=1 \
90 | eagle/train.py \
91 | --micro-batch-size 8 \
92 | --gradient-accumulation-steps 2 \
93 | --num-warmup-steps 1024 \
94 | --num-training-steps 131072 \
95 | --epochs 4 \
96 | --clearml-project eagle \
97 | --clearml-task 1gpus-8microbs-2accum-16globalbs \
98 | --verifier-model-path /mnt/resources/meta-llama2-7b-chat-hf \
99 | --dataset-path /mnt/resources/sharegpt.jsonl \
100 | --eagle-config-path /mnt/resources/eagle_config.json \
101 | --learning-rate 2e-4 \
102 | --maximum-model-length 2048 \
103 | --noise-low -0.1 \
104 | --noise-high 0.1 \
105 | --v-w 1.0 \
106 | --p-w 0.1 \
107 | --grad-clip 0.5 \
108 | --b1 0.9 \
109 | --b2 0.95 \
110 | --cpdir /mnt/resources/checkpoints \
111 | --save 4096 \
112 | --mixed-precision bf16 \
113 | --verifier-model-lm-head-dtype bfloat16 \
114 | --verifier-model-dtype bfloat16 \
115 | --eagle-dtype bfloat16 \
116 | --attn flash_attention_2
117 | ```
118 |
119 | For thoose who have access to 2 GPUs
120 |
121 | ```bash
122 | docker run \
123 | --gpus all \
124 | -e OMP_NUM_THREADS=4 \
125 | -e CUDA_VISIBLE_DEVICES=5,7 \
126 | -e CLEARML_OFFLINE_MODE=1 \
127 | -v ./resources:/mnt/resources \
128 | eagle \
129 | torchrun \
130 | --nnodes=1 \
131 | --nproc_per_node=2 \
132 | eagle/train.py \
133 | --micro-batch-size 4 \
134 | --gradient-accumulation-steps 2 \
135 | --num-warmup-steps 2048 \
136 | --num-training-steps 262144 \
137 | --epochs 4 \
138 | --clearml-project eagle \
139 | --clearml-task 2gpus-4microbs-2accum-16globalbs \
140 | --verifier-model-path /mnt/resources/meta-llama2-7b-chat-hf \
141 | --dataset-path /mnt/resources/sharegpt.jsonl \
142 | --eagle-config-path /mnt/resources/eagle_config.json \
143 | --learning-rate 2e-4 \
144 | --maximum-model-length 2048 \
145 | --noise-low -0.1 \
146 | --noise-high 0.1 \
147 | --v-w 1.0 \
148 | --p-w 0.1 \
149 | --grad-clip 0.5 \
150 | --b1 0.9 \
151 | --b2 0.95 \
152 | --cpdir /mnt/resources/checkpoints \
153 | --save 4096 \
154 | --mixed-precision bf16 \
155 | --verifier-model-lm-head-dtype bfloat16 \
156 | --verifier-model-dtype bfloat16 \
157 | --eagle-dtype bfloat16 \
158 | --attn flash_attention_2
159 | ```
160 |
161 | For thoose who have access to 4 GPUs
162 |
163 | ```bash
164 | docker run \
165 | --gpus all \
166 | -e OMP_NUM_THREADS=4 \
167 | -e CUDA_VISIBLE_DEVICES=1,3,5,7 \
168 | -e CLEARML_OFFLINE_MODE=1 \
169 | -v ./resources:/mnt/resources \
170 | eagle \
171 | torchrun \
172 | --nnodes=1 \
173 | --nproc_per_node=4 \
174 | eagle/train.py \
175 | --micro-batch-size 2 \
176 | --gradient-accumulation-steps 2 \
177 | --num-warmup-steps 4096 \
178 | --num-training-steps 524288 \
179 | --epochs 4 \
180 | --clearml-project eagle \
181 | --clearml-task 4gpus-2microbs-2accum-16globalbs \
182 | --verifier-model-path /mnt/resources/meta-llama2-7b-chat-hf \
183 | --dataset-path /mnt/resources/sharegpt.jsonl \
184 | --eagle-config-path /mnt/resources/eagle_config.json \
185 | --learning-rate 2e-4 \
186 | --maximum-model-length 2048 \
187 | --noise-low -0.1 \
188 | --noise-high 0.1 \
189 | --v-w 1.0 \
190 | --p-w 0.1 \
191 | --grad-clip 0.5 \
192 | --b1 0.9 \
193 | --b2 0.95 \
194 | --cpdir /mnt/resources/checkpoints \
195 | --save 4096 \
196 | --mixed-precision bf16 \
197 | --verifier-model-lm-head-dtype bfloat16 \
198 | --verifier-model-dtype bfloat16 \
199 | --eagle-dtype bfloat16 \
200 | --attn flash_attention_2
201 | ```
202 |
--------------------------------------------------------------------------------
/eagle/train_tp.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import json
4 | import math
5 | import torch
6 | import pathlib
7 | import clearml
8 | import logging
9 | import datasets
10 | import argparse
11 | import safetensors
12 | import transformers
13 |
14 | from eagle.llama2 import Llama2Model
15 |
16 |
17 | def coach() -> None:
18 | arguments = _parse_arguments()
19 |
20 | torch.backends.cuda.matmul.allow_tf32 = True
21 | logging.basicConfig(level=logging.INFO, format='%(levelname)s %(asctime)s [%(filename)s:%(lineno)d] %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
22 |
23 | logging.info("Start to prepare clearml ")
24 | clearml_task = clearml.Task.init(project_name=arguments.clearml_project, task_name=arguments.clearml_task, reuse_last_task_id=False, continue_last_task=False, output_uri=False, auto_connect_frameworks=False, auto_resource_monitoring=False)
25 | clearml_logger = clearml_task.get_logger()
26 |
27 | logging.info("Start to prepare language model head ")
28 | lm_head = _initialize_verifier_lm_head(verifier_path=arguments.verifier_model_path).to(getattr(torch, arguments.verifier_model_lm_head_dtype)).to("cuda")
29 | logging.info("Language model head has dtype %s", next(lm_head.parameters()).dtype)
30 | logging.info("Language model head has %f billion parameters", _count_parameters(model=lm_head) / 10 ** 9)
31 | clearml_logger.report_single_value(name="Language model head parameters billion", value=_count_parameters(model=lm_head) / 10 ** 9)
32 |
33 | logging.info("Start to prepare target model ")
34 | verifier_model = transformers.AutoModelForCausalLM.from_pretrained(arguments.verifier_model_path, device_map="auto", torch_dtype=getattr(torch, arguments.verifier_model_dtype), attn_implementation=arguments.attn)
35 | verifier_model = verifier_model.eval()
36 | logging.info("Target model head has dtype %s", next(verifier_model.parameters()).dtype)
37 | logging.info("Target model head has %f billion parameters", _count_parameters(model=verifier_model) / 10 ** 9)
38 | clearml_logger.report_single_value(name="Target model head parameters billion", value=_count_parameters(model=verifier_model) / 10 ** 9)
39 |
40 | logging.info("Start to prepare draft model ")
41 | config = transformers.AutoConfig.from_pretrained(arguments.eagle_config_path)
42 | model = Llama2Model(config, load_emb=True, path=arguments.verifier_model_path).to(getattr(torch, arguments.eagle_dtype)).to("cuda")
43 | logging.info("Draft model head has dtype %s", next(model.parameters()).dtype)
44 | logging.info("Draft model head has %f billion parameters", _count_parameters(model=model) / 10 ** 9)
45 | model.train()
46 | clearml_logger.report_single_value(name="Draft model head parameters billion", value=_count_parameters(model=model) / 10 ** 9)
47 |
48 | logging.info("Start to prepare data ")
49 | dataset = datasets.load_dataset("json", data_files={"train": [arguments.dataset_path]})["train"]
50 | dataset = Dataset(dataset=dataset)
51 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=arguments.micro_batch_size, collate_fn=Collator(arguments.verifier_model_path))
52 | logging.info("Dataset contains %d samples", len(dataset))
53 |
54 | logging.info("Start to prepare miscellaneous ")
55 | criterion = torch.nn.SmoothL1Loss(reduction="none")
56 | model_optimizer = torch.optim.AdamW(model.parameters(), lr=arguments.learning_rate, betas=(arguments.b1, arguments.b2))
57 |
58 | scheduler = transformers.get_linear_schedule_with_warmup(optimizer=model_optimizer, num_warmup_steps=arguments.num_warmup_steps, num_training_steps=arguments.num_training_steps)
59 |
60 | logging.info("Start training ")
61 | total_steps_passed = 0
62 | for epoch in range(arguments.epochs):
63 | training_iterator = iter(dataloader)
64 | num_samples_in_epoch = len(dataloader)
65 | remainder = num_samples_in_epoch % arguments.gradient_accumulation_steps
66 | remainder = remainder if remainder != 0 else arguments.gradient_accumulation_steps
67 | total_gradient_updates = math.ceil(num_samples_in_epoch / arguments.gradient_accumulation_steps)
68 | for update_step in range(total_gradient_updates):
69 | step_start = time.perf_counter()
70 | accum_loss = 0.0
71 | batch_samples = []
72 | num_batches_in_step = arguments.gradient_accumulation_steps if update_step != (total_gradient_updates - 1) else remainder
73 | for _ in range(num_batches_in_step):
74 | batch_samples += [next(training_iterator)]
75 | num_items_in_batch = sum([batch["loss_mask"][:, :arguments.maximum_model_length].sum() for batch in batch_samples])
76 |
77 | if num_items_in_batch == 0:
78 | logging.warning("num_items_in_batch is zero, skipping steps")
79 | continue
80 |
81 | step_correctly_predicted_tokens_count = 0
82 |
83 | for i, batch in enumerate(batch_samples):
84 | batch = _make_eagle_input(batch, verifier_model, arguments.maximum_model_length, arguments.noise_low, arguments.noise_high, "cuda")
85 | batch["hidden_states"] = batch["hidden_states"]
86 | batch["target"] = batch["target"]
87 | predict = model(batch["hidden_states"].to( getattr(torch, arguments.eagle_dtype) ), input_ids=batch["input_ids"])
88 | with torch.no_grad():
89 | target_head = lm_head(batch["target"].to( getattr(torch, arguments.verifier_model_lm_head_dtype) ),)
90 | target_p = torch.nn.Softmax(dim=2)(target_head)
91 | target_p = target_p.detach()
92 | out_head = lm_head(predict.to(getattr(torch, arguments.verifier_model_lm_head_dtype)))
93 | out_logp = torch.nn.LogSoftmax(dim=2)(out_head)
94 |
95 | loss_mask = batch["loss_mask"][:, :, None]
96 |
97 | _, target_max_p_tokens = torch.max(target_p, 2)
98 | _, ealge_max_p_tokens = torch.max(out_logp, 2)
99 | step_correctly_predicted_tokens_count += ((target_max_p_tokens == ealge_max_p_tokens) * loss_mask.squeeze()).sum().item()
100 |
101 | plogp = target_p * out_logp
102 | ploss = -torch.sum(torch.sum(loss_mask * plogp, 2))
103 | vloss = criterion(predict, batch["target"])
104 | vloss = torch.sum(torch.mean(loss_mask * vloss, 2))
105 | loss = arguments.v_w * vloss + arguments.p_w * ploss
106 | loss = loss / num_items_in_batch
107 | accum_loss += loss.item()
108 | loss.backward()
109 | torch.nn.utils.clip_grad_norm_(model.parameters(), arguments.grad_clip)
110 |
111 | model_optimizer.step()
112 | scheduler.step()
113 | model_optimizer.zero_grad()
114 |
115 | total_steps_passed += 1
116 |
117 | step_end = time.perf_counter()
118 | mean_step_duration_across_gpus = step_end - step_start
119 |
120 | time_taken = step_end - step_start
121 | total_throughput = num_items_in_batch / time_taken
122 |
123 | loss_tensor = accum_loss
124 |
125 | accuracy = float("nan")
126 | if num_items_in_batch != 0:
127 | accuracy = step_correctly_predicted_tokens_count / num_items_in_batch
128 |
129 | current_lr = arguments.learning_rate
130 | if arguments.num_warmup_steps is not None:
131 | current_lr = scheduler.get_last_lr()[0]
132 |
133 | logging.info("epoch %d/%d, step %d/%d, mean step duration across gpus %.4f seconds, lr %.8f, loss %.4f, throughput %d tps, accuracy %.4f", epoch + 1, arguments.epochs, total_steps_passed, arguments.num_training_steps, mean_step_duration_across_gpus, current_lr, loss_tensor, total_throughput, accuracy)
134 | clearml_logger.report_scalar(title="train/steploss", series="series", value=loss_tensor, iteration=total_steps_passed)
135 | clearml_logger.report_scalar(title="train/throughput tokens/s", series="series", value=total_throughput, iteration=total_steps_passed)
136 | clearml_logger.report_scalar(title="train/stepaccuracy", series="series", value=accuracy, iteration=total_steps_passed)
137 | clearml_logger.report_scalar(title="train/epoch", series="series", value=epoch, iteration=total_steps_passed)
138 | clearml_logger.report_scalar(title="train/lr", series="series", value=current_lr, iteration=total_steps_passed)
139 |
140 | if total_steps_passed % arguments.save == 0:
141 | pathlib.Path(f"{arguments.cpdir}/epoch_{epoch}_step_{total_steps_passed}").mkdir(parents=True)
142 | safetensors.torch.save_file(model.state_dict(), f"{arguments.cpdir}/epoch_{epoch}_step_{total_steps_passed}/model.safetensors")
143 | with open(arguments.eagle_config_path, "r") as f:
144 | config_data = json.load(f)
145 | config_data["architectures"] = ["LlamaForCausalLMEagle"]
146 | with open(f"{arguments.cpdir}/epoch_{epoch}_step_{total_steps_passed}/config.json", "w") as file:
147 | json.dump(config_data, file, ensure_ascii=False, indent=4)
148 |
149 | if total_steps_passed == arguments.num_training_steps:
150 | break
151 |
152 | if total_steps_passed == arguments.num_training_steps:
153 | break
154 |
155 |
156 | def _parse_arguments() -> argparse.Namespace:
157 | parser = argparse.ArgumentParser(description="Coach that trains eagle draft model")
158 | parser.add_argument("--micro-batch-size", type=int, required=True, help="Micro batch size")
159 | parser.add_argument("--gradient-accumulation-steps", type=int, required=True, help="Gradient accumulation steps")
160 | parser.add_argument("--num-warmup-steps", type=int, required=True, help="Num warmup steps")
161 | parser.add_argument("--num-training-steps", type=int, required=True, help="Num training steps")
162 | parser.add_argument("--clearml-project", type=str, required=True, help="Clearml project")
163 | parser.add_argument("--clearml-task", type=str, required=True, help="Clearml task")
164 | parser.add_argument("--epochs", type=int, required=True, help="Epochs")
165 | parser.add_argument("--verifier-model-path", type=str, required=True, help="verifier_model_path")
166 | parser.add_argument("--dataset-path", type=str, required=True, help="verifier_model_path")
167 | parser.add_argument("--eagle-config-path", type=str, required=True, help="eagle_config_path")
168 | parser.add_argument("--learning-rate", type=float, required=True, help="eagle_config_path")
169 | parser.add_argument("--maximum-model-length", type=int, required=True, help="eagle_config_path")
170 | parser.add_argument("--noise-low", type=float, required=True, help="eagle_config_path")
171 | parser.add_argument("--noise-high", type=float, required=True, help="eagle_config_path")
172 | parser.add_argument("--v-w", type=float, required=True, help="eagle_config_path")
173 | parser.add_argument("--p-w", type=float, required=True, help="eagle_config_path")
174 | parser.add_argument("--grad-clip", type=float, required=True, help="eagle_config_path")
175 | parser.add_argument("--b1", type=float, required=True, help="eagle_config_path")
176 | parser.add_argument("--b2", type=float, required=True, help="eagle_config_path")
177 | parser.add_argument("--cpdir", type=pathlib.Path, default="./checkpoints", help="Path to folder to save checkpoints")
178 | parser.add_argument("--save", type=int, required=False, help="Save model after every number of steps")
179 | parser.add_argument("--mixed-precision", type=str, required=False, help="Save model after every number of steps")
180 | parser.add_argument("--verifier-model-lm-head-dtype", type=str, required=False, help="Save model after every number of steps")
181 | parser.add_argument("--verifier-model-dtype", type=str, required=False, help="Save model after every number of steps")
182 | parser.add_argument("--eagle-dtype", type=str, required=False, help="Save model after every number of steps")
183 | parser.add_argument("--attn", type=str, required=False, help="Save model after every number of steps")
184 | return parser.parse_args()
185 |
186 |
187 | def _count_parameters(model: torch.nn.Module) -> int:
188 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
189 |
190 |
191 | def _initialize_verifier_lm_head(verifier_path: pathlib.Path) -> torch.nn.Linear:
192 | with open(f"{verifier_path}/config.json", "r") as file:
193 | config = json.load(file)
194 | head = torch.nn.Linear(config["hidden_size"], config["vocab_size"], bias=False)
195 | with open(os.path.join(verifier_path, "model.safetensors.index.json"), "r") as f:
196 | index_json = json.loads(f.read())
197 | head_path = index_json["weight_map"]["lm_head.weight"]
198 | with safetensors.safe_open(os.path.join(verifier_path, head_path), framework="pt") as f:
199 | tensor = f.get_slice("lm_head.weight")[:, :config["hidden_size"]]
200 | head.weight.data = tensor
201 | head.eval()
202 | for param in head.parameters():
203 | param.requires_grad = False
204 | return head
205 |
206 |
207 | class Dataset(torch.utils.data.Dataset):
208 | def __init__(self, dataset: datasets.Dataset) -> None:
209 | self._dataset = dataset
210 |
211 | def __len__(self) -> int:
212 | return len(self._dataset)
213 |
214 | def __getitem__(self, index: int) -> dict:
215 | return self._dataset[index]
216 |
217 |
218 | class Collator:
219 | def __init__(self, model_path) -> None:
220 | self._tokenizer = transformers.AutoTokenizer.from_pretrained(str(model_path), use_fast=True)
221 | self._tokenizer.pad_token = "[PAD]"
222 | self._tokenizer.pad_token_id = 0
223 |
224 | def __call__(self, features: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor]:
225 | result = self._tokenizer.apply_chat_template([m["messages"] for m in features], tokenize=True, add_generation_prompt=False, return_dict=True, return_assistant_tokens_mask=True, return_tensors="pt", padding=True)
226 | return {
227 | "input_ids": result["input_ids"],
228 | "loss_mask": result["assistant_masks"]
229 | }
230 |
231 |
232 | def _make_eagle_input(batch, verifier_model, max_model_len, transform_uniform_low, transformer_uniform_high, device):
233 | input_ids = batch["input_ids"].to(device)[:, :max_model_len]
234 | loss_mask = batch["loss_mask"].to(device)[:, :max_model_len]
235 |
236 | with torch.no_grad():
237 | outs_big = verifier_model(input_ids, output_hidden_states=True, use_cache=False)
238 | hidden_state_big = outs_big.hidden_states[-1]
239 | hidden_state_big = _apply_noise_to_hidden_state(hidden_state_big, transform_uniform_low, transformer_uniform_high)
240 | T, L, D = hidden_state_big.shape
241 | target = hidden_state_big.new_zeros((T, L, D))
242 | target[:, :-1, :] = hidden_state_big[:, 1:, :]
243 | input_ids = torch.cat((input_ids[:, 1:], torch.zeros(input_ids.size(0), 1, dtype=input_ids.dtype, device=input_ids.device)), dim=1)
244 | batch = {"input_ids": input_ids, "hidden_states": hidden_state_big, "target": target, "loss_mask": loss_mask}
245 | return batch
246 |
247 |
248 | def _apply_noise_to_hidden_state(hidden_state: torch.FloatTensor, transform_uniform_low, transformer_uniform_high) -> None:
249 | noise = torch.rand_like(hidden_state) * (transformer_uniform_high - transform_uniform_low) + transform_uniform_low
250 | noisy_tensor = hidden_state + noise
251 | return noisy_tensor
252 |
253 |
254 | if __name__ == "__main__":
255 | coach()
256 |
--------------------------------------------------------------------------------
/eagle/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import json
4 | import math
5 | import torch
6 | import pathlib
7 | import clearml
8 | import logging
9 | import datasets
10 | import argparse
11 | import accelerate
12 | import contextlib
13 | import safetensors
14 | import transformers
15 |
16 | from eagle.llama2 import Llama2Model
17 |
18 |
19 | def coach() -> None:
20 | arguments = _parse_arguments()
21 |
22 | accelerator = accelerate.Accelerator(log_with="all", gradient_accumulation_steps=arguments.gradient_accumulation_steps, mixed_precision=arguments.mixed_precision)
23 | accelerate.utils.set_seed(seed=0)
24 | torch.backends.cuda.matmul.allow_tf32 = True
25 | logging.basicConfig(level=logging.INFO, format='%(levelname)s %(asctime)s [%(filename)s:%(lineno)d] %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
26 | logger = accelerate.logging.get_logger(name=__name__, log_level="INFO")
27 |
28 | if accelerator.is_main_process:
29 | logger.info("Start to prepare clearml ", main_process_only=True)
30 | clearml_task = clearml.Task.init(project_name=arguments.clearml_project, task_name=arguments.clearml_task, reuse_last_task_id=False, continue_last_task=False, output_uri=False, auto_connect_frameworks=False, auto_resource_monitoring=False)
31 | clearml_logger = clearml_task.get_logger()
32 |
33 | logger.info("Start to prepare language model head ", main_process_only=True)
34 | lm_head = _initialize_verifier_lm_head(verifier_path=arguments.verifier_model_path).to(getattr(torch, arguments.verifier_model_lm_head_dtype)).to(accelerator.device)
35 | logger.info("Language model head has dtype %s", next(lm_head.parameters()).dtype, main_process_only=True)
36 | logger.info("Language model head has %f billion parameters", _count_parameters(model=lm_head) / 10 ** 9, main_process_only=True)
37 | if accelerator.is_main_process:
38 | clearml_logger.report_single_value(name="Language model head parameters billion", value=_count_parameters(model=lm_head) / 10 ** 9)
39 |
40 | logger.info("Start to prepare target model ", main_process_only=True)
41 | verifier_model = transformers.AutoModelForCausalLM.from_pretrained(arguments.verifier_model_path, device_map=accelerator.device, torch_dtype=getattr(torch, arguments.verifier_model_dtype), attn_implementation=arguments.attn)
42 | verifier_model = verifier_model.eval()
43 | logger.info("Target model head has dtype %s", next(verifier_model.parameters()).dtype, main_process_only=True)
44 | logger.info("Target model head has %f billion parameters", _count_parameters(model=verifier_model) / 10 ** 9, main_process_only=True)
45 | if accelerator.is_main_process:
46 | clearml_logger.report_single_value(name="Target model head parameters billion", value=_count_parameters(model=verifier_model) / 10 ** 9)
47 |
48 | logger.info("Start to prepare draft model ", main_process_only=True)
49 | config = transformers.AutoConfig.from_pretrained(arguments.eagle_config_path)
50 | model = Llama2Model(config, load_emb=True, path=arguments.verifier_model_path).to(getattr(torch, arguments.eagle_dtype)).to(accelerator.device)
51 | logger.info("Draft model head has dtype %s", next(model.parameters()).dtype, main_process_only=True)
52 | logger.info("Draft model head has %f billion parameters", _count_parameters(model=model) / 10 ** 9, main_process_only=True)
53 | model.train()
54 | accelerator.register_for_checkpointing(model)
55 | if accelerator.is_main_process:
56 | clearml_logger.report_single_value(name="Draft model head parameters billion", value=_count_parameters(model=model) / 10 ** 9)
57 |
58 | logger.info("Start to prepare data ", main_process_only=True)
59 | dataset = datasets.load_dataset("json", data_files={"train": [arguments.dataset_path]})["train"]
60 | dataset = Dataset(dataset=dataset)
61 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=arguments.micro_batch_size, collate_fn=Collator(arguments.verifier_model_path))
62 | logger.info("Dataset contains %d samples", len(dataset), main_process_only=True)
63 |
64 | logger.info("Start to prepare miscellaneous ", main_process_only=True)
65 | criterion = torch.nn.SmoothL1Loss(reduction="none")
66 | model_optimizer = torch.optim.AdamW(model.parameters(), lr=arguments.learning_rate, betas=(arguments.b1, arguments.b2))
67 | accelerator.register_for_checkpointing(model_optimizer)
68 |
69 | scheduler = transformers.get_linear_schedule_with_warmup(optimizer=model_optimizer, num_warmup_steps=arguments.num_warmup_steps, num_training_steps=arguments.num_training_steps)
70 | accelerator.register_for_checkpointing(scheduler)
71 |
72 | model = accelerator.prepare_model(model)
73 | model_optimizer = accelerator.prepare_optimizer(model_optimizer)
74 | dataloader = accelerator.prepare_data_loader(dataloader, device_placement=True)
75 | scheduler = accelerator.prepare_scheduler(scheduler)
76 |
77 | logger.info("Start training ", main_process_only=True)
78 | total_steps_passed = 0
79 | for epoch in range(arguments.epochs):
80 | training_iterator = iter(dataloader)
81 | num_samples_in_epoch = len(dataloader)
82 | remainder = num_samples_in_epoch % arguments.gradient_accumulation_steps
83 | remainder = remainder if remainder != 0 else arguments.gradient_accumulation_steps
84 | total_gradient_updates = math.ceil(num_samples_in_epoch / arguments.gradient_accumulation_steps)
85 | for update_step in range(total_gradient_updates):
86 | step_start = time.perf_counter()
87 | accum_loss = 0.0
88 | batch_samples = []
89 | num_batches_in_step = arguments.gradient_accumulation_steps if update_step != (total_gradient_updates - 1) else remainder
90 | for _ in range(num_batches_in_step):
91 | batch_samples += [next(training_iterator)]
92 | num_items_in_batch = sum([batch["loss_mask"][:, :arguments.maximum_model_length].sum() for batch in batch_samples])
93 | num_items_in_batch = accelerator.gather(num_items_in_batch).sum().item()
94 | step_correctly_predicted_tokens_count = 0
95 |
96 | for i, batch in enumerate(batch_samples):
97 | if (i < len(batch_samples) - 1 and accelerator.num_processes > 1):
98 | ctx = model.no_sync
99 | else:
100 | ctx = contextlib.nullcontext
101 | with ctx():
102 | batch = _make_eagle_input(batch, verifier_model, arguments.maximum_model_length, arguments.noise_low, arguments.noise_high, accelerator.device)
103 | batch["hidden_states"] = batch["hidden_states"]
104 | batch["target"] = batch["target"]
105 | predict = model(batch["hidden_states"].to( getattr(torch, arguments.eagle_dtype) ), input_ids=batch["input_ids"])
106 | with torch.no_grad():
107 | target_head = lm_head(batch["target"].to( getattr(torch, arguments.verifier_model_lm_head_dtype) ),)
108 | target_p = torch.nn.Softmax(dim=2)(target_head)
109 | target_p = target_p.detach()
110 | out_head = lm_head(predict.to(getattr(torch, arguments.verifier_model_lm_head_dtype)))
111 | out_logp = torch.nn.LogSoftmax(dim=2)(out_head)
112 |
113 | loss_mask = batch["loss_mask"][:, :, None]
114 |
115 | _, target_max_p_tokens = torch.max(target_p, 2)
116 | _, ealge_max_p_tokens = torch.max(out_logp, 2)
117 | step_correctly_predicted_tokens_count += ((target_max_p_tokens == ealge_max_p_tokens) * loss_mask.squeeze()).sum().item()
118 |
119 | plogp = target_p * out_logp
120 | ploss = -torch.sum(torch.sum(loss_mask * plogp, 2))
121 | vloss = criterion(predict, batch["target"])
122 | vloss = torch.sum(torch.mean(loss_mask * vloss, 2))
123 | loss = arguments.v_w * vloss + arguments.p_w * ploss
124 | loss = (loss * arguments.gradient_accumulation_steps * accelerator.num_processes) / num_items_in_batch
125 | accum_loss += loss.item()
126 | accelerator.backward(loss)
127 | accelerator.clip_grad_value_(model.parameters(), arguments.grad_clip)
128 |
129 | model_optimizer.step()
130 | scheduler.step()
131 | model_optimizer.zero_grad()
132 |
133 | total_steps_passed += 1
134 |
135 | step_end = time.perf_counter()
136 | step_duration = torch.tensor(step_end - step_start, device=accelerator.device)
137 | mean_step_duration_across_gpus = accelerator.reduce(step_duration, reduction="mean").item()
138 |
139 | time_taken = step_end - step_start
140 | throughput = torch.tensor(num_items_in_batch / time_taken, device=accelerator.device)
141 | total_throughput = accelerator.reduce(throughput, reduction="sum").item()
142 |
143 | loss_tensor = torch.tensor(accum_loss / (arguments.gradient_accumulation_steps * accelerator.num_processes), device=accelerator.device)
144 | loss_tensor = accelerator.reduce(loss_tensor, reduction="sum").item()
145 |
146 | accuracy = float("nan")
147 | if num_items_in_batch != 0:
148 | step_correctly_predicted_tokens_count = accelerator.reduce(torch.tensor(step_correctly_predicted_tokens_count, device=accelerator.device), reduction="sum").item()
149 | accuracy = step_correctly_predicted_tokens_count / num_items_in_batch
150 |
151 | current_lr = arguments.learning_rate
152 | if arguments.num_warmup_steps is not None:
153 | current_lr = scheduler.get_last_lr()[0]
154 |
155 | logger.info("epoch %d/%d, step %d/%d, mean step duration across gpus %.4f seconds, lr %.8f, loss %.4f, throughput %d tps, accuracy %.4f", epoch + 1, arguments.epochs, total_steps_passed, arguments.num_training_steps, mean_step_duration_across_gpus, current_lr, loss_tensor, total_throughput, accuracy, main_process_only=True)
156 | if accelerator.is_main_process:
157 | clearml_logger.report_scalar(title="train/steploss", series="series", value=loss_tensor, iteration=total_steps_passed)
158 | clearml_logger.report_scalar(title="train/throughput tokens/s", series="series", value=total_throughput, iteration=total_steps_passed)
159 | clearml_logger.report_scalar(title="train/stepaccuracy", series="series", value=accuracy, iteration=total_steps_passed)
160 | clearml_logger.report_scalar(title="train/epoch", series="series", value=epoch, iteration=total_steps_passed)
161 | clearml_logger.report_scalar(title="train/lr", series="series", value=current_lr, iteration=total_steps_passed)
162 |
163 | if accelerator.is_local_main_process and total_steps_passed % arguments.save == 0:
164 | accelerator.save_state(output_dir=f"{arguments.cpdir}/epoch_{epoch}_step_{total_steps_passed}")
165 | with open(arguments.eagle_config_path, "r") as f:
166 | config_data = json.load(f)
167 | config_data["architectures"] = ["LlamaForCausalLMEagle"]
168 | with open(f"{arguments.cpdir}/epoch_{epoch}_step_{total_steps_passed}/config.json", "w") as file:
169 | json.dump(config_data, file, ensure_ascii=False, indent=4)
170 |
171 | if total_steps_passed == arguments.num_training_steps:
172 | break
173 |
174 | if total_steps_passed == arguments.num_training_steps:
175 | break
176 |
177 | accelerator.end_training()
178 |
179 |
180 | def _parse_arguments() -> argparse.Namespace:
181 | parser = argparse.ArgumentParser(description="Coach that trains eagle draft model")
182 | parser.add_argument("--micro-batch-size", type=int, required=True, help="Micro batch size")
183 | parser.add_argument("--gradient-accumulation-steps", type=int, required=True, help="Gradient accumulation steps")
184 | parser.add_argument("--num-warmup-steps", type=int, required=True, help="Num warmup steps")
185 | parser.add_argument("--num-training-steps", type=int, required=True, help="Num training steps")
186 | parser.add_argument("--clearml-project", type=str, required=True, help="Clearml project")
187 | parser.add_argument("--clearml-task", type=str, required=True, help="Clearml task")
188 | parser.add_argument("--epochs", type=int, required=True, help="Epochs")
189 | parser.add_argument("--verifier-model-path", type=str, required=True, help="verifier_model_path")
190 | parser.add_argument("--dataset-path", type=str, required=True, help="verifier_model_path")
191 | parser.add_argument("--eagle-config-path", type=str, required=True, help="eagle_config_path")
192 | parser.add_argument("--learning-rate", type=float, required=True, help="eagle_config_path")
193 | parser.add_argument("--maximum-model-length", type=int, required=True, help="eagle_config_path")
194 | parser.add_argument("--noise-low", type=float, required=True, help="eagle_config_path")
195 | parser.add_argument("--noise-high", type=float, required=True, help="eagle_config_path")
196 | parser.add_argument("--v-w", type=float, required=True, help="eagle_config_path")
197 | parser.add_argument("--p-w", type=float, required=True, help="eagle_config_path")
198 | parser.add_argument("--grad-clip", type=float, required=True, help="eagle_config_path")
199 | parser.add_argument("--b1", type=float, required=True, help="eagle_config_path")
200 | parser.add_argument("--b2", type=float, required=True, help="eagle_config_path")
201 | parser.add_argument("--cpdir", type=pathlib.Path, default="./checkpoints", help="Path to folder to save checkpoints")
202 | parser.add_argument("--save", type=int, required=False, help="Save model after every number of steps")
203 | parser.add_argument("--mixed-precision", type=str, required=False, help="Save model after every number of steps")
204 | parser.add_argument("--verifier-model-lm-head-dtype", type=str, required=False, help="Save model after every number of steps")
205 | parser.add_argument("--verifier-model-dtype", type=str, required=False, help="Save model after every number of steps")
206 | parser.add_argument("--eagle-dtype", type=str, required=False, help="Save model after every number of steps")
207 | parser.add_argument("--attn", type=str, required=False, help="Save model after every number of steps")
208 | return parser.parse_args()
209 |
210 |
211 | def _count_parameters(model: torch.nn.Module) -> int:
212 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
213 |
214 |
215 | def _initialize_verifier_lm_head(verifier_path: pathlib.Path) -> torch.nn.Linear:
216 | with open(f"{verifier_path}/config.json", "r") as file:
217 | config = json.load(file)
218 | head = torch.nn.Linear(config["hidden_size"], config["vocab_size"], bias=False)
219 | with open(os.path.join(verifier_path, "model.safetensors.index.json"), "r") as f:
220 | index_json = json.loads(f.read())
221 | head_path = index_json["weight_map"]["lm_head.weight"]
222 | with safetensors.safe_open(os.path.join(verifier_path, head_path), framework="pt") as f:
223 | tensor = f.get_slice("lm_head.weight")[:, :config["hidden_size"]]
224 | head.weight.data = tensor
225 | head.eval()
226 | for param in head.parameters():
227 | param.requires_grad = False
228 | return head
229 |
230 |
231 | class Dataset(torch.utils.data.Dataset):
232 | def __init__(self, dataset: datasets.Dataset) -> None:
233 | self._dataset = dataset
234 |
235 | def __len__(self) -> int:
236 | return len(self._dataset)
237 |
238 | def __getitem__(self, index: int) -> dict:
239 | return self._dataset[index]
240 |
241 |
242 | class Collator:
243 | def __init__(self, model_path) -> None:
244 | self._tokenizer = transformers.AutoTokenizer.from_pretrained(str(model_path), use_fast=True)
245 | self._tokenizer.pad_token = "[PAD]"
246 | self._tokenizer.pad_token_id = 0
247 |
248 | def __call__(self, features: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor]:
249 | result = self._tokenizer.apply_chat_template([m["messages"] for m in features], tokenize=True, add_generation_prompt=False, return_dict=True, return_assistant_tokens_mask=True, return_tensors="pt", padding=True)
250 | return {
251 | "input_ids": result["input_ids"],
252 | "loss_mask": result["assistant_masks"]
253 | }
254 |
255 |
256 | def _make_eagle_input(batch, verifier_model, max_model_len, transform_uniform_low, transformer_uniform_high, device):
257 | input_ids = batch["input_ids"].to(device)[:, :max_model_len]
258 | loss_mask = batch["loss_mask"].to(device)[:, :max_model_len]
259 |
260 | with torch.no_grad():
261 | outs_big = verifier_model(input_ids, output_hidden_states=True, use_cache=False)
262 | hidden_state_big = outs_big.hidden_states[-1]
263 | hidden_state_big = _apply_noise_to_hidden_state(hidden_state_big, transform_uniform_low, transformer_uniform_high)
264 | T, L, D = hidden_state_big.shape
265 | target = hidden_state_big.new_zeros((T, L, D))
266 | target[:, :-1, :] = hidden_state_big[:, 1:, :]
267 | input_ids = torch.cat((input_ids[:, 1:], torch.zeros(input_ids.size(0), 1, dtype=input_ids.dtype, device=input_ids.device)), dim=1)
268 | batch = {"input_ids": input_ids, "hidden_states": hidden_state_big, "target": target, "loss_mask": loss_mask}
269 | return batch
270 |
271 |
272 | def _apply_noise_to_hidden_state(hidden_state: torch.FloatTensor, transform_uniform_low, transformer_uniform_high) -> None:
273 | noise = torch.rand_like(hidden_state) * (transformer_uniform_high - transform_uniform_low) + transform_uniform_low
274 | noisy_tensor = hidden_state + noise
275 | return noisy_tensor
276 |
277 |
278 | if __name__ == "__main__":
279 | coach()
--------------------------------------------------------------------------------
/eagle/llama2.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | from typing import List, Optional, Tuple
4 |
5 | import torch
6 | import torch.nn.functional as F
7 | import torch.utils.checkpoint
8 | from torch import nn
9 |
10 | from transformers.activations import ACT2FN
11 | top_k=10
12 |
13 | # Copied from transformers.models.bart.modeling_bart._make_causal_mask
14 | def _make_causal_mask(
15 | input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
16 | ):
17 | """
18 | Make causal mask used for bi-directional self-attention.
19 | """
20 | bsz, tgt_len = input_ids_shape
21 | mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
22 | mask_cond = torch.arange(mask.size(-1), device=device)
23 | mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
24 | mask = mask.to(dtype)
25 |
26 | if past_key_values_length > 0:
27 | mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
28 | return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
29 |
30 |
31 | # Copied from transformers.models.bart.modeling_bart._expand_mask
32 | def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
33 | """
34 | Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
35 | """
36 | bsz, src_len = mask.size()
37 | tgt_len = tgt_len if tgt_len is not None else src_len
38 |
39 | expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
40 |
41 | inverted_mask = 1.0 - expanded_mask
42 |
43 | return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
44 |
45 | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
46 | """
47 | This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
48 | num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
49 | """
50 | batch, num_key_value_heads, slen, head_dim = hidden_states.shape
51 | if n_rep == 1:
52 | return hidden_states
53 | hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
54 | return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
55 |
56 | def rotate_half(x):
57 | """Rotates half the hidden dims of the input."""
58 | x1 = x[..., : x.shape[-1] // 2]
59 | x2 = x[..., x.shape[-1] // 2 :]
60 | return torch.cat((-x2, x1), dim=-1)
61 | def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
62 | # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
63 | cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
64 | sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
65 | cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
66 | sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
67 | q_embed = (q * cos) + (rotate_half(q) * sin)
68 | k_embed = (k * cos) + (rotate_half(k) * sin)
69 | return q_embed, k_embed
70 | class LlamaRotaryEmbedding(torch.nn.Module):
71 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
72 | super().__init__()
73 |
74 | self.dim = dim
75 | self.max_position_embeddings = max_position_embeddings
76 | self.base = base
77 | inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
78 | self.register_buffer("inv_freq", inv_freq, persistent=False)
79 |
80 | # Build here to make `torch.jit.trace` work.
81 | self._set_cos_sin_cache(
82 | seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
83 | )
84 |
85 | def _set_cos_sin_cache(self, seq_len, device, dtype):
86 | self.max_seq_len_cached = seq_len
87 | t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
88 |
89 | freqs = torch.einsum("i,j->ij", t, self.inv_freq)
90 | # Different from paper, but it uses a different permutation in order to obtain the same calculation
91 | emb = torch.cat((freqs, freqs), dim=-1)
92 | self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
93 | self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
94 |
95 | def forward(self, x, seq_len=None):
96 | # x: [bs, num_attention_heads, seq_len, head_size]
97 | if seq_len > self.max_seq_len_cached:
98 | self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
99 |
100 | return (
101 | self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
102 | self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
103 | )
104 |
105 | class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
106 | """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
107 |
108 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
109 | self.scaling_factor = scaling_factor
110 | super().__init__(dim, max_position_embeddings, base, device)
111 |
112 | def _set_cos_sin_cache(self, seq_len, device, dtype):
113 | self.max_seq_len_cached = seq_len
114 | t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
115 | t = t / self.scaling_factor
116 |
117 | freqs = torch.einsum("i,j->ij", t, self.inv_freq)
118 | # Different from paper, but it uses a different permutation in order to obtain the same calculation
119 | emb = torch.cat((freqs, freqs), dim=-1)
120 | self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
121 | self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
122 |
123 |
124 | class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
125 | """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
126 |
127 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
128 | self.scaling_factor = scaling_factor
129 | super().__init__(dim, max_position_embeddings, base, device)
130 |
131 | def _set_cos_sin_cache(self, seq_len, device, dtype):
132 | self.max_seq_len_cached = seq_len
133 |
134 | if seq_len > self.max_position_embeddings:
135 | base = self.base * (
136 | (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
137 | ) ** (self.dim / (self.dim - 2))
138 | inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
139 | self.register_buffer("inv_freq", inv_freq, persistent=False)
140 |
141 | t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
142 |
143 | freqs = torch.einsum("i,j->ij", t, self.inv_freq)
144 | # Different from paper, but it uses a different permutation in order to obtain the same calculation
145 | emb = torch.cat((freqs, freqs), dim=-1)
146 | self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
147 | self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
148 |
149 | class LlamaAttention(nn.Module):
150 | """Multi-headed attention from 'Attention Is All You Need' paper"""
151 |
152 | def __init__(self, config):
153 | super().__init__()
154 | self.config = config
155 | self.hidden_size = config.hidden_size
156 | self.num_heads = config.num_attention_heads
157 | self.head_dim = self.hidden_size // self.num_heads
158 | self.num_key_value_heads = config.num_key_value_heads
159 | self.num_key_value_groups = self.num_heads // self.num_key_value_heads
160 | self.max_position_embeddings = config.max_position_embeddings
161 |
162 | if (self.head_dim * self.num_heads) != self.hidden_size:
163 | raise ValueError(
164 | f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
165 | f" and `num_heads`: {self.num_heads})."
166 | )
167 | self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
168 | self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
169 | self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
170 | self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
171 | self._init_rope()
172 |
173 | def _init_rope(self):
174 | if self.config.rope_scaling is None:
175 | self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
176 | else:
177 | scaling_type = self.config.rope_scaling["type"]
178 | scaling_factor = self.config.rope_scaling["factor"]
179 | if scaling_type == "linear":
180 | self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
181 | self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
182 | )
183 | elif scaling_type == "dynamic":
184 | self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
185 | self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
186 | )
187 | else:
188 | raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
189 |
190 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
191 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
192 |
193 | def forward(
194 | self,
195 | hidden_states: torch.Tensor,
196 | attention_mask: Optional[torch.Tensor] = None,
197 | position_ids: Optional[torch.LongTensor] = None,
198 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
199 | output_attentions: bool = False,
200 | use_cache: bool = False,
201 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
202 | bsz, q_len, _ = hidden_states.size()
203 |
204 | if self.config.pretraining_tp > 1:
205 | key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
206 | query_slices = self.q_proj.weight.split(
207 | (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
208 | )
209 | key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
210 | value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
211 |
212 | query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
213 | query_states = torch.cat(query_states, dim=-1)
214 |
215 | key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
216 | key_states = torch.cat(key_states, dim=-1)
217 |
218 | value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
219 | value_states = torch.cat(value_states, dim=-1)
220 |
221 | else:
222 | query_states = self.q_proj(hidden_states)
223 | key_states = self.k_proj(hidden_states)
224 | value_states = self.v_proj(hidden_states)
225 |
226 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
227 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
228 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
229 |
230 | kv_seq_len = key_states.shape[-2]
231 | if past_key_value is not None:
232 | kv_seq_len += past_key_value[0].shape[-2]
233 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
234 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
235 |
236 | if past_key_value is not None:
237 | # reuse k, v, self_attention
238 | key_states = torch.cat([past_key_value[0], key_states], dim=2)
239 | value_states = torch.cat([past_key_value[1], value_states], dim=2)
240 |
241 | past_key_value = (key_states, value_states) if use_cache else None
242 |
243 | # repeat k/v heads if n_kv_heads < n_heads
244 | key_states = repeat_kv(key_states, self.num_key_value_groups)
245 | value_states = repeat_kv(value_states, self.num_key_value_groups)
246 |
247 | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
248 |
249 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
250 | raise ValueError(
251 | f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
252 | f" {attn_weights.size()}"
253 | )
254 |
255 | if attention_mask is not None:
256 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
257 | raise ValueError(
258 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
259 | )
260 | attn_weights = attn_weights + attention_mask
261 |
262 | # upcast attention to fp32
263 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
264 | attn_output = torch.matmul(attn_weights, value_states)
265 |
266 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
267 | raise ValueError(
268 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
269 | f" {attn_output.size()}"
270 | )
271 |
272 | attn_output = attn_output.transpose(1, 2).contiguous()
273 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
274 |
275 | if self.config.pretraining_tp > 1:
276 | attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
277 | o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
278 | attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
279 | else:
280 | attn_output = self.o_proj(attn_output)
281 |
282 | if not output_attentions:
283 | attn_weights = None
284 |
285 | return attn_output, attn_weights, past_key_value
286 |
287 |
288 | class LlamaMLP(nn.Module):
289 | def __init__(self, config):
290 | super().__init__()
291 | self.config = config
292 | self.hidden_size = config.hidden_size
293 | self.intermediate_size = config.intermediate_size
294 | self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
295 | self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
296 | self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
297 | self.act_fn = ACT2FN[config.hidden_act]
298 |
299 | def forward(self, x):
300 | if self.config.pretraining_tp > 1:
301 | slice = self.intermediate_size // self.config.pretraining_tp
302 | gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
303 | up_proj_slices = self.up_proj.weight.split(slice, dim=0)
304 | down_proj_slices = self.down_proj.weight.split(slice, dim=1)
305 |
306 | gate_proj = torch.cat(
307 | [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
308 | )
309 | up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
310 |
311 | intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
312 | down_proj = [
313 | F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
314 | ]
315 | down_proj = sum(down_proj)
316 | else:
317 | down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
318 |
319 | return down_proj
320 |
321 | class LlamaRMSNorm(nn.Module):
322 | def __init__(self, hidden_size, eps=1e-6):
323 | """
324 | LlamaRMSNorm is equivalent to T5LayerNorm
325 | """
326 | super().__init__()
327 | self.weight = nn.Parameter(torch.ones(hidden_size))
328 | self.variance_epsilon = eps
329 |
330 | def forward(self, hidden_states):
331 | input_dtype = hidden_states.dtype
332 | hidden_states = hidden_states.to(torch.float32)
333 | variance = hidden_states.pow(2).mean(-1, keepdim=True)
334 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
335 | return self.weight * hidden_states.to(input_dtype)
336 |
337 | class LlamaDecoderLayer(nn.Module):
338 | def __init__(self, config,index):
339 | super().__init__()
340 | self.hidden_size = config.hidden_size
341 | self.self_attn = LlamaAttention(config=config)
342 | self.mlp = LlamaMLP(config)
343 | self.index=index
344 | if self.index!=0:
345 | self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
346 | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
347 |
348 | def forward(
349 | self,
350 | hidden_states: torch.Tensor,
351 | attention_mask: Optional[torch.Tensor] = None,
352 | position_ids: Optional[torch.LongTensor] = None,
353 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
354 | output_attentions: Optional[bool] = False,
355 | use_cache: Optional[bool] = False,
356 | ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
357 | """
358 | Args:
359 | hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
360 | attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
361 | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
362 | output_attentions (`bool`, *optional*):
363 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under
364 | returned tensors for more detail.
365 | use_cache (`bool`, *optional*):
366 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
367 | (see `past_key_values`).
368 | past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
369 | """
370 |
371 | residual = hidden_states
372 |
373 | if self.index != 0:
374 | hidden_states = self.input_layernorm(hidden_states)
375 |
376 | # Self Attention
377 | hidden_states, self_attn_weights, present_key_value = self.self_attn(
378 | hidden_states=hidden_states,
379 | attention_mask=attention_mask,
380 | position_ids=position_ids,
381 | past_key_value=past_key_value,
382 | output_attentions=output_attentions,
383 | use_cache=use_cache,
384 | )
385 | hidden_states = residual + hidden_states
386 |
387 | # Fully Connected
388 | residual = hidden_states
389 | hidden_states = self.post_attention_layernorm(hidden_states)
390 | hidden_states = self.mlp(hidden_states)
391 | hidden_states = residual + hidden_states
392 |
393 | outputs = (hidden_states,)
394 |
395 | if output_attentions:
396 | outputs += (self_attn_weights,)
397 |
398 | if use_cache:
399 | outputs += (present_key_value,)
400 |
401 | return outputs
402 |
403 |
404 | class Llama2Model(nn.Module):
405 | def __init__(self,config,load_emb=False,path=None,bias=True):
406 | super().__init__()
407 | self.gradient_checkpointing = True
408 | self.padding_idx = config.pad_token_id
409 | self.vocab_size = config.vocab_size
410 |
411 | #####################################################################################
412 | # Init embedding from base model
413 |
414 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
415 | if load_emb:
416 | from safetensors import safe_open
417 | import json
418 | try:
419 | with open(os.path.join(path,"model.safetensors.index.json"),"r") as f:
420 | index_json=json.loads(f.read())
421 | emb_path=index_json["weight_map"]["model.embed_tokens.weight"]
422 | with safe_open(os.path.join(path,emb_path),
423 | framework="pt",
424 | device="cpu") as f:
425 | tensor_slice = f.get_slice("model.embed_tokens.weight")
426 | vocab_size, hidden_dim = tensor_slice.get_shape()
427 | tensor = tensor_slice[:, :hidden_dim].float()
428 | except:
429 | with open(os.path.join(path, "pytorch_model.bin.index.json"), "r") as f:
430 | index_json = json.loads(f.read())
431 | emb_path = index_json["weight_map"]["model.embed_tokens.weight"]
432 | weights=torch.load(os.path.join(path,emb_path))
433 | tensor=weights["model.embed_tokens.weight"].float()
434 | self.embed_tokens.weight.data = tensor
435 |
436 | #####################################################################################
437 | # init eagle layers
438 |
439 | self.layers = nn.ModuleList([LlamaDecoderLayer(config,index) for index in range(config.num_hidden_layers)])
440 |
441 | #####################################################################################
442 | # init eagle down pooler from concat([emb, hidden_state]) -> hidden_size
443 |
444 | self.fc=nn.Linear(2*config.hidden_size,config.hidden_size,bias=bias)
445 |
446 | #####################################################################################
447 | # freeze emb + extra stuff
448 |
449 | self.act=ACT2FN[config.hidden_act]
450 | for param in self.embed_tokens.parameters():
451 | param.requires_grad = False
452 |
453 | def init_tree(self):
454 | self.tree = mc_sim_7b_63
455 | self.tree_buffer=generate_tree_buffers(self.tree,self.embed_tokens.weight.device)
456 |
457 | def reset(self):
458 | self.tree_mask=None
459 |
460 | def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
461 | # create causal mask
462 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
463 | combined_attention_mask = None
464 | if input_shape[-1] > 1:
465 | combined_attention_mask = _make_causal_mask(
466 | input_shape,
467 | #inputs_embeds.dtype,
468 | torch.float32, # [MODIFIED] force to cast to float32
469 | device=inputs_embeds.device,
470 | past_key_values_length=past_key_values_length,
471 | )
472 |
473 | if attention_mask is not None:
474 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
475 | expanded_attn_mask = _expand_mask(attention_mask, torch.float32, tgt_len=input_shape[-1]).to(
476 | inputs_embeds.device
477 | )
478 | combined_attention_mask = (
479 | expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
480 | )
481 |
482 | # [MODIFIED] add tree mask
483 | if hasattr(self, "tree_mask") and self.tree_mask is not None:
484 | tree_mask = self.tree_mask
485 | tree_len = tree_mask.size(-1)
486 | combined_attention_mask[:, :, -tree_len:, -tree_len:][
487 | tree_mask == 0
488 | ] = torch.finfo(torch.float32).min
489 |
490 |
491 | return combined_attention_mask
492 |
493 | def forward(
494 | self,
495 | hidden_states,
496 | input_ids,
497 | attention_mask: Optional[torch.Tensor] = None,
498 | position_ids: Optional[torch.LongTensor] = None,
499 | past_key_values: Optional[List[torch.FloatTensor]] = None,
500 | inputs_embeds: Optional[torch.FloatTensor] = None,
501 | use_cache: Optional[bool] = None,
502 | output_attentions: Optional[bool] = None,
503 | output_hidden_states: Optional[bool] = None,
504 | return_dict: Optional[bool] = None,
505 | std=None
506 | ):
507 | batch_size, seq_length, _ = hidden_states.shape
508 | seq_length_with_past = seq_length
509 | past_key_values_length = 0
510 |
511 | with torch.no_grad():
512 | #####################################################################################
513 | # embed tokens using embedding table
514 |
515 | inputs_embeds = self.embed_tokens(input_ids)
516 |
517 | if past_key_values is not None:
518 | past_key_values_length = past_key_values[0][0].shape[2]
519 | seq_length_with_past = seq_length_with_past + past_key_values_length
520 | if position_ids is None:
521 | device = hidden_states.device if hidden_states is not None else inputs_embeds.device
522 | position_ids = torch.arange(
523 | past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
524 | )
525 | position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
526 | else:
527 | position_ids = position_ids.view(-1, seq_length).long()
528 |
529 | if attention_mask is None:
530 | attention_mask = torch.ones(
531 | (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
532 | )
533 | attention_mask = self._prepare_decoder_attention_mask(
534 | attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
535 | )
536 |
537 | inputs_embeds=inputs_embeds.to(hidden_states.dtype)
538 |
539 | #####################################################################################
540 | # Concat emb + hidden state and lower dim back to hidden_size using fully connected
541 |
542 | hidden_states = self.fc(torch.cat((inputs_embeds, hidden_states), dim=-1))
543 |
544 |
545 | all_hidden_states = () if output_hidden_states else None
546 | next_decoder_cache = () if use_cache else None
547 |
548 | for idx, decoder_layer in enumerate(self.layers):
549 | if output_hidden_states:
550 | all_hidden_states += (hidden_states,)
551 |
552 | past_key_value = past_key_values[idx] if past_key_values is not None else None
553 |
554 | if self.gradient_checkpointing and self.training:
555 |
556 | def create_custom_forward(module):
557 | def custom_forward(*inputs):
558 | # None for past_key_value
559 | return module(*inputs, past_key_value, output_attentions)
560 |
561 | return custom_forward
562 |
563 | layer_outputs = torch.utils.checkpoint.checkpoint(
564 | create_custom_forward(decoder_layer),
565 | hidden_states,
566 | attention_mask,
567 | position_ids,
568 | )
569 | else:
570 | layer_outputs = decoder_layer(
571 | hidden_states,
572 | attention_mask=attention_mask,
573 | position_ids=position_ids,
574 | past_key_value=past_key_value,
575 | output_attentions=output_attentions,
576 | use_cache=use_cache,
577 | )
578 |
579 | hidden_states = layer_outputs[0]
580 |
581 | if use_cache:
582 | next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
583 |
584 | if use_cache:
585 | return hidden_states,next_decoder_cache
586 |
587 | return hidden_states
588 |
--------------------------------------------------------------------------------