├── .git-blame-ignore-revs
├── .github
├── ISSUE_TEMPLATE
│ └── bug_report.md
└── workflows
│ ├── call_cpu_tests.yml
│ ├── call_gpu_tests.yml
│ ├── ci_credentials.yml
│ ├── ci_linux.yml
│ ├── ci_macos.yml
│ ├── ci_windows.yml
│ ├── notebook_tests.yml
│ ├── pull_request.yml
│ └── pypi_upload.yml
├── .gitignore
├── CONTRIBUTING.md
├── GOVERNANCE.md
├── LICENSE.md
├── MAINTAINERS.md
├── MANIFEST.in
├── README.md
├── client
└── graphpaper-inline
│ ├── .gitignore
│ ├── TODO.txt
│ ├── build-to-guidance.sh
│ ├── dist
│ └── .gitignore
│ ├── package.json
│ ├── pnpm-lock.yaml
│ ├── postcss.config.js
│ ├── rollup.config.mjs
│ ├── src
│ ├── App.svelte
│ ├── CustomAudio.svelte
│ ├── CustomVideo.svelte
│ ├── MetricRecord.svelte
│ ├── ResizeListener.svelte
│ ├── Select.svelte
│ ├── Sparkline.svelte
│ ├── StitchHandler.svelte
│ ├── TokenGrid.svelte
│ ├── TokenGridItem.svelte
│ ├── clickoutside.ts
│ ├── interfaces.ts
│ ├── longhover.ts
│ ├── main.css
│ ├── main.js
│ ├── metrics.ts
│ ├── mocks.ts
│ ├── stitch.ts
│ └── template.html
│ ├── tailwind.config.js
│ └── tsconfig.json
├── docs
├── .readthedocs.yaml
├── Makefile
├── _static
│ └── css
│ │ └── styles.css
├── api.rst
├── api_examples.rst
├── art_of_prompt_design.rst
├── conf.py
├── figures
│ ├── anachronism.png
│ ├── await1.png
│ ├── await2.png
│ ├── capture_example.png
│ ├── chat1.png
│ ├── chat_animation.gif
│ ├── chat_reading.png
│ ├── demo_output.png
│ ├── favicon.ico
│ ├── favicon.png
│ ├── function.png
│ ├── gen_loop_demo.png
│ ├── generate_select.png
│ ├── generation1.png
│ ├── get_started_button.png
│ ├── guidance_logo.svg
│ ├── guidance_logo_blue.svg
│ ├── guidance_logo_blue_dark.svg
│ ├── guidance_logo_light_blue.svg
│ ├── guidance_logo_white_dark.svg
│ ├── hidden1.png
│ ├── json_animation.gif
│ ├── json_syntax_variables.png
│ ├── perfect_syntax.png
│ ├── proverb_animation.gif
│ ├── proverb_output.png
│ ├── select.png
│ ├── simple_fstring_llama2_7b.png
│ ├── simple_gen_llama2_7b.png
│ ├── simple_select_llama2_7b.png
│ ├── simple_streaming_example.gif
│ ├── template_objs.png
│ ├── url_with_space.png
│ ├── url_without_space.png
│ └── watch_demo_button.png
├── index.rst
├── make.bat
└── tutorials.rst
├── guidance
├── __init__.py
├── _ast.py
├── _bg
│ └── __init__.py
├── _grammar.py
├── _guidance.py
├── _guidance.pyi
├── _parser.py
├── _schema.py
├── _utils.py
├── bench
│ ├── __init__.py
│ ├── _api.py
│ ├── _powerlift.py
│ └── _utils.py
├── chat.py
├── library
│ ├── __init__.py
│ ├── _audio.py
│ ├── _block.py
│ ├── _capture.py
│ ├── _ebnf.py
│ ├── _gen.py
│ ├── _image.py
│ ├── _json.py
│ ├── _optional.py
│ ├── _pydantic.py
│ ├── _role.py
│ ├── _sequences.py
│ ├── _subgrammar.py
│ ├── _substring.py
│ ├── _tool.py
│ └── _video.py
├── metrics
│ ├── __init__.py
│ └── _metrics.py
├── models
│ ├── __init__.py
│ ├── _azureai.py
│ ├── _base
│ │ ├── __init__.py
│ │ ├── _interpreter.py
│ │ ├── _model.py
│ │ └── _state.py
│ ├── _byte_tokenizer.py
│ ├── _engine
│ │ ├── __init__.py
│ │ ├── _engine.py
│ │ ├── _interpreter.py
│ │ ├── _state.py
│ │ └── _tokenizer.py
│ ├── _llama_cpp.py
│ ├── _mock.py
│ ├── _openai.py
│ ├── _openai_base.py
│ ├── _transformers.py
│ ├── broken_models
│ │ ├── README.MD
│ │ ├── _Gemini.py
│ │ ├── _anthropic.py
│ │ ├── _azure_openai.py
│ │ ├── _azureai_studio.py
│ │ ├── _cohere.py
│ │ ├── _googleai.py
│ │ ├── _lite_llm.py
│ │ ├── _togetherai.py
│ │ └── _vertexai.py
│ └── experimental
│ │ ├── __init__.py
│ │ └── _vllm.py
├── py.typed
├── registry
│ ├── __init__.py
│ └── _registry.py
├── resources
│ ├── graphpaper-inline.html
│ ├── main.js
│ ├── sample_audio.wav
│ ├── sample_image.png
│ └── sample_video.mp4
├── selectors.py
├── trace
│ ├── __init__.py
│ └── _trace.py
└── visual
│ ├── __init__.py
│ ├── _environment.py
│ ├── _exchange.py
│ ├── _jupyter.py
│ ├── _message.py
│ ├── _renderer.py
│ └── _trace.py
├── notebooks
├── anachronism.ipynb
├── api_examples
│ ├── library
│ │ └── gen.ipynb
│ └── models
│ │ ├── AzureOpenAI.ipynb
│ │ ├── OpenAI.ipynb
│ │ └── TogetherAI.ipynb
├── art_of_prompt_design
│ ├── prompt_boundaries_and_token_healing.ipynb
│ ├── rag.ipynb
│ ├── react.ipynb
│ ├── tool_use.ipynb
│ └── use_clear_syntax.ipynb
├── benchmarks
│ └── json_output_bench.ipynb
├── chatgpt_vs_open_source_on_harder_tasks.ipynb
├── guaranteeing_valid_syntax.ipynb
├── proverb.ipynb
├── testing_lms.ipynb
├── tutorials
│ ├── adding_new_models.ipynb
│ ├── chat.ipynb
│ ├── code_generation.ipynb
│ ├── guidance_acceleration.ipynb
│ ├── intro_to_guidance.ipynb
│ ├── regex_constraints.ipynb
│ └── token_healing.ipynb
└── unstable
│ ├── .gitignore
│ └── State Debugging.ipynb
├── packages
└── python
│ └── stitch
│ ├── .coveragerc
│ ├── .eslintignore
│ ├── .eslintrc.js
│ ├── .github
│ └── workflows
│ │ └── build.yml
│ ├── .gitignore
│ ├── .npmignore
│ ├── .prettierignore
│ ├── .prettierrc
│ ├── .yarnrc.yml
│ ├── LICENSE.txt
│ ├── MANIFEST.in
│ ├── README.md
│ ├── babel.config.js
│ ├── codecov.yml
│ ├── css
│ └── widget.css
│ ├── docs
│ ├── Makefile
│ ├── environment.yml
│ ├── make.bat
│ └── source
│ │ ├── _static
│ │ └── helper.js
│ │ ├── conf.py
│ │ ├── develop-install.rst
│ │ ├── examples
│ │ ├── index.rst
│ │ └── introduction.nblink
│ │ ├── index.rst
│ │ ├── installing.rst
│ │ └── introduction.rst
│ ├── examples
│ └── introduction.ipynb
│ ├── install.json
│ ├── jest.config.js
│ ├── package.json
│ ├── pyproject.toml
│ ├── pytest.ini
│ ├── readthedocs.yml
│ ├── setup.py
│ ├── src
│ ├── __tests__
│ │ ├── index.spec.ts
│ │ └── utils.ts
│ ├── extension.ts
│ ├── index.ts
│ ├── plugin.ts
│ ├── version.ts
│ └── widget.ts
│ ├── stitch.json
│ ├── stitch
│ ├── __init__.py
│ ├── _frontend.py
│ ├── _version.py
│ ├── nbextension
│ │ └── extension.js
│ ├── stitch.py
│ └── tests
│ │ ├── __init__.py
│ │ ├── conftest.py
│ │ ├── test_example.py
│ │ └── test_nbextension_path.py
│ ├── tsconfig.eslint.json
│ ├── tsconfig.json
│ ├── webpack.config.js
│ └── yarn.lock
├── pyproject.toml
├── setup.py
└── tests
├── ReadMe.md
├── __init__.py
├── bench
├── __init__.py
├── test_api.py
├── test_powerlift.py
└── test_utils.py
├── conftest.py
├── model_integration
├── __init__.py
├── library
│ ├── test_gen.py
│ ├── test_subgrammar.py
│ └── test_substring.py
├── test_grammar.py
├── test_model.py
└── test_tokenizers.py
├── model_specific
├── __init__.py
├── common_chat_testing.py
├── llama_cpp_tests
│ ├── __init__.py
│ ├── test_chat_templates.py
│ └── test_llama_cpp.py
├── test_transformers.py
└── test_visual.py
├── need_credentials
├── __init__.py
├── test_anthropic.py
├── test_azureai_openai.py
├── test_azureai_studio.py
├── test_chat_templates.py
├── test_cohere.py
├── test_googleai.py
├── test_lite_llm.py
├── test_openai.py
├── test_togetherai.py
├── test_tokenizers.py
└── test_vertexai.py
├── notebooks
├── __init__.py
└── test_notebooks.py
├── tokenizer_common.py
├── unit
├── __init__.py
├── library
│ ├── __init__.py
│ ├── json
│ │ ├── __init__.py
│ │ ├── test_allOf.py
│ │ ├── test_json.py
│ │ ├── test_refs.py
│ │ ├── test_string_format.py
│ │ └── utils.py
│ ├── test_block.py
│ ├── test_capture.py
│ ├── test_gen.py
│ ├── test_image.py
│ ├── test_one_or_more.py
│ ├── test_pydantic.py
│ ├── test_regex.py
│ ├── test_sequences.py
│ ├── test_subgrammar.py
│ └── test_substring.py
├── test_ast.py
├── test_decorator.py
├── test_grammar.py
├── test_ll.py
├── test_model.py
├── test_parser.py
├── test_trace.py
└── test_visual.py
└── utils.py
/.git-blame-ignore-revs:
--------------------------------------------------------------------------------
1 | # .git-blame-ignore-revs
2 | # Ran black on major files to standardize codebase
3 | 57da386795bc94a34275b333da586f171f96d7c8
4 | # Ran black on tests and other ancillary python files in code
5 | 083fb9877b507ed27136441c683ce051edf37e81
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: Create a report to help us improve
4 | title: ''
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | **The bug**
11 | A clear and concise description of what the bug is.
12 |
13 | **To Reproduce**
14 | Give a full working code snippet that can be pasted into a notebook cell or python file. Make sure to include the LLM load step so we know which model you are using.
15 | ```python
16 | # put your code snippet here
17 | ```
18 |
19 | **System info (please complete the following information):**
20 | - OS (e.g. Ubuntu, Windows 11, Mac OS, etc.):
21 | - Guidance Version (`guidance.__version__`):
22 |
--------------------------------------------------------------------------------
/.github/workflows/call_cpu_tests.yml:
--------------------------------------------------------------------------------
1 | name: call_cpu_tests
2 |
3 | on:
4 | workflow_call:
5 | inputs:
6 | os:
7 | required: true
8 | type: string
9 | python-version:
10 | required: true
11 | type: string
12 | model:
13 | required: true
14 | type: string
15 | secrets:
16 | HF_TOKEN:
17 | required: false
18 | workflow_dispatch:
19 | inputs:
20 | os:
21 | required: false
22 | type: string
23 | default: "Large_Linux" # can instead use "Large_Windows" or the default OSes like "macos-latest"
24 | python-version:
25 | required: false
26 | type: string
27 | default: "3.12"
28 | model:
29 | required: false
30 | type: string
31 | default: "transformers_gpt2_cpu" # also try "llamacpp_llama2_7b_cpu", etc
32 | commit_id:
33 | description: 'Branch or Commit ID (optional)'
34 | required: false
35 | type: string
36 |
37 | jobs:
38 | cpu_tests:
39 | runs-on: ${{ inputs.os }}
40 | steps:
41 | - name: Checkout repo at ${{ github.event_name == 'workflow_dispatch' && inputs.commit_id || github.sha }}
42 | uses: actions/checkout@v4
43 | with:
44 | ref: ${{ github.event_name == 'workflow_dispatch' && inputs.commit_id || github.sha }}
45 | - name: Set up Python ${{ inputs.python-version }}
46 | uses: actions/setup-python@v5
47 | with:
48 | python-version: ${{ inputs.python-version }}
49 | - name: Install guidance in ${{ inputs.os }}
50 | shell: bash
51 | run: |
52 | python -m pip install --upgrade pip
53 | python -m pip install -e .[llamacpp,transformers,test]
54 | python -m pip install accelerate # required if using smaller quantizations
55 | - name: cpu_tests for ${{ inputs.model }}
56 | shell: bash
57 | env:
58 | HF_TOKEN: ${{ secrets.HF_TOKEN }}
59 | run: |
60 | pytest -vv --cov=guidance --cov-report=xml --cov-report=term-missing \
61 | --selected_model ${{ inputs.model }} \
62 | ./tests/model_integration ./tests/model_specific
63 | - name: Upload coverage reports to Codecov
64 | uses: codecov/codecov-action@v4
65 | with:
66 | token: ${{ secrets.CODECOV_TOKEN }}
67 |
--------------------------------------------------------------------------------
/.github/workflows/call_gpu_tests.yml:
--------------------------------------------------------------------------------
1 | name: call_gpu_tests
2 |
3 | on:
4 | workflow_call:
5 | inputs:
6 | os:
7 | required: true
8 | type: string
9 | python-version:
10 | required: true
11 | type: string
12 | model:
13 | required: true
14 | type: string
15 | secrets:
16 | HF_TOKEN:
17 | required: false
18 | workflow_dispatch:
19 | inputs:
20 | os:
21 | required: false
22 | type: string
23 | default: "gpu-runner"
24 | python-version:
25 | required: false
26 | type: string
27 | default: "3.12"
28 | model:
29 | required: false
30 | type: string
31 | default: "llamacpp_llama2_7b_gpu" # also try "transformers_gpt2_gpu", "transformers_phi2_gpu", etc
32 | commit_id:
33 | description: 'Branch or Commit ID (optional)'
34 | required: false
35 | type: string
36 |
37 | jobs:
38 | gpu_tests:
39 | runs-on: ${{ inputs.os }}
40 | steps:
41 | - name: Checkout repo at ${{ github.event_name == 'workflow_dispatch' && inputs.commit_id || github.sha }}
42 | uses: actions/checkout@v4
43 | with:
44 | ref: ${{ github.event_name == 'workflow_dispatch' && inputs.commit_id || github.sha }}
45 | - name: Set up Python ${{ inputs.python-version }}
46 | uses: actions/setup-python@v5
47 | with:
48 | python-version: ${{ inputs.python-version }}
49 | - name: Install NVIDIA SDK
50 | shell: bash
51 | run: |
52 | nvidia-smi
53 | sudo apt-get --yes update
54 | sudo apt-get --yes install cuda-toolkit-12.6
55 | echo "/usr/local/cuda-12.6/bin" >> $GITHUB_PATH
56 | - name: Upgrade pip
57 | shell: bash
58 | run : |
59 | python -m pip install --upgrade pip
60 | - name: Install other packages
61 | shell: bash
62 | run: |
63 | python -m pip install accelerate
64 | - name: Install guidance in ${{ inputs.os }}
65 | shell: bash
66 | run: |
67 | CMAKE_ARGS="-DGGML_CUDA=on" python -m pip install -e .[llamacpp,transformers,test]
68 | - name: Check GPU available
69 | shell: bash
70 | run: |
71 | python -c "import torch; assert torch.cuda.is_available()"
72 | - name: gpu_tests for ${{ inputs.model }}
73 | shell: bash
74 | env:
75 | HF_TOKEN: ${{ secrets.HF_TOKEN }}
76 | run: |
77 | pytest -vv --cov=guidance --cov-report=xml --cov-report=term-missing \
78 | --selected_model ${{ inputs.model }} \
79 | ./tests/model_integration ./tests/model_specific
80 | - name: Upload coverage reports to Codecov
81 | uses: codecov/codecov-action@v4
82 | with:
83 | token: ${{ secrets.CODECOV_TOKEN }}
84 |
--------------------------------------------------------------------------------
/.github/workflows/ci_linux.yml:
--------------------------------------------------------------------------------
1 | # CI Tests which run on Linux machines
2 |
3 | # These access secrets, so should only be run on local branches.
4 |
5 | # Ideally, the CI tests would be a single workflow, but several issues
6 | # (especially varied OS support) mean that it is hard to keep a single
7 | # workflow green.
8 |
9 | name: CI Tests - Linux
10 | permissions:
11 | contents: read
12 |
13 | on:
14 | workflow_dispatch:
15 | inputs:
16 | commit_id:
17 | description: 'Branch or Commit ID (optional)'
18 | required: false
19 | type: string
20 | schedule:
21 | # * is a special character in YAML so we quote this string
22 | # Run at 09:00 UTC every day
23 | - cron: '00 09 * * *'
24 |
25 |
26 | jobs:
27 | cpu_small:
28 | strategy:
29 | fail-fast: false # Don't cancel all on first failure
30 | matrix:
31 | python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
32 | model:
33 | - "transformers_gpt2_cpu"
34 | - "llamacpp_phi3_mini_4k_instruct_cpu"
35 | - "llamacpp_gemma2_9b_cpu"
36 | uses: ./.github/workflows/call_cpu_tests.yml
37 | with:
38 | os: Large_Linux
39 | python-version: ${{ matrix.python-version }}
40 | model: ${{ matrix.model }}
41 | secrets:
42 | HF_TOKEN: ${{ secrets.HF_TOKEN }}
43 |
44 | cpu_big:
45 | strategy:
46 | fail-fast: false # Don't cancel all on first failure
47 | matrix:
48 | python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
49 | model:
50 | - "llamacpp_llama2_7b_cpu"
51 | - "transformers_llama3_8b_cpu"
52 | - "transformers_phi4_mini_cpu"
53 | uses: ./.github/workflows/call_cpu_tests.yml
54 | with:
55 | os: Large_Linux
56 | python-version: ${{ matrix.python-version }}
57 | model: ${{ matrix.model }}
58 | secrets:
59 | HF_TOKEN: ${{ secrets.HF_TOKEN }}
60 |
61 | gpu_tests:
62 | strategy:
63 | fail-fast: false # Don't cancel all on first failure
64 | matrix:
65 | python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
66 | model:
67 | - "transformers_gpt2_gpu"
68 | - "transformers_gemma2_9b_gpu"
69 | - "llamacpp_llama2_7b_gpu"
70 | - "transformers_gemma2_9b_cpu" # CUDA is required for this model
71 | - "transformers_phi4_mini_gpu"
72 | uses: ./.github/workflows/call_gpu_tests.yml
73 | with:
74 | os: "gpu-runner"
75 | python-version: ${{ matrix.python-version }}
76 | model: ${{ matrix.model }}
77 | secrets:
78 | HF_TOKEN: ${{ secrets.HF_TOKEN }}
79 |
--------------------------------------------------------------------------------
/.github/workflows/ci_macos.yml:
--------------------------------------------------------------------------------
1 | # CI Tests which run on MacOS machines
2 |
3 | # These access secrets, so should only be run on local branches.
4 |
5 | # Ideally, the CI tests would be a single workflow, but several issues
6 | # (especially varied OS support) mean that it is hard to keep a single
7 | # workflow green.
8 |
9 | # MacOS has been a particular trouble due to the small disk space
10 | # allocations on all the VMs, leading to the --selected_model
11 | # machinery
12 |
13 | name: CI Tests - MacOS
14 | permissions:
15 | contents: read
16 |
17 | on:
18 | workflow_dispatch:
19 | inputs:
20 | commit_id:
21 | description: 'Branch or Commit ID (optional)'
22 | required: false
23 | type: string
24 | schedule:
25 | # * is a special character in YAML so we quote this string
26 | # Run at 09:10 UTC every day
27 | - cron: '10 09 * * *'
28 |
29 | jobs:
30 | cpu_small:
31 | strategy:
32 | fail-fast: false # Don't cancel all on first failure
33 | matrix:
34 | python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
35 | model:
36 | - "transformers_gpt2_cpu"
37 | - "llamacpp_phi3_mini_4k_instruct_cpu"
38 | - "llamacpp_gemma2_9b_cpu"
39 | uses: ./.github/workflows/call_cpu_tests.yml
40 | with:
41 | os: "macos-latest"
42 | python-version: ${{ matrix.python-version }}
43 | model: ${{ matrix.model }}
44 | secrets:
45 | HF_TOKEN: ${{ secrets.HF_TOKEN }}
--------------------------------------------------------------------------------
/.github/workflows/ci_windows.yml:
--------------------------------------------------------------------------------
1 | # CI Tests which run on Windows machines
2 |
3 | # These access secrets, so should only be run on local branches.
4 |
5 | # Ideally, the CI tests would be a single workflow, but several issues
6 | # (especially varied OS support) mean that it is hard to keep a single
7 | # workflow green. If there is one OS likely to lag slightly in support
8 | # it is Windows
9 |
10 | name: CI Tests - Windows
11 | permissions:
12 | contents: read
13 |
14 | on:
15 | workflow_dispatch:
16 | inputs:
17 | commit_id:
18 | description: 'Branch or Commit ID (optional)'
19 | required: false
20 | type: string
21 | schedule:
22 | # * is a special character in YAML so we quote this string
23 | # Run at 09:30 UTC every day
24 | - cron: '30 09 * * *'
25 |
26 | jobs:
27 | cpu_small:
28 | strategy:
29 | fail-fast: false # Don't cancel all on first failure
30 | matrix:
31 | python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
32 | model:
33 | - "transformers_gpt2_cpu"
34 | - "llamacpp_phi3_mini_4k_instruct_cpu"
35 | - "llamacpp_gemma2_9b_cpu"
36 | uses: ./.github/workflows/call_cpu_tests.yml
37 | with:
38 | os: "Large_Windows"
39 | python-version: ${{ matrix.python-version }}
40 | model: ${{ matrix.model }}
41 | secrets:
42 | HF_TOKEN: ${{ secrets.HF_TOKEN }}
43 |
44 | cpu_big:
45 | strategy:
46 | fail-fast: false # Don't cancel all on first failure
47 | matrix:
48 | python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
49 | model:
50 | - "llamacpp_llama2_7b_cpu"
51 | - "transformers_llama3_8b_cpu"
52 | - "transformers_phi4_mini_cpu"
53 | uses: ./.github/workflows/call_cpu_tests.yml
54 | with:
55 | os: "Large_Windows"
56 | python-version: ${{ matrix.python-version }}
57 | model: ${{ matrix.model }}
58 | secrets:
59 | HF_TOKEN: ${{ secrets.HF_TOKEN }}
60 |
--------------------------------------------------------------------------------
/.github/workflows/notebook_tests.yml:
--------------------------------------------------------------------------------
1 | # These should only be run on main, because they access secrets
2 | # Not part of the regular CI run, since notebook tests seem
3 | # particularly flaky
4 |
5 | name: Notebook Tests
6 |
7 | on:
8 | workflow_dispatch:
9 | inputs:
10 | commit_id:
11 | description: 'Branch or Commit ID (optional)'
12 | required: false
13 | type: string
14 | schedule:
15 | # * is a special character in YAML so we quote this string
16 | # Run at 10:00 UTC every day
17 | - cron: '00 10 * * *'
18 |
19 | jobs:
20 | notebook_tests:
21 | runs-on: "Large_Linux"
22 | environment: test
23 | strategy:
24 | fail-fast: false # Don't cancel all on first failure
25 | matrix:
26 | python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
27 | permissions:
28 | id-token: write # for Azure CLI login
29 | steps:
30 | - name: Checkout repo at ${{ github.event_name == 'workflow_dispatch' && inputs.commit_id || github.sha }}
31 | uses: actions/checkout@v4
32 | with:
33 | ref: ${{ github.event_name == 'workflow_dispatch' && inputs.commit_id || github.sha }}
34 | - name: Set up Python ${{ matrix.python-version }}
35 | uses: actions/setup-python@v5
36 | with:
37 | python-version: ${{ matrix.python-version }}
38 | - name: Install guidance
39 | shell: bash
40 | run: |
41 | python -m pip install --upgrade pip
42 | python -m pip install -e .[all,llamacpp,test]
43 | - name: Azure login
44 | uses: azure/login@v2
45 | with:
46 | client-id: ${{ secrets.AZURE_CLIENT_ID }}
47 | tenant-id: ${{ secrets.AZURE_TENANT_ID }}
48 | subscription-id: ${{ secrets.AZURE_SUBSCRIPTION_ID }}
49 | - name: 'Run Azure CLI commands'
50 | shell: bash
51 | run: |
52 | az account show
53 | az group list
54 | - name: Notebook tests
55 | shell: bash
56 | env:
57 | HF_TOKEN: ${{ secrets.HF_TOKEN }}
58 | # Configure endpoints
59 | AZUREAI_OPENAI_CHAT_ENDPOINT: ${{ vars.AZUREAI_OPENAI_CHAT_ENDPOINT }}
60 | AZUREAI_OPENAI_CHAT_DEPLOYMENT_NAME: ${{ vars.AZUREAI_OPENAI_CHAT_DEPLOYMENT_NAME }}
61 | AZUREAI_OPENAI_CHAT_MODEL: ${{ vars.AZUREAI_OPENAI_CHAT_MODEL }}
62 | AZUREAI_OPENAI_CHAT_API_VERSION: ${{ vars.AZUREAI_OPENAI_CHAT_API_VERSION }}
63 | run: |
64 | pytest -vv --cov=guidance --cov-report=xml --cov-report=term-missing \
65 | ./tests/notebooks
66 | - name: Upload coverage reports to Codecov
67 | uses: codecov/codecov-action@v4
68 | with:
69 | token: ${{ secrets.CODECOV_TOKEN }}
70 |
--------------------------------------------------------------------------------
/.github/workflows/pull_request.yml:
--------------------------------------------------------------------------------
1 | name: Pull Request
2 |
3 | on:
4 | pull_request:
5 | workflow_dispatch:
6 | inputs:
7 | commit_id:
8 | description: 'Branch or Commit ID (optional)'
9 | required: false
10 | type: string
11 | schedule:
12 | # Run at 10:00 UTC every day
13 | - cron: "00 10 * * *"
14 |
15 | jobs:
16 | unit_tests:
17 | strategy:
18 | fail-fast: false # Don't cancel all on first failure
19 | matrix:
20 | os: [ubuntu-latest, windows-latest, macos-latest]
21 | python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
22 | runs-on: ${{ matrix.os }}
23 | steps:
24 | - name: Checkout repo at ${{ github.event_name == 'workflow_dispatch' && inputs.commit_id || github.sha }}
25 | uses: actions/checkout@v4
26 | with:
27 | ref: ${{ github.event_name == 'workflow_dispatch' && inputs.commit_id || github.sha }}
28 | - name: Set up Python ${{ matrix.python-version }}
29 | uses: actions/setup-python@v5
30 | with:
31 | python-version: ${{ matrix.python-version }}
32 | - name: Minimal install
33 | run: |
34 | python -m pip install --upgrade pip
35 | python -m pip install -e .
36 | - name: Attempt import
37 | run: |
38 | python -c "import guidance"
39 | - name: Bigger install
40 | run: |
41 | python -m pip install -e .[unittest]
42 | - name: Unit Tests
43 | shell: bash
44 | run: |
45 | pytest -vv --cov=guidance --cov-report=xml --cov-report=term-missing \
46 | ./tests/unit
47 | - name: Upload coverage reports to Codecov
48 | uses: codecov/codecov-action@v4
49 | with:
50 | token: ${{ secrets.CODECOV_TOKEN }}
51 |
52 | cpu_tests:
53 | strategy:
54 | fail-fast: false # Don't cancel all on first failure
55 | matrix:
56 | os: ["Large_Linux"] # , "Large_Windows"]
57 | python-version: ["3.9", "3.13"]
58 | model:
59 | - "transformers_gpt2_cpu"
60 | - "llamacpp_phi3_mini_4k_instruct_cpu"
61 | uses: ./.github/workflows/call_cpu_tests.yml
62 | with:
63 | os: ${{ matrix.os }}
64 | python-version: ${{ matrix.python-version }}
65 | model: ${{ matrix.model }}
66 |
67 | # gpu_tests:
68 | # strategy:
69 | # fail-fast: false # Don't cancel all on first failure
70 | # matrix:
71 | # os: ["gpu-runner"]
72 | # python-version: ["3.9", "3.12"]
73 | # model:
74 | # - "transformers_gpt2_gpu"
75 | # - "llamacpp_llama2_7b_gpu"
76 | # uses: ./.github/workflows/call_gpu_tests.yml
77 | # with:
78 | # os: ${{ matrix.os }}
79 | # python-version: ${{ matrix.python-version }}
80 | # model: ${{ matrix.model }}
81 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | notebooks/local_scratch
2 | __pycache__/
3 | .vscode
4 | .vs
5 | .idea/
6 | /build
7 | /dist
8 | *.egg-info
9 | *.diskcache
10 | .ipynb_checkpoints
11 | node_modules
12 | .eggs/
13 | .env
14 | .DS_Store
15 | venv/
16 |
17 | # Ignore native library built by setup
18 | guidance/*.so
19 | guidance/_rust/*.so
20 | guidance/_rust/target/
21 | guidance/_rust/Cargo.lock
22 | *.pyd
23 |
24 | notebooks/**/*.papermill_out.ipynb
25 |
26 | .mypy_cache/*
27 |
28 | **/scratch.*
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) The Guidance Contributors
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/MAINTAINERS.md:
--------------------------------------------------------------------------------
1 | # Maintainers
2 |
3 | This document lists the Maintainers of the Project. Maintainers may be added once approved by the existing maintainers as described in the [Governance document](./GOVERNANCE.md). By adding your name to this list you are agreeing to abide by the Project governance documents and to abide by all of the Organization's polices, including the [code of conduct](https://github.com/guidance-ai/governance/blob/main/CODE-OF-CONDUCT.md), [trademark policy](https://github.com/guidance-ai/governance/blob/main/TRADEMARKS.md), and [antitrust policy](https://github.com/guidance-ai/governance/blob/main/ANTITRUST.md). If you are participating because of your affiliation with another organization (designated below), you represent that you have the authority to bind that organization to these policies.
4 |
5 | | **NAME** | **Handle** | **Affiliated Organization** |
6 | | --- | --- | --- |
7 | | Scott Lundberg | [slundberg](https://github.com/slundberg) | |
8 | | Harsha Nori | [Harsha-Nori](https://github.com/Harsha-Nori) | Microsoft |
9 | | Marco Tulio Ribeiro | [marcotcr](https://github.com/marcotcr) | Google |
10 |
11 | ---
12 | Part of MVG-0.1-beta.
13 | Made with love by GitHub. Licensed under the [CC-BY 4.0 License](https://creativecommons.org/licenses/by-sa/4.0/).
14 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include resources/graphpaper-inline.html
2 | include resources/sample_audio.wav
3 | include resources/sample_video.mp4
4 | include resources/sample_image.png
5 |
--------------------------------------------------------------------------------
/client/graphpaper-inline/.gitignore:
--------------------------------------------------------------------------------
1 | node_modules/
2 | build/
3 | .DS_Store
4 |
--------------------------------------------------------------------------------
/client/graphpaper-inline/TODO.txt:
--------------------------------------------------------------------------------
1 | - Remove CDN font links (googlefonts)
2 | - Image integration
3 | - Testing
--------------------------------------------------------------------------------
/client/graphpaper-inline/build-to-guidance.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -x
3 |
4 | npm run build
5 | cp dist/index.html ../../guidance/resources/graphpaper-inline.html
--------------------------------------------------------------------------------
/client/graphpaper-inline/dist/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/client/graphpaper-inline/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "graphpaper",
3 | "version": "0.0.1",
4 | "scripts": {
5 | "build": "rollup -c",
6 | "dev": "rollup -c -w",
7 | "start": "sirv dist"
8 | },
9 | "devDependencies": {
10 | "@rollup/plugin-commonjs": "^26.0.1",
11 | "@rollup/plugin-node-resolve": "^15.2.3",
12 | "@rollup/plugin-terser": "^0.4.4",
13 | "@rollup/plugin-typescript": "^11.1.6",
14 | "@types/d3-scale": "^4.0.8",
15 | "@types/d3-scale-chromatic": "^3.0.3",
16 | "@types/dompurify": "^3.0.5",
17 | "@types/video.js": "^7.3.58",
18 | "autoprefixer": "^10.4.20",
19 | "cssnano": "^7.0.5",
20 | "postcss": "^8.4.41",
21 | "rollup": "^4.21.0",
22 | "rollup-plugin-copy": "^3.5.0",
23 | "rollup-plugin-html-bundle": "^0.0.3",
24 | "rollup-plugin-livereload": "^2.0.5",
25 | "rollup-plugin-postcss": "^4.0.2",
26 | "rollup-plugin-serve": "^1.1.1",
27 | "rollup-plugin-svelte": "^7.2.2",
28 | "sirv-cli": "^2.0.2",
29 | "svelte": "^4.2.18",
30 | "svelte-preprocess": "^6.0.2",
31 | "tailwindcss": "^3.4.10",
32 | "tslib": "^2.6.3",
33 | "typescript": "^5.5.4"
34 | },
35 | "dependencies": {
36 | "d3-interpolate": "^3.0.1",
37 | "d3-scale": "^4.0.2",
38 | "d3-scale-chromatic": "^3.1.0",
39 | "dompurify": "^3.1.7",
40 | "tailwind-scrollbar": "^4.0.0",
41 | "video.js": "^8.21.0"
42 | }
43 | }
--------------------------------------------------------------------------------
/client/graphpaper-inline/postcss.config.js:
--------------------------------------------------------------------------------
1 | module.exports = {
2 | plugins: {
3 | tailwindcss: {},
4 | autoprefixer: {},
5 | cssnano: { preset: 'default' }
6 | }
7 | }
--------------------------------------------------------------------------------
/client/graphpaper-inline/rollup.config.mjs:
--------------------------------------------------------------------------------
1 | import svelte from 'rollup-plugin-svelte';
2 | import { sveltePreprocess } from 'svelte-preprocess';
3 | import resolve from '@rollup/plugin-node-resolve';
4 | import commonjs from '@rollup/plugin-commonjs';
5 | import terser from '@rollup/plugin-terser';
6 | import typescript from '@rollup/plugin-typescript';
7 | import postcss from 'rollup-plugin-postcss';
8 | import livereload from 'rollup-plugin-livereload';
9 | // @ts-ignore
10 | import serve from 'rollup-plugin-serve';
11 | // @ts-ignore
12 | import htmlBundle from 'rollup-plugin-html-bundle';
13 | import copy from 'rollup-plugin-copy';
14 |
15 | const production = !process.env.ROLLUP_WATCH;
16 |
17 | export default [{
18 | input: 'src/main.js',
19 | output: {
20 | format: 'iife',
21 | name: 'app',
22 | file: 'build/bundle.js',
23 | sourcemap: !production,
24 | },
25 | plugins: [
26 | typescript(),
27 | svelte({
28 | compilerOptions: {
29 | dev: !production
30 | },
31 | preprocess: sveltePreprocess()
32 | }),
33 | resolve({
34 | browser: true,
35 | dedupe: importee => importee === 'svelte' || importee.startsWith('svelte/'),
36 | extensions: ['.svelte', '.mjs', '.ts', '.js', '.json', '.node']
37 | }),
38 | commonjs(),
39 | postcss(),
40 | copy({
41 | targets: [
42 | { src: 'src/template.html', dest: 'build' }
43 | ]
44 | }),
45 | htmlBundle({
46 | template: 'build/template.html',
47 | target: production ? 'dist/index.html' : 'build/index.html',
48 | targetElement: 'body',
49 | inline: production
50 | }),
51 | !production && serve('build'),
52 | !production && livereload('build'),
53 | production && terser()
54 | ],
55 | watch: {
56 | clearScreen: false
57 | }
58 | }];
--------------------------------------------------------------------------------
/client/graphpaper-inline/src/CustomVideo.svelte:
--------------------------------------------------------------------------------
1 |
43 |
44 |
45 |
48 |
49 |
50 |
55 |
--------------------------------------------------------------------------------
/client/graphpaper-inline/src/MetricRecord.svelte:
--------------------------------------------------------------------------------
1 |
2 |
16 |
17 |
24 |
25 |
26 |
27 | {#if value.constructor === Array}
28 | {metricDef.name}
29 |
30 | {:else}
31 | {metricDef.name}
32 | {#if typeof value === "number"}
33 | {value.toFixed(metricDef.precision)}
34 | {#if metricDef.units !== ''}
35 | {metricDef.units}
36 | {/if}
37 |
38 | {:else}
39 | {value}
40 | {#if metricDef.units !== ''}
41 | {metricDef.units}
42 | {/if}
43 |
44 | {/if}
45 | {/if}
46 |
47 |
--------------------------------------------------------------------------------
/client/graphpaper-inline/src/ResizeListener.svelte:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/client/graphpaper-inline/src/Select.svelte:
--------------------------------------------------------------------------------
1 |
2 |
27 |
28 |
29 |
39 | {#if showList}
40 |
41 | {#each values as value, i}
42 | - selectOption(value)} on:keypress={(_) => {}}>{value}
43 | {/each}
44 |
45 | {/if}
46 |
47 |
--------------------------------------------------------------------------------
/client/graphpaper-inline/src/Sparkline.svelte:
--------------------------------------------------------------------------------
1 |
2 |
3 |
31 |
32 |
33 |
38 |
--------------------------------------------------------------------------------
/client/graphpaper-inline/src/StitchHandler.svelte:
--------------------------------------------------------------------------------
1 |
2 |
47 |
48 |
--------------------------------------------------------------------------------
/client/graphpaper-inline/src/TokenGridItem.svelte:
--------------------------------------------------------------------------------
1 |
2 |
10 |
11 | {#each token.text as ch, i}
12 | {#if ch === ' '}
13 |
14 | {#if i === 0}
15 |
16 | {token.role}
17 |
18 | {/if}
19 |
20 |
21 | {:else if ch === '\t'}
22 |
23 | {#if i === 0}
24 |
25 | {token.role}
26 |
27 | {/if}
28 | \t
29 |
30 | {:else if ch === '\n'}
31 |
32 | {#if i === 0}
33 |
34 | {token.role}
35 |
36 | {/if}
37 | \n
38 |
39 |
40 | {:else}
41 |
42 | {#if i === 0}
43 |
44 | {token.role}
45 |
46 | {/if}
47 | {ch}
48 |
49 | {/if}
50 | {/each}
--------------------------------------------------------------------------------
/client/graphpaper-inline/src/clickoutside.ts:
--------------------------------------------------------------------------------
1 | // Action for clicking outside an element.
2 |
3 | export function clickOutside(node: HTMLElement) {
4 | const handleClick = (event: MouseEvent) => {
5 | let target = event.target as HTMLElement;
6 | if (!node.contains(target)) {
7 | node.dispatchEvent(new CustomEvent('outclick'));
8 | }
9 | };
10 |
11 | document.addEventListener('click', handleClick, true);
12 |
13 | return {
14 | destroy() {
15 | document.removeEventListener('click', handleClick, true);
16 | }
17 | };
18 | }
--------------------------------------------------------------------------------
/client/graphpaper-inline/src/interfaces.ts:
--------------------------------------------------------------------------------
1 | // Interfaces used within the client. This is separate to messaging interfaces.
2 |
3 | import type {GenToken, RoleOpenerInput} from "./stitch";
4 |
5 | export interface MetricDef {
6 | name: string,
7 | units: string,
8 | description: string,
9 | isScalar: boolean,
10 | precision: number,
11 | }
12 |
13 | export type MetricVal = string | number | Array;
14 |
15 | export interface Token {
16 | text: string,
17 | prob: number,
18 | latency_ms: number,
19 | is_input: boolean,
20 | is_force_forwarded: boolean,
21 | is_generated: boolean,
22 | role: string,
23 | special: boolean,
24 | top_k?: Array
25 | }
26 | export declare type TokenCallback = (token: Token) => string;
27 |
28 | export interface MediaNodeContext {
29 | roleStack: RoleOpenerInput[];
30 | index: number;
31 | }
32 |
33 | export type MediaType = "audio" | "video" | "image";
34 |
35 | export interface MediaNode {
36 | type: MediaType;
37 | value: any;
38 | format: string;
39 | context: MediaNodeContext;
40 | }
41 |
42 | export type MultimodalNode =
43 | | { type: 'token', data: Token }
44 | | { type: 'media', data: MediaNode };
45 |
--------------------------------------------------------------------------------
/client/graphpaper-inline/src/longhover.ts:
--------------------------------------------------------------------------------
1 | // Action for long mouse hovers.
2 |
3 | export function longhover(node: HTMLElement, duration: number) {
4 | let timer: any;
5 |
6 | const handleMouseOver = (event: MouseEvent) => {
7 | timer = setTimeout(() => {
8 | node.dispatchEvent(new CustomEvent('longmouseover', {detail: event}));
9 | }, duration);
10 | };
11 | const handleMouseOut = (event: MouseEvent) => {
12 | clearTimeout(timer);
13 | node.dispatchEvent(new CustomEvent('longmouseout', {detail: event}));
14 | }
15 |
16 | node.addEventListener('mouseover', handleMouseOver);
17 | node.addEventListener('mouseout', handleMouseOut);
18 |
19 | return {
20 | update(newDuration: number) {
21 | duration = newDuration
22 | },
23 | destroy() {
24 | node.removeEventListener('mouseover', handleMouseOver);
25 | node.removeEventListener('mouseout', handleMouseOut);
26 | }
27 | };
28 | }
29 |
--------------------------------------------------------------------------------
/client/graphpaper-inline/src/main.css:
--------------------------------------------------------------------------------
1 | /* Custom CSS for web app. */
2 | @tailwind base;
3 | @tailwind components;
4 | @tailwind utilities;
--------------------------------------------------------------------------------
/client/graphpaper-inline/src/main.js:
--------------------------------------------------------------------------------
1 | // Entrypoint for web app.
2 |
3 | import App from './App.svelte';
4 |
5 | const app = new App({
6 | target: document.body,
7 | });
8 |
9 | export default app;
--------------------------------------------------------------------------------
/client/graphpaper-inline/src/metrics.ts:
--------------------------------------------------------------------------------
1 | // Metrics and their definitions.
2 |
3 | import type { MetricDef } from './interfaces';
4 |
5 | export const metricDefs: Record = {
6 | 'status': {
7 | name: '',
8 | units: '',
9 | description: 'Determines whether engine is running, completed or in error.',
10 | isScalar: true,
11 | precision: 0
12 | },
13 | 'cpu': {
14 | name: 'CPU',
15 | units: '%',
16 | description: 'Average utilization across CPU cores.',
17 | isScalar: false,
18 | precision: 1
19 | },
20 | 'gpu': {
21 | name: 'GPU',
22 | units: '%',
23 | description: 'Average utilization across GPUs.',
24 | isScalar: false,
25 | precision: 1
26 | },
27 | 'ram': {
28 | name: 'RAM',
29 | units: 'GB',
30 | description: 'Utilization of RAM.',
31 | isScalar: true,
32 | precision: 1
33 | },
34 | 'vram': {
35 | name: 'VRAM',
36 | units: 'GB',
37 | description: 'Utilization of video RAM.',
38 | isScalar: true,
39 | precision: 1
40 | },
41 | 'wall time': {
42 | name: 'Time',
43 | units: 's',
44 | description: 'Time taken from initial display to engine completion.',
45 | isScalar: true,
46 | precision: 1
47 | },
48 | 'avg latency': {
49 | name: 'Latency',
50 | units: 'ms',
51 | description: 'Average roundtrip latency per token',
52 | isScalar: true,
53 | precision: 0
54 | },
55 | 'consumed': {
56 | name: 'Used',
57 | units: 'tkn',
58 | description: 'Total tokens consumed by language model.',
59 | isScalar: true,
60 | precision: 0
61 | },
62 | 'token reduction': {
63 | name: 'Reduced',
64 | units: '%',
65 | description: 'Total tokens consumed by language model divided by total tokens.',
66 | isScalar: true,
67 | precision: 0
68 | }
69 | };
70 |
--------------------------------------------------------------------------------
/client/graphpaper-inline/src/template.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
--------------------------------------------------------------------------------
/client/graphpaper-inline/tailwind.config.js:
--------------------------------------------------------------------------------
1 | /** @type {import('tailwindcss').Config} */
2 | module.exports = {
3 | content: ["./src/**/*.{html,ts,js,svelte}"],
4 | theme: {
5 | extend: {
6 | fontFamily: {
7 | 'token': ['JetBrains Mono'],
8 | },
9 | keyframes: {
10 | 'cpulse': {
11 | '50%': { opacity: 0.0 }
12 | }
13 | },
14 | animation: {
15 | 'cpulse': 'cpulse 3.5s cubic-bezier(0.4, 0, 0.6, 1) infinite'
16 | }
17 | }
18 | },
19 | plugins: [
20 | require('tailwind-scrollbar'),
21 | ],
22 | }
23 |
24 |
--------------------------------------------------------------------------------
/client/graphpaper-inline/tsconfig.json:
--------------------------------------------------------------------------------
1 | {
2 | "compilerOptions": {
3 | "allowJs": true,
4 | "checkJs": true,
5 | "esModuleInterop": true,
6 | "forceConsistentCasingInFileNames": true,
7 | "resolveJsonModule": true,
8 | "skipLibCheck": true,
9 | "sourceMap": true,
10 | "strict": true,
11 | "verbatimModuleSyntax": true,
12 | "module": "ESNext",
13 | "moduleResolution": "bundler"
14 | }
15 | }
--------------------------------------------------------------------------------
/docs/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # version: 2
2 |
3 | # Read the Docs configuration file for Sphinx projects
4 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
5 |
6 | # Required
7 | version: 2
8 |
9 | # Set the OS, Python version and other tools you might need
10 | build:
11 | os: ubuntu-22.04
12 | tools:
13 | python: "3.10"
14 | # You can also specify other tool versions:
15 | # nodejs: "20"
16 | # rust: "1.70"
17 | # golang: "1.20"
18 |
19 | # Build documentation in the "docs/" directory with Sphinx
20 | sphinx:
21 | configuration: docs/conf.py
22 | # You can configure Sphinx to use a different builder, for instance use the dirhtml builder for simpler URLs
23 | # builder: "dirhtml"
24 | # Fail on all warnings to avoid broken references
25 | # fail_on_warning: true
26 |
27 | # Optionally build your docs in additional formats such as PDF and ePub
28 | # formats:
29 | # - pdf
30 | # - epub
31 |
32 | # Optional but recommended, declare the Python requirements required
33 | # to build your documentation
34 | # See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html
35 | # python:
36 | # install:
37 | # - requirements: docs/requirements.txt
38 | # version: 3.8
39 | python:
40 | install:
41 | - method: pip
42 | path: .
43 | extra_requirements:
44 | - docs
45 |
--------------------------------------------------------------------------------
/docs/_static/css/styles.css:
--------------------------------------------------------------------------------
1 | .wy-side-nav-search > a img.logo, .wy-side-nav-search .wy-dropdown > a img.logo {
2 | width: 250px;
3 | margin-top: 20px;
4 | margin-bottom: 15px;
5 | }
6 |
7 | .wy-side-nav-search>div.version {
8 | color: black;
9 | }
10 | @media screen and (min-width: 767px) {
11 | .wy-table-responsive table td {
12 | white-space: normal;
13 | }
14 | .wy-table-responsive {
15 | overflow: visible;
16 | }
17 | }
18 |
19 | /* .wy-side-nav-search .wy-dropdown>a img.logo,.wy-side-nav-search>a img.logo {
20 | max-width: 40%;
21 | } */
22 |
23 | .wy-side-nav-search>div.version {
24 | color: #d9d9d9;
25 | }
26 |
27 | .wy-nav-top {
28 | background: #343131;
29 | }
30 |
31 | .highlight {
32 | background: #f7f7f7;
33 | }
34 |
35 | .wy-side-nav-search input[type=text] {
36 | border-color: #666666;
37 | }
38 |
39 | a {
40 | color: #008bfb;
41 | }
42 |
43 | a:hover {
44 | color: #008bfb;
45 | }
46 |
47 | a:visited {
48 | color: #008bfb;
49 | }
50 |
51 | html.writer-html4 .rst-content dl:not(.docutils)>dt, html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.glossary):not(.simple)>dt {
52 | background: #008bfb11;
53 | color: #0086f6;
54 | border-top: 3px solid #008bfbaa;
55 | }
56 |
57 | .rst-versions .rst-current-version {
58 | color: #fcfcfc;
59 | }
60 |
61 | .wy-menu-vertical a {
62 | color: #d9d9d9;
63 | }
64 |
65 | section h2 {
66 | margin-top: 30px;
67 | }
68 |
69 | .rst-content code.literal, .rst-content tt.literal {
70 | color: #008bfb;
71 | }
--------------------------------------------------------------------------------
/docs/api.rst:
--------------------------------------------------------------------------------
1 | .. currentmodule:: guidance
2 |
3 | API Reference
4 | =============
5 | This page contains the API reference for public objects and functions in Guidance.
6 |
7 |
8 | .. _functions_api:
9 |
10 | functions
11 | ---------
12 | .. autosummary::
13 | :toctree: generated/
14 |
15 | guidance.gen
16 | guidance.select
17 | guidance.json
18 |
19 |
20 | .. _contexts_api:
21 |
22 | context blocks
23 | --------------
24 | .. autosummary::
25 | :toctree: generated/
26 |
27 | guidance.instruction
28 | guidance.system
29 | guidance.user
30 | guidance.assistant
31 |
32 |
33 | .. _models_api:
34 |
35 | models
36 | ------
37 | .. autosummary::
38 | :toctree: generated/
39 |
40 | guidance.models.Model
41 | guidance.models.LlamaCpp
42 | guidance.models.Transformers
43 | guidance.models.Anthropic
44 | guidance.models.AzureOpenAI
45 | guidance.models.Cohere
46 | guidance.models.GoogleAI
47 | guidance.models.LiteLLM
48 | guidance.models.OpenAI
49 | guidance.models.VertexAI
50 |
--------------------------------------------------------------------------------
/docs/api_examples.rst:
--------------------------------------------------------------------------------
1 | .. currentmodule:: guidance
2 |
3 | .. _api_examples:
4 |
5 | API Examples
6 | ------------
7 |
8 | These examples parallel the namespace structure of Guidance. Each object or function in Guidance has a
9 | corresponding example notebook here that demonstrates its API usage. The source notebooks
10 | are `available on GitHub `_.
11 |
12 |
13 | .. _functions_examples:
14 |
15 | functions
16 | =========
17 | .. Examples for built-in guidance functions.
18 |
19 | .. toctree::
20 | :glob:
21 | :maxdepth: 1
22 |
23 | example_notebooks/api_examples/library/*
24 |
25 |
26 | .. _models_examples:
27 |
28 | models
29 | ======
30 | .. Examples for members of :ref:`guidance.models `.
31 |
32 | .. toctree::
33 | :glob:
34 | :maxdepth: 1
35 |
36 | example_notebooks/api_examples/models/*
37 |
--------------------------------------------------------------------------------
/docs/art_of_prompt_design.rst:
--------------------------------------------------------------------------------
1 | .. currentmodule:: guidance
2 |
3 | .. _art_of_prompt_design:
4 |
5 | The Art of Prompt Design
6 | ------------------------
7 |
8 | These notebooks demonstrate how to design effective prompts and guidance programs, they also cover common useful
9 | design patterns. The source notebooks are `available on GitHub `_.
10 |
11 |
12 | .. toctree::
13 | :glob:
14 | :maxdepth: 1
15 |
16 | example_notebooks/art_of_prompt_design/use_clear_syntax.ipynb
17 | example_notebooks/art_of_prompt_design/prompt_boundaries_and_token_healing.ipynb
18 | example_notebooks/art_of_prompt_design/tool_use.ipynb
19 | example_notebooks/art_of_prompt_design/react.ipynb
20 | example_notebooks/art_of_prompt_design/rag.ipynb
--------------------------------------------------------------------------------
/docs/figures/anachronism.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/anachronism.png
--------------------------------------------------------------------------------
/docs/figures/await1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/await1.png
--------------------------------------------------------------------------------
/docs/figures/await2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/await2.png
--------------------------------------------------------------------------------
/docs/figures/capture_example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/capture_example.png
--------------------------------------------------------------------------------
/docs/figures/chat1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/chat1.png
--------------------------------------------------------------------------------
/docs/figures/chat_animation.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/chat_animation.gif
--------------------------------------------------------------------------------
/docs/figures/chat_reading.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/chat_reading.png
--------------------------------------------------------------------------------
/docs/figures/demo_output.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/demo_output.png
--------------------------------------------------------------------------------
/docs/figures/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/favicon.ico
--------------------------------------------------------------------------------
/docs/figures/favicon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/favicon.png
--------------------------------------------------------------------------------
/docs/figures/function.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/function.png
--------------------------------------------------------------------------------
/docs/figures/gen_loop_demo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/gen_loop_demo.png
--------------------------------------------------------------------------------
/docs/figures/generate_select.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/generate_select.png
--------------------------------------------------------------------------------
/docs/figures/generation1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/generation1.png
--------------------------------------------------------------------------------
/docs/figures/get_started_button.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/get_started_button.png
--------------------------------------------------------------------------------
/docs/figures/guidance_logo_blue.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
41 |
--------------------------------------------------------------------------------
/docs/figures/guidance_logo_blue_dark.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
44 |
--------------------------------------------------------------------------------
/docs/figures/hidden1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/hidden1.png
--------------------------------------------------------------------------------
/docs/figures/json_animation.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/json_animation.gif
--------------------------------------------------------------------------------
/docs/figures/json_syntax_variables.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/json_syntax_variables.png
--------------------------------------------------------------------------------
/docs/figures/perfect_syntax.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/perfect_syntax.png
--------------------------------------------------------------------------------
/docs/figures/proverb_animation.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/proverb_animation.gif
--------------------------------------------------------------------------------
/docs/figures/proverb_output.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/proverb_output.png
--------------------------------------------------------------------------------
/docs/figures/select.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/select.png
--------------------------------------------------------------------------------
/docs/figures/simple_fstring_llama2_7b.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/simple_fstring_llama2_7b.png
--------------------------------------------------------------------------------
/docs/figures/simple_gen_llama2_7b.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/simple_gen_llama2_7b.png
--------------------------------------------------------------------------------
/docs/figures/simple_select_llama2_7b.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/simple_select_llama2_7b.png
--------------------------------------------------------------------------------
/docs/figures/simple_streaming_example.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/simple_streaming_example.gif
--------------------------------------------------------------------------------
/docs/figures/template_objs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/template_objs.png
--------------------------------------------------------------------------------
/docs/figures/url_with_space.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/url_with_space.png
--------------------------------------------------------------------------------
/docs/figures/url_without_space.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/url_without_space.png
--------------------------------------------------------------------------------
/docs/figures/watch_demo_button.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/watch_demo_button.png
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 |
2 | .. image:: figures/guidance_logo_blue.svg
3 | :width: 300px
4 | :align: center
5 | |
6 |
7 | **Guidance** enables you to control modern language models more effectively and efficiently than traditional prompting or chaining. Guidance programs allow you to interleave generation, prompting, and logical control into a single continuous flow matching how the language model actually processes the text.
8 |
9 | Install
10 | =======
11 |
12 | Guidance can be installed from `PyPI `_::
13 |
14 | pip install guidance
15 |
16 |
17 | Contents
18 | ========
19 |
20 | .. toctree::
21 | :maxdepth: 2
22 |
23 | Tutorials
24 | API reference
25 | API examples
26 | The Art of Prompt Design
27 |
--------------------------------------------------------------------------------
/docs/tutorials.rst:
--------------------------------------------------------------------------------
1 | .. currentmodule:: guidance
2 |
3 | .. _tutorials:
4 |
5 | Tutorials
6 | ----------------
7 |
8 | These notebooks demonstrate various features of `guidance``. The source notebooks
9 | are `available on GitHub `_.
10 |
11 |
12 | .. toctree::
13 | :glob:
14 | :maxdepth: 1
15 |
16 | example_notebooks/tutorials/intro_to_guidance.ipynb
17 | example_notebooks/tutorials/token_healing.ipynb
18 | example_notebooks/tutorials/regex_constraints.ipynb
19 | example_notebooks/tutorials/guidance_acceleration.ipynb
20 | example_notebooks/tutorials/code_generation.ipynb
21 | example_notebooks/tutorials/chat.ipynb
--------------------------------------------------------------------------------
/guidance/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.2.1"
2 |
3 | import sys
4 | import types
5 |
6 | from . import models
7 | from ._guidance import guidance
8 |
9 | from ._ast import GrammarNode, Function
10 | from ._utils import strip_multiline_string_indents
11 |
12 |
13 | # This makes the guidance module callable
14 | class _Guidance(types.ModuleType):
15 | def __call__(
16 | self, f=None, *, stateless=False, cache=None, dedent=True, model=models.Model
17 | ):
18 | return guidance(
19 | f, stateless=stateless, cache=cache, dedent=dedent, model=model
20 | )
21 |
22 |
23 | sys.modules[__name__].__class__ = _Guidance
24 |
25 | # we expose all the library functions at the top level of the module
26 | from .library import *
--------------------------------------------------------------------------------
/guidance/_bg/__init__.py:
--------------------------------------------------------------------------------
1 | """ Background thread for asyncio handling.
2 |
3 | This is currently being used for messaging, visualization and metrics.
4 | """
5 |
6 | import asyncio
7 | import threading
8 | from asyncio import AbstractEventLoop, Task
9 | from concurrent.futures import Future
10 | from typing import Any, Coroutine, TypeVar
11 |
12 | T = TypeVar('T')
13 |
14 | def _start_asyncio_loop(loop: AbstractEventLoop):
15 | asyncio.set_event_loop(loop)
16 | loop.run_forever()
17 |
18 |
19 | def _asyncio_background_thread() -> tuple[threading.Thread, AbstractEventLoop]:
20 | loop = asyncio.new_event_loop()
21 | thread = threading.Thread(target=_start_asyncio_loop, args=(loop,))
22 | thread.daemon = True
23 | return thread, loop
24 |
25 |
26 | class BackgroundAsync:
27 | """ Runs background thread that has an asyncio event loop."""
28 |
29 | def __init__(self):
30 | """ Initializes. """
31 | self._loop = None
32 | self._thread = None
33 |
34 | def _thread_and_loop(self) -> tuple[threading.Thread, AbstractEventLoop]:
35 | if self._loop is None:
36 | self._thread, self._loop = _asyncio_background_thread()
37 | self._thread.start()
38 | return self._thread, self._loop
39 |
40 | def call_soon_threadsafe(self, cb, *args, context = None):
41 | """ Fires callback in background thread."""
42 |
43 | _, loop = self._thread_and_loop()
44 | return loop.call_soon_threadsafe(cb, *args, context=context)
45 |
46 | def run_async_coroutine(self, coroutine: Coroutine[Any, Any, T]) -> Future[T]:
47 | """ Runs an asynchronous coroutine in the visual thread.
48 |
49 | Args:
50 | coroutine: Coroutine to be run on visual thread.
51 |
52 | Returns:
53 | Future of coroutine.
54 | """
55 | _, loop = self._thread_and_loop()
56 | future = asyncio.run_coroutine_threadsafe(coroutine, loop)
57 | return future
58 |
59 | @staticmethod
60 | async def async_task(coroutine: Coroutine[Any, Any, T]) -> Task[T]:
61 | """ Creates an asyncio task from coroutine.
62 |
63 | Args:
64 | coroutine: Coroutine within task.
65 |
66 | Returns:
67 | Asyncio task.
68 | """
69 | task = asyncio.create_task(coroutine)
70 | return task
71 |
72 | @staticmethod
73 | async def print_all_tasks(): # pragma: no cover
74 | """Prints all tasks running in background thread loop (for debugging purposes)."""
75 | for task in asyncio.all_tasks():
76 | print(task)
77 |
--------------------------------------------------------------------------------
/guidance/_guidance.pyi:
--------------------------------------------------------------------------------
1 | import sys
2 | from typing import (
3 | Any,
4 | Callable,
5 | Literal,
6 | TypeVar,
7 | Union,
8 | overload,
9 | )
10 | from contextvars import ContextVar
11 | if sys.version_info >= (3, 10):
12 | from typing import ParamSpec, TypeAlias, Concatenate
13 | else:
14 | from typing_extensions import ParamSpec, TypeAlias, Concatenate
15 |
16 | from ._ast import RuleNode, Function
17 | from .models import Model
18 |
19 | _in_stateless_context: ContextVar[bool]
20 |
21 | P = ParamSpec("P")
22 | M: TypeAlias = Any # sort of Union[Model, GrammarNode]?
23 | R = TypeVar("R", bound = Union[Function, RuleNode])
24 | GuidanceWrappable = Callable[Concatenate[M, P], M]
25 | GuidanceFunction = Callable[P, R]
26 | StatefulGuidanceFunction = GuidanceFunction[P, Function]
27 | StatelessGuidanceFunction = GuidanceFunction[P, RuleNode]
28 |
29 | @overload
30 | def guidance(
31 | f: GuidanceWrappable[P],
32 | *,
33 | stateless: Literal[False] = False,
34 | cache: bool = ...,
35 | dedent: bool = ...,
36 | model: type[Model] = ...,
37 | ) -> StatefulGuidanceFunction[P]:
38 | ...
39 |
40 |
41 | @overload
42 | def guidance(
43 | f: None = None,
44 | *,
45 | stateless: Literal[False] = False,
46 | cache: bool = ...,
47 | dedent: bool = ...,
48 | model: type[Model] = ...,
49 | ) -> Callable[[GuidanceWrappable[P]], StatefulGuidanceFunction[P]]:
50 | ...
51 |
52 |
53 | @overload
54 | def guidance(
55 | f: GuidanceWrappable[P],
56 | *,
57 | stateless: Literal[True],
58 | cache: bool = ...,
59 | dedent: bool = ...,
60 | model: type[Model] = ...,
61 | ) -> StatelessGuidanceFunction[P]:
62 | ...
63 |
64 |
65 | @overload
66 | def guidance(
67 | f: None = None,
68 | *,
69 | stateless: Literal[True],
70 | cache: bool = ...,
71 | dedent: bool = ...,
72 | model: type[Model] = ...,
73 | ) -> Callable[[GuidanceWrappable[P]], StatelessGuidanceFunction[P]]:
74 | ...
75 |
76 |
77 | @overload
78 | def guidance(
79 | f: GuidanceWrappable[P],
80 | *,
81 | stateless: Callable[..., bool],
82 | cache: bool = ...,
83 | dedent: bool = ...,
84 | model: type[Model] = ...,
85 | ) -> GuidanceFunction[P, Union[Function, RuleNode]]:
86 | ...
87 |
88 |
89 | @overload
90 | def guidance(
91 | f: None = None,
92 | *,
93 | stateless: Callable[..., bool],
94 | cache: bool = ...,
95 | dedent: bool = ...,
96 | model: type[Model] = ...,
97 | ) -> Callable[[GuidanceWrappable[P]], GuidanceFunction[P, Union[Function, RuleNode]]]:
98 | ...
99 |
--------------------------------------------------------------------------------
/guidance/bench/__init__.py:
--------------------------------------------------------------------------------
1 | """Elementary benchmarking for `guidance` development purposes.
2 |
3 | `guidance` lives in a fast paced LLM environment, has complex dependencies and is tricky to implement.
4 | These benchmarks are designed to focus on key use cases, where regressions can create havoc.
5 |
6 | General guidelines:
7 | - Simplicity first, then customization - reproducibility by the community is encouraged
8 | - Everything takes forever - allow a pathway to scale horizontally
9 | - Goalposts shift - some of the code for benchmarking will change frequently and that's okay
10 |
11 | Implementation:
12 |
13 | The `bench` function is provided for no frills benchmarking that is designated for
14 | automated testing.
15 |
16 | For customization, we provide a notebook demonstration of how to run custom benchmarks
17 | that are near mirror versions of what is available in the `bench` function provided.
18 |
19 | Not implemented yet, but we intend to provide an avenue of running the benchmarks via
20 | docker containers that have GPU resourcing to scale horizontally.
21 | """
22 |
23 | from guidance.bench._powerlift import (
24 | retrieve_langchain,
25 | langchain_chat_extract_runner,
26 | langchain_chat_extract_filter_template,
27 | )
28 | from guidance.bench._api import bench
29 |
30 | # TODO(nopdive): Enable docker containers to execute benchmarking easily
31 |
--------------------------------------------------------------------------------
/guidance/bench/_api.py:
--------------------------------------------------------------------------------
1 | """User facing API for benchmarking."""
2 |
3 | from typing import List, Optional, Tuple, Union
4 | from pathlib import Path
5 | from guidance.bench._utils import lib_bench_dir
6 |
7 | """Available models to run benchmark against."""
8 | AVAILABLE_MODELS = [
9 | "guidance-mistral-7b-instruct",
10 | "base-mistral-7b-instruct",
11 | "guidance-phi-3-mini-4k-instruct",
12 | "base-phi-3-mini-4k-instruct",
13 | "guidance-llama2-7b-32k-instruct",
14 | "base-llama2-7b-32k-instruct",
15 | ]
16 |
17 |
18 | def bench(
19 | db_url: str,
20 | experiment_name: str,
21 | models: List[str] = AVAILABLE_MODELS,
22 | force_recreate: bool = False,
23 | timeout: int = 3600,
24 | cache_dir: Union[str, Path] = lib_bench_dir() / "cache",
25 | debug_mode: bool = False,
26 | ) -> Tuple[object, object]:
27 | """Benchmarks guidance against preset tasks.
28 |
29 | This runs on a single machine, one trial at a time.
30 | To run this the first time you will need API_LANGCHAIN_KEY set as an environment variable.
31 |
32 | Args:
33 | db_url (str): Database connection string.
34 | experiment_name (str): Name of experiment to create / run.
35 | models (List[str], optional): Models to benchmark. Defaults to AVAILABLE_MODELS.
36 | force_recreate (bool, optional): Recreate the database before benchmarking. Defaults to False.
37 | timeout (int, optional): Max execution time per trial. Defaults to 3600.
38 | cache_dir (Union[str, Path], optional): Cache to store external datasets. Defaults to lib_bench_dir() / "cache".
39 | debug_mode (bool): Set this when you require a debugger to step line by line in the trial_runner.
40 |
41 | Returns:
42 | Tuple[object, object]: (status, results) data frames where status relates to trials, results are wide form aggregates of each model.
43 | """
44 | from guidance.bench._powerlift import bench as inner_bench
45 |
46 | status_df, result_df = inner_bench(
47 | db_url, experiment_name, models, force_recreate, timeout, cache_dir, debug_mode
48 | )
49 | return status_df, result_df
50 |
--------------------------------------------------------------------------------
/guidance/bench/_utils.py:
--------------------------------------------------------------------------------
1 | """Shared utility functions for module."""
2 |
3 | import os
4 | from pathlib import Path
5 |
6 | def lib_bench_dir() -> Path:
7 | """Library directory to store configurations and cached assets for benchmarking.
8 |
9 | If the library directory does not exist, it is created as a side effect.
10 |
11 | The library bench directory path can also be set via env var `GUIDANCE_BENCH_DIR`.
12 |
13 | Returns:
14 | Path: Library's directory path for benchmarking.
15 | """
16 |
17 | env_lib_path = os.environ.get("GUIDANCE_BENCH_DIR", None)
18 | if env_lib_path is None:
19 | lib_path = Path.home() / ".guidance-bench"
20 | else:
21 | lib_path = Path(env_lib_path)
22 | Path.mkdir(lib_path, parents=True, exist_ok=True)
23 |
24 | return lib_path
25 |
--------------------------------------------------------------------------------
/guidance/library/__init__.py:
--------------------------------------------------------------------------------
1 | # import functions that can be called directly
2 | from ._gen import gen, call_tool, regex
3 | from ._image import image, gen_image
4 | from ._audio import audio, gen_audio
5 | from ._video import video, gen_video
6 | from ._capture import capture
7 |
8 | # core grammar functions
9 | from .._grammar import select
10 | from .._grammar import with_temperature
11 | from .._grammar import string
12 | from .._grammar import token_limit
13 |
14 | # context blocks
15 | from ._block import block
16 | from ._role import role, system, assistant, user #, function, instruction, indent_roles
17 |
18 | # from ..models._model import context_free
19 |
20 | # stateless library functions
21 | from ._sequences import one_or_more, zero_or_more, at_most_n_repeats, exactly_n_repeats, sequence
22 | from ._substring import substring
23 | from ._optional import optional
24 | from ._tool import Tool
25 | from ._json import json
26 | from ._ebnf import lark, gbnf_to_lark
27 |
--------------------------------------------------------------------------------
/guidance/library/_audio.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 | import typing
3 | import base64
4 |
5 | from .._guidance import guidance
6 | from .._utils import bytes_from
7 | from .._ast import AudioBlob, GenAudio
8 |
9 |
10 | @guidance
11 | def audio(lm, src: typing.Union[str, pathlib.Path, bytes], allow_local: bool = True):
12 | bytes_data = bytes_from(src, allow_local=allow_local)
13 | base64_string = base64.b64encode(bytes_data).decode('utf-8')
14 | lm += AudioBlob(data=base64_string)
15 | return lm
16 |
17 |
18 | @guidance
19 | def gen_audio(lm, **kwargs):
20 | return lm + GenAudio(kwargs=kwargs)
21 |
--------------------------------------------------------------------------------
/guidance/library/_block.py:
--------------------------------------------------------------------------------
1 | from contextlib import contextmanager
2 | from typing import Optional, Union
3 | from .._ast import ASTNode, Function
4 | from ..models._base._model import _active_blocks
5 | from .._guidance import _in_stateless_context
6 |
7 | class Block:
8 | def __init__(self, name: Optional[str], opener: Union[str, Function, ASTNode], closer: Union[str, Function, ASTNode]):
9 | self.name = name
10 | self.opener = opener
11 | self.closer = closer
12 |
13 |
14 | @contextmanager
15 | def block(name=None, opener=None, closer=None):
16 | if _in_stateless_context.get():
17 | raise RuntimeError("Cannot use roles or other blocks when stateless=True")
18 | current_blocks = _active_blocks.get()
19 | new_block = Block(name=name, opener=opener, closer=closer)
20 | token = _active_blocks.set(current_blocks + (new_block,))
21 | try:
22 | yield
23 | finally:
24 | _active_blocks.reset(token)
25 |
--------------------------------------------------------------------------------
/guidance/library/_capture.py:
--------------------------------------------------------------------------------
1 | from .._guidance import guidance
2 | from .._grammar import capture as grammar_capture, GrammarNode
3 | from ._block import block
4 |
5 | @guidance(stateless=lambda *args, **kwargs: isinstance(args[0], GrammarNode))
6 | def capture(lm, value, name):
7 | if isinstance(value, GrammarNode):
8 | return lm + grammar_capture(value, name)
9 | else:
10 | with block(name):
11 | lm += value
12 | return lm
13 |
--------------------------------------------------------------------------------
/guidance/library/_ebnf.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from llguidance.gbnf_to_lark import gbnf_to_lark as _gbnf_to_lark
4 |
5 | from .._ast import RuleNode, LarkNode, GrammarNode
6 | from .._grammar import capture, token_limit, with_temperature
7 |
8 |
9 | def lark(
10 | lark_grammar: str,
11 | *,
12 | name: Optional[str] = None,
13 | temperature: Optional[float] = None,
14 | max_tokens: Optional[int] = None,
15 | ) -> GrammarNode:
16 | """
17 | Builds a guidance grammar from (a variant of) the EBNF syntax used by the Lark parsing toolkit.
18 |
19 | See documentation at https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md for more
20 | details.
21 | """
22 | node = RuleNode(
23 | name=name or "lark",
24 | value=LarkNode(lark_grammar=lark_grammar)
25 | )
26 | if temperature is not None:
27 | node = with_temperature(node, temperature)
28 | if max_tokens is not None:
29 | node = token_limit(node, max_tokens)
30 | if name is not None:
31 | node = capture(node, name)
32 |
33 | return node
34 |
35 |
36 | def gbnf_to_lark(gbnf_grammar: str) -> str:
37 | """
38 | Converts a GBNF (llama.cpp) grammar to Lark(-like) syntax. This is a best-effort
39 | conversion and may not work for all grammars. We recommend using this function
40 | as a starting point and then manually editing the resulting Lark grammar to suit
41 | your needs.
42 |
43 | See documentation at https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md
44 | for more information on the output syntax's semantics.
45 | """
46 | return _gbnf_to_lark(gbnf_grammar)
47 |
--------------------------------------------------------------------------------
/guidance/library/_image.py:
--------------------------------------------------------------------------------
1 | import importlib.resources
2 | import pathlib
3 | import typing
4 | import base64
5 | import re
6 |
7 | from .._guidance import guidance
8 | from .._utils import bytes_from
9 | from .._ast import ImageBlob, ImageUrl
10 | from ..trace._trace import ImageOutput
11 |
12 |
13 | @guidance
14 | def image(lm, src: typing.Union[str, pathlib.Path, bytes], allow_local: bool = True):
15 | if isinstance(src, str) and re.match(r"^(?!file://)[^:/]+://", src):
16 | lm += ImageUrl(url=src)
17 | else:
18 | bytes_data = bytes_from(src, allow_local=allow_local)
19 | base64_string = base64.b64encode(bytes_data).decode('utf-8')
20 | lm += ImageBlob(data=base64_string)
21 | return lm
22 |
23 |
24 | @guidance
25 | def gen_image(lm):
26 | # TODO(nopdive): Mock for testing. Remove all of this code later.
27 | with importlib.resources.files("guidance").joinpath("resources/sample_image.png").open("rb") as f:
28 | bytes_data = f.read()
29 | base64_string = base64.b64encode(bytes_data).decode('utf-8')
30 | lm += ImageOutput(value=base64_string, is_input=False)
31 | return lm
32 |
--------------------------------------------------------------------------------
/guidance/library/_optional.py:
--------------------------------------------------------------------------------
1 | from .._guidance import guidance
2 | from .._grammar import repeat
3 |
4 | @guidance(stateless=True)
5 | def optional(lm, value):
6 | return lm + repeat(value, 0, 1)
7 |
--------------------------------------------------------------------------------
/guidance/library/_pydantic.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | from typing import Any, Dict, Type, Union
3 |
4 | import pydantic
5 |
6 |
7 | class GenerateJsonSchemaSafe(pydantic.json_schema.GenerateJsonSchema):
8 | """
9 | Subclass pydantic's GenerateJsonSchema to catch pydantic schemas that will not
10 | translate properly to json schemas used for generation.
11 |
12 | In particular, JSON schemas do not offer a way to specify "key type",
13 | so we need to raise an exception if users attempt to specify non-string
14 | keys through pydantic. Otherwise, they may get unexpected output from
15 | model generation.
16 | """
17 |
18 | def generate_inner(self, schema):
19 | if schema["type"] == "dict":
20 | key_type = schema["keys_schema"]["type"]
21 | if key_type != "str":
22 | raise TypeError(
23 | f"JSON does not support non-string keys, got type {key_type}"
24 | )
25 | return super().generate_inner(schema)
26 |
27 |
28 | def pydantic_to_json_schema(
29 | schema: Union[Type["pydantic.BaseModel"], "pydantic.TypeAdapter"]
30 | ) -> Dict[str, Any]:
31 | if inspect.isclass(schema) and issubclass(schema, pydantic.BaseModel):
32 | return schema.model_json_schema(schema_generator=GenerateJsonSchemaSafe)
33 | if isinstance(schema, pydantic.TypeAdapter):
34 | return schema.json_schema(schema_generator=GenerateJsonSchemaSafe)
35 | raise TypeError(f"Cannot generate json schema from type {type(schema)}")
36 |
--------------------------------------------------------------------------------
/guidance/library/_role.py:
--------------------------------------------------------------------------------
1 | from contextlib import AbstractContextManager
2 | from .._ast import RoleEnd, RoleStart
3 | from ._block import block
4 |
5 |
6 | # TODO HN: Add a docstring to better describe arbitrary role functions
7 | def role(role: str) -> AbstractContextManager:
8 | return block(
9 | name=None,
10 | opener=RoleStart(role),
11 | closer=RoleEnd(role),
12 | )
13 |
14 |
15 | def system() -> AbstractContextManager:
16 | """Indicate the 'system' prompt
17 |
18 | A convention has grown up around 'chat' APIs that
19 | prompts are split into three parts: system, user
20 | and assistant.
21 | This indicates the start of a 'system' block, which
22 | provides background information to the LLM.
23 |
24 | >>> with system():
25 | >>> lm += "A system prompt"
26 |
27 | """
28 | return role("system")
29 |
30 |
31 | def user() -> AbstractContextManager:
32 | """Indicate the 'user' prompt
33 |
34 | A convention has grown up around 'chat' APIs that
35 | prompts are split into three parts: system, user
36 | and assistant.
37 | This indicates the start of a 'user' block, which
38 | provides input to the LLM from the user.
39 |
40 | >>> with user():
41 | >>> lm += "What the user said"
42 |
43 | """
44 | return role("user")
45 |
46 |
47 | def assistant() -> AbstractContextManager:
48 | """Indicate the 'assistant' prompt
49 |
50 | A convention has grown up around 'chat' APIs that
51 | prompts are split into three parts: system, user
52 | and assistant.
53 | This indicates the start of an 'assistant' block, which
54 | marks LLM response (or where the LLM will generate
55 | the next response).
56 |
57 | >>> with assistant():
58 | >>> lm += gen(name="model_output", max_tokens=20)
59 |
60 | """
61 | return role("assistant")
62 |
--------------------------------------------------------------------------------
/guidance/library/_sequences.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 |
3 | from .._grammar import repeat
4 | from .._guidance import guidance
5 |
6 |
7 | @guidance(stateless=True)
8 | def exactly_n_repeats(model, value, n_repeats: int):
9 | return model + repeat(value, min=n_repeats, max=n_repeats)
10 |
11 |
12 | @guidance(stateless=True)
13 | def at_most_n_repeats(model, value, n_repeats: int):
14 | return model + repeat(value, min=0, max=n_repeats)
15 |
16 |
17 | @guidance(stateless=True)
18 | def sequence(model, value, min_length: int = 0, max_length: Union[int, None] = None):
19 | # Just an alias for repeat for now -- TODO: remove?
20 | return model + repeat(value, min=min_length, max=max_length)
21 |
22 |
23 | @guidance(stateless=True)
24 | def one_or_more(model, value):
25 | return model + repeat(value, min=1)
26 |
27 |
28 | @guidance(stateless=True)
29 | def zero_or_more(model, value):
30 | return model + repeat(value, min=0)
31 |
--------------------------------------------------------------------------------
/guidance/library/_subgrammar.py:
--------------------------------------------------------------------------------
1 | from .._ast import GrammarNode, RuleNode
2 | from .._grammar import subgrammar, regex
3 |
4 | __all__ = ["subgrammar", "regex", "as_regular_grammar", "lexeme"]
5 |
6 | def as_regular_grammar(node: GrammarNode, lexeme=False):
7 | # TODO: Remove this assertion-only check?
8 | if isinstance(node, RuleNode):
9 | rule = node
10 | else:
11 | rule = RuleNode("dummy", node)
12 | assert rule.is_allowed_in_lark_terminal
13 | return node
14 |
15 | def lexeme(body_regex: str, json_string: bool = False):
16 | if json_string:
17 | raise NotImplementedError("JSON strings are not supported")
18 | return regex(body_regex)
19 |
--------------------------------------------------------------------------------
/guidance/library/_substring.py:
--------------------------------------------------------------------------------
1 | import re
2 | from typing import Callable, Iterable, Literal, Optional, Union
3 |
4 | from .._ast import RuleNode, SubstringNode
5 |
6 |
7 | def chunk_on_word(text: str) -> list[str]:
8 | return re.findall(r"(\s+|\w+|[^\s\w]+)", text)
9 |
10 |
11 | def substring(
12 | target_string: str,
13 | *,
14 | chunk: Union[Literal["word", "character"], Callable[[str], Iterable[str]]] = "word",
15 | name: Optional[str] = None,
16 | ) -> RuleNode:
17 | chunks: Iterable[str]
18 | if chunk == "word":
19 | chunks = chunk_on_word(target_string)
20 | elif chunk == "character":
21 | chunks = tuple(target_string)
22 | elif callable(chunk):
23 | chunks = chunk(target_string)
24 | if "".join(chunks) != target_string:
25 | raise ValueError(
26 | "chunk_on function must return a sequence of strings that can be joined to form the target string"
27 | )
28 | else:
29 | raise ValueError(f"Invalid `chunk` value: {chunk!r}. Expected 'word', 'character', or a function.")
30 |
31 | return RuleNode(
32 | name=name or "substring",
33 | value=SubstringNode(tuple(chunks)),
34 | capture=name,
35 | )
36 |
--------------------------------------------------------------------------------
/guidance/library/_tool.py:
--------------------------------------------------------------------------------
1 | from .._guidance import guidance
2 | from .._grammar import select
3 | from ._optional import optional
4 | from ._sequences import zero_or_more
5 | from ._subgrammar import lexeme, subgrammar
6 |
7 | class Tool:
8 | def __init__(self, call_grammar=None, tool_call=None, callable=None):
9 | # call_grammar specifies how the tool can be called. Crucially, it has to capture the args in variable 'tool_args'
10 | # tool_call is a guidance function actually calls the tool, and returns an lm object with whatever outputs it wants
11 | # callable: guidance function or regular callable, will be converted to grammar
12 | # TODO: hidden is not working yet
13 | first_option = (call_grammar is not None) and (tool_call is not None)
14 | second_option = callable is not None
15 | # either both are true or both false
16 | if first_option == second_option:
17 | raise Exception(
18 | "Must pass either (call_grammar, tool call) or callable, but not both or neither"
19 | )
20 | if second_option:
21 | call_grammar, tool_call = fn_to_grammar_call(callable)
22 | self.call_grammar = call_grammar
23 | self.tool_call = tool_call
24 |
25 |
26 | def basic_func_grammar(name):
27 | arg = lexeme(r"[^,=)]+")
28 | kwarg = arg + "=" + arg
29 | args = arg + zero_or_more("," + arg)
30 | kwargs = kwarg + zero_or_more("," + kwarg)
31 |
32 | obj = name + "("
33 | obj += subgrammar(
34 | name="tool_args",
35 | body=optional(
36 | select([
37 | args,
38 | kwargs,
39 | args + "," + kwargs,
40 | ])
41 | ),
42 | skip_regex=r" *"
43 | )
44 | obj += ")"
45 | return obj
46 |
47 |
48 | def fn_to_grammar_call(callable):
49 | # TODO later: validate the call. Here is code to get required and optional args of 'guidance_fn':
50 | # name = guidance_fn.__name__
51 | # required_args = []
52 | # optional_args = []
53 | # sig = inspect.signature(guidance_fn)
54 | # for i, x in enumerate(sig.parameters.values()):
55 | # if i == 0:
56 | # continue
57 | # if x.default is x.empty:
58 | # required_args.append(x.name)
59 | # else:
60 | # optional_args.append(x.name)
61 | name = callable.__name__
62 | call_grammar = basic_func_grammar(name)
63 |
64 | @guidance(dedent=False)
65 | def basic_tool_call(lm):
66 | args = lm["tool_args"]
67 | args = args.split(",")
68 | positional = [x.strip() for x in args if "=" not in x]
69 | kwargs = dict([tuple(x.strip().split("=")) for x in args if "=" in x])
70 | lm += callable(*positional, **kwargs)
71 | return lm
72 |
73 | return call_grammar, basic_tool_call
74 |
--------------------------------------------------------------------------------
/guidance/library/_video.py:
--------------------------------------------------------------------------------
1 | import importlib.resources
2 | import pathlib
3 | import typing
4 | import base64
5 |
6 |
7 | from .._guidance import guidance
8 | from .._utils import bytes_from
9 | # from ..trace._trace import VideoInput
10 | from ..trace._trace import VideoOutput
11 |
12 |
13 | @guidance
14 | def video(lm, src: typing.Union[str, pathlib.Path, bytes], allow_local: bool = True):
15 | # TODO(nopdive): Mock for testing. Remove all of this code later.
16 | bytes_data = bytes_from(src, allow_local=allow_local)
17 | base64_string = base64.b64encode(bytes_data).decode('utf-8')
18 | lm += VideoOutput(value=base64_string, is_input=True)
19 | # lm += VideoInput(value=base64_string)
20 | return lm
21 |
22 |
23 | @guidance
24 | def gen_video(lm):
25 | # TODO(nopdive): Mock for testing. Remove all of this code later.
26 | with importlib.resources.files("guidance").joinpath("resources/sample_video.png").open("rb") as f:
27 | bytes_data = f.read()
28 | base64_string = base64.b64encode(bytes_data).decode('utf-8')
29 | lm += VideoOutput(value=base64_string, is_input=False)
30 | return lm
31 |
--------------------------------------------------------------------------------
/guidance/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | """ Metrics that arise from both language models and its execution environment."""
2 |
3 | from ._metrics import ALL_METRICS, METRICS_TOPIC
4 | from ._metrics import MonitoringMetric, PostExecMetrics, Monitor, PeriodicMetricsGenerator
--------------------------------------------------------------------------------
/guidance/models/__init__.py:
--------------------------------------------------------------------------------
1 | from ._base import Model
2 |
3 | # from ._engine import Instruct, Chat
4 |
5 | # local models
6 | from ._transformers import Transformers, TransformersTokenizer
7 | from ._llama_cpp import LlamaCpp
8 | from ._mock import Mock # , MockChat
9 |
10 | # from .vertexai._vertexai import (
11 | # VertexAI,
12 | # VertexAIChat,
13 | # VertexAICompletion,
14 | # VertexAIInstruct,
15 | # )
16 | # from ._azure_openai import (
17 | # AzureOpenAI,
18 | # )
19 | # from ._azureai_studio import AzureAIStudioChat
20 | from ._openai import OpenAI
21 |
22 | # from ._lite_llm import LiteLLM, LiteLLMChat, LiteLLMInstruct, LiteLLMCompletion
23 | # from ._cohere import Cohere, CohereCompletion, CohereInstruct
24 | # from ._anthropic import Anthropic
25 | # from ._googleai import GoogleAI, GoogleAIChat
26 | # from ._togetherai import (
27 | # TogetherAI,
28 | # TogetherAIChat,
29 | # TogetherAIInstruct,
30 | # TogetherAICompletion,
31 | # )
32 | from . import experimental
33 |
--------------------------------------------------------------------------------
/guidance/models/_base/__init__.py:
--------------------------------------------------------------------------------
1 | from ._interpreter import Interpreter
2 | from ._model import Model
3 | from ._state import State
4 |
5 | __all__ = [
6 | "Model",
7 | "role",
8 | "State",
9 | "Message",
10 | "Interpreter",
11 | "ASTNode",
12 | "ContentChunk",
13 | "MessageChunk",
14 | ]
15 |
--------------------------------------------------------------------------------
/guidance/models/_base/_state.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Optional, TypedDict, Union
3 |
4 | from ...trace import CaptureOutput
5 |
6 |
7 | class CaptureVar(TypedDict):
8 | value: str
9 | log_prob: Optional[float]
10 |
11 |
12 | class State(ABC):
13 | def __init__(self) -> None:
14 | self.captures: dict[str, Union[CaptureVar, list[CaptureVar]]] = {}
15 | self.active_role: Optional[str] = None
16 |
17 | @abstractmethod
18 | def __str__(self) -> str:
19 | pass
20 |
21 | def apply_capture(
22 | self, name: str, value: Optional[str], log_prob=Optional[float], is_append: bool = False
23 | ) -> CaptureOutput:
24 | if value is None:
25 | # A "reset" signal
26 | self.captures.pop(name)
27 | else:
28 | var = CaptureVar(value=value, log_prob=log_prob)
29 | if is_append:
30 | vars = self.captures.get(name, [])
31 | if not isinstance(vars, list):
32 | vars = [vars]
33 | vars.append(var)
34 | self.captures[name] = vars
35 | else:
36 | self.captures[name] = var
37 |
38 | return CaptureOutput(
39 | name=name,
40 | value=value,
41 | log_probs=log_prob or float("nan"),
42 | is_append=is_append,
43 | )
44 |
--------------------------------------------------------------------------------
/guidance/models/_byte_tokenizer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from ._engine import Tokenizer
3 | from ._engine._tokenizer import TokenizerWrappable
4 | from ..chat import load_template_class
5 | from typing import List
6 |
7 | class ByteTokenizer(Tokenizer):
8 | def __init__(self, chat_template=None):
9 | # directly map integer values to byte strings
10 | all_bytes = [bytes([i]) for i in range(256)]
11 | bos = b""
12 | tokens = np.array(all_bytes + [bos], dtype="object")
13 | ll_tokenizer = TokenizerWrappable(
14 | eos_token_id=256,
15 | bos_token_id=256,
16 | tokens=tokens,
17 | special_token_ids=[],
18 | # ENCODE MUST BE OVERRIDDEN
19 | encode_callable=self.encode,
20 | ).as_ll_tokenizer()
21 |
22 | super().__init__(
23 | ll_tokenizer=ll_tokenizer,
24 | chat_template=chat_template,
25 | bos_token_id=256,
26 | )
27 |
28 | def encode(self, byte_string: bytes) -> List[int]:
29 | """Returns a list of tokens that represent the given byte string."""
30 | if isinstance(byte_string, str):
31 | byte_string = byte_string.encode("utf8")
32 | i = 0
33 | result = []
34 | while i < len(byte_string):
35 | if byte_string[i:i+3] == b'':
36 | result.append(256)
37 | i += 3 # Skip the next two characters as part of ''
38 | else:
39 | result.append(byte_string[i])
40 | i += 1
41 | return result
42 |
--------------------------------------------------------------------------------
/guidance/models/_engine/__init__.py:
--------------------------------------------------------------------------------
1 | from ._tokenizer import Tokenizer # isort:skip
2 | from ._interpreter import EngineInterpreter, Llama3VisionInterpreter, Phi3VisionInterpreter
3 | from ._engine import Engine
4 | from ._state import EngineState
5 |
6 | __all__ = [
7 | "Tokenizer",
8 | "Engine",
9 | "EngineInterpreter",
10 | "EngineState",
11 | "Llama3VisionInterpreter",
12 | "Phi3VisionInterpreter",
13 | ]
14 |
--------------------------------------------------------------------------------
/guidance/models/_engine/_state.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | from .._base import State
4 |
5 |
6 | class EngineState(State):
7 | def __init__(self) -> None:
8 | super().__init__()
9 | self.prompt: str = ""
10 | self.images: list[Any] = []
11 | self.audio: list[Any] = []
12 | self.videos: list[Any] = []
13 |
14 | def __str__(self) -> str:
15 | return self.prompt
16 |
--------------------------------------------------------------------------------
/guidance/models/_openai.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 |
4 | from ._base import Model
5 | from ._openai_base import (
6 | BaseOpenAIInterpreter,
7 | Message,
8 | OpenAIAudioMixin,
9 | OpenAIImageMixin,
10 | OpenAIJSONMixin,
11 | OpenAIRegexMixin,
12 | OpenAIRuleMixin,
13 | )
14 |
15 |
16 | class OpenAIInterpreter(OpenAIRuleMixin, OpenAIJSONMixin, OpenAIRegexMixin, BaseOpenAIInterpreter):
17 | def __init__(
18 | self,
19 | model: str,
20 | api_key: Optional[str] = None,
21 | **kwargs,
22 | ):
23 | try:
24 | import openai
25 | except ImportError:
26 | raise Exception(
27 | "Please install the openai package version >= 1 using `pip install openai -U` in order to use guidance.models.OpenAI!"
28 | )
29 | client = openai.OpenAI(api_key=api_key, **kwargs)
30 | super().__init__(model=model, client=client)
31 |
32 |
33 | class OpenAI(Model):
34 | def __init__(
35 | self,
36 | model: str,
37 | echo: bool = True,
38 | *,
39 | api_key: Optional[str] = None,
40 | **kwargs,
41 | ):
42 | """Build a new OpenAI model object that represents a model in a given state.
43 |
44 | Parameters
45 | ----------
46 | model : str
47 | The name of the OpenAI model to use (e.g. gpt-4o-mini).
48 | echo : bool
49 | If true the final result of creating this model state will be displayed (as HTML in a notebook).
50 | api_key : None or str
51 | The OpenAI API key to use for remote requests, passed directly to the `openai.OpenAI` constructor.
52 |
53 | **kwargs :
54 | All extra keyword arguments are passed directly to the `openai.OpenAI` constructor. Commonly used argument
55 | names include `base_url` and `organization`
56 | """
57 |
58 | if "audio-preview" in model:
59 | interpreter_cls = type(
60 | "OpenAIAudioInterpreter", (OpenAIAudioMixin, OpenAIInterpreter), {}
61 | )
62 | elif model.startswith("gpt-4o") or model.startswith("o1"):
63 | interpreter_cls = type(
64 | "OpenAIImageInterpreter", (OpenAIImageMixin, OpenAIInterpreter), {}
65 | )
66 | else:
67 | interpreter_cls = OpenAIInterpreter
68 |
69 | super().__init__(interpreter=interpreter_cls(model, api_key=api_key, **kwargs), echo=echo)
70 |
--------------------------------------------------------------------------------
/guidance/models/broken_models/README.MD:
--------------------------------------------------------------------------------
1 | These model files use an older version of guidance's internal API design, and need to be updated. They're kept here commented in the codebase solely as reference documentation to help with the migration. They cannot and should not be imported from this repository.
--------------------------------------------------------------------------------
/guidance/models/broken_models/_cohere.py:
--------------------------------------------------------------------------------
1 | # from ._lite_llm import LiteLLMEngine, LiteLLM, LiteLLMCompletion, LiteLLMInstruct
2 |
3 |
4 | # class Cohere(LiteLLM):
5 | # def __init__(
6 | # self,
7 | # model,
8 | # tokenizer=None,
9 | # echo=True,
10 | # timeout=0.5,
11 | # compute_log_probs=False,
12 | # max_streaming_tokens=1000,
13 | # ):
14 | # """Build a new Anthropic model object that represents a model in a given state."""
15 | # try:
16 | # import tokenizers
17 | # except ModuleNotFoundError:
18 | # raise Exception(
19 | # "Please install the HuggingFace tokenizers package using `pip install tokenizers -U` in order to use guidance.models.Cohere!"
20 | # )
21 |
22 | # # get the tokenizer
23 | # if tokenizer is None:
24 | # try:
25 | # tokenizer = tokenizers.Tokenizer.from_pretrained("Cohere/" + model)
26 | # except:
27 | # tokenizer = tokenizers.Tokenizer.from_pretrained(
28 | # "Cohere/command-nightly"
29 | # )
30 |
31 | # super().__init__(
32 | # model,
33 | # tokenizer=tokenizer,
34 | # echo=echo,
35 | # timeout=timeout,
36 | # max_streaming_tokens=max_streaming_tokens,
37 | # compute_log_probs=compute_log_probs,
38 | # )
39 |
40 |
41 | # class CohereCompletion(Cohere, LiteLLMCompletion):
42 | # pass
43 |
44 |
45 | # class CohereInstruct(Cohere, LiteLLMInstruct):
46 | # pass
47 |
--------------------------------------------------------------------------------
/guidance/models/broken_models/_togetherai.py:
--------------------------------------------------------------------------------
1 | # import os
2 | # from .._engine._engine import Chat, Instruct
3 | # from .._openai import (
4 | # OpenAI,
5 | # OpenAIEngine,
6 | # )
7 | # from ..transformers import TransformersTokenizer
8 |
9 |
10 | # class TogetherAI(OpenAI):
11 | # def __init__(
12 | # self,
13 | # model,
14 | # tokenizer=None,
15 | # echo=True,
16 | # api_key=None,
17 | # max_streaming_tokens=1000,
18 | # timeout=0.5,
19 | # compute_log_probs=False,
20 | # engine_class=None,
21 | # **kwargs,
22 | # ):
23 | # """
24 | # Build a new TogetherAI model object that represents a model in a given state.
25 | # """
26 |
27 | # tokenizer = TransformersTokenizer(
28 | # model=model, tokenizer=tokenizer, ignore_bos_token=True
29 | # )
30 |
31 | # # Default base_url is the together.ai endpoint
32 | # if not "base_url" in kwargs:
33 | # kwargs["base_url"] = "https://api.together.xyz"
34 | # # TogetherAI uses TOGETHERAI_API_KEY env value instead of OPENAI_API_KEY
35 | # # We pass explicitly to avoid OpenAI class complaining about a missing key
36 | # if api_key is None:
37 | # api_key = os.environ.get("TOGETHERAI_API_KEY", None)
38 | # if api_key is None:
39 | # raise Exception(
40 | # "The api_key client option must be set either by passing api_key to the client or by setting the TOGETHERAI_API_KEY environment variable"
41 | # )
42 |
43 | # if engine_class is None:
44 | # engine_map = {
45 | # TogetherAICompletion: OpenAIEngine,
46 | # TogetherAIInstruct: OpenAIEngine,
47 | # TogetherAIChat: OpenAIEngine,
48 | # TogetherAI: OpenAIEngine,
49 | # }
50 | # for k in engine_map:
51 | # if issubclass(self.__class__, k):
52 | # engine_class = engine_map[k]
53 | # break
54 |
55 | # super().__init__(
56 | # model,
57 | # tokenizer,
58 | # echo,
59 | # api_key,
60 | # max_streaming_tokens,
61 | # timeout,
62 | # compute_log_probs,
63 | # engine_class,
64 | # **kwargs,
65 | # )
66 |
67 |
68 | # class TogetherAICompletion(TogetherAI):
69 | # pass
70 |
71 |
72 | # class TogetherAIInstruct(TogetherAI, Instruct):
73 | # """
74 | # Utilizes chat endpoints to simulate a single instruction query
75 | # together.ai will format in correct prompt template for model on their end
76 | # """
77 |
78 | # def get_role_start(self, name):
79 | # if name == "instruction":
80 | # return "<|im_start|>user\n"
81 | # else:
82 | # raise Exception(
83 | # f"The TogetherAIInstruct model does not know about the {name} role type!"
84 | # )
85 |
86 | # def get_role_end(self, name):
87 | # if name == "instruction":
88 | # return "<|im_end|>"
89 | # else:
90 | # raise Exception(
91 | # f"The TogetherAIInstruct model does not know about the {name} role type!"
92 | # )
93 |
94 |
95 | # class TogetherAIChat(TogetherAI, Chat):
96 | # pass
97 |
--------------------------------------------------------------------------------
/guidance/models/experimental/__init__.py:
--------------------------------------------------------------------------------
1 | from ._vllm import VLLMModel
--------------------------------------------------------------------------------
/guidance/models/experimental/_vllm.py:
--------------------------------------------------------------------------------
1 | from typing import Iterator, Optional, TYPE_CHECKING
2 | import wave
3 | import base64
4 | from io import BytesIO
5 |
6 | if TYPE_CHECKING:
7 | from openai.types.chat import ChatCompletionChunk
8 |
9 | from ..._ast import GrammarNode
10 | from ...trace import OutputAttr, TextOutput
11 | from ...trace._trace import AudioOutput
12 | from .._openai_base import (
13 | BaseOpenAIInterpreter,
14 | AssistantAudio,
15 | )
16 | from .._base import Model
17 |
18 |
19 | class VLLMInterpreter(BaseOpenAIInterpreter):
20 | def __init__(
21 | self,
22 | model: str,
23 | base_url: Optional[str] = None,
24 | api_key: Optional[str] = None,
25 | **kwargs,
26 | ):
27 | try:
28 | import openai
29 | except ImportError:
30 | raise Exception(
31 | "Please install the openai package version >= 1 using `pip install openai -U` in order to use guidance.models.OpenAI!"
32 | )
33 | client = openai.OpenAI(base_url=base_url, api_key=api_key, **kwargs)
34 | super().__init__(model=model, client=client)
35 |
36 | def grammar(self, node: GrammarNode, **kwargs) -> Iterator[OutputAttr]:
37 | buffer: str = ""
38 | for attr in self._run(
39 | extra_body=dict(
40 | guided_decoding_backend="guidance",
41 | guided_grammar=node.ll_grammar(),
42 | )
43 | ):
44 | if isinstance(attr, TextOutput):
45 | buffer += attr.value
46 | yield attr
47 | matches = node.match(
48 | buffer,
49 | raise_exceptions=False,
50 | # Turn of max_tokens since we don't have access to the tokenizer
51 | enforce_max_tokens=False,
52 | )
53 | if matches is None:
54 | # TODO: should probably raise...
55 | # raise ValueError("vLLM failed to constrain the grammar")
56 | pass
57 | else:
58 | for name, value in matches.captures.items():
59 | log_probs = matches.log_probs[name]
60 | if isinstance(value, list):
61 | assert isinstance(log_probs, list)
62 | assert len(value) == len(log_probs)
63 | for v, l in zip(value, log_probs):
64 | yield self.state.apply_capture(
65 | name=name, value=v, log_prob=l, is_append=True
66 | )
67 | else:
68 | yield self.state.apply_capture(
69 | name=name, value=value, log_prob=log_probs, is_append=False
70 | )
71 |
72 |
73 | class VLLMModel(Model):
74 | def __init__(self, model: str, echo=True, **kwargs):
75 | super().__init__(
76 | interpreter=VLLMInterpreter(model=model, **kwargs),
77 | echo=echo,
78 | )
79 |
--------------------------------------------------------------------------------
/guidance/py.typed:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/guidance/py.typed
--------------------------------------------------------------------------------
/guidance/registry/__init__.py:
--------------------------------------------------------------------------------
1 | """Registry module that contains singletons."""
2 |
3 | from ._registry import get_renderer, set_renderer, get_trace_handler, get_exchange
4 | from ._registry import get_bg_async, get_monitor
--------------------------------------------------------------------------------
/guidance/registry/_registry.py:
--------------------------------------------------------------------------------
1 | # NOTE(nopdive): Consider moving singleton factories to registry static class.
2 |
3 | import threading
4 |
5 | from .._schema import GuidanceEngineMetrics
6 | from ..metrics import Monitor, PeriodicMetricsGenerator
7 | from ..trace import TraceHandler
8 | from ..visual import AutoRenderer, Renderer, TopicExchange
9 | from .._bg import BackgroundAsync
10 |
11 | _monitor_lock = threading.Lock()
12 | _monitor = None
13 | _periodic_metrics_gen = None
14 |
15 | _bg_async_lock = threading.Lock()
16 | _bg_async = None
17 |
18 | _exchange_lock = threading.Lock()
19 | _exchange = None
20 |
21 | _trace_handler_lock = threading.Lock()
22 | _trace_handler = None
23 |
24 | _renderer_lock = threading.Lock()
25 | _renderer = None
26 |
27 |
28 | def get_monitor() -> Monitor:
29 | global _monitor
30 | global _monitor_lock
31 | global _periodic_metrics_gen
32 |
33 | with _monitor_lock:
34 | if _monitor is None:
35 | _monitor = Monitor(GuidanceEngineMetrics())
36 | _monitor.start()
37 | _periodic_metrics_gen = PeriodicMetricsGenerator(_monitor)
38 | _periodic_metrics_gen.start()
39 | return _monitor
40 |
41 |
42 | def get_bg_async() -> BackgroundAsync:
43 | global _bg_async
44 | global _bg_async_lock
45 |
46 | with _bg_async_lock:
47 | if _bg_async is None:
48 | _bg_async = BackgroundAsync()
49 | return _bg_async
50 |
51 |
52 | def get_exchange() -> TopicExchange:
53 | global _exchange
54 | global _exchange_lock
55 |
56 | with _exchange_lock:
57 | if _exchange is None:
58 | _exchange = TopicExchange()
59 | return _exchange
60 |
61 |
62 | def get_trace_handler() -> TraceHandler:
63 | global _trace_handler
64 | global _trace_handler_lock
65 |
66 | with _trace_handler_lock:
67 | if _trace_handler is None:
68 | _trace_handler = TraceHandler()
69 | return _trace_handler
70 |
71 |
72 | def get_renderer() -> Renderer:
73 | global _renderer
74 | global _renderer_lock
75 |
76 | with _renderer_lock:
77 | trace_handler = get_trace_handler()
78 | if _renderer is None:
79 | _renderer = AutoRenderer(trace_handler)
80 | return _renderer
81 |
82 |
83 | def set_renderer(renderer: Renderer) -> None:
84 | global _renderer
85 | global _renderer_lock
86 |
87 | with _renderer_lock:
88 | _renderer = renderer
89 |
--------------------------------------------------------------------------------
/guidance/resources/sample_audio.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/guidance/resources/sample_audio.wav
--------------------------------------------------------------------------------
/guidance/resources/sample_image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/guidance/resources/sample_image.png
--------------------------------------------------------------------------------
/guidance/resources/sample_video.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/guidance/resources/sample_video.mp4
--------------------------------------------------------------------------------
/guidance/trace/__init__.py:
--------------------------------------------------------------------------------
1 | """Trace tree of inputs & outputs generated from a guidance program.
2 |
3 | The first implementation aims for simplicity.
4 | Once benchmark figures are out, we'll figure out what to optimize.
5 |
6 | The most critical class is the trace handler. See its documentation for trace design & motivations.
7 | """
8 |
9 | from ._trace import NodeAttr, InputAttr, OutputAttr, StatefulGuidanceInput, StatelessGuidanceInput
10 | from ._trace import LiteralInput, EmbeddedInput, ImageInput, RoleCloserInput, RoleOpenerInput
11 | from ._trace import TextOutput, ImageOutput, CaptureOutput, TraceNode, TraceHandler
12 |
--------------------------------------------------------------------------------
/guidance/visual/__init__.py:
--------------------------------------------------------------------------------
1 | """UI and other visual UX considerations.
2 |
3 | Users should have few reasons to be accessing this module.
4 | """
5 |
6 | from ._message import GuidanceMessage, TraceMessage, ResetDisplayMessage, ClientReadyMessage, ClientReadyAckMessage
7 | from ._message import ExecutionCompletedMessage, TokensMessage, MetricMessage, OutputRequestMessage
8 | from ._message import ExecutionStartedMessage
9 | from ._renderer import AutoRenderer, JupyterWidgetRenderer, Renderer
10 | from ._message import serialize_message, deserialize_message
11 | from ._trace import trace_node_to_str, display_trace_tree, trace_node_to_html
12 | from ._exchange import TopicExchange
--------------------------------------------------------------------------------
/guidance/visual/_exchange.py:
--------------------------------------------------------------------------------
1 | """ Poor man's exchanges for routing messages. """
2 |
3 | from collections import defaultdict
4 | from typing import Callable
5 | from ..visual import GuidanceMessage
6 | import re
7 | import logging
8 |
9 | logger = logging.getLogger(__name__)
10 |
11 | DEFAULT_TOPIC = "/default"
12 | WILDCARD_PATTERN = r".*"
13 |
14 |
15 | class TopicExchange:
16 | """ Queue-less topic exchange for routing messages.
17 |
18 | This is not as comprehensive as a full distributed topic exchange.
19 | It is specific to a single process, with no queues and less generalized routing keys.
20 | """
21 |
22 | def __init__(self):
23 | """ Initializes."""
24 | self._observers = defaultdict(list)
25 |
26 | def subscribe(self, callback: Callable[[GuidanceMessage], None], topic_pat: str = WILDCARD_PATTERN) -> None:
27 | """ Subscribes to incoming messages.
28 |
29 | Args:
30 | callback: Callback to handle incoming messages.
31 | topic_pat: Topic to notify.
32 | """
33 | logger.debug(f"EXCHANGE:pre_subscribe:{self._observers[topic_pat]}")
34 | self._observers[topic_pat].append(callback)
35 | logger.debug(f"EXCHANGE:post_subscribe:{self._observers[topic_pat]}")
36 |
37 | def unsubscribe(self, callback: Callable[[GuidanceMessage], None], topic_pat: str = WILDCARD_PATTERN) -> None:
38 | """ Unsubscribes from incoming messages.
39 |
40 | Args:
41 | callback: Callback to remove.
42 | topic_pat: Topic pattern.
43 | """
44 | logger.debug(f"EXCHANGE:pre_unsubscribe:{self._observers[topic_pat]}")
45 | try:
46 | self._observers[topic_pat].remove(callback)
47 | except ValueError as _:
48 | logger.warning(f"EXCHANGE:cb at '{topic_pat}' already removed.")
49 | logger.debug(f"EXCHANGE:post_unsubscribe:{self._observers[topic_pat]}")
50 |
51 | if len(self._observers[topic_pat]) == 0:
52 | logger.debug(f"EXCHANGE:delete_entry:{topic_pat}")
53 | del self._observers[topic_pat]
54 |
55 | def publish(self, message: GuidanceMessage, topic: str = DEFAULT_TOPIC):
56 | """ Notifies all subscribers to topic pattern of an incoming message.
57 |
58 | Args:
59 | message: Incoming message.
60 | topic: Topics to notify.
61 | """
62 | # logger.debug(f"EXCHANGE:publish:{message}")
63 | for obs_topic_pat, observers in self._observers.items():
64 | if re.match(obs_topic_pat, topic):
65 | for observer in observers:
66 | observer(message)
67 |
68 |
69 | __all__ = ["TopicExchange"]
70 |
--------------------------------------------------------------------------------
/guidance/visual/_jupyter.py:
--------------------------------------------------------------------------------
1 | """ Jupyter specific utilities."""
2 |
3 |
4 | from typing import Callable, Any, Tuple, Optional
5 | import logging
6 | from uuid import uuid4
7 |
8 | try:
9 | from IPython import get_ipython
10 | except ImportError:
11 | pass
12 |
13 | logger = logging.getLogger(__name__)
14 |
15 | IPythonCallback = Callable[[Any], None]
16 |
17 |
18 | def ipy_handle_event_once(cb: IPythonCallback, event_name: str) -> Tuple[Optional[IPythonCallback], str]:
19 | ipy = get_ipython()
20 | cell_session_id = str(uuid4())
21 |
22 | if ipy is None:
23 | return None, ""
24 |
25 | def cb_closure(msg):
26 | cb(info=msg)
27 | ipy.events.unregister(event_name, cb_closure)
28 | ipy.events.register(event_name, cb_closure)
29 |
30 | return cb_closure, cell_session_id
--------------------------------------------------------------------------------
/notebooks/unstable/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/packages/python/stitch/.coveragerc:
--------------------------------------------------------------------------------
1 | [run]
2 | omit = stitch/tests/*
3 |
--------------------------------------------------------------------------------
/packages/python/stitch/.eslintignore:
--------------------------------------------------------------------------------
1 | node_modules
2 | dist
3 | coverage
4 | **/*.d.ts
5 | tests
--------------------------------------------------------------------------------
/packages/python/stitch/.eslintrc.js:
--------------------------------------------------------------------------------
1 | module.exports = {
2 | extends: [
3 | 'eslint:recommended',
4 | 'plugin:@typescript-eslint/eslint-recommended',
5 | 'plugin:@typescript-eslint/recommended',
6 | 'plugin:prettier/recommended'
7 | ],
8 | parser: '@typescript-eslint/parser',
9 | parserOptions: {
10 | project: 'tsconfig.eslint.json',
11 | sourceType: 'module'
12 | },
13 | plugins: ['@typescript-eslint'],
14 | rules: {
15 | '@typescript-eslint/no-unused-vars': ['warn', { args: 'none' }],
16 | '@typescript-eslint/no-explicit-any': 'off',
17 | '@typescript-eslint/no-namespace': 'off',
18 | '@typescript-eslint/no-use-before-define': 'off',
19 | '@typescript-eslint/quotes': [
20 | 'error',
21 | 'single',
22 | { avoidEscape: true, allowTemplateLiterals: false }
23 | ],
24 | curly: ['error', 'all'],
25 | eqeqeq: 'error',
26 | 'prefer-arrow-callback': 'error'
27 | }
28 | };
--------------------------------------------------------------------------------
/packages/python/stitch/.github/workflows/build.yml:
--------------------------------------------------------------------------------
1 | name: Build
2 |
3 | on:
4 | push:
5 | branches: main
6 | pull_request:
7 | branches: "*"
8 |
9 | jobs:
10 | build:
11 | runs-on: ${{ matrix.os }}
12 | strategy:
13 | fail-fast: false
14 | matrix:
15 | os: [ubuntu-latest, windows-latest, macos-latest]
16 | python-version: ["3.7", "3.10"]
17 | steps:
18 | - name: Checkout
19 | uses: actions/checkout@v2
20 |
21 | - uses: jupyterlab/maintainer-tools/.github/actions/base-setup@v1
22 |
23 | - name: Install dependencies
24 | run: |
25 | python -m pip install -U codecov
26 | npm install -g codecov
27 | - name: Test the extension
28 | run: |
29 | python -m pip install --upgrade -v -e ".[test, examples, docs]"
30 | python -m pytest
31 | yarn run test
32 |
33 | - name: Linting
34 | if: ${{ matrix.os == 'ubuntu-latest' }}
35 | run: |
36 | yarn run lint:check
37 |
38 | - name: Check docs can be build + links
39 | if: ${{ matrix.os == 'ubuntu-latest' }}
40 | working-directory: docs
41 | run: |
42 | sudo apt install -y pandoc
43 | make html
44 | python -m pytest --check-links
45 |
--------------------------------------------------------------------------------
/packages/python/stitch/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 |
27 | # PyInstaller
28 | # Usually these files are written by a python script from a template
29 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
30 | *.manifest
31 | *.spec
32 |
33 | # Installer logs
34 | pip-log.txt
35 | pip-delete-this-directory.txt
36 |
37 | # Unit test / coverage reports
38 | htmlcov/
39 | .tox/
40 | .coverage
41 | .coverage.*
42 | .cache
43 | nosetests.xml
44 | coverage.xml
45 | *,cover
46 | .hypothesis/
47 |
48 | # Translations
49 | *.mo
50 | *.pot
51 |
52 | # Django stuff:
53 | *.log
54 | local_settings.py
55 |
56 | # Flask instance folder
57 | instance/
58 |
59 | # Scrapy stuff:
60 | .scrapy
61 |
62 | # Sphinx documentation
63 | docs/_build/
64 | docs/source/_static/embed-bundle.js
65 | docs/source/_static/embed-bundle.js.map
66 |
67 | # PyBuilder
68 | target/
69 |
70 | # IPython Notebook
71 | .ipynb_checkpoints
72 |
73 | # pyenv
74 | .python-version
75 |
76 | # celery beat schedule file
77 | celerybeat-schedule
78 |
79 | # dotenv
80 | .env
81 |
82 | # virtualenv
83 | venv/
84 | ENV/
85 |
86 | # Spyder project settings
87 | .spyderproject
88 |
89 | # Rope project settings
90 | .ropeproject
91 |
92 | # =========================
93 | # Operating System Files
94 | # =========================
95 |
96 | # OSX
97 | # =========================
98 |
99 | .DS_Store
100 | .AppleDouble
101 | .LSOverride
102 |
103 | # Thumbnails
104 | ._*
105 |
106 | # Files that might appear in the root of a volume
107 | .DocumentRevisions-V100
108 | .fseventsd
109 | .Spotlight-V100
110 | .TemporaryItems
111 | .Trashes
112 | .VolumeIcon.icns
113 |
114 | # Directories potentially created on remote AFP share
115 | .AppleDB
116 | .AppleDesktop
117 | Network Trash Folder
118 | Temporary Items
119 | .apdisk
120 |
121 | # Windows
122 | # =========================
123 |
124 | # Windows image file caches
125 | Thumbs.db
126 | ehthumbs.db
127 |
128 | # Folder config file
129 | Desktop.ini
130 |
131 | # Recycle Bin used on file shares
132 | $RECYCLE.BIN/
133 |
134 | # Windows Installer files
135 | *.cab
136 | *.msi
137 | *.msm
138 | *.msp
139 |
140 | # Windows shortcuts
141 | *.lnk
142 |
143 |
144 | # NPM
145 | # ----
146 |
147 | **/node_modules/
148 | stitch/nbextension/index.*
149 | .yarn/
150 |
151 | # Coverage data
152 | # -------------
153 | **/coverage/
154 |
155 | # Packed lab extensions
156 | stitch/labextension
157 |
--------------------------------------------------------------------------------
/packages/python/stitch/.npmignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | node_modules/
3 | tests/
4 | .jshintrc
5 | # Ignore any build output from python:
6 | dist/*.tar.gz
7 | dist/*.wheel
8 |
--------------------------------------------------------------------------------
/packages/python/stitch/.prettierignore:
--------------------------------------------------------------------------------
1 | node_modules
2 | **/node_modules
3 | **/lib
4 | **/package.json
--------------------------------------------------------------------------------
/packages/python/stitch/.prettierrc:
--------------------------------------------------------------------------------
1 | {
2 | "singleQuote": true
3 | }
--------------------------------------------------------------------------------
/packages/python/stitch/.yarnrc.yml:
--------------------------------------------------------------------------------
1 | nodeLinker: node-modules
2 |
--------------------------------------------------------------------------------
/packages/python/stitch/LICENSE.txt:
--------------------------------------------------------------------------------
1 | Copyright (c) 2024 Guidance Contributors
2 | All rights reserved.
3 |
4 | Redistribution and use in source and binary forms, with or without
5 | modification, are permitted provided that the following conditions are met:
6 |
7 | 1. Redistributions of source code must retain the above copyright notice, this
8 | list of conditions and the following disclaimer.
9 |
10 | 2. Redistributions in binary form must reproduce the above copyright notice,
11 | this list of conditions and the following disclaimer in the documentation
12 | and/or other materials provided with the distribution.
13 |
14 | 3. Neither the name of the copyright holder nor the names of its
15 | contributors may be used to endorse or promote products derived from
16 | this software without specific prior written permission.
17 |
18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 |
--------------------------------------------------------------------------------
/packages/python/stitch/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include LICENSE.txt
2 | include README.md
3 |
4 | include setup.py
5 | include pyproject.toml
6 | include pytest.ini
7 | include .coverage.rc
8 |
9 | include tsconfig.json
10 | include package.json
11 | include webpack.config.js
12 | include stitch/labextension/*.tgz
13 |
14 | # Documentation
15 | graft docs
16 | exclude docs/\#*
17 | prune docs/build
18 | prune docs/gh-pages
19 | prune docs/dist
20 |
21 | # Examples
22 | graft examples
23 |
24 | # Tests
25 | graft tests
26 | prune tests/build
27 |
28 | # Javascript files
29 | graft stitch/nbextension
30 | graft src
31 | graft css
32 | prune **/node_modules
33 | prune coverage
34 | prune lib
35 |
36 | # Patterns to exclude from any directory
37 | global-exclude *~
38 | global-exclude *.pyc
39 | global-exclude *.pyo
40 | global-exclude .git
41 | global-exclude .ipynb_checkpoints
42 |
--------------------------------------------------------------------------------
/packages/python/stitch/README.md:
--------------------------------------------------------------------------------
1 |
2 | # stitch
3 |
4 | [](https://travis-ci.org/guidance-ai/stitch)
5 | [](https://codecov.io/gh/guidance-ai/stitch)
6 |
7 |
8 | Bidirectional comms for Jupyter and JavaScript.
9 |
10 | ## Installation
11 |
12 | You can install using `pip`:
13 |
14 | ```bash
15 | pip install guidance-stitch
16 | ```
17 |
18 | If you are using Jupyter Notebook 5.2 or earlier, you may also need to enable
19 | the nbextension:
20 | ```bash
21 | jupyter nbextension enable --py [--sys-prefix|--user|--system] guidance-stitch
22 | ```
23 |
24 | ## Development Installation
25 |
26 | Create a dev environment:
27 | ```bash
28 | conda create -n stitch-dev -c conda-forge nodejs python jupyterlab=4.0.11
29 | conda activate stitch-dev
30 | ```
31 |
32 | Install the python. This will also build the TS package.
33 | ```bash
34 | pip install -e ".[test, examples]"
35 | ```
36 |
37 | When developing your extensions, you need to manually enable your extensions with the
38 | notebook / lab frontend. For lab, this is done by the command:
39 |
40 | ```
41 | jupyter labextension develop --overwrite .
42 | jlpm run build
43 | ```
44 |
45 | For classic notebook, you need to run:
46 |
47 | ```
48 | jupyter nbextension install --sys-prefix --symlink --overwrite --py guidance-stitch
49 | jupyter nbextension enable --sys-prefix --py guidance-stitch
50 | ```
51 |
52 | Note that the `--symlink` flag doesn't work on Windows, so you will here have to run
53 | the `install` command every time that you rebuild your extension. For certain installations
54 | you might also need another flag instead of `--sys-prefix`, but we won't cover the meaning
55 | of those flags here.
56 |
57 | ### How to see your changes
58 | #### Typescript:
59 | If you use JupyterLab to develop then you can watch the source directory and run JupyterLab at the same time in different
60 | terminals to watch for changes in the extension's source and automatically rebuild the widget.
61 |
62 | ```bash
63 | # Watch the source directory in one terminal, automatically rebuilding when needed
64 | jlpm run watch
65 | # Run JupyterLab in another terminal
66 | jupyter lab
67 | ```
68 |
69 | After a change wait for the build to finish and then refresh your browser and the changes should take effect.
70 |
71 | #### Python:
72 | If you make a change to the python code then you will need to restart the notebook kernel to have it take effect.
73 |
74 | ## Updating the version
75 |
76 | To update the version, install tbump and use it to bump the version.
77 | By default it will also create a tag.
78 |
79 | ```bash
80 | pip install tbump
81 | tbump
82 | ```
83 |
84 |
--------------------------------------------------------------------------------
/packages/python/stitch/babel.config.js:
--------------------------------------------------------------------------------
1 | module.exports = {
2 | sourceMap: 'inline',
3 | presets: [
4 | [
5 | '@babel/preset-env',
6 | {
7 | targets: {
8 | node: 'current',
9 | },
10 | },
11 | ],
12 | ],
13 | };
14 |
--------------------------------------------------------------------------------
/packages/python/stitch/codecov.yml:
--------------------------------------------------------------------------------
1 | comment: off
2 | # show coverage in CI status, but never consider it a failure
3 | coverage:
4 | status:
5 | project:
6 | default:
7 | target: 0%
8 | patch:
9 | default:
10 | target: 0%
11 | ignore:
12 | - "stitch/tests"
13 |
--------------------------------------------------------------------------------
/packages/python/stitch/css/widget.css:
--------------------------------------------------------------------------------
1 | .custom-widget {
2 | background-color: lightseagreen;
3 | padding: 0px 2px;
4 | }
5 |
--------------------------------------------------------------------------------
/packages/python/stitch/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line.
5 | SPHINXOPTS =
6 | SPHINXBUILD = sphinx-build
7 | SPHINXPROJ = stitch
8 | SOURCEDIR = source
9 | BUILDDIR = build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/packages/python/stitch/docs/environment.yml:
--------------------------------------------------------------------------------
1 |
2 | name: stitch_docs
3 | channels:
4 | - conda-forge
5 | dependencies:
6 | - python=3.*
7 | - nodejs
8 | - jupyter_sphinx
9 | - sphinx
10 | - sphinx_rtd_theme
11 | - nbsphinx
12 | - nbsphinx-link
13 |
--------------------------------------------------------------------------------
/packages/python/stitch/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=source
11 | set BUILDDIR=build
12 | set SPHINXPROJ=stitch
13 |
14 | if "%1" == "" goto help
15 |
16 | %SPHINXBUILD% >NUL 2>NUL
17 | if errorlevel 9009 (
18 | echo.
19 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
20 | echo.installed, then set the SPHINXBUILD environment variable to point
21 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
22 | echo.may add the Sphinx directory to PATH.
23 | echo.
24 | echo.If you don't have Sphinx installed, grab it from
25 | echo.http://sphinx-doc.org/
26 | exit /b 1
27 | )
28 |
29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
30 | goto end
31 |
32 | :help
33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
34 |
35 | :end
36 | popd
37 |
--------------------------------------------------------------------------------
/packages/python/stitch/docs/source/_static/helper.js:
--------------------------------------------------------------------------------
1 | var cache_require = window.require;
2 |
3 | window.addEventListener('load', function() {
4 | window.require = cache_require;
5 | });
6 |
--------------------------------------------------------------------------------
/packages/python/stitch/docs/source/develop-install.rst:
--------------------------------------------------------------------------------
1 |
2 | Developer install
3 | =================
4 |
5 |
6 | To install a developer version of stitch, you will first need to clone
7 | the repository::
8 |
9 | git clone https://github.com/guidance-ai/stitch
10 | cd stitch
11 |
12 | Next, install it with a develop install using pip::
13 |
14 | pip install -e .
15 |
16 |
17 | If you are planning on working on the JS/frontend code, you should also do
18 | a link installation of the extension::
19 |
20 | jupyter nbextension install [--sys-prefix / --user / --system] --symlink --py stitch
21 |
22 | jupyter nbextension enable [--sys-prefix / --user / --system] --py stitch
23 |
24 | with the `appropriate flag`_. Or, if you are using Jupyterlab::
25 |
26 | jupyter labextension install .
27 |
28 |
29 | .. links
30 |
31 | .. _`appropriate flag`: https://jupyter-notebook.readthedocs.io/en/stable/extending/frontend_extensions.html#installing-and-enabling-extensions
32 |
--------------------------------------------------------------------------------
/packages/python/stitch/docs/source/examples/index.rst:
--------------------------------------------------------------------------------
1 |
2 | Examples
3 | ========
4 |
5 | This section contains several examples generated from Jupyter notebooks.
6 | The widgets have been embedded into the page for demonstrative purposes.
7 |
8 | .. todo::
9 |
10 | Add links to notebooks in examples folder similar to the initial
11 | one. This is a manual step to ensure only those examples that
12 | are suited for inclusion are used.
13 |
14 |
15 | .. toctree::
16 | :glob:
17 |
18 | *
19 |
--------------------------------------------------------------------------------
/packages/python/stitch/docs/source/examples/introduction.nblink:
--------------------------------------------------------------------------------
1 | {
2 | "path": "../../../examples/introduction.ipynb"
3 | }
4 |
--------------------------------------------------------------------------------
/packages/python/stitch/docs/source/index.rst:
--------------------------------------------------------------------------------
1 |
2 | stitch
3 | =====================================
4 |
5 | Version: |release|
6 |
7 | Bidirectional comms for Jupyter and JavaScript.
8 |
9 |
10 | Quickstart
11 | ----------
12 |
13 | To get started with stitch, install with pip::
14 |
15 | pip install stitch
16 |
17 | or with conda::
18 |
19 | conda install stitch
20 |
21 |
22 | Contents
23 | --------
24 |
25 | .. toctree::
26 | :maxdepth: 2
27 | :caption: Installation and usage
28 |
29 | installing
30 | introduction
31 |
32 | .. toctree::
33 | :maxdepth: 1
34 |
35 | examples/index
36 |
37 |
38 | .. toctree::
39 | :maxdepth: 2
40 | :caption: Development
41 |
42 | develop-install
43 |
44 |
45 | .. links
46 |
47 | .. _`Jupyter widgets`: https://jupyter.org/widgets.html
48 |
49 | .. _`notebook`: https://jupyter-notebook.readthedocs.io/en/latest/
50 |
--------------------------------------------------------------------------------
/packages/python/stitch/docs/source/installing.rst:
--------------------------------------------------------------------------------
1 |
2 | .. _installation:
3 |
4 | Installation
5 | ============
6 |
7 |
8 | The simplest way to install stitch is via pip::
9 |
10 | pip install stitch
11 |
12 | or via conda::
13 |
14 | conda install stitch
15 |
16 |
17 | If you installed via pip, and notebook version < 5.3, you will also have to
18 | install / configure the front-end extension as well. If you are using classic
19 | notebook (as opposed to Jupyterlab), run::
20 |
21 | jupyter nbextension install [--sys-prefix / --user / --system] --py stitch
22 |
23 | jupyter nbextension enable [--sys-prefix / --user / --system] --py stitch
24 |
25 | with the `appropriate flag`_. If you are using Jupyterlab, install the extension
26 | with::
27 |
28 | jupyter labextension install @guidance-ai/stitch
29 |
30 | If you are installing using conda, these commands should be unnecessary, but If
31 | you need to run them the commands should be the same (just make sure you choose the
32 | `--sys-prefix` flag).
33 |
34 |
35 | .. links
36 |
37 | .. _`appropriate flag`: https://jupyter-notebook.readthedocs.io/en/stable/extending/frontend_extensions.html#installing-and-enabling-extensions
38 |
--------------------------------------------------------------------------------
/packages/python/stitch/docs/source/introduction.rst:
--------------------------------------------------------------------------------
1 | =============
2 | Introduction
3 | =============
4 |
5 | .. todo::
6 |
7 | add prose explaining project purpose and usage here
8 |
--------------------------------------------------------------------------------
/packages/python/stitch/install.json:
--------------------------------------------------------------------------------
1 | {
2 | "packageManager": "python",
3 | "packageName": "stitch",
4 | "uninstallInstructions": "Use your Python package manager (pip, conda, etc.) to uninstall the package stitch"
5 | }
6 |
--------------------------------------------------------------------------------
/packages/python/stitch/jest.config.js:
--------------------------------------------------------------------------------
1 | module.exports = {
2 | automock: false,
3 | moduleNameMapper: {
4 | '\\.(css|less|sass|scss)$': 'identity-obj-proxy',
5 | },
6 | preset: 'ts-jest/presets/js-with-babel',
7 | moduleFileExtensions: ['ts', 'tsx', 'js', 'jsx', 'json', 'node'],
8 | testPathIgnorePatterns: ['/lib/', '/node_modules/'],
9 | testRegex: '/__tests__/.*.spec.ts[x]?$',
10 | transformIgnorePatterns: ['/node_modules/(?!(@jupyter(lab|-widgets)/.*)/)'],
11 | globals: {
12 | 'ts-jest': {
13 | tsconfig: '/tsconfig.json',
14 | },
15 | },
16 | };
17 |
--------------------------------------------------------------------------------
/packages/python/stitch/pytest.ini:
--------------------------------------------------------------------------------
1 | [pytest]
2 | testpaths = stitch/tests examples
3 | norecursedirs = node_modules .ipynb_checkpoints
4 | addopts = --nbval --current-env
5 |
--------------------------------------------------------------------------------
/packages/python/stitch/readthedocs.yml:
--------------------------------------------------------------------------------
1 | type: sphinx
2 | python:
3 | version: 3.5
4 | pip_install: true
5 | extra_requirements:
6 | - examples
7 | - docs
8 | conda:
9 | file: docs/environment.yml
10 |
--------------------------------------------------------------------------------
/packages/python/stitch/setup.py:
--------------------------------------------------------------------------------
1 | # setup.py shim for use with applications that require it.
2 | __import__("setuptools").setup()
3 |
--------------------------------------------------------------------------------
/packages/python/stitch/src/__tests__/index.spec.ts:
--------------------------------------------------------------------------------
1 | /** @jest-environment jsdom */
2 | // Copyright (c) Jupyter Development Team.
3 | // Distributed under the terms of the Modified BSD License.
4 |
5 | // Add any needed widget imports here (or from controls)
6 | // import {} from '@jupyter-widgets/base';
7 |
8 | // NOTE(nopdive): Workaround for jsdom drag event failure.
9 | Object.defineProperty(window, 'DragEvent', {
10 | value: class DragEvent {},
11 | });
12 |
13 | import { createTestModel } from './utils';
14 | import { StitchModel } from '..';
15 |
16 | describe('Example', () => {
17 | describe('StitchModel', () => {
18 | it('should be createable', () => {
19 | const model = createTestModel(StitchModel);
20 | expect(model).toBeInstanceOf(StitchModel);
21 | expect(model.get('srcdoc')).toEqual(
22 | 'srcdoc should be defined by the user
',
23 | );
24 | });
25 |
26 | it('should be createable with a value', () => {
27 | const state = { srcdoc: 'it is alright' };
28 | const model = createTestModel(StitchModel, state);
29 | expect(model).toBeInstanceOf(StitchModel);
30 | expect(model.get('srcdoc')).toEqual('it is alright');
31 | });
32 | });
33 | });
34 |
--------------------------------------------------------------------------------
/packages/python/stitch/src/__tests__/utils.ts:
--------------------------------------------------------------------------------
1 | // Copyright (c) Jupyter Development Team.
2 | // Distributed under the terms of the Modified BSD License.
3 |
4 | import * as widgets from '@jupyter-widgets/base';
5 | import * as baseManager from '@jupyter-widgets/base-manager';
6 | import * as services from '@jupyterlab/services';
7 |
8 | let numComms = 0;
9 |
10 | export class MockComm implements widgets.IClassicComm {
11 | constructor() {
12 | this.comm_id = `mock-comm-id-${numComms}`;
13 | numComms += 1;
14 | }
15 | on_close(fn: ((x?: any) => void) | null): void {
16 | this._on_close = fn;
17 | }
18 | on_msg(fn: (x?: any) => void): void {
19 | this._on_msg = fn;
20 | }
21 | _process_msg(msg: services.KernelMessage.ICommMsgMsg): void | Promise {
22 | if (this._on_msg) {
23 | return this._on_msg(msg);
24 | } else {
25 | return Promise.resolve();
26 | }
27 | }
28 | close(): string {
29 | if (this._on_close) {
30 | this._on_close();
31 | }
32 | return 'dummy';
33 | }
34 | send(): string {
35 | return 'dummy';
36 | }
37 |
38 | open(): string {
39 | return 'dummy';
40 | }
41 |
42 | comm_id: string;
43 | target_name = 'dummy';
44 | _on_msg: ((x?: any) => void) | null = null;
45 | _on_close: ((x?: any) => void) | null = null;
46 | }
47 |
48 | export class DummyManager extends baseManager.ManagerBase {
49 | constructor() {
50 | super();
51 | this.el = window.document.createElement('div');
52 | }
53 |
54 | display_view(
55 | msg: services.KernelMessage.IMessage,
56 | view: widgets.DOMWidgetView,
57 | options: any
58 | ) {
59 | // TODO: make this a spy
60 | // TODO: return an html element
61 | return Promise.resolve(view).then((view) => {
62 | this.el.appendChild(view.el);
63 | view.on('remove', () => console.log('view removed', view));
64 | return view.el;
65 | });
66 | }
67 |
68 | protected loadClass(
69 | className: string,
70 | moduleName: string,
71 | moduleVersion: string
72 | ): Promise {
73 | if (moduleName === '@jupyter-widgets/base') {
74 | if ((widgets as any)[className]) {
75 | return Promise.resolve((widgets as any)[className]);
76 | } else {
77 | return Promise.reject(`Cannot find class ${className}`);
78 | }
79 | } else if (moduleName === 'jupyter-datawidgets') {
80 | if (this.testClasses[className]) {
81 | return Promise.resolve(this.testClasses[className]);
82 | } else {
83 | return Promise.reject(`Cannot find class ${className}`);
84 | }
85 | } else {
86 | return Promise.reject(`Cannot find module ${moduleName}`);
87 | }
88 | }
89 |
90 | _get_comm_info() {
91 | return Promise.resolve({});
92 | }
93 |
94 | _create_comm() {
95 | return Promise.resolve(new MockComm());
96 | }
97 |
98 | el: HTMLElement;
99 |
100 | testClasses: { [key: string]: any } = {};
101 | }
102 |
103 | export interface Constructor {
104 | new (attributes?: any, options?: any): T;
105 | }
106 |
107 | export function createTestModel(
108 | constructor: Constructor,
109 | attributes?: any
110 | ): T {
111 | const id = widgets.uuid();
112 | const widget_manager = new DummyManager();
113 | const modelOptions = {
114 | widget_manager: widget_manager,
115 | model_id: id,
116 | };
117 |
118 | return new constructor(attributes, modelOptions);
119 | }
120 |
--------------------------------------------------------------------------------
/packages/python/stitch/src/extension.ts:
--------------------------------------------------------------------------------
1 | // Copyright (c) Jupyter Development Team.
2 | // Distributed under the terms of the Modified BSD License.
3 |
4 | // Entry point for the notebook bundle containing custom model definitions.
5 | //
6 | // Setup notebook base URL
7 | //
8 | // Some static assets may be required by the custom widget javascript. The base
9 | // url for the notebook is not known at build time and is therefore computed
10 | // dynamically.
11 | // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
12 | (window as any).__webpack_public_path__ =
13 | document.querySelector('body')!.getAttribute('data-base-url') +
14 | 'nbextensions/stitch';
15 |
16 | export * from './index';
17 |
--------------------------------------------------------------------------------
/packages/python/stitch/src/index.ts:
--------------------------------------------------------------------------------
1 | // Copyright (c) Guidance Contributors
2 | // Distributed under the terms of the Modified BSD License.
3 |
4 | export * from './version';
5 | export * from './widget';
6 |
--------------------------------------------------------------------------------
/packages/python/stitch/src/plugin.ts:
--------------------------------------------------------------------------------
1 | // Copyright (c) Guidance Contributors
2 | // Distributed under the terms of the Modified BSD License.
3 |
4 | import { Application, IPlugin } from '@lumino/application';
5 |
6 | import { Widget } from '@lumino/widgets';
7 |
8 | import { IJupyterWidgetRegistry } from '@jupyter-widgets/base';
9 |
10 | import * as widgetExports from './widget';
11 |
12 | import { MODULE_NAME, MODULE_VERSION } from './version';
13 |
14 | const EXTENSION_ID = '@guidance-ai/stitch:plugin';
15 |
16 | /**
17 | * The example plugin.
18 | */
19 | const examplePlugin: IPlugin, void> = {
20 | id: EXTENSION_ID,
21 | requires: [IJupyterWidgetRegistry],
22 | activate: activateWidgetExtension,
23 | autoStart: true,
24 | } as unknown as IPlugin, void>;
25 | // the "as unknown as ..." typecast above is solely to support JupyterLab 1
26 | // and 2 in the same codebase and should be removed when we migrate to Lumino.
27 |
28 | export default examplePlugin;
29 |
30 | /**
31 | * Activate the widget extension.
32 | */
33 | function activateWidgetExtension(
34 | app: Application,
35 | registry: IJupyterWidgetRegistry
36 | ): void {
37 | registry.registerWidget({
38 | name: MODULE_NAME,
39 | version: MODULE_VERSION,
40 | exports: widgetExports,
41 | });
42 | }
43 |
--------------------------------------------------------------------------------
/packages/python/stitch/src/version.ts:
--------------------------------------------------------------------------------
1 | // Copyright (c) Guidance Contributors
2 | // Distributed under the terms of the Modified BSD License.
3 |
4 | // eslint-disable-next-line @typescript-eslint/ban-ts-comment
5 | // @ts-ignore
6 | // eslint-disable-next-line @typescript-eslint/no-var-requires
7 | const data = require('../package.json');
8 |
9 | /**
10 | * The _model_module_version/_view_module_version this package implements.
11 | *
12 | * The html widget manager assumes that this is the same as the npm package
13 | * version number.
14 | */
15 | export const MODULE_VERSION = data.version;
16 |
17 | /*
18 | * The current package name.
19 | */
20 | export const MODULE_NAME = data.name;
21 |
--------------------------------------------------------------------------------
/packages/python/stitch/stitch.json:
--------------------------------------------------------------------------------
1 | {
2 | "load_extensions": {
3 | "stitch/extension": true
4 | }
5 | }
6 |
--------------------------------------------------------------------------------
/packages/python/stitch/stitch/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 |
4 | # Copyright (c) Guidance Contributors.
5 | # Distributed under the terms of the Modified BSD License.
6 |
7 | from .stitch import StitchWidget
8 | from ._version import __version__, version_info
9 |
10 | def _jupyter_labextension_paths():
11 | """Called by Jupyter Lab Server to detect if it is a valid labextension and
12 | to install the widget
13 | Returns
14 | =======
15 | src: Source directory name to copy files from. Webpack outputs generated files
16 | into this directory and Jupyter Lab copies from this directory during
17 | widget installation
18 | dest: Destination directory name to install widget files to. Jupyter Lab copies
19 | from `src` directory into /labextensions/ directory
20 | during widget installation
21 | """
22 | return [{
23 | 'src': 'labextension',
24 | 'dest': '@guidance-ai/stitch',
25 | }]
26 |
27 |
28 | def _jupyter_nbextension_paths():
29 | """Called by Jupyter Notebook Server to detect if it is a valid nbextension and
30 | to install the widget
31 | Returns
32 | =======
33 | section: The section of the Jupyter Notebook Server to change.
34 | Must be 'notebook' for widget extensions
35 | src: Source directory name to copy files from. Webpack outputs generated files
36 | into this directory and Jupyter Notebook copies from this directory during
37 | widget installation
38 | dest: Destination directory name to install widget files to. Jupyter Notebook copies
39 | from `src` directory into /nbextensions/ directory
40 | during widget installation
41 | require: Path to importable AMD Javascript module inside the
42 | /nbextensions/ directory
43 | """
44 | return [{
45 | 'section': 'notebook',
46 | 'src': 'nbextension',
47 | 'dest': 'stitch',
48 | 'require': 'stitch/extension'
49 | }]
50 |
--------------------------------------------------------------------------------
/packages/python/stitch/stitch/_frontend.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 |
4 | # Copyright (c) Guidance Contributors.
5 | # Distributed under the terms of the Modified BSD License.
6 |
7 | """
8 | Information about the frontend package of the widgets.
9 | """
10 |
11 | module_name = "@guidance-ai/stitch"
12 | module_version = "^0.1.4"
13 |
--------------------------------------------------------------------------------
/packages/python/stitch/stitch/_version.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 |
4 | # Copyright (c) Guidance Contributors.
5 | # Distributed under the terms of the Modified BSD License.
6 |
7 | version_info = (0, 1, 4)
8 | __version__ = ".".join(map(str, version_info))
9 |
--------------------------------------------------------------------------------
/packages/python/stitch/stitch/nbextension/extension.js:
--------------------------------------------------------------------------------
1 | // Entry point for the notebook bundle containing custom model definitions.
2 | //
3 | define(function() {
4 | "use strict";
5 |
6 | window['requirejs'].config({
7 | map: {
8 | '*': {
9 | '@guidance-ai/stitch': 'nbextensions/stitch/index',
10 | },
11 | }
12 | });
13 | // Export the required load_ipython_extension function
14 | return {
15 | load_ipython_extension : function() {}
16 | };
17 | });
--------------------------------------------------------------------------------
/packages/python/stitch/stitch/stitch.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 |
4 | # Copyright (c) Guidance Contributors.
5 | # Distributed under the terms of the Modified BSD License.
6 |
7 | """
8 | Stitch Widget that allows bidirectional comms from Jupyter and JavaScript.
9 | """
10 |
11 | from ipywidgets import DOMWidget
12 | from traitlets import Unicode
13 | from ._frontend import module_name, module_version
14 |
15 |
16 | class StitchWidget(DOMWidget):
17 | """Widget that purely handles communication between an iframe and kernel via postMessage."""
18 |
19 | _model_name = Unicode('StitchModel').tag(sync=True)
20 | _model_module = Unicode(module_name).tag(sync=True)
21 | _model_module_version = Unicode(module_version).tag(sync=True)
22 | _view_name = Unicode('StitchView').tag(sync=True)
23 | _view_module = Unicode(module_name).tag(sync=True)
24 | _view_module_version = Unicode(module_version).tag(sync=True)
25 |
26 | kernelmsg = Unicode("").tag(sync=True)
27 | clientmsg = Unicode("").tag(sync=True)
28 | srcdoc = Unicode("srcdoc should be defined by the user
").tag(sync=True)
29 | initial_height = Unicode("1px").tag(sync=True)
30 | initial_width = Unicode("1px").tag(sync=True)
31 | initial_border = Unicode("0").tag(sync=True)
32 |
33 | # NOTE(nopdive): Should we sync or not? There are overheads when we deal with bandwidth on real time applications.
34 | state = Unicode("").tag(sync=True)
--------------------------------------------------------------------------------
/packages/python/stitch/stitch/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/packages/python/stitch/stitch/tests/__init__.py
--------------------------------------------------------------------------------
/packages/python/stitch/stitch/tests/conftest.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 |
4 | # Copyright (c) Guidance Contributors.
5 | # Distributed under the terms of the Modified BSD License.
6 |
7 | import pytest
8 |
9 | from ipykernel.comm import Comm
10 | from ipywidgets import Widget
11 |
12 | class MockComm(Comm):
13 | """A mock Comm object.
14 |
15 | Can be used to inspect calls to Comm's open/send/close methods.
16 | """
17 | comm_id = 'a-b-c-d'
18 | kernel = 'Truthy'
19 |
20 | def __init__(self, *args, **kwargs):
21 | self.log_open = []
22 | self.log_send = []
23 | self.log_close = []
24 | super(MockComm, self).__init__(*args, **kwargs)
25 |
26 | def open(self, *args, **kwargs):
27 | self.log_open.append((args, kwargs))
28 |
29 | def send(self, *args, **kwargs):
30 | self.log_send.append((args, kwargs))
31 |
32 | def close(self, *args, **kwargs):
33 | self.log_close.append((args, kwargs))
34 |
35 | _widget_attrs = {}
36 | undefined = object()
37 |
38 |
39 | @pytest.fixture
40 | def mock_comm():
41 | _widget_attrs['_comm_default'] = getattr(Widget, '_comm_default', undefined)
42 | Widget._comm_default = lambda self: MockComm()
43 | _widget_attrs['_ipython_display_'] = Widget._ipython_display_
44 | def raise_not_implemented(*args, **kwargs):
45 | raise NotImplementedError()
46 | Widget._ipython_display_ = raise_not_implemented
47 |
48 | yield MockComm()
49 |
50 | for attr, value in _widget_attrs.items():
51 | if value is undefined:
52 | delattr(Widget, attr)
53 | else:
54 | setattr(Widget, attr, value)
55 |
--------------------------------------------------------------------------------
/packages/python/stitch/stitch/tests/test_example.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 |
4 | # Copyright (c) Guidance Contributors.
5 | # Distributed under the terms of the Modified BSD License.
6 |
7 | import pytest
8 |
9 | from ..stitch import StitchWidget
10 |
11 |
12 | def test_example_creation_blank():
13 | w = StitchWidget()
14 | assert w.kernelmsg == ""
15 | assert w.clientmsg == ""
16 | assert w.srcdoc == "srcdoc should be defined by the user
"
17 |
--------------------------------------------------------------------------------
/packages/python/stitch/stitch/tests/test_nbextension_path.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 |
4 | # Copyright (c) Guidance Contributors.
5 | # Distributed under the terms of the Modified BSD License.
6 |
7 |
8 | def test_nbextension_path():
9 | # Check that magic function can be imported from package root:
10 | from stitch import _jupyter_nbextension_paths
11 | # Ensure that it can be called without incident:
12 | path = _jupyter_nbextension_paths()
13 | # Some sanity checks:
14 | assert len(path) == 1
15 | assert isinstance(path[0], dict)
16 |
--------------------------------------------------------------------------------
/packages/python/stitch/tsconfig.eslint.json:
--------------------------------------------------------------------------------
1 | {
2 | "extends": "./tsconfig.json",
3 | "include": ["src/**/*.ts", "src/**/*.tsx"],
4 | "exclude": []
5 | }
--------------------------------------------------------------------------------
/packages/python/stitch/tsconfig.json:
--------------------------------------------------------------------------------
1 | {
2 | "compilerOptions": {
3 | "declaration": true,
4 | "esModuleInterop":true,
5 | "lib": ["es2015", "dom"],
6 | "module": "commonjs",
7 | "moduleResolution": "node",
8 | "noEmitOnError": true,
9 | "noUnusedLocals": true,
10 | "outDir": "lib",
11 | "resolveJsonModule": true,
12 | "rootDir": "src",
13 | "skipLibCheck": true,
14 | "sourceMap": true,
15 | "strict": true,
16 | "strictPropertyInitialization": false,
17 | "target": "es2015",
18 | "types": ["jest"]
19 | },
20 | "include": [
21 | "src/**/*.ts",
22 | "src/**/*.tsx",
23 | ],
24 | "exclude": ["src/**/__tests__"]
25 | }
26 |
--------------------------------------------------------------------------------
/packages/python/stitch/webpack.config.js:
--------------------------------------------------------------------------------
1 | const path = require('path');
2 | const version = require('./package.json').version;
3 |
4 | // Custom webpack rules
5 | const rules = [
6 | { test: /\.ts$/, loader: 'ts-loader' },
7 | { test: /\.js$/, loader: 'source-map-loader' },
8 | { test: /\.css$/, use: ['style-loader', 'css-loader']}
9 | ];
10 |
11 | // Packages that shouldn't be bundled but loaded at runtime
12 | const externals = ['@jupyter-widgets/base'];
13 |
14 | const resolve = {
15 | // Add '.ts' and '.tsx' as resolvable extensions.
16 | extensions: [".webpack.js", ".web.js", ".ts", ".js"]
17 | };
18 |
19 | module.exports = [
20 | /**
21 | * Notebook extension
22 | *
23 | * This bundle only contains the part of the JavaScript that is run on load of
24 | * the notebook.
25 | */
26 | {
27 | entry: './src/extension.ts',
28 | output: {
29 | filename: 'index.js',
30 | path: path.resolve(__dirname, 'stitch', 'nbextension'),
31 | libraryTarget: 'amd',
32 | publicPath: '',
33 | },
34 | module: {
35 | rules: rules
36 | },
37 | devtool: 'source-map',
38 | externals,
39 | resolve,
40 | },
41 |
42 | /**
43 | * Embeddable @guidance-ai/stitch bundle
44 | *
45 | * This bundle is almost identical to the notebook extension bundle. The only
46 | * difference is in the configuration of the webpack public path for the
47 | * static assets.
48 | *
49 | * The target bundle is always `dist/index.js`, which is the path required by
50 | * the custom widget embedder.
51 | */
52 | {
53 | entry: './src/index.ts',
54 | output: {
55 | filename: 'index.js',
56 | path: path.resolve(__dirname, 'dist'),
57 | libraryTarget: 'amd',
58 | library: "@guidance-ai/stitch",
59 | publicPath: 'https://unpkg.com/@guidance-ai/stitch@' + version + '/dist/'
60 | },
61 | devtool: 'source-map',
62 | module: {
63 | rules: rules
64 | },
65 | externals,
66 | resolve,
67 | },
68 |
69 |
70 | /**
71 | * Documentation widget bundle
72 | *
73 | * This bundle is used to embed widgets in the package documentation.
74 | */
75 | {
76 | entry: './src/index.ts',
77 | output: {
78 | filename: 'embed-bundle.js',
79 | path: path.resolve(__dirname, 'docs', 'source', '_static'),
80 | library: "stitch",
81 | libraryTarget: 'amd'
82 | },
83 | module: {
84 | rules: rules
85 | },
86 | devtool: 'source-map',
87 | externals,
88 | resolve,
89 | }
90 |
91 | ];
92 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = [
3 | "setuptools>=42",
4 | "wheel",
5 | "pybind11>=2.10.0",
6 | ]
7 | build-backend = "setuptools.build_meta"
8 |
9 | [tool.pytest.ini_options]
10 | addopts = "--strict-markers"
11 | markers = [
12 | "resource_intensive: Tests which readily exceed resources such as memory",
13 | "asyncio: Make async tests to run"
14 | ]
15 |
16 | [tool.black]
17 | line-length = 99
18 | target_version = ['py39', 'py310', 'py311', 'py312']
19 |
20 | [tool.isort]
21 | profile = "black"
22 |
23 | [tool.mypy]
24 | no_implicit_optional = false # TODO: PEP484, set to true
25 | strict = false # TODO: set to true
26 | exclude = ["tests"]
27 |
28 | [[tool.mypy.overrides]]
29 | module = "pybind11.*"
30 | ignore_missing_imports = true
31 |
32 | [[tool.mypy.overrides]]
33 | module = "vertexai.*"
34 | ignore_missing_imports = true
35 |
36 | [[tool.mypy.overrides]]
37 | module = "google.*"
38 | ignore_missing_imports = true
39 |
40 | [[tool.mypy.overrides]]
41 | module = "llama_cpp.*"
42 | ignore_missing_imports = true
43 |
44 | [[tool.mypy.overrides]]
45 | module = "anthropic.*"
46 | ignore_missing_imports = true
47 |
48 | [[tool.mypy.overrides]]
49 | module = "litellm.*"
50 | ignore_missing_imports = true
51 |
52 | [[tool.mypy.overrides]]
53 | module = "transformers.*"
54 | ignore_missing_imports = true
55 |
56 | [[tool.mypy.overrides]]
57 | module = "diskcache.*"
58 | ignore_missing_imports = true
59 |
60 | [[tool.mypy.overrides]]
61 | module = "tokenizers.*"
62 | ignore_missing_imports = true
--------------------------------------------------------------------------------
/tests/ReadMe.md:
--------------------------------------------------------------------------------
1 | # Testing
2 |
3 | ## Organisation
4 |
5 | The tests are arranged into the following directories:
6 |
7 | - `unit` tests do not depend on LLMs (but may use `model.Mock`)
8 | - `model_integration` tests should run on any (fully supported) model, supplied by the `selected_model` fixture
9 | - `model_specific` tests are for isolating particular issues with individual LLMs
10 | - `need_credentials` tests are for tests which need access to various credentials (mainly `Grammarless` models for endpoints without full Guidance support)
11 | - `notebook` tests are for notebooks
12 |
13 | The `model_specific` tests should make use of the `selected_model` machinery, but skip themselves if the appropriate model is not supplied.
14 | A sample means of achieving this:
15 |
16 | ```python
17 | @pytest.fixture(scope="module")
18 | def phi3_model(selected_model, selected_model_name):
19 | if selected_model_name in ["transformers_phi3_mini_4k_instruct_cpu"]:
20 | return selected_model
21 | else:
22 | pytest.skip("Requires Phi3 model")
23 | ```
24 |
25 | ## Selecting a model
26 |
27 | To select a particular model when running the tests, use the `--selected_model` command line option.
28 | For example:
29 |
30 | ```bash
31 | python -m pytest --selected_model transformers_gemma2_9b_cpu ./tests/model_integration/
32 | ```
33 |
34 | The allowed values for `--selected_model` are in the [`confest.py`](./conftest.py) file, and are defined in the `selected_model` function.
35 | Alternatively, the `GUIDANCE_SELECTED_MODEL` environment variable can be used to override the default value for `--selected_model` (which can be useful when using a debugger).
36 |
37 | ### A Note on Credentials
38 |
39 | As noted above the `need_credentials` tests are mainly for `Grammarless` models - those for remote endpoints which do not support Guidance grammars (there are a few exceptions, which is why the directory isn't simply named `grammarless`).
40 | As endpoints with Guidance grammar support come online, their tests should *not* go in there; these should go into `model_integration` and `model_specific`, but will only be run in CI builds.
41 | Similarly, some models (e.g. LLama3) require credentials in order to download their weights from Hugging Face.
42 | These should be run through the `model_integration` and `model_specific` tests, but this run will happen from the CI build, and hence have credential access.
43 |
44 | ## Testing Goal
45 |
46 | Ideally, when creating a new feature, most of the tests should go into the `unit` directory, and make use of `model.Mock` if needed.
47 | These should always be able to be run with quite a minimal Guidance installation (have to add `pytest`, obviously).
48 | These tests should be fast, and facilitate a developer experience build around running
49 |
50 | ```bash
51 | pytest tests/unit
52 | ```
53 | very frequently.
54 |
55 | There should also be a handful of tests in `model_integration`, which should work with _any_ fully supported Guidance model.
56 | Finally, if any model quirks are noted (and _especially_ if workarounds are required in the code), tests to characterise these should go into `model_specific`.
57 |
58 | In this paradigm, no tests in `unit` or `model_integration` should be using `pytest.skip` (or its variants).
59 | Those in `model_specific` will use `pytest.skip` for when the `selected_model` fixture is not of the appropriate type.
60 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/tests/__init__.py
--------------------------------------------------------------------------------
/tests/bench/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/tests/bench/__init__.py
--------------------------------------------------------------------------------
/tests/bench/test_api.py:
--------------------------------------------------------------------------------
1 | from guidance.bench._api import bench, AVAILABLE_MODELS
2 | from pathlib import Path
3 | import tempfile
4 | import pytest
5 |
6 | @pytest.mark.skip("Waiting on CI upgrades. Need access to env var LANGCHAIN_API_KEY.")
7 | def test_bench():
8 | # TODO(nopdive): Parameterize models once CI is upgraded.
9 | with tempfile.TemporaryDirectory() as tmp_dir:
10 | db_path = Path(tmp_dir) / "bench.db"
11 | db_url = f"sqlite:///{db_path}"
12 | status_df, result_df = bench(db_url, "bench-test", models=AVAILABLE_MODELS[:1], debug_mode=True)
13 |
14 | assert len(status_df) > 0
15 | assert len(result_df) > 0
16 | assert (status_df['status'] == 'COMPLETE').all()
--------------------------------------------------------------------------------
/tests/bench/test_powerlift.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import tempfile
3 |
4 | from guidance.bench._powerlift import retrieve_langchain
5 | from pathlib import Path
6 |
7 | def test_retrieve_langchain_err(monkeypatch):
8 | monkeypatch.delenv("LANGCHAIN_API_KEY", raising=False)
9 | with pytest.raises(ValueError):
10 | gen = retrieve_langchain()
11 | _ = list(gen)
12 |
13 | @pytest.mark.skip("Waiting on CI upgrades. Need access to env var LANGCHAIN_API_KEY.")
14 | def test_retrieve_langchain_basic():
15 | with tempfile.TemporaryDirectory() as tmp_dir:
16 | # Run once
17 | first_results = list(retrieve_langchain(cache_dir=tmp_dir))
18 | langchain_cache_path = Path(tmp_dir, "langchain")
19 | assert Path.exists(langchain_cache_path)
20 |
21 | # Run another time to trigger the cache
22 | second_results = list(retrieve_langchain(cache_dir=tmp_dir))
23 | for first, second in zip(first_results, second_results):
24 | assert first.inputs.equals(second.inputs)
--------------------------------------------------------------------------------
/tests/bench/test_utils.py:
--------------------------------------------------------------------------------
1 | from guidance.bench._utils import lib_bench_dir
2 | import tempfile
3 | from pathlib import Path
4 |
5 | def test_lib_bench_dir_basic():
6 | expected_dir = Path.home() / ".guidance-bench"
7 | actual_dir = lib_bench_dir()
8 |
9 | assert expected_dir == actual_dir
10 | assert Path.exists(actual_dir)
11 |
12 |
13 | def test_lib_bench_dir_env_var(monkeypatch):
14 | with tempfile.TemporaryDirectory() as tmp_dir:
15 | expected_dir = Path(tmp_dir) / "guidance-bench"
16 | monkeypatch.setenv("GUIDANCE_BENCH_DIR", expected_dir)
17 |
18 | actual_dir = lib_bench_dir()
19 | assert expected_dir == actual_dir
20 | assert Path.exists(actual_dir)
--------------------------------------------------------------------------------
/tests/model_integration/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/tests/model_integration/__init__.py
--------------------------------------------------------------------------------
/tests/model_integration/library/test_subgrammar.py:
--------------------------------------------------------------------------------
1 | import re
2 | import numpy as np
3 | import pytest
4 | from jsonschema import validate
5 | import json
6 |
7 | import guidance
8 | from guidance import (
9 | gen,
10 | select,
11 | optional,
12 | one_or_more,
13 | )
14 | from guidance.library._subgrammar import subgrammar, lexeme
15 |
16 |
17 | @guidance(stateless=True)
18 | def json_string(lm):
19 | return lm + lexeme(r'"(\\(["\\\/bfnrt]|u[a-fA-F0-9]{4})|[^"\\\x00-\x1F\x7F]+)*"')
20 |
21 |
22 | @guidance(stateless=True)
23 | def json_number(lm):
24 | return lm + lexeme(r"-?(?:0|[1-9][0-9]*)(?:\.[0-9]+)?(?:[eE][+-]?[0-9]+)?")
25 |
26 |
27 | @guidance(stateless=True)
28 | def json_value(lm):
29 | return lm + select(
30 | [
31 | json_string(),
32 | json_number(),
33 | json_object(),
34 | json_array(),
35 | "true",
36 | "false",
37 | "null",
38 | ]
39 | )
40 |
41 |
42 | @guidance(stateless=True)
43 | def json_member(lm):
44 | return lm + json_string() + ":" + json_value()
45 |
46 |
47 | @guidance(stateless=True)
48 | def json_object(lm):
49 | return lm + "{" + optional(json_member() + one_or_more("," + json_member())) + "}"
50 |
51 |
52 | @guidance(stateless=True)
53 | def json_array(lm):
54 | return lm + "[" + optional(json_value() + one_or_more("," + json_value())) + "]"
55 |
56 |
57 | @guidance(stateless=True)
58 | def gen_json_object(lm, name: str, max_tokens=100000000):
59 | grm = subgrammar(
60 | body=json_object(),
61 | name=name,
62 | skip_regex=r"[\x20\x0A\x0D\x09]+",
63 | max_tokens=max_tokens
64 | )
65 | return lm + grm
66 |
67 |
68 | def test_greedy_json_object(selected_model: guidance.models.Model):
69 | lm = selected_model
70 | lm += "John Doe's name, age, and birthday:\n"
71 | lm += gen_json_object("hacker", max_tokens=1000)
72 | lm += "\nScore: " + gen("score", regex="[1-3]")
73 | # make sure it parses as JSON
74 | obj = json.loads(lm["hacker"])
75 | assert isinstance(obj, dict)
76 | assert lm["score"] in ["1", "2", "3"]
77 |
78 |
79 | def test_greedy_single_terminal(selected_model: guidance.models.Model):
80 | lm = selected_model
81 | lm += "A number: "
82 | lm += subgrammar(body=lexeme(r"[0-9]{3}"))
83 | assert re.search(r": [0-9]{3}$", str(lm))
84 |
--------------------------------------------------------------------------------
/tests/model_integration/library/test_substring.py:
--------------------------------------------------------------------------------
1 | from guidance import gen, models, substring
2 |
3 |
4 | def test_substring_equal_unconstrained(selected_model: models.Model):
5 | target_model = selected_model
6 | lm = target_model + "ae galera " + gen(max_tokens=10, name="test")
7 | lm2 = target_model + "ae galera " + substring(lm["test"], name="capture")
8 | assert lm2["capture"] in lm["test"]
9 |
--------------------------------------------------------------------------------
/tests/model_integration/test_grammar.py:
--------------------------------------------------------------------------------
1 | from guidance import models, select
2 |
3 |
4 | def test_select_simple(selected_model: models.Model):
5 | lm = selected_model
6 | options = ["baad I think", "bad I think", "bad"]
7 | lm = lm + "Scott is quite " + select(name="bad", options=options)
8 | assert lm["bad"] in options
9 |
--------------------------------------------------------------------------------
/tests/model_integration/test_tokenizers.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from tests.tokenizer_common import TOKENIZER_ROUND_TRIP_STRINGS, BaseTestTransformerTokenizers
3 |
4 | # These are not _strictly_ unit tests, since they refer
5 | # to specific tokenisers. However, tokenisers are small,
6 | # so if the tokeniser can be loaded separately from the
7 | # model, then this is a good place to have them live.
8 |
9 | # The LlamaCpp tokenisers have tests in test_llamacpp.py
10 | # since those tokenisers cannot be loaded separately from
11 | # their models.
12 |
13 | # The transformer tests have an authenticated version under
14 | # need_credentials
15 |
16 |
17 | class TestUnauthenticatedTransformerTokenizers(BaseTestTransformerTokenizers):
18 | TRANSFORMER_MODELS = [
19 | "gpt2",
20 | "microsoft/phi-2",
21 | "microsoft/Phi-3-small-8k-instruct",
22 | "microsoft/Phi-3-mini-4k-instruct",
23 | ]
24 |
25 | @pytest.mark.parametrize(
26 | "model_name",
27 | TRANSFORMER_MODELS,
28 | )
29 | def test_smoke(self, model_name: str):
30 | self.base_smoke(model_name)
31 |
32 | @pytest.mark.parametrize("model_name", TRANSFORMER_MODELS)
33 | @pytest.mark.parametrize("target_string", TOKENIZER_ROUND_TRIP_STRINGS)
34 | def test_string_roundtrip(self, model_name: str, target_string: str):
35 | self.base_string_roundtrip(model_name, target_string)
36 |
37 | @pytest.mark.parametrize("model_name", TRANSFORMER_MODELS)
38 | def test_eos_bos_token_round_trip(self, model_name: str):
39 | self.base_eos_bos_token_round_trip(model_name)
40 |
--------------------------------------------------------------------------------
/tests/model_specific/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/tests/model_specific/__init__.py
--------------------------------------------------------------------------------
/tests/model_specific/common_chat_testing.py:
--------------------------------------------------------------------------------
1 | from guidance import assistant, gen, models, system, user
2 |
3 |
4 | def smoke_chat(lm: models.Model, has_system_role: bool = True):
5 | # lm.engine.reset_metrics()
6 | if has_system_role:
7 | with system():
8 | lm += "You are a math wiz."
9 |
10 | with user():
11 | lm += "What is 1 + 1?"
12 |
13 | with assistant():
14 | lm += gen(max_tokens=10, name="text", temperature=0.5)
15 |
16 | print(str(lm))
17 | # print(f"{lm.engine.metrics=}")
18 | assert len(lm["text"]) > 0
19 | # assert lm.engine.metrics.engine_input_tokens > 2, "Expect some input tokens"
20 | # assert lm.engine.metrics.engine_output_tokens > 0, "Expect some output tokens"
21 |
22 |
23 | def longer_chat_1(lm: models.Model, has_system_role: bool = True):
24 | if has_system_role:
25 | with system():
26 | lm += "You are a math wiz."
27 |
28 | with user():
29 | lm += "What is 1 + 1?"
30 |
31 | with assistant():
32 | lm += gen(max_tokens=10, name="text")
33 |
34 | print(str(lm))
35 | assert len(lm["text"]) > 0
36 |
37 | with user():
38 | lm += "10. Now you pick a number between 0 and 20"
39 |
40 | with assistant():
41 | lm += gen(max_tokens=2, name="number")
42 |
43 | print(str(lm))
44 | assert len(lm["number"]) > 0
45 |
46 |
47 | def longer_chat_2(lm: models.Model, has_system_role: bool = True):
48 | if has_system_role:
49 | with system():
50 | lm += "You are a math wiz."
51 |
52 | with user():
53 | lm += "What is 1 + 1?"
54 |
55 | # This is the new part compared to longer_chat_1
56 | with assistant():
57 | lm += "2"
58 |
59 | with user():
60 | lm += "What is 2 + 3?"
61 |
62 | # Resume the previous
63 | with assistant():
64 | lm += gen(max_tokens=10, name="text")
65 |
66 | print(str(lm))
67 | assert len(lm["text"]) > 0
68 |
69 | with user():
70 | lm += "10. Now you pick a number between 0 and 20"
71 |
72 | with assistant():
73 | lm += gen(max_tokens=2, name="number")
74 |
75 | print(str(lm))
76 | assert len(lm["number"]) > 0
77 |
--------------------------------------------------------------------------------
/tests/model_specific/llama_cpp_tests/__init__.py:
--------------------------------------------------------------------------------
1 | # Set up for pytest
--------------------------------------------------------------------------------
/tests/model_specific/llama_cpp_tests/test_chat_templates.py:
--------------------------------------------------------------------------------
1 | import jinja2
2 | import pytest
3 |
4 | import guidance
5 |
6 | from guidance.chat import CHAT_TEMPLATE_CACHE
7 |
8 |
9 | def test_chat_format_smoke(llamacpp_model: guidance.models.LlamaCpp, selected_model_name):
10 | # Retrieve the template string
11 | if (
12 | hasattr(llamacpp_model.engine.model_obj, "metadata")
13 | and "tokenizer.chat_template" in llamacpp_model.engine.model_obj.metadata
14 | ):
15 | model_chat_template = llamacpp_model.engine.model_obj.metadata["tokenizer.chat_template"]
16 | else:
17 | pytest.skip("Chat template not available from LlamaCpp object")
18 |
19 | lm = guidance.models.Mock("")
20 | lm._interpreter.chat_template = CHAT_TEMPLATE_CACHE[model_chat_template]()
21 |
22 | messages = [
23 | {"role": "user", "content": "Good_day_to_you!"},
24 | {"role": "assistant", "content": "Hello!"},
25 | ]
26 |
27 | # Note that llama-cpp-python does provide a llama_chat_apply_template function
28 | # but details about its use are thin on the ground and according to
29 | # https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template
30 | # it does its own thing internally
31 | jinja2_template = jinja2.Environment(loader=jinja2.BaseLoader()).from_string(
32 | model_chat_template
33 | )
34 | jinja2_render = jinja2_template.render(
35 | messages=messages,
36 | bos_token=llamacpp_model.engine.tokenizer.bos_token.decode(),
37 | eos_token=llamacpp_model.engine.tokenizer.eos_token.decode(),
38 | )
39 |
40 | with guidance.user():
41 | lm += "Good_day_to_you!"
42 | with guidance.assistant():
43 | lm += "Hello!"
44 | # Only check substring due to BOS/EOS tokens
45 | if selected_model_name == "llamacpp_mistral_7b_cpu":
46 | # The templates extracted via Transformers and GGUF are somewhat
47 | # different for Mistral. This is showing up in slightly
48 | # different spacing (our template is putting in a few extra spaces)
49 | # so at least make sure the 'tags' are correct
50 | assert str(lm).replace(" ", "") in jinja2_render.replace(" ", "")
51 | else:
52 | assert str(lm) in jinja2_render
53 |
--------------------------------------------------------------------------------
/tests/model_specific/test_visual.py:
--------------------------------------------------------------------------------
1 | from guidance.registry import get_renderer
2 |
3 |
4 | def test_repeat_simple_model():
5 | from guidance.models import Transformers
6 | from guidance import gen
7 | from guidance.registry import set_renderer, get_trace_handler
8 | from guidance.visual import JupyterWidgetRenderer
9 |
10 | trace_handler = get_trace_handler()
11 | original_renderer = get_renderer()
12 | for i in range(2):
13 | set_renderer(JupyterWidgetRenderer(trace_handler))
14 |
15 | lm = Transformers('gpt2')
16 | lm += 'Hi hi hi'
17 | lm += gen(max_tokens=5)
18 |
19 | set_renderer(original_renderer)
20 |
21 | assert True
22 |
23 |
24 | def test_roles():
25 | from guidance.models import Transformers
26 | from guidance import gen, user, system
27 |
28 | m0 = Transformers("gpt2")
29 | with system():
30 | m1 = m0 + "You are responsible for writing an epic poem."
31 | with user():
32 | m2 = m1 + "Roses are red and " + gen(name="suffix", regex=r'[\w\s]{20,30}', max_tokens=30)
33 |
34 | assert m2 is not None
--------------------------------------------------------------------------------
/tests/need_credentials/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/tests/need_credentials/__init__.py
--------------------------------------------------------------------------------
/tests/need_credentials/test_anthropic.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | import guidance
4 | from guidance import assistant, capture, gen, select, system, user
5 |
6 | from ..utils import get_model
7 |
8 |
9 | def test_anthropic_chat():
10 | try:
11 | lm = guidance.models.Anthropic(model="claude-3-haiku-20240307")
12 | except:
13 | pytest.skip("Skipping Anthropic test because we can't load the model!")
14 | with system():
15 | lm += "You are a math wiz."
16 |
17 | with user():
18 | lm += "What is 1 + 1?"
19 |
20 | with assistant():
21 | lm += gen(max_tokens=10, name="text")
22 | lm += "Pick a number: "
23 |
24 | assert len(lm["text"]) > 0
25 |
26 |
27 | def test_anthropic_select():
28 | try:
29 | lm = guidance.models.Anthropic(model="claude-instant-1.2")
30 | except:
31 | pytest.skip("Skipping Anthropic test because we can't load the model!")
32 |
33 | # We can't meaningfully test or enforce select on this model
34 | with pytest.raises(guidance.models._model.ConstraintException):
35 | with user():
36 | lm += "Write the next number in the list: 1,2,3,4,5,6,"
37 | with assistant():
38 | lm += select(
39 | ["harsha", "scott", "marco"], name="the number"
40 | )
41 |
42 |
43 | def test_anthropic_chat_loop():
44 | # tests issue #509
45 | try:
46 | model = guidance.models.Anthropic(model="claude-3-haiku-20240307")
47 | except:
48 | pytest.skip("Skipping Anthropic test because we can't load the model!")
49 |
50 | for i in range(2):
51 |
52 | with system():
53 | lm = model + "You will just return whatever number I give you"
54 |
55 | with user():
56 | lm += f"The number is: {i}"
57 |
58 | with assistant():
59 | lm += gen(name="answer", max_tokens=2)
60 |
61 | # def test_direct_anthropic_api():
62 | # import anthropic
63 |
64 | # client = anthropic.Anthropic()
65 |
66 | # with client.messages.stream(
67 | # max_tokens=10,
68 | # system="You are a counting robot. Do nothing but continue counting numbers in the same format the user presented.",
69 | # messages=[{"role": "user", "content": "1,2,3,4,5,"}],
70 | # model="claude-3-haiku-20240307",
71 | # ) as stream:
72 | # text_list = []
73 | # for text in stream.text_stream:
74 | # print(text, end="", flush=True)
75 | # text_list.append(text)
76 |
77 | # assert len(text_list) > 0
--------------------------------------------------------------------------------
/tests/need_credentials/test_azureai_studio.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from guidance import models
4 | from guidance.models._azureai import create_azure_aifoundry_model
5 |
6 |
7 | from ..model_specific import common_chat_testing
8 | from ..utils import env_or_fail, slowdown
9 |
10 | # pytest.skip("Deployments temporarily deleted", allow_module_level=True)
11 |
12 | # How to fill out the environment variables to
13 | # set up the models
14 | # Temporarily remove mistral pending endpoint investigation
15 | # _chat_models = {"phi3": "PHI3", "llama3": "LLAMA3_CHAT"}
16 | _chat_models = {"phi4": "PHI4"}
17 |
18 |
19 | def _get_chat_model(model_name: str):
20 | env_string = _chat_models[model_name]
21 |
22 | azureai_studio_endpoint = env_or_fail(f"AZUREAI_STUDIO_{env_string}_ENDPOINT")
23 | azureai_studio_model_name = env_or_fail(f"AZUREAI_STUDIO_{env_string}_MODEL_NAME")
24 | azureai_studio_key = env_or_fail(f"AZUREAI_STUDIO_{env_string}_KEY")
25 |
26 | lm = create_azure_aifoundry_model(
27 | azure_endpoint=azureai_studio_endpoint,
28 | api_key=azureai_studio_key,
29 | # token_credential=DefaultAzureCredential(),
30 | model_name=azureai_studio_model_name,
31 | )
32 | assert isinstance(lm, models.Model)
33 | return lm
34 |
35 |
36 | @pytest.mark.parametrize("chat_model_name", _chat_models.keys())
37 | def test_azureai_chat_smoke(chat_model_name: str):
38 | slowdown()
39 |
40 | lm = _get_chat_model(chat_model_name)
41 |
42 | common_chat_testing.smoke_chat(lm, chat_model_name != "mistral")
43 |
44 |
45 | @pytest.mark.parametrize("chat_model_name", _chat_models.keys())
46 | def test_azureai_chat_longer_1(chat_model_name: str):
47 | slowdown()
48 |
49 | lm = _get_chat_model(chat_model_name)
50 | common_chat_testing.longer_chat_1(lm, chat_model_name != "mistral")
51 |
52 |
53 | @pytest.mark.parametrize("chat_model_name", _chat_models.keys())
54 | def test_azureai_chat_longer_2(chat_model_name: str):
55 | slowdown()
56 |
57 | lm = _get_chat_model(chat_model_name)
58 | common_chat_testing.longer_chat_2(lm, chat_model_name != "mistral")
59 |
--------------------------------------------------------------------------------
/tests/need_credentials/test_cohere.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | import guidance
4 | from guidance import assistant, capture, gen, role, select, system, user
5 |
6 |
7 | def test_lite_llm_basic():
8 | try:
9 | lm = guidance.models.CohereCompletion("command-nightly")
10 | except:
11 | pytest.skip("Skipping Cohere test because we can't load the model!")
12 | lm += "Count to 20: 1,2,3,4,"
13 | nl = "\n"
14 | lm += f"""\
15 | 5,6,7"""
16 | lm += f"""{gen(max_tokens=1, suffix=nl)}aaaaaa"""
17 | assert str(lm)[-5:] == "aaaaa"
18 |
19 |
20 | def test_lite_llm_instruct():
21 | try:
22 | lm = guidance.models.CohereInstruct("command-nightly")
23 | except:
24 | pytest.skip("Skipping LiteLLM test because we can't load the model!")
25 | with role("instruction"):
26 | lm += "Count to 20."
27 | lm += gen("val", max_tokens=1)
28 | assert len(lm["val"]) > 0
29 |
--------------------------------------------------------------------------------
/tests/need_credentials/test_googleai.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from guidance import gen, models, select
4 |
5 | from ..utils import get_model
6 |
7 |
8 | def test_gemini_pro():
9 | from guidance import assistant, gen, models, system, user
10 |
11 | try:
12 | vmodel = models.GoogleAI("gemini-pro")
13 | except:
14 | pytest.skip("Skipping GoogleAI test because we can't load the model!")
15 |
16 | lm = vmodel
17 |
18 | with user():
19 | lm += "The economy is crashing!"
20 |
21 | with assistant():
22 | lm += gen("test1", max_tokens=100)
23 |
24 | with user():
25 | lm += "What is the best again?"
26 |
27 | with assistant():
28 | lm += gen("test2", max_tokens=100)
29 |
30 | assert len(lm["test1"]) > 0
31 | assert len(lm["test2"]) > 0
32 |
33 | # second time to make sure cache reuse is okay
34 | lm = vmodel
35 |
36 | with user():
37 | lm += "The economy is crashing!"
38 |
39 | with assistant():
40 | lm += gen("test1", max_tokens=100)
41 |
42 | with user():
43 | lm += "What is the best again?"
44 |
45 | with assistant():
46 | lm += gen("test2", max_tokens=100)
47 |
48 | assert len(lm["test1"]) > 0
49 | assert len(lm["test2"]) > 0
50 | assert lm["test1"].find("<|im_end|>") < 0
51 |
--------------------------------------------------------------------------------
/tests/need_credentials/test_lite_llm.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | import guidance
4 | from guidance import assistant, capture, gen, select, system, user
5 |
6 | from ..utils import get_model
7 |
8 |
9 | def test_lite_llm_basic_openai():
10 | try:
11 | lm = guidance.models.LiteLLMCompletion("gpt-3.5-turbo-instruct")
12 | except:
13 | pytest.skip("Skipping LiteLLM test because we can't load the model!")
14 | lm += "Count to 20: 1,2,3,4,"
15 | nl = "\n"
16 | lm += f"""\
17 | 5,6,7"""
18 | lm += f"""{gen(max_tokens=1, suffix=nl)}aaaaaa"""
19 | assert str(lm)[-5:] == "aaaaa"
20 |
21 |
22 | def test_lite_llm_basic_cohere():
23 | try:
24 | lm = guidance.models.LiteLLMCompletion("command-nightly")
25 | except:
26 | pytest.skip("Skipping LiteLLM test because we can't load the model!")
27 | lm += "Count to 20: 1,2,3,4,"
28 | nl = "\n"
29 | lm += f"""\
30 | 5,6,7"""
31 | lm += f"""{gen(max_tokens=1, suffix=nl)}aaaaaa"""
32 | assert str(lm)[-5:] == "aaaaa"
33 |
34 |
35 | def test_lite_llm_select():
36 | try:
37 | lm = guidance.models.LiteLLMCompletion("gpt-3.5-turbo-instruct")
38 | except:
39 | pytest.skip("Skipping LiteLLM test because we can't load the model!")
40 | lm += "Pick a number: "
41 | lm += select(
42 | ["1", "11", "111", "1111", "11111", "111111", "1111111"], name="the number"
43 | )
44 | assert str(lm)[-1] in "123"
45 |
46 |
47 | def test_lite_llm_chat():
48 | try:
49 | lm = guidance.models.LiteLLMChat("gpt-3.5-turbo")
50 | except:
51 | pytest.skip("Skipping LiteLLM test because we can't load the model!")
52 | with system():
53 | lm += "You are a math wiz."
54 |
55 | with user():
56 | lm += "What is 1 + 1?"
57 |
58 | with assistant():
59 | lm += gen(max_tokens=10, name="text")
60 | lm += "Pick a number: "
61 |
62 | assert len(lm["text"]) > 0
63 |
--------------------------------------------------------------------------------
/tests/need_credentials/test_togetherai.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | import guidance
4 | from guidance import assistant, gen, select, system, user
5 |
6 |
7 | def test_togetherai_basic():
8 | try:
9 | lm = guidance.models.TogetherAI("mistralai/Mistral-7B-v0.1")
10 | except:
11 | pytest.skip("Skipping TogetherAI test because we can't load the model!")
12 | lm += "Count to 20: 1,2,3,4,"
13 | stop = "\n"
14 | lm += f"""{gen(max_tokens=1, stop=stop, name="text")}"""
15 | assert str(lm)[-1] == "5"
16 |
17 |
18 | def test_togetherai_select():
19 | try:
20 | lm = guidance.models.TogetherAI("mistralai/Mistral-7B-v0.1")
21 | except:
22 | pytest.skip("Skipping TogetherAI test because we can't load the model!")
23 | nums = ["1", "11", "111", "1111", "11111", "111111", "1111111"]
24 | lm += "Pick a number: "
25 | lm += select(nums, name="number")
26 | assert str(lm["number"]) in nums
27 |
28 |
29 | def test_togetherai_chat():
30 | try:
31 | lm = guidance.models.TogetherAIChat("teknium/OpenHermes-2-Mistral-7B")
32 | except:
33 | pytest.skip("Skipping TogetherAI test because we can't load the model!")
34 | with system():
35 | lm += "You are a math wiz."
36 |
37 | with user():
38 | lm += "What is 1 + 1?"
39 |
40 | with assistant():
41 | lm += gen(max_tokens=10, name="text")
42 | lm += "Pick a number: "
43 |
44 | assert len(lm["text"]) > 0
45 |
46 |
47 | def test_togetherai_chat_without_roles():
48 | try:
49 | lm = guidance.models.TogetherAIChat("teknium/OpenHermes-2-Mistral-7B")
50 | except:
51 | pytest.skip("Skipping TogetherAI test because we can't load the model!")
52 | with pytest.raises(ValueError) as error_info:
53 | lm += "You are a math wiz. What is 1+1?" + gen(max_tokens=10, name="text")
54 |
55 |
56 | def test_togetherai_chat_loop():
57 | try:
58 | model = guidance.models.TogetherAIChat(
59 | "teknium/OpenHermes-2-Mistral-7B", echo=False
60 | )
61 | except:
62 | pytest.skip("Skipping TogetherAI test because we can't load the model!")
63 |
64 | with system():
65 | lm = model + "You will just return whatever number I give you"
66 |
67 | for i in range(2):
68 | with user():
69 | lm += f"The number is: {i}"
70 |
71 | with assistant():
72 | lm += gen(name="answer", max_tokens=10)
73 | assert len(lm["answer"]) > 0
74 |
--------------------------------------------------------------------------------
/tests/need_credentials/test_tokenizers.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from guidance import models
4 |
5 | from tests.tokenizer_common import TOKENIZER_ROUND_TRIP_STRINGS, BaseTestTransformerTokenizers
6 |
7 | # Since this is under 'need_credentials' we can assume that HF_TOKEN
8 | # will be available when run
9 |
10 |
11 | class TestAuthenticatedTransformerTokenizers(BaseTestTransformerTokenizers):
12 | TRANSFORMER_MODELS = [
13 | # "google/gemma-2-9b-it", # Works locally, fails in build
14 | "meta-llama/Meta-Llama-3-8B-Instruct",
15 | ]
16 |
17 | @pytest.mark.parametrize(
18 | "model_name",
19 | TRANSFORMER_MODELS,
20 | )
21 | def test_smoke(self, model_name: str):
22 | try:
23 | self.base_smoke(model_name)
24 | except OSError:
25 | pytest.skip("HuggingFace raises OSError if user is not authenticated.")
26 |
27 | @pytest.mark.parametrize("model_name", TRANSFORMER_MODELS)
28 | @pytest.mark.parametrize("target_string", TOKENIZER_ROUND_TRIP_STRINGS)
29 | def test_string_roundtrip(self, model_name: str, target_string: str):
30 | try:
31 | self.base_string_roundtrip(model_name, target_string)
32 | except OSError:
33 | pytest.skip("HuggingFace raises OSError if user is not authenticated.")
34 |
--------------------------------------------------------------------------------
/tests/need_credentials/test_vertexai.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from guidance import gen, role, models, select
4 |
5 | from ..utils import get_model
6 |
7 |
8 | def test_palm2_instruct():
9 | try:
10 | vmodel = models.VertexAI("text-bison@001")
11 | except:
12 | pytest.skip("Skipping VertexAI test because we can't load the model!")
13 |
14 | with role("instruction"):
15 | lm = vmodel + "this is a test about"
16 | lm += gen("test", max_tokens=100)
17 | assert len(lm["test"]) > 0
18 |
19 |
20 | def test_palm2_chat():
21 | from guidance import assistant, gen, models, system, user
22 |
23 | try:
24 | vmodel = models.VertexAI("chat-bison@001")
25 | except:
26 | pytest.skip("Skipping VertexAI test because we can't load the model!")
27 |
28 | with system():
29 | lm = vmodel + "You are an always-happy agent no matter what."
30 |
31 | with user():
32 | lm += "The economy is crashing!"
33 |
34 | with assistant():
35 | lm += gen("test1", max_tokens=100)
36 |
37 | with user():
38 | lm += "What is the best again?"
39 |
40 | with assistant():
41 | lm += gen("test2", max_tokens=100)
42 |
43 | assert len(lm["test1"]) > 0
44 | assert len(lm["test2"]) > 0
45 |
46 | # second time to make sure cache reuse is okay
47 | with system():
48 | lm = vmodel + "You are an always-happy agent no matter what."
49 |
50 | with user():
51 | lm += "The economy is crashing!"
52 |
53 | with assistant():
54 | lm += gen("test1", max_tokens=100)
55 |
56 | with user():
57 | lm += "What is the best again?"
58 |
59 | with assistant():
60 | lm += gen("test2", max_tokens=100)
61 |
62 | assert len(lm["test1"]) > 0
63 | assert len(lm["test2"]) > 0
64 | assert lm["test1"].find("<|im_end|>") < 0
65 |
66 |
67 | def test_gemini_chat():
68 | from guidance import assistant, gen, models, system, user
69 |
70 | try:
71 | vmodel = models.VertexAI("gemini-pro")
72 | except:
73 | pytest.skip("Skipping VertexAI test because we can't load the model!")
74 |
75 | lm = vmodel
76 |
77 | with user():
78 | lm += "The economy is crashing!"
79 |
80 | with assistant():
81 | lm += gen("test1", max_tokens=100)
82 |
83 | with user():
84 | lm += "What is the best again?"
85 |
86 | with assistant():
87 | lm += gen("test2", max_tokens=100)
88 |
89 | assert len(lm["test1"]) > 0
90 | assert len(lm["test2"]) > 0
91 |
92 | # second time to make sure cache reuse is okay
93 | lm = vmodel
94 |
95 | with user():
96 | lm += "The economy is crashing!"
97 |
98 | with assistant():
99 | lm += gen("test1", max_tokens=100)
100 |
101 | with user():
102 | lm += "What is the best again?"
103 |
104 | with assistant():
105 | lm += gen("test2", max_tokens=100)
106 |
107 | assert len(lm["test1"]) > 0
108 | assert len(lm["test2"]) > 0
109 | assert lm["test1"].find("<|im_end|>") < 0
110 |
--------------------------------------------------------------------------------
/tests/notebooks/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/tests/notebooks/__init__.py
--------------------------------------------------------------------------------
/tests/tokenizer_common.py:
--------------------------------------------------------------------------------
1 | from guidance import models
2 |
3 | TOKENIZER_ROUND_TRIP_STRINGS = [
4 | "",
5 | " ",
6 | "hello",
7 | " hello",
8 | "two words",
9 | " two words",
10 | " two words ",
11 | "two words ",
12 | "’",
13 | "’•¶∂ƒ˙∆£Ħ爨ൠᅘ∰፨",
14 | ]
15 |
16 |
17 | class BaseTestTransformerTokenizers:
18 | def base_smoke(self, model_name: str):
19 | my_tok = models.TransformersTokenizer.from_pretrained(
20 | model_name, trust_remote_code=True,
21 | )
22 | assert my_tok is not None
23 |
24 | def base_string_roundtrip(self, model_name: str, target_string: str):
25 | my_tok = models.TransformersTokenizer.from_pretrained(
26 | model_name,
27 | trust_remote_code=True,
28 | )
29 |
30 | encoded = my_tok.encode(target_string.encode())
31 | decoded = my_tok.decode(encoded)
32 | final_string = decoded.decode()
33 |
34 | assert final_string == target_string
35 |
36 | def base_eos_bos_token_round_trip(
37 | self, model_name: str
38 | ):
39 | my_tok = models.TransformersTokenizer.from_pretrained(
40 | model_name,
41 | trust_remote_code=True,
42 | )
43 |
44 | assert my_tok.eos_token == my_tok.decode([my_tok.eos_token_id])
45 | assert my_tok.encode(my_tok.eos_token) == [my_tok.eos_token_id]
46 |
47 | if my_tok.bos_token is not None:
48 | assert my_tok.bos_token == my_tok.decode([my_tok.bos_token_id])
49 | assert my_tok.encode(my_tok.bos_token) == [my_tok.bos_token_id]
50 |
--------------------------------------------------------------------------------
/tests/unit/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/tests/unit/__init__.py
--------------------------------------------------------------------------------
/tests/unit/library/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/tests/unit/library/__init__.py
--------------------------------------------------------------------------------
/tests/unit/library/json/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/tests/unit/library/json/__init__.py
--------------------------------------------------------------------------------
/tests/unit/library/json/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | from functools import partial
3 | from json import dumps as json_dumps
4 | from json import loads as json_loads
5 | from typing import Any, Optional, Union
6 |
7 | from jsonschema import validate
8 |
9 | from guidance import json as gen_json
10 | from guidance.library._json import JSONSchema
11 |
12 | from ....utils import check_match_failure as _check_match_failure
13 | from ....utils import check_run_with_temperature
14 | from ....utils import generate_and_check as _generate_and_check
15 |
16 |
17 | def generate_and_check(
18 | target_obj: Any,
19 | schema_obj: Union[str, JSONSchema],
20 | desired_temperature: Optional[float] = None,
21 | ):
22 | if isinstance(schema_obj, str):
23 | schema_obj = json_loads(schema_obj)
24 |
25 | # Sanity check what we're being asked
26 | validate(instance=target_obj, schema=schema_obj)
27 | prepared_json = json_dumps(target_obj)
28 | assert json.loads(prepared_json) == target_obj
29 |
30 | # Now test that the grammar can recognize and generate prepared_json
31 | # We partial in the grammar_callable
32 | if desired_temperature is not None:
33 | grammar_callable = partial(gen_json, schema=schema_obj, temperature=desired_temperature)
34 | else:
35 | grammar_callable = partial(gen_json, schema=schema_obj)
36 |
37 | lm = _generate_and_check(
38 | grammar_callable,
39 | test_string=prepared_json,
40 | )
41 | check_run_with_temperature(lm, desired_temperature)
42 |
43 |
44 | def check_match_failure(
45 | *,
46 | bad_string: str,
47 | good_bytes: Optional[bytes] = None,
48 | failure_byte: Optional[bytes] = None,
49 | allowed_bytes: Optional[set[bytes]] = None,
50 | schema_obj: Union[str, JSONSchema],
51 | ):
52 | grammar = gen_json(schema=schema_obj)
53 |
54 | _check_match_failure(
55 | bad_string=bad_string,
56 | good_bytes=good_bytes,
57 | failure_byte=failure_byte,
58 | allowed_bytes=allowed_bytes,
59 | grammar=grammar,
60 | )
61 |
--------------------------------------------------------------------------------
/tests/unit/library/test_block.py:
--------------------------------------------------------------------------------
1 | from guidance import regex, block, models
2 | import pytest
3 |
4 |
5 | def test_text_opener():
6 | model = models.Mock("open texta")
7 | with block(opener="open text"):
8 | model += regex(r".")
9 | assert str(model) == "open texta"
10 |
11 |
12 | def test_text_closer():
13 | # NOTE(nopdive): Behavioral change, no longer need closer for str call.
14 | model = models.Mock("a")
15 | model += ""
16 | with block(closer="close text"):
17 | model += regex(r".")
18 | assert str(model) == "a"
19 |
20 |
21 | def test_grammar_opener():
22 | model = models.Mock("open texta")
23 | with block(opener="open tex" + regex(r".")):
24 | model += regex(r".")
25 | assert str(model) == "open texta"
26 |
27 |
28 | # TODO(nopdive): Review this exception later -- how should we be going about grammars in blocks overall.
29 | @pytest.mark.skip(reason="requires review")
30 | def test_grammar_closer():
31 | model = models.Mock(["aclose text", "close text"])
32 | model += ""
33 | try:
34 | with block(closer=regex(r".") + "lose text"):
35 | model += regex(r".")
36 | except:
37 | return # we expect an exception
38 | assert (
39 | False
40 | ), "We should have thrown an exception using a context (prompt) based grammar in the closer!"
41 |
42 |
43 | def test_block_name_capture():
44 | model = models.Mock("open texta")
45 | with block("my_data"):
46 | model += "open text"
47 | model += regex(r".")
48 | assert model["my_data"] == "open texta"
49 |
50 |
51 | def test_block_name_capture_closed():
52 | model = models.Mock("open texta")
53 | with block("my_data"):
54 | model += "open text"
55 | model += regex(r".")
56 | model += "tmp"
57 | assert model["my_data"] == "open texta"
58 |
--------------------------------------------------------------------------------
/tests/unit/library/test_capture.py:
--------------------------------------------------------------------------------
1 | from guidance import capture, models, one_or_more, select, guidance
2 |
3 |
4 | def test_capture():
5 | model = models.Mock()
6 | model += "This is" + capture(select(options=["bad", "quite bad"]), name="my_var")
7 | assert model["my_var"] in ["bad", "quite bad"]
8 |
9 |
10 | def test_capture_star():
11 | lm = models.Mock(b"1234233234")
12 | grammar = capture(one_or_more(select(["1", "2"])), name="test")
13 | lm2 = lm + grammar
14 | assert lm2["test"] == "12"
15 |
16 |
17 | def test_capture_raw_function():
18 | lm = models.Mock(b"1234233234")
19 | lm += select(["1", "2"], name="state")
20 |
21 | @guidance
22 | def raw_fn(lm):
23 | if lm["state"] == "1":
24 | lm += select(["3", "4"], name="state_1")
25 | elif lm["state"] == "2":
26 | lm += select(["5", "6"], name="state_2")
27 | return lm
28 |
29 | lm_nocap = lm + "the beginning|" + raw_fn() + "|the end"
30 | lm_cap_arg = lm + "the beginning|" + capture("" + raw_fn() + "" , "cap_arg") + "|the end"
31 | lm_cap_kwarg = lm + "the beginning|" + capture("" + raw_fn() + "", name="cap_kwarg") + "|the end"
32 |
33 | # Bunch of random tests
34 | assert "state_1" in lm_nocap or "state_2" in lm_nocap
35 | assert "cap_arg" in lm_cap_arg
36 | assert "cap_kwarg" in lm_cap_kwarg
37 | assert lm_cap_arg["cap_arg"].startswith("")
38 | assert lm_cap_arg["cap_arg"].endswith("")
39 | assert lm_cap_kwarg["cap_kwarg"].startswith("")
40 | assert lm_cap_kwarg["cap_kwarg"].endswith("")
41 | assert len(lm_cap_arg["cap_arg"]) == len(lm_cap_kwarg["cap_kwarg"])
42 |
43 | assert str(lm_nocap).endswith("|the end")
44 | assert str(lm_cap_arg).endswith("|the end")
45 | assert str(lm_cap_kwarg).endswith("|the end")
--------------------------------------------------------------------------------
/tests/unit/library/test_image.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import uuid
3 | import requests
4 | import tempfile
5 | import pathlib
6 |
7 | from urllib.error import HTTPError, URLError
8 | from guidance import models, image
9 | from ...utils import remote_image_url
10 |
11 | #################################################################################
12 | # The tests below need to be rewritten once multimodal support is complete
13 | # A pseudocode description has been written in comments to preserve notes about
14 | # what was tested, for reference, in case we want to reproduct it in the new system
15 | #################################################################################
16 |
17 | def test_local_image():
18 | # 1. Create a mock model
19 | # 2. Add an image from the local filesystem to the model's prompt
20 | # 3. Validate that the model contains the image in its prompt
21 | pass
22 |
23 |
24 | def test_local_image_not_found():
25 | # 1. Create a mock model
26 | # 2. Try to add a non-existing image from the local filesystem to the model's prompt
27 | # 3. Check for a file not found error, or other appropriate exception, to be thrown
28 | pass
29 |
30 |
31 | def test_remote_image():
32 | # 1. Create a mock model
33 | # 2. Add a remote image from picsum using remote_image_url() utility function
34 | # 3. Validate that the model contains the image in its prompt
35 | pass
36 |
37 |
38 | def test_remote_image_not_found():
39 | # 1. Create a mock model
40 | # 2. Try to add a non-existing remote image
41 | # 3. Catch an HTTPError or URLError from the model trying to fetch the image, which should result in a 404
42 | pass
43 |
44 |
45 | def test_image_from_bytes():
46 | # 1. Create a mock model
47 | # 2. Download an image from remote_image_url() and save it as a binary file
48 | # 3. Read the binary file and add it to the model's prompt as an image
49 | # 3. Validate that the model contains the image in its prompt
50 | pass
51 |
--------------------------------------------------------------------------------
/tests/unit/library/test_one_or_more.py:
--------------------------------------------------------------------------------
1 | from guidance import models, one_or_more, regex
2 |
3 |
4 | def test_string():
5 | model = models.Mock("aaabc")
6 | assert str(model + "" + one_or_more("a")) == "aaa"
7 |
8 |
9 | def test_grammar():
10 | model = models.Mock("bac")
11 | assert str(model + "" + one_or_more(regex(r"[ab]"))) == "ba"
12 |
13 |
14 | def test_at_least_one():
15 | model = models.Mock("cbac")
16 | assert not str(model + "" + one_or_more(regex(r"[ab]"))).startswith("c")
17 |
--------------------------------------------------------------------------------
/tests/unit/library/test_subgrammar.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from guidance.library._subgrammar import subgrammar, lexeme
3 |
4 | class TestEndingLexemeAmbiguous:
5 | @pytest.mark.parametrize(
6 | "skip_rx",
7 | [None, r"\s", r"\s+", r"\s*"]
8 | )
9 | @pytest.mark.parametrize(
10 | "string",
11 | ["123"]
12 | )
13 | def test_lexeme_can_be_done_even_if_could_match_more(self, string, skip_rx):
14 | g1 = subgrammar(body=lexeme(r"\d+"), skip_regex=skip_rx, name="mycap")
15 | assert (m := g1.match(string)) is not None and m.captures["mycap"] == string
16 | g2 = g1 + "x"
17 | assert (m := g2.match(f"{string}x")) is not None and m.captures["mycap"] == string
18 |
19 | @pytest.mark.parametrize(
20 | "string",
21 | ["1", "123", "1x", "123x"]
22 | )
23 | def test_nullable_final_lexeme(self, string):
24 | g = subgrammar(body=lexeme(r"\d+")+lexeme(r"x?"), name="mycap")
25 | match = g.match(string)
26 | assert match is not None
27 | assert match.captures["mycap"] == string
28 |
--------------------------------------------------------------------------------
/tests/unit/library/test_substring.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from guidance import models, substring
4 |
5 |
6 | @pytest.mark.parametrize(
7 | ("mock_string", "target_string", "expected_string"),
8 | [
9 | ("abc", "abc", "abc"),
10 | ("ab", "abc", "ab"),
11 | ("bc", "abc", "bc"),
12 | ("a", "abc", "a"),
13 | ("b", "abc", "b"),
14 | ("c", "abc", "c"),
15 | ("abc", "def", ""), # This is a 'failure' case
16 | (
17 | "long string",
18 | "This is long string, only take part of this long string",
19 | "long string",
20 | ),
21 | ],
22 | )
23 | def test_mocked_substring(mock_string, target_string, expected_string):
24 | m = models.Mock(f"{mock_string}")
25 |
26 | lm = m + substring(target_string, chunk="character", name="result")
27 | assert lm["result"] == expected_string
28 |
--------------------------------------------------------------------------------
/tests/unit/test_model.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import guidance
3 | from guidance import gen, models, user, system
4 |
5 | def test_call_embeddings():
6 | """This tests calls embedded in strings."""
7 | model = models.Mock()
8 |
9 | @guidance(dedent=False)
10 | def bla(lm, bla):
11 | lm += bla + "ae" + gen(max_tokens=10)
12 | return lm
13 |
14 | @guidance(dedent=False)
15 | def ble(lm):
16 | lm += f"""
17 | ae galera! {bla('33')}
18 | let's do more stuff!!""" + gen(
19 | max_tokens=10
20 | )
21 | return lm
22 |
23 | assert "{{G|" not in str(model + ble())
24 |
25 |
26 | @pytest.mark.xfail(
27 | reason="llguidance currently emits an additional empty capture group when no explicit stop is provided"
28 | )
29 | def test_model_set():
30 | model = models.Mock()
31 | model = model.set("num", "4")
32 | assert "num" in model
33 | assert model["num"] == "4"
34 | assert model.log_prob("num") is not None
35 |
36 | model = model.set("list_num", ['1', '2'])
37 | assert "list_num" in model
38 | assert model["list_num"] == ['1', '2']
39 | assert model.log_prob("list_num") is not None
40 |
41 | model += gen("list_num", max_tokens=10, list_append=True)
42 | assert len(model['list_num']) == 3
43 |
44 |
45 | def test_trace():
46 | from guidance import system, user, gen, models
47 | m0 = models.Mock()
48 |
49 | with system():
50 | m1 = m0 + "You are responsible for autocompleting a sentence."
51 | with user():
52 | m2 = m1 + "Roses are red and " + gen(name="suffix", regex='[A-Za-z]{2,5}', max_tokens=5)
53 |
54 | assert m2['suffix'] is not None
--------------------------------------------------------------------------------
/tests/unit/test_trace.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from guidance.trace._trace import WeakRefList, TraceHandler, LiteralInput, TextOutput, RoleCloserInput
4 | from guidance.trace import TraceNode, StatelessGuidanceInput, StatefulGuidanceInput, ImageInput, EmbeddedInput, \
5 | RoleOpenerInput, ImageOutput, CaptureOutput
6 |
7 |
8 | def test_weak_ref_list():
9 | class EmptyClass:
10 | pass
11 |
12 | a = EmptyClass()
13 | b = EmptyClass()
14 | li = WeakRefList()
15 | li.append(a)
16 | li.append(b)
17 |
18 | del a
19 | with pytest.raises(ReferenceError):
20 | _ = li[0]
21 |
22 | # Does not remove dead entries
23 | _ = li[1]
24 | assert len(li) == 2
25 |
26 | # Remove works as expected
27 | li.remove(b)
28 | assert len(li) == 1
29 |
30 | # Iter goes over live entries only
31 | for el in li:
32 | _ = el
33 |
34 |
35 | def test_trace_node():
36 | root = TraceNode()
37 | child1 = TraceNode()
38 | child2 = TraceNode()
39 | root.add_child(child1)
40 | root.add_child(child2)
41 |
42 | assert root.root() is root
43 | assert list(root.ancestors()) == []
44 | assert list(root.path()) == [root]
45 | assert list(root.traverse()) == [root, child1, child2]
46 |
47 | assert child1.root() is root
48 | assert list(child1.ancestors()) == [root]
49 | assert list(child1.path()) == [root, child1]
50 | assert list(child1.traverse()) == [child1]
51 |
52 |
53 | def test_trace_handler():
54 | trace_handler = TraceHandler()
55 | root = trace_handler.update_node(0, None, None)
56 | child1 = trace_handler.update_node(1, 0, None)
57 | inp = LiteralInput(value="")
58 | out = TextOutput(value="")
59 | pre_child2 = trace_handler.update_node(2, 0, inp)
60 | child2 = trace_handler.update_node(2, 0, out)
61 |
62 | assert pre_child2 == child2
63 | assert child2.input == inp
64 | assert child2.output == out
65 | assert child2.root() == root
66 | assert child1 not in child2.path()
67 |
68 |
69 |
70 | @pytest.mark.parametrize(
71 | 'node',
72 | [
73 | StatelessGuidanceInput(value=None),
74 | StatefulGuidanceInput(value=None),
75 | LiteralInput(value=""),
76 | ImageInput(value=b""),
77 | EmbeddedInput(value=""),
78 | RoleOpenerInput(name=""),
79 | RoleCloserInput(name=""),
80 | TextOutput(value=""),
81 | ImageOutput(value=b""),
82 | CaptureOutput(name=""),
83 | ]
84 | )
85 | def test_node_format_smoke(node):
86 | node.__repr__()
87 | node.__str__()
88 |
--------------------------------------------------------------------------------
/tests/unit/test_visual.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from guidance._schema import GenTokenExtra, GenToken
3 | from guidance.registry import get_bg_async
4 | from guidance.trace import TraceHandler, LiteralInput, TextOutput
5 | from guidance.visual import TraceMessage, MetricMessage, ExecutionCompletedMessage, \
6 | TokensMessage, ResetDisplayMessage, ClientReadyMessage, OutputRequestMessage, \
7 | ClientReadyAckMessage, trace_node_to_html, display_trace_tree, trace_node_to_str, TopicExchange, GuidanceMessage
8 | from guidance.visual import serialize_message, deserialize_message
9 | from guidance.visual._environment import Environment
10 | import asyncio
11 |
12 | from guidance.visual._exchange import DEFAULT_TOPIC
13 |
14 |
15 | @pytest.mark.parametrize(
16 | "message",
17 | [
18 | TraceMessage(trace_id=0),
19 | MetricMessage(name="name", value="value"),
20 | ExecutionCompletedMessage(last_trace_id=0),
21 | TokensMessage(trace_id=0, text="text", tokens=[
22 | GenTokenExtra(token_id=0, prob=0, top_k=[GenToken(token_id=0, prob=0)])
23 | ]),
24 | ResetDisplayMessage(),
25 | ClientReadyMessage(),
26 | ClientReadyAckMessage(),
27 | OutputRequestMessage(),
28 | ]
29 | )
30 | def test_serialization(message):
31 | ser = serialize_message(message)
32 | deser = deserialize_message(ser)
33 | assert deser.model_dump() == message.model_dump()
34 |
35 |
36 | def test_async():
37 | _, loop = get_bg_async()._thread_and_loop()
38 | assert loop != asyncio.get_event_loop()
39 |
40 | async def f():
41 | return True
42 |
43 | task = get_bg_async().run_async_coroutine(get_bg_async().async_task(f())).result()
44 | assert task.result() is True
45 |
46 |
47 | def test_str_method_smoke():
48 | trace_handler = TraceHandler()
49 | trace_handler.update_node(1, 0, None)
50 | inp = LiteralInput(value="Hi there!")
51 | out = TextOutput(value="Hi there!")
52 | trace_handler.update_node(2, 0, inp)
53 | child_node = trace_handler.update_node(2, 0, out)
54 |
55 | assert trace_node_to_html(child_node) != ""
56 | assert trace_node_to_str(child_node) != ""
57 | assert display_trace_tree(trace_handler) is None
58 |
59 |
60 | def test_environment():
61 | env = Environment()
62 | assert not env.is_cloud()
63 | assert not env.is_notebook()
64 | assert env.is_terminal()
65 | assert "ipython-zmq" not in env.detected_envs
66 |
67 |
68 | def test_exchange():
69 | exchange = TopicExchange()
70 | assert len(exchange._observers) == 0
71 |
72 | count = 0
73 | def inc(_: GuidanceMessage):
74 | nonlocal count
75 | count += 1
76 |
77 | # Defaults
78 | exchange.subscribe(inc)
79 | exchange.publish(GuidanceMessage())
80 | exchange.unsubscribe(inc)
81 | assert count == 1
82 | assert len(exchange._observers) == 0
83 |
84 | # Topic pattern set
85 | topic_pat = "no"
86 | exchange.subscribe(inc, topic_pat)
87 | exchange.publish(GuidanceMessage(), topic_pat)
88 | exchange.unsubscribe(inc, topic_pat)
89 | assert count == 2
90 | assert len(exchange._observers) == 0
91 |
92 | # Missed topic
93 | topic_pat = "what"
94 | exchange.subscribe(inc, topic_pat)
95 | exchange.publish(GuidanceMessage())
96 | exchange.unsubscribe(inc, topic_pat)
97 | assert count == 2
98 | assert len(exchange._observers) == 0
--------------------------------------------------------------------------------