├── examples ├── __init__.py ├── hard_constraints.py ├── haiku.py └── grammar_constraint.py ├── docs ├── anatomy.md ├── performance.md ├── immutability.md ├── index.md ├── visualization.md ├── batching.md ├── gen_reference_page.py ├── transformers.md ├── caching.md └── getting_started.md ├── codecov.yml ├── llamppl ├── __init__.py ├── inference │ ├── __init__.py │ ├── smc_record.py │ ├── smc_standard.py │ └── smc_steer.py ├── util.py ├── distributions │ ├── geometric.py │ ├── bernoulli.py │ ├── __init__.py │ ├── logcategorical.py │ ├── distribution.py │ ├── tokencategorical.py │ ├── transformer.py │ └── lmcontext.py ├── chunks.py ├── modeling.py └── llms.py ├── .pre-commit-config.yaml ├── .github └── workflows │ ├── docs.yml │ ├── release.yml │ └── tests.yml ├── mkdocs.yml ├── pyproject.toml ├── tests ├── test_examples.py └── test_lmcontext.py ├── benchmark └── benchmark_backend.py ├── .gitignore ├── README.md ├── LICENSE └── html └── smc.html /examples/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/anatomy.md: -------------------------------------------------------------------------------- 1 | # Anatomy of a LLaMPPL model 2 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | status: 3 | project: 4 | default: 5 | target: auto 6 | threshold: 3% 7 | -------------------------------------------------------------------------------- /llamppl/__init__.py: -------------------------------------------------------------------------------- 1 | """Probabilistic programming with Large Language Models.""" 2 | 3 | from .chunks import * 4 | from .distributions import * 5 | from .inference import * 6 | from .llms import * 7 | from .modeling import * 8 | from .util import * 9 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v5.0.0 4 | hooks: 5 | - id: check-yaml 6 | args: [--unsafe] 7 | - id: end-of-file-fixer 8 | - id: trailing-whitespace 9 | 10 | - repo: https://github.com/astral-sh/ruff-pre-commit 11 | rev: v0.9.9 12 | hooks: 13 | - id: ruff-format 14 | types_or: [ python, pyi, jupyter ] 15 | -------------------------------------------------------------------------------- /docs/performance.md: -------------------------------------------------------------------------------- 1 | # Improving performance of LLaMPPL models 2 | 3 | If your LLaMPPL model is running slowly, consider exploiting the following features to improve performance: 4 | 5 | - [Auto-Batching](batching.md) — to run multiple particles concurrently, with batched LLM calls 6 | - [Caching](caching.md) - to cache key and value vectors for long prompts 7 | - [Immutability hinting](immutability.md) - to significantly speed up the bookkeeping performed by SMC inference 8 | -------------------------------------------------------------------------------- /llamppl/inference/__init__.py: -------------------------------------------------------------------------------- 1 | """Provides inference methods for use with LLaMPPL models. 2 | 3 | This module currently provides the following inference methods: 4 | 5 | * `smc_standard(model, num_particles, ess_threshold=0.5)`: Standard SMC with multinomial resampling. 6 | 7 | * `smc_steer(model, num_beams, num_expansions)`: a without-replacement SMC algorithm that resembles beam search. 8 | """ 9 | 10 | from .smc_standard import smc_standard 11 | from .smc_steer import smc_steer 12 | -------------------------------------------------------------------------------- /docs/immutability.md: -------------------------------------------------------------------------------- 1 | # Immutability 2 | 3 | When a particle is promising, the sequential Monte Carlo algorithm may _clone_ it, by calling `copy.deepcopy`. 4 | 5 | Depending on your model, this may be more or less expensive. 6 | 7 | To make it faster, override the `immutable_properties(self)` method of your Model class, to return a `set[str]` of property names that are guaranteed not to change during `step`. For all properties in this set, LLaMPPL will use shared memory across particles, and avoid copying when cloning particles. 8 | -------------------------------------------------------------------------------- /llamppl/util.py: -------------------------------------------------------------------------------- 1 | """Utility functions""" 2 | 3 | import numpy as np 4 | 5 | 6 | def logsumexp(nums): 7 | m = np.max(nums) 8 | return np.log(np.sum(np.exp(nums - m))) + m 9 | 10 | 11 | def log_softmax(nums): 12 | """Compute log(softmax(nums)). 13 | 14 | Args: 15 | nums: a vector or numpy array of unnormalized log probabilities. 16 | 17 | Returns: 18 | np.array: an array of log (normalized) probabilities. 19 | """ 20 | return nums - logsumexp(nums) 21 | 22 | 23 | def softmax(nums): 24 | return np.exp(log_softmax(nums)) 25 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # Home 2 | 3 | [LLaMPPL](https://github.com/genlm/llamppl) is a research prototype for language model probabilistic programming: specifying language generation tasks by writing probabilistic programs that combine calls to LLMs, symbolic program logic, and probabilistic conditioning. To solve these tasks, LLaMPPL uses a specialized sequential Monte Carlo inference algorithm. 4 | 5 | This technique, SMC steering, is described in our workshop abstract, [Sequential Monte Carlo Steering of Large Language Models using Probabilistic Programs](https://arxiv.org/abs/2306.03081). 6 | -------------------------------------------------------------------------------- /llamppl/distributions/geometric.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .distribution import Distribution 4 | 5 | 6 | class Geometric(Distribution): 7 | """A Geometric distribution.""" 8 | 9 | def __init__(self, p): 10 | """Create a Geometric distribution. 11 | 12 | Args: 13 | p: the rate of the Geometric distribution. 14 | """ 15 | self.p = p 16 | 17 | async def sample(self): 18 | n = np.random.geometric(self.p) 19 | return n, await self.log_prob(n) 20 | 21 | async def log_prob(self, value): 22 | return np.log(self.p) + np.log(1 - self.p) * (value - 1) 23 | 24 | async def argmax(self, idx): 25 | return idx - 1 # Most likely outcome is 0, then 1, etc. 26 | -------------------------------------------------------------------------------- /llamppl/distributions/bernoulli.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .distribution import Distribution 4 | 5 | 6 | class Bernoulli(Distribution): 7 | """A Bernoulli distribution.""" 8 | 9 | def __init__(self, p): 10 | """Create a Bernoulli distribution. 11 | 12 | Args: 13 | p: the probability-of-True for the Bernoulli distribution. 14 | """ 15 | self.p = p 16 | 17 | async def sample(self): 18 | b = np.random.rand() < self.p 19 | return (b, await self.log_prob(b)) 20 | 21 | async def log_prob(self, value): 22 | return np.log(self.p) if value else np.log1p(-self.p) 23 | 24 | async def argmax(self, idx): 25 | return (self.p > 0.5) if idx == 0 else (self.p < 0.5) 26 | -------------------------------------------------------------------------------- /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | name: docs 2 | on: 3 | push: 4 | branches: 5 | - main 6 | permissions: 7 | contents: write 8 | jobs: 9 | deploy: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v3 13 | - uses: actions/setup-python@v4 14 | with: 15 | python-version: 3.x 16 | - run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV 17 | - uses: actions/cache@v3 18 | with: 19 | key: mkdocs-material-${{ env.cache_id }} 20 | path: .cache 21 | restore-keys: | 22 | mkdocs-material- 23 | - run: pip install mkdocs-material mkdocstrings mkdocs-literate-nav mkdocs-section-index mkdocs-gen-files mkdocstrings-python 24 | - run: mkdocs gh-deploy --force 25 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: LLaMPPL Docs 2 | 3 | theme: 4 | name: "material" 5 | palette: 6 | scheme: slate 7 | 8 | markdown_extensions: 9 | - pymdownx.highlight: 10 | anchor_linenums: true 11 | line_spans: __span 12 | pygments_lang_class: true 13 | - pymdownx.inlinehilite 14 | - pymdownx.snippets 15 | - pymdownx.superfences 16 | 17 | plugins: 18 | - search 19 | - mkdocstrings 20 | - gen-files: 21 | scripts: 22 | - docs/gen_reference_page.py 23 | - literate-nav: 24 | nav_file: SUMMARY.md 25 | - section-index 26 | 27 | nav: 28 | - index.md 29 | - Getting Started: 30 | - getting_started.md 31 | - anatomy.md 32 | - transformers.md 33 | - Performance Engineering: 34 | - performance.md 35 | - batching.md 36 | - caching.md 37 | - immutability.md 38 | - Visualization: 39 | - visualization.md 40 | - Code Reference: reference/ 41 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release to PyPI 2 | 3 | on: 4 | workflow_dispatch: 5 | release: 6 | types: [published] 7 | 8 | jobs: 9 | release: 10 | runs-on: ubuntu-22.04 11 | 12 | # Add "id-token" with the intended permissions. 13 | permissions: 14 | contents: 'read' 15 | id-token: 'write' 16 | 17 | steps: 18 | - uses: actions/checkout@v4 19 | with: 20 | # This is here so that the versioning plugin will be able to see tags 21 | # and version using them. 22 | fetch-depth: 0 23 | 24 | - uses: actions/setup-python@v4 25 | with: 26 | python-version: 3.11.5 27 | 28 | - name: Build package 29 | run: | 30 | python3 -m pip install --upgrade build 31 | python3 -m build 32 | 33 | - name: Publish to PyPI 34 | uses: pypa/gh-action-pypi-publish@release/v1 35 | with: 36 | password: ${{ secrets.PYPI_API_TOKEN }} 37 | -------------------------------------------------------------------------------- /llamppl/distributions/__init__.py: -------------------------------------------------------------------------------- 1 | """Exposes distributions for use with `sample`, `observe`, and `intervene` methods in LLaMPPL models. 2 | 3 | Currently supported distributions: 4 | 5 | * `Bernoulli(p: float) -> bool` 6 | * `Geometric(p: float) -> int` 7 | * `LogCategorical(logits: array) -> int` 8 | * `TokenCategorical(lm: llamppl.llms.CachedCausalLM, logits: array) -> llamppl.llms.Token` 9 | * `Transformer(lm: llamppl.llms.CachedCausalLM) -> llamppl.llms.Token` 10 | * `LMContext(lm: llamppl.llms.CachedCausalLM, prompt: list[int]).next_token() -> llamppl.llms.Token` 11 | * `LMContext(lm: llamppl.llms.CachedCausalLM, prompt: list[int]).mask_dist(mask: set[int]) -> bool` 12 | """ 13 | 14 | from .bernoulli import Bernoulli 15 | from .distribution import Distribution 16 | from .geometric import Geometric 17 | from .lmcontext import LMContext 18 | from .logcategorical import LogCategorical 19 | from .tokencategorical import TokenCategorical 20 | from .transformer import Transformer 21 | -------------------------------------------------------------------------------- /llamppl/distributions/logcategorical.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from ..util import log_softmax 4 | from .distribution import Distribution 5 | 6 | 7 | class LogCategorical(Distribution): 8 | """A Geometric distribution.""" 9 | 10 | def __init__(self, logits): 11 | """Create a Categorical distribution from unnormalized log probabilities (logits). 12 | Given an array of logits, takes their `softmax` and samples an integer in `range(len(logits))` 13 | from the resulting categorical. 14 | 15 | Args: 16 | logits (np.array): a numpy array of unnormalized log probabilities. 17 | """ 18 | self.log_probs = log_softmax(logits) 19 | 20 | async def sample(self): 21 | n = np.random.choice(len(self.log_probs), p=np.exp(self.log_probs)) 22 | return n, await self.log_prob(n) 23 | 24 | async def log_prob(self, value): 25 | return self.log_probs[value] 26 | 27 | async def argmax(self, idx): 28 | return np.argsort(self.log_probs)[-idx] 29 | -------------------------------------------------------------------------------- /docs/visualization.md: -------------------------------------------------------------------------------- 1 | # Visualization 2 | 3 | We provide a Web interface for visualizing the execution of a sequential Monte Carlo algorithm, 4 | based on contributions from Maddy Bowers and Jacob Hoover. 5 | 6 | First, update your model to support visualization by implementing the [`string_for_serialization`](llamppl.modeling.Model.string_for_serialization) method. 7 | Return a string that summarizes the particle's current state. 8 | 9 | To run the interface, change to the `html` directory and run `python -m http.server`. This will start serving 10 | the files in the `html` directory at localhost:8000. (If you are SSH-ing onto a remote machine, you may need 11 | port forwarding. Visual Studio Code automatically handles this for some ports, including 8000.) 12 | Then, when calling [`smc_standard`](llamppl.inference.smc_standard), set `visualization_dir` 13 | to the path to the `html` directory. A JSON record of the run will automatically be saved 14 | to that directory, and a URL will be printed to the console (`http://localhost:8000/smc.html?path=$json_file`). 15 | -------------------------------------------------------------------------------- /docs/batching.md: -------------------------------------------------------------------------------- 1 | # Auto-Batching 2 | 3 | If running in a GPU-accelerated environment, LLaMPPL supports **auto-batching**. 4 | 5 | The `step` method of a LLaMPPL model describes how to advance a *single* particle one step of generation. 6 | But inference methods must maintain many particles at once. 7 | 8 | With auto-batching, LLaMPPL will execute particles' `step` methods concurrently, and automatically batch calls 9 | to large language models. This batching is handled by the `CachedCausalLM` object. 10 | 11 | If you are using the `vllm` backend, all batching decisions are handled internally by `vllm`. 12 | 13 | If you are using the huggingface backend, the behavior is controlled by two parameters: 14 | 15 | * `lm.batch_size`: the maximum number of requests to batch. The default value is 20. 16 | * `lm.timeout`: if `lm.timeout` seconds pass with no new request, the current batch is processed even if not full. The default value is 0.02. 17 | 18 | You may want to set the batch size (`#!python lm.batch_size`) to the number of particles you are using (if the number of particles is not too large). 19 | -------------------------------------------------------------------------------- /docs/gen_reference_page.py: -------------------------------------------------------------------------------- 1 | """Generate the code reference pages and navigation.""" 2 | 3 | from pathlib import Path 4 | 5 | import mkdocs_gen_files 6 | 7 | nav = mkdocs_gen_files.Nav() 8 | 9 | for path in sorted(Path("llamppl").rglob("*.py")): 10 | if any(part.startswith(".") for part in path.parts): 11 | continue 12 | 13 | module_path = path.relative_to(".").with_suffix("") 14 | doc_path = path.relative_to(".").with_suffix(".md") 15 | full_doc_path = Path("reference", doc_path) 16 | 17 | parts = tuple(module_path.parts) 18 | 19 | if parts[-1] == "__init__": 20 | print(f"init, making parts {parts[:-1]}") 21 | parts = parts[:-1] 22 | elif parts[-1] == "__main__": 23 | continue 24 | 25 | nav[parts] = doc_path.as_posix() # 26 | 27 | with mkdocs_gen_files.open(full_doc_path, "w") as fd: 28 | ident = ".".join(parts) 29 | fd.write(f"::: {ident}") 30 | 31 | mkdocs_gen_files.set_edit_path(full_doc_path, path) 32 | 33 | with mkdocs_gen_files.open("reference/SUMMARY.md", "w") as nav_file: # 34 | nav_file.writelines(nav.build_literate_nav()) # 35 | -------------------------------------------------------------------------------- /llamppl/distributions/distribution.py: -------------------------------------------------------------------------------- 1 | class Distribution: 2 | """Abstract base class for a distribution.""" 3 | 4 | async def sample(self): 5 | """Generate a random sample from the distribution. 6 | 7 | Returns: 8 | x: a value randomly sampled from the distribution.""" 9 | raise NotImplementedError() 10 | 11 | async def log_prob(self, x): 12 | """Compute the log probability of a value under this distribution, 13 | or the log probability density if the distribution is continuous. 14 | 15 | Args: 16 | x: the point at which to evaluate the log probability. 17 | Returns: 18 | logprob (float): the log probability of `x`.""" 19 | raise NotImplementedError() 20 | 21 | async def argmax(self, n): 22 | """Return the nth most probable outcome under this distribution (assuming this is a discrete distribution). 23 | 24 | Args: 25 | n (int): which value to return to, indexed from most probable (n=0) to least probable (n=|support|). 26 | Returns: 27 | x: the nth most probable outcome from this distribution.""" 28 | raise NotImplementedError() 29 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "llamppl" 3 | dynamic = ["version"] 4 | description = "Probabilistic programming with Large Language Models." 5 | authors = [ 6 | {name = "Alex Lew", email = "alexlew@mit.edu"}, 7 | {name = "Gabriel Grand", email = "grandg@mit.edu"}, 8 | {name = "Ben LeBrun", email = "benlebrun1@gmail.com"}, 9 | ] 10 | license = {text = "MIT"} 11 | readme = "README.md" 12 | requires-python = ">=3.10" 13 | dependencies = [ 14 | "torch>=2.1.2", 15 | "numpy>=1.26.2", 16 | "scipy>=1.11.4", 17 | "protobuf>=5.27.2", 18 | "pre-commit>=3.7.1", 19 | "ipykernel>=6.29.5", 20 | "genlm-backend>=0.1.0a1", 21 | ] 22 | 23 | [project.optional-dependencies] 24 | vllm = ["vllm>=0.6.6"] 25 | mlx = ["genlm-backend[mlx]>=0.1.7"] 26 | dev = [ 27 | "pytest", 28 | "pytest-benchmark", 29 | "pytest-cov", 30 | "pre-commit>=3.6.0", 31 | "ruff>=0.9.9", 32 | "jupyterlab>=4.0.9", 33 | "ipywidgets>=8.1.1", 34 | "matplotlib>=3.9.1", 35 | "seaborn>=0.13.2", 36 | ] 37 | yelp = [ 38 | "yake>=0.4.8", 39 | "datasets>=2.20.0", 40 | ] 41 | collie = [ 42 | "collie-bench>=0.1.0", 43 | "nltk>=3.8.1", 44 | "dill>=0.3.8", 45 | "evaluate>=0.4.2", 46 | ] 47 | examples = ["nltk>=3.8.1"] 48 | 49 | [tool.setuptools.packages.find] 50 | include = ["llamppl*"] 51 | 52 | [build-system] 53 | requires = ["setuptools>=64.0", "setuptools-scm>=8"] 54 | build-backend = "setuptools.build_meta" 55 | 56 | [tool.setuptools_scm] 57 | -------------------------------------------------------------------------------- /tests/test_examples.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | import pytest 4 | import torch 5 | 6 | from examples.haiku import run_example as run_haiku 7 | from examples.hard_constraints import run_example as run_hard_constraints 8 | from llamppl.llms import CachedCausalLM, MLX_AVAILABLE 9 | 10 | if MLX_AVAILABLE: 11 | backends = ["mock", "mlx"] 12 | else: 13 | backends = [ 14 | "mock", 15 | "hf", 16 | pytest.param( 17 | "vllm", 18 | marks=pytest.mark.skipif( 19 | not torch.cuda.is_available(), reason="vLLM backend requires CUDA" 20 | ), 21 | ), 22 | ] 23 | 24 | 25 | @pytest.fixture 26 | def LLM(backend): 27 | # Set lower gpu_memory_utilization in vllm so that we can fit both models on the GPU 28 | kwargs = ( 29 | {"engine_opts": {"gpu_memory_utilization": 0.45}} 30 | if backend == "vllm" 31 | else {"cache_size": 10} 32 | if backend == "mlx" 33 | else {} 34 | ) 35 | return CachedCausalLM.from_pretrained("gpt2", backend=backend, **kwargs) 36 | 37 | 38 | @pytest.mark.parametrize("backend", backends) 39 | def test_hard_constraints(LLM, n_particles=20, max_tokens=25): 40 | particles = asyncio.run( 41 | run_hard_constraints(LLM, max_tokens=max_tokens, n_particles=n_particles) 42 | ) 43 | assert len(particles) == n_particles 44 | 45 | 46 | @pytest.mark.parametrize("backend", backends) 47 | def test_haiku(LLM, n_particles=20): 48 | particles = asyncio.run( 49 | run_haiku(LLM, poem_title="The beauty of testing", n_particles=n_particles) 50 | ) 51 | assert len(particles) == n_particles 52 | -------------------------------------------------------------------------------- /llamppl/distributions/tokencategorical.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from ..llms import Token 5 | from ..util import log_softmax 6 | from .distribution import Distribution 7 | 8 | 9 | class TokenCategorical(Distribution): 10 | def __init__(self, lm, logits): 11 | """Create a Categorical distribution whose values are Tokens, not integers. 12 | Given a language model `lm` and an array of unnormalized log probabilities (of length `len(lm.vocab)`), 13 | uses softmax to normalize them and samples a Token from the resulting categorical. 14 | 15 | Args: 16 | lm (llamppl.llms.CachedCausalLM): the language model whose vocabulary is to be generated from. 17 | logits (np.array): a numpy array of unnormalized log probabilities. 18 | """ 19 | self.lm = lm 20 | self.log_probs = log_softmax(logits) 21 | if self.lm.tokenizer.vocab_size != len(logits): 22 | raise RuntimeError( 23 | f"TokenCategorical: vocab size is {self.lm.tokenizer.vocab_size} but provided {len(logits)} logits." 24 | ) 25 | 26 | async def sample(self): 27 | n = np.random.choice(len(self.log_probs), p=(np.exp(self.log_probs))) 28 | return ( 29 | Token(self.lm, n, self.lm.tokenizer.convert_ids_to_tokens(n)), 30 | self.log_probs[n], 31 | ) 32 | 33 | async def log_prob(self, value): 34 | return self.log_probs[value.token_id] 35 | 36 | async def argmax(self, idx): 37 | tok = torch.argsort(self.log_probs)[-idx] 38 | return ( 39 | Token(self.lm, tok, self.lm.tokenizer.convert_ids_to_tokens(tok)), 40 | self.log_probs[tok], 41 | ) 42 | -------------------------------------------------------------------------------- /tests/test_lmcontext.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | import numpy as np 4 | import pytest 5 | import torch 6 | 7 | from llamppl.distributions.lmcontext import LMContext 8 | from llamppl.llms import CachedCausalLM, MLX_AVAILABLE 9 | 10 | if MLX_AVAILABLE: 11 | backends = ["mock", "mlx"] 12 | else: 13 | backends = [ 14 | "mock", 15 | "hf", 16 | pytest.param( 17 | "vllm", 18 | marks=pytest.mark.skipif( 19 | not torch.cuda.is_available(), reason="vLLM backend requires CUDA" 20 | ), 21 | ), 22 | ] 23 | 24 | 25 | @pytest.fixture 26 | def lm(backend): 27 | kwargs = {"cache_size": 10} if backend == "mlx" else {} 28 | return CachedCausalLM.from_pretrained("gpt2", backend=backend, **kwargs) 29 | 30 | 31 | @pytest.mark.parametrize("backend", backends) 32 | def test_init(lm): 33 | prompt = "Hello, world!" 34 | lmcontext = LMContext(lm, prompt) 35 | assert lmcontext.tokens == lm.tokenizer.encode(prompt) 36 | logprobs = lm.next_token_logprobs_unbatched(lmcontext.tokens) 37 | np.testing.assert_allclose( 38 | lmcontext.next_token_logprobs, 39 | logprobs, 40 | rtol=5e-4, 41 | err_msg="Sync context __init__", 42 | ) 43 | 44 | async def async_context(): 45 | return LMContext(lm, prompt) 46 | 47 | lmcontext = asyncio.run(async_context()) 48 | np.testing.assert_allclose( 49 | lmcontext.next_token_logprobs, 50 | logprobs, 51 | rtol=5e-4, 52 | err_msg="Async context __init__", 53 | ) 54 | 55 | async def async_context_create(): 56 | return await LMContext.create(lm, prompt) 57 | 58 | lmcontext = asyncio.run(async_context_create()) 59 | np.testing.assert_allclose( 60 | lmcontext.next_token_logprobs, 61 | logprobs, 62 | rtol=5e-4, 63 | err_msg="Async context create", 64 | ) 65 | -------------------------------------------------------------------------------- /llamppl/distributions/transformer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from ..llms import Token 4 | from ..llms import TokenSequence 5 | from .distribution import Distribution 6 | 7 | 8 | # Transformer(lm, prompt) -- where prompt can either be a string or a list of Tokens. 9 | class Transformer(Distribution): 10 | def __init__(self, lm, prompt, temp=1.0): 11 | """Create a Categorical distribution whose values are Tokens, with probabilities given 12 | by a language model. Supports auto-batching. 13 | 14 | Args: 15 | lm (llamppl.llms.CachedCausalLM): the language model. 16 | prompt (str | llamppl.llms.TokenSequence): the sequence of tokens to use as the prompt. If a string, `lm.tokenizer` is used to encode it. 17 | temp (float): temperature at which to generate (0 < `temp` < `float('inf')`). 18 | """ 19 | self.lm = lm 20 | self.temp = temp 21 | 22 | # prompt will be a list of ints 23 | if isinstance(prompt, str): 24 | prompt = self.lm.tokenizer.encode(prompt) 25 | elif isinstance(prompt, TokenSequence): 26 | prompt = prompt.seq 27 | 28 | self.prompt = prompt 29 | 30 | async def log_prob(self, x): 31 | log_probs = await self.lm.next_token_logprobs(self.prompt) 32 | log_probs = log_probs / self.temp 33 | 34 | if isinstance(x, Token): 35 | x = x.token_id 36 | 37 | return log_probs[x] 38 | 39 | async def sample(self): 40 | log_probs = await self.lm.next_token_logprobs(self.prompt) 41 | log_probs = log_probs / self.temp 42 | probs = np.exp(log_probs) 43 | token_id = np.random.choice(len(probs), p=(probs)) 44 | logprob = log_probs[token_id] 45 | return ( 46 | Token(self.lm, token_id, self.lm.tokenizer.convert_ids_to_tokens(token_id)), 47 | logprob, 48 | ) 49 | 50 | 51 | # def argmax(self, idx): 52 | # token_id = np.argsort(self.log_probs)[-idx] 53 | # return Token(self.lm, token_id, self.lm.tokenizer.convert_ids_to_tokens(token_id)), log_probs[token_id] 54 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | # Builds the `llamppl` environment and runs all tests 2 | 3 | name: Codebase tests 4 | 5 | on: 6 | pull_request: 7 | push: 8 | branches: 9 | - main 10 | 11 | permissions: 12 | contents: read 13 | 14 | jobs: 15 | build: 16 | runs-on: ParallelHoss 17 | 18 | steps: 19 | - name: Check out repository 20 | uses: actions/checkout@v4 21 | 22 | - name: Set up python 23 | id: setup-python 24 | uses: actions/setup-python@v5 25 | with: 26 | python-version: '3.11.5' 27 | 28 | - name: Run Tests 29 | run: | 30 | python -m venv venv 31 | source venv/bin/activate 32 | pip install -e .[dev,examples] 33 | pip install --force-reinstall 'triton==3.2.0' 34 | # Add the project root to the PYTHONPATH for examples 35 | PYTHONPATH=$PYTHONPATH:$(pwd) pytest tests --cov=llamppl --cov-report=json 36 | 37 | - name: Upload coverage to Codecov 38 | uses: codecov/codecov-action@v5 39 | with: 40 | fail_ci_if_error: false 41 | disable_search: true 42 | token: ${{ secrets.CODECOV_TOKEN }} 43 | files: ./coverage.json 44 | slug: genlm/llamppl 45 | 46 | 47 | test_mlx: 48 | runs-on: macos-14 49 | 50 | steps: 51 | - name: Check out repository 52 | uses: actions/checkout@v4 53 | with: 54 | fetch-depth: 1 55 | 56 | - name: Set up Python 57 | uses: actions/setup-python@v5 58 | with: 59 | python-version: '3.11.5' 60 | cache: 'pip' 61 | 62 | - name: Run tests with MLX extras 63 | run: | 64 | python -m venv venv 65 | source venv/bin/activate 66 | pip install -e .[mlx,dev,examples] 67 | PYTHONPATH=$PYTHONPATH:$(pwd) pytest tests --cov=llamppl --cov-report=json 68 | 69 | - name: Upload MLX coverage to Codecov 70 | uses: codecov/codecov-action@v5 71 | with: 72 | fail_ci_if_error: false 73 | disable_search: true 74 | token: ${{ secrets.CODECOV_TOKEN }} 75 | files: ./coverage.json 76 | slug: genlm/llamppl 77 | -------------------------------------------------------------------------------- /benchmark/benchmark_backend.py: -------------------------------------------------------------------------------- 1 | """ 2 | Requires pytest and pytest-benchmark (pip install pytest pytest-benchmark) 3 | 4 | Example usage: pytest benchmark/benchmark_backend.py --benchmark-only --benchmark-group-by=func -v 5 | """ 6 | 7 | import asyncio 8 | 9 | import pytest 10 | import torch 11 | 12 | from examples.haiku import run_example as run_haiku 13 | from examples.hard_constraints import run_example as run_hard_constraints 14 | from llamppl.llms import CachedCausalLM, MLX_AVAILABLE 15 | 16 | backends = [ 17 | "hf", 18 | pytest.param( 19 | "vllm", 20 | marks=pytest.mark.skipif( 21 | not torch.cuda.is_available(), reason="vLLM backend requires CUDA" 22 | ), 23 | ), 24 | pytest.param( 25 | "mlx", 26 | marks=pytest.mark.skipif( 27 | not MLX_AVAILABLE, reason="MLX backend requires MLX-LM" 28 | ), 29 | ), 30 | ] 31 | 32 | 33 | @pytest.fixture 34 | def LLM(backend): 35 | # Set lower gpu_memory_utilization in vllm so that we can fit both models on the GPU 36 | kwargs = ( 37 | {"engine_opts": {"gpu_memory_utilization": 0.45}, "cache_size": 100} 38 | if backend == "vllm" 39 | else {} 40 | ) 41 | return CachedCausalLM.from_pretrained( 42 | "meta-llama/Meta-Llama-3-8B", backend=backend, **kwargs 43 | ) 44 | 45 | 46 | @pytest.mark.parametrize("backend", backends) 47 | def test_hard_constraints_benchmark(LLM, benchmark, n_particles=20, max_tokens=50): 48 | def run_with_clear_cache(): 49 | LLM.clear_cache() 50 | return asyncio.run( 51 | run_hard_constraints(LLM, max_tokens=max_tokens, n_particles=n_particles) 52 | ) 53 | 54 | # warmup 55 | run_with_clear_cache() 56 | 57 | benchmark.pedantic( 58 | run_with_clear_cache, 59 | iterations=1, 60 | rounds=3, 61 | ) 62 | 63 | 64 | @pytest.mark.parametrize("backend", backends) 65 | def test_haiku_benchmark(LLM, benchmark, n_particles=20): 66 | def run_with_clear_cache(): 67 | LLM.clear_cache() 68 | return asyncio.run( 69 | run_haiku(LLM, poem_title="The beauty of testing", n_particles=n_particles) 70 | ) 71 | 72 | # warmup 73 | run_with_clear_cache() 74 | 75 | benchmark.pedantic( 76 | run_with_clear_cache, 77 | iterations=1, 78 | rounds=3, 79 | ) 80 | -------------------------------------------------------------------------------- /llamppl/inference/smc_record.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | 4 | class SMCRecord: 5 | def __init__(self, n): 6 | self.history = [] 7 | self.most_recent_weights = [0.0 for _ in range(n)] 8 | self.step_num = 1 9 | 10 | def prepare_string(self, s): 11 | # If the string doesn't have <<< and >>>, prepend <<<>>> at the front. 12 | if "<<<" not in s and ">>>" not in s: 13 | return f"<<<>>>{s}" 14 | return s 15 | 16 | def particle_dict(self, particles): 17 | return [ 18 | { 19 | "contents": self.prepare_string(p.string_for_serialization()), 20 | "logweight": ( 21 | "-Infinity" if p.weight == float("-inf") else str(float(p.weight)) 22 | ), 23 | "weight_incr": str( 24 | float(p.weight) - float(self.most_recent_weights[i]) 25 | ), 26 | } 27 | for (i, p) in enumerate(particles) 28 | ] 29 | 30 | def add_init(self, particles): 31 | self.history.append( 32 | { 33 | "step": self.step_num, 34 | "mode": "init", 35 | "particles": self.particle_dict(particles), 36 | } 37 | ) 38 | self.most_recent_weights = [p.weight for p in particles] 39 | 40 | def add_smc_step(self, particles): 41 | self.step_num += 1 42 | self.history.append( 43 | { 44 | "step": self.step_num, 45 | "mode": "smc_step", 46 | "particles": self.particle_dict(particles), 47 | } 48 | ) 49 | self.most_recent_weights = [p.weight for p in particles] 50 | 51 | def add_resample(self, ancestor_indices, particles): 52 | self.step_num += 1 53 | self.most_recent_weights = [ 54 | self.most_recent_weights[i] for i in ancestor_indices 55 | ] 56 | 57 | self.history.append( 58 | { 59 | "mode": "resample", 60 | "step": self.step_num, 61 | "ancestors": [int(a) for a in ancestor_indices], 62 | "particles": self.particle_dict(particles), 63 | } 64 | ) 65 | 66 | self.most_recent_weights = [p.weight for p in particles] 67 | 68 | def to_json(self): 69 | return json.dumps(self.history) 70 | -------------------------------------------------------------------------------- /docs/transformers.md: -------------------------------------------------------------------------------- 1 | # Working with Transformers 2 | 3 | ## Load your Transformer as a `CachedCausalLM` 4 | 5 | The easiest way to load a Transformer model is to use the [`CachedCausalLM.from_pretrained`][llamppl.llms.CachedCausalLM.from_pretrained] static method, which accepts as input a HuggingFace model identifier. This loads the model's weights into memory, and also loads the appropriate tokenizer. Note that if the model in question requires HuggingFace authorization (e.g., Meta's Llama 2 models), you will need to login via the [`huggingface-cli` command line tool](https://huggingface.co/docs/huggingface_hub/en/guides/cli). 6 | 7 | ## Use the LLM within your model via the `Transformer` distribution 8 | 9 | Within a model, you can `sample` or `observe` from the [`Transformer`][llamppl.distributions.transformer.Transformer] distribution. It accepts as arguments a [`CachedCausalLM`][llamppl.llms.CachedCausalLM] instance, as well as a list of integer token ids specifying the context. It returns a distribution over next tokens. The [`Transformer`][llamppl.distributions.transformer.Transformer] distirbution is stateless, and so your model will need to manually extend the context with newly sampled tokens. 10 | 11 | ## Use the LLM within your model via the `LMContext` class 12 | 13 | Alternatively, you can initialize an [`LMContext`][llamppl.distributions.lmcontext.LMContext] object with a [`CachedCausalLM`][llamppl.llms.CachedCausalLM] instance instance and a string-valued prompt. It maintains a growing context as state, and exposes a [`next_token`][llamppl.distributions.lmcontext.LMContext.next_token] distribution that, when sampled, observed, or intervened, grows the context. It also supports a form of 'sub-token' generation, via the [`mask_dist`][llamppl.distributions.lmcontext.LMContext.mask_dist] distribution. 14 | 15 | ## Create custom token distributions with `TokenCategorical` 16 | 17 | You may also create a custom distribution over the vocabulary of a language model using the [`TokenCategorical`][llamppl.distributions.tokencategorical.TokenCategorical] distribution. It is parameterized by a [`CachedCausalLM`][llamppl.llms.CachedCausalLM] instance, and an array of logits equal in length to the language model's vocabulary size. 18 | This distribution is particularly useful as a proposal distribution; for example, a model might `sample` with `dist` set 19 | to the LM's next token distribution, but with `proposal` set to a modified distribution that uses a heuristic to upweight 20 | 'good' tokens and downweight 'bad' ones. 21 | -------------------------------------------------------------------------------- /docs/caching.md: -------------------------------------------------------------------------------- 1 | # Caching in LLaMPPL 2 | 3 | LLaMPPL performs two kinds of caching to improve performance. The caching behavior is dependent on the backend you are using with your `CachedCausalLM`. 4 | 5 | ## Log probability caching 6 | With the huggingface backend, next-token log probabilities are always cached, whenever they are computed. 7 | This way, if different particles make exactly the same log probability queries, 8 | the Transformer is run only once. This is primarily beneficial when: 9 | 10 | * particles are cloned during resampling 11 | 12 | * cloned particles happen to sample the same next token: if the next-token distribution is concentrated, 13 | it is likely that multiple copies of a particle will sample the same next token. Log probability caching 14 | allows them to sample the _following_ token using only a single call to the language model. 15 | 16 | The log probability cache can be cleared using the [`lm.clear_cache()`][llamppl.llms.CachedCausalLM.clear_cache] method. Note that for the huggingface backend, this method will also clear the KV cache. 17 | 18 | With the `vllm` backend, the log probability cache can be turned on by passing in a `cache_size` parameter to the `CachedCausalLM.from_pretrained` method; for example, `CachedCausalLM.from_pretrained("meta-llama/Llama-3.1-8B-Instruct", cache_size=100)`. By default, `cache_size` is set to 0, which means that the log probability cache is disabled. 19 | 20 | ## Key-value caching 21 | Key-value caching caches the key and value vectors computed by each layer of a Transformer, 22 | for reuse when processing new tokens at the end of a previously evaluated sequence. 23 | 24 | In principle, key-value caching is most useful when: 25 | 26 | * There is a long common *prompt* from which all particles are generating. 27 | In this case, the prompt's tokens can be evaluated just once by the language model, 28 | and each subsequent call only has to pay for the new tokens generated after the prompt. 29 | 30 | * Generations from the model are very long. In this case, it may be worth paying the memory 31 | cost to cache *different* key-value sequences for *each* particle, to speed up future next-token 32 | queries. 33 | 34 | When using the `vllm` backend, both types of caching are handled automatically. 35 | 36 | With the huggingface backend, only the first use case is well-supported by the LLaMPPL library, via the 37 | [`lm.cache_kv(prompt)`][llamppl.llms.CachedCausalLM.cache_kv] method. This method computes and caches key and value vectors 38 | for every token in `prompt`. Future calls to [`lm.next_token_logprobs`][llamppl.llms.CachedCausalLM.next_token_logprobs] and [`lm.next_token_logprobs_unbatched`][llamppl.llms.CachedCausalLM.next_token_logprobs_unbatched] 39 | will automatically recognize when `prompt` is a prefix of the new query, and automatically 40 | exploit incremental computation. Multiple prompts can be cached, and [`lm.clear_kv_cache()`][llamppl.llms.CachedCausalLM.clear_kv_cache] can 41 | be used to clear the KV-cache without clearing the log probability cache. 42 | 43 | Because [`lm.cache_kv`][llamppl.llms.CachedCausalLM.cache_kv] is not a batched call, 44 | it is not well-suited to caching 45 | different strings for different particles. 46 | Rather, it is best used in the `__init__` method of a model--or even 47 | outside of a model--on fixed prompt strings that every particle will share. 48 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | coverage.json 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | cover/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | .pybuilder/ 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | # For a library or package, you might want to ignore these files since the code is 89 | # intended to run in multiple environments; otherwise, check them in: 90 | .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # poetry 100 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 101 | # This is especially recommended for binary packages to ensure reproducibility, and is more 102 | # commonly ignored for libraries. 103 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 104 | poetry.lock 105 | 106 | # pdm 107 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 108 | #pdm.lock 109 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 110 | # in version control. 111 | # https://pdm.fming.dev/#use-with-ide 112 | .pdm.toml 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /llamppl/inference/smc_standard.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import copy 3 | from datetime import datetime 4 | 5 | import numpy as np 6 | 7 | from ..util import logsumexp 8 | from .smc_record import SMCRecord 9 | 10 | 11 | async def smc_standard( 12 | model, n_particles, ess_threshold=0.5, visualization_dir=None, json_file=None 13 | ): 14 | """ 15 | Standard sequential Monte Carlo algorithm with multinomial resampling. 16 | 17 | Args: 18 | model (llamppl.modeling.Model): The model to perform inference on. 19 | n_particles (int): Number of particles to execute concurrently. 20 | ess_threshold (float): Effective sample size below which resampling is triggered, given as a fraction of `n_particles`. 21 | visualization_dir (str): Path to the directory where the visualization server is running. 22 | json_file (str): Path to the JSON file to save the record of the inference, relative to `visualization_dir` if provided. 23 | 24 | Returns: 25 | particles (list[llamppl.modeling.Model]): The completed particles after inference. 26 | """ 27 | particles = [copy.deepcopy(model) for _ in range(n_particles)] 28 | await asyncio.gather(*[p.start() for p in particles]) 29 | record = visualization_dir is not None or json_file is not None 30 | history = SMCRecord(n_particles) if record else None 31 | 32 | ancestor_indices = list(range(n_particles)) 33 | did_resample = False 34 | while any(map(lambda p: not p.done_stepping(), particles)): 35 | # Step each particle 36 | for p in particles: 37 | p.untwist() 38 | await asyncio.gather(*[p.step() for p in particles if not p.done_stepping()]) 39 | 40 | # Record history 41 | if record: 42 | if len(history.history) == 0: 43 | history.add_init(particles) 44 | elif did_resample: 45 | history.add_resample(ancestor_indices, particles) 46 | else: 47 | history.add_smc_step(particles) 48 | 49 | # Normalize weights 50 | W = np.array([p.weight for p in particles]) 51 | w_sum = logsumexp(W) 52 | normalized_weights = W - w_sum 53 | 54 | # Resample if necessary 55 | if -logsumexp(normalized_weights * 2) < np.log(ess_threshold) + np.log( 56 | n_particles 57 | ): 58 | # Alternative implementation uses a multinomial distribution and only makes n-1 copies, reusing existing one, but fine for now 59 | probs = np.exp(normalized_weights) 60 | ancestor_indices = [ 61 | np.random.choice(range(len(particles)), p=probs) 62 | for _ in range(n_particles) 63 | ] 64 | 65 | if record: 66 | # Sort the ancestor indices 67 | ancestor_indices.sort() 68 | 69 | particles = [copy.deepcopy(particles[i]) for i in ancestor_indices] 70 | avg_weight = w_sum - np.log(n_particles) 71 | for p in particles: 72 | p.weight = avg_weight 73 | 74 | did_resample = True 75 | else: 76 | did_resample = False 77 | 78 | if record: 79 | # Figure out path to save JSON. 80 | if visualization_dir is None: 81 | json_path = json_file 82 | else: 83 | timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") 84 | json_relative = ( 85 | json_file 86 | if json_file is not None 87 | else f"{model.__class__.__name__}-{timestamp}.json" 88 | ) 89 | json_path = f"{visualization_dir}/{json_file}" 90 | 91 | # Save JSON 92 | with open(json_path, "w") as f: 93 | f.write(history.to_json()) 94 | 95 | # Web path is the part of the path after the html directory 96 | if visualization_dir is not None: 97 | print(f"Visualize at http://localhost:8000/smc.html?path={json_relative}") 98 | else: 99 | print(f"Saved record to {json_path}") 100 | 101 | return particles 102 | -------------------------------------------------------------------------------- /docs/getting_started.md: -------------------------------------------------------------------------------- 1 | # Getting Started 2 | 3 | ## Colab 4 | 5 | One easy way to try LLaMPPL out is to use a Colab notebook. We have [a demo notebook](https://colab.research.google.com/drive/1uJEC-U8dcwsTWccCDGVexpgXexzZ642n?usp=sharing) that performs constrained generation with GPT-2, a small enough model that the RAM and GPU constraints of Colab's free version should not prevent you from running the demo. 6 | 7 | ## Installing LLaMPPL 8 | 9 | To get started, clone the `llamppl` repository and install the `llamppl` package. 10 | 11 | ```bash 12 | git clone https://github.com/genlm/llamppl 13 | cd llamppl 14 | poetry install 15 | ``` 16 | 17 | We use [poetry](https://python-poetry.org/) to manage dependencies. If you don't have poetry installed, you can install it with `pip install poetry`. 18 | 19 | You can then run an example. The first time you run it, the example may ask to downlaod model weights from the HuggingFace model repository. 20 | 21 | ```bash 22 | poetry run python examples/hard_constraints.py 23 | ``` 24 | 25 | Depending on your available GPU memory, you may wish to edit the example to change parameters such as the batch size, or which HuggingFace model to use. The `hard_constraints.py` example has been run successfully on an NVIDIA L4 GPU (with 24 GB of VRAM) on Google Cloud. 26 | 27 | ## Your First Model 28 | 29 | Let's write a LLaMPPL model to generate according to the hard constraint that completions do not use the lowercase letter `e`. 30 | 31 | To do so, we write subclass the [`Model`](llamppl.modeling.Model) class: 32 | 33 | ```python 34 | # examples/no_e.py 35 | 36 | from llamppl import Model, LMContext, CachedCausalLM 37 | 38 | # A LLaMPPL model subclasses the Model class 39 | class MyModel(Model): 40 | 41 | # The __init__ method is used to process arguments 42 | # and initialize instance variables. 43 | def __init__(self, lm, prompt, forbidden_letter): 44 | super().__init__() 45 | 46 | # A stateful context object for the LLM, initialized with the prompt 47 | self.context = LMContext(lm, prompt) 48 | self.eos_token = lm.tokenizer.eos_token_id 49 | 50 | # The forbidden letter 51 | self.forbidden_tokens = set(i for (i, v) in enumerate(lm.vocab) 52 | if forbidden_letter in v) 53 | 54 | # The step method is used to perform a single 'step' of generation. 55 | # This might be a single token, a single phrase, or any other division. 56 | # Here, we generate one token at a time. 57 | async def step(self): 58 | # Condition on the next token *not* being a forbidden token. 59 | await self.observe(self.context.mask_dist(self.forbidden_tokens), False) 60 | 61 | # Sample the next token from the LLM -- automatically extends `self.context`. 62 | token = await self.sample(self.context.next_token()) 63 | 64 | # Check for EOS or end of sentence 65 | if token.token_id == self.eos_token or str(token) in ['.', '!', '?']: 66 | # Finish generation 67 | self.finish() 68 | 69 | # To improve performance, a hint that `self.forbidden_tokens` is immutable 70 | def immutable_properties(self): 71 | return set(['forbidden_tokens']) 72 | ``` 73 | 74 | To run the model, we use an inference method, like `smc_steer`: 75 | 76 | ```python 77 | import asyncio 78 | from llamppl import smc_steer 79 | 80 | # Initialize the HuggingFace model 81 | lm = CachedCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", auth_token=) 82 | 83 | # Create a model instance 84 | model = MyModel(lm, "The weather today is expected to be", "e") 85 | 86 | # Run inference 87 | particles = asyncio.run(smc_steer(model, 5, 3)) # number of particles N, and beam factor K 88 | ``` 89 | 90 | Each returned particle is an instance of the `MyModel` class that has been `step`-ped to completion. 91 | The generated strings can be printed along with the particle weights: 92 | 93 | ```python 94 | for particle in particles: 95 | print(f"{particle.context.s} (weight: {particle.weight})") 96 | ``` 97 | 98 | 99 | ## Learning more 100 | 101 | For more intuition on language model probabilistic programming, see [our paper](https://arxiv.org/abs/2306.03081), or the rest of this documentation. 102 | -------------------------------------------------------------------------------- /examples/hard_constraints.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import string 3 | 4 | from llamppl import CachedCausalLM 5 | from llamppl import LMContext 6 | from llamppl import Model 7 | from llamppl import smc_standard 8 | 9 | 10 | def make_masks(LLM): 11 | return { 12 | i: set( 13 | j 14 | for (j, v) in enumerate(LLM.str_vocab) 15 | if j != LLM.tokenizer.eos_token_id 16 | and "\n" not in v 17 | and any(c.isalpha() or c in string.punctuation for c in v) 18 | and len(v.strip()) <= 5 19 | and (not v[0].isalpha() or i + len(v) <= 5) 20 | ) 21 | for i in range(6) 22 | } 23 | 24 | 25 | class ConstraintModel(Model): 26 | def __init__(self, LLM, prompt, max_tokens): 27 | super().__init__() 28 | self.context = LMContext(LLM, prompt) 29 | self.max_tokens = max_tokens 30 | self.masks = make_masks(LLM) 31 | self.eos_token_id = LLM.tokenizer.eos_token_id 32 | 33 | async def start(self): 34 | mask = self.active_constraint_mask() 35 | await self.observe(self.context.mask_dist(mask), True) 36 | 37 | async def step(self): 38 | # Generate proposed token. 39 | token = await self.sample(self.context.next_token()) 40 | 41 | # Reduce number of max tokens remaining 42 | self.max_tokens -= 1 43 | 44 | print(f"{self.context}") 45 | 46 | # Check if done 47 | if token == self.eos_token_id or self.max_tokens == 0: 48 | self.finish() 49 | return 50 | 51 | # Observe that next token follows the constraint. 52 | mask = self.active_constraint_mask() 53 | await self.observe(self.context.mask_dist(mask), True) 54 | 55 | def active_constraint_mask(self): 56 | string_so_far = str(self.context) 57 | words = string_so_far.split() 58 | last_word = words[-1] if len(words) > 0 else "" 59 | return self.masks[min(5, len(last_word))] 60 | 61 | def string_for_serialization(self): 62 | return f"{self.context}" 63 | 64 | def immutable_properties(self): 65 | return ["masks"] 66 | 67 | 68 | # From Politico.com 69 | prompt = """3 things to watch … 70 | 71 | 1. The return of the House means new energy for the GOP’s Biden impeachment push, and Democrats are starting their pushback early. Rep. Jamie Raskin (D-Md.) is out this morning with a 14-page rebuttal memo that seeks to paint the GOP campaign as a “complete and total bust” and an attempt at distracting from the “overwhelming evidence of [Trump’s] criminal and corrupt conduct during his term of office.” 72 | 73 | 2. The Senate is back this evening for a bed-check vote. With Minority Leader Mitch McConnell having successfully quieted (public) chatter about his health, expect senators to be quizzed anew about Sen. Tommy Tuberville’s (R-Ala.) Pentagon nominee blockade, especially with the Joint Chiefs chair, Gen. Mark Milley, just weeks away from retirement and the confirmation of his successor, Gen. C.Q. Brown, in limbo. 74 | 75 | 3.""" 76 | 77 | 78 | async def run_example(LLM, max_tokens=50, n_particles=20, ess_threshold=0.5): 79 | # Cache the key value vectors for the prompt. 80 | LLM.cache_kv(LLM.tokenizer.encode(prompt)) 81 | 82 | # Initialize the Model. 83 | constraint_model = ConstraintModel(LLM, prompt, max_tokens) 84 | 85 | # Run inference. 86 | particles = await smc_standard( 87 | constraint_model, n_particles, ess_threshold, "html", "results/output.json" 88 | ) 89 | for p in particles: 90 | print(f"{p.context}") 91 | 92 | return particles 93 | 94 | 95 | def main(): 96 | # Load the language model. 97 | # Mistral and Vicuna are open models; to use a model with restricted access, like LLaMA 3, 98 | # authenticate using the Huggingface CLI. 99 | LLM = CachedCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B") 100 | # LLM = CachedCausalLM.from_pretrained("lmsys/vicuna-7b-v1.5") 101 | # LLM = CachedCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1") 102 | 103 | # Set batch size if provided. This operation is only valid for the HuggingFace backend. 104 | if LLM.backend == "hf": 105 | LLM.batch_size = 40 106 | 107 | # Run the example. 108 | asyncio.run(run_example(LLM)) 109 | 110 | 111 | if __name__ == "__main__": 112 | main() 113 | -------------------------------------------------------------------------------- /llamppl/inference/smc_steer.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import copy 3 | 4 | import numpy as np 5 | 6 | from ..util import logsumexp 7 | from ..util import softmax 8 | 9 | 10 | def find_c(weights, N): 11 | # Sort the weights 12 | sorted_weights = np.sort(weights) 13 | # Find the smallest chi 14 | B_val = 0.0 15 | A_val = len(weights) 16 | for i in range(len(sorted_weights)): 17 | chi = sorted_weights[i] 18 | # Calculate A_val -- number of weights larger than chi 19 | A_val -= 1 20 | # Update B_val -- add the sum of weights smaller than or equal to chi 21 | B_val += chi 22 | if B_val / chi + A_val - N <= 1e-12: 23 | return (N - A_val) / B_val 24 | return N 25 | 26 | 27 | def resample_optimal(weights, N): 28 | c = find_c(weights, N) 29 | # Weights for which c * w >= 1 are deterministically resampled 30 | deterministic = np.where(c * weights >= 1)[0] 31 | # Weights for which c * w <= 1 are stochastically resampled 32 | stochastic = np.where(c * weights < 1)[0] 33 | # Stratified sampling to generate N-len(deterministic) indices 34 | # from the stochastic weights 35 | n_stochastic = len(stochastic) 36 | n_resample = N - len(deterministic) 37 | if n_resample == 0: 38 | return deterministic, np.array([], dtype=int), c 39 | K = np.sum(weights[stochastic]) / (n_resample) 40 | u = np.random.uniform(0, K) 41 | i = 0 42 | stoch_resampled = np.array([], dtype=int) 43 | while i < n_stochastic: 44 | u = u - weights[stochastic[i]] 45 | if u <= 0: 46 | # Add stochastic[i] to resampled indices 47 | stoch_resampled = np.append(stoch_resampled, stochastic[i]) 48 | # Update u 49 | u = u + K 50 | i = i + 1 51 | else: 52 | i += 1 53 | # Concatenate the deterministic and stochastic resampled indices 54 | # resampled = np.concatenate((deterministic, stoch_resampled)) 55 | # return resampled 56 | return deterministic, stoch_resampled, c 57 | 58 | 59 | async def smc_steer(model, n_particles, n_beam): 60 | """ 61 | Modified sequential Monte Carlo algorithm that uses without-replacement resampling, 62 | as described in [our workshop abstract](https://arxiv.org/abs/2306.03081). 63 | 64 | Args: 65 | model (llamppl.modeling.Model): The model to perform inference on. 66 | n_particles (int): Number of particles to maintain. 67 | n_beam (int): Number of continuations to consider for each particle. 68 | 69 | Returns: 70 | particles (list[llamppl.modeling.Model]): The completed particles after inference. 71 | """ 72 | # Create n_particles copies of the model 73 | particles = [copy.deepcopy(model) for _ in range(n_particles)] 74 | await asyncio.gather(*[p.start() for p in particles]) 75 | 76 | while any(map(lambda p: not p.done_stepping(), particles)): 77 | # Count the number of finished particles 78 | n_finished = sum(map(lambda p: p.done_stepping(), particles)) 79 | n_total = n_finished + (n_particles - n_finished) * n_beam 80 | 81 | # Create a super-list of particles that has n_beam copies of each 82 | super_particles = [] 83 | for p in particles: 84 | p.untwist() 85 | super_particles.append(p) 86 | if p.done_stepping(): 87 | p.weight += np.log(n_total) - np.log(n_particles) 88 | else: 89 | p.weight += np.log(n_total) - np.log(n_particles) - np.log(n_beam) 90 | super_particles.extend([copy.deepcopy(p) for _ in range(n_beam - 1)]) 91 | 92 | # Step each super-particle 93 | await asyncio.gather( 94 | *[p.step() for p in super_particles if not p.done_stepping()] 95 | ) 96 | 97 | # Use optimal resampling to resample 98 | W = np.array([p.weight for p in super_particles]) 99 | W_tot = logsumexp(W) 100 | W_normalized = softmax(W) 101 | det_indices, stoch_indices, c = resample_optimal(W_normalized, n_particles) 102 | particles = [ 103 | super_particles[i] for i in np.concatenate((det_indices, stoch_indices)) 104 | ] 105 | # For deterministic particles: w = w * N/N' 106 | for i in det_indices: 107 | super_particles[i].weight += np.log(n_particles) - np.log(n_total) 108 | # For stochastic particles: w = 1/c * total sum(stoch weights) / num_stoch = sum(stoch weights / total) / num_stoch * total * N/M 109 | for i in stoch_indices: 110 | super_particles[i].weight = ( 111 | W_tot - np.log(c) + np.log(n_particles) - np.log(n_total) 112 | ) 113 | 114 | # Return the particles 115 | return particles 116 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LLaMPPL 2 | 3 | [![docs](https://github.com/genlm/llamppl/actions/workflows/docs.yml/badge.svg)](https://genlm.github.io/llamppl) 4 | [![Tests](https://github.com/genlm/llamppl/actions/workflows/tests.yml/badge.svg)](https://github.com/genlm/llamppl/actions/workflows/tests.yml) 5 | [![codecov](https://codecov.io/gh/genlm/llamppl/graph/badge.svg?token=pgVQBiqCuM)](https://codecov.io/gh/genlm/llamppl) 6 | 7 | 8 | LLaMPPL is a research prototype for language model probabilistic programming: specifying language generation tasks by writing probabilistic programs that combine calls to LLMs, symbolic program logic, and probabilistic conditioning. To solve these tasks, LLaMPPL uses a specialized sequential Monte Carlo inference algorithm. This technique, SMC steering, is described in [our recent workshop abstract](https://arxiv.org/abs/2306.03081). 9 | 10 | This library was formerly known as `hfppl`. 11 | 12 | ## Installation 13 | 14 | If you just want to try out LLaMPPL, check out our [demo notebook on Colab](https://colab.research.google.com/drive/1uJEC-U8dcwsTWccCDGVexpgXexzZ642n?usp=sharing), which performs a simple constrained generation task using GPT-2. (Larger models may require more RAM or GPU resources than Colab's free version provides.) 15 | 16 | To get started on your own machine, you can install this library from PyPI: 17 | 18 | ``` 19 | pip install llamppl 20 | ``` 21 | 22 | For faster inference on Apple Silicon devices, you can install with MLX backend: 23 | 24 | ```bash 25 | pip install llamppl[mlx] 26 | ``` 27 | 28 | ### Local installation 29 | 30 | For local development, clone this repository and run `pip install -e ".[dev,examples]"` to install `llamppl` and its development dependencies. 31 | 32 | ``` 33 | git clone https://github.com/genlm/llamppl 34 | cd llamppl 35 | pip install -e ".[dev,examples]" 36 | ``` 37 | 38 | Then, try running an example. Note that this will cause the weights of a HuggingFace model to be downloaded. 39 | 40 | ``` 41 | python examples/hard_constraints.py 42 | ``` 43 | 44 | If everything is working, you should see the model generate political news using words that are at most five letters long (e.g., "Dr. Jill Biden may still be a year away from the White House but she is set to make her first trip to the U.N. today."). 45 | 46 | ## Modeling with LLaMPPL 47 | 48 | A LLaMPPL program is a subclass of the `llamppl.Model` class. 49 | 50 | ```python 51 | from llamppl import Model, LMContext, CachedCausalLM 52 | 53 | # A LLaMPPL model subclasses the Model class 54 | class MyModel(Model): 55 | 56 | # The __init__ method is used to process arguments 57 | # and initialize instance variables. 58 | def __init__(self, lm, prompt, forbidden_letter): 59 | super().__init__() 60 | 61 | # A stateful context object for the LLM, initialized with the prompt 62 | self.context = LMContext(lm, prompt) 63 | self.eos_token = lm.tokenizer.eos_token_id 64 | 65 | # The forbidden letter 66 | self.forbidden_tokens = set(i for (i, v) in enumerate(lm.vocab) 67 | if forbidden_letter in v) 68 | 69 | # The step method is used to perform a single 'step' of generation. 70 | # This might be a single token, a single phrase, or any other division. 71 | # Here, we generate one token at a time. 72 | async def step(self): 73 | # Condition on the next token *not* being a forbidden token. 74 | await self.observe(self.context.mask_dist(self.forbidden_tokens), False) 75 | 76 | # Sample the next token from the LLM -- automatically extends `self.context`. 77 | token = await self.sample(self.context.next_token()) 78 | 79 | # Check for EOS or end of sentence 80 | if token.token_id == self.eos_token or str(token) in ['.', '!', '?']: 81 | # Finish generation 82 | self.finish() 83 | 84 | # To improve performance, a hint that `self.forbidden_tokens` is immutable 85 | def immutable_properties(self): 86 | return set(['forbidden_tokens']) 87 | ``` 88 | 89 | The Model class provides a number of useful methods for specifying a LLaMPPL program: 90 | 91 | * `self.sample(dist[, proposal])` samples from the given distribution. Providing a proposal does not modify the task description, but can improve inference. Here, for example, we use a proposal that pre-emptively avoids the forbidden letter. 92 | * `self.condition(cond)` conditions on the given Boolean expression. 93 | * `self.finish()` indicates that generation is complete. 94 | * `self.observe(dist, obs)` performs a form of 'soft conditioning' on the given distribution. It is equivalent to (but more efficient than) sampling a value `v` from `dist` and then immediately running `condition(v == obs)`. 95 | 96 | To run inference, we use the `smc_steer` or `smc_standard` methods: 97 | 98 | ```python 99 | import asyncio 100 | from llamppl import smc_steer 101 | 102 | # Initialize the language model 103 | lm = CachedCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") 104 | 105 | # Create a model instance 106 | model = MyModel(lm, "The weather today is expected to be", "e") 107 | 108 | # Run inference 109 | particles = asyncio.run(smc_steer(model, 5, 3)) # number of particles N, and beam factor K 110 | ``` 111 | 112 | Sample output: 113 | 114 | ``` 115 | sunny. 116 | sunny and cool. 117 | 34° (81°F) in Chicago with winds at 5mph. 118 | 34° (81°F) in Chicago with winds at 2-9 mph. 119 | hot and humid with a possibility of rain, which is not uncommon for this part of Mississippi. 120 | ``` 121 | 122 | Further documentation can be found at https://genlm.github.io/llamppl. 123 | -------------------------------------------------------------------------------- /examples/haiku.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | import nltk 4 | 5 | from llamppl import CachedCausalLM 6 | from llamppl import LMContext 7 | from llamppl import Model 8 | from llamppl import sample_word 9 | from llamppl import smc_standard 10 | 11 | # download the CMU pronunciation dictionary (if we haven't already) 12 | nltk.download("cmudict") 13 | 14 | # Load the CMU pronunciation dictionary and use it for syllable counting 15 | from nltk.corpus import cmudict 16 | 17 | CMUDICT = cmudict.dict() 18 | 19 | 20 | def count_syllables(word, unknown_word_syllables=100): 21 | # Use the dictionary to get the list of possible phonetic representations for the word 22 | phonetic_transcriptions = CMUDICT.get(word.strip().lower(), []) 23 | 24 | # Count the number of syllables based on the number of phonetic transcriptions 25 | syllable_count = min( 26 | [ 27 | len([ph for ph in transcription if ph[-1].isdigit()]) 28 | for transcription in phonetic_transcriptions 29 | ], 30 | default=unknown_word_syllables, 31 | ) 32 | 33 | return syllable_count 34 | 35 | 36 | # Example poems for the prompt. 37 | # Authors: 38 | # - Amy Lowell 39 | # - Sonia Sanchez 40 | # - Katsushika Hokusai 41 | # - Matsuo Basho 42 | # Note that not all of these follow the syllabic constraints of a Haiku; the goal is 43 | # to encode a certain 'poetic style' but to leave the syllabic constraints to be enforced 44 | # by the probabilistic program (enabling generalization to other syllabic constraints). 45 | EXAMPLE_POEMS = """Example poems. Note how they tend to end on a somewhat surprising or otherwise satisfying note, and are not repetitive at the end. 46 | 47 | 1. "Portrait" 48 | Sweet smell of wet flowers 49 | Over an evening garden. 50 | Your portrait, perhaps? 51 | 52 | 2. "River of Love" 53 | love between us is 54 | speech and breath. loving you is 55 | a long river running. 56 | 57 | 3. "Practice" 58 | I write, erase, rewrite 59 | Erase again, and then 60 | A poppy blooms. 61 | 62 | 4. "Caterpillar" 63 | A caterpillar, 64 | this deep in fall, 65 | still not a butterfly.""" 66 | 67 | 68 | # LLaMPPL model 69 | class Haiku(Model): 70 | def __init__(self, LLM, prompt, syllable_pattern=[5, 7, 5]): 71 | super().__init__() 72 | self.context = LMContext(LLM, prompt) 73 | self.syllable_pattern = syllable_pattern 74 | self.previous_string = str(self.context) 75 | self.newline_token = LLM.str_vocab.index("\n") 76 | self.eos_token = LLM.tokenizer.eos_token_id 77 | 78 | async def step(self): 79 | self.previous_string = str(self.context) 80 | 81 | # Get the number of syllables required in the next line 82 | syllables_remaining = self.syllable_pattern.pop(0) 83 | 84 | # Loop to sample words until this line is over 85 | while syllables_remaining > 0: 86 | # Sample a word 87 | word, punctuation = await self.call(sample_word(self.context)) 88 | 89 | # Subtract syllables from the remaining count 90 | syllables_remaining -= count_syllables(word) 91 | 92 | # Reject if we overshot 93 | self.condition(syllables_remaining == 0) 94 | 95 | # If there are no more lines, finish 96 | if not self.syllable_pattern: 97 | await self.observe(self.context.next_token(), self.eos_token) 98 | self.finish() 99 | return 100 | 101 | # Otherwise, observe a line break 102 | await self.observe(self.context.next_token(), self.newline_token) 103 | 104 | # Print current result 105 | print(str(self.context)) 106 | 107 | def string_for_serialization(self): 108 | # Replace newlines with slashes in str(self.context) 109 | s = ( 110 | self.previous_string 111 | + "<<<" 112 | + str(self.context)[len(self.previous_string) :] 113 | + ">>>" 114 | ) 115 | return s.replace("\n", "/") 116 | 117 | 118 | async def run_example( 119 | LLM, poem_title, syllable_pattern=[5, 7, 5], n_particles=20, ess_threshold=0.5 120 | ): 121 | # Construct prompt 122 | prompt = f"""{EXAMPLE_POEMS} 123 | 124 | 5. "{poem_title}" 125 | """ 126 | 127 | # Cache the key value vectors for the prompt 128 | LLM.cache_kv(LLM.tokenizer.encode(prompt)) 129 | 130 | # Initialize the Model 131 | haiku_model = Haiku(LLM, prompt, syllable_pattern) 132 | 133 | # Run inference 134 | particles = await smc_standard( 135 | haiku_model, n_particles, ess_threshold, "html", "results/haiku.json" 136 | ) 137 | 138 | return particles 139 | 140 | 141 | def main(): 142 | # Load the language model. 143 | # Mistral is an open model; to use a model with restricted access, like LLaMA 3, 144 | # authenticate using the Huggingface CLI. 145 | LLM = CachedCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B") 146 | # LLM = CachedCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1") 147 | 148 | # Set batch size if using HuggingFace backend 149 | if LLM.backend == "hf": 150 | LLM.batch_size = 40 151 | 152 | # Get poem title from user 153 | poem_title = input("Enter a title for your Haiku: ") 154 | 155 | syllables_per_line = [5, 7, 5] # [5, 3, 5] for a Lune 156 | 157 | # Run the example 158 | particles = asyncio.run( 159 | run_example(LLM, poem_title, syllable_pattern=syllables_per_line) 160 | ) 161 | 162 | print("--------") 163 | for i, particle in enumerate(particles): 164 | print(f"\nPoem {i} (weight {particle.weight}):") 165 | print(f"{particle.context}") 166 | 167 | 168 | if __name__ == "__main__": 169 | main() 170 | -------------------------------------------------------------------------------- /llamppl/chunks.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import string 3 | 4 | from .modeling import submodel 5 | 6 | 7 | @submodel 8 | async def sample_word(self, context, max_tokens=5, allow_punctuation=True): 9 | """Sample a word from the `LMContext` object `context`.""" 10 | last_token = ( 11 | context.lm.str_vocab[context.tokens[-1]] if len(context.tokens) > 0 else "" 12 | ) 13 | last_character = last_token[-1] if len(last_token) > 0 else "" 14 | needs_space = last_character not in string.whitespace and last_character not in [ 15 | "-", 16 | "'", 17 | '"', 18 | ] 19 | if needs_space: 20 | starts_word_mask = context.lm.masks.STARTS_NEW_WORD 21 | else: 22 | starts_word_mask = context.lm.masks.CONTINUES_CURRENT_WORD 23 | 24 | # Force model to start a new word 25 | await self.observe(context.mask_dist(starts_word_mask), True) 26 | 27 | word = "" 28 | num_tokens = 0 29 | while True: 30 | token = await self.sample(context.next_token()) 31 | word += context.lm.str_vocab[token.token_id] 32 | num_tokens += 1 33 | 34 | if num_tokens == max_tokens: 35 | await self.observe( 36 | context.mask_dist(context.lm.masks.CONTINUES_CURRENT_WORD), False 37 | ) 38 | break 39 | 40 | if not ( 41 | await self.sample( 42 | context.mask_dist(context.lm.masks.CONTINUES_CURRENT_WORD) 43 | ) 44 | ): 45 | break 46 | 47 | # Sample punctuation, if desired 48 | punctuation = "" 49 | if allow_punctuation and await self.sample( 50 | context.mask_dist(context.lm.masks.PUNCTUATION) 51 | ): 52 | punctuation_token = await self.sample(context.next_token()) 53 | punctuation = context.lm.str_vocab[punctuation_token.token_id] 54 | 55 | return word, punctuation 56 | 57 | 58 | @submodel 59 | async def sample_word_2( 60 | self, 61 | context, 62 | max_chars: int = None, 63 | allow_mid_punctuation: bool = True, 64 | allow_end_punctuation: bool = True, 65 | ): 66 | """Sample a word from the `LMContext` object `context`. 67 | 68 | Unlike sample_word() above, this method allows for character-level control over the length of the word. 69 | It also allows for control over the presence of punctuation in the middle and at the end of the word. 70 | 71 | Args: 72 | max_chars (int): Maximum number of characters in the word. If None, the model will sample a word of any length. 73 | allow_mid_punctuation (bool): If True, the model may sample punctuation in the middle of the word. 74 | allow_end_punctuation (bool): If True, the model may sample punctuation at the end of the word. 75 | 76 | Returns: 77 | Tuple[str, str]: The sampled word and punctuation 78 | """ 79 | # NOTE: Yields control back to the event loop. Necessary to allow timeouts to work correctly when this method is called in a loop. 80 | await asyncio.sleep(0) 81 | 82 | # This approach sometimes breaks with max_chars = 1 83 | if max_chars is not None: 84 | assert max_chars > 1 85 | 86 | last_token = ( 87 | context.lm.str_vocab[context.tokens[-1]] if len(context.tokens) > 0 else "" 88 | ) 89 | last_character = last_token[-1] if len(last_token) > 0 else "" 90 | needs_space = last_character not in string.whitespace and last_character not in [ 91 | "-", 92 | "'", 93 | '"', 94 | ] 95 | if needs_space: 96 | starts_word_mask = context.lm.masks.STARTS_NEW_WORD 97 | else: 98 | starts_word_mask = context.lm.masks.CONTINUES_CURRENT_WORD 99 | 100 | # Force model to start a new word 101 | await self.observe(context.mask_dist(starts_word_mask), True) 102 | 103 | word = "" 104 | while True: 105 | # Force model to sample a token with an appropriate number of characters 106 | if max_chars is not None: 107 | await self.observe( 108 | self.context.mask_dist( 109 | self.context.lm.masks.token_length_mask( 110 | max_chars=max_chars - len(word.strip()) 111 | ) 112 | ), 113 | True, 114 | ) 115 | 116 | token = await self.sample(context.next_token()) 117 | word += context.lm.str_vocab[token.token_id] 118 | 119 | # If we ran out of chars, break 120 | if max_chars is not None and len(word.strip()) >= max_chars: 121 | await self.observe( 122 | context.mask_dist(context.lm.masks.CONTINUES_CURRENT_WORD), False 123 | ) 124 | break 125 | 126 | # If the model wants to end the word, break 127 | if not ( 128 | await self.sample( 129 | context.mask_dist(context.lm.masks.CONTINUES_CURRENT_WORD) 130 | ) 131 | ): 132 | break 133 | 134 | # Sample punctuation, if desired 135 | mid_punctuation, end_punctuation = "", "" 136 | 137 | mask = set() 138 | if allow_mid_punctuation: 139 | mask = mask | context.lm.masks.MID_PUNCTUATION 140 | if allow_end_punctuation: 141 | mask = mask | context.lm.masks.END_PUNCTUATION 142 | 143 | if mask and await self.sample(context.mask_dist(mask)): 144 | token = await self.sample(context.next_token()) 145 | if token.token_id in context.lm.masks.MID_PUNCTUATION: 146 | mid_punctuation = context.lm.str_vocab[token.token_id] 147 | if token.token_id in context.lm.masks.END_PUNCTUATION: 148 | end_punctuation = context.lm.str_vocab[token.token_id] 149 | 150 | return word, mid_punctuation, end_punctuation 151 | -------------------------------------------------------------------------------- /examples/grammar_constraint.py: -------------------------------------------------------------------------------- 1 | """SMC Steering with Grammar Constraints 2 | 3 | Author: Gabriel Grand (grandg@mit.edu) 4 | 5 | This example illustrates grammar-constrained inference with SMC Steering. 6 | `GrammarConstrainedSMC` takes as input a grammar in Lark format. 7 | We use the Synchromesh (Poesia et al., 2022) to align the grammar with the 8 | language model vocabulary. 9 | 10 | Requires synchromesh (github.com/kanishkg/synchromesh) 11 | """ 12 | 13 | import asyncio 14 | 15 | from synchromesh.completion_engine import LarkCompletionEngine 16 | from synchromesh.synchromesh import StreamingCSD 17 | 18 | from llamppl.distributions import LMContext 19 | from llamppl.inference import smc_standard 20 | from llamppl.llms import CachedCausalLM 21 | from llamppl.modeling import Model 22 | 23 | 24 | class GrammarConstrainedSMC(Model): 25 | def __init__( 26 | self, 27 | lm: CachedCausalLM, 28 | grammar: str, 29 | start_rule: str, 30 | prompt: str = None, 31 | allow_ws: bool = False, 32 | max_tokens: int = 32, 33 | verbose: bool = False, 34 | ): 35 | super().__init__() 36 | self.lm = lm 37 | self.grammar = grammar 38 | self.context = LMContext(lm, prompt) 39 | self.vocab = self.lm.str_vocab 40 | self.eos_token_id = self.lm.tokenizer.eos_token_id 41 | 42 | self.comp_engine = LarkCompletionEngine( 43 | grammar, start_token=start_rule, allow_ws=allow_ws 44 | ) 45 | self.csd = StreamingCSD( 46 | completion_engine=self.comp_engine, 47 | lm_vocabulary=self.vocab, 48 | enforce_token_maximality=False, 49 | ) 50 | 51 | self.max_tokens = max_tokens 52 | self.n_tokens = 0 53 | 54 | self.verbose = verbose 55 | 56 | async def step(self): 57 | # Get valid tokens for next step 58 | valid_token_ids = self.csd.get_valid_tokens() 59 | 60 | # If generation is a complete derivation, allow the end-of-string token 61 | if self.csd.is_complete(): 62 | valid_token_ids += [self.eos_token_id] 63 | 64 | # If no valid next tokens, reject and terminate 65 | if len(valid_token_ids) == 0: 66 | self.condition(False) 67 | return 68 | 69 | # Sample a token from the valid tokens 70 | await self.observe(self.context.mask_dist(set(valid_token_ids)), True) 71 | token = await self.sample(self.context.next_token()) 72 | 73 | # If the token is the end-of-string token, accept and terminate 74 | if token.token_id == self.eos_token_id: 75 | self.finish() 76 | return 77 | 78 | # Feed the token to StreamingCSD 79 | self.csd.feed_prediction(token.token_id) 80 | self.n_tokens += 1 81 | 82 | if self.verbose: 83 | print(str(self.context)) 84 | 85 | # Max tokens reached 86 | if self.n_tokens >= self.max_tokens: 87 | self.condition(False) 88 | self.finish() 89 | 90 | def immutable_properties(self): 91 | return set( 92 | [ 93 | "grammar", 94 | "max_tokens", 95 | "verbose", 96 | ] 97 | ) 98 | 99 | 100 | EXAMPLE_PROMPT = """Paraphrase the following sentences 101 | Human:who teaches CSE101? 102 | Bot:instructor of CSE101 103 | Human:how many students can enroll in PSY456? 104 | Bot:capacity of PSY456 105 | Human:at what school is BIO433 taught? 106 | Bot:""" 107 | 108 | EXAMPLE_GRAMMAR = r""" 109 | ?start: " "? function " of " dept code 110 | function: "instructor" | "students" | "capacity" | "department" | "school" | "college" 111 | dept: /[A-Z]{3}/ 112 | code: /[0-9]{3}/ 113 | """ 114 | 115 | 116 | async def run_generation( 117 | model: str, 118 | grammar: str, 119 | start_rule: str, 120 | prompt: str = None, 121 | allow_ws: bool = False, 122 | n_particles: int = 5, 123 | max_tokens: int = 32, 124 | verbose: bool = False, 125 | ): 126 | LLM = CachedCausalLM.from_pretrained(args.model) 127 | if LLM.backend == "hf": 128 | LLM.batch_size = args.batch_size 129 | model = GrammarConstrainedSMC( 130 | lm=LLM, 131 | grammar=grammar, 132 | start_rule=start_rule, 133 | prompt=prompt, 134 | max_tokens=max_tokens, 135 | allow_ws=allow_ws, 136 | verbose=verbose, 137 | ) 138 | particles = await smc_standard(model, n_particles=n_particles) 139 | particles_sorted = sorted(particles, key=lambda p: p.weight, reverse=True) 140 | print([(p.weight, str(p.context)) for p in particles_sorted]) 141 | 142 | 143 | if __name__ == "__main__": 144 | import argparse 145 | 146 | parser = argparse.ArgumentParser() 147 | parser.add_argument( 148 | "--model", 149 | type=str, 150 | default="codellama/CodeLlama-7b-hf", 151 | help="Name of the HuggingFace model to use", 152 | ) 153 | parser.add_argument( 154 | "--grammar", 155 | type=str, 156 | default=None, 157 | help="Path to the grammar file", 158 | ) 159 | parser.add_argument( 160 | "--start-rule", 161 | type=str, 162 | default="start", 163 | help="Name of the start rule in the grammar", 164 | ) 165 | parser.add_argument( 166 | "--prompt", 167 | type=str, 168 | default=None, 169 | help="Prompt to start generation from", 170 | ) 171 | parser.add_argument( 172 | "--n-particles", 173 | type=int, 174 | default=5, 175 | help="Number of particles to use in SMC", 176 | ) 177 | parser.add_argument( 178 | "--max-tokens", 179 | type=int, 180 | default=32, 181 | help="Maximum number of tokens to generate", 182 | ) 183 | parser.add_argument( 184 | "--allow-ws", 185 | action="store_true", 186 | help="Allow whitespace", 187 | ) 188 | parser.add_argument( 189 | "--verbose", 190 | action="store_true", 191 | help="Print intermediate generations", 192 | ) 193 | args = parser.parse_args() 194 | 195 | if args.grammar is not None: 196 | # Load the grammar 197 | with open(args.grammar, "r") as f: 198 | grammar = f.read() 199 | else: 200 | grammar = EXAMPLE_GRAMMAR 201 | 202 | prompt = args.prompt or EXAMPLE_PROMPT 203 | 204 | asyncio.run( 205 | run_generation( 206 | model=args.model, 207 | grammar=grammar, 208 | start_rule=args.start_rule, 209 | prompt=prompt, 210 | n_particles=args.n_particles, 211 | max_tokens=args.max_tokens, 212 | allow_ws=args.allow_ws, 213 | verbose=args.verbose, 214 | ) 215 | ) 216 | -------------------------------------------------------------------------------- /llamppl/modeling.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | 4 | class SubModel: 5 | def __init__(self): 6 | self.parent = None 7 | 8 | async def run_with_parent(self, parent): 9 | old_parent = self.parent 10 | self.parent = parent 11 | val = await self.forward() 12 | self.parent = old_parent 13 | return val 14 | 15 | async def forward(self): 16 | raise NotImplementedError( 17 | "SubModel.forward() must be implemented by subclasses" 18 | ) 19 | 20 | async def sample(self, dist, proposal=None): 21 | return await self.parent.sample(dist, proposal) 22 | 23 | async def observe(self, dist, x): 24 | return await self.parent.observe(dist, x) 25 | 26 | async def intervene(self, dist, x): 27 | return await self.parent.intervene(dist, x) 28 | 29 | def condition(self, b): 30 | return self.parent.condition(b) 31 | 32 | def score(self, score): 33 | return self.parent.score(score) 34 | 35 | def twist(self, amt): 36 | return self.parent.twist(amt) 37 | 38 | async def call(self, submodel): 39 | return await submodel.run_with_parent(self.parent) 40 | 41 | 42 | # For use as a decorator 43 | import functools 44 | 45 | 46 | def submodel(f): 47 | """Decorator to create a SubModel implementation from an async function. 48 | 49 | For example: 50 | 51 | ```python 52 | @submodel 53 | async def sample_two_tokens(self, context): 54 | token1 = await self.sample(context.next_token()) 55 | token2 = await self.sample(context.next_token()) 56 | return token1, token2 57 | ``` 58 | 59 | This SubModel can then be used from another model or submodel, using the syntax `await self.call(sample_two_tokens(context))`. 60 | """ 61 | 62 | @functools.wraps(f, updated=()) # unclear if this is the best way to do it 63 | class SubModelImpl(SubModel): 64 | def __init__(self, *args, **kwargs): 65 | super().__init__() 66 | self.args = args 67 | self.kwargs = kwargs 68 | 69 | async def forward(self): 70 | return await f(self, *self.args, **self.kwargs) 71 | 72 | return SubModelImpl 73 | 74 | 75 | class Model: 76 | """Base class for all LLaMPPL models. 77 | 78 | Your models should subclass this class. Minimally, you should provide an `__init__` method 79 | that calls `super().__init__(self)`, and a `step` method. 80 | """ 81 | 82 | def __init__(self): 83 | self.weight = 0.0 84 | self.finished = False 85 | self.mode = "sample" 86 | self.beam_idx = 0 87 | self.force_eos = False 88 | self.twist_amount = 0.0 89 | 90 | def reset(self): 91 | self.weight = 0.0 92 | self.finished = False 93 | self.mode = "sample" 94 | self.beam_idx = 0 95 | self.force_eos = False 96 | self.twist_amount = 0.0 97 | 98 | def immutable_properties(self): 99 | """Return a `set[str]` of properties that LLaMPPL may assume do not change during execution of `step`. 100 | This set is empty by default but can be overridden by subclasses to speed up inference. 101 | 102 | Returns: 103 | properties (set[str]): a set of immutable property names""" 104 | return set() 105 | 106 | def __deepcopy__(self, memo): 107 | cpy = type(self).__new__(type(self)) 108 | immutable = self.immutable_properties() 109 | 110 | for k, v in self.__dict__.items(): 111 | if k in immutable: 112 | setattr(cpy, k, v) 113 | else: 114 | setattr(cpy, k, copy.deepcopy(v, memo)) 115 | 116 | return cpy 117 | 118 | def twist(self, amt): 119 | """Multiply this particle's weight by `exp(amt)`, but divide it back out before the next `step`. 120 | 121 | Use this method to provide heuristic guidance about whether a particle is "on the right track" 122 | without changing the ultimate target distribution. 123 | 124 | Args: 125 | amt: the logarithm of the amount by which to (temporarily) multiply this particle's weight. 126 | """ 127 | self.twist_amount += amt 128 | self.score(amt) 129 | 130 | def untwist(self): 131 | self.score(-self.twist_amount) 132 | self.twist_amount = 0.0 133 | 134 | def finish(self): 135 | self.untwist() 136 | self.finished = True 137 | 138 | def done_stepping(self): 139 | return self.finished 140 | 141 | async def step(self): 142 | """Defines the computation performed in each step of the model. 143 | 144 | All subclasses should override this method.""" 145 | 146 | if not self.done_stepping(): 147 | raise NotImplementedError("Model.step() must be implemented by subclasses") 148 | 149 | def __str__(self): 150 | return "Particle" 151 | 152 | async def start(self): 153 | pass 154 | 155 | def score(self, score): 156 | """Multiply this particle's weight by `exp(score)`. 157 | 158 | The `score` method is a low-level way to change the target distribution. 159 | For many use cases, it is sufficient to use `sample`, `observe`, `condition`, 160 | and `twist`, all of which are implemented in terms of `score`. 161 | 162 | Args: 163 | score: logarithm of the amount by which the particle's weight should be multiplied. 164 | """ 165 | self.weight += score 166 | 167 | def condition(self, b): 168 | """Constrain a given Boolean expression to be `True`. 169 | 170 | If the condition is False, the particle's weight is set to zero and `self.finish()` 171 | is called, so that no further `step` calls are made. 172 | 173 | Args: 174 | b: the Boolean expression whose value is constrained to be True. 175 | """ 176 | if not b: 177 | self.score(float("-inf")) 178 | self.finish() 179 | 180 | async def intervene(self, dist, x): 181 | """Force the distribution to take on the value `x`, but do not _condition_ on this result. 182 | 183 | This is useful primarily with distributions that have side effects (e.g., modifying some state). 184 | For example, a model with the code 185 | 186 | ```python 187 | token_1 = await self.sample(self.stateful_lm.next_token()) 188 | await self.observe(self.stateful_lm.next_token(), token_2) 189 | ``` 190 | 191 | encodes a posterior inference problem, to find `token_1` values that *likely preceded* `token_2`. By contrast, 192 | 193 | ```python 194 | token_1 = await self.sample(stateful_lm.next_token()) 195 | await self.intervene(self.stateful_lm.next_token(), token_2) 196 | ``` 197 | 198 | encodes a much easier task: freely generate `token_1` and then force-feed `token_2` as the following token. 199 | 200 | Args: 201 | dist (llamppl.distributions.distribution.Distribution): the distribution on which to intervene. 202 | x: the value to intervene with. 203 | """ 204 | await dist.log_prob(x) 205 | return x 206 | 207 | async def observe(self, dist, x): 208 | """Condition the model on the value `x` being sampled from the distribution `dist`. 209 | 210 | For discrete distributions `dist`, `await self.observe(dist, x)` specifies the same constraint as 211 | ``` 212 | val = await self.sample(dist) 213 | self.condition(val == x) 214 | ``` 215 | but can be much more efficient. 216 | 217 | Args: 218 | dist: a `Distribution` object from which to observe 219 | x: the value observed from `dist` 220 | """ 221 | p = await dist.log_prob(x) 222 | self.score(p) 223 | return x 224 | 225 | async def sample(self, dist, proposal=None): 226 | """Extend the model with a sample from a given `Distribution`, with support for autobatching. 227 | If specified, the Distribution `proposal` is used during inference to generate informed hypotheses. 228 | 229 | Args: 230 | dist: the `Distribution` object from which to sample 231 | proposal: if provided, inference algorithms will use this `Distribution` object to generate proposed samples, rather than `dist`. 232 | However, importance weights will be adjusted so that the target posterior is independent of the proposal. 233 | 234 | Returns: 235 | value: the value sampled from the distribution. 236 | """ 237 | # Special logic for beam search 238 | # if self.mode == "beam": 239 | # d = dist if proposal is None else proposal 240 | # x, w = d.argmax(self.beam_idx) 241 | # if proposal is not None: 242 | # self.score(dist.log_prob(x)) 243 | # else: 244 | # self.score(w) 245 | # return x 246 | 247 | if proposal is None: 248 | x, _ = await dist.sample() 249 | return x 250 | else: 251 | x, q = await proposal.sample() 252 | p = await dist.log_prob(x) 253 | self.score(p - q) 254 | return x 255 | 256 | async def call(self, submodel): 257 | return await submodel.run_with_parent(self) 258 | 259 | def string_for_serialization(self): 260 | """Return a string representation of the particle for serialization purposes. 261 | 262 | Returns: 263 | str: a string representation of the particle. 264 | """ 265 | return str(self) 266 | -------------------------------------------------------------------------------- /llamppl/distributions/lmcontext.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import numpy as np 4 | 5 | from ..llms import Token 6 | from ..util import log_softmax 7 | from ..util import logsumexp 8 | from .distribution import Distribution 9 | 10 | 11 | class LMNextToken(Distribution): 12 | def __init__(self, ctx): 13 | self.ctx = ctx 14 | 15 | async def log_prob(self, x): 16 | if isinstance(x, Token): 17 | x = x.token_id 18 | 19 | lp = self.ctx.next_token_logprobs[x] 20 | self.ctx.tokens.append(x) 21 | updated_logprobs = await self.ctx.lm.next_token_logprobs(self.ctx.tokens) 22 | self.ctx.next_token_logprobs = log_softmax(updated_logprobs / self.ctx.temp) 23 | self.ctx.model_mask = self.ctx.lm.masks.ALL_TOKENS 24 | 25 | return lp 26 | 27 | async def sample(self): 28 | probs = np.exp(self.ctx.next_token_logprobs) 29 | probs /= np.sum(probs) # Renormalize to fix floating point errors 30 | token_id = np.random.choice(len(probs), p=(probs)) 31 | self.ctx.tokens.append(token_id) 32 | logprob = self.ctx.next_token_logprobs[token_id] 33 | 34 | # Reset mask and update logprobs 35 | self.ctx.model_mask = self.ctx.lm.masks.ALL_TOKENS 36 | updated_logprobs = await self.ctx.lm.next_token_logprobs(self.ctx.tokens) 37 | self.ctx.next_token_logprobs = log_softmax(updated_logprobs / self.ctx.temp) 38 | 39 | t = Token( 40 | self.ctx.lm, token_id, self.ctx.lm.tokenizer.convert_ids_to_tokens(token_id) 41 | ) 42 | return t, logprob 43 | 44 | 45 | class LMTokenMask(Distribution): 46 | def __init__(self, ctx, mask): 47 | self.ctx = ctx 48 | self.mask = mask 49 | 50 | async def sample(self): 51 | newly_bad_tokens = [i for i in self.ctx.model_mask if i not in self.mask] 52 | good_tokens = [i for i in self.ctx.model_mask if i in self.mask] 53 | logprob_no_mask = logsumexp(self.ctx.next_token_logprobs[newly_bad_tokens]) 54 | if logprob_no_mask > 0: 55 | logprob_yes_mask = float("-inf") 56 | else: 57 | # When logprob_no_mask is very close to 0.0, np.log1p can raise a "divide by zero" 58 | # warning before returning -inf. We suppress this warning, because returning -inf 59 | # is the desired behavior (the LLM places no mass on 'yes'). 60 | with np.errstate(divide="ignore"): 61 | logprob_yes_mask = np.log1p(-np.exp(logprob_no_mask)) 62 | decide_no_mask = np.random.rand() < np.exp(logprob_no_mask) 63 | if decide_no_mask: 64 | self.ctx.model_mask = self.ctx.model_mask - self.mask 65 | self.ctx.next_token_logprobs[good_tokens] = float("-inf") 66 | self.ctx.next_token_logprobs -= logprob_no_mask 67 | return False, logprob_no_mask 68 | else: 69 | self.ctx.model_mask = self.ctx.model_mask.intersection(self.mask) 70 | self.ctx.next_token_logprobs[newly_bad_tokens] = float("-inf") 71 | self.ctx.next_token_logprobs -= logprob_yes_mask 72 | return True, logprob_yes_mask 73 | 74 | async def log_prob(self, v): 75 | good_tokens = ( 76 | self.ctx.model_mask.intersection(self.mask) 77 | if v 78 | else self.ctx.model_mask - self.mask 79 | ) 80 | if len(good_tokens) == 0: 81 | # If there are no good tokens, the log probability of v under the mask is -inf 82 | # However, since this method updates the model_mask as a side-effect, 83 | # this will put the context in an invalid state, so we instead raise an exception. 84 | raise NullMask( 85 | "Unable to compute log probability of mask that rules out all tokens." 86 | ) 87 | else: 88 | logprob_good = logsumexp(self.ctx.next_token_logprobs[list(good_tokens)]) 89 | 90 | bad_tokens = [i for i in self.ctx.model_mask if i not in good_tokens] 91 | self.ctx.next_token_logprobs[bad_tokens] = float("-inf") 92 | self.ctx.next_token_logprobs -= logprob_good 93 | self.ctx.model_mask = good_tokens 94 | return logprob_good 95 | 96 | 97 | class NullMask(Exception): 98 | pass 99 | 100 | 101 | class LMContext: 102 | """Represents a generation-in-progress from a language model. 103 | 104 | The state tracks two pieces of information: 105 | 106 | * A sequence of tokens — the ever-growing context for the language model. 107 | * A *current mask* — a set of tokens that have not yet been ruled out as the next token. 108 | 109 | Storing a mask enables _sub-token_ generation: models can use `LMContext` to sample 110 | the next token in _stages_, first deciding, e.g., whether to use an upper-case or lower-case 111 | first letter, and only later deciding which upper-case or lower-case token to generate. 112 | 113 | The state of a `LMContext` can be advanced in two ways: 114 | 115 | 1. Sampling, observing, or intervening the `next_token()` distribution. This causes a token 116 | to be added to the growing sequence of tokens. Supports auto-batching. 117 | 2. Sampling, observing, or intervening the `mask_dist(mask)` distribution for a given mask (set of 118 | token ids). This changes the current mask. 119 | 120 | Attributes: 121 | lm (llamppl.llms.CachedCausalLM): the language model for which this is a context 122 | tokens (list[int]): the underlying sequence of tokens, including prompt, in this context 123 | next_token_logprobs (numpy.array): numpy array holding the log probabilities for the next token. Unlike the log probabilities reported by `CachedCausalLM.next_token_logprobs`, these probabilities are rescaled for this `LMContext`'s temperature parameter, and for any active masks. This vector is managed by the `LMContext` object internally; do not mutate. 124 | temp (float): temeprature for next-token distribution (0 < temp < float('inf')) 125 | model_mask (set[int]): set of tokens that have not been ruled out as the next token. This mask is managed by the `LMContext` object internally; do not mutate. 126 | show_prompt (bool): controls whether the string representation of this `LMContext` includes the initial prompt or not. Defaults to `False`. 127 | """ 128 | 129 | def __init__(self, lm, prompt, temp=1.0, show_prompt=False, show_eos=True): 130 | """Create a new `LMContext` with a given prompt and temperature. 131 | 132 | Args: 133 | lm (llamppl.llms.CachedCausalLM): the language model for which this is a context. 134 | prompt (str): a string with which to initialize the context. Will be tokenized using `lm.tokenizer`. 135 | temp (float): temeprature for next-token distribution (0 < temp < float('inf')) 136 | 137 | Note: 138 | For async initialization of LMContext, use LMContext.create(). 139 | """ 140 | self._init_common(lm, prompt, temp, show_prompt, show_eos) 141 | self.next_token_logprobs = log_softmax( 142 | lm.next_token_logprobs_unbatched(self.tokens) / temp 143 | ) 144 | 145 | @classmethod 146 | async def create(cls, lm, prompt, temp=1.0, show_prompt=False, show_eos=True): 147 | """Asynchronously create a new `LMContext` with a given prompt and temperature.""" 148 | self = cls.__new__(cls) 149 | self._init_common(lm, prompt, temp, show_prompt, show_eos) 150 | logprobs = await lm.next_token_logprobs(self.tokens) 151 | self.next_token_logprobs = log_softmax(logprobs / temp) 152 | return self 153 | 154 | def _init_common(self, lm, prompt, temp, show_prompt, show_eos): 155 | """Initialize common attributes shared between __init__ and create.""" 156 | self.lm = lm 157 | self.tokens = lm.tokenizer.encode(prompt) 158 | self.temp = temp 159 | self.model_mask = lm.masks.ALL_TOKENS 160 | self.prompt_string_length = len(lm.tokenizer.decode(self.tokens)) 161 | self.prompt_token_count = len(self.tokens) 162 | self.show_prompt = show_prompt 163 | self.show_eos = show_eos 164 | 165 | def next_token(self): 166 | """Distribution over the next token. 167 | 168 | Sampling or observing from this distribution advances the state of this `LMContext` instance. 169 | """ 170 | return LMNextToken(self) 171 | 172 | def mask_dist(self, mask): 173 | """Bernoulli distribution, with probability of True equal to the probability that the next token of this `LMContext` belongs 174 | to the given mask. 175 | 176 | Sampling or observing from this distribution modifies the state of this `LMContext` instance, so that 177 | the `next_token()` distribution either *will* (if True) or *will not* (if False) generate a token from 178 | the given mask. 179 | 180 | Args: 181 | mask: a `set(int)` specifying which token ids are included within the mask. 182 | """ 183 | return LMTokenMask(self, mask) 184 | 185 | @property 186 | def token_count(self): 187 | return len(self.tokens) - self.prompt_token_count 188 | 189 | def __str__(self): 190 | full_string = self.lm.tokenizer.decode(self.tokens) 191 | if not self.show_prompt: 192 | full_string = full_string[self.prompt_string_length :] 193 | if not self.show_eos and full_string.endswith(self.lm.tokenizer.eos_token): 194 | full_string = full_string[: -len(self.lm.tokenizer.eos_token)] 195 | return full_string 196 | 197 | def __deepcopy__(self, memo): 198 | cpy = type(self).__new__(type(self)) 199 | 200 | for k, v in self.__dict__.items(): 201 | if k in set(["lm"]): 202 | setattr(cpy, k, v) 203 | else: 204 | setattr(cpy, k, copy.deepcopy(v, memo)) 205 | 206 | return cpy 207 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /html/smc.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 107 | 108 | 109 | 110 | 111 | 112 |

