├── requirements.txt ├── .gitattributes ├── doc └── AssistantTUI.mp4 ├── domain ├── tools │ ├── tool.py │ ├── tool_handler.py │ ├── code_interpreter.py │ ├── retrieval.py │ └── function.py ├── steps │ ├── step_handler.py │ ├── message_creation.py │ ├── tool_call.py │ ├── step_list.py │ └── step.py ├── thread_list.py ├── assistant.py ├── message.py ├── thread.py └── run.py ├── main.py ├── .vscode ├── settings.json └── launch.json ├── callable.py ├── log.py ├── app └── terminal │ ├── thread │ ├── thread_container.py │ ├── thread_list.py │ └── thread_messages.py │ ├── assistant_app.py │ ├── app.css │ └── assistant_container.py ├── README.md └── .gitignore /requirements.txt: -------------------------------------------------------------------------------- 1 | textual==0.46.0 2 | openai==1.6.0 -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /doc/AssistantTUI.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dalssoft/assistant-tui/HEAD/doc/AssistantTUI.mp4 -------------------------------------------------------------------------------- /domain/tools/tool.py: -------------------------------------------------------------------------------- 1 | class Tool: 2 | def __init__(self, tool_call): 3 | self.tool_call = tool_call 4 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from app.terminal.assistant_app import AssistantApp 2 | 3 | if __name__ == "__main__": 4 | app = AssistantApp() 5 | app.run() 6 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.analysis.autoImportCompletions": true, 3 | "python.analysis.typeCheckingMode": "basic", 4 | "[python]": { 5 | "editor.defaultFormatter": "ms-python.black-formatter", 6 | "editor.formatOnSave": true 7 | }, 8 | "terminal.integrated.minimumContrastRatio": 1 9 | } -------------------------------------------------------------------------------- /callable.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import asyncio 3 | 4 | 5 | class Callable: 6 | def __init__(self, fn): 7 | self.fn = fn 8 | 9 | def __call__(self, *args, **kwargs): 10 | if inspect.iscoroutinefunction(self.fn): 11 | asyncio.ensure_future(self.fn(*args, **kwargs)) 12 | else: 13 | self.fn(*args, **kwargs) 14 | -------------------------------------------------------------------------------- /domain/tools/tool_handler.py: -------------------------------------------------------------------------------- 1 | from domain.tools.code_interpreter import CodeInterpreterTool 2 | from domain.tools.function import FunctionTool 3 | from domain.tools.retrieval import RetrievalTool 4 | 5 | 6 | class ToolHandler: 7 | @staticmethod 8 | def from_raw(tool_call): 9 | types = [CodeInterpreterTool, FunctionTool, RetrievalTool] 10 | for t in types: 11 | if t.is_type(tool_call): 12 | return t(tool_call) 13 | -------------------------------------------------------------------------------- /domain/tools/code_interpreter.py: -------------------------------------------------------------------------------- 1 | from domain.tools.tool import Tool 2 | 3 | 4 | class CodeInterpreterTool(Tool): 5 | type = "code_interpreter" 6 | 7 | @staticmethod 8 | def is_type(tool_call): 9 | type = tool_call.type if hasattr(tool_call, "type") else tool_call["type"] 10 | return type == CodeInterpreterTool.type 11 | 12 | def debug(self): 13 | return f""" 14 | | type: {self.type} 15 | | input: {self.tool_call.code_interpreter.input} 16 | | outputs: {self.tool_call.code_interpreter.outputs} 17 | """ 18 | -------------------------------------------------------------------------------- /domain/tools/retrieval.py: -------------------------------------------------------------------------------- 1 | from domain.tools.tool import Tool 2 | 3 | 4 | class RetrievalTool(Tool): 5 | type = "retrieval" 6 | 7 | @staticmethod 8 | def is_type(tool_call): 9 | type = tool_call.type if hasattr(tool_call, "type") else tool_call["type"] 10 | return type == RetrievalTool.type 11 | 12 | def debug(self): 13 | retrieval = ( 14 | self.tool_call.retrieval.retrieval if self.tool_call.retrieval else None 15 | ) 16 | return f""" 17 | | type: {self.type} 18 | | retrieval: {retrieval} 19 | """ 20 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Init Terminal App", 9 | "type": "python", 10 | "request": "launch", 11 | "program": "./main.py", 12 | "args": [], 13 | "console": "integratedTerminal", 14 | "justMyCode": true, 15 | }, 16 | ] 17 | } -------------------------------------------------------------------------------- /domain/steps/step_handler.py: -------------------------------------------------------------------------------- 1 | from domain.steps.message_creation import MessageCreationStep 2 | from domain.steps.tool_call import ToolCallStep 3 | 4 | 5 | class StepHandler: 6 | @staticmethod 7 | def from_raw(thread_run, thread_run_step): 8 | types = [MessageCreationStep, ToolCallStep] 9 | step = None 10 | for type in types: 11 | if type.is_type(thread_run_step): 12 | step = type(thread_run, thread_run_step) 13 | break 14 | step.thread_run_step = thread_run_step 15 | step.status = thread_run_step.status 16 | return step 17 | -------------------------------------------------------------------------------- /domain/tools/function.py: -------------------------------------------------------------------------------- 1 | from domain.tools.tool import Tool 2 | 3 | 4 | class FunctionTool(Tool): 5 | type = "function" 6 | 7 | @staticmethod 8 | def is_type(tool_call): 9 | type = tool_call.type if hasattr(tool_call, "type") else tool_call["type"] 10 | return type == FunctionTool.type 11 | 12 | def debug(self): 13 | return f""" 14 | | type: {self.type} 15 | | function: {self.tool_call["function"]["name"]} 16 | | arguments: {self.tool_call["function"]["arguments"]} 17 | | outputs: {self.tool_call["function"]["outputs"] if "outputs" in self.tool_call["function"] else None} 18 | """ 19 | -------------------------------------------------------------------------------- /domain/steps/message_creation.py: -------------------------------------------------------------------------------- 1 | from domain.steps.step import Step 2 | 3 | 4 | class MessageCreationStep(Step): 5 | type = "message_creation" 6 | 7 | def __init__(self, thread_run, thread_run_step): 8 | super().__init__(thread_run, thread_run_step.id) 9 | self.thread_run_step = thread_run_step 10 | 11 | @staticmethod 12 | def is_type(step): 13 | return step.step_details.type == MessageCreationStep.type 14 | 15 | def message_id(self): 16 | return self.thread_run_step.step_details.message_creation.message_id 17 | 18 | def debug(self): 19 | return ( 20 | super().debug() 21 | + f""" 22 | | message_id: {self.message_id()} 23 | """ 24 | ) 25 | -------------------------------------------------------------------------------- /domain/steps/tool_call.py: -------------------------------------------------------------------------------- 1 | from json import tool 2 | from domain.steps.step import Step 3 | from domain.tools.tool_handler import ToolHandler 4 | 5 | 6 | class ToolCallStep(Step): 7 | type = "tool_calls" 8 | 9 | def __init__(self, thread_run, thread_run_step): 10 | super().__init__(thread_run, thread_run_step.id) 11 | self.tool_calls = thread_run_step.step_details.tool_calls 12 | 13 | @staticmethod 14 | def is_type(step): 15 | return step.step_details.type == ToolCallStep.type 16 | 17 | def debug(self): 18 | tools = [ToolHandler.from_raw(tool_call) for tool_call in self.tool_calls] 19 | debugs = [tool.debug() for tool in tools] 20 | return ( 21 | super().debug() 22 | + f""" 23 | | tool_calls: {"".join(debugs)} 24 | """ 25 | ) 26 | -------------------------------------------------------------------------------- /log.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | logging.basicConfig( 4 | level=logging.INFO, 5 | filemode="a", # append 6 | filename="assistant.log", 7 | format="[%(asctime)s:%(levelname)s:%(name)s]%(message)s", 8 | ) 9 | 10 | 11 | def log(*messages): 12 | logging.info("\n".join(str(message) for message in messages)) 13 | 14 | 15 | def log_action(instance, action, *messages): 16 | if isinstance(instance, str): 17 | # instance is a string (ex: "Thread") then it's a class name 18 | instance_name = instance 19 | else: 20 | # instance is the self of the class 21 | instance_name = instance.__class__.__name__ 22 | 23 | # check if instance has an id 24 | if hasattr(instance, "id"): 25 | instance_name += f":id={instance.id}" 26 | 27 | logging.info( 28 | f"[{instance_name}:{action}]" + "\n".join(str(message) for message in messages) 29 | ) 30 | -------------------------------------------------------------------------------- /domain/steps/step_list.py: -------------------------------------------------------------------------------- 1 | from re import T 2 | from log import log_action 3 | from openai import OpenAI 4 | from domain.steps.step_handler import StepHandler 5 | 6 | 7 | class StepList: 8 | def __init__(self, thread_run): 9 | self.client = OpenAI() 10 | self.thread_run = thread_run 11 | self.steps = [] 12 | self.callbacks = { 13 | "new_step": [], 14 | } 15 | 16 | async def refresh(self): 17 | steps = self.client.beta.threads.runs.steps.list( 18 | thread_id=self.thread_run.thread.id, 19 | run_id=self.thread_run.id, 20 | ) 21 | 22 | for step in steps.data: 23 | if step.id not in [s.id for s in self.steps]: 24 | new_step = StepHandler.from_raw(self.thread_run, step) 25 | await self._add_new_step(new_step) 26 | 27 | log_action(self, "refresh", steps) 28 | 29 | return self 30 | 31 | def watch_for_new_step(self, callback): 32 | self.callbacks["new_step"].append(callback) 33 | 34 | async def _on_new_step(self, new_step): 35 | for callback in self.callbacks["new_step"]: 36 | await callback(new_step) 37 | 38 | async def _add_new_step(self, new_step): 39 | self.steps.append(new_step) 40 | await self._on_new_step(new_step) 41 | -------------------------------------------------------------------------------- /app/terminal/thread/thread_container.py: -------------------------------------------------------------------------------- 1 | from textual.containers import Container, Horizontal 2 | from textual.reactive import reactive 3 | from app.terminal.thread.thread_list import ThreadListContainer 4 | from app.terminal.thread.thread_messages import ThreadMessagesContainer 5 | 6 | 7 | class ThreadContainer(Container): 8 | assistant = reactive(None) 9 | 10 | def compose(self): 11 | yield Horizontal( 12 | ThreadListContainer(id="thread_list_container"), 13 | ThreadMessagesContainer(id="thread_messages_container"), 14 | id="thread_container_inner", 15 | ) 16 | 17 | def watch_assistant(self, assistant): 18 | thread_list_container = self.query_one("#thread_list_container") 19 | thread_list_container.assistant = assistant 20 | thread_messages_container = self.query_one("#thread_messages_container") 21 | thread_messages_container.assistant = assistant 22 | 23 | async def on_thread_list_container_thread_selected(self, event): 24 | thread_messages_container_old = self.query_one("#thread_messages_container") 25 | await thread_messages_container_old.remove() 26 | 27 | thread_messages_container = ThreadMessagesContainer( 28 | id="thread_messages_container" 29 | ) 30 | 31 | thread_container = self.query_one("#thread_container_inner") 32 | await thread_container.mount(thread_messages_container) 33 | thread_messages_container.assistant = self.assistant 34 | thread_messages_container.thread = event.thread 35 | -------------------------------------------------------------------------------- /app/terminal/assistant_app.py: -------------------------------------------------------------------------------- 1 | from textual import on 2 | from textual.app import App, ComposeResult 3 | from textual.widgets import Header, Footer 4 | from log import log_action 5 | from app.terminal.assistant_container import AssistantContainer 6 | from app.terminal.thread.thread_container import ThreadContainer 7 | 8 | 9 | class AppState: 10 | assistant = None 11 | thread = None 12 | 13 | 14 | class AssistantApp(App): 15 | CSS_PATH = "app.css" 16 | BINDINGS = [("d", "toggle_dark", "Toggle dark mode"), ("q", "quit", "Quit")] 17 | TITLE = "Assistant Terminal 🤖" 18 | 19 | app_state = AppState() 20 | 21 | def compose(self) -> ComposeResult: 22 | yield Header() 23 | yield Footer() 24 | yield AssistantContainer(id="assistant_container") 25 | yield ThreadContainer(id="thread_container", classes="remove") 26 | 27 | def action_toggle_dark(self) -> None: 28 | """An action to toggle dark mode.""" 29 | self.dark = not self.dark 30 | 31 | def action_quit(self) -> None: 32 | """An action to quit the app.""" 33 | self.exit() 34 | 35 | def on_assistant_container_assistant_selected(self, event): 36 | self.app_state.assistant = event.assistant 37 | self.query_one("#assistant_container").display = False 38 | thread_container = self.query_one("#thread_container") 39 | thread_container.display = True 40 | thread_container.assistant = event.assistant 41 | log_action(self, "on_assistant_container_assistant_selected", event.assistant) 42 | 43 | 44 | if __name__ == "__main__": 45 | app = AssistantApp() 46 | app.run() 47 | -------------------------------------------------------------------------------- /domain/thread_list.py: -------------------------------------------------------------------------------- 1 | from domain.thread import Thread 2 | import json 3 | from log import log_action 4 | 5 | 6 | class ThreadList: 7 | def __init__(self, scope: str): 8 | self.scope = scope 9 | self.file_name = f"threads_{self.scope}.json" 10 | 11 | def _persist(self, threads): 12 | file = open(self.file_name, "w") 13 | file.write(json.dumps(threads, default=lambda x: x.to_json())) 14 | file.close() 15 | 16 | def _read(self): 17 | try: 18 | file = open(self.file_name, "r") 19 | return json.loads(file.read()) 20 | except: 21 | return [] 22 | 23 | def _add_thread(self, thread): 24 | threads = self._read() 25 | threads.append(thread) 26 | self._persist(threads) 27 | 28 | def create_thread(self, name): 29 | thread = Thread(id=None) 30 | thread.create(name) 31 | self._add_thread(thread) 32 | log_action(ThreadList, "create_thread", thread) 33 | return thread 34 | 35 | def remove_thread(self, id): 36 | threads = self._read() 37 | log_action(ThreadList, "remove_thread", id) 38 | for thread in threads: 39 | if thread["id"] == id: 40 | threads.remove(thread) 41 | Thread(id=id).delete() 42 | self._persist(threads) 43 | return True 44 | return False 45 | 46 | def list_all(self): 47 | threads_data = self._read() 48 | threads = [] 49 | for thread_data in threads_data: 50 | thread = Thread(id=thread_data["id"]) 51 | thread.name = thread_data["name"] 52 | threads.append(thread) 53 | log_action(ThreadList, "list_all", threads) 54 | return threads 55 | -------------------------------------------------------------------------------- /domain/assistant.py: -------------------------------------------------------------------------------- 1 | from openai import OpenAI 2 | from log import log_action 3 | 4 | 5 | class Assistant: 6 | def __init__(self, id, assistant=None): 7 | self.client = OpenAI() 8 | self.id = id 9 | self.assistant = assistant 10 | self.name = assistant.name if assistant else None 11 | 12 | @staticmethod 13 | def create(name, instructions): 14 | # Create a new assistant 15 | client = OpenAI() 16 | my_assistant = client.beta.assistants.create( 17 | instructions=instructions, 18 | name=name, 19 | tools=[{"type": "code_interpreter"}], 20 | model="gpt-4-1106-preview", 21 | ) 22 | assistant = Assistant.from_raw(my_assistant) 23 | log_action("Assistant", "create", assistant) 24 | return assistant 25 | 26 | def retrieve(self): 27 | my_assistant = self.client.beta.assistants.retrieve(self.id) 28 | self.assistant = my_assistant 29 | log_action(self, "retrieve", my_assistant) 30 | return self 31 | 32 | def delete(self): 33 | response = self.client.beta.assistants.delete(self.id) 34 | log_action(self, "delete", response) 35 | 36 | @staticmethod 37 | def list_all(): 38 | client = OpenAI() 39 | 40 | my_assistants = client.beta.assistants.list( 41 | order="desc", 42 | limit=100, 43 | ) 44 | 45 | # transform raw assistants into Assistant objects 46 | assistants = [] 47 | for assistant in my_assistants.data: 48 | assistant = Assistant.from_raw(assistant) 49 | assistants.append(assistant) 50 | 51 | log_action("Assistant", "list_all", assistants) 52 | 53 | return assistants 54 | 55 | @staticmethod 56 | def from_raw(assistant): 57 | return Assistant(assistant.id, assistant) 58 | -------------------------------------------------------------------------------- /domain/message.py: -------------------------------------------------------------------------------- 1 | from log import log_action 2 | from openai import OpenAI 3 | 4 | 5 | class Message: 6 | def __init__(self, thread, id): 7 | self.client = OpenAI() 8 | self.thread = thread 9 | self.id = id 10 | self.thread_message = None 11 | self.content = [] 12 | self.role = None 13 | 14 | def text(self): 15 | texts = list(content.text.value for content in self.content) 16 | return " ".join(texts) 17 | 18 | def annotations(self): 19 | # TODO: Currently the API is not returning annotations 20 | return list(content.text.annotations for content in self.content) 21 | 22 | @staticmethod 23 | def create(thread, text): 24 | if thread is None: 25 | raise Exception("Thread is required") 26 | if not text: 27 | raise Exception("Text is required") 28 | 29 | client = OpenAI() 30 | thread_message = client.beta.threads.messages.create( 31 | thread.id, 32 | role="user", 33 | content=text, 34 | ) 35 | new_message = Message.from_raw(thread, thread_message) 36 | log_action(new_message, "create", new_message) 37 | return new_message 38 | 39 | @staticmethod 40 | def retrieve(thread, id): 41 | client = OpenAI() 42 | thread_message = client.beta.threads.messages.retrieve( 43 | message_id=id, 44 | thread_id=thread.id, 45 | ) 46 | message = Message.from_raw(thread, thread_message) 47 | log_action(thread_message, "retrieve", message) 48 | return message 49 | 50 | @staticmethod 51 | def from_raw(thread, thread_message): 52 | message = Message(thread, thread_message.id) 53 | message.thread_message = thread_message 54 | message.content = thread_message.content 55 | message.role = thread_message.role 56 | return message 57 | -------------------------------------------------------------------------------- /domain/thread.py: -------------------------------------------------------------------------------- 1 | from log import log_action 2 | from openai import OpenAI 3 | from domain.message import Message 4 | 5 | 6 | class Thread: 7 | def __init__(self, id): 8 | self.client = OpenAI() 9 | self.id = id 10 | self.name = None 11 | self._thread = None 12 | self.messages = [] 13 | self.callbacks = { 14 | "new_message": [], 15 | } 16 | 17 | def create(self, name): 18 | empty_thread = self.client.beta.threads.create() 19 | self.id = empty_thread.id 20 | self.name = name 21 | self._thread = empty_thread 22 | log_action(self, "create", empty_thread) 23 | 24 | def retrieve(self): 25 | thread = self.client.beta.threads.retrieve(self.id) 26 | self._thread = thread 27 | log_action(self, "retrieve", thread) 28 | return self 29 | 30 | def delete(self): 31 | response = self.client.beta.threads.delete(self.id) 32 | log_action(self, "delete", response) 33 | 34 | def retrieve_messages( 35 | self, 36 | limit=None, 37 | order=None, 38 | before=None, 39 | after=None, 40 | ): 41 | limit = limit or 100 42 | order = order or "desc" 43 | 44 | thread_messages = self.client.beta.threads.messages.list( 45 | thread_id=self.id, 46 | limit=limit, 47 | order=order, 48 | before=before, 49 | after=after, 50 | ) 51 | 52 | messages = [] 53 | for thread_message in thread_messages.data: 54 | message = Message.from_raw(self, thread_message) 55 | messages.append(message) 56 | 57 | self.messages = messages 58 | 59 | log_action(self, "retrieve_messages", messages) 60 | 61 | return messages 62 | 63 | async def retrieve_message_and_append(self, message_id): 64 | message = Message.retrieve(self, message_id) 65 | await self._add_new_message(message) 66 | return message 67 | 68 | def to_json(self): 69 | return {"id": self.id, "name": self.name} 70 | 71 | def watch_for_new_message(self, callback): 72 | self.callbacks["new_message"].append(callback) 73 | 74 | async def _on_new_message(self, new_message): 75 | for callback in self.callbacks["new_message"]: 76 | await callback(new_message) 77 | 78 | async def _add_new_message(self, new_message): 79 | self.messages.append(new_message) 80 | await self._on_new_message(new_message) 81 | -------------------------------------------------------------------------------- /domain/steps/step.py: -------------------------------------------------------------------------------- 1 | from log import log_action 2 | from openai import OpenAI 3 | import time 4 | 5 | 6 | class Step: 7 | run_status = { 8 | "in_progress": "in_progress", 9 | "cancelled": "cancelled", 10 | "failed": "failed", 11 | "completed": "completed", 12 | "expired": "expired", 13 | } 14 | 15 | time_to_wait = 0.1 16 | time_to_timeout = 120 17 | 18 | def __init__(self, thread_run, id=None): 19 | self.client = OpenAI() 20 | self.thread_run = thread_run 21 | self.id = id 22 | self.thread_run_step = None 23 | self.status = self.run_status["in_progress"] 24 | self.callbacks = { 25 | "status_change": [], 26 | } 27 | 28 | def refresh(self): 29 | step = self.client.beta.threads.runs.steps.retrieve( 30 | thread_id=self.thread_run.thread.id, 31 | run_id=self.thread_run.id, 32 | step_id=self.id, 33 | ) 34 | self.thread_run_step = step 35 | self.status = step.status 36 | self.type = step.step_details.type 37 | log_action(self, "refresh", step) 38 | return self 39 | 40 | async def wait_for_completion(self): 41 | start_time = time.time() 42 | previous_status = self.status 43 | while self.status == self.run_status["in_progress"]: 44 | # handle timeout 45 | is_timeout = time.time() - start_time > self.time_to_timeout 46 | if is_timeout: 47 | self.status = self.run_status["expired"] 48 | log_action(self, "timeout") 49 | break 50 | 51 | # retrieve and check status 52 | self.refresh() 53 | 54 | # handle status change 55 | if self.status != previous_status: 56 | previous_status = self.status 57 | await self._on_status_change() 58 | 59 | # wait for next iteration 60 | time.sleep(self.time_to_wait) 61 | 62 | return self 63 | 64 | def has_completed(self): 65 | return self.status != self.run_status["in_progress"] 66 | 67 | def watch_for_status_change(self, callback): 68 | self.callbacks["status_change"].append(callback) 69 | 70 | async def _on_status_change(self): 71 | for callback in self.callbacks["status_change"]: 72 | await callback(self) 73 | 74 | def debug(self): 75 | return f""" 76 | | id: {self.id} 77 | | status: {self.status} 78 | | type: {self.thread_run_step.step_details.type}""" 79 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ATUI - Assistant Textual User Interface 2 | 3 | This is a [textual](https://textual.textualize.io/) user interface for the for OpenAI's [Assistant](https://platform.openai.com/docs/assistants/overview) API. 4 | 5 | https://github.com/dalssoft/textual-assistant/assets/209287/71d28cd9-1dd0-445d-a644-d62b255fe192 6 | 7 | # How to use 8 | 9 | ## Setup 10 | 11 | ```bash 12 | pyenv install 3.12.0 # install python version 13 | pyenv local 3.12.0 # set python version 14 | python -m venv venv # create virtual environment 15 | source ./venv/bin/activate # activate virtual environment 16 | pip3 install -r requirements.txt # install dependencies 17 | ``` 18 | 19 | ## Run 20 | 21 | ```bash 22 | export OPENAI_API_KEY=sk-.... # set openai api key. alternatively, create a .env file with this key 23 | python main.py 24 | ``` 25 | 26 | # Features 27 | 28 | - [x] Textual User Interface 29 | - [x] Assistant 30 | - [ ] Create 31 | - [x] Select 32 | - [x] Details 33 | - [x] Info 34 | - [ ] Functions Details 35 | - [x] Thread 36 | - [x] Create 37 | - [x] Select 38 | - [x] Details 39 | - [x] Message 40 | - [x] List 41 | - [ ] Annotations 42 | - [ ] Attachments 43 | - [x] Tools 44 | - [x] Code Interpreter 45 | - [x] Retrieve 46 | - [ ] Function 47 | - [x] Debug: Message / Runs / Steps 48 | - [x] Log 49 | 50 | - [x] Domain 51 | - [x] Assistant 52 | - [x] Create 53 | - [x] Retrieve 54 | - [x] Delete 55 | - [x] List All 56 | - [ ] Update 57 | - [ ] Files 58 | - [x] Thread 59 | - [x] Create 60 | - [x] Retrieve 61 | - [x] Delete 62 | - [x] List All 63 | - [ ] Update 64 | - [x] Retrieve Messages 65 | - [x] Events 66 | - [x] New Message 67 | - [x] Persist 68 | - [x] Local 69 | - [ ] DB 70 | - [x] Message 71 | - [x] Create 72 | - [x] Retrieve 73 | - [ ] Update 74 | - [ ] Files 75 | - [x] Run 76 | - [x] Create 77 | - [x] Retrieve 78 | - [x] Cancel 79 | - [ ] List All 80 | - [ ] Update 81 | - [x] Polling 82 | - [x] Events 83 | - [x] Status Change 84 | - [x] Step 85 | - [x] Types 86 | - [x] Message Creation 87 | - [x] Tool Call 88 | - [x] Retrieve 89 | - [x] Polling 90 | - [x] Events 91 | - [x] New Step 92 | - [x] Status Change 93 | - [x] Tools 94 | - [x] Code Interpreter 95 | - [x] Retrieve 96 | - [x] Function 97 | - [ ] Handle Function Call 98 | -------------------------------------------------------------------------------- /app/terminal/thread/thread_list.py: -------------------------------------------------------------------------------- 1 | from concurrent.futures import thread 2 | from textual import on 3 | from textual.containers import ScrollableContainer 4 | from textual.widgets import Header, Footer, Label, ListView, ListItem, Button, Static 5 | from textual.containers import Vertical, Container 6 | from textual.reactive import reactive 7 | from textual.message import Message 8 | from domain.thread_list import ThreadList 9 | from log import log_action 10 | 11 | 12 | class ThreadListContainer(ScrollableContainer): 13 | assistant = reactive(None) 14 | threads = reactive([]) 15 | current_thread = reactive(None) 16 | 17 | def compose(self): 18 | yield Vertical( 19 | Static(id="assistant_name"), 20 | Label("Threads:"), 21 | ListView(id="thread_list"), 22 | Button("Create Thread", id="create_thread_button"), 23 | ) 24 | 25 | @on(Button.Pressed, "#create_thread_button") 26 | def create_thread_click(self, event): 27 | self.create_thread() 28 | 29 | @on(ListView.Selected, "#thread_list") 30 | def thread_list_selected(self, event): 31 | self.select_thread(event.item.id) 32 | 33 | def watch_assistant(self, assistant): 34 | self.update_assistant(assistant) 35 | self.list_all_threads(assistant) 36 | 37 | def watch_threads(self, threads): 38 | self.update_threads_list(threads) 39 | 40 | def watch_current_thread(self, thread): 41 | self.post_message(self.ThreadSelected(thread)) 42 | log_action(self, "watch_current_thread", thread) 43 | 44 | def update_assistant(self, assistant): 45 | if assistant is None: 46 | return 47 | self.query_one("#assistant_name").update(assistant.name) 48 | 49 | def update_threads_list(self, threads): 50 | if not threads: 51 | return 52 | list_view = self.query_one("#thread_list") 53 | list_view.clear() 54 | for thread in threads: 55 | list_view.append(ListItem(Label(thread.name), id=thread.id)) 56 | 57 | log_action(self, "list_all_threads") 58 | 59 | def list_all_threads(self, assistant): 60 | if assistant is None: 61 | return 62 | self.threads = ThreadList(assistant.id).list_all() 63 | log_action(self, "list_all_threads") 64 | 65 | def create_thread(self): 66 | thread_name = "New Thread - " + str(len(self.threads) + 1) 67 | self.current_thread = ThreadList(self.assistant.id).create_thread(thread_name) 68 | self.list_all_threads(self.assistant) 69 | list = self.query_one("#thread_list") 70 | items = list.children 71 | for index, item in enumerate(items): 72 | if item.id == self.current_thread.id: 73 | list.index = index 74 | 75 | log_action(self, "create_thread") 76 | 77 | def select_thread(self, thread_id): 78 | self.current_thread = next( 79 | (thread for thread in self.threads if thread.id == thread_id) 80 | ) 81 | log_action(self, "select_thread") 82 | 83 | class ThreadSelected(Message): 84 | def __init__(self, thread): 85 | super().__init__() 86 | self.thread = thread 87 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | .DS_Store 163 | 164 | threads_*.json -------------------------------------------------------------------------------- /app/terminal/app.css: -------------------------------------------------------------------------------- 1 | /* Terminal App using Textual Python */ 2 | 3 | #assistant_container { 4 | width: 30%; 5 | height: 100%; 6 | border-right: tall gray; 7 | padding: 1; 8 | } 9 | 10 | PlaceholderWithLabel { 11 | height: auto; 12 | margin-top: 1; 13 | } 14 | 15 | AssistantContainer #empty_details Center Static{ 16 | width: auto; 17 | border: tall $surface-lighten-2; 18 | padding-left: 1; 19 | padding-right: 1; 20 | margin: 2; 21 | margin-top: 4; 22 | } 23 | 24 | AssistantContainer #empty_details Static { 25 | margin: 1; 26 | } 27 | 28 | AssistantContainer .label { 29 | text-style: bold underline; 30 | width: auto; 31 | } 32 | 33 | PlaceholderWithLabel .placeholder { 34 | background: rgb(22, 22, 22); 35 | width: 100%; 36 | height: auto; 37 | margin-left: 1; 38 | } 39 | 40 | AssistantContainer .tools { 41 | height: auto; 42 | margin-top: 1; 43 | } 44 | 45 | AssistantContainer ToolSwitch { 46 | align: left middle; 47 | height: auto; 48 | margin-top: 1; 49 | } 50 | 51 | AssistantContainer ToolSwitch Switch { 52 | width: auto; 53 | } 54 | 55 | AssistantContainer ToolSwitch Middle { 56 | height: 100%; 57 | } 58 | 59 | AssistantContainer ToolSwitch Static { 60 | width: auto; 61 | } 62 | 63 | AssistantContainer .tools PlaceholderWithLabel { 64 | height: auto; 65 | margin-left: 10; 66 | } 67 | 68 | Button { 69 | margin-top: 1; 70 | padding-left: 2; 71 | padding-right: 2; 72 | } 73 | 74 | .remove { 75 | display: none; 76 | } 77 | 78 | .singleline { 79 | layout: horizontal; 80 | } 81 | 82 | .multiline { 83 | layout: vertical; 84 | } 85 | 86 | ThreadListContainer { 87 | align: left top; 88 | width: 30%; 89 | border-right: tall gray; 90 | padding: 1; 91 | } 92 | 93 | ThreadListContainer #create_thread_button { 94 | dock: bottom; 95 | width: 100%; 96 | height: auto; 97 | } 98 | 99 | ThreadListContainer #assistant_name { 100 | dock: top; 101 | padding-left: 2; 102 | margin: 1; 103 | border: tall $surface-lighten-2; 104 | } 105 | 106 | ThreadListContainer Label { 107 | text-style: bold underline; 108 | } 109 | 110 | ThreadListContainer ListView { 111 | width: 100%; 112 | height: 100%; 113 | } 114 | 115 | ThreadListContainer ListItem { 116 | padding: 1; 117 | } 118 | 119 | ThreadListContainer ListItem Label { 120 | text-style: none; 121 | } 122 | 123 | ThreadMessagesContainer { 124 | width: 70%; 125 | height: 100%; 126 | padding-left: 4; 127 | padding-right: 4; 128 | } 129 | 130 | ThreadMessagesContainer NewMessage { 131 | width: 93%; 132 | height: 7; 133 | padding: 1; 134 | } 135 | 136 | ThreadMessagesContainer NewMessage TextArea { 137 | height: 100%; 138 | border: tall black; 139 | } 140 | 141 | ThreadMessagesContainer NewMessage Button { 142 | height: 100%; 143 | margin-left: 1; 144 | margin-right: 1; 145 | } 146 | 147 | ThreadMessagesContainer NewMessage Checkbox { 148 | margin-top: 1; 149 | } 150 | 151 | ThreadMessagesContainer MessageList { 152 | align: center bottom; 153 | } 154 | 155 | ThreadMessagesContainer MessageList ThreadMessage { 156 | height: auto; 157 | layout: horizontal; 158 | } 159 | 160 | ThreadMessagesContainer MessageList ThreadMessage #message_container { 161 | height: auto; 162 | layout: vertical; 163 | align: left middle; 164 | } 165 | 166 | ThreadMessagesContainer MessageList .avatar { 167 | margin-top: 1; 168 | width: auto; 169 | } 170 | 171 | ThreadMessagesContainer MessageList ThreadMessage .message { 172 | width: auto; 173 | max-width: 80%; 174 | border: round gray; 175 | } 176 | 177 | ThreadMessagesContainer MessageList UserMessage .message { 178 | border: round rgb(107, 198, 107); 179 | } 180 | 181 | ThreadMessagesContainer MessageList DebugMessage .message { 182 | border: round rgb(100, 100, 100); 183 | color: rgb(100, 100, 100); 184 | } 185 | 186 | ThreadMessagesContainer MessageList DebugMessage .avatar { 187 | color: rgb(100, 100, 100); 188 | } -------------------------------------------------------------------------------- /domain/run.py: -------------------------------------------------------------------------------- 1 | from log import log_action 2 | from openai import OpenAI 3 | from threading import Thread 4 | from domain.steps.step_list import StepList 5 | from domain.steps.message_creation import MessageCreationStep 6 | import time 7 | import asyncio 8 | 9 | 10 | class Run: 11 | run_status = { 12 | "queued": "queued", 13 | "in_progress": "in_progress", 14 | "requires_action": "requires_action", 15 | "cancelling": "cancelling", 16 | "cancelled": "cancelled", 17 | "failed": "failed", 18 | "completed": "completed", 19 | "expired": "expired", 20 | } 21 | 22 | def __init__(self, assistant, thread, id=None): 23 | self.client = OpenAI() 24 | self.assistant = assistant 25 | self.thread = thread 26 | self.id = id 27 | self.thread_run = None 28 | self.step_list = StepList(self) 29 | self.step_list.watch_for_new_step(self._on_new_step) 30 | self.status = self.run_status["queued"] 31 | self.callbacks = { 32 | "status_change": [], 33 | "new_step": [], 34 | "step_status_change": [], 35 | } 36 | 37 | async def _update_from_raw(self, thread_run): 38 | self.thread_run = thread_run 39 | previous_status = self.status 40 | self.status = thread_run.status 41 | if previous_status != self.status: 42 | await self._on_status_change() 43 | 44 | async def create(self): 45 | run = self.client.beta.threads.runs.create( 46 | thread_id=self.thread.id, 47 | assistant_id=self.assistant.id, 48 | ) 49 | self.id = run.id 50 | await self._update_from_raw(run) 51 | 52 | log_action(self, "create", run) 53 | 54 | return self 55 | 56 | async def refresh(self): 57 | run = self.client.beta.threads.runs.retrieve( 58 | thread_id=self.thread.id, 59 | run_id=self.id, 60 | ) 61 | self.thread_run = run 62 | await self._update_from_raw(run) 63 | 64 | log_action(self, "refresh", run) 65 | 66 | return self 67 | 68 | async def cancel(self): 69 | run = self.client.beta.threads.runs.cancel( 70 | thread_id=self.thread.id, 71 | run_id=self.id, 72 | ) 73 | await self._update_from_raw(run) 74 | 75 | log_action(self, "cancel", run) 76 | 77 | return self 78 | 79 | def watch_for_status_change(self, callback): 80 | self.callbacks["status_change"].append(callback) 81 | 82 | async def _on_status_change(self): 83 | for callback in self.callbacks["status_change"]: 84 | await callback(self) 85 | 86 | def watch_for_new_step(self, callback): 87 | self.callbacks["new_step"].append(callback) 88 | 89 | def watch_for_step_status_change(self, callback): 90 | self.callbacks["step_status_change"].append(callback) 91 | 92 | async def _on_new_step(self, new_step): 93 | for callback in self.callbacks["new_step"]: 94 | await callback(new_step) 95 | 96 | for callback in self.callbacks["step_status_change"]: 97 | new_step.watch_for_status_change(callback) 98 | 99 | await new_step.wait_for_completion() 100 | 101 | if isinstance(new_step, MessageCreationStep) and new_step.has_completed(): 102 | await self.thread.retrieve_message_and_append(new_step.message_id()) 103 | 104 | async def wait_for_completion(self): 105 | thread = Thread(target=asyncio.run, args=(self._polling(),)) 106 | thread.start() 107 | # await self._polling() 108 | 109 | async def _polling(self): 110 | run_status_to_watch = [ 111 | self.run_status["queued"], 112 | self.run_status["in_progress"], 113 | self.run_status["requires_action"], 114 | self.run_status["cancelling"], 115 | ] 116 | 117 | while self.status in run_status_to_watch: 118 | log_action(self, "_polling") 119 | await self.refresh() 120 | await self.step_list.refresh() 121 | time.sleep(0.1) 122 | -------------------------------------------------------------------------------- /app/terminal/assistant_container.py: -------------------------------------------------------------------------------- 1 | from os import name 2 | from typing import Container 3 | from textual import on 4 | from textual.app import App, ComposeResult 5 | from textual.containers import ScrollableContainer, Container 6 | from textual.widgets import Header, Footer, Static, Select, Label, Button, Switch 7 | from textual.containers import Horizontal, Vertical, Middle, Center 8 | from textual.message import Message 9 | from domain.assistant import Assistant 10 | from log import log_action 11 | from textual.reactive import reactive 12 | import webbrowser 13 | 14 | 15 | class PlaceholderWithLabel(Static): 16 | def __init__(self, label, placeholder_id, placeholder_text, classes="singleline"): 17 | self.label = label 18 | self.placeholder_id = placeholder_id 19 | self.placeholder_text = placeholder_text 20 | super().__init__(classes=classes) 21 | 22 | def compose(self) -> ComposeResult: 23 | yield Label(self.label, classes="label") 24 | yield Label( 25 | self.placeholder_text, id=self.placeholder_id, classes="placeholder" 26 | ) 27 | 28 | 29 | class ToolSwitch(Horizontal): 30 | value = reactive(False) 31 | switch = None 32 | 33 | def __init__(self, label, id): 34 | self.label = label 35 | super().__init__(id=id) 36 | 37 | def compose(self): 38 | self.switch = Switch(value=self.value, disabled=True) 39 | yield self.switch 40 | yield Middle(Static(self.label)) 41 | 42 | def watch_value(self, value): 43 | if self.switch is None: 44 | return 45 | self.switch.value = value 46 | 47 | 48 | class AssistantContainer(ScrollableContainer): 49 | assistants = [] 50 | assistant = None 51 | 52 | def compose(self) -> ComposeResult: 53 | yield Select( 54 | options=self.assistants, 55 | id="assistant_name", 56 | prompt="Select an assistant", 57 | ) 58 | with Vertical(id="empty_details"): 59 | link = "https://platform.openai.com/assistants" 60 | yield Center(Static("[b]Assistant Text User Interface[/b]")) 61 | yield Static("Select an existing assistant from the list above.") 62 | yield Static( 63 | f"Or you can create an assistant using the OpenAI platform on the web: {link}" 64 | ) 65 | 66 | with Vertical(id="assistant_details", classes="remove"): 67 | yield PlaceholderWithLabel( 68 | "Instructions:", 69 | "instructions", 70 | "...", 71 | classes="multiline", 72 | ) 73 | yield PlaceholderWithLabel("Model:", "model", "") 74 | yield Vertical( 75 | Label("Tools:", classes="label"), 76 | ToolSwitch( 77 | "Code Interpreter", 78 | id="code_interpreter", 79 | ), 80 | ToolSwitch( 81 | "Retrival", 82 | id="retrieval", 83 | ), 84 | ToolSwitch( 85 | "Function", 86 | id="function", 87 | ), 88 | PlaceholderWithLabel("Functions:", "functions", ""), 89 | classes="tools", 90 | ) 91 | yield PlaceholderWithLabel("Files:", "files", "") 92 | yield Center(Button("Use this assistant", id="use_assistant")) 93 | 94 | def action_open_url(self, url): 95 | webbrowser.open(url) 96 | 97 | def on_mount(self): 98 | self.list_all_assistants() 99 | 100 | @on(Select.Changed, "#assistant_name") 101 | def assistant_name_select_changed(self, event): 102 | empty_details = self.query_one("#empty_details") 103 | empty_details.add_class("remove") 104 | assistant_details = self.query_one("#assistant_details") 105 | assistant_details.remove_class("remove") 106 | self.select_assistant(event.value) 107 | 108 | @on(Button.Pressed, "#use_assistant") 109 | def use_assistant_button_clicked(self, event): 110 | event.stop() 111 | self.use_assistant() 112 | 113 | def list_all_assistants(self): 114 | self.assistants = Assistant.list_all() 115 | assistants = list( 116 | (assistant.name, assistant.id) for assistant in self.assistants 117 | ) 118 | self.query_one("#assistant_name").set_options(assistants) 119 | log_action(self, "list_all_assistants", self.assistants) 120 | 121 | def select_assistant(self, assistant_id): 122 | assistant = next( 123 | (assistant for assistant in self.assistants if assistant.id == assistant_id) 124 | ) 125 | self.assistant = assistant 126 | inner_assistant = assistant.assistant 127 | self.query_one("#instructions").update(inner_assistant.instructions) 128 | self.query_one("#model").update(inner_assistant.model) 129 | self.query_one("#code_interpreter").value = any( 130 | tool.type == "code_interpreter" for tool in inner_assistant.tools 131 | ) 132 | self.query_one("#retrieval").value = any( 133 | tool.type == "retrieval" for tool in inner_assistant.tools 134 | ) 135 | self.query_one("#function").value = any( 136 | tool.type == "function" for tool in inner_assistant.tools 137 | ) 138 | functions = list( 139 | tool.function.name 140 | for tool in inner_assistant.tools 141 | if tool.type == "function" 142 | ) 143 | self.query_one("#functions").update(", ".join(functions)) 144 | 145 | self.query_one("#files").update(", ".join(inner_assistant.file_ids)) 146 | log_action(self, "select_assistant", assistant) 147 | 148 | def use_assistant(self): 149 | self.add_class("assistant_selected") 150 | self.post_message(self.AssistantSelected(self.assistant)) 151 | log_action(self, "use_assistant", self.assistant) 152 | 153 | class AssistantSelected(Message): 154 | def __init__(self, assistant): 155 | super().__init__() 156 | self.assistant = assistant 157 | -------------------------------------------------------------------------------- /app/terminal/thread/thread_messages.py: -------------------------------------------------------------------------------- 1 | from ast import Add 2 | from concurrent.futures import thread 3 | from email import message 4 | from textual import on 5 | from textual.containers import ScrollableContainer 6 | from textual.widgets import Button, Static, TextArea, Checkbox 7 | from textual.containers import Vertical, Container, Horizontal, Center 8 | from textual.reactive import reactive 9 | from textual.message import Message 10 | from domain.message import Message as Msg 11 | from domain.run import Run 12 | from log import log_action 13 | import time 14 | 15 | 16 | class ThreadMessage(Container): 17 | avatar = "" 18 | container_classes = "" 19 | title = "" 20 | 21 | def __init__(self, text, **kwargs): 22 | self.text = text 23 | super().__init__(**kwargs) 24 | 25 | def compose(self): 26 | message = Static(self.text, classes="message") 27 | message.border_title = self.title 28 | yield Container( 29 | Static(self.avatar, classes="avatar"), 30 | message, 31 | classes=self.container_classes, 32 | id="message_container", 33 | ) 34 | 35 | 36 | class AssistantMessage(ThreadMessage): 37 | def __init__(self, text, annotations, **kwargs): 38 | self.annotations = annotations 39 | super().__init__(text, **kwargs) 40 | 41 | def compose(self): 42 | self.avatar = "🤖 Assistant" 43 | self.title = "" 44 | self.container_classes = "assistant" 45 | yield from super().compose() 46 | 47 | 48 | class UserMessage(ThreadMessage): 49 | def compose(self): 50 | self.avatar = "👤 You" 51 | self.title = "" 52 | self.container_classes = "user" 53 | yield from super().compose() 54 | 55 | 56 | class DebugMessage(ThreadMessage): 57 | def compose(self): 58 | self.avatar = f"🐞 Debug" 59 | self.title = "" 60 | self.container_classes = "debug" 61 | yield from super().compose() 62 | 63 | 64 | class MessageList(ScrollableContainer): 65 | thread = None 66 | 67 | async def clear(self): 68 | await self.remove_children() 69 | 70 | async def add_user_message(self, message): 71 | user_message = UserMessage(message.text(), id=message.id) 72 | await self.mount(user_message) 73 | return user_message 74 | 75 | async def add_assistant_message(self, message): 76 | assistant_message = AssistantMessage( 77 | message.text(), message.annotations(), id=message.id 78 | ) 79 | await self.mount(assistant_message) 80 | return assistant_message 81 | 82 | async def add_debug_message(self, text): 83 | current_date = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) 84 | text = f"{text}\n{current_date}" 85 | debug_message = DebugMessage(text) 86 | await self.mount(debug_message) 87 | return debug_message 88 | 89 | async def fill(self, thread): 90 | if thread is None: 91 | return 92 | self.thread = thread 93 | await self.clear() 94 | self.loading = True 95 | messages = thread.retrieve_messages(order="asc") 96 | for message in messages: 97 | add_message = { 98 | "user": self.add_user_message, 99 | "assistant": self.add_assistant_message, 100 | } 101 | await add_message[message.role](message) 102 | 103 | self.scroll_end(animate=True, duration=0.2) 104 | self.loading = False 105 | 106 | 107 | class NewMessage(ScrollableContainer): 108 | def compose(self): 109 | text_area = TextArea(id="message_textarea") 110 | text_area.show_line_numbers = False 111 | yield Horizontal( 112 | Checkbox("Debug", id="debug_checkbox"), 113 | text_area, 114 | Button("Send", id="send_message_button"), 115 | classes="new_message_container", 116 | ) 117 | 118 | @on(Button.Pressed, "#send_message_button") 119 | def send_message_click(self, event): 120 | self.send_message() 121 | 122 | def send_message(self): 123 | text = self.query_one("#message_textarea").text 124 | self.post_message(self.MessageSent(text)) 125 | self.query_one("#message_textarea").text = "" 126 | 127 | @on(Checkbox.Changed, "#debug_checkbox") 128 | def debug_checkbox_changed(self, event): 129 | self.post_message(self.DebugChanged(event.checkbox.value)) 130 | 131 | class MessageSent(Message): 132 | def __init__(self, text): 133 | super().__init__() 134 | self.text = text 135 | 136 | class DebugChanged(Message): 137 | def __init__(self, debug): 138 | super().__init__() 139 | self.debug = debug 140 | 141 | 142 | class ThreadMessagesContainer(ScrollableContainer): 143 | assistant = reactive(None) 144 | thread = reactive(None) 145 | is_debug = reactive(False) 146 | 147 | def compose(self): 148 | yield MessageList(id="message_list") 149 | yield Center(NewMessage(id="new_message")) 150 | 151 | def message_list(self): 152 | return self.query_one("#message_list") 153 | 154 | async def on_new_message_message_sent(self, event): 155 | await self.new_message(self.assistant, self.thread, event.text) 156 | 157 | async def on_new_message_debug_changed(self, event): 158 | self.is_debug = event.debug 159 | 160 | async def watch_thread(self, thread): 161 | if thread is None: 162 | self.add_class("remove") 163 | return 164 | self.remove_class("remove") 165 | thread.watch_for_new_message(self._on_new_message) 166 | await self.fill_message_list(thread) 167 | 168 | async def fill_message_list(self, thread): 169 | if thread is None: 170 | return 171 | await self.message_list().fill(thread) 172 | 173 | async def new_message(self, assistant, thread, text): 174 | button = self.query_one("#send_message_button") 175 | button.loading = True 176 | button.disabled = True 177 | message = Msg.create(thread, text) 178 | ui_user_message = await self.message_list().add_user_message(message) 179 | ui_user_message.scroll_visible(animate=True) 180 | run = Run(assistant, thread) 181 | run.watch_for_status_change(self._on_status_change) 182 | run.watch_for_new_step(self._on_new_step) 183 | run.watch_for_step_status_change(self._on_step_status_change) 184 | await run.create() 185 | await run.wait_for_completion() 186 | 187 | async def _on_new_message(self, message): 188 | if self.is_debug: 189 | await self.message_list().add_debug_message( 190 | f"""New message 191 | | id: {message.id} 192 | | role: {message.role} 193 | """ 194 | ) 195 | self.message_list().scroll_end() 196 | 197 | message_list = self.message_list() 198 | await message_list.add_assistant_message(message) 199 | message_list.scroll_end(animate=True, duration=0.2) 200 | button = self.query_one("#send_message_button") 201 | button.loading = False 202 | button.disabled = False 203 | 204 | async def _on_status_change(self, thread_run): 205 | if self.is_debug: 206 | await self.message_list().add_debug_message( 207 | f"""Run status changed 208 | | id: {thread_run.id} 209 | | status: {thread_run.status} 210 | """ 211 | ) 212 | self.message_list().scroll_end() 213 | 214 | async def _on_new_step(self, thread_run_step): 215 | if self.is_debug: 216 | await self.message_list().add_debug_message( 217 | f"""New step 218 | {thread_run_step.debug()} 219 | """ 220 | ) 221 | self.message_list().scroll_end() 222 | 223 | async def _on_step_status_change(self, thread_run_step): 224 | if self.is_debug: 225 | await self.message_list().add_debug_message( 226 | f"""Step status changed 227 | {thread_run_step.debug()} 228 | """ 229 | ) 230 | self.message_list().scroll_end() 231 | --------------------------------------------------------------------------------