├── tests ├── test_parsing │ └── __init__.py ├── test_accept_string │ ├── __init__.py │ ├── test_json_arr.py │ ├── test_unicode.py │ ├── test_decode_utf8.py │ ├── test_accept_unicode_bytes.py │ ├── test_json.py │ ├── test_geo_query.py │ ├── test_overnight.py │ ├── test_smiles.py │ └── test_pddl.py ├── test_accept_token_sequence │ ├── __init__.py │ ├── test_gpt2.py │ ├── test_gemma2.py │ ├── test_llama.py │ ├── test_mistral.py │ ├── test_phi.py │ ├── test_deepseek.py │ ├── test_llama3.py │ ├── test_t5.py │ └── _test_accept_tokens_mixin.py ├── __init__.py ├── json_utils.py └── test_cli │ └── test_model_support.py ├── transformers_cfg ├── cli │ ├── __init__.py │ └── cli_main.py ├── metrics │ └── __init__.py ├── adapters │ ├── __init__.py │ └── llama_cpp_python.py ├── generation │ └── __init__.py ├── tokenization │ ├── __init__.py │ ├── mapping │ │ ├── __init__.py │ │ ├── ByteProxyMapping.py │ │ └── token2byte.py │ ├── SUPPORTED_TOKENIZERS.py │ ├── utils.py │ ├── tokenizer.py │ └── byte_trie.py ├── __init__.py ├── grammar_utils.py ├── logging_config.py ├── utils.py └── utf8_utils.py ├── examples ├── grammars │ ├── balanced_parentheses.ebnf │ ├── animal.ebnf │ ├── arabic.ebnf │ ├── arithmetic.ebnf │ ├── chinese.ebnf │ ├── russian.ebnf │ ├── korean.ebnf │ ├── japanese.ebnf │ ├── PDDL │ │ ├── blocks.ebnf │ │ ├── satellite.ebnf │ │ ├── satellite_typed.ebnf │ │ ├── depot.ebnf │ │ └── depot_typed.ebnf │ ├── emoji.ebnf │ ├── json_minimal.ebnf │ ├── unicode │ │ └── emoji_escape.ebnf │ ├── custom_json_grammars │ │ ├── schemas │ │ │ ├── student.json │ │ │ └── product_catalog.json │ │ ├── grammars │ │ │ ├── product_catalog.ebnf │ │ │ └── student.ebnf │ │ └── README.md │ ├── chess.ebnf │ ├── json.ebnf │ ├── json_arr.ebnf │ ├── overnight.ebnf │ ├── c.ebnf │ ├── cIE.ebnf │ ├── SMILES │ │ ├── generic.ebnf │ │ ├── chain_extenders.ebnf │ │ ├── isocyanates.ebnf │ │ └── acrylates.ebnf │ └── geo_query.ebnf ├── prompts │ └── json.txt ├── benchmarking │ ├── run_generation.sh │ └── time_benchmarking.py ├── accept.py ├── generate_llama_cpp_python.py ├── run_seq2seq_model.py ├── generate_json.py ├── pipeline_json.py ├── metrics │ └── run_constrained_decoding_metric.py ├── demo.sh ├── generate_smiles.py ├── generate_geo_query.py └── generate_pddl.py ├── docs ├── troubleshoot.md ├── assets │ ├── plots │ │ ├── benchmarking_results.png │ │ ├── benchmarking_smoothed.png │ │ └── arithmetic_grammar_viz.png │ └── screenshots │ │ └── vscode_ebnf_syntax_highlight.png ├── grammar&parser.md ├── json_grammar.md ├── contribute.md ├── add_new_model_support.md ├── benchmarking.md ├── supported_models.yaml └── debugging_custom_grammars.md ├── requirements.txt ├── .github └── workflows │ ├── test.yml │ └── pipy_release.yml ├── LICENSE ├── setup.py ├── .pre-commit-config.yaml └── .gitignore /tests/test_parsing/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /transformers_cfg/cli/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/test_accept_string/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /transformers_cfg/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /transformers_cfg/adapters/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /transformers_cfg/generation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /transformers_cfg/tokenization/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/test_accept_token_sequence/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /transformers_cfg/tokenization/mapping/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/grammars/balanced_parentheses.ebnf: -------------------------------------------------------------------------------- 1 | root ::= "(" root ")" | "" 2 | -------------------------------------------------------------------------------- /examples/grammars/animal.ebnf: -------------------------------------------------------------------------------- 1 | root ::= "The animal is a " animal "." 2 | animal ::= "cat" | "fish" 3 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | from transformers_cfg.logging_config import setup_logging 2 | 3 | setup_logging() 4 | -------------------------------------------------------------------------------- /docs/troubleshoot.md: -------------------------------------------------------------------------------- 1 | # Trouble Shooting 2 | 3 | ## The generation doesn't seem to be conforming to the json format? 4 | -------------------------------------------------------------------------------- /transformers_cfg/__init__.py: -------------------------------------------------------------------------------- 1 | from .logging_config import setup_logging 2 | 3 | setup_logging() 4 | 5 | __version__ = "0.2.7" 6 | -------------------------------------------------------------------------------- /docs/assets/plots/benchmarking_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-dlab/transformers-CFG/HEAD/docs/assets/plots/benchmarking_results.png -------------------------------------------------------------------------------- /docs/assets/plots/benchmarking_smoothed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-dlab/transformers-CFG/HEAD/docs/assets/plots/benchmarking_smoothed.png -------------------------------------------------------------------------------- /docs/assets/plots/arithmetic_grammar_viz.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-dlab/transformers-CFG/HEAD/docs/assets/plots/arithmetic_grammar_viz.png -------------------------------------------------------------------------------- /docs/assets/screenshots/vscode_ebnf_syntax_highlight.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-dlab/transformers-CFG/HEAD/docs/assets/screenshots/vscode_ebnf_syntax_highlight.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=2.0.0 2 | numpy>=1.24.2 3 | transformers>=4.37.2 4 | tokenizers>=0.19.0 5 | termcolor>=2.4.0 6 | sentencepiece>=0.1.99 7 | protobuf>=4.25.2 8 | setuptools>=69.0.3 9 | -------------------------------------------------------------------------------- /examples/grammars/arabic.ebnf: -------------------------------------------------------------------------------- 1 | # A probably incorrect grammar for Arabic text 2 | root ::= ar-char+ ([ \t\n] ar-char+)* 3 | ar-char ::= arabic | punctuation 4 | arabic ::= [٠-٩] 5 | punctuation ::= [،-۔] 6 | -------------------------------------------------------------------------------- /examples/grammars/arithmetic.ebnf: -------------------------------------------------------------------------------- 1 | root ::= (expr "=" ws term "\n")+ 2 | expr ::= term ([-+*/] term)* 3 | term ::= ident | num | "(" ws expr ")" ws 4 | ident ::= [a-z] [a-z0-9_]* ws 5 | num ::= [0-9]+ ws 6 | ws ::= [ \t\n]* 7 | -------------------------------------------------------------------------------- /examples/grammars/chinese.ebnf: -------------------------------------------------------------------------------- 1 | # An overly simplified grammar for Chinese text 2 | root ::= cn-char+ ([ \t\n] cn-char+)* 3 | cn-char ::= chinese | punctuation 4 | chinese ::= [一-鿿] 5 | punctuation ::= [、-〾] | [!?.,;:()"'`] 6 | -------------------------------------------------------------------------------- /examples/grammars/russian.ebnf: -------------------------------------------------------------------------------- 1 | # An overly simplified grammar for Russian text 2 | root ::= ru-char+ ([ \t\n] ru-char+)* 3 | ru-char ::= cyrillic | punctuation 4 | cyrillic ::= [А-Яа-я] 5 | punctuation ::= [、-〾] | [!?.,;:()"'`] 6 | -------------------------------------------------------------------------------- /examples/grammars/korean.ebnf: -------------------------------------------------------------------------------- 1 | # An overly simplified grammar for Korean 2 | root ::= kr-char+ ([ \t\n] kr-char+)* 3 | kr-char ::= hangul | hanja | punctuation 4 | hangul ::= [가-힣] 5 | hanja ::= [一-鿿] 6 | punctuation ::= [、-〾] | [!?.,;:()"'`] 7 | -------------------------------------------------------------------------------- /examples/grammars/japanese.ebnf: -------------------------------------------------------------------------------- 1 | # A probably incorrect grammar for Japanese 2 | root ::= jp-char+ ([ \t\n] jp-char+)* 3 | jp-char ::= hiragana | katakana | punctuation | cjk 4 | hiragana ::= [ぁ-ゟ] 5 | katakana ::= [ァ-ヿ] 6 | punctuation ::= [、-〾] 7 | cjk ::= [一-鿿] 8 | -------------------------------------------------------------------------------- /examples/grammars/PDDL/blocks.ebnf: -------------------------------------------------------------------------------- 1 | root ::= plan 2 | 3 | plan ::= action ( " " action )* 4 | 5 | action ::= "(" ( 6 | ( "put-down" | "pick-up" ) " " object | 7 | ( "unstack" | "pick-up-and-stack") " " object " " object | 8 | "unstack-and-stack" " " object " " object " " object 9 | ) ")" 10 | 11 | object ::= [a-e] 12 | -------------------------------------------------------------------------------- /tests/json_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | 4 | def is_json_parsable(string): 5 | try: 6 | json.loads(string) 7 | return True 8 | except json.JSONDecodeError: 9 | return False 10 | except Exception as e: 11 | # You might want to handle or log other exceptions as well 12 | return False 13 | -------------------------------------------------------------------------------- /tests/test_cli/test_model_support.py: -------------------------------------------------------------------------------- 1 | from transformers_cfg.cli.cli_main import check_model_support 2 | 3 | 4 | def test_supported_model(): 5 | model = "gpt2" 6 | assert check_model_support(model) == True 7 | 8 | 9 | def test_unsupported_model(): 10 | model = "bigscience/bloom" 11 | assert check_model_support(model) == False 12 | -------------------------------------------------------------------------------- /transformers_cfg/grammar_utils.py: -------------------------------------------------------------------------------- 1 | from .token_grammar_recognizer import ( 2 | IncrementalTokenRecognizer, 3 | NonIncrementalTokenSeqRecognizer, 4 | ) 5 | 6 | 7 | # Old class name, kept for backward compatibility 8 | IncrementalGrammarConstraint = IncrementalTokenRecognizer 9 | 10 | NonIncrementalGrammarConstraint = NonIncrementalTokenSeqRecognizer 11 | -------------------------------------------------------------------------------- /docs/grammar&parser.md: -------------------------------------------------------------------------------- 1 | # Grammar 2 | 3 | 4 | 5 | #### Terminals 6 | 7 | - character, e.g. `a`, `b`, `c`, `2`, `里`, `ç`, `😀` 8 | - code point 9 | - hexcode(0-255) in the form of `\xNN`, e.g. `\x61` for `a` 10 | - unicode(including letters, numbers, and symbols), e.g. `U+0061` for `a` 11 | 12 | - Escapes: 13 | - 8-bit (\xXX), 14 | - 16-bit (\uXXXX), 15 | - 32-bit (\UXXXXXXXX). 16 | -------------------------------------------------------------------------------- /docs/json_grammar.md: -------------------------------------------------------------------------------- 1 | # JSON(JavaScript Object Notation) Grammar 2 | 3 | 4 | ## JSON standard 5 | 6 | https://datatracker.ietf.org/doc/html/rfc7159 7 | 8 | ## Clarification 9 | 10 | - JSON doesn't support comments.(JSON5 does but it's not in Python's standard library) 11 | - JSON doesn't support trailing commas. 12 | 13 | 14 | ## JSON5 VS JSON 15 | 16 | https://spec.json5.org/ 17 | -------------------------------------------------------------------------------- /examples/prompts/json.txt: -------------------------------------------------------------------------------- 1 | Generate a valid JSON object that contains a list of all even numbers up to . 2 | Generate a valid JSON object that contains a list of all keys from key1 to key. 3 | Generate a valid JSON object that contains all key value pairs from (key1, value1) to (key, value) . 4 | This is a valid json string for a detailed http request: 5 | -------------------------------------------------------------------------------- /tests/test_accept_token_sequence/test_gpt2.py: -------------------------------------------------------------------------------- 1 | from transformers import GPT2TokenizerFast 2 | from tests.test_accept_token_sequence._test_accept_tokens_mixin import ( 3 | TokenizerTesterMixin, 4 | ) 5 | 6 | 7 | class TestGPT2Tokenizer(TokenizerTesterMixin): 8 | tokenizer_class = GPT2TokenizerFast 9 | pretrained_name = "gpt2" 10 | 11 | def setup(self): 12 | self.setup_tokenizer() 13 | -------------------------------------------------------------------------------- /tests/test_accept_token_sequence/test_gemma2.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer 2 | from tests.test_accept_token_sequence._test_accept_tokens_mixin import ( 3 | TokenizerTesterMixin, 4 | ) 5 | 6 | 7 | class TestGemma2Tokenizer(TokenizerTesterMixin): 8 | tokenizer_class = AutoTokenizer 9 | pretrained_name = "Transformers-CFG/gemma-2-2b-it-tokenizer" 10 | 11 | def setup(self): 12 | self.setup_tokenizer() 13 | -------------------------------------------------------------------------------- /tests/test_accept_token_sequence/test_llama.py: -------------------------------------------------------------------------------- 1 | from transformers import LlamaTokenizerFast 2 | from tests.test_accept_token_sequence._test_accept_tokens_mixin import ( 3 | TokenizerTesterMixin, 4 | ) 5 | 6 | 7 | class TestLlamaTokenizer(TokenizerTesterMixin): 8 | tokenizer_class = LlamaTokenizerFast 9 | pretrained_name = "Transformers-CFG/llama-7B-tokenizer" 10 | 11 | def setup(self): 12 | self.setup_tokenizer() 13 | -------------------------------------------------------------------------------- /tests/test_accept_token_sequence/test_mistral.py: -------------------------------------------------------------------------------- 1 | from transformers import LlamaTokenizerFast 2 | from tests.test_accept_token_sequence._test_accept_tokens_mixin import ( 3 | TokenizerTesterMixin, 4 | ) 5 | 6 | 7 | class TestMistralTokenizer(TokenizerTesterMixin): 8 | tokenizer_class = LlamaTokenizerFast 9 | pretrained_name = "Transformers-CFG/Mistral-7B-v0.1-tokenizer" 10 | 11 | def setup(self): 12 | self.setup_tokenizer() 13 | -------------------------------------------------------------------------------- /examples/grammars/emoji.ebnf: -------------------------------------------------------------------------------- 1 | # This contains the Emoticons Unicode block (U+1F600–U+1F64F) , which represents 80 emoticons(emoji) 2 | # that are commonly used in text messaging and social media. c.f. https://en.wikipedia.org/wiki/Emoticons_(Unicode_block) 3 | # This doesn't include the Miscellaneous Symbols and Pictographs block (U+1F300–U+1F5FF) and 4 | # the Supplemental Symbols and Pictographs block (U+1F900–U+1F9FF). 5 | root ::= emoji+ 6 | emoji ::= [😀-🙏] 7 | -------------------------------------------------------------------------------- /examples/grammars/PDDL/satellite.ebnf: -------------------------------------------------------------------------------- 1 | root ::= plan 2 | 3 | plan ::= action ( " " action )* 4 | 5 | action ::= "(" ( 6 | ( "switch-on" | "switch-off" ) " " object " " object | 7 | ( "turn-to" | "calibrate" ) " " object " " object " " object | 8 | "take-image" " " object " " object " " object " " object 9 | ) ")" 10 | 11 | # Example, should be input dependent 12 | 13 | object ::= "instrument" [0-5] | "satellite" [0-3] | "direction" [0-5] | "mode" [0-5] 14 | -------------------------------------------------------------------------------- /examples/grammars/json_minimal.ebnf: -------------------------------------------------------------------------------- 1 | 2 | 3 | root ::= object 4 | 5 | object ::= "{" ws ( string ":" ws value ("," ws string ":" ws value)* )? "}" 6 | 7 | value ::= object | array | string | number | ("true" | "false" | "null") ws 8 | 9 | array ::= "[" ws ( value ("," ws value)* )? "]" ws 10 | 11 | string ::= "\"" [a-zA-Z0-9]* "\"" ws 12 | 13 | number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws 14 | 15 | 16 | ws ::= ([ \t\n] ws)? 17 | -------------------------------------------------------------------------------- /examples/grammars/unicode/emoji_escape.ebnf: -------------------------------------------------------------------------------- 1 | # This contains the Emoticons Unicode block (U+1F600–U+1F64F) , which represents 80 emoticons(emoji) 2 | # that are commonly used in text messaging and social media. c.f. https://en.wikipedia.org/wiki/Emoticons_(Unicode_block) 3 | # This doesn't include the Miscellaneous Symbols and Pictographs block (U+1F300–U+1F5FF) and 4 | # the Supplemental Symbols and Pictographs block (U+1F900–U+1F9FF). 5 | root ::= emoji+ 6 | emoji ::= [\U0001F600-\U0001F64F] 7 | -------------------------------------------------------------------------------- /examples/grammars/custom_json_grammars/schemas/student.json: -------------------------------------------------------------------------------- 1 | { 2 | "type": "object", 3 | "properties": { 4 | "name": { 5 | "type": "string" 6 | }, 7 | "age": { 8 | "type": "number" 9 | }, 10 | "is_student": { 11 | "type": "boolean" 12 | }, 13 | "courses": { 14 | "type": "array", 15 | "items": { 16 | "type": "string" 17 | } 18 | } 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /tests/test_accept_token_sequence/test_phi.py: -------------------------------------------------------------------------------- 1 | from transformers import CodeGenTokenizerFast 2 | from tests.test_accept_token_sequence._test_accept_tokens_mixin import ( 3 | TokenizerTesterMixin, 4 | ) 5 | 6 | # @unittest.skip("CodeGen is not supported and will be removed") 7 | class TestPhiTokenizer(TokenizerTesterMixin): 8 | tokenizer_class = CodeGenTokenizerFast 9 | pretrained_name = "Transformers-CFG/phi-1_5-tokenizer" 10 | 11 | def setup(self): 12 | self.setup_tokenizer() 13 | -------------------------------------------------------------------------------- /tests/test_accept_token_sequence/test_deepseek.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer 2 | from tests.test_accept_token_sequence._test_accept_tokens_mixin import ( 3 | TokenizerTesterMixin, 4 | ) 5 | 6 | # @unittest.skip("CodeGen is not supported and will be removed") 7 | class TestDeepSeekTokenizer(TokenizerTesterMixin): 8 | tokenizer_class = AutoTokenizer 9 | pretrained_name = "Transformers-CFG/deepseek-coder-1.3b-base-tokenizer" 10 | 11 | def setup(self): 12 | self.setup_tokenizer() 13 | -------------------------------------------------------------------------------- /tests/test_accept_token_sequence/test_llama3.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer 2 | from tests.test_accept_token_sequence._test_accept_tokens_mixin import ( 3 | TokenizerTesterMixin, 4 | ) 5 | 6 | 7 | class TestLlama3Tokenizer(TokenizerTesterMixin): 8 | 9 | # This also applies to Llama3.1, Llama3.2, Llama3.3 as they share the same tokenizer 10 | tokenizer_class = AutoTokenizer 11 | pretrained_name = "Transformers-CFG/Meta-Llama-3-8B-tokenizer" 12 | 13 | def setup(self): 14 | self.setup_tokenizer() 15 | -------------------------------------------------------------------------------- /examples/grammars/PDDL/satellite_typed.ebnf: -------------------------------------------------------------------------------- 1 | root ::= plan 2 | 3 | plan ::= action ( " " action )* 4 | 5 | action ::= "(" ( 6 | ( "switch-on" | "switch-off" ) " " instrument " " satellite | 7 | "turn-to" " " satellite " " direction " " direction | 8 | "calibrate" " " satellite " " instrument " " direction | 9 | "take-image" " " satellite " " direction " " instrument " " mode 10 | ) ")" 11 | 12 | # Example, should be input dependent 13 | 14 | instrument ::= "instrument" [0-5] 15 | 16 | satellite ::= "satellite" [0-3] 17 | 18 | direction ::= "direction" [0-5] 19 | 20 | mode ::= "mode" [0-5] 21 | -------------------------------------------------------------------------------- /transformers_cfg/tokenization/SUPPORTED_TOKENIZERS.py: -------------------------------------------------------------------------------- 1 | from transformers import ( 2 | GPT2TokenizerFast, 3 | BartTokenizerFast, 4 | LlamaTokenizerFast, 5 | T5TokenizerFast, 6 | CodeGenTokenizerFast, 7 | PreTrainedTokenizerFast, 8 | GemmaTokenizerFast, 9 | Qwen2TokenizerFast, 10 | ByT5Tokenizer, 11 | ) 12 | 13 | SUPPORTED_TOKENIZERS = { 14 | GPT2TokenizerFast, 15 | BartTokenizerFast, 16 | LlamaTokenizerFast, 17 | T5TokenizerFast, 18 | CodeGenTokenizerFast, 19 | PreTrainedTokenizerFast, 20 | GemmaTokenizerFast, 21 | Qwen2TokenizerFast, 22 | ByT5Tokenizer, 23 | } 24 | -------------------------------------------------------------------------------- /examples/grammars/chess.ebnf: -------------------------------------------------------------------------------- 1 | # Specifies chess moves as a list in algebraic notation, using PGN conventions 2 | 3 | # Force first move to "1. ", then any 1-2 digit number after, relying on model to follow the pattern 4 | root ::= "1. " move " " move "\n" ([1-9] [0-9]? ". " move " " move "\n")+ 5 | move ::= (pawn | nonpawn | castle) [+#]? 6 | 7 | # piece type, optional file/rank, optional capture, dest file & rank 8 | nonpawn ::= [NBKQR] [a-h]? [1-8]? "x"? [a-h] [1-8] 9 | 10 | # optional file & capture, dest file & rank, optional promotion 11 | pawn ::= ([a-h] "x")? [a-h] [1-8] ("=" [NBKQR])? 12 | 13 | castle ::= "O-O" "-O"? 14 | -------------------------------------------------------------------------------- /examples/grammars/PDDL/depot.ebnf: -------------------------------------------------------------------------------- 1 | 2 | root ::= plan 3 | 4 | plan ::= action ( " " action )* 5 | 6 | action ::= "(" ( 7 | "drive" " " object " " object " " object | 8 | ( "lift" | "load" | "unload" | "drive-and-load" | "drop" ) " " object " " object " " object " " object | 9 | ( "drive-and-lift" | "drive-and-unload" ) " " object " " object " " object " " object " " object | 10 | "lift-and-drive" " " object " " object " " object " " object " " object " " object 11 | ) ")" 12 | 13 | # Example, should be input dependent 14 | 15 | object ::= "truck" [0-1] | "hoist" [0-2] | "crate" [0-5] | "pallet" [0-5] | "depot" [0-1] | "distributor" [0-1] 16 | -------------------------------------------------------------------------------- /examples/grammars/json.ebnf: -------------------------------------------------------------------------------- 1 | # Grammar for subset of JSON 2 | # String doesn't support unicode and escape yet 3 | # If you don't need to generate unicode and escape, you can use this grammar 4 | # We are working to support unicode and escape 5 | 6 | root ::= object 7 | 8 | object ::= "{" ws ( string ":" ws value ("," ws string ":" ws value)* )? "}" 9 | 10 | value ::= object | array | string | number | ("true" | "false" | "null") ws 11 | 12 | array ::= "[" ws ( value ("," ws value)* )? "]" ws 13 | 14 | string ::= "\"" [ \t!#-\[\]-~]* "\"" ws 15 | 16 | number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws 17 | 18 | 19 | ws ::= ([ \t\n] ws)? 20 | -------------------------------------------------------------------------------- /transformers_cfg/logging_config.py: -------------------------------------------------------------------------------- 1 | # logging_config.py 2 | import os 3 | import logging 4 | 5 | 6 | def setup_logging(): 7 | log_level_name = os.getenv( 8 | "TCFG_LOG_LEVEL", "WARNING" 9 | ).upper() # Default to WARNING if not set 10 | log_levels = { 11 | "DEBUG": logging.DEBUG, 12 | "INFO": logging.INFO, 13 | "WARNING": logging.WARNING, 14 | "ERROR": logging.ERROR, 15 | "CRITICAL": logging.CRITICAL, 16 | } 17 | log_level = log_levels.get(log_level_name, logging.WARNING) 18 | # Create a logger for the library 19 | logger = logging.getLogger("transformers_cfg") 20 | # the level will propagate to loggers with submodule scope 21 | logger.setLevel(log_level) 22 | -------------------------------------------------------------------------------- /docs/contribute.md: -------------------------------------------------------------------------------- 1 | # Contribute 2 | 3 | We welcome contributions to the project. 4 | 5 | To contribute, please follow these steps: 6 | 1. Fork the repository. 7 | 2. Create a new branch for your changes with a descriptive name, e.g. `git checkout -b feature/add-support-for-xyz` or `git checkout -b fix/parsing-error-in-abc`. 8 | 3. Create an environment with the required dependencies via `pip install -r requirements.txt`. 9 | 4. Install `pre-commit` hooks via `pre-commit install`. 10 | 5. Make your changes and add tests to ensure your changes are correct. 11 | 6. Commit them, `pre-commit` will run automatically when you commit. Tests will be run to ensure your changes are correct. 12 | 7. If all tests pass, push your changes to your fork and create a pull request. 13 | -------------------------------------------------------------------------------- /examples/grammars/custom_json_grammars/schemas/product_catalog.json: -------------------------------------------------------------------------------- 1 | { 2 | "title": "Product", 3 | "description": "A product from my catalog", 4 | "type": "object", 5 | "properties": { 6 | "productId": { 7 | "description": "The unique identifier for a product", 8 | "type": "integer" 9 | }, 10 | "productName": { 11 | "description": "Name of the product", 12 | "type": "string" 13 | }, 14 | "price": { 15 | "description": "The price of the product", 16 | "type": "number", 17 | "exclusiveMinimum": 0 18 | } 19 | }, 20 | "required": [ 21 | "productId", 22 | "productName", 23 | "price" 24 | ] 25 | } 26 | -------------------------------------------------------------------------------- /tests/test_accept_string/test_json_arr.py: -------------------------------------------------------------------------------- 1 | from tests.json_utils import is_json_parsable 2 | from transformers_cfg.parser import parse_ebnf 3 | from transformers_cfg.recognizer import StringRecognizer 4 | 5 | 6 | def test_minimal_json_array(): 7 | """ 8 | Test that we can load a JSON array 9 | """ 10 | jsons = [ 11 | "[\\n]", 12 | "[\\n1]", 13 | "[\\n1,2]", 14 | "[\\n1,2,3]", 15 | ] 16 | with open("examples/grammars/json_arr.ebnf", "r") as file: 17 | input_text = file.read() 18 | parsed_grammar = parse_ebnf(input_text) 19 | start_rule_id = parsed_grammar.symbol_table["root"] 20 | recognizer = StringRecognizer(parsed_grammar.grammar_encoding, start_rule_id) 21 | 22 | for json in jsons: 23 | assert is_json_parsable(json) == recognizer._accept_prefix( 24 | json 25 | ), f"Failed on {json}" 26 | -------------------------------------------------------------------------------- /examples/grammars/json_arr.ebnf: -------------------------------------------------------------------------------- 1 | # This is the same as json.gbnf but we restrict whitespaces at the end of the root array 2 | # Useful for generating JSON arrays 3 | # String doesn't support unicode and escape yet 4 | 5 | root ::= arr 6 | value ::= object | array | string | number | ("true" | "false" | "null") ws 7 | 8 | arr ::= 9 | "[\n" ws ( 10 | value 11 | (",\n" ws value)* 12 | )? "]" 13 | 14 | object ::= 15 | "{" ws ( 16 | string ":" ws value 17 | ("," ws string ":" ws value)* 18 | )? "}" ws 19 | 20 | array ::= 21 | "[" ws ( 22 | value 23 | ("," ws value)* 24 | )? "]" ws 25 | 26 | string ::= "\"" [ \t!#-\[\]-~]* "\"" ws 27 | 28 | number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws 29 | 30 | # Optional space: by convention, applied in this grammar after literal chars when allowed 31 | ws ::= ([ \t\n] ws)? 32 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Run Tests 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | tags: 8 | - '*' 9 | pull_request: 10 | branches: 11 | - main 12 | 13 | jobs: 14 | test: 15 | name: Run Tests 16 | runs-on: ubuntu-22.04 17 | 18 | steps: 19 | - uses: actions/checkout@v3 20 | 21 | - name: Set up Python 3.9 22 | uses: actions/setup-python@v3 23 | with: 24 | python-version: '3.9' 25 | 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install -e .[dev] 30 | pip install huggingface_hub 31 | 32 | - name: Save Hugging Face API Token 33 | run: | 34 | python -c "from huggingface_hub.hf_api import HfFolder; HfFolder.save_token('${{ secrets.HF_API_TOKEN }}')" 35 | 36 | - name: Run tests 37 | run: | 38 | pytest tests/ 39 | -------------------------------------------------------------------------------- /examples/grammars/PDDL/depot_typed.ebnf: -------------------------------------------------------------------------------- 1 | 2 | root ::= plan 3 | 4 | plan ::= action ( " " action )* 5 | 6 | action ::= "(" ( 7 | "drive" " " truck " " place " " place | 8 | ( "lift" | "drop" ) " " hoist " " crate " " surface " " place | 9 | ( "load" | "unload" ) " " hoist " " crate " " truck " " place | 10 | "drive-and-load" " " truck " " hoist " " crate " " place | 11 | ( "drive-and-lift" | "drive-and-unload" ) " " truck " " hoist " " crate " " surface " " place | 12 | "lift-and-drive" " " truck " " hoist " " crate " " surface " " place " " place 13 | ) ")" 14 | 15 | # Example, should be input dependent 16 | 17 | truck ::= "truck" [0-1] 18 | 19 | hoist ::= "hoist" [0-2] 20 | 21 | crate ::= "crate" [0-5] 22 | 23 | pallet ::= "pallet" [0-5] 24 | 25 | depot ::= "depot" [0-1] 26 | 27 | distributor ::= "distributor" [0-1] 28 | 29 | place ::= depot | distributor 30 | 31 | surface ::= pallet | crate 32 | -------------------------------------------------------------------------------- /examples/grammars/custom_json_grammars/grammars/product_catalog.ebnf: -------------------------------------------------------------------------------- 1 | char ::= [^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) 2 | decimal-part ::= [0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9])?)?)?)?)?)?)?)?)?)?)?)?)?)?)? 3 | integer ::= ("-"? integral-part) space 4 | integral-part ::= [0-9] | [1-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9])?)?)?)?)?)?)?)?)?)?)?)?)?)?)? 5 | number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space 6 | price-kv ::= "\"price\"" space ":" space number 7 | productId-kv ::= "\"productId\"" space ":" space integer 8 | productName-kv ::= "\"productName\"" space ":" space string 9 | root ::= "{" space productId-kv "," space productName-kv "," space price-kv "}" space 10 | space ::= " "? 11 | string ::= "\"" char* "\"" space 12 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 EPFL Data Science Lab (dlab) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /examples/grammars/custom_json_grammars/grammars/student.ebnf: -------------------------------------------------------------------------------- 1 | age-kv ::= "\"age\"" space ":" space number 2 | age-rest ::= ( "," space courses-kv )? 3 | boolean ::= ("true" | "false") space 4 | char ::= [^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) 5 | courses ::= "[" space (string ("," space string)*)? "]" space 6 | courses-kv ::= "\"courses\"" space ":" space courses 7 | decimal-part ::= [0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9])?)?)?)?)?)?)?)?)?)?)?)?)?)?)? 8 | integral-part ::= [0-9] | [1-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9])?)?)?)?)?)?)?)?)?)?)?)?)?)?)? 9 | is-student-kv ::= "\"is_student\"" space ":" space boolean 10 | is-student-rest ::= ( "," space name-kv )? name-rest 11 | name-kv ::= "\"name\"" space ":" space string 12 | name-rest ::= ( "," space age-kv )? age-rest 13 | number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space 14 | root ::= "{" space (is-student-kv is-student-rest | name-kv name-rest | age-kv age-rest | courses-kv )? "}" space 15 | space ::= " "? 16 | string ::= "\"" char* "\"" space 17 | -------------------------------------------------------------------------------- /transformers_cfg/tokenization/utils.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from typing import Dict 3 | from transformers_cfg.tokenization.SUPPORTED_TOKENIZERS import SUPPORTED_TOKENIZERS 4 | 5 | 6 | def replace_hex(match): 7 | hex_value = match.group(1) 8 | return chr(int(hex_value, 16)) 9 | 10 | 11 | # This will collect all imported classes from the current module (globals()) 12 | def get_imported_tokenizer_classes(module_globals) -> Dict[str, type]: 13 | return { 14 | name: obj 15 | for name, obj in module_globals.items() 16 | if inspect.isclass(obj) and name.endswith("TokenizerFast") 17 | } 18 | 19 | 20 | def get_tokenizer_real_class(hf_tokenizer): 21 | return hf_tokenizer.__class__ 22 | 23 | 24 | def is_tokenizer_supported(hf_tokenizer_or_name): 25 | if isinstance(hf_tokenizer_or_name, str): 26 | from transformers import AutoTokenizer 27 | 28 | hf_tokenizer = AutoTokenizer.from_pretrained(hf_tokenizer_or_name) 29 | else: 30 | hf_tokenizer = hf_tokenizer_or_name 31 | return get_tokenizer_real_class(hf_tokenizer) in SUPPORTED_TOKENIZERS 32 | 33 | 34 | def get_tokenizer_charset(hf_tokenizer): 35 | return set("".join(hf_tokenizer.vocab.keys())) 36 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | from transformers_cfg import __version__ 3 | 4 | setup( 5 | name="transformers_cfg", 6 | version=__version__, 7 | author="EPFL-dlab", 8 | author_email="saibo.geng@epfl.ch", 9 | description="Extension of Transformers library for Context-Free Grammar Constrained Decoding with EBNF grammars", 10 | long_description=open("README.md").read(), 11 | long_description_content_type="text/markdown", 12 | url="https://github.com/epfl-dlab/transformers-CFG", 13 | packages=find_packages(), 14 | license="MIT", 15 | classifiers=[ 16 | "Programming Language :: Python :: 3", 17 | "License :: OSI Approved :: MIT License", 18 | "Operating System :: OS Independent", 19 | ], 20 | install_requires=open("requirements.txt").read().splitlines(), 21 | package_data={ 22 | "transformers_cfg": ["examples/grammars/*.ebnf"], 23 | }, 24 | include_package_data=True, 25 | entry_points={ 26 | "console_scripts": [ 27 | "transformers-cfg-cli=transformers_cfg.cli.cli_main:main", 28 | ], 29 | }, 30 | extras_require={ 31 | "dev": [ 32 | "pre-commit", 33 | "pytest", 34 | "twine", 35 | ], 36 | }, 37 | ) 38 | -------------------------------------------------------------------------------- /tests/test_accept_string/test_unicode.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from transformers_cfg.recognizer import StringRecognizer 3 | from transformers_cfg.parser import parse_ebnf 4 | 5 | 6 | @pytest.fixture(scope="module") 7 | def japanese_recognizer(): 8 | with open("examples/grammars/japanese.ebnf", "r") as file: 9 | input_text = file.read() 10 | parsed_grammar = parse_ebnf(input_text) 11 | start_rule_id = parsed_grammar.symbol_table["root"] 12 | return StringRecognizer(parsed_grammar.grammar_encoding, start_rule_id) 13 | 14 | 15 | @pytest.fixture(scope="module") 16 | def emoji_recognizer(): 17 | with open("examples/grammars/emoji.ebnf", "r") as file: 18 | input_text = file.read() 19 | parsed_grammar = parse_ebnf(input_text) 20 | start_rule_id = parsed_grammar.symbol_table["root"] 21 | return StringRecognizer(parsed_grammar.grammar_encoding, start_rule_id) 22 | 23 | 24 | def test_accept_japanese(japanese_recognizer): 25 | """ 26 | Test that we can accept japanese characters 27 | """ 28 | japanese = "こんにちは世界" 29 | assert japanese_recognizer._accept_prefix(japanese) 30 | 31 | 32 | def test_emoji(emoji_recognizer): 33 | """ 34 | Test that we can accept emoji 35 | """ 36 | emoji = "😀😄😂" 37 | assert emoji_recognizer._accept_prefix(emoji) 38 | -------------------------------------------------------------------------------- /.github/workflows/pipy_release.yml: -------------------------------------------------------------------------------- 1 | name: Publish Python 🐍 distributions 📦 to PyPI 2 | 3 | on: 4 | push: 5 | tags: 6 | - '*' # Trigger on any tagging event 7 | 8 | jobs: 9 | build-n-publish: 10 | name: Build and publish Python 🐍 distributions 📦 to PyPI 11 | runs-on: ubuntu-22.04 12 | steps: 13 | - uses: actions/checkout@v3 # Updated to a specific version of the checkout action 14 | with: 15 | ref: main # Ensure it checks out the main branch 16 | - name: Set up Python 3.9 17 | uses: actions/setup-python@v3 18 | with: 19 | python-version: '3.9' 20 | - name: Install setuptools and wheel 21 | run: python -m pip install setuptools wheel 22 | - name: Extract tag name 23 | id: tag 24 | run: echo ::set-output name=TAG_NAME::$(echo $GITHUB_REF | cut -d / -f 3) 25 | - name: Extract version from __init__.py 26 | id: get_version 27 | run: | 28 | VERSION=$(grep '__version__' your_package/__init__.py | cut -d'"' -f2) 29 | echo "VERSION=$VERSION" >> $GITHUB_ENV 30 | - name: Build a binary wheel 31 | run: >- 32 | python setup.py sdist bdist_wheel 33 | - name: Publish distribution 📦 to PyPI 34 | uses: pypa/gh-action-pypi-publish@master 35 | with: 36 | password: ${{ secrets.PYPI_API_TOKEN }} 37 | -------------------------------------------------------------------------------- /examples/benchmarking/run_generation.sh: -------------------------------------------------------------------------------- 1 | grammar_path=$1 2 | grammar_name=$(basename $grammar_path) 3 | prompts_path=$2 4 | model_id=${3:-"openai-community/gpt2"} 5 | model_name=$(echo $model_id | sed 's/\//_/g') 6 | device=${4:-"cpu"} 7 | 8 | current_date="`date +%Y:%m:%d-%H:%M:%S`" 9 | logs_file="logs/$grammar_name-$model_name-$device-$current_date.tsv" 10 | tmp_file="tmp_$current_date.txt" 11 | echo $logs_file 12 | 13 | touch $logs_file 14 | echo -e "prompt\tn_tokens\trun_id\ttotal_time\ttime_per_token\tdevice\tmodel_id\tconstrained_time\tunconstrained_time" >> $logs_file 15 | for max_new_tokens in 1 2 4 8 16 32 64 128 256 16 | do 17 | echo "Max new tokens: $max_new_tokens" 18 | while IFS= read -r prompt 19 | do 20 | echo "Prompt: $prompt" 21 | for run_id in {1..5} 22 | do 23 | echo "Measurment: $run_id" 24 | kernprof -b --skip-zero -v time_benchmarking.py $grammar_path "$prompt" $max_new_tokens $model_id > $tmp_file 25 | unconstrained_time=$(cat $tmp_file | grep "Unconstrained time: " | awk '{print $3;}') 26 | constrained_time=$(cat $tmp_file | grep "Constrained time: " | awk '{print $3;}') 27 | (cat $tmp_file | grep "(process_logits)" | awk -v ut=$unconstrained_time -v ct=$constrained_time -v p="$prompt" -v rid=$run_id -v mid=$model_id -v d=$device '{OFS = "\t"} {print p,$1,rid,$4,$5,d,mid,ct,ut}') >> $logs_file 28 | done; 29 | done < "$prompts_path" 30 | done; 31 | rm $tmp_file 32 | -------------------------------------------------------------------------------- /examples/grammars/overnight.ebnf: -------------------------------------------------------------------------------- 1 | root ::= "(listValue " list_value ")" 2 | 3 | 4 | list_value ::= "(filter " ( list_value " " PROPERTY | list_value " " PROPERTY OP list_value | list_value " " "(ensureNumericProperty " PROPERTY ")" OP "(ensureNumericEntity " list_value ")" ) ")" | 5 | "(superlative " list_value AGGREGATE "(ensureNumericProperty " PROPERTY "))" | 6 | "(countSuperlative " list_value AGGREGATE PROPERTY ( " " list_value)? ")" | 7 | "(countComparative " list_value " " PROPERTY OP list_value ( " " list_value)? ")" | 8 | "(_size " list_value ")" | 9 | "(aggregate" AGGREGATE list_value ")" | 10 | "(getProperty " ( list_value " " PROPERTY | "(singleton " SINGLETON_VALUE ") " "!type" ) ")" | 11 | "(concat " ( ENTITY_VALUE " " ENTITY_VALUE | NUMBER_VALUE " " NUMBER_VALUE ) ")" | 12 | ENTITY_VALUE | NUMBER_VALUE 13 | 14 | PROPERTY ::= "shape" | "color" | "length" | "is_special" | "width" | "height" | "left" | "right" | "above" | "below" | 15 | "(reverse " ( "left" | "right" | "above" | "below" ) ")" 16 | 17 | 18 | SINGLETON_VALUE ::= "en.block" | "en.shape" | "en.color" 19 | 20 | ENTITY_VALUE ::= "en.block.block1" | "en.block.block2" | "en.shape.pyramid" | "en.shape.cube" | "en.color.red" | "en.color.green" 21 | 22 | NUMBER_VALUE ::= ( "3" | "6" ) " " "en.inch" | "2" 23 | 24 | 25 | OP ::= " " ( "=" | ">" | "<" | ">=" | "<=" | "!=" ) " " 26 | 27 | AGGREGATE ::= " " ("sum" | "max" | "min" | "avg" ) " " 28 | -------------------------------------------------------------------------------- /examples/grammars/c.ebnf: -------------------------------------------------------------------------------- 1 | root ::= (declaration)* 2 | 3 | declaration ::= dataType identifier "(" parameter? ")" "{" statement* "}" 4 | 5 | dataType ::= "int" ws | "float" ws | "char" ws 6 | identifier ::= [a-zA-Z_] [a-zA-Z_0-9]* 7 | 8 | parameter ::= dataType identifier 9 | 10 | statement ::= 11 | ( dataType identifier ws "=" ws expression ";" ) | 12 | ( identifier ws "=" ws expression ";" ) | 13 | ( identifier ws "(" argList? ")" ";" ) | 14 | ( "return" ws expression ";" ) | 15 | ( "while" "(" condition ")" "{" statement* "}" ) | 16 | ( "for" "(" forInit ";" ws condition ";" ws forUpdate ")" "{" statement* "}" ) | 17 | ( "if" "(" condition ")" "{" statement* "}" ("else" "{" statement* "}")? ) | 18 | ( singleLineComment ) | 19 | ( multiLineComment ) 20 | 21 | forInit ::= dataType identifier ws "=" ws expression | identifier ws "=" ws expression 22 | forUpdate ::= identifier ws "=" ws expression 23 | 24 | condition ::= expression relationOperator expression 25 | relationOperator ::= ("<=" | "<" | "==" | "!=" | ">=" | ">") 26 | 27 | expression ::= term (("+" | "-") term)* 28 | term ::= factor(("*" | "/") factor)* 29 | 30 | factor ::= identifier | number | unaryTerm | funcCall | parenExpression 31 | unaryTerm ::= "-" factor 32 | funcCall ::= identifier "(" argList? ")" 33 | parenExpression ::= "(" ws expression ws ")" 34 | 35 | argList ::= expression ("," ws expression)* 36 | 37 | number ::= [0-9]+ 38 | 39 | singleLineComment ::= "//" [^\n]* "\n" 40 | multiLineComment ::= "/*" ( [^*] | ("*" [^/]) )* "*/" 41 | 42 | ws ::= ([ \t\n]+) 43 | -------------------------------------------------------------------------------- /tests/test_accept_token_sequence/test_t5.py: -------------------------------------------------------------------------------- 1 | from transformers import T5TokenizerFast 2 | from tests.test_accept_token_sequence._test_accept_tokens_mixin import ( 3 | TokenizerTesterMixin, 4 | ) 5 | 6 | # @unittest.skip("T5Tokenizer's mapping is not well defined, not working") 7 | class TestT5Tokenizer(TokenizerTesterMixin): 8 | tokenizer_class = T5TokenizerFast 9 | pretrained_name = "t5-small" 10 | 11 | def setup(self): 12 | self.setup_tokenizer() 13 | 14 | 15 | class TestT5TokenizerUnkToken: 16 | def test_unk_token(self): 17 | tokenizer = T5TokenizerFast.from_pretrained("t5-small") 18 | 19 | unk_token_id = tokenizer.unk_token_id 20 | unk_token = tokenizer.unk_token 21 | 22 | # open curly brace is an unk token 23 | curly_brace_open = "{" 24 | # we take the 2nd token because the first token is the space token 25 | curly_brace_open_id = tokenizer.encode(curly_brace_open)[1] 26 | assert curly_brace_open_id == unk_token_id 27 | 28 | curly_brace_close = "}" 29 | curly_brace_close_id = tokenizer.encode(curly_brace_close)[1] 30 | assert curly_brace_close_id == unk_token_id 31 | 32 | eos_token_id = tokenizer.eos_token_id 33 | # tab in t5 signifies the end of a line 34 | tab = "\t" 35 | tab_id = tokenizer.encode(tab)[0] 36 | assert tab_id == eos_token_id 37 | 38 | # newline in t5 signifies the end of a line 39 | newline = "\n" 40 | newline_id = tokenizer.encode(newline)[0] 41 | assert newline_id == eos_token_id 42 | -------------------------------------------------------------------------------- /examples/grammars/cIE.ebnf: -------------------------------------------------------------------------------- 1 | # This is just for illustration purposes. Depending on the actual use case, the set of entities and relations can be extended to include more entities and relations. 2 | root ::= triplet ( delim triplet )* 3 | triplet ::= "[s] " subject " [r] " predicate " [o] " object 4 | delim ::= " [e] " 5 | subject ::= subject_entity 6 | predicate ::= relation 7 | object ::= object_entity 8 | subject_entity ::= "Rene Descartes" | "Isaac Newton" | "Albert Einstein" | "Stephen Hawking" | "Galileo Galilei" | "Nikola Tesla" | "Leonardo da Vinci" | "Aristotle" | "Plato" | "Socrates" | "Pythagoras" | "Euclid" | "Archimedes" | "Hippocrates" | "Ptolemy" | "Nicolaus Copernicus" | "Johannes Kepler" | "Galileo Galilei" | "Isaac Newton" | "Albert Einstein" | "Stephen Hawking" | "Nikola Tesla" | "Leonardo da Vinci" | "Aristotle" 9 | object_entity ::= "France" | "England" | "Germany" | "Italy" | "Greece" | "Egypt" | "China" | "India" | "Russia" | "USA" | "Canada" | "Brazil" | "Australia" | "Japan" | "South Africa" | "Mexico" | "Argentina" | "Spain" | "Portugal" | "Netherlands" | "Belgium" | "Sweden" | "Norway" | "Denmark" | "Finland" | "Poland" | "Czech Republic" | "Slovakia" | "Hungary" | "Romania" | "Bulgaria" | "Greece" | "Turkey" | "Iran" | "Iraq" | "Syria" 10 | relation ::= "was born in" | "died in" | "lived in" | "worked in" | "studied in" | "invented" | "discovered" | "wrote" | "painted" | "sculpted" | "composed" | "played" | "sang" | "acted" | "directed" | "produced" | "won" | "lost" | "was awarded" | "was nominated" | "was married to" | "was divorced from" | "had children with" | "was friends with" | "was enemies with" 11 | -------------------------------------------------------------------------------- /examples/accept.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | 4 | from transformers_cfg.parser import parse_ebnf 5 | from transformers_cfg.recognizer import StringRecognizer 6 | 7 | logging.basicConfig(level=logging.DEBUG) 8 | 9 | 10 | def main(args): 11 | 12 | with open(args.grammar_file_path, "r") as file: 13 | grammar_str = file.read() 14 | parsed_grammar = parse_ebnf(grammar_str) 15 | start_rule_id = parsed_grammar.symbol_table["root"] 16 | recognizer = StringRecognizer(parsed_grammar.grammar_encoding, start_rule_id) 17 | 18 | if args.mode == "prefix": 19 | result = recognizer._accept_prefix(args.sentence) 20 | else: 21 | result = recognizer._accept_string(args.sentence) 22 | 23 | print(result) 24 | 25 | 26 | if __name__ == "__main__": 27 | parser = argparse.ArgumentParser( 28 | description="Generate text with grammar constraints." 29 | ) 30 | parser.add_argument( 31 | "-g", 32 | "--grammar_file_path", 33 | type=str, 34 | required=True, 35 | help="Path to the grammar file (supports both relative and absolute paths)", 36 | ) 37 | parser.add_argument( 38 | "-s", "--sentence", type=str, required=True, help="Prefix prompt for generation" 39 | ) 40 | parser.add_argument( 41 | "-m", 42 | "--mode", 43 | type=str, 44 | choices=["prefix", "sentence"], 45 | default="prefix", 46 | help="Mode of operation, " 47 | "prefix mode accepts a prefix string, sentence mode only accepts a full sentence", 48 | ) 49 | 50 | args = parser.parse_args() 51 | main(args) 52 | -------------------------------------------------------------------------------- /examples/generate_llama_cpp_python.py: -------------------------------------------------------------------------------- 1 | import io 2 | import torch 3 | import logging 4 | from contextlib import redirect_stderr 5 | from llama_cpp import Llama 6 | from transformers_cfg.grammar_utils import IncrementalGrammarConstraint 7 | from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor 8 | from transformers import AutoTokenizer 9 | 10 | logging.basicConfig(level=logging.INFO) 11 | 12 | # Define your EBNF grammar (you can replace this with your own) 13 | ebnf_grammar = """ 14 | 15 | root ::= "The animal is a " animal "." 16 | 17 | animal ::= "cat" | "fish" 18 | 19 | """ 20 | 21 | # Load the tokenizer matching your model 22 | tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5b") 23 | 24 | # Redirect stderr and load the model via llama-cpp-python 25 | f = io.StringIO() 26 | with redirect_stderr(f): 27 | model = Llama(model_path="qwen2.5-1.5b-q8_0.gguf", n_ctx=8000, verbose=False) 28 | 29 | # Create the grammar constraint and the logits processor with the new parameter. 30 | grammar_constraint = IncrementalGrammarConstraint(ebnf_grammar, "root", tokenizer) 31 | grammar_processor = GrammarConstrainedLogitsProcessor(grammar_constraint, adapter="llama-cpp-python") 32 | 33 | # Define a prompt. 34 | prompt = """The text says, "The animal is a dog." The answer is obvious. """ 35 | 36 | # Use the text completion API with the logits processor. 37 | response = model.create_completion( 38 | stream=True, 39 | prompt=prompt, 40 | logits_processor=[grammar_processor], 41 | max_tokens=100, 42 | ) 43 | 44 | for token in response: 45 | token_text = token["choices"][0]["text"] 46 | print(token_text, end="", flush=True) 47 | -------------------------------------------------------------------------------- /examples/run_seq2seq_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | t5 tokenizer has a lot of unk tokens, such as open curly brace, close curly brace, tab, newline, etc. 3 | 4 | """ 5 | import transformers 6 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM 7 | from transformers_cfg.grammar_utils import ( 8 | IncrementalGrammarConstraint, 9 | ) 10 | from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor 11 | 12 | import logging 13 | 14 | logging.basicConfig(level=logging.DEBUG) 15 | transformers.logging.set_verbosity_debug() 16 | 17 | 18 | if __name__ == "__main__": 19 | 20 | model_name = "facebook/bart-base" 21 | # model_name = "google-t5/t5-base" 22 | 23 | # Load model and tokenizer 24 | tokenizer = AutoTokenizer.from_pretrained(model_name) 25 | # tokenizer.pad_token = tokenizer.eos_token 26 | model = AutoModelForSeq2SeqLM.from_pretrained(model_name) 27 | # resize the embedding layer to match the tokenizer 28 | model.resize_token_embeddings(len(tokenizer)) 29 | 30 | # Load json grammar 31 | with open("examples/grammars/cIE.ebnf", "r") as file: 32 | grammar_str = file.read() 33 | grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer) 34 | grammar_processor = GrammarConstrainedLogitsProcessor(grammar) 35 | 36 | # Generate 37 | prefix1 = " entity1, relation1 , entity2 => " 38 | input_ids = tokenizer( 39 | [prefix1], add_special_tokens=False, return_tensors="pt", padding=True 40 | )["input_ids"] 41 | 42 | output = model.generate( 43 | input_ids, 44 | do_sample=False, 45 | max_new_tokens=60, 46 | num_beams=1, 47 | logits_processor=[grammar_processor], 48 | num_return_sequences=1, 49 | ) 50 | # decode output 51 | generations = tokenizer.batch_decode(output, skip_special_tokens=True) 52 | print(generations) 53 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v3.4.0 # Use the ref you want to point at 4 | hooks: 5 | - id: trailing-whitespace # Trims trailing whitespace. 6 | - id: end-of-file-fixer # Makes sure files end in a newline and only a newline. 7 | - id: check-yaml # Checks yaml files for syntax validity 8 | - id: check-added-large-files # Checks for large files being added to git, > 10MB 9 | args: ['--maxkb=1000'] 10 | exclude: "notebooks" 11 | - id: check-case-conflict # Checks for files whose names differ only in case 12 | - id: check-merge-conflict # Checks for files that contain merge conflict strings. 13 | - id: debug-statements # Prevents the accidental commit of breakpoint(), pdb.set_trace(), etc. 14 | - id: check-ast # Simply check whether files parse as valid python. 15 | - id: check-docstring-first # Checks for a common error of placing code before the docstring. 16 | - id: check-executables-have-shebangs # Checks that executable files have a shebang. 17 | - id: detect-private-key # Detects the presence of private keys. 18 | - id: detect-aws-credentials 19 | args: 20 | - --allow-missing-credentials 21 | - repo: https://github.com/ambv/black # automatic format code style 22 | rev: 22.3.0 23 | hooks: 24 | - id: black 25 | - repo: local 26 | hooks: 27 | - id: pytest 28 | name: pytest 29 | entry: pytest tests 30 | language: system 31 | types: [ python ] 32 | pass_filenames: false 33 | stages: [ commit ] 34 | 35 | - repo: https://github.com/PyCQA/autoflake # remove unused imports 36 | rev: v2.2.1 37 | hooks: 38 | - id: autoflake 39 | args: [--remove-all-unused-imports, --in-place] 40 | 41 | - repo: https://github.com/compilerla/conventional-pre-commit 42 | rev: v3.2.0 43 | hooks: 44 | - id: conventional-pre-commit 45 | stages: [commit-msg] 46 | args: [] 47 | -------------------------------------------------------------------------------- /docs/add_new_model_support.md: -------------------------------------------------------------------------------- 1 | # Add new model support 2 | 3 | In case you want to use a new model that is not supported yet, here is a guide to add support for it. 4 | 5 | In the following guide, we will use the newly released `meta-llama/Meta-Llama-3-8B` model as an example. 6 | 7 | 8 | ### Step 1: Check if the model is supported 9 | 10 | `transformers-cfg-cli` is a command-line tool that can be used to check if a model is supported by `transformers-cfg`. 11 | 12 | ```bash 13 | transformers-cfg-cli check meta-llama/Meta-Llama-3-8B 14 | # Model meta-llama/Meta-Llama-3-8B is not supported. 15 | # OR 16 | # Model meta-llama/Meta-Llama-3-8B is supported. 17 | ``` 18 | 19 | 20 | 21 | 22 | ### Step 1: Find the underlying tokenizer class 23 | 24 | 25 | ```python 26 | from transformers import AutoTokenizer 27 | 28 | tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B") 29 | 30 | print(tokenizer.__class__) 31 | # transformers.tokenization_utils_fast.PreTrainedTokenizerFast 32 | ``` 33 | 34 | As you can see here, the tokenizer class is `PreTrainedTokenizerFast`. 35 | 36 | There are several caveats to this: 37 | 38 | 39 | 1. Many models can share the same tokenizer class, even though HF sometimes make wrapper classes to make the tokenizer class more user-friendly. 40 | 41 | For example, both `mistralai/Mistral-7B-v0.1` and `meta-llama/Llama-2-7b-hf` use `LlamaTokenizerFast` as their tokenizer class. 42 | 43 | ```python 44 | from transformers import AutoTokenizer 45 | 46 | mistral_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") 47 | llama_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") 48 | 49 | print(mistral_tokenizer.__class__) 50 | # transformers.models.llama.tokenization_llama.LlamaTokenizerFast 51 | 52 | print(llama_tokenizer.__class__) 53 | # transformers.models.llama.tokenization_llama.LlamaTokenizerFast 54 | ``` 55 | 56 | 2. Two models in the same family but different generations can have different tokenizer classes. 57 | 58 | For example, `meta-llama/Meta-Llama-3-8B` uses `PreTrainedTokenizerFast` while `meta-llama/Llama-2-7b-hf` uses `LlamaTokenizerFast`. 59 | 60 | 61 | 62 | ### Step 2: See if the tokenizer class is already supported 63 | -------------------------------------------------------------------------------- /examples/generate_json.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoModelForCausalLM, AutoTokenizer 3 | from transformers_cfg.grammar_utils import IncrementalGrammarConstraint 4 | from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor 5 | 6 | 7 | if __name__ == "__main__": 8 | 9 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 10 | print(f"Using device: {device}") 11 | 12 | model_id = "mistralai/Mistral-7B-v0.1" 13 | 14 | # Load model and tokenizer 15 | tokenizer = AutoTokenizer.from_pretrained(model_id) 16 | tokenizer.pad_token = tokenizer.eos_token 17 | 18 | model = AutoModelForCausalLM.from_pretrained(model_id).to( 19 | device 20 | ) # Load model to defined device 21 | model.generation_config.pad_token_id = model.generation_config.eos_token_id 22 | 23 | # Load grammar 24 | with open("examples/grammars/json.ebnf", "r") as file: 25 | grammar_str = file.read() 26 | grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer) 27 | grammar_processor = GrammarConstrainedLogitsProcessor(grammar) 28 | 29 | # Generate 30 | prefix1 = "This is a valid json string for http request:" 31 | prefix2 = "This is a valid json string for shopping cart:" 32 | input_ids = tokenizer( 33 | [prefix1, prefix2], add_special_tokens=False, return_tensors="pt", padding=True 34 | )["input_ids"].to( 35 | device 36 | ) # Move input_ids to the same device as model 37 | 38 | output = model.generate( 39 | input_ids, 40 | do_sample=False, 41 | max_new_tokens=60, 42 | logits_processor=[grammar_processor], 43 | repetition_penalty=1.1, 44 | num_return_sequences=1, 45 | ) 46 | # decode output 47 | generations = tokenizer.batch_decode(output, skip_special_tokens=True) 48 | print(generations) 49 | 50 | """ 51 | 'This is a valid json string for http request:{ "request": { "method": "GET", "headers": [], "content": "Content","type": "application" }} 52 | 'This is a valid json string for shopping cart:This is a valid json string for shopping cart:{ "name": "MyCart", "price": 0, "value": 1 } 53 | """ 54 | -------------------------------------------------------------------------------- /examples/grammars/SMILES/generic.ebnf: -------------------------------------------------------------------------------- 1 | root ::= smiles 2 | 3 | smiles ::= atom ( chain | branch )* 4 | 5 | chain ::= (dot atom | bond? ( atom | ring_closure ) )+ 6 | 7 | branch ::= "(" ( ( dot | bond )? smiles )+ ")" 8 | 9 | atom ::= organic_symbol | aromatic_symbol | atom_spec | wildcard 10 | 11 | bond ::= "-" | "=" | "#" | "$" | ":" | "/" | "\\" 12 | 13 | dot ::= "." 14 | 15 | wildcard ::= "*" 16 | 17 | atom_spec ::= "[" isotope? ( "se" | "as" | aromatic_symbol | element_symbol | wildcard ) chiral_class? h_count? ( charge | class? ) "]" 18 | 19 | organic_symbol ::= "B" | "C" | "N" | "O" | "P" | "S" | "F" | "I" | "Br" | "Cl" | "At" | "Ts" 20 | 21 | aromatic_symbol ::= "b" | "c" | "n" | "o" | "p" | "s" 22 | 23 | element_symbol ::= "A" ( "c" | "g" | "l" | "m" | "r" | "s" | "t" | "u" ) | 24 | "B" ( "a" | "e" | "h" | "i" | "k" | "r" )? | 25 | "C" ( "a" | "d" | "e" | "f" | "l" | "m" | "n" | "o" | "r" | "s" | "u" )? | 26 | "D" ( "b" | "s" | "y" ) | 27 | "E" ( "r" | "s" | "u" ) | 28 | "F" ( "e" | "l" | "m" | "r" )? | 29 | "G" ( "a" | "d" | "e" ) | 30 | "H" ( "e" | "f" | "g" | "o" | "s" )? | 31 | "I" ( "n" | "r" )? | 32 | "K" "r"? | 33 | "L" ( "a" | "i" | "r" | "u" | "v" ) | 34 | "M" ( "c" | "g" | "n" | "o" | "t" ) | 35 | "N" ( "a" | "b" | "d" | "e" | "h" | "i" | "o" | "p" )? | 36 | "O" ( "g" | "s" )? | 37 | "P" ( "a" | "b" | "d" | "m" | "o" | "r" | "t" | "u" )? | 38 | "R" ( "a" | "b" | "e" | "f" | "g" | "h" | "n" | "u" ) | 39 | "S" ( "b" | "c" | "e" | "g" | "i" | "m" | "n" | "r" )? | 40 | "T" ( "a" | "b" | "c" | "e" | "h" | "i" | "l" | "m" | "s" ) | 41 | "U" | "V" | "W" | "Xe" | "Y" "b"? | 42 | "Z" ( "n" | "r" ) 43 | 44 | ring_closure ::= "%" [1-9] [0-9] | [0-9] 45 | 46 | chiral_class ::= ( "@" ( "@" | "TH" [1-2] | "AL" [1-2] | "SP" [1-3] | "TB" ( "1" [0-9]? | "2" "0"? | [3-9] ) | "OH" ( "1" [0-9]? | "2" [0-9]? | "3" "0"? | [4-9] ) )? )? 47 | 48 | charge ::= "-" ( "-" | "0" | "1" [0-5]? | [2-9] )? | "+" ( "+" | "0" | "1" [0-5]? | [2-9] )? 49 | 50 | h_count ::= "H" [0-9]? 51 | 52 | class ::= ":" [0-9]+ 53 | 54 | isotope ::= [1-9] [0-9]? [0-9]? 55 | -------------------------------------------------------------------------------- /tests/test_accept_string/test_decode_utf8.py: -------------------------------------------------------------------------------- 1 | from transformers_cfg.utf8_utils import ( 2 | decode_utf8, 3 | PartialUTF8, 4 | ) # Make sure to import your function and class 5 | 6 | 7 | def test_decode_single_character(): 8 | """Test decoding a single character.""" 9 | utf8_bytes = b"\xe2\x82\xac" # Euro sign 10 | expected_code_points = [ 11 | 8364, 12 | ] # Euro sign code point followed by terminating 0 13 | result, _ = decode_utf8(utf8_bytes, PartialUTF8()) 14 | assert result == expected_code_points 15 | 16 | 17 | def test_decode_multiple_characters(): 18 | """Test decoding a string with multiple UTF-8 characters.""" 19 | utf8_bytes = b"Hello, \xe2\x82\xac!" # "Hello, €!" 20 | expected_code_points = [ 21 | 72, 22 | 101, 23 | 108, 24 | 108, 25 | 111, 26 | 44, 27 | 32, 28 | 8364, 29 | 33, 30 | ] 31 | result, _ = decode_utf8(utf8_bytes, PartialUTF8()) 32 | assert result == expected_code_points 33 | 34 | 35 | def test_handle_incomplete_sequence(): 36 | """Test handling of an incomplete UTF-8 sequence.""" 37 | utf8_bytes = b"\xe2" # Incomplete sequence for the Euro sign 38 | expected_code_points = [] # Expect a 0 due to incomplete sequence 39 | result, partial = decode_utf8(utf8_bytes, PartialUTF8()) 40 | assert result == expected_code_points 41 | starting_value = int.from_bytes(b"\xe2", "big") # 226 42 | offset_value = int.from_bytes(b"\xe0", "big") # 224 43 | assert partial.value == starting_value - offset_value # 226-224=2 44 | assert ( 45 | partial.n_remain == 2 46 | ) # Expect n_remain to be 2 because 2 more bytes are needed 47 | 48 | 49 | def test_continue_incomplete_sequence(): 50 | """Test continuation of decoding with a previously incomplete sequence.""" 51 | utf8_bytes = b"\x82\xac" # Continuation of the Euro sign 52 | partial_start = PartialUTF8( 53 | value=2, n_remain=2 54 | ) # Simulate a previous state expecting 2 more bytes 55 | expected_code_points = [ 56 | 8364, 57 | ] # Completed Euro sign code point 58 | result, _ = decode_utf8(utf8_bytes, partial_start) 59 | assert result == expected_code_points 60 | 61 | 62 | def test_empty_string(): 63 | """Test handling of an empty string.""" 64 | utf8_bytes = b"" 65 | expected_code_points = [] 66 | result, _ = decode_utf8(utf8_bytes, PartialUTF8()) 67 | assert result == expected_code_points 68 | -------------------------------------------------------------------------------- /examples/grammars/custom_json_grammars/README.md: -------------------------------------------------------------------------------- 1 | # Custom json grammars 2 | 3 | You can use custom grammars to constrain the output of a language model to generate valid json objects. This is useful when you want to generate json objects for specific applications, such as http requests or shopping carts. 4 | 5 | ## Quickstart 6 | 7 | There are multiple ways to represent json schemas. 8 | We provide recommendations on how to do this for two common formats: Typescript and json. 9 | 10 |
11 | Example of a Typescript schema for a Student object 12 | 13 | ```Typescript 14 | interface Student{ 15 | name: string; 16 | age: number; 17 | is_student : boolean; 18 | courses: string[]; 19 | } 20 | ``` 21 |
22 | 23 |
24 | Example of a json schema for a Student object 25 | 26 | ```json 27 | { 28 | "type": "object", 29 | "properties": { 30 | "name": {"type": "string"}, 31 | "age": {"type": "number"}, 32 | "is_student": {"type": "boolean"}, 33 | "courses": { 34 | "type": "array", 35 | "items": { "type": "string"} 36 | } 37 | } 38 | } 39 | ``` 40 |
41 | 42 | 43 | ### From Typescript 44 | 45 | To generate custom json grammars from Typescript schemas, you can use [this online tool](https://grammar.intrinsiclabs.ai/) or [this Typescript generator](https://github.com/IntrinsicLabsAI/gbnfgen) from Intrinsic AI. Then, simply copy paste the resulting grammar into a text file and use it with the `IncrementalGrammarConstraint`. 46 | 47 | 48 | ### From json schemas 49 | 50 | Alternatively, you can generate custom json grammars from json format schemas using the `json_schema_to_grammar.py` script, analogous to [the one in the lama.cpp repository](https://github.com/ggerganov/llama.cpp/blob/ab9a3240a9da941fdef5cd4a25f2b97c2f5a67aa/examples/json_schema_to_grammar.py). 51 | 52 | 53 | To generate a grammar from a json schema, run the following command: 54 | 55 | ```bash 56 | python3 json_schema_to_grammar.py -i schemas/product_catalog.json -o grammars/product_catalog.ebnf 57 | ``` 58 | This script generates a grammar from a json schema file (see examples of json schemas in `/schemas` and the corresponding grammars in `/grammars`). The generated grammar is in the Extended Backus-Naur Form (EBNF) format and can be directly used with the `IncrementalGrammarConstraint`. 59 | 60 | Additional arguments allow to specify the property order of the json object as well as string formatting parameters. 61 | -------------------------------------------------------------------------------- /examples/pipeline_json.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | from transformers import pipeline 4 | from transformers import AutoModelForCausalLM, AutoTokenizer 5 | from transformers_cfg.grammar_utils import IncrementalGrammarConstraint 6 | from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor 7 | 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser( 11 | description="Generate json strings with huggingface pipelining" 12 | ) 13 | parser.add_argument( 14 | "--model-id", 15 | type=str, 16 | default="/dlabdata1/llm_hub/Mistral-7B-v0.1", 17 | help="Model ID", 18 | ) 19 | parser.add_argument("--device", type=str, help="Device to put the model on") 20 | return parser.parse_args() 21 | 22 | 23 | def main(): 24 | args = parse_args() 25 | model_id = args.model_id 26 | 27 | # Detect if GPU is available, otherwise use CPU 28 | device = torch.device( 29 | args.device or ("cuda" if torch.cuda.is_available() else "cpu") 30 | ) 31 | print(f"Using device: {device}") 32 | 33 | # Load model and tokenizer 34 | tokenizer = AutoTokenizer.from_pretrained(model_id) 35 | tokenizer.pad_token = tokenizer.eos_token 36 | # Load model to defined device 37 | model = AutoModelForCausalLM.from_pretrained(model_id).to(device) 38 | 39 | # Load grammar 40 | with open(f"examples/grammars/json.ebnf", "r") as file: 41 | grammar_str = file.read() 42 | 43 | grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer) 44 | grammar_processor = GrammarConstrainedLogitsProcessor(grammar) 45 | 46 | # Initialize pipeline 47 | pipe = pipeline( 48 | "text-generation", 49 | model=model, 50 | tokenizer=tokenizer, 51 | device_map="auto", 52 | max_length=50, 53 | batch_size=2, 54 | ) 55 | # # outputs = pipe("This is a valid json string for http request:", do_sample=False, max_length=50) 56 | generations = pipe( 57 | [ 58 | "This is a valid json string for http request: ", 59 | "This is a valid json string for shopping cart: ", 60 | ], 61 | do_sample=False, 62 | logits_processor=[grammar_processor], 63 | ) 64 | 65 | print(generations) 66 | 67 | """ 68 | This is a valid json string for http request: {"name":"John","age":30,"city":"New York"} 69 | This is a valid json string for shopping cart: {"items":[{"id":"1","quantity":"1"},{"id":"2","quantity":"2"}]} 70 | """ 71 | 72 | 73 | if __name__ == "__main__": 74 | main() 75 | -------------------------------------------------------------------------------- /docs/benchmarking.md: -------------------------------------------------------------------------------- 1 | # Benchmarking constrained generation overhead in transformers-CFG 2 | 3 | This document provides guidelines and on benchmarking grammar constrained decoding when working with the `transformers_cfg` library. 4 | 5 | ## Table of Contents 6 | 7 | - [Benchmarking ](#benchmarking-) 8 | - [Analyzing the results ](#analyzing-the-results-) 9 | 10 | 11 | ## Benchmarking 12 | 13 | To measure the grammar-constrined generation overtime, one can use `transformers-CFG/examples/benchmarking/benchmark_generation.sh`. 14 | 15 | It is designed to calculate per-token logits processing latency for different grammars, with verying generation lengths and prompts. 16 | 17 | To run the benchmarking script, you can use the following command: 18 | 19 | ```bash 20 | ./benchmark_generation.sh grammar_path prompts_path hg_model; 21 | 22 | ``` 23 | 24 | Where the arguments are: 25 | - `grammar_path`: the path to the grammar file in .ebnf format. (see `examples/grammars`) 26 | - `prompts_path`: the path to the prompts file in .txt format. (see `examples/prompts`) 27 | - `hg_model`: the [Hugging Face Transformers](https://github.com/huggingface/transformers) model name or path to the model. (e.g. `openai-community/gpt2`) 28 | 29 | The output of the script will be saved in `transformers_cfg/examples/benchmarking/logs` directory in a .tsv format. 30 | 31 | The output contains the following columns: 32 | 33 | - `prompt`: the text of the prompt (see more on the benchmarking prompt design in the `examples/benchmarking/process_benchmarking_logs.ipynb`) 34 | - `n_tokens`: number of tokens generated (can be affected by the `max_new_tokens` parameter) 35 | - `run_id`: run id (each generation is performed 5 times per prompt to account for noise in the execution time measurmnet) 36 | - `total_time`: total overhead (depends on the complexity of the grammar, the model, the prompt and the device) 37 | - `time_per_token`: per token overhead 38 | - `device`: device 39 | - `model_id`: the [Hugging Face Transformers](https://github.com/huggingface/transformers) model name or path to the model 40 | - `constrained_time`: total time of constrained generation (including forward passes) 41 | - `unconstrained_time`: total time of constrained generation (including forward passes) 42 | 43 | ## Analyzing the results 44 | 45 | To aggregare and visualize the results, you can use the `transformers_cfg/examples/benchmarking/process_benchmarking_logs.ipynb` notebook. 46 | 47 | The notebook will load the logs from the `transformers_cfg/examples/benchmarking/logs` directory and provide you with the following visualization: 48 | 49 | ![Json grammar benchmarking](assets/plots/benchmarking_results.png) 50 | -------------------------------------------------------------------------------- /tests/test_accept_string/test_accept_unicode_bytes.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from transformers_cfg.recognizer import StringRecognizer 3 | from transformers_cfg.parser import parse_ebnf 4 | 5 | 6 | def test_accept_japanese(): 7 | """ 8 | Test that we can accept japanese characters 9 | """ 10 | japanese = "こんにちは世界" 11 | with open("examples/grammars/japanese.ebnf", "r") as file: 12 | input_text = file.read() 13 | parsed_grammar = parse_ebnf(input_text) 14 | 15 | start_rule_id = parsed_grammar.symbol_table["root"] 16 | 17 | recognizer = StringRecognizer(parsed_grammar.grammar_encoding, start_rule_id) 18 | 19 | bytes_japanese = bytes(japanese, "utf-8") 20 | logging.debug(f"bytes_japanese: {bytes_japanese} of length {len(bytes_japanese)}") 21 | # こんにちは世界 22 | 23 | head_bytes = bytes_japanese[:8] 24 | parsing_state = recognizer._update_state_with_bytes(head_bytes) 25 | 26 | # non empty stack means that the bytes were accepted 27 | assert len(parsing_state.stacks) > 0 28 | 29 | 30 | def test_accept_japanese_progressive(): 31 | ####################### 32 | # Now consider the case of progressive matching 33 | ####################### 34 | 35 | japanese = "こんにちは世界" 36 | with open("examples/grammars/japanese.ebnf", "r") as file: 37 | input_text = file.read() 38 | parsed_grammar = parse_ebnf(input_text) 39 | 40 | start_rule_id = parsed_grammar.symbol_table["root"] 41 | 42 | recognizer = StringRecognizer(parsed_grammar.grammar_encoding, start_rule_id) 43 | 44 | bytes_japanese = bytes(japanese, "utf-8") 45 | logging.debug(f"bytes_japanese: {bytes_japanese} of length {len(bytes_japanese)}") 46 | 47 | byte_tokens = [bytes_japanese[i] for i in range(len(bytes_japanese))] 48 | # cast into bytes 49 | byte_tokens = [bytes([byte]) for byte in byte_tokens] 50 | 51 | parsing_state = recognizer.get_initial_parsing_state() 52 | 53 | for i, byte in enumerate(byte_tokens): 54 | parsing_state = recognizer._update_state_with_bytes(byte, parsing_state) 55 | assert len(parsing_state.stacks) > 0 56 | 57 | 58 | def test_accept_emoji(): 59 | """ 60 | Test that we can accept emoji 61 | """ 62 | emoji = "😀😄😂" 63 | with open("examples/grammars/emoji.ebnf", "r") as file: 64 | input_text = file.read() 65 | parsed_grammar = parse_ebnf(input_text) 66 | 67 | start_rule_id = parsed_grammar.symbol_table["root"] 68 | 69 | recognizer = StringRecognizer(parsed_grammar.grammar_encoding, start_rule_id) 70 | 71 | bytes_emoji = bytes(emoji, "utf-8") 72 | logging.debug(f"bytes_emoji: {bytes_emoji} of length {len(bytes_emoji)}") 73 | # 😀😄😂 74 | 75 | parsing_state = recognizer._update_state_with_bytes(bytes_emoji) 76 | # non empty stack means that the bytes were accepted 77 | assert len(parsing_state.stacks) > 0 78 | -------------------------------------------------------------------------------- /examples/grammars/SMILES/chain_extenders.ebnf: -------------------------------------------------------------------------------- 1 | root ::= (smiles bond?)* ( group_symbol_left group_bond? | group_radical_left bond? ) smiles+ | smiles+ ( bond? group_radical_right | group_bond? group_symbol_right) (bond? smiles)* 2 | 3 | group_radical_left ::= "(" ( group_symbol_left group_bond? smiles* )+ ")" 4 | 5 | group_radical_right ::= "(" ( smiles* group_bond? group_symbol_right )+ ")" 6 | 7 | group_bond ::= ( "-" | "\\" | "/" ) 8 | 9 | group_symbol_left ::= "OC" | "N" 10 | 11 | group_symbol_right ::= "CO" 12 | 13 | smiles ::= atom ( chain | branch )* 14 | 15 | chain ::= (dot atom | bond? ( atom | ring_closure ) )+ 16 | 17 | branch ::= "(" ( ( dot | bond )? smiles )+ ")" 18 | 19 | 20 | atom ::= organic_symbol | aromatic_symbol | atom_spec | wildcard | group_symbol_left 21 | 22 | bond ::= "-" | "=" | "#" | "$" | ":" | "/" | "\\" 23 | 24 | dot ::= "." 25 | 26 | atom_spec ::= "[" isotope? ( "se" | "as" | aromatic_symbol | element_symbol | wildcard ) chiral_class? h_count? ( charge? | class? ) "]" 27 | 28 | organic_symbol ::= "Br" | "Cl" | "N" | "O" | "P" | "S" | "F" | "I" | "B" | "C" 29 | 30 | aromatic_symbol ::= "b" | "c" | "n" | "o" | "p" | "s" 31 | 32 | wildcard ::= "*" 33 | 34 | element_symbol ::= "A" ( "c" | "g" | "l" | "m" | "r" | "s" | "t" | "u" ) | 35 | "B" ( "a" | "e" | "h" | "i" | "k" | "r" )? | 36 | "C" ( "a" | "d" | "e" | "f" | "l" | "m" | "n" | "o" | "r" | "s" | "u" )? | 37 | "D" ( "b" | "s" | "y" ) | 38 | "E" ( "r" | "s" | "u" ) | 39 | "F" ( "e" | "l" | "m" | "r" )? | 40 | "G" ( "a" | "d" | "e" ) | 41 | "H" ( "e" | "f" | "g" | "o" | "s" )? | 42 | "I" ( "n" | "r" )? | 43 | "K" "r"? | 44 | "L" ( "a" | "i" | "r" | "u" | "v" ) | 45 | "M" ( "c" | "g" | "n" | "o" | "t" ) | 46 | "N" ( "a" | "b" | "d" | "e" | "h" | "i" | "o" | "p" )? | 47 | "O" ( "g" | "s" )? | 48 | "P" ( "a" | "b" | "d" | "m" | "o" | "r" | "t" | "u" )? | 49 | "R" ( "a" | "b" | "e" | "f" | "g" | "h" | "n" | "u" ) | 50 | "S" ( "b" | "c" | "e" | "g" | "i" | "m" | "n" | "r" )? | 51 | "T" ( "a" | "b" | "c" | "e" | "h" | "i" | "l" | "m" | "s" ) | 52 | "U" | "V" | "W" | "Xe" | "Y" "b"? | 53 | "Z" ( "n" | "r" ) 54 | 55 | 56 | ring_closure ::= "%" [1-9] [0-9] | [0-9] 57 | 58 | chiral_class ::= ( "@" ( "@" | "TH" [1-2] | "AL" [1-2] | "SP" [1-3] | "TB" ( "1" [0-9]? | "2" "0"? | [3-9] ) | "OH" ( "1" [0-9]? | "2" [0-9]? | "3" "0"? | [4-9] ) )? )? 59 | 60 | charge ::= "-" ( "-" | "0" | "1" [0-5]? | [2-9] )? | "+" ( "+" | "0" | "1" [0-5]? | [2-9] )? 61 | 62 | h_count ::= "H" [0-9]? 63 | 64 | class ::= ":" [0-9]+ 65 | 66 | isotope ::= [1-9] [0-9]? [0-9]? 67 | -------------------------------------------------------------------------------- /examples/grammars/SMILES/isocyanates.ebnf: -------------------------------------------------------------------------------- 1 | root ::= ( group_symbol_left group_bond? | (smiles bond?)* group_radical_left bond? ) smiles+ | smiles+ ( bond? group_radical_right (bond? smiles)* | group_bond? group_symbol_right ) 2 | 3 | group_radical_left ::= "(" ( group_symbol_left group_bond? smiles* )+ ")" 4 | 5 | group_radical_right ::= "(" ( smiles* group_bond? group_symbol_right )+ ")" 6 | 7 | group_bond ::= ( "-" | "\\" | "/" ) 8 | 9 | group_symbol_left ::= "O=C=N" 10 | 11 | group_symbol_right ::= "N=C=O" 12 | 13 | smiles ::= atom ( chain | branch )* 14 | 15 | chain ::= (dot atom | bond? ( atom | ring_closure ) )+ 16 | 17 | branch ::= "(" ( ( dot | bond )? smiles )+ ")" 18 | 19 | 20 | atom ::= organic_symbol | aromatic_symbol | atom_spec | wildcard | group_symbol_left 21 | 22 | bond ::= "-" | "=" | "#" | "$" | ":" | "/" | "\\" 23 | 24 | dot ::= "." 25 | 26 | atom_spec ::= "[" isotope? ( "se" | "as" | aromatic_symbol | element_symbol | wildcard ) chiral_class? h_count? ( charge? | class? ) "]" 27 | 28 | organic_symbol ::= "Br" | "Cl" | "N" | "O" | "P" | "S" | "F" | "I" | "B" | "C" 29 | 30 | aromatic_symbol ::= "b" | "c" | "n" | "o" | "p" | "s" 31 | 32 | wildcard ::= "*" 33 | 34 | element_symbol ::= "A" ( "c" | "g" | "l" | "m" | "r" | "s" | "t" | "u" ) | 35 | "B" ( "a" | "e" | "h" | "i" | "k" | "r" )? | 36 | "C" ( "a" | "d" | "e" | "f" | "l" | "m" | "n" | "o" | "r" | "s" | "u" )? | 37 | "D" ( "b" | "s" | "y" ) | 38 | "E" ( "r" | "s" | "u" ) | 39 | "F" ( "e" | "l" | "m" | "r" )? | 40 | "G" ( "a" | "d" | "e" ) | 41 | "H" ( "e" | "f" | "g" | "o" | "s" )? | 42 | "I" ( "n" | "r" )? | 43 | "K" "r"? | 44 | "L" ( "a" | "i" | "r" | "u" | "v" ) | 45 | "M" ( "c" | "g" | "n" | "o" | "t" ) | 46 | "N" ( "a" | "b" | "d" | "e" | "h" | "i" | "o" | "p" )? | 47 | "O" ( "g" | "s" )? | 48 | "P" ( "a" | "b" | "d" | "m" | "o" | "r" | "t" | "u" )? | 49 | "R" ( "a" | "b" | "e" | "f" | "g" | "h" | "n" | "u" ) | 50 | "S" ( "b" | "c" | "e" | "g" | "i" | "m" | "n" | "r" )? | 51 | "T" ( "a" | "b" | "c" | "e" | "h" | "i" | "l" | "m" | "s" ) | 52 | "U" | "V" | "W" | "Xe" | "Y" "b"? | 53 | "Z" ( "n" | "r" ) 54 | 55 | 56 | ring_closure ::= "%" [1-9] [0-9] | [0-9] 57 | 58 | chiral_class ::= ( "@" ( "@" | "TH" [1-2] | "AL" [1-2] | "SP" [1-3] | "TB" ( "1" [0-9]? | "2" "0"? | [3-9] ) | "OH" ( "1" [0-9]? | "2" [0-9]? | "3" "0"? | [4-9] ) )? )? 59 | 60 | charge ::= "-" ( "-" | "0" | "1" [0-5]? | [2-9] )? | "+" ( "+" | "0" | "1" [0-5]? | [2-9] )? 61 | 62 | h_count ::= "H" [0-9]? 63 | 64 | class ::= ":" [0-9]+ 65 | 66 | isotope ::= [1-9] [0-9]? [0-9]? 67 | -------------------------------------------------------------------------------- /docs/supported_models.yaml: -------------------------------------------------------------------------------- 1 | transformers.models.codegen.tokenization_codegen_fast.CodeGenTokenizerFast: 2 | - microsoft/phi-1_5 3 | transformers.models.gpt2.tokenization_gpt2_fast.GPT2TokenizerFast: 4 | - gpt2 5 | - distilgpt2 6 | - facebook/opt-125m 7 | - openai-community/gpt2-large 8 | - openai-community/gpt2-xl 9 | - openai-community/gpt2-medium 10 | - EleutherAI/gpt-neo-125m 11 | - microsoft/DialoGPT-medium 12 | - nferruz/ProtGPT2 13 | - sshleifer/tiny-gpt2 14 | - facebook/opt-2.7b 15 | - tiiuae/falcon-rw-1b 16 | - EleutherAI/gpt-neo-1.3B 17 | - EleutherAI/gpt-j-6b 18 | - microsoft/DialoGPT-large 19 | - facebook/opt-350m 20 | - EleutherAI/gpt-neo-2.7B 21 | - togethercomputer/GPT-JT-6B-v1 22 | - facebook/opt-1.3b 23 | - facebook/opt-13b 24 | - KoboldAI/GPT-J-6B-Janeway 25 | - KoboldAI/OPT-6B-nerys-v2 26 | - succinctly/text2image-prompt-generator 27 | - KoboldAI/OPT-13B-Erebus 28 | - KoboldAI/OPT-6.7B-Erebus 29 | - ai-forever/mGPT 30 | transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast: 31 | - mistralai/Mistral-7B-v0.1 32 | - teknium/OpenHermes-2-Mistral-7B 33 | - davidkim205/komt-mistral-7b-v1 34 | - h2oai/h2ogpt-4096-llama2-7b-chat 35 | - HuggingFaceM4/tiny-random-LlamaForCausalLM 36 | - lmsys/vicuna-7b-v1.5 37 | - TinyPixel/Llama-2-7B-bf16-sharded 38 | - OpenAssistant/llama2-13b-orca-8k-3319 39 | - mistralai/Mistral-7B-Instruct-v0.2 40 | - mistralai/Mistral-7B-Instruct-v0.1 41 | - NousResearch/Llama-2-7b-chat-hf 42 | - NousResearch/Nous-Hermes-Llama2-13b 43 | - HuggingFaceH4/zephyr-7b-beta 44 | - TheBloke/Llama-2-13B-chat-GPTQ 45 | - NousResearch/Llama-2-7b-hf 46 | - echarlaix/tiny-random-mistral 47 | - NousResearch/Yarn-Mistral-7b-128k 48 | - fxmarty/tiny-llama-fast-tokenizer 49 | - liuhaotian/llava-v1.5-7b 50 | - Gryphe/MythoMax-L2-13b 51 | - TheBloke/Llama-2-7B-Chat-GPTQ 52 | - Open-Orca/Mistral-7B-OpenOrca 53 | - cognitivecomputations/dolphin-2.2.1-mistral-7b 54 | - DiscoResearch/mixtral-7b-8expert 55 | - lmsys/vicuna-13b-v1.5 56 | - huggyllama/llama-7b 57 | - TinyPixel/CodeLlama-7B-Python-bf16-sharded 58 | - HuggingFaceH4/zephyr-7b-alpha 59 | - Riiid/sheep-duck-llama-2 60 | - liuhaotian/llava-v1.5-13b 61 | - togethercomputer/LLaMA-2-7B-32K 62 | - teknium/OpenHermes-2.5-Mistral-7B 63 | - JackFram/llama-68m 64 | - openlm-research/open_llama_7b_v2 65 | - TinyLlama/TinyLlama-1.1B-Chat-v0.3 66 | - NousResearch/Nous-Hermes-llama-2-7b 67 | - TheBloke/Llama-2-7B-Chat-AWQ 68 | - garage-bAInd/Platypus2-7B 69 | - h2oai/h2ogpt-4096-llama2-13b-chat 70 | - TinyLlama/TinyLlama-1.1B-intermediate-step-715k-1.5T 71 | - togethercomputer/Llama-2-7B-32K-Instruct 72 | - TheBloke/Llama-2-7B-GPTQ 73 | - LeoLM/leo-hessianai-7b 74 | transformers.models.t5.tokenization_t5_fast.T5TokenizerFast: 75 | - Vamsi/T5_Paraphrase_Paws 76 | transformers.tokenization_utils_fast.PreTrainedTokenizerFast: 77 | - tiiuae/falcon-40b-instruct 78 | - tiiuae/falcon-7b-instruct 79 | - tiiuae/falcon-7b 80 | - fxmarty/really-tiny-falcon-testing 81 | -------------------------------------------------------------------------------- /tests/test_accept_string/test_json.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from transformers_cfg.parser import parse_ebnf 3 | from transformers_cfg.recognizer import StringRecognizer 4 | from tests.json_utils import is_json_parsable 5 | 6 | json_examples = { 7 | # Simple Nested Object 8 | "simple_nested": '{"name": "John", "age": 30, "address": {"street": "21 2nd Street", "city": "New York"}}', 9 | # Array of Objects 10 | "array_of_objects": '{"employees": [{"firstName": "John", "lastName": "Doe"}, {"firstName": "Anna", "lastName": "Smith"}]}', 11 | # Nested Arrays and Objects 12 | "nested_arrays_objects": '{"company": "OpenAI", "departments": [{"name": "Research", "members": [{"name": "Alice"}, {"name": "Bob"}]}, {"name": "Engineering", "members": [{"name": "Charlie"}]}]}', 13 | # Mixed Data Types 14 | "mixed_types": '{"name": "Alice", "age": 25, "isEmployee": true, "salary": null, "projects": ["NLP", "AI"]}', 15 | # Empty Object 16 | "empty_object": "{}", 17 | # Deeply Nested Object 18 | "deeply_nested": '{"level1": {"level2": {"level3": {"level4": {"message": "Deep"}}}}}', 19 | # Object with Numbers and Booleans 20 | "numbers_booleans": '{"temperature": 22.5, "isActive": false, "count": 10}', 21 | # Object with Array of Mixed Types 22 | "array_mixed_types": '{"data": [1, "two", true, null, {"nested": "object"}]}', 23 | # Complex Object with All Elements 24 | "complex_all_elements": '{"id": 101, "isActive": true, "info": {"name": "John Doe", "emails": ["john@example.com", "doe@example.com"], "address": {"city": "New York", "zip": "10001"}}, "tags": ["admin", "user"], "history": [{"login": "2023-01-01", "duration": 3600}, {"login": "2023-01-02", "duration": 2700}]}', 25 | # Object with Special Characters in Strings TODO fails 26 | # "escape_characters": '{"greeting": "Hello, \\"World\\"!", "path": "C:\\\\Program Files\\\\Test"}', 27 | } 28 | 29 | 30 | @pytest.fixture(scope="module") 31 | def recognizer(): 32 | with open("examples/grammars/json.ebnf", "r") as file: 33 | input_text = file.read() 34 | parsed_grammar = parse_ebnf(input_text) 35 | start_rule_id = parsed_grammar.symbol_table["root"] 36 | recognizer = StringRecognizer(parsed_grammar.grammar_encoding, start_rule_id) 37 | return recognizer 38 | 39 | 40 | def test_minimal_json_object(recognizer): 41 | """ 42 | Test that we can load a JSON object 43 | """ 44 | json = '{"foo": "bar", "baz": "bat"}' 45 | 46 | assert is_json_parsable(json) == recognizer._accept_prefix(json) 47 | assert is_json_parsable(json) == recognizer._accept_string(json) 48 | 49 | prefix_json = json[: len(json) // 2] 50 | assert recognizer._accept_prefix(prefix_json) 51 | assert not recognizer._accept_string(prefix_json) 52 | 53 | 54 | def test_systematic_examples(recognizer): 55 | for name, json_object in json_examples.items(): 56 | assert is_json_parsable(json_object) == recognizer._accept_prefix( 57 | json_object 58 | ), f"Failed on {name}, {json_object}" 59 | -------------------------------------------------------------------------------- /examples/grammars/SMILES/acrylates.ebnf: -------------------------------------------------------------------------------- 1 | root ::= (smiles bond?)* ( group_symbol_left group_bond? | group_radical_left bond? ) smiles+ | smiles+ ( bond? group_radical_right | group_bond? group_symbol_right) (bond? smiles)* 2 | 3 | group_radical_left ::= "(" ( group_symbol_left (group_bond smiles+)? )+ ")" 4 | 5 | group_radical_right ::= "(" ( (smiles+ group_bond )? group_symbol_right )+ ")" 6 | 7 | group_bond ::= ( "-" | "\\" | "/" ) 8 | 9 | group_symbol_left ::= "C=CC(=O)O" | "C=CC(O)=O" | "C(=C)C(=O)O" | "C(=C)C(O)=O" | "CC(=C)(=O)O" | "CC(=C)(O)=O" 10 | 11 | group_symbol_right ::= "OC(=O)C=C" | "O=C(O)C=C" | "OC(=O)C(=C)" | "O=C(O)C(=C)" | "O(O=)(C=)CC" | "O=(O)(C=)CC" 12 | 13 | smiles ::= atom ( chain | branch )* 14 | 15 | chain ::= (dot atom | bond? ( atom | ring_closure ) )+ 16 | 17 | branch ::= "(" ( ( dot | bond )? smiles )+ ")" 18 | 19 | atom ::= organic_symbol | aromatic_symbol | atom_spec | wildcard | group_symbol_left | group_symbol_right 20 | 21 | bond ::= "-" | "=" | "#" | "$" | ":" | "/" | "\\" 22 | 23 | dot ::= "." 24 | 25 | atom_spec ::= "[" isotope? ( "se" | "as" | aromatic_symbol | element_symbol | wildcard ) chiral_class? h_count? ( charge? | class? ) "]" 26 | 27 | organic_symbol ::= "Br" | "Cl" | "N" | "O" | "P" | "S" | "F" | "I" | "B" | "C" 28 | 29 | aromatic_symbol ::= "b" | "c" | "n" | "o" | "p" | "s" 30 | 31 | wildcard ::= "*" 32 | 33 | element_symbol ::= "A" ( "c" | "g" | "l" | "m" | "r" | "s" | "t" | "u" ) | 34 | "B" ( "a" | "e" | "h" | "i" | "k" | "r" )? | 35 | "C" ( "a" | "d" | "e" | "f" | "l" | "m" | "n" | "o" | "r" | "s" | "u" )? | 36 | "D" ( "b" | "s" | "y" ) | 37 | "E" ( "r" | "s" | "u" ) | 38 | "F" ( "e" | "l" | "m" | "r" )? | 39 | "G" ( "a" | "d" | "e" ) | 40 | "H" ( "e" | "f" | "g" | "o" | "s" )? | 41 | "I" ( "n" | "r" )? | 42 | "K" "r"? | 43 | "L" ( "a" | "i" | "r" | "u" | "v" ) | 44 | "M" ( "c" | "g" | "n" | "o" | "t" ) | 45 | "N" ( "a" | "b" | "d" | "e" | "h" | "i" | "o" | "p" )? | 46 | "O" ( "g" | "s" )? | 47 | "P" ( "a" | "b" | "d" | "m" | "o" | "r" | "t" | "u" )? | 48 | "R" ( "a" | "b" | "e" | "f" | "g" | "h" | "n" | "u" ) | 49 | "S" ( "b" | "c" | "e" | "g" | "i" | "m" | "n" | "r" )? | 50 | "T" ( "a" | "b" | "c" | "e" | "h" | "i" | "l" | "m" | "s" ) | 51 | "U" | "V" | "W" | "Xe" | "Y" "b"? | 52 | "Z" ( "n" | "r" ) 53 | 54 | 55 | ring_closure ::= "%" [1-9] [0-9] | [0-9] 56 | 57 | chiral_class ::= ( "@" ( "@" | "TH" [1-2] | "AL" [1-2] | "SP" [1-3] | "TB" ( "1" [0-9]? | "2" "0"? | [3-9] ) | "OH" ( "1" [0-9]? | "2" [0-9]? | "3" "0"? | [4-9] ) )? )? 58 | 59 | charge ::= "-" ( "-" | "0" | "1" [0-5]? | [2-9] )? | "+" ( "+" | "0" | "1" [0-5]? | [2-9] )? 60 | 61 | h_count ::= "H" [0-9]? 62 | 63 | class ::= ":" [0-9]+ 64 | 65 | isotope ::= [1-9] [0-9]? [0-9]? 66 | -------------------------------------------------------------------------------- /examples/metrics/run_constrained_decoding_metric.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer, AutoModelForCausalLM 2 | 3 | from transformers_cfg.generation import GrammarConstrainedLogitsProcessor 4 | from transformers_cfg.grammar_utils import IncrementalGrammarConstraint 5 | from transformers_cfg.metrics import ConstrainedDecodingMetric 6 | from transformers_cfg.metrics.metrics import ConstrainedDecodingMetricOutput 7 | 8 | if __name__ == "__main__": 9 | metric = ConstrainedDecodingMetric() 10 | 11 | model_id = "gpt2" 12 | 13 | # Load model and tokenizer 14 | tokenizer = AutoTokenizer.from_pretrained(model_id) 15 | tokenizer.pad_token = tokenizer.eos_token 16 | model = AutoModelForCausalLM.from_pretrained(model_id) 17 | model.generation_config.pad_token_id = model.generation_config.eos_token_id 18 | 19 | # Load json grammar 20 | with open("examples/grammars/json.ebnf", "r") as file: 21 | grammar_str = file.read() 22 | grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer) 23 | grammar_processor = GrammarConstrainedLogitsProcessor(grammar) 24 | 25 | # Generate 26 | input_ids = tokenizer( 27 | [ 28 | "This is a valid json string for http request:", 29 | "This is a valid json string for shopping cart:", 30 | ], 31 | add_special_tokens=False, 32 | return_tensors="pt", 33 | padding=True, 34 | )["input_ids"] 35 | output = model.generate( 36 | input_ids, 37 | max_length=30, 38 | logits_processor=[grammar_processor], 39 | repetition_penalty=1, 40 | num_return_sequences=1, 41 | return_dict_in_generate=True, 42 | output_scores=True, 43 | output_logits=True, 44 | ) 45 | 46 | # decode output 47 | generations = tokenizer.batch_decode(output["sequences"], skip_special_tokens=True) 48 | print(generations) 49 | 50 | result: ConstrainedDecodingMetricOutput = metric.compute_from_model_output(output) 51 | print("Original token probabilities:") 52 | print(result.df["original_token_probs"].head()) 53 | 54 | print("Renormalised token probabilities:") 55 | print(result.df["renormalised_token_probs"].head()) 56 | 57 | print("Total rejection probability gain:") 58 | print(result.df["total_rejection_prob_gain"].head()) 59 | 60 | """ 61 | Original token probabilities: 62 | Batch 1 Batch 2 63 | Step 1 0.002313 0.002031 64 | Step 2 0.250744 0.165687 65 | Step 3 0.104876 0.090784 66 | Step 4 0.419352 0.691826 67 | Step 5 0.203741 0.875446 68 | Renormalised token probabilities: 69 | Batch 1 Batch 2 70 | Step 1 0.620275 0.690802 71 | Step 2 0.762278 0.493411 72 | Step 3 0.105184 0.091143 73 | Step 4 0.428187 0.741440 74 | Step 5 0.205163 0.902042 75 | Total rejection probability gain: 76 | Batch 1 Batch 2 77 | Step 1 0.996271 0.997059 78 | Step 2 0.671060 0.664201 79 | Step 3 0.002922 0.003944 80 | Step 4 0.020609 0.066890 81 | Step 5 0.006915 0.029484 82 | 83 | """ 84 | -------------------------------------------------------------------------------- /tests/test_accept_string/test_geo_query.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from transformers_cfg.parser import parse_ebnf 3 | from transformers_cfg.recognizer import StringRecognizer 4 | from dataclasses import dataclass 5 | 6 | 7 | @dataclass 8 | class GeoQueryTestCase: 9 | name: str 10 | geo_query: str 11 | 12 | 13 | valid_geo_query_sentences = [ 14 | GeoQueryTestCase("simple_answer", "answer(smallest(state(all)))"), 15 | GeoQueryTestCase("state_id", "answer(highest(place(loc_2(stateid('hawaii')))))"), 16 | GeoQueryTestCase("river", "answer(river(all))"), 17 | GeoQueryTestCase("state", "answer(loc_1(major(river(all))))"), 18 | GeoQueryTestCase("next_to_2", "answer(state(next_to_2(stateid('texas'))))"), 19 | GeoQueryTestCase( 20 | "intersection", 21 | "answer(intersection(state(next_to_2(stateid('texas'))), loc_1(major(river(all)))))", 22 | ), 23 | GeoQueryTestCase("space in name", "answer(population_1(stateid('new york')))"), 24 | GeoQueryTestCase( 25 | "exclude", 26 | "answer(count(exclude(river(all), traverse_2(state(loc_1(capital(cityid('albany', _))))))))", 27 | ), 28 | GeoQueryTestCase( 29 | "city_id_with_state", "answer(population_1(cityid('washington', 'dc')))" 30 | ), 31 | ] 32 | 33 | valid_geo_query_prefixes = [ 34 | GeoQueryTestCase("empty_string", ""), 35 | GeoQueryTestCase("unbalanced_paranthesis", "answer(count(major(city(all"), 36 | ] 37 | 38 | invalid_geo_query_sentences = [ 39 | GeoQueryTestCase("no_answer", "highest(place(loc_2(stateid('kansas'))))"), 40 | GeoQueryTestCase("fake_country", "answer(major(city(loc_2(countryid('xx')))))"), 41 | GeoQueryTestCase("unexisting_function", "answer(population_2(stateid('hawaii')))"), 42 | GeoQueryTestCase("empty_operator", "answer(highest(place(loc_2())))"), 43 | GeoQueryTestCase("empty_paranthesis", "()"), 44 | GeoQueryTestCase( 45 | "missing_argument", "answer(intersection(state(next_to_2(stateid('texas'))), )" 46 | ), 47 | ] 48 | 49 | 50 | @pytest.fixture(scope="module") 51 | def recognizer(): 52 | with open(f"examples/grammars/geo_query.ebnf", "r") as file: 53 | input_text = file.read() 54 | parsed_grammar = parse_ebnf(input_text) 55 | start_rule_id = parsed_grammar.symbol_table["root"] 56 | recognizer = StringRecognizer(parsed_grammar.grammar_encoding, start_rule_id) 57 | print("SetUp successful!", flush=True) 58 | return recognizer 59 | 60 | 61 | def test_valid_sentence(recognizer): 62 | for geo_query_test_case in valid_geo_query_sentences: 63 | assert ( 64 | recognizer._accept_string(geo_query_test_case.geo_query) == True 65 | ), f"Failed on {geo_query_test_case.name}, {geo_query_test_case.geo_query}" 66 | 67 | for geo_query_test_case in valid_geo_query_prefixes + invalid_geo_query_sentences: 68 | assert ( 69 | recognizer._accept_string(geo_query_test_case.geo_query) == False 70 | ), f"Failed on {geo_query_test_case.name}, {geo_query_test_case.geo_query}" 71 | 72 | 73 | def test_valid_prefixes(recognizer): 74 | for geo_query_test_case in valid_geo_query_sentences + valid_geo_query_prefixes: 75 | assert ( 76 | recognizer._accept_prefix(geo_query_test_case.geo_query) == True 77 | ), f"Failed on {geo_query_test_case.name}, {geo_query_test_case.geo_query}" 78 | 79 | for geo_query_test_case in invalid_geo_query_sentences: 80 | assert ( 81 | recognizer._accept_prefix(geo_query_test_case.geo_query) == False 82 | ), f"Failed on {geo_query_test_case.name}, {geo_query_test_case.geo_query}" 83 | -------------------------------------------------------------------------------- /transformers_cfg/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import List 3 | 4 | from termcolor import colored 5 | 6 | 7 | def ints2bytes(sequence: List[int]) -> bytes: 8 | # check in the range of 0-255 9 | for item in sequence: 10 | if not 0 <= item <= 255: 11 | raise ValueError(f"item: {item} is not in the range [0, 255]") 12 | return bytes(sequence) 13 | 14 | 15 | def bytes2ints(byte_sequence: bytes) -> List[int]: 16 | return list(byte_sequence) 17 | 18 | 19 | def intervals_intersect(low1, high1, low2, high2): 20 | """ 21 | Check if two intervals [low1, high1] and [low2, high2] intersect. 22 | 23 | :param high1: High bound of the first interval. 24 | :param low1: Low bound of the first interval. 25 | :param high2: High bound of the second interval. 26 | :param low2: Low bound of the second interval. 27 | :return: True if the intervals intersect, False otherwise. 28 | """ 29 | # Check if one interval is completely to the right of the other 30 | if low1 > high2 or low2 > high1: 31 | return False 32 | 33 | # If the above condition is not met, the intervals intersect 34 | return True 35 | 36 | 37 | def pprint_token_ids(tokenizer, token_ids=None, text=None): 38 | if token_ids is None and text is None: 39 | raise ValueError("Either token_ids or text should be provided") 40 | if token_ids is None: 41 | token_ids = tokenizer.encode(text, add_special_tokens=False) 42 | special_token_ids = tokenizer.all_special_ids 43 | special_tokens = tokenizer.all_special_tokens 44 | special_id2token = { 45 | id: token for id, token in zip(special_token_ids, special_tokens) 46 | } 47 | # loop over token_ids and color the special tokens 48 | colored_token_ids = [] 49 | 50 | for token_id in token_ids: 51 | if token_id in special_id2token: 52 | colored_token_ids.append(colored(token_id, "red", attrs=["bold"])) 53 | else: 54 | colored_token_ids.append(str(token_id)) 55 | colored_token_ids_str = [str(item) for item in colored_token_ids] 56 | print("[" + ", ".join(colored_token_ids_str) + "]") 57 | 58 | 59 | def get_tokenizer_model_type(model: str = "gpt2"): 60 | """ 61 | reference https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_fast.py#L729 62 | :param model: 63 | :return: BPE, Unigram, WordPiece, WordLevel 64 | SentencePiece is used in conjunction with Unigram 65 | """ 66 | from transformers import AutoTokenizer 67 | 68 | # if the tokenizer is not in the repo, it will raise OSError 69 | # OSError: Can't load tokenizer for 'xxx' 70 | # This happens when the model reuses the tokenizer of another model 71 | if type(model) == str: 72 | try: 73 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True) 74 | # check if the tokenizer is fast 75 | except OSError: 76 | return None 77 | else: 78 | tokenizer = model 79 | 80 | if not tokenizer.is_fast: 81 | raise ValueError(f"The tokenizer {model} is not fast tokenizer") 82 | tokenizer_json = json.loads(tokenizer._tokenizer.to_str()) 83 | model_type = tokenizer_json["model"]["type"] 84 | if ( 85 | model_type == "BPE" 86 | and tokenizer_json["pre_tokenizer"] is not None 87 | and ( 88 | tokenizer_json["pre_tokenizer"]["type"] == "ByteLevel" 89 | or ( 90 | "pretokenizers" in tokenizer_json["pre_tokenizer"] 91 | and tokenizer_json["pre_tokenizer"]["pretokenizers"][1]["type"] 92 | == "ByteLevel" 93 | ) 94 | ) 95 | ): 96 | model_type = "ByteLevelBPE" 97 | return model_type 98 | -------------------------------------------------------------------------------- /transformers_cfg/tokenization/mapping/ByteProxyMapping.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer 2 | 3 | 4 | from typing import Dict, List 5 | 6 | 7 | class ByteProxyMapping: 8 | def __init__(self, tokenizer): 9 | # check if the tokenizer is fast, if so, convert it to slow 10 | if tokenizer.is_fast: 11 | tokenizer = AutoTokenizer.from_pretrained( 12 | tokenizer.name_or_path, use_fast=False 13 | ) 14 | self.tokenizer = tokenizer 15 | 16 | # if tokenizer doesn't have byte_encoder(which is the case for llama-3), use gpt2_tokenizer 17 | if not hasattr(tokenizer, "byte_encoder"): 18 | gpt2_tokenizer = AutoTokenizer.from_pretrained("gpt2", use_fast=False) 19 | tokenizer.byte_encoder = gpt2_tokenizer.byte_encoder 20 | tokenizer.byte_decoder = gpt2_tokenizer.byte_decoder 21 | 22 | self.byte2proxychar: Dict[int, str] = tokenizer.byte_encoder 23 | self.proxychar2byte: Dict[str, int] = tokenizer.byte_decoder 24 | 25 | # code point to byte 26 | self.cdp2byte: Dict[int, int] = { 27 | ord(c): b for c, b in self.proxychar2byte.items() 28 | } 29 | self.byte2cdp: Dict[int, int] = {v: k for k, v in self.cdp2byte.items()} 30 | self.PROXY_CDP_SET = set(self.cdp2byte.keys()) 31 | # [33, 126] and [161,172, [174, 323], in total 94 + 12 + 150 = 256(N.B. 173 is a control character) 32 | 33 | def encode_byte2proxychar_cdp(self, byte: int) -> int: 34 | assert 0 <= byte < 256, f"byte: {byte} is not in the range [0, 256)" 35 | return ord(self.byte2proxychar[byte]) 36 | 37 | def decode_proxychar2byte_cdp(self, cdp: int) -> int: 38 | byte_int: int = self.cdp2byte[cdp] 39 | assert 0 <= byte_int < 256, f"byte: {byte_int} is not in the range [0, 256)" 40 | return byte_int 41 | 42 | def decode_proxytoken2bytes(self, proxy_token: str) -> bytes: 43 | bytes_seq: List[int] = [ 44 | self.decode_proxychar2byte_cdp(ord(c)) for c in proxy_token 45 | ] 46 | return bytes(bytes_seq) 47 | 48 | def map(self, proxy_token: str) -> bytes: 49 | return self.decode_proxytoken2bytes(proxy_token) 50 | 51 | def token2bytes(self, token: str) -> bytes: 52 | bytes_seq: List[int] = [self.proxychar2byte[c] for c in token] 53 | return bytes(bytes_seq) 54 | 55 | 56 | class LLAMAByteProxyMapping: 57 | def __init__(self): 58 | pass 59 | 60 | def map(self, proxy_token: str) -> bytes: 61 | return self.decode_proxytoken2bytes(proxy_token) 62 | 63 | def decode_proxytoken2bytes(self, proxy_token: str) -> bytes: 64 | if proxy_token.startswith("<0x"): 65 | hex_value: str = proxy_token[3:-1] 66 | return bytes.fromhex(hex_value) 67 | else: 68 | # ad hoc fix for BPE 69 | if proxy_token.startswith("▁"): 70 | proxy_token = proxy_token.replace("▁", " ") 71 | return proxy_token.encode("utf-8") 72 | 73 | 74 | if __name__ == "__main__": 75 | 76 | gpt2_tokenizer = AutoTokenizer.from_pretrained("gpt2") 77 | 78 | # gpt2_tokenizer.encode("´") 79 | 80 | byteproxymapper = ByteProxyMapping(gpt2_tokenizer) 81 | 82 | for i in range(256): 83 | print(f"{i}: {byteproxymapper.encode_byte2proxychar_cdp(i)}") 84 | 85 | # decode a byte 86 | 87 | byte = 162 88 | print(f"proxy code point set: {byteproxymapper.PROXY_CDP_SET}") 89 | print(f"len(proxy code point set): {len(byteproxymapper.PROXY_CDP_SET)}") 90 | for i in range(33, 127): 91 | if i not in byteproxymapper.PROXY_CDP_SET: 92 | print(f"{i} not in proxy code point set") 93 | for i in range(161, 324): 94 | if i not in byteproxymapper.PROXY_CDP_SET: 95 | print(f"{i} not in proxy code point set") 96 | -------------------------------------------------------------------------------- /examples/benchmarking/time_benchmarking.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoModelForCausalLM, AutoTokenizer 3 | from transformers_cfg.grammar_utils import IncrementalGrammarConstraint 4 | from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor 5 | import time 6 | import sys 7 | from dataclasses import dataclass 8 | 9 | 10 | @dataclass 11 | class BenchmarkingArguments: 12 | grammar_filepath: str 13 | prompt: str 14 | max_new_tokens: int = 50 15 | model_id: str = "/dlabdata1/llm_hub/Mistral-7B-v0.1" 16 | device: str = "cpu" 17 | 18 | 19 | MAX_NEW_TOKEN_PLACEHOLDER = "" 20 | 21 | 22 | def parse_args(): 23 | raw_args = sys.argv[1:] 24 | n_passed = len(raw_args) 25 | if n_passed < 2: 26 | print("Usage: python time_benchmarking.py ") 27 | return 28 | if n_passed > 2: 29 | raw_args[2] = int(raw_args[2]) 30 | args = BenchmarkingArguments(*raw_args) 31 | return args 32 | 33 | 34 | def main(): 35 | args = parse_args() 36 | 37 | model_id = args.model_id 38 | 39 | # Detect if GPU is available, otherwise use CPU 40 | device = torch.device( 41 | args.device or ("cuda" if torch.cuda.is_available() else "cpu") 42 | ) 43 | print(f"Using device: {device}, max new tokens: {args.max_new_tokens}") 44 | 45 | # Load model and tokenizer 46 | tokenizer = AutoTokenizer.from_pretrained(model_id) 47 | tokenizer.pad_token = tokenizer.eos_token 48 | # Load model to defined device 49 | model = AutoModelForCausalLM.from_pretrained(model_id).to(device) 50 | 51 | # Load grammar 52 | with open(args.grammar_filepath, "r") as file: 53 | grammar_str = file.read() 54 | 55 | grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer) 56 | grammar_processor = GrammarConstrainedLogitsProcessor(grammar) 57 | 58 | # Generate 59 | args.prompt = args.prompt.replace( 60 | MAX_NEW_TOKEN_PLACEHOLDER, str(args.max_new_tokens) 61 | ) 62 | 63 | input_ids = tokenizer( 64 | [args.prompt], add_special_tokens=False, return_tensors="pt", padding=True 65 | )["input_ids"].to(device) 66 | 67 | unconstrained_st = time.perf_counter() 68 | unconstrained_output = model.generate( 69 | input_ids, 70 | do_sample=False, 71 | max_new_tokens=args.max_new_tokens, 72 | num_return_sequences=1, 73 | ) 74 | unconstrained_tot = time.perf_counter() - unconstrained_st 75 | 76 | constrained_st = time.perf_counter() 77 | constrained_output = model.generate( 78 | input_ids, 79 | do_sample=False, 80 | max_new_tokens=args.max_new_tokens, 81 | logits_processor=[grammar_processor], 82 | num_return_sequences=1, 83 | ) 84 | 85 | constrained_tot = time.perf_counter() - constrained_st 86 | print(f"Unconstrained time: {unconstrained_tot:.2f}") 87 | print(f"Constrained time: {constrained_tot:.2f}") 88 | 89 | # decode outputs (possibly of different lengths across decoding modes) 90 | generations = tokenizer.batch_decode( 91 | unconstrained_output, skip_special_tokens=True 92 | ) + tokenizer.batch_decode(constrained_output, skip_special_tokens=True) 93 | print() 94 | 95 | n_examples = len(input_ids) 96 | for i in range(n_examples): 97 | unconstrained_generation = generations[i] 98 | constrained_generation = generations[i + n_examples] 99 | 100 | for generation, generation_type in zip( 101 | [unconstrained_generation, constrained_generation], 102 | ["unconstrained", "constrained"], 103 | ): 104 | print(f"The {generation_type} generation:\n{generation}") 105 | print() 106 | 107 | 108 | if __name__ == "__main__": 109 | main() 110 | -------------------------------------------------------------------------------- /examples/demo.sh: -------------------------------------------------------------------------------- 1 | 2 | ################ 3 | # 4 | # JSON generation: object and array 5 | # 6 | ################ 7 | 8 | # generate json object 9 | transformers-cfg-cli generate \ 10 | -m "microsoft/Phi-3-mini-4k-instruct" \ 11 | -g "examples/grammars/json.ebnf" \ 12 | -p "This is a valid json string for http request:" \ 13 | --use_4bit \ 14 | --max_new_tokens 60 \ 15 | --repetition_penalty 1.1 16 | 17 | # generate json array 18 | 19 | transformers-cfg-cli generate \ 20 | -m "microsoft/Phi-3-mini-4k-instruct" \ 21 | -g "examples/grammars/json_arr.ebnf" \ 22 | -p "Put my shopping list into a json array:" \ 23 | --use_4bit \ 24 | --max_new_tokens 60 \ 25 | --repetition_penalty 1.1 26 | 27 | ################ 28 | # 29 | # Code generation: Python, C 30 | # 31 | ################ 32 | 33 | # generate C code 34 | transformers-cfg-cli generate \ 35 | -m "microsoft/Phi-3-mini-4k-instruct" \ 36 | -g "examples/grammars/c.ebnf" \ 37 | -p "#include \n" \ 38 | --use_4bit \ 39 | --max_new_tokens 20 \ 40 | --repetition_penalty 3.0 41 | 42 | ################ 43 | # 44 | # NLP tasks: relation extraction 45 | # 46 | ################ 47 | 48 | # generate relation extraction triples 49 | transformers-cfg-cli generate \ 50 | -m "microsoft/Phi-3-mini-4k-instruct" \ 51 | -g "examples/grammars/cIE.ebnf" \ 52 | -p "Extract relations from the following sentence: René Descartes was a French philosopher, scientist, and mathematician" \ 53 | --use_8bit \ 54 | --max_new_tokens 60 \ 55 | --repetition_penalty 1.1 56 | 57 | 58 | ################ 59 | # 60 | # Semantic parsing: CalFlow, GeoQuery, overnight, etc. 61 | # 62 | ################ 63 | 64 | transformers-cfg-cli generate \ 65 | -m "microsoft/Phi-3-mini-4k-instruct" \ 66 | -g "examples/grammars/calflow.ebnf" \ 67 | -p 'Generate 3 CalFlow strings: 1.(Yield (toRecipient (CurrentUser))) 2.(Yield (CreateCommitEventWrapper (CreatePreflightEventWrapper (Event.subject_? (?= "choose the meeting"))))) 3.' \ 68 | --use_4bit \ 69 | --max_new_tokens 60 \ 70 | --repetition_penalty 1.1 71 | 72 | transformers-cfg-cli generate \ 73 | -m "microsoft/Phi-3-mini-4k-instruct" \ 74 | -g "examples/grammars/geo_query.ebnf" \ 75 | -p "Translate the following sentence into GeoQuery: What is the population of the largest city in California?" \ 76 | --use_4bit \ 77 | --max_new_tokens 60 \ 78 | --repetition_penalty 1.1 79 | 80 | transformers-cfg-cli generate \ 81 | -m "microsoft/Phi-3-mini-4k-instruct" \ 82 | -g "examples/grammars/overnight.ebnf" \ 83 | -p """Translate natural language to DSL: 84 | Q: which brick is no wider than 3 inches 85 | A: listValue (filter (getProperty (singleton en.block) !type) (ensureNumericProperty width) <= (ensureNumericEntity 3 en.inch))) 86 | Q: which block is above block 1 87 | A: (listValue (filter (filter (getProperty (singleton en.block) !type) (reverse above) = en.block.block1) above = en.block.block1)) 88 | Q: what block is longer than 3 inches 89 | A: """ \ 90 | --use_4bit \ 91 | --max_new_tokens 60 \ 92 | --repetition_penalty 1.1 93 | 94 | 95 | 96 | ################ 97 | # 98 | # Unicode support, Chinese, Emoji, etc. 99 | # 100 | ################ 101 | 102 | transformers-cfg-cli generate \ 103 | -m "microsoft/Phi-3-mini-4k-instruct" \ 104 | -g "examples/grammars/chinese.ebnf" \ 105 | -p "Translate the following sentence into Chinese: My neighbor is a very nice person. -> " \ 106 | --use_4bit \ 107 | --max_new_tokens 60 \ 108 | --repetition_penalty 1.1 109 | 110 | 111 | transformers-cfg-cli generate \ 112 | -m "microsoft/Phi-3-mini-4k-instruct" \ 113 | -g "examples/grammars/emoji.ebnf" \ 114 | -p "Translate the following sentence into emoji: I am very happy today. -> " \ 115 | --use_4bit \ 116 | --max_new_tokens 60 \ 117 | --repetition_penalty 1.1 118 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | examples/grammars/debug 2 | examples/grammars/wip 3 | examples/jupyter 4 | profiling 5 | debug 6 | dev 7 | git_diff.txt 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | *.patch 13 | .secrets 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | share/python-wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | MANIFEST 37 | 38 | # PyInstaller 39 | # Usually these files are written by a python script from a template 40 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 41 | *.manifest 42 | *.spec 43 | 44 | # Installer logs 45 | pip-log.txt 46 | pip-delete-this-directory.txt 47 | 48 | # Unit test / coverage reports 49 | htmlcov/ 50 | .tox/ 51 | .nox/ 52 | .coverage 53 | .coverage.* 54 | .cache 55 | nosetests.xml 56 | coverage.xml 57 | *.cover 58 | *.py,cover 59 | .hypothesis/ 60 | .pytest_cache/ 61 | cover/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | db.sqlite3 71 | db.sqlite3-journal 72 | 73 | # Flask stuff: 74 | instance/ 75 | .webassets-cache 76 | 77 | # Scrapy stuff: 78 | .scrapy 79 | 80 | # Sphinx documentation 81 | docs/_build/ 82 | 83 | # PyBuilder 84 | .pybuilder/ 85 | target/ 86 | 87 | # Jupyter Notebook 88 | .ipynb_checkpoints 89 | 90 | # IPython 91 | profile_default/ 92 | ipython_config.py 93 | 94 | # pyenv 95 | # For a library or package, you might want to ignore these files since the code is 96 | # intended to run in multiple environments; otherwise, check them in: 97 | # .python-version 98 | 99 | # pipenv 100 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 101 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 102 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 103 | # install all needed dependencies. 104 | #Pipfile.lock 105 | 106 | # poetry 107 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 108 | # This is especially recommended for binary packages to ensure reproducibility, and is more 109 | # commonly ignored for libraries. 110 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 111 | #poetry.lock 112 | 113 | # pdm 114 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 115 | #pdm.lock 116 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 117 | # in version control. 118 | # https://pdm.fming.dev/#use-with-ide 119 | .pdm.toml 120 | 121 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 122 | __pypackages__/ 123 | 124 | # Celery stuff 125 | celerybeat-schedule 126 | celerybeat.pid 127 | 128 | # SageMath parsed files 129 | *.sage.py 130 | 131 | # Environments 132 | .env 133 | .venv 134 | env/ 135 | venv/ 136 | ENV/ 137 | env.bak/ 138 | venv.bak/ 139 | 140 | # Spyder project settings 141 | .spyderproject 142 | .spyproject 143 | 144 | # Rope project settings 145 | .ropeproject 146 | 147 | # mkdocs documentation 148 | /site 149 | 150 | # mypy 151 | .mypy_cache/ 152 | .dmypy.json 153 | dmypy.json 154 | 155 | # Pyre type checker 156 | .pyre/ 157 | 158 | # pytype static type analyzer 159 | .pytype/ 160 | 161 | # Cython debug symbols 162 | cython_debug/ 163 | 164 | # PyCharm 165 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 166 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 167 | # and can be added to the global gitignore or merged into this file. For a more nuclear 168 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 169 | .idea/ 170 | 171 | # VSCode 172 | launch.json 173 | 174 | # Profiling 175 | *.prof 176 | 177 | # MacOS 178 | .DS_Store 179 | transformers_cfg/__init__.py 180 | -------------------------------------------------------------------------------- /examples/generate_smiles.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | from transformers import AutoModelForCausalLM, AutoTokenizer 4 | from transformers_cfg.grammar_utils import IncrementalGrammarConstraint 5 | from transformers_cfg.recognizer import StringRecognizer 6 | from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor 7 | from transformers_cfg.parser import parse_ebnf 8 | import time 9 | 10 | 11 | def parse_args(): 12 | parser = argparse.ArgumentParser(description="Generate SMILES strings") 13 | parser.add_argument( 14 | "--model-id", 15 | type=str, 16 | default="/dlabdata1/llm_hub/Mistral-7B-v0.1", 17 | help="Model ID", 18 | ) 19 | parser.add_argument("--device", type=str, help="Device to put the model on") 20 | parser.add_argument( 21 | "--smiles-type", 22 | type=str, 23 | choices=["generic", "isocyanates", "acrylates", "chain_extenders"], 24 | default="generic", 25 | help="Type of SMILES to generate", 26 | ) 27 | return parser.parse_args() 28 | 29 | 30 | if __name__ == "__main__": 31 | args = parse_args() 32 | model_id = args.model_id 33 | 34 | # Detect if GPU is available, otherwise use CPU 35 | device = torch.device( 36 | args.device or ("cuda" if torch.cuda.is_available() else "cpu") 37 | ) 38 | print(f"Using device: {device}") 39 | 40 | # Load model and tokenizer 41 | tokenizer = AutoTokenizer.from_pretrained(model_id) 42 | tokenizer.pad_token = tokenizer.eos_token 43 | print(f"N tokens: {len(tokenizer.get_vocab())}") 44 | # Load model to defined device 45 | model = AutoModelForCausalLM.from_pretrained(model_id).to(device) 46 | 47 | # Load grammar 48 | grammar_name = args.smiles_type 49 | with open(f"examples/grammars/SMILES/{grammar_name}.ebnf", "r") as file: 50 | grammar_str = file.read() 51 | 52 | parsed_grammar = parse_ebnf(grammar_str) 53 | first_rule = grammar_str.split("\n")[0] 54 | print(f"{grammar_name}: {first_rule}") 55 | 56 | grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer) 57 | grammar_processor = GrammarConstrainedLogitsProcessor(grammar) 58 | 59 | # Generate 60 | prefix1 = f"This is a {grammar_name} SMILES string:" 61 | 62 | input_ids = tokenizer( 63 | [prefix1], add_special_tokens=False, return_tensors="pt", padding=True 64 | )["input_ids"].to( 65 | device 66 | ) # Move input_ids to the same device as model 67 | 68 | max_new_tokens = 20 69 | unconstrained_output = model.generate( 70 | input_ids, 71 | do_sample=False, 72 | max_new_tokens=max_new_tokens, 73 | repetition_penalty=1.9, 74 | num_return_sequences=1, 75 | ) 76 | 77 | start = time.time() 78 | constrained_output = model.generate( 79 | input_ids, 80 | do_sample=False, 81 | max_new_tokens=max_new_tokens, 82 | logits_processor=[grammar_processor], 83 | repetition_penalty=1.9, 84 | num_return_sequences=1, 85 | ) 86 | 87 | string_grammar = StringRecognizer( 88 | parsed_grammar.grammar_encoding, parsed_grammar.symbol_table["root"] 89 | ) 90 | 91 | res = tokenizer.decode( 92 | constrained_output[0], 93 | skip_special_tokens=True, 94 | ) 95 | 96 | # decode output 97 | generations = tokenizer.batch_decode( 98 | torch.concat([unconstrained_output, constrained_output]), 99 | skip_special_tokens=True, 100 | ) 101 | 102 | print(f"Total decoding time: {time.time()-start:.2f}s") 103 | 104 | for generation, gen_type in zip(generations, ["Unconstrained:", "Constrained:"]): 105 | print(gen_type) 106 | print(generation) 107 | assert string_grammar._accept_prefix( 108 | res[len(prefix1) :] 109 | ), f"The generated prefix does not match the grammar: {string_grammar._accept_prefix(res[len(prefix1):])}" 110 | print( 111 | f"The generation matches the grammar: {string_grammar._accept_string(generation[len(prefix1):])}" 112 | ) 113 | 114 | #### 115 | # 116 | # Unconstrained: 117 | # This is a generic SMILES string: 118 | # C1=CC(C2)NC3c4cccc5cc6[n 119 | # Constrained: 120 | # This is a generic SMILES string:[102as]-Oc(=C)NCCCNCO.S 121 | -------------------------------------------------------------------------------- /transformers_cfg/tokenization/tokenizer.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | from transformers import ( 3 | GPT2TokenizerFast, 4 | BartTokenizerFast, 5 | LlamaTokenizerFast, 6 | T5TokenizerFast, 7 | CodeGenTokenizerFast, 8 | PreTrainedTokenizerFast, 9 | GemmaTokenizerFast, 10 | Qwen2TokenizerFast, 11 | ByT5Tokenizer, 12 | ) 13 | 14 | from transformers_cfg.tokenization.SUPPORTED_TOKENIZERS import SUPPORTED_TOKENIZERS 15 | 16 | 17 | def get_TCFG_tokenizer_class(model_name_or_tokenizer): 18 | from transformers import AutoTokenizer 19 | 20 | if isinstance(model_name_or_tokenizer, str): 21 | tokenizer = AutoTokenizer.from_pretrained(model_name_or_tokenizer) 22 | else: 23 | tokenizer = model_name_or_tokenizer 24 | 25 | return TCFG_Tokenizer.from_hf_tokenizer(tokenizer).__class__ 26 | 27 | 28 | class TCFG_Tokenizer: 29 | def __init__(self, hf_tokenizer): 30 | self.hf_tokenizer = hf_tokenizer 31 | self.special_token_ids = set(hf_tokenizer.all_special_ids) 32 | 33 | def real_vocab_size(self): 34 | return len(self.hf_tokenizer.get_vocab()) 35 | 36 | @classmethod 37 | def from_hf_tokenizer(cls, hf_tokenizer): 38 | assert ( 39 | type(hf_tokenizer) in SUPPORTED_TOKENIZERS 40 | ), f"Tokenizer not supported: {hf_tokenizer.__class__.__name__}, supported tokenizers: {SUPPORTED_TOKENIZERS}" 41 | 42 | if isinstance( 43 | hf_tokenizer, 44 | (GPT2TokenizerFast, BartTokenizerFast, Qwen2TokenizerFast, ByT5Tokenizer), 45 | ): 46 | return TCFG_Tokenizer(hf_tokenizer) 47 | elif isinstance( 48 | hf_tokenizer, (LlamaTokenizerFast, GemmaTokenizerFast, T5TokenizerFast) 49 | ): 50 | return TCFG_LlamaTokenizer(hf_tokenizer) 51 | elif isinstance(hf_tokenizer, CodeGenTokenizerFast): 52 | # phi reuses the codegen tokenizer 53 | return TCFG_PhiTokenizer(hf_tokenizer) 54 | elif ( 55 | isinstance(hf_tokenizer, PreTrainedTokenizerFast) 56 | and "Llama-3" 57 | in hf_tokenizer.name_or_path # this includes llama-3/llama-3.1/llama-3.2/llama-3.3 58 | ): 59 | return TCFG_LlamaTokenizer(hf_tokenizer) 60 | else: 61 | raise NotImplementedError( 62 | f"Tokenizer not supported: {hf_tokenizer.__class__.__name__}" 63 | ) 64 | 65 | # will be extended by the subclasses 66 | def get_special_token_ids_to_excluded(self) -> Set[int]: 67 | return self.special_token_ids 68 | 69 | 70 | class TCFG_LlamaTokenizer(TCFG_Tokenizer): 71 | def __init__(self, hf_tokenizer): 72 | super().__init__(hf_tokenizer) 73 | 74 | def get_special_token_ids_to_excluded(self): 75 | if "deepseek-coder" in self.hf_tokenizer.name_or_path: 76 | # deepseek has in total 22 special tokens, with token_ids from 32000 to 32021 77 | # with first 13 being characters for bytes: {'õ': 32000, '÷': 32001, 'Á': 32002, 'ý': 32003, 'À': 32004, 'ÿ': 32005, 'ø': 32006, 'ú': 32007, 'þ': 32008, 'ü': 32009, 'ù': 32010, 'ö': 32011, 'û': 32012} 78 | # the rest are special tokens for the tokenizer: { '<|begin▁of▁sentence|>': 32013, '<|end▁of▁sentence|>': 32014, '<|fim▁hole|>': 32015, '<|fim▁begin|>': 32016, '<|fim▁end|>': 32017, '': 32018, '<|User|>': 32019, '<|Assistant|>': 32020, '<|EOT|>': 32021} 79 | added_vocab_dict = self.hf_tokenizer.get_added_vocab() 80 | added_tokens_id_to_excluded = set( 81 | [ 82 | token_id 83 | for tok, token_id in added_vocab_dict.items() 84 | if tok.startswith("<|") 85 | ] 86 | ) 87 | return self.special_token_ids.union(added_tokens_id_to_excluded) 88 | return self.special_token_ids 89 | 90 | 91 | class TCFG_CharacterTokenizer(TCFG_Tokenizer): 92 | """ 93 | Not yet used, but can be used for character level tokenization (even though rarely used in practice) 94 | """ 95 | 96 | def __init__(self, hf_tokenizer): 97 | super().__init__(hf_tokenizer) 98 | 99 | 100 | class TCFG_PhiTokenizer(TCFG_Tokenizer): 101 | def __init__(self, hf_tokenizer): 102 | super().__init__(hf_tokenizer) 103 | 104 | def real_vocab_size(self): 105 | return 50257 # 50 k tokens + 256 for bytes + 1 for EOS 106 | -------------------------------------------------------------------------------- /examples/generate_geo_query.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | from transformers import AutoModelForCausalLM, AutoTokenizer 4 | from transformers_cfg.grammar_utils import IncrementalGrammarConstraint 5 | from transformers_cfg.recognizer import StringRecognizer 6 | from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor 7 | from transformers_cfg.parser import parse_ebnf 8 | 9 | 10 | def parse_args(): 11 | parser = argparse.ArgumentParser(description="Generate geo query strings") 12 | parser.add_argument( 13 | "--model-id", 14 | type=str, 15 | default="/dlabdata1/llm_hub/Mistral-7B-v0.1", 16 | help="Model ID", 17 | ) 18 | parser.add_argument("--device", type=str, help="Device to put the model on") 19 | return parser.parse_args() 20 | 21 | 22 | def main(): 23 | args = parse_args() 24 | model_id = args.model_id 25 | 26 | # Detect if GPU is available, otherwise use CPU 27 | device = torch.device( 28 | args.device or ("cuda" if torch.cuda.is_available() else "cpu") 29 | ) 30 | print(f"Using device: {device}") 31 | 32 | # Load model and tokenizer 33 | tokenizer = AutoTokenizer.from_pretrained(model_id) 34 | tokenizer.pad_token = tokenizer.eos_token 35 | # Load model to defined device 36 | model = AutoModelForCausalLM.from_pretrained(model_id).to(device) 37 | 38 | # Load grammar 39 | with open(f"examples/grammars/geo_query.ebnf", "r") as file: 40 | grammar_str = file.read() 41 | 42 | parsed_grammar = parse_ebnf(grammar_str) 43 | grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer) 44 | grammar_processor = GrammarConstrainedLogitsProcessor(grammar) 45 | 46 | # Generate 47 | prompts = [ 48 | "which state contains most rivers ? ", 49 | "number of citizens in boulder ? ", 50 | "what are the major cities of the us ? ", 51 | "what is the smallest city in washington ? ", 52 | "how many states border colorado and border new mexico ? ", 53 | ] 54 | 55 | input_ids = tokenizer( 56 | prompts, add_special_tokens=False, return_tensors="pt", padding=True 57 | )["input_ids"].to( 58 | device 59 | ) # Move input_ids to the same device as model 60 | 61 | n_examples = input_ids.shape[0] 62 | 63 | max_new_tokens = 50 64 | unconstrained_output = model.generate( 65 | input_ids, 66 | do_sample=False, 67 | max_new_tokens=max_new_tokens, 68 | repetition_penalty=1.9, 69 | num_return_sequences=1, 70 | ) 71 | constrained_output = model.generate( 72 | input_ids, 73 | do_sample=False, 74 | max_new_tokens=max_new_tokens, 75 | logits_processor=[grammar_processor], 76 | repetition_penalty=1.9, 77 | num_return_sequences=1, 78 | ) 79 | 80 | parsed_grammar = parse_ebnf(grammar_str) 81 | string_grammar = StringRecognizer( 82 | parsed_grammar.grammar_encoding, parsed_grammar.symbol_table["root"] 83 | ) 84 | 85 | # decode outputs (possibly of different lengths across decoding modes) 86 | generations = tokenizer.batch_decode( 87 | unconstrained_output, skip_special_tokens=True 88 | ) + tokenizer.batch_decode(constrained_output, skip_special_tokens=True) 89 | print() 90 | for i in range(n_examples): 91 | print(f"Unconstrained: {generations[i]}") 92 | constrained_generation = generations[i + n_examples] 93 | print(f"Constrained: {generations[i + n_examples]}") 94 | print( 95 | f"The constrained generation matches the grammar: {string_grammar._accept_string(constrained_generation[len(prompts[i]):])}" 96 | ) 97 | print( 98 | f"The generated prefix matches the grammar: {string_grammar._accept_prefix(constrained_generation[len(prompts[i]):])}" 99 | ) 100 | print() 101 | 102 | 103 | if __name__ == "__main__": 104 | main() 105 | 106 | 107 | ########################## 108 | # Example output: 109 | # 110 | # Unconstrained: how many states border colorado and border new mexico ? 1. 111 | # - How long is the drive from denver to albuquerque? The distance between Denver, Colorado (CO) & Alburqueque New Mexico(NM). Driving directions for your road trip or vacation: Get driving 112 | # Constrained: how many states border colorado and border new mexico ? answer(smallest_one(area_1(stateid('colorado')))) 113 | # 114 | ########################## 115 | -------------------------------------------------------------------------------- /transformers_cfg/adapters/llama_cpp_python.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | import torch 4 | 5 | logger = logging.getLogger(__name__) 6 | 7 | 8 | def llama_cpp_python(processor): 9 | """ 10 | Adapter function for llama-cpp-python. 11 | 12 | Args: 13 | processor: A GrammarConstrainedLogitsProcessor instance 14 | 15 | Returns: 16 | A function that can be used as a logits processor with llama-cpp-python 17 | """ 18 | reinit_attempts = 0 19 | reinit_max = 3 20 | accumulated_tokens = [] 21 | 22 | def _force_eos(scores): 23 | eos_token = processor.grammar_constraint.tokenizer.eos_token_id 24 | logger.warning(f"Forcing EOS token: {eos_token}") 25 | mask = torch.full_like(scores, fill_value=-float("inf")) 26 | if scores.dim() == 2: 27 | mask[:, eos_token] = 0 28 | else: 29 | mask[eos_token] = 0 30 | return mask 31 | 32 | def adapter_func(input_ids, scores): 33 | nonlocal reinit_attempts, accumulated_tokens 34 | 35 | # Normalize input_ids to a list of token sequences 36 | if np.isscalar(input_ids): 37 | input_ids = [int(input_ids)] 38 | elif isinstance(input_ids, np.ndarray): 39 | input_ids = input_ids.tolist() 40 | elif isinstance(input_ids, list): 41 | input_ids = [int(i) if isinstance(i, np.generic) else i for i in input_ids] 42 | elif isinstance(input_ids, np.generic): 43 | input_ids = [int(input_ids)] 44 | 45 | # Ensure we have a batch (list of token lists) 46 | if input_ids and isinstance(input_ids[0], int): 47 | input_ids = [input_ids] 48 | 49 | # Convert scores to a torch.Tensor if needed 50 | if not isinstance(scores, torch.Tensor): 51 | scores = torch.tensor(scores) 52 | 53 | # Ensure scores is 2D: [batch, vocab_size] 54 | if scores.dim() == 1: 55 | scores = scores.unsqueeze(0) 56 | 57 | # Track tokens for debugging 58 | if len(input_ids[0]) > len(accumulated_tokens): 59 | new_token = input_ids[0][-1] 60 | accumulated_tokens.append(new_token) 61 | try: 62 | token_text = processor.grammar_constraint.tokenizer.decode([new_token]) 63 | logger.debug(f"Added token: {new_token} ({token_text})") 64 | except Exception: 65 | logger.debug(f"Added token: {new_token} (cannot decode)") 66 | 67 | # Check for consistency: if the length of our input token sequence 68 | # does not match what the grammar expects, then reinitialize 69 | current_length = len(input_ids[0]) 70 | if hasattr(processor.grammar_constraint, "last_size") and processor.grammar_constraint.last_size is not None: 71 | expected_length = processor.grammar_constraint.last_size + 1 72 | if current_length != expected_length: 73 | logger.warning(f"Length mismatch: current={current_length}, expected={expected_length}. Reinitializing.") 74 | processor.reset() 75 | reinit_attempts = 0 76 | 77 | try: 78 | processed_scores = processor.process_logits(input_ids, scores) 79 | reinit_attempts = 0 80 | except ValueError as e: 81 | error_msg = str(e) 82 | if "All stacks are empty" in error_msg: 83 | # Try to recover by reinitializing the grammar constraint 84 | if reinit_attempts < reinit_max: 85 | logger.warning(f"Grammar constraint error: {error_msg}. Attempt {reinit_attempts+1}/{reinit_max} to recover.") 86 | processor.reset() 87 | reinit_attempts += 1 88 | try: 89 | processed_scores = processor.process_logits(input_ids, scores) 90 | except ValueError as e2: 91 | logger.error(f"Recovery failed: {str(e2)}") 92 | processed_scores = _force_eos(scores) 93 | else: 94 | # If reinitialization has already been attempted enough times, 95 | # treat the output as complete and force EOS 96 | logger.error(f"Max retries ({reinit_max}) exceeded. Current text: {processor.grammar_constraint.tokenizer.decode(accumulated_tokens)}") 97 | processed_scores = _force_eos(scores) 98 | else: 99 | logger.error(f"Unexpected error: {error_msg}") 100 | raise e 101 | 102 | # Remove the batch dimension if present 103 | if processed_scores.dim() == 2 and processed_scores.size(0) == 1: 104 | processed_scores = processed_scores.squeeze(0) 105 | return processed_scores.detach().cpu().numpy() 106 | 107 | return adapter_func 108 | -------------------------------------------------------------------------------- /transformers_cfg/tokenization/byte_trie.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from functools import lru_cache 3 | from typing import Dict, List, Set, Tuple, Optional 4 | from collections import deque 5 | 6 | from transformers_cfg.tokenization.mapping.token2byte import ( 7 | Token2ByteMapping, 8 | ) 9 | from transformers_cfg.tokenization.tokenizer import TCFG_Tokenizer 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class TrieNode: 15 | def __init__(self): 16 | self.children: Dict[int, "TrieNode"] = {} 17 | self.is_end_of_word: bool = False 18 | self.token_id: Optional[int] = None 19 | 20 | 21 | class ByteTrie: 22 | def __init__(self): 23 | self.root = TrieNode() 24 | 25 | def insert(self, word, token_id=None): 26 | node = self.root 27 | for char in word: 28 | if char not in node.children: 29 | node.children[char] = TrieNode() 30 | node = node.children[char] 31 | node.is_end_of_word = True 32 | node.token_id = token_id 33 | 34 | @classmethod 35 | def from_tokenizer(cls, tokenizer): 36 | vocab: Dict[str, int] = tokenizer.get_vocab() 37 | trie = cls() 38 | mapping = Token2ByteMapping.from_hf_tokenizer(tokenizer) 39 | TCFG_tokenizer = TCFG_Tokenizer.from_hf_tokenizer(tokenizer) 40 | 41 | token_ids_to_ignore: Set[ 42 | int 43 | ] = TCFG_tokenizer.get_special_token_ids_to_excluded() 44 | for token_id in range(TCFG_tokenizer.real_vocab_size()): 45 | if token_id not in token_ids_to_ignore: 46 | byte_repr = mapping.map(token_id) 47 | trie.insert(byte_repr, token_id) 48 | trie.vocab_size = len(vocab) 49 | return trie 50 | 51 | @lru_cache(maxsize=128) 52 | def __len__(self): 53 | # return len(self.dfs(verbose=False)) 54 | return self.vocab_size 55 | 56 | def bfs( 57 | self, predicate=lambda x: True, verbose=False 58 | ) -> List[Tuple[List[int], int]]: 59 | queue = deque([(self.root, [])]) 60 | valid_byte_seqs: List[Tuple[List[int], int]] = [] 61 | counter = {"visited": 0, "pruned": 0} 62 | 63 | while queue: 64 | counter["visited"] += 1 65 | node, byte_seq = queue.popleft() 66 | if predicate(byte_seq): 67 | if node.is_end_of_word: 68 | valid_byte_seqs.append((byte_seq, node.token_id)) 69 | for char, next_node in node.children.items(): 70 | new_byte_seq: List[int] = byte_seq.copy() 71 | new_byte_seq.append(char) 72 | queue.append((next_node, new_byte_seq)) 73 | else: 74 | counter["pruned"] += 1 75 | return valid_byte_seqs 76 | 77 | def get_next_token_acceptance( 78 | self, accept=lambda x: True, accept_eos=True, eos_token_id=None 79 | ) -> List[bool]: 80 | valid_byte_seqs: List[Tuple[List[int], int]] = self.bfs(accept, verbose=True) 81 | valid_token_ids: List[int] = [token_id for _, token_id in valid_byte_seqs] 82 | token_acceptance: List[bool] = [False] * (len(self)) 83 | 84 | for token_id in valid_token_ids: 85 | token_acceptance[token_id] = True 86 | if not accept_eos: 87 | # eos_token is mapped to an empty string, so it's always accepted regardless of the accept function 88 | # this can be undesirable, so we can set it to False to ignore it 89 | token_acceptance[eos_token_id] = False 90 | return token_acceptance 91 | 92 | def visualize(self, max_depth=3): 93 | def _visualize(node, prefix, depth): 94 | if depth > max_depth: 95 | return 96 | for char, next_node in node.children.items(): 97 | print(f"{prefix}{char} (Token ID: {next_node.token_id})") 98 | _visualize(next_node, prefix + " ", depth + 1) 99 | 100 | print("Visualizing ByteTrie:") 101 | _visualize(self.root, "", 1) 102 | 103 | 104 | if __name__ == "__main__": 105 | import logging 106 | 107 | # Configure logging 108 | logging.basicConfig(level=logging.INFO) 109 | from transformers import AutoTokenizer 110 | 111 | tokenizer = AutoTokenizer.from_pretrained("gpt2", fast=True) 112 | 113 | trie = ByteTrie.from_tokenizer(tokenizer) 114 | print(f"length of trie: {len(trie)}=={len(tokenizer.vocab.items())}") 115 | 116 | trie.visualize(max_depth=0) 117 | 118 | # 119 | # print(trie.search("hello")) # Example, replace with actual words from the vocab 120 | # print(trie.start_with_prefix("hell")) 121 | # 122 | # # Example Usage 123 | # words = trie.dfs(accept=lambda x: len(x) > 0 and x[0] == 65 or len(x)==0) 124 | # for word in words: 125 | # print(bytes(word[0]).decode("utf-8")) 126 | # 127 | # # Example Usage 128 | # words = trie.bfs(predicate=lambda x: len(x) > 0 and x[0] == 65 or len(x)==0) 129 | # for word in words: 130 | # print(bytes(word[0]).decode("utf-8")) 131 | # 132 | # token_acceptance = trie.get_next_token_acceptance(accept=lambda x: len(x) > 0 and x[0] == 65 or len(x)==0) 133 | # print(sum(token_acceptance)) 134 | # assert sum(token_acceptance) == len(words) 135 | 136 | ######################## 137 | # UTF-8 138 | ######################## 139 | 140 | # from transformers import AutoTokenizer 141 | # 142 | # japanese = "こんにちは世界" 143 | # with open("examples/grammars/japanese.ebnf", "r") as file: 144 | # input_text = file.read() 145 | # parsed_grammar = parse_ebnf(input_text) 146 | # 147 | # start_rule_id = parsed_grammar.symbol_table["root"] 148 | # 149 | # recognizer = GrammarRecognizer(parsed_grammar.grammar_encoding, start_rule_id) 150 | # parsing_state = recognizer.init_parsing_state() 151 | # token_acc = trie.get_next_token_acceptance(accept=lambda x: recognizer._probe_bytes_partial_match(x, parsing_state=parsing_state)) 152 | -------------------------------------------------------------------------------- /tests/test_accept_token_sequence/_test_accept_tokens_mixin.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import pytest 3 | from transformers import PreTrainedTokenizer 4 | from transformers_cfg.token_grammar_recognizer import IncrementalTokenRecognizer 5 | from transformers_cfg.utils import pprint_token_ids 6 | 7 | 8 | class TokenizerTesterMixin: 9 | tokenizer_class = None 10 | pretrained_name = None 11 | rust_tokenizer_class = None 12 | test_slow_tokenizer = True 13 | test_rust_tokenizer = True 14 | space_between_special_tokens = False 15 | from_pretrained_kwargs = None 16 | from_pretrained_filter = None 17 | from_pretrained_vocab_key = "vocab_file" 18 | test_seq2seq = True 19 | 20 | # set to True to test a sentencepiece tokenizer 21 | test_sentencepiece = False 22 | 23 | # set to True to ignore casing when testing a sentencepiece tokenizer 24 | # test_sentencepiece must also be set to True 25 | test_sentencepiece_ignore_case = False 26 | 27 | @pytest.fixture(autouse=True) 28 | def setup_tokenizer(self): 29 | self.tokenizer = self.get_tokenizer() 30 | 31 | def get_tokenizer(self, **kwargs) -> PreTrainedTokenizer: 32 | use_fast = kwargs.pop("use_fast", True) 33 | return self.tokenizer_class.from_pretrained( 34 | self.pretrained_name, use_fast=use_fast, **kwargs 35 | ) 36 | 37 | def test_json_parsable(self): 38 | # Test that we can load a JSON object 39 | with open("examples/grammars/json.ebnf", "r") as file: 40 | input_text = file.read() 41 | JsontokenRecognizer = IncrementalTokenRecognizer( 42 | grammar_str=input_text, start_rule_name="root", tokenizer=self.tokenizer 43 | ) 44 | 45 | valid_json = '{"foo": "bar", "baz": "bat"}' 46 | token_ids = self.tokenizer.encode(valid_json) 47 | pprint_token_ids(self.tokenizer, token_ids) 48 | 49 | # check if there is unk token 50 | for token_id in token_ids: 51 | if token_id == self.tokenizer.unk_token_id: 52 | warnings.warn( 53 | f"unk token found in input_token_ids: {token_ids}, skipping test" 54 | ) 55 | return 56 | 57 | acc_state = JsontokenRecognizer._update_state_with_single_token_seq( 58 | token_ids, as_string=False 59 | ) 60 | # the json object is complete, so the stacks should be empty 61 | assert acc_state.stacks == set() or acc_state.stacks == { 62 | tuple() 63 | }, f"stacks: {acc_state.stacks}, not empty" 64 | 65 | def test_balanced_parentheses(self): 66 | # Test that we can recognize a balanced parentheses 67 | with open("examples/grammars/balanced_parentheses.ebnf", "r") as file: 68 | input_text = file.read() 69 | recognizer = IncrementalTokenRecognizer( 70 | grammar_str=input_text, start_rule_name="root", tokenizer=self.tokenizer 71 | ) 72 | 73 | balanced_parentheses = "((((((((()))))))))" 74 | token_ids = self.tokenizer.encode(balanced_parentheses) 75 | pprint_token_ids(self.tokenizer, token_ids) 76 | 77 | # check if there is unk token 78 | for token_id in token_ids: 79 | if token_id == self.tokenizer.unk_token_id: 80 | warnings.warn( 81 | f"unk token found in input_token_ids: {token_ids}, skipping test" 82 | ) 83 | return 84 | parsing_state = recognizer._update_state_with_single_token_seq( 85 | token_ids, as_string=False 86 | ) 87 | # the json object is complete, so the stacks should be empty 88 | assert parsing_state.stacks == set() or parsing_state.stacks == { 89 | tuple() 90 | }, f"stacks: {parsing_state.stacks}, not empty" 91 | 92 | def test_forcing_sequence(self): 93 | 94 | string_to_force = "12345 678 90" 95 | 96 | grammar_str = f""" 97 | root ::= "{string_to_force}" 98 | 99 | """ 100 | 101 | tokenRecognizer = IncrementalTokenRecognizer( 102 | grammar_str=grammar_str, start_rule_name="root", tokenizer=self.tokenizer 103 | ) 104 | 105 | token_ids = self.tokenizer.encode(string_to_force) 106 | pprint_token_ids(self.tokenizer, token_ids) 107 | 108 | # check if there is unk token 109 | for token_id in token_ids: 110 | if token_id == self.tokenizer.unk_token_id: 111 | warnings.warn( 112 | f"unk token found in input_token_ids: {token_ids}, skipping test" 113 | ) 114 | return 115 | 116 | acc_state = tokenRecognizer._update_state_with_single_token_seq( 117 | token_ids, as_string=False 118 | ) 119 | # the json object is complete, so the stacks should be empty 120 | assert acc_state.stacks == set() or acc_state.stacks == { 121 | tuple() 122 | }, f"stacks: {acc_state.stacks}, not empty" 123 | 124 | def test_emoji(self): 125 | """ 126 | Test that we can accept emoji 127 | """ 128 | 129 | with open("examples/grammars/emoji.ebnf", "r") as file: 130 | input_text = file.read() 131 | 132 | tokenRecognizer = IncrementalTokenRecognizer( 133 | grammar_str=input_text, start_rule_name="root", tokenizer=self.tokenizer 134 | ) 135 | 136 | emoji = "😀😄😂" 137 | token_ids = self.tokenizer.encode(emoji) 138 | pprint_token_ids(self.tokenizer, token_ids) 139 | 140 | # check if there is unk token 141 | for token_id in token_ids: 142 | if token_id == self.tokenizer.unk_token_id: 143 | warnings.warn( 144 | f"unk token found in input_token_ids: {token_ids}, skipping test" 145 | ) 146 | return 147 | 148 | accpetance = tokenRecognizer.accept_token_ids(token_ids, as_string=False) 149 | 150 | assert accpetance, f"emoji: {emoji} not accepted, but it should be" 151 | -------------------------------------------------------------------------------- /tests/test_accept_string/test_overnight.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from transformers_cfg.parser import parse_ebnf 3 | from transformers_cfg.recognizer import StringRecognizer 4 | from dataclasses import dataclass 5 | 6 | 7 | @dataclass 8 | class OvernightTestCase: 9 | name: str 10 | overnight: str 11 | 12 | 13 | valid_overnight_sentences = [ 14 | OvernightTestCase( 15 | "simple_request", "(listValue (getProperty en.block.block1 width))" 16 | ), 17 | OvernightTestCase( 18 | "simple_filter", 19 | "(listValue (filter (getProperty (singleton en.block) !type) height = 3 en.inch))", 20 | ), 21 | OvernightTestCase( 22 | "count_values", 23 | "(listValue (countComparative (getProperty (singleton en.block) !type) shape >= 2))", 24 | ), 25 | OvernightTestCase( 26 | "ensure_property", 27 | "(listValue (filter (getProperty (singleton en.block) !type) (ensureNumericProperty width) <= (ensureNumericEntity 3 en.inch)))", 28 | ), 29 | OvernightTestCase( 30 | "above", 31 | "(listValue (filter (filter (getProperty (singleton en.block) !type) (reverse above) = en.block.block1) above = en.block.block1))", 32 | ), 33 | OvernightTestCase( 34 | "reverse right", 35 | "(listValue (filter (getProperty (singleton en.block) !type) right = (filter (getProperty (singleton en.block) !type) (reverse right) = en.block.block1)))", 36 | ), 37 | OvernightTestCase( 38 | "agg", 39 | "(listValue (superlative (getProperty (singleton en.block) !type) max (ensureNumericProperty length)))", 40 | ), 41 | OvernightTestCase( 42 | "nested_filters", 43 | "(listValue (filter (filter (getProperty (singleton en.block) !type) (reverse above) = en.block.block1) (reverse right) = en.block.block1))", 44 | ), 45 | OvernightTestCase( 46 | "shape", 47 | "(listValue (filter (getProperty (singleton en.block) !type) shape != en.shape.pyramid))", 48 | ), 49 | OvernightTestCase( 50 | "is_special", 51 | "(listValue (filter (filter (getProperty (singleton en.block) !type) is_special) left = en.block.block1))", 52 | ), 53 | OvernightTestCase( 54 | "two_blocks", 55 | "(listValue (filter (getProperty (singleton en.block) !type) left = (concat en.block.block1 en.block.block2)))", 56 | ), 57 | OvernightTestCase( 58 | "count_superlative", 59 | "(listValue (countSuperlative (getProperty (singleton en.block) !type) min (reverse above) (getProperty (singleton en.block) !type)))", 60 | ), 61 | OvernightTestCase( 62 | "long_query", 63 | "(listValue (filter (getProperty (singleton en.block) !type) (ensureNumericProperty height) > (ensureNumericEntity (getProperty en.block.block1 height))))", 64 | ), 65 | OvernightTestCase( 66 | "2_value", 67 | "(listValue (countComparative (getProperty (singleton en.block) !type) left < 2 (getProperty (singleton en.block) !type)))", 68 | ), 69 | OvernightTestCase( 70 | "concat_shapes", 71 | "(listValue (filter (getProperty (singleton en.block) !type) shape = (concat en.shape.pyramid en.shape.cube)))", 72 | ), 73 | ] 74 | 75 | 76 | valid_overnight_prefixes = [ 77 | OvernightTestCase("empty_string", ""), 78 | OvernightTestCase( 79 | "unbalanced_paranthesis", "(listValue (getProperty en.block.block1 width" 80 | ), 81 | OvernightTestCase("undefined_argument", "(listValue (getProperty en.block.block1"), 82 | OvernightTestCase( 83 | "left_comarisson", 84 | "(listValue (filter (getProperty (singleton en.block) !type) (ensureNumericProperty length) >=", 85 | ), 86 | ] 87 | 88 | invalid_overnight_sentences = [ 89 | OvernightTestCase( 90 | "unknown_property", "(listValue (getProperty en.block.block1 sparkliness))" 91 | ), 92 | OvernightTestCase("property", "(getProperty en.block.block1 width)"), 93 | OvernightTestCase("number_value", "3 en.inch"), 94 | OvernightTestCase( 95 | "extra_space", "(listValue ( getProperty en.block.block1 width))" 96 | ), 97 | OvernightTestCase("empty_operator", "(listValue (getProperty ))"), 98 | OvernightTestCase("empty_paranthesis", "()"), 99 | OvernightTestCase("missing_argument", "(listValue (getProperty en.block.block1 ))"), 100 | OvernightTestCase( 101 | "inexisting_shape", 102 | "(listValue (filter (getProperty (singleton en.block) !type) shape = (concat en.shape.pyramid en.shape.sphere)))", 103 | ), 104 | ] 105 | 106 | 107 | @pytest.fixture(scope="module") 108 | def recognizer(): 109 | with open(f"examples/grammars/overnight.ebnf", "r") as file: 110 | input_text = file.read() 111 | parsed_grammar = parse_ebnf(input_text) 112 | start_rule_id = parsed_grammar.symbol_table["root"] 113 | recognizer = StringRecognizer(parsed_grammar.grammar_encoding, start_rule_id) 114 | return recognizer 115 | 116 | 117 | def test_valid_sentence(recognizer): 118 | for overnight_test_case in valid_overnight_sentences: 119 | assert ( 120 | recognizer._accept_string(overnight_test_case.overnight) == True 121 | ), f"Failed on {overnight_test_case.name}, {overnight_test_case.overnight}" 122 | 123 | for overnight_test_case in valid_overnight_prefixes + invalid_overnight_sentences: 124 | assert ( 125 | recognizer._accept_string(overnight_test_case.overnight) == False 126 | ), f"Failed on {overnight_test_case.name}, {overnight_test_case.overnight}" 127 | 128 | 129 | def test_valid_prefixes(recognizer): 130 | for overnight_test_case in valid_overnight_sentences + valid_overnight_prefixes: 131 | assert ( 132 | recognizer._accept_prefix(overnight_test_case.overnight) == True 133 | ), f"Failed on {overnight_test_case.name}, {overnight_test_case.overnight}" 134 | 135 | for overnight_test_case in invalid_overnight_sentences: 136 | assert ( 137 | recognizer._accept_prefix(overnight_test_case.overnight) == False 138 | ), f"Failed on {overnight_test_case.name}, {overnight_test_case.overnight}" 139 | -------------------------------------------------------------------------------- /transformers_cfg/utf8_utils.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Tuple 3 | 4 | from dataclasses import dataclass 5 | 6 | 7 | @dataclass 8 | class PartialUTF8: 9 | """ 10 | A data class representing the state of a partially decoded UTF-8 sequence. 11 | 12 | Attributes: 13 | - value (int): The current accumulated value of the partially decoded Unicode code point. 14 | This attribute stores the bits that have been decoded so far. For a fully decoded 15 | character or before any partial decoding has started, this would typically be `0`. 16 | 17 | - n_remain (int): The number of bytes remaining to complete the current UTF-8 encoded character. 18 | A value of `-1` indicates that there is no ongoing partial decoding, i.e., 19 | either decoding has not started, or the last character was fully decoded. 20 | 21 | This class is used to handle situations where UTF-8 encoded data may end in the middle of a character 22 | sequence, allowing for the decoding process to be resumed when more data becomes available. 23 | """ 24 | 25 | value: int = 0 # Default to 0, indicating no partial value accumulated 26 | n_remain: int = ( 27 | -1 28 | ) # Default to -1, indicating no bytes are currently expected to complete the character 29 | 30 | def __hash__(self): 31 | return hash((self.value, self.n_remain)) 32 | 33 | def __eq__(self, other): 34 | if not isinstance(other, PartialUTF8): 35 | return NotImplemented 36 | return self.value == other.value and self.n_remain == other.n_remain 37 | 38 | 39 | from typing import List, Tuple 40 | from functools import lru_cache 41 | 42 | 43 | @lru_cache(maxsize=3000000) 44 | def decode_utf8( 45 | src: bytes, partial_start: PartialUTF8 46 | ) -> Tuple[List[int], PartialUTF8]: 47 | # Lookup table for determining the total bytes based on the first byte's high 4 bits 48 | lookup = [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4] 49 | pos = 0 # Position in the src bytes to start decoding from 50 | code_points = [] # List to store the decoded Unicode code points 51 | value = partial_start.value # Start with any previously partial decoded value 52 | n_remain = partial_start.n_remain # Number of bytes remaining from a partial decode 53 | 54 | # If there's a partial sequence left from last decode, try to continue decoding it 55 | while pos < len(src) and n_remain > 0: 56 | next_byte = src[pos] # Get the next byte to process 57 | # Check if the continuation byte format is correct (`10xxxxxx`) 58 | if (next_byte >> 6) != 2: 59 | # If not, it's an invalid sequence. Abort and return a special error state. 60 | code_points = [0] 61 | return code_points, PartialUTF8(0, -1) 62 | 63 | # Accumulate the value by shifting left and adding the relevant 6 bits 64 | value = (value << 6) + (next_byte & 0x3F) 65 | pos += 1 # Move to the next byte 66 | n_remain -= 1 # Decrement the number of remaining bytes 67 | 68 | # If we've completed a partial sequence, add its value to the code points 69 | if partial_start.n_remain > 0 and n_remain == 0: 70 | code_points.append(value) 71 | 72 | # Process the rest of src as complete or new UTF-8 sequences 73 | while pos < len(src): 74 | first_byte = src[pos] # Get the first byte of the next sequence 75 | highbits = first_byte >> 4 # Extract the high 4 bits for the lookup table 76 | n_remain = lookup[highbits] - 1 # Determine remaining bytes in this sequence 77 | 78 | # If lookup returns an invalid number, it's an invalid sequence. Abort. 79 | if n_remain < 0: 80 | # raise ValueError("Invalid UTF-8 sequence") 81 | code_points = [0] 82 | return code_points, PartialUTF8(0, -1) 83 | 84 | # Calculate the mask to isolate significant bits from the first byte 85 | mask = (1 << (7 - n_remain)) - 1 86 | value = first_byte & mask # Apply the mask to get the initial value 87 | pos += 1 # Move to the next byte 88 | 89 | # Process the continuation bytes 90 | while pos < len(src) and n_remain > 0: 91 | next_byte = src[pos] 92 | # Shift the accumulated value and add the next 6 significant bits 93 | value = (value << 6) + (next_byte & 0x3F) 94 | pos += 1 # Move to the next byte 95 | n_remain -= 1 # Decrement the number of remaining bytes 96 | 97 | # If the sequence is complete, add its decoded value to the code points 98 | if n_remain == 0: 99 | code_points.append(value) 100 | 101 | # # Append a terminating value to indicate the end (following llama-cpp implementation) 102 | # code_points.append(0) 103 | # the following line is crucial for LRU cache to work, as it reset to the initial state 104 | if n_remain == 0: 105 | n_remain = -1 106 | value = 0 107 | 108 | # Return the decoded code points and the state of any partial decoding 109 | return code_points, PartialUTF8(value, n_remain) 110 | 111 | 112 | def decode_utf8_leading_char(src: bytes) -> tuple: 113 | first_byte = src[0] 114 | highbits = first_byte >> 4 115 | lookup = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4] 116 | char_len = lookup[highbits] 117 | 118 | # Extract the relevant bytes for the UTF-8 character 119 | utf8_char_bytes = src[:char_len] 120 | 121 | # Decode the character 122 | char = utf8_char_bytes.decode("utf-8") 123 | 124 | # Use ord() to convert the single character to its Unicode code point 125 | code_point = ord(char) 126 | 127 | # Remaining bytes 128 | remaining_bytes = src[char_len:] 129 | 130 | return code_point, remaining_bytes 131 | 132 | 133 | def decode_utf8_string(utf8_bytes: bytes) -> list: 134 | code_points = [] 135 | while utf8_bytes: 136 | code_point, utf8_bytes = decode_utf8_leading_char(utf8_bytes) 137 | code_points.append(code_point) 138 | return code_points 139 | 140 | 141 | if __name__ == "__main__": 142 | # Given string 143 | my_string = "€Hello" # The Euro symbol followed by "Hello" 144 | 145 | # Get UTF-8 encoded bytes 146 | utf8_bytes = my_string.encode("utf-8") 147 | 148 | assert utf8_bytes == b"\xe2\x82\xacHello" 149 | 150 | # Example usage with the Euro symbol followed by more characters 151 | code_point, remaining_bytes = decode_utf8_leading_char(utf8_bytes) 152 | 153 | print(f"Code Point: {code_point}") # Expected Output: 8364 (Euro symbol) 154 | print(f"Remaining Bytes: {remaining_bytes}") # Expected Output: b'Hello' 155 | 156 | # Example usage with the entire string 157 | code_points = decode_utf8_string(utf8_bytes) 158 | 159 | print( 160 | f"Code Points: {code_points}" 161 | ) # Expected Output: [8364, 72, 101, 108, 108, 111] 162 | 163 | print("-" * 50) 164 | 165 | # Example usage: 166 | utf8_bytes = b"\xe2\x82\xacHello" # UTF-8 encoded string (Euro symbol + "Hello") 167 | partial_start = PartialUTF8() # Assuming start with no partial sequence 168 | code_points, partial_utf8 = decode_utf8(utf8_bytes, partial_start) 169 | 170 | print("Code Points:", code_points) 171 | print("Remaining UTF-8 State:", partial_utf8.value, partial_utf8.n_remain) 172 | -------------------------------------------------------------------------------- /docs/debugging_custom_grammars.md: -------------------------------------------------------------------------------- 1 | # Debugging custom grammars 2 | 3 | This document provides best practices for debugging custom grammars when using the `transformers_cfg` library. It offers strategies to help identify and resolve common issues during grammar creation or modification. 4 | 5 | ## Table of contents 6 | 7 | - [Introduction](#introduction) 8 | - [Syntax highlighting](#syntax-highlighting) 9 | - [EBNF and variants](#ebnf-and-variants) 10 | - [Check parse](#check-parse) 11 | - [Test with input](#test-with-input) 12 | - [Debug mode](#debug-mode) 13 | - [Tips and tricks](#tips-and-tricks) 14 | - [Incremental development](#incremental-development) 15 | - [Isolate grammar components](#isolate-grammar-components) 16 | - [Test with language model](#test-with-language-model) 17 | 18 | ## Introduction 19 | 20 | Context-free grammars (CFGs) involve complex syntax and semantics. This guide outlines strategies and tools to help debug custom grammars effectively. The `transformers_cfg` library uses EBNF notation for grammar definition and aligns with the grammar module of [llama.cpp](https://github.com/ggerganov/llama.cpp/tree/master/grammars). For an introduction to EBNF, refer to the [llama.cpp documentation](https://github.com/ggerganov/llama.cpp/tree/master/grammars), where EBNF is referred to as `gbnf` for its integration with the project 21 | 22 | ## Syntax highlighting 23 | 24 | The Visual Studio Code extension EBNF offers syntax highlighting for EBNF grammars. 25 | 26 |