Sequential Monte Carlo - Visualization

113 | 114 |
115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 |
125 | 126 | 127 | 128 | 365 | 366 | 367 | 368 | 369 | -------------------------------------------------------------------------------- /llamppl/llms.py: -------------------------------------------------------------------------------- 1 | """Utilities for working with language models.""" 2 | 3 | import string 4 | import warnings 5 | from collections import defaultdict 6 | 7 | import torch 8 | from genlm.backend.llm import AsyncTransformer 9 | from genlm.backend.llm import AsyncVirtualLM 10 | from genlm.backend.llm import MockAsyncLM 11 | from genlm.backend.llm import AsyncMlxLM 12 | 13 | VLLM_AVAILABLE = True 14 | try: 15 | import vllm 16 | except ImportError: 17 | VLLM_AVAILABLE = False 18 | 19 | MLX_AVAILABLE = True 20 | try: 21 | import mlx_lm 22 | except ImportError: 23 | MLX_AVAILABLE = False 24 | 25 | warnings.filterwarnings("once", category=DeprecationWarning) 26 | warnings.filterwarnings("once", category=RuntimeWarning) 27 | 28 | 29 | class Masks: 30 | def __init__(self, lm): 31 | self.ALL_TOKENS = set(range(len(lm.str_vocab))) 32 | self.STARTS_NEW_WORD = set( 33 | i 34 | for (i, v) in enumerate(lm.str_vocab) 35 | if v[0] == " " 36 | and len(v) > 1 37 | and v[1] not in string.whitespace 38 | and v[1] not in string.punctuation 39 | ) 40 | self.CONTINUES_CURRENT_WORD = set( 41 | i 42 | for (i, v) in enumerate(lm.str_vocab) 43 | if all(c in "'" or c.isalpha() for c in v) 44 | ) 45 | self.MID_PUNCTUATION = set( 46 | i for (i, v) in enumerate(lm.str_vocab) if v in (",", ":", ";", "-", '"') 47 | ) 48 | self.END_PUNCTUATION = set( 49 | i for (i, v) in enumerate(lm.str_vocab) if v in (".", "!", "?") 50 | ) 51 | self.PUNCTUATION = self.MID_PUNCTUATION | self.END_PUNCTUATION 52 | self.CONTAINS_WHITESPACE = set( 53 | i 54 | for (i, v) in enumerate(lm.str_vocab) 55 | if any(c in string.whitespace for c in v) 56 | ) 57 | self.EOS = set([lm.tokenizer.eos_token_id]) 58 | 59 | self.precompute_token_lengths(lm) 60 | 61 | def precompute_token_lengths(self, lm): 62 | """Precompute the length of each token. Special tokens are considered to have length 0.""" 63 | self._token_lengths = {i: len(v) for (i, v) in enumerate(lm.str_vocab)} 64 | for i in lm.tokenizer.all_special_ids: 65 | self._token_lengths[i] = 0 66 | 67 | def token_length_mask(self, min: int = None, max: int = None): 68 | if min is None: 69 | min = 0 70 | if max is None: 71 | max = float("inf") 72 | return set( 73 | [i for i, v_len in self._token_lengths.items() if min <= v_len <= max] 74 | ) 75 | 76 | 77 | class TokenSequence: 78 | """A sequence of tokens. 79 | 80 | Supports addition (via `+` or mutating `+=`) with: 81 | 82 | * other `TokenSequence` instances (concatenation) 83 | * individual tokens, represented as integers or `Token` instances 84 | * strings, which are tokenized by `lm.tokenizer` 85 | 86 | Attributes: 87 | lm (llamppl.llms.CachedCausalLM): the language model whose vocabulary the tokens come from. 88 | seq (list[llamppl.llms.Token]): the sequence of tokens.""" 89 | 90 | def __init__(self, lm, seq=None): 91 | """Create a `TokenSequence` from a language model and a sequence. 92 | 93 | Args: 94 | lm (llamppl.llms.CachedCausalLM): the language model whose vocabulary the tokens come from. 95 | seq (str | list[int]): the sequence of token ids, or a string which will be automatically tokenized. Defaults to the singleton sequence containing a bos token. 96 | """ 97 | self.lm = lm 98 | if seq is None: 99 | self.seq = [lm.tokenizer.bos_token_id] 100 | elif isinstance(seq, str): 101 | self.seq = self.lm.tokenizer.encode(seq) 102 | else: 103 | self.seq = seq 104 | 105 | def __str__(self): 106 | return self.lm.tokenizer.decode(self.seq) 107 | 108 | def __iadd__(self, other): 109 | if isinstance(other, Token): 110 | assert other.lm is self.lm 111 | self.seq.append(other.token_id) 112 | elif isinstance(other, TokenSequence): 113 | assert other.lm is self.lm 114 | self.seq.extend(other.seq) 115 | elif isinstance(other, str): 116 | self.seq.extend(self.lm.tokenizer.encode(other, add_special_tokens=False)) 117 | elif isinstance(other, int): 118 | self.seq.append(other) 119 | else: 120 | raise RuntimeError(f"Addition not supported on {type(other)}") 121 | return self 122 | 123 | def __radd__(self, other): 124 | if isinstance(other, Token): 125 | assert other.lm is self.lm 126 | return TokenSequence(self.lm, [other.token_id, *self.seq]) 127 | elif isinstance(other, TokenSequence): 128 | assert other.lm is self.lm 129 | return TokenSequence(self.lm, other.seq + self.seq) 130 | elif isinstance(other, str): 131 | return TokenSequence( 132 | self.lm, 133 | self.lm.tokenizer.encode(other, add_special_tokens=False) + self.seq, 134 | ) 135 | elif isinstance(other, int): 136 | return TokenSequence(self.lm, [other, *self.seq]) 137 | else: 138 | raise RuntimeError(f"Addition not supported on {type(other)}") 139 | 140 | def __add__(self, other): 141 | s = TokenSequence(self.lm, self.seq) 142 | s += other 143 | return s 144 | 145 | 146 | class Token: 147 | """Class representing a token. 148 | 149 | Attributes: 150 | lm (llamppl.llms.CachedCausalLM): the language model for which this is a Token. 151 | token_id (int): the integer token id (an index into the vocabulary). 152 | token_str (str): a string, which the token represents—equal to `lm.str_vocab[token_id]`. 153 | """ 154 | 155 | def __init__(self, lm, token_id, token_str): 156 | self.lm = lm 157 | self.token_id = token_id 158 | self.token_str = token_str 159 | 160 | # Adding tokens 161 | def __add__(self, other): 162 | s = TokenSequence(self.lm, [self.token_id]) 163 | s += other 164 | return s 165 | 166 | def __radd__(self, other): 167 | s = TokenSequence(self.lm, [self.token_id]) 168 | return other + s 169 | 170 | # Support checking for EOS 171 | def __eq__(self, other): 172 | if isinstance(other, Token): 173 | return self.lm is other.lm and self.token_id == other.token_id 174 | elif isinstance(other, int): 175 | return self.token_id == other 176 | else: 177 | return self.token_str == other 178 | 179 | def __int__(self): 180 | return self.token_id 181 | 182 | def __str__(self): 183 | return self.token_str 184 | 185 | def __repr__(self): 186 | return f"<{self.token_str}|{self.token_id}>" 187 | 188 | 189 | class CachedCausalLM: 190 | """Wrapper around a [`genlm.backend.llm.AsyncLM`](https://genlm.github.io/genlm-backend/reference/genlm/backend/llm/__init__/). 191 | 192 | Attributes: 193 | model (genlm_backend.llm.AsyncLM): The underlying language model (either `AsyncVirtualLM` or `AsyncTransformer`). 194 | str_vocab (list[str]): List mapping token IDs to their string representations. 195 | byte_vocab (list[bytes]): List mapping token IDs to their byte representations. 196 | masks (Masks): Token masks for filtering logits during generation. 197 | """ 198 | 199 | @classmethod 200 | def from_pretrained(cls, model_id, backend=None, **kwargs): 201 | """Create a CachedCausalLM from a HuggingFace model name. 202 | 203 | This is a convenience method that instantiates the underlying `AsyncLM` from a HuggingFace model name. 204 | 205 | Args: 206 | model_id (str): Name or path of the HuggingFace pretrained model to load. 207 | backend (str, optional): `AsyncLM` backend to use: 208 | - 'vllm' to instantiate an `AsyncVirtualLM`; ideal for GPU usage 209 | - 'hf' for an `AsyncTransformer`; ideal for CPU usage 210 | - 'mock' for a `MockAsyncLM`; ideal for testing. 211 | - 'mlx' for an `AsyncMlxLM`; ideal for usage on devices with Apple silicon. 212 | Defaults to 'vllm' if CUDA is available, otherwise 'hf'. 213 | **kwargs: Additional keyword arguments passed to the `AsyncLM` constructor. 214 | See [`AsyncLM` documentation](https://probcomp.github.io/genlm-backend/reference/genlm_backend/llm/__init__/). 215 | 216 | Returns: 217 | CachedCausalLM: The llamppl-compatible interface to the `AsyncLM` model. 218 | """ 219 | backend = backend or ( 220 | "vllm" if (torch.cuda.is_available() and VLLM_AVAILABLE) else "hf" 221 | ) 222 | 223 | if backend == "vllm": 224 | if not VLLM_AVAILABLE: 225 | raise ValueError( 226 | "vLLM backend requested but vLLM is not installed. " 227 | "Please install vLLM with `pip install vllm`." 228 | ) 229 | model_cls = AsyncVirtualLM 230 | elif backend == "hf": 231 | model_cls = AsyncTransformer 232 | elif backend == "mock": 233 | model_cls = MockAsyncLM 234 | elif backend == "mlx": 235 | model_cls = AsyncMlxLM 236 | else: 237 | raise ValueError( 238 | f"Unknown backend: {backend}. Must be one of ['vllm', 'hf', 'mock', 'mlx']" 239 | ) 240 | 241 | # Handle legacy auth_token parameter. The ability to pass in the auth_token should 242 | # be removed in a future version since it is not supported by the vllm backend. 243 | # Users should authenticate with the HuggingFace CLI. 244 | auth_token = kwargs.pop("auth_token", None) 245 | if auth_token: 246 | if backend == "vllm": 247 | raise ValueError( 248 | "Explicitly passing auth_token is not compatible with the vLLM AsyncLM backend. " 249 | "Authenticate using `huggingface-cli login` instead." 250 | ) 251 | 252 | if "hf_opts" not in kwargs: 253 | kwargs["hf_opts"] = {} 254 | kwargs["hf_opts"]["token"] = auth_token 255 | 256 | warnings.warn( 257 | "Passing auth_token directly is deprecated and will be removed in a future version. " 258 | "Please authenticate using `huggingface-cli login` instead.", 259 | DeprecationWarning, 260 | stacklevel=2, 261 | ) 262 | 263 | load_in_8bit = kwargs.pop("load_in_8bit", False) 264 | if load_in_8bit: 265 | if "bitsandbytes_opts" not in kwargs: 266 | kwargs["bitsandbytes_opts"] = {} 267 | kwargs["bitsandbytes_opts"]["load_in_8bit"] = True 268 | 269 | warnings.warn( 270 | "load_in_8bit is deprecated and will be removed in a future version. " 271 | "Please pass `bitsandbytes_opts` instead.", 272 | DeprecationWarning, 273 | stacklevel=2, 274 | ) 275 | 276 | model = model_cls.from_name(model_id, **kwargs) 277 | 278 | return cls(model) 279 | 280 | def __init__(self, model): 281 | """ 282 | Create a `CachedCausalLM` from an `AsyncLM`. 283 | 284 | Args: 285 | model (genlm_backend.llm.AsyncLM): an `AsyncLM` instance. 286 | """ 287 | if isinstance(model, AsyncVirtualLM): 288 | self.backend = "vllm" 289 | elif isinstance(model, AsyncTransformer): 290 | self.backend = "hf" 291 | elif isinstance(model, MockAsyncLM): 292 | self.backend = "mock" 293 | elif isinstance(model, AsyncMlxLM): 294 | self.backend = "mlx" 295 | else: 296 | raise ValueError( 297 | f"Unknown model type: {type(model)}. Must be one of [AsyncVirtualLM, AsyncTransformer, MockAsyncLM, AsyncMlxLM]" 298 | ) 299 | 300 | self.model = model 301 | self.tokenizer = model.tokenizer 302 | self.str_vocab = model.str_vocab 303 | self.byte_vocab = model.byte_vocab 304 | self.masks = Masks(self) 305 | 306 | @property 307 | def vocab(self): 308 | """Legacy accessor for string vocabulary. Prefer using `.str_vocab` directly for access to the model's string vocabulary.""" 309 | warnings.warn( 310 | "Accessing .vocab directly is deprecated and will be removed in a future version. Use .str_vocab or .byte_vocab instead.", 311 | DeprecationWarning, 312 | stacklevel=2, 313 | ) 314 | return self.model.str_vocab 315 | 316 | def __deepcopy__(self, memo): 317 | return self 318 | 319 | async def next_token_logprobs(self, token_ids): 320 | """Request log probabilities of next token. This version is asynchronous and support auto batching of concurrent requests; use with `await`. 321 | 322 | Args: 323 | token_ids (list[int]): a list of token ids, representing a prompt to the language model. 324 | 325 | Returns: 326 | logprobs (numpy.array): a numpy array of length `len(str_vocab)` (equivalently `len(byte_vocab)`) with the language model's log (normalized) probabilities for the next token following the prompt. 327 | """ 328 | logprobs = await self.model.next_token_logprobs(token_ids) 329 | return logprobs.float().cpu().numpy() 330 | 331 | def next_token_logprobs_unbatched(self, token_ids): 332 | """Request log probabilities of next token. Not asynchronous, and does not support auto-batching. 333 | 334 | Args: 335 | token_ids (list[int]): a list of token ids, representing a prompt to the language model. 336 | 337 | Returns: 338 | logprobs (numpy.array): a numpy array of length `len(str_vocab)` (equivalently `len(byte_vocab)`) with the language model's log (normalized) probabilities for the next token following the prompt. 339 | """ 340 | return self.model.next_token_logprobs_sync(token_ids).float().cpu().numpy() 341 | 342 | def clear_cache(self): 343 | """Clear the cache of log probabilities and key/value pairs. 344 | 345 | For HuggingFace backend: Clears both logprob cache and KV cache. 346 | 347 | For vLLM backend: Only clears logprob cache (KV cache is managed internally by vLLM). 348 | """ 349 | self.model.clear_cache() 350 | 351 | def clear_kv_cache(self): 352 | """Clear any key and value vectors from the cache.""" 353 | if self.backend == "hf": 354 | self.model.clear_kv_cache() 355 | elif self.backend == "mlx": 356 | self.model.clear_cache() 357 | elif self.backend == "vllm": 358 | warnings.warn( 359 | "clear_kv_cache() is only supported for the HuggingFace backend. The KV cache for the vLLM backend is handled internally by vLLM. No operation performed.", 360 | RuntimeWarning, 361 | stacklevel=2, 362 | ) 363 | elif self.backend == "mock": 364 | pass 365 | else: 366 | raise RuntimeError( 367 | f"clear_kv_cache() is not implemented for backend type {type(self.model)}" 368 | ) 369 | 370 | def reset_async_queries(self): 371 | """Clear any pending language model queries from the queue.""" 372 | if self.backend in ["hf", "mlx"]: 373 | self.model.reset_async_queries() 374 | elif self.backend == "vllm": 375 | warnings.warn( 376 | "reset_async_queries() is only supported for the HuggingFace backend. No operation performed.", 377 | RuntimeWarning, 378 | stacklevel=2, 379 | ) 380 | elif self.backend == "mock": 381 | pass 382 | else: 383 | raise RuntimeError( 384 | f"reset_async_queries() is not implemented for backend type {type(self.model)}" 385 | ) 386 | 387 | def cache_kv(self, prompt_tokens): 388 | """Cache the key and value vectors for a prompt. 389 | 390 | Args: 391 | prompt_tokens (list[int]): token ids for the prompt to cache. 392 | """ 393 | if self.backend in ["hf", "mlx"]: 394 | self.model.cache_kv(prompt_tokens) 395 | elif self.backend == "vllm": 396 | warnings.warn( 397 | "cache_kv() is only supported for the HuggingFace backend. The KV cache for the vLLM backend is handled internally by vLLM. No operation performed.", 398 | RuntimeWarning, 399 | stacklevel=2, 400 | ) 401 | elif self.backend == "mock": 402 | pass 403 | else: 404 | raise RuntimeError( 405 | f"cache_kv() is not implemented for backend type {type(self.model)}" 406 | ) 407 | --------------------------------------------------------------------------------