├── .github
└── workflows
│ └── test.yml
├── .gitignore
├── README.md
├── openlogprobs
├── __init__.py
├── extract.py
├── models.py
└── utils.py
├── pyproject.toml
├── requirements.txt
├── setup.py
├── test
└── test_logprobs.py
└── vis.png
/.github/workflows/test.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions
2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
3 |
4 | name: Test with PyTest
5 |
6 | on:
7 | push:
8 | branches: [ main ]
9 | pull_request:
10 | branches: [ main ]
11 |
12 | jobs:
13 | build:
14 |
15 | runs-on: ubuntu-latest
16 | strategy:
17 | matrix:
18 | python-version: [3.9.18, 3.12.1]
19 |
20 | steps:
21 | - uses: actions/checkout@v2
22 | - name: Set up Python ${{ matrix.python-version }}
23 | uses: actions/setup-python@v2
24 | with:
25 | python-version: ${{ matrix.python-version }}
26 | - name: Install dependencies
27 | run: |
28 | python -m pip install --upgrade pip setuptools wheel
29 | pip install numpy openai scipy transformers
30 | pip install pytest pytest-xdist # Testing packages
31 | pip install -e . # Install openlogprobs
32 | - name: Import openlogprobs
33 | run: |
34 | printf "import openlogprobs\n" | python
35 | - name: Test with pytest
36 | run: |
37 | pytest -vx --dist=loadfile -n auto
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | build/
3 | dist/
4 | *.egg-info/
5 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # openlogprobs
2 |
3 | #### 🪄 openlogprobs is a Python API for extracting log-probabilities from language model APIs 🪄
4 |
5 |
6 | ```bash
7 | pip install openlogprobs
8 | ```
9 |
10 |
11 |
12 | 
13 |
14 | 
15 |
16 | Many API-based language model services hide the log-probability outputs from their models. One reason is security – language model outputs can reveal information about their inputs and can be used for efficient model distillation. Another reason is a practical one: serving 30,000 (or whatever the vocabulary size is) floats via an API would take way too much data for a typical API request. So this information is hidden to you.
17 |
18 | However, most APIs also allow a 'logit bias' argument to positively or negatively influence the likelihood of certain tokens in language model output. It turns out, though, that we can use this logit bias on individual tokens to reverse-engineer their log probabilities. We developed an algorithm to do this efficiently, which effectively allows us to extract *full probability vectors* via APIs such as the OpenAI API. For more information, read the below section about the algorithm, or read the code in openlogprobs/extract.py.
19 |
20 |
21 | ## Usage
22 |
23 | ### topk search
24 |
25 | If the API exposes the top-k log-probabilities, we can efficiently extract the next-token probabilities via our 'topk' algorithm:
26 |
27 | ```python
28 | from openlogprobs import extract_logprobs
29 | extract_logprobs("gpt-3.5-turbo-instruct", "i like pie", method="topk")
30 | ```
31 |
32 | ### exact solution
33 |
34 | If the API exposes the top-k log-probabilities, we can efficiently extract the next-token probabilities k-at-a-time via our 'exact' algorithm:
35 |
36 | ```python
37 | from openlogprobs import extract_logprobs
38 | extract_logprobs("gpt-3.5-turbo-instruct", "i like pie", method="exact", parallel=True)
39 | ```
40 |
41 | This method requires fewer API calls then the top-k algorithm (only 1 call per k tokens).
42 |
43 | ### binary search
44 |
45 | If the API does not expose top-k logprobs, we can still extract the distribution, but it takes more language model calls:
46 |
47 | ```python
48 | from openlogprobs import extract_logprobs
49 | extract_logprobs("gpt-3.5-turbo-instruct", "i like pie", method="bisection")
50 | ```
51 |
52 | ### Future work (help wanted!)
53 |
54 | - support multiple logprobs (concurrent binary search)
55 | - estimate costs for various APIs
56 | - support checkpointing
57 |
58 | ## Algorithms
59 |
60 | ### Bisection and top-k
61 |
62 | Our algorithm is esssentially a binary search (technically 'univariate bisection' on a continuous variable) where we apply different amounts of logit bias to make certain tokens likely enough to appear in the generation. This allows us to estimate the probability of any token relative to the most likely token. To obtain the full vector of probabilities, we can run this binary search on every token in the vocabulary. Note that essentially all models support logit bias, and for that to work, all models that support logit bias must be open-vocabulary.
63 |
64 | Here's a crude visualization of how our algorithm works for a single token:
65 |
66 |
67 |
68 | Each API call (purple) brings us successively closer to the true token probability (green).
69 |
70 | ### Exact solution
71 |
72 | Our exact solution algorithm solves directly for the logprobs.
73 | To understand the math, see [this outline](https://mattf1n.github.io/openlogprobs.html).
74 |
75 | ## Language Model Inversion paper
76 |
77 | This algorithm was developed mainly by Justin Chiu to facilitate the paper [*Language Model Inversion*](https://arxiv.org/abs/2311.13647). If you're using our algorithm in academic research, please cite our paper:
78 |
79 | The exact solution algortithm was contributed by [Matthew Finlayson](https://mattf1n.github.io).
80 |
81 | ```
82 | @misc{morris2023language,
83 | title={Language Model Inversion},
84 | author={John X. Morris and Wenting Zhao and Justin T. Chiu and Vitaly Shmatikov and Alexander M. Rush},
85 | year={2023},
86 | eprint={2311.13647},
87 | archivePrefix={arXiv},
88 | primaryClass={cs.CL}
89 | }
90 | ```
91 |
--------------------------------------------------------------------------------
/openlogprobs/__init__.py:
--------------------------------------------------------------------------------
1 | from .extract import extract_logprobs
2 | from .models import OpenAIModel
--------------------------------------------------------------------------------
/openlogprobs/extract.py:
--------------------------------------------------------------------------------
1 | import tqdm
2 | import tiktoken
3 | import numpy as np
4 | from scipy.special import logsumexp
5 | import math
6 | from functools import partial, reduce
7 | from operator import or_ as union
8 | from typing import Literal, Optional
9 |
10 | from concurrent.futures import ThreadPoolExecutor
11 |
12 | from openlogprobs.models import Model
13 | from openlogprobs.utils import batched
14 |
15 |
16 | def exact_solve(
17 | model: Model,
18 | prefix: str,
19 | idx: list[int],
20 | bias: float = 5.0,
21 | top_logprob: Optional[float] = None,
22 | ) -> tuple[dict[int, float], set[int], int]:
23 | """Parallel exact solve based on https://mattf1n.github.io/openlogprobs.html"""
24 | logit_bias = {i: bias for i in idx}
25 | topk_words = model.topk(prefix, logit_bias)
26 | if all(i in topk_words for i in idx):
27 | biased_logprobs = np.array([topk_words[i] for i in idx])
28 | log_biased_prob = logsumexp(biased_logprobs)
29 | logprobs = biased_logprobs - np.logaddexp(
30 | bias + np.log1p(-np.exp(log_biased_prob)), log_biased_prob
31 | )
32 | return dict(zip(idx, logprobs)), set(), 1
33 | else:
34 | if top_logprob is None:
35 | missing_tokens = set(idx) - set(topk_words)
36 | raise TypeError(
37 | f"Tokens {missing_tokens} not in top-k with bias {bias}."
38 | "Either increase bias or provide top unbiased logprob (top_logprob)"
39 | )
40 | success_idxs = list(i for i in idx if i in topk_words)
41 | fail_idxs = set(idx) - set(topk_words)
42 | biased_top_logprob = max(
43 | logprob for i, logprob in topk_words.items() if i not in idx
44 | )
45 | biased_logprobs = np.array([topk_words[i] for i in success_idxs])
46 | logprobs = biased_logprobs - biased_top_logprob + top_logprob - bias
47 | return dict(zip(success_idxs, logprobs)), fail_idxs, 1
48 |
49 |
50 | def bisection_search(
51 | model: Model, prefix: str, idx: int, k=1, low=0, high=32, eps=1e-8
52 | ):
53 | # check if idx is the argmax
54 | num_calls = k
55 | if model.argmax(prefix) == idx:
56 | return 0, num_calls
57 |
58 | # initialize high
59 | logit_bias = {idx: high}
60 | while model.argmax(prefix, logit_bias) != idx:
61 | logit_bias[idx] *= 2
62 | num_calls += k
63 | high = logit_bias[idx]
64 |
65 | # improve estimate
66 | mid = (high + low) / 2
67 | while high >= low + eps:
68 | logit_bias[idx] = mid
69 | if model.argmax(prefix, logit_bias) == idx:
70 | high = mid
71 | else:
72 | low = mid
73 | mid = (high + low) / 2
74 | num_calls += k
75 | return -mid, num_calls
76 |
77 |
78 | def topk_search(model: Model, prefix: str, idx: int, k=1, high=40):
79 | # get raw topk, could be done outside and passed in
80 | topk_words = model.topk(prefix)
81 | highest_idx = list(topk_words.keys())[np.argmax(list(topk_words.values()))]
82 | if idx == highest_idx:
83 | return topk_words[idx], k
84 | num_calls = k
85 |
86 | # initialize high
87 | logit_bias = {idx: high}
88 | new_max_idx = model.argmax(prefix, logit_bias)
89 | num_calls += k
90 | while new_max_idx != idx:
91 | logit_bias[idx] *= 2
92 | new_max_idx = model.argmax(prefix, logit_bias)
93 | num_calls += k
94 | high = logit_bias[idx]
95 |
96 | output = model.topk(prefix, logit_bias)
97 | num_calls += k
98 |
99 | # compute normalizing constant
100 | diff = topk_words[highest_idx] - output[highest_idx]
101 | logZ = high - math.log(math.exp(diff) - 1)
102 | fv = output[idx] + math.log(math.exp(logZ) + math.exp(high)) - high
103 | logprob = fv - logZ
104 |
105 | return logprob, num_calls
106 |
107 |
108 | def extract_logprobs(
109 | model: Model,
110 | prefix: str,
111 | method: Literal["bisection", "topk", "exact"] = "bisection",
112 | k: int = 5,
113 | eps: float = 1e-6,
114 | multithread: bool = False,
115 | bias: float = 5.0,
116 | parallel: bool = False,
117 | ):
118 | vocab_size = model.vocab_size
119 |
120 | if method == "exact":
121 | logprob_dict = model.topk(prefix)
122 | top_logprob = max(logprob_dict.values())
123 | bias += top_logprob - min(logprob_dict.values())
124 | remaining = set(range(vocab_size)) - set(logprob_dict)
125 | total_calls = 0
126 | if multithread:
127 | executor = ThreadPoolExecutor(max_workers=8)
128 | map_func = executor.map
129 | else:
130 | map_func = map
131 | while remaining:
132 | search_results = map_func(
133 | partial(
134 | exact_solve,
135 | model,
136 | prefix,
137 | bias=bias,
138 | top_logprob=top_logprob,
139 | ),
140 | batched(remaining, k),
141 | )
142 | logprob_dicts, skipped, calls = zip(*search_results)
143 | logprob_dict |= reduce(union, logprob_dicts)
144 | remaining = set.union(*skipped)
145 | total_calls += sum(calls)
146 | bias += 5
147 | if multithread:
148 | executor.shutdown()
149 | logprobs = np.array([logprob_dict[i] for i in range(vocab_size)])
150 | return logprobs, total_calls
151 | else:
152 | search_func = topk_search if method == "topk" else bisection_search
153 | search = partial(search_func, model, prefix, k=k)
154 | vocab = list(range(vocab_size))
155 | if multithread:
156 | with ThreadPoolExecutor(max_workers=8) as executor:
157 | search_results = executor.map(search, tqdm.tqdm(vocab))
158 | else:
159 | search_results = map(search, tqdm.tqdm(vocab))
160 | logit_list, calls = zip(*search_results)
161 | logits = np.array(logit_list)
162 | return logits - logsumexp(logits), sum(calls)
163 |
--------------------------------------------------------------------------------
/openlogprobs/models.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Optional
2 |
3 | import abc
4 | import os
5 |
6 | import openai
7 | import numpy as np
8 | import tiktoken
9 |
10 | class Model(abc.ABC):
11 | """This class wraps the model API. It can take text and a logit bias and return text outputs."""
12 |
13 | @abc.abstractproperty
14 | def vocab_size() -> int:
15 | return -1
16 |
17 | @abc.abstractmethod
18 | def argmax(self, prefix: str, logit_bias: Dict[str, float] = {}) -> int:
19 | raise NotImplementedError
20 |
21 | @abc.abstractmethod
22 | def topk(self, prefix: str, logit_bias: Dict[str, float] = {}) -> Dict[int, float]:
23 | raise NotImplementedError
24 |
25 | def median_topk(self, k, *args, **kwargs):
26 | """Runs the same topk query multiple times and returns the median. Useful
27 | to combat API nondeterminism when calling topk()."""
28 | results = [self.topk(*args, **kwargs) for _ in range(k)]
29 | return {
30 | word: np.median([result[word] for result in results])
31 | for word in results[0].keys()
32 | }
33 | def median_argmax(self, k, *args, **kwargs):
34 | """Runs the same argmax query multiple times and returns the median. Useful
35 | to combat API nondeterminism when calling argmax()."""
36 | results = [self.argmax(*args, **kwargs) for _ in range(k)]
37 | return np.median()
38 |
39 |
40 | class OpenAIModel(Model):
41 | """Model wrapper for OpenAI API."""
42 | def __init__(self, model: str, system: Optional[str] = None):
43 | self.encoding = tiktoken.encoding_for_model(model)
44 | self.client = openai.OpenAI(api_key=os.environ["OPENAI_API_KEY"])
45 | self.system = (system or "You are a helpful assistant.")
46 |
47 | @property
48 | def vocab_size(self) -> int:
49 | return self.encoding.n_vocab
50 |
51 | def argmax(self, prefix: str, logit_bias: Dict[str, float] = {}) -> int:
52 | model = self.model
53 | system = self.system
54 | enc = tiktoken.encoding_for_model(model)
55 | if model == "gpt-3.5-turbo-instruct":
56 | if logit_bias is not None:
57 | response = self.client.completions.create(
58 | model=model,
59 | prompt=prefix,
60 | temperature=0,
61 | max_tokens=1,
62 | logit_bias=logit_bias,
63 | n=1,
64 | )
65 | else:
66 | response = self.client.completions.create(
67 | model=model,
68 | prompt=prefix,
69 | temperature=0,
70 | max_tokens=1,
71 | n=1,
72 | )
73 | output = response.choices[0].text
74 | eos_idx = enc.encode(
75 | "<|endoftext|>", allowed_special={"<|endoftext|>", "<|im_start|>"}
76 | )[0]
77 | outputs = [choice.text for choice in response.choices]
78 | else:
79 | if logit_bias is not None:
80 | response = self.client.chat.completions.create(
81 | model=model,
82 | messages=[
83 | {"role": "system", "content": system},
84 | {"role": "user", "content": prefix},
85 | ],
86 | temperature=0,
87 | max_tokens=1,
88 | logit_bias=logit_bias,
89 | n=1,
90 | )
91 | else:
92 | response = self.client.chat.completions.create(
93 | model=model,
94 | messages=[
95 | {"role": "system", "content": system},
96 | {"role": "user", "content": prefix},
97 | ],
98 | temperature=0,
99 | max_tokens=1,
100 | n=1,
101 | )
102 | output = response.choices[0].message["content"]
103 | outputs = [choice.message["content"] for choice in response.choices]
104 | eos_idx = enc.encode(
105 | "<|endoftext|>", allowed_special={"<|endoftext|>", "<|im_start|>"}
106 | )[0]
107 |
108 | return enc.encode(output)[0]
109 |
110 | def topk(self, prefix: str, logit_bias: Dict[str, float] = {}) -> Dict[int, float]:
111 | enc = self.encoding
112 | model = self.model
113 | system = self.system
114 | if model == "gpt-3.5-turbo-instruct":
115 | if logit_bias is not None:
116 | response = self.client.completions.create(
117 | model=model,
118 | prompt=prefix,
119 | temperature=1,
120 | max_tokens=1,
121 | logit_bias=logit_bias,
122 | logprobs=5,
123 | )
124 | else:
125 | response = self.client.completions.create(
126 | model=model,
127 | prompt=prefix,
128 | temperature=1,
129 | max_tokens=1,
130 | logprobs=5,
131 | )
132 | else:
133 | raise NotImplementedError(f"Tried to get topk logprobs for: {model}")
134 | topk_dict = response.choices[0].logprobs.top_logprobs[0]
135 | return {enc.encode(x)[0]: y for x, y in topk_dict.items()}
136 |
--------------------------------------------------------------------------------
/openlogprobs/utils.py:
--------------------------------------------------------------------------------
1 | from itertools import islice
2 | import sys
3 |
4 | if sys.version_info.minor < 12:
5 |
6 | def batched(iterable, n):
7 | """From https://docs.python.org/3.11/library/itertools.html#itertools-recipes"""
8 | "Batch data into tuples of length n. The last batch may be shorter."
9 | # batched('ABCDEFG', 3) --> ABC DEF G
10 | if n < 1:
11 | raise ValueError("n must be at least one")
12 | it = iter(iterable)
13 | while batch := tuple(islice(it, n)):
14 | yield batch
15 |
16 | else:
17 | from itertools import batched
18 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "openlogprobs"
3 | version = "0.0.0"
4 | description = "Logprob estimation from LLM APIs"
5 | authors = ["Justin Chiu "]
6 | readme = "README.md"
7 |
8 | [tool.poetry.dependencies]
9 | python = "^3.9"
10 | numpy = "*"
11 | pyparsing = "*"
12 | tqdm = "*"
13 | tiktoken = "*"
14 | tokenizers = ">=0.13.3"
15 | openai = "*"
16 | tenacity = "*"
17 | scipy = "*"
18 | datasets = "*"
19 | pytest = "*"
20 |
21 | [build-system]
22 | requires = ["poetry-core"]
23 | build-backend = "poetry.core.masonry.api"
24 |
25 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | openai
3 | scipy
4 | transformers
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import find_packages, setup
2 |
3 | setup(
4 | name="openlogprobs",
5 | version="0.0.2",
6 | description="extract log-probabilities from APIs",
7 | author="Justin Chiu & Jack Morris",
8 | author_email="jtc257@cornell.edu",
9 | packages=find_packages(),
10 | install_requires=open("requirements.txt").readlines()
11 | # install_requires=[],
12 | )
--------------------------------------------------------------------------------
/test/test_logprobs.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 |
3 | import numpy as np
4 | import pytest
5 | from scipy.special import log_softmax
6 | import transformers
7 |
8 | from openlogprobs import (
9 | extract_logprobs,
10 | # OpenAIModel,
11 | )
12 | from openlogprobs.extract import (
13 | bisection_search,
14 | topk_search,
15 | )
16 | from openlogprobs.models import Model
17 |
18 | prefix = "Should i take this class or not? The professor of this class is not good at all. He doesn't teach well and he is always late for class."
19 |
20 |
21 | def load_fake_logits(vocab_size: int) -> np.ndarray:
22 | np.random.seed(42)
23 | logits = np.random.randn(vocab_size)
24 | logits[1] += 10
25 | logits[12] += 20
26 | logits[13] += 30
27 | logits[24] += 30
28 | logits[35] += 30
29 | return logits
30 |
31 |
32 | class FakeModel(Model):
33 | """Represents a fake API with a temperature of 1. Used for testing."""
34 |
35 | def __init__(self, vocab_size: int = 100, get_logits=None):
36 | self.tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2")
37 | self.fake_vocab_size = vocab_size
38 | if get_logits is None:
39 | self.logits = load_fake_logits(self.vocab_size)[:vocab_size]
40 | else:
41 | self.logits = get_logits(vocab_size)
42 |
43 | @property
44 | def vocab_size(self):
45 | return self.fake_vocab_size
46 |
47 | def _idx_to_str(self, idx: int) -> str:
48 | return self.tokenizer.decode([idx], skip_special_tokens=True)
49 |
50 | def _add_logit_bias(self, logit_bias: Dict[str, float]) -> np.ndarray:
51 | logits = self.logits.copy()
52 | for token_idx, bias in logit_bias.items():
53 | logits[token_idx] += bias
54 | logits = logits.astype(np.double)
55 | return log_softmax(logits)
56 |
57 | def argmax(self, prefix: str, logit_bias: Dict[str, float] = {}) -> int:
58 | logits = self._add_logit_bias(logit_bias)
59 | return logits.argmax()
60 |
61 | def topk(self, prefix: str, logit_bias: Dict[str, float] = {}) -> Dict[int, float]:
62 | k = 5 # TODO: what topk?
63 | logits = self._add_logit_bias(logit_bias)
64 | topk = logits.argsort()[-k:]
65 | return {k: logits[k] for k in topk}
66 |
67 |
68 | @pytest.fixture
69 | def model():
70 | # return OpenAIModel("gpt-3.5-turbo-instruct")
71 | return FakeModel()
72 |
73 |
74 | @pytest.fixture
75 | def uniform_model():
76 | # return OpenAIModel("gpt-3.5-turbo-instruct")
77 | return FakeModel(get_logits=np.ones)
78 |
79 |
80 | @pytest.fixture
81 | def topk_words(model):
82 | return model.topk(prefix)
83 |
84 |
85 | def test_bisection(model, topk_words):
86 | true_sorted_logprobs = np.array(sorted(topk_words.values()))
87 | true_diffs = true_sorted_logprobs - true_sorted_logprobs.max()
88 |
89 | estimated_diffs = {
90 | word: bisection_search(model, prefix, word) for word in topk_words.keys()
91 | }
92 | estimated_diffs = np.array(sorted([x[0] for x in estimated_diffs.values()]))
93 | assert np.allclose(true_diffs, estimated_diffs, atol=1e-5)
94 |
95 |
96 | def test_topk(model, topk_words):
97 | true_probs = np.array(sorted(topk_words.values()))
98 |
99 | estimated_probs = {
100 | word: topk_search(model, prefix, word) for word in topk_words.keys()
101 | }
102 | estimated_probs = np.array(sorted([x[0] for x in estimated_probs.values()]))
103 | assert np.allclose(true_probs, estimated_probs, atol=1e-5)
104 |
105 |
106 | def test_topk_consistency(model, topk_words):
107 | true_probs = np.array(sorted(topk_words.values()))
108 |
109 | probs = []
110 | for _trial in range(10):
111 | estimated_probs = {
112 | word: topk_search(model, prefix, word) for word in topk_words.keys()
113 | }
114 | estimated_probs = np.array(sorted([x[0] for x in estimated_probs.values()]))
115 | probs.append(estimated_probs)
116 | probs = np.stack(probs)
117 | assert np.allclose(true_probs, np.median(probs, 0), atol=1e-5)
118 |
119 |
120 | def test_extract_topk(model):
121 | true_logprobs = log_softmax(model.logits)
122 | extracted_logprobs, num_calls = extract_logprobs(
123 | model, prefix="test", method="topk", multithread=False, k=1
124 | )
125 | np.testing.assert_allclose(true_logprobs, extracted_logprobs)
126 | assert num_calls == 298
127 |
128 |
129 | def test_extract_bisection(model):
130 | true_logprobs = log_softmax(model.logits)
131 | extracted_logprobs, num_calls = extract_logprobs(
132 | model, prefix="test", method="bisection", multithread=False, k=1
133 | )
134 | np.testing.assert_allclose(true_logprobs, extracted_logprobs)
135 | assert num_calls == 3270
136 |
137 |
138 | def test_extract_exact(model):
139 | true_logprobs = log_softmax(model.logits)
140 | extracted_logprobs, num_calls = extract_logprobs(
141 | model, prefix="test", method="exact", multithread=False
142 | )
143 | np.testing.assert_allclose(true_logprobs, extracted_logprobs)
144 | assert num_calls < len(true_logprobs)
145 |
146 |
147 | def test_extract_exact_parallel(model):
148 | true_logprobs = log_softmax(model.logits)
149 | extracted_logprobs, num_calls = extract_logprobs(
150 | model,
151 | prefix="test",
152 | method="exact",
153 | multithread=False,
154 | parallel=True,
155 | )
156 | np.testing.assert_allclose(true_logprobs, extracted_logprobs)
157 | assert num_calls < len(true_logprobs)
158 |
159 |
160 | def test_extract_topk_multithread(model):
161 | true_logprobs = log_softmax(model.logits)
162 | extracted_logprobs, num_calls = extract_logprobs(
163 | model, prefix="test", method="topk", multithread=True, k=1
164 | )
165 | np.testing.assert_allclose(true_logprobs, extracted_logprobs)
166 | assert num_calls == 298
167 |
168 |
169 | def test_extract_exact_multithread(model):
170 | true_logprobs = log_softmax(model.logits)
171 | extracted_logprobs, num_calls = extract_logprobs(
172 | model, prefix="test", method="exact", multithread=True
173 | )
174 | np.testing.assert_allclose(true_logprobs, extracted_logprobs)
175 | assert num_calls < len(true_logprobs)
176 |
177 |
178 | def test_extract_exact_parallel_multithread(model):
179 | true_logprobs = log_softmax(model.logits)
180 | extracted_logprobs, num_calls = extract_logprobs(
181 | model, prefix="test", method="exact", multithread=True, parallel=True
182 | )
183 | np.testing.assert_allclose(true_logprobs, extracted_logprobs)
184 | assert num_calls < len(true_logprobs)
185 |
186 |
187 | def test_extract_exact_parallel_multithread_uniform(uniform_model):
188 | true_logprobs = log_softmax(uniform_model.logits)
189 | extracted_logprobs, num_calls = extract_logprobs(
190 | uniform_model,
191 | prefix="test",
192 | method="exact",
193 | parallel=True,
194 | )
195 | np.testing.assert_allclose(true_logprobs, extracted_logprobs)
196 | assert num_calls < len(true_logprobs)
197 |
--------------------------------------------------------------------------------
/vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/justinchiu/openlogprobs/0d5aabde327f51589921e0d012e1c19620dc21c2/vis.png
--------------------------------------------------------------------------------