├── .gitattributes ├── assistants_api ├── app │ ├── __init__.py │ ├── lib │ │ ├── db │ │ │ ├── __init__.py │ │ │ ├── database.py │ │ │ ├── schemas.py │ │ │ └── models.py │ │ ├── fs │ │ │ ├── schemas.py │ │ │ ├── store.py │ │ │ └── actions.py │ │ ├── mb │ │ │ ├── actions.py │ │ │ └── broker.py │ │ └── wv │ │ │ ├── client.py │ │ │ └── actions.py │ ├── models │ │ ├── __init__.py │ │ └── thread.py │ ├── routers │ │ ├── __init__.py │ │ ├── runsteps_router.py │ │ ├── ops │ │ │ ├── run_ops_router.py │ │ │ ├── runsteps_ops_router.py │ │ │ └── web_retrieval_ops_router.py │ │ ├── threads_router.py │ │ ├── file_router.py │ │ ├── message_router.py │ │ ├── run_router.py │ │ ├── assistant_router.py │ │ └── vectorstore_router.py │ ├── utils │ │ ├── __init__.py │ │ ├── document_loader.py │ │ ├── tranformers.py │ │ └── crawling.py │ └── main.py ├── assets │ ├── test.pdf │ ├── openapi-generator.py │ ├── my_information.txt │ ├── test.txt │ └── code-reference.txt ├── openai-1.26.0-py3-none-any.whl ├── Dockerfile ├── requirements.txt └── tests │ ├── ops │ ├── test_runs_ops.py │ ├── test_web_crawling_ops.py │ └── test_runsteps_ops.py │ ├── test_threads.py │ ├── test_files.py │ ├── test_assistant.py │ ├── test_messages.py │ └── test_run_function_calling.py ├── run_executor_worker ├── src │ ├── __init__.py │ ├── utils │ │ ├── __init__.py │ │ ├── context.py │ │ ├── weaviate_utils.py │ │ ├── tools.py │ │ ├── coala.py │ │ ├── ops_api_handler.py │ │ └── openai_clients.py │ ├── data_models │ │ ├── __init__.py │ │ ├── runstep.py │ │ └── run.py │ ├── run_executor │ │ ├── __init__.py │ │ └── main.py │ ├── constants.py │ ├── actions │ │ ├── web_retrieval.py │ │ ├── file_search.py │ │ └── function_calling_tool.py │ ├── consumer.py │ └── agents │ │ └── router.py ├── openai-1.26.0-py3-none-any.whl ├── Dockerfile ├── requirements.txt └── scripts │ └── watcher.py ├── .dockerignore ├── .gitignore ├── pyproject.toml ├── .flake8 ├── .pre-commit-config.yaml ├── .env.example ├── LICENSE ├── rundev.sh ├── docker-compose.dev.yml ├── README.md └── examples └── compounding_demo.ipynb /.gitattributes: -------------------------------------------------------------------------------- 1 | * text=auto eol=lf -------------------------------------------------------------------------------- /assistants_api/app/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assistants_api/app/lib/db/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assistants_api/app/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assistants_api/app/routers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assistants_api/app/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /run_executor_worker/src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /run_executor_worker/src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /run_executor_worker/src/data_models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /run_executor_worker/src/run_executor/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | # backend 2 | # __pycache__ 3 | .venv 4 | venv 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .venv 2 | __pycache__ 3 | .env* 4 | .pytest_cache 5 | !.env.example -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 79 3 | skip-string-normalization = true -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 88 3 | extend-ignore = E203, W503 4 | exclude = .git,__pycache__,old,build,dist,.venv 5 | -------------------------------------------------------------------------------- /assistants_api/app/lib/fs/schemas.py: -------------------------------------------------------------------------------- 1 | from openai.types import FileObject, FileDeleted 2 | 3 | FileObject 4 | FileDeleted 5 | -------------------------------------------------------------------------------- /assistants_api/assets/test.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGPTs-platform/assistants-api/HEAD/assistants_api/assets/test.pdf -------------------------------------------------------------------------------- /assistants_api/openai-1.26.0-py3-none-any.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGPTs-platform/assistants-api/HEAD/assistants_api/openai-1.26.0-py3-none-any.whl -------------------------------------------------------------------------------- /run_executor_worker/openai-1.26.0-py3-none-any.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGPTs-platform/assistants-api/HEAD/run_executor_worker/openai-1.26.0-py3-none-any.whl -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black 3 | rev: 23.12.1 4 | hooks: 5 | - id: black 6 | - repo: https://github.com/PyCQA/flake8 7 | rev: 6.1.0 8 | hooks: 9 | - id: flake8 10 | -------------------------------------------------------------------------------- /assistants_api/app/lib/mb/actions.py: -------------------------------------------------------------------------------- 1 | from broker import RabbitMQBroker 2 | 3 | 4 | def enqueue_run(broker: RabbitMQBroker, queue_name: str, run_id: str): 5 | """Enqueue a run ID to a specified queue.""" 6 | broker.publish(queue_name=queue_name, message=run_id) 7 | -------------------------------------------------------------------------------- /run_executor_worker/src/constants.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from pydantic import BaseModel 3 | 4 | 5 | class PromptKeys(Enum): 6 | TRANSITION = "" 7 | 8 | 9 | class WebRetrievalResult(BaseModel): 10 | url: str 11 | content: str 12 | depth: int 13 | -------------------------------------------------------------------------------- /assistants_api/app/lib/wv/client.py: -------------------------------------------------------------------------------- 1 | import weaviate 2 | import os 3 | 4 | WEAVIATE_HOST = os.getenv("WEAVIATE_HOST") 5 | OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") 6 | 7 | client = weaviate.connect_to_local( 8 | host=WEAVIATE_HOST, 9 | port=8080, 10 | grpc_port=50051, 11 | headers={ 12 | "X-OpenAI-Api-Key": os.getenv("OPENAI_API_KEY"), 13 | }, 14 | ) 15 | -------------------------------------------------------------------------------- /run_executor_worker/src/utils/context.py: -------------------------------------------------------------------------------- 1 | from typing import List, Any 2 | 3 | 4 | def context_trimmer( 5 | item_list: List[Any], max_length: int, trim_start: bool 6 | ) -> List[Any]: 7 | def calculate_length(items: List[Any]) -> int: 8 | return len(str(items)) 9 | 10 | trimmed_list = item_list[:] 11 | 12 | while calculate_length(trimmed_list) > max_length: 13 | if trim_start: 14 | trimmed_list.pop(0) 15 | else: 16 | trimmed_list.pop() 17 | 18 | return trimmed_list 19 | -------------------------------------------------------------------------------- /assistants_api/assets/openapi-generator.py: -------------------------------------------------------------------------------- 1 | import json 2 | import jsonref 3 | import requests 4 | import yaml 5 | 6 | url = 'https://raw.githubusercontent.com/openai/openai-openapi/master/openapi.yaml' 7 | response = requests.get(url) 8 | yaml_data = yaml.safe_load(response.text) 9 | 10 | string_json_data = json.dumps(yaml_data) 11 | 12 | dereferenced = jsonref.loads(string_json_data) 13 | 14 | # save prettified in final dereferenced file 15 | with open('openai-openapi-dereferenced.json', 'w') as f: 16 | f.write(json.dumps(dereferenced, indent=2)) 17 | -------------------------------------------------------------------------------- /assistants_api/app/lib/fs/store.py: -------------------------------------------------------------------------------- 1 | import os 2 | from minio import Minio 3 | 4 | ACCESS_KEY = os.getenv('MINIO_ACCESS_KEY') 5 | SECRET_KEY = os.getenv('MINIO_SECRET_KEY') 6 | MINIO_ENDPOINT = os.getenv('MINIO_ENDPOINT') 7 | 8 | BUCKET_NAME = "store" 9 | 10 | 11 | # dependency 12 | def minio_client(): 13 | minio_client = Minio( 14 | "minio:9000", 15 | access_key=ACCESS_KEY, 16 | secret_key=SECRET_KEY, 17 | secure=False, 18 | ) 19 | # Create bucket if it doesn't exist 20 | found = minio_client.bucket_exists(BUCKET_NAME) 21 | if not found: 22 | minio_client.make_bucket(BUCKET_NAME) 23 | return minio_client 24 | -------------------------------------------------------------------------------- /run_executor_worker/Dockerfile: -------------------------------------------------------------------------------- 1 | # Use the official Python image from the Docker Hub 2 | FROM python:3.10-slim 3 | 4 | # Set the working directory 5 | WORKDIR /app 6 | 7 | # Install the necessary system dependencies 8 | RUN apt-get update && apt-get install -y libpq-dev gcc git 9 | 10 | # Copy the requirements file into the container 11 | COPY requirements.txt . 12 | COPY openai-1.26.0-py3-none-any.whl . 13 | 14 | # Install the dependencies 15 | RUN pip install --no-cache-dir -r requirements.txt 16 | 17 | # Copy the rest of the application code into the container 18 | COPY . . 19 | 20 | # Set environment variables 21 | ENV PYTHONUNBUFFERED=1 22 | 23 | # Command to run the watcher script 24 | CMD ["python", "scripts/watcher.py"] 25 | -------------------------------------------------------------------------------- /.env.example: -------------------------------------------------------------------------------- 1 | # Postgres 2 | POSTGRES_HOST=postgres 3 | POSTGRES_PORT=5432 4 | POSTGRES_USER=devuser 5 | POSTGRES_PASSWORD=devpass 6 | POSTGRES_DB=mydatabase 7 | 8 | # Minio 9 | MINIO_ENDPOINT=minio 10 | MINIO_ACCESS_KEY=acess_key 11 | MINIO_SECRET_KEY=secrret_key 12 | 13 | #RabbitMQ 14 | RABBITMQ_DEFAULT_USER=devuser 15 | RABBITMQ_DEFAULT_PASS=devpass 16 | RABBITMQ_HOST=rabbitmq 17 | RABBITMQ_PORT=5672 18 | 19 | # Weaviate 20 | WEAVIATE_HOST=localhost 21 | 22 | # Used for embedding model and or testing against OpenAI Assistants API 23 | # NEEDED vv 24 | OPENAI_API_KEY= 25 | 26 | # Generator (NEED some combination of these) 27 | LITELLM_API_URL= 28 | LITELLM_API_KEY= 29 | LITELLM_MODEL= 30 | 31 | # Function Calling Generator (currently) 32 | # NEEDED vv 33 | FC_API_URL= 34 | FC_API_KEY= 35 | # NEEDED vv 36 | FC_MODEL= -------------------------------------------------------------------------------- /assistants_api/Dockerfile: -------------------------------------------------------------------------------- 1 | # Use an official Python runtime as a parent image 2 | FROM python:3.10-slim 3 | 4 | # Set the working directory inside the container 5 | WORKDIR / 6 | 7 | # Install the necessary system dependencies 8 | RUN apt-get update && apt-get install -y libpq-dev gcc git 9 | 10 | # Copy the requirements file into the container 11 | COPY requirements.txt . 12 | COPY openai-1.26.0-py3-none-any.whl . 13 | 14 | # Install any needed packages specified in requirements.txt 15 | RUN pip install --no-cache-dir -r requirements.txt 16 | 17 | # Copy the current directory contents into the container at /app 18 | COPY . . 19 | 20 | # Make port 8000 available to the world outside this container 21 | EXPOSE 8000 22 | 23 | # Navigate to the app directory (if your FastAPI app is inside an 'app' folder) 24 | WORKDIR ./app 25 | 26 | # Run main.py when the container launches 27 | CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000", "--reload"] 28 | -------------------------------------------------------------------------------- /run_executor_worker/requirements.txt: -------------------------------------------------------------------------------- 1 | annotated-types==0.6.0 2 | anyio==4.3.0 3 | Authlib==1.3.0 4 | certifi==2024.2.2 5 | cffi==1.16.0 6 | cfgv==3.4.0 7 | charset-normalizer==3.3.2 8 | colorama==0.4.6 9 | cryptography==3.4.7 10 | distlib==0.3.8 11 | distro==1.9.0 12 | exceptiongroup==1.2.0 13 | filelock==3.14.0 14 | grpcio==1.62.2 15 | grpcio-health-checking==1.62.2 16 | grpcio-tools==1.62.2 17 | h11==0.14.0 18 | httpcore==1.0.5 19 | httpx==0.27.0 20 | identify==2.5.36 21 | idna==3.6 22 | nodeenv==1.8.0 23 | ./openai-1.26.0-py3-none-any.whl 24 | pika==1.3.2 25 | platformdirs==4.2.2 26 | pre-commit==3.7.1 27 | protobuf==4.25.3 28 | pycparser==2.22 29 | pydantic==2.6.4 30 | pydantic_core==2.16.3 31 | python-dateutil==2.9.0.post0 32 | python-dotenv==1.0.1 33 | PyYAML==6.0.1 34 | requests==2.31.0 35 | sniffio==1.3.1 36 | tqdm==4.66.2 37 | typing_extensions==4.11.0 38 | urllib3==2.2.1 39 | validators==0.28.0 40 | virtualenv==20.26.2 41 | watchfiles==0.21.0 42 | weaviate-client==4.5.6 43 | -------------------------------------------------------------------------------- /run_executor_worker/src/data_models/runstep.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | from typing import Literal, Optional, Dict, Any, Union 3 | from openai.types.beta.threads.runs import ( 4 | RunStep, 5 | MessageCreationStepDetails, 6 | ToolCallsStepDetails, 7 | ) 8 | 9 | RunStep 10 | 11 | StepDetails = Union[MessageCreationStepDetails, ToolCallsStepDetails] 12 | 13 | 14 | class RunStepUpdate(BaseModel): 15 | assistant_id: Optional[str] = None 16 | cancelled_at: Optional[int] = None 17 | completed_at: Optional[int] = None 18 | expired_at: Optional[int] = None 19 | failed_at: Optional[int] = None 20 | last_error: Optional[Dict[str, Any]] = None 21 | metadata: Optional[Dict[str, Any]] = None 22 | status: Literal[ 23 | "in_progress", "cancelled", "failed", "completed", "expired" 24 | ] = None 25 | step_details: StepDetails = None 26 | type: Literal["message_creation", "tool_calls"] = None 27 | usage: Optional[Dict[str, Any]] = None 28 | -------------------------------------------------------------------------------- /assistants_api/app/routers/runsteps_router.py: -------------------------------------------------------------------------------- 1 | # routers/run_steps.py 2 | from fastapi import APIRouter, Depends 3 | from sqlalchemy.orm import Session 4 | from utils.tranformers import db_to_pydantic_runstep 5 | from lib.db import crud, schemas 6 | from lib.db.database import get_db 7 | 8 | router = APIRouter() 9 | 10 | 11 | @router.get( 12 | "/threads/{thread_id}/runs/{run_id}/steps", 13 | response_model=schemas.SyncCursorPage[schemas.RunStep], 14 | ) 15 | def get_run_steps( 16 | thread_id: str, 17 | run_id: str, 18 | limit: int = 20, 19 | order: str = "desc", 20 | after: str = None, 21 | before: str = None, 22 | db: Session = Depends(get_db), 23 | ): 24 | db_run_steps = crud.get_run_steps( 25 | db, thread_id, run_id, limit, order, after, before 26 | ) 27 | 28 | run_steps = [db_to_pydantic_runstep(run_step) for run_step in db_run_steps] 29 | 30 | paginated_run_steps = schemas.SyncCursorPage(data=run_steps) 31 | 32 | return paginated_run_steps 33 | -------------------------------------------------------------------------------- /assistants_api/app/lib/db/database.py: -------------------------------------------------------------------------------- 1 | import os 2 | from sqlalchemy import create_engine 3 | from sqlalchemy.ext.declarative import declarative_base 4 | from sqlalchemy.orm import sessionmaker 5 | 6 | POSTGRES_USER = os.getenv("POSTGRES_USER") 7 | POSTGRES_PASSWORD = os.getenv("POSTGRES_PASSWORD") 8 | POSTGRES_HOST = os.getenv("POSTGRES_HOST") 9 | POSTGRES_PORT = os.getenv("POSTGRES_PORT") 10 | POSTGRES_DB = os.getenv("POSTGRES_DB") 11 | 12 | databse_url = ( 13 | f"postgresql://{POSTGRES_USER}:{POSTGRES_PASSWORD}@" 14 | + f"{POSTGRES_HOST}:{POSTGRES_PORT}/{POSTGRES_DB}" 15 | ) 16 | 17 | engine = create_engine(databse_url) 18 | SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) 19 | 20 | Base = declarative_base() 21 | 22 | 23 | # Dependency 24 | def get_db(): 25 | db = SessionLocal() 26 | try: 27 | yield db 28 | finally: 29 | db.close() 30 | 31 | 32 | def reset_db(): 33 | Base.metadata.drop_all(bind=engine) 34 | Base.metadata.create_all(bind=engine) 35 | return True 36 | -------------------------------------------------------------------------------- /assistants_api/app/routers/ops/run_ops_router.py: -------------------------------------------------------------------------------- 1 | # In your FastAPI router file 2 | from fastapi import APIRouter, Body, Depends, HTTPException, Path 3 | from sqlalchemy.orm import Session 4 | from lib.db import ( 5 | crud, 6 | schemas, 7 | database, 8 | ) # Import your CRUD handlers, schemas, and models 9 | from utils.tranformers import db_to_pydantic_run 10 | 11 | router = APIRouter() 12 | 13 | 14 | @router.post( 15 | "/ops/threads/{thread_id}/runs/{run_id}", response_model=schemas.Run 16 | ) 17 | def update_run( 18 | thread_id: str = Path(..., title="The ID of the thread"), 19 | run_id: str = Path(..., title="The ID of the run to update"), 20 | run_update: schemas.RunUpdate = Body(..., title="The fields to update"), 21 | db: Session = Depends(database.get_db), 22 | ): 23 | db_run = crud.update_run( 24 | db, 25 | thread_id=thread_id, 26 | run_id=run_id, 27 | run_update=run_update.model_dump(exclude_none=True), 28 | ) 29 | if db_run is None: 30 | raise HTTPException(status_code=404, detail="Run not found") 31 | return db_to_pydantic_run(db_run) 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 OPENGPT LLC 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /assistants_api/app/lib/mb/broker.py: -------------------------------------------------------------------------------- 1 | import pika 2 | import os 3 | 4 | RABBITMQ_DEFAULT_USER = os.getenv("RABBITMQ_DEFAULT_USER") 5 | RABBITMQ_DEFAULT_PASS = os.getenv("RABBITMQ_DEFAULT_PASS") 6 | RABBITMQ_HOST = os.getenv("RABBITMQ_HOST") 7 | RABBITMQ_PORT = os.getenv("RABBITMQ_PORT") 8 | 9 | 10 | class RabbitMQBroker: 11 | def __init__(self): 12 | credentials = pika.PlainCredentials( 13 | RABBITMQ_DEFAULT_USER, RABBITMQ_DEFAULT_PASS 14 | ) 15 | self.connection = pika.BlockingConnection( 16 | pika.ConnectionParameters( 17 | host=RABBITMQ_HOST, 18 | port=RABBITMQ_PORT, 19 | credentials=credentials, 20 | heartbeat=30, 21 | ) 22 | ) 23 | self.channel = self.connection.channel() 24 | 25 | def publish(self, queue_name: str, message: str): 26 | self.channel.queue_declare(queue=queue_name, durable=True) 27 | self.channel.basic_publish( 28 | exchange='', 29 | routing_key=queue_name, 30 | body=message, 31 | properties=pika.BasicProperties( 32 | delivery_mode=1, 33 | ), 34 | ) 35 | 36 | def close_connection(self): 37 | self.connection.close() 38 | 39 | 40 | def get_broker(): 41 | return RabbitMQBroker() 42 | -------------------------------------------------------------------------------- /rundev.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Load .env file from current directory 4 | if [ -f .env ]; then 5 | echo "Loading .env file from current directory" 6 | export $(cat .env | sed 's/#.*//g' | xargs) 7 | fi 8 | 9 | # Set Docker Compose HTTP timeout 10 | export COMPOSE_HTTP_TIMEOUT=200 11 | 12 | # Function to bring up Docker Compose with retries 13 | function start_docker_compose { 14 | local retries=3 15 | local count=0 16 | 17 | while [ $count -lt $retries ]; do 18 | docker-compose -f docker-compose.dev.yml down 19 | docker-compose -f docker-compose.dev.yml up 20 | if [ $? -eq 0 ]; then 21 | echo "Docker Compose started successfully" 22 | return 0 23 | else 24 | echo "Error starting Docker Compose, retrying... ($((count+1))/$retries)" 25 | count=$((count+1)) 26 | sleep 10 27 | fi 28 | done 29 | 30 | echo "Failed to start Docker Compose after $retries attempts" 31 | return 1 32 | } 33 | 34 | # Start Docker Compose with retries 35 | if start_docker_compose; then 36 | # Sleep for a specified duration to allow services to initialize 37 | sleep 100 38 | # Show the logs and follow them 39 | docker-compose -f docker-compose.dev.yml logs -f 40 | else 41 | echo "Exiting script due to failure in starting Docker Compose." 42 | exit 1 43 | fi 44 | -------------------------------------------------------------------------------- /assistants_api/app/lib/wv/actions.py: -------------------------------------------------------------------------------- 1 | import weaviate.classes as wvc 2 | from lib.wv.client import client as weaviate_client 3 | from utils.document_loader import DocumentLoader 4 | from weaviate.collections import Collection 5 | 6 | 7 | def id_to_string(id: int) -> str: 8 | # need to remove all the - from the uuid 9 | return str(id).replace("-", "") 10 | 11 | 12 | def create_collection(name: str) -> Collection: 13 | collection = weaviate_client.collections.create( 14 | name=id_to_string(name), 15 | vectorizer_config=wvc.config.Configure.Vectorizer.text2vec_openai(), 16 | generative_config=wvc.config.Configure.Generative.openai(), 17 | ) 18 | 19 | return collection 20 | 21 | 22 | def delete_collection(name: str) -> None: 23 | weaviate_client.collections.delete(name=id_to_string(name)) 24 | 25 | 26 | def upload_file_chunks( 27 | file_data: bytes, file_name: str, file_id: str, vector_store_id: str 28 | ) -> int: 29 | document = DocumentLoader(file_data=file_data, file_name=file_name) 30 | document.read() 31 | chunks = document.split(text_length=300, text_overlap=100) 32 | 33 | collection = weaviate_client.collections.get( 34 | name=id_to_string(vector_store_id) 35 | ) 36 | 37 | data = [{"text": chunk, "file_id": file_id} for chunk in chunks] 38 | collection.data.insert_many(data) 39 | 40 | return len(chunks) 41 | -------------------------------------------------------------------------------- /assistants_api/app/models/thread.py: -------------------------------------------------------------------------------- 1 | # from enum import Enum 2 | # from openai.types.beta.assistant import ( 3 | # ToolCodeInterpreter, 4 | # ToolRetrieval, 5 | # ) 6 | 7 | # # from openai.types.beta.thread import Thread 8 | # from openai.types.beta.threads import ThreadMessage as OpenaiThreadMessage 9 | # from openai.types.beta.threads.runs import RunStep 10 | # from pydantic import BaseModel 11 | # from typing import Dict, List, Union 12 | 13 | # Tool = Union[ToolCodeInterpreter, ToolRetrieval] 14 | 15 | 16 | # class ThreadMessage(OpenaiThreadMessage): 17 | # pass 18 | 19 | 20 | # class RunStatus(Enum): 21 | # QUEUED = "queued" 22 | # IN_PROGRESS = "in_progress" 23 | # REQUIRES_ACTION = "requires_action" 24 | # CANCELLING = "cancelling" 25 | # CANCELLED = "cancelled" 26 | # FAILED = "failed" 27 | # COMPLETED = "completed" 28 | # EXPIRED = "expired" 29 | 30 | 31 | # # Thread 32 | # class ThreadMetadata(BaseModel): 33 | # gpt_id: str = None 34 | # user_id: str = None 35 | # title: str = None 36 | # last_updated: str = None 37 | 38 | 39 | # class CustomThread(BaseModel): 40 | # id: str 41 | # created_at: int 42 | # metadata: ThreadMetadata 43 | 44 | 45 | # class UpsertCustomThread(BaseModel): 46 | # gpt_id: str 47 | 48 | 49 | # class CreateThreadMessage(BaseModel): 50 | # content: str 51 | 52 | 53 | # class CreateThread(BaseModel): 54 | # title: str 55 | 56 | 57 | # class MessagesRunStepResponse(BaseModel): 58 | # messages: List[ThreadMessage] 59 | # runs_steps: Dict[str, List[RunStep]] 60 | -------------------------------------------------------------------------------- /assistants_api/assets/my_information.txt: -------------------------------------------------------------------------------- 1 | - **Lead Satellite Systems Engineer at Orbital Dynamics:** Directed the design and deployment of communication and weather satellites, improving global data coverage and transmission capabilities. 2 | 3 | - **Developed Advanced Earth Observation Satellites:** Spearheaded the creation of high-resolution Earth observation satellites, providing critical data for climate monitoring, disaster response, and agricultural management. 4 | 5 | - **Architect of Satellite Constellation Networks:** Designed and implemented a network of small satellites (CubeSats) for space debris tracking and management, significantly reducing collision risks in low Earth orbit. 6 | 7 | - **Pioneered Interplanetary Mission Designs:** Played a key role in the conceptualization and planning of missions to Mars and beyond, focusing on spacecraft navigation and landing technologies for harsh planetary environments. 8 | 9 | - **Conducted Research on Space Weather Effects:** Led a research team studying the impact of space weather on satellite operations, contributing to the development of more robust satellite designs that withstand solar and cosmic radiation. 10 | 11 | - **Developed Satellite Ground Control Software:** Created a comprehensive software suite for satellite ground control operations, enhancing the efficiency and reliability of satellite monitoring and control from Earth-based stations. 12 | 13 | - **Recipient of the Space Technology Excellence Award:** Honored for significant contributions to satellite technology and space exploration, particularly in advancing satellite propulsion and interplanetary mission design. -------------------------------------------------------------------------------- /run_executor_worker/scripts/watcher.py: -------------------------------------------------------------------------------- 1 | import os 2 | import signal 3 | import subprocess 4 | import platform 5 | from watchfiles import watch 6 | 7 | 8 | class Watcher: 9 | def __init__(self, command): 10 | self.command = command 11 | self.process = None 12 | 13 | def run(self): 14 | print("Current directory:", os.getcwd()) 15 | self.start_process() 16 | try: 17 | for changes in watch('.', recursive=True): 18 | print(f"Changes detected: {changes}") 19 | self.restart() 20 | except KeyboardInterrupt: 21 | self.stop_process() 22 | print("Shutting down gracefully...") 23 | 24 | def start_process(self): 25 | if platform.system() == 'Windows': 26 | self.process = subprocess.Popen( 27 | self.command, 28 | shell=True, 29 | creationflags=subprocess.CREATE_NEW_PROCESS_GROUP, 30 | ) 31 | else: 32 | self.process = subprocess.Popen( 33 | self.command, shell=True, preexec_fn=os.setsid 34 | ) 35 | 36 | def stop_process(self): 37 | if self.process: 38 | if platform.system() == 'Windows': 39 | self.process.send_signal(signal.CTRL_BREAK_EVENT) 40 | else: 41 | os.killpg(os.getpgid(self.process.pid), signal.SIGTERM) 42 | self.process.wait() 43 | 44 | def restart(self): 45 | print("File change detected. Restarting process...") 46 | self.stop_process() 47 | self.start_process() 48 | 49 | 50 | if __name__ == "__main__": 51 | command = "python src/consumer.py" 52 | watcher = Watcher(command) 53 | watcher.run() 54 | -------------------------------------------------------------------------------- /run_executor_worker/src/data_models/run.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | from typing import Literal, Optional, List, Any, Union 3 | from enum import Enum 4 | from openai.types.beta.threads.run import Run, RequiredAction 5 | from openai.types.beta.threads.runs import ( 6 | RunStep, 7 | MessageCreationStepDetails, 8 | ToolCallsStepDetails, 9 | ) 10 | 11 | Run 12 | RunStep 13 | RequiredAction 14 | 15 | 16 | class RunUpdate(BaseModel): 17 | assistant_id: Optional[str] = None 18 | cancelled_at: Optional[int] = None 19 | completed_at: Optional[int] = None 20 | expires_at: Optional[int] = None 21 | failed_at: Optional[int] = None 22 | file_ids: Optional[List[str]] = None 23 | instructions: Optional[str] = None 24 | last_error: Optional[Any] = None 25 | model: Optional[str] = None 26 | started_at: Optional[int] = None 27 | status: Optional[str] = None 28 | tools: Optional[Any] = None 29 | usage: Optional[Any] = None 30 | required_action: Optional[RequiredAction] = None 31 | 32 | 33 | class RunStatus(str, Enum): 34 | QUEUED = "queued" 35 | IN_PROGRESS = "in_progress" 36 | REQUIRES_ACTION = "requires_action" 37 | CANCELLING = "cancelling" 38 | CANCELLED = "cancelled" 39 | FAILED = "failed" 40 | COMPLETED = "completed" 41 | EXPIRED = "expired" 42 | 43 | 44 | StepDetails = Union[MessageCreationStepDetails, ToolCallsStepDetails] 45 | 46 | 47 | class RunStepCreate(BaseModel): 48 | # Define the fields required for creating a RunStep 49 | assistant_id: str 50 | step_details: Any 51 | type: Literal["message_creation", "tool_calls"] 52 | status: Literal[ 53 | "in_progress", "cancelled", "failed", "completed", "expired" 54 | ] 55 | step_details: StepDetails 56 | -------------------------------------------------------------------------------- /run_executor_worker/src/utils/weaviate_utils.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import weaviate 3 | import os 4 | import math 5 | 6 | WEAVIATE_HOST = os.getenv("WEAVIATE_HOST") 7 | OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") 8 | weaviate_client = weaviate.connect_to_local( 9 | host=WEAVIATE_HOST, 10 | port=8080, 11 | grpc_port=50051, 12 | headers={ 13 | "X-OpenAI-Api-Key": OPENAI_API_KEY, 14 | }, 15 | ) 16 | 17 | LIMIT = 2 18 | 19 | 20 | def id_to_string(id: int) -> str: 21 | # need to remove all the - from the uuid 22 | return str(id).replace("-", "") 23 | 24 | 25 | def retrieve_file_chunks(vector_store_ids: List[str], query: str) -> List[str]: 26 | chunks = [] 27 | for vector_store_id in vector_store_ids: 28 | collection = None 29 | if weaviate_client.collections.exists( 30 | name=id_to_string(vector_store_id) 31 | ): 32 | collection = weaviate_client.collections.get( 33 | name=id_to_string(vector_store_id) 34 | ) 35 | else: 36 | raise Exception(f"Collection {vector_store_id} does not exist.") 37 | 38 | retrieve_file_chunks = collection.query.near_text( 39 | query=query, 40 | limit=math.ceil(LIMIT / len(vector_store_ids)), 41 | ) 42 | print("RETRIEVE FILE CHUNKS: ", retrieve_file_chunks) 43 | 44 | chunks.extend( 45 | [ 46 | chunk.properties["text"] 47 | for chunk in retrieve_file_chunks.objects 48 | ] 49 | ) 50 | 51 | return chunks 52 | 53 | 54 | def get_web_retrieval_description() -> str: 55 | collection = weaviate_client.collections.get(name="web_retrieval") 56 | config = collection.config.get() 57 | return config.description 58 | -------------------------------------------------------------------------------- /assistants_api/assets/test.txt: -------------------------------------------------------------------------------- 1 | This is a text file containing text. 2 | Here is a second line of text. 3 | 123 4 | 5 | Lorem ipsum dolor sit amet, consectetur adipiscing elit 6 | sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. 7 | Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. 8 | Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. 9 | Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum. 10 | Lorem ipsum dolor sit amet, consectetur adipiscing elit 11 | sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. 12 | Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. 13 | Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. 14 | Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum. 15 | Lorem ipsum dolor sit amet, consectetur adipiscing elit 16 | sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. 17 | Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. 18 | Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. 19 | Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum. 20 | Lorem ipsum dolor sit amet, consectetur adipiscing elit 21 | sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. 22 | Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. 23 | Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. 24 | Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum. 25 | -------------------------------------------------------------------------------- /assistants_api/app/utils/document_loader.py: -------------------------------------------------------------------------------- 1 | import fitz 2 | from langchain_text_splitters import RecursiveCharacterTextSplitter 3 | 4 | 5 | class DocumentLoader: 6 | def __init__(self, file_data: bytes, file_name: str): 7 | self.file_data = file_data 8 | self.file_name = file_name 9 | self.text = "" 10 | 11 | def read(self): 12 | # Determine the file type and process accordingly 13 | if self.file_name.endswith('.pdf'): 14 | self.text = self._read_pdf() 15 | elif self.file_name.endswith('.txt'): 16 | self.text = self._read_text() 17 | else: 18 | raise ValueError("Unsupported file type") 19 | 20 | return self.text 21 | 22 | def _read_pdf(self): 23 | text = "" 24 | try: 25 | # Load PDF from bytes 26 | doc = fitz.open("pdf", self.file_data) 27 | for page in doc: 28 | text += page.get_text() 29 | doc.close() 30 | except Exception as e: 31 | raise RuntimeError(f"Failed to process PDF: {e}") 32 | return text 33 | 34 | def _read_text(self): 35 | try: 36 | # Decode binary data assuming UTF-8 encoding 37 | text = self.file_data.decode('utf-8') 38 | except Exception as e: 39 | raise RuntimeError(f"Failed to process text file: {e}") 40 | return text 41 | 42 | def split(self, text_length: int, text_overlap: int): 43 | if not self.text: 44 | raise ValueError( 45 | "No text available. Ensure 'read' is called first." 46 | ) 47 | 48 | text_splitter = RecursiveCharacterTextSplitter( 49 | chunk_size=text_length, 50 | chunk_overlap=text_overlap, 51 | ) 52 | documents = text_splitter.create_documents([self.text]) 53 | texts = [doc.page_content for doc in documents] 54 | return texts 55 | -------------------------------------------------------------------------------- /assistants_api/app/routers/ops/runsteps_ops_router.py: -------------------------------------------------------------------------------- 1 | # In your FastAPI router file 2 | from fastapi import APIRouter, Body, Depends, HTTPException, Path 3 | from sqlalchemy.orm import Session 4 | from lib.db import ( 5 | crud, 6 | schemas, 7 | database, 8 | ) # Import your CRUD handlers, schemas, and models 9 | from utils.tranformers import db_to_pydantic_runstep 10 | 11 | router = APIRouter() 12 | 13 | 14 | # TODO: improve test to actually inspect the run steps 15 | @router.post( 16 | "/ops/threads/{thread_id}/runs/{run_id}/steps", 17 | response_model=schemas.RunStep, 18 | ) 19 | def create_run_step( 20 | thread_id: str = Path(..., title="The ID of the thread"), 21 | run_id: str = Path(..., title="The ID of the run"), 22 | run_step: schemas.RunStepCreate = Body(..., title="Run step details"), 23 | db: Session = Depends(database.get_db), 24 | ): 25 | # Logic to create a run step 26 | db_run_step = crud.create_run_step( 27 | db=db, thread_id=thread_id, run_id=run_id, run_step=run_step 28 | ) 29 | if db_run_step is None: 30 | raise HTTPException(status_code=500, detail="Run step creation failed") 31 | return db_to_pydantic_runstep(db_run_step) 32 | 33 | 34 | @router.post( 35 | "/ops/threads/{thread_id}/runs/{run_id}/steps/{step_id}", 36 | response_model=schemas.RunStep, 37 | ) 38 | def update_run_step( 39 | thread_id: str = Path(..., title="The ID of the thread"), 40 | run_id: str = Path(..., title="The ID of the run"), 41 | step_id: str = Path(..., title="The ID of the run step to update"), 42 | run_step_update: schemas.RunStepUpdate = Body( 43 | ..., title="Fields to update" 44 | ), 45 | db: Session = Depends(database.get_db), 46 | ): 47 | db_run_step = crud.update_run_step( 48 | db=db, 49 | thread_id=thread_id, 50 | run_id=run_id, 51 | step_id=step_id, 52 | run_step_update=run_step_update.model_dump(exclude_none=True), 53 | ) 54 | if db_run_step is None: 55 | raise HTTPException(status_code=404, detail="Run step not found") 56 | return db_to_pydantic_runstep(db_run_step) 57 | -------------------------------------------------------------------------------- /assistants_api/requirements.txt: -------------------------------------------------------------------------------- 1 | aiohttp==3.9.5 2 | aiosignal==1.3.1 3 | annotated-types==0.6.0 4 | anyio==4.2.0 5 | argon2-cffi==23.1.0 6 | argon2-cffi-bindings==21.2.0 7 | async-timeout==4.0.3 8 | attrs==23.2.0 9 | Authlib==1.3.0 10 | beautifulsoup4==4.12.3 11 | certifi==2024.2.2 12 | cffi==1.16.0 13 | cfgv==3.4.0 14 | charset-normalizer==3.3.2 15 | click==8.1.7 16 | colorama==0.4.6 17 | cryptography==42.0.5 18 | dataclasses-json==0.6.4 19 | distlib==0.3.8 20 | distro==1.9.0 21 | dnspython==2.6.0 22 | email-validator==2.1.0.post1 23 | exceptiongroup==1.2.0 24 | fastapi==0.109.2 25 | filelock==3.13.1 26 | frozenlist==1.4.1 27 | greenlet==3.0.3 28 | grpcio==1.62.2 29 | grpcio-health-checking==1.62.2 30 | grpcio-tools==1.62.2 31 | h11==0.14.0 32 | httpcore==1.0.3 33 | httptools==0.6.1 34 | httpx==0.27.0 35 | identify==2.5.35 36 | idna==3.6 37 | iniconfig==2.0.0 38 | itsdangerous==2.1.2 39 | Jinja2==3.1.3 40 | jsonpatch==1.33 41 | jsonpointer==2.4 42 | jsonref==1.1.0 43 | langchain==0.1.16 44 | langchain-community==0.0.34 45 | langchain-core==0.1.46 46 | langchain-text-splitters==0.0.1 47 | langsmith==0.1.51 48 | lxml==5.2.1 49 | MarkupSafe==2.1.5 50 | marshmallow==3.21.1 51 | minio==7.2.4 52 | multidict==6.0.5 53 | mypy-extensions==1.0.0 54 | nodeenv==1.8.0 55 | numpy==1.26.4 56 | ./openai-1.26.0-py3-none-any.whl 57 | orjson==3.9.14 58 | packaging==23.2 59 | pika==1.3.2 60 | platformdirs==4.2.0 61 | pluggy==1.4.0 62 | pre-commit==3.6.2 63 | protobuf==4.25.3 64 | psycopg2==2.9.9 65 | pycparser==2.21 66 | pycryptodome==3.20.0 67 | pydantic==2.6.1 68 | pydantic-extra-types==2.5.0 69 | pydantic-settings==2.2.0 70 | pydantic_core==2.16.2 71 | PyMuPDF==1.24.2 72 | PyMuPDFb==1.24.1 73 | pytest==8.0.1 74 | pytest-dependency==0.6.0 75 | python-dotenv==1.0.1 76 | python-multipart==0.0.9 77 | PyYAML==6.0.1 78 | requests==2.31.0 79 | sniffio==1.3.0 80 | SQLAlchemy==2.0.27 81 | starlette==0.36.3 82 | tenacity==8.2.3 83 | tomli==2.0.1 84 | tqdm==4.66.2 85 | typing-inspect==0.9.0 86 | typing_extensions==4.9.0 87 | ujson==5.9.0 88 | urllib3==2.2.1 89 | uvicorn==0.27.1 90 | validators==0.22.0 91 | virtualenv==20.25.0 92 | watchfiles==0.21.0 93 | weaviate-client==4.5.5 94 | websockets==12.0 95 | yarl==1.9.4 -------------------------------------------------------------------------------- /assistants_api/app/lib/fs/actions.py: -------------------------------------------------------------------------------- 1 | from fastapi import UploadFile 2 | from minio import Minio 3 | from .schemas import FileObject 4 | import time 5 | import uuid 6 | import io 7 | 8 | 9 | def upload_file( 10 | minio_client: Minio, bucket_name: str, file: UploadFile, file_data: bytes 11 | ) -> FileObject: 12 | file_id = str(uuid.uuid4()) # Generate a unique file ID 13 | file_name = file.filename 14 | file_size = len(file_data) 15 | file_stream = io.BytesIO(file_data) # Create a stream from the byte data 16 | 17 | # Save file to MinIO 18 | minio_client.put_object( 19 | bucket_name, 20 | file_id, 21 | file_stream, 22 | file_size, 23 | metadata={"filename": file_name}, 24 | ) 25 | 26 | return FileObject( 27 | id=file_id, 28 | bytes=file_size, 29 | created_at=int(time.time()), 30 | filename=file_name, 31 | object="file", 32 | purpose="assistants", 33 | status="uploaded", 34 | ) 35 | 36 | 37 | def get_file( 38 | minio_client: Minio, bucket_name: str, file_id: str 39 | ) -> FileObject: 40 | # Retrieve file information 41 | try: 42 | file_stat = minio_client.stat_object(bucket_name, file_id) 43 | # Optional: Retrieve the metadata if needed 44 | filename = file_stat.metadata["x-amz-meta-filename"] 45 | 46 | return FileObject( 47 | id=file_id, 48 | bytes=file_stat.size, 49 | created_at=int(file_stat.last_modified.timestamp()), 50 | filename=filename, 51 | object="file", 52 | purpose="assistants", # Adjust as necessary, possibly from metadata 53 | status="uploaded", # Adjust as necessary 54 | ) 55 | except Exception as e: 56 | print(f"Error retrieving file: {str(e)}") 57 | raise e 58 | 59 | 60 | def get_file_binary( 61 | minio_client: Minio, bucket_name: str, file_id: str 62 | ) -> bytes: 63 | return minio_client.get_object(bucket_name, file_id).read() 64 | 65 | 66 | def delete_file(minio_client: Minio, bucket_name: str, file_id: str) -> None: 67 | minio_client.remove_object(bucket_name, file_id) 68 | -------------------------------------------------------------------------------- /assistants_api/tests/ops/test_runs_ops.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import pytest 3 | from openai import OpenAI 4 | from openai.types.beta.threads import Run 5 | import os 6 | import time 7 | 8 | api_key = os.getenv("OPENAI_API_KEY") if os.getenv("OPENAI_API_KEY") else None 9 | 10 | 11 | @pytest.fixture 12 | def openai_client(): 13 | return OpenAI( 14 | base_url="http://localhost:8000", 15 | api_key=api_key, 16 | ) 17 | 18 | 19 | @pytest.fixture 20 | def thread_id(openai_client: OpenAI): 21 | thread_metadata = {"example_key": "example_value"} 22 | response = openai_client.beta.threads.create(metadata=thread_metadata) 23 | return response.id 24 | 25 | 26 | @pytest.fixture 27 | def assistant_id(openai_client: OpenAI): 28 | response = openai_client.beta.assistants.create( 29 | instructions="You are an AI designed to provide examples.", 30 | name="Example Assistant", 31 | tools=[{"type": "code_interpreter"}], 32 | model="gpt-3.5-turbo", 33 | ) 34 | return response.id 35 | 36 | 37 | @pytest.fixture 38 | def run_id(openai_client: OpenAI, thread_id: str, assistant_id: str): 39 | response = openai_client.beta.threads.runs.create( 40 | thread_id=thread_id, 41 | assistant_id=assistant_id, 42 | ) 43 | return response.id 44 | 45 | 46 | @pytest.mark.dependency() 47 | def test_update_run(thread_id: str, run_id: str): 48 | # Assuming you have a way to get a run_id, perhaps from the test_create_run test 49 | curr_time = int(time.time()) 50 | update_url = f"http://localhost:8000/ops/threads/{thread_id}/runs/{run_id}" 51 | update_data = { 52 | "status": "completed", 53 | "completed_at": curr_time, 54 | } 55 | 56 | response = requests.post(update_url, json=update_data) 57 | 58 | # Verify the response status code and the updated fields 59 | assert response.status_code == 200 60 | updated_run = response.json() 61 | updated_run = Run(**updated_run) 62 | assert updated_run.status == "completed" 63 | assert updated_run.completed_at == curr_time 64 | assert updated_run.id == run_id 65 | 66 | # You might want to fetch the updated run again using a GET request to double-check 67 | -------------------------------------------------------------------------------- /assistants_api/app/routers/threads_router.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter, Depends, HTTPException 2 | from sqlalchemy.orm import Session 3 | from lib.db import schemas, database, crud 4 | from utils.tranformers import db_to_pydantic_thread 5 | 6 | router = APIRouter() 7 | 8 | 9 | @router.post("/threads", response_model=schemas.Thread) 10 | def create_thread( 11 | thread_data: schemas.ThreadCreate, db: Session = Depends(database.get_db) 12 | ): 13 | """ 14 | Create a new thread. 15 | - **metadata**: Set of 16 key-value pairs that can be attached to the thread. 16 | """ 17 | 18 | db_thread = crud.create_thread(db, thread_data) 19 | return db_to_pydantic_thread(db_thread) 20 | 21 | 22 | @router.get("/threads/{thread_id}", response_model=schemas.Thread) 23 | def get_thread(thread_id: str, db: Session = Depends(database.get_db)): 24 | """ 25 | Retrieve a specific thread by its ID. 26 | - **thread_id**: The ID of the thread to retrieve. 27 | """ 28 | db_thread = crud.get_thread(db, thread_id=thread_id) 29 | if db_thread is None: 30 | raise HTTPException(status_code=404, detail="No thread found") 31 | 32 | return db_to_pydantic_thread(db_thread) 33 | 34 | 35 | @router.post("/threads/{thread_id}", response_model=schemas.Thread) 36 | def update_thread( 37 | thread_id: str, 38 | thread_data: schemas.ThreadUpdate, 39 | db: Session = Depends(database.get_db), 40 | ): 41 | """ 42 | Update a specific thread by its ID. 43 | - **thread_id**: The ID of the thread to update. 44 | - **metadata**: Set of 16 key-value pairs that can be attached to the thread. 45 | """ 46 | db_thread = crud.update_thread( 47 | db, thread_id, thread_data.model_dump(exclude_none=True) 48 | ) 49 | if db_thread is None: 50 | raise HTTPException(status_code=404, detail="No thread found") 51 | 52 | return db_to_pydantic_thread(db_thread) 53 | 54 | 55 | @router.delete("/threads/{thread_id}", response_model=schemas.ThreadDeleted) 56 | def delete_thread(thread_id: str, db: Session = Depends(database.get_db)): 57 | """ 58 | Delete a specific thread by its ID. 59 | - **thread_id**: The ID of the thread to delete. 60 | """ 61 | is_deleted = crud.delete_thread(db, thread_id) 62 | if not is_deleted: 63 | raise HTTPException(status_code=404, detail="No thread found") 64 | 65 | return schemas.ThreadDeleted( 66 | id=thread_id, deleted=is_deleted, object="thread.deleted" 67 | ) 68 | -------------------------------------------------------------------------------- /assistants_api/app/utils/tranformers.py: -------------------------------------------------------------------------------- 1 | from lib.db import models 2 | from lib.db import schemas 3 | 4 | 5 | def db_to_pydantic_assistant( 6 | db_assistant: models.Assistant, 7 | ) -> schemas.Assistant: 8 | assistant_dict = db_assistant.__dict__ 9 | assistant_dict = assistant_dict.copy() 10 | del assistant_dict["_sa_instance_state"] 11 | assistant_dict["metadata"] = assistant_dict["_metadata"] 12 | del assistant_dict["_metadata"] 13 | return schemas.Assistant(**assistant_dict) 14 | 15 | 16 | def db_to_pydantic_thread( 17 | db_thread: models.Thread, 18 | ) -> schemas.Thread: 19 | thread_dict = db_thread.__dict__ 20 | thread_dict = thread_dict.copy() 21 | del thread_dict["_sa_instance_state"] 22 | thread_dict["metadata"] = thread_dict["_metadata"] 23 | del thread_dict["_metadata"] 24 | return schemas.Thread(**thread_dict) 25 | 26 | 27 | def db_to_pydantic_message( 28 | db_message: models.Message, 29 | ) -> schemas.Message: 30 | message_dict = db_message.__dict__ 31 | message_dict = message_dict.copy() 32 | del message_dict["_sa_instance_state"] 33 | message_dict["metadata"] = message_dict["_metadata"] 34 | del message_dict["_metadata"] 35 | return schemas.Message(**message_dict) 36 | 37 | 38 | def db_to_pydantic_run( 39 | db_run: models.Run, 40 | ) -> schemas.Run: 41 | run_dict = db_run.__dict__ 42 | run_dict = run_dict.copy() 43 | del run_dict["_sa_instance_state"] 44 | run_dict["metadata"] = run_dict["_metadata"] 45 | del run_dict["_metadata"] 46 | return schemas.Run(**run_dict) 47 | 48 | 49 | def db_to_pydantic_runstep( 50 | db_run_step: models.RunStep, 51 | ) -> schemas.RunStep: 52 | run_step_dict = db_run_step.__dict__ 53 | run_step_dict = run_step_dict.copy() 54 | del run_step_dict["_sa_instance_state"] 55 | run_step_dict["metadata"] = run_step_dict["_metadata"] 56 | del run_step_dict["_metadata"] 57 | return schemas.RunStep(**run_step_dict) 58 | 59 | 60 | def db_to_pydantic_vector_store( 61 | db_vector_store: models.VectorStore, 62 | ) -> schemas.VectorStore: 63 | vector_store_dict = db_vector_store.__dict__ 64 | vector_store_dict = vector_store_dict.copy() 65 | del vector_store_dict["_sa_instance_state"] 66 | vector_store_dict["metadata"] = vector_store_dict["_metadata"] 67 | del vector_store_dict["_metadata"] 68 | return schemas.VectorStore(**vector_store_dict) 69 | 70 | 71 | def db_to_pydantic_vector_store_file_batch( 72 | db_vector_store_file_batch: models.VectorStoreFileBatch, 73 | ) -> schemas.VectorStoreFileBatch: 74 | vector_store_file_batch_dict = db_vector_store_file_batch.__dict__ 75 | vector_store_file_batch_dict = vector_store_file_batch_dict.copy() 76 | del vector_store_file_batch_dict["_sa_instance_state"] 77 | return schemas.VectorStoreFileBatch(**vector_store_file_batch_dict) 78 | -------------------------------------------------------------------------------- /assistants_api/app/main.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI, Request 2 | from routers import ( 3 | assistant_router, 4 | file_router, 5 | threads_router, 6 | message_router, 7 | run_router, 8 | runsteps_router, 9 | vectorstore_router, 10 | ) 11 | from routers.ops import ( 12 | run_ops_router, 13 | runsteps_ops_router, 14 | web_retrieval_ops_router, 15 | ) 16 | from lib.db.database import engine 17 | from lib.db import models 18 | from fastapi.middleware.cors import CORSMiddleware 19 | from dotenv import load_dotenv 20 | from starlette.middleware.base import BaseHTTPMiddleware 21 | from lib.wv.client import client as wv_client 22 | import weaviate 23 | 24 | 25 | class RawBodyMiddleware(BaseHTTPMiddleware): 26 | async def dispatch(self, request: Request, call_next): 27 | body = await request.body() 28 | print(f"Raw body: {body}") 29 | response = await call_next(request) 30 | return response 31 | 32 | 33 | load_dotenv() 34 | 35 | app = FastAPI() 36 | 37 | app.add_middleware( 38 | RawBodyMiddleware, 39 | ) 40 | 41 | app.add_middleware( 42 | CORSMiddleware, 43 | allow_origins=["*"], 44 | allow_credentials=True, 45 | allow_methods=["*"], 46 | allow_headers=["*"], 47 | ) 48 | 49 | if not wv_client.collections.exists(name="web_retrieval"): 50 | print("Creating web retrieval collection...") 51 | wv_client.collections.create( 52 | name=web_retrieval_ops_router.COLLECTION_NAME, 53 | description=web_retrieval_ops_router.DEFAULT_WEB_RETRIEVAL_DESCRIPTION, 54 | generative_config=weaviate.classes.config.Configure.Generative.openai(), 55 | properties=[ 56 | weaviate.classes.config.Property( 57 | name="url", data_type=weaviate.classes.config.DataType.TEXT 58 | ), 59 | weaviate.classes.config.Property( 60 | name="content", 61 | data_type=weaviate.classes.config.DataType.TEXT, 62 | ), 63 | weaviate.classes.config.Property( 64 | name="depth", 65 | data_type=weaviate.classes.config.DataType.NUMBER, 66 | ), 67 | ], 68 | vectorizer_config=[ 69 | weaviate.classes.config.Configure.NamedVectors.text2vec_openai( 70 | name="content_and_url", 71 | source_properties=["content", "url"], 72 | ) 73 | ], 74 | ) 75 | 76 | # TODO: Remove this in production 77 | models.Base.metadata.drop_all(bind=engine) 78 | 79 | # Create the database tables 80 | models.Base.metadata.create_all(bind=engine) 81 | 82 | app.include_router(assistant_router.router) 83 | app.include_router(file_router.router) 84 | app.include_router(threads_router.router) 85 | app.include_router(message_router.router) 86 | app.include_router(run_router.router) 87 | app.include_router(runsteps_router.router) 88 | app.include_router(vectorstore_router.router) 89 | 90 | # ops routers 91 | app.include_router(run_ops_router.router) 92 | app.include_router(runsteps_ops_router.router) 93 | app.include_router(web_retrieval_ops_router.router) 94 | -------------------------------------------------------------------------------- /assistants_api/tests/ops/test_web_crawling_ops.py: -------------------------------------------------------------------------------- 1 | from openai import OpenAI 2 | import pytest 3 | import weaviate 4 | import os 5 | 6 | 7 | # Setup Weaviate client and test URLs 8 | WEAVIATE_HOST = "localhost" 9 | OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") 10 | 11 | api_key = os.getenv("OPENAI_API_KEY") if os.getenv("OPENAI_API_KEY") else None 12 | use_openai = True if os.getenv("USE_OPENAI") else False 13 | base_url = "http://localhost:8000" 14 | 15 | 16 | @pytest.fixture 17 | def openai_client(): 18 | if use_openai: 19 | return OpenAI( 20 | api_key=api_key, 21 | ) 22 | else: 23 | return OpenAI( 24 | base_url=base_url, 25 | ) 26 | 27 | 28 | @pytest.fixture 29 | def weaviate_client(): 30 | return weaviate.connect_to_local( 31 | host=WEAVIATE_HOST, 32 | port=8080, 33 | grpc_port=50051, 34 | headers={"X-OpenAI-Api-Key": OPENAI_API_KEY}, 35 | ) 36 | 37 | 38 | @pytest.mark.dependency() 39 | def test_weaviate_integration( 40 | openai_client: OpenAI, weaviate_client: weaviate.client.WeaviateClient 41 | ): 42 | crawl = openai_client.ops.web_retrieval.crawl_and_upsert( 43 | root_urls=["https://quotes.toscrape.com/"], 44 | constrain_to_root_domain=True, 45 | max_depth=1, 46 | ) 47 | 48 | assert crawl.message == "Crawling completed successfully." 49 | assert len(crawl.crawl_infos) == 47 50 | 51 | err_count = 0 52 | for crawl_info in crawl.crawl_infos: 53 | assert crawl_info.content == "" 54 | if crawl_info.error: 55 | err_count += 1 56 | assert err_count <= 8 57 | 58 | # Check collection existence 59 | collection_name = "web_retrieval" 60 | assert weaviate_client.collections.exists(name=collection_name) is True 61 | 62 | # Insert and retrieve data to ensure functionality 63 | collection = weaviate_client.collections.get(name=collection_name) 64 | query_result = collection.query.near_text( 65 | query="Oscar Winning Films", limit=1 66 | ) 67 | 68 | assert len(query_result.objects) == 1 69 | assert "content" in query_result.objects[0].properties 70 | 71 | 72 | @pytest.mark.dependency(depends=["test_weaviate_integration"]) 73 | @pytest.mark.skip(reason="This test is skipped unless explicitly run.") 74 | def test_change_description_and_delete_collection( 75 | openai_client: OpenAI, weaviate_client: weaviate.client.WeaviateClient 76 | ): 77 | # Test change description 78 | collection_name = "web_retrieval" 79 | test_description = "This is a test description." 80 | openai_client.ops.web_retrieval.crawl_and_upsert( 81 | root_urls=["https://quotes.toscrape.com/"], 82 | constrain_to_root_domain=True, 83 | max_depth=0, 84 | description=test_description, 85 | ) 86 | 87 | collection = weaviate_client.collections.get(name=collection_name) 88 | config = collection.config.get() 89 | assert config.description == test_description 90 | 91 | # test delete (which behaves more like a reset) 92 | openai_client.ops.web_retrieval.delete() 93 | 94 | config = collection.config.get() 95 | assert not (config.description == test_description) 96 | -------------------------------------------------------------------------------- /run_executor_worker/src/utils/tools.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import List 3 | from openai.types.beta.web_retrieval_tool import WebRetrievalTool 4 | from openai.types.beta.assistant_tool import AssistantTool 5 | from openai.types.beta.file_search_tool import FileSearchTool 6 | from openai.types.beta.code_interpreter_tool import CodeInterpreterTool 7 | from openai.types.beta.function_tool import FunctionTool 8 | from pydantic import BaseModel 9 | 10 | 11 | class Actions(Enum): 12 | # function, retrieval, code_interpreter, text_generation, completion 13 | FUNCTION = "function" 14 | WEB_RETRIEVAL = "web_retrieval" 15 | FILE_SEARCH = "file_search" 16 | CODE_INTERPRETER = "code_interpreter" 17 | TEXT_GENERATION = "text_generation" 18 | COMPLETION = "completion" 19 | FAILURE = "failure" 20 | 21 | 22 | class ActionItem(BaseModel): 23 | type: str 24 | description: str 25 | 26 | 27 | def actions_to_map(actions: List[str]) -> dict[str, ActionItem]: 28 | """ 29 | Converts a list of AssistantTool objects to a dictionary. 30 | """ 31 | actions_map = {} 32 | for action in actions: 33 | if action == Actions.TEXT_GENERATION.value: 34 | actions_map[action] = ActionItem( 35 | type=Actions.TEXT_GENERATION.value, 36 | description="Communicate to the user either to summarize or express the next tasks to be executed.", # noqa 37 | ) 38 | elif action == Actions.COMPLETION.value: 39 | actions_map[action] = ActionItem( 40 | type=Actions.COMPLETION.value, 41 | description="Finish the process, generate the final answer", 42 | ) 43 | return actions_map 44 | 45 | 46 | def tools_to_map( 47 | tools: List[AssistantTool], web_retrieval_description: str 48 | ) -> dict[str, ActionItem]: 49 | """ 50 | Converts a list of AssistantTool objects to a dictionary. 51 | """ 52 | tools_map: dict[str, ActionItem] = {} 53 | for tool in tools: 54 | if isinstance(tool, FunctionTool): 55 | if not tools_map.get(tool.type): 56 | tools_map[tool.type] = ActionItem( 57 | type=tool.type, 58 | description="Function calls available to you are: ", 59 | ) 60 | tools_map[tool.type].description += f"{tool.function.model_dump()}" 61 | 62 | elif isinstance(tool, FileSearchTool): 63 | tools_map[tool.type] = ActionItem( 64 | type=tool.type, 65 | description="Retrieves information from files provided.", 66 | ) 67 | elif isinstance(tool, WebRetrievalTool): 68 | tools_map[tool.type] = ActionItem( 69 | type=tool.type, 70 | description=web_retrieval_description, 71 | ) 72 | elif isinstance(tool, CodeInterpreterTool): 73 | tools_map[tool.type] = ActionItem( 74 | type=tool.type, 75 | description="Interprets and executes code.", 76 | ) 77 | return tools_map 78 | 79 | 80 | text_generation_tool = ActionItem( 81 | type="text_generation", 82 | description="General text response.", 83 | ) 84 | 85 | completion_tool = ActionItem( 86 | type="completion", 87 | description="Completes the task.", 88 | ) 89 | -------------------------------------------------------------------------------- /run_executor_worker/src/actions/web_retrieval.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from utils.ops_api_handler import create_web_retrieval_runstep 3 | from utils.openai_clients import litellm_client 4 | from data_models import run 5 | import os 6 | from agents import coala 7 | from utils.weaviate_utils import weaviate_client 8 | from constants import WebRetrievalResult 9 | 10 | 11 | class WebRetrieval: 12 | def __init__( 13 | self, 14 | coala_class: "coala.CoALA", 15 | amt_documents: int = 2, 16 | ): 17 | self.coala_class = coala_class 18 | self.amt_documents = amt_documents 19 | 20 | def query(self, query: str, site: str = None) -> List[WebRetrievalResult]: 21 | collection_name = "web_retrieval" 22 | if weaviate_client.collections.exists(name=collection_name): 23 | collection = weaviate_client.collections.get(name=collection_name) 24 | else: 25 | raise Exception(f"Collection {collection_name} does not exist.") 26 | 27 | query_result = collection.query.hybrid( 28 | query=query, 29 | limit=self.amt_documents, 30 | target_vector="content_and_url", 31 | ) 32 | 33 | return [ 34 | WebRetrievalResult( 35 | url=chunk.properties["url"], 36 | content=chunk.properties["content"], 37 | depth=chunk.properties["depth"], 38 | ) 39 | for chunk in query_result.objects 40 | ] 41 | 42 | def generate( 43 | self, 44 | ) -> run.RunStep: 45 | # get relevant retrieval query 46 | user_instruction = self.coala_class.compose_user_instruction() 47 | instructions = f"""{user_instruction}Your role is generate a query for semantic search according to current working memory. 48 | Even if there is no relevant information in the working memory, you should still generate a query relevant to the provided information. 49 | Only respond with the query iteself NOTHING ELSE. 50 | 51 | """ # noqa 52 | 53 | messages = [ 54 | { 55 | "role": "user", 56 | "content": instructions + self.compose_query_system_prompt(), 57 | }, 58 | ] 59 | response = litellm_client.chat.completions.create( 60 | model=os.getenv("LITELLM_MODEL"), 61 | messages=messages, 62 | max_tokens=200, 63 | ) 64 | query = response.choices[0].message.content 65 | print(f"\n\n\nQuery generated: {query}") 66 | 67 | # Retrieve documents based on the query 68 | try: 69 | retrieved_items: List[WebRetrievalResult] = self.query(query) 70 | except Exception as e: 71 | raise Exception( 72 | f"Error: web_retrieval is not instantiated yet.\n{e}" 73 | ) 74 | 75 | run_step = create_web_retrieval_runstep( 76 | self.coala_class.thread_id, 77 | self.coala_class.run_id, 78 | self.coala_class.assistant_id, 79 | query, 80 | retrieved_items, 81 | site=", ".join([item.url for item in retrieved_items]), 82 | ) 83 | return run_step 84 | 85 | def compose_query_system_prompt(self) -> str: 86 | trace = self.coala_class.compose_react_trace() 87 | 88 | composed_instruction = f"""Current working memory: 89 | {trace}""" 90 | return composed_instruction 91 | -------------------------------------------------------------------------------- /assistants_api/app/routers/file_router.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, Path 2 | from minio import Minio, S3Error 3 | from lib.fs import actions 4 | from lib.fs.store import minio_client, BUCKET_NAME 5 | from lib.fs.schemas import FileObject, FileDeleted 6 | from lib.db.database import get_db 7 | from typing_extensions import Literal 8 | from sqlalchemy.orm import Session 9 | from lib.db import crud 10 | 11 | router = APIRouter() 12 | 13 | 14 | @router.post("/files", response_model=FileObject) 15 | async def create_file( 16 | file: UploadFile = File(...), 17 | purpose: Literal["fine-tune", "assistants"] = File(...), 18 | db: Session = Depends(get_db), 19 | minio_client: Minio = Depends(minio_client), 20 | ): 21 | if purpose not in ["fine-tune", "assistants"]: 22 | raise HTTPException(status_code=400, detail="Invalid purpose") 23 | 24 | # Ensure file data is read before any operation that might exhaust the stream 25 | file_data_bytes = await file.read() 26 | 27 | # Check if file data is empty 28 | if not file_data_bytes: 29 | raise HTTPException(status_code=400, detail="File is empty") 30 | 31 | uploaded_file = actions.upload_file( 32 | minio_client=minio_client, 33 | bucket_name=BUCKET_NAME, 34 | file=file, 35 | file_data=file_data_bytes, 36 | ) 37 | 38 | crud.create_file(db=db, file=uploaded_file) 39 | 40 | # # File data is passed here after ensuring it's not empty 41 | # wv_actions.upload_file_chunks( 42 | # file_data=file_data_bytes, 43 | # file_name=file.filename, 44 | # file_id=uploaded_file.id, 45 | # ) 46 | 47 | return uploaded_file 48 | 49 | 50 | @router.get("/files/{file_id}", response_model=FileObject) 51 | async def get_file( 52 | file_id: str = Path(..., description="The ID of the file to retrieve"), 53 | db: Session = Depends(get_db), 54 | minio_client: Minio = Depends(minio_client), 55 | ): 56 | # Retrieve file metadata from the database 57 | file_metadata = crud.get_file(db=db, file_id=file_id) 58 | if not file_metadata: 59 | raise HTTPException(status_code=404, detail="File not found") 60 | 61 | # # Optional: Retrieve file contents from MinIO 62 | # response = minio_client.get_object(BUCKET_NAME, file_id) 63 | # file_content = response.read() 64 | 65 | # Return file metadata 66 | return file_metadata 67 | 68 | 69 | @router.delete("/files/{file_id}", response_model=FileDeleted) 70 | async def delete_file( 71 | file_id: str = Path(..., description="The ID of the file to delete"), 72 | db: Session = Depends(get_db), 73 | minio_client: Minio = Depends(minio_client), 74 | ): 75 | # Verify if the file exists in the database 76 | file_metadata = crud.get_file(db=db, file_id=file_id) 77 | if not file_metadata: 78 | raise HTTPException(status_code=404, detail="File not found") 79 | 80 | # Attempt to delete the file from MinIO 81 | try: 82 | actions.delete_file( 83 | minio_client=minio_client, bucket_name=BUCKET_NAME, file_id=file_id 84 | ) 85 | except S3Error as e: 86 | raise HTTPException( 87 | status_code=500, detail=f"Failed to delete file from storage: {e}" 88 | ) 89 | 90 | # Delete the file metadata from the database 91 | crud.delete_file(db=db, file_id=file_id) 92 | 93 | return FileDeleted(id=file_id, deleted=True, object="file") 94 | -------------------------------------------------------------------------------- /assistants_api/tests/test_threads.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from openai import OpenAI 3 | from openai.types.beta.thread import Thread 4 | import os 5 | import time 6 | 7 | api_key = os.getenv("OPENAI_API_KEY") if os.getenv("OPENAI_API_KEY") else None 8 | use_openai = True if os.getenv("USE_OPENAI") else False 9 | base_url = "http://localhost:8000" 10 | 11 | 12 | @pytest.fixture 13 | def openai_client(): 14 | if use_openai: 15 | return OpenAI( 16 | api_key=api_key, 17 | ) 18 | else: 19 | return OpenAI( 20 | base_url=base_url, 21 | ) 22 | 23 | 24 | @pytest.fixture 25 | def thread_id(openai_client: OpenAI): 26 | thread_metadata = {"example_key": "example_value"} 27 | response = openai_client.beta.threads.create(metadata=thread_metadata) 28 | return response.id 29 | 30 | 31 | @pytest.mark.dependency() 32 | def test_create_thread_without_messages(openai_client: OpenAI): 33 | thread_metadata = {"example_key": "example_value"} 34 | response = openai_client.beta.threads.create(metadata=thread_metadata) 35 | assert isinstance(response, Thread) 36 | assert response.id is not None 37 | assert response.object == "thread" 38 | assert response.created_at is not None 39 | assert response.metadata == thread_metadata 40 | 41 | 42 | @pytest.mark.dependency(depends=["test_create_thread_without_messages"]) 43 | def test_get_thread(openai_client: OpenAI, thread_id: str): 44 | metadata = {"example_key": "example_value"} 45 | response = openai_client.beta.threads.retrieve(thread_id=thread_id) 46 | assert isinstance(response, Thread) 47 | assert response.id == thread_id 48 | assert response.object == "thread" 49 | assert response.created_at is not None 50 | assert response.metadata == metadata 51 | 52 | 53 | @pytest.mark.dependency(depends=["test_create_thread_without_messages"]) 54 | def test_update_thread_metadata(openai_client: OpenAI, thread_id: str): 55 | metadata_update = {"new_key": "new_value"} 56 | openai_client.beta.threads.update(thread_id, metadata=metadata_update) 57 | time.sleep(0.5) 58 | metadata_to_be = {**metadata_update, "example_key": "example_value"} 59 | response = openai_client.beta.threads.retrieve(thread_id=thread_id) 60 | assert isinstance(response, Thread) 61 | assert response.id == thread_id 62 | assert response.metadata == metadata_to_be 63 | 64 | 65 | @pytest.mark.dependency( 66 | depends=["test_create_thread_without_messages", "test_get_thread"] 67 | ) 68 | def test_delete_thread(openai_client: OpenAI, thread_id: str): 69 | # veryfy that the thread exists 70 | response = openai_client.beta.threads.retrieve(thread_id=thread_id) 71 | assert isinstance(response, Thread) 72 | assert response.id == thread_id 73 | # delete the thread 74 | response = openai_client.beta.threads.delete(thread_id=thread_id) 75 | assert response.id == thread_id 76 | assert response.deleted is True 77 | # verify that the thread has been deleted 78 | try: 79 | openai_client.beta.threads.retrieve(thread_id=thread_id) 80 | except Exception as e: 81 | assert e.status_code == 404 82 | assert "No thread found" in str(e) 83 | else: 84 | raise AssertionError("Thread was not deleted") 85 | 86 | 87 | # @pytest.fixture(scope="session", autouse=True) 88 | # def cleanup(request): 89 | # # THIS REQUIRES A WAY TO RETIREVE ALL THREADS WHICH CURRENTLY DOES NOT EXIST IN THE API # noqa 90 | -------------------------------------------------------------------------------- /run_executor_worker/src/consumer.py: -------------------------------------------------------------------------------- 1 | from concurrent.futures import ThreadPoolExecutor 2 | import pika 3 | import os 4 | from dotenv import load_dotenv 5 | from run_executor.main import ExecuteRun 6 | import json 7 | import time 8 | 9 | load_dotenv() 10 | 11 | MAX_WORKERS = int(os.getenv("MAX_WORKERS", 4)) 12 | RABBITMQ_DEFAULT_USER = os.getenv("RABBITMQ_DEFAULT_USER") 13 | RABBITMQ_DEFAULT_PASS = os.getenv("RABBITMQ_DEFAULT_PASS") 14 | RABBITMQ_HOST = os.getenv("RABBITMQ_HOST") 15 | RABBITMQ_PORT = int(os.getenv("RABBITMQ_PORT", 5672)) 16 | 17 | 18 | class RabbitMQConsumer: 19 | def __init__(self, max_workers=4): 20 | self.max_workers = max_workers 21 | self.executor = ThreadPoolExecutor(max_workers=max_workers) 22 | self.connect() 23 | 24 | def connect(self): 25 | credentials = pika.PlainCredentials( 26 | RABBITMQ_DEFAULT_USER, RABBITMQ_DEFAULT_PASS 27 | ) 28 | while True: 29 | try: 30 | self.connection = pika.BlockingConnection( 31 | pika.ConnectionParameters( 32 | host=RABBITMQ_HOST, 33 | port=RABBITMQ_PORT, 34 | credentials=credentials, 35 | heartbeat=600, 36 | ) 37 | ) 38 | self.channel = self.connection.channel() 39 | self.channel.basic_qos(prefetch_count=self.max_workers) 40 | break 41 | except pika.exceptions.AMQPConnectionError as e: 42 | print(f"Connection error: {e}, retrying in 5 seconds...") 43 | time.sleep(5) 44 | 45 | def process_message(self, body): 46 | try: 47 | message = body.decode("utf-8") 48 | data = json.loads(message) 49 | print(f"\n\nProcessing {data}") 50 | run = ExecuteRun(data["thread_id"], data["run_id"]) 51 | run.execute() 52 | except json.JSONDecodeError as e: 53 | print(f"Failed to decode JSON: {e}") 54 | 55 | def callback(self, ch, method, properties, body): 56 | self.executor.submit(self.process_message_and_ack, body, ch, method) 57 | 58 | def process_message_and_ack(self, body, ch, method): 59 | try: 60 | self.process_message(body) 61 | ch.basic_ack(delivery_tag=method.delivery_tag) 62 | except Exception as e: 63 | print(f"Failed to process message {body}: {e}") 64 | ch.basic_nack(delivery_tag=method.delivery_tag, requeue=False) 65 | 66 | def start_consuming(self, queue_name): 67 | while True: 68 | try: 69 | self.channel.queue_declare(queue=queue_name, durable=True) 70 | self.channel.basic_consume( 71 | queue=queue_name, 72 | on_message_callback=self.callback, 73 | auto_ack=False, 74 | ) 75 | print("Waiting for messages. To exit press CTRL+C") 76 | self.channel.start_consuming() 77 | except pika.exceptions.ConnectionClosedByBroker: 78 | print("Connection closed by broker, reconnecting...") 79 | self.connect() 80 | except pika.exceptions.StreamLostError: 81 | print("Stream lost, reconnecting...") 82 | self.connect() 83 | except Exception as e: 84 | print(f"Exception in consuming: {e}") 85 | self.connect() 86 | 87 | 88 | if __name__ == "__main__": 89 | consumer = RabbitMQConsumer(max_workers=MAX_WORKERS) 90 | consumer.start_consuming("runs_queue") 91 | -------------------------------------------------------------------------------- /assistants_api/app/routers/message_router.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query 2 | from typing import Optional 3 | from sqlalchemy.orm import Session 4 | from lib.db import crud, schemas, database 5 | from utils.tranformers import db_to_pydantic_message 6 | 7 | router = APIRouter() 8 | 9 | 10 | @router.post("/threads/{thread_id}/messages", response_model=schemas.Message) 11 | def create_message_in_thread( 12 | thread_id: str, 13 | message_inp: schemas.MessageInput, 14 | db: Session = Depends(database.get_db), 15 | ): 16 | db_thread = crud.get_thread(db, thread_id=thread_id) 17 | if db_thread is None: 18 | raise HTTPException(status_code=404, detail="No thread found") 19 | 20 | db_message = crud.create_message( 21 | db=db, thread_id=thread_id, message_inp=message_inp 22 | ) 23 | return db_to_pydantic_message(db_message) 24 | 25 | 26 | @router.get( 27 | "/threads/{thread_id}/messages", 28 | response_model=schemas.SyncCursorPage[schemas.Message], 29 | ) 30 | def get_messages_in_thread( 31 | thread_id: str, 32 | db: Session = Depends(database.get_db), 33 | limit: int = Query(default=20, le=100), 34 | order: str = Query(default="desc", regex="^(asc|desc)$"), 35 | after: Optional[str] = None, 36 | before: Optional[str] = None, 37 | ): 38 | """ 39 | List messages in a thread with optional pagination and ordering. 40 | - **limit**: Maximum number of results to return. 41 | - **order**: Sort order based on the creation time ('asc' or 'desc'). 42 | - **after**: ID to start the list from (for pagination). 43 | - **before**: ID to list up to (for pagination). 44 | """ 45 | db_messages = crud.get_messages( 46 | db=db, 47 | thread_id=thread_id, 48 | limit=limit, 49 | order=order, 50 | after=after, 51 | before=before, 52 | ) 53 | 54 | messages = [db_to_pydantic_message(message) for message in db_messages] 55 | paginated_messages = schemas.SyncCursorPage(data=messages) 56 | 57 | return paginated_messages 58 | 59 | 60 | @router.get( 61 | "/threads/{thread_id}/messages/{message_id}", 62 | response_model=schemas.Message, 63 | ) 64 | def get_message( 65 | thread_id: str, 66 | message_id: str, 67 | db: Session = Depends(database.get_db), 68 | ): 69 | """ 70 | Retrieve a specific message from a thread. 71 | - **thread_id**: The ID of the thread. 72 | - **message_id**: The ID of the message to retrieve. 73 | """ 74 | message_db = crud.get_message_by_id( 75 | db, thread_id=thread_id, message_id=message_id 76 | ) 77 | if not message_db: 78 | raise HTTPException(status_code=404, detail="Message not found") 79 | return db_to_pydantic_message(message_db) 80 | 81 | 82 | @router.post( 83 | "/threads/{thread_id}/messages/{message_id}", 84 | response_model=schemas.Message, 85 | ) 86 | def modify_message( 87 | thread_id: str = Path( 88 | ..., description="The ID of the thread to which this message belongs." 89 | ), 90 | message_id: str = Path( 91 | ..., description="The ID of the message to modify." 92 | ), 93 | update_data: schemas.MessageUpdate = Body(...), 94 | db: Session = Depends(database.get_db), 95 | ): 96 | """ 97 | Modifies a message. 98 | - **thread_id**: The ID of the thread. 99 | - **message_id**: The ID of the message to modify. 100 | - **update_data**: Data for updating the message. 101 | """ 102 | db_message = crud.update_message( 103 | db, thread_id, message_id, update_data.model_dump(exclude_none=True) 104 | ) 105 | if db_message is None: 106 | raise HTTPException(status_code=404, detail="Message not found") 107 | return db_to_pydantic_message(db_message) 108 | -------------------------------------------------------------------------------- /docker-compose.dev.yml: -------------------------------------------------------------------------------- 1 | version: "3.8" 2 | 3 | services: 4 | postgres: 5 | image: postgres:14 6 | restart: always 7 | environment: 8 | POSTGRES_HOST: $POSTGRES_HOST 9 | POSTGRES_PORT: $POSTGRES_PORT 10 | POSTGRES_USER: $POSTGRES_USER 11 | POSTGRES_PASSWORD: $POSTGRES_PASSWORD 12 | POSTGRES_DB: $POSTGRES_DB 13 | ports: 14 | - "5432:5432" 15 | volumes: 16 | - postgres_data:/var/lib/postgresql/data 17 | 18 | minio: 19 | image: minio/minio 20 | restart: always 21 | environment: 22 | MINIO_ACCESS_KEY: $MINIO_ACCESS_KEY 23 | MINIO_SECRET_KEY: $MINIO_SECRET_KEY 24 | ports: 25 | - "9000:9000" 26 | volumes: 27 | - minio_data:/data 28 | command: server /data 29 | 30 | rabbitmq: 31 | image: "rabbitmq:3-management" 32 | environment: 33 | RABBITMQ_DEFAULT_USER: $RABBITMQ_DEFAULT_USER 34 | RABBITMQ_DEFAULT_PASS: $RABBITMQ_DEFAULT_PASS 35 | ports: 36 | - "5672:5672" 37 | - "15672:15672" 38 | volumes: 39 | - rabbitmq_data:/var/lib/rabbitmq 40 | 41 | weaviate: 42 | image: cr.weaviate.io/semitechnologies/weaviate:1.24.17 43 | ports: 44 | - "8080:8080" 45 | - "50051:50051" 46 | volumes: 47 | - weaviate_data:/var/lib/weaviate 48 | restart: on-failure:0 49 | environment: 50 | QUERY_DEFAULTS_LIMIT: 25 51 | AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED: "true" 52 | PERSISTENCE_DATA_PATH: "/var/lib/weaviate" 53 | DEFAULT_VECTORIZER_MODULE: "none" 54 | ENABLE_MODULES: "text2vec-cohere,text2vec-huggingface,text2vec-palm,text2vec-openai,generative-openai,generative-cohere,generative-palm,ref2vec-centroid,reranker-cohere,qna-openai" 55 | CLUSTER_HOSTNAME: "node1" 56 | command: 57 | - --host 58 | - 0.0.0.0 59 | - --port 60 | - "8080" 61 | - --scheme 62 | - http 63 | 64 | assistants_api: 65 | build: 66 | context: ./assistants_api 67 | dockerfile: Dockerfile 68 | volumes: 69 | - ./assistants_api/app:/app 70 | ports: 71 | - "8000:8000" 72 | depends_on: 73 | - postgres 74 | - minio 75 | - rabbitmq 76 | - weaviate 77 | environment: 78 | POSTGRES_HOST: $POSTGRES_HOST 79 | POSTGRES_PORT: $POSTGRES_PORT 80 | POSTGRES_USER: $POSTGRES_USER 81 | POSTGRES_PASSWORD: $POSTGRES_PASSWORD 82 | POSTGRES_DB: $POSTGRES_DB 83 | OPENAI_API_KEY: $OPENAI_API_KEY 84 | MINIO_ENDPOINT: minio 85 | MINIO_ACCESS_KEY: $MINIO_ACCESS_KEY 86 | MINIO_SECRET_KEY: $MINIO_SECRET_KEY 87 | RABBITMQ_DEFAULT_USER: $RABBITMQ_DEFAULT_USER 88 | RABBITMQ_DEFAULT_PASS: $RABBITMQ_DEFAULT_PASS 89 | RABBITMQ_HOST: rabbitmq 90 | RABBITMQ_PORT: $RABBITMQ_PORT 91 | WEAVIATE_HOST: weaviate 92 | command: sh -c "sleep 10 && uvicorn main:app --host 0.0.0.0 --port 8000 --reload" 93 | run_executor_worker: 94 | build: 95 | context: ./run_executor_worker 96 | dockerfile: Dockerfile 97 | volumes: 98 | - ./run_executor_worker:/app 99 | depends_on: 100 | - postgres 101 | - minio 102 | - rabbitmq 103 | - weaviate 104 | environment: 105 | MAX_WORKERS: 12 106 | RABBITMQ_DEFAULT_USER: $RABBITMQ_DEFAULT_USER 107 | RABBITMQ_DEFAULT_PASS: $RABBITMQ_DEFAULT_PASS 108 | RABBITMQ_HOST: rabbitmq 109 | RABBITMQ_PORT: $RABBITMQ_PORT 110 | OPENAI_API_KEY: $OPENAI_API_KEY 111 | ASSISTANTS_API_URL: http://assistants_api:8000 112 | LITELLM_API_URL: $LITELLM_API_URL 113 | LITELLM_API_KEY: $LITELLM_API_KEY 114 | LITELLM_MODEL: $LITELLM_MODEL 115 | WEAVIATE_HOST: weaviate 116 | FC_API_URL: $FC_API_URL 117 | FC_API_KEY: $FC_API_KEY 118 | FC_MODEL: $FC_MODEL 119 | 120 | command: sh -c "sleep 10 && python scripts/watcher.py" 121 | 122 | volumes: 123 | postgres_data: 124 | minio_data: 125 | rabbitmq_data: 126 | weaviate_data: 127 | -------------------------------------------------------------------------------- /run_executor_worker/src/utils/coala.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | from openai.pagination import SyncCursorPage 3 | from data_models import run 4 | from openai.types.beta.threads import ThreadMessage 5 | from utils.tools import Actions 6 | from utils.tools import ActionItem 7 | 8 | 9 | class CoALA: 10 | def __init__( 11 | self, 12 | runsteps: SyncCursorPage[run.RunStep], 13 | messages: SyncCursorPage[ThreadMessage], 14 | job_summary: str, 15 | tools_map: dict[str, ActionItem], 16 | ): 17 | """ 18 | CoALA class to setup the CoALA prompt 19 | messages: episodic memory 20 | runsteps: working memory 21 | job_summary: objective 22 | tools_map: external actions 23 | """ 24 | self.runsteps = runsteps 25 | self.messages = messages 26 | self.job_summary = job_summary 27 | self.tools_map = tools_map 28 | 29 | def compose_trace(self): 30 | """ 31 | Compose the trace prompt of the current task 32 | """ 33 | trace_prompt = [] 34 | for step in self.runsteps: 35 | if step.type == "tool_calls": 36 | trace_prompt.append( 37 | f"Action: {step.step_details.tool_calls[0].type}" 38 | ) 39 | trace_prompt.append( 40 | f"Observation: {step.step_details.tool_calls[0].model_dump()}" 41 | ) 42 | if step.type == "message_creation": 43 | message = next( 44 | ( 45 | msg.content[0].text.value 46 | for msg in self.messages.data 47 | if msg.id 48 | == step.step_details.message_creation.message_id 49 | ), 50 | None, 51 | ) 52 | trace_prompt.append(f"Thought: {message}") 53 | return "\n".join(trace_prompt) 54 | 55 | def compose_actions(self): 56 | """ 57 | Compose the tools prompt for the CoALA task 58 | """ 59 | action_prompts = [] 60 | for tool in self.tools_map: 61 | action_prompts.append( 62 | f"- {tool} ({self.tools_map[tool].description})" 63 | ) 64 | action_prompts.append( 65 | f"- {Actions.COMPLETION.value} (Finish the process, generate the final answer)" # noqa 66 | ) 67 | return "\n".join(action_prompts) 68 | 69 | def compose_prompt( 70 | self, type: Literal["action", "thought", "final_answer"] 71 | ): 72 | """ 73 | Compose the prompt for the CoALA task 74 | """ 75 | base_prompt = None 76 | if type == "action": 77 | base_prompt = """Your role is to determine which "Action" to use next. 78 | You must always begin with "Action: ..." """ 79 | elif type == "thought": 80 | base_prompt = """Your role is to provide a "Thought" response to the user. 81 | You must always begin with "Thought: ..." and finish with "Action: " """ 82 | elif type == "final_answer": 83 | base_prompt = """Your role is to provide the "Final Answer" to the user. 84 | You must always begin with "Final Answer: ..." """ 85 | trace_prompt = self.compose_trace() 86 | actions_prompt = self.compose_actions() 87 | actions_list = list(self.tools_map.keys()) + [Actions.COMPLETION.value] 88 | 89 | coala_prompt = f"""{base_prompt} 90 | You will observe that there are already steps after "Begin!". 91 | The actions available to you are: 92 | 93 | {actions_prompt} 94 | 95 | Continue the generation. 96 | Only reply with the single next step. 97 | Do respond with more than the immediate next step. 98 | Use the following format: 99 | 100 | Question: the input question you must answer 101 | Thought: you should always think about what to do 102 | Action: the action to take, should be one of {actions_list} 103 | Observation: the result of the action 104 | ... (this Thought/Action/Observation can repeat N times) 105 | Thought: I now know the final answer 106 | Final Answer: the final answer to the original input question 107 | 108 | Begin! 109 | 110 | Question: {self.job_summary} 111 | {trace_prompt}""" 112 | 113 | return coala_prompt 114 | -------------------------------------------------------------------------------- /assistants_api/tests/test_files.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from openai import OpenAI 3 | from openai.types import FileObject 4 | from minio import Minio 5 | import os 6 | 7 | api_key = os.getenv("OPENAI_API_KEY") if os.getenv("OPENAI_API_KEY") else None 8 | weaviate_url = os.getenv("WEAVIATE_URL") if os.getenv("WEAVIATE_URL") else None 9 | use_openai = True if os.getenv("USE_OPENAI") else False 10 | base_url = "http://localhost:8000" 11 | 12 | 13 | current_dir = os.path.dirname(__file__) 14 | test_txt_file_path = os.path.join(current_dir, '..', 'assets', 'test.txt') 15 | test_pdf_file_path = os.path.join(current_dir, '..', 'assets', 'test.pdf') 16 | 17 | ACCESS_KEY = os.getenv('MINIO_ACCESS_KEY') 18 | SECRET_KEY = os.getenv('MINIO_SECRET_KEY') 19 | MINIO_URL = "localhost:9000" 20 | BUCKET_NAME = "store" 21 | 22 | 23 | @pytest.fixture 24 | def minio_client(): 25 | minio_client = Minio( 26 | MINIO_URL, 27 | access_key=ACCESS_KEY, 28 | secret_key=SECRET_KEY, 29 | secure=False, 30 | ) 31 | # Create bucket if it doesn't exist 32 | found = minio_client.bucket_exists(BUCKET_NAME) 33 | if not found: 34 | minio_client.make_bucket(BUCKET_NAME) 35 | return minio_client 36 | 37 | 38 | @pytest.fixture 39 | def openai_client(): 40 | if use_openai: 41 | return OpenAI( 42 | api_key=api_key, 43 | ) 44 | else: 45 | return OpenAI( 46 | base_url=base_url, 47 | ) 48 | 49 | 50 | @pytest.mark.dependency() 51 | def test_create_file(openai_client: OpenAI, minio_client: Minio): 52 | with open(test_txt_file_path, 'rb') as file: 53 | response = openai_client.files.create(file=file, purpose="assistants") 54 | assert isinstance(response, FileObject) 55 | assert response.id == response.id 56 | assert response.bytes is not None 57 | assert response.created_at is not None 58 | assert response.filename == "test.txt" 59 | assert response.purpose == "assistants" 60 | 61 | if not use_openai: 62 | file_stat = minio_client.stat_object(BUCKET_NAME, response.id) 63 | assert file_stat.size > 1800 64 | assert file_stat.metadata["x-amz-meta-filename"] == "test.txt" 65 | 66 | 67 | def test_create_file_pdf(openai_client: OpenAI): 68 | with open(test_pdf_file_path, 'rb') as file: 69 | response = openai_client.files.create(file=file, purpose="assistants") 70 | assert isinstance(response, FileObject) 71 | assert response.id == response.id 72 | assert response.bytes is not None 73 | assert response.created_at is not None 74 | assert response.filename == "test.pdf" 75 | assert response.purpose == "assistants" 76 | 77 | 78 | @pytest.mark.dependency(depends=["test_create_file"]) 79 | def test_retrieve_file(openai_client: OpenAI): 80 | # Assuming you have a file ID to test with 81 | with open(test_txt_file_path, 'rb') as file: 82 | file_created = openai_client.files.create( 83 | file=file, purpose="assistants" 84 | ) 85 | response = openai_client.files.retrieve(file_created.id) 86 | 87 | assert isinstance(response, FileObject) 88 | assert response.id == file_created.id 89 | assert response.bytes is not None 90 | assert response.created_at is not None 91 | assert response.filename == "test.txt" 92 | assert response.purpose == "assistants" 93 | 94 | 95 | @pytest.mark.dependency(depends=["test_create_file", "test_retrieve_file"]) 96 | def test_delete_file(openai_client: OpenAI): 97 | # Step 1: Create a file 98 | with open(test_txt_file_path, 'rb') as file: 99 | create_response = openai_client.files.create( 100 | file=file, purpose="assistants" 101 | ) 102 | assert create_response.id is not None 103 | 104 | # Step 2: Retrieve the created file 105 | retrieve_response = openai_client.files.retrieve(create_response.id) 106 | assert retrieve_response.id == create_response.id 107 | 108 | # Step 3: Delete the file 109 | delete_response = openai_client.files.delete(create_response.id) 110 | assert delete_response.deleted is True 111 | assert delete_response.id == create_response.id 112 | 113 | # Step 4: Attempt to retrieve the deleted file 114 | with pytest.raises(Exception): 115 | openai_client.files.retrieve(create_response.id) 116 | -------------------------------------------------------------------------------- /run_executor_worker/src/actions/file_search.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from utils.weaviate_utils import retrieve_file_chunks 3 | from utils.ops_api_handler import create_retrieval_runstep 4 | from utils.openai_clients import litellm_client, assistants_client 5 | from openai.types.beta.vector_store import VectorStore 6 | from data_models import run 7 | import json 8 | import os 9 | 10 | # import coala 11 | from agents import coala 12 | 13 | 14 | class FileSearch: 15 | def __init__( 16 | self, 17 | coala_class: "coala.CoALA", 18 | ): 19 | self.coala_class = coala_class 20 | self.vector_stores: List[VectorStore] = [] 21 | 22 | def retrieve_vector_stores(self): 23 | vector_store_ids = ( 24 | self.coala_class.assistant.tool_resources.file_search.vector_store_ids 25 | ) 26 | for vector_store_id in vector_store_ids: 27 | vector_store = assistants_client.beta.vector_stores.retrieve( 28 | vector_store_id 29 | ) 30 | self.vector_stores.append(vector_store) 31 | 32 | def generate( 33 | self, 34 | ) -> run.RunStep: 35 | # get relevant retrieval query 36 | user_instruction = self.coala_class.compose_user_instruction() 37 | instruction = f"""{user_instruction}Your role is generate a query for semantic search to retrieve important according to current working memory and the available files. 38 | Even if there is no relevant information in the working memory, you should still generate a query to retrieve the most relevant information from the available files. 39 | Only respond with the query iteself NOTHING ELSE. 40 | 41 | """ # noqa 42 | if not len(self.vector_stores): 43 | self.retrieve_vector_stores() 44 | 45 | messages = [ 46 | { 47 | "role": "user", 48 | "content": instruction + self.compose_query_system_prompt(), 49 | }, 50 | ] 51 | response = litellm_client.chat.completions.create( 52 | model=os.getenv( 53 | "LITELLM_MODEL" 54 | ), # Replace with your model of choice 55 | messages=messages, 56 | max_tokens=200, # You may adjust the token limit as necessary 57 | ) 58 | query = response.choices[0].message.content 59 | # TODO: retrieve from db, and delete mock retrieval document 60 | vector_store_ids = ( 61 | self.coala_class.assistant.tool_resources.file_search.vector_store_ids 62 | ) 63 | retrieved_documents = retrieve_file_chunks( 64 | vector_store_ids, 65 | query, 66 | ) 67 | 68 | run_step = create_retrieval_runstep( 69 | self.coala_class.thread_id, 70 | self.coala_class.run_id, 71 | self.coala_class.assistant_id, 72 | retrieved_documents, 73 | ) 74 | return run_step 75 | 76 | def compose_file_list( 77 | self, 78 | ) -> str: 79 | files_names = [] 80 | 81 | print 82 | 83 | file_ids = ( 84 | [] 85 | ) # NOTE: this only work natively with OpenGPTs-platform/assistants-api otherwise you need to make sure to manually manage the metadata["_file_ids"] inside the vector stores # noqa 86 | 87 | for vector_store in self.vector_stores: 88 | vector_store_file_ids = ( 89 | json.loads(vector_store.metadata["_file_ids"]) 90 | if "_file_ids" in vector_store.metadata 91 | else [] 92 | ) 93 | file_ids.extend(vector_store_file_ids) 94 | 95 | if not file_ids: 96 | print("\n\nNO FILES AVAILABLE: ", file_ids) 97 | return "" 98 | for file_id in file_ids: 99 | file = assistants_client.files.retrieve(file_id) 100 | files_names.append(f"- {file.filename}") 101 | return "\n".join(files_names) 102 | 103 | def compose_query_system_prompt(self) -> str: 104 | composed_instruction = "" 105 | trace = self.coala_class.compose_react_trace() 106 | 107 | file_list_str = self.compose_file_list() 108 | if file_list_str: 109 | composed_instruction += f"""The files currently available to you are: 110 | {self.compose_file_list()} 111 | 112 | """ 113 | 114 | composed_instruction += f"""Current working memory: 115 | {trace}""" 116 | return composed_instruction 117 | -------------------------------------------------------------------------------- /assistants_api/app/routers/run_router.py: -------------------------------------------------------------------------------- 1 | # In your FastAPI router file 2 | from fastapi import APIRouter, Body, Depends, HTTPException, Path 3 | from sqlalchemy.orm import Session 4 | from lib.db import ( 5 | crud, 6 | schemas, 7 | database, 8 | ) # Import your CRUD handlers, schemas, and models 9 | from lib.mb.broker import RabbitMQBroker, get_broker 10 | from utils.tranformers import db_to_pydantic_run 11 | import json 12 | 13 | router = APIRouter() 14 | 15 | 16 | @router.post("/threads/{thread_id}/runs", response_model=schemas.Run) 17 | def create_run( 18 | thread_id: str = Path(..., title="The ID of the thread to run"), 19 | run: schemas.RunContent = Body(..., title="The run content"), 20 | db: Session = Depends(database.get_db), 21 | broker: RabbitMQBroker = Depends(get_broker), 22 | ): 23 | """ 24 | Create a new run within a specified thread. 25 | 26 | This endpoint creates a new run associated with a given thread, using the provided run content. 27 | It ensures that the specified thread exists and validates the run content against the associated assistant. 28 | If the creation is successful, the new run's ID is published to a RabbitMQ queue for further processing. 29 | 30 | Parameters: 31 | - thread_id (str): The ID of the thread in which the run is to be created. 32 | - run (schemas.RunContent): The content of the run, including any specific instructions, model ID, and tools. 33 | 34 | Returns: 35 | - The newly created run as a Pydantic model, conforming to schemas.Run. 36 | 37 | Raises: 38 | - HTTPException: If the run creation fails, an HTTP 500 error is returned with a failure detail. 39 | """ # noqa 40 | db_run = crud.create_run(db=db, thread_id=thread_id, run_params=run) 41 | if db_run is None: 42 | raise HTTPException(status_code=500, detail="Run creation failed") 43 | 44 | # After successful creation, publish the run ID to the RabbitMQ queue 45 | data = {"thread_id": thread_id, "run_id": str(db_run.id)} 46 | message = json.dumps(data) 47 | broker.publish("runs_queue", message) 48 | broker.close_connection() 49 | 50 | return db_to_pydantic_run(db_run) 51 | 52 | 53 | @router.get("/threads/{thread_id}/runs/{run_id}", response_model=schemas.Run) 54 | def read_run( 55 | thread_id: str, run_id: str, db: Session = Depends(database.get_db) 56 | ): 57 | db_run = crud.get_run(db, thread_id=thread_id, run_id=run_id) 58 | if db_run is None: 59 | raise HTTPException(status_code=404, detail="Run not found") 60 | return db_to_pydantic_run(db_run) 61 | 62 | 63 | @router.post( 64 | "/threads/{thread_id}/runs/{run_id}/cancel", response_model=schemas.Run 65 | ) 66 | def cancel_run( 67 | thread_id: str, run_id: str, db: Session = Depends(database.get_db) 68 | ): 69 | run = crud.cancel_run(db, thread_id=thread_id, run_id=run_id) 70 | if run is None: 71 | raise HTTPException(status_code=404, detail="Run not found") 72 | return db_to_pydantic_run(run) 73 | 74 | 75 | @router.post( 76 | "/threads/{thread_id}/runs/{run_id}/submit_tool_outputs", 77 | response_model=schemas.Run, 78 | ) 79 | def submit_tool_outputs( 80 | *, 81 | thread_id: str = Path( 82 | ..., description="The ID of the thread to which this run belongs." 83 | ), 84 | run_id: str = Path( 85 | ..., 86 | description="The ID of the run that requires the tool output submission.", 87 | ), 88 | body: schemas.SubmitToolOutputsRunRequest = Body( 89 | ..., description="Request body containing tool outputs." 90 | ), 91 | db: Session = Depends(database.get_db), 92 | broker: RabbitMQBroker = Depends(get_broker), 93 | ): 94 | # Logic to handle the submission of tool outputs 95 | # This will involve updating the database and performing necessary actions 96 | try: 97 | db_run = crud.submit_tool_outputs( 98 | db=db, 99 | thread_id=thread_id, 100 | run_id=run_id, 101 | tool_outputs=body.tool_outputs, 102 | ) 103 | # After successful creation, publish the run ID to the RabbitMQ queue 104 | data = {"thread_id": thread_id, "run_id": str(db_run.id)} 105 | message = json.dumps(data) 106 | broker.publish("runs_queue", message) 107 | broker.close_connection() 108 | 109 | return db_to_pydantic_run(db_run) 110 | 111 | except Exception as e: 112 | raise HTTPException(status_code=400, detail=str(e)) 113 | -------------------------------------------------------------------------------- /run_executor_worker/src/actions/function_calling_tool.py: -------------------------------------------------------------------------------- 1 | import json 2 | from utils.ops_api_handler import ( 3 | create_function_runstep, 4 | create_message_runstep, 5 | ) 6 | from utils.openai_clients import fc_client 7 | from openai.types.beta.threads.runs.function_tool_call import Function 8 | from data_models import run 9 | import os 10 | from agents import coala 11 | from openai.types.chat.chat_completion_message import ( 12 | ChatCompletionMessage, 13 | ) 14 | 15 | 16 | class FunctionCallingTool: 17 | def __init__( 18 | self, 19 | coala_class: "coala.CoALA", 20 | ): 21 | self.coala_class = coala_class 22 | self.function_tools = [ 23 | tool.model_dump() 24 | for tool in self.coala_class.assistant.tools 25 | if tool.type == "function" 26 | ] 27 | 28 | def generate_tool_call(self) -> run.RunStep: 29 | # get all tools of type function 30 | instructions = """Your role is to call one of the function that are provided to provided to you.\n""" # noqa 31 | tool_call = fc_client.chat.completions.create( 32 | messages=[ 33 | { 34 | "role": "user", 35 | "content": instructions 36 | + self.compose_query_system_prompt(), 37 | } 38 | ], 39 | model=os.getenv("FC_MODEL"), 40 | tools=self.function_tools, 41 | ) 42 | print("\n\ntool_call:\n", tool_call) 43 | function = tool_call.choices[0].message.tool_calls[0].function 44 | # cast to run steps function 45 | function = Function( 46 | name=function.name, 47 | arguments=function.arguments, 48 | ) 49 | # creat run step 50 | run_step = create_function_runstep( 51 | self.coala_class.thread_id, 52 | self.coala_class.run_id, 53 | self.coala_class.assistant_id, 54 | function, 55 | ) 56 | 57 | return run_step 58 | 59 | def generate_tool_summary(self, function_step: run.RunStep) -> run.RunStep: 60 | generator_messages = [ 61 | {"role": message.role, "content": message.content[0].text.value} 62 | for message in self.coala_class.messages.data 63 | ] 64 | # add function call and response to messages 65 | tool_calls = [] 66 | tool_results = [] 67 | for tool_call in function_step.step_details.tool_calls: 68 | tool_calls.append( 69 | ChatCompletionMessage( 70 | role="assistant", 71 | tool_calls=[ 72 | { 73 | "id": tool_call.id, 74 | "function": { 75 | "arguments": tool_call.function.arguments, 76 | "name": tool_call.function.name, 77 | }, 78 | "type": tool_call.type, 79 | } 80 | ], 81 | ) 82 | ) 83 | tool_results.append( 84 | { 85 | "role": "tool", 86 | "tool_call_id": tool_call.id, 87 | "name": tool_call.function.name, 88 | "content": tool_call.function.output, 89 | } 90 | ) 91 | 92 | generator_messages += tool_calls + tool_results 93 | 94 | fc_summary = fc_client.chat.completions.create( 95 | messages=generator_messages, 96 | model=os.getenv("FC_MODEL"), 97 | tools=self.function_tools, 98 | ) 99 | 100 | print("\n\nfc_summary:\n", fc_summary) 101 | 102 | message_rs = create_message_runstep( 103 | self.coala_class.thread_id, 104 | self.coala_class.run_id, 105 | self.coala_class.assistant_id, 106 | fc_summary.choices[0].message.content, 107 | ) 108 | react_step = coala.ReactStep( 109 | step_type=coala.ReactStepType.THOUGHT, 110 | content=json.dumps(fc_summary.choices[0].message.content), 111 | ) 112 | self.coala_class.react_steps.append(react_step) 113 | 114 | return message_rs 115 | 116 | def compose_query_system_prompt(self) -> str: 117 | trace = self.coala_class.compose_react_trace() 118 | 119 | composed_instruction = f"""Current working memory: 120 | {trace}""" 121 | return composed_instruction 122 | -------------------------------------------------------------------------------- /assistants_api/app/routers/ops/web_retrieval_ops_router.py: -------------------------------------------------------------------------------- 1 | # ops/web_retrieval.py 2 | from fastapi import APIRouter, Body, HTTPException 3 | from utils.crawling import ( 4 | crawl_websites, 5 | content_preprocess, 6 | ) 7 | from lib.wv.client import client 8 | import weaviate 9 | from lib.db import schemas 10 | 11 | router = APIRouter() 12 | 13 | COLLECTION_NAME = "web_retrieval" 14 | DEFAULT_WEB_RETRIEVAL_DESCRIPTION = "web_retrieval has not been initiated yet. Do not use this tool. To initiate it use `client.ops.web_retrieval.crawl_and_upsert(...)`" # noqa 15 | 16 | 17 | async def success_callback( 18 | crawl_info: schemas.CrawlInfo, collection: weaviate.collections.Collection 19 | ): 20 | print(f"Callback for URL: {crawl_info.url}\n") 21 | try: 22 | collection.data.delete_many( 23 | where=weaviate.classes.query.Filter.by_property("url").equal( 24 | crawl_info.url 25 | ) 26 | ) 27 | processed_data = content_preprocess(crawl_info) 28 | data_to_insert = [ 29 | {"url": info.url, "content": info.content, "depth": info.depth} 30 | for info in processed_data 31 | ] 32 | collection.data.insert_many(data_to_insert) 33 | except Exception as e: 34 | print(f"Error during callback for URL {crawl_info.url}: {e}") 35 | 36 | 37 | @router.post("/ops/web_retrieval", response_model=schemas.WebRetrievalResponse) 38 | async def start_crawl( 39 | data: schemas.WebRetrievalCreate = Body( 40 | ..., title="Root URLs and max depth" 41 | ), 42 | ): 43 | if data.description == DEFAULT_WEB_RETRIEVAL_DESCRIPTION: 44 | data.description = "Web Retrieval contains information scraped from specific website domains. Use this when precise information in a website may need to be retrieved." # noqa 45 | print( 46 | f"\n\nWARNING: WEB_RETRIEVAL_DESCRIPTION is not set. Defaulting to \"{data.description}\"" # noqa 47 | ) # noqa 48 | collection = client.collections.get(name=COLLECTION_NAME) 49 | if data.description: 50 | collection.config.update(description=data.description) 51 | 52 | print("Starting web retrieval...") 53 | try: 54 | crawl_infos = await crawl_websites( 55 | data.root_urls, 56 | data.constrain_to_root_domain, 57 | data.max_depth, 58 | lambda x: success_callback(x, collection), 59 | ) 60 | 61 | print(f"\n\nTotal crawls: {len(crawl_infos)}") 62 | no_error_craws = [c for c in crawl_infos if c.error is None] 63 | print(f"Successful crawls count: {len(no_error_craws)}") 64 | 65 | # clear content from crawl_infos 66 | for crawl_info in crawl_infos: 67 | crawl_info.content = "" 68 | return schemas.WebRetrievalResponse( 69 | message="Crawling completed successfully.", 70 | crawl_infos=crawl_infos, 71 | ) 72 | except Exception as e: 73 | raise HTTPException(status_code=500, detail=str(e)) 74 | 75 | 76 | # behaves more like restart 77 | @router.delete("/ops/web_retrieval", response_model=schemas.DeleteResponse) 78 | async def delete_collection(): 79 | try: 80 | if client.collections.exists(name=COLLECTION_NAME): 81 | client.collections.delete(name=COLLECTION_NAME) 82 | # recreate the collection with no items 83 | del_res = schemas.DeleteResponse( 84 | message=f"Collection '{COLLECTION_NAME}' deleted successfully." 85 | ) 86 | else: 87 | del_res = schemas.DeleteResponse( 88 | message=f"Collection '{COLLECTION_NAME}' does not exist." 89 | ) 90 | except Exception as e: 91 | del_res = schemas.DeleteResponse(message=f"Error: {str(e)}") 92 | client.collections.create( 93 | name=COLLECTION_NAME, 94 | description=DEFAULT_WEB_RETRIEVAL_DESCRIPTION, 95 | generative_config=weaviate.classes.config.Configure.Generative.openai(), 96 | properties=[ 97 | weaviate.classes.config.Property( 98 | name="url", data_type=weaviate.classes.config.DataType.TEXT 99 | ), 100 | weaviate.classes.config.Property( 101 | name="content", 102 | data_type=weaviate.classes.config.DataType.TEXT, 103 | ), 104 | weaviate.classes.config.Property( 105 | name="depth", 106 | data_type=weaviate.classes.config.DataType.NUMBER, 107 | ), 108 | ], 109 | vectorizer_config=[ 110 | weaviate.classes.config.Configure.NamedVectors.text2vec_openai( 111 | name="content_and_url", 112 | source_properties=["content", "url"], 113 | ) 114 | ], 115 | ) 116 | return del_res 117 | -------------------------------------------------------------------------------- /assistants_api/app/routers/assistant_router.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter, Depends, HTTPException, Query 2 | from typing import Optional 3 | from utils.tranformers import db_to_pydantic_assistant 4 | 5 | from lib.db.database import get_db 6 | from sqlalchemy.orm import Session 7 | from lib.db import crud, schemas 8 | 9 | router = APIRouter() 10 | 11 | 12 | @router.post("/assistants", response_model=schemas.Assistant) 13 | def create_assistant( 14 | assistant: schemas.AssistantCreate, 15 | db: Session = Depends(get_db), 16 | ): 17 | """ 18 | Create a new assistant. 19 | - **model**: ID of the model to use. 20 | - **name**: The name of the assistant. 21 | - **description**: The description of the assistant. 22 | - **instructions**: The system instructions that the assistant uses. 23 | - **tools**: A list of tools enabled on the assistant. 24 | - **file_ids**: A list of file IDs attached to this assistant. 25 | - **metadata**: Set of 16 key-value pairs that can be attached to the assistant. 26 | """ 27 | 28 | db_assistant = crud.create_assistant(db=db, assistant=assistant) 29 | return db_to_pydantic_assistant(db_assistant) 30 | 31 | 32 | @router.get( 33 | "/assistants", response_model=schemas.SyncCursorPage[schemas.Assistant] 34 | ) 35 | def list_assistants( 36 | db: Session = Depends(get_db), 37 | limit: int = Query(default=20, le=100), 38 | order: str = Query(default="desc", regex="^(asc|desc)$"), 39 | after: Optional[str] = None, 40 | before: Optional[str] = None, 41 | ): 42 | """ 43 | List assistants with optional pagination and ordering. 44 | - **limit**: Maximum number of results to return. 45 | - **order**: Sort order based on the creation time ('asc' or 'desc'). 46 | - **after**: ID to start the list from (for pagination). 47 | - **before**: ID to list up to (for pagination). 48 | """ 49 | db_assistants = crud.get_assistants( 50 | db=db, limit=limit, order=order, after=after, before=before 51 | ) 52 | 53 | assistants = [ 54 | db_to_pydantic_assistant(assistant) for assistant in db_assistants 55 | ] 56 | paginated_assistants = schemas.SyncCursorPage(data=assistants) 57 | 58 | return paginated_assistants 59 | 60 | 61 | @router.get("/assistants/{assistant_id}", response_model=schemas.Assistant) 62 | def get_assistant(assistant_id: str, db: Session = Depends(get_db)): 63 | """ 64 | Retrieves an assistant by its unique ID. 65 | 66 | - **assistant_id**: UUID of the assistant to retrieve. 67 | """ 68 | db_assistant = crud.get_assistant_by_id(db=db, assistant_id=assistant_id) 69 | if db_assistant is None: 70 | raise HTTPException(status_code=404, detail="No assistant found") 71 | return db_to_pydantic_assistant(db_assistant) 72 | 73 | 74 | @router.post("/assistants/{assistant_id}", response_model=schemas.Assistant) 75 | def update_assistant( 76 | assistant_id: str, 77 | assistant_update: schemas.AssistantUpdate, 78 | db: Session = Depends(get_db), 79 | ): 80 | """ 81 | Updates specified fields of an existing assistant. 82 | 83 | - **assistant_id**: UUID of the assistant to update. 84 | - **assistant_update**: JSON body containing fields to update on the assistant. 85 | - `name`: Optional. New name of the assistant (max length: 256). 86 | - `description`: Optional. New description of the assistant (max length: 512). 87 | - `model`: Optional. Model ID to use for the assistant (max length: 256). 88 | - `instructions`: Optional. System instructions for the assistant (max length: 32768). 89 | - `tools`: Optional. List of tools enabled on the assistant. 90 | - `metadata`: Optional. Metadata key-value pairs attached to the assistant. 91 | """ # noqa 92 | # Update the assistant with new values 93 | updated_assistant = crud.update_assistant( 94 | db=db, 95 | assistant_id=assistant_id, 96 | assistant_update=assistant_update.model_dump(exclude_none=True), 97 | ) 98 | 99 | if updated_assistant is None: 100 | raise HTTPException(status_code=404, detail="No assistant found") 101 | 102 | return db_to_pydantic_assistant(updated_assistant) 103 | 104 | 105 | @router.delete( 106 | "/assistants/{assistant_id}", 107 | response_model=schemas.AssistantDeleted, 108 | ) 109 | def delete_assistant( 110 | assistant_id: str, 111 | db: Session = Depends(get_db), 112 | ): 113 | """ 114 | Deletes an assistant by its unique ID. 115 | 116 | - **assistant_id**: UUID of the assistant to delete. 117 | """ 118 | deletion_success = crud.delete_assistant(db=db, assistant_id=assistant_id) 119 | if not deletion_success: 120 | raise HTTPException(status_code=404, detail="No assistant found") 121 | return {"id": assistant_id, "deleted": True, "object": "assistant.deleted"} 122 | -------------------------------------------------------------------------------- /run_executor_worker/src/agents/router.py: -------------------------------------------------------------------------------- 1 | from constants import PromptKeys 2 | from utils.context import context_trimmer 3 | from utils.openai_clients import ( 4 | fc_client, 5 | litellm_client, 6 | ChatCompletion, 7 | ) 8 | import os 9 | from run_executor import main 10 | import json 11 | 12 | 13 | class RouterAgent: 14 | def __init__( 15 | self, 16 | execute_run_class: "main.ExecuteRun", 17 | ): 18 | self.execute_run_class = execute_run_class 19 | 20 | def compose_system_prompt(self) -> str: 21 | return f"""USER_INSTRUCTION:```{self.execute_run_class.assistant.instructions}```""" # noqa 22 | 23 | # TODO: add assistant and base tools off of assistant 24 | def generate( 25 | self, 26 | ) -> str: 27 | """ 28 | Generates a response based on the chat history and role instructions. 29 | 30 | Args: 31 | tools (dict): The tools available to the agent. 32 | paginated_messages (SyncCursorPage[Message]): The chat history. 33 | 34 | Returns: 35 | str: It either returns `{PromptKeys.TRANSITION.value}` or a generated response. 36 | """ # noqa 37 | 38 | # Build messages to send to the model 39 | cleaned_messages = [] 40 | for message in self.execute_run_class.messages.data: 41 | cleaned_messages.append( 42 | { 43 | "role": message.role, 44 | "content": message.content[0].text.value, 45 | } 46 | ) 47 | 48 | trimmed_messages = cleaned_messages 49 | if self.execute_run_class.run.max_prompt_tokens: 50 | trimmed_messages = context_trimmer( 51 | item_list=cleaned_messages, 52 | max_length=self.execute_run_class.run.max_prompt_tokens * 3, 53 | trim_start=True, 54 | ) 55 | 56 | messages = [ 57 | { 58 | "role": "system", 59 | "content": self.compose_system_prompt(), 60 | } 61 | ] + trimmed_messages 62 | try: 63 | tools_list = "\n".join( 64 | [ 65 | f"- {tool.type}: {tool.description}" 66 | for _, tool in self.execute_run_class.tools_map.items() 67 | ] 68 | ) 69 | tools_needed_response: ChatCompletion = fc_client.chat.completions.create( 70 | model=os.getenv("FC_MODEL"), 71 | messages=messages, 72 | tools=[ 73 | { 74 | 'type': 'function', 75 | 'function': { 76 | 'name': 'determine_tools_needed', 77 | 'description': f"""The following tools are available to you:```{tools_list}``` 78 | Determine if those tools are needed to respond to the user's message.""", # noqa 79 | 'parameters': { 80 | 'type': 'object', 81 | 'properties': { 82 | 'tools_needed': { 83 | 'type': 'boolean', 84 | 'description': 'Are the tools necessary.', # noqa 85 | } 86 | }, 87 | 'required': ['tools_needed'], 88 | }, 89 | }, 90 | } 91 | ], 92 | max_tokens=28, 93 | tool_choice={ 94 | "type": "function", 95 | "function": {"name": "determine_tools_needed"}, 96 | }, 97 | ) 98 | 99 | # parse the response to get the arguments 100 | print( 101 | "\n\nTool needed response:\n", 102 | tools_needed_response.choices[0] 103 | .message.tool_calls[0] 104 | .function, 105 | ) 106 | tools_needed_args = json.loads( 107 | tools_needed_response.choices[0] 108 | .message.tool_calls[0] 109 | .function.arguments 110 | ) 111 | if tools_needed_args["tools_needed"]: 112 | return PromptKeys.TRANSITION.value 113 | else: 114 | pass 115 | except Exception as e: 116 | print("Error with tools_needed_response:", e) 117 | 118 | response = litellm_client.chat.completions.create( 119 | model=os.getenv("LITELLM_MODEL"), 120 | messages=messages, 121 | max_tokens=2000, 122 | ) 123 | 124 | print("GENERATION: ", response.choices[0].message.content) 125 | 126 | return response.choices[0].message.content 127 | -------------------------------------------------------------------------------- /assistants_api/tests/ops/test_runsteps_ops.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import pytest 3 | from openai import OpenAI 4 | from openai.types.beta.threads.runs import RunStep 5 | import os 6 | import time 7 | 8 | api_key = os.getenv("OPENAI_API_KEY") if os.getenv("OPENAI_API_KEY") else None 9 | 10 | 11 | @pytest.fixture 12 | def openai_client(): 13 | return OpenAI( 14 | base_url="http://localhost:8000", 15 | api_key=api_key, 16 | ) 17 | 18 | 19 | @pytest.fixture 20 | def thread_id(openai_client: OpenAI): 21 | thread_metadata = {"example_key": "example_value"} 22 | response = openai_client.beta.threads.create(metadata=thread_metadata) 23 | return response.id 24 | 25 | 26 | @pytest.fixture 27 | def assistant_id(openai_client: OpenAI): 28 | response = openai_client.beta.assistants.create( 29 | instructions="You are an AI designed to provide examples.", 30 | name="Example Assistant", 31 | tools=[{"type": "code_interpreter"}], 32 | model="gpt-3.5-turbo", 33 | ) 34 | return response.id 35 | 36 | 37 | @pytest.fixture 38 | def run_id(openai_client: OpenAI, thread_id: str, assistant_id: str): 39 | response = openai_client.beta.threads.runs.create( 40 | thread_id=thread_id, 41 | assistant_id=assistant_id, 42 | ) 43 | return response.id 44 | 45 | 46 | @pytest.mark.dependency() 47 | def test_create_run_step(assistant_id: str, thread_id: str, run_id: str): 48 | create_url = ( 49 | f"http://localhost:8000/ops/threads/{thread_id}/runs/{run_id}/steps" 50 | ) 51 | step_data = { 52 | "assistant_id": assistant_id, 53 | "type": "tool_calls", 54 | "status": "in_progress", 55 | "step_details": {"tool_calls": [], "type": "tool_calls"}, 56 | } 57 | 58 | response = requests.post(create_url, json=step_data) 59 | 60 | # Verify the response status code and the returned fields 61 | assert response.status_code == 200 62 | created_step = response.json() 63 | created_step = RunStep(**created_step) 64 | assert created_step.status == "in_progress" 65 | assert created_step.type == "tool_calls" 66 | assert created_step.assistant_id == assistant_id 67 | assert created_step.thread_id == thread_id 68 | assert created_step.run_id == run_id 69 | assert created_step.id is not None 70 | 71 | return created_step.id 72 | 73 | 74 | @pytest.mark.dependency(depends=["test_create_run_step"]) 75 | def test_update_run_step(assistant_id: str, thread_id: str, run_id: str): 76 | step_id = test_create_run_step(assistant_id, thread_id, run_id) 77 | update_url = f"http://localhost:8000/ops/threads/{thread_id}/runs/{run_id}/steps/{step_id}" # noqa 78 | curr_time = int(time.time()) 79 | update_data = { 80 | "status": "completed", 81 | "completed_at": curr_time, 82 | } 83 | 84 | response = requests.post(update_url, json=update_data) 85 | 86 | # Verify the response status code and the updated fields 87 | assert response.status_code == 200 88 | updated_step = response.json() 89 | updated_step = RunStep(**updated_step) 90 | assert updated_step.status == "completed" 91 | assert updated_step.completed_at == curr_time 92 | assert updated_step.id == step_id 93 | 94 | 95 | # only works with access to steps (therefore does not work with OpenAI hosted Assistants API) # noqa 96 | @pytest.mark.dependency(depends=["test_create_run_step"]) 97 | def test_create_multiple_run_steps( 98 | openai_client: OpenAI, assistant_id: str, thread_id: str, run_id: str 99 | ): 100 | create_url = ( 101 | f"http://localhost:8000/ops/threads/{thread_id}/runs/{run_id}/steps" 102 | ) 103 | step_data_tool = { 104 | "assistant_id": assistant_id, 105 | "type": "tool_calls", 106 | "status": "in_progress", 107 | "step_details": {"tool_calls": [], "type": "tool_calls"}, 108 | } 109 | step_data_message = { 110 | "assistant_id": assistant_id, 111 | "type": "message_creation", 112 | "status": "in_progress", 113 | "step_details": { 114 | "message_creation": {"message_id": "msg_6iTjazdBj74xg3yVbjrZye9P"}, 115 | "type": "message_creation", 116 | }, 117 | } 118 | 119 | requests.post(create_url, json=step_data_tool) 120 | requests.post(create_url, json=step_data_message) 121 | 122 | response = openai_client.beta.threads.runs.steps.list( 123 | thread_id=thread_id, run_id=run_id 124 | ) 125 | assert response is not None 126 | assert len(response.data) == 2 127 | assert isinstance(response.data[0], RunStep) 128 | assert response.data[0].id is not None 129 | assert response.data[0].step_details.type == "tool_calls" 130 | assert response.data[1].id is not None 131 | assert response.data[1].step_details.type == "message_creation" 132 | assert ( 133 | response.data[1].step_details.message_creation.message_id 134 | == "msg_6iTjazdBj74xg3yVbjrZye9P" 135 | ) 136 | -------------------------------------------------------------------------------- /assistants_api/app/utils/crawling.py: -------------------------------------------------------------------------------- 1 | # crawling.py 2 | import httpx 3 | from bs4 import BeautifulSoup 4 | from urllib.parse import urljoin, urlparse 5 | import asyncio 6 | import fitz # PyMuPDF 7 | from langchain_text_splitters import ( 8 | RecursiveCharacterTextSplitter, 9 | HTMLHeaderTextSplitter, 10 | ) 11 | from lib.db.schemas import CrawlInfo 12 | 13 | 14 | async def fetch_url(client, url, current_depth, retries=1, timeout=10.0): 15 | for attempt in range(retries): 16 | try: 17 | response = await client.get(url, timeout=timeout) 18 | response.raise_for_status() 19 | 20 | content_type = response.headers.get("content-type", "").lower() 21 | if "application/pdf" in content_type: 22 | print(f"Fetched PDF {url} at depth {current_depth}") 23 | pdf_content = await fetch_pdf_content(response.content) 24 | return pdf_content, None # Return content with no error 25 | else: 26 | print(f"Fetched HTML {url} at depth {current_depth}") 27 | return response.text, None # Return content with no error 28 | except ( 29 | httpx.RequestError, 30 | httpx.HTTPStatusError, 31 | httpx.TimeoutException, 32 | ) as e: 33 | print(f"Error fetching {url}: {e}") 34 | return None, str(e) # Return no content with error message 35 | 36 | print(f"Failed to fetch {url} after {retries} attempts.") 37 | return None, "Failed after retries" 38 | 39 | 40 | async def fetch_pdf_content(pdf_bytes): 41 | try: 42 | document = fitz.open(stream=pdf_bytes, filetype="pdf") 43 | text = "" 44 | for page in document: 45 | text += page.get_text() 46 | return text 47 | except Exception as e: 48 | print(f"Error processing PDF: {e}") 49 | return None 50 | 51 | 52 | async def process_url( 53 | client, 54 | url, 55 | current_depth, 56 | root_url, 57 | visited, 58 | max_depth, 59 | constrain_to_root_domain, 60 | success_callback, 61 | ): 62 | if url in visited or (max_depth is not None and current_depth > max_depth): 63 | return None, None, None 64 | 65 | visited.add(url) 66 | content, error = await fetch_url(client, url, current_depth) 67 | if content is None and error is not None: 68 | crawl_info = CrawlInfo( 69 | url=url, content="", error=error, depth=current_depth 70 | ) 71 | return crawl_info, root_url, None 72 | 73 | crawl_info = CrawlInfo(url=url, content=content, depth=current_depth) 74 | if success_callback: 75 | await success_callback(crawl_info) # Call the success callback 76 | 77 | if not content.startswith( 78 | "%PDF" 79 | ): # PDF content will not contain HTML links 80 | soup = BeautifulSoup(content, "lxml") 81 | links = [ 82 | urljoin(url, a_tag["href"]) 83 | for a_tag in soup.find_all("a", href=True) 84 | ] 85 | if constrain_to_root_domain: 86 | valid_links = [ 87 | link 88 | for link in links 89 | if urlparse(link).netloc == urlparse(root_url).netloc 90 | ] 91 | else: 92 | valid_links = links 93 | print(f"\nFound {len(valid_links)} valid links on {url}") 94 | return crawl_info, root_url, valid_links 95 | else: 96 | return crawl_info, root_url, [] 97 | 98 | 99 | async def crawl_websites( 100 | root_urls, 101 | constrain_to_root_domain, 102 | max_depth=None, 103 | success_callback=None, 104 | ) -> list[CrawlInfo]: 105 | visited = set() 106 | queue = [(url, url, 0) for url in root_urls] # (url, root_url, depth) 107 | all_data = [] 108 | 109 | async with httpx.AsyncClient() as client: 110 | while queue: 111 | print("\n\nQueue:\n", queue) 112 | tasks = [ 113 | process_url( 114 | client, 115 | url, 116 | depth, 117 | root_url, 118 | visited, 119 | max_depth, 120 | constrain_to_root_domain, 121 | success_callback, 122 | ) 123 | for url, root_url, depth in queue 124 | ] 125 | results = await asyncio.gather(*tasks) 126 | 127 | new_queue = [] 128 | for data, root_url, links in results: 129 | if data is not None: 130 | all_data.append(data) 131 | if links is not None: 132 | new_queue.extend( 133 | (link, root_url, data.depth + 1) for link in links 134 | ) 135 | 136 | queue = new_queue 137 | 138 | return all_data 139 | 140 | 141 | # Placeholder function for preprocessing content 142 | def content_preprocess(crawl_info: CrawlInfo): 143 | chunk_size = 2000 144 | chunk_overlap = 200 145 | documents = [] 146 | if crawl_info.url.endswith(".pdf"): 147 | text_splitter = RecursiveCharacterTextSplitter( 148 | chunk_size=chunk_size, 149 | chunk_overlap=chunk_overlap, 150 | length_function=len, 151 | is_separator_regex=False, 152 | ) 153 | 154 | documents = text_splitter.create_documents([crawl_info.content]) 155 | 156 | else: # if is HTML content 157 | headers_to_split_on = [("h1", "Header 1"), ("h2", "Header 2")] 158 | splitter = HTMLHeaderTextSplitter( 159 | headers_to_split_on=headers_to_split_on 160 | ) 161 | header_split_docs = splitter.split_text(crawl_info.content) 162 | 163 | text_splitter = RecursiveCharacterTextSplitter( 164 | chunk_size=chunk_size, chunk_overlap=chunk_overlap 165 | ) 166 | documents = text_splitter.split_documents(header_split_docs) 167 | 168 | crawl_info_docs = [ 169 | CrawlInfo( 170 | url=crawl_info.url, 171 | content=doc.page_content, 172 | depth=crawl_info.depth, 173 | ) 174 | for doc in documents 175 | ] 176 | 177 | return crawl_info_docs 178 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # assistants-api 2 | ### [DISCORD](https://discord.gg/jZSVhtwTz6) 3 | ### [Video Demo](https://www.youtube.com/watch?v=cAKcsK7uYro&t=19s) 4 | 5 | Replicate and improve the OpenAI Assistants API 6 | Note: currently support client `openai==1.26.0` (excluding custom tools like `web_retriever`) or use our fork by `pip install git+https://github.com/OpenGPTs-platform/openai-python.git` 7 | 8 | 9 | ### [Video Demo (full OpenGPTs-platform)](https://youtu.be/yPdIEKb3jWc) 10 | 11 | Architecture 12 | ![image](https://github.com/OpenGPTs-platform/assistants-api/assets/37946988/faa5a4b2-1186-49b8-a80b-39c4fc00b772) 13 | 14 | ### Quickstart 15 | 0. Clone the repo `git clone https://github.com/OpenGPTs-platform/assistants-api.git` and navigate into `assistants-api`directory 16 | 1. Create a copy of [`.env.example`](./.env.example) and name it `.env`. Fill in necessary values. 17 | 2. Start docker-compose `docker-compose -f .\docker-compose.dev.yml up` 18 | 3. Its running 🥳! 19 | 4. In a new directory and environment, install the `openai` client fork with `pip install IPython git+https://github.com/OpenGPTs-platform/openai-python.git`, and try it with the following demo (NOTE: update `YOUR_FILE_PATH` with your file path that you want to test retrieval with). Also it may now 20 | ```py 21 | from openai import OpenAI 22 | client = OpenAI( 23 | base_url="http://localhost:8000", 24 | api_key="NO_KEY_NEEDED", 25 | ) 26 | 27 | # Upload file to the server 28 | file = client.files.create( 29 | file=open('YOUR_FILE_PATH', 'rb'), # Input your file path (currently accepts .txt and .pdf files) 30 | purpose='assistants' 31 | ) 32 | 33 | # Create vector store with the file id 34 | vs = client.beta.vector_stores.create( 35 | name='my_info', 36 | file_ids=[file.id] 37 | ) 38 | 39 | # Create an assistant with the vector store passed in 40 | asst = client.beta.assistants.create( 41 | name="Demo Assistant", 42 | instructions="Always start your responses for with the word 'APPLE'", 43 | model="gpt-4-turbo", 44 | tools=[{"type": "file_search"}, {"type": "web_retrieval"}], 45 | tool_resources={ 46 | "file_search": { 47 | "vector_store_ids": [vs.id] 48 | } 49 | }, 50 | ) 51 | 52 | # Create a thread with or without messages 53 | thr = client.beta.threads.create( 54 | messages=[ 55 | { 56 | "role": "user", 57 | "content": "I am curious what is in the file I provided" 58 | } 59 | ], 60 | ) 61 | 62 | # Execute the run (Adds run to RabbitMQ for to be dequeued and processed by run_executor_worker) 63 | run = client.beta.threads.runs.create( 64 | thread_id=thr.id, 65 | assistant_id=asst.id 66 | ) 67 | 68 | # Poll the response untill complete (streaming not yet supported) 69 | from IPython.display import clear_output 70 | import time 71 | while run.status not in ['completed', 'failed']: 72 | time.sleep(1) 73 | clear_output(wait=True) 74 | run = client.beta.threads.runs.retrieve(thread_id=thr.id, run_id=run.id) 75 | print("RUN STATUS:\n",run.status) 76 | messages = client.beta.threads.messages.list(thread_id=thr.id, order='desc') 77 | print("THREAD MESSAGES:\n",messages.model_dump_json(indent=2)) 78 | ``` 79 | 80 | ## [More Comprehensive Demo](./examples/compounding_demo.ipynb) 81 | ## [assistants_api](./assistants_api) 82 | ![image](https://github.com/OpenGPTs-platform/assistants-api/assets/37946988/c5eac63b-b1bb-4504-ab02-4c8814d81e8d) 83 | [_View full Figma spec_](https://www.figma.com/file/RBobTMUNS6EtelpTDyYqnA/Open-GPTs?type=whiteboard&node-id=0%3A1&t=Ga2G6MUOUiNjqe3l-1) 84 | 85 | Handle the business logic (store and retrieve data, store files, enque runs) for the Assistants API according to the official [OpenAI OpenAPI specification](https://raw.githubusercontent.com/openai/openai-openapi/master/openapi.yaml). 86 | 87 | The [OpenAI OpenAPI specification](https://raw.githubusercontent.com/openai/openai-openapi/master/openapi.yaml) is the source of truth for this API. 88 | 89 | ## [run_executor_worker](./run_executor_worker) 90 | ![image](https://github.com/OpenGPTs-platform/HexAmerous/assets/37946988/610c60fe-ad01-4231-aec2-84c9a295ed30) 91 | [_View full Figma spec_](https://www.figma.com/file/RBobTMUNS6EtelpTDyYqnA/Open-GPTs?type=whiteboard&node-id=0%3A1&t=Ga2G6MUOUiNjqe3l-1) 92 | 93 | Agent that executes runs according to CoALA architecture and ReAct prompting strategy. 94 | 95 | ## Major Objectives 96 | 1. Function calling for [run_executor_worker](./run_executor_worker) using [Mistral-7B-Instruct-v0.3](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3) 97 | 2. Connect [chat-ui](https://github.com/OpenGPTs-platform/chat-ui) 98 | 3. Optimize prompting in [run_executor_worker](./run_executor_worker) 99 | 4. Open-source `web_retrieval` and add `annotations` to messages for citation purposes 100 | 101 | ## Helper ["Assistants API" Dev Assistant GPT](https://chat.openai.com/g/g-VxH4qXfuJ-assistants-api-assistant) 102 | Helper assistant for developing the "Assistants API". Normally conversation will flow like so: 103 | ```txt 104 | Human: Lets work on /assistants GET endpoint, begin with a test. Here is an example of what I have so far: 105 | 106 | 107 | Assistant: 108 | 109 | Human: Ok lets move on to the endpoint. Here is what I have so far: 110 | 111 | 112 | 113 | 114 | Assistant: 115 | 116 | THEN WHEN YOU REPEAT WITH THE CURRENT CHAT YOU SHOULD NOT NEED ALL THE EXAMPLES 117 | ``` 118 | 119 | ### Instruction 120 | 121 | The user has the goal to build a FastAPI python server according to the OpenAPI specification in your knowledge "openai-openapi-dereferenced.json". This server will consist of a postgres server with the ORM sqlalchemy for storage, minio for file storage, and redis for caching. 122 | 123 | Your objective is to facilitate the development of the server for the user by following these steps, ALWAYS FOLLOW THESE STEPS IN THE CORRESPONDING ORDER: 124 | 125 | 1. Using code_interpreter tool download "openai-openapi-dereferenced.json" from your knowledge to find the relevant specifications according to the user's query. Programmatically navigate the JSON. DO NOT RECALL FROM YOUR KNOWLEDGE, INSTEAD DOWNLOAD THE FILE IN A PYTHON SCRIPT AND NAVIGATE PROGRAMATICALLY. 126 | 2. Asking the user questions if more information is needed 127 | 3. Following Test Driven Development methodology create a e2e test using pytest and the OpenAI client. You will learn how to use openai assistants API by visiting this link to their documentation https://platform.openai.com/docs/assistants/overview with web browsing tool. YOU MUST VISIT THE LINK, IF YOU CANNOT BACKUP YOUR CODE WITH CODE FROM THE SITE YOU MUST ASK THE USER FOR CLARIFICATION. (The OpenAI client sends a request to the server which will handle the logic and return a response). YOU MUST WRITE THE TEST. 128 | 4. Asking the user questions if more information is needed. 129 | 5. Creating a plan for execution and providing the code. 130 | 131 | ### Knowledge 132 | 133 | Add [openai-openapi-dereferenced.json](./assets/openai-openapi-dereferenced.json) 134 | -------------------------------------------------------------------------------- /examples/compounding_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "!pip uninstall openai -y\n", 10 | "!pip install git+https://github.com/OpenGPTs-platform/openai-python" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "metadata": {}, 16 | "source": [ 17 | "## Configure Client" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": null, 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "from openai import OpenAI\n", 27 | "import json\n", 28 | "\n", 29 | "client = OpenAI(\n", 30 | " base_url=\"http://localhost:8000\",\n", 31 | " api_key=\"api_key\",\n", 32 | ")" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "client.ops.web_retrieval.delete()" 42 | ] 43 | }, 44 | { 45 | "cell_type": "markdown", 46 | "metadata": {}, 47 | "source": [ 48 | "## Base Demo Without Tools" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "assistant = client.beta.assistants.create(\n", 58 | " instructions=\"\"\"Your job is to assist the assume the role of a news provider and inform the user of current news.\n", 59 | "Always direct them to where they can learn more by providing the corresponding link.\"\"\", # noqa\n", 60 | " name=\"News Provider Assistant\",\n", 61 | " model=\"gpt-3.5-turbo\",\n", 62 | ")" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": null, 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "thread = client.beta.threads.create(\n", 72 | " messages=[\n", 73 | " {\n", 74 | " \"role\": \"user\",\n", 75 | " \"content\": \"Show me interesting space themed news from ycombinator.\", # noqa\n", 76 | " },\n", 77 | " ],\n", 78 | ")\n", 79 | "run = client.beta.threads.runs.create_and_poll(\n", 80 | " thread_id=thread.id,\n", 81 | " assistant_id=assistant.id,\n", 82 | ")\n", 83 | "messages = client.beta.threads.messages.list(\n", 84 | " thread_id=thread.id, order='desc'\n", 85 | ")\n", 86 | "print(json.dumps(messages.model_dump(), indent=2))" 87 | ] 88 | }, 89 | { 90 | "cell_type": "markdown", 91 | "metadata": {}, 92 | "source": [ 93 | "## Demo With `web_retireval` Tool" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "assistant = client.beta.assistants.update(\n", 103 | " assistant_id=assistant.id,\n", 104 | " tools=[\n", 105 | " {\"type\": \"web_retrieval\"}\n", 106 | " ],\n", 107 | ")" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": null, 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [ 116 | "crawl = client.ops.web_retrieval.crawl_and_upsert(\n", 117 | " root_urls=[\"https://news.ycombinator.com/\"],\n", 118 | " max_depth=1,\n", 119 | " description=\"Live news from ycombinator, a news feed centered on science and technology.\",\n", 120 | " constrain_to_root_domain=False,\n", 121 | ")\n", 122 | "successful_crawls = [ci for ci in crawl.crawl_infos if ci.error is None]\n", 123 | "print(\"Successful crawls: \", len(successful_crawls))\n", 124 | "print(crawl.model_dump_json(indent=2))\n" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [ 133 | "thread = client.beta.threads.create(\n", 134 | " messages=[\n", 135 | " {\n", 136 | " \"role\": \"user\",\n", 137 | " \"content\": \"Show me interesting space themed news from ycombinator.\", # noqa\n", 138 | " },\n", 139 | " ],\n", 140 | ")\n", 141 | "run = client.beta.threads.runs.create_and_poll(\n", 142 | " thread_id=thread.id,\n", 143 | " assistant_id=assistant.id,\n", 144 | ")\n", 145 | "messages = client.beta.threads.messages.list(\n", 146 | " thread_id=thread.id, order='desc'\n", 147 | ")\n", 148 | "print(json.dumps(messages.model_dump(), indent=2))" 149 | ] 150 | }, 151 | { 152 | "cell_type": "markdown", 153 | "metadata": {}, 154 | "source": [ 155 | "## Demo With Tools `web_retireval` and `file_search`" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": null, 161 | "metadata": {}, 162 | "outputs": [], 163 | "source": [ 164 | "file = client.files.create(\n", 165 | " purpose=\"assistants\",\n", 166 | " file=open(\"../assistants_api/assets/my_information.txt\", \"rb\"),\n", 167 | ")\n", 168 | "vector_store = client.beta.vector_stores.create(\n", 169 | " name=\"Information About Me\",\n", 170 | " file_ids=[file.id],\n", 171 | ")\n", 172 | "assistant = client.beta.assistants.update(\n", 173 | " assistant_id=assistant.id,\n", 174 | " instructions=\"\"\"Your job is to assist the assume the role of a news provider and inform the user of current news.\n", 175 | "Always direct them to where they can learn more by providing the corresponding link.\n", 176 | "You must begin by searching through the users files to find information about them.\n", 177 | "Only then can you can look for relevant news.\"\"\",\n", 178 | " tools=[\n", 179 | " {\"type\": \"web_retrieval\"},\n", 180 | " {\"type\": \"file_search\"}\n", 181 | " ],\n", 182 | " tool_resources={\n", 183 | " \"file_search\": {\n", 184 | " \"vector_store_ids\": [vector_store.id]\n", 185 | " }\n", 186 | " }\n", 187 | ")\n", 188 | " " 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": null, 194 | "metadata": {}, 195 | "outputs": [], 196 | "source": [ 197 | "thread = client.beta.threads.create(\n", 198 | " messages=[\n", 199 | " {\n", 200 | " \"role\": \"user\",\n", 201 | " \"content\": \"Find me interesting news.\", # noqa\n", 202 | " },\n", 203 | " ],\n", 204 | ")\n", 205 | "run = client.beta.threads.runs.create_and_poll(\n", 206 | " thread_id=thread.id,\n", 207 | " assistant_id=assistant.id,\n", 208 | ")\n", 209 | "messages = client.beta.threads.messages.list(\n", 210 | " thread_id=thread.id, order='desc'\n", 211 | ")\n", 212 | "print(json.dumps(messages.model_dump(), indent=2))" 213 | ] 214 | } 215 | ], 216 | "metadata": { 217 | "kernelspec": { 218 | "display_name": ".venv", 219 | "language": "python", 220 | "name": "python3" 221 | }, 222 | "language_info": { 223 | "codemirror_mode": { 224 | "name": "ipython", 225 | "version": 3 226 | }, 227 | "file_extension": ".py", 228 | "mimetype": "text/x-python", 229 | "name": "python", 230 | "nbconvert_exporter": "python", 231 | "pygments_lexer": "ipython3", 232 | "version": "3.10.2" 233 | } 234 | }, 235 | "nbformat": 4, 236 | "nbformat_minor": 2 237 | } 238 | -------------------------------------------------------------------------------- /assistants_api/tests/test_assistant.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from openai import OpenAI 3 | from openai.pagination import SyncCursorPage 4 | from openai.types.beta.assistant import Assistant 5 | from openai.types.beta.code_interpreter_tool import CodeInterpreterTool 6 | from datetime import datetime 7 | import os 8 | 9 | api_key = os.getenv("OPENAI_API_KEY") if os.getenv("OPENAI_API_KEY") else None 10 | use_openai = True if os.getenv("USE_OPENAI") else False 11 | base_url = "http://localhost:8000" 12 | 13 | 14 | @pytest.fixture 15 | def openai_client(): 16 | if use_openai: 17 | return OpenAI( 18 | api_key=api_key, 19 | ) 20 | else: 21 | return OpenAI( 22 | base_url=base_url, 23 | ) 24 | 25 | 26 | # # TODO: cleanup are causing issues with the tests, uncomment when fixed or using OpenAI API # noqa 27 | # @pytest.fixture(scope="session", autouse=True) 28 | # def cleanup(request): 29 | # openai_client = OpenAI( 30 | # base_url="http://localhost:8000", 31 | # ) 32 | 33 | # def remove_all_assistants(): 34 | # for assistant in openai_client.beta.assistants.list().data: 35 | # openai_client.beta.assistants.delete(assistant.id) 36 | 37 | # request.addfinalizer(remove_all_assistants) 38 | 39 | 40 | # /assistants POST 41 | @pytest.mark.dependency() 42 | def test_create_assistant(openai_client: OpenAI): 43 | metadata = { 44 | "str": "string", 45 | "int": "1", 46 | "bool": "True", 47 | "list": "[1, 2, 3]", 48 | } 49 | response = openai_client.beta.assistants.create( 50 | instructions="You are an AI designed to provide examples.", 51 | name="Example Assistant", 52 | tools=[{"type": "code_interpreter"}], 53 | model="gpt-4", 54 | metadata=metadata, 55 | ) 56 | assert isinstance(response, Assistant) 57 | assert response.id is not None 58 | assert response.created_at is not None 59 | assert ( 60 | response.instructions == "You are an AI designed to provide examples." 61 | ) 62 | assert response.name == "Example Assistant" 63 | assert isinstance(response.tools[0], CodeInterpreterTool) 64 | assert response.model == "gpt-4" 65 | assert response.metadata == metadata 66 | 67 | 68 | # /assistants GET 69 | @pytest.mark.dependency(depends=["test_create_assistant"]) 70 | def test_list_assistants_after_creation(openai_client: OpenAI): 71 | response = openai_client.beta.assistants.list() 72 | assert isinstance(response, SyncCursorPage) 73 | assert len(response.data) > 0 74 | assert all(isinstance(item, Assistant) for item in response.data) 75 | 76 | 77 | # /assistants GET 78 | @pytest.mark.dependency(depends=["test_create_assistant"]) 79 | def test_list_assistants_limit(openai_client: OpenAI): 80 | limit = 1 81 | response = openai_client.beta.assistants.list(limit=limit) 82 | assert isinstance(response, SyncCursorPage) 83 | assert len(response.data) <= limit 84 | assert all(isinstance(item, Assistant) for item in response.data) 85 | 86 | 87 | # /assistants GET 88 | @pytest.mark.dependency(depends=["test_create_assistant"]) 89 | def test_list_assistants_order(openai_client: OpenAI): 90 | response_desc = openai_client.beta.assistants.list(order="desc") 91 | response_asc = openai_client.beta.assistants.list(order="asc") 92 | assert isinstance(response_desc, SyncCursorPage) and isinstance( 93 | response_asc, SyncCursorPage 94 | ) 95 | assert len(response_desc.data) > 0 and len(response_asc.data) > 0 96 | desc_first_created_at = response_desc.data[0].created_at 97 | asc_first_created_at = response_asc.data[0].created_at 98 | assert ( 99 | desc_first_created_at >= asc_first_created_at 100 | ), "Ordering does not match expected results" 101 | 102 | 103 | # /assistants/{assistant_id} GET 104 | @pytest.mark.dependency(depends=["test_create_assistant"]) 105 | def test_get_assistant(openai_client: OpenAI): 106 | # Assuming "test_create_assistant" creates an assistant and returns its ID 107 | template = { 108 | "model": "gpt-4", 109 | "name": "Example Assistant", 110 | } 111 | new_assistant = openai_client.beta.assistants.create(**template) 112 | 113 | response = openai_client.beta.assistants.retrieve(new_assistant.id) 114 | 115 | # Validate the response structure and data 116 | assert isinstance(response, Assistant) 117 | assert response.id == new_assistant.id 118 | assert response.model == template["model"] 119 | assert response.name == template["name"] 120 | assert response.object == "assistant" 121 | assert isinstance(response.created_at, int) 122 | assert datetime.utcfromtimestamp( 123 | response.created_at 124 | ) # Checks if `created_at` is a valid timestamp" 125 | 126 | 127 | # /assistants/{assistant_id} POST 128 | @pytest.mark.dependency(depends=["test_create_assistant"]) 129 | def test_modify_assistant(openai_client: OpenAI): 130 | metadata = {"str": "string", "int": "1", "list": "[1, 2, 3]"} 131 | template = { 132 | "model": "gpt-4", 133 | "name": "Example Assistant", 134 | "instructions": "You are an AI designed to provide examples.", 135 | "metadata": metadata, 136 | } 137 | new_assistant = openai_client.beta.assistants.create(**template) 138 | 139 | metadata_to_be = {**template["metadata"], "bool": "True"} 140 | updated_template = { 141 | "instructions": "Updated instructions for the assistant.", 142 | "tools": [{"type": "code_interpreter"}], 143 | "metadata": {"bool": "True"}, 144 | } 145 | # Perform the update operation 146 | response = openai_client.beta.assistants.update( 147 | new_assistant.id, 148 | **updated_template, 149 | ) 150 | 151 | # Verify the response 152 | assert isinstance(response, Assistant) 153 | assert response.id == new_assistant.id 154 | assert response.instructions == updated_template["instructions"] 155 | assert isinstance(response.tools[0], CodeInterpreterTool) 156 | assert response.name == template["name"] 157 | assert response.metadata == metadata_to_be 158 | 159 | 160 | @pytest.mark.dependency(depends=["test_create_assistant"]) 161 | def test_delete_assistant(openai_client: OpenAI): 162 | # Assuming an assistant has been created in a prior test and its ID is retrievable 163 | metadata = { 164 | "str": "string", 165 | "int": "1", 166 | "bool": "True", 167 | "list": "[1, 2, 3]", 168 | } 169 | template = { 170 | "model": "gpt-4", 171 | "name": "Example Assistant", 172 | "instructions": "You are an AI designed to provide examples.", 173 | "metadata": metadata, 174 | } 175 | new_assistant = openai_client.beta.assistants.create(**template) 176 | 177 | # Perform the delete operation 178 | response = openai_client.beta.assistants.delete(new_assistant.id) 179 | 180 | # Verify the deletion response 181 | assert response.id == new_assistant.id 182 | assert response.deleted is True 183 | assert response.object == "assistant.deleted" 184 | 185 | # try and retrieve the assistant again 186 | try: 187 | openai_client.beta.assistants.retrieve(new_assistant.id) 188 | except Exception as e: 189 | assert e.status_code == 404 190 | assert "No assistant found" in str(e) 191 | else: 192 | raise AssertionError("Assistant was not deleted") 193 | -------------------------------------------------------------------------------- /run_executor_worker/src/utils/ops_api_handler.py: -------------------------------------------------------------------------------- 1 | # api_handler.py 2 | from typing import List, Literal 3 | import uuid 4 | import requests 5 | import os 6 | from data_models import run 7 | from openai.types.beta.threads.message import Message 8 | from openai.types.beta.threads.runs import FileSearchToolCall 9 | from openai.types.beta.threads.runs.web_retrieval_tool_call import ( 10 | WebRetrievalToolCall, 11 | ) 12 | from openai.types.beta.threads.runs.function_tool_call import Function 13 | from constants import WebRetrievalResult 14 | from utils.openai_clients import assistants_client 15 | 16 | 17 | # TODO: create run script that imports env vars 18 | BASE_URL = os.getenv("ASSISTANTS_API_URL") 19 | 20 | 21 | def update_run( 22 | thread_id: str, run_id: str, run_update: run.RunUpdate 23 | ) -> run.Run: 24 | """ 25 | Update the status of a Run. 26 | 27 | Parameters: 28 | thread_id (str): The ID of the thread. 29 | run_id (str): The ID of the run. 30 | new_status (str): The new status to set for the run. 31 | 32 | Returns: 33 | bool: True if the status was successfully updated, False otherwise. 34 | """ 35 | update_url = f"{BASE_URL}/ops/threads/{thread_id}/runs/{run_id}" 36 | update_data = run_update.model_dump(exclude_none=True) 37 | 38 | response = requests.post(update_url, json=update_data) 39 | 40 | if response.status_code == 200: 41 | return run.Run(**response.json()) 42 | else: 43 | return None 44 | 45 | 46 | def create_message( 47 | thread_id: str, content: str, role: Literal["user", "assistant"] 48 | ) -> Message: 49 | # Create a thread with a message 50 | message = assistants_client.beta.threads.messages.create( 51 | thread_id=thread_id, content=content, role=role 52 | ) 53 | assert message.thread_id == thread_id 54 | 55 | return message 56 | 57 | 58 | def create_message_runstep( 59 | thread_id: str, run_id: str, assistant_id: str, content: str 60 | ) -> run.RunStep: 61 | message = create_message(thread_id, content, role="assistant") 62 | # Prepare run step details 63 | run_step_details = { 64 | "assistant_id": assistant_id, 65 | "step_details": { 66 | "type": "message_creation", 67 | "message_creation": {"message_id": message.id}, 68 | }, 69 | "type": "message_creation", 70 | "status": "completed", 71 | } 72 | run_step_details = run.RunStepCreate(**run_step_details).model_dump( 73 | exclude_none=True 74 | ) 75 | 76 | # Post request to create a run step 77 | response = requests.post( 78 | f"{BASE_URL}/ops/threads/{thread_id}/runs/{run_id}/steps", 79 | json=run_step_details, 80 | ) 81 | if response.status_code != 200: 82 | raise Exception(f"Failed to create run step: {response.text}") 83 | 84 | return run.RunStep(**response.json()) 85 | 86 | 87 | def create_retrieval_runstep( 88 | thread_id: str, run_id: str, assistant_id: str, documents: List[str] 89 | ) -> dict: 90 | # Assuming the `ToolCall` is properly defined elsewhere to include `RetrievalToolCall`. # noqa 91 | tool_call = FileSearchToolCall( 92 | id="unique_tool_call_id", # This should be a unique identifier. 93 | file_search={"documents": documents}, 94 | type="file_search", 95 | ) 96 | 97 | # Prepare run step details with the tool call 98 | run_step_details = { 99 | "assistant_id": assistant_id, 100 | "step_details": { 101 | "type": "tool_calls", 102 | "tool_calls": [ 103 | tool_call.model_dump() 104 | ], # Serialize `ToolCall` to a dict 105 | }, 106 | "type": "tool_calls", 107 | "status": "completed", 108 | } 109 | 110 | # This model dumping part would be dependent on how you're handling Pydantic models, showing a conceptual example: # noqa 111 | run_step_details = run.RunStepCreate(**run_step_details).model_dump( 112 | exclude_none=True 113 | ) 114 | 115 | # Post request to create a run step 116 | response = requests.post( 117 | f"{BASE_URL}/ops/threads/{thread_id}/runs/{run_id}/steps", 118 | json=run_step_details, 119 | ) 120 | if response.status_code != 200: 121 | raise Exception(f"Failed to create run step: {response.text}") 122 | 123 | return run.RunStep(**response.json()) 124 | 125 | 126 | def create_web_retrieval_runstep( 127 | thread_id: str, 128 | run_id: str, 129 | assistant_id: str, 130 | query: str, 131 | retreived_content: List[WebRetrievalResult], 132 | site: str, 133 | ) -> dict: 134 | # Assuming the `ToolCall` is properly defined elsewhere to include `RetrievalToolCall`. # noqa 135 | tool_call = WebRetrievalToolCall( 136 | id="call_" 137 | + str(uuid.uuid4())[:-5], # This should be a unique identifier. 138 | query=query, 139 | retrieval=[item.model_dump() for item in retreived_content], 140 | type="web_retrieval", 141 | ) 142 | 143 | # Prepare run step details with the tool call 144 | run_step_details = { 145 | "assistant_id": assistant_id, 146 | "step_details": { 147 | "type": "tool_calls", 148 | "tool_calls": [ 149 | tool_call.model_dump() 150 | ], # Serialize `ToolCall` to a dict 151 | }, 152 | "type": "tool_calls", 153 | "status": "completed", 154 | } 155 | 156 | # This model dumping part would be dependent on how you're handling Pydantic models, showing a conceptual example: # noqa 157 | run_step_details = run.RunStepCreate(**run_step_details).model_dump( 158 | exclude_none=True 159 | ) 160 | 161 | # Post request to create a run step 162 | response = requests.post( 163 | f"{BASE_URL}/ops/threads/{thread_id}/runs/{run_id}/steps", 164 | json=run_step_details, 165 | ) 166 | if response.status_code != 200: 167 | raise Exception(f"Failed to create run step: {response.text}") 168 | 169 | return run.RunStep(**response.json()) 170 | 171 | 172 | def create_function_runstep( 173 | thread_id: str, 174 | run_id: str, 175 | assistant_id: str, 176 | function: Function, 177 | ) -> dict: 178 | tool_call_id = "call_" + str(uuid.uuid4())[:-5] 179 | # Prepare run step details with the tool call 180 | run_step_details = { 181 | "assistant_id": assistant_id, 182 | "step_details": { 183 | "type": "tool_calls", 184 | "tool_calls": [ 185 | { 186 | "id": tool_call_id, # This should be a unique identifier. 187 | "function": function, 188 | "type": "function", 189 | } 190 | ], # Serialize `ToolCall` to a dict 191 | }, 192 | "type": "tool_calls", 193 | "status": "in_progress", 194 | } 195 | 196 | # This model dumping part would be dependent on how you're handling Pydantic models, showing a conceptual example: # noqa 197 | run_step_details = run.RunStepCreate(**run_step_details).model_dump( 198 | exclude_none=True 199 | ) 200 | 201 | # Post request to create a run step 202 | response = requests.post( 203 | f"{BASE_URL}/ops/threads/{thread_id}/runs/{run_id}/steps", 204 | json=run_step_details, 205 | ) 206 | if response.status_code != 200: 207 | raise Exception(f"Failed to create run step: {response.text}") 208 | 209 | return run.RunStep(**response.json()) 210 | -------------------------------------------------------------------------------- /assistants_api/assets/code-reference.txt: -------------------------------------------------------------------------------- 1 | Below is a portion of a test for you to reference: 2 | import pytest 3 | from openai import OpenAI 4 | from openai.pagination import SyncCursorPage 5 | from openai.types.beta.assistant import Assistant, ToolCodeInterpreter 6 | from datetime import datetime 7 | 8 | # import os 9 | 10 | 11 | @pytest.fixture 12 | def openai_client(): 13 | # Replace "your_api_key_here" with your actual OpenAI API key 14 | return OpenAI( 15 | base_url="http://localhost:8000", 16 | # api_key=os.getenv("OPENAI_API_KEY"), 17 | ) 18 | 19 | 20 | @pytest.fixture(scope="session", autouse=True) 21 | def cleanup(request): 22 | openai_client = OpenAI( 23 | base_url="http://localhost:8000", 24 | ) 25 | 26 | def remove_all_assistants(): 27 | for assistant in openai_client.beta.assistants.list().data: 28 | openai_client.beta.assistants.delete(assistant.id) 29 | 30 | request.addfinalizer(remove_all_assistants) 31 | 32 | 33 | # /assistants POST 34 | @pytest.mark.dependency() 35 | def test_create_assistant(openai_client: OpenAI): 36 | openai_client.beta.threads.runs.steps.list() 37 | response = openai_client.beta.assistants.create( 38 | instructions="You are an AI designed to provide examples.", 39 | name="Example Assistant", 40 | tools=[{"type": "code_interpreter"}], 41 | model="gpt-4", 42 | metadata={"str": "string", "int": 1, "bool": True, "list": [1, 2, 3]}, 43 | ) 44 | assert isinstance(response, Assistant) 45 | assert response.id is not None 46 | assert response.created_at is not None 47 | assert ( 48 | response.instructions == "You are an AI designed to provide examples." 49 | ) 50 | assert response.name == "Example Assistant" 51 | assert isinstance(response.tools[0], ToolCodeInterpreter) 52 | assert response.model == "gpt-4" 53 | assert response.metadata == { 54 | "str": "string", 55 | "int": 1, 56 | "bool": True, 57 | "list": [1, 2, 3], 58 | } 59 | ... 60 | 61 | Below is a portion of a router with endpoints for you to reference: 62 | from fastapi import APIRouter, Depends, HTTPException, Query 63 | from typing import Optional 64 | from utils.tranformers import db_to_pydantic_assistant 65 | 66 | from lib.db.database import get_db 67 | from sqlalchemy.orm import Session 68 | from lib.db import crud, schemas 69 | from openai.pagination import SyncCursorPage 70 | 71 | router = APIRouter() 72 | 73 | 74 | @router.post("/assistants", response_model=schemas.Assistant) 75 | def create_assistant( 76 | assistant: schemas.AssistantCreate, 77 | db: Session = Depends(get_db), 78 | ): 79 | """ 80 | Create a new assistant. 81 | - **model**: ID of the model to use. 82 | - **name**: The name of the assistant. 83 | - **description**: The description of the assistant. 84 | - **instructions**: The system instructions that the assistant uses. 85 | - **tools**: A list of tools enabled on the assistant. 86 | - **file_ids**: A list of file IDs attached to this assistant. 87 | - **metadata**: Set of 16 key-value pairs that can be attached to the assistant. 88 | """ 89 | db_assistant = crud.create_assistant(db=db, assistant=assistant) 90 | return db_assistant 91 | 92 | 93 | @router.get("/assistants", response_model=SyncCursorPage[schemas.Assistant]) 94 | def list_assistants( 95 | db: Session = Depends(get_db), 96 | limit: int = Query(default=20, le=100), 97 | order: str = Query(default="desc", regex="^(asc|desc)$"), 98 | after: Optional[str] = None, 99 | before: Optional[str] = None, 100 | ): 101 | """ 102 | List assistants with optional pagination and ordering. 103 | - **limit**: Maximum number of results to return. 104 | - **order**: Sort order based on the creation time ('asc' or 'desc'). 105 | - **after**: ID to start the list from (for pagination). 106 | - **before**: ID to list up to (for pagination). 107 | """ 108 | db_assistants = crud.get_assistants( 109 | db=db, limit=limit, order=order, after=after, before=before 110 | ) 111 | 112 | assistants = [ 113 | db_to_pydantic_assistant(assistant) for assistant in db_assistants 114 | ] 115 | paginated_assistants = SyncCursorPage(data=assistants) 116 | 117 | return paginated_assistants 118 | ... 119 | 120 | Below is a portion of crud for you to reference: 121 | from sqlalchemy.orm import Session 122 | import time 123 | from sqlalchemy import desc, asc 124 | 125 | from lib.fs.schemas import FileObject 126 | from . import models, schemas 127 | import uuid 128 | 129 | 130 | # ASSISTANT 131 | def create_assistant(db: Session, assistant: schemas.AssistantCreate): 132 | tools = [tool.model_dump() for tool in assistant.tools] 133 | # Generate a unique ID for the new assistant 134 | db_assistant = models.Assistant( 135 | id=str(uuid.uuid4()), 136 | object="assistant", 137 | name=assistant.name, 138 | description=assistant.description, 139 | model=assistant.model, 140 | instructions=assistant.instructions, 141 | tools=tools, # Ensure your model and schema correctly handle serialization/deserialization # noqa 142 | file_ids=assistant.file_ids, 143 | metadata=assistant.metadata, 144 | created_at=int(time.time()), # Assuming UNIX timestamp for created_at 145 | # Include other fields as necessary 146 | ) 147 | db.add(db_assistant) 148 | db.commit() 149 | db.refresh(db_assistant) 150 | return db_assistant 151 | ... 152 | 153 | Below is a portion of the schemas for you to reference: 154 | from pydantic import BaseModel, Field 155 | from typing import Optional, List, Dict, Any 156 | from openai.types.beta.assistant import Assistant, Tool 157 | from openai.types.beta.assistant_deleted import AssistantDeleted 158 | 159 | Assistant 160 | AssistantDeleted 161 | 162 | 163 | class AssistantCreate(BaseModel): 164 | name: Optional[str] = Field(None, max_length=256) 165 | description: Optional[str] = Field(None, max_length=512) 166 | model: str 167 | instructions: Optional[str] = Field(None, max_length=32768) 168 | tools: List[Tool] = [] 169 | file_ids: List[str] = [] 170 | metadata: Optional[Dict[str, Any]] = None 171 | 172 | Below is an example of a model for you to reference: 173 | from sqlalchemy import Column, String, Integer, JSON, Enum 174 | from .database import Base 175 | 176 | 177 | class Assistant(Base): 178 | __tablename__ = "assistants" 179 | 180 | id = Column(String, primary_key=True, index=True) 181 | object = Column( 182 | Enum("assistant", name="assistant_object"), 183 | nullable=False, 184 | default="assistant", 185 | ) # Since "object" is a reserved keyword in Python, consider renaming or handle appropriately # noqa 186 | created_at = Column(Integer, nullable=False) 187 | name = Column(String(256), nullable=True) 188 | description = Column(String(512), nullable=True) 189 | model = Column(String, nullable=False) 190 | instructions = Column(String(32768), nullable=True) 191 | tools = Column( 192 | JSON, default=[] 193 | ) # Ensure your database supports JSON type; otherwise, consider storing as String and serializing/deserializing # noqa 194 | file_ids = Column(JSON, default=[]) 195 | _metadata = Column("metadata", JSON, nullable=True) 196 | 197 | # # If there's a relationship with users (assuming one assistant can belong to one user) # noqa 198 | # user_id = Column(String, ForeignKey('users.id')) 199 | # owner = relationship("User", back_populates="user_gpts") -------------------------------------------------------------------------------- /assistants_api/tests/test_messages.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from openai import OpenAI 3 | import os 4 | import time 5 | 6 | api_key = os.getenv("OPENAI_API_KEY") if os.getenv("OPENAI_API_KEY") else None 7 | use_openai = True if os.getenv("USE_OPENAI") else False 8 | base_url = "http://localhost:8000" 9 | 10 | 11 | @pytest.fixture 12 | def openai_client(): 13 | if use_openai: 14 | return OpenAI( 15 | api_key=api_key, 16 | ) 17 | else: 18 | return OpenAI( 19 | base_url=base_url, 20 | ) 21 | 22 | 23 | @pytest.fixture 24 | def thread_id(openai_client: OpenAI): 25 | thread_metadata = {"example_key": "example_value"} 26 | response = openai_client.beta.threads.create(metadata=thread_metadata) 27 | return response.id 28 | 29 | 30 | @pytest.mark.dependency() 31 | def test_create_message_in_thread(openai_client: OpenAI, thread_id: str): 32 | # Assume create_thread is a helper function that creates a thread and returns its ID 33 | message_data = { 34 | "role": "user", 35 | "content": "Hello, World!", 36 | "metadata": {"example_key": "example_value"}, 37 | } 38 | 39 | response = openai_client.beta.threads.messages.create( 40 | thread_id=thread_id, **message_data 41 | ) 42 | 43 | assert response.id is not None 44 | assert response.role == message_data["role"] 45 | assert response.content[0].text.value == message_data["content"] 46 | assert response.attachments == [] 47 | assert response.metadata == message_data["metadata"] 48 | 49 | 50 | @pytest.mark.dependency(depends=["test_create_message_in_thread"]) 51 | def test_get_messages_in_thread(openai_client: OpenAI, thread_id: str): 52 | # Create some messages in the thread for testing 53 | message_data_1 = { 54 | "role": "user", 55 | "content": "First message", 56 | "metadata": {"example_key": "example_value"}, 57 | } 58 | message_data_2 = { 59 | "role": "user", 60 | "content": "Second message", 61 | } 62 | openai_client.beta.threads.messages.create( 63 | thread_id=thread_id, **message_data_1 64 | ) 65 | time.sleep( 66 | 1 67 | ) # TODO: remove this. It adds a gap in btween created_at to ensure a difference in order # noqa 68 | openai_client.beta.threads.messages.create( 69 | thread_id=thread_id, **message_data_2 70 | ) 71 | 72 | # Retrieve messages from the thread 73 | response = openai_client.beta.threads.messages.list( 74 | thread_id=thread_id, limit=2, order='desc' 75 | ) 76 | 77 | assert len(response.data) == 2 78 | assert response.data[1].role == message_data_1["role"] 79 | assert response.data[1].content[0].text.value == message_data_1["content"] 80 | assert response.data[1].attachments == [] 81 | assert response.data[1].metadata == message_data_1["metadata"] 82 | 83 | 84 | @pytest.mark.dependency( 85 | depends=["test_create_message_in_thread", "test_get_messages_in_thread"] 86 | ) 87 | def test_create_thread_with_message(openai_client: OpenAI): 88 | # Create a thread with a message 89 | message_data = { 90 | "role": "user", 91 | "content": "Hello, World!", 92 | "metadata": {"example_key": "example_value"}, 93 | } 94 | 95 | create_thread = openai_client.beta.threads.create(messages=[message_data]) 96 | openai_client.beta.threads.messages.retrieve 97 | assert create_thread.id is not None 98 | 99 | get_messages = openai_client.beta.threads.messages.list( 100 | thread_id=create_thread.id 101 | ) 102 | 103 | assert len(get_messages.data) == 1 104 | assert get_messages.data[0].role == message_data["role"] 105 | assert ( 106 | get_messages.data[0].content[0].text.value == message_data["content"] 107 | ) 108 | assert get_messages.data[0].attachments == [] 109 | assert get_messages.data[0].metadata == message_data["metadata"] 110 | 111 | 112 | @pytest.mark.dependency( 113 | depends=["test_create_message_in_thread", "test_get_messages_in_thread"] 114 | ) 115 | def test_get_specific_message_in_thread(openai_client: OpenAI, thread_id: str): 116 | # First, create a message in the thread for testing 117 | message_data = { 118 | "role": "user", 119 | "content": "Test message content", 120 | "metadata": {"example_key": "example_value"}, 121 | } 122 | message_response = openai_client.beta.threads.messages.create( 123 | thread_id=thread_id, **message_data 124 | ) 125 | message_id = message_response.id 126 | 127 | # Retrieve the specific message from the thread 128 | retrieved_message = openai_client.beta.threads.messages.retrieve( 129 | thread_id=thread_id, message_id=message_id 130 | ) 131 | 132 | # Verify the retrieved message details 133 | assert retrieved_message.id == message_id 134 | assert retrieved_message.thread_id == thread_id 135 | assert retrieved_message.role == message_data["role"] 136 | assert retrieved_message.content[0].text.value == message_data["content"] 137 | assert retrieved_message.attachments == [] 138 | assert retrieved_message.metadata == message_data.get("metadata", {}) 139 | 140 | # Optionally, cleanup by deleting the message and thread if necessary 141 | 142 | 143 | @pytest.mark.dependency( 144 | depends=["test_create_message_in_thread", "test_get_messages_in_thread"] 145 | ) 146 | def test_modify_message_in_thread(openai_client: OpenAI, thread_id: str): 147 | # Create a message in the thread for testing 148 | message_data = { 149 | "role": "user", 150 | "content": "Initial message content", 151 | } 152 | message_response = openai_client.beta.threads.messages.create( 153 | thread_id=thread_id, **message_data 154 | ) 155 | message_id = message_response.id 156 | 157 | retrieved_message = openai_client.beta.threads.messages.retrieve( 158 | thread_id=thread_id, message_id=message_id 159 | ) 160 | 161 | assert retrieved_message.id == message_id 162 | assert retrieved_message.thread_id == thread_id 163 | assert retrieved_message.role == message_data["role"] 164 | assert retrieved_message.content[0].text.value == message_data["content"] 165 | assert retrieved_message.attachments == [] 166 | assert retrieved_message.metadata == {} 167 | 168 | # Data for modification 169 | updated_metadata = {"modified": "true", "user": "abc123"} 170 | 171 | # Modify the message 172 | modified_message = openai_client.beta.threads.messages.update( 173 | thread_id=thread_id, 174 | message_id=message_id, 175 | metadata=updated_metadata, 176 | ) 177 | 178 | assert modified_message.id == message_id 179 | assert modified_message.thread_id == thread_id 180 | assert modified_message.role == message_data["role"] 181 | assert modified_message.content[0].text.value == message_data["content"] 182 | assert not modified_message.attachments 183 | assert modified_message.metadata == updated_metadata 184 | 185 | retrieved_updated_message = openai_client.beta.threads.messages.retrieve( 186 | thread_id=thread_id, message_id=message_id 187 | ) 188 | 189 | # Verify the response 190 | assert retrieved_updated_message.id == message_id 191 | assert retrieved_updated_message.thread_id == thread_id 192 | assert retrieved_updated_message.role == message_data["role"] 193 | assert ( 194 | retrieved_updated_message.content[0].text.value 195 | == message_data["content"] 196 | ) 197 | assert retrieved_updated_message.attachments == [] 198 | assert retrieved_updated_message.metadata == updated_metadata 199 | -------------------------------------------------------------------------------- /assistants_api/app/routers/vectorstore_router.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | from fastapi import APIRouter, Depends, HTTPException, Query, BackgroundTasks 3 | from sqlalchemy.orm import Session 4 | from utils.tranformers import ( 5 | db_to_pydantic_vector_store, 6 | db_to_pydantic_vector_store_file_batch, 7 | ) 8 | from lib.db import crud, schemas, database 9 | from minio import Minio 10 | from lib.wv import actions as wv_actions 11 | from lib.fs.store import minio_client, BUCKET_NAME 12 | from lib.fs import actions as fs_actions 13 | import json 14 | 15 | 16 | router = APIRouter() 17 | 18 | 19 | @router.post("/vector_stores", response_model=schemas.VectorStore) 20 | def create_vector_store( 21 | background_tasks: BackgroundTasks, 22 | vector_store: schemas.VectorStoreCreate, 23 | db: Session = Depends(database.get_db), 24 | minio_client: Minio = Depends(minio_client), 25 | ): 26 | db_vector_store = crud.create_vector_store( 27 | db=db, vector_store=vector_store 28 | ) 29 | vector_store_model = db_to_pydantic_vector_store(db_vector_store) 30 | wv_actions.create_collection(vector_store_model.id) 31 | 32 | # Adding the file processing to background tasks 33 | if vector_store.file_ids: 34 | background_tasks.add_task( 35 | process_files, 36 | vector_store_model, 37 | vector_store.file_ids, 38 | db, 39 | minio_client, 40 | ) 41 | 42 | return vector_store_model 43 | 44 | 45 | @router.post( 46 | "/vector_stores/{vector_store_id}/file_batches", 47 | response_model=schemas.VectorStoreFileBatch, 48 | ) 49 | def create_vector_store_file_batch( 50 | background_tasks: BackgroundTasks, 51 | vector_store_id: str, 52 | file_batch: schemas.CreateVectorStoreFileBatchRequest, 53 | db: Session = Depends(database.get_db), 54 | minio_client: Minio = Depends(minio_client), 55 | ): 56 | # Check if vector store exists 57 | db_vector_store = crud.get_vector_store(db, vector_store_id) 58 | if not db_vector_store: 59 | raise HTTPException(status_code=404, detail="Vector store not found") 60 | vector_store = db_to_pydantic_vector_store(db_vector_store) 61 | 62 | # Create file batch 63 | db_file_batch = crud.create_file_batch( 64 | db, vector_store.id, file_batch.file_ids 65 | ) 66 | file_batch_model = db_to_pydantic_vector_store_file_batch(db_file_batch) 67 | 68 | # Process files in the background 69 | background_tasks.add_task( 70 | process_files, 71 | vector_store, 72 | file_batch.file_ids, 73 | db, 74 | minio_client, 75 | file_batch_model, 76 | ) 77 | 78 | return file_batch_model 79 | 80 | 81 | def process_files( 82 | vector_store_model: schemas.VectorStore, 83 | file_ids, 84 | db, 85 | minio_client, 86 | vector_store_file_batch: Optional[schemas.VectorStoreFileBatch] = None, 87 | ): 88 | # Retrieve the existing vector store to update 89 | vector_store_model.file_counts.in_progress = len(file_ids) 90 | 91 | if vector_store_file_batch: 92 | vector_store_file_batch.file_counts.in_progress = len(file_ids) 93 | 94 | status = ( 95 | "in_progress" 96 | if vector_store_model.file_counts.in_progress > 0 97 | else "completed" 98 | ) 99 | usage_bytes = 0 100 | 101 | # Process each file 102 | for file_id in file_ids: 103 | crud.update_vector_store( 104 | db, 105 | vector_store_model.id, 106 | { 107 | "file_counts": vector_store_model.file_counts.model_dump(), 108 | "usage_bytes": usage_bytes, 109 | "status": status, 110 | "metadata": vector_store_model.metadata, 111 | }, 112 | ) 113 | if vector_store_file_batch: 114 | crud.update_file_batch( 115 | db, 116 | vector_store_file_batch.id, 117 | { 118 | "file_counts": vector_store_file_batch.file_counts.model_dump(), 119 | "status": status, 120 | }, 121 | ) 122 | try: 123 | file_data = fs_actions.get_file_binary( 124 | minio_client, BUCKET_NAME, file_id 125 | ) 126 | file_metadata = fs_actions.get_file( 127 | minio_client, BUCKET_NAME, file_id 128 | ) 129 | 130 | wv_actions.upload_file_chunks( 131 | file_data, 132 | file_metadata.filename, 133 | file_id, 134 | vector_store_model.id, 135 | ) 136 | usage_bytes += len(file_data) 137 | vector_store_model.file_counts.completed += 1 138 | # update metadata _file_ids 139 | file_ids: List[str] = json.loads( 140 | vector_store_model.metadata["_file_ids"] 141 | ) 142 | file_ids.append(file_id) 143 | vector_store_model.metadata["_file_ids"] = json.dumps(file_ids) 144 | if vector_store_file_batch: 145 | vector_store_file_batch.file_counts.completed += 1 146 | except Exception as e: 147 | print( 148 | f"Error processing file '{file_metadata.filename}': {str(e)}" 149 | ) 150 | vector_store_model.file_counts.failed += 1 151 | if vector_store_file_batch: 152 | vector_store_file_batch.file_counts.failed += 1 153 | finally: 154 | vector_store_model.file_counts.in_progress -= 1 155 | vector_store_model.file_counts.total += 1 156 | if vector_store_file_batch: 157 | vector_store_file_batch.file_counts.in_progress -= 1 158 | vector_store_file_batch.file_counts.total += 1 159 | status = ( 160 | "completed" 161 | if vector_store_model.file_counts.in_progress == 0 162 | else "in_progress" 163 | ) 164 | 165 | status = ( 166 | "in_progress" 167 | if vector_store_model.file_counts.in_progress > 0 168 | else "completed" 169 | ) 170 | 171 | # Update the vector store with the final counts and usage bytes 172 | crud.update_vector_store( 173 | db, 174 | vector_store_model.id, 175 | { 176 | "file_counts": vector_store_model.file_counts.model_dump(), 177 | "usage_bytes": usage_bytes, 178 | "status": status, 179 | "metadata": vector_store_model.metadata, 180 | }, 181 | ) 182 | 183 | 184 | @router.get( 185 | "/vector_stores/{vector_store_id}", response_model=schemas.VectorStore 186 | ) 187 | def read_vector_store( 188 | vector_store_id: str, db: Session = Depends(database.get_db) 189 | ): 190 | db_vector_store = crud.get_vector_store( 191 | db, vector_store_id=vector_store_id 192 | ) 193 | if db_vector_store is None: 194 | raise HTTPException(status_code=404, detail="Vector store not found") 195 | return db_to_pydantic_vector_store(db_vector_store) 196 | 197 | 198 | @router.get( 199 | "/vector_stores", 200 | response_model=schemas.SyncCursorPage[schemas.VectorStore], 201 | ) 202 | def list_vector_stores( 203 | db: Session = Depends(database.get_db), 204 | limit: int = Query(default=20, le=100), 205 | order: str = Query(default="desc", regex="^(asc|desc)$"), 206 | after: Optional[str] = None, 207 | before: Optional[str] = None, 208 | ): 209 | """ 210 | List vector stores with optional pagination and ordering. 211 | - **limit**: Maximum number of results to return. 212 | - **order**: Sort order based on the creation time ('asc' or 'desc'). 213 | - **after**: ID to start the list from (for pagination). 214 | - **before**: ID to list up to (for pagination). 215 | """ 216 | vector_stores = crud.get_vector_stores( 217 | db=db, limit=limit, order=order, after=after, before=before 218 | ) 219 | 220 | vector_store_data = [ 221 | db_to_pydantic_vector_store(store) for store in vector_stores 222 | ] 223 | paginated_vector_stores = schemas.SyncCursorPage(data=vector_store_data) 224 | 225 | return paginated_vector_stores 226 | -------------------------------------------------------------------------------- /assistants_api/app/lib/db/schemas.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Field 2 | from typing import Iterable, Literal, Optional, List, Dict, Any, Union 3 | from openai.types.beta.assistant import Assistant, AssistantTool 4 | from openai.types.beta import Thread 5 | from openai.types.beta.threads import Message 6 | from enum import Enum 7 | 8 | from openai.types.beta.thread_deleted import ThreadDeleted 9 | from openai.types.beta.assistant_deleted import AssistantDeleted 10 | from openai.types.beta.vector_store import ( 11 | VectorStore, 12 | FileCounts, 13 | ExpiresAfter, 14 | ) 15 | from openai.types.beta.vector_stores.vector_store_file_batch import ( 16 | VectorStoreFileBatch, 17 | ) 18 | 19 | from openai.pagination import SyncCursorPage 20 | from openai.types.beta.threads.message_create_params import Attachment 21 | from openai.types.beta import assistant_update_params, assistant_create_params 22 | from openai.types.beta.threads.runs import ( 23 | RunStep, 24 | MessageCreationStepDetails, 25 | ToolCallsStepDetails, 26 | ) 27 | from openai.types.beta.threads.run_submit_tool_outputs_params import ToolOutput 28 | from openai.types.beta.threads.run import Run, RequiredAction 29 | from openai.types.beta.threads.text_content_block import TextContentBlock 30 | from openai.types.beta.threads.text import Text 31 | from openai.types.beta.threads import run_create_params 32 | from openai.types.beta.assistant_tool_param import AssistantToolParam 33 | from openai.types.beta.assistant_tool_choice_option_param import ( 34 | AssistantToolChoiceOptionParam, 35 | ) 36 | from openai.types.beta.assistant_response_format_option_param import ( 37 | AssistantResponseFormatOptionParam, 38 | ) 39 | from openai.resources.ops.web_retireval import ( 40 | CrawlInfo, 41 | WebRetrieval, 42 | WebRetrievalResponse, 43 | DeleteResponse, 44 | ) 45 | 46 | 47 | Assistant 48 | AssistantDeleted 49 | Thread 50 | ThreadDeleted 51 | Message # database stored message, typically used for output 52 | SyncCursorPage 53 | Run 54 | RunStep 55 | VectorStore 56 | FileCounts, 57 | ExpiresAfter, 58 | VectorStoreFileBatch 59 | TextContentBlock 60 | Text 61 | ToolOutput 62 | CrawlInfo, 63 | WebRetrieval, 64 | WebRetrievalResponse, 65 | DeleteResponse, 66 | 67 | StepDetails = Union[MessageCreationStepDetails, ToolCallsStepDetails] 68 | 69 | 70 | class AssistantCreate(BaseModel): 71 | name: Optional[str] = Field(None, max_length=256) 72 | description: Optional[str] = Field(None, max_length=512) 73 | model: str # This field is required 74 | instructions: Optional[str] = Field(None, max_length=32768) 75 | tools: List[AssistantTool] = [] 76 | metadata: Optional[Dict[str, Any]] = None 77 | response_format: Optional[str] = None 78 | temperature: Optional[float] = None 79 | tool_resources: Optional[assistant_create_params.ToolResources] = None 80 | top_p: Optional[float] = None 81 | 82 | 83 | class AssistantUpdate(BaseModel): 84 | name: Optional[str] = Field(None, max_length=256) 85 | description: Optional[str] = Field(None, max_length=512) 86 | model: Optional[str] = Field(None, max_length=256) 87 | instructions: Optional[str] = Field(None, max_length=32768) 88 | metadata: Optional[Dict[str, Any]] = None 89 | tools: Optional[List[AssistantTool]] = None # Simplified for example 90 | response_format: Optional[str] = None 91 | temperature: Optional[float] = None 92 | tool_resources: Optional[assistant_update_params.ToolResources] = None 93 | top_p: Optional[float] = None 94 | 95 | 96 | class MessageInput(BaseModel): # Input for message data 97 | role: Literal["user", "assistant"] 98 | content: str 99 | metadata: Optional[Dict[str, str]] = Field(default_factory=dict) 100 | attachments: Optional[List[Attachment]] = Field(default_factory=list) 101 | 102 | 103 | class ThreadCreate(BaseModel): 104 | messages: Optional[List[MessageInput]] = Field(default=[]) 105 | metadata: Optional[Dict[str, str]] = Field(default={}) 106 | 107 | 108 | class ThreadUpdate(BaseModel): 109 | messages: Optional[List[MessageInput]] = Field(default=[]) 110 | metadata: Optional[Dict[str, str]] = Field(default={}) 111 | 112 | 113 | class MessageUpdate(BaseModel): 114 | metadata: Optional[Dict[str, str]] = Field(default={}) 115 | 116 | 117 | class RunContent(BaseModel): 118 | assistant_id: str 119 | stream: bool = False 120 | additional_instructions: Optional[str] = None 121 | additional_messages: Optional[ 122 | Iterable[run_create_params.AdditionalMessage] 123 | ] = None 124 | instructions: Optional[str] = None 125 | max_completion_tokens: Optional[int] = None 126 | max_prompt_tokens: Optional[int] = None 127 | metadata: Optional[dict] = None 128 | model: Optional[str] = None 129 | response_format: Optional[AssistantResponseFormatOptionParam] = None 130 | temperature: Optional[float] = None 131 | tool_choice: Optional[AssistantToolChoiceOptionParam] = None 132 | tools: Optional[Iterable[AssistantToolParam]] = None 133 | top_p: Optional[float] = None 134 | truncation_strategy: Optional[run_create_params.TruncationStrategy] = None 135 | extra_headers: Optional[dict] = None 136 | extra_query: Optional[dict] = None 137 | extra_body: Optional[dict] = None 138 | timeout: Optional[float] = None 139 | 140 | 141 | class RunStatus(str, Enum): 142 | QUEUED = "queued" 143 | IN_PROGRESS = "in_progress" 144 | REQUIRES_ACTION = "requires_action" 145 | CANCELLING = "cancelling" 146 | CANCELLED = "cancelled" 147 | FAILED = "failed" 148 | COMPLETED = "completed" 149 | EXPIRED = "expired" 150 | 151 | 152 | class RunUpdate(BaseModel): 153 | assistant_id: Optional[str] = None 154 | cancelled_at: Optional[int] = None 155 | completed_at: Optional[int] = None 156 | expires_at: Optional[int] = None 157 | failed_at: Optional[int] = None 158 | file_ids: Optional[List[str]] = None 159 | instructions: Optional[str] = None 160 | last_error: Optional[Any] = None 161 | model: Optional[str] = None 162 | started_at: Optional[int] = None 163 | status: Optional[str] = None 164 | tools: Optional[Any] = None 165 | usage: Optional[Any] = None 166 | required_action: Optional[RequiredAction] = None 167 | 168 | 169 | class RunStepCreate(BaseModel): 170 | # Define the fields required for creating a RunStep 171 | assistant_id: str 172 | step_details: Any 173 | type: Literal["message_creation", "tool_calls"] 174 | status: Literal[ 175 | "in_progress", "cancelled", "failed", "completed", "expired" 176 | ] 177 | step_details: StepDetails 178 | 179 | 180 | class RunStepUpdate(BaseModel): 181 | assistant_id: Optional[str] = None 182 | cancelled_at: Optional[int] = None 183 | completed_at: Optional[int] = None 184 | expired_at: Optional[int] = None 185 | failed_at: Optional[int] = None 186 | last_error: Optional[Dict[str, Any]] = None 187 | metadata: Optional[Dict[str, Any]] = None 188 | status: Literal[ 189 | "in_progress", "cancelled", "failed", "completed", "expired" 190 | ] = None 191 | step_details: StepDetails = None 192 | type: Literal["message_creation", "tool_calls"] = None 193 | usage: Optional[Dict[str, Any]] = None 194 | 195 | 196 | class VectorStoreCreate(BaseModel): 197 | file_ids: Optional[List[str]] = Field( 198 | default=[], description="A list of file IDs for the vector store." 199 | ) 200 | name: str = Field(..., description="The name of the vector store.") 201 | expires_after: Optional[ExpiresAfter] = Field( 202 | None, description="The expiration policy for the vector store." 203 | ) 204 | metadata: Optional[Dict[str, str]] = Field( 205 | {}, description="Metadata for additional structured information." 206 | ) 207 | 208 | 209 | class CreateVectorStoreFileBatchRequest(BaseModel): 210 | file_ids: List[str] = Field( 211 | ..., min_items=1, max_items=500, example=["file-abc123", "file-abc456"] 212 | ) 213 | 214 | 215 | class SubmitToolOutputsRunRequest(BaseModel): 216 | tool_outputs: List[ToolOutput] 217 | stream: Optional[bool] = None 218 | 219 | 220 | class WebRetrievalCreate(BaseModel): 221 | root_urls: List[str] 222 | constrain_to_root_domain: bool 223 | max_depth: int 224 | description: Optional[str] = None 225 | -------------------------------------------------------------------------------- /run_executor_worker/src/utils/openai_clients.py: -------------------------------------------------------------------------------- 1 | from openai import OpenAI 2 | from openai.types.chat import ( 3 | completion_create_params, 4 | chat_completion_tool_choice_option_param, 5 | ChatCompletionMessageParam, 6 | ) 7 | from openai.types.chat.chat_completion import ChatCompletion, Choice 8 | from openai.types.chat.chat_completion_message import ChatCompletionMessage 9 | from openai.types.chat.chat_completion_message_tool_call import ( 10 | ChatCompletionMessageToolCall, 11 | Function, 12 | ) 13 | 14 | from openai._types import NOT_GIVEN 15 | from typing import Iterable, Union, Optional, Dict, List, Literal 16 | 17 | import requests 18 | import os 19 | import json 20 | from dateutil.parser import isoparse 21 | import re 22 | 23 | # raise error if LITELLM_API_URL or ASSISTANTS_API_URL or FC_API_URL is not set 24 | if not os.getenv("LITELLM_API_URL"): 25 | # print a warning message suggesting that it is defaulting to openai inference 26 | print("LITELLM_API_URL is not set. Defaulting to OpenAI inference.") 27 | if not os.getenv("ASSISTANTS_API_URL"): 28 | raise ValueError("ASSISTANTS_API_URL is not set") 29 | if not os.getenv("FC_API_URL"): 30 | print("FC_API_URL is not set. Defaulting to OpenAI inference.") 31 | 32 | litellm_client = None 33 | if os.getenv("LITELLM_API_URL"): 34 | litellm_client = OpenAI( 35 | api_key=os.getenv("LITELLM_API_KEY"), 36 | base_url=os.getenv("LITELLM_API_URL", None), 37 | ) 38 | else: 39 | litellm_client = OpenAI( 40 | api_key=os.getenv("LITELLM_API_KEY"), 41 | ) 42 | 43 | assistants_client = OpenAI( 44 | base_url=os.getenv("ASSISTANTS_API_URL"), 45 | ) 46 | 47 | if os.getenv("FC_API_URL"): 48 | fc_client = OpenAI( 49 | base_url=os.getenv("FC_API_URL"), 50 | api_key=os.getenv("FC_API_KEY"), 51 | ) 52 | else: 53 | fc_client = OpenAI( 54 | api_key=os.getenv("FC_API_KEY"), 55 | ) 56 | 57 | 58 | def chat_completion_inputs_to_prompt( 59 | messages: Iterable[ChatCompletionMessageParam], 60 | tools: Iterable[completion_create_params.ChatCompletionToolParam], 61 | ) -> str: 62 | # Convert messages to prompt format 63 | formatted_prompt = "" 64 | # count the amount of non assistant messages at the end 65 | final_user_messages = 0 66 | for message in reversed(messages): 67 | if message["role"] == "user" or message["role"] == "system": 68 | final_user_messages -= 1 69 | else: 70 | break 71 | for idx, message in enumerate(messages[:final_user_messages]): 72 | if message["role"] == "user" or message["role"] == "system": 73 | if idx != 0 and ( 74 | messages[idx - 1]["role"] == "user" 75 | or messages[idx - 1]["role"] == "system" 76 | ): 77 | formatted_prompt += f" {message['content']}" 78 | else: 79 | if idx != 0: 80 | formatted_prompt += "" 81 | formatted_prompt += f"[INST] {message['content']}" 82 | elif message["role"] == "assistant": 83 | formatted_prompt += f"[/INST]{message['content']}" 84 | else: 85 | raise ValueError(f"Invalid message type: {type(message)}") 86 | 87 | # Convert tools to prompt format 88 | formatted_prompt += ( 89 | f"[AVAILABLE_TOOLS] {json.dumps(tools)}[/AVAILABLE_TOOLS]" 90 | ) 91 | for idx, message in enumerate(messages[final_user_messages:]): 92 | if idx != 0 and ( 93 | messages[idx - 1]["role"] == "user" 94 | or messages[idx - 1]["role"] == "system" 95 | ): 96 | formatted_prompt += f" {message['content']}" 97 | else: 98 | formatted_prompt += f"[INST] {message['content']}" 99 | formatted_prompt += "[/INST]" 100 | 101 | return formatted_prompt 102 | 103 | 104 | def find_and_parse_json_objects(text): 105 | # Regular expression to find JSON arrays in the string 106 | json_pattern = re.compile(r'\[.*?\]') 107 | json_objects = [] 108 | 109 | match = json_pattern.search(text) 110 | if match: 111 | json_str = match.group() 112 | try: 113 | json_obj = json.loads(json_str) 114 | json_objects.append(json_obj) 115 | except json.JSONDecodeError as e: 116 | print(f"JSONDecodeError: {e} with string: {json_str}") 117 | 118 | return json_objects 119 | 120 | 121 | def fc_chat_completions_create( 122 | messages: Iterable[ChatCompletionMessageParam], 123 | model: Union[str, completion_create_params.ChatModel], 124 | function_call: completion_create_params.FunctionCall = NOT_GIVEN, 125 | functions: Iterable[completion_create_params.Function] = NOT_GIVEN, 126 | logit_bias: Optional[Dict[str, int]] = NOT_GIVEN, 127 | logprobs: Optional[bool] = NOT_GIVEN, 128 | max_tokens: Optional[int] = NOT_GIVEN, 129 | n: Optional[int] = NOT_GIVEN, 130 | presence_penalty: Optional[float] = NOT_GIVEN, 131 | response_format: completion_create_params.ResponseFormat = NOT_GIVEN, 132 | seed: Optional[int] = NOT_GIVEN, 133 | stop: Union[Optional[str], List[str]] = NOT_GIVEN, 134 | stream: Optional[Literal[False]] = NOT_GIVEN, 135 | stream_options: Optional[ 136 | completion_create_params.ChatCompletionStreamOptionsParam 137 | ] = NOT_GIVEN, 138 | temperature: Optional[float] = NOT_GIVEN, 139 | tool_choice: chat_completion_tool_choice_option_param.ChatCompletionToolChoiceOptionParam = NOT_GIVEN, # noqa 140 | tools: Iterable[ 141 | completion_create_params.ChatCompletionToolParam 142 | ] = NOT_GIVEN, 143 | top_logprobs: Optional[int] = NOT_GIVEN, 144 | top_p: Optional[float] = NOT_GIVEN, 145 | user: str = NOT_GIVEN, 146 | extra_headers: Dict[str, str] = None, 147 | extra_query: Dict[str, str] = None, 148 | extra_body: Dict[str, Union[str, int, float]] = None, 149 | timeout: float = None, 150 | ) -> ChatCompletion: 151 | function_signature_name = tools[0]["function"]['name'] 152 | messages = messages + [ 153 | { 154 | "role": "user", 155 | "content": f" YOU MUST REPLY STRICTLY FOLLOWING THE SPECIFIC JSON SCHEMA FORMAT FROM {function_signature_name} DO NOT RESPOND WITH ANYTHING ELSE.", # noqa 156 | } 157 | ] 158 | # transform inputs to prompt 159 | prompt = chat_completion_inputs_to_prompt(messages, tools) 160 | # make request to mistral ollama raw endpoint 161 | body = { 162 | "model": os.getenv("FC_MODEL"), 163 | "prompt": prompt, 164 | "raw": True, 165 | "stream": False, 166 | } 167 | if max_tokens: 168 | body["options"] = {"num_predict": max_tokens} 169 | response = requests.post( 170 | os.getenv('FC_API_URL'), 171 | headers={ 172 | "Authorization": f"Bearer {os.getenv('FC_API_KEY')}", 173 | "Content-Type": "application/json", 174 | }, 175 | json=body, 176 | ) 177 | response.raise_for_status() 178 | json_response = response.json() 179 | 180 | created_at_str = json_response["created_at"] 181 | created_at = int(isoparse(created_at_str).timestamp()) 182 | 183 | text_response = json_response["response"] 184 | print("\n\nFunction calling response:\n", text_response) 185 | text_response_postfix = text_response.split("\n\n")[0] 186 | text_response_postfix = text_response_postfix.replace("'", '"') 187 | 188 | json_text_response = find_and_parse_json_objects(text_response_postfix)[0] 189 | 190 | chat_completion = ChatCompletion( 191 | choices=[ 192 | Choice( 193 | finish_reason="stop", 194 | index=0, 195 | message=ChatCompletionMessage( 196 | tool_calls=[ 197 | ChatCompletionMessageToolCall( 198 | id="toolcall-1234567890abcdefg", 199 | function=Function( 200 | arguments=json.dumps( 201 | json_text_response[0]["arguments"] 202 | ), 203 | name=json_text_response[0]["name"], 204 | ), 205 | type="function", 206 | ) 207 | ], 208 | role="assistant", 209 | ), 210 | ) 211 | ], 212 | created=created_at, 213 | id="chatcmpl-1234567890abcdefg", 214 | model=model, 215 | object="chat.completion", 216 | ) 217 | return chat_completion 218 | -------------------------------------------------------------------------------- /assistants_api/app/lib/db/models.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import ( 2 | ARRAY, 3 | BigInteger, 4 | Column, 5 | Float, 6 | ForeignKey, 7 | String, 8 | Integer, 9 | JSON, 10 | Enum, 11 | ) 12 | from sqlalchemy.orm import relationship 13 | from .database import Base 14 | 15 | 16 | class Assistant(Base): 17 | __tablename__ = "assistants" 18 | 19 | id = Column(String, primary_key=True) 20 | object = Column(String, nullable=False, default="assistant") 21 | created_at = Column(Integer, nullable=False) 22 | name = Column(String(256)) 23 | description = Column(String(512)) 24 | model = Column(String(256), nullable=False) 25 | instructions = Column(String(32768), default="") 26 | tools = Column(JSON) 27 | _metadata = Column("metadata", JSON, nullable=True) 28 | response_format = Column( 29 | String(256) 30 | ) # Assuming simple string to represent the format 31 | temperature = Column(Float) 32 | tool_resources = Column(JSON) 33 | top_p = Column(Float) 34 | 35 | # # If there's a relationship with users (assuming one assistant can belong to one user) # noqa 36 | # user_id = Column(String, ForeignKey('users.id')) 37 | # owner = relationship("User", back_populates="user_gpts") 38 | 39 | 40 | class FilePurpose(Enum): 41 | FINE_TUNE = "fine-tune" 42 | ASSISTANTS = "assistants" 43 | 44 | 45 | class FileStatus(Enum): 46 | UPLOADED = "uploaded" 47 | PROCESSED = "processed" 48 | ERROR = "error" 49 | 50 | 51 | class File(Base): 52 | __tablename__ = "files" 53 | 54 | id = Column(String, primary_key=True, index=True) 55 | bytes = Column(Integer, nullable=False) 56 | created_at = Column(Integer, nullable=False) 57 | filename = Column(String(256), nullable=False) 58 | object = Column( 59 | Enum("file", name="file_object"), 60 | nullable=False, 61 | default="file", 62 | ) 63 | purpose = Column( 64 | Enum("assistants", name="file_purpose"), 65 | nullable=False, 66 | ) 67 | status = Column( 68 | Enum("uploaded", name="file_status"), 69 | nullable=False, 70 | ) 71 | status_details = Column(String(512), nullable=True) 72 | 73 | 74 | class Thread(Base): 75 | __tablename__ = "threads" 76 | 77 | id = Column(String, primary_key=True, index=True) 78 | created_at = Column(Integer, nullable=False) 79 | object = Column(String, nullable=False, default="thread") 80 | _metadata = Column("metadata", JSON, nullable=True) 81 | 82 | 83 | class Message(Base): 84 | __tablename__ = "messages" 85 | 86 | id = Column(String, primary_key=True, index=True) 87 | object = Column(String, nullable=False, default="thread.message") 88 | created_at = Column( 89 | BigInteger, nullable=False 90 | ) # BigInteger to ensure no repreating timestamps 91 | thread_id = Column(String, ForeignKey('threads.id')) 92 | role = Column(Enum('user', 'assistant', name='role_types'), nullable=False) 93 | content = Column( 94 | ARRAY(JSON), nullable=False 95 | ) # Structured content (text/images) 96 | attachments = Column(JSON, nullable=True) 97 | assistant_id = Column(String, nullable=True) 98 | run_id = Column(String, nullable=True) 99 | _metadata = Column("metadata", JSON, nullable=True) 100 | status = Column( 101 | Enum('in_progress', 'incomplete', 'completed', name='status_types'), 102 | nullable=False, 103 | default='in_progress', 104 | ) 105 | completed_at = Column(Integer, nullable=True) 106 | incomplete_at = Column(Integer, nullable=True) 107 | incomplete_details = Column(JSON, nullable=True) 108 | 109 | thread = relationship("Thread", back_populates="messages") 110 | 111 | 112 | Thread.messages = relationship( 113 | "Message", order_by=Message.created_at, back_populates="thread" 114 | ) 115 | 116 | 117 | class Run(Base): 118 | __tablename__ = 'runs' 119 | 120 | id = Column(String, primary_key=True, index=True) 121 | assistant_id = Column(String, index=False) 122 | cancelled_at = Column(Integer, nullable=True) 123 | completed_at = Column(Integer, nullable=True) 124 | created_at = Column(Integer, nullable=False) 125 | expires_at = Column( 126 | Integer, nullable=True 127 | ) # Changed from nullable=False to nullable=True 128 | failed_at = Column(Integer, nullable=True) 129 | incomplete_details = Column(JSON, nullable=True) # Added field 130 | instructions = Column(String, nullable=False, default="") 131 | last_error = Column(JSON, nullable=True) 132 | max_completion_tokens = Column(Integer, nullable=True) # Added field 133 | max_prompt_tokens = Column(Integer, nullable=True) # Added field 134 | _metadata = Column( 135 | "metadata", JSON, nullable=True 136 | ) # Renamed _metadata to metadata 137 | model = Column(String, nullable=False) 138 | object = Column(String, nullable=False, default="thread.run") 139 | required_action = Column(JSON, nullable=True) # Added field 140 | response_format = Column(JSON, nullable=True) # Added field 141 | started_at = Column(Integer, nullable=True) 142 | status = Column(String, nullable=False) 143 | thread_id = Column(String, ForeignKey('threads.id')) 144 | tool_choice = Column(JSON, nullable=True) # Added field 145 | tools = Column( 146 | JSON, nullable=True, default=[] 147 | ) # Modified default to match list in Pydantic schema 148 | truncation_strategy = Column(JSON, nullable=True) # Added field 149 | usage = Column(JSON, nullable=True) 150 | temperature = Column(Float, nullable=True) # Added field 151 | top_p = Column(Float, nullable=True) # Added field 152 | 153 | thread = relationship("Thread", back_populates="runs") 154 | 155 | 156 | Thread.runs = relationship( 157 | "Run", order_by=Run.created_at, back_populates="thread" 158 | ) 159 | 160 | 161 | class RunStep(Base): 162 | __tablename__ = "run_steps" 163 | 164 | id = Column(String, primary_key=True, index=True) 165 | assistant_id = Column(String, ForeignKey('assistants.id')) 166 | cancelled_at = Column(Integer, nullable=True) 167 | completed_at = Column(Integer, nullable=True) 168 | created_at = Column(Integer, nullable=False) 169 | expired_at = Column(Integer, nullable=True) 170 | failed_at = Column(Integer, nullable=True) 171 | last_error = Column(JSON, nullable=True) 172 | _metadata = Column("metadata", JSON, nullable=True) 173 | object = Column(String, nullable=False, default="thread.run.step") 174 | run_id = Column(String, ForeignKey('runs.id')) 175 | status = Column( 176 | Enum( 177 | "in_progress", 178 | "cancelled", 179 | "failed", 180 | "completed", 181 | "expired", 182 | name="run_step_status", 183 | ), 184 | nullable=False, 185 | ) 186 | step_details = Column( 187 | JSON, nullable=False 188 | ) # To store details refer to https://github.com/OpenGPTs-platform/assistants-api/issues/12 # noqa 189 | thread_id = Column(String, ForeignKey('threads.id')) 190 | type = Column( 191 | Enum("message_creation", "tool_calls", name="run_step_type"), 192 | nullable=False, 193 | ) 194 | usage = Column(JSON, nullable=True) 195 | 196 | # assistant = relationship("Assistant", back_populates="run_steps") 197 | # run = relationship("Run", back_populates="run_steps") 198 | thread = relationship("Thread", back_populates="run_steps") 199 | 200 | 201 | Thread.run_steps = relationship( 202 | "RunStep", order_by=RunStep.created_at, back_populates="thread" 203 | ) 204 | 205 | 206 | class VectorStore(Base): 207 | __tablename__ = 'vector_stores' 208 | 209 | id = Column(String, primary_key=True, index=True) 210 | created_at = Column(Integer, nullable=False) 211 | last_active_at = Column(Integer, nullable=True) 212 | _metadata = Column("metadata", JSON, nullable=True) 213 | name = Column(String(256), nullable=False) 214 | object = Column(String, nullable=False, default="vector_store") 215 | status = Column( 216 | Enum( 217 | "in_progress", 218 | "completed", 219 | "expired", 220 | name="vector_store_status", 221 | ), 222 | nullable=False, 223 | ) 224 | usage_bytes = Column(Integer, nullable=False) 225 | file_counts = Column(JSON, nullable=False) 226 | expires_after = Column(JSON, nullable=True) 227 | expires_at = Column(Integer, nullable=True) 228 | 229 | 230 | class VectorStoreFileBatch(Base): 231 | __tablename__ = "vector_store_file_batches" 232 | 233 | id = Column(String, primary_key=True, index=True) 234 | created_at = Column(Integer, nullable=False) 235 | vector_store_id = Column(String, index=True, nullable=False) 236 | object = Column(String, nullable=False, default="vector_store.files_batch") 237 | status = Column( 238 | Enum( 239 | "in_progress", 240 | "completed", 241 | "cancelled", 242 | "failed", 243 | name="batch_status", 244 | ), 245 | default="in_progress", 246 | ) 247 | file_counts = Column(JSON, nullable=False) 248 | -------------------------------------------------------------------------------- /assistants_api/tests/test_run_function_calling.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from openai import OpenAI 3 | 4 | 5 | from openai.types.beta.function_tool import FunctionTool 6 | from openai.types.beta.threads.required_action_function_tool_call import ( 7 | RequiredActionFunctionToolCall, 8 | ) 9 | from openai.types.beta.threads.runs.run_step import RunStep 10 | from openai.types.beta.threads.runs.tool_calls_step_details import ( 11 | ToolCallsStepDetails, 12 | ) 13 | from openai.types.beta.threads.runs.message_creation_step_details import ( 14 | MessageCreationStepDetails, 15 | ) 16 | 17 | # from openai.types.beta.threads.runs import RunStep 18 | import os 19 | import json 20 | 21 | api_key = os.getenv("OPENAI_API_KEY") if os.getenv("OPENAI_API_KEY") else None 22 | use_openai = True if os.getenv("USE_OPENAI") else False 23 | base_url = "http://localhost:8000" 24 | 25 | current_dir = os.path.dirname(__file__) 26 | code_reference_file_path = os.path.join( 27 | current_dir, '..', 'assets', 'code-reference.txt' 28 | ) 29 | 30 | 31 | @pytest.fixture 32 | def openai_client(): 33 | if use_openai: 34 | return OpenAI( 35 | api_key=api_key, 36 | ) 37 | else: 38 | return OpenAI( 39 | base_url=base_url, 40 | ) 41 | 42 | 43 | @pytest.fixture 44 | def thread_id(openai_client: OpenAI): 45 | response = openai_client.beta.threads.create( 46 | messages=[ 47 | { 48 | "role": "user", 49 | "content": "What's the weather in San Francisco today?", # noqa 50 | } 51 | ], 52 | ) 53 | return response.id 54 | 55 | 56 | get_current_temperature_sig = { 57 | "name": "get_current_temperature", 58 | "description": "Get the current temperature for a specific location", 59 | "parameters": { 60 | "type": "object", 61 | "properties": { 62 | "location": { 63 | "type": "string", 64 | "description": "The city and state, e.g., San Francisco, CA", 65 | }, 66 | "unit": { 67 | "type": "string", 68 | "enum": ["Celsius", "Fahrenheit"], 69 | "description": "The temperature unit to use. Infer this from the user's location.", # noqa 70 | }, 71 | }, 72 | "required": ["location", "unit"], 73 | }, 74 | } 75 | 76 | 77 | @pytest.fixture 78 | def assistant_id(openai_client: OpenAI): 79 | response = openai_client.beta.assistants.create( 80 | name="Tool Assistant", 81 | model="gpt-3.5-turbo", 82 | tools=[ 83 | { 84 | "type": "function", 85 | "function": get_current_temperature_sig, 86 | }, 87 | ], 88 | ) 89 | 90 | return response.id 91 | 92 | 93 | @pytest.fixture 94 | def run_id(openai_client: OpenAI, thread_id: str, assistant_id: str): 95 | response = openai_client.beta.threads.runs.create( 96 | thread_id=thread_id, 97 | assistant_id=assistant_id, 98 | ) 99 | return response.id 100 | 101 | 102 | @pytest.mark.dependency() 103 | def test_create_asst_with_fc(openai_client: OpenAI): 104 | response = openai_client.beta.assistants.create( 105 | name="Tool Assistant", 106 | model="gpt-3.5-turbo", 107 | tools=[ 108 | { 109 | "type": "function", 110 | "function": { 111 | "name": "get_current_temperature", 112 | "description": "Get the current temperature for a specific location", # noqa 113 | "parameters": { 114 | "type": "object", 115 | "properties": { 116 | "location": { 117 | "type": "string", 118 | "description": "The city and state, e.g., San Francisco, CA", # noqa 119 | }, 120 | "unit": { 121 | "type": "string", 122 | "enum": ["Celsius", "Fahrenheit"], 123 | "description": "The temperature unit to use. Infer this from the user's location.", # noqa 124 | }, 125 | }, 126 | "required": ["location", "unit"], 127 | }, 128 | }, 129 | }, 130 | ], 131 | ) 132 | asst = openai_client.beta.assistants.retrieve(assistant_id=response.id) 133 | assert asst.tools[0].type == "function" 134 | assert isinstance(asst.tools[0], FunctionTool) 135 | assert asst.tools[0].function.name == get_current_temperature_sig["name"] 136 | 137 | 138 | @pytest.mark.dependency(depends=["test_create_asst_with_fc"]) 139 | def test_execute_fc_run_to_tool_call( 140 | openai_client: OpenAI, assistant_id: str, thread_id: str 141 | ): 142 | run = openai_client.beta.threads.runs.create_and_poll( 143 | thread_id=thread_id, 144 | assistant_id=assistant_id, 145 | ) 146 | assert run.status == "requires_action" 147 | assert isinstance( 148 | run.required_action.submit_tool_outputs.tool_calls[0], 149 | RequiredActionFunctionToolCall, 150 | ) 151 | assert ( 152 | run.required_action.submit_tool_outputs.tool_calls[0].function.name 153 | == "get_current_temperature" 154 | ) 155 | json_args = json.loads( 156 | run.required_action.submit_tool_outputs.tool_calls[ 157 | 0 158 | ].function.arguments 159 | ) 160 | for param in get_current_temperature_sig["parameters"]["properties"]: 161 | assert param in json_args 162 | 163 | # test run_steps 164 | run_steps = openai_client.beta.threads.runs.steps.list( 165 | run_id=run.id, 166 | thread_id=thread_id, 167 | ) 168 | assert len(run_steps.data) >= 1 169 | latest_step: RunStep = run_steps.data[0] 170 | assert latest_step.status == "in_progress" 171 | assert isinstance(latest_step.step_details, ToolCallsStepDetails) 172 | 173 | 174 | @pytest.mark.dependency(depends=["test_execute_fc_run_to_tool_call"]) 175 | def test_execute_fc_to_submit_tool_output( 176 | openai_client: OpenAI, assistant_id: str, thread_id: str 177 | ): 178 | run = openai_client.beta.threads.runs.create_and_poll( 179 | thread_id=thread_id, 180 | assistant_id=assistant_id, 181 | ) 182 | assert run.status == "requires_action" 183 | 184 | tool_outputs = [ 185 | { 186 | "tool_call_id": run.required_action.submit_tool_outputs.tool_calls[ 187 | 0 188 | ].id, 189 | "output": "57", 190 | } 191 | ] 192 | 193 | run = openai_client.beta.threads.runs.submit_tool_outputs( 194 | thread_id=thread_id, run_id=run.id, tool_outputs=tool_outputs 195 | ) 196 | 197 | assert run.status == "queued" 198 | assert run.required_action is None 199 | 200 | run_steps = openai_client.beta.threads.runs.steps.list( 201 | run_id=run.id, 202 | thread_id=thread_id, 203 | ) 204 | tool_call_step: RunStep = run_steps.data[0] 205 | assert tool_call_step.status == "completed" 206 | assert isinstance(tool_call_step.step_details, ToolCallsStepDetails) 207 | assert tool_call_step.step_details.tool_calls[0].function.output == "57" 208 | 209 | 210 | @pytest.mark.dependency(depends=["test_execute_fc_to_submit_tool_output"]) 211 | def test_execute_full_fc_run( 212 | openai_client: OpenAI, assistant_id: str, thread_id: str 213 | ): 214 | run = openai_client.beta.threads.runs.create_and_poll( 215 | thread_id=thread_id, 216 | assistant_id=assistant_id, 217 | ) 218 | assert run.status == "requires_action" 219 | 220 | tool_outputs = [ 221 | { 222 | "tool_call_id": run.required_action.submit_tool_outputs.tool_calls[ 223 | 0 224 | ].id, 225 | "output": "57", 226 | } 227 | ] 228 | 229 | run = openai_client.beta.threads.runs.submit_tool_outputs_and_poll( 230 | thread_id=thread_id, run_id=run.id, tool_outputs=tool_outputs 231 | ) 232 | 233 | assert run.status == "completed" 234 | 235 | run_steps = openai_client.beta.threads.runs.steps.list( 236 | run_id=run.id, 237 | thread_id=thread_id, 238 | ) 239 | assert len(run_steps.data) > 1 240 | message_step: RunStep = run_steps.data[0] 241 | assert isinstance(message_step.step_details, MessageCreationStepDetails) 242 | message = openai_client.beta.threads.messages.retrieve( 243 | thread_id=thread_id, 244 | message_id=message_step.step_details.message_creation.message_id, 245 | ) 246 | assert "57" in message.model_dump_json() 247 | 248 | # find the first tool_call_step.step_details of type ToolCallsStepDetails 249 | tool_call_step = None 250 | for step in run_steps.data: 251 | if step.type == "tool_calls": 252 | tool_call_step = step 253 | break 254 | assert tool_call_step.status == "completed" 255 | assert isinstance(tool_call_step.step_details, ToolCallsStepDetails) 256 | assert "57" in tool_call_step.step_details.model_dump_json() 257 | -------------------------------------------------------------------------------- /run_executor_worker/src/run_executor/main.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any, Optional 2 | from constants import PromptKeys 3 | from utils.weaviate_utils import get_web_retrieval_description 4 | from utils.tools import ActionItem, Actions, tools_to_map 5 | from utils.ops_api_handler import create_message_runstep, update_run 6 | from data_models import run 7 | from openai.types.beta.threads.message import Message 8 | from utils.openai_clients import assistants_client 9 | from openai.types.beta.thread import Thread 10 | from openai.types.beta import Assistant 11 | from openai.pagination import SyncCursorPage 12 | from agents import router, coala 13 | from actions import function_calling_tool 14 | import json 15 | import datetime 16 | 17 | # TODO: add assistant and base tools off of assistant 18 | 19 | 20 | class ExecuteRun: 21 | def __init__( 22 | self, thread_id: str, run_id: str, run_config: Dict[str, Any] = {} 23 | ): 24 | self.run_id = run_id 25 | self.thread_id = thread_id 26 | self.assistant_id: Optional[str] = None 27 | self.run_config = run_config 28 | 29 | self.run: Optional[run.Run] = None 30 | self.messages: Optional[SyncCursorPage(Message)] = None 31 | self.thread: Optional[Thread] = None 32 | self.assistant: Optional[Assistant] = None 33 | self.tools_map: Optional[dict[str, ActionItem]] = None 34 | self.runsteps: Optional[SyncCursorPage[run.RunStep]] = None 35 | self.web_retrieval_description: Optional[str] = None 36 | # TODO: add assistant and base tools off of assistant 37 | 38 | def execute(self): 39 | # Create an instance of the RunUpdate schema with the new status 40 | run_update = run.RunUpdate(status=run.RunStatus.IN_PROGRESS.value) 41 | 42 | # Call the API handler to update the run status 43 | updated_run = update_run(self.thread_id, self.run_id, run_update) 44 | 45 | if not updated_run: 46 | print( 47 | f"Error updating run status for {self.run_id}. Aborting execution." 48 | ) 49 | return 50 | 51 | try: 52 | self.run = updated_run 53 | self.runsteps = assistants_client.beta.threads.runs.steps.list( 54 | run_id=self.run_id, thread_id=self.thread_id, order="desc" 55 | ) 56 | print("\n\nExecuting run: ", self.run, "\n\n") 57 | 58 | # Get the thread messages 59 | # TODO: should only populate these entities once 60 | thread = assistants_client.beta.threads.retrieve( 61 | thread_id=self.thread_id, 62 | ) 63 | self.thread = thread 64 | 65 | assistant = assistants_client.beta.assistants.retrieve( 66 | assistant_id=self.run.assistant_id, 67 | ) 68 | self.assistant_id = assistant.id 69 | self.assistant = assistant 70 | 71 | self.web_retrieval_description = get_web_retrieval_description() 72 | self.tools_map = tools_to_map( 73 | self.assistant.tools, self.web_retrieval_description 74 | ) 75 | 76 | messages = assistants_client.beta.threads.messages.list( 77 | thread_id=self.thread_id, order="asc" 78 | ) 79 | self.messages = messages 80 | 81 | if ( 82 | len(self.runsteps.data) 83 | and self.runsteps.data[0].status == "completed" 84 | and self.runsteps.data[0].type == "tool_calls" 85 | and self.runsteps.data[0].step_details.tool_calls[0].type 86 | == "function" 87 | ): 88 | router_response = "tool_response" 89 | else: 90 | router_agent = router.RouterAgent(self) # semantic router 91 | router_response = router_agent.generate() 92 | 93 | if ( 94 | router_response != PromptKeys.TRANSITION.value 95 | and router_response != "tool_response" 96 | ): 97 | create_message_runstep( 98 | self.thread_id, 99 | self.run_id, 100 | self.run.assistant_id, 101 | router_response, 102 | ) 103 | print(f"\n\nFinal response:\n{router_response}") 104 | update_run( 105 | self.thread_id, 106 | self.run_id, 107 | run.RunUpdate( 108 | status=run.RunStatus.COMPLETED.value, 109 | completed_at=int(datetime.datetime.now().timestamp()), 110 | ), 111 | ) 112 | print( 113 | f"""\n\nFinished executing run with status {run.RunStatus.COMPLETED.value}.""" # noqa 114 | ) 115 | return 116 | print("\n\nTRANSITIONING TO COALA\n\n") 117 | 118 | coala_class = coala.CoALA( 119 | self.run_id, self.thread_id, self.assistant_id, self 120 | ) 121 | self.run = coala_class.retrieve_run() 122 | self.assistant = coala_class.retrieve_assistant() 123 | self.messages = coala_class.retrieve_messages() 124 | self.runsteps = coala_class.retrieve_runsteps() 125 | coala_class.set_assistant_tools() 126 | 127 | if router_response == PromptKeys.TRANSITION.value: 128 | coala_class.generate_question() 129 | 130 | if router_response == "tool_response": 131 | coala_class.load_trace() 132 | 133 | max_steps = 8 134 | curr_step = 0 135 | 136 | requires_action = False 137 | while ( 138 | coala_class.react_steps[-1].step_type 139 | != coala.ReactStepType.FINAL_ANSWER 140 | ): 141 | if router_response == "tool_response": 142 | router_response = PromptKeys.TRANSITION.value 143 | fc_tool = function_calling_tool.FunctionCallingTool( 144 | coala_class 145 | ) 146 | fc_tool.generate_tool_summary(self.runsteps.data[-1]) 147 | continue 148 | 149 | self.messages = coala_class.retrieve_messages() 150 | self.runsteps = coala_class.retrieve_runsteps() 151 | if ( 152 | coala_class.react_steps[-1].step_type 153 | != coala.ReactStepType.THOUGHT 154 | ): 155 | coala_class.generate_thought() 156 | coala_class.generate_action() 157 | current_action = Actions(coala_class.react_steps[-1].content) 158 | coala_class.execute_action(current_action) 159 | 160 | print( 161 | f"""\n\nStep {curr_step} completed. 162 | with react steps: 163 | {json.dumps([step.model_dump() for step in coala_class.react_steps], indent=2)}""" 164 | ) 165 | curr_step += 1 166 | 167 | if current_action == Actions.FUNCTION: 168 | requires_action = True 169 | break 170 | if curr_step >= max_steps: 171 | break 172 | # if while completes from the if statement, then print("success") else if it breaks from the while loop, print("failure") # noqa 173 | if ( 174 | coala_class.react_steps[-1].step_type 175 | == coala.ReactStepType.FINAL_ANSWER 176 | ): 177 | run_update = run.RunUpdate( 178 | status=run.RunStatus.COMPLETED.value, 179 | completed_at=int(datetime.datetime.now().timestamp()), 180 | ) 181 | elif requires_action: 182 | latest_react_step = coala_class.react_steps[-1] 183 | run_update = run.RunUpdate( 184 | status=run.RunStatus.REQUIRES_ACTION.value, 185 | required_action=run.RequiredAction.model_validate( 186 | { 187 | "submit_tool_outputs": { 188 | "tool_calls": [ 189 | json.loads(latest_react_step.content) 190 | ] 191 | }, 192 | "type": "submit_tool_outputs", 193 | } 194 | ), 195 | ) 196 | else: 197 | run_update = run.RunUpdate( 198 | status=run.RunStatus.FAILED.value, 199 | completed_at=int(datetime.datetime.now().timestamp()), 200 | ) 201 | updated_run = update_run(self.thread_id, self.run_id, run_update) 202 | 203 | print( 204 | f"""\n\nFinished executing run with status {run_update.status} after {curr_step} steps.""" # noqa 205 | ) 206 | except Exception as e: 207 | print(f"Error executing run: {e}") 208 | run_update = run.RunUpdate( 209 | status=run.RunStatus.FAILED.value, 210 | failed_at=int(datetime.datetime.now().timestamp()), 211 | ) 212 | updated_run = update_run(self.thread_id, self.run_id, run_update) 213 | print(f"Run failed: {updated_run}") 214 | return 215 | --------------------------------------------------------------------------------