├── .github └── workflows │ ├── docs.yml │ ├── release.yml │ └── tests.yml ├── .gitignore ├── .pre-commit-config.yaml ├── README.md ├── benchmark └── benchmark_backend.py ├── codecov.yml ├── docs ├── anatomy.md ├── batching.md ├── caching.md ├── gen_reference_page.py ├── getting_started.md ├── immutability.md ├── index.md ├── performance.md ├── transformers.md └── visualization.md ├── examples ├── __init__.py ├── grammar_constraint.py ├── haiku.py └── hard_constraints.py ├── html ├── results │ └── output.json └── smc.html ├── llamppl ├── __init__.py ├── chunks.py ├── distributions │ ├── __init__.py │ ├── bernoulli.py │ ├── distribution.py │ ├── geometric.py │ ├── lmcontext.py │ ├── logcategorical.py │ ├── tokencategorical.py │ └── transformer.py ├── inference │ ├── __init__.py │ ├── smc_record.py │ ├── smc_standard.py │ └── smc_steer.py ├── llms.py ├── modeling.py └── util.py ├── mkdocs.yml ├── pyproject.toml └── tests ├── test_examples.py └── test_lmcontext.py /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | name: docs 2 | on: 3 | push: 4 | branches: 5 | - main 6 | permissions: 7 | contents: write 8 | jobs: 9 | deploy: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v3 13 | - uses: actions/setup-python@v4 14 | with: 15 | python-version: 3.x 16 | - run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV 17 | - uses: actions/cache@v3 18 | with: 19 | key: mkdocs-material-${{ env.cache_id }} 20 | path: .cache 21 | restore-keys: | 22 | mkdocs-material- 23 | - run: pip install mkdocs-material mkdocstrings mkdocs-literate-nav mkdocs-section-index mkdocs-gen-files mkdocstrings-python 24 | - run: mkdocs gh-deploy --force 25 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release to PyPI 2 | 3 | on: 4 | workflow_dispatch: 5 | release: 6 | types: [published] 7 | 8 | jobs: 9 | release: 10 | runs-on: ubuntu-22.04 11 | 12 | # Add "id-token" with the intended permissions. 13 | permissions: 14 | contents: 'read' 15 | id-token: 'write' 16 | 17 | steps: 18 | - uses: actions/checkout@v4 19 | with: 20 | # This is here so that the versioning plugin will be able to see tags 21 | # and version using them. 22 | fetch-depth: 0 23 | 24 | - uses: actions/setup-python@v4 25 | with: 26 | python-version: 3.11.5 27 | 28 | - name: Build package 29 | run: | 30 | python3 -m pip install --upgrade build 31 | python3 -m build 32 | 33 | - name: Publish to PyPI 34 | uses: pypa/gh-action-pypi-publish@release/v1 35 | with: 36 | password: ${{ secrets.PYPI_API_TOKEN }} 37 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | # Builds the `llamppl` environment and runs all tests 2 | 3 | name: Codebase tests 4 | 5 | on: 6 | pull_request: 7 | push: 8 | branches: 9 | - main 10 | 11 | permissions: 12 | contents: read 13 | 14 | jobs: 15 | build: 16 | runs-on: ParallelHoss 17 | 18 | steps: 19 | - name: Check out repository 20 | uses: actions/checkout@v4 21 | 22 | - name: Set up python 23 | id: setup-python 24 | uses: actions/setup-python@v5 25 | with: 26 | python-version: '3.11.5' 27 | 28 | - name: Run Tests 29 | run: | 30 | python -m venv venv 31 | source venv/bin/activate 32 | pip install -e .[dev,examples] 33 | # Add the project root to the PYTHONPATH for examples 34 | PYTHONPATH=$PYTHONPATH:$(pwd) pytest tests --cov=llamppl --cov-report=json 35 | 36 | - name: Upload coverage to Codecov 37 | uses: codecov/codecov-action@v5 38 | with: 39 | fail_ci_if_error: false 40 | disable_search: true 41 | token: ${{ secrets.CODECOV_TOKEN }} 42 | files: ./coverage.json 43 | slug: genlm/llamppl 44 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | coverage.json 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | cover/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | .pybuilder/ 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | # For a library or package, you might want to ignore these files since the code is 89 | # intended to run in multiple environments; otherwise, check them in: 90 | .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # poetry 100 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 101 | # This is especially recommended for binary packages to ensure reproducibility, and is more 102 | # commonly ignored for libraries. 103 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 104 | poetry.lock 105 | 106 | # pdm 107 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 108 | #pdm.lock 109 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 110 | # in version control. 111 | # https://pdm.fming.dev/#use-with-ide 112 | .pdm.toml 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v5.0.0 4 | hooks: 5 | - id: check-yaml 6 | args: [--unsafe] 7 | - id: end-of-file-fixer 8 | - id: trailing-whitespace 9 | 10 | - repo: https://github.com/astral-sh/ruff-pre-commit 11 | rev: v0.9.9 12 | hooks: 13 | - id: ruff-format 14 | types_or: [ python, pyi, jupyter ] 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LLaMPPL 2 | 3 | [![docs](https://github.com/genlm/llamppl/actions/workflows/docs.yml/badge.svg)](https://genlm.github.io/llamppl) 4 | [![Tests](https://github.com/genlm/llamppl/actions/workflows/tests.yml/badge.svg)](https://github.com/genlm/llamppl/actions/workflows/tests.yml) 5 | [![codecov](https://codecov.io/gh/genlm/llamppl/graph/badge.svg?token=pgVQBiqCuM)](https://codecov.io/gh/genlm/llamppl) 6 | 7 | 8 | LLaMPPL is a research prototype for language model probabilistic programming: specifying language generation tasks by writing probabilistic programs that combine calls to LLMs, symbolic program logic, and probabilistic conditioning. To solve these tasks, LLaMPPL uses a specialized sequential Monte Carlo inference algorithm. This technique, SMC steering, is described in [our recent workshop abstract](https://arxiv.org/abs/2306.03081). 9 | 10 | This library was formerly known as `hfppl`. 11 | 12 | ## Installation 13 | 14 | If you just want to try out LLaMPPL, check out our [demo notebook on Colab](https://colab.research.google.com/drive/1uJEC-U8dcwsTWccCDGVexpgXexzZ642n?usp=sharing), which performs a simple constrained generation task using GPT-2. (Larger models may require more RAM or GPU resources than Colab's free version provides.) 15 | 16 | To get started on your own machine, you can install this library from PyPI: 17 | 18 | ``` 19 | pip install llamppl 20 | ``` 21 | 22 | ### Local installation 23 | 24 | For local development, clone this repository and run `pip install -e ".[dev,examples]"` to install `llamppl` and its development dependencies. 25 | 26 | ``` 27 | git clone https://github.com/genlm/llamppl 28 | cd llamppl 29 | pip install -e ".[dev,examples]" 30 | ``` 31 | 32 | Then, try running an example. Note that this will cause the weights of a HuggingFace model to be downloaded. 33 | 34 | ``` 35 | python examples/hard_constraints.py 36 | ``` 37 | 38 | If everything is working, you should see the model generate political news using words that are at most five letters long (e.g., "Dr. Jill Biden may still be a year away from the White House but she is set to make her first trip to the U.N. today."). 39 | 40 | ## Modeling with LLaMPPL 41 | 42 | A LLaMPPL program is a subclass of the `llamppl.Model` class. 43 | 44 | ```python 45 | from llamppl import Model, LMContext, CachedCausalLM 46 | 47 | # A LLaMPPL model subclasses the Model class 48 | class MyModel(Model): 49 | 50 | # The __init__ method is used to process arguments 51 | # and initialize instance variables. 52 | def __init__(self, lm, prompt, forbidden_letter): 53 | super().__init__() 54 | 55 | # A stateful context object for the LLM, initialized with the prompt 56 | self.context = LMContext(lm, prompt) 57 | self.eos_token = lm.tokenizer.eos_token_id 58 | 59 | # The forbidden letter 60 | self.forbidden_tokens = set(i for (i, v) in enumerate(lm.vocab) 61 | if forbidden_letter in v) 62 | 63 | # The step method is used to perform a single 'step' of generation. 64 | # This might be a single token, a single phrase, or any other division. 65 | # Here, we generate one token at a time. 66 | async def step(self): 67 | # Condition on the next token *not* being a forbidden token. 68 | await self.observe(self.context.mask_dist(self.forbidden_tokens), False) 69 | 70 | # Sample the next token from the LLM -- automatically extends `self.context`. 71 | token = await self.sample(self.context.next_token()) 72 | 73 | # Check for EOS or end of sentence 74 | if token.token_id == self.eos_token or str(token) in ['.', '!', '?']: 75 | # Finish generation 76 | self.finish() 77 | 78 | # To improve performance, a hint that `self.forbidden_tokens` is immutable 79 | def immutable_properties(self): 80 | return set(['forbidden_tokens']) 81 | ``` 82 | 83 | The Model class provides a number of useful methods for specifying a LLaMPPL program: 84 | 85 | * `self.sample(dist[, proposal])` samples from the given distribution. Providing a proposal does not modify the task description, but can improve inference. Here, for example, we use a proposal that pre-emptively avoids the forbidden letter. 86 | * `self.condition(cond)` conditions on the given Boolean expression. 87 | * `self.finish()` indicates that generation is complete. 88 | * `self.observe(dist, obs)` performs a form of 'soft conditioning' on the given distribution. It is equivalent to (but more efficient than) sampling a value `v` from `dist` and then immediately running `condition(v == obs)`. 89 | 90 | To run inference, we use the `smc_steer` or `smc_standard` methods: 91 | 92 | ```python 93 | import asyncio 94 | from llamppl import smc_steer 95 | 96 | # Initialize the language model 97 | lm = CachedCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") 98 | 99 | # Create a model instance 100 | model = MyModel(lm, "The weather today is expected to be", "e") 101 | 102 | # Run inference 103 | particles = asyncio.run(smc_steer(model, 5, 3)) # number of particles N, and beam factor K 104 | ``` 105 | 106 | Sample output: 107 | 108 | ``` 109 | sunny. 110 | sunny and cool. 111 | 34° (81°F) in Chicago with winds at 5mph. 112 | 34° (81°F) in Chicago with winds at 2-9 mph. 113 | hot and humid with a possibility of rain, which is not uncommon for this part of Mississippi. 114 | ``` 115 | 116 | Further documentation can be found at https://genlm.github.io/llamppl. 117 | -------------------------------------------------------------------------------- /benchmark/benchmark_backend.py: -------------------------------------------------------------------------------- 1 | """ 2 | Requires pytest and pytest-benchmark (pip install pytest pytest-benchmark) 3 | 4 | Example usage: pytest benchmark/benchmark_backend.py --benchmark-only --benchmark-group-by=func -v 5 | """ 6 | 7 | import asyncio 8 | 9 | import pytest 10 | import torch 11 | 12 | from examples.haiku import run_example as run_haiku 13 | from examples.hard_constraints import run_example as run_hard_constraints 14 | from hfppl.llms import CachedCausalLM 15 | 16 | backends = [ 17 | "hf", 18 | pytest.param( 19 | "vllm", 20 | marks=pytest.mark.skipif( 21 | not torch.cuda.is_available(), reason="vLLM backend requires CUDA" 22 | ), 23 | ), 24 | ] 25 | 26 | 27 | @pytest.fixture 28 | def LLM(backend): 29 | # Set lower gpu_memory_utilization in vllm so that we can fit both models on the GPU 30 | kwargs = ( 31 | {"engine_opts": {"gpu_memory_utilization": 0.45}, "cache_size": 100} 32 | if backend == "vllm" 33 | else {} 34 | ) 35 | return CachedCausalLM.from_pretrained( 36 | "meta-llama/Meta-Llama-3-8B", backend=backend, **kwargs 37 | ) 38 | 39 | 40 | @pytest.mark.parametrize("backend", backends) 41 | def test_hard_constraints_benchmark(LLM, benchmark, n_particles=20, max_tokens=50): 42 | def run_with_clear_cache(): 43 | LLM.clear_cache() 44 | return asyncio.run( 45 | run_hard_constraints(LLM, max_tokens=max_tokens, n_particles=n_particles) 46 | ) 47 | 48 | # warmup 49 | run_with_clear_cache() 50 | 51 | benchmark.pedantic( 52 | run_with_clear_cache, 53 | iterations=1, 54 | rounds=3, 55 | ) 56 | 57 | 58 | @pytest.mark.parametrize("backend", backends) 59 | def test_haiku_benchmark(LLM, benchmark, n_particles=20): 60 | def run_with_clear_cache(): 61 | LLM.clear_cache() 62 | return asyncio.run( 63 | run_haiku(LLM, poem_title="The beauty of testing", n_particles=n_particles) 64 | ) 65 | 66 | # warmup 67 | run_with_clear_cache() 68 | 69 | benchmark.pedantic( 70 | run_with_clear_cache, 71 | iterations=1, 72 | rounds=3, 73 | ) 74 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | status: 3 | project: 4 | default: 5 | target: auto 6 | threshold: 3% 7 | -------------------------------------------------------------------------------- /docs/anatomy.md: -------------------------------------------------------------------------------- 1 | # Anatomy of a LLaMPPL model 2 | -------------------------------------------------------------------------------- /docs/batching.md: -------------------------------------------------------------------------------- 1 | # Auto-Batching 2 | 3 | If running in a GPU-accelerated environment, LLaMPPL supports **auto-batching**. 4 | 5 | The `step` method of a LLaMPPL model describes how to advance a *single* particle one step of generation. 6 | But inference methods must maintain many particles at once. 7 | 8 | With auto-batching, LLaMPPL will execute particles' `step` methods concurrently, and automatically batch calls 9 | to large language models. This batching is handled by the `CachedCausalLM` object. 10 | 11 | If you are using the `vllm` backend, all batching decisions are handled internally by `vllm`. 12 | 13 | If you are using the huggingface backend, the behavior is controlled by two parameters: 14 | 15 | * `lm.batch_size`: the maximum number of requests to batch. The default value is 20. 16 | * `lm.timeout`: if `lm.timeout` seconds pass with no new request, the current batch is processed even if not full. The default value is 0.02. 17 | 18 | You may want to set the batch size (`#!python lm.batch_size`) to the number of particles you are using (if the number of particles is not too large). 19 | -------------------------------------------------------------------------------- /docs/caching.md: -------------------------------------------------------------------------------- 1 | # Caching in LLaMPPL 2 | 3 | LLaMPPL performs two kinds of caching to improve performance. The caching behavior is dependent on the backend you are using with your `CachedCausalLM`. 4 | 5 | ## Log probability caching 6 | With the huggingface backend, next-token log probabilities are always cached, whenever they are computed. 7 | This way, if different particles make exactly the same log probability queries, 8 | the Transformer is run only once. This is primarily beneficial when: 9 | 10 | * particles are cloned during resampling 11 | 12 | * cloned particles happen to sample the same next token: if the next-token distribution is concentrated, 13 | it is likely that multiple copies of a particle will sample the same next token. Log probability caching 14 | allows them to sample the _following_ token using only a single call to the language model. 15 | 16 | The log probability cache can be cleared using the [`lm.clear_cache()`][llamppl.llms.CachedCausalLM.clear_cache] method. Note that for the huggingface backend, this method will also clear the KV cache. 17 | 18 | With the `vllm` backend, the log probability cache can be turned on by passing in a `cache_size` parameter to the `CachedCausalLM.from_pretrained` method; for example, `CachedCausalLM.from_pretrained("meta-llama/Llama-3.1-8B-Instruct", cache_size=100)`. By default, `cache_size` is set to 0, which means that the log probability cache is disabled. 19 | 20 | ## Key-value caching 21 | Key-value caching caches the key and value vectors computed by each layer of a Transformer, 22 | for reuse when processing new tokens at the end of a previously evaluated sequence. 23 | 24 | In principle, key-value caching is most useful when: 25 | 26 | * There is a long common *prompt* from which all particles are generating. 27 | In this case, the prompt's tokens can be evaluated just once by the language model, 28 | and each subsequent call only has to pay for the new tokens generated after the prompt. 29 | 30 | * Generations from the model are very long. In this case, it may be worth paying the memory 31 | cost to cache *different* key-value sequences for *each* particle, to speed up future next-token 32 | queries. 33 | 34 | When using the `vllm` backend, both types of caching are handled automatically. 35 | 36 | With the huggingface backend, only the first use case is well-supported by the LLaMPPL library, via the 37 | [`lm.cache_kv(prompt)`][llamppl.llms.CachedCausalLM.cache_kv] method. This method computes and caches key and value vectors 38 | for every token in `prompt`. Future calls to [`lm.next_token_logprobs`][llamppl.llms.CachedCausalLM.next_token_logprobs] and [`lm.next_token_logprobs_unbatched`][llamppl.llms.CachedCausalLM.next_token_logprobs_unbatched] 39 | will automatically recognize when `prompt` is a prefix of the new query, and automatically 40 | exploit incremental computation. Multiple prompts can be cached, and [`lm.clear_kv_cache()`][llamppl.llms.CachedCausalLM.clear_kv_cache] can 41 | be used to clear the KV-cache without clearing the log probability cache. 42 | 43 | Because [`lm.cache_kv`][llamppl.llms.CachedCausalLM.cache_kv] is not a batched call, 44 | it is not well-suited to caching 45 | different strings for different particles. 46 | Rather, it is best used in the `__init__` method of a model--or even 47 | outside of a model--on fixed prompt strings that every particle will share. 48 | -------------------------------------------------------------------------------- /docs/gen_reference_page.py: -------------------------------------------------------------------------------- 1 | """Generate the code reference pages and navigation.""" 2 | 3 | from pathlib import Path 4 | 5 | import mkdocs_gen_files 6 | 7 | nav = mkdocs_gen_files.Nav() 8 | 9 | for path in sorted(Path("llamppl").rglob("*.py")): 10 | if any(part.startswith(".") for part in path.parts): 11 | continue 12 | 13 | module_path = path.relative_to(".").with_suffix("") 14 | doc_path = path.relative_to(".").with_suffix(".md") 15 | full_doc_path = Path("reference", doc_path) 16 | 17 | parts = tuple(module_path.parts) 18 | 19 | if parts[-1] == "__init__": 20 | print(f"init, making parts {parts[:-1]}") 21 | parts = parts[:-1] 22 | elif parts[-1] == "__main__": 23 | continue 24 | 25 | nav[parts] = doc_path.as_posix() # 26 | 27 | with mkdocs_gen_files.open(full_doc_path, "w") as fd: 28 | ident = ".".join(parts) 29 | fd.write(f"::: {ident}") 30 | 31 | mkdocs_gen_files.set_edit_path(full_doc_path, path) 32 | 33 | with mkdocs_gen_files.open("reference/SUMMARY.md", "w") as nav_file: # 34 | nav_file.writelines(nav.build_literate_nav()) # 35 | -------------------------------------------------------------------------------- /docs/getting_started.md: -------------------------------------------------------------------------------- 1 | # Getting Started 2 | 3 | ## Colab 4 | 5 | One easy way to try LLaMPPL out is to use a Colab notebook. We have [a demo notebook](https://colab.research.google.com/drive/1uJEC-U8dcwsTWccCDGVexpgXexzZ642n?usp=sharing) that performs constrained generation with GPT-2, a small enough model that the RAM and GPU constraints of Colab's free version should not prevent you from running the demo. 6 | 7 | ## Installing LLaMPPL 8 | 9 | To get started, clone the `llamppl` repository and install the `llamppl` package. 10 | 11 | ```bash 12 | git clone https://github.com/genlm/llamppl 13 | cd llamppl 14 | poetry install 15 | ``` 16 | 17 | We use [poetry](https://python-poetry.org/) to manage dependencies. If you don't have poetry installed, you can install it with `pip install poetry`. 18 | 19 | You can then run an example. The first time you run it, the example may ask to downlaod model weights from the HuggingFace model repository. 20 | 21 | ```bash 22 | poetry run python examples/hard_constraints.py 23 | ``` 24 | 25 | Depending on your available GPU memory, you may wish to edit the example to change parameters such as the batch size, or which HuggingFace model to use. The `hard_constraints.py` example has been run successfully on an NVIDIA L4 GPU (with 24 GB of VRAM) on Google Cloud. 26 | 27 | ## Your First Model 28 | 29 | Let's write a LLaMPPL model to generate according to the hard constraint that completions do not use the lowercase letter `e`. 30 | 31 | To do so, we write subclass the [`Model`](llamppl.modeling.Model) class: 32 | 33 | ```python 34 | # examples/no_e.py 35 | 36 | from llamppl import Model, LMContext, CachedCausalLM 37 | 38 | # A LLaMPPL model subclasses the Model class 39 | class MyModel(Model): 40 | 41 | # The __init__ method is used to process arguments 42 | # and initialize instance variables. 43 | def __init__(self, lm, prompt, forbidden_letter): 44 | super().__init__() 45 | 46 | # A stateful context object for the LLM, initialized with the prompt 47 | self.context = LMContext(lm, prompt) 48 | self.eos_token = lm.tokenizer.eos_token_id 49 | 50 | # The forbidden letter 51 | self.forbidden_tokens = set(i for (i, v) in enumerate(lm.vocab) 52 | if forbidden_letter in v) 53 | 54 | # The step method is used to perform a single 'step' of generation. 55 | # This might be a single token, a single phrase, or any other division. 56 | # Here, we generate one token at a time. 57 | async def step(self): 58 | # Condition on the next token *not* being a forbidden token. 59 | await self.observe(self.context.mask_dist(self.forbidden_tokens), False) 60 | 61 | # Sample the next token from the LLM -- automatically extends `self.context`. 62 | token = await self.sample(self.context.next_token()) 63 | 64 | # Check for EOS or end of sentence 65 | if token.token_id == self.eos_token or str(token) in ['.', '!', '?']: 66 | # Finish generation 67 | self.finish() 68 | 69 | # To improve performance, a hint that `self.forbidden_tokens` is immutable 70 | def immutable_properties(self): 71 | return set(['forbidden_tokens']) 72 | ``` 73 | 74 | To run the model, we use an inference method, like `smc_steer`: 75 | 76 | ```python 77 | import asyncio 78 | from llamppl import smc_steer 79 | 80 | # Initialize the HuggingFace model 81 | lm = CachedCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", auth_token=) 82 | 83 | # Create a model instance 84 | model = MyModel(lm, "The weather today is expected to be", "e") 85 | 86 | # Run inference 87 | particles = asyncio.run(smc_steer(model, 5, 3)) # number of particles N, and beam factor K 88 | ``` 89 | 90 | Each returned particle is an instance of the `MyModel` class that has been `step`-ped to completion. 91 | The generated strings can be printed along with the particle weights: 92 | 93 | ```python 94 | for particle in particles: 95 | print(f"{particle.context.s} (weight: {particle.weight})") 96 | ``` 97 | 98 | 99 | ## Learning more 100 | 101 | For more intuition on language model probabilistic programming, see [our paper](https://arxiv.org/abs/2306.03081), or the rest of this documentation. 102 | -------------------------------------------------------------------------------- /docs/immutability.md: -------------------------------------------------------------------------------- 1 | # Immutability 2 | 3 | When a particle is promising, the sequential Monte Carlo algorithm may _clone_ it, by calling `copy.deepcopy`. 4 | 5 | Depending on your model, this may be more or less expensive. 6 | 7 | To make it faster, override the `immutable_properties(self)` method of your Model class, to return a `set[str]` of property names that are guaranteed not to change during `step`. For all properties in this set, LLaMPPL will use shared memory across particles, and avoid copying when cloning particles. 8 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # Home 2 | 3 | [LLaMPPL](https://github.com/genlm/llamppl) is a research prototype for language model probabilistic programming: specifying language generation tasks by writing probabilistic programs that combine calls to LLMs, symbolic program logic, and probabilistic conditioning. To solve these tasks, LLaMPPL uses a specialized sequential Monte Carlo inference algorithm. 4 | 5 | This technique, SMC steering, is described in our workshop abstract, [Sequential Monte Carlo Steering of Large Language Models using Probabilistic Programs](https://arxiv.org/abs/2306.03081). 6 | -------------------------------------------------------------------------------- /docs/performance.md: -------------------------------------------------------------------------------- 1 | # Improving performance of LLaMPPL models 2 | 3 | If your LLaMPPL model is running slowly, consider exploiting the following features to improve performance: 4 | 5 | - [Auto-Batching](batching.md) — to run multiple particles concurrently, with batched LLM calls 6 | - [Caching](caching.md) - to cache key and value vectors for long prompts 7 | - [Immutability hinting](immutability.md) - to significantly speed up the bookkeeping performed by SMC inference 8 | -------------------------------------------------------------------------------- /docs/transformers.md: -------------------------------------------------------------------------------- 1 | # Working with Transformers 2 | 3 | ## Load your Transformer as a `CachedCausalLM` 4 | 5 | The easiest way to load a Transformer model is to use the [`CachedCausalLM.from_pretrained`][llamppl.llms.CachedCausalLM.from_pretrained] static method, which accepts as input a HuggingFace model identifier. This loads the model's weights into memory, and also loads the appropriate tokenizer. Note that if the model in question requires HuggingFace authorization (e.g., Meta's Llama 2 models), you will need to login via the [`huggingface-cli` command line tool](https://huggingface.co/docs/huggingface_hub/en/guides/cli). 6 | 7 | ## Use the LLM within your model via the `Transformer` distribution 8 | 9 | Within a model, you can `sample` or `observe` from the [`Transformer`][llamppl.distributions.transformer.Transformer] distribution. It accepts as arguments a [`CachedCausalLM`][llamppl.llms.CachedCausalLM] instance, as well as a list of integer token ids specifying the context. It returns a distribution over next tokens. The [`Transformer`][llamppl.distributions.transformer.Transformer] distirbution is stateless, and so your model will need to manually extend the context with newly sampled tokens. 10 | 11 | ## Use the LLM within your model via the `LMContext` class 12 | 13 | Alternatively, you can initialize an [`LMContext`][llamppl.distributions.lmcontext.LMContext] object with a [`CachedCausalLM`][llamppl.llms.CachedCausalLM] instance instance and a string-valued prompt. It maintains a growing context as state, and exposes a [`next_token`][llamppl.distributions.lmcontext.LMContext.next_token] distribution that, when sampled, observed, or intervened, grows the context. It also supports a form of 'sub-token' generation, via the [`mask_dist`][llamppl.distributions.lmcontext.LMContext.mask_dist] distribution. 14 | 15 | ## Create custom token distributions with `TokenCategorical` 16 | 17 | You may also create a custom distribution over the vocabulary of a language model using the [`TokenCategorical`][llamppl.distributions.tokencategorical.TokenCategorical] distribution. It is parameterized by a [`CachedCausalLM`][llamppl.llms.CachedCausalLM] instance, and an array of logits equal in length to the language model's vocabulary size. 18 | This distribution is particularly useful as a proposal distribution; for example, a model might `sample` with `dist` set 19 | to the LM's next token distribution, but with `proposal` set to a modified distribution that uses a heuristic to upweight 20 | 'good' tokens and downweight 'bad' ones. 21 | -------------------------------------------------------------------------------- /docs/visualization.md: -------------------------------------------------------------------------------- 1 | # Visualization 2 | 3 | We provide a Web interface for visualizing the execution of a sequential Monte Carlo algorithm, 4 | based on contributions from Maddy Bowers and Jacob Hoover. 5 | 6 | First, update your model to support visualization by implementing the [`string_for_serialization`](llamppl.modeling.Model.string_for_serialization) method. 7 | Return a string that summarizes the particle's current state. 8 | 9 | To run the interface, change to the `html` directory and run `python -m http.server`. This will start serving 10 | the files in the `html` directory at localhost:8000. (If you are SSH-ing onto a remote machine, you may need 11 | port forwarding. Visual Studio Code automatically handles this for some ports, including 8000.) 12 | Then, when calling [`smc_standard`](llamppl.inference.smc_standard), set `visualization_dir` 13 | to the path to the `html` directory. A JSON record of the run will automatically be saved 14 | to that directory, and a URL will be printed to the console (`http://localhost:8000/smc.html?path=$json_file`). 15 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/genlm/llamppl/eebf893f24a5cc4386f41e72e5a524e9d0af82f6/examples/__init__.py -------------------------------------------------------------------------------- /examples/grammar_constraint.py: -------------------------------------------------------------------------------- 1 | """SMC Steering with Grammar Constraints 2 | 3 | Author: Gabriel Grand (grandg@mit.edu) 4 | 5 | This example illustrates grammar-constrained inference with SMC Steering. 6 | `GrammarConstrainedSMC` takes as input a grammar in Lark format. 7 | We use the Synchromesh (Poesia et al., 2022) to align the grammar with the 8 | language model vocabulary. 9 | 10 | Requires synchromesh (github.com/kanishkg/synchromesh) 11 | """ 12 | 13 | import asyncio 14 | 15 | from synchromesh.completion_engine import LarkCompletionEngine 16 | from synchromesh.synchromesh import StreamingCSD 17 | 18 | from llamppl.distributions import LMContext 19 | from llamppl.inference import smc_standard 20 | from llamppl.llms import CachedCausalLM 21 | from llamppl.modeling import Model 22 | 23 | 24 | class GrammarConstrainedSMC(Model): 25 | def __init__( 26 | self, 27 | lm: CachedCausalLM, 28 | grammar: str, 29 | start_rule: str, 30 | prompt: str = None, 31 | allow_ws: bool = False, 32 | max_tokens: int = 32, 33 | verbose: bool = False, 34 | ): 35 | super().__init__() 36 | self.lm = lm 37 | self.grammar = grammar 38 | self.context = LMContext(lm, prompt) 39 | self.vocab = self.lm.str_vocab 40 | self.eos_token_id = self.lm.tokenizer.eos_token_id 41 | 42 | self.comp_engine = LarkCompletionEngine( 43 | grammar, start_token=start_rule, allow_ws=allow_ws 44 | ) 45 | self.csd = StreamingCSD( 46 | completion_engine=self.comp_engine, 47 | lm_vocabulary=self.vocab, 48 | enforce_token_maximality=False, 49 | ) 50 | 51 | self.max_tokens = max_tokens 52 | self.n_tokens = 0 53 | 54 | self.verbose = verbose 55 | 56 | async def step(self): 57 | # Get valid tokens for next step 58 | valid_token_ids = self.csd.get_valid_tokens() 59 | 60 | # If generation is a complete derivation, allow the end-of-string token 61 | if self.csd.is_complete(): 62 | valid_token_ids += [self.eos_token_id] 63 | 64 | # If no valid next tokens, reject and terminate 65 | if len(valid_token_ids) == 0: 66 | self.condition(False) 67 | return 68 | 69 | # Sample a token from the valid tokens 70 | await self.observe(self.context.mask_dist(set(valid_token_ids)), True) 71 | token = await self.sample(self.context.next_token()) 72 | 73 | # If the token is the end-of-string token, accept and terminate 74 | if token.token_id == self.eos_token_id: 75 | self.finish() 76 | return 77 | 78 | # Feed the token to StreamingCSD 79 | self.csd.feed_prediction(token.token_id) 80 | self.n_tokens += 1 81 | 82 | if self.verbose: 83 | print(str(self.context)) 84 | 85 | # Max tokens reached 86 | if self.n_tokens >= self.max_tokens: 87 | self.condition(False) 88 | self.finish() 89 | 90 | def immutable_properties(self): 91 | return set( 92 | [ 93 | "grammar", 94 | "max_tokens", 95 | "verbose", 96 | ] 97 | ) 98 | 99 | 100 | EXAMPLE_PROMPT = """Paraphrase the following sentences 101 | Human:who teaches CSE101? 102 | Bot:instructor of CSE101 103 | Human:how many students can enroll in PSY456? 104 | Bot:capacity of PSY456 105 | Human:at what school is BIO433 taught? 106 | Bot:""" 107 | 108 | EXAMPLE_GRAMMAR = r""" 109 | ?start: " "? function " of " dept code 110 | function: "instructor" | "students" | "capacity" | "department" | "school" | "college" 111 | dept: /[A-Z]{3}/ 112 | code: /[0-9]{3}/ 113 | """ 114 | 115 | 116 | async def run_generation( 117 | model: str, 118 | grammar: str, 119 | start_rule: str, 120 | prompt: str = None, 121 | allow_ws: bool = False, 122 | n_particles: int = 5, 123 | max_tokens: int = 32, 124 | verbose: bool = False, 125 | ): 126 | LLM = CachedCausalLM.from_pretrained(args.model) 127 | if LLM.backend == "hf": 128 | LLM.batch_size = args.batch_size 129 | model = GrammarConstrainedSMC( 130 | lm=LLM, 131 | grammar=grammar, 132 | start_rule=start_rule, 133 | prompt=prompt, 134 | max_tokens=max_tokens, 135 | allow_ws=allow_ws, 136 | verbose=verbose, 137 | ) 138 | particles = await smc_standard(model, n_particles=n_particles) 139 | particles_sorted = sorted(particles, key=lambda p: p.weight, reverse=True) 140 | print([(p.weight, str(p.context)) for p in particles_sorted]) 141 | 142 | 143 | if __name__ == "__main__": 144 | import argparse 145 | 146 | parser = argparse.ArgumentParser() 147 | parser.add_argument( 148 | "--model", 149 | type=str, 150 | default="codellama/CodeLlama-7b-hf", 151 | help="Name of the HuggingFace model to use", 152 | ) 153 | parser.add_argument( 154 | "--grammar", 155 | type=str, 156 | default=None, 157 | help="Path to the grammar file", 158 | ) 159 | parser.add_argument( 160 | "--start-rule", 161 | type=str, 162 | default="start", 163 | help="Name of the start rule in the grammar", 164 | ) 165 | parser.add_argument( 166 | "--prompt", 167 | type=str, 168 | default=None, 169 | help="Prompt to start generation from", 170 | ) 171 | parser.add_argument( 172 | "--n-particles", 173 | type=int, 174 | default=5, 175 | help="Number of particles to use in SMC", 176 | ) 177 | parser.add_argument( 178 | "--max-tokens", 179 | type=int, 180 | default=32, 181 | help="Maximum number of tokens to generate", 182 | ) 183 | parser.add_argument( 184 | "--allow-ws", 185 | action="store_true", 186 | help="Allow whitespace", 187 | ) 188 | parser.add_argument( 189 | "--verbose", 190 | action="store_true", 191 | help="Print intermediate generations", 192 | ) 193 | args = parser.parse_args() 194 | 195 | if args.grammar is not None: 196 | # Load the grammar 197 | with open(args.grammar, "r") as f: 198 | grammar = f.read() 199 | else: 200 | grammar = EXAMPLE_GRAMMAR 201 | 202 | prompt = args.prompt or EXAMPLE_PROMPT 203 | 204 | asyncio.run( 205 | run_generation( 206 | model=args.model, 207 | grammar=grammar, 208 | start_rule=args.start_rule, 209 | prompt=prompt, 210 | n_particles=args.n_particles, 211 | max_tokens=args.max_tokens, 212 | allow_ws=args.allow_ws, 213 | verbose=args.verbose, 214 | ) 215 | ) 216 | -------------------------------------------------------------------------------- /examples/haiku.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | import nltk 4 | 5 | from llamppl import CachedCausalLM 6 | from llamppl import LMContext 7 | from llamppl import Model 8 | from llamppl import sample_word 9 | from llamppl import smc_standard 10 | 11 | # download the CMU pronunciation dictionary (if we haven't already) 12 | nltk.download("cmudict") 13 | 14 | # Load the CMU pronunciation dictionary and use it for syllable counting 15 | from nltk.corpus import cmudict 16 | 17 | CMUDICT = cmudict.dict() 18 | 19 | 20 | def count_syllables(word, unknown_word_syllables=100): 21 | # Use the dictionary to get the list of possible phonetic representations for the word 22 | phonetic_transcriptions = CMUDICT.get(word.strip().lower(), []) 23 | 24 | # Count the number of syllables based on the number of phonetic transcriptions 25 | syllable_count = min( 26 | [ 27 | len([ph for ph in transcription if ph[-1].isdigit()]) 28 | for transcription in phonetic_transcriptions 29 | ], 30 | default=unknown_word_syllables, 31 | ) 32 | 33 | return syllable_count 34 | 35 | 36 | # Example poems for the prompt. 37 | # Authors: 38 | # - Amy Lowell 39 | # - Sonia Sanchez 40 | # - Katsushika Hokusai 41 | # - Matsuo Basho 42 | # Note that not all of these follow the syllabic constraints of a Haiku; the goal is 43 | # to encode a certain 'poetic style' but to leave the syllabic constraints to be enforced 44 | # by the probabilistic program (enabling generalization to other syllabic constraints). 45 | EXAMPLE_POEMS = """Example poems. Note how they tend to end on a somewhat surprising or otherwise satisfying note, and are not repetitive at the end. 46 | 47 | 1. "Portrait" 48 | Sweet smell of wet flowers 49 | Over an evening garden. 50 | Your portrait, perhaps? 51 | 52 | 2. "River of Love" 53 | love between us is 54 | speech and breath. loving you is 55 | a long river running. 56 | 57 | 3. "Practice" 58 | I write, erase, rewrite 59 | Erase again, and then 60 | A poppy blooms. 61 | 62 | 4. "Caterpillar" 63 | A caterpillar, 64 | this deep in fall, 65 | still not a butterfly.""" 66 | 67 | 68 | # LLaMPPL model 69 | class Haiku(Model): 70 | def __init__(self, LLM, prompt, syllable_pattern=[5, 7, 5]): 71 | super().__init__() 72 | self.context = LMContext(LLM, prompt) 73 | self.syllable_pattern = syllable_pattern 74 | self.previous_string = str(self.context) 75 | self.newline_token = LLM.str_vocab.index("\n") 76 | self.eos_token = LLM.tokenizer.eos_token_id 77 | 78 | async def step(self): 79 | self.previous_string = str(self.context) 80 | 81 | # Get the number of syllables required in the next line 82 | syllables_remaining = self.syllable_pattern.pop(0) 83 | 84 | # Loop to sample words until this line is over 85 | while syllables_remaining > 0: 86 | # Sample a word 87 | word, punctuation = await self.call(sample_word(self.context)) 88 | 89 | # Subtract syllables from the remaining count 90 | syllables_remaining -= count_syllables(word) 91 | 92 | # Reject if we overshot 93 | self.condition(syllables_remaining == 0) 94 | 95 | # If there are no more lines, finish 96 | if not self.syllable_pattern: 97 | await self.observe(self.context.next_token(), self.eos_token) 98 | self.finish() 99 | return 100 | 101 | # Otherwise, observe a line break 102 | await self.observe(self.context.next_token(), self.newline_token) 103 | 104 | # Print current result 105 | print(str(self.context)) 106 | 107 | def string_for_serialization(self): 108 | # Replace newlines with slashes in str(self.context) 109 | s = ( 110 | self.previous_string 111 | + "<<<" 112 | + str(self.context)[len(self.previous_string) :] 113 | + ">>>" 114 | ) 115 | return s.replace("\n", "/") 116 | 117 | 118 | async def run_example( 119 | LLM, poem_title, syllable_pattern=[5, 7, 5], n_particles=20, ess_threshold=0.5 120 | ): 121 | # Construct prompt 122 | prompt = f"""{EXAMPLE_POEMS} 123 | 124 | 5. "{poem_title}" 125 | """ 126 | 127 | # Cache the key value vectors for the prompt 128 | LLM.cache_kv(LLM.tokenizer.encode(prompt)) 129 | 130 | # Initialize the Model 131 | haiku_model = Haiku(LLM, prompt, syllable_pattern) 132 | 133 | # Run inference 134 | particles = await smc_standard( 135 | haiku_model, n_particles, ess_threshold, "html", "results/haiku.json" 136 | ) 137 | 138 | return particles 139 | 140 | 141 | def main(): 142 | # Load the language model. 143 | # Mistral is an open model; to use a model with restricted access, like LLaMA 3, 144 | # authenticate using the Huggingface CLI. 145 | LLM = CachedCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B") 146 | # LLM = CachedCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1") 147 | 148 | # Set batch size if using HuggingFace backend 149 | if LLM.backend == "hf": 150 | LLM.batch_size = 40 151 | 152 | # Get poem title from user 153 | poem_title = input("Enter a title for your Haiku: ") 154 | 155 | syllables_per_line = [5, 7, 5] # [5, 3, 5] for a Lune 156 | 157 | # Run the example 158 | particles = asyncio.run( 159 | run_example(LLM, poem_title, syllable_pattern=syllables_per_line) 160 | ) 161 | 162 | print("--------") 163 | for i, particle in enumerate(particles): 164 | print(f"\nPoem {i} (weight {particle.weight}):") 165 | print(f"{particle.context}") 166 | 167 | 168 | if __name__ == "__main__": 169 | main() 170 | -------------------------------------------------------------------------------- /examples/hard_constraints.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import string 3 | 4 | from llamppl import CachedCausalLM 5 | from llamppl import LMContext 6 | from llamppl import Model 7 | from llamppl import smc_standard 8 | 9 | 10 | def make_masks(LLM): 11 | return { 12 | i: set( 13 | j 14 | for (j, v) in enumerate(LLM.str_vocab) 15 | if j != LLM.tokenizer.eos_token_id 16 | and "\n" not in v 17 | and any(c.isalpha() or c in string.punctuation for c in v) 18 | and len(v.strip()) <= 5 19 | and (not v[0].isalpha() or i + len(v) <= 5) 20 | ) 21 | for i in range(6) 22 | } 23 | 24 | 25 | class ConstraintModel(Model): 26 | def __init__(self, LLM, prompt, max_tokens): 27 | super().__init__() 28 | self.context = LMContext(LLM, prompt) 29 | self.max_tokens = max_tokens 30 | self.masks = make_masks(LLM) 31 | self.eos_token_id = LLM.tokenizer.eos_token_id 32 | 33 | async def start(self): 34 | mask = self.active_constraint_mask() 35 | await self.observe(self.context.mask_dist(mask), True) 36 | 37 | async def step(self): 38 | # Generate proposed token. 39 | token = await self.sample(self.context.next_token()) 40 | 41 | # Reduce number of max tokens remaining 42 | self.max_tokens -= 1 43 | 44 | print(f"{self.context}") 45 | 46 | # Check if done 47 | if token == self.eos_token_id or self.max_tokens == 0: 48 | self.finish() 49 | return 50 | 51 | # Observe that next token follows the constraint. 52 | mask = self.active_constraint_mask() 53 | await self.observe(self.context.mask_dist(mask), True) 54 | 55 | def active_constraint_mask(self): 56 | string_so_far = str(self.context) 57 | words = string_so_far.split() 58 | last_word = words[-1] if len(words) > 0 else "" 59 | return self.masks[min(5, len(last_word))] 60 | 61 | def string_for_serialization(self): 62 | return f"{self.context}" 63 | 64 | def immutable_properties(self): 65 | return ["masks"] 66 | 67 | 68 | # From Politico.com 69 | prompt = """3 things to watch … 70 | 71 | 1. The return of the House means new energy for the GOP’s Biden impeachment push, and Democrats are starting their pushback early. Rep. Jamie Raskin (D-Md.) is out this morning with a 14-page rebuttal memo that seeks to paint the GOP campaign as a “complete and total bust” and an attempt at distracting from the “overwhelming evidence of [Trump’s] criminal and corrupt conduct during his term of office.” 72 | 73 | 2. The Senate is back this evening for a bed-check vote. With Minority Leader Mitch McConnell having successfully quieted (public) chatter about his health, expect senators to be quizzed anew about Sen. Tommy Tuberville’s (R-Ala.) Pentagon nominee blockade, especially with the Joint Chiefs chair, Gen. Mark Milley, just weeks away from retirement and the confirmation of his successor, Gen. C.Q. Brown, in limbo. 74 | 75 | 3.""" 76 | 77 | 78 | async def run_example(LLM, max_tokens=50, n_particles=20, ess_threshold=0.5): 79 | # Cache the key value vectors for the prompt. 80 | LLM.cache_kv(LLM.tokenizer.encode(prompt)) 81 | 82 | # Initialize the Model. 83 | constraint_model = ConstraintModel(LLM, prompt, max_tokens) 84 | 85 | # Run inference. 86 | particles = await smc_standard( 87 | constraint_model, n_particles, ess_threshold, "html", "results/output.json" 88 | ) 89 | for p in particles: 90 | print(f"{p.context}") 91 | 92 | return particles 93 | 94 | 95 | def main(): 96 | # Load the language model. 97 | # Mistral and Vicuna are open models; to use a model with restricted access, like LLaMA 3, 98 | # authenticate using the Huggingface CLI. 99 | LLM = CachedCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B") 100 | # LLM = CachedCausalLM.from_pretrained("lmsys/vicuna-7b-v1.5") 101 | # LLM = CachedCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1") 102 | 103 | # Set batch size if provided. This operation is only valid for the HuggingFace backend. 104 | if LLM.backend == "hf": 105 | LLM.batch_size = 40 106 | 107 | # Run the example. 108 | asyncio.run(run_example(LLM)) 109 | 110 | 111 | if __name__ == "__main__": 112 | main() 113 | -------------------------------------------------------------------------------- /html/smc.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 107 | 108 | 109 | 110 | 111 | 112 |

