├── prompt_fab ├── __init__.py ├── lm_openai.py └── templates.py ├── .gitignore ├── example.py └── README.md /prompt_fab/__init__.py: -------------------------------------------------------------------------------- 1 | from .templates import * -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .DS_Store 3 | openai_api_key.txt -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | from prompt_fab import * 2 | from prompt_fab.lm_openai import get_template_completion_tokens_and_logprobs 3 | 4 | template = Prefix( 5 | 'Answer "Yes" or "No" to the following questions.\n\n', 6 | Repeat( 7 | Record( 8 | question=Affix('Q: ', SENTENCE, EOL), 9 | answer=Prefix('A: ', YES_NO) 10 | ), 11 | delimiter='\n\n' 12 | ) 13 | ) 14 | 15 | context_examples = [ 16 | {'question': 'Is the sky blue?', 'answer': True}, 17 | {'question': 'Can fish play basketball?', 'answer': False} 18 | ] 19 | 20 | query = 'Can you eat soup with a spoon?' 21 | 22 | # Pass in both partial data (just the context examples and query) 23 | # as well as the full data including the target label that we want 24 | # the likelihood of. Only one API call is made. 25 | tokens, scores = get_template_completion_tokens_and_logprobs( 26 | template, 27 | context_examples+[{'question': query, 'answer': None}], 28 | context_examples+[{'question': query, 'answer': True}] 29 | ) 30 | print(tokens, scores) -------------------------------------------------------------------------------- /prompt_fab/lm_openai.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import openai 3 | import os 4 | from prompt_fab.templates import Template 5 | 6 | __all__ = [ 7 | "get_completion_tokens_and_logprobs", 8 | "get_template_completion_tokens_and_logprobs" 9 | ] 10 | 11 | if 'OPENAI_API_KEY' not in os.environ and openai.api_key_path is None: 12 | openai.api_key_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'openai_api_key.txt') 13 | assert os.path.exists(openai.api_key_path) 14 | 15 | DEFAULT_MODEL = 'text-davinci-002' 16 | 17 | def get_completion_tokens_and_logprobs(prompt_text: str, completion_text: str, model: Optional[str]=None): 18 | """ 19 | Accepts a prompt prefix and a desired completion and returns a tuple of the 20 | tokens and conditional token log-likelihoods corresponding to the provided completion. 21 | """ 22 | if model is None: 23 | model = DEFAULT_MODEL 24 | response = openai.Completion.create( 25 | model=model, 26 | prompt=prompt_text+completion_text, 27 | max_tokens=0, 28 | logprobs=0, 29 | echo=True, 30 | n=1 31 | ) 32 | logprobs_obj = response["choices"][0]["logprobs"] 33 | completion_start_index = None 34 | for i, offset in enumerate(logprobs_obj["text_offset"]): 35 | if offset == len(prompt_text): 36 | completion_start_index = i 37 | break 38 | elif offset > len(prompt_text): 39 | completion_start_index = i-1 40 | break 41 | if completion_start_index is None: 42 | completion_start_index = len(logprobs_obj['text_offset'])-1 43 | return logprobs_obj["tokens"][completion_start_index:], logprobs_obj["token_logprobs"][completion_start_index:] 44 | 45 | def get_template_completion_tokens_and_logprobs(template: Template, data_partial, data_full, model: Optional[str]=None): 46 | """ 47 | Accepts a prompt template and two versions of a data object: 48 | data_partial: The data specifying just the prompt prefix 49 | data_full: The full data, which also specifies the target completion 50 | Returns a tuple containing the tokens and conditional token log-likelihoods corresponding 51 | to only the portion of the filled template that is not specified by data_partial. 52 | """ 53 | prompt_text = template.fill(data_partial) 54 | full_text = template.fill(data_full) 55 | assert full_text.startswith(prompt_text) 56 | completion_text = full_text[len(prompt_text):] 57 | return get_completion_tokens_and_logprobs(prompt_text, completion_text, model=model) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PromptFab 🛠️ 2 | 3 | PromptFab is a toolkit that helps with creating reusable templates for prompting large language models. 4 | 5 | PromptFab lets you write prompts that express the same underlying data schema in multiple formats without a bunch of boilerplate. 6 | 7 | ## Examples 8 | 9 | ```python 10 | from prompt_fab import * 11 | 12 | template = Prefix( 13 | 'Please answer the following questions.\n\n', 14 | Repeat( 15 | Record( 16 | question=Affix( 17 | 'Q: ', SENTENCE, EOL 18 | ), 19 | answer=Affix( 20 | 'A: ', SENTENCE 21 | ) 22 | ), 23 | delimiter='\n\n' 24 | ) 25 | ) 26 | 27 | data = [ 28 | { 29 | 'question': 'What year was Aubrey Plaza born?', 30 | 'answer': '1984' 31 | }, 32 | { 33 | 'question': 'What should I have for breakfast?', 34 | 'answer': 'You should have a banana.' 35 | } 36 | ] 37 | 38 | print(template.fill(data)) 39 | ``` 40 | Output: 41 | ``` 42 | Please answer the following questions. 43 | 44 | Q: What year was Aubrey Plaza born? 45 | A: 1984 46 | 47 | Q: What should I have for breakfast? 48 | A: You should have a banana. 49 | ``` 50 | PromptFab templates also allow parsing surface strings back into the data schema: 51 | ```python 52 | assert data == template.match(template.fill(data)) 53 | ``` 54 | 55 | An example of a prompt for a discriminative task: 56 | ```python 57 | from prompt_fab import * 58 | 59 | data = { 60 | 'premise': "I'm the best magician in the world.", 61 | 'hypothesis': "I can do magic.", 62 | 'label': True 63 | } 64 | 65 | template = Record( 66 | premise=Suffix(SENTENCE, EOL), 67 | hypothesis=Affix('Hypothesis: ', SENTENCE, EOL), 68 | label=Prefix('Does this hypothesis follow? ', 69 | Option({True: 'Yes', False: 'No'}) 70 | ) 71 | ) 72 | 73 | print(template.fill(data)) 74 | ``` 75 | Output: 76 | ``` 77 | I'm the best magician in the world. 78 | Hypothesis: I can do magic. 79 | Does this hypothesis follow? Yes 80 | ``` 81 | 82 | ## Computing scores 83 | 84 | PromptFab also includes helper functions in the `prompt_fab.lm_openai` module that let you compute the model scores of particular prompt completions using OpenAI model API calls. Make sure to have the `OPENAI_API_KEY` environment variable set, set `openai.api_key_path` in your own code, or place your OpenAI API key in a file `openai_api_key.txt` in this module's root directory. 85 | 86 | ```python 87 | from prompt_fab import * 88 | from prompt_fab.lm_openai import get_template_completion_tokens_and_logprobs 89 | 90 | template = Prefix( 91 | 'Answer "Yes" or "No" to the following questions.\n\n', 92 | Repeat( 93 | Record( 94 | question=Affix('Q: ', SENTENCE, EOL), 95 | answer=Prefix('A: ', YES_NO) 96 | ), 97 | delimiter='\n\n' 98 | ) 99 | ) 100 | 101 | context_examples = [ 102 | {'question': 'Is the sky blue?', 'answer': True}, 103 | {'question': 'Can fish play basketball?', 'answer': False} 104 | ] 105 | 106 | query = 'Can you eat soup with a spoon?' 107 | 108 | # Pass in both partial data (just the context examples and query) 109 | # as well as the full data including the target label that we want 110 | # the likelihood of. Only one API call is made. 111 | tokens, scores = get_template_completion_tokens_and_logprobs( 112 | template, 113 | context_examples+[{'question': query, 'answer': None}], 114 | context_examples+[{'question': query, 'answer': True}] 115 | ) 116 | print(tokens, scores) 117 | ``` 118 | Output: 119 | ``` 120 | [' Yes'] [-0.03768289] 121 | ``` 122 | This log-likelihood corresponds to a probability of 96%, so it looks like GPT-3 agrees that you can eat soup with a spoon. 123 | 124 | For more details on the provided template building blocks, 125 | refer to the docstrings in `prompt_fab.templates`. -------------------------------------------------------------------------------- /prompt_fab/templates.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Union, Optional, Mapping 2 | import re 3 | 4 | __all__ = [ 5 | "Template", 6 | "Fixed", 7 | "Pattern", 8 | "Integer", 9 | "Append", 10 | "Repeat", 11 | "NumberedList", 12 | "Affix", 13 | "Suffix", 14 | "Prefix", 15 | "Record", 16 | "Option", 17 | "SENTENCE", 18 | "EOL", 19 | "SPACE", 20 | "YES_NO", 21 | "NUM" 22 | ] 23 | 24 | class StringPos: 25 | """ 26 | Helper class to allow position in a string to be passed by reference. 27 | """ 28 | def __init__(self, s, pos=0): 29 | self.s = s 30 | self.pos = pos 31 | def advance(self, n): 32 | self.pos += n 33 | 34 | class Template: 35 | """ 36 | Template base class. 37 | Templates can convert data to string representations via .fill, 38 | or parse data from string representations via .match 39 | 40 | Custom Template subclasses should override ._match, not .match 41 | """ 42 | def _match(self, sp: StringPos): 43 | raise NotImplementedError() 44 | def match(self, s: str, pos=0): 45 | return self._match(StringPos(s, pos=pos)) 46 | def fill(self, *args): 47 | raise NotImplementedError() 48 | 49 | class Fixed(Template): 50 | """ 51 | Template that always surfaces a fixed string. 52 | """ 53 | def __init__(self, default, accepted_pattern=None): 54 | self.default = default 55 | accepted_pattern = default if accepted_pattern is None else accepted_pattern 56 | self.accepted_pattern = re.compile(accepted_pattern) 57 | def fill(self, *args): 58 | return self.default 59 | def _match(self, sp: StringPos): 60 | m = self.accepted_pattern.match(sp.s, pos=sp.pos) 61 | if m is None: 62 | return None 63 | else: 64 | sp.advance(len(m.group(0))) 65 | return m.group(0) 66 | 67 | def str_to_fixed(str_or_template: Union[str, Template]): 68 | if isinstance(str_or_template, Template): 69 | return str_or_template 70 | elif isinstance(str_or_template, str): 71 | return Fixed(str_or_template) 72 | else: 73 | raise ValueError('Expected string or template') 74 | 75 | NOTHING = Fixed('') 76 | 77 | class Pattern(Template): 78 | """ 79 | Template that surfaces strings. When parsing via .match, 80 | a Pattern instance will only consume strings that match its 81 | regular expression (specified by the pattern argument) 82 | """ 83 | def __init__(self, pattern: str): 84 | self.pattern = re.compile(pattern) 85 | self.group_idx = 0 if self.pattern.groups == 0 else 1 86 | def _match(self, sp: StringPos): 87 | m = self.pattern.match(sp.s, pos=sp.pos) 88 | if m is None: 89 | return None 90 | else: 91 | sp.advance(len(m.group(0))) 92 | return m.group(self.group_idx) 93 | def fill(self, s): 94 | if s is None: 95 | return '' 96 | return s 97 | 98 | class Integer(Pattern): 99 | """ 100 | Template that maps integers to strings. 101 | """ 102 | def __init__(self): 103 | super().__init__(r'-?[0-9]+') 104 | def _match(self, sp: StringPos): 105 | revert_pos = sp.pos 106 | s = super()._match(sp) 107 | try: 108 | i = int(s, base=10) 109 | except (ValueError, TypeError) as e: 110 | sp.pos = revert_pos 111 | return None 112 | return i 113 | def fill(self, i): 114 | if i is None: 115 | return '' 116 | return super().fill(str(i)) 117 | 118 | class Append(Template): 119 | """ 120 | Template that maps a fixed-length sequence to a string by applying a series 121 | of item_templates to the corresponding entries in the sequence. 122 | """ 123 | def __init__(self, *item_templates: Union[str, Template]): 124 | self.item_templates = tuple(map(str_to_fixed, item_templates)) 125 | def _match(self, sp: StringPos): 126 | matches = [] 127 | for item_template in self.item_templates: 128 | if (m := item_template._match(sp)) is None: 129 | return None 130 | matches.append(m) 131 | return tuple(matches) 132 | def fill(self, items): 133 | if items is None: 134 | return '' 135 | return ''.join(item_template.fill(item) for item_template, item in zip(self.item_templates, items)) 136 | 137 | class Repeat(Template): 138 | """ 139 | Template that maps a sequence to a string by applying item_template to each 140 | element of the sequence, separating elements with the provided delimiter. 141 | """ 142 | def __init__(self, item_template: Template, delimiter: Union[str, Fixed], trailing_delimiter=False): 143 | self.item_template = item_template 144 | self.delimiter = str_to_fixed(delimiter) 145 | self.trailing_delimiter = trailing_delimiter 146 | def _match(self, sp: StringPos): 147 | vals = [] 148 | i_revert = sp.pos 149 | last_dm = None 150 | while (m := self.item_template._match(sp)) is not None: 151 | vals.append(m) 152 | i_revert = sp.pos 153 | if (last_dm := self.delimiter._match(sp)) is None: 154 | break 155 | if (last_dm is not None) and not self.trailing_delimiter: 156 | sp.pos = i_revert 157 | return vals 158 | 159 | def fill(self, values): 160 | if values is None: 161 | return '' 162 | d = self.delimiter.fill() 163 | filled = d.join( 164 | self.item_template.fill(v) for v in values 165 | ) 166 | return filled+d if self.trailing_delimiter else filled 167 | 168 | class NumberedList(Repeat): 169 | """ 170 | Template that maps a sequence to a numbered list. This template accepts 171 | a label_template that should map integers to strings, 172 | and an item_template which should surface the actual list entries. 173 | Example: 174 | >>> t = NumberedList(Suffix(NUM, '. '), SENTENCE, EOL) 175 | >>> print(t.fill(['This is Sentence One.', 'Copy that, Sentence One - Sentence Two here.'])) 176 | 1. This is Sentence One. 177 | 2. Copy that, Sentence One. 178 | """ 179 | def __init__(self, label_template: Template, item_template: Template, delimiter: Union[str, Fixed], trailing_delimiter=False, start_idx=1): 180 | super().__init__(Append(label_template, item_template), delimiter, trailing_delimiter=trailing_delimiter) 181 | self.start_idx = start_idx 182 | 183 | def _match(self, sp): 184 | sm = super()._match(sp) 185 | return None if sm is None else [pair[1] for pair in sm] 186 | 187 | def fill(self, values): 188 | if values is None: 189 | return '' 190 | return super().fill(zip(map(str, range(self.start_idx, len(values)+self.start_idx)), values)) 191 | 192 | class Affix(Template): 193 | """ 194 | Template that adds fixed decoration to either (or both) sides of a given content template. 195 | """ 196 | def __init__(self, prefix: Union[str, Fixed], content: Template, suffix: Optional[Union[str, Fixed]] = None): 197 | self.prefix = str_to_fixed(prefix) 198 | self.content = content 199 | self.suffix = None if suffix is None else str_to_fixed(suffix) 200 | def _match(self, sp: StringPos): 201 | pm = self.prefix._match(sp) 202 | if pm is None: 203 | return None 204 | cm = self.content._match(sp) 205 | if cm is None: 206 | return None 207 | if self.suffix is not None: 208 | self.suffix._match(sp) 209 | return cm 210 | def fill(self, c): 211 | filled = self.prefix.fill()+self.content.fill(c) 212 | if self.suffix is not None and c is not None: 213 | filled += self.suffix.fill() 214 | return filled 215 | 216 | class Suffix(Affix): 217 | def __init__(self, content: Template, suffix: Union[str, Fixed]): 218 | super().__init__(NOTHING, content, suffix=suffix) 219 | 220 | Prefix = Affix 221 | 222 | class Record(Template): 223 | """ 224 | Template that maps the named fields of a record to a sequence of concatenated templates. 225 | Example: 226 | >>> t = Record( 227 | a=Affix(NUM, ' '), 228 | b=NUM 229 | ) 230 | >>> print(t.fill({'a': 1, 'b': 2})) 231 | 1 2 232 | >>> t.match('1 2') 233 | {'a': 1, 'b': 2} 234 | """ 235 | def __init__(self, **fields: Mapping[str, Union[str, Template]]): 236 | self.fields = {name: str_to_fixed(str_or_template) for name, str_or_template in fields.items()} 237 | def _match(self, sp: StringPos): 238 | rec = {} 239 | for field_name, field_template in self.fields.items(): 240 | m = field_template._match(sp) 241 | if m is not None: 242 | rec[field_name] = m 243 | return rec 244 | def fill(self, rec): 245 | if rec is None: 246 | return '' 247 | return ''.join(f.fill(rec.get(f_name)) for f_name, f in self.fields.items() if f_name in rec) 248 | 249 | class Option(Template): 250 | """ 251 | Template that maps between a set of values and arbitrary strings corresponding to each value. 252 | """ 253 | def __init__(self, value_templates: Mapping[Any, Union[str, Fixed]]): 254 | self.value_templates = {value: str_to_fixed(str_or_fixed) for value, str_or_fixed in value_templates.items()} 255 | def _match(self, sp: StringPos): 256 | for val, val_template in self.value_templates.items(): 257 | m = val_template._match(sp) 258 | if m is not None: 259 | return val 260 | return None 261 | def fill(self, val): 262 | if val is None: 263 | return '' 264 | return self.value_templates[val].fill() 265 | 266 | # Any string that (optionally) ends in a sentence-ending punctuation mark 267 | # and doesn't run longer than a single line 268 | SENTENCE = Pattern(r'[^\n\.\?\!]+[\.\?\!]?') 269 | EOL = Fixed('\n') 270 | SPACE = Fixed(' ', accepted_pattern=r' +') 271 | YES_NO = Option({False: 'No', True: 'Yes'}) 272 | NUM = Integer() --------------------------------------------------------------------------------