├── LICENSE ├── README.md ├── grammar.py ├── script.py ├── state_machine.py ├── symbols.py ├── test_tokenizer.py ├── tests.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 im-not-tom 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GBNF grammar extension for text-generation-webui 2 | 3 | Implementation of [GBNF grammar](https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md) 4 | for text-generation-webui. 5 | 6 | Allows to force output generated by LLM to conform with expected format, making it easier to parse. For example, 7 | with this extension you can get any LLM to generate proper JSON or YAML. 8 | 9 | This is done by creating state machine which, every time when new token is to be generated, decides which 10 | tokens are allowed to conform with defined grammar and ban (set probability to -inf) all other tokens. 11 | 12 | To better explain the issue, here is [an example of prompt and output generated without extension demonstrating the problem](https://rentry.org/yxg7s) 13 | and here is [same prompt and output generated with grammar](https://rentry.org/4tyci) 14 | 15 | ### Differences from GBNF in lamma.cpp 16 | 17 | This should be fully-conforming implementation, but allows for two additional features: 18 | - `.*` regexp can be used to basically turn off extension at some point and let LLM generate rest of the string normally 19 | - Terminals can be defined with both "" quotes and '' quotes. This was done by mistake and then kept for my convenience. 20 | -------------------------------------------------------------------------------- /grammar.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, List, Dict, Optional, Set 2 | from extensions.output_template.symbols import Symbol, Terminal, NonTerminal, Sequence, Alternative, Repeat, RegExp, \ 3 | AnyToken 4 | from extensions.output_template.state_machine import Advance, Matcher 5 | from extensions.output_template.utils import shared, AllowedTokens 6 | import torch, re 7 | 8 | RE_RULE = re.compile(r'\s*([-a-z]+)\s*::=\s*(.*)', re.MULTILINE | re.DOTALL) 9 | RE_NEWLINE = re.compile(r'[ \t]*\n[ \t\n]*(.*)', re.MULTILINE | re.DOTALL) 10 | RE_NONTERMINAL = re.compile(r'[ \t]*([-a-z_]+)[ \t]*(.*)', re.DOTALL) 11 | RE_ANYTOKEN = re.compile(r'[ \t]*\.[ \t]*\*[ \t]*(.*)', re.DOTALL) 12 | RE_OR = re.compile(r'[ \t\n]*\|[ \t]*(.*)', re.MULTILINE | re.DOTALL) 13 | RE_COMMENT = re.compile(r'([^#]*)#[^\n]*(.*)', re.MULTILINE | re.DOTALL) 14 | 15 | 16 | class Grammar: 17 | """ Grammar and also state machine used to match LLM output """ 18 | 19 | def __init__(self, definition: str): 20 | self.rules: Dict[str, Symbol] = {} 21 | self.active_matcher: Optional[Matcher] = None 22 | self.reset(definition) 23 | 24 | def stop(self): 25 | self.active_matcher = None 26 | 27 | def reset(self, definition: str = None): 28 | self.stop() 29 | if definition: 30 | text = definition 31 | self.rules = {} 32 | 33 | # Strip comments 34 | m = RE_COMMENT.match(text) 35 | while m: 36 | text = m.group(1) + m.group(2) 37 | m = RE_COMMENT.match(text) 38 | 39 | while text: 40 | m = RE_RULE.match(text) 41 | if not m: 42 | raise GrammarError("expected rule") 43 | rule_name, text = m.groups() 44 | if rule_name in self.rules: 45 | raise ValidationError(f"duplicate rule '{rule_name}'") 46 | self.rules[rule_name], text = parse_rule(text) 47 | 48 | if "root" not in self.rules: 49 | raise ValidationError("missing 'root' rule") 50 | for rule in self.rules.values(): 51 | rule.validate(self) 52 | 53 | self.enter_rule("root") 54 | 55 | def resolve(self, symbol: "Symbol") -> "Symbol": 56 | # Resolves NonTerminal into rule and returns Symbol it represents 57 | dont_loop: Set[Terminal] = set([symbol]) 58 | while isinstance(symbol, NonTerminal): 59 | dont_loop.add(symbol) 60 | if symbol.name not in self.rules: 61 | raise ValidationError(f"invalid rule name: '{symbol.name}'") 62 | if self.rules[symbol.name] in dont_loop: 63 | raise ValidationError(f"infinite loop detected at symbol '{symbol.name}'") 64 | symbol = self.rules[symbol.name] 65 | return symbol 66 | 67 | def enter_rule(self, name: str): 68 | """ 69 | Sets active symbol to specific rule. 70 | Used for testing. 71 | """ 72 | if name not in self.rules: 73 | raise ValueError(f"invalid rule name: '{name}'") 74 | self.active_matcher = self.resolve(self.rules[name]).enter(self) 75 | 76 | def get_rule_name(self, symbol: "Symbol"): 77 | for (name, s) in self.rules.items(): 78 | if symbol is s: 79 | return name 80 | return None 81 | 82 | def get_effective_matcher(self) -> Optional["Matcher"]: 83 | """ 84 | Recursively descends rule hierarchy and returns symbol 85 | that will effectively decide on next token 86 | """ 87 | return self.active_matcher.get_effective_matcher() if self.active_matcher else None 88 | 89 | def update_scores(self, scores: torch.FloatTensor) -> torch.FloatTensor: 90 | """ 91 | Calculates probability scores of next token according to current state. 92 | May update and return same object as one that was passed as argument. 93 | """ 94 | # TODO: how to cache or optimize this? 95 | if self.active_matcher: 96 | allowed = self.active_matcher.get_allowed_tokens(self) 97 | if allowed.look_ahead: 98 | allowed.allow_eos = True 99 | 100 | allowed.apply(scores) 101 | else: 102 | # Grammar reached terminal token. Force EOS 103 | AllowedTokens(allowed={int(shared.tokenizer.eos_token_id)}, allow_eos=True).apply(scores) 104 | return scores 105 | 106 | def advance(self, token_id: int): 107 | try: 108 | if self.active_matcher: 109 | # from extensions.output_template.script import logger 110 | # logger.warning(f"Feeding {token_id} into {self.active_matcher}.") 111 | a = self.active_matcher.advance(self, token_id) 112 | if a == Advance.Reject: 113 | if token_id == shared.tokenizer.eos_token_id: 114 | self.active_matcher = None 115 | else: 116 | raise GenerationError 117 | elif a == Advance.Done: 118 | self.active_matcher = None 119 | except GenerationError as e: 120 | from extensions.output_template.script import logger 121 | logger.warning("LLM failed to generate token conforming to grammar") 122 | self.active_matcher = None 123 | 124 | 125 | class GrammarError(ValueError): 126 | pass 127 | 128 | 129 | class ValidationError(GrammarError): 130 | pass 131 | 132 | 133 | class GenerationError(Exception): 134 | pass 135 | 136 | 137 | def find_unescaped_index(haystack: str, needle: str, start=0) -> int: 138 | index = start 139 | while True: 140 | index, index2 = haystack.find(needle, index), haystack.find("\\", index) 141 | if index < 0: 142 | return len(haystack) 143 | if index2 >= 0 and index2 < index: 144 | index = index2 + 2 145 | else: 146 | return index 147 | 148 | 149 | def parse_sequence(text: str, parentheses=False) -> Tuple[Sequence, str]: 150 | seq = [] 151 | while text: 152 | if text[0] in '"\'': 153 | # Terminal rule 154 | try: 155 | end_index = find_unescaped_index(text, text[0], 1) 156 | t = text[1:end_index].encode("utf-8").decode("unicode_escape") 157 | seq.append(Terminal(t)) 158 | text = text[end_index+1:] 159 | except ValueError: 160 | raise GrammarError(f"unmatched {text[0]}") 161 | elif RE_NONTERMINAL.match(text): 162 | # Non-terminal rule 163 | t, text = RE_NONTERMINAL.match(text).groups() 164 | seq.append(NonTerminal(t)) 165 | elif text[0] in " \t": 166 | # Whitespace 167 | text = text[1:] 168 | elif text[0] == "[": 169 | # Regexp rule 170 | try: 171 | end_index = find_unescaped_index(text, "]", 1) 172 | except ValueError: 173 | raise GrammarError(f"unmatched {text[0]}") 174 | try: 175 | seq.append(RegExp(text[0:end_index+1])) 176 | text = text[end_index+1:] 177 | except ValueError: 178 | raise GrammarError(f"invalid pattern {text[0:end_index]}") 179 | elif text[0] == "(": 180 | # Parenthesized rule 181 | text = text[1:] 182 | t, text = parse_rule(text, parentheses=True) 183 | seq.append(t) 184 | pass 185 | elif parentheses and text[0] == ")": 186 | text = text[1:] 187 | break 188 | elif RE_ANYTOKEN.match(text): 189 | # '.*' rule 190 | text, = RE_ANYTOKEN.match(text).groups() 191 | seq.append(AnyToken()) 192 | elif text[0] in "*?+": 193 | # Repeat rule 194 | if not seq: 195 | raise GrammarError(f"unexpected '{text[0]}'") 196 | left = seq.pop() 197 | if text[0] == "+": 198 | # A+ is converted into sequence (A A*) 199 | if text[0] in "+" and isinstance(left, RegExp): 200 | # If child is regexp, extend its rule so multi-character tokens are matched 201 | left = RegExp(left.value + text[0]) 202 | seq.append(Sequence([left, Repeat("*", left)])) 203 | else: 204 | seq.append(Repeat(text[0], left)) 205 | text = text[1:] 206 | elif RE_OR.match(text): 207 | text, = RE_OR.match(text).groups() 208 | if not seq: 209 | raise GrammarError(f"unexpected '|'") 210 | left = seq.pop() 211 | right, text = parse_rule(text, parentheses=parentheses) 212 | seq.append(Alternative([left, right])) 213 | break 214 | elif RE_NEWLINE.match(text): 215 | # Newline 216 | text, = RE_NEWLINE.match(text).groups() 217 | if not parentheses: 218 | break 219 | else: 220 | raise GrammarError(f"unexpected '{text[0:5]}'...") 221 | 222 | return Sequence(seq), text 223 | 224 | 225 | def parse_rule(text: str, parentheses=False) -> Tuple[Symbol, str]: 226 | rv, text = parse_sequence(text, parentheses=parentheses) 227 | if len(rv.items) == 1: 228 | rv = rv.items[0] 229 | return rv, text 230 | -------------------------------------------------------------------------------- /script.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from extensions.output_template.grammar import Grammar 3 | from extensions.output_template.utils import shared 4 | from functools import partial 5 | import torch, transformers 6 | try: 7 | from modules.logging_colors import logger 8 | except ModuleNotFoundError: 9 | # Allows testing by running script outside text-generation-ui 10 | from logging import Logger 11 | logger = Logger(__name__) 12 | transformers = None 13 | class LogitsProcessor: 14 | pass 15 | 16 | 17 | EMPTY_GRAMMAR = "root ::= .*" 18 | 19 | params = { 20 | "grammar": Grammar(EMPTY_GRAMMAR), 21 | "enabled": False, 22 | "template": "", 23 | "token_dictionary": None, 24 | "used_tokenizer": None, 25 | "scores_size": 0, 26 | } 27 | 28 | 29 | class TemplatingLogitsProcessor(transformers.LogitsProcessor): 30 | 31 | def __init__(self): 32 | super().__init__() 33 | self.last_input_size = 0 34 | 35 | def __call__(self, input_ids: Optional[torch.LongTensor], scores: torch.FloatTensor): 36 | if params["enabled"]: 37 | params["scores_size"] = len(scores[0]) 38 | grammar: Grammar = params["grammar"] 39 | 40 | if input_ids is not None: 41 | # input_ids are None when running from tests. 42 | input_size = len(input_ids[0]) 43 | if input_size <= self.last_input_size: 44 | logger.warning("output_template: input size unexpectedly decreased. Restarting grammar (except wrong output)") 45 | grammar.reset() 46 | elif self.last_input_size != 0: 47 | for token_id in input_ids[0][self.last_input_size:]: 48 | grammar.advance(int(token_id)) 49 | self.last_input_size = input_size 50 | 51 | return grammar.update_scores(scores) 52 | return scores 53 | 54 | 55 | def logits_processor_modifier(processor_list, input_ids): 56 | """ 57 | Adds logits processors to the list, allowing you to access and modify 58 | the next token probabilities. 59 | Only used by loaders that use the transformers library for sampling. 60 | """ 61 | processor_list.append(TemplatingLogitsProcessor()) 62 | return processor_list 63 | 64 | 65 | def token_generated_callback(input_ids, scores): 66 | if params["enabled"]: 67 | grammar: Grammar = params["grammar"] 68 | new_token = input_ids[0][-1] 69 | grammar.advance(int(new_token)) 70 | 71 | 72 | def input_modifier(string, state, is_chat=False): 73 | """ 74 | Initializes template and appends initial simple text to input. 75 | Note: In chat_mode, this extension does nothing. 76 | """ 77 | if not is_chat: 78 | if "grammar" in state or params["template"]: 79 | grammar: Grammar = params["grammar"] 80 | if "grammar" in state: 81 | grammar.reset(state["grammar"] or EMPTY_GRAMMAR) 82 | else: 83 | grammar.reset(params["template"]) 84 | params["enabled"] = True 85 | else: 86 | params["enabled"] = False 87 | return string 88 | 89 | 90 | def ui(): 91 | import gradio as gr 92 | output_template = gr.Textbox(value="", placeholder="Enter output template", label="Output Template", 93 | info='Output Template to use', lines=5) 94 | 95 | def update_output_template(x): 96 | logger.info("output_template updated") 97 | params.update({"template": x}) 98 | 99 | output_template.change(update_output_template, output_template, None) 100 | -------------------------------------------------------------------------------- /state_machine.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Set 2 | from extensions.output_template.utils import get_token_dictionary, AllowedTokens 3 | from enum import IntEnum 4 | 5 | 6 | class Advance(IntEnum): 7 | Again = 0 8 | Done = 1 9 | Reject = 2 10 | TryNext = 3 11 | 12 | 13 | class Matcher: 14 | ''' Piece of state machine used to match the Symbol ''' 15 | 16 | def __init__(self, symbol: "Symbol"): 17 | self.symbol = symbol 18 | 19 | def __repr__(self): 20 | return f"<{self.__class__.__name__} {self.debug()}>" 21 | 22 | def debug(self) -> str: 23 | """ 24 | Returns string describing current state. 25 | There's no logic here, just what seem'd most readable at the time 26 | """ 27 | return repr(self.symbol) 28 | 29 | def get_effective_matcher(self) -> Optional["Matcher"]: 30 | raise NotImplementedError(f"get_effective_matcher on {self.__class__.__name__}") 31 | 32 | def advance(self, g: "Grammar", token_id: int) -> Advance: 33 | """ 34 | Returns Advance.Again if given token matches symbol only partially and more tokens are expected. 35 | Returns Advance.Done if given token (at current state) matches rest of this symbol. 36 | Method should not be called again after returning Advance.Done. 37 | Returns Advance.Reject if token doesn't match 38 | Returns Advance.TryNext if not matching token is expected and next symbol should be tried 39 | """ 40 | raise NotImplementedError(f"advance on {self.__class__.__name__}") 41 | 42 | def get_allowed_tokens(self, g: "Grammar") -> AllowedTokens: 43 | raise NotImplementedError(f"get_allowed_tokens on {self.__class__.__name__}") 44 | 45 | 46 | class TerminalMatcher(Matcher): 47 | symbol: "Terminal" 48 | 49 | def __init__(self, t: "Terminal"): 50 | super().__init__(t) 51 | self.index = 0 52 | 53 | def debug(self) -> str: 54 | if self.index <= 0 or self.index >= len(self.symbol.value): 55 | return repr(self.symbol) 56 | return f"""t'{repr(self.symbol.value[:self.index]).strip("'")}[{repr(self.symbol.value[self.index:]).strip("'")}]'""" 57 | 58 | def get_effective_matcher(self) -> "Matcher": 59 | return self 60 | 61 | def get_allowed_tokens(self, g: "Grammar") -> AllowedTokens: 62 | if self.index not in self.symbol.allowed_cache: 63 | d = get_token_dictionary() 64 | allowed = set() 65 | for token_id in d: 66 | if d[token_id] and d[token_id] == self.symbol.value[self.index:self.index+len(d[token_id])]: 67 | allowed.add(token_id) 68 | self.symbol.allowed_cache[self.index] = allowed 69 | return AllowedTokens(allowed=self.symbol.allowed_cache[self.index]) 70 | 71 | def enter_in_middle(self, g: "Grammar", token_id: int) -> Advance: 72 | if self.index == 0: 73 | # Special case, allow entering mid-token 74 | d = get_token_dictionary() 75 | t = get_suffix_prefix(d[token_id], self.symbol.value) 76 | if not t: 77 | return Advance.Reject 78 | self.index += len(t) 79 | if self.index >= len(self.symbol.value): 80 | return Advance.Done 81 | return Advance.Again 82 | 83 | def advance(self, g: "Grammar", token_id: int) -> Advance: 84 | d = get_token_dictionary() 85 | t = d[token_id] 86 | if t != self.symbol.value[self.index:self.index + len(t)]: 87 | return Advance.Reject 88 | self.index += len(t) 89 | if self.index >= len(self.symbol.value): 90 | return Advance.Done 91 | return Advance.Again 92 | 93 | 94 | def get_suffix_prefix(suffix_from, prefix_from) -> str: 95 | i = 1 96 | while i <= min(len(suffix_from), len(prefix_from)): 97 | if suffix_from[-i:] != prefix_from[:i]: 98 | break 99 | i += 1 100 | return prefix_from[:i-1] 101 | 102 | 103 | class RegExpMatcher(Matcher): 104 | symbol: "RegExp" 105 | 106 | def __init__(self, t: "RegExp"): 107 | super().__init__(t) 108 | 109 | def debug(self) -> str: 110 | return f"""r{self.symbol.value}""" 111 | 112 | def get_effective_matcher(self) -> "Matcher": 113 | return self 114 | 115 | def get_allowed_tokens(self, g: "Grammar") -> AllowedTokens: 116 | if self.symbol.negative: 117 | if self.symbol.value not in self.symbol.banned_cache: 118 | d = get_token_dictionary() 119 | banned = set() 120 | for token_id in d: 121 | if self.symbol.re.search(d[token_id]): 122 | if self.symbol.next: 123 | t = d[token_id] 124 | # Check if there's prefix of next terminal that is also suffix of this token 125 | s = get_suffix_prefix(t, self.symbol.next.value) 126 | # If yes, check if rest of this token can be allowed 127 | if s and len(s) < len(t) and not self.symbol.re.search(t[0:-len(s)]): 128 | # Yes, allow that token 129 | pass 130 | else: 131 | # No, ban entire token 132 | banned.add(token_id) 133 | else: 134 | banned.add(token_id) 135 | self.symbol.banned_cache[self.symbol.value] = banned 136 | return AllowedTokens(banned=self.symbol.banned_cache[self.symbol.value]) 137 | else: 138 | if self.symbol.value not in self.symbol.allowed_cache: 139 | d = get_token_dictionary() 140 | allowed = set() 141 | for token_id in d: 142 | if self.symbol.re.match(d[token_id]): 143 | allowed.add(token_id) 144 | self.symbol.allowed_cache[self.symbol.value] = allowed 145 | return AllowedTokens(allowed=self.symbol.allowed_cache[self.symbol.value]) 146 | 147 | def advance(self, g: "Grammar", token_id: int) -> Advance: 148 | d = get_token_dictionary() 149 | if self.symbol.negative: 150 | if self.symbol.re.search(d[token_id]): 151 | return Advance.Reject 152 | elif not self.symbol.re.match(d[token_id]): 153 | return Advance.Reject 154 | # TODO: use index? How to deal with tokens that match partially? 155 | return Advance.Done 156 | 157 | 158 | class AnyTokenMatcher(Matcher): 159 | symbol: "AnyToken" 160 | 161 | def debug(self) -> str: 162 | return ".*" 163 | 164 | def get_effective_matcher(self) -> "Matcher": 165 | return self 166 | 167 | def get_allowed_tokens(self, g: "Grammar") -> AllowedTokens: 168 | return AllowedTokens(allow_eos=True) 169 | 170 | def advance(self, g: "Grammar", token_id: int) -> Advance: 171 | return Advance.Again 172 | 173 | 174 | class SequenceMatcher(Matcher): 175 | symbol: "Sequence" 176 | 177 | def __init__(self, symbol: "Sequence", items: List[Optional[Matcher]]): 178 | super().__init__(symbol) 179 | self.items = items 180 | self.index = 0 181 | 182 | def debug(self) -> str: 183 | return f'''({" ".join([ 184 | f"[{self.items[i].debug()}]" if i == self.index and self.items[i] 185 | else repr(self.symbol.items[i]) 186 | for i in range(len(self.symbol.items)) 187 | ])})''' 188 | 189 | def get_effective_matcher(self) -> Optional[Matcher]: 190 | assert self.index < len(self.items) 191 | if self.items[self.index]: 192 | return self.items[self.index].get_effective_matcher() 193 | return None 194 | 195 | def ensure_matcher(self, g: "Grammar", i=0) -> "SequenceMatcher": 196 | if not self.items[i]: 197 | self.items[i] = g.resolve(self.symbol.items[i]).enter(g) 198 | return self 199 | 200 | def get_allowed_tokens(self, g: "Grammar") -> AllowedTokens: 201 | assert self.index < len(self.items) 202 | rv = self.items[self.index].get_allowed_tokens(g) 203 | if rv.look_ahead: 204 | i = self.index 205 | ahead = rv 206 | while i < len(self.symbol.items) - 1: 207 | i += 1 208 | self.ensure_matcher(g, i) 209 | ahead = self.items[i].get_allowed_tokens(g) 210 | rv = rv.combine(ahead) 211 | if not ahead.look_ahead: 212 | break 213 | if not ahead.look_ahead: 214 | rv.look_ahead = False 215 | return rv 216 | 217 | def advance(self, g: "Grammar", token_id: int) -> Advance: 218 | a = self.items[self.index].advance(g, token_id) 219 | while True: 220 | if a in (Advance.Done, Advance.TryNext): 221 | if self.index < len(self.symbol.items) - 1: 222 | self.index += 1 223 | self.ensure_matcher(g, self.index) 224 | if a == Advance.TryNext: 225 | if (True 226 | and isinstance(self.items[self.index - 1], RepeatMatcher) 227 | and isinstance(self.items[self.index], TerminalMatcher) 228 | ): 229 | a = self.items[self.index].enter_in_middle(g, token_id) 230 | continue # yep, this is a goto 231 | return self.advance(g, token_id) 232 | a = Advance.Again 233 | break 234 | return a 235 | 236 | 237 | class AlternativeMatcher(Matcher): 238 | symbol: "Alternative" 239 | 240 | def __init__(self, symbol: "Alternative", items: Set[Matcher]): 241 | super().__init__(symbol) 242 | self.items = items 243 | 244 | def debug(self): 245 | return f'({" | ".join([x.debug() for x in self.items])})' 246 | 247 | def get_effective_matcher(self) -> "Matcher": 248 | if len(self.items) == 1: 249 | return list(self.items)[0].get_effective_matcher() 250 | return self 251 | 252 | def get_allowed_tokens(self, g: "Grammar") -> AllowedTokens: 253 | rv = None 254 | for i in self.items: 255 | a = i.get_allowed_tokens(g) 256 | rv = rv.combine(a) if rv else a 257 | # TODO: should this return 'ban everything' if no alternative is left? 258 | # TODO: should such state be even possible? 259 | return rv or AllowedTokens() 260 | 261 | def advance(self, g: "Grammar", token_id: int) -> Advance: 262 | best_a = Advance.Reject 263 | for i in list(self.items): 264 | a = i.advance(g, token_id) 265 | if a in (Advance.Reject, Advance.TryNext): 266 | if a == Advance.TryNext and best_a == Advance.Reject: 267 | best_a = Advance.TryNext 268 | self.items.remove(i) 269 | else: 270 | best_a = Advance.Done 271 | if a == Advance.Done: 272 | self.items.remove(i) 273 | if len(self.items) == 0: 274 | return best_a 275 | return Advance.Again 276 | 277 | 278 | class RepeatMatcher(Matcher): 279 | symbol: "Repeat" 280 | 281 | def __init__(self, symbol: "Symbol", effective_item: Matcher): 282 | super().__init__(symbol) 283 | self.effective_item = effective_item 284 | self.inside = False 285 | 286 | def get_effective_matcher(self) -> "Matcher": 287 | return self.effective_item.get_effective_matcher() if self.inside else self 288 | 289 | def get_allowed_tokens(self, g: "Grammar") -> AllowedTokens: 290 | rv = self.effective_item.get_allowed_tokens(g) 291 | if not self.inside: 292 | return rv.set_ahead() 293 | return rv 294 | 295 | def advance(self, g: "Grammar", token_id: int) -> Advance: 296 | a = self.effective_item.advance(g, token_id) 297 | if a == Advance.Reject: 298 | if self.inside: 299 | return a 300 | return Advance.TryNext 301 | elif a == Advance.Done: 302 | if self.symbol.mode == "*": 303 | self.effective_item = g.resolve(self.symbol.item).enter(g) 304 | self.inside = False 305 | return Advance.Again 306 | else: # mode == "?" 307 | return Advance.Done 308 | elif a == Advance.Again: 309 | self.inside = True 310 | return a 311 | -------------------------------------------------------------------------------- /symbols.py: -------------------------------------------------------------------------------- 1 | from extensions.output_template.state_machine import * 2 | import re 3 | 4 | 5 | class Symbol: 6 | def validate(self, g: "Grammar"): 7 | """ 8 | Validates against grammar. 9 | Currently just checks that for each Terminal there is rule defined. 10 | """ 11 | pass 12 | 13 | def enter(self, g: "Grammar") -> Matcher: 14 | raise NotImplementedError(f"enter on {self.__class__.__name__}") 15 | 16 | 17 | class NonTerminal(Symbol): 18 | def __init__(self, name: str): 19 | self.name = name 20 | 21 | def __repr__(self): 22 | return self.name 23 | 24 | def validate(self, g: "Grammar"): 25 | g.resolve(self).validate(g) 26 | 27 | def enter(self, g: "Grammar") -> Matcher: 28 | return g.resolve(self).enter(g) 29 | 30 | 31 | class Terminal(Symbol): 32 | def __init__(self, value: str): 33 | self.value = value 34 | self.allowed_cache = {} 35 | 36 | def __repr__(self): 37 | return f't{repr(self.value)}' 38 | 39 | def validate(self, g: "Grammar"): 40 | if not self.value: 41 | raise ValueError("empty terminal") 42 | 43 | def enter(self, g: "Grammar") -> Matcher: 44 | return TerminalMatcher(self) 45 | 46 | 47 | class RegExp(Symbol): 48 | allowed_cache = {} 49 | banned_cache = {} 50 | 51 | def __init__(self, value: str): 52 | self.value = value 53 | self.next: Optional[Terminal] = None 54 | if self.value.startswith("[^"): 55 | # To prevent generating giant set of almost all tokens, 56 | # tokens matching negative are banned instead 57 | self.negative = True 58 | self.re = re.compile("[" + value[2:], re.MULTILINE | re.DOTALL) 59 | else: 60 | self.negative = False 61 | self.re = re.compile("^" + value + "$", re.MULTILINE | re.DOTALL) 62 | 63 | def make_re(self): 64 | # https://youtu.be/iQrjbRz3y7A 65 | if self.value.startswith("[^"): 66 | # To prevent generating giant set of almost all tokens, 67 | # tokens matching this rest is banned 68 | self.negative = True 69 | r = "[" + self.value[2:] 70 | if self.next: 71 | r += "(" + re.escape(self.next.value) + ")?" 72 | self.re = re.compile(r, re.MULTILINE | re.DOTALL) 73 | else: 74 | self.negative = False 75 | self.re = re.compile("^" + self.value + "$", re.MULTILINE | re.DOTALL) 76 | 77 | def allow_next(self, t: Terminal): 78 | """ 79 | When regexp is followed by terminal, regexp is configured to also match tokens that include that terminal. 80 | This prevents banning perfectly good tokens and biasing LLM output. 81 | (also see 'test_allow_next') 82 | """ 83 | self.next = t 84 | 85 | def __repr__(self): 86 | return f'r{self.value}' 87 | 88 | def enter(self, g: "Grammar") -> Matcher: 89 | return RegExpMatcher(self) 90 | 91 | 92 | class AnyToken(Symbol): 93 | """ 94 | Special symbol to which '.*' is translated. 95 | Just matches anything, basically turning grammar off. 96 | """ 97 | 98 | def __repr__(self): 99 | return f'.*' 100 | 101 | def enter(self, g: "Grammar") -> Matcher: 102 | return AnyTokenMatcher(self) 103 | 104 | 105 | class Collection(Symbol): 106 | def __init__(self, items: List[Symbol]): 107 | self.items = items 108 | 109 | 110 | class Sequence(Collection): 111 | def __init__(self, items: List[Symbol]): 112 | super().__init__(items) 113 | self.effective = [] 114 | self.index = 0 115 | 116 | def validate(self, g: "Grammar"): 117 | super().validate(g) 118 | for i in range(len(self.items) - 1): 119 | if (True 120 | and isinstance(self.items[i], Repeat) 121 | and isinstance(self.items[i].item, RegExp) 122 | and isinstance(g.resolve(self.items[i + 1]), Terminal) 123 | ): 124 | # See test_allow_next 125 | self.items[i].item.allow_next(g.resolve(self.items[i + 1])) 126 | 127 | def __repr__(self): 128 | return f'({" ".join([repr(x) for x in self.items])})' 129 | 130 | def enter(self, g: "Grammar") -> Matcher: 131 | return SequenceMatcher(self, [ 132 | None 133 | for m in self.items 134 | ]).ensure_matcher(g) 135 | 136 | 137 | class Alternative(Collection): 138 | def __init__(self, items: List[Matcher]): 139 | super().__init__([]) 140 | for i in items: 141 | if isinstance(i, Alternative): 142 | self.items += i.items 143 | else: 144 | self.items.append(i) 145 | self.possible = set() 146 | 147 | def __repr__(self): 148 | return f'({" | ".join([repr(x) for x in self.items])})' 149 | 150 | def validate(self, g: "Grammar"): 151 | for item in self.items: 152 | g.resolve(item).validate(g) 153 | 154 | def enter(self, g: "Grammar") -> "Matcher": 155 | return AlternativeMatcher(self, { 156 | g.resolve(m).enter(g) 157 | for m in self.items 158 | }) 159 | 160 | 161 | class Repeat(Symbol): 162 | def __init__(self, mode: str, item: Symbol): 163 | assert mode in "*?" 164 | self.item = item 165 | self.mode = mode 166 | 167 | def __repr__(self): 168 | if isinstance(self.item, (Terminal, NonTerminal)): 169 | return f'({repr(self.item)}){self.mode}' 170 | else: 171 | return f'{repr(self.item)}{self.mode}' 172 | 173 | def validate(self, g: "Grammar"): 174 | self.item.validate(g) 175 | 176 | def enter(self, g: "Grammar") -> Matcher: 177 | return RepeatMatcher(self, g.resolve(self.item).enter(g)) 178 | -------------------------------------------------------------------------------- /test_tokenizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Fake tokenizer implementation used for testing. 3 | Uses only 127 tokens with few IDs reserved for multi-character tokens 4 | to simulate possible issues. 5 | """ 6 | from typing import List 7 | 8 | tokens = { 9 | 0: "\x00", # EOS 10 | 10: "\n", # newline 11 | # Randomly chosen strings 12 | # Ordered from longest to shortest and order is important. 13 | 1: "Universe", 14 | 2: "token", 15 | 3: "world", 16 | 4: "stock", 17 | 5: '..."', 18 | 6: "hall", 19 | 7: "the", 20 | 8: "...", 21 | 9: "and", 22 | 11: "com", 23 | 12: "Neg", 24 | 13: " ", 25 | 14: "end", 26 | 15: "six", 27 | 16: "tab", 28 | 17: "- [", 29 | 18: "gg", 30 | 19: "He", 31 | 20: "- ", 32 | 21: "ni", 33 | 22: "oo", 34 | 23: "[]", 35 | 24: "or", 36 | 25: '."', 37 | 26: "),", 38 | 27: "of", 39 | 28: "to", 40 | 29: "by", 41 | 30: "++", 42 | 31: "],", 43 | # From space above, all printable characters 44 | **{i: chr(i) for i in range(32, 127)}, 45 | } 46 | 47 | 48 | def encode(text) -> List[int]: 49 | rv = [] 50 | while text: 51 | for i in tokens: 52 | if text.startswith(tokens[i]): 53 | rv.append(i) 54 | text = text[len(tokens[i]):] 55 | break 56 | else: 57 | text = text[1:] 58 | return rv 59 | 60 | 61 | def decode(token_ids: List[int]) -> str: 62 | return "".join([tokens[i] for i in token_ids]) 63 | -------------------------------------------------------------------------------- /tests.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Union 3 | 4 | os.environ["OT_TESTING"] = "1" 5 | from extensions.output_template.state_machine import AnyTokenMatcher, TerminalMatcher 6 | from extensions.output_template.script import TemplatingLogitsProcessor, params 7 | from extensions.output_template.utils import encode, decode, shared, MINUS_INF 8 | from extensions.output_template.grammar import Grammar, Repeat, RegExp 9 | from torch import Tensor 10 | import math, random, json 11 | 12 | EOS = shared.tokenizer.eos_token_id 13 | TEMPLATE = """ 14 | root ::= "Alice:" space action 15 | greeting ::= "hello \\"world\\" \\U00003041" 16 | action ::= speech | bullet | command 17 | space ::= " " 18 | bullet ::= ("- " "I'll go" space "to" space location space "and do") [^\n]+ 19 | command ::= "/go " location 20 | location ::= ("hall" | "kitchen") 21 | speech ::= "\\"" [^"\n]+ "\\"" 22 | """ 23 | 24 | 25 | def random_scores(): 26 | return Tensor([[ 0.0001 + (math.floor(random.random() * 100) / 100.0) for _ in range(127) ]]) 27 | 28 | 29 | def set_score(token_id: Union[str, int, list], scores, value=1000.0): 30 | if type(token_id) is list: 31 | for i in token_id: 32 | scores = set_score(i, scores, value) 33 | return scores 34 | if type(token_id) is str: 35 | token_id = encode(token_id)[0] 36 | scores[..., token_id] = value 37 | return scores 38 | 39 | 40 | def scores_to_text(scores): 41 | scores = scores[0] 42 | return " ".join([ 43 | f"{repr(chr(i))}:{scores[i]}" 44 | for i in range(len(scores)) 45 | if scores[i] > 0 46 | ]) 47 | 48 | 49 | def sample_test(scores) -> int: 50 | # Returns single generated token 51 | TemplatingLogitsProcessor()(None, scores) 52 | best = int(scores.argmax()) 53 | grammar: Grammar = params["grammar"] 54 | grammar.advance(best) 55 | return best 56 | 57 | 58 | def get_text(until=EOS, score_fn=random_scores) -> str: 59 | tokens = [] 60 | while True: 61 | t = sample_test(score_fn()) 62 | if t != EOS: 63 | tokens.append(t) 64 | if type(until) is str and until in decode([t]): 65 | break 66 | elif t == until: 67 | break 68 | 69 | return decode(tokens) 70 | 71 | 72 | def test_grammar_parser(): 73 | g = Grammar(TEMPLATE) 74 | assert len(g.rules) == 8 75 | 76 | g.reset(""" 77 | root ::= "hi" 78 | # Testing case when grammar ends with non-terminated line with comment""") 79 | 80 | 81 | def test_terminal(): 82 | grammar: Grammar = params["grammar"] 83 | grammar.reset("""root ::= 'Hello world' [\n]+""") 84 | EOL = encode("\n")[0] 85 | t = get_text(EOL) 86 | assert t == "Hello world\n" 87 | assert grammar.active_matcher 88 | 89 | grammar.reset() 90 | scores = grammar.update_scores(random_scores()) 91 | assert len(encode("He")) == 1 92 | assert scores[..., encode("He")] > MINUS_INF 93 | assert scores[..., encode("H")] > MINUS_INF 94 | grammar.advance(encode("H")[0]) 95 | assert ord('e') == sample_test(random_scores()) 96 | matcher = grammar.get_effective_matcher() 97 | while grammar.get_effective_matcher() is matcher: 98 | sample_test(random_scores()) 99 | assert isinstance(grammar.get_effective_matcher().symbol, (RegExp, Repeat)) 100 | 101 | 102 | def test_alternate(): 103 | grammar: Grammar = params["grammar"] 104 | grammar.reset(TEMPLATE) 105 | grammar.enter_rule("action") 106 | sample_test(set_score("/", random_scores())) 107 | text = get_text() 108 | assert text in ("go hall", "go kitchen") 109 | 110 | 111 | def test_sequence(): 112 | grammar: Grammar = params["grammar"] 113 | grammar.reset(TEMPLATE) 114 | get_text(encode(" ")[0]) 115 | sample_test(set_score("-", random_scores())) 116 | assert 32 == sample_test(random_scores()) 117 | tokens = [] 118 | while not isinstance(grammar.get_effective_matcher().symbol, (Repeat, RegExp)): 119 | t = sample_test(random_scores()) 120 | if t == EOS: 121 | break 122 | tokens.append(t) 123 | 124 | assert decode(tokens) in ( 125 | "I'll go to hall and do", 126 | "I'll go to kitchen and do", 127 | ) 128 | 129 | 130 | def test_regexp(): 131 | grammar: Grammar = params["grammar"] 132 | grammar.reset("root ::= [a-z]") 133 | scores = TemplatingLogitsProcessor()(None, random_scores()) 134 | assert len([x for x in scores[0] if x > MINUS_INF]) == 26 135 | assert scores[..., EOS] == MINUS_INF # Also make sure that EOS is banned in any case 136 | 137 | # Tests that combination of repetition and regexp allows multi-char tokens 138 | grammar.reset("root ::= [a-z]+") 139 | scores = TemplatingLogitsProcessor()(None, random_scores()) 140 | assert len([x for x in scores[0] if x > MINUS_INF]) > 26 141 | assert scores[..., encode("world")] > MINUS_INF 142 | assert scores[..., EOS] == MINUS_INF 143 | 144 | # Tests banning specific character 145 | grammar.reset("root ::= [^\n]+") 146 | scores = TemplatingLogitsProcessor()(None, random_scores()) 147 | assert scores[..., 10] < 0 148 | assert scores[..., EOS] == MINUS_INF 149 | 150 | 151 | def test_repeat(): 152 | # Tests "speech" rule in TEMPLATE grammar 153 | grammar: Grammar = params["grammar"] 154 | grammar.reset(TEMPLATE) 155 | get_text(encode(" ")[0]) 156 | sample_test(set_score('"', random_scores())) 157 | q = encode('"')[0] 158 | for _ in range(10): 159 | # Just generate few random chars 160 | assert sample_test(set_score('"', random_scores(), MINUS_INF)) not in (EOS, q) 161 | # Test that " is still allowed to be generated 162 | scores = TemplatingLogitsProcessor()(None, set_score('"', random_scores())) 163 | assert scores[..., q] > MINUS_INF 164 | assert q == sample_test(scores) 165 | assert EOS == sample_test(random_scores()) 166 | 167 | # Tests grammar that generates repeated combination of foobars 168 | grammar: Grammar = params["grammar"] 169 | grammar.reset(""" 170 | root ::= many 171 | many ::= one one one+ 172 | one ::= foo | bar 173 | foo ::= "foo" 174 | bar ::= "b" "a"+ "r" 175 | """) 176 | text = get_text() 177 | assert "foo" in text or "ba" in text 178 | 179 | # Actual case that was broken originally 180 | grammar: Grammar = params["grammar"] 181 | grammar.reset(""" 182 | root ::= line line line 183 | line ::= "Alice:" space action newline 184 | action ::= speech | bullet | command 185 | space ::= " " 186 | newline ::= "\n" 187 | bullet ::= ("- " "I'll go" space "to" space location space "and do") gibberish 188 | command ::= "/go" space location 189 | location ::= ("hall" | "out" | "room") 190 | speech ::= "\\"" ">" gibberish "\\"" 191 | gibberish ::= [^-/">\n]+ 192 | """) 193 | text = get_text() 194 | commands = len(text.split("/go")) - 1 195 | bullets = len(text.split("-")) - 1 196 | speech = len(text.split(">")) - 1 197 | assert 3 == commands + bullets + speech 198 | 199 | 200 | def test_json(): 201 | grammar: Grammar = params["grammar"] 202 | grammar.reset(""" 203 | root ::= object 204 | value ::= object | array | string | number | ("true" | "false" | "null") ws 205 | 206 | object ::= 207 | "{" ws ( 208 | string ":" ws value 209 | ("," ws string ":" ws value)* 210 | )? "}" ws 211 | 212 | array ::= 213 | "[" ws ( 214 | value 215 | ("," ws value)* 216 | )? "]" ws 217 | 218 | string ::= 219 | "\\"" ( 220 | [^"\\\\\n] | 221 | "\\\\" (["\\\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes 222 | )* "\\"" ws 223 | 224 | number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws 225 | 226 | # Optional space: by convention, applied in this grammar after literal chars when allowed 227 | ws ::= ([ \\t\\n] ws)? 228 | """) 229 | 230 | random.seed(2342343231) 231 | # 1st token has to be { 232 | assert ord("{") == sample_test(random_scores()) 233 | # Any number of whitespace has to be allowed 234 | assert ord(" ") == sample_test(set_score(32, random_scores(), 1000.0)) 235 | assert ord(" ") == sample_test(set_score(32, random_scores(), 1000.0)) 236 | assert ord(" ") == sample_test(set_score(32, random_scores(), 1000.0)) 237 | assert ord("\n") == sample_test(set_score(10, random_scores(), 1000.0)) 238 | # Ban }. If not whitespace, only '"' may follow 239 | assert ord('"') == sample_test(set_score(list("} \n"), random_scores(), MINUS_INF)) 240 | # Grab some tokens and force string to finish 241 | for _ in range(5): 242 | sample_test(set_score('"', random_scores(), MINUS_INF)) 243 | assert ord('"') == sample_test(set_score('"', random_scores(), 1000)) 244 | # Now only whitespace and ':' should be allowed 245 | scores = random_scores() 246 | TemplatingLogitsProcessor()(None, scores) 247 | assert scores[..., ord(':')] > 0 248 | assert len([x for x in scores[0] if x > 0]) == 3 # ':', space and newline 249 | # Go over ':' and ban whitespace. Now start-of-value tokens should be allowed 250 | sample_test(set_score(':', random_scores())) 251 | scores = set_score(list(' \n'), random_scores(), MINUS_INF) 252 | TemplatingLogitsProcessor()(None, scores) 253 | # 14 characters = 0-9, minus sign, {, [ and '"'. 254 | # Additionally, 3 characters for n, t and f of null, true and false 255 | assert len([ x for x in scores[0] if x > 0 ]) == 17 256 | # Force 'true' and verify repetition 257 | sample_test(set_score('t', random_scores())) 258 | assert "true" == 't' + get_text('e') 259 | assert ord(',') == sample_test(set_score(list('} \n'), random_scores(), MINUS_INF)) 260 | # Force another key and make sure it generates quotes properly (this was failing before) 261 | sample_test(set_score('"', random_scores())) 262 | 263 | a = get_text('"') 264 | assert '"' not in a[:-1] 265 | scores = random_scores() 266 | TemplatingLogitsProcessor()(None, scores) 267 | assert len([x for x in scores[0] if x > 0]) == 3 # ':', space and newline 268 | assert ord(':') == sample_test(set_score(list(' \n'), random_scores(), MINUS_INF)) 269 | 270 | # Now just restart grammar and test it generates some proper json 271 | for _ in range(100): 272 | # (and do it few times, cos it tends to generate just {}) 273 | grammar.reset() 274 | a = get_text() 275 | json.loads(a) 276 | 277 | 278 | def test_any_token(): 279 | """ Test nonstandard rule to disable grammar """ 280 | grammar: Grammar = params["grammar"] 281 | grammar.reset(""" 282 | root ::= (donotend) 283 | donotend ::= (.*) 284 | """) 285 | for i in range(255): 286 | z = sample_test(random_scores()) 287 | assert isinstance(grammar.get_effective_matcher(), AnyTokenMatcher) 288 | 289 | 290 | def test_allow_next(): 291 | """ 292 | Tests workaround for following issue: 293 | Given grammar rule like root ::= [^"]* '"' expressing "anything until quotation mark, then quotation mark", 294 | LogitsProcessor actually bans most of the tokens ending in " as they don't match 1st rule in sequence. 295 | This would be technically okay and would generate correct output, but banning all those tokens would 296 | result in changing LLM behaviour too much and cause it to generate much longer quoted strings. 297 | """ 298 | # TODO: it may be good idea to do similar workaround for regexp followed by Alternative, but for now this 299 | # TODO: should fix the biggest issue. 300 | grammar: Grammar = params["grammar"] 301 | grammar.reset(""" 302 | root ::= qm [^"]* qm 'H' 303 | qm ::= '"' 304 | """) 305 | 306 | # Sanity check for test tokenizer tokenizer 307 | assert len(encode('."')) == 1 308 | TOKEN = encode('."')[0] 309 | Q = encode('"')[0] 310 | 311 | # Go over " to reach 'anything but " part of seuquence' 312 | scores = TemplatingLogitsProcessor()(None, random_scores()) 313 | assert scores[..., Q] > MINUS_INF 314 | grammar.advance(Q) 315 | # Check that '."' token is allowed as it fits above grammar 316 | scores = TemplatingLogitsProcessor()(None, random_scores()) 317 | assert scores[..., TOKEN] > MINUS_INF 318 | grammar.advance(TOKEN) 319 | 320 | # Check that grammar moved to last rule 321 | assert isinstance(grammar.get_effective_matcher(), TerminalMatcher) 322 | assert grammar.get_effective_matcher().symbol.value == "H" 323 | 324 | 325 | if __name__ == "__main__": 326 | params["scores_size"] = 127 327 | params["enabled"] = True 328 | test_grammar_parser() 329 | test_terminal() 330 | test_alternate() 331 | test_sequence() 332 | test_regexp() 333 | test_repeat() 334 | test_json() 335 | test_any_token() 336 | test_allow_next() 337 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Set, Dict 2 | import torch, os 3 | MINUS_INF = -float("inf") 4 | 5 | 6 | if "OT_TESTING" in os.environ: 7 | from collections import namedtuple 8 | shared = namedtuple("Shared", ['tokenizer'])(namedtuple("Tokenizer", ['eos_token_id'])(0)) 9 | from extensions.output_template.test_tokenizer import encode, decode 10 | else: 11 | from modules import shared 12 | 13 | def encode(text) -> List[int]: 14 | return shared.tokenizer.encode(str(text), add_special_tokens=False) 15 | 16 | def decode(token_ids: List[int]) -> str: 17 | return shared.tokenizer.decode(token_ids) 18 | 19 | 20 | class AllowedTokens: 21 | """ 22 | Utility class used to combine and eventually allow (or ban) generation of those tokens that may match 23 | (or for sure don't match) template. 24 | 25 | When applied on scores, following order is done: 26 | 1. If 'banned' is not empty, all but banned tokens are allowed. 27 | If token is both allowed and banned, it's allowed. 28 | 2. if 'allowed' is not empty (and banned is), all but allowed tokens are banned. 29 | 3. if 'allow_eos' is False, end-of-string token is banned in any case. 30 | 31 | 'look_ahead' is used by Repeat symbol to signal that next symbol should also be considered. 32 | """ 33 | def __init__(self, *, allowed=None, banned=None, look_ahead=False, allow_eos=False): 34 | self.allowed: Set[int] = allowed or set() 35 | self.banned: Set[int] = banned or set() 36 | assert (self.allowed and not self.banned) or (self.banned and not self.allowed) or not (self.allowed and self.banned) 37 | self.look_ahead = look_ahead 38 | self.allow_eos = allow_eos 39 | 40 | def combine(self, other: "AllowedTokens") -> "AllowedTokens": 41 | """ Returns new instance which is combination of self and other """ 42 | allowed = set() 43 | banned = set() 44 | if (not self.allowed and not self.banned) or (not other.allowed and not other.banned): 45 | # One of self/other is 'allow all' 46 | pass 47 | elif self.allowed and other.allowed: 48 | # Both are 'allow only these' 49 | assert not self.banned and not other.banned 50 | allowed = self.allowed.union(other.allowed) 51 | elif self.banned and other.banned: 52 | # Both are 'ban only these' 53 | assert not self.allowed and not other.allowed 54 | banned = self.banned.intersection(other.banned) 55 | elif self.allowed and other.banned: 56 | # I have allowed tokens, other has banned tokens. 57 | # Allow everything but those we both banned 58 | assert not self.banned and not other.allowed 59 | banned = other.banned - self.allowed 60 | elif other.allowed and self.banned: 61 | # As above but reversed 62 | return other.combine(self) 63 | else: 64 | assert False, "impossible combination" 65 | 66 | return AllowedTokens( 67 | allow_eos=self.allow_eos or other.allow_eos, 68 | look_ahead=self.look_ahead or other.look_ahead, 69 | allowed=allowed, 70 | banned=banned, 71 | ) 72 | 73 | def set_ahead(self): 74 | """ Returns copy of self with 'look_ahead' set to True """ 75 | return AllowedTokens( 76 | allow_eos=self.allow_eos, 77 | allowed=self.allowed, 78 | banned=self.banned, 79 | look_ahead=True, 80 | ) 81 | 82 | def __repr__(self): 83 | data = [] 84 | if self.look_ahead or self.allow_eos: 85 | data.append(",".join([ 86 | "ahead" if self.look_ahead else "", 87 | "eos" if self.allow_eos else "" 88 | ]).strip(",")) 89 | data.append(f"allowed={self.allowed}") 90 | data.append(f"banned={self.banned}") 91 | return f"" 92 | 93 | def apply(self, scores: torch.FloatTensor): 94 | if self.allowed and not self.banned: 95 | s = scores.new_full(scores.shape, False, dtype=torch.bool) 96 | for a in self.allowed: 97 | s[..., a] = True 98 | if self.allow_eos: 99 | s[..., shared.tokenizer.eos_token_id] = True 100 | scores[~s] = MINUS_INF 101 | if self.banned or not self.allow_eos: 102 | s = scores.new_full(scores.shape, True, dtype=torch.bool) 103 | for a in self.banned: 104 | if a not in self.allowed: 105 | s[..., a] = False 106 | if not self.allow_eos: 107 | s[..., shared.tokenizer.eos_token_id] = False 108 | scores[~s] = MINUS_INF 109 | 110 | 111 | def get_token_dictionary() -> Dict[int, str]: 112 | from extensions.output_template.script import params, logger 113 | if not params["token_dictionary"] or params["used_tokenizer"] is not shared.tokenizer: 114 | assert params["scores_size"] 115 | logger.info("output_template: Creating token dictionary. This takes few seconds, but is done only once.") 116 | if "OT_TESTING" in os.environ: 117 | params["token_dictionary"] = {i: decode([i]) for i in range(params["scores_size"])} 118 | else: 119 | def convert_ids_to_tokens(i): 120 | # Wraps shared.tokenizer.convert_ids_to_tokens to workaround possible missing tokens 121 | try: 122 | return shared.tokenizer.convert_ids_to_tokens(i) 123 | except IndexError: 124 | return None 125 | params["token_dictionary"] = { 126 | token_id: ( 127 | shared.tokenizer.decode([token_id]) 128 | if "▁" not in tmp 129 | else tmp.replace("▁", " ") 130 | ) 131 | for (token_id, tmp) in ( 132 | (i, convert_ids_to_tokens(i)) 133 | for i in range(params["scores_size"]) 134 | ) 135 | if tmp 136 | } 137 | # import json 138 | # open("/tmp/dict.json", "w").write(json.dumps(params["token_dictionary"])) 139 | params["used_tokenizer"] = shared.tokenizer 140 | logger.info("output_template: Done creating token dictionary.") 141 | return params["token_dictionary"] 142 | --------------------------------------------------------------------------------