56 | git clone git@github.com:YOUR_NAME/structured-logprobs.git
57 | ```
58 |
59 | 3. Now we need to install the environment. Navigate into the directory
60 |
61 | ```bash
62 | cd structured-logprobs
63 | ```
64 |
65 | Then, install and activate the environment with:
66 |
67 | ```bash
68 | uv sync
69 | ```
70 |
71 | 4. Install pre-commit to run linters/formatters at commit time:
72 |
73 | ```bash
74 | uv run pre-commit install
75 | ```
76 |
77 | 5. Create a branch for local development:
78 |
79 | ```bash
80 | git checkout -b name-of-your-bugfix-or-feature
81 | ```
82 |
83 | Now you can make your changes locally.
84 |
85 | 6. Don't forget to add test cases for your added functionality to the `tests` directory.
86 |
87 | 7. When you're done making changes, check that your changes pass the formatting tests.
88 |
89 | ```bash
90 | make check
91 | ```
92 |
93 | Now, validate that all unit tests are passing:
94 |
95 | ```bash
96 | make test
97 | ```
98 |
99 | 9. Before raising a pull request you should also run tox.
100 | This will run the tests across different versions of Python:
101 |
102 | ```bash
103 | tox
104 | ```
105 |
106 | This requires you to have multiple versions of python installed.
107 | This step is also triggered in the CI/CD pipeline, so you could also choose to skip this step locally.
108 |
109 | 10. Commit your changes and push your branch to GitHub:
110 |
111 | ```bash
112 | git add .
113 | git commit -m "Your detailed description of your changes."
114 | git push origin name-of-your-bugfix-or-feature
115 | ```
116 |
117 | 11. Submit a pull request through the GitHub website.
118 |
119 | # Pull Request Guidelines
120 |
121 | Before you submit a pull request, check that it meets these guidelines:
122 |
123 | 1. The pull request should include tests.
124 |
125 | 2. If the pull request adds functionality, the docs should be updated.
126 | Put your new functionality into a function with a docstring, and add the feature to the list in `README.md`.
127 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | 
2 | 
3 | [](https://github.com/arena-ai/structured-logprobs/actions/workflows/main.yml)
4 | [](https://github.com/arena-ai/structured-logprobs/actions/workflows/on-release-main.yml)
5 |
6 | 
7 |
8 | This Python library is designed to enhance OpenAI chat completion responses by adding detailed information about token log probabilities.
9 | This library works with OpenAI [Structured Outputs](https://platform.openai.com/docs/guides/structured-outputs), which is a feature that ensures the model will always generate responses that adhere to your supplied JSON Schema, so you don't need to worry about the model omitting a required key, or hallucinating an invalid enum value.
10 | It provides utilities to analyze and incorporate token-level log probabilities into structured outputs, helping developers understand the reliability of structured data extracted from OpenAI models.
11 |
12 | ## Objective
13 |
14 | 
15 |
16 | The primary goal of **structured-logprobs** is to provide insights into the reliability of extracted data. By analyzing token-level log probabilities, the library helps assess how likely each value generated from an LLM's structured outputs is.
17 |
18 | ## Key Features
19 |
20 | The module contains a function for mapping characters to token indices (`map_characters_to_token_indices`) and two methods for incorporating log probabilities:
21 |
22 | 1. Adding log probabilities as a separate field in the response (`add_logprobs`).
23 | 2. Embedding log probabilities inline within the message content (`add_logprobs_inline`).
24 |
25 | ## Example
26 |
27 | To use this library, first create a chat completion response with the OpenAI Python SDK, then enhance the response with log probabilities.
28 | Here is an example of how to do that:
29 |
30 | ```python
31 | from openai import OpenAI
32 | from openai.types import ResponseFormatJSONSchema
33 | from structured_logprobs import add_logprobs, add_logprobs_inline
34 |
35 | # Initialize the OpenAI client
36 | client = OpenAI(api_key="your-api-key")
37 |
38 | schema_path = "path-to-your-json-schema"
39 | with open(schema_path) as f:
40 | schema_content = json.load(f)
41 |
42 | # Validate the schema content
43 | response_schema = ResponseFormatJSONSchema.model_validate(schema_content)
44 |
45 | # Create a chat completion request
46 | completion = client.chat.completions.create(
47 | model="gpt-4o-2024-08-06",
48 | messages = [
49 | {
50 | "role": "system",
51 | "content": (
52 | "I have three questions. The first question is: What is the capital of France? "
53 | "The second question is: Which are the two nicest colors? "
54 | "The third question is: Can you roll a die and tell me which number comes up?"
55 | ),
56 | }
57 | ],
58 | logprobs=True,
59 | response_format=response_schema.model_dump(by_alias=True),
60 | )
61 |
62 | chat_completion = add_logprobs(completion)
63 | chat_completion_inline = add_logprobs_inline(completion)
64 | print(chat_completion.log_probs[0])
65 | {'capital_of_France': -5.5122365e-07, 'the_two_nicest_colors': [-0.0033997903, -0.011364183612649998], 'die_shows': -0.48048785}
66 | print(chat_completion_inline.choices[0].message.content)
67 | {"capital_of_France": "Paris", "capital_of_France_logprob": -6.704273e-07, "the_two_nicest_colors": ["blue", "green"], "die_shows": 5.0, "die_shows_logprob": -2.3782086}
68 | ```
69 |
70 | ## Example JSON Schema
71 |
72 | The `response_format` in the request body is an object specifying the format that the model must output. Setting to { "type": "json_schema", "json_schema": {...} } ensures the model will match your supplied [JSON schema](https://json-schema.org/overview/what-is-jsonschema).
73 |
74 | Below is the example of the JSON file that defines the schema used for validating the responses.
75 |
76 | ```python
77 | {
78 | "type": "json_schema",
79 | "json_schema": {
80 | "name": "answears",
81 | "description": "Response to questions in JSON format",
82 | "schema": {
83 | "type": "object",
84 | "properties": {
85 | "capital_of_France": { "type": "string" },
86 | "the_two_nicest_colors": {
87 | "type": "array",
88 | "items": {
89 | "type": "string",
90 | "enum": ["red", "blue", "green", "yellow", "purple"]
91 | }
92 | },
93 | "die_shows": { "type": "number" }
94 | },
95 | "required": ["capital_of_France", "the_two_nicest_colors", "die_shows"],
96 | "additionalProperties": false
97 | },
98 | "strict": true
99 | }
100 | }
101 | ```
102 |
--------------------------------------------------------------------------------
/structured_logprobs/main.py:
--------------------------------------------------------------------------------
1 | import json
2 | from typing import Any
3 |
4 | from openai.types.chat.chat_completion import ChatCompletion
5 | from openai.types.chat.chat_completion_token_logprob import ChatCompletionTokenLogprob
6 | from pydantic import BaseModel
7 |
8 | from structured_logprobs.helpers import extract_json_data, extract_json_data_inline
9 |
10 | MISSING_LOGPROBS_MESSAGE = "The 'logprobs' field is missing"
11 |
12 | """
13 |
14 | This module provides utilities to work with OpenAI chat completion responses,
15 | enhancing them by embedding log probabilities into the data.
16 | The module contains a function for mapping characters to token indices (`map_characters_to_token_indices`) and two methods for incorporating log probabilities:
17 | 1. Adding log probabilities as a separate field in the response (`add_logprobs`).
18 | 2. Embedding log probabilities inline within the message content (`add_logprobs_inline`).
19 |
20 | Classes:
21 | - ChatCompletionWithLogProbs: Represents a chat completion response with added log probabilities.
22 |
23 | """
24 |
25 |
26 | class ChatCompletionWithLogProbs(BaseModel):
27 | value: ChatCompletion
28 | log_probs: list[Any]
29 |
30 |
31 | def map_characters_to_token_indices(extracted_data_token: list[ChatCompletionTokenLogprob]) -> list[int]:
32 | """
33 | Maps each character in the JSON string output to its corresponding token index.
34 |
35 | Args:
36 | extracted_data_token : A list of `TokenLogprob` objects, where each object represents a token and its associated data.
37 |
38 | Returns:
39 | A list of integers where each position corresponds to a character in the concatenated JSON string,
40 | and the integer at each position is the index of the token responsible for generating that specific character.
41 | Example:
42 | >>> tokens = [ChatCompletionTokenLogprob(token='{'),
43 | ChatCompletionTokenLogprob(token='"key1"'),
44 | ChatCompletionTokenLogprob(token=': '),
45 | ChatCompletionTokenLogprob(token='"value1"'),
46 | ChatCompletionTokenLogprob(token='}')]
47 | >>> map_characters_to_token_indices(tokens)
48 | [0, 1, 1, 1, 1, 1, 1, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 4]
49 | """
50 |
51 | token_indices = []
52 |
53 | for token_idx, token_data in enumerate(extracted_data_token):
54 | token_text = token_data.token
55 | token_indices.extend([token_idx] * len(token_text))
56 |
57 | return token_indices
58 |
59 |
60 | def add_logprobs(chat_completion_response: ChatCompletion) -> ChatCompletionWithLogProbs:
61 | """
62 | Adds log probabilities to the chat completion response and returns a
63 | ChatCompletionWithLogProbs object.
64 |
65 | Args:
66 | chat_completion_response: The OpenAI chat completion response.
67 |
68 | Returns:
69 | An object containing:
70 | - The original chat completion response.
71 | - A `log_probs` field, structured like the message.content of the response,
72 | where values are replaced with their respective log-probabilities.
73 | Raises:
74 | AttributeError: If any 'choice' in the response does not contain 'logprobs'.
75 |
76 | """
77 |
78 | logprobs_data = []
79 | for choice in chat_completion_response.choices:
80 | # Check if the 'logprobs' field is present
81 | if hasattr(choice, "logprobs") and choice.logprobs is not None and choice.logprobs.content is not None:
82 | extracted_data = choice.message.content
83 | logprobs_list = choice.logprobs.content
84 | token_indices = map_characters_to_token_indices(logprobs_list) if logprobs_list else []
85 | json_dict = extract_json_data(extracted_data, logprobs_list, token_indices) if extracted_data else {}
86 | logprobs_data.append(json_dict)
87 | else:
88 | raise AttributeError(MISSING_LOGPROBS_MESSAGE)
89 |
90 | chat_completion_with_logprobs = ChatCompletionWithLogProbs(value=chat_completion_response, log_probs=logprobs_data)
91 | return chat_completion_with_logprobs
92 |
93 |
94 | def add_logprobs_inline(chat_completion_response: ChatCompletion) -> ChatCompletion:
95 | """
96 | Embeds inline log probabilities into the content of the message in the chat completion response.
97 |
98 | Args:
99 | ChatCompletion: The OpenAI chat completion response.
100 |
101 | Returns:
102 | ChatCompletion: The modified chat completion response object, where the content of the message
103 | is replaced with a dictionary that includes also inline log probabilities for atomic values.
104 |
105 | Raises:
106 | AttributeError: If the 'logprobs' field is not present in the response.
107 | """
108 |
109 | for choice in chat_completion_response.choices:
110 | # Check if the 'logprobs' field is present
111 | if hasattr(choice, "logprobs") and choice.logprobs is not None and choice.logprobs.content is not None:
112 | extracted_data = choice.message.content
113 | logprobs_list = choice.logprobs.content
114 | token_indices = map_characters_to_token_indices(logprobs_list) if logprobs_list else []
115 | json_dict = extract_json_data_inline(extracted_data, logprobs_list, token_indices) if extracted_data else {}
116 | choice.message.content = json.dumps(json_dict)
117 | else:
118 | raise AttributeError(MISSING_LOGPROBS_MESSAGE)
119 |
120 | return chat_completion_response
121 |
122 |
123 | if __name__ == "__main__": # pragma: no cover
124 | pass
125 |
--------------------------------------------------------------------------------
/tests/resources/simple_parsed_completion.json:
--------------------------------------------------------------------------------
1 | {
2 | "id": "chatcmpl-AigSM81aLFRN07IUezlC3zlZ9zdaU",
3 | "choices": [
4 | {
5 | "finish_reason": "stop",
6 | "index": 0,
7 | "logprobs": {
8 | "content": [
9 | {
10 | "token": "{\"",
11 | "bytes": [123, 34],
12 | "logprob": -0.000012590794,
13 | "top_logprobs": []
14 | },
15 | {
16 | "token": "name",
17 | "bytes": [110, 97, 109, 101],
18 | "logprob": 0.0,
19 | "top_logprobs": []
20 | },
21 | {
22 | "token": "\":\"",
23 | "bytes": [34, 58, 34],
24 | "logprob": -6.704273e-7,
25 | "top_logprobs": []
26 | },
27 | {
28 | "token": "Science",
29 | "bytes": [83, 99, 105, 101, 110, 99, 101],
30 | "logprob": -0.00012964146,
31 | "top_logprobs": []
32 | },
33 | {
34 | "token": " Fair",
35 | "bytes": [32, 70, 97, 105, 114],
36 | "logprob": -0.000058603408,
37 | "top_logprobs": []
38 | },
39 | {
40 | "token": "\",\"",
41 | "bytes": [34, 44, 34],
42 | "logprob": -0.0018461747,
43 | "top_logprobs": []
44 | },
45 | {
46 | "token": "date",
47 | "bytes": [100, 97, 116, 101],
48 | "logprob": 0.0,
49 | "top_logprobs": []
50 | },
51 | {
52 | "token": "\":\"",
53 | "bytes": [34, 58, 34],
54 | "logprob": -4.9617593e-6,
55 | "top_logprobs": []
56 | },
57 | {
58 | "token": "Friday",
59 | "bytes": [70, 114, 105, 100, 97, 121],
60 | "logprob": -0.09504829,
61 | "top_logprobs": []
62 | },
63 | {
64 | "token": "\",\"",
65 | "bytes": [34, 44, 34],
66 | "logprob": -0.0026011032,
67 | "top_logprobs": []
68 | },
69 | {
70 | "token": "participants",
71 | "bytes": [
72 | 112, 97, 114, 116, 105, 99, 105, 112, 97, 110, 116,
73 | 115
74 | ],
75 | "logprob": -1.9361265e-7,
76 | "top_logprobs": []
77 | },
78 | {
79 | "token": "\":[\"",
80 | "bytes": [34, 58, 91, 34],
81 | "logprob": 0.0,
82 | "top_logprobs": []
83 | },
84 | {
85 | "token": "Alice",
86 | "bytes": [65, 108, 105, 99, 101],
87 | "logprob": 0.0,
88 | "top_logprobs": []
89 | },
90 | {
91 | "token": "\",\"",
92 | "bytes": [34, 44, 34],
93 | "logprob": -1.2664457e-6,
94 | "top_logprobs": []
95 | },
96 | {
97 | "token": "Bob",
98 | "bytes": [66, 111, 98],
99 | "logprob": -7.89631e-7,
100 | "top_logprobs": []
101 | },
102 | {
103 | "token": "\"]",
104 | "bytes": [34, 93],
105 | "logprob": -1.9361265e-7,
106 | "top_logprobs": []
107 | },
108 | {
109 | "token": "}",
110 | "bytes": [125],
111 | "logprob": 0.0,
112 | "top_logprobs": []
113 | }
114 | ],
115 | "refusal": null
116 | },
117 | "message": {
118 | "content": "{\"name\":\"Science Fair\",\"date\":\"Friday\",\"participants\":[\"Alice\",\"Bob\"]}",
119 | "refusal": null,
120 | "role": "assistant",
121 | "audio": null,
122 | "function_call": null,
123 | "tool_calls": [],
124 | "parsed": {
125 | "name": "Science Fair",
126 | "date": "Friday",
127 | "participants": ["Alice", "Bob"]
128 | }
129 | }
130 | }
131 | ],
132 | "created": 1735212998,
133 | "model": "gpt-4o-2024-08-06",
134 | "object": "chat.completion",
135 | "service_tier": null,
136 | "system_fingerprint": "fp_5f20662549",
137 | "usage": {
138 | "completion_tokens": 18,
139 | "prompt_tokens": 92,
140 | "total_tokens": 110,
141 | "completion_tokens_details": {
142 | "accepted_prediction_tokens": 0,
143 | "audio_tokens": 0,
144 | "reasoning_tokens": 0,
145 | "rejected_prediction_tokens": 0
146 | },
147 | "prompt_tokens_details": {
148 | "audio_tokens": 0,
149 | "cached_tokens": 0
150 | }
151 | }
152 | }
153 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from pathlib import Path
4 | from typing import Any
5 |
6 | import pytest
7 | from openai import OpenAI
8 | from openai.types import ResponseFormatJSONSchema
9 | from openai.types.chat.chat_completion import ChatCompletion
10 | from openai.types.chat.parsed_chat_completion import ParsedChatCompletion
11 | from pydantic import BaseModel
12 |
13 |
14 | class CalendarEvent(BaseModel):
15 | name: str
16 | date: str | None
17 | participants: list[str]
18 |
19 |
20 | @pytest.fixture
21 | def chat_completion(pytestconfig) -> ChatCompletion:
22 | client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
23 | base_path = Path(pytestconfig.rootdir) # Base directory where pytest was run
24 | schema_path = base_path / "tests" / "resources" / "questions_json_schema.json"
25 | with open(schema_path) as f:
26 | schema_content = json.load(f)
27 |
28 | # Validate the schema content
29 | response_schema = ResponseFormatJSONSchema.model_validate(schema_content)
30 |
31 | completion = client.chat.completions.create(
32 | model="gpt-4o-2024-08-06",
33 | messages=[
34 | {
35 | "role": "system",
36 | "content": (
37 | "I have three questions. The first question is: What is the capital of France? "
38 | "The second question is: Which are the two nicest colors? "
39 | "The third question is: Can you roll a die and tell me which number comes up?"
40 | ),
41 | }
42 | ],
43 | logprobs=True,
44 | # Serialize using alias names to match OpenAI API's expected format.
45 | # This ensures that the field 'schema_' is serialized as 'schema' to meet the API's naming conventions.
46 | response_format=response_schema.model_dump(by_alias=True),
47 | )
48 | return completion
49 |
50 |
51 | @pytest.fixture
52 | def parsed_chat_completion() -> ParsedChatCompletion:
53 | client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
54 |
55 | # A simple data model
56 | class CalendarEvent(BaseModel):
57 | name: str
58 | date: str
59 | participants: list[str]
60 |
61 | # A request with structured output
62 | completion = client.beta.chat.completions.parse(
63 | model="gpt-4o-2024-08-06",
64 | messages=[
65 | {"role": "system", "content": "Extract the event information."},
66 | {"role": "user", "content": "Alice and Bob are going to a science fair on Friday."},
67 | ],
68 | logprobs=True,
69 | response_format=CalendarEvent,
70 | )
71 | return completion
72 |
73 |
74 | @pytest.fixture
75 | def simple_parsed_completion(pytestconfig) -> ParsedChatCompletion[CalendarEvent] | None:
76 | base_path = Path(pytestconfig.rootdir) # Base directory where pytest was run
77 | with open(base_path / "tests" / "resources" / "simple_parsed_completion.json") as f:
78 | return ParsedChatCompletion[CalendarEvent].model_validate_json(f.read())
79 | return None
80 |
81 |
82 | @pytest.fixture
83 | def json_output() -> dict[str, Any]:
84 | return {"name": -0.0001889152953, "date": -0.09505325175929999, "participants": [0.0, -2.0560767000000003e-06]}
85 |
86 |
87 | @pytest.fixture
88 | def json_output_inline() -> str:
89 | return json.dumps({
90 | "name": "Science Fair",
91 | "name_logprob": -0.0001889152953,
92 | "date": "Friday",
93 | "date_logprob": -0.09505325175929999,
94 | "participants": ["Alice", "Bob"],
95 | })
96 |
97 |
98 | class TokenLogprob:
99 | def __init__(self, token: str, logprob: float):
100 | self.token = token
101 | self.logprob = logprob
102 |
103 |
104 | @pytest.fixture
105 | def data_token() -> list[TokenLogprob]:
106 | return [
107 | TokenLogprob(token="{", logprob=-1.9365e-07), # Token index 0
108 | TokenLogprob(token='"a"', logprob=-0.01117), # Token index 1
109 | TokenLogprob(token=': "', logprob=-0.00279), # Token index 2
110 | TokenLogprob(token="he", logprob=-1.1472e-06), # Token index 3
111 | TokenLogprob(token='llo"', logprob=-0.00851), # Token index 4
112 | TokenLogprob(token=', "', logprob=-0.00851), # Token index 5
113 | TokenLogprob(token="b", logprob=-0.00851), # Token index 6
114 | TokenLogprob(token='": ', logprob=-0.00851), # Token index 7
115 | TokenLogprob(token="12", logprob=-0.00851), # Token index 8
116 | TokenLogprob(token=', "', logprob=-1.265e-07), # Token index 9
117 | TokenLogprob(token='c"', logprob=-0.00851), # Token index 10
118 | TokenLogprob(token=': [{"', logprob=-0.00851), # Token index 11
119 | TokenLogprob(token="d", logprob=-1.265e-07), # Token index 12
120 | TokenLogprob(token='":', logprob=-0.00851), # Token index 13
121 | TokenLogprob(token="42", logprob=-0.00851), # Token index 14
122 | TokenLogprob(token="}, ", logprob=-1.265e-07), # Token index 15
123 | TokenLogprob(token="11", logprob=-0.00851), # Token index 16
124 | TokenLogprob(token="]}", logprob=-1.265e-07), # Token index 17
125 | ]
126 |
127 |
128 | @pytest.fixture
129 | def token_indices() -> list[int]:
130 | return [
131 | 0,
132 | 1,
133 | 1,
134 | 1,
135 | 2,
136 | 2,
137 | 2,
138 | 3,
139 | 3,
140 | 4,
141 | 4,
142 | 4,
143 | 4,
144 | 5,
145 | 5,
146 | 5,
147 | 6,
148 | 7,
149 | 7,
150 | 7,
151 | 8,
152 | 8,
153 | 9,
154 | 9,
155 | 9,
156 | 10,
157 | 10,
158 | 11,
159 | 11,
160 | 11,
161 | 11,
162 | 11,
163 | 12,
164 | 13,
165 | 13,
166 | 14,
167 | 14,
168 | 15,
169 | 15,
170 | 15,
171 | 16,
172 | 16,
173 | 17,
174 | 17,
175 | ]
176 |
--------------------------------------------------------------------------------
/tests/test_main.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from pathlib import Path
4 |
5 | import pytest
6 | from dotenv import load_dotenv
7 | from openai import OpenAI
8 | from openai.types import ResponseFormatJSONSchema
9 | from openai.types.chat.parsed_chat_completion import ParsedChatCompletion
10 |
11 | from structured_logprobs.main import add_logprobs, add_logprobs_inline, map_characters_to_token_indices
12 |
13 | load_dotenv()
14 |
15 |
16 | def test_map_characters_to_token_indices(data_token, token_indices):
17 | result = map_characters_to_token_indices(data_token)
18 |
19 | assert result == token_indices
20 | assert result.count(1) == len(data_token[1].token)
21 |
22 |
23 | @pytest.mark.skip(reason="We do not want to automate this as no OPENAI_API_KEY is on github yet")
24 | def test_add_logprobs_with_openai(chat_completion):
25 | completion = add_logprobs(chat_completion)
26 | assert list(completion.log_probs[0].keys()) == ["capital_of_France", "the_two_nicest_colors", "die_shows"]
27 | assert isinstance(list(completion.log_probs[0].values())[0], float)
28 | assert isinstance(list(completion.log_probs[0].values())[1], list)
29 | assert isinstance(list(completion.log_probs[0].values())[1][0], float)
30 | assert isinstance(list(completion.log_probs[0].values())[2], float)
31 |
32 |
33 | @pytest.mark.skip(reason="We do not want to automate this as no OPENAI_API_KEY is on github yet")
34 | def test_add_logprobs_inline_with_openai(chat_completion):
35 | completion_inline = add_logprobs_inline(chat_completion)
36 | message_content = json.loads(completion_inline.choices[0].message.content)
37 | assert list(message_content.keys()) == [
38 | "capital_of_France",
39 | "capital_of_France_logprob",
40 | "the_two_nicest_colors",
41 | "die_shows",
42 | "die_shows_logprob",
43 | ]
44 | assert json.loads(completion_inline.choices[0].message.content)["capital_of_France"] == "Paris"
45 | assert isinstance(list(message_content.values())[0], str)
46 | assert isinstance(list(message_content.values())[1], float)
47 | assert isinstance(list(message_content.values())[2], list)
48 | assert isinstance(list(message_content.values())[2][1], str)
49 | assert isinstance(list(message_content.values())[3], float)
50 | assert isinstance(list(message_content.values())[4], float)
51 |
52 |
53 | @pytest.mark.skip(reason="We do not want to automate this as no OPENAI_API_KEY is on github yet")
54 | def test_generic_completion_with_openai(pytestconfig, json_output):
55 | client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
56 | base_path = Path(pytestconfig.rootdir) # Base directory where pytest was run
57 | schema_path = base_path / "tests" / "resources" / "simple_json_schema.json"
58 | with open(schema_path) as f:
59 | schema_content = json.load(f)
60 |
61 | # Validate the schema content
62 | response_schema = ResponseFormatJSONSchema.model_validate(schema_content)
63 |
64 | completion = client.chat.completions.create(
65 | model="gpt-4o-2024-08-06",
66 | messages=[
67 | {"role": "system", "content": "Extract the event information."},
68 | {"role": "user", "content": "Alice and Bob are going to a science fair on Friday."},
69 | ],
70 | logprobs=True,
71 | # Serialize using alias names to match OpenAI API's expected format.
72 | # This ensures that the field 'schema_' is serialized as 'schema' to meet the API's naming conventions.
73 | response_format=response_schema.model_dump(by_alias=True),
74 | )
75 | chat_completion = add_logprobs(completion)
76 | _ = add_logprobs_inline(completion)
77 | assert list(chat_completion.log_probs[0].keys()) == list(json_output.keys())
78 |
79 |
80 | @pytest.mark.skip(reason="We do not want to automate this as no OPENAI_API_KEY is on github yet")
81 | def test_add_logprobs_parsed_completion_with_openai(parsed_chat_completion, json_output):
82 | completion = add_logprobs(parsed_chat_completion)
83 | event = completion.value.choices[0].message.parsed
84 | assert event.name == "Science Fair"
85 | assert list(completion.log_probs[0].keys()) == list(json_output.keys())
86 | assert type(list(completion.log_probs[0].values())[0]) is type(list(json_output.values())[0])
87 | assert type(list(completion.log_probs[0].values())[1]) is type(list(json_output.values())[1])
88 | assert type(list(completion.log_probs[0].values())[2]) is type(list(json_output.values())[2])
89 | assert type(list(completion.log_probs[0].values())[2][1]) is type(list(json_output.values())[2][1])
90 |
91 |
92 | @pytest.mark.skip(reason="We do not want to automate this as no OPENAI_API_KEY is on github yet")
93 | def test_add_logprobs_inline_parsed_completion_with_openai(parsed_chat_completion, json_output_inline):
94 | completion_inline = add_logprobs_inline(parsed_chat_completion)
95 | message_content = json.loads(completion_inline.choices[0].message.content)
96 | assert list(message_content.keys()) == list(json.loads(json_output_inline).keys())
97 | assert list(message_content.values())[0] == "Science Fair"
98 | assert isinstance(list(message_content.values())[1], float)
99 | assert list(message_content.values())[2] == "Friday"
100 | assert isinstance(list(message_content.values())[3], float)
101 | assert list(message_content.values())[4] == ["Alice", "Bob"]
102 |
103 |
104 | def test_add_logprobs(simple_parsed_completion, json_output):
105 | completion = add_logprobs(simple_parsed_completion)
106 | if isinstance(completion.value, ParsedChatCompletion):
107 | event = completion.value.choices[0].message.parsed
108 | assert event.name == "Science Fair"
109 | assert completion.log_probs[0] == json_output
110 |
111 |
112 | def test_add_logprobs_inline(simple_parsed_completion, json_output_inline):
113 | completion = add_logprobs_inline(simple_parsed_completion)
114 | if isinstance(completion, ParsedChatCompletion):
115 | event = completion.choices[0].message.parsed
116 | assert event.name == "Science Fair"
117 | assert completion.choices[0].message.content == json_output_inline
118 |
--------------------------------------------------------------------------------
/structured_logprobs/helpers.py:
--------------------------------------------------------------------------------
1 | from typing import Any, TypeAlias
2 |
3 | from lark import Lark, Token, Transformer_NonRecursive, Tree, v_args
4 | from lark.tree import Meta
5 | from openai.types.chat.chat_completion_token_logprob import ChatCompletionTokenLogprob
6 | from pydantic import BaseModel
7 |
8 | PyTree: TypeAlias = Any # a tree-like structure built out of container-like Python objects.
9 |
10 |
11 | class HasProb(BaseModel):
12 | value: Any
13 | start: int
14 | end: int
15 | logprob: float
16 |
17 |
18 | # Define a grammar for JSON
19 | json_grammar = r"""
20 | start: value
21 |
22 | ?value: object #'?' is a Lark convention indicating that the rule can return the value directly instead of creating a separate parse tree node.
23 | | array
24 | | string
25 | | SIGNED_NUMBER -> number #'-> number' specifies an alias for the rule
26 | | true
27 | | false
28 | | null
29 |
30 | true: "true"
31 | false: "false"
32 | null: "null"
33 | array : "[" [value ("," value)*] "]"
34 | object : "{" [pair ("," pair)*] "}"
35 | pair : key ":" value
36 | key : ESCAPED_STRING
37 |
38 | string : ESCAPED_STRING
39 |
40 | %import common.ESCAPED_STRING
41 | %import common.SIGNED_NUMBER
42 | %import common.WS
43 | %ignore WS
44 | """
45 |
46 |
47 | # Transformer that processes the tree and substitutes each atomic value with the cumulative log-probability of its tokens
48 | @v_args(meta=True)
49 | class Extractor(Transformer_NonRecursive):
50 | def __init__(self, tokens: list[ChatCompletionTokenLogprob], token_indices: list[int]):
51 | super().__init__()
52 | self.tokens = tokens
53 | self.token_indices = token_indices
54 |
55 | def _compute_logprob_sum(self, start: int, end: int) -> float:
56 | token_start = self.token_indices[start]
57 | token_end = self.token_indices[end]
58 | sum_logporb = sum(self.tokens[i].logprob for i in range(token_start, token_end))
59 | return sum_logporb
60 |
61 | def number(self, meta: Meta, children: list[Token]) -> float:
62 | logprob_sum = self._compute_logprob_sum(meta.start_pos, meta.end_pos)
63 | return logprob_sum
64 |
65 | def string(self, meta: Meta, children: list[Token]) -> float:
66 | logprob_sum = self._compute_logprob_sum(meta.start_pos, meta.end_pos)
67 | return logprob_sum
68 |
69 | def true(self, meta: Meta, children: list[Token]) -> float:
70 | logprob_sum = self._compute_logprob_sum(meta.start_pos, meta.end_pos)
71 | return logprob_sum
72 |
73 | def false(self, meta: Meta, children: list[Token]) -> float:
74 | logprob_sum = self._compute_logprob_sum(meta.start_pos, meta.end_pos)
75 | return logprob_sum
76 |
77 | def null(self, meta: Meta, children: list[Token]) -> float:
78 | logprob_sum = self._compute_logprob_sum(meta.start_pos, meta.end_pos)
79 | return logprob_sum
80 |
81 | def array(self, meta: Meta, children: list[Any]) -> list[float]:
82 | return children
83 |
84 | def object(self, meta: Meta, children: list[tuple[str, Any]]) -> dict[str, Any]:
85 | result = {}
86 | for key, value in children:
87 | result[key] = value
88 | return result
89 |
90 | def pair(self, meta: Meta, children: list[Any]) -> tuple[str, Any]:
91 | value = children[1]
92 | key = children[0]
93 | if isinstance(value, Tree) and not value.children: # ['b', Tree(Token('RULE', 'value'), [])]
94 | value = None
95 | return key, value
96 |
97 | def key(self, meta: Meta, children: list[Token]) -> str:
98 | return children[0][1:-1]
99 |
100 | def start(self, meta: Meta, children: list[dict[str, Any]]) -> dict[str, Any]:
101 | return children[0]
102 |
103 |
104 | def extract_json_data(json_string: str, tokens: list[ChatCompletionTokenLogprob], token_indices: list[int]) -> PyTree:
105 | json_parser = Lark(json_grammar, parser="lalr", propagate_positions=True, maybe_placeholders=False)
106 | tree = json_parser.parse(json_string)
107 | extractor = Extractor(tokens, token_indices)
108 | return extractor.transform(tree)
109 |
110 |
111 | # Transformer that embeds log-probabilities for atomic values as in-line fields in dictionaries
112 | @v_args(meta=True)
113 | class ExtractorInline(Transformer_NonRecursive):
114 | def __init__(self, tokens: list[ChatCompletionTokenLogprob], token_indices: list[int]):
115 | super().__init__()
116 | self.tokens = tokens
117 | self.token_indices = token_indices
118 |
119 | def _compute_logprob_sum(self, start: int, end: int) -> float:
120 | token_start = self.token_indices[start]
121 | token_end = self.token_indices[end]
122 | sum_logporb = sum(self.tokens[i].logprob for i in range(token_start, token_end))
123 | return sum_logporb
124 |
125 | def number(self, meta: Meta, children: list[Token]) -> HasProb:
126 | logprob_sum = self._compute_logprob_sum(meta.start_pos, meta.end_pos)
127 | return HasProb(value=float(children[0]), start=meta.start_pos, end=meta.end_pos, logprob=logprob_sum)
128 |
129 | def string(self, meta: Meta, children: list[Token]) -> HasProb:
130 | logprob_sum = self._compute_logprob_sum(meta.start_pos, meta.end_pos)
131 | return HasProb(value=children[0][1:-1], start=meta.start_pos, end=meta.end_pos, logprob=logprob_sum)
132 |
133 | def true(self, meta: Meta, children: list[Token]) -> HasProb:
134 | logprob_sum = self._compute_logprob_sum(meta.start_pos, meta.end_pos)
135 | return HasProb(value=True, start=meta.start_pos, end=meta.end_pos, logprob=logprob_sum)
136 |
137 | def false(self, meta: Meta, children: list[Token]) -> HasProb:
138 | logprob_sum = self._compute_logprob_sum(meta.start_pos, meta.end_pos)
139 | return HasProb(value=False, start=meta.start_pos, end=meta.end_pos, logprob=logprob_sum)
140 |
141 | def null(self, meta: Meta, children: list[Token]) -> None:
142 | return None
143 |
144 | def array(self, meta: Meta, children: list[dict[str, Any] | Any]) -> list[dict[str, Any] | Any]:
145 | return [child.value if isinstance(child, HasProb) else child for child in children]
146 |
147 | def object(self, meta: Meta, children: list[tuple[str, Any]]) -> dict[str, Any]:
148 | result = {}
149 | for key, value in children:
150 | if isinstance(value, HasProb):
151 | result[key] = value.value
152 | result[f"{key}_logprob"] = value.logprob
153 | else:
154 | result[key] = value
155 | return result
156 |
157 | def pair(self, meta: Meta, children: list[str | Any]) -> tuple[str, Any]:
158 | value = children[1]
159 | key = children[0]
160 | if isinstance(value, Tree) and not value.children: # ['b', Tree(Token('RULE', 'value'), [])]
161 | value = None
162 | return key, value
163 |
164 | def key(self, meta: Meta, children: list[Token]) -> str:
165 | return children[0][1:-1]
166 |
167 | def start(self, meta: Meta, children: list[dict[str, Any]]) -> dict[str, Any]:
168 | return children[0]
169 |
170 |
171 | def extract_json_data_inline(
172 | json_string: str, tokens: list[ChatCompletionTokenLogprob], token_indices: list[int]
173 | ) -> PyTree:
174 | json_parser = Lark(json_grammar, parser="lalr", propagate_positions=True, maybe_placeholders=False)
175 | tree = json_parser.parse(json_string)
176 | extractor = ExtractorInline(tokens, token_indices)
177 | return extractor.transform(tree)
178 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/docs/notebooks/notebook.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "ySrdbsAImmvN"
7 | },
8 | "source": [
9 | "[](https://github.com/arena-ai/structured-logprobs/blob/main/docs/notebooks/notebook.ipynb) [](https://colab.research.google.com/github/arena-ai/structured-logprobs/blob/main/docs/notebooks/notebook.ipynb)\n",
10 | "\n",
11 | ""
12 | ]
13 | },
14 | {
15 | "cell_type": "markdown",
16 | "metadata": {
17 | "id": "dwszpJwbmmvP"
18 | },
19 | "source": [
20 | "This notebook provides a practical guide on using the `structured-logprobs` library with OpenAI's API to generate structured responses enriched with token-level log-probabilities."
21 | ]
22 | },
23 | {
24 | "cell_type": "markdown",
25 | "metadata": {
26 | "id": "VNRqF0lummvP"
27 | },
28 | "source": [
29 | "## Install the library\n",
30 | "\n",
31 | "`structured-logprobs` is available on PyPI and can be simply installed with pip."
32 | ]
33 | },
34 | {
35 | "cell_type": "code",
36 | "execution_count": null,
37 | "metadata": {
38 | "id": "LEzOxBTuz17L"
39 | },
40 | "outputs": [],
41 | "source": [
42 | "!pip install structured-logprobs~=0.1"
43 | ]
44 | },
45 | {
46 | "cell_type": "markdown",
47 | "metadata": {
48 | "id": "h3HpFDTPmmvR"
49 | },
50 | "source": [
51 | "Let's import the required libraries."
52 | ]
53 | },
54 | {
55 | "cell_type": "code",
56 | "execution_count": 2,
57 | "metadata": {
58 | "id": "zoq76ttC0nBS"
59 | },
60 | "outputs": [],
61 | "source": [
62 | "import getpass\n",
63 | "import json\n",
64 | "import math\n",
65 | "\n",
66 | "from openai import OpenAI\n",
67 | "from openai.types import ResponseFormatJSONSchema\n",
68 | "from rich import print, print_json\n",
69 | "\n",
70 | "from structured_logprobs.main import add_logprobs, add_logprobs_inline"
71 | ]
72 | },
73 | {
74 | "cell_type": "markdown",
75 | "metadata": {
76 | "id": "QNgxKrqummvS"
77 | },
78 | "source": [
79 | "## Setting Up the OpenAI API Client\n",
80 | "\n",
81 | "An OpenAI API key is mandatory to authenticate access to OpenAI's API. It is a token necessary to initialize the OpenAI Python client, enabling you to send requests to the API and receive responses.\n",
82 | "\n",
83 | "In this notebook, you will be prompted to enter your OPENAI_API_KEY securely using Python's getpass module. This ensures that your key is not hardcoded, reducing the risk of accidental exposure."
84 | ]
85 | },
86 | {
87 | "cell_type": "code",
88 | "execution_count": null,
89 | "metadata": {
90 | "id": "2QNRWDY6-dTg"
91 | },
92 | "outputs": [],
93 | "source": [
94 | "api_key = getpass.getpass(prompt=\"Enter you OPENAI_API_KEY: \")"
95 | ]
96 | },
97 | {
98 | "cell_type": "markdown",
99 | "metadata": {
100 | "id": "XEFmv6H9mmvS"
101 | },
102 | "source": [
103 | "Let's initialize the OpenAI client."
104 | ]
105 | },
106 | {
107 | "cell_type": "code",
108 | "execution_count": 4,
109 | "metadata": {
110 | "id": "bS5uatkr0m3x"
111 | },
112 | "outputs": [],
113 | "source": [
114 | "client = OpenAI(api_key=api_key)"
115 | ]
116 | },
117 | {
118 | "cell_type": "markdown",
119 | "metadata": {
120 | "id": "fk459nuxmmvT"
121 | },
122 | "source": [
123 | "## Create a chat completion request\n",
124 | "\n",
125 | "The first step is to define the JSON schema, used to structure the chat request to OpenAI. This schema helps OpenAI understand exactly how the response should be formatted and organized.\n",
126 | "\n",
127 | "Below is an example JSON schema used in this notebook. To learn more about JSON Schema, refer to [this overview](https://json-schema.org/overview/what-is-jsonschema)"
128 | ]
129 | },
130 | {
131 | "cell_type": "code",
132 | "execution_count": 5,
133 | "metadata": {
134 | "id": "aZ389dNjBZ19"
135 | },
136 | "outputs": [],
137 | "source": [
138 | "schema_content = {\n",
139 | " \"type\": \"json_schema\",\n",
140 | " \"json_schema\": {\n",
141 | " \"name\": \"answears\",\n",
142 | " \"description\": \"Response to questions in JSON format\",\n",
143 | " \"schema\": {\n",
144 | " \"type\": \"object\",\n",
145 | " \"properties\": {\n",
146 | " \"capital_of_France\": {\"type\": \"string\"},\n",
147 | " \"the_two_nicest_colors\": {\n",
148 | " \"type\": \"array\",\n",
149 | " \"items\": {\"type\": \"string\", \"enum\": [\"red\", \"blue\", \"green\", \"yellow\", \"purple\"]},\n",
150 | " },\n",
151 | " \"die_shows\": {\"type\": \"integer\"},\n",
152 | " },\n",
153 | " \"required\": [\"capital_of_France\", \"the_two_nicest_colors\", \"die_shows\"],\n",
154 | " \"additionalProperties\": False,\n",
155 | " },\n",
156 | " \"strict\": True,\n",
157 | " },\n",
158 | "}"
159 | ]
160 | },
161 | {
162 | "cell_type": "markdown",
163 | "metadata": {
164 | "id": "yPzGl5iYmmvT"
165 | },
166 | "source": [
167 | "The schema must be validated before being used as a parameter in the request to OpenAI."
168 | ]
169 | },
170 | {
171 | "cell_type": "code",
172 | "execution_count": 6,
173 | "metadata": {
174 | "id": "gRX19vy5ANZb"
175 | },
176 | "outputs": [],
177 | "source": [
178 | "response_schema = ResponseFormatJSONSchema.model_validate(schema_content)"
179 | ]
180 | },
181 | {
182 | "cell_type": "markdown",
183 | "metadata": {
184 | "id": "cIQD_KoFmmvU"
185 | },
186 | "source": [
187 | "Additionally, to create the chat completion, you must set up the model, input messages, and other parameters such as logprobs and response_format."
188 | ]
189 | },
190 | {
191 | "cell_type": "code",
192 | "execution_count": 7,
193 | "metadata": {
194 | "id": "sdvFNe2b-MLE"
195 | },
196 | "outputs": [],
197 | "source": [
198 | "completion = client.chat.completions.create(\n",
199 | " model=\"gpt-4o-2024-08-06\",\n",
200 | " messages=[\n",
201 | " {\n",
202 | " \"role\": \"system\",\n",
203 | " \"content\": (\n",
204 | " \"I have three questions. The first question is: What is the capital of France? \"\n",
205 | " \"The second question is: Which are the two nicest colors? \"\n",
206 | " \"The third question is: Can you roll a die and tell me which number comes up?\"\n",
207 | " ),\n",
208 | " }\n",
209 | " ],\n",
210 | " logprobs=True,\n",
211 | " response_format=response_schema.model_dump(by_alias=True),\n",
212 | ")"
213 | ]
214 | },
215 | {
216 | "cell_type": "markdown",
217 | "metadata": {
218 | "id": "yadvuQqgmmvU"
219 | },
220 | "source": [
221 | "If you print the response, you can observe how OpenAI organizes the logprobs. These logprobs are associated with individual tokens, which may not be convenient if you are looking for the log probability of the full value extracted for each requested field.\n",
222 | "\n",
223 | "```python\n",
224 | "ChatCompletion(\n",
225 | " id='chatcmpl-ApHuoaVGaxOoPUX6syvQt9XkfSkCe',\n",
226 | " choices=[\n",
227 | " Choice(\n",
228 | " finish_reason='stop',\n",
229 | " index=0,\n",
230 | " logprobs=ChoiceLogprobs(\n",
231 | " content=[\n",
232 | " ChatCompletionTokenLogprob(\n",
233 | " token='{\"',\n",
234 | " bytes=[123, 34],\n",
235 | " logprob=-1.50940705e-05\n",
236 | " ),\n",
237 | " ,\n",
238 | " ChatCompletionTokenLogprob(\n",
239 | " token='capital',\n",
240 | " bytes=[99, 97, 112, 105, 116, 97, 108],\n",
241 | " logprob=-7.226629e-06\n",
242 | " ),\n",
243 | " #...\n",
244 | " ],\n",
245 | " refusal=None\n",
246 | " ),\n",
247 | " message=ChatCompletionMessage(\n",
248 | " content='{\"capital_of_France\": \"Paris\", \"capital_of_France_logprob\": -1.22165105e-06,\n",
249 | "\"the_two_nicest_colors\": [\"blue\", \"green\"], \"die_shows\": 4.0, \"die_shows_logprob\": -0.44008404}',\n",
250 | " refusal=None,\n",
251 | " role='assistant',\n",
252 | " audio=None,\n",
253 | " function_call=None,\n",
254 | " tool_calls=None\n",
255 | " )\n",
256 | " )\n",
257 | " ],\n",
258 | " created=1736786958,\n",
259 | " model='gpt-4o-2024-08-06',\n",
260 | " object='chat.completion',\n",
261 | " service_tier='default',\n",
262 | " system_fingerprint='fp_703d4ff298',\n",
263 | " usage=CompletionUsage(\n",
264 | " completion_tokens=27,\n",
265 | " prompt_tokens=133,\n",
266 | " total_tokens=160,\n",
267 | " completion_tokens_details=CompletionTokensDetails(\n",
268 | " accepted_prediction_tokens=0,\n",
269 | " audio_tokens=0,\n",
270 | " reasoning_tokens=0,\n",
271 | " rejected_prediction_tokens=0\n",
272 | " ),\n",
273 | " prompt_tokens_details=PromptTokensDetails(audio_tokens=0, cached_tokens=0)\n",
274 | " )\n",
275 | ")\n",
276 | "```"
277 | ]
278 | },
279 | {
280 | "cell_type": "markdown",
281 | "metadata": {
282 | "id": "BpxIfEjJmmvU"
283 | },
284 | "source": [
285 | "## Enhance the chat completion result with log probabilities\n",
286 | "\n",
287 | "The strategy for aggregating log-probabilities involves mapping each character in the generated message's content to its corresponding token. Instead of focusing on individual token probabilities, the log probabilities of all tokens that form a given value are summed. This approach generates a more meaningful probability for all JSON elements."
288 | ]
289 | },
290 | {
291 | "cell_type": "code",
292 | "execution_count": 8,
293 | "metadata": {
294 | "id": "M49qluzkmmvU"
295 | },
296 | "outputs": [],
297 | "source": [
298 | "chat_completion = add_logprobs(completion)"
299 | ]
300 | },
301 | {
302 | "cell_type": "markdown",
303 | "metadata": {
304 | "id": "NafNeu3JmmvV"
305 | },
306 | "source": [
307 | "Now if you print the response you can see that it is a new Python object, which contains the original OpenAI response under the 'value' field, and a 'log_probs' field where the message values are replaced with their respective log probabilities.\n",
308 | "\n",
309 | "```python\n",
310 | "ChatCompletionWithLogProbs(\n",
311 | " value=ChatCompletion(\n",
312 | " id='chatcmpl-ApHuoaVGaxOoPUX6syvQt9XkfSkCe',\n",
313 | " choices=[\n",
314 | " Choice(\n",
315 | " finish_reason='stop',\n",
316 | " index=0,\n",
317 | " logprobs=ChoiceLogprobs(\n",
318 | " content=[\n",
319 | " ChatCompletionTokenLogprob(\n",
320 | " token='{\"',\n",
321 | " bytes=[123, 34],\n",
322 | " logprob=-1.50940705e-05,\n",
323 | " top_logprobs=[]\n",
324 | " ),\n",
325 | " #...\n",
326 | " ],\n",
327 | " refusal=None\n",
328 | " ),\n",
329 | " message=ChatCompletionMessage(\n",
330 | " content='{\"capital_of_France\":\"Paris\",\"the_two_nicest_colors\":[\"blue\",\"green\"],\"die_shows\":4}',\n",
331 | " refusal=None,\n",
332 | " role='assistant',\n",
333 | " audio=None,\n",
334 | " function_call=None,\n",
335 | " tool_calls=None\n",
336 | " )\n",
337 | " )\n",
338 | " ],\n",
339 | " created=1736786958,\n",
340 | " model='gpt-4o-2024-08-06',\n",
341 | " object='chat.completion',\n",
342 | " service_tier='default',\n",
343 | " system_fingerprint='fp_703d4ff298',\n",
344 | " usage=CompletionUsage(\n",
345 | " completion_tokens=27,\n",
346 | " prompt_tokens=133,\n",
347 | " total_tokens=160,\n",
348 | " completion_tokens_details=CompletionTokensDetails(\n",
349 | " accepted_prediction_tokens=0,\n",
350 | " audio_tokens=0,\n",
351 | " reasoning_tokens=0,\n",
352 | " rejected_prediction_tokens=0\n",
353 | " ),\n",
354 | " prompt_tokens_details=PromptTokensDetails(audio_tokens=0, cached_tokens=0)\n",
355 | " )\n",
356 | " ),\n",
357 | " log_probs=[\n",
358 | " {\n",
359 | " 'capital_of_France': -1.22165105e-06,\n",
360 | " 'the_two_nicest_colors': [-0.00276869551265, -0.00539924761265],\n",
361 | " 'die_shows': -0.44008404\n",
362 | " }\n",
363 | " ]\n",
364 | ")\n",
365 | "```"
366 | ]
367 | },
368 | {
369 | "cell_type": "code",
370 | "execution_count": 9,
371 | "metadata": {
372 | "colab": {
373 | "base_uri": "https://localhost:8080/",
374 | "height": 145
375 | },
376 | "id": "8YmboSB2vtXb",
377 | "outputId": "f4fe01a2-332e-4953-8ef7-37d43b0161a5"
378 | },
379 | "outputs": [
380 | {
381 | "data": {
382 | "text/html": [
383 | "{\n",
384 | " \"capital_of_France\": \"Paris\",\n",
385 | " \"the_two_nicest_colors\": [\n",
386 | " \"blue\",\n",
387 | " \"green\"\n",
388 | " ],\n",
389 | " \"die_shows\": 4\n",
390 | "}\n",
391 | "\n"
392 | ],
393 | "text/plain": [
394 | "\u001b[1m{\u001b[0m\n",
395 | " \u001b[1;34m\"capital_of_France\"\u001b[0m: \u001b[32m\"Paris\"\u001b[0m,\n",
396 | " \u001b[1;34m\"the_two_nicest_colors\"\u001b[0m: \u001b[1m[\u001b[0m\n",
397 | " \u001b[32m\"blue\"\u001b[0m,\n",
398 | " \u001b[32m\"green\"\u001b[0m\n",
399 | " \u001b[1m]\u001b[0m,\n",
400 | " \u001b[1;34m\"die_shows\"\u001b[0m: \u001b[1;36m4\u001b[0m\n",
401 | "\u001b[1m}\u001b[0m\n"
402 | ]
403 | },
404 | "metadata": {},
405 | "output_type": "display_data"
406 | }
407 | ],
408 | "source": [
409 | "print_json(chat_completion.value.choices[0].message.content)"
410 | ]
411 | },
412 | {
413 | "cell_type": "code",
414 | "execution_count": 10,
415 | "metadata": {
416 | "colab": {
417 | "base_uri": "https://localhost:8080/",
418 | "height": 97
419 | },
420 | "id": "ioGyilZAmmvW",
421 | "outputId": "c4ecf3f3-3f54-4a65-e1fb-a4c1532b1b81"
422 | },
423 | "outputs": [
424 | {
425 | "data": {
426 | "text/html": [
427 | "{\n",
428 | " 'capital_of_France': -1.10244729e-06,\n",
429 | " 'the_two_nicest_colors': [-0.0022088558126500003, -0.01012725961265],\n",
430 | " 'die_shows': -0.43754107\n",
431 | "}\n",
432 | "\n"
433 | ],
434 | "text/plain": [
435 | "\u001b[1m{\u001b[0m\n",
436 | " \u001b[32m'capital_of_France'\u001b[0m: \u001b[1;36m-1.10244729e-06\u001b[0m,\n",
437 | " \u001b[32m'the_two_nicest_colors'\u001b[0m: \u001b[1m[\u001b[0m\u001b[1;36m-0.0022088558126500003\u001b[0m, \u001b[1;36m-0.01012725961265\u001b[0m\u001b[1m]\u001b[0m,\n",
438 | " \u001b[32m'die_shows'\u001b[0m: \u001b[1;36m-0.43754107\u001b[0m\n",
439 | "\u001b[1m}\u001b[0m\n"
440 | ]
441 | },
442 | "metadata": {},
443 | "output_type": "display_data"
444 | }
445 | ],
446 | "source": [
447 | "print(chat_completion.log_probs[0])"
448 | ]
449 | },
450 | {
451 | "cell_type": "markdown",
452 | "metadata": {
453 | "id": "Vrqpmd5xmmvW"
454 | },
455 | "source": [
456 | "By applying the exponential function to logprobs, you can easily convert\n",
457 | "them to probabilities."
458 | ]
459 | },
460 | {
461 | "cell_type": "code",
462 | "execution_count": 11,
463 | "metadata": {
464 | "colab": {
465 | "base_uri": "https://localhost:8080/",
466 | "height": 33
467 | },
468 | "id": "SU0AoPbpmmvW",
469 | "outputId": "ee22f156-8072-4346-91c8-13d03ed701a7"
470 | },
471 | "outputs": [
472 | {
473 | "data": {
474 | "text/html": [
475 | "{'capital_of_France_prob': 1.0, 'the_two_nicest_colors_prob': [1.0, 0.99], 'die_shows_prob': 0.65}\n",
476 | "\n"
477 | ],
478 | "text/plain": [
479 | "\u001b[1m{\u001b[0m\u001b[32m'capital_of_France_prob'\u001b[0m: \u001b[1;36m1.0\u001b[0m, \u001b[32m'the_two_nicest_colors_prob'\u001b[0m: \u001b[1m[\u001b[0m\u001b[1;36m1.0\u001b[0m, \u001b[1;36m0.99\u001b[0m\u001b[1m]\u001b[0m, \u001b[32m'die_shows_prob'\u001b[0m: \u001b[1;36m0.65\u001b[0m\u001b[1m}\u001b[0m\n"
480 | ]
481 | },
482 | "metadata": {},
483 | "output_type": "display_data"
484 | }
485 | ],
486 | "source": [
487 | "data = chat_completion.log_probs[0]\n",
488 | "transformed_data = {\n",
489 | " key + \"_prob\": [round(math.exp(log_prob), 2) for log_prob in value]\n",
490 | " if isinstance(value, list)\n",
491 | " else round(math.exp(value), 2)\n",
492 | " for key, value in data.items()\n",
493 | "}\n",
494 | "print(transformed_data)"
495 | ]
496 | },
497 | {
498 | "cell_type": "markdown",
499 | "metadata": {
500 | "id": "ZRmQCvmOmmvW"
501 | },
502 | "source": [
503 | "## Enhance the chat completion result with in-line log probabilities\n",
504 | "\n",
505 | "With the `add_logprobs_inline` method you can embeds log probabilities directly within the content of the message. Instead of having log probabilities as a separate field, this function integrates them into the content if the chat completion response itself, allowing for atomic values to be accompanied by their respective log probabilities."
506 | ]
507 | },
508 | {
509 | "cell_type": "code",
510 | "execution_count": 12,
511 | "metadata": {
512 | "id": "7KBaRghOEJKr"
513 | },
514 | "outputs": [],
515 | "source": [
516 | "chat_completion_inline = add_logprobs_inline(completion)"
517 | ]
518 | },
519 | {
520 | "cell_type": "markdown",
521 | "metadata": {
522 | "id": "I4Yx8PdUmmvX"
523 | },
524 | "source": [
525 | "If you print now the response you can see that the content of the message is replaced with a dictionary that includes also inline log probabilities for atomic values.\n",
526 | "\n",
527 | "```python\n",
528 | "ChatCompletion(\n",
529 | " id='chatcmpl-ApIDdbCuAJ8EHM6RDNgGR3mEQZTBH',\n",
530 | " choices=[\n",
531 | " Choice(\n",
532 | " finish_reason='stop',\n",
533 | " index=0,\n",
534 | " logprobs=ChoiceLogprobs(\n",
535 | " content=[\n",
536 | " ChatCompletionTokenLogprob(\n",
537 | " token='{\"',\n",
538 | " bytes=[123, 34],\n",
539 | " logprob=-2.3795938e-05,\n",
540 | " top_logprobs=[]\n",
541 | " ),\n",
542 | " #...\n",
543 | " ],\n",
544 | " refusal=None\n",
545 | " ),\n",
546 | " message=ChatCompletionMessage(\n",
547 | " content='{\"capital_of_France\": \"Paris\", \"capital_of_France_logprob\": -7.448363e-07,\n",
548 | "\"the_two_nicest_colors\": [\"blue\", \"green\"], \"die_shows\": 4.0, \"die_shows_logprob\": -0.46062052}',\n",
549 | " refusal=None,\n",
550 | " role='assistant',\n",
551 | " audio=None,\n",
552 | " function_call=None,\n",
553 | " tool_calls=None\n",
554 | " )\n",
555 | " )\n",
556 | " ],\n",
557 | " created=1736788125,\n",
558 | " model='gpt-4o-2024-08-06',\n",
559 | " object='chat.completion',\n",
560 | " service_tier='default',\n",
561 | " system_fingerprint='fp_703d4ff298',\n",
562 | " usage=CompletionUsage(\n",
563 | " completion_tokens=27,\n",
564 | " prompt_tokens=133,\n",
565 | " total_tokens=160,\n",
566 | " completion_tokens_details=CompletionTokensDetails(\n",
567 | " accepted_prediction_tokens=0,\n",
568 | " audio_tokens=0,\n",
569 | " reasoning_tokens=0,\n",
570 | " rejected_prediction_tokens=0\n",
571 | " ),\n",
572 | " prompt_tokens_details=PromptTokensDetails(audio_tokens=0, cached_tokens=0)\n",
573 | " )\n",
574 | ")\n",
575 | "```"
576 | ]
577 | },
578 | {
579 | "cell_type": "code",
580 | "execution_count": 13,
581 | "metadata": {
582 | "colab": {
583 | "base_uri": "https://localhost:8080/",
584 | "height": 177
585 | },
586 | "id": "e0kRhM1a-MF0",
587 | "outputId": "a9f345cb-aaf0-4e5d-a888-57bf6e7515b7"
588 | },
589 | "outputs": [
590 | {
591 | "data": {
592 | "text/html": [
593 | "{\n",
594 | " \"capital_of_France\": \"Paris\",\n",
595 | " \"capital_of_France_logprob\": -1.10244729e-06,\n",
596 | " \"the_two_nicest_colors\": [\n",
597 | " \"blue\",\n",
598 | " \"green\"\n",
599 | " ],\n",
600 | " \"die_shows\": 4.0,\n",
601 | " \"die_shows_logprob\": -0.43754107\n",
602 | "}\n",
603 | "\n"
604 | ],
605 | "text/plain": [
606 | "\u001b[1m{\u001b[0m\n",
607 | " \u001b[1;34m\"capital_of_France\"\u001b[0m: \u001b[32m\"Paris\"\u001b[0m,\n",
608 | " \u001b[1;34m\"capital_of_France_logprob\"\u001b[0m: \u001b[1;36m-1.10244729e-06\u001b[0m,\n",
609 | " \u001b[1;34m\"the_two_nicest_colors\"\u001b[0m: \u001b[1m[\u001b[0m\n",
610 | " \u001b[32m\"blue\"\u001b[0m,\n",
611 | " \u001b[32m\"green\"\u001b[0m\n",
612 | " \u001b[1m]\u001b[0m,\n",
613 | " \u001b[1;34m\"die_shows\"\u001b[0m: \u001b[1;36m4.0\u001b[0m,\n",
614 | " \u001b[1;34m\"die_shows_logprob\"\u001b[0m: \u001b[1;36m-0.43754107\u001b[0m\n",
615 | "\u001b[1m}\u001b[0m\n"
616 | ]
617 | },
618 | "metadata": {},
619 | "output_type": "display_data"
620 | }
621 | ],
622 | "source": [
623 | "print_json(chat_completion_inline.choices[0].message.content)"
624 | ]
625 | },
626 | {
627 | "cell_type": "markdown",
628 | "metadata": {
629 | "id": "GtY9zebnmmvX"
630 | },
631 | "source": [
632 | "The probability can easily be obtained by exponentiating the the log-probability."
633 | ]
634 | },
635 | {
636 | "cell_type": "code",
637 | "execution_count": 14,
638 | "metadata": {
639 | "colab": {
640 | "base_uri": "https://localhost:8080/",
641 | "height": 129
642 | },
643 | "id": "0qbyoD4BmmvX",
644 | "outputId": "454a2a65-349b-4e98-e273-094876f128d3"
645 | },
646 | "outputs": [
647 | {
648 | "data": {
649 | "text/html": [
650 | "{\n",
651 | " 'capital_of_France': 'Paris',\n",
652 | " 'capital_of_France_prob': 1.0,\n",
653 | " 'the_two_nicest_colors': ['blue', 'green'],\n",
654 | " 'die_shows': 4.0,\n",
655 | " 'die_shows_prob': 0.65\n",
656 | "}\n",
657 | "\n"
658 | ],
659 | "text/plain": [
660 | "\u001b[1m{\u001b[0m\n",
661 | " \u001b[32m'capital_of_France'\u001b[0m: \u001b[32m'Paris'\u001b[0m,\n",
662 | " \u001b[32m'capital_of_France_prob'\u001b[0m: \u001b[1;36m1.0\u001b[0m,\n",
663 | " \u001b[32m'the_two_nicest_colors'\u001b[0m: \u001b[1m[\u001b[0m\u001b[32m'blue'\u001b[0m, \u001b[32m'green'\u001b[0m\u001b[1m]\u001b[0m,\n",
664 | " \u001b[32m'die_shows'\u001b[0m: \u001b[1;36m4.0\u001b[0m,\n",
665 | " \u001b[32m'die_shows_prob'\u001b[0m: \u001b[1;36m0.65\u001b[0m\n",
666 | "\u001b[1m}\u001b[0m\n"
667 | ]
668 | },
669 | "metadata": {},
670 | "output_type": "display_data"
671 | }
672 | ],
673 | "source": [
674 | "data = json.loads(chat_completion_inline.choices[0].message.content)\n",
675 | "transformed_data = {\n",
676 | " (key[:-8] + \"_prob\" if key.endswith(\"_logprob\") else key): (\n",
677 | " round(math.exp(value), 2) if key.endswith(\"_logprob\") else value\n",
678 | " )\n",
679 | " for key, value in data.items()\n",
680 | "}\n",
681 | "print(transformed_data)"
682 | ]
683 | }
684 | ],
685 | "metadata": {
686 | "colab": {
687 | "provenance": []
688 | },
689 | "kernelspec": {
690 | "display_name": "Python 3",
691 | "name": "python3"
692 | },
693 | "language_info": {
694 | "name": "python"
695 | }
696 | },
697 | "nbformat": 4,
698 | "nbformat_minor": 0
699 | }
700 |
--------------------------------------------------------------------------------
/images/logo.svg:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/docs/images/logo.svg:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------