├── .git-blame-ignore-revs ├── .github ├── ISSUE_TEMPLATE │ └── bug_report.md └── workflows │ ├── call_cpu_tests.yml │ ├── call_gpu_tests.yml │ ├── ci_credentials.yml │ ├── ci_linux.yml │ ├── ci_macos.yml │ ├── ci_windows.yml │ ├── notebook_tests.yml │ ├── pull_request.yml │ └── pypi_upload.yml ├── .gitignore ├── CONTRIBUTING.md ├── GOVERNANCE.md ├── LICENSE.md ├── MAINTAINERS.md ├── MANIFEST.in ├── README.md ├── client └── graphpaper-inline │ ├── .gitignore │ ├── TODO.txt │ ├── build-to-guidance.sh │ ├── dist │ └── .gitignore │ ├── package.json │ ├── pnpm-lock.yaml │ ├── postcss.config.js │ ├── rollup.config.mjs │ ├── src │ ├── App.svelte │ ├── CustomAudio.svelte │ ├── CustomVideo.svelte │ ├── MetricRecord.svelte │ ├── ResizeListener.svelte │ ├── Select.svelte │ ├── Sparkline.svelte │ ├── StitchHandler.svelte │ ├── TokenGrid.svelte │ ├── TokenGridItem.svelte │ ├── clickoutside.ts │ ├── interfaces.ts │ ├── longhover.ts │ ├── main.css │ ├── main.js │ ├── metrics.ts │ ├── mocks.ts │ ├── stitch.ts │ └── template.html │ ├── tailwind.config.js │ └── tsconfig.json ├── docs ├── .readthedocs.yaml ├── Makefile ├── _static │ └── css │ │ └── styles.css ├── api.rst ├── api_examples.rst ├── art_of_prompt_design.rst ├── conf.py ├── figures │ ├── anachronism.png │ ├── await1.png │ ├── await2.png │ ├── capture_example.png │ ├── chat1.png │ ├── chat_animation.gif │ ├── chat_reading.png │ ├── demo_output.png │ ├── favicon.ico │ ├── favicon.png │ ├── function.png │ ├── gen_loop_demo.png │ ├── generate_select.png │ ├── generation1.png │ ├── get_started_button.png │ ├── guidance_logo.svg │ ├── guidance_logo_blue.svg │ ├── guidance_logo_blue_dark.svg │ ├── guidance_logo_light_blue.svg │ ├── guidance_logo_white_dark.svg │ ├── hidden1.png │ ├── json_animation.gif │ ├── json_syntax_variables.png │ ├── perfect_syntax.png │ ├── proverb_animation.gif │ ├── proverb_output.png │ ├── select.png │ ├── simple_fstring_llama2_7b.png │ ├── simple_gen_llama2_7b.png │ ├── simple_select_llama2_7b.png │ ├── simple_streaming_example.gif │ ├── template_objs.png │ ├── url_with_space.png │ ├── url_without_space.png │ └── watch_demo_button.png ├── index.rst ├── make.bat └── tutorials.rst ├── guidance ├── __init__.py ├── _ast.py ├── _bg │ └── __init__.py ├── _grammar.py ├── _guidance.py ├── _guidance.pyi ├── _parser.py ├── _schema.py ├── _utils.py ├── bench │ ├── __init__.py │ ├── _api.py │ ├── _powerlift.py │ └── _utils.py ├── chat.py ├── library │ ├── __init__.py │ ├── _audio.py │ ├── _block.py │ ├── _capture.py │ ├── _ebnf.py │ ├── _gen.py │ ├── _image.py │ ├── _json.py │ ├── _optional.py │ ├── _pydantic.py │ ├── _role.py │ ├── _sequences.py │ ├── _subgrammar.py │ ├── _substring.py │ ├── _tool.py │ └── _video.py ├── metrics │ ├── __init__.py │ └── _metrics.py ├── models │ ├── __init__.py │ ├── _azureai.py │ ├── _base │ │ ├── __init__.py │ │ ├── _interpreter.py │ │ ├── _model.py │ │ └── _state.py │ ├── _byte_tokenizer.py │ ├── _engine │ │ ├── __init__.py │ │ ├── _engine.py │ │ ├── _interpreter.py │ │ ├── _state.py │ │ └── _tokenizer.py │ ├── _llama_cpp.py │ ├── _mock.py │ ├── _openai.py │ ├── _openai_base.py │ ├── _transformers.py │ ├── broken_models │ │ ├── README.MD │ │ ├── _Gemini.py │ │ ├── _anthropic.py │ │ ├── _azure_openai.py │ │ ├── _azureai_studio.py │ │ ├── _cohere.py │ │ ├── _googleai.py │ │ ├── _lite_llm.py │ │ ├── _togetherai.py │ │ └── _vertexai.py │ └── experimental │ │ ├── __init__.py │ │ └── _vllm.py ├── py.typed ├── registry │ ├── __init__.py │ └── _registry.py ├── resources │ ├── graphpaper-inline.html │ ├── main.js │ ├── sample_audio.wav │ ├── sample_image.png │ └── sample_video.mp4 ├── selectors.py ├── trace │ ├── __init__.py │ └── _trace.py └── visual │ ├── __init__.py │ ├── _environment.py │ ├── _exchange.py │ ├── _jupyter.py │ ├── _message.py │ ├── _renderer.py │ └── _trace.py ├── notebooks ├── anachronism.ipynb ├── api_examples │ ├── library │ │ └── gen.ipynb │ └── models │ │ ├── AzureOpenAI.ipynb │ │ ├── OpenAI.ipynb │ │ └── TogetherAI.ipynb ├── art_of_prompt_design │ ├── prompt_boundaries_and_token_healing.ipynb │ ├── rag.ipynb │ ├── react.ipynb │ ├── tool_use.ipynb │ └── use_clear_syntax.ipynb ├── benchmarks │ └── json_output_bench.ipynb ├── chatgpt_vs_open_source_on_harder_tasks.ipynb ├── guaranteeing_valid_syntax.ipynb ├── proverb.ipynb ├── testing_lms.ipynb ├── tutorials │ ├── adding_new_models.ipynb │ ├── chat.ipynb │ ├── code_generation.ipynb │ ├── guidance_acceleration.ipynb │ ├── intro_to_guidance.ipynb │ ├── regex_constraints.ipynb │ └── token_healing.ipynb └── unstable │ ├── .gitignore │ └── State Debugging.ipynb ├── packages └── python │ └── stitch │ ├── .coveragerc │ ├── .eslintignore │ ├── .eslintrc.js │ ├── .github │ └── workflows │ │ └── build.yml │ ├── .gitignore │ ├── .npmignore │ ├── .prettierignore │ ├── .prettierrc │ ├── .yarnrc.yml │ ├── LICENSE.txt │ ├── MANIFEST.in │ ├── README.md │ ├── babel.config.js │ ├── codecov.yml │ ├── css │ └── widget.css │ ├── docs │ ├── Makefile │ ├── environment.yml │ ├── make.bat │ └── source │ │ ├── _static │ │ └── helper.js │ │ ├── conf.py │ │ ├── develop-install.rst │ │ ├── examples │ │ ├── index.rst │ │ └── introduction.nblink │ │ ├── index.rst │ │ ├── installing.rst │ │ └── introduction.rst │ ├── examples │ └── introduction.ipynb │ ├── install.json │ ├── jest.config.js │ ├── package.json │ ├── pyproject.toml │ ├── pytest.ini │ ├── readthedocs.yml │ ├── setup.py │ ├── src │ ├── __tests__ │ │ ├── index.spec.ts │ │ └── utils.ts │ ├── extension.ts │ ├── index.ts │ ├── plugin.ts │ ├── version.ts │ └── widget.ts │ ├── stitch.json │ ├── stitch │ ├── __init__.py │ ├── _frontend.py │ ├── _version.py │ ├── nbextension │ │ └── extension.js │ ├── stitch.py │ └── tests │ │ ├── __init__.py │ │ ├── conftest.py │ │ ├── test_example.py │ │ └── test_nbextension_path.py │ ├── tsconfig.eslint.json │ ├── tsconfig.json │ ├── webpack.config.js │ └── yarn.lock ├── pyproject.toml ├── setup.py └── tests ├── ReadMe.md ├── __init__.py ├── bench ├── __init__.py ├── test_api.py ├── test_powerlift.py └── test_utils.py ├── conftest.py ├── model_integration ├── __init__.py ├── library │ ├── test_gen.py │ ├── test_subgrammar.py │ └── test_substring.py ├── test_grammar.py ├── test_model.py └── test_tokenizers.py ├── model_specific ├── __init__.py ├── common_chat_testing.py ├── llama_cpp_tests │ ├── __init__.py │ ├── test_chat_templates.py │ └── test_llama_cpp.py ├── test_transformers.py └── test_visual.py ├── need_credentials ├── __init__.py ├── test_anthropic.py ├── test_azureai_openai.py ├── test_azureai_studio.py ├── test_chat_templates.py ├── test_cohere.py ├── test_googleai.py ├── test_lite_llm.py ├── test_openai.py ├── test_togetherai.py ├── test_tokenizers.py └── test_vertexai.py ├── notebooks ├── __init__.py └── test_notebooks.py ├── tokenizer_common.py ├── unit ├── __init__.py ├── library │ ├── __init__.py │ ├── json │ │ ├── __init__.py │ │ ├── test_allOf.py │ │ ├── test_json.py │ │ ├── test_refs.py │ │ ├── test_string_format.py │ │ └── utils.py │ ├── test_block.py │ ├── test_capture.py │ ├── test_gen.py │ ├── test_image.py │ ├── test_one_or_more.py │ ├── test_pydantic.py │ ├── test_regex.py │ ├── test_sequences.py │ ├── test_subgrammar.py │ └── test_substring.py ├── test_ast.py ├── test_decorator.py ├── test_grammar.py ├── test_ll.py ├── test_model.py ├── test_parser.py ├── test_trace.py └── test_visual.py └── utils.py /.git-blame-ignore-revs: -------------------------------------------------------------------------------- 1 | # .git-blame-ignore-revs 2 | # Ran black on major files to standardize codebase 3 | 57da386795bc94a34275b333da586f171f96d7c8 4 | # Ran black on tests and other ancillary python files in code 5 | 083fb9877b507ed27136441c683ce051edf37e81 -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **The bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Give a full working code snippet that can be pasted into a notebook cell or python file. Make sure to include the LLM load step so we know which model you are using. 15 | ```python 16 | # put your code snippet here 17 | ``` 18 | 19 | **System info (please complete the following information):** 20 | - OS (e.g. Ubuntu, Windows 11, Mac OS, etc.): 21 | - Guidance Version (`guidance.__version__`): 22 | -------------------------------------------------------------------------------- /.github/workflows/call_cpu_tests.yml: -------------------------------------------------------------------------------- 1 | name: call_cpu_tests 2 | 3 | on: 4 | workflow_call: 5 | inputs: 6 | os: 7 | required: true 8 | type: string 9 | python-version: 10 | required: true 11 | type: string 12 | model: 13 | required: true 14 | type: string 15 | secrets: 16 | HF_TOKEN: 17 | required: false 18 | workflow_dispatch: 19 | inputs: 20 | os: 21 | required: false 22 | type: string 23 | default: "Large_Linux" # can instead use "Large_Windows" or the default OSes like "macos-latest" 24 | python-version: 25 | required: false 26 | type: string 27 | default: "3.12" 28 | model: 29 | required: false 30 | type: string 31 | default: "transformers_gpt2_cpu" # also try "llamacpp_llama2_7b_cpu", etc 32 | commit_id: 33 | description: 'Branch or Commit ID (optional)' 34 | required: false 35 | type: string 36 | 37 | jobs: 38 | cpu_tests: 39 | runs-on: ${{ inputs.os }} 40 | steps: 41 | - name: Checkout repo at ${{ github.event_name == 'workflow_dispatch' && inputs.commit_id || github.sha }} 42 | uses: actions/checkout@v4 43 | with: 44 | ref: ${{ github.event_name == 'workflow_dispatch' && inputs.commit_id || github.sha }} 45 | - name: Set up Python ${{ inputs.python-version }} 46 | uses: actions/setup-python@v5 47 | with: 48 | python-version: ${{ inputs.python-version }} 49 | - name: Install guidance in ${{ inputs.os }} 50 | shell: bash 51 | run: | 52 | python -m pip install --upgrade pip 53 | python -m pip install -e .[llamacpp,transformers,test] 54 | python -m pip install accelerate # required if using smaller quantizations 55 | - name: cpu_tests for ${{ inputs.model }} 56 | shell: bash 57 | env: 58 | HF_TOKEN: ${{ secrets.HF_TOKEN }} 59 | run: | 60 | pytest -vv --cov=guidance --cov-report=xml --cov-report=term-missing \ 61 | --selected_model ${{ inputs.model }} \ 62 | ./tests/model_integration ./tests/model_specific 63 | - name: Upload coverage reports to Codecov 64 | uses: codecov/codecov-action@v4 65 | with: 66 | token: ${{ secrets.CODECOV_TOKEN }} 67 | -------------------------------------------------------------------------------- /.github/workflows/call_gpu_tests.yml: -------------------------------------------------------------------------------- 1 | name: call_gpu_tests 2 | 3 | on: 4 | workflow_call: 5 | inputs: 6 | os: 7 | required: true 8 | type: string 9 | python-version: 10 | required: true 11 | type: string 12 | model: 13 | required: true 14 | type: string 15 | secrets: 16 | HF_TOKEN: 17 | required: false 18 | workflow_dispatch: 19 | inputs: 20 | os: 21 | required: false 22 | type: string 23 | default: "gpu-runner" 24 | python-version: 25 | required: false 26 | type: string 27 | default: "3.12" 28 | model: 29 | required: false 30 | type: string 31 | default: "llamacpp_llama2_7b_gpu" # also try "transformers_gpt2_gpu", "transformers_phi2_gpu", etc 32 | commit_id: 33 | description: 'Branch or Commit ID (optional)' 34 | required: false 35 | type: string 36 | 37 | jobs: 38 | gpu_tests: 39 | runs-on: ${{ inputs.os }} 40 | steps: 41 | - name: Checkout repo at ${{ github.event_name == 'workflow_dispatch' && inputs.commit_id || github.sha }} 42 | uses: actions/checkout@v4 43 | with: 44 | ref: ${{ github.event_name == 'workflow_dispatch' && inputs.commit_id || github.sha }} 45 | - name: Set up Python ${{ inputs.python-version }} 46 | uses: actions/setup-python@v5 47 | with: 48 | python-version: ${{ inputs.python-version }} 49 | - name: Install NVIDIA SDK 50 | shell: bash 51 | run: | 52 | nvidia-smi 53 | sudo apt-get --yes update 54 | sudo apt-get --yes install cuda-toolkit-12.6 55 | echo "/usr/local/cuda-12.6/bin" >> $GITHUB_PATH 56 | - name: Upgrade pip 57 | shell: bash 58 | run : | 59 | python -m pip install --upgrade pip 60 | - name: Install other packages 61 | shell: bash 62 | run: | 63 | python -m pip install accelerate 64 | - name: Install guidance in ${{ inputs.os }} 65 | shell: bash 66 | run: | 67 | CMAKE_ARGS="-DGGML_CUDA=on" python -m pip install -e .[llamacpp,transformers,test] 68 | - name: Check GPU available 69 | shell: bash 70 | run: | 71 | python -c "import torch; assert torch.cuda.is_available()" 72 | - name: gpu_tests for ${{ inputs.model }} 73 | shell: bash 74 | env: 75 | HF_TOKEN: ${{ secrets.HF_TOKEN }} 76 | run: | 77 | pytest -vv --cov=guidance --cov-report=xml --cov-report=term-missing \ 78 | --selected_model ${{ inputs.model }} \ 79 | ./tests/model_integration ./tests/model_specific 80 | - name: Upload coverage reports to Codecov 81 | uses: codecov/codecov-action@v4 82 | with: 83 | token: ${{ secrets.CODECOV_TOKEN }} 84 | -------------------------------------------------------------------------------- /.github/workflows/ci_linux.yml: -------------------------------------------------------------------------------- 1 | # CI Tests which run on Linux machines 2 | 3 | # These access secrets, so should only be run on local branches. 4 | 5 | # Ideally, the CI tests would be a single workflow, but several issues 6 | # (especially varied OS support) mean that it is hard to keep a single 7 | # workflow green. 8 | 9 | name: CI Tests - Linux 10 | permissions: 11 | contents: read 12 | 13 | on: 14 | workflow_dispatch: 15 | inputs: 16 | commit_id: 17 | description: 'Branch or Commit ID (optional)' 18 | required: false 19 | type: string 20 | schedule: 21 | # * is a special character in YAML so we quote this string 22 | # Run at 09:00 UTC every day 23 | - cron: '00 09 * * *' 24 | 25 | 26 | jobs: 27 | cpu_small: 28 | strategy: 29 | fail-fast: false # Don't cancel all on first failure 30 | matrix: 31 | python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] 32 | model: 33 | - "transformers_gpt2_cpu" 34 | - "llamacpp_phi3_mini_4k_instruct_cpu" 35 | - "llamacpp_gemma2_9b_cpu" 36 | uses: ./.github/workflows/call_cpu_tests.yml 37 | with: 38 | os: Large_Linux 39 | python-version: ${{ matrix.python-version }} 40 | model: ${{ matrix.model }} 41 | secrets: 42 | HF_TOKEN: ${{ secrets.HF_TOKEN }} 43 | 44 | cpu_big: 45 | strategy: 46 | fail-fast: false # Don't cancel all on first failure 47 | matrix: 48 | python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] 49 | model: 50 | - "llamacpp_llama2_7b_cpu" 51 | - "transformers_llama3_8b_cpu" 52 | - "transformers_phi4_mini_cpu" 53 | uses: ./.github/workflows/call_cpu_tests.yml 54 | with: 55 | os: Large_Linux 56 | python-version: ${{ matrix.python-version }} 57 | model: ${{ matrix.model }} 58 | secrets: 59 | HF_TOKEN: ${{ secrets.HF_TOKEN }} 60 | 61 | gpu_tests: 62 | strategy: 63 | fail-fast: false # Don't cancel all on first failure 64 | matrix: 65 | python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] 66 | model: 67 | - "transformers_gpt2_gpu" 68 | - "transformers_gemma2_9b_gpu" 69 | - "llamacpp_llama2_7b_gpu" 70 | - "transformers_gemma2_9b_cpu" # CUDA is required for this model 71 | - "transformers_phi4_mini_gpu" 72 | uses: ./.github/workflows/call_gpu_tests.yml 73 | with: 74 | os: "gpu-runner" 75 | python-version: ${{ matrix.python-version }} 76 | model: ${{ matrix.model }} 77 | secrets: 78 | HF_TOKEN: ${{ secrets.HF_TOKEN }} 79 | -------------------------------------------------------------------------------- /.github/workflows/ci_macos.yml: -------------------------------------------------------------------------------- 1 | # CI Tests which run on MacOS machines 2 | 3 | # These access secrets, so should only be run on local branches. 4 | 5 | # Ideally, the CI tests would be a single workflow, but several issues 6 | # (especially varied OS support) mean that it is hard to keep a single 7 | # workflow green. 8 | 9 | # MacOS has been a particular trouble due to the small disk space 10 | # allocations on all the VMs, leading to the --selected_model 11 | # machinery 12 | 13 | name: CI Tests - MacOS 14 | permissions: 15 | contents: read 16 | 17 | on: 18 | workflow_dispatch: 19 | inputs: 20 | commit_id: 21 | description: 'Branch or Commit ID (optional)' 22 | required: false 23 | type: string 24 | schedule: 25 | # * is a special character in YAML so we quote this string 26 | # Run at 09:10 UTC every day 27 | - cron: '10 09 * * *' 28 | 29 | jobs: 30 | cpu_small: 31 | strategy: 32 | fail-fast: false # Don't cancel all on first failure 33 | matrix: 34 | python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] 35 | model: 36 | - "transformers_gpt2_cpu" 37 | - "llamacpp_phi3_mini_4k_instruct_cpu" 38 | - "llamacpp_gemma2_9b_cpu" 39 | uses: ./.github/workflows/call_cpu_tests.yml 40 | with: 41 | os: "macos-latest" 42 | python-version: ${{ matrix.python-version }} 43 | model: ${{ matrix.model }} 44 | secrets: 45 | HF_TOKEN: ${{ secrets.HF_TOKEN }} -------------------------------------------------------------------------------- /.github/workflows/ci_windows.yml: -------------------------------------------------------------------------------- 1 | # CI Tests which run on Windows machines 2 | 3 | # These access secrets, so should only be run on local branches. 4 | 5 | # Ideally, the CI tests would be a single workflow, but several issues 6 | # (especially varied OS support) mean that it is hard to keep a single 7 | # workflow green. If there is one OS likely to lag slightly in support 8 | # it is Windows 9 | 10 | name: CI Tests - Windows 11 | permissions: 12 | contents: read 13 | 14 | on: 15 | workflow_dispatch: 16 | inputs: 17 | commit_id: 18 | description: 'Branch or Commit ID (optional)' 19 | required: false 20 | type: string 21 | schedule: 22 | # * is a special character in YAML so we quote this string 23 | # Run at 09:30 UTC every day 24 | - cron: '30 09 * * *' 25 | 26 | jobs: 27 | cpu_small: 28 | strategy: 29 | fail-fast: false # Don't cancel all on first failure 30 | matrix: 31 | python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] 32 | model: 33 | - "transformers_gpt2_cpu" 34 | - "llamacpp_phi3_mini_4k_instruct_cpu" 35 | - "llamacpp_gemma2_9b_cpu" 36 | uses: ./.github/workflows/call_cpu_tests.yml 37 | with: 38 | os: "Large_Windows" 39 | python-version: ${{ matrix.python-version }} 40 | model: ${{ matrix.model }} 41 | secrets: 42 | HF_TOKEN: ${{ secrets.HF_TOKEN }} 43 | 44 | cpu_big: 45 | strategy: 46 | fail-fast: false # Don't cancel all on first failure 47 | matrix: 48 | python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] 49 | model: 50 | - "llamacpp_llama2_7b_cpu" 51 | - "transformers_llama3_8b_cpu" 52 | - "transformers_phi4_mini_cpu" 53 | uses: ./.github/workflows/call_cpu_tests.yml 54 | with: 55 | os: "Large_Windows" 56 | python-version: ${{ matrix.python-version }} 57 | model: ${{ matrix.model }} 58 | secrets: 59 | HF_TOKEN: ${{ secrets.HF_TOKEN }} 60 | -------------------------------------------------------------------------------- /.github/workflows/notebook_tests.yml: -------------------------------------------------------------------------------- 1 | # These should only be run on main, because they access secrets 2 | # Not part of the regular CI run, since notebook tests seem 3 | # particularly flaky 4 | 5 | name: Notebook Tests 6 | 7 | on: 8 | workflow_dispatch: 9 | inputs: 10 | commit_id: 11 | description: 'Branch or Commit ID (optional)' 12 | required: false 13 | type: string 14 | schedule: 15 | # * is a special character in YAML so we quote this string 16 | # Run at 10:00 UTC every day 17 | - cron: '00 10 * * *' 18 | 19 | jobs: 20 | notebook_tests: 21 | runs-on: "Large_Linux" 22 | environment: test 23 | strategy: 24 | fail-fast: false # Don't cancel all on first failure 25 | matrix: 26 | python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] 27 | permissions: 28 | id-token: write # for Azure CLI login 29 | steps: 30 | - name: Checkout repo at ${{ github.event_name == 'workflow_dispatch' && inputs.commit_id || github.sha }} 31 | uses: actions/checkout@v4 32 | with: 33 | ref: ${{ github.event_name == 'workflow_dispatch' && inputs.commit_id || github.sha }} 34 | - name: Set up Python ${{ matrix.python-version }} 35 | uses: actions/setup-python@v5 36 | with: 37 | python-version: ${{ matrix.python-version }} 38 | - name: Install guidance 39 | shell: bash 40 | run: | 41 | python -m pip install --upgrade pip 42 | python -m pip install -e .[all,llamacpp,test] 43 | - name: Azure login 44 | uses: azure/login@v2 45 | with: 46 | client-id: ${{ secrets.AZURE_CLIENT_ID }} 47 | tenant-id: ${{ secrets.AZURE_TENANT_ID }} 48 | subscription-id: ${{ secrets.AZURE_SUBSCRIPTION_ID }} 49 | - name: 'Run Azure CLI commands' 50 | shell: bash 51 | run: | 52 | az account show 53 | az group list 54 | - name: Notebook tests 55 | shell: bash 56 | env: 57 | HF_TOKEN: ${{ secrets.HF_TOKEN }} 58 | # Configure endpoints 59 | AZUREAI_OPENAI_CHAT_ENDPOINT: ${{ vars.AZUREAI_OPENAI_CHAT_ENDPOINT }} 60 | AZUREAI_OPENAI_CHAT_DEPLOYMENT_NAME: ${{ vars.AZUREAI_OPENAI_CHAT_DEPLOYMENT_NAME }} 61 | AZUREAI_OPENAI_CHAT_MODEL: ${{ vars.AZUREAI_OPENAI_CHAT_MODEL }} 62 | AZUREAI_OPENAI_CHAT_API_VERSION: ${{ vars.AZUREAI_OPENAI_CHAT_API_VERSION }} 63 | run: | 64 | pytest -vv --cov=guidance --cov-report=xml --cov-report=term-missing \ 65 | ./tests/notebooks 66 | - name: Upload coverage reports to Codecov 67 | uses: codecov/codecov-action@v4 68 | with: 69 | token: ${{ secrets.CODECOV_TOKEN }} 70 | -------------------------------------------------------------------------------- /.github/workflows/pull_request.yml: -------------------------------------------------------------------------------- 1 | name: Pull Request 2 | 3 | on: 4 | pull_request: 5 | workflow_dispatch: 6 | inputs: 7 | commit_id: 8 | description: 'Branch or Commit ID (optional)' 9 | required: false 10 | type: string 11 | schedule: 12 | # Run at 10:00 UTC every day 13 | - cron: "00 10 * * *" 14 | 15 | jobs: 16 | unit_tests: 17 | strategy: 18 | fail-fast: false # Don't cancel all on first failure 19 | matrix: 20 | os: [ubuntu-latest, windows-latest, macos-latest] 21 | python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] 22 | runs-on: ${{ matrix.os }} 23 | steps: 24 | - name: Checkout repo at ${{ github.event_name == 'workflow_dispatch' && inputs.commit_id || github.sha }} 25 | uses: actions/checkout@v4 26 | with: 27 | ref: ${{ github.event_name == 'workflow_dispatch' && inputs.commit_id || github.sha }} 28 | - name: Set up Python ${{ matrix.python-version }} 29 | uses: actions/setup-python@v5 30 | with: 31 | python-version: ${{ matrix.python-version }} 32 | - name: Minimal install 33 | run: | 34 | python -m pip install --upgrade pip 35 | python -m pip install -e . 36 | - name: Attempt import 37 | run: | 38 | python -c "import guidance" 39 | - name: Bigger install 40 | run: | 41 | python -m pip install -e .[unittest] 42 | - name: Unit Tests 43 | shell: bash 44 | run: | 45 | pytest -vv --cov=guidance --cov-report=xml --cov-report=term-missing \ 46 | ./tests/unit 47 | - name: Upload coverage reports to Codecov 48 | uses: codecov/codecov-action@v4 49 | with: 50 | token: ${{ secrets.CODECOV_TOKEN }} 51 | 52 | cpu_tests: 53 | strategy: 54 | fail-fast: false # Don't cancel all on first failure 55 | matrix: 56 | os: ["Large_Linux"] # , "Large_Windows"] 57 | python-version: ["3.9", "3.13"] 58 | model: 59 | - "transformers_gpt2_cpu" 60 | - "llamacpp_phi3_mini_4k_instruct_cpu" 61 | uses: ./.github/workflows/call_cpu_tests.yml 62 | with: 63 | os: ${{ matrix.os }} 64 | python-version: ${{ matrix.python-version }} 65 | model: ${{ matrix.model }} 66 | 67 | # gpu_tests: 68 | # strategy: 69 | # fail-fast: false # Don't cancel all on first failure 70 | # matrix: 71 | # os: ["gpu-runner"] 72 | # python-version: ["3.9", "3.12"] 73 | # model: 74 | # - "transformers_gpt2_gpu" 75 | # - "llamacpp_llama2_7b_gpu" 76 | # uses: ./.github/workflows/call_gpu_tests.yml 77 | # with: 78 | # os: ${{ matrix.os }} 79 | # python-version: ${{ matrix.python-version }} 80 | # model: ${{ matrix.model }} 81 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | notebooks/local_scratch 2 | __pycache__/ 3 | .vscode 4 | .vs 5 | .idea/ 6 | /build 7 | /dist 8 | *.egg-info 9 | *.diskcache 10 | .ipynb_checkpoints 11 | node_modules 12 | .eggs/ 13 | .env 14 | .DS_Store 15 | venv/ 16 | 17 | # Ignore native library built by setup 18 | guidance/*.so 19 | guidance/_rust/*.so 20 | guidance/_rust/target/ 21 | guidance/_rust/Cargo.lock 22 | *.pyd 23 | 24 | notebooks/**/*.papermill_out.ipynb 25 | 26 | .mypy_cache/* 27 | 28 | **/scratch.* -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) The Guidance Contributors 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MAINTAINERS.md: -------------------------------------------------------------------------------- 1 | # Maintainers 2 | 3 | This document lists the Maintainers of the Project. Maintainers may be added once approved by the existing maintainers as described in the [Governance document](./GOVERNANCE.md). By adding your name to this list you are agreeing to abide by the Project governance documents and to abide by all of the Organization's polices, including the [code of conduct](https://github.com/guidance-ai/governance/blob/main/CODE-OF-CONDUCT.md), [trademark policy](https://github.com/guidance-ai/governance/blob/main/TRADEMARKS.md), and [antitrust policy](https://github.com/guidance-ai/governance/blob/main/ANTITRUST.md). If you are participating because of your affiliation with another organization (designated below), you represent that you have the authority to bind that organization to these policies. 4 | 5 | | **NAME** | **Handle** | **Affiliated Organization** | 6 | | --- | --- | --- | 7 | | Scott Lundberg | [slundberg](https://github.com/slundberg) | | 8 | | Harsha Nori | [Harsha-Nori](https://github.com/Harsha-Nori) | Microsoft | 9 | | Marco Tulio Ribeiro | [marcotcr](https://github.com/marcotcr) | Google | 10 | 11 | --- 12 | Part of MVG-0.1-beta. 13 | Made with love by GitHub. Licensed under the [CC-BY 4.0 License](https://creativecommons.org/licenses/by-sa/4.0/). 14 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include resources/graphpaper-inline.html 2 | include resources/sample_audio.wav 3 | include resources/sample_video.mp4 4 | include resources/sample_image.png 5 | -------------------------------------------------------------------------------- /client/graphpaper-inline/.gitignore: -------------------------------------------------------------------------------- 1 | node_modules/ 2 | build/ 3 | .DS_Store 4 | -------------------------------------------------------------------------------- /client/graphpaper-inline/TODO.txt: -------------------------------------------------------------------------------- 1 | - Remove CDN font links (googlefonts) 2 | - Image integration 3 | - Testing -------------------------------------------------------------------------------- /client/graphpaper-inline/build-to-guidance.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -x 3 | 4 | npm run build 5 | cp dist/index.html ../../guidance/resources/graphpaper-inline.html -------------------------------------------------------------------------------- /client/graphpaper-inline/dist/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /client/graphpaper-inline/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "graphpaper", 3 | "version": "0.0.1", 4 | "scripts": { 5 | "build": "rollup -c", 6 | "dev": "rollup -c -w", 7 | "start": "sirv dist" 8 | }, 9 | "devDependencies": { 10 | "@rollup/plugin-commonjs": "^26.0.1", 11 | "@rollup/plugin-node-resolve": "^15.2.3", 12 | "@rollup/plugin-terser": "^0.4.4", 13 | "@rollup/plugin-typescript": "^11.1.6", 14 | "@types/d3-scale": "^4.0.8", 15 | "@types/d3-scale-chromatic": "^3.0.3", 16 | "@types/dompurify": "^3.0.5", 17 | "@types/video.js": "^7.3.58", 18 | "autoprefixer": "^10.4.20", 19 | "cssnano": "^7.0.5", 20 | "postcss": "^8.4.41", 21 | "rollup": "^4.21.0", 22 | "rollup-plugin-copy": "^3.5.0", 23 | "rollup-plugin-html-bundle": "^0.0.3", 24 | "rollup-plugin-livereload": "^2.0.5", 25 | "rollup-plugin-postcss": "^4.0.2", 26 | "rollup-plugin-serve": "^1.1.1", 27 | "rollup-plugin-svelte": "^7.2.2", 28 | "sirv-cli": "^2.0.2", 29 | "svelte": "^4.2.18", 30 | "svelte-preprocess": "^6.0.2", 31 | "tailwindcss": "^3.4.10", 32 | "tslib": "^2.6.3", 33 | "typescript": "^5.5.4" 34 | }, 35 | "dependencies": { 36 | "d3-interpolate": "^3.0.1", 37 | "d3-scale": "^4.0.2", 38 | "d3-scale-chromatic": "^3.1.0", 39 | "dompurify": "^3.1.7", 40 | "tailwind-scrollbar": "^4.0.0", 41 | "video.js": "^8.21.0" 42 | } 43 | } -------------------------------------------------------------------------------- /client/graphpaper-inline/postcss.config.js: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | plugins: { 3 | tailwindcss: {}, 4 | autoprefixer: {}, 5 | cssnano: { preset: 'default' } 6 | } 7 | } -------------------------------------------------------------------------------- /client/graphpaper-inline/rollup.config.mjs: -------------------------------------------------------------------------------- 1 | import svelte from 'rollup-plugin-svelte'; 2 | import { sveltePreprocess } from 'svelte-preprocess'; 3 | import resolve from '@rollup/plugin-node-resolve'; 4 | import commonjs from '@rollup/plugin-commonjs'; 5 | import terser from '@rollup/plugin-terser'; 6 | import typescript from '@rollup/plugin-typescript'; 7 | import postcss from 'rollup-plugin-postcss'; 8 | import livereload from 'rollup-plugin-livereload'; 9 | // @ts-ignore 10 | import serve from 'rollup-plugin-serve'; 11 | // @ts-ignore 12 | import htmlBundle from 'rollup-plugin-html-bundle'; 13 | import copy from 'rollup-plugin-copy'; 14 | 15 | const production = !process.env.ROLLUP_WATCH; 16 | 17 | export default [{ 18 | input: 'src/main.js', 19 | output: { 20 | format: 'iife', 21 | name: 'app', 22 | file: 'build/bundle.js', 23 | sourcemap: !production, 24 | }, 25 | plugins: [ 26 | typescript(), 27 | svelte({ 28 | compilerOptions: { 29 | dev: !production 30 | }, 31 | preprocess: sveltePreprocess() 32 | }), 33 | resolve({ 34 | browser: true, 35 | dedupe: importee => importee === 'svelte' || importee.startsWith('svelte/'), 36 | extensions: ['.svelte', '.mjs', '.ts', '.js', '.json', '.node'] 37 | }), 38 | commonjs(), 39 | postcss(), 40 | copy({ 41 | targets: [ 42 | { src: 'src/template.html', dest: 'build' } 43 | ] 44 | }), 45 | htmlBundle({ 46 | template: 'build/template.html', 47 | target: production ? 'dist/index.html' : 'build/index.html', 48 | targetElement: 'body', 49 | inline: production 50 | }), 51 | !production && serve('build'), 52 | !production && livereload('build'), 53 | production && terser() 54 | ], 55 | watch: { 56 | clearScreen: false 57 | } 58 | }]; -------------------------------------------------------------------------------- /client/graphpaper-inline/src/CustomVideo.svelte: -------------------------------------------------------------------------------- 1 | 43 | 44 |
45 | 48 |
49 | 50 | 55 | -------------------------------------------------------------------------------- /client/graphpaper-inline/src/MetricRecord.svelte: -------------------------------------------------------------------------------- 1 | 2 | 16 | 17 | 24 | 25 | 26 | 27 | {#if value.constructor === Array} 28 | {metricDef.name} 29 | 30 | {:else} 31 | {metricDef.name} 32 | {#if typeof value === "number"} 33 | {value.toFixed(metricDef.precision)} 34 | {#if metricDef.units !== ''} 35 | {metricDef.units} 36 | {/if} 37 | 38 | {:else} 39 | {value} 40 | {#if metricDef.units !== ''} 41 | {metricDef.units} 42 | {/if} 43 | 44 | {/if} 45 | {/if} 46 | 47 | -------------------------------------------------------------------------------- /client/graphpaper-inline/src/ResizeListener.svelte: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /client/graphpaper-inline/src/Select.svelte: -------------------------------------------------------------------------------- 1 | 2 | 27 | 28 |
29 | 39 | {#if showList} 40 | 45 | {/if} 46 |
47 | -------------------------------------------------------------------------------- /client/graphpaper-inline/src/Sparkline.svelte: -------------------------------------------------------------------------------- 1 | 2 | 3 | 31 | 32 |
33 | 34 | 35 | 36 | 37 | 38 |
-------------------------------------------------------------------------------- /client/graphpaper-inline/src/StitchHandler.svelte: -------------------------------------------------------------------------------- 1 | 2 | 47 | 48 | -------------------------------------------------------------------------------- /client/graphpaper-inline/src/TokenGridItem.svelte: -------------------------------------------------------------------------------- 1 | 2 | 10 | 11 | {#each token.text as ch, i} 12 | {#if ch === ' '} 13 | 14 | {#if i === 0} 15 | 16 | {token.role} 17 | 18 | {/if} 19 |   20 | 21 | {:else if ch === '\t'} 22 | 23 | {#if i === 0} 24 | 25 | {token.role} 26 | 27 | {/if} 28 | \t   29 | 30 | {:else if ch === '\n'} 31 | 32 | {#if i === 0} 33 | 34 | {token.role} 35 | 36 | {/if} 37 | \n 38 | 39 |
40 | {:else} 41 | 42 | {#if i === 0} 43 | 44 | {token.role} 45 | 46 | {/if} 47 | {ch} 48 | 49 | {/if} 50 | {/each} -------------------------------------------------------------------------------- /client/graphpaper-inline/src/clickoutside.ts: -------------------------------------------------------------------------------- 1 | // Action for clicking outside an element. 2 | 3 | export function clickOutside(node: HTMLElement) { 4 | const handleClick = (event: MouseEvent) => { 5 | let target = event.target as HTMLElement; 6 | if (!node.contains(target)) { 7 | node.dispatchEvent(new CustomEvent('outclick')); 8 | } 9 | }; 10 | 11 | document.addEventListener('click', handleClick, true); 12 | 13 | return { 14 | destroy() { 15 | document.removeEventListener('click', handleClick, true); 16 | } 17 | }; 18 | } -------------------------------------------------------------------------------- /client/graphpaper-inline/src/interfaces.ts: -------------------------------------------------------------------------------- 1 | // Interfaces used within the client. This is separate to messaging interfaces. 2 | 3 | import type {GenToken, RoleOpenerInput} from "./stitch"; 4 | 5 | export interface MetricDef { 6 | name: string, 7 | units: string, 8 | description: string, 9 | isScalar: boolean, 10 | precision: number, 11 | } 12 | 13 | export type MetricVal = string | number | Array; 14 | 15 | export interface Token { 16 | text: string, 17 | prob: number, 18 | latency_ms: number, 19 | is_input: boolean, 20 | is_force_forwarded: boolean, 21 | is_generated: boolean, 22 | role: string, 23 | special: boolean, 24 | top_k?: Array 25 | } 26 | export declare type TokenCallback = (token: Token) => string; 27 | 28 | export interface MediaNodeContext { 29 | roleStack: RoleOpenerInput[]; 30 | index: number; 31 | } 32 | 33 | export type MediaType = "audio" | "video" | "image"; 34 | 35 | export interface MediaNode { 36 | type: MediaType; 37 | value: any; 38 | format: string; 39 | context: MediaNodeContext; 40 | } 41 | 42 | export type MultimodalNode = 43 | | { type: 'token', data: Token } 44 | | { type: 'media', data: MediaNode }; 45 | -------------------------------------------------------------------------------- /client/graphpaper-inline/src/longhover.ts: -------------------------------------------------------------------------------- 1 | // Action for long mouse hovers. 2 | 3 | export function longhover(node: HTMLElement, duration: number) { 4 | let timer: any; 5 | 6 | const handleMouseOver = (event: MouseEvent) => { 7 | timer = setTimeout(() => { 8 | node.dispatchEvent(new CustomEvent('longmouseover', {detail: event})); 9 | }, duration); 10 | }; 11 | const handleMouseOut = (event: MouseEvent) => { 12 | clearTimeout(timer); 13 | node.dispatchEvent(new CustomEvent('longmouseout', {detail: event})); 14 | } 15 | 16 | node.addEventListener('mouseover', handleMouseOver); 17 | node.addEventListener('mouseout', handleMouseOut); 18 | 19 | return { 20 | update(newDuration: number) { 21 | duration = newDuration 22 | }, 23 | destroy() { 24 | node.removeEventListener('mouseover', handleMouseOver); 25 | node.removeEventListener('mouseout', handleMouseOut); 26 | } 27 | }; 28 | } 29 | -------------------------------------------------------------------------------- /client/graphpaper-inline/src/main.css: -------------------------------------------------------------------------------- 1 | /* Custom CSS for web app. */ 2 | @tailwind base; 3 | @tailwind components; 4 | @tailwind utilities; -------------------------------------------------------------------------------- /client/graphpaper-inline/src/main.js: -------------------------------------------------------------------------------- 1 | // Entrypoint for web app. 2 | 3 | import App from './App.svelte'; 4 | 5 | const app = new App({ 6 | target: document.body, 7 | }); 8 | 9 | export default app; -------------------------------------------------------------------------------- /client/graphpaper-inline/src/metrics.ts: -------------------------------------------------------------------------------- 1 | // Metrics and their definitions. 2 | 3 | import type { MetricDef } from './interfaces'; 4 | 5 | export const metricDefs: Record = { 6 | 'status': { 7 | name: '', 8 | units: '', 9 | description: 'Determines whether engine is running, completed or in error.', 10 | isScalar: true, 11 | precision: 0 12 | }, 13 | 'cpu': { 14 | name: 'CPU', 15 | units: '%', 16 | description: 'Average utilization across CPU cores.', 17 | isScalar: false, 18 | precision: 1 19 | }, 20 | 'gpu': { 21 | name: 'GPU', 22 | units: '%', 23 | description: 'Average utilization across GPUs.', 24 | isScalar: false, 25 | precision: 1 26 | }, 27 | 'ram': { 28 | name: 'RAM', 29 | units: 'GB', 30 | description: 'Utilization of RAM.', 31 | isScalar: true, 32 | precision: 1 33 | }, 34 | 'vram': { 35 | name: 'VRAM', 36 | units: 'GB', 37 | description: 'Utilization of video RAM.', 38 | isScalar: true, 39 | precision: 1 40 | }, 41 | 'wall time': { 42 | name: 'Time', 43 | units: 's', 44 | description: 'Time taken from initial display to engine completion.', 45 | isScalar: true, 46 | precision: 1 47 | }, 48 | 'avg latency': { 49 | name: 'Latency', 50 | units: 'ms', 51 | description: 'Average roundtrip latency per token', 52 | isScalar: true, 53 | precision: 0 54 | }, 55 | 'consumed': { 56 | name: 'Used', 57 | units: 'tkn', 58 | description: 'Total tokens consumed by language model.', 59 | isScalar: true, 60 | precision: 0 61 | }, 62 | 'token reduction': { 63 | name: 'Reduced', 64 | units: '%', 65 | description: 'Total tokens consumed by language model divided by total tokens.', 66 | isScalar: true, 67 | precision: 0 68 | } 69 | }; 70 | -------------------------------------------------------------------------------- /client/graphpaper-inline/src/template.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /client/graphpaper-inline/tailwind.config.js: -------------------------------------------------------------------------------- 1 | /** @type {import('tailwindcss').Config} */ 2 | module.exports = { 3 | content: ["./src/**/*.{html,ts,js,svelte}"], 4 | theme: { 5 | extend: { 6 | fontFamily: { 7 | 'token': ['JetBrains Mono'], 8 | }, 9 | keyframes: { 10 | 'cpulse': { 11 | '50%': { opacity: 0.0 } 12 | } 13 | }, 14 | animation: { 15 | 'cpulse': 'cpulse 3.5s cubic-bezier(0.4, 0, 0.6, 1) infinite' 16 | } 17 | } 18 | }, 19 | plugins: [ 20 | require('tailwind-scrollbar'), 21 | ], 22 | } 23 | 24 | -------------------------------------------------------------------------------- /client/graphpaper-inline/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "allowJs": true, 4 | "checkJs": true, 5 | "esModuleInterop": true, 6 | "forceConsistentCasingInFileNames": true, 7 | "resolveJsonModule": true, 8 | "skipLibCheck": true, 9 | "sourceMap": true, 10 | "strict": true, 11 | "verbatimModuleSyntax": true, 12 | "module": "ESNext", 13 | "moduleResolution": "bundler" 14 | } 15 | } -------------------------------------------------------------------------------- /docs/.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # version: 2 2 | 3 | # Read the Docs configuration file for Sphinx projects 4 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 5 | 6 | # Required 7 | version: 2 8 | 9 | # Set the OS, Python version and other tools you might need 10 | build: 11 | os: ubuntu-22.04 12 | tools: 13 | python: "3.10" 14 | # You can also specify other tool versions: 15 | # nodejs: "20" 16 | # rust: "1.70" 17 | # golang: "1.20" 18 | 19 | # Build documentation in the "docs/" directory with Sphinx 20 | sphinx: 21 | configuration: docs/conf.py 22 | # You can configure Sphinx to use a different builder, for instance use the dirhtml builder for simpler URLs 23 | # builder: "dirhtml" 24 | # Fail on all warnings to avoid broken references 25 | # fail_on_warning: true 26 | 27 | # Optionally build your docs in additional formats such as PDF and ePub 28 | # formats: 29 | # - pdf 30 | # - epub 31 | 32 | # Optional but recommended, declare the Python requirements required 33 | # to build your documentation 34 | # See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html 35 | # python: 36 | # install: 37 | # - requirements: docs/requirements.txt 38 | # version: 3.8 39 | python: 40 | install: 41 | - method: pip 42 | path: . 43 | extra_requirements: 44 | - docs 45 | -------------------------------------------------------------------------------- /docs/_static/css/styles.css: -------------------------------------------------------------------------------- 1 | .wy-side-nav-search > a img.logo, .wy-side-nav-search .wy-dropdown > a img.logo { 2 | width: 250px; 3 | margin-top: 20px; 4 | margin-bottom: 15px; 5 | } 6 | 7 | .wy-side-nav-search>div.version { 8 | color: black; 9 | } 10 | @media screen and (min-width: 767px) { 11 | .wy-table-responsive table td { 12 | white-space: normal; 13 | } 14 | .wy-table-responsive { 15 | overflow: visible; 16 | } 17 | } 18 | 19 | /* .wy-side-nav-search .wy-dropdown>a img.logo,.wy-side-nav-search>a img.logo { 20 | max-width: 40%; 21 | } */ 22 | 23 | .wy-side-nav-search>div.version { 24 | color: #d9d9d9; 25 | } 26 | 27 | .wy-nav-top { 28 | background: #343131; 29 | } 30 | 31 | .highlight { 32 | background: #f7f7f7; 33 | } 34 | 35 | .wy-side-nav-search input[type=text] { 36 | border-color: #666666; 37 | } 38 | 39 | a { 40 | color: #008bfb; 41 | } 42 | 43 | a:hover { 44 | color: #008bfb; 45 | } 46 | 47 | a:visited { 48 | color: #008bfb; 49 | } 50 | 51 | html.writer-html4 .rst-content dl:not(.docutils)>dt, html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.glossary):not(.simple)>dt { 52 | background: #008bfb11; 53 | color: #0086f6; 54 | border-top: 3px solid #008bfbaa; 55 | } 56 | 57 | .rst-versions .rst-current-version { 58 | color: #fcfcfc; 59 | } 60 | 61 | .wy-menu-vertical a { 62 | color: #d9d9d9; 63 | } 64 | 65 | section h2 { 66 | margin-top: 30px; 67 | } 68 | 69 | .rst-content code.literal, .rst-content tt.literal { 70 | color: #008bfb; 71 | } -------------------------------------------------------------------------------- /docs/api.rst: -------------------------------------------------------------------------------- 1 | .. currentmodule:: guidance 2 | 3 | API Reference 4 | ============= 5 | This page contains the API reference for public objects and functions in Guidance. 6 | 7 | 8 | .. _functions_api: 9 | 10 | functions 11 | --------- 12 | .. autosummary:: 13 | :toctree: generated/ 14 | 15 | guidance.gen 16 | guidance.select 17 | guidance.json 18 | 19 | 20 | .. _contexts_api: 21 | 22 | context blocks 23 | -------------- 24 | .. autosummary:: 25 | :toctree: generated/ 26 | 27 | guidance.instruction 28 | guidance.system 29 | guidance.user 30 | guidance.assistant 31 | 32 | 33 | .. _models_api: 34 | 35 | models 36 | ------ 37 | .. autosummary:: 38 | :toctree: generated/ 39 | 40 | guidance.models.Model 41 | guidance.models.LlamaCpp 42 | guidance.models.Transformers 43 | guidance.models.Anthropic 44 | guidance.models.AzureOpenAI 45 | guidance.models.Cohere 46 | guidance.models.GoogleAI 47 | guidance.models.LiteLLM 48 | guidance.models.OpenAI 49 | guidance.models.VertexAI 50 | -------------------------------------------------------------------------------- /docs/api_examples.rst: -------------------------------------------------------------------------------- 1 | .. currentmodule:: guidance 2 | 3 | .. _api_examples: 4 | 5 | API Examples 6 | ------------ 7 | 8 | These examples parallel the namespace structure of Guidance. Each object or function in Guidance has a 9 | corresponding example notebook here that demonstrates its API usage. The source notebooks 10 | are `available on GitHub `_. 11 | 12 | 13 | .. _functions_examples: 14 | 15 | functions 16 | ========= 17 | .. Examples for built-in guidance functions. 18 | 19 | .. toctree:: 20 | :glob: 21 | :maxdepth: 1 22 | 23 | example_notebooks/api_examples/library/* 24 | 25 | 26 | .. _models_examples: 27 | 28 | models 29 | ====== 30 | .. Examples for members of :ref:`guidance.models `. 31 | 32 | .. toctree:: 33 | :glob: 34 | :maxdepth: 1 35 | 36 | example_notebooks/api_examples/models/* 37 | -------------------------------------------------------------------------------- /docs/art_of_prompt_design.rst: -------------------------------------------------------------------------------- 1 | .. currentmodule:: guidance 2 | 3 | .. _art_of_prompt_design: 4 | 5 | The Art of Prompt Design 6 | ------------------------ 7 | 8 | These notebooks demonstrate how to design effective prompts and guidance programs, they also cover common useful 9 | design patterns. The source notebooks are `available on GitHub `_. 10 | 11 | 12 | .. toctree:: 13 | :glob: 14 | :maxdepth: 1 15 | 16 | example_notebooks/art_of_prompt_design/use_clear_syntax.ipynb 17 | example_notebooks/art_of_prompt_design/prompt_boundaries_and_token_healing.ipynb 18 | example_notebooks/art_of_prompt_design/tool_use.ipynb 19 | example_notebooks/art_of_prompt_design/react.ipynb 20 | example_notebooks/art_of_prompt_design/rag.ipynb -------------------------------------------------------------------------------- /docs/figures/anachronism.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/anachronism.png -------------------------------------------------------------------------------- /docs/figures/await1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/await1.png -------------------------------------------------------------------------------- /docs/figures/await2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/await2.png -------------------------------------------------------------------------------- /docs/figures/capture_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/capture_example.png -------------------------------------------------------------------------------- /docs/figures/chat1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/chat1.png -------------------------------------------------------------------------------- /docs/figures/chat_animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/chat_animation.gif -------------------------------------------------------------------------------- /docs/figures/chat_reading.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/chat_reading.png -------------------------------------------------------------------------------- /docs/figures/demo_output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/demo_output.png -------------------------------------------------------------------------------- /docs/figures/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/favicon.ico -------------------------------------------------------------------------------- /docs/figures/favicon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/favicon.png -------------------------------------------------------------------------------- /docs/figures/function.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/function.png -------------------------------------------------------------------------------- /docs/figures/gen_loop_demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/gen_loop_demo.png -------------------------------------------------------------------------------- /docs/figures/generate_select.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/generate_select.png -------------------------------------------------------------------------------- /docs/figures/generation1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/generation1.png -------------------------------------------------------------------------------- /docs/figures/get_started_button.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/get_started_button.png -------------------------------------------------------------------------------- /docs/figures/guidance_logo_blue.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 8 | 9 | 10 | 13 | 19 | 21 | 23 | 26 | 29 | 31 | 33 | 35 | 38 | 39 | 40 | 41 | -------------------------------------------------------------------------------- /docs/figures/guidance_logo_blue_dark.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 9 | 10 | 11 | 14 | 20 | 22 | 24 | 27 | 30 | 32 | 35 | 38 | 41 | 42 | 43 | 44 | -------------------------------------------------------------------------------- /docs/figures/hidden1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/hidden1.png -------------------------------------------------------------------------------- /docs/figures/json_animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/json_animation.gif -------------------------------------------------------------------------------- /docs/figures/json_syntax_variables.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/json_syntax_variables.png -------------------------------------------------------------------------------- /docs/figures/perfect_syntax.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/perfect_syntax.png -------------------------------------------------------------------------------- /docs/figures/proverb_animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/proverb_animation.gif -------------------------------------------------------------------------------- /docs/figures/proverb_output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/proverb_output.png -------------------------------------------------------------------------------- /docs/figures/select.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/select.png -------------------------------------------------------------------------------- /docs/figures/simple_fstring_llama2_7b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/simple_fstring_llama2_7b.png -------------------------------------------------------------------------------- /docs/figures/simple_gen_llama2_7b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/simple_gen_llama2_7b.png -------------------------------------------------------------------------------- /docs/figures/simple_select_llama2_7b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/simple_select_llama2_7b.png -------------------------------------------------------------------------------- /docs/figures/simple_streaming_example.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/simple_streaming_example.gif -------------------------------------------------------------------------------- /docs/figures/template_objs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/template_objs.png -------------------------------------------------------------------------------- /docs/figures/url_with_space.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/url_with_space.png -------------------------------------------------------------------------------- /docs/figures/url_without_space.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/url_without_space.png -------------------------------------------------------------------------------- /docs/figures/watch_demo_button.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/docs/figures/watch_demo_button.png -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | 2 | .. image:: figures/guidance_logo_blue.svg 3 | :width: 300px 4 | :align: center 5 | | 6 | 7 | **Guidance** enables you to control modern language models more effectively and efficiently than traditional prompting or chaining. Guidance programs allow you to interleave generation, prompting, and logical control into a single continuous flow matching how the language model actually processes the text. 8 | 9 | Install 10 | ======= 11 | 12 | Guidance can be installed from `PyPI `_:: 13 | 14 | pip install guidance 15 | 16 | 17 | Contents 18 | ======== 19 | 20 | .. toctree:: 21 | :maxdepth: 2 22 | 23 | Tutorials 24 | API reference 25 | API examples 26 | The Art of Prompt Design 27 | -------------------------------------------------------------------------------- /docs/tutorials.rst: -------------------------------------------------------------------------------- 1 | .. currentmodule:: guidance 2 | 3 | .. _tutorials: 4 | 5 | Tutorials 6 | ---------------- 7 | 8 | These notebooks demonstrate various features of `guidance``. The source notebooks 9 | are `available on GitHub `_. 10 | 11 | 12 | .. toctree:: 13 | :glob: 14 | :maxdepth: 1 15 | 16 | example_notebooks/tutorials/intro_to_guidance.ipynb 17 | example_notebooks/tutorials/token_healing.ipynb 18 | example_notebooks/tutorials/regex_constraints.ipynb 19 | example_notebooks/tutorials/guidance_acceleration.ipynb 20 | example_notebooks/tutorials/code_generation.ipynb 21 | example_notebooks/tutorials/chat.ipynb -------------------------------------------------------------------------------- /guidance/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.2.1" 2 | 3 | import sys 4 | import types 5 | 6 | from . import models 7 | from ._guidance import guidance 8 | 9 | from ._ast import GrammarNode, Function 10 | from ._utils import strip_multiline_string_indents 11 | 12 | 13 | # This makes the guidance module callable 14 | class _Guidance(types.ModuleType): 15 | def __call__( 16 | self, f=None, *, stateless=False, cache=None, dedent=True, model=models.Model 17 | ): 18 | return guidance( 19 | f, stateless=stateless, cache=cache, dedent=dedent, model=model 20 | ) 21 | 22 | 23 | sys.modules[__name__].__class__ = _Guidance 24 | 25 | # we expose all the library functions at the top level of the module 26 | from .library import * -------------------------------------------------------------------------------- /guidance/_bg/__init__.py: -------------------------------------------------------------------------------- 1 | """ Background thread for asyncio handling. 2 | 3 | This is currently being used for messaging, visualization and metrics. 4 | """ 5 | 6 | import asyncio 7 | import threading 8 | from asyncio import AbstractEventLoop, Task 9 | from concurrent.futures import Future 10 | from typing import Any, Coroutine, TypeVar 11 | 12 | T = TypeVar('T') 13 | 14 | def _start_asyncio_loop(loop: AbstractEventLoop): 15 | asyncio.set_event_loop(loop) 16 | loop.run_forever() 17 | 18 | 19 | def _asyncio_background_thread() -> tuple[threading.Thread, AbstractEventLoop]: 20 | loop = asyncio.new_event_loop() 21 | thread = threading.Thread(target=_start_asyncio_loop, args=(loop,)) 22 | thread.daemon = True 23 | return thread, loop 24 | 25 | 26 | class BackgroundAsync: 27 | """ Runs background thread that has an asyncio event loop.""" 28 | 29 | def __init__(self): 30 | """ Initializes. """ 31 | self._loop = None 32 | self._thread = None 33 | 34 | def _thread_and_loop(self) -> tuple[threading.Thread, AbstractEventLoop]: 35 | if self._loop is None: 36 | self._thread, self._loop = _asyncio_background_thread() 37 | self._thread.start() 38 | return self._thread, self._loop 39 | 40 | def call_soon_threadsafe(self, cb, *args, context = None): 41 | """ Fires callback in background thread.""" 42 | 43 | _, loop = self._thread_and_loop() 44 | return loop.call_soon_threadsafe(cb, *args, context=context) 45 | 46 | def run_async_coroutine(self, coroutine: Coroutine[Any, Any, T]) -> Future[T]: 47 | """ Runs an asynchronous coroutine in the visual thread. 48 | 49 | Args: 50 | coroutine: Coroutine to be run on visual thread. 51 | 52 | Returns: 53 | Future of coroutine. 54 | """ 55 | _, loop = self._thread_and_loop() 56 | future = asyncio.run_coroutine_threadsafe(coroutine, loop) 57 | return future 58 | 59 | @staticmethod 60 | async def async_task(coroutine: Coroutine[Any, Any, T]) -> Task[T]: 61 | """ Creates an asyncio task from coroutine. 62 | 63 | Args: 64 | coroutine: Coroutine within task. 65 | 66 | Returns: 67 | Asyncio task. 68 | """ 69 | task = asyncio.create_task(coroutine) 70 | return task 71 | 72 | @staticmethod 73 | async def print_all_tasks(): # pragma: no cover 74 | """Prints all tasks running in background thread loop (for debugging purposes).""" 75 | for task in asyncio.all_tasks(): 76 | print(task) 77 | -------------------------------------------------------------------------------- /guidance/_guidance.pyi: -------------------------------------------------------------------------------- 1 | import sys 2 | from typing import ( 3 | Any, 4 | Callable, 5 | Literal, 6 | TypeVar, 7 | Union, 8 | overload, 9 | ) 10 | from contextvars import ContextVar 11 | if sys.version_info >= (3, 10): 12 | from typing import ParamSpec, TypeAlias, Concatenate 13 | else: 14 | from typing_extensions import ParamSpec, TypeAlias, Concatenate 15 | 16 | from ._ast import RuleNode, Function 17 | from .models import Model 18 | 19 | _in_stateless_context: ContextVar[bool] 20 | 21 | P = ParamSpec("P") 22 | M: TypeAlias = Any # sort of Union[Model, GrammarNode]? 23 | R = TypeVar("R", bound = Union[Function, RuleNode]) 24 | GuidanceWrappable = Callable[Concatenate[M, P], M] 25 | GuidanceFunction = Callable[P, R] 26 | StatefulGuidanceFunction = GuidanceFunction[P, Function] 27 | StatelessGuidanceFunction = GuidanceFunction[P, RuleNode] 28 | 29 | @overload 30 | def guidance( 31 | f: GuidanceWrappable[P], 32 | *, 33 | stateless: Literal[False] = False, 34 | cache: bool = ..., 35 | dedent: bool = ..., 36 | model: type[Model] = ..., 37 | ) -> StatefulGuidanceFunction[P]: 38 | ... 39 | 40 | 41 | @overload 42 | def guidance( 43 | f: None = None, 44 | *, 45 | stateless: Literal[False] = False, 46 | cache: bool = ..., 47 | dedent: bool = ..., 48 | model: type[Model] = ..., 49 | ) -> Callable[[GuidanceWrappable[P]], StatefulGuidanceFunction[P]]: 50 | ... 51 | 52 | 53 | @overload 54 | def guidance( 55 | f: GuidanceWrappable[P], 56 | *, 57 | stateless: Literal[True], 58 | cache: bool = ..., 59 | dedent: bool = ..., 60 | model: type[Model] = ..., 61 | ) -> StatelessGuidanceFunction[P]: 62 | ... 63 | 64 | 65 | @overload 66 | def guidance( 67 | f: None = None, 68 | *, 69 | stateless: Literal[True], 70 | cache: bool = ..., 71 | dedent: bool = ..., 72 | model: type[Model] = ..., 73 | ) -> Callable[[GuidanceWrappable[P]], StatelessGuidanceFunction[P]]: 74 | ... 75 | 76 | 77 | @overload 78 | def guidance( 79 | f: GuidanceWrappable[P], 80 | *, 81 | stateless: Callable[..., bool], 82 | cache: bool = ..., 83 | dedent: bool = ..., 84 | model: type[Model] = ..., 85 | ) -> GuidanceFunction[P, Union[Function, RuleNode]]: 86 | ... 87 | 88 | 89 | @overload 90 | def guidance( 91 | f: None = None, 92 | *, 93 | stateless: Callable[..., bool], 94 | cache: bool = ..., 95 | dedent: bool = ..., 96 | model: type[Model] = ..., 97 | ) -> Callable[[GuidanceWrappable[P]], GuidanceFunction[P, Union[Function, RuleNode]]]: 98 | ... 99 | -------------------------------------------------------------------------------- /guidance/bench/__init__.py: -------------------------------------------------------------------------------- 1 | """Elementary benchmarking for `guidance` development purposes. 2 | 3 | `guidance` lives in a fast paced LLM environment, has complex dependencies and is tricky to implement. 4 | These benchmarks are designed to focus on key use cases, where regressions can create havoc. 5 | 6 | General guidelines: 7 | - Simplicity first, then customization - reproducibility by the community is encouraged 8 | - Everything takes forever - allow a pathway to scale horizontally 9 | - Goalposts shift - some of the code for benchmarking will change frequently and that's okay 10 | 11 | Implementation: 12 | 13 | The `bench` function is provided for no frills benchmarking that is designated for 14 | automated testing. 15 | 16 | For customization, we provide a notebook demonstration of how to run custom benchmarks 17 | that are near mirror versions of what is available in the `bench` function provided. 18 | 19 | Not implemented yet, but we intend to provide an avenue of running the benchmarks via 20 | docker containers that have GPU resourcing to scale horizontally. 21 | """ 22 | 23 | from guidance.bench._powerlift import ( 24 | retrieve_langchain, 25 | langchain_chat_extract_runner, 26 | langchain_chat_extract_filter_template, 27 | ) 28 | from guidance.bench._api import bench 29 | 30 | # TODO(nopdive): Enable docker containers to execute benchmarking easily 31 | -------------------------------------------------------------------------------- /guidance/bench/_api.py: -------------------------------------------------------------------------------- 1 | """User facing API for benchmarking.""" 2 | 3 | from typing import List, Optional, Tuple, Union 4 | from pathlib import Path 5 | from guidance.bench._utils import lib_bench_dir 6 | 7 | """Available models to run benchmark against.""" 8 | AVAILABLE_MODELS = [ 9 | "guidance-mistral-7b-instruct", 10 | "base-mistral-7b-instruct", 11 | "guidance-phi-3-mini-4k-instruct", 12 | "base-phi-3-mini-4k-instruct", 13 | "guidance-llama2-7b-32k-instruct", 14 | "base-llama2-7b-32k-instruct", 15 | ] 16 | 17 | 18 | def bench( 19 | db_url: str, 20 | experiment_name: str, 21 | models: List[str] = AVAILABLE_MODELS, 22 | force_recreate: bool = False, 23 | timeout: int = 3600, 24 | cache_dir: Union[str, Path] = lib_bench_dir() / "cache", 25 | debug_mode: bool = False, 26 | ) -> Tuple[object, object]: 27 | """Benchmarks guidance against preset tasks. 28 | 29 | This runs on a single machine, one trial at a time. 30 | To run this the first time you will need API_LANGCHAIN_KEY set as an environment variable. 31 | 32 | Args: 33 | db_url (str): Database connection string. 34 | experiment_name (str): Name of experiment to create / run. 35 | models (List[str], optional): Models to benchmark. Defaults to AVAILABLE_MODELS. 36 | force_recreate (bool, optional): Recreate the database before benchmarking. Defaults to False. 37 | timeout (int, optional): Max execution time per trial. Defaults to 3600. 38 | cache_dir (Union[str, Path], optional): Cache to store external datasets. Defaults to lib_bench_dir() / "cache". 39 | debug_mode (bool): Set this when you require a debugger to step line by line in the trial_runner. 40 | 41 | Returns: 42 | Tuple[object, object]: (status, results) data frames where status relates to trials, results are wide form aggregates of each model. 43 | """ 44 | from guidance.bench._powerlift import bench as inner_bench 45 | 46 | status_df, result_df = inner_bench( 47 | db_url, experiment_name, models, force_recreate, timeout, cache_dir, debug_mode 48 | ) 49 | return status_df, result_df 50 | -------------------------------------------------------------------------------- /guidance/bench/_utils.py: -------------------------------------------------------------------------------- 1 | """Shared utility functions for module.""" 2 | 3 | import os 4 | from pathlib import Path 5 | 6 | def lib_bench_dir() -> Path: 7 | """Library directory to store configurations and cached assets for benchmarking. 8 | 9 | If the library directory does not exist, it is created as a side effect. 10 | 11 | The library bench directory path can also be set via env var `GUIDANCE_BENCH_DIR`. 12 | 13 | Returns: 14 | Path: Library's directory path for benchmarking. 15 | """ 16 | 17 | env_lib_path = os.environ.get("GUIDANCE_BENCH_DIR", None) 18 | if env_lib_path is None: 19 | lib_path = Path.home() / ".guidance-bench" 20 | else: 21 | lib_path = Path(env_lib_path) 22 | Path.mkdir(lib_path, parents=True, exist_ok=True) 23 | 24 | return lib_path 25 | -------------------------------------------------------------------------------- /guidance/library/__init__.py: -------------------------------------------------------------------------------- 1 | # import functions that can be called directly 2 | from ._gen import gen, call_tool, regex 3 | from ._image import image, gen_image 4 | from ._audio import audio, gen_audio 5 | from ._video import video, gen_video 6 | from ._capture import capture 7 | 8 | # core grammar functions 9 | from .._grammar import select 10 | from .._grammar import with_temperature 11 | from .._grammar import string 12 | from .._grammar import token_limit 13 | 14 | # context blocks 15 | from ._block import block 16 | from ._role import role, system, assistant, user #, function, instruction, indent_roles 17 | 18 | # from ..models._model import context_free 19 | 20 | # stateless library functions 21 | from ._sequences import one_or_more, zero_or_more, at_most_n_repeats, exactly_n_repeats, sequence 22 | from ._substring import substring 23 | from ._optional import optional 24 | from ._tool import Tool 25 | from ._json import json 26 | from ._ebnf import lark, gbnf_to_lark 27 | -------------------------------------------------------------------------------- /guidance/library/_audio.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import typing 3 | import base64 4 | 5 | from .._guidance import guidance 6 | from .._utils import bytes_from 7 | from .._ast import AudioBlob, GenAudio 8 | 9 | 10 | @guidance 11 | def audio(lm, src: typing.Union[str, pathlib.Path, bytes], allow_local: bool = True): 12 | bytes_data = bytes_from(src, allow_local=allow_local) 13 | base64_string = base64.b64encode(bytes_data).decode('utf-8') 14 | lm += AudioBlob(data=base64_string) 15 | return lm 16 | 17 | 18 | @guidance 19 | def gen_audio(lm, **kwargs): 20 | return lm + GenAudio(kwargs=kwargs) 21 | -------------------------------------------------------------------------------- /guidance/library/_block.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | from typing import Optional, Union 3 | from .._ast import ASTNode, Function 4 | from ..models._base._model import _active_blocks 5 | from .._guidance import _in_stateless_context 6 | 7 | class Block: 8 | def __init__(self, name: Optional[str], opener: Union[str, Function, ASTNode], closer: Union[str, Function, ASTNode]): 9 | self.name = name 10 | self.opener = opener 11 | self.closer = closer 12 | 13 | 14 | @contextmanager 15 | def block(name=None, opener=None, closer=None): 16 | if _in_stateless_context.get(): 17 | raise RuntimeError("Cannot use roles or other blocks when stateless=True") 18 | current_blocks = _active_blocks.get() 19 | new_block = Block(name=name, opener=opener, closer=closer) 20 | token = _active_blocks.set(current_blocks + (new_block,)) 21 | try: 22 | yield 23 | finally: 24 | _active_blocks.reset(token) 25 | -------------------------------------------------------------------------------- /guidance/library/_capture.py: -------------------------------------------------------------------------------- 1 | from .._guidance import guidance 2 | from .._grammar import capture as grammar_capture, GrammarNode 3 | from ._block import block 4 | 5 | @guidance(stateless=lambda *args, **kwargs: isinstance(args[0], GrammarNode)) 6 | def capture(lm, value, name): 7 | if isinstance(value, GrammarNode): 8 | return lm + grammar_capture(value, name) 9 | else: 10 | with block(name): 11 | lm += value 12 | return lm 13 | -------------------------------------------------------------------------------- /guidance/library/_ebnf.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from llguidance.gbnf_to_lark import gbnf_to_lark as _gbnf_to_lark 4 | 5 | from .._ast import RuleNode, LarkNode, GrammarNode 6 | from .._grammar import capture, token_limit, with_temperature 7 | 8 | 9 | def lark( 10 | lark_grammar: str, 11 | *, 12 | name: Optional[str] = None, 13 | temperature: Optional[float] = None, 14 | max_tokens: Optional[int] = None, 15 | ) -> GrammarNode: 16 | """ 17 | Builds a guidance grammar from (a variant of) the EBNF syntax used by the Lark parsing toolkit. 18 | 19 | See documentation at https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md for more 20 | details. 21 | """ 22 | node = RuleNode( 23 | name=name or "lark", 24 | value=LarkNode(lark_grammar=lark_grammar) 25 | ) 26 | if temperature is not None: 27 | node = with_temperature(node, temperature) 28 | if max_tokens is not None: 29 | node = token_limit(node, max_tokens) 30 | if name is not None: 31 | node = capture(node, name) 32 | 33 | return node 34 | 35 | 36 | def gbnf_to_lark(gbnf_grammar: str) -> str: 37 | """ 38 | Converts a GBNF (llama.cpp) grammar to Lark(-like) syntax. This is a best-effort 39 | conversion and may not work for all grammars. We recommend using this function 40 | as a starting point and then manually editing the resulting Lark grammar to suit 41 | your needs. 42 | 43 | See documentation at https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md 44 | for more information on the output syntax's semantics. 45 | """ 46 | return _gbnf_to_lark(gbnf_grammar) 47 | -------------------------------------------------------------------------------- /guidance/library/_image.py: -------------------------------------------------------------------------------- 1 | import importlib.resources 2 | import pathlib 3 | import typing 4 | import base64 5 | import re 6 | 7 | from .._guidance import guidance 8 | from .._utils import bytes_from 9 | from .._ast import ImageBlob, ImageUrl 10 | from ..trace._trace import ImageOutput 11 | 12 | 13 | @guidance 14 | def image(lm, src: typing.Union[str, pathlib.Path, bytes], allow_local: bool = True): 15 | if isinstance(src, str) and re.match(r"^(?!file://)[^:/]+://", src): 16 | lm += ImageUrl(url=src) 17 | else: 18 | bytes_data = bytes_from(src, allow_local=allow_local) 19 | base64_string = base64.b64encode(bytes_data).decode('utf-8') 20 | lm += ImageBlob(data=base64_string) 21 | return lm 22 | 23 | 24 | @guidance 25 | def gen_image(lm): 26 | # TODO(nopdive): Mock for testing. Remove all of this code later. 27 | with importlib.resources.files("guidance").joinpath("resources/sample_image.png").open("rb") as f: 28 | bytes_data = f.read() 29 | base64_string = base64.b64encode(bytes_data).decode('utf-8') 30 | lm += ImageOutput(value=base64_string, is_input=False) 31 | return lm 32 | -------------------------------------------------------------------------------- /guidance/library/_optional.py: -------------------------------------------------------------------------------- 1 | from .._guidance import guidance 2 | from .._grammar import repeat 3 | 4 | @guidance(stateless=True) 5 | def optional(lm, value): 6 | return lm + repeat(value, 0, 1) 7 | -------------------------------------------------------------------------------- /guidance/library/_pydantic.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from typing import Any, Dict, Type, Union 3 | 4 | import pydantic 5 | 6 | 7 | class GenerateJsonSchemaSafe(pydantic.json_schema.GenerateJsonSchema): 8 | """ 9 | Subclass pydantic's GenerateJsonSchema to catch pydantic schemas that will not 10 | translate properly to json schemas used for generation. 11 | 12 | In particular, JSON schemas do not offer a way to specify "key type", 13 | so we need to raise an exception if users attempt to specify non-string 14 | keys through pydantic. Otherwise, they may get unexpected output from 15 | model generation. 16 | """ 17 | 18 | def generate_inner(self, schema): 19 | if schema["type"] == "dict": 20 | key_type = schema["keys_schema"]["type"] 21 | if key_type != "str": 22 | raise TypeError( 23 | f"JSON does not support non-string keys, got type {key_type}" 24 | ) 25 | return super().generate_inner(schema) 26 | 27 | 28 | def pydantic_to_json_schema( 29 | schema: Union[Type["pydantic.BaseModel"], "pydantic.TypeAdapter"] 30 | ) -> Dict[str, Any]: 31 | if inspect.isclass(schema) and issubclass(schema, pydantic.BaseModel): 32 | return schema.model_json_schema(schema_generator=GenerateJsonSchemaSafe) 33 | if isinstance(schema, pydantic.TypeAdapter): 34 | return schema.json_schema(schema_generator=GenerateJsonSchemaSafe) 35 | raise TypeError(f"Cannot generate json schema from type {type(schema)}") 36 | -------------------------------------------------------------------------------- /guidance/library/_role.py: -------------------------------------------------------------------------------- 1 | from contextlib import AbstractContextManager 2 | from .._ast import RoleEnd, RoleStart 3 | from ._block import block 4 | 5 | 6 | # TODO HN: Add a docstring to better describe arbitrary role functions 7 | def role(role: str) -> AbstractContextManager: 8 | return block( 9 | name=None, 10 | opener=RoleStart(role), 11 | closer=RoleEnd(role), 12 | ) 13 | 14 | 15 | def system() -> AbstractContextManager: 16 | """Indicate the 'system' prompt 17 | 18 | A convention has grown up around 'chat' APIs that 19 | prompts are split into three parts: system, user 20 | and assistant. 21 | This indicates the start of a 'system' block, which 22 | provides background information to the LLM. 23 | 24 | >>> with system(): 25 | >>> lm += "A system prompt" 26 | 27 | """ 28 | return role("system") 29 | 30 | 31 | def user() -> AbstractContextManager: 32 | """Indicate the 'user' prompt 33 | 34 | A convention has grown up around 'chat' APIs that 35 | prompts are split into three parts: system, user 36 | and assistant. 37 | This indicates the start of a 'user' block, which 38 | provides input to the LLM from the user. 39 | 40 | >>> with user(): 41 | >>> lm += "What the user said" 42 | 43 | """ 44 | return role("user") 45 | 46 | 47 | def assistant() -> AbstractContextManager: 48 | """Indicate the 'assistant' prompt 49 | 50 | A convention has grown up around 'chat' APIs that 51 | prompts are split into three parts: system, user 52 | and assistant. 53 | This indicates the start of an 'assistant' block, which 54 | marks LLM response (or where the LLM will generate 55 | the next response). 56 | 57 | >>> with assistant(): 58 | >>> lm += gen(name="model_output", max_tokens=20) 59 | 60 | """ 61 | return role("assistant") 62 | -------------------------------------------------------------------------------- /guidance/library/_sequences.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | from .._grammar import repeat 4 | from .._guidance import guidance 5 | 6 | 7 | @guidance(stateless=True) 8 | def exactly_n_repeats(model, value, n_repeats: int): 9 | return model + repeat(value, min=n_repeats, max=n_repeats) 10 | 11 | 12 | @guidance(stateless=True) 13 | def at_most_n_repeats(model, value, n_repeats: int): 14 | return model + repeat(value, min=0, max=n_repeats) 15 | 16 | 17 | @guidance(stateless=True) 18 | def sequence(model, value, min_length: int = 0, max_length: Union[int, None] = None): 19 | # Just an alias for repeat for now -- TODO: remove? 20 | return model + repeat(value, min=min_length, max=max_length) 21 | 22 | 23 | @guidance(stateless=True) 24 | def one_or_more(model, value): 25 | return model + repeat(value, min=1) 26 | 27 | 28 | @guidance(stateless=True) 29 | def zero_or_more(model, value): 30 | return model + repeat(value, min=0) 31 | -------------------------------------------------------------------------------- /guidance/library/_subgrammar.py: -------------------------------------------------------------------------------- 1 | from .._ast import GrammarNode, RuleNode 2 | from .._grammar import subgrammar, regex 3 | 4 | __all__ = ["subgrammar", "regex", "as_regular_grammar", "lexeme"] 5 | 6 | def as_regular_grammar(node: GrammarNode, lexeme=False): 7 | # TODO: Remove this assertion-only check? 8 | if isinstance(node, RuleNode): 9 | rule = node 10 | else: 11 | rule = RuleNode("dummy", node) 12 | assert rule.is_allowed_in_lark_terminal 13 | return node 14 | 15 | def lexeme(body_regex: str, json_string: bool = False): 16 | if json_string: 17 | raise NotImplementedError("JSON strings are not supported") 18 | return regex(body_regex) 19 | -------------------------------------------------------------------------------- /guidance/library/_substring.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Callable, Iterable, Literal, Optional, Union 3 | 4 | from .._ast import RuleNode, SubstringNode 5 | 6 | 7 | def chunk_on_word(text: str) -> list[str]: 8 | return re.findall(r"(\s+|\w+|[^\s\w]+)", text) 9 | 10 | 11 | def substring( 12 | target_string: str, 13 | *, 14 | chunk: Union[Literal["word", "character"], Callable[[str], Iterable[str]]] = "word", 15 | name: Optional[str] = None, 16 | ) -> RuleNode: 17 | chunks: Iterable[str] 18 | if chunk == "word": 19 | chunks = chunk_on_word(target_string) 20 | elif chunk == "character": 21 | chunks = tuple(target_string) 22 | elif callable(chunk): 23 | chunks = chunk(target_string) 24 | if "".join(chunks) != target_string: 25 | raise ValueError( 26 | "chunk_on function must return a sequence of strings that can be joined to form the target string" 27 | ) 28 | else: 29 | raise ValueError(f"Invalid `chunk` value: {chunk!r}. Expected 'word', 'character', or a function.") 30 | 31 | return RuleNode( 32 | name=name or "substring", 33 | value=SubstringNode(tuple(chunks)), 34 | capture=name, 35 | ) 36 | -------------------------------------------------------------------------------- /guidance/library/_tool.py: -------------------------------------------------------------------------------- 1 | from .._guidance import guidance 2 | from .._grammar import select 3 | from ._optional import optional 4 | from ._sequences import zero_or_more 5 | from ._subgrammar import lexeme, subgrammar 6 | 7 | class Tool: 8 | def __init__(self, call_grammar=None, tool_call=None, callable=None): 9 | # call_grammar specifies how the tool can be called. Crucially, it has to capture the args in variable 'tool_args' 10 | # tool_call is a guidance function actually calls the tool, and returns an lm object with whatever outputs it wants 11 | # callable: guidance function or regular callable, will be converted to grammar 12 | # TODO: hidden is not working yet 13 | first_option = (call_grammar is not None) and (tool_call is not None) 14 | second_option = callable is not None 15 | # either both are true or both false 16 | if first_option == second_option: 17 | raise Exception( 18 | "Must pass either (call_grammar, tool call) or callable, but not both or neither" 19 | ) 20 | if second_option: 21 | call_grammar, tool_call = fn_to_grammar_call(callable) 22 | self.call_grammar = call_grammar 23 | self.tool_call = tool_call 24 | 25 | 26 | def basic_func_grammar(name): 27 | arg = lexeme(r"[^,=)]+") 28 | kwarg = arg + "=" + arg 29 | args = arg + zero_or_more("," + arg) 30 | kwargs = kwarg + zero_or_more("," + kwarg) 31 | 32 | obj = name + "(" 33 | obj += subgrammar( 34 | name="tool_args", 35 | body=optional( 36 | select([ 37 | args, 38 | kwargs, 39 | args + "," + kwargs, 40 | ]) 41 | ), 42 | skip_regex=r" *" 43 | ) 44 | obj += ")" 45 | return obj 46 | 47 | 48 | def fn_to_grammar_call(callable): 49 | # TODO later: validate the call. Here is code to get required and optional args of 'guidance_fn': 50 | # name = guidance_fn.__name__ 51 | # required_args = [] 52 | # optional_args = [] 53 | # sig = inspect.signature(guidance_fn) 54 | # for i, x in enumerate(sig.parameters.values()): 55 | # if i == 0: 56 | # continue 57 | # if x.default is x.empty: 58 | # required_args.append(x.name) 59 | # else: 60 | # optional_args.append(x.name) 61 | name = callable.__name__ 62 | call_grammar = basic_func_grammar(name) 63 | 64 | @guidance(dedent=False) 65 | def basic_tool_call(lm): 66 | args = lm["tool_args"] 67 | args = args.split(",") 68 | positional = [x.strip() for x in args if "=" not in x] 69 | kwargs = dict([tuple(x.strip().split("=")) for x in args if "=" in x]) 70 | lm += callable(*positional, **kwargs) 71 | return lm 72 | 73 | return call_grammar, basic_tool_call 74 | -------------------------------------------------------------------------------- /guidance/library/_video.py: -------------------------------------------------------------------------------- 1 | import importlib.resources 2 | import pathlib 3 | import typing 4 | import base64 5 | 6 | 7 | from .._guidance import guidance 8 | from .._utils import bytes_from 9 | # from ..trace._trace import VideoInput 10 | from ..trace._trace import VideoOutput 11 | 12 | 13 | @guidance 14 | def video(lm, src: typing.Union[str, pathlib.Path, bytes], allow_local: bool = True): 15 | # TODO(nopdive): Mock for testing. Remove all of this code later. 16 | bytes_data = bytes_from(src, allow_local=allow_local) 17 | base64_string = base64.b64encode(bytes_data).decode('utf-8') 18 | lm += VideoOutput(value=base64_string, is_input=True) 19 | # lm += VideoInput(value=base64_string) 20 | return lm 21 | 22 | 23 | @guidance 24 | def gen_video(lm): 25 | # TODO(nopdive): Mock for testing. Remove all of this code later. 26 | with importlib.resources.files("guidance").joinpath("resources/sample_video.png").open("rb") as f: 27 | bytes_data = f.read() 28 | base64_string = base64.b64encode(bytes_data).decode('utf-8') 29 | lm += VideoOutput(value=base64_string, is_input=False) 30 | return lm 31 | -------------------------------------------------------------------------------- /guidance/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | """ Metrics that arise from both language models and its execution environment.""" 2 | 3 | from ._metrics import ALL_METRICS, METRICS_TOPIC 4 | from ._metrics import MonitoringMetric, PostExecMetrics, Monitor, PeriodicMetricsGenerator -------------------------------------------------------------------------------- /guidance/models/__init__.py: -------------------------------------------------------------------------------- 1 | from ._base import Model 2 | 3 | # from ._engine import Instruct, Chat 4 | 5 | # local models 6 | from ._transformers import Transformers, TransformersTokenizer 7 | from ._llama_cpp import LlamaCpp 8 | from ._mock import Mock # , MockChat 9 | 10 | # from .vertexai._vertexai import ( 11 | # VertexAI, 12 | # VertexAIChat, 13 | # VertexAICompletion, 14 | # VertexAIInstruct, 15 | # ) 16 | # from ._azure_openai import ( 17 | # AzureOpenAI, 18 | # ) 19 | # from ._azureai_studio import AzureAIStudioChat 20 | from ._openai import OpenAI 21 | 22 | # from ._lite_llm import LiteLLM, LiteLLMChat, LiteLLMInstruct, LiteLLMCompletion 23 | # from ._cohere import Cohere, CohereCompletion, CohereInstruct 24 | # from ._anthropic import Anthropic 25 | # from ._googleai import GoogleAI, GoogleAIChat 26 | # from ._togetherai import ( 27 | # TogetherAI, 28 | # TogetherAIChat, 29 | # TogetherAIInstruct, 30 | # TogetherAICompletion, 31 | # ) 32 | from . import experimental 33 | -------------------------------------------------------------------------------- /guidance/models/_base/__init__.py: -------------------------------------------------------------------------------- 1 | from ._interpreter import Interpreter 2 | from ._model import Model 3 | from ._state import State 4 | 5 | __all__ = [ 6 | "Model", 7 | "role", 8 | "State", 9 | "Message", 10 | "Interpreter", 11 | "ASTNode", 12 | "ContentChunk", 13 | "MessageChunk", 14 | ] 15 | -------------------------------------------------------------------------------- /guidance/models/_base/_state.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Optional, TypedDict, Union 3 | 4 | from ...trace import CaptureOutput 5 | 6 | 7 | class CaptureVar(TypedDict): 8 | value: str 9 | log_prob: Optional[float] 10 | 11 | 12 | class State(ABC): 13 | def __init__(self) -> None: 14 | self.captures: dict[str, Union[CaptureVar, list[CaptureVar]]] = {} 15 | self.active_role: Optional[str] = None 16 | 17 | @abstractmethod 18 | def __str__(self) -> str: 19 | pass 20 | 21 | def apply_capture( 22 | self, name: str, value: Optional[str], log_prob=Optional[float], is_append: bool = False 23 | ) -> CaptureOutput: 24 | if value is None: 25 | # A "reset" signal 26 | self.captures.pop(name) 27 | else: 28 | var = CaptureVar(value=value, log_prob=log_prob) 29 | if is_append: 30 | vars = self.captures.get(name, []) 31 | if not isinstance(vars, list): 32 | vars = [vars] 33 | vars.append(var) 34 | self.captures[name] = vars 35 | else: 36 | self.captures[name] = var 37 | 38 | return CaptureOutput( 39 | name=name, 40 | value=value, 41 | log_probs=log_prob or float("nan"), 42 | is_append=is_append, 43 | ) 44 | -------------------------------------------------------------------------------- /guidance/models/_byte_tokenizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from ._engine import Tokenizer 3 | from ._engine._tokenizer import TokenizerWrappable 4 | from ..chat import load_template_class 5 | from typing import List 6 | 7 | class ByteTokenizer(Tokenizer): 8 | def __init__(self, chat_template=None): 9 | # directly map integer values to byte strings 10 | all_bytes = [bytes([i]) for i in range(256)] 11 | bos = b"" 12 | tokens = np.array(all_bytes + [bos], dtype="object") 13 | ll_tokenizer = TokenizerWrappable( 14 | eos_token_id=256, 15 | bos_token_id=256, 16 | tokens=tokens, 17 | special_token_ids=[], 18 | # ENCODE MUST BE OVERRIDDEN 19 | encode_callable=self.encode, 20 | ).as_ll_tokenizer() 21 | 22 | super().__init__( 23 | ll_tokenizer=ll_tokenizer, 24 | chat_template=chat_template, 25 | bos_token_id=256, 26 | ) 27 | 28 | def encode(self, byte_string: bytes) -> List[int]: 29 | """Returns a list of tokens that represent the given byte string.""" 30 | if isinstance(byte_string, str): 31 | byte_string = byte_string.encode("utf8") 32 | i = 0 33 | result = [] 34 | while i < len(byte_string): 35 | if byte_string[i:i+3] == b'': 36 | result.append(256) 37 | i += 3 # Skip the next two characters as part of '' 38 | else: 39 | result.append(byte_string[i]) 40 | i += 1 41 | return result 42 | -------------------------------------------------------------------------------- /guidance/models/_engine/__init__.py: -------------------------------------------------------------------------------- 1 | from ._tokenizer import Tokenizer # isort:skip 2 | from ._interpreter import EngineInterpreter, Llama3VisionInterpreter, Phi3VisionInterpreter 3 | from ._engine import Engine 4 | from ._state import EngineState 5 | 6 | __all__ = [ 7 | "Tokenizer", 8 | "Engine", 9 | "EngineInterpreter", 10 | "EngineState", 11 | "Llama3VisionInterpreter", 12 | "Phi3VisionInterpreter", 13 | ] 14 | -------------------------------------------------------------------------------- /guidance/models/_engine/_state.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from .._base import State 4 | 5 | 6 | class EngineState(State): 7 | def __init__(self) -> None: 8 | super().__init__() 9 | self.prompt: str = "" 10 | self.images: list[Any] = [] 11 | self.audio: list[Any] = [] 12 | self.videos: list[Any] = [] 13 | 14 | def __str__(self) -> str: 15 | return self.prompt 16 | -------------------------------------------------------------------------------- /guidance/models/_openai.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | 4 | from ._base import Model 5 | from ._openai_base import ( 6 | BaseOpenAIInterpreter, 7 | Message, 8 | OpenAIAudioMixin, 9 | OpenAIImageMixin, 10 | OpenAIJSONMixin, 11 | OpenAIRegexMixin, 12 | OpenAIRuleMixin, 13 | ) 14 | 15 | 16 | class OpenAIInterpreter(OpenAIRuleMixin, OpenAIJSONMixin, OpenAIRegexMixin, BaseOpenAIInterpreter): 17 | def __init__( 18 | self, 19 | model: str, 20 | api_key: Optional[str] = None, 21 | **kwargs, 22 | ): 23 | try: 24 | import openai 25 | except ImportError: 26 | raise Exception( 27 | "Please install the openai package version >= 1 using `pip install openai -U` in order to use guidance.models.OpenAI!" 28 | ) 29 | client = openai.OpenAI(api_key=api_key, **kwargs) 30 | super().__init__(model=model, client=client) 31 | 32 | 33 | class OpenAI(Model): 34 | def __init__( 35 | self, 36 | model: str, 37 | echo: bool = True, 38 | *, 39 | api_key: Optional[str] = None, 40 | **kwargs, 41 | ): 42 | """Build a new OpenAI model object that represents a model in a given state. 43 | 44 | Parameters 45 | ---------- 46 | model : str 47 | The name of the OpenAI model to use (e.g. gpt-4o-mini). 48 | echo : bool 49 | If true the final result of creating this model state will be displayed (as HTML in a notebook). 50 | api_key : None or str 51 | The OpenAI API key to use for remote requests, passed directly to the `openai.OpenAI` constructor. 52 | 53 | **kwargs : 54 | All extra keyword arguments are passed directly to the `openai.OpenAI` constructor. Commonly used argument 55 | names include `base_url` and `organization` 56 | """ 57 | 58 | if "audio-preview" in model: 59 | interpreter_cls = type( 60 | "OpenAIAudioInterpreter", (OpenAIAudioMixin, OpenAIInterpreter), {} 61 | ) 62 | elif model.startswith("gpt-4o") or model.startswith("o1"): 63 | interpreter_cls = type( 64 | "OpenAIImageInterpreter", (OpenAIImageMixin, OpenAIInterpreter), {} 65 | ) 66 | else: 67 | interpreter_cls = OpenAIInterpreter 68 | 69 | super().__init__(interpreter=interpreter_cls(model, api_key=api_key, **kwargs), echo=echo) 70 | -------------------------------------------------------------------------------- /guidance/models/broken_models/README.MD: -------------------------------------------------------------------------------- 1 | These model files use an older version of guidance's internal API design, and need to be updated. They're kept here commented in the codebase solely as reference documentation to help with the migration. They cannot and should not be imported from this repository. -------------------------------------------------------------------------------- /guidance/models/broken_models/_cohere.py: -------------------------------------------------------------------------------- 1 | # from ._lite_llm import LiteLLMEngine, LiteLLM, LiteLLMCompletion, LiteLLMInstruct 2 | 3 | 4 | # class Cohere(LiteLLM): 5 | # def __init__( 6 | # self, 7 | # model, 8 | # tokenizer=None, 9 | # echo=True, 10 | # timeout=0.5, 11 | # compute_log_probs=False, 12 | # max_streaming_tokens=1000, 13 | # ): 14 | # """Build a new Anthropic model object that represents a model in a given state.""" 15 | # try: 16 | # import tokenizers 17 | # except ModuleNotFoundError: 18 | # raise Exception( 19 | # "Please install the HuggingFace tokenizers package using `pip install tokenizers -U` in order to use guidance.models.Cohere!" 20 | # ) 21 | 22 | # # get the tokenizer 23 | # if tokenizer is None: 24 | # try: 25 | # tokenizer = tokenizers.Tokenizer.from_pretrained("Cohere/" + model) 26 | # except: 27 | # tokenizer = tokenizers.Tokenizer.from_pretrained( 28 | # "Cohere/command-nightly" 29 | # ) 30 | 31 | # super().__init__( 32 | # model, 33 | # tokenizer=tokenizer, 34 | # echo=echo, 35 | # timeout=timeout, 36 | # max_streaming_tokens=max_streaming_tokens, 37 | # compute_log_probs=compute_log_probs, 38 | # ) 39 | 40 | 41 | # class CohereCompletion(Cohere, LiteLLMCompletion): 42 | # pass 43 | 44 | 45 | # class CohereInstruct(Cohere, LiteLLMInstruct): 46 | # pass 47 | -------------------------------------------------------------------------------- /guidance/models/broken_models/_togetherai.py: -------------------------------------------------------------------------------- 1 | # import os 2 | # from .._engine._engine import Chat, Instruct 3 | # from .._openai import ( 4 | # OpenAI, 5 | # OpenAIEngine, 6 | # ) 7 | # from ..transformers import TransformersTokenizer 8 | 9 | 10 | # class TogetherAI(OpenAI): 11 | # def __init__( 12 | # self, 13 | # model, 14 | # tokenizer=None, 15 | # echo=True, 16 | # api_key=None, 17 | # max_streaming_tokens=1000, 18 | # timeout=0.5, 19 | # compute_log_probs=False, 20 | # engine_class=None, 21 | # **kwargs, 22 | # ): 23 | # """ 24 | # Build a new TogetherAI model object that represents a model in a given state. 25 | # """ 26 | 27 | # tokenizer = TransformersTokenizer( 28 | # model=model, tokenizer=tokenizer, ignore_bos_token=True 29 | # ) 30 | 31 | # # Default base_url is the together.ai endpoint 32 | # if not "base_url" in kwargs: 33 | # kwargs["base_url"] = "https://api.together.xyz" 34 | # # TogetherAI uses TOGETHERAI_API_KEY env value instead of OPENAI_API_KEY 35 | # # We pass explicitly to avoid OpenAI class complaining about a missing key 36 | # if api_key is None: 37 | # api_key = os.environ.get("TOGETHERAI_API_KEY", None) 38 | # if api_key is None: 39 | # raise Exception( 40 | # "The api_key client option must be set either by passing api_key to the client or by setting the TOGETHERAI_API_KEY environment variable" 41 | # ) 42 | 43 | # if engine_class is None: 44 | # engine_map = { 45 | # TogetherAICompletion: OpenAIEngine, 46 | # TogetherAIInstruct: OpenAIEngine, 47 | # TogetherAIChat: OpenAIEngine, 48 | # TogetherAI: OpenAIEngine, 49 | # } 50 | # for k in engine_map: 51 | # if issubclass(self.__class__, k): 52 | # engine_class = engine_map[k] 53 | # break 54 | 55 | # super().__init__( 56 | # model, 57 | # tokenizer, 58 | # echo, 59 | # api_key, 60 | # max_streaming_tokens, 61 | # timeout, 62 | # compute_log_probs, 63 | # engine_class, 64 | # **kwargs, 65 | # ) 66 | 67 | 68 | # class TogetherAICompletion(TogetherAI): 69 | # pass 70 | 71 | 72 | # class TogetherAIInstruct(TogetherAI, Instruct): 73 | # """ 74 | # Utilizes chat endpoints to simulate a single instruction query 75 | # together.ai will format in correct prompt template for model on their end 76 | # """ 77 | 78 | # def get_role_start(self, name): 79 | # if name == "instruction": 80 | # return "<|im_start|>user\n" 81 | # else: 82 | # raise Exception( 83 | # f"The TogetherAIInstruct model does not know about the {name} role type!" 84 | # ) 85 | 86 | # def get_role_end(self, name): 87 | # if name == "instruction": 88 | # return "<|im_end|>" 89 | # else: 90 | # raise Exception( 91 | # f"The TogetherAIInstruct model does not know about the {name} role type!" 92 | # ) 93 | 94 | 95 | # class TogetherAIChat(TogetherAI, Chat): 96 | # pass 97 | -------------------------------------------------------------------------------- /guidance/models/experimental/__init__.py: -------------------------------------------------------------------------------- 1 | from ._vllm import VLLMModel -------------------------------------------------------------------------------- /guidance/models/experimental/_vllm.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator, Optional, TYPE_CHECKING 2 | import wave 3 | import base64 4 | from io import BytesIO 5 | 6 | if TYPE_CHECKING: 7 | from openai.types.chat import ChatCompletionChunk 8 | 9 | from ..._ast import GrammarNode 10 | from ...trace import OutputAttr, TextOutput 11 | from ...trace._trace import AudioOutput 12 | from .._openai_base import ( 13 | BaseOpenAIInterpreter, 14 | AssistantAudio, 15 | ) 16 | from .._base import Model 17 | 18 | 19 | class VLLMInterpreter(BaseOpenAIInterpreter): 20 | def __init__( 21 | self, 22 | model: str, 23 | base_url: Optional[str] = None, 24 | api_key: Optional[str] = None, 25 | **kwargs, 26 | ): 27 | try: 28 | import openai 29 | except ImportError: 30 | raise Exception( 31 | "Please install the openai package version >= 1 using `pip install openai -U` in order to use guidance.models.OpenAI!" 32 | ) 33 | client = openai.OpenAI(base_url=base_url, api_key=api_key, **kwargs) 34 | super().__init__(model=model, client=client) 35 | 36 | def grammar(self, node: GrammarNode, **kwargs) -> Iterator[OutputAttr]: 37 | buffer: str = "" 38 | for attr in self._run( 39 | extra_body=dict( 40 | guided_decoding_backend="guidance", 41 | guided_grammar=node.ll_grammar(), 42 | ) 43 | ): 44 | if isinstance(attr, TextOutput): 45 | buffer += attr.value 46 | yield attr 47 | matches = node.match( 48 | buffer, 49 | raise_exceptions=False, 50 | # Turn of max_tokens since we don't have access to the tokenizer 51 | enforce_max_tokens=False, 52 | ) 53 | if matches is None: 54 | # TODO: should probably raise... 55 | # raise ValueError("vLLM failed to constrain the grammar") 56 | pass 57 | else: 58 | for name, value in matches.captures.items(): 59 | log_probs = matches.log_probs[name] 60 | if isinstance(value, list): 61 | assert isinstance(log_probs, list) 62 | assert len(value) == len(log_probs) 63 | for v, l in zip(value, log_probs): 64 | yield self.state.apply_capture( 65 | name=name, value=v, log_prob=l, is_append=True 66 | ) 67 | else: 68 | yield self.state.apply_capture( 69 | name=name, value=value, log_prob=log_probs, is_append=False 70 | ) 71 | 72 | 73 | class VLLMModel(Model): 74 | def __init__(self, model: str, echo=True, **kwargs): 75 | super().__init__( 76 | interpreter=VLLMInterpreter(model=model, **kwargs), 77 | echo=echo, 78 | ) 79 | -------------------------------------------------------------------------------- /guidance/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/guidance/py.typed -------------------------------------------------------------------------------- /guidance/registry/__init__.py: -------------------------------------------------------------------------------- 1 | """Registry module that contains singletons.""" 2 | 3 | from ._registry import get_renderer, set_renderer, get_trace_handler, get_exchange 4 | from ._registry import get_bg_async, get_monitor -------------------------------------------------------------------------------- /guidance/registry/_registry.py: -------------------------------------------------------------------------------- 1 | # NOTE(nopdive): Consider moving singleton factories to registry static class. 2 | 3 | import threading 4 | 5 | from .._schema import GuidanceEngineMetrics 6 | from ..metrics import Monitor, PeriodicMetricsGenerator 7 | from ..trace import TraceHandler 8 | from ..visual import AutoRenderer, Renderer, TopicExchange 9 | from .._bg import BackgroundAsync 10 | 11 | _monitor_lock = threading.Lock() 12 | _monitor = None 13 | _periodic_metrics_gen = None 14 | 15 | _bg_async_lock = threading.Lock() 16 | _bg_async = None 17 | 18 | _exchange_lock = threading.Lock() 19 | _exchange = None 20 | 21 | _trace_handler_lock = threading.Lock() 22 | _trace_handler = None 23 | 24 | _renderer_lock = threading.Lock() 25 | _renderer = None 26 | 27 | 28 | def get_monitor() -> Monitor: 29 | global _monitor 30 | global _monitor_lock 31 | global _periodic_metrics_gen 32 | 33 | with _monitor_lock: 34 | if _monitor is None: 35 | _monitor = Monitor(GuidanceEngineMetrics()) 36 | _monitor.start() 37 | _periodic_metrics_gen = PeriodicMetricsGenerator(_monitor) 38 | _periodic_metrics_gen.start() 39 | return _monitor 40 | 41 | 42 | def get_bg_async() -> BackgroundAsync: 43 | global _bg_async 44 | global _bg_async_lock 45 | 46 | with _bg_async_lock: 47 | if _bg_async is None: 48 | _bg_async = BackgroundAsync() 49 | return _bg_async 50 | 51 | 52 | def get_exchange() -> TopicExchange: 53 | global _exchange 54 | global _exchange_lock 55 | 56 | with _exchange_lock: 57 | if _exchange is None: 58 | _exchange = TopicExchange() 59 | return _exchange 60 | 61 | 62 | def get_trace_handler() -> TraceHandler: 63 | global _trace_handler 64 | global _trace_handler_lock 65 | 66 | with _trace_handler_lock: 67 | if _trace_handler is None: 68 | _trace_handler = TraceHandler() 69 | return _trace_handler 70 | 71 | 72 | def get_renderer() -> Renderer: 73 | global _renderer 74 | global _renderer_lock 75 | 76 | with _renderer_lock: 77 | trace_handler = get_trace_handler() 78 | if _renderer is None: 79 | _renderer = AutoRenderer(trace_handler) 80 | return _renderer 81 | 82 | 83 | def set_renderer(renderer: Renderer) -> None: 84 | global _renderer 85 | global _renderer_lock 86 | 87 | with _renderer_lock: 88 | _renderer = renderer 89 | -------------------------------------------------------------------------------- /guidance/resources/sample_audio.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/guidance/resources/sample_audio.wav -------------------------------------------------------------------------------- /guidance/resources/sample_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/guidance/resources/sample_image.png -------------------------------------------------------------------------------- /guidance/resources/sample_video.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/guidance/resources/sample_video.mp4 -------------------------------------------------------------------------------- /guidance/trace/__init__.py: -------------------------------------------------------------------------------- 1 | """Trace tree of inputs & outputs generated from a guidance program. 2 | 3 | The first implementation aims for simplicity. 4 | Once benchmark figures are out, we'll figure out what to optimize. 5 | 6 | The most critical class is the trace handler. See its documentation for trace design & motivations. 7 | """ 8 | 9 | from ._trace import NodeAttr, InputAttr, OutputAttr, StatefulGuidanceInput, StatelessGuidanceInput 10 | from ._trace import LiteralInput, EmbeddedInput, ImageInput, RoleCloserInput, RoleOpenerInput 11 | from ._trace import TextOutput, ImageOutput, CaptureOutput, TraceNode, TraceHandler 12 | -------------------------------------------------------------------------------- /guidance/visual/__init__.py: -------------------------------------------------------------------------------- 1 | """UI and other visual UX considerations. 2 | 3 | Users should have few reasons to be accessing this module. 4 | """ 5 | 6 | from ._message import GuidanceMessage, TraceMessage, ResetDisplayMessage, ClientReadyMessage, ClientReadyAckMessage 7 | from ._message import ExecutionCompletedMessage, TokensMessage, MetricMessage, OutputRequestMessage 8 | from ._message import ExecutionStartedMessage 9 | from ._renderer import AutoRenderer, JupyterWidgetRenderer, Renderer 10 | from ._message import serialize_message, deserialize_message 11 | from ._trace import trace_node_to_str, display_trace_tree, trace_node_to_html 12 | from ._exchange import TopicExchange -------------------------------------------------------------------------------- /guidance/visual/_exchange.py: -------------------------------------------------------------------------------- 1 | """ Poor man's exchanges for routing messages. """ 2 | 3 | from collections import defaultdict 4 | from typing import Callable 5 | from ..visual import GuidanceMessage 6 | import re 7 | import logging 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | DEFAULT_TOPIC = "/default" 12 | WILDCARD_PATTERN = r".*" 13 | 14 | 15 | class TopicExchange: 16 | """ Queue-less topic exchange for routing messages. 17 | 18 | This is not as comprehensive as a full distributed topic exchange. 19 | It is specific to a single process, with no queues and less generalized routing keys. 20 | """ 21 | 22 | def __init__(self): 23 | """ Initializes.""" 24 | self._observers = defaultdict(list) 25 | 26 | def subscribe(self, callback: Callable[[GuidanceMessage], None], topic_pat: str = WILDCARD_PATTERN) -> None: 27 | """ Subscribes to incoming messages. 28 | 29 | Args: 30 | callback: Callback to handle incoming messages. 31 | topic_pat: Topic to notify. 32 | """ 33 | logger.debug(f"EXCHANGE:pre_subscribe:{self._observers[topic_pat]}") 34 | self._observers[topic_pat].append(callback) 35 | logger.debug(f"EXCHANGE:post_subscribe:{self._observers[topic_pat]}") 36 | 37 | def unsubscribe(self, callback: Callable[[GuidanceMessage], None], topic_pat: str = WILDCARD_PATTERN) -> None: 38 | """ Unsubscribes from incoming messages. 39 | 40 | Args: 41 | callback: Callback to remove. 42 | topic_pat: Topic pattern. 43 | """ 44 | logger.debug(f"EXCHANGE:pre_unsubscribe:{self._observers[topic_pat]}") 45 | try: 46 | self._observers[topic_pat].remove(callback) 47 | except ValueError as _: 48 | logger.warning(f"EXCHANGE:cb at '{topic_pat}' already removed.") 49 | logger.debug(f"EXCHANGE:post_unsubscribe:{self._observers[topic_pat]}") 50 | 51 | if len(self._observers[topic_pat]) == 0: 52 | logger.debug(f"EXCHANGE:delete_entry:{topic_pat}") 53 | del self._observers[topic_pat] 54 | 55 | def publish(self, message: GuidanceMessage, topic: str = DEFAULT_TOPIC): 56 | """ Notifies all subscribers to topic pattern of an incoming message. 57 | 58 | Args: 59 | message: Incoming message. 60 | topic: Topics to notify. 61 | """ 62 | # logger.debug(f"EXCHANGE:publish:{message}") 63 | for obs_topic_pat, observers in self._observers.items(): 64 | if re.match(obs_topic_pat, topic): 65 | for observer in observers: 66 | observer(message) 67 | 68 | 69 | __all__ = ["TopicExchange"] 70 | -------------------------------------------------------------------------------- /guidance/visual/_jupyter.py: -------------------------------------------------------------------------------- 1 | """ Jupyter specific utilities.""" 2 | 3 | 4 | from typing import Callable, Any, Tuple, Optional 5 | import logging 6 | from uuid import uuid4 7 | 8 | try: 9 | from IPython import get_ipython 10 | except ImportError: 11 | pass 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | IPythonCallback = Callable[[Any], None] 16 | 17 | 18 | def ipy_handle_event_once(cb: IPythonCallback, event_name: str) -> Tuple[Optional[IPythonCallback], str]: 19 | ipy = get_ipython() 20 | cell_session_id = str(uuid4()) 21 | 22 | if ipy is None: 23 | return None, "" 24 | 25 | def cb_closure(msg): 26 | cb(info=msg) 27 | ipy.events.unregister(event_name, cb_closure) 28 | ipy.events.register(event_name, cb_closure) 29 | 30 | return cb_closure, cell_session_id -------------------------------------------------------------------------------- /notebooks/unstable/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /packages/python/stitch/.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | omit = stitch/tests/* 3 | -------------------------------------------------------------------------------- /packages/python/stitch/.eslintignore: -------------------------------------------------------------------------------- 1 | node_modules 2 | dist 3 | coverage 4 | **/*.d.ts 5 | tests -------------------------------------------------------------------------------- /packages/python/stitch/.eslintrc.js: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | extends: [ 3 | 'eslint:recommended', 4 | 'plugin:@typescript-eslint/eslint-recommended', 5 | 'plugin:@typescript-eslint/recommended', 6 | 'plugin:prettier/recommended' 7 | ], 8 | parser: '@typescript-eslint/parser', 9 | parserOptions: { 10 | project: 'tsconfig.eslint.json', 11 | sourceType: 'module' 12 | }, 13 | plugins: ['@typescript-eslint'], 14 | rules: { 15 | '@typescript-eslint/no-unused-vars': ['warn', { args: 'none' }], 16 | '@typescript-eslint/no-explicit-any': 'off', 17 | '@typescript-eslint/no-namespace': 'off', 18 | '@typescript-eslint/no-use-before-define': 'off', 19 | '@typescript-eslint/quotes': [ 20 | 'error', 21 | 'single', 22 | { avoidEscape: true, allowTemplateLiterals: false } 23 | ], 24 | curly: ['error', 'all'], 25 | eqeqeq: 'error', 26 | 'prefer-arrow-callback': 'error' 27 | } 28 | }; -------------------------------------------------------------------------------- /packages/python/stitch/.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: Build 2 | 3 | on: 4 | push: 5 | branches: main 6 | pull_request: 7 | branches: "*" 8 | 9 | jobs: 10 | build: 11 | runs-on: ${{ matrix.os }} 12 | strategy: 13 | fail-fast: false 14 | matrix: 15 | os: [ubuntu-latest, windows-latest, macos-latest] 16 | python-version: ["3.7", "3.10"] 17 | steps: 18 | - name: Checkout 19 | uses: actions/checkout@v2 20 | 21 | - uses: jupyterlab/maintainer-tools/.github/actions/base-setup@v1 22 | 23 | - name: Install dependencies 24 | run: | 25 | python -m pip install -U codecov 26 | npm install -g codecov 27 | - name: Test the extension 28 | run: | 29 | python -m pip install --upgrade -v -e ".[test, examples, docs]" 30 | python -m pytest 31 | yarn run test 32 | 33 | - name: Linting 34 | if: ${{ matrix.os == 'ubuntu-latest' }} 35 | run: | 36 | yarn run lint:check 37 | 38 | - name: Check docs can be build + links 39 | if: ${{ matrix.os == 'ubuntu-latest' }} 40 | working-directory: docs 41 | run: | 42 | sudo apt install -y pandoc 43 | make html 44 | python -m pytest --check-links 45 | -------------------------------------------------------------------------------- /packages/python/stitch/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask instance folder 57 | instance/ 58 | 59 | # Scrapy stuff: 60 | .scrapy 61 | 62 | # Sphinx documentation 63 | docs/_build/ 64 | docs/source/_static/embed-bundle.js 65 | docs/source/_static/embed-bundle.js.map 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # IPython Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # dotenv 80 | .env 81 | 82 | # virtualenv 83 | venv/ 84 | ENV/ 85 | 86 | # Spyder project settings 87 | .spyderproject 88 | 89 | # Rope project settings 90 | .ropeproject 91 | 92 | # ========================= 93 | # Operating System Files 94 | # ========================= 95 | 96 | # OSX 97 | # ========================= 98 | 99 | .DS_Store 100 | .AppleDouble 101 | .LSOverride 102 | 103 | # Thumbnails 104 | ._* 105 | 106 | # Files that might appear in the root of a volume 107 | .DocumentRevisions-V100 108 | .fseventsd 109 | .Spotlight-V100 110 | .TemporaryItems 111 | .Trashes 112 | .VolumeIcon.icns 113 | 114 | # Directories potentially created on remote AFP share 115 | .AppleDB 116 | .AppleDesktop 117 | Network Trash Folder 118 | Temporary Items 119 | .apdisk 120 | 121 | # Windows 122 | # ========================= 123 | 124 | # Windows image file caches 125 | Thumbs.db 126 | ehthumbs.db 127 | 128 | # Folder config file 129 | Desktop.ini 130 | 131 | # Recycle Bin used on file shares 132 | $RECYCLE.BIN/ 133 | 134 | # Windows Installer files 135 | *.cab 136 | *.msi 137 | *.msm 138 | *.msp 139 | 140 | # Windows shortcuts 141 | *.lnk 142 | 143 | 144 | # NPM 145 | # ---- 146 | 147 | **/node_modules/ 148 | stitch/nbextension/index.* 149 | .yarn/ 150 | 151 | # Coverage data 152 | # ------------- 153 | **/coverage/ 154 | 155 | # Packed lab extensions 156 | stitch/labextension 157 | -------------------------------------------------------------------------------- /packages/python/stitch/.npmignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | node_modules/ 3 | tests/ 4 | .jshintrc 5 | # Ignore any build output from python: 6 | dist/*.tar.gz 7 | dist/*.wheel 8 | -------------------------------------------------------------------------------- /packages/python/stitch/.prettierignore: -------------------------------------------------------------------------------- 1 | node_modules 2 | **/node_modules 3 | **/lib 4 | **/package.json -------------------------------------------------------------------------------- /packages/python/stitch/.prettierrc: -------------------------------------------------------------------------------- 1 | { 2 | "singleQuote": true 3 | } -------------------------------------------------------------------------------- /packages/python/stitch/.yarnrc.yml: -------------------------------------------------------------------------------- 1 | nodeLinker: node-modules 2 | -------------------------------------------------------------------------------- /packages/python/stitch/LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2024 Guidance Contributors 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | 2. Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | 3. Neither the name of the copyright holder nor the names of its 15 | contributors may be used to endorse or promote products derived from 16 | this software without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | -------------------------------------------------------------------------------- /packages/python/stitch/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE.txt 2 | include README.md 3 | 4 | include setup.py 5 | include pyproject.toml 6 | include pytest.ini 7 | include .coverage.rc 8 | 9 | include tsconfig.json 10 | include package.json 11 | include webpack.config.js 12 | include stitch/labextension/*.tgz 13 | 14 | # Documentation 15 | graft docs 16 | exclude docs/\#* 17 | prune docs/build 18 | prune docs/gh-pages 19 | prune docs/dist 20 | 21 | # Examples 22 | graft examples 23 | 24 | # Tests 25 | graft tests 26 | prune tests/build 27 | 28 | # Javascript files 29 | graft stitch/nbextension 30 | graft src 31 | graft css 32 | prune **/node_modules 33 | prune coverage 34 | prune lib 35 | 36 | # Patterns to exclude from any directory 37 | global-exclude *~ 38 | global-exclude *.pyc 39 | global-exclude *.pyo 40 | global-exclude .git 41 | global-exclude .ipynb_checkpoints 42 | -------------------------------------------------------------------------------- /packages/python/stitch/README.md: -------------------------------------------------------------------------------- 1 | 2 | # stitch 3 | 4 | [![Build Status](https://travis-ci.org/guidance-ai/stitch.svg?branch=master)](https://travis-ci.org/guidance-ai/stitch) 5 | [![codecov](https://codecov.io/gh/guidance-ai/stitch/branch/master/graph/badge.svg)](https://codecov.io/gh/guidance-ai/stitch) 6 | 7 | 8 | Bidirectional comms for Jupyter and JavaScript. 9 | 10 | ## Installation 11 | 12 | You can install using `pip`: 13 | 14 | ```bash 15 | pip install guidance-stitch 16 | ``` 17 | 18 | If you are using Jupyter Notebook 5.2 or earlier, you may also need to enable 19 | the nbextension: 20 | ```bash 21 | jupyter nbextension enable --py [--sys-prefix|--user|--system] guidance-stitch 22 | ``` 23 | 24 | ## Development Installation 25 | 26 | Create a dev environment: 27 | ```bash 28 | conda create -n stitch-dev -c conda-forge nodejs python jupyterlab=4.0.11 29 | conda activate stitch-dev 30 | ``` 31 | 32 | Install the python. This will also build the TS package. 33 | ```bash 34 | pip install -e ".[test, examples]" 35 | ``` 36 | 37 | When developing your extensions, you need to manually enable your extensions with the 38 | notebook / lab frontend. For lab, this is done by the command: 39 | 40 | ``` 41 | jupyter labextension develop --overwrite . 42 | jlpm run build 43 | ``` 44 | 45 | For classic notebook, you need to run: 46 | 47 | ``` 48 | jupyter nbextension install --sys-prefix --symlink --overwrite --py guidance-stitch 49 | jupyter nbextension enable --sys-prefix --py guidance-stitch 50 | ``` 51 | 52 | Note that the `--symlink` flag doesn't work on Windows, so you will here have to run 53 | the `install` command every time that you rebuild your extension. For certain installations 54 | you might also need another flag instead of `--sys-prefix`, but we won't cover the meaning 55 | of those flags here. 56 | 57 | ### How to see your changes 58 | #### Typescript: 59 | If you use JupyterLab to develop then you can watch the source directory and run JupyterLab at the same time in different 60 | terminals to watch for changes in the extension's source and automatically rebuild the widget. 61 | 62 | ```bash 63 | # Watch the source directory in one terminal, automatically rebuilding when needed 64 | jlpm run watch 65 | # Run JupyterLab in another terminal 66 | jupyter lab 67 | ``` 68 | 69 | After a change wait for the build to finish and then refresh your browser and the changes should take effect. 70 | 71 | #### Python: 72 | If you make a change to the python code then you will need to restart the notebook kernel to have it take effect. 73 | 74 | ## Updating the version 75 | 76 | To update the version, install tbump and use it to bump the version. 77 | By default it will also create a tag. 78 | 79 | ```bash 80 | pip install tbump 81 | tbump 82 | ``` 83 | 84 | -------------------------------------------------------------------------------- /packages/python/stitch/babel.config.js: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | sourceMap: 'inline', 3 | presets: [ 4 | [ 5 | '@babel/preset-env', 6 | { 7 | targets: { 8 | node: 'current', 9 | }, 10 | }, 11 | ], 12 | ], 13 | }; 14 | -------------------------------------------------------------------------------- /packages/python/stitch/codecov.yml: -------------------------------------------------------------------------------- 1 | comment: off 2 | # show coverage in CI status, but never consider it a failure 3 | coverage: 4 | status: 5 | project: 6 | default: 7 | target: 0% 8 | patch: 9 | default: 10 | target: 0% 11 | ignore: 12 | - "stitch/tests" 13 | -------------------------------------------------------------------------------- /packages/python/stitch/css/widget.css: -------------------------------------------------------------------------------- 1 | .custom-widget { 2 | background-color: lightseagreen; 3 | padding: 0px 2px; 4 | } 5 | -------------------------------------------------------------------------------- /packages/python/stitch/docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SPHINXPROJ = stitch 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /packages/python/stitch/docs/environment.yml: -------------------------------------------------------------------------------- 1 | 2 | name: stitch_docs 3 | channels: 4 | - conda-forge 5 | dependencies: 6 | - python=3.* 7 | - nodejs 8 | - jupyter_sphinx 9 | - sphinx 10 | - sphinx_rtd_theme 11 | - nbsphinx 12 | - nbsphinx-link 13 | -------------------------------------------------------------------------------- /packages/python/stitch/docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | set SPHINXPROJ=stitch 13 | 14 | if "%1" == "" goto help 15 | 16 | %SPHINXBUILD% >NUL 2>NUL 17 | if errorlevel 9009 ( 18 | echo. 19 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 20 | echo.installed, then set the SPHINXBUILD environment variable to point 21 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 22 | echo.may add the Sphinx directory to PATH. 23 | echo. 24 | echo.If you don't have Sphinx installed, grab it from 25 | echo.http://sphinx-doc.org/ 26 | exit /b 1 27 | ) 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /packages/python/stitch/docs/source/_static/helper.js: -------------------------------------------------------------------------------- 1 | var cache_require = window.require; 2 | 3 | window.addEventListener('load', function() { 4 | window.require = cache_require; 5 | }); 6 | -------------------------------------------------------------------------------- /packages/python/stitch/docs/source/develop-install.rst: -------------------------------------------------------------------------------- 1 | 2 | Developer install 3 | ================= 4 | 5 | 6 | To install a developer version of stitch, you will first need to clone 7 | the repository:: 8 | 9 | git clone https://github.com/guidance-ai/stitch 10 | cd stitch 11 | 12 | Next, install it with a develop install using pip:: 13 | 14 | pip install -e . 15 | 16 | 17 | If you are planning on working on the JS/frontend code, you should also do 18 | a link installation of the extension:: 19 | 20 | jupyter nbextension install [--sys-prefix / --user / --system] --symlink --py stitch 21 | 22 | jupyter nbextension enable [--sys-prefix / --user / --system] --py stitch 23 | 24 | with the `appropriate flag`_. Or, if you are using Jupyterlab:: 25 | 26 | jupyter labextension install . 27 | 28 | 29 | .. links 30 | 31 | .. _`appropriate flag`: https://jupyter-notebook.readthedocs.io/en/stable/extending/frontend_extensions.html#installing-and-enabling-extensions 32 | -------------------------------------------------------------------------------- /packages/python/stitch/docs/source/examples/index.rst: -------------------------------------------------------------------------------- 1 | 2 | Examples 3 | ======== 4 | 5 | This section contains several examples generated from Jupyter notebooks. 6 | The widgets have been embedded into the page for demonstrative purposes. 7 | 8 | .. todo:: 9 | 10 | Add links to notebooks in examples folder similar to the initial 11 | one. This is a manual step to ensure only those examples that 12 | are suited for inclusion are used. 13 | 14 | 15 | .. toctree:: 16 | :glob: 17 | 18 | * 19 | -------------------------------------------------------------------------------- /packages/python/stitch/docs/source/examples/introduction.nblink: -------------------------------------------------------------------------------- 1 | { 2 | "path": "../../../examples/introduction.ipynb" 3 | } 4 | -------------------------------------------------------------------------------- /packages/python/stitch/docs/source/index.rst: -------------------------------------------------------------------------------- 1 | 2 | stitch 3 | ===================================== 4 | 5 | Version: |release| 6 | 7 | Bidirectional comms for Jupyter and JavaScript. 8 | 9 | 10 | Quickstart 11 | ---------- 12 | 13 | To get started with stitch, install with pip:: 14 | 15 | pip install stitch 16 | 17 | or with conda:: 18 | 19 | conda install stitch 20 | 21 | 22 | Contents 23 | -------- 24 | 25 | .. toctree:: 26 | :maxdepth: 2 27 | :caption: Installation and usage 28 | 29 | installing 30 | introduction 31 | 32 | .. toctree:: 33 | :maxdepth: 1 34 | 35 | examples/index 36 | 37 | 38 | .. toctree:: 39 | :maxdepth: 2 40 | :caption: Development 41 | 42 | develop-install 43 | 44 | 45 | .. links 46 | 47 | .. _`Jupyter widgets`: https://jupyter.org/widgets.html 48 | 49 | .. _`notebook`: https://jupyter-notebook.readthedocs.io/en/latest/ 50 | -------------------------------------------------------------------------------- /packages/python/stitch/docs/source/installing.rst: -------------------------------------------------------------------------------- 1 | 2 | .. _installation: 3 | 4 | Installation 5 | ============ 6 | 7 | 8 | The simplest way to install stitch is via pip:: 9 | 10 | pip install stitch 11 | 12 | or via conda:: 13 | 14 | conda install stitch 15 | 16 | 17 | If you installed via pip, and notebook version < 5.3, you will also have to 18 | install / configure the front-end extension as well. If you are using classic 19 | notebook (as opposed to Jupyterlab), run:: 20 | 21 | jupyter nbextension install [--sys-prefix / --user / --system] --py stitch 22 | 23 | jupyter nbextension enable [--sys-prefix / --user / --system] --py stitch 24 | 25 | with the `appropriate flag`_. If you are using Jupyterlab, install the extension 26 | with:: 27 | 28 | jupyter labextension install @guidance-ai/stitch 29 | 30 | If you are installing using conda, these commands should be unnecessary, but If 31 | you need to run them the commands should be the same (just make sure you choose the 32 | `--sys-prefix` flag). 33 | 34 | 35 | .. links 36 | 37 | .. _`appropriate flag`: https://jupyter-notebook.readthedocs.io/en/stable/extending/frontend_extensions.html#installing-and-enabling-extensions 38 | -------------------------------------------------------------------------------- /packages/python/stitch/docs/source/introduction.rst: -------------------------------------------------------------------------------- 1 | ============= 2 | Introduction 3 | ============= 4 | 5 | .. todo:: 6 | 7 | add prose explaining project purpose and usage here 8 | -------------------------------------------------------------------------------- /packages/python/stitch/install.json: -------------------------------------------------------------------------------- 1 | { 2 | "packageManager": "python", 3 | "packageName": "stitch", 4 | "uninstallInstructions": "Use your Python package manager (pip, conda, etc.) to uninstall the package stitch" 5 | } 6 | -------------------------------------------------------------------------------- /packages/python/stitch/jest.config.js: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | automock: false, 3 | moduleNameMapper: { 4 | '\\.(css|less|sass|scss)$': 'identity-obj-proxy', 5 | }, 6 | preset: 'ts-jest/presets/js-with-babel', 7 | moduleFileExtensions: ['ts', 'tsx', 'js', 'jsx', 'json', 'node'], 8 | testPathIgnorePatterns: ['/lib/', '/node_modules/'], 9 | testRegex: '/__tests__/.*.spec.ts[x]?$', 10 | transformIgnorePatterns: ['/node_modules/(?!(@jupyter(lab|-widgets)/.*)/)'], 11 | globals: { 12 | 'ts-jest': { 13 | tsconfig: '/tsconfig.json', 14 | }, 15 | }, 16 | }; 17 | -------------------------------------------------------------------------------- /packages/python/stitch/pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | testpaths = stitch/tests examples 3 | norecursedirs = node_modules .ipynb_checkpoints 4 | addopts = --nbval --current-env 5 | -------------------------------------------------------------------------------- /packages/python/stitch/readthedocs.yml: -------------------------------------------------------------------------------- 1 | type: sphinx 2 | python: 3 | version: 3.5 4 | pip_install: true 5 | extra_requirements: 6 | - examples 7 | - docs 8 | conda: 9 | file: docs/environment.yml 10 | -------------------------------------------------------------------------------- /packages/python/stitch/setup.py: -------------------------------------------------------------------------------- 1 | # setup.py shim for use with applications that require it. 2 | __import__("setuptools").setup() 3 | -------------------------------------------------------------------------------- /packages/python/stitch/src/__tests__/index.spec.ts: -------------------------------------------------------------------------------- 1 | /** @jest-environment jsdom */ 2 | // Copyright (c) Jupyter Development Team. 3 | // Distributed under the terms of the Modified BSD License. 4 | 5 | // Add any needed widget imports here (or from controls) 6 | // import {} from '@jupyter-widgets/base'; 7 | 8 | // NOTE(nopdive): Workaround for jsdom drag event failure. 9 | Object.defineProperty(window, 'DragEvent', { 10 | value: class DragEvent {}, 11 | }); 12 | 13 | import { createTestModel } from './utils'; 14 | import { StitchModel } from '..'; 15 | 16 | describe('Example', () => { 17 | describe('StitchModel', () => { 18 | it('should be createable', () => { 19 | const model = createTestModel(StitchModel); 20 | expect(model).toBeInstanceOf(StitchModel); 21 | expect(model.get('srcdoc')).toEqual( 22 | '

srcdoc should be defined by the user

', 23 | ); 24 | }); 25 | 26 | it('should be createable with a value', () => { 27 | const state = { srcdoc: 'it is alright' }; 28 | const model = createTestModel(StitchModel, state); 29 | expect(model).toBeInstanceOf(StitchModel); 30 | expect(model.get('srcdoc')).toEqual('it is alright'); 31 | }); 32 | }); 33 | }); 34 | -------------------------------------------------------------------------------- /packages/python/stitch/src/__tests__/utils.ts: -------------------------------------------------------------------------------- 1 | // Copyright (c) Jupyter Development Team. 2 | // Distributed under the terms of the Modified BSD License. 3 | 4 | import * as widgets from '@jupyter-widgets/base'; 5 | import * as baseManager from '@jupyter-widgets/base-manager'; 6 | import * as services from '@jupyterlab/services'; 7 | 8 | let numComms = 0; 9 | 10 | export class MockComm implements widgets.IClassicComm { 11 | constructor() { 12 | this.comm_id = `mock-comm-id-${numComms}`; 13 | numComms += 1; 14 | } 15 | on_close(fn: ((x?: any) => void) | null): void { 16 | this._on_close = fn; 17 | } 18 | on_msg(fn: (x?: any) => void): void { 19 | this._on_msg = fn; 20 | } 21 | _process_msg(msg: services.KernelMessage.ICommMsgMsg): void | Promise { 22 | if (this._on_msg) { 23 | return this._on_msg(msg); 24 | } else { 25 | return Promise.resolve(); 26 | } 27 | } 28 | close(): string { 29 | if (this._on_close) { 30 | this._on_close(); 31 | } 32 | return 'dummy'; 33 | } 34 | send(): string { 35 | return 'dummy'; 36 | } 37 | 38 | open(): string { 39 | return 'dummy'; 40 | } 41 | 42 | comm_id: string; 43 | target_name = 'dummy'; 44 | _on_msg: ((x?: any) => void) | null = null; 45 | _on_close: ((x?: any) => void) | null = null; 46 | } 47 | 48 | export class DummyManager extends baseManager.ManagerBase { 49 | constructor() { 50 | super(); 51 | this.el = window.document.createElement('div'); 52 | } 53 | 54 | display_view( 55 | msg: services.KernelMessage.IMessage, 56 | view: widgets.DOMWidgetView, 57 | options: any 58 | ) { 59 | // TODO: make this a spy 60 | // TODO: return an html element 61 | return Promise.resolve(view).then((view) => { 62 | this.el.appendChild(view.el); 63 | view.on('remove', () => console.log('view removed', view)); 64 | return view.el; 65 | }); 66 | } 67 | 68 | protected loadClass( 69 | className: string, 70 | moduleName: string, 71 | moduleVersion: string 72 | ): Promise { 73 | if (moduleName === '@jupyter-widgets/base') { 74 | if ((widgets as any)[className]) { 75 | return Promise.resolve((widgets as any)[className]); 76 | } else { 77 | return Promise.reject(`Cannot find class ${className}`); 78 | } 79 | } else if (moduleName === 'jupyter-datawidgets') { 80 | if (this.testClasses[className]) { 81 | return Promise.resolve(this.testClasses[className]); 82 | } else { 83 | return Promise.reject(`Cannot find class ${className}`); 84 | } 85 | } else { 86 | return Promise.reject(`Cannot find module ${moduleName}`); 87 | } 88 | } 89 | 90 | _get_comm_info() { 91 | return Promise.resolve({}); 92 | } 93 | 94 | _create_comm() { 95 | return Promise.resolve(new MockComm()); 96 | } 97 | 98 | el: HTMLElement; 99 | 100 | testClasses: { [key: string]: any } = {}; 101 | } 102 | 103 | export interface Constructor { 104 | new (attributes?: any, options?: any): T; 105 | } 106 | 107 | export function createTestModel( 108 | constructor: Constructor, 109 | attributes?: any 110 | ): T { 111 | const id = widgets.uuid(); 112 | const widget_manager = new DummyManager(); 113 | const modelOptions = { 114 | widget_manager: widget_manager, 115 | model_id: id, 116 | }; 117 | 118 | return new constructor(attributes, modelOptions); 119 | } 120 | -------------------------------------------------------------------------------- /packages/python/stitch/src/extension.ts: -------------------------------------------------------------------------------- 1 | // Copyright (c) Jupyter Development Team. 2 | // Distributed under the terms of the Modified BSD License. 3 | 4 | // Entry point for the notebook bundle containing custom model definitions. 5 | // 6 | // Setup notebook base URL 7 | // 8 | // Some static assets may be required by the custom widget javascript. The base 9 | // url for the notebook is not known at build time and is therefore computed 10 | // dynamically. 11 | // eslint-disable-next-line @typescript-eslint/no-non-null-assertion 12 | (window as any).__webpack_public_path__ = 13 | document.querySelector('body')!.getAttribute('data-base-url') + 14 | 'nbextensions/stitch'; 15 | 16 | export * from './index'; 17 | -------------------------------------------------------------------------------- /packages/python/stitch/src/index.ts: -------------------------------------------------------------------------------- 1 | // Copyright (c) Guidance Contributors 2 | // Distributed under the terms of the Modified BSD License. 3 | 4 | export * from './version'; 5 | export * from './widget'; 6 | -------------------------------------------------------------------------------- /packages/python/stitch/src/plugin.ts: -------------------------------------------------------------------------------- 1 | // Copyright (c) Guidance Contributors 2 | // Distributed under the terms of the Modified BSD License. 3 | 4 | import { Application, IPlugin } from '@lumino/application'; 5 | 6 | import { Widget } from '@lumino/widgets'; 7 | 8 | import { IJupyterWidgetRegistry } from '@jupyter-widgets/base'; 9 | 10 | import * as widgetExports from './widget'; 11 | 12 | import { MODULE_NAME, MODULE_VERSION } from './version'; 13 | 14 | const EXTENSION_ID = '@guidance-ai/stitch:plugin'; 15 | 16 | /** 17 | * The example plugin. 18 | */ 19 | const examplePlugin: IPlugin, void> = { 20 | id: EXTENSION_ID, 21 | requires: [IJupyterWidgetRegistry], 22 | activate: activateWidgetExtension, 23 | autoStart: true, 24 | } as unknown as IPlugin, void>; 25 | // the "as unknown as ..." typecast above is solely to support JupyterLab 1 26 | // and 2 in the same codebase and should be removed when we migrate to Lumino. 27 | 28 | export default examplePlugin; 29 | 30 | /** 31 | * Activate the widget extension. 32 | */ 33 | function activateWidgetExtension( 34 | app: Application, 35 | registry: IJupyterWidgetRegistry 36 | ): void { 37 | registry.registerWidget({ 38 | name: MODULE_NAME, 39 | version: MODULE_VERSION, 40 | exports: widgetExports, 41 | }); 42 | } 43 | -------------------------------------------------------------------------------- /packages/python/stitch/src/version.ts: -------------------------------------------------------------------------------- 1 | // Copyright (c) Guidance Contributors 2 | // Distributed under the terms of the Modified BSD License. 3 | 4 | // eslint-disable-next-line @typescript-eslint/ban-ts-comment 5 | // @ts-ignore 6 | // eslint-disable-next-line @typescript-eslint/no-var-requires 7 | const data = require('../package.json'); 8 | 9 | /** 10 | * The _model_module_version/_view_module_version this package implements. 11 | * 12 | * The html widget manager assumes that this is the same as the npm package 13 | * version number. 14 | */ 15 | export const MODULE_VERSION = data.version; 16 | 17 | /* 18 | * The current package name. 19 | */ 20 | export const MODULE_NAME = data.name; 21 | -------------------------------------------------------------------------------- /packages/python/stitch/stitch.json: -------------------------------------------------------------------------------- 1 | { 2 | "load_extensions": { 3 | "stitch/extension": true 4 | } 5 | } 6 | -------------------------------------------------------------------------------- /packages/python/stitch/stitch/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # Copyright (c) Guidance Contributors. 5 | # Distributed under the terms of the Modified BSD License. 6 | 7 | from .stitch import StitchWidget 8 | from ._version import __version__, version_info 9 | 10 | def _jupyter_labextension_paths(): 11 | """Called by Jupyter Lab Server to detect if it is a valid labextension and 12 | to install the widget 13 | Returns 14 | ======= 15 | src: Source directory name to copy files from. Webpack outputs generated files 16 | into this directory and Jupyter Lab copies from this directory during 17 | widget installation 18 | dest: Destination directory name to install widget files to. Jupyter Lab copies 19 | from `src` directory into /labextensions/ directory 20 | during widget installation 21 | """ 22 | return [{ 23 | 'src': 'labextension', 24 | 'dest': '@guidance-ai/stitch', 25 | }] 26 | 27 | 28 | def _jupyter_nbextension_paths(): 29 | """Called by Jupyter Notebook Server to detect if it is a valid nbextension and 30 | to install the widget 31 | Returns 32 | ======= 33 | section: The section of the Jupyter Notebook Server to change. 34 | Must be 'notebook' for widget extensions 35 | src: Source directory name to copy files from. Webpack outputs generated files 36 | into this directory and Jupyter Notebook copies from this directory during 37 | widget installation 38 | dest: Destination directory name to install widget files to. Jupyter Notebook copies 39 | from `src` directory into /nbextensions/ directory 40 | during widget installation 41 | require: Path to importable AMD Javascript module inside the 42 | /nbextensions/ directory 43 | """ 44 | return [{ 45 | 'section': 'notebook', 46 | 'src': 'nbextension', 47 | 'dest': 'stitch', 48 | 'require': 'stitch/extension' 49 | }] 50 | -------------------------------------------------------------------------------- /packages/python/stitch/stitch/_frontend.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # Copyright (c) Guidance Contributors. 5 | # Distributed under the terms of the Modified BSD License. 6 | 7 | """ 8 | Information about the frontend package of the widgets. 9 | """ 10 | 11 | module_name = "@guidance-ai/stitch" 12 | module_version = "^0.1.4" 13 | -------------------------------------------------------------------------------- /packages/python/stitch/stitch/_version.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # Copyright (c) Guidance Contributors. 5 | # Distributed under the terms of the Modified BSD License. 6 | 7 | version_info = (0, 1, 4) 8 | __version__ = ".".join(map(str, version_info)) 9 | -------------------------------------------------------------------------------- /packages/python/stitch/stitch/nbextension/extension.js: -------------------------------------------------------------------------------- 1 | // Entry point for the notebook bundle containing custom model definitions. 2 | // 3 | define(function() { 4 | "use strict"; 5 | 6 | window['requirejs'].config({ 7 | map: { 8 | '*': { 9 | '@guidance-ai/stitch': 'nbextensions/stitch/index', 10 | }, 11 | } 12 | }); 13 | // Export the required load_ipython_extension function 14 | return { 15 | load_ipython_extension : function() {} 16 | }; 17 | }); -------------------------------------------------------------------------------- /packages/python/stitch/stitch/stitch.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # Copyright (c) Guidance Contributors. 5 | # Distributed under the terms of the Modified BSD License. 6 | 7 | """ 8 | Stitch Widget that allows bidirectional comms from Jupyter and JavaScript. 9 | """ 10 | 11 | from ipywidgets import DOMWidget 12 | from traitlets import Unicode 13 | from ._frontend import module_name, module_version 14 | 15 | 16 | class StitchWidget(DOMWidget): 17 | """Widget that purely handles communication between an iframe and kernel via postMessage.""" 18 | 19 | _model_name = Unicode('StitchModel').tag(sync=True) 20 | _model_module = Unicode(module_name).tag(sync=True) 21 | _model_module_version = Unicode(module_version).tag(sync=True) 22 | _view_name = Unicode('StitchView').tag(sync=True) 23 | _view_module = Unicode(module_name).tag(sync=True) 24 | _view_module_version = Unicode(module_version).tag(sync=True) 25 | 26 | kernelmsg = Unicode("").tag(sync=True) 27 | clientmsg = Unicode("").tag(sync=True) 28 | srcdoc = Unicode("

srcdoc should be defined by the user

").tag(sync=True) 29 | initial_height = Unicode("1px").tag(sync=True) 30 | initial_width = Unicode("1px").tag(sync=True) 31 | initial_border = Unicode("0").tag(sync=True) 32 | 33 | # NOTE(nopdive): Should we sync or not? There are overheads when we deal with bandwidth on real time applications. 34 | state = Unicode("").tag(sync=True) -------------------------------------------------------------------------------- /packages/python/stitch/stitch/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/packages/python/stitch/stitch/tests/__init__.py -------------------------------------------------------------------------------- /packages/python/stitch/stitch/tests/conftest.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # Copyright (c) Guidance Contributors. 5 | # Distributed under the terms of the Modified BSD License. 6 | 7 | import pytest 8 | 9 | from ipykernel.comm import Comm 10 | from ipywidgets import Widget 11 | 12 | class MockComm(Comm): 13 | """A mock Comm object. 14 | 15 | Can be used to inspect calls to Comm's open/send/close methods. 16 | """ 17 | comm_id = 'a-b-c-d' 18 | kernel = 'Truthy' 19 | 20 | def __init__(self, *args, **kwargs): 21 | self.log_open = [] 22 | self.log_send = [] 23 | self.log_close = [] 24 | super(MockComm, self).__init__(*args, **kwargs) 25 | 26 | def open(self, *args, **kwargs): 27 | self.log_open.append((args, kwargs)) 28 | 29 | def send(self, *args, **kwargs): 30 | self.log_send.append((args, kwargs)) 31 | 32 | def close(self, *args, **kwargs): 33 | self.log_close.append((args, kwargs)) 34 | 35 | _widget_attrs = {} 36 | undefined = object() 37 | 38 | 39 | @pytest.fixture 40 | def mock_comm(): 41 | _widget_attrs['_comm_default'] = getattr(Widget, '_comm_default', undefined) 42 | Widget._comm_default = lambda self: MockComm() 43 | _widget_attrs['_ipython_display_'] = Widget._ipython_display_ 44 | def raise_not_implemented(*args, **kwargs): 45 | raise NotImplementedError() 46 | Widget._ipython_display_ = raise_not_implemented 47 | 48 | yield MockComm() 49 | 50 | for attr, value in _widget_attrs.items(): 51 | if value is undefined: 52 | delattr(Widget, attr) 53 | else: 54 | setattr(Widget, attr, value) 55 | -------------------------------------------------------------------------------- /packages/python/stitch/stitch/tests/test_example.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # Copyright (c) Guidance Contributors. 5 | # Distributed under the terms of the Modified BSD License. 6 | 7 | import pytest 8 | 9 | from ..stitch import StitchWidget 10 | 11 | 12 | def test_example_creation_blank(): 13 | w = StitchWidget() 14 | assert w.kernelmsg == "" 15 | assert w.clientmsg == "" 16 | assert w.srcdoc == "

srcdoc should be defined by the user

" 17 | -------------------------------------------------------------------------------- /packages/python/stitch/stitch/tests/test_nbextension_path.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # Copyright (c) Guidance Contributors. 5 | # Distributed under the terms of the Modified BSD License. 6 | 7 | 8 | def test_nbextension_path(): 9 | # Check that magic function can be imported from package root: 10 | from stitch import _jupyter_nbextension_paths 11 | # Ensure that it can be called without incident: 12 | path = _jupyter_nbextension_paths() 13 | # Some sanity checks: 14 | assert len(path) == 1 15 | assert isinstance(path[0], dict) 16 | -------------------------------------------------------------------------------- /packages/python/stitch/tsconfig.eslint.json: -------------------------------------------------------------------------------- 1 | { 2 | "extends": "./tsconfig.json", 3 | "include": ["src/**/*.ts", "src/**/*.tsx"], 4 | "exclude": [] 5 | } -------------------------------------------------------------------------------- /packages/python/stitch/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "declaration": true, 4 | "esModuleInterop":true, 5 | "lib": ["es2015", "dom"], 6 | "module": "commonjs", 7 | "moduleResolution": "node", 8 | "noEmitOnError": true, 9 | "noUnusedLocals": true, 10 | "outDir": "lib", 11 | "resolveJsonModule": true, 12 | "rootDir": "src", 13 | "skipLibCheck": true, 14 | "sourceMap": true, 15 | "strict": true, 16 | "strictPropertyInitialization": false, 17 | "target": "es2015", 18 | "types": ["jest"] 19 | }, 20 | "include": [ 21 | "src/**/*.ts", 22 | "src/**/*.tsx", 23 | ], 24 | "exclude": ["src/**/__tests__"] 25 | } 26 | -------------------------------------------------------------------------------- /packages/python/stitch/webpack.config.js: -------------------------------------------------------------------------------- 1 | const path = require('path'); 2 | const version = require('./package.json').version; 3 | 4 | // Custom webpack rules 5 | const rules = [ 6 | { test: /\.ts$/, loader: 'ts-loader' }, 7 | { test: /\.js$/, loader: 'source-map-loader' }, 8 | { test: /\.css$/, use: ['style-loader', 'css-loader']} 9 | ]; 10 | 11 | // Packages that shouldn't be bundled but loaded at runtime 12 | const externals = ['@jupyter-widgets/base']; 13 | 14 | const resolve = { 15 | // Add '.ts' and '.tsx' as resolvable extensions. 16 | extensions: [".webpack.js", ".web.js", ".ts", ".js"] 17 | }; 18 | 19 | module.exports = [ 20 | /** 21 | * Notebook extension 22 | * 23 | * This bundle only contains the part of the JavaScript that is run on load of 24 | * the notebook. 25 | */ 26 | { 27 | entry: './src/extension.ts', 28 | output: { 29 | filename: 'index.js', 30 | path: path.resolve(__dirname, 'stitch', 'nbextension'), 31 | libraryTarget: 'amd', 32 | publicPath: '', 33 | }, 34 | module: { 35 | rules: rules 36 | }, 37 | devtool: 'source-map', 38 | externals, 39 | resolve, 40 | }, 41 | 42 | /** 43 | * Embeddable @guidance-ai/stitch bundle 44 | * 45 | * This bundle is almost identical to the notebook extension bundle. The only 46 | * difference is in the configuration of the webpack public path for the 47 | * static assets. 48 | * 49 | * The target bundle is always `dist/index.js`, which is the path required by 50 | * the custom widget embedder. 51 | */ 52 | { 53 | entry: './src/index.ts', 54 | output: { 55 | filename: 'index.js', 56 | path: path.resolve(__dirname, 'dist'), 57 | libraryTarget: 'amd', 58 | library: "@guidance-ai/stitch", 59 | publicPath: 'https://unpkg.com/@guidance-ai/stitch@' + version + '/dist/' 60 | }, 61 | devtool: 'source-map', 62 | module: { 63 | rules: rules 64 | }, 65 | externals, 66 | resolve, 67 | }, 68 | 69 | 70 | /** 71 | * Documentation widget bundle 72 | * 73 | * This bundle is used to embed widgets in the package documentation. 74 | */ 75 | { 76 | entry: './src/index.ts', 77 | output: { 78 | filename: 'embed-bundle.js', 79 | path: path.resolve(__dirname, 'docs', 'source', '_static'), 80 | library: "stitch", 81 | libraryTarget: 'amd' 82 | }, 83 | module: { 84 | rules: rules 85 | }, 86 | devtool: 'source-map', 87 | externals, 88 | resolve, 89 | } 90 | 91 | ]; 92 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools>=42", 4 | "wheel", 5 | "pybind11>=2.10.0", 6 | ] 7 | build-backend = "setuptools.build_meta" 8 | 9 | [tool.pytest.ini_options] 10 | addopts = "--strict-markers" 11 | markers = [ 12 | "resource_intensive: Tests which readily exceed resources such as memory", 13 | "asyncio: Make async tests to run" 14 | ] 15 | 16 | [tool.black] 17 | line-length = 99 18 | target_version = ['py39', 'py310', 'py311', 'py312'] 19 | 20 | [tool.isort] 21 | profile = "black" 22 | 23 | [tool.mypy] 24 | no_implicit_optional = false # TODO: PEP484, set to true 25 | strict = false # TODO: set to true 26 | exclude = ["tests"] 27 | 28 | [[tool.mypy.overrides]] 29 | module = "pybind11.*" 30 | ignore_missing_imports = true 31 | 32 | [[tool.mypy.overrides]] 33 | module = "vertexai.*" 34 | ignore_missing_imports = true 35 | 36 | [[tool.mypy.overrides]] 37 | module = "google.*" 38 | ignore_missing_imports = true 39 | 40 | [[tool.mypy.overrides]] 41 | module = "llama_cpp.*" 42 | ignore_missing_imports = true 43 | 44 | [[tool.mypy.overrides]] 45 | module = "anthropic.*" 46 | ignore_missing_imports = true 47 | 48 | [[tool.mypy.overrides]] 49 | module = "litellm.*" 50 | ignore_missing_imports = true 51 | 52 | [[tool.mypy.overrides]] 53 | module = "transformers.*" 54 | ignore_missing_imports = true 55 | 56 | [[tool.mypy.overrides]] 57 | module = "diskcache.*" 58 | ignore_missing_imports = true 59 | 60 | [[tool.mypy.overrides]] 61 | module = "tokenizers.*" 62 | ignore_missing_imports = true -------------------------------------------------------------------------------- /tests/ReadMe.md: -------------------------------------------------------------------------------- 1 | # Testing 2 | 3 | ## Organisation 4 | 5 | The tests are arranged into the following directories: 6 | 7 | - `unit` tests do not depend on LLMs (but may use `model.Mock`) 8 | - `model_integration` tests should run on any (fully supported) model, supplied by the `selected_model` fixture 9 | - `model_specific` tests are for isolating particular issues with individual LLMs 10 | - `need_credentials` tests are for tests which need access to various credentials (mainly `Grammarless` models for endpoints without full Guidance support) 11 | - `notebook` tests are for notebooks 12 | 13 | The `model_specific` tests should make use of the `selected_model` machinery, but skip themselves if the appropriate model is not supplied. 14 | A sample means of achieving this: 15 | 16 | ```python 17 | @pytest.fixture(scope="module") 18 | def phi3_model(selected_model, selected_model_name): 19 | if selected_model_name in ["transformers_phi3_mini_4k_instruct_cpu"]: 20 | return selected_model 21 | else: 22 | pytest.skip("Requires Phi3 model") 23 | ``` 24 | 25 | ## Selecting a model 26 | 27 | To select a particular model when running the tests, use the `--selected_model` command line option. 28 | For example: 29 | 30 | ```bash 31 | python -m pytest --selected_model transformers_gemma2_9b_cpu ./tests/model_integration/ 32 | ``` 33 | 34 | The allowed values for `--selected_model` are in the [`confest.py`](./conftest.py) file, and are defined in the `selected_model` function. 35 | Alternatively, the `GUIDANCE_SELECTED_MODEL` environment variable can be used to override the default value for `--selected_model` (which can be useful when using a debugger). 36 | 37 | ### A Note on Credentials 38 | 39 | As noted above the `need_credentials` tests are mainly for `Grammarless` models - those for remote endpoints which do not support Guidance grammars (there are a few exceptions, which is why the directory isn't simply named `grammarless`). 40 | As endpoints with Guidance grammar support come online, their tests should *not* go in there; these should go into `model_integration` and `model_specific`, but will only be run in CI builds. 41 | Similarly, some models (e.g. LLama3) require credentials in order to download their weights from Hugging Face. 42 | These should be run through the `model_integration` and `model_specific` tests, but this run will happen from the CI build, and hence have credential access. 43 | 44 | ## Testing Goal 45 | 46 | Ideally, when creating a new feature, most of the tests should go into the `unit` directory, and make use of `model.Mock` if needed. 47 | These should always be able to be run with quite a minimal Guidance installation (have to add `pytest`, obviously). 48 | These tests should be fast, and facilitate a developer experience build around running 49 | 50 | ```bash 51 | pytest tests/unit 52 | ``` 53 | very frequently. 54 | 55 | There should also be a handful of tests in `model_integration`, which should work with _any_ fully supported Guidance model. 56 | Finally, if any model quirks are noted (and _especially_ if workarounds are required in the code), tests to characterise these should go into `model_specific`. 57 | 58 | In this paradigm, no tests in `unit` or `model_integration` should be using `pytest.skip` (or its variants). 59 | Those in `model_specific` will use `pytest.skip` for when the `selected_model` fixture is not of the appropriate type. 60 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/tests/__init__.py -------------------------------------------------------------------------------- /tests/bench/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/tests/bench/__init__.py -------------------------------------------------------------------------------- /tests/bench/test_api.py: -------------------------------------------------------------------------------- 1 | from guidance.bench._api import bench, AVAILABLE_MODELS 2 | from pathlib import Path 3 | import tempfile 4 | import pytest 5 | 6 | @pytest.mark.skip("Waiting on CI upgrades. Need access to env var LANGCHAIN_API_KEY.") 7 | def test_bench(): 8 | # TODO(nopdive): Parameterize models once CI is upgraded. 9 | with tempfile.TemporaryDirectory() as tmp_dir: 10 | db_path = Path(tmp_dir) / "bench.db" 11 | db_url = f"sqlite:///{db_path}" 12 | status_df, result_df = bench(db_url, "bench-test", models=AVAILABLE_MODELS[:1], debug_mode=True) 13 | 14 | assert len(status_df) > 0 15 | assert len(result_df) > 0 16 | assert (status_df['status'] == 'COMPLETE').all() -------------------------------------------------------------------------------- /tests/bench/test_powerlift.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import tempfile 3 | 4 | from guidance.bench._powerlift import retrieve_langchain 5 | from pathlib import Path 6 | 7 | def test_retrieve_langchain_err(monkeypatch): 8 | monkeypatch.delenv("LANGCHAIN_API_KEY", raising=False) 9 | with pytest.raises(ValueError): 10 | gen = retrieve_langchain() 11 | _ = list(gen) 12 | 13 | @pytest.mark.skip("Waiting on CI upgrades. Need access to env var LANGCHAIN_API_KEY.") 14 | def test_retrieve_langchain_basic(): 15 | with tempfile.TemporaryDirectory() as tmp_dir: 16 | # Run once 17 | first_results = list(retrieve_langchain(cache_dir=tmp_dir)) 18 | langchain_cache_path = Path(tmp_dir, "langchain") 19 | assert Path.exists(langchain_cache_path) 20 | 21 | # Run another time to trigger the cache 22 | second_results = list(retrieve_langchain(cache_dir=tmp_dir)) 23 | for first, second in zip(first_results, second_results): 24 | assert first.inputs.equals(second.inputs) -------------------------------------------------------------------------------- /tests/bench/test_utils.py: -------------------------------------------------------------------------------- 1 | from guidance.bench._utils import lib_bench_dir 2 | import tempfile 3 | from pathlib import Path 4 | 5 | def test_lib_bench_dir_basic(): 6 | expected_dir = Path.home() / ".guidance-bench" 7 | actual_dir = lib_bench_dir() 8 | 9 | assert expected_dir == actual_dir 10 | assert Path.exists(actual_dir) 11 | 12 | 13 | def test_lib_bench_dir_env_var(monkeypatch): 14 | with tempfile.TemporaryDirectory() as tmp_dir: 15 | expected_dir = Path(tmp_dir) / "guidance-bench" 16 | monkeypatch.setenv("GUIDANCE_BENCH_DIR", expected_dir) 17 | 18 | actual_dir = lib_bench_dir() 19 | assert expected_dir == actual_dir 20 | assert Path.exists(actual_dir) -------------------------------------------------------------------------------- /tests/model_integration/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/tests/model_integration/__init__.py -------------------------------------------------------------------------------- /tests/model_integration/library/test_subgrammar.py: -------------------------------------------------------------------------------- 1 | import re 2 | import numpy as np 3 | import pytest 4 | from jsonschema import validate 5 | import json 6 | 7 | import guidance 8 | from guidance import ( 9 | gen, 10 | select, 11 | optional, 12 | one_or_more, 13 | ) 14 | from guidance.library._subgrammar import subgrammar, lexeme 15 | 16 | 17 | @guidance(stateless=True) 18 | def json_string(lm): 19 | return lm + lexeme(r'"(\\(["\\\/bfnrt]|u[a-fA-F0-9]{4})|[^"\\\x00-\x1F\x7F]+)*"') 20 | 21 | 22 | @guidance(stateless=True) 23 | def json_number(lm): 24 | return lm + lexeme(r"-?(?:0|[1-9][0-9]*)(?:\.[0-9]+)?(?:[eE][+-]?[0-9]+)?") 25 | 26 | 27 | @guidance(stateless=True) 28 | def json_value(lm): 29 | return lm + select( 30 | [ 31 | json_string(), 32 | json_number(), 33 | json_object(), 34 | json_array(), 35 | "true", 36 | "false", 37 | "null", 38 | ] 39 | ) 40 | 41 | 42 | @guidance(stateless=True) 43 | def json_member(lm): 44 | return lm + json_string() + ":" + json_value() 45 | 46 | 47 | @guidance(stateless=True) 48 | def json_object(lm): 49 | return lm + "{" + optional(json_member() + one_or_more("," + json_member())) + "}" 50 | 51 | 52 | @guidance(stateless=True) 53 | def json_array(lm): 54 | return lm + "[" + optional(json_value() + one_or_more("," + json_value())) + "]" 55 | 56 | 57 | @guidance(stateless=True) 58 | def gen_json_object(lm, name: str, max_tokens=100000000): 59 | grm = subgrammar( 60 | body=json_object(), 61 | name=name, 62 | skip_regex=r"[\x20\x0A\x0D\x09]+", 63 | max_tokens=max_tokens 64 | ) 65 | return lm + grm 66 | 67 | 68 | def test_greedy_json_object(selected_model: guidance.models.Model): 69 | lm = selected_model 70 | lm += "John Doe's name, age, and birthday:\n" 71 | lm += gen_json_object("hacker", max_tokens=1000) 72 | lm += "\nScore: " + gen("score", regex="[1-3]") 73 | # make sure it parses as JSON 74 | obj = json.loads(lm["hacker"]) 75 | assert isinstance(obj, dict) 76 | assert lm["score"] in ["1", "2", "3"] 77 | 78 | 79 | def test_greedy_single_terminal(selected_model: guidance.models.Model): 80 | lm = selected_model 81 | lm += "A number: " 82 | lm += subgrammar(body=lexeme(r"[0-9]{3}")) 83 | assert re.search(r": [0-9]{3}$", str(lm)) 84 | -------------------------------------------------------------------------------- /tests/model_integration/library/test_substring.py: -------------------------------------------------------------------------------- 1 | from guidance import gen, models, substring 2 | 3 | 4 | def test_substring_equal_unconstrained(selected_model: models.Model): 5 | target_model = selected_model 6 | lm = target_model + "ae galera " + gen(max_tokens=10, name="test") 7 | lm2 = target_model + "ae galera " + substring(lm["test"], name="capture") 8 | assert lm2["capture"] in lm["test"] 9 | -------------------------------------------------------------------------------- /tests/model_integration/test_grammar.py: -------------------------------------------------------------------------------- 1 | from guidance import models, select 2 | 3 | 4 | def test_select_simple(selected_model: models.Model): 5 | lm = selected_model 6 | options = ["baad I think", "bad I think", "bad"] 7 | lm = lm + "Scott is quite " + select(name="bad", options=options) 8 | assert lm["bad"] in options 9 | -------------------------------------------------------------------------------- /tests/model_integration/test_tokenizers.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from tests.tokenizer_common import TOKENIZER_ROUND_TRIP_STRINGS, BaseTestTransformerTokenizers 3 | 4 | # These are not _strictly_ unit tests, since they refer 5 | # to specific tokenisers. However, tokenisers are small, 6 | # so if the tokeniser can be loaded separately from the 7 | # model, then this is a good place to have them live. 8 | 9 | # The LlamaCpp tokenisers have tests in test_llamacpp.py 10 | # since those tokenisers cannot be loaded separately from 11 | # their models. 12 | 13 | # The transformer tests have an authenticated version under 14 | # need_credentials 15 | 16 | 17 | class TestUnauthenticatedTransformerTokenizers(BaseTestTransformerTokenizers): 18 | TRANSFORMER_MODELS = [ 19 | "gpt2", 20 | "microsoft/phi-2", 21 | "microsoft/Phi-3-small-8k-instruct", 22 | "microsoft/Phi-3-mini-4k-instruct", 23 | ] 24 | 25 | @pytest.mark.parametrize( 26 | "model_name", 27 | TRANSFORMER_MODELS, 28 | ) 29 | def test_smoke(self, model_name: str): 30 | self.base_smoke(model_name) 31 | 32 | @pytest.mark.parametrize("model_name", TRANSFORMER_MODELS) 33 | @pytest.mark.parametrize("target_string", TOKENIZER_ROUND_TRIP_STRINGS) 34 | def test_string_roundtrip(self, model_name: str, target_string: str): 35 | self.base_string_roundtrip(model_name, target_string) 36 | 37 | @pytest.mark.parametrize("model_name", TRANSFORMER_MODELS) 38 | def test_eos_bos_token_round_trip(self, model_name: str): 39 | self.base_eos_bos_token_round_trip(model_name) 40 | -------------------------------------------------------------------------------- /tests/model_specific/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/tests/model_specific/__init__.py -------------------------------------------------------------------------------- /tests/model_specific/common_chat_testing.py: -------------------------------------------------------------------------------- 1 | from guidance import assistant, gen, models, system, user 2 | 3 | 4 | def smoke_chat(lm: models.Model, has_system_role: bool = True): 5 | # lm.engine.reset_metrics() 6 | if has_system_role: 7 | with system(): 8 | lm += "You are a math wiz." 9 | 10 | with user(): 11 | lm += "What is 1 + 1?" 12 | 13 | with assistant(): 14 | lm += gen(max_tokens=10, name="text", temperature=0.5) 15 | 16 | print(str(lm)) 17 | # print(f"{lm.engine.metrics=}") 18 | assert len(lm["text"]) > 0 19 | # assert lm.engine.metrics.engine_input_tokens > 2, "Expect some input tokens" 20 | # assert lm.engine.metrics.engine_output_tokens > 0, "Expect some output tokens" 21 | 22 | 23 | def longer_chat_1(lm: models.Model, has_system_role: bool = True): 24 | if has_system_role: 25 | with system(): 26 | lm += "You are a math wiz." 27 | 28 | with user(): 29 | lm += "What is 1 + 1?" 30 | 31 | with assistant(): 32 | lm += gen(max_tokens=10, name="text") 33 | 34 | print(str(lm)) 35 | assert len(lm["text"]) > 0 36 | 37 | with user(): 38 | lm += "10. Now you pick a number between 0 and 20" 39 | 40 | with assistant(): 41 | lm += gen(max_tokens=2, name="number") 42 | 43 | print(str(lm)) 44 | assert len(lm["number"]) > 0 45 | 46 | 47 | def longer_chat_2(lm: models.Model, has_system_role: bool = True): 48 | if has_system_role: 49 | with system(): 50 | lm += "You are a math wiz." 51 | 52 | with user(): 53 | lm += "What is 1 + 1?" 54 | 55 | # This is the new part compared to longer_chat_1 56 | with assistant(): 57 | lm += "2" 58 | 59 | with user(): 60 | lm += "What is 2 + 3?" 61 | 62 | # Resume the previous 63 | with assistant(): 64 | lm += gen(max_tokens=10, name="text") 65 | 66 | print(str(lm)) 67 | assert len(lm["text"]) > 0 68 | 69 | with user(): 70 | lm += "10. Now you pick a number between 0 and 20" 71 | 72 | with assistant(): 73 | lm += gen(max_tokens=2, name="number") 74 | 75 | print(str(lm)) 76 | assert len(lm["number"]) > 0 77 | -------------------------------------------------------------------------------- /tests/model_specific/llama_cpp_tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Set up for pytest -------------------------------------------------------------------------------- /tests/model_specific/llama_cpp_tests/test_chat_templates.py: -------------------------------------------------------------------------------- 1 | import jinja2 2 | import pytest 3 | 4 | import guidance 5 | 6 | from guidance.chat import CHAT_TEMPLATE_CACHE 7 | 8 | 9 | def test_chat_format_smoke(llamacpp_model: guidance.models.LlamaCpp, selected_model_name): 10 | # Retrieve the template string 11 | if ( 12 | hasattr(llamacpp_model.engine.model_obj, "metadata") 13 | and "tokenizer.chat_template" in llamacpp_model.engine.model_obj.metadata 14 | ): 15 | model_chat_template = llamacpp_model.engine.model_obj.metadata["tokenizer.chat_template"] 16 | else: 17 | pytest.skip("Chat template not available from LlamaCpp object") 18 | 19 | lm = guidance.models.Mock("") 20 | lm._interpreter.chat_template = CHAT_TEMPLATE_CACHE[model_chat_template]() 21 | 22 | messages = [ 23 | {"role": "user", "content": "Good_day_to_you!"}, 24 | {"role": "assistant", "content": "Hello!"}, 25 | ] 26 | 27 | # Note that llama-cpp-python does provide a llama_chat_apply_template function 28 | # but details about its use are thin on the ground and according to 29 | # https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template 30 | # it does its own thing internally 31 | jinja2_template = jinja2.Environment(loader=jinja2.BaseLoader()).from_string( 32 | model_chat_template 33 | ) 34 | jinja2_render = jinja2_template.render( 35 | messages=messages, 36 | bos_token=llamacpp_model.engine.tokenizer.bos_token.decode(), 37 | eos_token=llamacpp_model.engine.tokenizer.eos_token.decode(), 38 | ) 39 | 40 | with guidance.user(): 41 | lm += "Good_day_to_you!" 42 | with guidance.assistant(): 43 | lm += "Hello!" 44 | # Only check substring due to BOS/EOS tokens 45 | if selected_model_name == "llamacpp_mistral_7b_cpu": 46 | # The templates extracted via Transformers and GGUF are somewhat 47 | # different for Mistral. This is showing up in slightly 48 | # different spacing (our template is putting in a few extra spaces) 49 | # so at least make sure the 'tags' are correct 50 | assert str(lm).replace(" ", "") in jinja2_render.replace(" ", "") 51 | else: 52 | assert str(lm) in jinja2_render 53 | -------------------------------------------------------------------------------- /tests/model_specific/test_visual.py: -------------------------------------------------------------------------------- 1 | from guidance.registry import get_renderer 2 | 3 | 4 | def test_repeat_simple_model(): 5 | from guidance.models import Transformers 6 | from guidance import gen 7 | from guidance.registry import set_renderer, get_trace_handler 8 | from guidance.visual import JupyterWidgetRenderer 9 | 10 | trace_handler = get_trace_handler() 11 | original_renderer = get_renderer() 12 | for i in range(2): 13 | set_renderer(JupyterWidgetRenderer(trace_handler)) 14 | 15 | lm = Transformers('gpt2') 16 | lm += 'Hi hi hi' 17 | lm += gen(max_tokens=5) 18 | 19 | set_renderer(original_renderer) 20 | 21 | assert True 22 | 23 | 24 | def test_roles(): 25 | from guidance.models import Transformers 26 | from guidance import gen, user, system 27 | 28 | m0 = Transformers("gpt2") 29 | with system(): 30 | m1 = m0 + "You are responsible for writing an epic poem." 31 | with user(): 32 | m2 = m1 + "Roses are red and " + gen(name="suffix", regex=r'[\w\s]{20,30}', max_tokens=30) 33 | 34 | assert m2 is not None -------------------------------------------------------------------------------- /tests/need_credentials/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/tests/need_credentials/__init__.py -------------------------------------------------------------------------------- /tests/need_credentials/test_anthropic.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import guidance 4 | from guidance import assistant, capture, gen, select, system, user 5 | 6 | from ..utils import get_model 7 | 8 | 9 | def test_anthropic_chat(): 10 | try: 11 | lm = guidance.models.Anthropic(model="claude-3-haiku-20240307") 12 | except: 13 | pytest.skip("Skipping Anthropic test because we can't load the model!") 14 | with system(): 15 | lm += "You are a math wiz." 16 | 17 | with user(): 18 | lm += "What is 1 + 1?" 19 | 20 | with assistant(): 21 | lm += gen(max_tokens=10, name="text") 22 | lm += "Pick a number: " 23 | 24 | assert len(lm["text"]) > 0 25 | 26 | 27 | def test_anthropic_select(): 28 | try: 29 | lm = guidance.models.Anthropic(model="claude-instant-1.2") 30 | except: 31 | pytest.skip("Skipping Anthropic test because we can't load the model!") 32 | 33 | # We can't meaningfully test or enforce select on this model 34 | with pytest.raises(guidance.models._model.ConstraintException): 35 | with user(): 36 | lm += "Write the next number in the list: 1,2,3,4,5,6," 37 | with assistant(): 38 | lm += select( 39 | ["harsha", "scott", "marco"], name="the number" 40 | ) 41 | 42 | 43 | def test_anthropic_chat_loop(): 44 | # tests issue #509 45 | try: 46 | model = guidance.models.Anthropic(model="claude-3-haiku-20240307") 47 | except: 48 | pytest.skip("Skipping Anthropic test because we can't load the model!") 49 | 50 | for i in range(2): 51 | 52 | with system(): 53 | lm = model + "You will just return whatever number I give you" 54 | 55 | with user(): 56 | lm += f"The number is: {i}" 57 | 58 | with assistant(): 59 | lm += gen(name="answer", max_tokens=2) 60 | 61 | # def test_direct_anthropic_api(): 62 | # import anthropic 63 | 64 | # client = anthropic.Anthropic() 65 | 66 | # with client.messages.stream( 67 | # max_tokens=10, 68 | # system="You are a counting robot. Do nothing but continue counting numbers in the same format the user presented.", 69 | # messages=[{"role": "user", "content": "1,2,3,4,5,"}], 70 | # model="claude-3-haiku-20240307", 71 | # ) as stream: 72 | # text_list = [] 73 | # for text in stream.text_stream: 74 | # print(text, end="", flush=True) 75 | # text_list.append(text) 76 | 77 | # assert len(text_list) > 0 -------------------------------------------------------------------------------- /tests/need_credentials/test_azureai_studio.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from guidance import models 4 | from guidance.models._azureai import create_azure_aifoundry_model 5 | 6 | 7 | from ..model_specific import common_chat_testing 8 | from ..utils import env_or_fail, slowdown 9 | 10 | # pytest.skip("Deployments temporarily deleted", allow_module_level=True) 11 | 12 | # How to fill out the environment variables to 13 | # set up the models 14 | # Temporarily remove mistral pending endpoint investigation 15 | # _chat_models = {"phi3": "PHI3", "llama3": "LLAMA3_CHAT"} 16 | _chat_models = {"phi4": "PHI4"} 17 | 18 | 19 | def _get_chat_model(model_name: str): 20 | env_string = _chat_models[model_name] 21 | 22 | azureai_studio_endpoint = env_or_fail(f"AZUREAI_STUDIO_{env_string}_ENDPOINT") 23 | azureai_studio_model_name = env_or_fail(f"AZUREAI_STUDIO_{env_string}_MODEL_NAME") 24 | azureai_studio_key = env_or_fail(f"AZUREAI_STUDIO_{env_string}_KEY") 25 | 26 | lm = create_azure_aifoundry_model( 27 | azure_endpoint=azureai_studio_endpoint, 28 | api_key=azureai_studio_key, 29 | # token_credential=DefaultAzureCredential(), 30 | model_name=azureai_studio_model_name, 31 | ) 32 | assert isinstance(lm, models.Model) 33 | return lm 34 | 35 | 36 | @pytest.mark.parametrize("chat_model_name", _chat_models.keys()) 37 | def test_azureai_chat_smoke(chat_model_name: str): 38 | slowdown() 39 | 40 | lm = _get_chat_model(chat_model_name) 41 | 42 | common_chat_testing.smoke_chat(lm, chat_model_name != "mistral") 43 | 44 | 45 | @pytest.mark.parametrize("chat_model_name", _chat_models.keys()) 46 | def test_azureai_chat_longer_1(chat_model_name: str): 47 | slowdown() 48 | 49 | lm = _get_chat_model(chat_model_name) 50 | common_chat_testing.longer_chat_1(lm, chat_model_name != "mistral") 51 | 52 | 53 | @pytest.mark.parametrize("chat_model_name", _chat_models.keys()) 54 | def test_azureai_chat_longer_2(chat_model_name: str): 55 | slowdown() 56 | 57 | lm = _get_chat_model(chat_model_name) 58 | common_chat_testing.longer_chat_2(lm, chat_model_name != "mistral") 59 | -------------------------------------------------------------------------------- /tests/need_credentials/test_cohere.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import guidance 4 | from guidance import assistant, capture, gen, role, select, system, user 5 | 6 | 7 | def test_lite_llm_basic(): 8 | try: 9 | lm = guidance.models.CohereCompletion("command-nightly") 10 | except: 11 | pytest.skip("Skipping Cohere test because we can't load the model!") 12 | lm += "Count to 20: 1,2,3,4," 13 | nl = "\n" 14 | lm += f"""\ 15 | 5,6,7""" 16 | lm += f"""{gen(max_tokens=1, suffix=nl)}aaaaaa""" 17 | assert str(lm)[-5:] == "aaaaa" 18 | 19 | 20 | def test_lite_llm_instruct(): 21 | try: 22 | lm = guidance.models.CohereInstruct("command-nightly") 23 | except: 24 | pytest.skip("Skipping LiteLLM test because we can't load the model!") 25 | with role("instruction"): 26 | lm += "Count to 20." 27 | lm += gen("val", max_tokens=1) 28 | assert len(lm["val"]) > 0 29 | -------------------------------------------------------------------------------- /tests/need_credentials/test_googleai.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from guidance import gen, models, select 4 | 5 | from ..utils import get_model 6 | 7 | 8 | def test_gemini_pro(): 9 | from guidance import assistant, gen, models, system, user 10 | 11 | try: 12 | vmodel = models.GoogleAI("gemini-pro") 13 | except: 14 | pytest.skip("Skipping GoogleAI test because we can't load the model!") 15 | 16 | lm = vmodel 17 | 18 | with user(): 19 | lm += "The economy is crashing!" 20 | 21 | with assistant(): 22 | lm += gen("test1", max_tokens=100) 23 | 24 | with user(): 25 | lm += "What is the best again?" 26 | 27 | with assistant(): 28 | lm += gen("test2", max_tokens=100) 29 | 30 | assert len(lm["test1"]) > 0 31 | assert len(lm["test2"]) > 0 32 | 33 | # second time to make sure cache reuse is okay 34 | lm = vmodel 35 | 36 | with user(): 37 | lm += "The economy is crashing!" 38 | 39 | with assistant(): 40 | lm += gen("test1", max_tokens=100) 41 | 42 | with user(): 43 | lm += "What is the best again?" 44 | 45 | with assistant(): 46 | lm += gen("test2", max_tokens=100) 47 | 48 | assert len(lm["test1"]) > 0 49 | assert len(lm["test2"]) > 0 50 | assert lm["test1"].find("<|im_end|>") < 0 51 | -------------------------------------------------------------------------------- /tests/need_credentials/test_lite_llm.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import guidance 4 | from guidance import assistant, capture, gen, select, system, user 5 | 6 | from ..utils import get_model 7 | 8 | 9 | def test_lite_llm_basic_openai(): 10 | try: 11 | lm = guidance.models.LiteLLMCompletion("gpt-3.5-turbo-instruct") 12 | except: 13 | pytest.skip("Skipping LiteLLM test because we can't load the model!") 14 | lm += "Count to 20: 1,2,3,4," 15 | nl = "\n" 16 | lm += f"""\ 17 | 5,6,7""" 18 | lm += f"""{gen(max_tokens=1, suffix=nl)}aaaaaa""" 19 | assert str(lm)[-5:] == "aaaaa" 20 | 21 | 22 | def test_lite_llm_basic_cohere(): 23 | try: 24 | lm = guidance.models.LiteLLMCompletion("command-nightly") 25 | except: 26 | pytest.skip("Skipping LiteLLM test because we can't load the model!") 27 | lm += "Count to 20: 1,2,3,4," 28 | nl = "\n" 29 | lm += f"""\ 30 | 5,6,7""" 31 | lm += f"""{gen(max_tokens=1, suffix=nl)}aaaaaa""" 32 | assert str(lm)[-5:] == "aaaaa" 33 | 34 | 35 | def test_lite_llm_select(): 36 | try: 37 | lm = guidance.models.LiteLLMCompletion("gpt-3.5-turbo-instruct") 38 | except: 39 | pytest.skip("Skipping LiteLLM test because we can't load the model!") 40 | lm += "Pick a number: " 41 | lm += select( 42 | ["1", "11", "111", "1111", "11111", "111111", "1111111"], name="the number" 43 | ) 44 | assert str(lm)[-1] in "123" 45 | 46 | 47 | def test_lite_llm_chat(): 48 | try: 49 | lm = guidance.models.LiteLLMChat("gpt-3.5-turbo") 50 | except: 51 | pytest.skip("Skipping LiteLLM test because we can't load the model!") 52 | with system(): 53 | lm += "You are a math wiz." 54 | 55 | with user(): 56 | lm += "What is 1 + 1?" 57 | 58 | with assistant(): 59 | lm += gen(max_tokens=10, name="text") 60 | lm += "Pick a number: " 61 | 62 | assert len(lm["text"]) > 0 63 | -------------------------------------------------------------------------------- /tests/need_credentials/test_togetherai.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import guidance 4 | from guidance import assistant, gen, select, system, user 5 | 6 | 7 | def test_togetherai_basic(): 8 | try: 9 | lm = guidance.models.TogetherAI("mistralai/Mistral-7B-v0.1") 10 | except: 11 | pytest.skip("Skipping TogetherAI test because we can't load the model!") 12 | lm += "Count to 20: 1,2,3,4," 13 | stop = "\n" 14 | lm += f"""{gen(max_tokens=1, stop=stop, name="text")}""" 15 | assert str(lm)[-1] == "5" 16 | 17 | 18 | def test_togetherai_select(): 19 | try: 20 | lm = guidance.models.TogetherAI("mistralai/Mistral-7B-v0.1") 21 | except: 22 | pytest.skip("Skipping TogetherAI test because we can't load the model!") 23 | nums = ["1", "11", "111", "1111", "11111", "111111", "1111111"] 24 | lm += "Pick a number: " 25 | lm += select(nums, name="number") 26 | assert str(lm["number"]) in nums 27 | 28 | 29 | def test_togetherai_chat(): 30 | try: 31 | lm = guidance.models.TogetherAIChat("teknium/OpenHermes-2-Mistral-7B") 32 | except: 33 | pytest.skip("Skipping TogetherAI test because we can't load the model!") 34 | with system(): 35 | lm += "You are a math wiz." 36 | 37 | with user(): 38 | lm += "What is 1 + 1?" 39 | 40 | with assistant(): 41 | lm += gen(max_tokens=10, name="text") 42 | lm += "Pick a number: " 43 | 44 | assert len(lm["text"]) > 0 45 | 46 | 47 | def test_togetherai_chat_without_roles(): 48 | try: 49 | lm = guidance.models.TogetherAIChat("teknium/OpenHermes-2-Mistral-7B") 50 | except: 51 | pytest.skip("Skipping TogetherAI test because we can't load the model!") 52 | with pytest.raises(ValueError) as error_info: 53 | lm += "You are a math wiz. What is 1+1?" + gen(max_tokens=10, name="text") 54 | 55 | 56 | def test_togetherai_chat_loop(): 57 | try: 58 | model = guidance.models.TogetherAIChat( 59 | "teknium/OpenHermes-2-Mistral-7B", echo=False 60 | ) 61 | except: 62 | pytest.skip("Skipping TogetherAI test because we can't load the model!") 63 | 64 | with system(): 65 | lm = model + "You will just return whatever number I give you" 66 | 67 | for i in range(2): 68 | with user(): 69 | lm += f"The number is: {i}" 70 | 71 | with assistant(): 72 | lm += gen(name="answer", max_tokens=10) 73 | assert len(lm["answer"]) > 0 74 | -------------------------------------------------------------------------------- /tests/need_credentials/test_tokenizers.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from guidance import models 4 | 5 | from tests.tokenizer_common import TOKENIZER_ROUND_TRIP_STRINGS, BaseTestTransformerTokenizers 6 | 7 | # Since this is under 'need_credentials' we can assume that HF_TOKEN 8 | # will be available when run 9 | 10 | 11 | class TestAuthenticatedTransformerTokenizers(BaseTestTransformerTokenizers): 12 | TRANSFORMER_MODELS = [ 13 | # "google/gemma-2-9b-it", # Works locally, fails in build 14 | "meta-llama/Meta-Llama-3-8B-Instruct", 15 | ] 16 | 17 | @pytest.mark.parametrize( 18 | "model_name", 19 | TRANSFORMER_MODELS, 20 | ) 21 | def test_smoke(self, model_name: str): 22 | try: 23 | self.base_smoke(model_name) 24 | except OSError: 25 | pytest.skip("HuggingFace raises OSError if user is not authenticated.") 26 | 27 | @pytest.mark.parametrize("model_name", TRANSFORMER_MODELS) 28 | @pytest.mark.parametrize("target_string", TOKENIZER_ROUND_TRIP_STRINGS) 29 | def test_string_roundtrip(self, model_name: str, target_string: str): 30 | try: 31 | self.base_string_roundtrip(model_name, target_string) 32 | except OSError: 33 | pytest.skip("HuggingFace raises OSError if user is not authenticated.") 34 | -------------------------------------------------------------------------------- /tests/need_credentials/test_vertexai.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from guidance import gen, role, models, select 4 | 5 | from ..utils import get_model 6 | 7 | 8 | def test_palm2_instruct(): 9 | try: 10 | vmodel = models.VertexAI("text-bison@001") 11 | except: 12 | pytest.skip("Skipping VertexAI test because we can't load the model!") 13 | 14 | with role("instruction"): 15 | lm = vmodel + "this is a test about" 16 | lm += gen("test", max_tokens=100) 17 | assert len(lm["test"]) > 0 18 | 19 | 20 | def test_palm2_chat(): 21 | from guidance import assistant, gen, models, system, user 22 | 23 | try: 24 | vmodel = models.VertexAI("chat-bison@001") 25 | except: 26 | pytest.skip("Skipping VertexAI test because we can't load the model!") 27 | 28 | with system(): 29 | lm = vmodel + "You are an always-happy agent no matter what." 30 | 31 | with user(): 32 | lm += "The economy is crashing!" 33 | 34 | with assistant(): 35 | lm += gen("test1", max_tokens=100) 36 | 37 | with user(): 38 | lm += "What is the best again?" 39 | 40 | with assistant(): 41 | lm += gen("test2", max_tokens=100) 42 | 43 | assert len(lm["test1"]) > 0 44 | assert len(lm["test2"]) > 0 45 | 46 | # second time to make sure cache reuse is okay 47 | with system(): 48 | lm = vmodel + "You are an always-happy agent no matter what." 49 | 50 | with user(): 51 | lm += "The economy is crashing!" 52 | 53 | with assistant(): 54 | lm += gen("test1", max_tokens=100) 55 | 56 | with user(): 57 | lm += "What is the best again?" 58 | 59 | with assistant(): 60 | lm += gen("test2", max_tokens=100) 61 | 62 | assert len(lm["test1"]) > 0 63 | assert len(lm["test2"]) > 0 64 | assert lm["test1"].find("<|im_end|>") < 0 65 | 66 | 67 | def test_gemini_chat(): 68 | from guidance import assistant, gen, models, system, user 69 | 70 | try: 71 | vmodel = models.VertexAI("gemini-pro") 72 | except: 73 | pytest.skip("Skipping VertexAI test because we can't load the model!") 74 | 75 | lm = vmodel 76 | 77 | with user(): 78 | lm += "The economy is crashing!" 79 | 80 | with assistant(): 81 | lm += gen("test1", max_tokens=100) 82 | 83 | with user(): 84 | lm += "What is the best again?" 85 | 86 | with assistant(): 87 | lm += gen("test2", max_tokens=100) 88 | 89 | assert len(lm["test1"]) > 0 90 | assert len(lm["test2"]) > 0 91 | 92 | # second time to make sure cache reuse is okay 93 | lm = vmodel 94 | 95 | with user(): 96 | lm += "The economy is crashing!" 97 | 98 | with assistant(): 99 | lm += gen("test1", max_tokens=100) 100 | 101 | with user(): 102 | lm += "What is the best again?" 103 | 104 | with assistant(): 105 | lm += gen("test2", max_tokens=100) 106 | 107 | assert len(lm["test1"]) > 0 108 | assert len(lm["test2"]) > 0 109 | assert lm["test1"].find("<|im_end|>") < 0 110 | -------------------------------------------------------------------------------- /tests/notebooks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/tests/notebooks/__init__.py -------------------------------------------------------------------------------- /tests/tokenizer_common.py: -------------------------------------------------------------------------------- 1 | from guidance import models 2 | 3 | TOKENIZER_ROUND_TRIP_STRINGS = [ 4 | "", 5 | " ", 6 | "hello", 7 | " hello", 8 | "two words", 9 | " two words", 10 | " two words ", 11 | "two words ", 12 | "’", 13 | "’•¶∂ƒ˙∆£Ħ爨ൠᅘ∰፨", 14 | ] 15 | 16 | 17 | class BaseTestTransformerTokenizers: 18 | def base_smoke(self, model_name: str): 19 | my_tok = models.TransformersTokenizer.from_pretrained( 20 | model_name, trust_remote_code=True, 21 | ) 22 | assert my_tok is not None 23 | 24 | def base_string_roundtrip(self, model_name: str, target_string: str): 25 | my_tok = models.TransformersTokenizer.from_pretrained( 26 | model_name, 27 | trust_remote_code=True, 28 | ) 29 | 30 | encoded = my_tok.encode(target_string.encode()) 31 | decoded = my_tok.decode(encoded) 32 | final_string = decoded.decode() 33 | 34 | assert final_string == target_string 35 | 36 | def base_eos_bos_token_round_trip( 37 | self, model_name: str 38 | ): 39 | my_tok = models.TransformersTokenizer.from_pretrained( 40 | model_name, 41 | trust_remote_code=True, 42 | ) 43 | 44 | assert my_tok.eos_token == my_tok.decode([my_tok.eos_token_id]) 45 | assert my_tok.encode(my_tok.eos_token) == [my_tok.eos_token_id] 46 | 47 | if my_tok.bos_token is not None: 48 | assert my_tok.bos_token == my_tok.decode([my_tok.bos_token_id]) 49 | assert my_tok.encode(my_tok.bos_token) == [my_tok.bos_token_id] 50 | -------------------------------------------------------------------------------- /tests/unit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/tests/unit/__init__.py -------------------------------------------------------------------------------- /tests/unit/library/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/tests/unit/library/__init__.py -------------------------------------------------------------------------------- /tests/unit/library/json/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guidance-ai/guidance/e5be7b6bf41e93faeb540ae169509568a1941c9a/tests/unit/library/json/__init__.py -------------------------------------------------------------------------------- /tests/unit/library/json/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | from functools import partial 3 | from json import dumps as json_dumps 4 | from json import loads as json_loads 5 | from typing import Any, Optional, Union 6 | 7 | from jsonschema import validate 8 | 9 | from guidance import json as gen_json 10 | from guidance.library._json import JSONSchema 11 | 12 | from ....utils import check_match_failure as _check_match_failure 13 | from ....utils import check_run_with_temperature 14 | from ....utils import generate_and_check as _generate_and_check 15 | 16 | 17 | def generate_and_check( 18 | target_obj: Any, 19 | schema_obj: Union[str, JSONSchema], 20 | desired_temperature: Optional[float] = None, 21 | ): 22 | if isinstance(schema_obj, str): 23 | schema_obj = json_loads(schema_obj) 24 | 25 | # Sanity check what we're being asked 26 | validate(instance=target_obj, schema=schema_obj) 27 | prepared_json = json_dumps(target_obj) 28 | assert json.loads(prepared_json) == target_obj 29 | 30 | # Now test that the grammar can recognize and generate prepared_json 31 | # We partial in the grammar_callable 32 | if desired_temperature is not None: 33 | grammar_callable = partial(gen_json, schema=schema_obj, temperature=desired_temperature) 34 | else: 35 | grammar_callable = partial(gen_json, schema=schema_obj) 36 | 37 | lm = _generate_and_check( 38 | grammar_callable, 39 | test_string=prepared_json, 40 | ) 41 | check_run_with_temperature(lm, desired_temperature) 42 | 43 | 44 | def check_match_failure( 45 | *, 46 | bad_string: str, 47 | good_bytes: Optional[bytes] = None, 48 | failure_byte: Optional[bytes] = None, 49 | allowed_bytes: Optional[set[bytes]] = None, 50 | schema_obj: Union[str, JSONSchema], 51 | ): 52 | grammar = gen_json(schema=schema_obj) 53 | 54 | _check_match_failure( 55 | bad_string=bad_string, 56 | good_bytes=good_bytes, 57 | failure_byte=failure_byte, 58 | allowed_bytes=allowed_bytes, 59 | grammar=grammar, 60 | ) 61 | -------------------------------------------------------------------------------- /tests/unit/library/test_block.py: -------------------------------------------------------------------------------- 1 | from guidance import regex, block, models 2 | import pytest 3 | 4 | 5 | def test_text_opener(): 6 | model = models.Mock("open texta") 7 | with block(opener="open text"): 8 | model += regex(r".") 9 | assert str(model) == "open texta" 10 | 11 | 12 | def test_text_closer(): 13 | # NOTE(nopdive): Behavioral change, no longer need closer for str call. 14 | model = models.Mock("a") 15 | model += "" 16 | with block(closer="close text"): 17 | model += regex(r".") 18 | assert str(model) == "a" 19 | 20 | 21 | def test_grammar_opener(): 22 | model = models.Mock("open texta") 23 | with block(opener="open tex" + regex(r".")): 24 | model += regex(r".") 25 | assert str(model) == "open texta" 26 | 27 | 28 | # TODO(nopdive): Review this exception later -- how should we be going about grammars in blocks overall. 29 | @pytest.mark.skip(reason="requires review") 30 | def test_grammar_closer(): 31 | model = models.Mock(["aclose text", "close text"]) 32 | model += "" 33 | try: 34 | with block(closer=regex(r".") + "lose text"): 35 | model += regex(r".") 36 | except: 37 | return # we expect an exception 38 | assert ( 39 | False 40 | ), "We should have thrown an exception using a context (prompt) based grammar in the closer!" 41 | 42 | 43 | def test_block_name_capture(): 44 | model = models.Mock("open texta") 45 | with block("my_data"): 46 | model += "open text" 47 | model += regex(r".") 48 | assert model["my_data"] == "open texta" 49 | 50 | 51 | def test_block_name_capture_closed(): 52 | model = models.Mock("open texta") 53 | with block("my_data"): 54 | model += "open text" 55 | model += regex(r".") 56 | model += "tmp" 57 | assert model["my_data"] == "open texta" 58 | -------------------------------------------------------------------------------- /tests/unit/library/test_capture.py: -------------------------------------------------------------------------------- 1 | from guidance import capture, models, one_or_more, select, guidance 2 | 3 | 4 | def test_capture(): 5 | model = models.Mock() 6 | model += "This is" + capture(select(options=["bad", "quite bad"]), name="my_var") 7 | assert model["my_var"] in ["bad", "quite bad"] 8 | 9 | 10 | def test_capture_star(): 11 | lm = models.Mock(b"1234233234") 12 | grammar = capture(one_or_more(select(["1", "2"])), name="test") 13 | lm2 = lm + grammar 14 | assert lm2["test"] == "12" 15 | 16 | 17 | def test_capture_raw_function(): 18 | lm = models.Mock(b"1234233234") 19 | lm += select(["1", "2"], name="state") 20 | 21 | @guidance 22 | def raw_fn(lm): 23 | if lm["state"] == "1": 24 | lm += select(["3", "4"], name="state_1") 25 | elif lm["state"] == "2": 26 | lm += select(["5", "6"], name="state_2") 27 | return lm 28 | 29 | lm_nocap = lm + "the beginning|" + raw_fn() + "|the end" 30 | lm_cap_arg = lm + "the beginning|" + capture("" + raw_fn() + "" , "cap_arg") + "|the end" 31 | lm_cap_kwarg = lm + "the beginning|" + capture("" + raw_fn() + "", name="cap_kwarg") + "|the end" 32 | 33 | # Bunch of random tests 34 | assert "state_1" in lm_nocap or "state_2" in lm_nocap 35 | assert "cap_arg" in lm_cap_arg 36 | assert "cap_kwarg" in lm_cap_kwarg 37 | assert lm_cap_arg["cap_arg"].startswith("") 38 | assert lm_cap_arg["cap_arg"].endswith("") 39 | assert lm_cap_kwarg["cap_kwarg"].startswith("") 40 | assert lm_cap_kwarg["cap_kwarg"].endswith("") 41 | assert len(lm_cap_arg["cap_arg"]) == len(lm_cap_kwarg["cap_kwarg"]) 42 | 43 | assert str(lm_nocap).endswith("|the end") 44 | assert str(lm_cap_arg).endswith("|the end") 45 | assert str(lm_cap_kwarg).endswith("|the end") -------------------------------------------------------------------------------- /tests/unit/library/test_image.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import uuid 3 | import requests 4 | import tempfile 5 | import pathlib 6 | 7 | from urllib.error import HTTPError, URLError 8 | from guidance import models, image 9 | from ...utils import remote_image_url 10 | 11 | ################################################################################# 12 | # The tests below need to be rewritten once multimodal support is complete 13 | # A pseudocode description has been written in comments to preserve notes about 14 | # what was tested, for reference, in case we want to reproduct it in the new system 15 | ################################################################################# 16 | 17 | def test_local_image(): 18 | # 1. Create a mock model 19 | # 2. Add an image from the local filesystem to the model's prompt 20 | # 3. Validate that the model contains the image in its prompt 21 | pass 22 | 23 | 24 | def test_local_image_not_found(): 25 | # 1. Create a mock model 26 | # 2. Try to add a non-existing image from the local filesystem to the model's prompt 27 | # 3. Check for a file not found error, or other appropriate exception, to be thrown 28 | pass 29 | 30 | 31 | def test_remote_image(): 32 | # 1. Create a mock model 33 | # 2. Add a remote image from picsum using remote_image_url() utility function 34 | # 3. Validate that the model contains the image in its prompt 35 | pass 36 | 37 | 38 | def test_remote_image_not_found(): 39 | # 1. Create a mock model 40 | # 2. Try to add a non-existing remote image 41 | # 3. Catch an HTTPError or URLError from the model trying to fetch the image, which should result in a 404 42 | pass 43 | 44 | 45 | def test_image_from_bytes(): 46 | # 1. Create a mock model 47 | # 2. Download an image from remote_image_url() and save it as a binary file 48 | # 3. Read the binary file and add it to the model's prompt as an image 49 | # 3. Validate that the model contains the image in its prompt 50 | pass 51 | -------------------------------------------------------------------------------- /tests/unit/library/test_one_or_more.py: -------------------------------------------------------------------------------- 1 | from guidance import models, one_or_more, regex 2 | 3 | 4 | def test_string(): 5 | model = models.Mock("aaabc") 6 | assert str(model + "" + one_or_more("a")) == "aaa" 7 | 8 | 9 | def test_grammar(): 10 | model = models.Mock("bac") 11 | assert str(model + "" + one_or_more(regex(r"[ab]"))) == "ba" 12 | 13 | 14 | def test_at_least_one(): 15 | model = models.Mock("cbac") 16 | assert not str(model + "" + one_or_more(regex(r"[ab]"))).startswith("c") 17 | -------------------------------------------------------------------------------- /tests/unit/library/test_subgrammar.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from guidance.library._subgrammar import subgrammar, lexeme 3 | 4 | class TestEndingLexemeAmbiguous: 5 | @pytest.mark.parametrize( 6 | "skip_rx", 7 | [None, r"\s", r"\s+", r"\s*"] 8 | ) 9 | @pytest.mark.parametrize( 10 | "string", 11 | ["123"] 12 | ) 13 | def test_lexeme_can_be_done_even_if_could_match_more(self, string, skip_rx): 14 | g1 = subgrammar(body=lexeme(r"\d+"), skip_regex=skip_rx, name="mycap") 15 | assert (m := g1.match(string)) is not None and m.captures["mycap"] == string 16 | g2 = g1 + "x" 17 | assert (m := g2.match(f"{string}x")) is not None and m.captures["mycap"] == string 18 | 19 | @pytest.mark.parametrize( 20 | "string", 21 | ["1", "123", "1x", "123x"] 22 | ) 23 | def test_nullable_final_lexeme(self, string): 24 | g = subgrammar(body=lexeme(r"\d+")+lexeme(r"x?"), name="mycap") 25 | match = g.match(string) 26 | assert match is not None 27 | assert match.captures["mycap"] == string 28 | -------------------------------------------------------------------------------- /tests/unit/library/test_substring.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from guidance import models, substring 4 | 5 | 6 | @pytest.mark.parametrize( 7 | ("mock_string", "target_string", "expected_string"), 8 | [ 9 | ("abc", "abc", "abc"), 10 | ("ab", "abc", "ab"), 11 | ("bc", "abc", "bc"), 12 | ("a", "abc", "a"), 13 | ("b", "abc", "b"), 14 | ("c", "abc", "c"), 15 | ("abc", "def", ""), # This is a 'failure' case 16 | ( 17 | "long string", 18 | "This is long string, only take part of this long string", 19 | "long string", 20 | ), 21 | ], 22 | ) 23 | def test_mocked_substring(mock_string, target_string, expected_string): 24 | m = models.Mock(f"{mock_string}") 25 | 26 | lm = m + substring(target_string, chunk="character", name="result") 27 | assert lm["result"] == expected_string 28 | -------------------------------------------------------------------------------- /tests/unit/test_model.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import guidance 3 | from guidance import gen, models, user, system 4 | 5 | def test_call_embeddings(): 6 | """This tests calls embedded in strings.""" 7 | model = models.Mock() 8 | 9 | @guidance(dedent=False) 10 | def bla(lm, bla): 11 | lm += bla + "ae" + gen(max_tokens=10) 12 | return lm 13 | 14 | @guidance(dedent=False) 15 | def ble(lm): 16 | lm += f""" 17 | ae galera! {bla('33')} 18 | let's do more stuff!!""" + gen( 19 | max_tokens=10 20 | ) 21 | return lm 22 | 23 | assert "{{G|" not in str(model + ble()) 24 | 25 | 26 | @pytest.mark.xfail( 27 | reason="llguidance currently emits an additional empty capture group when no explicit stop is provided" 28 | ) 29 | def test_model_set(): 30 | model = models.Mock() 31 | model = model.set("num", "4") 32 | assert "num" in model 33 | assert model["num"] == "4" 34 | assert model.log_prob("num") is not None 35 | 36 | model = model.set("list_num", ['1', '2']) 37 | assert "list_num" in model 38 | assert model["list_num"] == ['1', '2'] 39 | assert model.log_prob("list_num") is not None 40 | 41 | model += gen("list_num", max_tokens=10, list_append=True) 42 | assert len(model['list_num']) == 3 43 | 44 | 45 | def test_trace(): 46 | from guidance import system, user, gen, models 47 | m0 = models.Mock() 48 | 49 | with system(): 50 | m1 = m0 + "You are responsible for autocompleting a sentence." 51 | with user(): 52 | m2 = m1 + "Roses are red and " + gen(name="suffix", regex='[A-Za-z]{2,5}', max_tokens=5) 53 | 54 | assert m2['suffix'] is not None -------------------------------------------------------------------------------- /tests/unit/test_trace.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from guidance.trace._trace import WeakRefList, TraceHandler, LiteralInput, TextOutput, RoleCloserInput 4 | from guidance.trace import TraceNode, StatelessGuidanceInput, StatefulGuidanceInput, ImageInput, EmbeddedInput, \ 5 | RoleOpenerInput, ImageOutput, CaptureOutput 6 | 7 | 8 | def test_weak_ref_list(): 9 | class EmptyClass: 10 | pass 11 | 12 | a = EmptyClass() 13 | b = EmptyClass() 14 | li = WeakRefList() 15 | li.append(a) 16 | li.append(b) 17 | 18 | del a 19 | with pytest.raises(ReferenceError): 20 | _ = li[0] 21 | 22 | # Does not remove dead entries 23 | _ = li[1] 24 | assert len(li) == 2 25 | 26 | # Remove works as expected 27 | li.remove(b) 28 | assert len(li) == 1 29 | 30 | # Iter goes over live entries only 31 | for el in li: 32 | _ = el 33 | 34 | 35 | def test_trace_node(): 36 | root = TraceNode() 37 | child1 = TraceNode() 38 | child2 = TraceNode() 39 | root.add_child(child1) 40 | root.add_child(child2) 41 | 42 | assert root.root() is root 43 | assert list(root.ancestors()) == [] 44 | assert list(root.path()) == [root] 45 | assert list(root.traverse()) == [root, child1, child2] 46 | 47 | assert child1.root() is root 48 | assert list(child1.ancestors()) == [root] 49 | assert list(child1.path()) == [root, child1] 50 | assert list(child1.traverse()) == [child1] 51 | 52 | 53 | def test_trace_handler(): 54 | trace_handler = TraceHandler() 55 | root = trace_handler.update_node(0, None, None) 56 | child1 = trace_handler.update_node(1, 0, None) 57 | inp = LiteralInput(value="") 58 | out = TextOutput(value="") 59 | pre_child2 = trace_handler.update_node(2, 0, inp) 60 | child2 = trace_handler.update_node(2, 0, out) 61 | 62 | assert pre_child2 == child2 63 | assert child2.input == inp 64 | assert child2.output == out 65 | assert child2.root() == root 66 | assert child1 not in child2.path() 67 | 68 | 69 | 70 | @pytest.mark.parametrize( 71 | 'node', 72 | [ 73 | StatelessGuidanceInput(value=None), 74 | StatefulGuidanceInput(value=None), 75 | LiteralInput(value=""), 76 | ImageInput(value=b""), 77 | EmbeddedInput(value=""), 78 | RoleOpenerInput(name=""), 79 | RoleCloserInput(name=""), 80 | TextOutput(value=""), 81 | ImageOutput(value=b""), 82 | CaptureOutput(name=""), 83 | ] 84 | ) 85 | def test_node_format_smoke(node): 86 | node.__repr__() 87 | node.__str__() 88 | -------------------------------------------------------------------------------- /tests/unit/test_visual.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from guidance._schema import GenTokenExtra, GenToken 3 | from guidance.registry import get_bg_async 4 | from guidance.trace import TraceHandler, LiteralInput, TextOutput 5 | from guidance.visual import TraceMessage, MetricMessage, ExecutionCompletedMessage, \ 6 | TokensMessage, ResetDisplayMessage, ClientReadyMessage, OutputRequestMessage, \ 7 | ClientReadyAckMessage, trace_node_to_html, display_trace_tree, trace_node_to_str, TopicExchange, GuidanceMessage 8 | from guidance.visual import serialize_message, deserialize_message 9 | from guidance.visual._environment import Environment 10 | import asyncio 11 | 12 | from guidance.visual._exchange import DEFAULT_TOPIC 13 | 14 | 15 | @pytest.mark.parametrize( 16 | "message", 17 | [ 18 | TraceMessage(trace_id=0), 19 | MetricMessage(name="name", value="value"), 20 | ExecutionCompletedMessage(last_trace_id=0), 21 | TokensMessage(trace_id=0, text="text", tokens=[ 22 | GenTokenExtra(token_id=0, prob=0, top_k=[GenToken(token_id=0, prob=0)]) 23 | ]), 24 | ResetDisplayMessage(), 25 | ClientReadyMessage(), 26 | ClientReadyAckMessage(), 27 | OutputRequestMessage(), 28 | ] 29 | ) 30 | def test_serialization(message): 31 | ser = serialize_message(message) 32 | deser = deserialize_message(ser) 33 | assert deser.model_dump() == message.model_dump() 34 | 35 | 36 | def test_async(): 37 | _, loop = get_bg_async()._thread_and_loop() 38 | assert loop != asyncio.get_event_loop() 39 | 40 | async def f(): 41 | return True 42 | 43 | task = get_bg_async().run_async_coroutine(get_bg_async().async_task(f())).result() 44 | assert task.result() is True 45 | 46 | 47 | def test_str_method_smoke(): 48 | trace_handler = TraceHandler() 49 | trace_handler.update_node(1, 0, None) 50 | inp = LiteralInput(value="Hi there!") 51 | out = TextOutput(value="Hi there!") 52 | trace_handler.update_node(2, 0, inp) 53 | child_node = trace_handler.update_node(2, 0, out) 54 | 55 | assert trace_node_to_html(child_node) != "" 56 | assert trace_node_to_str(child_node) != "" 57 | assert display_trace_tree(trace_handler) is None 58 | 59 | 60 | def test_environment(): 61 | env = Environment() 62 | assert not env.is_cloud() 63 | assert not env.is_notebook() 64 | assert env.is_terminal() 65 | assert "ipython-zmq" not in env.detected_envs 66 | 67 | 68 | def test_exchange(): 69 | exchange = TopicExchange() 70 | assert len(exchange._observers) == 0 71 | 72 | count = 0 73 | def inc(_: GuidanceMessage): 74 | nonlocal count 75 | count += 1 76 | 77 | # Defaults 78 | exchange.subscribe(inc) 79 | exchange.publish(GuidanceMessage()) 80 | exchange.unsubscribe(inc) 81 | assert count == 1 82 | assert len(exchange._observers) == 0 83 | 84 | # Topic pattern set 85 | topic_pat = "no" 86 | exchange.subscribe(inc, topic_pat) 87 | exchange.publish(GuidanceMessage(), topic_pat) 88 | exchange.unsubscribe(inc, topic_pat) 89 | assert count == 2 90 | assert len(exchange._observers) == 0 91 | 92 | # Missed topic 93 | topic_pat = "what" 94 | exchange.subscribe(inc, topic_pat) 95 | exchange.publish(GuidanceMessage()) 96 | exchange.unsubscribe(inc, topic_pat) 97 | assert count == 2 98 | assert len(exchange._observers) == 0 --------------------------------------------------------------------------------