├── .gitignore
├── .gitmodules
├── LICENSE
├── Makefile
├── README.md
├── benchmark
├── .gitignore
├── __init__.py
├── benchmark_base.py
├── dataset_download.py
├── longbench.py
├── metrics.py
├── ms_marco_v1_1.py
├── multi_news.py
├── profile_parser.py
├── profiles
│ └── document_summary_simple.json
├── results
│ └── README.md
├── schema
│ ├── README.md
│ └── test
│ │ ├── empty.xml
│ │ ├── prompt_mbti.xml
│ │ ├── schema_code_generation.xml
│ │ ├── schema_long_task_1.xml
│ │ ├── schema_mbti.xml
│ │ ├── schema_mbti_short.xml
│ │ ├── schema_persona.xml
│ │ └── schema_persona_long.xml
├── squad_v2.py
└── utils.py
├── benchmark_memcpy.py
├── config
├── dataset_maxlen.json
├── dataset_prompt.json
├── llm_config_falcon_40b.json
├── llm_config_falcon_7b.json
├── llm_config_llama2_13b.json
├── llm_config_llama2_7b.json
├── llm_config_longchat_7b.json
├── llm_config_mpt_30b.json
├── llm_config_mpt_7b.json
└── llm_config_vicuna_7b.json
├── demo.py
├── eval.py
├── eval_acc.py
├── eval_acc.slurm
├── eval_sys.py
├── eval_sys_a100-7b-cpu.slurm
├── eval_sys_a100-7b-gpu.slurm
├── eval_sys_a40-7b-cpu.slurm
├── eval_sys_a40-7b-gpu.slurm
├── examples
├── code_generation_bookstore.xml
├── code_generation_game.xml
├── parameterized_prompts.xml
├── persona_generation.xml
└── personalization-education.xml
├── g.sh
├── get_scores.py
├── metrics.py
├── promptcache
├── __init__.py
├── cache_engine.py
├── compiler.py
├── conversation.py
├── generation_engine.py
├── inference.py
├── model
│ ├── __init__.py
│ ├── falcon.py
│ ├── llama2.py
│ └── mpt.py
├── prompt.py
└── schema.py
├── requirements.txt
├── score.py
├── scripts
├── README.md
├── benchmark_setup.json
├── eval_llama2.slurm
└── run_benchmarks.py
└── tests
├── test.py
└── test_document_summary.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | results/
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | .idea/
161 |
162 | # VSCode workspace
163 | *.code-workspace
164 | .vscode/
165 |
166 | # Models
167 | meta-llama/
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "dependency/LongBench"]
2 | path = dependency/LongBench
3 | url = https://github.com/THUDM/LongBench.git
4 | ignore = dirty
5 | [submodule "dependency/bleurt"]
6 | path = dependency/bleurt
7 | url = https://github.com/google-research/bleurt
8 | ignore = dirty
9 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 In Gim
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: eval
2 |
3 | # LLM_CONFIG_PATH ?= ./config/llm_config_falcon.json
4 | LLM_CONFIG_PATH ?= ./config/llm_config_llama2.json
5 | DATASET ?= squad_v2
6 | # DATASET ?= multi_news
7 | # DATASET ?= ms_marco
8 | # DATASET ?= narrativeqa
9 | ENABLE_CACHE ?= False
10 | SPLIT ?= 0,1
11 | TEST_LATENCY ?= False
12 | USE_CPU_FOR_INFERENCE ?= False
13 | eval:
14 | CUDA_VISIBLE_DEVICES=0 python3 eval.py \
15 | --llm_config_path $(LLM_CONFIG_PATH) \
16 | --dataset $(DATASET) \
17 | --enable_cache $(ENABLE_CACHE) \
18 | --split $(SPLIT) \
19 | --test_latency $(TEST_LATENCY) \
20 | --use_cpu_for_inference $(USE_CPU_FOR_INFERENCE) \
21 | --verbose False
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Prompt Cache
2 |
3 | This is the repository for the [Prompt Cache: Modular Attention Reuse For Low-Latency Inference](https://arxiv.org/abs/2311.04934) paper. This repository includes the implementation and evaluation tools to demonstrate prompt caching technique.
4 |
5 | ## Getting Started
6 |
7 | ### Installation
8 |
9 | To begin using the Prompt Cache, you need to install the required dependencies:
10 |
11 | ```bash
12 | pip install -r requirements.txt
13 | ```
14 |
15 | For evaluations involving `LongBench`, additional installation of `bleurt` from the source is necessary:
16 |
17 | ```bash
18 | cd ./dependency/bleurt
19 | pip install .
20 | ```
21 |
22 | ### Supported Architectures
23 |
24 | The Prompt Cache extends the `transformers` library and is compatible with several Large Language Model (LLM) architectures. The inference engine currently supports:
25 |
26 | - **Llama2**: For example, use `meta-llama/Llama-2-7b-chat-hf` or `codellama/CodeLlama-7b-Instruct-hf`.
27 | - **Falcon**: Example configuration is `tiiuae/falcon-7b-instruct`.
28 | - **MPT**: An example model is `mosaicml/mpt-7b-chat-8k`.
29 |
30 | Model weights for these architectures are automatically retrieved from the Hugging Face model hub.
31 |
32 | ### Demo
33 |
34 | Explore the capabilities with our demo:
35 |
36 | ```bash
37 | python demo.py
38 | ```
39 |
40 | In the demo script `demo.py`, the `use_cache` flag can be toggled to enable or disable the Prompt Cache feature. You are encouraged to modify the script to experiment with different prompts and schemas.
41 |
42 | The demo leverages a schema from `./examples/code_generation_game.xml` with the following prompt setup:
43 |
44 | ```xml
45 |
46 |
47 |
48 |
49 |
50 |
51 | Create a main entry for the game:
52 |
53 | ```
54 |
55 | ### Additional Schemas
56 |
57 | Find more examples of schemas in the `./examples` directory, which can serve as templates for creating your custom prompts and schemas. For instance, using the schema found at [/examples/persona_generation.xml](./benchmark/schema_persona_long.xml), you can create a prompt like the following:
58 | ```xml
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 | Introduce about yourself.
81 |
82 |
83 | ```
84 |
85 | Executing the above prompt will generate the following response, GPU: NVIDIA RTX 4090 / CPU: Intel i9-13900K
86 |
87 | - **Baseline (with KV cache)** (TTFT: GPU: 286.9ms, CPU: 17917.14 ms)
88 |
89 | ```
90 | User: Introduce about yourself.
91 | Assistant: Hello! My name is Alex, and I am a 25-year-old woman who lives in a beautiful coastal town. I have recently completed my
92 | Ph.D. in computer science and am now working as a software engineer at a leading tech company. When I'm not working, I
93 | enjoy spending time outdoors, reading books, and trying out new recipes in the kitchen. I am also very passionate about
94 | my work and am constantly looking for opportunities to learn and grow in my
95 | ```
96 |
97 | - **With PromptCache** (TTFT: GPU: 78.2ms, CPU: 511.31ms)
98 |
99 | ```
100 | User: Introduce about yourself.
101 | Assistant: Hello! My name is Alex, and I'm a 28-year-old software developer who recently moved to the coastal town of Oceanview. I have a doctorate degree in computer science and specialize in artificial intelligence and machine learning. I work as a lead engineer at a tech startup that focuses on developing innovative AI solutions for various industries.
102 | In my free time, I enjoy exploring the beautiful beaches and hiking trails around Oceanview with my wife, Sarah. We met while we were both pursuing our graduate degrees at Stanford University, and we've been inseparable ever since. We have two adorable kids, Emily and Jack, who keep us busy and bring us endless joy.
103 | As an introverted person, I often prefer spending time alone or with my close family members, but I also make an effort to connect with others through social events and community activities. I believe in being honest, kind, and respectful towards everyone, regardless of their background or beliefs.
104 | ```
105 |
106 | ## Prompt Markup Language (PML)
107 |
108 | Schema and prompts are written in Prompt Markup Language (PML). PML is a simple XML-based language that allows users to define schemas and prompts for the Prompt Cache.
109 |
110 | ### Writing schema with PML
111 |
112 | ```xml
113 |
114 |
115 |
116 |
117 |
118 |
119 |
124 |
125 | Just some text with parameter:
126 |
127 |
128 |
129 | Nested
130 |
131 |
137 |
138 |
139 |
143 |
144 | System prompt type 1.
145 |
146 |
150 | System prompt type 2.
151 | System prompt type 3,
152 | with parameter:
153 |
154 |
155 |
156 |
157 |
158 | User 1 information
159 | User 2 information
160 |
161 |
162 |
163 |
164 |
165 |
166 | Task description 1
167 |
168 |
169 | Task description 1
170 | Task description 1
171 |
172 |
173 |
174 |
175 |
176 | ```
177 |
178 | ### Writing Prompt with PML
179 |
180 | ```xml
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 | What will be the next movement of the robot?
194 | It will move forward.
195 | What would be the speed then?
196 |
197 | ```
198 |
199 | #### Compiling Python function into PML
200 |
201 | Compiler can be found in `promptcache/compiler.py`. You can use python decorators to compile python functions into PML.
202 |
203 | ```python
204 | @prompt
205 | def test_prompt(flag, item):
206 | ""
207 | "some system message"
208 | if flag:
209 | "only if flag is true"
210 | match item:
211 | case 1:
212 | "item is 1"
213 | case 2:
214 | "item is 2"
215 | "some user message"
216 |
217 | r = test_prompt(True, 1)
218 |
219 | print(r.get_schema())
220 | print(r.get_prompt())
221 | ```
222 |
223 |
224 |
225 | ### Benchmark and Evaluation
226 |
227 | You can run accuracy benchmarks on LongBench with
228 |
229 | ```bash
230 | python eval_acc.py --help
231 | ```
232 |
233 | To evaluate the inference time on LongBench, you can run the following command:
234 |
235 | ```bash
236 | python eval.py --help
237 | ```
238 |
239 | ### Citation
240 | ```
241 | @article{gim2023prompt,
242 | title={Prompt cache: Modular attention reuse for low-latency inference},
243 | author={Gim, In and Chen, Guojun and Lee, Seung-seob and Sarda, Nikhil and Khandelwal, Anurag and Zhong, Lin},
244 | journal={arXiv preprint arXiv:2311.04934},
245 | year={2023}
246 | }
247 | ```
248 |
--------------------------------------------------------------------------------
/benchmark/.gitignore:
--------------------------------------------------------------------------------
1 | # Any xml file should be added with git add -f
2 | *.xml
3 | results*
4 | BLEURT-20/
5 |
--------------------------------------------------------------------------------
/benchmark/__init__.py:
--------------------------------------------------------------------------------
1 | #
--------------------------------------------------------------------------------
/benchmark/benchmark_base.py:
--------------------------------------------------------------------------------
1 | # It provides APIs to load/initialize the dataset and evaluate the response from the llm
2 |
3 | # * Required API
4 | # * init() : download (one time) and load the dataset to run; do any preprocessing required for running this benchmark
5 | # * get_entry_count(): return the number of entries in the dataset.
6 | # * get_query(): return a list of Entry objects for the given range.
7 | import os, json
8 | import abc
9 | from typing import Tuple, List
10 |
11 | SCHEMA_FILE_DIRECTORY = "./benchmark/schema"
12 | DATASET_LIST = ["squad_v2", "multi_news", "wiki_qa", "pubmed_qa", "ms_marco", "narrativeqa", "qasper",
13 | "multifieldqa_en", "hotpotqa", "2wikimqa", "musique", "dureader", "gov_report", "qmsum", "multi_news_long",
14 | "vcsum", "trec", "triviaqa", "samsum", "lsht", "passage_count", "passage_retrieval_en", "lcc",
15 | "repobench-p"]
16 |
17 | DATASET_SUBSET = {
18 | "pubmed_qa": ["pqa_artificial", "pqa_labeled", "pqa_unlabeled"],
19 | "ms_marco": ["v1.1", "v2.1"]
20 | }
21 |
22 | class Entry:
23 | def __init__(self, schema, prompt, answer=None):
24 | """
25 | Constructor to initialize any required variables.
26 | [schema: str] path to the schema file, usage: cache_engine.add_schema(read_file(schema, preproc))
27 | [prompt: str] prompt text, which I should feed to the llm directly, it contains the used schema name and the question from dataset
28 | [answer: [str]] the potential answer list to the above question
29 | """
30 | self.schema = schema
31 | self.prompt = prompt
32 | self.answer = answer
33 |
34 | def __repr__(self) -> str:
35 | return f"Entry(schema={self.schema}, prompt={self.prompt}, answer={self.answer})"
36 |
37 |
38 | class Benchmark(abc.ABC):
39 | def __init__(self, dataset_name: str):
40 | """
41 | Constructor to initialize any required variables.
42 | """
43 | if dataset_name not in DATASET_LIST:
44 | raise ValueError("Dataset name cannot be None, valid dataset names are: " + ", ".join(DATASET_LIST))
45 |
46 | self.dataset_name = dataset_name
47 | self.dataset = None
48 | self.entries = []
49 | self.schema_path = os.path.join(SCHEMA_FILE_DIRECTORY, dataset_name)
50 | if not os.path.exists(self.schema_path):
51 | os.makedirs(self.schema_path)
52 |
53 | self.load_prompt()
54 |
55 | def load_prompt(self):
56 | base_path = os.path.dirname(os.path.abspath(__file__))
57 | dataset_prompt_path = os.path.join(base_path, '../config/dataset_prompt.json')
58 | with open(dataset_prompt_path, 'r') as f:
59 | self.dataset_prompt = json.load(f)
60 | self.dataset_prompt = self.dataset_prompt[self.dataset_name]
61 |
62 | @abc.abstractmethod
63 | def init(self, limit_entries=None):
64 | """
65 | Download (one time) and load the dataset to run;
66 | Preprocess the dataset to be organized in the `Entry` format.
67 | """
68 | raise NotImplementedError("This method should be overridden by subclass")
69 |
70 | def get_entry_count(self) -> int:
71 | """
72 | Return the number of entries in the dataset.
73 | """
74 | return len(self.entries)
75 |
76 | def get_query(self, range: Tuple[int, int]) -> List[Entry]:
77 | """
78 | Return a list of Entry objects for the given range.
79 | [range: (int, int)] the range of entries to return
80 | """
81 | return self.entries[range[0]:range[1]]
82 |
--------------------------------------------------------------------------------
/benchmark/dataset_download.py:
--------------------------------------------------------------------------------
1 | from datasets import load_dataset
2 |
3 | ## I will simply list up datasets below
4 |
5 | datasets = {}
6 |
7 |
8 | def load_documentation_summary():
9 | ## Summary dataset
10 | datasets['multi_news'] = load_dataset('multi_news')
11 | return datasets['multi_news']
12 | # print("Multi News\n", datasets['multi_news']['train'][0])
13 |
14 |
15 | def load_multidoc_qna():
16 | ## Open domain question answering
17 | # = version 2.1 =
18 | # datasets['ms_marco'] = load_dataset('ms_marco', 'v2.1')
19 | # print("MS_Marco", datasets['ms_marco']['train'][0])
20 |
21 | # = version 1.1 =
22 | datasets['ms_marco'] = load_dataset('ms_marco', 'v1.1')
23 | # print("MS_Marco", datasets['ms_marco']['train'][0])
24 | return datasets['ms_marco']
25 |
26 |
27 | pass
28 |
--------------------------------------------------------------------------------
/benchmark/longbench.py:
--------------------------------------------------------------------------------
1 | # import re
2 | #
3 | # from .benchmark_base import Benchmark, Entry
4 | # from .utils import XMLSchemaBuilder
5 | # from datasets import load_dataset
6 | # import os
7 | #
8 | #
9 | # def escape_tags(input_str):
10 | # # pattern = r'<(?P.*?)>'
11 | #
12 | # # # The lambda function ensures only the first letter is capitalized
13 | # # def repl(match):
14 | # # return '(' + match.group("content").capitalize() + ')'
15 | # #
16 | # # return re.sub(pattern, repl, input_str)
17 | # return input_str.replace('<', '(').replace('>', ')')
18 | #
19 | #
20 | # class LongBench(Benchmark):
21 | # def __init__(self, subset_name: str):
22 | # super().__init__(subset_name)
23 | #
24 | # def init(self, limit_entries=None):
25 | # """
26 | # Download (one time) and load the dataset to run;
27 | # Preprocess the dataset to be organized in the `Entry` format.
28 | # """
29 | # self.dataset = load_dataset('THUDM/LongBench', self.dataset_name)
30 | #
31 | # count = 0
32 | # for split in self.dataset.values():
33 | # for item in split:
34 | # if limit_entries is not None and count >= limit_entries:
35 | # break
36 | # schema_name = f"schema_{item['_id']}"
37 | #
38 | # fmt_schema = self.dataset_prompt["context"].format(context=escape_tags(item["context"]))
39 | # fmt_prompt = self.dataset_prompt["question"].format(input=escape_tags(item["input"])[:1000])
40 | # fmt_question = self.dataset_prompt["answer"]
41 | #
42 | # schema = f"""
43 | # {fmt_schema}
44 | # """
45 | #
46 | # prompt = f"""
47 | # {fmt_prompt}{fmt_question}
48 | # """
49 | # self.entries.append(Entry(schema, prompt, item["answers"]))
50 | #
51 | # count += 1
52 |
53 | import re
54 |
55 | from .benchmark_base import Benchmark, Entry
56 | from .utils import XMLSchemaBuilder
57 | from datasets import load_dataset
58 | import os
59 |
60 | _system_description = "Below are a series of dialogues between various people and an AI assistant. The AI tries to be helpful, polite, \
61 | honest, sophisticated, emotionally aware, and humble-but-knowledgeable. The assistant is happy to help with \
62 | almost anything, and will do its best to understand exactly what is needed. It also tries to avoid giving false \
63 | or misleading information, and it caveats when it isn't entirely sure about the right answer. That said, the \
64 | assistant is practical and really does its best, and doesn't let caution get too much in the way of being \
65 | useful."
66 | _user_description = "For the upcoming interaction, I would like you to answer some questions about the document."
67 | _assistant_description = "Sure. I have read the document. Please give me any question."
68 |
69 |
70 | def escape_tags(input_str):
71 | # pattern = r'<(?P.*?)>'
72 |
73 | # # The lambda function ensures only the first letter is capitalized
74 | # def repl(match):
75 | # return '(' + match.group("content").capitalize() + ')'
76 | #
77 | # return re.sub(pattern, repl, input_str)
78 | return input_str.replace('<', '(').replace('>', ')')
79 |
80 |
81 | class LongBench(Benchmark):
82 | def __init__(self, subset_name: str):
83 | super().__init__(subset_name)
84 |
85 | def init(self, limit_entries=None):
86 | """
87 | Download (one time) and load the dataset to run;
88 | Preprocess the dataset to be organized in the `Entry` format.
89 | """
90 | self.dataset = load_dataset('THUDM/LongBench', self.dataset_name)
91 |
92 | count = 0
93 | for split in self.dataset.values():
94 | for item in split:
95 | if limit_entries is not None and count >= limit_entries:
96 | break
97 | id = item["_id"]
98 | schema_name = f"schema_{id}"
99 |
100 | fmt_schema = self.dataset_prompt["context"].format(context=escape_tags(item["context"]))
101 |
102 | schema_content = f"""
103 |
104 |
105 |
106 | {fmt_schema}
107 |
108 |
109 | """
110 | #
111 | #
112 | # builder = XMLSchemaBuilder(schema_name)
113 | # context = item["context"]
114 | # # print(len(context))
115 | # # title = item["title"]
116 | # question = item["input"]
117 | # answer = item["answers"]
118 | # builder.set_system_description(_system_description)
119 | # builder.set_user_description(_user_description)
120 | # builder.add_document_module("context",
121 | # self.dataset_prompt["context"].format(context=escape_tags(context)))
122 | # builder.set_assistant_description(_assistant_description)
123 |
124 | schema_file_name = f"{schema_name}.xml"
125 | with open(os.path.join(self.schema_path, schema_file_name), "w") as f:
126 | f.write(schema_content)
127 |
128 | prompt = f"""
129 |
130 |
131 | {self.dataset_prompt["question"].format(input=escape_tags(item["input"])[:1000])}{self.dataset_prompt["answer"]}
132 | """
133 | self.entries.append(Entry(schema_file_name, prompt, item["answers"]))
134 |
135 | count += 1
136 |
137 |
138 | if __name__ == '__main__':
139 | sq = LongBench('narrativeqa')
140 | sq.init()
141 | print(sq.get_entry_count())
142 | print(sq.get_query((100, 101)))
143 |
--------------------------------------------------------------------------------
/benchmark/metrics.py:
--------------------------------------------------------------------------------
1 | import os, requests, zipfile, io, sys
2 | from tqdm import tqdm
3 | from bleurt import score
4 |
5 | BLEURT_20_URL = "https://storage.googleapis.com/bleurt-oss-21/BLEURT-20.zip"
6 | CHECKPOINT = "BLEURT-20"
7 |
8 | def download_bleurt_20():
9 | if not os.path.exists('./BLEURT-20/'):
10 | print("Downloading BLEURT-20 checkpoint...")
11 | with requests.get(BLEURT_20_URL, stream=True) as response:
12 | total_size = int(response.headers.get('content-length', 0))
13 | block_size = 1024 # 1 Kbyte
14 | buffer = io.BytesIO()
15 | # Initialize tqdm with the total file size, and use it as an iterable in a loop
16 | with tqdm(total=total_size, unit='B', unit_scale=True) as pbar:
17 | for data in response.iter_content(block_size):
18 | buffer.write(data)
19 | # Update tqdm about the downloaded data size
20 | pbar.update(len(data))
21 |
22 | buffer.seek(0)
23 |
24 | # Unzip the file
25 | print("Unzipping BLEURT-20 checkpoint...")
26 | with zipfile.ZipFile(buffer) as zip_file:
27 | # Extract all the contents into the current working directory
28 | zip_file.extractall(path=".")
29 |
30 | class BleurtScorer:
31 | def __init__(self):
32 | download_bleurt_20()
33 | self.scorer = score.BleurtScorer(CHECKPOINT)
34 |
35 | def score(self, refs=[str], hyps=[str]):
36 | return self.scorer.score(references=refs, candidates=hyps)
37 |
38 | if __name__ == '__main__':
39 | bs = BleurtScorer()
40 | print(bs.score(["i'm leo"], ["my name is leo"]))
41 |
--------------------------------------------------------------------------------
/benchmark/ms_marco_v1_1.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | # Add the parent directory to the sys.path list
5 | document_summary_path = os.path.abspath(os.path.dirname(__file__))
6 | # sys.path.append(os.path.abspath(os.path.join(document_summary_path, '..')))
7 |
8 | from .benchmark_base import Benchmark, Entry
9 | from .dataset_download import load_multidoc_qna
10 | from .utils import XMLSchemaBuilder
11 |
12 | _document_schema_name = "multi_document_qna"
13 | _document_header = "Document"
14 | _document_dataset = "validation"
15 | _document_system_description = "Dialogues between a user and an AI about the document provided by the user with the aim of being helpful, aware, and accurate."
16 | _document_assistant_description = "Sure. I have read the documents separated by comma. Give me any instructions regarding query information based on the documents, and I will try to follow them."
17 | _document_user_summary = "Among the list of given documents separated by a comma, find the document that is most useful to answer the query and return its index starting at [0]. All documents are unrelated, return [-1]. Do not reply using a complete sentence, and only give the answer in the following format: [1]" # Take a deep breath and think step-by-step.
18 |
19 | MAX_DOCUMENT_LENGTH = 2560
20 |
21 | class MSMarcoV1(Benchmark):
22 | def __init__(self):
23 | super().__init__("ms_marco")
24 | self.next_query_idx = 0
25 |
26 | def init(self, limit_entries=None, verbose=False):
27 | """
28 | Download (one time) and load the dataset to run;
29 | Do any preprocessing required for running this benchmark.
30 | """
31 | self.dataset = load_multidoc_qna()
32 | if verbose:
33 | print("Dataset loaded. First entry below:")
34 | print(self.dataset[_document_dataset][1])
35 | # Now we can generate xml files
36 | assert self.dataset is not None
37 | schema_file_name = "schema_summary_sample.xml"
38 | self._generate_xml(limit_entries)
39 |
40 | def _generate_xml(self, limit_entries):
41 | # Generate xml files
42 | # - In this version, we build the xml file per entry
43 | count = 0
44 | for document_idx in range(len(self.dataset[_document_dataset])):
45 | if limit_entries is not None and count >= limit_entries:
46 | break
47 | # Create an instance of XMLSchemaBuilder with the schema name "document_summary"
48 | query_idx = self.dataset[_document_dataset][document_idx]["query_id"]
49 | query_str = self.dataset[_document_dataset][document_idx]["query"]
50 | schema_name = f"{_document_schema_name}_{document_idx}_q{query_idx}"
51 | answer_list = self.dataset[_document_dataset][document_idx]["passages"]["is_selected"]
52 | answer_str = str(answer_list.index(True)) if True in answer_list else "-1"
53 | builder = XMLSchemaBuilder(schema_name)
54 |
55 | # Set the system description
56 | builder.set_system_description(_document_system_description)
57 |
58 | # Set the user description
59 | builder.set_user_description("") # _document_user_description
60 | # builder.add_document_module("DOC", "The given documents are the target for searching to answer the query. They can contain multiple sentences.")
61 |
62 |
63 | # Add document modules
64 | # - we need to collect passages into document str, separated by newline or \n
65 | document_str = ""
66 |
67 | for p_idx, passage in enumerate(self.dataset[_document_dataset][document_idx]["passages"]["passage_text"]):
68 | document_str += f"\"[{p_idx}] {passage},\n"
69 | document_str = document_str.replace("’", "'").replace("”", '"').replace("“", '"').replace("‘", "'").replace("…", "...").replace("–", "-")
70 |
71 | module_str = f"{_document_header}{document_idx}"
72 | builder.add_document_module(f"{module_str}", document_str)
73 |
74 | # Set assistant reply
75 | builder.set_assistant_description(_document_assistant_description)
76 |
77 | # Write the XML string to a file
78 | schema_file_name = f"{schema_name}.xml"
79 | with open(os.path.join(self.schema_path, schema_file_name), "w") as f:
80 | f.write(builder.generate_xml())
81 |
82 | # Prepare the entry
83 | prompt = \
84 | f"""
85 | <{module_str}/>
86 | {_document_user_summary} Query: "{query_str}"
87 |
88 | """
89 | self.entries.append(Entry(schema_file_name, prompt, answer_str))
90 |
91 | count += 1
92 |
93 | def get_next_query(self):
94 | """
95 | Return query_id (unsigned), query (string), and chosen modules (a list of string).
96 | """
97 | assert self.dataset is not None
98 | assert self.dataset[_document_dataset] is not None
99 | assert self.next_query_idx < len(self.dataset[_document_dataset])
100 | response = self.next_query_idx, "", [f"{_document_header}{self.next_query_idx}"]
101 | self.next_query_idx += 1
102 | return response
103 |
104 | def evaluate(self, query_id, response_from_llm):
105 | """
106 | Take query_id and response_from_llm as parameters and return a score in the range [0,1].
107 | """
108 | assert self.dataset is not None
109 | assert self.dataset[_document_dataset] is not None
110 | assert query_id < len(self.dataset[_document_dataset])
111 | assert response_from_llm is not None
112 | assert response_from_llm != ""
113 | raise NotImplementedError(
114 | "This method should call utility function to measure how the response is closer to the expected answer.")
115 |
116 | if __name__ == '__main__':
117 | msm = MSMarcoV1()
118 | msm.init()
119 | print(msm.get_entry_count())
120 | print(msm.get_query((0, 1)))
121 |
--------------------------------------------------------------------------------
/benchmark/multi_news.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | # Add the parent directory to the sys.path list
5 | document_summary_path = os.path.abspath(os.path.dirname(__file__))
6 | # sys.path.append(os.path.abspath(os.path.join(document_summary_path, '..')))
7 |
8 | from .benchmark_base import Benchmark, Entry
9 | from .dataset_download import load_documentation_summary
10 | from .utils import XMLSchemaBuilder
11 |
12 | _document_header = "Document"
13 | _document_dataset = "validation"
14 | _document_system_description = "Dialogues between a user and an AI about the document provided by the user with the aim of being helpful, aware, and accurate."
15 | _document_assistant_description = "Sure. I have read the document. give me any instructions regarding summarization, and I will try to follow them."
16 | _document_user_summary = "Summarize the above document in around THREE sentences:"
17 |
18 | MAX_DOCUMENT_LENGTH = 2560
19 |
20 | class MultiNews(Benchmark):
21 | def __init__(self):
22 | super().__init__("multi_news")
23 | self.next_query_idx = 0
24 |
25 | def init(self, limit_entries=None, verbose=False):
26 | """
27 | Download (one time) and load the dataset to run;
28 | Do any preprocessing required for running this benchmark.
29 | """
30 | self.dataset = load_documentation_summary()
31 | if verbose:
32 | print("Dataset loaded. First entry below:")
33 | print(self.dataset[_document_dataset][1])
34 | # Now we can generate xml files
35 | assert self.dataset is not None
36 | schema_file_name = "schema_summary_sample.xml"
37 | self._generate_xml(limit_entries)
38 |
39 | def _generate_xml(self, limit_entries):
40 | # Generate xml files
41 | # - In this version, we build the xml file per entry
42 | count = 0
43 | for document_idx in range(len(self.dataset[_document_dataset])):
44 | if limit_entries is not None and count >= limit_entries:
45 | break
46 | # Create an instance of XMLSchemaBuilder with the schema name "document_summary"
47 | schema_name = f"_document_schema_name_{document_idx}"
48 | builder = XMLSchemaBuilder(schema_name)
49 |
50 | # Set the system description
51 | builder.set_system_description(_document_system_description)
52 |
53 | # Set the user description
54 | builder.set_user_description("") # _document_user_description
55 |
56 | # Add document modules
57 | # builder.add_document_module("DOC", "The given documents are the target for the summarization task. They can contain multiple sentences.")
58 | document_str = self.dataset[_document_dataset][document_idx]["document"].replace("’", "'").replace("”",
59 | '"').replace(
60 | "“", '"').replace("‘", "'").replace("…", "...").replace("–", "-")
61 | if len(document_str) > MAX_DOCUMENT_LENGTH:
62 | document_str = document_str[:MAX_DOCUMENT_LENGTH]
63 | builder.add_document_module(f"{_document_header}{document_idx}", document_str)
64 |
65 | # Set assistant reply
66 | builder.set_assistant_description(_document_assistant_description)
67 |
68 | # Write the XML string to a file
69 | schema_file_name = f"{schema_name}.xml"
70 | with open(os.path.join(self.schema_path, schema_file_name), "w") as f:
71 | f.write(builder.generate_xml())
72 |
73 | # Prepare the entry
74 | prompt = f"""
75 |
76 | <{_document_header}{document_idx}/>
77 | {_document_user_summary}
78 | """
79 | summary_str = self.dataset[_document_dataset][document_idx]["summary"].replace("’", "'").replace("”",
80 | '"').replace(
81 | "“", '"').replace("‘", "'").replace("…", "...").replace("–", "-")
82 | self.entries.append(Entry(schema_file_name, prompt, summary_str))
83 |
84 | count += 1
85 |
86 | def get_next_query(self):
87 | """
88 | Return query_id (unsigned), query (string), and chosen modules (a list of string).
89 | """
90 | assert self.dataset is not None
91 | assert self.dataset[_document_dataset] is not None
92 | assert self.next_query_idx < len(self.dataset[_document_dataset])
93 | response = self.next_query_idx, "", [f"{_document_header}{self.next_query_idx}"]
94 | self.next_query_idx += 1
95 | return response
96 |
97 | def evaluate(self, query_id, response_from_llm):
98 | """
99 | Take query_id and response_from_llm as parameters and return a score in the range [0,1].
100 | """
101 | assert self.dataset is not None
102 | assert self.dataset[_document_dataset] is not None
103 | assert query_id < len(self.dataset[_document_dataset])
104 | assert response_from_llm is not None
105 | assert response_from_llm != ""
106 | raise NotImplementedError(
107 | "This method should call utility function to measure how the response is closer to the expected answer.")
108 |
109 | if __name__ == '__main__':
110 | mn = MultiNews()
111 | mn.init()
112 | print(mn.get_entry_count())
113 |
--------------------------------------------------------------------------------
/benchmark/profile_parser.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import sys
4 |
5 |
6 | class JsonParser:
7 | def __init__(self, file_path):
8 | # add this directory to the sys.path list
9 | # print(sys.path)
10 | print(f'Parsing {file_path}...')
11 | assert os.path.isfile(file_path)
12 | self.file_path = file_path
13 | self.data = None
14 |
15 | def parse(self):
16 | with open(self.file_path) as f:
17 | self.data = json.load(f)
18 |
19 | def get_data(self):
20 | return self.data
21 |
22 |
23 | class BenchmarkSetupParser(JsonParser):
24 | dataset_size_str = 'dataset_sizes'
25 |
26 | def __init__(self, file_path):
27 | super().__init__(file_path)
28 |
29 | def parse(self):
30 | super().parse()
31 |
32 | def get_data_size(self, size_name: str):
33 | assert self.data is not None
34 | assert size_name in self.data[BenchmarkSetupParser.dataset_size_str]
35 | return int(self.data[BenchmarkSetupParser.dataset_size_str][size_name])
36 |
37 |
38 | class BenchmarkProfileParser(JsonParser):
39 | def __init__(self, file_path):
40 | super().__init__(file_path)
41 | self.benchmark_setup_parser: BenchmarkSetupParser = None
42 |
43 | def parse(self):
44 | super().parse()
45 | # try to get the benchmark setup
46 | assert 'benchmark_dataset_name' in self.data
47 | setup_json_path = f'benchmark/{self.data["benchmark_dataset_name"]}/setup.json'
48 | self.benchmark_setup_parser = BenchmarkSetupParser(setup_json_path)
49 | self.benchmark_setup_parser.parse()
50 | print(f'Parsing benchmark setup from {setup_json_path}...')
51 | # print(f'Dataset size: {self.benchmark_setup_parser.get_data_size(self.data["dataset_size"])}')
52 |
53 | def get_benchmark_name(self):
54 | return self.data['benchmark_name']
55 |
56 | def get_benchmark_description(self):
57 | return self.data['benchmark_description']
58 |
59 | def get_benchmark_dataset_name(self):
60 | return self.data['benchmark_dataset_name']
61 |
62 | def get_benchmark_dataset_comment(self):
63 | return self.data['benchmark_dataset_comment']
64 |
65 | def get_dataset_size(self):
66 | return self.data['dataset_size']
67 |
68 | def get_prompt_cache(self):
69 | return bool(self.data['prompt_cache'])
70 |
--------------------------------------------------------------------------------
/benchmark/profiles/document_summary_simple.json:
--------------------------------------------------------------------------------
1 | {
2 | "benchmark_name": "Document summary simple",
3 | "benchmark_description": "Document summary benchmark with a few entries from multi news dataset",
4 | "benchmark_dataset_name": "document_summary",
5 | "benchmark_dataset_comment": "This is the path name of the benchmark dataset module; always under benchmark/",
6 | "dataset_size": "small",
7 | "prompt_cache": true
8 | }
--------------------------------------------------------------------------------
/benchmark/results/README.md:
--------------------------------------------------------------------------------
1 | Benchmark results will be stored here.
--------------------------------------------------------------------------------
/benchmark/schema/README.md:
--------------------------------------------------------------------------------
1 | Place the generated schema files here, please create a folder for your dataset, e.g. squad_v2, multi_news, wiki_qa, pubmed_qa, and ms_macro.
--------------------------------------------------------------------------------
/benchmark/schema/test/empty.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 | You are a helpful AI assistant.
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/benchmark/schema/test/prompt_mbti.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/benchmark/schema/test/schema_code_generation.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
13 |
14 | Just some text with parameter:
15 |
16 |
17 |
18 | Nested
19 |
20 |
26 |
27 |
28 |
32 |
33 | System prompt type 1.
34 |
35 |
39 | System prompt type 2.
40 | System prompt type 3,
41 | with parameter:
42 |
43 |
44 |
45 |
46 |
47 | User 1 information
48 | User 2 information
49 |
50 |
51 |
52 |
53 |
54 |
55 | Task description 1
56 |
57 |
58 | Task description 1
59 | Task description 1
60 |
61 |
62 |
63 |
--------------------------------------------------------------------------------
/benchmark/schema/test/schema_long_task_1.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
13 |
14 | Just some text with parameter:
15 |
16 |
17 |
18 | Nested
19 |
20 |
26 |
27 |
28 |
32 |
33 | System prompt type 1.
34 |
35 |
39 | System prompt type 2.
40 | System prompt type 3,
41 | with parameter:
42 |
43 |
44 |
45 |
46 |
47 | User 1 information
48 | User 2 information
49 |
50 |
51 |
52 |
53 |
54 |
55 | Task description 1
56 |
57 |
58 | Task description 1
59 | Task description 1
60 |
61 |
62 |
63 |
--------------------------------------------------------------------------------
/benchmark/schema/test/schema_mbti.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 | Below are a series of dialogues between various people and an AI assistant. The AI tries to be helpful, polite,
4 | honest, sophisticated, emotionally aware, and humble-but-knowledgeable. The assistant is happy to help with
5 | almost anything, and will do its best to understand exactly what is needed. It also tries to avoid giving false
6 | or misleading information, and it caveats when it isn't entirely sure about the right answer. That said, the
7 | assistant is practical and really does its best, and doesn't let caution get too much in the way of being
8 | useful.
9 |
10 |
11 |
12 | For the upcoming interaction, I would like you to create a hypothetical character based on a specific MBTI
13 | personality type. This personality type will include four indicators that shape the character's responses,
14 | demeanor, and overall approach to life.
15 |
16 |
17 |
18 | The Extroversion-Introversion indicator describes how a character responds to social situations and how
19 | they recharge their energy. This can be either Extroversion (E) or Introversion (I).
20 |
21 |
22 | This character has Extroversion (E) trait, and does not have Introversion (I) trait.
23 | Extroversion is characterized by a preference to focus on the world outside
24 | the self. Extraverts are energized by social gatherings, engaging in activities with others, and being
25 | expressive and assertive. They enjoy meeting new people, have a wide circle of friends, and are often
26 | perceived as popular and sociable. They are often more comfortable in groups, thrive in social
27 | situations, and prefer engaging with the external world.
28 |
29 |
30 | This character has Introversion (I) trait, and does not have Extroversion (E) trait.
31 | Introversion is characterized by a preference to focus on the internal world
32 | of thoughts and feelings. Introverts gain energy from spending time alone and prefer to interact with a
33 | small group of close friends. They are often perceived as reserved or reflective. They often prefer
34 | solitary activities or spending time with one or two close friends rather than in large social
35 | gatherings. They often prefer engaging with their internal world of thoughts and ideas.
36 |
37 |
38 |
39 |
40 |
41 |
42 | The Sensing-Intuition indicator describes how a character processes information and perceives the world
43 | around them. This can be either Sensing (S) or Intuition (N).
44 |
45 |
46 | This character has Sensing (S) trait, and does not have Intuition (N) trait.
47 | Sensing is characterized by a preference to focus on the present and on concrete information gained
48 | from the senses. Sensing types are often practical and realistic, and they prefer routine and order.
49 | They are often detail-oriented and observant and rely on their five senses to interpret the world. They
50 | prefer concrete, factual information rather than abstract concepts and theories. They are often
51 | pragmatic and grounded in reality.
52 |
53 |
54 | This character has Intuition (N) trait, and does not have Sensing (S) trait.
55 | Intuition is characterized by a preference to focus on the future and on possibilities. Intuitive
56 | types are often imaginative and creative, and they prefer new experiences and challenges. They are often
57 | more comfortable with theories and abstract concepts and enjoy discussing possibilities and what could
58 | be. They often rely on their intuition and are more interested in the big picture rather than the
59 | details.
60 |
61 |
62 |
63 |
64 |
65 | The Thinking-Feeling indicator describes how a character makes decisions and evaluates situations. This
66 | can be either Thinking (T) or Feeling (F).
67 |
68 |
69 | This character has Thinking (T) trait, and does not have Feeling (F) trait.
70 | Thinking is characterized by a preference to make decisions based on logic and objective analysis.
71 | Thinking types often prioritize fairness and efficiency in their decisions and may sometimes overlook
72 | the impact on people. They often approach problems and decisions logically and objectively, and they
73 | value truth and justice over harmony and cooperation. They often prefer to evaluate situations
74 | objectively and make decisions based on logic rather than emotion.
75 |
76 |
77 | This character has Feeling (F) trait, and does not have Thinking (T) trait.
78 | Feeling is characterized by a preference to make decisions based on personal values and the impact
79 | on others. Feeling types often prioritize harmony and empathy in their interactions and may sometimes
80 | overlook logical implications. They often approach problems and decisions with a people-centered
81 | perspective, and they value compassion and cooperation over efficiency and fairness. They often prefer
82 | to evaluate situations subjectively and make decisions based on personal values and beliefs.
83 |
84 |
85 |
86 |
87 |
88 | The Judging-Perceiving indicator describes how a character approaches life, either with a structured and
89 | planned approach or a flexible and adaptable one. This can be either Judging (J) or Perceiving (P).
90 |
91 |
92 | This character has Judging (J) trait, and does not have Perceiving (P) trait.
93 | Judging is characterized by a preference for structure and organization. Judging types often like
94 | to make plans and stick to them, and they prefer clarity and closure in their decisions. They are often
95 | decisive and organized, and they value predictability and stability. They often prefer to have a plan
96 | and follow it rather than being spontaneous and flexible.
97 |
98 |
99 | This character has Perceiving (P) trait, and does not have Judging (J) trait.
100 | Perceiving is characterized by a preference for flexibility and spontaneity. Perceiving types often
101 | like to keep their options open and adapt to new information as it comes in. They are often flexible and
102 | adaptable, and they value spontaneity and freedom. They often prefer to go with the flow and adapt to
103 | changes rather than sticking to a plan.
104 |
105 |
106 |
107 | Once you have created this character, I will ask questions about 'him/her' and you will respond based on the
108 | persona you have created. Remember to maintain the consistency of the character's traits throughout the
109 | conversation. Are you ready to create the character?
110 |
111 |
112 |
113 | Great, thank you for providing the MBTI personality type. Based on the provided personality traits, I have created
114 | a hypothetical character named Alex. Now feel free to ask any questions about him, and I will respond based on
115 | his persona.
116 |
117 |
118 |
119 |
120 |
--------------------------------------------------------------------------------
/benchmark/schema/test/schema_mbti_short.xml:
--------------------------------------------------------------------------------
1 |
2 | Dialogues between people and an AI with the aim of being helpful, aware, and accurate.
3 |
4 | Create a character based on MBTI. This character have:
5 |
6 |
7 | E (focuses on the outside world and is energized by social interactions.)
8 | I (prefers solitude and is energized by internal thoughts.)
9 |
10 |
11 |
12 | S (practical, detail-oriented, prefers concrete information.)
13 | N (imaginative, creative, focused on possibilities.)
14 |
15 |
16 |
17 | T (makes decisions based on logic.)
18 | F (values personal impact and harmony.)
19 |
20 |
21 |
22 | J (prefers structure and plans.)
23 | P (flexible and spontaneous.)
24 |
25 | Ready to create the character?
26 |
27 | A character named Alex is created based on MBTI. Ask anything about him.
28 |
29 |
--------------------------------------------------------------------------------
/benchmark/schema/test/schema_persona.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | Below are a series of dialogues between various people and an AI assistant. The AI tries to be helpful, polite,
5 | honest, sophisticated, emotionally aware, and humble-but-knowledgeable. The assistant is happy to help with
6 | almost anything, and will do its best to understand exactly what is needed. It also tries to avoid giving false
7 | or misleading information, and it caveats when it isn't entirely sure about the right answer. That said, the
8 | assistant is practical and really does its best, and doesn't let caution get too much in the way of being
9 | useful.
10 |
11 |
12 |
13 | For the upcoming interaction, I would like you to create a hypothetical character based on seven
14 | combination of traits: age, residence, education, occupation, martial-status, and personality. This will shape
15 | the character's responses, demeanor, and overall approach to life.
16 |
17 |
18 |
19 |
20 |
21 | This person is between 0 and 12 years old. This age group represents children from birth
22 | to pre-adolescence. They are typically dependent on parents or guardians and are in the early stages
23 | of
24 | physical and mental development.
25 |
26 |
27 | This person is between 13 and 19 years old. This age group represents adolescents who are
28 | experiencing physical changes and emotional development. They are often in secondary school and
29 | beginning to gain some independence.
30 |
31 |
32 | This person is between 20 and 34 years old. This age group represents adults who are
33 | often completing their education, starting their careers, and may be living independently for the
34 | first
35 | time. They are exploring relationships and may be starting families.
36 |
37 |
38 | This person is between 35 and 54 years old. This age group represents adults who are
39 | often established in their careers and may have growing families. They may be experiencing changes
40 | in
41 | their physical health and may be taking care of aging parents.
42 |
43 |
44 | This person is 55+ years old. This age group represents older adults who may be
45 | retiring or already retired. They may have grown children and may be experiencing health challenges
46 | associated with aging.
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 | This person lives in a city with a dense population, often with access to many amenities, public
56 | transportation, and cultural attractions. However, it may also come with challenges such as noise,
57 | pollution, and a higher cost of living.
58 |
59 |
60 | This person lives in the suburbs, areas that are often residential and located just outside of a
61 | city. Suburbs often offer a quieter environment, more green space, and may be more family-friendly.
62 |
63 |
64 | This person lives in the countryside, often in a less populated area with open spaces and natural
65 | surroundings. It may offer a slower pace of life but may also have fewer amenities and services
66 | available.
67 |
68 |
69 | This person lives in an area near the sea, often with access to beaches and water activities. It
70 | may offer a relaxed lifestyle and scenic views but may also come with challenges such as extreme
71 | weather
72 | conditions.
73 |
74 |
75 | This person lives in a mountainous area, often with access to outdoor activities such as hiking
76 | and skiing. It may offer a peaceful and natural environment but may also have challenges such as
77 | harsh
78 | winters and limited access to services.
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 | This person had no formal education, which may limit job opportunities and earning potential. It
88 | may also affect one's ability to read and write or to access information and services.
89 |
90 |
91 | This person had completed high school, which is often the minimum requirement for many jobs. It
92 | indicates a basic level of education and the ability to read and write.
93 |
94 |
95 | This person had completed an undergraduate degree, which may open up more job opportunities and
96 | lead to higher earning potential. It indicates a higher level of education and specialized knowledge
97 | in
98 | a particular field.
99 |
100 |
101 | This person had completed a graduate degree, which may lead to specialized job opportunities and
102 | higher earning potential. It indicates an advanced level of education and expertise in a particular
103 | field.
104 |
105 |
106 | This person had completed a doctorate degree, which is often required for academic or research
107 | positions. It indicates the highest level of education and expertise in a particular field.
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 | This person works in the medical field, which may include roles such as doctor, nurse, or medical
117 | technician. It often involves providing care for others and may require specialized training and
118 | certifications.
119 |
120 |
121 | This person works in the education field, which may include roles such as teacher, administrator,
122 | or counselor. It often involves working with children or young adults and may require specialized
123 | training and certifications.
124 |
125 |
126 | This person works in the technology field, which may include roles such as software developer, IT
127 | specialist, or network administrator. It often involves working with computers and may require
128 | specialized training and certifications.
129 |
130 |
131 | This person works in the arts and entertainment field, which may include roles such as artist,
132 | musician, or actor. It often involves creative expression and may require specialized training and
133 | talent.
134 |
135 |
136 | This person works in the finance and business field, which may include roles such as accountant,
137 | financial advisor, or business manager. It often involves managing money and may require specialized
138 | training and certifications.
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 | This person has never been married, which may involve living independently or with others who
148 | are not a spouse. It may also involve focusing on personal goals and priorities.
149 |
150 |
151 | This person is currently married, which may involve sharing responsibilities and making
152 | decisions together with a spouse. It may also involve raising children together.
153 |
154 |
155 | This person had been previously married but now being single, which may involve adjusting to a
156 | new way of life and may involve co-parenting children with a former spouse.
157 |
158 |
159 | This person had lost a spouse to death, which may involve grieving and adjusting to a new way of
160 | life. It may also involve managing responsibilities alone that were once shared.
161 |
162 |
163 | This person is in a relationship but not being married, which may involve sharing some
164 | responsibilities and making decisions together with a partner. It may also involve negotiating
165 | boundaries and expectations.
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 | This person has an extroverted personality, which often involves enjoying social interactions
175 | and feeling energized by being around others. Extroverts often seek out social situations and enjoy
176 | meeting new people.
177 |
178 |
179 | This person is currently married, which often involves preferring solitude or small group
180 | interactions. Introverts often feel drained by social interactions and need time alone to recharge.
181 |
182 |
183 | This person has a sensing personality, which often involves focusing on the present and relying
184 | on concrete, tangible information. Sensing types often prefer practical, realistic solutions.
185 |
186 |
187 | This person has an intuitive personality, which often involves focusing on the future and
188 | relying on abstract, theoretical information. Intuitive types often prefer creative, imaginative
189 | solutions.
190 |
191 |
192 | This person has a feeling personality, which often involves making decisions based on personal
193 | values and the impact on others. Feeling types often prioritize harmony and empathy in their
194 | interactions.
195 |
196 |
197 |
198 |
199 | Once you have created this character, I will ask questions about 'him/her' and you will respond based on the
200 | persona you have created. Remember to maintain the consistency of the character's traits throughout the
201 | conversation. Are you ready to create the character?
202 |
203 |
204 |
205 | Great, thank you for providing the information about the character. Based on the provided traits, I have
206 | created a hypothetical character named Alex. Now feel free to ask any questions about him, and I will respond
207 | based on his persona.
208 |
209 |
210 |
211 |
212 |
--------------------------------------------------------------------------------
/benchmark/squad_v2.py:
--------------------------------------------------------------------------------
1 | from .benchmark_base import Benchmark, Entry
2 | from .utils import XMLSchemaBuilder
3 | from datasets import load_dataset
4 | import os
5 |
6 | _document_schema_name = "document_summary"
7 | _document_header = "Document"
8 | _document_system_description = "Below are a series of dialogues between various people and an AI assistant. The AI tries to be helpful, polite, \
9 | honest, sophisticated, emotionally aware, and humble-but-knowledgeable. The assistant is happy to help with \
10 | almost anything, and will do its best to understand exactly what is needed. It also tries to avoid giving false \
11 | or misleading information, and it caveats when it isn't entirely sure about the right answer. That said, the \
12 | assistant is practical and really does its best, and doesn't let caution get too much in the way of being \
13 | useful."
14 | _document_user_description = "For the upcoming interaction, I would like you to answer some questions about the document."
15 | _document_assistant_description = "Sure. I have read the document. Please give me any question."
16 |
17 |
18 | class SquadV2(Benchmark):
19 | def __init__(self):
20 | super().__init__("squad_v2")
21 |
22 | def init(self, limit_entries=None):
23 | """
24 | Download (one time) and load the dataset to run;
25 | Preprocess the dataset to be organized in the `Entry` format.
26 | """
27 | self.dataset = load_dataset(self.dataset_name)
28 | count = 0
29 | # for split in self.dataset.values():
30 | # only use validation set
31 | for item in self.dataset['validation']:
32 | if limit_entries is not None and count >= limit_entries:
33 | break
34 | id = item["id"]
35 | schema_name = f"schema_{id}"
36 | builder = XMLSchemaBuilder(schema_name)
37 | context = item["context"]
38 | # title = item["title"]
39 | question = item["question"]
40 | answer = item["answers"]["text"]
41 | builder.set_system_description(_document_system_description)
42 | builder.set_user_description(_document_user_description)
43 | builder.add_document_module("context", self.dataset_prompt["context"].format(context=context))
44 | builder.set_assistant_description(_document_assistant_description)
45 |
46 | schema_file_name = f"{schema_name}.xml"
47 | with open(os.path.join(self.schema_path, schema_file_name), "w") as f:
48 | f.write(builder.generate_xml())
49 |
50 | # skip entry without ground truth
51 | if len(answer) == 0:
52 | continue
53 |
54 | prompt = f"""
55 |
56 |
57 | {self.dataset_prompt["question"].format(input=question)}
58 | """
59 | self.entries.append(Entry(schema_file_name, prompt, answer))
60 |
61 | count += 1
62 |
63 | if __name__ == '__main__':
64 | sq = SquadV2()
65 | sq.init()
66 | print(sq.get_entry_count())
67 |
--------------------------------------------------------------------------------
/benchmark/utils.py:
--------------------------------------------------------------------------------
1 | from xml.etree.ElementTree import Element, SubElement, tostring, ElementTree
2 | from xml.dom import minidom
3 |
4 |
5 | class XMLSchemaBuilder:
6 | def __init__(self, schema_name):
7 | self.schema = Element('schema', name=schema_name)
8 | self.user = None
9 | self.user_union = None
10 |
11 | def set_system_description(self, description):
12 | system = SubElement(self.schema, 'system')
13 | system.text = description
14 |
15 | def set_user_description(self, description, user_union=False, scaffold_name="DOC"):
16 | self.user = SubElement(self.schema, 'user')
17 | self.user.text = description
18 | if user_union:
19 | self.user_union = SubElement(self.user, 'union', scaffold=scaffold_name)
20 |
21 | def add_document_module_union(self, module_name, content):
22 | assert self.user_union is not None
23 | module = SubElement(self.user_union, 'module', name=module_name)
24 | module.text = content.replace('\n', '\\n').replace("'", "\'").replace('"', '\"')
25 |
26 | def add_document_module(self, module_name, content):
27 | module = SubElement(self.user, 'module', name=module_name)
28 | module.text = content.replace('\n', '\\n').replace("'", "\'").replace('"', '\"')
29 |
30 | def set_assistant_description(self, description):
31 | assistant = SubElement(self.schema, 'assistant')
32 | assistant.text = description
33 |
34 | def generate_xml(self):
35 | rough_string = tostring(self.schema, 'utf-8')
36 | reparsed = minidom.parseString(rough_string)
37 | prettystr = reparsed.toprettyxml(indent="\t")
38 | return prettystr.replace('"', "'")
39 |
--------------------------------------------------------------------------------
/benchmark_memcpy.py:
--------------------------------------------------------------------------------
1 | # On RTX 4090
2 | # Host-to-Host (CPU to CPU) Average Latency: 3.79 milliseconds
3 | # Host-to-Device (CPU to GPU) Average Latency: 5.34 milliseconds
4 | # Device-to-Device (GPU to GPU) Average Latency: 0.23 milliseconds
5 | # Device-to-Host (GPU to CPU) Average Latency: 5.88 milliseconds
6 |
7 |
8 | import torch
9 | import time
10 |
11 | NUM_LAYERS = 30
12 | SEQ_LEN = 5000
13 | CACHE_DIM = (40, SEQ_LEN, 128)
14 |
15 | print('loaded')
16 |
17 |
18 | def create_cache(device):
19 | return [(torch.rand(CACHE_DIM, dtype=torch.float16, device=device),
20 | torch.rand(CACHE_DIM, dtype=torch.float16, device=device)) for _ in
21 | range(NUM_LAYERS)]
22 |
23 |
24 | def benchmark_transfer(src_cache, dst_cache, description):
25 | start_time = time.time()
26 | for src, dst in zip(src_cache, dst_cache):
27 | dst[0].copy_(src[0], non_blocking=True)
28 | dst[1].copy_(src[0], non_blocking=True)
29 | torch.cuda.synchronize() # Ensure CUDA operations are synchronized
30 | elapsed = (time.time() - start_time) / NUM_LAYERS
31 | print(f"{description} Average Latency: {elapsed * 1000:.2f} milliseconds")
32 |
33 |
34 | cpu_cache = create_cache('cpu')
35 | gpu_cache = create_cache('cuda')
36 | cpu_cache_clone = create_cache('cpu')
37 | gpu_cache_clone = create_cache('cuda')
38 |
39 | # Host-to-Host (CPU to CPU) Transfer
40 | benchmark_transfer(cpu_cache, cpu_cache_clone, "Host-to-Host (CPU to CPU)")
41 |
42 | # Host-to-Device (CPU to GPU) Transfer
43 | benchmark_transfer(cpu_cache, gpu_cache_clone, "Host-to-Device (CPU to GPU)")
44 |
45 | # Device-to-Device (GPU to GPU) Transfer
46 | benchmark_transfer(gpu_cache, gpu_cache_clone, "Device-to-Device (GPU to GPU)")
47 |
48 | # Device-to-Host (GPU to CPU) Transfer
49 | benchmark_transfer(gpu_cache, cpu_cache_clone, "Device-to-Host (GPU to CPU)")
50 |
--------------------------------------------------------------------------------
/config/dataset_maxlen.json:
--------------------------------------------------------------------------------
1 | {
2 | "narrativeqa": 128,
3 | "qasper": 128,
4 | "multifieldqa_en": 64,
5 | "multifieldqa_zh": 64,
6 | "hotpotqa": 32,
7 | "2wikimqa": 32,
8 | "musique": 32,
9 | "dureader": 128,
10 | "gov_report": 512,
11 | "qmsum": 512,
12 | "vcsum": 512,
13 | "trec": 64,
14 | "triviaqa": 32,
15 | "samsum": 128,
16 | "lsht": 64,
17 | "passage_count": 32,
18 | "passage_retrieval_en": 32,
19 | "passage_retrieval_zh": 32,
20 | "lcc": 64,
21 | "repobench-p": 64,
22 | "multi_news": 512,
23 | "multi_news_long": 512,
24 | "squad_v2": 32,
25 | "wiki_qa": 128,
26 | "pubmd_qa": 128,
27 | "ms_marco": 32
28 | }
--------------------------------------------------------------------------------
/config/dataset_prompt.json:
--------------------------------------------------------------------------------
1 | {
2 | "narrativeqa": {
3 | "context": "You are given a story, which can be either a novel or a movie script, and a question. Answer the question as concisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nStory: {context}\n\n",
4 | "question": "Now, answer the question based on the story as concisely as you can, using a single phrase if possible. Do not provide any explanation. Question: {input}\n\n",
5 | "answer": "The concise answer is:"
6 | },
7 | "qasper": {
8 | "context": "You are given a scientific article and a question. Article: {context}",
9 | "question": "Answer the question as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write \"unanswerable\". If the question is a yes/no question, answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation.\n\n Question: {input}\n\n",
10 | "answer": "The concise answer is:"
11 | },
12 | "multifieldqa_en": {
13 | "context": "Read the following text and answer briefly.\n\n{context}",
14 | "question": "Question: {input}\n",
15 | "answer": "The concise answer is:"
16 | },
17 | "multifieldqa_zh": {
18 | "context": "阅读以下文字并用中文简短回答:\n\n{context}",
19 | "question": "问题:{input}\n",
20 | "answer": "回答:"
21 | },
22 | "hotpotqa": {
23 | "context": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}",
24 | "question": "Question: {input}\n",
25 | "answer": "The concise answer is:"
26 | },
27 | "2wikimqa": {
28 | "context": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}",
29 | "question": "Question: {input}\n",
30 | "answer": "The concise answer is:"
31 | },
32 | "musique": {
33 | "context": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}",
34 | "question": "Question: {input}\nAnswer:",
35 | "answer": "The concise answer is:"
36 | },
37 | "dureader": {
38 | "context": "请基于给定的文章回答下述问题。\n\n文章:{context}",
39 | "question": "问题:{input}\n",
40 | "answer": "回答:"
41 | },
42 | "gov_report": {
43 | "context": "You are given a report by a government agency. Write a one-page summary of the report.\n\nReport:\n{context}",
44 | "question": "",
45 | "answer": "Summary:"
46 | },
47 | "qmsum": {
48 | "context": "You are given a meeting transcript and a query containing a question or instruction.\n\nTranscript:\n{context}",
49 | "question": "Query: {input}\n",
50 | "answer": "The concise answer is:"
51 | },
52 | "vcsum": {
53 | "context": "下面有一段会议记录,请你阅读后,写一段总结,总结会议的内容。\n会议记录:\n{context}",
54 | "question": "",
55 | "answer": "会议总结:"
56 | },
57 | "trec": {
58 | "context": "Please determine the type of the question below. Here are some examples of questions.\n\n{context}",
59 | "question": "{input}",
60 | "answer": "Type of the question:"
61 | },
62 | "triviaqa": {
63 | "context": "Answer the question based on the given passage. Only give me the answer and do not output any other words. The following are some examples.\n\n{context}",
64 | "question": "{input}",
65 | "answer": "The concise answer is:"
66 | },
67 | "samsum": {
68 | "context": "Summarize the dialogue into a few short sentences. The following are some examples.\n\n{context}",
69 | "question": "{input}",
70 | "answer": "Summary:"
71 | },
72 | "lsht": {
73 | "context": "请判断给定新闻的类别,下面是一些例子。\n\n{context}",
74 | "question": "{input}"
75 | },
76 | "passage_count": {
77 | "context": "There are some paragraphs below sourced from Wikipedia. Some of them may be duplicates. Please carefully read these paragraphs and determine how many unique paragraphs there are after removing duplicates. In other words, how many non-repeating paragraphs are there in total?\n\n{context}",
78 | "question": "",
79 | "answer": "The concise answer is: "
80 | },
81 | "passage_retrieval_en": {
82 | "context": "Here are 30 paragraphs from Wikipedia, along with an abstract. Please determine which paragraph the abstract is from.\n\n{context}",
83 | "question": "The following is an abstract.\n\n{input}\n\nPlease enter the number of the paragraph that the abstract is from. The answer format must be like \"Paragraph 1\", \"Paragraph 2\", etc.\n\n",
84 | "answer": "The answer is: "
85 | },
86 | "passage_retrieval_zh": {
87 | "context": "以下是若干段落文字,以及其中一个段落的摘要。请确定给定的摘要出自哪一段。\n\n{context}",
88 | "question": "下面是一个摘要\n\n{input}\n\n请输入摘要所属段落的编号。答案格式必须是\"段落1\",\"段落2\"等格式\n\n",
89 | "answer": "答案是:"
90 | },
91 | "lcc": {
92 | "context": "Please complete the code given below. \n{context}",
93 | "question": "",
94 | "answer": "Next line of code:\n"
95 | },
96 | "repobench-p": {
97 | "context": "Please complete the code given below. \n{context}",
98 | "question": "{input}",
99 | "answer": "Next line of code:\n"
100 | },
101 | "multi_news": {
102 | "context": "You are given several news passages. Write a one-page summary of all news. \n\nNews:\n{context}\n\nNow, write a one-page summary of all the news.\n\n",
103 | "question": "",
104 | "answer": "Summary:"
105 | },
106 | "multi_news_long": {
107 | "context": "You are given several news passages. Write a one-page summary of all news. \n\nNews:\n{context}\n\nNow, write a one-page summary of all the news.\n\n",
108 | "question": "",
109 | "answer": "Summary:"
110 | },
111 | "squad_v2": {
112 | "context": "Read the following text and answer briefly.\n\n{context}",
113 | "question": "Question: {input}\n\n",
114 | "answer": "The concise answer is:"
115 | },
116 | "wiki_qa": {
117 | "context": "Please complete the code given below. \n{context}",
118 | "question": "",
119 | "answer": "Next line of code:\n"
120 | },
121 | "pubmd_qa": {
122 | "context": "Please complete the code given below. \n{context}",
123 | "question": "",
124 | "answer": "Next line of code:\n"
125 | },
126 | "ms_marco": {
127 | "context": "",
128 | "question": "Among the list of given documents separated by a comma, find the document that is most useful to answer the query and return its index starting at [0]. All documents are unrelated, return [-1]. Do not reply using a complete sentence, and only give the answer in the following format: [1]:\n"
129 | }
130 | }
--------------------------------------------------------------------------------
/config/llm_config_falcon_40b.json:
--------------------------------------------------------------------------------
1 | {
2 | "arch": "falcon",
3 | "log_name": "falcon-40b",
4 | "name": "tiiuae/falcon-40b-instruct",
5 | "load_in_8bit": true,
6 | "device_map": "auto",
7 | "max_tokens": 1500,
8 | "max_ctx_length": 3000
9 | }
--------------------------------------------------------------------------------
/config/llm_config_falcon_7b.json:
--------------------------------------------------------------------------------
1 | {
2 | "arch": "falcon",
3 | "log_name": "falcon-7b",
4 | "name": "tiiuae/falcon-7b-instruct",
5 | "load_in_8bit": true,
6 | "device_map": "auto",
7 | "max_tokens": 1500,
8 | "max_ctx_length": 3000
9 | }
--------------------------------------------------------------------------------
/config/llm_config_llama2_13b.json:
--------------------------------------------------------------------------------
1 | {
2 |
3 | "arch": "llama",
4 | "log_name": "llama2-13b",
5 | "name": "meta-llama/Llama-2-13b-chat-hf",
6 | "load_in_8bit": true,
7 | "device_map": "auto",
8 | "max_tokens": 3500,
9 | "max_ctx_length": 4096
10 |
11 | }
--------------------------------------------------------------------------------
/config/llm_config_llama2_7b.json:
--------------------------------------------------------------------------------
1 | {
2 | "arch": "llama",
3 | "log_name": "llama2-7b",
4 | "name": "meta-llama/Llama-2-7b-chat-hf",
5 | "load_in_8bit": true,
6 | "device_map": "auto",
7 | "max_tokens": 3500,
8 | "max_ctx_length": 4096
9 | }
--------------------------------------------------------------------------------
/config/llm_config_longchat_7b.json:
--------------------------------------------------------------------------------
1 | {
2 | "arch": "longchat",
3 | "log_name": "longchat-7b",
4 | "name": "lmsys/longchat-7b-v1.5-32k",
5 | "load_in_8bit": true,
6 | "device_map": "auto",
7 | "max_tokens": 8000,
8 | "max_ctx_length": 9186
9 | }
--------------------------------------------------------------------------------
/config/llm_config_mpt_30b.json:
--------------------------------------------------------------------------------
1 | {
2 |
3 | "arch": "mpt",
4 | "log_name": "mpt-30b",
5 | "name": "mosaicml/mpt-30b-chat",
6 | "load_in_8bit": true,
7 | "device_map": "auto",
8 | "max_tokens": 3500,
9 | "max_ctx_length": 4096
10 |
11 | }
--------------------------------------------------------------------------------
/config/llm_config_mpt_7b.json:
--------------------------------------------------------------------------------
1 | {
2 | "arch": "mpt",
3 | "log_name": "mpt-7b",
4 | "name": "mosaicml/mpt-7b-chat-8k",
5 | "load_in_8bit": true,
6 | "device_map": "auto",
7 | "max_tokens": 3500,
8 | "max_ctx_length": 4096
9 | }
--------------------------------------------------------------------------------
/config/llm_config_vicuna_7b.json:
--------------------------------------------------------------------------------
1 | {
2 | "arch": "vicuna",
3 | "log_name": "vicuna-7b",
4 | "name": "lmsys/vicuna-7b-v1.5-16k",
5 | "load_in_8bit": true,
6 | "device_map": "auto",
7 | "max_tokens": 8000,
8 | "max_ctx_length": 9000
9 | }
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | import random
2 | import re
3 |
4 | import numpy as np
5 | import torch.cuda
6 | import fire
7 |
8 | from promptcache.model import Llama2, Falcon, Mpt, CodeLlama
9 |
10 | from promptcache import Prompt, CompactSpaces, read_file, CacheEngine, \
11 | GenerationEngine, GenerationParameters, llama2_template
12 |
13 |
14 | def escape_tags(input_str):
15 | pattern = r'<(?P.*?)>'
16 |
17 | def repl(match):
18 | return '(' + match.group("content").capitalize() + ')'
19 |
20 | return re.sub(pattern, repl, input_str)
21 |
22 |
23 | def main(enable_cache=True):
24 | enable_cpu_inference = False
25 | disable_prompt_cache = not enable_cache
26 |
27 | lm_for_cache = CodeLlama("codellama/CodeLlama-7b-Instruct-hf",
28 | load_in_8bit=True,
29 | device_map="auto")
30 |
31 | lm = lm_for_cache
32 |
33 | if enable_cpu_inference:
34 | lm = CodeLlama("codellama/CodeLlama-7b-Instruct-hf",
35 | load_in_8bit=False,
36 | device_map=None)
37 |
38 | # lm = Falcon("tiiuae/falcon-7b-instruct",
39 | # load_in_8bit=True if not disable_cuda else False,
40 | # device_map="auto" if not disable_cuda else None)
41 |
42 | # lm = Mpt("mosaicml/mpt-7b-chat-8k",
43 | # load_in_8bit=True if not disable_cuda else False,
44 | # device_map="auto" if not disable_cuda else None)
45 |
46 | cache_engine = CacheEngine(5000, lm_for_cache, target_device='cpu' if enable_cpu_inference else None)
47 | gen_engine = GenerationEngine(lm)
48 |
49 | preproc = [
50 | # CompactSpaces(),
51 | lm.get_formatter()
52 | ]
53 |
54 | cache_engine.add_schema(read_file("./examples/code_generation_game.xml", preproc), max_tokens=800)
55 |
56 | parameter = GenerationParameters(
57 | temperature=1.0,
58 | repetition_penalty=1.0,
59 | top_p=0.95,
60 | top_k=-1,
61 | max_new_tokens=512,
62 | stop_token_ids=lm.stop_token_ids,
63 | stop_str=lm.stop_str
64 | )
65 |
66 | prompt_text = f"""
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 | Create a main entry for the game:
75 |
76 |
77 | """
78 |
79 | prompt = Prompt(prompt_text, preproc)
80 | token_ids, position_ids, cache_time, cache = cache_engine.process(prompt, no_cache=disable_prompt_cache,
81 | return_full_position_ids=lm.use_full_position_ids)
82 |
83 | output_stream = gen_engine.generate(token_ids, position_ids, parameter, cache, stream_interval=2,
84 | use_full_position_ids=lm.use_full_position_ids)
85 |
86 | print(f"Assistant: ", end="", flush=True)
87 |
88 | resp = ""
89 | pre = 0
90 | for outputs in output_stream:
91 | output_text = outputs.new_text.strip().split(" ")
92 | now = len(output_text) - 1
93 | if now > pre:
94 | tt = " ".join(output_text[pre:now])
95 | resp += tt + " "
96 | print(tt, end=" ", flush=True)
97 | pre = now
98 | tt = " ".join(output_text[pre:])
99 | print(tt, flush=True)
100 | resp += tt
101 |
102 | print("\n")
103 | prompt_text += f"{resp}"
104 |
105 |
106 | def seed_everything(seed):
107 | torch.manual_seed(seed)
108 | torch.cuda.manual_seed(seed)
109 | np.random.seed(seed)
110 | random.seed(seed)
111 | torch.backends.cudnn.benchmark = False
112 | torch.backends.cudnn.deterministic = True
113 | torch.cuda.manual_seed_all(seed)
114 |
115 |
116 | if __name__ == "__main__":
117 | seed_everything(42)
118 | fire.Fire(main)
119 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import numpy as np
4 | import torch.cuda
5 | import fire
6 | import sys, json
7 | import os
8 | import datetime
9 | from tqdm import tqdm
10 | from benchmark.longbench import LongBench
11 | from promptcache.model import Llama2, Falcon, Mpt
12 | from promptcache import Prompt, CompactSpaces, read_file, CacheEngine, \
13 | GenerationEngine, GenerationParameters
14 |
15 | from benchmark.benchmark_base import DATASET_LIST, SCHEMA_FILE_DIRECTORY
16 | from benchmark.squad_v2 import SquadV2
17 | from benchmark.multi_news import MultiNews
18 | from benchmark.ms_marco_v1_1 import MSMarcoV1
19 |
20 | BENCHMARK_PATH = "./benchmark"
21 |
22 |
23 | class Eval:
24 | def __init__(self, llm_config_path, dataset, enable_cache, use_cpu_for_inference=False):
25 | with open("./config/dataset_maxlen.json", 'r') as f:
26 | self.dataset_maxlen = json.load(f)
27 |
28 | with open(llm_config_path, 'r') as f:
29 | self.llm_config = json.load(f)
30 | self.enable_cache = enable_cache
31 | self.use_cpu_for_inference = use_cpu_for_inference
32 |
33 | self.model_name = self.llm_config["name"]
34 | if "llama" in self.model_name:
35 | self.model_name = "llama"
36 | self.lm_for_caching = Llama2(name=self.llm_config['name'], device_map="auto", load_in_8bit=True)
37 | elif "falcon" in self.model_name:
38 | self.model_name = "falcon"
39 | self.lm_for_caching = Falcon(name=self.llm_config['name'], device_map="auto", load_in_8bit=True)
40 | elif "mpt" in self.model_name:
41 | self.model_name = "mpt"
42 | self.lm_for_caching = Mpt(name=self.llm_config['name'], device_map="auto", load_in_8bit=True)
43 | else:
44 | raise ValueError("Invalid model name")
45 |
46 | if self.use_cpu_for_inference:
47 | if "llama" in self.model_name:
48 | self.lm = Llama2(name=self.llm_config['name'], device_map=None)
49 | elif "falcon" in self.model_name:
50 | self.lm = Falcon(name=self.llm_config['name'], device_map=None)
51 | elif "mpt" in self.model_name:
52 | self.lm = Mpt(name=self.llm_config['name'], device_map=None)
53 | else:
54 | self.lm = self.lm_for_caching
55 |
56 | self.cache_engine = CacheEngine(self.llm_config.get("max_ctx_length", 4096), self.lm_for_caching,
57 | target_device=self.lm.device)
58 | self.gen_engine = GenerationEngine(self.lm)
59 | self.preproc = [
60 | # CompactSpaces(),
61 | self.lm.get_formatter()
62 | ]
63 |
64 | # self.parameter = GenerationParameters(
65 | # temperature=0.1,
66 | # repetition_penalty=1.17,
67 | # top_p=0.95,
68 | # top_k=-1,
69 | # max_new_tokens=512,
70 | # stop_token_ids=self.lm.stop_token_ids,
71 | # stop_str=self.lm.stop_str
72 | # )
73 |
74 | self.parameter = GenerationParameters(
75 | temperature=1.0,
76 | repetition_penalty=1.0,
77 | top_p=0.95,
78 | top_k=-1,
79 | max_new_tokens=self.dataset_maxlen[dataset],
80 | stop_token_ids=self.lm.stop_token_ids,
81 | stop_str=self.lm.stop_str
82 | )
83 |
84 | if dataset is None or dataset not in DATASET_LIST:
85 | raise ValueError("Dataset name cannot be None, valid dataset names are: " + ", ".join(DATASET_LIST))
86 |
87 | match dataset:
88 | case "squad_v2":
89 | self.dataset = SquadV2()
90 |
91 | case "multi_news":
92 | self.dataset = MultiNews()
93 |
94 | case "narrativeqa":
95 | self.dataset = LongBench("narrativeqa")
96 |
97 | case "qasper":
98 | self.dataset = LongBench("qasper")
99 |
100 | case "multifieldqa_en":
101 | self.dataset = LongBench("multifieldqa_en")
102 |
103 | case "hotpotqa":
104 | self.dataset = LongBench("hotpotqa")
105 |
106 | case "2wikimqa":
107 | self.dataset = LongBench("2wikimqa")
108 |
109 | case "musique":
110 | self.dataset = LongBench("musique")
111 |
112 | case "dureader":
113 | self.dataset = LongBench("dureader")
114 |
115 | case "gov_report":
116 | self.dataset = LongBench("gov_report")
117 |
118 | case "qmsum":
119 | self.dataset = LongBench("qmsum")
120 |
121 | case "multi_news_long":
122 | self.dataset = LongBench("multi_news")
123 |
124 | case "vcsum":
125 | self.dataset = LongBench("vcsum")
126 |
127 | case "trec":
128 | self.dataset = LongBench("trec")
129 |
130 | case "triviaqa":
131 | self.dataset = LongBench("triviaqa")
132 |
133 | case "samsum":
134 | self.dataset = LongBench("samsum")
135 |
136 | case "lsht":
137 | self.dataset = LongBench("lsht")
138 |
139 | case "passage_count":
140 | self.dataset = LongBench("passage_count")
141 |
142 | case "passage_retrieval_en":
143 | self.dataset = LongBench("passage_retrieval_en")
144 |
145 | case "lcc":
146 | self.dataset = LongBench("lcc")
147 |
148 | case "repobench-p":
149 | self.dataset = LongBench("repobench-p")
150 |
151 | # for testing purpose, limit the entries to a small number
152 | self.dataset.init()
153 |
154 | # create result directory
155 | self.result_directory = os.path.join(BENCHMARK_PATH, "results",
156 | f"{self.model_name}-{self.dataset.dataset_name}")
157 | if not os.path.exists(self.result_directory):
158 | os.makedirs(self.result_directory)
159 |
160 | self.result_file_suffix = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
161 |
162 | def store_results(self, results, split):
163 | if self.enable_cache:
164 | prefix = "with_cache"
165 | else:
166 | prefix = "no_cache"
167 |
168 | with open(os.path.join(self.result_directory, f"{prefix}_split_{split[0]}_{split[1]}_time_{self.result_file_suffix}.json"), "a") as f:
169 | json.dump(results, f)
170 | f.write("\n")
171 |
172 | @torch.inference_mode()
173 | def run_latency_eval(self):
174 |
175 | for entry in self.dataset.entries:
176 |
177 | schema_file_path = os.path.join(SCHEMA_FILE_DIRECTORY, self.dataset.dataset_name, entry.schema)
178 | print(schema_file_path)
179 | if True:
180 | self.cache_engine.add_schema(read_file(schema_file_path, self.preproc), max_tokens=3500)
181 |
182 | prompt = Prompt(entry.prompt, self.preproc)
183 |
184 | no_cache = not self.enable_cache
185 |
186 | token_ids, position_ids, cache_time, cache = self.cache_engine.process(prompt, no_cache=no_cache,
187 | return_full_position_ids=self.lm.use_full_position_ids)
188 |
189 | if no_cache:
190 | assert cache is None
191 |
192 | input_ids = torch.tensor([token_ids], device=self.lm.device, dtype=torch.long)
193 | position_ids = torch.tensor([position_ids], device=self.lm.device, dtype=torch.long)
194 | # print(len(position_ids[0]))
195 |
196 | # add redundant batch dim
197 | if cache is not None:
198 | cache = [(k[0].unsqueeze(0), k[1].unsqueeze(0)) for k in cache]
199 |
200 | start = torch.cuda.Event(enable_timing=True)
201 | end = torch.cuda.Event(enable_timing=True)
202 |
203 | start.record()
204 | out = self.lm(input_ids=input_ids,
205 | position_ids=position_ids,
206 | past_key_values=cache,
207 | use_cache=True)
208 | end.record()
209 | torch.cuda.synchronize()
210 | response_time = start.elapsed_time(end)
211 |
212 | result = {
213 | "cache_time": cache_time,
214 | "response_time": response_time,
215 | }
216 | print(result)
217 | self.store_results(result)
218 |
219 | self.cache_engine.remove_all_schemas()
220 |
221 | def run(self, split, verbose=False):
222 | entry_count = self.dataset.get_entry_count()
223 | split_count = entry_count // split[1]
224 |
225 | start = split_count * split[0]
226 | end = split_count * (split[0] + 1)
227 | print(f"Running benchmark on {self.dataset.dataset_name}, start: {start}, end: {end}")
228 |
229 | for i in tqdm(range(start, end)):
230 | entries = self.dataset.get_query((i, i + 1))
231 | for entry in entries:
232 | schema_file_path = os.path.join(SCHEMA_FILE_DIRECTORY, self.dataset.dataset_name, entry.schema)
233 | print(schema_file_path)
234 | self.cache_engine.add_schema(read_file(schema_file_path, self.preproc),
235 | batch_size=self.llm_config.get("schema_load_batch", 1),
236 | max_tokens=self.llm_config.get("max_tokens", 3500))
237 |
238 | for entry in entries:
239 | print(entry.prompt)
240 | prompt = Prompt(entry.prompt, self.preproc)
241 | no_cache = not self.enable_cache
242 | token_ids, position_ids, cache_time, cache = self.cache_engine.process(prompt, no_cache=no_cache,
243 | return_full_position_ids=self.lm.use_full_position_ids)
244 | if no_cache:
245 | assert cache is None
246 | # for debugging
247 | if verbose:
248 | print("No caching; prompt:\n" + self.lm.decode(token_ids) + "\n")
249 |
250 | output_stream = self.gen_engine.generate(token_ids, position_ids, self.parameter, cache,
251 | stream_interval=2,
252 | use_full_position_ids=self.lm.use_full_position_ids)
253 |
254 | resp = ""
255 | pre = 0
256 | response_time = 0.0
257 | for outputs in output_stream:
258 | response_time = outputs.response_time
259 | output_text = outputs.new_text.strip().split(" ")
260 | now = len(output_text) - 1
261 | if now > pre:
262 | tt = " ".join(output_text[pre:now])
263 | resp += tt + " "
264 | print(tt, end=" ", flush=True)
265 | pre = now
266 |
267 | tt = " ".join(output_text[pre:])
268 | print(tt, flush=True)
269 | resp += tt
270 |
271 | result = {
272 | "cache_time": cache_time,
273 | "response_time": response_time,
274 | "answers": entry.answer,
275 | "response": resp
276 | }
277 | self.store_results(result, split)
278 | print("\n")
279 |
280 | self.cache_engine.remove_all_schemas()
281 |
282 |
283 | def seed_everything(seed):
284 | torch.manual_seed(seed)
285 | torch.cuda.manual_seed(seed)
286 | np.random.seed(seed)
287 | random.seed(seed)
288 | torch.backends.cudnn.benchmark = False
289 | torch.backends.cudnn.deterministic = True
290 | torch.cuda.manual_seed_all(seed)
291 |
292 |
293 | def main(llm_config_path: str = os.path.join('./', "config/llm_config_llama2_7b.json"),
294 | dataset: str = "narrativeqa", enable_cache=False, cache_batch_size=1, split=(0, 1),
295 | test_latency=False,
296 | use_cpu_for_inference=False,
297 | verbose=False):
298 | seed_everything(42)
299 |
300 | eval = Eval(llm_config_path, dataset, enable_cache, use_cpu_for_inference)
301 |
302 | if test_latency:
303 | eval.run_latency_eval()
304 | else:
305 | eval.run(split, verbose)
306 |
307 | if __name__ == "__main__":
308 | fire.Fire(main)
309 |
--------------------------------------------------------------------------------
/eval_acc.py:
--------------------------------------------------------------------------------
1 | import math
2 | import random
3 |
4 | import numpy as np
5 | import torch.cuda
6 | import fire
7 | import sys, json
8 | import os
9 | import datetime
10 | from tqdm import tqdm
11 | from benchmark.longbench import LongBench
12 | from promptcache.model import Llama2, Falcon, Mpt
13 | from promptcache import Prompt, CompactSpaces, read_file, CacheEngine, \
14 | GenerationEngine, GenerationParameters
15 |
16 | from benchmark.benchmark_base import DATASET_LIST, SCHEMA_FILE_DIRECTORY
17 | from benchmark.squad_v2 import SquadV2
18 | from benchmark.multi_news import MultiNews
19 | from promptcache.prompt import apply_preproc
20 | from metrics import (
21 | qa_f1_score,
22 | rouge_zh_score,
23 | qa_f1_zh_score,
24 | rouge_score,
25 | classification_score,
26 | retrieval_score,
27 | retrieval_zh_score,
28 | count_score,
29 | code_sim_score
30 | )
31 | from multiprocessing import cpu_count, Process, Queue
32 | from concurrent.futures import ProcessPoolExecutor
33 |
34 | dataset2metric = {
35 | "narrativeqa": qa_f1_score,
36 | "qasper": qa_f1_score,
37 | "multifieldqa_en": qa_f1_score,
38 | "multifieldqa_zh": qa_f1_zh_score,
39 | "hotpotqa": qa_f1_score,
40 | "2wikimqa": qa_f1_score,
41 | "musique": qa_f1_score,
42 | "dureader": rouge_zh_score,
43 | "gov_report": rouge_score,
44 | "qmsum": rouge_score,
45 | "multi_news": rouge_score,
46 | "vcsum": rouge_zh_score,
47 | "trec": classification_score,
48 | "triviaqa": qa_f1_score,
49 | "samsum": rouge_score,
50 | "lsht": classification_score,
51 | "passage_retrieval_en": retrieval_score,
52 | "passage_count": count_score,
53 | "passage_retrieval_zh": retrieval_zh_score,
54 | "lcc": code_sim_score,
55 | "repobench-p": code_sim_score,
56 | }
57 |
58 | BENCHMARK_PATH = "./benchmark"
59 |
60 |
61 | class Eval:
62 | def __init__(self, gpu_id, llm_config_path, dataset_list, enable_cache):
63 | with open("./config/dataset_maxlen.json", 'r') as f:
64 | self.dataset_maxlen = json.load(f)
65 |
66 | with open(llm_config_path, 'r') as f:
67 | self.llm_config = json.load(f)
68 |
69 | self.enable_cache = enable_cache
70 |
71 | self.model_name = self.llm_config["name"]
72 | self.model_arch = self.llm_config["arch"]
73 | self.model_log_name = self.llm_config["log_name"]
74 | self.max_ctx_length = self.llm_config.get("max_ctx_length", 4096)
75 | self.max_tokens = self.llm_config.get("max_tokens", 3500)
76 | self.dataset_list = dataset_list
77 |
78 | if self.model_arch == "llama":
79 | self.lm = Llama2(name=self.model_name, device_map={"": gpu_id}, load_in_8bit=True)
80 | elif self.model_arch == "falcon":
81 | self.lm = Falcon(name=self.model_name, device_map={"": gpu_id}, load_in_8bit=True)
82 | elif self.model_arch == "mpt":
83 | self.lm = Mpt(name=self.model_name, device_map={"": gpu_id}, load_in_8bit=True)
84 | else:
85 | raise ValueError("Invalid model name")
86 |
87 | self.cache_engine = CacheEngine(self.max_ctx_length, self.lm,
88 | target_device=self.lm.device)
89 | self.gen_engine = GenerationEngine(self.lm)
90 | self.preproc = [
91 | # CompactSpaces(),
92 | self.lm.get_formatter()
93 | ]
94 |
95 | # create result directory
96 | self.result_directory = os.path.join(BENCHMARK_PATH, "results_acc")
97 | if not os.path.exists(self.result_directory):
98 | os.makedirs(self.result_directory)
99 |
100 | def run(self):
101 |
102 | for dataset_name in self.dataset_list:
103 |
104 | dataset = self.dataset_list[dataset_name]
105 | dataset.init(limit_entries=3)
106 |
107 | results = []
108 |
109 | for entry in tqdm(dataset.entries):
110 | # print(entry.prompt)
111 |
112 | schema = apply_preproc(entry.schema, self.preproc)
113 | prompt = Prompt(entry.prompt, self.preproc)
114 |
115 | self.cache_engine.add_schema(schema, max_tokens=self.max_tokens)
116 |
117 | no_cache = not self.enable_cache
118 | token_ids, position_ids, cache_time, cache = self.cache_engine.process(prompt,
119 | no_cache=no_cache,
120 | return_full_position_ids=self.lm.use_full_position_ids)
121 | if no_cache:
122 | assert cache is None
123 |
124 | parameter = GenerationParameters(
125 | temperature=0.0,
126 | repetition_penalty=1.0,
127 | top_p=0.0,
128 | top_k=-1,
129 | max_new_tokens=self.dataset_maxlen[dataset_name],
130 | stop_token_ids=self.lm.stop_token_ids,
131 | stop_str=self.lm.stop_str
132 | )
133 |
134 | output_stream = self.gen_engine.generate(token_ids, position_ids, parameter, cache,
135 | stream_interval=2,
136 | use_full_position_ids=self.lm.use_full_position_ids)
137 |
138 | resp = ""
139 | pre = 0
140 | response_time = 0.0
141 | for outputs in output_stream:
142 | response_time = outputs.response_time
143 | output_text = outputs.new_text.strip().split(" ")
144 | now = len(output_text) - 1
145 | if now > pre:
146 | tt = " ".join(output_text[pre:now])
147 | resp += tt + " "
148 | # print(tt, end=" ", flush=True)
149 | pre = now
150 |
151 | tt = " ".join(output_text[pre:])
152 | # print(tt, flush=True)
153 | resp += tt
154 | # print("\n")
155 |
156 | result = {
157 | "cache_time": cache_time,
158 | "response_time": response_time,
159 | "answers": entry.answer,
160 | "response": resp
161 | }
162 | print(result)
163 |
164 | results.append(result)
165 | self.cache_engine.remove_all_schemas()
166 |
167 | total_score = 0
168 | metric_fn = dataset2metric[dataset_name]
169 | for result in results:
170 | response = result["response"]
171 | answers = result["answers"]
172 |
173 | score = 0.
174 | for answer in answers:
175 | score = max(score, metric_fn(response, answer))
176 |
177 | total_score += score
178 |
179 | total_score = total_score / len(results) * 100
180 | print(f"Total score: {total_score:.2f}")
181 |
182 | if self.enable_cache:
183 | prefix = "cache_enabled"
184 | else:
185 | prefix = "cache_disabled"
186 | filename = f"{self.model_log_name}-{dataset_name}-{prefix}.json"
187 |
188 | with open(os.path.join(self.result_directory, filename), "w") as f:
189 | json.dump({
190 | "model_name": self.model_name,
191 | "model_arch": self.model_arch,
192 | "dataset_name": dataset_name,
193 | "enable_cache": self.enable_cache,
194 | "total_score": total_score,
195 | "results": results
196 | }, f)
197 |
198 |
199 | def run_eval(gpu_id, llm_config_path: str = os.path.join('./', "config/llm_config_llama2_7b.json"),
200 | dataset: str = "narrativeqa",
201 | enable_cache=True, ):
202 | seed_everything(42)
203 |
204 | eval = Eval(gpu_id, llm_config_path, dataset, enable_cache)
205 | eval.run()
206 |
207 |
208 | def seed_everything(seed):
209 | torch.manual_seed(seed)
210 | torch.cuda.manual_seed(seed)
211 | np.random.seed(seed)
212 | random.seed(seed)
213 | torch.backends.cudnn.benchmark = False
214 | torch.backends.cudnn.deterministic = True
215 | torch.cuda.manual_seed_all(seed)
216 |
217 |
218 | def main(num_gpus=1, llm_config_path: str = os.path.join('./', "config/llm_config_llama2_7b.json"),
219 | enable_cache=True,
220 | ):
221 | dataset_list = {
222 | "narrativeqa": LongBench("narrativeqa"),
223 | "qasper": LongBench("qasper"),
224 | "multifieldqa_en": LongBench("multifieldqa_en"),
225 | "hotpotqa": LongBench("hotpotqa"),
226 | "2wikimqa": LongBench("2wikimqa"),
227 | "musique": LongBench("musique"),
228 | "gov_report": LongBench("gov_report"),
229 | "qmsum": LongBench("qmsum"),
230 | "multi_news": LongBench("multi_news"),
231 | "trec": LongBench("trec"),
232 | "triviaqa": LongBench("triviaqa"),
233 | "samsum": LongBench("samsum"),
234 | "passage_count": LongBench("passage_count"),
235 | "passage_retrieval_en": LongBench("passage_retrieval_en"),
236 | "lcc": LongBench("lcc"),
237 | "repobench-p": LongBench("repobench-p")
238 | }
239 |
240 | dpg = int(math.ceil(len(dataset_list) / num_gpus))
241 |
242 | jobs_list = []
243 | nn = list(dataset_list.keys())
244 | for i in range(num_gpus):
245 | dataset_names = nn[i * dpg:(i + 1) * dpg]
246 | jobs = {}
247 | for dn in dataset_names:
248 | jobs[dn] = dataset_list[dn]
249 | jobs_list.append(jobs)
250 |
251 | processes = [
252 | Process(target=run_eval, args=(i, llm_config_path, jobs_list[i], enable_cache))
253 | for i in range(num_gpus)
254 | ]
255 |
256 | for p in processes:
257 | p.start()
258 |
259 | seed_everything(42)
260 |
261 | for p in processes:
262 | p.join()
263 |
264 |
265 | if __name__ == "__main__":
266 | fire.Fire(main)
267 |
--------------------------------------------------------------------------------
/eval_acc.slurm:
--------------------------------------------------------------------------------
1 | #!/bin/zsh
2 | #SBATCH --mem=128g
3 | #SBATCH --nodes=1
4 | #SBATCH --ntasks-per-node=1
5 | #SBATCH --cpus-per-task=32 # <- match to OMP_NUM_THREADS
6 | #SBATCH --partition=gpuA100x4 # <- or one of: gpuA100x4 gpuA40x4 gpuA100x8 gpuMI100x8
7 | #SBATCH --account=bccn-delta-gpu
8 | #SBATCH --job-name=eval_acc
9 | #SBATCH --time=10:00:00 # hh:mm:ss for the job
10 | ### GPU options ###
11 | #SBATCH --gpus-per-node=4
12 | ##SBATCH --gpu-bind=none # <- or closest
13 | #SBATCH --mail-user=in.gim@yale.edu.edu
14 | #SBATCH --mail-type="BEGIN,END"
15 |
16 | module reset # drop modules and explicitly load the ones needed
17 | # (good job metadata and reproducibility)
18 | # $WORK and $SCRATCH are now set
19 | module load python # ... or any appropriate modules
20 | module list # job documentation and metadata
21 | echo "job is starting on `hostname`"
22 | srun python3 eval_acc.py --num_gpus=4
23 |
24 |
--------------------------------------------------------------------------------
/eval_sys.py:
--------------------------------------------------------------------------------
1 | import gc
2 |
3 | import torch.cuda
4 | import fire
5 | import sys, json
6 | import os
7 | import datetime
8 | from tqdm import tqdm
9 | from benchmark.longbench import LongBench
10 | from promptcache.model import Llama2, Falcon, Mpt
11 | from promptcache import Prompt, CompactSpaces, read_file, CacheEngine, \
12 | GenerationEngine, GenerationParameters
13 |
14 | from benchmark.benchmark_base import SCHEMA_FILE_DIRECTORY
15 |
16 | BENCHMARK_PATH = "./benchmark"
17 | from torch.profiler import profile, record_function, ProfilerActivity
18 |
19 |
20 | class Eval:
21 | def __init__(self, memo, llm_config_path, use_cpu_for_inference=False):
22 | with open("./config/dataset_maxlen.json", 'r') as f:
23 | self.dataset_maxlen = json.load(f)
24 |
25 | with open(llm_config_path, 'r') as f:
26 | self.llm_config = json.load(f)
27 | self.memo = memo
28 | self.use_cpu_for_inference = use_cpu_for_inference
29 | self.repeat_times = 2 if use_cpu_for_inference else 3
30 | self.model_name = self.llm_config["name"]
31 | self.model_arch = self.llm_config["arch"]
32 | self.model_log_name = self.llm_config["log_name"]
33 | self.max_ctx_length = self.llm_config.get("max_ctx_length", 4096)
34 | self.max_tokens = self.llm_config.get("max_tokens", 3500)
35 |
36 | if self.model_arch == "llama":
37 | self.lm_for_caching = Llama2(name=self.model_name, device_map={"": 0}, load_in_8bit=True)
38 | elif self.model_arch == "falcon":
39 | self.lm_for_caching = Falcon(name=self.model_name, device_map={"": 0}, load_in_8bit=True)
40 | elif self.model_arch == "mpt":
41 | self.lm_for_caching = Mpt(name=self.model_name, device_map={"": 0}, load_in_8bit=True)
42 | else:
43 | raise ValueError("Invalid model name")
44 |
45 | if self.use_cpu_for_inference:
46 | if self.model_arch == "llama":
47 | self.lm = Llama2(name=self.model_name, device_map=None)
48 | elif self.model_arch == "falcon":
49 | self.lm = Falcon(name=self.model_name, device_map=None)
50 | elif self.model_arch == "mpt":
51 | self.lm = Mpt(name=self.model_name, device_map=None)
52 | else:
53 | self.lm = self.lm_for_caching
54 |
55 | self.cache_engine = CacheEngine(self.max_ctx_length, self.lm,
56 | target_device=self.lm.device)
57 | self.gen_engine = GenerationEngine(self.lm)
58 | self.preproc = [
59 | # CompactSpaces(),
60 | self.lm.get_formatter()
61 | ]
62 |
63 | self.dataset_list = {
64 | "narrativeqa": LongBench("narrativeqa"),
65 | "qasper": LongBench("qasper"),
66 | "multifieldqa_en": LongBench("multifieldqa_en"),
67 | "hotpotqa": LongBench("hotpotqa"),
68 | "2wikimqa": LongBench("2wikimqa"),
69 | "musique": LongBench("musique"),
70 | "gov_report": LongBench("gov_report"),
71 | "qmsum": LongBench("qmsum"),
72 | "multi_news": LongBench("multi_news"),
73 | "triviaqa": LongBench("triviaqa"),
74 | "samsum": LongBench("samsum"),
75 | "passage_count": LongBench("passage_count"),
76 | "passage_retrieval_en": LongBench("passage_retrieval_en"),
77 | }
78 |
79 | # @torch.inference_mode()
80 | # def profile_cpu_inference(self):
81 | #
82 | # with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof:
83 | # with record_function("model_inference"):
84 | # model(inputs)
85 |
86 | # recomputation overhead vs mem trasnfer overhead
87 | @torch.inference_mode()
88 | def run_critical_point(self):
89 |
90 | def create_cache(seq_len):
91 |
92 | # # llama 2 13B
93 | num_layers = 40
94 | num_heads = 40
95 | head_dim = 128
96 |
97 | # # llama 2 7B
98 | # num_layers = 32
99 | # num_heads = 32
100 | # head_dim = 128
101 |
102 | return [(torch.rand((num_heads, seq_len, head_dim), dtype=torch.float16, device='cpu'),
103 | torch.rand((num_heads, seq_len, head_dim), dtype=torch.float16, device='cpu')) for _ in
104 | range(num_layers)]
105 |
106 | test_seq_len = [
107 | 1,
108 | 2,
109 | 4,
110 | 8,
111 | 16,
112 | 32,
113 | 64,
114 | 128,
115 | 256,
116 | 512,
117 | 512 + 128 * 1,
118 | 512 + 128 * 2,
119 | 512 + 128 * 3,
120 | 1024,
121 | 1024 + 256 * 1,
122 | 1024 + 256 * 2,
123 | 1024 + 256 * 3,
124 | 2048,
125 | 2028 + 512 * 1,
126 | 2028 + 512 * 2,
127 | 2028 + 512 * 3,
128 | 4096,
129 | # 4096 + 1024 * 1,
130 | # 4096 + 1024 * 2,
131 |
132 | ]
133 |
134 | results = []
135 |
136 | for seq_len in tqdm(test_seq_len):
137 | for _ in range(self.repeat_times):
138 | ## 1. compute gpu upload time
139 | kv_cache = create_cache(seq_len)
140 |
141 | torch.cuda.synchronize()
142 | start = torch.cuda.Event(enable_timing=True)
143 | end = torch.cuda.Event(enable_timing=True)
144 |
145 | start.record()
146 | # upload everything to GPU
147 | kv_cache_gpu = [
148 | (k[0].to('cuda', non_blocking=True, copy=True), k[1].to('cuda', non_blocking=True, copy=True))
149 | for k in kv_cache]
150 |
151 | end.record()
152 | torch.cuda.synchronize()
153 | gpu_upload_time = start.elapsed_time(end)
154 |
155 | del kv_cache_gpu, kv_cache
156 | gc.collect()
157 | torch.cuda.empty_cache()
158 |
159 | results.append({
160 | "seq_len": seq_len,
161 | "time": gpu_upload_time,
162 | })
163 |
164 | result_path = os.path.join(BENCHMARK_PATH, "results_latency")
165 |
166 | with open(os.path.join(result_path, f"{self.memo}-{self.model_log_name}-critical_point-upload.json"),
167 | "w") as f:
168 | json.dump(
169 | {
170 | 'model_name': self.model_name,
171 | 'results': results
172 | }, f)
173 |
174 | results = []
175 | ## 2. compute recomputation time
176 | for seq_len in tqdm(test_seq_len):
177 |
178 | for _ in range(self.repeat_times):
179 | token_ids = [100] * seq_len
180 | position_ids = list(range(seq_len))
181 |
182 | input_ids = torch.tensor([token_ids], device=self.lm.device, dtype=torch.long)
183 | position_ids = torch.tensor([position_ids], device=self.lm.device, dtype=torch.long)
184 |
185 | start = torch.cuda.Event(enable_timing=True)
186 | end = torch.cuda.Event(enable_timing=True)
187 |
188 | start.record()
189 | out = self.lm(input_ids=input_ids,
190 | position_ids=position_ids,
191 | past_key_values=None,
192 | use_cache=False)
193 |
194 | end.record()
195 | torch.cuda.synchronize()
196 | recomputation_time = start.elapsed_time(end)
197 |
198 | del out
199 | gc.collect()
200 | torch.cuda.empty_cache()
201 |
202 | results.append({
203 | "seq_len": seq_len,
204 | "time": recomputation_time
205 | })
206 |
207 | result_path = os.path.join(BENCHMARK_PATH, "results_latency")
208 |
209 | with open(os.path.join(result_path, f"{self.memo}-{self.model_log_name}-critical_point-recomputation.json"),
210 | "w") as f:
211 | json.dump(
212 | {
213 | 'model_name': self.model_log_name,
214 | 'results': results
215 | }, f)
216 |
217 | @torch.inference_mode()
218 | def run_critical_point22(self):
219 |
220 |
221 | test_seq_len = [
222 | 1,
223 | 2,
224 | 4,
225 | 8,
226 | 16,
227 | 32,
228 | 64,
229 | 128,
230 | 256,
231 | 512,
232 | 512 + 128 * 1,
233 | 512 + 128 * 2,
234 | 512 + 128 * 3,
235 | 1024,
236 | 1024 + 256 * 1,
237 | 1024 + 256 * 2,
238 | 1024 + 256 * 3,
239 | 2048,
240 | 2028 + 512 * 1,
241 | 2028 + 512 * 2,
242 | 2028 + 512 * 3,
243 | #4096,
244 | # 4096 + 1024 * 1,
245 | # 4096 + 1024 * 2,
246 |
247 | ]
248 |
249 | results = []
250 |
251 | for seq_len in tqdm(test_seq_len):
252 | for _ in range(self.repeat_times):
253 | ## 1. compute gpu upload time
254 | input_ids = torch.tensor([[100]], device=self.lm.device, dtype=torch.long)
255 | #position_ids = torch.tensor([[100]], device=self.lm.device, dtype=torch.long)
256 |
257 | device_cache = [
258 | (torch.empty(1, 32, seq_len, 128, device=self.lm.device, dtype=torch.half), # key
259 | torch.empty(1, 32, seq_len, 128, device=self.lm.device, dtype=torch.half)) for _ in
260 | range(32)]
261 |
262 | torch.cuda.synchronize()
263 | start = torch.cuda.Event(enable_timing=True)
264 | end = torch.cuda.Event(enable_timing=True)
265 |
266 | start.record()
267 | # upload everything to GPU
268 | out = self.lm(input_ids=input_ids,
269 | #position_ids=position_ids,
270 | past_key_values=device_cache,
271 | use_cache=True)
272 |
273 | end.record()
274 | torch.cuda.synchronize()
275 | gpu_upload_time = start.elapsed_time(end)
276 |
277 | del device_cache
278 | gc.collect()
279 | torch.cuda.empty_cache()
280 |
281 | results.append({
282 | "seq_len": seq_len,
283 | "time": gpu_upload_time,
284 | })
285 |
286 | result_path = os.path.join(BENCHMARK_PATH, "aaa")
287 |
288 | with open(os.path.join(result_path, f"{self.memo}-{self.model_log_name}-critical_point-upload.json"),
289 | "w") as f:
290 | json.dump(
291 | {
292 | 'model_name': self.model_name,
293 | 'results': results
294 | }, f)
295 |
296 | results = []
297 |
298 |
299 | @torch.inference_mode()
300 | def run_latency_eval(self, do_cache):
301 |
302 | for dataset_name in self.dataset_list:
303 |
304 | dataset = self.dataset_list[dataset_name]
305 | dataset.init(limit_entries=5)
306 |
307 | # create result directory
308 | device_used = "cpu" if self.use_cpu_for_inference else "gpu"
309 | cache_used = "cache" if do_cache else "no_cache"
310 | result_path = os.path.join(BENCHMARK_PATH, "results_latency")
311 | no_cache = not do_cache
312 |
313 | if not os.path.exists(result_path):
314 | os.makedirs(result_path)
315 |
316 | results = []
317 |
318 | for entry in tqdm(dataset.entries[:5]):
319 | for _ in range(self.repeat_times):
320 | schema_file_path = os.path.join(SCHEMA_FILE_DIRECTORY, dataset_name, entry.schema)
321 |
322 | self.cache_engine.add_schema(read_file(schema_file_path, self.preproc), no_cache=no_cache,
323 | max_tokens=3500)
324 |
325 | prompt = Prompt(entry.prompt, self.preproc)
326 |
327 | token_ids, position_ids, cache_time, cache = self.cache_engine.process(prompt, no_cache=no_cache,
328 | return_full_position_ids=self.lm.use_full_position_ids)
329 |
330 | input_ids = torch.tensor([token_ids], device=self.lm.device, dtype=torch.long)
331 | position_ids = torch.tensor([position_ids], device=self.lm.device, dtype=torch.long)
332 | # print(len(position_ids[0]))
333 |
334 | # add redundant batch dim
335 | if cache is not None:
336 | cache = [(k[0].unsqueeze(0), k[1].unsqueeze(0)) for k in cache]
337 |
338 | start = torch.cuda.Event(enable_timing=True)
339 | end = torch.cuda.Event(enable_timing=True)
340 |
341 | start.record()
342 |
343 | # with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof:
344 | # with record_function("model_inference"):
345 | out = self.lm(input_ids=input_ids,
346 | position_ids=position_ids,
347 | past_key_values=cache,
348 | use_cache=True)
349 | end.record()
350 | torch.cuda.synchronize()
351 | response_time = start.elapsed_time(end)
352 |
353 | # print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
354 |
355 | result = {
356 | "entry_schema": entry.schema,
357 | "cache_time": cache_time,
358 | "response_time": response_time,
359 | }
360 | # print(result)
361 | results.append(result)
362 |
363 | self.cache_engine.remove_all_schemas()
364 |
365 | with open(
366 | os.path.join(result_path,
367 | f"{self.memo}-{self.model_log_name}-{device_used}-{cache_used}-{dataset_name}.json"),
368 | "w") as f:
369 | json.dump(
370 | {
371 | 'model_name': self.model_log_name,
372 | 'device_used': device_used,
373 | 'cache_used': cache_used,
374 | 'dataset_name': dataset_name,
375 |
376 | 'results': results
377 | }, f)
378 | f.write("\n")
379 |
380 | @torch.inference_mode()
381 | def run_profile(self, do_cache):
382 | device_used = "cpu" if self.use_cpu_for_inference else "gpu"
383 | cache_used = "cache" if do_cache else "no_cache"
384 |
385 | for dataset_name in self.dataset_list:
386 |
387 | dataset = self.dataset_list[dataset_name]
388 | dataset.init(limit_entries=5)
389 |
390 | no_cache = not do_cache
391 |
392 | for entry in tqdm(dataset.entries[:5]):
393 | for _ in range(self.repeat_times):
394 | schema_file_path = os.path.join(SCHEMA_FILE_DIRECTORY, dataset_name, entry.schema)
395 |
396 | self.cache_engine.add_schema(read_file(schema_file_path, self.preproc), no_cache=no_cache,
397 | max_tokens=2500)
398 |
399 | prompt = Prompt(entry.prompt, self.preproc)
400 |
401 | token_ids, position_ids, cache_time, cache = self.cache_engine.process(prompt,
402 | no_cache=no_cache,
403 | return_full_position_ids=self.lm.use_full_position_ids)
404 |
405 | input_ids = torch.tensor([token_ids], device=self.lm.device, dtype=torch.long)
406 | position_ids = torch.tensor([position_ids], device=self.lm.device, dtype=torch.long)
407 | # print(len(position_ids[0]))
408 |
409 | # add redundant batch dim
410 | if cache is not None:
411 | cache = [(k[0].unsqueeze(0), k[1].unsqueeze(0)) for k in cache]
412 |
413 | with profile(activities=[ProfilerActivity.CUDA], with_stack=True,
414 | experimental_config=torch._C._profiler._ExperimentalConfig(verbose=True)) as prof:
415 | with record_function("model_inference"):
416 | out = self.lm(input_ids=input_ids,
417 | position_ids=position_ids,
418 | past_key_values=cache,
419 | use_cache=True)
420 |
421 | prof.export_stacks(f"./profile/{device_used}_{cache_used}_self_cuda_time_total.txt",
422 | "self_cuda_time_total")
423 | self.cache_engine.remove_all_schemas()
424 |
425 | return
426 |
427 |
428 | def main(memo: str = "13900k-cpu", llm_config_path: str = os.path.join('./', "config/llm_config_llama2_7b.json"),
429 | use_cpu_for_inference=True):
430 | eval = Eval(memo, llm_config_path, use_cpu_for_inference)
431 |
432 | # eval.run_latency_eval(False)
433 | # eval.run_latency_eval(True)
434 | #eval.run_profile(True)
435 | #eval.run_profile(False)
436 |
437 | eval.run_critical_point22()
438 |
439 |
440 | if __name__ == "__main__":
441 | fire.Fire(main)
442 |
--------------------------------------------------------------------------------
/eval_sys_a100-7b-cpu.slurm:
--------------------------------------------------------------------------------
1 | #!/bin/zsh
2 | #SBATCH --exclusive
3 | #SBATCH --mem=0
4 | #SBATCH --nodes=1
5 | #SBATCH --ntasks-per-node=64
6 | #SBATCH --partition=gpuA100x4 # <- or one of: gpuA100x4 gpuA40x4 gpuA100x8 gpuMI100x8
7 | #SBATCH --account=bccn-delta-gpu
8 | #SBATCH --job-name=a100_7b_cpu
9 | #SBATCH --time=47:00:00 # hh:mm:ss for the job
10 | #SBATCH --constraint="scratch"
11 |
12 | ### GPU options ###
13 | #SBATCH --gpus-per-node=1
14 | #SBATCH --gpu-bind=closest # select a cpu close to gpu on pci bus topology
15 | #SBATCH --mail-user=in.gim@yale.edu
16 | #SBATCH --mail-type="BEGIN,END"
17 |
18 | module reset # drop modules and explicitly load the ones needed
19 | # (good job metadata and reproducibility)
20 | # $WORK and $SCRATCH are now set
21 | module load python # ... or any appropriate modules
22 | module list # job documentation and metadata
23 | echo "job is starting on `hostname`"
24 |
25 | srun python3 eval_sys.py \
26 | --memo=a100 \
27 | --llm_config_path=./config/llm_config_llama2_7b.json \
28 | --use_cpu_for_inference=True
29 |
--------------------------------------------------------------------------------
/eval_sys_a100-7b-gpu.slurm:
--------------------------------------------------------------------------------
1 | #!/bin/zsh
2 | #SBATCH --exclusive
3 | #SBATCH --mem=0
4 | #SBATCH --nodes=1
5 | #SBATCH --ntasks-per-node=1
6 | #SBATCH --cpus-per-task=16
7 | #SBATCH --partition=gpuA100x4 # <- or one of: gpuA100x4 gpuA40x4 gpuA100x8 gpuMI100x8
8 | #SBATCH --account=bccn-delta-gpu
9 | #SBATCH --job-name=a100_gpu
10 | #SBATCH --time=47:00:00 # hh:mm:ss for the job
11 | #SBATCH --constraint="scratch"
12 |
13 | ### GPU options ###
14 | #SBATCH --gpus-per-node=1
15 | #SBATCH --gpu-bind=closest # select a cpu close to gpu on pci bus topology
16 | #SBATCH --mail-user=in.gim@yale.edu
17 | #SBATCH --mail-type="BEGIN,END"
18 |
19 | module reset # drop modules and explicitly load the ones needed
20 | # (good job metadata and reproducibility)
21 | # $WORK and $SCRATCH are now set
22 | module load python # ... or any appropriate modules
23 | module list # job documentation and metadata
24 | echo "job is starting on `hostname`"
25 | srun python3 eval_sys.py \
26 | --memo=a100 \
27 | --llm_config_path=./config/llm_config_llama2_7b.json \
28 | --use_cpu_for_inference=False
29 |
--------------------------------------------------------------------------------
/eval_sys_a40-7b-cpu.slurm:
--------------------------------------------------------------------------------
1 | #!/bin/zsh
2 | #SBATCH --exclusive
3 | #SBATCH --mem=0
4 | #SBATCH --nodes=1
5 | #SBATCH --ntasks-per-node=64
6 | #SBATCH --partition=gpuA40x4 # <- or one of: gpuA100x4 gpuA40x4 gpuA100x8 gpuMI100x8
7 | #SBATCH --account=bccn-delta-gpu
8 | #SBATCH --job-name=a40_7b_cpu
9 | #SBATCH --time=47:00:00 # hh:mm:ss for the job
10 | #SBATCH --constraint="scratch"
11 |
12 | ### GPU options ###
13 | #SBATCH --gpus-per-node=1
14 | #SBATCH --gpu-bind=closest # select a cpu close to gpu on pci bus topology
15 | #SBATCH --mail-user=in.gim@yale.edu
16 | #SBATCH --mail-type="BEGIN,END"
17 |
18 | module reset # drop modules and explicitly load the ones needed
19 | # (good job metadata and reproducibility)
20 | # $WORK and $SCRATCH are now set
21 | module load python # ... or any appropriate modules
22 | module list # job documentation and metadata
23 | echo "job is starting on `hostname`"
24 |
25 | srun python3 eval_sys.py \
26 | --memo=a40 \
27 | --llm_config_path=./config/llm_config_llama2_7b.json \
28 | --use_cpu_for_inference=True
29 |
--------------------------------------------------------------------------------
/eval_sys_a40-7b-gpu.slurm:
--------------------------------------------------------------------------------
1 | #!/bin/zsh
2 | #SBATCH --exclusive
3 | #SBATCH --mem=0
4 | #SBATCH --nodes=1
5 | #SBATCH --ntasks-per-node=1
6 | #SBATCH --cpus-per-task=16
7 | #SBATCH --partition=gpuA40x4 # <- or one of: gpuA100x4 gpuA40x4 gpuA100x8 gpuMI100x8
8 | #SBATCH --account=bccn-delta-gpu
9 | #SBATCH --job-name=a40_7b_gpu
10 | #SBATCH --time=47:00:00 # hh:mm:ss for the job
11 | #SBATCH --constraint="scratch"
12 |
13 | ### GPU options ###
14 | #SBATCH --gpus-per-node=1
15 | #SBATCH --gpu-bind=closest # select a cpu close to gpu on pci bus topology
16 | #SBATCH --mail-user=in.gim@yale.edu
17 | #SBATCH --mail-type="BEGIN,END"
18 |
19 | module reset # drop modules and explicitly load the ones needed
20 | # (good job metadata and reproducibility)
21 | # $WORK and $SCRATCH are now set
22 | module load python # ... or any appropriate modules
23 | module list # job documentation and metadata
24 | echo "job is starting on `hostname`"
25 | srun python3 eval_sys.py \
26 | --memo=a40 \
27 | --llm_config_path=./config/llm_config_llama2_7b.json \
28 | --use_cpu_for_inference=False
29 |
--------------------------------------------------------------------------------
/examples/code_generation_bookstore.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | You are a sophisticated language model assistant that can read and understand multiple source files
5 | simultaneously. Your current task is to examine the provided source files. These files contain Python
6 | classes and methods.
7 |
8 | Using the knowledge extracted from the source files, you are expected to generate code following the
9 | instructions that will be given.
10 |
11 |
12 |
13 | Please read the given source files, understand their structure and relationships. I'll provide you with my
14 | instruction.
15 |
16 |
17 | class Book:
18 | def __init__(self, title, author, price, isbn):
19 | self.title = title
20 | self.author = author
21 | self.price = price
22 | self.isbn = isbn
23 |
24 | def get_details(self):
25 | return {
26 | "title": self.title,
27 | "author": self.author,
28 | "price": self.price,
29 | "isbn": self.isbn
30 | }
31 |
32 | def set_details(self, title, author, price, isbn):
33 | self.title = title
34 | self.author = author
35 | self.price = price
36 | self.isbn = isbn
37 |
38 |
39 |
40 | class User:
41 | def __init__(self, username, password):
42 | self.username = username
43 | self.password = password
44 | self.purchased_books = []
45 |
46 | def register(self, username, password):
47 | self.username = username
48 | self.password = password
49 |
50 | def login(self, username, password):
51 | return self.username == username and self.password == password
52 |
53 | def purchase(self, book):
54 | self.purchased_books.append(book)
55 |
56 | def view_purchased_books(self):
57 | return self.purchased_books
58 |
59 |
60 |
61 |
62 | class Inventory:
63 | def __init__(self):
64 | self.books = []
65 |
66 | def add_book(self, book):
67 | self.books.append(book)
68 |
69 | def remove_book(self, book):
70 | self.books.remove(book)
71 |
72 | def search_by_title(self, title):
73 | return [book for book in self.books if book.title == title]
74 |
75 | def search_by_author(self, author):
76 | return [book for book in self.books if book.author == author]
77 |
78 |
79 |
80 | from Book import Book
81 | from User import User
82 | from Inventory import Inventory
83 |
84 | class Store:
85 | def __init__(self):
86 | self.users = []
87 | self.inventory = Inventory()
88 |
89 | def register_user(self, username, password):
90 | user = User(username, password)
91 | self.users.append(user)
92 |
93 | def login_user(self, username, password):
94 | for user in self.users:
95 | if user.login(username, password):
96 | return user
97 | return None
98 |
99 | def purchase_book(self, user, book):
100 | user.purchase(book)
101 |
102 | def view_inventory(self):
103 | return self.inventory.books
104 |
105 | def search_books(self, search_type, query):
106 | if search_type == "title":
107 | return self.inventory.search_by_title(query)
108 | elif search_type == "author":
109 | return self.inventory.search_by_author(query)
110 | return []
111 |
112 |
113 |
114 | class Database:
115 | def __init__(self):
116 | self.users = []
117 | self.books = []
118 |
119 | def save_user(self, user):
120 | self.users.append(user)
121 |
122 | def save_book(self, book):
123 | self.books.append(book)
124 |
125 | def retrieve_all_users(self):
126 | return self.users
127 |
128 | def retrieve_all_books(self):
129 | return self.books
130 |
131 | def find_user(self, username):
132 | for user in self.users:
133 | if user.username == username:
134 | return user
135 | return None
136 |
137 | def find_book(self, isbn):
138 | for book in self.books:
139 | if book.isbn == isbn:
140 | return book
141 | return None
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 | I have read and understood the source codes. I am ready to generate code.
150 | Give me the instructions.
151 |
152 |
153 |
--------------------------------------------------------------------------------
/examples/parameterized_prompts.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | You are a world-renowned travel planner, having curated trips for both the enthusiastic explorer and the
5 | leisure-seeking traveler. Your adeptness in understanding unique travel preferences, combined with your
6 | extensive knowledge of global destinations, ensures an unforgettable journey.
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 | I'm gearing up for a memorable escape for the duration of.
15 |
16 |
17 |
18 |
19 | My eyes are set on exploring the diverse tapestry of our own homeland. I've chosen the vibrant city
20 | of
21 | . Given its domestic charm, I anticipate an itinerary that beautifully
22 | blends
23 | affordability with rich experiences. I'm eager to soak in the city's heritage, dine on local
24 | delicacies,
25 | and engage with its heartwarming community.
26 |
27 |
28 |
29 |
30 |
31 | I'm yearning to tread international soils, to immerse myself in unfamiliar cultures, sceneries, and
32 | stories.
33 |
34 |
35 |
36 | The Maldives beckons me with its constellation of coral islands. Renowned for its ethereal
37 | underwater beauty and luxury resorts, it offers an unmatched tropical experience. From
38 | diving
39 | into the vibrant coral reefs to dining in an underwater restaurant, from watching
40 | bioluminescent
41 | phytoplankton light up the shores to enjoying traditional drumming and dancing, every moment
42 | promises to be a slice of paradise.
43 |
44 |
45 |
46 |
47 |
48 | The vast expanse of the Amazon rainforest stirs the adventurer in me. A treasure trove of
49 | biodiversity, it spans across several countries in South America. I dream of navigating the
50 | serpentine Amazon river, witnessing the cacophony of life in the jungle, and understanding
51 | the
52 | age-old traditions of the indigenous tribes. The enchanting flora, diverse fauna, and the
53 | mesmerizing sounds of the forest await my arrival.
54 |
55 |
56 |
57 |
58 |
59 | The golden embrace of the Sahara desert is an experience I seek. Beyond its mesmerizing
60 | dunes
61 | lies a world rich in history and culture. The tales of ancient caravans, the artistry of
62 | Berber
63 | music, and the thrill of dune bashing are just the tip of the sand dune. Nights under the
64 | vast
65 | desert sky, sharing stories around a campfire, and the hauntingly beautiful music of the
66 | desert
67 | dwellers promise an ethereal journey.
68 |
69 |
70 |
71 |
72 |
73 | Tokyo, Japan's bustling capital, is a harmonious blend of tradition and modernity.
74 | Skyscrapers
75 | stand alongside historic temples, and cutting-edge technology coexists with age-old customs.
76 | From the serenity of the Meiji Shrine to the electric vibes of Shibuya Crossing, from
77 | sumptuous
78 | sushi feasts to tranquil tea ceremonies, Tokyo promises a whirlwind of experiences.
79 |
80 |
81 |
82 |
83 |
84 | The eternal city of Rome, with its millennia of history, calls out to the history aficionado
85 | in
86 | me. Walking its streets feels like a journey through time, from ancient Roman ruins to
87 | Renaissance art. From tossing a coin in the Trevi Fountain to a leisurely stroll in the
88 | Roman
89 | Forum, from savoring a gelato by the Spanish Steps to attending a mass at St. Peter's
90 | Basilica,
91 | Rome promises a cultural feast.
92 |
93 |
94 |
95 |
96 |
97 | Cape Town, nestled at the foot of the majestic Table Mountain in South Africa, offers a
98 | cocktail
99 | of experiences. With its rich history, vibrant arts scene, and diverse landscapes ranging
100 | from
101 | beaches to vineyards, it's a traveler's paradise. I look forward to exploring the historic
102 | Robben Island, savoring wines in Stellenbosch, and witnessing the panoramic views from the
103 | Cape
104 | of Good Hope.
105 |
106 |
107 |
108 |
109 |
110 | Sydney, the shimmering jewel of Australia, where urban sophistication meets natural beauty.
111 | The
112 | iconic Sydney Opera House and the majestic Sydney Harbour Bridge dominate its skyline.
113 | Beaches
114 | like Bondi and Manly promise sun, surf, and relaxation, while the city's rich culinary
115 | scene,
116 | bustling markets, and world-class museums ensure an enriching urban adventure.
117 |
118 |
119 |
120 |
121 |
122 | Buenos Aires, Argentina's cosmopolitan capital, is a city that pulsates with passion and
123 | energy.
124 | Tango music fills its air, and its streets are a canvas of art and culture. From the
125 | historic
126 | San Telmo district to the vibrant La Boca neighborhood, from savoring a traditional asado to
127 | dancing the night away in a milonga, Buenos Aires promises a fiery and rhythmic journey.
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 | I'd love to help. I've carefully read the city details and the trip duration. Please give me an instruction.
137 |
138 |
139 |
--------------------------------------------------------------------------------
/g.sh:
--------------------------------------------------------------------------------
1 | srun --exclusive --mem=0 --nodes=1 --ntasks=1 --ntasks-per-node=1 --cpus-per-task=16 \
2 | --partition=gpuA40x4-interactive --account=bccn-delta-gpu --gpu-bind=verbose,per_task:1 \
3 | --gpus-per-node=1 --gpus-per-task=1 --constraint="scratch" --pty /bin/zsh
--------------------------------------------------------------------------------
/get_scores.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import json
3 |
4 | import numpy as np
5 | from tqdm import tqdm
6 |
7 | from metrics import (
8 | qa_f1_score,
9 | rouge_zh_score,
10 | qa_f1_zh_score,
11 | rouge_score,
12 | classification_score,
13 | retrieval_score,
14 | retrieval_zh_score,
15 | count_score,
16 | code_sim_score
17 | )
18 |
19 | dataset2metric = {
20 | "squad_v2": qa_f1_score,
21 | "narrativeqa": qa_f1_score,
22 | "qasper": qa_f1_score,
23 | "multifieldqa_en": qa_f1_score,
24 | "multifieldqa_zh": qa_f1_zh_score,
25 | "hotpotqa": qa_f1_score,
26 | "2wikimqa": qa_f1_score,
27 | "musique": qa_f1_score,
28 | "dureader": rouge_zh_score,
29 | "gov_report": rouge_score,
30 | "qmsum": rouge_score,
31 | "multi_news": rouge_score,
32 | # "trec": classification_score,
33 | "triviaqa": qa_f1_score,
34 | "samsum": rouge_score,
35 | "passage_retrieval_en": retrieval_score,
36 | "passage_count": count_score,
37 | "passage_retrieval_zh": retrieval_zh_score,
38 | "lcc": code_sim_score,
39 | "repobench-p": code_sim_score,
40 | }
41 |
42 |
43 | def main():
44 | dset_list = [
45 | #"squad_v2",
46 | "narrativeqa",
47 | "qasper",
48 | "multifieldqa_en",
49 | "hotpotqa",
50 | "2wikimqa",
51 | "musique",
52 | "gov_report",
53 | "qmsum",
54 | "multi_news",
55 | # "trec",
56 | "triviaqa",
57 | # "samsum",
58 | "passage_count",
59 | "passage_retrieval_en",
60 | # "lcc",
61 | # "repobench-p",
62 | ]
63 |
64 | model_list = [
65 | #"falcon",
66 | "llama",
67 | #"mpt"
68 | ]
69 |
70 | for dset in tqdm(dset_list):
71 | print(f"Dataset: {dset}------------------------------")
72 | for m in model_list:
73 | p = f"./results_13b/{m}-{dset}"
74 | with open(glob.glob(f"{p}/no_cache_*.json")[0], "r") as f:
75 | no_cache = [json.loads(line) for line in f]
76 | no_cache_score, nc_std = score(no_cache, dset)
77 |
78 | with open(glob.glob(f"{p}/with_cache_*.json")[0], "r") as f:
79 | with_cache = [json.loads(line) for line in f]
80 | with_cache_score, wc_std = score(with_cache, dset)
81 |
82 | print(f"{m}-{dset}: {no_cache_score:.2f} ({nc_std:.2f}) vs {with_cache_score:.2f} ({wc_std:.2f}) ")
83 |
84 |
85 | def score(results, dataset_name):
86 | scores_list = []
87 | for result in results:
88 | response = result["response"].split('')[0].split('<|endoftext|>')[0].split('<|im_end|>')[0]
89 | answers = result["answers"]
90 |
91 | score = 0.
92 | for answer in answers:
93 | score = max(score, dataset2metric[dataset_name](response, answer))
94 |
95 | scores_list.append(score)
96 |
97 | #print(scores_list)
98 | mean = np.mean(scores_list) * 100
99 | std = np.std(scores_list)
100 |
101 | return mean, std
102 |
103 |
104 | if __name__ == "__main__":
105 | main()
106 |
--------------------------------------------------------------------------------
/metrics.py:
--------------------------------------------------------------------------------
1 | import re
2 | import string
3 |
4 | import jieba
5 | from fuzzywuzzy import fuzz
6 | import difflib
7 |
8 | from typing import List
9 | from collections import Counter
10 | from rouge import Rouge
11 |
12 |
13 | def normalize_answer(s):
14 | """Lower text and remove punctuation, articles and extra whitespace."""
15 |
16 | def remove_articles(text):
17 | return re.sub(r"\b(a|an|the)\b", " ", text)
18 |
19 | def white_space_fix(text):
20 | return " ".join(text.split())
21 |
22 | def remove_punc(text):
23 | exclude = set(string.punctuation)
24 | return "".join(ch for ch in text if ch not in exclude)
25 |
26 | def lower(text):
27 | return text.lower()
28 |
29 | return white_space_fix(remove_articles(remove_punc(lower(s))))
30 |
31 |
32 | def normalize_zh_answer(s):
33 | """Lower text and remove punctuation, extra whitespace."""
34 |
35 | def white_space_fix(text):
36 | return "".join(text.split())
37 |
38 | def remove_punc(text):
39 | cn_punctuation = "!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏."
40 | all_punctuation = set(string.punctuation + cn_punctuation)
41 | return "".join(ch for ch in text if ch not in all_punctuation)
42 |
43 | def lower(text):
44 | return text.lower()
45 |
46 | return white_space_fix(remove_punc(lower(s)))
47 |
48 |
49 | def count_score(prediction, ground_truth, **kwargs):
50 | numbers = re.findall(r"\d+", prediction)
51 | right_num = 0
52 | for number in numbers:
53 | if str(number) == str(ground_truth):
54 | right_num += 1
55 | final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
56 | return float(final_score)
57 |
58 |
59 | def retrieval_score(prediction, ground_truth, **kwargs):
60 | pattern = r'Paragraph (\d+)'
61 | matches = re.findall(pattern, ground_truth)
62 | ground_truth_id = matches[0]
63 | numbers = re.findall(r"\d+", prediction)
64 | right_num = 0
65 | for number in numbers:
66 | if str(number) == str(ground_truth_id):
67 | right_num += 1
68 | final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
69 | return float(final_score)
70 |
71 |
72 | def retrieval_zh_score(prediction, ground_truth, **kwargs):
73 | pattern = r'段落(\d+)'
74 | matches = re.findall(pattern, ground_truth)
75 | ground_truth_id = matches[0]
76 | numbers = re.findall(r"\d+", prediction)
77 | right_num = 0
78 | for number in numbers:
79 | if str(number) == str(ground_truth_id):
80 | right_num += 1
81 | final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
82 | return float(final_score)
83 |
84 |
85 | def code_sim_score(prediction, ground_truth, **kwargs):
86 | all_lines = prediction.lstrip('\n').split('\n')
87 | prediction = ""
88 | for line in all_lines:
89 | if ('`' not in line) and ('#' not in line) and ('//' not in line):
90 | prediction = line
91 | break
92 | return (fuzz.ratio(prediction, ground_truth) / 100)
93 |
94 |
95 | def classification_score(prediction, ground_truth, **kwargs):
96 | em_match_list = []
97 | all_classes = kwargs["all_classes"]
98 | for class_name in all_classes:
99 | if class_name in prediction:
100 | em_match_list.append(class_name)
101 | for match_term in em_match_list:
102 | if match_term in ground_truth and match_term != ground_truth:
103 | em_match_list.remove(match_term)
104 | if em_match_list != 0:
105 | if ground_truth in em_match_list:
106 | score = (1.0 / len(em_match_list))
107 | else:
108 | score = 0.0
109 | else:
110 | best_match = None
111 | highest_similarity = 0
112 | for string in all_classes:
113 | similarity = difflib.SequenceMatcher(None, string, prediction).ratio()
114 | if similarity > highest_similarity:
115 | highest_similarity = similarity
116 | best_match = string
117 | score = float(best_match == ground_truth)
118 | return score
119 |
120 |
121 | def rouge_score(prediction, ground_truth, **kwargs):
122 | rouge = Rouge()
123 | try:
124 | scores = rouge.get_scores([prediction], [ground_truth], avg=True)
125 | except:
126 | return 0.0
127 | return scores["rouge-l"]["f"]
128 |
129 |
130 | def rouge_zh_score(prediction, ground_truth, **kwargs):
131 | prediction = " ".join(list(jieba.cut(prediction, cut_all=False)))
132 | ground_truth = " ".join(list(jieba.cut(ground_truth, cut_all=False)))
133 | score = rouge_score(prediction, ground_truth)
134 | return score
135 |
136 |
137 | def f1_score(prediction, ground_truth, **kwargs):
138 | common = Counter(prediction) & Counter(ground_truth)
139 | num_same = sum(common.values())
140 | if num_same == 0:
141 | return 0
142 | precision = 1.0 * num_same / len(prediction)
143 | recall = 1.0 * num_same / len(ground_truth)
144 | f1 = (2 * precision * recall) / (precision + recall)
145 | return f1
146 |
147 |
148 | def qa_f1_score(prediction, ground_truth, **kwargs):
149 | normalized_prediction = normalize_answer(prediction)
150 | normalized_ground_truth = normalize_answer(ground_truth)
151 |
152 | prediction_tokens = normalized_prediction.split()
153 | ground_truth_tokens = normalized_ground_truth.split()
154 | return f1_score(prediction_tokens, ground_truth_tokens)
155 |
156 |
157 | def qa_f1_zh_score(prediction, ground_truth, **kwargs):
158 | prediction_tokens = list(jieba.cut(prediction, cut_all=False))
159 | ground_truth_tokens = list(jieba.cut(ground_truth, cut_all=False))
160 | prediction_tokens = [normalize_zh_answer(token) for token in prediction_tokens]
161 | ground_truth_tokens = [normalize_zh_answer(token) for token in ground_truth_tokens]
162 | prediction_tokens = [token for token in prediction_tokens if len(token) > 0]
163 | ground_truth_tokens = [token for token in ground_truth_tokens if len(token) > 0]
164 | return f1_score(prediction_tokens, ground_truth_tokens)
165 |
--------------------------------------------------------------------------------
/promptcache/__init__.py:
--------------------------------------------------------------------------------
1 | # re-export all modules
2 |
3 | from .cache_engine import CacheEngine
4 | from .generation_engine import GenerationEngine, GenerationParameters
5 | from .prompt import Prompt, CompactSpaces, read_file
6 | from .schema import Schema
7 | from .conversation import llama2_template
--------------------------------------------------------------------------------
/promptcache/compiler.py:
--------------------------------------------------------------------------------
1 | import ast
2 | import inspect
3 | import astor
4 | from functools import wraps
5 |
6 |
7 | class SchemaBuilder:
8 | schema: list[str | tuple[bool, str], tuple[str, dict]]
9 | schema_name: str | None
10 |
11 | def __init__(self):
12 | self.schema = []
13 | self.schema_name = None
14 |
15 | def text(self, value: str):
16 | self.schema.append(value)
17 |
18 | def cond(self, cond: bool, _):
19 | value = self.schema.pop()
20 | self.schema.append((cond, value))
21 |
22 | def union(self, target: str, values: dict):
23 | self.schema.append((target, values))
24 |
25 | def get_schema(self) -> str:
26 |
27 | uid = 0
28 | output = ""
29 | for item in self.schema:
30 | match item:
31 | case str(val):
32 | output += val + "\n"
33 | case (bool(), str(val)):
34 | output += f"{val}"
35 | uid += 1
36 | case (str(), dict(d)):
37 | output += ""
38 | for cond, val in d.items():
39 | output += f"{val}"
40 | uid += 1
41 | output += ""
42 |
43 | self.schema_name = f'{abs(hash(output)):x}'
44 |
45 | output = f"" + output
46 | output += ""
47 | return output
48 |
49 | def get_prompt(self) -> str:
50 |
51 | if self.schema_name is None:
52 | self.get_schema()
53 |
54 | output = f""
55 | uid = 0
56 | for item in self.schema:
57 | match item:
58 | case (bool(cond), str()):
59 | if cond:
60 | output += f""
61 | uid += 1
62 | case (str(target), dict(d)):
63 | for cond, val in d.items():
64 | if target == cond:
65 | output += f""
66 | uid += 1
67 | case _:
68 | ...
69 | output += ""
70 | return output
71 |
72 |
73 | class CodeTransformer(ast.NodeTransformer):
74 |
75 | def visit_Expr(self, node):
76 | if isinstance(node.value, ast.Constant) and isinstance(node.value.value, str):
77 | # Safely format the string
78 | return ast.parse(f'builder.text({repr(node.value.value)})').body[0]
79 | return node
80 |
81 | def visit_If(self, node):
82 | condition = astor.to_source(node.test).strip()
83 | then_body = [self.visit(n) for n in node.body]
84 | content = astor.to_source(then_body[0]).strip()
85 | return ast.parse(f'builder.cond({condition}, {content})').body[0]
86 |
87 | def visit_Match(self, node):
88 | subject = astor.to_source(node.subject).strip()
89 | cases = {}
90 | for case in node.cases:
91 | key = astor.to_source(case.pattern.value).strip()
92 | value = astor.to_source(case.body[0]).strip()
93 | cases[key] = value
94 | cases_str = ", ".join([f'{key}: {value}' for key, value in cases.items()])
95 | return ast.parse(f'builder.union({subject}, {{{cases_str}}})').body[0]
96 |
97 |
98 | cached_compiled_funcs = dict()
99 |
100 |
101 | def prompt(func):
102 | #
103 | @wraps(func)
104 | def wrapper(*args, **kwargs):
105 | global cached_compiled_funcs
106 |
107 | if func in cached_compiled_funcs:
108 | print('using cached')
109 | return cached_compiled_funcs[func](*args, **kwargs)
110 |
111 | source_code = inspect.getsource(func).splitlines()
112 | # strip decorators from the source itself
113 | source_code = "\n".join(line for line in source_code if not line.startswith('@'))
114 |
115 | tree = ast.parse(source_code)
116 |
117 | # Insert 'env = Env()' at the start of the function body
118 | env_init = ast.parse("builder = SchemaBuilder()").body[0]
119 | tree.body[0].body.insert(0, env_init)
120 |
121 | # Ensure the transformed function returns 'env'
122 | tree.body[0].body.append(ast.parse("return builder").body[0])
123 |
124 | transformer = CodeTransformer()
125 | new_tree = transformer.visit(tree)
126 |
127 | # for debugging
128 | new_source_code = astor.to_source(new_tree)
129 |
130 | local_vars = {}
131 | exec(compile(new_tree, filename="", mode="exec"), func.__globals__, local_vars)
132 | modified_func = local_vars[func.__name__]
133 | cached_compiled_funcs[func] = modified_func
134 | return modified_func(*args, **kwargs)
135 |
136 | return wrapper
137 |
--------------------------------------------------------------------------------
/promptcache/conversation.py:
--------------------------------------------------------------------------------
1 | from enum import auto, Enum, IntEnum
2 | from typing import List, Any, Dict
3 | import dataclasses
4 |
5 | import numpy as np
6 | import torch
7 | import transformers
8 | from transformers.trainer_pt_utils import LabelSmoother
9 |
10 | IGNORE_TOKEN_ID = LabelSmoother.ignore_index
11 |
12 |
13 | class SeparatorStyle(IntEnum):
14 | """Separator styles."""
15 |
16 | ADD_COLON_SINGLE = auto()
17 | ADD_COLON_TWO = auto()
18 | ADD_COLON_SPACE_SINGLE = auto()
19 | NO_COLON_SINGLE = auto()
20 | NO_COLON_TWO = auto()
21 | ADD_NEW_LINE_SINGLE = auto()
22 | LLAMA2 = auto()
23 | CHATGLM = auto()
24 | CHATML = auto()
25 | CHATINTERN = auto()
26 | DOLLY = auto()
27 | RWKV = auto()
28 | PHOENIX = auto()
29 | ROBIN = auto()
30 |
31 |
32 | @dataclasses.dataclass
33 | class Conversation:
34 | """A class that manages prompt templates and keeps all conversation history."""
35 |
36 | # The name of this template
37 | name: str
38 | # The system prompt
39 | system: str
40 | # Two roles
41 | roles: List[str]
42 | # All messages. Each item is (role, message).
43 | messages: List[List[str]]
44 | # The number of few shot examples
45 | offset: int
46 | # Separators
47 | sep_style: SeparatorStyle
48 | sep: str
49 | sep2: str = None
50 | # Stop criteria (the default one is EOS token)
51 | stop_str: str = None
52 | # Stops generation if meeting any token in this list
53 | stop_token_ids: List[int] = None
54 |
55 | def get_prompt(self) -> str:
56 | """Get the prompt for generation."""
57 | if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:
58 | ret = self.system + self.sep
59 | for role, message in self.messages:
60 | if message:
61 | ret += role + ": " + message + self.sep
62 | else:
63 | ret += role + ":"
64 | return ret
65 | elif self.sep_style == SeparatorStyle.ADD_COLON_TWO:
66 | seps = [self.sep, self.sep2]
67 | ret = self.system + seps[0]
68 | for i, (role, message) in enumerate(self.messages):
69 | if message:
70 | ret += role + ": " + message + seps[i % 2]
71 | else:
72 | ret += role + ":"
73 | return ret
74 | elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE:
75 | ret = self.system + self.sep
76 | for role, message in self.messages:
77 | if message:
78 | ret += role + ": " + message + self.sep
79 | else:
80 | ret += role + ": " # must be end with a space
81 | return ret
82 | elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE:
83 | ret = "" if self.system == "" else self.system + self.sep
84 | for role, message in self.messages:
85 | if message:
86 | ret += role + "\n" + message + self.sep
87 | else:
88 | ret += role + "\n"
89 | return ret
90 | elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
91 | ret = self.system
92 | for role, message in self.messages:
93 | if message:
94 | ret += role + message + self.sep
95 | else:
96 | ret += role
97 | return ret
98 | elif self.sep_style == SeparatorStyle.NO_COLON_TWO:
99 | seps = [self.sep, self.sep2]
100 | ret = self.system
101 | for i, (role, message) in enumerate(self.messages):
102 | if message:
103 | ret += role + message + seps[i % 2]
104 | else:
105 | ret += role
106 | return ret
107 | elif self.sep_style == SeparatorStyle.RWKV:
108 | ret = self.system
109 | for i, (role, message) in enumerate(self.messages):
110 | if message:
111 | ret += (
112 | role
113 | + ": "
114 | + message.replace("\r\n", "\n").replace("\n\n", "\n")
115 | )
116 | ret += "\n\n"
117 | else:
118 | ret += role + ":"
119 | return ret
120 | elif self.sep_style == SeparatorStyle.LLAMA2:
121 | seps = [self.sep, self.sep2]
122 | ret = ""
123 | for i, (role, message) in enumerate(self.messages):
124 | if message:
125 | if i == 0:
126 | ret += self.system + message
127 | else:
128 | ret += role + " " + message + seps[i % 2]
129 | else:
130 | ret += role
131 | return ret
132 | elif self.sep_style == SeparatorStyle.CHATGLM:
133 | # source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
134 | # source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
135 | round_add_n = 1 if self.name == "chatglm2" else 0
136 | if self.system:
137 | ret = self.system + self.sep
138 | else:
139 | ret = ""
140 |
141 | for i, (role, message) in enumerate(self.messages):
142 | if i % 2 == 0:
143 | ret += f"[Round {i // 2 + round_add_n}]{self.sep}"
144 |
145 | if message:
146 | ret += f"{role}:{message}{self.sep}"
147 | else:
148 | ret += f"{role}:"
149 | return ret
150 | elif self.sep_style == SeparatorStyle.CHATML:
151 | ret = "" if self.system == "" else self.system + self.sep + "\n"
152 | for role, message in self.messages:
153 | if message:
154 | ret += role + "\n" + message + self.sep + "\n"
155 | else:
156 | ret += role + "\n"
157 | return ret
158 | elif self.sep_style == SeparatorStyle.CHATINTERN:
159 | # source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771
160 | seps = [self.sep, self.sep2]
161 | ret = self.system
162 | for i, (role, message) in enumerate(self.messages):
163 | if i % 2 == 0:
164 | ret += ""
165 | if message:
166 | ret += role + ":" + message + seps[i % 2] + "\n"
167 | else:
168 | ret += role + ":"
169 | return ret
170 | elif self.sep_style == SeparatorStyle.DOLLY:
171 | seps = [self.sep, self.sep2]
172 | ret = self.system
173 | for i, (role, message) in enumerate(self.messages):
174 | if message:
175 | ret += role + ":\n" + message + seps[i % 2]
176 | if i % 2 == 1:
177 | ret += "\n\n"
178 | else:
179 | ret += role + ":\n"
180 | return ret
181 | elif self.sep_style == SeparatorStyle.PHOENIX:
182 | ret = self.system
183 | for role, message in self.messages:
184 | if message:
185 | ret += role + ": " + "" + message + ""
186 | else:
187 | ret += role + ": " + ""
188 | return ret
189 | elif self.sep_style == SeparatorStyle.ROBIN:
190 | ret = self.system + self.sep
191 | for role, message in self.messages:
192 | if message:
193 | ret += role + ":\n" + message + self.sep
194 | else:
195 | ret += role + ":\n"
196 | return ret
197 | else:
198 | raise ValueError(f"Invalid style: {self.sep_style}")
199 |
200 | def append_message(self, role: str, message: str):
201 | """Append a new message."""
202 | self.messages.append([role, message])
203 |
204 | def update_last_message(self, message: str):
205 | """Update the last output.
206 |
207 | The last message is typically set to be None when constructing the prompt,
208 | so we need to update it in-place after getting the response from a model.
209 | """
210 | self.messages[-1][1] = message
211 |
212 | def to_gradio_chatbot(self):
213 | """Convert the conversation to gradio chatbot format."""
214 | ret = []
215 | for i, (role, msg) in enumerate(self.messages[self.offset:]):
216 | if i % 2 == 0:
217 | ret.append([msg, None])
218 | else:
219 | ret[-1][-1] = msg
220 | return ret
221 |
222 | def to_openai_api_messages(self):
223 | """Convert the conversation to OpenAI chat completion format."""
224 | ret = [{"role": "system", "content": self.system}]
225 |
226 | for i, (_, msg) in enumerate(self.messages[self.offset:]):
227 | if i % 2 == 0:
228 | ret.append({"role": "user", "content": msg})
229 | else:
230 | if msg is not None:
231 | ret.append({"role": "assistant", "content": msg})
232 | return ret
233 |
234 | def copy(self):
235 | return Conversation(
236 | name=self.name,
237 | system=self.system,
238 | roles=self.roles,
239 | messages=[[x, y] for x, y in self.messages],
240 | offset=self.offset,
241 | sep_style=self.sep_style,
242 | sep=self.sep,
243 | sep2=self.sep2,
244 | stop_str=self.stop_str,
245 | stop_token_ids=self.stop_token_ids,
246 | )
247 |
248 | def dict(self):
249 | return {
250 | "template_name": self.name,
251 | "system": self.system,
252 | "roles": self.roles,
253 | "messages": self.messages,
254 | "offset": self.offset,
255 | }
256 |
257 |
258 | def vicuna_template(system: str = "A chat between a user and an artificial intelligence assistant. "
259 | "The assistant gives helpful and detailed answers to the user's questions."):
260 | return Conversation(
261 | name="vicuna_v1.1",
262 | system=system,
263 | roles=["USER", "ASSISTANT"],
264 | messages=[],
265 | offset=0,
266 | sep_style=SeparatorStyle.ADD_COLON_TWO,
267 | sep=" ",
268 | sep2="",
269 | )
270 |
271 |
272 | def llama2_template(
273 | system: str = "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. "
274 | "Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. "
275 | "Please ensure that your responses are socially unbiased and positive in nature.\n\n"
276 | "If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. "
277 | "If you don't know the answer to a question, please don't share false information.\n)"):
278 | return Conversation(
279 | name="llama-2",
280 | system=f" [INST] <>\n{system}<>\n\n ",
281 | roles=["[INST]", "[/INST]"],
282 | messages=[],
283 | offset=0,
284 | sep_style=SeparatorStyle.LLAMA2,
285 | sep=" ",
286 | sep2="",
287 | stop_token_ids=[2],
288 | )
289 |
290 | # wittgenstein
291 |
292 |
293 | #
294 | # def preprocess(
295 | # conv,
296 | # source,
297 | # tokenizer: transformers.PreTrainedTokenizer,
298 | # ) -> (torch.Tensor, torch.Tensor):
299 | # roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
300 | #
301 | # # print(sources)
302 | #
303 | # # Apply prompt templates
304 | #
305 | # if roles[source[0]["from"]] != conv.roles[0]:
306 | # # Skip the first one if it is not from human
307 | # source = source[1:]
308 | #
309 | # conv.messages = []
310 | # for j, sentence in enumerate(source):
311 | # role = roles[sentence["from"]]
312 | # assert role == conv.roles[j % 2], f"{i}"
313 | # conv.append_message(role, sentence["value"])
314 | #
315 | # prompt = conv.get_prompt()
316 | #
317 | # # Tokenize conversations
318 | # input_ids = tokenizer(
319 | # conversations,
320 | # return_tensors="pt"
321 | # ).input_ids
322 | #
323 | # targets = input_ids.clone()
324 | #
325 | # # print(tokenizer.decode(input_ids[0]))
326 | #
327 | # # assert conv.sep_style == SeparatorStyle.ADD_COLON_TWO
328 | #
329 | # #Mask targets
330 | # sep = conv.sep + conv.roles[1] + ": "
331 | # for conversation, target in zip(conversations, targets):
332 | # total_len = int(target.ne(tokenizer.pad_token_id).sum())
333 | #
334 | # rounds = conversation.split(conv.sep2)
335 | # cur_len = 1
336 | # target[:cur_len] = IGNORE_TOKEN_ID
337 | # for i, rou in enumerate(rounds):
338 | # if rou == "":
339 | # break
340 | #
341 | # parts = rou.split(sep)
342 | # if len(parts) != 2:
343 | # break
344 | # parts[0] += sep
345 | # round_len = len(tokenizer(rou).input_ids)
346 | # instruction_len = len(tokenizer(parts[0]).input_ids) - 2
347 | #
348 | # target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID
349 | #
350 | # cur_len += round_len
351 | # target[cur_len:] = IGNORE_TOKEN_ID
352 | #
353 | # if cur_len < tokenizer.model_max_length:
354 | # if cur_len != total_len:
355 | # target[:] = IGNORE_TOKEN_ID
356 | # print(
357 | # f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
358 | # f" (ignored)"
359 | # )
360 | # print(target)
361 | #
362 | # return input_ids, targets
363 |
364 | # print(len(input_ids[0]))
365 | # print(tokenizer.decode(input_ids[0], skip_special_tokens=True))
366 | # print(len(targets[0]))
367 | # targets[0][targets[0] == IGNORE_TOKEN_ID] = tokenizer.pad_token_id
368 | # print(tokenizer.decode(targets[0], skip_special_tokens=True))
369 | #
370 | # return dict(
371 | # input_ids=input_ids,
372 | # labels=targets,
373 | # attention_mask=input_ids.ne(tokenizer.pad_token_id),
374 | # )
375 |
--------------------------------------------------------------------------------
/promptcache/generation_engine.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import time
3 | from dataclasses import dataclass, field
4 | from typing import Optional, Generator, List, Tuple
5 |
6 | import torch
7 |
8 | from promptcache.cache_engine import KVCache
9 |
10 | from transformers.generation.logits_process import (
11 | LogitsProcessorList,
12 | RepetitionPenaltyLogitsProcessor,
13 | TemperatureLogitsWarper,
14 | TopKLogitsWarper,
15 | TopPLogitsWarper,
16 | )
17 | import termcolor
18 | from promptcache.model import LanguageModel
19 |
20 |
21 | @dataclass
22 | class GenerationParameters:
23 | temperature: float = 1.0
24 | repetition_penalty: float = 1.0
25 | top_p: float = 1.0
26 | top_k: int = -1
27 | max_new_tokens: int = 256
28 | stop_token_ids: List[int] = field(default_factory=lambda: [])
29 | stop_str: List[str] = field(default_factory=lambda: [])
30 | echo: bool = True
31 |
32 | def get_logits_processor(self):
33 | p = LogitsProcessorList()
34 | if self.temperature >= 1e-5 and self.temperature != 1.0:
35 | p.append(TemperatureLogitsWarper(self.temperature))
36 | if self.repetition_penalty > 1.0:
37 | p.append(RepetitionPenaltyLogitsProcessor(self.repetition_penalty))
38 | if 1e-8 <= self.top_p < 1.0:
39 | p.append(TopPLogitsWarper(self.top_p))
40 | if self.top_k > 0:
41 | p.append(TopKLogitsWarper(self.top_k))
42 | return p
43 |
44 |
45 | def is_partial_stop(output: str, stop_str: str):
46 | """Check whether the output contains a partial stop str."""
47 | for i in range(0, min(len(output), len(stop_str))):
48 | if stop_str.startswith(output[-i:]):
49 | return True
50 | return False
51 |
52 |
53 | @dataclass
54 | class Output:
55 | text: str
56 | new_text: str
57 | response_time: float = 0.0
58 | elapsed_time: float = 0.0
59 |
60 |
61 | class GenerationEngine:
62 | lm: LanguageModel
63 |
64 | def __init__(self, lm: LanguageModel):
65 | self.lm = lm
66 |
67 | @torch.inference_mode()
68 | def generate(self,
69 | token_ids: List[int],
70 | position_ids: List[int],
71 | params: GenerationParameters,
72 | cache: Optional[KVCache] = None,
73 | stream_interval: int = 2,
74 | use_full_position_ids: bool = False) -> Generator[Output, None, None]:
75 |
76 | logits_processor = params.get_logits_processor()
77 | output_ids = list(token_ids)
78 | new_output_ids = list()
79 |
80 | device = self.lm.device
81 |
82 | position_offset = max(position_ids) + 1
83 | past_key_values = None
84 | new_token_id = 0
85 |
86 | inference_time = 0.0
87 | response_time = 0.0
88 |
89 | position_ids_og = position_ids
90 |
91 | for i in range(params.max_new_tokens):
92 |
93 | # initial phase
94 | if past_key_values is None:
95 |
96 | input_ids = torch.tensor([token_ids], device=device, dtype=torch.long)
97 | position_ids = torch.tensor([position_ids], device=device, dtype=torch.long)
98 | # print(len(position_ids[0]))
99 |
100 | # add redundant batch dim
101 | if cache is not None:
102 | cache = [(k[0].unsqueeze(0), k[1].unsqueeze(0)) for k in cache]
103 |
104 | start = torch.cuda.Event(enable_timing=True)
105 | end = torch.cuda.Event(enable_timing=True)
106 |
107 | start.record()
108 | out = self.lm(input_ids=input_ids,
109 | position_ids=position_ids,
110 | past_key_values=cache,
111 | use_cache=True)
112 | end.record()
113 | torch.cuda.synchronize()
114 | inference_time += start.elapsed_time(end)
115 | response_time = inference_time
116 | # print(f'Response time: {inference_time:.2f} ms')
117 | # pretty print using termcolor
118 | print(termcolor.colored(f'Prefill latency: {inference_time:.2f} ms', 'yellow'))
119 |
120 | logits = out.logits
121 | past_key_values = out.past_key_values
122 |
123 | else:
124 | # upload to the GPU
125 | input_ids = torch.tensor([[new_token_id]], device=device, dtype=torch.long)
126 |
127 | if use_full_position_ids:
128 | position_ids = torch.tensor([position_ids_og + list(range(position_offset, position_offset + i))],
129 | device=device, dtype=torch.long)
130 |
131 | else:
132 | position_ids = torch.tensor([[position_offset + i]], device=device, dtype=torch.long)
133 |
134 | start = torch.cuda.Event(enable_timing=True)
135 | end = torch.cuda.Event(enable_timing=True)
136 |
137 | start.record()
138 | out = self.lm(input_ids=input_ids,
139 | position_ids=position_ids,
140 | past_key_values=past_key_values,
141 | use_cache=True)
142 | end.record()
143 | torch.cuda.synchronize()
144 | inference_time += start.elapsed_time(end)
145 |
146 | logits = out.logits
147 | past_key_values = out.past_key_values
148 |
149 | if params.repetition_penalty > 1.0:
150 | tmp_output_ids = torch.as_tensor([output_ids], device=device)
151 | else:
152 | tmp_output_ids = None
153 |
154 | last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0]
155 |
156 | ccc = []
157 |
158 | if params.temperature < 1e-5 or params.top_p < 1e-8: # greedy
159 | new_token_id = int(torch.argmax(last_token_logits))
160 | # _, indices = torch.topk(last_token_logits, 2)
161 | # ccc = [int(index) for index in indices.tolist()]
162 |
163 | else:
164 | probs = torch.softmax(last_token_logits, dim=-1)
165 | new_token_id = int(torch.multinomial(probs, num_samples=1))
166 | # probs = torch.softmax(last_token_logits, dim=-1)
167 | # indices = torch.multinomial(probs, num_samples=2)
168 | # ccc = [int(token) for token in indices.tolist()]
169 |
170 | # new_token_id = ccc[1]
171 |
172 | output_ids.append(new_token_id)
173 | new_output_ids.append(new_token_id)
174 |
175 | # print(self.lm.decode([new_token_id]))
176 |
177 | if new_token_id in params.stop_token_ids:
178 | # print('Stopped', self.lm.decode([new_token_id]), self.lm.encode('. '))
179 | stopped = True
180 | else:
181 | stopped = False
182 |
183 | if i % stream_interval == 0 or i == params.max_new_tokens - 1 or stopped:
184 | output = self.lm.decode(output_ids)
185 | new_output = self.lm.decode(new_output_ids)
186 |
187 | partially_stopped = False
188 |
189 | for each_stop in params.stop_str:
190 | pos = new_output.rfind(each_stop, 0)
191 | if pos != -1:
192 | new_output = new_output[:pos]
193 | stopped = True
194 | break
195 | else:
196 | partially_stopped = is_partial_stop(output, each_stop)
197 | if partially_stopped:
198 | break
199 |
200 | if not partially_stopped:
201 | yield Output(output, new_output, inference_time, response_time)
202 |
203 | if stopped:
204 | break
205 |
206 | # clean
207 | del past_key_values, out
208 | gc.collect()
209 | torch.cuda.empty_cache()
210 |
--------------------------------------------------------------------------------
/promptcache/inference.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import gc
3 | import math
4 | from typing import Iterable, Optional
5 | import sys
6 | import time
7 | import warnings
8 |
9 | from conversation import Conversation, SeparatorStyle, llama2_template
10 |
11 | import torch
12 | from transformers import (
13 | AutoTokenizer,
14 | AutoModelForCausalLM,
15 | LlamaTokenizer,
16 | LlamaForCausalLM,
17 | AutoModel,
18 | AutoModelForSeq2SeqLM,
19 | T5Tokenizer,
20 | AutoConfig,
21 | )
22 | from transformers.generation.logits_process import (
23 | LogitsProcessorList,
24 | RepetitionPenaltyLogitsProcessor,
25 | TemperatureLogitsWarper,
26 | TopKLogitsWarper,
27 | TopPLogitsWarper,
28 | )
29 |
30 |
31 | def prepare_logits_processor(
32 | temperature: float, repetition_penalty: float, top_p: float, top_k: int
33 | ) -> LogitsProcessorList:
34 | processor_list = LogitsProcessorList()
35 | # TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op so we skip two cases.
36 | if temperature >= 1e-5 and temperature != 1.0:
37 | processor_list.append(TemperatureLogitsWarper(temperature))
38 | if repetition_penalty > 1.0:
39 | processor_list.append(
40 | RepetitionPenaltyLogitsProcessor(repetition_penalty))
41 | if 1e-8 <= top_p < 1.0:
42 | processor_list.append(TopPLogitsWarper(top_p))
43 | if top_k > 0:
44 | processor_list.append(TopKLogitsWarper(top_k))
45 | return processor_list
46 |
47 |
48 | def partial_stop(output, stop_str):
49 | for i in range(0, min(len(output), len(stop_str))):
50 | if stop_str.startswith(output[-i:]):
51 | return True
52 | return False
53 |
54 |
55 | @torch.inference_mode()
56 | def generate_stream(
57 | model, tokenizer, params, device, context_len=2048, stream_interval=2
58 | ):
59 | prompt = params["prompt"]
60 | len_prompt = len(prompt)
61 | temperature = float(params.get("temperature", 1.0))
62 | repetition_penalty = float(params.get("repetition_penalty", 1.0))
63 | top_p = float(params.get("top_p", 1.0))
64 | top_k = int(params.get("top_k", -1)) # -1 means disable
65 | max_new_tokens = int(params.get("max_new_tokens", 256))
66 | stop_str = params.get("stop", None)
67 | echo = bool(params.get("echo", True))
68 | stop_token_ids = params.get("stop_token_ids", None) or []
69 | stop_token_ids.append(tokenizer.eos_token_id)
70 |
71 | logits_processor = prepare_logits_processor(
72 | temperature, repetition_penalty, top_p, top_k
73 | )
74 |
75 | input_ids = tokenizer(prompt).input_ids
76 | input_echo_len = len(input_ids)
77 | output_ids = list(input_ids)
78 |
79 | max_src_len = context_len - max_new_tokens - 8
80 |
81 | input_ids = input_ids[-max_src_len:]
82 |
83 | # print(prompt)
84 | # aa = tokenizer.convert_ids_to_tokens(input_ids)
85 | # print(aa)
86 |
87 | past_key_values = out = None
88 | for i in range(max_new_tokens):
89 | if i == 0:
90 | print(torch.as_tensor(
91 | [input_ids], device=device))
92 |
93 | out = model(torch.as_tensor([input_ids], device=device))
94 | logits = out.logits
95 | past_key_values = out.past_key_values
96 | else:
97 |
98 | out = model(
99 | input_ids=torch.as_tensor([[token]], device=device),
100 | use_cache=True,
101 | past_key_values=past_key_values,
102 | )
103 | logits = out.logits
104 | past_key_values = out.past_key_values
105 |
106 | if logits_processor:
107 | if repetition_penalty > 1.0:
108 | tmp_output_ids = torch.as_tensor(
109 | [output_ids], device=logits.device)
110 | else:
111 | tmp_output_ids = None
112 | last_token_logits = logits_processor(
113 | tmp_output_ids, logits[:, -1, :])[0]
114 | else:
115 | last_token_logits = logits[0, -1, :]
116 |
117 | if temperature < 1e-5 or top_p < 1e-8: # greedy
118 | token = int(torch.argmax(last_token_logits))
119 | else:
120 | probs = torch.softmax(last_token_logits, dim=-1)
121 | token = int(torch.multinomial(probs, num_samples=1))
122 |
123 | output_ids.append(token)
124 |
125 | if token in stop_token_ids:
126 | stopped = True
127 | else:
128 | stopped = False
129 |
130 | if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped:
131 | if echo:
132 | tmp_output_ids = output_ids
133 | rfind_start = len_prompt
134 | else:
135 | tmp_output_ids = output_ids[input_echo_len:]
136 | rfind_start = 0
137 |
138 | output = tokenizer.decode(
139 | tmp_output_ids,
140 | skip_special_tokens=True,
141 | spaces_between_special_tokens=False,
142 | )
143 |
144 | partially_stopped = False
145 | if stop_str:
146 | if isinstance(stop_str, str):
147 | pos = output.rfind(stop_str, rfind_start)
148 | if pos != -1:
149 | output = output[:pos]
150 | stopped = True
151 | else:
152 | partially_stopped = partial_stop(output, stop_str)
153 | elif isinstance(stop_str, Iterable):
154 | for each_stop in stop_str:
155 | pos = output.rfind(each_stop, rfind_start)
156 | if pos != -1:
157 | output = output[:pos]
158 | stopped = True
159 | break
160 | else:
161 | partially_stopped = partial_stop(output, each_stop)
162 | if partially_stopped:
163 | break
164 | else:
165 | raise ValueError("Invalid stop field type.")
166 |
167 | # prevent yielding partial stop sequence
168 | if not partially_stopped:
169 | yield {
170 | "text": output,
171 | "usage": {
172 | "prompt_tokens": input_echo_len,
173 | "completion_tokens": i,
174 | "total_tokens": input_echo_len + i,
175 | },
176 | "finish_reason": None,
177 | }
178 |
179 | if stopped:
180 | break
181 |
182 | # finish stream event, which contains finish reason
183 | if i == max_new_tokens - 1:
184 | finish_reason = "length"
185 | elif stopped:
186 | finish_reason = "stop"
187 | else:
188 | finish_reason = None
189 |
190 | yield {
191 | "text": output,
192 | "usage": {
193 | "prompt_tokens": input_echo_len,
194 | "completion_tokens": i,
195 | "total_tokens": input_echo_len + i,
196 | },
197 | "finish_reason": finish_reason,
198 | }
199 |
200 | # clean
201 | del past_key_values, out
202 | gc.collect()
203 | torch.cuda.empty_cache()
204 |
205 |
206 | class ChatIO(abc.ABC):
207 | @abc.abstractmethod
208 | def prompt_for_input(self, role: str) -> str:
209 | """Prompt for input from a role."""
210 |
211 | @abc.abstractmethod
212 | def prompt_for_output(self, role: str):
213 | """Prompt for output from a role."""
214 |
215 | @abc.abstractmethod
216 | def stream_output(self, output_stream):
217 | """Stream output."""
218 |
219 |
220 | class SimpleChatIO(ChatIO):
221 | def prompt_for_input(self, role) -> str:
222 | return input(f"{role}: ")
223 |
224 | def prompt_for_output(self, role: str):
225 | print(f"{role}: ", end="", flush=True)
226 |
227 | def stream_output(self, output_stream):
228 | pre = 0
229 | for outputs in output_stream:
230 | output_text = outputs["text"]
231 | output_text = output_text.strip().split(" ")
232 | now = len(output_text) - 1
233 | if now > pre:
234 | print(" ".join(output_text[pre:now]), end=" ", flush=True)
235 | pre = now
236 | print(" ".join(output_text[pre:]), flush=True)
237 | return " ".join(output_text)
238 |
239 |
240 | def chat_loop(
241 | model_path: str,
242 | temperature: float,
243 | repetition_penalty: float,
244 | max_new_tokens: int,
245 | chatio: ChatIO,
246 | debug: bool,
247 | ):
248 | # Model
249 | # model, tokenizer = load_model(
250 | # model_path,
251 | # device,
252 | # num_gpus,
253 | # max_gpu_memory,
254 | # load_8bit,
255 | # cpu_offloading,
256 | # gptq_config,
257 | # revision,
258 | # debug,
259 | # )
260 |
261 | tokenizer = AutoTokenizer.from_pretrained(model_path)
262 |
263 | model = LlamaForCausalLM.from_pretrained(model_path)
264 |
265 | device = 'cpu'
266 |
267 | conv = llama2_template()
268 |
269 | while True:
270 | try:
271 | inp = chatio.prompt_for_input(conv.roles[0])
272 | except EOFError:
273 | inp = ""
274 |
275 | if inp == "!!exit" or not inp:
276 | print("exit...")
277 | break
278 |
279 | conv.append_message(conv.roles[0], inp)
280 | conv.append_message(conv.roles[1], None)
281 |
282 | prompt = conv.get_prompt()
283 |
284 | gen_params = {
285 | "model": model_path,
286 | "prompt": prompt,
287 | "temperature": temperature,
288 | "repetition_penalty": repetition_penalty,
289 | "max_new_tokens": max_new_tokens,
290 | "stop": conv.stop_str,
291 | "stop_token_ids": conv.stop_token_ids,
292 | "echo": False,
293 | }
294 |
295 | chatio.prompt_for_output(conv.roles[1])
296 | output_stream = generate_stream(model, tokenizer, gen_params, device)
297 | t = time.time()
298 | outputs = chatio.stream_output(output_stream)
299 | duration = time.time() - t
300 | conv.update_last_message(outputs.strip())
301 |
302 | if debug:
303 | num_tokens = len(tokenizer.encode(outputs))
304 | msg = {
305 | "conv_template": conv.name,
306 | "prompt": prompt,
307 | "outputs": outputs,
308 | "speed (token/s)": round(num_tokens / duration, 2),
309 | }
310 | print(f"\n{msg}\n")
311 |
312 |
313 | if __name__ == "__main__":
314 | chat_loop(
315 | model_path="meta-llama/Llama-2-7b-chat-hf",
316 | temperature=0.7,
317 | repetition_penalty=1.0,
318 | max_new_tokens=512,
319 | chatio=SimpleChatIO(),
320 | debug=True,
321 | )
322 |
--------------------------------------------------------------------------------
/promptcache/model/__init__.py:
--------------------------------------------------------------------------------
1 | import abc
2 | from typing import Callable, List, Optional, Tuple
3 | import torch
4 | import re
5 |
6 | from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer, PreTrainedTokenizer, \
7 | PretrainedConfig, PreTrainedModel, CodeLlamaTokenizer
8 |
9 | from promptcache.model.falcon import FalconForCausalLM
10 | from promptcache.model.llama2 import LlamaForCausalLM
11 | from promptcache.model.mpt import MptForCausalLM
12 | from promptcache.prompt import Preprocessor, escape_xml, PreprocessorList
13 |
14 |
15 | # supported models
16 | # gpt2
17 | # aquilla
18 | # bloom
19 | # baichuan
20 |
21 | # llama
22 | # llama2
23 | # code llama
24 | # wizardlm
25 | # mpt-chat
26 | # alpaca
27 | # ghatglm
28 | # t5
29 | # falcon
30 | # dolly
31 | #
32 | #
33 | # Aquila (BAAI/Aquila-7B, BAAI/AquilaChat-7B, etc.)
34 | # Baichuan (baichuan-inc/Baichuan-7B, baichuan-inc/Baichuan-13B-Chat, etc.)
35 | # BLOOM (bigscience/bloom, bigscience/bloomz, etc.)
36 | # Falcon (tiiuae/falcon-7b, tiiuae/falcon-40b, tiiuae/falcon-rw-7b, etc.)
37 | # GPT-2 (gpt2, gpt2-xl, etc.)
38 | # GPT BigCode (bigcode/starcoder, bigcode/gpt_bigcode-santacoder, etc.)
39 | # GPT-J (EleutherAI/gpt-j-6b, nomic-ai/gpt4all-j, etc.)
40 | # GPT-NeoX (EleutherAI/gpt-neox-20b, databricks/dolly-v2-12b, stabilityai/stablelm-tuned-alpha-7b, etc.)
41 | # InternLM (internlm/internlm-7b, internlm/internlm-chat-7b, etc.)
42 | # LLaMA & LLaMA-2 (meta-llama/Llama-2-70b-hf, lmsys/vicuna-13b-v1.3, young-geng/koala, openlm-research/open_llama_13b, etc.)
43 | # MPT (mosaicml/mpt-7b, mosaicml/mpt-30b, etc.)
44 | # OPT (facebook/opt-66b, facebook/opt-iml-max-30b, etc.)
45 | # Qwen (Qwen/Qwen-7B, Qwen/Qwen-7B-Chat, etc.)
46 |
47 |
48 | class FormatConversation(Preprocessor):
49 | system: (str, str, str)
50 | user: (str, str)
51 | assistant: (str, str)
52 |
53 | def __init__(self, system: (str, str, str), user: (str, str), assistant: (str, str)):
54 | super().__init__()
55 |
56 | self.system = (escape_xml(system[0]), escape_xml(system[1]), escape_xml(system[2]))
57 | self.user = (escape_xml(user[0]), escape_xml(user[1]))
58 | self.assistant = (escape_xml(assistant[0]), escape_xml(assistant[1]))
59 |
60 | def __call__(self, prompt: str) -> str:
61 | replacement_pairs = [
62 | ("", self.system[0]),
63 | ("", self.system[1]),
64 | ("", self.system[2]),
65 | ("", self.user[0]),
66 | ("", self.user[1]),
67 | ("", self.assistant[0]),
68 | ("", self.assistant[1])
69 | ]
70 |
71 | # remove space before
72 | prompt = re.sub(r' +', '', prompt)
73 |
74 | for old, new in replacement_pairs:
75 | prompt = prompt.replace(old, new)
76 |
77 | return prompt
78 |
79 |
80 | class FormatLlama2Conversation(FormatConversation):
81 |
82 | def __init__(self):
83 | super().__init__(
84 | system=("[INST] <>\n", "<>\n\n", "[INST]"),
85 | user=(" ", "[/INST]"),
86 | assistant=(" ", "[INST]")
87 | )
88 |
89 |
90 | class LanguageModel(abc.ABC):
91 | name: str
92 | hf_tokenizer: PreTrainedTokenizer
93 | hf_model: PreTrainedModel
94 | stop_token_ids: List[int]
95 | stop_str: List[str]
96 | use_full_position_ids: bool = False
97 |
98 | def __init__(self, name: str, model: PreTrainedModel, tokenizer: PreTrainedTokenizer,
99 | stop_token_ids: Optional[List[int]] = None, stop_str: Optional[List[str]] = None):
100 | self.name = name
101 | self.hf_tokenizer = tokenizer
102 | self.hf_model = model
103 | self.stop_token_ids = stop_token_ids if stop_token_ids is not None else [self.eos_token_id]
104 | self.stop_str = stop_str if stop_str is not None else []
105 |
106 | @abc.abstractmethod
107 | def get_formatter(self) -> Callable[[str], str]:
108 | pass
109 |
110 | def get_cache_shape(self) -> Tuple[int, int, int]:
111 | num_head = self.config.num_attention_heads
112 | head_dim = self.config.hidden_size // self.config.num_attention_heads
113 |
114 | return self.config.num_hidden_layers, num_head, head_dim
115 |
116 | def store_k_hook(self, k_cache: torch.Tensor) -> torch.Tensor:
117 | return k_cache
118 |
119 | def store_v_hook(self, v_cache: torch.Tensor) -> torch.Tensor:
120 | return v_cache
121 |
122 | def read_k_hook(self, k_cache: torch.Tensor) -> torch.Tensor:
123 | return k_cache
124 |
125 | def read_v_hook(self, v_cache: torch.Tensor) -> torch.Tensor:
126 | return v_cache
127 |
128 | def __call__(self, **kwargs):
129 | return self.hf_model(**kwargs)
130 |
131 | def encode(self, text: str) -> List[int]:
132 | # Warning: this is a hack to remove bos_token
133 | token_ids = self.hf_tokenizer.encode(text, add_special_tokens=False)
134 | return token_ids
135 |
136 | def decode(self, token_ids: List[int]) -> str:
137 | return self.hf_tokenizer.decode(token_ids, skip_special_tokens=False, spaces_between_special_tokens=False)
138 |
139 | @property
140 | def unk_token(self) -> str:
141 | return self.hf_tokenizer.unk_token
142 |
143 | @property
144 | def unk_token_id(self) -> int:
145 | return self.hf_tokenizer.unk_token_id
146 |
147 | @property
148 | def eos_token(self) -> str:
149 | return self.hf_tokenizer.eos_token
150 |
151 | @property
152 | def eos_token_id(self) -> int:
153 | return self.hf_tokenizer.eos_token_id
154 |
155 | @property
156 | def device(self) -> torch.device:
157 | return self.hf_model.device
158 |
159 | @property
160 | def config(self) -> PretrainedConfig:
161 | return self.hf_model.config
162 |
163 |
164 | class CodeLlama(LanguageModel):
165 |
166 | def __init__(self, name: str = "codellama/CodeLlama-13b-Instruct-hf", **kwargs):
167 | tokenizer = CodeLlamaTokenizer.from_pretrained(name)
168 | model = LlamaForCausalLM.from_pretrained(name, **kwargs)
169 |
170 | self.formatter = FormatConversation(
171 | system=(" [INST] <>\n", "<>\n\n", " [INST] "),
172 | user=("", "[/INST]"),
173 | assistant=("", " [INST] "))
174 |
175 | stop_token_ids = [tokenizer.eos_token_id]
176 |
177 | stop_str = [""]
178 |
179 | super().__init__(name, model, tokenizer, stop_token_ids, stop_str)
180 |
181 | def get_formatter(self) -> Callable[[str], str]:
182 | return self.formatter
183 |
184 |
185 | class Llama2(LanguageModel):
186 |
187 | def __init__(self, name: str = "meta-llama/Llama-2-7b-chat-hf", **kwargs):
188 | tokenizer = LlamaTokenizer.from_pretrained(name)
189 | model = LlamaForCausalLM.from_pretrained(name, **kwargs)
190 |
191 | self.formatter = FormatConversation(
192 | system=(" [INST] <>\n", "<>\n\n", " [INST] "),
193 | user=("", "[/INST]"),
194 | assistant=("", " [INST] "))
195 |
196 | stop_token_ids = [tokenizer.eos_token_id]
197 |
198 | stop_str = [""]
199 |
200 | super().__init__(name, model, tokenizer, stop_token_ids, stop_str)
201 |
202 | def get_formatter(self) -> Callable[[str], str]:
203 | return self.formatter
204 |
205 |
206 | class Falcon(LanguageModel):
207 | def __init__(self, name="tiiuae/falcon-7b-instruct", **kwargs):
208 | tokenizer = AutoTokenizer.from_pretrained(name)
209 | model = FalconForCausalLM.from_pretrained(name, **kwargs)
210 |
211 | def rep(prompt: str) -> str:
212 | return prompt.replace("\r\n", "\n").replace("\n\n", "\n")
213 |
214 | # name = "falcon",
215 | # roles = ("User", "Assistant"),
216 | # messages = [],
217 | # sep_style = SeparatorStyle.RWKV,
218 | # sep = "\n",
219 | # sep2 = "<|endoftext|>",
220 | # stop_str = "\nUser",
221 |
222 | conv = FormatConversation(
223 | system=("", "\n\n", ""),
224 | user=("User: ", "\n\nAssistant:"),
225 | assistant=(" ", "\n\n"))
226 | #
227 | # conv = FormatConversation(
228 | # system=("", "\n\n", ""),
229 | # user=("<|prompt|>", "<|endoftext|>"),
230 | # assistant=("<|answer|>", "<|endoftext|>"))
231 |
232 | self.formatter = PreprocessorList([
233 | rep, conv
234 | ])
235 |
236 | stop_token_ids = [0,
237 | 1,
238 | 2,
239 | 3,
240 | 4,
241 | 5,
242 | 6,
243 | 7,
244 | 8,
245 | 9,
246 | 10,
247 | 11, ]
248 |
249 | stop_str = ["<|endoftext|>", "\nUser"]
250 |
251 | super().__init__(name, model, tokenizer, stop_token_ids, stop_str)
252 |
253 | def get_formatter(self) -> Callable[[str], str]:
254 | return self.formatter
255 |
256 | def get_cache_shape(self) -> Tuple[int, int, int]:
257 | head_dim = self.hf_model.config.hidden_size // self.hf_model.config.num_attention_heads
258 | return self.hf_model.config.num_hidden_layers, 1, head_dim
259 |
260 |
261 | class Mpt(LanguageModel):
262 | def __init__(self, name="mosaicml/mpt-7b-chat", **kwargs):
263 | # tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b", trust_remote_code=True)
264 | tokenizer = AutoTokenizer.from_pretrained(name, trust_remote_code=True)
265 |
266 | model = MptForCausalLM.from_pretrained(name, max_seq_len=8192,
267 | **kwargs)
268 |
269 | conv = FormatConversation(
270 | system=("<|im_start|>system\n", "<|im_end|>\n", ""),
271 | user=("<|im_start|>user\n", "<|im_end|>\n<|im_start|>assistant\n"),
272 | assistant=("", "<|im_end|>\n"))
273 |
274 | self.formatter = conv
275 | self.use_full_position_ids = True
276 |
277 | stop_token_ids = [50278, 0]
278 | stop_str = []
279 |
280 | super().__init__(name, model, tokenizer, stop_token_ids, stop_str)
281 |
282 | def get_formatter(self) -> Callable[[str], str]:
283 | return self.formatter
284 |
285 | # https://huggingface.co/mosaicml/mpt-7b-chat/blob/main/configuration_mpt.py
286 | def get_cache_shape(self) -> Tuple[int, int, int]:
287 | head_dim = self.hf_model.config.d_model // self.hf_model.config.n_heads,
288 | return self.hf_model.config.n_layers, self.hf_model.config.n_heads, head_dim[0]
289 | #
290 | # def store_k_hook(self, v_cache: torch.Tensor) -> torch.Tensor:
291 | # # batch, n_layers, seq_len, head_dim = v_cache.shape
292 | # return v_cache.transpose(2, 3)
293 | #
294 | # def read_k_hook(self, v_cache: torch.Tensor) -> torch.Tensor:
295 | # return v_cache.transpose(1, 2)
296 |
--------------------------------------------------------------------------------
/promptcache/prompt.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import re
4 | from abc import ABC, abstractmethod
5 | from typing import Union, List, Any, Optional
6 | from dataclasses import dataclass
7 |
8 | import lxml
9 | import lxml.etree
10 |
11 | import xml.sax.saxutils
12 |
13 |
14 | def repr_indent(obj: Any, indent: int = 1) -> str:
15 | return '\n'.join([indent * '\t' + s for s in repr(obj).split('\n')])
16 |
17 |
18 | def read_file(filename: str, preprocessors: List[Preprocessor] = None) -> str:
19 | with open(filename, 'r') as f:
20 |
21 | text = f.read()
22 |
23 | if preprocessors is not None:
24 | for p in preprocessors:
25 | text = p(text)
26 |
27 | return text
28 |
29 |
30 | def apply_preproc(text: str, preprocessors: List[Preprocessor] = None) -> str:
31 | if preprocessors is not None:
32 | for p in preprocessors:
33 | text = p(text)
34 |
35 | return text
36 |
37 |
38 | def escape_xml(data):
39 | return xml.sax.saxutils.escape(data, entities={
40 | "'": "'",
41 | "\"": """
42 | })
43 |
44 |
45 | # Replaces multiple leading and trailing whitespaces with a single space.
46 | def compact_surrounding_spaces(text: str) -> str:
47 | return re.sub(r'^\s+|\s+$', ' ', text)
48 |
49 |
50 | def compact_spaces(text: str) -> str:
51 | return ' '.join(text.split())
52 |
53 |
54 | class Preprocessor(ABC):
55 | """This class is used to preprocess the prompt before it is passed to the model."""
56 |
57 | def __init__(self):
58 | pass
59 |
60 | @abstractmethod
61 | def __call__(self, prompt: str) -> str:
62 | raise NotImplementedError
63 |
64 |
65 | class PreprocessorList(Preprocessor):
66 | """This class is used to preprocess the prompt before it is passed to the model."""
67 | pre: List[Preprocessor]
68 |
69 | def __init__(self, pre: List[Preprocessor]):
70 | self.pre = pre
71 |
72 | super().__init__()
73 |
74 | def __call__(self, prompt: str) -> str:
75 | for p in self.pre:
76 | prompt = p(prompt)
77 | return prompt
78 |
79 |
80 | class CompactSpaces(Preprocessor):
81 | """This class is used to remove all leading and trailing whitespaces."""
82 | only_surrounding: bool
83 |
84 | def __init__(self, only_surrounding: bool = False):
85 | super().__init__()
86 | self.only_surrounding = only_surrounding
87 |
88 | def __call__(self, prompt: str) -> str:
89 |
90 | if self.only_surrounding:
91 | return compact_surrounding_spaces(prompt)
92 | else:
93 | return compact_spaces(prompt)
94 |
95 |
96 | class ModuleRef:
97 | """This class is used to represent a module reference in the prompt."""
98 | name: str
99 | args: List[Argument]
100 | modules: List[ModuleRef]
101 |
102 | def __init__(self, spec: lxml.etree.Element = None):
103 | if spec is not None:
104 | self._process(spec)
105 |
106 | def _process(self, root: lxml.etree.Element):
107 | self.name = root.tag
108 | self.args = [Argument(name, value) for name, value in root.attrib.items()]
109 |
110 | self.modules = []
111 |
112 | # leading text
113 | if root.text is not None and len(root.text.strip()) > 0:
114 | raise ValueError("Module reference cannot have text")
115 |
116 | for e in root:
117 | self.modules.append(ModuleRef(e))
118 | if e.tail is not None and len(e.tail.strip()) > 0:
119 | raise ValueError("Module reference cannot have text")
120 |
121 | def __repr__(self) -> str:
122 |
123 | args = " ".join([f"{arg.name}={repr(arg.value)}" for arg in self.args])
124 | if len(args) > 0:
125 | r = f"@{self.name}({args})"
126 | else:
127 | r = f"@{self.name}"
128 |
129 | for m in self.modules:
130 | r += '\n' + repr_indent(m)
131 | return r
132 |
133 |
134 | @dataclass
135 | class Argument:
136 | name: str
137 | value: str
138 |
139 |
140 | class Prompt(ModuleRef):
141 | """This class is used to represent a prompt."""
142 | schema: str
143 | text: str
144 |
145 | preproc: List[Preprocessor]
146 |
147 | def __init__(self, spec: Union[str, lxml.etree.Element], preproc: Optional[List[Preprocessor]] = None):
148 |
149 | super().__init__()
150 | self.args = []
151 | self.preproc = preproc if preproc is not None else []
152 | self.text = ""
153 | if type(spec) == str:
154 |
155 | for p in self.preproc:
156 | spec = p(spec)
157 |
158 | parser = lxml.etree.XMLParser(recover=True)
159 | spec = lxml.etree.fromstring(spec, parser=parser)
160 |
161 | self._process(spec)
162 |
163 | def _process(self, root: lxml.etree.Element):
164 |
165 | assert root.tag == "prompt"
166 |
167 | if "schema" in root.attrib:
168 | self.schema = root.attrib["schema"]
169 | else:
170 | self.schema = "" # empty schema
171 |
172 | self.name = self.schema
173 | self.modules = []
174 |
175 | # text only prompt
176 | if len(root) == 0:
177 | text = root.text
178 | if text is None:
179 | raise ValueError("Prompt cannot be empty")
180 |
181 | self.text = compact_surrounding_spaces(text)
182 |
183 | else:
184 | if root.text is not None and len(root.text.strip()) > 0:
185 | raise ValueError("Prompt cannot have leading text")
186 |
187 | tail = False
188 | for e in root:
189 | if tail:
190 | raise ValueError("Prompt cannot have text between module references")
191 |
192 | self.modules.append(ModuleRef(e))
193 |
194 | if e.tail is not None:
195 | self.text = compact_surrounding_spaces(e.tail)
196 |
197 | self.text = self.text.strip()
198 |
199 | def add_text(self, text: str):
200 | for p in self.preproc:
201 | text = p(text)
202 |
203 | self.text += text
204 |
205 | def __repr__(self) -> str:
206 | r = f"Schema: @{self.name}"
207 | for m in self.modules:
208 | r += '\n' + repr_indent(m)
209 | r += '\nText: ' + repr(self.text)
210 | return r
211 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | astor==0.8.1
2 | datasets==2.14.7
3 | fire==0.6.0
4 | fuzzywuzzy==0.18.0
5 | jieba==0.42.1
6 | lxml==5.2.1
7 | numpy==1.26.4
8 | Requests==2.31.0
9 | rouge==1.0.1
10 | termcolor==2.4.0
11 | torch==2.3.0
12 | tqdm==4.66.4
13 | transformers==4.34.0
--------------------------------------------------------------------------------
/score.py:
--------------------------------------------------------------------------------
1 | import fire
2 | import json
3 |
4 | from metrics import (
5 | qa_f1_score,
6 | rouge_zh_score,
7 | qa_f1_zh_score,
8 | rouge_score,
9 | classification_score,
10 | retrieval_score,
11 | retrieval_zh_score,
12 | count_score,
13 | code_sim_score
14 | )
15 |
16 | dataset2metric = {
17 | "narrativeqa": qa_f1_score,
18 | "qasper": qa_f1_score,
19 | "multifieldqa_en": qa_f1_score,
20 | "multifieldqa_zh": qa_f1_zh_score,
21 | "hotpotqa": qa_f1_score,
22 | "2wikimqa": qa_f1_score,
23 | "musique": qa_f1_score,
24 | "dureader": rouge_zh_score,
25 | "gov_report": rouge_score,
26 | "qmsum": rouge_score,
27 | "multi_news": rouge_score,
28 | "vcsum": rouge_zh_score,
29 | "trec": classification_score,
30 | "triviaqa": qa_f1_score,
31 | "samsum": rouge_score,
32 | "lsht": classification_score,
33 | "passage_retrieval_en": retrieval_score,
34 | "passage_count": count_score,
35 | "passage_retrieval_zh": retrieval_zh_score,
36 | "lcc": code_sim_score,
37 | "repobench-p": code_sim_score,
38 | }
39 |
40 |
41 | def main(result_file):
42 | dataset_name = result_file.split('/')[-3].split('-')[1]
43 | with open(result_file, 'r') as f:
44 | # load line by line
45 | results = [json.loads(line) for line in f]
46 | total_score = 0.
47 | for result in results:
48 | response = result["response"]
49 | answers = result["answers"]
50 |
51 | score = 0.
52 | for answer in answers:
53 | score = max(score, dataset2metric[dataset_name](response, answer))
54 |
55 | total_score += score
56 | print(total_score / len(results) * 100)
57 |
58 |
59 | if __name__ == '__main__':
60 | fire.Fire(main)
61 |
--------------------------------------------------------------------------------
/scripts/README.md:
--------------------------------------------------------------------------------
1 | # How to run the scripts
2 | Inside this directory you can simply run:
3 | ```bash
4 | python run_benchmarks.py
5 | ```
6 | It will automatically run all the benchmarks based on `benchmark_setup.json`.
7 | Note) You may want to clean up `../benchmark/results`.
8 |
--------------------------------------------------------------------------------
/scripts/benchmark_setup.json:
--------------------------------------------------------------------------------
1 | {
2 | "default": {
3 | "llm_config_path": "./config/llm_config_llama2.json",
4 | "split": "0,1"
5 | },
6 | "benchmarks": [
7 |
8 | {"dataset": "narrativeqa"},
9 | {"dataset": "qasper"},
10 | {"dataset": "multifieldqa_en"},
11 | {"dataset": "hotpotqa"},
12 | {"dataset": "2wikimqa"},
13 | {"dataset": "musique"},
14 | {"dataset": "gov_report"},
15 | {"dataset": "qmsum"},
16 | {"dataset": "multi_news_long"},
17 | {"dataset": "trec"},
18 | {"dataset": "triviaqa"},
19 | {"dataset": "passage_count"},
20 | {"dataset": "passage_retrieval_en"}
21 | ],
22 | "okay_so_far_benchmarks" :[
23 |
24 | ],
25 | "skipped_benchmarks": [
26 | {"dataset": "multi_news", "reason": "We will use Longbench version, multi_news_log"},
27 | {"dataset": "lcc", "reason": "code syntax interferes with xml syntax of prompt cache"},
28 | {"dataset": "repobench-p", "reason": "code syntax interferes with xml syntax of prompt cache"},
29 | {"dataset": "samsum", "reason": "code syntax interferes with xml syntax of prompt cache"},
30 | {"dataset": "lsht", "reason": "out of memory"},
31 | {"dataset": "vcsum", "reason": "out of memory"},
32 | {"dataset": "dureader", "reason": "out of memory"},
33 | {"dataset": "ms_marco", "reason": "bad output"}
34 | ],
35 | "llm_list": [
36 | {
37 | "llm": "falcon",
38 | "config_name": "./config/llm_config_falcon_40b.json"
39 | },
40 | {
41 | "llm": "mpt",
42 | "config_name": "./config/llm_config_mpt_30b.json"
43 | },
44 | {
45 | "llm": "llama2",
46 | "config_name": "./config/llm_config_llama2_13b.json"
47 | }
48 | ]
49 | }
50 |
--------------------------------------------------------------------------------
/scripts/eval_llama2.slurm:
--------------------------------------------------------------------------------
1 | #!/bin/zsh
2 | #SBATCH --mem=128g
3 | #SBATCH --nodes=1
4 | #SBATCH --ntasks-per-node=1
5 | #SBATCH --cpus-per-task=32 # <- match to OMP_NUM_THREADS
6 | #SBATCH --partition=gpuA100x4 # <- or one of: gpuA100x4 gpuA40x4 gpuA100x8 gpuMI100x8
7 | #SBATCH --account=bccn-delta-gpu
8 | #SBATCH --job-name=eval_acc
9 | #SBATCH --time=10:00:00 # hh:mm:ss for the job
10 | ### GPU options ###
11 | #SBATCH --gpus-per-node=4
12 | ##SBATCH --gpu-bind=none # <- or closest
13 | #SBATCH --mail-user=in.gim@yale.edu.edu
14 | #SBATCH --mail-type="BEGIN,END"
15 |
16 | module reset # drop modules and explicitly load the ones needed
17 | # (good job metadata and reproducibility)
18 | # $WORK and $SCRATCH are now set
19 | module load python # ... or any appropriate modules
20 | module list # job documentation and metadata
21 | echo "job is starting on `hostname`"
22 | srun python3 run_benchmarks.py
--------------------------------------------------------------------------------
/scripts/run_benchmarks.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | import logging
3 | import json
4 | import os
5 | import sys
6 | import datetime
7 | from concurrent.futures import ThreadPoolExecutor
8 | import threading
9 | import queue
10 |
11 |
12 | def detect_nvidia_gpus():
13 | try:
14 | result = subprocess.run(['nvidia-smi', '--query-gpu=name', '--format=csv,noheader'], capture_output=True, text=True)
15 | num_gpus = len(result.stdout.strip().split('\n'))
16 | return num_gpus
17 | except Exception as e:
18 | logging.error(f"An error occurred: {e}")
19 | return 0
20 |
21 | def read_args_from_json(json_path):
22 | with open(json_path, 'r') as f:
23 | return json.load(f)
24 |
25 | def construct_python_commands(default_args, benchmarks, llm_list):
26 | commands = []
27 | for llm in llm_list:
28 | _llm_name = llm["llm"]
29 | llm_config = llm["config_name"]
30 | for benchmark in benchmarks:
31 | for enable_cache in [True, False]:
32 | split = benchmark.get("split", 1)
33 | for index in range(split):
34 | command = "python3 eval.py"
35 | merged_args = {**default_args, **benchmark, "llm_config_path": llm_config}
36 | for key, value in merged_args.items():
37 | if key == "enable_cache":
38 | continue
39 | elif key == "split":
40 | command += f" --split {index},{split}"
41 | else:
42 | command += f" --{key} {value}"
43 | command += f" --enable_cache {enable_cache}"
44 | command += f" --test_latency False"
45 | commands.append(command)
46 | return commands
47 |
48 | global python_commands_list
49 | def gpu_worker(gpu_id, command_lock):
50 | while True:
51 | with command_lock:
52 | global next_command_index
53 | if next_command_index >= len(python_commands_list):
54 | return # All commands are executed, so exit the thread
55 | command = python_commands_list[next_command_index]
56 | next_command_index += 1
57 |
58 | env_command = f"CUDA_VISIBLE_DEVICES={gpu_id} " + command
59 | try:
60 | subprocess.run("cd .. && " + env_command, shell=True, check=True)
61 | logging.info(f"Worker using GPU {gpu_id}: Command {command} completed successfully.")
62 | except subprocess.CalledProcessError as e:
63 | logging.error(f"Worker using GPU {gpu_id}: Command {command} failed: {e}")
64 |
65 |
66 | def main():
67 | # logging.basicConfig(level=logging.INFO)
68 | logging.basicConfig(filename=f'benchmark_results_{datetime.date.today().strftime("%Y%m%d_%H%M%S")}.log', level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
69 | num_gpus = detect_nvidia_gpus()
70 | logging.info(f"Detected {num_gpus} Nvidia GPUs.")
71 |
72 | # Read arguments from JSON file
73 | args_dict = read_args_from_json("benchmark_setup.json")
74 |
75 | global python_commands_list
76 | # Construct the Python commands
77 | python_commands_list = construct_python_commands(args_dict["default"], args_dict["benchmarks"], args_dict["llm_list"])
78 | logging.info(f"Constructed {len(python_commands_list)} benchmarks.")
79 |
80 | global next_command_index
81 | next_command_index = 0
82 | command_lock = threading.Lock()
83 | # Start a thread for each GPU
84 | threads = []
85 | for gpu_id in range(num_gpus):
86 | t = threading.Thread(target=gpu_worker, args=(gpu_id, command_lock))
87 | t.start()
88 | threads.append(t)
89 |
90 | # Wait for all threads to finish
91 | for t in threads:
92 | t.join()
93 |
94 | if __name__ == "__main__":
95 | main()
96 |
--------------------------------------------------------------------------------
/tests/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 |
4 | def build_mpt_alibi_tensor(num_heads, sequence_length, alibi_bias_max=8, device='cpu'):
5 | r"""
6 | Link to paper: https://arxiv.org/abs/2108.12409 - Alibi tensor is not causal as the original paper mentions, it
7 | relies on a translation invariance of softmax for quick implementation. This implementation has been copied from
8 | the alibi implementation of MPT source code that led to slightly different results than the Bloom alibi:
9 | https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L292
10 | """
11 | alibi = torch.arange(1 - sequence_length, 1, dtype=torch.int32, device=device).view(1, 1, 1, sequence_length)
12 | num_heads_power_of_2 = 2 ** math.ceil(math.log2(num_heads))
13 |
14 | base = torch.arange(1, num_heads_power_of_2 + 1, dtype=torch.float32, device=device)
15 | base = base * (alibi_bias_max / num_heads_power_of_2)
16 |
17 | slopes = 1.0 / torch.pow(2, base)
18 | slopes = slopes.view(1, num_heads, 1, 1)
19 |
20 | if num_heads_power_of_2 != num_heads:
21 | slopes = torch.concat([slopes[1::2], slopes[::2]])[:num_heads]
22 |
23 | alibi = alibi * slopes
24 | return alibi.squeeze(0)
25 |
26 |
27 | table = build_mpt_alibi_tensor(16, 20)
28 | print(table.shape)
29 | print(table)
--------------------------------------------------------------------------------
/tests/test_document_summary.py:
--------------------------------------------------------------------------------
1 | # Add the parent directory to the sys.path list
2 | import os, sys
3 | document_summary_path = os.path.abspath(os.path.dirname(__file__))
4 | sys.path.append(os.path.abspath(os.path.join(document_summary_path, '..')))
5 |
6 | import unittest
7 | from benchmark.multi_news import MultiNews
8 |
9 | ## The following code is for testing the MultiNews class
10 | class TestMultiNews(unittest.TestCase):
11 | def setUp(self):
12 | self.document_summary = MultiNews()
13 | self.document_summary.init(verbose=False)
14 |
15 | def test_get_next_query(self):
16 | query_id, query, modules = self.document_summary.get_next_query()
17 | self.assertIsInstance(query_id, int)
18 | self.assertGreaterEqual(query_id, 0)
19 | self.assertIsInstance(query, str)
20 | self.assertEqual(query, "") # query is empty string for summary benchmark
21 | self.assertIsInstance(modules, list)
22 | self.assertGreaterEqual(len(modules), 1)
23 | for module in modules:
24 | self.assertIsInstance(module, str)
25 |
26 | def test_evaluate(self):
27 | query_id, query, modules = self.document_summary.get_next_query()
28 | response = self.document_summary.evaluate(query_id, "response")
29 | self.assertIsInstance(response, float)
30 | self.assertGreaterEqual(response, 0.)
31 | self.assertLessEqual(response, 1.)
32 |
33 | if __name__ == '__main__':
34 | unittest.main()
35 |
--------------------------------------------------------------------------------