Sequential Monte Carlo - Visualization

113 | 114 |
115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 |
125 | 126 | 127 | 128 | 365 | 366 | 367 | 368 | 369 | -------------------------------------------------------------------------------- /llamppl/__init__.py: -------------------------------------------------------------------------------- 1 | """Probabilistic programming with Large Language Models.""" 2 | 3 | from .chunks import * 4 | from .distributions import * 5 | from .inference import * 6 | from .llms import * 7 | from .modeling import * 8 | from .util import * 9 | -------------------------------------------------------------------------------- /llamppl/chunks.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import string 3 | 4 | from .modeling import submodel 5 | 6 | 7 | @submodel 8 | async def sample_word(self, context, max_tokens=5, allow_punctuation=True): 9 | """Sample a word from the `LMContext` object `context`.""" 10 | last_token = ( 11 | context.lm.str_vocab[context.tokens[-1]] if len(context.tokens) > 0 else "" 12 | ) 13 | last_character = last_token[-1] if len(last_token) > 0 else "" 14 | needs_space = last_character not in string.whitespace and last_character not in [ 15 | "-", 16 | "'", 17 | '"', 18 | ] 19 | if needs_space: 20 | starts_word_mask = context.lm.masks.STARTS_NEW_WORD 21 | else: 22 | starts_word_mask = context.lm.masks.CONTINUES_CURRENT_WORD 23 | 24 | # Force model to start a new word 25 | await self.observe(context.mask_dist(starts_word_mask), True) 26 | 27 | word = "" 28 | num_tokens = 0 29 | while True: 30 | token = await self.sample(context.next_token()) 31 | word += context.lm.str_vocab[token.token_id] 32 | num_tokens += 1 33 | 34 | if num_tokens == max_tokens: 35 | await self.observe( 36 | context.mask_dist(context.lm.masks.CONTINUES_CURRENT_WORD), False 37 | ) 38 | break 39 | 40 | if not ( 41 | await self.sample( 42 | context.mask_dist(context.lm.masks.CONTINUES_CURRENT_WORD) 43 | ) 44 | ): 45 | break 46 | 47 | # Sample punctuation, if desired 48 | punctuation = "" 49 | if allow_punctuation and await self.sample( 50 | context.mask_dist(context.lm.masks.PUNCTUATION) 51 | ): 52 | punctuation_token = await self.sample(context.next_token()) 53 | punctuation = context.lm.str_vocab[punctuation_token.token_id] 54 | 55 | return word, punctuation 56 | 57 | 58 | @submodel 59 | async def sample_word_2( 60 | self, 61 | context, 62 | max_chars: int = None, 63 | allow_mid_punctuation: bool = True, 64 | allow_end_punctuation: bool = True, 65 | ): 66 | """Sample a word from the `LMContext` object `context`. 67 | 68 | Unlike sample_word() above, this method allows for character-level control over the length of the word. 69 | It also allows for control over the presence of punctuation in the middle and at the end of the word. 70 | 71 | Args: 72 | max_chars (int): Maximum number of characters in the word. If None, the model will sample a word of any length. 73 | allow_mid_punctuation (bool): If True, the model may sample punctuation in the middle of the word. 74 | allow_end_punctuation (bool): If True, the model may sample punctuation at the end of the word. 75 | 76 | Returns: 77 | Tuple[str, str]: The sampled word and punctuation 78 | """ 79 | # NOTE: Yields control back to the event loop. Necessary to allow timeouts to work correctly when this method is called in a loop. 80 | await asyncio.sleep(0) 81 | 82 | # This approach sometimes breaks with max_chars = 1 83 | if max_chars is not None: 84 | assert max_chars > 1 85 | 86 | last_token = ( 87 | context.lm.str_vocab[context.tokens[-1]] if len(context.tokens) > 0 else "" 88 | ) 89 | last_character = last_token[-1] if len(last_token) > 0 else "" 90 | needs_space = last_character not in string.whitespace and last_character not in [ 91 | "-", 92 | "'", 93 | '"', 94 | ] 95 | if needs_space: 96 | starts_word_mask = context.lm.masks.STARTS_NEW_WORD 97 | else: 98 | starts_word_mask = context.lm.masks.CONTINUES_CURRENT_WORD 99 | 100 | # Force model to start a new word 101 | await self.observe(context.mask_dist(starts_word_mask), True) 102 | 103 | word = "" 104 | while True: 105 | # Force model to sample a token with an appropriate number of characters 106 | if max_chars is not None: 107 | await self.observe( 108 | context.mask_dist( 109 | context.lm.masks.MAX_TOKEN_LENGTH[max_chars - len(word.strip())] 110 | ), 111 | True, 112 | ) 113 | 114 | token = await self.sample(context.next_token()) 115 | word += context.lm.str_vocab[token.token_id] 116 | 117 | # If we ran out of chars, break 118 | if max_chars is not None and len(word.strip()) >= max_chars: 119 | await self.observe( 120 | context.mask_dist(context.lm.masks.CONTINUES_CURRENT_WORD), False 121 | ) 122 | break 123 | 124 | # If the model wants to end the word, break 125 | if not ( 126 | await self.sample( 127 | context.mask_dist(context.lm.masks.CONTINUES_CURRENT_WORD) 128 | ) 129 | ): 130 | break 131 | 132 | # Sample punctuation, if desired 133 | mid_punctuation, end_punctuation = "", "" 134 | 135 | mask = set() 136 | if allow_mid_punctuation: 137 | mask = mask | context.lm.masks.MID_PUNCTUATION 138 | if allow_end_punctuation: 139 | mask = mask | context.lm.masks.END_PUNCTUATION 140 | 141 | if mask and await self.sample(context.mask_dist(mask)): 142 | token = await self.sample(context.next_token()) 143 | if token.token_id in context.lm.masks.MID_PUNCTUATION: 144 | mid_punctuation = context.lm.str_vocab[token.token_id] 145 | if token.token_id in context.lm.masks.END_PUNCTUATION: 146 | end_punctuation = context.lm.str_vocab[token.token_id] 147 | 148 | return word, mid_punctuation, end_punctuation 149 | -------------------------------------------------------------------------------- /llamppl/distributions/__init__.py: -------------------------------------------------------------------------------- 1 | """Exposes distributions for use with `sample`, `observe`, and `intervene` methods in LLaMPPL models. 2 | 3 | Currently supported distributions: 4 | 5 | * `Bernoulli(p: float) -> bool` 6 | * `Geometric(p: float) -> int` 7 | * `LogCategorical(logits: array) -> int` 8 | * `TokenCategorical(lm: llamppl.llms.CachedCausalLM, logits: array) -> llamppl.llms.Token` 9 | * `Transformer(lm: llamppl.llms.CachedCausalLM) -> llamppl.llms.Token` 10 | * `LMContext(lm: llamppl.llms.CachedCausalLM, prompt: list[int]).next_token() -> llamppl.llms.Token` 11 | * `LMContext(lm: llamppl.llms.CachedCausalLM, prompt: list[int]).mask_dist(mask: set[int]) -> bool` 12 | """ 13 | 14 | from .bernoulli import Bernoulli 15 | from .distribution import Distribution 16 | from .geometric import Geometric 17 | from .lmcontext import LMContext 18 | from .logcategorical import LogCategorical 19 | from .tokencategorical import TokenCategorical 20 | from .transformer import Transformer 21 | -------------------------------------------------------------------------------- /llamppl/distributions/bernoulli.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .distribution import Distribution 4 | 5 | 6 | class Bernoulli(Distribution): 7 | """A Bernoulli distribution.""" 8 | 9 | def __init__(self, p): 10 | """Create a Bernoulli distribution. 11 | 12 | Args: 13 | p: the probability-of-True for the Bernoulli distribution. 14 | """ 15 | self.p = p 16 | 17 | async def sample(self): 18 | b = np.random.rand() < self.p 19 | return (b, await self.log_prob(b)) 20 | 21 | async def log_prob(self, value): 22 | return np.log(self.p) if value else np.log1p(-self.p) 23 | 24 | async def argmax(self, idx): 25 | return (self.p > 0.5) if idx == 0 else (self.p < 0.5) 26 | -------------------------------------------------------------------------------- /llamppl/distributions/distribution.py: -------------------------------------------------------------------------------- 1 | class Distribution: 2 | """Abstract base class for a distribution.""" 3 | 4 | async def sample(self): 5 | """Generate a random sample from the distribution. 6 | 7 | Returns: 8 | x: a value randomly sampled from the distribution.""" 9 | raise NotImplementedError() 10 | 11 | async def log_prob(self, x): 12 | """Compute the log probability of a value under this distribution, 13 | or the log probability density if the distribution is continuous. 14 | 15 | Args: 16 | x: the point at which to evaluate the log probability. 17 | Returns: 18 | logprob (float): the log probability of `x`.""" 19 | raise NotImplementedError() 20 | 21 | async def argmax(self, n): 22 | """Return the nth most probable outcome under this distribution (assuming this is a discrete distribution). 23 | 24 | Args: 25 | n (int): which value to return to, indexed from most probable (n=0) to least probable (n=|support|). 26 | Returns: 27 | x: the nth most probable outcome from this distribution.""" 28 | raise NotImplementedError() 29 | -------------------------------------------------------------------------------- /llamppl/distributions/geometric.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .distribution import Distribution 4 | 5 | 6 | class Geometric(Distribution): 7 | """A Geometric distribution.""" 8 | 9 | def __init__(self, p): 10 | """Create a Geometric distribution. 11 | 12 | Args: 13 | p: the rate of the Geometric distribution. 14 | """ 15 | self.p = p 16 | 17 | async def sample(self): 18 | n = np.random.geometric(self.p) 19 | return n, await self.log_prob(n) 20 | 21 | async def log_prob(self, value): 22 | return np.log(self.p) + np.log(1 - self.p) * (value - 1) 23 | 24 | async def argmax(self, idx): 25 | return idx - 1 # Most likely outcome is 0, then 1, etc. 26 | -------------------------------------------------------------------------------- /llamppl/distributions/lmcontext.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import numpy as np 4 | 5 | from ..llms import Token 6 | from ..util import log_softmax 7 | from ..util import logsumexp 8 | from .distribution import Distribution 9 | 10 | 11 | class LMNextToken(Distribution): 12 | def __init__(self, ctx): 13 | self.ctx = ctx 14 | 15 | async def log_prob(self, x): 16 | if isinstance(x, Token): 17 | x = x.token_id 18 | 19 | lp = self.ctx.next_token_logprobs[x] 20 | self.ctx.tokens.append(x) 21 | updated_logprobs = await self.ctx.lm.next_token_logprobs(self.ctx.tokens) 22 | self.ctx.next_token_logprobs = log_softmax(updated_logprobs / self.ctx.temp) 23 | self.ctx.model_mask = self.ctx.lm.masks.ALL_TOKENS 24 | 25 | return lp 26 | 27 | async def sample(self): 28 | probs = np.exp(self.ctx.next_token_logprobs) 29 | probs /= np.sum(probs) # Renormalize to fix floating point errors 30 | token_id = np.random.choice(len(probs), p=(probs)) 31 | self.ctx.tokens.append(token_id) 32 | logprob = self.ctx.next_token_logprobs[token_id] 33 | 34 | # Reset mask and update logprobs 35 | self.ctx.model_mask = self.ctx.lm.masks.ALL_TOKENS 36 | updated_logprobs = await self.ctx.lm.next_token_logprobs(self.ctx.tokens) 37 | self.ctx.next_token_logprobs = log_softmax(updated_logprobs / self.ctx.temp) 38 | 39 | t = Token( 40 | self.ctx.lm, token_id, self.ctx.lm.tokenizer.convert_ids_to_tokens(token_id) 41 | ) 42 | return t, logprob 43 | 44 | 45 | class LMTokenMask(Distribution): 46 | def __init__(self, ctx, mask): 47 | self.ctx = ctx 48 | self.mask = mask 49 | 50 | async def sample(self): 51 | newly_bad_tokens = [i for i in self.ctx.model_mask if i not in self.mask] 52 | good_tokens = [i for i in self.ctx.model_mask if i in self.mask] 53 | logprob_no_mask = logsumexp(self.ctx.next_token_logprobs[newly_bad_tokens]) 54 | if logprob_no_mask > 0: 55 | logprob_yes_mask = float("-inf") 56 | else: 57 | # When logprob_no_mask is very close to 0.0, np.log1p can raise a "divide by zero" 58 | # warning before returning -inf. We suppress this warning, because returning -inf 59 | # is the desired behavior (the LLM places no mass on 'yes'). 60 | with np.errstate(divide="ignore"): 61 | logprob_yes_mask = np.log1p(-np.exp(logprob_no_mask)) 62 | decide_no_mask = np.random.rand() < np.exp(logprob_no_mask) 63 | if decide_no_mask: 64 | self.ctx.model_mask = self.ctx.model_mask - self.mask 65 | self.ctx.next_token_logprobs[good_tokens] = float("-inf") 66 | self.ctx.next_token_logprobs -= logprob_no_mask 67 | return False, logprob_no_mask 68 | else: 69 | self.ctx.model_mask = self.ctx.model_mask.intersection(self.mask) 70 | self.ctx.next_token_logprobs[newly_bad_tokens] = float("-inf") 71 | self.ctx.next_token_logprobs -= logprob_yes_mask 72 | return True, logprob_yes_mask 73 | 74 | async def log_prob(self, v): 75 | good_tokens = ( 76 | self.ctx.model_mask.intersection(self.mask) 77 | if v 78 | else self.ctx.model_mask - self.mask 79 | ) 80 | if len(good_tokens) == 0: 81 | # If there are no good tokens, the log probability of v under the mask is -inf 82 | # However, since this method updates the model_mask as a side-effect, 83 | # this will put the context in an invalid state, so we instead raise an exception. 84 | raise NullMask( 85 | "Unable to compute log probability of mask that rules out all tokens." 86 | ) 87 | else: 88 | logprob_good = logsumexp(self.ctx.next_token_logprobs[list(good_tokens)]) 89 | 90 | bad_tokens = [i for i in self.ctx.model_mask if i not in good_tokens] 91 | self.ctx.next_token_logprobs[bad_tokens] = float("-inf") 92 | self.ctx.next_token_logprobs -= logprob_good 93 | self.ctx.model_mask = good_tokens 94 | return logprob_good 95 | 96 | 97 | class NullMask(Exception): 98 | pass 99 | 100 | 101 | class LMContext: 102 | """Represents a generation-in-progress from a language model. 103 | 104 | The state tracks two pieces of information: 105 | 106 | * A sequence of tokens — the ever-growing context for the language model. 107 | * A *current mask* — a set of tokens that have not yet been ruled out as the next token. 108 | 109 | Storing a mask enables _sub-token_ generation: models can use `LMContext` to sample 110 | the next token in _stages_, first deciding, e.g., whether to use an upper-case or lower-case 111 | first letter, and only later deciding which upper-case or lower-case token to generate. 112 | 113 | The state of a `LMContext` can be advanced in two ways: 114 | 115 | 1. Sampling, observing, or intervening the `next_token()` distribution. This causes a token 116 | to be added to the growing sequence of tokens. Supports auto-batching. 117 | 2. Sampling, observing, or intervening the `mask_dist(mask)` distribution for a given mask (set of 118 | token ids). This changes the current mask. 119 | 120 | Attributes: 121 | lm (llamppl.llms.CachedCausalLM): the language model for which this is a context 122 | tokens (list[int]): the underlying sequence of tokens, including prompt, in this context 123 | next_token_logprobs (numpy.array): numpy array holding the log probabilities for the next token. Unlike the log probabilities reported by `CachedCausalLM.next_token_logprobs`, these probabilities are rescaled for this `LMContext`'s temperature parameter, and for any active masks. This vector is managed by the `LMContext` object internally; do not mutate. 124 | temp (float): temeprature for next-token distribution (0 < temp < float('inf')) 125 | model_mask (set[int]): set of tokens that have not been ruled out as the next token. This mask is managed by the `LMContext` object internally; do not mutate. 126 | show_prompt (bool): controls whether the string representation of this `LMContext` includes the initial prompt or not. Defaults to `False`. 127 | """ 128 | 129 | def __init__(self, lm, prompt, temp=1.0, show_prompt=False, show_eos=True): 130 | """Create a new `LMContext` with a given prompt and temperature. 131 | 132 | Args: 133 | lm (llamppl.llms.CachedCausalLM): the language model for which this is a context. 134 | prompt (str): a string with which to initialize the context. Will be tokenized using `lm.tokenizer`. 135 | temp (float): temeprature for next-token distribution (0 < temp < float('inf')) 136 | 137 | Note: 138 | For async initialization of LMContext, use LMContext.create(). 139 | """ 140 | self._init_common(lm, prompt, temp, show_prompt, show_eos) 141 | self.next_token_logprobs = log_softmax( 142 | lm.next_token_logprobs_unbatched(self.tokens) / temp 143 | ) 144 | 145 | @classmethod 146 | async def create(cls, lm, prompt, temp=1.0, show_prompt=False, show_eos=True): 147 | """Asynchronously create a new `LMContext` with a given prompt and temperature.""" 148 | self = cls.__new__(cls) 149 | self._init_common(lm, prompt, temp, show_prompt, show_eos) 150 | logprobs = await lm.next_token_logprobs(self.tokens) 151 | self.next_token_logprobs = log_softmax(logprobs / temp) 152 | return self 153 | 154 | def _init_common(self, lm, prompt, temp, show_prompt, show_eos): 155 | """Initialize common attributes shared between __init__ and create.""" 156 | self.lm = lm 157 | self.tokens = lm.tokenizer.encode(prompt) 158 | self.temp = temp 159 | self.model_mask = lm.masks.ALL_TOKENS 160 | self.prompt_string_length = len(lm.tokenizer.decode(self.tokens)) 161 | self.prompt_token_count = len(self.tokens) 162 | self.show_prompt = show_prompt 163 | self.show_eos = show_eos 164 | 165 | def next_token(self): 166 | """Distribution over the next token. 167 | 168 | Sampling or observing from this distribution advances the state of this `LMContext` instance. 169 | """ 170 | return LMNextToken(self) 171 | 172 | def mask_dist(self, mask): 173 | """Bernoulli distribution, with probability of True equal to the probability that the next token of this `LMContext` belongs 174 | to the given mask. 175 | 176 | Sampling or observing from this distribution modifies the state of this `LMContext` instance, so that 177 | the `next_token()` distribution either *will* (if True) or *will not* (if False) generate a token from 178 | the given mask. 179 | 180 | Args: 181 | mask: a `set(int)` specifying which token ids are included within the mask. 182 | """ 183 | return LMTokenMask(self, mask) 184 | 185 | @property 186 | def token_count(self): 187 | return len(self.tokens) - self.prompt_token_count 188 | 189 | def __str__(self): 190 | full_string = self.lm.tokenizer.decode(self.tokens) 191 | if not self.show_prompt: 192 | full_string = full_string[self.prompt_string_length :] 193 | if not self.show_eos and full_string.endswith(self.lm.tokenizer.eos_token): 194 | full_string = full_string[: -len(self.lm.tokenizer.eos_token)] 195 | return full_string 196 | 197 | def __deepcopy__(self, memo): 198 | cpy = type(self).__new__(type(self)) 199 | 200 | for k, v in self.__dict__.items(): 201 | if k in set(["lm"]): 202 | setattr(cpy, k, v) 203 | else: 204 | setattr(cpy, k, copy.deepcopy(v, memo)) 205 | 206 | return cpy 207 | -------------------------------------------------------------------------------- /llamppl/distributions/logcategorical.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from ..util import log_softmax 4 | from .distribution import Distribution 5 | 6 | 7 | class LogCategorical(Distribution): 8 | """A Geometric distribution.""" 9 | 10 | def __init__(self, logits): 11 | """Create a Categorical distribution from unnormalized log probabilities (logits). 12 | Given an array of logits, takes their `softmax` and samples an integer in `range(len(logits))` 13 | from the resulting categorical. 14 | 15 | Args: 16 | logits (np.array): a numpy array of unnormalized log probabilities. 17 | """ 18 | self.log_probs = log_softmax(logits) 19 | 20 | async def sample(self): 21 | n = np.random.choice(len(self.log_probs), p=np.exp(self.log_probs)) 22 | return n, await self.log_prob(n) 23 | 24 | async def log_prob(self, value): 25 | return self.log_probs[value] 26 | 27 | async def argmax(self, idx): 28 | return np.argsort(self.log_probs)[-idx] 29 | -------------------------------------------------------------------------------- /llamppl/distributions/tokencategorical.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from ..llms import Token 5 | from ..util import log_softmax 6 | from .distribution import Distribution 7 | 8 | 9 | class TokenCategorical(Distribution): 10 | def __init__(self, lm, logits): 11 | """Create a Categorical distribution whose values are Tokens, not integers. 12 | Given a language model `lm` and an array of unnormalized log probabilities (of length `len(lm.vocab)`), 13 | uses softmax to normalize them and samples a Token from the resulting categorical. 14 | 15 | Args: 16 | lm (llamppl.llms.CachedCausalLM): the language model whose vocabulary is to be generated from. 17 | logits (np.array): a numpy array of unnormalized log probabilities. 18 | """ 19 | self.lm = lm 20 | self.log_probs = log_softmax(logits) 21 | if self.lm.tokenizer.vocab_size != len(logits): 22 | raise RuntimeError( 23 | f"TokenCategorical: vocab size is {self.lm.tokenizer.vocab_size} but provided {len(logits)} logits." 24 | ) 25 | 26 | async def sample(self): 27 | n = np.random.choice(len(self.log_probs), p=(np.exp(self.log_probs))) 28 | return ( 29 | Token(self.lm, n, self.lm.tokenizer.convert_ids_to_tokens(n)), 30 | self.log_probs[n], 31 | ) 32 | 33 | async def log_prob(self, value): 34 | return self.log_probs[value.token_id] 35 | 36 | async def argmax(self, idx): 37 | tok = torch.argsort(self.log_probs)[-idx] 38 | return ( 39 | Token(self.lm, tok, self.lm.tokenizer.convert_ids_to_tokens(tok)), 40 | self.log_probs[tok], 41 | ) 42 | -------------------------------------------------------------------------------- /llamppl/distributions/transformer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from ..llms import Token 4 | from ..llms import TokenSequence 5 | from .distribution import Distribution 6 | 7 | 8 | # Transformer(lm, prompt) -- where prompt can either be a string or a list of Tokens. 9 | class Transformer(Distribution): 10 | def __init__(self, lm, prompt, temp=1.0): 11 | """Create a Categorical distribution whose values are Tokens, with probabilities given 12 | by a language model. Supports auto-batching. 13 | 14 | Args: 15 | lm (llamppl.llms.CachedCausalLM): the language model. 16 | prompt (str | llamppl.llms.TokenSequence): the sequence of tokens to use as the prompt. If a string, `lm.tokenizer` is used to encode it. 17 | temp (float): temperature at which to generate (0 < `temp` < `float('inf')`). 18 | """ 19 | self.lm = lm 20 | self.temp = temp 21 | 22 | # prompt will be a list of ints 23 | if isinstance(prompt, str): 24 | prompt = self.lm.tokenizer.encode(prompt) 25 | elif isinstance(prompt, TokenSequence): 26 | prompt = prompt.seq 27 | 28 | self.prompt = prompt 29 | 30 | async def log_prob(self, x): 31 | log_probs = await self.lm.next_token_logprobs(self.prompt) 32 | log_probs = log_probs / self.temp 33 | 34 | if isinstance(x, Token): 35 | x = x.token_id 36 | 37 | return log_probs[x] 38 | 39 | async def sample(self): 40 | log_probs = await self.lm.next_token_logprobs(self.prompt) 41 | log_probs = log_probs / self.temp 42 | probs = np.exp(log_probs) 43 | token_id = np.random.choice(len(probs), p=(probs)) 44 | logprob = log_probs[token_id] 45 | return ( 46 | Token(self.lm, token_id, self.lm.tokenizer.convert_ids_to_tokens(token_id)), 47 | logprob, 48 | ) 49 | 50 | 51 | # def argmax(self, idx): 52 | # token_id = np.argsort(self.log_probs)[-idx] 53 | # return Token(self.lm, token_id, self.lm.tokenizer.convert_ids_to_tokens(token_id)), log_probs[token_id] 54 | -------------------------------------------------------------------------------- /llamppl/inference/__init__.py: -------------------------------------------------------------------------------- 1 | """Provides inference methods for use with LLaMPPL models. 2 | 3 | This module currently provides the following inference methods: 4 | 5 | * `smc_standard(model, num_particles, ess_threshold=0.5)`: Standard SMC with multinomial resampling. 6 | 7 | * `smc_steer(model, num_beams, num_expansions)`: a without-replacement SMC algorithm that resembles beam search. 8 | """ 9 | 10 | from .smc_standard import smc_standard 11 | from .smc_steer import smc_steer 12 | -------------------------------------------------------------------------------- /llamppl/inference/smc_record.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | 4 | class SMCRecord: 5 | def __init__(self, n): 6 | self.history = [] 7 | self.most_recent_weights = [0.0 for _ in range(n)] 8 | self.step_num = 1 9 | 10 | def prepare_string(self, s): 11 | # If the string doesn't have <<< and >>>, prepend <<<>>> at the front. 12 | if "<<<" not in s and ">>>" not in s: 13 | return f"<<<>>>{s}" 14 | return s 15 | 16 | def particle_dict(self, particles): 17 | return [ 18 | { 19 | "contents": self.prepare_string(p.string_for_serialization()), 20 | "logweight": ( 21 | "-Infinity" if p.weight == float("-inf") else str(float(p.weight)) 22 | ), 23 | "weight_incr": str( 24 | float(p.weight) - float(self.most_recent_weights[i]) 25 | ), 26 | } 27 | for (i, p) in enumerate(particles) 28 | ] 29 | 30 | def add_init(self, particles): 31 | self.history.append( 32 | { 33 | "step": self.step_num, 34 | "mode": "init", 35 | "particles": self.particle_dict(particles), 36 | } 37 | ) 38 | self.most_recent_weights = [p.weight for p in particles] 39 | 40 | def add_smc_step(self, particles): 41 | self.step_num += 1 42 | self.history.append( 43 | { 44 | "step": self.step_num, 45 | "mode": "smc_step", 46 | "particles": self.particle_dict(particles), 47 | } 48 | ) 49 | self.most_recent_weights = [p.weight for p in particles] 50 | 51 | def add_resample(self, ancestor_indices, particles): 52 | self.step_num += 1 53 | self.most_recent_weights = [ 54 | self.most_recent_weights[i] for i in ancestor_indices 55 | ] 56 | 57 | self.history.append( 58 | { 59 | "mode": "resample", 60 | "step": self.step_num, 61 | "ancestors": [int(a) for a in ancestor_indices], 62 | "particles": self.particle_dict(particles), 63 | } 64 | ) 65 | 66 | self.most_recent_weights = [p.weight for p in particles] 67 | 68 | def to_json(self): 69 | return json.dumps(self.history) 70 | -------------------------------------------------------------------------------- /llamppl/inference/smc_standard.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import copy 3 | from datetime import datetime 4 | 5 | import numpy as np 6 | 7 | from ..util import logsumexp 8 | from .smc_record import SMCRecord 9 | 10 | 11 | async def smc_standard( 12 | model, n_particles, ess_threshold=0.5, visualization_dir=None, json_file=None 13 | ): 14 | """ 15 | Standard sequential Monte Carlo algorithm with multinomial resampling. 16 | 17 | Args: 18 | model (llamppl.modeling.Model): The model to perform inference on. 19 | n_particles (int): Number of particles to execute concurrently. 20 | ess_threshold (float): Effective sample size below which resampling is triggered, given as a fraction of `n_particles`. 21 | visualization_dir (str): Path to the directory where the visualization server is running. 22 | json_file (str): Path to the JSON file to save the record of the inference, relative to `visualization_dir` if provided. 23 | 24 | Returns: 25 | particles (list[llamppl.modeling.Model]): The completed particles after inference. 26 | """ 27 | particles = [copy.deepcopy(model) for _ in range(n_particles)] 28 | await asyncio.gather(*[p.start() for p in particles]) 29 | record = visualization_dir is not None or json_file is not None 30 | history = SMCRecord(n_particles) if record else None 31 | 32 | ancestor_indices = list(range(n_particles)) 33 | did_resample = False 34 | while any(map(lambda p: not p.done_stepping(), particles)): 35 | # Step each particle 36 | for p in particles: 37 | p.untwist() 38 | await asyncio.gather(*[p.step() for p in particles if not p.done_stepping()]) 39 | 40 | # Record history 41 | if record: 42 | if len(history.history) == 0: 43 | history.add_init(particles) 44 | elif did_resample: 45 | history.add_resample(ancestor_indices, particles) 46 | else: 47 | history.add_smc_step(particles) 48 | 49 | # Normalize weights 50 | W = np.array([p.weight for p in particles]) 51 | w_sum = logsumexp(W) 52 | normalized_weights = W - w_sum 53 | 54 | # Resample if necessary 55 | if -logsumexp(normalized_weights * 2) < np.log(ess_threshold) + np.log( 56 | n_particles 57 | ): 58 | # Alternative implementation uses a multinomial distribution and only makes n-1 copies, reusing existing one, but fine for now 59 | probs = np.exp(normalized_weights) 60 | ancestor_indices = [ 61 | np.random.choice(range(len(particles)), p=probs) 62 | for _ in range(n_particles) 63 | ] 64 | 65 | if record: 66 | # Sort the ancestor indices 67 | ancestor_indices.sort() 68 | 69 | particles = [copy.deepcopy(particles[i]) for i in ancestor_indices] 70 | avg_weight = w_sum - np.log(n_particles) 71 | for p in particles: 72 | p.weight = avg_weight 73 | 74 | did_resample = True 75 | else: 76 | did_resample = False 77 | 78 | if record: 79 | # Figure out path to save JSON. 80 | if visualization_dir is None: 81 | json_path = json_file 82 | else: 83 | timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") 84 | json_relative = ( 85 | json_file 86 | if json_file is not None 87 | else f"{model.__class__.__name__}-{timestamp}.json" 88 | ) 89 | json_path = f"{visualization_dir}/{json_file}" 90 | 91 | # Save JSON 92 | with open(json_path, "w") as f: 93 | f.write(history.to_json()) 94 | 95 | # Web path is the part of the path after the html directory 96 | if visualization_dir is not None: 97 | print(f"Visualize at http://localhost:8000/smc.html?path={json_relative}") 98 | else: 99 | print(f"Saved record to {json_path}") 100 | 101 | return particles 102 | -------------------------------------------------------------------------------- /llamppl/inference/smc_steer.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import copy 3 | 4 | import numpy as np 5 | 6 | from ..util import logsumexp 7 | from ..util import softmax 8 | 9 | 10 | def find_c(weights, N): 11 | # Sort the weights 12 | sorted_weights = np.sort(weights) 13 | # Find the smallest chi 14 | B_val = 0.0 15 | A_val = len(weights) 16 | for i in range(len(sorted_weights)): 17 | chi = sorted_weights[i] 18 | # Calculate A_val -- number of weights larger than chi 19 | A_val -= 1 20 | # Update B_val -- add the sum of weights smaller than or equal to chi 21 | B_val += chi 22 | if B_val / chi + A_val - N <= 1e-12: 23 | return (N - A_val) / B_val 24 | return N 25 | 26 | 27 | def resample_optimal(weights, N): 28 | c = find_c(weights, N) 29 | # Weights for which c * w >= 1 are deterministically resampled 30 | deterministic = np.where(c * weights >= 1)[0] 31 | # Weights for which c * w <= 1 are stochastically resampled 32 | stochastic = np.where(c * weights < 1)[0] 33 | # Stratified sampling to generate N-len(deterministic) indices 34 | # from the stochastic weights 35 | n_stochastic = len(stochastic) 36 | n_resample = N - len(deterministic) 37 | if n_resample == 0: 38 | return deterministic, np.array([], dtype=int), c 39 | K = np.sum(weights[stochastic]) / (n_resample) 40 | u = np.random.uniform(0, K) 41 | i = 0 42 | stoch_resampled = np.array([], dtype=int) 43 | while i < n_stochastic: 44 | u = u - weights[stochastic[i]] 45 | if u <= 0: 46 | # Add stochastic[i] to resampled indices 47 | stoch_resampled = np.append(stoch_resampled, stochastic[i]) 48 | # Update u 49 | u = u + K 50 | i = i + 1 51 | else: 52 | i += 1 53 | # Concatenate the deterministic and stochastic resampled indices 54 | # resampled = np.concatenate((deterministic, stoch_resampled)) 55 | # return resampled 56 | return deterministic, stoch_resampled, c 57 | 58 | 59 | async def smc_steer(model, n_particles, n_beam): 60 | """ 61 | Modified sequential Monte Carlo algorithm that uses without-replacement resampling, 62 | as described in [our workshop abstract](https://arxiv.org/abs/2306.03081). 63 | 64 | Args: 65 | model (llamppl.modeling.Model): The model to perform inference on. 66 | n_particles (int): Number of particles to maintain. 67 | n_beam (int): Number of continuations to consider for each particle. 68 | 69 | Returns: 70 | particles (list[llamppl.modeling.Model]): The completed particles after inference. 71 | """ 72 | # Create n_particles copies of the model 73 | particles = [copy.deepcopy(model) for _ in range(n_particles)] 74 | await asyncio.gather(*[p.start() for p in particles]) 75 | 76 | while any(map(lambda p: not p.done_stepping(), particles)): 77 | # Count the number of finished particles 78 | n_finished = sum(map(lambda p: p.done_stepping(), particles)) 79 | n_total = n_finished + (n_particles - n_finished) * n_beam 80 | 81 | # Create a super-list of particles that has n_beam copies of each 82 | super_particles = [] 83 | for p in particles: 84 | p.untwist() 85 | super_particles.append(p) 86 | if p.done_stepping(): 87 | p.weight += np.log(n_total) - np.log(n_particles) 88 | else: 89 | p.weight += np.log(n_total) - np.log(n_particles) - np.log(n_beam) 90 | super_particles.extend([copy.deepcopy(p) for _ in range(n_beam - 1)]) 91 | 92 | # Step each super-particle 93 | await asyncio.gather( 94 | *[p.step() for p in super_particles if not p.done_stepping()] 95 | ) 96 | 97 | # Use optimal resampling to resample 98 | W = np.array([p.weight for p in super_particles]) 99 | W_tot = logsumexp(W) 100 | W_normalized = softmax(W) 101 | det_indices, stoch_indices, c = resample_optimal(W_normalized, n_particles) 102 | particles = [ 103 | super_particles[i] for i in np.concatenate((det_indices, stoch_indices)) 104 | ] 105 | # For deterministic particles: w = w * N/N' 106 | for i in det_indices: 107 | super_particles[i].weight += np.log(n_particles) - np.log(n_total) 108 | # For stochastic particles: w = 1/c * total sum(stoch weights) / num_stoch = sum(stoch weights / total) / num_stoch * total * N/M 109 | for i in stoch_indices: 110 | super_particles[i].weight = ( 111 | W_tot - np.log(c) + np.log(n_particles) - np.log(n_total) 112 | ) 113 | 114 | # Return the particles 115 | return particles 116 | -------------------------------------------------------------------------------- /llamppl/llms.py: -------------------------------------------------------------------------------- 1 | """Utilities for working with language models.""" 2 | 3 | import string 4 | import warnings 5 | from collections import defaultdict 6 | 7 | import torch 8 | from genlm.backend.llm import AsyncTransformer 9 | from genlm.backend.llm import AsyncVirtualLM 10 | from genlm.backend.llm import MockAsyncLM 11 | 12 | VLLM_AVAILABLE = True 13 | try: 14 | import vllm 15 | except ImportError: 16 | VLLM_AVAILABLE = False 17 | 18 | warnings.filterwarnings("once", category=DeprecationWarning) 19 | warnings.filterwarnings("once", category=RuntimeWarning) 20 | 21 | 22 | class Masks: 23 | def __init__(self, lm): 24 | self.ALL_TOKENS = set(range(len(lm.str_vocab))) 25 | self.STARTS_NEW_WORD = set( 26 | i 27 | for (i, v) in enumerate(lm.str_vocab) 28 | if v[0] == " " 29 | and len(v) > 1 30 | and v[1] not in string.whitespace 31 | and v[1] not in string.punctuation 32 | ) 33 | self.CONTINUES_CURRENT_WORD = set( 34 | i 35 | for (i, v) in enumerate(lm.str_vocab) 36 | if all(c in "'" or c.isalpha() for c in v) 37 | ) 38 | self.MID_PUNCTUATION = set( 39 | i for (i, v) in enumerate(lm.str_vocab) if v in (",", ":", ";", "-", '"') 40 | ) 41 | self.END_PUNCTUATION = set( 42 | i for (i, v) in enumerate(lm.str_vocab) if v in (".", "!", "?") 43 | ) 44 | self.PUNCTUATION = self.MID_PUNCTUATION | self.END_PUNCTUATION 45 | self.CONTAINS_WHITESPACE = set( 46 | i 47 | for (i, v) in enumerate(lm.str_vocab) 48 | if any(c in string.whitespace for c in v) 49 | ) 50 | self.EOS = set([lm.tokenizer.eos_token_id]) 51 | 52 | self.MAX_TOKEN_LENGTH = self.precompute_token_length_masks(lm) 53 | 54 | def precompute_token_length_masks(self, lm): 55 | """Precompute masks for tokens of different lengths. 56 | 57 | Each mask is a set of token ids that are of the given length or shorter.""" 58 | max_token_length = max([len(t) for t in lm.str_vocab]) 59 | 60 | masks = defaultdict(lambda: self.ALL_TOKENS) 61 | masks[0] = set([lm.tokenizer.eos_token_id]) 62 | for token_length in range(1, max_token_length + 1): 63 | masks[token_length] = set( 64 | i 65 | for (i, v) in enumerate(lm.str_vocab) 66 | if len(v) <= token_length and i != lm.tokenizer.eos_token_id 67 | ) 68 | 69 | return masks 70 | 71 | 72 | class TokenSequence: 73 | """A sequence of tokens. 74 | 75 | Supports addition (via `+` or mutating `+=`) with: 76 | 77 | * other `TokenSequence` instances (concatenation) 78 | * individual tokens, represented as integers or `Token` instances 79 | * strings, which are tokenized by `lm.tokenizer` 80 | 81 | Attributes: 82 | lm (llamppl.llms.CachedCausalLM): the language model whose vocabulary the tokens come from. 83 | seq (list[llamppl.llms.Token]): the sequence of tokens.""" 84 | 85 | def __init__(self, lm, seq=None): 86 | """Create a `TokenSequence` from a language model and a sequence. 87 | 88 | Args: 89 | lm (llamppl.llms.CachedCausalLM): the language model whose vocabulary the tokens come from. 90 | seq (str | list[int]): the sequence of token ids, or a string which will be automatically tokenized. Defaults to the singleton sequence containing a bos token. 91 | """ 92 | self.lm = lm 93 | if seq is None: 94 | self.seq = [lm.tokenizer.bos_token_id] 95 | elif isinstance(seq, str): 96 | self.seq = self.lm.tokenizer.encode(seq) 97 | else: 98 | self.seq = seq 99 | 100 | def __str__(self): 101 | return self.lm.tokenizer.decode(self.seq) 102 | 103 | def __iadd__(self, other): 104 | if isinstance(other, Token): 105 | assert other.lm is self.lm 106 | self.seq.append(other.token_id) 107 | elif isinstance(other, TokenSequence): 108 | assert other.lm is self.lm 109 | self.seq.extend(other.seq) 110 | elif isinstance(other, str): 111 | self.seq.extend(self.lm.tokenizer.encode(other, add_special_tokens=False)) 112 | elif isinstance(other, int): 113 | self.seq.append(other) 114 | else: 115 | raise RuntimeError(f"Addition not supported on {type(other)}") 116 | return self 117 | 118 | def __radd__(self, other): 119 | if isinstance(other, Token): 120 | assert other.lm is self.lm 121 | return TokenSequence(self.lm, [other.token_id, *self.seq]) 122 | elif isinstance(other, TokenSequence): 123 | assert other.lm is self.lm 124 | return TokenSequence(self.lm, other.seq + self.seq) 125 | elif isinstance(other, str): 126 | return TokenSequence( 127 | self.lm, 128 | self.lm.tokenizer.encode(other, add_special_tokens=False) + self.seq, 129 | ) 130 | elif isinstance(other, int): 131 | return TokenSequence(self.lm, [other, *self.seq]) 132 | else: 133 | raise RuntimeError(f"Addition not supported on {type(other)}") 134 | 135 | def __add__(self, other): 136 | s = TokenSequence(self.lm, self.seq) 137 | s += other 138 | return s 139 | 140 | 141 | class Token: 142 | """Class representing a token. 143 | 144 | Attributes: 145 | lm (llamppl.llms.CachedCausalLM): the language model for which this is a Token. 146 | token_id (int): the integer token id (an index into the vocabulary). 147 | token_str (str): a string, which the token represents—equal to `lm.str_vocab[token_id]`. 148 | """ 149 | 150 | def __init__(self, lm, token_id, token_str): 151 | self.lm = lm 152 | self.token_id = token_id 153 | self.token_str = token_str 154 | 155 | # Adding tokens 156 | def __add__(self, other): 157 | s = TokenSequence(self.lm, [self.token_id]) 158 | s += other 159 | return s 160 | 161 | def __radd__(self, other): 162 | s = TokenSequence(self.lm, [self.token_id]) 163 | return other + s 164 | 165 | # Support checking for EOS 166 | def __eq__(self, other): 167 | if isinstance(other, Token): 168 | return self.lm is other.lm and self.token_id == other.token_id 169 | elif isinstance(other, int): 170 | return self.token_id == other 171 | else: 172 | return self.token_str == other 173 | 174 | def __int__(self): 175 | return self.token_id 176 | 177 | def __str__(self): 178 | return self.token_str 179 | 180 | def __repr__(self): 181 | return f"<{self.token_str}|{self.token_id}>" 182 | 183 | 184 | class CachedCausalLM: 185 | """Wrapper around a [`genlm.backend.llm.AsyncLM`](https://genlm.github.io/genlm-backend/reference/genlm/backend/llm/__init__/). 186 | 187 | Attributes: 188 | model (genlm_backend.llm.AsyncLM): The underlying language model (either `AsyncVirtualLM` or `AsyncTransformer`). 189 | str_vocab (list[str]): List mapping token IDs to their string representations. 190 | byte_vocab (list[bytes]): List mapping token IDs to their byte representations. 191 | masks (Masks): Token masks for filtering logits during generation. 192 | """ 193 | 194 | @classmethod 195 | def from_pretrained(cls, model_id, backend=None, **kwargs): 196 | """Create a CachedCausalLM from a HuggingFace model name. 197 | 198 | This is a convenience method that instantiates the underlying `AsyncLM` from a HuggingFace model name. 199 | 200 | Args: 201 | model_id (str): Name or path of the HuggingFace pretrained model to load. 202 | backend (str, optional): `AsyncLM` backend to use: 203 | - 'vllm' to instantiate an `AsyncVirtualLM`; ideal for GPU usage 204 | - 'hf' for an `AsyncTransformer`; ideal for CPU usage 205 | - 'mock' for a `MockAsyncLM`; ideal for testing. 206 | Defaults to 'vllm' if CUDA is available, otherwise 'hf'. 207 | **kwargs: Additional keyword arguments passed to the `AsyncLM` constructor. 208 | See [`AsyncLM` documentation](https://probcomp.github.io/genlm-backend/reference/genlm_backend/llm/__init__/). 209 | 210 | Returns: 211 | CachedCausalLM: The llamppl-compatible interface to the `AsyncLM` model. 212 | """ 213 | backend = backend or ( 214 | "vllm" if (torch.cuda.is_available() and VLLM_AVAILABLE) else "hf" 215 | ) 216 | 217 | if backend == "vllm": 218 | if not VLLM_AVAILABLE: 219 | raise ValueError( 220 | "vLLM backend requested but vLLM is not installed. " 221 | "Please install vLLM with `pip install vllm`." 222 | ) 223 | model_cls = AsyncVirtualLM 224 | elif backend == "hf": 225 | model_cls = AsyncTransformer 226 | elif backend == "mock": 227 | model_cls = MockAsyncLM 228 | else: 229 | raise ValueError( 230 | f"Unknown backend: {backend}. Must be one of ['vllm', 'hf', 'mock']" 231 | ) 232 | 233 | # Handle legacy auth_token parameter. The ability to pass in the auth_token should 234 | # be removed in a future version since it is not supported by the vllm backend. 235 | # Users should authenticate with the HuggingFace CLI. 236 | auth_token = kwargs.pop("auth_token", None) 237 | if auth_token: 238 | if backend == "vllm": 239 | raise ValueError( 240 | "Explicitly passing auth_token is not compatible with the vLLM AsyncLM backend. " 241 | "Authenticate using `huggingface-cli login` instead." 242 | ) 243 | 244 | if "hf_opts" not in kwargs: 245 | kwargs["hf_opts"] = {} 246 | kwargs["hf_opts"]["token"] = auth_token 247 | 248 | warnings.warn( 249 | "Passing auth_token directly is deprecated and will be removed in a future version. " 250 | "Please authenticate using `huggingface-cli login` instead.", 251 | DeprecationWarning, 252 | stacklevel=2, 253 | ) 254 | 255 | load_in_8bit = kwargs.pop("load_in_8bit", False) 256 | if load_in_8bit: 257 | if "bitsandbytes_opts" not in kwargs: 258 | kwargs["bitsandbytes_opts"] = {} 259 | kwargs["bitsandbytes_opts"]["load_in_8bit"] = True 260 | 261 | warnings.warn( 262 | "load_in_8bit is deprecated and will be removed in a future version. " 263 | "Please pass `bitsandbytes_opts` instead.", 264 | DeprecationWarning, 265 | stacklevel=2, 266 | ) 267 | 268 | model = model_cls.from_name(model_id, **kwargs) 269 | 270 | return cls(model) 271 | 272 | def __init__(self, model): 273 | """ 274 | Create a `CachedCausalLM` from an `AsyncLM`. 275 | 276 | Args: 277 | model (genlm_backend.llm.AsyncLM): an `AsyncLM` instance. 278 | """ 279 | if isinstance(model, AsyncVirtualLM): 280 | self.backend = "vllm" 281 | elif isinstance(model, AsyncTransformer): 282 | self.backend = "hf" 283 | elif isinstance(model, MockAsyncLM): 284 | self.backend = "mock" 285 | else: 286 | raise ValueError( 287 | f"Unknown model type: {type(model)}. Must be one of [AsyncVirtualLM, AsyncTransformer, MockAsyncLM]" 288 | ) 289 | 290 | self.model = model 291 | self.tokenizer = model.tokenizer 292 | self.str_vocab = model.str_vocab 293 | self.byte_vocab = model.byte_vocab 294 | self.masks = Masks(self) 295 | 296 | @property 297 | def vocab(self): 298 | """Legacy accessor for string vocabulary. Prefer using `.str_vocab` directly for access to the model's string vocabulary.""" 299 | warnings.warn( 300 | "Accessing .vocab directly is deprecated and will be removed in a future version. Use .str_vocab or .byte_vocab instead.", 301 | DeprecationWarning, 302 | stacklevel=2, 303 | ) 304 | return self.model.str_vocab 305 | 306 | def __deepcopy__(self, memo): 307 | return self 308 | 309 | async def next_token_logprobs(self, token_ids): 310 | """Request log probabilities of next token. This version is asynchronous and support auto batching of concurrent requests; use with `await`. 311 | 312 | Args: 313 | token_ids (list[int]): a list of token ids, representing a prompt to the language model. 314 | 315 | Returns: 316 | logprobs (numpy.array): a numpy array of length `len(str_vocab)` (equivalently `len(byte_vocab)`) with the language model's log (normalized) probabilities for the next token following the prompt. 317 | """ 318 | logprobs = await self.model.next_token_logprobs(token_ids) 319 | return logprobs.float().cpu().numpy() 320 | 321 | def next_token_logprobs_unbatched(self, token_ids): 322 | """Request log probabilities of next token. Not asynchronous, and does not support auto-batching. 323 | 324 | Args: 325 | token_ids (list[int]): a list of token ids, representing a prompt to the language model. 326 | 327 | Returns: 328 | logprobs (numpy.array): a numpy array of length `len(str_vocab)` (equivalently `len(byte_vocab)`) with the language model's log (normalized) probabilities for the next token following the prompt. 329 | """ 330 | return self.model.next_token_logprobs_sync(token_ids).float().cpu().numpy() 331 | 332 | def clear_cache(self): 333 | """Clear the cache of log probabilities and key/value pairs. 334 | 335 | For HuggingFace backend: Clears both logprob cache and KV cache. 336 | 337 | For vLLM backend: Only clears logprob cache (KV cache is managed internally by vLLM). 338 | """ 339 | self.model.clear_cache() 340 | 341 | def clear_kv_cache(self): 342 | """Clear any key and value vectors from the cache.""" 343 | if self.backend == "hf": 344 | self.model.clear_kv_cache() 345 | elif self.backend == "vllm": 346 | warnings.warn( 347 | "clear_kv_cache() is only supported for the HuggingFace backend. The KV cache for the vLLM backend is handled internally by vLLM. No operation performed.", 348 | RuntimeWarning, 349 | stacklevel=2, 350 | ) 351 | elif self.backend == "mock": 352 | pass 353 | else: 354 | raise RuntimeError( 355 | f"clear_kv_cache() is not implemented for backend type {type(self.model)}" 356 | ) 357 | 358 | def reset_async_queries(self): 359 | """Clear any pending language model queries from the queue.""" 360 | if self.backend == "hf": 361 | self.model.reset_async_queries() 362 | elif self.backend == "vllm": 363 | warnings.warn( 364 | "reset_async_queries() is only supported for the HuggingFace backend. No operation performed.", 365 | RuntimeWarning, 366 | stacklevel=2, 367 | ) 368 | elif self.backend == "mock": 369 | pass 370 | else: 371 | raise RuntimeError( 372 | f"reset_async_queries() is not implemented for backend type {type(self.model)}" 373 | ) 374 | 375 | def cache_kv(self, prompt_tokens): 376 | """Cache the key and value vectors for a prompt. 377 | 378 | Args: 379 | prompt_tokens (list[int]): token ids for the prompt to cache. 380 | """ 381 | if self.backend == "hf": 382 | self.model.cache_kv(prompt_tokens) 383 | elif self.backend == "vllm": 384 | warnings.warn( 385 | "cache_kv() is only supported for the HuggingFace backend. The KV cache for the vLLM backend is handled internally by vLLM. No operation performed.", 386 | RuntimeWarning, 387 | stacklevel=2, 388 | ) 389 | elif self.backend == "mock": 390 | pass 391 | else: 392 | raise RuntimeError( 393 | f"cache_kv() is not implemented for backend type {type(self.model)}" 394 | ) 395 | -------------------------------------------------------------------------------- /llamppl/modeling.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | 4 | class SubModel: 5 | def __init__(self): 6 | self.parent = None 7 | 8 | async def run_with_parent(self, parent): 9 | old_parent = self.parent 10 | self.parent = parent 11 | val = await self.forward() 12 | self.parent = old_parent 13 | return val 14 | 15 | async def forward(self): 16 | raise NotImplementedError( 17 | "SubModel.forward() must be implemented by subclasses" 18 | ) 19 | 20 | async def sample(self, dist, proposal=None): 21 | return await self.parent.sample(dist, proposal) 22 | 23 | async def observe(self, dist, x): 24 | return await self.parent.observe(dist, x) 25 | 26 | async def intervene(self, dist, x): 27 | return await self.parent.intervene(dist, x) 28 | 29 | def condition(self, b): 30 | return self.parent.condition(b) 31 | 32 | def score(self, score): 33 | return self.parent.score(score) 34 | 35 | def twist(self, amt): 36 | return self.parent.twist(amt) 37 | 38 | async def call(self, submodel): 39 | return await submodel.run_with_parent(self.parent) 40 | 41 | 42 | # For use as a decorator 43 | import functools 44 | 45 | 46 | def submodel(f): 47 | """Decorator to create a SubModel implementation from an async function. 48 | 49 | For example: 50 | 51 | ```python 52 | @submodel 53 | async def sample_two_tokens(self, context): 54 | token1 = await self.sample(context.next_token()) 55 | token2 = await self.sample(context.next_token()) 56 | return token1, token2 57 | ``` 58 | 59 | This SubModel can then be used from another model or submodel, using the syntax `await self.call(sample_two_tokens(context))`. 60 | """ 61 | 62 | @functools.wraps(f, updated=()) # unclear if this is the best way to do it 63 | class SubModelImpl(SubModel): 64 | def __init__(self, *args, **kwargs): 65 | super().__init__() 66 | self.args = args 67 | self.kwargs = kwargs 68 | 69 | async def forward(self): 70 | return await f(self, *self.args, **self.kwargs) 71 | 72 | return SubModelImpl 73 | 74 | 75 | class Model: 76 | """Base class for all LLaMPPL models. 77 | 78 | Your models should subclass this class. Minimally, you should provide an `__init__` method 79 | that calls `super().__init__(self)`, and a `step` method. 80 | """ 81 | 82 | def __init__(self): 83 | self.weight = 0.0 84 | self.finished = False 85 | self.mode = "sample" 86 | self.beam_idx = 0 87 | self.force_eos = False 88 | self.twist_amount = 0.0 89 | 90 | def reset(self): 91 | self.weight = 0.0 92 | self.finished = False 93 | self.mode = "sample" 94 | self.beam_idx = 0 95 | self.force_eos = False 96 | self.twist_amount = 0.0 97 | 98 | def immutable_properties(self): 99 | """Return a `set[str]` of properties that LLaMPPL may assume do not change during execution of `step`. 100 | This set is empty by default but can be overridden by subclasses to speed up inference. 101 | 102 | Returns: 103 | properties (set[str]): a set of immutable property names""" 104 | return set() 105 | 106 | def __deepcopy__(self, memo): 107 | cpy = type(self).__new__(type(self)) 108 | immutable = self.immutable_properties() 109 | 110 | for k, v in self.__dict__.items(): 111 | if k in immutable: 112 | setattr(cpy, k, v) 113 | else: 114 | setattr(cpy, k, copy.deepcopy(v, memo)) 115 | 116 | return cpy 117 | 118 | def twist(self, amt): 119 | """Multiply this particle's weight by `exp(amt)`, but divide it back out before the next `step`. 120 | 121 | Use this method to provide heuristic guidance about whether a particle is "on the right track" 122 | without changing the ultimate target distribution. 123 | 124 | Args: 125 | amt: the logarithm of the amount by which to (temporarily) multiply this particle's weight. 126 | """ 127 | self.twist_amount += amt 128 | self.score(amt) 129 | 130 | def untwist(self): 131 | self.score(-self.twist_amount) 132 | self.twist_amount = 0.0 133 | 134 | def finish(self): 135 | self.untwist() 136 | self.finished = True 137 | 138 | def done_stepping(self): 139 | return self.finished 140 | 141 | async def step(self): 142 | """Defines the computation performed in each step of the model. 143 | 144 | All subclasses should override this method.""" 145 | 146 | if not self.done_stepping(): 147 | raise NotImplementedError("Model.step() must be implemented by subclasses") 148 | 149 | def __str__(self): 150 | return "Particle" 151 | 152 | async def start(self): 153 | pass 154 | 155 | def score(self, score): 156 | """Multiply this particle's weight by `exp(score)`. 157 | 158 | The `score` method is a low-level way to change the target distribution. 159 | For many use cases, it is sufficient to use `sample`, `observe`, `condition`, 160 | and `twist`, all of which are implemented in terms of `score`. 161 | 162 | Args: 163 | score: logarithm of the amount by which the particle's weight should be multiplied. 164 | """ 165 | self.weight += score 166 | 167 | def condition(self, b): 168 | """Constrain a given Boolean expression to be `True`. 169 | 170 | If the condition is False, the particle's weight is set to zero and `self.finish()` 171 | is called, so that no further `step` calls are made. 172 | 173 | Args: 174 | b: the Boolean expression whose value is constrained to be True. 175 | """ 176 | if not b: 177 | self.score(float("-inf")) 178 | self.finish() 179 | 180 | async def intervene(self, dist, x): 181 | """Force the distribution to take on the value `x`, but do not _condition_ on this result. 182 | 183 | This is useful primarily with distributions that have side effects (e.g., modifying some state). 184 | For example, a model with the code 185 | 186 | ```python 187 | token_1 = await self.sample(self.stateful_lm.next_token()) 188 | await self.observe(self.stateful_lm.next_token(), token_2) 189 | ``` 190 | 191 | encodes a posterior inference problem, to find `token_1` values that *likely preceded* `token_2`. By contrast, 192 | 193 | ```python 194 | token_1 = await self.sample(stateful_lm.next_token()) 195 | await self.intervene(self.stateful_lm.next_token(), token_2) 196 | ``` 197 | 198 | encodes a much easier task: freely generate `token_1` and then force-feed `token_2` as the following token. 199 | 200 | Args: 201 | dist (llamppl.distributions.distribution.Distribution): the distribution on which to intervene. 202 | x: the value to intervene with. 203 | """ 204 | await dist.log_prob(x) 205 | return x 206 | 207 | async def observe(self, dist, x): 208 | """Condition the model on the value `x` being sampled from the distribution `dist`. 209 | 210 | For discrete distributions `dist`, `await self.observe(dist, x)` specifies the same constraint as 211 | ``` 212 | val = await self.sample(dist) 213 | self.condition(val == x) 214 | ``` 215 | but can be much more efficient. 216 | 217 | Args: 218 | dist: a `Distribution` object from which to observe 219 | x: the value observed from `dist` 220 | """ 221 | p = await dist.log_prob(x) 222 | self.score(p) 223 | return x 224 | 225 | async def sample(self, dist, proposal=None): 226 | """Extend the model with a sample from a given `Distribution`, with support for autobatching. 227 | If specified, the Distribution `proposal` is used during inference to generate informed hypotheses. 228 | 229 | Args: 230 | dist: the `Distribution` object from which to sample 231 | proposal: if provided, inference algorithms will use this `Distribution` object to generate proposed samples, rather than `dist`. 232 | However, importance weights will be adjusted so that the target posterior is independent of the proposal. 233 | 234 | Returns: 235 | value: the value sampled from the distribution. 236 | """ 237 | # Special logic for beam search 238 | # if self.mode == "beam": 239 | # d = dist if proposal is None else proposal 240 | # x, w = d.argmax(self.beam_idx) 241 | # if proposal is not None: 242 | # self.score(dist.log_prob(x)) 243 | # else: 244 | # self.score(w) 245 | # return x 246 | 247 | if proposal is None: 248 | x, _ = await dist.sample() 249 | return x 250 | else: 251 | x, q = await proposal.sample() 252 | p = await dist.log_prob(x) 253 | self.score(p - q) 254 | return x 255 | 256 | async def call(self, submodel): 257 | return await submodel.run_with_parent(self) 258 | 259 | def string_for_serialization(self): 260 | """Return a string representation of the particle for serialization purposes. 261 | 262 | Returns: 263 | str: a string representation of the particle. 264 | """ 265 | return str(self) 266 | -------------------------------------------------------------------------------- /llamppl/util.py: -------------------------------------------------------------------------------- 1 | """Utility functions""" 2 | 3 | import numpy as np 4 | 5 | 6 | def logsumexp(nums): 7 | m = np.max(nums) 8 | return np.log(np.sum(np.exp(nums - m))) + m 9 | 10 | 11 | def log_softmax(nums): 12 | """Compute log(softmax(nums)). 13 | 14 | Args: 15 | nums: a vector or numpy array of unnormalized log probabilities. 16 | 17 | Returns: 18 | np.array: an array of log (normalized) probabilities. 19 | """ 20 | return nums - logsumexp(nums) 21 | 22 | 23 | def softmax(nums): 24 | return np.exp(log_softmax(nums)) 25 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: LLaMPPL Docs 2 | 3 | theme: 4 | name: "material" 5 | palette: 6 | scheme: slate 7 | 8 | markdown_extensions: 9 | - pymdownx.highlight: 10 | anchor_linenums: true 11 | line_spans: __span 12 | pygments_lang_class: true 13 | - pymdownx.inlinehilite 14 | - pymdownx.snippets 15 | - pymdownx.superfences 16 | 17 | plugins: 18 | - search 19 | - mkdocstrings 20 | - gen-files: 21 | scripts: 22 | - docs/gen_reference_page.py 23 | - literate-nav: 24 | nav_file: SUMMARY.md 25 | - section-index 26 | 27 | nav: 28 | - index.md 29 | - Getting Started: 30 | - getting_started.md 31 | - anatomy.md 32 | - transformers.md 33 | - Performance Engineering: 34 | - performance.md 35 | - batching.md 36 | - caching.md 37 | - immutability.md 38 | - Visualization: 39 | - visualization.md 40 | - Code Reference: reference/ 41 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "llamppl" 3 | dynamic = ["version"] 4 | description = "Probabilistic programming with Large Language Models." 5 | authors = [ 6 | {name = "Alex Lew", email = "alexlew@mit.edu"}, 7 | {name = "Gabriel Grand", email = "grandg@mit.edu"}, 8 | {name = "Ben LeBrun", email = "benlebrun1@gmail.com"}, 9 | ] 10 | license = {text = "MIT"} 11 | readme = "README.md" 12 | requires-python = ">=3.10" 13 | dependencies = [ 14 | "torch>=2.1.2", 15 | "numpy>=1.26.2", 16 | "scipy>=1.11.4", 17 | "protobuf>=5.27.2", 18 | "pre-commit>=3.7.1", 19 | "ipykernel>=6.29.5", 20 | "genlm-backend>=0.1.0a1", 21 | ] 22 | 23 | [project.optional-dependencies] 24 | vllm = ["vllm>=0.6.6"] 25 | dev = [ 26 | "pytest", 27 | "pytest-benchmark", 28 | "pytest-cov", 29 | "pre-commit>=3.6.0", 30 | "ruff>=0.9.9", 31 | "jupyterlab>=4.0.9", 32 | "ipywidgets>=8.1.1", 33 | "matplotlib>=3.9.1", 34 | "seaborn>=0.13.2", 35 | ] 36 | yelp = [ 37 | "yake>=0.4.8", 38 | "datasets>=2.20.0", 39 | ] 40 | collie = [ 41 | "collie-bench>=0.1.0", 42 | "nltk>=3.8.1", 43 | "dill>=0.3.8", 44 | "evaluate>=0.4.2", 45 | ] 46 | examples = ["nltk>=3.8.1"] 47 | 48 | [tool.setuptools.packages.find] 49 | include = ["llamppl*"] 50 | 51 | [build-system] 52 | requires = ["setuptools>=64.0", "setuptools-scm>=8"] 53 | build-backend = "setuptools.build_meta" 54 | 55 | [tool.setuptools_scm] 56 | -------------------------------------------------------------------------------- /tests/test_examples.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | import pytest 4 | import torch 5 | 6 | from examples.haiku import run_example as run_haiku 7 | from examples.hard_constraints import run_example as run_hard_constraints 8 | from llamppl.llms import CachedCausalLM 9 | 10 | backends = [ 11 | "mock", 12 | "hf", 13 | pytest.param( 14 | "vllm", 15 | marks=pytest.mark.skipif( 16 | not torch.cuda.is_available(), reason="vLLM backend requires CUDA" 17 | ), 18 | ), 19 | ] 20 | 21 | 22 | @pytest.fixture 23 | def LLM(backend): 24 | # Set lower gpu_memory_utilization in vllm so that we can fit both models on the GPU 25 | kwargs = ( 26 | {"engine_opts": {"gpu_memory_utilization": 0.45}} if backend == "vllm" else {} 27 | ) 28 | return CachedCausalLM.from_pretrained("gpt2", backend=backend, **kwargs) 29 | 30 | 31 | @pytest.mark.parametrize("backend", backends) 32 | def test_hard_constraints(LLM, n_particles=20, max_tokens=25): 33 | particles = asyncio.run( 34 | run_hard_constraints(LLM, max_tokens=max_tokens, n_particles=n_particles) 35 | ) 36 | assert len(particles) == n_particles 37 | 38 | 39 | @pytest.mark.parametrize("backend", backends) 40 | def test_haiku(LLM, n_particles=20): 41 | particles = asyncio.run( 42 | run_haiku(LLM, poem_title="The beauty of testing", n_particles=n_particles) 43 | ) 44 | assert len(particles) == n_particles 45 | -------------------------------------------------------------------------------- /tests/test_lmcontext.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | import numpy as np 4 | import pytest 5 | import torch 6 | 7 | from llamppl.distributions.lmcontext import LMContext 8 | from llamppl.llms import CachedCausalLM 9 | 10 | backends = [ 11 | "mock", 12 | "hf", 13 | pytest.param( 14 | "vllm", 15 | marks=pytest.mark.skipif( 16 | not torch.cuda.is_available(), reason="vLLM backend requires CUDA" 17 | ), 18 | ), 19 | ] 20 | 21 | 22 | @pytest.fixture 23 | def lm(backend): 24 | return CachedCausalLM.from_pretrained("gpt2", backend=backend) 25 | 26 | 27 | @pytest.mark.parametrize("backend", backends) 28 | def test_init(lm): 29 | prompt = "Hello, world!" 30 | lmcontext = LMContext(lm, prompt) 31 | assert lmcontext.tokens == lm.tokenizer.encode(prompt) 32 | logprobs = lm.next_token_logprobs_unbatched(lmcontext.tokens) 33 | np.testing.assert_allclose( 34 | lmcontext.next_token_logprobs, 35 | logprobs, 36 | rtol=1e-5, 37 | err_msg="Sync context __init__", 38 | ) 39 | 40 | async def async_context(): 41 | return LMContext(lm, prompt) 42 | 43 | lmcontext = asyncio.run(async_context()) 44 | np.testing.assert_allclose( 45 | lmcontext.next_token_logprobs, 46 | logprobs, 47 | rtol=1e-5, 48 | err_msg="Async context __init__", 49 | ) 50 | 51 | async def async_context_create(): 52 | return await LMContext.create(lm, prompt) 53 | 54 | lmcontext = asyncio.run(async_context_create()) 55 | np.testing.assert_allclose( 56 | lmcontext.next_token_logprobs, 57 | logprobs, 58 | rtol=1e-5, 59 | err_msg="Async context create", 60 | ) 61 | --------------------------------------------------------------------------------