├── COPYING.txt
├── README.md
├── pyproject.toml
├── src
└── gptcmd
│ ├── __init__.py
│ ├── cli.py
│ ├── config.py
│ ├── config_sample.toml
│ ├── llm
│ ├── README.md
│ ├── __init__.py
│ └── openai.py
│ └── message.py
└── tests
├── test_llm.py
└── test_message.py
/README.md:
--------------------------------------------------------------------------------
1 | # Gptcmd
2 | Gptcmd allows you to interact with large language models, such as OpenAI's GPT, efficiently in your terminal. Gptcmd can manage multiple concurrent "threads" of conversation, allowing for free and easy prompt experimentation and iteration. Individual messages can be manipulated, loaded from, and saved to files (both plain text and JSON), and API parameters are fully customizable. In short, Gptcmd is simple yet flexible, useful for both basic conversation and more involved prototyping.
3 |
4 | ## Getting started
5 | Gptcmd requires [Python](https://python.org) 3.8.6 or later. It is available on PyPI, and can, for instance, be installed with `pip install gptcmd` at a command line shell. Running `gptcmd` at a shell starts the application. If Python's `bin` or `scripts` directory isn't on your path, you may need to launch the application with a command like `~/.local/bin/gptcmd` (depending on your system configuration). In most cases though, `gptcmd` should "just work".
6 |
7 | If you'd like to use OpenAI models and you don't have an OpenAI account, you'll need to create one and [add some credit](https://platform.openai.com/account/billing/overview). $5 or so goes very far, [especially on `gpt-4.1-mini` or `gpt-4.1-nano`](#model-selection).
8 |
9 | Gptcmd searches for provider credentials in its configuration file, falling back to the `OPENAI_API_KEY` environment variable if no key is provided in its configuration. If you'd like to use OpenAI models and you don't have an API key, you'll need to [generate a key](https://platform.openai.com/account/api-keys).
10 |
11 | Once Gptcmd starts, it presents a prompt containing the name of the currently active model and waits for user input. Running the `quit` command (typing `quit` at the prompt and pressing Return) exits the program.
12 |
13 | Gptcmd has a help facility that provides a list of available commands and brief usage hints for each. The `help` command with no arguments provides a list of available commands. Passing a command as an argument to `help` returns information on the selected command.
14 |
15 | ### Configuring Gptcmd
16 | When Gptcmd starts for the first time, it generates a configuration file whose location depends on your operating system:
17 |
18 | Platform | Config location
19 | --- | ---
20 | Windows | `%appdata%\gptcmd\config.toml`
21 | MacOS | `~/Library/Application Support/gptcmd/config.toml`
22 | Other | `$XDG_CONFIG_HOME/gptcmd/config.toml` or `~/.config/gptcmd/config.toml`
23 |
24 | You may open Gptcmd's configuration file in a text editor to change application settings. The file contains comments that describe the available options. Configuration changes will be applied the next time Gptcmd is restarted.
25 |
26 | ### Simple conversation
27 | The `say` command sends a message to the model:
28 |
29 | ```
30 | (gpt-4o) say Hello, world!
31 | ...
32 | Hello! How can I assist you today?
33 | ```
34 |
35 | Gptcmd sends the entire conversation every time, never deleting history unless told to do so.
36 |
37 | ```
38 | (gpt-4o) say I'm good! How are you?
39 | ...
40 | I'm just a program, so I don't have feelings, but I'm here and ready to help you with anything you need!
41 | (gpt-4o) say That's alright. Count from 1 to 5.
42 | ...
43 | Sure! Here you go: 1, 2, 3, 4, 5.
44 | (gpt-4o) say What are the next two numbers after that?
45 | ...
46 | The next two numbers are 6 and 7.
47 | ```
48 |
49 | The conversation can be cleared with the `clear` command, at which point any previous context will no longer be made available to the model:
50 |
51 | ```
52 | (gpt-4o) clear
53 | Delete 8 messages? (y/n)y
54 | Cleared
55 | (gpt-4o) say What are the next two numbers after that?
56 | ...
57 | I apologize, but your request is unclear. Could you please provide more information or context? For example, if you're referring to a numerical pattern or sequence, sharing the sequence would help me assist you better.
58 | ```
59 |
60 | ### Viewing messages
61 | The `first` and `last` commands view the first and last messages in the conversation respectively:
62 |
63 | ```
64 | (gpt-4o) say Write a limerick about generative AI.
65 | ...
66 | In the land of the silicon chip,
67 | Generative AI took a trip.
68 | With words it would play,
69 | In a curious way,
70 | Creating tales at the click of a lip.
71 | (gpt-4o) first
72 | user: What are the next two numbers after that?
73 | (gpt-4o) last
74 | assistant: In the land of the silicon chip,
75 | Generative AI took a trip.
76 | With words it would play,
77 | In a curious way,
78 | Creating tales at the click of a lip.
79 | ```
80 |
81 | Providing an integer k as an argument shows the first/last k messages:
82 |
83 | ```
84 | (gpt-4o) first 2
85 | user: What are the next two numbers after that?
86 | assistant: I apologize, but your request is unclear. Could you please provide more information or context? For example, if you're referring to a numerical pattern or sequence, sharing the sequence would help me assist you better.
87 | (gpt-4o) last 3
88 | assistant: I apologize, but your request is unclear. Could you please provide more information or context? For example, if you're referring to a numerical pattern or sequence, sharing the sequence would help me assist you better.
89 | user: Write a limerick about generative AI.
90 | assistant: In the land of the silicon chip,
91 | Generative AI took a trip.
92 | With words it would play,
93 | In a curious way,
94 | Creating tales at the click of a lip.
95 | ```
96 |
97 | The `view` command shows the entire conversation:
98 |
99 | ```
100 | (gpt-4o) view
101 | user: What are the next two numbers after that?
102 | assistant: I apologize, but your request is unclear. Could you please provide more information or context? For example, if you're referring to a numerical pattern or sequence, sharing the sequence would help me assist you better.
103 | user: Write a limerick about generative AI.
104 | assistant: In the land of the silicon chip,
105 | Generative AI took a trip.
106 | With words it would play,
107 | In a curious way,
108 | Creating tales at the click of a lip.
109 | ```
110 |
111 | ### Message ranges
112 | Various Gptcmd commands work over ranges of messages in a conversation. Ranges are specified either as the index (position) of a single message, or a space-separated pair of the inclusive indices of the beginning and end of an interval of messages. Unlike in many programming languages, messages are one-indexed (i.e. `1` refers to the first message, `2` to the second, etc.). A dot (`.`) refers either to the entire conversation or, in place of a numeric index, to either the beginning or end of the conversation. Negative indexing is supported (`-1` refers to the last message, `-2` to the penultimate, and so on).
113 |
114 | The `view` command accepts a range of messages as an argument. When provided, it only shows messages in the indicated range. Some example message range specifications follow:
115 |
116 | ```
117 | (gpt-4o) view 1
118 | user: What are the next two numbers after that?
119 | (gpt-4o) view 2
120 | assistant: I apologize, but your request is unclear. Could you please provide more information or context? For example, if you're referring to a numerical pattern or sequence, sharing the sequence would help me assist you better.
121 | (gpt-4o) view 2 3
122 | assistant: I apologize, but your request is unclear. Could you please provide more information or context? For example, if you're referring to a numerical pattern or sequence, sharing the sequence would help me assist you better.
123 | user: Write a limerick about generative AI.
124 | (gpt-4o) view -1
125 | assistant: In the land of the silicon chip,
126 | Generative AI took a trip.
127 | With words it would play,
128 | In a curious way,
129 | Creating tales at the click of a lip.
130 | (gpt-4o) view 3 .
131 | user: Write a limerick about generative AI.
132 | assistant: In the land of the silicon chip,
133 | Generative AI took a trip.
134 | With words it would play,
135 | In a curious way,
136 | Creating tales at the click of a lip.
137 | (gpt-4o) view .
138 | user: What are the next two numbers after that?
139 | assistant: I apologize, but your request is unclear. Could you please provide more information or context? For example, if you're referring to a numerical pattern or sequence, sharing the sequence would help me assist you better.
140 | user: Write a limerick about generative AI.
141 | assistant: In the land of the silicon chip,
142 | Generative AI took a trip.
143 | With words it would play,
144 | In a curious way,
145 | Creating tales at the click of a lip.
146 | ```
147 |
148 | ### Conversation roles
149 | Gptcmd messages can have one of three "roles": user, representing the end-user; assistant, representing the large language model; and system, used to inform the model's behaviour. System messages can be appended to the conversation with the `system` command:
150 |
151 | ```
152 | (gpt-4o) system As a philosophical proof assistant, you prefer to use the technique of proof by contradiction at any opportunity. Therefore, whenever you are asked to prove something, you first assume that the proposition is false, then given this assumption, you attempt to find a contradiction.
153 | 'As a philosophical proof assistant, you prefer to use the...' added as system
154 | (gpt-4o) say Prove that all cacti are succulents.
155 | ...
156 | To prove that all cacti are succulents, let's employ proof by contradiction. We begin by assuming the opposite of what we want to prove—that not all cacti are succulents. This assumption asserts that there is at least one cactus that is not a succulent.
157 |
158 | Now, let's examine the definitions involved:
159 |
160 | 1. **Cactus**: A cactus is a type of plant belonging to the family Cactaceae, which is well-adapted to arid environments. Cacti typically have thick, fleshy parts that store water, and they often have needles instead of leaves.
161 |
162 | 2. **Succulent**: A succulent is a plant that has thick, fleshy parts developed to retain water in arid climates or soil conditions. This is a broader category that includes many different plant families, including Cactaceae.
163 |
164 | Under our assumption, there exists a cactus that does not fit the definition of a succulent. In other words, this hypothetical cactus would not be adapted for water storage, lacking the thick, fleshy parts that characterize succulents. However, such characteristics are fundamental to the biological adaptation seen in all cacti, given their native environments, which require water conservation.
165 |
166 | Thus, the assumption leads to a contradiction with the basic biological characteristics and adaptations of cacti. As all known cacti possess the necessary attributes to be classified as succulents, our assumption is false.
167 |
168 | Therefore, we conclude that all cacti are indeed succulents.
169 | ```
170 |
171 | Similarly, user and assistant messages can be added with the `user` and `assistant` commands respectively. Since models like GPT are agnostic of the source of messages (i.e. they don't track their own responses and expect downstream applications to manage context), Gptcmd allows you to inject your own arbitrary conversation history to which the model can respond:
172 |
173 | ```
174 | (gpt-4o) user What are the first five Fibonacci numbers?
175 | 'What are the first five Fibonacci numbers?' added as user
176 | (gpt-4o) assistant 1, 1, 2, 3, 5.
177 | '1, 1, 2, 3, 5.' added as assistant
178 | (gpt-4o) say And the next five?
179 | ...
180 | The next five Fibonacci numbers after 1, 1, 2, 3, 5 are 8, 13, 21, 34, and 55.
181 | ```
182 |
183 | The `send` command sends the conversation in its current state to the model and requests a response:
184 |
185 | ```
186 | (gpt-4o) user What are the first ten digits of pi?
187 | 'What are the first ten digits of pi?' added as user
188 | (gpt-4o) send
189 | ...
190 | The first ten digits of pi are 3.141592653.
191 | ```
192 |
193 | In fact, the `say` command just adds a user message and sends the conversation:
194 |
195 | ```
196 | (gpt-4o) say What are the first ten digits of pi?
197 | ...
198 | The first ten digits of pi are 3.141592653.
199 | ```
200 |
201 | With no arguments, the `user`, `assistant`, `system`, and `say` commands open an external text editor (based on your system or Gptcmd configuration) for message composition.
202 |
203 | ### Working with images
204 | OpenAI's latest models, such as `gppt-4o`, support images alongside text content. Images can be attached to messages with the `image` command, which accepts two arguments: the location of the image, either a URL or path to a local file; and the index of the message to which the image should be attached (if unspecified, it defaults to the last). We'll ask GPT to describe an image by creating a user message and attaching an image from Wikimedia Commons:
205 |
206 | ```
207 | (gpt-4o) user What's in this image?
208 | "What's in this image?" added as user
209 | (gpt-4o) image https://upload.wikimedia.org/wikipedia/commons/c/ce/Long_cane.jpg
210 | Image added to "What's in this image?"
211 | ```
212 |
213 | When viewing the conversation, an at sign before a message indicates an attachment (multiple at signs indicate multiple attachments):
214 |
215 | ```
216 | (gpt-4o) view
217 | @user: What's in this image?
218 | ```
219 |
220 | Now, we can `send` our message to get a description:
221 |
222 | ```
223 | (gpt-4o) send
224 | ...
225 | This is a white cane, often used by individuals who are blind or visually impaired to aid in mobility and navigation. It has a handle, a long shaft, and a rounded tip.
226 | ```
227 |
228 | ### Managing messages
229 | The `pop` command with no argument deletes the last message of a conversation:
230 |
231 | ```
232 | (gpt-4o) say Responding with only one word, tell me a female given name.
233 | ...
234 | Alice.
235 | (gpt-4o) pop
236 | 'Alice.' deleted
237 | (gpt-4o) send
238 | ...
239 | Sophia
240 | (gpt-4o) pop
241 | 'Sophia' deleted
242 | (gpt-4o) send
243 | ...
244 | Emily
245 | (gpt-4o) view
246 | user: Responding with only one word, tell me a female given name.
247 | assistant: Emily
248 | ```
249 |
250 | Deleting the last message and resending the conversation is a very common action while experimenting with large language models, so Gptcmd includes a shortcut: the `retry` command:
251 |
252 | ```
253 | (gpt-4o) say Responding with only one word, tell me a male given name.
254 | ...
255 | David
256 | (gpt-4o) retry
257 | ...
258 | John
259 | (gpt-4o) retry
260 | ...
261 | William
262 | ```
263 |
264 | Providing the index of a single message as an argument to `pop` deletes the specified message:
265 |
266 | ```
267 | (gpt-4o) user abc
268 | 'abc' added as user
269 | (gpt-4o) user def
270 | 'def' added as user
271 | (gpt-4o) user ghi
272 | 'ghi' added as user
273 | (gpt-4o) pop 2
274 | 'def' deleted
275 | (gpt-4o) view
276 | user: abc
277 | user: ghi
278 | ```
279 |
280 | The `move` command moves the message at the start of a range to the end of that range. This is often useful for adding a system message to conversations that previously didn't have one:
281 |
282 | ```
283 | (gpt-4o) say Prove that pi is not exactly equal to 3.14628 without the use of any mathematical symbols.
284 | ...
285 | To demonstrate that pi is not exactly equal to three point one four six two eight, we can refer to the nature of pi and how it has been studied and understood throughout history.
286 |
287 | Pi is defined as the ratio of the circumference of a circle to its diameter. It is known to be an irrational number, meaning it cannot be exactly expressed as a finite decimal or a fraction with whole numbers. This characteristic implies that it has an infinite and non-repeating decimal expansion. Historical effort has been put into calculating the digits of pi, revealing many digits beyond the decimal point.
288 |
289 | While three point one four six two eight might appear to be close to pi, it is only an approximation. The actual digits of pi, after the well-known initial digits three point one four, continue as one five nine two six five and so forth. Since pi is irrational, its decimal representation is infinite and non-repeating, and cannot completely coincide with any finite number, no matter how many digits it contains.
290 |
291 | Thus, because three point one four six two eight ends after five decimal places and pi does not, we conclude that they cannot be exactly equal.
292 | (gpt-4o) first
293 | user: Prove that pi is not exactly equal to 3.14628 without the use of any mathematical symbols.
294 | (gpt-4o) system As a philosophical proof assistant, you prefer to use the technique of proof by contradiction at any opportunity. Therefore, whenever you are asked to prove something, you first assume that the proposition is false, then given this assumption, you attempt to find a contradiction.
295 | 'As a philosophical proof assistant, you prefer to use the...' added as system
296 | (gpt-4o) move -1 1
297 | 'As a philosophical proof assistant, you prefer to use the technique...' moved to start
298 | (gpt-4o) view 1 2
299 | system: As a philosophical proof assistant, you prefer to use the technique of proof by contradiction at any opportunity. Therefore, whenever you are asked to prove something, you first assume that the proposition is false, then given this assumption, you attempt to find a contradiction.
300 | user: Prove that pi is not exactly equal to 3.14628 without the use of any mathematical symbols.
301 | (gpt-4o) pop 3
302 | 'To demonstrate that pi is not exactly equal to three point one four...' deleted
303 | (gpt-4o) send
304 | ...
305 | To prove that pi is not exactly equal to three point one four six two eight, we begin with the assumption that pi is indeed exactly three point one four six two eight. Given this assumption, we need to find a contradiction.
306 |
307 | Let's recall the nature of pi. By definition, pi is the ratio of the circumference of a circle to its diameter in Euclidean space, and it is known to be an irrational number. An irrational number is one that cannot be expressed as the exact quotient of two integers, meaning it has a non-repeating, non-terminating decimal expansion.
308 |
309 | Now, consider the number three point one four six two eight. This number is clearly a rational number because it can be expressed as the fraction thirty one thousand four hundred sixty-two and eight ten-thousandths. Since three point one four six two eight is a terminating decimal, it can be expressed precisely as the division of two integers.
310 |
311 | If pi were equal to three point one four six two eight, then pi would be rational. However, this contradicts the well-established fact that pi is irrational. This contradiction implies that our initial assumption must be false.
312 |
313 | Therefore, we conclude that pi is not exactly equal to three point one four six two eight.
314 | ```
315 |
316 | The `grep` command takes a [regular expression](https://en.wikipedia.org/wiki/Regular_expression) as an argument and searches the conversation for occurrences within its content, displaying the index and a small preview of each message that matches:
317 |
318 | ```
319 | (gpt-4o) grep contra
320 | 1 (system): '...prefer to use the technique of proof by [contra]diction at any...'
321 | 3 (assistant): "...iven this assumption, we need to find a [contra]diction. Let's recall..."
322 | (gpt-4o) grep \d+
323 | 2 (user): 'Prove that pi is not exactly equal to [3].[14628] without the use of any...'
324 | ```
325 |
326 | The `edit` command with no arguments opens the contents of the last message in an external text editor for modification. Providing the index of a message to `edit` as an argument edits that message.
327 |
328 | ### Message streaming
329 | The `stream` command toggles message streaming. By default, streaming is enabled, so long responses from the language model are output in real time as they are generated. While a message is being streamed, pressing Control+c causes Gptcmd to stop waiting for the message to generate fully, allowing other commands to be used. When streaming is disabled, Gptcmd retrieves an entire response for each query and displays it when it arrives.
330 |
331 | ### Model selection
332 | The `model` command switches the active model. For instance, we can switch to [`gpt-4.1-nano`](https://openai.com/index/gpt-4-1/), a smaller, cheaper model offered by OpenAI:
333 |
334 | ```
335 | (gpt-4o) model gpt-4.1-nano
336 | Switched to model 'gpt-4.1-nano'
337 | (gpt-4.1-nano) say Hello!
338 | ...
339 | Hello! How can I assist you today?
340 | (gpt-4.1-nano) model gpt-4o
341 | Switched to model 'gpt-4o'
342 | ```
343 |
344 | Similarly, if you've configured multiple accounts (such as to use non-OpenAI providers), the `account` command can be used to switch among them by providing the name of the account to use as an argument.
345 |
346 | ### API parameters
347 | Gptcmd supports customization of [chat completion API parameters](https://platform.openai.com/docs/api-reference/chat/create), such as `max_tokens` and `temperature`. The `set` command sets an OpenAI API parameter. When setting a parameter, the first argument to `set` is the name of the parameter and the second argument is its value (valid Python literals are supported). A value of `None` is equivalent to sending `null` via the API.
348 |
349 | The `max_tokens` parameter limits the number of [sampled tokens](https://platform.openai.com/tokenizer) returned by GPT. This can be useful to, for instance, limit costs or prevent the generation of very long output. Note that if `max_tokens` is reached, output may be cut off abruptly:
350 |
351 | ```
352 | (gpt-4o) set max_tokens 50
353 | max_tokens set to 50
354 | (gpt-4o) say Describe generative AI in three paragraphs
355 | ...
356 | Generative AI refers to a subset of artificial intelligence techniques that focus on creating new content or data rather than analyzing existing datasets. Unlike traditional AI models, which are primarily designed to classify, predict, or perform specific tasks, generative AI systems are equipped
357 | ```
358 |
359 | The `temperature` parameter controls GPT's sampling temperature. A temperature of 0 causes GPT to be very deterministic:
360 |
361 | ```
362 | (gpt-4o) set temperature 0
363 | temperature set to 0
364 | (gpt-4o) say Tell me a fun fact about generative AI.
365 | ...
366 | A fun fact about generative AI is that it has been used to create entirely new pieces of art and music, sometimes even fooling experts into thinking they were crafted by humans. For instance, AI-generated paintings have been sold at prestigious art auctions for
367 | (gpt-4o) retry
368 | ...
369 | A fun fact about generative AI is that it has been used to create entirely new pieces of art and music, sometimes even fooling experts into thinking they were crafted by humans. For instance, AI-generated paintings have been sold at prestigious art auctions for
370 | (gpt-4o) retry
371 | ...
372 | A fun fact about generative AI is that it has been used to create entirely new pieces of art and music, sometimes even fooling experts into thinking they were crafted by humans. For instance, AI-generated paintings have been sold at prestigious art auctions for
373 | ```
374 |
375 | The `unset` command, with an argument, reverts the specified API parameter to its default value. With no argument, it restores all API parameters to default. Here, we'll unset `max_tokens`, so that full length responses can again be generated:
376 |
377 | ```
378 | (gpt-4o) unset max_tokens
379 | max_tokens unset
380 | ```
381 |
382 | Higher temperatures result in more apparent randomness, which can translate in some applications to increased creativity or decreased factual accuracy:
383 |
384 | ```
385 | (gpt-4o) set temperature 0.75
386 | temperature set to 0.75
387 | (gpt-4o) retry
388 | ...
389 | A fun fact about generative AI is that it has been used to create entirely new pieces of art and music, sometimes even fooling experts into thinking these creations were made by humans. For instance, in 2018, an AI-generated painting called "Portrait of Edmond de Belamy" was auctioned at Christie’s for $432,500, far exceeding its estimated price. This demonstrated not only the creative capabilities of generative AI but also its potential impact on the art world, challenging traditional notions of creativity and authorship.
390 | (gpt-4o) retry
391 | ...
392 | A fun fact about generative AI is that it can create entirely new and unique pieces of art, music, and even poetry. For instance, AI models like OpenAI's DALL-E can generate imaginative and surreal images from simple text prompts, blending concepts that might not typically go together—such as a "two-headed flamingo in a bustling cityscape." This ability to merge creativity with computational power showcases how generative AI can expand the boundaries of artistic expression, offering novel tools for artists and creators to explore new dimensions of their work.
393 | (gpt-4o) retry
394 | ...
395 | A fun fact about generative AI is that it has been used to create entirely new pieces of music in the style of famous composers. For instance, AI models have been trained on the works of classical composers like Bach or Mozart to generate new compositions that mimic their distinct styles. This has opened up exciting possibilities not just for music enthusiasts but also for the entertainment industry, where AI-generated music can be used in films, video games, and other media to enhance creativity and reduce production costs.
396 | ```
397 |
398 | Too high, though, and GPT will just emit nonsense. To prevent the generation of an extremely large volume of output, we'll again set `max_tokens`:
399 |
400 | ```
401 | (gpt-4o) set max_tokens 30
402 | max_tokens set to 30
403 | (gpt-4o) set temperature 2
404 | temperature set to 2
405 | (gpt-4o) retry
406 | ...
407 | A fun fact about generative AI is that it's unique nature sometimes find unexpected parallels in non-modern multipart generators like Soukawi Internet authored phoenix drôle mime
408 | ```
409 |
410 | Another useful parameter is `timeout` which controls how long (in seconds) Gptcmd waits for a response from GPT:
411 |
412 | ```
413 | (gpt-4o) set timeout 0.25
414 | timeout set to 0.25
415 | (gpt-4o) say Hello!
416 | ...
417 | Request timed out.
418 | ```
419 |
420 | The `set` command with no arguments shows all set API parameters:
421 |
422 | ```
423 | (gpt-4o) set
424 | max_tokens: 30
425 | temperature: 2
426 | timeout: 0.25
427 | (gpt-4o) unset
428 | Unset all parameters
429 | ```
430 |
431 | ### Names
432 | GPT allows mesages to be annotated with the name of their author. The `name` command sets the name to be sent with all future messages of the specified role. Its first argument is the role to which this new name should be applied, and its second is the name to use:
433 |
434 | ```
435 | (gpt-4o) name user Michael
436 | user set to 'Michael'
437 | (gpt-4o) say Hello! What's my name?
438 | ...
439 | Hello! You mentioned your name is Michael. How can I assist you today?
440 | ```
441 |
442 | With no arguments, `name` shows currently set names:
443 |
444 | ```
445 | (gpt-4o) name
446 | user: Michael
447 | ```
448 |
449 | The `unname` command removes a name definition to be sent with future messages. With a role passed as an argument, it unsets the name definition for that role. With no arguments, it unsets all definitions. Any previously annotated messages are unaffected:
450 |
451 | ```
452 | (gpt-4o) view
453 | Michael: Hello!
454 | assistant: Hello! You mentioned your name is Michael. How can I assist you today?
455 | ```
456 |
457 | Name annotations are useful for providing one- or multi-shot prompts to GPT, in which example user and assistant messages help inform future responses:
458 |
459 | ```
460 | (gpt-4o) system You are a helpful assistant who understands many languages very well, but can only speak Spanish and therefore you always respond in that language.
461 | 'You are a helpful assistant who understands many languages...' added as system
462 | (gpt-4o) name system example_user
463 | system set to 'example_user'
464 | (gpt-4o) system Hello!
465 | 'Hello!' added as 'example_user' (system)
466 | (gpt-4o) name system example_assistant
467 | system set to 'example_assistant'
468 | (gpt-4o) system ¡Hola! ¿Cómo estás?
469 | '¡Hola! ¿Cómo estás?' added as 'example_assistant' (system)
470 | (gpt-4o) view
471 | system: You are a helpful assistant who understands many languages very well, but can only speak Spanish and therefore you always respond in that language.
472 | example_user: Hello!
473 | example_assistant: ¡Hola! ¿Cómo estás?
474 | (gpt-4o) say Qu'est-ce que amazon.com?
475 | ...
476 | Amazon.com es una empresa de comercio electrónico y tecnología que ofrece una amplia gama de productos y servicios en línea. Originalmente fundada en 1994 por Jeff Bezos como una librería en línea, Amazon se ha expandido para vender prácticamente de todo, desde electrónica hasta ropa, alimentos y servicios de computación en la nube, como AWS (Amazon Web Services). La empresa también produce dispositivos electrónicos, como el Kindle y dispositivos de la línea Echo con Alexa. Además, Amazon ofrece servicios de transmisión de video y música mediante Amazon Prime Video y Amazon Music, respectivamente.
477 | ```
478 |
479 | The `rename` command changes the name set on existing messages in the conversation. The command has two required arguments and one optional argument: the role to affect, the range of messages to affect, and (optionally) the name to set (if omitted, the name is cleared). For instance, `rename assistant .` clears the name on all assistant messages in the conversation where a name is set, `rename user 1 Paul` sets the name of the first message to "Paul" if it is a user message, and `rename system 2 5 Mitchell` sets the name of all system messages in the second through fifth to "Mitchell".
480 |
481 | ### Sticky messages
482 | Messages can be marked "sticky", so deletion, renaming, and similar modifications do not affect them. This is often useful for system messages and example context that you don't wish to delete accidentally. The sticky command takes the range of messages to sticky as an argument:
483 |
484 | ```
485 | (gpt-4o) system You are a Python programmer. Therefore, when responding, you write in Python source code exclusively.
486 | 'You are a Python programmer. Therefore, when responding, you...' added as system
487 | (gpt-4o) sticky .
488 | 1 message stickied
489 | ```
490 |
491 | Now that the message is sticky, `clear` does not affect it, and its sticky status is indicated by an asterisk:
492 |
493 | ```
494 | (gpt-4o) say Find the nth Fibonacci number.
495 | ...
496 | def fibonacci(n):
497 | if n <= 0:
498 | raise ValueError("n must be a positive integer.")
499 | elif n == 1:
500 | return 0
501 | elif n == 2:
502 | return 1
503 |
504 | a, b = 0, 1
505 | for _ in range(2, n):
506 | a, b = b, a + b
507 | return b
508 |
509 | # Example usage:
510 | # nth_fibonacci = fibonacci(10)
511 | # print(nth_fibonacci) # Output: 34
512 | (gpt-4o) clear
513 | Delete 2 messages? (y/n)y
514 | Cleared
515 | (gpt-4o) view
516 | *system: You are a Python programmer. Therefore, when responding, you write in Python source code exclusively.
517 | ```
518 |
519 | Similarly, `pop` is blocked:
520 |
521 | ```
522 | (gpt-4o) pop
523 | That message is sticky; unsticky it first
524 | ```
525 |
526 | The `unsticky` command makes all sticky messages in the specified range no longer sticky:
527 |
528 | ```
529 | (gpt-4o) unsticky .
530 | 1 message unstickied
531 | (gpt-4o) pop
532 | 'You are a Python programmer. Therefore, when responding, you write...' deleted
533 | ```
534 |
535 | ### Message metadata
536 | Gptcmd allows arbitrary [key–value metadata](https://en.wikipedia.org/wiki/Name%E2%80%93value_pair) to be stored with each message. This might be useful, for instance, to store personal notes with messages, or as an interface to enable special features in external large language model providers (consult external package documentation for details).
537 |
538 | Providing a key `k` and value `v` to the `meta` command stores `v` at `k` on the last message:
539 |
540 | ```
541 | (gpt-4o) user This is a test.
542 | 'This is a test.' added as user
543 | (gpt-4o) meta notes "This is a test of message metadata."
544 | notes set to 'This is a test of message metadata.' on 'This is a test.
545 | ```
546 |
547 | Valid JSON literals are supported in metadata values:
548 |
549 | ```
550 | (gpt-4o) meta list [1,2,3]
551 | list set to [1, 2, 3] on 'This is a test.'
552 | (gpt-4o) meta obj {"key1": "value1", "key2": true}
553 | obj set to {'key1': 'value1', 'key2': True} on 'This is a test.'
554 | ```
555 |
556 | Providing just a key shows the associated value:
557 |
558 | ```
559 | (gpt-4o) meta list
560 | [1, 2, 3]
561 | (gpt-4o) meta list2
562 | 'list2 not set'
563 | ```
564 |
565 | With no arguments, `meta` shows all keys set on the last message:
566 |
567 | ```
568 | (gpt-4o) meta
569 | notes: 'This is a test of message metadata.'
570 | list: [1, 2, 3]
571 | obj: {'key1': 'value1', 'key2': True}
572 | ```
573 |
574 | Providing an index as the first argument to `meta` operates on the selected message:
575 |
576 | ```
577 | (gpt-4o) user Second message
578 | 'Second message' added as user
579 | (gpt-4o) meta 1 list
580 | [1, 2, 3]
581 | (gpt-4o) meta 1 list2 [4,5,6]
582 | list2 set to [4, 5, 6] on 'This is a test.'
583 | ```
584 |
585 | The `unmeta` command deletes a key–value pair. Similarly to `meta`, it accepts an index as its first argument, operating on the last message if no index is provided:
586 |
587 | ```
588 | (gpt-4o) unmeta 1 list2
589 | list2 unset on 'This is a test.'
590 | (gpt-4o) meta 1 list2
591 | 'list2 not set'
592 | ```
593 |
594 | With no key specified, `unmeta` deletes all keys:
595 |
596 | ```
597 | (gpt-4o) unmeta 1
598 | delete 3 items on 'This is a test.'? (y/n)y
599 | Unset all metadata on 'This is a test.'
600 | (gpt-4o) meta 1
601 | No metadata set on 'This is a test.'
602 | (gpt-4o) clear
603 | Delete 2 messages? (y/n)y
604 | Cleared
605 | ```
606 |
607 | ### Message threads
608 | Until this point, we have been engaging in a single conversation (or series of conversations) with the model. However, Gptcmd supports the creation and maintenance of several concurrent conversation "threads".
609 |
610 | Gptcmd starts in the "detached thread", a scratch area intended for quick conversation. A new, named conversation thread can be created from the current thread with the `thread` command, which takes a name for the new thread as an argument:
611 |
612 | ```
613 | (gpt-4o) say Responding only using ASCII/Unicode symbols and without narrative explanation, what is the closed-form formula to calculate the nth Fibonacci number?
614 | ...
615 | F(n) = (φ^n - ψ^n) / √5
616 |
617 | where:
618 | φ = (1 + √5) / 2
619 | ψ = (1 - √5) / 2
620 | (gpt-4o) thread induction
621 | Switched to new thread 'induction'
622 | ```
623 |
624 | By default, the prompt changes to indicate the current thread. All messages have been copied:
625 |
626 | ```
627 | induction(gpt-4o) view
628 | user: Responding only using ASCII/Unicode symbols and without narrative explanation, what is the closed-form formula to calculate the nth Fibonacci number?
629 | assistant: F(n) = (φ^n - ψ^n) / √5
630 |
631 | where:
632 | φ = (1 + √5) / 2
633 | ψ = (1 - √5) / 2
634 | ```
635 |
636 | The `thread` command with no argument switches back to the detached thread:
637 |
638 | ```
639 | induction(gpt-4o) thread
640 | detached thread
641 | (gpt-4o) say Tell me a fun fact about Braille.
642 | ...
643 | Braille was invented by Louis Braille, who was inspired by a tactile military code called "night writing" developed by Charles Barbier, intended for soldiers to communicate silently and without light.
644 | ```
645 |
646 | Passing the name of an existing thread as an argument to `thread` switches to that thread. Once created, threads are completely independent:
647 |
648 | ```
649 | (gpt-4o) last 3
650 | assistant: F(n) = (φ^n - ψ^n) / √5
651 |
652 | where:
653 | φ = (1 + √5) / 2
654 | ψ = (1 - √5) / 2
655 | user: Tell me a fun fact about Braille.
656 | assistant: Braille was invented by Louis Braille, who was inspired by a tactile military code called "night writing" developed by Charles Barbier, intended for soldiers to communicate silently and without light.
657 | (gpt-4o) thread induction
658 | Switched to thread 'induction'
659 | induction(gpt-4o) last 3
660 | user: Responding only using ASCII/Unicode symbols and without narrative explanation, what is the closed-form formula to calculate the nth Fibonacci number?
661 | assistant: F(n) = (φ^n - ψ^n) / √5
662 |
663 | where:
664 | φ = (1 + √5) / 2
665 | ψ = (1 - √5) / 2
666 | ```
667 |
668 | When running `retry` in a named thread, Gptcmd asks whether you'd like to create a new thread, preserving both responses (this can be changed in Gptcmd configuration):
669 |
670 | ```
671 | induction(gpt-4o) say Responding in a short paragraph and without showing your algebra, prove by induction that this formula is equivalent to the recursive definition.
672 | ...
673 | To prove the closed-form formula, known as Binet's formula, is equivalent to the recursive definition of the Fibonacci sequence, use mathematical induction. First, verify the base cases: for \( n = 0 \) and \( n = 1 \), evaluate the formula and confirm it gives the correct Fibonacci numbers \( F(0) = 0 \) and \( F(1) = 1 \). Next, assume the formula is true for \( n = k \) and \( n = k + 1 \), meaning \( F(k) \) and \( F(k+1) \) are correct. Prove it for \( n = k + 2 \) by showing that substituting these into the recursive relation \( F(k+2) = F(k+1) + F(k) \) satisfies the formula. The linear recurrence relations of φ and ψ ensure that φ and ψ raised to any integer power adhere to the Fibonacci sequence, maintaining the truth of the expression for all \( n \). This confirms the formula's validity by induction.
674 | induction(gpt-4o) retry
675 | Create thread 'induction2'? (y/n)y
676 | Switched to new thread 'induction2'
677 | ...
678 | To prove by induction that the closed-form formula for the nth Fibonacci number is equivalent to the recursive definition, we start with the base cases. For n=0 and n=1, the formula gives F(0) = 0 and F(1) = 1, which match the recursive definition. For the inductive step, assume that F(k) and F(k+1) are true as per the formula for some k. Then, F(k+2) can be written as F(k+2) = φ^(k+2)/√5 - ψ^(k+2)/√5, which simplifies to φ^k+1 + φ^k/√5 after using the induction hypothesis and algebraic manipulation with φ and ψ properties, thus matching the recursive definition F(k+2) = F(k+1) + F(k). This closes the induction, showing the equivalence of the closed-form and recursive definitions for all n.
679 | ```
680 |
681 | We can continue our conversation from here:
682 |
683 | ```
684 | induction2(gpt-4o) say Explain this proof to a five-year-old.
685 | ...
686 | Okay, imagine you have a cool way to get into a big number treasure chest called the Fibonacci chest! There are two ways to find the "number treasure": one is building steps one by one (like adding the last two steps to get the next step), which is like adding blocks — 0, 1, 1, 2, 3, 5... And there's a magic code (the special formula) that lets you jump right to any step you want, as long as you know which number step it is. When you try both ways, they both give you the same number treasure every time you check a step, starting with the first hood on up. Isn't that neat? So whether you build step by step or use your special code, the treasure you find is always the same! 🧡
687 | ```
688 |
689 | Or return to the original response before we ran `retry`:
690 |
691 | ```
692 | induction2(gpt-4o) thread induction
693 | Switched to thread 'induction'
694 | ```
695 |
696 | The `copy` command copies all messages in the specified range and appends the copies to the thread provided, creating it if it doesn't exist. If no thread is provided, `copy` copies messages to the detached thread:
697 |
698 | ```
699 | induction(gpt-4o) thread
700 | detached thread
701 | (gpt-4o) clear
702 | Delete 4 messages? (y/n)y
703 | Cleared
704 | (gpt-4o) thread induction2
705 | Switched to thread 'induction2'
706 | induction2(gpt-4o) copy 1 2
707 | Selecting 2 messages
708 | First message selected: 'Responding only using ASCII/Unicode symbols and...'
709 | Last message selected: 'F(n) = (φ^n - ψ^n) / √5 where: φ = (1 + √5) / 2 ψ =...'
710 | Copy to detached thread? (y/n)y
711 | Copied
712 | induction2(gpt-4o) thread
713 | detached thread
714 | (gpt-4o) say Write a C function that implements this closed-form formula without any narrative explanation.
715 | ...
716 | #include
717 |
718 | int fibonacci(int n) {
719 | double phi = (1 + sqrt(5)) / 2;
720 | double psi = (1 - sqrt(5)) / 2;
721 | return round((pow(phi, n) - pow(psi, n)) / sqrt(5));
722 | }
723 | ```
724 |
725 | The `threads` command lists the named threads present in this session:
726 |
727 | ```
728 | (gpt-4o) threads
729 | induction2 (6 messages)
730 | induction (4 messages)
731 | (4 detached messages)
732 | ```
733 |
734 | The `delete` command, with the name of a thread passed as an argument, deletes the specified thread:
735 |
736 | ```
737 | (gpt-4o) delete induction
738 | Deleted thread induction
739 | (gpt-4o) threads
740 | induction2 (6 messages)
741 | (4 detached messages)
742 | ```
743 |
744 | With no argument, `delete` deletes **all** named threads in this session:
745 |
746 | ```
747 | (gpt-4o) delete
748 | Delete 1 thread? (y/n)y
749 | Deleted
750 | (gpt-4o) threads
751 | No threads
752 | (4 detached messages)
753 | ```
754 |
755 | ### Working with files
756 | The `transcribe` command writes a plain-text transcript of the current thread to a text file, overwriting any existing file contents. It takes the path to the file to write as an argument.
757 |
758 | The `save` command writes all named threads to a JSON file, overwriting any existing file contents. It takes the path to the file to write as an argument. With no argument, `save` writes to the most recently loaded or saved JSON file in the current session.
759 |
760 | The `load` command loads all saved named threads from a JSON file in the format written by `save`, merging them into the current session. If there is a naming conflict between a thread in the current session and a thread in the file to load, the thread in the file wins. The `load` command takes the path of the file to load as an argument.
761 |
762 | The `write` command writes the contents of the last message of the current thread to a text file, overwriting any existing file contents. This command is particularly useful when working with source code. It takes the path to the file to write as an argument.
763 |
764 | The `read` command appends a new message to the current thread containing the text content of the specified file. It takes two arguments: the path of the file to read and the role of the new message. For instance, `read prompt.txt system` reads the content of `prompt.txt` appending it as a new system message.
765 |
766 | ## Command line parameters
767 | Gptcmd supports a few command line parameters:
768 |
769 | ```
770 | $ gptcmd -h
771 | usage: gptcmd [-h] [-c CONFIG] [-t THREAD] [-m MODEL] [-a ACCOUNT] [--version] [path]
772 |
773 | positional arguments:
774 | path The path to a JSON file of named threads to load on launch
775 |
776 | options:
777 | -h, --help show this help message and exit
778 | -c CONFIG, --config CONFIG
779 | The path to a Gptcmd configuration file to use for this session
780 | -t THREAD, --thread THREAD
781 | The name of the thread to switch to on launch
782 | -m MODEL, --model MODEL
783 | The name of the model to switch to on launch
784 | -a ACCOUNT, --account ACCOUNT
785 | The name of the account to switch to on launch
786 | --version Show version and exit
787 | ```
788 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=61.0"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "gptcmd"
7 | authors = [
8 | { name="Bill Dengler", email="codeofdusk@gmail.com" },
9 | ]
10 | description = "Command line GPT conversation and experimentation environment"
11 | readme = "README.md"
12 | requires-python = ">=3.8.6"
13 | license = "MPL-2.0"
14 | classifiers = [
15 | "Programming Language :: Python :: 3",
16 | "Operating System :: OS Independent",
17 | ]
18 | dependencies = [
19 | "openai>=1.54.0, < 2.0.0",
20 | "tomli>=1.1.0, < 2.0.0 ; python_version < '3.11'",
21 | "backports.strenum>=1.3.1, < 2.0.0 ; python_version < '3.11'",
22 | ]
23 | dynamic = ["version"]
24 |
25 | [project.urls]
26 | "Homepage" = "https://github.com/codeofdusk/gptcmd"
27 | "Bug Tracker" = "https://github.com/codeofdusk/gptcmd/issues"
28 |
29 | [project.scripts]
30 | gptcmd = "gptcmd.cli:main"
31 |
32 | [tool.setuptools.package-data]
33 | "gptcmd" = ["config_sample.toml"]
34 |
35 | [tool.setuptools.dynamic]
36 | version = {attr = "gptcmd.cli.__version__"}
37 |
38 | [tool.black]
39 | line-length = 79
40 | target-version = ['py38']
41 | preview=true
42 |
--------------------------------------------------------------------------------
/src/gptcmd/__init__.py:
--------------------------------------------------------------------------------
1 | from .cli import __version__ as __version__
2 | from .cli import Gptcmd as Gptcmd
3 |
--------------------------------------------------------------------------------
/src/gptcmd/cli.py:
--------------------------------------------------------------------------------
1 | """
2 | This module contains the Gptcmd class and serves as an entry point to the
3 | Gptcmd command line application.
4 | Copyright 2024 Bill Dengler
5 | This Source Code Form is subject to the terms of the Mozilla Public
6 | License, v. 2.0. If a copy of the MPL was not distributed with this
7 | file, You can obtain one at https://mozilla.org/MPL/2.0/.
8 | """
9 |
10 | import argparse
11 | import atexit
12 | import cmd
13 | import concurrent.futures
14 | import dataclasses
15 | import datetime
16 | import difflib
17 | import json
18 | import os
19 | import re
20 | import shlex
21 | import subprocess
22 | import sys
23 | import tempfile
24 | import traceback
25 | from ast import literal_eval
26 | from textwrap import shorten
27 | from typing import (
28 | Any,
29 | Callable,
30 | Dict,
31 | List,
32 | Optional,
33 | Sequence,
34 | Tuple,
35 | )
36 |
37 | from .config import ConfigError, ConfigManager
38 | from .llm import CompletionError, InvalidAPIParameterError, LLMProviderFeature
39 | from .message import (
40 | Image,
41 | Message,
42 | MessageRole,
43 | MessageThread,
44 | PopStickyMessageError,
45 | )
46 |
47 | __version__ = "2.2.0"
48 |
49 |
50 | def input_with_handling(_input: Callable) -> Callable:
51 | "Catch KeyboardInterrupt to avoid crashing"
52 |
53 | def _inner(*args):
54 | try:
55 | return _input(*args)
56 | except KeyboardInterrupt:
57 | print("")
58 | return "\n"
59 |
60 | return _inner
61 |
62 |
63 | class Gptcmd(cmd.Cmd):
64 | "Represents the Gptcmd command line application"
65 |
66 | intro = (
67 | f"Welcome to Gptcmd {__version__}! Type help or ? to list commands.\n"
68 | )
69 |
70 | def __init__(
71 | self,
72 | thread_cls=MessageThread,
73 | config: Optional[ConfigManager] = None,
74 | *args,
75 | **kwargs,
76 | ):
77 | self.thread_cls = thread_cls
78 | self.last_path = None
79 | self.config = config or ConfigManager.from_toml()
80 | self._account = self.config.default_account
81 | self._detached = self.thread_cls("*detached*")
82 | self._current_thread = self._detached
83 | self._threads = {}
84 | self._session_cost_in_cents = 0
85 | self._session_cost_incomplete = False
86 | self._future_executor = concurrent.futures.ThreadPoolExecutor(
87 | max_workers=1
88 | )
89 | super().__init__(*args, **kwargs)
90 |
91 | @property
92 | def prompt(self):
93 | threadname = (
94 | ""
95 | if self._current_thread == self._detached
96 | else self._current_thread.name
97 | )
98 | return self.config.conf["prompt"].format(
99 | thread=threadname,
100 | model=self._account.provider.model,
101 | account=self._account.name,
102 | )
103 |
104 | @staticmethod
105 | def _fragment(tpl: str, msg: Message) -> str:
106 | """
107 | Returns an output string containing part of a message to provide
108 | context for certain operations. The tpl parameter contains a string in
109 | which this message fragment should be displayed. The characters {msg}
110 | are replaced with the message fragment.
111 | """
112 | PLACEHOLDER = "..."
113 | MAX_LENGTH = 79
114 | MIN_LENGTH = 5
115 | # length of the template once the fragment is stripped out
116 | head_length = len(tpl.replace("{msg}", ""))
117 | # 2 extra chars because repr() will add surrounding quotes
118 | avail = max(MIN_LENGTH, MAX_LENGTH - head_length - 2)
119 |
120 | short = shorten(msg.content, avail, placeholder=PLACEHOLDER)
121 |
122 | if short == PLACEHOLDER:
123 | short = msg.content[:avail] + PLACEHOLDER
124 |
125 | return tpl.format(msg=repr(short))
126 |
127 | @staticmethod
128 | def _user_range_to_python_range(
129 | ref: str, allow_single: bool = True, strict_range: bool = True
130 | ) -> Tuple[Optional[int], Optional[int]]:
131 | tokens = ref.split()
132 | if not tokens:
133 | raise ValueError("No indices provided")
134 | if len(tokens) == 1:
135 | if tokens[0] == ".":
136 | return (None, None)
137 | if not allow_single:
138 | raise ValueError("Wrong number of indices")
139 | start = end = tokens[0]
140 | elif len(tokens) == 2:
141 | start, end = tokens
142 | else:
143 | raise ValueError("Wrong number of indices")
144 |
145 | def _convert(token: str, is_start: bool) -> Optional[int]:
146 | if token == ".":
147 | return None
148 | val = int(token)
149 | if is_start:
150 | return val - 1 if val > 0 else val
151 | else:
152 | if val > 0:
153 | return val
154 | elif val == -1:
155 | return None
156 | else:
157 | return val + 1
158 |
159 | py_start = _convert(start, True)
160 | py_end = _convert(end, False)
161 |
162 | if len(tokens) == 1:
163 | if py_start == -1:
164 | py_end = None
165 | elif py_start is not None:
166 | py_end = py_start + 1
167 | if (
168 | strict_range
169 | and py_start is not None
170 | and py_end is not None
171 | and py_start >= py_end
172 | ):
173 | raise ValueError("Range end is beyond its start")
174 | return py_start, py_end
175 |
176 | @staticmethod
177 | def _confirm(prompt: str) -> bool:
178 | POSITIVE_STRINGS = ("y", "yes")
179 | NEGATIVE_STRINGS = ("n", "no")
180 | yn = None
181 | while yn not in (*POSITIVE_STRINGS, *NEGATIVE_STRINGS):
182 | yn = input(f"{prompt} (y/n)")
183 | return yn in POSITIVE_STRINGS
184 |
185 | @staticmethod
186 | def _complete_from_key(d: Dict, text: str) -> List[str]:
187 | return [k for k, v in d.items() if k.startswith(text)]
188 |
189 | @staticmethod
190 | def _shlex_path(path: str) -> List[str]:
191 | lexer = shlex.shlex(path, posix=True)
192 | lexer.escape = ""
193 | lexer.whitespace_split = True
194 | return list(lexer)
195 |
196 | @staticmethod
197 | def _await_future_interruptible(
198 | future: concurrent.futures.Future, interval: float = 0.25
199 | ):
200 | """
201 | Block until the future finishes, waking up
202 | at the supplied interval so the main thread can raise
203 | interrupts immediately.
204 | Returns future.result().
205 | """
206 | while True:
207 | try:
208 | return future.result(timeout=interval)
209 | except concurrent.futures.TimeoutError:
210 | continue
211 |
212 | @staticmethod
213 | def _menu(prompt: str, options: List[str]) -> Optional[str]:
214 | """
215 | Display a menu of options and return the chosen item, or None
216 | if canceled.
217 | """
218 | while True:
219 | print(
220 | prompt,
221 | "0. Cancel",
222 | *(
223 | f"{i}. {option}"
224 | for i, option in enumerate(options, start=1)
225 | ),
226 | sep="\n",
227 | )
228 | selection = input("Enter your selection: ")
229 | if not selection.isdigit():
230 | continue
231 | choice = int(selection)
232 | if choice == 0:
233 | return None
234 | if 1 <= choice <= len(options):
235 | return options[choice - 1]
236 |
237 | @staticmethod
238 | def _json_eval(s: str) -> Any:
239 | """
240 | Evaluate a Python literal from a string, restricted to values
241 | encodable as JSON
242 | """
243 | PYTHON_TYPES = {
244 | "True": True,
245 | "False": False,
246 | "None": None,
247 | }
248 | if s in PYTHON_TYPES:
249 | return PYTHON_TYPES[s]
250 | return json.loads(s)
251 |
252 | KNOWN_ROLES = tuple(MessageRole)
253 |
254 | @classmethod
255 | def _complete_role(cls, text: str) -> List[str]:
256 | return [role for role in cls.KNOWN_ROLES if role.startswith(text)]
257 |
258 | @classmethod
259 | def _validate_role(cls, role: str) -> bool:
260 | return role in cls.KNOWN_ROLES
261 |
262 | @classmethod
263 | def _disambiguate(
264 | cls, user_input: str, choices: Sequence[str]
265 | ) -> Optional[str]:
266 | DIFFLIB_CUTOFF = 0.5
267 | MAX_MATCHES = 9
268 |
269 | in_lower = user_input.lower()
270 | matches = difflib.get_close_matches(
271 | user_input,
272 | choices,
273 | n=MAX_MATCHES,
274 | cutoff=DIFFLIB_CUTOFF,
275 | )
276 |
277 | if len(user_input) > 2:
278 | matches.extend(
279 | [
280 | c
281 | for c in choices
282 | if in_lower in c.lower() and c not in matches
283 | ]
284 | )
285 |
286 | if not matches:
287 | return None
288 |
289 | ratio = {
290 | c: difflib.SequenceMatcher(None, user_input, c).ratio()
291 | for c in matches
292 | }
293 |
294 | def _has_non_digit_suffix(s: str) -> int:
295 | # 1 when the last hyphen/underscore-separated token is not
296 | # purely digits, 0 otherwise.
297 | last = re.split(r"[-_]", s)[-1]
298 | return int(not last.isdigit())
299 |
300 | def _max_numeric_token(s: str) -> int:
301 | # Greatest integer appearing anywhere in the candidate, or ‑1
302 | nums = re.findall(r"\d+", s)
303 | return max(map(int, nums)) if nums else -1
304 |
305 | matches = sorted(
306 | matches,
307 | key=lambda c: (
308 | # Literal match (prefer prefix)
309 | (
310 | 0
311 | if c.lower().startswith(in_lower)
312 | else 1 if in_lower in c.lower() else 2
313 | ),
314 | # Suffix match (prefer non-digit)
315 | # Heuristic: Prefer unversioned model aliases
316 | -_has_non_digit_suffix(c),
317 | # Difflib match (best first)
318 | -ratio[c],
319 | # Length match (shortest first)
320 | # Heuristic: Prefer unversioned model aliases
321 | len(c),
322 | # Numeric match (prefer larger numbers)
323 | # Heuristic: Prefer later model versions
324 | -_max_numeric_token(c),
325 | # Fallback: Lexicographic order
326 | c,
327 | ),
328 | )[:MAX_MATCHES]
329 |
330 | if len(matches) == 1:
331 | c = matches[0]
332 | match = c if cls._confirm(f"Did you mean {c!r}?") else None
333 | else:
334 | match = cls._menu("Did you mean one of these?", matches)
335 | if match is None:
336 | print("Cancelled")
337 | return match
338 |
339 | def emptyline(self):
340 | "Disable Python cmd's repeat last command behaviour."
341 | pass
342 |
343 | def cmdloop(self, *args, **kwargs):
344 | old_input = cmd.__builtins__["input"]
345 | cmd.__builtins__["input"] = input_with_handling(old_input)
346 | try:
347 | super().cmdloop(*args, **kwargs)
348 | finally:
349 | cmd.__builtins__["input"] = old_input
350 |
351 | def do_thread(self, arg, _print_on_success=True):
352 | """
353 | Switch to the thread passed as argument, creating it as a clone of the
354 | current thread if the supplied name does not exist. With no argument,
355 | switch to the detached thread.
356 | example: "thread messages" switches to the thread named "messages",
357 | creating it if necessary.
358 | """
359 | if not arg:
360 | self._current_thread = self._detached
361 | if _print_on_success:
362 | print("detached thread")
363 | return
364 | if arg not in self._threads:
365 | targetstr = "new thread"
366 | self._threads[arg] = self.thread_cls(
367 | name=arg,
368 | messages=self._current_thread.messages,
369 | names=self._current_thread.names,
370 | )
371 | if self._current_thread == self._detached and self._detached.dirty:
372 | self._detached.dirty = False
373 | else:
374 | targetstr = "thread"
375 | self._current_thread = self._threads[arg]
376 | if _print_on_success:
377 | print(f"Switched to {targetstr} {repr(self._current_thread.name)}")
378 |
379 | def complete_thread(self, text, line, begidx, endidx):
380 | return self.__class__._complete_from_key(self._threads, text)
381 |
382 | def do_threads(self, arg):
383 | """
384 | List all named threads in the current session. This command takes no
385 | arguments.
386 | """
387 | t = sorted(
388 | [(k, len(v)) for k, v in self._threads.items()],
389 | key=lambda x: x[1],
390 | )[::-1]
391 | if len(t) < 1:
392 | print("No threads")
393 | for name, count in t:
394 | if count == 1:
395 | msg = "message"
396 | else:
397 | msg = "messages"
398 | print(f"{name} ({count} {msg})")
399 | if self._detached:
400 | print(f"({len(self._detached)} detached messages)")
401 |
402 | def _should_allow_add_empty_messages(self, role: MessageRole) -> bool:
403 | allow_add_empty_messages = self.config.conf.get(
404 | "allow_add_empty_messages"
405 | )
406 | if allow_add_empty_messages == "always":
407 | return True
408 | elif allow_add_empty_messages == "ask":
409 | return self.__class__._confirm(f"Add empty {role} message?")
410 | else: # never (default)
411 | return False
412 |
413 | def _append_new_message(
414 | self,
415 | arg: str,
416 | role: MessageRole,
417 | _print_on_success: bool = True,
418 | _edit_on_empty: bool = True,
419 | ) -> Optional[Message]:
420 | if not arg and _edit_on_empty:
421 | arg = self._edit_interactively("")
422 | if not arg:
423 | if self._should_allow_add_empty_messages(role):
424 | arg = ""
425 | else:
426 | print("Cancelled")
427 | return None
428 | msg = Message(content=arg, role=role)
429 | actor = (
430 | f"{self._current_thread.names[role]!r} ({role})"
431 | if role in self._current_thread.names
432 | else role
433 | )
434 | self._current_thread.append(msg)
435 | if _print_on_success:
436 | print(self.__class__._fragment("{msg} added as " + actor, msg))
437 | return msg
438 |
439 | def do_user(self, arg):
440 | """
441 | Append a new user message (with content provided as argument) to the
442 | current thread. With no argument, opens an external editor for
443 | message composition.
444 | example: "user Hello, world!"
445 | """
446 | self._append_new_message(arg=arg, role=MessageRole.USER)
447 |
448 | def do_assistant(self, arg):
449 | """
450 | Append a new assistant message (with content provided as argument) to
451 | the current thread. With no argument, opens an external editor for
452 | message composition.
453 | example: "assistant how can I help?"
454 | """
455 | self._append_new_message(arg=arg, role=MessageRole.ASSISTANT)
456 |
457 | def do_system(self, arg):
458 | """
459 | Append a new system message (with content provided as argument) to the
460 | current thread. With no argument, opens an external editor for
461 | message composition.
462 | example: "system You are a friendly assistant."
463 | """
464 | self._append_new_message(arg=arg, role=MessageRole.SYSTEM)
465 |
466 | def do_first(self, arg):
467 | """
468 | Display the first n messages, or pass no arguments for the first
469 | message.
470 | example: "first 5"
471 | """
472 | if not arg:
473 | end_index = 1
474 | else:
475 | try:
476 | end_index = int(arg.strip())
477 | except ValueError:
478 | print("Usage: first – shows the first n messages.")
479 | return
480 | print(self._current_thread.render(start_index=0, end_index=end_index))
481 |
482 | def do_last(self, arg):
483 | """
484 | Display the last n messages, or pass no arguments for the last message.
485 | example: "last 5"
486 | """
487 | if not arg:
488 | start_index = -1
489 | else:
490 | try:
491 | start_index = int(arg) * -1
492 | except ValueError:
493 | print("Usage: last – shows the last n messages.")
494 | return
495 | print(self._current_thread.render(start_index=start_index))
496 |
497 | def do_view(self, arg):
498 | """
499 | Pass no arguments to read the entire thread in cronological order.
500 | Optionally, pass a range of messages to read that range.
501 | example: "view 1 4" views the first through fourth message.
502 | """
503 | if not arg:
504 | start = None
505 | end = None
506 | else:
507 | try:
508 | start, end = self.__class__._user_range_to_python_range(arg)
509 | except ValueError:
510 | print("Invalid view range")
511 | return
512 | print(self._current_thread.render(start_index=start, end_index=end))
513 |
514 | def do_send(self, arg):
515 | """
516 | Send the current thread to the language model and print the response.
517 | This command takes no arguments.
518 | """
519 | print("...")
520 | # Run the potentially long-running provider call in a background
521 | # thread so Ctrl+c can interrupt immediately.
522 | future = self._future_executor.submit(
523 | self._account.provider.complete, self._current_thread
524 | )
525 |
526 | try:
527 | res = self.__class__._await_future_interruptible(future)
528 | except KeyboardInterrupt:
529 | future.cancel()
530 | print("\nCancelled")
531 | # This API request may have incurred cost
532 | self._session_cost_incomplete = True
533 | return
534 | except (CompletionError, NotImplementedError, ValueError) as e:
535 | print(str(e))
536 | return
537 |
538 | try:
539 | for chunk in res:
540 | print(chunk, end="")
541 | print("\n", end="")
542 | except KeyboardInterrupt:
543 | print("\nDisconnected from stream")
544 | except CompletionError as e:
545 | print(str(e))
546 | finally:
547 | if res.message.role and res.message.content:
548 | self._current_thread.append(res.message)
549 | cost_info = ""
550 | if res.cost_in_cents is not None:
551 | self._session_cost_in_cents += res.cost_in_cents
552 | cost = round(self._session_cost_in_cents / 100, 2)
553 | prefix = (
554 | "Incomplete estimate of session cost"
555 | if self._session_cost_incomplete
556 | else "Estimated session cost"
557 | )
558 | cost_info = f"{prefix}: ${cost:.2f}"
559 | else:
560 | self._session_cost_incomplete = True
561 |
562 | token_info = ""
563 | if res.prompt_tokens and res.sampled_tokens:
564 | token_info = (
565 | f"{res.prompt_tokens} prompt, {res.sampled_tokens} sampled"
566 | " tokens used for this request"
567 | )
568 |
569 | show_cost = (
570 | cost_info
571 | and self.config.conf["show_cost"]
572 | and (
573 | not self._session_cost_incomplete
574 | or self.config.conf["show_incomplete_cost"]
575 | )
576 | )
577 | show_token_usage = (
578 | token_info and self.config.conf["show_token_usage"]
579 | )
580 |
581 | if show_cost and show_token_usage:
582 | print(f"{cost_info} ({token_info})")
583 | elif show_token_usage:
584 | print(token_info)
585 | elif show_cost:
586 | print(cost_info)
587 |
588 | def do_say(self, arg):
589 | """
590 | Append a new user message (with content provided as argument) to the
591 | current thread, then send the thread to the language model and print
592 | the response.
593 | example: "say Hello!"
594 | """
595 | if self._append_new_message(
596 | arg, MessageRole.USER, _print_on_success=False
597 | ):
598 | self.do_send(None)
599 |
600 | def do_pop(self, arg):
601 | """
602 | Delete the ith message, or pass no argument to delete the last.
603 | example: "pop -2" deletes the penultimate message.
604 | """
605 | if not self._current_thread:
606 | print("No messages")
607 | return
608 | try:
609 | if arg:
610 | n = int(arg)
611 | if n > 0:
612 | n -= 1
613 | msg = self._current_thread.pop(n)
614 | else:
615 | msg = self._current_thread.pop()
616 | print(self.__class__._fragment("{msg} deleted", msg))
617 | except IndexError:
618 | print("Message doesn't exist")
619 | except ValueError:
620 | print("Usage: pop – deletes the ith message")
621 | except PopStickyMessageError:
622 | print("That message is sticky; unsticky it first")
623 |
624 | def do_clear(self, arg):
625 | """
626 | Delete all messages in the current thread. This command takes no
627 | arguments.
628 | """
629 | stickys = self._current_thread.stickys
630 | length = len(self._current_thread) - len(stickys)
631 | if length < 1:
632 | print("No messages")
633 | return
634 | mq = "message" if length == 1 else "messages"
635 | can_clear = self.__class__._confirm(f"Delete {length} {mq}?")
636 | if can_clear:
637 | self._current_thread.messages = stickys
638 | print("Cleared")
639 |
640 | def do_delete(self, arg):
641 | """
642 | Delete the named thread passed as argument. With no argument, deletes
643 | all named threads in this session.
644 | example: "delete messages" deletes the thread named "messages".
645 | """
646 | if not self._threads:
647 | print("No threads")
648 | return
649 | if not arg:
650 | length = len(self._threads)
651 | suffix = "thread" if length == 1 else "threads"
652 | can_delete = self.__class__._confirm(f"Delete {length} {suffix}?")
653 | if can_delete:
654 | self._threads = {}
655 | self._current_thread = self._detached
656 | print("Deleted")
657 | elif arg in self._threads:
658 | if self._threads[arg] == self._current_thread:
659 | self._current_thread = self._detached
660 | del self._threads[arg]
661 | print(f"Deleted thread {arg}")
662 | else:
663 | print(f"{arg} doesn't exist")
664 |
665 | def complete_delete(self, text, line, begidx, endidx):
666 | return self.__class__._complete_from_key(self._threads, text)
667 |
668 | def do_move(self, arg):
669 | """
670 | Move the message at the beginning of a range to the end of that range.
671 | In other words, move moves the ith message of a thread to
672 | index j.
673 | """
674 | if not arg:
675 | print("Usage: move ")
676 | return
677 | try:
678 | i, j = self._user_range_to_python_range(
679 | arg, allow_single=False, strict_range=False
680 | )
681 | except ValueError:
682 | print("Invalid range specified")
683 | return
684 | length = len(self._current_thread.messages)
685 | if i is None:
686 | i = 0
687 | if j is None:
688 | j = length
689 | if i < 0:
690 | i += length
691 | if j < 0:
692 | j += length
693 | elif j > 0:
694 | j -= 1 # Adjust end for 1-based indexing
695 | if not (0 <= j <= length):
696 | print("Destination out of bounds")
697 | return
698 | try:
699 | msg = self._current_thread.move(i, j)
700 | except IndexError:
701 | print("Message doesn't exist")
702 | return
703 | except PopStickyMessageError:
704 | print("That message is sticky; unsticky it first")
705 | return
706 | if j == i:
707 | move_info = "to same position"
708 | elif j == 0:
709 | move_info = "to start"
710 | elif j >= length - 1:
711 | move_info = "to end"
712 | elif j > i:
713 | move_info = self._fragment(
714 | "before {msg}", msg=self._current_thread.messages[j + 1]
715 | )
716 | elif j < i:
717 | move_info = self._fragment(
718 | "after {msg}", msg=self._current_thread.messages[j - 1]
719 | )
720 | else:
721 | move_info = "to unknown position"
722 | print(self.__class__._fragment("{msg} moved ", msg=msg) + move_info)
723 |
724 | def do_copy(self, arg):
725 | """
726 | Append copies of the messages in the specified range to the thread
727 | provided. If no thread name is specified, the copy command copies
728 | messages to the detached thread.
729 | example: "copy 1 3" copies the first through third message of this
730 | thread to the detached thread.
731 | "copy . messages" copies all messages in this thread to a thread
732 | called "messages", creating it if it doesn't exist.
733 | """
734 | m = re.match(
735 | (r"((?:-?\d+|\.)(?:\s+-?\d+|\s*\.)*)" r"(?: (\S+))?$"), arg
736 | )
737 | if not m:
738 | print("Usage: copy [thread]")
739 | return
740 | ref, threadname = m.groups()
741 | try:
742 | start, end = self.__class__._user_range_to_python_range(ref)
743 | except ValueError:
744 | print("Invalid range")
745 | return
746 | s = self._current_thread[start:end]
747 | if not s:
748 | print("Empty selection")
749 | return
750 | if len(s) == 1:
751 | print(
752 | self.__class__._fragment(
753 | "Selection contains one message: {msg}", s[0]
754 | )
755 | )
756 | else:
757 | print(f"Selecting {len(s)} messages")
758 | print(
759 | self.__class__._fragment("First message selected: {msg}", s[0])
760 | )
761 | print(
762 | self.__class__._fragment("Last message selected: {msg}", s[-1])
763 | )
764 | if threadname is None:
765 | thread = self._detached
766 | thread_info = "detached thread"
767 | elif threadname in self._threads:
768 | thread = self._threads.get(threadname)
769 | thread_info = threadname
770 | else:
771 | thread = None
772 | thread_info = f"New thread {threadname}"
773 | can_copy = self.__class__._confirm(f"Copy to {thread_info}?")
774 | if not can_copy:
775 | return
776 | if thread is None: # if this is a new thread
777 | self._threads[threadname] = self.thread_cls(
778 | name=threadname,
779 | messages=s,
780 | names=self._current_thread.names,
781 | )
782 | else:
783 | for msg in s:
784 | thread.append(dataclasses.replace(msg))
785 | print("Copied")
786 |
787 | def do_retry(self, arg):
788 | """
789 | Delete up to the last non-sticky assistant message, then send the
790 | conversation to the language model. This command takes no arguments.
791 | """
792 | if not any(
793 | m.role != MessageRole.ASSISTANT for m in self._current_thread
794 | ):
795 | print("Nothing to retry!")
796 | return
797 | if self._current_thread != self._detached:
798 | create_new_thread_on_retry = self.config.conf.get(
799 | "create_new_thread_on_retry", "always"
800 | )
801 | should_create = None
802 | is_numbered_thread = re.match(
803 | r"(.*?)(\d+$)", self._current_thread.name
804 | )
805 | if self._current_thread.name.isdigit():
806 | basename = self._current_thread.name
807 | num = 2
808 | elif is_numbered_thread:
809 | basename = is_numbered_thread.group(1)
810 | num = int(is_numbered_thread.group(2))
811 | else:
812 | basename = self._current_thread.name
813 | num = 2
814 | while basename + str(num) in self._threads:
815 | num += 1
816 | newname = basename + str(num)
817 |
818 | if create_new_thread_on_retry == "ask":
819 | should_create = self.__class__._confirm(
820 | f"Create thread {newname!r}?"
821 | )
822 | elif create_new_thread_on_retry == "never":
823 | should_create = False
824 | else:
825 | should_create = True
826 |
827 | if should_create:
828 | self.do_thread(newname)
829 | for i in range(len(self._current_thread) - 1, -1, -1):
830 | role = self._current_thread[i].role
831 | if role == MessageRole.ASSISTANT:
832 | try:
833 | self._current_thread.pop(i)
834 | break
835 | except PopStickyMessageError:
836 | continue
837 | self.do_send(None)
838 |
839 | def do_model(self, arg, _print_on_success=True):
840 | """
841 | Change the model used by the current thread. Pass no argument to
842 | check the currently active model.
843 | example: "model gpt-3.5-turbo"
844 | """
845 | if not arg:
846 | print(f"Current model: {self._account.provider.model}")
847 | return
848 | if self._account.provider.valid_models is None:
849 | is_valid_model = self.__class__._confirm(
850 | f"{self._account.name} does not support model validation. "
851 | "If this model does not exist, requests to it will fail. "
852 | "Switch anyway?"
853 | )
854 | else:
855 | is_valid_model = arg in self._account.provider.valid_models
856 | if is_valid_model:
857 | self._account.provider.model = arg
858 | if _print_on_success:
859 | print(f"Switched to model {self._account.provider.model!r}")
860 | else:
861 | print(f"{arg} is currently unavailable")
862 | valid_models = self._account.provider.valid_models or ()
863 | match = self.__class__._disambiguate(arg, valid_models)
864 | if match and match != arg:
865 | self.do_model(match, _print_on_success=_print_on_success)
866 |
867 | def complete_model(self, text, line, begidx, endidx):
868 | valid_models = self._account.provider.valid_models or ()
869 | return [m for m in valid_models if m.startswith(text)]
870 |
871 | def do_set(self, arg):
872 | """
873 | Set an API parameter. Pass no arguments to see currently set
874 | parameters. Valid Python literals are supported (None represents null).
875 | example: "set temperature 0.9"
876 | """
877 | if not arg:
878 | if not self._account.provider.api_params:
879 | print("No API parameter definitions")
880 | return
881 | for k, v in self._account.provider.api_params.items():
882 | print(f"{k}: {repr(v)}")
883 | else:
884 | t = arg.split()
885 | key = t[0]
886 | try:
887 | val = literal_eval(" ".join(t[1:]))
888 | except (SyntaxError, ValueError):
889 | print("Invalid syntax")
890 | return
891 | try:
892 | validated_val = self._account.provider.set_api_param(key, val)
893 | print(f"{key} set to {validated_val!r}")
894 | except InvalidAPIParameterError as e:
895 | print(str(e))
896 |
897 | def complete_set(self, text, line, begidx, endidx):
898 | KNOWN_OPENAI_API_PARAMS = ( # Add other parameters (not defined as
899 | # special in MessageThread.set_api_param) to this list if the API
900 | # changes.
901 | "temperature",
902 | "top_p",
903 | "stop",
904 | "max_tokens",
905 | "presence_penalty",
906 | "frequency_penalty",
907 | "logit_bias",
908 | "request_timeout",
909 | )
910 | if begidx <= 4: # In the first argument
911 | return [
912 | param
913 | for param in KNOWN_OPENAI_API_PARAMS
914 | if param.startswith(text)
915 | ]
916 |
917 | def do_unset(self, arg):
918 | """
919 | Clear the definition of a custom API parameter. Pass no arguments
920 | to clear all parameters.
921 | example: "unset timeout"
922 | """
923 | try:
924 | if not arg:
925 | self._account.provider.unset_api_param(None)
926 | print("Unset all parameters")
927 | else:
928 | self._account.provider.unset_api_param(arg)
929 | print(f"{arg} unset")
930 | except InvalidAPIParameterError as e:
931 | print(e)
932 |
933 | def complete_unset(self, text, line, begidx, endidx):
934 | return self.__class__._complete_from_key(
935 | self._account.provider.api_params, text
936 | )
937 |
938 | def do_stream(self, arg):
939 | """
940 | Toggle streaming, which allows responses to be displayed as they are
941 | generated. This command takes no arguments.
942 | """
943 | try:
944 | self._account.provider.stream = not self._account.provider.stream
945 | if self._account.provider.stream:
946 | print("On")
947 | else:
948 | print("Off")
949 | except NotImplementedError as e:
950 | print(str(e))
951 |
952 | def do_name(self, arg):
953 | """
954 | Set a name to send to the language model for all future messages of
955 | the specified role. First argument is the role
956 | (user/assistant/system), second is the name to send. Pass no arguments
957 | to see all set names in this thread.
958 | example: "name user Bill"
959 | """
960 | if not arg:
961 | for k, v in self._current_thread.names.items():
962 | print(f"{k}: {v}")
963 | return
964 | if (
965 | LLMProviderFeature.MESSAGE_NAME_FIELD
966 | not in self._account.provider.SUPPORTED_FEATURES
967 | ):
968 | print("Name definition not supported")
969 | return
970 | t = arg.split()
971 | if len(t) != 2 or not self.__class__._validate_role(t[0]):
972 | print(
973 | f"Usage: name <{'|'.join(self.__class__.KNOWN_ROLES)}> "
975 | )
976 | return
977 | role = MessageRole(t[0])
978 | name = " ".join(t[1:])
979 | self._current_thread.names[role] = name
980 | print(f"{role} set to {name!r}")
981 |
982 | def complete_name(self, text, line, begidx, endidx):
983 | if begidx <= 5: # In the first argument
984 | return self.__class__._complete_role(text)
985 |
986 | def do_unname(self, arg):
987 | """
988 | Clear the definition of a name. Pass no arguments to clear all
989 | names.
990 | """
991 | if (
992 | LLMProviderFeature.MESSAGE_NAME_FIELD
993 | not in self._account.provider.SUPPORTED_FEATURES
994 | ):
995 | print("Name definition not supported")
996 | return
997 | if not arg:
998 | self._current_thread.names = {}
999 | print("Unset all names")
1000 | name = self._current_thread.names.get(arg)
1001 | if name is None:
1002 | print(f"{arg} not set")
1003 | else:
1004 | del self._current_thread.names[arg]
1005 | print(f"{arg} is no longer {name!r}")
1006 |
1007 | def complete_unname(self, text, line, begidx, endidx):
1008 | return self.__class__._complete_from_key(
1009 | self._current_thread.names, text
1010 | )
1011 |
1012 | def do_rename(self, arg):
1013 | """
1014 | Change the name for the specified role over a range of non-sticky
1015 | messages in the current thread. This command takes three arguments:
1016 | the role (user/assistant/system), the range to affect, and an optional
1017 | name (omitting the name clears it).
1018 | examples:
1019 | "rename assistant 1 5 AI" (sets the name to "AI" in any user messages
1020 | in the first through fifth message)
1021 | "rename assistant 1 3" (unsets names on assistant messages in the
1022 | current thread)
1023 | "rename user ." (unsets all names on user messages in the current
1024 | thread)
1025 | """
1026 | m = re.match(
1027 | (
1028 | f"^({'|'.join(self.__class__.KNOWN_ROLES)})\\s+"
1029 | r"((?:-?\d+|\.)(?:\s+-?\d+|\s*\.)*)"
1030 | r"(?:\s+([a-zA-Z0-9_-]{1,64}))?$"
1031 | ),
1032 | arg,
1033 | )
1034 | if not m:
1035 | print(
1036 | f"Usage: rename <{'|'.join(self.__class__.KNOWN_ROLES)}>"
1037 | " [name]"
1038 | )
1039 | return
1040 | if (
1041 | LLMProviderFeature.MESSAGE_NAME_FIELD
1042 | not in self._account.provider.SUPPORTED_FEATURES
1043 | ):
1044 | print("Name definition not supported")
1045 | return
1046 | role, ref, name = m.groups()
1047 | try:
1048 | start, end = self.__class__._user_range_to_python_range(ref)
1049 | except ValueError:
1050 | print("Invalid rename range")
1051 | return
1052 | t = self._current_thread.rename(
1053 | role=role, name=name, start_index=start, end_index=end
1054 | )
1055 | mp = "message" if len(t) == 1 else "messages"
1056 | print(f"{len(t)} {mp} renamed")
1057 |
1058 | def complete_rename(self, text, line, begidx, endidx):
1059 | if begidx <= 7: # In the first argument
1060 | return self.__class__._complete_role(text)
1061 |
1062 | def do_sticky(self, arg):
1063 | """
1064 | Sticky the messages in the specified range, so that deletion commands
1065 | in the current thread (pop, clear, etc.) and rename do not affect them.
1066 | example: "sticky 1 5"
1067 | """
1068 | try:
1069 | start, end = self.__class__._user_range_to_python_range(arg)
1070 | except ValueError:
1071 | print("Invalid sticky range")
1072 | return
1073 | t = self._current_thread.sticky(start, end, True)
1074 | mp = "message" if len(t) == 1 else "messages"
1075 | print(f"{len(t)} {mp} stickied")
1076 |
1077 | def do_unsticky(self, arg):
1078 | """
1079 | Unsticky any sticky mesages in the specified range, so that deletion
1080 | and rename commands once again affect them.
1081 | example: "unsticky 1 5"
1082 | """
1083 | try:
1084 | start, end = self.__class__._user_range_to_python_range(arg)
1085 | except ValueError:
1086 | print("Invalid unsticky range")
1087 | return
1088 | t = self._current_thread.sticky(start, end, False)
1089 | mp = "message" if len(t) == 1 else "messages"
1090 | print(f"{len(t)} {mp} unstickied")
1091 |
1092 | def do_save(
1093 | self,
1094 | arg: str,
1095 | _extra_metadata: Optional[Dict[str, Any]] = None,
1096 | _print_on_success: bool = True,
1097 | ):
1098 | """
1099 | Save all named threads to the specified json file. With no argument,
1100 | save to the most recently loaded/saved JSON file in this session.
1101 | """
1102 | args = self.__class__._shlex_path(arg)
1103 | if len(args) > 1:
1104 | print("Usage: save [path]")
1105 | return
1106 | if self._detached.dirty:
1107 | print(
1108 | f"Warning: {len(self._detached)} detached messages will not"
1109 | " be saved. If you wish to save them, create a named"
1110 | " thread."
1111 | )
1112 | if not self._threads:
1113 | print("No threads to save!")
1114 | return
1115 | if not args:
1116 | if self.last_path is None:
1117 | print("No file specified")
1118 | return
1119 | path = self.last_path
1120 | else:
1121 | path = args[0]
1122 | res = {}
1123 | if _extra_metadata is None:
1124 | res["_meta"] = {}
1125 | else:
1126 | res["_meta"] = _extra_metadata.copy()
1127 | res["_meta"]["version"] = __version__
1128 | res["threads"] = {k: v.to_dict() for k, v in self._threads.items()}
1129 | try:
1130 | with open(path, "w", encoding="utf-8") as cam:
1131 | json.dump(res, cam, indent=2)
1132 | except (OSError, UnicodeEncodeError) as e:
1133 | print(str(e))
1134 | return
1135 | for thread in self._threads.values():
1136 | thread.dirty = False
1137 | if _print_on_success:
1138 | print(f"{os.path.abspath(path)} saved")
1139 | self.last_path = path
1140 |
1141 | def do_load(self, arg, _print_on_success=True):
1142 | "Load all threads from the specified json file."
1143 | if not arg:
1144 | print("Usage: load \n")
1145 | return
1146 | try:
1147 | args = self.__class__._shlex_path(arg)
1148 | except ValueError as e:
1149 | print(e)
1150 | return
1151 | if len(args) != 1:
1152 | print("Usage: load ")
1153 | return
1154 | path = args[0]
1155 | try:
1156 | with open(path, encoding="utf-8") as fin:
1157 | d = json.load(fin)
1158 | except (
1159 | OSError,
1160 | json.JSONDecodeError,
1161 | UnicodeDecodeError,
1162 | ) as e:
1163 | print(f"Cannot load: {str(e)}")
1164 | return
1165 | if "_meta" not in d:
1166 | print("Cannot load: malformed or very old file!")
1167 | return
1168 | my_major = int(__version__.split(".")[0])
1169 | their_major = int(d["_meta"]["version"].split(".")[0])
1170 | if my_major < their_major:
1171 | print(
1172 | "Cannot load: this file requires Gptcmd version"
1173 | f" {their_major}.0.0 or later!"
1174 | )
1175 | return
1176 | self._threads.update(
1177 | {
1178 | k: self.thread_cls.from_dict(v, name=k)
1179 | for k, v in d["threads"].items()
1180 | }
1181 | )
1182 | if self._current_thread != self._detached:
1183 | # If a thread is loaded with the same name as the current thread,
1184 | # the current thread might become unreachable.
1185 | # Re-sync the current thread with reality.
1186 | self._current_thread = self._threads.get(
1187 | self._current_thread.name, self._detached
1188 | )
1189 | self.last_path = arg
1190 | if _print_on_success:
1191 | print(f"{arg} loaded")
1192 |
1193 | def do_read(self, arg):
1194 | """
1195 | Read the contents of the file (first argument) as a new message with
1196 | the specified role (second argument).
1197 | example: "read /path/to/prompt.txt system"
1198 | """
1199 | try:
1200 | args = self.__class__._shlex_path(arg)
1201 | except ValueError as e:
1202 | print(e)
1203 | return
1204 | if len(args) < 2 or not self.__class__._validate_role(args[-1]):
1205 | print(
1206 | f"Usage: read <{'|'.join(self.__class__.KNOWN_ROLES)}>"
1207 | )
1208 | return
1209 | path = " ".join(args[:-1])
1210 | role = MessageRole(args[-1])
1211 | try:
1212 | with open(path, encoding="utf-8", errors="ignore") as fin:
1213 | self._append_new_message(
1214 | arg=fin.read(), role=role, _edit_on_empty=False
1215 | )
1216 | except (OSError, UnicodeDecodeError) as e:
1217 | print(str(e))
1218 | return
1219 |
1220 | def complete_read(self, text, line, begidx, endidx):
1221 | if begidx > 5: # Passed the first argument
1222 | return self.__class__._complete_role(text)
1223 |
1224 | def do_write(self, arg):
1225 | "Write the contents of the last message to the specified file."
1226 | try:
1227 | args = self.__class__._shlex_path(arg)
1228 | except ValueError as e:
1229 | print(e)
1230 | return
1231 | if len(args) != 1:
1232 | print("Usage: write ")
1233 | return
1234 | path = args[0]
1235 | try:
1236 | with open(path, "w", encoding="utf-8", errors="ignore") as cam:
1237 | msg = self._current_thread.messages[-1]
1238 | cam.write(msg.content)
1239 | print(
1240 | self.__class__._fragment(
1241 | "{msg} written to " + os.path.abspath(path), msg
1242 | )
1243 | )
1244 | except (OSError, UnicodeEncodeError) as e:
1245 | print(str(e))
1246 | return
1247 |
1248 | def complete_write(self, text, line, begidx, endidx):
1249 | if begidx > 6: # Passed the first argument
1250 | return self.__class__._complete_role(text)
1251 |
1252 | def do_transcribe(self, arg):
1253 | """
1254 | Write the entire thread (as a human-readable transcript) to the
1255 | specified file.
1256 | """
1257 | try:
1258 | args = self.__class__._shlex_path(arg)
1259 | except ValueError as e:
1260 | print(e)
1261 | return
1262 | if len(args) != 1:
1263 | print("Usage: transcribe ")
1264 | return
1265 | path = args[0]
1266 | try:
1267 | with open(path, "w", encoding="utf-8", errors="ignore") as cam:
1268 | cam.write(
1269 | self._current_thread.render(display_indicators=False)
1270 | )
1271 | print(f"Transcribed to {os.path.abspath(path)}")
1272 | except (OSError, UnicodeEncodeError) as e:
1273 | print(str(e))
1274 | return
1275 |
1276 | def do_image(self, arg):
1277 | "Attach an image at the specified location"
1278 | USAGE = "Usage: image [message]"
1279 | m = re.match(r"^(.*?)(?:\s(-?\d+))?$", arg)
1280 | if not m:
1281 | print(USAGE)
1282 | return
1283 | location, ref = m.groups()
1284 | if not location or location.isspace():
1285 | print(USAGE)
1286 | return
1287 | try:
1288 | idx = (
1289 | -1
1290 | if ref is None
1291 | else self.__class__._user_range_to_python_range(ref)[0]
1292 | )
1293 | except ValueError:
1294 | print("Invalid message specification")
1295 | return
1296 | if location.startswith("http"):
1297 | img = Image(url=location)
1298 | else:
1299 | try:
1300 | img = Image.from_path(self.__class__._shlex_path(location)[0])
1301 | except (OSError, ValueError) as e:
1302 | print(e)
1303 | return
1304 | try:
1305 | msg = self._current_thread[idx]
1306 | msg.attachments.append(img)
1307 | print(self.__class__._fragment("Image added to {msg}", msg))
1308 | except IndexError:
1309 | print("Message doesn't exist")
1310 | return
1311 |
1312 | def do_account(self, arg, _print_on_success: bool = True):
1313 | "Switch between configured accounts."
1314 | if not arg:
1315 | others = [
1316 | v.name
1317 | for v in self.config.accounts.values()
1318 | if v != self._account
1319 | ]
1320 | print(f"Active account: {self._account.name}")
1321 | if others:
1322 | print(f"Available accounts: {', '.join(others)}")
1323 | return
1324 | if arg in self.config.accounts:
1325 | candidate = self.config.accounts[arg]
1326 | try:
1327 | _ = candidate.provider # Attempt to instantiate
1328 | except ConfigError as e:
1329 | print(str(e))
1330 | return
1331 | self._account = candidate
1332 | if _print_on_success:
1333 | print(f"Switched to account {self._account.name!r}")
1334 | else:
1335 | print(f"{arg} is not configured")
1336 |
1337 | def complete_account(self, text, line, begidx, endidx):
1338 | return self.__class__._complete_from_key(self.config.accounts, text)
1339 |
1340 | def _edit_interactively(
1341 | self, initial_text: str, filename_prefix: str = "gptcmd"
1342 | ) -> Optional[str]:
1343 | try:
1344 | with tempfile.NamedTemporaryFile(
1345 | prefix=filename_prefix,
1346 | mode="w",
1347 | delete=False,
1348 | encoding="utf-8",
1349 | errors="ignore",
1350 | ) as cam:
1351 | cam.write(initial_text)
1352 | tempname = cam.name
1353 | except (OSError, UnicodeEncodeError) as e:
1354 | print(e)
1355 | return None
1356 | except KeyboardInterrupt:
1357 | return None
1358 | try:
1359 | mtime_before = os.path.getmtime(tempname)
1360 | subprocess.run((*self.config.editor, tempname), check=True)
1361 | mtime_after = os.path.getmtime(tempname)
1362 | if mtime_after == mtime_before:
1363 | # File was not changed
1364 | return None
1365 | with open(tempname, encoding="utf-8") as fin:
1366 | return fin.read()
1367 | except FileNotFoundError:
1368 | editor_cmd = " ".join(self.config.editor)
1369 | print(f"Editor {editor_cmd} could not be found")
1370 | return None
1371 | except (
1372 | UnicodeDecodeError,
1373 | subprocess.CalledProcessError,
1374 | ConfigError,
1375 | ) as e:
1376 | print(e)
1377 | return None
1378 | except KeyboardInterrupt:
1379 | return None
1380 | finally:
1381 | # Clean up the temporary file
1382 | os.unlink(tempname)
1383 |
1384 | def do_edit(self, arg):
1385 | """
1386 | Opens the content of the specified message in an external editor for
1387 | modification. With no argument, edits the last message.
1388 | """
1389 | try:
1390 | idx = (
1391 | -1
1392 | if not arg
1393 | else self.__class__._user_range_to_python_range(arg)[0]
1394 | )
1395 | except ValueError:
1396 | print(
1397 | "Usage: edit[message]\n"
1398 | "With no argument, the edit command edits the last message. "
1399 | "With a message number provided as an argument, the edit "
1400 | "command edits that message."
1401 | )
1402 | return
1403 | try:
1404 | msg = self._current_thread.messages[idx]
1405 | new = self._edit_interactively(msg.content)
1406 | if new:
1407 | msg.content = new
1408 | print("Edited")
1409 | else:
1410 | print("Cancelled")
1411 | except IndexError:
1412 | print("Message doesn't exist")
1413 |
1414 | def do_grep(self, arg):
1415 | """
1416 | Search the current thread for messages whose content matches the
1417 | supplied regex.
1418 | """
1419 | if not arg:
1420 | print("Usage: grep ")
1421 | return
1422 | try:
1423 | expr = re.compile(arg)
1424 | except re.error as e:
1425 | print(e)
1426 | return
1427 |
1428 | def _show(content: str, m: re.Match, width: int = 40) -> str:
1429 | start = max(0, m.start() - width)
1430 | end = min(len(content), m.end() + width)
1431 | res = content[start:end]
1432 | if start > 0:
1433 | res = "..." + res
1434 | if end < len(content):
1435 | res += "..."
1436 | return expr.sub(lambda x: f"[{x.group(0)}]", res)
1437 |
1438 | found = False
1439 | for idx, msg in enumerate(self._current_thread.messages, start=1):
1440 | if not (m := expr.search(msg.content)):
1441 | continue
1442 | if m.end() == m.start():
1443 | continue
1444 | preview = self.__class__._fragment(
1445 | "{msg}", Message(content=_show(msg.content, m), role=msg.role)
1446 | )
1447 | name = msg.name if msg.name else msg.role
1448 | print(f"{msg.display_indicators}{idx} ({name}): {preview}")
1449 | found = True
1450 |
1451 | if not found:
1452 | print("No hits!")
1453 |
1454 | def _parse_meta_args(
1455 | self, arg: str
1456 | ) -> Tuple["Message", Optional[str], Optional[str]]:
1457 | arg = arg.strip()
1458 | if not arg:
1459 | # bare `meta` / `unmeta` operate on last message
1460 | return (self._current_thread[-1], None, None)
1461 |
1462 | tokens = arg.split()
1463 |
1464 | idx_token = None
1465 | if tokens and (tokens[0] == "." or tokens[0].lstrip("-").isdigit()):
1466 | idx_token = tokens.pop(0)
1467 |
1468 | if idx_token is None:
1469 | idx = -1
1470 | else:
1471 | start, _ = self.__class__._user_range_to_python_range(
1472 | idx_token, allow_single=True
1473 | )
1474 | idx = -1 if start is None else start
1475 |
1476 | # will raise IndexError if message is absent, handled by caller
1477 | target_msg = self._current_thread[idx]
1478 |
1479 | if not tokens:
1480 | return (target_msg, None, None)
1481 |
1482 | key = tokens.pop(0)
1483 | val = " ".join(tokens) if tokens else None
1484 | return (target_msg, key, val)
1485 |
1486 | def do_meta(self, arg):
1487 | """
1488 | Get or set metadata on a message.
1489 | """
1490 | USAGE = "Usage: meta [message] [value]"
1491 | try:
1492 | msg, key, val = self._parse_meta_args(arg)
1493 | except ValueError:
1494 | print(USAGE)
1495 | return
1496 | except IndexError:
1497 | print("message doesn't exist")
1498 | return
1499 |
1500 | if key is None:
1501 | if msg.metadata:
1502 | for k, v in msg.metadata.items():
1503 | print(f"{k}: {v!r}")
1504 | else:
1505 | print(
1506 | self.__class__._fragment("No metadata set on {msg}", msg)
1507 | )
1508 | return
1509 |
1510 | if val is None:
1511 | print(repr(msg.metadata.get(key, f"{key} not set")))
1512 | return
1513 |
1514 | try:
1515 | validated_val = self.__class__._json_eval(val)
1516 | except (json.JSONDecodeError, UnicodeDecodeError):
1517 | print("Invalid syntax")
1518 | return
1519 | msg.metadata[key] = validated_val
1520 | self._current_thread.dirty = True
1521 |
1522 | printable_val = (
1523 | repr(validated_val).replace("{", "{{").replace("}", "}}")
1524 | )
1525 | print(
1526 | self.__class__._fragment(
1527 | f"{key} set to {printable_val} on {{msg}}", msg
1528 | )
1529 | )
1530 |
1531 | def complete_meta(self, text, line, begidx, endidx):
1532 | if text.lstrip("-").isdigit():
1533 | return []
1534 | try:
1535 | msg = self._current_thread[-1]
1536 | return self.__class__._complete_from_key(msg.metadata, text)
1537 | except IndexError:
1538 | return []
1539 |
1540 | def do_unmeta(self, arg):
1541 | """
1542 | Delete a metadata key from a message.
1543 | """
1544 | USAGE = "Usage: unmeta [message] "
1545 | try:
1546 | msg, key, val = self._parse_meta_args(arg)
1547 | except ValueError:
1548 | print(USAGE)
1549 | return
1550 | except IndexError:
1551 | print("message doesn't exist")
1552 | return
1553 | if val is not None: # malformed syntax
1554 | print(USAGE)
1555 | return
1556 | if key is None:
1557 | if not msg.metadata:
1558 | print(
1559 | self.__class__._fragment("No metadata set on {msg}", msg)
1560 | )
1561 | return
1562 | n = len(msg.metadata)
1563 | items = "item" if n == 1 else "items"
1564 | prompt = self.__class__._fragment(
1565 | f"delete {n} {items} on {{msg}}?", msg
1566 | )
1567 | if not self.__class__._confirm(prompt):
1568 | return
1569 | msg.metadata.clear()
1570 | self._current_thread.dirty = True
1571 | print(self.__class__._fragment("Unset all metadata on {msg}", msg))
1572 | return
1573 | if key in msg.metadata:
1574 | msg.metadata.pop(key)
1575 | self._current_thread.dirty = True
1576 | print(self.__class__._fragment(f"{key} unset on {{msg}}", msg))
1577 | else:
1578 | print(self.__class__._fragment(f"{key} not set on {{msg}}", msg))
1579 |
1580 | def complete_unmeta(self, text, line, begidx, endidx):
1581 | try:
1582 | msg = self._current_thread[-1]
1583 | return self.__class__._complete_from_key(msg.metadata, text)
1584 | except IndexError:
1585 | return []
1586 |
1587 | def do_quit(self, arg):
1588 | "Exit the program."
1589 | warn = ""
1590 | if self._detached.dirty:
1591 | warn += "All unsaved detached messages will be lost.\n"
1592 | for threadname, thread in self._threads.items():
1593 | if thread.dirty:
1594 | warn += f"{threadname} has unsaved changes.\n"
1595 | if warn:
1596 | can_exit = self.__class__._confirm(
1597 | f"{warn}\nAre you sure that you wish to exit?"
1598 | )
1599 | else:
1600 | can_exit = True
1601 | if can_exit:
1602 | self._future_executor.shutdown(wait=False)
1603 | return can_exit # Truthy return values cause the cmdloop to stop
1604 |
1605 |
1606 | def _write_crash_dump(shell: Gptcmd, exc: Exception) -> Optional[str]:
1607 | """
1608 | Serialize the current shell into a JSON file and return its absolute
1609 | path.
1610 | """
1611 | detached_added = False
1612 | try:
1613 | ts = (
1614 | datetime.datetime.now()
1615 | .isoformat(timespec="seconds")
1616 | .replace(":", "-")
1617 | )
1618 | filename = f"gptcmd-{ts}.json"
1619 | tb_text = "".join(
1620 | traceback.format_exception(type(exc), exc, exc.__traceback__)
1621 | )
1622 | if shell._detached:
1623 | original_dirty = shell._detached.dirty
1624 | detached_base = "__detached__"
1625 | detached_key = detached_base
1626 | i = 1
1627 | while detached_key in shell._threads:
1628 | i += 1
1629 | detached_key = f"{detached_base}{i}"
1630 | shell._detached.dirty = False
1631 | shell._threads[detached_key] = shell._detached
1632 | detached_added = True
1633 | shell.do_save(
1634 | filename,
1635 | _extra_metadata={"crash_traceback": tb_text},
1636 | _print_on_success=False,
1637 | )
1638 | return os.path.abspath(filename)
1639 | except Exception as e:
1640 | print(f"Failed to write crash dump: {e}", file=sys.stderr)
1641 | return None
1642 | finally:
1643 | if detached_added:
1644 | shell._detached.dirty = original_dirty
1645 | shell._threads.pop(detached_key, None)
1646 |
1647 |
1648 | def main() -> bool:
1649 | """
1650 | Setuptools requires a callable entry point to build an installable script
1651 | """
1652 | parser = argparse.ArgumentParser()
1653 | parser.add_argument(
1654 | "path",
1655 | help="The path to a JSON file of named threads to load on launch",
1656 | nargs="?",
1657 | )
1658 | parser.add_argument(
1659 | "-c",
1660 | "--config",
1661 | help="The path to a Gptcmd configuration file to use for this session",
1662 | )
1663 | parser.add_argument(
1664 | "-t",
1665 | "--thread",
1666 | help="The name of the thread to switch to on launch",
1667 | )
1668 | parser.add_argument(
1669 | "-m",
1670 | "--model",
1671 | help="The name of the model to switch to on launch",
1672 | )
1673 | parser.add_argument(
1674 | "-a",
1675 | "--account",
1676 | help="The name of the account to switch to on launch",
1677 | )
1678 | parser.add_argument(
1679 | "--version", help="Show version and exit", action="store_true"
1680 | )
1681 | args = parser.parse_args()
1682 | if args.version:
1683 | print(f"Gptcmd {__version__}")
1684 | return True
1685 | try:
1686 | if args.config:
1687 | config = ConfigManager.from_toml(args.config)
1688 | else:
1689 | config = None
1690 | shell = Gptcmd(config=config)
1691 | except ConfigError as e:
1692 | print(f"Couldn't read config: {e}")
1693 | return False
1694 | if args.path:
1695 | shell.do_load(args.path, _print_on_success=False)
1696 | if args.thread:
1697 | shell.do_thread(args.thread, _print_on_success=False)
1698 | if args.account:
1699 | shell.do_account(args.account, _print_on_success=False)
1700 | if args.model:
1701 | shell.do_model(args.model, _print_on_success=False)
1702 | try:
1703 | shell.cmdloop()
1704 | except Exception as e:
1705 | # Does any thread contain messages?
1706 | should_save = (shell._detached and shell._detached.dirty) or any(
1707 | t and t.dirty for t in shell._threads.values()
1708 | )
1709 | if should_save:
1710 | dump_path = _write_crash_dump(shell, e)
1711 | if dump_path:
1712 | # Hack: Print the "crash dump" notice after the traceback
1713 | atexit.register(
1714 | lambda p=dump_path: print(
1715 | f"Crash dump written to {p}",
1716 | file=sys.stderr,
1717 | flush=True,
1718 | )
1719 | )
1720 | raise
1721 | return True
1722 |
1723 |
1724 | if __name__ == "__main__":
1725 | success = main()
1726 | if success:
1727 | sys.exit(0)
1728 | else:
1729 | sys.exit(1)
1730 |
--------------------------------------------------------------------------------
/src/gptcmd/config.py:
--------------------------------------------------------------------------------
1 | """
2 | This module contains the ConfigManager class, which controls Gptcmd's
3 | config system.
4 | Copyright 2024 Bill Dengler
5 | This Source Code Form is subject to the terms of the Mozilla Public
6 | License, v. 2.0. If a copy of the MPL was not distributed with this
7 | file, You can obtain one at https://mozilla.org/MPL/2.0/.
8 | """
9 |
10 | import dataclasses
11 | import os
12 | import sys
13 | import platform
14 | import shlex
15 | import shutil
16 | from functools import cached_property
17 | from importlib import resources
18 | from importlib.metadata import entry_points
19 | from typing import Any, Dict, List, Optional, Type, Union
20 |
21 | if sys.version_info >= (3, 11):
22 | import tomllib
23 | else:
24 | import tomli as tomllib
25 |
26 | from .llm import LLMProvider
27 | from .llm.openai import AzureAI, OpenAI
28 |
29 | DEFAULT_PROVIDERS: Dict[str, Type[LLMProvider]] = {
30 | "openai": OpenAI,
31 | "azure": AzureAI,
32 | }
33 |
34 |
35 | class ConfigError(Exception):
36 | pass
37 |
38 |
39 | @dataclasses.dataclass(frozen=True)
40 | class Account:
41 | name: str
42 | provider: dataclasses.InitVar[Union[LLMProvider, Type[LLMProvider]]]
43 | _conf: Dict[str, Any] = dataclasses.field(
44 | default_factory=dict, repr=False, compare=False, hash=False
45 | )
46 | _provider_cls: Type[LLMProvider] = dataclasses.field(
47 | init=False, repr=False, compare=False, hash=False
48 | )
49 |
50 | def __post_init__(self, provider):
51 | if isinstance(provider, LLMProvider):
52 | object.__setattr__(self, "_provider_cls", type(provider))
53 | # bypass cached_property
54 | object.__setattr__(self, "provider", provider)
55 | elif isinstance(provider, type) and issubclass(provider, LLMProvider):
56 | object.__setattr__(self, "_provider_cls", provider)
57 | else:
58 | raise TypeError(
59 | "provider must be an LLMProvider instance or subclass"
60 | )
61 |
62 | @cached_property
63 | def provider(self) -> LLMProvider:
64 | return self._provider_cls.from_config(self._conf)
65 |
66 |
67 | class ConfigManager:
68 | "Handles Gptcmd's configuration system."
69 |
70 | def __init__(
71 | self,
72 | config: Dict,
73 | providers: Optional[Dict[str, Type[LLMProvider]]] = None,
74 | ):
75 | """
76 | Initialize the ConfigManager with a configuration dictionary.
77 | """
78 | # Validate the provided config
79 | if "schema_version" not in config:
80 | raise ConfigError("Missing 'schema_version'")
81 |
82 | conf = self._load_sample_config()
83 |
84 | my_major = int(conf.pop("schema_version").split(".")[0])
85 | their_major = int(config["schema_version"].split(".")[0])
86 | if their_major > my_major:
87 | raise ConfigError(
88 | "This configuration is too new for the current version"
89 | " of Gptcmd!"
90 | )
91 |
92 | conf.update(config)
93 | self.conf = conf
94 | if providers is None:
95 | providers = self.__class__._discover_external_providers(
96 | initial_providers=DEFAULT_PROVIDERS
97 | )
98 | self.accounts = self._configure_accounts(
99 | self.conf["accounts"], providers
100 | )
101 | # Validate the default account immediately; others stay lazy-loaded
102 | _ = self.default_account.provider
103 |
104 | @staticmethod
105 | def _discover_external_providers(
106 | initial_providers: Optional[Dict[str, Type[LLMProvider]]] = None,
107 | ) -> Dict[str, Type[LLMProvider]]:
108 | """
109 | Discover external providers registered via entry points.
110 | """
111 | res: Dict[str, Type[LLMProvider]] = {}
112 | if initial_providers:
113 | res.update(initial_providers)
114 | eps = entry_points()
115 | ENTRY_POINT_GROUP = "gptcmd.providers"
116 | if hasattr(eps, "select"):
117 | selected_eps = eps.select(group=ENTRY_POINT_GROUP)
118 | else:
119 | selected_eps = eps.get(ENTRY_POINT_GROUP, ())
120 | for ep in selected_eps:
121 | provider_cls = ep.load()
122 | if ep.name in res:
123 |
124 | def fully_qualified_name(cls):
125 | return cls.__module__ + "." + cls.__qualname__
126 |
127 | raise ConfigError(
128 | f"Duplicate registration for {ep.name}:"
129 | f" {fully_qualified_name(res[ep.name])} and"
130 | f" {fully_qualified_name(provider_cls)}"
131 | )
132 | else:
133 | res[ep.name] = provider_cls
134 | return res
135 |
136 | @classmethod
137 | def from_toml(cls, path: Optional[str] = None):
138 | """
139 | Create a ConfigManager instance from a TOML file.
140 | """
141 | if path is None:
142 | config_root = cls._get_config_root()
143 | config_path = os.path.join(config_root, "config.toml")
144 | if not os.path.exists(config_path):
145 | os.makedirs(config_root, exist_ok=True)
146 | with resources.path(
147 | "gptcmd", "config_sample.toml"
148 | ) as sample_path:
149 | shutil.copy(sample_path, config_path)
150 | else:
151 | config_path = path
152 |
153 | try:
154 | with open(config_path, "rb") as fin:
155 | return cls(tomllib.load(fin))
156 | except (OSError, tomllib.TOMLDecodeError) as e:
157 | raise ConfigError(str(e)) from e
158 |
159 | def _configure_accounts(
160 | self, account_config: Dict, providers: Dict[str, Type[LLMProvider]]
161 | ) -> Dict[str, Account]:
162 | res = {}
163 | for name, conf in account_config.items():
164 | if "provider" not in conf:
165 | raise ConfigError(f"Account {name} has no provider specified")
166 | provider_cls = providers.get(conf["provider"])
167 | if not provider_cls:
168 | raise ConfigError(
169 | f"Provider {conf['provider']} is not available. Perhaps"
170 | " you need to install it?"
171 | )
172 | res[name] = Account(
173 | name=name,
174 | provider=provider_cls,
175 | _conf=conf.copy(),
176 | )
177 | return res
178 |
179 | @property
180 | def default_account(self) -> Account:
181 | try:
182 | return self.accounts.get(
183 | "default",
184 | next(
185 | iter(self.accounts.values())
186 | ), # The first configured account
187 | )
188 | except StopIteration:
189 | raise ConfigError("No default account configured")
190 |
191 | @property
192 | def editor(self) -> List[str]:
193 | posix = platform.system().lower() != "windows"
194 | editor = (
195 | self.conf.get("editor") or self.__class__._get_default_editor()
196 | )
197 | return shlex.split(editor, posix=posix)
198 |
199 | @staticmethod
200 | def _get_config_root():
201 | """Get the root directory for the configuration file."""
202 | system = platform.system().lower()
203 | if system == "windows":
204 | base_path = os.environ.get("APPDATA") or os.path.expanduser("~")
205 | elif system == "darwin":
206 | base_path = os.path.expanduser("~/Library/Application Support")
207 | else:
208 | base_path = os.environ.get(
209 | "XDG_CONFIG_HOME"
210 | ) or os.path.expanduser("~/.config")
211 | return os.path.join(base_path, "gptcmd")
212 |
213 | @staticmethod
214 | def _load_sample_config():
215 | "Load the sample configuration file as a dict"
216 | with resources.open_binary("gptcmd", "config_sample.toml") as fin:
217 | return tomllib.load(fin)
218 |
219 | @staticmethod
220 | def _get_default_editor():
221 | system = platform.system().lower()
222 | if system == "windows":
223 | # On Windows, default to notepad
224 | return "notepad"
225 | else:
226 | # On Unix-like systems, use the EDITOR environment variable if set
227 | editor = os.environ.get("EDITOR")
228 | if editor:
229 | return editor
230 | else:
231 | # Try common editors in order of preference
232 | for cmd in ("nano", "emacs", "vim", "ed", "vi"):
233 | if shutil.which(cmd):
234 | return cmd
235 | raise ConfigError("No editor available")
236 |
--------------------------------------------------------------------------------
/src/gptcmd/config_sample.toml:
--------------------------------------------------------------------------------
1 | # Gptcmd configuration
2 |
3 | # This option is used by the application for version tracking.
4 | schema_version = "1.1.0"
5 |
6 | # This option controls the formatting of the prompt.
7 | # The following keywords (when placed in braces) are replaced by:
8 | # model: the name of the active model
9 | # thread: the name of the active thread (if not the detached thread)
10 | # account: the name of the active account
11 | # Python escape sequences are supported.
12 | # Any other characters placed in this string are printed literally.
13 | prompt = "{thread}({model}) "
14 |
15 | # This option controls whether estimated session cost is displayed, when
16 | # available, after each successful request.
17 | show_cost = true
18 |
19 | # Sometimes, such as when switching to a model that doesn't have cost
20 | # information available, cost estimation is unsupported.
21 | # Since these requests aren't counted in the session cost estimate, when
22 | # switching back to a scenario that does support cost estimation, the reported
23 | # estimated cost will be incomplete.
24 | # This option controls whether these incomplete estimates are displayed.
25 | show_incomplete_cost = false
26 |
27 | # This option controls whether the number of prompt (input) and sampled
28 | # (generated) tokens used for each request is displayed when available.
29 | show_token_usage = true
30 |
31 | # This option specifies the external editor Gptcmd uses for commands that require one.
32 | # If this option is not set, Gptcmd uses Notepad on Windows.
33 | # On Unix-like systems, Gptcmd uses the default configured editor, typically
34 | # determined by the EDITOR environment variable.
35 | # To specify a custom editor, uncomment the line setting the editor option
36 | # below and set it to an editor of your choice.
37 | # For example, to use Notepad++ on Windows:
38 | # editor = "C:\\Program Files (x86)\\Notepad++\\notepad++.exe -multiInst -notabbar -nosession -noPlugin"
39 |
40 | # This option controls how Gptcmd handles situations when the user invokes an
41 | # external editor to add a message but then closes the editor without entering
42 | # any content.
43 | # By default, this option is set to "never", meaning Gptcmd will cancel the
44 | # operation if no content is entered.
45 | # When this option is set to "ask", Gptcmd will prompt the user to confirm
46 | # whether to add an empty message or cancel.
47 | # Setting this option to "always" will add an empty message without prompting,
48 | # replicating Gptcmd's behaviour before version 2.0.0.
49 | # Unless you know that you have a specific need to create empty messages,
50 | # "never" is recommended.
51 | allow_add_empty_messages = "never"
52 |
53 | # This option controls what Gptcmd does when the user runs `retry` from
54 | # a named thread.
55 | # When this option is set to "always", a new thread will be created on retry,
56 | # replicating Gptcmd's behaviour before version 2.1.0.
57 | # When this option is set to "ask", Gptcmd will prompt the user whether to
58 | # create a new thread for this retried query or to overwrite the
59 | # existing contents, similar to Gptcmd's behaviour in the detached thread.
60 | # When this option is set to "never", Gptcmd always overwrites previous
61 | # assistant contents with the retried query in both detached and named threads.
62 | create_new_thread_on_retry = "ask"
63 |
64 | # Account Configuration
65 | # The following sections configure Gptcmd's connections to large language model provider accounts.
66 | # By default, Gptcmd uses the [accounts.default] section on startup.
67 | # If this section doesn't exist, Gptcmd uses the first account section it finds.
68 | # You can add multiple accounts by creating additional sections:
69 | # [accounts.first]
70 | # [accounts.second]
71 | # [accounts.custom_name]
72 | # Each account section should contain connection details similar to [accounts.default].
73 |
74 | # Within each account section (placed between its header and the next account's header), you can specify the following options:
75 |
76 | # provider: Specifies the large language model provider; must be "openai",
77 | # "azure", or the name of an external provider.
78 | # Example:
79 | # provider = "openai"
80 |
81 | # model: The OpenAI model or Azure deployment Gptcmd should use when this account is activated.
82 | # Example:
83 | # model = "gpt-4o-mini"
84 |
85 | # endpoint: For Azure accounts, the Azure endpoint URL.
86 | # Example:
87 | # endpoint = "https://contoso.openai.azure.com/"
88 |
89 | # api_key: The API key to use. If omitted, Gptcmd reads it from the OPENAI_API_KEY (for OpenAI accounts) or AZURE_OPENAI_API_KEY (for Azure accounts) environment variable.
90 | # Example:
91 | # api_key = "sk-xxxxxx"
92 |
93 | # base_url: For OpenAI accounts, the endpoint URL to which Gptcmd should connect.
94 | # With the "model" option, this option can be used to connect Gptcmd to third-party OpenAI-compatible APIs.
95 | # Example:
96 | # base_url = "https://openrouter.ai/api/v1"
97 |
98 | # Any additional options are passed directly to the Python OpenAI client's constructor for this account.
99 |
100 | [accounts.default]
101 | provider="openai"
102 |
--------------------------------------------------------------------------------
/src/gptcmd/llm/README.md:
--------------------------------------------------------------------------------
1 | # Large language model providers
2 |
3 | Gptcmd uses instances of the `LLMProvider` abstract class to interact with large language models (LLMs). This document describes the `LLMProvider` abstract class and supporting infrastructure and demonstrates how to implement a simple custom provider.
4 |
5 | ## Overview
6 |
7 | `gptcmd.llm.LLMProvider` is an [abstract base class](https://docs.python.org/3/glossary.html#term-abstract-base-class) located in `src/gptcmd/llm/__init__.py`. It defines the interface that all LLM providers must implement to work with Gptcmd. Below is an overview of the main components.
8 |
9 | ### Key methods
10 |
11 | #### `from_config(cls, conf: Dict) -> LLMProvider`
12 |
13 | * **Purpose**: A class method that instantiates the LLMProvider from a user configuration dictionary.
14 | * **Usage**: This class method is used by the configuration system to instantiate `LLMProvider` classes.
15 |
16 | #### `complete(self, messages: Sequence[Message]) -> LLMResponse`
17 |
18 | * **Purpose**: Generate a response from the LLM given a collection of `Message` objects.
19 | * **Usage**: This method should contain the logic that calls your LLM API and converts its response into an `LLMResponse` object.
20 |
21 | #### `validate_api_params(self, params: Dict[str, Any]) -> Dict[str, Any]`
22 |
23 | * **Purpose**: Validate and sanitize API parameters provided by the user.
24 | * **Usage**: Ensure that only valid parameters are accepted and that they are within acceptable ranges or formats. This method should raise `InvalidAPIParameterError` for unknown parameter values or values that cannot be sanitized programmatically.
25 |
26 | #### `get_best_model(self) -> str`
27 |
28 | * **Purpose**: Return the name of the most capable model offered by this provider.
29 | * **Usage**: This method helps in selecting a default model if none is otherwise configured.
30 |
31 | #### `valid_models(self) -> Optional[Iterable[str]]`
32 |
33 | * **Purpose**: Provide a collection of valid model names that can be used with this provider. If a list of valid models cannot be determined in this session, return `None`.
34 | * **Usage**: Used during validation when switching the active model.
35 |
36 | ### Supporting classes and exceptions
37 |
38 | #### `gptcmd.message.Message`
39 |
40 | * **Purpose**: A [`dataclass`](https://docs.python.org/3/library/dataclasses.html) representing a message written by the user or LLM.
41 | * **Usage**: Used throughout the application.
42 |
43 | ##### Key fields
44 |
45 | Field | Type | Description
46 | --- | --- | ---
47 | `content` | `str` | The text of the message.
48 | `role` | `gptcmd.message.MessageRole` | The conversational role of this message, such as `gptcmd.message.MessageRole.USER`.
49 | `name` | `Optional[str]` | The user-provided name for this message.
50 | `attachments` | `List[gptcmd.message.MessageAttachment]` | A list of rich attachments, such as images, associated with this message.
51 | `metadata` | `Dict[str, Any]` | A dictionary of arbitrary metadata associated with this message, which can be used to get or set data particular to a specific provider (such as reasoning text, a digital signature, user requests for special handling, etc.). Since `metadata` is a field on `Message`, it can be accessed by any provider: it may be wise to, say, prefix metadata keys with the `LLMProvider`'s entry point name and an underscore for namespacing. Metadata values must be JSON serializable.
52 |
53 | #### `gptcmd.llm.LLMResponse`
54 |
55 | * **Purpose**: A [`dataclass`](https://docs.python.org/3/library/dataclasses.html) containing a `Message` generated by the `LLMProvider` in response to a user request, as well as optional metadata like token counts and cost estimates.
56 | * **Usage**: Return this from your `complete` method.
57 |
58 | ##### Key fields
59 |
60 | Field | Type | Description
61 | --- | --- | ---
62 | `message` | `gptcmd.message.Message` | The `Message` object containing the LLM's response to a user query.
63 | `prompt_tokens` | `Optional[int]` | The number of tokens, as determined by the LLM's tokenizer, which the request (context) that generated this response contains.
64 | `sampled_tokens` | `Optional[int]` | The number of tokens, as determined by the LLM's tokenizer, which this response contains.
65 | `cost_in_cents` | `Optional[Union[int, Decimal]]` | An estimate of the cost, in US cents, of the request that generated this response.
66 |
67 | #### Exceptions
68 |
69 | * **`gptcmd.config.ConfigError`**: Raised by the `from_config` method when the provider cannot be configured.
70 | * **`gptcmd.llm.CompletionError`**: Raised by the `complete` method when the LLM cannot generate a response.
71 | * **`gptcmd.llm.InvalidAPIParameterError`**: Raised by the `validate_api_params` method when invalid API parameters are provided.
72 |
73 | ## Building an `LLMProvider`
74 |
75 | To show how the process works, we'll build a simple `LLMProvider` implementation that mostly just responds with a copy of the user's request. To start, create a directory called `gptcmd-echo-provider`.
76 |
77 | ### Packaging metadata
78 |
79 | In your `gptcmd-echo-provider` directory, create a file called `pyproject.toml` with the following content:
80 |
81 | ``` toml
82 | [build-system]
83 | requires = ["setuptools>=61.0"]
84 | build-backend = "setuptools.build_meta"
85 |
86 | [project]
87 | name = "gptcmd-echo-provider"
88 | version = "0.1.0"
89 | dependencies = ["gptcmd >= 2.0.0"]
90 | ```
91 |
92 | More information about the `pyproject.toml` format can be found in the [relevant section of the Python Packaging Tutorial](https://packaging.python.org/en/latest/guides/writing-pyproject-toml/).
93 |
94 | Gptcmd uses a [packaging entry point](https://packaging.python.org/en/latest/specifications/entry-points/) to find external providers. The name of the entry point corresponds to the value of the `provider` option used to select it in a user account configuration. Add this to the end of `pyproject.toml`, which will make our new provider selectable with `provider="echo"` in an account configuration table:
95 |
96 | ``` toml
97 | [project.entry-points."gptcmd.providers"]
98 | echo = "gptcmd_echo_provider.echo:EchoProvider"
99 | ```
100 |
101 | Create an `src` directory inside `gptcmd-echo-provider`. Inside that directory, create a subdirectory called `gptcmd_echo_provider`. Create an empty file at `gptcmd-echo-provider/src/gptcmd_echo_provider/__init__.py` so that Python considers this directory a package.
102 |
103 | ### Provider implementation
104 |
105 | Create a new file, `gptcmd-echo-provider/src/gptcmd_echo_provider/echo.py`, with the following content:
106 |
107 | ``` python
108 | from typing import Any, Dict, Iterable, Sequence
109 |
110 | from gptcmd.llm import (
111 | CompletionError,
112 | InvalidAPIParameterError,
113 | LLMProvider,
114 | LLMResponse,
115 | )
116 | from gptcmd.message import Message, MessageRole
117 |
118 |
119 | class EchoProvider(LLMProvider):
120 | @classmethod
121 | def from_config(cls, conf: Dict[str, Any]) -> "EchoProvider":
122 | # No config options supported
123 | return cls()
124 |
125 | def validate_api_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
126 | raise InvalidAPIParameterError("API parameters are unsupported")
127 | ```
128 |
129 | #### Implementing `complete`
130 |
131 | For this provider, the `complete` method just returns a copy of whatever the user said last. Add this to `echo.py` (inside the `EchoProvider` class):
132 |
133 | ``` python
134 | def complete(self, messages: Sequence[Message]) -> LLMResponse:
135 | for msg in reversed(messages):
136 | if msg.role == MessageRole.USER:
137 | return LLMResponse(
138 | Message(content=msg.content, role=MessageRole.ASSISTANT)
139 | )
140 | # We never reached a user message, so just throw an error
141 | raise CompletionError("Nothing to echo!")
142 | ```
143 |
144 | #### Implementing model selection
145 |
146 | Since this provider is just an example, we'll only support one model called `echo-1`. If this provider made multiple models available, we would provide the full list of options in the `valid_models` method. The currently selected model is available on the `model` attribute of an `LLMProvider` instance. Add this to `echo.py` (inside the class):
147 |
148 | ``` python
149 | def get_best_model(self) -> str:
150 | return "echo-1"
151 |
152 | @property
153 | def valid_models(self) -> Iterable[str]:
154 | return ("echo-1",)
155 | ```
156 |
157 | ### Testing the provider
158 |
159 | Let's install, configure, and try out the new provider. From the `gptcmd-echo-provider` directory, run `pip install .` to install the provider package. During provider development, you might want to do an [editable install](https://pip.pypa.io/en/latest/topics/local-project-installs/) (`pip install -e .`) so that you don't need to reinstall the package after each change.
160 |
161 | After the provider is installed, add a new account to your configuration file:
162 |
163 | ``` toml
164 | [accounts.echotest]
165 | provider="echo"
166 | ```
167 |
168 | Start Gptcmd and test the provider:
169 |
170 | ```
171 | (gpt-4o) account echotest
172 | Switched to account 'echotest'
173 | (echo-1) say Hello, world!
174 | ...
175 | Hello, world!
176 | ```
177 |
178 | ### Optional features
179 |
180 | #### User configuration
181 |
182 | Configuration values can be extracted and passed to the created `LLMProvider` instance from its `from_config` constructor. For instance, we can add a configuration option to echo messages in reverse. First, add a constructor to the `EchoProvider` class:
183 |
184 | ``` python
185 | def __init__(self, backwards: bool = False, *args, **kwargs):
186 | self.backwards = backwards
187 | super().__init__(*args, **kwargs)
188 | ```
189 |
190 | Then, replace the `from_config` method with:
191 |
192 | ``` python
193 | @classmethod
194 | def from_config(cls, conf: Dict[str, Any]) -> "EchoProvider":
195 | return cls(backwards=conf.get("backwards"))
196 | ```
197 |
198 | In this example, `from_config` always succeeds. If `from_config` might throw an error (for instance, due to invalid user input, failed network requests, etc.), the method should raise `ConfigError` (in the `gptcmd.config` module).
199 |
200 | Now, modify `complete`:
201 |
202 | ``` python
203 | def complete(self, messages: Sequence[Message]) -> LLMResponse:
204 | for msg in reversed(messages):
205 | if msg.role == MessageRole.USER:
206 | content = msg.content[::-1] if self.backwards else msg.content
207 | return LLMResponse(
208 | Message(content=content, role=MessageRole.ASSISTANT)
209 | )
210 | # We never reached a user message, so just throw an error
211 | raise CompletionError("Nothing to echo!")
212 | ```
213 |
214 | By default, the provider outputs messages as-is (in forward order):
215 |
216 | ```
217 | (echo-1) say Hello, world!
218 | ...
219 | Hello, world!
220 | ```
221 |
222 | However, when we add `backwards=true` to the account configuration:
223 |
224 | ``` toml
225 | [accounts.echotest]
226 | provider="echo"
227 | backwards=true
228 | ```
229 |
230 | We get:
231 |
232 | ```
233 | (echo-1) say Hello, world!
234 | ...
235 | !dlrow ,olleH
236 | (echo-1) say Was it Eliot's toilet I saw?
237 | ...
238 | ?was I teliot s'toilE ti saW
239 | ```
240 |
241 | #### API parameters
242 |
243 | To support API parameters, implement `validate_api_params`. We'll add a parameter to control how many times the user's message is echoed. Replace the `validate_api_params` method with:
244 |
245 | ``` python
246 | def validate_api_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
247 | # Examine the provided parameters
248 | for param, value in params.items():
249 | if param == "repeat":
250 | if not isinstance(value, int):
251 | raise InvalidAPIParameterError("Repeat must be an integer")
252 | # We must echo at least one time
253 | # If the user provides zero or a negative number, set it to 1
254 | params["repeat"] = max(1, value)
255 | else: # An unsupported parameter
256 | raise InvalidAPIParameterError(
257 | f"Parameter {param!r} not supported"
258 | )
259 | return params
260 | ```
261 |
262 | Implement support for the new parameter in `complete`:
263 |
264 | ``` python
265 | def complete(self, messages: Sequence[Message]) -> LLMResponse:
266 | for msg in reversed(messages):
267 | if msg.role == MessageRole.USER:
268 | content = msg.content
269 | if "repeat" in self.api_params:
270 | content *= self.api_params["repeat"]
271 | return LLMResponse(
272 | Message(content=content, role=MessageRole.ASSISTANT)
273 | )
274 | # We never reached a user message, so just throw an error
275 | raise CompletionError("Nothing to echo!")
276 | ```
277 |
278 | Test it:
279 |
280 | ```
281 | (echo-1) say hello
282 | ...
283 | hello
284 | (echo-1) set repeat 3
285 | repeat set to 3
286 | (echo-1) retry
287 | ...
288 | hellohellohello
289 | (echo-1) set repeat -1
290 | repeat set to 1
291 | (echo-1) retry
292 | ...
293 | hello
294 | (echo-1) unset repeat
295 | repeat unset
296 | (echo-1) retry
297 | ...
298 | hello
299 | ```
300 |
301 | ##### Default parameters
302 |
303 | To define default values for API parameters, update the constructor to set them, and override `unset_api_param` to restore the default value when a default parameter is unset. We'll set a default value of 1 for the `repeat` parameter. Add a class variable to the `EchoProvider` class:
304 |
305 | ``` python
306 | class EchoProvider(LLMProvider):
307 | DEFAULT_API_PARAMS: Dict[str, Any] = {"repeat": 1}
308 | # ...
309 | ```
310 |
311 | Next, add a constructor:
312 |
313 | ``` python
314 | def __init__(self, *args, **kwargs):
315 | super().__init__(*args, **kwargs)
316 | self.update_api_params(self.__class__.DEFAULT_API_PARAMS)
317 | ```
318 |
319 | And override `unset_api_param`:
320 |
321 | ``` python
322 | def unset_api_param(self, key: Optional[str] = None) -> None:
323 | super().unset_api_param(key)
324 | if key in self.__class__.DEFAULT_API_PARAMS:
325 | self.set_api_param(key, self.__class__.DEFAULT_API_PARAMS[param])
326 | elif key is None: # Unset all parameters
327 | self.update_api_params(self.__class__.DEFAULT_API_PARAMS)
328 | ```
329 |
330 | Then, we can simplify `complete`:
331 |
332 | ``` python
333 | def complete(self, messages: Sequence[Message]) -> LLMResponse:
334 | for msg in reversed(messages):
335 | if msg.role == MessageRole.USER:
336 | return LLMResponse(
337 | Message(
338 | content=msg.content * self.api_params["repeat"],
339 | role=MessageRole.ASSISTANT,
340 | )
341 | )
342 | # We never reached a user message, so just throw an error
343 | raise CompletionError("Nothing to echo!")
344 | ```
345 |
346 | #### Message name field
347 |
348 | If your `LLMProvider` implementation processes the `name` field set on `Message` objects, you'll need to advertise this support. Add a class variable called `SUPPORTED_FEATURES` containing the appropriate member of the `gptcmd.llm.LLMProviderFeature` [flag enumeration](https://docs.python.org/3/library/enum.html#enum.Flag). Import `LLMProviderFeature` from `gptcmd.llm` in your provider's module, then add this inside the class:
349 |
350 | ``` python
351 | SUPPORTED_FEATURES = LLMProviderFeature.MESSAGE_NAME_FIELD
352 | ```
353 |
354 | If your class implements support for multiple `LLMProviderFeature`s, separate them with a pipe (`|`) character.
355 |
356 | #### Message attachments
357 |
358 | To implement support for message attachments, write a formatter function for each attachment type you support, decorated with the `register_attachment_formatter` decorator on your provider class. For `EchoProvider`, we'll convert images to a simple string representation. First, import `Image` from `gptcmd.message`, then add this function to `echo.py` (outside the class):
359 |
360 | ``` python
361 | @EchoProvider.register_attachment_formatter(Image)
362 | def format_image(img: Image) -> str:
363 | return f"img={img.url}"
364 | ```
365 |
366 | Now, modify `complete` to process attachments in the correct place. For `EchoProvider`, we'll just add them to the response content. For a more functional provider, you might add them to an API request body:
367 |
368 | ``` python
369 | def complete(self, messages: Sequence[Message]) -> LLMResponse:
370 | for msg in reversed(messages):
371 | if msg.role == MessageRole.USER:
372 | content = msg.content
373 | for a in msg.attachments:
374 | content += "\n" + self.format_attachment(a)
375 | return LLMResponse(
376 | Message(content=content, role=MessageRole.ASSISTANT)
377 | )
378 | # We never reached a user message, so just throw an error
379 | raise CompletionError("Nothing to echo!")
380 | ```
381 |
382 | Now, when we attach images, their URLs are echoed:
383 |
384 | ```
385 | (echo-1) user hello!
386 | 'hello!' added as user
387 | (echo-1) image http://example.com/image.jpg
388 | Image added to 'hello!'
389 | (echo-1) send
390 | ...
391 | hello!
392 | img=http://example.com/image.jpg
393 | ```
394 |
395 | #### Streamed responses
396 |
397 | If your `LLMProvider` implementation can stream parts of a response as they are generated, you'll need to advertise this support. Add a class variable called `SUPPORTED_FEATURES` containing the appropriate member of the `gptcmd.llm.LLMProviderFeature` [flag enumeration](https://docs.python.org/3/library/enum.html#enum.Flag). Import `LLMProviderFeature` from `gptcmd.llm` in your provider's module, then add this inside the class:
398 |
399 | ``` python
400 | SUPPORTED_FEATURES = LLMProviderFeature.RESPONSE_STREAMING
401 | ```
402 |
403 | If your class implements support for multiple `LLMProviderFeature`s, separate them with a pipe (`|`) character.
404 |
405 | The `stream` property on the `LLMProvider` instance will be set to `True` when a response should be streamed. If your `LLMProvider` only supports streaming under certain conditions (such as when certain models are used but not others), override the `stream` property getter to return `False` and the setter to raise `NotImplementedError` with an appropriate message in unsupported scenarios (you can use the already defined `LLMProvider._stream` attribute as a backing field). To enable streaming by default, set the property in your provider's constructor. Then, subclass `LLMResponse` to handle streams. In your `LLMResponse` implementation, you'll want to create a backing `Message` (that you update as the response streams in, so that user disconnections and runtime errors are handled gracefully), and implement an iterator to update this message and yield the text of the next chunk of the stream as a string. In `complete`, return your custom `LLMResponse` when `self.stream == True`.
406 |
--------------------------------------------------------------------------------
/src/gptcmd/llm/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | This module contains the LLMProvider class and supporting infrastructure.
3 | Copyright 2024 Bill Dengler
4 | This Source Code Form is subject to the terms of the Mozilla Public
5 | License, v. 2.0. If a copy of the MPL was not distributed with this
6 | file, You can obtain one at https://mozilla.org/MPL/2.0/.
7 | """
8 |
9 | import dataclasses
10 | from abc import ABC, abstractmethod
11 | from decimal import Decimal
12 | from enum import Flag, auto
13 | from typing import (
14 | Any,
15 | Callable,
16 | Dict,
17 | Iterable,
18 | Optional,
19 | Sequence,
20 | Type,
21 | Union,
22 | )
23 |
24 | from ..message import Message, MessageAttachment, UnknownAttachment
25 |
26 |
27 | @dataclasses.dataclass
28 | class LLMResponse:
29 | message: Message
30 | prompt_tokens: Optional[int] = None
31 | sampled_tokens: Optional[int] = None
32 | cost_in_cents: Optional[Union[int, Decimal]] = None
33 |
34 | def __iter__(self):
35 | """The default iterator for non-streaming LLMResponse objects."""
36 | yield self.message.content
37 |
38 |
39 | class InvalidAPIParameterError(Exception):
40 | pass
41 |
42 |
43 | class CompletionError(Exception):
44 | pass
45 |
46 |
47 | class LLMProviderFeature(Flag):
48 | """
49 | An enum representing optional features that an LLMProvider might
50 | implement.
51 | """
52 |
53 | # Whether this LLM implements support for the name attribute
54 | # on Message objects. If this flag is not set, message names are likely
55 | # to be ignored.
56 | MESSAGE_NAME_FIELD = auto()
57 |
58 | # Whether this LLM implements support for streamed responses
59 | RESPONSE_STREAMING = auto()
60 |
61 |
62 | class LLMProvider(ABC):
63 | """
64 | An object which generates the most likely next Message
65 | given a sequence of Messages.
66 | """
67 |
68 | SUPPORTED_FEATURES: LLMProviderFeature = LLMProviderFeature(0)
69 |
70 | def __init__(self, model: Optional[str] = None):
71 | self.model: Optional[str] = model or self.get_best_model()
72 | self._api_params: Dict[str, Any] = {}
73 | self._stream: bool = False
74 |
75 | def __init_subclass__(cls):
76 | cls._attachment_formatters: Dict[
77 | Type[MessageAttachment],
78 | Callable[[MessageAttachment], Dict[str, Any]],
79 | ] = {}
80 |
81 | @abstractmethod
82 | def complete(self, messages: Sequence[Message]) -> LLMResponse:
83 | pass
84 |
85 | @abstractmethod
86 | def validate_api_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
87 | """
88 | Given a dict of API parameters, this method:
89 | Raises InvalidAPIParameterError if this model doesn't support a
90 | parameter defined in the dictionary.
91 | If the user-provided value is out of range or in the incorrect format,
92 | this method adjusts the value accordingly.
93 | """
94 | pass
95 |
96 | @property
97 | @abstractmethod
98 | def valid_models(self) -> Optional[Iterable[str]]:
99 | """
100 | A collection of model names that can be set on this LLM provider
101 | """
102 | pass
103 |
104 | @classmethod
105 | @abstractmethod
106 | def from_config(cls, conf: Dict):
107 | "Instantiate this object from a dict of configuration file parameters."
108 | pass
109 |
110 | @property
111 | def stream(self) -> bool:
112 | return (
113 | self._stream
114 | and LLMProviderFeature.RESPONSE_STREAMING
115 | in self.SUPPORTED_FEATURES
116 | )
117 |
118 | @stream.setter
119 | def stream(self, val: bool):
120 | if (
121 | LLMProviderFeature.RESPONSE_STREAMING
122 | not in self.SUPPORTED_FEATURES
123 | ):
124 | raise NotImplementedError(
125 | "Response streaming is not supported by this LLM"
126 | )
127 | self._stream = val
128 |
129 | @property
130 | def api_params(self) -> Dict[str, Any]:
131 | return self._api_params.copy()
132 |
133 | def set_api_param(self, key: str, value: Any) -> Any:
134 | """Set an API parameter after validating it."""
135 | new_params = self._api_params.copy()
136 | new_params[key] = value
137 | validated_params = self.validate_api_params(new_params)
138 | self._api_params = validated_params
139 | return validated_params.get(key)
140 |
141 | def unset_api_param(self, key: Optional[str] = None) -> None:
142 | if key is None:
143 | self._api_params = {}
144 | else:
145 | try:
146 | del self._api_params[key]
147 | except KeyError:
148 | raise InvalidAPIParameterError(f"{key} not set")
149 |
150 | def update_api_params(self, params: Dict[str, Any]) -> None:
151 | """Update multiple API parameters at once after validating them."""
152 | new_params = self._api_params.copy()
153 | new_params.update(params)
154 | validated_params = self.validate_api_params(new_params)
155 | self._api_params = validated_params
156 |
157 | @abstractmethod
158 | def get_best_model(self) -> str:
159 | """
160 | This method returns the name of the most capable model offered by
161 | this provider.
162 | """
163 | pass
164 |
165 | @classmethod
166 | def register_attachment_formatter(
167 | cls, attachment_type: Type[MessageAttachment]
168 | ):
169 | def decorator(func: Callable[[MessageAttachment], Dict[str, Any]]):
170 | cls._attachment_formatters[attachment_type] = func
171 | return func
172 |
173 | return decorator
174 |
175 | def format_attachment(
176 | self, attachment: MessageAttachment
177 | ) -> Dict[str, Any]:
178 | if isinstance(attachment, UnknownAttachment):
179 | raise ValueError(
180 | f"{attachment.type} attachments are not supported. Perhaps you"
181 | " need to update Gptcmd or install a package?"
182 | )
183 | for cls in self.__class__.__mro__:
184 | formatter = getattr(cls, "_attachment_formatters", {}).get(
185 | type(attachment)
186 | )
187 | if formatter:
188 | return formatter(attachment)
189 | raise ValueError(
190 | f"{type(attachment).__name__} attachments aren't supported by"
191 | " this LLM"
192 | )
193 |
--------------------------------------------------------------------------------
/src/gptcmd/llm/openai.py:
--------------------------------------------------------------------------------
1 | """
2 | This module contains implementations of LLMProvider for OpenAI and Azure.
3 | Copyright 2024 Bill Dengler
4 | This Source Code Form is subject to the terms of the Mozilla Public
5 | License, v. 2.0. If a copy of the MPL was not distributed with this
6 | file, You can obtain one at https://mozilla.org/MPL/2.0/.
7 | """
8 |
9 | import inspect
10 | import re
11 |
12 | from collections import namedtuple
13 | from decimal import Decimal
14 | from typing import Any, Dict, Iterable, Optional, Sequence
15 |
16 | from . import (
17 | CompletionError,
18 | InvalidAPIParameterError,
19 | LLMProviderFeature,
20 | LLMProvider,
21 | LLMResponse,
22 | )
23 | from ..message import Image, Message, MessageRole
24 |
25 | import openai
26 |
27 | ModelCostInfo = namedtuple(
28 | "ModelCostInfo", ("prompt_scale", "sampled_scale", "cache_discount_factor")
29 | )
30 |
31 |
32 | class OpenAI(LLMProvider):
33 | SUPPORTED_FEATURES = (
34 | LLMProviderFeature.MESSAGE_NAME_FIELD
35 | | LLMProviderFeature.RESPONSE_STREAMING
36 | )
37 |
38 | _reasoning_model_expr = re.compile(r"^o\d.*$")
39 |
40 | def __init__(self, client, *args, **kwargs):
41 | self._client = client
42 | try:
43 | self._models = {m.id for m in self._client.models.list().data}
44 | except openai.NotFoundError:
45 | # Some OpenAI-like APIs implement a chat completions endpoint but
46 | # don't offer a list of models.
47 | # For these APIs, disable model validation.
48 | self._models = None
49 | super().__init__(*args, **kwargs)
50 | self._stream = True
51 |
52 | @classmethod
53 | def from_config(cls, conf: Dict):
54 | SPECIAL_OPTS = (
55 | "model",
56 | "provider",
57 | )
58 | model = conf.get("model")
59 | client_opts = {k: v for k, v in conf.items() if k not in SPECIAL_OPTS}
60 | try:
61 | client = openai.OpenAI(**client_opts)
62 | except openai.OpenAIError as e:
63 | # Import late to avoid circular import
64 | from ..config import ConfigError
65 |
66 | raise ConfigError(str(e)) from e
67 | return cls(client, model=model)
68 |
69 | def _message_to_openai(self, msg: Message) -> Dict[str, Any]:
70 | res = {
71 | "role": (
72 | "developer"
73 | if self.__class__._reasoning_model_expr.match(self.model)
74 | and msg.role == MessageRole.SYSTEM
75 | else msg.role
76 | )
77 | }
78 | if msg.name:
79 | res["name"] = msg.name
80 | if msg.attachments:
81 | res["content"] = [
82 | {"type": "text", "text": msg.content},
83 | *[self.format_attachment(a) for a in msg.attachments],
84 | ]
85 | else:
86 | res["content"] = msg.content
87 | return res
88 |
89 | @staticmethod
90 | def _estimate_cost_in_cents(
91 | model: str,
92 | prompt_tokens: int,
93 | cached_prompt_tokens: int,
94 | sampled_tokens: int,
95 | ) -> Optional[Decimal]:
96 | COST_PER_PROMPT_SAMPLED: Dict[str, ModelCostInfo] = {
97 | "o3-2025-04-16": ModelCostInfo(
98 | Decimal("10") / Decimal("1000000"),
99 | Decimal("40") / Decimal("1000000"),
100 | Decimal("0.25"),
101 | ),
102 | "o1-2024-12-17": ModelCostInfo(
103 | Decimal("15") / Decimal("1000000"),
104 | Decimal("60") / Decimal("1000000"),
105 | Decimal("0.5"),
106 | ),
107 | "o1-preview-2024-09-12": ModelCostInfo(
108 | Decimal("15") / Decimal("1000000"),
109 | Decimal("60") / Decimal("1000000"),
110 | Decimal("0.5"),
111 | ),
112 | "o4-mini-2025-04-16": ModelCostInfo(
113 | Decimal("1.1") / Decimal("1000000"),
114 | Decimal("4.4") / Decimal("1000000"),
115 | Decimal("0.25"),
116 | ),
117 | "o3-mini-2025-01-31": ModelCostInfo(
118 | Decimal("1.1") / Decimal("1000000"),
119 | Decimal("4.4") / Decimal("1000000"),
120 | Decimal("0.5"),
121 | ),
122 | "o1-mini-2024-09-12": ModelCostInfo(
123 | Decimal("3") / Decimal("1000000"),
124 | Decimal("12") / Decimal("1000000"),
125 | Decimal("0.5"),
126 | ),
127 | "gpt-4.1-2025-04-14": ModelCostInfo(
128 | Decimal("2") / Decimal("1000000"),
129 | Decimal("8") / Decimal("1000000"),
130 | Decimal("0.25"),
131 | ),
132 | "gpt-4.5-preview-2025-02-27": ModelCostInfo(
133 | Decimal("75") / Decimal("1000000"),
134 | Decimal("150") / Decimal("1000000"),
135 | Decimal("0.5"),
136 | ),
137 | "gpt-4o-2024-11-20": ModelCostInfo(
138 | Decimal("2.5") / Decimal("1000000"),
139 | Decimal("10") / Decimal("1000000"),
140 | Decimal("0.5"),
141 | ),
142 | "gpt-4o-2024-08-06": ModelCostInfo(
143 | Decimal("2.5") / Decimal("1000000"),
144 | Decimal("10") / Decimal("1000000"),
145 | Decimal("0.5"),
146 | ),
147 | "gpt-4o-2024-05-13": ModelCostInfo(
148 | Decimal("5") / Decimal("1000000"),
149 | Decimal("15") / Decimal("1000000"),
150 | Decimal("0.5"),
151 | ),
152 | "gpt-4.1-mini-2025-04-14": ModelCostInfo(
153 | Decimal("0.4") / Decimal("1000000"),
154 | Decimal("1.6") / Decimal("1000000"),
155 | Decimal("0.25"),
156 | ),
157 | "gpt-4.1-nano-2025-04-14": ModelCostInfo(
158 | Decimal("0.1") / Decimal("1000000"),
159 | Decimal("0.4") / Decimal("1000000"),
160 | Decimal("0.25"),
161 | ),
162 | "gpt-4o-mini-2024-07-18": ModelCostInfo(
163 | Decimal("0.15") / Decimal("1000000"),
164 | Decimal("0.6") / Decimal("1000000"),
165 | Decimal("0.5"),
166 | ),
167 | "gpt-4-turbo-2024-04-09": ModelCostInfo(
168 | Decimal("10") / Decimal("1000000"),
169 | Decimal("30") / Decimal("1000000"),
170 | Decimal("0.5"),
171 | ),
172 | "gpt-4-0125-preview": ModelCostInfo(
173 | Decimal("10") / Decimal("1000000"),
174 | Decimal("30") / Decimal("1000000"),
175 | Decimal("0.5"),
176 | ),
177 | "gpt-4-1106-preview": ModelCostInfo(
178 | Decimal("10") / Decimal("1000000"),
179 | Decimal("30") / Decimal("1000000"),
180 | Decimal("0.5"),
181 | ),
182 | "gpt-4-1106-vision-preview": ModelCostInfo(
183 | Decimal("10") / Decimal("1000000"),
184 | Decimal("30") / Decimal("1000000"),
185 | Decimal("0.5"),
186 | ),
187 | "gpt-4-0613": ModelCostInfo(
188 | Decimal("30") / Decimal("1000000"),
189 | Decimal("60") / Decimal("1000000"),
190 | Decimal("0.5"),
191 | ),
192 | "gpt-3.5-turbo-0125": ModelCostInfo(
193 | Decimal("0.5") / Decimal("1000000"),
194 | Decimal("1.5") / Decimal("1000000"),
195 | Decimal("0.5"),
196 | ),
197 | "gpt-3.5-turbo-1106": ModelCostInfo(
198 | Decimal("1") / Decimal("1000000"),
199 | Decimal("2") / Decimal("1000000"),
200 | Decimal("0.5"),
201 | ),
202 | "gpt-3.5-turbo-0613": ModelCostInfo(
203 | Decimal("1.5") / Decimal("1000000"),
204 | Decimal("2") / Decimal("1000000"),
205 | Decimal("0.5"),
206 | ),
207 | "gpt-3.5-turbo-16k-0613": ModelCostInfo(
208 | Decimal("3") / Decimal("1000000"),
209 | Decimal("4") / Decimal("1000000"),
210 | Decimal("0.5"),
211 | ),
212 | "gpt-3.5-turbo-0301": ModelCostInfo(
213 | Decimal("1.5") / Decimal("1000000"),
214 | Decimal("2") / Decimal("1000000"),
215 | Decimal("0.5"),
216 | ),
217 | }
218 |
219 | if model not in COST_PER_PROMPT_SAMPLED:
220 | return None
221 | info = COST_PER_PROMPT_SAMPLED[model]
222 | cached_prompt_scale = info.prompt_scale * info.cache_discount_factor
223 | uncached_prompt_tokens = prompt_tokens - cached_prompt_tokens
224 | return (
225 | Decimal(uncached_prompt_tokens) * info.prompt_scale
226 | + Decimal(cached_prompt_tokens) * cached_prompt_scale
227 | + Decimal(sampled_tokens) * info.sampled_scale
228 | ) * Decimal("100")
229 |
230 | def complete(self, messages: Sequence[Message]) -> LLMResponse:
231 | kwargs = {
232 | "model": self.model,
233 | "messages": [self._message_to_openai(m) for m in messages],
234 | "stream": self.stream,
235 | **self.validate_api_params(self.api_params),
236 | }
237 | if self.stream:
238 | # Enable usage statistics
239 | kwargs["stream_options"] = {"include_usage": True}
240 | if kwargs["model"] == "gpt-4-vision-preview":
241 | # For some unknown reason, OpenAI sets a very low
242 | # default max_tokens. For consistency with other models,
243 | # set it to the maximum if not overridden by the user.
244 | kwargs.setdefault("max_tokens", 4096)
245 | try:
246 | resp = self._client.chat.completions.create(**kwargs)
247 | except openai.OpenAIError as e:
248 | raise CompletionError(str(e)) from e
249 | if isinstance(resp, openai.Stream):
250 | return StreamedOpenAIResponse(resp, self)
251 | if resp.choices is None:
252 | raise CompletionError("Empty response (no choices specified)")
253 | elif len(resp.choices) != 1:
254 | raise CompletionError(
255 | f"Unexpected number of choices ({len(resp.choices)}) from"
256 | " OpenAI response"
257 | )
258 | choice = resp.choices[0]
259 | prompt_tokens = resp.usage.prompt_tokens
260 | prompt_tokens_details = getattr(
261 | resp.usage, "prompt_tokens_details", None
262 | )
263 | if prompt_tokens_details is None:
264 | cached_prompt_tokens = 0
265 | else:
266 | cached_prompt_tokens = prompt_tokens_details.cached_tokens
267 | sampled_tokens = resp.usage.completion_tokens
268 |
269 | return LLMResponse(
270 | message=Message(
271 | content=choice.message.content,
272 | role=MessageRole(choice.message.role),
273 | ),
274 | prompt_tokens=prompt_tokens,
275 | sampled_tokens=sampled_tokens,
276 | cost_in_cents=self.__class__._estimate_cost_in_cents(
277 | model=resp.model,
278 | prompt_tokens=prompt_tokens,
279 | cached_prompt_tokens=cached_prompt_tokens,
280 | sampled_tokens=sampled_tokens,
281 | ),
282 | )
283 |
284 | def validate_api_params(self, params):
285 | SPECIAL_OPTS = frozenset(
286 | ("model", "messages", "stream", "n", "stream_options")
287 | )
288 | valid_opts = (
289 | frozenset(
290 | inspect.signature(
291 | self._client.chat.completions.create
292 | ).parameters.keys()
293 | )
294 | - SPECIAL_OPTS
295 | )
296 | for opt in params:
297 | if opt not in valid_opts:
298 | raise InvalidAPIParameterError(f"Unknown parameter {opt}")
299 | return params
300 |
301 | @property
302 | def stream(self) -> bool:
303 | return self._stream
304 |
305 | @stream.setter
306 | def stream(self, val: bool):
307 | self._stream = val
308 |
309 | @property
310 | def valid_models(self) -> Iterable[str]:
311 | return self._models
312 |
313 | def get_best_model(self):
314 | BEST_MODELS = (
315 | "gpt-4.1",
316 | "gpt-4o",
317 | "gpt-4-turbo",
318 | "gpt-4.1-mini",
319 | "gpt-4o-mini",
320 | "gpt-4",
321 | "gpt-3.5-turbo",
322 | )
323 | res = next(
324 | (model for model in BEST_MODELS if model in self.valid_models),
325 | None,
326 | )
327 | if res is None:
328 | raise RuntimeError(
329 | "No known GPT model available! If this is an OpenAI-like API, "
330 | "set the model explicitly"
331 | )
332 | else:
333 | return res
334 |
335 |
336 | class AzureAI(OpenAI):
337 | AZURE_API_VERSION = "2024-06-01"
338 |
339 | @classmethod
340 | def from_config(cls, conf):
341 | SPECIAL_OPTS = (
342 | "model",
343 | "provider",
344 | "api_version",
345 | )
346 | model = conf.get("model")
347 | client_opts = {k: v for k, v in conf.items() if k not in SPECIAL_OPTS}
348 | client_opts["api_version"] = cls.AZURE_API_VERSION
349 | endpoint = client_opts.pop("endpoint", None)
350 | if endpoint:
351 | client_opts["azure_endpoint"] = endpoint
352 | client = openai.AzureOpenAI(**client_opts)
353 | return cls(client, model=model)
354 |
355 |
356 | @OpenAI.register_attachment_formatter(Image)
357 | def format_image_for_openai(img: Image) -> Dict[str, Any]:
358 | res = {"type": "image_url", "image_url": {"url": img.url}}
359 | if img.detail is not None:
360 | res["image_url"]["detail"] = img.detail
361 | return res
362 |
363 |
364 | class StreamedOpenAIResponse(LLMResponse):
365 | def __init__(self, backing_stream: openai.Stream, provider: OpenAI):
366 | self._stream = backing_stream
367 | self._provider = provider
368 |
369 | m = Message(content="", role="")
370 | super().__init__(m)
371 |
372 | def __iter__(self):
373 | return self
374 |
375 | def __next__(self):
376 | try:
377 | chunk = next(self._stream)
378 | except openai.OpenAIError as e:
379 | raise CompletionError(str(e)) from e
380 | if chunk is None:
381 | return ""
382 | if chunk.usage:
383 | prompt_tokens = chunk.usage.prompt_tokens
384 | prompt_tokens_details = getattr(
385 | chunk.usage, "prompt_tokens_details", None
386 | )
387 | if prompt_tokens_details is None:
388 | cached_prompt_tokens = 0
389 | else:
390 | cached_prompt_tokens = prompt_tokens_details.cached_tokens
391 | sampled_tokens = chunk.usage.completion_tokens
392 | self.prompt_tokens = prompt_tokens
393 | self.sampled_tokens = sampled_tokens
394 | self.cost_in_cents = (
395 | self._provider.__class__._estimate_cost_in_cents(
396 | model=chunk.model,
397 | prompt_tokens=prompt_tokens,
398 | cached_prompt_tokens=cached_prompt_tokens,
399 | sampled_tokens=sampled_tokens,
400 | )
401 | )
402 | if chunk.choices is None or len(chunk.choices) != 1:
403 | return ""
404 | delta = chunk.choices[0].delta
405 | if delta is None:
406 | return ""
407 | if delta.role and delta.role != self.message.role:
408 | self.message.role += delta.role
409 | if delta.content:
410 | self.message.content += delta.content
411 | return delta.content
412 | else:
413 | return ""
414 |
--------------------------------------------------------------------------------
/src/gptcmd/message.py:
--------------------------------------------------------------------------------
1 | """
2 | This module contains classes and types for interacting with messages and
3 | message threads.
4 | Copyright 2024 Bill Dengler
5 | This Source Code Form is subject to the terms of the Mozilla Public
6 | License, v. 2.0. If a copy of the MPL was not distributed with this
7 | file, You can obtain one at https://mozilla.org/MPL/2.0/.
8 | """
9 |
10 | import base64
11 | import dataclasses
12 | import mimetypes
13 | import sys
14 | from abc import ABC, abstractmethod
15 | from collections.abc import Sequence
16 | from enum import auto
17 | from typing import (
18 | Any,
19 | Callable,
20 | Dict,
21 | Generic,
22 | Iterable,
23 | List,
24 | Optional,
25 | Tuple,
26 | Type,
27 | TypeVar,
28 | )
29 |
30 | if sys.version_info >= (3, 11):
31 | from enum import StrEnum
32 | else:
33 | from backports.strenum import StrEnum
34 |
35 |
36 | T = TypeVar("T")
37 |
38 |
39 | class TwoWayRegistrar(Generic[T]):
40 | """
41 | A registrar that maintains both forward and reverse mappings between keys
42 | and classes.
43 | Ensures a one-to-one relationship and provides reverse lookup.
44 | """
45 |
46 | def __init__(self):
47 | self._registry: Dict[str, Type[T]] = {}
48 | self._reverse_registry: Dict[Type[T], str] = {}
49 |
50 | def __contains__(self, key: str) -> bool:
51 | return key in self._registry
52 |
53 | def register(self, key: str) -> Callable[[Type[T]], Type[T]]:
54 | """
55 | Decorator to register a class with the passed-in key.
56 | """
57 |
58 | def decorator(cls: Type[T]) -> Type[T]:
59 | if key in self._registry:
60 | raise ValueError(f"{key} is already registered")
61 | elif cls in self._reverse_registry:
62 | raise ValueError(f"{cls} is already registered")
63 | self._registry[key] = cls
64 | self._reverse_registry[cls] = key
65 | return cls
66 |
67 | return decorator
68 |
69 | def get(self, key: str) -> Type[T]:
70 | """
71 | Retrieve a class from the registry by key.
72 | """
73 | if key not in self._registry:
74 | raise KeyError(f"{key} is not registered")
75 | return self._registry[key]
76 |
77 | def reverse_get(self, cls: Type[T]) -> str:
78 | """
79 | Retrieve the key associated with a class from the reverse registry.
80 | """
81 | if cls not in self._reverse_registry:
82 | raise KeyError(f"Class '{cls.__name__}' is not registered")
83 | return self._reverse_registry[cls]
84 |
85 |
86 | attachment_type_registrar: TwoWayRegistrar["MessageAttachment"] = (
87 | TwoWayRegistrar()
88 | )
89 |
90 |
91 | class MessageAttachment(ABC):
92 | """
93 | A non-text component that can be associated with a Message, such as an
94 | image for vision models.
95 | """
96 |
97 | @classmethod
98 | def from_dict(cls, d: Dict[str, Any]) -> "MessageAttachment":
99 | """
100 | Instantiate a MessageAttachment from a dict in the format returned by
101 | MessageAttachment.to_dict()
102 | """
103 | attachment_type_key = d.get("type")
104 | attachment_data = d.get("data", {})
105 | try:
106 | attachment_type = attachment_type_registrar.get(
107 | attachment_type_key
108 | )
109 | except KeyError:
110 | return UnknownAttachment(
111 | _type=attachment_type_key, _data=attachment_data
112 | )
113 | return attachment_type._deserialize(attachment_data)
114 |
115 | @classmethod
116 | @abstractmethod
117 | def _deserialize(cls, d: Dict[str, Any]) -> "MessageAttachment":
118 | "Deserialize a dict into a MessageAttachment subclass instance"
119 | pass
120 |
121 | def to_dict(self) -> Dict[str, Any]:
122 | "Exports this attachment as a serializable dict"
123 | return {
124 | "type": attachment_type_registrar.reverse_get(self.__class__),
125 | "data": self._serialize(),
126 | }
127 |
128 | @abstractmethod
129 | def _serialize(self) -> Dict[str, Any]:
130 | "Serialize this attachment into a dict"
131 | pass
132 |
133 | def __eq__(self, other):
134 | return self.to_dict() == other.to_dict()
135 |
136 |
137 | @attachment_type_registrar.register("image_url")
138 | class Image(MessageAttachment):
139 | "An image reachable by URL that can be fetched by the LLM API."
140 |
141 | def __init__(self, url: str, detail: Optional[str] = None):
142 | self.url = url
143 | self.detail = detail
144 |
145 | @classmethod
146 | def from_path(cls, path: str, *args, **kwargs):
147 | "Instantiate an Image from a file"
148 | with open(path, "rb") as fin:
149 | b64data = base64.b64encode(fin.read()).decode("utf-8")
150 | mimetype = mimetypes.guess_type(path)[0]
151 | return cls(*args, url=f"data:{mimetype};base64,{b64data}", **kwargs)
152 |
153 | @classmethod
154 | def _deserialize(cls, d: Dict[str, Any]) -> "Image":
155 | return cls(url=d["url"], detail=d.get("detail"))
156 |
157 | def _serialize(self) -> Dict[str, Any]:
158 | res = {"url": self.url}
159 | if self.detail is not None:
160 | res["detail"] = self.detail
161 | return res
162 |
163 |
164 | class UnknownAttachment(MessageAttachment):
165 | """
166 | A MessageAttachment created when a dict in the form returned by
167 | MessageAttachment.to_dict contains an unknown or ambiguous type.
168 | This class should not be instantiated directly.
169 | """
170 |
171 | def __init__(self, _type: str, _data: Dict):
172 | self.data = _data
173 | self.type = _type
174 |
175 | @classmethod
176 | def _deserialize(cls, data):
177 | return cls(_type=None, _data=data)
178 |
179 | def _serialize(self):
180 | return self.data.copy()
181 |
182 | def to_dict(self):
183 | # Since these attachments are explicitly not registered, use our
184 | # internal type field instead of the registrar.
185 | return {"type": self.type, "data": self._serialize()}
186 |
187 |
188 | class MessageRole(StrEnum):
189 | """
190 | An enumeration defining valid values for the role attribute on
191 | Message objects
192 | """
193 |
194 | USER = auto()
195 | ASSISTANT = auto()
196 | SYSTEM = auto()
197 |
198 |
199 | @dataclasses.dataclass
200 | class Message:
201 | """A message sent to or received from an LLM."""
202 |
203 | #: The text content of the message
204 | content: str
205 | #: a member of MessageRole that defines the conversational role of the
206 | #: author of this message
207 | role: MessageRole
208 | #: The name of the author of this message
209 | name: Optional[str] = None
210 | #: Whether this message is "sticky" (not affected by thread-level deletion
211 | #: operations)
212 | sticky: bool = False
213 | #: A collection of attached objects, such as images
214 | attachments: List[MessageAttachment] = dataclasses.field(
215 | default_factory=list
216 | )
217 | #: Arbitrary metadata for this message
218 | metadata: Dict[str, Any] = dataclasses.field(default_factory=dict)
219 |
220 | @classmethod
221 | def from_dict(cls, d: Dict[str, Any]):
222 | """
223 | Instantiate a Message from a dict in the format returned by
224 | Message.to_dict()
225 | """
226 | valid_keys = [f.name for f in dataclasses.fields(cls)]
227 | kwargs = {}
228 | for k, v in d.items():
229 | if k == "attachments":
230 | kwargs[k] = [MessageAttachment.from_dict(i) for i in v]
231 | elif k == "role":
232 | kwargs[k] = MessageRole(v)
233 | elif k == "_sticky": # v1 sticky field
234 | kwargs["sticky"] = v
235 | elif k in valid_keys:
236 | kwargs[k] = v
237 | return cls(**kwargs)
238 |
239 | def to_dict(self) -> Dict[str, Any]:
240 | "Exports this message as a serializable dict"
241 | res = {f.name: getattr(self, f.name) for f in dataclasses.fields(self)}
242 | res["attachments"] = [a.to_dict() for a in self.attachments]
243 | return res
244 |
245 | @property
246 | def display_indicators(self):
247 | """
248 | Returns indicators for various states (sticky, has attachments, etc.)
249 | for use in thread rendering and similar display scenarios.
250 | """
251 | return "*" * self.sticky + "@" * len(self.attachments)
252 |
253 |
254 | class PopStickyMessageError(Exception):
255 | "Thrown when attempting to pop a Message marked sticky"
256 |
257 | pass
258 |
259 |
260 | class MessageThread(Sequence):
261 | def __init__(
262 | self,
263 | name: str,
264 | messages: Optional[Iterable[Message]] = None,
265 | names: Optional[Dict[str, str]] = None,
266 | ):
267 | """A conversation thread
268 |
269 | args:
270 | name: The display name of this thread
271 | messages: An iterable of Message objects from which to populate
272 | this thread
273 | names: Mapping of roles to names that should be set on
274 | future messages added to this thread
275 | """
276 | self.name: str = name
277 | self._messages: List[Message] = (
278 | [dataclasses.replace(m) for m in messages]
279 | if messages is not None
280 | else []
281 | )
282 | self.names: Dict[MessageRole, str] = names if names is not None else {}
283 | self.dirty: bool = False
284 |
285 | @classmethod
286 | def from_dict(cls, d: Dict[str, Any], name: str):
287 | """
288 | Instantiate a MessageThread from a dict in the format returned by
289 | MessageThread.to_dict()
290 | """
291 | messages = [Message.from_dict(m) for m in d.get("messages", [])]
292 | names = d.get("names")
293 | if names:
294 | names = {MessageRole(k): v for k, v in names.items()}
295 | res = cls(name=name, messages=messages, names=names)
296 | return res
297 |
298 | def __repr__(self) -> str:
299 | return f"<{self.name} MessageThread {self._messages!r}>"
300 |
301 | def __getitem__(self, n):
302 | return self._messages[n]
303 |
304 | def __len__(self) -> int:
305 | return len(self._messages)
306 |
307 | @property
308 | def messages(self) -> Tuple[Message, ...]:
309 | return tuple(self._messages)
310 |
311 | @messages.setter
312 | def messages(self, val: Iterable[Message]):
313 | self._messages = list(val)
314 | self.dirty = True
315 |
316 | @property
317 | def stickys(self) -> List[Message]:
318 | return [m for m in self._messages if m.sticky]
319 |
320 | def to_dict(self) -> Dict[str, Any]:
321 | "Exports this thread to a serializable dict."
322 | return {
323 | "messages": [m.to_dict() for m in self._messages],
324 | "names": self.names.copy(),
325 | }
326 |
327 | def append(self, message: Message) -> None:
328 | "Adds a new message to the end of this thread"
329 | if not isinstance(message, Message):
330 | raise TypeError("append requires a Message object")
331 | message.name = self.names.get(message.role)
332 | self._messages.append(message)
333 | self.dirty = True
334 |
335 | def render(
336 | self,
337 | start_index: Optional[int] = None,
338 | end_index: Optional[int] = None,
339 | display_indicators: bool = True,
340 | ) -> str:
341 | """Renders this thread as a human-readable transcript
342 |
343 | args:
344 | start_index: the beginning of the range of messages to render
345 | end_index: the end of the range of messages to render
346 | display_indicators: Output symbols to indicate particular message
347 | states (such as an asterisk for sticky messages)
348 | """
349 | lines = (
350 | (msg.display_indicators if display_indicators else "")
351 | + (msg.name if msg.name is not None else msg.role)
352 | + ": "
353 | + msg.content
354 | for msg in self._messages[start_index:end_index]
355 | )
356 | return "\n".join(lines)
357 |
358 | def pop(self, n: Optional[int] = None) -> Message:
359 | "Remove the nth message from this thread and return it"
360 | if n is None:
361 | n = -1
362 | if self._messages[n].sticky:
363 | raise PopStickyMessageError
364 | res = self._messages.pop(n)
365 | self.dirty = True
366 | return res
367 |
368 | def clear(self) -> None:
369 | "Remove *all* messages (except those marked sticky) from this thread"
370 | if self._messages:
371 | self.dirty = True
372 | self._messages = self.stickys
373 |
374 | def move(self, i: Optional[int], j: Optional[int]) -> Message:
375 | """Pop the message at index i and re-insert it at index j"""
376 | msg = self.pop(i)
377 | if j is None:
378 | j = len(self)
379 | self._messages.insert(j, msg)
380 | return msg
381 |
382 | def rename(
383 | self,
384 | role: MessageRole,
385 | name: str,
386 | start_index: Optional[int] = None,
387 | end_index: Optional[int] = None,
388 | ) -> List[Message]:
389 | """
390 | Changes the name set on all non-sticky messages of the specified role
391 | in this thread. If start_index or end_index is specified, only
392 | messages in the specified range are affected
393 | """
394 | res = []
395 | for msg in self._messages[start_index:end_index]:
396 | if msg.role == role and not msg.sticky:
397 | msg.name = name
398 | res.append(msg)
399 | if res:
400 | self.dirty = True
401 | return res
402 |
403 | def sticky(
404 | self, start_index: Optional[int], end_index: Optional[int], state: bool
405 | ) -> List[Message]:
406 | """
407 | Stickys or unstickys (depending on the state parameter) all messages
408 | in this thread. If start_index or end_index is specified, only
409 | messages in the specified range are affected. Returns a list of
410 | messages affected by this operation.
411 | """
412 | res = []
413 | for m in self._messages[start_index:end_index]:
414 | if m.sticky != state:
415 | m.sticky = state
416 | res.append(m)
417 | if res:
418 | self.dirty = True
419 | return res
420 |
--------------------------------------------------------------------------------
/tests/test_llm.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | from gptcmd.llm import InvalidAPIParameterError, LLMProvider, LLMResponse
4 | from gptcmd.message import Message
5 | from typing import Dict, Sequence
6 |
7 |
8 | class CactusProvider(LLMProvider):
9 | @classmethod
10 | def from_config(cls, conf: Dict):
11 | return cls(**conf)
12 |
13 | def complete(self, messages: Sequence[Message]) -> LLMResponse:
14 | return LLMResponse(Message(content="Cactus cactus!", role="assistant"))
15 |
16 | def validate_api_params(self, params):
17 | if "invalid_param" in params:
18 | raise InvalidAPIParameterError("Invalid parameter")
19 | return params
20 |
21 | @property
22 | def valid_models(self):
23 | return ["saguaro-1", "saguaro-2"]
24 |
25 | def get_best_model(self):
26 | return "saguaro-2"
27 |
28 |
29 | class TestLLMProvider(unittest.TestCase):
30 | def setUp(self):
31 | self.llm = CactusProvider()
32 |
33 | def test_init(self):
34 | self.assertEqual(self.llm.model, "saguaro-2")
35 | self.assertEqual(self.llm._api_params, {})
36 | self.assertFalse(self.llm.stream)
37 |
38 | def test_set_api_param_valid(self):
39 | self.llm.set_api_param("temperature", 0.8)
40 | self.assertEqual(self.llm._api_params["temperature"], 0.8)
41 |
42 | def test_set_api_param_invalid(self):
43 | with self.assertRaises(InvalidAPIParameterError):
44 | self.llm.set_api_param("invalid_param", "value")
45 |
46 | def test_update_api_params_valid(self):
47 | self.llm.update_api_params({"temperature": 0.8, "max_tokens": 100})
48 | self.assertEqual(self.llm._api_params["temperature"], 0.8)
49 | self.assertEqual(self.llm._api_params["max_tokens"], 100)
50 |
51 | def test_update_api_params_invalid(self):
52 | with self.assertRaises(InvalidAPIParameterError):
53 | self.llm.update_api_params(
54 | {"temperature": 0.8, "invalid_param": "value"}
55 | )
56 |
57 | def test_complete(self):
58 | messages = [Message(content="Hello", role="user")]
59 | response = self.llm.complete(messages)
60 | self.assertIsInstance(response, LLMResponse)
61 | self.assertEqual(response.message.content, "Cactus cactus!")
62 | self.assertEqual(response.message.role, "assistant")
63 |
64 | def test_default_text_iter(self):
65 | messages = [Message(content="Testing testing", role="user")]
66 | response = self.llm.complete(messages)
67 | self.assertIsInstance(response, LLMResponse)
68 | buf = ""
69 | for chunk in response:
70 | buf += chunk
71 | self.assertEqual(buf, "Cactus cactus!")
72 |
73 | def test_valid_models(self):
74 | self.assertEqual(self.llm.valid_models, ["saguaro-1", "saguaro-2"])
75 |
76 | def test_get_best_model(self):
77 | self.assertEqual(self.llm.get_best_model(), "saguaro-2")
78 |
79 |
80 | if __name__ == "__main__":
81 | unittest.main()
82 |
--------------------------------------------------------------------------------
/tests/test_message.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from gptcmd.message import (
3 | Image,
4 | Message,
5 | MessageRole,
6 | MessageThread,
7 | PopStickyMessageError,
8 | UnknownAttachment,
9 | )
10 |
11 | """
12 | This module contains unit tests for MessageThread and related objects.
13 | Copyright 2023 Bill Dengler
14 | This Source Code Form is subject to the terms of the Mozilla Public
15 | License, v. 2.0. If a copy of the MPL was not distributed with this
16 | file, You can obtain one at https://mozilla.org/MPL/2.0/.
17 | """
18 |
19 |
20 | class TestMessageThreadInit(unittest.TestCase):
21 | def test_init_empty(self):
22 | thread = MessageThread(name="test")
23 | self.assertEqual(thread.name, "test")
24 | self.assertEqual(len(thread), 0)
25 | self.assertEqual(thread.dirty, False)
26 |
27 | def test_init_with_messages(self):
28 | messages = [
29 | Message(content="Hello", role=MessageRole.USER),
30 | Message(content="Hi", role=MessageRole.ASSISTANT),
31 | ]
32 | thread = MessageThread(name="test", messages=messages)
33 | self.assertEqual(len(thread), 2)
34 | self.assertEqual(thread[0].content, "Hello")
35 | self.assertEqual(thread[1].content, "Hi")
36 |
37 | def test_init_with_names(self):
38 | names = {MessageRole.USER: "Alice", MessageRole.ASSISTANT: "Mila"}
39 | thread = MessageThread(name="test", names=names)
40 | self.assertEqual(thread.names, names)
41 |
42 |
43 | class TestMessageThread(unittest.TestCase):
44 | def setUp(self):
45 | self.thread = MessageThread(name="test")
46 |
47 | def test_append(self):
48 | self.thread.append(Message(content="Hello", role=MessageRole.USER))
49 | self.assertEqual(len(self.thread), 1)
50 | self.assertEqual(self.thread[0].content, "Hello")
51 | self.assertEqual(self.thread[0].role, MessageRole.USER)
52 | self.assertTrue(self.thread.dirty)
53 |
54 | def test_render(self):
55 | self.thread.append(
56 | Message(content="What is a cactus?", role=MessageRole.USER)
57 | )
58 | self.thread.append(
59 | Message(
60 | content=(
61 | "A desert plant with thick, fleshy stems, sharp spines,"
62 | " and beautiful, short-lived flowers."
63 | ),
64 | role=MessageRole.ASSISTANT,
65 | )
66 | )
67 | self.assertEqual(
68 | self.thread.render(),
69 | "user: What is a cactus?\nassistant: A desert plant with thick,"
70 | " fleshy stems, sharp spines, and beautiful, short-lived flowers.",
71 | )
72 |
73 | def test_render_custom_names(self):
74 | self.thread.names = {
75 | MessageRole.USER: "Bill",
76 | MessageRole.ASSISTANT: "Kevin",
77 | }
78 | self.thread.append(
79 | Message(content="What is a cactus?", role=MessageRole.USER)
80 | )
81 | self.thread.append(
82 | Message(
83 | content=(
84 | "A desert plant with thick, fleshy stems, sharp spines,"
85 | " and beautiful, short-lived flowers."
86 | ),
87 | role=MessageRole.ASSISTANT,
88 | )
89 | )
90 | self.assertEqual(
91 | self.thread.render(),
92 | "Bill: What is a cactus?\nKevin: A desert plant with thick, fleshy"
93 | " stems, sharp spines, and beautiful, short-lived flowers.",
94 | )
95 |
96 | def test_pop(self):
97 | self.thread.append(Message(content="Hello", role=MessageRole.USER))
98 | self.thread.append(Message(content="Hi", role=MessageRole.ASSISTANT))
99 | popped = self.thread.pop()
100 | self.assertEqual(len(self.thread), 1)
101 | self.assertEqual(popped.content, "Hi")
102 | self.assertEqual(popped.role, MessageRole.ASSISTANT)
103 | self.thread.pop()
104 | with self.assertRaises(IndexError):
105 | self.thread.pop()
106 |
107 | def test_pop_sticky(self):
108 | self.thread.append(
109 | Message(content="Hello", role=MessageRole.USER, sticky=True)
110 | )
111 | with self.assertRaises(PopStickyMessageError):
112 | self.thread.pop()
113 |
114 | def test_clear(self):
115 | self.thread.append(Message(content="Hello", role=MessageRole.USER))
116 | self.thread.append(Message(content="Hi", role=MessageRole.ASSISTANT))
117 | self.thread.clear()
118 | self.assertEqual(len(self.thread), 0)
119 |
120 | def test_clear_sticky(self):
121 | self.thread.append(
122 | Message(content="Hello", role=MessageRole.USER, sticky=True)
123 | )
124 | self.thread.append(Message(content="Hi", role=MessageRole.ASSISTANT))
125 | self.thread.clear()
126 | self.assertEqual(len(self.thread), 1)
127 |
128 | def test_flip(self):
129 | self.thread.append(Message(content="Hello", role=MessageRole.USER))
130 | self.thread.append(Message(content="Hi", role=MessageRole.ASSISTANT))
131 | flipped = self.thread.move(-1, 0)
132 | self.assertEqual(flipped.content, "Hi")
133 | self.assertEqual(self.thread[0].content, "Hi")
134 | self.assertEqual(self.thread[0].role, MessageRole.ASSISTANT)
135 | self.assertEqual(self.thread[1].content, "Hello")
136 | self.assertEqual(self.thread[1].role, MessageRole.USER)
137 |
138 | def test_rename(self):
139 | self.thread.append(Message(content="Hello", role=MessageRole.USER))
140 | self.thread.append(Message(content="Hi", role=MessageRole.ASSISTANT))
141 | self.thread.rename(role=MessageRole.ASSISTANT, name="GPT")
142 | self.assertEqual(self.thread[1].name, "GPT")
143 |
144 | def test_rename_limited_range(self):
145 | self.thread.append(Message(content="abc", role=MessageRole.USER))
146 | self.thread.append(Message(content="def", role=MessageRole.ASSISTANT))
147 | self.thread.append(Message(content="ghi", role=MessageRole.USER))
148 | self.thread.append(Message(content="jkl", role=MessageRole.USER))
149 | self.thread.rename(
150 | role=MessageRole.USER, name="Kevin", start_index=0, end_index=2
151 | )
152 | self.assertEqual(self.thread[0].name, "Kevin")
153 | self.assertIsNone(self.thread[1].name)
154 | self.assertIsNone(self.thread[2].name)
155 | self.assertIsNone(self.thread[3].name)
156 |
157 | def test_sticky(self):
158 | self.thread.append(Message(content="Hello", role=MessageRole.USER))
159 | self.thread.append(Message(content="Hi", role=MessageRole.ASSISTANT))
160 | self.thread.sticky(0, 1, True)
161 | self.assertTrue(self.thread[0].sticky)
162 | self.assertFalse(self.thread[1].sticky)
163 |
164 | def test_messages_property(self):
165 | self.thread.append(Message(content="Hello", role=MessageRole.USER))
166 | self.thread.append(Message(content="Hi", role=MessageRole.ASSISTANT))
167 | messages = self.thread.messages
168 | self.assertIsInstance(messages, tuple)
169 | self.assertEqual(len(messages), 2)
170 | self.assertEqual(messages[0].content, "Hello")
171 | self.assertEqual(messages[1].content, "Hi")
172 |
173 | def test_to_dict(self):
174 | self.thread.append(Message(content="Hello", role=MessageRole.USER))
175 | self.thread.append(Message(content="Hi", role=MessageRole.ASSISTANT))
176 | thread_dict = self.thread.to_dict()
177 | self.assertIn("messages", thread_dict)
178 | self.assertIn("names", thread_dict)
179 | self.assertEqual(len(thread_dict["messages"]), 2)
180 |
181 | def test_from_dict(self):
182 | thread_dict = {
183 | "messages": [
184 | {"content": "Hello", "role": "user"},
185 | {"content": "Hi", "role": "assistant"},
186 | ],
187 | "names": {"user": "Alice", "assistant": "Mila"},
188 | }
189 | thread = MessageThread.from_dict(thread_dict, name="test")
190 | self.assertEqual(thread.name, "test")
191 | self.assertEqual(len(thread), 2)
192 | self.assertEqual(thread[0].content, "Hello")
193 | self.assertEqual(thread[0].role, MessageRole.USER)
194 | self.assertEqual(thread[1].content, "Hi")
195 | self.assertEqual(thread[1].role, MessageRole.ASSISTANT)
196 | self.assertEqual(
197 | thread.names,
198 | {MessageRole.USER: "Alice", MessageRole.ASSISTANT: "Mila"},
199 | )
200 |
201 |
202 | class TestMessage(unittest.TestCase):
203 | def test_message_creation(self):
204 | message = Message(content="Hello", role=MessageRole.USER)
205 | self.assertEqual(message.content, "Hello")
206 | self.assertEqual(message.role, MessageRole.USER)
207 | self.assertIsNone(message.name)
208 | self.assertFalse(message.sticky)
209 | self.assertEqual(message.attachments, [])
210 |
211 | def test_message_with_attachment(self):
212 | image = Image(url="http://example.com/image.jpg")
213 | message = Message(
214 | content="What's in this image?",
215 | role=MessageRole.USER,
216 | attachments=[image],
217 | )
218 | self.assertEqual(len(message.attachments), 1)
219 | self.assertIsInstance(message.attachments[0], Image)
220 |
221 | def test_message_to_dict(self):
222 | message = Message(content="Hello", role=MessageRole.USER, name="Bill")
223 | message_dict = message.to_dict()
224 | self.assertEqual(message_dict["content"], "Hello")
225 | self.assertEqual(message_dict["role"], MessageRole.USER)
226 | self.assertEqual(message_dict["name"], "Bill")
227 |
228 | def test_message_from_dict(self):
229 | message_dict = {
230 | "content": "Hello",
231 | "role": "user",
232 | "name": "Bill",
233 | "sticky": True,
234 | "attachments": [
235 | {
236 | "type": "image_url",
237 | "data": {"url": "http://example.com/image.jpg"},
238 | }
239 | ],
240 | }
241 | message = Message.from_dict(message_dict)
242 | self.assertEqual(message.content, "Hello")
243 | self.assertEqual(message.role, MessageRole.USER)
244 | self.assertEqual(message.name, "Bill")
245 | self.assertTrue(message.sticky)
246 | self.assertEqual(len(message.attachments), 1)
247 | self.assertIsInstance(message.attachments[0], Image)
248 |
249 | def test_message_unknown_attachment(self):
250 | message_dict = {
251 | "content": "",
252 | "role": "user",
253 | "attachments": [
254 | {
255 | "type": "nonexistent_attachment",
256 | "data": {"username": "kwebb"},
257 | }
258 | ],
259 | }
260 | message = Message.from_dict(message_dict)
261 | self.assertEqual(len(message.attachments), 1)
262 | self.assertIsInstance(message.attachments[0], UnknownAttachment)
263 | serialized_message = message.to_dict()
264 | self.assertEqual(
265 | message_dict["attachments"][0],
266 | serialized_message["attachments"][0],
267 | )
268 |
269 |
270 | class TestImage(unittest.TestCase):
271 | def test_image_creation(self):
272 | image = Image(url="http://example.com/image.jpg", detail="high")
273 | self.assertEqual(image.url, "http://example.com/image.jpg")
274 | self.assertEqual(image.detail, "high")
275 |
276 | def test_image_to_dict(self):
277 | image = Image(url="http://example.com/image.jpg", detail="high")
278 | image_dict = image.to_dict()
279 | self.assertEqual(image_dict["type"], "image_url")
280 | self.assertEqual(
281 | image_dict["data"]["url"], "http://example.com/image.jpg"
282 | )
283 | self.assertEqual(image_dict["data"]["detail"], "high")
284 |
285 | def test_image_from_dict(self):
286 | image_dict = {
287 | "type": "image_url",
288 | "data": {"url": "http://example.com/image.jpg", "detail": "high"},
289 | }
290 | image = Image.from_dict(image_dict)
291 | self.assertEqual(image.url, "http://example.com/image.jpg")
292 | self.assertEqual(image.detail, "high")
293 |
294 |
295 | if __name__ == "__main__":
296 | unittest.main()
297 |
--------------------------------------------------------------------------------