├── .github └── workflows │ └── test.yml ├── .gitignore ├── README.md ├── openlogprobs ├── __init__.py ├── extract.py ├── models.py └── utils.py ├── pyproject.toml ├── requirements.txt ├── setup.py ├── test └── test_logprobs.py └── vis.png /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Test with PyTest 5 | 6 | on: 7 | push: 8 | branches: [ main ] 9 | pull_request: 10 | branches: [ main ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | strategy: 17 | matrix: 18 | python-version: [3.9.18, 3.12.1] 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip setuptools wheel 29 | pip install numpy openai scipy transformers 30 | pip install pytest pytest-xdist # Testing packages 31 | pip install -e . # Install openlogprobs 32 | - name: Import openlogprobs 33 | run: | 34 | printf "import openlogprobs\n" | python 35 | - name: Test with pytest 36 | run: | 37 | pytest -vx --dist=loadfile -n auto -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | build/ 3 | dist/ 4 | *.egg-info/ 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # openlogprobs 2 | 3 | #### 🪄 openlogprobs is a Python API for extracting log-probabilities from language model APIs 🪄

4 | 5 | 6 | ```bash 7 | pip install openlogprobs 8 | ``` 9 | 10 |
11 | 12 | ![openlogprobs on pypi](https://badge.fury.io/py/openlogprobs.svg) 13 | 14 | ![Test with PyTest](https://github.com/justinchiu/openlogprobs/workflows/Test%20with%20PyTest/badge.svg) 15 | 16 | Many API-based language model services hide the log-probability outputs from their models. One reason is security – language model outputs can reveal information about their inputs and can be used for efficient model distillation. Another reason is a practical one: serving 30,000 (or whatever the vocabulary size is) floats via an API would take way too much data for a typical API request. So this information is hidden to you. 17 | 18 | However, most APIs also allow a 'logit bias' argument to positively or negatively influence the likelihood of certain tokens in language model output. It turns out, though, that we can use this logit bias on individual tokens to reverse-engineer their log probabilities. We developed an algorithm to do this efficiently, which effectively allows us to extract *full probability vectors* via APIs such as the OpenAI API. For more information, read the below section about the algorithm, or read the code in openlogprobs/extract.py. 19 | 20 | 21 | ## Usage 22 | 23 | ### topk search 24 | 25 | If the API exposes the top-k log-probabilities, we can efficiently extract the next-token probabilities via our 'topk' algorithm: 26 | 27 | ```python 28 | from openlogprobs import extract_logprobs 29 | extract_logprobs("gpt-3.5-turbo-instruct", "i like pie", method="topk") 30 | ``` 31 | 32 | ### exact solution 33 | 34 | If the API exposes the top-k log-probabilities, we can efficiently extract the next-token probabilities k-at-a-time via our 'exact' algorithm: 35 | 36 | ```python 37 | from openlogprobs import extract_logprobs 38 | extract_logprobs("gpt-3.5-turbo-instruct", "i like pie", method="exact", parallel=True) 39 | ``` 40 | 41 | This method requires fewer API calls then the top-k algorithm (only 1 call per k tokens). 42 | 43 | ### binary search 44 | 45 | If the API does not expose top-k logprobs, we can still extract the distribution, but it takes more language model calls: 46 | 47 | ```python 48 | from openlogprobs import extract_logprobs 49 | extract_logprobs("gpt-3.5-turbo-instruct", "i like pie", method="bisection") 50 | ``` 51 | 52 | ### Future work (help wanted!) 53 | 54 | - support multiple logprobs (concurrent binary search) 55 | - estimate costs for various APIs 56 | - support checkpointing 57 | 58 | ## Algorithms 59 | 60 | ### Bisection and top-k 61 | 62 | Our algorithm is esssentially a binary search (technically 'univariate bisection' on a continuous variable) where we apply different amounts of logit bias to make certain tokens likely enough to appear in the generation. This allows us to estimate the probability of any token relative to the most likely token. To obtain the full vector of probabilities, we can run this binary search on every token in the vocabulary. Note that essentially all models support logit bias, and for that to work, all models that support logit bias must be open-vocabulary. 63 | 64 | Here's a crude visualization of how our algorithm works for a single token: 65 | 66 | 67 | 68 | Each API call (purple) brings us successively closer to the true token probability (green). 69 | 70 | ### Exact solution 71 | 72 | Our exact solution algorithm solves directly for the logprobs. 73 | To understand the math, see [this outline](https://mattf1n.github.io/openlogprobs.html). 74 | 75 | ## Language Model Inversion paper 76 | 77 | This algorithm was developed mainly by Justin Chiu to facilitate the paper [*Language Model Inversion*](https://arxiv.org/abs/2311.13647). If you're using our algorithm in academic research, please cite our paper: 78 | 79 | The exact solution algortithm was contributed by [Matthew Finlayson](https://mattf1n.github.io). 80 | 81 | ``` 82 | @misc{morris2023language, 83 | title={Language Model Inversion}, 84 | author={John X. Morris and Wenting Zhao and Justin T. Chiu and Vitaly Shmatikov and Alexander M. Rush}, 85 | year={2023}, 86 | eprint={2311.13647}, 87 | archivePrefix={arXiv}, 88 | primaryClass={cs.CL} 89 | } 90 | ``` 91 | -------------------------------------------------------------------------------- /openlogprobs/__init__.py: -------------------------------------------------------------------------------- 1 | from .extract import extract_logprobs 2 | from .models import OpenAIModel -------------------------------------------------------------------------------- /openlogprobs/extract.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | import tiktoken 3 | import numpy as np 4 | from scipy.special import logsumexp 5 | import math 6 | from functools import partial, reduce 7 | from operator import or_ as union 8 | from typing import Literal, Optional 9 | 10 | from concurrent.futures import ThreadPoolExecutor 11 | 12 | from openlogprobs.models import Model 13 | from openlogprobs.utils import batched 14 | 15 | 16 | def exact_solve( 17 | model: Model, 18 | prefix: str, 19 | idx: list[int], 20 | bias: float = 5.0, 21 | top_logprob: Optional[float] = None, 22 | ) -> tuple[dict[int, float], set[int], int]: 23 | """Parallel exact solve based on https://mattf1n.github.io/openlogprobs.html""" 24 | logit_bias = {i: bias for i in idx} 25 | topk_words = model.topk(prefix, logit_bias) 26 | if all(i in topk_words for i in idx): 27 | biased_logprobs = np.array([topk_words[i] for i in idx]) 28 | log_biased_prob = logsumexp(biased_logprobs) 29 | logprobs = biased_logprobs - np.logaddexp( 30 | bias + np.log1p(-np.exp(log_biased_prob)), log_biased_prob 31 | ) 32 | return dict(zip(idx, logprobs)), set(), 1 33 | else: 34 | if top_logprob is None: 35 | missing_tokens = set(idx) - set(topk_words) 36 | raise TypeError( 37 | f"Tokens {missing_tokens} not in top-k with bias {bias}." 38 | "Either increase bias or provide top unbiased logprob (top_logprob)" 39 | ) 40 | success_idxs = list(i for i in idx if i in topk_words) 41 | fail_idxs = set(idx) - set(topk_words) 42 | biased_top_logprob = max( 43 | logprob for i, logprob in topk_words.items() if i not in idx 44 | ) 45 | biased_logprobs = np.array([topk_words[i] for i in success_idxs]) 46 | logprobs = biased_logprobs - biased_top_logprob + top_logprob - bias 47 | return dict(zip(success_idxs, logprobs)), fail_idxs, 1 48 | 49 | 50 | def bisection_search( 51 | model: Model, prefix: str, idx: int, k=1, low=0, high=32, eps=1e-8 52 | ): 53 | # check if idx is the argmax 54 | num_calls = k 55 | if model.argmax(prefix) == idx: 56 | return 0, num_calls 57 | 58 | # initialize high 59 | logit_bias = {idx: high} 60 | while model.argmax(prefix, logit_bias) != idx: 61 | logit_bias[idx] *= 2 62 | num_calls += k 63 | high = logit_bias[idx] 64 | 65 | # improve estimate 66 | mid = (high + low) / 2 67 | while high >= low + eps: 68 | logit_bias[idx] = mid 69 | if model.argmax(prefix, logit_bias) == idx: 70 | high = mid 71 | else: 72 | low = mid 73 | mid = (high + low) / 2 74 | num_calls += k 75 | return -mid, num_calls 76 | 77 | 78 | def topk_search(model: Model, prefix: str, idx: int, k=1, high=40): 79 | # get raw topk, could be done outside and passed in 80 | topk_words = model.topk(prefix) 81 | highest_idx = list(topk_words.keys())[np.argmax(list(topk_words.values()))] 82 | if idx == highest_idx: 83 | return topk_words[idx], k 84 | num_calls = k 85 | 86 | # initialize high 87 | logit_bias = {idx: high} 88 | new_max_idx = model.argmax(prefix, logit_bias) 89 | num_calls += k 90 | while new_max_idx != idx: 91 | logit_bias[idx] *= 2 92 | new_max_idx = model.argmax(prefix, logit_bias) 93 | num_calls += k 94 | high = logit_bias[idx] 95 | 96 | output = model.topk(prefix, logit_bias) 97 | num_calls += k 98 | 99 | # compute normalizing constant 100 | diff = topk_words[highest_idx] - output[highest_idx] 101 | logZ = high - math.log(math.exp(diff) - 1) 102 | fv = output[idx] + math.log(math.exp(logZ) + math.exp(high)) - high 103 | logprob = fv - logZ 104 | 105 | return logprob, num_calls 106 | 107 | 108 | def extract_logprobs( 109 | model: Model, 110 | prefix: str, 111 | method: Literal["bisection", "topk", "exact"] = "bisection", 112 | k: int = 5, 113 | eps: float = 1e-6, 114 | multithread: bool = False, 115 | bias: float = 5.0, 116 | parallel: bool = False, 117 | ): 118 | vocab_size = model.vocab_size 119 | 120 | if method == "exact": 121 | logprob_dict = model.topk(prefix) 122 | top_logprob = max(logprob_dict.values()) 123 | bias += top_logprob - min(logprob_dict.values()) 124 | remaining = set(range(vocab_size)) - set(logprob_dict) 125 | total_calls = 0 126 | if multithread: 127 | executor = ThreadPoolExecutor(max_workers=8) 128 | map_func = executor.map 129 | else: 130 | map_func = map 131 | while remaining: 132 | search_results = map_func( 133 | partial( 134 | exact_solve, 135 | model, 136 | prefix, 137 | bias=bias, 138 | top_logprob=top_logprob, 139 | ), 140 | batched(remaining, k), 141 | ) 142 | logprob_dicts, skipped, calls = zip(*search_results) 143 | logprob_dict |= reduce(union, logprob_dicts) 144 | remaining = set.union(*skipped) 145 | total_calls += sum(calls) 146 | bias += 5 147 | if multithread: 148 | executor.shutdown() 149 | logprobs = np.array([logprob_dict[i] for i in range(vocab_size)]) 150 | return logprobs, total_calls 151 | else: 152 | search_func = topk_search if method == "topk" else bisection_search 153 | search = partial(search_func, model, prefix, k=k) 154 | vocab = list(range(vocab_size)) 155 | if multithread: 156 | with ThreadPoolExecutor(max_workers=8) as executor: 157 | search_results = executor.map(search, tqdm.tqdm(vocab)) 158 | else: 159 | search_results = map(search, tqdm.tqdm(vocab)) 160 | logit_list, calls = zip(*search_results) 161 | logits = np.array(logit_list) 162 | return logits - logsumexp(logits), sum(calls) 163 | -------------------------------------------------------------------------------- /openlogprobs/models.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional 2 | 3 | import abc 4 | import os 5 | 6 | import openai 7 | import numpy as np 8 | import tiktoken 9 | 10 | class Model(abc.ABC): 11 | """This class wraps the model API. It can take text and a logit bias and return text outputs.""" 12 | 13 | @abc.abstractproperty 14 | def vocab_size() -> int: 15 | return -1 16 | 17 | @abc.abstractmethod 18 | def argmax(self, prefix: str, logit_bias: Dict[str, float] = {}) -> int: 19 | raise NotImplementedError 20 | 21 | @abc.abstractmethod 22 | def topk(self, prefix: str, logit_bias: Dict[str, float] = {}) -> Dict[int, float]: 23 | raise NotImplementedError 24 | 25 | def median_topk(self, k, *args, **kwargs): 26 | """Runs the same topk query multiple times and returns the median. Useful 27 | to combat API nondeterminism when calling topk().""" 28 | results = [self.topk(*args, **kwargs) for _ in range(k)] 29 | return { 30 | word: np.median([result[word] for result in results]) 31 | for word in results[0].keys() 32 | } 33 | def median_argmax(self, k, *args, **kwargs): 34 | """Runs the same argmax query multiple times and returns the median. Useful 35 | to combat API nondeterminism when calling argmax().""" 36 | results = [self.argmax(*args, **kwargs) for _ in range(k)] 37 | return np.median() 38 | 39 | 40 | class OpenAIModel(Model): 41 | """Model wrapper for OpenAI API.""" 42 | def __init__(self, model: str, system: Optional[str] = None): 43 | self.encoding = tiktoken.encoding_for_model(model) 44 | self.client = openai.OpenAI(api_key=os.environ["OPENAI_API_KEY"]) 45 | self.system = (system or "You are a helpful assistant.") 46 | 47 | @property 48 | def vocab_size(self) -> int: 49 | return self.encoding.n_vocab 50 | 51 | def argmax(self, prefix: str, logit_bias: Dict[str, float] = {}) -> int: 52 | model = self.model 53 | system = self.system 54 | enc = tiktoken.encoding_for_model(model) 55 | if model == "gpt-3.5-turbo-instruct": 56 | if logit_bias is not None: 57 | response = self.client.completions.create( 58 | model=model, 59 | prompt=prefix, 60 | temperature=0, 61 | max_tokens=1, 62 | logit_bias=logit_bias, 63 | n=1, 64 | ) 65 | else: 66 | response = self.client.completions.create( 67 | model=model, 68 | prompt=prefix, 69 | temperature=0, 70 | max_tokens=1, 71 | n=1, 72 | ) 73 | output = response.choices[0].text 74 | eos_idx = enc.encode( 75 | "<|endoftext|>", allowed_special={"<|endoftext|>", "<|im_start|>"} 76 | )[0] 77 | outputs = [choice.text for choice in response.choices] 78 | else: 79 | if logit_bias is not None: 80 | response = self.client.chat.completions.create( 81 | model=model, 82 | messages=[ 83 | {"role": "system", "content": system}, 84 | {"role": "user", "content": prefix}, 85 | ], 86 | temperature=0, 87 | max_tokens=1, 88 | logit_bias=logit_bias, 89 | n=1, 90 | ) 91 | else: 92 | response = self.client.chat.completions.create( 93 | model=model, 94 | messages=[ 95 | {"role": "system", "content": system}, 96 | {"role": "user", "content": prefix}, 97 | ], 98 | temperature=0, 99 | max_tokens=1, 100 | n=1, 101 | ) 102 | output = response.choices[0].message["content"] 103 | outputs = [choice.message["content"] for choice in response.choices] 104 | eos_idx = enc.encode( 105 | "<|endoftext|>", allowed_special={"<|endoftext|>", "<|im_start|>"} 106 | )[0] 107 | 108 | return enc.encode(output)[0] 109 | 110 | def topk(self, prefix: str, logit_bias: Dict[str, float] = {}) -> Dict[int, float]: 111 | enc = self.encoding 112 | model = self.model 113 | system = self.system 114 | if model == "gpt-3.5-turbo-instruct": 115 | if logit_bias is not None: 116 | response = self.client.completions.create( 117 | model=model, 118 | prompt=prefix, 119 | temperature=1, 120 | max_tokens=1, 121 | logit_bias=logit_bias, 122 | logprobs=5, 123 | ) 124 | else: 125 | response = self.client.completions.create( 126 | model=model, 127 | prompt=prefix, 128 | temperature=1, 129 | max_tokens=1, 130 | logprobs=5, 131 | ) 132 | else: 133 | raise NotImplementedError(f"Tried to get topk logprobs for: {model}") 134 | topk_dict = response.choices[0].logprobs.top_logprobs[0] 135 | return {enc.encode(x)[0]: y for x, y in topk_dict.items()} 136 | -------------------------------------------------------------------------------- /openlogprobs/utils.py: -------------------------------------------------------------------------------- 1 | from itertools import islice 2 | import sys 3 | 4 | if sys.version_info.minor < 12: 5 | 6 | def batched(iterable, n): 7 | """From https://docs.python.org/3.11/library/itertools.html#itertools-recipes""" 8 | "Batch data into tuples of length n. The last batch may be shorter." 9 | # batched('ABCDEFG', 3) --> ABC DEF G 10 | if n < 1: 11 | raise ValueError("n must be at least one") 12 | it = iter(iterable) 13 | while batch := tuple(islice(it, n)): 14 | yield batch 15 | 16 | else: 17 | from itertools import batched 18 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "openlogprobs" 3 | version = "0.0.0" 4 | description = "Logprob estimation from LLM APIs" 5 | authors = ["Justin Chiu "] 6 | readme = "README.md" 7 | 8 | [tool.poetry.dependencies] 9 | python = "^3.9" 10 | numpy = "*" 11 | pyparsing = "*" 12 | tqdm = "*" 13 | tiktoken = "*" 14 | tokenizers = ">=0.13.3" 15 | openai = "*" 16 | tenacity = "*" 17 | scipy = "*" 18 | datasets = "*" 19 | pytest = "*" 20 | 21 | [build-system] 22 | requires = ["poetry-core"] 23 | build-backend = "poetry.core.masonry.api" 24 | 25 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | openai 3 | scipy 4 | transformers -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name="openlogprobs", 5 | version="0.0.2", 6 | description="extract log-probabilities from APIs", 7 | author="Justin Chiu & Jack Morris", 8 | author_email="jtc257@cornell.edu", 9 | packages=find_packages(), 10 | install_requires=open("requirements.txt").readlines() 11 | # install_requires=[], 12 | ) -------------------------------------------------------------------------------- /test/test_logprobs.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import numpy as np 4 | import pytest 5 | from scipy.special import log_softmax 6 | import transformers 7 | 8 | from openlogprobs import ( 9 | extract_logprobs, 10 | # OpenAIModel, 11 | ) 12 | from openlogprobs.extract import ( 13 | bisection_search, 14 | topk_search, 15 | ) 16 | from openlogprobs.models import Model 17 | 18 | prefix = "Should i take this class or not? The professor of this class is not good at all. He doesn't teach well and he is always late for class." 19 | 20 | 21 | def load_fake_logits(vocab_size: int) -> np.ndarray: 22 | np.random.seed(42) 23 | logits = np.random.randn(vocab_size) 24 | logits[1] += 10 25 | logits[12] += 20 26 | logits[13] += 30 27 | logits[24] += 30 28 | logits[35] += 30 29 | return logits 30 | 31 | 32 | class FakeModel(Model): 33 | """Represents a fake API with a temperature of 1. Used for testing.""" 34 | 35 | def __init__(self, vocab_size: int = 100, get_logits=None): 36 | self.tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2") 37 | self.fake_vocab_size = vocab_size 38 | if get_logits is None: 39 | self.logits = load_fake_logits(self.vocab_size)[:vocab_size] 40 | else: 41 | self.logits = get_logits(vocab_size) 42 | 43 | @property 44 | def vocab_size(self): 45 | return self.fake_vocab_size 46 | 47 | def _idx_to_str(self, idx: int) -> str: 48 | return self.tokenizer.decode([idx], skip_special_tokens=True) 49 | 50 | def _add_logit_bias(self, logit_bias: Dict[str, float]) -> np.ndarray: 51 | logits = self.logits.copy() 52 | for token_idx, bias in logit_bias.items(): 53 | logits[token_idx] += bias 54 | logits = logits.astype(np.double) 55 | return log_softmax(logits) 56 | 57 | def argmax(self, prefix: str, logit_bias: Dict[str, float] = {}) -> int: 58 | logits = self._add_logit_bias(logit_bias) 59 | return logits.argmax() 60 | 61 | def topk(self, prefix: str, logit_bias: Dict[str, float] = {}) -> Dict[int, float]: 62 | k = 5 # TODO: what topk? 63 | logits = self._add_logit_bias(logit_bias) 64 | topk = logits.argsort()[-k:] 65 | return {k: logits[k] for k in topk} 66 | 67 | 68 | @pytest.fixture 69 | def model(): 70 | # return OpenAIModel("gpt-3.5-turbo-instruct") 71 | return FakeModel() 72 | 73 | 74 | @pytest.fixture 75 | def uniform_model(): 76 | # return OpenAIModel("gpt-3.5-turbo-instruct") 77 | return FakeModel(get_logits=np.ones) 78 | 79 | 80 | @pytest.fixture 81 | def topk_words(model): 82 | return model.topk(prefix) 83 | 84 | 85 | def test_bisection(model, topk_words): 86 | true_sorted_logprobs = np.array(sorted(topk_words.values())) 87 | true_diffs = true_sorted_logprobs - true_sorted_logprobs.max() 88 | 89 | estimated_diffs = { 90 | word: bisection_search(model, prefix, word) for word in topk_words.keys() 91 | } 92 | estimated_diffs = np.array(sorted([x[0] for x in estimated_diffs.values()])) 93 | assert np.allclose(true_diffs, estimated_diffs, atol=1e-5) 94 | 95 | 96 | def test_topk(model, topk_words): 97 | true_probs = np.array(sorted(topk_words.values())) 98 | 99 | estimated_probs = { 100 | word: topk_search(model, prefix, word) for word in topk_words.keys() 101 | } 102 | estimated_probs = np.array(sorted([x[0] for x in estimated_probs.values()])) 103 | assert np.allclose(true_probs, estimated_probs, atol=1e-5) 104 | 105 | 106 | def test_topk_consistency(model, topk_words): 107 | true_probs = np.array(sorted(topk_words.values())) 108 | 109 | probs = [] 110 | for _trial in range(10): 111 | estimated_probs = { 112 | word: topk_search(model, prefix, word) for word in topk_words.keys() 113 | } 114 | estimated_probs = np.array(sorted([x[0] for x in estimated_probs.values()])) 115 | probs.append(estimated_probs) 116 | probs = np.stack(probs) 117 | assert np.allclose(true_probs, np.median(probs, 0), atol=1e-5) 118 | 119 | 120 | def test_extract_topk(model): 121 | true_logprobs = log_softmax(model.logits) 122 | extracted_logprobs, num_calls = extract_logprobs( 123 | model, prefix="test", method="topk", multithread=False, k=1 124 | ) 125 | np.testing.assert_allclose(true_logprobs, extracted_logprobs) 126 | assert num_calls == 298 127 | 128 | 129 | def test_extract_bisection(model): 130 | true_logprobs = log_softmax(model.logits) 131 | extracted_logprobs, num_calls = extract_logprobs( 132 | model, prefix="test", method="bisection", multithread=False, k=1 133 | ) 134 | np.testing.assert_allclose(true_logprobs, extracted_logprobs) 135 | assert num_calls == 3270 136 | 137 | 138 | def test_extract_exact(model): 139 | true_logprobs = log_softmax(model.logits) 140 | extracted_logprobs, num_calls = extract_logprobs( 141 | model, prefix="test", method="exact", multithread=False 142 | ) 143 | np.testing.assert_allclose(true_logprobs, extracted_logprobs) 144 | assert num_calls < len(true_logprobs) 145 | 146 | 147 | def test_extract_exact_parallel(model): 148 | true_logprobs = log_softmax(model.logits) 149 | extracted_logprobs, num_calls = extract_logprobs( 150 | model, 151 | prefix="test", 152 | method="exact", 153 | multithread=False, 154 | parallel=True, 155 | ) 156 | np.testing.assert_allclose(true_logprobs, extracted_logprobs) 157 | assert num_calls < len(true_logprobs) 158 | 159 | 160 | def test_extract_topk_multithread(model): 161 | true_logprobs = log_softmax(model.logits) 162 | extracted_logprobs, num_calls = extract_logprobs( 163 | model, prefix="test", method="topk", multithread=True, k=1 164 | ) 165 | np.testing.assert_allclose(true_logprobs, extracted_logprobs) 166 | assert num_calls == 298 167 | 168 | 169 | def test_extract_exact_multithread(model): 170 | true_logprobs = log_softmax(model.logits) 171 | extracted_logprobs, num_calls = extract_logprobs( 172 | model, prefix="test", method="exact", multithread=True 173 | ) 174 | np.testing.assert_allclose(true_logprobs, extracted_logprobs) 175 | assert num_calls < len(true_logprobs) 176 | 177 | 178 | def test_extract_exact_parallel_multithread(model): 179 | true_logprobs = log_softmax(model.logits) 180 | extracted_logprobs, num_calls = extract_logprobs( 181 | model, prefix="test", method="exact", multithread=True, parallel=True 182 | ) 183 | np.testing.assert_allclose(true_logprobs, extracted_logprobs) 184 | assert num_calls < len(true_logprobs) 185 | 186 | 187 | def test_extract_exact_parallel_multithread_uniform(uniform_model): 188 | true_logprobs = log_softmax(uniform_model.logits) 189 | extracted_logprobs, num_calls = extract_logprobs( 190 | uniform_model, 191 | prefix="test", 192 | method="exact", 193 | parallel=True, 194 | ) 195 | np.testing.assert_allclose(true_logprobs, extracted_logprobs) 196 | assert num_calls < len(true_logprobs) 197 | -------------------------------------------------------------------------------- /vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/justinchiu/openlogprobs/0d5aabde327f51589921e0d012e1c19620dc21c2/vis.png --------------------------------------------------------------------------------