├── .coverage
├── .coveragerc
├── .github
├── FUNDING.yml
└── workflows
│ ├── python-app.yml
│ └── python-publish.yml
├── .gitignore
├── CITATION.cff
├── LICENSE
├── README.md
├── examples
├── quickstart-mlx.py
├── quickstart-torch.py
├── simple_demo.ipynb
├── simple_demo.py
├── thinking_answer.ipynb
└── thinking_answer.py
├── logo.png
├── pse
├── structuring_engine.py
├── types
│ ├── array.py
│ ├── base
│ │ ├── any.py
│ │ ├── chain.py
│ │ ├── character.py
│ │ ├── encapsulated.py
│ │ ├── loop.py
│ │ ├── phrase.py
│ │ └── wait_for.py
│ ├── boolean.py
│ ├── enum.py
│ ├── grammar
│ │ ├── __init__.py
│ │ ├── default_grammars
│ │ │ ├── bash.lark
│ │ │ ├── bash.py
│ │ │ ├── python.lark
│ │ │ └── python.py
│ │ └── lark.py
│ ├── integer.py
│ ├── json
│ │ ├── __init__.py
│ │ ├── any_json_schema.py
│ │ ├── json_array.py
│ │ ├── json_key_value.py
│ │ ├── json_number.py
│ │ ├── json_object.py
│ │ ├── json_string.py
│ │ ├── json_value.py
│ │ └── schema_sources
│ │ │ ├── from_function.py
│ │ │ └── from_pydantic.py
│ ├── key_value.py
│ ├── misc
│ │ ├── fenced_freeform.py
│ │ └── freeform.py
│ ├── number.py
│ ├── object.py
│ ├── string.py
│ ├── whitespace.py
│ └── xml
│ │ ├── xml_encapsulated.py
│ │ └── xml_tag.py
└── util
│ ├── generate_mlx.py
│ ├── get_top_logits.py
│ ├── jax_mixin.py
│ ├── tf_mixin.py
│ └── torch_mixin.py
├── pyproject.toml
└── tests
├── functional
└── test_e2e.py
└── unit
├── test_structuring_engine.py
└── types
├── base
├── test_chain.py
├── test_character.py
├── test_encapsulated.py
├── test_loop.py
├── test_phrase.py
├── test_state_machine.py
└── test_wait_for.py
├── bash
├── test_bash_code.py
└── test_wrapped_bash.py
├── json
├── schema_sources
│ ├── test_from_function.py
│ └── test_from_pydantic.py
├── test_any_json_schema.py
├── test_json_array.py
├── test_json_key_value.py
├── test_json_number.py
├── test_json_object.py
├── test_json_string.py
├── test_json_to_state_machine.py
└── test_json_value.py
├── misc
├── test_fenced_freeform.py
└── test_freeform.py
├── python
├── test_python_code.py
└── test_wrapped_python.py
├── test_array.py
├── test_boolean.py
├── test_enum.py
├── test_integer.py
├── test_key_value.py
├── test_number.py
├── test_object.py
├── test_string.py
├── test_whitespace.py
└── xml
├── test_xml_encapsulated.py
└── test_xml_tag.py
/.coverage:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TheProxyCompany/proxy-structuring-engine/1cb33d487126abc6b85a3f78833177840911d4e4/.coverage
--------------------------------------------------------------------------------
/.coveragerc:
--------------------------------------------------------------------------------
1 | [run]
2 | source = pse
3 | omit =
4 | tests/*
5 | pse/util/*
6 |
7 | [report]
8 | omit =
9 | tests/*
10 | */__init__.py
11 | pse/util/*
12 | exclude_lines =
13 | def __repr__
14 | def __str__
15 | ^\s*if __name__ == ['"]__main__['"]:?
16 | pragma: no cover
17 |
--------------------------------------------------------------------------------
/.github/FUNDING.yml:
--------------------------------------------------------------------------------
1 | github: [TheProxyCompany]
2 | custom: ["https://www.theproxycompany.com/research#support"]
3 |
--------------------------------------------------------------------------------
/.github/workflows/python-app.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python
2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
3 |
4 | name: Unit Tests
5 |
6 | on:
7 | push:
8 | branches: [ "main" ]
9 | pull_request:
10 | branches: [ "main" ]
11 |
12 | permissions:
13 | contents: read
14 |
15 | jobs:
16 | build:
17 |
18 | runs-on: ubuntu-latest
19 |
20 | steps:
21 | - uses: actions/checkout@v4
22 | - name: Set up Python 3.12
23 | uses: actions/setup-python@v4
24 | with:
25 | python-version: "3.12"
26 | - name: Install dependencies
27 | run: |
28 | python -m pip install --upgrade pip
29 | pip install .[dev]
30 | - uses: astral-sh/ruff-action@v1
31 | with:
32 | changed-files: "true"
33 | - name: Lint with flake8
34 | run: |
35 | # stop the build if there are Python syntax errors or undefined names
36 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
37 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
38 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
39 |
40 | - name: Run unit tests
41 | run: pytest tests/unit
42 |
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
3 |
4 | # This workflow uses actions that are not certified by GitHub.
5 | # They are provided by a third-party and are governed by
6 | # separate terms of service, privacy policy, and support
7 | # documentation.
8 |
9 | name: Upload Python Package
10 |
11 | on:
12 | release:
13 | types: [published, updated]
14 | # push:
15 | # branches: [ "main" ]
16 |
17 | permissions:
18 | contents: read
19 |
20 | jobs:
21 | deploy:
22 |
23 | runs-on: ubuntu-latest
24 |
25 | steps:
26 | - uses: actions/checkout@v4
27 | - name: Set up Python
28 | uses: actions/setup-python@v4
29 | with:
30 | python-version: '3.x'
31 | - name: Install dependencies
32 | run: |
33 | python -m pip install --upgrade pip
34 | pip install build
35 | - name: Build package
36 | run: python -m build
37 | - name: Publish package
38 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
39 | with:
40 | user: __token__
41 | password: ${{ secrets.PYPI_API_TOKEN }}
42 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | dist
2 | __pycache__
3 | *.pyc
4 | .cursorrules
5 | .cursorignore
6 | .DS_Store
7 | .env
8 | .venv
9 | .pytest_cache
10 | Cargo.lock
11 | target
12 | .cache
13 | uv.lock
14 | */__pycache__
15 | .venv
16 | .pytest_cache
17 | CLAUDE.md
18 | .coverage.*
19 |
--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
1 | cff-version: 1.2.0
2 | message: "If you use this software, please cite it as below."
3 | authors:
4 | - family-names: "Wind"
5 | given-names: "Jack"
6 | title: "Proxy Structuring Engine"
7 | version: 2025.06.1
8 | date-released: 2025-06-03
9 | url: "https://github.com/TheProxyCompany/proxy-structuring-engine"
10 | repository-code: "https://github.com/TheProxyCompany/proxy-structuring-engine"
11 | license: Apache-2.0
12 | publisher: "The Proxy Company"
13 |
--------------------------------------------------------------------------------
/examples/quickstart-mlx.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import sys
4 |
5 | import mlx.core as mx
6 | from mlx_lm.utils import generate_step, load # type: ignore[reportMissingImports]
7 |
8 | from pse.structuring_engine import StructuringEngine
9 |
10 | logging.basicConfig(level=logging.DEBUG, stream=sys.stdout)
11 |
12 | ADVANCED_JSON_SCHEMA = {
13 | "type": "object",
14 | "description": "High-level thoughts, reasoning and internal dialogue.\n Used for step by step reasoning.",
15 | "properties": {
16 | "chain_of_thought": {
17 | "type": "array",
18 | "items": {
19 | "type": "string",
20 | "minLength": 200,
21 | },
22 | "minItems": 1,
23 | "maxItems": 3,
24 | },
25 | },
26 | "required": ["chain_of_thought"],
27 | }
28 |
29 | system_message = f"""
30 | You are an AI that can think step by step.
31 | You must follow this schema when generating your response:
32 | {json.dumps(ADVANCED_JSON_SCHEMA, indent=2)}
33 | """
34 |
35 | prompt = "This is a test - I want to see your private internal monologue."
36 | messages = [
37 | {"role": "system", "content": system_message},
38 | {"role": "user", "content": prompt},
39 | ]
40 | model_path_hf = "meta-llama/Llama-3.2-3B-Instruct"
41 | model, tokenizer = load(model_path_hf)
42 | engine = StructuringEngine(tokenizer._tokenizer) # noqa: SLF001
43 | engine.configure(ADVANCED_JSON_SCHEMA)
44 |
45 | encoded_prompt = engine.tokenizer.apply_chat_template(
46 | conversation=messages,
47 | add_generation_prompt=True,
48 | )
49 |
50 | for tokens, _ in generate_step(
51 | prompt=mx.array(encoded_prompt),
52 | model=model,
53 | logits_processors=[engine.process_logits],
54 | sampler=lambda x: engine.sample(x, mx.argmax), # type: ignore [arg-type]
55 | max_tokens=-1,
56 | ):
57 | encoded_prompt.append(tokens) # type: ignore [attr-defined]
58 | if engine.has_reached_accept_state:
59 | break
60 |
61 | output = engine.get_structured_output()
62 | print(json.dumps(output, indent=2))
63 |
--------------------------------------------------------------------------------
/examples/quickstart-torch.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | import torch # type: ignore[reportMissingImports]
4 | from pydantic import BaseModel
5 | from transformers.models.auto.tokenization_auto import AutoTokenizer
6 | from transformers.models.llama.modeling_llama import LlamaForCausalLM
7 |
8 | from pse.structuring_engine import StructuringEngine
9 | from pse.util.torch_mixin import PSETorchMixin
10 |
11 |
12 | # 1. Define your desired output structure using Pydantic
13 | class Product(BaseModel):
14 | name: str
15 | price: float
16 | description: str | None = None
17 |
18 |
19 | # 2. Load your model and tokenizer. Apply the PSE mixin.
20 | class PSE_Torch(PSETorchMixin, LlamaForCausalLM):
21 | pass
22 |
23 |
24 | model_path = "meta-llama/Llama-3.2-1B-Instruct" # Or any model
25 | tokenizer = AutoTokenizer.from_pretrained(model_path)
26 | model = PSE_Torch.from_pretrained(
27 | model_path, torch_dtype=torch.bfloat16, device_map="auto"
28 | ) # Load to GPU, if available
29 |
30 | # Ensure padding token is set for generation
31 | model.config.pad_token_id = model.config.eos_token_id[0]
32 | if model.generation_config:
33 | model.generation_config.pad_token_id = model.config.eos_token_id[0]
34 |
35 | # 3. Create a StructuringEngine and configure it with your schema
36 | model.engine = StructuringEngine(tokenizer)
37 | model.engine.configure(Product)
38 |
39 | # 4. Create your prompt. Include the schema for the LLM's context.
40 | prompt = f"""
41 | You are a product catalog assistant. Create a product description in JSON format,
42 | following this schema:
43 |
44 | {Product.model_json_schema()}
45 |
46 | Create a product description for a new type of noise-cancelling headphones.
47 | """
48 |
49 | messages = [{"role": "user", "content": prompt}]
50 | input_ids = tokenizer.apply_chat_template(
51 | messages, return_tensors="pt", add_generation_prompt=True
52 | )
53 | # 5. Generate!
54 | assert isinstance(input_ids, torch.Tensor)
55 | input_ids = input_ids.to(model.device)
56 | assert isinstance(input_ids, torch.Tensor)
57 | output = model.generate(
58 | input_ids,
59 | do_sample=True,
60 | max_new_tokens=200,
61 | top_k=10,
62 | top_p=None,
63 | )
64 | # 6. Parse the structured output
65 | structured_output = model.engine.get_structured_output(Product)
66 | print(json.dumps(structured_output.model_dump(), indent=2))
67 |
--------------------------------------------------------------------------------
/examples/simple_demo.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 |
4 | import torch # type: ignore[reportMissingImports]
5 | from pydantic import BaseModel
6 | from transformers.models.auto.tokenization_auto import AutoTokenizer
7 | from transformers.models.llama.modeling_llama import LlamaForCausalLM
8 |
9 | from pse.structuring_engine import StructuringEngine
10 | from pse.util.torch_mixin import PSETorchMixin
11 |
12 | # toggle this to logging.DEBUG to see the PSE debug logs!
13 | logging.basicConfig(level=logging.DEBUG)
14 |
15 |
16 | class PSE_Torch(PSETorchMixin, LlamaForCausalLM):
17 | pass
18 |
19 |
20 | model_path = "meta-llama/Llama-3.2-1B-Instruct"
21 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
22 | model = PSE_Torch.from_pretrained(
23 | model_path,
24 | torch_dtype=torch.bfloat16,
25 | device_map="auto",
26 | )
27 |
28 | model.config.pad_token_id = model.config.eos_token_id[0]
29 | if model.generation_config:
30 | model.generation_config.top_p = None
31 | model.generation_config.top_k = 8
32 | model.generation_config.do_sample = True
33 | model.generation_config.temperature = 1.0
34 | model.generation_config.max_new_tokens = 1000
35 | model.generation_config.pad_token_id = model.config.eos_token_id[0]
36 |
37 | # create structuring engine normally
38 | model.engine = StructuringEngine(tokenizer, multi_token_sampling=True)
39 | SIMPLE_JSON_SCHEMA = {
40 | "type": "object",
41 | "properties": {"value": {"type": "number"}},
42 | "required": ["value"],
43 | }
44 | model.engine.configure(SIMPLE_JSON_SCHEMA)
45 | prompt = (
46 | "Please generate a json object with the value 9.11, with the following schema:\n"
47 | )
48 | prompt += json.dumps(SIMPLE_JSON_SCHEMA, indent=2)
49 |
50 | messages = [{"role": "user", "content": prompt}]
51 | input_ids = tokenizer.apply_chat_template(
52 | messages, return_tensors="pt", add_generation_prompt=True
53 | )
54 | assert isinstance(input_ids, torch.Tensor)
55 | input_ids = input_ids.to(model.device)
56 | assert isinstance(input_ids, torch.Tensor)
57 | output = model.generate(
58 | input_ids,
59 | do_sample=True,
60 | )
61 | # you can print the prompt + output:
62 | # print(tokenizer.decode(output[0]))
63 | structured_output = model.engine.get_structured_output()
64 | print(100 * "-")
65 | print(json.dumps(structured_output, indent=2))
66 |
67 | ADVANCED_JSON_SCHEMA = {
68 | "type": "object",
69 | "description": "High-level thoughts, reasoning and internal dialogue.\n Used for step by step reasoning.",
70 | "properties": {
71 | "chain_of_thought": {
72 | "type": "array",
73 | "items": {
74 | "type": "string",
75 | "minLength": 20, # minimum length of a thought (optional)
76 | },
77 | "minItems": 1, # floor the number of thoughts (optional)
78 | "maxItems": 3, # limit the number of thoughts (optional)
79 | },
80 | },
81 | "required": ["chain_of_thought"],
82 | }
83 | model.engine.configure(ADVANCED_JSON_SCHEMA)
84 | raw_prompt = (
85 | f"This is a test of your thought process.\n"
86 | f"I want to see your private internal monologue.\n"
87 | f"Please follow the following schema when generating your response:\n{json.dumps(ADVANCED_JSON_SCHEMA, indent=2)}\n"
88 | )
89 | messages = [{"role": "user", "content": raw_prompt}]
90 | input_ids = tokenizer.apply_chat_template(
91 | messages, return_tensors="pt", add_generation_prompt=True
92 | )
93 | assert isinstance(input_ids, torch.Tensor)
94 | input_ids = input_ids.to(model.device)
95 | assert isinstance(input_ids, torch.Tensor)
96 | greedy_output = model.generate(input_ids)
97 | structured_output = model.engine.get_structured_output()
98 | print(100 * "-")
99 | print(json.dumps(structured_output, indent=2))
100 |
101 |
102 | class CursorPositionModel(BaseModel):
103 | """
104 | An object representing the position and click state of a cursor.
105 |
106 | Attributes:
107 | x_pos: The horizontal position of the cursor in pixels
108 | y_pos: The vertical position of the cursor in pixels
109 | left_click: Whether the left mouse button is currently pressed. Default is False.
110 | """
111 |
112 | x_pos: int
113 | y_pos: int
114 | left_click: bool = False
115 |
116 |
117 | model.engine.configure(
118 | CursorPositionModel, delimiters=("", "")
119 | )
120 | prompt = (
121 | "Please use the following schema to generate a cursor position:\n"
122 | f"{json.dumps(CursorPositionModel.model_json_schema(), indent=2)}.\n"
123 | "Pretend to move the cursor to x = 100 and y = 100, with the left mouse button clicked.\n"
124 | "Wrap your response in CursorPositionModel."
125 | )
126 | messages = [{"role": "user", "content": prompt}]
127 | input_ids = tokenizer.apply_chat_template(
128 | messages, return_tensors="pt", add_generation_prompt=True
129 | )
130 | assert isinstance(input_ids, torch.Tensor)
131 | input_ids = input_ids.to(model.device)
132 | assert isinstance(input_ids, torch.Tensor)
133 | output = model.generate(
134 | input_ids,
135 | do_sample=True,
136 | )
137 | structured_output = model.engine.get_structured_output(CursorPositionModel)
138 | print(100 * "-")
139 | print(json.dumps(structured_output.model_dump(), indent=2))
140 |
--------------------------------------------------------------------------------
/examples/thinking_answer.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import torch # type: ignore[reportMissingImports]
4 | from pse_core.state_machine import StateMachine
5 | from transformers.models.auto.tokenization_auto import AutoTokenizer
6 | from transformers.models.llama.modeling_llama import LlamaForCausalLM
7 |
8 | from pse.structuring_engine import StructuringEngine
9 | from pse.types.base.loop import LoopStateMachine
10 | from pse.types.misc.fenced_freeform import FencedFreeformStateMachine
11 | from pse.util.torch_mixin import PSETorchMixin
12 |
13 | # toggle this to logging.DEBUG to see the PSE debug logs!
14 | logging.basicConfig(level=logging.DEBUG)
15 |
16 |
17 | class PSE_Torch(PSETorchMixin, LlamaForCausalLM):
18 | pass
19 |
20 |
21 | # you can change the model path to any other model on huggingface
22 | model_path = "meta-llama/Llama-3.2-1B-Instruct"
23 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
24 | model = PSE_Torch.from_pretrained(
25 | model_path,
26 | torch_dtype=torch.bfloat16,
27 | device_map="auto",
28 | )
29 |
30 | model.config.pad_token_id = model.config.eos_token_id[0]
31 | if model.generation_config:
32 | model.generation_config.top_p = None
33 | model.generation_config.top_k = 8
34 | model.generation_config.do_sample = True
35 | model.generation_config.temperature = 1.0
36 | model.generation_config.max_new_tokens = 1000
37 | model.generation_config.pad_token_id = model.config.eos_token_id[0]
38 |
39 | # create structuring engine normally
40 | model.engine = StructuringEngine(tokenizer, multi_token_sampling=True)
41 |
42 | # define custom state machines
43 | thinking_delimiters = ("[thinking]", "[/thinking]")
44 | answer_delimiters = ("[answer]", "[/answer]")
45 |
46 | # encapsulated state machines are used to allow a language model
47 | # to generate unstructured content before the structured output
48 | # starts. This "scratchpad" is disabled by default (min_buffer_length=-1)
49 | thinking_state_machine = FencedFreeformStateMachine("thinking", thinking_delimiters)
50 | # the answer state machine is used to wrap the structured output
51 | answer_state_machine = FencedFreeformStateMachine("answer", answer_delimiters)
52 | # Configure the engine with a state machine that enforces the following flow:
53 | #
54 | # The model starts in the 'thinking' state where it can express its reasoning.
55 | # From there, it can transition to providing its final answer.
56 | #
57 | # ┌──────────────┐
58 | # │ │
59 | # ▼ │
60 | # ┌──────────┐ │
61 | # │ │ │
62 | # │ thinking ├────────┘
63 | # │ │
64 | # └──────┬───┘
65 | # │
66 | # ▼
67 | # ┌──────────┐
68 | # │ │
69 | # │ answer │
70 | # │ │
71 | # └──────────┘
72 | #
73 | # This ensures the model follows a structured thought process before
74 | # providing its final answer.
75 |
76 | model.engine.configure(
77 | StateMachine(
78 | {
79 | "thinking": [
80 | (
81 | LoopStateMachine(
82 | thinking_state_machine,
83 | min_loop_count=1,
84 | max_loop_count=2,
85 | ),
86 | "answer",
87 | )
88 | ],
89 | "answer": [
90 | (
91 | answer_state_machine,
92 | "done",
93 | ),
94 | ],
95 | },
96 | start_state="thinking",
97 | end_states=["done"],
98 | )
99 | )
100 |
101 | system_prompt = (
102 | f"Reason step by step using delimiters to seperate your thought process.\n"
103 | "For example, when asked a question, you should think and then answer.\n"
104 | "Example:\n"
105 | f"{thinking_delimiters[0]}Thinking goes here{thinking_delimiters[1]}"
106 | f"{answer_delimiters[0]}Answer goes here{answer_delimiters[1]}\n"
107 | "you can think multiple times before providing your answer.\n\n"
108 | )
109 | prompt = "Please pick a favorite color. Think about it first."
110 |
111 | input_ids = tokenizer.apply_chat_template(
112 | [
113 | {"role": "system", "content": system_prompt},
114 | {"role": "user", "content": prompt},
115 | ],
116 | return_tensors="pt",
117 | add_generation_prompt=True,
118 | )
119 | assert isinstance(input_ids, torch.Tensor)
120 | input_ids = input_ids.to(model.device)
121 | assert isinstance(input_ids, torch.Tensor)
122 | output = model.generate(input_ids)
123 |
124 | for label, output in model.engine.get_labeled_output():
125 | print("-" * 100)
126 | print(f"[{label}]")
127 | print(output)
128 | print(f"[/{label}]")
129 |
--------------------------------------------------------------------------------
/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TheProxyCompany/proxy-structuring-engine/1cb33d487126abc6b85a3f78833177840911d4e4/logo.png
--------------------------------------------------------------------------------
/pse/types/array.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Any
4 |
5 | from pse_core import StateGraph, StateId
6 | from pse_core.state_machine import StateMachine
7 | from pse_core.stepper import Stepper
8 |
9 | from pse.types.base.chain import ChainStateMachine
10 | from pse.types.base.phrase import PhraseStateMachine
11 | from pse.types.json.json_value import JsonStateMachine
12 | from pse.types.whitespace import WhitespaceStateMachine
13 |
14 |
15 | class ArrayStateMachine(StateMachine):
16 | """
17 | Accepts a well-formed JSON array and handles state transitions during parsing.
18 |
19 | This state_machine manages the parsing of JSON arrays by defining the state transitions
20 | and maintaining the current array values being parsed.
21 | """
22 |
23 | def __init__(self, state_graph: StateGraph | None = None) -> None:
24 | base_array_state_graph: StateGraph = {
25 | 0: [(PhraseStateMachine("["), 1)],
26 | 1: [
27 | (WhitespaceStateMachine(), 2),
28 | (PhraseStateMachine("]"), "$"), # Allow empty array
29 | ],
30 | 2: [(JsonStateMachine(), 3)],
31 | 3: [(WhitespaceStateMachine(), 4)],
32 | 4: [
33 | (
34 | ChainStateMachine(
35 | [PhraseStateMachine(","), WhitespaceStateMachine()]
36 | ),
37 | 2,
38 | ),
39 | (PhraseStateMachine("]"), "$"),
40 | ],
41 | }
42 | super().__init__(state_graph or base_array_state_graph)
43 |
44 | def get_new_stepper(self, state: StateId | None = None) -> ArrayStepper:
45 | return ArrayStepper(self, state)
46 |
47 | def __str__(self) -> str:
48 | return "Array"
49 |
50 |
51 | class ArrayStepper(Stepper):
52 | def __init__(
53 | self,
54 | state_machine: ArrayStateMachine,
55 | current_state: StateId | None = None,
56 | ):
57 | super().__init__(state_machine, current_state)
58 | self.state_machine: ArrayStateMachine = state_machine
59 | self.value: list[Any] = []
60 |
61 | def clone(self) -> ArrayStepper:
62 | cloned_stepper = super().clone()
63 | cloned_stepper.value = self.value[:]
64 | return cloned_stepper
65 |
66 | def is_within_value(self) -> bool:
67 | return self.current_state == 3
68 |
69 | def add_to_history(self, stepper: Stepper) -> None:
70 | if self.is_within_value():
71 | self.value.append(stepper.get_current_value())
72 | super().add_to_history(stepper)
73 |
74 | def get_current_value(self) -> list:
75 | """
76 | Get the current parsed JSON object.
77 |
78 | Returns:
79 | dict[str, Any]: The accumulated key-value pairs representing the JSON object.
80 | """
81 | return self.value
82 |
--------------------------------------------------------------------------------
/pse/types/base/any.py:
--------------------------------------------------------------------------------
1 | from pse_core import StateId
2 | from pse_core.state_machine import StateMachine
3 | from pse_core.stepper import Stepper
4 |
5 |
6 | class AnyStateMachine(StateMachine):
7 | def __init__(self, state_machines: list[StateMachine]) -> None:
8 |
9 | self.state_machines: list[StateMachine] = state_machines
10 | super().__init__(
11 | {
12 | 0: [
13 | (state_machine, "$")
14 | for state_machine in self.state_machines
15 | ]
16 | }
17 | )
18 |
19 | def get_steppers(self, state: StateId | None = None) -> list[Stepper]:
20 | steppers = []
21 | for edge, _ in self.get_edges(state or 0):
22 | steppers.extend(edge.get_steppers())
23 | return steppers
24 |
25 | def __str__(self) -> str:
26 | return "Any"
27 |
--------------------------------------------------------------------------------
/pse/types/base/chain.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import logging
4 |
5 | from pse_core.state_machine import StateMachine
6 | from pse_core.stepper import Stepper
7 |
8 | logger = logging.getLogger(__name__)
9 |
10 |
11 | class ChainStateMachine(StateMachine):
12 | """
13 | Chain multiple StateMachines in a specific order.
14 | """
15 |
16 | def __init__(self, state_machines: list[StateMachine], is_optional: bool = False) -> None:
17 | """
18 | Args:
19 | state_machines: State machines to be chained in sequence
20 | """
21 | super().__init__(
22 | state_graph={
23 | i: [(state_machine, i + 1)]
24 | for i, state_machine in enumerate(state_machines)
25 | },
26 | end_states=[len(state_machines)],
27 | is_optional=is_optional,
28 | )
29 |
30 | def get_new_stepper(self, state: int | str | None = None) -> Stepper:
31 | return ChainStepper(self, state)
32 |
33 | def __str__(self) -> str:
34 | return "Chain"
35 |
36 |
37 | class ChainStepper(Stepper):
38 | """
39 | A stepper that chains multiple steppers in a specific sequence.
40 | """
41 |
42 | def __init__(self, chain_state_machine: ChainStateMachine, *args, **kwargs) -> None:
43 | super().__init__(chain_state_machine, *args, **kwargs)
44 | self.state_machine = chain_state_machine
45 |
--------------------------------------------------------------------------------
/pse/types/base/character.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from collections.abc import Iterable
4 |
5 | from pse_core.state_machine import StateMachine
6 | from pse_core.stepper import Stepper
7 |
8 |
9 | class CharacterStateMachine(StateMachine):
10 | """
11 | Accepts one or more valid characters.
12 | """
13 |
14 | def __init__(
15 | self,
16 | whitelist_charset: str | list[str] | Iterable[str] = "",
17 | graylist_charset: str | list[str] | Iterable[str] = "",
18 | blacklist_charset: str | list[str] | Iterable[str] = "",
19 | char_min: int | None = None,
20 | char_limit: int | None = None,
21 | is_optional: bool = False,
22 | case_sensitive: bool = True,
23 | ) -> None:
24 | """
25 | Initialize a CharacterStateMachine with character sets and constraints.
26 |
27 | Args:
28 | whitelist_charset: Characters that are explicitly allowed
29 | graylist_charset: Characters that are allowed but terminate the match if they follow other characters
30 | blacklist_charset: Characters that are explicitly forbidden
31 | char_min: Minimum number of characters required (0 if None)
32 | char_limit: Maximum number of characters allowed (unlimited if 0 or None)
33 | is_optional: Whether this state machine is optional
34 | case_sensitive: Whether character matching is case-sensitive
35 | """
36 | super().__init__(
37 | is_optional=is_optional,
38 | is_case_sensitive=case_sensitive,
39 | )
40 | self.char_min = char_min or 0
41 | self.char_limit = char_limit or 0
42 | self.charset: set[str] = set()
43 | self.graylist_charset: set[str] = set()
44 | self.blacklist_charset: set[str] = set()
45 |
46 | # Process all charsets efficiently
47 | def convert_to_set(chars):
48 | return set(char.lower() for char in chars) if chars else set()
49 |
50 | self.charset = set(whitelist_charset) if case_sensitive else convert_to_set(whitelist_charset)
51 | self.graylist_charset = set(graylist_charset) if case_sensitive else convert_to_set(graylist_charset)
52 | self.blacklist_charset = set(blacklist_charset) if case_sensitive else convert_to_set(blacklist_charset)
53 |
54 | def get_new_stepper(self, state: int | str) -> CharacterStepper:
55 | return CharacterStepper(self)
56 |
57 | def __str__(self) -> str:
58 | return "Character"
59 |
60 |
61 | class CharacterStepper(Stepper):
62 | """
63 | Stepper for navigating through characters in CharacterStateMachine.
64 | """
65 |
66 | def __init__(
67 | self,
68 | state_machine: CharacterStateMachine,
69 | value: str | None = None,
70 | ) -> None:
71 | """
72 | Initialize the Stepper.
73 |
74 | Args:
75 | value (Optional[str]): The accumulated string value. Defaults to None.
76 | """
77 | super().__init__(state_machine)
78 | self.target_state = "$"
79 | self.state_machine: CharacterStateMachine = state_machine
80 | self._raw_value = value
81 | if value:
82 | self.consumed_character_count = len(value)
83 |
84 | def accepts_any_token(self) -> bool:
85 | return not self.state_machine.charset
86 |
87 | def get_valid_continuations(self, depth: int = 0) -> list[str]:
88 | """
89 | Returns a list of valid continuations for the current stepper.
90 | """
91 | return list(self.state_machine.charset)
92 |
93 | def can_accept_more_input(self) -> bool:
94 | """
95 | Determines if the stepper can accept more input based on the character limit.
96 | """
97 | if (
98 | self.state_machine.char_limit > 0
99 | and self.consumed_character_count >= self.state_machine.char_limit
100 | ):
101 | return False
102 |
103 | return True
104 |
105 | def should_start_step(self, token: str) -> bool:
106 | """
107 | Determines if a transition should start with the given token.
108 |
109 | Args:
110 | token (str): The input token to check.
111 |
112 | Returns:
113 | bool: True if the token can start a transition, False otherwise.
114 | """
115 | if not token or (
116 | self.state_machine.char_limit > 0
117 | and self.consumed_character_count >= self.state_machine.char_limit
118 | ):
119 | return False
120 |
121 | first_char = token[0]
122 | if not self.state_machine.is_case_sensitive:
123 | first_char = first_char.lower()
124 |
125 | if first_char in self.state_machine.blacklist_charset:
126 | return False
127 |
128 | if self.state_machine.charset:
129 | return first_char in self.state_machine.charset
130 |
131 | return True
132 |
133 | def should_complete_step(self) -> bool:
134 | """
135 | Determines if the transition should be completed based on the character limit.
136 | """
137 | if (
138 | self.state_machine.char_limit > 0
139 | and self.consumed_character_count > self.state_machine.char_limit
140 | ):
141 | return False
142 |
143 | if (
144 | self.state_machine.char_min > 0
145 | and self.consumed_character_count < self.state_machine.char_min
146 | ):
147 | return False
148 |
149 | return True
150 |
151 | def consume(self, token: str) -> list[Stepper]:
152 | """
153 | Advance the stepper with the given input.
154 |
155 | This method processes the input token and determines how much of it can be consumed
156 | based on character constraints. It stops consuming at the first invalid character.
157 |
158 | Args:
159 | token: The input string to consume
160 |
161 | Returns:
162 | List of new steppers after advancement (empty if nothing can be consumed)
163 | """
164 | if not token or not self.should_start_step(token):
165 | return []
166 |
167 | # Apply case sensitivity
168 | token = token.lower() if not self.state_machine.is_case_sensitive else token
169 |
170 | # Cache frequently used properties for performance
171 | charset = self.state_machine.charset
172 | blacklist = self.state_machine.blacklist_charset
173 | graylist = self.state_machine.graylist_charset
174 | char_limit = self.state_machine.char_limit
175 | consumed_count = self.consumed_character_count
176 |
177 | # Find the longest valid prefix efficiently
178 | valid_prefix_len = 0
179 | for char in token:
180 | # Stop at first invalid character or limit
181 | is_blacklisted = char in blacklist
182 | not_in_charset = len(charset) > 0 and char not in charset
183 | exceeds_limit = char_limit > 0 and valid_prefix_len + consumed_count >= char_limit
184 | is_graylisted = len(graylist) > 0 and valid_prefix_len > 0 and char in graylist
185 |
186 | if is_blacklisted or not_in_charset or exceeds_limit or is_graylisted:
187 | break
188 |
189 | valid_prefix_len += 1
190 |
191 | # Extract the valid portion using string slicing
192 | valid_prefix = token[:valid_prefix_len]
193 |
194 | # Create new stepper with updated state
195 | new_value = self.get_raw_value() + valid_prefix
196 | remaining_input = token[len(valid_prefix):] or None
197 | new_stepper = self.step(new_value, remaining_input)
198 |
199 | return [new_stepper]
200 |
--------------------------------------------------------------------------------
/pse/types/base/encapsulated.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from collections.abc import Callable
4 | from typing import Self
5 |
6 | from pse_core import StateId
7 | from pse_core.state_machine import StateMachine
8 | from pse_core.stepper import Stepper
9 |
10 | from pse.types.base.phrase import PhraseStateMachine
11 | from pse.types.base.wait_for import WaitFor
12 |
13 |
14 | class EncapsulatedStateMachine(StateMachine):
15 | """
16 | This class encapsulates an state_machine that recognizes content framed by
17 | specified opening and closing delimiters.
18 | """
19 |
20 | def __init__(
21 | self,
22 | state_machine: StateMachine,
23 | delimiters: tuple[str, str] | None,
24 | buffer_length: int = -1,
25 | is_optional: bool = False,
26 | ) -> None:
27 | """
28 |
29 | Args:
30 | state_machine: The state_machine wrapped by this state machine.
31 | delimiters: The tuple of opening and closing delimiters.
32 | """
33 | self.inner_state_machine = state_machine
34 | self.delimiters = delimiters or ("```", "```")
35 | super().__init__(
36 | {
37 | 0: [
38 | (
39 | WaitFor(
40 | PhraseStateMachine(self.delimiters[0]),
41 | buffer_length=buffer_length,
42 | ),
43 | 1,
44 | ),
45 | ],
46 | 1: [(state_machine, 2)],
47 | 2: [(PhraseStateMachine(self.delimiters[1]), "$")],
48 | },
49 | is_optional=is_optional,
50 | )
51 |
52 | def get_new_stepper(self, state: StateId | None = None) -> EncapsulatedStepper:
53 | return EncapsulatedStepper(self, state)
54 |
55 | class EncapsulatedStepper(Stepper):
56 |
57 | def __init__(
58 | self,
59 | state_machine: EncapsulatedStateMachine,
60 | state: StateId | None = None,
61 | ) -> None:
62 | super().__init__(state_machine, state)
63 | self.state_machine: EncapsulatedStateMachine = state_machine
64 | self.inner_stepper: Stepper | None = None
65 |
66 | def clone(self) -> Self:
67 | clone = super().clone()
68 | if self.inner_stepper:
69 | clone.inner_stepper = self.inner_stepper.clone()
70 | return clone
71 |
72 | def is_within_value(self) -> bool:
73 | if self.current_state == 0 and self.sub_stepper:
74 | return self.sub_stepper.is_within_value()
75 |
76 | return self.current_state != 0
77 |
78 | def add_to_history(self, stepper: Stepper) -> None:
79 | if self.current_state == 2:
80 | self.inner_stepper = stepper
81 |
82 | return super().add_to_history(stepper)
83 |
84 | def get_invalid_continuations(self) -> list[str]:
85 | if not self.inner_stepper:
86 | return [self.state_machine.delimiters[1]]
87 | return super().get_invalid_continuations()
88 |
89 | def get_final_state(self) -> list[Stepper]:
90 | return [self]
91 |
92 | def get_token_safe_output(self, decode_function: Callable[[list[int]], str]) -> str:
93 | """
94 | Retrieve the token-safe output with delimiters removed.
95 |
96 | This method processes the raw output by removing the encapsulating delimiters,
97 | handling both complete and partial delimiter occurrences efficiently.
98 |
99 | Args:
100 | decode_function: Function to decode token IDs into a string
101 |
102 | Returns:
103 | Processed string with delimiters stripped
104 | """
105 | # Get and decode the token history
106 | token_ids = self.get_token_ids_history()
107 | token_safe_output: str = decode_function(token_ids).strip()
108 |
109 | # Extract delimiters
110 | start_delim, end_delim = self.state_machine.delimiters
111 |
112 | # Remove start delimiter - optimize by checking exact match first
113 | # This is faster than always using lstrip
114 | if token_safe_output.startswith(start_delim):
115 | token_safe_output = token_safe_output[len(start_delim):]
116 | else:
117 | token_safe_output = token_safe_output.lstrip(start_delim)
118 |
119 | # Remove end delimiter - optimize by checking exact match first
120 | # This is faster than always using rstrip
121 | if token_safe_output.endswith(end_delim):
122 | token_safe_output = token_safe_output[:-len(end_delim)]
123 | else:
124 | token_safe_output = token_safe_output.rstrip(end_delim)
125 |
126 | return token_safe_output
127 |
--------------------------------------------------------------------------------
/pse/types/base/loop.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import logging
4 | from typing import Self
5 |
6 | from pse_core import StateGraph
7 | from pse_core.state_machine import StateMachine
8 | from pse_core.stepper import Stepper
9 |
10 | logger = logging.getLogger(__name__)
11 |
12 |
13 | class LoopStateMachine(StateMachine):
14 | """
15 | Loop through a single StateMachine.
16 | """
17 |
18 | def __init__(
19 | self,
20 | state_machine: StateMachine,
21 | min_loop_count: int = 1,
22 | max_loop_count: int = -1,
23 | separator_state_machine: StateMachine | None = None,
24 | track_separator: bool = True,
25 | ) -> None:
26 | """
27 | Args:
28 | state_machine: State machine to be looped through
29 | """
30 | self.separator_state_machine = separator_state_machine
31 | self.track_separator = track_separator
32 | if self.separator_state_machine:
33 | state_graph: StateGraph = {
34 | 0: [(state_machine, 1)],
35 | 1: [(self.separator_state_machine, 2)],
36 | 2: [(state_machine, 1)],
37 | }
38 | else:
39 | state_graph: StateGraph = {
40 | 0: [(state_machine, 1)],
41 | 1: [(state_machine, 0)],
42 | }
43 |
44 | super().__init__(
45 | state_graph=state_graph,
46 | is_optional=min_loop_count == 0,
47 | )
48 | self.min_loop_count = min_loop_count or 1
49 | self.max_loop_count = max_loop_count
50 |
51 | def get_new_stepper(self, state: int | str | None = None) -> Stepper:
52 | return LoopStepper(self, state)
53 |
54 | def __str__(self) -> str:
55 | return "Loop"
56 |
57 |
58 | class LoopStepper(Stepper):
59 | """
60 | A stepper that loops through a single StateMachine.
61 | """
62 |
63 | def __init__(self, loop_state_machine: LoopStateMachine, *args, **kwargs) -> None:
64 | super().__init__(loop_state_machine, *args, **kwargs)
65 | self.state_machine: LoopStateMachine = loop_state_machine
66 | self.loop_count = 0
67 |
68 | def clone(self) -> Self:
69 | clone = super().clone()
70 | clone.loop_count = self.loop_count
71 | return clone
72 |
73 | def should_branch(self) -> bool:
74 | return super().should_branch() and self.loop_count < self.state_machine.max_loop_count
75 |
76 | def has_reached_accept_state(self) -> bool:
77 | """
78 | Determines if this stepper is in an accept state.
79 |
80 | A loop stepper is in an accept state if:
81 | 1. It has completed at least the minimum required iterations, and
82 | 2. If it's currently processing a sub-stepper, that sub-stepper is in an accept state
83 |
84 | Returns:
85 | True if the stepper is in an accept state, False otherwise
86 | """
87 | # Early exit if we haven't reached the minimum loop count
88 | if self.loop_count < self.state_machine.min_loop_count:
89 | return False
90 |
91 | # If we're currently processing a value through a sub-stepper,
92 | # delegate to that sub-stepper's accept state
93 | if self.sub_stepper is not None and self.sub_stepper.is_within_value():
94 | return self.sub_stepper.has_reached_accept_state()
95 |
96 | # Otherwise, we're in an accept state if we've met the minimum loop count
97 | return True
98 |
99 | def consume(self, token: str) -> list[Stepper]:
100 | new_steppers: list[Stepper] = []
101 |
102 | def _validate_loop_stepper(stepper: LoopStepper) -> Stepper | None:
103 | """
104 | Validate that the loop stepper respects the max loop count.
105 | """
106 | if stepper.loop_count < self.state_machine.max_loop_count:
107 | # if the loop count is less than the max loop count,
108 | # we can just return the stepper
109 | return stepper
110 | elif stepper.loop_count == self.state_machine.max_loop_count:
111 | # if the loop count is equal to the max loop count,
112 | # we need to make sure the stepper is not expecting a transition
113 | # to a new state
114 | if stepper.target_state is not None:
115 | stepper.current_state = stepper.target_state
116 | stepper.target_state = None
117 | stepper.sub_stepper = None
118 | return stepper
119 | else:
120 | # otherwise, the loop stepper is invalid
121 | return None
122 |
123 | # explicitly check that the new steppers respect the max loop count
124 | for new_stepper in super().consume(token):
125 | assert isinstance(new_stepper, LoopStepper)
126 | if valid_stepper := _validate_loop_stepper(new_stepper):
127 | new_steppers.append(valid_stepper)
128 |
129 | return new_steppers
130 |
131 | def can_accept_more_input(self) -> bool:
132 | if not super().can_accept_more_input():
133 | return False
134 |
135 | if self.state_machine.max_loop_count > 0:
136 | return self.loop_count < self.state_machine.max_loop_count
137 |
138 | return True
139 |
140 | def should_start_step(self, token: str) -> bool:
141 | if self.loop_count >= self.state_machine.max_loop_count:
142 | return False
143 |
144 | return super().should_start_step(token)
145 |
146 | def add_to_history(self, stepper: Stepper) -> None:
147 | if (
148 | self.state_machine.separator_state_machine
149 | and stepper.state_machine == self.state_machine.separator_state_machine
150 | ):
151 | if not self.state_machine.track_separator:
152 | return
153 | else:
154 | self.loop_count += 1
155 |
156 | return super().add_to_history(stepper)
157 |
158 | def get_final_state(self) -> list[Stepper]:
159 | """
160 | Gets the final state representation for this stepper.
161 |
162 | This method decides which stepper(s) to return as the final state:
163 | - If we're within a value in the sub-stepper (and it's not a whitespace separator),
164 | delegate to the sub-stepper's final state
165 | - Otherwise, return this stepper's history
166 |
167 | Returns:
168 | List of steppers representing the final state
169 | """
170 | # If no sub-stepper exists, return history
171 | # If sub-stepper isn't processing a value, return history
172 | if not self.sub_stepper or not self.sub_stepper.is_within_value():
173 | return self.history
174 |
175 | # Check if sub-stepper is a separator we should ignore
176 | separator_state_machine = self.state_machine.separator_state_machine
177 | is_separator = (
178 | separator_state_machine is not None
179 | and self.sub_stepper.state_machine == separator_state_machine
180 | )
181 |
182 | # If it's a separator we don't want to track, return history
183 | if is_separator and not self.state_machine.track_separator:
184 | return self.history
185 |
186 | # Otherwise, delegate to the sub-stepper's final state
187 | return self.sub_stepper.get_final_state()
188 |
--------------------------------------------------------------------------------
/pse/types/base/phrase.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import logging
4 |
5 | from pse_core.state_machine import StateMachine
6 | from pse_core.stepper import Stepper
7 |
8 | logger = logging.getLogger(__name__)
9 |
10 |
11 | class PhraseStateMachine(StateMachine):
12 | """
13 | Accepts a predefined sequence of characters, validating input against the specified text.
14 |
15 | Attributes:
16 | phrase (str): The target string that this state_machine is validating against.
17 | """
18 |
19 | def __init__(
20 | self,
21 | phrase: str,
22 | is_optional: bool = False,
23 | is_case_sensitive: bool = True,
24 | ):
25 | """
26 | Initialize a new PhraseStateMachine instance with the specified text.
27 |
28 | Args:
29 | phrase (str): The string of characters that this state_machine will validate.
30 | Must be a non-empty string.
31 |
32 | Raises:
33 | ValueError: If the provided text is empty.
34 | """
35 | super().__init__(
36 | is_optional=is_optional,
37 | is_case_sensitive=is_case_sensitive,
38 | )
39 |
40 | if not phrase:
41 | raise ValueError("Phrase must be a non-empty string.")
42 |
43 | self.phrase = phrase
44 |
45 | def get_new_stepper(self, state: int | str | None = None) -> PhraseStepper:
46 | return PhraseStepper(self)
47 |
48 | def __str__(self) -> str:
49 | """
50 | Provide a string representation of the PhraseStateMachine.
51 |
52 | Returns:
53 | str: A string representation of the PhraseStateMachine.
54 | """
55 | return f"Phrase({self.phrase!r})"
56 |
57 | def __eq__(self, other: object) -> bool:
58 | return (
59 | isinstance(other, PhraseStateMachine)
60 | and self.phrase == other.phrase
61 | )
62 |
63 |
64 | class PhraseStepper(Stepper):
65 |
66 | def __init__(
67 | self,
68 | state_machine: PhraseStateMachine,
69 | consumed_character_count: int | None = None,
70 | ):
71 | super().__init__(state_machine)
72 | if consumed_character_count is not None and consumed_character_count < 0:
73 | raise ValueError("Consumed character count must be non-negative")
74 |
75 | self.consumed_character_count = consumed_character_count or 0
76 | self.state_machine: PhraseStateMachine = state_machine
77 | self.target_state = "$"
78 |
79 | def can_accept_more_input(self) -> bool:
80 | """
81 | Check if the stepper can accept more input.
82 | """
83 | return self.consumed_character_count < len(self.state_machine.phrase)
84 |
85 | def should_start_step(self, token: str) -> bool:
86 | """
87 | Start a transition if the token is not empty and matches the remaining text.
88 | """
89 | if not token:
90 | return False
91 |
92 | valid_length = self._get_valid_match_length(token)
93 | return valid_length > 0
94 |
95 | def should_complete_step(self) -> bool:
96 | return self.consumed_character_count == len(self.state_machine.phrase)
97 |
98 | def get_valid_continuations(self, depth: int = 0) -> list[str]:
99 | if self.consumed_character_count >= len(self.state_machine.phrase):
100 | return []
101 |
102 | remaining_text = self.state_machine.phrase[self.consumed_character_count :]
103 | return [remaining_text]
104 |
105 | def consume(self, token: str) -> list[Stepper]:
106 | """
107 | Advances the stepper if the token matches the expected text at the current position.
108 | Args:
109 | token (str): The string to match against the expected text.
110 |
111 | Returns:
112 | list[Stepper]: A stepper if the token matches, empty otherwise.
113 | """
114 | valid_length = self._get_valid_match_length(token)
115 | if valid_length <= 0:
116 | return []
117 |
118 | new_value = self.get_raw_value() + token[:valid_length]
119 | remaining_input = token[valid_length:] if valid_length < len(token) else None
120 | new_stepper = self.step(new_value, remaining_input)
121 | return [new_stepper]
122 |
123 | def get_raw_value(self) -> str:
124 | return self.state_machine.phrase[: self.consumed_character_count]
125 |
126 | def _get_valid_match_length(self, token: str, pos: int | None = None) -> int:
127 | """
128 | Calculate the length of the matching prefix between the token and the target phrase.
129 |
130 | Args:
131 | token: The input string to check against the target phrase
132 | pos: Starting position in the phrase (defaults to current consumed count)
133 |
134 | Returns:
135 | Length of the matching prefix
136 | """
137 | # Use current position if not specified
138 | pos = pos or self.consumed_character_count
139 |
140 | # Get the remaining portion of the phrase to match against
141 | remaining_phrase = self.state_machine.phrase[pos:]
142 |
143 | # Determine maximum possible match length
144 | max_length = min(len(token), len(remaining_phrase))
145 |
146 | # Find the longest matching prefix using string slicing
147 | # This is more efficient than character-by-character comparison
148 | for i in range(max_length + 1):
149 | if token[:i] != remaining_phrase[:i]:
150 | return i - 1
151 |
152 | return max_length
153 |
--------------------------------------------------------------------------------
/pse/types/base/wait_for.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import logging
4 | from typing import Self
5 |
6 | from pse_core import StateId
7 | from pse_core.state_machine import StateMachine
8 | from pse_core.stepper import Stepper
9 |
10 | logger = logging.getLogger(__name__)
11 |
12 |
13 | class WaitFor(StateMachine):
14 | """
15 | Accept all text until a segment triggers a nested StateId Machine.
16 |
17 | Accumulates text in a buffer until a segment triggers the nested StateId Machine.
18 |
19 | This is particularly useful for allowing free-form text until a specific
20 | delimiter or pattern is detected, such as when parsing output from
21 | language models that encapsulate JSON within markdown code blocks.
22 | """
23 |
24 | def __init__(
25 | self,
26 | state_machine: StateMachine,
27 | buffer_length: int = -1,
28 | strict: bool = True,
29 | ):
30 | """
31 | Initialize with a target nested StateMachine.
32 |
33 | Args:
34 | state_machine (StateMachine): The nested StateMachine to watch for.
35 | buffer_length (int):
36 | The minimum length of the buffer
37 | strict (bool):
38 | If True, the nested StateMachine's progress is reset when invalid input is detected.
39 | """
40 | super().__init__()
41 |
42 | self.min_buffer_length = buffer_length
43 | self.strict = strict
44 | self.wait_for_sm = state_machine
45 |
46 | def get_transitions(self, _: Stepper) -> list[tuple[Stepper, StateId]]:
47 | transitions = []
48 | for transition in self.wait_for_sm.get_steppers():
49 | transitions.append((transition, "$"))
50 | return transitions
51 |
52 | def get_new_stepper(self, _: StateId | None = None) -> Stepper:
53 | return WaitForStepper(self)
54 |
55 | def get_steppers(self, _: StateId | None = None) -> list[Stepper]:
56 | return self.branch_stepper(self.get_new_stepper())
57 |
58 | def __str__(self) -> str:
59 | return f"WaitFor({self.wait_for_sm})"
60 |
61 |
62 | class WaitForStepper(Stepper):
63 | def __init__(self, state_machine: WaitFor):
64 | super().__init__(state_machine)
65 | self.target_state = "$"
66 | self.state_machine: WaitFor = state_machine
67 | self.buffer = ""
68 |
69 | def clone(self) -> Self:
70 | clone = super().clone()
71 | clone.buffer = self.buffer
72 | return clone
73 |
74 | def accepts_any_token(self) -> bool:
75 | """
76 | Determines if this stepper can accept any token based on buffer state.
77 |
78 | The stepper accepts any token if:
79 | 1. The buffer meets minimum length requirements, or
80 | 2. The sub-stepper is active and accepts any token
81 |
82 | Returns:
83 | True if the stepper can accept any token, False otherwise
84 | """
85 | # Cache min_buffer_length for performance
86 | min_buffer_length = self.state_machine.min_buffer_length
87 |
88 | # Delegate to sub_stepper if it's active
89 | if self.sub_stepper and self.sub_stepper.is_within_value():
90 | return self.sub_stepper.accepts_any_token()
91 |
92 | # If the buffer is not long enough, we can accept any token
93 | if len(self.buffer) < min_buffer_length:
94 | return True
95 |
96 | # Otherwise, check the size of the buffer
97 | return len(self.buffer) >= min_buffer_length
98 |
99 | def get_valid_continuations(self) -> list[str]:
100 | """
101 | If the buffer is long enough, we can accept any valid continuations.
102 |
103 | If the buffer is not long enough, we can accept everything.
104 | """
105 | if len(self.buffer) >= self.state_machine.min_buffer_length:
106 | return super().get_valid_continuations()
107 | return []
108 |
109 | def get_invalid_continuations(self) -> list[str]:
110 | """
111 | If the buffer is not long enough yet,
112 | any valid continuation is inversed and
113 | invalid to allow the buffer to grow.
114 |
115 | If the buffer is long enough, there are no invalid continuations.
116 | """
117 | if len(self.buffer) < self.state_machine.min_buffer_length and self.sub_stepper:
118 | return self.sub_stepper.get_valid_continuations()
119 | return []
120 |
121 | def should_start_step(self, token: str) -> bool:
122 | """
123 | Determines if the stepper should start processing the token.
124 |
125 | This method decides whether to start a step based on:
126 | 1. Whether we have remaining input from a previous token
127 | 2. The buffer length requirements
128 | 3. The stepper's current state
129 |
130 | Args:
131 | token: The token to potentially process
132 |
133 | Returns:
134 | True if the step should start, False otherwise
135 | """
136 | # Never start a step if we have remaining input
137 | if self.remaining_input:
138 | return False
139 |
140 | # Cache frequently accessed values
141 | required_buffer_length = self.state_machine.min_buffer_length
142 | should_start = super().should_start_step(token)
143 |
144 | # Handle unlimited buffer length case
145 | if required_buffer_length <= 0:
146 | return should_start or not self.is_within_value()
147 |
148 | # For cases with a positive buffer length requirement
149 | buffer_length = len(self.buffer)
150 | is_in_value = self.is_within_value()
151 |
152 | # Return True if either:
153 | # 1. super().should_start_step() returns True and we have enough buffer, or
154 | # 2. super().should_start_step() returns False but we're not within a value
155 | if should_start and buffer_length >= required_buffer_length:
156 | return True
157 | if not should_start and not is_in_value:
158 | return True
159 |
160 | return False
161 |
162 | def consume(self, token: str) -> list[Stepper]:
163 | # No sub_stepper means we can't process anything
164 | if not self.sub_stepper:
165 | return []
166 |
167 | # Try to find the longest valid prefix that the sub_stepper will accept
168 | invalid_prefix = ""
169 | valid_suffix = token
170 |
171 | while valid_suffix and not self.sub_stepper.should_start_step(valid_suffix):
172 | invalid_prefix += valid_suffix[0]
173 | valid_suffix = valid_suffix[1:]
174 |
175 | if self.state_machine.strict and self.is_within_value() and invalid_prefix:
176 | return []
177 |
178 | if invalid_prefix and (
179 | not self.is_within_value() or not self.state_machine.strict
180 | ):
181 | if not self.is_within_value() and self.state_machine.min_buffer_length == -1:
182 | return []
183 |
184 | clone = self.clone()
185 | clone.buffer += invalid_prefix
186 | if valid_suffix:
187 | return self.state_machine.advance_stepper(clone, valid_suffix)
188 | else:
189 | return [clone]
190 |
191 | return self.state_machine.advance_stepper(self, valid_suffix)
192 |
--------------------------------------------------------------------------------
/pse/types/boolean.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from pse_core import StateId
4 | from pse_core.state_machine import StateMachine
5 | from pse_core.stepper import Stepper
6 |
7 | from pse.types.base.phrase import PhraseStateMachine
8 |
9 |
10 | class BooleanStateMachine(StateMachine):
11 | """
12 | Accepts a JSON boolean value: true, false.
13 | """
14 |
15 | def __init__(self) -> None:
16 | super().__init__(
17 | {
18 | 0: [
19 | (PhraseStateMachine("true"), "$"),
20 | (PhraseStateMachine("false"), "$"),
21 | ]
22 | }
23 | )
24 |
25 | def get_steppers(self, state: StateId | None = None) -> list[Stepper]:
26 | steppers = []
27 | for edge, _ in self.get_edges(state or 0):
28 | steppers.extend(edge.get_steppers())
29 | return steppers
30 |
--------------------------------------------------------------------------------
/pse/types/enum.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from pse_core import StateGraph, StateId
4 | from pse_core.state_machine import StateMachine
5 | from pse_core.stepper import Stepper
6 |
7 | from pse.types.base.chain import ChainStateMachine
8 | from pse.types.base.phrase import PhraseStateMachine
9 |
10 |
11 | class EnumStateMachine(StateMachine):
12 | """
13 | Accept one of several constant strings.
14 | """
15 |
16 | def __init__(self, enum_values: list[str], require_quotes: bool = True) -> None:
17 | if not enum_values:
18 | raise ValueError("Enum values must be provided.")
19 |
20 | state_graph: StateGraph = {0: []}
21 | unique_enum_values = list(set(enum_values))
22 | for value in unique_enum_values:
23 | sm = (
24 | PhraseStateMachine(value)
25 | if not require_quotes
26 | else ChainStateMachine(
27 | [
28 | PhraseStateMachine('"'),
29 | PhraseStateMachine(value),
30 | PhraseStateMachine('"'),
31 | ]
32 | )
33 | )
34 | state_graph[0].append((sm, "$"))
35 |
36 | super().__init__(state_graph)
37 |
38 | def get_steppers(self, state: StateId | None = None) -> list[Stepper]:
39 | steppers = []
40 | for edge, _ in self.get_edges(state or 0):
41 | steppers.extend(edge.get_steppers())
42 | return steppers
43 |
--------------------------------------------------------------------------------
/pse/types/grammar/__init__.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | from lark import Lark
4 | from lark.exceptions import UnexpectedCharacters, UnexpectedEOF, UnexpectedToken
5 |
6 |
7 | class LarkGrammar(ABC):
8 |
9 | name: str
10 | lark_grammar: Lark
11 | delimiters: tuple[str, str] | None = None
12 |
13 | def __init__(
14 | self,
15 | name: str,
16 | lark_grammar: Lark,
17 | delimiters: tuple[str, str] | None = None,
18 | ):
19 | self.name = name
20 | self.lark_grammar = lark_grammar
21 | self.delimiters = delimiters
22 |
23 | @abstractmethod
24 | def validate(
25 | self,
26 | input: str,
27 | strict: bool = False,
28 | start: str | None = None,
29 | ) -> bool:
30 | """
31 | Validate the input against the grammar.
32 |
33 | Args:
34 | input (str): The input to validate.
35 | strict (bool): Whether to use strict validation.
36 | start (str): The start rule to use.
37 |
38 | Returns:
39 | bool: True if the input is valid, False otherwise.
40 | """
41 | try:
42 | self.lark_grammar.parse(input, start=start)
43 | return True
44 | except Exception as e:
45 | if not strict:
46 | if isinstance(e, UnexpectedEOF | UnexpectedCharacters):
47 | return True
48 | elif isinstance(e, UnexpectedToken) and e.token.type == "$END":
49 | return True
50 |
51 | return False
52 |
53 | from pse.types.grammar.default_grammars.bash import BashGrammar # noqa: E402
54 | from pse.types.grammar.default_grammars.python import PythonGrammar # noqa: E402
55 | from pse.types.grammar.lark import LarkGrammarStateMachine # noqa: E402
56 |
57 | PythonStateMachine = LarkGrammarStateMachine(PythonGrammar())
58 | BashStateMachine = LarkGrammarStateMachine(BashGrammar())
59 |
60 | __all__ = [
61 | "BashStateMachine",
62 | "LarkGrammarStateMachine",
63 | "PythonStateMachine",
64 | ]
65 |
--------------------------------------------------------------------------------
/pse/types/grammar/default_grammars/bash.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import logging
4 | import os
5 |
6 | from lark import Lark
7 | from lark.exceptions import UnexpectedCharacters, UnexpectedToken
8 |
9 | from pse.types.grammar import LarkGrammar
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 |
14 | class BashGrammar(LarkGrammar):
15 | def __init__(self):
16 | # Get the path to the bash.lark file
17 | current_dir = os.path.dirname(os.path.abspath(__file__))
18 | grammar_path = os.path.join(current_dir, "bash.lark")
19 |
20 | # Read the Lark file
21 | with open(grammar_path) as f:
22 | bash_grammar_content = f.read()
23 |
24 | bash_lark_grammar = Lark(
25 | bash_grammar_content,
26 | start="start",
27 | parser="lalr",
28 | lexer="basic",
29 | )
30 |
31 | super().__init__(
32 | name="Bash",
33 | lark_grammar=bash_lark_grammar,
34 | delimiters=("```bash\n", "\n```"),
35 | )
36 |
37 | def validate(
38 | self,
39 | input: str,
40 | strict: bool = False,
41 | start: str | None = None,
42 | ) -> bool:
43 | """
44 | Validate Bash code using the Lark parser.
45 |
46 | Args:
47 | input: The Bash code to validate.
48 | strict: Whether to use strict validation.
49 | start: The start rule to use.
50 | """
51 | # If code is empty, it's not valid bash
52 | if not input.strip():
53 | return False
54 |
55 | try:
56 | # Try to parse the input normally
57 | self.lark_grammar.parse(input, start=start or "start")
58 | return True
59 | except Exception as e:
60 | if not strict and isinstance(e, UnexpectedToken):
61 | return e.token.type == "$END" or "ESAC" in e.expected
62 |
63 | if not strict and isinstance(e, UnexpectedCharacters):
64 | # special case for unclosed quotes
65 | return e.char == "'" or e.char == '"'
66 |
67 | return False
68 |
--------------------------------------------------------------------------------
/pse/types/grammar/default_grammars/python.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import logging
4 | import os
5 |
6 | from lark import Lark
7 | from lark.exceptions import UnexpectedCharacters, UnexpectedToken
8 | from lark.indenter import PythonIndenter
9 |
10 | from pse.types.grammar import LarkGrammar
11 |
12 | logger = logging.getLogger(__name__)
13 |
14 | class PythonGrammar(LarkGrammar):
15 |
16 | def __init__(self):
17 | current_dir = os.path.dirname(os.path.abspath(__file__))
18 | grammar_path = os.path.join(current_dir, "python.lark")
19 |
20 | # Read the Lark file
21 | with open(grammar_path) as f:
22 | python_grammar_content = f.read()
23 | python_lark_grammar = Lark(
24 | python_grammar_content,
25 | parser="lalr",
26 | lexer="basic",
27 | postlex=PythonIndenter(),
28 | start=["file_input"],
29 | )
30 |
31 | super().__init__(
32 | name="Python",
33 | lark_grammar=python_lark_grammar,
34 | delimiters=("```python\n", "\n```"),
35 | )
36 |
37 | def validate(
38 | self,
39 | input: str,
40 | strict: bool = False,
41 | start: str | None = None,
42 | ) -> bool:
43 | """
44 | Validate Python code using the Lark parser.
45 |
46 | Args:
47 | parser: The Lark parser to use.
48 | code: The Python code to validate.
49 | strict: Whether to use strict validation.
50 | """
51 | if strict and not input.endswith("\n"):
52 | input += "\n"
53 |
54 | try:
55 | self.lark_grammar.parse(input, start=start)
56 | return True
57 | except Exception as e:
58 | if not strict and isinstance(e, UnexpectedToken):
59 | return e.token.type == "_DEDENT" or e.token.type == "$END"
60 | if not strict and isinstance(e, UnexpectedCharacters):
61 | # special case for unclosed quotes
62 | return e.char == "'" or e.char == '"'
63 |
64 | return False
65 |
--------------------------------------------------------------------------------
/pse/types/grammar/lark.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import logging
4 |
5 | from pse_core import StateId
6 | from pse_core.stepper import Stepper
7 |
8 | from pse.types.base.character import CharacterStateMachine, CharacterStepper
9 | from pse.types.grammar import LarkGrammar
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 | class LarkGrammarStateMachine(CharacterStateMachine):
14 | def __init__(self, grammar: LarkGrammar):
15 | super().__init__(char_min=1)
16 | self.grammar = grammar
17 |
18 | @property
19 | def delimiters(self) -> tuple[str, str] | None:
20 | return self.grammar.delimiters
21 |
22 | def get_new_stepper(self, state: StateId | None) -> LarkGrammarStepper:
23 | """
24 | Get a new stepper for the grammar.
25 | """
26 | return LarkGrammarStepper(self)
27 |
28 | def __str__(self) -> str:
29 | return self.grammar.name
30 |
31 |
32 | class LarkGrammarStepper(CharacterStepper):
33 | """
34 | A stepper for the grammar state machine.
35 | """
36 |
37 | def __init__(self, state_machine: LarkGrammarStateMachine):
38 | """
39 | Initialize the grammar stepper with a state machine.
40 |
41 | Args:
42 | state_machine: The grammar state machine that defines the valid transitions
43 | """
44 | super().__init__(state_machine)
45 | self.state_machine: LarkGrammarStateMachine = state_machine
46 |
47 | def get_identifier(self) -> str | None:
48 | return self.state_machine.grammar.name.lower()
49 |
50 | def should_start_step(self, token: str) -> bool:
51 | """
52 | Should the stepper start a new step?
53 | """
54 | valid_prefix, _ = self.get_valid_prefix(token)
55 | return valid_prefix is not None
56 |
57 | def has_reached_accept_state(self) -> bool:
58 | """
59 | Has the stepper reached the accept state?
60 | """
61 | valid_input = self.get_raw_value()
62 | return self.state_machine.grammar.validate(valid_input, strict=True)
63 |
64 | def consume(self, token: str) -> list[Stepper]:
65 | """
66 | Consume the input token and return possible next states.
67 |
68 | Args:
69 | token: The input string to consume
70 |
71 | Returns:
72 | A list of new steppers after consuming the token.
73 | Returns empty list if no valid transitions are possible.
74 | """
75 | valid_input, remaining_input = self.get_valid_prefix(token)
76 | assert valid_input is not None
77 | return [
78 | self.step(
79 | self.get_raw_value() + valid_input,
80 | remaining_input or None,
81 | )
82 | ]
83 |
84 | def get_valid_prefix(self, new_input: str) -> tuple[str | None, str]:
85 | """
86 | Get the first prefix of the new input that maintains a valid grammar state.
87 |
88 | Args:
89 | new_input: The input string to validate
90 |
91 | Returns:
92 | A tuple of (valid_prefix, remaining_input) where valid_prefix is None if no
93 | valid prefix exists
94 | """
95 | candidate_base = self.get_raw_value()
96 | max_valid_index = None
97 | for i in range(1, len(new_input) + 1):
98 | candidate = candidate_base + new_input[:i]
99 | if self.state_machine.grammar.validate(candidate, False):
100 | max_valid_index = i
101 | break
102 |
103 | if max_valid_index is not None:
104 | return new_input[:max_valid_index], new_input[max_valid_index:]
105 | else:
106 | return None, ""
107 |
--------------------------------------------------------------------------------
/pse/types/integer.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Any
4 |
5 | from pse.types.base.character import (
6 | CharacterStateMachine,
7 | CharacterStepper,
8 | )
9 |
10 |
11 | class IntegerStateMachine(CharacterStateMachine):
12 | """
13 | Accepts an integer as per JSON specification.
14 | """
15 |
16 | def __init__(self, drop_leading_zeros: bool = True) -> None:
17 | super().__init__("0123456789")
18 | self.drop_leading_zeros = drop_leading_zeros
19 |
20 | def get_new_stepper(self, state: int | str) -> IntegerStepper:
21 | return IntegerStepper(self)
22 |
23 | def __str__(self) -> str:
24 | return "Integer"
25 |
26 |
27 | class IntegerStepper(CharacterStepper):
28 | def __init__(
29 | self, state_machine: IntegerStateMachine, value: str | None = None
30 | ) -> None:
31 | super().__init__(state_machine, value)
32 | self.state_machine: IntegerStateMachine = state_machine
33 |
34 | def get_current_value(self) -> Any:
35 | if self._raw_value is None:
36 | return None
37 | return int(self._raw_value) if self.state_machine.drop_leading_zeros else self._raw_value
38 |
--------------------------------------------------------------------------------
/pse/types/json/any_json_schema.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | from pse_core import StateId
4 | from pse_core.state_machine import StateMachine
5 | from pse_core.stepper import Stepper
6 |
7 |
8 | class AnySchemaStateMachine(StateMachine):
9 | """
10 | Accepts JSON input that complies with any of several provided JSON schemas
11 | """
12 |
13 | def __init__(self, schemas: list[dict[str, Any]], context: dict[str, Any]) -> None:
14 | """
15 | This state_machine will validate JSON input against any of the provided schemas.
16 |
17 | Args:
18 | schemas (List[Dict[str, Any]]): A list of JSON schemas to validate against.
19 | context (Dict[str, Any]): Contextual information for schema definitions and paths.
20 | """
21 | from pse.types.json import _json_schema_to_state_machine
22 |
23 | # Construct the state machine graph with an initial state `0` that transitions
24 | # to the end state `$` for each schema state_machine.
25 | self.state_machines: list[StateMachine] = []
26 | for schema in schemas:
27 | sm = _json_schema_to_state_machine(schema, context)
28 | self.state_machines.append(sm)
29 |
30 | super().__init__(
31 | {0: [(state_machine, "$") for state_machine in self.state_machines]}
32 | )
33 |
34 | def get_steppers(self, state: StateId | None = None) -> list[Stepper]:
35 | steppers = []
36 | for edge, _ in self.get_edges(state or 0):
37 | steppers.extend(edge.get_steppers())
38 | return steppers
39 |
40 | def __str__(self) -> str:
41 | return "AnyJsonSchema"
42 |
--------------------------------------------------------------------------------
/pse/types/json/json_array.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Any
4 |
5 | from pse_core import StateId
6 | from pse_core.stepper import Stepper
7 |
8 | from pse.types.array import ArrayStateMachine, ArrayStepper
9 | from pse.types.base.chain import ChainStateMachine
10 | from pse.types.base.phrase import PhraseStateMachine
11 | from pse.types.json import _json_schema_to_state_machine
12 | from pse.types.whitespace import WhitespaceStateMachine
13 |
14 |
15 | class ArraySchemaStateMachine(ArrayStateMachine):
16 | def __init__(self, schema: dict[str, Any], context: dict[str, Any]) -> None:
17 | self.schema = schema
18 | self.context = context
19 | super().__init__(
20 | {
21 | 0: [
22 | (PhraseStateMachine("["), 1),
23 | ],
24 | 1: [
25 | (WhitespaceStateMachine(), 2),
26 | (PhraseStateMachine("]"), "$"),
27 | ],
28 | 2: [
29 | (
30 | _json_schema_to_state_machine(
31 | self.schema["items"], self.context
32 | ),
33 | 3,
34 | ),
35 | ],
36 | 3: [
37 | (WhitespaceStateMachine(), 4),
38 | ],
39 | 4: [
40 | (
41 | ChainStateMachine(
42 | [PhraseStateMachine(","), WhitespaceStateMachine()]
43 | ),
44 | 2,
45 | ),
46 | (PhraseStateMachine("]"), "$"),
47 | ],
48 | }
49 | )
50 |
51 | def get_transitions(self, stepper: Stepper) -> list[tuple[Stepper, StateId]]:
52 | """Retrieve transition steppers from the current state.
53 |
54 | For each edge from the current state, returns steppers that can traverse that edge.
55 | Args:
56 | stepper: The stepper initiating the transition.
57 | state: Optional starting state. If None, uses the stepper's current state.
58 |
59 | Returns:
60 | list[tuple[Stepper, StateId]]: A list of tuples representing transitions.
61 | """
62 | if stepper.current_state == 4:
63 | transitions: list[tuple[Stepper, StateId]] = []
64 | if len(stepper.get_current_value()) >= self.min_items():
65 | for transition in PhraseStateMachine("]").get_steppers():
66 | transitions.append((transition, "$"))
67 |
68 | if len(stepper.get_current_value()) < self.max_items():
69 | for transition in ChainStateMachine(
70 | [PhraseStateMachine(","), WhitespaceStateMachine()]
71 | ).get_steppers():
72 | transitions.append((transition, 2))
73 |
74 | return transitions
75 | elif stepper.current_state == 1 and self.min_items() > 0:
76 | transitions = []
77 | for transition in WhitespaceStateMachine().get_steppers():
78 | transitions.append((transition, 2))
79 | return transitions
80 | else:
81 | return super().get_transitions(stepper)
82 |
83 | def get_new_stepper(self, state: StateId | None = None) -> ArraySchemaStepper:
84 | return ArraySchemaStepper(self, state)
85 |
86 | def min_items(self) -> int:
87 | """
88 | Returns the minimum number of items in the array, according to the schema
89 | """
90 | return self.schema.get("minItems", 0)
91 |
92 | def max_items(self) -> int:
93 | """
94 | Returns the maximum number of items in the array, according to the schema
95 | """
96 | return self.schema.get("maxItems", 2**32)
97 |
98 | def unique_items(self) -> bool:
99 | """
100 | Returns whether the items in the array must be unique, according to the schema
101 | """
102 | return self.schema.get("uniqueItems", False)
103 |
104 | def __str__(self) -> str:
105 | return "JSON" + super().__str__()
106 |
107 |
108 | class ArraySchemaStepper(ArrayStepper):
109 | """ """
110 |
111 | def __init__(
112 | self,
113 | state_machine: ArraySchemaStateMachine,
114 | current_state: StateId | None = None,
115 | ):
116 | super().__init__(state_machine, current_state)
117 | self.state_machine: ArraySchemaStateMachine = state_machine
118 |
119 | def add_to_history(self, stepper: Stepper) -> None:
120 | """
121 | Adds an item to the array.
122 | """
123 | item = stepper.get_current_value()
124 | if self.state_machine.unique_items() and self.is_within_value():
125 | if item in self.value:
126 | return
127 |
128 | super().add_to_history(stepper)
129 |
--------------------------------------------------------------------------------
/pse/types/json/json_key_value.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Any
4 |
5 | from pse_core import StateId
6 |
7 | from pse.types.base.chain import ChainStateMachine
8 | from pse.types.base.phrase import PhraseStateMachine
9 | from pse.types.json import _json_schema_to_state_machine
10 | from pse.types.key_value import KeyValueStateMachine, KeyValueStepper
11 | from pse.types.string import StringStateMachine
12 | from pse.types.whitespace import WhitespaceStateMachine
13 |
14 |
15 | class KeyValueSchemaStateMachine(KeyValueStateMachine):
16 | """
17 | Args:
18 | prop_name (str): The name of the property.
19 | prop_schema (Dict[str, Any]): The schema of the property.
20 | context (Dict[str, Any]): The parsing context.
21 | """
22 |
23 | def __init__(
24 | self,
25 | prop_name: str | None,
26 | prop_schema: dict[str, Any],
27 | context: dict[str, Any],
28 | ):
29 | self.prop_name = prop_name
30 | self.prop_schema = prop_schema
31 | self.prop_context = {
32 | "defs": context.get("defs", {}),
33 | "path": f"{context.get('path', '')}/{prop_name}",
34 | }
35 | if self.prop_name:
36 | key_value_sm = ChainStateMachine(
37 | [
38 | PhraseStateMachine('"'),
39 | PhraseStateMachine(self.prop_name),
40 | PhraseStateMachine('"'),
41 | ]
42 | )
43 | else:
44 | key_value_sm = StringStateMachine()
45 |
46 | is_optional = self.prop_schema.get("nullable", False) or "default" in self.prop_schema
47 | super().__init__(
48 | [
49 | key_value_sm,
50 | WhitespaceStateMachine(),
51 | PhraseStateMachine(":"),
52 | WhitespaceStateMachine(),
53 | _json_schema_to_state_machine(self.prop_schema, self.prop_context),
54 | ],
55 | is_optional=is_optional,
56 | )
57 |
58 | def get_new_stepper(self, state: StateId | None = None) -> KeyValueSchemaStepper:
59 | return KeyValueSchemaStepper(self, state)
60 |
61 |
62 | class KeyValueSchemaStepper(KeyValueStepper):
63 | def __init__(
64 | self,
65 | state_machine: KeyValueSchemaStateMachine,
66 | current_state: StateId | None = None,
67 | ):
68 | super().__init__(state_machine, current_state)
69 | self.state_machine: KeyValueSchemaStateMachine = state_machine
70 |
--------------------------------------------------------------------------------
/pse/types/json/json_number.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from pse_core import StateId
4 | from pse_core.stepper import Stepper
5 |
6 | from pse.types.number import NumberStateMachine
7 |
8 |
9 | class NumberSchemaStateMachine(NumberStateMachine):
10 | """
11 | Accept a JSON number that conforms to a JSON schema
12 | """
13 |
14 | def __init__(self, schema):
15 | super().__init__()
16 | self.schema = schema
17 | self.is_integer = schema["type"] == "integer"
18 | self.requires_validation = any(
19 | constraint in schema
20 | for constraint in [
21 | "minimum",
22 | "exclusiveMinimum",
23 | "maximum",
24 | "exclusiveMaximum",
25 | "multipleOf",
26 | ]
27 | )
28 |
29 | def get_new_stepper(self, state: StateId | None = None) -> NumberSchemaStepper:
30 | return NumberSchemaStepper(self, state)
31 |
32 | def validate_value(self, value: float) -> bool:
33 | """
34 | Validate the number value according to the schema
35 | """
36 | if not isinstance(value, int | float):
37 | return True
38 |
39 | if "minimum" in self.schema and value < self.schema["minimum"]:
40 | return False
41 |
42 | if (
43 | "exclusiveMinimum" in self.schema
44 | and value <= self.schema["exclusiveMinimum"]
45 | ):
46 | return False
47 | if "maximum" in self.schema and value > self.schema["maximum"]:
48 | return False
49 | if (
50 | "exclusiveMaximum" in self.schema
51 | and value >= self.schema["exclusiveMaximum"]
52 | ):
53 | return False
54 | if "multipleOf" in self.schema:
55 | divisor = self.schema["multipleOf"]
56 | if value / divisor != value // divisor:
57 | return False
58 |
59 | if self.is_integer and not (isinstance(value, int) or value.is_integer()):
60 | return False
61 |
62 | return True
63 |
64 | def __str__(self) -> str:
65 | return "JSON" + super().__str__()
66 |
67 |
68 | class NumberSchemaStepper(Stepper):
69 | """ """
70 |
71 | def __init__(
72 | self,
73 | state_machine: NumberSchemaStateMachine,
74 | current_state: StateId | None = None,
75 | ):
76 | super().__init__(state_machine, current_state)
77 | self.state_machine: NumberSchemaStateMachine = state_machine
78 |
79 | def should_start_step(self, token: str) -> bool:
80 | if self.state_machine.is_integer and self.target_state == 3:
81 | return False
82 |
83 | return super().should_start_step(token)
84 |
85 | def should_complete_step(self) -> bool:
86 | if not super().should_complete_step():
87 | return False
88 |
89 | return self.state_machine.validate_value(self.get_current_value())
90 |
91 | def has_reached_accept_state(self) -> bool:
92 | if super().has_reached_accept_state():
93 | return self.state_machine.validate_value(self.get_current_value())
94 |
95 | return False
96 |
--------------------------------------------------------------------------------
/pse/types/json/json_object.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Any
4 |
5 | from pse_core import StateId
6 | from pse_core.state_machine import StateMachine
7 | from pse_core.stepper import Stepper
8 |
9 | from pse.types.base.chain import ChainStateMachine
10 | from pse.types.base.phrase import PhraseStateMachine
11 | from pse.types.json.json_key_value import KeyValueSchemaStateMachine
12 | from pse.types.key_value import KeyValueStateMachine
13 | from pse.types.object import ObjectStateMachine
14 | from pse.types.whitespace import WhitespaceStateMachine
15 |
16 |
17 | class ObjectSchemaStateMachine(ObjectStateMachine):
18 | def __init__(
19 | self,
20 | schema: dict[str, Any],
21 | context: dict[str, Any],
22 | ):
23 | self.schema = schema
24 | self.context = context
25 | self.properties: dict[str, Any] = schema.get("properties", {})
26 | self.required_property_names: list[str] = schema.get("required", [])
27 | self.additional_properties: dict[str, Any] | bool = schema.get(
28 | "additionalProperties", {}
29 | )
30 | self.ordered_properties: bool = schema.get("orderedProperties", True)
31 | if any(prop not in self.properties for prop in self.required_property_names):
32 | raise ValueError("Required property not defined in schema")
33 |
34 | for property_name, property_schema in self.properties.items():
35 | if property_name in self.required_property_names and property_schema:
36 | if (
37 | property_schema.get("nullable", False)
38 | or "default" in property_schema
39 | ):
40 | self.required_property_names.remove(property_name)
41 |
42 | super().__init__(schema.get("nullable", False))
43 |
44 | def get_transitions(self, stepper: Stepper) -> list[tuple[Stepper, StateId]]:
45 | """Retrieve transition steppers from the current state.
46 |
47 | Returns:
48 | list[tuple[Stepper, StateId]]: A list of tuples representing transitions.
49 | """
50 | value = stepper.get_current_value()
51 | transitions: list[tuple[Stepper, StateId]] = []
52 | if stepper.current_state == 2:
53 | for property in self.get_property_state_machines(value):
54 | for transition in property.get_steppers():
55 | transitions.append((transition, 3))
56 |
57 | elif stepper.current_state == 4:
58 | if all(prop_name in value for prop_name in self.required_property_names):
59 | for transition in PhraseStateMachine("}").get_steppers():
60 | transitions.append((transition, "$"))
61 |
62 | if len(value) < len(self.properties) or self.additional_properties:
63 | for transition in ChainStateMachine(
64 | [PhraseStateMachine(","), WhitespaceStateMachine()]
65 | ).get_steppers():
66 | transitions.append((transition, 2))
67 | else:
68 | return super().get_transitions(stepper)
69 |
70 | return transitions
71 |
72 | def get_property_state_machines(self, value: dict[str, Any]) -> list[StateMachine]:
73 | property_state_machines: list[StateMachine] = []
74 | for prop_name, prop_schema in self.properties.items():
75 | if prop_name not in value:
76 | property = KeyValueSchemaStateMachine(
77 | prop_name,
78 | prop_schema,
79 | self.context,
80 | )
81 | property_state_machines.append(property)
82 | if self.ordered_properties:
83 | break
84 |
85 | if (
86 | all(prop_name in value for prop_name in self.required_property_names)
87 | and self.additional_properties
88 | ):
89 | # non-schema kv property to represent the additional properties
90 | if isinstance(self.additional_properties, dict):
91 | property = KeyValueSchemaStateMachine(
92 | None,
93 | self.additional_properties,
94 | self.context,
95 | )
96 | else:
97 | property = KeyValueStateMachine()
98 | property_state_machines.append(property)
99 |
100 | return property_state_machines
101 |
102 | def __eq__(self, other: object) -> bool:
103 | return (
104 | isinstance(other, ObjectSchemaStateMachine)
105 | and super().__eq__(other)
106 | and self.schema == other.schema
107 | )
108 |
109 | def __str__(self) -> str:
110 | return "JSON" + super().__str__()
111 |
--------------------------------------------------------------------------------
/pse/types/json/json_string.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import logging
4 | import re
5 |
6 | import regex
7 | from pse_core import StateId
8 |
9 | from pse.types.string import StringStateMachine, StringStepper
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 |
14 | class StringSchemaStateMachine(StringStateMachine):
15 | """
16 | Accept a JSON string that conforms to a JSON schema, including 'pattern' and 'format' constraints.
17 | """
18 |
19 | # Class-level constants
20 | EMAIL_PATTERN = re.compile(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$")
21 | SUPPORTED_FORMATS = frozenset(["email", "date-time", "uri"])
22 |
23 | def __init__(
24 | self,
25 | schema: dict,
26 | ):
27 | super().__init__(
28 | min_length=schema.get("minLength"),
29 | max_length=schema.get("maxLength"),
30 | )
31 | self.schema = schema or {}
32 | self.pattern: re.Pattern | None = None
33 | self.format: str | None = None
34 |
35 | if "pattern" in self.schema:
36 | try:
37 | self.pattern = re.compile(self.schema["pattern"])
38 | except re.error as e:
39 | raise ValueError(f"Invalid pattern in schema: {e}") from e
40 |
41 | if "format" in self.schema:
42 | self.format = self.schema["format"]
43 | if self.format not in self.SUPPORTED_FORMATS:
44 | raise ValueError(
45 | f"Format '{self.format}' not supported. Supported formats: {', '.join(self.SUPPORTED_FORMATS)}"
46 | )
47 |
48 | def get_new_stepper(self, state: StateId | None = None) -> StringSchemaStepper:
49 | return StringSchemaStepper(self, state)
50 |
51 | def min_length(self) -> int:
52 | """
53 | Returns the minimum string length according to the schema.
54 | """
55 | return self.schema.get("minLength", 0)
56 |
57 | def max_length(self) -> int:
58 | """
59 | Returns the maximum string length according to the schema.
60 | """
61 | return self.schema.get("maxLength", 10000)
62 |
63 | def validate_email(self, value: str) -> bool:
64 | """
65 | Validate that the value is a valid email address.
66 | """
67 | return bool(self.EMAIL_PATTERN.fullmatch(value))
68 |
69 | def validate_date_time(self, value: str) -> bool:
70 | """
71 | Validate that the value is a valid ISO 8601 date-time.
72 | """
73 | from datetime import datetime
74 |
75 | try:
76 | datetime.fromisoformat(value)
77 | return True
78 | except ValueError:
79 | return False
80 |
81 | def validate_uri(self, value: str) -> bool:
82 | """
83 | Validate that the value is a valid URI.
84 | """
85 | from urllib.parse import urlparse
86 |
87 | try:
88 | result = urlparse(value)
89 | return result.scheme is not None and result.netloc is not None
90 | except ValueError:
91 | return False
92 |
93 | def __str__(self) -> str:
94 | return "JSON" + super().__str__()
95 |
96 |
97 | class StringSchemaStepper(StringStepper):
98 | def __init__(
99 | self,
100 | state_machine: StringSchemaStateMachine,
101 | current_state: StateId | None = None,
102 | ):
103 | super().__init__(state_machine, current_state)
104 | self.state_machine: StringSchemaStateMachine = state_machine
105 | self._format_validators = {
106 | "email": self.state_machine.validate_email,
107 | "date-time": self.state_machine.validate_date_time,
108 | "uri": self.state_machine.validate_uri,
109 | }
110 |
111 | def should_start_step(self, token: str) -> bool:
112 | if super().should_start_step(token):
113 | if self.is_within_value():
114 | valid_prefix = self.get_valid_prefix(token)
115 | return self.validate_value(valid_prefix)
116 | return True
117 |
118 | return False
119 |
120 | def consume(self, token: str):
121 | """
122 | Consume the token and return the new stepper.
123 | """
124 | if self.is_within_value():
125 | valid_prefix = self.get_valid_prefix(token)
126 | if not valid_prefix:
127 | return []
128 | else:
129 | valid_prefix = token
130 |
131 | steppers = super().consume(valid_prefix)
132 | for stepper in steppers:
133 | if token != valid_prefix:
134 | stepper.remaining_input = token[len(valid_prefix) :]
135 |
136 | return steppers
137 |
138 | def clean_value(self, value: str) -> str:
139 | """
140 | Clean and normalize the input value by removing bounding quotes.
141 |
142 | Args:
143 | value: The string value to clean.
144 |
145 | Returns:
146 | str: The cleaned string with bounding quotes removed.
147 | """
148 | if value.startswith('"'):
149 | value = value[1:]
150 | if value.endswith('"'):
151 | first_quote = value.index('"')
152 | value = value[: first_quote]
153 | return value
154 |
155 | def get_valid_prefix(self, s: str) -> str | None:
156 | """
157 | Check whether the string 's' can be a prefix of any string matching the pattern.
158 | Uses binary search for efficiency.
159 | """
160 | if (
161 | not self.is_within_value()
162 | or not self.state_machine.pattern
163 | or not self.sub_stepper
164 | ):
165 | return s
166 |
167 | current_value = self.sub_stepper.get_raw_value()
168 | quotes_removed_s = self.clean_value(s)
169 |
170 | left, right = 0, len(quotes_removed_s)
171 | best_match = None
172 |
173 | while left <= right:
174 | mid = (left + right) // 2
175 | working_s = quotes_removed_s[:mid]
176 | match = regex.match(
177 | self.state_machine.pattern.pattern,
178 | current_value + working_s,
179 | partial=True,
180 | )
181 | if match:
182 | best_match = working_s
183 | left = mid + 1 # Try a longer prefix
184 | else:
185 | right = mid - 1 # Try a shorter prefix
186 |
187 | if best_match is not None:
188 | if best_match == quotes_removed_s:
189 | return s # Return original if the whole string is a valid prefix
190 | return best_match
191 |
192 | return None
193 |
194 | def validate_value(self, value: str | None = None) -> bool:
195 | """
196 | Validate the string value according to the schema.
197 |
198 | Args:
199 | value: Optional string to append to current value before validation
200 |
201 | Returns:
202 | bool: True if the value meets all schema constraints
203 |
204 | Note:
205 | Validates length, pattern, and format constraints in sequence
206 | """
207 | value = self.clean_value(self.get_raw_value() + (value or ""))
208 | if not value:
209 | return False
210 |
211 | # Length validation
212 | if len(value) > self.state_machine.max_length():
213 | return False
214 |
215 | # Pattern validation
216 | if not self.is_within_value() and self.state_machine.pattern:
217 | if not self.state_machine.pattern.match(value):
218 | return False
219 |
220 | # Format validation
221 | if self.state_machine.format:
222 | validator = self._format_validators.get(self.state_machine.format)
223 | if not validator:
224 | raise ValueError(
225 | f"No validator found for format: {self.state_machine.format}"
226 | )
227 | if not validator(value):
228 | return False
229 |
230 | return True
231 |
--------------------------------------------------------------------------------
/pse/types/json/json_value.py:
--------------------------------------------------------------------------------
1 | from pse_core import Edge, StateId
2 | from pse_core.state_machine import StateMachine
3 | from pse_core.stepper import Stepper
4 |
5 |
6 | class JsonStateMachine(StateMachine):
7 | def get_edges(self, state: StateId) -> list[Edge]:
8 | if state == 0:
9 | from pse.types.array import ArrayStateMachine
10 | from pse.types.base.phrase import PhraseStateMachine
11 | from pse.types.boolean import BooleanStateMachine
12 | from pse.types.number import NumberStateMachine
13 | from pse.types.object import ObjectStateMachine
14 | from pse.types.string import StringStateMachine
15 |
16 | return [
17 | (ObjectStateMachine(), "$"),
18 | (ArrayStateMachine(), "$"),
19 | (StringStateMachine(), "$"),
20 | (PhraseStateMachine("null"), "$"),
21 | (BooleanStateMachine(), "$"),
22 | (NumberStateMachine(), "$"),
23 | ]
24 | return []
25 |
26 | def get_steppers(self, state: StateId | None = None) -> list[Stepper]:
27 | steppers = []
28 | for edge, _ in self.get_edges(state or 0):
29 | steppers.extend(edge.get_steppers())
30 | return steppers
31 |
--------------------------------------------------------------------------------
/pse/types/json/schema_sources/from_function.py:
--------------------------------------------------------------------------------
1 | import enum
2 | import inspect
3 | import logging
4 | from collections.abc import Callable
5 | from typing import Any, get_args, get_origin
6 |
7 | from docstring_parser import Docstring, DocstringParam, parse
8 | from pydantic import BaseModel
9 |
10 | from pse.types.json.schema_sources.from_pydantic import pydantic_to_schema
11 |
12 | logger = logging.getLogger(__name__)
13 |
14 |
15 | def callable_to_schema(function: Callable) -> dict[str, Any]:
16 | """
17 | Generate a schema for the specified Python function.
18 |
19 | This takes a callable and parses its signature and docstring,
20 | and constructs a schema representing the function's parameters.
21 |
22 | Args:
23 | function (Callable): The Python function to generate a schema for.
24 |
25 | Returns:
26 | dict[str, Any]: A dictionary representing the JSON schema of the function's parameters.
27 | """
28 | sig = inspect.signature(function)
29 | docstring = parse(function.__doc__ or "No docstring provided")
30 |
31 | schema: dict[str, Any] = {
32 | "name": function.__name__,
33 | "description": docstring.description,
34 | "parameters": {
35 | "type": "object",
36 | "properties": {},
37 | "required": [],
38 | },
39 | }
40 |
41 | param_index = 0
42 | for param in sig.parameters.values():
43 | if param.name == "self":
44 | continue # skip 'self' parameter
45 |
46 | param_docstring = (
47 | docstring.params[param_index]
48 | if len(docstring.params) > param_index
49 | else None
50 | )
51 | param_schema = parameter_to_schema(param, param_docstring, docstring)
52 | schema["parameters"]["properties"][param.name] = param_schema
53 |
54 | if param.default is inspect.Parameter.empty and param_schema.get("nullable", False) is False:
55 | schema["parameters"]["required"].append(param.name)
56 |
57 | param_index += 1
58 |
59 | # Handle the case when all parameters are nullable with defaults
60 | if not schema["parameters"]["required"]:
61 | del schema["parameters"]["required"]
62 | schema["parameters"]["nullable"] = True
63 |
64 | return schema
65 |
66 |
67 | def parameter_to_schema(
68 | param: inspect.Parameter,
69 | param_docstring: DocstringParam | None,
70 | docstring: Docstring
71 | ) -> dict[str, Any]:
72 | """
73 | Generate a schema for a function parameter.
74 |
75 | Args:
76 | param (inspect.Parameter): The parameter to generate a schema for.
77 | docstring (Docstring): The docstring for the function.
78 | """
79 |
80 | parameter_schema: dict[str, Any] = {}
81 | if param_docstring:
82 | parameter_schema["description"] = param_docstring.description or docstring.short_description or ""
83 |
84 | if inspect.isclass(param.annotation) and issubclass(param.annotation, BaseModel):
85 | # Use Pydantic model if the parameter is a BaseModel subclass
86 | return pydantic_to_schema(param.annotation)
87 | elif param.annotation == inspect.Parameter.empty:
88 | logger.warning(f"Parameter '{param.name}' lacks type annotation.")
89 | parameter_schema["type"] = "any"
90 | return parameter_schema
91 | elif param.default is not inspect.Parameter.empty:
92 | default_value = param.default
93 | if default_value is None:
94 | parameter_schema["nullable"] = True
95 | else:
96 | parameter_schema["default"] = default_value
97 | #######
98 | parameter_type_schemas = []
99 | parameter_arguments = get_args(param.annotation)
100 |
101 | # Special handling for direct dict type
102 | origin = get_origin(param.annotation)
103 | if origin is dict:
104 | dict_schema = {
105 | "type": "object",
106 | "additionalProperties": {"type": "any"}
107 | }
108 | args = get_args(param.annotation)
109 | if len(args) > 1:
110 | value_type = args[1]
111 | dict_schema["additionalProperties"] = {
112 | "type": get_type(value_type)
113 | }
114 | # Preserve the description if it exists
115 | if param_docstring:
116 | dict_schema["description"] = param_docstring.description or ""
117 | return dict_schema
118 |
119 | # Process union types and other types.
120 | parameter_type_schemas: list[dict[str, Any]] = []
121 | for argument in parameter_arguments or [param.annotation]:
122 | parameter_type_schema: dict[str, Any] = {}
123 | arg_origin = get_origin(argument)
124 | parameter_type = get_type(argument)
125 |
126 | if arg_origin is dict:
127 | parameter_type_schema["type"] = "object"
128 | args = get_args(argument)
129 | # Consider using get, or a guard clause.
130 | if len(args) > 1:
131 | parameter_type_schema["additionalProperties"] = {
132 | "type": get_type(args[1])
133 | }
134 | elif parameter_type == "null":
135 | parameter_schema["nullable"] = True
136 | continue # Skip adding to type_schemas.
137 | elif parameter_type in ("array", "set"):
138 | parameter_type_schema["type"] = parameter_type
139 | if args := get_args(argument):
140 | parameter_type_schema["items"] = {"type": get_type(args[0])}
141 | elif parameter_type == "enum" and issubclass(argument, enum.Enum):
142 | parameter_type_schema["enum"] = [
143 | member.value for member in argument
144 | ] # More concisely.
145 | elif parameter_type:
146 | parameter_type_schema["type"] = parameter_type
147 |
148 | if parameter_type_schema:
149 | parameter_type_schemas.append(parameter_type_schema)
150 |
151 | # Simplify the logic for setting the final schema type, handling edge cases.
152 | match len(parameter_type_schemas):
153 | case 0:
154 | # If no types were added and it wasn't nullable, default to "any".
155 | if "nullable" not in parameter_schema:
156 | parameter_schema["type"] = "any"
157 | case 1:
158 | # Merge the single schema into the main schema.
159 | parameter_schema.update(parameter_type_schemas[0])
160 |
161 | if len(parameter_type_schemas) > 1:
162 | parameter_schema["type"] = parameter_type_schemas
163 | elif parameter_type_schemas:
164 | parameter_schema.update(**parameter_type_schemas[0])
165 | else:
166 | parameter_schema["type"] = "any"
167 |
168 | return parameter_schema
169 |
170 | def get_type(python_type: Any) -> str:
171 | """Map a Python type to a JSON schema type."""
172 | if python_type is type(None):
173 | return "null"
174 |
175 | type_name = get_origin(python_type) or python_type
176 | type_map: dict[type | Any, str] = {
177 | int: "integer",
178 | str: "string",
179 | bool: "boolean",
180 | float: "number",
181 | list: "array",
182 | dict: "object",
183 | tuple: "array",
184 | set: "set",
185 | enum.EnumType: "enum",
186 | type(None): "null",
187 | BaseModel: "object",
188 | Any: "any",
189 | }
190 | if type_name not in type_map:
191 | if type(python_type) in type_map:
192 | return type_map[type(python_type)]
193 |
194 | logger.warning(f"Unknown type: {python_type}")
195 | return "any"
196 |
197 | return type_map[type_name]
198 |
--------------------------------------------------------------------------------
/pse/types/json/schema_sources/from_pydantic.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import Any
3 |
4 | from docstring_parser import parse
5 | from pydantic import BaseModel
6 |
7 | logger = logging.getLogger(__name__)
8 |
9 |
10 | def pydantic_to_schema(model: type[BaseModel]) -> dict[str, Any]:
11 | """
12 | Convert a Pydantic model class to a standardized schema format.
13 |
14 | Returns:
15 | dict[str, Any]: A dictionary representing the schema.
16 | """
17 | # Get base schema from Pydantic model
18 | schema = model.model_json_schema()
19 |
20 | # Extract docstring info
21 | docstring = parse(model.__doc__ or "")
22 | docstring_params = {
23 | param.arg_name: param.description
24 | for param in docstring.params
25 | if param.description
26 | }
27 |
28 | # Get description from schema or docstring
29 | description = schema.get("description") or docstring.short_description or ""
30 |
31 | # Extract parameters, excluding metadata fields
32 | parameters = {k: v for k, v in schema.items() if k not in {"title", "description"}}
33 |
34 | # Process properties and required fields
35 | properties = parameters.get("properties", {})
36 | required_fields = set(parameters.get("required", []))
37 |
38 | assert isinstance(properties, dict)
39 |
40 | # Update field schemas with descriptions and requirements
41 | for field_name, field in model.model_fields.items():
42 | field_schema: dict[str, Any] = properties.get(field_name, {})
43 |
44 | # Add field to required list if needed
45 | if field.is_required():
46 | required_fields.add(field_name)
47 |
48 | # Set description from field or docstring
49 | field_schema["description"] = field.description or docstring_params.get(
50 | field_name, ""
51 | )
52 |
53 | # Add any extra schema properties
54 | if extra := field.json_schema_extra:
55 | if isinstance(extra, dict):
56 | # For dictionaries, update the schema with the extra fields
57 | field_schema.update(extra)
58 |
59 | parameters["required"] = list(required_fields)
60 |
61 | return {
62 | "name": schema.get("title", model.__name__),
63 | "description": description,
64 | **parameters,
65 | }
66 |
--------------------------------------------------------------------------------
/pse/types/key_value.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import json
4 | import logging
5 | from typing import Any
6 |
7 | from pse_core import StateId
8 | from pse_core.state_machine import StateMachine
9 | from pse_core.stepper import Stepper
10 |
11 | from pse.types.base.chain import ChainStateMachine
12 | from pse.types.base.phrase import PhraseStateMachine
13 | from pse.types.string import StringStateMachine
14 | from pse.types.whitespace import WhitespaceStateMachine
15 |
16 | logger = logging.getLogger()
17 |
18 |
19 | class KeyValueStateMachine(ChainStateMachine):
20 | def __init__(self, sequence: list[StateMachine] | None = None, is_optional: bool = False) -> None:
21 | from pse.types.json.json_value import JsonStateMachine
22 |
23 | super().__init__(
24 | sequence
25 | or [
26 | StringStateMachine(),
27 | WhitespaceStateMachine(),
28 | PhraseStateMachine(":"),
29 | WhitespaceStateMachine(),
30 | JsonStateMachine(),
31 | ],
32 | is_optional=is_optional,
33 | )
34 |
35 | def get_new_stepper(self, state: StateId | None = None) -> KeyValueStepper:
36 | return KeyValueStepper(self, state)
37 |
38 | def __str__(self) -> str:
39 | return "KeyValue"
40 |
41 |
42 | class KeyValueStepper(Stepper):
43 | def __init__(
44 | self,
45 | state_machine: KeyValueStateMachine,
46 | current_step_id: StateId | None = None,
47 | ) -> None:
48 | super().__init__(state_machine, current_step_id)
49 | self.prop_name = ""
50 | self.prop_value: Any | None = None
51 |
52 | def clone(self) -> KeyValueStepper:
53 | cloned_stepper = super().clone()
54 | cloned_stepper.prop_name = self.prop_name
55 | cloned_stepper.prop_value = self.prop_value
56 | return cloned_stepper
57 |
58 | def should_complete_step(self) -> bool:
59 | """
60 | Handle the completion of a transition by setting the property name and value.
61 |
62 | Returns:
63 | bool: True if the transition was successful, False otherwise.
64 | """
65 | if not super().should_complete_step() or not self.sub_stepper:
66 | return False
67 |
68 | try:
69 | if self.target_state == 1:
70 | self.prop_name = json.loads(self.sub_stepper.get_raw_value())
71 | elif self.target_state in self.state_machine.end_states:
72 | self.prop_value = json.loads(self.sub_stepper.get_raw_value())
73 | except Exception:
74 | return False
75 |
76 | return True
77 |
78 | def get_current_value(self) -> tuple[str, Any]:
79 | """
80 | Get the parsed property as a key-value pair.
81 |
82 | Returns:
83 | Tuple[str, Any]: A tuple containing the property name and its corresponding value.
84 | """
85 | if not self.prop_name:
86 | return ("", None)
87 | return (self.prop_name, self.prop_value)
88 |
--------------------------------------------------------------------------------
/pse/types/misc/fenced_freeform.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from pse_core import StateId
4 |
5 | from pse.types.base.character import CharacterStateMachine
6 | from pse.types.base.encapsulated import EncapsulatedStateMachine, EncapsulatedStepper
7 |
8 |
9 | class FencedFreeformStateMachine(EncapsulatedStateMachine):
10 | """
11 | A state machine that can be used to parse freeform text that is enclosed in a pair of delimiters.
12 | """
13 | def __init__(self,
14 | identifier: str | None = None,
15 | delimiter: tuple[str, str] | None = None,
16 | buffer_length: int = -1,
17 | char_min: int = 1,
18 | char_max: int = -1,
19 | is_optional: bool = False):
20 |
21 | if delimiter is None:
22 | delimiter = (f'```{identifier or ""}\n', '\n```')
23 |
24 | freeform_state_machine = CharacterStateMachine(
25 | whitelist_charset="",
26 | graylist_charset=set(delimiter[1]),
27 | blacklist_charset=delimiter[0][0],
28 | char_min=char_min,
29 | char_limit=char_max,
30 | )
31 | super().__init__(freeform_state_machine, delimiter, buffer_length, is_optional)
32 | self.identifier = identifier
33 |
34 | def get_new_stepper(self, state: StateId | None = None) -> FencedFreeformStepper:
35 | return FencedFreeformStepper(self, state)
36 |
37 | class FencedFreeformStepper(EncapsulatedStepper):
38 |
39 | def __init__(
40 | self,
41 | state_machine: FencedFreeformStateMachine,
42 | state: StateId | None = None,
43 | ) -> None:
44 | super().__init__(state_machine, state)
45 | self.state_machine: FencedFreeformStateMachine = state_machine
46 |
47 | def get_identifier(self) -> str | None:
48 | return self.state_machine.identifier
49 |
50 | def get_invalid_continuations(self) -> list[str]:
51 | if not self.inner_stepper:
52 | return [self.state_machine.delimiters[1]]
53 | return super().get_invalid_continuations()
54 |
--------------------------------------------------------------------------------
/pse/types/misc/freeform.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from collections.abc import Callable
4 | from typing import Any
5 |
6 | from pse_core import StateId
7 |
8 | from pse.types.base.wait_for import WaitFor, WaitForStepper
9 | from pse.types.enum import EnumStateMachine
10 |
11 |
12 | class FreeformStateMachine(WaitFor):
13 | """
14 | A state machine that can be used to parse freeform text that has an ending delimiter.
15 | """
16 |
17 | def __init__(
18 | self,
19 | end_delimiters: list[str],
20 | char_min: int | None = None,
21 | ):
22 | self.end_delimiters = end_delimiters
23 | delimiter_state_machine = EnumStateMachine(self.end_delimiters, require_quotes=False)
24 | super().__init__(
25 | delimiter_state_machine,
26 | buffer_length=char_min or 1,
27 | )
28 |
29 | def get_new_stepper(self, _: StateId | None = None) -> FreeformStepper:
30 | return FreeformStepper(self)
31 |
32 | def __str__(self) -> str:
33 | return "FreeformText"
34 |
35 | class FreeformStepper(WaitForStepper):
36 |
37 | def __init__(
38 | self,
39 | state_machine: FreeformStateMachine,
40 | ):
41 | super().__init__(state_machine)
42 | self.state_machine: FreeformStateMachine = state_machine
43 |
44 | def get_raw_value(self) -> str:
45 | """
46 | Get the raw value of the buffer.
47 | """
48 | if self.sub_stepper:
49 | return self.buffer + self.sub_stepper.get_raw_value()
50 | elif self.history:
51 | accepted_raw_value = self.history[-1].get_raw_value()
52 | return self.buffer + accepted_raw_value
53 |
54 | return self.buffer
55 |
56 | def get_current_value(self) -> Any:
57 | return self.buffer
58 |
59 | def get_token_safe_output(
60 | self,
61 | decode_function: Callable[[list[int]], str],
62 | ) -> str:
63 | """
64 | Get the token safe output of the buffer.
65 | """
66 | safe_output = super().get_token_safe_output(decode_function)
67 | for end_delimiter in self.state_machine.end_delimiters:
68 | if safe_output.endswith(end_delimiter):
69 | return safe_output[:-len(end_delimiter)]
70 |
71 | return safe_output
72 |
--------------------------------------------------------------------------------
/pse/types/number.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import logging
4 |
5 | from pse_core import Edge, StateId
6 | from pse_core.state_machine import StateMachine
7 |
8 | from pse.types.base.chain import ChainStateMachine
9 | from pse.types.base.character import CharacterStateMachine
10 | from pse.types.base.phrase import PhraseStateMachine
11 | from pse.types.integer import IntegerStateMachine
12 |
13 | logger = logging.getLogger(__name__)
14 |
15 |
16 | class NumberStateMachine(StateMachine):
17 | """
18 | Accepts a well-formed JSON number.
19 |
20 | This state_machine defines the state transitions for parsing JSON numbers, handling integer,
21 | decimal, and exponential formats as specified by the JSON standard.
22 | """
23 |
24 | def __init__(self):
25 | super().__init__(
26 | {
27 | 0: [
28 | (PhraseStateMachine("-", is_optional=True), 1),
29 | ],
30 | 1: [
31 | (IntegerStateMachine(), 2),
32 | ],
33 | 2: [
34 | (
35 | ChainStateMachine(
36 | [
37 | PhraseStateMachine("."),
38 | IntegerStateMachine(drop_leading_zeros=False),
39 | ],
40 | ),
41 | 3,
42 | ),
43 | ],
44 | 3: [
45 | (CharacterStateMachine("eE", char_limit=1), 4),
46 | ],
47 | 4: [
48 | (CharacterStateMachine("+-", char_limit=1), 5),
49 | ],
50 | 5: [
51 | (IntegerStateMachine(), "$"),
52 | ],
53 | },
54 | end_states=[2, 3, "$"],
55 | )
56 |
57 | def get_edges(self, state: StateId) -> list[Edge]:
58 | """
59 | Get the edges for a given state.
60 | """
61 | if state == 2:
62 | return [*super().get_edges(state), *super().get_edges(3)]
63 | elif state == 4:
64 | return [*super().get_edges(5), *super().get_edges(state)]
65 | else:
66 | return [*super().get_edges(state)]
67 |
68 | def __str__(self) -> str:
69 | return "Number"
70 |
--------------------------------------------------------------------------------
/pse/types/object.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import logging
4 | from typing import Any
5 |
6 | from pse_core import StateId
7 | from pse_core.state_machine import StateMachine
8 | from pse_core.stepper import Stepper
9 |
10 | from pse.types.base.chain import ChainStateMachine
11 | from pse.types.base.phrase import PhraseStateMachine
12 | from pse.types.key_value import KeyValueStateMachine
13 | from pse.types.whitespace import WhitespaceStateMachine
14 |
15 | logger = logging.getLogger()
16 |
17 |
18 | class ObjectStateMachine(StateMachine):
19 | """
20 | Accepts a well-formed JSON object and manages state transitions during parsing.
21 |
22 | This state_machine handles the parsing of JSON objects by defining state transitions
23 | and maintaining the current object properties being parsed.
24 | """
25 |
26 | def __init__(self, is_optional: bool = False) -> None:
27 | """
28 |
29 | Sets up the state transition graph for parsing JSON objects.
30 | """
31 | super().__init__(
32 | {
33 | 0: [
34 | (PhraseStateMachine("{"), 1),
35 | ],
36 | 1: [
37 | (WhitespaceStateMachine(), 2),
38 | ],
39 | 2: [
40 | (KeyValueStateMachine(), 3),
41 | ],
42 | 3: [
43 | (WhitespaceStateMachine(), 4),
44 | ],
45 | 4: [
46 | (
47 | ChainStateMachine(
48 | [PhraseStateMachine(","), WhitespaceStateMachine()]
49 | ),
50 | 2,
51 | ),
52 | (PhraseStateMachine("}"), "$"), # End of object
53 | ],
54 | },
55 | is_optional=is_optional,
56 | )
57 |
58 | def get_new_stepper(self, state: StateId | None = None) -> ObjectStepper:
59 | return ObjectStepper(self, state)
60 |
61 | def get_transitions(self, stepper: Stepper) -> list[tuple[Stepper, StateId]]:
62 | transitions = super().get_transitions(stepper)
63 | if stepper.current_state == 1 and self.is_optional:
64 | for transition in PhraseStateMachine("}").get_steppers():
65 | transitions.append((transition, "$"))
66 | return transitions
67 |
68 | def __str__(self) -> str:
69 | return "Object"
70 |
71 |
72 | class ObjectStepper(Stepper):
73 | def __init__(
74 | self, state_machine: ObjectStateMachine, current_state: StateId | None = None
75 | ) -> None:
76 | super().__init__(state_machine, current_state)
77 | self.value: dict[str, Any] = {}
78 |
79 | def clone(self) -> ObjectStepper:
80 | cloned_stepper = super().clone()
81 | cloned_stepper.value = self.value.copy()
82 | return cloned_stepper
83 |
84 | def add_to_history(self, stepper: Stepper) -> None:
85 | if self.current_state == 3:
86 | prop_name, prop_value = stepper.get_current_value()
87 | logger.debug(f"🟢 Adding {prop_name}: {prop_value} to {self.value}")
88 | self.value[prop_name] = prop_value
89 | super().add_to_history(stepper)
90 |
91 | def get_current_value(self) -> dict[str, Any]:
92 | """
93 | Get the current parsed JSON object.
94 |
95 | Returns:
96 | dict[str, Any]: The accumulated key-value pairs representing the JSON object.
97 | """
98 | if not self.get_raw_value():
99 | return {}
100 | return self.value
101 |
--------------------------------------------------------------------------------
/pse/types/string.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from pse_core import StateId
4 | from pse_core.state_machine import StateMachine
5 | from pse_core.stepper import Stepper
6 |
7 | from pse.types.base.character import CharacterStateMachine
8 | from pse.types.base.phrase import PhraseStateMachine
9 |
10 | INVALID_CHARS: set[str] = {chr(c) for c in range(0, 0x20)} | {'"', "\\"}
11 |
12 |
13 | class StringStateMachine(StateMachine):
14 | """
15 | Accepts a well-formed JSON string.
16 |
17 | The length of the string is measured excluding the surrounding quotation marks.
18 | """
19 |
20 | # StateId constants
21 | STRING_CONTENTS = 1
22 | ESCAPED_SEQUENCE = 2
23 | HEX_CODE = 3
24 |
25 | def __init__(self, min_length: int | None = None, max_length: int | None = None):
26 | """
27 | The state machine is configured to parse JSON strings, handling escape sequences
28 | and Unicode characters appropriately.
29 | """
30 | super().__init__(
31 | {
32 | 0: [
33 | (PhraseStateMachine('"'), self.STRING_CONTENTS),
34 | ],
35 | self.STRING_CONTENTS: [
36 | (
37 | CharacterStateMachine(
38 | blacklist_charset=INVALID_CHARS,
39 | char_min=min_length,
40 | char_limit=max_length,
41 | ),
42 | self.STRING_CONTENTS,
43 | ), # Regular characters
44 | (PhraseStateMachine('"'), "$"), # End quote
45 | (
46 | PhraseStateMachine("\\"),
47 | self.ESCAPED_SEQUENCE,
48 | ), # Escape character
49 | ],
50 | self.ESCAPED_SEQUENCE: [
51 | (
52 | CharacterStateMachine('"\\/bfnrt', char_limit=1),
53 | self.STRING_CONTENTS,
54 | ), # Escaped characters
55 | (PhraseStateMachine("u"), self.HEX_CODE), # Unicode escape sequence
56 | ],
57 | self.HEX_CODE: [
58 | (
59 | CharacterStateMachine(
60 | "0123456789ABCDEFabcdef",
61 | char_min=4,
62 | char_limit=4,
63 | ),
64 | self.STRING_CONTENTS,
65 | ),
66 | ],
67 | }
68 | )
69 |
70 | def get_new_stepper(self, state: int | str | None = None) -> Stepper:
71 | return StringStepper(self, state)
72 |
73 | def __str__(self) -> str:
74 | return "String"
75 |
76 |
77 | class StringStepper(Stepper):
78 | def __init__(
79 | self, state_machine: StringStateMachine, current_state: StateId | None = None
80 | ) -> None:
81 | super().__init__(state_machine, current_state)
82 | self.state_machine: StringStateMachine = state_machine
83 |
84 | def is_within_value(self) -> bool:
85 | """
86 | Determines if the stepper is currently within the string value (after opening quote, before closing quote).
87 | """
88 | return self.current_state != 0 and self.target_state not in self.state_machine.end_states
89 |
--------------------------------------------------------------------------------
/pse/types/whitespace.py:
--------------------------------------------------------------------------------
1 | """Whitespace state machine for parsing optional whitespace in structured data.
2 |
3 | This module provides a state machine for recognizing and parsing whitespace
4 | characters in structured data formats like JSON.
5 | """
6 |
7 | from __future__ import annotations
8 |
9 | from pse.types.base.character import CharacterStateMachine
10 |
11 | # Whitespace characters as defined by the JSON standard
12 | WHITESPACE_CHARS = " \t\n\r"
13 |
14 |
15 | class WhitespaceStateMachine(CharacterStateMachine):
16 | """Optional whitespace state machine using TokenTrie for efficient matching."""
17 |
18 | def __init__(self, min_whitespace: int = 0, max_whitespace: int = 20):
19 | """Initialize the whitespace state machine with configurable limits.
20 |
21 | Args:
22 | min_whitespace: Minimum allowable whitespace characters.
23 | Defaults to 0.
24 | max_whitespace: Maximum allowable whitespace characters.
25 | Defaults to 20.
26 | """
27 | super().__init__(
28 | WHITESPACE_CHARS,
29 | char_min=min_whitespace,
30 | char_limit=max_whitespace,
31 | is_optional=(min_whitespace == 0),
32 | )
33 |
34 | def __str__(self) -> str:
35 | """Return a string representation of this state machine."""
36 | return "Whitespace"
37 |
--------------------------------------------------------------------------------
/pse/types/xml/xml_encapsulated.py:
--------------------------------------------------------------------------------
1 | from pse_core import StateId
2 | from pse_core.state_machine import StateMachine
3 |
4 | from pse.types.base.encapsulated import EncapsulatedStepper
5 | from pse.types.base.wait_for import WaitFor
6 | from pse.types.xml.xml_tag import XMLTagStateMachine
7 |
8 |
9 | class XMLEncapsulatedStateMachine(StateMachine):
10 | """
11 | A state machine that wraps a state machine in XML tags.
12 | """
13 |
14 | def __init__(
15 | self,
16 | state_machine: StateMachine,
17 | tag_name: str,
18 | min_buffer_length: int = -1,
19 | is_optional: bool = False,
20 | ) -> None:
21 | """
22 |
23 | Args:
24 | state_machine: The state_machine wrapped by this state machine.
25 | tag_name: The name of the tag to wrap the state machine in.
26 | """
27 | self.inner_state_machine = state_machine
28 | self.xml_delimiters = (f"<{tag_name}>", f"{tag_name}>")
29 | super().__init__(
30 | {
31 | 0: [
32 | (
33 | WaitFor(
34 | XMLTagStateMachine(tag_name),
35 | buffer_length=min_buffer_length,
36 | ),
37 | 1,
38 | ),
39 | ],
40 | 1: [(state_machine, 2)],
41 | 2: [(XMLTagStateMachine(tag_name, closing_tag=True), "$")],
42 | },
43 | is_optional=is_optional,
44 | )
45 |
46 | def get_new_stepper(self, state: StateId | None = None) -> EncapsulatedStepper:
47 | return EncapsulatedStepper(self, state)
48 |
--------------------------------------------------------------------------------
/pse/types/xml/xml_tag.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from pse.types.base.chain import ChainStateMachine, ChainStepper
4 | from pse.types.base.phrase import PhraseStateMachine
5 |
6 |
7 | class XMLTagStateMachine(ChainStateMachine):
8 | """
9 | A state machine that recognizes XML tags.
10 | """
11 |
12 | def __init__(self, tag_name: str, closing_tag: bool = False) -> None:
13 | self.tag_name = tag_name
14 | self.xml_tag = ("<" if not closing_tag else "") + tag_name + ">"
15 | super().__init__(
16 | [
17 | PhraseStateMachine("<" if not closing_tag else ""),
18 | PhraseStateMachine(tag_name),
19 | PhraseStateMachine(">"),
20 | ]
21 | )
22 |
23 | def get_new_stepper(self, state: int | str | None = None) -> XMLTagStepper:
24 | return XMLTagStepper(self, state)
25 |
26 | def __str__(self) -> str:
27 | return self.tag_name
28 |
29 | class XMLTagStepper(ChainStepper):
30 |
31 | def __init__(self, state_machine: XMLTagStateMachine, *args, **kwargs) -> None:
32 | super().__init__(state_machine, *args, **kwargs)
33 | self.state_machine: XMLTagStateMachine = state_machine
34 |
35 | def get_valid_continuations(self) -> list[str]:
36 | return [self.state_machine.xml_tag]
37 |
--------------------------------------------------------------------------------
/pse/util/generate_mlx.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from collections.abc import Callable
3 | from typing import Any
4 |
5 | import mlx.core as mx
6 | import mlx.nn as nn
7 | from mlx_proxy.generate_step import generate_step
8 | from mlx_proxy.samplers import make_sampler
9 |
10 | from pse.structuring_engine import StructuringEngine
11 |
12 | logger = logging.getLogger(__name__)
13 |
14 | def generate(
15 | prompt: str,
16 | model: nn.Module,
17 | engine: StructuringEngine,
18 | prefill: str | None = None,
19 | ) -> str:
20 | mx.metal.clear_cache()
21 | messages = [{"role": "user", "content": prompt}]
22 | formatted_prompt = engine.tokenizer.apply_chat_template(
23 | conversation=messages,
24 | add_generation_prompt=True,
25 | tokenize=False
26 | )
27 | assert isinstance(formatted_prompt, str)
28 | formatted_prompt = formatted_prompt + (prefill or "")
29 | logger.info(formatted_prompt)
30 |
31 | encoded_prompt = engine.tokenizer.encode(formatted_prompt, add_special_tokens=False)
32 | output_tokens: list[int] = []
33 | for tokens, _ in generate_step(
34 | prompt=encoded_prompt,
35 | model=model,
36 | logits_processors=[engine.process_logits],
37 | sampler=sampler(engine),
38 | max_tokens=-1,
39 | ):
40 | assert isinstance(tokens, mx.array)
41 | token_list = tokens.tolist() if tokens.shape[0] > 1 else [tokens.item()]
42 | encoded_prompt.extend(token_list) # type: ignore[arg-type]
43 | output_tokens.extend(token_list) # type: ignore[arg-type]
44 | if engine.has_reached_accept_state:
45 | break
46 |
47 | output = engine.tokenizer.decode(output_tokens)
48 | return prefill + output if prefill else output
49 |
50 | def sampler(engine: StructuringEngine, **kwargs: Any) -> Callable[..., Any]:
51 | """
52 | Return a sampler function.
53 | If structured is True, use the structured sampler.
54 | Otherwise, use the simple sampler.
55 | """
56 | sampler = make_sampler(kwargs.get("temp", 0.7))
57 | return lambda x: engine.sample(x, sampler)
58 |
--------------------------------------------------------------------------------
/pse/util/jax_mixin.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Callable
2 |
3 | import jax
4 | import jax.numpy as jnp
5 | from jax import lax
6 | from transformers import FlaxGenerationMixin, FlaxLogitsProcessorList
7 | from transformers.generation.flax_utils import FlaxSampleOutput, SampleState
8 |
9 | from pse.structuring_engine import StructuringEngine
10 |
11 |
12 | class PSEFlaxMixin(FlaxGenerationMixin):
13 | engine: StructuringEngine
14 |
15 | @staticmethod
16 | def make_sampler(prng_key: jnp.ndarray) -> Callable:
17 | return lambda x: jax.random.categorical(prng_key, x, axis=-1)
18 |
19 | def _sample(
20 | self,
21 | input_ids: None,
22 | max_length: int | None = None,
23 | pad_token_id: int | None = None,
24 | eos_token_id: int | None = None,
25 | prng_key: jnp.ndarray | None = None,
26 | logits_processor: FlaxLogitsProcessorList | None = None,
27 | logits_warper: FlaxLogitsProcessorList | None = None,
28 | trace: bool = True,
29 | params: dict[str, jnp.ndarray] | None = None,
30 | model_kwargs: dict[str, jnp.ndarray] | None = None,
31 | ):
32 | # init values
33 | max_length = max_length if max_length is not None else self.generation_config.max_length
34 | pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
35 | eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
36 | prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)
37 |
38 | batch_size, cur_len = input_ids.shape # type: ignore [arg-type]
39 |
40 | eos_token_id = jnp.array(eos_token_id, dtype=jnp.int32) if eos_token_id is not None else None # type: ignore [arg-type]
41 | pad_token_id = jnp.array(pad_token_id, dtype=jnp.int32) # type: ignore [arg-type]
42 | cur_len = jnp.array(cur_len) # type: ignore [arg-type]
43 |
44 | # per batch-item holding current token in loop.
45 | sequences = jnp.full((batch_size, max_length), pad_token_id, dtype=jnp.int32) # type: ignore [arg-type]
46 | sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0)) # type: ignore [arg-type]
47 |
48 | # per batch-item state bit indicating if sentence has finished.
49 | is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_)
50 | if not logits_processor or self.engine.process_logits not in logits_processor:
51 | # insert the engine at the beginning of the list
52 | if logits_processor is None:
53 | logits_processor = FlaxLogitsProcessorList()
54 | logits_processor.insert(0, self.engine.process_logits)
55 |
56 | # For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
57 | # and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
58 | model = self.decode if self.config.is_encoder_decoder else self # type: ignore [attr-defined]
59 |
60 | assert isinstance(model, Callable)
61 | # initialize model specific kwargs
62 | model_kwargs = self.prepare_inputs_for_generation(input_ids, max_length, **model_kwargs) # type: ignore [arg-type]
63 |
64 | # initialize state
65 | state = SampleState(
66 | cur_len=cur_len, # type: ignore [arg-type]
67 | sequences=sequences, # type: ignore [arg-type]
68 | running_token=input_ids, # type: ignore [arg-type]
69 | is_sent_finished=is_sent_finished, # type: ignore [arg-type]
70 | prng_key=prng_key, # type: ignore [arg-type]
71 | model_kwargs=model_kwargs, # type: ignore [arg-type]
72 | )
73 |
74 | def sample_search_cond_fn(state):
75 | """state termination condition fn."""
76 | has_reached_max_length = state.cur_len == max_length
77 | all_sequence_finished = jnp.all(state.is_sent_finished)
78 | finish_generation = jnp.logical_or(has_reached_max_length, all_sequence_finished)
79 | return ~finish_generation
80 |
81 | def sample_search_body_fn(state):
82 | """state update fn."""
83 | prng_key, prng_key_next = jax.random.split(state.prng_key)
84 | model_outputs = model(state.running_token, params=params, **state.model_kwargs)
85 |
86 | logits = model_outputs.logits[:, -1]
87 | # apply min_length, ...
88 | logits = logits_processor(state.sequences, logits, state.cur_len)
89 | # apply top_p, top_k, temperature
90 | if logits_warper:
91 | logits = logits_warper(logits, logits, state.cur_len)
92 |
93 | sampler = PSEFlaxMixin.make_sampler(prng_key)
94 | next_token = self.engine.sample(logits, sampler)
95 |
96 | next_token = next_token * ~state.is_sent_finished + pad_token_id * state.is_sent_finished
97 | next_is_sent_finished = state.is_sent_finished | (next_token == eos_token_id)
98 | next_token = next_token[:, None]
99 |
100 | next_sequences = lax.dynamic_update_slice(state.sequences, next_token, (0, state.cur_len))
101 | next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs) # type: ignore [attr-defined]
102 |
103 | return SampleState(
104 | cur_len=state.cur_len + len(next_token), # type: ignore [arg-type]
105 | sequences=next_sequences, # type: ignore [arg-type]
106 | running_token=next_token, # type: ignore [arg-type]
107 | is_sent_finished=next_is_sent_finished, # type: ignore [arg-type]
108 | model_kwargs=next_model_kwargs, # type: ignore [arg-type]
109 | prng_key=prng_key_next, # type: ignore [arg-type]
110 | )
111 |
112 | # The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
113 | if input_ids.shape[1] > 1: # type: ignore [arg-type]
114 | state = sample_search_body_fn(state)
115 |
116 | if not trace:
117 | state = self._run_loop_in_debug(
118 | lambda state: sample_search_cond_fn(state) and not self.engine.has_reached_accept_state,
119 | sample_search_body_fn,
120 | state,
121 | )
122 | else:
123 | state = lax.while_loop(
124 | lambda state: sample_search_cond_fn(state) and not self.engine.has_reached_accept_state,
125 | sample_search_body_fn,
126 | state,
127 | )
128 |
129 | return FlaxSampleOutput(sequences=state.sequences)
130 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["hatchling"]
3 | build-backend = "hatchling.build"
4 |
5 | [project]
6 | name = "pse"
7 | version = "2025.06.1"
8 | authors = [
9 | { name = "Jack Wind", email = "jckwind11@gmail.com" },
10 | { name = "The Proxy Company", email = "contact@theproxy.company" }
11 | ]
12 | description = "Proxy Structuring Engine: Stateful AI-generated output with a focus on creativity, speed, and structure."
13 | readme = "README.md"
14 | requires-python = ">=3.11"
15 | license = { file = "LICENSE" }
16 |
17 | dependencies = [
18 | "docstring-parser",
19 | "protobuf",
20 | "pse-core",
21 | "pydantic",
22 | "regex",
23 | "tokenizers",
24 | "transformers",
25 | "tqdm",
26 | "typing_extensions",
27 | "wheel",
28 | "lark",
29 | ]
30 |
31 | [project.optional-dependencies]
32 | dev = [
33 | "coverage",
34 | "flake8",
35 | "pytest",
36 | "ruff",
37 | "sentencepiece",
38 | "ipykernel",
39 | "pytest-cov",
40 | ]
41 | mlx = [
42 | "mlx",
43 | "mlx-proxy",
44 | ]
45 | torch = [
46 | "torch",
47 | "accelerate",
48 | ]
49 | tensorflow = [
50 | "tensorflow",
51 | "accelerate",
52 | ]
53 | jax = [
54 | "jax",
55 | "accelerate",
56 | ]
57 |
58 | [project.urls]
59 | homepage = "https://github.com/TheProxyCompany/proxy-structuring-engine"
60 | documentation = "https://github.com/TheProxyCompany/proxy-structuring-engine#readme"
61 | source = "https://github.com/TheProxyCompany/proxy-structuring-engine"
62 |
63 | [tool.ruff.lint]
64 | extend-select = [
65 | "B", # flake8-bugbear
66 | "I", # isort
67 | "PGH", # pygrep-hooks
68 | "RUF", # Ruff-specific
69 | "UP", # pyupgrade
70 | "SLF", # string-literal-format
71 | "F8", # flake8-comprehensions
72 | ]
73 |
74 | [tool.hatch.build.targets.sdist]
75 | include = [
76 | "LICENSE",
77 | "README.md",
78 | "pyproject.toml",
79 | "pse"
80 | ]
81 |
82 | [tool.hatch.build.targets.wheel]
83 | packages = ["pse"]
84 | include = ["pse/**"]
85 | optimize = true
86 | ignore-vcs = true
87 | python-tag = "py311"
88 | repair-wheel = true
89 |
90 | [tool.hatch.envs.default]
91 | python = "3.11"
92 | env-vars = { PYTHONOPTIMIZE = "2" }
93 |
94 | [tool.pytest.ini_options]
95 | log_cli = false
96 | log_cli_level = "WARNING"
97 | log_cli_format = "in %(filename)s:%(lineno)d [%(levelname)s] %(message)s"
98 | log_cli_date_format = "%H:%M:%S"
99 | addopts = "--cov=pse"
100 |
--------------------------------------------------------------------------------
/tests/unit/types/base/test_wait_for.py:
--------------------------------------------------------------------------------
1 | from pse_core.trie import TrieMap
2 |
3 | from pse.types.base.phrase import PhraseStateMachine, PhraseStepper
4 | from pse.types.base.wait_for import (
5 | WaitFor,
6 | WaitForStepper,
7 | )
8 |
9 |
10 | def test_default_wait_for_acceptor() -> None:
11 | text_acceptor = PhraseStateMachine("Hello World")
12 | state_machine = WaitFor(text_acceptor, buffer_length=0)
13 |
14 | steppers = list(state_machine.get_steppers())
15 | assert len(steppers) == 1
16 | stepper = steppers[0]
17 | assert isinstance(stepper, WaitForStepper)
18 | assert stepper.accepts_any_token()
19 | assert stepper.sub_stepper
20 | assert stepper.sub_stepper.state_machine == text_acceptor
21 | assert isinstance(stepper.sub_stepper, PhraseStepper)
22 | assert not stepper.is_within_value()
23 | steppers = state_machine.advance_all_basic(steppers, "Hello ")
24 | assert len(steppers) == 1
25 | assert steppers[0].is_within_value()
26 | steppers = state_machine.advance_all_basic(steppers, "World")
27 | assert len(steppers) == 1
28 | assert steppers[0].has_reached_accept_state()
29 |
30 |
31 | def test_basic_wait_for_acceptor() -> None:
32 | """Test that the WaitForAcceptor can accept any token."""
33 | text_acceptor = PhraseStateMachine("Hello World")
34 | state_machine = WaitFor(text_acceptor)
35 | steppers = list(state_machine.get_steppers())
36 | steppers = state_machine.advance_all_basic(steppers, "Hello World")
37 | assert len(steppers) == 1
38 | assert steppers[0].has_reached_accept_state()
39 | assert not steppers[0].get_invalid_continuations()
40 |
41 |
42 | def test_interrupted_wait_for_acceptor() -> None:
43 | text_acceptor = PhraseStateMachine("Hello World")
44 | state_machine = WaitFor(text_acceptor, strict=True)
45 |
46 | steppers = state_machine.get_steppers()
47 | steppers = state_machine.advance_all_basic(steppers, "Hello ")
48 | assert len(steppers) == 1
49 | assert steppers[0].is_within_value()
50 | steppers = state_machine.advance_all_basic(
51 | steppers, "I'm gonna mess up the pattern!"
52 | )
53 | assert not steppers
54 |
55 |
56 | def test_wait_for_acceptor_with_break() -> None:
57 | """Test that the WaitForAcceptor can accept any token."""
58 | text_acceptor = PhraseStateMachine("Hello World")
59 | state_machine = WaitFor(text_acceptor, strict=False)
60 | steppers = list(state_machine.get_steppers())
61 | steppers = state_machine.advance_all_basic(steppers, "Hello ")
62 | assert len(steppers) == 1
63 |
64 | steppers = state_machine.advance_all_basic(
65 | steppers, "I'm gonna mess up the pattern! But i'll still be accepted!"
66 | )
67 | assert len(steppers) == 1
68 |
69 | steppers = state_machine.advance_all_basic(steppers, "World")
70 | assert len(steppers) == 1
71 | assert steppers[0].has_reached_accept_state()
72 |
73 |
74 | def test_wait_for_acceptor_with_partial_match():
75 | """Test that the WaitForAcceptor can accept any token."""
76 | text_acceptor = PhraseStateMachine('"hello"')
77 | state_machine = WaitFor(text_acceptor)
78 | steppers = list(state_machine.get_steppers())
79 | trie_map = TrieMap()
80 | items = [
81 | ('"hello', 1),
82 | ('"', 2),
83 | ("hello", 3),
84 | ('"c', 4),
85 | ]
86 | trie_map = trie_map.insert_all(items)
87 | stepper_deltas = state_machine.advance_all(steppers, '"*', trie_map)
88 | for stepper_delta in stepper_deltas:
89 | assert stepper_delta.was_healed
90 | assert stepper_delta.token == '"'
91 | assert stepper_delta.stepper.get_current_value() == '"'
92 | assert len(steppers) == 1
93 | assert not steppers[0].has_reached_accept_state()
94 |
95 |
96 | def test_get_valid_continuations_buffer_too_short():
97 | """Test get_valid_continuations when buffer is shorter than min_buffer_length."""
98 | text_acceptor = PhraseStateMachine("Hello")
99 | state_machine = WaitFor(text_acceptor, buffer_length=10)
100 |
101 | steppers = list(state_machine.get_steppers())
102 | assert len(steppers) == 1
103 | stepper = steppers[0]
104 |
105 | # Buffer is empty, should return empty list
106 | continuations = stepper.get_valid_continuations()
107 | assert continuations == []
108 | invalid_continuations = stepper.get_invalid_continuations()
109 | assert invalid_continuations == ["Hello"]
110 |
111 | def test_should_start_step_with_remaining_input():
112 | """Test should_start_step when remaining_input is not None."""
113 | text_acceptor = PhraseStateMachine("Hello")
114 | state_machine = WaitFor(text_acceptor)
115 |
116 | steppers = list(state_machine.get_steppers())
117 | assert len(steppers) == 1
118 | stepper = steppers[0]
119 |
120 | # Set remaining_input
121 | stepper.remaining_input = "something"
122 |
123 | # should_start_step should be False with remaining_input
124 | assert not stepper.should_start_step("Hello")
125 |
126 |
127 | def test_consume_with_no_sub_stepper():
128 | """Test consume method when sub_stepper is None."""
129 | state_machine = WaitFor(PhraseStateMachine("Hello"))
130 | stepper = WaitForStepper(state_machine)
131 |
132 | # Force sub_stepper to be None
133 | stepper.sub_stepper = None
134 |
135 | # consume should return empty list
136 | result = stepper.consume("any token")
137 | assert result == []
138 |
139 |
140 | def test_consume_with_min_buffer_length_negative():
141 | """Test consume when min_buffer_length is -1 and not within value."""
142 | text_acceptor = PhraseStateMachine("Hello")
143 | state_machine = WaitFor(text_acceptor, buffer_length=-1)
144 |
145 | steppers = list(state_machine.get_steppers())
146 | assert len(steppers) == 1
147 | stepper = steppers[0]
148 |
149 | # Make sure we're not within a value
150 | assert not stepper.is_within_value()
151 |
152 | # Try to consume a token that doesn't start the pattern
153 | # With buffer_length = -1, this should return empty list
154 | result = stepper.consume("NotHello")
155 | assert result == []
156 |
--------------------------------------------------------------------------------
/tests/unit/types/bash/test_bash_code.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from pse_core.state_machine import StateMachine
3 |
4 | from pse.types.grammar import BashStateMachine
5 |
6 |
7 | @pytest.mark.parametrize(
8 | "code, should_accept",
9 | [
10 | # Basic commands
11 | ("echo 'Hello, World!'", True),
12 | ("ls -la", True),
13 | ("cd /home/user", True),
14 | ("grep -r 'pattern' .", True),
15 | ("cat file.txt", True),
16 | # Variables and assignments
17 | ("x=1", True),
18 | ("NAME='John Doe'", True),
19 | ("export PATH=$PATH:/usr/local/bin", True),
20 | ("readonly VAR=value", True),
21 | # Control structures
22 | ("if [ $x -eq 1 ]; then echo 'yes'; fi", True),
23 | ("for i in 1 2 3; do echo $i; done", True),
24 | ("while [ $count -lt 10 ]; do echo $count; ((count++)); done", True),
25 | ("case $var in a) echo 'a';; b) echo 'b';; esac", True),
26 | # Functions
27 | ("function greet() { echo 'Hello'; }", True),
28 | ("myfunc() { local var=1; echo $var; }", True),
29 | # Pipes and redirections
30 | ("cat file.txt | grep pattern", True),
31 | ("ls > output.txt", True),
32 | ("cat < input.txt", True),
33 | ("command >> log.txt 2>&1", True),
34 | # Command substitution
35 | ("echo $(date)", True),
36 | ("files=$(ls -la)", True),
37 | # Arithmetic
38 | ("echo $((1 + 2))", True),
39 | ("((x = y + 3))", True),
40 | # Comments
41 | ("# This is a comment", True),
42 | ("echo 'Hello' # inline comment", True),
43 | # Multiline commands
44 | ("echo 'line 1'\necho 'line 2'", True),
45 | ("if true; then\n echo 'true'\nfi", True),
46 | # Invalid syntax
47 | ("if then fi", False),
48 | ("for in do done", False),
49 | ("case esac", False),
50 | ("echo 'unterminated string", False),
51 | ("function () {}", False),
52 | ("ls | | grep pattern", False),
53 | ],
54 | )
55 | def test_bash_source_validation(code, should_accept):
56 | """Test validation of Bash source code."""
57 | source_code_sm = StateMachine(
58 | {
59 | 0: [(BashStateMachine, "$")],
60 | }
61 | )
62 | steppers = source_code_sm.get_steppers()
63 | steppers = source_code_sm.advance_all_basic(steppers, code)
64 |
65 | if should_accept:
66 | assert any(stepper.has_reached_accept_state() for stepper in steppers), (
67 | f"Should accept valid Bash code: {code}"
68 | )
69 | else:
70 | assert not any(stepper.has_reached_accept_state() for stepper in steppers), (
71 | f"Should not accept invalid Bash code: {code}"
72 | )
73 |
74 |
75 | def test_incremental_parsing():
76 | """Test incremental parsing of Bash code."""
77 | source_code_sm = StateMachine(
78 | {
79 | 0: [(BashStateMachine, "$")],
80 | }
81 | )
82 |
83 | # Test that we can parse a simple echo command
84 | # This is known to work from the other tests
85 | complete_code = "echo 'Hello, World!'"
86 | steppers = source_code_sm.get_steppers()
87 | steppers = source_code_sm.advance_all_basic(steppers, complete_code)
88 |
89 | # The simple echo command should be valid
90 | assert len(steppers) > 0, "Should accept simple echo command"
91 | assert any(stepper.has_reached_accept_state() for stepper in steppers)
92 |
93 |
94 | def test_empty_input():
95 | """Test handling of empty input."""
96 | source_code_sm = StateMachine(
97 | {
98 | 0: [(BashStateMachine, "$")],
99 | }
100 | )
101 | steppers = source_code_sm.get_steppers()
102 | steppers = source_code_sm.advance_all_basic(steppers, "")
103 | assert not any(stepper.has_reached_accept_state() for stepper in steppers)
104 |
105 |
106 | @pytest.mark.parametrize(
107 | "incomplete_code",
108 | [
109 | "if [ $x -eq 1 ]; then",
110 | "for i in 1 2 3; do",
111 | "while true; do",
112 | "case $var in",
113 | "function name() {",
114 | "echo 'Hello' |",
115 | "ls -la >",
116 | "cat <",
117 | "grep pattern &&",
118 | "find . -name '*.txt' ||",
119 | ],
120 | )
121 | def test_incomplete_but_valid_code(incomplete_code):
122 | """Test handling of incomplete but syntactically valid Bash code."""
123 | source_code_sm = StateMachine(
124 | {
125 | 0: [(BashStateMachine, "$")],
126 | }
127 | )
128 | steppers = source_code_sm.get_steppers()
129 | steppers = source_code_sm.advance_all_basic(steppers, incomplete_code)
130 | assert len(steppers) > 0
131 | # Incomplete code should be able to accept more input
132 | assert all(stepper.can_accept_more_input() for stepper in steppers)
133 | # Incomplete code should not be in an accept state
134 | assert not any(stepper.has_reached_accept_state() for stepper in steppers)
135 |
136 |
137 | @pytest.mark.parametrize(
138 | "code",
139 | [
140 | "echo 'string with unmatched quote",
141 | "if [ $x -eq 1 ]",
142 | "for i in $(seq 1 10)",
143 | "cat file.txt | grep 'pattern' |",
144 | "function name() { echo 'incomplete",
145 | "case $var in a) echo 'a';;",
146 | "while [ $count -lt 10 ]",
147 | "ls -la 2>",
148 | ],
149 | )
150 | def test_bash_specific_incomplete_constructs(code):
151 | """Test Bash-specific incomplete constructs that should be considered valid during incremental parsing."""
152 | source_code_sm = StateMachine(
153 | {
154 | 0: [(BashStateMachine, "$")],
155 | }
156 | )
157 | steppers = source_code_sm.get_steppers()
158 | steppers = source_code_sm.advance_all_basic(steppers, code)
159 | assert len(steppers) > 0, f"Should accept incomplete Bash construct: {code}"
160 | # Incomplete constructs should not be in an accept state
161 | assert not any(stepper.has_reached_accept_state() for stepper in steppers)
162 | assert all(stepper.can_accept_more_input() for stepper in steppers)
163 |
164 | def test_bash_code_validate_no_code():
165 | """Test that validate returns False for empty code."""
166 | assert not BashStateMachine.grammar.validate("")
167 |
--------------------------------------------------------------------------------
/tests/unit/types/bash/test_wrapped_bash.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from pse.types.base.encapsulated import EncapsulatedStateMachine
4 | from pse.types.grammar import BashStateMachine
5 |
6 |
7 | def test_basic_bash_block():
8 | """Test basic Bash code block parsing."""
9 | bash_sm = EncapsulatedStateMachine(
10 | BashStateMachine, delimiters=BashStateMachine.delimiters
11 | )
12 | steppers = bash_sm.get_steppers()
13 |
14 | # Test opening delimiter
15 | steppers = bash_sm.advance_all_basic(steppers, "```bash\n")
16 | assert len(steppers) > 0
17 | # Test Bash code
18 | steppers = bash_sm.advance_all_basic(steppers, "echo 'Hello, World!'")
19 | assert len(steppers) > 0
20 |
21 | # Test closing delimiter
22 | steppers = bash_sm.advance_all_basic(steppers, "\n```")
23 | assert any(stepper.has_reached_accept_state() for stepper in steppers)
24 |
25 |
26 | @pytest.mark.parametrize(
27 | "code_block",
28 | [
29 | "```bash\necho 'Hello, World!'\n```",
30 | "```bash\nls -la\ngrep pattern file.txt\n```",
31 | "```bash\nfor i in 1 2 3; do\n echo $i\ndone\n```",
32 | "```bash\nif [ -f file.txt ]; then\n cat file.txt\nelse\n echo 'File not found'\nfi\n```",
33 | "```bash\nfunction greet() {\n echo 'Hello, $1!'\n}\n\ngreet 'World'\n```",
34 | ],
35 | )
36 | def test_complete_bash_blocks(code_block):
37 | """Test various complete Bash code blocks."""
38 | bash_sm = EncapsulatedStateMachine(
39 | BashStateMachine, delimiters=BashStateMachine.delimiters
40 | )
41 | steppers = bash_sm.get_steppers()
42 | steppers = bash_sm.advance_all_basic(steppers, code_block)
43 | assert any(stepper.has_reached_accept_state() for stepper in steppers)
44 |
45 |
46 | def test_custom_delimiters():
47 | """Test BashStateMachine with custom delimiters."""
48 | sm = EncapsulatedStateMachine(
49 | BashStateMachine, delimiters=("", "")
50 | )
51 | steppers = sm.get_steppers()
52 |
53 | code_block = "echo 'Hello, World!'"
54 | steppers = sm.advance_all_basic(steppers, code_block)
55 | assert any(stepper.has_reached_accept_state() for stepper in steppers)
56 |
57 |
58 | def test_stepper_clone():
59 | """Test cloning of BashStepper."""
60 | bash_sm = EncapsulatedStateMachine(
61 | BashStateMachine, delimiters=BashStateMachine.delimiters
62 | )
63 | steppers = bash_sm.get_steppers()
64 | steppers = bash_sm.advance_all_basic(steppers, "```bash\necho 'Hello'\n")
65 |
66 | original_stepper = steppers[0]
67 | cloned_stepper = original_stepper.clone()
68 |
69 | assert original_stepper.get_current_value() == cloned_stepper.get_current_value()
70 | assert original_stepper is not cloned_stepper
71 |
72 |
73 | @pytest.mark.parametrize(
74 | "invalid_block",
75 | [
76 | "```bash\nif then fi\n```",
77 | "```bash\nfor in do done\n```",
78 | "```bash\ncase esac\n```",
79 | "```bash\nfunction () {}\n```",
80 | "```bash\nls | | grep pattern\n```",
81 | ],
82 | )
83 | def test_invalid_bash_blocks(invalid_block):
84 | """Test handling of invalid Bash code blocks."""
85 | bash_sm = EncapsulatedStateMachine(
86 | BashStateMachine, delimiters=BashStateMachine.delimiters
87 | )
88 | steppers = bash_sm.get_steppers()
89 | steppers = bash_sm.advance_all_basic(steppers, invalid_block)
90 | assert not any(stepper.has_reached_accept_state() for stepper in steppers)
91 |
92 |
93 | def test_incremental_parsing():
94 | """Test incremental parsing of Bash code block."""
95 | bash_sm = EncapsulatedStateMachine(
96 | BashStateMachine, delimiters=BashStateMachine.delimiters
97 | )
98 | steppers = bash_sm.get_steppers()
99 |
100 | parts = [
101 | "```bash\n",
102 | "for i in 1 2 3; do\n",
103 | " echo $i\n",
104 | "done\n",
105 | "```",
106 | ]
107 |
108 | for part in parts:
109 | steppers = bash_sm.advance_all_basic(steppers, part)
110 | assert len(steppers) > 0
111 |
112 | assert any(stepper.has_reached_accept_state() for stepper in steppers)
113 |
114 |
115 | def test_can_accept_more_input():
116 | """Test that the stepper can accept more input."""
117 | bash_sm = EncapsulatedStateMachine(
118 | BashStateMachine, delimiters=BashStateMachine.delimiters
119 | )
120 | steppers = bash_sm.get_steppers()
121 | steppers = bash_sm.advance_all_basic(steppers, "```bash\n")
122 | assert len(steppers) > 0
123 | steppers = bash_sm.advance_all_basic(steppers, "echo 'Hello")
124 | assert len(steppers) > 0
125 | assert all(stepper.can_accept_more_input() for stepper in steppers)
126 | steppers = bash_sm.advance_all_basic(steppers, "'")
127 | assert len(steppers) > 0
128 | assert all(stepper.can_accept_more_input() for stepper in steppers)
129 | steppers = bash_sm.advance_all_basic(steppers, "\n```")
130 | assert len(steppers) > 0
131 | assert any(stepper.has_reached_accept_state() for stepper in steppers)
132 |
133 |
134 | @pytest.mark.parametrize(
135 | "code_block",
136 | [
137 | "```bash\n#!/bin/bash\n\n# This is a comment\necho 'Hello, World!'\n```",
138 | "```bash\nVAR='value'\necho $VAR\n```",
139 | "```bash\nls -la | grep '.txt' | sort\n```",
140 | "```bash\nif [ -d /tmp ]; then\n echo 'Directory exists'\nfi\n```",
141 | ],
142 | )
143 | def test_bash_specific_features(code_block):
144 | """Test Bash-specific features."""
145 | bash_sm = EncapsulatedStateMachine(
146 | BashStateMachine, delimiters=BashStateMachine.delimiters
147 | )
148 | steppers = bash_sm.get_steppers()
149 | steppers = bash_sm.advance_all_basic(steppers, code_block)
150 | assert any(stepper.has_reached_accept_state() for stepper in steppers)
151 |
152 |
153 | @pytest.mark.parametrize(
154 | "incomplete_block",
155 | [
156 | "```bash\nif [ $x -eq 1 ]; then\n",
157 | "```bash\nfor i in 1 2 3; do\n",
158 | "```bash\nwhile true; do\n",
159 | "```bash\ncase $var in\n",
160 | "```bash\nfunction name() {\n",
161 | ],
162 | )
163 | def test_incomplete_bash_blocks(incomplete_block):
164 | """Test handling of incomplete Bash code blocks."""
165 | bash_sm = EncapsulatedStateMachine(
166 | BashStateMachine, delimiters=BashStateMachine.delimiters
167 | )
168 | steppers = bash_sm.get_steppers()
169 | steppers = bash_sm.advance_all_basic(steppers, incomplete_block)
170 | assert len(steppers) > 0
171 | assert all(stepper.can_accept_more_input() for stepper in steppers)
172 |
--------------------------------------------------------------------------------
/tests/unit/types/json/schema_sources/test_from_pydantic.py:
--------------------------------------------------------------------------------
1 | """Tests for the from_pydantic module."""
2 | from pydantic import BaseModel, Field
3 |
4 | from pse.types.json.schema_sources.from_pydantic import pydantic_to_schema
5 |
6 |
7 | class SimpleModel(BaseModel):
8 | """A simple model with basic fields.
9 |
10 | This model contains string and integer fields.
11 | """
12 |
13 | name: str = Field(description="The name field")
14 | age: int = Field(description="The age field")
15 |
16 |
17 | class ModelWithDefaults(BaseModel):
18 | """A model with default values.
19 |
20 | Args:
21 | name: A name field with default
22 | age: An age field that's required
23 | """
24 |
25 | name: str = Field(default="John", description="The name field")
26 | age: int = Field(description="The age field")
27 |
28 |
29 | class ComplexModel(BaseModel):
30 | """A complex model with nested types.
31 |
32 | Args:
33 | name: The user's name
34 | age: The user's age
35 | tags: A list of string tags
36 | address: An optional address
37 | """
38 |
39 | name: str
40 | age: int
41 | tags: list[str] = Field(default_factory=list)
42 | address: str | None = None
43 |
44 |
45 | class NestedModel(BaseModel):
46 | """A model with nested models.
47 |
48 | Args:
49 | title: The title
50 | user: A user object
51 | """
52 |
53 | title: str
54 | user: SimpleModel
55 |
56 |
57 | class ModelWithoutDocs(BaseModel):
58 | name: str
59 | age: int
60 |
61 |
62 | class ModelWithJsonExtra(BaseModel):
63 | """A model with json_schema_extra in fields."""
64 |
65 | name: str = Field(
66 | description="The name field",
67 | json_schema_extra={"example": "John Doe", "pattern": "^[a-zA-Z ]+$"}
68 | )
69 | age: int = Field(
70 | description="The age field",
71 | json_schema_extra={"minimum": 0, "maximum": 120}
72 | )
73 |
74 |
75 | def test_pydantic_to_schema_simple():
76 | """Test pydantic_to_schema with a simple model."""
77 | schema = pydantic_to_schema(SimpleModel)
78 |
79 | assert schema["name"] == "SimpleModel"
80 | assert "description" in schema
81 | assert "properties" in schema
82 | assert "name" in schema["properties"]
83 | assert "age" in schema["properties"]
84 | assert schema["properties"]["name"]["description"] == "The name field"
85 | assert schema["properties"]["age"]["description"] == "The age field"
86 | assert "required" in schema
87 | assert "name" in schema["required"]
88 | assert "age" in schema["required"]
89 |
90 |
91 | def test_pydantic_to_schema_with_defaults():
92 | """Test pydantic_to_schema with a model having default values."""
93 | schema = pydantic_to_schema(ModelWithDefaults)
94 |
95 | assert "properties" in schema
96 | assert "name" in schema["properties"]
97 | assert "age" in schema["properties"]
98 | assert "default" in schema["properties"]["name"]
99 | assert schema["properties"]["name"]["default"] == "John"
100 | assert "required" in schema
101 | assert "age" in schema["required"]
102 | assert "name" not in schema["required"]
103 |
104 |
105 | def test_pydantic_to_schema_complex():
106 | """Test pydantic_to_schema with a model having complex types."""
107 | schema = pydantic_to_schema(ComplexModel)
108 |
109 | assert "properties" in schema
110 | assert "name" in schema["properties"]
111 | assert "age" in schema["properties"]
112 | assert "tags" in schema["properties"]
113 | assert "address" in schema["properties"]
114 |
115 | # Check array type
116 | assert schema["properties"]["tags"]["type"] == "array"
117 | assert "items" in schema["properties"]["tags"]
118 |
119 | # Check nullable field - Pydantic v2 uses anyOf for nullable fields
120 | address_schema = schema["properties"]["address"]
121 | if "anyOf" in address_schema:
122 | # Check if one of the types is null
123 | has_null_type = any(
124 | item.get("type") == "null" for item in address_schema["anyOf"]
125 | )
126 | assert has_null_type
127 | else:
128 | assert address_schema.get("nullable", False) is True
129 |
130 | # Check required fields
131 | assert "required" in schema
132 | assert "name" in schema["required"]
133 | assert "age" in schema["required"]
134 | assert "address" not in schema["required"]
135 | assert "tags" not in schema["required"]
136 |
137 |
138 | def test_pydantic_to_schema_nested():
139 | """Test pydantic_to_schema with nested models."""
140 | schema = pydantic_to_schema(NestedModel)
141 |
142 | assert "properties" in schema
143 | assert "title" in schema["properties"]
144 | assert "user" in schema["properties"]
145 |
146 | # Check nested model - Pydantic v2 uses $ref
147 | user_schema = schema["properties"]["user"]
148 |
149 | # For Pydantic v2, it might use $ref
150 | if "$ref" in user_schema:
151 | # If using $ref, then SimpleModel should be in $defs
152 | assert "$defs" in schema
153 | assert "SimpleModel" in schema["$defs"] or user_schema["$ref"].split("/")[-1] in schema["$defs"]
154 | else:
155 | # For direct embedding of model (older versions)
156 | assert "properties" in user_schema
157 | assert "name" in user_schema["properties"]
158 | assert "age" in user_schema["properties"]
159 | assert "required" in user_schema
160 | assert "name" in user_schema["required"]
161 | assert "age" in user_schema["required"]
162 |
163 |
164 | def test_pydantic_to_schema_without_docs():
165 | """Test pydantic_to_schema with a model without docstrings."""
166 | schema = pydantic_to_schema(ModelWithoutDocs)
167 |
168 | assert schema["name"] == "ModelWithoutDocs"
169 | assert schema["description"] == ""
170 | assert "properties" in schema
171 | assert "name" in schema["properties"]
172 | assert "age" in schema["properties"]
173 | assert "description" in schema["properties"]["name"]
174 | assert schema["properties"]["name"]["description"] == ""
175 |
176 |
177 | def test_pydantic_to_schema_with_json_extra():
178 | """Test pydantic_to_schema with json_schema_extra in fields."""
179 | schema = pydantic_to_schema(ModelWithJsonExtra)
180 |
181 | assert "properties" in schema
182 | assert "name" in schema["properties"]
183 | assert "age" in schema["properties"]
184 |
185 | # Check extra schema properties
186 | name_schema = schema["properties"]["name"]
187 | assert "example" in name_schema
188 | assert name_schema["example"] == "John Doe"
189 | assert "pattern" in name_schema
190 | assert name_schema["pattern"] == "^[a-zA-Z ]+$"
191 |
192 | age_schema = schema["properties"]["age"]
193 | assert "minimum" in age_schema
194 | assert age_schema["minimum"] == 0
195 | assert "maximum" in age_schema
196 | assert age_schema["maximum"] == 120
197 |
--------------------------------------------------------------------------------
/tests/unit/types/json/test_any_json_schema.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | from typing import Any
3 |
4 | import pytest
5 |
6 | from pse.types.json.any_json_schema import AnySchemaStateMachine
7 |
8 |
9 | @pytest.fixture
10 | def context():
11 | """Provide the context fixture."""
12 | return {"defs": defaultdict(dict), "path": ""}
13 |
14 |
15 | def parse_input(state_machine: AnySchemaStateMachine, json_string: str) -> Any:
16 | """
17 | Helper function to parse a JSON string using the AnyOfAcceptor.
18 |
19 | Args:
20 | state_machine (AnyOfAcceptor): The state_machine instance to use for parsing.
21 | json_string (str): The JSON string to parse.
22 |
23 | Returns:
24 | Any: The parsed JSON value.
25 |
26 | Raises:
27 | JSONParsingError: If the JSON input is invalid or does not match any schema.
28 | """
29 | steppers = state_machine.get_steppers()
30 | steppers = state_machine.advance_all_basic(steppers, json_string)
31 | for stepper in steppers:
32 | if stepper.has_reached_accept_state():
33 | return stepper.get_current_value()
34 |
35 | raise ValueError(f"Invalid JSON input for AnyOfAcceptor: {json_string}")
36 |
37 |
38 | @pytest.mark.parametrize(
39 | "schemas, token, expected_result",
40 | [
41 | (
42 | [{"type": "number", "minimum": 0}, {"type": "string", "maxLength": 5}],
43 | "10",
44 | 10,
45 | ),
46 | (
47 | [{"type": "number", "minimum": 0}, {"type": "string", "maxLength": 5}],
48 | '"test"',
49 | "test",
50 | ),
51 | ],
52 | )
53 | def test_accept_input_matching_single_schema(context, schemas, token, expected_result):
54 | """Test that input matching a single schema is accepted."""
55 | state_machine = AnySchemaStateMachine(schemas=schemas, context=context)
56 | result = parse_input(state_machine, token)
57 | assert result == expected_result, (
58 | f"AnyOfAcceptor should accept valid input {token}."
59 | )
60 |
61 |
62 | def test_accept_input_matching_multiple_schemas(context):
63 | """Test that input matching multiple schemas is accepted."""
64 | # Define overlapping schemas
65 | schema1 = {"type": "number", "minimum": 0, "maximum": 100}
66 | schema2 = {"type": "number", "multipleOf": 5}
67 | state_machine = AnySchemaStateMachine(schemas=[schema1, schema2], context=context)
68 |
69 | valid_input = "25" # Matches both schemas
70 | result = parse_input(state_machine, valid_input)
71 | assert result == 25, "AnyOfAcceptor should accept input matching multiple schemas."
72 |
73 |
74 | def test_reject_input_not_matching_any_schema(context):
75 | """Test that input not matching any schema is rejected."""
76 | schema1 = {"type": "boolean"}
77 | schema2 = {"type": "null"}
78 | state_machine = AnySchemaStateMachine(schemas=[schema1, schema2], context=context)
79 |
80 | invalid_input_number = "1"
81 | invalid_input_string = '"test"'
82 |
83 | # Test with invalid number
84 | with pytest.raises(ValueError):
85 | parse_input(state_machine, invalid_input_number)
86 |
87 | # Test with invalid string
88 | with pytest.raises(ValueError):
89 | parse_input(state_machine, invalid_input_string)
90 |
91 |
92 | def test_complex_nested_schemas(context):
93 | """Test AnyOfAcceptor with complex nested schemas."""
94 | schema1 = {
95 | "type": "object",
96 | "properties": {
97 | "name": {"type": "string"},
98 | "age": {"type": "number", "minimum": 0},
99 | },
100 | "required": ["name", "age"],
101 | }
102 | schema2 = {"type": "array", "items": {"type": "string"}, "minItems": 1}
103 | state_machine = AnySchemaStateMachine(schemas=[schema1, schema2], context=context)
104 |
105 | valid_object = '{"name": "Alice", "age": 30}'
106 | valid_array = '["apple", "banana"]'
107 |
108 | # Test with valid object
109 | result_object = parse_input(state_machine, valid_object)
110 | assert result_object == {
111 | "name": "Alice",
112 | "age": 30,
113 | }, "AnyOfAcceptor should accept valid object input."
114 |
115 | # Test with valid array
116 | result_array = parse_input(state_machine, valid_array)
117 | assert result_array == [
118 | "apple",
119 | "banana",
120 | ], "AnyOfAcceptor should accept valid array input."
121 |
122 |
123 | def test_partial_input(context):
124 | """Test that partial input does not result in acceptance."""
125 | schema = {"type": "string", "minLength": 5}
126 | state_machine = AnySchemaStateMachine(schemas=[schema], context=context)
127 |
128 | partial_input = '"test"'
129 |
130 | with pytest.raises(ValueError):
131 | parse_input(state_machine, partial_input)
132 |
133 |
134 | @pytest.mark.parametrize(
135 | "token, expected_result",
136 | [
137 | ('"test"', "test"), # Matches string schema
138 | ("123", 123), # Matches number schema
139 | ],
140 | )
141 | def test_multiple_accepted_steppers(context, token, expected_result):
142 | """Test that AnyOfAcceptor handles multiple accepted steppers correctly."""
143 | schema1 = {"type": "string"}
144 | schema2 = {"type": "number"}
145 | state_machine = AnySchemaStateMachine(schemas=[schema1, schema2], context=context)
146 |
147 | result = parse_input(state_machine, token)
148 | assert result == expected_result, f"AnyOfAcceptor should accept input {token}."
149 |
--------------------------------------------------------------------------------
/tests/unit/types/json/test_json_key_value.py:
--------------------------------------------------------------------------------
1 | from pse.types.json.json_key_value import (
2 | KeyValueSchemaStateMachine,
3 | KeyValueSchemaStepper,
4 | )
5 |
6 |
7 | def test_property_parsing():
8 | state_machine = KeyValueSchemaStateMachine(
9 | prop_name="type",
10 | prop_schema={"type": "string"},
11 | context={"defs": {}},
12 | )
13 | steppers = list(state_machine.get_steppers())
14 | steppers = state_machine.advance_all_basic(steppers, '"')
15 |
16 | assert len(steppers) == 1
17 | steppers = state_machine.advance_all_basic(steppers, "type")
18 | assert len(steppers) == 1
19 | steppers = state_machine.advance_all_basic(steppers, '": "hi"')
20 | assert len(steppers) == 1
21 | assert steppers[0].has_reached_accept_state()
22 | assert steppers[0].get_current_value() == ("type", "hi")
23 |
24 |
25 | def test_property_parsing_with_string_sm():
26 | """Test KeyValueSchemaStateMachine when prop_name is None, using StringStateMachine."""
27 | state_machine = KeyValueSchemaStateMachine(
28 | prop_name=None, # This tests the branch that uses StringStateMachine
29 | prop_schema={"type": "string"},
30 | context={"defs": {}, "path": "/parent"},
31 | )
32 |
33 | steppers = list(state_machine.get_steppers())
34 | steppers = state_machine.advance_all_basic(steppers, '"dynamic_key"')
35 | assert len(steppers) > 0
36 |
37 | steppers = state_machine.advance_all_basic(steppers, ': "value"')
38 | assert len(steppers) > 0
39 | assert any(stepper.has_reached_accept_state() for stepper in steppers)
40 |
41 | for stepper in steppers:
42 | if stepper.has_reached_accept_state():
43 | assert stepper.get_current_value() == ("dynamic_key", "value")
44 |
45 |
46 | def test_key_value_schema_stepper_initialization():
47 | """Test KeyValueSchemaStepper initialization and state machine access."""
48 | state_machine = KeyValueSchemaStateMachine(
49 | prop_name="test_prop",
50 | prop_schema={"type": "string"},
51 | context={"defs": {}},
52 | )
53 |
54 | stepper = state_machine.get_new_stepper()
55 | assert isinstance(stepper, KeyValueSchemaStepper), "Should return a KeyValueSchemaStepper instance"
56 | assert stepper.state_machine is state_machine, "Stepper should reference the correct state machine"
57 |
58 | def test_key_value_schema_stepper_equality():
59 | """Test KeyValueSchemaStepper equality."""
60 | state_machine = KeyValueSchemaStateMachine(
61 | prop_name="test_prop",
62 | prop_schema={"type": "string"},
63 | context={"defs": {}},
64 | )
65 |
66 | steppers1 = state_machine.get_steppers()
67 | steppers2 = state_machine.get_steppers()
68 | assert steppers1[0] == steppers2[0]
69 |
70 | steppers1 = state_machine.advance_all_basic(steppers1, '"test_prop": "value"')
71 | assert len(steppers1) == 1
72 | steppers2 = state_machine.advance_all_basic(steppers2, '"test_prop": "value"')
73 | assert len(steppers2) == 1
74 | assert steppers1[0] == steppers2[0]
75 |
76 | stepper3 = state_machine.get_new_stepper()
77 | assert steppers1[0] != stepper3
78 |
--------------------------------------------------------------------------------
/tests/unit/types/json/test_json_to_state_machine.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | from typing import Any
3 |
4 | import pytest
5 |
6 | from pse.types.base.chain import ChainStateMachine
7 | from pse.types.enum import EnumStateMachine
8 | from pse.types.json import _json_schema_to_state_machine
9 | from pse.types.json.any_json_schema import AnySchemaStateMachine
10 | from pse.types.json.json_array import ArraySchemaStateMachine
11 | from pse.types.json.json_number import NumberSchemaStateMachine
12 | from pse.types.json.json_object import ObjectSchemaStateMachine
13 | from pse.types.json.json_string import StringSchemaStateMachine
14 |
15 |
16 | @pytest.mark.parametrize(
17 | "schema, expected_acceptor_cls, acceptor_len",
18 | [
19 | ({"type": "number", "minimum": 0}, NumberSchemaStateMachine, None),
20 | ({"type": "string", "nullable": True}, AnySchemaStateMachine, 2),
21 | ({"type": ["string", "number"], "minimum": 0}, AnySchemaStateMachine, 2),
22 | (
23 | {
24 | "type": "object",
25 | "properties": {
26 | "name": {"type": "string"},
27 | "age": {"type": "number", "minimum": 0},
28 | },
29 | "required": ["name", "age"],
30 | },
31 | ObjectSchemaStateMachine,
32 | None,
33 | ),
34 | (
35 | {"type": "array", "items": {"type": "string"}, "minItems": 1},
36 | ArraySchemaStateMachine,
37 | None,
38 | ),
39 | ({"enum": ["red", "green", "blue"]}, EnumStateMachine, None),
40 | (
41 | {"allOf": [{"type": "string"}, {"minLength": 5}]},
42 | StringSchemaStateMachine,
43 | None,
44 | ),
45 | ({"oneOf": [{"type": "string"}, {"type": "number"}]}, AnySchemaStateMachine, 2),
46 | ({"type": "string", "const": "fixed_value"}, ChainStateMachine, None),
47 | ],
48 | )
49 | def test_get_acceptor_schema_types(
50 | schema: dict[str, Any],
51 | expected_acceptor_cls: type[Any],
52 | acceptor_len: int | None,
53 | ) -> None:
54 | """Test get_json_acceptor with various schema types and expected acceptors."""
55 | state_machine = _json_schema_to_state_machine(schema)
56 | assert isinstance(state_machine, expected_acceptor_cls), (
57 | f"Expected {expected_acceptor_cls.__name__} for schema {schema}"
58 | )
59 | if acceptor_len is not None:
60 | assert isinstance(state_machine, AnySchemaStateMachine)
61 | assert len(state_machine.state_machines) == acceptor_len, (
62 | f"Expected state_machine length {acceptor_len} for schema {schema}"
63 | )
64 |
65 |
66 | @pytest.fixture
67 | def context_with_definition() -> dict[str, Any]:
68 | """Fixture providing context with predefined definitions."""
69 | context = {"defs": defaultdict(dict), "path": ""}
70 | context["defs"]["#/definitions/address"] = {
71 | "type": "object",
72 | "properties": {"street": {"type": "string"}, "city": {"type": "string"}},
73 | "required": ["street", "city"],
74 | }
75 | return context
76 |
77 |
78 | def test_get_acceptor_with_ref_schema(context_with_definition: dict[str, Any]) -> None:
79 | """Test get_json_acceptor with a $ref schema referencing a definition."""
80 | schema = {"$ref": "#/definitions/address"}
81 | state_machine = _json_schema_to_state_machine(schema, context_with_definition)
82 | assert isinstance(
83 | state_machine,
84 | ObjectSchemaStateMachine,
85 | ), (
86 | "get_json_acceptor should return an ObjectSchemaAcceptor for $ref schemas referencing object definitions."
87 | )
88 |
--------------------------------------------------------------------------------
/tests/unit/types/json/test_json_value.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | import pytest
4 |
5 | from pse.types.json.json_value import JsonStateMachine
6 |
7 |
8 | @pytest.fixture
9 | def json_acceptor():
10 | return JsonStateMachine()
11 |
12 |
13 | @pytest.fixture
14 | def parse_json(json_acceptor: JsonStateMachine):
15 | def parser(json_string: str) -> Any:
16 | steppers = list(json_acceptor.get_steppers())
17 | for char in json_string:
18 | steppers = json_acceptor.advance_all_basic(steppers, char)
19 | if not steppers:
20 | raise AssertionError("No steppers after parsing")
21 | accepted_values = [
22 | stepper.get_current_value()
23 | for stepper in steppers
24 | if stepper.has_reached_accept_state()
25 | ]
26 | if not accepted_values:
27 | raise AssertionError("No accepted steppers after parsing")
28 | return accepted_values[0]
29 |
30 | return parser
31 |
32 |
33 | def test_json_acceptor_initialization(json_acceptor):
34 | assert isinstance(json_acceptor, JsonStateMachine), (
35 | "JsonAcceptor instance was not created."
36 | )
37 |
38 |
39 | @pytest.mark.parametrize(
40 | "json_string, expected",
41 | [
42 | (
43 | '{"name": "John", "age": 30, "city": "New York"}',
44 | {"name": "John", "age": 30, "city": "New York"},
45 | ),
46 | ('["apple", "banana", "cherry"]', ["apple", "banana", "cherry"]),
47 | ('{"message": "Hello, \\nWorld! \\t😊"}', {"message": "Hello, \nWorld! \t😊"}),
48 | (
49 | '{"greeting": "こんにちは世界", "emoji": "🚀🌟"}',
50 | {"greeting": "こんにちは世界", "emoji": "🚀🌟"},
51 | ),
52 | ],
53 | )
54 | def test_parse_valid_json(json_string, expected, parse_json):
55 | parsed = parse_json(json_string)
56 | assert parsed == expected
57 |
58 |
59 | @pytest.mark.parametrize(
60 | "json_string",
61 | [
62 | '{"name": "John", "age": 30,, "city": "New York"}', # Invalid syntax
63 | "", # Empty JSON string
64 | ],
65 | )
66 | def test_parse_invalid_json(json_string, parse_json):
67 | with pytest.raises(AssertionError):
68 | parse_json(json_string)
69 |
70 |
71 | def test_non_zero_state():
72 | sm = JsonStateMachine()
73 | edges = sm.get_edges(1)
74 | assert not edges
75 |
--------------------------------------------------------------------------------
/tests/unit/types/misc/test_fenced_freeform.py:
--------------------------------------------------------------------------------
1 | from pse.types.misc.fenced_freeform import FencedFreeformStateMachine
2 |
3 |
4 | def test_fenced_freeform_default_delimiters():
5 | """Test FencedFreeformStateMachine with default delimiters."""
6 | sm = FencedFreeformStateMachine()
7 | input_sequence = "```\nSome freeform text\n```"
8 | steppers = sm.get_steppers()
9 | steppers = sm.advance_all_basic(steppers, input_sequence)
10 |
11 | assert any(stepper.has_reached_accept_state() for stepper in steppers)
12 |
13 |
14 | def test_fenced_freeform_custom_delimiter():
15 | """Test FencedFreeformStateMachine with custom delimiter."""
16 | sm = FencedFreeformStateMachine(identifier="json")
17 | input_sequence = "```json\n{\"key\": \"value\"}\n```"
18 | steppers = sm.get_steppers()
19 | steppers = sm.advance_all_basic(steppers, input_sequence)
20 |
21 | assert any(stepper.has_reached_accept_state() for stepper in steppers)
22 | assert steppers[0].get_identifier() == "json"
23 |
24 |
25 | def test_fenced_freeform_missing_open_delimiter():
26 | """Test rejection when open delimiter is missing."""
27 | sm = FencedFreeformStateMachine()
28 | input_sequence = "Some freeform text\n```"
29 | steppers = sm.get_steppers()
30 | steppers = sm.advance_all_basic(steppers, input_sequence)
31 |
32 | assert not any(stepper.has_reached_accept_state() for stepper in steppers)
33 |
34 |
35 | def test_fenced_freeform_missing_close_delimiter():
36 | """Test rejection when close delimiter is missing."""
37 | sm = FencedFreeformStateMachine()
38 | input_sequence = "```\nSome freeform text"
39 | steppers = sm.get_steppers()
40 | steppers = sm.advance_all_basic(steppers, input_sequence)
41 |
42 | assert not any(stepper.has_reached_accept_state() for stepper in steppers)
43 | for stepper in steppers:
44 | if stepper.sub_stepper and not stepper.sub_stepper.is_within_value():
45 | assert not stepper.get_invalid_continuations()
46 |
47 |
48 | def test_fenced_freeform_partial_delimiter():
49 | """Test rejection when delimiter is partially provided."""
50 | sm = FencedFreeformStateMachine()
51 | input_sequence = "``\nSome freeform text\n```"
52 | steppers = sm.get_steppers()
53 | steppers = sm.advance_all_basic(steppers, input_sequence)
54 |
55 | assert not any(stepper.has_reached_accept_state() for stepper in steppers)
56 |
57 |
58 | def test_fenced_freeform_respects_char_min():
59 | """Test that char_min is respected."""
60 | sm = FencedFreeformStateMachine(char_min=10)
61 | input_sequence = "```\nshort\n```"
62 | steppers = sm.get_steppers()
63 | steppers = sm.advance_all_basic(steppers, input_sequence)
64 |
65 | assert not any(stepper.has_reached_accept_state() for stepper in steppers)
66 |
67 | input_sequence_valid = "```\nlong enough text\n```"
68 | steppers = sm.get_steppers()
69 | steppers = sm.advance_all_basic(steppers, input_sequence_valid)
70 |
71 | assert any(stepper.has_reached_accept_state() for stepper in steppers)
72 |
73 |
74 | def test_fenced_freeform_respects_char_max():
75 | """Test that char_max is respected."""
76 | sm = FencedFreeformStateMachine(char_max=10)
77 | input_sequence = "```\nthis text is too long\n```"
78 | steppers = sm.get_steppers()
79 | steppers = sm.advance_all_basic(steppers, input_sequence)
80 |
81 | assert not any(stepper.has_reached_accept_state() for stepper in steppers)
82 |
83 | input_sequence_valid = "```\nshort\n```"
84 | steppers = sm.get_steppers()
85 | steppers = sm.advance_all_basic(steppers, input_sequence_valid)
86 |
87 | assert any(stepper.has_reached_accept_state() for stepper in steppers)
88 |
89 |
90 | def test_fenced_freeform_invalid_continuations():
91 | """Test invalid continuations within fenced freeform."""
92 | sm = FencedFreeformStateMachine()
93 | steppers = sm.get_steppers()
94 | steppers = sm.advance_all_basic(steppers, "```\nSome text")
95 | assert len(steppers) > 0
96 | for stepper in steppers:
97 | if stepper.target_state == 2:
98 | invalid_continuations = stepper.get_invalid_continuations()
99 | assert "\n```" in invalid_continuations
100 |
101 | def test_fenced_freeform_buffer_length():
102 | """Test buffer_length constraint."""
103 | sm = FencedFreeformStateMachine(buffer_length=5)
104 | steppers = sm.get_steppers()
105 | assert not any(stepper.should_start_step("```\n") for stepper in steppers)
106 |
107 | steppers = sm.advance_all_basic(steppers, "12345")
108 | assert any(stepper.should_start_step("```\n") for stepper in steppers)
109 |
--------------------------------------------------------------------------------
/tests/unit/types/misc/test_freeform.py:
--------------------------------------------------------------------------------
1 | from pse.types.misc.freeform import FreeformStateMachine, FreeformStepper
2 |
3 |
4 | def test_freeform_basic():
5 | """Test basic functionality of FreeformStateMachine."""
6 | sm = FreeformStateMachine(end_delimiters=["END"])
7 | input_sequence = "Some freeform textEND"
8 | steppers = sm.get_steppers()
9 | steppers = sm.advance_all_basic(steppers, input_sequence)
10 |
11 | assert any(stepper.has_reached_accept_state() for stepper in steppers)
12 | accepted_stepper = next(stepper for stepper in steppers if stepper.has_reached_accept_state())
13 | assert accepted_stepper.get_raw_value() == "Some freeform textEND"
14 |
15 | # Test token_safe_output removes delimiter
16 | assert accepted_stepper.get_token_safe_output(lambda x: "".join(chr(i) for i in x)) == "Some freeform text"
17 |
18 |
19 | def test_freeform_multiple_delimiters():
20 | """Test FreeformStateMachine with multiple end delimiters."""
21 | sm = FreeformStateMachine(end_delimiters=["END", "STOP", "FINISH"])
22 |
23 | # Test with first delimiter
24 | input_sequence = "Text with first delimiterEND"
25 | steppers = sm.get_steppers()
26 | steppers = sm.advance_all_basic(steppers, input_sequence)
27 | assert any(stepper.has_reached_accept_state() for stepper in steppers)
28 |
29 | # Test with second delimiter
30 | input_sequence = "Text with second delimiterSTOP"
31 | steppers = sm.get_steppers()
32 | steppers = sm.advance_all_basic(steppers, input_sequence)
33 | assert any(stepper.has_reached_accept_state() for stepper in steppers)
34 |
35 | # Test with third delimiter
36 | input_sequence = "Text with third delimiterFINISH"
37 | steppers = sm.get_steppers()
38 | steppers = sm.advance_all_basic(steppers, input_sequence)
39 | assert any(stepper.has_reached_accept_state() for stepper in steppers)
40 |
41 |
42 | def test_freeform_missing_delimiter():
43 | """Test that text without an ending delimiter is not accepted."""
44 | sm = FreeformStateMachine(end_delimiters=["END"])
45 | input_sequence = "Text without delimiter"
46 | steppers = sm.get_steppers()
47 | steppers = sm.advance_all_basic(steppers, input_sequence)
48 |
49 | assert not any(stepper.has_reached_accept_state() for stepper in steppers)
50 |
51 |
52 | def test_freeform_partial_delimiter():
53 | """Test that partial delimiters aren't accepted."""
54 | sm = FreeformStateMachine(end_delimiters=["END"])
55 | input_sequence = "Text with partial delimiterEN"
56 | steppers = sm.get_steppers()
57 | steppers = sm.advance_all_basic(steppers, input_sequence)
58 |
59 | assert not any(stepper.has_reached_accept_state() for stepper in steppers)
60 |
61 |
62 | def test_freeform_respects_char_min():
63 | """Test that char_min is respected."""
64 | sm = FreeformStateMachine(end_delimiters=["END"], char_min=10)
65 |
66 | # Test with text shorter than char_min
67 | input_sequence = "shortEND"
68 | steppers = sm.get_steppers()
69 | steppers = sm.advance_all_basic(steppers, input_sequence)
70 | assert not any(stepper.has_reached_accept_state() for stepper in steppers)
71 |
72 | # Test with text meeting char_min requirement
73 | input_sequence = "long enough textEND"
74 | steppers = sm.get_steppers()
75 | steppers = sm.advance_all_basic(steppers, input_sequence)
76 | assert any(stepper.has_reached_accept_state() for stepper in steppers)
77 |
78 |
79 | def test_freeform_stepper_initialization():
80 | """Test the initialization of FreeformStepper."""
81 | sm = FreeformStateMachine(end_delimiters=["END"])
82 | stepper = sm.get_new_stepper()
83 |
84 | assert isinstance(stepper, FreeformStepper)
85 | assert stepper.state_machine == sm
86 | assert stepper.buffer == ""
87 |
88 |
89 | def test_freeform_token_safe_output():
90 | """Test the token_safe_output method."""
91 | sm = FreeformStateMachine(end_delimiters=["END", "STOP"])
92 |
93 | # Test with first delimiter
94 | steppers = sm.get_steppers()
95 | steppers = sm.advance_all_basic(steppers, "Some textEND")
96 | accepted_stepper = next(stepper for stepper in steppers if stepper.has_reached_accept_state())
97 | assert accepted_stepper.get_token_safe_output(lambda x: "".join(chr(i) for i in x)) == "Some text"
98 |
99 | # Test with second delimiter
100 | steppers = sm.get_steppers()
101 | steppers = sm.advance_all_basic(steppers, "Other textSTOP")
102 | accepted_stepper = next(stepper for stepper in steppers if stepper.has_reached_accept_state())
103 | assert accepted_stepper.get_token_safe_output(lambda x: "".join(chr(i) for i in x)) == "Other text"
104 |
105 |
106 | def test_freeform_get_raw_value():
107 | """Test the get_raw_value method."""
108 | sm = FreeformStateMachine(end_delimiters=["END"])
109 | input_sequence = "Raw text valueEND"
110 | steppers = sm.get_steppers()
111 | steppers = sm.advance_all_basic(steppers, input_sequence)
112 |
113 | accepted_stepper = next(stepper for stepper in steppers if stepper.has_reached_accept_state())
114 | assert accepted_stepper.get_raw_value() == input_sequence
115 |
116 | # Ensure the raw value includes the delimiter
117 | assert accepted_stepper.get_raw_value().endswith("END")
118 |
119 |
120 | def test_freeform_string_representation():
121 | """Test the string representation of FreeformStateMachine."""
122 | sm = FreeformStateMachine(end_delimiters=["END"])
123 | assert str(sm) == "FreeformText"
124 |
--------------------------------------------------------------------------------
/tests/unit/types/python/test_python_code.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from pse_core.state_machine import StateMachine
3 |
4 | from pse.types.grammar import PythonStateMachine
5 |
6 |
7 | @pytest.mark.parametrize(
8 | "code, should_accept",
9 | [
10 | ("x = 1", True),
11 | ("x =! 1", False),
12 | ("x != 1", True),
13 | ("def foo():\n pass", True),
14 | ("class Test:\n pass", True),
15 | ("print('Hello, World!')", True),
16 | ("if True:\n print('test')", True),
17 | # Test incomplete code
18 | ("1 + ", False),
19 | ("def", False),
20 | ("for i in", False),
21 | ("x =", False),
22 | ("", False),
23 | ("class:", False),
24 | ("x y z", False),
25 | # Test expressions
26 | ("1 + 2", True),
27 | ("len([1, 2, 3])", True),
28 | ("'test'.upper()", True),
29 | # Test multiline code
30 | ("x = 1\ny = 2\nz = x + y", True),
31 | ("def test():\n x = 1\n return x", True),
32 | # Test comments
33 | ("# This is a comment", True),
34 | ("x = 1 # inline comment", True),
35 | ("'''docstring'''", True),
36 | # Test complex Python features
37 | ("lambda x: x * 2", True),
38 | ("try:\n x()\nexcept:\n pass", True),
39 | ("with open('file') as f:\n pass", True),
40 | ("[x for x in range(10)]", True),
41 | # Test invalid syntax
42 | ("def def", False),
43 | ("class class", False),
44 | ("return return", False),
45 | ("import import", False),
46 | ],
47 | )
48 | def test_python_source_validation(code, should_accept):
49 | """Test validation of Python source code."""
50 | source_code_sm = StateMachine(
51 | {
52 | 0: [(PythonStateMachine, "$")],
53 | }
54 | )
55 | steppers = source_code_sm.get_steppers()
56 | steppers = source_code_sm.advance_all_basic(steppers, code)
57 |
58 | if should_accept:
59 | assert any(stepper.has_reached_accept_state() for stepper in steppers), (
60 | f"Should accept valid Python code: {code}"
61 | )
62 | else:
63 | assert not any(stepper.has_reached_accept_state() for stepper in steppers), (
64 | f"Should not accept invalid Python code: {code}"
65 | )
66 |
67 |
68 | def test_incremental_parsing():
69 | """Test incremental parsing of Python code."""
70 | source_code_sm = StateMachine(
71 | {
72 | 0: [(PythonStateMachine, "$")],
73 | }
74 | )
75 | steppers = source_code_sm.get_steppers()
76 |
77 | # Test valid incremental input
78 | code_parts = [
79 | "def test",
80 | " ",
81 | "(",
82 | ")",
83 | ":",
84 | "\n",
85 | " ",
86 | "x",
87 | " ",
88 | "=",
89 | " ",
90 | "1",
91 | "\n",
92 | " ",
93 | "return",
94 | " ",
95 | ]
96 |
97 | for part in code_parts:
98 | steppers = source_code_sm.advance_all_basic(steppers, part)
99 | assert len(steppers) > 0, f"Should accept partial input: {part}"
100 |
101 | assert any(stepper.has_reached_accept_state() for stepper in steppers)
102 |
103 |
104 | def test_empty_input():
105 | """Test handling of empty input."""
106 | source_code_sm = StateMachine(
107 | {
108 | 0: [(PythonStateMachine, "$")],
109 | }
110 | )
111 | steppers = source_code_sm.get_steppers()
112 | steppers = source_code_sm.advance_all_basic(steppers, "")
113 | assert not any(stepper.has_reached_accept_state() for stepper in steppers)
114 |
115 |
116 | @pytest.mark.parametrize(
117 | "incomplete_code",
118 | [
119 | "if x == 1:\n ",
120 | "def test():\n x = ",
121 | "class MyClass:\n def ",
122 | "try:\n x = 1\nexcept ",
123 | "with open('file') as ",
124 | "def foo():",
125 | "if True:",
126 | "for x in range(10):",
127 | ],
128 | )
129 | def test_incomplete_but_valid_code(incomplete_code):
130 | """Test handling of incomplete but syntactically valid code."""
131 | source_code_sm = StateMachine(
132 | {
133 | 0: [(PythonStateMachine, "$")],
134 | }
135 | )
136 | steppers = source_code_sm.get_steppers()
137 | steppers = source_code_sm.advance_all_basic(steppers, incomplete_code)
138 | assert len(steppers) == 1
139 | assert all(stepper.can_accept_more_input() for stepper in steppers)
140 |
141 |
142 | def test_identifer():
143 | """Test identifier of PythonStateMachine."""
144 | assert PythonStateMachine.get_new_stepper(None).get_identifier() == "python"
145 |
146 | def test_invalid_code_advance():
147 | """Test invalid code advance."""
148 | sm = StateMachine(
149 | {
150 | 0: [(PythonStateMachine, "$")],
151 | }
152 | )
153 | steppers = sm.get_steppers()
154 | steppers = sm.advance_all_basic(steppers, "print(")
155 | assert len(steppers) == 1
156 | # advance with invalid characters
157 | steppers = sm.advance_all_basic(steppers, '!!!!!!!!!')
158 | assert not steppers
159 |
--------------------------------------------------------------------------------
/tests/unit/types/python/test_wrapped_python.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from pse.types.base.encapsulated import EncapsulatedStateMachine
4 | from pse.types.grammar import PythonStateMachine
5 |
6 |
7 | def test_basic_python_block():
8 | """Test basic Python code block parsing."""
9 | python_sm = EncapsulatedStateMachine(
10 | PythonStateMachine, delimiters=PythonStateMachine.delimiters
11 | )
12 | steppers = python_sm.get_steppers()
13 |
14 | # Test opening delimiter
15 | steppers = python_sm.advance_all_basic(steppers, "```python\n")
16 | assert len(steppers) > 0
17 | # Test Python code
18 | steppers = python_sm.advance_all_basic(steppers, "x = 1")
19 | assert len(steppers) > 0
20 |
21 | # Test closing delimiter
22 | steppers = python_sm.advance_all_basic(steppers, "\n```")
23 | assert any(stepper.has_reached_accept_state() for stepper in steppers)
24 |
25 |
26 | @pytest.mark.parametrize(
27 | "code_block",
28 | [
29 | "```python\nprint('Hello')\n\n```",
30 | "```python\nx = 1\ny = 2\nprint(x + y)\n\n```",
31 | "```python\ndef test():\n return True\n\n```",
32 | "```python\nclass Test:\n pass\n\n```",
33 | ],
34 | )
35 | def test_complete_python_blocks(code_block):
36 | """Test various complete Python code blocks."""
37 | python_sm = EncapsulatedStateMachine(
38 | PythonStateMachine, delimiters=PythonStateMachine.delimiters
39 | )
40 | steppers = python_sm.get_steppers()
41 | steppers = python_sm.advance_all_basic(steppers, code_block)
42 | assert any(stepper.has_reached_accept_state() for stepper in steppers)
43 |
44 |
45 | def test_custom_delimiters():
46 | """Test PythonStateMachine with custom delimiters."""
47 | sm = EncapsulatedStateMachine(
48 | PythonStateMachine, delimiters=("", "")
49 | )
50 | steppers = sm.get_steppers()
51 |
52 | code_block = "x = 1"
53 | steppers = sm.advance_all_basic(steppers, code_block)
54 | assert any(stepper.has_reached_accept_state() for stepper in steppers)
55 |
56 |
57 | def test_stepper_clone():
58 | """Test cloning of PythonStepper."""
59 | python_sm = EncapsulatedStateMachine(
60 | PythonStateMachine, delimiters=PythonStateMachine.delimiters
61 | )
62 | steppers = python_sm.get_steppers()
63 | steppers = python_sm.advance_all_basic(steppers, "```python\nx = 1\n")
64 |
65 | original_stepper = steppers[0]
66 | cloned_stepper = original_stepper.clone()
67 |
68 | assert original_stepper.get_current_value() == cloned_stepper.get_current_value()
69 | assert original_stepper is not cloned_stepper
70 |
71 |
72 | @pytest.mark.parametrize(
73 | "invalid_block",
74 | [
75 | "```python\ndef invalid syntax\n\n```",
76 | "```python\nclass:\n\n```",
77 | "```python\nwhile:\n\n```",
78 | "```python\nprint('no closing parenthesis'\n\n```",
79 | ],
80 | )
81 | def test_invalid_python_blocks(invalid_block):
82 | """Test handling of invalid Python code blocks."""
83 | python_sm = EncapsulatedStateMachine(
84 | PythonStateMachine, delimiters=PythonStateMachine.delimiters
85 | )
86 | steppers = python_sm.get_steppers()
87 | steppers = python_sm.advance_all_basic(steppers, invalid_block)
88 | assert not any(stepper.has_reached_accept_state() for stepper in steppers)
89 |
90 |
91 | def test_incremental_parsing():
92 | """Test incremental parsing of Python code block."""
93 | python_sm = EncapsulatedStateMachine(
94 | PythonStateMachine, delimiters=PythonStateMachine.delimiters
95 | )
96 | steppers = python_sm.get_steppers()
97 |
98 | parts = [
99 | "```python\n",
100 | "def test():\n",
101 | " x = 1\n",
102 | " return x\n",
103 | "\n```",
104 | ]
105 |
106 | for part in parts:
107 | steppers = python_sm.advance_all_basic(steppers, part)
108 | assert len(steppers) > 0
109 |
110 | assert any(stepper.has_reached_accept_state() for stepper in steppers)
111 |
112 |
113 | def test_can_accept_more_input():
114 | """Test that the stepper can accept more input."""
115 | python_sm = EncapsulatedStateMachine(
116 | PythonStateMachine, delimiters=PythonStateMachine.delimiters
117 | )
118 | steppers = python_sm.get_steppers()
119 | steppers = python_sm.advance_all_basic(steppers, "```python\n")
120 | assert len(steppers) > 0
121 | steppers = python_sm.advance_all_basic(steppers, "print('Hello")
122 | assert len(steppers) > 0
123 | assert all(stepper.can_accept_more_input() for stepper in steppers)
124 | steppers = python_sm.advance_all_basic(steppers, "')")
125 | assert len(steppers) > 0
126 | assert all(stepper.can_accept_more_input() for stepper in steppers)
127 | steppers = python_sm.advance_all_basic(steppers, "\n```")
128 | assert len(steppers) > 0
129 | assert any(stepper.has_reached_accept_state() for stepper in steppers)
130 |
--------------------------------------------------------------------------------
/tests/unit/types/test_array.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | import pytest
4 |
5 | from pse.types.array import ArrayStateMachine
6 |
7 |
8 | @pytest.fixture
9 | def state_machine():
10 | """Fixture that provides an ArrayAcceptor instance."""
11 | return ArrayStateMachine()
12 |
13 |
14 | def parse_array(state_machine: ArrayStateMachine, json_string: str) -> list[Any]:
15 | """
16 | Helper function to parse a JSON array string using the ArrayAcceptor.
17 |
18 | Args:
19 | state_machine (ArrayAcceptor): The ArrayAcceptor instance.
20 | json_string (str): The JSON array string to parse.
21 |
22 | Returns:
23 | list[Any]: The parsed array.
24 |
25 | Raises:
26 | AssertionError: If the JSON array is invalid.
27 | """
28 | steppers = state_machine.get_steppers()
29 | for char in json_string:
30 | steppers = state_machine.advance_all_basic(steppers, char)
31 | if not any(stepper.has_reached_accept_state() for stepper in steppers):
32 | raise AssertionError("No stepper in accepted state")
33 | # Assuming the first accepted stepper contains the parsed value
34 | for stepper in steppers:
35 | if stepper.has_reached_accept_state():
36 | return stepper.get_current_value()
37 | return []
38 |
39 |
40 | # Parameterized tests for valid arrays
41 | @pytest.mark.parametrize(
42 | "json_string, expected",
43 | [
44 | ("[]", []),
45 | ("[1]", [1]),
46 | ("[123]", [123]),
47 | ('[123, 456, "789"]', [123, 456, "789"]),
48 | ("[[1, 2], [3, 4]]", [[1, 2], [3, 4]]),
49 | ],
50 | )
51 | def test_valid_arrays(
52 | state_machine: ArrayStateMachine, json_string: str, expected: list[Any]
53 | ):
54 | """Test parsing of valid JSON arrays."""
55 | assert parse_array(state_machine, json_string) == expected
56 |
57 |
58 | # Parameterized tests for invalid arrays
59 | @pytest.mark.parametrize(
60 | "json_string",
61 | [
62 | "[123, 456", # Missing closing bracket
63 | "[123, 456, ]", # Trailing comma
64 | ],
65 | )
66 | def test_invalid_arrays(state_machine: ArrayStateMachine, json_string: str):
67 | """Test that an AssertionError is raised for invalid arrays."""
68 | with pytest.raises(AssertionError):
69 | parse_array(state_machine, json_string)
70 |
--------------------------------------------------------------------------------
/tests/unit/types/test_boolean.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from pse_core.state_machine import StateMachine
3 | from pse_core.stepper import Stepper
4 |
5 | from pse.types.boolean import BooleanStateMachine
6 |
7 |
8 | # Fixture for BooleanAcceptor
9 | @pytest.fixture
10 | def boolean_acceptor():
11 | return BooleanStateMachine()
12 |
13 |
14 | # Helper function to process input for BooleanAcceptor
15 | def process_input(state_machine: StateMachine, token: str) -> list[Stepper]:
16 | steppers = state_machine.get_steppers()
17 | return state_machine.advance_all_basic(steppers, token)
18 |
19 |
20 | # Test for BooleanAcceptor
21 | def test_accept_true() -> None:
22 | acceptor = BooleanStateMachine()
23 | steppers = acceptor.get_steppers()
24 | accepted_steppers = acceptor.advance_all_basic(steppers, "true")
25 | assert any(stepper.get_current_value() is True for stepper in accepted_steppers), (
26 | "Should have a stepper with value True"
27 | )
28 |
29 |
30 | def test_accept_false(boolean_acceptor: BooleanStateMachine) -> None:
31 | steppers = boolean_acceptor.get_steppers()
32 | accepted_steppers = boolean_acceptor.advance_all_basic(steppers, "false")
33 | assert any(stepper.get_current_value() is False for stepper in accepted_steppers), (
34 | "Should have a stepper with value False"
35 | )
36 |
37 |
38 | @pytest.mark.parametrize("token", ["tru", "fals", "True", "False", "TRUE", "FALSE"])
39 | def test_reject_invalid_boolean(boolean_acceptor, token):
40 | accepted_steppers = list(process_input(boolean_acceptor, token))
41 | assert not any(
42 | stepper.has_reached_accept_state() for stepper in accepted_steppers
43 | ), f"Should not accept '{token}' as a valid boolean."
44 |
45 |
46 | @pytest.mark.parametrize("token", [" true", "false ", " true ", " false "])
47 | def test_accept_with_whitespace(boolean_acceptor, token):
48 | accepted_steppers = list(process_input(boolean_acceptor, token))
49 | assert not any(
50 | stepper.has_reached_accept_state() for stepper in accepted_steppers
51 | ), f"Should not accept '{token}' with whitespace."
52 |
53 |
54 | @pytest.mark.parametrize("token", ["truex", "falsey", "true123", "false!"])
55 | def test_extra_characters(boolean_acceptor, token):
56 | accepted_steppers = list(process_input(boolean_acceptor, token))
57 | assert not any(
58 | stepper.has_reached_accept_state() for stepper in accepted_steppers
59 | ), f"Should not accept '{token}' with extra characters."
60 |
--------------------------------------------------------------------------------
/tests/unit/types/test_enum.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from pse.types.enum import EnumStateMachine
4 |
5 |
6 | def test_accept_valid_enum_value():
7 | """Test that the state machine correctly accepts a value present in the enum."""
8 | sm = EnumStateMachine(["value1", "value2", "value3"], require_quotes=False)
9 | steppers = sm.get_steppers()
10 | steppers = sm.advance_all_basic(steppers, "value1")
11 |
12 | assert any(stepper.has_reached_accept_state() for stepper in steppers)
13 | for stepper in steppers:
14 | if stepper.has_reached_accept_state():
15 | assert stepper.get_current_value() == "value1"
16 |
17 |
18 | def test_reject_invalid_enum_value():
19 | """Test that the state machine correctly rejects a value not present in the enum."""
20 | sm = EnumStateMachine(["value1", "value2", "value3"], require_quotes=False)
21 | steppers = sm.get_steppers()
22 | steppers = sm.advance_all_basic(steppers, "invalid_value")
23 |
24 | assert not any(stepper.has_reached_accept_state() for stepper in steppers)
25 |
26 |
27 | @pytest.mark.parametrize("value", ["value1", "value2", "value3"])
28 | def test_accept_multiple_enum_values(value):
29 | """Test that the state machine correctly accepts multiple different valid enum values."""
30 | sm = EnumStateMachine(["value1", "value2", "value3"], require_quotes=False)
31 | steppers = sm.get_steppers()
32 | steppers = sm.advance_all_basic(steppers, value)
33 |
34 | assert any(stepper.has_reached_accept_state() for stepper in steppers)
35 | for stepper in steppers:
36 | if stepper.has_reached_accept_state():
37 | assert stepper.get_current_value() == value
38 |
39 |
40 | def test_partial_enum_value_rejection():
41 | """Test that the state machine does not accept prefixes of valid enum values."""
42 | sm = EnumStateMachine(["value1", "value2", "value3"], require_quotes=False)
43 | steppers = sm.get_steppers()
44 | steppers = sm.advance_all_basic(steppers, "val")
45 |
46 | assert not any(stepper.has_reached_accept_state() for stepper in steppers)
47 |
48 |
49 | def test_init_with_empty_enum():
50 | """Test initializing EnumStateMachine with empty enum values raises ValueError."""
51 | with pytest.raises(ValueError):
52 | EnumStateMachine(enum_values=[])
53 |
54 |
55 | @pytest.mark.parametrize("special_value", ["val!@#", "val-123", "val_😊"])
56 | def test_accept_enum_with_special_characters(special_value):
57 | """Test that the state machine correctly handles enum values with special characters."""
58 | sm = EnumStateMachine([special_value], require_quotes=False)
59 | steppers = sm.get_steppers()
60 | steppers = sm.advance_all_basic(steppers, special_value)
61 |
62 | assert any(stepper.has_reached_accept_state() for stepper in steppers)
63 | for stepper in steppers:
64 | if stepper.has_reached_accept_state():
65 | assert stepper.get_current_value() == special_value
66 |
67 |
68 | def test_char_by_char_enum_parsing():
69 | """Test parsing enum values character by character."""
70 | sm = EnumStateMachine(["value1", "value2", "value3"], require_quotes=False)
71 | steppers = sm.get_steppers()
72 |
73 | for char in "value1":
74 | steppers = sm.advance_all_basic(steppers, char)
75 | if not steppers:
76 | break
77 |
78 | assert any(stepper.has_reached_accept_state() for stepper in steppers)
79 | for stepper in steppers:
80 | if stepper.has_reached_accept_state():
81 | assert stepper.get_current_value() == "value1"
82 |
83 |
84 | @pytest.mark.parametrize(
85 | "value",
86 | [
87 | '"test"',
88 | "'test'",
89 | ],
90 | )
91 | def test_enum_with_quotes(value):
92 | """Test enum values with quotes requirement (default behavior)."""
93 | sm = EnumStateMachine(["test"]) # require_quotes defaults to True
94 | steppers = sm.get_steppers()
95 | steppers = sm.advance_all_basic(steppers, value)
96 |
97 | for stepper in steppers:
98 | assert stepper.has_reached_accept_state()
99 | assert stepper.get_current_value() == "test"
100 |
101 |
102 | def test_enum_without_quotes():
103 | """Test enum values without quotes requirement."""
104 | sm = EnumStateMachine(["test"], require_quotes=False)
105 | steppers = sm.get_steppers()
106 | steppers = sm.advance_all_basic(steppers, "test")
107 |
108 | assert any(stepper.has_reached_accept_state() for stepper in steppers)
109 | for stepper in steppers:
110 | if stepper.has_reached_accept_state():
111 | assert stepper.get_current_value() == "test"
112 |
113 |
114 | def test_enum_requires_quotes_by_default():
115 | """Test that enum values require quotes by default."""
116 | sm = EnumStateMachine(["test"]) # require_quotes defaults to True
117 | steppers = sm.get_steppers()
118 | steppers = sm.advance_all_basic(steppers, "test") # no quotes
119 |
120 | assert not any(stepper.has_reached_accept_state() for stepper in steppers)
121 |
--------------------------------------------------------------------------------
/tests/unit/types/test_key_value.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from pse.types.base.phrase import PhraseStateMachine
4 | from pse.types.key_value import KeyValueStateMachine
5 |
6 |
7 | @pytest.mark.parametrize(
8 | "input_string, expected_name, expected_value",
9 | [
10 | ('"key": "value"', "key", "value"),
11 | ('"complex_key": {"nested": "value"}', "complex_key", {"nested": "value"}),
12 | ('"": "empty_key"', "", None),
13 | ('"unicode_key": "unicode_value🎉"', "unicode_key", "unicode_value🎉"),
14 | ('"spaced_key" : "spaced_value"', "spaced_key", "spaced_value"),
15 | ],
16 | )
17 | def test_property_parsing(input_string, expected_name, expected_value):
18 | sm = KeyValueStateMachine()
19 | steppers = sm.get_steppers()
20 | for char in input_string:
21 | steppers = sm.advance_all_basic(steppers, char)
22 |
23 | accepted_steppers = [
24 | stepper for stepper in steppers if stepper.has_reached_accept_state()
25 | ]
26 | assert accepted_steppers, (
27 | f"No stepper reached an accepted state for: {input_string}"
28 | )
29 |
30 | for stepper in accepted_steppers:
31 | name, value = stepper.get_current_value()
32 | assert name == expected_name
33 | assert value == expected_value
34 |
35 |
36 | @pytest.mark.parametrize(
37 | "invalid_input",
38 | [
39 | 'key: "value"', # missing quotes around key
40 | '"key" "value"', # missing colon
41 | '"key":', # missing value
42 | '"key": value', # unquoted value
43 | ':"value"', # missing key
44 | ],
45 | )
46 | def test_invalid_property_formats(invalid_input):
47 | sm = KeyValueStateMachine()
48 | steppers = sm.get_steppers()
49 | for char in invalid_input:
50 | steppers = sm.advance_all_basic(steppers, char)
51 |
52 | assert not any(stepper.has_reached_accept_state() for stepper in steppers), (
53 | f"Stepper should not reach accepted state for invalid input: {invalid_input}"
54 | )
55 |
56 |
57 | def test_empty_key_value():
58 | sm = KeyValueStateMachine()
59 | steppers = sm.get_steppers()
60 | assert len(steppers) == 1
61 | assert not steppers[0].should_complete_step()
62 | assert steppers[0].get_current_value() == ("", None)
63 |
64 |
65 | def test_invalid_sub_stepper_json():
66 | sm = KeyValueStateMachine()
67 | steppers = sm.get_steppers()
68 | assert len(steppers) == 1
69 | stepper = steppers[0]
70 | stepper.sub_stepper = PhraseStateMachine('"invalid json').get_new_stepper()
71 | steppers = sm.advance_all_basic(steppers, '"invalid json')
72 | assert len(steppers) == 1
73 | assert not steppers[0].should_complete_step()
74 |
75 | def test_key_value_equality():
76 | sm1 = KeyValueStateMachine()
77 | sm2 = KeyValueStateMachine()
78 | sm3 = KeyValueStateMachine()
79 |
80 | assert sm1 == sm2
81 | assert sm1 == sm3
82 | assert sm2 == sm3
83 |
--------------------------------------------------------------------------------
/tests/unit/types/xml/test_xml_encapsulated.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import pytest
4 |
5 | from pse.types.base.phrase import PhraseStateMachine
6 | from pse.types.xml.xml_encapsulated import XMLEncapsulatedStateMachine
7 |
8 |
9 | def test_basic_wrapped_content() -> None:
10 | """Test recognition of basic wrapped content."""
11 | inner_sm = PhraseStateMachine("content")
12 | wrapped_sm = XMLEncapsulatedStateMachine(inner_sm, "div")
13 | steppers = wrapped_sm.get_steppers()
14 |
15 | input_sequence = "
content
"
16 | steppers = wrapped_sm.advance_all_basic(steppers, input_sequence)
17 | assert any(stepper.has_reached_accept_state() for stepper in steppers)
18 |
19 |
20 | def test_incremental_parsing() -> None:
21 | """Test incremental parsing of wrapped content."""
22 | inner_sm = PhraseStateMachine("hello")
23 | wrapped_sm = XMLEncapsulatedStateMachine(inner_sm, "greeting")
24 | steppers = wrapped_sm.get_steppers()
25 |
26 | parts = ["", "hel", "lo", ""]
27 |
28 | for part in parts:
29 | steppers = wrapped_sm.advance_all_basic(steppers, part)
30 | assert len(steppers) > 0
31 |
32 | assert any(stepper.has_reached_accept_state() for stepper in steppers)
33 |
34 |
35 | def test_nested_wrapped_content() -> None:
36 | """Test nested XML wrapped content."""
37 | inner_content = PhraseStateMachine("text")
38 | inner_wrapped = XMLEncapsulatedStateMachine(inner_content, "span")
39 | outer_wrapped = XMLEncapsulatedStateMachine(inner_wrapped, "div")
40 |
41 | input_sequence = "text
"
42 | steppers = outer_wrapped.get_steppers()
43 | steppers = outer_wrapped.advance_all_basic(steppers, input_sequence)
44 |
45 | assert any(stepper.has_reached_accept_state() for stepper in steppers)
46 |
47 |
48 | def test_min_buffer_length() -> None:
49 | """Test that min_buffer_length is respected."""
50 | inner_sm = PhraseStateMachine("content")
51 | wrapped_sm = XMLEncapsulatedStateMachine(inner_sm, "div", min_buffer_length=10)
52 | steppers = wrapped_sm.get_steppers()
53 |
54 | # Should not start with opening tag yet
55 | assert not any(stepper.should_start_step("") for stepper in steppers)
56 |
57 | # Add sufficient buffer
58 | buffer_content = "x" * 10
59 | steppers = wrapped_sm.advance_all_basic(steppers, buffer_content)
60 | assert any(stepper.should_start_step("
") for stepper in steppers)
61 |
62 |
63 | def test_invalid_content() -> None:
64 | """Test rejection of invalid inner content."""
65 | inner_sm = PhraseStateMachine("expected")
66 | wrapped_sm = XMLEncapsulatedStateMachine(inner_sm, "div")
67 | steppers = wrapped_sm.get_steppers()
68 |
69 | input_sequence = "
unexpected
"
70 | steppers = wrapped_sm.advance_all_basic(steppers, input_sequence)
71 | assert not any(stepper.has_reached_accept_state() for stepper in steppers)
72 |
73 |
74 | def test_missing_closing_tag() -> None:
75 | """Test rejection when closing tag is missing."""
76 | inner_sm = PhraseStateMachine("content")
77 | wrapped_sm = XMLEncapsulatedStateMachine(inner_sm, "div")
78 | steppers = wrapped_sm.get_steppers()
79 |
80 | input_sequence = "
content"
81 | steppers = wrapped_sm.advance_all_basic(steppers, input_sequence)
82 | assert not any(stepper.has_reached_accept_state() for stepper in steppers)
83 |
84 |
85 | def test_mismatched_tags() -> None:
86 | """Test rejection of mismatched opening and closing tags."""
87 | inner_sm = PhraseStateMachine("content")
88 | wrapped_sm = XMLEncapsulatedStateMachine(inner_sm, "div")
89 | steppers = wrapped_sm.get_steppers()
90 |
91 | input_sequence = "
content"
92 | steppers = wrapped_sm.advance_all_basic(steppers, input_sequence)
93 | assert not any(stepper.has_reached_accept_state() for stepper in steppers)
94 |
95 |
96 | @pytest.mark.parametrize(
97 | "tag_name, content, should_accept",
98 | [
99 | ("div", "content", True),
100 | ("p", "multi\nline\ncontent", True),
101 | ("custom-tag", "content with spaces", True),
102 | ],
103 | )
104 | def test_various_content_scenarios(
105 | tag_name: str, content: str, should_accept: bool
106 | ) -> None:
107 | """Test various content scenarios."""
108 | inner_sm = PhraseStateMachine(content)
109 | wrapped_sm = XMLEncapsulatedStateMachine(inner_sm, tag_name)
110 | steppers = wrapped_sm.get_steppers()
111 |
112 | input_sequence = f"<{tag_name}>{content}{tag_name}>"
113 | steppers = wrapped_sm.advance_all_basic(steppers, input_sequence)
114 | assert any(stepper.has_reached_accept_state() for stepper in steppers)
115 |
--------------------------------------------------------------------------------
/tests/unit/types/xml/test_xml_tag.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from pse.types.base.chain import ChainStateMachine
4 | from pse.types.xml.xml_tag import XMLTagStateMachine
5 |
6 |
7 | def test_basic_tag_initialization() -> None:
8 | """Test basic initialization of XMLTagStateMachine."""
9 | tag_machine = XMLTagStateMachine("div")
10 | assert isinstance(tag_machine, ChainStateMachine)
11 | steppers = tag_machine.get_steppers()
12 | assert len(steppers) == 1
13 | assert not steppers[0].has_reached_accept_state()
14 |
15 |
16 | def test_closing_tag_initialization() -> None:
17 | """Test initialization of closing XMLTagStateMachine."""
18 | tag_machine = XMLTagStateMachine("div", closing_tag=True)
19 | steppers = tag_machine.get_steppers()
20 | assert len(steppers) == 1
21 | assert not steppers[0].has_reached_accept_state()
22 |
23 |
24 | def test_basic_tag_recognition() -> None:
25 | """Test recognition of a basic XML tag."""
26 | tag_machine = XMLTagStateMachine("div")
27 | steppers = tag_machine.get_steppers()
28 | steppers = tag_machine.advance_all_basic(steppers, "
")
29 | assert len(steppers) == 1
30 | assert steppers[0].has_reached_accept_state()
31 |
32 |
33 | def test_closing_tag_recognition() -> None:
34 | """Test recognition of a closing XML tag."""
35 | tag_machine = XMLTagStateMachine("div", closing_tag=True)
36 | steppers = tag_machine.get_steppers()
37 | steppers = tag_machine.advance_all_basic(steppers, "
")
38 | assert len(steppers) == 1
39 | assert steppers[0].has_reached_accept_state()
40 |
41 |
42 | def test_partial_tag_recognition() -> None:
43 | """Test partial tag recognition behavior."""
44 | tag_machine = XMLTagStateMachine("div")
45 | steppers = tag_machine.get_steppers()
46 | steppers = tag_machine.advance_all_basic(steppers, "
", True),
55 | ("span", "", False),
56 | ("p", "
", False), # Case sensitivity test
57 | ("input", "", True),
58 | ("br", "
", False), # Doesn't handle self-closing tags
59 | ],
60 | )
61 | def test_various_tag_scenarios(
62 | tag_name: str, input_text: str, should_accept: bool
63 | ) -> None:
64 | """Test various tag scenarios including invalid ones."""
65 | machine = XMLTagStateMachine(tag_name)
66 | steppers = machine.get_steppers()
67 | steppers = machine.advance_all_basic(steppers, input_text)
68 |
69 | if should_accept:
70 | assert len(steppers) == 1
71 | assert steppers[0].has_reached_accept_state()
72 | else:
73 | assert not any(s.has_reached_accept_state() for s in steppers)
74 |
75 |
76 | def test_empty_tag_name() -> None:
77 | """Test that empty tag names are not allowed."""
78 | with pytest.raises(ValueError):
79 | XMLTagStateMachine("")
80 |
81 |
82 | def test_whitespace_handling() -> None:
83 | """Test that whitespace is not allowed within tags."""
84 | tag_machine = XMLTagStateMachine("div")
85 | steppers = tag_machine.get_steppers()
86 | steppers = tag_machine.advance_all_basic(steppers, "< div>")
87 | assert not any(s.has_reached_accept_state() for s in steppers)
88 |
89 |
90 | def test_get_valid_continuations() -> None:
91 | """Test the get_valid_continuations method."""
92 | tag_machine = XMLTagStateMachine("div")
93 | stepper = tag_machine.get_new_stepper()
94 | # It's a chain, so it should check the ultimate result.
95 | assert stepper.get_valid_continuations() == ["
"]
96 |
97 | tag_machine = XMLTagStateMachine("span", closing_tag=True)
98 | stepper = tag_machine.get_new_stepper()
99 | assert stepper.get_valid_continuations() == [""]
100 |
101 | # Check that it works even if we advance the stepper.
102 | tag_machine = XMLTagStateMachine("div")
103 | steppers = tag_machine.get_steppers()
104 | steppers = tag_machine.advance_all_basic(steppers, "<")
105 | assert len(steppers) == 1
106 | assert steppers[0].get_valid_continuations() == ["
"]
107 | # And again.
108 | steppers = tag_machine.advance_all_basic(steppers, "d")
109 | assert len(steppers) == 1
110 | assert steppers[0].get_valid_continuations() == ["
"]
111 |
112 | # Finally check it returns nothing if we've gone too far.
113 | steppers = tag_machine.advance_all_basic(steppers, "i")
114 | steppers = tag_machine.advance_all_basic(steppers, "v")
115 | steppers = tag_machine.advance_all_basic(steppers, ">")
116 | assert len(steppers) == 1
117 | assert steppers[0].get_valid_continuations() == ["
"]
118 |
--------------------------------------------------------------------------------