├── scaled_rope ├── __init__.py ├── GPTNeoXNTKScaledRotaryEmbedding.py ├── GPTNeoXDynamicScaledRotaryEmbedding.py ├── LlamaNTKScaledRotaryEmbedding.py ├── LlamaLinearScaledRotaryEmbedding.py ├── LlamaDynamicScaledRotaryEmbedding.py ├── LlamaYaRNScaledRotaryEmbedding.py ├── LlamaPartNTKScaledRotaryEmbedding.py ├── patch.py ├── LlamaDynamicYaRNScaledRotaryEmbedding.py ├── FalconDynamicPartNTKScaledRotaryEmbedding.py ├── LlamaDynamicPartNTKScaledRotaryEmbedding.py ├── LlamaReRoPE.py ├── configuration_mistral.py └── configuration_llama.py ├── paper ├── yarn.pdf ├── mmlu-average.py └── plot.py ├── data ├── proofpile-long-small.csv.pdf ├── proofpile-long-small.csv.png ├── proofpile-long-small-32k-70b.csv.pdf ├── proofpile-long-small-32k-70b.csv.png ├── proofpile-long-small-mistral.csv.pdf ├── proofpile-long-small-mistral.csv.png ├── proofpile-long-small-solar.csv.pdf ├── proofpile-long-small-solar.csv.png ├── govreport.csv ├── proofpile-long-small-8k.csv ├── Yarn-Llama-2-7b-64k-truthfulqa.json ├── Yarn-Llama-2-13b-64k-hellaswag.json ├── Yarn-Llama-2-13b-64k-truthfulqa.json ├── Yarn-Llama-2-70b-32k-truthfulqa.json ├── Yarn-Llama-2-7b-128k-truthfulqa.json ├── Yarn-Llama-2-7b-64k-arc.json ├── Yarn-Llama-2-7b-64k-hellaswag.json ├── Yarn-Mistral-7b-128k-hellaswag.json ├── Yarn-Mistral-7b-128k-truthfulqa.json ├── Yarn-Mistral-7b-64k-arc.json ├── Yarn-Mistral-7b-64k-hellaswag.json ├── Yarn-Mistral-7b-64k-truthfulqa.json ├── Yarn-Llama-2-13b-128k-arc.json ├── Yarn-Llama-2-13b-128k-hellaswag.json ├── Yarn-Llama-2-13b-128k-truthfulqa.json ├── Yarn-Llama-2-13b-64k-arc.json ├── Yarn-Llama-2-70b-32k-arc.json ├── Yarn-Llama-2-7b-128k-arc.json ├── Yarn-Llama-2-7b-128k-hellaswag.json ├── Yarn-Mistral-7b-128k-arc.json ├── Yarn-Solar-10b-32k-hellaswag.json ├── Yarn-Solar-10b-32k-truthfulqa.json ├── Yarn-Solar-10b-64k-hellaswag.json ├── Yarn-Solar-10b-64k-truthfulqa.json ├── Yarn-Solar-10b-32k-arc.json ├── Yarn-Solar-10b-64k-arc.json ├── proofpile-long-small-32k-70b.csv ├── proofpile-long-small-mistral.csv ├── proofpile-long-small-solar.csv ├── proofpile-long-small.csv ├── Yarn-Llama-2-13b-64k-mmlu.json └── Yarn-Llama-2-7b-128k-mmlu.json ├── .gitignore ├── requirements.txt ├── setup.py ├── truncate.py ├── LICENSE ├── deepspeed ├── zero2.json ├── zero3.json └── zero3_offload.json ├── eval ├── prompt-loop.py ├── quality.py ├── passkey.py ├── passkey_hard.py ├── perplexity.py └── model_loader.py ├── train.sh ├── eval.sh ├── tokenization.py ├── README.md ├── eval-harness.sh └── finetune.py /scaled_rope/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /paper/yarn.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jquesnelle/yarn/HEAD/paper/yarn.pdf -------------------------------------------------------------------------------- /data/proofpile-long-small.csv.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jquesnelle/yarn/HEAD/data/proofpile-long-small.csv.pdf -------------------------------------------------------------------------------- /data/proofpile-long-small.csv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jquesnelle/yarn/HEAD/data/proofpile-long-small.csv.png -------------------------------------------------------------------------------- /data/proofpile-long-small-32k-70b.csv.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jquesnelle/yarn/HEAD/data/proofpile-long-small-32k-70b.csv.pdf -------------------------------------------------------------------------------- /data/proofpile-long-small-32k-70b.csv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jquesnelle/yarn/HEAD/data/proofpile-long-small-32k-70b.csv.png -------------------------------------------------------------------------------- /data/proofpile-long-small-mistral.csv.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jquesnelle/yarn/HEAD/data/proofpile-long-small-mistral.csv.pdf -------------------------------------------------------------------------------- /data/proofpile-long-small-mistral.csv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jquesnelle/yarn/HEAD/data/proofpile-long-small-mistral.csv.png -------------------------------------------------------------------------------- /data/proofpile-long-small-solar.csv.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jquesnelle/yarn/HEAD/data/proofpile-long-small-solar.csv.pdf -------------------------------------------------------------------------------- /data/proofpile-long-small-solar.csv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jquesnelle/yarn/HEAD/data/proofpile-long-small-solar.csv.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | env/ 3 | venv/ 4 | /.vscode 5 | /output 6 | *.txt 7 | *.egg-info 8 | wandb/ 9 | lm_cache/ 10 | run.sh -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | datasets 3 | evaluate 4 | torch>=2 5 | transformers>=4.36.2 6 | tqdm 7 | einops 8 | sentencepiece 9 | protobuf==3.19.6 10 | scikit-learn 11 | matplotlib 12 | pandas 13 | numpy 14 | -------------------------------------------------------------------------------- /data/govreport.csv: -------------------------------------------------------------------------------- 1 | ,32768 2 | NousResearch/CodeLlama-13b-hf,4.223242283 3 | NousResearch/Yarn-Llama-2-13b-64k,3.35193944 4 | NousResearch/Yarn-Llama-2-13b-128k,3.392480612 5 | togethercomputer/LLaMA-2-7B-32K,3.667786837 6 | NousResearch/CodeLlama-7b-hf,4.438117504 7 | NousResearch/Yarn-Llama-2-7b-64k,3.594996214 8 | NousResearch/Yarn-Llama-2-7b-128k,3.642920256 -------------------------------------------------------------------------------- /paper/mmlu-average.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy 3 | import json 4 | 5 | def main(args): 6 | obj = json.load(open(args.file, "r", encoding="utf-8")) 7 | results = [result["acc"] for result in obj["results"].values()] 8 | print(numpy.average(results)) 9 | 10 | if __name__ == "__main__": 11 | args = argparse.ArgumentParser() 12 | args.add_argument("file", type=str) 13 | main(args.parse_args()) -------------------------------------------------------------------------------- /data/proofpile-long-small-8k.csv: -------------------------------------------------------------------------------- 1 | ,2048,4096,6144,8192,10240,12288,14336,16384 2 | emozilla/Yarn-Llama-2-7b-8k,3.9184463024139404,3.5097882747650146,3.5142462253570557,3.3563895225524902,6.0485148429870605,18.118249893188477,41.190216064453125,77.2326889038086 3 | emozilla/NTK-Llama-2-7b-8k,4.202556610107422,3.750666856765747,3.7433838844299316,3.586596727371216,6.235434532165527,18.260393142700195,41.89409255981445,74.43867492675781 4 | conceptofmind/LLongMA-2-7b,3.9218568801879883,3.509538412094116,3.5088977813720703,3.344118118286133,8.071100234985352,23.530475616455078,54.98833465576172,102.7402114868164 5 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | install_requires = [] 4 | with open("./requirements.txt", encoding="utf-8") as requirements_file: 5 | reqs = [r.strip() for r in requirements_file.readlines()] 6 | reqs = [r for r in reqs if r and r[0] != "#"] 7 | for r in reqs: 8 | install_requires.append(r) 9 | 10 | setup( 11 | name="scaled-rope", 12 | version="0.1", 13 | packages=["scaled_rope"], 14 | install_requires=install_requires, 15 | url='https://github.com/jquesnelle/scaled-rope/', 16 | license='MIT', 17 | classifiers=[ 18 | "Programming Language :: Python :: 3", 19 | "License :: OSI Approved :: MIT License", 20 | ] 21 | ) -------------------------------------------------------------------------------- /data/Yarn-Llama-2-7b-64k-truthfulqa.json: -------------------------------------------------------------------------------- 1 | { 2 | "results": { 3 | "truthfulqa_mc": { 4 | "mc1": 0.2533659730722154, 5 | "mc1_stderr": 0.015225899340826842, 6 | "mc2": 0.38227179754259044, 7 | "mc2_stderr": 0.0135262280350684 8 | } 9 | }, 10 | "versions": { 11 | "truthfulqa_mc": 1 12 | }, 13 | "config": { 14 | "model": "hf-causal-experimental", 15 | "model_args": "pretrained=NousResearch/LLaMA-2-7B-YaRN-64K,use_accelerate=True,dtype=bfloat16,trust_remote_code=True", 16 | "num_fewshot": 0, 17 | "batch_size": "2", 18 | "batch_sizes": [], 19 | "device": null, 20 | "no_cache": false, 21 | "limit": null, 22 | "bootstrap_iters": 100000, 23 | "description_dict": {} 24 | } 25 | } -------------------------------------------------------------------------------- /data/Yarn-Llama-2-13b-64k-hellaswag.json: -------------------------------------------------------------------------------- 1 | { 2 | "results": { 3 | "hellaswag": { 4 | "acc": 0.6189006174068911, 5 | "acc_stderr": 0.004846643735666543, 6 | "acc_norm": 0.8237402907787293, 7 | "acc_norm_stderr": 0.003802622341529011 8 | } 9 | }, 10 | "versions": { 11 | "hellaswag": 0 12 | }, 13 | "config": { 14 | "model": "hf-causal-experimental", 15 | "model_args": "pretrained=NousResearch/LLaMA-2-13B-YaRN-64K,use_accelerate=True,dtype=bfloat16,trust_remote_code=True", 16 | "num_fewshot": 10, 17 | "batch_size": "2", 18 | "batch_sizes": [], 19 | "device": null, 20 | "no_cache": false, 21 | "limit": null, 22 | "bootstrap_iters": 100000, 23 | "description_dict": {} 24 | } 25 | } -------------------------------------------------------------------------------- /data/Yarn-Llama-2-13b-64k-truthfulqa.json: -------------------------------------------------------------------------------- 1 | { 2 | "results": { 3 | "truthfulqa_mc": { 4 | "mc1": 0.2558139534883721, 5 | "mc1_stderr": 0.015274176219283352, 6 | "mc2": 0.37809984999128077, 7 | "mc2_stderr": 0.013601944131929177 8 | } 9 | }, 10 | "versions": { 11 | "truthfulqa_mc": 1 12 | }, 13 | "config": { 14 | "model": "hf-causal-experimental", 15 | "model_args": "pretrained=NousResearch/LLaMA-2-13B-YaRN-64K,use_accelerate=True,dtype=bfloat16,trust_remote_code=True", 16 | "num_fewshot": 0, 17 | "batch_size": "2", 18 | "batch_sizes": [], 19 | "device": null, 20 | "no_cache": false, 21 | "limit": null, 22 | "bootstrap_iters": 100000, 23 | "description_dict": {} 24 | } 25 | } -------------------------------------------------------------------------------- /data/Yarn-Llama-2-70b-32k-truthfulqa.json: -------------------------------------------------------------------------------- 1 | { 2 | "results": { 3 | "truthfulqa_mc": { 4 | "mc1": 0.3157894736842105, 5 | "mc1_stderr": 0.016272287957916912, 6 | "mc2": 0.4613857455330124, 7 | "mc2_stderr": 0.0140068920418935 8 | } 9 | }, 10 | "versions": { 11 | "truthfulqa_mc": 1 12 | }, 13 | "config": { 14 | "model": "hf-causal-experimental", 15 | "model_args": "pretrained=NousResearch/Yarn-Llama-2-70b-32k,use_accelerate=True,dtype=bfloat16,trust_remote_code=True", 16 | "num_fewshot": 0, 17 | "batch_size": "20", 18 | "batch_sizes": [], 19 | "device": null, 20 | "no_cache": false, 21 | "limit": null, 22 | "bootstrap_iters": 100000, 23 | "description_dict": {} 24 | } 25 | } -------------------------------------------------------------------------------- /data/Yarn-Llama-2-7b-128k-truthfulqa.json: -------------------------------------------------------------------------------- 1 | { 2 | "results": { 3 | "truthfulqa_mc": { 4 | "mc1": 0.2484700122399021, 5 | "mc1_stderr": 0.015127427096520688, 6 | "mc2": 0.3737977589710703, 7 | "mc2_stderr": 0.01349079049699645 8 | } 9 | }, 10 | "versions": { 11 | "truthfulqa_mc": 1 12 | }, 13 | "config": { 14 | "model": "hf-causal-experimental", 15 | "model_args": "pretrained=NousResearch/LLaMA-2-7B-YaRN-128K,use_accelerate=True,dtype=bfloat16,trust_remote_code=True", 16 | "num_fewshot": 0, 17 | "batch_size": "2", 18 | "batch_sizes": [], 19 | "device": null, 20 | "no_cache": false, 21 | "limit": null, 22 | "bootstrap_iters": 100000, 23 | "description_dict": {} 24 | } 25 | } -------------------------------------------------------------------------------- /data/Yarn-Llama-2-7b-64k-arc.json: -------------------------------------------------------------------------------- 1 | { 2 | "results": { 3 | "arc_challenge": { 4 | "acc": 0.49402730375426623, 5 | "acc_stderr": 0.014610348300255795, 6 | "acc_norm": 0.523037542662116, 7 | "acc_norm_stderr": 0.014595873205358266 8 | } 9 | }, 10 | "versions": { 11 | "arc_challenge": 0 12 | }, 13 | "config": { 14 | "model": "hf-causal-experimental", 15 | "model_args": "pretrained=NousResearch/LLaMA-2-7B-YaRN-64K,use_accelerate=True,dtype=bfloat16,trust_remote_code=True", 16 | "num_fewshot": 25, 17 | "batch_size": "2", 18 | "batch_sizes": [], 19 | "device": null, 20 | "no_cache": false, 21 | "limit": null, 22 | "bootstrap_iters": 100000, 23 | "description_dict": {} 24 | } 25 | } -------------------------------------------------------------------------------- /data/Yarn-Llama-2-7b-64k-hellaswag.json: -------------------------------------------------------------------------------- 1 | { 2 | "results": { 3 | "hellaswag": { 4 | "acc": 0.5888269269069907, 5 | "acc_stderr": 0.004910409150135491, 6 | "acc_norm": 0.788388767177853, 7 | "acc_norm_stderr": 0.00407615874434677 8 | } 9 | }, 10 | "versions": { 11 | "hellaswag": 0 12 | }, 13 | "config": { 14 | "model": "hf-causal-experimental", 15 | "model_args": "pretrained=NousResearch/LLaMA-2-7B-YaRN-64K,use_accelerate=True,dtype=bfloat16,trust_remote_code=True", 16 | "num_fewshot": 10, 17 | "batch_size": "2", 18 | "batch_sizes": [], 19 | "device": null, 20 | "no_cache": false, 21 | "limit": null, 22 | "bootstrap_iters": 100000, 23 | "description_dict": {} 24 | } 25 | } -------------------------------------------------------------------------------- /data/Yarn-Mistral-7b-128k-hellaswag.json: -------------------------------------------------------------------------------- 1 | { 2 | "results": { 3 | "hellaswag": { 4 | "acc": 0.6096395140410277, 5 | "acc_stderr": 0.00486834105656622, 6 | "acc_norm": 0.8058155745867357, 7 | "acc_norm_stderr": 0.0039476309218879286 8 | } 9 | }, 10 | "versions": { 11 | "hellaswag": 0 12 | }, 13 | "config": { 14 | "model": "hf-causal-experimental", 15 | "model_args": "pretrained=NousResearch/Yarn-Mistral-7b-128k,use_accelerate=True,dtype=bfloat16,trust_remote_code=True", 16 | "num_fewshot": 10, 17 | "batch_size": "8", 18 | "batch_sizes": [], 19 | "device": null, 20 | "no_cache": false, 21 | "limit": null, 22 | "bootstrap_iters": 100000, 23 | "description_dict": {} 24 | } 25 | } -------------------------------------------------------------------------------- /data/Yarn-Mistral-7b-128k-truthfulqa.json: -------------------------------------------------------------------------------- 1 | { 2 | "results": { 3 | "truthfulqa_mc": { 4 | "mc1": 0.2717258261933905, 5 | "mc1_stderr": 0.015572840452875833, 6 | "mc2": 0.4246618096588122, 7 | "mc2_stderr": 0.014162245453141755 8 | } 9 | }, 10 | "versions": { 11 | "truthfulqa_mc": 1 12 | }, 13 | "config": { 14 | "model": "hf-causal-experimental", 15 | "model_args": "pretrained=NousResearch/Yarn-Mistral-7b-128k,use_accelerate=True,dtype=bfloat16,trust_remote_code=True", 16 | "num_fewshot": 0, 17 | "batch_size": "8", 18 | "batch_sizes": [], 19 | "device": null, 20 | "no_cache": false, 21 | "limit": null, 22 | "bootstrap_iters": 100000, 23 | "description_dict": {} 24 | } 25 | } -------------------------------------------------------------------------------- /data/Yarn-Mistral-7b-64k-arc.json: -------------------------------------------------------------------------------- 1 | { 2 | "results": { 3 | "arc_challenge": { 4 | "acc": 0.5648464163822525, 5 | "acc_stderr": 0.014487986197186045, 6 | "acc_norm": 0.5938566552901023, 7 | "acc_norm_stderr": 0.014351656690097862 8 | } 9 | }, 10 | "versions": { 11 | "arc_challenge": 0 12 | }, 13 | "config": { 14 | "model": "hf-causal-experimental", 15 | "model_args": "pretrained=NousResearch/Yarn-Mistral-7b-64k,use_accelerate=True,dtype=bfloat16,trust_remote_code=True", 16 | "num_fewshot": 25, 17 | "batch_size": "8", 18 | "batch_sizes": [], 19 | "device": null, 20 | "no_cache": false, 21 | "limit": null, 22 | "bootstrap_iters": 100000, 23 | "description_dict": {} 24 | } 25 | } -------------------------------------------------------------------------------- /data/Yarn-Mistral-7b-64k-hellaswag.json: -------------------------------------------------------------------------------- 1 | { 2 | "results": { 3 | "hellaswag": { 4 | "acc": 0.6149173471420036, 5 | "acc_stderr": 0.004856203374715456, 6 | "acc_norm": 0.8121888070105556, 7 | "acc_norm_stderr": 0.0038976312814765204 8 | } 9 | }, 10 | "versions": { 11 | "hellaswag": 0 12 | }, 13 | "config": { 14 | "model": "hf-causal-experimental", 15 | "model_args": "pretrained=NousResearch/Yarn-Mistral-7b-64k,use_accelerate=True,dtype=bfloat16,trust_remote_code=True", 16 | "num_fewshot": 10, 17 | "batch_size": "8", 18 | "batch_sizes": [], 19 | "device": null, 20 | "no_cache": false, 21 | "limit": null, 22 | "bootstrap_iters": 100000, 23 | "description_dict": {} 24 | } 25 | } -------------------------------------------------------------------------------- /data/Yarn-Mistral-7b-64k-truthfulqa.json: -------------------------------------------------------------------------------- 1 | { 2 | "results": { 3 | "truthfulqa_mc": { 4 | "mc1": 0.2766217870257038, 5 | "mc1_stderr": 0.015659605755326923, 6 | "mc2": 0.42507003020142764, 7 | "mc2_stderr": 0.01414022344947813 8 | } 9 | }, 10 | "versions": { 11 | "truthfulqa_mc": 1 12 | }, 13 | "config": { 14 | "model": "hf-causal-experimental", 15 | "model_args": "pretrained=NousResearch/Yarn-Mistral-7b-64k,use_accelerate=True,dtype=bfloat16,trust_remote_code=True", 16 | "num_fewshot": 0, 17 | "batch_size": "8", 18 | "batch_sizes": [], 19 | "device": null, 20 | "no_cache": false, 21 | "limit": null, 22 | "bootstrap_iters": 100000, 23 | "description_dict": {} 24 | } 25 | } -------------------------------------------------------------------------------- /data/Yarn-Llama-2-13b-128k-arc.json: -------------------------------------------------------------------------------- 1 | { 2 | "results": { 3 | "arc_challenge": { 4 | "acc": 0.5401023890784983, 5 | "acc_stderr": 0.01456431885692485, 6 | "acc_norm": 0.5802047781569966, 7 | "acc_norm_stderr": 0.014422181226303028 8 | } 9 | }, 10 | "versions": { 11 | "arc_challenge": 0 12 | }, 13 | "config": { 14 | "model": "hf-causal-experimental", 15 | "model_args": "pretrained=NousResearch/LLaMA-2-13B-YaRN-128K,use_accelerate=True,dtype=bfloat16,trust_remote_code=True", 16 | "num_fewshot": 25, 17 | "batch_size": "2", 18 | "batch_sizes": [], 19 | "device": null, 20 | "no_cache": false, 21 | "limit": null, 22 | "bootstrap_iters": 100000, 23 | "description_dict": {} 24 | } 25 | } -------------------------------------------------------------------------------- /data/Yarn-Llama-2-13b-128k-hellaswag.json: -------------------------------------------------------------------------------- 1 | { 2 | "results": { 3 | "hellaswag": { 4 | "acc": 0.6189006174068911, 5 | "acc_stderr": 0.0048466437356665445, 6 | "acc_norm": 0.8224457279426409, 7 | "acc_norm_stderr": 0.0038135610571503444 8 | } 9 | }, 10 | "versions": { 11 | "hellaswag": 0 12 | }, 13 | "config": { 14 | "model": "hf-causal-experimental", 15 | "model_args": "pretrained=NousResearch/LLaMA-2-13B-YaRN-128K,use_accelerate=True,dtype=bfloat16,trust_remote_code=True", 16 | "num_fewshot": 10, 17 | "batch_size": "2", 18 | "batch_sizes": [], 19 | "device": null, 20 | "no_cache": false, 21 | "limit": null, 22 | "bootstrap_iters": 100000, 23 | "description_dict": {} 24 | } 25 | } -------------------------------------------------------------------------------- /data/Yarn-Llama-2-13b-128k-truthfulqa.json: -------------------------------------------------------------------------------- 1 | { 2 | "results": { 3 | "truthfulqa_mc": { 4 | "mc1": 0.24969400244798043, 5 | "mc1_stderr": 0.015152286907148128, 6 | "mc2": 0.3736016332066591, 7 | "mc2_stderr": 0.013540879642054216 8 | } 9 | }, 10 | "versions": { 11 | "truthfulqa_mc": 1 12 | }, 13 | "config": { 14 | "model": "hf-causal-experimental", 15 | "model_args": "pretrained=NousResearch/LLaMA-2-13B-YaRN-128K,use_accelerate=True,dtype=bfloat16,trust_remote_code=True", 16 | "num_fewshot": 0, 17 | "batch_size": "2", 18 | "batch_sizes": [], 19 | "device": null, 20 | "no_cache": false, 21 | "limit": null, 22 | "bootstrap_iters": 100000, 23 | "description_dict": {} 24 | } 25 | } -------------------------------------------------------------------------------- /data/Yarn-Llama-2-13b-64k-arc.json: -------------------------------------------------------------------------------- 1 | { 2 | "results": { 3 | "arc_challenge": { 4 | "acc": 0.5409556313993175, 5 | "acc_stderr": 0.01456229107360123, 6 | "acc_norm": 0.5810580204778157, 7 | "acc_norm_stderr": 0.014418106953639013 8 | } 9 | }, 10 | "versions": { 11 | "arc_challenge": 0 12 | }, 13 | "config": { 14 | "model": "hf-causal-experimental", 15 | "model_args": "pretrained=NousResearch/LLaMA-2-13B-YaRN-64K,use_accelerate=True,dtype=bfloat16,trust_remote_code=True", 16 | "num_fewshot": 25, 17 | "batch_size": "2", 18 | "batch_sizes": [], 19 | "device": null, 20 | "no_cache": false, 21 | "limit": null, 22 | "bootstrap_iters": 100000, 23 | "description_dict": {} 24 | } 25 | } -------------------------------------------------------------------------------- /data/Yarn-Llama-2-70b-32k-arc.json: -------------------------------------------------------------------------------- 1 | { 2 | "results": { 3 | "arc_challenge": { 4 | "acc": 0.6262798634812287, 5 | "acc_stderr": 0.014137708601759091, 6 | "acc_norm": 0.674061433447099, 7 | "acc_norm_stderr": 0.013697432466693244 8 | } 9 | }, 10 | "versions": { 11 | "arc_challenge": 0 12 | }, 13 | "config": { 14 | "model": "hf-causal-experimental", 15 | "model_args": "pretrained=NousResearch/Yarn-Llama-2-70b-32k,use_accelerate=True,dtype=bfloat16,trust_remote_code=True", 16 | "num_fewshot": 25, 17 | "batch_size": "2", 18 | "batch_sizes": [], 19 | "device": null, 20 | "no_cache": false, 21 | "limit": null, 22 | "bootstrap_iters": 100000, 23 | "description_dict": {} 24 | } 25 | } -------------------------------------------------------------------------------- /data/Yarn-Llama-2-7b-128k-arc.json: -------------------------------------------------------------------------------- 1 | { 2 | "results": { 3 | "arc_challenge": { 4 | "acc": 0.49402730375426623, 5 | "acc_stderr": 0.014610348300255793, 6 | "acc_norm": 0.5213310580204779, 7 | "acc_norm_stderr": 0.014598087973127108 8 | } 9 | }, 10 | "versions": { 11 | "arc_challenge": 0 12 | }, 13 | "config": { 14 | "model": "hf-causal-experimental", 15 | "model_args": "pretrained=NousResearch/LLaMA-2-7B-YaRN-128K,use_accelerate=True,dtype=bfloat16,trust_remote_code=True", 16 | "num_fewshot": 25, 17 | "batch_size": "2", 18 | "batch_sizes": [], 19 | "device": null, 20 | "no_cache": false, 21 | "limit": null, 22 | "bootstrap_iters": 100000, 23 | "description_dict": {} 24 | } 25 | } -------------------------------------------------------------------------------- /data/Yarn-Llama-2-7b-128k-hellaswag.json: -------------------------------------------------------------------------------- 1 | { 2 | "results": { 3 | "hellaswag": { 4 | "acc": 0.5893248356901015, 5 | "acc_stderr": 0.004909509538525163, 6 | "acc_norm": 0.7843059151563434, 7 | "acc_norm_stderr": 0.0041046239918463645 8 | } 9 | }, 10 | "versions": { 11 | "hellaswag": 0 12 | }, 13 | "config": { 14 | "model": "hf-causal-experimental", 15 | "model_args": "pretrained=NousResearch/LLaMA-2-7B-YaRN-128K,use_accelerate=True,dtype=bfloat16,trust_remote_code=True", 16 | "num_fewshot": 10, 17 | "batch_size": "2", 18 | "batch_sizes": [], 19 | "device": null, 20 | "no_cache": false, 21 | "limit": null, 22 | "bootstrap_iters": 100000, 23 | "description_dict": {} 24 | } 25 | } -------------------------------------------------------------------------------- /data/Yarn-Mistral-7b-128k-arc.json: -------------------------------------------------------------------------------- 1 | { 2 | "results": { 3 | "arc_challenge": { 4 | "acc": 0.5520477815699659, 5 | "acc_stderr": 0.014532011498211674, 6 | "acc_norm": 0.5887372013651877, 7 | "acc_norm_stderr": 0.014379441068522082 8 | } 9 | }, 10 | "versions": { 11 | "arc_challenge": 0 12 | }, 13 | "config": { 14 | "model": "hf-causal-experimental", 15 | "model_args": "pretrained=NousResearch/Yarn-Mistral-7b-128k,use_accelerate=True,dtype=bfloat16,trust_remote_code=True", 16 | "num_fewshot": 25, 17 | "batch_size": "8", 18 | "batch_sizes": [], 19 | "device": null, 20 | "no_cache": false, 21 | "limit": null, 22 | "bootstrap_iters": 100000, 23 | "description_dict": {} 24 | } 25 | } -------------------------------------------------------------------------------- /data/Yarn-Solar-10b-32k-hellaswag.json: -------------------------------------------------------------------------------- 1 | { 2 | "results": { 3 | "hellaswag": { 4 | "acc": 0.6373232423819957, 5 | "acc_stderr": 0.004797900720081486, 6 | "acc_norm": 0.8364867556263692, 7 | "acc_norm_stderr": 0.003690774563638007 8 | } 9 | }, 10 | "versions": { 11 | "hellaswag": 0 12 | }, 13 | "config": { 14 | "model": "hf-causal-experimental", 15 | "model_args": "pretrained=NousResearch/Yarn-Solar-10b-32k,use_accelerate=True,dtype=bfloat16,trust_remote_code=True,peft=", 16 | "num_fewshot": 10, 17 | "batch_size": "24", 18 | "batch_sizes": [], 19 | "device": null, 20 | "no_cache": true, 21 | "limit": null, 22 | "bootstrap_iters": 100000, 23 | "description_dict": {} 24 | } 25 | } -------------------------------------------------------------------------------- /data/Yarn-Solar-10b-32k-truthfulqa.json: -------------------------------------------------------------------------------- 1 | { 2 | "results": { 3 | "truthfulqa_mc": { 4 | "mc1": 0.3023255813953488, 5 | "mc1_stderr": 0.01607750926613303, 6 | "mc2": 0.4481800921545103, 7 | "mc2_stderr": 0.014270815215139987 8 | } 9 | }, 10 | "versions": { 11 | "truthfulqa_mc": 1 12 | }, 13 | "config": { 14 | "model": "hf-causal-experimental", 15 | "model_args": "pretrained=NousResearch/Yarn-Solar-10b-32k,use_accelerate=True,dtype=bfloat16,trust_remote_code=True,peft=", 16 | "num_fewshot": 0, 17 | "batch_size": "48", 18 | "batch_sizes": [], 19 | "device": null, 20 | "no_cache": true, 21 | "limit": null, 22 | "bootstrap_iters": 100000, 23 | "description_dict": {} 24 | } 25 | } -------------------------------------------------------------------------------- /data/Yarn-Solar-10b-64k-hellaswag.json: -------------------------------------------------------------------------------- 1 | { 2 | "results": { 3 | "hellaswag": { 4 | "acc": 0.6346345349531965, 5 | "acc_stderr": 0.00480548376705535, 6 | "acc_norm": 0.8308105954989046, 7 | "acc_norm_stderr": 0.0037415289563158456 8 | } 9 | }, 10 | "versions": { 11 | "hellaswag": 0 12 | }, 13 | "config": { 14 | "model": "hf-causal-experimental", 15 | "model_args": "pretrained=NousResearch/Yarn-Solar-10b-64k,use_accelerate=True,dtype=bfloat16,trust_remote_code=True,peft=", 16 | "num_fewshot": 10, 17 | "batch_size": "24", 18 | "batch_sizes": [], 19 | "device": null, 20 | "no_cache": true, 21 | "limit": null, 22 | "bootstrap_iters": 100000, 23 | "description_dict": {} 24 | } 25 | } -------------------------------------------------------------------------------- /data/Yarn-Solar-10b-64k-truthfulqa.json: -------------------------------------------------------------------------------- 1 | { 2 | "results": { 3 | "truthfulqa_mc": { 4 | "mc1": 0.30599755201958384, 5 | "mc1_stderr": 0.01613222972815504, 6 | "mc2": 0.4565955145484738, 7 | "mc2_stderr": 0.014514499295144236 8 | } 9 | }, 10 | "versions": { 11 | "truthfulqa_mc": 1 12 | }, 13 | "config": { 14 | "model": "hf-causal-experimental", 15 | "model_args": "pretrained=NousResearch/Yarn-Solar-10b-64k,use_accelerate=True,dtype=bfloat16,trust_remote_code=True,peft=", 16 | "num_fewshot": 0, 17 | "batch_size": "48", 18 | "batch_sizes": [], 19 | "device": null, 20 | "no_cache": true, 21 | "limit": null, 22 | "bootstrap_iters": 100000, 23 | "description_dict": {} 24 | } 25 | } -------------------------------------------------------------------------------- /data/Yarn-Solar-10b-32k-arc.json: -------------------------------------------------------------------------------- 1 | { 2 | "results": { 3 | "arc_challenge": { 4 | "acc": 0.5580204778156996, 5 | "acc_stderr": 0.014512682523128345, 6 | "acc_norm": 0.5964163822525598, 7 | "acc_norm_stderr": 0.014337158914268447 8 | } 9 | }, 10 | "versions": { 11 | "arc_challenge": 0 12 | }, 13 | "config": { 14 | "model": "hf-causal-experimental", 15 | "model_args": "pretrained=NousResearch/Yarn-Solar-10b-32k,use_accelerate=True,dtype=bfloat16,trust_remote_code=True,peft=", 16 | "num_fewshot": 25, 17 | "batch_size": "24", 18 | "batch_sizes": [], 19 | "device": null, 20 | "no_cache": true, 21 | "limit": null, 22 | "bootstrap_iters": 100000, 23 | "description_dict": {} 24 | } 25 | } -------------------------------------------------------------------------------- /data/Yarn-Solar-10b-64k-arc.json: -------------------------------------------------------------------------------- 1 | { 2 | "results": { 3 | "arc_challenge": { 4 | "acc": 0.5597269624573379, 5 | "acc_stderr": 0.014506769524804236, 6 | "acc_norm": 0.5921501706484642, 7 | "acc_norm_stderr": 0.014361097288449703 8 | } 9 | }, 10 | "versions": { 11 | "arc_challenge": 0 12 | }, 13 | "config": { 14 | "model": "hf-causal-experimental", 15 | "model_args": "pretrained=NousResearch/Yarn-Solar-10b-64k,use_accelerate=True,dtype=bfloat16,trust_remote_code=True,peft=", 16 | "num_fewshot": 25, 17 | "batch_size": "24", 18 | "batch_sizes": [], 19 | "device": null, 20 | "no_cache": true, 21 | "limit": null, 22 | "bootstrap_iters": 100000, 23 | "description_dict": {} 24 | } 25 | } -------------------------------------------------------------------------------- /truncate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from datasets import load_dataset 3 | 4 | def main(args): 5 | dataset = load_dataset(args.dataset, split="train") 6 | def truncate(sample): 7 | sample["input_ids"] = sample["input_ids"][0:args.truncate] 8 | sample["labels"] = sample["labels"][0:args.truncate] 9 | sample["attention_mask"] = sample["attention_mask"][0:args.truncate] 10 | return sample 11 | dataset = dataset.map(truncate, desc="Truncating", num_proc=args.num_proc) 12 | dataset.save_to_disk(args.output) 13 | 14 | 15 | if __name__ == "__main__": 16 | args = argparse.ArgumentParser() 17 | args.add_argument("truncate", type=int) 18 | args.add_argument("output", type=str) 19 | args.add_argument("--num-proc", type=int, default=32) 20 | args.add_argument("--dataset", type=str, 21 | default="emozilla/pg_books-tokenized-bos-eos-chunked-65536") 22 | main(args.parse_args()) -------------------------------------------------------------------------------- /paper/plot.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import matplotlib.pyplot as plt 3 | import pandas as pd 4 | 5 | def main(args): 6 | data = pd.read_csv(args.csv) 7 | fig, ax = plt.subplots(figsize=(10,5)) 8 | 9 | x_data = [float(x) for x in data.columns[1:]] 10 | for row in data.values: 11 | label = row[0].replace("NousResearch/", "") 12 | ax.plot(x_data, [float(x) for x in row[1:]], label=label) 13 | 14 | ax.set_xlabel("Context Window") 15 | ax.set_ylabel("Perplexity (lower is better)") 16 | 17 | ax.set_xlim(args.xmin, args.xmax) 18 | ax.set_ylim(args.ymin, args.ymax) 19 | 20 | ax.legend(loc="upper right") 21 | 22 | fig.savefig(args.csv + ".png") 23 | fig.savefig(args.csv + ".pdf", transparent=True) 24 | 25 | if __name__ == "__main__": 26 | args = argparse.ArgumentParser() 27 | args.add_argument("csv", type=str) 28 | args.add_argument("--xmin", type=int, default=0) 29 | args.add_argument("--xmax", type=int, default=131072) 30 | args.add_argument("--ymin", type=float, default=2.2) 31 | args.add_argument("--ymax", type=float, default=3.8) 32 | main(args.parse_args()) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /deepspeed/zero2.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": [0.9, 0.95], 18 | "eps": 1e-8, 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | "scheduler": { 23 | "type": "WarmupDecayLR", 24 | "params": { 25 | "warmup_min_lr": "auto", 26 | "warmup_max_lr": "auto", 27 | "warmup_num_steps": "auto", 28 | "warmup_type": "linear", 29 | "total_num_steps": "auto" 30 | } 31 | }, 32 | "zero_optimization": { 33 | "stage": 2, 34 | "allgather_partitions": true, 35 | "allgather_bucket_size": 1e9, 36 | "overlap_comm": false, 37 | "reduce_scatter": true, 38 | "reduce_bucket_size": 1e9, 39 | "contiguous_gradients": true 40 | }, 41 | "gradient_accumulation_steps": "auto", 42 | "gradient_clipping": "auto", 43 | "steps_per_print": 2000, 44 | "train_batch_size": "auto", 45 | "train_micro_batch_size_per_gpu": "auto", 46 | "wall_clock_breakdown": false 47 | } -------------------------------------------------------------------------------- /data/proofpile-long-small-32k-70b.csv: -------------------------------------------------------------------------------- 1 | ,1024,1536,2048,2560,3072,3584,4096,4608,5120,5632,6144,6656,7168,7680,8192,9216,10240,11264,12288,13312,14336,15360,16384,17408,18432,19456,20480,21504,22528,23552,24576,25600,26624,27648,28672,29696,30720,31744,32768 2 | meta-llama/Llama-2-70b-hf,3.7055604457855225,3.45805287361145,3.2749223709106445,3.173884868621826,3.0366575717926025,3.018570899963379,2.957059383392334,2.9754884243011475,3.6701107025146484,6.477595806121826,11.858061790466309,19.669586181640625,29.5391902923584,41.38957214355469,56.098636627197266 3 | NousResearch/Yarn-Llama-2-70b-32k,3.6099791526794434,3.392317056655884,3.22432279586792,3.123823404312134,2.9928181171417236,2.969006061553955,2.9126534461975098,2.9208717346191406,2.960122585296631,2.94386363029484,2.9340384006500244,2.917194366455078,2.877331256866455,2.838615655899048,2.8187291622161865,2.7568020820617676,2.6751885414123535,2.623396635055542,2.5787668228149414,2.5502796173095703,2.527597188949585,2.4948105812072754,2.450591802597046,2.4069149494171143,2.386108636856079,2.3773491382598877,2.3615598678588867,2.3387508392333984,2.3200013637542725,2.3081839084625244,2.2967910766601562,2.2946012020111084,2.2839956283569336,2.273364305496216,2.2604012489318848,2.2527451515197754,2.24609112739563,2.2412569522857666,2.2288265228271484 4 | -------------------------------------------------------------------------------- /deepspeed/zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": [0.9, 0.95], 18 | "eps": 1e-8, 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | "scheduler": { 23 | "type": "WarmupDecayLR", 24 | "params": { 25 | "warmup_min_lr": "auto", 26 | "warmup_max_lr": "auto", 27 | "warmup_num_steps": "auto", 28 | "warmup_type": "linear", 29 | "total_num_steps": "auto" 30 | } 31 | }, 32 | "zero_optimization": { 33 | "stage": 3, 34 | "overlap_comm": true, 35 | "contiguous_gradients": true, 36 | "sub_group_size": 1e9, 37 | "reduce_bucket_size": "auto", 38 | "stage3_prefetch_bucket_size": "auto", 39 | "stage3_param_persistence_threshold": "auto", 40 | "stage3_max_live_parameters": 1e9, 41 | "stage3_max_reuse_distance": 1e9, 42 | "stage3_gather_16bit_weights_on_model_save": true 43 | }, 44 | "gradient_accumulation_steps": "auto", 45 | "gradient_clipping": "auto", 46 | "steps_per_print": 2000, 47 | "train_batch_size": "auto", 48 | "train_micro_batch_size_per_gpu": "auto", 49 | "wall_clock_breakdown": false 50 | } -------------------------------------------------------------------------------- /deepspeed/zero3_offload.json: -------------------------------------------------------------------------------- 1 | { 2 | "bf16": { 3 | "enabled": "auto" 4 | }, 5 | "scheduler": { 6 | "type": "WarmupDecayLR", 7 | "params": { 8 | "warmup_min_lr": "auto", 9 | "warmup_max_lr": "auto", 10 | "warmup_num_steps": "auto", 11 | "warmup_type": "linear", 12 | "total_num_steps": "auto" 13 | } 14 | }, 15 | "optimizer": { 16 | "type": "AdamW", 17 | "params": { 18 | "lr": "auto", 19 | "betas": [0.9, 0.95], 20 | "eps": 1e-8, 21 | "weight_decay": "auto" 22 | } 23 | }, 24 | "zero_optimization": { 25 | "stage": 3, 26 | "offload_optimizer": { 27 | "device": "cpu", 28 | "pin_memory": true 29 | }, 30 | "offload_param": { 31 | "device": "cpu", 32 | "pin_memory": true 33 | }, 34 | "overlap_comm": true, 35 | "contiguous_gradients": true, 36 | "sub_group_size": 1e9, 37 | "reduce_bucket_size": "auto", 38 | "stage3_prefetch_bucket_size": "auto", 39 | "stage3_param_persistence_threshold": "auto", 40 | "stage3_max_live_parameters": 1e9, 41 | "stage3_max_reuse_distance": 1e9, 42 | "stage3_gather_16bit_weights_on_model_save": true 43 | }, 44 | "gradient_accumulation_steps": "auto", 45 | "gradient_clipping": "auto", 46 | "steps_per_print": 2000, 47 | "train_batch_size": "auto", 48 | "train_micro_batch_size_per_gpu": "auto", 49 | "wall_clock_breakdown": false 50 | } -------------------------------------------------------------------------------- /eval/prompt-loop.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | from datasets import load_dataset 4 | from transformers import AutoTokenizer, pipeline 5 | from tqdm import tqdm 6 | from model_loader import * 7 | 8 | 9 | def main(args): 10 | tokenizer = AutoTokenizer.from_pretrained( 11 | args.model, model_max_length=sys.maxsize, trust_remote_code=True) 12 | tokenizer.pad_token = tokenizer.eos_token 13 | 14 | model = load_model_and_apply_patches(args.model, args) 15 | 16 | pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, pad_token_id=tokenizer.eos_token_id, 17 | temperature=args.temperature, repetition_penalty=args.repetition_penalty, 18 | top_k=args.top_k, penalty_alpha=args.penalty_alpha, do_sample=args.temperature is not None) 19 | 20 | while True: 21 | if args.input_file is None: 22 | prompt_text = input("> ") 23 | else: 24 | input(f"Press enter to read {args.input_file} ") 25 | prompt_text = open(args.input_file, encoding="utf=8").read() 26 | response = pipe(prompt_text, num_return_sequences=1, max_new_tokens=args.max_new_tokens)[ 27 | 0]["generated_text"][len(prompt_text):] 28 | print(f"< {response}") 29 | 30 | 31 | if __name__ == "__main__": 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument("-m", "--model", type=str, required=True) 34 | parser.add_argument("--max-new-tokens", type=int, default=256) 35 | parser.add_argument("--input-file", type=str) 36 | parser.add_argument("--temperature", type=float) 37 | parser.add_argument("--repetition-penalty", type=float) 38 | parser.add_argument("--penalty-alpha", type=float) 39 | parser.add_argument("--top-k", type=int) 40 | 41 | args = add_args(parser).parse_args() 42 | main(args) 43 | -------------------------------------------------------------------------------- /scaled_rope/GPTNeoXNTKScaledRotaryEmbedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class GPTNeoXNTKScaledRotaryEmbedding(torch.nn.Module): 4 | def __init__(self, dim, max_position_embeddings, base=10000, alpha=1, device=None): 5 | super().__init__() 6 | base = base * alpha ** (dim / (dim-2)) 7 | inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) 8 | self.register_buffer("inv_freq", inv_freq) 9 | 10 | # Build here to make `torch.jit.trace` work. 11 | self.max_seq_len_cached = max_position_embeddings 12 | t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) 13 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 14 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 15 | emb = torch.cat((freqs, freqs), dim=-1) 16 | self.cos_cached = emb.cos()[None, None, :, :] 17 | self.sin_cached = emb.sin()[None, None, :, :] 18 | 19 | def forward(self, x, seq_len=None): 20 | # x: [bs, num_attention_heads, seq_len, head_size] 21 | # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. 22 | if seq_len > self.max_seq_len_cached: 23 | self.max_seq_len_cached = seq_len 24 | t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) 25 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 26 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 27 | emb = torch.cat((freqs, freqs), dim=-1).to(x.device) 28 | self.cos_cached = emb.cos()[None, None, :, :] 29 | self.sin_cached = emb.sin()[None, None, :, :] 30 | return self.cos_cached[:seq_len, ...].to(x.device), self.sin_cached[:seq_len, ...].to(x.device) -------------------------------------------------------------------------------- /scaled_rope/GPTNeoXDynamicScaledRotaryEmbedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class GPTNeoXDynamicScaledRotaryEmbedding(torch.nn.Module): 4 | def __init__(self, dim, max_position_embeddings, base=10000, device=None): 5 | super().__init__() 6 | self.max_position_embeddings = max_position_embeddings 7 | inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) 8 | self.register_buffer("inv_freq", inv_freq) 9 | 10 | # Build here to make `torch.jit.trace` work. 11 | self.max_seq_len_cached = max_position_embeddings 12 | t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) 13 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 14 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 15 | emb = torch.cat((freqs, freqs), dim=-1) 16 | self.cos_cached = emb.cos()[None, None, :, :] 17 | self.sin_cached = emb.sin()[None, None, :, :] 18 | 19 | def forward(self, x, seq_len=None): 20 | # x: [bs, num_attention_heads, seq_len, head_size] 21 | # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. 22 | if seq_len > self.max_seq_len_cached: 23 | self.max_seq_len_cached = seq_len 24 | t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) 25 | t *= self.max_position_embeddings / seq_len 26 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 27 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 28 | emb = torch.cat((freqs, freqs), dim=-1).to(x.device) 29 | self.cos_cached = emb.cos()[None, None, :, :] 30 | self.sin_cached = emb.sin()[None, None, :, :] 31 | return self.cos_cached[:seq_len, ...].to(x.device), self.sin_cached[:seq_len, ...].to(x.device) -------------------------------------------------------------------------------- /scaled_rope/LlamaNTKScaledRotaryEmbedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class LlamaNTKScaledRotaryEmbedding(torch.nn.Module): 4 | def __init__(self, dim, max_position_embeddings=2048, base=10000, alpha=1, device=None): 5 | super().__init__() 6 | base = base * alpha ** (dim / (dim-2)) 7 | inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) 8 | self.register_buffer("inv_freq", inv_freq) 9 | 10 | # Build here to make `torch.jit.trace` work. 11 | self.max_seq_len_cached = max_position_embeddings 12 | t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) 13 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 14 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 15 | emb = torch.cat((freqs, freqs), dim=-1) 16 | dtype = torch.get_default_dtype() 17 | self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) 18 | self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) 19 | 20 | def forward(self, x, seq_len=None): 21 | # x: [bs, num_attention_heads, seq_len, head_size] 22 | # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. 23 | if seq_len > self.max_seq_len_cached: 24 | self.max_seq_len_cached = seq_len 25 | t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) 26 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 27 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 28 | emb = torch.cat((freqs, freqs), dim=-1).to(x.device) 29 | self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False) 30 | self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False) 31 | return ( 32 | self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), 33 | self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), 34 | ) -------------------------------------------------------------------------------- /scaled_rope/LlamaLinearScaledRotaryEmbedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class LlamaLinearScaledRotaryEmbedding(torch.nn.Module): 4 | def __init__(self, dim, max_position_embeddings=2048, base=10000, scale=1, device=None): 5 | super().__init__() 6 | self.scale = scale 7 | inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) 8 | self.register_buffer("inv_freq", inv_freq) 9 | 10 | # Build here to make `torch.jit.trace` work. 11 | self.max_seq_len_cached = max_position_embeddings 12 | t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) 13 | t /= self.scale 14 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 15 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 16 | emb = torch.cat((freqs, freqs), dim=-1) 17 | dtype = torch.get_default_dtype() 18 | self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) 19 | self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) 20 | 21 | def forward(self, x, seq_len=None): 22 | # x: [bs, num_attention_heads, seq_len, head_size] 23 | # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. 24 | if seq_len > self.max_seq_len_cached: 25 | self.max_seq_len_cached = seq_len 26 | t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) 27 | t /= self.scale 28 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 29 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 30 | emb = torch.cat((freqs, freqs), dim=-1).to(x.device) 31 | self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False) 32 | self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False) 33 | return ( 34 | self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), 35 | self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), 36 | ) -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # run `accelerate config` first. pass --deepspeed to finetune.py if using DeepSpeed 4 | 5 | accelerate launch finetune.py \ 6 | --output-dir output/yarn-7b-64k \ 7 | --model NousResearch/Llama-2-7b-hf 8 | 9 | accelerate launch finetune.py \ 10 | --output-dir output/yarn-7b-128k \ 11 | --model output/yarn-7b-64k \ 12 | --max-train-steps 200 \ 13 | --scaling-factor 32 \ 14 | --seed 31337 15 | 16 | accelerate launch finetune.py \ 17 | --model NousResearch/Llama-2-13b-hf \ 18 | --output-dir output/yarn-13b-64k 19 | 20 | accelerate launch finetune.py \ 21 | --output-dir output/yarn-13b-128k \ 22 | --model output/yarn-13b-64k \ 23 | --max-train-steps 200 \ 24 | --scaling-factor 32 \ 25 | --seed 31337 26 | 27 | accelerate launch finetune.py \ 28 | --model NousResearch/Llama-2-70b-hf \ 29 | --output-dir output/yarn-70b-32k \ 30 | --learning-rate 0.00001 \ 31 | --lr-schedule constant \ 32 | --scaling-factor 8 \ 33 | --dataset emozilla/yarn-train-tokenized-8k-llama 34 | 35 | # ablations 36 | 37 | python3 truncate.py 8192 output/truncated-8k 38 | 39 | accelerate launch finetune.py \ 40 | --output-dir output/linear-7b-8k \ 41 | --model NousResearch/Llama-2-7b-hf \ 42 | --scaling-type linear \ 43 | --scaling-factor 2 \ 44 | --dataset output/truncated-8k 45 | 46 | accelerate launch finetune.py \ 47 | --output-dir output/ntk-7b-8k \ 48 | --model NousResearch/Llama-2-7b-hf \ 49 | --scaling-type ntk \ 50 | --scaling-factor 1 \ 51 | --rope-theta 20000 \ 52 | --dataset output/truncated-8k 53 | 54 | accelerate launch finetune.py \ 55 | --output-dir output/yarn-7b-8k \ 56 | --model NousResearch/Llama-2-7b-hf \ 57 | --scaling-factor 2 \ 58 | --dataset output/truncated-8k 59 | 60 | # mistral 61 | 62 | accelerate launch finetune.py \ 63 | --output-dir output/yarn-mistral-7b-64k \ 64 | --model mistralai/Mistral-7B-v0.1 \ 65 | --architecture mistral \ 66 | --scaling-factor 8 \ 67 | --max-position-embeddings 16384 \ 68 | --dataset emozilla/yarn-train-tokenized-16k-mistral \ 69 | --sliding-window-attention-schedule 65536 \ 70 | --lr-schedule constant \ 71 | --learning-rate 0.000001 \ 72 | --max-train-steps 1000 73 | 74 | accelerate launch finetune.py \ 75 | --output-dir output/yarn-mistral-7b-128k \ 76 | --model output/yarn-mistral-7b-64k \ 77 | --architecture mistral \ 78 | --scaling-factor 16 \ 79 | --max-position-embeddings 16384 \ 80 | --dataset emozilla/yarn-train-tokenized-16k-mistral \ 81 | --sliding-window-attention-schedule 131072 \ 82 | --lr-schedule constant \ 83 | --learning-rate 0.000001 \ 84 | --max-train-steps 500 \ 85 | --seed 31337 -------------------------------------------------------------------------------- /scaled_rope/LlamaDynamicScaledRotaryEmbedding.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | class LlamaDynamicScaledRotaryEmbedding(torch.nn.Module): 5 | def __init__(self, dim, max_position_embeddings=2048, base=10000, ntk=False, device=None): 6 | super().__init__() 7 | self.ntk = ntk 8 | self.base = base 9 | self.dim = dim 10 | self.max_position_embeddings = max_position_embeddings 11 | inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) 12 | self.register_buffer("inv_freq", inv_freq) 13 | 14 | # Build here to make `torch.jit.trace` work. 15 | self.max_seq_len_cached = max_position_embeddings 16 | t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) 17 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 18 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 19 | emb = torch.cat((freqs, freqs), dim=-1) 20 | dtype = torch.get_default_dtype() 21 | self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) 22 | self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) 23 | 24 | def forward(self, x, seq_len=None): 25 | # x: [bs, num_attention_heads, seq_len, head_size] 26 | # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. 27 | if seq_len > self.max_seq_len_cached: 28 | self.max_seq_len_cached = seq_len 29 | if self.ntk: 30 | base = self.base * ((self.ntk * seq_len / self.max_position_embeddings) - (self.ntk - 1)) ** (self.dim / (self.dim-2)) 31 | inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(x.device) / self.dim)) 32 | self.register_buffer("inv_freq", inv_freq) 33 | t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) 34 | if not self.ntk: 35 | t *= self.max_position_embeddings / seq_len 36 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 37 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 38 | emb = torch.cat((freqs, freqs), dim=-1).to(x.device) 39 | self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False) 40 | self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False) 41 | return ( 42 | self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), 43 | self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), 44 | ) -------------------------------------------------------------------------------- /eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # python eval/perplexity.py -m meta-llama/Llama-2-7b-hf --dataset pg19 --split test --feature text --save-tokenized output/pg19-test-tokenized 4 | PG19="--tokenized emozilla/pg19-test-tokenized" 5 | 6 | # python eval/perplexity.py -m meta-llama/Llama-2-7b-hf --dataset tau/scrolls --subset gov_report --split test --feature input --save-tokenized output/govreport-test-tokenized 7 | GOVREPORT="--tokenized emozilla/govreport-test-tokenized --dataset-min-tokens 16384 --samples 50" 8 | 9 | # python eval/perplexity.py -m meta-llama/Llama-2-7b-hf --dataset hoskinson-center/proof-pile --split test --feature text --save-tokenized output/proofpile-test-tokenized 10 | PROOFPILE="--tokenized emozilla/proofpile-test-tokenized --dataset-min-tokens 32768 --samples 50" 11 | PROOFPILE_LONG_SMALL="--tokenized emozilla/proofpile-test-tokenized --dataset-min-tokens 131072 --samples 10 --truncate" 12 | 13 | # python eval/perplexity.py -m mistralai/Mistral-7B-v0.1 --dataset hoskinson-center/proof-pile --split test --feature text --save-tokenized output/proofpile-test-tokenized-mistral 14 | PROOFPILE_LONG_SMALL_MISTRAL="--tokenized emozilla/proofpile-test-tokenized-mistral --dataset-min-tokens 131072 --samples 10 --truncate --split train" 15 | 16 | CUSTOM="--custom-model-together" 17 | 18 | python eval/perplexity.py \ 19 | ${PROOFPILE_LONG_SMALL} ${CUSTOM} \ 20 | --output-file data/proofpile-long-small.csv \ 21 | --min-tokens 2048 --max-tokens 131072 --tokens-step 2048 --aggressive-memory \ 22 | -m NousResearch/CodeLlama-13b-hf \ 23 | -m NousResearch/Yarn-Llama-2-13b-64k \ 24 | -m NousResearch/Yarn-Llama-2-13b-128k \ 25 | -m togethercomputer/LLaMA-2-7B-32K \ 26 | -m NousResearch/CodeLlama-7b-hf \ 27 | -m NousResearch/Yarn-Llama-2-7b-64k \ 28 | -m NousResearch/Yarn-Llama-2-7b-128k 29 | 30 | python eval/perplexity.py \ 31 | ${GOVREPORT} ${CUSTOM} \ 32 | --output-file data/govreport.csv \ 33 | --min-tokens 32768 --max-tokens 32768 \ 34 | -m NousResearch/CodeLlama-13b-hf \ 35 | -m NousResearch/Yarn-Llama-2-13b-64k \ 36 | -m NousResearch/Yarn-Llama-2-13b-128k \ 37 | -m togethercomputer/LLaMA-2-7B-32K \ 38 | -m NousResearch/CodeLlama-7b-hf \ 39 | -m NousResearch/Yarn-Llama-2-7b-64k \ 40 | -m NousResearch/Yarn-Llama-2-7b-128k 41 | 42 | python eval/perplexity.py \ 43 | ${PROOFPILE_LONG_SMALL} ${CUSTOM} \ 44 | --output-file data/proofpile-long-small-8k.csv \ 45 | --min-tokens 2048 --max-tokens 16384 --tokens-step 2048 \ 46 | -m emozilla/Yarn-Llama-2-7b-8k \ 47 | -m emozilla/NTK-Llama-2-7b-8k \ 48 | -m conceptofmind/LLongMA-2-7b 49 | 50 | python eval/perplexity.py \ 51 | ${PROOFPILE_LONG_SMALL_MISTRAL} \ 52 | --output-file data/proofpile-long-small-mistral.csv \ 53 | --flash-attention --custom-model-mistral \ 54 | --min-tokens 2048 --max-tokens 131072 --tokens-step 2048 --aggressive-memory \ 55 | --sliding-window-attention 131072 \ 56 | -m NousResearch/Yarn-Mistral-7b-64k \ 57 | -m NousResearch/Yarn-Mistral-7b-128k \ 58 | -m amazon/MistralLite \ 59 | -m mistralai/Mistral-7B-v0.1 -------------------------------------------------------------------------------- /data/proofpile-long-small-mistral.csv: -------------------------------------------------------------------------------- 1 | ,2048,4096,6144,8192,10240,12288,14336,16384,18432,20480,22528,24576,26624,28672,30720,32768,34816,36864,38912,40960,43008,45056,47104,49152,51200,53248,55296,57344,59392,61440,63488,65536,67584,69632,71680,73728,75776,77824,79872,81920,83968,86016,88064,90112,92160,94208,96256,98304,100352,102400,104448,106496,108544,110592,112640,114688,116736,118784,120832,122880,124928,126976,129024,131072 2 | NousResearch/Yarn-Mistral-7b-64k,3.543721914,3.179555655,3.202048063,3.042807817,2.894950151,2.795645714,2.730502129,2.646996498,2.567584038,2.536862135,2.492991209,2.461726665,2.445360184,2.419542313,2.395036697,2.374049664,2.354360342,2.339094639,2.330737829,2.315256357,2.312551737,2.297796965,2.284247398,2.270946741,2.264210939,2.251716614,2.235651016,2.224263191,2.214893103,2.211028099,2.203557014,2.196890116,2.193943739,2.189531803,2.18980217,2.233218908,2.356315136,2.56439805,2.845679522,3.170113802,3.569000244,4.059425354,4.652839661,5.442737579,6.42795229,7.59900856,8.925559998,10.36964417,11.93230343,13.71412182,15.65219021,17.80123711,20.06212044,22.48859978,25.13574409,27.92183876,30.82406998,34.03783798,37.50074387,41.03094864,44.77576447,48.74263,52.95149612,57.42604446 3 | NousResearch/Yarn-Mistral-7b-128k,3.57974577,3.213460445,3.241326571,3.082948208,2.932858944,2.83460021,2.769934416,2.687001467,2.608081818,2.577638865,2.534418106,2.502552509,2.485738754,2.460346937,2.435867786,2.415941,2.396462202,2.381477594,2.372894526,2.357146025,2.354343891,2.339352608,2.325308323,2.311910391,2.305170536,2.292590141,2.276251316,2.264626741,2.254742622,2.250113487,2.24245739,2.235068083,2.230760574,2.223576069,2.214715242,2.212279081,2.209080696,2.203451633,2.198322058,2.196845531,2.192856312,2.192943335,2.191142321,2.190784454,2.187283993,2.181257963,2.17683506,2.175874472,2.175319672,2.176944494,2.178384781,2.177542686,2.174207926,2.175656319,2.17612648,2.177727222,2.175268173,2.179584503,2.181523085,2.182691097,2.18452549,2.18682313,2.187037706,2.187235117 4 | amazon/MistralLite,3.660849571,3.26819706,3.294190645,3.128564596,3.12820673,8.094698906,23.45741653,47.2832489,80.00630188,124.1926651,177.1727142,240.5597229,320.0137329,409.2606812,508.006134,633.5405273,767.4940796,889.2739258,1011.397034,1141.009888,1295.947021,1451.602173,1632.844604,1839.353638,2039.824829,2248.419434,2424.333984,2581.619141,2724.690186,2862.85376,2998.141357,3130.638184,3267.87793,3399.478271,3528.172363,3652.74707,3771.285889,3885.411377,4003.28418,4114.573242,4227.552734,4341.183105,4449.754883,4556.889648,4661.644531,4761.617676,4854.407227,4947.772949,5036.537598,5127.046387,5214.43457,5304.992676,5388.914551,5471.793945,5553.78125,5631.61377,5705.724609,5783.854492,5861.807617,5933.722168,6016.069824,6097.908203,6177.30127,6255.70459 5 | mistralai/Mistral-7B-v0.1,3.448711634,3.090441465,3.110674858,2.960596561,3.015338421,6.869444847,18.1113987,36.80331039,64.37407684,102.7547073,146.9998169,200.4661407,265.9177551,340.8301392,442.8686523,576.6227417,737.1204224,894.3130493,1050.718628,1211.67334,1409.42688,1631.725342,1829.34375,2031.636963,2216.33252,2404.873535,2569.363037,2719.29834,2864.481445,3004.575928,3141.500732,3274.869629,3413.874023,3546.929932,3676.311768,3801.981201,3920.933594,4034.686279,4151.763184,4262.827637,4377.522461,4490.838867,4597.029297,4702.352539,4807.625977,,,,,,,,,,,,,,,,,,, -------------------------------------------------------------------------------- /tokenization.py: -------------------------------------------------------------------------------- 1 | from itertools import chain 2 | from functools import reduce 3 | import multiprocessing 4 | import argparse 5 | from typing import List 6 | from datasets import concatenate_datasets, load_dataset, load_from_disk 7 | from transformers import AutoTokenizer 8 | 9 | def main(args): 10 | if args.dataset is None or len(args.dataset[0]) == 0: 11 | raise RuntimeError("No datasets provided") 12 | datasets = args.dataset[0] 13 | 14 | splits = [x.split(",")[1] if len(x.split(",")) == 2 else "" for x in datasets] 15 | datasets = [x.split(",")[0] for x in datasets] 16 | 17 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) 18 | if args.json: 19 | dataset = load_dataset("json", data_files=datasets)[args.split] 20 | if reduce(lambda x,y: x or len(y) > 0, splits, False): 21 | if len(datasets) > 1: 22 | raise RuntimeError("Can only use splitting on json datasets if there is exactly one input file") 23 | dataset = dataset.train_test_split(train_size=float(splits[0]), seed=args.seed)["train"] 24 | else: 25 | to_concatenate = [] 26 | for i in range(0, len(datasets)): 27 | try: 28 | loaded = load_from_disk(datasets[i]) 29 | except: 30 | loaded = load_dataset([i])[args.split] 31 | if len(splits[i]) > 0: 32 | loaded = loaded.train_test_split(train_size=float(splits[i]), seed=args.seed)["train"] 33 | to_concatenate.append(loaded) 34 | dataset = concatenate_datasets(to_concatenate) 35 | 36 | dataset = dataset.remove_columns([x for x in dataset.column_names if x not in [args.feature]]) 37 | 38 | tokenized_dataset = dataset.map( 39 | lambda example: tokenizer( 40 | [t + tokenizer.eos_token for t in example[args.feature]]), 41 | batched=True, 42 | num_proc=args.num_proc, 43 | remove_columns=[args.feature], 44 | ) 45 | 46 | block_size = args.sequence_length 47 | 48 | # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. 49 | def group_texts(examples): 50 | # Concatenate all texts. 51 | concatenated_examples = { 52 | k: list(chain(*examples[k])) for k in examples.keys()} 53 | total_length = len(concatenated_examples[list(examples.keys())[0]]) 54 | # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can 55 | # customize this part to your needs. 56 | if total_length >= block_size: 57 | total_length = (total_length // block_size) * block_size 58 | # Split by chunks of max_len. 59 | result = { 60 | k: [t[i: i + block_size] 61 | for i in range(0, total_length, block_size)] 62 | for k, t in concatenated_examples.items() 63 | } 64 | return result 65 | 66 | train_dataset = tokenized_dataset.map( 67 | group_texts, batched=True, num_proc=args.num_proc, 68 | ) 69 | 70 | if args.output: 71 | train_dataset.save_to_disk(args.output) 72 | if args.push_to_hub: 73 | train_dataset.push_to_hub(args.push_to_hub, private=True) 74 | 75 | if __name__ == "__main__": 76 | parser = argparse.ArgumentParser() 77 | parser.add_argument("--dataset", action="append", nargs="+") 78 | parser.add_argument("--tokenizer", default="NousResearch/Llama-2-7b-hf") 79 | parser.add_argument("--sequence-length", type=int, default=8192) 80 | parser.add_argument("--feature", type=str, default="text") 81 | parser.add_argument("--split", type=str, default="train") 82 | parser.add_argument("--output", type=str) 83 | parser.add_argument("--push-to-hub", type=str) 84 | parser.add_argument("--seed", type=int, default=42) 85 | parser.add_argument("--json", action="store_true") 86 | parser.add_argument("--num-proc", type=int, 87 | default=multiprocessing.cpu_count()) 88 | main(parser.parse_args()) 89 | -------------------------------------------------------------------------------- /eval/quality.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy 3 | import sys 4 | import torch 5 | from datasets import load_dataset 6 | from transformers import AutoTokenizer 7 | from tqdm import tqdm 8 | from model_loader import load_model_and_apply_patches, add_args 9 | 10 | 11 | ZERO_SCROLLS_QUALITY_PROMPT = "You are provided a story and a multiple-choice question with 4 possible answers (marked by A, B, C, D). Choose the best answer by writing its corresponding letter (either A, B, C, or D).\n\nStory:\n{story}\n\nQuestion and Possible Answers:\n{question}\n (A) {a}\n (B) {b}\n (C) {c}\n (D) {d}" 12 | CHOICES = ["A", "B", "C", "D"] 13 | 14 | 15 | def get_prompt(sample): 16 | options = sample["options"] 17 | instruction = ZERO_SCROLLS_QUALITY_PROMPT.format( 18 | story=sample["article"], question=sample["question"], a=options[0], b=options[1], c=options[2], d=options[3]) 19 | return f"{instruction}\n\nAnswer: (" 20 | 21 | 22 | def main(args): 23 | models = [x[0] for x in args.model] 24 | 25 | tokenizer = AutoTokenizer.from_pretrained( 26 | models[0], model_max_length=sys.maxsize, trust_remote_code=True) 27 | tokenizer.pad_token = tokenizer.eos_token 28 | tokenizer.pad_token_id = tokenizer.eos_token_id 29 | 30 | dataset = load_dataset("emozilla/quality", split=args.split) 31 | dataset = dataset.map(lambda sample: {"prompt": get_prompt(sample)}) 32 | if args.max_tokens: 33 | dataset = dataset.filter(lambda sample: len( 34 | tokenizer(sample["prompt"]).input_ids) <= args.max_tokens) 35 | 36 | choice_tokens = [x[0] for x in tokenizer( 37 | CHOICES, add_special_tokens=False).input_ids] 38 | decoded_choice = tokenizer.decode( 39 | choice_tokens, clean_up_tokenization_spaces=True) 40 | 41 | results = [] 42 | for model in models: 43 | torch.cuda.empty_cache() 44 | 45 | loaded = load_model_and_apply_patches(model, args) 46 | 47 | correct_answers = 0 48 | i = 0 49 | max = len(dataset) if args.limit is None else args.limit 50 | bar = tqdm(total=max) 51 | while i < max: 52 | sample = dataset[i] 53 | tokenized_prompt = tokenizer(sample["prompt"], return_tensors="pt") 54 | input_ids = tokenized_prompt.input_ids.to("cuda") 55 | attention_mask = tokenized_prompt.attention_mask.to("cuda") 56 | 57 | output = loaded.generate(input_ids, attention_mask=attention_mask, 58 | max_new_tokens=1, return_dict_in_generate=True, output_scores=True, pad_token_id=tokenizer.eos_token_id) 59 | scores = output.scores[0][0] 60 | choice_scores = [x.cpu() for x in [scores[choice_tokens[0]], 61 | scores[choice_tokens[1]], scores[choice_tokens[2]], scores[choice_tokens[3]]]] 62 | selection = numpy.argmax([x.float().cpu() for x in choice_scores]) 63 | 64 | correct_answers += 1 if selection == sample["answer"] else 0 65 | 66 | if args.print_choices: 67 | print( 68 | f"Choice: {CHOICES[selection]} Correct: {CHOICES[sample['answer']]}") 69 | 70 | i += 1 71 | percent = (correct_answers / i) * 100.0 72 | 73 | bar.desc = f"{model}: {percent:.1f}%" 74 | bar.update() 75 | 76 | percent = correct_answers / max 77 | results.append(str(percent)) 78 | 79 | if args.output_file: 80 | with open(args.output_file, "w", encoding="utf-8") as f: 81 | f.write(",".join(models) + "\n") 82 | f.write(",".join(results) + "\n") 83 | 84 | if __name__ == "__main__": 85 | parser = argparse.ArgumentParser() 86 | parser.add_argument("-m", "--model", action="append", nargs="+") 87 | parser.add_argument("--limit", type=int) 88 | parser.add_argument("--max-tokens", type=int) 89 | parser.add_argument("--split", type=str, default="validation") 90 | parser.add_argument("--print-choices", action="store_true") 91 | parser.add_argument("--output-file", type=str) 92 | 93 | args = add_args(parser).parse_args() 94 | main(args) 95 | -------------------------------------------------------------------------------- /data/proofpile-long-small-solar.csv: -------------------------------------------------------------------------------- 1 | ,2048,4096,6144,8192,10240,12288,14336,16384,18432,20480,22528,24576,26624,28672,30720,32768,34816,36864,38912,40960,43008,45056,47104,49152,51200,53248,55296,57344,59392,61440,63488,65536,67584,69632,71680,73728,75776,77824,79872,81920,83968,86016,88064,90112,92160,94208,96256,98304,100352,102400,104448,106496,108544,110592,112640,114688,116736,118784,120832,122880,124928,126976,129024,131072 2 | mistralai/Mistral-7B-v0.1,3.448711634,3.090441465,3.110674858,2.960596561,3.015338421,6.869444847,18.1113987,36.80331039,64.37407684,102.7547073,146.9998169,200.4661407,265.9177551,340.8301392,442.8686523,576.6227417,737.1204224,894.3130493,1050.718628,1211.67334,1409.42688,1631.725342,1829.34375,2031.636963,2216.33252,2404.873535,2569.363037,2719.29834,2864.481445,3004.575928,3141.500732,3274.869629,3413.874023,3546.929932,3676.311768,3801.981201,3920.933594,4034.686279,4151.763184,4262.827637,4377.522461,4490.838867,4597.029297,4702.352539,4807.625977,,,,,,,,,,,,,,,,,,, 3 | Yarn-Mistral-7b-64k,3.543721914,3.179555655,3.202048063,3.042807817,2.894950151,2.795645714,2.730502129,2.646996498,2.567584038,2.536862135,2.492991209,2.461726665,2.445360184,2.419542313,2.395036697,2.374049664,2.354360342,2.339094639,2.330737829,2.315256357,2.312551737,2.297796965,2.284247398,2.270946741,2.264210939,2.251716614,2.235651016,2.224263191,2.214893103,2.211028099,2.203557014,2.196890116,2.193943739,2.189531803,2.18980217,2.233218908,2.356315136,2.56439805,2.845679522,3.170113802,3.569000244,4.059425354,4.652839661,5.442737579,6.42795229,7.59900856,8.925559998,10.36964417,11.93230343,13.71412182,15.65219021,17.80123711,20.06212044,22.48859978,25.13574409,27.92183876,30.82406998,34.03783798,37.50074387,41.03094864,44.77576447,48.74263,52.95149612,57.42604446 4 | Yarn-Mistral-7b-128k,3.57974577,3.213460445,3.241326571,3.082948208,2.932858944,2.83460021,2.769934416,2.687001467,2.608081818,2.577638865,2.534418106,2.502552509,2.485738754,2.460346937,2.435867786,2.415941,2.396462202,2.381477594,2.372894526,2.357146025,2.354343891,2.339352608,2.325308323,2.311910391,2.305170536,2.292590141,2.276251316,2.264626741,2.254742622,2.250113487,2.24245739,2.235068083,2.230760574,2.223576069,2.214715242,2.212279081,2.209080696,2.203451633,2.198322058,2.196845531,2.192856312,2.192943335,2.191142321,2.190784454,2.187283993,2.181257963,2.17683506,2.175874472,2.175319672,2.176944494,2.178384781,2.177542686,2.174207926,2.175656319,2.17612648,2.177727222,2.175268173,2.179584503,2.181523085,2.182691097,2.18452549,2.18682313,2.187037706,2.187235117 5 | upstage/SOLAR-10.7B-v1.0,3.4076623916625977,3.073754072189331,4.124830722808838,11.522461891174316,32.61447525024414,81.04706573486328,181.0179901123047,351.9505310058594,602.053466796875,874.5831909179688,1082.156494140625,1294.7906494140625,1461.424072265625,1611.6634521484375,1845.816650390625,2064.8994140625,2212.74853515625,2405.04541015625,2591.93359375,2761.818359375,2847.3662109375,2902.39599609375,3019.432373046875,3152.097412109375,3311.100341796875,3452.78173828125,3567.749755859375,3596.150390625,3603.256591796875,3670.601806640625,3755.050048828125 6 | Yarn-Solar-10b-32k,3.4276325702667236,3.087886333465576,3.111973762512207,2.950162172317505,2.8105783462524414,2.7172908782958984,2.653538703918457,2.5744435787200928,2.496849298477173,2.4663002490997314,2.4250009059906006,2.3932745456695557,2.377849817276001,2.35152006149292,2.327505350112915,2.307377338409424,2.2881836891174316,2.2747302055358887,2.278059244155884,2.299076557159424,2.407583713531494,2.5985310077667236,2.913794755935669,3.3971292972564697,4.003856658935547,4.71740198135376,5.495856285095215,6.412007808685303,7.422156810760498,8.540860176086426,9.84152889251709 7 | Yarn-Solar-10b-64k,3.4785315990448,3.1318323612213135,3.1546216011047363,2.9895522594451904,2.8481993675231934,2.753148078918457,2.68953275680542,2.6093976497650146,2.530197858810425,2.4994125366210938,2.4572999477386475,2.424976348876953,2.4085893630981445,2.381751537322998,2.3567190170288086,2.335132598876953,2.314568042755127,2.298326253890991,2.2893223762512207,2.273292303085327,2.269257068634033,2.2547223567962646,2.2408652305603027,2.2274065017700195,2.2199928760528564,2.2070305347442627,2.190474271774292,2.1788814067840576,2.1688458919525146,2.1640143394470215,2.155989408493042 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # YaRN 2 | 3 | This repo contains the code and data for the YaRN context window extension method. 4 | 5 | ## Paper 6 | 7 | Paper (ICLR 2024): [YaRN: Efficient Context Window Extension of Large Language Models](https://openreview.net/forum?id=wHBfxhZu1u) 8 | Old Preprint [(arXiv)](https://arxiv.org/abs/2309.00071) 9 | 10 | ## Models 11 | 12 | ### LLaMA 13 | 14 | We publish variants of [Llama 2](https://about.fb.com/news/2023/07/llama-2/) fine-tuned with YaRN at 32K, 64K and 128K context window length. 15 | They are available under the Llama 2 license on 🤗 Hugging Face. 16 | 17 | | Size | Context | Link | 18 | | ---: | ------: | :----- | 19 | | 7B | 64K | [NousResearch/Yarn-Llama-2-7b-64k](https://huggingface.co/NousResearch/Yarn-Llama-2-7b-64k) | 20 | | 7B | 128K | [NousResearch/Yarn-Llama-2-7b-128k](https://huggingface.co/NousResearch/Yarn-Llama-2-7b-128k) | 21 | | 13B | 64K | [NousResearch/Yarn-Llama-2-13b-64k](https://huggingface.co/NousResearch/Yarn-Llama-2-13b-64k) | 22 | | 13B | 128K | [NousResearch/Yarn-Llama-2-13b-128k](https://huggingface.co/NousResearch/Yarn-Llama-2-13b-128k) | 23 | | 70B | 32K | [NousResearch/Yarn-Llama-2-70b-32k](https://huggingface.co/NousResearch/Yarn-Llama-2-70b-32k) | 24 | 25 | In addition, we also publish 8K context window versions of Llama 2 7B fine-tuned with [NTK-aware](https://huggingface.co/emozilla/NTK-Llama-2-7b-8k) and [YaRN](https://huggingface.co/emozilla/Yarn-Llama-2-7b-8k) (Table 1 in the conference paper). 26 | 27 | ### Mistral 28 | 29 | With the release of v2 of our paper we are also publishing 64K and 128K variants of [Mistral 7B v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1). 30 | 31 | | Size | Context | Link | 32 | | ---: | ------: | :----- | 33 | | 7B | 64K | [NousResearch/Yarn-Mistral-7b-64k](https://huggingface.co/NousResearch/Yarn-Mistral-7b-64k) | 34 | | 7B | 128K | [NousResearch/Yarn-Mistral-7b-128k](https://huggingface.co/NousResearch/Yarn-Mistral-7b-128k) | 35 | 36 | ### SOLAR 37 | 38 | The [SOLAR 10.7B v1.0](https://huggingface.co/upstage/SOLAR-10.7B-v1.0) model utilizes [depth-up scaling](https://arxiv.org/abs/2312.15166) to add layers to [Mistral 7B v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1), which may potentially improve long context performance on a per-parameter basis. 39 | We publish 32K and 64K variants. 40 | 41 | | Size | Context | Link | 42 | | ------: | ------: | :----- | 43 | | 10.7B | 32K | [NousResearch/Yarn-Solar-10b-32k](https://huggingface.co/NousResearch/Yarn-Solar-10b-32k) | 44 | | 10.7B | 64K | [NousResearch/Yarn-Solar-10b-64k](https://huggingface.co/NousResearch/Yarn-Solar-10b-64k) | 45 | 46 | ## Reproduction 47 | 48 | We strongly believe in open science, and thus publish all code and data to reproduce the results in our paper. 49 | To reproduce, clone the repository and perform a local installation. 50 | 51 | ```python 52 | git clone https://github.com/jquesnelle/yarn 53 | cd yarn 54 | pip install -e . 55 | ``` 56 | 57 | ### Training 58 | 59 | To train the models, run `accelerate config` and enable DeepSpeed acceleration. `deepspeed/zero3.json` was the configuration file used for training. 60 | 61 | ```sh 62 | # ./train.sh 63 | ``` 64 | 65 | The tokenized training data is available on [🤗Hugging Face](https://huggingface.co/datasets/emozilla/pg_books-tokenized-bos-eos-chunked-65536) and was derived from the [pg19](https://huggingface.co/datasets/emozilla/pg19) dataset. 66 | For the Mistral models, a mix of the pretrain and fine-tune splits of [Long-Data-Collections](https://huggingface.co/datasets/togethercomputer/Long-Data-Collections) was used and the tokenized dataset is also available on [🤗Hugging Face](https://huggingface.co/datasets/emozilla/yarn-train-tokenized-16k-mistral). 67 | 68 | ### Evaluation 69 | 70 | To reproduce the evaluations, install [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness) with `pip install git+https://github.com/EleutherAI/lm-evaluation-harness` and then run the two provided scripts. 71 | 72 | ```sh 73 | # ./eval.sh 74 | # ./eval-harness.sh 75 | ``` 76 | 77 | ### Citation 78 | 79 | ``` 80 | @inproceedings{ 81 | peng2024yarn, 82 | title={Ya{RN}: Efficient Context Window Extension of Large Language Models}, 83 | author={Bowen Peng and Jeffrey Quesnelle and Honglu Fan and Enrico Shippole}, 84 | booktitle={The Twelfth International Conference on Learning Representations}, 85 | year={2024}, 86 | url={https://openreview.net/forum?id=wHBfxhZu1u} 87 | } 88 | ``` 89 | -------------------------------------------------------------------------------- /scaled_rope/LlamaYaRNScaledRotaryEmbedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | # Inverse dim formula to find dim based on number of rotations 5 | def find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048): 6 | return (dim * math.log(max_position_embeddings/(num_rotations * 2 * math.pi)))/(2 * math.log(base)) 7 | 8 | # Find dim range bounds based on rotations 9 | def find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048): 10 | low = math.floor(find_correction_dim( 11 | low_rot, dim, base, max_position_embeddings)) 12 | high = math.ceil(find_correction_dim( 13 | high_rot, dim, base, max_position_embeddings)) 14 | return max(low, 0), min(high, dim-1) # Clamp values just in case 15 | 16 | def linear_ramp_mask(min, max, dim): 17 | if min == max: 18 | max += 0.001 # Prevent singularity 19 | 20 | linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) 21 | ramp_func = torch.clamp(linear_func, 0, 1) 22 | return ramp_func 23 | 24 | def get_mscale(scale=1): 25 | if scale <= 1: 26 | return 1.0 27 | return 0.1 * math.log(scale) + 1.0 28 | 29 | class LlamaYaRNScaledRotaryEmbedding(torch.nn.Module): 30 | def __init__(self, dim, max_position_embeddings=2048, base=10000, scale=1, original_max_position_embeddings=2048, extrapolation_factor=1, attn_factor=1, beta_fast=32, beta_slow=1, finetuned=False, device=None): 31 | super().__init__() 32 | 33 | self.dim = dim 34 | self.max_position_embeddings = max_position_embeddings 35 | self.base = base 36 | self.scale = scale 37 | self.original_max_position_embeddings = original_max_position_embeddings 38 | self.extrapolation_factor = extrapolation_factor 39 | self.attn_factor = attn_factor 40 | self.beta_fast = beta_fast 41 | self.beta_slow = beta_slow 42 | 43 | self.yarn(device) 44 | 45 | # Build here to make `torch.jit.trace` work. 46 | self.max_seq_len_cached = max_position_embeddings 47 | t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) 48 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 49 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 50 | emb = torch.cat((freqs, freqs), dim=-1) 51 | dtype = torch.get_default_dtype() 52 | 53 | self.register_buffer("cos_cached", (emb.cos() * self.mscale)[None, None, :, :].to(dtype), persistent=False) 54 | self.register_buffer("sin_cached", (emb.sin() * self.mscale)[None, None, :, :].to(dtype), persistent=False) 55 | 56 | def forward(self, x, seq_len=None): 57 | # x: [bs, num_attention_heads, seq_len, head_size] 58 | # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. 59 | if seq_len > self.max_seq_len_cached: 60 | self.max_seq_len_cached = seq_len 61 | 62 | t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) 63 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 64 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 65 | emb = torch.cat((freqs, freqs), dim=-1).to(x.device) 66 | 67 | self.register_buffer("cos_cached", (emb.cos() * self.mscale)[None, None, :, :].to(x.dtype), persistent=False) 68 | self.register_buffer("sin_cached", (emb.sin() * self.mscale)[None, None, :, :].to(x.dtype), persistent=False) 69 | return ( 70 | self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), 71 | self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), 72 | ) 73 | 74 | def yarn(self, device): 75 | pos_freqs = self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) 76 | inv_freq_extrapolation = 1.0 / pos_freqs 77 | inv_freq_interpolation = 1.0 / (self.scale * pos_freqs) 78 | 79 | low, high = find_correction_range(self.beta_fast, self.beta_slow, self.dim, self.base, self.original_max_position_embeddings) 80 | inv_freq_mask = (1 - linear_ramp_mask(low, high, self.dim // 2).float().to(device)) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation 81 | inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask 82 | 83 | self.register_buffer("inv_freq", inv_freq) 84 | self.mscale = float(get_mscale(self.scale) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation -------------------------------------------------------------------------------- /scaled_rope/LlamaPartNTKScaledRotaryEmbedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | def find_correction_factor(num_rotations, dim, base=10000, max_position_embeddings=2048): 5 | return (dim * math.log(max_position_embeddings/(num_rotations * 2 * math.pi)))/(2 * math.log(base)) #Inverse dim formula to find number of rotations 6 | 7 | def find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048): 8 | low = math.floor(find_correction_factor(low_rot, dim, base, max_position_embeddings)) 9 | high = math.ceil(find_correction_factor(high_rot, dim, base, max_position_embeddings)) 10 | return max(low, 0), min(high, dim-1) #Clamp values just in case 11 | 12 | def linear_ramp_mask(min, max, dim): 13 | if min == max: 14 | max += 0.001 #Prevent singularity 15 | 16 | linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) 17 | ramp_func = torch.clamp(linear_func, 0, 1) 18 | return ramp_func 19 | 20 | def find_newbase_ntk(dim, base=10000, scale=1): 21 | return base * scale ** (dim / (dim-2)) 22 | 23 | class LlamaPartNTKScaledRotaryEmbedding(torch.nn.Module): 24 | def __init__(self, dim, max_position_embeddings=2048, base=10000, scale=1, ntk_factor=1, extrapolation_factor=1, original_max_position_embeddings=2048, device=None): 25 | super().__init__() 26 | 27 | #Interpolation constants found experimentally for LLaMA (might not be totally optimal though) 28 | #Do not change unless there is a good reason for doing so! 29 | beta_0 = 1.25 30 | beta_1 = 0.75 31 | gamma_0 = 16 32 | gamma_1 = 2 33 | 34 | #Three RoPE extrapolation/interpolation methods 35 | inv_freq_base = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) 36 | inv_freq_linear = 1.0 / (scale * (base ** (torch.arange(0, dim, 2).float().to(device) / dim))) 37 | inv_freq_ntk = 1.0 / (find_newbase_ntk(dim, base, scale) ** (torch.arange(0, dim, 2).float().to(device) / dim)) 38 | 39 | current_dtype = inv_freq_ntk.dtype 40 | current_device = inv_freq_ntk.device 41 | 42 | #Combine NTK and Linear 43 | low, high = find_correction_range(beta_0, beta_1, dim, base, original_max_position_embeddings) 44 | inv_freq_mask = (1 - linear_ramp_mask(low, high, dim // 2).type(current_dtype).to(current_device)) * ntk_factor 45 | inv_freq = inv_freq_linear * (1 - inv_freq_mask) + inv_freq_ntk * inv_freq_mask 46 | 47 | #Combine Extrapolation and NTK and Linear 48 | low, high = find_correction_range(gamma_0, gamma_1, dim, base, original_max_position_embeddings) 49 | inv_freq_mask = (1 - linear_ramp_mask(low, high, dim // 2).type(current_dtype).to(current_device)) * extrapolation_factor 50 | inv_freq = inv_freq * (1 - inv_freq_mask) + inv_freq_base * inv_freq_mask 51 | 52 | self.register_buffer("inv_freq", inv_freq) 53 | 54 | # Build here to make `torch.jit.trace` work. 55 | self.max_seq_len_cached = max_position_embeddings 56 | t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) 57 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 58 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 59 | emb = torch.cat((freqs, freqs), dim=-1) 60 | dtype = torch.get_default_dtype() 61 | self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) 62 | self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) 63 | 64 | def forward(self, x, seq_len=None): 65 | # x: [bs, num_attention_heads, seq_len, head_size] 66 | # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. 67 | if seq_len > self.max_seq_len_cached: 68 | self.max_seq_len_cached = seq_len 69 | t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) 70 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 71 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 72 | emb = torch.cat((freqs, freqs), dim=-1).to(x.device) 73 | self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False) 74 | self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False) 75 | return ( 76 | self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), 77 | self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), 78 | ) 79 | -------------------------------------------------------------------------------- /scaled_rope/patch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def patch_llama_for_dynamic_scaled_rotary_embeddings(model, ntk): 5 | from .LlamaDynamicScaledRotaryEmbedding import LlamaDynamicScaledRotaryEmbedding 6 | for each in model.model.layers: 7 | each.self_attn.rotary_emb = LlamaDynamicScaledRotaryEmbedding( 8 | each.self_attn.head_dim, device=each.self_attn.rotary_emb.inv_freq.device, ntk=ntk) 9 | 10 | def patch_llama_for_dynamic_part_ntk_rotary_embeddings(model, finetuned): 11 | from .LlamaDynamicPartNTKScaledRotaryEmbedding import LlamaDynamicPartNTKScaledRotaryEmbedding 12 | for each in model.model.layers: 13 | each.self_attn.rotary_emb = LlamaDynamicPartNTKScaledRotaryEmbedding( 14 | each.self_attn.head_dim, finetuned=finetuned, device=each.self_attn.rotary_emb.inv_freq.device) 15 | 16 | def patch_llama_for_dynamic_yarn_rotary_embeddings(model, original_max_position_embeddings, finetuned): 17 | from .LlamaDynamicYaRNScaledRotaryEmbedding import LlamaDynamicYaRNScaledRotaryEmbedding 18 | for each in model.model.layers: 19 | each.self_attn.rotary_emb = LlamaDynamicYaRNScaledRotaryEmbedding( 20 | each.self_attn.head_dim, finetuned=finetuned, original_max_position_embeddings=original_max_position_embeddings, device=each.self_attn.rotary_emb.inv_freq.device) 21 | 22 | def patch_falcon_for_dynamic_part_ntk_rotary_embeddings(model): 23 | from .FalconDynamicPartNTKScaledRotaryEmbedding import FalconDynamicPartNTKScaledRotaryEmbedding 24 | for each in model.transformer.h: 25 | each.self_attention.maybe_rotary = FalconDynamicPartNTKScaledRotaryEmbedding(each.self_attention.head_dim) 26 | 27 | def patch_llama_for_ntk_scaled_rotary_embeddings(model, alpha): 28 | from .LlamaNTKScaledRotaryEmbedding import LlamaNTKScaledRotaryEmbedding 29 | for each in model.model.layers: 30 | each.self_attn.rotary_emb = LlamaNTKScaledRotaryEmbedding( 31 | each.self_attn.head_dim, alpha=alpha, device=each.self_attn.rotary_emb.inv_freq.device) 32 | 33 | 34 | def patch_llama_for_linear_scaled_rotary_embeddings(model, scale): 35 | from .LlamaLinearScaledRotaryEmbedding import LlamaLinearScaledRotaryEmbedding 36 | for each in model.model.layers: 37 | each.self_attn.rotary_emb = LlamaLinearScaledRotaryEmbedding( 38 | each.self_attn.head_dim, scale=scale, device=each.self_attn.rotary_emb.inv_freq.device) 39 | 40 | 41 | def patch_llama_for_part_ntk_scaled_rotary_embeddings(model, scale): 42 | from .LlamaPartNTKScaledRotaryEmbedding import LlamaPartNTKScaledRotaryEmbedding 43 | for each in model.model.layers: 44 | each.self_attn.rotary_emb = LlamaPartNTKScaledRotaryEmbedding( 45 | each.self_attn.head_dim, scale=scale, device=each.self_attn.rotary_emb.inv_freq.device) 46 | 47 | def patch_llama_for_yarn_scaled_rotary_embeddings(model, scale, original_max_position_embeddings): 48 | from .LlamaYaRNScaledRotaryEmbedding import LlamaYaRNScaledRotaryEmbedding 49 | for each in model.model.layers: 50 | each.self_attn.rotary_emb = LlamaYaRNScaledRotaryEmbedding( 51 | each.self_attn.head_dim, scale=scale, original_max_position_embeddings=original_max_position_embeddings, device=each.self_attn.rotary_emb.inv_freq.device) 52 | 53 | def patch_gptneox_for_scaled_rotary_embeddings(model): 54 | from .GPTNeoXDynamicScaledRotaryEmbedding import GPTNeoXDynamicScaledRotaryEmbedding 55 | for each in model.gpt_neox.layers: 56 | each.attention.rotary_emb = GPTNeoXDynamicScaledRotaryEmbedding( 57 | each.attention.rotary_ndims, model.config.max_position_embeddings, device=each.attention.rotary_emb.inv_freq.device) 58 | 59 | 60 | def patch_gptneox_for_ntk_scaled_rotary_embeddings(model, alpha): 61 | from .GPTNeoXNTKScaledRotaryEmbedding import GPTNeoXNTKScaledRotaryEmbedding 62 | for each in model.gpt_neox.layers: 63 | each.attention.rotary_emb = GPTNeoXNTKScaledRotaryEmbedding( 64 | each.attention.rotary_ndims, model.config.max_position_embeddings, alpha=alpha, device=each.attention.rotary_emb.inv_freq.device) 65 | 66 | 67 | def patch_gptneox_for_longer_sequences(model, max_positions): 68 | for each in model.gpt_neox.layers: 69 | each.attention.bias = torch.tril(torch.ones((max_positions, max_positions), dtype=each.attention.bias.dtype, device=each.attention.bias.device)).view( 70 | 1, 1, max_positions, max_positions 71 | ) 72 | 73 | def patch_llama_for_rerope(model, training_length, window): 74 | from .LlamaReRoPE import forward_with_rerope 75 | for each in model.model.layers: 76 | def forward(*args, **kwargs): 77 | return forward_with_rerope(each.self_attn, *args, **kwargs) 78 | 79 | each.self_attn.training_length = int(training_length) 80 | each.self_attn.window = int(window) 81 | # each.self_attn.forward = forward -------------------------------------------------------------------------------- /scaled_rope/LlamaDynamicYaRNScaledRotaryEmbedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | # Inverse dim formula to find dim based on number of rotations 5 | def find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048): 6 | return (dim * math.log(max_position_embeddings/(num_rotations * 2 * math.pi)))/(2 * math.log(base)) 7 | 8 | # Find dim range bounds based on rotations 9 | def find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048): 10 | low = math.floor(find_correction_dim( 11 | low_rot, dim, base, max_position_embeddings)) 12 | high = math.ceil(find_correction_dim( 13 | high_rot, dim, base, max_position_embeddings)) 14 | return max(low, 0), min(high, dim-1) # Clamp values just in case 15 | 16 | def linear_ramp_mask(min, max, dim): 17 | if min == max: 18 | max += 0.001 # Prevent singularity 19 | 20 | linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) 21 | ramp_func = torch.clamp(linear_func, 0, 1) 22 | return ramp_func 23 | 24 | def get_mscale(scale=1): 25 | if scale <= 1: 26 | return 1.0 27 | return 0.1 * math.log(scale) + 1.0 28 | 29 | class LlamaDynamicYaRNScaledRotaryEmbedding(torch.nn.Module): 30 | def __init__(self, dim, max_position_embeddings=2048, base=10000, original_max_position_embeddings=2048, extrapolation_factor=1, attn_factor=1, beta_fast=32, beta_slow=1, finetuned=False, device=None): 31 | super().__init__() 32 | 33 | self.dim = dim 34 | self.max_position_embeddings = max_position_embeddings 35 | self.base = base 36 | self.original_max_position_embeddings = original_max_position_embeddings 37 | self.extrapolation_factor = extrapolation_factor 38 | self.attn_factor = attn_factor 39 | self.beta_fast = beta_fast 40 | self.beta_slow = beta_slow 41 | 42 | if finetuned: 43 | self.yarn(self.max_position_embeddings / self.original_max_position_embeddings, device) 44 | else: 45 | inv_freq = 1.0 / \ 46 | (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) 47 | self.register_buffer("inv_freq", inv_freq) 48 | self.mscale = 1 49 | 50 | # Build here to make `torch.jit.trace` work. 51 | self.max_seq_len_cached = max_position_embeddings 52 | t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) 53 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 54 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 55 | emb = torch.cat((freqs, freqs), dim=-1) 56 | dtype = torch.get_default_dtype() 57 | 58 | self.register_buffer("cos_cached", (emb.cos() * self.mscale)[None, None, :, :].to(dtype), persistent=False) 59 | self.register_buffer("sin_cached", (emb.sin() * self.mscale)[None, None, :, :].to(dtype), persistent=False) 60 | 61 | def forward(self, x, seq_len=None): 62 | # x: [bs, num_attention_heads, seq_len, head_size] 63 | # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. 64 | if seq_len > self.max_seq_len_cached: 65 | self.max_seq_len_cached = seq_len 66 | 67 | self.yarn(seq_len / self.original_max_position_embeddings, x.device) 68 | 69 | t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) 70 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 71 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 72 | emb = torch.cat((freqs, freqs), dim=-1).to(x.device) 73 | 74 | self.register_buffer("cos_cached", (emb.cos() * self.mscale)[None, None, :, :].to(x.dtype), persistent=False) 75 | self.register_buffer("sin_cached", (emb.sin() * self.mscale)[None, None, :, :].to(x.dtype), persistent=False) 76 | return ( 77 | self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), 78 | self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), 79 | ) 80 | 81 | def yarn(self, scale, device): 82 | pos_freqs = self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) 83 | inv_freq_extrapolation = 1.0 / pos_freqs 84 | inv_freq_interpolation = 1.0 / (scale * pos_freqs) 85 | 86 | low, high = find_correction_range(self.beta_fast, self.beta_slow, self.dim, self.base, self.original_max_position_embeddings) 87 | inv_freq_mask = (1 - linear_ramp_mask(low, high, self.dim // 2).float().to(device)) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation 88 | inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask 89 | 90 | self.register_buffer("inv_freq", inv_freq) 91 | self.mscale = float(get_mscale(scale) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation -------------------------------------------------------------------------------- /eval/passkey.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | import re 4 | import sys 5 | import torch 6 | import warnings 7 | from transformers import AutoTokenizer, pipeline 8 | from tqdm import tqdm, trange 9 | from tqdm.contrib import tenumerate 10 | from model_loader import * 11 | 12 | # from https://github.com/epfml/landmark-attention/blob/main/llama/run_test.py 13 | 14 | 15 | def generate_prompt(n_garbage): 16 | """Generates a text file and inserts an execute line at a random position.""" 17 | n_garbage_prefix = random.randint(0, n_garbage) 18 | n_garbage_suffix = n_garbage - n_garbage_prefix 19 | 20 | task_description = "There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there." 21 | garbage = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again." 22 | garbage_inf = " ".join([garbage] * 10000) 23 | assert len(garbage_inf) >= n_garbage 24 | garbage_prefix = garbage_inf[:n_garbage_prefix] 25 | garbage_suffix = garbage_inf[:n_garbage_suffix] 26 | pass_key = random.randint(1, 50000) 27 | information_line = f"The pass key is {pass_key}. Remember it. {pass_key} is the pass key." 28 | final_question = "What is the pass key? The pass key is" 29 | lines = [ 30 | task_description, 31 | garbage_prefix, 32 | information_line, 33 | garbage_suffix, 34 | final_question 35 | ] 36 | return "\n".join(lines), pass_key 37 | 38 | 39 | def test_model(pipe, prompt_text, pass_key): 40 | response = pipe(prompt_text, num_return_sequences=1, max_new_tokens=10)[ 41 | 0]["generated_text"][len(prompt_text):] 42 | assert f"The pass key is {pass_key}" in prompt_text 43 | 44 | try: 45 | pass_key = int(re.search(r'\d+', response).group()) 46 | except: 47 | pass_key = response[:20] 48 | 49 | return pass_key 50 | 51 | 52 | def main(args): 53 | models = [x[0] for x in args.model] 54 | tokenizer = AutoTokenizer.from_pretrained( 55 | models[0], model_max_length=sys.maxsize, padding_side="right", trust_remote_code=True) 56 | 57 | if args.fixed_length: 58 | lengths = [args.fixed_length] 59 | tokens = [len(tokenizer.encode(generate_prompt(args.fixed_length)[0]))] 60 | print(f"Prompt is {tokens[0]} tokens") 61 | else: 62 | if args.tokens_step: 63 | tokens = [x for x in range( 64 | args.min_tokens, args.max_tokens + 1, args.tokens_step)] 65 | else: 66 | tokens = [args.min_tokens] 67 | while args.min_tokens < args.max_tokens: 68 | point = tokens[-1] * 2 69 | if point <= args.max_tokens: 70 | tokens.append(point) 71 | else: 72 | break 73 | 74 | lengths = [] 75 | last_n = 0 76 | for target in tqdm(tokens, desc="Determining sequence lengths"): 77 | num_tokens = 0 78 | n = last_n 79 | while num_tokens < target: 80 | last_n = n 81 | n += args.length_step 82 | prompt = generate_prompt(n)[0] 83 | num_tokens = len(tokenizer.encode(prompt)) 84 | lengths.append(last_n) 85 | 86 | results = [] 87 | for model in tqdm(models, desc="Model", leave=False): 88 | torch.cuda.empty_cache() 89 | 90 | loaded = load_model_and_apply_patches(model, args) 91 | 92 | pipe = pipeline("text-generation", model=loaded, 93 | tokenizer=tokenizer, pad_token_id=tokenizer.eos_token_id) 94 | 95 | result = [0] * len(lengths) 96 | for i, length in tenumerate(lengths, desc="Lengths", leave=False): 97 | for _ in trange(0, args.iterations, desc="Iterations", leave=False): 98 | prompt_text, pass_key = generate_prompt(length) 99 | num_tokens = len(pipe.tokenizer.encode(prompt_text)) 100 | answer = test_model(pipe, prompt_text, pass_key) 101 | if answer == pass_key: 102 | result[i] += 1 103 | result[i] /= args.iterations 104 | print(f"{model}: {tokens[i]}={int(result[i]*100)}%") 105 | 106 | result.insert(0, model) 107 | results.append(result) 108 | 109 | if args.output_file: 110 | with open(args.output_file, "w", encoding="utf-8") as f: 111 | f.write(f",{','.join([str(x) for x in tokens])}\n") 112 | for result in results: 113 | f.write(f"{','.join([str(x) for x in result])}\n") 114 | 115 | 116 | if __name__ == "__main__": 117 | warnings.simplefilter("ignore") 118 | parser = argparse.ArgumentParser() 119 | parser.add_argument("-m", "--model", action="append", nargs="+") 120 | parser.add_argument("--fixed-length", type=int) 121 | parser.add_argument("--max-tokens", type=int, default=8192) 122 | parser.add_argument("--min-tokens", type=int, default=256) 123 | parser.add_argument("--tokens-step", type=int) 124 | parser.add_argument("--length-step", type=int, default=128) 125 | parser.add_argument("--iterations", type=int, default=20) 126 | parser.add_argument("--output-file", type=str) 127 | main(add_args(parser).parse_args()) 128 | -------------------------------------------------------------------------------- /scaled_rope/FalconDynamicPartNTKScaledRotaryEmbedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | # rotary pos emb helpers (torch.jit.script does not seem to support staticmethod...) 5 | def rotate_half(x): 6 | x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] 7 | return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in torch < 1.8.0 8 | 9 | def find_correction_factor(num_rotations, dim, base=10000, max_position_embeddings=2048): 10 | return (dim * math.log(max_position_embeddings/(num_rotations * 2 * math.pi)))/(2 * math.log(base)) #Inverse dim formula to find number of rotations 11 | 12 | def find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048): 13 | low = math.floor(find_correction_factor(low_rot, dim, base, max_position_embeddings)) 14 | high = math.ceil(find_correction_factor(high_rot, dim, base, max_position_embeddings)) 15 | return max(low, 0), min(high, dim-1) #Clamp values just in case 16 | 17 | def linear_ramp_mask(min, max, dim): 18 | if min == max: 19 | max += 0.001 #Prevent singularity 20 | 21 | linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) 22 | ramp_func = torch.clamp(linear_func, 0, 1) 23 | return ramp_func 24 | 25 | def find_newbase_ntk(dim, base=10000, scale=1): 26 | return base * scale ** (dim / (dim-2)) 27 | 28 | class FalconDynamicPartNTKScaledRotaryEmbedding(torch.nn.Module): 29 | """Implementation of RotaryEmbedding from GPT-NeoX. 30 | This implementation is design to operate on queries and keys that are compatible with 31 | [batch_size, n_heads_per_partition, seq_len, head_dim] (e.g. MinGPTAttention format). 32 | """ 33 | 34 | def __init__( 35 | self, 36 | head_dim: int, 37 | base=10000, 38 | max_position_embeddings=2048, 39 | ntk_factor=1, 40 | extrapolation_factor=1, 41 | ): 42 | super().__init__() 43 | inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim)) 44 | self.register_buffer("inv_freq", inv_freq, persistent=False) 45 | self.head_dim = head_dim 46 | self.seq_len_cached = None 47 | self.batch_size_cached = None 48 | self.cos_cached: torch.Tensor | None = None 49 | self.sin_cached: torch.Tensor | None = None 50 | self.base = base 51 | self.ntk_factor = ntk_factor 52 | self.extrapolation_factor = extrapolation_factor 53 | self.max_position_embeddings = max_position_embeddings 54 | 55 | def cos_sin( 56 | self, 57 | seq_len: int, 58 | device="cuda", 59 | dtype=torch.bfloat16, 60 | ) -> torch.Tensor: 61 | if seq_len != self.seq_len_cached: 62 | self.seq_len_cached = seq_len 63 | 64 | if seq_len >= self.max_position_embeddings: 65 | #Interpolation constants found experimentally for LLaMA (might not be totally optimal though) 66 | #Do not change unless there is a good reason for doing so! 67 | beta_0 = 1.25 68 | beta_1 = 0.75 69 | gamma_0 = 16 70 | gamma_1 = 2 71 | 72 | # the "dynamic" part 73 | scale = seq_len / self.max_position_embeddings 74 | 75 | #Three RoPE extrapolation/interpolation methods 76 | inv_freq_base = 1.0 / (self.base ** (torch.arange(0, self.head_dim, 2).float().to(device) / self.head_dim)) 77 | inv_freq_linear = 1.0 / (scale * (self.base ** (torch.arange(0, self.head_dim, 2).float().to(device) / self.head_dim))) 78 | inv_freq_ntk = 1.0 / (find_newbase_ntk(self.head_dim, self.base, scale) ** (torch.arange(0, self.head_dim, 2).float().to(device) / self.head_dim)) 79 | 80 | #Combine NTK and Linear 81 | low, high = find_correction_range(beta_0, beta_1, self.head_dim, self.base, self.max_position_embeddings) 82 | inv_freq_mask = (1 - linear_ramp_mask(low, high, self.head_dim // 2).type(dtype).to(device)) * self.ntk_factor 83 | inv_freq = inv_freq_linear * (1 - inv_freq_mask) + inv_freq_ntk * inv_freq_mask 84 | 85 | #Combine Extrapolation and NTK and Linear 86 | low, high = find_correction_range(gamma_0, gamma_1, self.head_dim, self.base, self.max_position_embeddings) 87 | inv_freq_mask = (1 - linear_ramp_mask(low, high, self.head_dim // 2).type(dtype).to(device)) * self.extrapolation_factor 88 | inv_freq = inv_freq * (1 - inv_freq_mask) + inv_freq_base * inv_freq_mask 89 | 90 | self.register_buffer("inv_freq", inv_freq, persistent=False) 91 | 92 | t = torch.arange(seq_len, device=device).type_as(self.inv_freq) 93 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 94 | emb = torch.cat((freqs, freqs), dim=-1).to(device) 95 | 96 | if dtype in [torch.float16, torch.bfloat16]: 97 | emb = emb.float() 98 | 99 | self.cos_cached = emb.cos()[None, :, :] 100 | self.sin_cached = emb.sin()[None, :, :] 101 | 102 | self.cos_cached = self.cos_cached.type(dtype) 103 | self.sin_cached = self.sin_cached.type(dtype) 104 | 105 | return self.cos_cached, self.sin_cached 106 | 107 | def forward(self, q, k): 108 | batch, seq_len, head_dim = q.shape 109 | cos, sin = self.cos_sin(seq_len, q.device, q.dtype) 110 | return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) -------------------------------------------------------------------------------- /data/proofpile-long-small.csv: -------------------------------------------------------------------------------- 1 | ,2048,4096,6144,8192,10240,12288,14336,16384,18432,20480,22528,24576,26624,28672,30720,32768,34816,36864,38912,40960,43008,45056,47104,49152,51200,53248,55296,57344,59392,61440,63488,65536,67584,69632,71680,73728,75776,77824,79872,81920,83968,86016,88064,90112,92160,94208,96256,98304,100352,102400,104448,106496,108544,110592,112640,114688,116736,118784,120832,122880,124928,126976,129024,131072 2 | NousResearch/CodeLlama-13b-hf,4.273087502,3.782966375,3.747891188,3.541493893,3.299672842,3.163071394,3.088324308,2.97326231,2.869639635,2.830187321,2.766177654,2.732852697,2.71192503,2.675392389,2.651535273,2.626380444,2.595819712,2.57420063,2.565205336,2.54873085,2.548609257,2.531494856,2.513749123,2.49823904,2.494235754,2.47132349,2.451201439,2.441103935,2.429980993,2.42084527,2.411242485,2.406718254,2.401376247,2.392701626,2.383884907,2.382179976,2.376464128,2.372499466,2.366471529,2.36151576,2.366015911,2.365545034,2.366862535,2.367699623,2.362401485,2.360276222,2.362963676,2.368350983,2.375911474,2.381812572,2.387961864,2.390213966,2.393273354,2.40162158,2.408063889,2.417190552,2.431198835,2.448834896,2.460363865,2.475495815,2.492848873,2.504173756,2.522003889,2.539400101 3 | NousResearch/Yarn-Llama-2-13b-64k,3.806912899,3.401254654,3.407855988,3.251551151,3.067051172,2.947519541,2.885066032,2.790714741,2.706659317,2.674013376,2.620252371,2.590349197,2.572107077,2.539562941,2.519123554,2.495951176,2.469627142,2.448579311,2.439432859,2.42330575,2.422192097,2.404289961,2.387158632,2.372862101,2.367934942,2.348040104,2.33080554,2.320837259,2.30927825,2.300942659,2.291547298,2.288551807,2.289961576,2.297888994,2.351677656,2.57955718,3.007681847,3.473647833,4.030450344,4.767449379,5.804921627,7.070561409,8.51751709,10.1455431,12.00591183,14.04304886,16.28896713,18.62060165,21.09352684,23.62802124,26.40208626,29.29329681,32.39636993,35.62049103,38.97951889,42.42362595,46.07492828,50.00058746,53.85381699,58.07590485,62.6740799,67.68762207,73.16788483,78.88412476 4 | NousResearch/Yarn-Llama-2-13b-128k,3.860970736,3.445932865,3.448999405,3.290391445,3.10351944,2.983279943,2.921168089,2.826341867,2.741844416,2.709266186,2.654923916,2.624872208,2.606542826,2.573714733,2.552715063,2.528698444,2.501633883,2.480374336,2.470992088,2.454633236,2.453536034,2.435217381,2.417454004,2.40269208,2.397391081,2.377074718,2.3593328,2.348656893,2.336751699,2.327668667,2.317125559,2.311721802,2.305166483,2.294796467,2.284498453,2.280540466,2.27330637,2.268487453,2.26162529,2.254907846,2.254883528,2.252135038,2.24918437,2.246737003,2.238143206,2.231673002,2.22956872,2.230200291,2.231621504,2.232100725,2.232689619,2.230055332,2.227398157,2.228141785,2.225455523,2.224666595,2.226955652,2.229412079,2.228803635,2.231298923,2.234647036,2.23429656,2.237049818,2.239109278 5 | togethercomputer/LLaMA-2-7B-32K,4.155766964,3.695604801,3.69126153,3.506507397,3.287930489,3.151707888,3.082172155,2.973427534,2.87671876,2.838195324,2.775708914,2.741379976,2.720355511,2.682943344,2.659923792,2.635623455,3.138739586,4.002960682,5.53152132,7.911962032,10.89597988,14.6881752,19.60680199,25.84499931,34.02965927,44.28824234,58.05148315,75.64293671,97.0128479,122.897522,153.1299286,187.6531525,229.0131073,274.9129639,329.1715088,394.709259,467.8532715,549.4265137,643.2252808,742.9885254,857.225708,969.8421631,1095.236816,1230.364868,1362.941528,1507.267578,1692.917969,1920.873657,2199.592285,2507.515869,2852.171143,3203.650635,3574.250977,3996.98584,4537.44873,5184.702637,6004.256348,6606.742676,7038.255859,7528.726562,8116.672852,8968.637695,10048.38281,11458.37891 6 | NousResearch/CodeLlama-7b-hf,4.46213007,3.954604864,3.93326354,3.714056969,3.456939936,3.310533762,3.228844643,3.107144594,2.998012304,2.954304457,2.886570454,2.854606152,2.833822489,2.794210434,2.77005291,2.744535208,2.714225054,2.692722321,2.685347557,2.670629263,2.67561245,2.660611629,2.645694733,2.630815983,2.628546476,2.605154514,2.585397959,2.576208591,2.565323353,2.556936026,2.549165726,2.546676874,2.543354273,2.53632021,2.529965162,2.530725718,2.526273012,2.524932146,2.520926237,2.518762827,2.526155233,2.528147936,2.531188011,2.534967422,2.529437304,2.527860165,2.531913042,2.538416147,2.547965765,2.554723978,2.561625957,2.563698053,2.565192461,2.571644068,2.573992252,2.579402924,2.591320276,2.60629487,2.616500378,2.632394314,2.651987553,2.665081501,2.686140537,2.705837965 7 | NousResearch/Yarn-Llama-2-7b-64k,4.140674114,3.68923521,3.686704159,3.506547213,3.29269886,3.16106081,3.094652176,2.987483501,2.892051935,2.85571456,2.794786453,2.760071754,2.740739346,2.703416824,2.680655718,2.654691696,2.623630524,2.601610899,2.59185791,2.57383728,2.57201004,2.551998854,2.532728672,2.516901731,2.511267185,2.489012718,2.469356775,2.457770824,2.444830179,2.43555069,2.424718618,2.421703339,2.427140951,2.441857338,2.490459681,2.610041857,2.92416954,3.400960207,3.937628746,4.631416321,5.472755909,6.437920094,7.533967972,8.803408623,10.2223444,11.78959942,13.58889198,15.5106535,17.6680088,20.011549,22.51984406,25.25511551,28.12402344,31.22803879,34.65332413,38.36104584,42.42313004,46.62594604,50.83547211,,,,, 8 | NousResearch/Yarn-Llama-2-7b-128k,4.213690758,3.752393484,3.748307467,3.562077045,3.342643738,3.209561348,3.142821312,3.035613537,2.938964605,2.902445078,2.840379715,2.805624247,2.785522223,2.747549295,2.724321604,2.698002338,2.666688919,2.644493818,2.63421607,2.615665674,2.613620996,2.592806339,2.572926044,2.556559086,2.550908089,2.527721167,2.507618904,2.495225668,2.481578827,2.471572161,2.460015297,2.453372717,2.446470499,2.435064077,2.423389673,2.419371367,2.411088943,2.405922651,2.398278236,2.390570879,2.390710831,2.387418032,2.384321451,2.381688356,2.371665239,2.364724159,2.362744331,2.363505602,2.365161896,2.365767002,2.36630702,2.363711119,2.360804796,2.361681461,2.358406067,2.357491255,2.360047102,2.36299324,2.36221838,2.364833593,2.368324041,2.367149591,2.370092154,2.372046709 -------------------------------------------------------------------------------- /scaled_rope/LlamaDynamicPartNTKScaledRotaryEmbedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | 5 | def find_correction_factor(num_rotations, dim, base=10000, max_position_embeddings=2048): 6 | # Inverse dim formula to find number of rotations 7 | return (dim * math.log(max_position_embeddings/(num_rotations * 2 * math.pi)))/(2 * math.log(base)) 8 | 9 | 10 | def find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048): 11 | low = math.floor(find_correction_factor( 12 | low_rot, dim, base, max_position_embeddings)) 13 | high = math.ceil(find_correction_factor( 14 | high_rot, dim, base, max_position_embeddings)) 15 | return max(low, 0), min(high, dim-1) # Clamp values just in case 16 | 17 | 18 | def linear_ramp_mask(min, max, dim): 19 | if min == max: 20 | max += 0.001 # Prevent singularity 21 | 22 | linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) 23 | ramp_func = torch.clamp(linear_func, 0, 1) 24 | return ramp_func 25 | 26 | 27 | def find_newbase_ntk(dim, base=10000, scale=1): 28 | return base * scale ** (dim / (dim-2)) 29 | 30 | 31 | class LlamaDynamicPartNTKScaledRotaryEmbedding(torch.nn.Module): 32 | def __init__(self, dim, max_position_embeddings=2048, original_max_position_embeddings=2048, base=10000, ntk_factor=1, extrapolation_factor=1, finetuned=False, device=None): 33 | super().__init__() 34 | self.dim = dim 35 | self.base = base 36 | self.ntk_factor = ntk_factor 37 | self.extrapolation_factor = extrapolation_factor 38 | self.max_position_embeddings = max_position_embeddings 39 | if finetuned: 40 | self.ntk(self.max_position_embeddings / original_max_position_embeddings, device) 41 | else: 42 | inv_freq = 1.0 / \ 43 | (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) 44 | self.register_buffer("inv_freq", inv_freq) 45 | 46 | # Build here to make `torch.jit.trace` work. 47 | self.max_seq_len_cached = max_position_embeddings 48 | t = torch.arange(self.max_seq_len_cached, 49 | device=self.inv_freq.device, dtype=self.inv_freq.dtype) 50 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 51 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 52 | emb = torch.cat((freqs, freqs), dim=-1) 53 | dtype = torch.get_default_dtype() 54 | self.register_buffer("cos_cached", emb.cos()[ 55 | None, None, :, :].to(dtype), persistent=False) 56 | self.register_buffer("sin_cached", emb.sin()[ 57 | None, None, :, :].to(dtype), persistent=False) 58 | 59 | def forward(self, x, seq_len=None): 60 | # x: [bs, num_attention_heads, seq_len, head_size] 61 | # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. 62 | if seq_len > self.max_seq_len_cached: 63 | self.max_seq_len_cached = seq_len 64 | 65 | self.ntk(seq_len / self.max_position_embeddings, x.device) 66 | 67 | t = torch.arange(self.max_seq_len_cached, 68 | device=x.device, dtype=self.inv_freq.dtype) 69 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 70 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 71 | emb = torch.cat((freqs, freqs), dim=-1).to(x.device) 72 | self.register_buffer("cos_cached", emb.cos()[ 73 | None, None, :, :].to(x.dtype), persistent=False) 74 | self.register_buffer("sin_cached", emb.sin()[ 75 | None, None, :, :].to(x.dtype), persistent=False) 76 | return ( 77 | self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), 78 | self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), 79 | ) 80 | 81 | def ntk(self, scale, device): 82 | 83 | # Interpolation constants found experimentally for LLaMA (might not be totally optimal though) 84 | # Do not change unless there is a good reason for doing so! 85 | beta_0 = 1.25 86 | beta_1 = 0.75 87 | gamma_0 = 16 88 | gamma_1 = 2 89 | 90 | # Three RoPE extrapolation/interpolation methods 91 | inv_freq_base = 1.0 / \ 92 | (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) 93 | inv_freq_linear = 1.0 / \ 94 | (scale * (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))) 95 | inv_freq_ntk = 1.0 / (find_newbase_ntk(self.dim, self.base, scale) 96 | ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) 97 | 98 | current_dtype = inv_freq_ntk.dtype 99 | current_device = inv_freq_ntk.device 100 | 101 | # Combine NTK and Linear 102 | low, high = find_correction_range( 103 | beta_0, beta_1, self.dim, self.base, self.max_position_embeddings) 104 | inv_freq_mask = (1 - linear_ramp_mask(low, high, self.dim // 105 | 2).type(current_dtype).to(current_device)) * self.ntk_factor 106 | inv_freq = inv_freq_linear * \ 107 | (1 - inv_freq_mask) + inv_freq_ntk * inv_freq_mask 108 | 109 | # Combine Extrapolation and NTK and Linear 110 | low, high = find_correction_range( 111 | gamma_0, gamma_1, self.dim, self.base, self.max_position_embeddings) 112 | inv_freq_mask = (1 - linear_ramp_mask(low, high, self.dim // 2).type( 113 | current_dtype).to(current_device)) * self.extrapolation_factor 114 | inv_freq = inv_freq * (1 - inv_freq_mask) + \ 115 | inv_freq_base * inv_freq_mask 116 | 117 | self.register_buffer("inv_freq", inv_freq) 118 | -------------------------------------------------------------------------------- /eval/passkey_hard.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | import re 4 | import sys 5 | import torch 6 | import warnings 7 | from transformers import AutoTokenizer, pipeline 8 | from tqdm import tqdm, trange 9 | from tqdm.contrib import tenumerate 10 | from model_loader import * 11 | from datasets import load_dataset 12 | import random 13 | #import pickle 14 | import json 15 | 16 | # from https://github.com/epfml/landmark-attention/blob/main/llama/run_test.py 17 | 18 | 19 | def order(i): 20 | if i % 10 == 1 and i % 10 != 11: 21 | return str(i) + "st" 22 | elif i % 10 == 2 and i % 10 != 12: 23 | return str(i) + "nd" 24 | elif i % 19 == 3 and i % 10 != 13: 25 | return str(i) + "rd" 26 | else: 27 | return str(i) + "th" 28 | 29 | def generate_prompt(docs, num_keys=1): 30 | task_description = "There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there." 31 | pass_keys = [random.randint(1, 50000) for _ in range(num_keys)] 32 | start_pos = sorted([random.randint(1, len(docs)) for _ in range(num_keys)]) 33 | information_lines = [f"The {order(i+1)} pass key is {pass_key}. Remember it. {pass_key} is the {order(i+1)} pass key." for i, pass_key in enumerate(pass_keys)] 34 | retrieve_number = random.randint(0, num_keys - 1) 35 | final_question = f"What is the {order(retrieve_number + 1)} pass key? The {order(retrieve_number + 1)} pass key is" 36 | lines = [task_description] 37 | prev = 0 38 | for line, pos in zip(information_lines, start_pos): 39 | lines.append("".join(docs[prev:pos])) 40 | lines.append(line) 41 | prev = pos 42 | lines.append("".join(docs[prev:])) 43 | lines.append(final_question) 44 | 45 | return "\n".join(lines), pass_keys, start_pos, retrieve_number 46 | 47 | 48 | def test_model(pipe, prompt_text): 49 | response = pipe(prompt_text, num_return_sequences=1, max_new_tokens=10)[ 50 | 0]["generated_text"][len(prompt_text):] 51 | 52 | try: 53 | pass_key = int(re.search(r'\d+', response).group()) 54 | except: 55 | pass_key = response[:20] 56 | 57 | return pass_key, response 58 | 59 | def construct_junk(data, length, tokenizer): 60 | token_count = 0 61 | docs = [] 62 | length = length or 8192 63 | 64 | while token_count < length: 65 | sample = random.choice(data)["text"] 66 | toks = tokenizer(sample, return_offsets_mapping=True) 67 | offsets = [(i, j) for i, j in toks["offset_mapping"] if i < j] 68 | num_tok_to_add = min(length - token_count, len(offsets)) 69 | pretokenized = [sample[i:j] for i, j in offsets[:num_tok_to_add]] 70 | docs.extend(pretokenized) 71 | token_count += num_tok_to_add 72 | 73 | return docs 74 | 75 | 76 | def main(args): 77 | models = [x[0] for x in args.model] 78 | tokenizer = AutoTokenizer.from_pretrained( 79 | models[0], model_max_length=sys.maxsize, padding_side="right", trust_remote_code=True) 80 | 81 | data = load_dataset(args.dataset)[args.split] 82 | junks = construct_junk(data, args.fixed_length, tokenizer) 83 | 84 | # We restrict tokens to a small subset: digits, eos and continuous spaces/linebreaks 85 | # This is to prevent continuations like " is a special number" blah blah blah... 86 | if args.restrict_tokens: 87 | vocab = tokenizer.vocab 88 | 89 | escape_char = "▁" # for Llama family 90 | 91 | digit_tokens = [vocab[a] for a in vocab.keys() if a.lstrip(escape_char).isdigit()] 92 | # Add EOS 93 | digit_tokens.append(vocab[tokenizer.eos_token]) 94 | # Add spaces/linebreaks 95 | extra = [vocab[a] for a in vocab.keys() if a.strip(" \n" + escape_char) == ""] 96 | digit_tokens.extend(extra) 97 | 98 | mask = torch.ones(tokenizer.vocab_size, dtype=torch.bool) 99 | mask[digit_tokens] = 0 100 | 101 | def filter_digits(module, input, output): 102 | output.logits[..., mask[:output.logits.size(-1)]] = -1e4 103 | 104 | print(f"Decoding restricted to {len(digit_tokens)} tokens.") 105 | 106 | 107 | results = [] 108 | success_count = 0 109 | for model in tqdm(models, desc="Model", leave=False): 110 | torch.cuda.empty_cache() 111 | 112 | loaded = load_model_and_apply_patches(model, args) 113 | if args.restrict_tokens: 114 | loaded.register_forward_hook(filter_digits) 115 | 116 | pipe = pipeline("text-generation", model=loaded, 117 | tokenizer=tokenizer, pad_token_id=tokenizer.eos_token_id) 118 | 119 | for _ in trange(0, args.iterations, desc="Iterations", leave=False): 120 | prompt_text, pass_keys, start_pos, target = generate_prompt(junks, args.num_keys) 121 | num_tokens = len(pipe.tokenizer.encode(prompt_text)) 122 | answer, return_text = test_model(pipe, prompt_text) 123 | passed = str(answer).startswith(str(pass_keys[target])) 124 | result = {"prompt_text": prompt_text, "start_pos": start_pos, "pass_keys": pass_keys, "return_text": return_text, "passed": passed} 125 | success_count += passed 126 | results.append(result) 127 | 128 | results.append({"original_prompt": junks}) 129 | print(f"Iteration: {args.iterations}") 130 | print(f"Successes: {success_count}") 131 | 132 | if args.output_file: 133 | with open(args.output_file, "w") as f: 134 | json.dump(results, f) 135 | 136 | 137 | if __name__ == "__main__": 138 | warnings.simplefilter("ignore") 139 | parser = argparse.ArgumentParser() 140 | parser.add_argument("-m", "--model", action="append", nargs="+") 141 | parser.add_argument("--fixed-length", type=int, default=8192) 142 | parser.add_argument("--restrict-tokens", type=bool, default=True) 143 | parser.add_argument("--num-keys", type=int, default=1) 144 | parser.add_argument("--iterations", type=int, default=20) 145 | parser.add_argument("--output-file", type=str) 146 | parser.add_argument("--dataset", type=str, default="togethercomputer/RedPajama-Data-1T-Sample") 147 | parser.add_argument("--split", type=str, default="train") 148 | main(add_args(parser).parse_args()) 149 | -------------------------------------------------------------------------------- /scaled_rope/LlamaReRoPE.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | # ReRoPE (Rectified Rotary Position Embeddings) 3 | # 链接:https://kexue.fm/archives/9708 4 | # transformers 4.31.0 测试通过 5 | 6 | import torch 7 | import numpy as np 8 | 9 | def apply_rotary_pos_emb(q, k, cos, sin, position_ids): 10 | from transformers.models.llama.modeling_llama import rotate_half 11 | # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. 12 | cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] 13 | sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] 14 | cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] 15 | sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] 16 | q_embed = (q * cos[:, :, -q.shape[2]:]) + (rotate_half(q) * sin[:, :, -q.shape[2]:]) if q is not None else None 17 | k_embed = (k * cos) + (rotate_half(k) * sin) if k is not None else None 18 | return q_embed, k_embed 19 | 20 | 21 | def forward_with_rerope( 22 | self, 23 | hidden_states: torch.Tensor, 24 | attention_mask=None, 25 | position_ids=None, 26 | past_key_value=None, 27 | output_attentions=False, 28 | use_cache=False, 29 | ): 30 | from transformers.models.llama.modeling_llama import repeat_kv, F, nn, math 31 | bsz, q_len, _ = hidden_states.size() 32 | 33 | if self.pretraining_tp > 1: 34 | key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp 35 | query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.pretraining_tp, dim=0) 36 | key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) 37 | value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) 38 | 39 | query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)] 40 | query_states = torch.cat(query_states, dim=-1) 41 | 42 | key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)] 43 | key_states = torch.cat(key_states, dim=-1) 44 | 45 | value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)] 46 | value_states = torch.cat(value_states, dim=-1) 47 | 48 | else: 49 | query_states = self.q_proj(hidden_states) 50 | key_states = self.k_proj(hidden_states) 51 | value_states = self.v_proj(hidden_states) 52 | 53 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 54 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 55 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 56 | query_states *= ((position_ids + 1)[:, None, :, None].log() / np.log(self.training_length)).clip(1).to(query_states.dtype) 57 | 58 | kv_seq_len = key_states.shape[-2] 59 | if past_key_value is not None: 60 | kv_seq_len += past_key_value[0].shape[-2] 61 | # reuse k, v, self_attention 62 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 63 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 64 | position_ids = torch.cat([past_key_value[2], position_ids], dim=1) 65 | 66 | past_key_value = (key_states, value_states, position_ids) if use_cache else None 67 | 68 | if q_len == 1: 69 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 70 | position_ids = (position_ids[:, -1] - position_ids).clip(max=self.window) 71 | _, key_states = apply_rotary_pos_emb(None, key_states, cos, -sin, position_ids) 72 | key_states = repeat_kv(key_states, self.num_key_value_groups) 73 | value_states = repeat_kv(value_states, self.num_key_value_groups) 74 | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) 75 | else: 76 | cos, sin = self.rotary_emb(value_states, seq_len=max(kv_seq_len, self.window)) 77 | query_states1, key_states1 = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) 78 | query_states2, _ = apply_rotary_pos_emb(query_states, None, cos, sin, position_ids * 0 + self.window) 79 | 80 | # repeat k/v heads if n_kv_heads < n_heads 81 | key_states1 = repeat_kv(key_states1, self.num_key_value_groups) 82 | key_states2 = repeat_kv(key_states, self.num_key_value_groups) 83 | value_states = repeat_kv(value_states, self.num_key_value_groups) 84 | 85 | attn_weights1 = torch.matmul(query_states1, key_states1.transpose(2, 3)) / math.sqrt(self.head_dim) 86 | attn_weights2 = torch.matmul(query_states2, key_states2.transpose(2, 3)) / math.sqrt(self.head_dim) 87 | rectified_mask = (position_ids[:, -q_len:, None] - position_ids[:, None]).abs() < self.window 88 | attn_weights = torch.where(rectified_mask, attn_weights1, attn_weights2) 89 | 90 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): 91 | raise ValueError( 92 | f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" 93 | f" {attn_weights.size()}" 94 | ) 95 | 96 | if attention_mask is not None: 97 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): 98 | raise ValueError( 99 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" 100 | ) 101 | attn_weights = attn_weights + attention_mask 102 | 103 | # upcast attention to fp32 104 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) 105 | attn_output = torch.matmul(attn_weights, value_states) 106 | 107 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): 108 | raise ValueError( 109 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" 110 | f" {attn_output.size()}" 111 | ) 112 | 113 | attn_output = attn_output.transpose(1, 2).contiguous() 114 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 115 | 116 | if self.pretraining_tp > 1: 117 | attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2) 118 | o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.pretraining_tp, dim=1) 119 | attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.pretraining_tp)]) 120 | else: 121 | attn_output = self.o_proj(attn_output) 122 | 123 | if not output_attentions: 124 | attn_weights = None 125 | 126 | return attn_output, attn_weights, past_key_value -------------------------------------------------------------------------------- /eval-harness.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | LM_EVALUATION_HARNESS_PATH="../lm-evaluation-harness" 4 | ARGS="--model=hf-causal-experimental --batch_size 2" 5 | MODEL_ARGS="use_accelerate=True,dtype=bfloat16,trust_remote_code=True" 6 | ARC="--tasks=arc_challenge --num_fewshot=25" 7 | HELLASWAG="--tasks=hellaswag --num_fewshot=10" 8 | TRUTHFULQA="--tasks=truthfulqa_mc --num_fewshot=0" 9 | MMLU="--tasks=hendrycksTest-abstract_algebra,hendrycksTest-anatomy,hendrycksTest-astronomy,hendrycksTest-business_ethics,hendrycksTest-clinical_knowledge,hendrycksTest-college_biology,hendrycksTest-college_chemistry,hendrycksTest-college_computer_science,hendrycksTest-college_mathematics,hendrycksTest-college_medicine,hendrycksTest-college_physics,hendrycksTest-computer_security,hendrycksTest-conceptual_physics,hendrycksTest-econometrics,hendrycksTest-electrical_engineering,hendrycksTest-elementary_mathematics,hendrycksTest-formal_logic,hendrycksTest-global_facts,hendrycksTest-high_school_biology,hendrycksTest-high_school_chemistry,hendrycksTest-high_school_computer_science,hendrycksTest-high_school_european_history,hendrycksTest-high_school_geography,hendrycksTest-high_school_government_and_politics,hendrycksTest-high_school_macroeconomics,hendrycksTest-high_school_mathematics,hendrycksTest-high_school_microeconomics,hendrycksTest-high_school_physics,hendrycksTest-high_school_psychology,hendrycksTest-high_school_statistics,hendrycksTest-high_school_us_history,hendrycksTest-high_school_world_history,hendrycksTest-human_aging,hendrycksTest-human_sexuality,hendrycksTest-international_law,hendrycksTest-jurisprudence,hendrycksTest-logical_fallacies,hendrycksTest-machine_learning,hendrycksTest-management,hendrycksTest-marketing,hendrycksTest-medical_genetics,hendrycksTest-miscellaneous,hendrycksTest-moral_disputes,hendrycksTest-moral_scenarios,hendrycksTest-nutrition,hendrycksTest-philosophy,hendrycksTest-prehistory,hendrycksTest-professional_accounting,hendrycksTest-professional_law,hendrycksTest-professional_medicine,hendrycksTest-professional_psychology,hendrycksTest-public_relations,hendrycksTest-security_studies,hendrycksTest-sociology,hendrycksTest-us_foreign_policy,hendrycksTest-virology,hendrycksTest-world_religions --num_fewshot=5" 10 | 11 | ### ARC-Challenge 12 | 13 | python ${LM_EVALUATION_HARNESS_PATH}/main.py ${ARGS} \ 14 | ${ARC} \ 15 | --model_args="pretrained=NousResearch/Yarn-Llama-2-7b-64k,${MODEL_ARGS}" \ 16 | --output_path="data/Yarn-Llama-2-7b-64k-arc.json" 17 | 18 | python ${LM_EVALUATION_HARNESS_PATH}/main.py ${ARGS} \ 19 | ${ARC} \ 20 | --model_args="pretrained=NousResearch/Yarn-Llama-2-7b-128k,${MODEL_ARGS}" \ 21 | --output_path="data/Yarn-Llama-2-7b-128k-arc.json" 22 | 23 | python ${LM_EVALUATION_HARNESS_PATH}/main.py ${ARGS} \ 24 | ${ARC} \ 25 | --model_args="pretrained=NousResearch/Yarn-Llama-2-13b-64k,${MODEL_ARGS}" \ 26 | --output_path="data/Yarn-Llama-2-13b-64k-arc.json" 27 | 28 | python ${LM_EVALUATION_HARNESS_PATH}/main.py ${ARGS} \ 29 | ${ARC} \ 30 | --model_args="pretrained=NousResearch/Yarn-Llama-2-13b-128k,${MODEL_ARGS}" \ 31 | --output_path="data/Yarn-Llama-2-13b-128k-arc.json" 32 | 33 | python ${LM_EVALUATION_HARNESS_PATH}/main.py ${ARGS} \ 34 | ${ARC} \ 35 | --model_args="pretrained=NousResearch/Yarn-Mistral-7b-64k,${MODEL_ARGS}" \ 36 | --output_path="data/Yarn-Mistral-7b-64k-arc.json" 37 | 38 | python ${LM_EVALUATION_HARNESS_PATH}/main.py ${ARGS} \ 39 | ${ARC} \ 40 | --model_args="pretrained=NousResearch/Yarn-Mistral-7b-128k,${MODEL_ARGS}" \ 41 | --output_path="data/Yarn-Mistral-7b-128k-arc.json" 42 | 43 | ### Hellaswag 44 | 45 | python ${LM_EVALUATION_HARNESS_PATH}/main.py ${ARGS} \ 46 | ${HELLASWAG} \ 47 | --model_args="pretrained=NousResearch/Yarn-Llama-2-7b-64k,${MODEL_ARGS}" \ 48 | --output_path="data/Yarn-Llama-2-7b-64k-hellaswag.json" 49 | 50 | python ${LM_EVALUATION_HARNESS_PATH}/main.py ${ARGS} \ 51 | ${HELLASWAG} \ 52 | --model_args="pretrained=NousResearch/Yarn-Llama-2-7b-128k,${MODEL_ARGS}" \ 53 | --output_path="data/Yarn-Llama-2-7b-128k-hellaswag.json" 54 | 55 | python ${LM_EVALUATION_HARNESS_PATH}/main.py ${ARGS} \ 56 | ${HELLASWAG} \ 57 | --model_args="pretrained=NousResearch/Yarn-Llama-2-13b-64k,${MODEL_ARGS}" \ 58 | --output_path="data/Yarn-Llama-2-13b-64k-hellaswag.json" 59 | 60 | python ${LM_EVALUATION_HARNESS_PATH}/main.py ${ARGS} \ 61 | ${HELLASWAG} \ 62 | --model_args="pretrained=NousResearch/Yarn-Llama-2-13b-128k,${MODEL_ARGS}" \ 63 | --output_path="data/Yarn-Llama-2-13b-128k-hellaswag.json" 64 | 65 | python ${LM_EVALUATION_HARNESS_PATH}/main.py ${ARGS} \ 66 | ${HELLASWAG} \ 67 | --model_args="pretrained=NousResearch/Yarn-Mistral-7b-64k,${MODEL_ARGS}" \ 68 | --output_path="data/Yarn-Mistral-7b-64k-hellaswag.json" 69 | 70 | python ${LM_EVALUATION_HARNESS_PATH}/main.py ${ARGS} \ 71 | ${HELLASWAG} \ 72 | --model_args="pretrained=NousResearch/Yarn-Mistral-7b-128k,${MODEL_ARGS}" \ 73 | --output_path="data/Yarn-Mistral-7b-128k-hellaswag.json" 74 | 75 | ### MMLU 76 | 77 | python ${LM_EVALUATION_HARNESS_PATH}/main.py ${ARGS} \ 78 | ${MMLU} \ 79 | --model_args="pretrained=NousResearch/Yarn-Llama-2-7b-64k,${MODEL_ARGS}" \ 80 | --output_path="data/Yarn-Llama-2-7b-64k-mmlu.json" 81 | 82 | python ${LM_EVALUATION_HARNESS_PATH}/main.py ${ARGS} \ 83 | ${MMLU} \ 84 | --model_args="pretrained=NousResearch/Yarn-Llama-2-7b-128k,${MODEL_ARGS}" \ 85 | --output_path="data/Yarn-Llama-2-7b-128k-mmlu.json" 86 | 87 | python ${LM_EVALUATION_HARNESS_PATH}/main.py ${ARGS} \ 88 | ${MMLU} \ 89 | --model_args="pretrained=NousResearch/Yarn-Llama-2-13b-64k,${MODEL_ARGS}" \ 90 | --output_path="data/Yarn-Llama-2-13b-64k-mmlu.json" 91 | 92 | python ${LM_EVALUATION_HARNESS_PATH}/main.py ${ARGS} \ 93 | ${MMLU} \ 94 | --model_args="pretrained=NousResearch/Yarn-Llama-2-13b-128k,${MODEL_ARGS}" \ 95 | --output_path="data/Yarn-Llama-2-13b-128k-mmlu.json" 96 | 97 | python ${LM_EVALUATION_HARNESS_PATH}/main.py ${ARGS} \ 98 | ${MMLU} \ 99 | --model_args="pretrained=NousResearch/Yarn-Mistral-7b-64k,${MODEL_ARGS}" \ 100 | --output_path="data/Yarn-Mistral-7b-64k-mmlu.json" 101 | 102 | python ${LM_EVALUATION_HARNESS_PATH}/main.py ${ARGS} \ 103 | ${MMLU} \ 104 | --model_args="pretrained=NousResearch/Yarn-Mistral-7b-128k,${MODEL_ARGS}" \ 105 | --output_path="data/Yarn-Mistral-7b-128k-mmlu.json" 106 | 107 | ## TruthfulQA 108 | 109 | python ${LM_EVALUATION_HARNESS_PATH}/main.py ${ARGS} \ 110 | ${TRUTHFULQA} \ 111 | --model_args="pretrained=NousResearch/Yarn-Llama-2-7b-64k,${MODEL_ARGS}" \ 112 | --output_path="data/Yarn-Llama-2-7b-64k-truthfulqa.json" 113 | 114 | python ${LM_EVALUATION_HARNESS_PATH}/main.py ${ARGS} \ 115 | ${TRUTHFULQA} \ 116 | --model_args="pretrained=NousResearch/Yarn-Llama-2-7b-128k,${MODEL_ARGS}" \ 117 | --output_path="data/Yarn-Llama-2-7b-128k-truthfulqa.json" 118 | 119 | python ${LM_EVALUATION_HARNESS_PATH}/main.py ${ARGS} \ 120 | ${TRUTHFULQA} \ 121 | --model_args="pretrained=NousResearch/Yarn-Llama-2-13b-64k,${MODEL_ARGS}" \ 122 | --output_path="data/Yarn-Llama-2-13b-64k-truthfulqa.json" 123 | 124 | python ${LM_EVALUATION_HARNESS_PATH}/main.py ${ARGS} \ 125 | ${TRUTHFULQA} \ 126 | --model_args="pretrained=NousResearch/Yarn-Llama-2-13b-128k,${MODEL_ARGS}" \ 127 | --output_path="data/Yarn-Llama-2-13b-128k-truthfulqa.json" 128 | 129 | python ${LM_EVALUATION_HARNESS_PATH}/main.py ${ARGS} \ 130 | ${TRUTHFULQA} \ 131 | --model_args="pretrained=NousResearch/Yarn-Mistral-7b-64k,${MODEL_ARGS}" \ 132 | --output_path="data/Yarn-Mistral-7b-64k-truthfulqa.json" 133 | 134 | python ${LM_EVALUATION_HARNESS_PATH}/main.py ${ARGS} \ 135 | ${TRUTHFULQA} \ 136 | --model_args="pretrained=NousResearch/Yarn-Mistral-7b-128k,${MODEL_ARGS}" \ 137 | --output_path="data/Yarn-Mistral-7b-128k-truthfulqa.json" 138 | -------------------------------------------------------------------------------- /eval/perplexity.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datasets 3 | import gc 4 | import sys 5 | import torch 6 | import warnings 7 | from transformers import AutoTokenizer 8 | from tqdm import tqdm 9 | from model_loader import * 10 | 11 | 12 | def compute_perplexity( 13 | encodings, model, tokenizer, add_start_token: bool = True, device=None, max_length=None, sliding_window=256, truncate=False, aggressive_memory=False, hide_progress=False, 14 | ): 15 | r"""Compute "sliding window" perplexity on a dataset. Validated against the calculations reported in arXiv 2306.15595""" 16 | if device is not None: 17 | assert device in ["gpu", "cpu", 18 | "cuda"], "device should be either gpu or cpu." 19 | if device == "gpu": 20 | device = "cuda" 21 | else: 22 | device = "cuda" if torch.cuda.is_available() else "cpu" 23 | 24 | if add_start_token: 25 | # leave room for token to be added: 26 | assert ( 27 | tokenizer.bos_token is not None 28 | ), "Input model must already have a BOS token if using add_start_token=True. Please use a different model, or set add_start_token=False" 29 | max_tokenized_len = max_length - 1 30 | else: 31 | max_tokenized_len = max_length 32 | 33 | encoded_texts = encodings["input_ids"] 34 | attn_masks = encodings["attention_mask"] 35 | 36 | if max_length and truncate: 37 | encoded_texts = [x[0:max_tokenized_len] for x in encoded_texts] 38 | attn_masks = [x[0:max_tokenized_len] for x in attn_masks] 39 | sliding_window = max_tokenized_len 40 | 41 | pbar = tqdm(total=len(encoded_texts), disable=hide_progress) 42 | nlls = [] 43 | for encoding_index in range(0, len(encoded_texts)): 44 | 45 | labels = torch.tensor(encoded_texts[encoding_index:encoding_index+1]) 46 | seq_len = labels.size(1) 47 | 48 | prev_end_loc = 0 49 | for begin_loc in range(0, seq_len, sliding_window): 50 | 51 | end_loc = min(begin_loc + max_tokenized_len, seq_len) 52 | trg_len = end_loc - prev_end_loc 53 | input_ids = labels[:, begin_loc:end_loc].to(device) 54 | 55 | if add_start_token: 56 | bos_tokens_tensor = torch.tensor( 57 | [[tokenizer.bos_token_id]] * input_ids.size(dim=0)).to(device) 58 | input_ids = torch.cat( 59 | [bos_tokens_tensor, input_ids], dim=1) 60 | 61 | target_ids = input_ids.clone() 62 | target_ids[:, :-trg_len] = -100 63 | 64 | with torch.no_grad(): 65 | outputs = model(input_ids, labels=target_ids) 66 | neg_log_likelihood = outputs.loss 67 | 68 | if aggressive_memory: 69 | outputs = None 70 | input_ids = None 71 | target_ids = None 72 | gc.collect() 73 | torch.cuda.empty_cache() 74 | 75 | nlls.append(neg_log_likelihood) 76 | 77 | ppl = float(torch.exp(torch.stack(nlls).mean()).float().cpu()) 78 | pbar.set_postfix(ppl=ppl) 79 | 80 | prev_end_loc = end_loc 81 | if end_loc == seq_len: 82 | break 83 | 84 | pbar.update(1) 85 | 86 | ppl = float(torch.exp(torch.stack(nlls).mean()).float().cpu()) 87 | return {"mean_perplexity": ppl} 88 | 89 | 90 | def main(args): 91 | models = [x[0] for x in args.model] 92 | tokenizer = AutoTokenizer.from_pretrained( 93 | models[0], model_max_length=sys.maxsize, trust_remote_code=True) 94 | tokenizer.pad_token = tokenizer.eos_token 95 | 96 | if args.tokenized: 97 | try: 98 | input_texts = datasets.load_from_disk(args.tokenized) 99 | except: 100 | input_texts = datasets.load_dataset( 101 | args.tokenized, name=args.subset, split=args.split) 102 | else: 103 | input_texts = datasets.load_dataset( 104 | args.dataset, name=args.subset, split=args.split) 105 | 106 | def tokenize(example): 107 | tokenized = tokenizer( 108 | example[args.feature], 109 | add_special_tokens=False, 110 | padding=True, 111 | truncation=False, 112 | max_length=sys.maxsize, 113 | return_attention_mask=True, 114 | ) 115 | example["input_ids"] = tokenized["input_ids"] 116 | example["attention_mask"] = tokenized["attention_mask"] 117 | example["tokenized_len"] = len(tokenized["input_ids"]) 118 | return example 119 | 120 | input_texts = input_texts.map(tokenize) 121 | if args.save_tokenized: 122 | input_texts.save_to_disk(args.save_tokenized) 123 | print(f"Saved tokenized dataset to {args.save_tokenized}") 124 | return 125 | 126 | if args.dataset_min_tokens: 127 | input_texts = input_texts.filter( 128 | lambda x: x["tokenized_len"] >= args.dataset_min_tokens) 129 | if args.samples: 130 | input_texts = input_texts[:args.samples] 131 | 132 | if args.tokens_step: 133 | tokens = [x for x in range( 134 | args.min_tokens, args.max_tokens + 1, args.tokens_step)] 135 | else: 136 | tokens = [args.min_tokens] 137 | while args.min_tokens < args.max_tokens: 138 | point = tokens[-1] * 2 139 | if point <= args.max_tokens: 140 | tokens.append(point) 141 | else: 142 | break 143 | 144 | results = [] 145 | for model in tqdm(models, desc="Model", leave=False, disable=args.hide_progress): 146 | torch.cuda.empty_cache() 147 | 148 | loaded = load_model_and_apply_patches(model, args) 149 | 150 | result = [] 151 | for max_length in tokens: 152 | ppl = compute_perplexity(model=loaded, tokenizer=tokenizer, encodings=input_texts, 153 | add_start_token=tokenizer.bos_token is not None, max_length=max_length, 154 | sliding_window=args.sliding_window, truncate=args.truncate, 155 | aggressive_memory=args.aggressive_memory, hide_progress=args.hide_progress)['mean_perplexity'] 156 | print(f"{model}: {max_length}={ppl}") 157 | result.append(ppl) 158 | 159 | result.insert(0, model) 160 | results.append(result) 161 | 162 | if args.output_file: 163 | with open(args.output_file, "w", encoding="utf-8") as f: 164 | f.write(f",{','.join([str(x) for x in tokens])}\n") 165 | for result in results: 166 | f.write(f"{','.join([str(x) for x in result])}\n") 167 | 168 | 169 | if __name__ == "__main__": 170 | warnings.simplefilter("ignore") 171 | parser = argparse.ArgumentParser() 172 | parser.add_argument("-m", "--model", action="append", nargs="+") 173 | parser.add_argument("-d", "--dataset", type=str) 174 | parser.add_argument("-s", "--subset", type=str) 175 | parser.add_argument("-f", "--feature", type=str) 176 | parser.add_argument("--max-tokens", type=int, default=8192) 177 | parser.add_argument("--min-tokens", type=int, default=256) 178 | parser.add_argument("--dataset-min-tokens", type=int) 179 | parser.add_argument("--tokens-step", type=int) 180 | parser.add_argument("--sliding-window", type=int, default=256) 181 | parser.add_argument("--truncate", action="store_true") 182 | parser.add_argument("--split", type=str, default="test") 183 | parser.add_argument("--samples", type=int) 184 | parser.add_argument("--save-tokenized", type=str) 185 | parser.add_argument("--tokenized", type=str) 186 | parser.add_argument("--output-file", type=str) 187 | parser.add_argument("--aggressive-memory", action="store_true") 188 | parser.add_argument("--hide-progress", action="store_true") 189 | main(add_args(parser).parse_args()) 190 | -------------------------------------------------------------------------------- /scaled_rope/configuration_mistral.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Mistral model configuration""" 16 | 17 | from transformers.configuration_utils import PretrainedConfig 18 | from transformers.utils import logging 19 | 20 | 21 | logger = logging.get_logger(__name__) 22 | 23 | MISTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP = { 24 | "mistralai/Mistral-7B-v0.1": "https://huggingface.co/mistralai/Mistral-7B-v0.1/resolve/main/config.json", 25 | "mistralai/Mistral-7B-Instruct-v0.1": "https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/resolve/main/config.json", 26 | } 27 | 28 | 29 | class MistralConfig(PretrainedConfig): 30 | r""" 31 | This is the configuration class to store the configuration of a [`MistralModel`]. It is used to instantiate an 32 | Mistral model according to the specified arguments, defining the model architecture. Instantiating a configuration 33 | with the defaults will yield a similar configuration to that of the Mistral-7B-v0.1 or Mistral-7B-Instruct-v0.1. 34 | 35 | [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1) 36 | [mistralai/Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1) 37 | 38 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 39 | documentation from [`PretrainedConfig`] for more information. 40 | 41 | 42 | Args: 43 | vocab_size (`int`, *optional*, defaults to 32000): 44 | Vocabulary size of the Mistral model. Defines the number of different tokens that can be represented by the 45 | `inputs_ids` passed when calling [`MistralModel`] 46 | hidden_size (`int`, *optional*, defaults to 4096): 47 | Dimension of the hidden representations. 48 | intermediate_size (`int`, *optional*, defaults to 14336): 49 | Dimension of the MLP representations. 50 | num_hidden_layers (`int`, *optional*, defaults to 32): 51 | Number of hidden layers in the Transformer encoder. 52 | num_attention_heads (`int`, *optional*, defaults to 32): 53 | Number of attention heads for each attention layer in the Transformer encoder. 54 | num_key_value_heads (`int`, *optional*, defaults to 8): 55 | This is the number of key_value heads that should be used to implement Grouped Query Attention. If 56 | `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if 57 | `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When 58 | converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed 59 | by meanpooling all the original heads within that group. For more details checkout [this 60 | paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`. 61 | hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): 62 | The non-linear activation function (function or string) in the decoder. 63 | max_position_embeddings (`int`, *optional*, defaults to `4096*32`): 64 | The maximum sequence length that this model might ever be used with. Mistral's sliding window attention 65 | allows sequence of up to 4096*32 tokens. 66 | initializer_range (`float`, *optional*, defaults to 0.02): 67 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 68 | rms_norm_eps (`float`, *optional*, defaults to 1e-06): 69 | The epsilon used by the rms normalization layers. 70 | use_cache (`bool`, *optional*, defaults to `True`): 71 | Whether or not the model should return the last key/values attentions (not used by all models). Only 72 | relevant if `config.is_decoder=True`. 73 | pad_token_id (`int`, *optional*): 74 | The id of the padding token. 75 | bos_token_id (`int`, *optional*, defaults to 1): 76 | The id of the "beginning-of-sequence" token. 77 | eos_token_id (`int`, *optional*, defaults to 2): 78 | The id of the "end-of-sequence" token. 79 | tie_word_embeddings (`bool`, *optional*, defaults to `False`): 80 | Whether the model's input and output word embeddings should be tied. 81 | rope_scaling (`Dict`, *optional*): 82 | Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling 83 | strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format 84 | is `{"type": strategy name, "factor": scaling factor}`. 85 | rope_theta (`float`, *optional*, defaults to 10000.0): 86 | The base period of the RoPE embeddings. 87 | sliding_window (`int`, *optional*, defaults to 4096): 88 | Sliding window attention window size. If not specified, will default to `4096`. 89 | 90 | 91 | ```python 92 | >>> from transformers import MistralModel, MistralConfig 93 | 94 | >>> # Initializing a Mistral 7B style configuration 95 | >>> configuration = MistralConfig() 96 | 97 | >>> # Initializing a model from the Mistral 7B style configuration 98 | >>> model = MistralModel(configuration) 99 | 100 | >>> # Accessing the model configuration 101 | >>> configuration = model.config 102 | ```""" 103 | 104 | model_type = "mistral" 105 | keys_to_ignore_at_inference = ["past_key_values"] 106 | 107 | def __init__( 108 | self, 109 | vocab_size=32000, 110 | hidden_size=4096, 111 | intermediate_size=14336, 112 | num_hidden_layers=32, 113 | num_attention_heads=32, 114 | num_key_value_heads=8, 115 | hidden_act="silu", 116 | max_position_embeddings=4096 * 32, 117 | initializer_range=0.02, 118 | rms_norm_eps=1e-6, 119 | use_cache=True, 120 | pad_token_id=None, 121 | bos_token_id=1, 122 | eos_token_id=2, 123 | tie_word_embeddings=False, 124 | rope_scaling=None, 125 | rope_theta=10000.0, 126 | sliding_window=4096, 127 | **kwargs, 128 | ): 129 | self.vocab_size = vocab_size 130 | self.max_position_embeddings = max_position_embeddings 131 | self.hidden_size = hidden_size 132 | self.intermediate_size = intermediate_size 133 | self.num_hidden_layers = num_hidden_layers 134 | self.num_attention_heads = num_attention_heads 135 | self.sliding_window = sliding_window 136 | 137 | # for backward compatibility 138 | if num_key_value_heads is None: 139 | num_key_value_heads = num_attention_heads 140 | 141 | self.num_key_value_heads = num_key_value_heads 142 | self.hidden_act = hidden_act 143 | self.initializer_range = initializer_range 144 | self.rms_norm_eps = rms_norm_eps 145 | self.use_cache = use_cache 146 | self.rope_scaling = rope_scaling 147 | self._rope_scaling_validation() 148 | self.rope_theta = rope_theta 149 | 150 | super().__init__( 151 | pad_token_id=pad_token_id, 152 | bos_token_id=bos_token_id, 153 | eos_token_id=eos_token_id, 154 | tie_word_embeddings=tie_word_embeddings, 155 | **kwargs, 156 | ) 157 | 158 | def _rope_scaling_validation(self): 159 | """ 160 | Validate the `rope_scaling` configuration. 161 | """ 162 | if self.rope_scaling is None: 163 | return 164 | 165 | if not isinstance(self.rope_scaling, dict): 166 | raise ValueError( 167 | "`rope_scaling` must be a dictionary, " 168 | f"got {self.rope_scaling}" 169 | ) 170 | rope_scaling_type = self.rope_scaling.get("type", None) 171 | rope_scaling_factor = self.rope_scaling.get("factor", None) 172 | if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic", "yarn", "dynamic-yarn"]: 173 | raise ValueError( 174 | f"`rope_scaling`'s name field must be one of ['linear', 'dynamic', 'yarn', 'dynamic-yarn'], got {rope_scaling_type}" 175 | ) 176 | if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: 177 | raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}") 178 | if rope_scaling_type == "yarn" or rope_scaling_type == "dynamic-yarn": 179 | original_max_position_embeddings = self.rope_scaling.get("original_max_position_embeddings", None) 180 | if original_max_position_embeddings is None or not isinstance(original_max_position_embeddings, int): 181 | raise ValueError(f"`rope_scaling.original_max_position_embeddings` must be set to an int when using yarn, and dynamic-yarn") -------------------------------------------------------------------------------- /scaled_rope/configuration_llama.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. 3 | # 4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX 5 | # and OPT implementations in this library. It has been modified from its 6 | # original forms to accommodate minor architectural differences compared 7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | """ LLaMA model configuration""" 21 | 22 | from transformers.configuration_utils import PretrainedConfig 23 | from transformers.utils import logging 24 | 25 | 26 | logger = logging.get_logger(__name__) 27 | 28 | LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP = {} 29 | 30 | 31 | class LlamaConfig(PretrainedConfig): 32 | r""" 33 | This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA 34 | model according to the specified arguments, defining the model architecture. Instantiating a configuration with the 35 | defaults will yield a similar configuration to that of the LLaMA-7B. 36 | 37 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 38 | documentation from [`PretrainedConfig`] for more information. 39 | 40 | 41 | Args: 42 | vocab_size (`int`, *optional*, defaults to 32000): 43 | Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the 44 | `inputs_ids` passed when calling [`LlamaModel`] 45 | hidden_size (`int`, *optional*, defaults to 4096): 46 | Dimension of the hidden representations. 47 | intermediate_size (`int`, *optional*, defaults to 11008): 48 | Dimension of the MLP representations. 49 | num_hidden_layers (`int`, *optional*, defaults to 32): 50 | Number of hidden layers in the Transformer encoder. 51 | num_attention_heads (`int`, *optional*, defaults to 32): 52 | Number of attention heads for each attention layer in the Transformer encoder. 53 | num_key_value_heads (`int`, *optional*): 54 | This is the number of key_value heads that should be used to implement Grouped Query Attention. If 55 | `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if 56 | `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When 57 | converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed 58 | by meanpooling all the original heads within that group. For more details checkout [this 59 | paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to 60 | `num_attention_heads`. 61 | pretraining_tp (`int`, *optional*, defaults to `1`): 62 | Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this 63 | document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is 64 | necessary to ensure exact reproducibility of the pretraining results. Please refer to [this 65 | issue](https://github.com/pytorch/pytorch/issues/76232). 66 | hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): 67 | The non-linear activation function (function or string) in the decoder. 68 | max_position_embeddings (`int`, *optional*, defaults to 2048): 69 | The maximum sequence length that this model might ever be used with. Typically set this to something large 70 | just in case (e.g., 512 or 1024 or 2048). 71 | initializer_range (`float`, *optional*, defaults to 0.02): 72 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 73 | rms_norm_eps (`float`, *optional*, defaults to 1e-12): 74 | The epsilon used by the rms normalization layers. 75 | use_cache (`bool`, *optional*, defaults to `True`): 76 | Whether or not the model should return the last key/values attentions (not used by all models). Only 77 | relevant if `config.is_decoder=True`. 78 | tie_word_embeddings(`bool`, *optional*, defaults to `False`): 79 | Whether to tie weight embeddings 80 | rope_scaling (`Dict`, *optional*): 81 | Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling 82 | strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format 83 | is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update 84 | `max_position_embeddings` to the expected new maximum. See the following thread for more information on how 85 | these scaling strategies behave: 86 | https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an 87 | experimental feature, subject to breaking API changes in future versions. 88 | attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): 89 | Whether to use a bias in the query, key, value and output projection layers during self-attention. 90 | attention_dropout (`float`, *optional*, defaults to 0.0): 91 | The dropout ratio for the attention probabilities. 92 | 93 | Example: 94 | 95 | ```python 96 | >>> from transformers import LlamaModel, LlamaConfig 97 | 98 | >>> # Initializing a LLaMA llama-7b style configuration 99 | >>> configuration = LlamaConfig() 100 | 101 | >>> # Initializing a model from the llama-7b style configuration 102 | >>> model = LlamaModel(configuration) 103 | 104 | >>> # Accessing the model configuration 105 | >>> configuration = model.config 106 | ```""" 107 | model_type = "llama" 108 | keys_to_ignore_at_inference = ["past_key_values"] 109 | 110 | def __init__( 111 | self, 112 | vocab_size=32000, 113 | hidden_size=4096, 114 | intermediate_size=11008, 115 | num_hidden_layers=32, 116 | num_attention_heads=32, 117 | num_key_value_heads=None, 118 | hidden_act="silu", 119 | max_position_embeddings=2048, 120 | initializer_range=0.02, 121 | rms_norm_eps=1e-6, 122 | use_cache=True, 123 | pad_token_id=0, 124 | bos_token_id=1, 125 | eos_token_id=2, 126 | pretraining_tp=1, 127 | tie_word_embeddings=False, 128 | rope_theta=10000, 129 | rope_scaling=None, 130 | attention_bias=False, 131 | attention_dropout=0.0, 132 | **kwargs, 133 | ): 134 | self.vocab_size = vocab_size 135 | self.max_position_embeddings = max_position_embeddings 136 | self.hidden_size = hidden_size 137 | self.intermediate_size = intermediate_size 138 | self.num_hidden_layers = num_hidden_layers 139 | self.num_attention_heads = num_attention_heads 140 | 141 | # for backward compatibility 142 | if num_key_value_heads is None: 143 | num_key_value_heads = num_attention_heads 144 | 145 | self.num_key_value_heads = num_key_value_heads 146 | self.hidden_act = hidden_act 147 | self.initializer_range = initializer_range 148 | self.rms_norm_eps = rms_norm_eps 149 | self.pretraining_tp = pretraining_tp 150 | self.use_cache = use_cache 151 | self.rope_theta = rope_theta 152 | self.rope_scaling = rope_scaling 153 | self._rope_scaling_validation() 154 | self.attention_bias = attention_bias 155 | self.attention_dropout = attention_dropout 156 | 157 | super().__init__( 158 | pad_token_id=pad_token_id, 159 | bos_token_id=bos_token_id, 160 | eos_token_id=eos_token_id, 161 | tie_word_embeddings=tie_word_embeddings, 162 | **kwargs, 163 | ) 164 | 165 | def _rope_scaling_validation(self): 166 | """ 167 | Validate the `rope_scaling` configuration. 168 | """ 169 | if self.rope_scaling is None: 170 | return 171 | 172 | if not isinstance(self.rope_scaling, dict): 173 | raise ValueError( 174 | "`rope_scaling` must be a dictionary, " 175 | f"got {self.rope_scaling}" 176 | ) 177 | rope_scaling_type = self.rope_scaling.get("type", None) 178 | rope_scaling_factor = self.rope_scaling.get("factor", None) 179 | if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic", "yarn", "dynamic-yarn"]: 180 | raise ValueError( 181 | f"`rope_scaling`'s name field must be one of ['linear', 'dynamic', 'yarn', 'dynamic-yarn'], got {rope_scaling_type}" 182 | ) 183 | if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: 184 | raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}") 185 | if rope_scaling_type == "yarn" or rope_scaling_type == "dynamic-yarn": 186 | original_max_position_embeddings = self.rope_scaling.get("original_max_position_embeddings", None) 187 | if original_max_position_embeddings is None or not isinstance(original_max_position_embeddings, int): 188 | raise ValueError(f"`rope_scaling.original_max_position_embeddings` must be set to an int when using yarn, and dynamic-yarn") -------------------------------------------------------------------------------- /eval/model_loader.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from transformers import AutoConfig, AutoModelForCausalLM, BitsAndBytesConfig 3 | from scaled_rope.patch import * 4 | 5 | 6 | def load_model(model, args): 7 | if args.custom_model: 8 | from scaled_rope.modeling_llama_yarn import LlamaForCausalLM 9 | from scaled_rope.configuration_llama import LlamaConfig 10 | model_cls = LlamaForCausalLM 11 | config_cls = LlamaConfig 12 | elif args.custom_model_together: 13 | from scaled_rope.modeling_llama_together_yarn import LlamaForCausalLM 14 | from scaled_rope.configuration_llama import LlamaConfig 15 | model_cls = LlamaForCausalLM 16 | config_cls = LlamaConfig 17 | elif args.custom_model_mistral: 18 | from scaled_rope.modeling_mistral_yarn import MistralForCausalLM 19 | from scaled_rope.configuration_mistral import MistralConfig 20 | model_cls = MistralForCausalLM 21 | config_cls = MistralConfig 22 | else: 23 | model_cls = AutoModelForCausalLM 24 | config_cls = AutoConfig 25 | 26 | config = config_cls.from_pretrained( 27 | model, trust_remote_code=not args.custom_model) 28 | if args.max_position_embeddings: 29 | config.max_position_embeddings = args.max_position_embeddings 30 | if args.factor: 31 | config.rope_scaling["factor"] = args.factor 32 | if args.no_use_cache: 33 | config.use_cache = False 34 | else: 35 | config.use_cache = True 36 | if args.sliding_window_attention: 37 | config.sliding_window = args.sliding_window_attention 38 | if args.custom_model or args.custom_model_together or args.custom_model_mistral: 39 | if args.linear: 40 | config.rope_scaling = { 41 | "type": "linear", 42 | "factor": args.linear 43 | } 44 | elif args.dynamic_ntk: 45 | config.rope_scaling = { 46 | "type": "dynamic", 47 | "factor": args.dynamic_ntk 48 | } 49 | elif args.part_ntk: 50 | config.rope_scaling = { 51 | "type": "ntk-by-parts", 52 | "factor": args.part_ntk 53 | } 54 | elif args.yarn: 55 | config.rope_scaling = { 56 | "type": "yarn", 57 | "factor": args.yarn, 58 | "original_max_position_embeddings": args.original_max_position_embeddings, 59 | } 60 | elif args.dynamic_yarn: 61 | config.rope_scaling = { 62 | "type": "dynamic-yarn", 63 | "factor": args.factor if args.factor else (config.rope_scaling.get("factor", 1.0) if config.rope_scaling is not None else 1.0), 64 | "original_max_position_embeddings": args.original_max_position_embeddings if args.original_max_position_embeddings else config.rope_scaling["original_max_position_embeddings"], 65 | "finetuned": args.finetuned if args.finetuned else (config.rope_scaling.get("finetuned", False) if config.rope_scaling is not None else False) 66 | } 67 | else: 68 | if args.rerope: 69 | assert not args.custom_model and not args.custom_model_together 70 | from transformers.models.llama.modeling_llama import LlamaAttention 71 | from scaled_rope.LlamaReRoPE import forward_with_rerope 72 | LlamaAttention.forward = forward_with_rerope 73 | 74 | if args.load_in_8bit or args.load_in_4bit: 75 | quantization_config = BitsAndBytesConfig( 76 | load_in_4bit=args.load_in_4bit, 77 | load_in_8bit=args.load_in_8bit, 78 | llm_int8_threshold=6.0, 79 | llm_int8_has_fp16_weight=False, 80 | bnb_4bit_compute_dtype=torch.bfloat16, 81 | bnb_4bit_use_double_quant=True, 82 | bnb_4bit_quant_type="nf4", 83 | ) 84 | torch_dtype = None 85 | config.pretraining_tp = 1 86 | else: 87 | quantization_config = None 88 | torch_dtype = torch.bfloat16 89 | 90 | loaded = model_cls.from_pretrained( 91 | model, 92 | torch_dtype=torch_dtype, 93 | device_map="auto", 94 | trust_remote_code=not args.custom_model, 95 | config=config, 96 | quantization_config=quantization_config, 97 | use_flash_attention_2=args.flash_attention, 98 | ) 99 | 100 | return loaded 101 | 102 | 103 | def add_args(parser: ArgumentParser): 104 | parser.add_argument("--dynamic-linear", action="store_true") 105 | parser.add_argument("--dynamic-ntk", type=float) 106 | parser.add_argument("--dynamic-part-ntk", action="store_true") 107 | parser.add_argument("--dynamic-yarn", action="store_true") 108 | parser.add_argument("--ntk", type=float) 109 | parser.add_argument("--part-ntk", type=float) 110 | parser.add_argument("--linear", type=float) 111 | parser.add_argument("--yarn", type=float) 112 | parser.add_argument("--rerope", type=float) 113 | parser.add_argument("--factor", type=float) 114 | parser.add_argument("--load-in-8bit", action="store_true") 115 | parser.add_argument("--load-in-4bit", action="store_true") 116 | parser.add_argument("--finetuned", action="store_true") 117 | parser.add_argument("--gpt-neox-max-length", type=int) 118 | parser.add_argument("--adapter", type=str) 119 | parser.add_argument("--max-position-embeddings", type=int) 120 | parser.add_argument("--original-max-position-embeddings", type=int) 121 | parser.add_argument("--sliding-window-attention", type=int) 122 | parser.add_argument("--custom-model", action="store_true") 123 | parser.add_argument("--custom-model-together", action="store_true") 124 | parser.add_argument("--custom-model-mistral", action="store_true") 125 | parser.add_argument("--flash-attention", action="store_true") 126 | parser.add_argument("--no-use-cache", action="store_true") 127 | return parser 128 | 129 | 130 | def apply_patches(model, args): 131 | if not args.custom_model and not args.custom_model_together and not args.custom_model_mistral: 132 | if "GPTNeoXForCausalLM" in model.config.architectures: 133 | assert args.gpt_neox_max_length is not None 134 | patch_gptneox_for_longer_sequences(model, args.gpt_neox_max_length) 135 | if args.dynamic_linear: 136 | if "GPTNeoXForCausalLM" in model.config.architectures: 137 | patch_gptneox_for_scaled_rotary_embeddings(model) 138 | elif "LlamaForCausalLM" in model.config.architectures: 139 | patch_llama_for_dynamic_scaled_rotary_embeddings(model) 140 | else: 141 | raise RuntimeError( 142 | f"Unsupported architecture {model.config.architectures} for dyanmic linear") 143 | elif args.dynamic_ntk: 144 | if "LlamaForCausalLM" in model.config.architectures: 145 | patch_llama_for_dynamic_scaled_rotary_embeddings( 146 | model, ntk=args.dynamic_ntk) 147 | else: 148 | raise RuntimeError( 149 | f"Unsupported architecture {model.config.architectures} for dyanmic ntk") 150 | elif args.dynamic_part_ntk: 151 | if "LlamaForCausalLM" in model.config.architectures: 152 | patch_llama_for_dynamic_part_ntk_rotary_embeddings( 153 | model, args.finetuned) 154 | elif "RWForCausalLM" in model.config.architectures: 155 | patch_falcon_for_dynamic_part_ntk_rotary_embeddings(model) 156 | else: 157 | raise RuntimeError( 158 | f"Unsupported architecture {model.config.architectures} for dyanmic part ntk") 159 | elif args.dynamic_yarn: 160 | if "LlamaForCausalLM" in model.config.architectures: 161 | patch_llama_for_dynamic_yarn_rotary_embeddings( 162 | model, args.original_max_position_embeddings, args.finetuned) 163 | else: 164 | raise RuntimeError( 165 | f"Unsupported architecture {model.config.architectures} for dyanmic yarn") 166 | elif args.ntk: 167 | if "GPTNeoXForCausalLM" in model.config.architectures: 168 | patch_gptneox_for_ntk_scaled_rotary_embeddings( 169 | model, args.ntk) 170 | elif "LlamaForCausalLM" in model.config.architectures: 171 | patch_llama_for_ntk_scaled_rotary_embeddings(model, args.ntk) 172 | else: 173 | raise RuntimeError( 174 | f"Unsupported architecture {model.config.architectures} for ntk") 175 | elif args.linear: 176 | if "LlamaForCausalLM" in model.config.architectures: 177 | patch_llama_for_linear_scaled_rotary_embeddings( 178 | model, scale=args.linear) 179 | else: 180 | raise RuntimeError( 181 | f"Unsupported architecture {model.config.architectures} for linear") 182 | elif args.part_ntk: 183 | if "LlamaForCausalLM" in model.config.architectures: 184 | patch_llama_for_part_ntk_scaled_rotary_embeddings( 185 | model, scale=args.part_ntk) 186 | else: 187 | raise RuntimeError( 188 | f"Unsupported architecture {model.config.architectures} for part ntk") 189 | elif args.yarn: 190 | if "LlamaForCausalLM" in model.config.architectures: 191 | patch_llama_for_yarn_scaled_rotary_embeddings( 192 | model, scale=args.yarn, original_max_position_embeddings=args.original_max_position_embeddings) 193 | else: 194 | raise RuntimeError( 195 | f"Unsupported architecture {model.config.architectures} for YaRN") 196 | elif args.rerope: 197 | if "LlamaForCausalLM" in model.config.architectures: 198 | training_length = args.original_max_position_embeddings if args.original_max_position_embeddings else 4096 199 | window = args.rerope 200 | patch_llama_for_rerope( 201 | model, training_length=training_length, window=window) 202 | else: 203 | raise RuntimeError( 204 | f"Unsupported architecture {model.config.architectures} for YaRN") 205 | 206 | if args.adapter: 207 | from peft import PeftModel 208 | model = PeftModel.from_pretrained(model, args.adapter) 209 | model = model.merge_and_unload() 210 | 211 | return model 212 | 213 | 214 | def load_model_and_apply_patches(model, args): 215 | return apply_patches(load_model(model, args), args) 216 | -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import torch 4 | import os 5 | from datasets import load_dataset, load_from_disk, DatasetDict 6 | from datetime import timedelta 7 | from torch.utils.data import DataLoader 8 | from accelerate import Accelerator 9 | from accelerate.utils import InitProcessGroupKwargs, set_seed, DummyOptim, DummyScheduler 10 | from tqdm import tqdm 11 | from transformers import set_seed, default_data_collator, get_linear_schedule_with_warmup, get_constant_schedule_with_warmup 12 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType, FullStateDictConfig 13 | 14 | 15 | def find_all_linear_names(model): 16 | lora_module_names = set() 17 | for name, module in model.named_modules(): 18 | if isinstance(module, torch.nn.Linear): 19 | names = name.split(".") 20 | lora_module_names.add(names[0] if len(names) == 1 else names[-1]) 21 | 22 | if "lm_head" in lora_module_names: 23 | lora_module_names.remove("lm_head") 24 | 25 | return list(lora_module_names) 26 | 27 | 28 | def main(args): 29 | 30 | if args.output_dir: 31 | os.makedirs(args.output_dir, exist_ok=True) 32 | 33 | if args.wandb: 34 | import wandb 35 | wandb.login() 36 | 37 | set_seed(args.seed) 38 | 39 | timeout = InitProcessGroupKwargs(timeout=timedelta(seconds=1_000_000)) 40 | accelerator = Accelerator( 41 | gradient_accumulation_steps=args.gradient_accumulate_every, 42 | mixed_precision="bf16", 43 | log_with="wandb" if args.wandb else None, 44 | kwargs_handlers=[timeout] 45 | ) 46 | accelerator.init_trackers( 47 | project_name=args.wandb if args.wandb else "yarn", 48 | ) 49 | accelerator.print(f"Total GPUS: {accelerator.num_processes}") 50 | 51 | if args.architecture == "llama": 52 | from scaled_rope.modeling_llama_yarn import LlamaForCausalLM 53 | from scaled_rope.configuration_llama import LlamaConfig 54 | config_cls = LlamaConfig 55 | model_cls = LlamaForCausalLM 56 | original_max_position_embeddings = args.original_max_position_embeddings if args.original_max_position_embeddings else 4096 57 | elif args.architecture == "mistral": 58 | from scaled_rope.modeling_mistral_yarn import MistralForCausalLM 59 | from scaled_rope.configuration_mistral import MistralConfig 60 | config_cls = MistralConfig 61 | model_cls = MistralForCausalLM 62 | original_max_position_embeddings = args.original_max_position_embeddings if args.original_max_position_embeddings else 8192 63 | 64 | config = config_cls.from_pretrained(args.model) 65 | config.rope_scaling = { 66 | "type": args.scaling_type, 67 | "factor": args.scaling_factor, 68 | "original_max_position_embeddings": original_max_position_embeddings 69 | } 70 | config.rope_theta = args.rope_theta 71 | config.max_position_embeddings = int(args.scaling_factor * original_max_position_embeddings) \ 72 | if not args.max_position_embeddings else args.max_position_embeddings 73 | 74 | sliding_window_attention_schedule = [int(x) for x in args.sliding_window_attention_schedule.split(",")] \ 75 | if args.sliding_window_attention_schedule else None 76 | if sliding_window_attention_schedule is not None and len(sliding_window_attention_schedule) == 1: 77 | config.sliding_window = sliding_window_attention_schedule[0] 78 | accelerator.print( 79 | f"Sliding attention window set to {config.sliding_window}") 80 | 81 | model = model_cls.from_pretrained( 82 | args.model, 83 | torch_dtype=torch.bfloat16, 84 | config=config, 85 | use_flash_attention_2=True 86 | ) 87 | 88 | try: 89 | train_dataset = load_dataset(args.dataset) 90 | except: 91 | train_dataset = load_from_disk(args.dataset) 92 | if isinstance(train_dataset, DatasetDict): 93 | train_dataset = train_dataset["train"] 94 | 95 | if "input_ids" not in train_dataset.column_names: 96 | raise RuntimeError("Dataset must include an `input_ids` feature") 97 | if "labels" not in train_dataset.column_names: 98 | def add_labels(sample): 99 | sample["labels"] = copy.deepcopy(sample["input_ids"]) 100 | return sample 101 | train_dataset = train_dataset.map( 102 | add_labels, desc="Adding labels", num_proc=args.num_proc) 103 | if "attention_mask" not in train_dataset.column_names: 104 | def add_attention_mask(sample): 105 | sample["attention_mask"] = torch.ones( 106 | len(sample["input_ids"]), dtype=torch.int8) 107 | return sample 108 | train_dataset = train_dataset.map( 109 | add_attention_mask, desc="Adding attention mask", num_proc=args.num_proc) 110 | 111 | if args.truncate: 112 | def truncate(sample): 113 | sample["input_ids"] = sample["input_ids"][0:args.truncate] 114 | sample["labels"] = sample["labels"][0:args.truncate] 115 | sample["attention_mask"] = sample["attention_mask"][0:args.truncate] 116 | return sample 117 | train_dataset = train_dataset.map( 118 | truncate, desc="Truncating", num_proc=args.num_proc) 119 | 120 | train_loader = DataLoader( 121 | train_dataset, 122 | collate_fn=default_data_collator, 123 | shuffle=True, 124 | batch_size=args.batch_size 125 | ) 126 | 127 | if args.lora: 128 | from peft import get_peft_model, LoraConfig, TaskType 129 | target_modules = find_all_linear_names(model) 130 | accelerator.print(f"LoRA target modules: {target_modules}") 131 | peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, inference_mode=False, 132 | r=16, lora_alpha=64, lora_dropout=0.05, target_modules=target_modules) 133 | model = get_peft_model(model, peft_config) 134 | model.print_trainable_parameters() 135 | 136 | if args.deepspeed: 137 | optim = DummyOptim(model.parameters(), lr=args.learning_rate) 138 | scheduler = DummyScheduler( 139 | optim, num_training_steps=args.max_train_steps, num_warmup_steps=args.warmup_steps) 140 | model, optim, train_loader, scheduler = accelerator.prepare( 141 | model, optim, train_loader, scheduler 142 | ) 143 | else: 144 | model = accelerator.prepare(model) 145 | optim = torch.optim.AdamW(model.parameters(), lr=args.learning_rate) 146 | if args.lr_schedule == "linear": 147 | scheduler = get_linear_schedule_with_warmup( 148 | optim, num_training_steps=args.max_train_steps, num_warmup_steps=args.warmup_steps) 149 | elif args.lr_schedule == "constant": 150 | scheduler = get_constant_schedule_with_warmup( 151 | optim, num_warmup_steps=args.warmup_steps) 152 | optim, train_loader, scheduler = accelerator.prepare( 153 | optim, train_loader, scheduler) 154 | 155 | if not args.lora: 156 | model.gradient_checkpointing_enable() 157 | 158 | accelerator.register_for_checkpointing(scheduler) 159 | total_batch_size = ( 160 | args.batch_size * accelerator.num_processes * args.gradient_accumulate_every 161 | ) 162 | 163 | accelerator.print(f"Max train steps: {args.max_train_steps}") 164 | accelerator.print(f"Total batch size: {total_batch_size}") 165 | progress_bar = tqdm( 166 | range(args.max_train_steps), disable=not accelerator.is_local_main_process 167 | ) 168 | completed_steps = 0 169 | 170 | if args.resume_from_checkpoint: 171 | if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": 172 | accelerator.print( 173 | f"Resuming from checkpoint {args.resume_from_checkpoint}") 174 | accelerator.load_state(args.resume_from_checkpoint) 175 | path = os.path.basename(args.resume_from_checkpoint) 176 | training_difference = os.path.splitext(path)[0] 177 | 178 | resume_step = ( 179 | int(training_difference.replace("step_", "")) 180 | ) 181 | 182 | if args.resume_from_checkpoint and resume_step is not None: 183 | train_loader = accelerator.skip_first_batches( 184 | train_loader, resume_step) 185 | completed_steps += resume_step 186 | progress_bar.update(resume_step) 187 | accelerator.print(f"Resuming training from step {resume_step}") 188 | 189 | loss_file = open(args.log_loss, "a" if args.resume_from_checkpoint else "w") if args.log_loss and accelerator.is_main_process else None 190 | 191 | if not args.save_only: 192 | model.train() 193 | for step, batch in enumerate(train_loader): 194 | if sliding_window_attention_schedule is not None: 195 | model.config.sliding_window = sliding_window_attention_schedule[completed_steps % len( 196 | sliding_window_attention_schedule)] 197 | 198 | loss_log = None 199 | with accelerator.accumulate(model): 200 | loss = model(**batch).loss 201 | accelerator.backward(loss) 202 | 203 | if accelerator.sync_gradients: 204 | loss_log = {"loss": loss.item()} 205 | accelerator.log(loss_log, step=completed_steps) 206 | if loss_file is not None: 207 | loss_file.write(f"{loss_log['loss']},") 208 | loss_file.flush() 209 | if isinstance(args.grad_norm, float): 210 | accelerator.clip_grad_norm_( 211 | model.parameters(), args.grad_norm) 212 | 213 | optim.step() 214 | scheduler.step() 215 | optim.zero_grad() 216 | 217 | if accelerator.sync_gradients: 218 | progress_bar.update(1) 219 | if loss_log is not None: 220 | progress_bar.set_postfix(loss_log) 221 | completed_steps += 1 222 | 223 | if isinstance(args.checkpointing_steps, int) and completed_steps > 0: 224 | if completed_steps % args.checkpointing_steps == 0: 225 | output_dir = f"step_{completed_steps}" 226 | if args.output_dir is not None: 227 | output_dir = os.path.join( 228 | args.output_dir, output_dir) 229 | accelerator.save_state(output_dir) 230 | 231 | if completed_steps >= args.max_train_steps: 232 | break 233 | 234 | accelerator.print(f"Training Finished") 235 | accelerator.end_training() 236 | 237 | if args.output_dir is not None: 238 | accelerator.print(f"Saving model to {args.output_dir}") 239 | 240 | accelerator.wait_for_everyone() 241 | 242 | if args.deepspeed: 243 | state_dict = accelerator.get_state_dict(model) 244 | else: 245 | full_state_dict_config = FullStateDictConfig( 246 | offload_to_cpu=True, rank0_only=True) 247 | with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, full_state_dict_config): 248 | state_dict = accelerator.get_state_dict(model, unwrap=False) 249 | 250 | accelerator.unwrap_model(model).save_pretrained( 251 | f"{args.output_dir}", 252 | is_main_process=accelerator.is_main_process, 253 | save_function=accelerator.save, 254 | state_dict=state_dict, 255 | ) 256 | 257 | accelerator.print(f"Saving Finished") 258 | 259 | 260 | if __name__ == "__main__": 261 | args = argparse.ArgumentParser() 262 | args.add_argument("--batch-size", type=int, default=1) 263 | args.add_argument("--gradient-accumulate-every", type=int, default=8) 264 | args.add_argument("--resume-from-checkpoint", type=str) 265 | args.add_argument("--checkpointing-steps", type=int) 266 | args.add_argument("--output-dir", type=str, required=True) 267 | args.add_argument("--wandb", type=str) 268 | args.add_argument("--seed", type=int, default=42) 269 | args.add_argument("--max-train-steps", type=int, default=400) 270 | args.add_argument("--warmup-steps", type=int, default=20) 271 | args.add_argument("--learning-rate", type=float, default=2e-5) 272 | args.add_argument("--grad-norm", action="store_true") 273 | args.add_argument("--lora", action="store_true") 274 | args.add_argument("--model", type=str, 275 | default="NousResearch/Llama-2-7b-hf") 276 | args.add_argument("--scaling-factor", type=float, default=16.0) 277 | args.add_argument("--scaling-type", type=str, default="yarn") 278 | args.add_argument("--rope-theta", type=float, default=10000.0) 279 | args.add_argument("--truncate", type=int) 280 | args.add_argument("--dataset", type=str, 281 | default="emozilla/pg_books-tokenized-bos-eos-chunked-65536") 282 | args.add_argument("--deepspeed", action="store_true") 283 | args.add_argument("--num-proc", type=int, default=32) 284 | args.add_argument("--architecture", type=str, 285 | choices=["llama", "mistral"], default="llama") 286 | args.add_argument("--max-position-embeddings", type=int) 287 | args.add_argument("--sliding-window-attention-schedule", type=str) 288 | args.add_argument("--lr-schedule", type=str, 289 | choices=["linear", "constant"], default="linear") 290 | args.add_argument("--save-only", action="store_true") 291 | args.add_argument("--log-loss", type=str) 292 | args.add_argument("--original-max-position-embeddings", type=int) 293 | main(args.parse_args()) 294 | -------------------------------------------------------------------------------- /data/Yarn-Llama-2-13b-64k-mmlu.json: -------------------------------------------------------------------------------- 1 | { 2 | "results": { 3 | "hendrycksTest-abstract_algebra": { 4 | "acc": 0.32, 5 | "acc_stderr": 0.046882617226215034, 6 | "acc_norm": 0.32, 7 | "acc_norm_stderr": 0.046882617226215034 8 | }, 9 | "hendrycksTest-anatomy": { 10 | "acc": 0.4740740740740741, 11 | "acc_stderr": 0.04313531696750574, 12 | "acc_norm": 0.4740740740740741, 13 | "acc_norm_stderr": 0.04313531696750574 14 | }, 15 | "hendrycksTest-astronomy": { 16 | "acc": 0.5263157894736842, 17 | "acc_stderr": 0.04063302731486671, 18 | "acc_norm": 0.5263157894736842, 19 | "acc_norm_stderr": 0.04063302731486671 20 | }, 21 | "hendrycksTest-business_ethics": { 22 | "acc": 0.53, 23 | "acc_stderr": 0.05016135580465919, 24 | "acc_norm": 0.53, 25 | "acc_norm_stderr": 0.05016135580465919 26 | }, 27 | "hendrycksTest-clinical_knowledge": { 28 | "acc": 0.5660377358490566, 29 | "acc_stderr": 0.030503292013342592, 30 | "acc_norm": 0.5660377358490566, 31 | "acc_norm_stderr": 0.030503292013342592 32 | }, 33 | "hendrycksTest-college_biology": { 34 | "acc": 0.5694444444444444, 35 | "acc_stderr": 0.04140685639111503, 36 | "acc_norm": 0.5694444444444444, 37 | "acc_norm_stderr": 0.04140685639111503 38 | }, 39 | "hendrycksTest-college_chemistry": { 40 | "acc": 0.41, 41 | "acc_stderr": 0.04943110704237101, 42 | "acc_norm": 0.41, 43 | "acc_norm_stderr": 0.04943110704237101 44 | }, 45 | "hendrycksTest-college_computer_science": { 46 | "acc": 0.46, 47 | "acc_stderr": 0.05009082659620333, 48 | "acc_norm": 0.46, 49 | "acc_norm_stderr": 0.05009082659620333 50 | }, 51 | "hendrycksTest-college_mathematics": { 52 | "acc": 0.27, 53 | "acc_stderr": 0.0446196043338474, 54 | "acc_norm": 0.27, 55 | "acc_norm_stderr": 0.0446196043338474 56 | }, 57 | "hendrycksTest-college_medicine": { 58 | "acc": 0.4913294797687861, 59 | "acc_stderr": 0.03811890988940412, 60 | "acc_norm": 0.4913294797687861, 61 | "acc_norm_stderr": 0.03811890988940412 62 | }, 63 | "hendrycksTest-college_physics": { 64 | "acc": 0.2549019607843137, 65 | "acc_stderr": 0.0433643270799318, 66 | "acc_norm": 0.2549019607843137, 67 | "acc_norm_stderr": 0.0433643270799318 68 | }, 69 | "hendrycksTest-computer_security": { 70 | "acc": 0.7, 71 | "acc_stderr": 0.046056618647183814, 72 | "acc_norm": 0.7, 73 | "acc_norm_stderr": 0.046056618647183814 74 | }, 75 | "hendrycksTest-conceptual_physics": { 76 | "acc": 0.40425531914893614, 77 | "acc_stderr": 0.032081157507886836, 78 | "acc_norm": 0.40425531914893614, 79 | "acc_norm_stderr": 0.032081157507886836 80 | }, 81 | "hendrycksTest-econometrics": { 82 | "acc": 0.2719298245614035, 83 | "acc_stderr": 0.04185774424022056, 84 | "acc_norm": 0.2719298245614035, 85 | "acc_norm_stderr": 0.04185774424022056 86 | }, 87 | "hendrycksTest-electrical_engineering": { 88 | "acc": 0.4413793103448276, 89 | "acc_stderr": 0.04137931034482758, 90 | "acc_norm": 0.4413793103448276, 91 | "acc_norm_stderr": 0.04137931034482758 92 | }, 93 | "hendrycksTest-elementary_mathematics": { 94 | "acc": 0.34656084656084657, 95 | "acc_stderr": 0.024508777521028428, 96 | "acc_norm": 0.34656084656084657, 97 | "acc_norm_stderr": 0.024508777521028428 98 | }, 99 | "hendrycksTest-formal_logic": { 100 | "acc": 0.2698412698412698, 101 | "acc_stderr": 0.03970158273235172, 102 | "acc_norm": 0.2698412698412698, 103 | "acc_norm_stderr": 0.03970158273235172 104 | }, 105 | "hendrycksTest-global_facts": { 106 | "acc": 0.38, 107 | "acc_stderr": 0.048783173121456316, 108 | "acc_norm": 0.38, 109 | "acc_norm_stderr": 0.048783173121456316 110 | }, 111 | "hendrycksTest-high_school_biology": { 112 | "acc": 0.6161290322580645, 113 | "acc_stderr": 0.02766618207553965, 114 | "acc_norm": 0.6161290322580645, 115 | "acc_norm_stderr": 0.02766618207553965 116 | }, 117 | "hendrycksTest-high_school_chemistry": { 118 | "acc": 0.4039408866995074, 119 | "acc_stderr": 0.0345245390382204, 120 | "acc_norm": 0.4039408866995074, 121 | "acc_norm_stderr": 0.0345245390382204 122 | }, 123 | "hendrycksTest-high_school_computer_science": { 124 | "acc": 0.51, 125 | "acc_stderr": 0.05024183937956912, 126 | "acc_norm": 0.51, 127 | "acc_norm_stderr": 0.05024183937956912 128 | }, 129 | "hendrycksTest-high_school_european_history": { 130 | "acc": 0.6424242424242425, 131 | "acc_stderr": 0.03742597043806586, 132 | "acc_norm": 0.6424242424242425, 133 | "acc_norm_stderr": 0.03742597043806586 134 | }, 135 | "hendrycksTest-high_school_geography": { 136 | "acc": 0.6464646464646465, 137 | "acc_stderr": 0.03406086723547155, 138 | "acc_norm": 0.6464646464646465, 139 | "acc_norm_stderr": 0.03406086723547155 140 | }, 141 | "hendrycksTest-high_school_government_and_politics": { 142 | "acc": 0.7564766839378239, 143 | "acc_stderr": 0.03097543638684544, 144 | "acc_norm": 0.7564766839378239, 145 | "acc_norm_stderr": 0.03097543638684544 146 | }, 147 | "hendrycksTest-high_school_macroeconomics": { 148 | "acc": 0.47435897435897434, 149 | "acc_stderr": 0.025317649726448663, 150 | "acc_norm": 0.47435897435897434, 151 | "acc_norm_stderr": 0.025317649726448663 152 | }, 153 | "hendrycksTest-high_school_mathematics": { 154 | "acc": 0.25925925925925924, 155 | "acc_stderr": 0.02671924078371216, 156 | "acc_norm": 0.25925925925925924, 157 | "acc_norm_stderr": 0.02671924078371216 158 | }, 159 | "hendrycksTest-high_school_microeconomics": { 160 | "acc": 0.5378151260504201, 161 | "acc_stderr": 0.032385469487589795, 162 | "acc_norm": 0.5378151260504201, 163 | "acc_norm_stderr": 0.032385469487589795 164 | }, 165 | "hendrycksTest-high_school_physics": { 166 | "acc": 0.31125827814569534, 167 | "acc_stderr": 0.03780445850526733, 168 | "acc_norm": 0.31125827814569534, 169 | "acc_norm_stderr": 0.03780445850526733 170 | }, 171 | "hendrycksTest-high_school_psychology": { 172 | "acc": 0.708256880733945, 173 | "acc_stderr": 0.019489300968876525, 174 | "acc_norm": 0.708256880733945, 175 | "acc_norm_stderr": 0.019489300968876525 176 | }, 177 | "hendrycksTest-high_school_statistics": { 178 | "acc": 0.4027777777777778, 179 | "acc_stderr": 0.03344887382997867, 180 | "acc_norm": 0.4027777777777778, 181 | "acc_norm_stderr": 0.03344887382997867 182 | }, 183 | "hendrycksTest-high_school_us_history": { 184 | "acc": 0.6813725490196079, 185 | "acc_stderr": 0.03270287181482081, 186 | "acc_norm": 0.6813725490196079, 187 | "acc_norm_stderr": 0.03270287181482081 188 | }, 189 | "hendrycksTest-high_school_world_history": { 190 | "acc": 0.6835443037974683, 191 | "acc_stderr": 0.030274974880218977, 192 | "acc_norm": 0.6835443037974683, 193 | "acc_norm_stderr": 0.030274974880218977 194 | }, 195 | "hendrycksTest-human_aging": { 196 | "acc": 0.6367713004484304, 197 | "acc_stderr": 0.03227790442850499, 198 | "acc_norm": 0.6367713004484304, 199 | "acc_norm_stderr": 0.03227790442850499 200 | }, 201 | "hendrycksTest-human_sexuality": { 202 | "acc": 0.6106870229007634, 203 | "acc_stderr": 0.04276486542814591, 204 | "acc_norm": 0.6106870229007634, 205 | "acc_norm_stderr": 0.04276486542814591 206 | }, 207 | "hendrycksTest-international_law": { 208 | "acc": 0.7024793388429752, 209 | "acc_stderr": 0.04173349148083499, 210 | "acc_norm": 0.7024793388429752, 211 | "acc_norm_stderr": 0.04173349148083499 212 | }, 213 | "hendrycksTest-jurisprudence": { 214 | "acc": 0.6759259259259259, 215 | "acc_stderr": 0.045245960070300496, 216 | "acc_norm": 0.6759259259259259, 217 | "acc_norm_stderr": 0.045245960070300496 218 | }, 219 | "hendrycksTest-logical_fallacies": { 220 | "acc": 0.6257668711656442, 221 | "acc_stderr": 0.03802068102899615, 222 | "acc_norm": 0.6257668711656442, 223 | "acc_norm_stderr": 0.03802068102899615 224 | }, 225 | "hendrycksTest-machine_learning": { 226 | "acc": 0.3125, 227 | "acc_stderr": 0.043994650575715215, 228 | "acc_norm": 0.3125, 229 | "acc_norm_stderr": 0.043994650575715215 230 | }, 231 | "hendrycksTest-management": { 232 | "acc": 0.6504854368932039, 233 | "acc_stderr": 0.04721188506097172, 234 | "acc_norm": 0.6504854368932039, 235 | "acc_norm_stderr": 0.04721188506097172 236 | }, 237 | "hendrycksTest-marketing": { 238 | "acc": 0.8034188034188035, 239 | "acc_stderr": 0.02603538609895129, 240 | "acc_norm": 0.8034188034188035, 241 | "acc_norm_stderr": 0.02603538609895129 242 | }, 243 | "hendrycksTest-medical_genetics": { 244 | "acc": 0.55, 245 | "acc_stderr": 0.05, 246 | "acc_norm": 0.55, 247 | "acc_norm_stderr": 0.05 248 | }, 249 | "hendrycksTest-miscellaneous": { 250 | "acc": 0.735632183908046, 251 | "acc_stderr": 0.015769984840690515, 252 | "acc_norm": 0.735632183908046, 253 | "acc_norm_stderr": 0.015769984840690515 254 | }, 255 | "hendrycksTest-moral_disputes": { 256 | "acc": 0.6069364161849711, 257 | "acc_stderr": 0.02629622791561367, 258 | "acc_norm": 0.6069364161849711, 259 | "acc_norm_stderr": 0.02629622791561367 260 | }, 261 | "hendrycksTest-moral_scenarios": { 262 | "acc": 0.34413407821229053, 263 | "acc_stderr": 0.015889221313307094, 264 | "acc_norm": 0.34413407821229053, 265 | "acc_norm_stderr": 0.015889221313307094 266 | }, 267 | "hendrycksTest-nutrition": { 268 | "acc": 0.5915032679738562, 269 | "acc_stderr": 0.028146405993096358, 270 | "acc_norm": 0.5915032679738562, 271 | "acc_norm_stderr": 0.028146405993096358 272 | }, 273 | "hendrycksTest-philosophy": { 274 | "acc": 0.6205787781350482, 275 | "acc_stderr": 0.02755994980234782, 276 | "acc_norm": 0.6205787781350482, 277 | "acc_norm_stderr": 0.02755994980234782 278 | }, 279 | "hendrycksTest-prehistory": { 280 | "acc": 0.5987654320987654, 281 | "acc_stderr": 0.027272582849839796, 282 | "acc_norm": 0.5987654320987654, 283 | "acc_norm_stderr": 0.027272582849839796 284 | }, 285 | "hendrycksTest-professional_accounting": { 286 | "acc": 0.425531914893617, 287 | "acc_stderr": 0.029494827600144373, 288 | "acc_norm": 0.425531914893617, 289 | "acc_norm_stderr": 0.029494827600144373 290 | }, 291 | "hendrycksTest-professional_law": { 292 | "acc": 0.4165580182529335, 293 | "acc_stderr": 0.01259115324505739, 294 | "acc_norm": 0.4165580182529335, 295 | "acc_norm_stderr": 0.01259115324505739 296 | }, 297 | "hendrycksTest-professional_medicine": { 298 | "acc": 0.43014705882352944, 299 | "acc_stderr": 0.030074971917302875, 300 | "acc_norm": 0.43014705882352944, 301 | "acc_norm_stderr": 0.030074971917302875 302 | }, 303 | "hendrycksTest-professional_psychology": { 304 | "acc": 0.5408496732026143, 305 | "acc_stderr": 0.020160213617222516, 306 | "acc_norm": 0.5408496732026143, 307 | "acc_norm_stderr": 0.020160213617222516 308 | }, 309 | "hendrycksTest-public_relations": { 310 | "acc": 0.6363636363636364, 311 | "acc_stderr": 0.04607582090719976, 312 | "acc_norm": 0.6363636363636364, 313 | "acc_norm_stderr": 0.04607582090719976 314 | }, 315 | "hendrycksTest-security_studies": { 316 | "acc": 0.6040816326530613, 317 | "acc_stderr": 0.03130802899065686, 318 | "acc_norm": 0.6040816326530613, 319 | "acc_norm_stderr": 0.03130802899065686 320 | }, 321 | "hendrycksTest-sociology": { 322 | "acc": 0.681592039800995, 323 | "acc_stderr": 0.032941184790540944, 324 | "acc_norm": 0.681592039800995, 325 | "acc_norm_stderr": 0.032941184790540944 326 | }, 327 | "hendrycksTest-us_foreign_policy": { 328 | "acc": 0.83, 329 | "acc_stderr": 0.0377525168068637, 330 | "acc_norm": 0.83, 331 | "acc_norm_stderr": 0.0377525168068637 332 | }, 333 | "hendrycksTest-virology": { 334 | "acc": 0.41566265060240964, 335 | "acc_stderr": 0.03836722176598052, 336 | "acc_norm": 0.41566265060240964, 337 | "acc_norm_stderr": 0.03836722176598052 338 | }, 339 | "hendrycksTest-world_religions": { 340 | "acc": 0.7426900584795322, 341 | "acc_stderr": 0.03352799844161865, 342 | "acc_norm": 0.7426900584795322, 343 | "acc_norm_stderr": 0.03352799844161865 344 | } 345 | }, 346 | "versions": { 347 | "hendrycksTest-abstract_algebra": 1, 348 | "hendrycksTest-anatomy": 1, 349 | "hendrycksTest-astronomy": 1, 350 | "hendrycksTest-business_ethics": 1, 351 | "hendrycksTest-clinical_knowledge": 1, 352 | "hendrycksTest-college_biology": 1, 353 | "hendrycksTest-college_chemistry": 1, 354 | "hendrycksTest-college_computer_science": 1, 355 | "hendrycksTest-college_mathematics": 1, 356 | "hendrycksTest-college_medicine": 1, 357 | "hendrycksTest-college_physics": 1, 358 | "hendrycksTest-computer_security": 1, 359 | "hendrycksTest-conceptual_physics": 1, 360 | "hendrycksTest-econometrics": 1, 361 | "hendrycksTest-electrical_engineering": 1, 362 | "hendrycksTest-elementary_mathematics": 1, 363 | "hendrycksTest-formal_logic": 1, 364 | "hendrycksTest-global_facts": 1, 365 | "hendrycksTest-high_school_biology": 1, 366 | "hendrycksTest-high_school_chemistry": 1, 367 | "hendrycksTest-high_school_computer_science": 1, 368 | "hendrycksTest-high_school_european_history": 1, 369 | "hendrycksTest-high_school_geography": 1, 370 | "hendrycksTest-high_school_government_and_politics": 1, 371 | "hendrycksTest-high_school_macroeconomics": 1, 372 | "hendrycksTest-high_school_mathematics": 1, 373 | "hendrycksTest-high_school_microeconomics": 1, 374 | "hendrycksTest-high_school_physics": 1, 375 | "hendrycksTest-high_school_psychology": 1, 376 | "hendrycksTest-high_school_statistics": 1, 377 | "hendrycksTest-high_school_us_history": 1, 378 | "hendrycksTest-high_school_world_history": 1, 379 | "hendrycksTest-human_aging": 1, 380 | "hendrycksTest-human_sexuality": 1, 381 | "hendrycksTest-international_law": 1, 382 | "hendrycksTest-jurisprudence": 1, 383 | "hendrycksTest-logical_fallacies": 1, 384 | "hendrycksTest-machine_learning": 1, 385 | "hendrycksTest-management": 1, 386 | "hendrycksTest-marketing": 1, 387 | "hendrycksTest-medical_genetics": 1, 388 | "hendrycksTest-miscellaneous": 1, 389 | "hendrycksTest-moral_disputes": 1, 390 | "hendrycksTest-moral_scenarios": 1, 391 | "hendrycksTest-nutrition": 1, 392 | "hendrycksTest-philosophy": 1, 393 | "hendrycksTest-prehistory": 1, 394 | "hendrycksTest-professional_accounting": 1, 395 | "hendrycksTest-professional_law": 1, 396 | "hendrycksTest-professional_medicine": 1, 397 | "hendrycksTest-professional_psychology": 1, 398 | "hendrycksTest-public_relations": 1, 399 | "hendrycksTest-security_studies": 1, 400 | "hendrycksTest-sociology": 1, 401 | "hendrycksTest-us_foreign_policy": 1, 402 | "hendrycksTest-virology": 1, 403 | "hendrycksTest-world_religions": 1 404 | }, 405 | "config": { 406 | "model": "hf-causal-experimental", 407 | "model_args": "pretrained=NousResearch/LLaMA-2-13B-YaRN-64K,use_accelerate=True,dtype=bfloat16,trust_remote_code=True", 408 | "num_fewshot": 5, 409 | "batch_size": "2", 410 | "batch_sizes": [], 411 | "device": null, 412 | "no_cache": false, 413 | "limit": null, 414 | "bootstrap_iters": 100000, 415 | "description_dict": {} 416 | } 417 | } -------------------------------------------------------------------------------- /data/Yarn-Llama-2-7b-128k-mmlu.json: -------------------------------------------------------------------------------- 1 | { 2 | "results": { 3 | "hendrycksTest-abstract_algebra": { 4 | "acc": 0.33, 5 | "acc_stderr": 0.04725815626252605, 6 | "acc_norm": 0.33, 7 | "acc_norm_stderr": 0.04725815626252605 8 | }, 9 | "hendrycksTest-anatomy": { 10 | "acc": 0.42962962962962964, 11 | "acc_stderr": 0.04276349494376599, 12 | "acc_norm": 0.42962962962962964, 13 | "acc_norm_stderr": 0.04276349494376599 14 | }, 15 | "hendrycksTest-astronomy": { 16 | "acc": 0.375, 17 | "acc_stderr": 0.039397364351956274, 18 | "acc_norm": 0.375, 19 | "acc_norm_stderr": 0.039397364351956274 20 | }, 21 | "hendrycksTest-business_ethics": { 22 | "acc": 0.38, 23 | "acc_stderr": 0.048783173121456316, 24 | "acc_norm": 0.38, 25 | "acc_norm_stderr": 0.048783173121456316 26 | }, 27 | "hendrycksTest-clinical_knowledge": { 28 | "acc": 0.4075471698113208, 29 | "acc_stderr": 0.030242233800854494, 30 | "acc_norm": 0.4075471698113208, 31 | "acc_norm_stderr": 0.030242233800854494 32 | }, 33 | "hendrycksTest-college_biology": { 34 | "acc": 0.4375, 35 | "acc_stderr": 0.04148415739394154, 36 | "acc_norm": 0.4375, 37 | "acc_norm_stderr": 0.04148415739394154 38 | }, 39 | "hendrycksTest-college_chemistry": { 40 | "acc": 0.32, 41 | "acc_stderr": 0.046882617226215034, 42 | "acc_norm": 0.32, 43 | "acc_norm_stderr": 0.046882617226215034 44 | }, 45 | "hendrycksTest-college_computer_science": { 46 | "acc": 0.38, 47 | "acc_stderr": 0.048783173121456316, 48 | "acc_norm": 0.38, 49 | "acc_norm_stderr": 0.048783173121456316 50 | }, 51 | "hendrycksTest-college_mathematics": { 52 | "acc": 0.22, 53 | "acc_stderr": 0.0416333199893227, 54 | "acc_norm": 0.22, 55 | "acc_norm_stderr": 0.0416333199893227 56 | }, 57 | "hendrycksTest-college_medicine": { 58 | "acc": 0.37572254335260113, 59 | "acc_stderr": 0.036928207672648664, 60 | "acc_norm": 0.37572254335260113, 61 | "acc_norm_stderr": 0.036928207672648664 62 | }, 63 | "hendrycksTest-college_physics": { 64 | "acc": 0.21568627450980393, 65 | "acc_stderr": 0.04092563958237654, 66 | "acc_norm": 0.21568627450980393, 67 | "acc_norm_stderr": 0.04092563958237654 68 | }, 69 | "hendrycksTest-computer_security": { 70 | "acc": 0.56, 71 | "acc_stderr": 0.04988876515698589, 72 | "acc_norm": 0.56, 73 | "acc_norm_stderr": 0.04988876515698589 74 | }, 75 | "hendrycksTest-conceptual_physics": { 76 | "acc": 0.4127659574468085, 77 | "acc_stderr": 0.03218471141400351, 78 | "acc_norm": 0.4127659574468085, 79 | "acc_norm_stderr": 0.03218471141400351 80 | }, 81 | "hendrycksTest-econometrics": { 82 | "acc": 0.2719298245614035, 83 | "acc_stderr": 0.04185774424022056, 84 | "acc_norm": 0.2719298245614035, 85 | "acc_norm_stderr": 0.04185774424022056 86 | }, 87 | "hendrycksTest-electrical_engineering": { 88 | "acc": 0.3931034482758621, 89 | "acc_stderr": 0.0407032901370707, 90 | "acc_norm": 0.3931034482758621, 91 | "acc_norm_stderr": 0.0407032901370707 92 | }, 93 | "hendrycksTest-elementary_mathematics": { 94 | "acc": 0.25396825396825395, 95 | "acc_stderr": 0.022418042891113946, 96 | "acc_norm": 0.25396825396825395, 97 | "acc_norm_stderr": 0.022418042891113946 98 | }, 99 | "hendrycksTest-formal_logic": { 100 | "acc": 0.29365079365079366, 101 | "acc_stderr": 0.040735243221471255, 102 | "acc_norm": 0.29365079365079366, 103 | "acc_norm_stderr": 0.040735243221471255 104 | }, 105 | "hendrycksTest-global_facts": { 106 | "acc": 0.28, 107 | "acc_stderr": 0.04512608598542128, 108 | "acc_norm": 0.28, 109 | "acc_norm_stderr": 0.04512608598542128 110 | }, 111 | "hendrycksTest-high_school_biology": { 112 | "acc": 0.4129032258064516, 113 | "acc_stderr": 0.02800913812540039, 114 | "acc_norm": 0.4129032258064516, 115 | "acc_norm_stderr": 0.02800913812540039 116 | }, 117 | "hendrycksTest-high_school_chemistry": { 118 | "acc": 0.3448275862068966, 119 | "acc_stderr": 0.03344283744280459, 120 | "acc_norm": 0.3448275862068966, 121 | "acc_norm_stderr": 0.03344283744280459 122 | }, 123 | "hendrycksTest-high_school_computer_science": { 124 | "acc": 0.41, 125 | "acc_stderr": 0.04943110704237101, 126 | "acc_norm": 0.41, 127 | "acc_norm_stderr": 0.04943110704237101 128 | }, 129 | "hendrycksTest-high_school_european_history": { 130 | "acc": 0.6484848484848484, 131 | "acc_stderr": 0.037282069986826503, 132 | "acc_norm": 0.6484848484848484, 133 | "acc_norm_stderr": 0.037282069986826503 134 | }, 135 | "hendrycksTest-high_school_geography": { 136 | "acc": 0.3787878787878788, 137 | "acc_stderr": 0.03456088731993747, 138 | "acc_norm": 0.3787878787878788, 139 | "acc_norm_stderr": 0.03456088731993747 140 | }, 141 | "hendrycksTest-high_school_government_and_politics": { 142 | "acc": 0.538860103626943, 143 | "acc_stderr": 0.03597524411734578, 144 | "acc_norm": 0.538860103626943, 145 | "acc_norm_stderr": 0.03597524411734578 146 | }, 147 | "hendrycksTest-high_school_macroeconomics": { 148 | "acc": 0.34102564102564104, 149 | "acc_stderr": 0.024035489676335068, 150 | "acc_norm": 0.34102564102564104, 151 | "acc_norm_stderr": 0.024035489676335068 152 | }, 153 | "hendrycksTest-high_school_mathematics": { 154 | "acc": 0.2518518518518518, 155 | "acc_stderr": 0.02646611753895991, 156 | "acc_norm": 0.2518518518518518, 157 | "acc_norm_stderr": 0.02646611753895991 158 | }, 159 | "hendrycksTest-high_school_microeconomics": { 160 | "acc": 0.3739495798319328, 161 | "acc_stderr": 0.03142946637883708, 162 | "acc_norm": 0.3739495798319328, 163 | "acc_norm_stderr": 0.03142946637883708 164 | }, 165 | "hendrycksTest-high_school_physics": { 166 | "acc": 0.31788079470198677, 167 | "acc_stderr": 0.038020397601079024, 168 | "acc_norm": 0.31788079470198677, 169 | "acc_norm_stderr": 0.038020397601079024 170 | }, 171 | "hendrycksTest-high_school_psychology": { 172 | "acc": 0.5155963302752293, 173 | "acc_stderr": 0.02142689153920805, 174 | "acc_norm": 0.5155963302752293, 175 | "acc_norm_stderr": 0.02142689153920805 176 | }, 177 | "hendrycksTest-high_school_statistics": { 178 | "acc": 0.25462962962962965, 179 | "acc_stderr": 0.029711275860005357, 180 | "acc_norm": 0.25462962962962965, 181 | "acc_norm_stderr": 0.029711275860005357 182 | }, 183 | "hendrycksTest-high_school_us_history": { 184 | "acc": 0.5, 185 | "acc_stderr": 0.03509312031717982, 186 | "acc_norm": 0.5, 187 | "acc_norm_stderr": 0.03509312031717982 188 | }, 189 | "hendrycksTest-high_school_world_history": { 190 | "acc": 0.569620253164557, 191 | "acc_stderr": 0.032230171959375976, 192 | "acc_norm": 0.569620253164557, 193 | "acc_norm_stderr": 0.032230171959375976 194 | }, 195 | "hendrycksTest-human_aging": { 196 | "acc": 0.4977578475336323, 197 | "acc_stderr": 0.033557465352232634, 198 | "acc_norm": 0.4977578475336323, 199 | "acc_norm_stderr": 0.033557465352232634 200 | }, 201 | "hendrycksTest-human_sexuality": { 202 | "acc": 0.45038167938931295, 203 | "acc_stderr": 0.04363643698524779, 204 | "acc_norm": 0.45038167938931295, 205 | "acc_norm_stderr": 0.04363643698524779 206 | }, 207 | "hendrycksTest-international_law": { 208 | "acc": 0.5785123966942148, 209 | "acc_stderr": 0.04507732278775087, 210 | "acc_norm": 0.5785123966942148, 211 | "acc_norm_stderr": 0.04507732278775087 212 | }, 213 | "hendrycksTest-jurisprudence": { 214 | "acc": 0.48148148148148145, 215 | "acc_stderr": 0.04830366024635331, 216 | "acc_norm": 0.48148148148148145, 217 | "acc_norm_stderr": 0.04830366024635331 218 | }, 219 | "hendrycksTest-logical_fallacies": { 220 | "acc": 0.3496932515337423, 221 | "acc_stderr": 0.03746668325470021, 222 | "acc_norm": 0.3496932515337423, 223 | "acc_norm_stderr": 0.03746668325470021 224 | }, 225 | "hendrycksTest-machine_learning": { 226 | "acc": 0.32142857142857145, 227 | "acc_stderr": 0.044328040552915206, 228 | "acc_norm": 0.32142857142857145, 229 | "acc_norm_stderr": 0.044328040552915206 230 | }, 231 | "hendrycksTest-management": { 232 | "acc": 0.3883495145631068, 233 | "acc_stderr": 0.0482572933735639, 234 | "acc_norm": 0.3883495145631068, 235 | "acc_norm_stderr": 0.0482572933735639 236 | }, 237 | "hendrycksTest-marketing": { 238 | "acc": 0.6410256410256411, 239 | "acc_stderr": 0.03142616993791924, 240 | "acc_norm": 0.6410256410256411, 241 | "acc_norm_stderr": 0.03142616993791924 242 | }, 243 | "hendrycksTest-medical_genetics": { 244 | "acc": 0.47, 245 | "acc_stderr": 0.050161355804659205, 246 | "acc_norm": 0.47, 247 | "acc_norm_stderr": 0.050161355804659205 248 | }, 249 | "hendrycksTest-miscellaneous": { 250 | "acc": 0.5696040868454662, 251 | "acc_stderr": 0.0177058687762924, 252 | "acc_norm": 0.5696040868454662, 253 | "acc_norm_stderr": 0.0177058687762924 254 | }, 255 | "hendrycksTest-moral_disputes": { 256 | "acc": 0.47109826589595377, 257 | "acc_stderr": 0.026874085883518348, 258 | "acc_norm": 0.47109826589595377, 259 | "acc_norm_stderr": 0.026874085883518348 260 | }, 261 | "hendrycksTest-moral_scenarios": { 262 | "acc": 0.23798882681564246, 263 | "acc_stderr": 0.014242630070574915, 264 | "acc_norm": 0.23798882681564246, 265 | "acc_norm_stderr": 0.014242630070574915 266 | }, 267 | "hendrycksTest-nutrition": { 268 | "acc": 0.42483660130718953, 269 | "acc_stderr": 0.028304576673141117, 270 | "acc_norm": 0.42483660130718953, 271 | "acc_norm_stderr": 0.028304576673141117 272 | }, 273 | "hendrycksTest-philosophy": { 274 | "acc": 0.5530546623794212, 275 | "acc_stderr": 0.028237769422085335, 276 | "acc_norm": 0.5530546623794212, 277 | "acc_norm_stderr": 0.028237769422085335 278 | }, 279 | "hendrycksTest-prehistory": { 280 | "acc": 0.4845679012345679, 281 | "acc_stderr": 0.02780749004427619, 282 | "acc_norm": 0.4845679012345679, 283 | "acc_norm_stderr": 0.02780749004427619 284 | }, 285 | "hendrycksTest-professional_accounting": { 286 | "acc": 0.36524822695035464, 287 | "acc_stderr": 0.028723863853281278, 288 | "acc_norm": 0.36524822695035464, 289 | "acc_norm_stderr": 0.028723863853281278 290 | }, 291 | "hendrycksTest-professional_law": { 292 | "acc": 0.3455019556714472, 293 | "acc_stderr": 0.012145303004087202, 294 | "acc_norm": 0.3455019556714472, 295 | "acc_norm_stderr": 0.012145303004087202 296 | }, 297 | "hendrycksTest-professional_medicine": { 298 | "acc": 0.39705882352941174, 299 | "acc_stderr": 0.029722152099280072, 300 | "acc_norm": 0.39705882352941174, 301 | "acc_norm_stderr": 0.029722152099280072 302 | }, 303 | "hendrycksTest-professional_psychology": { 304 | "acc": 0.43300653594771243, 305 | "acc_stderr": 0.020045442473324227, 306 | "acc_norm": 0.43300653594771243, 307 | "acc_norm_stderr": 0.020045442473324227 308 | }, 309 | "hendrycksTest-public_relations": { 310 | "acc": 0.5272727272727272, 311 | "acc_stderr": 0.04782001791380061, 312 | "acc_norm": 0.5272727272727272, 313 | "acc_norm_stderr": 0.04782001791380061 314 | }, 315 | "hendrycksTest-security_studies": { 316 | "acc": 0.363265306122449, 317 | "acc_stderr": 0.03078905113903081, 318 | "acc_norm": 0.363265306122449, 319 | "acc_norm_stderr": 0.03078905113903081 320 | }, 321 | "hendrycksTest-sociology": { 322 | "acc": 0.5223880597014925, 323 | "acc_stderr": 0.035319879302087305, 324 | "acc_norm": 0.5223880597014925, 325 | "acc_norm_stderr": 0.035319879302087305 326 | }, 327 | "hendrycksTest-us_foreign_policy": { 328 | "acc": 0.68, 329 | "acc_stderr": 0.04688261722621505, 330 | "acc_norm": 0.68, 331 | "acc_norm_stderr": 0.04688261722621505 332 | }, 333 | "hendrycksTest-virology": { 334 | "acc": 0.42168674698795183, 335 | "acc_stderr": 0.03844453181770917, 336 | "acc_norm": 0.42168674698795183, 337 | "acc_norm_stderr": 0.03844453181770917 338 | }, 339 | "hendrycksTest-world_religions": { 340 | "acc": 0.6432748538011696, 341 | "acc_stderr": 0.03674013002860954, 342 | "acc_norm": 0.6432748538011696, 343 | "acc_norm_stderr": 0.03674013002860954 344 | } 345 | }, 346 | "versions": { 347 | "hendrycksTest-abstract_algebra": 1, 348 | "hendrycksTest-anatomy": 1, 349 | "hendrycksTest-astronomy": 1, 350 | "hendrycksTest-business_ethics": 1, 351 | "hendrycksTest-clinical_knowledge": 1, 352 | "hendrycksTest-college_biology": 1, 353 | "hendrycksTest-college_chemistry": 1, 354 | "hendrycksTest-college_computer_science": 1, 355 | "hendrycksTest-college_mathematics": 1, 356 | "hendrycksTest-college_medicine": 1, 357 | "hendrycksTest-college_physics": 1, 358 | "hendrycksTest-computer_security": 1, 359 | "hendrycksTest-conceptual_physics": 1, 360 | "hendrycksTest-econometrics": 1, 361 | "hendrycksTest-electrical_engineering": 1, 362 | "hendrycksTest-elementary_mathematics": 1, 363 | "hendrycksTest-formal_logic": 1, 364 | "hendrycksTest-global_facts": 1, 365 | "hendrycksTest-high_school_biology": 1, 366 | "hendrycksTest-high_school_chemistry": 1, 367 | "hendrycksTest-high_school_computer_science": 1, 368 | "hendrycksTest-high_school_european_history": 1, 369 | "hendrycksTest-high_school_geography": 1, 370 | "hendrycksTest-high_school_government_and_politics": 1, 371 | "hendrycksTest-high_school_macroeconomics": 1, 372 | "hendrycksTest-high_school_mathematics": 1, 373 | "hendrycksTest-high_school_microeconomics": 1, 374 | "hendrycksTest-high_school_physics": 1, 375 | "hendrycksTest-high_school_psychology": 1, 376 | "hendrycksTest-high_school_statistics": 1, 377 | "hendrycksTest-high_school_us_history": 1, 378 | "hendrycksTest-high_school_world_history": 1, 379 | "hendrycksTest-human_aging": 1, 380 | "hendrycksTest-human_sexuality": 1, 381 | "hendrycksTest-international_law": 1, 382 | "hendrycksTest-jurisprudence": 1, 383 | "hendrycksTest-logical_fallacies": 1, 384 | "hendrycksTest-machine_learning": 1, 385 | "hendrycksTest-management": 1, 386 | "hendrycksTest-marketing": 1, 387 | "hendrycksTest-medical_genetics": 1, 388 | "hendrycksTest-miscellaneous": 1, 389 | "hendrycksTest-moral_disputes": 1, 390 | "hendrycksTest-moral_scenarios": 1, 391 | "hendrycksTest-nutrition": 1, 392 | "hendrycksTest-philosophy": 1, 393 | "hendrycksTest-prehistory": 1, 394 | "hendrycksTest-professional_accounting": 1, 395 | "hendrycksTest-professional_law": 1, 396 | "hendrycksTest-professional_medicine": 1, 397 | "hendrycksTest-professional_psychology": 1, 398 | "hendrycksTest-public_relations": 1, 399 | "hendrycksTest-security_studies": 1, 400 | "hendrycksTest-sociology": 1, 401 | "hendrycksTest-us_foreign_policy": 1, 402 | "hendrycksTest-virology": 1, 403 | "hendrycksTest-world_religions": 1 404 | }, 405 | "config": { 406 | "model": "hf-causal-experimental", 407 | "model_args": "pretrained=NousResearch/LLaMA-2-7B-YaRN-128K,use_accelerate=True,dtype=bfloat16,trust_remote_code=True", 408 | "num_fewshot": 5, 409 | "batch_size": "2", 410 | "batch_sizes": [], 411 | "device": null, 412 | "no_cache": false, 413 | "limit": null, 414 | "bootstrap_iters": 100000, 415 | "description_dict": {} 416 | } 417 | } --------------------------------------------------------------------------------