├── requirements.txt
├── .gitignore
├── vis.png
├── openlogprobs
├── __init__.py
├── utils.py
├── models.py
└── extract.py
├── setup.py
├── pyproject.toml
├── .github
└── workflows
│ └── test.yml
├── README.md
└── test
└── test_logprobs.py
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | openai
3 | scipy
4 | transformers
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | build/
3 | dist/
4 | *.egg-info/
5 |
--------------------------------------------------------------------------------
/vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/justinchiu/openlogprobs/HEAD/vis.png
--------------------------------------------------------------------------------
/openlogprobs/__init__.py:
--------------------------------------------------------------------------------
1 | from .extract import extract_logprobs
2 | from .models import OpenAIModel
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import find_packages, setup
2 |
3 | setup(
4 | name="openlogprobs",
5 | version="0.0.2",
6 | description="extract log-probabilities from APIs",
7 | author="Justin Chiu & Jack Morris",
8 | author_email="jtc257@cornell.edu",
9 | packages=find_packages(),
10 | install_requires=open("requirements.txt").readlines()
11 | # install_requires=[],
12 | )
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "openlogprobs"
3 | version = "0.0.0"
4 | description = "Logprob estimation from LLM APIs"
5 | authors = ["Justin Chiu
67 |
68 | Each API call (purple) brings us successively closer to the true token probability (green).
69 |
70 | ### Exact solution
71 |
72 | Our exact solution algorithm solves directly for the logprobs.
73 | To understand the math, see [this outline](https://mattf1n.github.io/openlogprobs.html).
74 |
75 | ## Language Model Inversion paper
76 |
77 | This algorithm was developed mainly by Justin Chiu to facilitate the paper [*Language Model Inversion*](https://arxiv.org/abs/2311.13647). If you're using our algorithm in academic research, please cite our paper:
78 |
79 | The exact solution algortithm was contributed by [Matthew Finlayson](https://mattf1n.github.io).
80 |
81 | ```
82 | @misc{morris2023language,
83 | title={Language Model Inversion},
84 | author={John X. Morris and Wenting Zhao and Justin T. Chiu and Vitaly Shmatikov and Alexander M. Rush},
85 | year={2023},
86 | eprint={2311.13647},
87 | archivePrefix={arXiv},
88 | primaryClass={cs.CL}
89 | }
90 | ```
91 |
--------------------------------------------------------------------------------
/openlogprobs/models.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Optional
2 |
3 | import abc
4 | import os
5 |
6 | import openai
7 | import numpy as np
8 | import tiktoken
9 |
10 | class Model(abc.ABC):
11 | """This class wraps the model API. It can take text and a logit bias and return text outputs."""
12 |
13 | @abc.abstractproperty
14 | def vocab_size() -> int:
15 | return -1
16 |
17 | @abc.abstractmethod
18 | def argmax(self, prefix: str, logit_bias: Dict[str, float] = {}) -> int:
19 | raise NotImplementedError
20 |
21 | @abc.abstractmethod
22 | def topk(self, prefix: str, logit_bias: Dict[str, float] = {}) -> Dict[int, float]:
23 | raise NotImplementedError
24 |
25 | def median_topk(self, k, *args, **kwargs):
26 | """Runs the same topk query multiple times and returns the median. Useful
27 | to combat API nondeterminism when calling topk()."""
28 | results = [self.topk(*args, **kwargs) for _ in range(k)]
29 | return {
30 | word: np.median([result[word] for result in results])
31 | for word in results[0].keys()
32 | }
33 | def median_argmax(self, k, *args, **kwargs):
34 | """Runs the same argmax query multiple times and returns the median. Useful
35 | to combat API nondeterminism when calling argmax()."""
36 | results = [self.argmax(*args, **kwargs) for _ in range(k)]
37 | return np.median()
38 |
39 |
40 | class OpenAIModel(Model):
41 | """Model wrapper for OpenAI API."""
42 | def __init__(self, model: str, system: Optional[str] = None):
43 | self.encoding = tiktoken.encoding_for_model(model)
44 | self.client = openai.OpenAI(api_key=os.environ["OPENAI_API_KEY"])
45 | self.system = (system or "You are a helpful assistant.")
46 |
47 | @property
48 | def vocab_size(self) -> int:
49 | return self.encoding.n_vocab
50 |
51 | def argmax(self, prefix: str, logit_bias: Dict[str, float] = {}) -> int:
52 | model = self.model
53 | system = self.system
54 | enc = tiktoken.encoding_for_model(model)
55 | if model == "gpt-3.5-turbo-instruct":
56 | if logit_bias is not None:
57 | response = self.client.completions.create(
58 | model=model,
59 | prompt=prefix,
60 | temperature=0,
61 | max_tokens=1,
62 | logit_bias=logit_bias,
63 | n=1,
64 | )
65 | else:
66 | response = self.client.completions.create(
67 | model=model,
68 | prompt=prefix,
69 | temperature=0,
70 | max_tokens=1,
71 | n=1,
72 | )
73 | output = response.choices[0].text
74 | eos_idx = enc.encode(
75 | "<|endoftext|>", allowed_special={"<|endoftext|>", "<|im_start|>"}
76 | )[0]
77 | outputs = [choice.text for choice in response.choices]
78 | else:
79 | if logit_bias is not None:
80 | response = self.client.chat.completions.create(
81 | model=model,
82 | messages=[
83 | {"role": "system", "content": system},
84 | {"role": "user", "content": prefix},
85 | ],
86 | temperature=0,
87 | max_tokens=1,
88 | logit_bias=logit_bias,
89 | n=1,
90 | )
91 | else:
92 | response = self.client.chat.completions.create(
93 | model=model,
94 | messages=[
95 | {"role": "system", "content": system},
96 | {"role": "user", "content": prefix},
97 | ],
98 | temperature=0,
99 | max_tokens=1,
100 | n=1,
101 | )
102 | output = response.choices[0].message["content"]
103 | outputs = [choice.message["content"] for choice in response.choices]
104 | eos_idx = enc.encode(
105 | "<|endoftext|>", allowed_special={"<|endoftext|>", "<|im_start|>"}
106 | )[0]
107 |
108 | return enc.encode(output)[0]
109 |
110 | def topk(self, prefix: str, logit_bias: Dict[str, float] = {}) -> Dict[int, float]:
111 | enc = self.encoding
112 | model = self.model
113 | system = self.system
114 | if model == "gpt-3.5-turbo-instruct":
115 | if logit_bias is not None:
116 | response = self.client.completions.create(
117 | model=model,
118 | prompt=prefix,
119 | temperature=1,
120 | max_tokens=1,
121 | logit_bias=logit_bias,
122 | logprobs=5,
123 | )
124 | else:
125 | response = self.client.completions.create(
126 | model=model,
127 | prompt=prefix,
128 | temperature=1,
129 | max_tokens=1,
130 | logprobs=5,
131 | )
132 | else:
133 | raise NotImplementedError(f"Tried to get topk logprobs for: {model}")
134 | topk_dict = response.choices[0].logprobs.top_logprobs[0]
135 | return {enc.encode(x)[0]: y for x, y in topk_dict.items()}
136 |
--------------------------------------------------------------------------------
/openlogprobs/extract.py:
--------------------------------------------------------------------------------
1 | import tqdm
2 | import tiktoken
3 | import numpy as np
4 | from scipy.special import logsumexp
5 | import math
6 | from functools import partial, reduce
7 | from operator import or_ as union
8 | from typing import Literal, Optional
9 |
10 | from concurrent.futures import ThreadPoolExecutor
11 |
12 | from openlogprobs.models import Model
13 | from openlogprobs.utils import batched
14 |
15 |
16 | def exact_solve(
17 | model: Model,
18 | prefix: str,
19 | idx: list[int],
20 | bias: float = 5.0,
21 | top_logprob: Optional[float] = None,
22 | ) -> tuple[dict[int, float], set[int], int]:
23 | """Parallel exact solve based on https://mattf1n.github.io/openlogprobs.html"""
24 | logit_bias = {i: bias for i in idx}
25 | topk_words = model.topk(prefix, logit_bias)
26 | if all(i in topk_words for i in idx):
27 | biased_logprobs = np.array([topk_words[i] for i in idx])
28 | log_biased_prob = logsumexp(biased_logprobs)
29 | logprobs = biased_logprobs - np.logaddexp(
30 | bias + np.log1p(-np.exp(log_biased_prob)), log_biased_prob
31 | )
32 | return dict(zip(idx, logprobs)), set(), 1
33 | else:
34 | if top_logprob is None:
35 | missing_tokens = set(idx) - set(topk_words)
36 | raise TypeError(
37 | f"Tokens {missing_tokens} not in top-k with bias {bias}."
38 | "Either increase bias or provide top unbiased logprob (top_logprob)"
39 | )
40 | success_idxs = list(i for i in idx if i in topk_words)
41 | fail_idxs = set(idx) - set(topk_words)
42 | biased_top_logprob = max(
43 | logprob for i, logprob in topk_words.items() if i not in idx
44 | )
45 | biased_logprobs = np.array([topk_words[i] for i in success_idxs])
46 | logprobs = biased_logprobs - biased_top_logprob + top_logprob - bias
47 | return dict(zip(success_idxs, logprobs)), fail_idxs, 1
48 |
49 |
50 | def bisection_search(
51 | model: Model, prefix: str, idx: int, k=1, low=0, high=32, eps=1e-8
52 | ):
53 | # check if idx is the argmax
54 | num_calls = k
55 | if model.argmax(prefix) == idx:
56 | return 0, num_calls
57 |
58 | # initialize high
59 | logit_bias = {idx: high}
60 | while model.argmax(prefix, logit_bias) != idx:
61 | logit_bias[idx] *= 2
62 | num_calls += k
63 | high = logit_bias[idx]
64 |
65 | # improve estimate
66 | mid = (high + low) / 2
67 | while high >= low + eps:
68 | logit_bias[idx] = mid
69 | if model.argmax(prefix, logit_bias) == idx:
70 | high = mid
71 | else:
72 | low = mid
73 | mid = (high + low) / 2
74 | num_calls += k
75 | return -mid, num_calls
76 |
77 |
78 | def topk_search(model: Model, prefix: str, idx: int, k=1, high=40):
79 | # get raw topk, could be done outside and passed in
80 | topk_words = model.topk(prefix)
81 | highest_idx = list(topk_words.keys())[np.argmax(list(topk_words.values()))]
82 | if idx == highest_idx:
83 | return topk_words[idx], k
84 | num_calls = k
85 |
86 | # initialize high
87 | logit_bias = {idx: high}
88 | new_max_idx = model.argmax(prefix, logit_bias)
89 | num_calls += k
90 | while new_max_idx != idx:
91 | logit_bias[idx] *= 2
92 | new_max_idx = model.argmax(prefix, logit_bias)
93 | num_calls += k
94 | high = logit_bias[idx]
95 |
96 | output = model.topk(prefix, logit_bias)
97 | num_calls += k
98 |
99 | # compute normalizing constant
100 | diff = topk_words[highest_idx] - output[highest_idx]
101 | logZ = high - math.log(math.exp(diff) - 1)
102 | fv = output[idx] + math.log(math.exp(logZ) + math.exp(high)) - high
103 | logprob = fv - logZ
104 |
105 | return logprob, num_calls
106 |
107 |
108 | def extract_logprobs(
109 | model: Model,
110 | prefix: str,
111 | method: Literal["bisection", "topk", "exact"] = "bisection",
112 | k: int = 5,
113 | eps: float = 1e-6,
114 | multithread: bool = False,
115 | bias: float = 5.0,
116 | parallel: bool = False,
117 | ):
118 | vocab_size = model.vocab_size
119 |
120 | if method == "exact":
121 | logprob_dict = model.topk(prefix)
122 | top_logprob = max(logprob_dict.values())
123 | bias += top_logprob - min(logprob_dict.values())
124 | remaining = set(range(vocab_size)) - set(logprob_dict)
125 | total_calls = 0
126 | if multithread:
127 | executor = ThreadPoolExecutor(max_workers=8)
128 | map_func = executor.map
129 | else:
130 | map_func = map
131 | while remaining:
132 | search_results = map_func(
133 | partial(
134 | exact_solve,
135 | model,
136 | prefix,
137 | bias=bias,
138 | top_logprob=top_logprob,
139 | ),
140 | batched(remaining, k),
141 | )
142 | logprob_dicts, skipped, calls = zip(*search_results)
143 | logprob_dict |= reduce(union, logprob_dicts)
144 | remaining = set.union(*skipped)
145 | total_calls += sum(calls)
146 | bias += 5
147 | if multithread:
148 | executor.shutdown()
149 | logprobs = np.array([logprob_dict[i] for i in range(vocab_size)])
150 | return logprobs, total_calls
151 | else:
152 | search_func = topk_search if method == "topk" else bisection_search
153 | search = partial(search_func, model, prefix, k=k)
154 | vocab = list(range(vocab_size))
155 | if multithread:
156 | with ThreadPoolExecutor(max_workers=8) as executor:
157 | search_results = executor.map(search, tqdm.tqdm(vocab))
158 | else:
159 | search_results = map(search, tqdm.tqdm(vocab))
160 | logit_list, calls = zip(*search_results)
161 | logits = np.array(logit_list)
162 | return logits - logsumexp(logits), sum(calls)
163 |
--------------------------------------------------------------------------------
/test/test_logprobs.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 |
3 | import numpy as np
4 | import pytest
5 | from scipy.special import log_softmax
6 | import transformers
7 |
8 | from openlogprobs import (
9 | extract_logprobs,
10 | # OpenAIModel,
11 | )
12 | from openlogprobs.extract import (
13 | bisection_search,
14 | topk_search,
15 | )
16 | from openlogprobs.models import Model
17 |
18 | prefix = "Should i take this class or not? The professor of this class is not good at all. He doesn't teach well and he is always late for class."
19 |
20 |
21 | def load_fake_logits(vocab_size: int) -> np.ndarray:
22 | np.random.seed(42)
23 | logits = np.random.randn(vocab_size)
24 | logits[1] += 10
25 | logits[12] += 20
26 | logits[13] += 30
27 | logits[24] += 30
28 | logits[35] += 30
29 | return logits
30 |
31 |
32 | class FakeModel(Model):
33 | """Represents a fake API with a temperature of 1. Used for testing."""
34 |
35 | def __init__(self, vocab_size: int = 100, get_logits=None):
36 | self.tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2")
37 | self.fake_vocab_size = vocab_size
38 | if get_logits is None:
39 | self.logits = load_fake_logits(self.vocab_size)[:vocab_size]
40 | else:
41 | self.logits = get_logits(vocab_size)
42 |
43 | @property
44 | def vocab_size(self):
45 | return self.fake_vocab_size
46 |
47 | def _idx_to_str(self, idx: int) -> str:
48 | return self.tokenizer.decode([idx], skip_special_tokens=True)
49 |
50 | def _add_logit_bias(self, logit_bias: Dict[str, float]) -> np.ndarray:
51 | logits = self.logits.copy()
52 | for token_idx, bias in logit_bias.items():
53 | logits[token_idx] += bias
54 | logits = logits.astype(np.double)
55 | return log_softmax(logits)
56 |
57 | def argmax(self, prefix: str, logit_bias: Dict[str, float] = {}) -> int:
58 | logits = self._add_logit_bias(logit_bias)
59 | return logits.argmax()
60 |
61 | def topk(self, prefix: str, logit_bias: Dict[str, float] = {}) -> Dict[int, float]:
62 | k = 5 # TODO: what topk?
63 | logits = self._add_logit_bias(logit_bias)
64 | topk = logits.argsort()[-k:]
65 | return {k: logits[k] for k in topk}
66 |
67 |
68 | @pytest.fixture
69 | def model():
70 | # return OpenAIModel("gpt-3.5-turbo-instruct")
71 | return FakeModel()
72 |
73 |
74 | @pytest.fixture
75 | def uniform_model():
76 | # return OpenAIModel("gpt-3.5-turbo-instruct")
77 | return FakeModel(get_logits=np.ones)
78 |
79 |
80 | @pytest.fixture
81 | def topk_words(model):
82 | return model.topk(prefix)
83 |
84 |
85 | def test_bisection(model, topk_words):
86 | true_sorted_logprobs = np.array(sorted(topk_words.values()))
87 | true_diffs = true_sorted_logprobs - true_sorted_logprobs.max()
88 |
89 | estimated_diffs = {
90 | word: bisection_search(model, prefix, word) for word in topk_words.keys()
91 | }
92 | estimated_diffs = np.array(sorted([x[0] for x in estimated_diffs.values()]))
93 | assert np.allclose(true_diffs, estimated_diffs, atol=1e-5)
94 |
95 |
96 | def test_topk(model, topk_words):
97 | true_probs = np.array(sorted(topk_words.values()))
98 |
99 | estimated_probs = {
100 | word: topk_search(model, prefix, word) for word in topk_words.keys()
101 | }
102 | estimated_probs = np.array(sorted([x[0] for x in estimated_probs.values()]))
103 | assert np.allclose(true_probs, estimated_probs, atol=1e-5)
104 |
105 |
106 | def test_topk_consistency(model, topk_words):
107 | true_probs = np.array(sorted(topk_words.values()))
108 |
109 | probs = []
110 | for _trial in range(10):
111 | estimated_probs = {
112 | word: topk_search(model, prefix, word) for word in topk_words.keys()
113 | }
114 | estimated_probs = np.array(sorted([x[0] for x in estimated_probs.values()]))
115 | probs.append(estimated_probs)
116 | probs = np.stack(probs)
117 | assert np.allclose(true_probs, np.median(probs, 0), atol=1e-5)
118 |
119 |
120 | def test_extract_topk(model):
121 | true_logprobs = log_softmax(model.logits)
122 | extracted_logprobs, num_calls = extract_logprobs(
123 | model, prefix="test", method="topk", multithread=False, k=1
124 | )
125 | np.testing.assert_allclose(true_logprobs, extracted_logprobs)
126 | assert num_calls == 298
127 |
128 |
129 | def test_extract_bisection(model):
130 | true_logprobs = log_softmax(model.logits)
131 | extracted_logprobs, num_calls = extract_logprobs(
132 | model, prefix="test", method="bisection", multithread=False, k=1
133 | )
134 | np.testing.assert_allclose(true_logprobs, extracted_logprobs)
135 | assert num_calls == 3270
136 |
137 |
138 | def test_extract_exact(model):
139 | true_logprobs = log_softmax(model.logits)
140 | extracted_logprobs, num_calls = extract_logprobs(
141 | model, prefix="test", method="exact", multithread=False
142 | )
143 | np.testing.assert_allclose(true_logprobs, extracted_logprobs)
144 | assert num_calls < len(true_logprobs)
145 |
146 |
147 | def test_extract_exact_parallel(model):
148 | true_logprobs = log_softmax(model.logits)
149 | extracted_logprobs, num_calls = extract_logprobs(
150 | model,
151 | prefix="test",
152 | method="exact",
153 | multithread=False,
154 | parallel=True,
155 | )
156 | np.testing.assert_allclose(true_logprobs, extracted_logprobs)
157 | assert num_calls < len(true_logprobs)
158 |
159 |
160 | def test_extract_topk_multithread(model):
161 | true_logprobs = log_softmax(model.logits)
162 | extracted_logprobs, num_calls = extract_logprobs(
163 | model, prefix="test", method="topk", multithread=True, k=1
164 | )
165 | np.testing.assert_allclose(true_logprobs, extracted_logprobs)
166 | assert num_calls == 298
167 |
168 |
169 | def test_extract_exact_multithread(model):
170 | true_logprobs = log_softmax(model.logits)
171 | extracted_logprobs, num_calls = extract_logprobs(
172 | model, prefix="test", method="exact", multithread=True
173 | )
174 | np.testing.assert_allclose(true_logprobs, extracted_logprobs)
175 | assert num_calls < len(true_logprobs)
176 |
177 |
178 | def test_extract_exact_parallel_multithread(model):
179 | true_logprobs = log_softmax(model.logits)
180 | extracted_logprobs, num_calls = extract_logprobs(
181 | model, prefix="test", method="exact", multithread=True, parallel=True
182 | )
183 | np.testing.assert_allclose(true_logprobs, extracted_logprobs)
184 | assert num_calls < len(true_logprobs)
185 |
186 |
187 | def test_extract_exact_parallel_multithread_uniform(uniform_model):
188 | true_logprobs = log_softmax(uniform_model.logits)
189 | extracted_logprobs, num_calls = extract_logprobs(
190 | uniform_model,
191 | prefix="test",
192 | method="exact",
193 | parallel=True,
194 | )
195 | np.testing.assert_allclose(true_logprobs, extracted_logprobs)
196 | assert num_calls < len(true_logprobs)
197 |
--------------------------------------------------------------------------------