├── .gitignore ├── .gitmodules ├── LICENSE ├── Makefile ├── README.md ├── benchmark ├── .gitignore ├── __init__.py ├── benchmark_base.py ├── dataset_download.py ├── longbench.py ├── metrics.py ├── ms_marco_v1_1.py ├── multi_news.py ├── profile_parser.py ├── profiles │ └── document_summary_simple.json ├── results │ └── README.md ├── schema │ ├── README.md │ └── test │ │ ├── empty.xml │ │ ├── prompt_mbti.xml │ │ ├── schema_code_generation.xml │ │ ├── schema_long_task_1.xml │ │ ├── schema_mbti.xml │ │ ├── schema_mbti_short.xml │ │ ├── schema_persona.xml │ │ └── schema_persona_long.xml ├── squad_v2.py └── utils.py ├── benchmark_memcpy.py ├── config ├── dataset_maxlen.json ├── dataset_prompt.json ├── llm_config_falcon_40b.json ├── llm_config_falcon_7b.json ├── llm_config_llama2_13b.json ├── llm_config_llama2_7b.json ├── llm_config_longchat_7b.json ├── llm_config_mpt_30b.json ├── llm_config_mpt_7b.json └── llm_config_vicuna_7b.json ├── demo.py ├── eval.py ├── eval_acc.py ├── eval_acc.slurm ├── eval_sys.py ├── eval_sys_a100-7b-cpu.slurm ├── eval_sys_a100-7b-gpu.slurm ├── eval_sys_a40-7b-cpu.slurm ├── eval_sys_a40-7b-gpu.slurm ├── examples ├── code_generation_bookstore.xml ├── code_generation_game.xml ├── parameterized_prompts.xml ├── persona_generation.xml └── personalization-education.xml ├── g.sh ├── get_scores.py ├── metrics.py ├── promptcache ├── __init__.py ├── cache_engine.py ├── compiler.py ├── conversation.py ├── generation_engine.py ├── inference.py ├── model │ ├── __init__.py │ ├── falcon.py │ ├── llama2.py │ └── mpt.py ├── prompt.py └── schema.py ├── requirements.txt ├── score.py ├── scripts ├── README.md ├── benchmark_setup.json ├── eval_llama2.slurm └── run_benchmarks.py └── tests ├── test.py └── test_document_summary.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | results/ 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | .idea/ 161 | 162 | # VSCode workspace 163 | *.code-workspace 164 | .vscode/ 165 | 166 | # Models 167 | meta-llama/ -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "dependency/LongBench"] 2 | path = dependency/LongBench 3 | url = https://github.com/THUDM/LongBench.git 4 | ignore = dirty 5 | [submodule "dependency/bleurt"] 6 | path = dependency/bleurt 7 | url = https://github.com/google-research/bleurt 8 | ignore = dirty 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 In Gim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: eval 2 | 3 | # LLM_CONFIG_PATH ?= ./config/llm_config_falcon.json 4 | LLM_CONFIG_PATH ?= ./config/llm_config_llama2.json 5 | DATASET ?= squad_v2 6 | # DATASET ?= multi_news 7 | # DATASET ?= ms_marco 8 | # DATASET ?= narrativeqa 9 | ENABLE_CACHE ?= False 10 | SPLIT ?= 0,1 11 | TEST_LATENCY ?= False 12 | USE_CPU_FOR_INFERENCE ?= False 13 | eval: 14 | CUDA_VISIBLE_DEVICES=0 python3 eval.py \ 15 | --llm_config_path $(LLM_CONFIG_PATH) \ 16 | --dataset $(DATASET) \ 17 | --enable_cache $(ENABLE_CACHE) \ 18 | --split $(SPLIT) \ 19 | --test_latency $(TEST_LATENCY) \ 20 | --use_cpu_for_inference $(USE_CPU_FOR_INFERENCE) \ 21 | --verbose False 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Prompt Cache 2 | 3 | This is the repository for the [Prompt Cache: Modular Attention Reuse For Low-Latency Inference](https://arxiv.org/abs/2311.04934) paper. This repository includes the implementation and evaluation tools to demonstrate prompt caching technique. 4 | 5 | ## Getting Started 6 | 7 | ### Installation 8 | 9 | To begin using the Prompt Cache, you need to install the required dependencies: 10 | 11 | ```bash 12 | pip install -r requirements.txt 13 | ``` 14 | 15 | For evaluations involving `LongBench`, additional installation of `bleurt` from the source is necessary: 16 | 17 | ```bash 18 | cd ./dependency/bleurt 19 | pip install . 20 | ``` 21 | 22 | ### Supported Architectures 23 | 24 | The Prompt Cache extends the `transformers` library and is compatible with several Large Language Model (LLM) architectures. The inference engine currently supports: 25 | 26 | - **Llama2**: For example, use `meta-llama/Llama-2-7b-chat-hf` or `codellama/CodeLlama-7b-Instruct-hf`. 27 | - **Falcon**: Example configuration is `tiiuae/falcon-7b-instruct`. 28 | - **MPT**: An example model is `mosaicml/mpt-7b-chat-8k`. 29 | 30 | Model weights for these architectures are automatically retrieved from the Hugging Face model hub. 31 | 32 | ### Demo 33 | 34 | Explore the capabilities with our demo: 35 | 36 | ```bash 37 | python demo.py 38 | ``` 39 | 40 | In the demo script `demo.py`, the `use_cache` flag can be toggled to enable or disable the Prompt Cache feature. You are encouraged to modify the script to experiment with different prompts and schemas. 41 | 42 | The demo leverages a schema from `./examples/code_generation_game.xml` with the following prompt setup: 43 | 44 | ```xml 45 | 46 | 47 | 48 | 49 | 50 | 51 | Create a main entry for the game: 52 | 53 | ``` 54 | 55 | ### Additional Schemas 56 | 57 | Find more examples of schemas in the `./examples` directory, which can serve as templates for creating your custom prompts and schemas. For instance, using the schema found at [/examples/persona_generation.xml](./benchmark/schema_persona_long.xml), you can create a prompt like the following: 58 | ```xml 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | Introduce about yourself. 81 | 82 | 83 | ``` 84 | 85 | Executing the above prompt will generate the following response, GPU: NVIDIA RTX 4090 / CPU: Intel i9-13900K 86 | 87 | - **Baseline (with KV cache)** (TTFT: GPU: 286.9ms, CPU: 17917.14 ms) 88 | 89 | ``` 90 | User: Introduce about yourself. 91 | Assistant: Hello! My name is Alex, and I am a 25-year-old woman who lives in a beautiful coastal town. I have recently completed my 92 | Ph.D. in computer science and am now working as a software engineer at a leading tech company. When I'm not working, I 93 | enjoy spending time outdoors, reading books, and trying out new recipes in the kitchen. I am also very passionate about 94 | my work and am constantly looking for opportunities to learn and grow in my 95 | ``` 96 | 97 | - **With PromptCache** (TTFT: GPU: 78.2ms, CPU: 511.31ms) 98 | 99 | ``` 100 | User: Introduce about yourself. 101 | Assistant: Hello! My name is Alex, and I'm a 28-year-old software developer who recently moved to the coastal town of Oceanview. I have a doctorate degree in computer science and specialize in artificial intelligence and machine learning. I work as a lead engineer at a tech startup that focuses on developing innovative AI solutions for various industries. 102 | In my free time, I enjoy exploring the beautiful beaches and hiking trails around Oceanview with my wife, Sarah. We met while we were both pursuing our graduate degrees at Stanford University, and we've been inseparable ever since. We have two adorable kids, Emily and Jack, who keep us busy and bring us endless joy. 103 | As an introverted person, I often prefer spending time alone or with my close family members, but I also make an effort to connect with others through social events and community activities. I believe in being honest, kind, and respectful towards everyone, regardless of their background or beliefs. 104 | ``` 105 | 106 | ## Prompt Markup Language (PML) 107 | 108 | Schema and prompts are written in Prompt Markup Language (PML). PML is a simple XML-based language that allows users to define schemas and prompts for the Prompt Cache. 109 | 110 | ### Writing schema with PML 111 | 112 | ```xml 113 | 114 | 115 | 116 | 117 | 118 | 119 | 124 | 125 | Just some text with parameter: 126 | 127 | 128 | 129 | Nested 130 | 131 | 137 | 138 | 139 | 143 | 144 | System prompt type 1. 145 | 146 | 150 | System prompt type 2. 151 | System prompt type 3, 152 | with parameter: 153 | 154 | 155 | 156 | 157 | 158 | User 1 information 159 | User 2 information 160 | 161 | 162 | 163 | 164 | 165 | 166 | Task description 1 167 | 168 | 169 | Task description 1 170 | Task description 1 171 | 172 | 173 | 174 | 175 | 176 | ``` 177 | 178 | ### Writing Prompt with PML 179 | 180 | ```xml 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | What will be the next movement of the robot? 194 | It will move forward. 195 | What would be the speed then? 196 | 197 | ``` 198 | 199 | #### Compiling Python function into PML 200 | 201 | Compiler can be found in `promptcache/compiler.py`. You can use python decorators to compile python functions into PML. 202 | 203 | ```python 204 | @prompt 205 | def test_prompt(flag, item): 206 | "" 207 | "some system message" 208 | if flag: 209 | "only if flag is true" 210 | match item: 211 | case 1: 212 | "item is 1" 213 | case 2: 214 | "item is 2" 215 | "some user message" 216 | 217 | r = test_prompt(True, 1) 218 | 219 | print(r.get_schema()) 220 | print(r.get_prompt()) 221 | ``` 222 | 223 | 224 | 225 | ### Benchmark and Evaluation 226 | 227 | You can run accuracy benchmarks on LongBench with 228 | 229 | ```bash 230 | python eval_acc.py --help 231 | ``` 232 | 233 | To evaluate the inference time on LongBench, you can run the following command: 234 | 235 | ```bash 236 | python eval.py --help 237 | ``` 238 | 239 | ### Citation 240 | ``` 241 | @article{gim2023prompt, 242 | title={Prompt cache: Modular attention reuse for low-latency inference}, 243 | author={Gim, In and Chen, Guojun and Lee, Seung-seob and Sarda, Nikhil and Khandelwal, Anurag and Zhong, Lin}, 244 | journal={arXiv preprint arXiv:2311.04934}, 245 | year={2023} 246 | } 247 | ``` 248 | -------------------------------------------------------------------------------- /benchmark/.gitignore: -------------------------------------------------------------------------------- 1 | # Any xml file should be added with git add -f 2 | *.xml 3 | results* 4 | BLEURT-20/ 5 | -------------------------------------------------------------------------------- /benchmark/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------- /benchmark/benchmark_base.py: -------------------------------------------------------------------------------- 1 | # It provides APIs to load/initialize the dataset and evaluate the response from the llm 2 | 3 | # * Required API 4 | # * init() : download (one time) and load the dataset to run; do any preprocessing required for running this benchmark 5 | # * get_entry_count(): return the number of entries in the dataset. 6 | # * get_query(): return a list of Entry objects for the given range. 7 | import os, json 8 | import abc 9 | from typing import Tuple, List 10 | 11 | SCHEMA_FILE_DIRECTORY = "./benchmark/schema" 12 | DATASET_LIST = ["squad_v2", "multi_news", "wiki_qa", "pubmed_qa", "ms_marco", "narrativeqa", "qasper", 13 | "multifieldqa_en", "hotpotqa", "2wikimqa", "musique", "dureader", "gov_report", "qmsum", "multi_news_long", 14 | "vcsum", "trec", "triviaqa", "samsum", "lsht", "passage_count", "passage_retrieval_en", "lcc", 15 | "repobench-p"] 16 | 17 | DATASET_SUBSET = { 18 | "pubmed_qa": ["pqa_artificial", "pqa_labeled", "pqa_unlabeled"], 19 | "ms_marco": ["v1.1", "v2.1"] 20 | } 21 | 22 | class Entry: 23 | def __init__(self, schema, prompt, answer=None): 24 | """ 25 | Constructor to initialize any required variables. 26 | [schema: str] path to the schema file, usage: cache_engine.add_schema(read_file(schema, preproc)) 27 | [prompt: str] prompt text, which I should feed to the llm directly, it contains the used schema name and the question from dataset 28 | [answer: [str]] the potential answer list to the above question 29 | """ 30 | self.schema = schema 31 | self.prompt = prompt 32 | self.answer = answer 33 | 34 | def __repr__(self) -> str: 35 | return f"Entry(schema={self.schema}, prompt={self.prompt}, answer={self.answer})" 36 | 37 | 38 | class Benchmark(abc.ABC): 39 | def __init__(self, dataset_name: str): 40 | """ 41 | Constructor to initialize any required variables. 42 | """ 43 | if dataset_name not in DATASET_LIST: 44 | raise ValueError("Dataset name cannot be None, valid dataset names are: " + ", ".join(DATASET_LIST)) 45 | 46 | self.dataset_name = dataset_name 47 | self.dataset = None 48 | self.entries = [] 49 | self.schema_path = os.path.join(SCHEMA_FILE_DIRECTORY, dataset_name) 50 | if not os.path.exists(self.schema_path): 51 | os.makedirs(self.schema_path) 52 | 53 | self.load_prompt() 54 | 55 | def load_prompt(self): 56 | base_path = os.path.dirname(os.path.abspath(__file__)) 57 | dataset_prompt_path = os.path.join(base_path, '../config/dataset_prompt.json') 58 | with open(dataset_prompt_path, 'r') as f: 59 | self.dataset_prompt = json.load(f) 60 | self.dataset_prompt = self.dataset_prompt[self.dataset_name] 61 | 62 | @abc.abstractmethod 63 | def init(self, limit_entries=None): 64 | """ 65 | Download (one time) and load the dataset to run; 66 | Preprocess the dataset to be organized in the `Entry` format. 67 | """ 68 | raise NotImplementedError("This method should be overridden by subclass") 69 | 70 | def get_entry_count(self) -> int: 71 | """ 72 | Return the number of entries in the dataset. 73 | """ 74 | return len(self.entries) 75 | 76 | def get_query(self, range: Tuple[int, int]) -> List[Entry]: 77 | """ 78 | Return a list of Entry objects for the given range. 79 | [range: (int, int)] the range of entries to return 80 | """ 81 | return self.entries[range[0]:range[1]] 82 | -------------------------------------------------------------------------------- /benchmark/dataset_download.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | 3 | ## I will simply list up datasets below 4 | 5 | datasets = {} 6 | 7 | 8 | def load_documentation_summary(): 9 | ## Summary dataset 10 | datasets['multi_news'] = load_dataset('multi_news') 11 | return datasets['multi_news'] 12 | # print("Multi News\n", datasets['multi_news']['train'][0]) 13 | 14 | 15 | def load_multidoc_qna(): 16 | ## Open domain question answering 17 | # = version 2.1 = 18 | # datasets['ms_marco'] = load_dataset('ms_marco', 'v2.1') 19 | # print("MS_Marco", datasets['ms_marco']['train'][0]) 20 | 21 | # = version 1.1 = 22 | datasets['ms_marco'] = load_dataset('ms_marco', 'v1.1') 23 | # print("MS_Marco", datasets['ms_marco']['train'][0]) 24 | return datasets['ms_marco'] 25 | 26 | 27 | pass 28 | -------------------------------------------------------------------------------- /benchmark/longbench.py: -------------------------------------------------------------------------------- 1 | # import re 2 | # 3 | # from .benchmark_base import Benchmark, Entry 4 | # from .utils import XMLSchemaBuilder 5 | # from datasets import load_dataset 6 | # import os 7 | # 8 | # 9 | # def escape_tags(input_str): 10 | # # pattern = r'<(?P.*?)>' 11 | # 12 | # # # The lambda function ensures only the first letter is capitalized 13 | # # def repl(match): 14 | # # return '(' + match.group("content").capitalize() + ')' 15 | # # 16 | # # return re.sub(pattern, repl, input_str) 17 | # return input_str.replace('<', '(').replace('>', ')') 18 | # 19 | # 20 | # class LongBench(Benchmark): 21 | # def __init__(self, subset_name: str): 22 | # super().__init__(subset_name) 23 | # 24 | # def init(self, limit_entries=None): 25 | # """ 26 | # Download (one time) and load the dataset to run; 27 | # Preprocess the dataset to be organized in the `Entry` format. 28 | # """ 29 | # self.dataset = load_dataset('THUDM/LongBench', self.dataset_name) 30 | # 31 | # count = 0 32 | # for split in self.dataset.values(): 33 | # for item in split: 34 | # if limit_entries is not None and count >= limit_entries: 35 | # break 36 | # schema_name = f"schema_{item['_id']}" 37 | # 38 | # fmt_schema = self.dataset_prompt["context"].format(context=escape_tags(item["context"])) 39 | # fmt_prompt = self.dataset_prompt["question"].format(input=escape_tags(item["input"])[:1000]) 40 | # fmt_question = self.dataset_prompt["answer"] 41 | # 42 | # schema = f""" 43 | # {fmt_schema} 44 | # """ 45 | # 46 | # prompt = f""" 47 | # {fmt_prompt}{fmt_question} 48 | # """ 49 | # self.entries.append(Entry(schema, prompt, item["answers"])) 50 | # 51 | # count += 1 52 | 53 | import re 54 | 55 | from .benchmark_base import Benchmark, Entry 56 | from .utils import XMLSchemaBuilder 57 | from datasets import load_dataset 58 | import os 59 | 60 | _system_description = "Below are a series of dialogues between various people and an AI assistant. The AI tries to be helpful, polite, \ 61 | honest, sophisticated, emotionally aware, and humble-but-knowledgeable. The assistant is happy to help with \ 62 | almost anything, and will do its best to understand exactly what is needed. It also tries to avoid giving false \ 63 | or misleading information, and it caveats when it isn't entirely sure about the right answer. That said, the \ 64 | assistant is practical and really does its best, and doesn't let caution get too much in the way of being \ 65 | useful." 66 | _user_description = "For the upcoming interaction, I would like you to answer some questions about the document." 67 | _assistant_description = "Sure. I have read the document. Please give me any question." 68 | 69 | 70 | def escape_tags(input_str): 71 | # pattern = r'<(?P.*?)>' 72 | 73 | # # The lambda function ensures only the first letter is capitalized 74 | # def repl(match): 75 | # return '(' + match.group("content").capitalize() + ')' 76 | # 77 | # return re.sub(pattern, repl, input_str) 78 | return input_str.replace('<', '(').replace('>', ')') 79 | 80 | 81 | class LongBench(Benchmark): 82 | def __init__(self, subset_name: str): 83 | super().__init__(subset_name) 84 | 85 | def init(self, limit_entries=None): 86 | """ 87 | Download (one time) and load the dataset to run; 88 | Preprocess the dataset to be organized in the `Entry` format. 89 | """ 90 | self.dataset = load_dataset('THUDM/LongBench', self.dataset_name) 91 | 92 | count = 0 93 | for split in self.dataset.values(): 94 | for item in split: 95 | if limit_entries is not None and count >= limit_entries: 96 | break 97 | id = item["_id"] 98 | schema_name = f"schema_{id}" 99 | 100 | fmt_schema = self.dataset_prompt["context"].format(context=escape_tags(item["context"])) 101 | 102 | schema_content = f""" 103 | 104 | 105 | 106 | {fmt_schema} 107 | 108 | 109 | """ 110 | # 111 | # 112 | # builder = XMLSchemaBuilder(schema_name) 113 | # context = item["context"] 114 | # # print(len(context)) 115 | # # title = item["title"] 116 | # question = item["input"] 117 | # answer = item["answers"] 118 | # builder.set_system_description(_system_description) 119 | # builder.set_user_description(_user_description) 120 | # builder.add_document_module("context", 121 | # self.dataset_prompt["context"].format(context=escape_tags(context))) 122 | # builder.set_assistant_description(_assistant_description) 123 | 124 | schema_file_name = f"{schema_name}.xml" 125 | with open(os.path.join(self.schema_path, schema_file_name), "w") as f: 126 | f.write(schema_content) 127 | 128 | prompt = f""" 129 | 130 | 131 | {self.dataset_prompt["question"].format(input=escape_tags(item["input"])[:1000])}{self.dataset_prompt["answer"]} 132 | """ 133 | self.entries.append(Entry(schema_file_name, prompt, item["answers"])) 134 | 135 | count += 1 136 | 137 | 138 | if __name__ == '__main__': 139 | sq = LongBench('narrativeqa') 140 | sq.init() 141 | print(sq.get_entry_count()) 142 | print(sq.get_query((100, 101))) 143 | -------------------------------------------------------------------------------- /benchmark/metrics.py: -------------------------------------------------------------------------------- 1 | import os, requests, zipfile, io, sys 2 | from tqdm import tqdm 3 | from bleurt import score 4 | 5 | BLEURT_20_URL = "https://storage.googleapis.com/bleurt-oss-21/BLEURT-20.zip" 6 | CHECKPOINT = "BLEURT-20" 7 | 8 | def download_bleurt_20(): 9 | if not os.path.exists('./BLEURT-20/'): 10 | print("Downloading BLEURT-20 checkpoint...") 11 | with requests.get(BLEURT_20_URL, stream=True) as response: 12 | total_size = int(response.headers.get('content-length', 0)) 13 | block_size = 1024 # 1 Kbyte 14 | buffer = io.BytesIO() 15 | # Initialize tqdm with the total file size, and use it as an iterable in a loop 16 | with tqdm(total=total_size, unit='B', unit_scale=True) as pbar: 17 | for data in response.iter_content(block_size): 18 | buffer.write(data) 19 | # Update tqdm about the downloaded data size 20 | pbar.update(len(data)) 21 | 22 | buffer.seek(0) 23 | 24 | # Unzip the file 25 | print("Unzipping BLEURT-20 checkpoint...") 26 | with zipfile.ZipFile(buffer) as zip_file: 27 | # Extract all the contents into the current working directory 28 | zip_file.extractall(path=".") 29 | 30 | class BleurtScorer: 31 | def __init__(self): 32 | download_bleurt_20() 33 | self.scorer = score.BleurtScorer(CHECKPOINT) 34 | 35 | def score(self, refs=[str], hyps=[str]): 36 | return self.scorer.score(references=refs, candidates=hyps) 37 | 38 | if __name__ == '__main__': 39 | bs = BleurtScorer() 40 | print(bs.score(["i'm leo"], ["my name is leo"])) 41 | -------------------------------------------------------------------------------- /benchmark/ms_marco_v1_1.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | # Add the parent directory to the sys.path list 5 | document_summary_path = os.path.abspath(os.path.dirname(__file__)) 6 | # sys.path.append(os.path.abspath(os.path.join(document_summary_path, '..'))) 7 | 8 | from .benchmark_base import Benchmark, Entry 9 | from .dataset_download import load_multidoc_qna 10 | from .utils import XMLSchemaBuilder 11 | 12 | _document_schema_name = "multi_document_qna" 13 | _document_header = "Document" 14 | _document_dataset = "validation" 15 | _document_system_description = "Dialogues between a user and an AI about the document provided by the user with the aim of being helpful, aware, and accurate." 16 | _document_assistant_description = "Sure. I have read the documents separated by comma. Give me any instructions regarding query information based on the documents, and I will try to follow them." 17 | _document_user_summary = "Among the list of given documents separated by a comma, find the document that is most useful to answer the query and return its index starting at [0]. All documents are unrelated, return [-1]. Do not reply using a complete sentence, and only give the answer in the following format: [1]" # Take a deep breath and think step-by-step. 18 | 19 | MAX_DOCUMENT_LENGTH = 2560 20 | 21 | class MSMarcoV1(Benchmark): 22 | def __init__(self): 23 | super().__init__("ms_marco") 24 | self.next_query_idx = 0 25 | 26 | def init(self, limit_entries=None, verbose=False): 27 | """ 28 | Download (one time) and load the dataset to run; 29 | Do any preprocessing required for running this benchmark. 30 | """ 31 | self.dataset = load_multidoc_qna() 32 | if verbose: 33 | print("Dataset loaded. First entry below:") 34 | print(self.dataset[_document_dataset][1]) 35 | # Now we can generate xml files 36 | assert self.dataset is not None 37 | schema_file_name = "schema_summary_sample.xml" 38 | self._generate_xml(limit_entries) 39 | 40 | def _generate_xml(self, limit_entries): 41 | # Generate xml files 42 | # - In this version, we build the xml file per entry 43 | count = 0 44 | for document_idx in range(len(self.dataset[_document_dataset])): 45 | if limit_entries is not None and count >= limit_entries: 46 | break 47 | # Create an instance of XMLSchemaBuilder with the schema name "document_summary" 48 | query_idx = self.dataset[_document_dataset][document_idx]["query_id"] 49 | query_str = self.dataset[_document_dataset][document_idx]["query"] 50 | schema_name = f"{_document_schema_name}_{document_idx}_q{query_idx}" 51 | answer_list = self.dataset[_document_dataset][document_idx]["passages"]["is_selected"] 52 | answer_str = str(answer_list.index(True)) if True in answer_list else "-1" 53 | builder = XMLSchemaBuilder(schema_name) 54 | 55 | # Set the system description 56 | builder.set_system_description(_document_system_description) 57 | 58 | # Set the user description 59 | builder.set_user_description("") # _document_user_description 60 | # builder.add_document_module("DOC", "The given documents are the target for searching to answer the query. They can contain multiple sentences.") 61 | 62 | 63 | # Add document modules 64 | # - we need to collect passages into document str, separated by newline or \n 65 | document_str = "" 66 | 67 | for p_idx, passage in enumerate(self.dataset[_document_dataset][document_idx]["passages"]["passage_text"]): 68 | document_str += f"\"[{p_idx}] {passage},\n" 69 | document_str = document_str.replace("’", "'").replace("”", '"').replace("“", '"').replace("‘", "'").replace("…", "...").replace("–", "-") 70 | 71 | module_str = f"{_document_header}{document_idx}" 72 | builder.add_document_module(f"{module_str}", document_str) 73 | 74 | # Set assistant reply 75 | builder.set_assistant_description(_document_assistant_description) 76 | 77 | # Write the XML string to a file 78 | schema_file_name = f"{schema_name}.xml" 79 | with open(os.path.join(self.schema_path, schema_file_name), "w") as f: 80 | f.write(builder.generate_xml()) 81 | 82 | # Prepare the entry 83 | prompt = \ 84 | f""" 85 | <{module_str}/> 86 | {_document_user_summary} Query: "{query_str}" 87 | 88 | """ 89 | self.entries.append(Entry(schema_file_name, prompt, answer_str)) 90 | 91 | count += 1 92 | 93 | def get_next_query(self): 94 | """ 95 | Return query_id (unsigned), query (string), and chosen modules (a list of string). 96 | """ 97 | assert self.dataset is not None 98 | assert self.dataset[_document_dataset] is not None 99 | assert self.next_query_idx < len(self.dataset[_document_dataset]) 100 | response = self.next_query_idx, "", [f"{_document_header}{self.next_query_idx}"] 101 | self.next_query_idx += 1 102 | return response 103 | 104 | def evaluate(self, query_id, response_from_llm): 105 | """ 106 | Take query_id and response_from_llm as parameters and return a score in the range [0,1]. 107 | """ 108 | assert self.dataset is not None 109 | assert self.dataset[_document_dataset] is not None 110 | assert query_id < len(self.dataset[_document_dataset]) 111 | assert response_from_llm is not None 112 | assert response_from_llm != "" 113 | raise NotImplementedError( 114 | "This method should call utility function to measure how the response is closer to the expected answer.") 115 | 116 | if __name__ == '__main__': 117 | msm = MSMarcoV1() 118 | msm.init() 119 | print(msm.get_entry_count()) 120 | print(msm.get_query((0, 1))) 121 | -------------------------------------------------------------------------------- /benchmark/multi_news.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | # Add the parent directory to the sys.path list 5 | document_summary_path = os.path.abspath(os.path.dirname(__file__)) 6 | # sys.path.append(os.path.abspath(os.path.join(document_summary_path, '..'))) 7 | 8 | from .benchmark_base import Benchmark, Entry 9 | from .dataset_download import load_documentation_summary 10 | from .utils import XMLSchemaBuilder 11 | 12 | _document_header = "Document" 13 | _document_dataset = "validation" 14 | _document_system_description = "Dialogues between a user and an AI about the document provided by the user with the aim of being helpful, aware, and accurate." 15 | _document_assistant_description = "Sure. I have read the document. give me any instructions regarding summarization, and I will try to follow them." 16 | _document_user_summary = "Summarize the above document in around THREE sentences:" 17 | 18 | MAX_DOCUMENT_LENGTH = 2560 19 | 20 | class MultiNews(Benchmark): 21 | def __init__(self): 22 | super().__init__("multi_news") 23 | self.next_query_idx = 0 24 | 25 | def init(self, limit_entries=None, verbose=False): 26 | """ 27 | Download (one time) and load the dataset to run; 28 | Do any preprocessing required for running this benchmark. 29 | """ 30 | self.dataset = load_documentation_summary() 31 | if verbose: 32 | print("Dataset loaded. First entry below:") 33 | print(self.dataset[_document_dataset][1]) 34 | # Now we can generate xml files 35 | assert self.dataset is not None 36 | schema_file_name = "schema_summary_sample.xml" 37 | self._generate_xml(limit_entries) 38 | 39 | def _generate_xml(self, limit_entries): 40 | # Generate xml files 41 | # - In this version, we build the xml file per entry 42 | count = 0 43 | for document_idx in range(len(self.dataset[_document_dataset])): 44 | if limit_entries is not None and count >= limit_entries: 45 | break 46 | # Create an instance of XMLSchemaBuilder with the schema name "document_summary" 47 | schema_name = f"_document_schema_name_{document_idx}" 48 | builder = XMLSchemaBuilder(schema_name) 49 | 50 | # Set the system description 51 | builder.set_system_description(_document_system_description) 52 | 53 | # Set the user description 54 | builder.set_user_description("") # _document_user_description 55 | 56 | # Add document modules 57 | # builder.add_document_module("DOC", "The given documents are the target for the summarization task. They can contain multiple sentences.") 58 | document_str = self.dataset[_document_dataset][document_idx]["document"].replace("’", "'").replace("”", 59 | '"').replace( 60 | "“", '"').replace("‘", "'").replace("…", "...").replace("–", "-") 61 | if len(document_str) > MAX_DOCUMENT_LENGTH: 62 | document_str = document_str[:MAX_DOCUMENT_LENGTH] 63 | builder.add_document_module(f"{_document_header}{document_idx}", document_str) 64 | 65 | # Set assistant reply 66 | builder.set_assistant_description(_document_assistant_description) 67 | 68 | # Write the XML string to a file 69 | schema_file_name = f"{schema_name}.xml" 70 | with open(os.path.join(self.schema_path, schema_file_name), "w") as f: 71 | f.write(builder.generate_xml()) 72 | 73 | # Prepare the entry 74 | prompt = f""" 75 | 76 | <{_document_header}{document_idx}/> 77 | {_document_user_summary} 78 | """ 79 | summary_str = self.dataset[_document_dataset][document_idx]["summary"].replace("’", "'").replace("”", 80 | '"').replace( 81 | "“", '"').replace("‘", "'").replace("…", "...").replace("–", "-") 82 | self.entries.append(Entry(schema_file_name, prompt, summary_str)) 83 | 84 | count += 1 85 | 86 | def get_next_query(self): 87 | """ 88 | Return query_id (unsigned), query (string), and chosen modules (a list of string). 89 | """ 90 | assert self.dataset is not None 91 | assert self.dataset[_document_dataset] is not None 92 | assert self.next_query_idx < len(self.dataset[_document_dataset]) 93 | response = self.next_query_idx, "", [f"{_document_header}{self.next_query_idx}"] 94 | self.next_query_idx += 1 95 | return response 96 | 97 | def evaluate(self, query_id, response_from_llm): 98 | """ 99 | Take query_id and response_from_llm as parameters and return a score in the range [0,1]. 100 | """ 101 | assert self.dataset is not None 102 | assert self.dataset[_document_dataset] is not None 103 | assert query_id < len(self.dataset[_document_dataset]) 104 | assert response_from_llm is not None 105 | assert response_from_llm != "" 106 | raise NotImplementedError( 107 | "This method should call utility function to measure how the response is closer to the expected answer.") 108 | 109 | if __name__ == '__main__': 110 | mn = MultiNews() 111 | mn.init() 112 | print(mn.get_entry_count()) 113 | -------------------------------------------------------------------------------- /benchmark/profile_parser.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | 5 | 6 | class JsonParser: 7 | def __init__(self, file_path): 8 | # add this directory to the sys.path list 9 | # print(sys.path) 10 | print(f'Parsing {file_path}...') 11 | assert os.path.isfile(file_path) 12 | self.file_path = file_path 13 | self.data = None 14 | 15 | def parse(self): 16 | with open(self.file_path) as f: 17 | self.data = json.load(f) 18 | 19 | def get_data(self): 20 | return self.data 21 | 22 | 23 | class BenchmarkSetupParser(JsonParser): 24 | dataset_size_str = 'dataset_sizes' 25 | 26 | def __init__(self, file_path): 27 | super().__init__(file_path) 28 | 29 | def parse(self): 30 | super().parse() 31 | 32 | def get_data_size(self, size_name: str): 33 | assert self.data is not None 34 | assert size_name in self.data[BenchmarkSetupParser.dataset_size_str] 35 | return int(self.data[BenchmarkSetupParser.dataset_size_str][size_name]) 36 | 37 | 38 | class BenchmarkProfileParser(JsonParser): 39 | def __init__(self, file_path): 40 | super().__init__(file_path) 41 | self.benchmark_setup_parser: BenchmarkSetupParser = None 42 | 43 | def parse(self): 44 | super().parse() 45 | # try to get the benchmark setup 46 | assert 'benchmark_dataset_name' in self.data 47 | setup_json_path = f'benchmark/{self.data["benchmark_dataset_name"]}/setup.json' 48 | self.benchmark_setup_parser = BenchmarkSetupParser(setup_json_path) 49 | self.benchmark_setup_parser.parse() 50 | print(f'Parsing benchmark setup from {setup_json_path}...') 51 | # print(f'Dataset size: {self.benchmark_setup_parser.get_data_size(self.data["dataset_size"])}') 52 | 53 | def get_benchmark_name(self): 54 | return self.data['benchmark_name'] 55 | 56 | def get_benchmark_description(self): 57 | return self.data['benchmark_description'] 58 | 59 | def get_benchmark_dataset_name(self): 60 | return self.data['benchmark_dataset_name'] 61 | 62 | def get_benchmark_dataset_comment(self): 63 | return self.data['benchmark_dataset_comment'] 64 | 65 | def get_dataset_size(self): 66 | return self.data['dataset_size'] 67 | 68 | def get_prompt_cache(self): 69 | return bool(self.data['prompt_cache']) 70 | -------------------------------------------------------------------------------- /benchmark/profiles/document_summary_simple.json: -------------------------------------------------------------------------------- 1 | { 2 | "benchmark_name": "Document summary simple", 3 | "benchmark_description": "Document summary benchmark with a few entries from multi news dataset", 4 | "benchmark_dataset_name": "document_summary", 5 | "benchmark_dataset_comment": "This is the path name of the benchmark dataset module; always under benchmark/", 6 | "dataset_size": "small", 7 | "prompt_cache": true 8 | } -------------------------------------------------------------------------------- /benchmark/results/README.md: -------------------------------------------------------------------------------- 1 | Benchmark results will be stored here. -------------------------------------------------------------------------------- /benchmark/schema/README.md: -------------------------------------------------------------------------------- 1 | Place the generated schema files here, please create a folder for your dataset, e.g. squad_v2, multi_news, wiki_qa, pubmed_qa, and ms_macro. -------------------------------------------------------------------------------- /benchmark/schema/test/empty.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | You are a helpful AI assistant. 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /benchmark/schema/test/prompt_mbti.xml: -------------------------------------------------------------------------------- 1 | 2 |

