├── .coverage ├── .coveragerc ├── .github ├── FUNDING.yml └── workflows │ ├── python-app.yml │ └── python-publish.yml ├── .gitignore ├── CITATION.cff ├── LICENSE ├── README.md ├── examples ├── quickstart-mlx.py ├── quickstart-torch.py ├── simple_demo.ipynb ├── simple_demo.py ├── thinking_answer.ipynb └── thinking_answer.py ├── logo.png ├── pse ├── structuring_engine.py ├── types │ ├── array.py │ ├── base │ │ ├── any.py │ │ ├── chain.py │ │ ├── character.py │ │ ├── encapsulated.py │ │ ├── loop.py │ │ ├── phrase.py │ │ └── wait_for.py │ ├── boolean.py │ ├── enum.py │ ├── grammar │ │ ├── __init__.py │ │ ├── default_grammars │ │ │ ├── bash.lark │ │ │ ├── bash.py │ │ │ ├── python.lark │ │ │ └── python.py │ │ └── lark.py │ ├── integer.py │ ├── json │ │ ├── __init__.py │ │ ├── any_json_schema.py │ │ ├── json_array.py │ │ ├── json_key_value.py │ │ ├── json_number.py │ │ ├── json_object.py │ │ ├── json_string.py │ │ ├── json_value.py │ │ └── schema_sources │ │ │ ├── from_function.py │ │ │ └── from_pydantic.py │ ├── key_value.py │ ├── misc │ │ ├── fenced_freeform.py │ │ └── freeform.py │ ├── number.py │ ├── object.py │ ├── string.py │ ├── whitespace.py │ └── xml │ │ ├── xml_encapsulated.py │ │ └── xml_tag.py └── util │ ├── generate_mlx.py │ ├── get_top_logits.py │ ├── jax_mixin.py │ ├── tf_mixin.py │ └── torch_mixin.py ├── pyproject.toml └── tests ├── functional └── test_e2e.py └── unit ├── test_structuring_engine.py └── types ├── base ├── test_chain.py ├── test_character.py ├── test_encapsulated.py ├── test_loop.py ├── test_phrase.py ├── test_state_machine.py └── test_wait_for.py ├── bash ├── test_bash_code.py └── test_wrapped_bash.py ├── json ├── schema_sources │ ├── test_from_function.py │ └── test_from_pydantic.py ├── test_any_json_schema.py ├── test_json_array.py ├── test_json_key_value.py ├── test_json_number.py ├── test_json_object.py ├── test_json_string.py ├── test_json_to_state_machine.py └── test_json_value.py ├── misc ├── test_fenced_freeform.py └── test_freeform.py ├── python ├── test_python_code.py └── test_wrapped_python.py ├── test_array.py ├── test_boolean.py ├── test_enum.py ├── test_integer.py ├── test_key_value.py ├── test_number.py ├── test_object.py ├── test_string.py ├── test_whitespace.py └── xml ├── test_xml_encapsulated.py └── test_xml_tag.py /.coverage: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TheProxyCompany/proxy-structuring-engine/1cb33d487126abc6b85a3f78833177840911d4e4/.coverage -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | source = pse 3 | omit = 4 | tests/* 5 | pse/util/* 6 | 7 | [report] 8 | omit = 9 | tests/* 10 | */__init__.py 11 | pse/util/* 12 | exclude_lines = 13 | def __repr__ 14 | def __str__ 15 | ^\s*if __name__ == ['"]__main__['"]:? 16 | pragma: no cover 17 | -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | github: [TheProxyCompany] 2 | custom: ["https://www.theproxycompany.com/research#support"] 3 | -------------------------------------------------------------------------------- /.github/workflows/python-app.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python 3 | 4 | name: Unit Tests 5 | 6 | on: 7 | push: 8 | branches: [ "main" ] 9 | pull_request: 10 | branches: [ "main" ] 11 | 12 | permissions: 13 | contents: read 14 | 15 | jobs: 16 | build: 17 | 18 | runs-on: ubuntu-latest 19 | 20 | steps: 21 | - uses: actions/checkout@v4 22 | - name: Set up Python 3.12 23 | uses: actions/setup-python@v4 24 | with: 25 | python-version: "3.12" 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install .[dev] 30 | - uses: astral-sh/ruff-action@v1 31 | with: 32 | changed-files: "true" 33 | - name: Lint with flake8 34 | run: | 35 | # stop the build if there are Python syntax errors or undefined names 36 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 37 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 38 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 39 | 40 | - name: Run unit tests 41 | run: pytest tests/unit 42 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published, updated] 14 | # push: 15 | # branches: [ "main" ] 16 | 17 | permissions: 18 | contents: read 19 | 20 | jobs: 21 | deploy: 22 | 23 | runs-on: ubuntu-latest 24 | 25 | steps: 26 | - uses: actions/checkout@v4 27 | - name: Set up Python 28 | uses: actions/setup-python@v4 29 | with: 30 | python-version: '3.x' 31 | - name: Install dependencies 32 | run: | 33 | python -m pip install --upgrade pip 34 | pip install build 35 | - name: Build package 36 | run: python -m build 37 | - name: Publish package 38 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 39 | with: 40 | user: __token__ 41 | password: ${{ secrets.PYPI_API_TOKEN }} 42 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | dist 2 | __pycache__ 3 | *.pyc 4 | .cursorrules 5 | .cursorignore 6 | .DS_Store 7 | .env 8 | .venv 9 | .pytest_cache 10 | Cargo.lock 11 | target 12 | .cache 13 | uv.lock 14 | */__pycache__ 15 | .venv 16 | .pytest_cache 17 | CLAUDE.md 18 | .coverage.* 19 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - family-names: "Wind" 5 | given-names: "Jack" 6 | title: "Proxy Structuring Engine" 7 | version: 2025.06.1 8 | date-released: 2025-06-03 9 | url: "https://github.com/TheProxyCompany/proxy-structuring-engine" 10 | repository-code: "https://github.com/TheProxyCompany/proxy-structuring-engine" 11 | license: Apache-2.0 12 | publisher: "The Proxy Company" 13 | -------------------------------------------------------------------------------- /examples/quickstart-mlx.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import sys 4 | 5 | import mlx.core as mx 6 | from mlx_lm.utils import generate_step, load # type: ignore[reportMissingImports] 7 | 8 | from pse.structuring_engine import StructuringEngine 9 | 10 | logging.basicConfig(level=logging.DEBUG, stream=sys.stdout) 11 | 12 | ADVANCED_JSON_SCHEMA = { 13 | "type": "object", 14 | "description": "High-level thoughts, reasoning and internal dialogue.\n Used for step by step reasoning.", 15 | "properties": { 16 | "chain_of_thought": { 17 | "type": "array", 18 | "items": { 19 | "type": "string", 20 | "minLength": 200, 21 | }, 22 | "minItems": 1, 23 | "maxItems": 3, 24 | }, 25 | }, 26 | "required": ["chain_of_thought"], 27 | } 28 | 29 | system_message = f""" 30 | You are an AI that can think step by step. 31 | You must follow this schema when generating your response: 32 | {json.dumps(ADVANCED_JSON_SCHEMA, indent=2)} 33 | """ 34 | 35 | prompt = "This is a test - I want to see your private internal monologue." 36 | messages = [ 37 | {"role": "system", "content": system_message}, 38 | {"role": "user", "content": prompt}, 39 | ] 40 | model_path_hf = "meta-llama/Llama-3.2-3B-Instruct" 41 | model, tokenizer = load(model_path_hf) 42 | engine = StructuringEngine(tokenizer._tokenizer) # noqa: SLF001 43 | engine.configure(ADVANCED_JSON_SCHEMA) 44 | 45 | encoded_prompt = engine.tokenizer.apply_chat_template( 46 | conversation=messages, 47 | add_generation_prompt=True, 48 | ) 49 | 50 | for tokens, _ in generate_step( 51 | prompt=mx.array(encoded_prompt), 52 | model=model, 53 | logits_processors=[engine.process_logits], 54 | sampler=lambda x: engine.sample(x, mx.argmax), # type: ignore [arg-type] 55 | max_tokens=-1, 56 | ): 57 | encoded_prompt.append(tokens) # type: ignore [attr-defined] 58 | if engine.has_reached_accept_state: 59 | break 60 | 61 | output = engine.get_structured_output() 62 | print(json.dumps(output, indent=2)) 63 | -------------------------------------------------------------------------------- /examples/quickstart-torch.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import torch # type: ignore[reportMissingImports] 4 | from pydantic import BaseModel 5 | from transformers.models.auto.tokenization_auto import AutoTokenizer 6 | from transformers.models.llama.modeling_llama import LlamaForCausalLM 7 | 8 | from pse.structuring_engine import StructuringEngine 9 | from pse.util.torch_mixin import PSETorchMixin 10 | 11 | 12 | # 1. Define your desired output structure using Pydantic 13 | class Product(BaseModel): 14 | name: str 15 | price: float 16 | description: str | None = None 17 | 18 | 19 | # 2. Load your model and tokenizer. Apply the PSE mixin. 20 | class PSE_Torch(PSETorchMixin, LlamaForCausalLM): 21 | pass 22 | 23 | 24 | model_path = "meta-llama/Llama-3.2-1B-Instruct" # Or any model 25 | tokenizer = AutoTokenizer.from_pretrained(model_path) 26 | model = PSE_Torch.from_pretrained( 27 | model_path, torch_dtype=torch.bfloat16, device_map="auto" 28 | ) # Load to GPU, if available 29 | 30 | # Ensure padding token is set for generation 31 | model.config.pad_token_id = model.config.eos_token_id[0] 32 | if model.generation_config: 33 | model.generation_config.pad_token_id = model.config.eos_token_id[0] 34 | 35 | # 3. Create a StructuringEngine and configure it with your schema 36 | model.engine = StructuringEngine(tokenizer) 37 | model.engine.configure(Product) 38 | 39 | # 4. Create your prompt. Include the schema for the LLM's context. 40 | prompt = f""" 41 | You are a product catalog assistant. Create a product description in JSON format, 42 | following this schema: 43 | 44 | {Product.model_json_schema()} 45 | 46 | Create a product description for a new type of noise-cancelling headphones. 47 | """ 48 | 49 | messages = [{"role": "user", "content": prompt}] 50 | input_ids = tokenizer.apply_chat_template( 51 | messages, return_tensors="pt", add_generation_prompt=True 52 | ) 53 | # 5. Generate! 54 | assert isinstance(input_ids, torch.Tensor) 55 | input_ids = input_ids.to(model.device) 56 | assert isinstance(input_ids, torch.Tensor) 57 | output = model.generate( 58 | input_ids, 59 | do_sample=True, 60 | max_new_tokens=200, 61 | top_k=10, 62 | top_p=None, 63 | ) 64 | # 6. Parse the structured output 65 | structured_output = model.engine.get_structured_output(Product) 66 | print(json.dumps(structured_output.model_dump(), indent=2)) 67 | -------------------------------------------------------------------------------- /examples/simple_demo.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | 4 | import torch # type: ignore[reportMissingImports] 5 | from pydantic import BaseModel 6 | from transformers.models.auto.tokenization_auto import AutoTokenizer 7 | from transformers.models.llama.modeling_llama import LlamaForCausalLM 8 | 9 | from pse.structuring_engine import StructuringEngine 10 | from pse.util.torch_mixin import PSETorchMixin 11 | 12 | # toggle this to logging.DEBUG to see the PSE debug logs! 13 | logging.basicConfig(level=logging.DEBUG) 14 | 15 | 16 | class PSE_Torch(PSETorchMixin, LlamaForCausalLM): 17 | pass 18 | 19 | 20 | model_path = "meta-llama/Llama-3.2-1B-Instruct" 21 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) 22 | model = PSE_Torch.from_pretrained( 23 | model_path, 24 | torch_dtype=torch.bfloat16, 25 | device_map="auto", 26 | ) 27 | 28 | model.config.pad_token_id = model.config.eos_token_id[0] 29 | if model.generation_config: 30 | model.generation_config.top_p = None 31 | model.generation_config.top_k = 8 32 | model.generation_config.do_sample = True 33 | model.generation_config.temperature = 1.0 34 | model.generation_config.max_new_tokens = 1000 35 | model.generation_config.pad_token_id = model.config.eos_token_id[0] 36 | 37 | # create structuring engine normally 38 | model.engine = StructuringEngine(tokenizer, multi_token_sampling=True) 39 | SIMPLE_JSON_SCHEMA = { 40 | "type": "object", 41 | "properties": {"value": {"type": "number"}}, 42 | "required": ["value"], 43 | } 44 | model.engine.configure(SIMPLE_JSON_SCHEMA) 45 | prompt = ( 46 | "Please generate a json object with the value 9.11, with the following schema:\n" 47 | ) 48 | prompt += json.dumps(SIMPLE_JSON_SCHEMA, indent=2) 49 | 50 | messages = [{"role": "user", "content": prompt}] 51 | input_ids = tokenizer.apply_chat_template( 52 | messages, return_tensors="pt", add_generation_prompt=True 53 | ) 54 | assert isinstance(input_ids, torch.Tensor) 55 | input_ids = input_ids.to(model.device) 56 | assert isinstance(input_ids, torch.Tensor) 57 | output = model.generate( 58 | input_ids, 59 | do_sample=True, 60 | ) 61 | # you can print the prompt + output: 62 | # print(tokenizer.decode(output[0])) 63 | structured_output = model.engine.get_structured_output() 64 | print(100 * "-") 65 | print(json.dumps(structured_output, indent=2)) 66 | 67 | ADVANCED_JSON_SCHEMA = { 68 | "type": "object", 69 | "description": "High-level thoughts, reasoning and internal dialogue.\n Used for step by step reasoning.", 70 | "properties": { 71 | "chain_of_thought": { 72 | "type": "array", 73 | "items": { 74 | "type": "string", 75 | "minLength": 20, # minimum length of a thought (optional) 76 | }, 77 | "minItems": 1, # floor the number of thoughts (optional) 78 | "maxItems": 3, # limit the number of thoughts (optional) 79 | }, 80 | }, 81 | "required": ["chain_of_thought"], 82 | } 83 | model.engine.configure(ADVANCED_JSON_SCHEMA) 84 | raw_prompt = ( 85 | f"This is a test of your thought process.\n" 86 | f"I want to see your private internal monologue.\n" 87 | f"Please follow the following schema when generating your response:\n{json.dumps(ADVANCED_JSON_SCHEMA, indent=2)}\n" 88 | ) 89 | messages = [{"role": "user", "content": raw_prompt}] 90 | input_ids = tokenizer.apply_chat_template( 91 | messages, return_tensors="pt", add_generation_prompt=True 92 | ) 93 | assert isinstance(input_ids, torch.Tensor) 94 | input_ids = input_ids.to(model.device) 95 | assert isinstance(input_ids, torch.Tensor) 96 | greedy_output = model.generate(input_ids) 97 | structured_output = model.engine.get_structured_output() 98 | print(100 * "-") 99 | print(json.dumps(structured_output, indent=2)) 100 | 101 | 102 | class CursorPositionModel(BaseModel): 103 | """ 104 | An object representing the position and click state of a cursor. 105 | 106 | Attributes: 107 | x_pos: The horizontal position of the cursor in pixels 108 | y_pos: The vertical position of the cursor in pixels 109 | left_click: Whether the left mouse button is currently pressed. Default is False. 110 | """ 111 | 112 | x_pos: int 113 | y_pos: int 114 | left_click: bool = False 115 | 116 | 117 | model.engine.configure( 118 | CursorPositionModel, delimiters=("", "") 119 | ) 120 | prompt = ( 121 | "Please use the following schema to generate a cursor position:\n" 122 | f"{json.dumps(CursorPositionModel.model_json_schema(), indent=2)}.\n" 123 | "Pretend to move the cursor to x = 100 and y = 100, with the left mouse button clicked.\n" 124 | "Wrap your response in CursorPositionModel." 125 | ) 126 | messages = [{"role": "user", "content": prompt}] 127 | input_ids = tokenizer.apply_chat_template( 128 | messages, return_tensors="pt", add_generation_prompt=True 129 | ) 130 | assert isinstance(input_ids, torch.Tensor) 131 | input_ids = input_ids.to(model.device) 132 | assert isinstance(input_ids, torch.Tensor) 133 | output = model.generate( 134 | input_ids, 135 | do_sample=True, 136 | ) 137 | structured_output = model.engine.get_structured_output(CursorPositionModel) 138 | print(100 * "-") 139 | print(json.dumps(structured_output.model_dump(), indent=2)) 140 | -------------------------------------------------------------------------------- /examples/thinking_answer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch # type: ignore[reportMissingImports] 4 | from pse_core.state_machine import StateMachine 5 | from transformers.models.auto.tokenization_auto import AutoTokenizer 6 | from transformers.models.llama.modeling_llama import LlamaForCausalLM 7 | 8 | from pse.structuring_engine import StructuringEngine 9 | from pse.types.base.loop import LoopStateMachine 10 | from pse.types.misc.fenced_freeform import FencedFreeformStateMachine 11 | from pse.util.torch_mixin import PSETorchMixin 12 | 13 | # toggle this to logging.DEBUG to see the PSE debug logs! 14 | logging.basicConfig(level=logging.DEBUG) 15 | 16 | 17 | class PSE_Torch(PSETorchMixin, LlamaForCausalLM): 18 | pass 19 | 20 | 21 | # you can change the model path to any other model on huggingface 22 | model_path = "meta-llama/Llama-3.2-1B-Instruct" 23 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) 24 | model = PSE_Torch.from_pretrained( 25 | model_path, 26 | torch_dtype=torch.bfloat16, 27 | device_map="auto", 28 | ) 29 | 30 | model.config.pad_token_id = model.config.eos_token_id[0] 31 | if model.generation_config: 32 | model.generation_config.top_p = None 33 | model.generation_config.top_k = 8 34 | model.generation_config.do_sample = True 35 | model.generation_config.temperature = 1.0 36 | model.generation_config.max_new_tokens = 1000 37 | model.generation_config.pad_token_id = model.config.eos_token_id[0] 38 | 39 | # create structuring engine normally 40 | model.engine = StructuringEngine(tokenizer, multi_token_sampling=True) 41 | 42 | # define custom state machines 43 | thinking_delimiters = ("[thinking]", "[/thinking]") 44 | answer_delimiters = ("[answer]", "[/answer]") 45 | 46 | # encapsulated state machines are used to allow a language model 47 | # to generate unstructured content before the structured output 48 | # starts. This "scratchpad" is disabled by default (min_buffer_length=-1) 49 | thinking_state_machine = FencedFreeformStateMachine("thinking", thinking_delimiters) 50 | # the answer state machine is used to wrap the structured output 51 | answer_state_machine = FencedFreeformStateMachine("answer", answer_delimiters) 52 | # Configure the engine with a state machine that enforces the following flow: 53 | # 54 | # The model starts in the 'thinking' state where it can express its reasoning. 55 | # From there, it can transition to providing its final answer. 56 | # 57 | # ┌──────────────┐ 58 | # │ │ 59 | # ▼ │ 60 | # ┌──────────┐ │ 61 | # │ │ │ 62 | # │ thinking ├────────┘ 63 | # │ │ 64 | # └──────┬───┘ 65 | # │ 66 | # ▼ 67 | # ┌──────────┐ 68 | # │ │ 69 | # │ answer │ 70 | # │ │ 71 | # └──────────┘ 72 | # 73 | # This ensures the model follows a structured thought process before 74 | # providing its final answer. 75 | 76 | model.engine.configure( 77 | StateMachine( 78 | { 79 | "thinking": [ 80 | ( 81 | LoopStateMachine( 82 | thinking_state_machine, 83 | min_loop_count=1, 84 | max_loop_count=2, 85 | ), 86 | "answer", 87 | ) 88 | ], 89 | "answer": [ 90 | ( 91 | answer_state_machine, 92 | "done", 93 | ), 94 | ], 95 | }, 96 | start_state="thinking", 97 | end_states=["done"], 98 | ) 99 | ) 100 | 101 | system_prompt = ( 102 | f"Reason step by step using delimiters to seperate your thought process.\n" 103 | "For example, when asked a question, you should think and then answer.\n" 104 | "Example:\n" 105 | f"{thinking_delimiters[0]}Thinking goes here{thinking_delimiters[1]}" 106 | f"{answer_delimiters[0]}Answer goes here{answer_delimiters[1]}\n" 107 | "you can think multiple times before providing your answer.\n\n" 108 | ) 109 | prompt = "Please pick a favorite color. Think about it first." 110 | 111 | input_ids = tokenizer.apply_chat_template( 112 | [ 113 | {"role": "system", "content": system_prompt}, 114 | {"role": "user", "content": prompt}, 115 | ], 116 | return_tensors="pt", 117 | add_generation_prompt=True, 118 | ) 119 | assert isinstance(input_ids, torch.Tensor) 120 | input_ids = input_ids.to(model.device) 121 | assert isinstance(input_ids, torch.Tensor) 122 | output = model.generate(input_ids) 123 | 124 | for label, output in model.engine.get_labeled_output(): 125 | print("-" * 100) 126 | print(f"[{label}]") 127 | print(output) 128 | print(f"[/{label}]") 129 | -------------------------------------------------------------------------------- /logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TheProxyCompany/proxy-structuring-engine/1cb33d487126abc6b85a3f78833177840911d4e4/logo.png -------------------------------------------------------------------------------- /pse/types/array.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Any 4 | 5 | from pse_core import StateGraph, StateId 6 | from pse_core.state_machine import StateMachine 7 | from pse_core.stepper import Stepper 8 | 9 | from pse.types.base.chain import ChainStateMachine 10 | from pse.types.base.phrase import PhraseStateMachine 11 | from pse.types.json.json_value import JsonStateMachine 12 | from pse.types.whitespace import WhitespaceStateMachine 13 | 14 | 15 | class ArrayStateMachine(StateMachine): 16 | """ 17 | Accepts a well-formed JSON array and handles state transitions during parsing. 18 | 19 | This state_machine manages the parsing of JSON arrays by defining the state transitions 20 | and maintaining the current array values being parsed. 21 | """ 22 | 23 | def __init__(self, state_graph: StateGraph | None = None) -> None: 24 | base_array_state_graph: StateGraph = { 25 | 0: [(PhraseStateMachine("["), 1)], 26 | 1: [ 27 | (WhitespaceStateMachine(), 2), 28 | (PhraseStateMachine("]"), "$"), # Allow empty array 29 | ], 30 | 2: [(JsonStateMachine(), 3)], 31 | 3: [(WhitespaceStateMachine(), 4)], 32 | 4: [ 33 | ( 34 | ChainStateMachine( 35 | [PhraseStateMachine(","), WhitespaceStateMachine()] 36 | ), 37 | 2, 38 | ), 39 | (PhraseStateMachine("]"), "$"), 40 | ], 41 | } 42 | super().__init__(state_graph or base_array_state_graph) 43 | 44 | def get_new_stepper(self, state: StateId | None = None) -> ArrayStepper: 45 | return ArrayStepper(self, state) 46 | 47 | def __str__(self) -> str: 48 | return "Array" 49 | 50 | 51 | class ArrayStepper(Stepper): 52 | def __init__( 53 | self, 54 | state_machine: ArrayStateMachine, 55 | current_state: StateId | None = None, 56 | ): 57 | super().__init__(state_machine, current_state) 58 | self.state_machine: ArrayStateMachine = state_machine 59 | self.value: list[Any] = [] 60 | 61 | def clone(self) -> ArrayStepper: 62 | cloned_stepper = super().clone() 63 | cloned_stepper.value = self.value[:] 64 | return cloned_stepper 65 | 66 | def is_within_value(self) -> bool: 67 | return self.current_state == 3 68 | 69 | def add_to_history(self, stepper: Stepper) -> None: 70 | if self.is_within_value(): 71 | self.value.append(stepper.get_current_value()) 72 | super().add_to_history(stepper) 73 | 74 | def get_current_value(self) -> list: 75 | """ 76 | Get the current parsed JSON object. 77 | 78 | Returns: 79 | dict[str, Any]: The accumulated key-value pairs representing the JSON object. 80 | """ 81 | return self.value 82 | -------------------------------------------------------------------------------- /pse/types/base/any.py: -------------------------------------------------------------------------------- 1 | from pse_core import StateId 2 | from pse_core.state_machine import StateMachine 3 | from pse_core.stepper import Stepper 4 | 5 | 6 | class AnyStateMachine(StateMachine): 7 | def __init__(self, state_machines: list[StateMachine]) -> None: 8 | 9 | self.state_machines: list[StateMachine] = state_machines 10 | super().__init__( 11 | { 12 | 0: [ 13 | (state_machine, "$") 14 | for state_machine in self.state_machines 15 | ] 16 | } 17 | ) 18 | 19 | def get_steppers(self, state: StateId | None = None) -> list[Stepper]: 20 | steppers = [] 21 | for edge, _ in self.get_edges(state or 0): 22 | steppers.extend(edge.get_steppers()) 23 | return steppers 24 | 25 | def __str__(self) -> str: 26 | return "Any" 27 | -------------------------------------------------------------------------------- /pse/types/base/chain.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | 5 | from pse_core.state_machine import StateMachine 6 | from pse_core.stepper import Stepper 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | class ChainStateMachine(StateMachine): 12 | """ 13 | Chain multiple StateMachines in a specific order. 14 | """ 15 | 16 | def __init__(self, state_machines: list[StateMachine], is_optional: bool = False) -> None: 17 | """ 18 | Args: 19 | state_machines: State machines to be chained in sequence 20 | """ 21 | super().__init__( 22 | state_graph={ 23 | i: [(state_machine, i + 1)] 24 | for i, state_machine in enumerate(state_machines) 25 | }, 26 | end_states=[len(state_machines)], 27 | is_optional=is_optional, 28 | ) 29 | 30 | def get_new_stepper(self, state: int | str | None = None) -> Stepper: 31 | return ChainStepper(self, state) 32 | 33 | def __str__(self) -> str: 34 | return "Chain" 35 | 36 | 37 | class ChainStepper(Stepper): 38 | """ 39 | A stepper that chains multiple steppers in a specific sequence. 40 | """ 41 | 42 | def __init__(self, chain_state_machine: ChainStateMachine, *args, **kwargs) -> None: 43 | super().__init__(chain_state_machine, *args, **kwargs) 44 | self.state_machine = chain_state_machine 45 | -------------------------------------------------------------------------------- /pse/types/base/character.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from collections.abc import Iterable 4 | 5 | from pse_core.state_machine import StateMachine 6 | from pse_core.stepper import Stepper 7 | 8 | 9 | class CharacterStateMachine(StateMachine): 10 | """ 11 | Accepts one or more valid characters. 12 | """ 13 | 14 | def __init__( 15 | self, 16 | whitelist_charset: str | list[str] | Iterable[str] = "", 17 | graylist_charset: str | list[str] | Iterable[str] = "", 18 | blacklist_charset: str | list[str] | Iterable[str] = "", 19 | char_min: int | None = None, 20 | char_limit: int | None = None, 21 | is_optional: bool = False, 22 | case_sensitive: bool = True, 23 | ) -> None: 24 | """ 25 | Initialize a CharacterStateMachine with character sets and constraints. 26 | 27 | Args: 28 | whitelist_charset: Characters that are explicitly allowed 29 | graylist_charset: Characters that are allowed but terminate the match if they follow other characters 30 | blacklist_charset: Characters that are explicitly forbidden 31 | char_min: Minimum number of characters required (0 if None) 32 | char_limit: Maximum number of characters allowed (unlimited if 0 or None) 33 | is_optional: Whether this state machine is optional 34 | case_sensitive: Whether character matching is case-sensitive 35 | """ 36 | super().__init__( 37 | is_optional=is_optional, 38 | is_case_sensitive=case_sensitive, 39 | ) 40 | self.char_min = char_min or 0 41 | self.char_limit = char_limit or 0 42 | self.charset: set[str] = set() 43 | self.graylist_charset: set[str] = set() 44 | self.blacklist_charset: set[str] = set() 45 | 46 | # Process all charsets efficiently 47 | def convert_to_set(chars): 48 | return set(char.lower() for char in chars) if chars else set() 49 | 50 | self.charset = set(whitelist_charset) if case_sensitive else convert_to_set(whitelist_charset) 51 | self.graylist_charset = set(graylist_charset) if case_sensitive else convert_to_set(graylist_charset) 52 | self.blacklist_charset = set(blacklist_charset) if case_sensitive else convert_to_set(blacklist_charset) 53 | 54 | def get_new_stepper(self, state: int | str) -> CharacterStepper: 55 | return CharacterStepper(self) 56 | 57 | def __str__(self) -> str: 58 | return "Character" 59 | 60 | 61 | class CharacterStepper(Stepper): 62 | """ 63 | Stepper for navigating through characters in CharacterStateMachine. 64 | """ 65 | 66 | def __init__( 67 | self, 68 | state_machine: CharacterStateMachine, 69 | value: str | None = None, 70 | ) -> None: 71 | """ 72 | Initialize the Stepper. 73 | 74 | Args: 75 | value (Optional[str]): The accumulated string value. Defaults to None. 76 | """ 77 | super().__init__(state_machine) 78 | self.target_state = "$" 79 | self.state_machine: CharacterStateMachine = state_machine 80 | self._raw_value = value 81 | if value: 82 | self.consumed_character_count = len(value) 83 | 84 | def accepts_any_token(self) -> bool: 85 | return not self.state_machine.charset 86 | 87 | def get_valid_continuations(self, depth: int = 0) -> list[str]: 88 | """ 89 | Returns a list of valid continuations for the current stepper. 90 | """ 91 | return list(self.state_machine.charset) 92 | 93 | def can_accept_more_input(self) -> bool: 94 | """ 95 | Determines if the stepper can accept more input based on the character limit. 96 | """ 97 | if ( 98 | self.state_machine.char_limit > 0 99 | and self.consumed_character_count >= self.state_machine.char_limit 100 | ): 101 | return False 102 | 103 | return True 104 | 105 | def should_start_step(self, token: str) -> bool: 106 | """ 107 | Determines if a transition should start with the given token. 108 | 109 | Args: 110 | token (str): The input token to check. 111 | 112 | Returns: 113 | bool: True if the token can start a transition, False otherwise. 114 | """ 115 | if not token or ( 116 | self.state_machine.char_limit > 0 117 | and self.consumed_character_count >= self.state_machine.char_limit 118 | ): 119 | return False 120 | 121 | first_char = token[0] 122 | if not self.state_machine.is_case_sensitive: 123 | first_char = first_char.lower() 124 | 125 | if first_char in self.state_machine.blacklist_charset: 126 | return False 127 | 128 | if self.state_machine.charset: 129 | return first_char in self.state_machine.charset 130 | 131 | return True 132 | 133 | def should_complete_step(self) -> bool: 134 | """ 135 | Determines if the transition should be completed based on the character limit. 136 | """ 137 | if ( 138 | self.state_machine.char_limit > 0 139 | and self.consumed_character_count > self.state_machine.char_limit 140 | ): 141 | return False 142 | 143 | if ( 144 | self.state_machine.char_min > 0 145 | and self.consumed_character_count < self.state_machine.char_min 146 | ): 147 | return False 148 | 149 | return True 150 | 151 | def consume(self, token: str) -> list[Stepper]: 152 | """ 153 | Advance the stepper with the given input. 154 | 155 | This method processes the input token and determines how much of it can be consumed 156 | based on character constraints. It stops consuming at the first invalid character. 157 | 158 | Args: 159 | token: The input string to consume 160 | 161 | Returns: 162 | List of new steppers after advancement (empty if nothing can be consumed) 163 | """ 164 | if not token or not self.should_start_step(token): 165 | return [] 166 | 167 | # Apply case sensitivity 168 | token = token.lower() if not self.state_machine.is_case_sensitive else token 169 | 170 | # Cache frequently used properties for performance 171 | charset = self.state_machine.charset 172 | blacklist = self.state_machine.blacklist_charset 173 | graylist = self.state_machine.graylist_charset 174 | char_limit = self.state_machine.char_limit 175 | consumed_count = self.consumed_character_count 176 | 177 | # Find the longest valid prefix efficiently 178 | valid_prefix_len = 0 179 | for char in token: 180 | # Stop at first invalid character or limit 181 | is_blacklisted = char in blacklist 182 | not_in_charset = len(charset) > 0 and char not in charset 183 | exceeds_limit = char_limit > 0 and valid_prefix_len + consumed_count >= char_limit 184 | is_graylisted = len(graylist) > 0 and valid_prefix_len > 0 and char in graylist 185 | 186 | if is_blacklisted or not_in_charset or exceeds_limit or is_graylisted: 187 | break 188 | 189 | valid_prefix_len += 1 190 | 191 | # Extract the valid portion using string slicing 192 | valid_prefix = token[:valid_prefix_len] 193 | 194 | # Create new stepper with updated state 195 | new_value = self.get_raw_value() + valid_prefix 196 | remaining_input = token[len(valid_prefix):] or None 197 | new_stepper = self.step(new_value, remaining_input) 198 | 199 | return [new_stepper] 200 | -------------------------------------------------------------------------------- /pse/types/base/encapsulated.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from collections.abc import Callable 4 | from typing import Self 5 | 6 | from pse_core import StateId 7 | from pse_core.state_machine import StateMachine 8 | from pse_core.stepper import Stepper 9 | 10 | from pse.types.base.phrase import PhraseStateMachine 11 | from pse.types.base.wait_for import WaitFor 12 | 13 | 14 | class EncapsulatedStateMachine(StateMachine): 15 | """ 16 | This class encapsulates an state_machine that recognizes content framed by 17 | specified opening and closing delimiters. 18 | """ 19 | 20 | def __init__( 21 | self, 22 | state_machine: StateMachine, 23 | delimiters: tuple[str, str] | None, 24 | buffer_length: int = -1, 25 | is_optional: bool = False, 26 | ) -> None: 27 | """ 28 | 29 | Args: 30 | state_machine: The state_machine wrapped by this state machine. 31 | delimiters: The tuple of opening and closing delimiters. 32 | """ 33 | self.inner_state_machine = state_machine 34 | self.delimiters = delimiters or ("```", "```") 35 | super().__init__( 36 | { 37 | 0: [ 38 | ( 39 | WaitFor( 40 | PhraseStateMachine(self.delimiters[0]), 41 | buffer_length=buffer_length, 42 | ), 43 | 1, 44 | ), 45 | ], 46 | 1: [(state_machine, 2)], 47 | 2: [(PhraseStateMachine(self.delimiters[1]), "$")], 48 | }, 49 | is_optional=is_optional, 50 | ) 51 | 52 | def get_new_stepper(self, state: StateId | None = None) -> EncapsulatedStepper: 53 | return EncapsulatedStepper(self, state) 54 | 55 | class EncapsulatedStepper(Stepper): 56 | 57 | def __init__( 58 | self, 59 | state_machine: EncapsulatedStateMachine, 60 | state: StateId | None = None, 61 | ) -> None: 62 | super().__init__(state_machine, state) 63 | self.state_machine: EncapsulatedStateMachine = state_machine 64 | self.inner_stepper: Stepper | None = None 65 | 66 | def clone(self) -> Self: 67 | clone = super().clone() 68 | if self.inner_stepper: 69 | clone.inner_stepper = self.inner_stepper.clone() 70 | return clone 71 | 72 | def is_within_value(self) -> bool: 73 | if self.current_state == 0 and self.sub_stepper: 74 | return self.sub_stepper.is_within_value() 75 | 76 | return self.current_state != 0 77 | 78 | def add_to_history(self, stepper: Stepper) -> None: 79 | if self.current_state == 2: 80 | self.inner_stepper = stepper 81 | 82 | return super().add_to_history(stepper) 83 | 84 | def get_invalid_continuations(self) -> list[str]: 85 | if not self.inner_stepper: 86 | return [self.state_machine.delimiters[1]] 87 | return super().get_invalid_continuations() 88 | 89 | def get_final_state(self) -> list[Stepper]: 90 | return [self] 91 | 92 | def get_token_safe_output(self, decode_function: Callable[[list[int]], str]) -> str: 93 | """ 94 | Retrieve the token-safe output with delimiters removed. 95 | 96 | This method processes the raw output by removing the encapsulating delimiters, 97 | handling both complete and partial delimiter occurrences efficiently. 98 | 99 | Args: 100 | decode_function: Function to decode token IDs into a string 101 | 102 | Returns: 103 | Processed string with delimiters stripped 104 | """ 105 | # Get and decode the token history 106 | token_ids = self.get_token_ids_history() 107 | token_safe_output: str = decode_function(token_ids).strip() 108 | 109 | # Extract delimiters 110 | start_delim, end_delim = self.state_machine.delimiters 111 | 112 | # Remove start delimiter - optimize by checking exact match first 113 | # This is faster than always using lstrip 114 | if token_safe_output.startswith(start_delim): 115 | token_safe_output = token_safe_output[len(start_delim):] 116 | else: 117 | token_safe_output = token_safe_output.lstrip(start_delim) 118 | 119 | # Remove end delimiter - optimize by checking exact match first 120 | # This is faster than always using rstrip 121 | if token_safe_output.endswith(end_delim): 122 | token_safe_output = token_safe_output[:-len(end_delim)] 123 | else: 124 | token_safe_output = token_safe_output.rstrip(end_delim) 125 | 126 | return token_safe_output 127 | -------------------------------------------------------------------------------- /pse/types/base/loop.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | from typing import Self 5 | 6 | from pse_core import StateGraph 7 | from pse_core.state_machine import StateMachine 8 | from pse_core.stepper import Stepper 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class LoopStateMachine(StateMachine): 14 | """ 15 | Loop through a single StateMachine. 16 | """ 17 | 18 | def __init__( 19 | self, 20 | state_machine: StateMachine, 21 | min_loop_count: int = 1, 22 | max_loop_count: int = -1, 23 | separator_state_machine: StateMachine | None = None, 24 | track_separator: bool = True, 25 | ) -> None: 26 | """ 27 | Args: 28 | state_machine: State machine to be looped through 29 | """ 30 | self.separator_state_machine = separator_state_machine 31 | self.track_separator = track_separator 32 | if self.separator_state_machine: 33 | state_graph: StateGraph = { 34 | 0: [(state_machine, 1)], 35 | 1: [(self.separator_state_machine, 2)], 36 | 2: [(state_machine, 1)], 37 | } 38 | else: 39 | state_graph: StateGraph = { 40 | 0: [(state_machine, 1)], 41 | 1: [(state_machine, 0)], 42 | } 43 | 44 | super().__init__( 45 | state_graph=state_graph, 46 | is_optional=min_loop_count == 0, 47 | ) 48 | self.min_loop_count = min_loop_count or 1 49 | self.max_loop_count = max_loop_count 50 | 51 | def get_new_stepper(self, state: int | str | None = None) -> Stepper: 52 | return LoopStepper(self, state) 53 | 54 | def __str__(self) -> str: 55 | return "Loop" 56 | 57 | 58 | class LoopStepper(Stepper): 59 | """ 60 | A stepper that loops through a single StateMachine. 61 | """ 62 | 63 | def __init__(self, loop_state_machine: LoopStateMachine, *args, **kwargs) -> None: 64 | super().__init__(loop_state_machine, *args, **kwargs) 65 | self.state_machine: LoopStateMachine = loop_state_machine 66 | self.loop_count = 0 67 | 68 | def clone(self) -> Self: 69 | clone = super().clone() 70 | clone.loop_count = self.loop_count 71 | return clone 72 | 73 | def should_branch(self) -> bool: 74 | return super().should_branch() and self.loop_count < self.state_machine.max_loop_count 75 | 76 | def has_reached_accept_state(self) -> bool: 77 | """ 78 | Determines if this stepper is in an accept state. 79 | 80 | A loop stepper is in an accept state if: 81 | 1. It has completed at least the minimum required iterations, and 82 | 2. If it's currently processing a sub-stepper, that sub-stepper is in an accept state 83 | 84 | Returns: 85 | True if the stepper is in an accept state, False otherwise 86 | """ 87 | # Early exit if we haven't reached the minimum loop count 88 | if self.loop_count < self.state_machine.min_loop_count: 89 | return False 90 | 91 | # If we're currently processing a value through a sub-stepper, 92 | # delegate to that sub-stepper's accept state 93 | if self.sub_stepper is not None and self.sub_stepper.is_within_value(): 94 | return self.sub_stepper.has_reached_accept_state() 95 | 96 | # Otherwise, we're in an accept state if we've met the minimum loop count 97 | return True 98 | 99 | def consume(self, token: str) -> list[Stepper]: 100 | new_steppers: list[Stepper] = [] 101 | 102 | def _validate_loop_stepper(stepper: LoopStepper) -> Stepper | None: 103 | """ 104 | Validate that the loop stepper respects the max loop count. 105 | """ 106 | if stepper.loop_count < self.state_machine.max_loop_count: 107 | # if the loop count is less than the max loop count, 108 | # we can just return the stepper 109 | return stepper 110 | elif stepper.loop_count == self.state_machine.max_loop_count: 111 | # if the loop count is equal to the max loop count, 112 | # we need to make sure the stepper is not expecting a transition 113 | # to a new state 114 | if stepper.target_state is not None: 115 | stepper.current_state = stepper.target_state 116 | stepper.target_state = None 117 | stepper.sub_stepper = None 118 | return stepper 119 | else: 120 | # otherwise, the loop stepper is invalid 121 | return None 122 | 123 | # explicitly check that the new steppers respect the max loop count 124 | for new_stepper in super().consume(token): 125 | assert isinstance(new_stepper, LoopStepper) 126 | if valid_stepper := _validate_loop_stepper(new_stepper): 127 | new_steppers.append(valid_stepper) 128 | 129 | return new_steppers 130 | 131 | def can_accept_more_input(self) -> bool: 132 | if not super().can_accept_more_input(): 133 | return False 134 | 135 | if self.state_machine.max_loop_count > 0: 136 | return self.loop_count < self.state_machine.max_loop_count 137 | 138 | return True 139 | 140 | def should_start_step(self, token: str) -> bool: 141 | if self.loop_count >= self.state_machine.max_loop_count: 142 | return False 143 | 144 | return super().should_start_step(token) 145 | 146 | def add_to_history(self, stepper: Stepper) -> None: 147 | if ( 148 | self.state_machine.separator_state_machine 149 | and stepper.state_machine == self.state_machine.separator_state_machine 150 | ): 151 | if not self.state_machine.track_separator: 152 | return 153 | else: 154 | self.loop_count += 1 155 | 156 | return super().add_to_history(stepper) 157 | 158 | def get_final_state(self) -> list[Stepper]: 159 | """ 160 | Gets the final state representation for this stepper. 161 | 162 | This method decides which stepper(s) to return as the final state: 163 | - If we're within a value in the sub-stepper (and it's not a whitespace separator), 164 | delegate to the sub-stepper's final state 165 | - Otherwise, return this stepper's history 166 | 167 | Returns: 168 | List of steppers representing the final state 169 | """ 170 | # If no sub-stepper exists, return history 171 | # If sub-stepper isn't processing a value, return history 172 | if not self.sub_stepper or not self.sub_stepper.is_within_value(): 173 | return self.history 174 | 175 | # Check if sub-stepper is a separator we should ignore 176 | separator_state_machine = self.state_machine.separator_state_machine 177 | is_separator = ( 178 | separator_state_machine is not None 179 | and self.sub_stepper.state_machine == separator_state_machine 180 | ) 181 | 182 | # If it's a separator we don't want to track, return history 183 | if is_separator and not self.state_machine.track_separator: 184 | return self.history 185 | 186 | # Otherwise, delegate to the sub-stepper's final state 187 | return self.sub_stepper.get_final_state() 188 | -------------------------------------------------------------------------------- /pse/types/base/phrase.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | 5 | from pse_core.state_machine import StateMachine 6 | from pse_core.stepper import Stepper 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | class PhraseStateMachine(StateMachine): 12 | """ 13 | Accepts a predefined sequence of characters, validating input against the specified text. 14 | 15 | Attributes: 16 | phrase (str): The target string that this state_machine is validating against. 17 | """ 18 | 19 | def __init__( 20 | self, 21 | phrase: str, 22 | is_optional: bool = False, 23 | is_case_sensitive: bool = True, 24 | ): 25 | """ 26 | Initialize a new PhraseStateMachine instance with the specified text. 27 | 28 | Args: 29 | phrase (str): The string of characters that this state_machine will validate. 30 | Must be a non-empty string. 31 | 32 | Raises: 33 | ValueError: If the provided text is empty. 34 | """ 35 | super().__init__( 36 | is_optional=is_optional, 37 | is_case_sensitive=is_case_sensitive, 38 | ) 39 | 40 | if not phrase: 41 | raise ValueError("Phrase must be a non-empty string.") 42 | 43 | self.phrase = phrase 44 | 45 | def get_new_stepper(self, state: int | str | None = None) -> PhraseStepper: 46 | return PhraseStepper(self) 47 | 48 | def __str__(self) -> str: 49 | """ 50 | Provide a string representation of the PhraseStateMachine. 51 | 52 | Returns: 53 | str: A string representation of the PhraseStateMachine. 54 | """ 55 | return f"Phrase({self.phrase!r})" 56 | 57 | def __eq__(self, other: object) -> bool: 58 | return ( 59 | isinstance(other, PhraseStateMachine) 60 | and self.phrase == other.phrase 61 | ) 62 | 63 | 64 | class PhraseStepper(Stepper): 65 | 66 | def __init__( 67 | self, 68 | state_machine: PhraseStateMachine, 69 | consumed_character_count: int | None = None, 70 | ): 71 | super().__init__(state_machine) 72 | if consumed_character_count is not None and consumed_character_count < 0: 73 | raise ValueError("Consumed character count must be non-negative") 74 | 75 | self.consumed_character_count = consumed_character_count or 0 76 | self.state_machine: PhraseStateMachine = state_machine 77 | self.target_state = "$" 78 | 79 | def can_accept_more_input(self) -> bool: 80 | """ 81 | Check if the stepper can accept more input. 82 | """ 83 | return self.consumed_character_count < len(self.state_machine.phrase) 84 | 85 | def should_start_step(self, token: str) -> bool: 86 | """ 87 | Start a transition if the token is not empty and matches the remaining text. 88 | """ 89 | if not token: 90 | return False 91 | 92 | valid_length = self._get_valid_match_length(token) 93 | return valid_length > 0 94 | 95 | def should_complete_step(self) -> bool: 96 | return self.consumed_character_count == len(self.state_machine.phrase) 97 | 98 | def get_valid_continuations(self, depth: int = 0) -> list[str]: 99 | if self.consumed_character_count >= len(self.state_machine.phrase): 100 | return [] 101 | 102 | remaining_text = self.state_machine.phrase[self.consumed_character_count :] 103 | return [remaining_text] 104 | 105 | def consume(self, token: str) -> list[Stepper]: 106 | """ 107 | Advances the stepper if the token matches the expected text at the current position. 108 | Args: 109 | token (str): The string to match against the expected text. 110 | 111 | Returns: 112 | list[Stepper]: A stepper if the token matches, empty otherwise. 113 | """ 114 | valid_length = self._get_valid_match_length(token) 115 | if valid_length <= 0: 116 | return [] 117 | 118 | new_value = self.get_raw_value() + token[:valid_length] 119 | remaining_input = token[valid_length:] if valid_length < len(token) else None 120 | new_stepper = self.step(new_value, remaining_input) 121 | return [new_stepper] 122 | 123 | def get_raw_value(self) -> str: 124 | return self.state_machine.phrase[: self.consumed_character_count] 125 | 126 | def _get_valid_match_length(self, token: str, pos: int | None = None) -> int: 127 | """ 128 | Calculate the length of the matching prefix between the token and the target phrase. 129 | 130 | Args: 131 | token: The input string to check against the target phrase 132 | pos: Starting position in the phrase (defaults to current consumed count) 133 | 134 | Returns: 135 | Length of the matching prefix 136 | """ 137 | # Use current position if not specified 138 | pos = pos or self.consumed_character_count 139 | 140 | # Get the remaining portion of the phrase to match against 141 | remaining_phrase = self.state_machine.phrase[pos:] 142 | 143 | # Determine maximum possible match length 144 | max_length = min(len(token), len(remaining_phrase)) 145 | 146 | # Find the longest matching prefix using string slicing 147 | # This is more efficient than character-by-character comparison 148 | for i in range(max_length + 1): 149 | if token[:i] != remaining_phrase[:i]: 150 | return i - 1 151 | 152 | return max_length 153 | -------------------------------------------------------------------------------- /pse/types/base/wait_for.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | from typing import Self 5 | 6 | from pse_core import StateId 7 | from pse_core.state_machine import StateMachine 8 | from pse_core.stepper import Stepper 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class WaitFor(StateMachine): 14 | """ 15 | Accept all text until a segment triggers a nested StateId Machine. 16 | 17 | Accumulates text in a buffer until a segment triggers the nested StateId Machine. 18 | 19 | This is particularly useful for allowing free-form text until a specific 20 | delimiter or pattern is detected, such as when parsing output from 21 | language models that encapsulate JSON within markdown code blocks. 22 | """ 23 | 24 | def __init__( 25 | self, 26 | state_machine: StateMachine, 27 | buffer_length: int = -1, 28 | strict: bool = True, 29 | ): 30 | """ 31 | Initialize with a target nested StateMachine. 32 | 33 | Args: 34 | state_machine (StateMachine): The nested StateMachine to watch for. 35 | buffer_length (int): 36 | The minimum length of the buffer 37 | strict (bool): 38 | If True, the nested StateMachine's progress is reset when invalid input is detected. 39 | """ 40 | super().__init__() 41 | 42 | self.min_buffer_length = buffer_length 43 | self.strict = strict 44 | self.wait_for_sm = state_machine 45 | 46 | def get_transitions(self, _: Stepper) -> list[tuple[Stepper, StateId]]: 47 | transitions = [] 48 | for transition in self.wait_for_sm.get_steppers(): 49 | transitions.append((transition, "$")) 50 | return transitions 51 | 52 | def get_new_stepper(self, _: StateId | None = None) -> Stepper: 53 | return WaitForStepper(self) 54 | 55 | def get_steppers(self, _: StateId | None = None) -> list[Stepper]: 56 | return self.branch_stepper(self.get_new_stepper()) 57 | 58 | def __str__(self) -> str: 59 | return f"WaitFor({self.wait_for_sm})" 60 | 61 | 62 | class WaitForStepper(Stepper): 63 | def __init__(self, state_machine: WaitFor): 64 | super().__init__(state_machine) 65 | self.target_state = "$" 66 | self.state_machine: WaitFor = state_machine 67 | self.buffer = "" 68 | 69 | def clone(self) -> Self: 70 | clone = super().clone() 71 | clone.buffer = self.buffer 72 | return clone 73 | 74 | def accepts_any_token(self) -> bool: 75 | """ 76 | Determines if this stepper can accept any token based on buffer state. 77 | 78 | The stepper accepts any token if: 79 | 1. The buffer meets minimum length requirements, or 80 | 2. The sub-stepper is active and accepts any token 81 | 82 | Returns: 83 | True if the stepper can accept any token, False otherwise 84 | """ 85 | # Cache min_buffer_length for performance 86 | min_buffer_length = self.state_machine.min_buffer_length 87 | 88 | # Delegate to sub_stepper if it's active 89 | if self.sub_stepper and self.sub_stepper.is_within_value(): 90 | return self.sub_stepper.accepts_any_token() 91 | 92 | # If the buffer is not long enough, we can accept any token 93 | if len(self.buffer) < min_buffer_length: 94 | return True 95 | 96 | # Otherwise, check the size of the buffer 97 | return len(self.buffer) >= min_buffer_length 98 | 99 | def get_valid_continuations(self) -> list[str]: 100 | """ 101 | If the buffer is long enough, we can accept any valid continuations. 102 | 103 | If the buffer is not long enough, we can accept everything. 104 | """ 105 | if len(self.buffer) >= self.state_machine.min_buffer_length: 106 | return super().get_valid_continuations() 107 | return [] 108 | 109 | def get_invalid_continuations(self) -> list[str]: 110 | """ 111 | If the buffer is not long enough yet, 112 | any valid continuation is inversed and 113 | invalid to allow the buffer to grow. 114 | 115 | If the buffer is long enough, there are no invalid continuations. 116 | """ 117 | if len(self.buffer) < self.state_machine.min_buffer_length and self.sub_stepper: 118 | return self.sub_stepper.get_valid_continuations() 119 | return [] 120 | 121 | def should_start_step(self, token: str) -> bool: 122 | """ 123 | Determines if the stepper should start processing the token. 124 | 125 | This method decides whether to start a step based on: 126 | 1. Whether we have remaining input from a previous token 127 | 2. The buffer length requirements 128 | 3. The stepper's current state 129 | 130 | Args: 131 | token: The token to potentially process 132 | 133 | Returns: 134 | True if the step should start, False otherwise 135 | """ 136 | # Never start a step if we have remaining input 137 | if self.remaining_input: 138 | return False 139 | 140 | # Cache frequently accessed values 141 | required_buffer_length = self.state_machine.min_buffer_length 142 | should_start = super().should_start_step(token) 143 | 144 | # Handle unlimited buffer length case 145 | if required_buffer_length <= 0: 146 | return should_start or not self.is_within_value() 147 | 148 | # For cases with a positive buffer length requirement 149 | buffer_length = len(self.buffer) 150 | is_in_value = self.is_within_value() 151 | 152 | # Return True if either: 153 | # 1. super().should_start_step() returns True and we have enough buffer, or 154 | # 2. super().should_start_step() returns False but we're not within a value 155 | if should_start and buffer_length >= required_buffer_length: 156 | return True 157 | if not should_start and not is_in_value: 158 | return True 159 | 160 | return False 161 | 162 | def consume(self, token: str) -> list[Stepper]: 163 | # No sub_stepper means we can't process anything 164 | if not self.sub_stepper: 165 | return [] 166 | 167 | # Try to find the longest valid prefix that the sub_stepper will accept 168 | invalid_prefix = "" 169 | valid_suffix = token 170 | 171 | while valid_suffix and not self.sub_stepper.should_start_step(valid_suffix): 172 | invalid_prefix += valid_suffix[0] 173 | valid_suffix = valid_suffix[1:] 174 | 175 | if self.state_machine.strict and self.is_within_value() and invalid_prefix: 176 | return [] 177 | 178 | if invalid_prefix and ( 179 | not self.is_within_value() or not self.state_machine.strict 180 | ): 181 | if not self.is_within_value() and self.state_machine.min_buffer_length == -1: 182 | return [] 183 | 184 | clone = self.clone() 185 | clone.buffer += invalid_prefix 186 | if valid_suffix: 187 | return self.state_machine.advance_stepper(clone, valid_suffix) 188 | else: 189 | return [clone] 190 | 191 | return self.state_machine.advance_stepper(self, valid_suffix) 192 | -------------------------------------------------------------------------------- /pse/types/boolean.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pse_core import StateId 4 | from pse_core.state_machine import StateMachine 5 | from pse_core.stepper import Stepper 6 | 7 | from pse.types.base.phrase import PhraseStateMachine 8 | 9 | 10 | class BooleanStateMachine(StateMachine): 11 | """ 12 | Accepts a JSON boolean value: true, false. 13 | """ 14 | 15 | def __init__(self) -> None: 16 | super().__init__( 17 | { 18 | 0: [ 19 | (PhraseStateMachine("true"), "$"), 20 | (PhraseStateMachine("false"), "$"), 21 | ] 22 | } 23 | ) 24 | 25 | def get_steppers(self, state: StateId | None = None) -> list[Stepper]: 26 | steppers = [] 27 | for edge, _ in self.get_edges(state or 0): 28 | steppers.extend(edge.get_steppers()) 29 | return steppers 30 | -------------------------------------------------------------------------------- /pse/types/enum.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pse_core import StateGraph, StateId 4 | from pse_core.state_machine import StateMachine 5 | from pse_core.stepper import Stepper 6 | 7 | from pse.types.base.chain import ChainStateMachine 8 | from pse.types.base.phrase import PhraseStateMachine 9 | 10 | 11 | class EnumStateMachine(StateMachine): 12 | """ 13 | Accept one of several constant strings. 14 | """ 15 | 16 | def __init__(self, enum_values: list[str], require_quotes: bool = True) -> None: 17 | if not enum_values: 18 | raise ValueError("Enum values must be provided.") 19 | 20 | state_graph: StateGraph = {0: []} 21 | unique_enum_values = list(set(enum_values)) 22 | for value in unique_enum_values: 23 | sm = ( 24 | PhraseStateMachine(value) 25 | if not require_quotes 26 | else ChainStateMachine( 27 | [ 28 | PhraseStateMachine('"'), 29 | PhraseStateMachine(value), 30 | PhraseStateMachine('"'), 31 | ] 32 | ) 33 | ) 34 | state_graph[0].append((sm, "$")) 35 | 36 | super().__init__(state_graph) 37 | 38 | def get_steppers(self, state: StateId | None = None) -> list[Stepper]: 39 | steppers = [] 40 | for edge, _ in self.get_edges(state or 0): 41 | steppers.extend(edge.get_steppers()) 42 | return steppers 43 | -------------------------------------------------------------------------------- /pse/types/grammar/__init__.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | from lark import Lark 4 | from lark.exceptions import UnexpectedCharacters, UnexpectedEOF, UnexpectedToken 5 | 6 | 7 | class LarkGrammar(ABC): 8 | 9 | name: str 10 | lark_grammar: Lark 11 | delimiters: tuple[str, str] | None = None 12 | 13 | def __init__( 14 | self, 15 | name: str, 16 | lark_grammar: Lark, 17 | delimiters: tuple[str, str] | None = None, 18 | ): 19 | self.name = name 20 | self.lark_grammar = lark_grammar 21 | self.delimiters = delimiters 22 | 23 | @abstractmethod 24 | def validate( 25 | self, 26 | input: str, 27 | strict: bool = False, 28 | start: str | None = None, 29 | ) -> bool: 30 | """ 31 | Validate the input against the grammar. 32 | 33 | Args: 34 | input (str): The input to validate. 35 | strict (bool): Whether to use strict validation. 36 | start (str): The start rule to use. 37 | 38 | Returns: 39 | bool: True if the input is valid, False otherwise. 40 | """ 41 | try: 42 | self.lark_grammar.parse(input, start=start) 43 | return True 44 | except Exception as e: 45 | if not strict: 46 | if isinstance(e, UnexpectedEOF | UnexpectedCharacters): 47 | return True 48 | elif isinstance(e, UnexpectedToken) and e.token.type == "$END": 49 | return True 50 | 51 | return False 52 | 53 | from pse.types.grammar.default_grammars.bash import BashGrammar # noqa: E402 54 | from pse.types.grammar.default_grammars.python import PythonGrammar # noqa: E402 55 | from pse.types.grammar.lark import LarkGrammarStateMachine # noqa: E402 56 | 57 | PythonStateMachine = LarkGrammarStateMachine(PythonGrammar()) 58 | BashStateMachine = LarkGrammarStateMachine(BashGrammar()) 59 | 60 | __all__ = [ 61 | "BashStateMachine", 62 | "LarkGrammarStateMachine", 63 | "PythonStateMachine", 64 | ] 65 | -------------------------------------------------------------------------------- /pse/types/grammar/default_grammars/bash.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | import os 5 | 6 | from lark import Lark 7 | from lark.exceptions import UnexpectedCharacters, UnexpectedToken 8 | 9 | from pse.types.grammar import LarkGrammar 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class BashGrammar(LarkGrammar): 15 | def __init__(self): 16 | # Get the path to the bash.lark file 17 | current_dir = os.path.dirname(os.path.abspath(__file__)) 18 | grammar_path = os.path.join(current_dir, "bash.lark") 19 | 20 | # Read the Lark file 21 | with open(grammar_path) as f: 22 | bash_grammar_content = f.read() 23 | 24 | bash_lark_grammar = Lark( 25 | bash_grammar_content, 26 | start="start", 27 | parser="lalr", 28 | lexer="basic", 29 | ) 30 | 31 | super().__init__( 32 | name="Bash", 33 | lark_grammar=bash_lark_grammar, 34 | delimiters=("```bash\n", "\n```"), 35 | ) 36 | 37 | def validate( 38 | self, 39 | input: str, 40 | strict: bool = False, 41 | start: str | None = None, 42 | ) -> bool: 43 | """ 44 | Validate Bash code using the Lark parser. 45 | 46 | Args: 47 | input: The Bash code to validate. 48 | strict: Whether to use strict validation. 49 | start: The start rule to use. 50 | """ 51 | # If code is empty, it's not valid bash 52 | if not input.strip(): 53 | return False 54 | 55 | try: 56 | # Try to parse the input normally 57 | self.lark_grammar.parse(input, start=start or "start") 58 | return True 59 | except Exception as e: 60 | if not strict and isinstance(e, UnexpectedToken): 61 | return e.token.type == "$END" or "ESAC" in e.expected 62 | 63 | if not strict and isinstance(e, UnexpectedCharacters): 64 | # special case for unclosed quotes 65 | return e.char == "'" or e.char == '"' 66 | 67 | return False 68 | -------------------------------------------------------------------------------- /pse/types/grammar/default_grammars/python.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | import os 5 | 6 | from lark import Lark 7 | from lark.exceptions import UnexpectedCharacters, UnexpectedToken 8 | from lark.indenter import PythonIndenter 9 | 10 | from pse.types.grammar import LarkGrammar 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | class PythonGrammar(LarkGrammar): 15 | 16 | def __init__(self): 17 | current_dir = os.path.dirname(os.path.abspath(__file__)) 18 | grammar_path = os.path.join(current_dir, "python.lark") 19 | 20 | # Read the Lark file 21 | with open(grammar_path) as f: 22 | python_grammar_content = f.read() 23 | python_lark_grammar = Lark( 24 | python_grammar_content, 25 | parser="lalr", 26 | lexer="basic", 27 | postlex=PythonIndenter(), 28 | start=["file_input"], 29 | ) 30 | 31 | super().__init__( 32 | name="Python", 33 | lark_grammar=python_lark_grammar, 34 | delimiters=("```python\n", "\n```"), 35 | ) 36 | 37 | def validate( 38 | self, 39 | input: str, 40 | strict: bool = False, 41 | start: str | None = None, 42 | ) -> bool: 43 | """ 44 | Validate Python code using the Lark parser. 45 | 46 | Args: 47 | parser: The Lark parser to use. 48 | code: The Python code to validate. 49 | strict: Whether to use strict validation. 50 | """ 51 | if strict and not input.endswith("\n"): 52 | input += "\n" 53 | 54 | try: 55 | self.lark_grammar.parse(input, start=start) 56 | return True 57 | except Exception as e: 58 | if not strict and isinstance(e, UnexpectedToken): 59 | return e.token.type == "_DEDENT" or e.token.type == "$END" 60 | if not strict and isinstance(e, UnexpectedCharacters): 61 | # special case for unclosed quotes 62 | return e.char == "'" or e.char == '"' 63 | 64 | return False 65 | -------------------------------------------------------------------------------- /pse/types/grammar/lark.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | 5 | from pse_core import StateId 6 | from pse_core.stepper import Stepper 7 | 8 | from pse.types.base.character import CharacterStateMachine, CharacterStepper 9 | from pse.types.grammar import LarkGrammar 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | class LarkGrammarStateMachine(CharacterStateMachine): 14 | def __init__(self, grammar: LarkGrammar): 15 | super().__init__(char_min=1) 16 | self.grammar = grammar 17 | 18 | @property 19 | def delimiters(self) -> tuple[str, str] | None: 20 | return self.grammar.delimiters 21 | 22 | def get_new_stepper(self, state: StateId | None) -> LarkGrammarStepper: 23 | """ 24 | Get a new stepper for the grammar. 25 | """ 26 | return LarkGrammarStepper(self) 27 | 28 | def __str__(self) -> str: 29 | return self.grammar.name 30 | 31 | 32 | class LarkGrammarStepper(CharacterStepper): 33 | """ 34 | A stepper for the grammar state machine. 35 | """ 36 | 37 | def __init__(self, state_machine: LarkGrammarStateMachine): 38 | """ 39 | Initialize the grammar stepper with a state machine. 40 | 41 | Args: 42 | state_machine: The grammar state machine that defines the valid transitions 43 | """ 44 | super().__init__(state_machine) 45 | self.state_machine: LarkGrammarStateMachine = state_machine 46 | 47 | def get_identifier(self) -> str | None: 48 | return self.state_machine.grammar.name.lower() 49 | 50 | def should_start_step(self, token: str) -> bool: 51 | """ 52 | Should the stepper start a new step? 53 | """ 54 | valid_prefix, _ = self.get_valid_prefix(token) 55 | return valid_prefix is not None 56 | 57 | def has_reached_accept_state(self) -> bool: 58 | """ 59 | Has the stepper reached the accept state? 60 | """ 61 | valid_input = self.get_raw_value() 62 | return self.state_machine.grammar.validate(valid_input, strict=True) 63 | 64 | def consume(self, token: str) -> list[Stepper]: 65 | """ 66 | Consume the input token and return possible next states. 67 | 68 | Args: 69 | token: The input string to consume 70 | 71 | Returns: 72 | A list of new steppers after consuming the token. 73 | Returns empty list if no valid transitions are possible. 74 | """ 75 | valid_input, remaining_input = self.get_valid_prefix(token) 76 | assert valid_input is not None 77 | return [ 78 | self.step( 79 | self.get_raw_value() + valid_input, 80 | remaining_input or None, 81 | ) 82 | ] 83 | 84 | def get_valid_prefix(self, new_input: str) -> tuple[str | None, str]: 85 | """ 86 | Get the first prefix of the new input that maintains a valid grammar state. 87 | 88 | Args: 89 | new_input: The input string to validate 90 | 91 | Returns: 92 | A tuple of (valid_prefix, remaining_input) where valid_prefix is None if no 93 | valid prefix exists 94 | """ 95 | candidate_base = self.get_raw_value() 96 | max_valid_index = None 97 | for i in range(1, len(new_input) + 1): 98 | candidate = candidate_base + new_input[:i] 99 | if self.state_machine.grammar.validate(candidate, False): 100 | max_valid_index = i 101 | break 102 | 103 | if max_valid_index is not None: 104 | return new_input[:max_valid_index], new_input[max_valid_index:] 105 | else: 106 | return None, "" 107 | -------------------------------------------------------------------------------- /pse/types/integer.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Any 4 | 5 | from pse.types.base.character import ( 6 | CharacterStateMachine, 7 | CharacterStepper, 8 | ) 9 | 10 | 11 | class IntegerStateMachine(CharacterStateMachine): 12 | """ 13 | Accepts an integer as per JSON specification. 14 | """ 15 | 16 | def __init__(self, drop_leading_zeros: bool = True) -> None: 17 | super().__init__("0123456789") 18 | self.drop_leading_zeros = drop_leading_zeros 19 | 20 | def get_new_stepper(self, state: int | str) -> IntegerStepper: 21 | return IntegerStepper(self) 22 | 23 | def __str__(self) -> str: 24 | return "Integer" 25 | 26 | 27 | class IntegerStepper(CharacterStepper): 28 | def __init__( 29 | self, state_machine: IntegerStateMachine, value: str | None = None 30 | ) -> None: 31 | super().__init__(state_machine, value) 32 | self.state_machine: IntegerStateMachine = state_machine 33 | 34 | def get_current_value(self) -> Any: 35 | if self._raw_value is None: 36 | return None 37 | return int(self._raw_value) if self.state_machine.drop_leading_zeros else self._raw_value 38 | -------------------------------------------------------------------------------- /pse/types/json/any_json_schema.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from pse_core import StateId 4 | from pse_core.state_machine import StateMachine 5 | from pse_core.stepper import Stepper 6 | 7 | 8 | class AnySchemaStateMachine(StateMachine): 9 | """ 10 | Accepts JSON input that complies with any of several provided JSON schemas 11 | """ 12 | 13 | def __init__(self, schemas: list[dict[str, Any]], context: dict[str, Any]) -> None: 14 | """ 15 | This state_machine will validate JSON input against any of the provided schemas. 16 | 17 | Args: 18 | schemas (List[Dict[str, Any]]): A list of JSON schemas to validate against. 19 | context (Dict[str, Any]): Contextual information for schema definitions and paths. 20 | """ 21 | from pse.types.json import _json_schema_to_state_machine 22 | 23 | # Construct the state machine graph with an initial state `0` that transitions 24 | # to the end state `$` for each schema state_machine. 25 | self.state_machines: list[StateMachine] = [] 26 | for schema in schemas: 27 | sm = _json_schema_to_state_machine(schema, context) 28 | self.state_machines.append(sm) 29 | 30 | super().__init__( 31 | {0: [(state_machine, "$") for state_machine in self.state_machines]} 32 | ) 33 | 34 | def get_steppers(self, state: StateId | None = None) -> list[Stepper]: 35 | steppers = [] 36 | for edge, _ in self.get_edges(state or 0): 37 | steppers.extend(edge.get_steppers()) 38 | return steppers 39 | 40 | def __str__(self) -> str: 41 | return "AnyJsonSchema" 42 | -------------------------------------------------------------------------------- /pse/types/json/json_array.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Any 4 | 5 | from pse_core import StateId 6 | from pse_core.stepper import Stepper 7 | 8 | from pse.types.array import ArrayStateMachine, ArrayStepper 9 | from pse.types.base.chain import ChainStateMachine 10 | from pse.types.base.phrase import PhraseStateMachine 11 | from pse.types.json import _json_schema_to_state_machine 12 | from pse.types.whitespace import WhitespaceStateMachine 13 | 14 | 15 | class ArraySchemaStateMachine(ArrayStateMachine): 16 | def __init__(self, schema: dict[str, Any], context: dict[str, Any]) -> None: 17 | self.schema = schema 18 | self.context = context 19 | super().__init__( 20 | { 21 | 0: [ 22 | (PhraseStateMachine("["), 1), 23 | ], 24 | 1: [ 25 | (WhitespaceStateMachine(), 2), 26 | (PhraseStateMachine("]"), "$"), 27 | ], 28 | 2: [ 29 | ( 30 | _json_schema_to_state_machine( 31 | self.schema["items"], self.context 32 | ), 33 | 3, 34 | ), 35 | ], 36 | 3: [ 37 | (WhitespaceStateMachine(), 4), 38 | ], 39 | 4: [ 40 | ( 41 | ChainStateMachine( 42 | [PhraseStateMachine(","), WhitespaceStateMachine()] 43 | ), 44 | 2, 45 | ), 46 | (PhraseStateMachine("]"), "$"), 47 | ], 48 | } 49 | ) 50 | 51 | def get_transitions(self, stepper: Stepper) -> list[tuple[Stepper, StateId]]: 52 | """Retrieve transition steppers from the current state. 53 | 54 | For each edge from the current state, returns steppers that can traverse that edge. 55 | Args: 56 | stepper: The stepper initiating the transition. 57 | state: Optional starting state. If None, uses the stepper's current state. 58 | 59 | Returns: 60 | list[tuple[Stepper, StateId]]: A list of tuples representing transitions. 61 | """ 62 | if stepper.current_state == 4: 63 | transitions: list[tuple[Stepper, StateId]] = [] 64 | if len(stepper.get_current_value()) >= self.min_items(): 65 | for transition in PhraseStateMachine("]").get_steppers(): 66 | transitions.append((transition, "$")) 67 | 68 | if len(stepper.get_current_value()) < self.max_items(): 69 | for transition in ChainStateMachine( 70 | [PhraseStateMachine(","), WhitespaceStateMachine()] 71 | ).get_steppers(): 72 | transitions.append((transition, 2)) 73 | 74 | return transitions 75 | elif stepper.current_state == 1 and self.min_items() > 0: 76 | transitions = [] 77 | for transition in WhitespaceStateMachine().get_steppers(): 78 | transitions.append((transition, 2)) 79 | return transitions 80 | else: 81 | return super().get_transitions(stepper) 82 | 83 | def get_new_stepper(self, state: StateId | None = None) -> ArraySchemaStepper: 84 | return ArraySchemaStepper(self, state) 85 | 86 | def min_items(self) -> int: 87 | """ 88 | Returns the minimum number of items in the array, according to the schema 89 | """ 90 | return self.schema.get("minItems", 0) 91 | 92 | def max_items(self) -> int: 93 | """ 94 | Returns the maximum number of items in the array, according to the schema 95 | """ 96 | return self.schema.get("maxItems", 2**32) 97 | 98 | def unique_items(self) -> bool: 99 | """ 100 | Returns whether the items in the array must be unique, according to the schema 101 | """ 102 | return self.schema.get("uniqueItems", False) 103 | 104 | def __str__(self) -> str: 105 | return "JSON" + super().__str__() 106 | 107 | 108 | class ArraySchemaStepper(ArrayStepper): 109 | """ """ 110 | 111 | def __init__( 112 | self, 113 | state_machine: ArraySchemaStateMachine, 114 | current_state: StateId | None = None, 115 | ): 116 | super().__init__(state_machine, current_state) 117 | self.state_machine: ArraySchemaStateMachine = state_machine 118 | 119 | def add_to_history(self, stepper: Stepper) -> None: 120 | """ 121 | Adds an item to the array. 122 | """ 123 | item = stepper.get_current_value() 124 | if self.state_machine.unique_items() and self.is_within_value(): 125 | if item in self.value: 126 | return 127 | 128 | super().add_to_history(stepper) 129 | -------------------------------------------------------------------------------- /pse/types/json/json_key_value.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Any 4 | 5 | from pse_core import StateId 6 | 7 | from pse.types.base.chain import ChainStateMachine 8 | from pse.types.base.phrase import PhraseStateMachine 9 | from pse.types.json import _json_schema_to_state_machine 10 | from pse.types.key_value import KeyValueStateMachine, KeyValueStepper 11 | from pse.types.string import StringStateMachine 12 | from pse.types.whitespace import WhitespaceStateMachine 13 | 14 | 15 | class KeyValueSchemaStateMachine(KeyValueStateMachine): 16 | """ 17 | Args: 18 | prop_name (str): The name of the property. 19 | prop_schema (Dict[str, Any]): The schema of the property. 20 | context (Dict[str, Any]): The parsing context. 21 | """ 22 | 23 | def __init__( 24 | self, 25 | prop_name: str | None, 26 | prop_schema: dict[str, Any], 27 | context: dict[str, Any], 28 | ): 29 | self.prop_name = prop_name 30 | self.prop_schema = prop_schema 31 | self.prop_context = { 32 | "defs": context.get("defs", {}), 33 | "path": f"{context.get('path', '')}/{prop_name}", 34 | } 35 | if self.prop_name: 36 | key_value_sm = ChainStateMachine( 37 | [ 38 | PhraseStateMachine('"'), 39 | PhraseStateMachine(self.prop_name), 40 | PhraseStateMachine('"'), 41 | ] 42 | ) 43 | else: 44 | key_value_sm = StringStateMachine() 45 | 46 | is_optional = self.prop_schema.get("nullable", False) or "default" in self.prop_schema 47 | super().__init__( 48 | [ 49 | key_value_sm, 50 | WhitespaceStateMachine(), 51 | PhraseStateMachine(":"), 52 | WhitespaceStateMachine(), 53 | _json_schema_to_state_machine(self.prop_schema, self.prop_context), 54 | ], 55 | is_optional=is_optional, 56 | ) 57 | 58 | def get_new_stepper(self, state: StateId | None = None) -> KeyValueSchemaStepper: 59 | return KeyValueSchemaStepper(self, state) 60 | 61 | 62 | class KeyValueSchemaStepper(KeyValueStepper): 63 | def __init__( 64 | self, 65 | state_machine: KeyValueSchemaStateMachine, 66 | current_state: StateId | None = None, 67 | ): 68 | super().__init__(state_machine, current_state) 69 | self.state_machine: KeyValueSchemaStateMachine = state_machine 70 | -------------------------------------------------------------------------------- /pse/types/json/json_number.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pse_core import StateId 4 | from pse_core.stepper import Stepper 5 | 6 | from pse.types.number import NumberStateMachine 7 | 8 | 9 | class NumberSchemaStateMachine(NumberStateMachine): 10 | """ 11 | Accept a JSON number that conforms to a JSON schema 12 | """ 13 | 14 | def __init__(self, schema): 15 | super().__init__() 16 | self.schema = schema 17 | self.is_integer = schema["type"] == "integer" 18 | self.requires_validation = any( 19 | constraint in schema 20 | for constraint in [ 21 | "minimum", 22 | "exclusiveMinimum", 23 | "maximum", 24 | "exclusiveMaximum", 25 | "multipleOf", 26 | ] 27 | ) 28 | 29 | def get_new_stepper(self, state: StateId | None = None) -> NumberSchemaStepper: 30 | return NumberSchemaStepper(self, state) 31 | 32 | def validate_value(self, value: float) -> bool: 33 | """ 34 | Validate the number value according to the schema 35 | """ 36 | if not isinstance(value, int | float): 37 | return True 38 | 39 | if "minimum" in self.schema and value < self.schema["minimum"]: 40 | return False 41 | 42 | if ( 43 | "exclusiveMinimum" in self.schema 44 | and value <= self.schema["exclusiveMinimum"] 45 | ): 46 | return False 47 | if "maximum" in self.schema and value > self.schema["maximum"]: 48 | return False 49 | if ( 50 | "exclusiveMaximum" in self.schema 51 | and value >= self.schema["exclusiveMaximum"] 52 | ): 53 | return False 54 | if "multipleOf" in self.schema: 55 | divisor = self.schema["multipleOf"] 56 | if value / divisor != value // divisor: 57 | return False 58 | 59 | if self.is_integer and not (isinstance(value, int) or value.is_integer()): 60 | return False 61 | 62 | return True 63 | 64 | def __str__(self) -> str: 65 | return "JSON" + super().__str__() 66 | 67 | 68 | class NumberSchemaStepper(Stepper): 69 | """ """ 70 | 71 | def __init__( 72 | self, 73 | state_machine: NumberSchemaStateMachine, 74 | current_state: StateId | None = None, 75 | ): 76 | super().__init__(state_machine, current_state) 77 | self.state_machine: NumberSchemaStateMachine = state_machine 78 | 79 | def should_start_step(self, token: str) -> bool: 80 | if self.state_machine.is_integer and self.target_state == 3: 81 | return False 82 | 83 | return super().should_start_step(token) 84 | 85 | def should_complete_step(self) -> bool: 86 | if not super().should_complete_step(): 87 | return False 88 | 89 | return self.state_machine.validate_value(self.get_current_value()) 90 | 91 | def has_reached_accept_state(self) -> bool: 92 | if super().has_reached_accept_state(): 93 | return self.state_machine.validate_value(self.get_current_value()) 94 | 95 | return False 96 | -------------------------------------------------------------------------------- /pse/types/json/json_object.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Any 4 | 5 | from pse_core import StateId 6 | from pse_core.state_machine import StateMachine 7 | from pse_core.stepper import Stepper 8 | 9 | from pse.types.base.chain import ChainStateMachine 10 | from pse.types.base.phrase import PhraseStateMachine 11 | from pse.types.json.json_key_value import KeyValueSchemaStateMachine 12 | from pse.types.key_value import KeyValueStateMachine 13 | from pse.types.object import ObjectStateMachine 14 | from pse.types.whitespace import WhitespaceStateMachine 15 | 16 | 17 | class ObjectSchemaStateMachine(ObjectStateMachine): 18 | def __init__( 19 | self, 20 | schema: dict[str, Any], 21 | context: dict[str, Any], 22 | ): 23 | self.schema = schema 24 | self.context = context 25 | self.properties: dict[str, Any] = schema.get("properties", {}) 26 | self.required_property_names: list[str] = schema.get("required", []) 27 | self.additional_properties: dict[str, Any] | bool = schema.get( 28 | "additionalProperties", {} 29 | ) 30 | self.ordered_properties: bool = schema.get("orderedProperties", True) 31 | if any(prop not in self.properties for prop in self.required_property_names): 32 | raise ValueError("Required property not defined in schema") 33 | 34 | for property_name, property_schema in self.properties.items(): 35 | if property_name in self.required_property_names and property_schema: 36 | if ( 37 | property_schema.get("nullable", False) 38 | or "default" in property_schema 39 | ): 40 | self.required_property_names.remove(property_name) 41 | 42 | super().__init__(schema.get("nullable", False)) 43 | 44 | def get_transitions(self, stepper: Stepper) -> list[tuple[Stepper, StateId]]: 45 | """Retrieve transition steppers from the current state. 46 | 47 | Returns: 48 | list[tuple[Stepper, StateId]]: A list of tuples representing transitions. 49 | """ 50 | value = stepper.get_current_value() 51 | transitions: list[tuple[Stepper, StateId]] = [] 52 | if stepper.current_state == 2: 53 | for property in self.get_property_state_machines(value): 54 | for transition in property.get_steppers(): 55 | transitions.append((transition, 3)) 56 | 57 | elif stepper.current_state == 4: 58 | if all(prop_name in value for prop_name in self.required_property_names): 59 | for transition in PhraseStateMachine("}").get_steppers(): 60 | transitions.append((transition, "$")) 61 | 62 | if len(value) < len(self.properties) or self.additional_properties: 63 | for transition in ChainStateMachine( 64 | [PhraseStateMachine(","), WhitespaceStateMachine()] 65 | ).get_steppers(): 66 | transitions.append((transition, 2)) 67 | else: 68 | return super().get_transitions(stepper) 69 | 70 | return transitions 71 | 72 | def get_property_state_machines(self, value: dict[str, Any]) -> list[StateMachine]: 73 | property_state_machines: list[StateMachine] = [] 74 | for prop_name, prop_schema in self.properties.items(): 75 | if prop_name not in value: 76 | property = KeyValueSchemaStateMachine( 77 | prop_name, 78 | prop_schema, 79 | self.context, 80 | ) 81 | property_state_machines.append(property) 82 | if self.ordered_properties: 83 | break 84 | 85 | if ( 86 | all(prop_name in value for prop_name in self.required_property_names) 87 | and self.additional_properties 88 | ): 89 | # non-schema kv property to represent the additional properties 90 | if isinstance(self.additional_properties, dict): 91 | property = KeyValueSchemaStateMachine( 92 | None, 93 | self.additional_properties, 94 | self.context, 95 | ) 96 | else: 97 | property = KeyValueStateMachine() 98 | property_state_machines.append(property) 99 | 100 | return property_state_machines 101 | 102 | def __eq__(self, other: object) -> bool: 103 | return ( 104 | isinstance(other, ObjectSchemaStateMachine) 105 | and super().__eq__(other) 106 | and self.schema == other.schema 107 | ) 108 | 109 | def __str__(self) -> str: 110 | return "JSON" + super().__str__() 111 | -------------------------------------------------------------------------------- /pse/types/json/json_string.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | import re 5 | 6 | import regex 7 | from pse_core import StateId 8 | 9 | from pse.types.string import StringStateMachine, StringStepper 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class StringSchemaStateMachine(StringStateMachine): 15 | """ 16 | Accept a JSON string that conforms to a JSON schema, including 'pattern' and 'format' constraints. 17 | """ 18 | 19 | # Class-level constants 20 | EMAIL_PATTERN = re.compile(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$") 21 | SUPPORTED_FORMATS = frozenset(["email", "date-time", "uri"]) 22 | 23 | def __init__( 24 | self, 25 | schema: dict, 26 | ): 27 | super().__init__( 28 | min_length=schema.get("minLength"), 29 | max_length=schema.get("maxLength"), 30 | ) 31 | self.schema = schema or {} 32 | self.pattern: re.Pattern | None = None 33 | self.format: str | None = None 34 | 35 | if "pattern" in self.schema: 36 | try: 37 | self.pattern = re.compile(self.schema["pattern"]) 38 | except re.error as e: 39 | raise ValueError(f"Invalid pattern in schema: {e}") from e 40 | 41 | if "format" in self.schema: 42 | self.format = self.schema["format"] 43 | if self.format not in self.SUPPORTED_FORMATS: 44 | raise ValueError( 45 | f"Format '{self.format}' not supported. Supported formats: {', '.join(self.SUPPORTED_FORMATS)}" 46 | ) 47 | 48 | def get_new_stepper(self, state: StateId | None = None) -> StringSchemaStepper: 49 | return StringSchemaStepper(self, state) 50 | 51 | def min_length(self) -> int: 52 | """ 53 | Returns the minimum string length according to the schema. 54 | """ 55 | return self.schema.get("minLength", 0) 56 | 57 | def max_length(self) -> int: 58 | """ 59 | Returns the maximum string length according to the schema. 60 | """ 61 | return self.schema.get("maxLength", 10000) 62 | 63 | def validate_email(self, value: str) -> bool: 64 | """ 65 | Validate that the value is a valid email address. 66 | """ 67 | return bool(self.EMAIL_PATTERN.fullmatch(value)) 68 | 69 | def validate_date_time(self, value: str) -> bool: 70 | """ 71 | Validate that the value is a valid ISO 8601 date-time. 72 | """ 73 | from datetime import datetime 74 | 75 | try: 76 | datetime.fromisoformat(value) 77 | return True 78 | except ValueError: 79 | return False 80 | 81 | def validate_uri(self, value: str) -> bool: 82 | """ 83 | Validate that the value is a valid URI. 84 | """ 85 | from urllib.parse import urlparse 86 | 87 | try: 88 | result = urlparse(value) 89 | return result.scheme is not None and result.netloc is not None 90 | except ValueError: 91 | return False 92 | 93 | def __str__(self) -> str: 94 | return "JSON" + super().__str__() 95 | 96 | 97 | class StringSchemaStepper(StringStepper): 98 | def __init__( 99 | self, 100 | state_machine: StringSchemaStateMachine, 101 | current_state: StateId | None = None, 102 | ): 103 | super().__init__(state_machine, current_state) 104 | self.state_machine: StringSchemaStateMachine = state_machine 105 | self._format_validators = { 106 | "email": self.state_machine.validate_email, 107 | "date-time": self.state_machine.validate_date_time, 108 | "uri": self.state_machine.validate_uri, 109 | } 110 | 111 | def should_start_step(self, token: str) -> bool: 112 | if super().should_start_step(token): 113 | if self.is_within_value(): 114 | valid_prefix = self.get_valid_prefix(token) 115 | return self.validate_value(valid_prefix) 116 | return True 117 | 118 | return False 119 | 120 | def consume(self, token: str): 121 | """ 122 | Consume the token and return the new stepper. 123 | """ 124 | if self.is_within_value(): 125 | valid_prefix = self.get_valid_prefix(token) 126 | if not valid_prefix: 127 | return [] 128 | else: 129 | valid_prefix = token 130 | 131 | steppers = super().consume(valid_prefix) 132 | for stepper in steppers: 133 | if token != valid_prefix: 134 | stepper.remaining_input = token[len(valid_prefix) :] 135 | 136 | return steppers 137 | 138 | def clean_value(self, value: str) -> str: 139 | """ 140 | Clean and normalize the input value by removing bounding quotes. 141 | 142 | Args: 143 | value: The string value to clean. 144 | 145 | Returns: 146 | str: The cleaned string with bounding quotes removed. 147 | """ 148 | if value.startswith('"'): 149 | value = value[1:] 150 | if value.endswith('"'): 151 | first_quote = value.index('"') 152 | value = value[: first_quote] 153 | return value 154 | 155 | def get_valid_prefix(self, s: str) -> str | None: 156 | """ 157 | Check whether the string 's' can be a prefix of any string matching the pattern. 158 | Uses binary search for efficiency. 159 | """ 160 | if ( 161 | not self.is_within_value() 162 | or not self.state_machine.pattern 163 | or not self.sub_stepper 164 | ): 165 | return s 166 | 167 | current_value = self.sub_stepper.get_raw_value() 168 | quotes_removed_s = self.clean_value(s) 169 | 170 | left, right = 0, len(quotes_removed_s) 171 | best_match = None 172 | 173 | while left <= right: 174 | mid = (left + right) // 2 175 | working_s = quotes_removed_s[:mid] 176 | match = regex.match( 177 | self.state_machine.pattern.pattern, 178 | current_value + working_s, 179 | partial=True, 180 | ) 181 | if match: 182 | best_match = working_s 183 | left = mid + 1 # Try a longer prefix 184 | else: 185 | right = mid - 1 # Try a shorter prefix 186 | 187 | if best_match is not None: 188 | if best_match == quotes_removed_s: 189 | return s # Return original if the whole string is a valid prefix 190 | return best_match 191 | 192 | return None 193 | 194 | def validate_value(self, value: str | None = None) -> bool: 195 | """ 196 | Validate the string value according to the schema. 197 | 198 | Args: 199 | value: Optional string to append to current value before validation 200 | 201 | Returns: 202 | bool: True if the value meets all schema constraints 203 | 204 | Note: 205 | Validates length, pattern, and format constraints in sequence 206 | """ 207 | value = self.clean_value(self.get_raw_value() + (value or "")) 208 | if not value: 209 | return False 210 | 211 | # Length validation 212 | if len(value) > self.state_machine.max_length(): 213 | return False 214 | 215 | # Pattern validation 216 | if not self.is_within_value() and self.state_machine.pattern: 217 | if not self.state_machine.pattern.match(value): 218 | return False 219 | 220 | # Format validation 221 | if self.state_machine.format: 222 | validator = self._format_validators.get(self.state_machine.format) 223 | if not validator: 224 | raise ValueError( 225 | f"No validator found for format: {self.state_machine.format}" 226 | ) 227 | if not validator(value): 228 | return False 229 | 230 | return True 231 | -------------------------------------------------------------------------------- /pse/types/json/json_value.py: -------------------------------------------------------------------------------- 1 | from pse_core import Edge, StateId 2 | from pse_core.state_machine import StateMachine 3 | from pse_core.stepper import Stepper 4 | 5 | 6 | class JsonStateMachine(StateMachine): 7 | def get_edges(self, state: StateId) -> list[Edge]: 8 | if state == 0: 9 | from pse.types.array import ArrayStateMachine 10 | from pse.types.base.phrase import PhraseStateMachine 11 | from pse.types.boolean import BooleanStateMachine 12 | from pse.types.number import NumberStateMachine 13 | from pse.types.object import ObjectStateMachine 14 | from pse.types.string import StringStateMachine 15 | 16 | return [ 17 | (ObjectStateMachine(), "$"), 18 | (ArrayStateMachine(), "$"), 19 | (StringStateMachine(), "$"), 20 | (PhraseStateMachine("null"), "$"), 21 | (BooleanStateMachine(), "$"), 22 | (NumberStateMachine(), "$"), 23 | ] 24 | return [] 25 | 26 | def get_steppers(self, state: StateId | None = None) -> list[Stepper]: 27 | steppers = [] 28 | for edge, _ in self.get_edges(state or 0): 29 | steppers.extend(edge.get_steppers()) 30 | return steppers 31 | -------------------------------------------------------------------------------- /pse/types/json/schema_sources/from_function.py: -------------------------------------------------------------------------------- 1 | import enum 2 | import inspect 3 | import logging 4 | from collections.abc import Callable 5 | from typing import Any, get_args, get_origin 6 | 7 | from docstring_parser import Docstring, DocstringParam, parse 8 | from pydantic import BaseModel 9 | 10 | from pse.types.json.schema_sources.from_pydantic import pydantic_to_schema 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | def callable_to_schema(function: Callable) -> dict[str, Any]: 16 | """ 17 | Generate a schema for the specified Python function. 18 | 19 | This takes a callable and parses its signature and docstring, 20 | and constructs a schema representing the function's parameters. 21 | 22 | Args: 23 | function (Callable): The Python function to generate a schema for. 24 | 25 | Returns: 26 | dict[str, Any]: A dictionary representing the JSON schema of the function's parameters. 27 | """ 28 | sig = inspect.signature(function) 29 | docstring = parse(function.__doc__ or "No docstring provided") 30 | 31 | schema: dict[str, Any] = { 32 | "name": function.__name__, 33 | "description": docstring.description, 34 | "parameters": { 35 | "type": "object", 36 | "properties": {}, 37 | "required": [], 38 | }, 39 | } 40 | 41 | param_index = 0 42 | for param in sig.parameters.values(): 43 | if param.name == "self": 44 | continue # skip 'self' parameter 45 | 46 | param_docstring = ( 47 | docstring.params[param_index] 48 | if len(docstring.params) > param_index 49 | else None 50 | ) 51 | param_schema = parameter_to_schema(param, param_docstring, docstring) 52 | schema["parameters"]["properties"][param.name] = param_schema 53 | 54 | if param.default is inspect.Parameter.empty and param_schema.get("nullable", False) is False: 55 | schema["parameters"]["required"].append(param.name) 56 | 57 | param_index += 1 58 | 59 | # Handle the case when all parameters are nullable with defaults 60 | if not schema["parameters"]["required"]: 61 | del schema["parameters"]["required"] 62 | schema["parameters"]["nullable"] = True 63 | 64 | return schema 65 | 66 | 67 | def parameter_to_schema( 68 | param: inspect.Parameter, 69 | param_docstring: DocstringParam | None, 70 | docstring: Docstring 71 | ) -> dict[str, Any]: 72 | """ 73 | Generate a schema for a function parameter. 74 | 75 | Args: 76 | param (inspect.Parameter): The parameter to generate a schema for. 77 | docstring (Docstring): The docstring for the function. 78 | """ 79 | 80 | parameter_schema: dict[str, Any] = {} 81 | if param_docstring: 82 | parameter_schema["description"] = param_docstring.description or docstring.short_description or "" 83 | 84 | if inspect.isclass(param.annotation) and issubclass(param.annotation, BaseModel): 85 | # Use Pydantic model if the parameter is a BaseModel subclass 86 | return pydantic_to_schema(param.annotation) 87 | elif param.annotation == inspect.Parameter.empty: 88 | logger.warning(f"Parameter '{param.name}' lacks type annotation.") 89 | parameter_schema["type"] = "any" 90 | return parameter_schema 91 | elif param.default is not inspect.Parameter.empty: 92 | default_value = param.default 93 | if default_value is None: 94 | parameter_schema["nullable"] = True 95 | else: 96 | parameter_schema["default"] = default_value 97 | ####### 98 | parameter_type_schemas = [] 99 | parameter_arguments = get_args(param.annotation) 100 | 101 | # Special handling for direct dict type 102 | origin = get_origin(param.annotation) 103 | if origin is dict: 104 | dict_schema = { 105 | "type": "object", 106 | "additionalProperties": {"type": "any"} 107 | } 108 | args = get_args(param.annotation) 109 | if len(args) > 1: 110 | value_type = args[1] 111 | dict_schema["additionalProperties"] = { 112 | "type": get_type(value_type) 113 | } 114 | # Preserve the description if it exists 115 | if param_docstring: 116 | dict_schema["description"] = param_docstring.description or "" 117 | return dict_schema 118 | 119 | # Process union types and other types. 120 | parameter_type_schemas: list[dict[str, Any]] = [] 121 | for argument in parameter_arguments or [param.annotation]: 122 | parameter_type_schema: dict[str, Any] = {} 123 | arg_origin = get_origin(argument) 124 | parameter_type = get_type(argument) 125 | 126 | if arg_origin is dict: 127 | parameter_type_schema["type"] = "object" 128 | args = get_args(argument) 129 | # Consider using get, or a guard clause. 130 | if len(args) > 1: 131 | parameter_type_schema["additionalProperties"] = { 132 | "type": get_type(args[1]) 133 | } 134 | elif parameter_type == "null": 135 | parameter_schema["nullable"] = True 136 | continue # Skip adding to type_schemas. 137 | elif parameter_type in ("array", "set"): 138 | parameter_type_schema["type"] = parameter_type 139 | if args := get_args(argument): 140 | parameter_type_schema["items"] = {"type": get_type(args[0])} 141 | elif parameter_type == "enum" and issubclass(argument, enum.Enum): 142 | parameter_type_schema["enum"] = [ 143 | member.value for member in argument 144 | ] # More concisely. 145 | elif parameter_type: 146 | parameter_type_schema["type"] = parameter_type 147 | 148 | if parameter_type_schema: 149 | parameter_type_schemas.append(parameter_type_schema) 150 | 151 | # Simplify the logic for setting the final schema type, handling edge cases. 152 | match len(parameter_type_schemas): 153 | case 0: 154 | # If no types were added and it wasn't nullable, default to "any". 155 | if "nullable" not in parameter_schema: 156 | parameter_schema["type"] = "any" 157 | case 1: 158 | # Merge the single schema into the main schema. 159 | parameter_schema.update(parameter_type_schemas[0]) 160 | 161 | if len(parameter_type_schemas) > 1: 162 | parameter_schema["type"] = parameter_type_schemas 163 | elif parameter_type_schemas: 164 | parameter_schema.update(**parameter_type_schemas[0]) 165 | else: 166 | parameter_schema["type"] = "any" 167 | 168 | return parameter_schema 169 | 170 | def get_type(python_type: Any) -> str: 171 | """Map a Python type to a JSON schema type.""" 172 | if python_type is type(None): 173 | return "null" 174 | 175 | type_name = get_origin(python_type) or python_type 176 | type_map: dict[type | Any, str] = { 177 | int: "integer", 178 | str: "string", 179 | bool: "boolean", 180 | float: "number", 181 | list: "array", 182 | dict: "object", 183 | tuple: "array", 184 | set: "set", 185 | enum.EnumType: "enum", 186 | type(None): "null", 187 | BaseModel: "object", 188 | Any: "any", 189 | } 190 | if type_name not in type_map: 191 | if type(python_type) in type_map: 192 | return type_map[type(python_type)] 193 | 194 | logger.warning(f"Unknown type: {python_type}") 195 | return "any" 196 | 197 | return type_map[type_name] 198 | -------------------------------------------------------------------------------- /pse/types/json/schema_sources/from_pydantic.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Any 3 | 4 | from docstring_parser import parse 5 | from pydantic import BaseModel 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | def pydantic_to_schema(model: type[BaseModel]) -> dict[str, Any]: 11 | """ 12 | Convert a Pydantic model class to a standardized schema format. 13 | 14 | Returns: 15 | dict[str, Any]: A dictionary representing the schema. 16 | """ 17 | # Get base schema from Pydantic model 18 | schema = model.model_json_schema() 19 | 20 | # Extract docstring info 21 | docstring = parse(model.__doc__ or "") 22 | docstring_params = { 23 | param.arg_name: param.description 24 | for param in docstring.params 25 | if param.description 26 | } 27 | 28 | # Get description from schema or docstring 29 | description = schema.get("description") or docstring.short_description or "" 30 | 31 | # Extract parameters, excluding metadata fields 32 | parameters = {k: v for k, v in schema.items() if k not in {"title", "description"}} 33 | 34 | # Process properties and required fields 35 | properties = parameters.get("properties", {}) 36 | required_fields = set(parameters.get("required", [])) 37 | 38 | assert isinstance(properties, dict) 39 | 40 | # Update field schemas with descriptions and requirements 41 | for field_name, field in model.model_fields.items(): 42 | field_schema: dict[str, Any] = properties.get(field_name, {}) 43 | 44 | # Add field to required list if needed 45 | if field.is_required(): 46 | required_fields.add(field_name) 47 | 48 | # Set description from field or docstring 49 | field_schema["description"] = field.description or docstring_params.get( 50 | field_name, "" 51 | ) 52 | 53 | # Add any extra schema properties 54 | if extra := field.json_schema_extra: 55 | if isinstance(extra, dict): 56 | # For dictionaries, update the schema with the extra fields 57 | field_schema.update(extra) 58 | 59 | parameters["required"] = list(required_fields) 60 | 61 | return { 62 | "name": schema.get("title", model.__name__), 63 | "description": description, 64 | **parameters, 65 | } 66 | -------------------------------------------------------------------------------- /pse/types/key_value.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import json 4 | import logging 5 | from typing import Any 6 | 7 | from pse_core import StateId 8 | from pse_core.state_machine import StateMachine 9 | from pse_core.stepper import Stepper 10 | 11 | from pse.types.base.chain import ChainStateMachine 12 | from pse.types.base.phrase import PhraseStateMachine 13 | from pse.types.string import StringStateMachine 14 | from pse.types.whitespace import WhitespaceStateMachine 15 | 16 | logger = logging.getLogger() 17 | 18 | 19 | class KeyValueStateMachine(ChainStateMachine): 20 | def __init__(self, sequence: list[StateMachine] | None = None, is_optional: bool = False) -> None: 21 | from pse.types.json.json_value import JsonStateMachine 22 | 23 | super().__init__( 24 | sequence 25 | or [ 26 | StringStateMachine(), 27 | WhitespaceStateMachine(), 28 | PhraseStateMachine(":"), 29 | WhitespaceStateMachine(), 30 | JsonStateMachine(), 31 | ], 32 | is_optional=is_optional, 33 | ) 34 | 35 | def get_new_stepper(self, state: StateId | None = None) -> KeyValueStepper: 36 | return KeyValueStepper(self, state) 37 | 38 | def __str__(self) -> str: 39 | return "KeyValue" 40 | 41 | 42 | class KeyValueStepper(Stepper): 43 | def __init__( 44 | self, 45 | state_machine: KeyValueStateMachine, 46 | current_step_id: StateId | None = None, 47 | ) -> None: 48 | super().__init__(state_machine, current_step_id) 49 | self.prop_name = "" 50 | self.prop_value: Any | None = None 51 | 52 | def clone(self) -> KeyValueStepper: 53 | cloned_stepper = super().clone() 54 | cloned_stepper.prop_name = self.prop_name 55 | cloned_stepper.prop_value = self.prop_value 56 | return cloned_stepper 57 | 58 | def should_complete_step(self) -> bool: 59 | """ 60 | Handle the completion of a transition by setting the property name and value. 61 | 62 | Returns: 63 | bool: True if the transition was successful, False otherwise. 64 | """ 65 | if not super().should_complete_step() or not self.sub_stepper: 66 | return False 67 | 68 | try: 69 | if self.target_state == 1: 70 | self.prop_name = json.loads(self.sub_stepper.get_raw_value()) 71 | elif self.target_state in self.state_machine.end_states: 72 | self.prop_value = json.loads(self.sub_stepper.get_raw_value()) 73 | except Exception: 74 | return False 75 | 76 | return True 77 | 78 | def get_current_value(self) -> tuple[str, Any]: 79 | """ 80 | Get the parsed property as a key-value pair. 81 | 82 | Returns: 83 | Tuple[str, Any]: A tuple containing the property name and its corresponding value. 84 | """ 85 | if not self.prop_name: 86 | return ("", None) 87 | return (self.prop_name, self.prop_value) 88 | -------------------------------------------------------------------------------- /pse/types/misc/fenced_freeform.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pse_core import StateId 4 | 5 | from pse.types.base.character import CharacterStateMachine 6 | from pse.types.base.encapsulated import EncapsulatedStateMachine, EncapsulatedStepper 7 | 8 | 9 | class FencedFreeformStateMachine(EncapsulatedStateMachine): 10 | """ 11 | A state machine that can be used to parse freeform text that is enclosed in a pair of delimiters. 12 | """ 13 | def __init__(self, 14 | identifier: str | None = None, 15 | delimiter: tuple[str, str] | None = None, 16 | buffer_length: int = -1, 17 | char_min: int = 1, 18 | char_max: int = -1, 19 | is_optional: bool = False): 20 | 21 | if delimiter is None: 22 | delimiter = (f'```{identifier or ""}\n', '\n```') 23 | 24 | freeform_state_machine = CharacterStateMachine( 25 | whitelist_charset="", 26 | graylist_charset=set(delimiter[1]), 27 | blacklist_charset=delimiter[0][0], 28 | char_min=char_min, 29 | char_limit=char_max, 30 | ) 31 | super().__init__(freeform_state_machine, delimiter, buffer_length, is_optional) 32 | self.identifier = identifier 33 | 34 | def get_new_stepper(self, state: StateId | None = None) -> FencedFreeformStepper: 35 | return FencedFreeformStepper(self, state) 36 | 37 | class FencedFreeformStepper(EncapsulatedStepper): 38 | 39 | def __init__( 40 | self, 41 | state_machine: FencedFreeformStateMachine, 42 | state: StateId | None = None, 43 | ) -> None: 44 | super().__init__(state_machine, state) 45 | self.state_machine: FencedFreeformStateMachine = state_machine 46 | 47 | def get_identifier(self) -> str | None: 48 | return self.state_machine.identifier 49 | 50 | def get_invalid_continuations(self) -> list[str]: 51 | if not self.inner_stepper: 52 | return [self.state_machine.delimiters[1]] 53 | return super().get_invalid_continuations() 54 | -------------------------------------------------------------------------------- /pse/types/misc/freeform.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from collections.abc import Callable 4 | from typing import Any 5 | 6 | from pse_core import StateId 7 | 8 | from pse.types.base.wait_for import WaitFor, WaitForStepper 9 | from pse.types.enum import EnumStateMachine 10 | 11 | 12 | class FreeformStateMachine(WaitFor): 13 | """ 14 | A state machine that can be used to parse freeform text that has an ending delimiter. 15 | """ 16 | 17 | def __init__( 18 | self, 19 | end_delimiters: list[str], 20 | char_min: int | None = None, 21 | ): 22 | self.end_delimiters = end_delimiters 23 | delimiter_state_machine = EnumStateMachine(self.end_delimiters, require_quotes=False) 24 | super().__init__( 25 | delimiter_state_machine, 26 | buffer_length=char_min or 1, 27 | ) 28 | 29 | def get_new_stepper(self, _: StateId | None = None) -> FreeformStepper: 30 | return FreeformStepper(self) 31 | 32 | def __str__(self) -> str: 33 | return "FreeformText" 34 | 35 | class FreeformStepper(WaitForStepper): 36 | 37 | def __init__( 38 | self, 39 | state_machine: FreeformStateMachine, 40 | ): 41 | super().__init__(state_machine) 42 | self.state_machine: FreeformStateMachine = state_machine 43 | 44 | def get_raw_value(self) -> str: 45 | """ 46 | Get the raw value of the buffer. 47 | """ 48 | if self.sub_stepper: 49 | return self.buffer + self.sub_stepper.get_raw_value() 50 | elif self.history: 51 | accepted_raw_value = self.history[-1].get_raw_value() 52 | return self.buffer + accepted_raw_value 53 | 54 | return self.buffer 55 | 56 | def get_current_value(self) -> Any: 57 | return self.buffer 58 | 59 | def get_token_safe_output( 60 | self, 61 | decode_function: Callable[[list[int]], str], 62 | ) -> str: 63 | """ 64 | Get the token safe output of the buffer. 65 | """ 66 | safe_output = super().get_token_safe_output(decode_function) 67 | for end_delimiter in self.state_machine.end_delimiters: 68 | if safe_output.endswith(end_delimiter): 69 | return safe_output[:-len(end_delimiter)] 70 | 71 | return safe_output 72 | -------------------------------------------------------------------------------- /pse/types/number.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | 5 | from pse_core import Edge, StateId 6 | from pse_core.state_machine import StateMachine 7 | 8 | from pse.types.base.chain import ChainStateMachine 9 | from pse.types.base.character import CharacterStateMachine 10 | from pse.types.base.phrase import PhraseStateMachine 11 | from pse.types.integer import IntegerStateMachine 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class NumberStateMachine(StateMachine): 17 | """ 18 | Accepts a well-formed JSON number. 19 | 20 | This state_machine defines the state transitions for parsing JSON numbers, handling integer, 21 | decimal, and exponential formats as specified by the JSON standard. 22 | """ 23 | 24 | def __init__(self): 25 | super().__init__( 26 | { 27 | 0: [ 28 | (PhraseStateMachine("-", is_optional=True), 1), 29 | ], 30 | 1: [ 31 | (IntegerStateMachine(), 2), 32 | ], 33 | 2: [ 34 | ( 35 | ChainStateMachine( 36 | [ 37 | PhraseStateMachine("."), 38 | IntegerStateMachine(drop_leading_zeros=False), 39 | ], 40 | ), 41 | 3, 42 | ), 43 | ], 44 | 3: [ 45 | (CharacterStateMachine("eE", char_limit=1), 4), 46 | ], 47 | 4: [ 48 | (CharacterStateMachine("+-", char_limit=1), 5), 49 | ], 50 | 5: [ 51 | (IntegerStateMachine(), "$"), 52 | ], 53 | }, 54 | end_states=[2, 3, "$"], 55 | ) 56 | 57 | def get_edges(self, state: StateId) -> list[Edge]: 58 | """ 59 | Get the edges for a given state. 60 | """ 61 | if state == 2: 62 | return [*super().get_edges(state), *super().get_edges(3)] 63 | elif state == 4: 64 | return [*super().get_edges(5), *super().get_edges(state)] 65 | else: 66 | return [*super().get_edges(state)] 67 | 68 | def __str__(self) -> str: 69 | return "Number" 70 | -------------------------------------------------------------------------------- /pse/types/object.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | from typing import Any 5 | 6 | from pse_core import StateId 7 | from pse_core.state_machine import StateMachine 8 | from pse_core.stepper import Stepper 9 | 10 | from pse.types.base.chain import ChainStateMachine 11 | from pse.types.base.phrase import PhraseStateMachine 12 | from pse.types.key_value import KeyValueStateMachine 13 | from pse.types.whitespace import WhitespaceStateMachine 14 | 15 | logger = logging.getLogger() 16 | 17 | 18 | class ObjectStateMachine(StateMachine): 19 | """ 20 | Accepts a well-formed JSON object and manages state transitions during parsing. 21 | 22 | This state_machine handles the parsing of JSON objects by defining state transitions 23 | and maintaining the current object properties being parsed. 24 | """ 25 | 26 | def __init__(self, is_optional: bool = False) -> None: 27 | """ 28 | 29 | Sets up the state transition graph for parsing JSON objects. 30 | """ 31 | super().__init__( 32 | { 33 | 0: [ 34 | (PhraseStateMachine("{"), 1), 35 | ], 36 | 1: [ 37 | (WhitespaceStateMachine(), 2), 38 | ], 39 | 2: [ 40 | (KeyValueStateMachine(), 3), 41 | ], 42 | 3: [ 43 | (WhitespaceStateMachine(), 4), 44 | ], 45 | 4: [ 46 | ( 47 | ChainStateMachine( 48 | [PhraseStateMachine(","), WhitespaceStateMachine()] 49 | ), 50 | 2, 51 | ), 52 | (PhraseStateMachine("}"), "$"), # End of object 53 | ], 54 | }, 55 | is_optional=is_optional, 56 | ) 57 | 58 | def get_new_stepper(self, state: StateId | None = None) -> ObjectStepper: 59 | return ObjectStepper(self, state) 60 | 61 | def get_transitions(self, stepper: Stepper) -> list[tuple[Stepper, StateId]]: 62 | transitions = super().get_transitions(stepper) 63 | if stepper.current_state == 1 and self.is_optional: 64 | for transition in PhraseStateMachine("}").get_steppers(): 65 | transitions.append((transition, "$")) 66 | return transitions 67 | 68 | def __str__(self) -> str: 69 | return "Object" 70 | 71 | 72 | class ObjectStepper(Stepper): 73 | def __init__( 74 | self, state_machine: ObjectStateMachine, current_state: StateId | None = None 75 | ) -> None: 76 | super().__init__(state_machine, current_state) 77 | self.value: dict[str, Any] = {} 78 | 79 | def clone(self) -> ObjectStepper: 80 | cloned_stepper = super().clone() 81 | cloned_stepper.value = self.value.copy() 82 | return cloned_stepper 83 | 84 | def add_to_history(self, stepper: Stepper) -> None: 85 | if self.current_state == 3: 86 | prop_name, prop_value = stepper.get_current_value() 87 | logger.debug(f"🟢 Adding {prop_name}: {prop_value} to {self.value}") 88 | self.value[prop_name] = prop_value 89 | super().add_to_history(stepper) 90 | 91 | def get_current_value(self) -> dict[str, Any]: 92 | """ 93 | Get the current parsed JSON object. 94 | 95 | Returns: 96 | dict[str, Any]: The accumulated key-value pairs representing the JSON object. 97 | """ 98 | if not self.get_raw_value(): 99 | return {} 100 | return self.value 101 | -------------------------------------------------------------------------------- /pse/types/string.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pse_core import StateId 4 | from pse_core.state_machine import StateMachine 5 | from pse_core.stepper import Stepper 6 | 7 | from pse.types.base.character import CharacterStateMachine 8 | from pse.types.base.phrase import PhraseStateMachine 9 | 10 | INVALID_CHARS: set[str] = {chr(c) for c in range(0, 0x20)} | {'"', "\\"} 11 | 12 | 13 | class StringStateMachine(StateMachine): 14 | """ 15 | Accepts a well-formed JSON string. 16 | 17 | The length of the string is measured excluding the surrounding quotation marks. 18 | """ 19 | 20 | # StateId constants 21 | STRING_CONTENTS = 1 22 | ESCAPED_SEQUENCE = 2 23 | HEX_CODE = 3 24 | 25 | def __init__(self, min_length: int | None = None, max_length: int | None = None): 26 | """ 27 | The state machine is configured to parse JSON strings, handling escape sequences 28 | and Unicode characters appropriately. 29 | """ 30 | super().__init__( 31 | { 32 | 0: [ 33 | (PhraseStateMachine('"'), self.STRING_CONTENTS), 34 | ], 35 | self.STRING_CONTENTS: [ 36 | ( 37 | CharacterStateMachine( 38 | blacklist_charset=INVALID_CHARS, 39 | char_min=min_length, 40 | char_limit=max_length, 41 | ), 42 | self.STRING_CONTENTS, 43 | ), # Regular characters 44 | (PhraseStateMachine('"'), "$"), # End quote 45 | ( 46 | PhraseStateMachine("\\"), 47 | self.ESCAPED_SEQUENCE, 48 | ), # Escape character 49 | ], 50 | self.ESCAPED_SEQUENCE: [ 51 | ( 52 | CharacterStateMachine('"\\/bfnrt', char_limit=1), 53 | self.STRING_CONTENTS, 54 | ), # Escaped characters 55 | (PhraseStateMachine("u"), self.HEX_CODE), # Unicode escape sequence 56 | ], 57 | self.HEX_CODE: [ 58 | ( 59 | CharacterStateMachine( 60 | "0123456789ABCDEFabcdef", 61 | char_min=4, 62 | char_limit=4, 63 | ), 64 | self.STRING_CONTENTS, 65 | ), 66 | ], 67 | } 68 | ) 69 | 70 | def get_new_stepper(self, state: int | str | None = None) -> Stepper: 71 | return StringStepper(self, state) 72 | 73 | def __str__(self) -> str: 74 | return "String" 75 | 76 | 77 | class StringStepper(Stepper): 78 | def __init__( 79 | self, state_machine: StringStateMachine, current_state: StateId | None = None 80 | ) -> None: 81 | super().__init__(state_machine, current_state) 82 | self.state_machine: StringStateMachine = state_machine 83 | 84 | def is_within_value(self) -> bool: 85 | """ 86 | Determines if the stepper is currently within the string value (after opening quote, before closing quote). 87 | """ 88 | return self.current_state != 0 and self.target_state not in self.state_machine.end_states 89 | -------------------------------------------------------------------------------- /pse/types/whitespace.py: -------------------------------------------------------------------------------- 1 | """Whitespace state machine for parsing optional whitespace in structured data. 2 | 3 | This module provides a state machine for recognizing and parsing whitespace 4 | characters in structured data formats like JSON. 5 | """ 6 | 7 | from __future__ import annotations 8 | 9 | from pse.types.base.character import CharacterStateMachine 10 | 11 | # Whitespace characters as defined by the JSON standard 12 | WHITESPACE_CHARS = " \t\n\r" 13 | 14 | 15 | class WhitespaceStateMachine(CharacterStateMachine): 16 | """Optional whitespace state machine using TokenTrie for efficient matching.""" 17 | 18 | def __init__(self, min_whitespace: int = 0, max_whitespace: int = 20): 19 | """Initialize the whitespace state machine with configurable limits. 20 | 21 | Args: 22 | min_whitespace: Minimum allowable whitespace characters. 23 | Defaults to 0. 24 | max_whitespace: Maximum allowable whitespace characters. 25 | Defaults to 20. 26 | """ 27 | super().__init__( 28 | WHITESPACE_CHARS, 29 | char_min=min_whitespace, 30 | char_limit=max_whitespace, 31 | is_optional=(min_whitespace == 0), 32 | ) 33 | 34 | def __str__(self) -> str: 35 | """Return a string representation of this state machine.""" 36 | return "Whitespace" 37 | -------------------------------------------------------------------------------- /pse/types/xml/xml_encapsulated.py: -------------------------------------------------------------------------------- 1 | from pse_core import StateId 2 | from pse_core.state_machine import StateMachine 3 | 4 | from pse.types.base.encapsulated import EncapsulatedStepper 5 | from pse.types.base.wait_for import WaitFor 6 | from pse.types.xml.xml_tag import XMLTagStateMachine 7 | 8 | 9 | class XMLEncapsulatedStateMachine(StateMachine): 10 | """ 11 | A state machine that wraps a state machine in XML tags. 12 | """ 13 | 14 | def __init__( 15 | self, 16 | state_machine: StateMachine, 17 | tag_name: str, 18 | min_buffer_length: int = -1, 19 | is_optional: bool = False, 20 | ) -> None: 21 | """ 22 | 23 | Args: 24 | state_machine: The state_machine wrapped by this state machine. 25 | tag_name: The name of the tag to wrap the state machine in. 26 | """ 27 | self.inner_state_machine = state_machine 28 | self.xml_delimiters = (f"<{tag_name}>", f"") 29 | super().__init__( 30 | { 31 | 0: [ 32 | ( 33 | WaitFor( 34 | XMLTagStateMachine(tag_name), 35 | buffer_length=min_buffer_length, 36 | ), 37 | 1, 38 | ), 39 | ], 40 | 1: [(state_machine, 2)], 41 | 2: [(XMLTagStateMachine(tag_name, closing_tag=True), "$")], 42 | }, 43 | is_optional=is_optional, 44 | ) 45 | 46 | def get_new_stepper(self, state: StateId | None = None) -> EncapsulatedStepper: 47 | return EncapsulatedStepper(self, state) 48 | -------------------------------------------------------------------------------- /pse/types/xml/xml_tag.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pse.types.base.chain import ChainStateMachine, ChainStepper 4 | from pse.types.base.phrase import PhraseStateMachine 5 | 6 | 7 | class XMLTagStateMachine(ChainStateMachine): 8 | """ 9 | A state machine that recognizes XML tags. 10 | """ 11 | 12 | def __init__(self, tag_name: str, closing_tag: bool = False) -> None: 13 | self.tag_name = tag_name 14 | self.xml_tag = ("<" if not closing_tag else "" 15 | super().__init__( 16 | [ 17 | PhraseStateMachine("<" if not closing_tag else ""), 20 | ] 21 | ) 22 | 23 | def get_new_stepper(self, state: int | str | None = None) -> XMLTagStepper: 24 | return XMLTagStepper(self, state) 25 | 26 | def __str__(self) -> str: 27 | return self.tag_name 28 | 29 | class XMLTagStepper(ChainStepper): 30 | 31 | def __init__(self, state_machine: XMLTagStateMachine, *args, **kwargs) -> None: 32 | super().__init__(state_machine, *args, **kwargs) 33 | self.state_machine: XMLTagStateMachine = state_machine 34 | 35 | def get_valid_continuations(self) -> list[str]: 36 | return [self.state_machine.xml_tag] 37 | -------------------------------------------------------------------------------- /pse/util/generate_mlx.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from collections.abc import Callable 3 | from typing import Any 4 | 5 | import mlx.core as mx 6 | import mlx.nn as nn 7 | from mlx_proxy.generate_step import generate_step 8 | from mlx_proxy.samplers import make_sampler 9 | 10 | from pse.structuring_engine import StructuringEngine 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | def generate( 15 | prompt: str, 16 | model: nn.Module, 17 | engine: StructuringEngine, 18 | prefill: str | None = None, 19 | ) -> str: 20 | mx.metal.clear_cache() 21 | messages = [{"role": "user", "content": prompt}] 22 | formatted_prompt = engine.tokenizer.apply_chat_template( 23 | conversation=messages, 24 | add_generation_prompt=True, 25 | tokenize=False 26 | ) 27 | assert isinstance(formatted_prompt, str) 28 | formatted_prompt = formatted_prompt + (prefill or "") 29 | logger.info(formatted_prompt) 30 | 31 | encoded_prompt = engine.tokenizer.encode(formatted_prompt, add_special_tokens=False) 32 | output_tokens: list[int] = [] 33 | for tokens, _ in generate_step( 34 | prompt=encoded_prompt, 35 | model=model, 36 | logits_processors=[engine.process_logits], 37 | sampler=sampler(engine), 38 | max_tokens=-1, 39 | ): 40 | assert isinstance(tokens, mx.array) 41 | token_list = tokens.tolist() if tokens.shape[0] > 1 else [tokens.item()] 42 | encoded_prompt.extend(token_list) # type: ignore[arg-type] 43 | output_tokens.extend(token_list) # type: ignore[arg-type] 44 | if engine.has_reached_accept_state: 45 | break 46 | 47 | output = engine.tokenizer.decode(output_tokens) 48 | return prefill + output if prefill else output 49 | 50 | def sampler(engine: StructuringEngine, **kwargs: Any) -> Callable[..., Any]: 51 | """ 52 | Return a sampler function. 53 | If structured is True, use the structured sampler. 54 | Otherwise, use the simple sampler. 55 | """ 56 | sampler = make_sampler(kwargs.get("temp", 0.7)) 57 | return lambda x: engine.sample(x, sampler) 58 | -------------------------------------------------------------------------------- /pse/util/jax_mixin.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | from jax import lax 6 | from transformers import FlaxGenerationMixin, FlaxLogitsProcessorList 7 | from transformers.generation.flax_utils import FlaxSampleOutput, SampleState 8 | 9 | from pse.structuring_engine import StructuringEngine 10 | 11 | 12 | class PSEFlaxMixin(FlaxGenerationMixin): 13 | engine: StructuringEngine 14 | 15 | @staticmethod 16 | def make_sampler(prng_key: jnp.ndarray) -> Callable: 17 | return lambda x: jax.random.categorical(prng_key, x, axis=-1) 18 | 19 | def _sample( 20 | self, 21 | input_ids: None, 22 | max_length: int | None = None, 23 | pad_token_id: int | None = None, 24 | eos_token_id: int | None = None, 25 | prng_key: jnp.ndarray | None = None, 26 | logits_processor: FlaxLogitsProcessorList | None = None, 27 | logits_warper: FlaxLogitsProcessorList | None = None, 28 | trace: bool = True, 29 | params: dict[str, jnp.ndarray] | None = None, 30 | model_kwargs: dict[str, jnp.ndarray] | None = None, 31 | ): 32 | # init values 33 | max_length = max_length if max_length is not None else self.generation_config.max_length 34 | pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id 35 | eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id 36 | prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0) 37 | 38 | batch_size, cur_len = input_ids.shape # type: ignore [arg-type] 39 | 40 | eos_token_id = jnp.array(eos_token_id, dtype=jnp.int32) if eos_token_id is not None else None # type: ignore [arg-type] 41 | pad_token_id = jnp.array(pad_token_id, dtype=jnp.int32) # type: ignore [arg-type] 42 | cur_len = jnp.array(cur_len) # type: ignore [arg-type] 43 | 44 | # per batch-item holding current token in loop. 45 | sequences = jnp.full((batch_size, max_length), pad_token_id, dtype=jnp.int32) # type: ignore [arg-type] 46 | sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0)) # type: ignore [arg-type] 47 | 48 | # per batch-item state bit indicating if sentence has finished. 49 | is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_) 50 | if not logits_processor or self.engine.process_logits not in logits_processor: 51 | # insert the engine at the beginning of the list 52 | if logits_processor is None: 53 | logits_processor = FlaxLogitsProcessorList() 54 | logits_processor.insert(0, self.engine.process_logits) 55 | 56 | # For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop 57 | # and pass it the `encoder_outputs`, which are part of the `model_kwargs`. 58 | model = self.decode if self.config.is_encoder_decoder else self # type: ignore [attr-defined] 59 | 60 | assert isinstance(model, Callable) 61 | # initialize model specific kwargs 62 | model_kwargs = self.prepare_inputs_for_generation(input_ids, max_length, **model_kwargs) # type: ignore [arg-type] 63 | 64 | # initialize state 65 | state = SampleState( 66 | cur_len=cur_len, # type: ignore [arg-type] 67 | sequences=sequences, # type: ignore [arg-type] 68 | running_token=input_ids, # type: ignore [arg-type] 69 | is_sent_finished=is_sent_finished, # type: ignore [arg-type] 70 | prng_key=prng_key, # type: ignore [arg-type] 71 | model_kwargs=model_kwargs, # type: ignore [arg-type] 72 | ) 73 | 74 | def sample_search_cond_fn(state): 75 | """state termination condition fn.""" 76 | has_reached_max_length = state.cur_len == max_length 77 | all_sequence_finished = jnp.all(state.is_sent_finished) 78 | finish_generation = jnp.logical_or(has_reached_max_length, all_sequence_finished) 79 | return ~finish_generation 80 | 81 | def sample_search_body_fn(state): 82 | """state update fn.""" 83 | prng_key, prng_key_next = jax.random.split(state.prng_key) 84 | model_outputs = model(state.running_token, params=params, **state.model_kwargs) 85 | 86 | logits = model_outputs.logits[:, -1] 87 | # apply min_length, ... 88 | logits = logits_processor(state.sequences, logits, state.cur_len) 89 | # apply top_p, top_k, temperature 90 | if logits_warper: 91 | logits = logits_warper(logits, logits, state.cur_len) 92 | 93 | sampler = PSEFlaxMixin.make_sampler(prng_key) 94 | next_token = self.engine.sample(logits, sampler) 95 | 96 | next_token = next_token * ~state.is_sent_finished + pad_token_id * state.is_sent_finished 97 | next_is_sent_finished = state.is_sent_finished | (next_token == eos_token_id) 98 | next_token = next_token[:, None] 99 | 100 | next_sequences = lax.dynamic_update_slice(state.sequences, next_token, (0, state.cur_len)) 101 | next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs) # type: ignore [attr-defined] 102 | 103 | return SampleState( 104 | cur_len=state.cur_len + len(next_token), # type: ignore [arg-type] 105 | sequences=next_sequences, # type: ignore [arg-type] 106 | running_token=next_token, # type: ignore [arg-type] 107 | is_sent_finished=next_is_sent_finished, # type: ignore [arg-type] 108 | model_kwargs=next_model_kwargs, # type: ignore [arg-type] 109 | prng_key=prng_key_next, # type: ignore [arg-type] 110 | ) 111 | 112 | # The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU 113 | if input_ids.shape[1] > 1: # type: ignore [arg-type] 114 | state = sample_search_body_fn(state) 115 | 116 | if not trace: 117 | state = self._run_loop_in_debug( 118 | lambda state: sample_search_cond_fn(state) and not self.engine.has_reached_accept_state, 119 | sample_search_body_fn, 120 | state, 121 | ) 122 | else: 123 | state = lax.while_loop( 124 | lambda state: sample_search_cond_fn(state) and not self.engine.has_reached_accept_state, 125 | sample_search_body_fn, 126 | state, 127 | ) 128 | 129 | return FlaxSampleOutput(sequences=state.sequences) 130 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "pse" 7 | version = "2025.06.1" 8 | authors = [ 9 | { name = "Jack Wind", email = "jckwind11@gmail.com" }, 10 | { name = "The Proxy Company", email = "contact@theproxy.company" } 11 | ] 12 | description = "Proxy Structuring Engine: Stateful AI-generated output with a focus on creativity, speed, and structure." 13 | readme = "README.md" 14 | requires-python = ">=3.11" 15 | license = { file = "LICENSE" } 16 | 17 | dependencies = [ 18 | "docstring-parser", 19 | "protobuf", 20 | "pse-core", 21 | "pydantic", 22 | "regex", 23 | "tokenizers", 24 | "transformers", 25 | "tqdm", 26 | "typing_extensions", 27 | "wheel", 28 | "lark", 29 | ] 30 | 31 | [project.optional-dependencies] 32 | dev = [ 33 | "coverage", 34 | "flake8", 35 | "pytest", 36 | "ruff", 37 | "sentencepiece", 38 | "ipykernel", 39 | "pytest-cov", 40 | ] 41 | mlx = [ 42 | "mlx", 43 | "mlx-proxy", 44 | ] 45 | torch = [ 46 | "torch", 47 | "accelerate", 48 | ] 49 | tensorflow = [ 50 | "tensorflow", 51 | "accelerate", 52 | ] 53 | jax = [ 54 | "jax", 55 | "accelerate", 56 | ] 57 | 58 | [project.urls] 59 | homepage = "https://github.com/TheProxyCompany/proxy-structuring-engine" 60 | documentation = "https://github.com/TheProxyCompany/proxy-structuring-engine#readme" 61 | source = "https://github.com/TheProxyCompany/proxy-structuring-engine" 62 | 63 | [tool.ruff.lint] 64 | extend-select = [ 65 | "B", # flake8-bugbear 66 | "I", # isort 67 | "PGH", # pygrep-hooks 68 | "RUF", # Ruff-specific 69 | "UP", # pyupgrade 70 | "SLF", # string-literal-format 71 | "F8", # flake8-comprehensions 72 | ] 73 | 74 | [tool.hatch.build.targets.sdist] 75 | include = [ 76 | "LICENSE", 77 | "README.md", 78 | "pyproject.toml", 79 | "pse" 80 | ] 81 | 82 | [tool.hatch.build.targets.wheel] 83 | packages = ["pse"] 84 | include = ["pse/**"] 85 | optimize = true 86 | ignore-vcs = true 87 | python-tag = "py311" 88 | repair-wheel = true 89 | 90 | [tool.hatch.envs.default] 91 | python = "3.11" 92 | env-vars = { PYTHONOPTIMIZE = "2" } 93 | 94 | [tool.pytest.ini_options] 95 | log_cli = false 96 | log_cli_level = "WARNING" 97 | log_cli_format = "in %(filename)s:%(lineno)d [%(levelname)s] %(message)s" 98 | log_cli_date_format = "%H:%M:%S" 99 | addopts = "--cov=pse" 100 | -------------------------------------------------------------------------------- /tests/unit/types/base/test_wait_for.py: -------------------------------------------------------------------------------- 1 | from pse_core.trie import TrieMap 2 | 3 | from pse.types.base.phrase import PhraseStateMachine, PhraseStepper 4 | from pse.types.base.wait_for import ( 5 | WaitFor, 6 | WaitForStepper, 7 | ) 8 | 9 | 10 | def test_default_wait_for_acceptor() -> None: 11 | text_acceptor = PhraseStateMachine("Hello World") 12 | state_machine = WaitFor(text_acceptor, buffer_length=0) 13 | 14 | steppers = list(state_machine.get_steppers()) 15 | assert len(steppers) == 1 16 | stepper = steppers[0] 17 | assert isinstance(stepper, WaitForStepper) 18 | assert stepper.accepts_any_token() 19 | assert stepper.sub_stepper 20 | assert stepper.sub_stepper.state_machine == text_acceptor 21 | assert isinstance(stepper.sub_stepper, PhraseStepper) 22 | assert not stepper.is_within_value() 23 | steppers = state_machine.advance_all_basic(steppers, "Hello ") 24 | assert len(steppers) == 1 25 | assert steppers[0].is_within_value() 26 | steppers = state_machine.advance_all_basic(steppers, "World") 27 | assert len(steppers) == 1 28 | assert steppers[0].has_reached_accept_state() 29 | 30 | 31 | def test_basic_wait_for_acceptor() -> None: 32 | """Test that the WaitForAcceptor can accept any token.""" 33 | text_acceptor = PhraseStateMachine("Hello World") 34 | state_machine = WaitFor(text_acceptor) 35 | steppers = list(state_machine.get_steppers()) 36 | steppers = state_machine.advance_all_basic(steppers, "Hello World") 37 | assert len(steppers) == 1 38 | assert steppers[0].has_reached_accept_state() 39 | assert not steppers[0].get_invalid_continuations() 40 | 41 | 42 | def test_interrupted_wait_for_acceptor() -> None: 43 | text_acceptor = PhraseStateMachine("Hello World") 44 | state_machine = WaitFor(text_acceptor, strict=True) 45 | 46 | steppers = state_machine.get_steppers() 47 | steppers = state_machine.advance_all_basic(steppers, "Hello ") 48 | assert len(steppers) == 1 49 | assert steppers[0].is_within_value() 50 | steppers = state_machine.advance_all_basic( 51 | steppers, "I'm gonna mess up the pattern!" 52 | ) 53 | assert not steppers 54 | 55 | 56 | def test_wait_for_acceptor_with_break() -> None: 57 | """Test that the WaitForAcceptor can accept any token.""" 58 | text_acceptor = PhraseStateMachine("Hello World") 59 | state_machine = WaitFor(text_acceptor, strict=False) 60 | steppers = list(state_machine.get_steppers()) 61 | steppers = state_machine.advance_all_basic(steppers, "Hello ") 62 | assert len(steppers) == 1 63 | 64 | steppers = state_machine.advance_all_basic( 65 | steppers, "I'm gonna mess up the pattern! But i'll still be accepted!" 66 | ) 67 | assert len(steppers) == 1 68 | 69 | steppers = state_machine.advance_all_basic(steppers, "World") 70 | assert len(steppers) == 1 71 | assert steppers[0].has_reached_accept_state() 72 | 73 | 74 | def test_wait_for_acceptor_with_partial_match(): 75 | """Test that the WaitForAcceptor can accept any token.""" 76 | text_acceptor = PhraseStateMachine('"hello"') 77 | state_machine = WaitFor(text_acceptor) 78 | steppers = list(state_machine.get_steppers()) 79 | trie_map = TrieMap() 80 | items = [ 81 | ('"hello', 1), 82 | ('"', 2), 83 | ("hello", 3), 84 | ('"c', 4), 85 | ] 86 | trie_map = trie_map.insert_all(items) 87 | stepper_deltas = state_machine.advance_all(steppers, '"*', trie_map) 88 | for stepper_delta in stepper_deltas: 89 | assert stepper_delta.was_healed 90 | assert stepper_delta.token == '"' 91 | assert stepper_delta.stepper.get_current_value() == '"' 92 | assert len(steppers) == 1 93 | assert not steppers[0].has_reached_accept_state() 94 | 95 | 96 | def test_get_valid_continuations_buffer_too_short(): 97 | """Test get_valid_continuations when buffer is shorter than min_buffer_length.""" 98 | text_acceptor = PhraseStateMachine("Hello") 99 | state_machine = WaitFor(text_acceptor, buffer_length=10) 100 | 101 | steppers = list(state_machine.get_steppers()) 102 | assert len(steppers) == 1 103 | stepper = steppers[0] 104 | 105 | # Buffer is empty, should return empty list 106 | continuations = stepper.get_valid_continuations() 107 | assert continuations == [] 108 | invalid_continuations = stepper.get_invalid_continuations() 109 | assert invalid_continuations == ["Hello"] 110 | 111 | def test_should_start_step_with_remaining_input(): 112 | """Test should_start_step when remaining_input is not None.""" 113 | text_acceptor = PhraseStateMachine("Hello") 114 | state_machine = WaitFor(text_acceptor) 115 | 116 | steppers = list(state_machine.get_steppers()) 117 | assert len(steppers) == 1 118 | stepper = steppers[0] 119 | 120 | # Set remaining_input 121 | stepper.remaining_input = "something" 122 | 123 | # should_start_step should be False with remaining_input 124 | assert not stepper.should_start_step("Hello") 125 | 126 | 127 | def test_consume_with_no_sub_stepper(): 128 | """Test consume method when sub_stepper is None.""" 129 | state_machine = WaitFor(PhraseStateMachine("Hello")) 130 | stepper = WaitForStepper(state_machine) 131 | 132 | # Force sub_stepper to be None 133 | stepper.sub_stepper = None 134 | 135 | # consume should return empty list 136 | result = stepper.consume("any token") 137 | assert result == [] 138 | 139 | 140 | def test_consume_with_min_buffer_length_negative(): 141 | """Test consume when min_buffer_length is -1 and not within value.""" 142 | text_acceptor = PhraseStateMachine("Hello") 143 | state_machine = WaitFor(text_acceptor, buffer_length=-1) 144 | 145 | steppers = list(state_machine.get_steppers()) 146 | assert len(steppers) == 1 147 | stepper = steppers[0] 148 | 149 | # Make sure we're not within a value 150 | assert not stepper.is_within_value() 151 | 152 | # Try to consume a token that doesn't start the pattern 153 | # With buffer_length = -1, this should return empty list 154 | result = stepper.consume("NotHello") 155 | assert result == [] 156 | -------------------------------------------------------------------------------- /tests/unit/types/bash/test_bash_code.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pse_core.state_machine import StateMachine 3 | 4 | from pse.types.grammar import BashStateMachine 5 | 6 | 7 | @pytest.mark.parametrize( 8 | "code, should_accept", 9 | [ 10 | # Basic commands 11 | ("echo 'Hello, World!'", True), 12 | ("ls -la", True), 13 | ("cd /home/user", True), 14 | ("grep -r 'pattern' .", True), 15 | ("cat file.txt", True), 16 | # Variables and assignments 17 | ("x=1", True), 18 | ("NAME='John Doe'", True), 19 | ("export PATH=$PATH:/usr/local/bin", True), 20 | ("readonly VAR=value", True), 21 | # Control structures 22 | ("if [ $x -eq 1 ]; then echo 'yes'; fi", True), 23 | ("for i in 1 2 3; do echo $i; done", True), 24 | ("while [ $count -lt 10 ]; do echo $count; ((count++)); done", True), 25 | ("case $var in a) echo 'a';; b) echo 'b';; esac", True), 26 | # Functions 27 | ("function greet() { echo 'Hello'; }", True), 28 | ("myfunc() { local var=1; echo $var; }", True), 29 | # Pipes and redirections 30 | ("cat file.txt | grep pattern", True), 31 | ("ls > output.txt", True), 32 | ("cat < input.txt", True), 33 | ("command >> log.txt 2>&1", True), 34 | # Command substitution 35 | ("echo $(date)", True), 36 | ("files=$(ls -la)", True), 37 | # Arithmetic 38 | ("echo $((1 + 2))", True), 39 | ("((x = y + 3))", True), 40 | # Comments 41 | ("# This is a comment", True), 42 | ("echo 'Hello' # inline comment", True), 43 | # Multiline commands 44 | ("echo 'line 1'\necho 'line 2'", True), 45 | ("if true; then\n echo 'true'\nfi", True), 46 | # Invalid syntax 47 | ("if then fi", False), 48 | ("for in do done", False), 49 | ("case esac", False), 50 | ("echo 'unterminated string", False), 51 | ("function () {}", False), 52 | ("ls | | grep pattern", False), 53 | ], 54 | ) 55 | def test_bash_source_validation(code, should_accept): 56 | """Test validation of Bash source code.""" 57 | source_code_sm = StateMachine( 58 | { 59 | 0: [(BashStateMachine, "$")], 60 | } 61 | ) 62 | steppers = source_code_sm.get_steppers() 63 | steppers = source_code_sm.advance_all_basic(steppers, code) 64 | 65 | if should_accept: 66 | assert any(stepper.has_reached_accept_state() for stepper in steppers), ( 67 | f"Should accept valid Bash code: {code}" 68 | ) 69 | else: 70 | assert not any(stepper.has_reached_accept_state() for stepper in steppers), ( 71 | f"Should not accept invalid Bash code: {code}" 72 | ) 73 | 74 | 75 | def test_incremental_parsing(): 76 | """Test incremental parsing of Bash code.""" 77 | source_code_sm = StateMachine( 78 | { 79 | 0: [(BashStateMachine, "$")], 80 | } 81 | ) 82 | 83 | # Test that we can parse a simple echo command 84 | # This is known to work from the other tests 85 | complete_code = "echo 'Hello, World!'" 86 | steppers = source_code_sm.get_steppers() 87 | steppers = source_code_sm.advance_all_basic(steppers, complete_code) 88 | 89 | # The simple echo command should be valid 90 | assert len(steppers) > 0, "Should accept simple echo command" 91 | assert any(stepper.has_reached_accept_state() for stepper in steppers) 92 | 93 | 94 | def test_empty_input(): 95 | """Test handling of empty input.""" 96 | source_code_sm = StateMachine( 97 | { 98 | 0: [(BashStateMachine, "$")], 99 | } 100 | ) 101 | steppers = source_code_sm.get_steppers() 102 | steppers = source_code_sm.advance_all_basic(steppers, "") 103 | assert not any(stepper.has_reached_accept_state() for stepper in steppers) 104 | 105 | 106 | @pytest.mark.parametrize( 107 | "incomplete_code", 108 | [ 109 | "if [ $x -eq 1 ]; then", 110 | "for i in 1 2 3; do", 111 | "while true; do", 112 | "case $var in", 113 | "function name() {", 114 | "echo 'Hello' |", 115 | "ls -la >", 116 | "cat <", 117 | "grep pattern &&", 118 | "find . -name '*.txt' ||", 119 | ], 120 | ) 121 | def test_incomplete_but_valid_code(incomplete_code): 122 | """Test handling of incomplete but syntactically valid Bash code.""" 123 | source_code_sm = StateMachine( 124 | { 125 | 0: [(BashStateMachine, "$")], 126 | } 127 | ) 128 | steppers = source_code_sm.get_steppers() 129 | steppers = source_code_sm.advance_all_basic(steppers, incomplete_code) 130 | assert len(steppers) > 0 131 | # Incomplete code should be able to accept more input 132 | assert all(stepper.can_accept_more_input() for stepper in steppers) 133 | # Incomplete code should not be in an accept state 134 | assert not any(stepper.has_reached_accept_state() for stepper in steppers) 135 | 136 | 137 | @pytest.mark.parametrize( 138 | "code", 139 | [ 140 | "echo 'string with unmatched quote", 141 | "if [ $x -eq 1 ]", 142 | "for i in $(seq 1 10)", 143 | "cat file.txt | grep 'pattern' |", 144 | "function name() { echo 'incomplete", 145 | "case $var in a) echo 'a';;", 146 | "while [ $count -lt 10 ]", 147 | "ls -la 2>", 148 | ], 149 | ) 150 | def test_bash_specific_incomplete_constructs(code): 151 | """Test Bash-specific incomplete constructs that should be considered valid during incremental parsing.""" 152 | source_code_sm = StateMachine( 153 | { 154 | 0: [(BashStateMachine, "$")], 155 | } 156 | ) 157 | steppers = source_code_sm.get_steppers() 158 | steppers = source_code_sm.advance_all_basic(steppers, code) 159 | assert len(steppers) > 0, f"Should accept incomplete Bash construct: {code}" 160 | # Incomplete constructs should not be in an accept state 161 | assert not any(stepper.has_reached_accept_state() for stepper in steppers) 162 | assert all(stepper.can_accept_more_input() for stepper in steppers) 163 | 164 | def test_bash_code_validate_no_code(): 165 | """Test that validate returns False for empty code.""" 166 | assert not BashStateMachine.grammar.validate("") 167 | -------------------------------------------------------------------------------- /tests/unit/types/bash/test_wrapped_bash.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from pse.types.base.encapsulated import EncapsulatedStateMachine 4 | from pse.types.grammar import BashStateMachine 5 | 6 | 7 | def test_basic_bash_block(): 8 | """Test basic Bash code block parsing.""" 9 | bash_sm = EncapsulatedStateMachine( 10 | BashStateMachine, delimiters=BashStateMachine.delimiters 11 | ) 12 | steppers = bash_sm.get_steppers() 13 | 14 | # Test opening delimiter 15 | steppers = bash_sm.advance_all_basic(steppers, "```bash\n") 16 | assert len(steppers) > 0 17 | # Test Bash code 18 | steppers = bash_sm.advance_all_basic(steppers, "echo 'Hello, World!'") 19 | assert len(steppers) > 0 20 | 21 | # Test closing delimiter 22 | steppers = bash_sm.advance_all_basic(steppers, "\n```") 23 | assert any(stepper.has_reached_accept_state() for stepper in steppers) 24 | 25 | 26 | @pytest.mark.parametrize( 27 | "code_block", 28 | [ 29 | "```bash\necho 'Hello, World!'\n```", 30 | "```bash\nls -la\ngrep pattern file.txt\n```", 31 | "```bash\nfor i in 1 2 3; do\n echo $i\ndone\n```", 32 | "```bash\nif [ -f file.txt ]; then\n cat file.txt\nelse\n echo 'File not found'\nfi\n```", 33 | "```bash\nfunction greet() {\n echo 'Hello, $1!'\n}\n\ngreet 'World'\n```", 34 | ], 35 | ) 36 | def test_complete_bash_blocks(code_block): 37 | """Test various complete Bash code blocks.""" 38 | bash_sm = EncapsulatedStateMachine( 39 | BashStateMachine, delimiters=BashStateMachine.delimiters 40 | ) 41 | steppers = bash_sm.get_steppers() 42 | steppers = bash_sm.advance_all_basic(steppers, code_block) 43 | assert any(stepper.has_reached_accept_state() for stepper in steppers) 44 | 45 | 46 | def test_custom_delimiters(): 47 | """Test BashStateMachine with custom delimiters.""" 48 | sm = EncapsulatedStateMachine( 49 | BashStateMachine, delimiters=("", "") 50 | ) 51 | steppers = sm.get_steppers() 52 | 53 | code_block = "echo 'Hello, World!'" 54 | steppers = sm.advance_all_basic(steppers, code_block) 55 | assert any(stepper.has_reached_accept_state() for stepper in steppers) 56 | 57 | 58 | def test_stepper_clone(): 59 | """Test cloning of BashStepper.""" 60 | bash_sm = EncapsulatedStateMachine( 61 | BashStateMachine, delimiters=BashStateMachine.delimiters 62 | ) 63 | steppers = bash_sm.get_steppers() 64 | steppers = bash_sm.advance_all_basic(steppers, "```bash\necho 'Hello'\n") 65 | 66 | original_stepper = steppers[0] 67 | cloned_stepper = original_stepper.clone() 68 | 69 | assert original_stepper.get_current_value() == cloned_stepper.get_current_value() 70 | assert original_stepper is not cloned_stepper 71 | 72 | 73 | @pytest.mark.parametrize( 74 | "invalid_block", 75 | [ 76 | "```bash\nif then fi\n```", 77 | "```bash\nfor in do done\n```", 78 | "```bash\ncase esac\n```", 79 | "```bash\nfunction () {}\n```", 80 | "```bash\nls | | grep pattern\n```", 81 | ], 82 | ) 83 | def test_invalid_bash_blocks(invalid_block): 84 | """Test handling of invalid Bash code blocks.""" 85 | bash_sm = EncapsulatedStateMachine( 86 | BashStateMachine, delimiters=BashStateMachine.delimiters 87 | ) 88 | steppers = bash_sm.get_steppers() 89 | steppers = bash_sm.advance_all_basic(steppers, invalid_block) 90 | assert not any(stepper.has_reached_accept_state() for stepper in steppers) 91 | 92 | 93 | def test_incremental_parsing(): 94 | """Test incremental parsing of Bash code block.""" 95 | bash_sm = EncapsulatedStateMachine( 96 | BashStateMachine, delimiters=BashStateMachine.delimiters 97 | ) 98 | steppers = bash_sm.get_steppers() 99 | 100 | parts = [ 101 | "```bash\n", 102 | "for i in 1 2 3; do\n", 103 | " echo $i\n", 104 | "done\n", 105 | "```", 106 | ] 107 | 108 | for part in parts: 109 | steppers = bash_sm.advance_all_basic(steppers, part) 110 | assert len(steppers) > 0 111 | 112 | assert any(stepper.has_reached_accept_state() for stepper in steppers) 113 | 114 | 115 | def test_can_accept_more_input(): 116 | """Test that the stepper can accept more input.""" 117 | bash_sm = EncapsulatedStateMachine( 118 | BashStateMachine, delimiters=BashStateMachine.delimiters 119 | ) 120 | steppers = bash_sm.get_steppers() 121 | steppers = bash_sm.advance_all_basic(steppers, "```bash\n") 122 | assert len(steppers) > 0 123 | steppers = bash_sm.advance_all_basic(steppers, "echo 'Hello") 124 | assert len(steppers) > 0 125 | assert all(stepper.can_accept_more_input() for stepper in steppers) 126 | steppers = bash_sm.advance_all_basic(steppers, "'") 127 | assert len(steppers) > 0 128 | assert all(stepper.can_accept_more_input() for stepper in steppers) 129 | steppers = bash_sm.advance_all_basic(steppers, "\n```") 130 | assert len(steppers) > 0 131 | assert any(stepper.has_reached_accept_state() for stepper in steppers) 132 | 133 | 134 | @pytest.mark.parametrize( 135 | "code_block", 136 | [ 137 | "```bash\n#!/bin/bash\n\n# This is a comment\necho 'Hello, World!'\n```", 138 | "```bash\nVAR='value'\necho $VAR\n```", 139 | "```bash\nls -la | grep '.txt' | sort\n```", 140 | "```bash\nif [ -d /tmp ]; then\n echo 'Directory exists'\nfi\n```", 141 | ], 142 | ) 143 | def test_bash_specific_features(code_block): 144 | """Test Bash-specific features.""" 145 | bash_sm = EncapsulatedStateMachine( 146 | BashStateMachine, delimiters=BashStateMachine.delimiters 147 | ) 148 | steppers = bash_sm.get_steppers() 149 | steppers = bash_sm.advance_all_basic(steppers, code_block) 150 | assert any(stepper.has_reached_accept_state() for stepper in steppers) 151 | 152 | 153 | @pytest.mark.parametrize( 154 | "incomplete_block", 155 | [ 156 | "```bash\nif [ $x -eq 1 ]; then\n", 157 | "```bash\nfor i in 1 2 3; do\n", 158 | "```bash\nwhile true; do\n", 159 | "```bash\ncase $var in\n", 160 | "```bash\nfunction name() {\n", 161 | ], 162 | ) 163 | def test_incomplete_bash_blocks(incomplete_block): 164 | """Test handling of incomplete Bash code blocks.""" 165 | bash_sm = EncapsulatedStateMachine( 166 | BashStateMachine, delimiters=BashStateMachine.delimiters 167 | ) 168 | steppers = bash_sm.get_steppers() 169 | steppers = bash_sm.advance_all_basic(steppers, incomplete_block) 170 | assert len(steppers) > 0 171 | assert all(stepper.can_accept_more_input() for stepper in steppers) 172 | -------------------------------------------------------------------------------- /tests/unit/types/json/schema_sources/test_from_pydantic.py: -------------------------------------------------------------------------------- 1 | """Tests for the from_pydantic module.""" 2 | from pydantic import BaseModel, Field 3 | 4 | from pse.types.json.schema_sources.from_pydantic import pydantic_to_schema 5 | 6 | 7 | class SimpleModel(BaseModel): 8 | """A simple model with basic fields. 9 | 10 | This model contains string and integer fields. 11 | """ 12 | 13 | name: str = Field(description="The name field") 14 | age: int = Field(description="The age field") 15 | 16 | 17 | class ModelWithDefaults(BaseModel): 18 | """A model with default values. 19 | 20 | Args: 21 | name: A name field with default 22 | age: An age field that's required 23 | """ 24 | 25 | name: str = Field(default="John", description="The name field") 26 | age: int = Field(description="The age field") 27 | 28 | 29 | class ComplexModel(BaseModel): 30 | """A complex model with nested types. 31 | 32 | Args: 33 | name: The user's name 34 | age: The user's age 35 | tags: A list of string tags 36 | address: An optional address 37 | """ 38 | 39 | name: str 40 | age: int 41 | tags: list[str] = Field(default_factory=list) 42 | address: str | None = None 43 | 44 | 45 | class NestedModel(BaseModel): 46 | """A model with nested models. 47 | 48 | Args: 49 | title: The title 50 | user: A user object 51 | """ 52 | 53 | title: str 54 | user: SimpleModel 55 | 56 | 57 | class ModelWithoutDocs(BaseModel): 58 | name: str 59 | age: int 60 | 61 | 62 | class ModelWithJsonExtra(BaseModel): 63 | """A model with json_schema_extra in fields.""" 64 | 65 | name: str = Field( 66 | description="The name field", 67 | json_schema_extra={"example": "John Doe", "pattern": "^[a-zA-Z ]+$"} 68 | ) 69 | age: int = Field( 70 | description="The age field", 71 | json_schema_extra={"minimum": 0, "maximum": 120} 72 | ) 73 | 74 | 75 | def test_pydantic_to_schema_simple(): 76 | """Test pydantic_to_schema with a simple model.""" 77 | schema = pydantic_to_schema(SimpleModel) 78 | 79 | assert schema["name"] == "SimpleModel" 80 | assert "description" in schema 81 | assert "properties" in schema 82 | assert "name" in schema["properties"] 83 | assert "age" in schema["properties"] 84 | assert schema["properties"]["name"]["description"] == "The name field" 85 | assert schema["properties"]["age"]["description"] == "The age field" 86 | assert "required" in schema 87 | assert "name" in schema["required"] 88 | assert "age" in schema["required"] 89 | 90 | 91 | def test_pydantic_to_schema_with_defaults(): 92 | """Test pydantic_to_schema with a model having default values.""" 93 | schema = pydantic_to_schema(ModelWithDefaults) 94 | 95 | assert "properties" in schema 96 | assert "name" in schema["properties"] 97 | assert "age" in schema["properties"] 98 | assert "default" in schema["properties"]["name"] 99 | assert schema["properties"]["name"]["default"] == "John" 100 | assert "required" in schema 101 | assert "age" in schema["required"] 102 | assert "name" not in schema["required"] 103 | 104 | 105 | def test_pydantic_to_schema_complex(): 106 | """Test pydantic_to_schema with a model having complex types.""" 107 | schema = pydantic_to_schema(ComplexModel) 108 | 109 | assert "properties" in schema 110 | assert "name" in schema["properties"] 111 | assert "age" in schema["properties"] 112 | assert "tags" in schema["properties"] 113 | assert "address" in schema["properties"] 114 | 115 | # Check array type 116 | assert schema["properties"]["tags"]["type"] == "array" 117 | assert "items" in schema["properties"]["tags"] 118 | 119 | # Check nullable field - Pydantic v2 uses anyOf for nullable fields 120 | address_schema = schema["properties"]["address"] 121 | if "anyOf" in address_schema: 122 | # Check if one of the types is null 123 | has_null_type = any( 124 | item.get("type") == "null" for item in address_schema["anyOf"] 125 | ) 126 | assert has_null_type 127 | else: 128 | assert address_schema.get("nullable", False) is True 129 | 130 | # Check required fields 131 | assert "required" in schema 132 | assert "name" in schema["required"] 133 | assert "age" in schema["required"] 134 | assert "address" not in schema["required"] 135 | assert "tags" not in schema["required"] 136 | 137 | 138 | def test_pydantic_to_schema_nested(): 139 | """Test pydantic_to_schema with nested models.""" 140 | schema = pydantic_to_schema(NestedModel) 141 | 142 | assert "properties" in schema 143 | assert "title" in schema["properties"] 144 | assert "user" in schema["properties"] 145 | 146 | # Check nested model - Pydantic v2 uses $ref 147 | user_schema = schema["properties"]["user"] 148 | 149 | # For Pydantic v2, it might use $ref 150 | if "$ref" in user_schema: 151 | # If using $ref, then SimpleModel should be in $defs 152 | assert "$defs" in schema 153 | assert "SimpleModel" in schema["$defs"] or user_schema["$ref"].split("/")[-1] in schema["$defs"] 154 | else: 155 | # For direct embedding of model (older versions) 156 | assert "properties" in user_schema 157 | assert "name" in user_schema["properties"] 158 | assert "age" in user_schema["properties"] 159 | assert "required" in user_schema 160 | assert "name" in user_schema["required"] 161 | assert "age" in user_schema["required"] 162 | 163 | 164 | def test_pydantic_to_schema_without_docs(): 165 | """Test pydantic_to_schema with a model without docstrings.""" 166 | schema = pydantic_to_schema(ModelWithoutDocs) 167 | 168 | assert schema["name"] == "ModelWithoutDocs" 169 | assert schema["description"] == "" 170 | assert "properties" in schema 171 | assert "name" in schema["properties"] 172 | assert "age" in schema["properties"] 173 | assert "description" in schema["properties"]["name"] 174 | assert schema["properties"]["name"]["description"] == "" 175 | 176 | 177 | def test_pydantic_to_schema_with_json_extra(): 178 | """Test pydantic_to_schema with json_schema_extra in fields.""" 179 | schema = pydantic_to_schema(ModelWithJsonExtra) 180 | 181 | assert "properties" in schema 182 | assert "name" in schema["properties"] 183 | assert "age" in schema["properties"] 184 | 185 | # Check extra schema properties 186 | name_schema = schema["properties"]["name"] 187 | assert "example" in name_schema 188 | assert name_schema["example"] == "John Doe" 189 | assert "pattern" in name_schema 190 | assert name_schema["pattern"] == "^[a-zA-Z ]+$" 191 | 192 | age_schema = schema["properties"]["age"] 193 | assert "minimum" in age_schema 194 | assert age_schema["minimum"] == 0 195 | assert "maximum" in age_schema 196 | assert age_schema["maximum"] == 120 197 | -------------------------------------------------------------------------------- /tests/unit/types/json/test_any_json_schema.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from typing import Any 3 | 4 | import pytest 5 | 6 | from pse.types.json.any_json_schema import AnySchemaStateMachine 7 | 8 | 9 | @pytest.fixture 10 | def context(): 11 | """Provide the context fixture.""" 12 | return {"defs": defaultdict(dict), "path": ""} 13 | 14 | 15 | def parse_input(state_machine: AnySchemaStateMachine, json_string: str) -> Any: 16 | """ 17 | Helper function to parse a JSON string using the AnyOfAcceptor. 18 | 19 | Args: 20 | state_machine (AnyOfAcceptor): The state_machine instance to use for parsing. 21 | json_string (str): The JSON string to parse. 22 | 23 | Returns: 24 | Any: The parsed JSON value. 25 | 26 | Raises: 27 | JSONParsingError: If the JSON input is invalid or does not match any schema. 28 | """ 29 | steppers = state_machine.get_steppers() 30 | steppers = state_machine.advance_all_basic(steppers, json_string) 31 | for stepper in steppers: 32 | if stepper.has_reached_accept_state(): 33 | return stepper.get_current_value() 34 | 35 | raise ValueError(f"Invalid JSON input for AnyOfAcceptor: {json_string}") 36 | 37 | 38 | @pytest.mark.parametrize( 39 | "schemas, token, expected_result", 40 | [ 41 | ( 42 | [{"type": "number", "minimum": 0}, {"type": "string", "maxLength": 5}], 43 | "10", 44 | 10, 45 | ), 46 | ( 47 | [{"type": "number", "minimum": 0}, {"type": "string", "maxLength": 5}], 48 | '"test"', 49 | "test", 50 | ), 51 | ], 52 | ) 53 | def test_accept_input_matching_single_schema(context, schemas, token, expected_result): 54 | """Test that input matching a single schema is accepted.""" 55 | state_machine = AnySchemaStateMachine(schemas=schemas, context=context) 56 | result = parse_input(state_machine, token) 57 | assert result == expected_result, ( 58 | f"AnyOfAcceptor should accept valid input {token}." 59 | ) 60 | 61 | 62 | def test_accept_input_matching_multiple_schemas(context): 63 | """Test that input matching multiple schemas is accepted.""" 64 | # Define overlapping schemas 65 | schema1 = {"type": "number", "minimum": 0, "maximum": 100} 66 | schema2 = {"type": "number", "multipleOf": 5} 67 | state_machine = AnySchemaStateMachine(schemas=[schema1, schema2], context=context) 68 | 69 | valid_input = "25" # Matches both schemas 70 | result = parse_input(state_machine, valid_input) 71 | assert result == 25, "AnyOfAcceptor should accept input matching multiple schemas." 72 | 73 | 74 | def test_reject_input_not_matching_any_schema(context): 75 | """Test that input not matching any schema is rejected.""" 76 | schema1 = {"type": "boolean"} 77 | schema2 = {"type": "null"} 78 | state_machine = AnySchemaStateMachine(schemas=[schema1, schema2], context=context) 79 | 80 | invalid_input_number = "1" 81 | invalid_input_string = '"test"' 82 | 83 | # Test with invalid number 84 | with pytest.raises(ValueError): 85 | parse_input(state_machine, invalid_input_number) 86 | 87 | # Test with invalid string 88 | with pytest.raises(ValueError): 89 | parse_input(state_machine, invalid_input_string) 90 | 91 | 92 | def test_complex_nested_schemas(context): 93 | """Test AnyOfAcceptor with complex nested schemas.""" 94 | schema1 = { 95 | "type": "object", 96 | "properties": { 97 | "name": {"type": "string"}, 98 | "age": {"type": "number", "minimum": 0}, 99 | }, 100 | "required": ["name", "age"], 101 | } 102 | schema2 = {"type": "array", "items": {"type": "string"}, "minItems": 1} 103 | state_machine = AnySchemaStateMachine(schemas=[schema1, schema2], context=context) 104 | 105 | valid_object = '{"name": "Alice", "age": 30}' 106 | valid_array = '["apple", "banana"]' 107 | 108 | # Test with valid object 109 | result_object = parse_input(state_machine, valid_object) 110 | assert result_object == { 111 | "name": "Alice", 112 | "age": 30, 113 | }, "AnyOfAcceptor should accept valid object input." 114 | 115 | # Test with valid array 116 | result_array = parse_input(state_machine, valid_array) 117 | assert result_array == [ 118 | "apple", 119 | "banana", 120 | ], "AnyOfAcceptor should accept valid array input." 121 | 122 | 123 | def test_partial_input(context): 124 | """Test that partial input does not result in acceptance.""" 125 | schema = {"type": "string", "minLength": 5} 126 | state_machine = AnySchemaStateMachine(schemas=[schema], context=context) 127 | 128 | partial_input = '"test"' 129 | 130 | with pytest.raises(ValueError): 131 | parse_input(state_machine, partial_input) 132 | 133 | 134 | @pytest.mark.parametrize( 135 | "token, expected_result", 136 | [ 137 | ('"test"', "test"), # Matches string schema 138 | ("123", 123), # Matches number schema 139 | ], 140 | ) 141 | def test_multiple_accepted_steppers(context, token, expected_result): 142 | """Test that AnyOfAcceptor handles multiple accepted steppers correctly.""" 143 | schema1 = {"type": "string"} 144 | schema2 = {"type": "number"} 145 | state_machine = AnySchemaStateMachine(schemas=[schema1, schema2], context=context) 146 | 147 | result = parse_input(state_machine, token) 148 | assert result == expected_result, f"AnyOfAcceptor should accept input {token}." 149 | -------------------------------------------------------------------------------- /tests/unit/types/json/test_json_key_value.py: -------------------------------------------------------------------------------- 1 | from pse.types.json.json_key_value import ( 2 | KeyValueSchemaStateMachine, 3 | KeyValueSchemaStepper, 4 | ) 5 | 6 | 7 | def test_property_parsing(): 8 | state_machine = KeyValueSchemaStateMachine( 9 | prop_name="type", 10 | prop_schema={"type": "string"}, 11 | context={"defs": {}}, 12 | ) 13 | steppers = list(state_machine.get_steppers()) 14 | steppers = state_machine.advance_all_basic(steppers, '"') 15 | 16 | assert len(steppers) == 1 17 | steppers = state_machine.advance_all_basic(steppers, "type") 18 | assert len(steppers) == 1 19 | steppers = state_machine.advance_all_basic(steppers, '": "hi"') 20 | assert len(steppers) == 1 21 | assert steppers[0].has_reached_accept_state() 22 | assert steppers[0].get_current_value() == ("type", "hi") 23 | 24 | 25 | def test_property_parsing_with_string_sm(): 26 | """Test KeyValueSchemaStateMachine when prop_name is None, using StringStateMachine.""" 27 | state_machine = KeyValueSchemaStateMachine( 28 | prop_name=None, # This tests the branch that uses StringStateMachine 29 | prop_schema={"type": "string"}, 30 | context={"defs": {}, "path": "/parent"}, 31 | ) 32 | 33 | steppers = list(state_machine.get_steppers()) 34 | steppers = state_machine.advance_all_basic(steppers, '"dynamic_key"') 35 | assert len(steppers) > 0 36 | 37 | steppers = state_machine.advance_all_basic(steppers, ': "value"') 38 | assert len(steppers) > 0 39 | assert any(stepper.has_reached_accept_state() for stepper in steppers) 40 | 41 | for stepper in steppers: 42 | if stepper.has_reached_accept_state(): 43 | assert stepper.get_current_value() == ("dynamic_key", "value") 44 | 45 | 46 | def test_key_value_schema_stepper_initialization(): 47 | """Test KeyValueSchemaStepper initialization and state machine access.""" 48 | state_machine = KeyValueSchemaStateMachine( 49 | prop_name="test_prop", 50 | prop_schema={"type": "string"}, 51 | context={"defs": {}}, 52 | ) 53 | 54 | stepper = state_machine.get_new_stepper() 55 | assert isinstance(stepper, KeyValueSchemaStepper), "Should return a KeyValueSchemaStepper instance" 56 | assert stepper.state_machine is state_machine, "Stepper should reference the correct state machine" 57 | 58 | def test_key_value_schema_stepper_equality(): 59 | """Test KeyValueSchemaStepper equality.""" 60 | state_machine = KeyValueSchemaStateMachine( 61 | prop_name="test_prop", 62 | prop_schema={"type": "string"}, 63 | context={"defs": {}}, 64 | ) 65 | 66 | steppers1 = state_machine.get_steppers() 67 | steppers2 = state_machine.get_steppers() 68 | assert steppers1[0] == steppers2[0] 69 | 70 | steppers1 = state_machine.advance_all_basic(steppers1, '"test_prop": "value"') 71 | assert len(steppers1) == 1 72 | steppers2 = state_machine.advance_all_basic(steppers2, '"test_prop": "value"') 73 | assert len(steppers2) == 1 74 | assert steppers1[0] == steppers2[0] 75 | 76 | stepper3 = state_machine.get_new_stepper() 77 | assert steppers1[0] != stepper3 78 | -------------------------------------------------------------------------------- /tests/unit/types/json/test_json_to_state_machine.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from typing import Any 3 | 4 | import pytest 5 | 6 | from pse.types.base.chain import ChainStateMachine 7 | from pse.types.enum import EnumStateMachine 8 | from pse.types.json import _json_schema_to_state_machine 9 | from pse.types.json.any_json_schema import AnySchemaStateMachine 10 | from pse.types.json.json_array import ArraySchemaStateMachine 11 | from pse.types.json.json_number import NumberSchemaStateMachine 12 | from pse.types.json.json_object import ObjectSchemaStateMachine 13 | from pse.types.json.json_string import StringSchemaStateMachine 14 | 15 | 16 | @pytest.mark.parametrize( 17 | "schema, expected_acceptor_cls, acceptor_len", 18 | [ 19 | ({"type": "number", "minimum": 0}, NumberSchemaStateMachine, None), 20 | ({"type": "string", "nullable": True}, AnySchemaStateMachine, 2), 21 | ({"type": ["string", "number"], "minimum": 0}, AnySchemaStateMachine, 2), 22 | ( 23 | { 24 | "type": "object", 25 | "properties": { 26 | "name": {"type": "string"}, 27 | "age": {"type": "number", "minimum": 0}, 28 | }, 29 | "required": ["name", "age"], 30 | }, 31 | ObjectSchemaStateMachine, 32 | None, 33 | ), 34 | ( 35 | {"type": "array", "items": {"type": "string"}, "minItems": 1}, 36 | ArraySchemaStateMachine, 37 | None, 38 | ), 39 | ({"enum": ["red", "green", "blue"]}, EnumStateMachine, None), 40 | ( 41 | {"allOf": [{"type": "string"}, {"minLength": 5}]}, 42 | StringSchemaStateMachine, 43 | None, 44 | ), 45 | ({"oneOf": [{"type": "string"}, {"type": "number"}]}, AnySchemaStateMachine, 2), 46 | ({"type": "string", "const": "fixed_value"}, ChainStateMachine, None), 47 | ], 48 | ) 49 | def test_get_acceptor_schema_types( 50 | schema: dict[str, Any], 51 | expected_acceptor_cls: type[Any], 52 | acceptor_len: int | None, 53 | ) -> None: 54 | """Test get_json_acceptor with various schema types and expected acceptors.""" 55 | state_machine = _json_schema_to_state_machine(schema) 56 | assert isinstance(state_machine, expected_acceptor_cls), ( 57 | f"Expected {expected_acceptor_cls.__name__} for schema {schema}" 58 | ) 59 | if acceptor_len is not None: 60 | assert isinstance(state_machine, AnySchemaStateMachine) 61 | assert len(state_machine.state_machines) == acceptor_len, ( 62 | f"Expected state_machine length {acceptor_len} for schema {schema}" 63 | ) 64 | 65 | 66 | @pytest.fixture 67 | def context_with_definition() -> dict[str, Any]: 68 | """Fixture providing context with predefined definitions.""" 69 | context = {"defs": defaultdict(dict), "path": ""} 70 | context["defs"]["#/definitions/address"] = { 71 | "type": "object", 72 | "properties": {"street": {"type": "string"}, "city": {"type": "string"}}, 73 | "required": ["street", "city"], 74 | } 75 | return context 76 | 77 | 78 | def test_get_acceptor_with_ref_schema(context_with_definition: dict[str, Any]) -> None: 79 | """Test get_json_acceptor with a $ref schema referencing a definition.""" 80 | schema = {"$ref": "#/definitions/address"} 81 | state_machine = _json_schema_to_state_machine(schema, context_with_definition) 82 | assert isinstance( 83 | state_machine, 84 | ObjectSchemaStateMachine, 85 | ), ( 86 | "get_json_acceptor should return an ObjectSchemaAcceptor for $ref schemas referencing object definitions." 87 | ) 88 | -------------------------------------------------------------------------------- /tests/unit/types/json/test_json_value.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import pytest 4 | 5 | from pse.types.json.json_value import JsonStateMachine 6 | 7 | 8 | @pytest.fixture 9 | def json_acceptor(): 10 | return JsonStateMachine() 11 | 12 | 13 | @pytest.fixture 14 | def parse_json(json_acceptor: JsonStateMachine): 15 | def parser(json_string: str) -> Any: 16 | steppers = list(json_acceptor.get_steppers()) 17 | for char in json_string: 18 | steppers = json_acceptor.advance_all_basic(steppers, char) 19 | if not steppers: 20 | raise AssertionError("No steppers after parsing") 21 | accepted_values = [ 22 | stepper.get_current_value() 23 | for stepper in steppers 24 | if stepper.has_reached_accept_state() 25 | ] 26 | if not accepted_values: 27 | raise AssertionError("No accepted steppers after parsing") 28 | return accepted_values[0] 29 | 30 | return parser 31 | 32 | 33 | def test_json_acceptor_initialization(json_acceptor): 34 | assert isinstance(json_acceptor, JsonStateMachine), ( 35 | "JsonAcceptor instance was not created." 36 | ) 37 | 38 | 39 | @pytest.mark.parametrize( 40 | "json_string, expected", 41 | [ 42 | ( 43 | '{"name": "John", "age": 30, "city": "New York"}', 44 | {"name": "John", "age": 30, "city": "New York"}, 45 | ), 46 | ('["apple", "banana", "cherry"]', ["apple", "banana", "cherry"]), 47 | ('{"message": "Hello, \\nWorld! \\t😊"}', {"message": "Hello, \nWorld! \t😊"}), 48 | ( 49 | '{"greeting": "こんにちは世界", "emoji": "🚀🌟"}', 50 | {"greeting": "こんにちは世界", "emoji": "🚀🌟"}, 51 | ), 52 | ], 53 | ) 54 | def test_parse_valid_json(json_string, expected, parse_json): 55 | parsed = parse_json(json_string) 56 | assert parsed == expected 57 | 58 | 59 | @pytest.mark.parametrize( 60 | "json_string", 61 | [ 62 | '{"name": "John", "age": 30,, "city": "New York"}', # Invalid syntax 63 | "", # Empty JSON string 64 | ], 65 | ) 66 | def test_parse_invalid_json(json_string, parse_json): 67 | with pytest.raises(AssertionError): 68 | parse_json(json_string) 69 | 70 | 71 | def test_non_zero_state(): 72 | sm = JsonStateMachine() 73 | edges = sm.get_edges(1) 74 | assert not edges 75 | -------------------------------------------------------------------------------- /tests/unit/types/misc/test_fenced_freeform.py: -------------------------------------------------------------------------------- 1 | from pse.types.misc.fenced_freeform import FencedFreeformStateMachine 2 | 3 | 4 | def test_fenced_freeform_default_delimiters(): 5 | """Test FencedFreeformStateMachine with default delimiters.""" 6 | sm = FencedFreeformStateMachine() 7 | input_sequence = "```\nSome freeform text\n```" 8 | steppers = sm.get_steppers() 9 | steppers = sm.advance_all_basic(steppers, input_sequence) 10 | 11 | assert any(stepper.has_reached_accept_state() for stepper in steppers) 12 | 13 | 14 | def test_fenced_freeform_custom_delimiter(): 15 | """Test FencedFreeformStateMachine with custom delimiter.""" 16 | sm = FencedFreeformStateMachine(identifier="json") 17 | input_sequence = "```json\n{\"key\": \"value\"}\n```" 18 | steppers = sm.get_steppers() 19 | steppers = sm.advance_all_basic(steppers, input_sequence) 20 | 21 | assert any(stepper.has_reached_accept_state() for stepper in steppers) 22 | assert steppers[0].get_identifier() == "json" 23 | 24 | 25 | def test_fenced_freeform_missing_open_delimiter(): 26 | """Test rejection when open delimiter is missing.""" 27 | sm = FencedFreeformStateMachine() 28 | input_sequence = "Some freeform text\n```" 29 | steppers = sm.get_steppers() 30 | steppers = sm.advance_all_basic(steppers, input_sequence) 31 | 32 | assert not any(stepper.has_reached_accept_state() for stepper in steppers) 33 | 34 | 35 | def test_fenced_freeform_missing_close_delimiter(): 36 | """Test rejection when close delimiter is missing.""" 37 | sm = FencedFreeformStateMachine() 38 | input_sequence = "```\nSome freeform text" 39 | steppers = sm.get_steppers() 40 | steppers = sm.advance_all_basic(steppers, input_sequence) 41 | 42 | assert not any(stepper.has_reached_accept_state() for stepper in steppers) 43 | for stepper in steppers: 44 | if stepper.sub_stepper and not stepper.sub_stepper.is_within_value(): 45 | assert not stepper.get_invalid_continuations() 46 | 47 | 48 | def test_fenced_freeform_partial_delimiter(): 49 | """Test rejection when delimiter is partially provided.""" 50 | sm = FencedFreeformStateMachine() 51 | input_sequence = "``\nSome freeform text\n```" 52 | steppers = sm.get_steppers() 53 | steppers = sm.advance_all_basic(steppers, input_sequence) 54 | 55 | assert not any(stepper.has_reached_accept_state() for stepper in steppers) 56 | 57 | 58 | def test_fenced_freeform_respects_char_min(): 59 | """Test that char_min is respected.""" 60 | sm = FencedFreeformStateMachine(char_min=10) 61 | input_sequence = "```\nshort\n```" 62 | steppers = sm.get_steppers() 63 | steppers = sm.advance_all_basic(steppers, input_sequence) 64 | 65 | assert not any(stepper.has_reached_accept_state() for stepper in steppers) 66 | 67 | input_sequence_valid = "```\nlong enough text\n```" 68 | steppers = sm.get_steppers() 69 | steppers = sm.advance_all_basic(steppers, input_sequence_valid) 70 | 71 | assert any(stepper.has_reached_accept_state() for stepper in steppers) 72 | 73 | 74 | def test_fenced_freeform_respects_char_max(): 75 | """Test that char_max is respected.""" 76 | sm = FencedFreeformStateMachine(char_max=10) 77 | input_sequence = "```\nthis text is too long\n```" 78 | steppers = sm.get_steppers() 79 | steppers = sm.advance_all_basic(steppers, input_sequence) 80 | 81 | assert not any(stepper.has_reached_accept_state() for stepper in steppers) 82 | 83 | input_sequence_valid = "```\nshort\n```" 84 | steppers = sm.get_steppers() 85 | steppers = sm.advance_all_basic(steppers, input_sequence_valid) 86 | 87 | assert any(stepper.has_reached_accept_state() for stepper in steppers) 88 | 89 | 90 | def test_fenced_freeform_invalid_continuations(): 91 | """Test invalid continuations within fenced freeform.""" 92 | sm = FencedFreeformStateMachine() 93 | steppers = sm.get_steppers() 94 | steppers = sm.advance_all_basic(steppers, "```\nSome text") 95 | assert len(steppers) > 0 96 | for stepper in steppers: 97 | if stepper.target_state == 2: 98 | invalid_continuations = stepper.get_invalid_continuations() 99 | assert "\n```" in invalid_continuations 100 | 101 | def test_fenced_freeform_buffer_length(): 102 | """Test buffer_length constraint.""" 103 | sm = FencedFreeformStateMachine(buffer_length=5) 104 | steppers = sm.get_steppers() 105 | assert not any(stepper.should_start_step("```\n") for stepper in steppers) 106 | 107 | steppers = sm.advance_all_basic(steppers, "12345") 108 | assert any(stepper.should_start_step("```\n") for stepper in steppers) 109 | -------------------------------------------------------------------------------- /tests/unit/types/misc/test_freeform.py: -------------------------------------------------------------------------------- 1 | from pse.types.misc.freeform import FreeformStateMachine, FreeformStepper 2 | 3 | 4 | def test_freeform_basic(): 5 | """Test basic functionality of FreeformStateMachine.""" 6 | sm = FreeformStateMachine(end_delimiters=["END"]) 7 | input_sequence = "Some freeform textEND" 8 | steppers = sm.get_steppers() 9 | steppers = sm.advance_all_basic(steppers, input_sequence) 10 | 11 | assert any(stepper.has_reached_accept_state() for stepper in steppers) 12 | accepted_stepper = next(stepper for stepper in steppers if stepper.has_reached_accept_state()) 13 | assert accepted_stepper.get_raw_value() == "Some freeform textEND" 14 | 15 | # Test token_safe_output removes delimiter 16 | assert accepted_stepper.get_token_safe_output(lambda x: "".join(chr(i) for i in x)) == "Some freeform text" 17 | 18 | 19 | def test_freeform_multiple_delimiters(): 20 | """Test FreeformStateMachine with multiple end delimiters.""" 21 | sm = FreeformStateMachine(end_delimiters=["END", "STOP", "FINISH"]) 22 | 23 | # Test with first delimiter 24 | input_sequence = "Text with first delimiterEND" 25 | steppers = sm.get_steppers() 26 | steppers = sm.advance_all_basic(steppers, input_sequence) 27 | assert any(stepper.has_reached_accept_state() for stepper in steppers) 28 | 29 | # Test with second delimiter 30 | input_sequence = "Text with second delimiterSTOP" 31 | steppers = sm.get_steppers() 32 | steppers = sm.advance_all_basic(steppers, input_sequence) 33 | assert any(stepper.has_reached_accept_state() for stepper in steppers) 34 | 35 | # Test with third delimiter 36 | input_sequence = "Text with third delimiterFINISH" 37 | steppers = sm.get_steppers() 38 | steppers = sm.advance_all_basic(steppers, input_sequence) 39 | assert any(stepper.has_reached_accept_state() for stepper in steppers) 40 | 41 | 42 | def test_freeform_missing_delimiter(): 43 | """Test that text without an ending delimiter is not accepted.""" 44 | sm = FreeformStateMachine(end_delimiters=["END"]) 45 | input_sequence = "Text without delimiter" 46 | steppers = sm.get_steppers() 47 | steppers = sm.advance_all_basic(steppers, input_sequence) 48 | 49 | assert not any(stepper.has_reached_accept_state() for stepper in steppers) 50 | 51 | 52 | def test_freeform_partial_delimiter(): 53 | """Test that partial delimiters aren't accepted.""" 54 | sm = FreeformStateMachine(end_delimiters=["END"]) 55 | input_sequence = "Text with partial delimiterEN" 56 | steppers = sm.get_steppers() 57 | steppers = sm.advance_all_basic(steppers, input_sequence) 58 | 59 | assert not any(stepper.has_reached_accept_state() for stepper in steppers) 60 | 61 | 62 | def test_freeform_respects_char_min(): 63 | """Test that char_min is respected.""" 64 | sm = FreeformStateMachine(end_delimiters=["END"], char_min=10) 65 | 66 | # Test with text shorter than char_min 67 | input_sequence = "shortEND" 68 | steppers = sm.get_steppers() 69 | steppers = sm.advance_all_basic(steppers, input_sequence) 70 | assert not any(stepper.has_reached_accept_state() for stepper in steppers) 71 | 72 | # Test with text meeting char_min requirement 73 | input_sequence = "long enough textEND" 74 | steppers = sm.get_steppers() 75 | steppers = sm.advance_all_basic(steppers, input_sequence) 76 | assert any(stepper.has_reached_accept_state() for stepper in steppers) 77 | 78 | 79 | def test_freeform_stepper_initialization(): 80 | """Test the initialization of FreeformStepper.""" 81 | sm = FreeformStateMachine(end_delimiters=["END"]) 82 | stepper = sm.get_new_stepper() 83 | 84 | assert isinstance(stepper, FreeformStepper) 85 | assert stepper.state_machine == sm 86 | assert stepper.buffer == "" 87 | 88 | 89 | def test_freeform_token_safe_output(): 90 | """Test the token_safe_output method.""" 91 | sm = FreeformStateMachine(end_delimiters=["END", "STOP"]) 92 | 93 | # Test with first delimiter 94 | steppers = sm.get_steppers() 95 | steppers = sm.advance_all_basic(steppers, "Some textEND") 96 | accepted_stepper = next(stepper for stepper in steppers if stepper.has_reached_accept_state()) 97 | assert accepted_stepper.get_token_safe_output(lambda x: "".join(chr(i) for i in x)) == "Some text" 98 | 99 | # Test with second delimiter 100 | steppers = sm.get_steppers() 101 | steppers = sm.advance_all_basic(steppers, "Other textSTOP") 102 | accepted_stepper = next(stepper for stepper in steppers if stepper.has_reached_accept_state()) 103 | assert accepted_stepper.get_token_safe_output(lambda x: "".join(chr(i) for i in x)) == "Other text" 104 | 105 | 106 | def test_freeform_get_raw_value(): 107 | """Test the get_raw_value method.""" 108 | sm = FreeformStateMachine(end_delimiters=["END"]) 109 | input_sequence = "Raw text valueEND" 110 | steppers = sm.get_steppers() 111 | steppers = sm.advance_all_basic(steppers, input_sequence) 112 | 113 | accepted_stepper = next(stepper for stepper in steppers if stepper.has_reached_accept_state()) 114 | assert accepted_stepper.get_raw_value() == input_sequence 115 | 116 | # Ensure the raw value includes the delimiter 117 | assert accepted_stepper.get_raw_value().endswith("END") 118 | 119 | 120 | def test_freeform_string_representation(): 121 | """Test the string representation of FreeformStateMachine.""" 122 | sm = FreeformStateMachine(end_delimiters=["END"]) 123 | assert str(sm) == "FreeformText" 124 | -------------------------------------------------------------------------------- /tests/unit/types/python/test_python_code.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pse_core.state_machine import StateMachine 3 | 4 | from pse.types.grammar import PythonStateMachine 5 | 6 | 7 | @pytest.mark.parametrize( 8 | "code, should_accept", 9 | [ 10 | ("x = 1", True), 11 | ("x =! 1", False), 12 | ("x != 1", True), 13 | ("def foo():\n pass", True), 14 | ("class Test:\n pass", True), 15 | ("print('Hello, World!')", True), 16 | ("if True:\n print('test')", True), 17 | # Test incomplete code 18 | ("1 + ", False), 19 | ("def", False), 20 | ("for i in", False), 21 | ("x =", False), 22 | ("", False), 23 | ("class:", False), 24 | ("x y z", False), 25 | # Test expressions 26 | ("1 + 2", True), 27 | ("len([1, 2, 3])", True), 28 | ("'test'.upper()", True), 29 | # Test multiline code 30 | ("x = 1\ny = 2\nz = x + y", True), 31 | ("def test():\n x = 1\n return x", True), 32 | # Test comments 33 | ("# This is a comment", True), 34 | ("x = 1 # inline comment", True), 35 | ("'''docstring'''", True), 36 | # Test complex Python features 37 | ("lambda x: x * 2", True), 38 | ("try:\n x()\nexcept:\n pass", True), 39 | ("with open('file') as f:\n pass", True), 40 | ("[x for x in range(10)]", True), 41 | # Test invalid syntax 42 | ("def def", False), 43 | ("class class", False), 44 | ("return return", False), 45 | ("import import", False), 46 | ], 47 | ) 48 | def test_python_source_validation(code, should_accept): 49 | """Test validation of Python source code.""" 50 | source_code_sm = StateMachine( 51 | { 52 | 0: [(PythonStateMachine, "$")], 53 | } 54 | ) 55 | steppers = source_code_sm.get_steppers() 56 | steppers = source_code_sm.advance_all_basic(steppers, code) 57 | 58 | if should_accept: 59 | assert any(stepper.has_reached_accept_state() for stepper in steppers), ( 60 | f"Should accept valid Python code: {code}" 61 | ) 62 | else: 63 | assert not any(stepper.has_reached_accept_state() for stepper in steppers), ( 64 | f"Should not accept invalid Python code: {code}" 65 | ) 66 | 67 | 68 | def test_incremental_parsing(): 69 | """Test incremental parsing of Python code.""" 70 | source_code_sm = StateMachine( 71 | { 72 | 0: [(PythonStateMachine, "$")], 73 | } 74 | ) 75 | steppers = source_code_sm.get_steppers() 76 | 77 | # Test valid incremental input 78 | code_parts = [ 79 | "def test", 80 | " ", 81 | "(", 82 | ")", 83 | ":", 84 | "\n", 85 | " ", 86 | "x", 87 | " ", 88 | "=", 89 | " ", 90 | "1", 91 | "\n", 92 | " ", 93 | "return", 94 | " ", 95 | ] 96 | 97 | for part in code_parts: 98 | steppers = source_code_sm.advance_all_basic(steppers, part) 99 | assert len(steppers) > 0, f"Should accept partial input: {part}" 100 | 101 | assert any(stepper.has_reached_accept_state() for stepper in steppers) 102 | 103 | 104 | def test_empty_input(): 105 | """Test handling of empty input.""" 106 | source_code_sm = StateMachine( 107 | { 108 | 0: [(PythonStateMachine, "$")], 109 | } 110 | ) 111 | steppers = source_code_sm.get_steppers() 112 | steppers = source_code_sm.advance_all_basic(steppers, "") 113 | assert not any(stepper.has_reached_accept_state() for stepper in steppers) 114 | 115 | 116 | @pytest.mark.parametrize( 117 | "incomplete_code", 118 | [ 119 | "if x == 1:\n ", 120 | "def test():\n x = ", 121 | "class MyClass:\n def ", 122 | "try:\n x = 1\nexcept ", 123 | "with open('file') as ", 124 | "def foo():", 125 | "if True:", 126 | "for x in range(10):", 127 | ], 128 | ) 129 | def test_incomplete_but_valid_code(incomplete_code): 130 | """Test handling of incomplete but syntactically valid code.""" 131 | source_code_sm = StateMachine( 132 | { 133 | 0: [(PythonStateMachine, "$")], 134 | } 135 | ) 136 | steppers = source_code_sm.get_steppers() 137 | steppers = source_code_sm.advance_all_basic(steppers, incomplete_code) 138 | assert len(steppers) == 1 139 | assert all(stepper.can_accept_more_input() for stepper in steppers) 140 | 141 | 142 | def test_identifer(): 143 | """Test identifier of PythonStateMachine.""" 144 | assert PythonStateMachine.get_new_stepper(None).get_identifier() == "python" 145 | 146 | def test_invalid_code_advance(): 147 | """Test invalid code advance.""" 148 | sm = StateMachine( 149 | { 150 | 0: [(PythonStateMachine, "$")], 151 | } 152 | ) 153 | steppers = sm.get_steppers() 154 | steppers = sm.advance_all_basic(steppers, "print(") 155 | assert len(steppers) == 1 156 | # advance with invalid characters 157 | steppers = sm.advance_all_basic(steppers, '!!!!!!!!!') 158 | assert not steppers 159 | -------------------------------------------------------------------------------- /tests/unit/types/python/test_wrapped_python.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from pse.types.base.encapsulated import EncapsulatedStateMachine 4 | from pse.types.grammar import PythonStateMachine 5 | 6 | 7 | def test_basic_python_block(): 8 | """Test basic Python code block parsing.""" 9 | python_sm = EncapsulatedStateMachine( 10 | PythonStateMachine, delimiters=PythonStateMachine.delimiters 11 | ) 12 | steppers = python_sm.get_steppers() 13 | 14 | # Test opening delimiter 15 | steppers = python_sm.advance_all_basic(steppers, "```python\n") 16 | assert len(steppers) > 0 17 | # Test Python code 18 | steppers = python_sm.advance_all_basic(steppers, "x = 1") 19 | assert len(steppers) > 0 20 | 21 | # Test closing delimiter 22 | steppers = python_sm.advance_all_basic(steppers, "\n```") 23 | assert any(stepper.has_reached_accept_state() for stepper in steppers) 24 | 25 | 26 | @pytest.mark.parametrize( 27 | "code_block", 28 | [ 29 | "```python\nprint('Hello')\n\n```", 30 | "```python\nx = 1\ny = 2\nprint(x + y)\n\n```", 31 | "```python\ndef test():\n return True\n\n```", 32 | "```python\nclass Test:\n pass\n\n```", 33 | ], 34 | ) 35 | def test_complete_python_blocks(code_block): 36 | """Test various complete Python code blocks.""" 37 | python_sm = EncapsulatedStateMachine( 38 | PythonStateMachine, delimiters=PythonStateMachine.delimiters 39 | ) 40 | steppers = python_sm.get_steppers() 41 | steppers = python_sm.advance_all_basic(steppers, code_block) 42 | assert any(stepper.has_reached_accept_state() for stepper in steppers) 43 | 44 | 45 | def test_custom_delimiters(): 46 | """Test PythonStateMachine with custom delimiters.""" 47 | sm = EncapsulatedStateMachine( 48 | PythonStateMachine, delimiters=("", "") 49 | ) 50 | steppers = sm.get_steppers() 51 | 52 | code_block = "x = 1" 53 | steppers = sm.advance_all_basic(steppers, code_block) 54 | assert any(stepper.has_reached_accept_state() for stepper in steppers) 55 | 56 | 57 | def test_stepper_clone(): 58 | """Test cloning of PythonStepper.""" 59 | python_sm = EncapsulatedStateMachine( 60 | PythonStateMachine, delimiters=PythonStateMachine.delimiters 61 | ) 62 | steppers = python_sm.get_steppers() 63 | steppers = python_sm.advance_all_basic(steppers, "```python\nx = 1\n") 64 | 65 | original_stepper = steppers[0] 66 | cloned_stepper = original_stepper.clone() 67 | 68 | assert original_stepper.get_current_value() == cloned_stepper.get_current_value() 69 | assert original_stepper is not cloned_stepper 70 | 71 | 72 | @pytest.mark.parametrize( 73 | "invalid_block", 74 | [ 75 | "```python\ndef invalid syntax\n\n```", 76 | "```python\nclass:\n\n```", 77 | "```python\nwhile:\n\n```", 78 | "```python\nprint('no closing parenthesis'\n\n```", 79 | ], 80 | ) 81 | def test_invalid_python_blocks(invalid_block): 82 | """Test handling of invalid Python code blocks.""" 83 | python_sm = EncapsulatedStateMachine( 84 | PythonStateMachine, delimiters=PythonStateMachine.delimiters 85 | ) 86 | steppers = python_sm.get_steppers() 87 | steppers = python_sm.advance_all_basic(steppers, invalid_block) 88 | assert not any(stepper.has_reached_accept_state() for stepper in steppers) 89 | 90 | 91 | def test_incremental_parsing(): 92 | """Test incremental parsing of Python code block.""" 93 | python_sm = EncapsulatedStateMachine( 94 | PythonStateMachine, delimiters=PythonStateMachine.delimiters 95 | ) 96 | steppers = python_sm.get_steppers() 97 | 98 | parts = [ 99 | "```python\n", 100 | "def test():\n", 101 | " x = 1\n", 102 | " return x\n", 103 | "\n```", 104 | ] 105 | 106 | for part in parts: 107 | steppers = python_sm.advance_all_basic(steppers, part) 108 | assert len(steppers) > 0 109 | 110 | assert any(stepper.has_reached_accept_state() for stepper in steppers) 111 | 112 | 113 | def test_can_accept_more_input(): 114 | """Test that the stepper can accept more input.""" 115 | python_sm = EncapsulatedStateMachine( 116 | PythonStateMachine, delimiters=PythonStateMachine.delimiters 117 | ) 118 | steppers = python_sm.get_steppers() 119 | steppers = python_sm.advance_all_basic(steppers, "```python\n") 120 | assert len(steppers) > 0 121 | steppers = python_sm.advance_all_basic(steppers, "print('Hello") 122 | assert len(steppers) > 0 123 | assert all(stepper.can_accept_more_input() for stepper in steppers) 124 | steppers = python_sm.advance_all_basic(steppers, "')") 125 | assert len(steppers) > 0 126 | assert all(stepper.can_accept_more_input() for stepper in steppers) 127 | steppers = python_sm.advance_all_basic(steppers, "\n```") 128 | assert len(steppers) > 0 129 | assert any(stepper.has_reached_accept_state() for stepper in steppers) 130 | -------------------------------------------------------------------------------- /tests/unit/types/test_array.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import pytest 4 | 5 | from pse.types.array import ArrayStateMachine 6 | 7 | 8 | @pytest.fixture 9 | def state_machine(): 10 | """Fixture that provides an ArrayAcceptor instance.""" 11 | return ArrayStateMachine() 12 | 13 | 14 | def parse_array(state_machine: ArrayStateMachine, json_string: str) -> list[Any]: 15 | """ 16 | Helper function to parse a JSON array string using the ArrayAcceptor. 17 | 18 | Args: 19 | state_machine (ArrayAcceptor): The ArrayAcceptor instance. 20 | json_string (str): The JSON array string to parse. 21 | 22 | Returns: 23 | list[Any]: The parsed array. 24 | 25 | Raises: 26 | AssertionError: If the JSON array is invalid. 27 | """ 28 | steppers = state_machine.get_steppers() 29 | for char in json_string: 30 | steppers = state_machine.advance_all_basic(steppers, char) 31 | if not any(stepper.has_reached_accept_state() for stepper in steppers): 32 | raise AssertionError("No stepper in accepted state") 33 | # Assuming the first accepted stepper contains the parsed value 34 | for stepper in steppers: 35 | if stepper.has_reached_accept_state(): 36 | return stepper.get_current_value() 37 | return [] 38 | 39 | 40 | # Parameterized tests for valid arrays 41 | @pytest.mark.parametrize( 42 | "json_string, expected", 43 | [ 44 | ("[]", []), 45 | ("[1]", [1]), 46 | ("[123]", [123]), 47 | ('[123, 456, "789"]', [123, 456, "789"]), 48 | ("[[1, 2], [3, 4]]", [[1, 2], [3, 4]]), 49 | ], 50 | ) 51 | def test_valid_arrays( 52 | state_machine: ArrayStateMachine, json_string: str, expected: list[Any] 53 | ): 54 | """Test parsing of valid JSON arrays.""" 55 | assert parse_array(state_machine, json_string) == expected 56 | 57 | 58 | # Parameterized tests for invalid arrays 59 | @pytest.mark.parametrize( 60 | "json_string", 61 | [ 62 | "[123, 456", # Missing closing bracket 63 | "[123, 456, ]", # Trailing comma 64 | ], 65 | ) 66 | def test_invalid_arrays(state_machine: ArrayStateMachine, json_string: str): 67 | """Test that an AssertionError is raised for invalid arrays.""" 68 | with pytest.raises(AssertionError): 69 | parse_array(state_machine, json_string) 70 | -------------------------------------------------------------------------------- /tests/unit/types/test_boolean.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pse_core.state_machine import StateMachine 3 | from pse_core.stepper import Stepper 4 | 5 | from pse.types.boolean import BooleanStateMachine 6 | 7 | 8 | # Fixture for BooleanAcceptor 9 | @pytest.fixture 10 | def boolean_acceptor(): 11 | return BooleanStateMachine() 12 | 13 | 14 | # Helper function to process input for BooleanAcceptor 15 | def process_input(state_machine: StateMachine, token: str) -> list[Stepper]: 16 | steppers = state_machine.get_steppers() 17 | return state_machine.advance_all_basic(steppers, token) 18 | 19 | 20 | # Test for BooleanAcceptor 21 | def test_accept_true() -> None: 22 | acceptor = BooleanStateMachine() 23 | steppers = acceptor.get_steppers() 24 | accepted_steppers = acceptor.advance_all_basic(steppers, "true") 25 | assert any(stepper.get_current_value() is True for stepper in accepted_steppers), ( 26 | "Should have a stepper with value True" 27 | ) 28 | 29 | 30 | def test_accept_false(boolean_acceptor: BooleanStateMachine) -> None: 31 | steppers = boolean_acceptor.get_steppers() 32 | accepted_steppers = boolean_acceptor.advance_all_basic(steppers, "false") 33 | assert any(stepper.get_current_value() is False for stepper in accepted_steppers), ( 34 | "Should have a stepper with value False" 35 | ) 36 | 37 | 38 | @pytest.mark.parametrize("token", ["tru", "fals", "True", "False", "TRUE", "FALSE"]) 39 | def test_reject_invalid_boolean(boolean_acceptor, token): 40 | accepted_steppers = list(process_input(boolean_acceptor, token)) 41 | assert not any( 42 | stepper.has_reached_accept_state() for stepper in accepted_steppers 43 | ), f"Should not accept '{token}' as a valid boolean." 44 | 45 | 46 | @pytest.mark.parametrize("token", [" true", "false ", " true ", " false "]) 47 | def test_accept_with_whitespace(boolean_acceptor, token): 48 | accepted_steppers = list(process_input(boolean_acceptor, token)) 49 | assert not any( 50 | stepper.has_reached_accept_state() for stepper in accepted_steppers 51 | ), f"Should not accept '{token}' with whitespace." 52 | 53 | 54 | @pytest.mark.parametrize("token", ["truex", "falsey", "true123", "false!"]) 55 | def test_extra_characters(boolean_acceptor, token): 56 | accepted_steppers = list(process_input(boolean_acceptor, token)) 57 | assert not any( 58 | stepper.has_reached_accept_state() for stepper in accepted_steppers 59 | ), f"Should not accept '{token}' with extra characters." 60 | -------------------------------------------------------------------------------- /tests/unit/types/test_enum.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from pse.types.enum import EnumStateMachine 4 | 5 | 6 | def test_accept_valid_enum_value(): 7 | """Test that the state machine correctly accepts a value present in the enum.""" 8 | sm = EnumStateMachine(["value1", "value2", "value3"], require_quotes=False) 9 | steppers = sm.get_steppers() 10 | steppers = sm.advance_all_basic(steppers, "value1") 11 | 12 | assert any(stepper.has_reached_accept_state() for stepper in steppers) 13 | for stepper in steppers: 14 | if stepper.has_reached_accept_state(): 15 | assert stepper.get_current_value() == "value1" 16 | 17 | 18 | def test_reject_invalid_enum_value(): 19 | """Test that the state machine correctly rejects a value not present in the enum.""" 20 | sm = EnumStateMachine(["value1", "value2", "value3"], require_quotes=False) 21 | steppers = sm.get_steppers() 22 | steppers = sm.advance_all_basic(steppers, "invalid_value") 23 | 24 | assert not any(stepper.has_reached_accept_state() for stepper in steppers) 25 | 26 | 27 | @pytest.mark.parametrize("value", ["value1", "value2", "value3"]) 28 | def test_accept_multiple_enum_values(value): 29 | """Test that the state machine correctly accepts multiple different valid enum values.""" 30 | sm = EnumStateMachine(["value1", "value2", "value3"], require_quotes=False) 31 | steppers = sm.get_steppers() 32 | steppers = sm.advance_all_basic(steppers, value) 33 | 34 | assert any(stepper.has_reached_accept_state() for stepper in steppers) 35 | for stepper in steppers: 36 | if stepper.has_reached_accept_state(): 37 | assert stepper.get_current_value() == value 38 | 39 | 40 | def test_partial_enum_value_rejection(): 41 | """Test that the state machine does not accept prefixes of valid enum values.""" 42 | sm = EnumStateMachine(["value1", "value2", "value3"], require_quotes=False) 43 | steppers = sm.get_steppers() 44 | steppers = sm.advance_all_basic(steppers, "val") 45 | 46 | assert not any(stepper.has_reached_accept_state() for stepper in steppers) 47 | 48 | 49 | def test_init_with_empty_enum(): 50 | """Test initializing EnumStateMachine with empty enum values raises ValueError.""" 51 | with pytest.raises(ValueError): 52 | EnumStateMachine(enum_values=[]) 53 | 54 | 55 | @pytest.mark.parametrize("special_value", ["val!@#", "val-123", "val_😊"]) 56 | def test_accept_enum_with_special_characters(special_value): 57 | """Test that the state machine correctly handles enum values with special characters.""" 58 | sm = EnumStateMachine([special_value], require_quotes=False) 59 | steppers = sm.get_steppers() 60 | steppers = sm.advance_all_basic(steppers, special_value) 61 | 62 | assert any(stepper.has_reached_accept_state() for stepper in steppers) 63 | for stepper in steppers: 64 | if stepper.has_reached_accept_state(): 65 | assert stepper.get_current_value() == special_value 66 | 67 | 68 | def test_char_by_char_enum_parsing(): 69 | """Test parsing enum values character by character.""" 70 | sm = EnumStateMachine(["value1", "value2", "value3"], require_quotes=False) 71 | steppers = sm.get_steppers() 72 | 73 | for char in "value1": 74 | steppers = sm.advance_all_basic(steppers, char) 75 | if not steppers: 76 | break 77 | 78 | assert any(stepper.has_reached_accept_state() for stepper in steppers) 79 | for stepper in steppers: 80 | if stepper.has_reached_accept_state(): 81 | assert stepper.get_current_value() == "value1" 82 | 83 | 84 | @pytest.mark.parametrize( 85 | "value", 86 | [ 87 | '"test"', 88 | "'test'", 89 | ], 90 | ) 91 | def test_enum_with_quotes(value): 92 | """Test enum values with quotes requirement (default behavior).""" 93 | sm = EnumStateMachine(["test"]) # require_quotes defaults to True 94 | steppers = sm.get_steppers() 95 | steppers = sm.advance_all_basic(steppers, value) 96 | 97 | for stepper in steppers: 98 | assert stepper.has_reached_accept_state() 99 | assert stepper.get_current_value() == "test" 100 | 101 | 102 | def test_enum_without_quotes(): 103 | """Test enum values without quotes requirement.""" 104 | sm = EnumStateMachine(["test"], require_quotes=False) 105 | steppers = sm.get_steppers() 106 | steppers = sm.advance_all_basic(steppers, "test") 107 | 108 | assert any(stepper.has_reached_accept_state() for stepper in steppers) 109 | for stepper in steppers: 110 | if stepper.has_reached_accept_state(): 111 | assert stepper.get_current_value() == "test" 112 | 113 | 114 | def test_enum_requires_quotes_by_default(): 115 | """Test that enum values require quotes by default.""" 116 | sm = EnumStateMachine(["test"]) # require_quotes defaults to True 117 | steppers = sm.get_steppers() 118 | steppers = sm.advance_all_basic(steppers, "test") # no quotes 119 | 120 | assert not any(stepper.has_reached_accept_state() for stepper in steppers) 121 | -------------------------------------------------------------------------------- /tests/unit/types/test_key_value.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from pse.types.base.phrase import PhraseStateMachine 4 | from pse.types.key_value import KeyValueStateMachine 5 | 6 | 7 | @pytest.mark.parametrize( 8 | "input_string, expected_name, expected_value", 9 | [ 10 | ('"key": "value"', "key", "value"), 11 | ('"complex_key": {"nested": "value"}', "complex_key", {"nested": "value"}), 12 | ('"": "empty_key"', "", None), 13 | ('"unicode_key": "unicode_value🎉"', "unicode_key", "unicode_value🎉"), 14 | ('"spaced_key" : "spaced_value"', "spaced_key", "spaced_value"), 15 | ], 16 | ) 17 | def test_property_parsing(input_string, expected_name, expected_value): 18 | sm = KeyValueStateMachine() 19 | steppers = sm.get_steppers() 20 | for char in input_string: 21 | steppers = sm.advance_all_basic(steppers, char) 22 | 23 | accepted_steppers = [ 24 | stepper for stepper in steppers if stepper.has_reached_accept_state() 25 | ] 26 | assert accepted_steppers, ( 27 | f"No stepper reached an accepted state for: {input_string}" 28 | ) 29 | 30 | for stepper in accepted_steppers: 31 | name, value = stepper.get_current_value() 32 | assert name == expected_name 33 | assert value == expected_value 34 | 35 | 36 | @pytest.mark.parametrize( 37 | "invalid_input", 38 | [ 39 | 'key: "value"', # missing quotes around key 40 | '"key" "value"', # missing colon 41 | '"key":', # missing value 42 | '"key": value', # unquoted value 43 | ':"value"', # missing key 44 | ], 45 | ) 46 | def test_invalid_property_formats(invalid_input): 47 | sm = KeyValueStateMachine() 48 | steppers = sm.get_steppers() 49 | for char in invalid_input: 50 | steppers = sm.advance_all_basic(steppers, char) 51 | 52 | assert not any(stepper.has_reached_accept_state() for stepper in steppers), ( 53 | f"Stepper should not reach accepted state for invalid input: {invalid_input}" 54 | ) 55 | 56 | 57 | def test_empty_key_value(): 58 | sm = KeyValueStateMachine() 59 | steppers = sm.get_steppers() 60 | assert len(steppers) == 1 61 | assert not steppers[0].should_complete_step() 62 | assert steppers[0].get_current_value() == ("", None) 63 | 64 | 65 | def test_invalid_sub_stepper_json(): 66 | sm = KeyValueStateMachine() 67 | steppers = sm.get_steppers() 68 | assert len(steppers) == 1 69 | stepper = steppers[0] 70 | stepper.sub_stepper = PhraseStateMachine('"invalid json').get_new_stepper() 71 | steppers = sm.advance_all_basic(steppers, '"invalid json') 72 | assert len(steppers) == 1 73 | assert not steppers[0].should_complete_step() 74 | 75 | def test_key_value_equality(): 76 | sm1 = KeyValueStateMachine() 77 | sm2 = KeyValueStateMachine() 78 | sm3 = KeyValueStateMachine() 79 | 80 | assert sm1 == sm2 81 | assert sm1 == sm3 82 | assert sm2 == sm3 83 | -------------------------------------------------------------------------------- /tests/unit/types/xml/test_xml_encapsulated.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import pytest 4 | 5 | from pse.types.base.phrase import PhraseStateMachine 6 | from pse.types.xml.xml_encapsulated import XMLEncapsulatedStateMachine 7 | 8 | 9 | def test_basic_wrapped_content() -> None: 10 | """Test recognition of basic wrapped content.""" 11 | inner_sm = PhraseStateMachine("content") 12 | wrapped_sm = XMLEncapsulatedStateMachine(inner_sm, "div") 13 | steppers = wrapped_sm.get_steppers() 14 | 15 | input_sequence = "
content
" 16 | steppers = wrapped_sm.advance_all_basic(steppers, input_sequence) 17 | assert any(stepper.has_reached_accept_state() for stepper in steppers) 18 | 19 | 20 | def test_incremental_parsing() -> None: 21 | """Test incremental parsing of wrapped content.""" 22 | inner_sm = PhraseStateMachine("hello") 23 | wrapped_sm = XMLEncapsulatedStateMachine(inner_sm, "greeting") 24 | steppers = wrapped_sm.get_steppers() 25 | 26 | parts = ["", "hel", "lo", ""] 27 | 28 | for part in parts: 29 | steppers = wrapped_sm.advance_all_basic(steppers, part) 30 | assert len(steppers) > 0 31 | 32 | assert any(stepper.has_reached_accept_state() for stepper in steppers) 33 | 34 | 35 | def test_nested_wrapped_content() -> None: 36 | """Test nested XML wrapped content.""" 37 | inner_content = PhraseStateMachine("text") 38 | inner_wrapped = XMLEncapsulatedStateMachine(inner_content, "span") 39 | outer_wrapped = XMLEncapsulatedStateMachine(inner_wrapped, "div") 40 | 41 | input_sequence = "
text
" 42 | steppers = outer_wrapped.get_steppers() 43 | steppers = outer_wrapped.advance_all_basic(steppers, input_sequence) 44 | 45 | assert any(stepper.has_reached_accept_state() for stepper in steppers) 46 | 47 | 48 | def test_min_buffer_length() -> None: 49 | """Test that min_buffer_length is respected.""" 50 | inner_sm = PhraseStateMachine("content") 51 | wrapped_sm = XMLEncapsulatedStateMachine(inner_sm, "div", min_buffer_length=10) 52 | steppers = wrapped_sm.get_steppers() 53 | 54 | # Should not start with opening tag yet 55 | assert not any(stepper.should_start_step("
") for stepper in steppers) 56 | 57 | # Add sufficient buffer 58 | buffer_content = "x" * 10 59 | steppers = wrapped_sm.advance_all_basic(steppers, buffer_content) 60 | assert any(stepper.should_start_step("
") for stepper in steppers) 61 | 62 | 63 | def test_invalid_content() -> None: 64 | """Test rejection of invalid inner content.""" 65 | inner_sm = PhraseStateMachine("expected") 66 | wrapped_sm = XMLEncapsulatedStateMachine(inner_sm, "div") 67 | steppers = wrapped_sm.get_steppers() 68 | 69 | input_sequence = "
unexpected
" 70 | steppers = wrapped_sm.advance_all_basic(steppers, input_sequence) 71 | assert not any(stepper.has_reached_accept_state() for stepper in steppers) 72 | 73 | 74 | def test_missing_closing_tag() -> None: 75 | """Test rejection when closing tag is missing.""" 76 | inner_sm = PhraseStateMachine("content") 77 | wrapped_sm = XMLEncapsulatedStateMachine(inner_sm, "div") 78 | steppers = wrapped_sm.get_steppers() 79 | 80 | input_sequence = "
content" 81 | steppers = wrapped_sm.advance_all_basic(steppers, input_sequence) 82 | assert not any(stepper.has_reached_accept_state() for stepper in steppers) 83 | 84 | 85 | def test_mismatched_tags() -> None: 86 | """Test rejection of mismatched opening and closing tags.""" 87 | inner_sm = PhraseStateMachine("content") 88 | wrapped_sm = XMLEncapsulatedStateMachine(inner_sm, "div") 89 | steppers = wrapped_sm.get_steppers() 90 | 91 | input_sequence = "
content" 92 | steppers = wrapped_sm.advance_all_basic(steppers, input_sequence) 93 | assert not any(stepper.has_reached_accept_state() for stepper in steppers) 94 | 95 | 96 | @pytest.mark.parametrize( 97 | "tag_name, content, should_accept", 98 | [ 99 | ("div", "content", True), 100 | ("p", "multi\nline\ncontent", True), 101 | ("custom-tag", "content with spaces", True), 102 | ], 103 | ) 104 | def test_various_content_scenarios( 105 | tag_name: str, content: str, should_accept: bool 106 | ) -> None: 107 | """Test various content scenarios.""" 108 | inner_sm = PhraseStateMachine(content) 109 | wrapped_sm = XMLEncapsulatedStateMachine(inner_sm, tag_name) 110 | steppers = wrapped_sm.get_steppers() 111 | 112 | input_sequence = f"<{tag_name}>{content}" 113 | steppers = wrapped_sm.advance_all_basic(steppers, input_sequence) 114 | assert any(stepper.has_reached_accept_state() for stepper in steppers) 115 | -------------------------------------------------------------------------------- /tests/unit/types/xml/test_xml_tag.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from pse.types.base.chain import ChainStateMachine 4 | from pse.types.xml.xml_tag import XMLTagStateMachine 5 | 6 | 7 | def test_basic_tag_initialization() -> None: 8 | """Test basic initialization of XMLTagStateMachine.""" 9 | tag_machine = XMLTagStateMachine("div") 10 | assert isinstance(tag_machine, ChainStateMachine) 11 | steppers = tag_machine.get_steppers() 12 | assert len(steppers) == 1 13 | assert not steppers[0].has_reached_accept_state() 14 | 15 | 16 | def test_closing_tag_initialization() -> None: 17 | """Test initialization of closing XMLTagStateMachine.""" 18 | tag_machine = XMLTagStateMachine("div", closing_tag=True) 19 | steppers = tag_machine.get_steppers() 20 | assert len(steppers) == 1 21 | assert not steppers[0].has_reached_accept_state() 22 | 23 | 24 | def test_basic_tag_recognition() -> None: 25 | """Test recognition of a basic XML tag.""" 26 | tag_machine = XMLTagStateMachine("div") 27 | steppers = tag_machine.get_steppers() 28 | steppers = tag_machine.advance_all_basic(steppers, "
") 29 | assert len(steppers) == 1 30 | assert steppers[0].has_reached_accept_state() 31 | 32 | 33 | def test_closing_tag_recognition() -> None: 34 | """Test recognition of a closing XML tag.""" 35 | tag_machine = XMLTagStateMachine("div", closing_tag=True) 36 | steppers = tag_machine.get_steppers() 37 | steppers = tag_machine.advance_all_basic(steppers, "
") 38 | assert len(steppers) == 1 39 | assert steppers[0].has_reached_accept_state() 40 | 41 | 42 | def test_partial_tag_recognition() -> None: 43 | """Test partial tag recognition behavior.""" 44 | tag_machine = XMLTagStateMachine("div") 45 | steppers = tag_machine.get_steppers() 46 | steppers = tag_machine.advance_all_basic(steppers, "", True), 55 | ("span", "
", False), 56 | ("p", "

", False), # Case sensitivity test 57 | ("input", "", True), 58 | ("br", "
", False), # Doesn't handle self-closing tags 59 | ], 60 | ) 61 | def test_various_tag_scenarios( 62 | tag_name: str, input_text: str, should_accept: bool 63 | ) -> None: 64 | """Test various tag scenarios including invalid ones.""" 65 | machine = XMLTagStateMachine(tag_name) 66 | steppers = machine.get_steppers() 67 | steppers = machine.advance_all_basic(steppers, input_text) 68 | 69 | if should_accept: 70 | assert len(steppers) == 1 71 | assert steppers[0].has_reached_accept_state() 72 | else: 73 | assert not any(s.has_reached_accept_state() for s in steppers) 74 | 75 | 76 | def test_empty_tag_name() -> None: 77 | """Test that empty tag names are not allowed.""" 78 | with pytest.raises(ValueError): 79 | XMLTagStateMachine("") 80 | 81 | 82 | def test_whitespace_handling() -> None: 83 | """Test that whitespace is not allowed within tags.""" 84 | tag_machine = XMLTagStateMachine("div") 85 | steppers = tag_machine.get_steppers() 86 | steppers = tag_machine.advance_all_basic(steppers, "< div>") 87 | assert not any(s.has_reached_accept_state() for s in steppers) 88 | 89 | 90 | def test_get_valid_continuations() -> None: 91 | """Test the get_valid_continuations method.""" 92 | tag_machine = XMLTagStateMachine("div") 93 | stepper = tag_machine.get_new_stepper() 94 | # It's a chain, so it should check the ultimate result. 95 | assert stepper.get_valid_continuations() == ["

"] 96 | 97 | tag_machine = XMLTagStateMachine("span", closing_tag=True) 98 | stepper = tag_machine.get_new_stepper() 99 | assert stepper.get_valid_continuations() == [""] 100 | 101 | # Check that it works even if we advance the stepper. 102 | tag_machine = XMLTagStateMachine("div") 103 | steppers = tag_machine.get_steppers() 104 | steppers = tag_machine.advance_all_basic(steppers, "<") 105 | assert len(steppers) == 1 106 | assert steppers[0].get_valid_continuations() == ["
"] 107 | # And again. 108 | steppers = tag_machine.advance_all_basic(steppers, "d") 109 | assert len(steppers) == 1 110 | assert steppers[0].get_valid_continuations() == ["
"] 111 | 112 | # Finally check it returns nothing if we've gone too far. 113 | steppers = tag_machine.advance_all_basic(steppers, "i") 114 | steppers = tag_machine.advance_all_basic(steppers, "v") 115 | steppers = tag_machine.advance_all_basic(steppers, ">") 116 | assert len(steppers) == 1 117 | assert steppers[0].get_valid_continuations() == ["
"] 118 | --------------------------------------------------------------------------------