├── LICENSE ├── README.md ├── constants.py ├── datasets_loader.py ├── experiment_manager.py ├── logits_processor.py ├── model_loaders.py ├── modeling_gpt2_with_pcw.py ├── modeling_llama_with_pcw.py ├── pcw_wrapper.py ├── requirements.in ├── requirements.txt ├── run_evaluation.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Parallel Context Windows (PCW) 2 | 3 | This repo contains the code for reproducing the classification experiments from [AI21 Labs](https://www.ai21.com/)' paper [Parallel Context Windows for Large Language Models 4 | ](https://arxiv.org/abs/2212.10947). 5 | The code was tested with python 3.10, for CPU, GPU and multiple GPU runs. Currently, the code supports using GPT2 and LLaMa model families. 6 | 7 | ## Setup 8 | 9 | To install the required libraries in our repo, run: 10 | ```bash 11 | pip install -r requirements.txt 12 | ``` 13 | To have a Pytorch version specific to your CUDA, [install](https://pytorch.org/) your version before running the above command. 14 | 15 | ## Evaluation 16 | Due to the fact that the paper's results were based on an earlier implementation of PCW and not [HuggingFace Transformers](https://huggingface.co/docs/transformers/index), the results produced using this code may differ slightly from those shown in the paper. 17 | To reproduce similar results shown in the appendix for GPT2-XL for a specific dataset (for example SST2), simply run: 18 | ```bash 19 | python run_evaluation.py \ 20 | --dataset sst2 \ 21 | --model gpt2-xl \ 22 | --n-windows 1 \ 23 | --n-windows 3 \ 24 | --subsample-test-set 250 \ 25 | --n-runs 30 \ 26 | --output-dir $OUTPUT_DIR 27 | ``` 28 | In this run, PCW's performance is evaluated on a subsample (250 samples) of the full test set. 29 | The experiment is repeated 30 times (with different random samples of training examples) for each number of windows (in this case - one and three). 30 | As a default, the script uses as many examples per window as possible. 31 | Note that using a single window is equivalent to regular ICL settings. Thus, this run should give similar results to those shown in Table 5 for SST2 with GPT2-XL. 32 | 33 | The evaluation output is a numpy file (shaped `[2,30]`) found in `$OUTPUT_DIR` with the mean accuracy for each repetition and number of windows. 34 | You could read the file directly with np.load, or use utils.py function to load and plot the results. 35 | See --help for further instructions. 36 | 37 | ## PCW Usage examples 38 | In the evaluation code, only classification tasks are performed. 39 | The code snippet below shows how PCW can be used both for classification and generation: 40 | 41 | ```python 42 | import numpy as np 43 | 44 | from model_loaders import load_pcw_wrapper 45 | from logits_processor import RestrictiveTokensLogitsProcessor 46 | 47 | from utils import encode_labels 48 | 49 | wrapper = load_pcw_wrapper('gpt2-large', n_windows=2) 50 | 51 | # use PCW with few shot for classification example: 52 | labels_input_ids = np.array(encode_labels(wrapper.tokenizer, ['positive', 'negative'])) 53 | # using RestrictiveTokensLogitsProcessor forces the output to be one of the labels: 54 | logit_processor = RestrictiveTokensLogitsProcessor(labels_input_ids, eos_token_id=wrapper.tokenizer.eos_token_id) 55 | output = wrapper.pcw_generate(contexts=["Review: Great movie! Sentiment: positive\n", 56 | "Review: Horrible film Sentiment: negative\n"], 57 | task_text="Review: I liked it Sentiment:", 58 | restrictive_logit_preprocessor=logit_processor, 59 | temperature=0, 60 | max_new_tokens=1) 61 | print(output.strip()) 62 | # use PCW for generation: 63 | output = wrapper.pcw_generate(contexts=["Review: Great movie!\n", "Review: Horrible film\n"], 64 | task_text="Review:", 65 | temperature=1, 66 | do_sample=True, 67 | max_new_tokens=16) 68 | print(output) 69 | ``` 70 | 71 | ## Citation 72 | 73 | If you find our paper or code helpful, please consider citing our paper: 74 | ``` 75 | @misc{ratner2023parallel, 76 | title={Parallel Context Windows for Large Language Models}, 77 | author={Nir Ratner and Yoav Levine and Yonatan Belinkov and Ori Ram and Inbal Magar and Omri Abend and Ehud Karpas and Amnon Shashua and Kevin Leyton-Brown and Yoav Shoham}, 78 | year={2023}, 79 | eprint={2212.10947}, 80 | archivePrefix={arXiv}, 81 | primaryClass={cs.CL} 82 | } 83 | ``` 84 | -------------------------------------------------------------------------------- /constants.py: -------------------------------------------------------------------------------- 1 | SPLIT_TOKEN = "==" 2 | TEXT_BETWEEN_SHOTS = f"\n{SPLIT_TOKEN}\n" 3 | N_TOKENS = 'n_tokens' 4 | PROMPTS = 'prompts' 5 | -------------------------------------------------------------------------------- /datasets_loader.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from abc import ABC 3 | from typing import Dict, Optional 4 | 5 | import pandas as pd 6 | from datasets import load_dataset 7 | 8 | from constants import PROMPTS 9 | 10 | 11 | UTTERANCE_PREFIX = 'utterance: ' 12 | 13 | INTENT_PREFIX = 'intent: ' 14 | 15 | LABEL_TOKENS = 'label_tokens' 16 | 17 | _logger = logging.getLogger(__name__) 18 | logging.basicConfig(level=logging.INFO, format='%(message)s') 19 | 20 | 21 | class ClassificationDatasetAccess(ABC): 22 | name: str 23 | dataset: Optional[str] = None 24 | subset: Optional[str] = None 25 | x_column: str = 'text' 26 | y_label: str = 'label' 27 | x_prefix: str = "Review: " 28 | y_prefix: str = "Sentiment: " 29 | label_mapping: Optional[Dict] = None 30 | map_labels: bool = True 31 | 32 | def __init__(self): 33 | super().__init__() 34 | if self.dataset is None: 35 | self.dataset = self.name 36 | train_dataset, test_dataset = self._load_dataset() 37 | train_df = train_dataset.to_pandas() 38 | test_df = test_dataset.to_pandas() 39 | _logger.info(f"loaded {len(train_df)} training samples & {len(test_df)} test samples") 40 | 41 | if self.map_labels: 42 | hf_default_labels = train_dataset.features[self.y_label] 43 | default_label_mapping = dict(enumerate(hf_default_labels.names)) if hasattr( 44 | train_dataset.features[self.y_label], 'names') else None 45 | self._initialize_label_mapping(default_label_mapping) 46 | 47 | self.train_df = self.apply_format(train_df) 48 | self.test_df = self.apply_format(test_df, test=True) 49 | 50 | def _initialize_label_mapping(self, default_label_mapping): 51 | if self.label_mapping: 52 | _logger.info("overriding default label mapping") 53 | if default_label_mapping: 54 | _logger.info([f"{default_label_mapping[k]} -> " 55 | f"{self.label_mapping[k]}" for k in self.label_mapping.keys()]) 56 | else: 57 | _logger.info(f"using default label mapping: {default_label_mapping}") 58 | self.label_mapping = default_label_mapping 59 | 60 | def _load_dataset(self): 61 | if self.subset is not None: 62 | dataset = load_dataset(self.dataset, self.subset) 63 | else: 64 | dataset = load_dataset(self.dataset) 65 | if 'validation' in dataset: 66 | return dataset['train'], dataset['validation'] 67 | if 'test' not in dataset: 68 | _logger.info("no test or validation found, splitting train set instead") 69 | dataset = dataset['train'].train_test_split(seed=42) 70 | 71 | return dataset['train'], dataset['test'] 72 | 73 | def generate_x_text(self, df: pd.DataFrame) -> pd.DataFrame: 74 | return df 75 | 76 | def generate_y_token_labels(self, df, test): 77 | if self.map_labels: 78 | df[LABEL_TOKENS] = df[self.y_label].map(self.label_mapping) 79 | else: 80 | df[LABEL_TOKENS] = df[self.y_label] 81 | return df 82 | 83 | @property 84 | def labels(self): 85 | if self.map_labels: 86 | return self.label_mapping.values() 87 | else: 88 | return self.test_df[LABEL_TOKENS].unique() 89 | 90 | def apply_format(self, df, test=False): 91 | df = self.generate_x_text(df) 92 | df = self.generate_y_token_labels(df, test) 93 | if test: 94 | df[PROMPTS] = df.apply(lambda x: f"{self.x_prefix}{x[self.x_column]}\n{self.y_prefix}".rstrip(), axis=1) 95 | else: 96 | df[PROMPTS] = df.apply(lambda x: f"{self.x_prefix}{x[self.x_column]}\n{self.y_prefix}{x[LABEL_TOKENS]}", 97 | axis=1) 98 | return df 99 | 100 | 101 | class SST5(ClassificationDatasetAccess): 102 | name = 'sst5' 103 | dataset = 'SetFit/sst5' 104 | label_mapping = {0: 'terrible', 1: 'bad', 2: 'okay', 3: 'good', 4: 'great'} 105 | 106 | 107 | class RTE(ClassificationDatasetAccess): 108 | name = 'rte' 109 | dataset = 'super_glue' 110 | subset = 'rte' 111 | x_prefix = '' 112 | y_prefix = 'prediction: ' 113 | label_mapping = {0: 'True', 1: 'False'} 114 | 115 | def generate_x_text(self, df: pd.DataFrame) -> pd.DataFrame: 116 | df['text'] = df.apply(lambda x: f"premise: {x['premise']}\nhypothesis: {x['hypothesis']}", axis=1) 117 | return df 118 | 119 | 120 | class CB(RTE): 121 | name = 'cb' 122 | subset = 'cb' 123 | label_mapping = {0: 'true', 1: 'false', 2: 'neither'} 124 | 125 | 126 | class SUBJ(ClassificationDatasetAccess): 127 | name = 'subj' 128 | dataset = 'SetFit/subj' 129 | label_mapping = {0: 'objective', 1: 'subjective'} 130 | x_prefix = 'Input: ' 131 | y_prefix = 'Type: ' 132 | 133 | 134 | class CR(ClassificationDatasetAccess): 135 | name = 'cr' 136 | dataset = 'SetFit/CR' 137 | label_mapping = {0: 'negative', 1: 'positive'} 138 | 139 | 140 | class AGNEWS(ClassificationDatasetAccess): 141 | name = 'agnews' 142 | dataset = 'ag_news' 143 | label_mapping = {0: 'world', 1: 'sports', 2: 'business', 3: 'technology'} 144 | x_prefix = 'input: ' 145 | y_prefix = 'type: ' 146 | 147 | 148 | class DBPEDIA(ClassificationDatasetAccess): 149 | name = 'dbpedia' 150 | dataset = 'dbpedia_14' 151 | label_mapping = {0: 'company', 152 | 1: 'school', 153 | 2: 'artist', 154 | 3: 'athlete', 155 | 4: 'politics', 156 | 5: 'transportation', 157 | 6: 'building', 158 | 7: 'nature', 159 | 8: 'village', 160 | 9: 'animal', 161 | 10: 'plant', 162 | 11: 'album', 163 | 12: 'film', 164 | 13: 'book'} 165 | x_prefix = 'input: ' 166 | y_prefix = 'type: ' 167 | 168 | def generate_x_text(self, df: pd.DataFrame) -> pd.DataFrame: 169 | df['text'] = df['content'] 170 | return df 171 | 172 | 173 | class SST2(ClassificationDatasetAccess): 174 | name = 'sst2' 175 | 176 | def generate_x_text(self, df: pd.DataFrame) -> pd.DataFrame: 177 | df['text'] = df['sentence'] 178 | return df 179 | 180 | 181 | class TREC(ClassificationDatasetAccess): 182 | name = 'trec' 183 | y_label = 'coarse_label' 184 | x_prefix = "Question: " 185 | y_prefix = "Type: " 186 | label_mapping = {0: "abbreviation", 1: "entity", 2: "description", 3: "human", 4: "location", 5: 'numeric'} 187 | 188 | 189 | class TRECFINE(ClassificationDatasetAccess): 190 | name = 'trecfine' 191 | dataset = 'trec' 192 | y_label = 'fine_label' 193 | x_prefix = "Question: " 194 | y_prefix = "Type: " 195 | # labels mapping based on: https://aclanthology.org/C16-1116.pdf, https://aclanthology.org/C02-1150.pdf 196 | label_mapping = {0: 'abbreviation abbreviation', 197 | 1: 'abbreviation expansion', 198 | 2: 'entity animal', 199 | 3: 'entity body', 200 | 4: 'entity color', 201 | 5: 'entity creation', 202 | 6: 'entity currency', 203 | 7: 'entity disease', 204 | 8: 'entity event', 205 | 9: 'entity food', 206 | 10: 'entity instrument', 207 | 11: 'entity language', 208 | 12: 'entity letter', 209 | 13: 'entity other', 210 | 14: 'entity plant', 211 | 15: 'entity product', 212 | 16: 'entity religion', 213 | 17: 'entity sport', 214 | 18: 'entity substance', 215 | 19: 'entity symbol', 216 | 20: 'entity technique', 217 | 21: 'entity term', 218 | 22: 'entity vehicle', 219 | 23: 'entity word', 220 | 24: 'description definition', 221 | 25: 'description description', 222 | 26: 'description manner', 223 | 27: 'description reason', 224 | 28: 'human group', 225 | 29: 'human individual', 226 | 30: 'human title', 227 | 31: 'human description', 228 | 32: 'location city', 229 | 33: 'location country', 230 | 34: 'location mountain', 231 | 35: 'location other', 232 | 36: 'location state', 233 | 37: 'numeric code', 234 | 38: 'numeric count', 235 | 39: 'numeric date', 236 | 40: 'numeric distance', 237 | 41: 'numeric money', 238 | 42: 'numeric order', 239 | 43: 'numeric other', 240 | 44: 'numeric period', 241 | 45: 'numeric percent', 242 | 46: 'numeric speed', 243 | 47: 'numeric temperature', 244 | 48: 'numeric size', 245 | 49: 'numeric weight'} 246 | 247 | 248 | class YELP(ClassificationDatasetAccess): 249 | name = 'yelp' 250 | dataset = 'yelp_review_full' 251 | x_prefix = 'review: ' 252 | y_prefix = 'stars: ' 253 | label_mapping = {0: '1', 1: '2', 2: '3', 3: '4', 4: '5'} 254 | 255 | 256 | class BANKING77(ClassificationDatasetAccess): 257 | name = 'banking77' 258 | x_prefix = 'query: ' 259 | y_prefix = INTENT_PREFIX 260 | 261 | def _initialize_label_mapping(self, default_label_mapping): 262 | default_label_mapping = {k: v.replace('_', ' ') for k, v in default_label_mapping.items()} 263 | super()._initialize_label_mapping(default_label_mapping) 264 | 265 | 266 | class NLU(ClassificationDatasetAccess): 267 | name = 'nlu' 268 | dataset = 'nlu_evaluation_data' 269 | x_prefix = UTTERANCE_PREFIX 270 | y_prefix = INTENT_PREFIX 271 | label_mapping = {0: 'alarm query', 1: 'alarm remove', 2: 'alarm set', 3: 'audio volume down', 272 | 4: 'audio volume mute', 5: 'audio volume other', 6: 'audio volume up', 7: 'calendar query', 273 | 8: 'calendar remove', 9: 'calendar set', 10: 'cooking query', 11: 'cooking recipe', 274 | 12: 'datetime convert', 13: 'datetime query', 14: 'email add contact', 15: 'email query', 275 | 16: 'email query contact', 17: 'email sendemail', 18: 'general affirm', 19: 'general command stop', 276 | 20: 'general confirm', 21: 'general dont care', 22: 'general explain', 23: 'general greet', 277 | 24: 'general joke', 25: 'general negate', 26: 'general praise', 27: 'general quirky', 278 | 28: 'general repeat', 29: 'iot cleaning', 30: 'iot coffee', 31: 'iot hue light change', 279 | 32: 'iot hue light dim', 33: 'iot hue light off', 34: 'iot hue lighton', 35: 'iot hue light up', 280 | 36: 'iot wemo off', 37: 'iot wemo on', 38: 'lists create or add', 39: 'lists query', 281 | 40: 'lists remove', 41: 'music dislikeness', 42: 'music likeness', 43: 'music query', 282 | 44: 'music settings', 45: 'news query', 46: 'play audiobook', 47: 'play game', 48: 'play music', 283 | 49: 'play podcasts', 50: 'play radio', 51: 'qa currency', 52: 'qa definition', 53: 'qa factoid', 284 | 54: 'qa maths', 55: 'qa stock', 56: 'recommendation events', 57: 'recommendation locations', 285 | 58: 'recommendation movies', 59: 'social post', 60: 'social query', 61: 'takeaway order', 286 | 62: 'takeaway query', 63: 'transport query', 64: 'transport taxi', 65: 'transport ticket', 287 | 66: 'transport traffic', 67: 'weather query'} 288 | 289 | 290 | class NLUSCENARIO(ClassificationDatasetAccess): 291 | name = 'nluscenario' 292 | dataset = 'nlu_evaluation_data' 293 | x_prefix = UTTERANCE_PREFIX 294 | y_prefix = 'scenario: ' 295 | y_label = 'scenario' 296 | map_labels = False 297 | 298 | 299 | class CLINIC150(BANKING77): 300 | name = "clinic150" 301 | dataset = 'clinc_oos' 302 | subset = 'plus' 303 | y_label = "intent" 304 | x_prefix = UTTERANCE_PREFIX 305 | y_prefix = INTENT_PREFIX 306 | 307 | 308 | DATASET_NAMES2LOADERS = {'sst5': SST5, 'sst2': SST2, 'agnews': AGNEWS, 'dbpedia': DBPEDIA, 'trec': TREC, 'cr': CR, 309 | 'cb': CB, 'rte': RTE, 'subj': SUBJ, 'yelp': YELP, 'banking77': BANKING77, 310 | 'nlu': NLU, 'nluscenario': NLUSCENARIO, 'trecfine': TRECFINE, 311 | 'clinic150': CLINIC150} 312 | 313 | if __name__ == '__main__': 314 | for ds_name, da in DATASET_NAMES2LOADERS.items(): 315 | _logger.info(ds_name) 316 | _logger.info(da().train_df[PROMPTS].iloc[0]) 317 | -------------------------------------------------------------------------------- /experiment_manager.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import random 3 | from typing import List, Dict 4 | 5 | import numpy as np 6 | import numpy.typing as npt 7 | import pandas as pd 8 | from tqdm import tqdm 9 | 10 | from constants import TEXT_BETWEEN_SHOTS, N_TOKENS, PROMPTS 11 | from datasets_loader import LABEL_TOKENS 12 | from pcw_wrapper import PCWModelWrapper 13 | from logits_processor import RestrictiveTokensLogitsProcessor 14 | from utils import n_tokens_in_prompt, encode_labels, encode_stop_seq 15 | 16 | _logger = logging.getLogger(__name__) 17 | logging.basicConfig(level=logging.INFO, format='%(message)s') 18 | 19 | STOP_SEQUENCE = '\n' 20 | 21 | 22 | class ExperimentManager: 23 | def __init__(self, test_df: pd.DataFrame, train_df: pd.DataFrame, model: PCWModelWrapper, 24 | labels: List[str] = None, random_seed: int = 42, subsample_test_set: int = 250, 25 | n_shots_per_window: int = None): 26 | if subsample_test_set < len(test_df): 27 | np.random.seed(random_seed) 28 | test_df = test_df.sample(subsample_test_set) 29 | self.test_df = test_df 30 | self.train_df = train_df 31 | self.model = model 32 | self.base_random_seed = random_seed 33 | self.n_shots_per_window = n_shots_per_window 34 | self.tokenizer = model.tokenizer 35 | self._initialize_labels_and_logit_processor(labels) 36 | 37 | def _initialize_labels_and_logit_processor(self, labels: List[str]) -> None: 38 | _logger.info(f"Provided labels: {labels}") 39 | labels_tokens = encode_labels(self.tokenizer, labels) 40 | labels_tokens_array = self.minimize_labels_tokens(labels_tokens) 41 | _logger.info(f"Provided labels average n_tokens: {np.round(np.mean([len(lt) for lt in labels_tokens]), 3)}") 42 | # we fix the labels accordingly in the test set: 43 | shorten_label_tokens = [t[t != self.tokenizer.eos_token_id].tolist() for t in labels_tokens_array] 44 | _logger.info( 45 | f"shortened labels average n_tokens: {np.round(np.mean([len(lt) for lt in shorten_label_tokens]), 3)}") 46 | # Moving the test set label tokens to their shorter version: 47 | map_labels = {old_label: self.tokenizer.decode(t).lstrip() for old_label, t in 48 | zip(labels, shorten_label_tokens)} 49 | self.test_df[LABEL_TOKENS] = self.test_df[LABEL_TOKENS].map(map_labels) 50 | pad = len(max(shorten_label_tokens, key=len)) 51 | labels_tokens_array = np.array( 52 | [i + [self.tokenizer.eos_token_id] * (pad - len(i)) for i in shorten_label_tokens]) 53 | self.max_n_tokens = pad 54 | labels_tokens_array = self.pad_contained_labels_with_stop_seq(shorten_label_tokens, labels_tokens_array) 55 | self.logit_processor = RestrictiveTokensLogitsProcessor(restrictive_token_ids=labels_tokens_array, 56 | eos_token_id=self.tokenizer.eos_token_id) 57 | self.possible_labels = set(map_labels.values()) 58 | 59 | def minimize_labels_tokens(self, labels_tokens: List[List[int]]) -> npt.NDArray[int]: 60 | """ 61 | Minimize the number of tokens per label to be the shortest possible unique one. 62 | """ 63 | pad = len(max(labels_tokens, key=len)) 64 | labels_tokens_array = np.array([i + [self.tokenizer.eos_token_id] * (pad - len(i)) for i in labels_tokens]) 65 | for i, tokens in enumerate(labels_tokens): 66 | for j in range(len(tokens)): 67 | labels_with_shared_beginnings = np.sum( 68 | np.all(labels_tokens_array[:, :j] == np.array(tokens[:j]), axis=1)) 69 | if labels_with_shared_beginnings == 1: 70 | labels_tokens_array[i, j:] = self.tokenizer.eos_token_id 71 | break 72 | return labels_tokens_array 73 | 74 | def pad_contained_labels_with_stop_seq(self, labels_tokens: List, labels_tokens_array: npt.NDArray[int]) \ 75 | -> npt.NDArray[int]: 76 | """ 77 | In case we have two labels, where one label contains the other label (for example: "A" and "A B") we need 78 | to allow the restrictive decoding to produce the output "A". We support it by adding "\n" to the shorter label. 79 | """ 80 | stop_seq_token_id = encode_stop_seq(self.tokenizer, STOP_SEQUENCE) 81 | for i, tokens in enumerate(labels_tokens): 82 | labels_with_shared_beginnings = np.sum( 83 | np.all(labels_tokens_array[:, :len(tokens)] == np.array(tokens), axis=1)) 84 | if labels_with_shared_beginnings > 1: 85 | _logger.info(f"label{self.tokenizer.decode(tokens)} is the beginning of one of the other labels," 86 | f"adding stop sequence to its end") 87 | labels_tokens_array[i, len(tokens)] = stop_seq_token_id 88 | return labels_tokens_array 89 | 90 | def _set_random_seed(self, random_seed: int) -> None: 91 | np.random.seed(random_seed) 92 | random.seed(random_seed) 93 | 94 | def get_few_shots_acc(self, windows_few_shot: List[str]) -> float: 95 | predicted_labels = self.get_predicted_labels(windows_few_shot) 96 | return self.calc_acc(predicted_labels) 97 | 98 | def get_predicted_labels(self, windows_few_shots: List[str]) -> List[str]: 99 | windows_cache = self.model.get_contexts_cache(windows_few_shots) 100 | predicted_labels = [] 101 | for q in self.test_df[PROMPTS]: 102 | predicted_label = self.predict_label(TEXT_BETWEEN_SHOTS + q, windows_cache) 103 | predicted_labels.append(predicted_label) 104 | assert set(predicted_labels).issubset(self.possible_labels) 105 | return predicted_labels 106 | 107 | def predict_label(self, task_text: str, cache: Dict) -> str: 108 | assert task_text == task_text.rstrip(), "prompt ends with a space!" 109 | res = self.model.pcw_generate(task_text=task_text, 110 | contexts_cache=cache, 111 | restrictive_logit_preprocessor=self.logit_processor, 112 | temperature=0, 113 | max_new_tokens=self.max_n_tokens) 114 | 115 | return res.lstrip().strip(STOP_SEQUENCE) 116 | 117 | def calc_acc(self, predicted_labels: List) -> float: 118 | predicted_labels = pd.Series(predicted_labels, index=self.test_df.index) 119 | acc = np.mean(predicted_labels == self.test_df[LABEL_TOKENS]) 120 | _logger.info(f"accuracy = {np.round(acc, 3)}") 121 | return acc 122 | 123 | def run_experiment_across_shots(self, n_shots_to_test: List[int], n_runs: int, 124 | too_long_patience: float = 0.2): 125 | accuracies = np.zeros((len(n_shots_to_test), n_runs)) 126 | for i, n_shots in enumerate(tqdm(n_shots_to_test)): 127 | _logger.info(f"starting with n = {n_shots}") 128 | self._set_random_seed(self.base_random_seed + n_shots) 129 | j = 0 130 | n_errors = 0 131 | while j < n_runs: 132 | few_shots_idx = self.sample_n_shots(n_shots) 133 | few_shots_prompts = list(self.train_df.loc[few_shots_idx, PROMPTS]) 134 | windows_few_shots = self.build_windows_few_shots_text(few_shots_prompts, self.n_shots_per_window) 135 | longest_window_n_tokens = max(n_tokens_in_prompt(self.tokenizer, window) 136 | for window in windows_few_shots) 137 | n_tokens_between_shots = n_tokens_in_prompt(self.tokenizer, TEXT_BETWEEN_SHOTS) 138 | if (longest_window_n_tokens + n_tokens_between_shots + self.test_df[N_TOKENS].max() 139 | + self.max_n_tokens) > self.model.context_window_size: 140 | _logger.warning("Drawn training shots were too long, trying again") 141 | n_errors += 1 142 | assert n_errors <= too_long_patience * n_runs, "too many long inputs were drawn!" 143 | continue 144 | accuracies[i, j] = self.get_few_shots_acc(windows_few_shots) 145 | j += 1 146 | return accuracies 147 | 148 | def sample_n_shots(self, n_shots: int) -> npt.NDArray[int]: 149 | few_shots_df = self.train_df.sample(n_shots) 150 | assert few_shots_df.index.is_unique, "few shots samples were not unique!" 151 | window_size = self.n_shots_per_window or n_shots 152 | n_windows = int(len(few_shots_df) / window_size) 153 | if not self.n_shots_per_window or n_windows == 1: 154 | return few_shots_df.index 155 | return self.balance_windows_sizes(n_windows, few_shots_df) 156 | 157 | def balance_windows_sizes(self, n_windows: int, few_shots_df: pd.DataFrame) -> npt.NDArray[int]: 158 | few_shots_df.sort_values(by=N_TOKENS, inplace=True, ascending=False) 159 | shape = (self.n_shots_per_window, n_windows) 160 | indexes = np.array(few_shots_df.index).reshape(shape) 161 | sizes = few_shots_df.loc[indexes.flatten()].n_tokens.values.reshape(indexes.shape) 162 | for i in range(1, self.n_shots_per_window): 163 | order = np.argsort((np.sum(sizes[:i, :], axis=0))) 164 | sizes[i, :] = sizes[i, order] 165 | indexes[i, :] = indexes[i, order] 166 | # shuffle the order in each window: 167 | for i in range(n_windows): 168 | np.random.shuffle(indexes[:, i]) 169 | indexes = indexes.T.flatten() 170 | return indexes 171 | 172 | @staticmethod 173 | def build_windows_few_shots_text(few_shots_prompts: List, window_size: int) -> List[str]: 174 | if window_size is None: 175 | window_size = len(few_shots_prompts) 176 | return [TEXT_BETWEEN_SHOTS.join(few_shots_prompts[i: i + window_size]) for i in 177 | range(0, len(few_shots_prompts), window_size)] 178 | -------------------------------------------------------------------------------- /logits_processor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from numpy import typing as npt 4 | from transformers import LogitsProcessor 5 | 6 | LOGIT_BIAS = 100 7 | 8 | 9 | class RestrictiveTokensLogitsProcessor(LogitsProcessor): 10 | """ Restrictive decoding is done by adding logits_bias to the relevant tokens. Based on: 11 | https://help.openai.com/en/articles/5247780-using-logit-bias-to-define-token-probability 12 | """ 13 | 14 | def __init__(self, 15 | restrictive_token_ids: npt.NDArray[int], 16 | eos_token_id: int, 17 | prompt_length_to_skip: int = 0, 18 | logits_bias: int = LOGIT_BIAS): 19 | self.restrictive_token_ids = restrictive_token_ids 20 | self.eos_token_id = eos_token_id 21 | self.logits_bias = logits_bias 22 | self.prompt_length_to_skip = prompt_length_to_skip 23 | self.mask = np.ones(restrictive_token_ids.shape[0], dtype=bool) 24 | 25 | self._preprocess_restrictive_array() 26 | 27 | def _preprocess_restrictive_array(self): 28 | # extend restrictive_token_ids to include eos as last token for each sequence 29 | if not (self.restrictive_token_ids[:, -1] == self.eos_token_id).all(): 30 | self.restrictive_token_ids = np.column_stack( 31 | (self.restrictive_token_ids, np.ones(self.restrictive_token_ids.shape[0]) * self.eos_token_id)). \ 32 | astype(int) 33 | 34 | def update_new_prompt_length_to_skip(self, prompt_length_to_skip: int): 35 | self.prompt_length_to_skip = prompt_length_to_skip 36 | self.mask = np.ones(self.restrictive_token_ids.shape[0], dtype=bool) 37 | 38 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 39 | assert input_ids.shape[0] == 1, "This implementation doesn't support batching" 40 | new_tokens_length = input_ids.shape[-1] - self.prompt_length_to_skip 41 | if new_tokens_length > 0: 42 | self.mask = self.mask & (self.restrictive_token_ids[:, new_tokens_length - 1] == input_ids[ 43 | 0, -1].item()) 44 | scores[:, self.restrictive_token_ids[self.mask, new_tokens_length]] += self.logits_bias 45 | return scores 46 | -------------------------------------------------------------------------------- /model_loaders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoConfig, LlamaTokenizer, GPT2Tokenizer, PreTrainedTokenizerBase 3 | 4 | from modeling_gpt2_with_pcw import GPT2LMHeadPCW 5 | from pcw_wrapper import PCWModelWrapper 6 | from modeling_llama_with_pcw import LlamaForCausalLMPCW 7 | 8 | GPT2_WINDOW_SIZE = 1024 9 | LLAMA_WINDOW_SIZE = 2048 10 | 11 | 12 | def validate_model_name(model_name: str) -> None: 13 | assert 'llama' in model_name or 'gpt2' in model_name, f"Unknown model: {model_name}" 14 | 15 | 16 | def load_tokenizer(model_name: str) -> PreTrainedTokenizerBase: 17 | if 'llama' in model_name: 18 | if model_name == 'seanmor5/tiny-llama-test' or 'decapoda-research' in model_name: # debug mode: 19 | tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf") 20 | # In case you load those models, we must override an incorrect config: 21 | # see: https://huggingface.co/decapoda-research/llama-7b-hf/discussions/12 22 | tokenizer.bos_token_id = 1 23 | tokenizer.eos_token_id = 2 24 | else: 25 | tokenizer = LlamaTokenizer.from_pretrained(model_name) 26 | else: 27 | # In our experiments we have added bos token to gpt2: 28 | tokenizer = GPT2Tokenizer.from_pretrained('gpt2', add_bos_token=True) 29 | return tokenizer 30 | 31 | 32 | def load_pcw_wrapper(model_name: str, cache_dir: str = None, 33 | right_indentation: bool = False, n_windows: int = 1) -> PCWModelWrapper: 34 | validate_model_name(model_name) 35 | config = AutoConfig.from_pretrained(model_name) 36 | device = "cuda" if torch.cuda.is_available() else "cpu" 37 | multi_gpus = torch.cuda.device_count() > 1 38 | model_args = { 39 | "cache_dir": cache_dir 40 | } 41 | if multi_gpus: 42 | model_args["device_map"] = "auto" 43 | model_args["low_cpu_mem_usage"] = True 44 | if hasattr(config, "torch_dtype") and config.torch_dtype is not None: 45 | model_args["torch_dtype"] = config.torch_dtype 46 | 47 | if 'gpt2' in model_name: 48 | # we override n_positions to bi pass the model's context window size restriction 49 | # (for gpt2, n_positions determines the causal attention mask matrix dimension). 50 | # The correct position embeddings (i.e., gpt2's 1024 trained position embeddings) are re-inserted to the model 51 | # in GPT2LMHeadWithPCWModel initialization. 52 | model_args['ignore_mismatched_sizes'] = True 53 | model_args['n_positions'] = GPT2_WINDOW_SIZE * n_windows 54 | model_obj = GPT2LMHeadPCW 55 | context_window_size = GPT2_WINDOW_SIZE 56 | else: 57 | # Note that some LLaMa versions located in HF have an incorrect token mapping, we correct it here: 58 | # see: https://huggingface.co/decapoda-research/llama-7b-hf/discussions/12 59 | # also: https://github.com/tloen/alpaca-lora/issues/279 60 | model_args['bos_token_id'] = 1 61 | model_args['eos_token_id'] = 2 62 | model_obj = LlamaForCausalLMPCW 63 | context_window_size = LLAMA_WINDOW_SIZE 64 | 65 | tokenizer = load_tokenizer(model_name) 66 | model = model_obj.from_pretrained(model_name, **model_args).eval() 67 | if not multi_gpus: 68 | model = model.to(device) 69 | 70 | return PCWModelWrapper(model, tokenizer, device, context_window_size, right_indentation) 71 | -------------------------------------------------------------------------------- /modeling_gpt2_with_pcw.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from typing import Tuple, Optional, Dict 3 | 4 | import torch 5 | from transformers import GPT2LMHeadModel 6 | from transformers.configuration_utils import PretrainedConfig 7 | 8 | from pcw_wrapper import generate_pcw_position_ids 9 | 10 | 11 | class GPT2LMHeadPCW(GPT2LMHeadModel, ABC): 12 | def __init__(self, config: PretrainedConfig): 13 | super().__init__(config) 14 | self._adapt_weights() 15 | 16 | def _adapt_weights(self): 17 | # We need to override the regular loading of wpe weight since we are adding support to longer contexts. 18 | self.transformer.wpe = GPT2LMHeadModel.from_pretrained(self.config.name_or_path).transformer.wpe 19 | 20 | def prepare_inputs_for_generation(self, 21 | input_ids: torch.LongTensor, 22 | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, 23 | windows_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, 24 | max_window_size: Optional[int] = None, 25 | sum_windows_size: Optional[int] = None, 26 | **kwargs 27 | ) -> Dict: 28 | """input_ids: 29 | ids of task_tokens. 30 | attention_mask: 31 | concatenation of windows + task tokens attentions masks. 32 | 33 | Note (past_key_values vs windows_key_values): 34 | In the first token generation, past_key_values is None while windows_key_values contains the combined past 35 | key values of context windows. During following generations, past_key_values is the concatenation of 36 | windows_key_values + previous generations. Thus, windows_key_values is practically ignored. 37 | """ 38 | 39 | token_type_ids = kwargs.get("token_type_ids") 40 | # only last token for inputs_ids if past_key_values is defined in kwargs 41 | if past_key_values: 42 | input_ids = input_ids[:, -1].unsqueeze(-1) 43 | if token_type_ids is not None: 44 | token_type_ids = token_type_ids[:, -1].unsqueeze(-1) 45 | 46 | attention_mask = kwargs.get("attention_mask") 47 | position_ids = kwargs.get("position_ids") 48 | 49 | if attention_mask is not None and position_ids is None: 50 | # create PCW's position_ids on the fly 51 | position_ids = generate_pcw_position_ids(attention_mask, max_window_size, past_key_values, 52 | sum_windows_size, windows_key_values) 53 | else: 54 | position_ids = None 55 | 56 | if windows_key_values and not past_key_values: 57 | past_key_values = windows_key_values 58 | return { 59 | "input_ids": input_ids, 60 | "past_key_values": past_key_values, 61 | "use_cache": kwargs.get("use_cache"), 62 | "position_ids": position_ids, 63 | "attention_mask": attention_mask, 64 | "token_type_ids": token_type_ids, 65 | } 66 | -------------------------------------------------------------------------------- /modeling_llama_with_pcw.py: -------------------------------------------------------------------------------- 1 | import math 2 | from abc import ABC 3 | from typing import Optional, Tuple, Dict 4 | 5 | import torch 6 | from torch import nn 7 | from transformers import LlamaConfig 8 | from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, LlamaRMSNorm, \ 9 | LlamaDecoderLayer, LlamaModel, LlamaForCausalLM 10 | 11 | from pcw_wrapper import generate_pcw_position_ids 12 | 13 | """ 14 | The following code is mainly copy+paste from the original modelling_llama.py: 15 | LlamaAttention uses a caching mechanism for the positional rotation vectors (using LlamaRotaryEmbedding). 16 | This mechanism forces us to override LLaMa attention layer, which in turn forces us to override the decoder, 17 | and model (so that the correct forward function would be called). 18 | """ 19 | 20 | 21 | class LlamaForCausalLMPCW(LlamaForCausalLM, ABC): 22 | _no_split_modules = ["LlamaDecoderLayerPCW"] 23 | 24 | def __init__(self, config: LlamaConfig): 25 | super(LlamaForCausalLM, self).__init__(config) 26 | # using our Llama model variant: 27 | self.model = LlamaModelPCW(config) 28 | 29 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 30 | 31 | # Initialize weights and apply final processing 32 | self.post_init() 33 | 34 | def prepare_inputs_for_generation(self, 35 | input_ids: torch.LongTensor, 36 | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, 37 | windows_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, 38 | max_window_size: Optional[int] = None, 39 | sum_windows_size: Optional[int] = None, 40 | **kwargs 41 | ) -> Dict: 42 | """input_ids: 43 | ids of task_tokens. 44 | attention_mask: 45 | concatenation of windows + task tokens attentions masks. 46 | 47 | Note (past_key_values vs windows_key_values): 48 | In the first token generation, past_key_values is None while windows_key_values contains the combined past 49 | key values of context windows. During following generations, past_key_values is the concatenation of 50 | windows_key_values + previous generations. Thus, windows_key_values is practically ignored. 51 | """ 52 | 53 | # only last token for inputs_ids if past_key_values is defined in kwargs 54 | if past_key_values: 55 | input_ids = input_ids[:, -1:] 56 | attention_mask = kwargs.get("attention_mask") 57 | position_ids = kwargs.get("position_ids", None) 58 | if attention_mask is not None and position_ids is None: 59 | # create PCW's position_ids on the fly 60 | position_ids = generate_pcw_position_ids(attention_mask, max_window_size, past_key_values, 61 | sum_windows_size, windows_key_values) 62 | 63 | if windows_key_values and not past_key_values: 64 | past_key_values = windows_key_values 65 | 66 | return { 67 | "input_ids": input_ids, 68 | "past_key_values": past_key_values, 69 | "use_cache": kwargs.get("use_cache"), 70 | "position_ids": position_ids, 71 | "attention_mask": attention_mask, 72 | } 73 | 74 | 75 | class LlamaModelPCW(LlamaModel, ABC): 76 | """ 77 | Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] 78 | 79 | Args: 80 | config: LlamaConfig 81 | """ 82 | 83 | def __init__(self, config: LlamaConfig): 84 | super(LlamaModel, self).__init__(config) 85 | self.padding_idx = config.pad_token_id 86 | self.vocab_size = config.vocab_size 87 | 88 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) 89 | # using the alternative decoder layer: 90 | self.layers = nn.ModuleList([LlamaDecoderLayerPCW(config) for _ in range(config.num_hidden_layers)]) 91 | self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 92 | 93 | self.gradient_checkpointing = False 94 | # Initialize weights and apply final processing 95 | self.post_init() 96 | 97 | 98 | class LlamaDecoderLayerPCW(LlamaDecoderLayer): 99 | def __init__(self, config: LlamaConfig): 100 | super().__init__(config) 101 | # overriding attention: 102 | self.self_attn = LlamaAttentionPCW(config=config) 103 | 104 | 105 | class LlamaAttentionPCW(LlamaAttention): 106 | # we have to override the forward attention due to the rotary embeddings caching mechanism 107 | def forward( 108 | self, 109 | hidden_states: torch.Tensor, 110 | attention_mask: Optional[torch.Tensor] = None, 111 | position_ids: Optional[torch.LongTensor] = None, 112 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 113 | output_attentions: bool = False, 114 | use_cache: bool = False, 115 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 116 | bsz, q_len, _ = hidden_states.size() 117 | 118 | query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 119 | key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 120 | value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 121 | 122 | kv_seq_len = key_states.shape[-2] 123 | if past_key_value is not None: 124 | kv_seq_len += past_key_value[0].shape[-2] 125 | 126 | # *** changes to the original code to accommodate PCW: 127 | # making sure that the model generates rotary embeddings in the correct length: 128 | seq_len = kv_seq_len if position_ids is None else int(torch.max(position_ids) + 1) 129 | cos, sin = self.rotary_emb(value_states, seq_len=seq_len) 130 | # *** End of changes due to PCW, the rest of the function is copy-paste from the original transformer package. 131 | 132 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) 133 | # [bsz, nh, t, hd] 134 | 135 | if past_key_value is not None: 136 | # reuse k, v, self_attention 137 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 138 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 139 | 140 | past_key_value = (key_states, value_states) if use_cache else None 141 | 142 | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) 143 | 144 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): 145 | raise ValueError( 146 | f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is" 147 | f" {attn_weights.size()}" 148 | ) 149 | 150 | if attention_mask is not None: 151 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): 152 | raise ValueError( 153 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" 154 | ) 155 | attn_weights = attn_weights + attention_mask 156 | attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) 157 | 158 | # upcast attention to fp32 159 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) 160 | attn_output = torch.matmul(attn_weights, value_states).to(query_states.dtype) 161 | 162 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): 163 | raise ValueError( 164 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" 165 | f" {attn_output.size()}" 166 | ) 167 | 168 | attn_output = attn_output.transpose(1, 2) 169 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 170 | 171 | attn_output = self.o_proj(attn_output) 172 | 173 | if not output_attentions: 174 | attn_weights = None 175 | 176 | return attn_output, attn_weights, past_key_value 177 | -------------------------------------------------------------------------------- /pcw_wrapper.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Optional, Dict 2 | 3 | import numpy as np 4 | import torch 5 | from transformers import PreTrainedTokenizerBase, PreTrainedModel 6 | 7 | from logits_processor import RestrictiveTokensLogitsProcessor 8 | from utils import n_tokens_in_prompt 9 | 10 | 11 | def combine_past_key_values(past_lst: List[Tuple[Tuple[torch.Tensor]]], longest_window_id: int) -> \ 12 | Tuple[Tuple[torch.Tensor, torch.Tensor]]: 13 | # We eliminate all but one bos token from windows to avoid multiple bos, which deterred our results. 14 | n_layers = len(past_lst[0]) 15 | longest_window = past_lst[longest_window_id] 16 | all_windows_except_longest = past_lst[:longest_window_id] + past_lst[longest_window_id + 1:] 17 | return tuple( 18 | (torch.cat([longest_window[i][0]] + [c[i][0][:, :, 1:, :] for c in all_windows_except_longest], dim=2), 19 | torch.cat([longest_window[i][1]] + [c[i][1][:, :, 1:, :] for c in all_windows_except_longest], dim=2)) 20 | for i in range(n_layers)) 21 | 22 | 23 | def generate_pcw_position_ids(attention_mask: torch.Tensor, max_window_size: int, 24 | past_key_values: Tuple[Tuple[torch.Tensor]], 25 | sum_windows_size: int, windows_key_values: Tuple[Tuple[torch.Tensor]]) -> torch.Tensor: 26 | position_ids = attention_mask.long().cumsum(-1) - 1 27 | n_task_tokens = position_ids.shape[1] - sum_windows_size 28 | position_ids[0, -n_task_tokens:] = torch.arange(max_window_size, max_window_size + n_task_tokens, 1) 29 | position_ids.masked_fill_(attention_mask == 0, 1) 30 | if past_key_values: # i.e., first token is already generated 31 | position_ids = position_ids[:, -1].unsqueeze(-1) 32 | elif windows_key_values: # i.e., we are in the first token generation 33 | position_ids = position_ids[:, sum_windows_size:] 34 | return position_ids 35 | 36 | 37 | class PCWModelWrapper: 38 | def __init__(self, 39 | model: PreTrainedModel, 40 | tokenizer: PreTrainedTokenizerBase, 41 | device: str, 42 | context_window_size: int, 43 | right_indentation: bool = False 44 | ): 45 | self.model = model 46 | self.tokenizer = tokenizer 47 | self.context_window_size = context_window_size 48 | self.device = device 49 | # Left indentation is the default behavior as explained in the paper. 50 | self.right_indentation = right_indentation 51 | 52 | def _get_windows(self, texts: List[str]) -> List[Dict]: 53 | windows = [] 54 | if self.right_indentation: 55 | max_window_size = max(n_tokens_in_prompt(self.tokenizer, t, add_special_tokens=True) for t in texts) 56 | 57 | for text in texts: 58 | encoded_input_window = self.tokenizer(text, return_tensors='pt').to(self.device) 59 | window_size = encoded_input_window['input_ids'].shape[1] 60 | if self.right_indentation: 61 | shift = max_window_size - window_size 62 | encoded_input_window["position_ids"] = encoded_input_window["attention_mask"].cumsum(-1) - 1 + shift 63 | with torch.no_grad(): 64 | output = self.model(**encoded_input_window) 65 | windows.append({'text': text, 66 | 'encoded_input': encoded_input_window, 67 | 'attention_mask': encoded_input_window['attention_mask'], 68 | 'window_size': window_size, 69 | 'output': output, 70 | 'past': output['past_key_values']}) 71 | return windows 72 | 73 | def get_contexts_cache(self, contexts: List[str]) -> Dict: 74 | windows = self._get_windows(contexts) 75 | windows_sizes = [window['window_size'] for window in windows] 76 | j = np.argmax(windows_sizes) 77 | # Windows contain bos tokens, we remove all but one to avoid multiple bos 78 | return {'past_key_values': combine_past_key_values([window['past'] for window in windows], j), 79 | 'max_window_size': max(windows_sizes), 80 | 'past_attention_mask': torch.cat( 81 | [windows[j]['attention_mask']] + [window['attention_mask'][:, 1:] for window in 82 | windows[:j] + windows[j + 1:]], dim=1), 83 | 'sum_windows_size': sum(windows_sizes) - (len(windows) - 1)} 84 | 85 | def pcw_generate(self, 86 | contexts: Optional[List[str]] = None, 87 | task_text: Optional[str] = None, 88 | contexts_cache: Optional[Dict] = None, 89 | restrictive_logit_preprocessor: Optional[RestrictiveTokensLogitsProcessor] = None, 90 | **kwargs 91 | ) -> str: 92 | """Note: Batching is not supported by PCW at the moment. """ 93 | assert (contexts is None) != ( 94 | contexts_cache is None), "pcw_generate should work with contexts or cache, not with both!" 95 | cache = contexts_cache or self.get_contexts_cache(contexts) 96 | encoded_task_text = self.tokenizer(task_text, add_special_tokens=False, return_tensors='pt').to(self.device) 97 | if restrictive_logit_preprocessor: 98 | restrictive_logit_preprocessor.update_new_prompt_length_to_skip(encoded_task_text['input_ids'].shape[1]) 99 | kwargs['logits_processor'] = [restrictive_logit_preprocessor] 100 | combined_attention_mask = torch.cat((cache['past_attention_mask'], encoded_task_text['attention_mask']), 101 | dim=1).to(self.device) 102 | with torch.no_grad(): 103 | res = self.model.generate(input_ids=encoded_task_text['input_ids'], 104 | attention_mask=combined_attention_mask, 105 | windows_key_values=cache['past_key_values'], 106 | max_window_size=cache['max_window_size'], 107 | sum_windows_size=cache['sum_windows_size'], 108 | pad_token_id=self.tokenizer.eos_token_id, 109 | **kwargs)[0] 110 | res = res[:-1] if res[-1] == self.tokenizer.eos_token_id else res 111 | return self.tokenizer.decode(res[encoded_task_text['input_ids'].shape[1]:]) 112 | -------------------------------------------------------------------------------- /requirements.in: -------------------------------------------------------------------------------- 1 | numpy~=1.24.2 2 | transformers==4.28.1 3 | matplotlib 4 | pandas~=1.5.3 5 | datasets~=2.9.0 6 | tqdm 7 | accelerate==0.18.0 8 | sentencepiece==0.1.99 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # 2 | # This file is autogenerated by pip-compile with Python 3.8 3 | # by the following command: 4 | # 5 | # pip-compile requirements.in 6 | # 7 | --no-binary grpcio 8 | 9 | accelerate==0.18.0 10 | # via -r requirements.in 11 | aiohttp==3.8.4 12 | # via 13 | # datasets 14 | # fsspec 15 | aiosignal==1.3.1 16 | # via aiohttp 17 | async-timeout==4.0.2 18 | # via aiohttp 19 | attrs==23.1.0 20 | # via aiohttp 21 | certifi==2023.5.7 22 | # via requests 23 | charset-normalizer==3.1.0 24 | # via 25 | # aiohttp 26 | # requests 27 | contourpy==1.1.0 28 | # via matplotlib 29 | cycler==0.11.0 30 | # via matplotlib 31 | datasets==2.9.0 32 | # via -r requirements.in 33 | dill==0.3.6 34 | # via 35 | # datasets 36 | # multiprocess 37 | filelock==3.12.2 38 | # via 39 | # huggingface-hub 40 | # torch 41 | # transformers 42 | fonttools==4.40.0 43 | # via matplotlib 44 | frozenlist==1.3.3 45 | # via 46 | # aiohttp 47 | # aiosignal 48 | fsspec[http]==2023.6.0 49 | # via 50 | # datasets 51 | # huggingface-hub 52 | huggingface-hub==0.15.1 53 | # via 54 | # datasets 55 | # transformers 56 | idna==3.4 57 | # via 58 | # requests 59 | # yarl 60 | importlib-resources==5.12.0 61 | # via matplotlib 62 | jinja2==3.1.2 63 | # via torch 64 | kiwisolver==1.4.4 65 | # via matplotlib 66 | markupsafe==2.1.3 67 | # via jinja2 68 | matplotlib==3.7.1 69 | # via -r requirements.in 70 | mpmath==1.3.0 71 | # via sympy 72 | multidict==6.0.4 73 | # via 74 | # aiohttp 75 | # yarl 76 | multiprocess==0.70.14 77 | # via datasets 78 | networkx==3.1 79 | # via torch 80 | numpy==1.24.3 81 | # via 82 | # -r requirements.in 83 | # accelerate 84 | # contourpy 85 | # datasets 86 | # matplotlib 87 | # pandas 88 | # pyarrow 89 | # transformers 90 | packaging==23.1 91 | # via 92 | # accelerate 93 | # datasets 94 | # huggingface-hub 95 | # matplotlib 96 | # transformers 97 | pandas==1.5.3 98 | # via 99 | # -r requirements.in 100 | # datasets 101 | pillow==9.5.0 102 | # via matplotlib 103 | psutil==5.9.5 104 | # via accelerate 105 | pyarrow==12.0.1 106 | # via datasets 107 | pyparsing==3.0.9 108 | # via matplotlib 109 | python-dateutil==2.8.2 110 | # via 111 | # matplotlib 112 | # pandas 113 | pytz==2023.3 114 | # via pandas 115 | pyyaml==6.0 116 | # via 117 | # accelerate 118 | # datasets 119 | # huggingface-hub 120 | # transformers 121 | regex==2023.6.3 122 | # via transformers 123 | requests==2.31.0 124 | # via 125 | # datasets 126 | # fsspec 127 | # huggingface-hub 128 | # responses 129 | # transformers 130 | responses==0.18.0 131 | # via datasets 132 | sentencepiece==0.1.99 133 | # via -r requirements.in 134 | six==1.16.0 135 | # via python-dateutil 136 | sympy==1.12 137 | # via torch 138 | tokenizers==0.13.3 139 | # via transformers 140 | torch==2.0.1 141 | # via accelerate 142 | tqdm==4.65.0 143 | # via 144 | # -r requirements.in 145 | # datasets 146 | # huggingface-hub 147 | # transformers 148 | transformers==4.28.1 149 | # via -r requirements.in 150 | typing-extensions==4.6.3 151 | # via 152 | # huggingface-hub 153 | # torch 154 | urllib3==2.0.3 155 | # via 156 | # requests 157 | # responses 158 | xxhash==3.2.0 159 | # via datasets 160 | yarl==1.9.2 161 | # via aiohttp 162 | zipp==3.15.0 163 | # via importlib-resources 164 | -------------------------------------------------------------------------------- /run_evaluation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | from typing import List, Optional 4 | 5 | import pandas as pd 6 | from transformers import PreTrainedTokenizerBase 7 | 8 | from datasets_loader import DATASET_NAMES2LOADERS 9 | from experiment_manager import ExperimentManager 10 | from model_loaders import load_pcw_wrapper 11 | from utils import get_max_n_shots, filter_extremely_long_samples, save_results 12 | 13 | _logger = logging.getLogger(__name__) 14 | logging.basicConfig(level=logging.INFO, format='%(message)s') 15 | 16 | 17 | def get_dataset(dataset: str, tokenizer: PreTrainedTokenizerBase) -> (pd.DataFrame, pd.DataFrame, List): 18 | da = DATASET_NAMES2LOADERS[dataset]() 19 | # Filter extremely long samples from both train and test samples: 20 | _logger.info("filtering test set:") 21 | test_df = filter_extremely_long_samples(da.test_df, tokenizer) 22 | _logger.info("filtering train set:") 23 | train_df = filter_extremely_long_samples(da.train_df, tokenizer) 24 | return test_df, train_df, da.labels 25 | 26 | 27 | def run_pcw_experiment(dataset: str, model: str, cache_dir: str, subsample_test_set: int, output_dir: str, 28 | n_windows: List[int], n_shots_per_window: Optional[int], n_runs: int, 29 | random_seed: int, right_indentation: bool) -> None: 30 | pcw_model = load_pcw_wrapper(model, cache_dir, right_indentation, max(n_windows)) 31 | 32 | test_df, train_df, labels = get_dataset(dataset, pcw_model.tokenizer) 33 | 34 | if n_shots_per_window is None: 35 | # default behaviour: we take the maximum number of samples per window 36 | n_shots_per_window = get_max_n_shots(train_df, test_df, pcw_model.tokenizer, pcw_model.context_window_size) 37 | _logger.info(f"Found max n shot per window = {n_shots_per_window}") 38 | 39 | n_shots = [i * n_shots_per_window for i in n_windows] 40 | 41 | em = ExperimentManager(test_df, train_df, pcw_model, labels, random_seed=random_seed, 42 | n_shots_per_window=n_shots_per_window, subsample_test_set=subsample_test_set) 43 | 44 | accuracies = em.run_experiment_across_shots(n_shots, n_runs) 45 | save_results(dataset, n_shots, accuracies, output_dir, model) 46 | 47 | 48 | if __name__ == '__main__': 49 | parser = argparse.ArgumentParser() 50 | parser.add_argument('--dataset', dest='dataset', action='store', required=True, 51 | help=f'Name of dataset (for example sst2).' 52 | f' The supported datasets are: {DATASET_NAMES2LOADERS.keys()}') 53 | parser.add_argument('--model', dest='model', action='store', default='gpt2', 54 | help='HF model name to use, either gpt2 or LLaMa family models') 55 | parser.add_argument('--subsample-test-set', dest='subsample_test_set', action='store', required=False, type=int, 56 | help='Size of test set to use to speed up eval. None means using all test set.') 57 | parser.add_argument('--output-dir', dest='output_dir', required=False, help="Directory for saving the results", 58 | default='./temp', action='store', type=str) 59 | parser.add_argument('--cache-dir', help="Hugging face cache dir", type=str, default=None, dest='cache_dir') 60 | parser.add_argument('--random-seed', dest='random_seed', required=False, default=42, action='store', type=int) 61 | parser.add_argument('--n-runs', dest='n_runs', 62 | help="Number of times experiments are repeated for every number of windows", action='store', 63 | type=int, default=1) 64 | parser.add_argument('-n', '--n-windows', dest='n_windows', help="Number of parallel context windows", 65 | action='append', type=int) 66 | parser.add_argument('--n-shots-per-window', dest='n_shots_per_window', 67 | help="number of examples to fit in each window", type=int, default=None) 68 | parser.add_argument('--right-indentation', dest='right_indentation', help="ident all windows to the right", 69 | action='store_true', default=False) 70 | args = parser.parse_args() 71 | run_pcw_experiment(**vars(args)) 72 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from typing import List, Tuple 4 | 5 | import numpy as np 6 | import pandas as pd 7 | from matplotlib import pyplot as plt 8 | from numpy import typing as npt 9 | from torch import distributed as dist 10 | from transformers import PreTrainedTokenizerBase, LlamaTokenizer 11 | 12 | from constants import TEXT_BETWEEN_SHOTS, N_TOKENS, PROMPTS 13 | 14 | _logger = logging.getLogger(__name__) 15 | logging.basicConfig(level=logging.INFO, format='%(message)s') 16 | 17 | 18 | def get_max_n_shots(train_df: pd.DataFrame, test_df: pd.DataFrame, tokenizer: PreTrainedTokenizerBase, 19 | prompt_size: int) -> int: 20 | n_tokens_between_shots = n_tokens_in_prompt(tokenizer, TEXT_BETWEEN_SHOTS) 21 | shot_lengths = train_df[N_TOKENS] + n_tokens_between_shots 22 | prompt_length_percentile = shot_lengths.quantile(0.9) 23 | longest_test_prompt = test_df[N_TOKENS].max() 24 | _logger.info(f"longest_test_prompt = {longest_test_prompt}") 25 | max_possible_shots_length = prompt_size - longest_test_prompt 26 | return int(np.floor(max_possible_shots_length / prompt_length_percentile)) 27 | 28 | 29 | def filter_extremely_long_samples(df: pd.DataFrame, tokenizer: PreTrainedTokenizerBase) -> pd.DataFrame: 30 | df[N_TOKENS] = df[PROMPTS].map(lambda x: n_tokens_in_prompt(tokenizer, x)) 31 | mask = df[N_TOKENS] <= df[N_TOKENS].quantile(0.99) 32 | _logger.info(f"filtered {sum(~mask)} from dataset due to extreme length") 33 | df = df.loc[mask].copy() 34 | _logger.info(f"longest remaining prompt according to tokenizer: {df[N_TOKENS].max()}") 35 | return df 36 | 37 | 38 | def n_tokens_in_prompt(tokenizer: PreTrainedTokenizerBase, prompt: str, add_special_tokens=False) -> int: 39 | return len(tokenizer.encode(prompt, add_special_tokens=add_special_tokens)) 40 | 41 | 42 | def plot_results_graph(results, dataset_name, n_shots, model='') -> None: 43 | plt.figure() 44 | plt.errorbar(n_shots, np.mean(results, axis=1), np.std(results, axis=1), fmt='*') 45 | plt.xlabel("# shots") 46 | plt.xticks(n_shots) 47 | metric = 'Accuracy' 48 | plt.ylabel(f"{dataset_name} {metric}") 49 | plt.title(f"{metric} {dataset_name} {model}") 50 | 51 | 52 | def load_results(dataset_name: str, output_dir: str, plot=False) -> Tuple[npt.NDArray[float], List[int]]: 53 | all_results = os.listdir(output_dir) 54 | results_path = [r for r in all_results if r.startswith(f'{dataset_name}_')] 55 | if len(results_path) != 1: 56 | raise ValueError(f"Found {len(results_path)} results!") 57 | results_path = results_path[0] 58 | results = np.load(os.path.join(output_dir, results_path)) 59 | n_shots = [int(d) for d in results_path.split('.')[-2].split('_') if d.isdigit()] 60 | if plot: 61 | plot_results_graph(results, dataset_name, n_shots) 62 | return results, n_shots 63 | 64 | 65 | def save_results(dataset: str, n_shots: List[int], results: npt.NDArray[int], output_dir: str, 66 | model: str = '', plot_results: bool = True) -> None: 67 | if plot_results: 68 | plot_results_graph(results, dataset, n_shots, model) 69 | plt.show() 70 | if not dist.is_initialized() or dist.get_rank() == 0: 71 | # in case we use multiple GPUs - we only save one file 72 | os.makedirs(output_dir, exist_ok=True) 73 | output_path = f"{output_dir}/{dataset}_n_shots_results_{'_'.join([str(i) for i in n_shots])}.npy" 74 | np.save(output_path, results) 75 | 76 | 77 | def encode_labels(tokenizer: PreTrainedTokenizerBase, labels: List[str]) -> List[List[int]]: 78 | if isinstance(tokenizer, LlamaTokenizer): 79 | # sentence piece - adds a space at the beginning of the sentence 80 | return [tokenizer.encode(f'{label.lstrip()}', add_special_tokens=False) for label in labels] 81 | 82 | return [tokenizer.encode(f' {label.lstrip()}', add_special_tokens=False) for label in labels] 83 | 84 | 85 | def encode_stop_seq(tokenizer: PreTrainedTokenizerBase, stop_seq: str) -> int: 86 | stop_seq_token_id = tokenizer.encode(stop_seq, add_special_tokens=False) 87 | if isinstance(tokenizer, LlamaTokenizer): 88 | assert len(stop_seq_token_id) == 2 89 | else: 90 | assert len(stop_seq_token_id) == 1 91 | return stop_seq_token_id[-1] 92 | --------------------------------------------------------------------------------