├── eagle ├── __init__.py ├── prepare_sharegpt_dataset.py ├── generate_trajectories.py ├── train_tp.py ├── train.py └── llama2.py ├── benchmark ├── __init__.py ├── README.md ├── create_alpaca_prompts.py └── benchmark.py ├── .gitignore ├── pages ├── README.md ├── train_qwen_3.md └── reproduce_eagle_1_paper.md ├── resources ├── eagle_config.json ├── eagle_config_qwen3_8b.json ├── example_chat_template_with_generation_keyword.json └── raw_example_dataset.jsonl ├── README.md ├── docker └── Dockerfile ├── .github └── workflows │ └── docker_build_push.yaml └── LICENSE /eagle/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /benchmark/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | tokenized_dataset 2 | eagle.egg-info 3 | checkpoints 4 | wandb 5 | vllm 6 | venv 7 | eagle/__pycache__ 8 | models 9 | *.crt 10 | -------------------------------------------------------------------------------- /pages/README.md: -------------------------------------------------------------------------------- 1 | ## Pages 2 | 3 | Explore pages to get started 4 | 5 | * [Reproduce eagle 1 paper on sharegpt dataset](./reproduce_eagle_1_paper.md) 6 | -------------------------------------------------------------------------------- /resources/eagle_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LlamaForCausalLM" 4 | ], 5 | "model_type": "llama", 6 | "hidden_act": "silu", 7 | "hidden_size": 4096, 8 | "initializer_range": 0.02, 9 | "torch_dtype": "bfloat16", 10 | "intermediate_size": 11008, 11 | "num_attention_heads": 32, 12 | "num_hidden_layers": 1, 13 | "num_key_value_heads": 32, 14 | "rms_norm_eps": 1e-06 15 | } 16 | -------------------------------------------------------------------------------- /resources/eagle_config_qwen3_8b.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "LlamaForCausalLM" 4 | ], 5 | "model_type": "llama", 6 | "hidden_act": "silu", 7 | "hidden_size": 4096, 8 | "initializer_range": 0.02, 9 | "torch_dtype": "bfloat16", 10 | "intermediate_size": 12288, 11 | "num_attention_heads": 32, 12 | "num_hidden_layers": 1, 13 | "num_key_value_heads": 32, 14 | "rms_norm_eps": 1e-06, 15 | "vocab_size": 151936 16 | } 17 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## 🦅 eagle 2 | 3 | Repository allows one to train eagle draft model fully compatible with [SGLang](https://github.com/sgl-project/sglang) that achives [paper](https://arxiv.org/abs/2401.15077v1) score in terms of end to end latency speed up and generation throughput. I will work on this project to make it minimalistic as possible while making it scalable to allow you to train SOTA eagle draft model under 1 hour on a single node of enterprise GPUs but not limited to. Checkout [pages](./pages/README.md) to get started 4 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:12.9.1-devel-ubuntu24.04 2 | 3 | WORKDIR /opt/eagle 4 | 5 | ENV PYTHONUNBUFFERED=1 6 | ENV PYTHONPATH=/opt/eagle 7 | 8 | RUN apt update 9 | RUN apt install --yes python3-pip python3.12-venv 10 | RUN python3 -m venv /venv 11 | 12 | ENV PATH="/venv/bin:$PATH" 13 | RUN pip install --upgrade pip 14 | 15 | RUN pip install torch==2.7.1 transformers==4.52.4 datasets==3.6.0 accelerate==1.8.1 clearml==2.0.0 16 | RUN pip install psutil 17 | RUN pip uninstall -y ninja && pip install ninja 18 | RUN MAX_JOBS=4 pip install flash-attn==2.8.0.post2 --no-build-isolation 19 | 20 | COPY ./eagle ./eagle 21 | -------------------------------------------------------------------------------- /.github/workflows/docker_build_push.yaml: -------------------------------------------------------------------------------- 1 | name: Build and push release docker image 2 | 3 | on: 4 | push: 5 | tags: 6 | - "*" 7 | 8 | jobs: 9 | docker: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - name: Authenticate to docker hub 13 | uses: docker/login-action@v3 14 | with: 15 | username: ${{ secrets.DOCKER_USERNAME }} 16 | password: ${{ secrets.DOCKER_PASSWORD }} 17 | registry: ${{ secrets.DOCKER_REGISTRY_URL }} 18 | 19 | - name: Configure qemu 20 | uses: docker/setup-qemu-action@v3 21 | 22 | - name: Configure buildx 23 | uses: docker/setup-buildx-action@v3 24 | 25 | - name: Build and push release docker image 26 | uses: docker/build-push-action@v6 27 | with: 28 | push: true 29 | file: docker/Dockerfile 30 | tags: | 31 | ${{ secrets.DOCKER_USERNAME }}/eagle:${{ github.ref_name }} 32 | ${{ secrets.DOCKER_USERNAME }}/eagle:latest 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Vladislav Kruglikov 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /resources/example_chat_template_with_generation_keyword.json: -------------------------------------------------------------------------------- 1 | { 2 | "add_bos_token": true, 3 | "add_eos_token": false, 4 | "bos_token": { 5 | "__type": "AddedToken", 6 | "content": "", 7 | "lstrip": false, 8 | "normalized": false, 9 | "rstrip": false, 10 | "single_word": false 11 | }, 12 | "chat_template": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{% generation %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endgeneration %}{% endif %}{% endfor %}", 13 | "clean_up_tokenization_spaces": false, 14 | "eos_token": { 15 | "__type": "AddedToken", 16 | "content": "", 17 | "lstrip": false, 18 | "normalized": false, 19 | "rstrip": false, 20 | "single_word": false 21 | }, 22 | "legacy": false, 23 | "model_max_length": 1000000000000000019884624838656, 24 | "pad_token": null, 25 | "padding_side": "right", 26 | "sp_model_kwargs": {}, 27 | "tokenizer_class": "LlamaTokenizer", 28 | "unk_token": { 29 | "__type": "AddedToken", 30 | "content": "", 31 | "lstrip": false, 32 | "normalized": false, 33 | "rstrip": false, 34 | "single_word": false 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /benchmark/README.md: -------------------------------------------------------------------------------- 1 | ## 📊 Benchmark 2 | 3 | Create alpaca prompts 4 | 5 | ```bash 6 | python3 ./benchmark/create_alpaca_prompts.py \ 7 | --output ./benchmark/alpaca.jsonl \ 8 | --n 128 9 | ``` 10 | 11 | Benchmark base model with batch size 1 meaning only 1 request runs at most at the same time 12 | 13 | ```bash 14 | docker run \ 15 | --gpus all \ 16 | -e CUDA_VISIBLE_DEVICES=0 \ 17 | -v ./models/meta-llama2-7b-chat-hf:/mnt/llama2-7b \ 18 | -v ./eagle_model:/mnt/eagle \ 19 | -v ./benchmark:/opt/benchmark \ 20 | --ipc=host \ 21 | --shm-size 32g \ 22 | lmsysorg/sglang:v0.4.6.post5-cu124 \ 23 | bash -c "cd /opt && export PYTHONPATH=$PYTHONPATH:. && python3 benchmark/benchmark.py \ 24 | --model /mnt/llama2-7b \ 25 | --prompts benchmark/alpaca.jsonl \ 26 | --n 64 \ 27 | --bs 1 \ 28 | --output benchmark/report_alpaca_bs1_wo_eagle.json \ 29 | --temperature 0" 30 | ``` 31 | 32 | Benchmark base model with draft model with batch size 1 meaning only 1 request runs at most at the same time 33 | 34 | ```bash 35 | docker run \ 36 | --gpus all \ 37 | -e CUDA_VISIBLE_DEVICES=1 \ 38 | -v ./models/meta-llama2-7b-chat-hf:/mnt/llama2-7b \ 39 | -v ./step5000/vllm:/mnt/eagle \ 40 | -v ./benchmark:/opt/benchmark \ 41 | -v ./checkpoints:/opt/checkpoints \ 42 | --ipc=host \ 43 | --shm-size 32g \ 44 | lmsysorg/sglang:v0.4.6.post5-cu124 \ 45 | bash -c "cd /opt && export PYTHONPATH=$PYTHONPATH:. && python3 benchmark/benchmark.py \ 46 | --model /mnt/llama2-7b \ 47 | --prompts benchmark/alpaca.jsonl \ 48 | --n 64 \ 49 | --bs 1 \ 50 | --output benchmark/report_alpaca_bs1_with_eagle_new_new.json \ 51 | --eagle /opt/checkpoints/step_10/sglang \ 52 | --steps 4 \ 53 | --k 1 \ 54 | --draft 4 \ 55 | --speculative-algorithm EAGLE \ 56 | --temperature 0" 57 | ``` 58 | -------------------------------------------------------------------------------- /benchmark/create_alpaca_prompts.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import argparse 3 | import datasets 4 | 5 | 6 | def create_alpaca_prompts() -> None: 7 | arguments = _parse_arguments() 8 | output_path = arguments.output 9 | n = arguments.n 10 | frac = arguments.frac 11 | if n is not None and frac is not None: 12 | raise ValueError("One of --n or --frac must be set") 13 | 14 | print("Loading dataset") 15 | dataset = datasets.load_dataset("tatsu-lab/alpaca")["train"] 16 | print(f"Dataset has {len(dataset)} rows") 17 | dataset = dataset.shuffle(seed=0) 18 | indices = range(n) if n is not None else range(int(frac * len(dataset))) 19 | dataset = dataset.select(indices) 20 | print(f"Dataset after select has {len(dataset)} rows") 21 | dataset = dataset.map( 22 | _apply_template, 23 | batched=False, 24 | num_proc=1, 25 | remove_columns=dataset.column_names, 26 | desc="Applying template" 27 | ) 28 | 29 | print("Saving to disk") 30 | dataset.to_json(output_path) 31 | 32 | 33 | def _parse_arguments() -> argparse.Namespace: 34 | parser = argparse.ArgumentParser(description="Create alpaca prompts") 35 | parser.add_argument( 36 | "--output", 37 | type=pathlib.Path, 38 | required=True, 39 | help="Path to save prompts" 40 | ) 41 | parser.add_argument( 42 | "--n", 43 | type=int, 44 | help="Number of sampels to take" 45 | ) 46 | parser.add_argument( 47 | "--frac", 48 | type=float, 49 | help="Number of sampels to take from 0.0 to 1.0 percent" 50 | ) 51 | return parser.parse_args() 52 | 53 | 54 | def _apply_template(example: dict) -> dict: 55 | if example["input"] == "": 56 | return {"prompt": _ALPACA_TEMPLATE_WITHOUT_INPUT.format(instruction=example["instruction"])} 57 | else: 58 | return {"prompt": _ALPACA_TEMPLATE_WITH_INPUT.format(instruction=example["instruction"], input=example["input"])} 59 | 60 | 61 | _ALPACA_TEMPLATE_WITH_INPUT = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. 62 | 63 | ### Instruction: 64 | {instruction} 65 | 66 | ### Input: 67 | {input} 68 | 69 | ### Response: 70 | """ 71 | 72 | _ALPACA_TEMPLATE_WITHOUT_INPUT = """Below is an instruction that describes a task. Write a response that appropriately completes the request. 73 | 74 | ### Instruction: 75 | {instruction} 76 | 77 | ### Response: 78 | """ 79 | 80 | 81 | if __name__ == "__main__": 82 | create_alpaca_prompts() 83 | -------------------------------------------------------------------------------- /pages/train_qwen_3.md: -------------------------------------------------------------------------------- 1 | ## Train qwen3 2 | 3 | ```bash 4 | huggingface-cli download Qwen/Qwen3-8B --local-dir ./resources/Qwen/Qwen3-8B 5 | ``` 6 | 7 | ```bash 8 | docker build --tag eagle -f docker/Dockerfile . && docker run \ 9 | --gpus all \ 10 | -e OMP_NUM_THREADS=4 \ 11 | -e CUDA_VISIBLE_DEVICES=1,2,5,7 \ 12 | -e CLEARML_OFFLINE_MODE=1 \ 13 | -v ./resources:/mnt/resources \ 14 | eagle \ 15 | torchrun \ 16 | --nnodes=1 \ 17 | --nproc_per_node=4 \ 18 | eagle/train.py \ 19 | --micro-batch-size 2 \ 20 | --gradient-accumulation-steps 2 \ 21 | --num-warmup-steps 4096 \ 22 | --num-training-steps 524288 \ 23 | --epochs 4 \ 24 | --clearml-project eagle \ 25 | --clearml-task 4gpus-2microbs-2accum-16globalbs \ 26 | --verifier-model-path /mnt/resources/qwen3-8b \ 27 | --dataset-path /mnt/resources/sharegpt.jsonl \ 28 | --eagle-config-path /mnt/resources/eagle_config_qwen3_8b.json \ 29 | --learning-rate 2e-4 \ 30 | --maximum-model-length 2048 \ 31 | --noise-low -0.1 \ 32 | --noise-high 0.1 \ 33 | --v-w 1.0 \ 34 | --p-w 0.1 \ 35 | --grad-clip 0.5 \ 36 | --b1 0.9 \ 37 | --b2 0.95 \ 38 | --cpdir /mnt/resources/checkpoints \ 39 | --save 4096 \ 40 | --mixed-precision bf16 \ 41 | --verifier-model-lm-head-dtype bfloat16 \ 42 | --verifier-model-dtype bfloat16 \ 43 | --eagle-dtype bfloat16 \ 44 | --attn flash_attention_2 45 | ``` 46 | 47 | ```bash 48 | docker build --tag eagle -f docker/Dockerfile . && docker run \ 49 | --gpus all \ 50 | -e CUDA_VISIBLE_DEVICES=1,2,5,7 \ 51 | -e CLEARML_OFFLINE_MODE=1 \ 52 | -v ./resources:/mnt/resources \ 53 | eagle python3 ./eagle/train_tp.py \ 54 | --micro-batch-size 2 \ 55 | --gradient-accumulation-steps 2 \ 56 | --num-warmup-steps 4096 \ 57 | --num-training-steps 524288 \ 58 | --epochs 4 \ 59 | --clearml-project eagle \ 60 | --clearml-task 4gpus-2microbs-2accum-16globalbs \ 61 | --verifier-model-path /mnt/resources/qwen3-8b \ 62 | --dataset-path /mnt/resources/sharegpt.jsonl \ 63 | --eagle-config-path /mnt/resources/eagle_config_qwen3_8b.json \ 64 | --learning-rate 2e-4 \ 65 | --maximum-model-length 2048 \ 66 | --noise-low -0.1 \ 67 | --noise-high 0.1 \ 68 | --v-w 1.0 \ 69 | --p-w 0.1 \ 70 | --grad-clip 0.5 \ 71 | --b1 0.9 \ 72 | --b2 0.95 \ 73 | --cpdir /mnt/resources/checkpoints \ 74 | --save 4096 \ 75 | --mixed-precision bf16 \ 76 | --verifier-model-lm-head-dtype bfloat16 \ 77 | --verifier-model-dtype bfloat16 \ 78 | --eagle-dtype bfloat16 \ 79 | --attn flash_attention_2 80 | ``` 81 | -------------------------------------------------------------------------------- /eagle/prepare_sharegpt_dataset.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import datasets 3 | import argparse 4 | 5 | 6 | def convert_sharegpt_dataset() -> None: 7 | arguments = _parse_arguments() 8 | output_path: pathlib.Path = arguments.output 9 | n = arguments.n 10 | frac = arguments.frac 11 | if n is not None and frac is not None: 12 | raise ValueError("One of --n or --frac must be set") 13 | 14 | print("Loading raw dataset") 15 | dataset = ( 16 | datasets 17 | .load_dataset("json", data_files={ 18 | "train": "https://huggingface.co/datasets/Aeala/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V4.3_unfiltered_cleaned_split.json" 19 | }) 20 | ["train"] 21 | .shuffle(seed=0) 22 | ) 23 | 24 | print(f"Dataset has {len(dataset)} rows") 25 | indices = range(n) if n is not None else range(int(frac * len(dataset))) 26 | dataset = dataset.select(indices) 27 | print(f"Dataset after select has {len(dataset)} rows") 28 | 29 | dataset = dataset.map( 30 | _convert_sharegpt_dataset, 31 | num_proc=1, 32 | remove_columns=dataset.column_names, 33 | desc="Converting dataset" 34 | ) 35 | 36 | print(f"Dataset after filtering has {len(dataset)} rows") 37 | 38 | print("Saving to disk") 39 | dataset.to_json(output_path) 40 | 41 | 42 | def _parse_arguments() -> argparse.Namespace: 43 | parser = argparse.ArgumentParser( 44 | description="Generate trajectories" 45 | ) 46 | parser.add_argument( 47 | "--output", 48 | type=pathlib.Path, 49 | required=True, 50 | help="Path to jsonlines file where the processed dataset will be stored" 51 | ) 52 | parser.add_argument( 53 | "--n", 54 | type=int, 55 | help="Number of sampels to take" 56 | ) 57 | parser.add_argument( 58 | "--frac", 59 | type=float, 60 | help="Number of sampels to take from 0.0 to 1.0 percent" 61 | ) 62 | return parser.parse_args() 63 | 64 | 65 | def _convert_sharegpt_dataset(example: dict) -> dict: 66 | new_turns = [] 67 | # https://github.com/SafeAILab/EAGLE/blob/4a9cf3a1f6cd4a294e6d30a4e7c77cba246d7ca5/eagle/ge_data/ge_data_all_llama2chat.py#L65 68 | system_prompt = "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information." 69 | new_turns.append({"role": "system", "content": system_prompt}) 70 | for turn in example["conversations"]: 71 | if turn["from"] == "gpt": 72 | role = "assistant" 73 | elif turn["from"] == "human": 74 | role = "user" 75 | else: 76 | raise ValueError("Unknown role") 77 | new_turn = {"role": role, "content": turn["value"]} 78 | new_turns.append(new_turn) 79 | # https://github.com/SafeAILab/EAGLE/blob/4a9cf3a1f6cd4a294e6d30a4e7c77cba246d7ca5/eagle/ge_data/ge_data_all_llama2chat.py#L69 80 | if new_turns[1]["role"] == "assistant": 81 | new_turns = [new_turns[0]] + new_turns[2:] 82 | return {"messages": new_turns} 83 | 84 | 85 | if __name__ == "__main__": 86 | convert_sharegpt_dataset() 87 | -------------------------------------------------------------------------------- /eagle/generate_trajectories.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import sglang 4 | import pathlib 5 | import argparse 6 | import datasets 7 | import transformers 8 | 9 | 10 | def _prepare_dataset() -> None: 11 | arguments = _parse_arguments() 12 | 13 | raw_path: pathlib.Path = arguments.input 14 | model_path: pathlib.Path = arguments.model 15 | tokenizer_path: pathlib.Path = arguments.tokenizer 16 | output_path: pathlib.Path = arguments.output 17 | n = arguments.n 18 | frac = arguments.frac 19 | if n is not None and frac is not None: 20 | raise ValueError("One of --n or --frac must be set") 21 | 22 | print("Loading tokenizer") 23 | tokenizer = transformers.AutoTokenizer.from_pretrained(str(tokenizer_path), use_fast=True) 24 | 25 | print("Loading raw dataset") 26 | dataset = ( 27 | datasets 28 | .load_dataset("json", data_files={"train": [str(raw_path)]}) 29 | ["train"] 30 | .shuffle(seed=0) 31 | ) 32 | print(f"Dataset has {len(dataset)} rows") 33 | indices = range(n) if n is not None else range(int(frac * len(dataset))) 34 | dataset = dataset.select(indices) 35 | print(f"Dataset after select has {len(dataset)} rows") 36 | 37 | dataset = dataset.map( 38 | lambda example: _tokenize_dataset(example=example, tokenizer=tokenizer), 39 | batched=False, 40 | num_proc=1, 41 | desc="Tokenizing dataset" 42 | ) 43 | 44 | llm = sglang.Engine( 45 | model_path=str(model_path), 46 | tp_size=arguments.tp, 47 | pp_size=arguments.pp, 48 | dp_size=arguments.dp, 49 | log_level="info" 50 | ) 51 | 52 | sampling_params = { 53 | "temperature": arguments.temperature, 54 | "max_new_tokens": arguments.max_new_tokens 55 | } 56 | 57 | batched_input_ids = [el["input_ids"] for el in dataset] 58 | 59 | start = time.perf_counter() 60 | outputs: list[dict] = llm.generate(input_ids=batched_input_ids, sampling_params=sampling_params) 61 | end = time.perf_counter() 62 | print(end - start, "total inference") 63 | 64 | 65 | def add_reply(example, idx): 66 | msgs = example["messages"].copy() 67 | msgs.append({"role": "assistant", "content": outputs[idx]["text"]}) 68 | return {"messages": msgs} 69 | 70 | 71 | dataset = dataset.map( 72 | add_reply, 73 | with_indices=True, 74 | batched=False, 75 | desc="Appending assistant replies" 76 | ) 77 | 78 | 79 | print("Saving to disk") 80 | dataset.select_columns(["id", "messages"]).to_json(output_path, force_ascii=False) 81 | 82 | 83 | def _parse_arguments() -> argparse.Namespace: 84 | parser = argparse.ArgumentParser( 85 | description="Generate trajectories" 86 | ) 87 | parser.add_argument( 88 | "--input", 89 | type=pathlib.Path, 90 | required=True, 91 | help="Path to JSON lines chat dataset as described in documentation wheres lines end with user response" 92 | ) 93 | parser.add_argument( 94 | "--model", 95 | type=pathlib.Path, 96 | required=True, 97 | help="Path to verifier model" 98 | ) 99 | parser.add_argument( 100 | "--tokenizer", 101 | type=pathlib.Path, 102 | required=True, 103 | help="Path to tokenizer, usually the same as the model" 104 | ) 105 | parser.add_argument( 106 | "--output", 107 | type=pathlib.Path, 108 | required=True, 109 | help="Path to jsonlines file where the processed dataset will be stored" 110 | ) 111 | parser.add_argument( 112 | "--temperature", 113 | type=float, 114 | default=0.0, 115 | help="Temperature" 116 | ) 117 | parser.add_argument( 118 | "--n", 119 | type=int, 120 | help="Number of sampels to take" 121 | ) 122 | parser.add_argument( 123 | "--max-new-tokens", 124 | type=int, 125 | help="Max new tokens" 126 | ) 127 | parser.add_argument( 128 | "--frac", 129 | type=float, 130 | help="Number of sampels to take from 0.0 to 1.0 percent" 131 | ) 132 | parser.add_argument( 133 | "--tp", 134 | type=int, 135 | help="tp size" 136 | ) 137 | parser.add_argument( 138 | "--pp", 139 | type=int, 140 | help="tp size" 141 | ) 142 | parser.add_argument( 143 | "--dp", 144 | type=int, 145 | help="dp size" 146 | ) 147 | return parser.parse_args() 148 | 149 | 150 | def _tokenize_dataset(example: dict, tokenizer: transformers.AutoTokenizer) -> dict[str, torch.LongTensor]: 151 | if example["messages"][-1]["role"] != "assistant": 152 | raise ValueError("Last message must be from an assistant") 153 | 154 | messages = example["messages"][:-1] # keep all except last assitant reply 155 | 156 | result = tokenizer.apply_chat_template( 157 | messages, 158 | tokenize=True, 159 | add_generation_prompt=True, 160 | return_dict=True 161 | ) 162 | 163 | return { 164 | "input_ids": result["input_ids"], 165 | "messages": messages 166 | } 167 | 168 | 169 | if __name__ == "__main__": 170 | _prepare_dataset() 171 | -------------------------------------------------------------------------------- /benchmark/benchmark.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sglang 3 | import pathlib 4 | import argparse 5 | import datasets 6 | 7 | 8 | def benchmark() -> None: 9 | # Parse arguments 10 | arguments = _parse_arguments() 11 | output_report_path = arguments.output 12 | prompts_path = arguments.prompts 13 | draft_tokens = arguments.draft 14 | top_k = arguments.k 15 | draft_steps = arguments.steps 16 | eagle_path = arguments.eagle 17 | model_path = arguments.model 18 | speculative_algorithm = arguments.speculative_algorithm 19 | n = arguments.n 20 | batch_size = arguments.bs 21 | temperature = arguments.temperature 22 | frac = arguments.frac 23 | if n is not None and frac is not None: 24 | raise ValueError("One of --n or --frac must be set") 25 | 26 | # Load dataset 27 | print("Loading raw dataset") 28 | dataset = ( 29 | datasets 30 | .load_dataset("json", data_files={"train": [str(prompts_path)]}) 31 | ["train"] 32 | .shuffle(seed=0) 33 | ) 34 | print(f"Dataset has {len(dataset)} rows") 35 | indices = range(n) if n is not None else range(int(frac * len(dataset))) 36 | dataset = dataset.select(indices) 37 | print(f"Dataset after select has {len(dataset)} rows") 38 | 39 | print("Creating sglang engine") 40 | 41 | # Prepare llm engine 42 | if speculative_algorithm is None: 43 | llm = sglang.Engine( 44 | model_path=str(model_path), 45 | max_running_requests=batch_size 46 | ) 47 | else: 48 | llm = sglang.Engine( 49 | model_path=str(model_path), 50 | speculative_algorithm=speculative_algorithm, 51 | speculative_draft_model_path=str(eagle_path), 52 | speculative_num_steps=draft_steps, 53 | speculative_eagle_topk=top_k, 54 | speculative_num_draft_tokens=draft_tokens, 55 | max_running_requests=batch_size, 56 | ) 57 | 58 | # Send requests to llm engine 59 | sampling_params = { 60 | "temperature": temperature, 61 | } 62 | 63 | total_verify_ct = 0 64 | total_latency = 0.0 65 | total_output_tokens = 0 66 | 67 | prompts = [example["prompt"] for example in dataset] 68 | outputs: list[dict] = llm.generate(prompts, sampling_params) 69 | llm.shutdown() 70 | 71 | # Collect metrics from llm engine 72 | for output in outputs: 73 | total_latency = max(total_latency, output["meta_info"]["e2e_latency"]) 74 | total_output_tokens += output["meta_info"]["completion_tokens"] 75 | if speculative_algorithm is not None: 76 | total_verify_ct += output["meta_info"]["spec_verify_ct"] 77 | 78 | total_output_throughput = total_output_tokens / total_latency 79 | if speculative_algorithm is not None and total_verify_ct != 0: 80 | accept_length = total_output_tokens / total_verify_ct 81 | else: 82 | accept_length = None 83 | 84 | # Create report and save to disk 85 | report_dict = { 86 | "output_throughput": total_output_throughput, 87 | "total_output_tokens": total_output_tokens, 88 | "total_latency": total_latency 89 | } 90 | 91 | if speculative_algorithm is not None: 92 | report_dict["acceptance_length"] = accept_length 93 | 94 | print(report_dict) 95 | 96 | with output_report_path.open("w") as report_file: 97 | json.dump(report_dict, report_file, indent=4) 98 | 99 | 100 | def _parse_arguments() -> argparse.Namespace: 101 | parser = argparse.ArgumentParser(description="Benchmark") 102 | parser.add_argument( 103 | "--model", 104 | type=pathlib.Path, 105 | required=True, 106 | help="Path to jsonlines file where the processed dataset will be stored" 107 | ) 108 | parser.add_argument( 109 | "--prompts", 110 | type=pathlib.Path, 111 | required=True, 112 | help="Path to jsonlines file with prompts" 113 | ) 114 | parser.add_argument( 115 | "--eagle", 116 | type=pathlib.Path, 117 | required=False, 118 | help="Path to jsonlines file where the processed dataset will be stored" 119 | ) 120 | parser.add_argument( 121 | "--speculative-algorithm", 122 | type=str, 123 | required=False, 124 | help="For example EAGLE or EAGLE 3" 125 | ) 126 | parser.add_argument( 127 | "--bs", 128 | type=int, 129 | required=True, 130 | help="For example EAGLE or EAGLE 3" 131 | ) 132 | parser.add_argument( 133 | "--steps", 134 | type=int, 135 | help="Number of sampels to take" 136 | ) 137 | parser.add_argument( 138 | "--k", 139 | type=int, 140 | help="Number of top k" 141 | ) 142 | parser.add_argument( 143 | "--draft", 144 | type=int, 145 | help="Number of draft tokens to generate" 146 | ) 147 | parser.add_argument( 148 | "--n", 149 | type=int, 150 | help="Number of sampels to take" 151 | ) 152 | parser.add_argument( 153 | "--frac", 154 | type=float, 155 | help="Number of sampels to take from 0.0 to 1.0 percent" 156 | ) 157 | parser.add_argument( 158 | "--output", 159 | type=pathlib.Path, 160 | required=True, 161 | help="Path to report" 162 | ) 163 | parser.add_argument( 164 | "--temperature", 165 | type=float, 166 | required=True, 167 | help="Temperature" 168 | ) 169 | return parser.parse_args() 170 | 171 | 172 | if __name__ == "__main__": 173 | benchmark() 174 | -------------------------------------------------------------------------------- /resources/raw_example_dataset.jsonl: -------------------------------------------------------------------------------- 1 | {"id": 0, "messages": [{"role": "user", "content": "Give me 100 prompt parameters that I can specify that will influence your output, e.g. voice, tone, register, style, audience etc."}, {"role": "assistant", "content": "Sure, here are 100 prompt parameters that you can specify to influence my output:\n\n1. Voice (e.g., male or female)\n2. Tone (e.g., serious, sarcastic, humorous, etc.)\n3. Register (e.g., formal, informal, academic, conversational, etc.)\n4. Style (e.g., narrative, descriptive, expository, persuasive, etc.)\n5. Audience (e.g., children, adults, experts, non-experts, etc.)\n6. Genre (e.g., fiction, non-fiction, poetry, drama, etc.)\n7. Language (e.g., English, French, Spanish, etc.)\n8. Dialect (e.g., American English, British English, etc.)\n9. Purpose (e.g., informative, instructive, entertaining, etc.)\n10. Mood (e.g., happy, sad, angry, etc.)\n11. Point of view (e.g., first person, third person, etc.)\n12. Setting (e.g., time period, location, etc.)\n13. Characterization (e.g., flat, round, static, dynamic, etc.)\n14. Conflict (e.g., internal, external, etc.)\n15. Plot (e.g., exposition, rising action, climax, etc.)\n16. Theme (e.g., love, friendship, betrayal, etc.)\n17. Symbolism (e.g., colors, objects, etc.)\n18. Imagery (e.g., visual, auditory, olfactory, etc.)\n19. Figurative language (e.g., simile, metaphor, hyperbole, etc.)\n20. Allusion (e.g., historical, cultural, literary, etc.)\n21. Irony (e.g., verbal, situational, dramatic, etc.)\n22. Foreshadowing (e.g., subtle, explicit, etc.)\n23. Flashback (e.g., chronological, non-chronological, etc.)\n24. Dialogue (e.g., realistic, stylized, etc.)\n25. Monologue (e.g., internal, external, etc.)\n26. Soliloquy (e.g., dramatic, introspective, etc.)\n27. Sound effects (e.g., ambient, diegetic, non-diegetic, etc.)\n28. Music (e.g., genre, tempo, etc.)\n29. Visual effects (e.g., lighting, color, etc.)\n30. Camera angles (e.g., low, high, close-up, etc.)\n31. Shot composition (e.g., rule of thirds, symmetry, etc.)\n32. Editing (e.g., cuts, fades, transitions, etc.)\n33. Pacing (e.g., fast, slow, steady, etc.)\n34. Structure (e.g., chronological, non-chronological, etc.)\n35. Tone of voice (e.g., serious, angry, sad, etc.)\n36. Accent (e.g., regional, international, etc.)\n37. Syntax (e.g., simple, complex, compound, etc.)\n38. Vocabulary (e.g., formal, informal, technical, etc.)\n39. Grammar (e.g., correct, incorrect, etc.)\n40. Spelling (e.g., American, British, etc.)\n41. Pronunciation (e.g., clear, mumbled, etc.)\n42. Enunciation (e.g., precise, slurred, etc.)\n43. Volume (e.g., loud, soft, etc.)\n44. Inflection (e.g., rising, falling, etc.)\n45. Pitch (e.g., high, low, etc.)\n46. Speed (e.g., fast, slow, etc.)\n47. Diction (e.g., clear, mumbled, etc.)\n48. Rhythm (e.g., staccato, legato, etc.)\n49. Timing ("}, {"role": "user", "content": "Continue"}, {"role": "assistant", "content": "49. Timing (e.g., pacing, pauses, etc.)\n50. Emphasis (e.g., stress, intonation, etc.)\n51. Emotion (e.g., happy, sad, angry, etc.)\n52. Attitude (e.g., positive, negative, neutral, etc.)\n53. Knowledge level (e.g., basic, intermediate, advanced, etc.)\n54. Educational background (e.g., high school, college, graduate school, etc.)\n55. Professional background (e.g., medicine, law, engineering, etc.)\n56. Cultural background (e.g., Western, Eastern, etc.)\n57. Geographic location (e.g., North America, Asia, Europe, etc.)\n58. Age (e.g., young adult, middle-aged, senior, etc.)\n59. Gender (e.g., male, female, non-binary, etc.)\n60. Religion (e.g., Christianity, Islam, Judaism, etc.)\n61. Political affiliation (e.g., conservative, liberal, moderate, etc.)\n62. Social status (e.g., wealthy, middle-class, low-income, etc.)\n63. Interpersonal relationships (e.g., parent-child, romantic, etc.)\n64. Career goals (e.g., entrepreneurship, academia, etc.)\n65. Life goals (e.g., travel, family, etc.)\n66. Hobbies\/interests (e.g., sports, music, art, etc.)\n67. Personality traits (e.g., extroverted, introverted, etc.)\n68. Learning style (e.g., visual, auditory, kinesthetic, etc.)\n69. Communication style (e.g., direct, indirect, assertive, etc.)\n70. Sensory preferences (e.g., visual, auditory, tactile, etc.)\n71. Learning environment (e.g., classroom, online, etc.)\n72. Communication medium (e.g., email, phone, in-person, etc.)\n73. Learning objectives (e.g., mastery, exploration, etc.)\n74. Learning outcome (e.g., knowledge acquisition, skill development, etc.)\n75. Learning pace (e.g., accelerated, self-paced, etc.)\n76. Learning format (e.g., lecture, discussion, hands-on, etc.)\n77. Learning method (e.g., traditional, experiential, etc.)\n78. Learning tools (e.g., textbooks, videos, software, etc.)\n79. Learning resources (e.g., library, online databases, etc.)\n80. Learning environment (e.g., quiet, collaborative, etc.)\n81. Learning motivation (e.g., intrinsic, extrinsic, etc.)\n82. Learning obstacles (e.g., lack of resources, time constraints, etc.)\n83. Learning preferences (e.g., self-directed, group-oriented, etc.)\n84. Learning feedback (e.g., formative, summative, etc.)\n85. Learning assessment (e.g., quizzes, exams, etc.)\n86. Learning evaluation (e.g., self-evaluation, peer evaluation, etc.)\n87. Learning reflection (e.g., journaling, discussion, etc.)\n88. Learning feedback (e.g., written, oral, etc.)\n89. Learning environment (e.g., physical, virtual, etc.)\n90. Learning support (e.g., tutoring, counseling, etc.)\n91. Learning goals (e.g., cognitive, affective, etc.)\n92. Learning objectives (e.g., knowledge, skills, attitudes, etc.)\n93. Learning outcomes (e.g., competence, proficiency, etc.)\n94. Learning activities (e.g., problem-solving, critical thinking, etc.)\n95. Learning strategies (e.g., memorization, repetition, etc.)\n96. Learning styles (e.g., visual, auditory, kinesthetic, etc.)\n97. Learning theories (e.g., behaviorism, constructivism,"},{"role": "user", "content": "Continue"}, {"role":"assistant", "content": "97. Learning theories (e.g., behaviorism, constructivism, etc.)\n98. Learning models (e.g., ADDIE, SAM, etc.)\n99. Learning technology (e.g., LMS, authoring tools, etc.)\n100. Learning trends (e.g., gamification, microlearning, etc.)\n\nI hope these prompt parameters help you specify the output you're looking for. Let me know if you have any other questions or if there's anything else I can assist you with."}]} 2 | -------------------------------------------------------------------------------- /pages/reproduce_eagle_1_paper.md: -------------------------------------------------------------------------------- 1 | ## Reproduce eagle 1 paper 2 | 3 | Page will cover all steps neede to reproduce eagle 1 paper for model [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) 4 | 5 | ## Docker 6 | 7 | In order to use tool you need to build docker 8 | 9 | ```bash 10 | docker build --tag vladislavkruglikov/eagle:0.3.1 -f docker/Dockerfile . 11 | ``` 12 | 13 | Or use already build one 14 | 15 | ```bash 16 | docker pull vladislavkruglikov/eagle:latest 17 | ``` 18 | 19 | ## Model 20 | 21 | Downloading model might require restricted access so you might need to export your token before 22 | 23 | ```bash 24 | export HF_TOKEN= 25 | huggingface-cli download meta-llama/Llama-2-7b-chat-hf \ 26 | --local-dir ./resources/meta-llama2-7b-chat-hf 27 | ``` 28 | 29 | Make sure chate template guards assistant reply with {% generation %} and {% endgeneration %} as shown in [this example](../resources/example_chat_template_with_generation_keyword.json). This is needed for huggingface to generate loss mask correctly such that loss attends only on assistant replies 30 | 31 | ## Prepare dataset 32 | 33 | In order to train models with our frameworks you need to supply dataset in particular format. It is jsonlines where each lines looks like this. Basically you have a list of messages and each message has 2 keys which are role and content 34 | 35 | ```json 36 | {"id": 0, "messages": [{"role": "user", "content": "Give me 100 prompt parameters that I can specify that will influence your output, e.g. voice, tone, register, style, audience etc."}, {"role": "assistant", "content": "Sure, here are 100 prompt parameters that you can specify to influence my output:\n\n1. Voice (e.g., male or female)\n2. Tone (e.g., serious, sarcastic, humorous, etc.)"}, {"role": "user", "content": "Continue"}, {"role": "assistant", "content": "3. Timing (e.g., pacing, pauses, etc.)\n4. Emphasis (e.g., stress, intonation, etc.)"}]} 37 | ``` 38 | 39 | In order to reproduce eagle 1 paper we will be using pre built script that downloads sharegpt dataset and formats as reference proposes 40 | 41 | ```bash 42 | docker run \ 43 | -v ./resources:/mnt/resources \ 44 | eagle \ 45 | python3 eagle/prepare_sharegpt_dataset.py \ 46 | --frac 1.0 \ 47 | --output /mnt/resources/sharegpt.jsonl 48 | ``` 49 | 50 | ## Train draft model 51 | 52 | I will demonstrate how to train draft model for particular configuration but you can twick it under your needs. Training script will log data into clearml dashboards. If you want to disable it (which is reccomended for first try run script) run 53 | 54 | ```bash 55 | export CLEARML_OFFLINE_MODE=1 56 | ``` 57 | 58 | 59 | Otherwise specify this environments 60 | 61 | ```bash 62 | export CLEARML_WEB_HOST= 63 | export CLEARML_API_HOST= 64 | export CLEARML_FILES_HOST= 65 | export CLEARML_API_ACCESS_KEY= 66 | export CLEARML_API_SECRET_KEY= 67 | export CLEARML_API_HOST_VERIFY_CERT= 68 | ``` 69 | 70 | Also it is up to you to configure 71 | 72 | ```bash 73 | export OMP_NUM_THREADS=4 74 | export CUDA_VISIBLE_DEVICES=7 75 | ``` 76 | 77 | I will demonstrate once again a very simple setup which will probably be the most popular. For that you need single gpu with roughly 80 HBM 78 | 79 | ```bash 80 | docker run \ 81 | --gpus all \ 82 | -e OMP_NUM_THREADS=4 \ 83 | -e CUDA_VISIBLE_DEVICES=7 \ 84 | -e CLEARML_OFFLINE_MODE=1 \ 85 | -v ./resources:/mnt/resources \ 86 | eagle \ 87 | torchrun \ 88 | --nnodes=1 \ 89 | --nproc_per_node=1 \ 90 | eagle/train.py \ 91 | --micro-batch-size 8 \ 92 | --gradient-accumulation-steps 2 \ 93 | --num-warmup-steps 1024 \ 94 | --num-training-steps 131072 \ 95 | --epochs 4 \ 96 | --clearml-project eagle \ 97 | --clearml-task 1gpus-8microbs-2accum-16globalbs \ 98 | --verifier-model-path /mnt/resources/meta-llama2-7b-chat-hf \ 99 | --dataset-path /mnt/resources/sharegpt.jsonl \ 100 | --eagle-config-path /mnt/resources/eagle_config.json \ 101 | --learning-rate 2e-4 \ 102 | --maximum-model-length 2048 \ 103 | --noise-low -0.1 \ 104 | --noise-high 0.1 \ 105 | --v-w 1.0 \ 106 | --p-w 0.1 \ 107 | --grad-clip 0.5 \ 108 | --b1 0.9 \ 109 | --b2 0.95 \ 110 | --cpdir /mnt/resources/checkpoints \ 111 | --save 4096 \ 112 | --mixed-precision bf16 \ 113 | --verifier-model-lm-head-dtype bfloat16 \ 114 | --verifier-model-dtype bfloat16 \ 115 | --eagle-dtype bfloat16 \ 116 | --attn flash_attention_2 117 | ``` 118 | 119 | For thoose who have access to 2 GPUs 120 | 121 | ```bash 122 | docker run \ 123 | --gpus all \ 124 | -e OMP_NUM_THREADS=4 \ 125 | -e CUDA_VISIBLE_DEVICES=5,7 \ 126 | -e CLEARML_OFFLINE_MODE=1 \ 127 | -v ./resources:/mnt/resources \ 128 | eagle \ 129 | torchrun \ 130 | --nnodes=1 \ 131 | --nproc_per_node=2 \ 132 | eagle/train.py \ 133 | --micro-batch-size 4 \ 134 | --gradient-accumulation-steps 2 \ 135 | --num-warmup-steps 2048 \ 136 | --num-training-steps 262144 \ 137 | --epochs 4 \ 138 | --clearml-project eagle \ 139 | --clearml-task 2gpus-4microbs-2accum-16globalbs \ 140 | --verifier-model-path /mnt/resources/meta-llama2-7b-chat-hf \ 141 | --dataset-path /mnt/resources/sharegpt.jsonl \ 142 | --eagle-config-path /mnt/resources/eagle_config.json \ 143 | --learning-rate 2e-4 \ 144 | --maximum-model-length 2048 \ 145 | --noise-low -0.1 \ 146 | --noise-high 0.1 \ 147 | --v-w 1.0 \ 148 | --p-w 0.1 \ 149 | --grad-clip 0.5 \ 150 | --b1 0.9 \ 151 | --b2 0.95 \ 152 | --cpdir /mnt/resources/checkpoints \ 153 | --save 4096 \ 154 | --mixed-precision bf16 \ 155 | --verifier-model-lm-head-dtype bfloat16 \ 156 | --verifier-model-dtype bfloat16 \ 157 | --eagle-dtype bfloat16 \ 158 | --attn flash_attention_2 159 | ``` 160 | 161 | For thoose who have access to 4 GPUs 162 | 163 | ```bash 164 | docker run \ 165 | --gpus all \ 166 | -e OMP_NUM_THREADS=4 \ 167 | -e CUDA_VISIBLE_DEVICES=1,3,5,7 \ 168 | -e CLEARML_OFFLINE_MODE=1 \ 169 | -v ./resources:/mnt/resources \ 170 | eagle \ 171 | torchrun \ 172 | --nnodes=1 \ 173 | --nproc_per_node=4 \ 174 | eagle/train.py \ 175 | --micro-batch-size 2 \ 176 | --gradient-accumulation-steps 2 \ 177 | --num-warmup-steps 4096 \ 178 | --num-training-steps 524288 \ 179 | --epochs 4 \ 180 | --clearml-project eagle \ 181 | --clearml-task 4gpus-2microbs-2accum-16globalbs \ 182 | --verifier-model-path /mnt/resources/meta-llama2-7b-chat-hf \ 183 | --dataset-path /mnt/resources/sharegpt.jsonl \ 184 | --eagle-config-path /mnt/resources/eagle_config.json \ 185 | --learning-rate 2e-4 \ 186 | --maximum-model-length 2048 \ 187 | --noise-low -0.1 \ 188 | --noise-high 0.1 \ 189 | --v-w 1.0 \ 190 | --p-w 0.1 \ 191 | --grad-clip 0.5 \ 192 | --b1 0.9 \ 193 | --b2 0.95 \ 194 | --cpdir /mnt/resources/checkpoints \ 195 | --save 4096 \ 196 | --mixed-precision bf16 \ 197 | --verifier-model-lm-head-dtype bfloat16 \ 198 | --verifier-model-dtype bfloat16 \ 199 | --eagle-dtype bfloat16 \ 200 | --attn flash_attention_2 201 | ``` 202 | -------------------------------------------------------------------------------- /eagle/train_tp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import json 4 | import math 5 | import torch 6 | import pathlib 7 | import clearml 8 | import logging 9 | import datasets 10 | import argparse 11 | import safetensors 12 | import transformers 13 | 14 | from eagle.llama2 import Llama2Model 15 | 16 | 17 | def coach() -> None: 18 | arguments = _parse_arguments() 19 | 20 | torch.backends.cuda.matmul.allow_tf32 = True 21 | logging.basicConfig(level=logging.INFO, format='%(levelname)s %(asctime)s [%(filename)s:%(lineno)d] %(message)s', datefmt='%Y-%m-%d %H:%M:%S') 22 | 23 | logging.info("Start to prepare clearml ") 24 | clearml_task = clearml.Task.init(project_name=arguments.clearml_project, task_name=arguments.clearml_task, reuse_last_task_id=False, continue_last_task=False, output_uri=False, auto_connect_frameworks=False, auto_resource_monitoring=False) 25 | clearml_logger = clearml_task.get_logger() 26 | 27 | logging.info("Start to prepare language model head ") 28 | lm_head = _initialize_verifier_lm_head(verifier_path=arguments.verifier_model_path).to(getattr(torch, arguments.verifier_model_lm_head_dtype)).to("cuda") 29 | logging.info("Language model head has dtype %s", next(lm_head.parameters()).dtype) 30 | logging.info("Language model head has %f billion parameters", _count_parameters(model=lm_head) / 10 ** 9) 31 | clearml_logger.report_single_value(name="Language model head parameters billion", value=_count_parameters(model=lm_head) / 10 ** 9) 32 | 33 | logging.info("Start to prepare target model ") 34 | verifier_model = transformers.AutoModelForCausalLM.from_pretrained(arguments.verifier_model_path, device_map="auto", torch_dtype=getattr(torch, arguments.verifier_model_dtype), attn_implementation=arguments.attn) 35 | verifier_model = verifier_model.eval() 36 | logging.info("Target model head has dtype %s", next(verifier_model.parameters()).dtype) 37 | logging.info("Target model head has %f billion parameters", _count_parameters(model=verifier_model) / 10 ** 9) 38 | clearml_logger.report_single_value(name="Target model head parameters billion", value=_count_parameters(model=verifier_model) / 10 ** 9) 39 | 40 | logging.info("Start to prepare draft model ") 41 | config = transformers.AutoConfig.from_pretrained(arguments.eagle_config_path) 42 | model = Llama2Model(config, load_emb=True, path=arguments.verifier_model_path).to(getattr(torch, arguments.eagle_dtype)).to("cuda") 43 | logging.info("Draft model head has dtype %s", next(model.parameters()).dtype) 44 | logging.info("Draft model head has %f billion parameters", _count_parameters(model=model) / 10 ** 9) 45 | model.train() 46 | clearml_logger.report_single_value(name="Draft model head parameters billion", value=_count_parameters(model=model) / 10 ** 9) 47 | 48 | logging.info("Start to prepare data ") 49 | dataset = datasets.load_dataset("json", data_files={"train": [arguments.dataset_path]})["train"] 50 | dataset = Dataset(dataset=dataset) 51 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=arguments.micro_batch_size, collate_fn=Collator(arguments.verifier_model_path)) 52 | logging.info("Dataset contains %d samples", len(dataset)) 53 | 54 | logging.info("Start to prepare miscellaneous ") 55 | criterion = torch.nn.SmoothL1Loss(reduction="none") 56 | model_optimizer = torch.optim.AdamW(model.parameters(), lr=arguments.learning_rate, betas=(arguments.b1, arguments.b2)) 57 | 58 | scheduler = transformers.get_linear_schedule_with_warmup(optimizer=model_optimizer, num_warmup_steps=arguments.num_warmup_steps, num_training_steps=arguments.num_training_steps) 59 | 60 | logging.info("Start training ") 61 | total_steps_passed = 0 62 | for epoch in range(arguments.epochs): 63 | training_iterator = iter(dataloader) 64 | num_samples_in_epoch = len(dataloader) 65 | remainder = num_samples_in_epoch % arguments.gradient_accumulation_steps 66 | remainder = remainder if remainder != 0 else arguments.gradient_accumulation_steps 67 | total_gradient_updates = math.ceil(num_samples_in_epoch / arguments.gradient_accumulation_steps) 68 | for update_step in range(total_gradient_updates): 69 | step_start = time.perf_counter() 70 | accum_loss = 0.0 71 | batch_samples = [] 72 | num_batches_in_step = arguments.gradient_accumulation_steps if update_step != (total_gradient_updates - 1) else remainder 73 | for _ in range(num_batches_in_step): 74 | batch_samples += [next(training_iterator)] 75 | num_items_in_batch = sum([batch["loss_mask"][:, :arguments.maximum_model_length].sum() for batch in batch_samples]) 76 | 77 | if num_items_in_batch == 0: 78 | logging.warning("num_items_in_batch is zero, skipping steps") 79 | continue 80 | 81 | step_correctly_predicted_tokens_count = 0 82 | 83 | for i, batch in enumerate(batch_samples): 84 | batch = _make_eagle_input(batch, verifier_model, arguments.maximum_model_length, arguments.noise_low, arguments.noise_high, "cuda") 85 | batch["hidden_states"] = batch["hidden_states"] 86 | batch["target"] = batch["target"] 87 | predict = model(batch["hidden_states"].to( getattr(torch, arguments.eagle_dtype) ), input_ids=batch["input_ids"]) 88 | with torch.no_grad(): 89 | target_head = lm_head(batch["target"].to( getattr(torch, arguments.verifier_model_lm_head_dtype) ),) 90 | target_p = torch.nn.Softmax(dim=2)(target_head) 91 | target_p = target_p.detach() 92 | out_head = lm_head(predict.to(getattr(torch, arguments.verifier_model_lm_head_dtype))) 93 | out_logp = torch.nn.LogSoftmax(dim=2)(out_head) 94 | 95 | loss_mask = batch["loss_mask"][:, :, None] 96 | 97 | _, target_max_p_tokens = torch.max(target_p, 2) 98 | _, ealge_max_p_tokens = torch.max(out_logp, 2) 99 | step_correctly_predicted_tokens_count += ((target_max_p_tokens == ealge_max_p_tokens) * loss_mask.squeeze()).sum().item() 100 | 101 | plogp = target_p * out_logp 102 | ploss = -torch.sum(torch.sum(loss_mask * plogp, 2)) 103 | vloss = criterion(predict, batch["target"]) 104 | vloss = torch.sum(torch.mean(loss_mask * vloss, 2)) 105 | loss = arguments.v_w * vloss + arguments.p_w * ploss 106 | loss = loss / num_items_in_batch 107 | accum_loss += loss.item() 108 | loss.backward() 109 | torch.nn.utils.clip_grad_norm_(model.parameters(), arguments.grad_clip) 110 | 111 | model_optimizer.step() 112 | scheduler.step() 113 | model_optimizer.zero_grad() 114 | 115 | total_steps_passed += 1 116 | 117 | step_end = time.perf_counter() 118 | mean_step_duration_across_gpus = step_end - step_start 119 | 120 | time_taken = step_end - step_start 121 | total_throughput = num_items_in_batch / time_taken 122 | 123 | loss_tensor = accum_loss 124 | 125 | accuracy = float("nan") 126 | if num_items_in_batch != 0: 127 | accuracy = step_correctly_predicted_tokens_count / num_items_in_batch 128 | 129 | current_lr = arguments.learning_rate 130 | if arguments.num_warmup_steps is not None: 131 | current_lr = scheduler.get_last_lr()[0] 132 | 133 | logging.info("epoch %d/%d, step %d/%d, mean step duration across gpus %.4f seconds, lr %.8f, loss %.4f, throughput %d tps, accuracy %.4f", epoch + 1, arguments.epochs, total_steps_passed, arguments.num_training_steps, mean_step_duration_across_gpus, current_lr, loss_tensor, total_throughput, accuracy) 134 | clearml_logger.report_scalar(title="train/steploss", series="series", value=loss_tensor, iteration=total_steps_passed) 135 | clearml_logger.report_scalar(title="train/throughput tokens/s", series="series", value=total_throughput, iteration=total_steps_passed) 136 | clearml_logger.report_scalar(title="train/stepaccuracy", series="series", value=accuracy, iteration=total_steps_passed) 137 | clearml_logger.report_scalar(title="train/epoch", series="series", value=epoch, iteration=total_steps_passed) 138 | clearml_logger.report_scalar(title="train/lr", series="series", value=current_lr, iteration=total_steps_passed) 139 | 140 | if total_steps_passed % arguments.save == 0: 141 | pathlib.Path(f"{arguments.cpdir}/epoch_{epoch}_step_{total_steps_passed}").mkdir(parents=True) 142 | safetensors.torch.save_file(model.state_dict(), f"{arguments.cpdir}/epoch_{epoch}_step_{total_steps_passed}/model.safetensors") 143 | with open(arguments.eagle_config_path, "r") as f: 144 | config_data = json.load(f) 145 | config_data["architectures"] = ["LlamaForCausalLMEagle"] 146 | with open(f"{arguments.cpdir}/epoch_{epoch}_step_{total_steps_passed}/config.json", "w") as file: 147 | json.dump(config_data, file, ensure_ascii=False, indent=4) 148 | 149 | if total_steps_passed == arguments.num_training_steps: 150 | break 151 | 152 | if total_steps_passed == arguments.num_training_steps: 153 | break 154 | 155 | 156 | def _parse_arguments() -> argparse.Namespace: 157 | parser = argparse.ArgumentParser(description="Coach that trains eagle draft model") 158 | parser.add_argument("--micro-batch-size", type=int, required=True, help="Micro batch size") 159 | parser.add_argument("--gradient-accumulation-steps", type=int, required=True, help="Gradient accumulation steps") 160 | parser.add_argument("--num-warmup-steps", type=int, required=True, help="Num warmup steps") 161 | parser.add_argument("--num-training-steps", type=int, required=True, help="Num training steps") 162 | parser.add_argument("--clearml-project", type=str, required=True, help="Clearml project") 163 | parser.add_argument("--clearml-task", type=str, required=True, help="Clearml task") 164 | parser.add_argument("--epochs", type=int, required=True, help="Epochs") 165 | parser.add_argument("--verifier-model-path", type=str, required=True, help="verifier_model_path") 166 | parser.add_argument("--dataset-path", type=str, required=True, help="verifier_model_path") 167 | parser.add_argument("--eagle-config-path", type=str, required=True, help="eagle_config_path") 168 | parser.add_argument("--learning-rate", type=float, required=True, help="eagle_config_path") 169 | parser.add_argument("--maximum-model-length", type=int, required=True, help="eagle_config_path") 170 | parser.add_argument("--noise-low", type=float, required=True, help="eagle_config_path") 171 | parser.add_argument("--noise-high", type=float, required=True, help="eagle_config_path") 172 | parser.add_argument("--v-w", type=float, required=True, help="eagle_config_path") 173 | parser.add_argument("--p-w", type=float, required=True, help="eagle_config_path") 174 | parser.add_argument("--grad-clip", type=float, required=True, help="eagle_config_path") 175 | parser.add_argument("--b1", type=float, required=True, help="eagle_config_path") 176 | parser.add_argument("--b2", type=float, required=True, help="eagle_config_path") 177 | parser.add_argument("--cpdir", type=pathlib.Path, default="./checkpoints", help="Path to folder to save checkpoints") 178 | parser.add_argument("--save", type=int, required=False, help="Save model after every number of steps") 179 | parser.add_argument("--mixed-precision", type=str, required=False, help="Save model after every number of steps") 180 | parser.add_argument("--verifier-model-lm-head-dtype", type=str, required=False, help="Save model after every number of steps") 181 | parser.add_argument("--verifier-model-dtype", type=str, required=False, help="Save model after every number of steps") 182 | parser.add_argument("--eagle-dtype", type=str, required=False, help="Save model after every number of steps") 183 | parser.add_argument("--attn", type=str, required=False, help="Save model after every number of steps") 184 | return parser.parse_args() 185 | 186 | 187 | def _count_parameters(model: torch.nn.Module) -> int: 188 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 189 | 190 | 191 | def _initialize_verifier_lm_head(verifier_path: pathlib.Path) -> torch.nn.Linear: 192 | with open(f"{verifier_path}/config.json", "r") as file: 193 | config = json.load(file) 194 | head = torch.nn.Linear(config["hidden_size"], config["vocab_size"], bias=False) 195 | with open(os.path.join(verifier_path, "model.safetensors.index.json"), "r") as f: 196 | index_json = json.loads(f.read()) 197 | head_path = index_json["weight_map"]["lm_head.weight"] 198 | with safetensors.safe_open(os.path.join(verifier_path, head_path), framework="pt") as f: 199 | tensor = f.get_slice("lm_head.weight")[:, :config["hidden_size"]] 200 | head.weight.data = tensor 201 | head.eval() 202 | for param in head.parameters(): 203 | param.requires_grad = False 204 | return head 205 | 206 | 207 | class Dataset(torch.utils.data.Dataset): 208 | def __init__(self, dataset: datasets.Dataset) -> None: 209 | self._dataset = dataset 210 | 211 | def __len__(self) -> int: 212 | return len(self._dataset) 213 | 214 | def __getitem__(self, index: int) -> dict: 215 | return self._dataset[index] 216 | 217 | 218 | class Collator: 219 | def __init__(self, model_path) -> None: 220 | self._tokenizer = transformers.AutoTokenizer.from_pretrained(str(model_path), use_fast=True) 221 | self._tokenizer.pad_token = "[PAD]" 222 | self._tokenizer.pad_token_id = 0 223 | 224 | def __call__(self, features: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor]: 225 | result = self._tokenizer.apply_chat_template([m["messages"] for m in features], tokenize=True, add_generation_prompt=False, return_dict=True, return_assistant_tokens_mask=True, return_tensors="pt", padding=True) 226 | return { 227 | "input_ids": result["input_ids"], 228 | "loss_mask": result["assistant_masks"] 229 | } 230 | 231 | 232 | def _make_eagle_input(batch, verifier_model, max_model_len, transform_uniform_low, transformer_uniform_high, device): 233 | input_ids = batch["input_ids"].to(device)[:, :max_model_len] 234 | loss_mask = batch["loss_mask"].to(device)[:, :max_model_len] 235 | 236 | with torch.no_grad(): 237 | outs_big = verifier_model(input_ids, output_hidden_states=True, use_cache=False) 238 | hidden_state_big = outs_big.hidden_states[-1] 239 | hidden_state_big = _apply_noise_to_hidden_state(hidden_state_big, transform_uniform_low, transformer_uniform_high) 240 | T, L, D = hidden_state_big.shape 241 | target = hidden_state_big.new_zeros((T, L, D)) 242 | target[:, :-1, :] = hidden_state_big[:, 1:, :] 243 | input_ids = torch.cat((input_ids[:, 1:], torch.zeros(input_ids.size(0), 1, dtype=input_ids.dtype, device=input_ids.device)), dim=1) 244 | batch = {"input_ids": input_ids, "hidden_states": hidden_state_big, "target": target, "loss_mask": loss_mask} 245 | return batch 246 | 247 | 248 | def _apply_noise_to_hidden_state(hidden_state: torch.FloatTensor, transform_uniform_low, transformer_uniform_high) -> None: 249 | noise = torch.rand_like(hidden_state) * (transformer_uniform_high - transform_uniform_low) + transform_uniform_low 250 | noisy_tensor = hidden_state + noise 251 | return noisy_tensor 252 | 253 | 254 | if __name__ == "__main__": 255 | coach() 256 | -------------------------------------------------------------------------------- /eagle/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import json 4 | import math 5 | import torch 6 | import pathlib 7 | import clearml 8 | import logging 9 | import datasets 10 | import argparse 11 | import accelerate 12 | import contextlib 13 | import safetensors 14 | import transformers 15 | 16 | from eagle.llama2 import Llama2Model 17 | 18 | 19 | def coach() -> None: 20 | arguments = _parse_arguments() 21 | 22 | accelerator = accelerate.Accelerator(log_with="all", gradient_accumulation_steps=arguments.gradient_accumulation_steps, mixed_precision=arguments.mixed_precision) 23 | accelerate.utils.set_seed(seed=0) 24 | torch.backends.cuda.matmul.allow_tf32 = True 25 | logging.basicConfig(level=logging.INFO, format='%(levelname)s %(asctime)s [%(filename)s:%(lineno)d] %(message)s', datefmt='%Y-%m-%d %H:%M:%S') 26 | logger = accelerate.logging.get_logger(name=__name__, log_level="INFO") 27 | 28 | if accelerator.is_main_process: 29 | logger.info("Start to prepare clearml ", main_process_only=True) 30 | clearml_task = clearml.Task.init(project_name=arguments.clearml_project, task_name=arguments.clearml_task, reuse_last_task_id=False, continue_last_task=False, output_uri=False, auto_connect_frameworks=False, auto_resource_monitoring=False) 31 | clearml_logger = clearml_task.get_logger() 32 | 33 | logger.info("Start to prepare language model head ", main_process_only=True) 34 | lm_head = _initialize_verifier_lm_head(verifier_path=arguments.verifier_model_path).to(getattr(torch, arguments.verifier_model_lm_head_dtype)).to(accelerator.device) 35 | logger.info("Language model head has dtype %s", next(lm_head.parameters()).dtype, main_process_only=True) 36 | logger.info("Language model head has %f billion parameters", _count_parameters(model=lm_head) / 10 ** 9, main_process_only=True) 37 | if accelerator.is_main_process: 38 | clearml_logger.report_single_value(name="Language model head parameters billion", value=_count_parameters(model=lm_head) / 10 ** 9) 39 | 40 | logger.info("Start to prepare target model ", main_process_only=True) 41 | verifier_model = transformers.AutoModelForCausalLM.from_pretrained(arguments.verifier_model_path, device_map=accelerator.device, torch_dtype=getattr(torch, arguments.verifier_model_dtype), attn_implementation=arguments.attn) 42 | verifier_model = verifier_model.eval() 43 | logger.info("Target model head has dtype %s", next(verifier_model.parameters()).dtype, main_process_only=True) 44 | logger.info("Target model head has %f billion parameters", _count_parameters(model=verifier_model) / 10 ** 9, main_process_only=True) 45 | if accelerator.is_main_process: 46 | clearml_logger.report_single_value(name="Target model head parameters billion", value=_count_parameters(model=verifier_model) / 10 ** 9) 47 | 48 | logger.info("Start to prepare draft model ", main_process_only=True) 49 | config = transformers.AutoConfig.from_pretrained(arguments.eagle_config_path) 50 | model = Llama2Model(config, load_emb=True, path=arguments.verifier_model_path).to(getattr(torch, arguments.eagle_dtype)).to(accelerator.device) 51 | logger.info("Draft model head has dtype %s", next(model.parameters()).dtype, main_process_only=True) 52 | logger.info("Draft model head has %f billion parameters", _count_parameters(model=model) / 10 ** 9, main_process_only=True) 53 | model.train() 54 | accelerator.register_for_checkpointing(model) 55 | if accelerator.is_main_process: 56 | clearml_logger.report_single_value(name="Draft model head parameters billion", value=_count_parameters(model=model) / 10 ** 9) 57 | 58 | logger.info("Start to prepare data ", main_process_only=True) 59 | dataset = datasets.load_dataset("json", data_files={"train": [arguments.dataset_path]})["train"] 60 | dataset = Dataset(dataset=dataset) 61 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=arguments.micro_batch_size, collate_fn=Collator(arguments.verifier_model_path)) 62 | logger.info("Dataset contains %d samples", len(dataset), main_process_only=True) 63 | 64 | logger.info("Start to prepare miscellaneous ", main_process_only=True) 65 | criterion = torch.nn.SmoothL1Loss(reduction="none") 66 | model_optimizer = torch.optim.AdamW(model.parameters(), lr=arguments.learning_rate, betas=(arguments.b1, arguments.b2)) 67 | accelerator.register_for_checkpointing(model_optimizer) 68 | 69 | scheduler = transformers.get_linear_schedule_with_warmup(optimizer=model_optimizer, num_warmup_steps=arguments.num_warmup_steps, num_training_steps=arguments.num_training_steps) 70 | accelerator.register_for_checkpointing(scheduler) 71 | 72 | model = accelerator.prepare_model(model) 73 | model_optimizer = accelerator.prepare_optimizer(model_optimizer) 74 | dataloader = accelerator.prepare_data_loader(dataloader, device_placement=True) 75 | scheduler = accelerator.prepare_scheduler(scheduler) 76 | 77 | logger.info("Start training ", main_process_only=True) 78 | total_steps_passed = 0 79 | for epoch in range(arguments.epochs): 80 | training_iterator = iter(dataloader) 81 | num_samples_in_epoch = len(dataloader) 82 | remainder = num_samples_in_epoch % arguments.gradient_accumulation_steps 83 | remainder = remainder if remainder != 0 else arguments.gradient_accumulation_steps 84 | total_gradient_updates = math.ceil(num_samples_in_epoch / arguments.gradient_accumulation_steps) 85 | for update_step in range(total_gradient_updates): 86 | step_start = time.perf_counter() 87 | accum_loss = 0.0 88 | batch_samples = [] 89 | num_batches_in_step = arguments.gradient_accumulation_steps if update_step != (total_gradient_updates - 1) else remainder 90 | for _ in range(num_batches_in_step): 91 | batch_samples += [next(training_iterator)] 92 | num_items_in_batch = sum([batch["loss_mask"][:, :arguments.maximum_model_length].sum() for batch in batch_samples]) 93 | num_items_in_batch = accelerator.gather(num_items_in_batch).sum().item() 94 | step_correctly_predicted_tokens_count = 0 95 | 96 | for i, batch in enumerate(batch_samples): 97 | if (i < len(batch_samples) - 1 and accelerator.num_processes > 1): 98 | ctx = model.no_sync 99 | else: 100 | ctx = contextlib.nullcontext 101 | with ctx(): 102 | batch = _make_eagle_input(batch, verifier_model, arguments.maximum_model_length, arguments.noise_low, arguments.noise_high, accelerator.device) 103 | batch["hidden_states"] = batch["hidden_states"] 104 | batch["target"] = batch["target"] 105 | predict = model(batch["hidden_states"].to( getattr(torch, arguments.eagle_dtype) ), input_ids=batch["input_ids"]) 106 | with torch.no_grad(): 107 | target_head = lm_head(batch["target"].to( getattr(torch, arguments.verifier_model_lm_head_dtype) ),) 108 | target_p = torch.nn.Softmax(dim=2)(target_head) 109 | target_p = target_p.detach() 110 | out_head = lm_head(predict.to(getattr(torch, arguments.verifier_model_lm_head_dtype))) 111 | out_logp = torch.nn.LogSoftmax(dim=2)(out_head) 112 | 113 | loss_mask = batch["loss_mask"][:, :, None] 114 | 115 | _, target_max_p_tokens = torch.max(target_p, 2) 116 | _, ealge_max_p_tokens = torch.max(out_logp, 2) 117 | step_correctly_predicted_tokens_count += ((target_max_p_tokens == ealge_max_p_tokens) * loss_mask.squeeze()).sum().item() 118 | 119 | plogp = target_p * out_logp 120 | ploss = -torch.sum(torch.sum(loss_mask * plogp, 2)) 121 | vloss = criterion(predict, batch["target"]) 122 | vloss = torch.sum(torch.mean(loss_mask * vloss, 2)) 123 | loss = arguments.v_w * vloss + arguments.p_w * ploss 124 | loss = (loss * arguments.gradient_accumulation_steps * accelerator.num_processes) / num_items_in_batch 125 | accum_loss += loss.item() 126 | accelerator.backward(loss) 127 | accelerator.clip_grad_value_(model.parameters(), arguments.grad_clip) 128 | 129 | model_optimizer.step() 130 | scheduler.step() 131 | model_optimizer.zero_grad() 132 | 133 | total_steps_passed += 1 134 | 135 | step_end = time.perf_counter() 136 | step_duration = torch.tensor(step_end - step_start, device=accelerator.device) 137 | mean_step_duration_across_gpus = accelerator.reduce(step_duration, reduction="mean").item() 138 | 139 | time_taken = step_end - step_start 140 | throughput = torch.tensor(num_items_in_batch / time_taken, device=accelerator.device) 141 | total_throughput = accelerator.reduce(throughput, reduction="sum").item() 142 | 143 | loss_tensor = torch.tensor(accum_loss / (arguments.gradient_accumulation_steps * accelerator.num_processes), device=accelerator.device) 144 | loss_tensor = accelerator.reduce(loss_tensor, reduction="sum").item() 145 | 146 | accuracy = float("nan") 147 | if num_items_in_batch != 0: 148 | step_correctly_predicted_tokens_count = accelerator.reduce(torch.tensor(step_correctly_predicted_tokens_count, device=accelerator.device), reduction="sum").item() 149 | accuracy = step_correctly_predicted_tokens_count / num_items_in_batch 150 | 151 | current_lr = arguments.learning_rate 152 | if arguments.num_warmup_steps is not None: 153 | current_lr = scheduler.get_last_lr()[0] 154 | 155 | logger.info("epoch %d/%d, step %d/%d, mean step duration across gpus %.4f seconds, lr %.8f, loss %.4f, throughput %d tps, accuracy %.4f", epoch + 1, arguments.epochs, total_steps_passed, arguments.num_training_steps, mean_step_duration_across_gpus, current_lr, loss_tensor, total_throughput, accuracy, main_process_only=True) 156 | if accelerator.is_main_process: 157 | clearml_logger.report_scalar(title="train/steploss", series="series", value=loss_tensor, iteration=total_steps_passed) 158 | clearml_logger.report_scalar(title="train/throughput tokens/s", series="series", value=total_throughput, iteration=total_steps_passed) 159 | clearml_logger.report_scalar(title="train/stepaccuracy", series="series", value=accuracy, iteration=total_steps_passed) 160 | clearml_logger.report_scalar(title="train/epoch", series="series", value=epoch, iteration=total_steps_passed) 161 | clearml_logger.report_scalar(title="train/lr", series="series", value=current_lr, iteration=total_steps_passed) 162 | 163 | if accelerator.is_local_main_process and total_steps_passed % arguments.save == 0: 164 | accelerator.save_state(output_dir=f"{arguments.cpdir}/epoch_{epoch}_step_{total_steps_passed}") 165 | with open(arguments.eagle_config_path, "r") as f: 166 | config_data = json.load(f) 167 | config_data["architectures"] = ["LlamaForCausalLMEagle"] 168 | with open(f"{arguments.cpdir}/epoch_{epoch}_step_{total_steps_passed}/config.json", "w") as file: 169 | json.dump(config_data, file, ensure_ascii=False, indent=4) 170 | 171 | if total_steps_passed == arguments.num_training_steps: 172 | break 173 | 174 | if total_steps_passed == arguments.num_training_steps: 175 | break 176 | 177 | accelerator.end_training() 178 | 179 | 180 | def _parse_arguments() -> argparse.Namespace: 181 | parser = argparse.ArgumentParser(description="Coach that trains eagle draft model") 182 | parser.add_argument("--micro-batch-size", type=int, required=True, help="Micro batch size") 183 | parser.add_argument("--gradient-accumulation-steps", type=int, required=True, help="Gradient accumulation steps") 184 | parser.add_argument("--num-warmup-steps", type=int, required=True, help="Num warmup steps") 185 | parser.add_argument("--num-training-steps", type=int, required=True, help="Num training steps") 186 | parser.add_argument("--clearml-project", type=str, required=True, help="Clearml project") 187 | parser.add_argument("--clearml-task", type=str, required=True, help="Clearml task") 188 | parser.add_argument("--epochs", type=int, required=True, help="Epochs") 189 | parser.add_argument("--verifier-model-path", type=str, required=True, help="verifier_model_path") 190 | parser.add_argument("--dataset-path", type=str, required=True, help="verifier_model_path") 191 | parser.add_argument("--eagle-config-path", type=str, required=True, help="eagle_config_path") 192 | parser.add_argument("--learning-rate", type=float, required=True, help="eagle_config_path") 193 | parser.add_argument("--maximum-model-length", type=int, required=True, help="eagle_config_path") 194 | parser.add_argument("--noise-low", type=float, required=True, help="eagle_config_path") 195 | parser.add_argument("--noise-high", type=float, required=True, help="eagle_config_path") 196 | parser.add_argument("--v-w", type=float, required=True, help="eagle_config_path") 197 | parser.add_argument("--p-w", type=float, required=True, help="eagle_config_path") 198 | parser.add_argument("--grad-clip", type=float, required=True, help="eagle_config_path") 199 | parser.add_argument("--b1", type=float, required=True, help="eagle_config_path") 200 | parser.add_argument("--b2", type=float, required=True, help="eagle_config_path") 201 | parser.add_argument("--cpdir", type=pathlib.Path, default="./checkpoints", help="Path to folder to save checkpoints") 202 | parser.add_argument("--save", type=int, required=False, help="Save model after every number of steps") 203 | parser.add_argument("--mixed-precision", type=str, required=False, help="Save model after every number of steps") 204 | parser.add_argument("--verifier-model-lm-head-dtype", type=str, required=False, help="Save model after every number of steps") 205 | parser.add_argument("--verifier-model-dtype", type=str, required=False, help="Save model after every number of steps") 206 | parser.add_argument("--eagle-dtype", type=str, required=False, help="Save model after every number of steps") 207 | parser.add_argument("--attn", type=str, required=False, help="Save model after every number of steps") 208 | return parser.parse_args() 209 | 210 | 211 | def _count_parameters(model: torch.nn.Module) -> int: 212 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 213 | 214 | 215 | def _initialize_verifier_lm_head(verifier_path: pathlib.Path) -> torch.nn.Linear: 216 | with open(f"{verifier_path}/config.json", "r") as file: 217 | config = json.load(file) 218 | head = torch.nn.Linear(config["hidden_size"], config["vocab_size"], bias=False) 219 | with open(os.path.join(verifier_path, "model.safetensors.index.json"), "r") as f: 220 | index_json = json.loads(f.read()) 221 | head_path = index_json["weight_map"]["lm_head.weight"] 222 | with safetensors.safe_open(os.path.join(verifier_path, head_path), framework="pt") as f: 223 | tensor = f.get_slice("lm_head.weight")[:, :config["hidden_size"]] 224 | head.weight.data = tensor 225 | head.eval() 226 | for param in head.parameters(): 227 | param.requires_grad = False 228 | return head 229 | 230 | 231 | class Dataset(torch.utils.data.Dataset): 232 | def __init__(self, dataset: datasets.Dataset) -> None: 233 | self._dataset = dataset 234 | 235 | def __len__(self) -> int: 236 | return len(self._dataset) 237 | 238 | def __getitem__(self, index: int) -> dict: 239 | return self._dataset[index] 240 | 241 | 242 | class Collator: 243 | def __init__(self, model_path) -> None: 244 | self._tokenizer = transformers.AutoTokenizer.from_pretrained(str(model_path), use_fast=True) 245 | self._tokenizer.pad_token = "[PAD]" 246 | self._tokenizer.pad_token_id = 0 247 | 248 | def __call__(self, features: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor]: 249 | result = self._tokenizer.apply_chat_template([m["messages"] for m in features], tokenize=True, add_generation_prompt=False, return_dict=True, return_assistant_tokens_mask=True, return_tensors="pt", padding=True) 250 | return { 251 | "input_ids": result["input_ids"], 252 | "loss_mask": result["assistant_masks"] 253 | } 254 | 255 | 256 | def _make_eagle_input(batch, verifier_model, max_model_len, transform_uniform_low, transformer_uniform_high, device): 257 | input_ids = batch["input_ids"].to(device)[:, :max_model_len] 258 | loss_mask = batch["loss_mask"].to(device)[:, :max_model_len] 259 | 260 | with torch.no_grad(): 261 | outs_big = verifier_model(input_ids, output_hidden_states=True, use_cache=False) 262 | hidden_state_big = outs_big.hidden_states[-1] 263 | hidden_state_big = _apply_noise_to_hidden_state(hidden_state_big, transform_uniform_low, transformer_uniform_high) 264 | T, L, D = hidden_state_big.shape 265 | target = hidden_state_big.new_zeros((T, L, D)) 266 | target[:, :-1, :] = hidden_state_big[:, 1:, :] 267 | input_ids = torch.cat((input_ids[:, 1:], torch.zeros(input_ids.size(0), 1, dtype=input_ids.dtype, device=input_ids.device)), dim=1) 268 | batch = {"input_ids": input_ids, "hidden_states": hidden_state_big, "target": target, "loss_mask": loss_mask} 269 | return batch 270 | 271 | 272 | def _apply_noise_to_hidden_state(hidden_state: torch.FloatTensor, transform_uniform_low, transformer_uniform_high) -> None: 273 | noise = torch.rand_like(hidden_state) * (transformer_uniform_high - transform_uniform_low) + transform_uniform_low 274 | noisy_tensor = hidden_state + noise 275 | return noisy_tensor 276 | 277 | 278 | if __name__ == "__main__": 279 | coach() -------------------------------------------------------------------------------- /eagle/llama2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | from typing import List, Optional, Tuple 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | import torch.utils.checkpoint 8 | from torch import nn 9 | 10 | from transformers.activations import ACT2FN 11 | top_k=10 12 | 13 | # Copied from transformers.models.bart.modeling_bart._make_causal_mask 14 | def _make_causal_mask( 15 | input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 16 | ): 17 | """ 18 | Make causal mask used for bi-directional self-attention. 19 | """ 20 | bsz, tgt_len = input_ids_shape 21 | mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) 22 | mask_cond = torch.arange(mask.size(-1), device=device) 23 | mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) 24 | mask = mask.to(dtype) 25 | 26 | if past_key_values_length > 0: 27 | mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) 28 | return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) 29 | 30 | 31 | # Copied from transformers.models.bart.modeling_bart._expand_mask 32 | def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): 33 | """ 34 | Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. 35 | """ 36 | bsz, src_len = mask.size() 37 | tgt_len = tgt_len if tgt_len is not None else src_len 38 | 39 | expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) 40 | 41 | inverted_mask = 1.0 - expanded_mask 42 | 43 | return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) 44 | 45 | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: 46 | """ 47 | This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, 48 | num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) 49 | """ 50 | batch, num_key_value_heads, slen, head_dim = hidden_states.shape 51 | if n_rep == 1: 52 | return hidden_states 53 | hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) 54 | return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) 55 | 56 | def rotate_half(x): 57 | """Rotates half the hidden dims of the input.""" 58 | x1 = x[..., : x.shape[-1] // 2] 59 | x2 = x[..., x.shape[-1] // 2 :] 60 | return torch.cat((-x2, x1), dim=-1) 61 | def apply_rotary_pos_emb(q, k, cos, sin, position_ids): 62 | # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. 63 | cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] 64 | sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] 65 | cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] 66 | sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] 67 | q_embed = (q * cos) + (rotate_half(q) * sin) 68 | k_embed = (k * cos) + (rotate_half(k) * sin) 69 | return q_embed, k_embed 70 | class LlamaRotaryEmbedding(torch.nn.Module): 71 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): 72 | super().__init__() 73 | 74 | self.dim = dim 75 | self.max_position_embeddings = max_position_embeddings 76 | self.base = base 77 | inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) 78 | self.register_buffer("inv_freq", inv_freq, persistent=False) 79 | 80 | # Build here to make `torch.jit.trace` work. 81 | self._set_cos_sin_cache( 82 | seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() 83 | ) 84 | 85 | def _set_cos_sin_cache(self, seq_len, device, dtype): 86 | self.max_seq_len_cached = seq_len 87 | t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) 88 | 89 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 90 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 91 | emb = torch.cat((freqs, freqs), dim=-1) 92 | self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) 93 | self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) 94 | 95 | def forward(self, x, seq_len=None): 96 | # x: [bs, num_attention_heads, seq_len, head_size] 97 | if seq_len > self.max_seq_len_cached: 98 | self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) 99 | 100 | return ( 101 | self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), 102 | self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), 103 | ) 104 | 105 | class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): 106 | """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" 107 | 108 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): 109 | self.scaling_factor = scaling_factor 110 | super().__init__(dim, max_position_embeddings, base, device) 111 | 112 | def _set_cos_sin_cache(self, seq_len, device, dtype): 113 | self.max_seq_len_cached = seq_len 114 | t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) 115 | t = t / self.scaling_factor 116 | 117 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 118 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 119 | emb = torch.cat((freqs, freqs), dim=-1) 120 | self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) 121 | self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) 122 | 123 | 124 | class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): 125 | """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" 126 | 127 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): 128 | self.scaling_factor = scaling_factor 129 | super().__init__(dim, max_position_embeddings, base, device) 130 | 131 | def _set_cos_sin_cache(self, seq_len, device, dtype): 132 | self.max_seq_len_cached = seq_len 133 | 134 | if seq_len > self.max_position_embeddings: 135 | base = self.base * ( 136 | (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) 137 | ) ** (self.dim / (self.dim - 2)) 138 | inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) 139 | self.register_buffer("inv_freq", inv_freq, persistent=False) 140 | 141 | t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) 142 | 143 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 144 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 145 | emb = torch.cat((freqs, freqs), dim=-1) 146 | self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) 147 | self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) 148 | 149 | class LlamaAttention(nn.Module): 150 | """Multi-headed attention from 'Attention Is All You Need' paper""" 151 | 152 | def __init__(self, config): 153 | super().__init__() 154 | self.config = config 155 | self.hidden_size = config.hidden_size 156 | self.num_heads = config.num_attention_heads 157 | self.head_dim = self.hidden_size // self.num_heads 158 | self.num_key_value_heads = config.num_key_value_heads 159 | self.num_key_value_groups = self.num_heads // self.num_key_value_heads 160 | self.max_position_embeddings = config.max_position_embeddings 161 | 162 | if (self.head_dim * self.num_heads) != self.hidden_size: 163 | raise ValueError( 164 | f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" 165 | f" and `num_heads`: {self.num_heads})." 166 | ) 167 | self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) 168 | self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) 169 | self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) 170 | self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) 171 | self._init_rope() 172 | 173 | def _init_rope(self): 174 | if self.config.rope_scaling is None: 175 | self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) 176 | else: 177 | scaling_type = self.config.rope_scaling["type"] 178 | scaling_factor = self.config.rope_scaling["factor"] 179 | if scaling_type == "linear": 180 | self.rotary_emb = LlamaLinearScalingRotaryEmbedding( 181 | self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor 182 | ) 183 | elif scaling_type == "dynamic": 184 | self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( 185 | self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor 186 | ) 187 | else: 188 | raise ValueError(f"Unknown RoPE scaling type {scaling_type}") 189 | 190 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): 191 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() 192 | 193 | def forward( 194 | self, 195 | hidden_states: torch.Tensor, 196 | attention_mask: Optional[torch.Tensor] = None, 197 | position_ids: Optional[torch.LongTensor] = None, 198 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 199 | output_attentions: bool = False, 200 | use_cache: bool = False, 201 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 202 | bsz, q_len, _ = hidden_states.size() 203 | 204 | if self.config.pretraining_tp > 1: 205 | key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp 206 | query_slices = self.q_proj.weight.split( 207 | (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 208 | ) 209 | key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) 210 | value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) 211 | 212 | query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] 213 | query_states = torch.cat(query_states, dim=-1) 214 | 215 | key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] 216 | key_states = torch.cat(key_states, dim=-1) 217 | 218 | value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] 219 | value_states = torch.cat(value_states, dim=-1) 220 | 221 | else: 222 | query_states = self.q_proj(hidden_states) 223 | key_states = self.k_proj(hidden_states) 224 | value_states = self.v_proj(hidden_states) 225 | 226 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 227 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 228 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 229 | 230 | kv_seq_len = key_states.shape[-2] 231 | if past_key_value is not None: 232 | kv_seq_len += past_key_value[0].shape[-2] 233 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 234 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) 235 | 236 | if past_key_value is not None: 237 | # reuse k, v, self_attention 238 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 239 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 240 | 241 | past_key_value = (key_states, value_states) if use_cache else None 242 | 243 | # repeat k/v heads if n_kv_heads < n_heads 244 | key_states = repeat_kv(key_states, self.num_key_value_groups) 245 | value_states = repeat_kv(value_states, self.num_key_value_groups) 246 | 247 | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) 248 | 249 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): 250 | raise ValueError( 251 | f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" 252 | f" {attn_weights.size()}" 253 | ) 254 | 255 | if attention_mask is not None: 256 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): 257 | raise ValueError( 258 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" 259 | ) 260 | attn_weights = attn_weights + attention_mask 261 | 262 | # upcast attention to fp32 263 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) 264 | attn_output = torch.matmul(attn_weights, value_states) 265 | 266 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): 267 | raise ValueError( 268 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" 269 | f" {attn_output.size()}" 270 | ) 271 | 272 | attn_output = attn_output.transpose(1, 2).contiguous() 273 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 274 | 275 | if self.config.pretraining_tp > 1: 276 | attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) 277 | o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) 278 | attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) 279 | else: 280 | attn_output = self.o_proj(attn_output) 281 | 282 | if not output_attentions: 283 | attn_weights = None 284 | 285 | return attn_output, attn_weights, past_key_value 286 | 287 | 288 | class LlamaMLP(nn.Module): 289 | def __init__(self, config): 290 | super().__init__() 291 | self.config = config 292 | self.hidden_size = config.hidden_size 293 | self.intermediate_size = config.intermediate_size 294 | self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) 295 | self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) 296 | self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) 297 | self.act_fn = ACT2FN[config.hidden_act] 298 | 299 | def forward(self, x): 300 | if self.config.pretraining_tp > 1: 301 | slice = self.intermediate_size // self.config.pretraining_tp 302 | gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) 303 | up_proj_slices = self.up_proj.weight.split(slice, dim=0) 304 | down_proj_slices = self.down_proj.weight.split(slice, dim=1) 305 | 306 | gate_proj = torch.cat( 307 | [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1 308 | ) 309 | up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) 310 | 311 | intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) 312 | down_proj = [ 313 | F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp) 314 | ] 315 | down_proj = sum(down_proj) 316 | else: 317 | down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) 318 | 319 | return down_proj 320 | 321 | class LlamaRMSNorm(nn.Module): 322 | def __init__(self, hidden_size, eps=1e-6): 323 | """ 324 | LlamaRMSNorm is equivalent to T5LayerNorm 325 | """ 326 | super().__init__() 327 | self.weight = nn.Parameter(torch.ones(hidden_size)) 328 | self.variance_epsilon = eps 329 | 330 | def forward(self, hidden_states): 331 | input_dtype = hidden_states.dtype 332 | hidden_states = hidden_states.to(torch.float32) 333 | variance = hidden_states.pow(2).mean(-1, keepdim=True) 334 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) 335 | return self.weight * hidden_states.to(input_dtype) 336 | 337 | class LlamaDecoderLayer(nn.Module): 338 | def __init__(self, config,index): 339 | super().__init__() 340 | self.hidden_size = config.hidden_size 341 | self.self_attn = LlamaAttention(config=config) 342 | self.mlp = LlamaMLP(config) 343 | self.index=index 344 | if self.index!=0: 345 | self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 346 | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 347 | 348 | def forward( 349 | self, 350 | hidden_states: torch.Tensor, 351 | attention_mask: Optional[torch.Tensor] = None, 352 | position_ids: Optional[torch.LongTensor] = None, 353 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 354 | output_attentions: Optional[bool] = False, 355 | use_cache: Optional[bool] = False, 356 | ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: 357 | """ 358 | Args: 359 | hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` 360 | attention_mask (`torch.FloatTensor`, *optional*): attention mask of size 361 | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. 362 | output_attentions (`bool`, *optional*): 363 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 364 | returned tensors for more detail. 365 | use_cache (`bool`, *optional*): 366 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding 367 | (see `past_key_values`). 368 | past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states 369 | """ 370 | 371 | residual = hidden_states 372 | 373 | if self.index != 0: 374 | hidden_states = self.input_layernorm(hidden_states) 375 | 376 | # Self Attention 377 | hidden_states, self_attn_weights, present_key_value = self.self_attn( 378 | hidden_states=hidden_states, 379 | attention_mask=attention_mask, 380 | position_ids=position_ids, 381 | past_key_value=past_key_value, 382 | output_attentions=output_attentions, 383 | use_cache=use_cache, 384 | ) 385 | hidden_states = residual + hidden_states 386 | 387 | # Fully Connected 388 | residual = hidden_states 389 | hidden_states = self.post_attention_layernorm(hidden_states) 390 | hidden_states = self.mlp(hidden_states) 391 | hidden_states = residual + hidden_states 392 | 393 | outputs = (hidden_states,) 394 | 395 | if output_attentions: 396 | outputs += (self_attn_weights,) 397 | 398 | if use_cache: 399 | outputs += (present_key_value,) 400 | 401 | return outputs 402 | 403 | 404 | class Llama2Model(nn.Module): 405 | def __init__(self,config,load_emb=False,path=None,bias=True): 406 | super().__init__() 407 | self.gradient_checkpointing = True 408 | self.padding_idx = config.pad_token_id 409 | self.vocab_size = config.vocab_size 410 | 411 | ##################################################################################### 412 | # Init embedding from base model 413 | 414 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) 415 | if load_emb: 416 | from safetensors import safe_open 417 | import json 418 | try: 419 | with open(os.path.join(path,"model.safetensors.index.json"),"r") as f: 420 | index_json=json.loads(f.read()) 421 | emb_path=index_json["weight_map"]["model.embed_tokens.weight"] 422 | with safe_open(os.path.join(path,emb_path), 423 | framework="pt", 424 | device="cpu") as f: 425 | tensor_slice = f.get_slice("model.embed_tokens.weight") 426 | vocab_size, hidden_dim = tensor_slice.get_shape() 427 | tensor = tensor_slice[:, :hidden_dim].float() 428 | except: 429 | with open(os.path.join(path, "pytorch_model.bin.index.json"), "r") as f: 430 | index_json = json.loads(f.read()) 431 | emb_path = index_json["weight_map"]["model.embed_tokens.weight"] 432 | weights=torch.load(os.path.join(path,emb_path)) 433 | tensor=weights["model.embed_tokens.weight"].float() 434 | self.embed_tokens.weight.data = tensor 435 | 436 | ##################################################################################### 437 | # init eagle layers 438 | 439 | self.layers = nn.ModuleList([LlamaDecoderLayer(config,index) for index in range(config.num_hidden_layers)]) 440 | 441 | ##################################################################################### 442 | # init eagle down pooler from concat([emb, hidden_state]) -> hidden_size 443 | 444 | self.fc=nn.Linear(2*config.hidden_size,config.hidden_size,bias=bias) 445 | 446 | ##################################################################################### 447 | # freeze emb + extra stuff 448 | 449 | self.act=ACT2FN[config.hidden_act] 450 | for param in self.embed_tokens.parameters(): 451 | param.requires_grad = False 452 | 453 | def init_tree(self): 454 | self.tree = mc_sim_7b_63 455 | self.tree_buffer=generate_tree_buffers(self.tree,self.embed_tokens.weight.device) 456 | 457 | def reset(self): 458 | self.tree_mask=None 459 | 460 | def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): 461 | # create causal mask 462 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 463 | combined_attention_mask = None 464 | if input_shape[-1] > 1: 465 | combined_attention_mask = _make_causal_mask( 466 | input_shape, 467 | #inputs_embeds.dtype, 468 | torch.float32, # [MODIFIED] force to cast to float32 469 | device=inputs_embeds.device, 470 | past_key_values_length=past_key_values_length, 471 | ) 472 | 473 | if attention_mask is not None: 474 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 475 | expanded_attn_mask = _expand_mask(attention_mask, torch.float32, tgt_len=input_shape[-1]).to( 476 | inputs_embeds.device 477 | ) 478 | combined_attention_mask = ( 479 | expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask 480 | ) 481 | 482 | # [MODIFIED] add tree mask 483 | if hasattr(self, "tree_mask") and self.tree_mask is not None: 484 | tree_mask = self.tree_mask 485 | tree_len = tree_mask.size(-1) 486 | combined_attention_mask[:, :, -tree_len:, -tree_len:][ 487 | tree_mask == 0 488 | ] = torch.finfo(torch.float32).min 489 | 490 | 491 | return combined_attention_mask 492 | 493 | def forward( 494 | self, 495 | hidden_states, 496 | input_ids, 497 | attention_mask: Optional[torch.Tensor] = None, 498 | position_ids: Optional[torch.LongTensor] = None, 499 | past_key_values: Optional[List[torch.FloatTensor]] = None, 500 | inputs_embeds: Optional[torch.FloatTensor] = None, 501 | use_cache: Optional[bool] = None, 502 | output_attentions: Optional[bool] = None, 503 | output_hidden_states: Optional[bool] = None, 504 | return_dict: Optional[bool] = None, 505 | std=None 506 | ): 507 | batch_size, seq_length, _ = hidden_states.shape 508 | seq_length_with_past = seq_length 509 | past_key_values_length = 0 510 | 511 | with torch.no_grad(): 512 | ##################################################################################### 513 | # embed tokens using embedding table 514 | 515 | inputs_embeds = self.embed_tokens(input_ids) 516 | 517 | if past_key_values is not None: 518 | past_key_values_length = past_key_values[0][0].shape[2] 519 | seq_length_with_past = seq_length_with_past + past_key_values_length 520 | if position_ids is None: 521 | device = hidden_states.device if hidden_states is not None else inputs_embeds.device 522 | position_ids = torch.arange( 523 | past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device 524 | ) 525 | position_ids = position_ids.unsqueeze(0).view(-1, seq_length) 526 | else: 527 | position_ids = position_ids.view(-1, seq_length).long() 528 | 529 | if attention_mask is None: 530 | attention_mask = torch.ones( 531 | (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device 532 | ) 533 | attention_mask = self._prepare_decoder_attention_mask( 534 | attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length 535 | ) 536 | 537 | inputs_embeds=inputs_embeds.to(hidden_states.dtype) 538 | 539 | ##################################################################################### 540 | # Concat emb + hidden state and lower dim back to hidden_size using fully connected 541 | 542 | hidden_states = self.fc(torch.cat((inputs_embeds, hidden_states), dim=-1)) 543 | 544 | 545 | all_hidden_states = () if output_hidden_states else None 546 | next_decoder_cache = () if use_cache else None 547 | 548 | for idx, decoder_layer in enumerate(self.layers): 549 | if output_hidden_states: 550 | all_hidden_states += (hidden_states,) 551 | 552 | past_key_value = past_key_values[idx] if past_key_values is not None else None 553 | 554 | if self.gradient_checkpointing and self.training: 555 | 556 | def create_custom_forward(module): 557 | def custom_forward(*inputs): 558 | # None for past_key_value 559 | return module(*inputs, past_key_value, output_attentions) 560 | 561 | return custom_forward 562 | 563 | layer_outputs = torch.utils.checkpoint.checkpoint( 564 | create_custom_forward(decoder_layer), 565 | hidden_states, 566 | attention_mask, 567 | position_ids, 568 | ) 569 | else: 570 | layer_outputs = decoder_layer( 571 | hidden_states, 572 | attention_mask=attention_mask, 573 | position_ids=position_ids, 574 | past_key_value=past_key_value, 575 | output_attentions=output_attentions, 576 | use_cache=use_cache, 577 | ) 578 | 579 | hidden_states = layer_outputs[0] 580 | 581 | if use_cache: 582 | next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) 583 | 584 | if use_cache: 585 | return hidden_states,next_decoder_cache 586 | 587 | return hidden_states 588 | --------------------------------------------------------------------------------