3 | -------------------------------------------------------------------------------- /benchmark/schema/test/schema_code_generation.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 13 | 14 | Just some text with parameter: 15 | 16 | 17 | 18 | Nested 19 | 20 | 26 | 27 | 28 | 32 | 33 | System prompt type 1. 34 | 35 | 39 | System prompt type 2. 40 | System prompt type 3, 41 | with parameter: 42 | 43 | 44 | 45 | 46 | 47 | User 1 information 48 | User 2 information 49 | 50 | 51 | 52 | 53 | 54 | 55 | Task description 1 56 | 57 | 58 | Task description 1 59 | Task description 1 60 | 61 | 62 | 63 | -------------------------------------------------------------------------------- /benchmark/schema/test/schema_long_task_1.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 13 | 14 | Just some text with parameter: 15 | 16 | 17 | 18 | Nested 19 | 20 | 26 | 27 | 28 | 32 | 33 | System prompt type 1. 34 | 35 | 39 | System prompt type 2. 40 | System prompt type 3, 41 | with parameter: 42 | 43 | 44 | 45 | 46 | 47 | User 1 information 48 | User 2 information 49 | 50 | 51 | 52 | 53 | 54 | 55 | Task description 1 56 | 57 | 58 | Task description 1 59 | Task description 1 60 | 61 | 62 | 63 | -------------------------------------------------------------------------------- /benchmark/schema/test/schema_mbti.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | Below are a series of dialogues between various people and an AI assistant. The AI tries to be helpful, polite, 4 | honest, sophisticated, emotionally aware, and humble-but-knowledgeable. The assistant is happy to help with 5 | almost anything, and will do its best to understand exactly what is needed. It also tries to avoid giving false 6 | or misleading information, and it caveats when it isn't entirely sure about the right answer. That said, the 7 | assistant is practical and really does its best, and doesn't let caution get too much in the way of being 8 | useful. 9 | 10 | 11 | 12 | For the upcoming interaction, I would like you to create a hypothetical character based on a specific MBTI 13 | personality type. This personality type will include four indicators that shape the character's responses, 14 | demeanor, and overall approach to life. 15 | 16 | 17 | 18 | The Extroversion-Introversion indicator describes how a character responds to social situations and how 19 | they recharge their energy. This can be either Extroversion (E) or Introversion (I). 20 | 21 | 22 | This character has Extroversion (E) trait, and does not have Introversion (I) trait. 23 | Extroversion is characterized by a preference to focus on the world outside 24 | the self. Extraverts are energized by social gatherings, engaging in activities with others, and being 25 | expressive and assertive. They enjoy meeting new people, have a wide circle of friends, and are often 26 | perceived as popular and sociable. They are often more comfortable in groups, thrive in social 27 | situations, and prefer engaging with the external world. 28 | 29 | 30 | This character has Introversion (I) trait, and does not have Extroversion (E) trait. 31 | Introversion is characterized by a preference to focus on the internal world 32 | of thoughts and feelings. Introverts gain energy from spending time alone and prefer to interact with a 33 | small group of close friends. They are often perceived as reserved or reflective. They often prefer 34 | solitary activities or spending time with one or two close friends rather than in large social 35 | gatherings. They often prefer engaging with their internal world of thoughts and ideas. 36 | 37 | 38 | 39 | 40 | 41 | 42 | The Sensing-Intuition indicator describes how a character processes information and perceives the world 43 | around them. This can be either Sensing (S) or Intuition (N). 44 | 45 | 46 | This character has Sensing (S) trait, and does not have Intuition (N) trait. 47 | Sensing is characterized by a preference to focus on the present and on concrete information gained 48 | from the senses. Sensing types are often practical and realistic, and they prefer routine and order. 49 | They are often detail-oriented and observant and rely on their five senses to interpret the world. They 50 | prefer concrete, factual information rather than abstract concepts and theories. They are often 51 | pragmatic and grounded in reality. 52 | 53 | 54 | This character has Intuition (N) trait, and does not have Sensing (S) trait. 55 | Intuition is characterized by a preference to focus on the future and on possibilities. Intuitive 56 | types are often imaginative and creative, and they prefer new experiences and challenges. They are often 57 | more comfortable with theories and abstract concepts and enjoy discussing possibilities and what could 58 | be. They often rely on their intuition and are more interested in the big picture rather than the 59 | details. 60 | 61 | 62 | 63 | 64 | 65 | The Thinking-Feeling indicator describes how a character makes decisions and evaluates situations. This 66 | can be either Thinking (T) or Feeling (F). 67 | 68 | 69 | This character has Thinking (T) trait, and does not have Feeling (F) trait. 70 | Thinking is characterized by a preference to make decisions based on logic and objective analysis. 71 | Thinking types often prioritize fairness and efficiency in their decisions and may sometimes overlook 72 | the impact on people. They often approach problems and decisions logically and objectively, and they 73 | value truth and justice over harmony and cooperation. They often prefer to evaluate situations 74 | objectively and make decisions based on logic rather than emotion. 75 | 76 | 77 | This character has Feeling (F) trait, and does not have Thinking (T) trait. 78 | Feeling is characterized by a preference to make decisions based on personal values and the impact 79 | on others. Feeling types often prioritize harmony and empathy in their interactions and may sometimes 80 | overlook logical implications. They often approach problems and decisions with a people-centered 81 | perspective, and they value compassion and cooperation over efficiency and fairness. They often prefer 82 | to evaluate situations subjectively and make decisions based on personal values and beliefs. 83 | 84 | 85 | 86 | 87 | 88 | The Judging-Perceiving indicator describes how a character approaches life, either with a structured and 89 | planned approach or a flexible and adaptable one. This can be either Judging (J) or Perceiving (P). 90 | 91 | 92 | This character has Judging (J) trait, and does not have Perceiving (P) trait. 93 | Judging is characterized by a preference for structure and organization. Judging types often like 94 | to make plans and stick to them, and they prefer clarity and closure in their decisions. They are often 95 | decisive and organized, and they value predictability and stability. They often prefer to have a plan 96 | and follow it rather than being spontaneous and flexible. 97 | 98 | 99 | This character has Perceiving (P) trait, and does not have Judging (J) trait. 100 | Perceiving is characterized by a preference for flexibility and spontaneity. Perceiving types often 101 | like to keep their options open and adapt to new information as it comes in. They are often flexible and 102 | adaptable, and they value spontaneity and freedom. They often prefer to go with the flow and adapt to 103 | changes rather than sticking to a plan. 104 | 105 | 106 | 107 | Once you have created this character, I will ask questions about 'him/her' and you will respond based on the 108 | persona you have created. Remember to maintain the consistency of the character's traits throughout the 109 | conversation. Are you ready to create the character? 110 | 111 | 112 | 113 | Great, thank you for providing the MBTI personality type. Based on the provided personality traits, I have created 114 | a hypothetical character named Alex. Now feel free to ask any questions about him, and I will respond based on 115 | his persona. 116 | 117 | 118 | 119 | 120 | -------------------------------------------------------------------------------- /benchmark/schema/test/schema_mbti_short.xml: -------------------------------------------------------------------------------- 1 | 2 | Dialogues between people and an AI with the aim of being helpful, aware, and accurate. 3 | 4 | Create a character based on MBTI. This character have: 5 | 6 | 7 | E (focuses on the outside world and is energized by social interactions.) 8 | I (prefers solitude and is energized by internal thoughts.) 9 | 10 | 11 | 12 | S (practical, detail-oriented, prefers concrete information.) 13 | N (imaginative, creative, focused on possibilities.) 14 | 15 | 16 | 17 | T (makes decisions based on logic.) 18 | F (values personal impact and harmony.) 19 | 20 | 21 | 22 | J (prefers structure and plans.) 23 | P (flexible and spontaneous.) 24 | 25 | Ready to create the character? 26 | 27 | A character named Alex is created based on MBTI. Ask anything about him. 28 | 29 | -------------------------------------------------------------------------------- /benchmark/schema/test/schema_persona.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Below are a series of dialogues between various people and an AI assistant. The AI tries to be helpful, polite, 5 | honest, sophisticated, emotionally aware, and humble-but-knowledgeable. The assistant is happy to help with 6 | almost anything, and will do its best to understand exactly what is needed. It also tries to avoid giving false 7 | or misleading information, and it caveats when it isn't entirely sure about the right answer. That said, the 8 | assistant is practical and really does its best, and doesn't let caution get too much in the way of being 9 | useful. 10 | 11 | 12 | 13 | For the upcoming interaction, I would like you to create a hypothetical character based on seven 14 | combination of traits: age, residence, education, occupation, martial-status, and personality. This will shape 15 | the character's responses, demeanor, and overall approach to life. 16 | 17 | 18 | 19 | 20 | 21 | This person is between 0 and 12 years old. This age group represents children from birth 22 | to pre-adolescence. They are typically dependent on parents or guardians and are in the early stages 23 | of 24 | physical and mental development. 25 | 26 | 27 | This person is between 13 and 19 years old. This age group represents adolescents who are 28 | experiencing physical changes and emotional development. They are often in secondary school and 29 | beginning to gain some independence. 30 | 31 | 32 | This person is between 20 and 34 years old. This age group represents adults who are 33 | often completing their education, starting their careers, and may be living independently for the 34 | first 35 | time. They are exploring relationships and may be starting families. 36 | 37 | 38 | This person is between 35 and 54 years old. This age group represents adults who are 39 | often established in their careers and may have growing families. They may be experiencing changes 40 | in 41 | their physical health and may be taking care of aging parents. 42 | 43 | 44 | This person is 55+ years old. This age group represents older adults who may be 45 | retiring or already retired. They may have grown children and may be experiencing health challenges 46 | associated with aging. 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | This person lives in a city with a dense population, often with access to many amenities, public 56 | transportation, and cultural attractions. However, it may also come with challenges such as noise, 57 | pollution, and a higher cost of living. 58 | 59 | 60 | This person lives in the suburbs, areas that are often residential and located just outside of a 61 | city. Suburbs often offer a quieter environment, more green space, and may be more family-friendly. 62 | 63 | 64 | This person lives in the countryside, often in a less populated area with open spaces and natural 65 | surroundings. It may offer a slower pace of life but may also have fewer amenities and services 66 | available. 67 | 68 | 69 | This person lives in an area near the sea, often with access to beaches and water activities. It 70 | may offer a relaxed lifestyle and scenic views but may also come with challenges such as extreme 71 | weather 72 | conditions. 73 | 74 | 75 | This person lives in a mountainous area, often with access to outdoor activities such as hiking 76 | and skiing. It may offer a peaceful and natural environment but may also have challenges such as 77 | harsh 78 | winters and limited access to services. 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | This person had no formal education, which may limit job opportunities and earning potential. It 88 | may also affect one's ability to read and write or to access information and services. 89 | 90 | 91 | This person had completed high school, which is often the minimum requirement for many jobs. It 92 | indicates a basic level of education and the ability to read and write. 93 | 94 | 95 | This person had completed an undergraduate degree, which may open up more job opportunities and 96 | lead to higher earning potential. It indicates a higher level of education and specialized knowledge 97 | in 98 | a particular field. 99 | 100 | 101 | This person had completed a graduate degree, which may lead to specialized job opportunities and 102 | higher earning potential. It indicates an advanced level of education and expertise in a particular 103 | field. 104 | 105 | 106 | This person had completed a doctorate degree, which is often required for academic or research 107 | positions. It indicates the highest level of education and expertise in a particular field. 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | This person works in the medical field, which may include roles such as doctor, nurse, or medical 117 | technician. It often involves providing care for others and may require specialized training and 118 | certifications. 119 | 120 | 121 | This person works in the education field, which may include roles such as teacher, administrator, 122 | or counselor. It often involves working with children or young adults and may require specialized 123 | training and certifications. 124 | 125 | 126 | This person works in the technology field, which may include roles such as software developer, IT 127 | specialist, or network administrator. It often involves working with computers and may require 128 | specialized training and certifications. 129 | 130 | 131 | This person works in the arts and entertainment field, which may include roles such as artist, 132 | musician, or actor. It often involves creative expression and may require specialized training and 133 | talent. 134 | 135 | 136 | This person works in the finance and business field, which may include roles such as accountant, 137 | financial advisor, or business manager. It often involves managing money and may require specialized 138 | training and certifications. 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | This person has never been married, which may involve living independently or with others who 148 | are not a spouse. It may also involve focusing on personal goals and priorities. 149 | 150 | 151 | This person is currently married, which may involve sharing responsibilities and making 152 | decisions together with a spouse. It may also involve raising children together. 153 | 154 | 155 | This person had been previously married but now being single, which may involve adjusting to a 156 | new way of life and may involve co-parenting children with a former spouse. 157 | 158 | 159 | This person had lost a spouse to death, which may involve grieving and adjusting to a new way of 160 | life. It may also involve managing responsibilities alone that were once shared. 161 | 162 | 163 | This person is in a relationship but not being married, which may involve sharing some 164 | responsibilities and making decisions together with a partner. It may also involve negotiating 165 | boundaries and expectations. 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | This person has an extroverted personality, which often involves enjoying social interactions 175 | and feeling energized by being around others. Extroverts often seek out social situations and enjoy 176 | meeting new people. 177 | 178 | 179 | This person is currently married, which often involves preferring solitude or small group 180 | interactions. Introverts often feel drained by social interactions and need time alone to recharge. 181 | 182 | 183 | This person has a sensing personality, which often involves focusing on the present and relying 184 | on concrete, tangible information. Sensing types often prefer practical, realistic solutions. 185 | 186 | 187 | This person has an intuitive personality, which often involves focusing on the future and 188 | relying on abstract, theoretical information. Intuitive types often prefer creative, imaginative 189 | solutions. 190 | 191 | 192 | This person has a feeling personality, which often involves making decisions based on personal 193 | values and the impact on others. Feeling types often prioritize harmony and empathy in their 194 | interactions. 195 | 196 | 197 | 198 | 199 | Once you have created this character, I will ask questions about 'him/her' and you will respond based on the 200 | persona you have created. Remember to maintain the consistency of the character's traits throughout the 201 | conversation. Are you ready to create the character? 202 | 203 | 204 | 205 | Great, thank you for providing the information about the character. Based on the provided traits, I have 206 | created a hypothetical character named Alex. Now feel free to ask any questions about him, and I will respond 207 | based on his persona. 208 | 209 | 210 | 211 | 212 | -------------------------------------------------------------------------------- /benchmark/squad_v2.py: -------------------------------------------------------------------------------- 1 | from .benchmark_base import Benchmark, Entry 2 | from .utils import XMLSchemaBuilder 3 | from datasets import load_dataset 4 | import os 5 | 6 | _document_schema_name = "document_summary" 7 | _document_header = "Document" 8 | _document_system_description = "Below are a series of dialogues between various people and an AI assistant. The AI tries to be helpful, polite, \ 9 | honest, sophisticated, emotionally aware, and humble-but-knowledgeable. The assistant is happy to help with \ 10 | almost anything, and will do its best to understand exactly what is needed. It also tries to avoid giving false \ 11 | or misleading information, and it caveats when it isn't entirely sure about the right answer. That said, the \ 12 | assistant is practical and really does its best, and doesn't let caution get too much in the way of being \ 13 | useful." 14 | _document_user_description = "For the upcoming interaction, I would like you to answer some questions about the document." 15 | _document_assistant_description = "Sure. I have read the document. Please give me any question." 16 | 17 | 18 | class SquadV2(Benchmark): 19 | def __init__(self): 20 | super().__init__("squad_v2") 21 | 22 | def init(self, limit_entries=None): 23 | """ 24 | Download (one time) and load the dataset to run; 25 | Preprocess the dataset to be organized in the `Entry` format. 26 | """ 27 | self.dataset = load_dataset(self.dataset_name) 28 | count = 0 29 | # for split in self.dataset.values(): 30 | # only use validation set 31 | for item in self.dataset['validation']: 32 | if limit_entries is not None and count >= limit_entries: 33 | break 34 | id = item["id"] 35 | schema_name = f"schema_{id}" 36 | builder = XMLSchemaBuilder(schema_name) 37 | context = item["context"] 38 | # title = item["title"] 39 | question = item["question"] 40 | answer = item["answers"]["text"] 41 | builder.set_system_description(_document_system_description) 42 | builder.set_user_description(_document_user_description) 43 | builder.add_document_module("context", self.dataset_prompt["context"].format(context=context)) 44 | builder.set_assistant_description(_document_assistant_description) 45 | 46 | schema_file_name = f"{schema_name}.xml" 47 | with open(os.path.join(self.schema_path, schema_file_name), "w") as f: 48 | f.write(builder.generate_xml()) 49 | 50 | # skip entry without ground truth 51 | if len(answer) == 0: 52 | continue 53 | 54 | prompt = f""" 55 | 56 | 57 | {self.dataset_prompt["question"].format(input=question)} 58 | """ 59 | self.entries.append(Entry(schema_file_name, prompt, answer)) 60 | 61 | count += 1 62 | 63 | if __name__ == '__main__': 64 | sq = SquadV2() 65 | sq.init() 66 | print(sq.get_entry_count()) 67 | -------------------------------------------------------------------------------- /benchmark/utils.py: -------------------------------------------------------------------------------- 1 | from xml.etree.ElementTree import Element, SubElement, tostring, ElementTree 2 | from xml.dom import minidom 3 | 4 | 5 | class XMLSchemaBuilder: 6 | def __init__(self, schema_name): 7 | self.schema = Element('schema', name=schema_name) 8 | self.user = None 9 | self.user_union = None 10 | 11 | def set_system_description(self, description): 12 | system = SubElement(self.schema, 'system') 13 | system.text = description 14 | 15 | def set_user_description(self, description, user_union=False, scaffold_name="DOC"): 16 | self.user = SubElement(self.schema, 'user') 17 | self.user.text = description 18 | if user_union: 19 | self.user_union = SubElement(self.user, 'union', scaffold=scaffold_name) 20 | 21 | def add_document_module_union(self, module_name, content): 22 | assert self.user_union is not None 23 | module = SubElement(self.user_union, 'module', name=module_name) 24 | module.text = content.replace('\n', '\\n').replace("'", "\'").replace('"', '\"') 25 | 26 | def add_document_module(self, module_name, content): 27 | module = SubElement(self.user, 'module', name=module_name) 28 | module.text = content.replace('\n', '\\n').replace("'", "\'").replace('"', '\"') 29 | 30 | def set_assistant_description(self, description): 31 | assistant = SubElement(self.schema, 'assistant') 32 | assistant.text = description 33 | 34 | def generate_xml(self): 35 | rough_string = tostring(self.schema, 'utf-8') 36 | reparsed = minidom.parseString(rough_string) 37 | prettystr = reparsed.toprettyxml(indent="\t") 38 | return prettystr.replace('"', "'") 39 | -------------------------------------------------------------------------------- /benchmark_memcpy.py: -------------------------------------------------------------------------------- 1 | # On RTX 4090 2 | # Host-to-Host (CPU to CPU) Average Latency: 3.79 milliseconds 3 | # Host-to-Device (CPU to GPU) Average Latency: 5.34 milliseconds 4 | # Device-to-Device (GPU to GPU) Average Latency: 0.23 milliseconds 5 | # Device-to-Host (GPU to CPU) Average Latency: 5.88 milliseconds 6 | 7 | 8 | import torch 9 | import time 10 | 11 | NUM_LAYERS = 30 12 | SEQ_LEN = 5000 13 | CACHE_DIM = (40, SEQ_LEN, 128) 14 | 15 | print('loaded') 16 | 17 | 18 | def create_cache(device): 19 | return [(torch.rand(CACHE_DIM, dtype=torch.float16, device=device), 20 | torch.rand(CACHE_DIM, dtype=torch.float16, device=device)) for _ in 21 | range(NUM_LAYERS)] 22 | 23 | 24 | def benchmark_transfer(src_cache, dst_cache, description): 25 | start_time = time.time() 26 | for src, dst in zip(src_cache, dst_cache): 27 | dst[0].copy_(src[0], non_blocking=True) 28 | dst[1].copy_(src[0], non_blocking=True) 29 | torch.cuda.synchronize() # Ensure CUDA operations are synchronized 30 | elapsed = (time.time() - start_time) / NUM_LAYERS 31 | print(f"{description} Average Latency: {elapsed * 1000:.2f} milliseconds") 32 | 33 | 34 | cpu_cache = create_cache('cpu') 35 | gpu_cache = create_cache('cuda') 36 | cpu_cache_clone = create_cache('cpu') 37 | gpu_cache_clone = create_cache('cuda') 38 | 39 | # Host-to-Host (CPU to CPU) Transfer 40 | benchmark_transfer(cpu_cache, cpu_cache_clone, "Host-to-Host (CPU to CPU)") 41 | 42 | # Host-to-Device (CPU to GPU) Transfer 43 | benchmark_transfer(cpu_cache, gpu_cache_clone, "Host-to-Device (CPU to GPU)") 44 | 45 | # Device-to-Device (GPU to GPU) Transfer 46 | benchmark_transfer(gpu_cache, gpu_cache_clone, "Device-to-Device (GPU to GPU)") 47 | 48 | # Device-to-Host (GPU to CPU) Transfer 49 | benchmark_transfer(gpu_cache, cpu_cache_clone, "Device-to-Host (GPU to CPU)") 50 | -------------------------------------------------------------------------------- /config/dataset_maxlen.json: -------------------------------------------------------------------------------- 1 | { 2 | "narrativeqa": 128, 3 | "qasper": 128, 4 | "multifieldqa_en": 64, 5 | "multifieldqa_zh": 64, 6 | "hotpotqa": 32, 7 | "2wikimqa": 32, 8 | "musique": 32, 9 | "dureader": 128, 10 | "gov_report": 512, 11 | "qmsum": 512, 12 | "vcsum": 512, 13 | "trec": 64, 14 | "triviaqa": 32, 15 | "samsum": 128, 16 | "lsht": 64, 17 | "passage_count": 32, 18 | "passage_retrieval_en": 32, 19 | "passage_retrieval_zh": 32, 20 | "lcc": 64, 21 | "repobench-p": 64, 22 | "multi_news": 512, 23 | "multi_news_long": 512, 24 | "squad_v2": 32, 25 | "wiki_qa": 128, 26 | "pubmd_qa": 128, 27 | "ms_marco": 32 28 | } -------------------------------------------------------------------------------- /config/dataset_prompt.json: -------------------------------------------------------------------------------- 1 | { 2 | "narrativeqa": { 3 | "context": "You are given a story, which can be either a novel or a movie script, and a question. Answer the question as concisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nStory: {context}\n\n", 4 | "question": "Now, answer the question based on the story as concisely as you can, using a single phrase if possible. Do not provide any explanation. Question: {input}\n\n", 5 | "answer": "The concise answer is:" 6 | }, 7 | "qasper": { 8 | "context": "You are given a scientific article and a question. Article: {context}", 9 | "question": "Answer the question as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write \"unanswerable\". If the question is a yes/no question, answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation.\n\n Question: {input}\n\n", 10 | "answer": "The concise answer is:" 11 | }, 12 | "multifieldqa_en": { 13 | "context": "Read the following text and answer briefly.\n\n{context}", 14 | "question": "Question: {input}\n", 15 | "answer": "The concise answer is:" 16 | }, 17 | "multifieldqa_zh": { 18 | "context": "阅读以下文字并用中文简短回答:\n\n{context}", 19 | "question": "问题:{input}\n", 20 | "answer": "回答:" 21 | }, 22 | "hotpotqa": { 23 | "context": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}", 24 | "question": "Question: {input}\n", 25 | "answer": "The concise answer is:" 26 | }, 27 | "2wikimqa": { 28 | "context": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}", 29 | "question": "Question: {input}\n", 30 | "answer": "The concise answer is:" 31 | }, 32 | "musique": { 33 | "context": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}", 34 | "question": "Question: {input}\nAnswer:", 35 | "answer": "The concise answer is:" 36 | }, 37 | "dureader": { 38 | "context": "请基于给定的文章回答下述问题。\n\n文章:{context}", 39 | "question": "问题:{input}\n", 40 | "answer": "回答:" 41 | }, 42 | "gov_report": { 43 | "context": "You are given a report by a government agency. Write a one-page summary of the report.\n\nReport:\n{context}", 44 | "question": "", 45 | "answer": "Summary:" 46 | }, 47 | "qmsum": { 48 | "context": "You are given a meeting transcript and a query containing a question or instruction.\n\nTranscript:\n{context}", 49 | "question": "Query: {input}\n", 50 | "answer": "The concise answer is:" 51 | }, 52 | "vcsum": { 53 | "context": "下面有一段会议记录,请你阅读后,写一段总结,总结会议的内容。\n会议记录:\n{context}", 54 | "question": "", 55 | "answer": "会议总结:" 56 | }, 57 | "trec": { 58 | "context": "Please determine the type of the question below. Here are some examples of questions.\n\n{context}", 59 | "question": "{input}", 60 | "answer": "Type of the question:" 61 | }, 62 | "triviaqa": { 63 | "context": "Answer the question based on the given passage. Only give me the answer and do not output any other words. The following are some examples.\n\n{context}", 64 | "question": "{input}", 65 | "answer": "The concise answer is:" 66 | }, 67 | "samsum": { 68 | "context": "Summarize the dialogue into a few short sentences. The following are some examples.\n\n{context}", 69 | "question": "{input}", 70 | "answer": "Summary:" 71 | }, 72 | "lsht": { 73 | "context": "请判断给定新闻的类别,下面是一些例子。\n\n{context}", 74 | "question": "{input}" 75 | }, 76 | "passage_count": { 77 | "context": "There are some paragraphs below sourced from Wikipedia. Some of them may be duplicates. Please carefully read these paragraphs and determine how many unique paragraphs there are after removing duplicates. In other words, how many non-repeating paragraphs are there in total?\n\n{context}", 78 | "question": "", 79 | "answer": "The concise answer is: " 80 | }, 81 | "passage_retrieval_en": { 82 | "context": "Here are 30 paragraphs from Wikipedia, along with an abstract. Please determine which paragraph the abstract is from.\n\n{context}", 83 | "question": "The following is an abstract.\n\n{input}\n\nPlease enter the number of the paragraph that the abstract is from. The answer format must be like \"Paragraph 1\", \"Paragraph 2\", etc.\n\n", 84 | "answer": "The answer is: " 85 | }, 86 | "passage_retrieval_zh": { 87 | "context": "以下是若干段落文字,以及其中一个段落的摘要。请确定给定的摘要出自哪一段。\n\n{context}", 88 | "question": "下面是一个摘要\n\n{input}\n\n请输入摘要所属段落的编号。答案格式必须是\"段落1\",\"段落2\"等格式\n\n", 89 | "answer": "答案是:" 90 | }, 91 | "lcc": { 92 | "context": "Please complete the code given below. \n{context}", 93 | "question": "", 94 | "answer": "Next line of code:\n" 95 | }, 96 | "repobench-p": { 97 | "context": "Please complete the code given below. \n{context}", 98 | "question": "{input}", 99 | "answer": "Next line of code:\n" 100 | }, 101 | "multi_news": { 102 | "context": "You are given several news passages. Write a one-page summary of all news. \n\nNews:\n{context}\n\nNow, write a one-page summary of all the news.\n\n", 103 | "question": "", 104 | "answer": "Summary:" 105 | }, 106 | "multi_news_long": { 107 | "context": "You are given several news passages. Write a one-page summary of all news. \n\nNews:\n{context}\n\nNow, write a one-page summary of all the news.\n\n", 108 | "question": "", 109 | "answer": "Summary:" 110 | }, 111 | "squad_v2": { 112 | "context": "Read the following text and answer briefly.\n\n{context}", 113 | "question": "Question: {input}\n\n", 114 | "answer": "The concise answer is:" 115 | }, 116 | "wiki_qa": { 117 | "context": "Please complete the code given below. \n{context}", 118 | "question": "", 119 | "answer": "Next line of code:\n" 120 | }, 121 | "pubmd_qa": { 122 | "context": "Please complete the code given below. \n{context}", 123 | "question": "", 124 | "answer": "Next line of code:\n" 125 | }, 126 | "ms_marco": { 127 | "context": "", 128 | "question": "Among the list of given documents separated by a comma, find the document that is most useful to answer the query and return its index starting at [0]. All documents are unrelated, return [-1]. Do not reply using a complete sentence, and only give the answer in the following format: [1]:\n" 129 | } 130 | } -------------------------------------------------------------------------------- /config/llm_config_falcon_40b.json: -------------------------------------------------------------------------------- 1 | { 2 | "arch": "falcon", 3 | "log_name": "falcon-40b", 4 | "name": "tiiuae/falcon-40b-instruct", 5 | "load_in_8bit": true, 6 | "device_map": "auto", 7 | "max_tokens": 1500, 8 | "max_ctx_length": 3000 9 | } -------------------------------------------------------------------------------- /config/llm_config_falcon_7b.json: -------------------------------------------------------------------------------- 1 | { 2 | "arch": "falcon", 3 | "log_name": "falcon-7b", 4 | "name": "tiiuae/falcon-7b-instruct", 5 | "load_in_8bit": true, 6 | "device_map": "auto", 7 | "max_tokens": 1500, 8 | "max_ctx_length": 3000 9 | } -------------------------------------------------------------------------------- /config/llm_config_llama2_13b.json: -------------------------------------------------------------------------------- 1 | { 2 | 3 | "arch": "llama", 4 | "log_name": "llama2-13b", 5 | "name": "meta-llama/Llama-2-13b-chat-hf", 6 | "load_in_8bit": true, 7 | "device_map": "auto", 8 | "max_tokens": 3500, 9 | "max_ctx_length": 4096 10 | 11 | } -------------------------------------------------------------------------------- /config/llm_config_llama2_7b.json: -------------------------------------------------------------------------------- 1 | { 2 | "arch": "llama", 3 | "log_name": "llama2-7b", 4 | "name": "meta-llama/Llama-2-7b-chat-hf", 5 | "load_in_8bit": true, 6 | "device_map": "auto", 7 | "max_tokens": 3500, 8 | "max_ctx_length": 4096 9 | } -------------------------------------------------------------------------------- /config/llm_config_longchat_7b.json: -------------------------------------------------------------------------------- 1 | { 2 | "arch": "longchat", 3 | "log_name": "longchat-7b", 4 | "name": "lmsys/longchat-7b-v1.5-32k", 5 | "load_in_8bit": true, 6 | "device_map": "auto", 7 | "max_tokens": 8000, 8 | "max_ctx_length": 9186 9 | } -------------------------------------------------------------------------------- /config/llm_config_mpt_30b.json: -------------------------------------------------------------------------------- 1 | { 2 | 3 | "arch": "mpt", 4 | "log_name": "mpt-30b", 5 | "name": "mosaicml/mpt-30b-chat", 6 | "load_in_8bit": true, 7 | "device_map": "auto", 8 | "max_tokens": 3500, 9 | "max_ctx_length": 4096 10 | 11 | } -------------------------------------------------------------------------------- /config/llm_config_mpt_7b.json: -------------------------------------------------------------------------------- 1 | { 2 | "arch": "mpt", 3 | "log_name": "mpt-7b", 4 | "name": "mosaicml/mpt-7b-chat-8k", 5 | "load_in_8bit": true, 6 | "device_map": "auto", 7 | "max_tokens": 3500, 8 | "max_ctx_length": 4096 9 | } -------------------------------------------------------------------------------- /config/llm_config_vicuna_7b.json: -------------------------------------------------------------------------------- 1 | { 2 | "arch": "vicuna", 3 | "log_name": "vicuna-7b", 4 | "name": "lmsys/vicuna-7b-v1.5-16k", 5 | "load_in_8bit": true, 6 | "device_map": "auto", 7 | "max_tokens": 8000, 8 | "max_ctx_length": 9000 9 | } -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import random 2 | import re 3 | 4 | import numpy as np 5 | import torch.cuda 6 | import fire 7 | 8 | from promptcache.model import Llama2, Falcon, Mpt, CodeLlama 9 | 10 | from promptcache import Prompt, CompactSpaces, read_file, CacheEngine, \ 11 | GenerationEngine, GenerationParameters, llama2_template 12 | 13 | 14 | def escape_tags(input_str): 15 | pattern = r'<(?P.*?)>' 16 | 17 | def repl(match): 18 | return '(' + match.group("content").capitalize() + ')' 19 | 20 | return re.sub(pattern, repl, input_str) 21 | 22 | 23 | def main(enable_cache=True): 24 | enable_cpu_inference = False 25 | disable_prompt_cache = not enable_cache 26 | 27 | lm_for_cache = CodeLlama("codellama/CodeLlama-7b-Instruct-hf", 28 | load_in_8bit=True, 29 | device_map="auto") 30 | 31 | lm = lm_for_cache 32 | 33 | if enable_cpu_inference: 34 | lm = CodeLlama("codellama/CodeLlama-7b-Instruct-hf", 35 | load_in_8bit=False, 36 | device_map=None) 37 | 38 | # lm = Falcon("tiiuae/falcon-7b-instruct", 39 | # load_in_8bit=True if not disable_cuda else False, 40 | # device_map="auto" if not disable_cuda else None) 41 | 42 | # lm = Mpt("mosaicml/mpt-7b-chat-8k", 43 | # load_in_8bit=True if not disable_cuda else False, 44 | # device_map="auto" if not disable_cuda else None) 45 | 46 | cache_engine = CacheEngine(5000, lm_for_cache, target_device='cpu' if enable_cpu_inference else None) 47 | gen_engine = GenerationEngine(lm) 48 | 49 | preproc = [ 50 | # CompactSpaces(), 51 | lm.get_formatter() 52 | ] 53 | 54 | cache_engine.add_schema(read_file("./examples/code_generation_game.xml", preproc), max_tokens=800) 55 | 56 | parameter = GenerationParameters( 57 | temperature=1.0, 58 | repetition_penalty=1.0, 59 | top_p=0.95, 60 | top_k=-1, 61 | max_new_tokens=512, 62 | stop_token_ids=lm.stop_token_ids, 63 | stop_str=lm.stop_str 64 | ) 65 | 66 | prompt_text = f""" 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | Create a main entry for the game: 75 | 76 | 77 | """ 78 | 79 | prompt = Prompt(prompt_text, preproc) 80 | token_ids, position_ids, cache_time, cache = cache_engine.process(prompt, no_cache=disable_prompt_cache, 81 | return_full_position_ids=lm.use_full_position_ids) 82 | 83 | output_stream = gen_engine.generate(token_ids, position_ids, parameter, cache, stream_interval=2, 84 | use_full_position_ids=lm.use_full_position_ids) 85 | 86 | print(f"Assistant: ", end="", flush=True) 87 | 88 | resp = "" 89 | pre = 0 90 | for outputs in output_stream: 91 | output_text = outputs.new_text.strip().split(" ") 92 | now = len(output_text) - 1 93 | if now > pre: 94 | tt = " ".join(output_text[pre:now]) 95 | resp += tt + " " 96 | print(tt, end=" ", flush=True) 97 | pre = now 98 | tt = " ".join(output_text[pre:]) 99 | print(tt, flush=True) 100 | resp += tt 101 | 102 | print("\n") 103 | prompt_text += f"{resp}" 104 | 105 | 106 | def seed_everything(seed): 107 | torch.manual_seed(seed) 108 | torch.cuda.manual_seed(seed) 109 | np.random.seed(seed) 110 | random.seed(seed) 111 | torch.backends.cudnn.benchmark = False 112 | torch.backends.cudnn.deterministic = True 113 | torch.cuda.manual_seed_all(seed) 114 | 115 | 116 | if __name__ == "__main__": 117 | seed_everything(42) 118 | fire.Fire(main) 119 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch.cuda 5 | import fire 6 | import sys, json 7 | import os 8 | import datetime 9 | from tqdm import tqdm 10 | from benchmark.longbench import LongBench 11 | from promptcache.model import Llama2, Falcon, Mpt 12 | from promptcache import Prompt, CompactSpaces, read_file, CacheEngine, \ 13 | GenerationEngine, GenerationParameters 14 | 15 | from benchmark.benchmark_base import DATASET_LIST, SCHEMA_FILE_DIRECTORY 16 | from benchmark.squad_v2 import SquadV2 17 | from benchmark.multi_news import MultiNews 18 | from benchmark.ms_marco_v1_1 import MSMarcoV1 19 | 20 | BENCHMARK_PATH = "./benchmark" 21 | 22 | 23 | class Eval: 24 | def __init__(self, llm_config_path, dataset, enable_cache, use_cpu_for_inference=False): 25 | with open("./config/dataset_maxlen.json", 'r') as f: 26 | self.dataset_maxlen = json.load(f) 27 | 28 | with open(llm_config_path, 'r') as f: 29 | self.llm_config = json.load(f) 30 | self.enable_cache = enable_cache 31 | self.use_cpu_for_inference = use_cpu_for_inference 32 | 33 | self.model_name = self.llm_config["name"] 34 | if "llama" in self.model_name: 35 | self.model_name = "llama" 36 | self.lm_for_caching = Llama2(name=self.llm_config['name'], device_map="auto", load_in_8bit=True) 37 | elif "falcon" in self.model_name: 38 | self.model_name = "falcon" 39 | self.lm_for_caching = Falcon(name=self.llm_config['name'], device_map="auto", load_in_8bit=True) 40 | elif "mpt" in self.model_name: 41 | self.model_name = "mpt" 42 | self.lm_for_caching = Mpt(name=self.llm_config['name'], device_map="auto", load_in_8bit=True) 43 | else: 44 | raise ValueError("Invalid model name") 45 | 46 | if self.use_cpu_for_inference: 47 | if "llama" in self.model_name: 48 | self.lm = Llama2(name=self.llm_config['name'], device_map=None) 49 | elif "falcon" in self.model_name: 50 | self.lm = Falcon(name=self.llm_config['name'], device_map=None) 51 | elif "mpt" in self.model_name: 52 | self.lm = Mpt(name=self.llm_config['name'], device_map=None) 53 | else: 54 | self.lm = self.lm_for_caching 55 | 56 | self.cache_engine = CacheEngine(self.llm_config.get("max_ctx_length", 4096), self.lm_for_caching, 57 | target_device=self.lm.device) 58 | self.gen_engine = GenerationEngine(self.lm) 59 | self.preproc = [ 60 | # CompactSpaces(), 61 | self.lm.get_formatter() 62 | ] 63 | 64 | # self.parameter = GenerationParameters( 65 | # temperature=0.1, 66 | # repetition_penalty=1.17, 67 | # top_p=0.95, 68 | # top_k=-1, 69 | # max_new_tokens=512, 70 | # stop_token_ids=self.lm.stop_token_ids, 71 | # stop_str=self.lm.stop_str 72 | # ) 73 | 74 | self.parameter = GenerationParameters( 75 | temperature=1.0, 76 | repetition_penalty=1.0, 77 | top_p=0.95, 78 | top_k=-1, 79 | max_new_tokens=self.dataset_maxlen[dataset], 80 | stop_token_ids=self.lm.stop_token_ids, 81 | stop_str=self.lm.stop_str 82 | ) 83 | 84 | if dataset is None or dataset not in DATASET_LIST: 85 | raise ValueError("Dataset name cannot be None, valid dataset names are: " + ", ".join(DATASET_LIST)) 86 | 87 | match dataset: 88 | case "squad_v2": 89 | self.dataset = SquadV2() 90 | 91 | case "multi_news": 92 | self.dataset = MultiNews() 93 | 94 | case "narrativeqa": 95 | self.dataset = LongBench("narrativeqa") 96 | 97 | case "qasper": 98 | self.dataset = LongBench("qasper") 99 | 100 | case "multifieldqa_en": 101 | self.dataset = LongBench("multifieldqa_en") 102 | 103 | case "hotpotqa": 104 | self.dataset = LongBench("hotpotqa") 105 | 106 | case "2wikimqa": 107 | self.dataset = LongBench("2wikimqa") 108 | 109 | case "musique": 110 | self.dataset = LongBench("musique") 111 | 112 | case "dureader": 113 | self.dataset = LongBench("dureader") 114 | 115 | case "gov_report": 116 | self.dataset = LongBench("gov_report") 117 | 118 | case "qmsum": 119 | self.dataset = LongBench("qmsum") 120 | 121 | case "multi_news_long": 122 | self.dataset = LongBench("multi_news") 123 | 124 | case "vcsum": 125 | self.dataset = LongBench("vcsum") 126 | 127 | case "trec": 128 | self.dataset = LongBench("trec") 129 | 130 | case "triviaqa": 131 | self.dataset = LongBench("triviaqa") 132 | 133 | case "samsum": 134 | self.dataset = LongBench("samsum") 135 | 136 | case "lsht": 137 | self.dataset = LongBench("lsht") 138 | 139 | case "passage_count": 140 | self.dataset = LongBench("passage_count") 141 | 142 | case "passage_retrieval_en": 143 | self.dataset = LongBench("passage_retrieval_en") 144 | 145 | case "lcc": 146 | self.dataset = LongBench("lcc") 147 | 148 | case "repobench-p": 149 | self.dataset = LongBench("repobench-p") 150 | 151 | # for testing purpose, limit the entries to a small number 152 | self.dataset.init() 153 | 154 | # create result directory 155 | self.result_directory = os.path.join(BENCHMARK_PATH, "results", 156 | f"{self.model_name}-{self.dataset.dataset_name}") 157 | if not os.path.exists(self.result_directory): 158 | os.makedirs(self.result_directory) 159 | 160 | self.result_file_suffix = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 161 | 162 | def store_results(self, results, split): 163 | if self.enable_cache: 164 | prefix = "with_cache" 165 | else: 166 | prefix = "no_cache" 167 | 168 | with open(os.path.join(self.result_directory, f"{prefix}_split_{split[0]}_{split[1]}_time_{self.result_file_suffix}.json"), "a") as f: 169 | json.dump(results, f) 170 | f.write("\n") 171 | 172 | @torch.inference_mode() 173 | def run_latency_eval(self): 174 | 175 | for entry in self.dataset.entries: 176 | 177 | schema_file_path = os.path.join(SCHEMA_FILE_DIRECTORY, self.dataset.dataset_name, entry.schema) 178 | print(schema_file_path) 179 | if True: 180 | self.cache_engine.add_schema(read_file(schema_file_path, self.preproc), max_tokens=3500) 181 | 182 | prompt = Prompt(entry.prompt, self.preproc) 183 | 184 | no_cache = not self.enable_cache 185 | 186 | token_ids, position_ids, cache_time, cache = self.cache_engine.process(prompt, no_cache=no_cache, 187 | return_full_position_ids=self.lm.use_full_position_ids) 188 | 189 | if no_cache: 190 | assert cache is None 191 | 192 | input_ids = torch.tensor([token_ids], device=self.lm.device, dtype=torch.long) 193 | position_ids = torch.tensor([position_ids], device=self.lm.device, dtype=torch.long) 194 | # print(len(position_ids[0])) 195 | 196 | # add redundant batch dim 197 | if cache is not None: 198 | cache = [(k[0].unsqueeze(0), k[1].unsqueeze(0)) for k in cache] 199 | 200 | start = torch.cuda.Event(enable_timing=True) 201 | end = torch.cuda.Event(enable_timing=True) 202 | 203 | start.record() 204 | out = self.lm(input_ids=input_ids, 205 | position_ids=position_ids, 206 | past_key_values=cache, 207 | use_cache=True) 208 | end.record() 209 | torch.cuda.synchronize() 210 | response_time = start.elapsed_time(end) 211 | 212 | result = { 213 | "cache_time": cache_time, 214 | "response_time": response_time, 215 | } 216 | print(result) 217 | self.store_results(result) 218 | 219 | self.cache_engine.remove_all_schemas() 220 | 221 | def run(self, split, verbose=False): 222 | entry_count = self.dataset.get_entry_count() 223 | split_count = entry_count // split[1] 224 | 225 | start = split_count * split[0] 226 | end = split_count * (split[0] + 1) 227 | print(f"Running benchmark on {self.dataset.dataset_name}, start: {start}, end: {end}") 228 | 229 | for i in tqdm(range(start, end)): 230 | entries = self.dataset.get_query((i, i + 1)) 231 | for entry in entries: 232 | schema_file_path = os.path.join(SCHEMA_FILE_DIRECTORY, self.dataset.dataset_name, entry.schema) 233 | print(schema_file_path) 234 | self.cache_engine.add_schema(read_file(schema_file_path, self.preproc), 235 | batch_size=self.llm_config.get("schema_load_batch", 1), 236 | max_tokens=self.llm_config.get("max_tokens", 3500)) 237 | 238 | for entry in entries: 239 | print(entry.prompt) 240 | prompt = Prompt(entry.prompt, self.preproc) 241 | no_cache = not self.enable_cache 242 | token_ids, position_ids, cache_time, cache = self.cache_engine.process(prompt, no_cache=no_cache, 243 | return_full_position_ids=self.lm.use_full_position_ids) 244 | if no_cache: 245 | assert cache is None 246 | # for debugging 247 | if verbose: 248 | print("No caching; prompt:\n" + self.lm.decode(token_ids) + "\n") 249 | 250 | output_stream = self.gen_engine.generate(token_ids, position_ids, self.parameter, cache, 251 | stream_interval=2, 252 | use_full_position_ids=self.lm.use_full_position_ids) 253 | 254 | resp = "" 255 | pre = 0 256 | response_time = 0.0 257 | for outputs in output_stream: 258 | response_time = outputs.response_time 259 | output_text = outputs.new_text.strip().split(" ") 260 | now = len(output_text) - 1 261 | if now > pre: 262 | tt = " ".join(output_text[pre:now]) 263 | resp += tt + " " 264 | print(tt, end=" ", flush=True) 265 | pre = now 266 | 267 | tt = " ".join(output_text[pre:]) 268 | print(tt, flush=True) 269 | resp += tt 270 | 271 | result = { 272 | "cache_time": cache_time, 273 | "response_time": response_time, 274 | "answers": entry.answer, 275 | "response": resp 276 | } 277 | self.store_results(result, split) 278 | print("\n") 279 | 280 | self.cache_engine.remove_all_schemas() 281 | 282 | 283 | def seed_everything(seed): 284 | torch.manual_seed(seed) 285 | torch.cuda.manual_seed(seed) 286 | np.random.seed(seed) 287 | random.seed(seed) 288 | torch.backends.cudnn.benchmark = False 289 | torch.backends.cudnn.deterministic = True 290 | torch.cuda.manual_seed_all(seed) 291 | 292 | 293 | def main(llm_config_path: str = os.path.join('./', "config/llm_config_llama2_7b.json"), 294 | dataset: str = "narrativeqa", enable_cache=False, cache_batch_size=1, split=(0, 1), 295 | test_latency=False, 296 | use_cpu_for_inference=False, 297 | verbose=False): 298 | seed_everything(42) 299 | 300 | eval = Eval(llm_config_path, dataset, enable_cache, use_cpu_for_inference) 301 | 302 | if test_latency: 303 | eval.run_latency_eval() 304 | else: 305 | eval.run(split, verbose) 306 | 307 | if __name__ == "__main__": 308 | fire.Fire(main) 309 | -------------------------------------------------------------------------------- /eval_acc.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | 4 | import numpy as np 5 | import torch.cuda 6 | import fire 7 | import sys, json 8 | import os 9 | import datetime 10 | from tqdm import tqdm 11 | from benchmark.longbench import LongBench 12 | from promptcache.model import Llama2, Falcon, Mpt 13 | from promptcache import Prompt, CompactSpaces, read_file, CacheEngine, \ 14 | GenerationEngine, GenerationParameters 15 | 16 | from benchmark.benchmark_base import DATASET_LIST, SCHEMA_FILE_DIRECTORY 17 | from benchmark.squad_v2 import SquadV2 18 | from benchmark.multi_news import MultiNews 19 | from promptcache.prompt import apply_preproc 20 | from metrics import ( 21 | qa_f1_score, 22 | rouge_zh_score, 23 | qa_f1_zh_score, 24 | rouge_score, 25 | classification_score, 26 | retrieval_score, 27 | retrieval_zh_score, 28 | count_score, 29 | code_sim_score 30 | ) 31 | from multiprocessing import cpu_count, Process, Queue 32 | from concurrent.futures import ProcessPoolExecutor 33 | 34 | dataset2metric = { 35 | "narrativeqa": qa_f1_score, 36 | "qasper": qa_f1_score, 37 | "multifieldqa_en": qa_f1_score, 38 | "multifieldqa_zh": qa_f1_zh_score, 39 | "hotpotqa": qa_f1_score, 40 | "2wikimqa": qa_f1_score, 41 | "musique": qa_f1_score, 42 | "dureader": rouge_zh_score, 43 | "gov_report": rouge_score, 44 | "qmsum": rouge_score, 45 | "multi_news": rouge_score, 46 | "vcsum": rouge_zh_score, 47 | "trec": classification_score, 48 | "triviaqa": qa_f1_score, 49 | "samsum": rouge_score, 50 | "lsht": classification_score, 51 | "passage_retrieval_en": retrieval_score, 52 | "passage_count": count_score, 53 | "passage_retrieval_zh": retrieval_zh_score, 54 | "lcc": code_sim_score, 55 | "repobench-p": code_sim_score, 56 | } 57 | 58 | BENCHMARK_PATH = "./benchmark" 59 | 60 | 61 | class Eval: 62 | def __init__(self, gpu_id, llm_config_path, dataset_list, enable_cache): 63 | with open("./config/dataset_maxlen.json", 'r') as f: 64 | self.dataset_maxlen = json.load(f) 65 | 66 | with open(llm_config_path, 'r') as f: 67 | self.llm_config = json.load(f) 68 | 69 | self.enable_cache = enable_cache 70 | 71 | self.model_name = self.llm_config["name"] 72 | self.model_arch = self.llm_config["arch"] 73 | self.model_log_name = self.llm_config["log_name"] 74 | self.max_ctx_length = self.llm_config.get("max_ctx_length", 4096) 75 | self.max_tokens = self.llm_config.get("max_tokens", 3500) 76 | self.dataset_list = dataset_list 77 | 78 | if self.model_arch == "llama": 79 | self.lm = Llama2(name=self.model_name, device_map={"": gpu_id}, load_in_8bit=True) 80 | elif self.model_arch == "falcon": 81 | self.lm = Falcon(name=self.model_name, device_map={"": gpu_id}, load_in_8bit=True) 82 | elif self.model_arch == "mpt": 83 | self.lm = Mpt(name=self.model_name, device_map={"": gpu_id}, load_in_8bit=True) 84 | else: 85 | raise ValueError("Invalid model name") 86 | 87 | self.cache_engine = CacheEngine(self.max_ctx_length, self.lm, 88 | target_device=self.lm.device) 89 | self.gen_engine = GenerationEngine(self.lm) 90 | self.preproc = [ 91 | # CompactSpaces(), 92 | self.lm.get_formatter() 93 | ] 94 | 95 | # create result directory 96 | self.result_directory = os.path.join(BENCHMARK_PATH, "results_acc") 97 | if not os.path.exists(self.result_directory): 98 | os.makedirs(self.result_directory) 99 | 100 | def run(self): 101 | 102 | for dataset_name in self.dataset_list: 103 | 104 | dataset = self.dataset_list[dataset_name] 105 | dataset.init(limit_entries=3) 106 | 107 | results = [] 108 | 109 | for entry in tqdm(dataset.entries): 110 | # print(entry.prompt) 111 | 112 | schema = apply_preproc(entry.schema, self.preproc) 113 | prompt = Prompt(entry.prompt, self.preproc) 114 | 115 | self.cache_engine.add_schema(schema, max_tokens=self.max_tokens) 116 | 117 | no_cache = not self.enable_cache 118 | token_ids, position_ids, cache_time, cache = self.cache_engine.process(prompt, 119 | no_cache=no_cache, 120 | return_full_position_ids=self.lm.use_full_position_ids) 121 | if no_cache: 122 | assert cache is None 123 | 124 | parameter = GenerationParameters( 125 | temperature=0.0, 126 | repetition_penalty=1.0, 127 | top_p=0.0, 128 | top_k=-1, 129 | max_new_tokens=self.dataset_maxlen[dataset_name], 130 | stop_token_ids=self.lm.stop_token_ids, 131 | stop_str=self.lm.stop_str 132 | ) 133 | 134 | output_stream = self.gen_engine.generate(token_ids, position_ids, parameter, cache, 135 | stream_interval=2, 136 | use_full_position_ids=self.lm.use_full_position_ids) 137 | 138 | resp = "" 139 | pre = 0 140 | response_time = 0.0 141 | for outputs in output_stream: 142 | response_time = outputs.response_time 143 | output_text = outputs.new_text.strip().split(" ") 144 | now = len(output_text) - 1 145 | if now > pre: 146 | tt = " ".join(output_text[pre:now]) 147 | resp += tt + " " 148 | # print(tt, end=" ", flush=True) 149 | pre = now 150 | 151 | tt = " ".join(output_text[pre:]) 152 | # print(tt, flush=True) 153 | resp += tt 154 | # print("\n") 155 | 156 | result = { 157 | "cache_time": cache_time, 158 | "response_time": response_time, 159 | "answers": entry.answer, 160 | "response": resp 161 | } 162 | print(result) 163 | 164 | results.append(result) 165 | self.cache_engine.remove_all_schemas() 166 | 167 | total_score = 0 168 | metric_fn = dataset2metric[dataset_name] 169 | for result in results: 170 | response = result["response"] 171 | answers = result["answers"] 172 | 173 | score = 0. 174 | for answer in answers: 175 | score = max(score, metric_fn(response, answer)) 176 | 177 | total_score += score 178 | 179 | total_score = total_score / len(results) * 100 180 | print(f"Total score: {total_score:.2f}") 181 | 182 | if self.enable_cache: 183 | prefix = "cache_enabled" 184 | else: 185 | prefix = "cache_disabled" 186 | filename = f"{self.model_log_name}-{dataset_name}-{prefix}.json" 187 | 188 | with open(os.path.join(self.result_directory, filename), "w") as f: 189 | json.dump({ 190 | "model_name": self.model_name, 191 | "model_arch": self.model_arch, 192 | "dataset_name": dataset_name, 193 | "enable_cache": self.enable_cache, 194 | "total_score": total_score, 195 | "results": results 196 | }, f) 197 | 198 | 199 | def run_eval(gpu_id, llm_config_path: str = os.path.join('./', "config/llm_config_llama2_7b.json"), 200 | dataset: str = "narrativeqa", 201 | enable_cache=True, ): 202 | seed_everything(42) 203 | 204 | eval = Eval(gpu_id, llm_config_path, dataset, enable_cache) 205 | eval.run() 206 | 207 | 208 | def seed_everything(seed): 209 | torch.manual_seed(seed) 210 | torch.cuda.manual_seed(seed) 211 | np.random.seed(seed) 212 | random.seed(seed) 213 | torch.backends.cudnn.benchmark = False 214 | torch.backends.cudnn.deterministic = True 215 | torch.cuda.manual_seed_all(seed) 216 | 217 | 218 | def main(num_gpus=1, llm_config_path: str = os.path.join('./', "config/llm_config_llama2_7b.json"), 219 | enable_cache=True, 220 | ): 221 | dataset_list = { 222 | "narrativeqa": LongBench("narrativeqa"), 223 | "qasper": LongBench("qasper"), 224 | "multifieldqa_en": LongBench("multifieldqa_en"), 225 | "hotpotqa": LongBench("hotpotqa"), 226 | "2wikimqa": LongBench("2wikimqa"), 227 | "musique": LongBench("musique"), 228 | "gov_report": LongBench("gov_report"), 229 | "qmsum": LongBench("qmsum"), 230 | "multi_news": LongBench("multi_news"), 231 | "trec": LongBench("trec"), 232 | "triviaqa": LongBench("triviaqa"), 233 | "samsum": LongBench("samsum"), 234 | "passage_count": LongBench("passage_count"), 235 | "passage_retrieval_en": LongBench("passage_retrieval_en"), 236 | "lcc": LongBench("lcc"), 237 | "repobench-p": LongBench("repobench-p") 238 | } 239 | 240 | dpg = int(math.ceil(len(dataset_list) / num_gpus)) 241 | 242 | jobs_list = [] 243 | nn = list(dataset_list.keys()) 244 | for i in range(num_gpus): 245 | dataset_names = nn[i * dpg:(i + 1) * dpg] 246 | jobs = {} 247 | for dn in dataset_names: 248 | jobs[dn] = dataset_list[dn] 249 | jobs_list.append(jobs) 250 | 251 | processes = [ 252 | Process(target=run_eval, args=(i, llm_config_path, jobs_list[i], enable_cache)) 253 | for i in range(num_gpus) 254 | ] 255 | 256 | for p in processes: 257 | p.start() 258 | 259 | seed_everything(42) 260 | 261 | for p in processes: 262 | p.join() 263 | 264 | 265 | if __name__ == "__main__": 266 | fire.Fire(main) 267 | -------------------------------------------------------------------------------- /eval_acc.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/zsh 2 | #SBATCH --mem=128g 3 | #SBATCH --nodes=1 4 | #SBATCH --ntasks-per-node=1 5 | #SBATCH --cpus-per-task=32 # <- match to OMP_NUM_THREADS 6 | #SBATCH --partition=gpuA100x4 # <- or one of: gpuA100x4 gpuA40x4 gpuA100x8 gpuMI100x8 7 | #SBATCH --account=bccn-delta-gpu 8 | #SBATCH --job-name=eval_acc 9 | #SBATCH --time=10:00:00 # hh:mm:ss for the job 10 | ### GPU options ### 11 | #SBATCH --gpus-per-node=4 12 | ##SBATCH --gpu-bind=none # <- or closest 13 | #SBATCH --mail-user=in.gim@yale.edu.edu 14 | #SBATCH --mail-type="BEGIN,END" 15 | 16 | module reset # drop modules and explicitly load the ones needed 17 | # (good job metadata and reproducibility) 18 | # $WORK and $SCRATCH are now set 19 | module load python # ... or any appropriate modules 20 | module list # job documentation and metadata 21 | echo "job is starting on `hostname`" 22 | srun python3 eval_acc.py --num_gpus=4 23 | 24 | -------------------------------------------------------------------------------- /eval_sys.py: -------------------------------------------------------------------------------- 1 | import gc 2 | 3 | import torch.cuda 4 | import fire 5 | import sys, json 6 | import os 7 | import datetime 8 | from tqdm import tqdm 9 | from benchmark.longbench import LongBench 10 | from promptcache.model import Llama2, Falcon, Mpt 11 | from promptcache import Prompt, CompactSpaces, read_file, CacheEngine, \ 12 | GenerationEngine, GenerationParameters 13 | 14 | from benchmark.benchmark_base import SCHEMA_FILE_DIRECTORY 15 | 16 | BENCHMARK_PATH = "./benchmark" 17 | from torch.profiler import profile, record_function, ProfilerActivity 18 | 19 | 20 | class Eval: 21 | def __init__(self, memo, llm_config_path, use_cpu_for_inference=False): 22 | with open("./config/dataset_maxlen.json", 'r') as f: 23 | self.dataset_maxlen = json.load(f) 24 | 25 | with open(llm_config_path, 'r') as f: 26 | self.llm_config = json.load(f) 27 | self.memo = memo 28 | self.use_cpu_for_inference = use_cpu_for_inference 29 | self.repeat_times = 2 if use_cpu_for_inference else 3 30 | self.model_name = self.llm_config["name"] 31 | self.model_arch = self.llm_config["arch"] 32 | self.model_log_name = self.llm_config["log_name"] 33 | self.max_ctx_length = self.llm_config.get("max_ctx_length", 4096) 34 | self.max_tokens = self.llm_config.get("max_tokens", 3500) 35 | 36 | if self.model_arch == "llama": 37 | self.lm_for_caching = Llama2(name=self.model_name, device_map={"": 0}, load_in_8bit=True) 38 | elif self.model_arch == "falcon": 39 | self.lm_for_caching = Falcon(name=self.model_name, device_map={"": 0}, load_in_8bit=True) 40 | elif self.model_arch == "mpt": 41 | self.lm_for_caching = Mpt(name=self.model_name, device_map={"": 0}, load_in_8bit=True) 42 | else: 43 | raise ValueError("Invalid model name") 44 | 45 | if self.use_cpu_for_inference: 46 | if self.model_arch == "llama": 47 | self.lm = Llama2(name=self.model_name, device_map=None) 48 | elif self.model_arch == "falcon": 49 | self.lm = Falcon(name=self.model_name, device_map=None) 50 | elif self.model_arch == "mpt": 51 | self.lm = Mpt(name=self.model_name, device_map=None) 52 | else: 53 | self.lm = self.lm_for_caching 54 | 55 | self.cache_engine = CacheEngine(self.max_ctx_length, self.lm, 56 | target_device=self.lm.device) 57 | self.gen_engine = GenerationEngine(self.lm) 58 | self.preproc = [ 59 | # CompactSpaces(), 60 | self.lm.get_formatter() 61 | ] 62 | 63 | self.dataset_list = { 64 | "narrativeqa": LongBench("narrativeqa"), 65 | "qasper": LongBench("qasper"), 66 | "multifieldqa_en": LongBench("multifieldqa_en"), 67 | "hotpotqa": LongBench("hotpotqa"), 68 | "2wikimqa": LongBench("2wikimqa"), 69 | "musique": LongBench("musique"), 70 | "gov_report": LongBench("gov_report"), 71 | "qmsum": LongBench("qmsum"), 72 | "multi_news": LongBench("multi_news"), 73 | "triviaqa": LongBench("triviaqa"), 74 | "samsum": LongBench("samsum"), 75 | "passage_count": LongBench("passage_count"), 76 | "passage_retrieval_en": LongBench("passage_retrieval_en"), 77 | } 78 | 79 | # @torch.inference_mode() 80 | # def profile_cpu_inference(self): 81 | # 82 | # with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof: 83 | # with record_function("model_inference"): 84 | # model(inputs) 85 | 86 | # recomputation overhead vs mem trasnfer overhead 87 | @torch.inference_mode() 88 | def run_critical_point(self): 89 | 90 | def create_cache(seq_len): 91 | 92 | # # llama 2 13B 93 | num_layers = 40 94 | num_heads = 40 95 | head_dim = 128 96 | 97 | # # llama 2 7B 98 | # num_layers = 32 99 | # num_heads = 32 100 | # head_dim = 128 101 | 102 | return [(torch.rand((num_heads, seq_len, head_dim), dtype=torch.float16, device='cpu'), 103 | torch.rand((num_heads, seq_len, head_dim), dtype=torch.float16, device='cpu')) for _ in 104 | range(num_layers)] 105 | 106 | test_seq_len = [ 107 | 1, 108 | 2, 109 | 4, 110 | 8, 111 | 16, 112 | 32, 113 | 64, 114 | 128, 115 | 256, 116 | 512, 117 | 512 + 128 * 1, 118 | 512 + 128 * 2, 119 | 512 + 128 * 3, 120 | 1024, 121 | 1024 + 256 * 1, 122 | 1024 + 256 * 2, 123 | 1024 + 256 * 3, 124 | 2048, 125 | 2028 + 512 * 1, 126 | 2028 + 512 * 2, 127 | 2028 + 512 * 3, 128 | 4096, 129 | # 4096 + 1024 * 1, 130 | # 4096 + 1024 * 2, 131 | 132 | ] 133 | 134 | results = [] 135 | 136 | for seq_len in tqdm(test_seq_len): 137 | for _ in range(self.repeat_times): 138 | ## 1. compute gpu upload time 139 | kv_cache = create_cache(seq_len) 140 | 141 | torch.cuda.synchronize() 142 | start = torch.cuda.Event(enable_timing=True) 143 | end = torch.cuda.Event(enable_timing=True) 144 | 145 | start.record() 146 | # upload everything to GPU 147 | kv_cache_gpu = [ 148 | (k[0].to('cuda', non_blocking=True, copy=True), k[1].to('cuda', non_blocking=True, copy=True)) 149 | for k in kv_cache] 150 | 151 | end.record() 152 | torch.cuda.synchronize() 153 | gpu_upload_time = start.elapsed_time(end) 154 | 155 | del kv_cache_gpu, kv_cache 156 | gc.collect() 157 | torch.cuda.empty_cache() 158 | 159 | results.append({ 160 | "seq_len": seq_len, 161 | "time": gpu_upload_time, 162 | }) 163 | 164 | result_path = os.path.join(BENCHMARK_PATH, "results_latency") 165 | 166 | with open(os.path.join(result_path, f"{self.memo}-{self.model_log_name}-critical_point-upload.json"), 167 | "w") as f: 168 | json.dump( 169 | { 170 | 'model_name': self.model_name, 171 | 'results': results 172 | }, f) 173 | 174 | results = [] 175 | ## 2. compute recomputation time 176 | for seq_len in tqdm(test_seq_len): 177 | 178 | for _ in range(self.repeat_times): 179 | token_ids = [100] * seq_len 180 | position_ids = list(range(seq_len)) 181 | 182 | input_ids = torch.tensor([token_ids], device=self.lm.device, dtype=torch.long) 183 | position_ids = torch.tensor([position_ids], device=self.lm.device, dtype=torch.long) 184 | 185 | start = torch.cuda.Event(enable_timing=True) 186 | end = torch.cuda.Event(enable_timing=True) 187 | 188 | start.record() 189 | out = self.lm(input_ids=input_ids, 190 | position_ids=position_ids, 191 | past_key_values=None, 192 | use_cache=False) 193 | 194 | end.record() 195 | torch.cuda.synchronize() 196 | recomputation_time = start.elapsed_time(end) 197 | 198 | del out 199 | gc.collect() 200 | torch.cuda.empty_cache() 201 | 202 | results.append({ 203 | "seq_len": seq_len, 204 | "time": recomputation_time 205 | }) 206 | 207 | result_path = os.path.join(BENCHMARK_PATH, "results_latency") 208 | 209 | with open(os.path.join(result_path, f"{self.memo}-{self.model_log_name}-critical_point-recomputation.json"), 210 | "w") as f: 211 | json.dump( 212 | { 213 | 'model_name': self.model_log_name, 214 | 'results': results 215 | }, f) 216 | 217 | @torch.inference_mode() 218 | def run_critical_point22(self): 219 | 220 | 221 | test_seq_len = [ 222 | 1, 223 | 2, 224 | 4, 225 | 8, 226 | 16, 227 | 32, 228 | 64, 229 | 128, 230 | 256, 231 | 512, 232 | 512 + 128 * 1, 233 | 512 + 128 * 2, 234 | 512 + 128 * 3, 235 | 1024, 236 | 1024 + 256 * 1, 237 | 1024 + 256 * 2, 238 | 1024 + 256 * 3, 239 | 2048, 240 | 2028 + 512 * 1, 241 | 2028 + 512 * 2, 242 | 2028 + 512 * 3, 243 | #4096, 244 | # 4096 + 1024 * 1, 245 | # 4096 + 1024 * 2, 246 | 247 | ] 248 | 249 | results = [] 250 | 251 | for seq_len in tqdm(test_seq_len): 252 | for _ in range(self.repeat_times): 253 | ## 1. compute gpu upload time 254 | input_ids = torch.tensor([[100]], device=self.lm.device, dtype=torch.long) 255 | #position_ids = torch.tensor([[100]], device=self.lm.device, dtype=torch.long) 256 | 257 | device_cache = [ 258 | (torch.empty(1, 32, seq_len, 128, device=self.lm.device, dtype=torch.half), # key 259 | torch.empty(1, 32, seq_len, 128, device=self.lm.device, dtype=torch.half)) for _ in 260 | range(32)] 261 | 262 | torch.cuda.synchronize() 263 | start = torch.cuda.Event(enable_timing=True) 264 | end = torch.cuda.Event(enable_timing=True) 265 | 266 | start.record() 267 | # upload everything to GPU 268 | out = self.lm(input_ids=input_ids, 269 | #position_ids=position_ids, 270 | past_key_values=device_cache, 271 | use_cache=True) 272 | 273 | end.record() 274 | torch.cuda.synchronize() 275 | gpu_upload_time = start.elapsed_time(end) 276 | 277 | del device_cache 278 | gc.collect() 279 | torch.cuda.empty_cache() 280 | 281 | results.append({ 282 | "seq_len": seq_len, 283 | "time": gpu_upload_time, 284 | }) 285 | 286 | result_path = os.path.join(BENCHMARK_PATH, "aaa") 287 | 288 | with open(os.path.join(result_path, f"{self.memo}-{self.model_log_name}-critical_point-upload.json"), 289 | "w") as f: 290 | json.dump( 291 | { 292 | 'model_name': self.model_name, 293 | 'results': results 294 | }, f) 295 | 296 | results = [] 297 | 298 | 299 | @torch.inference_mode() 300 | def run_latency_eval(self, do_cache): 301 | 302 | for dataset_name in self.dataset_list: 303 | 304 | dataset = self.dataset_list[dataset_name] 305 | dataset.init(limit_entries=5) 306 | 307 | # create result directory 308 | device_used = "cpu" if self.use_cpu_for_inference else "gpu" 309 | cache_used = "cache" if do_cache else "no_cache" 310 | result_path = os.path.join(BENCHMARK_PATH, "results_latency") 311 | no_cache = not do_cache 312 | 313 | if not os.path.exists(result_path): 314 | os.makedirs(result_path) 315 | 316 | results = [] 317 | 318 | for entry in tqdm(dataset.entries[:5]): 319 | for _ in range(self.repeat_times): 320 | schema_file_path = os.path.join(SCHEMA_FILE_DIRECTORY, dataset_name, entry.schema) 321 | 322 | self.cache_engine.add_schema(read_file(schema_file_path, self.preproc), no_cache=no_cache, 323 | max_tokens=3500) 324 | 325 | prompt = Prompt(entry.prompt, self.preproc) 326 | 327 | token_ids, position_ids, cache_time, cache = self.cache_engine.process(prompt, no_cache=no_cache, 328 | return_full_position_ids=self.lm.use_full_position_ids) 329 | 330 | input_ids = torch.tensor([token_ids], device=self.lm.device, dtype=torch.long) 331 | position_ids = torch.tensor([position_ids], device=self.lm.device, dtype=torch.long) 332 | # print(len(position_ids[0])) 333 | 334 | # add redundant batch dim 335 | if cache is not None: 336 | cache = [(k[0].unsqueeze(0), k[1].unsqueeze(0)) for k in cache] 337 | 338 | start = torch.cuda.Event(enable_timing=True) 339 | end = torch.cuda.Event(enable_timing=True) 340 | 341 | start.record() 342 | 343 | # with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof: 344 | # with record_function("model_inference"): 345 | out = self.lm(input_ids=input_ids, 346 | position_ids=position_ids, 347 | past_key_values=cache, 348 | use_cache=True) 349 | end.record() 350 | torch.cuda.synchronize() 351 | response_time = start.elapsed_time(end) 352 | 353 | # print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10)) 354 | 355 | result = { 356 | "entry_schema": entry.schema, 357 | "cache_time": cache_time, 358 | "response_time": response_time, 359 | } 360 | # print(result) 361 | results.append(result) 362 | 363 | self.cache_engine.remove_all_schemas() 364 | 365 | with open( 366 | os.path.join(result_path, 367 | f"{self.memo}-{self.model_log_name}-{device_used}-{cache_used}-{dataset_name}.json"), 368 | "w") as f: 369 | json.dump( 370 | { 371 | 'model_name': self.model_log_name, 372 | 'device_used': device_used, 373 | 'cache_used': cache_used, 374 | 'dataset_name': dataset_name, 375 | 376 | 'results': results 377 | }, f) 378 | f.write("\n") 379 | 380 | @torch.inference_mode() 381 | def run_profile(self, do_cache): 382 | device_used = "cpu" if self.use_cpu_for_inference else "gpu" 383 | cache_used = "cache" if do_cache else "no_cache" 384 | 385 | for dataset_name in self.dataset_list: 386 | 387 | dataset = self.dataset_list[dataset_name] 388 | dataset.init(limit_entries=5) 389 | 390 | no_cache = not do_cache 391 | 392 | for entry in tqdm(dataset.entries[:5]): 393 | for _ in range(self.repeat_times): 394 | schema_file_path = os.path.join(SCHEMA_FILE_DIRECTORY, dataset_name, entry.schema) 395 | 396 | self.cache_engine.add_schema(read_file(schema_file_path, self.preproc), no_cache=no_cache, 397 | max_tokens=2500) 398 | 399 | prompt = Prompt(entry.prompt, self.preproc) 400 | 401 | token_ids, position_ids, cache_time, cache = self.cache_engine.process(prompt, 402 | no_cache=no_cache, 403 | return_full_position_ids=self.lm.use_full_position_ids) 404 | 405 | input_ids = torch.tensor([token_ids], device=self.lm.device, dtype=torch.long) 406 | position_ids = torch.tensor([position_ids], device=self.lm.device, dtype=torch.long) 407 | # print(len(position_ids[0])) 408 | 409 | # add redundant batch dim 410 | if cache is not None: 411 | cache = [(k[0].unsqueeze(0), k[1].unsqueeze(0)) for k in cache] 412 | 413 | with profile(activities=[ProfilerActivity.CUDA], with_stack=True, 414 | experimental_config=torch._C._profiler._ExperimentalConfig(verbose=True)) as prof: 415 | with record_function("model_inference"): 416 | out = self.lm(input_ids=input_ids, 417 | position_ids=position_ids, 418 | past_key_values=cache, 419 | use_cache=True) 420 | 421 | prof.export_stacks(f"./profile/{device_used}_{cache_used}_self_cuda_time_total.txt", 422 | "self_cuda_time_total") 423 | self.cache_engine.remove_all_schemas() 424 | 425 | return 426 | 427 | 428 | def main(memo: str = "13900k-cpu", llm_config_path: str = os.path.join('./', "config/llm_config_llama2_7b.json"), 429 | use_cpu_for_inference=True): 430 | eval = Eval(memo, llm_config_path, use_cpu_for_inference) 431 | 432 | # eval.run_latency_eval(False) 433 | # eval.run_latency_eval(True) 434 | #eval.run_profile(True) 435 | #eval.run_profile(False) 436 | 437 | eval.run_critical_point22() 438 | 439 | 440 | if __name__ == "__main__": 441 | fire.Fire(main) 442 | -------------------------------------------------------------------------------- /eval_sys_a100-7b-cpu.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/zsh 2 | #SBATCH --exclusive 3 | #SBATCH --mem=0 4 | #SBATCH --nodes=1 5 | #SBATCH --ntasks-per-node=64 6 | #SBATCH --partition=gpuA100x4 # <- or one of: gpuA100x4 gpuA40x4 gpuA100x8 gpuMI100x8 7 | #SBATCH --account=bccn-delta-gpu 8 | #SBATCH --job-name=a100_7b_cpu 9 | #SBATCH --time=47:00:00 # hh:mm:ss for the job 10 | #SBATCH --constraint="scratch" 11 | 12 | ### GPU options ### 13 | #SBATCH --gpus-per-node=1 14 | #SBATCH --gpu-bind=closest # select a cpu close to gpu on pci bus topology 15 | #SBATCH --mail-user=in.gim@yale.edu 16 | #SBATCH --mail-type="BEGIN,END" 17 | 18 | module reset # drop modules and explicitly load the ones needed 19 | # (good job metadata and reproducibility) 20 | # $WORK and $SCRATCH are now set 21 | module load python # ... or any appropriate modules 22 | module list # job documentation and metadata 23 | echo "job is starting on `hostname`" 24 | 25 | srun python3 eval_sys.py \ 26 | --memo=a100 \ 27 | --llm_config_path=./config/llm_config_llama2_7b.json \ 28 | --use_cpu_for_inference=True 29 | -------------------------------------------------------------------------------- /eval_sys_a100-7b-gpu.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/zsh 2 | #SBATCH --exclusive 3 | #SBATCH --mem=0 4 | #SBATCH --nodes=1 5 | #SBATCH --ntasks-per-node=1 6 | #SBATCH --cpus-per-task=16 7 | #SBATCH --partition=gpuA100x4 # <- or one of: gpuA100x4 gpuA40x4 gpuA100x8 gpuMI100x8 8 | #SBATCH --account=bccn-delta-gpu 9 | #SBATCH --job-name=a100_gpu 10 | #SBATCH --time=47:00:00 # hh:mm:ss for the job 11 | #SBATCH --constraint="scratch" 12 | 13 | ### GPU options ### 14 | #SBATCH --gpus-per-node=1 15 | #SBATCH --gpu-bind=closest # select a cpu close to gpu on pci bus topology 16 | #SBATCH --mail-user=in.gim@yale.edu 17 | #SBATCH --mail-type="BEGIN,END" 18 | 19 | module reset # drop modules and explicitly load the ones needed 20 | # (good job metadata and reproducibility) 21 | # $WORK and $SCRATCH are now set 22 | module load python # ... or any appropriate modules 23 | module list # job documentation and metadata 24 | echo "job is starting on `hostname`" 25 | srun python3 eval_sys.py \ 26 | --memo=a100 \ 27 | --llm_config_path=./config/llm_config_llama2_7b.json \ 28 | --use_cpu_for_inference=False 29 | -------------------------------------------------------------------------------- /eval_sys_a40-7b-cpu.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/zsh 2 | #SBATCH --exclusive 3 | #SBATCH --mem=0 4 | #SBATCH --nodes=1 5 | #SBATCH --ntasks-per-node=64 6 | #SBATCH --partition=gpuA40x4 # <- or one of: gpuA100x4 gpuA40x4 gpuA100x8 gpuMI100x8 7 | #SBATCH --account=bccn-delta-gpu 8 | #SBATCH --job-name=a40_7b_cpu 9 | #SBATCH --time=47:00:00 # hh:mm:ss for the job 10 | #SBATCH --constraint="scratch" 11 | 12 | ### GPU options ### 13 | #SBATCH --gpus-per-node=1 14 | #SBATCH --gpu-bind=closest # select a cpu close to gpu on pci bus topology 15 | #SBATCH --mail-user=in.gim@yale.edu 16 | #SBATCH --mail-type="BEGIN,END" 17 | 18 | module reset # drop modules and explicitly load the ones needed 19 | # (good job metadata and reproducibility) 20 | # $WORK and $SCRATCH are now set 21 | module load python # ... or any appropriate modules 22 | module list # job documentation and metadata 23 | echo "job is starting on `hostname`" 24 | 25 | srun python3 eval_sys.py \ 26 | --memo=a40 \ 27 | --llm_config_path=./config/llm_config_llama2_7b.json \ 28 | --use_cpu_for_inference=True 29 | -------------------------------------------------------------------------------- /eval_sys_a40-7b-gpu.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/zsh 2 | #SBATCH --exclusive 3 | #SBATCH --mem=0 4 | #SBATCH --nodes=1 5 | #SBATCH --ntasks-per-node=1 6 | #SBATCH --cpus-per-task=16 7 | #SBATCH --partition=gpuA40x4 # <- or one of: gpuA100x4 gpuA40x4 gpuA100x8 gpuMI100x8 8 | #SBATCH --account=bccn-delta-gpu 9 | #SBATCH --job-name=a40_7b_gpu 10 | #SBATCH --time=47:00:00 # hh:mm:ss for the job 11 | #SBATCH --constraint="scratch" 12 | 13 | ### GPU options ### 14 | #SBATCH --gpus-per-node=1 15 | #SBATCH --gpu-bind=closest # select a cpu close to gpu on pci bus topology 16 | #SBATCH --mail-user=in.gim@yale.edu 17 | #SBATCH --mail-type="BEGIN,END" 18 | 19 | module reset # drop modules and explicitly load the ones needed 20 | # (good job metadata and reproducibility) 21 | # $WORK and $SCRATCH are now set 22 | module load python # ... or any appropriate modules 23 | module list # job documentation and metadata 24 | echo "job is starting on `hostname`" 25 | srun python3 eval_sys.py \ 26 | --memo=a40 \ 27 | --llm_config_path=./config/llm_config_llama2_7b.json \ 28 | --use_cpu_for_inference=False 29 | -------------------------------------------------------------------------------- /examples/code_generation_bookstore.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | You are a sophisticated language model assistant that can read and understand multiple source files 5 | simultaneously. Your current task is to examine the provided source files. These files contain Python 6 | classes and methods. 7 | 8 | Using the knowledge extracted from the source files, you are expected to generate code following the 9 | instructions that will be given. 10 | 11 | 12 | 13 | Please read the given source files, understand their structure and relationships. I'll provide you with my 14 | instruction. 15 | 16 | 17 | class Book: 18 | def __init__(self, title, author, price, isbn): 19 | self.title = title 20 | self.author = author 21 | self.price = price 22 | self.isbn = isbn 23 | 24 | def get_details(self): 25 | return { 26 | "title": self.title, 27 | "author": self.author, 28 | "price": self.price, 29 | "isbn": self.isbn 30 | } 31 | 32 | def set_details(self, title, author, price, isbn): 33 | self.title = title 34 | self.author = author 35 | self.price = price 36 | self.isbn = isbn 37 | 38 | 39 | 40 | class User: 41 | def __init__(self, username, password): 42 | self.username = username 43 | self.password = password 44 | self.purchased_books = [] 45 | 46 | def register(self, username, password): 47 | self.username = username 48 | self.password = password 49 | 50 | def login(self, username, password): 51 | return self.username == username and self.password == password 52 | 53 | def purchase(self, book): 54 | self.purchased_books.append(book) 55 | 56 | def view_purchased_books(self): 57 | return self.purchased_books 58 | 59 | 60 | 61 | 62 | class Inventory: 63 | def __init__(self): 64 | self.books = [] 65 | 66 | def add_book(self, book): 67 | self.books.append(book) 68 | 69 | def remove_book(self, book): 70 | self.books.remove(book) 71 | 72 | def search_by_title(self, title): 73 | return [book for book in self.books if book.title == title] 74 | 75 | def search_by_author(self, author): 76 | return [book for book in self.books if book.author == author] 77 | 78 | 79 | 80 | from Book import Book 81 | from User import User 82 | from Inventory import Inventory 83 | 84 | class Store: 85 | def __init__(self): 86 | self.users = [] 87 | self.inventory = Inventory() 88 | 89 | def register_user(self, username, password): 90 | user = User(username, password) 91 | self.users.append(user) 92 | 93 | def login_user(self, username, password): 94 | for user in self.users: 95 | if user.login(username, password): 96 | return user 97 | return None 98 | 99 | def purchase_book(self, user, book): 100 | user.purchase(book) 101 | 102 | def view_inventory(self): 103 | return self.inventory.books 104 | 105 | def search_books(self, search_type, query): 106 | if search_type == "title": 107 | return self.inventory.search_by_title(query) 108 | elif search_type == "author": 109 | return self.inventory.search_by_author(query) 110 | return [] 111 | 112 | 113 | 114 | class Database: 115 | def __init__(self): 116 | self.users = [] 117 | self.books = [] 118 | 119 | def save_user(self, user): 120 | self.users.append(user) 121 | 122 | def save_book(self, book): 123 | self.books.append(book) 124 | 125 | def retrieve_all_users(self): 126 | return self.users 127 | 128 | def retrieve_all_books(self): 129 | return self.books 130 | 131 | def find_user(self, username): 132 | for user in self.users: 133 | if user.username == username: 134 | return user 135 | return None 136 | 137 | def find_book(self, isbn): 138 | for book in self.books: 139 | if book.isbn == isbn: 140 | return book 141 | return None 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | I have read and understood the source codes. I am ready to generate code. 150 | Give me the instructions. 151 | 152 | 153 | -------------------------------------------------------------------------------- /examples/parameterized_prompts.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | You are a world-renowned travel planner, having curated trips for both the enthusiastic explorer and the 5 | leisure-seeking traveler. Your adeptness in understanding unique travel preferences, combined with your 6 | extensive knowledge of global destinations, ensures an unforgettable journey. 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | I'm gearing up for a memorable escape for the duration of. 15 | 16 | 17 | 18 | 19 | My eyes are set on exploring the diverse tapestry of our own homeland. I've chosen the vibrant city 20 | of 21 | . Given its domestic charm, I anticipate an itinerary that beautifully 22 | blends 23 | affordability with rich experiences. I'm eager to soak in the city's heritage, dine on local 24 | delicacies, 25 | and engage with its heartwarming community. 26 | 27 | 28 | 29 | 30 | 31 | I'm yearning to tread international soils, to immerse myself in unfamiliar cultures, sceneries, and 32 | stories. 33 | 34 | 35 | 36 | The Maldives beckons me with its constellation of coral islands. Renowned for its ethereal 37 | underwater beauty and luxury resorts, it offers an unmatched tropical experience. From 38 | diving 39 | into the vibrant coral reefs to dining in an underwater restaurant, from watching 40 | bioluminescent 41 | phytoplankton light up the shores to enjoying traditional drumming and dancing, every moment 42 | promises to be a slice of paradise. 43 | 44 | 45 | 46 | 47 | 48 | The vast expanse of the Amazon rainforest stirs the adventurer in me. A treasure trove of 49 | biodiversity, it spans across several countries in South America. I dream of navigating the 50 | serpentine Amazon river, witnessing the cacophony of life in the jungle, and understanding 51 | the 52 | age-old traditions of the indigenous tribes. The enchanting flora, diverse fauna, and the 53 | mesmerizing sounds of the forest await my arrival. 54 | 55 | 56 | 57 | 58 | 59 | The golden embrace of the Sahara desert is an experience I seek. Beyond its mesmerizing 60 | dunes 61 | lies a world rich in history and culture. The tales of ancient caravans, the artistry of 62 | Berber 63 | music, and the thrill of dune bashing are just the tip of the sand dune. Nights under the 64 | vast 65 | desert sky, sharing stories around a campfire, and the hauntingly beautiful music of the 66 | desert 67 | dwellers promise an ethereal journey. 68 | 69 | 70 | 71 | 72 | 73 | Tokyo, Japan's bustling capital, is a harmonious blend of tradition and modernity. 74 | Skyscrapers 75 | stand alongside historic temples, and cutting-edge technology coexists with age-old customs. 76 | From the serenity of the Meiji Shrine to the electric vibes of Shibuya Crossing, from 77 | sumptuous 78 | sushi feasts to tranquil tea ceremonies, Tokyo promises a whirlwind of experiences. 79 | 80 | 81 | 82 | 83 | 84 | The eternal city of Rome, with its millennia of history, calls out to the history aficionado 85 | in 86 | me. Walking its streets feels like a journey through time, from ancient Roman ruins to 87 | Renaissance art. From tossing a coin in the Trevi Fountain to a leisurely stroll in the 88 | Roman 89 | Forum, from savoring a gelato by the Spanish Steps to attending a mass at St. Peter's 90 | Basilica, 91 | Rome promises a cultural feast. 92 | 93 | 94 | 95 | 96 | 97 | Cape Town, nestled at the foot of the majestic Table Mountain in South Africa, offers a 98 | cocktail 99 | of experiences. With its rich history, vibrant arts scene, and diverse landscapes ranging 100 | from 101 | beaches to vineyards, it's a traveler's paradise. I look forward to exploring the historic 102 | Robben Island, savoring wines in Stellenbosch, and witnessing the panoramic views from the 103 | Cape 104 | of Good Hope. 105 | 106 | 107 | 108 | 109 | 110 | Sydney, the shimmering jewel of Australia, where urban sophistication meets natural beauty. 111 | The 112 | iconic Sydney Opera House and the majestic Sydney Harbour Bridge dominate its skyline. 113 | Beaches 114 | like Bondi and Manly promise sun, surf, and relaxation, while the city's rich culinary 115 | scene, 116 | bustling markets, and world-class museums ensure an enriching urban adventure. 117 | 118 | 119 | 120 | 121 | 122 | Buenos Aires, Argentina's cosmopolitan capital, is a city that pulsates with passion and 123 | energy. 124 | Tango music fills its air, and its streets are a canvas of art and culture. From the 125 | historic 126 | San Telmo district to the vibrant La Boca neighborhood, from savoring a traditional asado to 127 | dancing the night away in a milonga, Buenos Aires promises a fiery and rhythmic journey. 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | I'd love to help. I've carefully read the city details and the trip duration. Please give me an instruction. 137 | 138 | 139 | -------------------------------------------------------------------------------- /g.sh: -------------------------------------------------------------------------------- 1 | srun --exclusive --mem=0 --nodes=1 --ntasks=1 --ntasks-per-node=1 --cpus-per-task=16 \ 2 | --partition=gpuA40x4-interactive --account=bccn-delta-gpu --gpu-bind=verbose,per_task:1 \ 3 | --gpus-per-node=1 --gpus-per-task=1 --constraint="scratch" --pty /bin/zsh -------------------------------------------------------------------------------- /get_scores.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | 4 | import numpy as np 5 | from tqdm import tqdm 6 | 7 | from metrics import ( 8 | qa_f1_score, 9 | rouge_zh_score, 10 | qa_f1_zh_score, 11 | rouge_score, 12 | classification_score, 13 | retrieval_score, 14 | retrieval_zh_score, 15 | count_score, 16 | code_sim_score 17 | ) 18 | 19 | dataset2metric = { 20 | "squad_v2": qa_f1_score, 21 | "narrativeqa": qa_f1_score, 22 | "qasper": qa_f1_score, 23 | "multifieldqa_en": qa_f1_score, 24 | "multifieldqa_zh": qa_f1_zh_score, 25 | "hotpotqa": qa_f1_score, 26 | "2wikimqa": qa_f1_score, 27 | "musique": qa_f1_score, 28 | "dureader": rouge_zh_score, 29 | "gov_report": rouge_score, 30 | "qmsum": rouge_score, 31 | "multi_news": rouge_score, 32 | # "trec": classification_score, 33 | "triviaqa": qa_f1_score, 34 | "samsum": rouge_score, 35 | "passage_retrieval_en": retrieval_score, 36 | "passage_count": count_score, 37 | "passage_retrieval_zh": retrieval_zh_score, 38 | "lcc": code_sim_score, 39 | "repobench-p": code_sim_score, 40 | } 41 | 42 | 43 | def main(): 44 | dset_list = [ 45 | #"squad_v2", 46 | "narrativeqa", 47 | "qasper", 48 | "multifieldqa_en", 49 | "hotpotqa", 50 | "2wikimqa", 51 | "musique", 52 | "gov_report", 53 | "qmsum", 54 | "multi_news", 55 | # "trec", 56 | "triviaqa", 57 | # "samsum", 58 | "passage_count", 59 | "passage_retrieval_en", 60 | # "lcc", 61 | # "repobench-p", 62 | ] 63 | 64 | model_list = [ 65 | #"falcon", 66 | "llama", 67 | #"mpt" 68 | ] 69 | 70 | for dset in tqdm(dset_list): 71 | print(f"Dataset: {dset}------------------------------") 72 | for m in model_list: 73 | p = f"./results_13b/{m}-{dset}" 74 | with open(glob.glob(f"{p}/no_cache_*.json")[0], "r") as f: 75 | no_cache = [json.loads(line) for line in f] 76 | no_cache_score, nc_std = score(no_cache, dset) 77 | 78 | with open(glob.glob(f"{p}/with_cache_*.json")[0], "r") as f: 79 | with_cache = [json.loads(line) for line in f] 80 | with_cache_score, wc_std = score(with_cache, dset) 81 | 82 | print(f"{m}-{dset}: {no_cache_score:.2f} ({nc_std:.2f}) vs {with_cache_score:.2f} ({wc_std:.2f}) ") 83 | 84 | 85 | def score(results, dataset_name): 86 | scores_list = [] 87 | for result in results: 88 | response = result["response"].split('')[0].split('<|endoftext|>')[0].split('<|im_end|>')[0] 89 | answers = result["answers"] 90 | 91 | score = 0. 92 | for answer in answers: 93 | score = max(score, dataset2metric[dataset_name](response, answer)) 94 | 95 | scores_list.append(score) 96 | 97 | #print(scores_list) 98 | mean = np.mean(scores_list) * 100 99 | std = np.std(scores_list) 100 | 101 | return mean, std 102 | 103 | 104 | if __name__ == "__main__": 105 | main() 106 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import re 2 | import string 3 | 4 | import jieba 5 | from fuzzywuzzy import fuzz 6 | import difflib 7 | 8 | from typing import List 9 | from collections import Counter 10 | from rouge import Rouge 11 | 12 | 13 | def normalize_answer(s): 14 | """Lower text and remove punctuation, articles and extra whitespace.""" 15 | 16 | def remove_articles(text): 17 | return re.sub(r"\b(a|an|the)\b", " ", text) 18 | 19 | def white_space_fix(text): 20 | return " ".join(text.split()) 21 | 22 | def remove_punc(text): 23 | exclude = set(string.punctuation) 24 | return "".join(ch for ch in text if ch not in exclude) 25 | 26 | def lower(text): 27 | return text.lower() 28 | 29 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 30 | 31 | 32 | def normalize_zh_answer(s): 33 | """Lower text and remove punctuation, extra whitespace.""" 34 | 35 | def white_space_fix(text): 36 | return "".join(text.split()) 37 | 38 | def remove_punc(text): 39 | cn_punctuation = "!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏." 40 | all_punctuation = set(string.punctuation + cn_punctuation) 41 | return "".join(ch for ch in text if ch not in all_punctuation) 42 | 43 | def lower(text): 44 | return text.lower() 45 | 46 | return white_space_fix(remove_punc(lower(s))) 47 | 48 | 49 | def count_score(prediction, ground_truth, **kwargs): 50 | numbers = re.findall(r"\d+", prediction) 51 | right_num = 0 52 | for number in numbers: 53 | if str(number) == str(ground_truth): 54 | right_num += 1 55 | final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers) 56 | return float(final_score) 57 | 58 | 59 | def retrieval_score(prediction, ground_truth, **kwargs): 60 | pattern = r'Paragraph (\d+)' 61 | matches = re.findall(pattern, ground_truth) 62 | ground_truth_id = matches[0] 63 | numbers = re.findall(r"\d+", prediction) 64 | right_num = 0 65 | for number in numbers: 66 | if str(number) == str(ground_truth_id): 67 | right_num += 1 68 | final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers) 69 | return float(final_score) 70 | 71 | 72 | def retrieval_zh_score(prediction, ground_truth, **kwargs): 73 | pattern = r'段落(\d+)' 74 | matches = re.findall(pattern, ground_truth) 75 | ground_truth_id = matches[0] 76 | numbers = re.findall(r"\d+", prediction) 77 | right_num = 0 78 | for number in numbers: 79 | if str(number) == str(ground_truth_id): 80 | right_num += 1 81 | final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers) 82 | return float(final_score) 83 | 84 | 85 | def code_sim_score(prediction, ground_truth, **kwargs): 86 | all_lines = prediction.lstrip('\n').split('\n') 87 | prediction = "" 88 | for line in all_lines: 89 | if ('`' not in line) and ('#' not in line) and ('//' not in line): 90 | prediction = line 91 | break 92 | return (fuzz.ratio(prediction, ground_truth) / 100) 93 | 94 | 95 | def classification_score(prediction, ground_truth, **kwargs): 96 | em_match_list = [] 97 | all_classes = kwargs["all_classes"] 98 | for class_name in all_classes: 99 | if class_name in prediction: 100 | em_match_list.append(class_name) 101 | for match_term in em_match_list: 102 | if match_term in ground_truth and match_term != ground_truth: 103 | em_match_list.remove(match_term) 104 | if em_match_list != 0: 105 | if ground_truth in em_match_list: 106 | score = (1.0 / len(em_match_list)) 107 | else: 108 | score = 0.0 109 | else: 110 | best_match = None 111 | highest_similarity = 0 112 | for string in all_classes: 113 | similarity = difflib.SequenceMatcher(None, string, prediction).ratio() 114 | if similarity > highest_similarity: 115 | highest_similarity = similarity 116 | best_match = string 117 | score = float(best_match == ground_truth) 118 | return score 119 | 120 | 121 | def rouge_score(prediction, ground_truth, **kwargs): 122 | rouge = Rouge() 123 | try: 124 | scores = rouge.get_scores([prediction], [ground_truth], avg=True) 125 | except: 126 | return 0.0 127 | return scores["rouge-l"]["f"] 128 | 129 | 130 | def rouge_zh_score(prediction, ground_truth, **kwargs): 131 | prediction = " ".join(list(jieba.cut(prediction, cut_all=False))) 132 | ground_truth = " ".join(list(jieba.cut(ground_truth, cut_all=False))) 133 | score = rouge_score(prediction, ground_truth) 134 | return score 135 | 136 | 137 | def f1_score(prediction, ground_truth, **kwargs): 138 | common = Counter(prediction) & Counter(ground_truth) 139 | num_same = sum(common.values()) 140 | if num_same == 0: 141 | return 0 142 | precision = 1.0 * num_same / len(prediction) 143 | recall = 1.0 * num_same / len(ground_truth) 144 | f1 = (2 * precision * recall) / (precision + recall) 145 | return f1 146 | 147 | 148 | def qa_f1_score(prediction, ground_truth, **kwargs): 149 | normalized_prediction = normalize_answer(prediction) 150 | normalized_ground_truth = normalize_answer(ground_truth) 151 | 152 | prediction_tokens = normalized_prediction.split() 153 | ground_truth_tokens = normalized_ground_truth.split() 154 | return f1_score(prediction_tokens, ground_truth_tokens) 155 | 156 | 157 | def qa_f1_zh_score(prediction, ground_truth, **kwargs): 158 | prediction_tokens = list(jieba.cut(prediction, cut_all=False)) 159 | ground_truth_tokens = list(jieba.cut(ground_truth, cut_all=False)) 160 | prediction_tokens = [normalize_zh_answer(token) for token in prediction_tokens] 161 | ground_truth_tokens = [normalize_zh_answer(token) for token in ground_truth_tokens] 162 | prediction_tokens = [token for token in prediction_tokens if len(token) > 0] 163 | ground_truth_tokens = [token for token in ground_truth_tokens if len(token) > 0] 164 | return f1_score(prediction_tokens, ground_truth_tokens) 165 | -------------------------------------------------------------------------------- /promptcache/__init__.py: -------------------------------------------------------------------------------- 1 | # re-export all modules 2 | 3 | from .cache_engine import CacheEngine 4 | from .generation_engine import GenerationEngine, GenerationParameters 5 | from .prompt import Prompt, CompactSpaces, read_file 6 | from .schema import Schema 7 | from .conversation import llama2_template -------------------------------------------------------------------------------- /promptcache/compiler.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import inspect 3 | import astor 4 | from functools import wraps 5 | 6 | 7 | class SchemaBuilder: 8 | schema: list[str | tuple[bool, str], tuple[str, dict]] 9 | schema_name: str | None 10 | 11 | def __init__(self): 12 | self.schema = [] 13 | self.schema_name = None 14 | 15 | def text(self, value: str): 16 | self.schema.append(value) 17 | 18 | def cond(self, cond: bool, _): 19 | value = self.schema.pop() 20 | self.schema.append((cond, value)) 21 | 22 | def union(self, target: str, values: dict): 23 | self.schema.append((target, values)) 24 | 25 | def get_schema(self) -> str: 26 | 27 | uid = 0 28 | output = "" 29 | for item in self.schema: 30 | match item: 31 | case str(val): 32 | output += val + "\n" 33 | case (bool(), str(val)): 34 | output += f"{val}" 35 | uid += 1 36 | case (str(), dict(d)): 37 | output += "" 38 | for cond, val in d.items(): 39 | output += f"{val}" 40 | uid += 1 41 | output += "" 42 | 43 | self.schema_name = f'{abs(hash(output)):x}' 44 | 45 | output = f"" + output 46 | output += "" 47 | return output 48 | 49 | def get_prompt(self) -> str: 50 | 51 | if self.schema_name is None: 52 | self.get_schema() 53 | 54 | output = f"" 55 | uid = 0 56 | for item in self.schema: 57 | match item: 58 | case (bool(cond), str()): 59 | if cond: 60 | output += f"" 61 | uid += 1 62 | case (str(target), dict(d)): 63 | for cond, val in d.items(): 64 | if target == cond: 65 | output += f"" 66 | uid += 1 67 | case _: 68 | ... 69 | output += "" 70 | return output 71 | 72 | 73 | class CodeTransformer(ast.NodeTransformer): 74 | 75 | def visit_Expr(self, node): 76 | if isinstance(node.value, ast.Constant) and isinstance(node.value.value, str): 77 | # Safely format the string 78 | return ast.parse(f'builder.text({repr(node.value.value)})').body[0] 79 | return node 80 | 81 | def visit_If(self, node): 82 | condition = astor.to_source(node.test).strip() 83 | then_body = [self.visit(n) for n in node.body] 84 | content = astor.to_source(then_body[0]).strip() 85 | return ast.parse(f'builder.cond({condition}, {content})').body[0] 86 | 87 | def visit_Match(self, node): 88 | subject = astor.to_source(node.subject).strip() 89 | cases = {} 90 | for case in node.cases: 91 | key = astor.to_source(case.pattern.value).strip() 92 | value = astor.to_source(case.body[0]).strip() 93 | cases[key] = value 94 | cases_str = ", ".join([f'{key}: {value}' for key, value in cases.items()]) 95 | return ast.parse(f'builder.union({subject}, {{{cases_str}}})').body[0] 96 | 97 | 98 | cached_compiled_funcs = dict() 99 | 100 | 101 | def prompt(func): 102 | # 103 | @wraps(func) 104 | def wrapper(*args, **kwargs): 105 | global cached_compiled_funcs 106 | 107 | if func in cached_compiled_funcs: 108 | print('using cached') 109 | return cached_compiled_funcs[func](*args, **kwargs) 110 | 111 | source_code = inspect.getsource(func).splitlines() 112 | # strip decorators from the source itself 113 | source_code = "\n".join(line for line in source_code if not line.startswith('@')) 114 | 115 | tree = ast.parse(source_code) 116 | 117 | # Insert 'env = Env()' at the start of the function body 118 | env_init = ast.parse("builder = SchemaBuilder()").body[0] 119 | tree.body[0].body.insert(0, env_init) 120 | 121 | # Ensure the transformed function returns 'env' 122 | tree.body[0].body.append(ast.parse("return builder").body[0]) 123 | 124 | transformer = CodeTransformer() 125 | new_tree = transformer.visit(tree) 126 | 127 | # for debugging 128 | new_source_code = astor.to_source(new_tree) 129 | 130 | local_vars = {} 131 | exec(compile(new_tree, filename="", mode="exec"), func.__globals__, local_vars) 132 | modified_func = local_vars[func.__name__] 133 | cached_compiled_funcs[func] = modified_func 134 | return modified_func(*args, **kwargs) 135 | 136 | return wrapper 137 | -------------------------------------------------------------------------------- /promptcache/conversation.py: -------------------------------------------------------------------------------- 1 | from enum import auto, Enum, IntEnum 2 | from typing import List, Any, Dict 3 | import dataclasses 4 | 5 | import numpy as np 6 | import torch 7 | import transformers 8 | from transformers.trainer_pt_utils import LabelSmoother 9 | 10 | IGNORE_TOKEN_ID = LabelSmoother.ignore_index 11 | 12 | 13 | class SeparatorStyle(IntEnum): 14 | """Separator styles.""" 15 | 16 | ADD_COLON_SINGLE = auto() 17 | ADD_COLON_TWO = auto() 18 | ADD_COLON_SPACE_SINGLE = auto() 19 | NO_COLON_SINGLE = auto() 20 | NO_COLON_TWO = auto() 21 | ADD_NEW_LINE_SINGLE = auto() 22 | LLAMA2 = auto() 23 | CHATGLM = auto() 24 | CHATML = auto() 25 | CHATINTERN = auto() 26 | DOLLY = auto() 27 | RWKV = auto() 28 | PHOENIX = auto() 29 | ROBIN = auto() 30 | 31 | 32 | @dataclasses.dataclass 33 | class Conversation: 34 | """A class that manages prompt templates and keeps all conversation history.""" 35 | 36 | # The name of this template 37 | name: str 38 | # The system prompt 39 | system: str 40 | # Two roles 41 | roles: List[str] 42 | # All messages. Each item is (role, message). 43 | messages: List[List[str]] 44 | # The number of few shot examples 45 | offset: int 46 | # Separators 47 | sep_style: SeparatorStyle 48 | sep: str 49 | sep2: str = None 50 | # Stop criteria (the default one is EOS token) 51 | stop_str: str = None 52 | # Stops generation if meeting any token in this list 53 | stop_token_ids: List[int] = None 54 | 55 | def get_prompt(self) -> str: 56 | """Get the prompt for generation.""" 57 | if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE: 58 | ret = self.system + self.sep 59 | for role, message in self.messages: 60 | if message: 61 | ret += role + ": " + message + self.sep 62 | else: 63 | ret += role + ":" 64 | return ret 65 | elif self.sep_style == SeparatorStyle.ADD_COLON_TWO: 66 | seps = [self.sep, self.sep2] 67 | ret = self.system + seps[0] 68 | for i, (role, message) in enumerate(self.messages): 69 | if message: 70 | ret += role + ": " + message + seps[i % 2] 71 | else: 72 | ret += role + ":" 73 | return ret 74 | elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE: 75 | ret = self.system + self.sep 76 | for role, message in self.messages: 77 | if message: 78 | ret += role + ": " + message + self.sep 79 | else: 80 | ret += role + ": " # must be end with a space 81 | return ret 82 | elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE: 83 | ret = "" if self.system == "" else self.system + self.sep 84 | for role, message in self.messages: 85 | if message: 86 | ret += role + "\n" + message + self.sep 87 | else: 88 | ret += role + "\n" 89 | return ret 90 | elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE: 91 | ret = self.system 92 | for role, message in self.messages: 93 | if message: 94 | ret += role + message + self.sep 95 | else: 96 | ret += role 97 | return ret 98 | elif self.sep_style == SeparatorStyle.NO_COLON_TWO: 99 | seps = [self.sep, self.sep2] 100 | ret = self.system 101 | for i, (role, message) in enumerate(self.messages): 102 | if message: 103 | ret += role + message + seps[i % 2] 104 | else: 105 | ret += role 106 | return ret 107 | elif self.sep_style == SeparatorStyle.RWKV: 108 | ret = self.system 109 | for i, (role, message) in enumerate(self.messages): 110 | if message: 111 | ret += ( 112 | role 113 | + ": " 114 | + message.replace("\r\n", "\n").replace("\n\n", "\n") 115 | ) 116 | ret += "\n\n" 117 | else: 118 | ret += role + ":" 119 | return ret 120 | elif self.sep_style == SeparatorStyle.LLAMA2: 121 | seps = [self.sep, self.sep2] 122 | ret = "" 123 | for i, (role, message) in enumerate(self.messages): 124 | if message: 125 | if i == 0: 126 | ret += self.system + message 127 | else: 128 | ret += role + " " + message + seps[i % 2] 129 | else: 130 | ret += role 131 | return ret 132 | elif self.sep_style == SeparatorStyle.CHATGLM: 133 | # source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308 134 | # source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926 135 | round_add_n = 1 if self.name == "chatglm2" else 0 136 | if self.system: 137 | ret = self.system + self.sep 138 | else: 139 | ret = "" 140 | 141 | for i, (role, message) in enumerate(self.messages): 142 | if i % 2 == 0: 143 | ret += f"[Round {i // 2 + round_add_n}]{self.sep}" 144 | 145 | if message: 146 | ret += f"{role}:{message}{self.sep}" 147 | else: 148 | ret += f"{role}:" 149 | return ret 150 | elif self.sep_style == SeparatorStyle.CHATML: 151 | ret = "" if self.system == "" else self.system + self.sep + "\n" 152 | for role, message in self.messages: 153 | if message: 154 | ret += role + "\n" + message + self.sep + "\n" 155 | else: 156 | ret += role + "\n" 157 | return ret 158 | elif self.sep_style == SeparatorStyle.CHATINTERN: 159 | # source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771 160 | seps = [self.sep, self.sep2] 161 | ret = self.system 162 | for i, (role, message) in enumerate(self.messages): 163 | if i % 2 == 0: 164 | ret += "" 165 | if message: 166 | ret += role + ":" + message + seps[i % 2] + "\n" 167 | else: 168 | ret += role + ":" 169 | return ret 170 | elif self.sep_style == SeparatorStyle.DOLLY: 171 | seps = [self.sep, self.sep2] 172 | ret = self.system 173 | for i, (role, message) in enumerate(self.messages): 174 | if message: 175 | ret += role + ":\n" + message + seps[i % 2] 176 | if i % 2 == 1: 177 | ret += "\n\n" 178 | else: 179 | ret += role + ":\n" 180 | return ret 181 | elif self.sep_style == SeparatorStyle.PHOENIX: 182 | ret = self.system 183 | for role, message in self.messages: 184 | if message: 185 | ret += role + ": " + "" + message + "" 186 | else: 187 | ret += role + ": " + "" 188 | return ret 189 | elif self.sep_style == SeparatorStyle.ROBIN: 190 | ret = self.system + self.sep 191 | for role, message in self.messages: 192 | if message: 193 | ret += role + ":\n" + message + self.sep 194 | else: 195 | ret += role + ":\n" 196 | return ret 197 | else: 198 | raise ValueError(f"Invalid style: {self.sep_style}") 199 | 200 | def append_message(self, role: str, message: str): 201 | """Append a new message.""" 202 | self.messages.append([role, message]) 203 | 204 | def update_last_message(self, message: str): 205 | """Update the last output. 206 | 207 | The last message is typically set to be None when constructing the prompt, 208 | so we need to update it in-place after getting the response from a model. 209 | """ 210 | self.messages[-1][1] = message 211 | 212 | def to_gradio_chatbot(self): 213 | """Convert the conversation to gradio chatbot format.""" 214 | ret = [] 215 | for i, (role, msg) in enumerate(self.messages[self.offset:]): 216 | if i % 2 == 0: 217 | ret.append([msg, None]) 218 | else: 219 | ret[-1][-1] = msg 220 | return ret 221 | 222 | def to_openai_api_messages(self): 223 | """Convert the conversation to OpenAI chat completion format.""" 224 | ret = [{"role": "system", "content": self.system}] 225 | 226 | for i, (_, msg) in enumerate(self.messages[self.offset:]): 227 | if i % 2 == 0: 228 | ret.append({"role": "user", "content": msg}) 229 | else: 230 | if msg is not None: 231 | ret.append({"role": "assistant", "content": msg}) 232 | return ret 233 | 234 | def copy(self): 235 | return Conversation( 236 | name=self.name, 237 | system=self.system, 238 | roles=self.roles, 239 | messages=[[x, y] for x, y in self.messages], 240 | offset=self.offset, 241 | sep_style=self.sep_style, 242 | sep=self.sep, 243 | sep2=self.sep2, 244 | stop_str=self.stop_str, 245 | stop_token_ids=self.stop_token_ids, 246 | ) 247 | 248 | def dict(self): 249 | return { 250 | "template_name": self.name, 251 | "system": self.system, 252 | "roles": self.roles, 253 | "messages": self.messages, 254 | "offset": self.offset, 255 | } 256 | 257 | 258 | def vicuna_template(system: str = "A chat between a user and an artificial intelligence assistant. " 259 | "The assistant gives helpful and detailed answers to the user's questions."): 260 | return Conversation( 261 | name="vicuna_v1.1", 262 | system=system, 263 | roles=["USER", "ASSISTANT"], 264 | messages=[], 265 | offset=0, 266 | sep_style=SeparatorStyle.ADD_COLON_TWO, 267 | sep=" ", 268 | sep2="", 269 | ) 270 | 271 | 272 | def llama2_template( 273 | system: str = "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. " 274 | "Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. " 275 | "Please ensure that your responses are socially unbiased and positive in nature.\n\n" 276 | "If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. " 277 | "If you don't know the answer to a question, please don't share false information.\n)"): 278 | return Conversation( 279 | name="llama-2", 280 | system=f" [INST] <>\n{system}<>\n\n ", 281 | roles=["[INST]", "[/INST]"], 282 | messages=[], 283 | offset=0, 284 | sep_style=SeparatorStyle.LLAMA2, 285 | sep=" ", 286 | sep2="", 287 | stop_token_ids=[2], 288 | ) 289 | 290 | # wittgenstein 291 | 292 | 293 | # 294 | # def preprocess( 295 | # conv, 296 | # source, 297 | # tokenizer: transformers.PreTrainedTokenizer, 298 | # ) -> (torch.Tensor, torch.Tensor): 299 | # roles = {"human": conv.roles[0], "gpt": conv.roles[1]} 300 | # 301 | # # print(sources) 302 | # 303 | # # Apply prompt templates 304 | # 305 | # if roles[source[0]["from"]] != conv.roles[0]: 306 | # # Skip the first one if it is not from human 307 | # source = source[1:] 308 | # 309 | # conv.messages = [] 310 | # for j, sentence in enumerate(source): 311 | # role = roles[sentence["from"]] 312 | # assert role == conv.roles[j % 2], f"{i}" 313 | # conv.append_message(role, sentence["value"]) 314 | # 315 | # prompt = conv.get_prompt() 316 | # 317 | # # Tokenize conversations 318 | # input_ids = tokenizer( 319 | # conversations, 320 | # return_tensors="pt" 321 | # ).input_ids 322 | # 323 | # targets = input_ids.clone() 324 | # 325 | # # print(tokenizer.decode(input_ids[0])) 326 | # 327 | # # assert conv.sep_style == SeparatorStyle.ADD_COLON_TWO 328 | # 329 | # #Mask targets 330 | # sep = conv.sep + conv.roles[1] + ": " 331 | # for conversation, target in zip(conversations, targets): 332 | # total_len = int(target.ne(tokenizer.pad_token_id).sum()) 333 | # 334 | # rounds = conversation.split(conv.sep2) 335 | # cur_len = 1 336 | # target[:cur_len] = IGNORE_TOKEN_ID 337 | # for i, rou in enumerate(rounds): 338 | # if rou == "": 339 | # break 340 | # 341 | # parts = rou.split(sep) 342 | # if len(parts) != 2: 343 | # break 344 | # parts[0] += sep 345 | # round_len = len(tokenizer(rou).input_ids) 346 | # instruction_len = len(tokenizer(parts[0]).input_ids) - 2 347 | # 348 | # target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID 349 | # 350 | # cur_len += round_len 351 | # target[cur_len:] = IGNORE_TOKEN_ID 352 | # 353 | # if cur_len < tokenizer.model_max_length: 354 | # if cur_len != total_len: 355 | # target[:] = IGNORE_TOKEN_ID 356 | # print( 357 | # f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." 358 | # f" (ignored)" 359 | # ) 360 | # print(target) 361 | # 362 | # return input_ids, targets 363 | 364 | # print(len(input_ids[0])) 365 | # print(tokenizer.decode(input_ids[0], skip_special_tokens=True)) 366 | # print(len(targets[0])) 367 | # targets[0][targets[0] == IGNORE_TOKEN_ID] = tokenizer.pad_token_id 368 | # print(tokenizer.decode(targets[0], skip_special_tokens=True)) 369 | # 370 | # return dict( 371 | # input_ids=input_ids, 372 | # labels=targets, 373 | # attention_mask=input_ids.ne(tokenizer.pad_token_id), 374 | # ) 375 | -------------------------------------------------------------------------------- /promptcache/generation_engine.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import time 3 | from dataclasses import dataclass, field 4 | from typing import Optional, Generator, List, Tuple 5 | 6 | import torch 7 | 8 | from promptcache.cache_engine import KVCache 9 | 10 | from transformers.generation.logits_process import ( 11 | LogitsProcessorList, 12 | RepetitionPenaltyLogitsProcessor, 13 | TemperatureLogitsWarper, 14 | TopKLogitsWarper, 15 | TopPLogitsWarper, 16 | ) 17 | import termcolor 18 | from promptcache.model import LanguageModel 19 | 20 | 21 | @dataclass 22 | class GenerationParameters: 23 | temperature: float = 1.0 24 | repetition_penalty: float = 1.0 25 | top_p: float = 1.0 26 | top_k: int = -1 27 | max_new_tokens: int = 256 28 | stop_token_ids: List[int] = field(default_factory=lambda: []) 29 | stop_str: List[str] = field(default_factory=lambda: []) 30 | echo: bool = True 31 | 32 | def get_logits_processor(self): 33 | p = LogitsProcessorList() 34 | if self.temperature >= 1e-5 and self.temperature != 1.0: 35 | p.append(TemperatureLogitsWarper(self.temperature)) 36 | if self.repetition_penalty > 1.0: 37 | p.append(RepetitionPenaltyLogitsProcessor(self.repetition_penalty)) 38 | if 1e-8 <= self.top_p < 1.0: 39 | p.append(TopPLogitsWarper(self.top_p)) 40 | if self.top_k > 0: 41 | p.append(TopKLogitsWarper(self.top_k)) 42 | return p 43 | 44 | 45 | def is_partial_stop(output: str, stop_str: str): 46 | """Check whether the output contains a partial stop str.""" 47 | for i in range(0, min(len(output), len(stop_str))): 48 | if stop_str.startswith(output[-i:]): 49 | return True 50 | return False 51 | 52 | 53 | @dataclass 54 | class Output: 55 | text: str 56 | new_text: str 57 | response_time: float = 0.0 58 | elapsed_time: float = 0.0 59 | 60 | 61 | class GenerationEngine: 62 | lm: LanguageModel 63 | 64 | def __init__(self, lm: LanguageModel): 65 | self.lm = lm 66 | 67 | @torch.inference_mode() 68 | def generate(self, 69 | token_ids: List[int], 70 | position_ids: List[int], 71 | params: GenerationParameters, 72 | cache: Optional[KVCache] = None, 73 | stream_interval: int = 2, 74 | use_full_position_ids: bool = False) -> Generator[Output, None, None]: 75 | 76 | logits_processor = params.get_logits_processor() 77 | output_ids = list(token_ids) 78 | new_output_ids = list() 79 | 80 | device = self.lm.device 81 | 82 | position_offset = max(position_ids) + 1 83 | past_key_values = None 84 | new_token_id = 0 85 | 86 | inference_time = 0.0 87 | response_time = 0.0 88 | 89 | position_ids_og = position_ids 90 | 91 | for i in range(params.max_new_tokens): 92 | 93 | # initial phase 94 | if past_key_values is None: 95 | 96 | input_ids = torch.tensor([token_ids], device=device, dtype=torch.long) 97 | position_ids = torch.tensor([position_ids], device=device, dtype=torch.long) 98 | # print(len(position_ids[0])) 99 | 100 | # add redundant batch dim 101 | if cache is not None: 102 | cache = [(k[0].unsqueeze(0), k[1].unsqueeze(0)) for k in cache] 103 | 104 | start = torch.cuda.Event(enable_timing=True) 105 | end = torch.cuda.Event(enable_timing=True) 106 | 107 | start.record() 108 | out = self.lm(input_ids=input_ids, 109 | position_ids=position_ids, 110 | past_key_values=cache, 111 | use_cache=True) 112 | end.record() 113 | torch.cuda.synchronize() 114 | inference_time += start.elapsed_time(end) 115 | response_time = inference_time 116 | # print(f'Response time: {inference_time:.2f} ms') 117 | # pretty print using termcolor 118 | print(termcolor.colored(f'Prefill latency: {inference_time:.2f} ms', 'yellow')) 119 | 120 | logits = out.logits 121 | past_key_values = out.past_key_values 122 | 123 | else: 124 | # upload to the GPU 125 | input_ids = torch.tensor([[new_token_id]], device=device, dtype=torch.long) 126 | 127 | if use_full_position_ids: 128 | position_ids = torch.tensor([position_ids_og + list(range(position_offset, position_offset + i))], 129 | device=device, dtype=torch.long) 130 | 131 | else: 132 | position_ids = torch.tensor([[position_offset + i]], device=device, dtype=torch.long) 133 | 134 | start = torch.cuda.Event(enable_timing=True) 135 | end = torch.cuda.Event(enable_timing=True) 136 | 137 | start.record() 138 | out = self.lm(input_ids=input_ids, 139 | position_ids=position_ids, 140 | past_key_values=past_key_values, 141 | use_cache=True) 142 | end.record() 143 | torch.cuda.synchronize() 144 | inference_time += start.elapsed_time(end) 145 | 146 | logits = out.logits 147 | past_key_values = out.past_key_values 148 | 149 | if params.repetition_penalty > 1.0: 150 | tmp_output_ids = torch.as_tensor([output_ids], device=device) 151 | else: 152 | tmp_output_ids = None 153 | 154 | last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0] 155 | 156 | ccc = [] 157 | 158 | if params.temperature < 1e-5 or params.top_p < 1e-8: # greedy 159 | new_token_id = int(torch.argmax(last_token_logits)) 160 | # _, indices = torch.topk(last_token_logits, 2) 161 | # ccc = [int(index) for index in indices.tolist()] 162 | 163 | else: 164 | probs = torch.softmax(last_token_logits, dim=-1) 165 | new_token_id = int(torch.multinomial(probs, num_samples=1)) 166 | # probs = torch.softmax(last_token_logits, dim=-1) 167 | # indices = torch.multinomial(probs, num_samples=2) 168 | # ccc = [int(token) for token in indices.tolist()] 169 | 170 | # new_token_id = ccc[1] 171 | 172 | output_ids.append(new_token_id) 173 | new_output_ids.append(new_token_id) 174 | 175 | # print(self.lm.decode([new_token_id])) 176 | 177 | if new_token_id in params.stop_token_ids: 178 | # print('Stopped', self.lm.decode([new_token_id]), self.lm.encode('. ')) 179 | stopped = True 180 | else: 181 | stopped = False 182 | 183 | if i % stream_interval == 0 or i == params.max_new_tokens - 1 or stopped: 184 | output = self.lm.decode(output_ids) 185 | new_output = self.lm.decode(new_output_ids) 186 | 187 | partially_stopped = False 188 | 189 | for each_stop in params.stop_str: 190 | pos = new_output.rfind(each_stop, 0) 191 | if pos != -1: 192 | new_output = new_output[:pos] 193 | stopped = True 194 | break 195 | else: 196 | partially_stopped = is_partial_stop(output, each_stop) 197 | if partially_stopped: 198 | break 199 | 200 | if not partially_stopped: 201 | yield Output(output, new_output, inference_time, response_time) 202 | 203 | if stopped: 204 | break 205 | 206 | # clean 207 | del past_key_values, out 208 | gc.collect() 209 | torch.cuda.empty_cache() 210 | -------------------------------------------------------------------------------- /promptcache/inference.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import gc 3 | import math 4 | from typing import Iterable, Optional 5 | import sys 6 | import time 7 | import warnings 8 | 9 | from conversation import Conversation, SeparatorStyle, llama2_template 10 | 11 | import torch 12 | from transformers import ( 13 | AutoTokenizer, 14 | AutoModelForCausalLM, 15 | LlamaTokenizer, 16 | LlamaForCausalLM, 17 | AutoModel, 18 | AutoModelForSeq2SeqLM, 19 | T5Tokenizer, 20 | AutoConfig, 21 | ) 22 | from transformers.generation.logits_process import ( 23 | LogitsProcessorList, 24 | RepetitionPenaltyLogitsProcessor, 25 | TemperatureLogitsWarper, 26 | TopKLogitsWarper, 27 | TopPLogitsWarper, 28 | ) 29 | 30 | 31 | def prepare_logits_processor( 32 | temperature: float, repetition_penalty: float, top_p: float, top_k: int 33 | ) -> LogitsProcessorList: 34 | processor_list = LogitsProcessorList() 35 | # TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op so we skip two cases. 36 | if temperature >= 1e-5 and temperature != 1.0: 37 | processor_list.append(TemperatureLogitsWarper(temperature)) 38 | if repetition_penalty > 1.0: 39 | processor_list.append( 40 | RepetitionPenaltyLogitsProcessor(repetition_penalty)) 41 | if 1e-8 <= top_p < 1.0: 42 | processor_list.append(TopPLogitsWarper(top_p)) 43 | if top_k > 0: 44 | processor_list.append(TopKLogitsWarper(top_k)) 45 | return processor_list 46 | 47 | 48 | def partial_stop(output, stop_str): 49 | for i in range(0, min(len(output), len(stop_str))): 50 | if stop_str.startswith(output[-i:]): 51 | return True 52 | return False 53 | 54 | 55 | @torch.inference_mode() 56 | def generate_stream( 57 | model, tokenizer, params, device, context_len=2048, stream_interval=2 58 | ): 59 | prompt = params["prompt"] 60 | len_prompt = len(prompt) 61 | temperature = float(params.get("temperature", 1.0)) 62 | repetition_penalty = float(params.get("repetition_penalty", 1.0)) 63 | top_p = float(params.get("top_p", 1.0)) 64 | top_k = int(params.get("top_k", -1)) # -1 means disable 65 | max_new_tokens = int(params.get("max_new_tokens", 256)) 66 | stop_str = params.get("stop", None) 67 | echo = bool(params.get("echo", True)) 68 | stop_token_ids = params.get("stop_token_ids", None) or [] 69 | stop_token_ids.append(tokenizer.eos_token_id) 70 | 71 | logits_processor = prepare_logits_processor( 72 | temperature, repetition_penalty, top_p, top_k 73 | ) 74 | 75 | input_ids = tokenizer(prompt).input_ids 76 | input_echo_len = len(input_ids) 77 | output_ids = list(input_ids) 78 | 79 | max_src_len = context_len - max_new_tokens - 8 80 | 81 | input_ids = input_ids[-max_src_len:] 82 | 83 | # print(prompt) 84 | # aa = tokenizer.convert_ids_to_tokens(input_ids) 85 | # print(aa) 86 | 87 | past_key_values = out = None 88 | for i in range(max_new_tokens): 89 | if i == 0: 90 | print(torch.as_tensor( 91 | [input_ids], device=device)) 92 | 93 | out = model(torch.as_tensor([input_ids], device=device)) 94 | logits = out.logits 95 | past_key_values = out.past_key_values 96 | else: 97 | 98 | out = model( 99 | input_ids=torch.as_tensor([[token]], device=device), 100 | use_cache=True, 101 | past_key_values=past_key_values, 102 | ) 103 | logits = out.logits 104 | past_key_values = out.past_key_values 105 | 106 | if logits_processor: 107 | if repetition_penalty > 1.0: 108 | tmp_output_ids = torch.as_tensor( 109 | [output_ids], device=logits.device) 110 | else: 111 | tmp_output_ids = None 112 | last_token_logits = logits_processor( 113 | tmp_output_ids, logits[:, -1, :])[0] 114 | else: 115 | last_token_logits = logits[0, -1, :] 116 | 117 | if temperature < 1e-5 or top_p < 1e-8: # greedy 118 | token = int(torch.argmax(last_token_logits)) 119 | else: 120 | probs = torch.softmax(last_token_logits, dim=-1) 121 | token = int(torch.multinomial(probs, num_samples=1)) 122 | 123 | output_ids.append(token) 124 | 125 | if token in stop_token_ids: 126 | stopped = True 127 | else: 128 | stopped = False 129 | 130 | if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped: 131 | if echo: 132 | tmp_output_ids = output_ids 133 | rfind_start = len_prompt 134 | else: 135 | tmp_output_ids = output_ids[input_echo_len:] 136 | rfind_start = 0 137 | 138 | output = tokenizer.decode( 139 | tmp_output_ids, 140 | skip_special_tokens=True, 141 | spaces_between_special_tokens=False, 142 | ) 143 | 144 | partially_stopped = False 145 | if stop_str: 146 | if isinstance(stop_str, str): 147 | pos = output.rfind(stop_str, rfind_start) 148 | if pos != -1: 149 | output = output[:pos] 150 | stopped = True 151 | else: 152 | partially_stopped = partial_stop(output, stop_str) 153 | elif isinstance(stop_str, Iterable): 154 | for each_stop in stop_str: 155 | pos = output.rfind(each_stop, rfind_start) 156 | if pos != -1: 157 | output = output[:pos] 158 | stopped = True 159 | break 160 | else: 161 | partially_stopped = partial_stop(output, each_stop) 162 | if partially_stopped: 163 | break 164 | else: 165 | raise ValueError("Invalid stop field type.") 166 | 167 | # prevent yielding partial stop sequence 168 | if not partially_stopped: 169 | yield { 170 | "text": output, 171 | "usage": { 172 | "prompt_tokens": input_echo_len, 173 | "completion_tokens": i, 174 | "total_tokens": input_echo_len + i, 175 | }, 176 | "finish_reason": None, 177 | } 178 | 179 | if stopped: 180 | break 181 | 182 | # finish stream event, which contains finish reason 183 | if i == max_new_tokens - 1: 184 | finish_reason = "length" 185 | elif stopped: 186 | finish_reason = "stop" 187 | else: 188 | finish_reason = None 189 | 190 | yield { 191 | "text": output, 192 | "usage": { 193 | "prompt_tokens": input_echo_len, 194 | "completion_tokens": i, 195 | "total_tokens": input_echo_len + i, 196 | }, 197 | "finish_reason": finish_reason, 198 | } 199 | 200 | # clean 201 | del past_key_values, out 202 | gc.collect() 203 | torch.cuda.empty_cache() 204 | 205 | 206 | class ChatIO(abc.ABC): 207 | @abc.abstractmethod 208 | def prompt_for_input(self, role: str) -> str: 209 | """Prompt for input from a role.""" 210 | 211 | @abc.abstractmethod 212 | def prompt_for_output(self, role: str): 213 | """Prompt for output from a role.""" 214 | 215 | @abc.abstractmethod 216 | def stream_output(self, output_stream): 217 | """Stream output.""" 218 | 219 | 220 | class SimpleChatIO(ChatIO): 221 | def prompt_for_input(self, role) -> str: 222 | return input(f"{role}: ") 223 | 224 | def prompt_for_output(self, role: str): 225 | print(f"{role}: ", end="", flush=True) 226 | 227 | def stream_output(self, output_stream): 228 | pre = 0 229 | for outputs in output_stream: 230 | output_text = outputs["text"] 231 | output_text = output_text.strip().split(" ") 232 | now = len(output_text) - 1 233 | if now > pre: 234 | print(" ".join(output_text[pre:now]), end=" ", flush=True) 235 | pre = now 236 | print(" ".join(output_text[pre:]), flush=True) 237 | return " ".join(output_text) 238 | 239 | 240 | def chat_loop( 241 | model_path: str, 242 | temperature: float, 243 | repetition_penalty: float, 244 | max_new_tokens: int, 245 | chatio: ChatIO, 246 | debug: bool, 247 | ): 248 | # Model 249 | # model, tokenizer = load_model( 250 | # model_path, 251 | # device, 252 | # num_gpus, 253 | # max_gpu_memory, 254 | # load_8bit, 255 | # cpu_offloading, 256 | # gptq_config, 257 | # revision, 258 | # debug, 259 | # ) 260 | 261 | tokenizer = AutoTokenizer.from_pretrained(model_path) 262 | 263 | model = LlamaForCausalLM.from_pretrained(model_path) 264 | 265 | device = 'cpu' 266 | 267 | conv = llama2_template() 268 | 269 | while True: 270 | try: 271 | inp = chatio.prompt_for_input(conv.roles[0]) 272 | except EOFError: 273 | inp = "" 274 | 275 | if inp == "!!exit" or not inp: 276 | print("exit...") 277 | break 278 | 279 | conv.append_message(conv.roles[0], inp) 280 | conv.append_message(conv.roles[1], None) 281 | 282 | prompt = conv.get_prompt() 283 | 284 | gen_params = { 285 | "model": model_path, 286 | "prompt": prompt, 287 | "temperature": temperature, 288 | "repetition_penalty": repetition_penalty, 289 | "max_new_tokens": max_new_tokens, 290 | "stop": conv.stop_str, 291 | "stop_token_ids": conv.stop_token_ids, 292 | "echo": False, 293 | } 294 | 295 | chatio.prompt_for_output(conv.roles[1]) 296 | output_stream = generate_stream(model, tokenizer, gen_params, device) 297 | t = time.time() 298 | outputs = chatio.stream_output(output_stream) 299 | duration = time.time() - t 300 | conv.update_last_message(outputs.strip()) 301 | 302 | if debug: 303 | num_tokens = len(tokenizer.encode(outputs)) 304 | msg = { 305 | "conv_template": conv.name, 306 | "prompt": prompt, 307 | "outputs": outputs, 308 | "speed (token/s)": round(num_tokens / duration, 2), 309 | } 310 | print(f"\n{msg}\n") 311 | 312 | 313 | if __name__ == "__main__": 314 | chat_loop( 315 | model_path="meta-llama/Llama-2-7b-chat-hf", 316 | temperature=0.7, 317 | repetition_penalty=1.0, 318 | max_new_tokens=512, 319 | chatio=SimpleChatIO(), 320 | debug=True, 321 | ) 322 | -------------------------------------------------------------------------------- /promptcache/model/__init__.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import Callable, List, Optional, Tuple 3 | import torch 4 | import re 5 | 6 | from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer, PreTrainedTokenizer, \ 7 | PretrainedConfig, PreTrainedModel, CodeLlamaTokenizer 8 | 9 | from promptcache.model.falcon import FalconForCausalLM 10 | from promptcache.model.llama2 import LlamaForCausalLM 11 | from promptcache.model.mpt import MptForCausalLM 12 | from promptcache.prompt import Preprocessor, escape_xml, PreprocessorList 13 | 14 | 15 | # supported models 16 | # gpt2 17 | # aquilla 18 | # bloom 19 | # baichuan 20 | 21 | # llama 22 | # llama2 23 | # code llama 24 | # wizardlm 25 | # mpt-chat 26 | # alpaca 27 | # ghatglm 28 | # t5 29 | # falcon 30 | # dolly 31 | # 32 | # 33 | # Aquila (BAAI/Aquila-7B, BAAI/AquilaChat-7B, etc.) 34 | # Baichuan (baichuan-inc/Baichuan-7B, baichuan-inc/Baichuan-13B-Chat, etc.) 35 | # BLOOM (bigscience/bloom, bigscience/bloomz, etc.) 36 | # Falcon (tiiuae/falcon-7b, tiiuae/falcon-40b, tiiuae/falcon-rw-7b, etc.) 37 | # GPT-2 (gpt2, gpt2-xl, etc.) 38 | # GPT BigCode (bigcode/starcoder, bigcode/gpt_bigcode-santacoder, etc.) 39 | # GPT-J (EleutherAI/gpt-j-6b, nomic-ai/gpt4all-j, etc.) 40 | # GPT-NeoX (EleutherAI/gpt-neox-20b, databricks/dolly-v2-12b, stabilityai/stablelm-tuned-alpha-7b, etc.) 41 | # InternLM (internlm/internlm-7b, internlm/internlm-chat-7b, etc.) 42 | # LLaMA & LLaMA-2 (meta-llama/Llama-2-70b-hf, lmsys/vicuna-13b-v1.3, young-geng/koala, openlm-research/open_llama_13b, etc.) 43 | # MPT (mosaicml/mpt-7b, mosaicml/mpt-30b, etc.) 44 | # OPT (facebook/opt-66b, facebook/opt-iml-max-30b, etc.) 45 | # Qwen (Qwen/Qwen-7B, Qwen/Qwen-7B-Chat, etc.) 46 | 47 | 48 | class FormatConversation(Preprocessor): 49 | system: (str, str, str) 50 | user: (str, str) 51 | assistant: (str, str) 52 | 53 | def __init__(self, system: (str, str, str), user: (str, str), assistant: (str, str)): 54 | super().__init__() 55 | 56 | self.system = (escape_xml(system[0]), escape_xml(system[1]), escape_xml(system[2])) 57 | self.user = (escape_xml(user[0]), escape_xml(user[1])) 58 | self.assistant = (escape_xml(assistant[0]), escape_xml(assistant[1])) 59 | 60 | def __call__(self, prompt: str) -> str: 61 | replacement_pairs = [ 62 | ("", self.system[0]), 63 | ("", self.system[1]), 64 | ("", self.system[2]), 65 | ("", self.user[0]), 66 | ("", self.user[1]), 67 | ("", self.assistant[0]), 68 | ("", self.assistant[1]) 69 | ] 70 | 71 | # remove space before 72 | prompt = re.sub(r' +', '', prompt) 73 | 74 | for old, new in replacement_pairs: 75 | prompt = prompt.replace(old, new) 76 | 77 | return prompt 78 | 79 | 80 | class FormatLlama2Conversation(FormatConversation): 81 | 82 | def __init__(self): 83 | super().__init__( 84 | system=("[INST] <>\n", "<>\n\n", "[INST]"), 85 | user=(" ", "[/INST]"), 86 | assistant=(" ", "[INST]") 87 | ) 88 | 89 | 90 | class LanguageModel(abc.ABC): 91 | name: str 92 | hf_tokenizer: PreTrainedTokenizer 93 | hf_model: PreTrainedModel 94 | stop_token_ids: List[int] 95 | stop_str: List[str] 96 | use_full_position_ids: bool = False 97 | 98 | def __init__(self, name: str, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, 99 | stop_token_ids: Optional[List[int]] = None, stop_str: Optional[List[str]] = None): 100 | self.name = name 101 | self.hf_tokenizer = tokenizer 102 | self.hf_model = model 103 | self.stop_token_ids = stop_token_ids if stop_token_ids is not None else [self.eos_token_id] 104 | self.stop_str = stop_str if stop_str is not None else [] 105 | 106 | @abc.abstractmethod 107 | def get_formatter(self) -> Callable[[str], str]: 108 | pass 109 | 110 | def get_cache_shape(self) -> Tuple[int, int, int]: 111 | num_head = self.config.num_attention_heads 112 | head_dim = self.config.hidden_size // self.config.num_attention_heads 113 | 114 | return self.config.num_hidden_layers, num_head, head_dim 115 | 116 | def store_k_hook(self, k_cache: torch.Tensor) -> torch.Tensor: 117 | return k_cache 118 | 119 | def store_v_hook(self, v_cache: torch.Tensor) -> torch.Tensor: 120 | return v_cache 121 | 122 | def read_k_hook(self, k_cache: torch.Tensor) -> torch.Tensor: 123 | return k_cache 124 | 125 | def read_v_hook(self, v_cache: torch.Tensor) -> torch.Tensor: 126 | return v_cache 127 | 128 | def __call__(self, **kwargs): 129 | return self.hf_model(**kwargs) 130 | 131 | def encode(self, text: str) -> List[int]: 132 | # Warning: this is a hack to remove bos_token 133 | token_ids = self.hf_tokenizer.encode(text, add_special_tokens=False) 134 | return token_ids 135 | 136 | def decode(self, token_ids: List[int]) -> str: 137 | return self.hf_tokenizer.decode(token_ids, skip_special_tokens=False, spaces_between_special_tokens=False) 138 | 139 | @property 140 | def unk_token(self) -> str: 141 | return self.hf_tokenizer.unk_token 142 | 143 | @property 144 | def unk_token_id(self) -> int: 145 | return self.hf_tokenizer.unk_token_id 146 | 147 | @property 148 | def eos_token(self) -> str: 149 | return self.hf_tokenizer.eos_token 150 | 151 | @property 152 | def eos_token_id(self) -> int: 153 | return self.hf_tokenizer.eos_token_id 154 | 155 | @property 156 | def device(self) -> torch.device: 157 | return self.hf_model.device 158 | 159 | @property 160 | def config(self) -> PretrainedConfig: 161 | return self.hf_model.config 162 | 163 | 164 | class CodeLlama(LanguageModel): 165 | 166 | def __init__(self, name: str = "codellama/CodeLlama-13b-Instruct-hf", **kwargs): 167 | tokenizer = CodeLlamaTokenizer.from_pretrained(name) 168 | model = LlamaForCausalLM.from_pretrained(name, **kwargs) 169 | 170 | self.formatter = FormatConversation( 171 | system=(" [INST] <>\n", "<>\n\n", " [INST] "), 172 | user=("", "[/INST]"), 173 | assistant=("", " [INST] ")) 174 | 175 | stop_token_ids = [tokenizer.eos_token_id] 176 | 177 | stop_str = [""] 178 | 179 | super().__init__(name, model, tokenizer, stop_token_ids, stop_str) 180 | 181 | def get_formatter(self) -> Callable[[str], str]: 182 | return self.formatter 183 | 184 | 185 | class Llama2(LanguageModel): 186 | 187 | def __init__(self, name: str = "meta-llama/Llama-2-7b-chat-hf", **kwargs): 188 | tokenizer = LlamaTokenizer.from_pretrained(name) 189 | model = LlamaForCausalLM.from_pretrained(name, **kwargs) 190 | 191 | self.formatter = FormatConversation( 192 | system=(" [INST] <>\n", "<>\n\n", " [INST] "), 193 | user=("", "[/INST]"), 194 | assistant=("", " [INST] ")) 195 | 196 | stop_token_ids = [tokenizer.eos_token_id] 197 | 198 | stop_str = [""] 199 | 200 | super().__init__(name, model, tokenizer, stop_token_ids, stop_str) 201 | 202 | def get_formatter(self) -> Callable[[str], str]: 203 | return self.formatter 204 | 205 | 206 | class Falcon(LanguageModel): 207 | def __init__(self, name="tiiuae/falcon-7b-instruct", **kwargs): 208 | tokenizer = AutoTokenizer.from_pretrained(name) 209 | model = FalconForCausalLM.from_pretrained(name, **kwargs) 210 | 211 | def rep(prompt: str) -> str: 212 | return prompt.replace("\r\n", "\n").replace("\n\n", "\n") 213 | 214 | # name = "falcon", 215 | # roles = ("User", "Assistant"), 216 | # messages = [], 217 | # sep_style = SeparatorStyle.RWKV, 218 | # sep = "\n", 219 | # sep2 = "<|endoftext|>", 220 | # stop_str = "\nUser", 221 | 222 | conv = FormatConversation( 223 | system=("", "\n\n", ""), 224 | user=("User: ", "\n\nAssistant:"), 225 | assistant=(" ", "\n\n")) 226 | # 227 | # conv = FormatConversation( 228 | # system=("", "\n\n", ""), 229 | # user=("<|prompt|>", "<|endoftext|>"), 230 | # assistant=("<|answer|>", "<|endoftext|>")) 231 | 232 | self.formatter = PreprocessorList([ 233 | rep, conv 234 | ]) 235 | 236 | stop_token_ids = [0, 237 | 1, 238 | 2, 239 | 3, 240 | 4, 241 | 5, 242 | 6, 243 | 7, 244 | 8, 245 | 9, 246 | 10, 247 | 11, ] 248 | 249 | stop_str = ["<|endoftext|>", "\nUser"] 250 | 251 | super().__init__(name, model, tokenizer, stop_token_ids, stop_str) 252 | 253 | def get_formatter(self) -> Callable[[str], str]: 254 | return self.formatter 255 | 256 | def get_cache_shape(self) -> Tuple[int, int, int]: 257 | head_dim = self.hf_model.config.hidden_size // self.hf_model.config.num_attention_heads 258 | return self.hf_model.config.num_hidden_layers, 1, head_dim 259 | 260 | 261 | class Mpt(LanguageModel): 262 | def __init__(self, name="mosaicml/mpt-7b-chat", **kwargs): 263 | # tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b", trust_remote_code=True) 264 | tokenizer = AutoTokenizer.from_pretrained(name, trust_remote_code=True) 265 | 266 | model = MptForCausalLM.from_pretrained(name, max_seq_len=8192, 267 | **kwargs) 268 | 269 | conv = FormatConversation( 270 | system=("<|im_start|>system\n", "<|im_end|>\n", ""), 271 | user=("<|im_start|>user\n", "<|im_end|>\n<|im_start|>assistant\n"), 272 | assistant=("", "<|im_end|>\n")) 273 | 274 | self.formatter = conv 275 | self.use_full_position_ids = True 276 | 277 | stop_token_ids = [50278, 0] 278 | stop_str = [] 279 | 280 | super().__init__(name, model, tokenizer, stop_token_ids, stop_str) 281 | 282 | def get_formatter(self) -> Callable[[str], str]: 283 | return self.formatter 284 | 285 | # https://huggingface.co/mosaicml/mpt-7b-chat/blob/main/configuration_mpt.py 286 | def get_cache_shape(self) -> Tuple[int, int, int]: 287 | head_dim = self.hf_model.config.d_model // self.hf_model.config.n_heads, 288 | return self.hf_model.config.n_layers, self.hf_model.config.n_heads, head_dim[0] 289 | # 290 | # def store_k_hook(self, v_cache: torch.Tensor) -> torch.Tensor: 291 | # # batch, n_layers, seq_len, head_dim = v_cache.shape 292 | # return v_cache.transpose(2, 3) 293 | # 294 | # def read_k_hook(self, v_cache: torch.Tensor) -> torch.Tensor: 295 | # return v_cache.transpose(1, 2) 296 | -------------------------------------------------------------------------------- /promptcache/prompt.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import re 4 | from abc import ABC, abstractmethod 5 | from typing import Union, List, Any, Optional 6 | from dataclasses import dataclass 7 | 8 | import lxml 9 | import lxml.etree 10 | 11 | import xml.sax.saxutils 12 | 13 | 14 | def repr_indent(obj: Any, indent: int = 1) -> str: 15 | return '\n'.join([indent * '\t' + s for s in repr(obj).split('\n')]) 16 | 17 | 18 | def read_file(filename: str, preprocessors: List[Preprocessor] = None) -> str: 19 | with open(filename, 'r') as f: 20 | 21 | text = f.read() 22 | 23 | if preprocessors is not None: 24 | for p in preprocessors: 25 | text = p(text) 26 | 27 | return text 28 | 29 | 30 | def apply_preproc(text: str, preprocessors: List[Preprocessor] = None) -> str: 31 | if preprocessors is not None: 32 | for p in preprocessors: 33 | text = p(text) 34 | 35 | return text 36 | 37 | 38 | def escape_xml(data): 39 | return xml.sax.saxutils.escape(data, entities={ 40 | "'": "'", 41 | "\"": """ 42 | }) 43 | 44 | 45 | # Replaces multiple leading and trailing whitespaces with a single space. 46 | def compact_surrounding_spaces(text: str) -> str: 47 | return re.sub(r'^\s+|\s+$', ' ', text) 48 | 49 | 50 | def compact_spaces(text: str) -> str: 51 | return ' '.join(text.split()) 52 | 53 | 54 | class Preprocessor(ABC): 55 | """This class is used to preprocess the prompt before it is passed to the model.""" 56 | 57 | def __init__(self): 58 | pass 59 | 60 | @abstractmethod 61 | def __call__(self, prompt: str) -> str: 62 | raise NotImplementedError 63 | 64 | 65 | class PreprocessorList(Preprocessor): 66 | """This class is used to preprocess the prompt before it is passed to the model.""" 67 | pre: List[Preprocessor] 68 | 69 | def __init__(self, pre: List[Preprocessor]): 70 | self.pre = pre 71 | 72 | super().__init__() 73 | 74 | def __call__(self, prompt: str) -> str: 75 | for p in self.pre: 76 | prompt = p(prompt) 77 | return prompt 78 | 79 | 80 | class CompactSpaces(Preprocessor): 81 | """This class is used to remove all leading and trailing whitespaces.""" 82 | only_surrounding: bool 83 | 84 | def __init__(self, only_surrounding: bool = False): 85 | super().__init__() 86 | self.only_surrounding = only_surrounding 87 | 88 | def __call__(self, prompt: str) -> str: 89 | 90 | if self.only_surrounding: 91 | return compact_surrounding_spaces(prompt) 92 | else: 93 | return compact_spaces(prompt) 94 | 95 | 96 | class ModuleRef: 97 | """This class is used to represent a module reference in the prompt.""" 98 | name: str 99 | args: List[Argument] 100 | modules: List[ModuleRef] 101 | 102 | def __init__(self, spec: lxml.etree.Element = None): 103 | if spec is not None: 104 | self._process(spec) 105 | 106 | def _process(self, root: lxml.etree.Element): 107 | self.name = root.tag 108 | self.args = [Argument(name, value) for name, value in root.attrib.items()] 109 | 110 | self.modules = [] 111 | 112 | # leading text 113 | if root.text is not None and len(root.text.strip()) > 0: 114 | raise ValueError("Module reference cannot have text") 115 | 116 | for e in root: 117 | self.modules.append(ModuleRef(e)) 118 | if e.tail is not None and len(e.tail.strip()) > 0: 119 | raise ValueError("Module reference cannot have text") 120 | 121 | def __repr__(self) -> str: 122 | 123 | args = " ".join([f"{arg.name}={repr(arg.value)}" for arg in self.args]) 124 | if len(args) > 0: 125 | r = f"@{self.name}({args})" 126 | else: 127 | r = f"@{self.name}" 128 | 129 | for m in self.modules: 130 | r += '\n' + repr_indent(m) 131 | return r 132 | 133 | 134 | @dataclass 135 | class Argument: 136 | name: str 137 | value: str 138 | 139 | 140 | class Prompt(ModuleRef): 141 | """This class is used to represent a prompt.""" 142 | schema: str 143 | text: str 144 | 145 | preproc: List[Preprocessor] 146 | 147 | def __init__(self, spec: Union[str, lxml.etree.Element], preproc: Optional[List[Preprocessor]] = None): 148 | 149 | super().__init__() 150 | self.args = [] 151 | self.preproc = preproc if preproc is not None else [] 152 | self.text = "" 153 | if type(spec) == str: 154 | 155 | for p in self.preproc: 156 | spec = p(spec) 157 | 158 | parser = lxml.etree.XMLParser(recover=True) 159 | spec = lxml.etree.fromstring(spec, parser=parser) 160 | 161 | self._process(spec) 162 | 163 | def _process(self, root: lxml.etree.Element): 164 | 165 | assert root.tag == "prompt" 166 | 167 | if "schema" in root.attrib: 168 | self.schema = root.attrib["schema"] 169 | else: 170 | self.schema = "" # empty schema 171 | 172 | self.name = self.schema 173 | self.modules = [] 174 | 175 | # text only prompt 176 | if len(root) == 0: 177 | text = root.text 178 | if text is None: 179 | raise ValueError("Prompt cannot be empty") 180 | 181 | self.text = compact_surrounding_spaces(text) 182 | 183 | else: 184 | if root.text is not None and len(root.text.strip()) > 0: 185 | raise ValueError("Prompt cannot have leading text") 186 | 187 | tail = False 188 | for e in root: 189 | if tail: 190 | raise ValueError("Prompt cannot have text between module references") 191 | 192 | self.modules.append(ModuleRef(e)) 193 | 194 | if e.tail is not None: 195 | self.text = compact_surrounding_spaces(e.tail) 196 | 197 | self.text = self.text.strip() 198 | 199 | def add_text(self, text: str): 200 | for p in self.preproc: 201 | text = p(text) 202 | 203 | self.text += text 204 | 205 | def __repr__(self) -> str: 206 | r = f"Schema: @{self.name}" 207 | for m in self.modules: 208 | r += '\n' + repr_indent(m) 209 | r += '\nText: ' + repr(self.text) 210 | return r 211 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | astor==0.8.1 2 | datasets==2.14.7 3 | fire==0.6.0 4 | fuzzywuzzy==0.18.0 5 | jieba==0.42.1 6 | lxml==5.2.1 7 | numpy==1.26.4 8 | Requests==2.31.0 9 | rouge==1.0.1 10 | termcolor==2.4.0 11 | torch==2.3.0 12 | tqdm==4.66.4 13 | transformers==4.34.0 -------------------------------------------------------------------------------- /score.py: -------------------------------------------------------------------------------- 1 | import fire 2 | import json 3 | 4 | from metrics import ( 5 | qa_f1_score, 6 | rouge_zh_score, 7 | qa_f1_zh_score, 8 | rouge_score, 9 | classification_score, 10 | retrieval_score, 11 | retrieval_zh_score, 12 | count_score, 13 | code_sim_score 14 | ) 15 | 16 | dataset2metric = { 17 | "narrativeqa": qa_f1_score, 18 | "qasper": qa_f1_score, 19 | "multifieldqa_en": qa_f1_score, 20 | "multifieldqa_zh": qa_f1_zh_score, 21 | "hotpotqa": qa_f1_score, 22 | "2wikimqa": qa_f1_score, 23 | "musique": qa_f1_score, 24 | "dureader": rouge_zh_score, 25 | "gov_report": rouge_score, 26 | "qmsum": rouge_score, 27 | "multi_news": rouge_score, 28 | "vcsum": rouge_zh_score, 29 | "trec": classification_score, 30 | "triviaqa": qa_f1_score, 31 | "samsum": rouge_score, 32 | "lsht": classification_score, 33 | "passage_retrieval_en": retrieval_score, 34 | "passage_count": count_score, 35 | "passage_retrieval_zh": retrieval_zh_score, 36 | "lcc": code_sim_score, 37 | "repobench-p": code_sim_score, 38 | } 39 | 40 | 41 | def main(result_file): 42 | dataset_name = result_file.split('/')[-3].split('-')[1] 43 | with open(result_file, 'r') as f: 44 | # load line by line 45 | results = [json.loads(line) for line in f] 46 | total_score = 0. 47 | for result in results: 48 | response = result["response"] 49 | answers = result["answers"] 50 | 51 | score = 0. 52 | for answer in answers: 53 | score = max(score, dataset2metric[dataset_name](response, answer)) 54 | 55 | total_score += score 56 | print(total_score / len(results) * 100) 57 | 58 | 59 | if __name__ == '__main__': 60 | fire.Fire(main) 61 | -------------------------------------------------------------------------------- /scripts/README.md: -------------------------------------------------------------------------------- 1 | # How to run the scripts 2 | Inside this directory you can simply run: 3 | ```bash 4 | python run_benchmarks.py 5 | ``` 6 | It will automatically run all the benchmarks based on `benchmark_setup.json`. 7 | Note) You may want to clean up `../benchmark/results`. 8 | -------------------------------------------------------------------------------- /scripts/benchmark_setup.json: -------------------------------------------------------------------------------- 1 | { 2 | "default": { 3 | "llm_config_path": "./config/llm_config_llama2.json", 4 | "split": "0,1" 5 | }, 6 | "benchmarks": [ 7 | 8 | {"dataset": "narrativeqa"}, 9 | {"dataset": "qasper"}, 10 | {"dataset": "multifieldqa_en"}, 11 | {"dataset": "hotpotqa"}, 12 | {"dataset": "2wikimqa"}, 13 | {"dataset": "musique"}, 14 | {"dataset": "gov_report"}, 15 | {"dataset": "qmsum"}, 16 | {"dataset": "multi_news_long"}, 17 | {"dataset": "trec"}, 18 | {"dataset": "triviaqa"}, 19 | {"dataset": "passage_count"}, 20 | {"dataset": "passage_retrieval_en"} 21 | ], 22 | "okay_so_far_benchmarks" :[ 23 | 24 | ], 25 | "skipped_benchmarks": [ 26 | {"dataset": "multi_news", "reason": "We will use Longbench version, multi_news_log"}, 27 | {"dataset": "lcc", "reason": "code syntax interferes with xml syntax of prompt cache"}, 28 | {"dataset": "repobench-p", "reason": "code syntax interferes with xml syntax of prompt cache"}, 29 | {"dataset": "samsum", "reason": "code syntax interferes with xml syntax of prompt cache"}, 30 | {"dataset": "lsht", "reason": "out of memory"}, 31 | {"dataset": "vcsum", "reason": "out of memory"}, 32 | {"dataset": "dureader", "reason": "out of memory"}, 33 | {"dataset": "ms_marco", "reason": "bad output"} 34 | ], 35 | "llm_list": [ 36 | { 37 | "llm": "falcon", 38 | "config_name": "./config/llm_config_falcon_40b.json" 39 | }, 40 | { 41 | "llm": "mpt", 42 | "config_name": "./config/llm_config_mpt_30b.json" 43 | }, 44 | { 45 | "llm": "llama2", 46 | "config_name": "./config/llm_config_llama2_13b.json" 47 | } 48 | ] 49 | } 50 | -------------------------------------------------------------------------------- /scripts/eval_llama2.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/zsh 2 | #SBATCH --mem=128g 3 | #SBATCH --nodes=1 4 | #SBATCH --ntasks-per-node=1 5 | #SBATCH --cpus-per-task=32 # <- match to OMP_NUM_THREADS 6 | #SBATCH --partition=gpuA100x4 # <- or one of: gpuA100x4 gpuA40x4 gpuA100x8 gpuMI100x8 7 | #SBATCH --account=bccn-delta-gpu 8 | #SBATCH --job-name=eval_acc 9 | #SBATCH --time=10:00:00 # hh:mm:ss for the job 10 | ### GPU options ### 11 | #SBATCH --gpus-per-node=4 12 | ##SBATCH --gpu-bind=none # <- or closest 13 | #SBATCH --mail-user=in.gim@yale.edu.edu 14 | #SBATCH --mail-type="BEGIN,END" 15 | 16 | module reset # drop modules and explicitly load the ones needed 17 | # (good job metadata and reproducibility) 18 | # $WORK and $SCRATCH are now set 19 | module load python # ... or any appropriate modules 20 | module list # job documentation and metadata 21 | echo "job is starting on `hostname`" 22 | srun python3 run_benchmarks.py -------------------------------------------------------------------------------- /scripts/run_benchmarks.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import logging 3 | import json 4 | import os 5 | import sys 6 | import datetime 7 | from concurrent.futures import ThreadPoolExecutor 8 | import threading 9 | import queue 10 | 11 | 12 | def detect_nvidia_gpus(): 13 | try: 14 | result = subprocess.run(['nvidia-smi', '--query-gpu=name', '--format=csv,noheader'], capture_output=True, text=True) 15 | num_gpus = len(result.stdout.strip().split('\n')) 16 | return num_gpus 17 | except Exception as e: 18 | logging.error(f"An error occurred: {e}") 19 | return 0 20 | 21 | def read_args_from_json(json_path): 22 | with open(json_path, 'r') as f: 23 | return json.load(f) 24 | 25 | def construct_python_commands(default_args, benchmarks, llm_list): 26 | commands = [] 27 | for llm in llm_list: 28 | _llm_name = llm["llm"] 29 | llm_config = llm["config_name"] 30 | for benchmark in benchmarks: 31 | for enable_cache in [True, False]: 32 | split = benchmark.get("split", 1) 33 | for index in range(split): 34 | command = "python3 eval.py" 35 | merged_args = {**default_args, **benchmark, "llm_config_path": llm_config} 36 | for key, value in merged_args.items(): 37 | if key == "enable_cache": 38 | continue 39 | elif key == "split": 40 | command += f" --split {index},{split}" 41 | else: 42 | command += f" --{key} {value}" 43 | command += f" --enable_cache {enable_cache}" 44 | command += f" --test_latency False" 45 | commands.append(command) 46 | return commands 47 | 48 | global python_commands_list 49 | def gpu_worker(gpu_id, command_lock): 50 | while True: 51 | with command_lock: 52 | global next_command_index 53 | if next_command_index >= len(python_commands_list): 54 | return # All commands are executed, so exit the thread 55 | command = python_commands_list[next_command_index] 56 | next_command_index += 1 57 | 58 | env_command = f"CUDA_VISIBLE_DEVICES={gpu_id} " + command 59 | try: 60 | subprocess.run("cd .. && " + env_command, shell=True, check=True) 61 | logging.info(f"Worker using GPU {gpu_id}: Command {command} completed successfully.") 62 | except subprocess.CalledProcessError as e: 63 | logging.error(f"Worker using GPU {gpu_id}: Command {command} failed: {e}") 64 | 65 | 66 | def main(): 67 | # logging.basicConfig(level=logging.INFO) 68 | logging.basicConfig(filename=f'benchmark_results_{datetime.date.today().strftime("%Y%m%d_%H%M%S")}.log', level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S') 69 | num_gpus = detect_nvidia_gpus() 70 | logging.info(f"Detected {num_gpus} Nvidia GPUs.") 71 | 72 | # Read arguments from JSON file 73 | args_dict = read_args_from_json("benchmark_setup.json") 74 | 75 | global python_commands_list 76 | # Construct the Python commands 77 | python_commands_list = construct_python_commands(args_dict["default"], args_dict["benchmarks"], args_dict["llm_list"]) 78 | logging.info(f"Constructed {len(python_commands_list)} benchmarks.") 79 | 80 | global next_command_index 81 | next_command_index = 0 82 | command_lock = threading.Lock() 83 | # Start a thread for each GPU 84 | threads = [] 85 | for gpu_id in range(num_gpus): 86 | t = threading.Thread(target=gpu_worker, args=(gpu_id, command_lock)) 87 | t.start() 88 | threads.append(t) 89 | 90 | # Wait for all threads to finish 91 | for t in threads: 92 | t.join() 93 | 94 | if __name__ == "__main__": 95 | main() 96 | -------------------------------------------------------------------------------- /tests/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | def build_mpt_alibi_tensor(num_heads, sequence_length, alibi_bias_max=8, device='cpu'): 5 | r""" 6 | Link to paper: https://arxiv.org/abs/2108.12409 - Alibi tensor is not causal as the original paper mentions, it 7 | relies on a translation invariance of softmax for quick implementation. This implementation has been copied from 8 | the alibi implementation of MPT source code that led to slightly different results than the Bloom alibi: 9 | https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L292 10 | """ 11 | alibi = torch.arange(1 - sequence_length, 1, dtype=torch.int32, device=device).view(1, 1, 1, sequence_length) 12 | num_heads_power_of_2 = 2 ** math.ceil(math.log2(num_heads)) 13 | 14 | base = torch.arange(1, num_heads_power_of_2 + 1, dtype=torch.float32, device=device) 15 | base = base * (alibi_bias_max / num_heads_power_of_2) 16 | 17 | slopes = 1.0 / torch.pow(2, base) 18 | slopes = slopes.view(1, num_heads, 1, 1) 19 | 20 | if num_heads_power_of_2 != num_heads: 21 | slopes = torch.concat([slopes[1::2], slopes[::2]])[:num_heads] 22 | 23 | alibi = alibi * slopes 24 | return alibi.squeeze(0) 25 | 26 | 27 | table = build_mpt_alibi_tensor(16, 20) 28 | print(table.shape) 29 | print(table) -------------------------------------------------------------------------------- /tests/test_document_summary.py: -------------------------------------------------------------------------------- 1 | # Add the parent directory to the sys.path list 2 | import os, sys 3 | document_summary_path = os.path.abspath(os.path.dirname(__file__)) 4 | sys.path.append(os.path.abspath(os.path.join(document_summary_path, '..'))) 5 | 6 | import unittest 7 | from benchmark.multi_news import MultiNews 8 | 9 | ## The following code is for testing the MultiNews class 10 | class TestMultiNews(unittest.TestCase): 11 | def setUp(self): 12 | self.document_summary = MultiNews() 13 | self.document_summary.init(verbose=False) 14 | 15 | def test_get_next_query(self): 16 | query_id, query, modules = self.document_summary.get_next_query() 17 | self.assertIsInstance(query_id, int) 18 | self.assertGreaterEqual(query_id, 0) 19 | self.assertIsInstance(query, str) 20 | self.assertEqual(query, "") # query is empty string for summary benchmark 21 | self.assertIsInstance(modules, list) 22 | self.assertGreaterEqual(len(modules), 1) 23 | for module in modules: 24 | self.assertIsInstance(module, str) 25 | 26 | def test_evaluate(self): 27 | query_id, query, modules = self.document_summary.get_next_query() 28 | response = self.document_summary.evaluate(query_id, "response") 29 | self.assertIsInstance(response, float) 30 | self.assertGreaterEqual(response, 0.) 31 | self.assertLessEqual(response, 1.) 32 | 33 | if __name__ == '__main__': 34 | unittest.main() 35 | --------------------------------------------------------------------------------