├── src └── aiaio │ ├── app │ ├── __init__.py │ ├── templates │ │ └── index.html │ └── app.py │ ├── __init__.py │ ├── cli │ ├── __init__.py │ ├── aiaio.py │ └── run_app.py │ ├── logging.py │ ├── prompts.py │ └── db.py ├── ui.png ├── Makefile ├── Dockerfile ├── .github └── workflows │ ├── code_quality.yml │ └── build_mac.yml ├── pyproject.toml ├── .dockerignore ├── .gitignore ├── README.md └── LICENSE /src/aiaio/app/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ui.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abhishekkrthakur/aiaio/HEAD/ui.png -------------------------------------------------------------------------------- /src/aiaio/__init__.py: -------------------------------------------------------------------------------- 1 | from aiaio.logging import Logger 2 | 3 | 4 | __version__ = "0.0.9" 5 | logger = Logger().get_logger() 6 | -------------------------------------------------------------------------------- /src/aiaio/cli/__init__.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from argparse import ArgumentParser 3 | 4 | 5 | class BaseCLICommand(ABC): 6 | @staticmethod 7 | @abstractmethod 8 | def register_subcommand(parser: ArgumentParser): 9 | raise NotImplementedError() 10 | 11 | @abstractmethod 12 | def run(self): 13 | raise NotImplementedError() 14 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: quality style test 2 | 3 | quality: 4 | black --check --line-length 119 --target-version py310 --exclude .venv . 5 | isort --check-only --skip .venv . 6 | flake8 --max-line-length 119 --exclude .venv 7 | 8 | style: 9 | black --line-length 119 --target-version py310 --exclude .venv . 10 | isort --skip .venv . 11 | 12 | test: 13 | pytest -sv ./src/ 14 | 15 | pip: 16 | rm -rf build/ 17 | rm -rf dist/ 18 | make style && make quality 19 | python -m build 20 | twine upload dist/* --verbose --repository aiaio 21 | 22 | docker-build: 23 | docker build -t aiaio . 24 | 25 | docker-run: 26 | docker run --network=host -it --rm aiaio -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.11-slim 2 | 3 | WORKDIR /app 4 | 5 | # Set Python to run in unbuffered mode 6 | ENV PYTHONUNBUFFERED=1 7 | 8 | # Install system dependencies 9 | RUN apt-get update && \ 10 | apt-get install -y --no-install-recommends \ 11 | build-essential \ 12 | && rm -rf /var/lib/apt/lists/* 13 | 14 | # Copy only the necessary files 15 | COPY pyproject.toml . 16 | COPY README.md . 17 | COPY LICENSE . 18 | COPY Makefile . 19 | COPY src/ src/ 20 | 21 | # Install dependencies and package 22 | RUN pip install --no-cache-dir -e ".[dev]" 23 | 24 | # Expose the port from environment variable 25 | EXPOSE 9000 26 | 27 | CMD ["aiaio", "app", "--host", "0.0.0.0", "--port", "9000"] 28 | -------------------------------------------------------------------------------- /.github/workflows/code_quality.yml: -------------------------------------------------------------------------------- 1 | name: Code quality 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | release: 11 | types: 12 | - created 13 | 14 | jobs: 15 | check_code_quality: 16 | name: Check code quality 17 | runs-on: ubuntu-latest 18 | steps: 19 | - uses: actions/checkout@v2 20 | - name: Set up Python 3.10 21 | uses: actions/setup-python@v2 22 | with: 23 | python-version: 3.10.16 24 | - name: Install dependencies 25 | run: | 26 | python -m pip install --upgrade pip 27 | python -m pip install flake8 black isort 28 | - name: Make quality 29 | run: | 30 | make quality -------------------------------------------------------------------------------- /src/aiaio/cli/aiaio.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from aiaio import __version__ 4 | from aiaio.cli.run_app import RunAppCommand 5 | 6 | 7 | def main(): 8 | parser = argparse.ArgumentParser( 9 | "aiaio cli", 10 | usage="aiaio []", 11 | epilog="For more information about a command, run: `aiaio --help`", 12 | ) 13 | parser.add_argument("--version", "-v", help="Display version", action="store_true") 14 | commands_parser = parser.add_subparsers(help="commands") 15 | 16 | # Register commands 17 | RunAppCommand.register_subcommand(commands_parser) 18 | 19 | args = parser.parse_args() 20 | 21 | if args.version: 22 | print(__version__) 23 | exit(0) 24 | 25 | if not hasattr(args, "func"): 26 | parser.print_help() 27 | exit(1) 28 | 29 | command = args.func(args) 30 | command.run() 31 | 32 | 33 | if __name__ == "__main__": 34 | main() 35 | -------------------------------------------------------------------------------- /.github/workflows/build_mac.yml: -------------------------------------------------------------------------------- 1 | name: Build Mac App 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | workflow_dispatch: 7 | 8 | jobs: 9 | build: 10 | runs-on: macos-14 # Explicitly target Apple Silicon (M1) 11 | 12 | steps: 13 | - uses: actions/checkout@v4 14 | 15 | - name: Install uv 16 | uses: astral-sh/setup-uv@v3 17 | 18 | - name: Set up Python 19 | run: uv python install 3.10 20 | 21 | - name: Install dependencies 22 | run: | 23 | uv venv 24 | uv pip install . 25 | # Ensure build dependencies are installed explicitly just in case 26 | uv pip install pywebview pyinstaller pillow 27 | 28 | - name: Build Mac App 29 | run: | 30 | # Activate venv 31 | source .venv/bin/activate 32 | bash build_mac.sh 33 | 34 | - name: Upload Artifact 35 | uses: actions/upload-artifact@v4 36 | with: 37 | name: aiaio-mac-arm64-dmg 38 | path: dist/aiaio.dmg 39 | -------------------------------------------------------------------------------- /src/aiaio/logging.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from dataclasses import dataclass 3 | 4 | from loguru import logger 5 | 6 | 7 | @dataclass 8 | class Logger: 9 | """ 10 | A custom logger class that sets up and manages logging configuration. 11 | 12 | Methods 13 | ------- 14 | __post_init__(): 15 | Initializes the logger with a specific format and sets up the logger. 16 | 17 | setup_logger(): 18 | Configures the logger to output to stdout with the specified format and filter. 19 | 20 | get_logger(): 21 | Returns the configured logger instance. 22 | """ 23 | 24 | def __post_init__(self): 25 | self.log_format = ( 26 | "{level: <8} | " 27 | "{time:YYYY-MM-DD HH:mm:ss} | " 28 | "{name}:{function}:{line} - " 29 | "{message}" 30 | ) 31 | self.logger = logger 32 | self.setup_logger() 33 | 34 | def setup_logger(self): 35 | self.logger.remove() 36 | self.logger.add(sys.stdout, format=self.log_format) 37 | 38 | def get_logger(self): 39 | return self.logger 40 | -------------------------------------------------------------------------------- /src/aiaio/prompts.py: -------------------------------------------------------------------------------- 1 | # This file is kept for reference. The prompts are now also stored in the database. 2 | 3 | SUMMARY_PROMPT = """ 4 | you are a bot that summarizes user messages in less than 50 characters. 5 | just write a summary of the conversation. dont write this is a summary. 6 | dont answer the question, just summarize the conversation. 7 | the user wants to know what the conversation is about, not the answers. 8 | 9 | Examples: 10 | input: {'role': 'user', 'content': "['how to inverse a string in python?']"} 11 | output: reverse a string in python 12 | 13 | input: {'role': 'user', 'content': "['hi', 'how are you?', 'how do i install pandas?']"} 14 | output: greeting, install pandas 15 | 16 | input: {'role': 'user', 'content': "['hi']"} 17 | output: greeting 18 | 19 | input: {'role': 'user', 'content': "['hi', 'how are you?']"} 20 | output: greeting 21 | 22 | input: {'role': 'user', 'content': "['write a python snake game', 'thank you']"} 23 | output: python snake game 24 | """ 25 | 26 | DEFAULT_SYSTEM_PROMPT = """ 27 | You are a helpful bot that assists users with their queries. 28 | You should provide a helpful response to the user's query. 29 | """ 30 | 31 | SYSTEM_PROMPTS = { 32 | "summary": SUMMARY_PROMPT, 33 | "default": DEFAULT_SYSTEM_PROMPT, 34 | } 35 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "aiaio" 7 | version = "0.10.0" 8 | description = "aiaio" 9 | readme = "README.md" 10 | authors = [ 11 | { name = "abhishek" } 12 | ] 13 | classifiers = [ 14 | "Development Status :: 5 - Production/Stable", 15 | "License :: OSI Approved :: Apache Software License", 16 | "Operating System :: OS Independent", 17 | "Programming Language :: Python :: 3.10", 18 | "Programming Language :: Python :: 3.11", 19 | "Topic :: Scientific/Engineering :: Artificial Intelligence" 20 | ] 21 | keywords = ["aiaio"] 22 | dependencies = ["fastapi", "uvicorn", "loguru", "jinja2", "python-multipart", "openai", "websockets", "pywebview", "pyinstaller", "pillow"] 23 | 24 | [project.scripts] 25 | aiaio = "aiaio.cli.aiaio:main" 26 | 27 | [project.urls] 28 | Homepage = "https://github.com/abhishekkrthakur/aiaio" 29 | Issues = "https://github.com/abhishekkrthakur/aiaio/issues" 30 | 31 | [tool.setuptools.package-data] 32 | aiaio = [ 33 | "app/static/*", 34 | "app/templates/*" 35 | ] 36 | 37 | [tool.isort] 38 | ensure_newline_before_comments = true 39 | force_grid_wrap = 0 40 | include_trailing_comma = true 41 | line_length = 119 42 | lines_after_imports = 2 43 | multi_line_output = 3 44 | use_parentheses = true 45 | 46 | [tool.flake8] 47 | ignore = ["E203", "E501", "W503"] 48 | max-line-length = 119 49 | per-file-ignores = { "__init__.py" = ["F401", "E402"] } 50 | exclude = [".git", ".venv", "__pycache__", "dist", "build"] 51 | 52 | [tool.pytest.ini_options] 53 | addopts = "--maxfail=1 --disable-warnings" 54 | testpaths = ["tests"] 55 | -------------------------------------------------------------------------------- /src/aiaio/cli/run_app.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from argparse import ArgumentParser 3 | 4 | import uvicorn 5 | 6 | from aiaio import logger 7 | 8 | from . import BaseCLICommand 9 | 10 | 11 | def run_app_command_factory(args): 12 | return RunAppCommand(args.port, args.host, args.workers) 13 | 14 | 15 | class RunAppCommand(BaseCLICommand): 16 | @staticmethod 17 | def register_subcommand(parser: ArgumentParser): 18 | run_app_parser = parser.add_parser( 19 | "app", 20 | description="✨ Run app", 21 | ) 22 | run_app_parser.add_argument( 23 | "--port", 24 | type=int, 25 | default=10000, 26 | help="Port to run the app on", 27 | required=False, 28 | ) 29 | run_app_parser.add_argument( 30 | "--host", 31 | type=str, 32 | default="127.0.0.1", 33 | help="Host to run the app on", 34 | required=False, 35 | ) 36 | run_app_parser.add_argument( 37 | "--workers", 38 | type=int, 39 | default=1, 40 | help="Number of workers to run the app with", 41 | required=False, 42 | ) 43 | run_app_parser.set_defaults(func=run_app_command_factory) 44 | 45 | def __init__(self, port, host, workers): 46 | self.port = port 47 | self.host = host 48 | self.workers = workers 49 | 50 | def run(self): 51 | 52 | logger.info("Starting aiaio server.") 53 | 54 | try: 55 | uvicorn.run("aiaio.app.app:app", host=self.host, port=self.port, workers=self.workers) 56 | except KeyboardInterrupt: 57 | logger.warning("Server terminated by user.") 58 | sys.exit(0) 59 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | pythonenv3.10/ 2 | pythonenv3.8/ 3 | *.db 4 | 5 | # data files 6 | *.bin 7 | *.csv 8 | 9 | # vscode 10 | .vscode/ 11 | 12 | # Byte-compiled / optimized / DLL files 13 | __pycache__/ 14 | *.py[cod] 15 | *$py.class 16 | 17 | # C extensions 18 | *.so 19 | 20 | # Distribution / packaging 21 | .Python 22 | build/ 23 | develop-eggs/ 24 | dist/ 25 | downloads/ 26 | eggs/ 27 | .eggs/ 28 | lib/ 29 | lib64/ 30 | parts/ 31 | sdist/ 32 | var/ 33 | wheels/ 34 | pip-wheel-metadata/ 35 | share/python-wheels/ 36 | *.egg-info/ 37 | .installed.cfg 38 | *.egg 39 | MANIFEST 40 | 41 | # PyInstaller 42 | # Usually these files are written by a python script from a template 43 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 44 | *.manifest 45 | *.spec 46 | 47 | # Installer logs 48 | pip-log.txt 49 | pip-delete-this-directory.txt 50 | 51 | # Unit test / coverage reports 52 | htmlcov/ 53 | .tox/ 54 | .nox/ 55 | .coverage 56 | .coverage.* 57 | .cache 58 | nosetests.xml 59 | coverage.xml 60 | *.cover 61 | *.py,cover 62 | .hypothesis/ 63 | .pytest_cache/ 64 | 65 | # Translations 66 | *.mo 67 | *.pot 68 | 69 | # Django stuff: 70 | *.log 71 | local_settings.py 72 | db.sqlite3 73 | db.sqlite3-journal 74 | 75 | # Flask stuff: 76 | instance/ 77 | .webassets-cache 78 | 79 | # Scrapy stuff: 80 | .scrapy 81 | 82 | # Sphinx documentation 83 | docs/_build/ 84 | 85 | # PyBuilder 86 | target/ 87 | 88 | # Jupyter Notebook 89 | .ipynb_checkpoints 90 | 91 | # IPython 92 | profile_default/ 93 | ipython_config.py 94 | 95 | # pyenv 96 | .python-version 97 | 98 | # pipenv 99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 102 | # install all needed dependencies. 103 | #Pipfile.lock 104 | 105 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 106 | __pypackages__/ 107 | 108 | # Celery stuff 109 | celerybeat-schedule 110 | celerybeat.pid 111 | 112 | # SageMath parsed files 113 | *.sage.py 114 | 115 | # Environments 116 | .env 117 | .venv 118 | env/ 119 | venv/ 120 | ENV/ 121 | env.bak/ 122 | venv.bak/ 123 | 124 | # Spyder project settings 125 | .spyderproject 126 | .spyproject 127 | 128 | # Rope project settings 129 | .ropeproject 130 | 131 | # mkdocs documentation 132 | /site 133 | 134 | # mypy 135 | .mypy_cache/ 136 | .dmypy.json 137 | dmypy.json 138 | 139 | # Pyre type checker 140 | .pyre/ -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | pythonenv3.10/ 2 | pythonenv3.8/ 3 | *.db 4 | uv.lock 5 | .DS_Store 6 | 7 | # data files 8 | *.bin 9 | *.csv 10 | 11 | # vscode 12 | .vscode/ 13 | 14 | # Byte-compiled / optimized / DLL files 15 | __pycache__/ 16 | *.py[cod] 17 | *$py.class 18 | 19 | # C extensions 20 | *.so 21 | 22 | # Distribution / packaging 23 | .Python 24 | build/ 25 | develop-eggs/ 26 | dist/ 27 | downloads/ 28 | eggs/ 29 | .eggs/ 30 | lib/ 31 | lib64/ 32 | parts/ 33 | sdist/ 34 | var/ 35 | wheels/ 36 | pip-wheel-metadata/ 37 | share/python-wheels/ 38 | *.egg-info/ 39 | .installed.cfg 40 | *.egg 41 | MANIFEST 42 | 43 | # PyInstaller 44 | # Usually these files are written by a python script from a template 45 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 46 | *.manifest 47 | *.spec 48 | 49 | # Installer logs 50 | pip-log.txt 51 | pip-delete-this-directory.txt 52 | 53 | # Unit test / coverage reports 54 | htmlcov/ 55 | .tox/ 56 | .nox/ 57 | .coverage 58 | .coverage.* 59 | .cache 60 | nosetests.xml 61 | coverage.xml 62 | *.cover 63 | *.py,cover 64 | .hypothesis/ 65 | .pytest_cache/ 66 | 67 | # Translations 68 | *.mo 69 | *.pot 70 | 71 | # Django stuff: 72 | *.log 73 | local_settings.py 74 | db.sqlite3 75 | db.sqlite3-journal 76 | 77 | # Flask stuff: 78 | instance/ 79 | .webassets-cache 80 | 81 | # Scrapy stuff: 82 | .scrapy 83 | 84 | # Sphinx documentation 85 | docs/_build/ 86 | 87 | # PyBuilder 88 | target/ 89 | 90 | # Jupyter Notebook 91 | .ipynb_checkpoints 92 | 93 | # IPython 94 | profile_default/ 95 | ipython_config.py 96 | 97 | # pyenv 98 | .python-version 99 | 100 | # pipenv 101 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 102 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 103 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 104 | # install all needed dependencies. 105 | #Pipfile.lock 106 | 107 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 108 | __pypackages__/ 109 | 110 | # Celery stuff 111 | celerybeat-schedule 112 | celerybeat.pid 113 | 114 | # SageMath parsed files 115 | *.sage.py 116 | 117 | # Environments 118 | .env 119 | .venv 120 | env/ 121 | venv/ 122 | ENV/ 123 | env.bak/ 124 | venv.bak/ 125 | 126 | # Spyder project settings 127 | .spyderproject 128 | .spyproject 129 | 130 | # Rope project settings 131 | .ropeproject 132 | 133 | # mkdocs documentation 134 | /site 135 | 136 | # mypy 137 | .mypy_cache/ 138 | .dmypy.json 139 | dmypy.json 140 | 141 | # Pyre type checker 142 | .pyre/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # aiaio (AI-AI-O) 2 | 3 | A lightweight, privacy-focused web UI for interacting with AI models. Supports both local and remote LLM deployments through OpenAI-compatible APIs. 4 | 5 | ![Screenshot](https://github.com/abhishekkrthakur/aiaio/blob/main/ui.png?raw=true) 6 | 7 | ## Features 8 | 9 | - 🌓 Dark/Light mode support 10 | - 💾 Local SQLite database for conversation storage 11 | - 📁 File upload and processing (images, documents, etc.) 12 | - ⚙️ Configurable model parameters through UI 13 | - 🔒 Privacy-focused (all data stays local) 14 | - 📱 Responsive design for mobile/desktop 15 | - 🎨 Syntax highlighting for code blocks 16 | - 📋 One-click code block copying 17 | - 🔄 Real-time conversation updates 18 | - 📝 Automatic conversation summarization 19 | - 🎯 Customizable system prompts 20 | - 🌐 WebSocket support for real-time updates 21 | - 📦 Docker support for easy deploymen 22 | - 📦 Multiple API endpoint support 23 | - 📦 Multiple system prompt support 24 | 25 | ## Requirements 26 | 27 | 28 | - Python 3.10+ 29 | - An OpenAI-compatible API endpoint (local or remote) 30 | 31 | ## Supported API Endpoints 32 | 33 | aiaio works with any OpenAI-compatible API endpoint, including: 34 | 35 | - OpenAI API 36 | - vLLM 37 | - Text Generation Inference (TGI) 38 | - Hugging Face Inference Endpoints 39 | - llama.cpp server 40 | - LocalAI 41 | - Custom OpenAI-compatible APIs 42 | 43 | For example, you can serve llama 8b using vLLM using: 44 | 45 | ```bash 46 | vllm serve Meta-Llama-3.1-8B-Instruct.Q4_K_M.gguf --tokenizer meta-llama/Llama-3.1-8B-Instruct --max_model_len 125000 47 | ``` 48 | 49 | and once the api is running, you can access it using aiaio ui. 50 | 51 | ## Installation 52 | 53 | ### Using pip 54 | 55 | ```bash 56 | pip install aiaio 57 | ``` 58 | 59 | ### From source 60 | 61 | ```bash 62 | git clone https://github.com/abhishekkrthakur/aiaio.git 63 | cd aiaio 64 | pip install -e . 65 | ``` 66 | 67 | ## Quick Start 68 | 69 | 1. Start the server: 70 | ```bash 71 | aiaio app --host 127.0.0.1 --port 5000 72 | ``` 73 | 74 | 2. Open your browser and navigate to `http://127.0.0.1:5000` 75 | 76 | 3. Configure your API endpoint and model settings in the UI 77 | 78 | ## Docker Usage 79 | 80 | 1. Build the Docker image: 81 | ```bash 82 | docker build -t aiaio . 83 | ``` 84 | 85 | 2. Run the container: 86 | ```bash 87 | docker run --network host \ 88 | -v /path/to/data:/data \ 89 | aiaio 90 | ``` 91 | 92 | The `/data` volume mount is optional but recommended for persistent storage of the SQLite database and uploaded files. 93 | 94 | ## UI Configuration 95 | 96 | All model and API settings can be configured through the UI: 97 | 98 | ### Model Parameters 99 | - **Temperature** (0-2): Controls response randomness. Higher values make output more creative but less focused 100 | - **Max Tokens** (1-32k): Maximum length of generated responses 101 | - **Top P** (0-1): Controls diversity via nucleus sampling. Lower values make output more focused 102 | - **Model Name**: Name/path of the model to use (depends on your API endpoint) 103 | 104 | ### API Configuration 105 | - **Host**: URL of your OpenAI-compatible API endpoint 106 | - **API Key**: Authentication key if required by your endpoint 107 | 108 | These settings are stored in the local SQLite database and persist between sessions. 109 | 110 | ## File Handling 111 | 112 | aiaio supports uploading and processing various file types, depending on the model's capabilities: 113 | 114 | - Images (PNG, JPG, GIF, etc.) 115 | - Documents (PDF, DOC, DOCX) 116 | - Text files (TXT, CSV, JSON) 117 | - Audio files (depends on model capabilities) 118 | - Video files (depends on model capabilities) 119 | 120 | Uploaded files are stored temporarily and can be referenced in conversations. 121 | 122 | ## Database Schema 123 | 124 | aiaio uses SQLite for storage with the following main tables: 125 | 126 | - `conversations`: Stores chat histories and summaries 127 | - `messages`: Stores individual messages within conversations 128 | - `attachments`: Stores file attachment metadata 129 | - `settings`: Stores UI and model configuration 130 | 131 | ## Advanced Usage 132 | 133 | ### Custom System Prompts 134 | 135 | Each conversation can have its own system prompt that guides the AI's behavior. Click the "System Prompt" section above the chat to customize it. 136 | 137 | ### Conversation Management 138 | 139 | - Create new conversations using the "+ New Chat" button 140 | - Switch between conversations in the left sidebar 141 | - Delete conversations using the trash icon 142 | - View conversation summaries in the sidebar 143 | 144 | ### Keyboard Shortcuts 145 | 146 | - `Ctrl/Cmd + Enter`: Send message 147 | - `Esc`: Clear input 148 | - `Ctrl/Cmd + K`: Focus chat input 149 | - `Ctrl/Cmd + /`: Toggle settings sidebar 150 | 151 | ## Development 152 | 153 | ```bash 154 | # Clone the repository 155 | git clone https://github.com/abhishekkrthakur/aiaio.git 156 | cd aiaio 157 | 158 | # Create a virtual environment 159 | python -m venv venv 160 | source venv/bin/activate # or `venv\Scripts\activate` on Windows 161 | 162 | # Install development dependencies 163 | pip install -e ".[dev]" 164 | 165 | # Run tests 166 | pytest 167 | 168 | # Run with auto-reload for development 169 | uvicorn aiaio.app.app:app --reload --port 5000 170 | ``` 171 | 172 | ## Contributing 173 | 174 | Contributions are welcome! Please: 175 | 176 | 1. Fork the repository 177 | 2. Create a feature branch 178 | 3. Make your changes 179 | 4. Run the tests (`pytest`) 180 | 5. Submit a pull request 181 | 182 | ## License 183 | 184 | Apache License 2.0 - see LICENSE file for details 185 | 186 | ## Acknowledgements 187 | 188 | This project was primarily written with GitHub Copilot's assistance. While the human guided the development, Copilot generated much of the actual code. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /src/aiaio/db.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sqlite3 3 | import time 4 | import uuid 5 | from typing import Dict, List, Optional 6 | 7 | from .prompts import SYSTEM_PROMPTS 8 | 9 | 10 | # SQL schema for creating database tables 11 | _DB = """ 12 | CREATE TABLE conversations ( 13 | conversation_id TEXT PRIMARY KEY, 14 | created_at REAL DEFAULT (strftime('%s.%f', 'now')), 15 | updated_at REAL DEFAULT (strftime('%s.%f', 'now')), 16 | last_updated REAL DEFAULT (strftime('%s.%f', 'now')), 17 | summary TEXT, 18 | project_id TEXT REFERENCES projects(project_id) 19 | ); 20 | 21 | CREATE TABLE messages ( 22 | message_id TEXT PRIMARY KEY, 23 | conversation_id TEXT, 24 | role TEXT CHECK(role IN ('user', 'assistant', 'system')), 25 | content_type TEXT CHECK(content_type IN ('text', 'image', 'audio', 'video', 'file')), 26 | content TEXT, 27 | created_at REAL DEFAULT (strftime('%s.%f', 'now')), 28 | updated_at REAL DEFAULT (strftime('%s.%f', 'now')), 29 | FOREIGN KEY (conversation_id) REFERENCES conversations(conversation_id) 30 | ); 31 | 32 | CREATE TABLE attachments ( 33 | attachment_id TEXT PRIMARY KEY, 34 | message_id TEXT, 35 | file_name TEXT, 36 | file_path TEXT, 37 | file_type TEXT, 38 | file_size INTEGER, 39 | created_at REAL DEFAULT (strftime('%s.%f', 'now')), 40 | updated_at REAL DEFAULT (strftime('%s.%f', 'now')), 41 | FOREIGN KEY (message_id) REFERENCES messages(message_id) 42 | ); 43 | 44 | CREATE TABLE providers ( 45 | id INTEGER PRIMARY KEY AUTOINCREMENT, 46 | name TEXT NOT NULL UNIQUE, 47 | is_default BOOLEAN NOT NULL DEFAULT false, 48 | temperature REAL DEFAULT 1.0, 49 | top_p REAL DEFAULT 0.95, 50 | reasoning_effort TEXT DEFAULT 'none', 51 | use_for_summarization BOOLEAN DEFAULT false, 52 | host TEXT NOT NULL, 53 | api_key TEXT DEFAULT '', 54 | created_at REAL DEFAULT (strftime('%s.%f', 'now')), 55 | updated_at REAL DEFAULT (strftime('%s.%f', 'now')) 56 | ); 57 | 58 | CREATE TABLE models ( 59 | id INTEGER PRIMARY KEY AUTOINCREMENT, 60 | provider_id INTEGER NOT NULL, 61 | model_name TEXT NOT NULL, 62 | is_default BOOLEAN DEFAULT false, 63 | is_multimodal BOOLEAN DEFAULT false, 64 | created_at REAL DEFAULT (strftime('%s.%f', 'now')), 65 | updated_at REAL DEFAULT (strftime('%s.%f', 'now')), 66 | FOREIGN KEY (provider_id) REFERENCES providers(id) ON DELETE CASCADE, 67 | UNIQUE(provider_id, model_name) 68 | ); 69 | 70 | CREATE TABLE system_prompts ( 71 | id INTEGER PRIMARY KEY AUTOINCREMENT, 72 | prompt_name TEXT NOT NULL UNIQUE, 73 | prompt_text TEXT NOT NULL, 74 | is_active BOOLEAN DEFAULT false, 75 | created_at REAL DEFAULT (strftime('%s.%f', 'now')), 76 | updated_at REAL DEFAULT (strftime('%s.%f', 'now')) 77 | ); 78 | 79 | CREATE TABLE projects ( 80 | project_id TEXT PRIMARY KEY, 81 | name TEXT NOT NULL, 82 | description TEXT, 83 | system_prompt TEXT, 84 | created_at REAL DEFAULT (strftime('%s.%f', 'now')), 85 | updated_at REAL DEFAULT (strftime('%s.%f', 'now')) 86 | ); 87 | """ 88 | 89 | 90 | class ChatDatabase: 91 | """A class to manage chat-related database operations. 92 | 93 | This class handles all database interactions for conversations, messages, 94 | attachments, and settings using SQLite. 95 | 96 | Attributes: 97 | db_path (str): Path to the SQLite database file 98 | """ 99 | 100 | def __init__(self, db_path: str = "chatbot.db"): 101 | """Initialize the database connection. 102 | 103 | Args: 104 | db_path (str, optional): Path to the SQLite database file. Defaults to "chatbot.db". 105 | """ 106 | self.db_path = db_path 107 | self._init_db() 108 | 109 | def _init_db(self): 110 | """Initialize the database schema. 111 | 112 | Creates tables if they don't exist or if the database is new. 113 | Also handles schema migrations for existing databases. 114 | """ 115 | with sqlite3.connect(self.db_path) as conn: 116 | # Check if main table exists 117 | cursor = conn.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='conversations'") 118 | table_exists = cursor.fetchone() is not None 119 | 120 | if not table_exists: 121 | # Execute schema 122 | conn.executescript(_DB) 123 | 124 | # Insert default providers and models 125 | providers_count = conn.execute("SELECT COUNT(*) FROM providers").fetchone()[0] 126 | if providers_count == 0: 127 | # Custom provider 128 | cursor = conn.execute( 129 | """INSERT INTO providers 130 | (name, is_default, temperature, top_p, reasoning_effort, use_for_summarization, host, api_key) 131 | VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", 132 | ("Custom", True, 1.0, 0.95, "none", True, "http://localhost:8000/v1", ""), 133 | ) 134 | custom_id = cursor.lastrowid 135 | conn.execute( 136 | "INSERT INTO models (provider_id, model_name, is_default, is_multimodal) VALUES (?, ?, ?, ?)", 137 | (custom_id, "meta-llama/Llama-3.2-1B-Instruct", True, False), 138 | ) 139 | 140 | # Google provider 141 | cursor = conn.execute( 142 | """INSERT INTO providers 143 | (name, is_default, temperature, top_p, reasoning_effort, use_for_summarization, host, api_key) 144 | VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", 145 | ( 146 | "Google", 147 | False, 148 | 1.0, 149 | 0.95, 150 | "low", 151 | False, 152 | "https://generativelanguage.googleapis.com/v1beta", 153 | "", 154 | ), 155 | ) 156 | google_id = cursor.lastrowid 157 | conn.execute( 158 | "INSERT INTO models (provider_id, model_name, is_default, is_multimodal) VALUES (?, ?, ?, ?)", 159 | (google_id, "gemini-3-pro-preview", True, True), 160 | ) 161 | conn.execute( 162 | "INSERT INTO models (provider_id, model_name, is_default, is_multimodal) VALUES (?, ?, ?, ?)", 163 | (google_id, "gemini-2.5-pro", False, True), 164 | ) 165 | conn.execute( 166 | "INSERT INTO models (provider_id, model_name, is_default, is_multimodal) VALUES (?, ?, ?, ?)", 167 | (google_id, "gemini-2.5-flash", False, True), 168 | ) 169 | conn.execute( 170 | "INSERT INTO models (provider_id, model_name, is_default, is_multimodal) VALUES (?, ?, ?, ?)", 171 | (google_id, "gemini-2.5-flash-lite", False, True), 172 | ) 173 | 174 | # OpenAI provider 175 | cursor = conn.execute( 176 | """INSERT INTO providers 177 | (name, is_default, temperature, top_p, reasoning_effort, use_for_summarization, host, api_key) 178 | VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", 179 | ("OpenAI", False, 1.0, 0.95, "low", False, "https://api.openai.com/v1", ""), 180 | ) 181 | openai_id = cursor.lastrowid 182 | conn.execute( 183 | "INSERT INTO models (provider_id, model_name, is_default, is_multimodal) VALUES (?, ?, ?, ?)", 184 | (openai_id, "gpt-5.1-2025-11-13", True, True), 185 | ) 186 | conn.execute( 187 | "INSERT INTO models (provider_id, model_name, is_default, is_multimodal) VALUES (?, ?, ?, ?)", 188 | (openai_id, "gpt-5-mini-2025-08-07", False, True), 189 | ) 190 | conn.execute( 191 | "INSERT INTO models (provider_id, model_name, is_default, is_multimodal) VALUES (?, ?, ?, ?)", 192 | (openai_id, "gpt-5-nano-2025-08-07", False, True), 193 | ) 194 | conn.execute( 195 | "INSERT INTO models (provider_id, model_name, is_default, is_multimodal) VALUES (?, ?, ?, ?)", 196 | (openai_id, "gpt-oss-120b", False, False), 197 | ) 198 | 199 | # Insert system prompts 200 | conn.execute( 201 | """INSERT INTO system_prompts (prompt_name, prompt_text, is_active) 202 | VALUES (?, ?, ?)""", 203 | ("summary", SYSTEM_PROMPTS["summary"].strip(), False), 204 | ) 205 | conn.execute( 206 | """INSERT INTO system_prompts (prompt_name, prompt_text, is_active) 207 | VALUES (?, ?, ?)""", 208 | ("default", SYSTEM_PROMPTS["default"].strip(), True), 209 | ) 210 | else: 211 | # Check if summary column exists 212 | columns = conn.execute("PRAGMA table_info(conversations)").fetchall() 213 | column_names = [col[1] for col in columns] 214 | if "summary" not in column_names: 215 | conn.execute("ALTER TABLE conversations ADD COLUMN summary TEXT") 216 | 217 | # Check if project_id column exists 218 | if "project_id" not in column_names: 219 | conn.execute( 220 | "ALTER TABLE conversations ADD COLUMN project_id TEXT REFERENCES projects(project_id)" 221 | ) 222 | 223 | # Ensure default project exists 224 | # Check if projects table exists first (it might not if we are migrating from very old version, but schema above creates it) 225 | # If we just created schema, it exists. If we are migrating, we might need to create it? 226 | # The original code assumed if db exists, we just migrate columns. 227 | # But projects table was added later. 228 | # Let's check if projects table exists 229 | projects_table = conn.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='projects'").fetchone() 230 | if not projects_table: 231 | conn.execute(""" 232 | CREATE TABLE projects ( 233 | project_id TEXT PRIMARY KEY, 234 | name TEXT NOT NULL, 235 | description TEXT, 236 | system_prompt TEXT, 237 | created_at REAL DEFAULT (strftime('%s.%f', 'now')), 238 | updated_at REAL DEFAULT (strftime('%s.%f', 'now')) 239 | ); 240 | """) 241 | 242 | projects_count = conn.execute("SELECT COUNT(*) FROM projects").fetchone()[0] 243 | if projects_count == 0: 244 | default_project_id = str(uuid.uuid4()) 245 | conn.execute( 246 | "INSERT INTO projects (project_id, name, description, system_prompt) VALUES (?, ?, ?, ?)", 247 | ( 248 | default_project_id, 249 | "General", 250 | "Default project for general conversations", 251 | SYSTEM_PROMPTS["default"].strip(), 252 | ), 253 | ) 254 | 255 | # Migrate existing conversations to default project 256 | # Check if project_id column exists in conversations (we added it above if missing) 257 | conn.execute("UPDATE conversations SET project_id = ? WHERE project_id IS NULL", (default_project_id,)) 258 | 259 | def create_conversation(self, project_id: Optional[str] = None) -> str: 260 | """Create a new conversation. 261 | 262 | Args: 263 | project_id (str, optional): ID of the project the conversation belongs to. 264 | 265 | Returns: 266 | str: Unique identifier for the created conversation. 267 | """ 268 | conversation_id = str(uuid.uuid4()) 269 | with sqlite3.connect(self.db_path) as conn: 270 | if project_id: 271 | conn.execute( 272 | "INSERT INTO conversations (conversation_id, project_id) VALUES (?, ?)", 273 | (conversation_id, project_id), 274 | ) 275 | else: 276 | # Fallback to default project if none specified 277 | # Find a default project or the first one 278 | project = conn.execute("SELECT project_id FROM projects ORDER BY created_at ASC LIMIT 1").fetchone() 279 | if project: 280 | conn.execute( 281 | "INSERT INTO conversations (conversation_id, project_id) VALUES (?, ?)", 282 | (conversation_id, project[0]), 283 | ) 284 | else: 285 | # Should not happen due to init_db, but safe fallback 286 | conn.execute("INSERT INTO conversations (conversation_id) VALUES (?)", (conversation_id,)) 287 | return conversation_id 288 | 289 | def add_message( 290 | self, 291 | conversation_id: str, 292 | role: str, 293 | content: str, 294 | content_type: str = "text", 295 | attachments: Optional[List[Dict]] = None, 296 | ) -> str: 297 | """Add a new message to a conversation. 298 | 299 | Args: 300 | conversation_id (str): ID of the conversation 301 | role (str): Role of the message sender ('user', 'assistant', or 'system') 302 | content (str): Content of the message 303 | content_type (str, optional): Type of content. Defaults to "text". 304 | attachments (Optional[List[Dict]], optional): List of attachment metadata. Defaults to None. 305 | 306 | Returns: 307 | str: Unique identifier for the created message 308 | """ 309 | message_id = str(uuid.uuid4()) 310 | current_time = time.time() 311 | 312 | with sqlite3.connect(self.db_path) as conn: 313 | conn.execute( 314 | """INSERT INTO messages 315 | (message_id, conversation_id, role, content_type, content, created_at) 316 | VALUES (?, ?, ?, ?, ?, ?)""", 317 | (message_id, conversation_id, role, content_type, content, current_time), 318 | ) 319 | 320 | conn.execute( 321 | """UPDATE conversations 322 | SET last_updated = ? 323 | WHERE conversation_id = ?""", 324 | (current_time, conversation_id), 325 | ) 326 | 327 | if attachments: 328 | for att in attachments: 329 | conn.execute( 330 | """INSERT INTO attachments 331 | (attachment_id, message_id, file_name, file_path, file_type, file_size, created_at) 332 | VALUES (?, ?, ?, ?, ?, ?, ?)""", 333 | ( 334 | str(uuid.uuid4()), 335 | message_id, 336 | att["name"], 337 | att["path"], 338 | att["type"], 339 | att["size"], 340 | current_time, 341 | ), 342 | ) 343 | 344 | return message_id 345 | 346 | def get_conversation_history(self, conversation_id: str) -> List[Dict]: 347 | """Retrieve the full history of a conversation including attachments. 348 | 349 | Args: 350 | conversation_id (str): ID of the conversation 351 | 352 | Returns: 353 | List[Dict]: List of messages with their attachments in chronological order 354 | """ 355 | with sqlite3.connect(self.db_path) as conn: 356 | conn.row_factory = sqlite3.Row 357 | messages = conn.execute( 358 | """SELECT m.*, a.attachment_id, a.file_name, a.file_path, a.file_type, a.file_size 359 | FROM messages m 360 | LEFT JOIN attachments a ON m.message_id = a.message_id 361 | WHERE m.conversation_id = ? 362 | ORDER BY m.created_at ASC""", 363 | (conversation_id,), 364 | ).fetchall() 365 | 366 | # Group attachments by message_id 367 | message_dict = {} 368 | for row in messages: 369 | message_id = row["message_id"] 370 | if message_id not in message_dict: 371 | message_dict[message_id] = { 372 | key: row[key] 373 | for key in ["message_id", "conversation_id", "role", "content_type", "content", "created_at"] 374 | } 375 | message_dict[message_id]["attachments"] = [] 376 | 377 | if row["attachment_id"]: 378 | message_dict[message_id]["attachments"].append( 379 | { 380 | "attachment_id": row["attachment_id"], 381 | "file_name": row["file_name"], 382 | "file_path": row["file_path"], 383 | "file_type": row["file_type"], 384 | "file_size": row["file_size"], 385 | } 386 | ) 387 | 388 | return list(message_dict.values()) 389 | 390 | def get_attachment(self, attachment_id: str) -> Optional[Dict]: 391 | """Retrieve attachment details by ID.""" 392 | with sqlite3.connect(self.db_path) as conn: 393 | conn.row_factory = sqlite3.Row 394 | row = conn.execute("SELECT * FROM attachments WHERE attachment_id = ?", (attachment_id,)).fetchone() 395 | if row: 396 | return dict(row) 397 | return None 398 | 399 | def get_conversation_history_upto_message_id(self, conversation_id: str, message_id: str) -> List[Dict]: 400 | """Retrieve the full history of a conversation including attachments up to but not including a message_id. 401 | 402 | Args: 403 | conversation_id (str): ID of the conversation 404 | message_id (str): ID of the message 405 | 406 | Returns: 407 | List[Dict]: List of messages with their attachments in chronological order 408 | """ 409 | with sqlite3.connect(self.db_path) as conn: 410 | conn.row_factory = sqlite3.Row 411 | messages = conn.execute( 412 | """SELECT m.*, a.attachment_id, a.file_name, a.file_path, a.file_type, a.file_size 413 | FROM messages m 414 | LEFT JOIN attachments a ON m.message_id = a.message_id 415 | WHERE m.conversation_id = ? AND m.created_at < ( 416 | SELECT created_at FROM messages WHERE message_id = ? 417 | ) 418 | ORDER BY m.created_at ASC""", 419 | (conversation_id, message_id), 420 | ).fetchall() 421 | 422 | # Group attachments by message_id 423 | message_dict = {} 424 | for row in messages: 425 | message_id = row["message_id"] 426 | if message_id not in message_dict: 427 | message_dict[message_id] = { 428 | key: row[key] 429 | for key in ["message_id", "conversation_id", "role", "content_type", "content", "created_at"] 430 | } 431 | message_dict[message_id]["attachments"] = [] 432 | 433 | if row["attachment_id"]: 434 | message_dict[message_id]["attachments"].append( 435 | { 436 | "attachment_id": row["attachment_id"], 437 | "file_name": row["file_name"], 438 | "file_path": row["file_path"], 439 | "file_type": row["file_type"], 440 | "file_size": row["file_size"], 441 | } 442 | ) 443 | 444 | return list(message_dict.values()) 445 | 446 | def delete_conversation(self, conversation_id: str): 447 | """Delete a conversation and all its associated messages and attachments. 448 | 449 | Args: 450 | conversation_id (str): ID of the conversation to delete 451 | """ 452 | with sqlite3.connect(self.db_path) as conn: 453 | conn.execute( 454 | """DELETE FROM attachments 455 | WHERE message_id IN ( 456 | SELECT message_id FROM messages WHERE conversation_id = ? 457 | )""", 458 | (conversation_id,), 459 | ) 460 | conn.execute("DELETE FROM messages WHERE conversation_id = ?", (conversation_id,)) 461 | conn.execute("DELETE FROM conversations WHERE conversation_id = ?", (conversation_id,)) 462 | 463 | def get_all_conversations(self, project_id: Optional[str] = None) -> List[Dict]: 464 | """Retrieve all conversations with their message counts and last activity. 465 | 466 | Args: 467 | project_id (str, optional): Filter by project ID. 468 | 469 | Returns: 470 | List[Dict]: List of conversations with their metadata 471 | """ 472 | with sqlite3.connect(self.db_path) as conn: 473 | conn.row_factory = sqlite3.Row 474 | 475 | query = """SELECT c.*, 476 | COUNT(m.message_id) as message_count, 477 | MAX(m.created_at) as last_message_at 478 | FROM conversations c 479 | LEFT JOIN messages m ON c.conversation_id = m.conversation_id""" 480 | 481 | params = [] 482 | if project_id: 483 | query += " WHERE c.project_id = ?" 484 | params.append(project_id) 485 | 486 | query += """ GROUP BY c.conversation_id 487 | ORDER BY c.created_at ASC""" 488 | 489 | conversations = conn.execute(query, tuple(params)).fetchall() 490 | 491 | return [dict(conv) for conv in conversations] 492 | 493 | def get_project_for_conversation(self, conversation_id: str) -> Optional[Dict]: 494 | """Get the project associated with a conversation. 495 | 496 | Args: 497 | conversation_id (str): ID of the conversation 498 | 499 | Returns: 500 | Optional[Dict]: Project data if found, None otherwise 501 | """ 502 | with sqlite3.connect(self.db_path) as conn: 503 | conn.row_factory = sqlite3.Row 504 | project = conn.execute( 505 | """SELECT p.* FROM projects p 506 | JOIN conversations c ON p.project_id = c.project_id 507 | WHERE c.conversation_id = ?""", 508 | (conversation_id,), 509 | ).fetchone() 510 | return dict(project) if project else None 511 | 512 | # Project CRUD methods 513 | def create_project(self, name: str, description: str = "", system_prompt: str = "") -> str: 514 | """Create a new project. 515 | 516 | Args: 517 | name (str): Project name 518 | description (str): Project description 519 | system_prompt (str): Default system prompt for the project 520 | 521 | Returns: 522 | str: Project ID 523 | """ 524 | project_id = str(uuid.uuid4()) 525 | with sqlite3.connect(self.db_path) as conn: 526 | conn.execute( 527 | """INSERT INTO projects (project_id, name, description, system_prompt) 528 | VALUES (?, ?, ?, ?)""", 529 | (project_id, name, description, system_prompt), 530 | ) 531 | return project_id 532 | 533 | def get_projects(self) -> List[Dict]: 534 | """Get all projects.""" 535 | with sqlite3.connect(self.db_path) as conn: 536 | conn.row_factory = sqlite3.Row 537 | projects = conn.execute("SELECT * FROM projects ORDER BY created_at ASC").fetchall() 538 | return [dict(p) for p in projects] 539 | 540 | def get_project(self, project_id: str) -> Optional[Dict]: 541 | """Get a project by ID.""" 542 | with sqlite3.connect(self.db_path) as conn: 543 | conn.row_factory = sqlite3.Row 544 | project = conn.execute("SELECT * FROM projects WHERE project_id = ?", (project_id,)).fetchone() 545 | return dict(project) if project else None 546 | 547 | def update_project(self, project_id: str, name: str, description: str, system_prompt: str) -> bool: 548 | """Update a project.""" 549 | with sqlite3.connect(self.db_path) as conn: 550 | cursor = conn.execute( 551 | """UPDATE projects 552 | SET name = ?, description = ?, system_prompt = ?, updated_at = strftime('%s.%f', 'now') 553 | WHERE project_id = ?""", 554 | (name, description, system_prompt, project_id), 555 | ) 556 | return cursor.rowcount > 0 557 | 558 | def delete_project(self, project_id: str) -> bool: 559 | """Delete a project and its conversations.""" 560 | with sqlite3.connect(self.db_path) as conn: 561 | # Delete messages for all conversations in project 562 | conn.execute( 563 | """DELETE FROM messages WHERE conversation_id IN 564 | (SELECT conversation_id FROM conversations WHERE project_id = ?)""", 565 | (project_id,), 566 | ) 567 | # Delete conversations 568 | conn.execute("DELETE FROM conversations WHERE project_id = ?", (project_id,)) 569 | # Delete project 570 | cursor = conn.execute("DELETE FROM projects WHERE project_id = ?", (project_id,)) 571 | return cursor.rowcount > 0 572 | 573 | # Provider CRUD methods 574 | def get_default_provider(self) -> Optional[Dict]: 575 | """Get the default provider with its settings.""" 576 | with sqlite3.connect(self.db_path) as conn: 577 | conn.row_factory = sqlite3.Row 578 | provider = conn.execute("SELECT * FROM providers WHERE is_default = true").fetchone() 579 | return dict(provider) if provider else None 580 | 581 | def get_all_providers(self) -> List[Dict]: 582 | """Get all providers.""" 583 | with sqlite3.connect(self.db_path) as conn: 584 | conn.row_factory = sqlite3.Row 585 | providers = conn.execute("SELECT * FROM providers ORDER BY name").fetchall() 586 | return [dict(p) for p in providers] 587 | 588 | def get_provider_by_id(self, provider_id: int) -> Optional[Dict]: 589 | """Get provider by ID.""" 590 | with sqlite3.connect(self.db_path) as conn: 591 | conn.row_factory = sqlite3.Row 592 | provider = conn.execute("SELECT * FROM providers WHERE id = ?", (provider_id,)).fetchone() 593 | return dict(provider) if provider else None 594 | 595 | def add_provider(self, provider: Dict) -> int: 596 | """Add a new provider.""" 597 | with sqlite3.connect(self.db_path) as conn: 598 | cursor = conn.execute( 599 | """INSERT INTO providers 600 | (name, temperature, top_p, reasoning_effort, use_for_summarization, host, api_key) 601 | VALUES (?, ?, ?, ?, ?, ?, ?)""", 602 | ( 603 | provider.get("name"), 604 | provider.get("temperature", 1.0), 605 | provider.get("top_p", 0.95), 606 | provider.get("reasoning_effort", "none"), 607 | provider.get("use_for_summarization", False), 608 | provider.get("host"), 609 | provider.get("api_key", ""), 610 | ), 611 | ) 612 | return cursor.lastrowid 613 | 614 | def update_provider(self, provider_id: int, provider: Dict) -> bool: 615 | """Update provider settings.""" 616 | with sqlite3.connect(self.db_path) as conn: 617 | cursor = conn.execute( 618 | """UPDATE providers 619 | SET name = ?, temperature = ?, top_p = ?, reasoning_effort = ?, use_for_summarization = ?, 620 | host = ?, api_key = ?, updated_at = strftime('%s.%f', 'now') 621 | WHERE id = ?""", 622 | ( 623 | provider.get("name"), 624 | provider.get("temperature", 1.0), 625 | provider.get("top_p", 0.95), 626 | provider.get("reasoning_effort", "none"), 627 | provider.get("use_for_summarization", False), 628 | provider.get("host"), 629 | provider.get("api_key", ""), 630 | provider_id, 631 | ), 632 | ) 633 | return cursor.rowcount > 0 634 | 635 | def delete_provider(self, provider_id: int) -> bool: 636 | """Delete a provider (cascade deletes models).""" 637 | with sqlite3.connect(self.db_path) as conn: 638 | cursor = conn.execute("DELETE FROM providers WHERE id = ?", (provider_id,)) 639 | return cursor.rowcount > 0 640 | 641 | def set_default_provider(self, provider_id: int) -> bool: 642 | """Set a provider as default.""" 643 | with sqlite3.connect(self.db_path) as conn: 644 | conn.execute("UPDATE providers SET is_default = false WHERE is_default = true") 645 | cursor = conn.execute("UPDATE providers SET is_default = true WHERE id = ?", (provider_id,)) 646 | return cursor.rowcount > 0 647 | 648 | # Model CRUD methods 649 | def get_models_by_provider(self, provider_id: int) -> List[Dict]: 650 | """Get all models for a provider.""" 651 | with sqlite3.connect(self.db_path) as conn: 652 | conn.row_factory = sqlite3.Row 653 | models = conn.execute( 654 | "SELECT * FROM models WHERE provider_id = ? ORDER BY is_default DESC, model_name", (provider_id,) 655 | ).fetchall() 656 | return [dict(m) for m in models] 657 | 658 | def get_default_model(self, provider_id: int) -> Optional[Dict]: 659 | """Get the default model for a provider.""" 660 | with sqlite3.connect(self.db_path) as conn: 661 | conn.row_factory = sqlite3.Row 662 | model = conn.execute( 663 | "SELECT * FROM models WHERE provider_id = ? AND is_default = true", (provider_id,) 664 | ).fetchone() 665 | return dict(model) if model else None 666 | 667 | def add_model( 668 | self, provider_id: int, model_name: str, is_default: bool = False, is_multimodal: bool = False 669 | ) -> int: 670 | """Add a model to a provider.""" 671 | with sqlite3.connect(self.db_path) as conn: 672 | # Check if this is the first model for the provider 673 | existing_models = conn.execute( 674 | "SELECT COUNT(*) FROM models WHERE provider_id = ?", (provider_id,) 675 | ).fetchone()[0] 676 | 677 | # If this is the first model, make it default automatically 678 | if existing_models == 0: 679 | is_default = True 680 | 681 | # If setting as default, unset other defaults for this provider 682 | if is_default: 683 | conn.execute( 684 | "UPDATE models SET is_default = false WHERE provider_id = ? AND is_default = true", (provider_id,) 685 | ) 686 | 687 | cursor = conn.execute( 688 | "INSERT INTO models (provider_id, model_name, is_default, is_multimodal) VALUES (?, ?, ?, ?)", 689 | (provider_id, model_name, is_default, is_multimodal), 690 | ) 691 | return cursor.lastrowid 692 | 693 | def delete_model(self, model_id: int) -> bool: 694 | """Delete a model.""" 695 | with sqlite3.connect(self.db_path) as conn: 696 | cursor = conn.execute("DELETE FROM models WHERE id = ?", (model_id,)) 697 | return cursor.rowcount > 0 698 | 699 | def set_default_model(self, model_id: int) -> bool: 700 | """Set a model as default for its provider.""" 701 | with sqlite3.connect(self.db_path) as conn: 702 | # Get provider_id for this model 703 | provider_id = conn.execute("SELECT provider_id FROM models WHERE id = ?", (model_id,)).fetchone() 704 | if not provider_id: 705 | return False 706 | 707 | # Unset other defaults for this provider 708 | conn.execute( 709 | "UPDATE models SET is_default = false WHERE provider_id = ? AND is_default = true", (provider_id[0],) 710 | ) 711 | 712 | # Set this model as default 713 | cursor = conn.execute("UPDATE models SET is_default = true WHERE id = ?", (model_id,)) 714 | return cursor.rowcount > 0 715 | 716 | def update_conversation_summary(self, conversation_id: str, summary: str): 717 | """Update the summary of a conversation. 718 | 719 | Args: 720 | conversation_id (str): ID of the conversation 721 | summary (str): New summary text for the conversation 722 | """ 723 | with sqlite3.connect(self.db_path) as conn: 724 | conn.execute( 725 | """UPDATE conversations 726 | SET summary = ?, updated_at = strftime('%s.%f', 'now') 727 | WHERE conversation_id = ?""", 728 | (summary, conversation_id), 729 | ) 730 | 731 | def add_system_prompt(self, name: str, text: str) -> int: 732 | """Add a new system prompt. 733 | 734 | Args: 735 | name (str): Name of the prompt 736 | text (str): Prompt text 737 | 738 | Returns: 739 | int: ID of the newly created prompt 740 | """ 741 | with sqlite3.connect(self.db_path) as conn: 742 | cursor = conn.execute("INSERT INTO system_prompts (prompt_name, prompt_text) VALUES (?, ?)", (name, text)) 743 | return cursor.lastrowid 744 | 745 | def edit_system_prompt(self, prompt_id: int, name: str, text: str) -> bool: 746 | """Edit an existing system prompt. 747 | 748 | Args: 749 | prompt_id (int): ID of the prompt to edit 750 | name (str): New name for the prompt 751 | text (str): New prompt text 752 | 753 | Returns: 754 | bool: True if successful, False otherwise 755 | """ 756 | with sqlite3.connect(self.db_path) as conn: 757 | cursor = conn.execute( 758 | """UPDATE system_prompts 759 | SET prompt_name = ?, 760 | prompt_text = ?, 761 | updated_at = strftime('%s.%f', 'now') 762 | WHERE id = ?""", 763 | (name, text, prompt_id), 764 | ) 765 | return cursor.rowcount > 0 766 | 767 | def set_active_prompt(self, prompt_id: int) -> bool: 768 | """Set a prompt as active and deactivate all others. 769 | 770 | Args: 771 | prompt_id (int): ID of the prompt to activate 772 | 773 | Returns: 774 | bool: True if successful, False otherwise 775 | """ 776 | with sqlite3.connect(self.db_path) as conn: 777 | conn.execute("UPDATE system_prompts SET is_active = false") 778 | cursor = conn.execute("UPDATE system_prompts SET is_active = true WHERE id = ?", (prompt_id,)) 779 | return cursor.rowcount > 0 780 | 781 | def get_active_prompt(self) -> Optional[Dict]: 782 | """Get the currently active system prompt. 783 | 784 | Returns: 785 | Optional[Dict]: Active prompt data if found, None otherwise 786 | """ 787 | with sqlite3.connect(self.db_path) as conn: 788 | conn.row_factory = sqlite3.Row 789 | prompt = conn.execute("SELECT * FROM system_prompts WHERE is_active = true").fetchone() 790 | return dict(prompt) if prompt else None 791 | 792 | def get_all_prompts(self) -> List[Dict]: 793 | """Get all system prompts. 794 | 795 | Returns: 796 | List[Dict]: List of all prompts 797 | """ 798 | with sqlite3.connect(self.db_path) as conn: 799 | conn.row_factory = sqlite3.Row 800 | prompts = conn.execute("SELECT * FROM system_prompts").fetchall() 801 | return [dict(prompt) for prompt in prompts] 802 | 803 | def get_prompt_by_id(self, prompt_id: int) -> Optional[Dict]: 804 | """Get a specific system prompt by ID. 805 | 806 | Args: 807 | prompt_id (int): ID of the prompt to retrieve 808 | 809 | Returns: 810 | Optional[Dict]: Prompt data if found, None otherwise 811 | """ 812 | with sqlite3.connect(self.db_path) as conn: 813 | conn.row_factory = sqlite3.Row 814 | prompt = conn.execute("SELECT * FROM system_prompts WHERE id = ?", (prompt_id,)).fetchone() 815 | return dict(prompt) if prompt else None 816 | 817 | def delete_system_prompt(self, prompt_id: int) -> bool: 818 | """Delete a system prompt. 819 | 820 | Args: 821 | prompt_id (int): ID of the prompt to delete 822 | 823 | Returns: 824 | bool: True if successful, False otherwise 825 | """ 826 | with sqlite3.connect(self.db_path) as conn: 827 | cursor = conn.execute("DELETE FROM system_prompts WHERE id = ? AND prompt_name != 'default'", (prompt_id,)) 828 | return cursor.rowcount > 0 829 | 830 | def edit_message(self, message_id: str, new_content: str) -> bool: 831 | """Edit an existing message's content. 832 | 833 | Args: 834 | message_id (str): ID of the message to edit 835 | new_content (str): New message content 836 | 837 | Returns: 838 | bool: True if successful, False if message not found 839 | 840 | Raises: 841 | ValueError: If trying to edit a system message 842 | """ 843 | with sqlite3.connect(self.db_path) as conn: 844 | # Check if message exists and isn't a system message 845 | message = conn.execute("SELECT role FROM messages WHERE message_id = ?", (message_id,)).fetchone() 846 | 847 | if not message: 848 | return False 849 | 850 | if message[0] == "system": 851 | raise ValueError("System messages cannot be edited") 852 | 853 | cursor = conn.execute( 854 | """UPDATE messages 855 | SET content = ?, updated_at = strftime('%s.%f', 'now') 856 | WHERE message_id = ?""", 857 | (new_content, message_id), 858 | ) 859 | return cursor.rowcount > 0 860 | -------------------------------------------------------------------------------- /src/aiaio/app/templates/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | aiaio 8 | 9 | 51 | 52 | 53 | 54 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 228 | 229 | 230 | 231 | 232 |
233 | 234 |
235 | 236 | 314 | 315 | 316 |
317 | 318 |
320 | 323 | aiaio 324 |
325 |
326 | 327 | 328 |
330 |
332 | 333 |
334 | 335 | Loading... 336 |
337 |
338 |
339 | 340 | ... 341 |
342 |
343 |
344 | 345 | ... 346 |
347 |
348 |
349 | 350 | 351 |
352 |
353 | 354 |
355 |
356 | 357 | 358 |
360 |
361 | 362 | 363 | 364 | 365 |
366 |
368 |
369 |
371 | 372 | 377 | 378 | 379 | 382 | 383 | 384 | 388 | 389 | 390 | 394 |
395 |
396 |
397 |

AI can make mistakes. Please verify 398 | important information.

399 |
400 |
401 |
402 | 403 | 404 | 409 |
410 | 411 | 412 | 578 | 579 | 580 | 583 |
584 | 585 | 586 | 588 | 589 | 590 | 605 | 606 | 607 | 629 | 630 | 631 | 632 | 656 | 657 | 658 | 707 | 708 | 709 | 710 | 711 | 712 | -------------------------------------------------------------------------------- /src/aiaio/app/app.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import base64 3 | import os 4 | import re 5 | import sqlite3 6 | import tempfile 7 | import time 8 | from contextvars import ContextVar 9 | from dataclasses import dataclass 10 | from pathlib import Path 11 | from typing import Dict, List, Optional 12 | 13 | from fastapi import FastAPI, File, Form, HTTPException, Request, UploadFile, WebSocket, WebSocketDisconnect 14 | from fastapi.responses import FileResponse, HTMLResponse, JSONResponse, StreamingResponse 15 | from fastapi.staticfiles import StaticFiles 16 | from fastapi.templating import Jinja2Templates 17 | from openai import OpenAI 18 | from pydantic import BaseModel 19 | 20 | from aiaio import __version__, logger 21 | from aiaio.db import ChatDatabase 22 | from aiaio.prompts import SUMMARY_PROMPT 23 | 24 | 25 | logger.info("aiaio...") 26 | 27 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 28 | app = FastAPI() 29 | static_path = os.path.join(BASE_DIR, "static") 30 | app.mount("/static", StaticFiles(directory=static_path), name="static") 31 | templates_path = os.path.join(BASE_DIR, "templates") 32 | templates = Jinja2Templates(directory=templates_path) 33 | 34 | # Create temp directory for uploads 35 | TEMP_DIR = Path(tempfile.gettempdir()) / "aiaio_uploads" 36 | TEMP_DIR.mkdir(exist_ok=True) 37 | 38 | # Initialize database 39 | db = ChatDatabase() 40 | 41 | 42 | class ConnectionManager: 43 | def __init__(self): 44 | self.active_connections: Dict[str, WebSocket] = {} # Use dict instead of list 45 | self.active_generations: Dict[str, bool] = {} # Track active generations 46 | 47 | async def connect(self, websocket: WebSocket, client_id: str): 48 | await websocket.accept() 49 | self.active_connections[client_id] = websocket 50 | self.active_generations[client_id] = False 51 | 52 | def disconnect(self, client_id: str): 53 | if client_id in self.active_connections: 54 | del self.active_connections[client_id] 55 | if client_id in self.active_generations: 56 | del self.active_generations[client_id] 57 | 58 | def set_generating(self, client_id: str, is_generating: bool): 59 | self.active_generations[client_id] = is_generating 60 | 61 | def should_stop(self, client_id: str) -> bool: 62 | return not self.active_generations.get(client_id, False) 63 | 64 | async def broadcast(self, message: dict): 65 | for connection in self.active_connections.values(): 66 | try: 67 | await connection.send_json(message) 68 | except Exception: 69 | # If sending fails, we'll handle it in the main websocket route 70 | pass 71 | 72 | 73 | manager = ConnectionManager() 74 | 75 | 76 | class FileAttachment(BaseModel): 77 | """ 78 | Pydantic model for handling file attachments in messages. 79 | 80 | Attributes: 81 | name (str): Name of the file 82 | type (str): MIME type of the file 83 | data (str): Base64 encoded file data 84 | """ 85 | 86 | name: str 87 | type: str 88 | data: str 89 | 90 | 91 | class MessageContent(BaseModel): 92 | """ 93 | Pydantic model for message content including optional file attachments. 94 | 95 | Attributes: 96 | text (str): The text content of the message 97 | files (List[FileAttachment]): Optional list of file attachments 98 | """ 99 | 100 | text: str 101 | files: Optional[List[FileAttachment]] = None 102 | 103 | 104 | class ChatInput(BaseModel): 105 | """ 106 | Pydantic model for chat input data. 107 | 108 | Attributes: 109 | message (str): The user's message content 110 | system_prompt (str): Instructions for the AI model 111 | conversation_id (str, optional): ID of the conversation 112 | """ 113 | 114 | message: str 115 | system_prompt: str 116 | conversation_id: Optional[str] = None 117 | 118 | 119 | class MessageInput(BaseModel): 120 | """ 121 | Pydantic model for message input data. 122 | 123 | Attributes: 124 | role (str): The role of the message sender (e.g., 'user', 'assistant', 'system') 125 | content (str): The message content 126 | content_type (str): Type of content, defaults to "text" 127 | attachments (List[Dict], optional): List of file attachments 128 | """ 129 | 130 | role: str 131 | content: str 132 | content_type: str = "text" 133 | attachments: Optional[List[Dict]] = None 134 | 135 | 136 | class ProviderInput(BaseModel): 137 | """ 138 | Pydantic model for provider configuration. 139 | 140 | Attributes: 141 | name (str): Name of the provider 142 | temperature (float): Controls randomness in responses 143 | top_p (float): Controls diversity via nucleus sampling 144 | reasoning_effort (str): Reasoning effort level (none, low, medium, high) 145 | use_for_summarization (bool): Whether to use this provider for summarization 146 | host (str): API endpoint URL 147 | api_key (str): Authentication key for the API 148 | """ 149 | 150 | name: str 151 | temperature: Optional[float] = 1.0 152 | top_p: Optional[float] = 0.95 153 | reasoning_effort: Optional[str] = "none" 154 | use_for_summarization: Optional[bool] = False 155 | host: str 156 | api_key: Optional[str] = "" 157 | 158 | 159 | class ModelInput(BaseModel): 160 | """ 161 | Pydantic model for model input. 162 | 163 | Attributes: 164 | model_name (str): Name of the model 165 | """ 166 | 167 | model_name: str 168 | is_multimodal: Optional[bool] = False 169 | 170 | 171 | class PromptInput(BaseModel): 172 | """ 173 | Pydantic model for system prompt input. 174 | 175 | Attributes: 176 | name (str): Name of the prompt 177 | text (str): The prompt text content 178 | """ 179 | 180 | name: str 181 | text: str 182 | 183 | 184 | class ProjectInput(BaseModel): 185 | """ 186 | Pydantic model for project input. 187 | 188 | Attributes: 189 | name (str): Project name 190 | description (str): Project description 191 | system_prompt (str): Default system prompt for the project 192 | """ 193 | 194 | name: str 195 | description: Optional[str] = "" 196 | system_prompt: Optional[str] = "" 197 | 198 | 199 | class MessageEdit(BaseModel): 200 | """ 201 | Pydantic model for message edit requests. 202 | 203 | Attributes: 204 | content (str): New message content 205 | """ 206 | 207 | content: str 208 | 209 | 210 | @dataclass 211 | class RequestContext: 212 | is_disconnected: bool = False 213 | 214 | 215 | # Create a context variable to track request state 216 | request_context: ContextVar[RequestContext] = ContextVar("request_context", default=RequestContext()) 217 | 218 | 219 | async def text_streamer(messages: List[Dict[str, str]], client_id: str): 220 | """Stream text responses from the AI model.""" 221 | # Get default provider and model 222 | provider = db.get_default_provider() 223 | if not provider: 224 | raise HTTPException(status_code=404, detail="No default provider found") 225 | 226 | default_model = db.get_default_model(provider["id"]) 227 | if not default_model: 228 | raise HTTPException(status_code=404, detail="No default model found for provider") 229 | 230 | client = OpenAI( 231 | api_key=provider["api_key"] if provider["api_key"] != "" else "empty", 232 | base_url=provider["host"], 233 | ) 234 | 235 | formatted_messages = [] 236 | 237 | for msg in messages: 238 | formatted_msg = {"role": msg["role"]} 239 | attachments = msg.get("attachments", []) 240 | 241 | if attachments: 242 | # Handle messages with attachments 243 | content = [] 244 | if msg["content"]: 245 | content.append({"type": "text", "text": msg["content"]}) 246 | 247 | for att in attachments: 248 | file_type = att.get("file_type", "").split("/")[0] 249 | file_path = att["file_path"] 250 | mime_type = att.get("file_type", "application/octet-stream") 251 | 252 | # For all file types, encode as base64 and let the API handle it 253 | with open(file_path, "rb") as f: 254 | file_data = base64.b64encode(f.read()).decode() 255 | 256 | # Handle different file types 257 | if file_type == "image": 258 | content.append({"type": "image_url", "image_url": {"url": f"data:{mime_type};base64,{file_data}"}}) 259 | elif file_type == "video": 260 | content.append({"type": "video_url", "video_url": {"url": f"data:{mime_type};base64,{file_data}"}}) 261 | elif file_type == "audio": 262 | content.append( 263 | {"type": "input_audio", "input_audio": {"url": f"data:{mime_type};base64,{file_data}"}} 264 | ) 265 | else: 266 | # For documents (PDF, etc), send as image_url with proper MIME type 267 | # Many APIs support this for document understanding 268 | content.append({"type": "image_url", "image_url": {"url": f"data:{mime_type};base64,{file_data}"}}) 269 | 270 | formatted_msg["content"] = content 271 | else: 272 | # Handle text-only messages 273 | formatted_msg["content"] = msg["content"] 274 | 275 | formatted_messages.append(formatted_msg) 276 | 277 | stream = None 278 | try: 279 | manager.set_generating(client_id, True) 280 | 281 | # Prepare API call parameters 282 | api_params = { 283 | "messages": formatted_messages, 284 | "model": default_model["model_name"], 285 | "temperature": provider["temperature"], 286 | "top_p": provider["top_p"], 287 | "stream": True, 288 | } 289 | 290 | # Add reasoning_effort only if not 'none' 291 | if provider.get("reasoning_effort") and provider["reasoning_effort"] != "none": 292 | api_params["reasoning_effort"] = provider["reasoning_effort"] 293 | 294 | stream = client.chat.completions.create(**api_params) 295 | 296 | for message in stream: 297 | if manager.should_stop(client_id): 298 | logger.info(f"Stopping generation for client {client_id}") 299 | break 300 | 301 | if message.choices and len(message.choices) > 0: 302 | # Handle regular content 303 | if message.choices[0].delta.content is not None: 304 | yield message.choices[0].delta.content 305 | 306 | except Exception as e: 307 | logger.error(f"Error in text_streamer: {e}") 308 | # Yield error message with special marker so frontend can display it 309 | error_message = f"__ERROR__:{str(e)}" 310 | yield error_message 311 | raise 312 | 313 | finally: 314 | manager.set_generating(client_id, False) 315 | if stream and hasattr(stream, "response"): 316 | stream.response.close() 317 | 318 | 319 | @app.get("/", response_class=HTMLResponse) 320 | async def load_index(request: Request): 321 | """ 322 | Serve the main application page. 323 | 324 | Args: 325 | request (Request): FastAPI request object 326 | 327 | Returns: 328 | TemplateResponse: Rendered HTML template 329 | """ 330 | return templates.TemplateResponse( 331 | "index.html", 332 | { 333 | "request": request, 334 | "time": time.strftime("%Y-%m-%d %H:%M:%S"), 335 | }, 336 | ) 337 | 338 | 339 | @app.get("/version") 340 | async def version(): 341 | """ 342 | Get the application version. 343 | 344 | Returns: 345 | dict: Version information 346 | """ 347 | return {"version": __version__} 348 | 349 | 350 | @app.get("/conversations") 351 | async def get_conversations(project_id: Optional[str] = None): 352 | """ 353 | Retrieve all conversations. 354 | 355 | Args: 356 | project_id (str, optional): Filter by project ID 357 | 358 | Returns: 359 | dict: List of all conversations 360 | 361 | Raises: 362 | HTTPException: If database operation fails 363 | """ 364 | try: 365 | conversations = db.get_all_conversations(project_id) 366 | return {"conversations": conversations} 367 | except Exception as e: 368 | raise HTTPException(status_code=500, detail=str(e)) 369 | 370 | 371 | @app.get("/conversations/{conversation_id}") 372 | async def get_conversation(conversation_id: str): 373 | """ 374 | Retrieve a specific conversation's history. 375 | 376 | Args: 377 | conversation_id (str): ID of the conversation to retrieve 378 | 379 | Returns: 380 | dict: Conversation messages 381 | 382 | Raises: 383 | HTTPException: If conversation not found or operation fails 384 | """ 385 | try: 386 | history = db.get_conversation_history(conversation_id) 387 | if not history: 388 | raise HTTPException(status_code=404, detail="Conversation not found") 389 | return {"messages": history} 390 | except Exception as e: 391 | raise HTTPException(status_code=500, detail=str(e)) 392 | 393 | 394 | class CreateConversationInput(BaseModel): 395 | project_id: Optional[str] = None 396 | 397 | 398 | @app.post("/create_conversation") 399 | async def create_conversation(input: Optional[CreateConversationInput] = None): 400 | """ 401 | Create a new conversation. 402 | 403 | Returns: 404 | dict: New conversation ID 405 | 406 | Raises: 407 | HTTPException: If creation fails 408 | """ 409 | try: 410 | project_id = input.project_id if input else None 411 | conversation_id = db.create_conversation(project_id) 412 | await manager.broadcast({"type": "conversation_created", "conversation_id": conversation_id}) 413 | return {"conversation_id": conversation_id} 414 | except Exception as e: 415 | raise HTTPException(status_code=500, detail=str(e)) 416 | 417 | 418 | @app.post("/conversations/{conversation_id}/messages") 419 | async def add_message(conversation_id: str, message: MessageInput): 420 | """ 421 | Add a message to a conversation. 422 | 423 | Args: 424 | conversation_id (str): Target conversation ID 425 | message (MessageInput): Message data to add 426 | 427 | Returns: 428 | dict: Added message ID 429 | 430 | Raises: 431 | HTTPException: If operation fails 432 | """ 433 | try: 434 | message_id = db.add_message( 435 | conversation_id=conversation_id, 436 | role=message.role, 437 | content=message.content, 438 | content_type=message.content_type, 439 | attachments=message.attachments, 440 | ) 441 | return {"message_id": message_id} 442 | except Exception as e: 443 | raise HTTPException(status_code=500, detail=str(e)) 444 | 445 | 446 | @app.put("/messages/{message_id}") 447 | async def edit_message(message_id: str, edit: MessageEdit): 448 | """ 449 | Edit an existing message. 450 | 451 | Args: 452 | message_id (str): ID of the message to edit 453 | edit (MessageEdit): New message content 454 | 455 | Returns: 456 | dict: Operation status 457 | 458 | Raises: 459 | HTTPException: If message not found or edit not allowed 460 | """ 461 | try: 462 | success = db.edit_message(message_id, edit.content) 463 | if not success: 464 | raise HTTPException(status_code=404, detail="Message not found") 465 | 466 | # Get message role to send in broadcast 467 | with sqlite3.connect(db.db_path) as conn: 468 | conn.row_factory = sqlite3.Row 469 | msg = conn.execute("SELECT role FROM messages WHERE message_id = ?", (message_id,)).fetchone() 470 | 471 | # Broadcast update to all connected clients 472 | await manager.broadcast( 473 | {"type": "message_edited", "message_id": message_id, "content": edit.content, "role": msg["role"]} 474 | ) 475 | 476 | return {"status": "success"} 477 | except ValueError as e: 478 | raise HTTPException(status_code=403, detail=str(e)) 479 | except Exception as e: 480 | raise HTTPException(status_code=500, detail=str(e)) 481 | 482 | 483 | @app.get("/messages/{message_id}/raw") 484 | async def get_raw_message(message_id: str): 485 | """Get the raw content of a message. 486 | 487 | Args: 488 | message_id (str): ID of the message to retrieve 489 | 490 | Returns: 491 | dict: Message content 492 | 493 | Raises: 494 | HTTPException: If message not found 495 | """ 496 | try: 497 | with sqlite3.connect(db.db_path) as conn: 498 | conn.row_factory = sqlite3.Row 499 | message = conn.execute("SELECT content FROM messages WHERE message_id = ?", (message_id,)).fetchone() 500 | 501 | if not message: 502 | raise HTTPException(status_code=404, detail="Message not found") 503 | 504 | return {"content": message["content"]} 505 | except Exception as e: 506 | raise HTTPException(status_code=500, detail=str(e)) 507 | 508 | 509 | @app.delete("/conversations/{conversation_id}") 510 | async def delete_conversation(conversation_id: str): 511 | """ 512 | Delete a conversation. 513 | 514 | Args: 515 | conversation_id (str): ID of conversation to delete 516 | 517 | Returns: 518 | dict: Operation status 519 | 520 | Raises: 521 | HTTPException: If deletion fails 522 | """ 523 | try: 524 | db.delete_conversation(conversation_id) 525 | await manager.broadcast({"type": "conversation_deleted", "conversation_id": conversation_id}) 526 | return {"status": "success"} 527 | except Exception as e: 528 | raise HTTPException(status_code=500, detail=str(e)) 529 | 530 | 531 | class ConversationTitleUpdate(BaseModel): 532 | """Model for updating conversation title.""" 533 | 534 | title: str 535 | 536 | 537 | @app.put("/conversations/{conversation_id}/title") 538 | async def update_conversation_title(conversation_id: str, update: ConversationTitleUpdate): 539 | """ 540 | Update a conversation's title/summary. 541 | 542 | Args: 543 | conversation_id (str): ID of conversation to update 544 | update (ConversationTitleUpdate): New title 545 | 546 | Returns: 547 | dict: Operation status 548 | 549 | Raises: 550 | HTTPException: If update fails 551 | """ 552 | try: 553 | db.update_conversation_summary(conversation_id, update.title) 554 | await manager.broadcast( 555 | {"type": "summary_updated", "conversation_id": conversation_id, "summary": update.title} 556 | ) 557 | return {"status": "success"} 558 | except Exception as e: 559 | raise HTTPException(status_code=500, detail=str(e)) 560 | 561 | 562 | # Provider endpoints 563 | @app.get("/providers") 564 | async def get_all_providers(): 565 | """Get all providers.""" 566 | try: 567 | providers = db.get_all_providers() 568 | return {"providers": providers} 569 | except Exception as e: 570 | raise HTTPException(status_code=500, detail=str(e)) 571 | 572 | 573 | @app.get("/providers/{provider_id}") 574 | async def get_provider(provider_id: int): 575 | """Get provider by ID.""" 576 | try: 577 | provider = db.get_provider_by_id(provider_id) 578 | if not provider: 579 | raise HTTPException(status_code=404, detail="Provider not found") 580 | return provider 581 | except Exception as e: 582 | raise HTTPException(status_code=500, detail=str(e)) 583 | 584 | 585 | @app.post("/providers") 586 | async def create_provider(provider: ProviderInput): 587 | """Create a new provider.""" 588 | try: 589 | provider_dict = provider.model_dump() 590 | provider_id = db.add_provider(provider_dict) 591 | return {"status": "success", "id": provider_id} 592 | except sqlite3.IntegrityError as e: 593 | if "unique" in str(e).lower(): 594 | raise HTTPException(status_code=409, detail="A provider with this name already exists") 595 | raise HTTPException(status_code=400, detail=str(e)) 596 | except Exception as e: 597 | raise HTTPException(status_code=500, detail=str(e)) 598 | 599 | 600 | @app.put("/providers/{provider_id}") 601 | async def update_provider(provider_id: int, provider: ProviderInput): 602 | """Update a provider.""" 603 | try: 604 | provider_dict = provider.model_dump() 605 | success = db.update_provider(provider_id, provider_dict) 606 | if not success: 607 | raise HTTPException(status_code=404, detail="Provider not found") 608 | return {"status": "success"} 609 | except Exception as e: 610 | raise HTTPException(status_code=500, detail=str(e)) 611 | 612 | 613 | @app.delete("/providers/{provider_id}") 614 | async def delete_provider(provider_id: int): 615 | """Delete a provider.""" 616 | try: 617 | success = db.delete_provider(provider_id) 618 | if not success: 619 | raise HTTPException(status_code=404, detail="Provider not found") 620 | return {"status": "success"} 621 | except Exception as e: 622 | raise HTTPException(status_code=500, detail=str(e)) 623 | 624 | 625 | @app.post("/providers/{provider_id}/set_default") 626 | async def set_default_provider(provider_id: int): 627 | """Set a provider as default.""" 628 | try: 629 | success = db.set_default_provider(provider_id) 630 | if not success: 631 | raise HTTPException(status_code=404, detail="Provider not found") 632 | return {"status": "success"} 633 | except Exception as e: 634 | raise HTTPException(status_code=500, detail=str(e)) 635 | 636 | 637 | # Model endpoints 638 | @app.get("/providers/{provider_id}/models") 639 | async def get_provider_models(provider_id: int): 640 | """Get all models for a provider.""" 641 | try: 642 | models = db.get_models_by_provider(provider_id) 643 | return {"models": models} 644 | except Exception as e: 645 | raise HTTPException(status_code=500, detail=str(e)) 646 | 647 | 648 | @app.post("/providers/{provider_id}/models") 649 | async def add_provider_model(provider_id: int, model: ModelInput): 650 | """Add a model to a provider.""" 651 | try: 652 | model_id = db.add_model(provider_id, model.model_name, is_multimodal=model.is_multimodal) 653 | return {"status": "success", "id": model_id} 654 | except sqlite3.IntegrityError as e: 655 | if "unique" in str(e).lower(): 656 | raise HTTPException(status_code=409, detail="This model already exists for this provider") 657 | raise HTTPException(status_code=400, detail=str(e)) 658 | except Exception as e: 659 | raise HTTPException(status_code=500, detail=str(e)) 660 | 661 | 662 | @app.delete("/models/{model_id}") 663 | async def delete_model(model_id: int): 664 | """Delete a model.""" 665 | try: 666 | success = db.delete_model(model_id) 667 | if not success: 668 | raise HTTPException(status_code=404, detail="Model not found") 669 | return {"status": "success"} 670 | except Exception as e: 671 | raise HTTPException(status_code=500, detail=str(e)) 672 | 673 | 674 | @app.post("/models/{model_id}/set_default") 675 | async def set_default_model(model_id: int): 676 | """Set a model as default for its provider.""" 677 | try: 678 | success = db.set_default_model(model_id) 679 | if not success: 680 | raise HTTPException(status_code=404, detail="Model not found") 681 | return {"status": "success"} 682 | except Exception as e: 683 | raise HTTPException(status_code=500, detail=str(e)) 684 | 685 | 686 | @app.get("/default_provider") 687 | async def get_default_provider(): 688 | """Get the default provider.""" 689 | try: 690 | provider = db.get_default_provider() 691 | if not provider: 692 | raise HTTPException(status_code=404, detail="No default provider found") 693 | return provider 694 | except Exception as e: 695 | raise HTTPException(status_code=500, detail=str(e)) 696 | 697 | 698 | # Project endpoints 699 | @app.get("/projects") 700 | async def get_projects(): 701 | """Get all projects.""" 702 | try: 703 | projects = db.get_projects() 704 | return {"projects": projects} 705 | except Exception as e: 706 | raise HTTPException(status_code=500, detail=str(e)) 707 | 708 | 709 | @app.get("/projects/{project_id}") 710 | async def get_project(project_id: str): 711 | """Get project by ID.""" 712 | try: 713 | project = db.get_project(project_id) 714 | if not project: 715 | raise HTTPException(status_code=404, detail="Project not found") 716 | return project 717 | except Exception as e: 718 | raise HTTPException(status_code=500, detail=str(e)) 719 | 720 | 721 | @app.post("/projects") 722 | async def create_project(project: ProjectInput): 723 | """Create a new project.""" 724 | try: 725 | project_id = db.create_project( 726 | name=project.name, description=project.description, system_prompt=project.system_prompt 727 | ) 728 | return {"status": "success", "id": project_id} 729 | except Exception as e: 730 | raise HTTPException(status_code=500, detail=str(e)) 731 | 732 | 733 | @app.put("/projects/{project_id}") 734 | async def update_project(project_id: str, project: ProjectInput): 735 | """Update a project.""" 736 | try: 737 | success = db.update_project( 738 | project_id=project_id, 739 | name=project.name, 740 | description=project.description, 741 | system_prompt=project.system_prompt, 742 | ) 743 | if not success: 744 | raise HTTPException(status_code=404, detail="Project not found") 745 | return {"status": "success"} 746 | except Exception as e: 747 | raise HTTPException(status_code=500, detail=str(e)) 748 | 749 | 750 | @app.delete("/projects/{project_id}") 751 | async def delete_project(project_id: str): 752 | """Delete a project.""" 753 | try: 754 | success = db.delete_project(project_id) 755 | if not success: 756 | raise HTTPException(status_code=404, detail="Project not found") 757 | return {"status": "success"} 758 | except Exception as e: 759 | raise HTTPException(status_code=500, detail=str(e)) 760 | 761 | 762 | def generate_safe_filename(original_filename: str) -> str: 763 | """ 764 | Generate a safe filename with timestamp to prevent collisions. 765 | 766 | Args: 767 | original_filename (str): Original filename to be sanitized 768 | 769 | Returns: 770 | str: Sanitized filename with timestamp 771 | """ 772 | # Get timestamp 773 | timestamp = time.strftime("%Y%m%d_%H%M%S") 774 | 775 | # Get file extension 776 | ext = Path(original_filename).suffix 777 | 778 | # Get base name and sanitize it 779 | base = Path(original_filename).stem 780 | # Remove special characters and spaces 781 | base = re.sub(r"[^\w\-_]", "_", base) 782 | 783 | # Create new filename 784 | return f"{base}_{timestamp}{ext}" 785 | 786 | 787 | @app.get("/get_system_prompt", response_class=JSONResponse) 788 | async def get_system_prompt(conversation_id: str = None): 789 | """ 790 | Get the system prompt for a conversation. 791 | 792 | Args: 793 | conversation_id (str, optional): ID of the conversation 794 | 795 | Returns: 796 | JSONResponse: System prompt text 797 | 798 | Raises: 799 | HTTPException: If retrieval fails 800 | """ 801 | try: 802 | if conversation_id: 803 | history = db.get_conversation_history(conversation_id) 804 | if history: 805 | system_role_messages = [m for m in history if m["role"] == "system"] 806 | last_system_message = ( 807 | system_role_messages[-1]["content"] if system_role_messages else "You are a helpful assistant." 808 | ) 809 | return {"system_prompt": last_system_message} 810 | 811 | # Default system prompt for new conversations or when no conversation_id is provided 812 | # If conversation_id is provided, check its project 813 | if conversation_id: 814 | project = db.get_project_for_conversation(conversation_id) 815 | if project and project.get("system_prompt"): 816 | return {"system_prompt": project["system_prompt"]} 817 | 818 | active_prompt = db.get_active_prompt() 819 | return {"system_prompt": active_prompt["prompt_text"]} 820 | except Exception as e: 821 | raise HTTPException(status_code=500, detail=str(e)) 822 | 823 | 824 | @app.post("/chat", response_class=StreamingResponse) 825 | async def chat( 826 | message: str = Form(...), 827 | system_prompt: str = Form(...), 828 | conversation_id: str = Form(...), # Now required 829 | client_id: str = Form(...), # Add client_id parameter 830 | files: List[UploadFile] = File(None), 831 | request: Request = None, 832 | ): 833 | """ 834 | Handle chat requests with support for file uploads and streaming responses. 835 | 836 | Args: 837 | message (str): User's message 838 | system_prompt (str): System instructions for the AI 839 | conversation_id (str): Unique identifier for the conversation 840 | files (List[UploadFile]): Optional list of uploaded files 841 | 842 | Returns: 843 | StreamingResponse: Server-sent events stream of AI responses 844 | 845 | Raises: 846 | HTTPException: If there's an error processing the request 847 | """ 848 | try: 849 | logger.info(f"Chat request: message='{message}' conv_id={conversation_id} system_prompt='{system_prompt}'") 850 | 851 | # Create new context for this request 852 | ctx = RequestContext() 853 | token = request_context.set(ctx) 854 | 855 | try: 856 | # Verify conversation exists 857 | history = db.get_conversation_history(conversation_id) 858 | if history: 859 | system_role_messages = [m for m in history if m["role"] == "system"] 860 | last_system_message = system_role_messages[-1]["content"] if system_role_messages else "" 861 | if last_system_message != system_prompt: 862 | db.add_message(conversation_id=conversation_id, role="system", content=system_prompt) 863 | 864 | # Handle multiple file uploads 865 | file_info_list = [] 866 | if files: 867 | for file in files: 868 | if file is None: 869 | continue 870 | 871 | # Get file size by reading the file into memory 872 | contents = await file.read() 873 | file_size = len(contents) 874 | 875 | # Generate safe unique filename 876 | safe_filename = generate_safe_filename(file.filename) 877 | temp_file = TEMP_DIR / safe_filename 878 | 879 | try: 880 | # Save uploaded file 881 | with open(temp_file, "wb") as f: 882 | f.write(contents) 883 | file_info = { 884 | "name": file.filename, # Original name for display 885 | "path": str(temp_file), # Path to saved file 886 | "type": file.content_type, 887 | "size": file_size, 888 | } 889 | file_info_list.append(file_info) 890 | logger.info(f"Saved uploaded file: {temp_file} ({file_size} bytes)") 891 | except Exception as e: 892 | logger.error(f"Failed to save uploaded file: {e}") 893 | raise HTTPException(status_code=500, detail=f"Failed to process uploaded file: {str(e)}") 894 | 895 | # Try to read file content if it's text 896 | try: 897 | with open(temp_file, "r", encoding="utf-8") as f: 898 | text_content = f.read() 899 | # Append text content to message 900 | message += f"\n\n--- File: {file.filename} ---\n{text_content}" 901 | except UnicodeDecodeError: 902 | # Not a text file, skip appending content 903 | pass 904 | except Exception as e: 905 | logger.warning(f"Failed to read file content: {e}") 906 | 907 | if not history: 908 | db.add_message(conversation_id=conversation_id, role="system", content=system_prompt) 909 | 910 | db.add_message( 911 | conversation_id=conversation_id, 912 | role="user", 913 | content=message, 914 | attachments=file_info_list if file_info_list else None, 915 | ) 916 | 917 | # get updated conversation history 918 | history = db.get_conversation_history(conversation_id) 919 | 920 | async def process_and_stream(): 921 | """ 922 | Inner generator function to process the chat and stream responses. 923 | 924 | Yields: 925 | str: Chunks of the AI response 926 | """ 927 | full_response = "" 928 | 929 | # Removed canned response generator 930 | 931 | try: 932 | async for chunk in text_streamer(history, client_id): 933 | if ctx.is_disconnected: 934 | logger.info("Client disconnected, stopping generation") 935 | # Don't save partial response on user-initiated stop 936 | return 937 | full_response += chunk 938 | yield chunk 939 | await asyncio.sleep(0) # Ensure chunks are flushed immediately 940 | except asyncio.CancelledError: 941 | # Request was cancelled, save what we have so far 942 | logger.info("Request cancelled by client, saving partial response") 943 | if full_response: 944 | db.add_message(conversation_id=conversation_id, role="assistant", content=full_response) 945 | raise 946 | except Exception as e: 947 | logger.error(f"Error in process_and_stream: {e}") 948 | raise 949 | 950 | # Only store complete response if not cancelled 951 | message_id = db.add_message(conversation_id=conversation_id, role="assistant", content=full_response) 952 | 953 | # Broadcast update after storing the response 954 | await manager.broadcast( 955 | { 956 | "type": "message_added", 957 | "conversation_id": conversation_id, 958 | "message_id": message_id, 959 | } 960 | ) 961 | 962 | # Generate and store summary after assistant's response but only if its the first user message 963 | if len(history) == 2 and history[1]["role"] == "user": 964 | try: 965 | # Get the default provider to check use_for_summarization setting 966 | provider = db.get_default_provider() 967 | 968 | if provider and provider.get("use_for_summarization", False): 969 | # Use model to generate summary 970 | all_user_messages = [m["content"] for m in history if m["role"] == "user"] 971 | summary_messages = [ 972 | {"role": "system", "content": SUMMARY_PROMPT}, 973 | {"role": "user", "content": str(all_user_messages)}, 974 | ] 975 | summary = "" 976 | logger.info(summary_messages) 977 | async for chunk in text_streamer(summary_messages, client_id): 978 | summary += chunk 979 | db.update_conversation_summary(conversation_id, summary.strip()) 980 | else: 981 | # Use first few words of user message as summary 982 | user_message = history[1]["content"] 983 | # Get first 50 characters or until first newline 984 | summary = user_message.split("\n")[0][:50] 985 | if len(user_message.split("\n")[0]) > 50: 986 | summary += "..." 987 | db.update_conversation_summary(conversation_id, summary) 988 | 989 | # After summary update 990 | await manager.broadcast( 991 | { 992 | "type": "summary_updated", 993 | "conversation_id": conversation_id, 994 | "summary": summary if isinstance(summary, str) else summary.strip(), 995 | } 996 | ) 997 | except Exception as e: 998 | logger.error(f"Failed to generate summary: {e}") 999 | 1000 | response = StreamingResponse( 1001 | process_and_stream(), 1002 | media_type="text/event-stream", 1003 | headers={ 1004 | "Cache-Control": "no-cache", 1005 | "Connection": "keep-alive", 1006 | "X-Accel-Buffering": "no", # Disable Nginx buffering 1007 | }, 1008 | ) 1009 | 1010 | # Set up disconnection detection using response closure 1011 | async def on_disconnect(): 1012 | logger.info("Client disconnected, setting disconnected flag") 1013 | ctx.is_disconnected = True 1014 | 1015 | response.background = on_disconnect 1016 | return response 1017 | 1018 | finally: 1019 | # Reset context when done 1020 | request_context.reset(token) 1021 | 1022 | except Exception as e: 1023 | logger.error(f"Error in chat endpoint: {str(e)}") 1024 | raise HTTPException(status_code=500, detail=str(e)) 1025 | 1026 | 1027 | @app.post("/regenerate_response", response_class=StreamingResponse) 1028 | async def chat_again( 1029 | message: str = Form(...), 1030 | system_prompt: str = Form(...), 1031 | conversation_id: str = Form(...), 1032 | message_id: str = Form(...), 1033 | client_id: str = Form(...), # Add client_id parameter 1034 | ): 1035 | """ 1036 | This endpoint is used to regenerate the response of a message in a conversation at any point in time. 1037 | 1038 | Args: 1039 | message (str): User's message 1040 | system_prompt (str): System instructions for the AI 1041 | conversation_id (str): Unique identifier for the conversation 1042 | message_id (str): ID of the message to regenerate. This message will be replaced with the new response. 1043 | 1044 | Returns: 1045 | StreamingResponse: Server-sent events stream of AI responses 1046 | 1047 | Raises: 1048 | HTTPException: If there's an error processing the request 1049 | """ 1050 | try: 1051 | logger.info( 1052 | f"Regenerate request: message='{message}' conv_id={conversation_id} system_prompt='{system_prompt}'" 1053 | ) 1054 | 1055 | # Verify conversation exists 1056 | history = db.get_conversation_history_upto_message_id(conversation_id, message_id) 1057 | logger.info(history) 1058 | 1059 | if not history: 1060 | logger.error("No conversation history found") 1061 | raise HTTPException(status_code=404, detail="No conversation history found") 1062 | 1063 | system_role_messages = [m for m in history if m["role"] == "system"] 1064 | last_system_message = system_role_messages[-1]["content"] if system_role_messages else "" 1065 | if last_system_message != system_prompt: 1066 | db.add_message(conversation_id=conversation_id, role="system", content=system_prompt) 1067 | 1068 | async def process_and_stream(): 1069 | """ 1070 | Inner generator function to process the chat and stream responses. 1071 | 1072 | Yields: 1073 | str: Chunks of the AI response 1074 | """ 1075 | full_response = "" 1076 | async for chunk in text_streamer(history, client_id): 1077 | full_response += chunk 1078 | yield chunk 1079 | await asyncio.sleep(0) # Ensure chunks are flushed immediately 1080 | 1081 | # Store the complete response 1082 | db.edit_message(message_id, full_response) 1083 | 1084 | # Broadcast update after storing the response 1085 | await manager.broadcast( 1086 | { 1087 | "type": "message_added", 1088 | "conversation_id": conversation_id, 1089 | } 1090 | ) 1091 | 1092 | return StreamingResponse( 1093 | process_and_stream(), 1094 | media_type="text/event-stream", 1095 | headers={ 1096 | "Cache-Control": "no-cache", 1097 | "Connection": "keep-alive", 1098 | "X-Accel-Buffering": "no", # Disable Nginx buffering 1099 | }, 1100 | ) 1101 | 1102 | except Exception as e: 1103 | logger.error(f"Error in chat endpoint: {str(e)}") 1104 | raise HTTPException(status_code=500, detail=str(e)) 1105 | 1106 | 1107 | @app.post("/conversations/{conversation_id}/summary") 1108 | async def update_conversation_summary(conversation_id: str, summary: str = Form(...)): 1109 | """ 1110 | Update the summary of a conversation. 1111 | 1112 | Args: 1113 | conversation_id (str): ID of the conversation 1114 | summary (str): New summary text 1115 | 1116 | Returns: 1117 | dict: Operation status 1118 | 1119 | Raises: 1120 | HTTPException: If update fails 1121 | """ 1122 | try: 1123 | db.update_conversation_summary(conversation_id, summary) 1124 | return {"status": "success"} 1125 | except Exception as e: 1126 | raise HTTPException(status_code=500, detail=str(e)) 1127 | 1128 | 1129 | @app.get("/prompts") 1130 | async def get_all_prompts(): 1131 | """Get all system prompts.""" 1132 | try: 1133 | prompts = db.get_all_prompts() 1134 | formatted_prompts = [] 1135 | for prompt in prompts: 1136 | formatted_prompts.append( 1137 | { 1138 | "id": prompt["id"], 1139 | "name": prompt["prompt_name"], 1140 | "content": prompt["prompt_text"], 1141 | "is_active": bool(prompt["is_active"]), # Ensure boolean type 1142 | } 1143 | ) 1144 | return {"prompts": formatted_prompts} 1145 | except Exception as e: 1146 | logger.error(f"Error getting prompts: {str(e)}") 1147 | raise HTTPException(status_code=500, detail=str(e)) 1148 | 1149 | 1150 | @app.get("/prompts/{prompt_id}") 1151 | async def get_prompt(prompt_id: int): 1152 | """Get a specific prompt.""" 1153 | try: 1154 | prompt = db.get_prompt_by_id(prompt_id) 1155 | if not prompt: 1156 | raise HTTPException(status_code=404, detail="Prompt not found") 1157 | return { 1158 | "id": prompt["id"], # Changed from tuple index to dict key 1159 | "name": prompt["prompt_name"], 1160 | "content": prompt["prompt_text"], 1161 | } 1162 | except Exception as e: 1163 | logger.error(f"Error getting prompt {prompt_id}: {str(e)}") 1164 | raise HTTPException(status_code=500, detail=str(e)) 1165 | 1166 | 1167 | @app.post("/prompts") 1168 | async def create_prompt(prompt: PromptInput): 1169 | """Create a new prompt.""" 1170 | try: 1171 | prompt_id = db.add_system_prompt(prompt.name, prompt.text) 1172 | return {"id": prompt_id} 1173 | except Exception as e: 1174 | raise HTTPException(status_code=500, detail=str(e)) 1175 | 1176 | 1177 | @app.put("/prompts/{prompt_id}") 1178 | async def update_prompt(prompt_id: int, prompt: PromptInput): 1179 | """Update an existing prompt.""" 1180 | try: 1181 | success = db.edit_system_prompt(prompt_id, prompt.name, prompt.text) 1182 | if not success: 1183 | raise HTTPException(status_code=404, detail="Prompt not found") 1184 | return {"status": "success"} 1185 | except Exception as e: 1186 | raise HTTPException(status_code=500, detail=str(e)) 1187 | 1188 | 1189 | @app.delete("/prompts/{prompt_id}") 1190 | async def delete_prompt(prompt_id: int): 1191 | """ 1192 | Delete a system prompt. 1193 | 1194 | Args: 1195 | prompt_id (int): ID of the prompt to delete 1196 | 1197 | Returns: 1198 | dict: Operation status 1199 | 1200 | Raises: 1201 | HTTPException: If deletion fails or prompt is protected 1202 | """ 1203 | try: 1204 | # Get prompt to check if it's the default one 1205 | prompt = db.get_prompt_by_id(prompt_id) 1206 | if not prompt: 1207 | raise HTTPException(status_code=404, detail="Prompt not found") 1208 | 1209 | if prompt["prompt_name"] == "default": 1210 | raise HTTPException(status_code=403, detail="Cannot delete the default prompt") 1211 | 1212 | success = db.delete_system_prompt(prompt_id) 1213 | if not success: 1214 | raise HTTPException(status_code=500, detail="Failed to delete prompt") 1215 | return {"status": "success"} 1216 | except Exception as e: 1217 | raise HTTPException(status_code=500, detail=str(e)) 1218 | 1219 | 1220 | @app.post("/prompts/{prompt_id}/activate") 1221 | async def activate_prompt(prompt_id: int): 1222 | """ 1223 | Set a prompt as the active system prompt. 1224 | 1225 | Args: 1226 | prompt_id (int): ID of the prompt to activate 1227 | 1228 | Returns: 1229 | dict: Operation status 1230 | 1231 | Raises: 1232 | HTTPException: If activation fails or prompt not found 1233 | """ 1234 | try: 1235 | success = db.set_active_prompt(prompt_id) 1236 | if not success: 1237 | raise HTTPException(status_code=404, detail="Prompt not found") 1238 | return {"status": "success"} 1239 | except Exception as e: 1240 | raise HTTPException(status_code=500, detail=str(e)) 1241 | 1242 | 1243 | @app.get("/prompts/active") 1244 | async def get_active_prompt(): 1245 | """Get the currently active system prompt.""" 1246 | try: 1247 | prompt = db.get_active_prompt() 1248 | if not prompt: 1249 | # If no active prompt, get default 1250 | prompt = db.get_prompt_by_name("default") 1251 | if prompt: 1252 | # Make default prompt active 1253 | db.set_active_prompt(prompt["id"]) 1254 | 1255 | if not prompt: 1256 | raise HTTPException(status_code=404, detail="No active or default prompt found") 1257 | 1258 | return {"id": prompt["id"], "name": prompt["prompt_name"], "content": prompt["prompt_text"]} 1259 | except Exception as e: 1260 | logger.error(f"Error getting active prompt: {str(e)}") 1261 | raise HTTPException(status_code=500, detail=str(e)) 1262 | 1263 | 1264 | @app.websocket("/ws/{client_id}") 1265 | async def websocket_endpoint(websocket: WebSocket, client_id: str): 1266 | await manager.connect(websocket, client_id) 1267 | try: 1268 | while True: 1269 | message = await websocket.receive_text() 1270 | if message == "stop_generation": 1271 | manager.set_generating(client_id, False) 1272 | logger.info(f"Received stop signal for client {client_id}") 1273 | else: 1274 | # Handle other WebSocket messages 1275 | pass 1276 | except WebSocketDisconnect: 1277 | manager.disconnect(client_id) 1278 | 1279 | 1280 | @app.get("/attachments/{attachment_id}") 1281 | async def get_attachment(attachment_id: str): 1282 | """Serve attachment files.""" 1283 | attachment = db.get_attachment(attachment_id) 1284 | if not attachment: 1285 | raise HTTPException(status_code=404, detail="Attachment not found") 1286 | 1287 | file_path = Path(attachment["file_path"]) 1288 | if not file_path.exists(): 1289 | raise HTTPException(status_code=404, detail="File not found") 1290 | 1291 | return FileResponse(file_path, filename=attachment["file_name"], media_type=attachment["file_type"]) 1292 | --------------------------------------------------------------------------------