27 | EBNF syntax highlighting 28 |

29 |

Figure 1: EBNF syntax highlighting

30 | 31 | For IDEs that use the Open VSX marketplace, such as Trae, the [W3C EBNF extension](https://open-vsx.org/extension/mfederczuk/w3c-ebnf) provides similar functionality. In JetBrains IDEs, the [Context Free Grammar plugin](https://plugins.jetbrains.com/plugin/10162-context-free-grammar) is available. 32 | 33 | These extensions are third-party tools and are not affiliated with `transformers_cfg`. Use them responsibly and report any alternative suggestions. 34 | 35 | ## EBNF and variants 36 | 37 | EBNF is a notation with several variants, each featuring slightly different syntax while preserving the underlying semantics. 38 | 39 | The two major variants are: 40 | 41 | - [ISO/IEC 14977](https://en.wikipedia.org/wiki/Extended_Backus%E2%80%93Naur_form): The original standard for EBNF. 42 | - [W3C EBNF](https://www.w3.org/TR/REC-xml/#sec-notation): The variant used in the W3C XML specification. 43 | 44 | The EBNF variant in `transformers_cfg` aligns mostly with the W3C version, with one exception: the negation operator (`^`) is not yet supported but will be added in a future update. 45 | 46 | ## Check parse 47 | 48 | To verify that an EBNF grammar is correct, use the `transformers_cfg/parser::parse_ebnf` function. If Graphviz is installed, generate a parse tree by adding the `--graph` option. 49 | 50 | ```terminal 51 | python -m transformers_cfg.parser --grammar-file examples/grammars/your_grammar.ebnf 52 | ``` 53 | 54 | Example output for a JSON grammar is: 55 | 56 | ```terminal 57 | Grammar Rules: 58 | <0>root_2 ::= <2>jp-char <4>root_2 | <8>jp-char 59 | <12>root_4 ::= <14>jp-char <16>root_4 | <20>jp-char 60 | <24>root_3 ::= <26>[ - - 61 | - 62 | ] <33>root_4 63 | <37>root_5 ::= <39>root_3 <41>root_5 | 64 | <47>root ::= <49>root_2 <51>root_5 65 | <55>jp-char ::= <57>hiragana | <61>katakana | <65>punctuation | <69>cjk 66 | <73>hiragana ::= <75>[ぁ-ゟ] 67 | <80>katakana ::= <82>[ァ-ヿ] 68 | <87>punctuation ::= <89>[、-〾] 69 | <94>cjk ::= <96>[一-鿿] 70 | 71 | Grammar Hex representation: 72 | 0002 0005 0001 0001 0001 0002 0000 0003 0001 0001 0000 0000 0004 0005 0001 0001 0001 0004 0000 0003 0001 0001 0000 0000 0003 000a 0006 0020 0020 0009 0009 000a 000a 0001 0004 0000 0000 0005 0005 0001 0003 0001 0005 0000 0001 0000 0000 0000 0005 0001 0002 0001 0005 0000 0000 0001 0003 0001 0006 0000 0003 0001 0007 0000 0003 0001 0008 0000 0003 0001 0009 0000 0000 0006 0004 0002 3041 309f 0000 0000 0007 0004 0002 30a1 30ff 0000 0000 0008 0004 0002 3001 303e 0000 0000 0009 0004 0002 4e00 9fff 0000 0000 ffff 73 | 74 | Rules Decimal representation: 75 | <2> [[5, 1, 1, 1, 2, 0], [3, 1, 1, 0]] 76 | <4> [[5, 1, 1, 1, 4, 0], [3, 1, 1, 0]] 77 | <3> [[10, 6, 32, 32, 9, 9, 10, 10, 1, 4, 0]] 78 | <5> [[5, 1, 3, 1, 5, 0], [1, 0]] 79 | <0> [[5, 1, 2, 1, 5, 0]] 80 | <1> [[3, 1, 6, 0], [3, 1, 7, 0], [3, 1, 8, 0], [3, 1, 9, 0]] 81 | <6> [[4, 2, 12353, 12447, 0]] 82 | <7> [[4, 2, 12449, 12543, 0]] 83 | <8> [[4, 2, 12289, 12350, 0]] 84 | <9> [[4, 2, 19968, 40959, 0]] 85 | symbol_ids: 86 | {'root': 0, 'jp-char': 1, 'root_2': 2, 'root_3': 3, 'root_4': 4, 'root_5': 5, 'hiragana': 6, 'katakana': 7, 'punctuation': 8, 'cjk': 9} 87 | ``` 88 | 89 | A successful parse confirms the grammar is syntactically correct. 90 | 91 |

92 | Visualization of arithmetic grammar 93 |

94 |

Figure 2: Graph visualization of the arithmetic grammar

95 | 96 | ## Test with input 97 | 98 | After verifying that the grammar can be parsed, test it with a simple input to confirm the expected output. The following script demonstrates this process: 99 | 100 | ```python 101 | from transformers_cfg.parser import parse_ebnf 102 | from transformers_cfg.recognizer import StringRecognizer 103 | 104 | with open("examples/grammars/json.ebnf", "r") as file: 105 | input_text = file.read() 106 | parsed_grammar = parse_ebnf(input_text) 107 | 108 | start_rule_id = parsed_grammar.symbol_table["root"] 109 | recognizer = StringRecognizer(parsed_grammar.grammar_encoding, start_rule_id) 110 | 111 | # Test the grammar with a simple input. 112 | json_input = '{"foo": "bar", "baz": "bat"}' 113 | is_accepted = recognizer._accept_prefix(json_input) 114 | print(is_accepted) 115 | ``` 116 | 117 | If the script prints `True`, the grammar recognizes the input string. A result of `False` indicates that the input is not fully recognized. To identify the failure point, try testing with a partial input: 118 | 119 | ```python 120 | json_input = '{"foo": "bar"' 121 | is_accepted = recognizer._accept_prefix(json_input) 122 | print(is_accepted) 123 | ``` 124 | 125 | To verify if the input string is complete, use the `_accept_string` method, which returns `True` for a complete string and `False` otherwise. 126 | 127 | ## Debug mode 128 | 129 | Enable debug mode to observe the parsing process in detail by setting the environment variable: 130 | 131 | ```bash 132 | export TCFG_LOG_LEVEL=DEBUG 133 | ``` 134 | 135 | The output will log each accepted code point. For example: 136 | 137 | ```terminal 138 | DEBUG:root:code point [123] corresponding to { is accepted 139 | DEBUG:root:code point [123, 34] corresponding to " is accepted 140 | ... 141 | DEBUG:root:code point [123, 34, 102, 111, 111, 34, 58, 32, 34, 98, 97, 116, 34, 125] corresponding to } is accepted 142 | ``` 143 | 144 | This log assists in identifying where the parser accepts or rejects input characters. 145 | 146 | ## Tips and tricks 147 | 148 | ### Incremental development 149 | 150 | Begin with a minimal grammar rule and gradually add more rules. This incremental approach simplifies error detection as the grammar evolves. 151 | 152 | ### Isolate grammar components 153 | 154 | If the grammar does not behave as expected, isolate individual components to determine the source of the issue. Remove or comment out parts of the grammar and reintroduce them gradually until the problem is identified. 155 | 156 | ### Test with language model 157 | 158 | Once the grammar is confirmed to be correct, remaining issues likely pertain to other aspects of the system. Testing with a language model is important at this stage, although it falls outside the scope of grammar verification. 159 | -------------------------------------------------------------------------------- /transformers_cfg/tokenization/mapping/token2byte.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import torch 4 | from transformers_cfg.tokenization.SUPPORTED_TOKENIZERS import SUPPORTED_TOKENIZERS 5 | from .ByteProxyMapping import ByteProxyMapping, LLAMAByteProxyMapping 6 | import logging 7 | from transformers import ( 8 | GPT2TokenizerFast, 9 | BartTokenizerFast, 10 | T5TokenizerFast, 11 | CodeGenTokenizerFast, 12 | LlamaTokenizerFast, 13 | PreTrainedTokenizerFast, 14 | GemmaTokenizerFast, 15 | Qwen2TokenizerFast, 16 | ByT5Tokenizer, 17 | ) 18 | 19 | from transformers_cfg.tokenization.utils import get_tokenizer_charset 20 | 21 | log = logging.getLogger(__name__) 22 | 23 | 24 | class Token2ByteMapping(ABC): 25 | def __init__(self, tokenizer): 26 | self.eos_token_id = tokenizer.eos_token_id 27 | self.bos_token_id = tokenizer.bos_token_id 28 | self.tokenizer = tokenizer 29 | self.special = tokenizer.all_special_ids 30 | self._length = len(self.tokenizer.get_vocab()) 31 | 32 | def __len__(self): 33 | return self._length 34 | 35 | @abstractmethod 36 | def map(self, token_id: int, verbose=False) -> bytes: 37 | pass 38 | 39 | @classmethod 40 | def from_hf_tokenizer(cls, hf_tokenizer): 41 | assert ( 42 | type(hf_tokenizer) in SUPPORTED_TOKENIZERS 43 | ), f"Tokenizer not supported: {hf_tokenizer.__class__.__name__}, supported tokenizers: {SUPPORTED_TOKENIZERS}" 44 | if isinstance( 45 | hf_tokenizer, 46 | ( 47 | GPT2TokenizerFast, 48 | BartTokenizerFast, 49 | CodeGenTokenizerFast, 50 | Qwen2TokenizerFast, 51 | ), 52 | ): 53 | return GPT2Token2ByteMapping(hf_tokenizer) 54 | elif isinstance(hf_tokenizer, (LlamaTokenizerFast, GemmaTokenizerFast)): 55 | # deepseek, though inheriting from LlamaTokenizerFast, is actually a GPT2TokenizerFast 56 | # check https://github.com/epfl-dlab/transformers-CFG/issues/72 57 | if "deepseek-coder" in hf_tokenizer.name_or_path: 58 | return GPT2Token2ByteMapping(hf_tokenizer) 59 | return LLAMA1Token2ByteMapping(hf_tokenizer) 60 | elif isinstance(hf_tokenizer, T5TokenizerFast): 61 | return T5Token2ByteMapping(hf_tokenizer) 62 | elif ( 63 | isinstance(hf_tokenizer, PreTrainedTokenizerFast) 64 | and "Llama-3" 65 | in hf_tokenizer.name_or_path # this includes llama-3/llama-3.1/llama-3.2/llama-3.3 66 | ): 67 | return GPT2Token2ByteMapping(hf_tokenizer) 68 | elif isinstance(hf_tokenizer, ByT5Tokenizer): 69 | return ByT5Token2ByteMapping(hf_tokenizer) 70 | else: 71 | raise NotImplementedError( 72 | f"Tokenizer not supported: {hf_tokenizer.__class__.__name__}" 73 | ) 74 | 75 | @staticmethod 76 | def auto_infer(hf_tokenizer): 77 | "beta version, not sure if it will work for all cases" 78 | charset = get_tokenizer_charset(hf_tokenizer) 79 | size = len(charset) 80 | if size >= 256 and size < 256 + 30: 81 | return GPT2Token2ByteMapping(hf_tokenizer) 82 | elif "▁" in charset: 83 | return LLAMA1Token2ByteMapping(hf_tokenizer) 84 | else: 85 | raise NotImplementedError( 86 | f"Tokenizer not supported: {hf_tokenizer.__class__.__name__}" 87 | ) 88 | 89 | 90 | class GPT2Token2ByteMapping(Token2ByteMapping): 91 | def __init__(self, tokenizer): 92 | super().__init__(tokenizer) 93 | self.byte_proxy_mapping = ByteProxyMapping(tokenizer) 94 | 95 | def map2proxy_token(self, token_id: int) -> str: 96 | # This is the case for BOS, 97 | if token_id in self.special: 98 | return "" 99 | # if token_id is tensor, convert it to int 100 | if isinstance(token_id, torch.Tensor): 101 | token_id = token_id.item() 102 | proxy_token = self.tokenizer.convert_ids_to_tokens(token_id) 103 | return proxy_token 104 | 105 | def map(self, token_id: int, verbose=False) -> bytes: 106 | proxy_token = self.map2proxy_token(token_id) 107 | if verbose: 108 | log.debug(f"token_id: {token_id}, token: {proxy_token}") 109 | 110 | return self.byte_proxy_mapping.map(proxy_token) 111 | 112 | 113 | class LLAMA1Token2ByteMapping(Token2ByteMapping): 114 | def __init__(self, tokenizer): 115 | super().__init__(tokenizer) 116 | self.last_token_id = None 117 | self.byte_proxy_mapping = LLAMAByteProxyMapping() 118 | 119 | def map(self, token_id: int, verbose=False) -> bytes: 120 | # we need to check if the token is at the beginning of the sentence to remove the space 121 | # specific to BPE 122 | at_bos = False 123 | if self.last_token_id is not None and self.last_token_id == self.bos_token_id: 124 | at_bos = True 125 | self.last_token_id = token_id 126 | 127 | # This is the case for BOS, 128 | if token_id in self.special: 129 | return b"" 130 | # if token_id is tensor, convert it to int 131 | if isinstance(token_id, torch.Tensor): 132 | token_id = token_id.item() 133 | proxy_token = self.tokenizer.convert_ids_to_tokens(token_id) 134 | 135 | token_bytes = self.byte_proxy_mapping.map(proxy_token) 136 | 137 | # check if the first byte is a space 138 | if token_bytes[0] == 32 and at_bos: 139 | # remove space at the beginning of the sentence 140 | token_bytes = token_bytes[1:] 141 | 142 | return token_bytes 143 | 144 | 145 | class T5Token2ByteMapping(Token2ByteMapping): 146 | def __init__(self, tokenizer): 147 | super().__init__(tokenizer) 148 | self.at_bos = True 149 | self.byte_proxy_mapper = LLAMAByteProxyMapping() 150 | 151 | def map(self, token_id: int, verbose=False) -> bytes: 152 | # we need to check if the token is at the beginning of the sentence to remove the space 153 | # specific to BPE 154 | 155 | # This is the case for BOS, 156 | if token_id in self.special: 157 | self.at_bos = False 158 | return b"" 159 | # if token_id is tensor, convert it to int 160 | if isinstance(token_id, torch.Tensor): 161 | token_id = token_id.item() 162 | proxy_token = self.tokenizer.convert_ids_to_tokens(token_id) 163 | 164 | token_bytes = self.byte_proxy_mapper.map(proxy_token) 165 | 166 | # check if the first byte is a space 167 | if token_bytes[0] == 32 and self.at_bos: 168 | # remove space at the beginning of the sentence 169 | token_bytes = token_bytes[1:] 170 | 171 | self.at_bos = False 172 | return token_bytes 173 | 174 | 175 | class ByT5Token2ByteMapping(Token2ByteMapping): 176 | def __init__(self, tokenizer): 177 | super().__init__(tokenizer) 178 | self.token_id_to_bytes = {} 179 | 180 | def map(self, token_id: int, verbose=False) -> bytes: 181 | # By inspecting the token vocab, we can see that the first 3 tokens are special tokens 182 | # and the tokens after 258 are also special tokens 183 | # only the tokens between 3 and 258 are valid tokens, 256 bytes 184 | if isinstance(token_id, torch.Tensor): 185 | token_id = token_id.item() 186 | if token_id in self.token_id_to_bytes: 187 | return self.token_id_to_bytes[token_id] 188 | if 3 <= token_id <= 258: 189 | return ord(self.tokenizer.convert_ids_to_tokens(token_id)).to_bytes( 190 | 1, "big" 191 | ) 192 | else: 193 | # return empty bytes for special tokens 194 | return bytes() 195 | -------------------------------------------------------------------------------- /examples/generate_pddl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | from transformers import AutoModelForCausalLM, AutoTokenizer 4 | from transformers_cfg.grammar_utils import IncrementalGrammarConstraint 5 | from transformers_cfg.recognizer import StringRecognizer 6 | from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor 7 | from transformers_cfg.parser import parse_ebnf 8 | 9 | 10 | def parse_args(): 11 | parser = argparse.ArgumentParser(description="Generate PDDL strings") 12 | parser.add_argument( 13 | "--model-id", 14 | type=str, 15 | default="/dlabdata1/llm_hub/Mistral-7B-v0.1", 16 | help="Model ID", 17 | ) 18 | parser.add_argument("--device", type=str, help="Device to put the model on") 19 | parser.add_argument( 20 | "--pddl-type", 21 | type=str, 22 | choices=["blocks", "depot", "satellite", "depot_typed", "satellite_typed"], 23 | default="blocks", 24 | help="Type of PDDL to generate", 25 | ) 26 | return parser.parse_args() 27 | 28 | 29 | one_shot_prompts = { 30 | "blocks": "(put-down a) (unstack-and-stack c b d) (pick-up-and-stack b c)", 31 | "depot": "(drive truck0 depot0 distributor0) (lift-and-drive truck0 hoist0 crate0 pallet0 depot0 depot0) (lift hoist2 crate2 crate1 distributor1)", 32 | "satellite": "(switch-on instrument1 satellite3) (turn-to satellite1 direction4 direction0)", 33 | } 34 | 35 | 36 | def main(): 37 | args = parse_args() 38 | model_id = args.model_id 39 | 40 | # Detect if GPU is available, otherwise use CPU 41 | device = torch.device( 42 | args.device or ("cuda" if torch.cuda.is_available() else "cpu") 43 | ) 44 | print(f"Using device: {device}") 45 | 46 | # Load model and tokenizer 47 | tokenizer = AutoTokenizer.from_pretrained(model_id) 48 | tokenizer.pad_token = tokenizer.eos_token 49 | # Load model to defined device 50 | model = AutoModelForCausalLM.from_pretrained(model_id).to(device) 51 | 52 | # Load grammar 53 | with open(f"examples/grammars/PDDL/{args.pddl_type}.ebnf", "r") as file: 54 | grammar_str = file.read() 55 | 56 | parsed_grammar = parse_ebnf(grammar_str) 57 | grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer) 58 | grammar_processor = GrammarConstrainedLogitsProcessor(grammar) 59 | 60 | # Generate 61 | pddl_domain = args.pddl_type.split("_")[0] 62 | prompts = [ 63 | f"Give me two examples of the {pddl_domain} command sequence:\n" 64 | + f"1. {one_shot_prompts[pddl_domain]}\n2. " 65 | ] 66 | 67 | input_ids = tokenizer( 68 | prompts, add_special_tokens=False, return_tensors="pt", padding=True 69 | )["input_ids"].to( 70 | device 71 | ) # Move input_ids to the same device as model 72 | 73 | n_examples = input_ids.shape[0] 74 | 75 | max_new_tokens = 50 76 | unconstrained_output = model.generate( 77 | input_ids, 78 | do_sample=False, 79 | max_new_tokens=max_new_tokens, 80 | repetition_penalty=1.9, 81 | num_return_sequences=1, 82 | ) 83 | constrained_output = model.generate( 84 | input_ids, 85 | do_sample=False, 86 | max_new_tokens=max_new_tokens, 87 | logits_processor=[grammar_processor], 88 | repetition_penalty=1.9, 89 | num_return_sequences=1, 90 | ) 91 | 92 | parsed_grammar = parse_ebnf(grammar_str) 93 | string_grammar = StringRecognizer( 94 | parsed_grammar.grammar_encoding, parsed_grammar.symbol_table["root"] 95 | ) 96 | 97 | # decode outputs (possibly of different lengths across decoding modes) 98 | generations = tokenizer.batch_decode( 99 | unconstrained_output, skip_special_tokens=True 100 | ) + tokenizer.batch_decode(constrained_output, skip_special_tokens=True) 101 | print() 102 | 103 | for i in range(n_examples): 104 | unconstrained_generation = generations[i] 105 | constrained_generation = generations[i + n_examples] 106 | prompt = prompts[i] 107 | 108 | for generation, generation_type in zip( 109 | [unconstrained_generation, constrained_generation], 110 | ["unconstrained", "constrained"], 111 | ): 112 | print(f"The {generation_type} generation:\n{generation}") 113 | print( 114 | f"The {generation_type} generation is a valid prefix for the grammar: {string_grammar._accept_prefix(generation[len(prompt):])}" 115 | ) 116 | print( 117 | f"The {generation_type} generation is a valid sentence for the grammar: {string_grammar._accept_string(generation[len(prompt):])}" 118 | ) 119 | print() 120 | 121 | 122 | if __name__ == "__main__": 123 | main() 124 | 125 | 126 | ########################## 127 | # Example output: 128 | # 129 | # BLOCKS: 130 | # 131 | # The unconstrained generation: 132 | # Give me two examples of the blocks command sequence: 133 | # 1. (put-down a) (unstack-and-stack c b d) (pick-up-and-stack b c) 134 | # 2. 30,45,(move left),(turn right)(go forward). The first example is an action that can be performed by any robot with three arms and four objects in its workspace; it does not depend on what those particular arm/object 135 | # The unconstrained generation is a valid prefix for the grammar: False 136 | # The unconstrained generation is a valid sentence for the grammar: False 137 | 138 | # The constrained generation: 139 | # Give me two examples of the blocks command sequence: 140 | # 1. (put-down a) (unstack-and-stack c b d) (pick-up-and-stack b c) 141 | # 2. (pick-up e) (put-down e) (unstack-and-stack b c d) (put-down a) (unstack-and-stack c b 142 | # The constrained generation is a valid prefix for the grammar: True 143 | # The constrained generation is a valid sentence for the grammar: False 144 | # 145 | # DEPOT: 146 | # The unconstrained generation: 147 | # Give me two examples of the depot command sequence: 148 | # 1. (drive truck0 depot0 distributor0) (lift-and-drive truck0 hoist0 crate0 pallet0 depot0 depot0) (lift hoist2 crate2 crate1 distributor1) 149 | # 2. 3567894(move robot arm to position A)(pick up object from table and place it on shelf B). The first example is a simple one, where we have three trucks that are going back into their respective parking spots after 150 | # The unconstrained generation is a valid prefix for the grammar: False 151 | # The unconstrained generation is a valid sentence for the grammar: False 152 | 153 | # The constrained generation: 154 | # Give me two examples of the depot command sequence: 155 | # 1. (drive truck0 depot0 distributor0) (lift-and-drive truck0 hoist0 crate0 pallet0 depot0 depot0) (lift hoist2 crate2 crate1 distributor1) 156 | # 2. (load crate3 crate4 distributor1 distributor1) (unload truck0 pallet5 depot0 distributor1) (drop truck 157 | # The constrained generation is a valid prefix for the grammar: True 158 | # The constrained generation is a valid sentence for the grammar: False 159 | 160 | # SATELLITE: 161 | # The unconstrained generation: 162 | # Give me two examples of the satellite command sequence: 163 | # 1. (switch-on instrument1 satellite3) (turn-to satellite1 direction4 direction0) 164 | # 2. ......(move to position5 distance6 angle7 )... The first example is a simple one, but it shows how we can use satellites as instruments and also move them around in space using their own commands for movement or orientation change etc., 165 | # The unconstrained generation is a valid prefix for the grammar: False 166 | # The unconstrained generation is a valid sentence for the grammar: False 167 | 168 | # The constrained generation: 169 | # Give me two examples of the satellite command sequence: 170 | # 1. (switch-on instrument1 satellite3) (turn-to satellite1 direction4 direction0) 171 | # 2. (take-image satellite1 instrument5 direction0 direction1) (calibrate instrument5 satellite1 direction0) (switch-off instrument5 satell 172 | # The constrained generation is a valid prefix for the grammar: True 173 | # The constrained generation is a valid sentence for the grammar: False 174 | 175 | ########################## 176 | -------------------------------------------------------------------------------- /examples/grammars/geo_query.ebnf: -------------------------------------------------------------------------------- 1 | root ::= "answer(" answer_type ")" 2 | 3 | answer_type ::= city | state | num | place | river | country 4 | 5 | coma_sep ::= " "* "," " "* 6 | 7 | city ::= "city(" city ")" | 8 | "cityid('" CITYNAME "', '" STATEABBREV "')" | 9 | "cityid('" CITYNAME "', _)" | 10 | "capital(" city ")" | 11 | "major(" city ")" | 12 | "capital_1(" state ")" | 13 | "loc_2(" state ")" | 14 | "loc_2(" country ")" | 15 | "largest(" city ")" | 16 | "smallest(" city ")" | 17 | "intersection(" city coma_sep city ")" | 18 | "exclude(" city coma_sep city ")" | 19 | "largest_one(population_1(" city "))" | 20 | "largest_one(density_1(" city "))" | 21 | "smallest_one(population_1(" city "))" | 22 | "smallest_one(density_1(" city "))" | 23 | ALL_CITY 24 | 25 | place ::= "placeid('" PLACENAME "')" | 26 | "lake(" place ")" | 27 | "mountain(" place ")" | 28 | "place(" place ")" | 29 | "high_point_1(" state ")" | 30 | "low_point_1(" state ")" | 31 | "higher_2(" place ")" | 32 | "lower_2(" place ")" | 33 | "lowest(" place ")" | 34 | "highest(" place ")" | 35 | "capital(" place ")" | 36 | "loc_2(" city ")" | 37 | "loc_2(" state ")" | 38 | "loc_2(" country ")" | 39 | "major(" place ")" | 40 | "elevation_2(" NUM ")" | 41 | "exclude(" place coma_sep place ")" | 42 | ALL_PLACE 43 | 44 | river ::= "river(" river ")" | 45 | "riverid('" RIVERNAME "')" | 46 | "major(" river ")" | 47 | "loc_2(" country ")" | 48 | "loc_2(" state ")" | 49 | "longer(" river ")" | 50 | "traverse_2(" city ")" | 51 | "traverse_2(" state ")" | 52 | "traverse_2(" country ")" | 53 | "longest(" river ")" | 54 | "shortest(" river ")" | 55 | "most(" river ")" | 56 | "fewest(" river ")" | 57 | "intersection(" river coma_sep river ")" | 58 | "exclude(" river coma_sep river ")" | 59 | ALL_RIVER 60 | 61 | state ::= "state(" state ")" | 62 | "stateid('" STATENAME "')" | 63 | "capital_2(" city ")" | 64 | "highest_point_2(" place ")" | 65 | "loc_1(" place ")" | 66 | "loc_1(" city ")" | 67 | "loc_1(" river ")" | 68 | "loc_2(" country ")" | 69 | "next_to_1(" state ")" | 70 | "next_to_2(" state ")" | 71 | "next_to_2(" river ")" | 72 | "traverse_1(" river ")" | 73 | "largest(" state ")" | 74 | "largest_one(area_1(" state "))" | 75 | "largest_one(density_1(" state "))" | 76 | "largest_one(population_1(" state "))" | 77 | "smallest_one(area_1(" state "))" | 78 | "smallest_one(density_1(" state "))" | 79 | "smallest_one(population_1(" state "))" | 80 | "highest(" state ")" | 81 | "lowest(" state ")" | 82 | "most(" state ")" | 83 | "fewest(" state ")" | 84 | "smallest(" state ")" | 85 | "high_point_2(" place ")" | 86 | "low_point_2(" place ")" | 87 | "intersection(" state coma_sep state ")" | 88 | "exclude(" state coma_sep state ")" | 89 | ALL_STATE 90 | 91 | num ::= NUM | 92 | "area_1(" city ")" | 93 | "area_1(" state ")" | 94 | "area_1(" country ")" | 95 | "density_1(" city ")" | 96 | "density_1(" state ")" | 97 | "density_1(" country ")" | 98 | "elevation_1(" place ")" | 99 | "len(" river ")" | 100 | "population_1(" state ")" | 101 | "population_1(" city ")" | 102 | "population_1(" country ")" | 103 | "size(" state ")" | 104 | "size(" city ")" | 105 | "count(" city ")" | 106 | "count(" state ")" | 107 | "count(" river ")" | 108 | "sum(" num ")" | 109 | "smallest(" num ")" 110 | 111 | country ::= "countryid('" COUNTRYNAME "')" | "loc_1(" state ")" 112 | 113 | ALL_STATE ::= "state(all)" 114 | 115 | ALL_CITY ::= "city(all)" | "capital(all)" 116 | 117 | ALL_PLACE ::= "place(all)" | "mountain(all)" | "lake(all)" 118 | 119 | ALL_RIVER ::= "river(all)" 120 | 121 | NUM ::= "0" 122 | 123 | CITYNAME ::= "new york" | 124 | "guadalupe peak" | 125 | "durham" | 126 | "tempe" | 127 | "sacramento" | 128 | "albany" | 129 | "rochester" | 130 | "salem" | 131 | "portland" | 132 | "miami" | 133 | "san diego" | 134 | "spokane" | 135 | "erie" | 136 | "austin" | 137 | "new orleans" | 138 | "dallas" | 139 | "boulder" | 140 | "plano" | 141 | "fort wayne" | 142 | "boston" | 143 | "springfield" | 144 | "seattle" | 145 | "dover" | 146 | "minneapolis" | 147 | "denver" | 148 | "tucson" | 149 | "montgomery" | 150 | "san jose" | 151 | "atlanta" | 152 | "salt lake city" | 153 | "kalamazoo" | 154 | "flint" | 155 | "chicago" | 156 | "indiannapolis" | 157 | "pittsburgh" | 158 | "scotts valley" | 159 | "baton rouge" | 160 | "riverside" | 161 | "san francisco" | 162 | "des moines" | 163 | "columbus" | 164 | "houston" | 165 | "detroit" | 166 | "washington" | 167 | "indianapolis" | 168 | "san antonio" 169 | 170 | PLACENAME ::= "death valley" | "mount whitney" | "guadalupe peak" | "mount mckinley" 171 | 172 | STATEABBREV ::= "ga" | "wa" | "mn" | "az" | "pa" | "dc" | "tx" | "sd" | "me" | "ma" | "mo" 173 | 174 | STATENAME ::= "north carolina" | 175 | "washington" | 176 | "minnesota" | 177 | "florida" | 178 | "virginia" | 179 | "arkansas" | 180 | "iowa" | 181 | "california" | 182 | "delaware" | 183 | "new jersey" | 184 | "rhode island" | 185 | "nevada" | 186 | "nebraska" | 187 | "indiana" | 188 | "wisconsin" | 189 | "oklahoma" | 190 | "new mexico" | 191 | "idaho" | 192 | "ohio" | 193 | "montana" | 194 | "arizona" | 195 | "louisiana" | 196 | "tennessee" | 197 | "pennsylvania" | 198 | "new hampshire" | 199 | "south carolina" | 200 | "michigan" | 201 | "utah" | 202 | "vermont" | 203 | "kansas" | 204 | "oregon" | 205 | "wyoming" | 206 | "maryland" | 207 | "alaska" | 208 | "georgia" | 209 | "mississippi" | 210 | "illinois" | 211 | "texas" | 212 | "south dakota" | 213 | "north dakota" | 214 | "alabama" | 215 | "kentucky" | 216 | "hawaii" | 217 | "maine" | 218 | "west virginia" | 219 | "colorado" | 220 | "new york" | 221 | "massachusetts" | 222 | "missouri" 223 | 224 | COUNTRYNAME ::= "usa" 225 | 226 | RIVERNAME ::= "ohio" | "north platte" | "red" | "chattahoochee" | "mississippi" | "potomac"| "colorado" | "missouri" | "rio grande" | "delaware" 227 | -------------------------------------------------------------------------------- /tests/test_accept_string/test_smiles.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from transformers_cfg.parser import parse_ebnf 3 | from transformers_cfg.recognizer import StringRecognizer 4 | from dataclasses import dataclass 5 | 6 | 7 | @dataclass 8 | class MoleculeTestCase: 9 | name: str 10 | molecule: str 11 | 12 | 13 | valid_smiles_sentences = [ 14 | MoleculeTestCase("simple_atom", "C"), 15 | MoleculeTestCase("single_bond_no_hyphen", "CC"), 16 | MoleculeTestCase("double_bond", "C=O"), 17 | MoleculeTestCase("dot", "C.O"), 18 | MoleculeTestCase("radical", "CC(C)C"), 19 | MoleculeTestCase("isotope", "[14c]"), 20 | MoleculeTestCase("aromatic_no_hyphen", "C1CC1"), 21 | MoleculeTestCase("interleaved_cycle_explicit", "C1=CC=CC=C1"), 22 | MoleculeTestCase("interleaved_cycle_colon", "C1:C:C1"), 23 | MoleculeTestCase("interleaved_cycle_lower_case", "c1cc1"), 24 | MoleculeTestCase("cis_bond_right", "F/C=C\\F"), 25 | MoleculeTestCase("trans_bond_left", "F\\C=C\\F"), 26 | MoleculeTestCase("d_alanine", "C[C@H](C(=O)O)N"), 27 | MoleculeTestCase("l_alanine", "C[C@@H](C(=O)O)N"), 28 | MoleculeTestCase("nested_cycles", "C12(CCC1)CCC2"), 29 | MoleculeTestCase("charge", "[Cu+2].[O-]S(=O)[O-]"), 30 | MoleculeTestCase("mix_of_cases", "Cc(cc1)ccc1C#N"), 31 | MoleculeTestCase("mix_of_bonds_and_cycles", "C1CC/C=C1/C=C/C=C/C2=C(C)/CCC2"), 32 | MoleculeTestCase("wildcard", "Oc1c(*)cccc1"), 33 | ] 34 | 35 | valid_smiles_prefixes = [ 36 | MoleculeTestCase("empty_string", ""), 37 | MoleculeTestCase("simple_atom_dangling_bond", "C#"), 38 | MoleculeTestCase("unbalanced_paranthesis", "C(C(C)"), 39 | # Failure cases of SMILES in general 40 | # MoleculeTestCase("lowercase_outside_cycle", "c"), 41 | # MoleculeTestCase("unclosed cycle", "C1CC/C=C1/C=C/C=C/C2=C(C)/CCC"), 42 | # MoleculeTestCase("unclosed cycle", "C1CCC"), 43 | ] 44 | 45 | invalid_smiles_sentences = [ 46 | MoleculeTestCase("fake_atom", "L"), 47 | MoleculeTestCase("fake_molecule_in_brackets", "[Xx]"), 48 | MoleculeTestCase("bond_outside_parentheses", "CCC=(O)O"), 49 | MoleculeTestCase("double_double_bond", "C==C"), 50 | MoleculeTestCase("empty_paranthesis", "()"), 51 | MoleculeTestCase("invalid_charge", "[Cu+20].[O-]S(=O)(=O)[O-]"), 52 | # Failure cases of SMILES in general 53 | # MoleculeTestCase("two_bonds_same_atom", "C12C2CCC1"), 54 | # MoleculeTestCase("self-bond", "C11"), 55 | ] 56 | 57 | valid_isocyanite_sentences = [ 58 | MoleculeTestCase("short_isocyanate", "O=C=NCCCCCCN=C=O"), 59 | MoleculeTestCase("right_group", "CC1=C(C=C(C=C1)CN=C=O)N=C=O"), 60 | MoleculeTestCase("trans_bond_right", "Cc1ccc(cc1\\N=C=O)\\N=C=O"), 61 | MoleculeTestCase("group_radical", "O=C=NC1CCC(CC2CCC(CC2)N=C=O)CC1"), 62 | MoleculeTestCase("trans_bond_left", "O=C=N\\C1CC(C\\N=C=O)(CC(C1)(C)C)C"), 63 | MoleculeTestCase("trans_bond", "O=C=N\\CCCCCC/N=C=O"), 64 | MoleculeTestCase("group_radicals", "CCOC(C(N=C=O)CCCCN=C=O)=O"), 65 | MoleculeTestCase( 66 | "simple_atom", "O=C=NC1=CC=CC(CC2=CC=C(C=C2N=C=O)CC3=CC=C(C=C3)N=C=O)=C1" 67 | ), 68 | MoleculeTestCase( 69 | "single_bond_no_hyphen", 70 | "O=C=NC1=CC(CC2=C(C=C(C=C2)CC3=CC=C(C=C3N=C=O)CC4=CC=C(C=C4)N=C=O)N=C=O)=CC=C1", 71 | ), 72 | MoleculeTestCase( 73 | "double_bond", 74 | "O=C=NC1=CC=C(C=C1)CC2=CC=C(C=C2N=C=O)CC3=C(C=C(C=C3)CC4=CC=C(C=C4N=C=O)CC5=CC=C(C=C5)N=C=O)N=C=O", 75 | ), 76 | MoleculeTestCase("interleaved_cycle_explicit", "CC1(CC(CC(CN=C=O)(C1)C)N=C=O)C"), 77 | MoleculeTestCase("interleaved_cycle_colon", "CC1=C(C=C(C=C1)CN=C=O)N=C=O"), 78 | MoleculeTestCase("cycles", "O=C=N\\c1ccc(cc1)Cc2ccc(\\N=C=O)cc2"), 79 | ] 80 | 81 | valid_acrylate_sentences = [ 82 | MoleculeTestCase("simple_acrylate", "COC(=O)C=C"), 83 | MoleculeTestCase("simple_acrylate", "C=CC(=O)OC1=CC=CC=C1"), 84 | MoleculeTestCase("simple_acrylate_group_variation", "CC(=C)C(=O)OC1=CC=CC=C1"), 85 | MoleculeTestCase("", "C=CC(=O)OCCC1=CC=CC=C1"), 86 | MoleculeTestCase("", "CCC(C)OC(=O)C=C"), 87 | MoleculeTestCase("", "C=CC(=O)OC1=C(C(=C(C(=C1F)F)F)F)F"), 88 | MoleculeTestCase("", "CC(C)COC(=O)C(=C)C"), 89 | MoleculeTestCase("", "CCC(C)OC(=O)C(=C)C"), 90 | MoleculeTestCase("", "CCCOC(=O)C(=C)C"), 91 | MoleculeTestCase("", "CC1CC(CC(C1)(C)C)OC(=O)C(=C)C"), 92 | MoleculeTestCase("", "CCCOC(=O)C=C"), 93 | MoleculeTestCase("", "COCCOC(=O)C=C"), 94 | MoleculeTestCase("", "CC(=C)C(=O)OCCOC1=CC=CC=C1"), 95 | MoleculeTestCase("", "CCCCCCOC(=O)C=C"), 96 | MoleculeTestCase("", "CCCCOCCOC(=O)C(=C)C"), 97 | MoleculeTestCase("", "CC(=C)C(=O)OC"), 98 | MoleculeTestCase("", "CCCCOC(=O)C=C"), 99 | MoleculeTestCase("", "CCOCCOC(=O)C(=C)C"), 100 | MoleculeTestCase("", "CC(=C)C(=O)OC1CC2CCC1(C2(C)C)C"), 101 | MoleculeTestCase("", "CCCCC(CC)COC(=O)C(=C)C"), 102 | MoleculeTestCase("", "CC(C)(COCCCOC(=O)C=C)COCCCOC(=O)C=C"), 103 | MoleculeTestCase("", "C=CC(=O)OCCCCCCOC(=O)C=C"), 104 | MoleculeTestCase("", "C=CC(=O)OCC(CO)(COC(=O)C=C)COC(=O)C=C"), 105 | MoleculeTestCase("", "CCC(COCCCOC(=O)C=C)(COCCCOC(=O)C=C)COCCCOC(=O)C=C"), 106 | MoleculeTestCase("", "CCC(COCC(CC)(COC(=O)C=C)COC(=O)C=C)(COC(=O)C=C)COC(=O)C=C"), 107 | MoleculeTestCase( 108 | "", "C=CC(=O)OCC(CO)(COCC(COC(=O)C=C)(COC(=O)C=C)COC(=O)C=C)COC(=O)C=C" 109 | ), 110 | MoleculeTestCase( 111 | "", "C=CC(=O)OCC(COCC(COC(=O)C=C)(COC(=O)C=C)COC(=O)C=C)(COC(=O)C=C)COC(=O)C=C" 112 | ), 113 | ] 114 | 115 | valid_chain_extender_sentences = [ 116 | MoleculeTestCase("simplest_chain_extender", "OCCO"), 117 | MoleculeTestCase("", "OC(C)CCO"), 118 | MoleculeTestCase("", "OCCOCO"), 119 | MoleculeTestCase("", "OCCNC(=O)NCCCCCCNC(=O)NCCO"), 120 | MoleculeTestCase("", "OC(=O)C(N)CCCCN"), 121 | MoleculeTestCase("", "Oc1ccc(cc1)CCC(=O)OCCOC(=O)CCc1ccc(cc1)O"), 122 | MoleculeTestCase("", "OC(=O)C(N)CCN"), 123 | MoleculeTestCase("", "N1CCNCC1"), 124 | MoleculeTestCase("", "Nc1ccc(cc1)SSc2ccc(cc2)N"), 125 | MoleculeTestCase("", "Nc1ccc(cc1)Cc2ccc(cc2)N"), 126 | ] 127 | 128 | TestCases = { 129 | "generic": ( 130 | valid_smiles_sentences, 131 | valid_smiles_prefixes, 132 | invalid_smiles_sentences, 133 | ), 134 | "isocyanates": ( 135 | valid_isocyanite_sentences, 136 | valid_smiles_sentences, 137 | invalid_smiles_sentences, 138 | ), 139 | "acrylates": ( 140 | valid_acrylate_sentences, 141 | valid_smiles_sentences, 142 | invalid_smiles_sentences, 143 | ), 144 | "chain_extenders": ( 145 | valid_chain_extender_sentences, 146 | valid_smiles_sentences, 147 | invalid_smiles_sentences, 148 | ), 149 | } 150 | 151 | 152 | @pytest.fixture(scope="module") 153 | def recognizers(): 154 | recognizers = {} 155 | for grammar_name in TestCases: 156 | with open(f"examples/grammars/SMILES/{grammar_name}.ebnf", "r") as file: 157 | input_text = file.read() 158 | parsed_grammar = parse_ebnf(input_text) 159 | start_rule_id = parsed_grammar.symbol_table["root"] 160 | recognizers[grammar_name] = StringRecognizer( 161 | parsed_grammar.grammar_encoding, start_rule_id 162 | ) 163 | return recognizers 164 | 165 | 166 | def test_valid_sentence(recognizers): 167 | for grammar_name, recognizer in recognizers.items(): 168 | valid_full, valid_prefix, invalid = TestCases[grammar_name] 169 | 170 | for molecule_test_case in valid_full: 171 | fail_msg = ( 172 | f"{grammar_name.capitalize()}:" 173 | + f"Failed on {molecule_test_case.name}, {molecule_test_case.molecule}" 174 | ) 175 | assert ( 176 | recognizer._accept_string(molecule_test_case.molecule) == True 177 | ), fail_msg 178 | 179 | for molecule_test_case in valid_prefix + invalid: 180 | fail_msg = ( 181 | f"{grammar_name.capitalize()}:" 182 | + f"Failed on {molecule_test_case.name}, {molecule_test_case.molecule}" 183 | ) 184 | assert ( 185 | recognizer._accept_string(molecule_test_case.molecule) == False 186 | ), fail_msg 187 | 188 | 189 | def test_valid_prefixes(recognizers): 190 | for grammar_name, recognizer in recognizers.items(): 191 | valid_full, valid_prefix, invalid = TestCases[grammar_name] 192 | 193 | for molecule_test_case in valid_full + valid_prefix: 194 | fail_msg = ( 195 | f"{grammar_name.capitalize()}:" 196 | + f"Failed on {molecule_test_case.name}, {molecule_test_case.molecule}" 197 | ) 198 | assert ( 199 | recognizer._accept_prefix(molecule_test_case.molecule) == True 200 | ), fail_msg 201 | 202 | for molecule_test_case in invalid: 203 | fail_msg = ( 204 | f"{grammar_name.capitalize()}:" 205 | + f"Failed on {molecule_test_case.name}, {molecule_test_case.molecule}" 206 | ) 207 | assert ( 208 | recognizer._accept_prefix(molecule_test_case.molecule) == False 209 | ), fail_msg 210 | -------------------------------------------------------------------------------- /tests/test_accept_string/test_pddl.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from transformers_cfg.parser import parse_ebnf 3 | from transformers_cfg.recognizer import StringRecognizer 4 | from dataclasses import dataclass 5 | 6 | 7 | @dataclass 8 | class PDDLTestCase: 9 | name: str 10 | PDDL: str 11 | 12 | 13 | valid_blocks_sentences = [ 14 | PDDLTestCase("simple_operator", "(put-down c)"), 15 | PDDLTestCase("two_actions", "(put-down a) (put-down c)"), 16 | PDDLTestCase( 17 | "multiple_actions", 18 | "(pick-up-and-stack b a) (pick-up-and-stack c b) (pick-up-and-stack d c)", 19 | ), 20 | PDDLTestCase( 21 | "complex_only", 22 | "(unstack-and-stack c b d) (pick-up-and-stack b c) (pick-up-and-stack a b)", 23 | ), 24 | ] 25 | 26 | valid_blocks_prefixes = [ 27 | PDDLTestCase("empty_string", ""), 28 | PDDLTestCase("one_action_spaced", "(pick-up-and-stack b a) "), 29 | PDDLTestCase("r_unbalanced_paranthesis", "(pick-up c"), 30 | ] 31 | 32 | invalid_blocks_sentences = [ 33 | PDDLTestCase("undefined_block", "(pick-up z)"), 34 | PDDLTestCase("wrong_number_of_arguments", "(pick-up c c)"), 35 | PDDLTestCase("l_unbalanced_paranthesis", "(pick-up c))"), 36 | PDDLTestCase("unexisitng_operator", "(unstack-and-pick-up c b d)"), 37 | PDDLTestCase("empty_paranthesis", "()"), 38 | ] 39 | 40 | 41 | valid_depot_sentences = [ 42 | PDDLTestCase("simplest_operator", "(drive truck0 depot0 distributor0)"), 43 | PDDLTestCase("lift", "(lift hoist0 crate0 pallet0 distributor0)"), 44 | PDDLTestCase("drive_and_load", "(drive-and-load truck1 hoist0 crate0 depot0)"), 45 | PDDLTestCase( 46 | "drive_and_lift", 47 | "(drive-and-lift truck0 hoist0 crate0 pallet0 distributor0)", 48 | ), 49 | PDDLTestCase( 50 | "lift_and_drive", "(lift-and-drive truck0 hoist0 crate0 pallet0 depot0 depot1)" 51 | ), 52 | PDDLTestCase( 53 | "multiple_actions", 54 | "(lift-and-drive truck0 hoist0 crate0 pallet0 depot0 depot0) (lift hoist2 crate2 crate1 distributor1)", 55 | ), 56 | PDDLTestCase( 57 | "long_realistic", 58 | "(lift-and-drive truck0 hoist0 crate0 pallet0 depot0 depot0) (lift hoist2 crate2 crate1 distributor1) " 59 | + "(drive truck0 depot0 distributor0) (drive-and-lift truck0 hoist1 crate1 pallet2 distributor0) " 60 | + "(drop hoist1 crate1 crate3 distributor0) (drive-and-load truck1 hoist0 crate0 depot0) " 61 | + "(drive-and-unload truck1 hoist0 crate0 pallet2 depot0) (drive truck1 depot0 distributor1) " 62 | + "(drive-and-load truck1 hoist2 crate2 distributor1) (drive-and-unload truck1 hoist2 crate2 pallet0 distributor1)", 63 | ), 64 | PDDLTestCase( 65 | "long_real", 66 | "(lift-and-drive truck1 hoist0 crate5 pallet0 depot0 depot0) (drive-and-load truck1 hoist0 crate5 depot0) " 67 | + "(drive-and-lift truck0 hoist2 crate4 crate0 distributor1) (lift hoist1 crate0 pallet4 distributor0) " 68 | + "(drive-and-lift truck1 hoist0 crate1 pallet5 depot0) (drive-and-load truck1 hoist0 crate1 depot0) " 69 | + "(drive-and-lift truck1 hoist0 crate3 crate2 depot0) (drive-and-load truck1 hoist0 crate3 depot0) " 70 | + "(drive-and-unload truck1 hoist0 crate3 pallet1 depot0) (drop hoist2 crate4 pallet5 distributor1) " 71 | + "(drive-and-unload truck1 hoist0 crate1 crate2 depot0) (drive-and-lift truck0 hoist2 crate1 crate2 distributor1) " 72 | + "(drop hoist2 crate1 crate4 distributor1) (drive-and-unload truck1 hoist0 crate5 crate1 depot0) " 73 | + "(drop hoist1 crate0 pallet3 distributor0)", 74 | ), 75 | ] 76 | 77 | 78 | valid_depot_prefixes = [ 79 | PDDLTestCase("empty_string", ""), 80 | PDDLTestCase("one_action_spaced", "(load hoist0 crate0 truck0 distributor0) "), 81 | PDDLTestCase( 82 | "r_unbalanced_paranthesis", "(unload hoist2 crate5 truck1 distributor0" 83 | ), 84 | ] 85 | 86 | invalid_depot_sentences = [ 87 | PDDLTestCase("undefined_object", "(lift moist0 crate0 pallet0 distributor0)"), 88 | PDDLTestCase("wrong_number_of_arguments", "(lift hoist0)"), 89 | PDDLTestCase("l_unbalanced_paranthesis", "(drive truck0 depot0 distributor0))"), 90 | PDDLTestCase( 91 | "unexisitng_operator", 92 | "(load-and-drive truck0 hoist0 crate0 pallet0 depot0 depot1)", 93 | ), 94 | PDDLTestCase("empty_paranthesis", "()"), 95 | ] 96 | 97 | invalid_depot_typed_sentences = [ 98 | PDDLTestCase("load_wrong_type", "(load hoist0 crate0 pallet0 distributor0)"), 99 | PDDLTestCase("unload_wrong_type", "(unload hoist2 crate5 pallet5 distributor0)"), 100 | PDDLTestCase("lift", "(lift truck0 truck0 truck0 truck0)"), 101 | ] 102 | 103 | valid_satellite_sentences = [ 104 | PDDLTestCase("simple_operator", "(switch-on instrument1 satellite3)"), 105 | PDDLTestCase( 106 | "complex_only", 107 | "(switch-on instrument1 satellite3) (turn-to satellite1 direction4 direction0)", 108 | ), 109 | PDDLTestCase( 110 | "take_image", 111 | "(take-image satellite1 direction4 instrument2 mode1)", 112 | ), 113 | ] 114 | 115 | valid_satellite_prefixes = [ 116 | PDDLTestCase("empty_string", ""), 117 | PDDLTestCase("one_action_spaced", "(switch-off instrument2 satellite3) "), 118 | PDDLTestCase( 119 | "r_unbalanced_paranthesis", "(calibrate satellite1 instrument2 direction4" 120 | ), 121 | ] 122 | 123 | invalid_satellite_sentences = [ 124 | PDDLTestCase("undefined_object", "(switch-on instrument8 satellite3)"), 125 | PDDLTestCase( 126 | "wrong_number_of_arguments", "(take-image satellite1 instrument2 mode1)" 127 | ), 128 | PDDLTestCase( 129 | "l_unbalanced_paranthesis", "(calibrate satellite1 instrument2 direction4))" 130 | ), 131 | PDDLTestCase("unexisitng_operator", "(turn satellite1 direction4 direction0)"), 132 | PDDLTestCase("empty_paranthesis", "()"), 133 | ] 134 | 135 | invalid_satellite_typed_sentences = [ 136 | PDDLTestCase("switch_off_wrong_type", "(switch-off satellite3 instrument2)"), 137 | PDDLTestCase("turn_wrong_type", "(turn-to satellite1 satellite1 satellite1)"), 138 | ] 139 | 140 | 141 | TestCases = { 142 | "blocks": ( 143 | valid_blocks_sentences, 144 | valid_blocks_prefixes, 145 | invalid_blocks_sentences, 146 | ), 147 | "depot": ( 148 | valid_depot_sentences + invalid_depot_typed_sentences, 149 | valid_depot_prefixes, 150 | invalid_depot_sentences, 151 | ), 152 | "depot_typed": ( 153 | valid_depot_sentences, 154 | valid_depot_prefixes, 155 | invalid_depot_sentences + invalid_depot_typed_sentences, 156 | ), 157 | "satellite": ( 158 | valid_satellite_sentences + invalid_satellite_typed_sentences, 159 | valid_satellite_prefixes, 160 | invalid_satellite_sentences, 161 | ), 162 | "satellite_typed": ( 163 | valid_satellite_sentences, 164 | valid_satellite_prefixes, 165 | invalid_satellite_sentences + invalid_satellite_typed_sentences, 166 | ), 167 | } 168 | 169 | 170 | @pytest.fixture(scope="module") 171 | def recognizers(): 172 | recognizers = {} 173 | for grammar_name in TestCases: 174 | with open(f"examples/grammars/PDDL/{grammar_name}.ebnf", "r") as file: 175 | input_text = file.read() 176 | parsed_grammar = parse_ebnf(input_text) 177 | start_rule_id = parsed_grammar.symbol_table["root"] 178 | recognizers[grammar_name] = StringRecognizer( 179 | parsed_grammar.grammar_encoding, start_rule_id 180 | ) 181 | return recognizers 182 | 183 | 184 | def test_valid_sentences(recognizers): 185 | for grammar_name, recognizer in recognizers.items(): 186 | valid_full, valid_prefix, invalid = TestCases[grammar_name] 187 | 188 | for PDDL_test_case in valid_full: 189 | fail_msg = ( 190 | f"{grammar_name.capitalize()}:" 191 | + f"Failed on {PDDL_test_case.name}, {PDDL_test_case.PDDL}" 192 | ) 193 | assert recognizer._accept_string(PDDL_test_case.PDDL) == True, fail_msg 194 | 195 | for PDDL_test_case in valid_prefix + invalid: 196 | fail_msg = ( 197 | f"{grammar_name.capitalize()}:" 198 | + f"Failed on {PDDL_test_case.name}, {PDDL_test_case.PDDL}" 199 | ) 200 | assert recognizer._accept_string(PDDL_test_case.PDDL) == False, fail_msg 201 | 202 | 203 | def test_valid_prefixes(recognizers): 204 | for grammar_name, recognizer in recognizers.items(): 205 | valid_full, valid_prefix, invalid = TestCases[grammar_name] 206 | 207 | for PDDL_test_case in valid_full + valid_prefix: 208 | fail_msg = ( 209 | f"{grammar_name.capitalize()}:" 210 | + f"Failed on {PDDL_test_case.name}, {PDDL_test_case.PDDL}" 211 | ) 212 | assert recognizer._accept_prefix(PDDL_test_case.PDDL) == True, fail_msg 213 | 214 | for PDDL_test_case in invalid: 215 | fail_msg = ( 216 | f"{grammar_name.capitalize()}:" 217 | + f"Failed on {PDDL_test_case.name}, {PDDL_test_case.PDDL}" 218 | ) 219 | assert recognizer._accept_prefix(PDDL_test_case.PDDL) == False, fail_msg 220 | -------------------------------------------------------------------------------- /transformers_cfg/cli/cli_main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | from importlib import import_module 5 | from transformers_cfg.tokenization.utils import is_tokenizer_supported 6 | from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig 7 | from transformers_cfg.grammar_utils import IncrementalGrammarConstraint 8 | from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor 9 | import torch 10 | 11 | # Define ANSI escape codes for colors 12 | RED = "\033[91m" 13 | BLUE = "\033[94m" 14 | RESET = "\033[0m" 15 | 16 | 17 | def parse_arguments(args=None): 18 | parser = argparse.ArgumentParser(description="Transformers-CFG CLI") 19 | subparsers = parser.add_subparsers(dest="command", help="Sub-commands help") 20 | 21 | # Sub-command: check 22 | check_parser = subparsers.add_parser("check", help="Check if a model is supported") 23 | check_parser.add_argument( 24 | "model", type=str, help="The unique model name on HF hub." 25 | ) 26 | 27 | # Sub-command: generate 28 | generate_parser = subparsers.add_parser( 29 | "generate", help="Generate text with grammar constraints" 30 | ) 31 | generate_parser.add_argument( 32 | "-m", 33 | "--model_id", 34 | type=str, 35 | required=True, 36 | help="Model identifier for loading the tokenizer and model", 37 | ) 38 | generate_parser.add_argument( 39 | "-g", 40 | "--grammar_file_path", 41 | type=str, 42 | required=True, 43 | help="Path to the grammar file", 44 | ) 45 | generate_parser.add_argument( 46 | "-p", 47 | "--prompt", 48 | type=str, 49 | required=True, 50 | help="Prompt for generation", 51 | ) 52 | generate_parser.add_argument( 53 | "-d", 54 | "--device", 55 | type=str, 56 | default="cuda" if torch.cuda.is_available() else "cpu", 57 | choices=["cpu", "cuda"], 58 | help="Device to run the model on", 59 | ) 60 | generate_parser.add_argument( 61 | "-n", 62 | "--max_new_tokens", 63 | type=int, 64 | default=20, 65 | help="Maximum number of new tokens to generate", 66 | ) 67 | generate_parser.add_argument( 68 | "--repetition_penalty", 69 | type=float, 70 | default=1.1, 71 | help="Penalty for token repetition", 72 | ) 73 | generate_parser.add_argument( 74 | "--use_4bit", 75 | action="store_true", 76 | help="Load the model in 4-bit mode using bitsandbytes", 77 | ) 78 | generate_parser.add_argument( 79 | "--use_8bit", 80 | action="store_true", 81 | help="Load the model in 8-bit mode using bitsandbytes", 82 | ) 83 | generate_parser.add_argument( 84 | "--no_contrast_mode", 85 | action="store_true", 86 | help="Disable contrast mode (enabled by default)", 87 | ) 88 | generate_parser.add_argument( 89 | "--save_to", 90 | type=str, 91 | help="File path to save the generated text", 92 | ) 93 | generate_parser.add_argument( 94 | "--use_mlx", 95 | action="store_true", 96 | help="Use MLX on max to speed up generation", 97 | ) 98 | 99 | return parser.parse_args(args) 100 | 101 | 102 | def check_model_support(model_name): 103 | # Check if the model tokenizer is supported 104 | if is_tokenizer_supported(model_name): 105 | print(f"{model_name} is supported") 106 | return True 107 | else: 108 | print(f"{model_name} is not supported") 109 | return False 110 | 111 | 112 | def generate_text(args): 113 | # Store results for optional file output 114 | result = f"Prompt: {args.prompt}\n\n" 115 | 116 | # Load model and tokenizer 117 | tokenizer = AutoTokenizer.from_pretrained(args.model_id) 118 | tokenizer.pad_token = tokenizer.eos_token 119 | 120 | # Load grammar 121 | with open(args.grammar_file_path, "r") as file: 122 | grammar_str = file.read() 123 | grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer) 124 | grammar_processor = GrammarConstrainedLogitsProcessor(grammar) 125 | 126 | if args.use_mlx: 127 | try: 128 | import_module("mlx_lm") 129 | except ImportError: 130 | raise ImportError( 131 | "You need to install mlx to use MLX. Install it with `pip install 'git+https://github.com/nathanrchn/mlx-examples.git@logits_processor#subdirectory=llms'`." 132 | ) 133 | 134 | import numpy as np 135 | import mlx.core as mx 136 | from mlx_lm import load, stream_generate 137 | 138 | model, _ = load(args.model_id) 139 | 140 | if not args.no_contrast_mode: 141 | print(RED + "Unconstrained Generation:" + RESET) 142 | result += "Unconstrained Generation:\n" 143 | generation_stream = stream_generate( 144 | model, 145 | tokenizer, 146 | prompt=args.prompt, 147 | max_tokens=args.max_new_tokens, 148 | repetition_penalty=args.repetition_penalty, 149 | ) 150 | 151 | for token in generation_stream: 152 | result += token 153 | print(RED + token, end="", flush=True) 154 | 155 | print(RESET) 156 | 157 | def logits_processor(input_ids: mx.array, logits: mx.array) -> mx.array: 158 | torch_input_ids = torch.tensor( 159 | np.array(input_ids[None, :]), device=args.device 160 | ) 161 | torch_logits = torch.tensor(np.array(logits), device=args.device) 162 | 163 | torch_processed_logits = grammar_processor(torch_input_ids, torch_logits) 164 | return mx.array(torch_processed_logits.cpu().numpy()) 165 | 166 | generation_stream = stream_generate( 167 | model, 168 | tokenizer, 169 | prompt=args.prompt, 170 | max_tokens=args.max_new_tokens, 171 | repetition_penalty=args.repetition_penalty, 172 | logits_processor=logits_processor, 173 | ) 174 | 175 | # print prompt first in color 176 | print("\033[92m" + "Prompt:" + args.prompt + RESET) 177 | 178 | print(BLUE + "Constrained Generation:" + RESET) 179 | result += "Constrained Generation:\n" 180 | for token in generation_stream: 181 | result += token 182 | print(token, end="", flush=True) 183 | 184 | print() 185 | 186 | if args.save_to: 187 | with open(args.save_to, "w") as f: 188 | f.write(result) 189 | print(f"\nResults saved to {args.save_to}") 190 | 191 | return 192 | 193 | # Load the model with bitsandbytes if 8bit or 4bit flag is set 194 | if args.use_8bit or args.use_4bit: 195 | try: 196 | import_module("bitsandbytes") 197 | except ImportError: 198 | raise ImportError( 199 | "You need to install bitsandbytes to use 8-bit or 4-bit modes. Install it with `pip install bitsandbytes`." 200 | ) 201 | 202 | bnb_config = BitsAndBytesConfig( 203 | load_in_8bit=args.use_8bit, 204 | load_in_4bit=args.use_4bit, 205 | bnb_4bit_compute_dtype=torch.bfloat16, 206 | ) 207 | 208 | model = AutoModelForCausalLM.from_pretrained( 209 | args.model_id, quantization_config=bnb_config, device_map="auto" 210 | ) 211 | else: 212 | model = AutoModelForCausalLM.from_pretrained(args.model_id).to(args.device) 213 | 214 | # set special tokens in generation config 215 | model.generation_config.pad_token_id = tokenizer.pad_token_id 216 | 217 | inputs = tokenizer( 218 | args.prompt, add_special_tokens=False, return_tensors="pt", padding=True 219 | ) 220 | input_ids = inputs["input_ids"].to(args.device) 221 | attention_mask = inputs["attention_mask"].to(args.device) 222 | 223 | # Generate with grammar constraints 224 | constrained_output = model.generate( 225 | input_ids, 226 | attention_mask=attention_mask, 227 | do_sample=False, 228 | max_new_tokens=args.max_new_tokens, 229 | logits_processor=[grammar_processor], 230 | repetition_penalty=args.repetition_penalty, 231 | num_return_sequences=1, 232 | ) 233 | 234 | # remove prefix from the output 235 | constrained_output = constrained_output[:, len(input_ids[0]) :] 236 | 237 | constrained_generations = tokenizer.batch_decode( 238 | constrained_output, skip_special_tokens=True 239 | ) 240 | 241 | # print prompt first in color 242 | print("\033[92m" + "Prompt:" + args.prompt + RESET) 243 | 244 | # Generate without grammar constraints (if contrast mode is enabled) 245 | if not args.no_contrast_mode: 246 | unconstrained_output = model.generate( 247 | input_ids, 248 | attention_mask=attention_mask, 249 | do_sample=False, 250 | max_new_tokens=args.max_new_tokens, 251 | repetition_penalty=args.repetition_penalty, 252 | num_return_sequences=1, 253 | ) 254 | # remove prefix from the output 255 | unconstrained_output = unconstrained_output[:, len(input_ids[0]) :] 256 | unconstrained_generations = tokenizer.batch_decode( 257 | unconstrained_output, skip_special_tokens=True 258 | ) 259 | 260 | # Print results in different colors 261 | print("\n" + "#" * 30) 262 | print("\033[91mUnconstrained Generation" + RESET) 263 | print("#" * 30 + "\n") 264 | result += "Unconstrained Generation:\n" 265 | for generation in unconstrained_generations: 266 | print(RED + generation + RESET) 267 | result += generation + "\n" 268 | 269 | print("\n" + "#" * 30) 270 | print("\033[94mConstrained Generation" + RESET) 271 | print("#" * 30 + "\n") 272 | result += "Constrained Generation:\n" 273 | for generation in constrained_generations: 274 | print(BLUE + generation + RESET) 275 | result += generation + "\n" 276 | 277 | # Save to file if save_to is provided 278 | if args.save_to: 279 | with open(args.save_to, "w") as f: 280 | f.write(result) 281 | print(f"\nResults saved to {args.save_to}") 282 | 283 | 284 | def main(args=None): 285 | args = parse_arguments(args) 286 | 287 | if args.command == "check": 288 | check_model_support(args.model) 289 | elif args.command == "generate": 290 | generate_text(args) 291 | 292 | 293 | if __name__ == "__main__": 294 | main() 295 | 296 | # TODO, add support for device selection for parsing 297 | --------------------------------------------------------------------------------