├── .gitignore ├── __init__.py ├── sampler.py ├── readme.md ├── tokenizer.py ├── debug_tools.py ├── bnf_simple.py ├── example.py ├── penalty.py ├── easy_schema.py ├── bnf_complex.py ├── json_schema.py ├── bnf.py ├── grammar_pipeline.py └── pipeline.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__ -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .penalty import GlobalPenalty, SlidingPenalty 2 | from .tokenizer import Plain, StringTokenizer, HFTokenizer, TikTokenizer, RWKVTokenizer 3 | from .pipeline import Pipeline, StatefulPipeline, StatefulPipeline, GenerationArgs 4 | 5 | from . import tokenizer 6 | from . import penalty 7 | from . import pipeline 8 | -------------------------------------------------------------------------------- /sampler.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | 3 | 4 | class LogitsTransformer: 5 | """ 6 | Transform a tensor of logits into a transformed tensor of logits. 7 | 8 | An example is filtering the logits by a certain grammar. 9 | """ 10 | 11 | def transform(self, logits: Tensor) -> Tensor: 12 | """ 13 | Transform the logits. 14 | """ 15 | ... 16 | 17 | 18 | class LogitsSampler: 19 | """ 20 | Samples the logits. 21 | """ 22 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # rwkv_contrib 2 | 3 | A collection of RWKV related Python modules I use in my projects. 4 | 5 | The project setup is extremely weird and homegrown, so it is advised to not use this as a dependency, but rather read the source code and copy the parts you need. 6 | 7 | ## BNF-like grammar 8 | 9 | A BNF-like grammar was implemented to constrain the output of the model. It is not a full BNF implementation, but rather a subset of it. (And at least a bit more intuitive than the full BNF.) 10 | 11 | Please refer to `bnf.py`, `bnf_complex.py`, `bnf_simple` for the implementation. For grammar defined, please check `easy_schema.py` and `json_schema.py`. For an actual implemented pipeline, check `grammar_pipeline.py`. 12 | 13 | ## Debugging tools 14 | 15 | The `debug_tools.py` contains a few debugging tools I use in my projects. It is mainly used to dump the state of the model. 16 | -------------------------------------------------------------------------------- /tokenizer.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | from os import PathLike 4 | from rwkv.rwkv_tokenizer import TRIE_TOKENIZER 5 | from typing import Generic, List, TypeVar, TYPE_CHECKING 6 | 7 | if TYPE_CHECKING: 8 | from tokenizers import Tokenizer 9 | import tiktoken 10 | 11 | T = TypeVar('T') 12 | 13 | 14 | class Tokenizer(Generic[T]): 15 | def encode(self, x: T) -> List[int]: ... 16 | def decode(self, tokens: List[int]) -> T: ... 17 | def validate(self, x: T) -> bool: ... 18 | 19 | 20 | class Plain(Tokenizer[List[int]]): 21 | def encode(self, x: List[int]) -> List[int]: return x 22 | def decode(self, tokens: List[int]) -> List[int]: return tokens 23 | def validate(self, x: List[int]) -> bool: return True 24 | 25 | 26 | class StringTokenizer(Tokenizer[str]): 27 | def validate(self, x: str) -> bool: return '\ufffd' not in x 28 | 29 | 30 | try: 31 | from tokenizers import Tokenizer as _Tokenizer 32 | 33 | class HFTokenizer(StringTokenizer): 34 | def __init__(self, path: PathLike): 35 | self.tokenizer = _Tokenizer.from_file(path) 36 | 37 | def encode(self, x: str) -> List[int]: 38 | return self.tokenizer.encode(x).ids 39 | 40 | def decode(self, tokens: List[int]) -> str: 41 | return self.tokenizer.decode(tokens) 42 | 43 | except ImportError: 44 | ... 45 | 46 | try: 47 | import tiktoken 48 | 49 | class TikTokenizer(StringTokenizer): 50 | def __init__(self): 51 | self.tokenizer = tiktoken.get_encoding("cl100k_base") 52 | self.encode = self.tokenizer.encode 53 | self.decode = self.tokenizer.decode 54 | 55 | except ImportError: 56 | ... 57 | 58 | import inspect 59 | 60 | RWKV_BASE = os.path.dirname(inspect.getfile(TRIE_TOKENIZER)) 61 | 62 | class RWKVTokenizer(StringTokenizer): 63 | def __init__(self): 64 | self.tokenizer = TRIE_TOKENIZER(os.path.join(RWKV_BASE, 'rwkv_vocab_v20230424.txt')) 65 | self.encode = self.tokenizer.encode 66 | self.decode = self.tokenizer.decode 67 | -------------------------------------------------------------------------------- /debug_tools.py: -------------------------------------------------------------------------------- 1 | """ 2 | debug_tools 3 | =========== 4 | 5 | This module contains tools for debugging and testing. 6 | """ 7 | 8 | 9 | from copy import deepcopy 10 | from numpy import ndarray 11 | import numpy as np 12 | from torch import Tensor 13 | from typing import TYPE_CHECKING, Optional 14 | from pathlib import Path 15 | 16 | if TYPE_CHECKING: 17 | from rwkv_contrib.pipeline import StatefulPipeline, GenerationArgs 18 | 19 | 20 | class StateDump: 21 | """ 22 | Represents a dump of all 32 layers in the state of RWKV-LM. 23 | """ 24 | 25 | layers: list[list[float]] 26 | labels: list[str] 27 | output_prefix: str 28 | token_count: int 29 | 30 | def __init__(self, output_prefix: str) -> None: 31 | self.layers = [] 32 | self.labels = [] 33 | self.output_prefix = output_prefix 34 | self.token_count = 0 35 | 36 | def wraps(self, pipeline: "StatefulPipeline") -> None: 37 | """ 38 | Wraps the 'infer' method of a pipeline. 39 | """ 40 | 41 | pipeline_infer = pipeline.infer 42 | 43 | def infer(tokens: list[int], args: "GenerationArgs" = None) -> tuple[Optional[int], list[Tensor]]: 44 | """ 45 | Wraps the 'infer' method of a pipeline. 46 | """ 47 | token, state = pipeline_infer(tokens, args) 48 | state: list[Tensor] 49 | state_chunks: list[tuple[Tensor, ...]] = [chunk for chunk in zip(*[iter(deepcopy(state))] * 5)] 50 | for idx, (att_xx, att_aa, att_bb, att_pp, ffn_xx) in enumerate(state_chunks): 51 | # No att_pp here due to it's increasing, and will cause the umap repr to present some weird linear pattern. 52 | concated = att_aa.float().tolist() + att_bb.float().tolist() + att_xx.float().tolist() + ffn_xx.float().tolist() 53 | self.layers.append(concated) 54 | self.labels.append(f"{self.token_count}-l-{idx}") 55 | self.token_count += 1 56 | return token, state 57 | 58 | pipeline.infer = infer 59 | 60 | def dumps(self) -> None: 61 | """ 62 | Dumps the state to disk. 63 | 64 | Will clear the state after dumping. 65 | """ 66 | layers_output = Path(f"{self.output_prefix}.layers.npy") 67 | labels_output = Path(f"{self.output_prefix}.labels.npy") 68 | 69 | np.save(layers_output, np.array(self.layers)) 70 | np.save(labels_output, np.array(self.labels)) 71 | 72 | 73 | self.layers = [] 74 | self.labels = [] 75 | self.token_count = 0 76 | 77 | def loads(self) -> tuple[ndarray, ndarray]: 78 | """ 79 | Loads the state from disk. 80 | 81 | Returns a tuple of (layers, labels). 82 | """ 83 | 84 | layers_output = Path(f"{self.output_prefix}.layers.npy") 85 | labels_output = Path(f"{self.output_prefix}.labels.npy") 86 | 87 | return np.load(layers_output), np.load(labels_output) 88 | -------------------------------------------------------------------------------- /bnf_simple.py: -------------------------------------------------------------------------------- 1 | from functools import cache 2 | from rwkv_contrib.bnf import Action, BNFTree, Node, token_size, token_table 3 | 4 | 5 | class SimpleNode(Node): 6 | def get_sub(self, idx: int | None) -> int | None: 7 | return None 8 | 9 | def handle_match(self, callstack: list[tuple[int, int]]) -> Action: 10 | callstack.pop() # pop current node as we are done with it 11 | return Action.CONSUME 12 | 13 | def handle_fail(self, callstack: list[tuple[int, int]]) -> Action: 14 | callstack.pop() 15 | return Action.FAIL 16 | 17 | def add_to_stack(self, callstack: list[tuple[int, int]]) -> None: 18 | callstack.append((self.bnf_index, 0)) 19 | 20 | def matched(self, callstack: list[tuple[int, int]]) -> bool: 21 | return False # if simple node is on the stack, it is not matched yet or it would have been popped 22 | 23 | 24 | class CharNode(SimpleNode): 25 | """ 26 | Represents a node that accepts a single character. 27 | """ 28 | 29 | char: bytes 30 | 31 | def __init__(self, bnf_tree: BNFTree, char: bytes) -> None: 32 | self.char = char 33 | super().__init__(bnf_tree) 34 | 35 | def accept_character(self, char: bytes | None, callstack: list[tuple[int, int]]) -> bool: 36 | return self.char == char 37 | 38 | @cache 39 | def get_logits(self) -> set[int]: 40 | return {k for k, v in token_table.items() if v.startswith(self.char)} 41 | 42 | 43 | class NotCharNode(SimpleNode): 44 | """ 45 | Represents a node that accepts any character except a single character. 46 | """ 47 | 48 | chars: list[bytes] 49 | 50 | def __init__(self, bnf_tree: BNFTree, char: bytes) -> None: 51 | self.chars = [bytes([x]) for x in char] 52 | super().__init__(bnf_tree) 53 | 54 | def accept_character(self, char: bytes | None, callstack: list[tuple[int, int]]) -> bool: 55 | return char not in self.chars 56 | 57 | @cache 58 | def get_logits(self) -> set[int]: 59 | return {k for k, v in token_table.items() if all(not v.startswith(x) for x in self.chars)} 60 | 61 | 62 | class WildcardNode(SimpleNode): 63 | """ 64 | Represents a node that accepts any character. 65 | """ 66 | 67 | def accept_character(self, char: bytes | None, callstack: list[tuple[int, int]]) -> bool: 68 | return char is not None # any character is accepted but None (end of string) is not 69 | 70 | @cache 71 | def get_logits(self) -> set[int]: 72 | return {x for x in token_table.items()} 73 | 74 | 75 | class DigitNode(SimpleNode): 76 | """ 77 | Represents a node that accepts any digit. 78 | """ 79 | 80 | def accept_character(self, char: bytes | None, callstack: list[tuple[int, int]]) -> bool: 81 | return char is not None and char.isdigit() # any digit is accepted but None (end of string) is not 82 | 83 | @cache 84 | def get_logits(self) -> set[int]: 85 | return {x for x in range(token_size) if x in token_table and token_table[x].isdigit()} 86 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | from os import environ 2 | import warnings 3 | from bnf_complex import RepeatNode, SequenceNode 4 | from bnf_simple import NotCharNode 5 | from easy_schema import literal 6 | 7 | from pipeline import GenerationArgs 8 | 9 | # suppress warnings in v5 model 10 | warnings.filterwarnings("ignore", category=UserWarning) 11 | if True: 12 | # setup RWKV_JIT_ON and RWKV_CUDA_ON 13 | environ["RWKV_CUDA_ON"] = "1" 14 | environ["RWKV_JIT_ON"] = "1" 15 | from rwkv.model import RWKV 16 | 17 | from bnf import BNFTree, setup_tokens 18 | from tokenizer import RWKVTokenizer 19 | import bnf 20 | from grammar_pipeline import BNFPipeline 21 | 22 | summary_tree = BNFTree() 23 | 24 | # setup tokenizers for bnf to use 25 | # as the implementation will filter available logits 26 | bnf.base_tokenizer = RWKVTokenizer().tokenizer 27 | setup_tokens() 28 | 29 | # state is frozen in the pipeline to reduce overhead for prompt. 30 | firestarter = ":" 31 | summary_prompt = f""" 32 | Instruction: Summarize the input email in 1-2 sentences, and suggest what to do next. 33 | 34 | Input: 35 | Title: Let’s start your Green Diet! 36 | Email: Follow Conference Lodge on Facebook and Instagram to keep posted of the latest 37 | news and promotions 38 | \------------------------------------------------------- 39 | Please do not reply to this email message. 40 | \------------------------------------------------------- 41 | To unsubscribe from this communication, please visit webpage: 42 | https://myaccount.ust.hk/refreshable_lists 43 | 44 | Response 45 | """.strip() 46 | 47 | # define the grammar 48 | # basically, it matches a following pattern: 49 | # This email is about (...). It says (...). 50 | # So, as a wise assistant, I think you should (...). 51 | # 52 | summary_matcher = SequenceNode( 53 | summary_tree, 54 | [ 55 | literal(summary_tree, b" This email is about "), 56 | RepeatNode(summary_tree, NotCharNode(summary_tree, b",.!?\n")), 57 | literal(summary_tree, b". It says "), 58 | RepeatNode(summary_tree, NotCharNode(summary_tree, b".,\n")), 59 | literal(summary_tree, b".\nSo, as a wise assistant, I think you should "), 60 | RepeatNode(summary_tree, NotCharNode(summary_tree, b".!?\n")), 61 | literal(summary_tree, b".\n"), 62 | ], 63 | ) 64 | 65 | 66 | # create rwkv instance, look at the temperature and top_p here 67 | # also don't forget to set a correct model path 68 | rwkv = RWKV(model="../models/RWKV-4-World-7B-v1-20230626-ctx4096.pth", strategy="cuda fp16") 69 | args = GenerationArgs( 70 | temperature=2.5, 71 | top_p=0.6, 72 | alpha_frequency=0.3, 73 | alpha_presence=0.3, 74 | ) 75 | 76 | # create the pipeline 77 | # a model can be used for multiple pipelines 78 | # the logits_cache is used to cache the allowed logits for each node 79 | summary_pipeline = BNFPipeline( 80 | rwkv, 81 | summary_tree, 82 | summary_matcher, 83 | logits_cache="logits_cache_summary.npz", 84 | default_args=args, 85 | ) 86 | 87 | # infer the state from the prompt 88 | # so for a continuing generation, the state can be reused for multiple times 89 | print(summary_prompt + firestarter, end="") 90 | _, state = summary_pipeline.infer(summary_pipeline.encode(summary_prompt)) 91 | 92 | # actually generate the summary, note that the state is modified in-place 93 | # if multiple generations are needed, the state should be deepcopied 94 | for partial in summary_pipeline.generate(firestarter, 256, state=state): 95 | print(partial, end="", flush=True) 96 | 97 | # dump the logits cache, by default it will be saved to the cache in the pipeline 98 | summary_pipeline.dump_logits() 99 | -------------------------------------------------------------------------------- /penalty.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from typing import TYPE_CHECKING, Any, Deque, Dict 3 | from abc import ABCMeta, abstractmethod 4 | import numpy as np 5 | from torch import Tensor 6 | from collections import deque 7 | 8 | 9 | if TYPE_CHECKING: 10 | from .pipeline import GenerationArgs 11 | 12 | 13 | class Penalty(metaclass=ABCMeta): 14 | @abstractmethod 15 | def transform(self, out: Tensor, args: "GenerationArgs") -> Tensor: 16 | """ 17 | Transform the logits with the penalty. 18 | """ 19 | 20 | @abstractmethod 21 | def update(self, token: int, args: "GenerationArgs"): 22 | """ 23 | Update the penalty with the token. 24 | """ 25 | 26 | @abstractmethod 27 | def clear(self): 28 | """ 29 | Clear the penalty. 30 | """ 31 | 32 | @abstractmethod 33 | def copy(self) -> "Penalty": 34 | """ 35 | Copy the penalty, for history. 36 | """ 37 | 38 | 39 | class GlobalPenalty(Penalty): 40 | def __init__(self) -> None: 41 | self.token_occurrences = {} 42 | 43 | def transform(self, out: Tensor, args: "GenerationArgs") -> Tensor: 44 | for n in self.token_occurrences: 45 | out[n] -= args.alpha_presence + self.token_occurrences[n] * args.alpha_frequency 46 | return out 47 | 48 | def update(self, token: int, args: "GenerationArgs"): 49 | if token not in self.token_occurrences: 50 | self.token_occurrences[token] = 1 51 | else: 52 | self.token_occurrences[token] += 1 53 | 54 | def clear(self): 55 | self.token_occurrences = {} 56 | 57 | def copy(self) -> "GlobalPenalty": 58 | ret = GlobalPenalty() 59 | ret.token_occurrences = self.token_occurrences.copy() 60 | return ret 61 | 62 | 63 | class SlidingPenalty(Penalty): 64 | def __init__(self, maxlen: int = 512) -> None: 65 | self.maxlen = maxlen 66 | self.token_occurrences: Deque[int] = deque() 67 | self.occurrences: Dict[int, int] = {} 68 | 69 | def transform(self, out: Tensor, args: "GenerationArgs") -> Tensor: 70 | for n in self.occurrences: 71 | out[n] -= args.alpha_presence + self.occurrences[n] * args.alpha_frequency 72 | return out 73 | 74 | def update(self, token: int, args: "GenerationArgs"): 75 | self.token_occurrences.appendleft(token) 76 | if token not in self.occurrences: 77 | self.occurrences[token] = 1 78 | else: 79 | self.occurrences[token] += 1 80 | 81 | if len(self.token_occurrences) > self.maxlen: 82 | while len(self.token_occurrences) > self.maxlen: 83 | token = self.token_occurrences.pop() 84 | self.occurrences[token] -= 1 85 | 86 | def clear(self): 87 | self.token_occurrences.clear() 88 | self.occurrences = {} 89 | 90 | def copy(self) -> "SlidingPenalty": 91 | ret = SlidingPenalty(self.maxlen) 92 | ret.token_occurrences = self.token_occurrences.copy() 93 | ret.occurrences = self.occurrences.copy() 94 | return ret 95 | 96 | 97 | class LogPenalty(Penalty): 98 | token_penalties: np.ndarray[Any, np.dtype[np.float32]] 99 | tensor_penalties: Tensor | None 100 | table_size: int 101 | 102 | def __init__(self, table_size: int = 65536) -> None: 103 | self.token_penalties = np.zeros(table_size, dtype=np.float32) 104 | self.table_size = table_size 105 | self.tensor_penalties = None 106 | 107 | def transform(self, out: Tensor, args: "GenerationArgs") -> Tensor: 108 | # create a tensor from penalties 109 | if self.tensor_penalties is None: 110 | penalties = Tensor(self.token_penalties) 111 | if out.device != penalties.device: 112 | penalties = penalties.to(out.device) 113 | self.tensor_penalties = penalties 114 | out -= self.tensor_penalties 115 | return out 116 | 117 | def update(self, token: int, args: "GenerationArgs"): 118 | to_update = self.token_penalties if self.tensor_penalties is None else self.tensor_penalties 119 | to_update *= args.alpha_decay 120 | to_update[token] += args.alpha_presence 121 | 122 | def clear(self): 123 | self.token_penalties = np.zeros(self.table_size, dtype=np.float32) 124 | self.tensor_penalties = None 125 | 126 | def copy(self) -> "LogPenalty": 127 | ret = LogPenalty(self.table_size) 128 | ret.token_penalties = deepcopy(self.token_penalties) 129 | ret.tensor_penalties = None if self.tensor_penalties is None else self.tensor_penalties.clone() 130 | return ret 131 | 132 | -------------------------------------------------------------------------------- /easy_schema.py: -------------------------------------------------------------------------------- 1 | from functools import cache 2 | from rwkv_contrib.bnf import BNFTree, Node 3 | from rwkv_contrib.bnf_complex import OrNode, PopNode, RepeatNode, SequenceNode 4 | from rwkv_contrib.bnf_simple import CharNode, NotCharNode 5 | 6 | """ 7 | A toolset for Markdown-like BNF grammars. 8 | """ 9 | 10 | 11 | @cache 12 | def literal(tree: BNFTree, literal: bytes): 13 | literal_parsed = [bytes([x]) for x in literal] 14 | return SequenceNode(tree, [CharNode(tree, x) for x in literal_parsed]) 15 | 16 | 17 | @cache 18 | def non_newline(tree: BNFTree) -> Node: 19 | return RepeatNode(tree, NotCharNode(tree, b"\n")) 20 | 21 | 22 | @cache 23 | def quoted(tree: BNFTree, quote: bytes, inner: Node = None) -> Node: 24 | if inner is None: 25 | inner = NotCharNode(tree, b"\n") 26 | quote = [bytes([x]) for x in quote] 27 | start, end = quote 28 | return SequenceNode( 29 | tree, 30 | [ 31 | CharNode(tree, start), 32 | RepeatNode( 33 | tree, 34 | OrNode( 35 | tree, 36 | [ 37 | PopNode(tree, CharNode(tree, end), depth=2), 38 | inner, 39 | ], 40 | ), 41 | ), 42 | ], 43 | ) 44 | 45 | 46 | @cache 47 | def asterisks(tree: BNFTree, inner: Node = None) -> Node: 48 | return quoted(tree, b"**", inner) 49 | 50 | 51 | @cache 52 | def parentheses(tree: BNFTree, inner: Node = None) -> Node: 53 | return quoted(tree, b"()", inner) 54 | 55 | 56 | @cache 57 | def square_brackets(tree: BNFTree, inner: Node = None) -> Node: 58 | return quoted(tree, b"[]", inner) 59 | 60 | 61 | @cache 62 | def curly_brackets(tree: BNFTree, inner: Node = None) -> Node: 63 | return quoted(tree, b"{}", inner) 64 | 65 | 66 | @cache 67 | def infinite_list(tree: BNFTree, bullet: bytes, item_constraint: Node = None) -> Node: 68 | """ 69 | Represents an unnumbered list of items. 70 | 71 | Such list is defined as: 72 | 73 | ``` 74 | {bullet} {item_constraint}{newline}... 75 | ``` 76 | 77 | Note that this is infinite. 78 | """ 79 | 80 | if item_constraint is None: 81 | item_constraint = non_newline(tree) 82 | 83 | return RepeatNode( 84 | tree, 85 | SequenceNode( 86 | tree, 87 | [ 88 | CharNode(tree, bullet), 89 | CharNode(tree, b" "), 90 | item_constraint, 91 | CharNode(tree, b"\n"), 92 | ], 93 | ), 94 | ) 95 | 96 | 97 | @cache 98 | def finite_list(tree: BNFTree, bullet: bytes, item_count: int, item_constraint: Node = None) -> Node: 99 | if item_constraint is None: 100 | item_constraint = non_newline(tree) 101 | single_item = SequenceNode( 102 | tree, 103 | [ 104 | CharNode(tree, bullet), 105 | CharNode(tree, b" "), 106 | item_constraint, 107 | CharNode(tree, b"\n"), 108 | ], 109 | ) 110 | return SequenceNode(tree, [single_item] * item_count) 111 | 112 | 113 | @cache 114 | def numbered_list(tree: BNFTree, item_count: int, item_constraint: Node = None) -> Node: 115 | if item_constraint is None: 116 | item_constraint = non_newline(tree) 117 | 118 | number_items = [] 119 | for i in range(1, item_count + 1): 120 | number_items.append( 121 | SequenceNode( 122 | tree, 123 | [ 124 | CharNode(tree, str(i).encode()), 125 | CharNode(tree, b"."), 126 | CharNode(tree, b" "), 127 | item_constraint, 128 | CharNode(tree, b"\n"), 129 | ], 130 | ) 131 | ) 132 | 133 | return SequenceNode(tree, number_items) 134 | 135 | 136 | def named_list(tree: BNFTree, bullet: bytes | None, nodes: dict[str, Node | None]): 137 | node_sequence = [] 138 | for key, value in nodes.items(): 139 | node_sequence.append(literal(tree, (bullet + b" " if bullet is not None else b"") + key.encode() + b": ")) 140 | if value is None: 141 | value = non_newline(tree) 142 | node_sequence.append(value) 143 | node_sequence.append(CharNode(tree, b"\n")) 144 | 145 | return SequenceNode(tree, node_sequence) 146 | 147 | 148 | @cache 149 | def choices(tree, *args: str): 150 | args: list[list[bytes]] = [[bytes([y]) for y in x.encode()] for x in args] 151 | # make trie of all strings 152 | trie = {} 153 | for arg in args: 154 | current = trie 155 | for char in arg: 156 | if char not in current: 157 | current[char] = {} 158 | current = current[char] 159 | 160 | def trie_to_node(trie): 161 | if trie == {}: 162 | raise ValueError("Empty trie") 163 | return OrNode( 164 | tree, 165 | [ 166 | SequenceNode( 167 | tree, 168 | [CharNode(tree, x), trie_to_node(trie[x])], 169 | ) 170 | if trie[x] != {} 171 | else CharNode(tree, x) 172 | for x in trie 173 | ], 174 | ) 175 | 176 | return trie_to_node(trie) 177 | 178 | 179 | def list_line(tree, item_constraint: Node = None): 180 | if item_constraint is None: 181 | item_constraint = non_newline(tree) 182 | return SequenceNode( 183 | tree, 184 | [ 185 | item_constraint, 186 | RepeatNode( 187 | tree, 188 | SequenceNode( 189 | tree, 190 | [ 191 | CharNode(tree, b","), 192 | CharNode(tree, b" "), 193 | item_constraint, 194 | ], 195 | ), 196 | ), 197 | ], 198 | ) 199 | -------------------------------------------------------------------------------- /bnf_complex.py: -------------------------------------------------------------------------------- 1 | from functools import cache 2 | from rwkv_contrib.bnf import BNFTree, Node, Action 3 | 4 | 5 | class SequenceNode(Node): 6 | nodes: list[Node] 7 | 8 | def __init__(self, bnf_tree: BNFTree, nodes: list[Node]) -> None: 9 | self.nodes = nodes 10 | super().__init__(bnf_tree) 11 | 12 | def add_to_stack(self, callstack: list[tuple[int, int]]) -> None: 13 | callstack.append((self.bnf_index, 0)) 14 | self.bnf_tree[self.get_sub(None)].add_to_stack(callstack) 15 | 16 | def accept_character(self, char: bytes | None, callstack: list[tuple[int, int]]) -> bool: 17 | return True 18 | 19 | def handle_match(self, callstack: list[tuple[int, int]]) -> Action: 20 | # last item in the callstack is always current sequence node 21 | cur_idx, cur_sub = callstack[-1] 22 | next_sub = cur_sub + 1 23 | next_sub_idx = self.get_sub(next_sub) 24 | if next_sub_idx is None: 25 | # we are at the end of the sequence node, so we need to pop current node 26 | # and leave the character for parent node 27 | callstack.pop() 28 | else: 29 | # we are not at the end of the sequence node, so we step to the next subnode 30 | # and hand the character to it 31 | callstack[-1] = (cur_idx, next_sub) 32 | self.bnf_tree[next_sub_idx].add_to_stack(callstack) 33 | # in both cases we need to re-evaluate the current node 34 | return Action.RE_EVAL 35 | 36 | def handle_fail(self, callstack: list[tuple[int, int]]) -> Action: 37 | callstack.pop() 38 | return Action.FAIL 39 | 40 | def get_sub(self, idx: int | None) -> int | None: 41 | if idx is None: 42 | return self.nodes[0].bnf_index 43 | elif idx < len(self.nodes): 44 | return self.nodes[idx].bnf_index 45 | else: 46 | return None 47 | 48 | def matched(self, callstack: list[tuple[int, int]]) -> bool: 49 | cur_idx, cur_sub = callstack[-1] 50 | return cur_sub == len(self.nodes) - 1 # we are already at the end of the sequence node 51 | 52 | @cache 53 | def get_logits(self) -> set[int]: 54 | logits = set() 55 | for node in self.nodes: 56 | logits.update(node.get_logits()) 57 | return logits 58 | 59 | 60 | class OrNode(SequenceNode): 61 | def handle_match(self, callstack: list[tuple[int, int]]) -> Action: 62 | callstack.pop() # pop current node as we are done with it 63 | return Action.RE_EVAL 64 | 65 | def handle_fail(self, callstack: list[tuple[int, int]]) -> Action: 66 | cur_idx, cur_sub = callstack[-1] 67 | next_sub_idx = self.get_sub(cur_sub + 1) 68 | if next_sub_idx is None: # we tried all subnodes and failed, so we need to pop current node and fail 69 | callstack.pop() 70 | return Action.FAIL 71 | else: 72 | # we are not at the end of the sequence node, so we step to the next subnode to try it 73 | callstack[-1] = (cur_idx, cur_sub + 1) 74 | self.bnf_tree[next_sub_idx].add_to_stack(callstack) 75 | return Action.RE_EVAL 76 | 77 | def matched(self, callstack: list[tuple[int, int]]) -> bool: 78 | return True # if the node is on top of the stack, it is matched since the matched subnode is popped 79 | 80 | @cache 81 | def get_logits(self) -> set[int]: 82 | logits = set() 83 | for node in self.nodes: 84 | logits |= node.get_logits() 85 | return logits 86 | 87 | 88 | class PopNode(Node): 89 | depth: int 90 | node: Node 91 | 92 | def __init__(self, bnf_tree: BNFTree, node: Node, depth=1) -> None: 93 | self.depth = depth 94 | self.node = node 95 | super().__init__(bnf_tree) 96 | 97 | def add_to_stack(self, callstack: list[tuple[int, int]]) -> None: 98 | callstack.append((self.bnf_index, 0)) 99 | self.node.add_to_stack(callstack) 100 | 101 | def accept_character(self, char: bytes | None, callstack: list[tuple[int, int]]) -> bool: 102 | return True 103 | 104 | def handle_match(self, callstack: list[tuple[int, int]]) -> Action: 105 | callstack.pop() 106 | depth = self.depth 107 | while depth > 0: 108 | callstack.pop() 109 | depth -= 1 110 | return Action.RE_EVAL 111 | 112 | def handle_fail(self, callstack: list[tuple[int, int]]) -> Action: 113 | callstack.pop() 114 | return Action.FAIL 115 | 116 | def matched(self, callstack: list[tuple[int, int]]) -> bool: 117 | depth = self.depth 118 | while depth > 0: 119 | callstack.pop() 120 | depth -= 1 121 | return True # it should not be possible to match a pop node 122 | 123 | @cache 124 | def get_logits(self) -> set[int]: 125 | return self.node.get_logits() 126 | 127 | 128 | class RepeatNode(Node): 129 | node: Node 130 | 131 | def __init__(self, bnf_tree: BNFTree, node: Node) -> None: 132 | self.node = node 133 | super().__init__(bnf_tree) 134 | 135 | def add_to_stack(self, callstack: list[tuple[int, int]]) -> None: 136 | callstack.append((self.bnf_index, 0)) 137 | self.node.add_to_stack(callstack) 138 | 139 | def accept_character(self, char: bytes | None, callstack: list[tuple[int, int]]) -> bool: 140 | return True 141 | 142 | def handle_match(self, callstack: list[tuple[int, int]]) -> Action: 143 | self.node.add_to_stack(callstack) 144 | return Action.RE_EVAL 145 | 146 | def handle_fail(self, callstack: list[tuple[int, int]]) -> Action: 147 | callstack.pop() 148 | return Action.RE_EVAL 149 | 150 | def matched(self, callstack: list[tuple[int, int]]) -> bool: 151 | return False # it should not be possible to match a repeat node 152 | 153 | @cache 154 | def get_logits(self) -> set[int]: 155 | return self.node.get_logits() 156 | 157 | 158 | class OptionalNode(Node): 159 | node: Node 160 | 161 | def __init__(self, bnf_tree: BNFTree, node: Node) -> None: 162 | self.node = node 163 | super().__init__(bnf_tree) 164 | 165 | def add_to_stack(self, callstack: list[tuple[int, int]]) -> None: 166 | callstack.append((self.bnf_index, 0)) 167 | self.node.add_to_stack(callstack) 168 | 169 | def accept_character(self, char: bytes | None, callstack: list[tuple[int, int]]) -> bool: 170 | return True 171 | 172 | def handle_match(self, callstack: list[tuple[int, int]]) -> Action: 173 | callstack.pop() 174 | return Action.RE_EVAL 175 | 176 | def handle_fail(self, callstack: list[tuple[int, int]]) -> Action: 177 | callstack.pop() 178 | return Action.RE_EVAL 179 | 180 | def matched(self, callstack: list[tuple[int, int]]) -> bool: 181 | return True # it should not be possible to match a optional node 182 | 183 | @cache 184 | def get_logits(self) -> set[int]: 185 | return self.node.get_logits() 186 | -------------------------------------------------------------------------------- /json_schema.py: -------------------------------------------------------------------------------- 1 | from functools import cache, lru_cache 2 | from typing import Any 3 | import ujson 4 | 5 | from rwkv_contrib.bnf import BNFTree, Node 6 | from rwkv_contrib.bnf_complex import OptionalNode, SequenceNode, OrNode, RepeatNode, PopNode 7 | from rwkv_contrib.bnf_simple import CharNode, DigitNode, NotCharNode, WildcardNode 8 | 9 | 10 | @cache 11 | def json_string(tree: BNFTree): 12 | return SequenceNode( 13 | tree, 14 | [ 15 | CharNode(tree, b'"'), 16 | RepeatNode( 17 | tree, 18 | OrNode( 19 | tree, 20 | [ 21 | SequenceNode(tree, [CharNode(tree, b"\\"), NotCharNode(tree, b"\n")]), 22 | PopNode(tree, CharNode(tree, b'"'), depth=2), 23 | NotCharNode(tree, b'\n"'), 24 | ], 25 | ), 26 | ), 27 | ], 28 | ) 29 | 30 | 31 | @cache 32 | def literal(tree: BNFTree, literal: bytes): 33 | literal_parsed = [bytes([x]) for x in literal] 34 | return SequenceNode(tree, [CharNode(tree, x) for x in literal_parsed]) 35 | 36 | 37 | @cache 38 | def json_string_literal(tree: BNFTree, literal_string: bytes): 39 | jsonified = ujson.dumps(literal_string, reject_bytes=False).encode() 40 | return literal(tree, jsonified) 41 | 42 | 43 | @cache 44 | def json_string_enum(tree, *args: str): 45 | args = [ujson.dumps(x) for x in args] 46 | args: list[list[bytes]] = [[bytes([y]) for y in x.encode()] for x in args] 47 | # make trie of all strings 48 | trie = {} 49 | for arg in args: 50 | current = trie 51 | for char in arg: 52 | if char not in current: 53 | current[char] = {} 54 | current = current[char] 55 | 56 | def trie_to_node(trie): 57 | if trie == {}: 58 | raise ValueError("Empty trie") 59 | return OrNode( 60 | tree, 61 | [ 62 | SequenceNode( 63 | tree, 64 | [CharNode(tree, x), trie_to_node(trie[x])], 65 | ) 66 | if trie[x] != {} 67 | else CharNode(tree, x) 68 | for x in trie 69 | ], 70 | ) 71 | 72 | return trie_to_node(trie) 73 | 74 | 75 | @cache 76 | def json_boolean(tree: BNFTree): 77 | return OrNode(tree, [literal(tree, b"true"), literal(tree, b"false")]) 78 | 79 | 80 | @cache 81 | def json_number(tree): 82 | return SequenceNode( 83 | tree, 84 | [ 85 | OptionalNode( 86 | tree, 87 | CharNode(tree, b"-"), 88 | ), 89 | RepeatNode( 90 | tree, 91 | OrNode( 92 | tree, 93 | [ 94 | DigitNode(tree), 95 | PopNode( 96 | tree, 97 | CharNode(tree, b"."), 98 | depth=2, 99 | ), 100 | ], 101 | ), 102 | ), 103 | RepeatNode( 104 | tree, 105 | DigitNode(tree), 106 | ), 107 | ], 108 | ) 109 | 110 | 111 | @cache 112 | def json_date(tree): 113 | digit_match = DigitNode(tree) 114 | return SequenceNode( 115 | tree, 116 | [ 117 | CharNode(tree, b'"'), 118 | digit_match, 119 | digit_match, 120 | digit_match, 121 | digit_match, 122 | CharNode(tree, b"-"), 123 | digit_match, 124 | digit_match, 125 | CharNode(tree, b"-"), 126 | digit_match, 127 | digit_match, 128 | CharNode(tree, b'"'), 129 | ], 130 | ) 131 | 132 | 133 | @cache 134 | def json_time(tree): 135 | digit_match = DigitNode(tree) 136 | return SequenceNode( 137 | tree, 138 | [ 139 | CharNode(tree, b'"'), 140 | digit_match, 141 | digit_match, 142 | CharNode(tree, b":"), 143 | digit_match, 144 | digit_match, 145 | CharNode(tree, b":"), 146 | digit_match, 147 | digit_match, 148 | CharNode(tree, b'"'), 149 | ], 150 | ) 151 | 152 | 153 | @cache 154 | def json_time(tree): 155 | return SequenceNode( 156 | tree, 157 | [ 158 | CharNode(tree, b'"'), 159 | DigitNode(tree), 160 | DigitNode(tree), 161 | CharNode(tree, b":"), 162 | DigitNode(tree), 163 | DigitNode(tree), 164 | CharNode(tree, b":"), 165 | DigitNode(tree), 166 | DigitNode(tree), 167 | CharNode(tree, b'"'), 168 | ], 169 | ) 170 | 171 | 172 | @cache 173 | def json_null(tree): 174 | return literal(tree, b"null") 175 | 176 | 177 | @cache 178 | def json_nullable(tree, node: Node): 179 | return OrNode(tree, [json_null(tree), node]) 180 | 181 | 182 | @cache 183 | def json_array(tree, node: Node): 184 | return SequenceNode( 185 | tree, 186 | [ 187 | CharNode(tree, b"["), 188 | OrNode( 189 | tree, 190 | [ 191 | CharNode(tree, b"]"), 192 | SequenceNode( 193 | tree, 194 | [ 195 | node, 196 | RepeatNode( 197 | tree, 198 | SequenceNode( 199 | tree, 200 | [ 201 | literal(tree, b", "), 202 | node, 203 | ], 204 | ), 205 | ), 206 | CharNode(tree, b"]"), 207 | ], 208 | ), 209 | ], 210 | ), 211 | ], 212 | ) 213 | 214 | 215 | def json_object(tree, nodes: dict[str, Node]): 216 | node_sequence = [] 217 | for key, value in nodes.items(): 218 | node_sequence.append(json_string_literal(tree, key.encode())) 219 | node_sequence.append(literal(tree, b": ")) 220 | node_sequence.append(value) 221 | node_sequence.append(literal(tree, b", ")) 222 | node_sequence.pop() 223 | 224 | return SequenceNode( 225 | tree, 226 | [ 227 | CharNode(tree, b"{"), 228 | SequenceNode( 229 | tree, 230 | node_sequence, 231 | ), 232 | CharNode(tree, b"}"), 233 | ], 234 | ) 235 | 236 | 237 | def load_from_schema(tree: BNFTree, json: dict[str, Any]): 238 | element_type = json["type"] 239 | if element_type == "string": 240 | return json_string(tree) 241 | elif element_type == "number": 242 | return json_number(tree) 243 | elif element_type == "boolean": 244 | return json_boolean(tree) 245 | elif element_type == "array": 246 | return json_array(tree, load_from_schema(tree, json["items"])) 247 | elif element_type == "object": 248 | return json_object( 249 | tree, 250 | {key: load_from_schema(tree, value) for key, value in json["properties"].items()}, 251 | ) 252 | elif element_type == "enum": 253 | return json_string_enum(tree, *json["values"]) 254 | elif element_type == "date": 255 | return json_date(tree) 256 | elif element_type == "time": 257 | return json_time(tree) 258 | elif element_type == "nullable": 259 | return json_nullable(tree, load_from_schema(tree, json["value"])) 260 | -------------------------------------------------------------------------------- /bnf.py: -------------------------------------------------------------------------------- 1 | """ 2 | bnf.py 3 | 4 | This module contains the classes and functions for parsing BNF grammars for RWKV-LM. 5 | """ 6 | from abc import abstractmethod 7 | from collections import defaultdict 8 | from copy import deepcopy 9 | from enum import Enum 10 | from functools import cache 11 | from typing import Any, Optional 12 | import numpy as np 13 | from pytrie import SortedStringTrie 14 | 15 | from rwkv.rwkv_tokenizer import TRIE_TOKENIZER 16 | import torch 17 | 18 | base_tokenizer: TRIE_TOKENIZER = None 19 | token_size = 65536 20 | impossible = -1e9 21 | token_table: dict[int, bytes] = {} 22 | tensor_dtype = torch.float16 23 | 24 | 25 | def setup_tokens(): 26 | for i in range(token_size): 27 | try: 28 | token_table[i] = base_tokenizer.decodeBytes([i]) 29 | except: 30 | pass 31 | 32 | 33 | class BNFException(Exception): 34 | ... 35 | 36 | 37 | class Action(Enum): 38 | """ 39 | Represents an action in the automata. 40 | """ 41 | 42 | RE_EVAL = 0 # Re-evaluate the current node. 43 | CONSUME = 1 # Consumes current character. 44 | FAIL = 2 # Tell the parent node that it failed. 45 | 46 | 47 | class BNFTree: 48 | nodes: list["Node"] 49 | 50 | logits_cache: dict[tuple[tuple[int, int]], np.ndarray[Any, np.dtype[bool]]] 51 | 52 | def __init__(self) -> None: 53 | self.nodes = [] 54 | self.logits_cache = {} 55 | 56 | def __getitem__(self, idx: int) -> "Node": 57 | return self.nodes[idx] 58 | 59 | def __len__(self) -> int: 60 | return len(self.nodes) 61 | 62 | def register(self, node: "Node") -> int: 63 | """ 64 | Registers a node to the tree. Returns the index of the node. 65 | """ 66 | self.nodes.append(node) 67 | return len(self.nodes) - 1 68 | 69 | def eval_once(self, char: bytes | None, callstack: list[tuple[int, int]], propagate_fail=False) -> Action: 70 | cur_idx, cur_sub = callstack[-1] 71 | node = self[cur_idx] 72 | if node.accept_character(char, callstack) and not propagate_fail: 73 | action = node.handle_match(callstack) 74 | else: 75 | action = node.handle_fail(callstack) 76 | return action 77 | 78 | def eval(self, char: bytes | None, callstack: list[tuple[int, int]]) -> None: 79 | action = self.eval_once(char, callstack) 80 | if callstack == []: 81 | raise BNFException("Callstack is empty.") 82 | while action != Action.CONSUME: 83 | action = self.eval_once(char, callstack, propagate_fail=action == Action.FAIL) 84 | if callstack == []: 85 | raise BNFException("Callstack is empty.") 86 | 87 | def eval_bytes(self, text: bytes, callstack: list[tuple[int, int]]) -> None: 88 | for char in text: 89 | self.eval(bytes([char]), callstack) 90 | 91 | def eval_token(self, token: int, callstack: list[tuple[int, int]]) -> None: 92 | token_bytes = token_table[token] 93 | self.eval_bytes(token_bytes, callstack) 94 | 95 | def deflate(self, callstack: list[tuple[int, int]]) -> bool: 96 | """ 97 | Deflates the callstack. Returns True if the callstack is empty (all resolved). 98 | 99 | It will remove all nodes that are resolved from the callstack. 100 | """ 101 | if not callstack: 102 | return True 103 | while callstack: 104 | cur_idx, cur_sub = callstack[-1] 105 | node = self[cur_idx] 106 | if not node.matched(callstack): 107 | return False 108 | else: 109 | callstack.pop() # pop current node as we are done with it 110 | return True 111 | 112 | def get_logits(self, callstack: list[tuple[int, int]]) -> np.ndarray[Any, np.dtype[bool]]: 113 | """ 114 | Get the logits for the current callstack. 115 | """ 116 | 117 | callstack = list(callstack) 118 | if tuple(callstack) not in self.logits_cache: 119 | logits_filter = np.zeros(token_size, dtype=bool) 120 | for logit, logit_bytes in token_table.items(): 121 | try: 122 | cur_callstack = deepcopy(callstack) 123 | self.eval_bytes(logit_bytes, cur_callstack) 124 | logits_filter[logit] = True 125 | except BNFException as e: 126 | pass 127 | self.logits_cache[tuple(callstack)] = logits_filter 128 | return self.logits_cache[tuple(callstack)] 129 | 130 | def dump_logits(self, path) -> None: 131 | """ 132 | Dump the logits cache to a file. 133 | """ 134 | import numpy as np 135 | 136 | ks: list[np.ndarray] = [] 137 | vs = [] 138 | for k, v in self.logits_cache.items(): 139 | k = np.array(k).flatten() 140 | ks.append(k) 141 | vs.append(v) 142 | 143 | max_len = max([len(k) for k in ks]) 144 | for k in ks: 145 | original_len = len(k) 146 | k.resize(max_len, refcheck=False) 147 | k[original_len:] = -1 148 | np.savez_compressed(path, ks=ks, vs=vs) 149 | 150 | def load_logits(self, path) -> None: 151 | """ 152 | Load the logits cache from a file. 153 | """ 154 | import numpy as np 155 | 156 | data = np.load(path) 157 | ks = data["ks"] 158 | vs = data["vs"] 159 | 160 | for k, v in zip(ks, vs): 161 | kt = list() 162 | for i in range(0, len(k), 2): 163 | if k[i] == -1: 164 | break 165 | kt.append((k[i], k[i + 1])) 166 | self.logits_cache[tuple(kt)] = v 167 | 168 | @cache 169 | def get_tensor(self, device: str, callstack: tuple[tuple[int, int]]) -> torch.Tensor: 170 | """ 171 | Get the filter tensor for the current callstack. 172 | """ 173 | logits = self.get_logits(callstack) 174 | # set false to impossible, true to 0 175 | logits = np.where(logits, 0, impossible) 176 | return torch.tensor(logits, dtype=tensor_dtype, device=device).float() 177 | 178 | def filter_tensor(self, tensor: torch.Tensor, callstack: list[tuple[int, int]]) -> torch.Tensor: 179 | """ 180 | Filter the tensor by the callstack. 181 | """ 182 | filter_tensor = self.get_tensor(tensor.device, tuple(callstack)) 183 | return tensor + filter_tensor 184 | 185 | 186 | class Node: 187 | """ 188 | Represents a node in the automata. 189 | """ 190 | 191 | bnf_tree: BNFTree 192 | bnf_index: int 193 | 194 | def __init__(self, bnf_tree: BNFTree) -> None: 195 | self.bnf_tree = bnf_tree 196 | self.bnf_index = bnf_tree.register(self) 197 | self.get_logits() # cache the logits 198 | 199 | @abstractmethod 200 | def add_to_stack(self, callstack: list[tuple[int, int]]) -> None: 201 | """ 202 | Adds the node to the stack. 203 | """ 204 | 205 | @abstractmethod 206 | def accept_character(self, char: bytes | None, callstack: list[tuple[int, int]]) -> bool: 207 | """ 208 | Accepts a character. 209 | """ 210 | 211 | @abstractmethod 212 | def handle_match(self, callstack: list[tuple[int, int]]) -> Action: 213 | """ 214 | Get the action to take after the character is accepted. 215 | """ 216 | 217 | def handle_fail(self, callstack: list[tuple[int, int]]) -> Action: 218 | """ 219 | Get the action to take after the character is rejected. 220 | """ 221 | 222 | @abstractmethod 223 | def get_sub(self, idx: Optional[int]) -> Optional[int]: 224 | """ 225 | Returns the subnode at the given sub index. 226 | 227 | If the sub index is None, returns the first subnode. 228 | """ 229 | 230 | @abstractmethod 231 | def matched(self, callstack: list[tuple[int, int]]) -> bool: 232 | """ 233 | Check if the node is matched. 234 | """ 235 | 236 | @abstractmethod 237 | def get_logits(self) -> set[int]: 238 | """ 239 | Get the logits for the current node. 240 | """ 241 | 242 | @property 243 | @cache 244 | def complex(self) -> bool: 245 | return self.get_sub(None) is not None 246 | -------------------------------------------------------------------------------- /grammar_pipeline.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | from typing import TYPE_CHECKING, Any, Generator, Optional 3 | 4 | import numpy as np 5 | import torch 6 | from torch import Tensor 7 | from os import path 8 | from rwkv_contrib.bnf import BNFTree, Node 9 | from rwkv_contrib.penalty import GlobalPenalty, Penalty 10 | from rwkv_contrib.pipeline import GenerationArgs 11 | from rwkv_contrib.tokenizer import RWKVTokenizer, Tokenizer 12 | 13 | import torch.nn.functional as F 14 | 15 | if TYPE_CHECKING: 16 | from rwkv.model import RWKV 17 | 18 | 19 | class BNFPipeline: 20 | """ 21 | A stateless pipeline for RWKV. 22 | 23 | Output is restricted by a BNF grammar. 24 | """ 25 | 26 | def __init__( 27 | self, 28 | model: "RWKV", 29 | tree: BNFTree, 30 | initial_node: Node, 31 | tokenizer: Tokenizer[str] = RWKVTokenizer(), 32 | penalty: Penalty = None, 33 | default_args: GenerationArgs = None, 34 | logits_cache: str = None, 35 | ) -> None: 36 | penalty = penalty or GlobalPenalty() 37 | default_args = default_args or GenerationArgs() 38 | 39 | self.model = model 40 | self.tokenizer = tokenizer 41 | self.penalty = penalty 42 | self.default_args = default_args 43 | 44 | self.encode = tokenizer.encode 45 | self.decode = tokenizer.decode 46 | 47 | self.tree = tree 48 | self.initial_node = initial_node 49 | self.callstack: list[tuple[int, int]] = [] 50 | 51 | self.logits_cache = logits_cache 52 | if logits_cache is not None and path.exists(logits_cache): 53 | self.tree.load_logits(logits_cache) 54 | 55 | def sample_logits(self, logits: Tensor, args: GenerationArgs) -> Tensor: 56 | """ 57 | Sample logits. 58 | """ 59 | args = args or self.default_args 60 | 61 | probs = F.softmax(logits, dim=-1) 62 | top_k = args.top_k 63 | if probs.device == torch.device("cpu"): 64 | probs = probs.numpy() 65 | sorted_ids = np.argsort(probs) 66 | sorted_probs = probs[sorted_ids][::-1] 67 | cumulative_probs = np.cumsum(sorted_probs) 68 | cutoff = float(sorted_probs[np.argmax(cumulative_probs > args.top_p)]) 69 | probs[probs < cutoff] = 0 70 | if top_k < len(probs) and top_k > 0: 71 | probs[sorted_ids[:-top_k]] = 0 72 | if args.temperature != 1.0: 73 | probs = probs ** (1.0 / args.temperature) 74 | probs = probs / np.sum(probs) 75 | out = np.random.choice(a=len(probs), p=probs) 76 | token = int(out) 77 | else: 78 | sorted_ids = torch.argsort(probs) 79 | sorted_probs = probs[sorted_ids] 80 | sorted_probs = torch.flip(sorted_probs, dims=(0,)) 81 | cumulative_probs = torch.cumsum(sorted_probs, dim=-1).cpu().numpy() 82 | cutoff = float(sorted_probs[np.argmax(cumulative_probs > args.top_p)]) 83 | probs[probs < cutoff] = 0 84 | if top_k < len(probs) and top_k > 0: 85 | probs[sorted_ids[:-top_k]] = 0 86 | if args.temperature != 1.0: 87 | probs = probs ** (1.0 / args.temperature) 88 | out = torch.multinomial(probs, num_samples=1)[0] 89 | token = int(out) 90 | 91 | return token 92 | 93 | def dump_logits(self, logits_cache: str = None) -> None: 94 | logits_cache = logits_cache or self.logits_cache 95 | self.tree.dump_logits(logits_cache) 96 | 97 | def infer( 98 | self, 99 | tokens: list[int], 100 | *, 101 | state: Any = None, 102 | args: GenerationArgs = None, 103 | penalty: Penalty = None, 104 | update_tokens_penalty: bool = True, 105 | intialize_callstack: bool = True, 106 | ) -> tuple[Optional[int], Any]: 107 | """ 108 | Infer the next token from a list of tokens. 109 | 110 | If the input is a list, and the first element is an integer, it is assumed to be a list of tokens. 111 | 112 | None is returned if stop tokens are generated. 113 | """ 114 | 115 | args = args or self.default_args 116 | penalty = penalty or self.penalty 117 | 118 | if intialize_callstack: 119 | self.callstack = [] 120 | self.initial_node.add_to_stack(self.callstack) 121 | 122 | if update_tokens_penalty: 123 | for token in tokens: 124 | penalty.update(token, args) 125 | 126 | for i in range(0, len(tokens), args.chunk_len): 127 | chunk = tokens[i : i + args.chunk_len] 128 | out, state = self.model.forward(chunk, state=state) 129 | 130 | for n in args.token_ban: 131 | out[n] = -float("inf") 132 | 133 | out = penalty.transform(out, args) 134 | out = self.tree.filter_tensor(out, self.callstack) 135 | 136 | token = self.sample_logits(out, args=args) 137 | self.tree.eval_token(token, self.callstack) 138 | if token in args.token_stop: 139 | return None, state 140 | 141 | return token, state 142 | 143 | def generate(self, ctx: str, generation_length: int = 100, *, state=None, args: GenerationArgs = None, clear_penalty: bool = True) -> Generator[str, None, None]: 144 | self.callstack = [] 145 | self.initial_node.add_to_stack(self.callstack) 146 | 147 | if args is None: 148 | args = self.default_args 149 | 150 | if clear_penalty: 151 | self.penalty.clear() 152 | 153 | tokens_tmp = [] 154 | token, state = self.infer(self.encode(ctx), state=state, args=args, intialize_callstack=False) 155 | while token is not None and generation_length > 0: 156 | generation_length -= 1 157 | tokens_tmp.append(token) 158 | tmp = self.decode(tokens_tmp) 159 | if self.tokenizer.validate(tmp): 160 | yield tmp 161 | tokens_tmp = [] 162 | if self.tree.deflate(self.callstack): 163 | break 164 | token, state = self.infer([token], state=state, args=args, intialize_callstack=False) 165 | 166 | 167 | class StatefulBNFPipeline(BNFPipeline): 168 | state: Any 169 | 170 | def __init__( 171 | self, 172 | model: "RWKV", 173 | tree: BNFTree, 174 | initial_node: Node, 175 | tokenizer: Tokenizer[str] = RWKVTokenizer(), 176 | penalty: Penalty = None, 177 | default_args: GenerationArgs = None, 178 | logits_cache: str = None, 179 | init_state: Any = None, 180 | init_prompt: str = None, 181 | ) -> None: 182 | super().__init__(model, tree, initial_node, tokenizer, penalty, default_args, logits_cache) 183 | 184 | self.state = init_state 185 | if init_prompt is not None: 186 | self.push(init_prompt) 187 | 188 | def infer(self, tokens: list[int], *, state: Any = None, args: GenerationArgs = None, penalty: Penalty = None) -> tuple[int | None, Any]: 189 | if state is None: 190 | state = self.state 191 | token, self.state = super().infer(tokens, state=state, args=args, penalty=penalty) 192 | return token, self.state 193 | return super().infer(tokens, state=state, args=args, penalty=penalty) 194 | 195 | def push(self, ctx: str): 196 | tokens = self.encode(ctx) 197 | _, self.state = self.infer(tokens, state=self.state, args=self.default_args, penalty=self.penalty) 198 | 199 | 200 | class RecallableBNFPipeline(StatefulBNFPipeline): 201 | history: deque[Any] 202 | 203 | def __init__( 204 | self, 205 | model: "RWKV", 206 | tree: BNFTree, 207 | initial_node: Node, 208 | tokenizer: Tokenizer[str] = RWKVTokenizer(), 209 | penalty: Penalty = None, 210 | default_args: GenerationArgs = None, 211 | logits_cache: str = None, 212 | max_history: int = 10, 213 | init_state: Any = None, 214 | init_prompt: str = None, 215 | ) -> None: 216 | super().__init__(model, tree, initial_node, tokenizer, penalty, default_args, logits_cache, init_state, init_prompt) 217 | self.history = deque(maxlen=max_history) 218 | self.history.append(self.state) 219 | 220 | def recall(self, depth=1) -> Any: 221 | for _ in range(depth): 222 | self.state = self.history.pop() 223 | 224 | def push(self, ctx: str): 225 | self.history.append(self.state) 226 | return super().push(ctx) 227 | 228 | def generate(self, ctx: str, generation_length: int = 100, *, state=None, args: GenerationArgs = None, clear_penalty: bool = True) -> Generator[str, None, None]: 229 | self.history.append(self.state) 230 | return super().generate(ctx, generation_length, state=state, args=args, clear_penalty=clear_penalty) 231 | -------------------------------------------------------------------------------- /pipeline.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | from copy import deepcopy 3 | from typing import Any, Deque, Generator, Generic, Optional, TypeVar, List, TYPE_CHECKING 4 | from dataclasses import dataclass, field 5 | import torch 6 | from torch import Tensor 7 | import numpy as np 8 | from .penalty import GlobalPenalty, Penalty, SlidingPenalty 9 | from .tokenizer import RWKVTokenizer, Tokenizer 10 | import torch.nn.functional as F 11 | 12 | if TYPE_CHECKING: 13 | from rwkv.model import RWKV 14 | 15 | 16 | @dataclass 17 | class GenerationArgs: 18 | """ 19 | Data holder for generation arguments. 20 | """ 21 | 22 | temperature: float = 1.0 23 | top_p: float = 0.85 24 | top_k: int = 0 25 | alpha_frequency: float = 0.2 26 | alpha_presence: float = 0.2 27 | alpha_decay: float = 0.9 28 | token_ban: List[int] = field(default_factory=list) 29 | token_stop: List[int] = field(default_factory=list) 30 | chunk_len: int = 256 31 | linear_mode: bool = False 32 | 33 | 34 | T = TypeVar("T") 35 | 36 | 37 | class Pipeline(Generic[T]): 38 | 39 | """ 40 | A stateless pipeline for RWKV. 41 | 42 | GlobalPenalty is used by default. 43 | """ 44 | 45 | def __init__( 46 | self, 47 | model: "RWKV", 48 | tokenizer: Tokenizer[T] = RWKVTokenizer(), 49 | penalty: Penalty = None, 50 | default_args: GenerationArgs = None, 51 | ) -> None: 52 | penalty = penalty or GlobalPenalty() 53 | default_args = default_args or GenerationArgs() 54 | 55 | self.model = model 56 | self.tokenizer = tokenizer 57 | self.penalty = penalty 58 | self.default_args = default_args 59 | 60 | self.encode = tokenizer.encode 61 | self.decode = tokenizer.decode 62 | 63 | def sample_logits(self, logits, args: GenerationArgs = None): 64 | args = args or self.default_args 65 | 66 | probs = F.softmax(logits.float(), dim=-1) 67 | top_k = args.top_k 68 | if probs.device == torch.device("cpu"): 69 | probs = probs.numpy() 70 | sorted_ids = np.argsort(probs) 71 | sorted_probs = probs[sorted_ids][::-1] 72 | cumulative_probs = np.cumsum(sorted_probs) 73 | cutoff = float(sorted_probs[np.argmax(cumulative_probs > args.top_p)]) 74 | probs[probs < cutoff] = 0 75 | if top_k < len(probs) and top_k > 0: 76 | probs[sorted_ids[:-top_k]] = 0 77 | if args.temperature != 1.0: 78 | probs = probs ** (1.0 / args.temperature) 79 | probs = probs / np.sum(probs) 80 | out = np.random.choice(a=len(probs), p=probs) 81 | return int(out) 82 | else: 83 | sorted_ids = torch.argsort(probs) 84 | sorted_probs = probs[sorted_ids] 85 | sorted_probs = torch.flip(sorted_probs, dims=(0,)) 86 | cumulative_probs = torch.cumsum(sorted_probs, dim=-1).cpu().numpy() 87 | cutoff = float(sorted_probs[np.argmax(cumulative_probs > args.top_p)]) 88 | probs[probs < cutoff] = 0 89 | if top_k < len(probs) and top_k > 0: 90 | probs[sorted_ids[:-top_k]] = 0 91 | if args.temperature != 1.0: 92 | probs = probs ** (1.0 / args.temperature) 93 | out = torch.multinomial(probs, num_samples=1)[0] 94 | return int(out) 95 | 96 | def infer(self, tokens: List[int], *, state: Any = None, args: GenerationArgs = None, penalty: Penalty = None) -> tuple[Optional[int], Any]: 97 | """ 98 | Infer the next token from a list of tokens. 99 | 100 | If the input is a list, and the first element is an integer, it is assumed to be a list of tokens. 101 | 102 | None is returned if stop tokens are generated. 103 | """ 104 | 105 | args = args or self.default_args 106 | penalty = penalty or self.penalty 107 | 108 | for token in tokens: 109 | penalty.update(token, args) 110 | 111 | for i in range(0, len(tokens), args.chunk_len): 112 | chunk = tokens[i : i + args.chunk_len] 113 | out, state = self.model.forward(chunk, state=state) 114 | 115 | for n in args.token_ban: 116 | out[n] = -float("inf") 117 | 118 | out = penalty.transform(out, args) 119 | token = self.sample_logits(out, args=args) 120 | if token in args.token_stop: 121 | return None, state 122 | 123 | return token, state 124 | 125 | def generate(self, ctx: T, generation_length: int = 100, *, state=None, args: GenerationArgs = None, clear_penalty: bool = True) -> Generator[T, None, None]: 126 | if args is None: 127 | args = self.default_args 128 | 129 | if clear_penalty: 130 | self.penalty.clear() 131 | 132 | tokens_tmp = [] 133 | token, state = self.infer(self.encode(ctx), state=state, args=args) 134 | while token is not None and generation_length > 0: 135 | generation_length -= 1 136 | tokens_tmp.append(token) 137 | tmp = self.decode(tokens_tmp) 138 | if self.tokenizer.validate(tmp): 139 | yield tmp 140 | tokens_tmp = [] 141 | token, state = self.infer([token], state=state, args=args) 142 | 143 | 144 | class StatefulPipeline(Generic[T]): 145 | 146 | """ 147 | A stateful pipeline for RWKV. 148 | 149 | The pipeline holds the state that can act as 'memory' for the model. 150 | 151 | SlidingPenalty with maxlen=1024 is used by default. 152 | """ 153 | 154 | state: List[Tensor] 155 | last_token: Optional[int] 156 | 157 | def __init__( 158 | self, 159 | model: "RWKV", 160 | tokenizer: Tokenizer[T] = RWKVTokenizer(), 161 | penalty: Penalty = None, 162 | default_args: GenerationArgs = None, 163 | initial_state: List[Tensor] = None, 164 | initial_prompt: T = None, 165 | ) -> None: 166 | if initial_prompt is not None and initial_state is not None: 167 | raise ValueError("Cannot provide both initial_state and initial_prompt") 168 | 169 | penalty = penalty or SlidingPenalty() 170 | default_args = default_args or GenerationArgs() 171 | 172 | self.model = model 173 | self.tokenizer = tokenizer 174 | self.penalty = penalty 175 | self.default_args = default_args 176 | 177 | self.encode = tokenizer.encode 178 | self.decode = tokenizer.decode 179 | 180 | self.state = initial_state 181 | self.last_token = None 182 | if initial_prompt is not None: 183 | self.push(initial_prompt) 184 | 185 | def push(self, ctx: T, args: GenerationArgs = None, rnn_mode=False) -> None: 186 | """ 187 | Push a context into the state. 188 | 189 | Last token is inferred from the context, and will be used as the first token for the completion. 190 | """ 191 | args = args or self.default_args 192 | if args.linear_mode: 193 | tokens = self.encode(ctx) 194 | for token in tokens: 195 | self.last_token, self.state = self.infer([token], args=args) 196 | else: 197 | self.last_token, self.state = self.infer(self.encode(ctx), args=args) 198 | 199 | def sample_logits(self, logits: Tensor, args: GenerationArgs = None): 200 | args = args or self.default_args 201 | 202 | probs = F.softmax(logits.float(), dim=-1) 203 | top_k = args.top_k 204 | if probs.device == torch.device("cpu"): 205 | probs = probs.numpy() 206 | sorted_ids = np.argsort(probs) 207 | sorted_probs = probs[sorted_ids][::-1] 208 | cumulative_probs = np.cumsum(sorted_probs) 209 | cutoff = float(sorted_probs[np.argmax(cumulative_probs > args.top_p)]) 210 | probs[probs < cutoff] = 0 211 | if top_k < len(probs) and top_k > 0: 212 | probs[sorted_ids[:-top_k]] = 0 213 | if args.temperature != 1.0: 214 | probs = probs ** (1.0 / args.temperature) 215 | probs = probs / np.sum(probs) 216 | out = np.random.choice(a=len(probs), p=probs) 217 | return int(out) 218 | else: 219 | sorted_ids = torch.argsort(probs) 220 | sorted_probs = probs[sorted_ids] 221 | sorted_probs = torch.flip(sorted_probs, dims=(0,)) 222 | cumulative_probs = torch.cumsum(sorted_probs, dim=-1).cpu().numpy() 223 | cutoff = float(sorted_probs[np.argmax(cumulative_probs > args.top_p)]) 224 | probs[probs < cutoff] = 0 225 | if top_k < len(probs) and top_k > 0: 226 | probs[sorted_ids[:-top_k]] = 0 227 | if args.temperature != 1.0: 228 | probs = probs ** (1.0 / args.temperature) 229 | out = torch.multinomial(probs, num_samples=1)[0] 230 | return int(out) 231 | 232 | def infer(self, tokens: List[int], args: GenerationArgs = None) -> tuple[Optional[int], Any]: 233 | """ 234 | Generate exactly one token from the given list of token. 235 | 236 | `None` is returned if the token is a stop token. 237 | 238 | The state is updated after the generation. 239 | 240 | If `state` is `None`, the internal state is used, and the internal state is updated. 241 | """ 242 | args = args or self.default_args 243 | 244 | for token in tokens: 245 | self.penalty.update(token, args) 246 | 247 | for i in range(0, len(tokens), args.chunk_len): 248 | chunk = tokens[i : i + args.chunk_len] 249 | out, self.state = self.model.forward(chunk, state=self.state) 250 | 251 | for n in args.token_ban: 252 | out[n] = -float("inf") 253 | 254 | out = self.penalty.transform(out, args) 255 | token = self.sample_logits(out, args=args) 256 | if token in args.token_stop: 257 | return None, self.state 258 | 259 | self.last_token = token 260 | return token, self.state 261 | 262 | def _generate(self, ctx: List[int], generation_length: int = 100, *, args: GenerationArgs = None) -> Generator[List[int], None, None]: 263 | args = args or self.default_args 264 | 265 | tokens_tmp = [] 266 | if args.linear_mode: 267 | for ct in ctx: 268 | token, self.state = self.infer([ct], args=args) 269 | else: 270 | token, self.state = self.infer(ctx, args=args) 271 | while token is not None and generation_length > 0: 272 | generation_length -= 1 273 | tokens_tmp.append(token) 274 | tmp = self.decode(tokens_tmp) 275 | if self.tokenizer.validate(tmp): 276 | yield tmp 277 | if token == 0: 278 | break 279 | tokens_tmp = [] 280 | if token in args.token_stop: 281 | break 282 | token, self.state = self.infer([token], args=args) 283 | 284 | def generate(self, ctx: T, generation_length: int = 100, *, args: GenerationArgs = None) -> Generator[T, None, None]: 285 | return self._generate(self.encode(ctx), generation_length=generation_length, args=args) 286 | 287 | def continue_generation(self, generation_length: int = 100, *, args: GenerationArgs = None) -> Generator[T, None, None]: 288 | """ 289 | Continue the generation from the last token. 290 | 291 | The return value is a generator that yields the generated parts of the string. 292 | 293 | The state is updated as generation. 294 | """ 295 | return self._generate([self.last_token], generation_length=generation_length, args=args) 296 | 297 | 298 | class RecallablePipeline(StatefulPipeline[T]): 299 | """ 300 | A stateful pipeline that can recall the last generations. 301 | 302 | However, only generation is supported, infer is not supported as memory would explode. 303 | """ 304 | 305 | history: Deque[List[Tensor]] 306 | history_tokens: Deque[int] 307 | history_penalties: Deque[Penalty] 308 | 309 | def __init__( 310 | self, 311 | model: "RWKV", 312 | tokenizer: Tokenizer = RWKVTokenizer(), 313 | penalty: Penalty = None, 314 | default_args: GenerationArgs = None, 315 | initial_state: List[Tensor] = None, 316 | initial_prompt: T = None, 317 | max_history: int = 16, 318 | ) -> None: 319 | self.history = deque(maxlen=max_history) 320 | self.history_tokens = deque(maxlen=max_history) 321 | self.history_penalties = deque(maxlen=max_history) 322 | super().__init__(model, tokenizer, penalty, default_args, initial_state, initial_prompt) 323 | 324 | def recall(self, times: int = 1) -> None: 325 | """ 326 | Recall the last generation. 327 | """ 328 | for _ in range(times): 329 | if len(self.history) == 0: 330 | raise IndexError("Cannot recall empty history") 331 | self.state = self.history.pop() 332 | self.last_token = self.history_tokens.pop() 333 | self.penalty = self.history_penalties.pop() 334 | 335 | def push(self, ctx: T) -> None: 336 | if self.state is not None: 337 | self.history.append(deepcopy(self.state)) 338 | self.history_tokens.append(self.last_token) 339 | self.history_penalties.append(self.penalty.copy()) 340 | return super().push(ctx) 341 | 342 | def _generate(self, ctx: List[int], generation_length: int = 100, *, args: GenerationArgs = None) -> Generator[Any, None, None]: 343 | self.history.append(deepcopy(self.state)) 344 | self.history_tokens.append(self.last_token) 345 | self.history_penalties.append(self.penalty.copy()) 346 | return super()._generate(ctx, generation_length, args=args) 347 | --------------------------------------------------------------------------------