├── tests ├── __init__.py ├── myla │ ├── __init__.py │ ├── llms │ │ ├── __init__.py │ │ └── mock_test.py │ ├── vectorstores │ │ ├── __init__.py │ │ ├── xinference_embeddings_test.py │ │ ├── record_test.py │ │ └── faiss_group_test.py │ ├── utils_test.py │ ├── threads_test.py │ ├── files_test.py │ ├── persistence_test.py │ ├── assistants_test.py │ └── users_test.py ├── tools │ ├── __init__.py │ └── sync_tool.py ├── perf │ ├── .gitignore │ └── vs.py └── run.sh ├── myla ├── extensions │ ├── __init__.py │ └── tools │ │ ├── __init__.py │ │ ├── qa_summary.py │ │ ├── rephrase.py │ │ └── qa_retrieval.py ├── _version.py ├── _logging.py ├── webui │ ├── statics │ │ ├── welcome.md │ │ └── images │ │ │ └── screenshot.png │ ├── __init__.py │ ├── _web_template.py │ └── templates │ │ └── index.html ├── projects.py ├── __init__.py ├── _env.py ├── vectorstores │ ├── loaders.py │ ├── pandas_loader.py │ ├── _embeddings.py │ ├── xinference_embeddings.py │ ├── pdf_loader.py │ ├── sentence_transformers_embeddings.py │ ├── _base.py │ ├── chromadb_vectorstore.py │ ├── lancedb_vectorstore.py │ ├── __init__.py │ ├── faiss_vectorstore.py │ └── faiss_group.py ├── llms │ ├── utils.py │ ├── mock.py │ ├── backend.py │ ├── __init__.py │ ├── chatglm.py │ └── openai.py ├── permissions.py ├── _tools.py ├── iur.py ├── persistence.py ├── _auth.py ├── tools.py ├── utils.py ├── files.py ├── assistants.py ├── __main__.py ├── _run_scheduler.py ├── threads.py ├── _entry.py ├── retrieval.py ├── messages.py ├── runs.py ├── _models.py └── _llm.py ├── setup.py ├── requirements_dev.txt ├── js ├── build.sh ├── .gitignore ├── public │ └── index.html ├── src │ ├── index.css │ ├── index.js │ ├── org_selector.js │ ├── user.js │ ├── settings.js │ ├── secret_key.js │ ├── user_admin.js │ ├── members.js │ └── chat.js └── package.json ├── pyproject.toml ├── .pre-commit-config.yaml ├── requirements.txt ├── examples ├── llm_sync.py ├── upload_file.py ├── embeddings.py ├── llm_debug.py ├── docs_summary.py ├── lancedb_vs.py ├── chatbot-openai-sdk.py └── chatbot.py ├── scripts ├── update_0.2.22.sql └── update_0.2.26.sql ├── LICENSE ├── env-example.txt ├── setup.cfg ├── README_zh_CN.md ├── .gitignore └── README.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/myla/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/tools/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /myla/extensions/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/myla/llms/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /myla/extensions/tools/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/myla/vectorstores/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /myla/_version.py: -------------------------------------------------------------------------------- 1 | VERSION = '0.2.35' 2 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | setuptools.setup() 4 | -------------------------------------------------------------------------------- /tests/perf/.gitignore: -------------------------------------------------------------------------------- 1 | *.csv 2 | vs 3 | data 4 | *.pkl 5 | *.md -------------------------------------------------------------------------------- /tests/run.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | python -m unittest discover -p "*_test.py" -------------------------------------------------------------------------------- /myla/_logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | logger = logging.getLogger('myla') 4 | -------------------------------------------------------------------------------- /requirements_dev.txt: -------------------------------------------------------------------------------- 1 | build 2 | twine 3 | flake8 4 | autopep8 5 | isort 6 | pre-commit 7 | pydoc-markdown 8 | aiounittest -------------------------------------------------------------------------------- /myla/webui/statics/welcome.md: -------------------------------------------------------------------------------- 1 | ## Muyu Local Assistant 🚀 2 | 3 | 4 | * [Docs](/api/docs) 5 | * [API Debugging](/api/swagger) -------------------------------------------------------------------------------- /myla/webui/statics/images/screenshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/muyuworks/myla/HEAD/myla/webui/statics/images/screenshot.png -------------------------------------------------------------------------------- /myla/projects.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from . import _models 4 | 5 | 6 | class Project(_models.DBModel, table=True): 7 | name: Optional[str] = None 8 | -------------------------------------------------------------------------------- /js/build.sh: -------------------------------------------------------------------------------- 1 | npm run build 2 | 3 | mkdir -p ../myla/webui/statics/aify 4 | cp build/static/js/*.js ../myla/webui/statics/aify/aify.js 5 | cp build/static/css/*.css ../myla/webui/statics/aify/aify.css 6 | -------------------------------------------------------------------------------- /myla/__init__.py: -------------------------------------------------------------------------------- 1 | from ._version import VERSION 2 | __version__ = VERSION 3 | 4 | from ._api import api 5 | from ._entry import entry 6 | from ._logging import logger 7 | from ._run_scheduler import RunScheduler 8 | -------------------------------------------------------------------------------- /myla/_env.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | _here = os.path.abspath(os.path.join(os.path.dirname(__file__))) 4 | 5 | 6 | def webui_dir(): 7 | "Returns the directory where webuid resources ared stored." 8 | return os.path.join(_here, 'webui') 9 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", 'cython', "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [tool.pyright] 6 | venvPath = "." 7 | venv = ".venv" 8 | 9 | [tool.yaml] 10 | validate = false 11 | -------------------------------------------------------------------------------- /myla/vectorstores/loaders.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Iterator, Optional, Dict 3 | from ._base import Record 4 | 5 | 6 | class Loader(ABC): 7 | 8 | @abstractmethod 9 | def load(self, file, metadata: Optional[Dict] = None) -> Iterator[Record]: 10 | """Load data from a file""" 11 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/hhatto/autopep8 3 | rev: v2.0.4 4 | hooks: 5 | - id: autopep8 6 | - repo: https://github.com/pycqa/isort 7 | rev: 5.13.2 8 | hooks: 9 | - id: isort 10 | - repo: https://github.com/pycqa/flake8 11 | rev: 7.0.0 12 | hooks: 13 | - id: flake8 14 | -------------------------------------------------------------------------------- /myla/llms/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | 4 | def plain_messages(messages: List[Dict], model=None, roles=['user', 'assistant', 'system']): 5 | text = [] 6 | for m in messages: 7 | role = m['role'] 8 | if role in roles: 9 | text.append(f"{role}: {m['content']}") 10 | return '\n'.join(text) 11 | -------------------------------------------------------------------------------- /myla/webui/__init__.py: -------------------------------------------------------------------------------- 1 | from ._web_template import render 2 | from starlette.requests import Request 3 | 4 | 5 | async def assistant(request: Request): 6 | ctx = { 7 | "assistant_id": request.path_params.get("assistant_id", ''), 8 | "chat_mode": True 9 | } 10 | return await render('index.html', context=ctx)(request=request) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # API Server 2 | uvicorn 3 | starlette 4 | fastapi 5 | python-dotenv 6 | 7 | # Web 8 | jinja2 9 | 10 | # HTTP Client 11 | aiohttp 12 | 13 | # Persistence 14 | sqlmodel 15 | 16 | # LLMs 17 | openai 18 | 19 | # File upload 20 | python-multipart 21 | 22 | pandas 23 | aiofiles 24 | 25 | sentence_transformers 26 | faiss-cpu 27 | 28 | Authlib -------------------------------------------------------------------------------- /examples/llm_sync.py: -------------------------------------------------------------------------------- 1 | from myla import llms 2 | import dotenv 3 | 4 | dotenv.load_dotenv(".env") 5 | 6 | llm = llms.get("gpt-3.5-turbo") 7 | 8 | resp = llm.sync_generate(instructions="hi") 9 | print(resp) 10 | 11 | resp = llm.sync_chat(messages=[{"role": "user", "content": "hi"}], stream=True) 12 | for r in resp: 13 | print(r, end='', flush=True) 14 | print("\n") 15 | -------------------------------------------------------------------------------- /tests/myla/utils_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from myla import utils 3 | 4 | 5 | class TestUtils(unittest.TestCase): 6 | def test_generate_id(self): 7 | print('uuid: ', utils.uuid().hex) 8 | print("sha1: ", utils.sha1(utils.uuid().bytes).hex()) 9 | print("Id: ", utils.random_id()) 10 | print("SecretKey: ", utils.random_key()) 11 | -------------------------------------------------------------------------------- /js/.gitignore: -------------------------------------------------------------------------------- 1 | # See https://help.github.com/articles/ignoring-files/ for more about ignoring files. 2 | 3 | # dependencies 4 | /node_modules 5 | /.pnp 6 | .pnp.js 7 | 8 | # testing 9 | /coverage 10 | 11 | # production 12 | /build 13 | 14 | # misc 15 | .DS_Store 16 | .env.local 17 | .env.development.local 18 | .env.test.local 19 | .env.production.local 20 | 21 | npm-debug.log* 22 | yarn-debug.log* 23 | yarn-error.log* 24 | -------------------------------------------------------------------------------- /examples/upload_file.py: -------------------------------------------------------------------------------- 1 | import openai 2 | 3 | openai.api_key = "sk-" 4 | openai.base_url = "http://localhost:2000/api/v1/" 5 | 6 | openai.files.create( 7 | #file=open("./examples/upload_file.py", 'rb'), 8 | file=open("./data/myla_test_kb.json", 'rb'), 9 | purpose="assistants", 10 | extra_body={"embeddings": "category,query"} 11 | ) 12 | 13 | files = openai.files.list(purpose="assistants") 14 | 15 | print(files) 16 | -------------------------------------------------------------------------------- /examples/embeddings.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import numpy as np 3 | import faiss 4 | from myla.vectorstores.sentence_transformers_embeddings import SentenceTransformerEmbeddings 5 | 6 | embeddings = SentenceTransformerEmbeddings(model_name="/Users/shellc/Downloads/bge-large-zh-v1.5") 7 | 8 | v = asyncio.run(embeddings.aembed("hello")) 9 | 10 | v = np.array([v], dtype=np.float32) 11 | 12 | faiss.normalize_L2(v) 13 | 14 | n = np.linalg.norm(v, ord=2) 15 | print(n) 16 | -------------------------------------------------------------------------------- /tests/tools/sync_tool.py: -------------------------------------------------------------------------------- 1 | from myla.tools import Context, Tool 2 | import time 3 | import asyncio 4 | 5 | class SyncTool(Tool): 6 | def execute(self, context: Context) -> None: 7 | for i in range(10): 8 | print(f"SyncTool: {i}") 9 | time.sleep(3) 10 | 11 | class AsyncTool(Tool): 12 | async def execute(self, context: Context) -> None: 13 | for i in range(10): 14 | print(f"AsyncTool: {i}") 15 | await asyncio.sleep(3) 16 | -------------------------------------------------------------------------------- /tests/myla/llms/mock_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from myla import llms 4 | from myla import utils 5 | 6 | 7 | class TestMockLLM(unittest.TestCase): 8 | def test_chat(self): 9 | llm = llms.get("mock@mock") 10 | self.assertIsNotNone(llm) 11 | 12 | g = utils.sync_call(llm.chat, messages=[{'role': 'user', 'content': 'hello'}]) 13 | self.assertEqual('hello', g) 14 | 15 | g = utils.sync_call(llm.generate, instructions="hello") 16 | self.assertEqual('hello', g) 17 | -------------------------------------------------------------------------------- /js/public/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | Aify 8 | 9 | 10 | 11 | 12 |
13 | 14 | 19 | 20 | -------------------------------------------------------------------------------- /tests/myla/vectorstores/xinference_embeddings_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from myla.vectorstores.xinference_embeddings import XinferenceEmbeddings 4 | 5 | 6 | class XinferenceTests(unittest.TestCase): 7 | def test_connect(self): 8 | embed = XinferenceEmbeddings( 9 | base_url="http://localhost:9997", 10 | model_id="bge-small-zh", 11 | instruction="Represent the sentence for searching the most similar sentences from the corpus." 12 | ) 13 | embeds = embed.embed("你好") 14 | self.assertEqual(len(embeds), 512) 15 | -------------------------------------------------------------------------------- /examples/llm_debug.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from myla.llms.chatglm import ChatGLM 3 | 4 | 5 | async def main(): 6 | openai = ChatGLM() 7 | r = await openai.chat(messages=[{ 8 | "role": "system", 9 | "content": "你是谁" 10 | }], 11 | stream=True, 12 | model="/Users/shellc/Workspaces/chatglm.cpp/chatglm-ggml.bin" 13 | ) 14 | print(r) 15 | async for c in r: 16 | print(c, end='', flush=True) 17 | 18 | 19 | if __name__ == '__main__': 20 | import dotenv 21 | dotenv.load_dotenv(".env") 22 | 23 | asyncio.run(main=main()) 24 | -------------------------------------------------------------------------------- /myla/permissions.py: -------------------------------------------------------------------------------- 1 | #from ._logging import logger 2 | 3 | 4 | def check(resource_type, resource_id, org_id, project_id, owner_id, user_id, orgs, permission) -> bool: 5 | #logger.debug(f"Checking permission {permission} for {resource_type} {resource_id} in org {org_id} and project {project_id}, orgs={orgs}") 6 | 7 | role = None 8 | if org_id in orgs: 9 | role = orgs[org_id].role 10 | 11 | if permission == "read": 12 | if role == "reader" or role == "owner": 13 | return True 14 | elif permission == "write": 15 | if role == "owner": 16 | return True 17 | 18 | return False 19 | -------------------------------------------------------------------------------- /examples/docs_summary.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import myla.llms as llms 3 | 4 | DOCS = """ 5 | [] 6 | """ 7 | 8 | QUERY = "送给谢顶男票的礼物" 9 | 10 | INSTRUCTIONS = """ 11 | 你是专业的问答分析助手。下面是JSON格式的问答记录。 12 | <问答记录开始> 13 | {docs} 14 | <问答记录介绍> 15 | 16 | 请根据问答记录生成新问题的候选回答。 17 | 新问题: {query} 18 | 候选回答: 19 | """ 20 | 21 | 22 | async def main(): 23 | r = await llms.get().chat(messages=[{ 24 | "role": "system", 25 | "content": INSTRUCTIONS.format(docs=DOCS, query=QUERY) 26 | }], stream=False) 27 | print(r) 28 | 29 | 30 | if __name__ == '__main__': 31 | import dotenv 32 | dotenv.load_dotenv(".env") 33 | 34 | asyncio.run(main=main()) 35 | -------------------------------------------------------------------------------- /myla/llms/mock.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | from .backend import LLM 3 | 4 | 5 | class MockLLM(LLM): 6 | def __init__(self) -> None: 7 | super().__init__() 8 | 9 | async def chat(self, messages: List[Dict], model=None, stream=False, **kwargs): 10 | last_message = messages[-1]['content'] 11 | 12 | if stream: 13 | async def iter(): 14 | for c in [last_message]: 15 | yield c 16 | return iter() 17 | else: 18 | return last_message 19 | 20 | async def generate(self, instructions: str, model=None, stream=False, **kwargs): 21 | return instructions 22 | -------------------------------------------------------------------------------- /js/src/index.css: -------------------------------------------------------------------------------- 1 | * { 2 | font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, 'Helvetica Neue', Arial, 3 | 'Noto Sans', sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol', 4 | 'Noto Color Emoji'; 5 | } 6 | a { 7 | text-decoration: none; 8 | } 9 | .thread-close, .anticon { 10 | transform: translate(-0.5px, -3px); 11 | } 12 | 13 | /* Hide scrollbar for Chrome, Safari and Opera */ 14 | .scrollbar-none::-webkit-scrollbar { 15 | display: none; 16 | } 17 | 18 | /* Hide scrollbar for IE, Edge and Firefox */ 19 | .scrollbar-none { 20 | -ms-overflow-style: none; /* IE and Edge */ 21 | scrollbar-width: none; /* Firefox */ 22 | } -------------------------------------------------------------------------------- /scripts/update_0.2.22.sql: -------------------------------------------------------------------------------- 1 | ALTER TABLE thread ADD tag TEXT; 2 | CREATE INDEX ix_thread_tag ON thread (tag); 3 | 4 | ALTER TABLE message ADD tag TEXT; 5 | CREATE INDEX ix_message_tag ON message (tag); 6 | 7 | ALTER TABLE assistant ADD tag TEXT; 8 | CREATE INDEX ix_assistant_tag ON assistant (tag); 9 | 10 | ALTER TABLE file ADD tag TEXT; 11 | CREATE INDEX ix_file_tag ON thread (tag); 12 | 13 | ALTER TABLE organization ADD tag TEXT; 14 | CREATE INDEX ix_organization_tag ON file (tag); 15 | 16 | ALTER TABLE run ADD tag TEXT; 17 | CREATE INDEX ix_run_tag ON thread (tag); 18 | 19 | ALTER TABLE secretkey ADD tag TEXT; 20 | CREATE INDEX ix_secretkey_tag ON run (tag); 21 | 22 | ALTER TABLE user ADD tag TEXT; 23 | CREATE INDEX ix_user_tag ON user (tag); 24 | -------------------------------------------------------------------------------- /js/src/index.js: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import ReactDOM from 'react-dom/client'; 3 | import { Aify } from './aify' 4 | import { getSecretKey, Login } from './user'; 5 | 6 | import 'bootstrap/dist/css/bootstrap.css' 7 | import './index.css' 8 | 9 | export const createAify = (elementId, chatMode = false, assistantId = null) => { 10 | const root = ReactDOM.createRoot(document.getElementById(elementId)); 11 | root.render( 12 | 13 | {getSecretKey() ? ( 14 | 15 | ) : ( 16 | 17 | )} 18 | 19 | ); 20 | } 21 | window.createAify = createAify; 22 | //export default createAify; -------------------------------------------------------------------------------- /examples/lancedb_vs.py: -------------------------------------------------------------------------------- 1 | from myla.vectorstores.lancedb_vectorstore import LanceDB 2 | from myla.vectorstores.sentence_transformers_embeddings import SentenceTransformerEmbeddings 3 | from myla.vectorstores.pandas_loader import PandasLoader 4 | 5 | embeddings = SentenceTransformerEmbeddings(model_name="/Users/shellc/Downloads/bge-large-zh-v1.5") 6 | 7 | vs = LanceDB(db_uri="/tmp/lancedb", embeddings=embeddings) 8 | 9 | collection = "default" 10 | 11 | records = list(PandasLoader().load("./data/202101.csv")) 12 | 13 | vs.create_collection(collection=collection, schema=records[0], mode='overwrite') 14 | 15 | vs.add(collection=collection, records=records) 16 | 17 | print(vs.search(collection=collection, query="新疆")) 18 | 19 | #vs.delete(collection=collection, query="text = 'hello'") 20 | -------------------------------------------------------------------------------- /scripts/update_0.2.26.sql: -------------------------------------------------------------------------------- 1 | ALTER TABLE userorglink ADD "role" VARCHAR; 2 | UPDATE userorglink SET role='owner'; 3 | 4 | UPDATE organization SET user_id=(SELECT userorglink.user_id FROM userorglink WHERE organization.id=userorglink.org_id); 5 | 6 | UPDATE assistant SET org_id=(SELECT userorglink.org_id FROM userorglink WHERE assistant.user_id=userorglink.user_id); 7 | UPDATE thread SET org_id=(SELECT userorglink.org_id FROM userorglink WHERE thread.user_id=userorglink.user_id); 8 | UPDATE message SET org_id=(SELECT userorglink.org_id FROM userorglink WHERE message.user_id=userorglink.user_id); 9 | UPDATE run SET org_id=(SELECT userorglink.org_id FROM userorglink WHERE run.user_id=userorglink.user_id); 10 | UPDATE file SET org_id=(SELECT userorglink.org_id FROM userorglink WHERE file.user_id=userorglink.user_id); -------------------------------------------------------------------------------- /myla/vectorstores/pandas_loader.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from typing import Iterator, Optional, Dict 3 | from ._base import Record 4 | from .loaders import Loader 5 | 6 | 7 | class PandasLoader(Loader): 8 | def __init__(self, ftype="csv") -> None: 9 | super().__init__() 10 | self._ftype = ftype 11 | 12 | def load(self, file, metadata: Optional[Dict] = None) -> Iterator[Record]: 13 | if self._ftype == 'csv': 14 | df = pd.read_csv(file) 15 | elif self._ftype == 'xls' or self._ftype == 'xlsx': 16 | df = pd.read_excel(file) 17 | elif self._ftype == 'json': 18 | df = pd.read_json(file) 19 | else: 20 | raise ValueError(f"File type not supported: {self._ftype}") 21 | 22 | for _, r in df.iterrows(): 23 | yield r.to_dict() 24 | -------------------------------------------------------------------------------- /myla/_tools.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import json 3 | import os 4 | 5 | from ._logging import logger 6 | 7 | _tools = {} 8 | 9 | 10 | def load_tools(): 11 | tools = os.environ.get('TOOLS') 12 | try: 13 | if tools: 14 | tools = json.loads(tools) 15 | for tool in tools: 16 | impl = tool['impl'] 17 | ss = impl.split('.') 18 | module = importlib.import_module('.'.join(ss[:-1])) 19 | 20 | args = tool['args'] if 'args' in tool else {} 21 | instance = getattr(module, ss[-1])(**args) 22 | _tools[tool["name"]] = instance 23 | except Exception as e: 24 | logger.error(f"Load tools faild: {tools}", exc_info=e) 25 | 26 | 27 | def get_tool(name): 28 | return _tools.get(name) 29 | 30 | 31 | def get_tools(): 32 | return _tools 33 | -------------------------------------------------------------------------------- /myla/llms/backend.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | 4 | class LLM: 5 | def __init__(self, model=None) -> None: 6 | self.model = model 7 | 8 | async def chat(self, messages: List[Dict], model=None, stream=False, **kwargs): 9 | raise NotImplemented() 10 | 11 | async def generate(self, instructions: str, model=None, stream=False, **kwargs): 12 | raise NotImplemented() 13 | 14 | def sync_chat(self, messages: List[Dict], model=None, stream=False, **kwargs): 15 | raise NotImplemented() 16 | 17 | def sync_generate(self, instructions: str, model=None, stream=False, **kwargs): 18 | raise NotImplemented() 19 | 20 | 21 | class Usage: 22 | def __init__(self, prompt_tokens: int = 0, completion_tokens: int = 0) -> None: 23 | self.prompt_tokens = prompt_tokens 24 | self.completion_tokens = completion_tokens 25 | -------------------------------------------------------------------------------- /myla/vectorstores/_embeddings.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from abc import ABC, abstractmethod 3 | import asyncio 4 | 5 | 6 | class Embeddings(ABC): 7 | 8 | @abstractmethod 9 | def embed_batch(self, texts: List[str], **kwargs) -> List[List[float]]: 10 | """Embed text batch.""" 11 | 12 | def embed(self, text: str, **kwargs) -> List[float]: 13 | """Embed text.""" 14 | return self.embed_batch(texts=[text], **kwargs)[0] 15 | 16 | async def aembed(self, text: str, **kwargs) -> List[float]: 17 | """Asynchronous Embed text.""" 18 | return await asyncio.get_running_loop().run_in_executor( 19 | None, self.embed, text, **kwargs 20 | ) 21 | 22 | async def aembed_batch(self, texts: [str], **kwargs) -> List[List[float]]: 23 | """Asynchronous Embed text.""" 24 | return await asyncio.get_running_loop().run_in_executor( 25 | None, self.embed_batch, texts, **kwargs 26 | ) 27 | -------------------------------------------------------------------------------- /myla/webui/_web_template.py: -------------------------------------------------------------------------------- 1 | import os 2 | from starlette.templating import Jinja2Templates 3 | from jinja2.exceptions import TemplateNotFound 4 | from starlette.exceptions import HTTPException 5 | from .._env import webui_dir 6 | 7 | 8 | def get_templates(templates_dir=None): 9 | if not templates_dir: 10 | templates_dir = webui_dir() 11 | return Jinja2Templates(directory=os.path.join(templates_dir, 'templates')) 12 | 13 | 14 | _templates = get_templates() 15 | 16 | 17 | def render(template_name, context={}, templates=None): 18 | """Render a template.""" 19 | if not templates: 20 | templates = _templates 21 | 22 | async def _request(request): 23 | context['request'] = request 24 | try: 25 | return templates.TemplateResponse(template_name, context=context) 26 | except TemplateNotFound as e: 27 | raise HTTPException(status_code=404, detail=f"Template not found: {e}") 28 | return _request 29 | -------------------------------------------------------------------------------- /myla/webui/templates/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | Myla 10 | 11 | 12 | 13 | 14 | 15 | 16 |
17 | 18 | 26 | -------------------------------------------------------------------------------- /tests/myla/vectorstores/record_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from myla.vectorstores import Record 4 | 5 | 6 | class TestRecord(unittest.TestCase): 7 | def test_values_to_text(self): 8 | r = { 9 | "category": "介绍/推荐", 10 | "query": "XX系列", 11 | "response": [ 12 | "XXX" 13 | ], 14 | "msg_id": 0, 15 | "source": "standard" 16 | } 17 | t = Record.values_to_text(r, props=['category']) 18 | self.assertEqual("介绍/推荐", t) 19 | 20 | t = Record.values_to_text(r, props=['category', 'query']) 21 | self.assertEqual("介绍/推荐\001XX系列", t) 22 | 23 | t = Record.values_to_text(r, props=['category', 'query'], separator='\t') 24 | self.assertEqual("介绍/推荐\tXX系列", t) 25 | 26 | try: 27 | Record.values_to_text(r, props='category') 28 | except Exception as e: 29 | self.assertTrue(isinstance(e, ValueError)) 30 | 31 | t = Record.values_to_text(r, separator='\t') 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 muyuworks 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /tests/myla/threads_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from myla import threads, persistence 4 | 5 | 6 | class TestUsers(unittest.TestCase): 7 | 8 | def setUp(self) -> None: 9 | self.db = persistence.Persistence(database_url="sqlite://") 10 | self.db.initialize_database() 11 | self.session = self.db.create_session() 12 | 13 | def tearDown(self) -> None: 14 | self.session.close() 15 | 16 | def test_create(self): 17 | t_created = threads.create(thread=threads.ThreadCreate(metadata={'k': 'v'}), session=self.session) 18 | self.assertIsInstance(t_created, threads.ThreadRead) 19 | self.assertEqual(t_created.metadata, {'k': 'v'}) 20 | 21 | t_read = threads.get(id=t_created.id, session=self.session) 22 | self.assertEqual(t_read.metadata, t_created.metadata) 23 | 24 | def test_create_with_tag(self): 25 | t_created = threads.create(thread=threads.ThreadCreate(metadata={'k': 'v'}), tag="t_thread", session=self.session) 26 | self.assertIsInstance(t_created, threads.ThreadRead) 27 | thread_list = threads.list(tag="t_thread", session=self.session) 28 | self.assertEqual(len(thread_list.data), 1) 29 | self.assertIsNotNone(thread_list.data[0].id, t_created.id) 30 | -------------------------------------------------------------------------------- /js/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "js", 3 | "version": "0.1.0", 4 | "private": true, 5 | "dependencies": { 6 | "@microsoft/fetch-event-source": "^2.0.1", 7 | "@testing-library/jest-dom": "^5.17.0", 8 | "@testing-library/react": "^13.4.0", 9 | "@testing-library/user-event": "^13.5.0", 10 | "antd": "^5.11.1", 11 | "bootstrap": "^5.3.2", 12 | "fetch-event-source": "^1.0.0-alpha.2", 13 | "js-cookie": "^3.0.5", 14 | "react": "^18.2.0", 15 | "react-dom": "^18.2.0", 16 | "react-markdown": "^9.0.0", 17 | "react-scripts": "5.0.1", 18 | "remark-gfm": "^4.0.0", 19 | "web-vitals": "^2.1.4" 20 | }, 21 | "scripts": { 22 | "start": "react-scripts start", 23 | "build": "react-scripts build", 24 | "test": "react-scripts test", 25 | "eject": "react-scripts eject" 26 | }, 27 | "eslintConfig": { 28 | "extends": [ 29 | "react-app", 30 | "react-app/jest" 31 | ] 32 | }, 33 | "browserslist": { 34 | "production": [ 35 | ">0.2%", 36 | "not dead", 37 | "not op_mini all" 38 | ], 39 | "development": [ 40 | "last 1 chrome version", 41 | "last 1 firefox version", 42 | "last 1 safari version" 43 | ] 44 | }, 45 | "proxy": "http://127.0.0.1:2000" 46 | } 47 | -------------------------------------------------------------------------------- /myla/vectorstores/xinference_embeddings.py: -------------------------------------------------------------------------------- 1 | from random import randint 2 | from typing import List 3 | 4 | from xinference_client import RESTfulClient as Client 5 | 6 | from ._embeddings import Embeddings 7 | 8 | 9 | class XinferenceEmbeddings(Embeddings): 10 | def __init__( 11 | self, 12 | base_url, 13 | model_id, 14 | instruction=None 15 | ) -> None: 16 | self._base_url = base_url 17 | self._model_id = model_id 18 | self._instruction = instruction 19 | 20 | client = Client(self._base_url) 21 | #self._model = client.get_model(model_id) 22 | model_ids = model_id.split(",") 23 | self._models = [] 24 | for m_id in model_ids: 25 | self._models.append(client.get_model(m_id)) 26 | 27 | def embed_batch(self, texts: List[str], **kwargs) -> List[List[float]]: 28 | if self._instruction is not None: 29 | texts = [self._instruction + t for t in texts] 30 | 31 | model = self._get_model() 32 | 33 | embeds = model.create_embedding(texts) 34 | return [e["embedding"] for e in embeds["data"]] 35 | 36 | def _get_model(self): 37 | idx = randint(0, len(self._models) - 1) 38 | return self._models[idx] 39 | -------------------------------------------------------------------------------- /myla/iur.py: -------------------------------------------------------------------------------- 1 | from .tools import Tool, Context 2 | from . import llms, logger 3 | from .llms import utils 4 | 5 | INSTRUCTIONS_ZH = """ 6 | 你是专业的文本分析助手, 负责改写用户回复, 下面是AI助手和用户的对话记录, system 是AI助手的身份设定, user是用户, assistant是AI助手: 7 | -开始对话- 8 | {history} 9 | -结束对话- 10 | 用户最新回复: {last_user_message} 11 | 请你结合对话记录改写用户最新回复。 12 | 如果用户最新回复是好的、谢谢、你好等问候语或礼貌性回复,不要改写用户回复。 13 | 如果用户最新回复是提问,请你以用户的身份改写,使其表述更清晰并包含用户的完整意图, 易于AI助手理解。 14 | 15 | 请直接输出修改后的结果,不要包含前缀说明。 16 | 改写后的用户回复: 17 | """ 18 | 19 | 20 | class IURTool(Tool): 21 | async def execute(self, context: Context) -> None: 22 | """ 23 | 根据会话历史让 LLM 决定是否需要改写用户最后一条消息 24 | """ 25 | last_user_message = None 26 | if len(context.messages) > 0: 27 | if context.messages[-1]["role"] == "user": 28 | last_user_message = context.messages[-1]['content'] 29 | 30 | if not last_user_message: 31 | return 32 | 33 | history = utils.plain_messages(messages=context.messages) 34 | 35 | iur_query = await llms.get().generate(INSTRUCTIONS_ZH.format(history=history, last_user_message=last_user_message), temperature=0) 36 | 37 | logger.debug(f"Converstations: \n{history}\n IUR: {iur_query}") 38 | 39 | context.messages[-1]['content'] = iur_query 40 | 41 | context.message_metadata['iur'] = iur_query 42 | -------------------------------------------------------------------------------- /myla/vectorstores/pdf_loader.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Iterator, Optional, Dict 3 | from myla.vectorstores._base import Record 4 | from .loaders import Loader 5 | 6 | 7 | class PDFLoader(Loader): 8 | def __init__(self, chunk_size=500, chunk_overlap=50) -> None: 9 | super().__init__() 10 | self._chunk_size = chunk_size 11 | self._chunk_overlap = chunk_overlap 12 | 13 | def load(self, file, metadata: Optional[Dict] = None) -> Iterator[Record]: 14 | try: 15 | import pypdf 16 | except ImportError as e: 17 | raise ImportError( 18 | "Could not import pypdf python package. " 19 | "Please install it with `pip install pypdf`." 20 | ) from e 21 | reader = pypdf.PdfReader(file) 22 | for page in reader.pages: 23 | text = page.extract_text() 24 | for s in self._split(text=text): 25 | yield {"text": s} 26 | 27 | def _split(self, text): 28 | for i in range(math.ceil(len(text)/self._chunk_size)): 29 | begin = i * self._chunk_size 30 | end = begin + self._chunk_size 31 | 32 | if i > 0: 33 | begin -= self._chunk_overlap 34 | if begin < 0: 35 | begin = 0 36 | if end > len(text): 37 | end = len(text) 38 | 39 | yield text[begin: end] 40 | -------------------------------------------------------------------------------- /env-example.txt: -------------------------------------------------------------------------------- 1 | # Persistence 2 | #DATABASE_URL=sqlite:///myla.db 3 | #DATABASE_CONNECT_ARGS={"check_same_thread": false} 4 | 5 | MYLA_DELETE_MODE=soft 6 | 7 | # LLMs 8 | 9 | #LLM_ENDPOINT=http://172.88.0.20:8888/v1/ 10 | #LLM_API_KEY=sk-xx 11 | #DEFAULT_LLM_MODEL_NAME=Qwen-14B-Chat-Int4 12 | 13 | # to use ChatGLM as the backend: pip install myla[chatglm] 14 | # the model name for ChatGLM like: 15 | #DEFAULT_LLM_MODEL_NAME=chatglm@/Users/shellc/Workspaces/chatglm.cpp/chatglm-ggml.bin 16 | 17 | 18 | # Ebeddings 19 | 20 | #EMBEDDINGS_IMPL=sentence_transformers 21 | #EMBEDDINGS_MODEL_NAME=/Users/shellc/Downloads/bge-large-zh-v1.5 22 | #EMBEDDINGS_DEVICE=cpu 23 | #EMBEDDINGS_INSTRUCTION= 24 | 25 | 26 | # Vectorstore 27 | # Default vecotrstore backend, options: faiss, lancedb 28 | # to use faiss as the backend: pip install myla[faiss-cpu] or myla[faiss-gpu] 29 | # to use LanceDB as the backend: pip install myla[lancedb] 30 | 31 | #VECTOR_STORE_IMPL=faiss 32 | 33 | 34 | # Tools 35 | # JSON format configurations 36 | 37 | #TOOLS=' 38 | #[ 39 | # { 40 | # "name": "retrieval", 41 | # "impl": "myla.retrieval.RetrievalTool" 42 | # } 43 | #] 44 | #' 45 | 46 | 47 | # Vectorstore Loaders 48 | # JSON format configurations 49 | 50 | #LOADERS=' 51 | #[ 52 | # { 53 | # "name": "my_loader", 54 | # "impl": "my_loaders.MyLoader" 55 | # } 56 | #] 57 | #' 58 | 59 | 60 | # Others 61 | 62 | # HuggingFace 63 | 64 | #TOKENIZERS_PARALLELISM=4 -------------------------------------------------------------------------------- /myla/extensions/tools/qa_summary.py: -------------------------------------------------------------------------------- 1 | from myla.tools import Tool, Context 2 | from myla import llms 3 | 4 | DOC_SUMMARY_INSTRUCTIONS_ZH = """ 5 | 你是专业的文本分析助手, 你负责为用户问题生成候选答案。 6 | 7 | 你要使用下面的JSON格式的数据, query字段是提问, response字段是回答。 8 | 9 | <数据开始> 10 | {docs} 11 | <数据结束> 12 | 用户提问: {query} 13 | 14 | 请为用户提问生成候选回答,如果用户提问不明确,请用户进一步说明, 生成结果不要包含问题。 15 | 16 | 候选回答: 17 | """ 18 | 19 | 20 | class QASummaryTool(Tool): 21 | async def execute(self, context: Context) -> None: 22 | if len(context.messages) == 0: 23 | return 24 | 25 | last_message = context.messages[-1]['content'] 26 | 27 | docs = None 28 | 29 | for msg in context.messages: 30 | if msg.get('type') == 'docs': 31 | docs = msg 32 | 33 | if docs: 34 | summary = await llms.get().chat(messages=[{ 35 | "role": "system", 36 | "content": DOC_SUMMARY_INSTRUCTIONS_ZH.format(docs=docs['content'], query=last_message) 37 | }], stream=False, temperature=0) 38 | if summary: 39 | docs['content'] = summary 40 | 41 | messages = [] # 删除对话历史, 只保留 system Message 和最后一条用户消息 42 | for msg in context.messages: 43 | if not (msg['role'] == 'user' or msg['role'] == 'assistant'): 44 | messages.append(msg) 45 | messages.append({"role": "user", "content": last_message}) 46 | 47 | context.messages = messages 48 | -------------------------------------------------------------------------------- /tests/myla/files_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from myla import files, persistence, utils 4 | 5 | 6 | class TestFiles(unittest.TestCase): 7 | 8 | def setUp(self) -> None: 9 | self.db = persistence.Persistence(database_url="sqlite://") 10 | self.db.initialize_database() 11 | self.session = self.db.create_session() 12 | 13 | def tearDown(self) -> None: 14 | self.session.close() 15 | 16 | def test_file_create(self): 17 | id = utils.random_id() 18 | file = files.FileUpload(purpose="assistant", metadata={"k1": "v1", "k2": "v2"}) 19 | file_created = files.create(id=id, file=file, filename="filename", bytes=0, session=self.session) 20 | 21 | self.assertEqual(id, file_created.id) 22 | self.assertEqual("assistant", file_created.purpose) 23 | self.assertEqual({"k1": "v1", "k2": "v2"}, file_created.metadata) 24 | 25 | file_read = files.get(id=id, session=self.session) 26 | self.assertEqual(id, file_read.id) 27 | self.assertEqual("assistant", file_read.purpose) 28 | self.assertEqual({"k1": "v1", "k2": "v2"}, file_read.metadata) 29 | 30 | status = files.delete(id=id, session=self.session) 31 | self.assertEqual(id, status.id) 32 | self.assertEqual("file.deleted", status.object) 33 | self.assertEqual(True, status.deleted) 34 | 35 | file_read = files.get(id=id, session=self.session) 36 | self.assertIsNone(file_read) 37 | -------------------------------------------------------------------------------- /js/src/org_selector.js: -------------------------------------------------------------------------------- 1 | import { Select } from "antd" 2 | import { useEffect, useState } from "react" 3 | 4 | export const OrgSelector = () => { 5 | const [orgs, setOrgs] = useState() 6 | const [defaultOrg, setDefaultOrg] = useState(localStorage.getItem("org_id")) 7 | 8 | const loadOrgs = () => { 9 | 10 | fetch("/api/v1/organizations") 11 | .then(res => res.json()) 12 | .then(data => { 13 | if (defaultOrg === null) { 14 | /*data.data.map(org => { 15 | if (org.is_primary) { 16 | console.log(org.id) 17 | setDefaultOrg(org.id) 18 | } 19 | })*/ 20 | 21 | setDefaultOrg(data.data[0].id) 22 | } 23 | 24 | setOrgs(data.data) 25 | }) 26 | } 27 | 28 | const changeOrg = (orgId) => { 29 | localStorage.setItem('org_id', orgId) 30 | document.location.reload() 31 | } 32 | 33 | useEffect(() => { 34 | loadOrgs() 35 | }, []) 36 | 37 | return ( 38 | 79 | 80 | 81 | 90 | 91 | 92 | 93 | 96 | 97 | 98 | ) 99 | } -------------------------------------------------------------------------------- /myla/files.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Union 2 | from sqlmodel import Field, Session, select 3 | from . import _models 4 | 5 | 6 | class FileUpload(_models.MetadataModel): 7 | purpose: str 8 | 9 | 10 | class FileRead(_models.ReadModel): 11 | bytes: int 12 | filename: str 13 | purpose: str 14 | 15 | 16 | class FileList(_models.ListModel): 17 | data: List[FileRead] 18 | 19 | 20 | class File(_models.DBModel, table=True): 21 | bytes: int 22 | filename: Optional[str] = None 23 | purpose: str = Field(index=True) 24 | 25 | 26 | @_models.auto_session 27 | def create(id: str, file: FileUpload, bytes: int, filename: str, user_id: str = None, org_id: str = None, session: Optional[Session] = None) -> FileRead: 28 | db_model = File( 29 | purpose=file.purpose, 30 | bytes=bytes, 31 | filename=filename, 32 | metadata_=file.metadata 33 | ) 34 | 35 | dbo = _models.create(id=id, object="file", meta_model=file, user_id=user_id, org_id=org_id, db_model=db_model, session=session) 36 | 37 | return dbo.to_read(FileRead) 38 | 39 | 40 | @_models.auto_session 41 | def get(id: str, user_id: str = None, session: Optional[Session] = None) -> Union[FileRead, None]: 42 | """Retrieval File object. 43 | 44 | Args: 45 | id(str): file id 46 | 47 | Returns: 48 | if file exists return FileRead object else return None 49 | """ 50 | return _models.get(db_cls=File, read_cls=FileRead, id=id, user_id=user_id, session=session) 51 | 52 | 53 | @_models.auto_session 54 | def delete(id: str, user_id: str = None, mode="soft", session: Optional[Session] = None) -> _models.DeletionStatus: 55 | return _models.delete(db_cls=File, id=id, user_id=user_id, mode=mode, session=session) 56 | 57 | 58 | @_models.auto_session 59 | def list(purpose: str = None, limit: int = 20, order: str = "desc", after: str = None, before: str = None, user_id: str = None, org_id: str = None, session: Optional[Session] = None) -> FileList: 60 | select_stmt = select(File) 61 | select_stmt = select_stmt.filter(File.is_deleted == False) 62 | 63 | if purpose: 64 | select_stmt = select_stmt.filter(File.purpose == purpose) 65 | 66 | select_stmt = select_stmt.order_by(-File.created_at if order == "desc" else File.created_at) 67 | if after: 68 | f = get(id=after, user_id=user_id, session=session) 69 | if f: 70 | select_stmt = select_stmt.filter(File.created_at > f.created_at) 71 | if before: 72 | f = get(id=before, user_id=user_id, session=session) 73 | if f: 74 | select_stmt = select_stmt.filter(File.created_at < f.created_at) 75 | 76 | if user_id: 77 | select_stmt = select_stmt.filter(File.user_id == user_id) 78 | if org_id: 79 | select_stmt = select_stmt.filter(File.org_id == org_id) 80 | 81 | select_stmt = select_stmt.limit(limit) 82 | 83 | dbos = session.exec(select_stmt).all() 84 | rs = [] 85 | for dbo in dbos: 86 | rs.append(dbo.to_read(FileRead)) 87 | r = FileList(data=rs) 88 | return r 89 | -------------------------------------------------------------------------------- /tests/myla/persistence_test.py: -------------------------------------------------------------------------------- 1 | 2 | import sqlite3 3 | import threading 4 | import unittest 5 | 6 | from myla import persistence, threads, utils 7 | 8 | db_file = "./tests/data/myla-test.db" 9 | stop = False 10 | 11 | 12 | class TestPersistence(unittest.TestCase): 13 | def setUp(self) -> None: 14 | 15 | self.db = persistence.Persistence(database_url=f"sqlite:///{db_file}", connect_args={"timeout": 1}) 16 | self.db.initialize_database() 17 | #with self.db.create_session() as session: 18 | # session.exec(text("PRAGMA journal_mode = WAL")) 19 | 20 | def tearDown(self) -> None: 21 | #if os.path.exists(db_file): 22 | # os.remove(db_file) 23 | #self.session.close() 24 | pass 25 | 26 | def test_database_lock(self): 27 | 28 | def _read_thread(): 29 | global stop 30 | session = self.db.create_session() 31 | 32 | while not stop: 33 | try: 34 | tl = threads.list(limit=100, session=session) 35 | print(f"list: {len(tl.data)}") 36 | except Exception as e: 37 | #print(f"Read thread error: {e}") 38 | stop = True 39 | raise e 40 | 41 | def _write_thread(): 42 | global stop 43 | #db = persistence.Persistence(database_url=f"sqlite:///{db_file}") 44 | session = self.db.create_session() 45 | 46 | c = 0 47 | while not stop: 48 | try: 49 | threads.create(thread=threads.ThreadCreate(), session=session) 50 | 51 | c += 1 52 | if c % 1000 == 0: 53 | print(f"Create threads: {c} {stop}") 54 | except Exception as e: 55 | stop = True 56 | session.rollback() 57 | raise e 58 | 59 | def _native_write_thread(): 60 | global stop 61 | conn = sqlite3.connect(db_file, isolation_level=None) 62 | #conn = self.db.engine.connect().connection 63 | 64 | c = 0 65 | while not stop: 66 | try: 67 | conn.execute("INSERT INTO thread (id) VALUES (?)", (utils.random_id(),)) 68 | 69 | c += 1 70 | if c % 100 == 0: 71 | print(f"Native create threads: {c}") 72 | 73 | if c == 100000: 74 | stop = True 75 | print("stop") 76 | except Exception as e: 77 | stop = True 78 | raise e 79 | 80 | threading.Thread(target=_write_thread).start() 81 | threading.Thread(target=_write_thread).start() 82 | threading.Thread(target=_write_thread).start() 83 | #threading.Thread(target=_native_write_thread).start() 84 | threading.Thread(target=_read_thread).start() 85 | threading.Thread(target=_read_thread).start() 86 | threading.Thread(target=_read_thread).start() 87 | -------------------------------------------------------------------------------- /myla/assistants.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any, Dict, List, Optional, Union 3 | 4 | from pydantic import BaseModel 5 | from sqlmodel import JSON, Column, Field, Session 6 | 7 | from . import _models 8 | 9 | 10 | class AssistantBase(BaseModel): 11 | name: Optional[str] = None 12 | description: Optional[str] = None 13 | model: str 14 | instructions: Optional[str] = None 15 | tools: Optional[List[Dict[str, Any]]] = Field(sa_type=JSON, default=None) 16 | file_ids: Optional[List[str]] = Field(sa_type=JSON, default=None) 17 | 18 | 19 | class AssistantEdit(_models.MetadataModel, AssistantBase): 20 | pass 21 | 22 | 23 | class AssistantCreate(AssistantEdit): 24 | pass 25 | 26 | 27 | class AssistantModify(AssistantEdit): 28 | pass 29 | 30 | 31 | class AssistantRead(_models.ReadModel, AssistantBase): 32 | pass 33 | 34 | 35 | class AssistantList(_models.ListModel): 36 | data: List[AssistantRead] = [] 37 | 38 | 39 | class Assistant(_models.DBModel, AssistantBase, table=True): 40 | """ 41 | Represents an assistant that can call the model and use tools. 42 | """ 43 | 44 | 45 | @_models.auto_session 46 | def create(assistant: AssistantCreate, user_id: str = None, org_id: str = None, session: Session = None) -> AssistantRead: 47 | db_model = Assistant.model_validate(assistant) 48 | if not db_model.model or db_model == '': 49 | default_model = os.environ.get("DEFAULT_LLM_MODEL_NAME") 50 | if default_model: 51 | db_model.model = default_model 52 | 53 | dbo = _models.create(object="assistant", meta_model=assistant, db_model=db_model, user_id=user_id, org_id=org_id, session=session) 54 | return dbo.to_read(AssistantRead) 55 | 56 | 57 | @_models.auto_session 58 | def get(id: str, user_id: str = None, session: Session = None) -> Union[AssistantRead, None]: 59 | return _models.get(db_cls=Assistant, read_cls=AssistantRead, id=id, user_id=user_id, session=session) 60 | 61 | 62 | @_models.auto_session 63 | def modify(id: str, assistant: AssistantModify, user_id: str = None, session: Session = None) -> Union[AssistantRead, None]: 64 | return _models.modify(db_cls=Assistant, read_cls=AssistantRead, id=id, to_update=assistant.model_dump(exclude_unset=True), user_id=user_id, session=session) 65 | 66 | 67 | @_models.auto_session 68 | def delete(id: str, user_id: str = None, mode="soft", session: Optional[Session] = None) -> _models.DeletionStatus: 69 | return _models.delete(db_cls=Assistant, id=id, user_id=user_id, mode=mode, session=session) 70 | 71 | 72 | @_models.auto_session 73 | def list( 74 | limit: int = 20, 75 | order: str = "desc", 76 | after: str = None, 77 | before: str = None, 78 | user_id: str = None, 79 | org_id: str = None, 80 | session: Optional[Session] = None 81 | ) -> AssistantList: 82 | return _models.list( 83 | db_cls=Assistant, 84 | read_cls=AssistantRead, 85 | list_cls=AssistantList, 86 | limit=limit, 87 | order=order, 88 | after=after, 89 | before=before, 90 | user_id=user_id, 91 | org_id=org_id, 92 | session=session 93 | ) 94 | -------------------------------------------------------------------------------- /js/src/settings.js: -------------------------------------------------------------------------------- 1 | import { Button, Divider, Form, Input, message } from "antd" 2 | import { getUser, logout } from "./user" 3 | 4 | export const Settings = () => { 5 | const [changePasswordForm] = Form.useForm() 6 | const [msg, msgContext] = message.useMessage() 7 | 8 | const changePassword = () => { 9 | let user = getUser(); 10 | 11 | let password = changePasswordForm.getFieldValue('password'); 12 | let confirm_pwd = changePasswordForm.getFieldValue("confirm_pwd"); 13 | 14 | if (password !== confirm_pwd) { 15 | msg.error("Password conflict."); 16 | return 17 | } 18 | 19 | fetch(`/api/v1/users/${user.username}/password`, { 20 | method: 'PUT', 21 | headers: {"Content-Type": "application/json"}, 22 | body: JSON.stringify({"password": password}) 23 | }).then(r => { 24 | if (r.status === 200) { 25 | msg.success("OK"); 26 | changePasswordForm.resetFields(); 27 | } else { 28 | throw new Error("Status: " + r.status); 29 | } 30 | }).catch(err => { 31 | msg.error(err.message); 32 | }) 33 | } 34 | 35 | const changeProfile = () => { 36 | 37 | } 38 | 39 | return ( 40 |
41 | {msgContext} 42 |
Profile
43 |
44 | 45 | 49 | 50 |
51 | 52 | 53 |
Change password
54 |
60 | 67 | 68 | 69 | 76 | 77 | 78 | 79 | 82 | 83 |
84 | 85 | 86 |
87 | 90 |
91 |
92 | ) 93 | } -------------------------------------------------------------------------------- /myla/__main__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import sys 4 | import argparse 5 | import uvicorn 6 | from uvicorn.config import LOGGING_CONFIG 7 | from ._entry import register_webui 8 | 9 | MYLA_LIB_DIR = os.path.abspath(os.path.join( 10 | os.path.dirname(__file__), os.pardir)) 11 | sys.path.insert(0, MYLA_LIB_DIR) 12 | 13 | 14 | def runserver(args): 15 | log_level = "DEBUG" if args.debug else "INFO" 16 | logger = {"handlers": ["default"], "level": log_level, "propagate": False} 17 | LOGGING_CONFIG['loggers'][''] = logger 18 | LOGGING_CONFIG['loggers']['myla'] = logger 19 | 20 | reload_dirs = [] 21 | ext_dir = None 22 | 23 | if args.extensions: 24 | ext_dir = os.path.abspath(args.extensions) 25 | os.environ['EXT_DIR'] = ext_dir 26 | 27 | if args.reload: 28 | reload_dirs.append(MYLA_LIB_DIR) 29 | if ext_dir: 30 | reload_dirs.append(ext_dir) 31 | if args.reload_dirs: 32 | reload_dirs.append(args.relead_dirs) 33 | 34 | if args.data: 35 | os.environ['DATA_DIR'] = args.data 36 | 37 | if args.vectorstore: 38 | os.environ['VECTORSTORE_DIR'] = args.vectorstore 39 | elif 'DATA_DIR' in os.environ: 40 | os.environ['VECTORSTORE_DIR'] = os.path.join(os.environ['DATA_DIR'], 'vectorstore') 41 | 42 | if args.webui: 43 | os.environ['WEBUI'] = args.webui 44 | register_webui(args.webui) 45 | 46 | uvicorn.run('myla:entry', 47 | host=args.host, 48 | port=args.port, 49 | workers=args.workers, 50 | reload=args.reload, 51 | log_config=LOGGING_CONFIG, 52 | reload_dirs=reload_dirs 53 | ) 54 | 55 | 56 | parser = argparse.ArgumentParser() 57 | 58 | parser.add_argument('-H', '--host', default='0.0.0.0', 59 | help="bind socket to this host. default: 0.0.0.0") 60 | parser.add_argument('-p', '--port', default=2000, 61 | type=int, help="bind socket to this port, default: 2000") 62 | parser.add_argument('-w', '--workers', default=1, type=int, 63 | help="number of worker processes, default: 1") 64 | parser.add_argument('-r', '--reload', default=False, 65 | action='store_true', help="enable auto-reload") 66 | parser.add_argument('--reload-dirs', default=None, 67 | help="set reload directories explicitly, default is applications directory") 68 | parser.add_argument('--env-file', default='.env', 69 | help="environment configuration file") 70 | parser.add_argument("--extensions", default=None, help="extensions directory") 71 | parser.add_argument("--vectorstore", default=None, 72 | help="vectorstore directory") 73 | parser.add_argument("--data", default='data', 74 | help="data directory") 75 | parser.add_argument("--debug", default=False, 76 | action='store_true', help="enable debug") 77 | parser.add_argument("--webui", default=None, help="webui directory") 78 | 79 | 80 | def main(): 81 | args = parser.parse_args(sys.argv[1:]) 82 | 83 | if os.path.exists(args.env_file): 84 | from dotenv import load_dotenv 85 | load_dotenv(args.env_file) 86 | 87 | runserver(args) 88 | 89 | 90 | main() 91 | -------------------------------------------------------------------------------- /myla/_run_scheduler.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import asyncio 3 | from ._llm import chat_complete 4 | from ._logging import logger 5 | 6 | 7 | class AsyncIterator: 8 | def __init__(self): 9 | self.queue = asyncio.Queue() 10 | self.created_at = datetime.now().timestamp() 11 | 12 | def __aiter__(self): 13 | return self 14 | 15 | async def __anext__(self): 16 | i = await self.queue.get() 17 | if i is None: 18 | raise StopAsyncIteration 19 | 20 | return i 21 | 22 | async def put(self, item): 23 | await self.queue.put(item) 24 | 25 | 26 | class RunScheduler: 27 | _instance = None 28 | 29 | def __init__(self) -> None: 30 | self._tasks = set() 31 | self._run_queue = asyncio.Queue() 32 | self._run_iters = dict() 33 | self._run_iters_lock = asyncio.Lock() 34 | self._last_clear_at = datetime.now().timestamp() 35 | 36 | @staticmethod 37 | def default(): 38 | if not RunScheduler._instance: 39 | RunScheduler._instance = RunScheduler() 40 | return RunScheduler._instance 41 | 42 | def submit_run(self, run): 43 | self._run_queue.put_nowait(run) 44 | 45 | async def _create_run_iter(self, run_id): 46 | async with self._run_iters_lock: 47 | self._run_iters[run_id] = AsyncIterator() 48 | logger.debug(f"Run iters: {self._run_iters.keys()}") 49 | return self._run_iters[run_id] 50 | 51 | async def get_run_iter(self, run_id): 52 | async with self._run_iters_lock: 53 | logger.debug(f"Run iters: {self._run_iters.keys()}") 54 | return self._run_iters.get(run_id) 55 | 56 | async def _clear_iters(self): 57 | expires = 60*10 58 | now = datetime.now().timestamp() 59 | if self._last_clear_at + expires > now: 60 | return 61 | async with self._run_iters_lock: 62 | expired = [] 63 | for run_id, iter in self._run_iters.items(): 64 | if iter.created_at + expires < now: 65 | expired.append(run_id) 66 | 67 | for run_id in expired: 68 | self._run_iters.pop(run_id) 69 | 70 | logger.info(f"Run iters cleared: {expired}") 71 | self.last_clear_at = now 72 | 73 | def start(self): 74 | async def _start(): 75 | while True: 76 | try: 77 | run = await self._run_queue.get() 78 | logger.debug(f"RunScheduler received new task, run_id={run.id}") 79 | iter = await self._create_run_iter(run.id) 80 | 81 | task = asyncio.create_task( 82 | chat_complete(run=run, iter=iter) 83 | ) 84 | self._tasks.add(task) 85 | 86 | await self._clear_iters() 87 | 88 | done = [] 89 | for t in self._tasks: 90 | if t.done(): 91 | done.append(t) 92 | for t in done: 93 | self._tasks.remove(t) 94 | except Exception as e: 95 | logger.error(f"RunScheduler error: {e}") 96 | return asyncio.create_task(_start()) 97 | -------------------------------------------------------------------------------- /tests/myla/assistants_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from myla import assistants, persistence 4 | 5 | 6 | class TestUsers(unittest.TestCase): 7 | 8 | def setUp(self) -> None: 9 | self.db = persistence.Persistence(database_url="sqlite://") 10 | self.db.initialize_database() 11 | self.session = self.db.create_session() 12 | 13 | def tearDown(self) -> None: 14 | self.session.close() 15 | 16 | def test_create_get_list_assistant(self, user_id=None): 17 | asst_created = assistants.create(assistant=assistants.AssistantCreate( 18 | name='name', 19 | description='desc', 20 | model='model', 21 | instructions='instruction', 22 | tools=[{'type': 'retrieval'}], 23 | file_ids=['1'], 24 | metadata={'k': 'v'} 25 | ), user_id=user_id, session=self.session) 26 | self.assertIsNotNone(asst_created.id) 27 | 28 | asst_read = assistants.get(id=asst_created.id, user_id=user_id, session=self.session) 29 | self.assertIsNotNone(asst_read) 30 | self.assertEqual(asst_created.id, asst_read.id) 31 | self.assertEqual(asst_read.name, 'name') 32 | self.assertEqual(asst_read.description, 'desc') 33 | self.assertEqual(asst_read.instructions, 'instruction') 34 | self.assertEqual(asst_read.tools, [{'type': 'retrieval'}]) 35 | self.assertEqual(asst_read.file_ids, ['1']) 36 | self.assertEqual(asst_read.model, 'model') 37 | self.assertEqual(asst_read.metadata, {'k': 'v'}) 38 | 39 | asst_list = assistants.list(user_id=user_id, session=self.session) 40 | self.assertEqual(len(asst_list.data), 1) 41 | asst_read = asst_list.data[0] 42 | self.assertEqual(asst_created.id, asst_read.id) 43 | self.assertEqual(asst_read.name, 'name') 44 | self.assertEqual(asst_read.description, 'desc') 45 | self.assertEqual(asst_read.instructions, 'instruction') 46 | self.assertEqual(asst_read.tools, [{'type': 'retrieval'}]) 47 | self.assertEqual(asst_read.file_ids, ['1']) 48 | self.assertEqual(asst_read.model, 'model') 49 | self.assertEqual(asst_read.metadata, {'k': 'v'}) 50 | 51 | asst_update = assistants.modify(id=asst_read.id, assistant=assistants.AssistantModify(name='newname', model='newmodel', metadata={'k': 'v1'}), user_id=user_id, session=self.session) 52 | self.assertIsNotNone(asst_update) 53 | self.assertEqual(asst_update.id, asst_created.id) 54 | self.assertEqual(asst_update.name, 'newname') 55 | asst_read = assistants.get(id=asst_created.id, user_id=user_id, session=self.session) 56 | self.assertEqual(asst_read.name, 'newname') 57 | self.assertEqual(asst_read.metadata, {'k': 'v1'}) 58 | self.assertEqual(asst_read.model, 'newmodel') 59 | self.assertEqual(asst_read.description, 'desc') 60 | self.assertEqual(asst_read.instructions, 'instruction') 61 | self.assertEqual(asst_read.tools, [{'type': 'retrieval'}]) 62 | self.assertEqual(asst_read.file_ids, ['1']) 63 | 64 | res = assistants.delete(id=asst_read.id, user_id=user_id, session=self.session) 65 | self.assertEqual(res.object, 'assistant.deleted') 66 | asst_read = assistants.get(id=asst_created.id, user_id=user_id, session=self.session) 67 | self.assertIsNone(asst_read) 68 | 69 | def test_create_get_list_assistant_with_user_id(self): 70 | self.test_create_get_list_assistant(user_id='shellc') 71 | -------------------------------------------------------------------------------- /myla/threads.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from typing import List, Optional, Union 3 | 4 | from sqlmodel import Session 5 | 6 | from . import _models 7 | from .messages import Message 8 | 9 | 10 | class ThreadEdit(_models.MetadataModel): 11 | # message list 12 | pass 13 | 14 | 15 | class ThreadCreate(ThreadEdit): 16 | pass 17 | 18 | 19 | class ThreadModify(ThreadEdit): 20 | pass 21 | 22 | 23 | class ThreadRead(_models.ReadModel): 24 | pass 25 | 26 | 27 | class ThreadList(_models.ListModel): 28 | data: List[ThreadRead] = [] 29 | 30 | 31 | class Thread(_models.DBModel, table=True): 32 | """ 33 | Represents an assistant that can call the model and use tools. 34 | """ 35 | 36 | 37 | @_models.auto_session 38 | def create( 39 | thread: ThreadCreate, 40 | tag: Optional[str] = None, 41 | user_id: Optional[str] = None, 42 | org_id: Optional[str] = None, 43 | session: Optional[Session] = None 44 | ) -> ThreadRead: 45 | db_model = Thread.model_validate(thread) 46 | 47 | dbo = _models.create( 48 | object="thread", 49 | meta_model=thread, 50 | db_model=db_model, 51 | tag=tag, 52 | user_id=user_id, 53 | org_id=org_id, 54 | session=session 55 | ) 56 | return dbo.to_read(ThreadRead) 57 | 58 | 59 | @_models.auto_session 60 | def get(id: str, user_id: str = None, session: Session = None) -> Union[ThreadRead, None]: 61 | return _models.get(db_cls=Thread, read_cls=ThreadRead, id=id, user_id=user_id, session=session) 62 | 63 | 64 | @_models.auto_session 65 | def modify(id: str, thread: ThreadEdit, user_id: str = None, session: Session = None): 66 | return _models.modify(db_cls=Thread, read_cls=ThreadRead, id=id, to_update=thread.model_dump(exclude_unset=True), user_id=user_id, session=session) 67 | 68 | 69 | @_models.auto_session 70 | def delete(id: str, user_id: str = None, mode="soft", session: Optional[Session] = None) -> _models.DeletionStatus: 71 | dbo = session.get(Thread, id) 72 | if dbo and (not user_id or user_id == dbo.user_id): 73 | deleted_at = int(datetime.now().timestamp()*1000) 74 | if mode is not None and mode == 'soft': 75 | dbo.is_deleted = True 76 | dbo.deleted_at = deleted_at 77 | session.add(dbo) 78 | 79 | session.query(Message).where(Message.thread_id == id).update({Message.is_deleted: True, Message.deleted_at: deleted_at}) 80 | 81 | session.commit() 82 | session.refresh(dbo) 83 | else: 84 | session.delete(dbo) 85 | session.query(Message).where(Message.thread_id == id).delete() 86 | session.commit() 87 | return _models.DeletionStatus(id=id, object="thread.deleted", deleted=True) 88 | 89 | 90 | @_models.auto_session 91 | def list( 92 | limit: int = 20, 93 | order: str = "desc", 94 | after: Optional[str] = None, 95 | before: Optional[str] = None, 96 | tag: Optional[str] = None, 97 | user_id: Optional[str] = None, 98 | org_id: Optional[str] = None, 99 | session: Session = None 100 | ) -> ThreadList: 101 | return _models.list( 102 | db_cls=Thread, 103 | read_cls=ThreadRead, 104 | list_cls=ThreadList, 105 | limit=limit, 106 | order=order, 107 | after=after, 108 | before=before, 109 | tag=tag, 110 | user_id=user_id, 111 | org_id=org_id, 112 | session=session 113 | ) 114 | -------------------------------------------------------------------------------- /myla/vectorstores/chromadb_vectorstore.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Any, List, Optional, Dict 3 | from ._base import Record, VectorStore 4 | from ._embeddings import Embeddings 5 | from .. import utils 6 | 7 | 8 | class Chromadb(VectorStore): 9 | def __init__(self, path, embeddings: Embeddings = None) -> None: 10 | super().__init__() 11 | 12 | try: 13 | import chromadb 14 | except ImportError as exc: 15 | raise ImportError( 16 | "Could not import chromadb python package. " 17 | "Please install it with `pip install chromadb`." 18 | ) from exc 19 | self._embeddings = embeddings 20 | self._db = chromadb.PersistentClient(path=path) 21 | 22 | def create_collection(self, collection: str, schema: Dict[str, type] = None, mode="create"): 23 | """Create a new collection""" 24 | self._db.create_collection(name=collection) 25 | 26 | def add( 27 | self, 28 | collection: str, 29 | records: List[Record], 30 | embeddings_columns: List[str] = None, 31 | vectors: List[List[float]] = None, 32 | **kwargs 33 | ): 34 | """Add record to the vectorsotre""" 35 | col = self._db.get_collection(name=collection) 36 | 37 | ids = [] 38 | text_to_embed = [] 39 | for r in records: 40 | ids.append(utils.random_id()) 41 | text_to_embed.append(Record.values_to_text(r, props=embeddings_columns)) 42 | 43 | if not vectors: 44 | vectors = self._embeddings.embed_batch(texts=text_to_embed) 45 | 46 | if len(vectors) != len(records): 47 | raise ValueError("The length of records must be the same as the length of vecotors.") 48 | 49 | batch_size = 40000 50 | batchs = math.ceil(len(records) / batch_size) 51 | for i in range(batchs): 52 | b = i*batch_size 53 | e = (i+1)*batch_size 54 | if len(ids[b:e]) == 0: 55 | break 56 | col.add(ids=ids[b:e], embeddings=vectors[b:e], documents=text_to_embed[b:e], metadatas=records[b:e]) 57 | 58 | def delete(self, collection: str, query: str): 59 | """Delete record from the vectorstore""" 60 | 61 | def search( 62 | self, 63 | collection: str = None, 64 | query: str = None, 65 | vector: List = None, 66 | filter: Any = None, 67 | limit: int = 20, 68 | columns: Optional[List[str]] = None, 69 | with_vector: bool = False, 70 | with_distance: bool = False, 71 | **kwargs 72 | ) -> Optional[List[Record]]: 73 | """Search records""" 74 | col = self._db.get_collection(name=collection) 75 | 76 | include = ['metadatas', 'documents'] 77 | if with_vector: 78 | include.append('embeddings') 79 | if with_distance: 80 | include.append('distances') 81 | 82 | res = col.query( 83 | query_embeddings=[vector] if vector else None, 84 | query_texts=[query] if query else None, 85 | n_results=limit, 86 | include=include 87 | ) 88 | 89 | result = [] 90 | i = 0 91 | for r in res['metadatas']: 92 | record = r[0] 93 | if with_vector: 94 | record['vecotor'] = res['embeddings'][i][0] 95 | if with_distance: 96 | record['_distantce'] = res['distances'][i][0] 97 | result.append(record) 98 | 99 | i += 1 100 | return result 101 | -------------------------------------------------------------------------------- /myla/_entry.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import importlib 4 | from starlette.applications import Starlette 5 | from starlette.responses import RedirectResponse 6 | from starlette.routing import Mount, Route 7 | from starlette.staticfiles import StaticFiles 8 | from starlette.middleware import Middleware 9 | from starlette.middleware.authentication import AuthenticationMiddleware 10 | from contextlib import asynccontextmanager 11 | 12 | from .persistence import Persistence 13 | from ._run_scheduler import RunScheduler 14 | from . import _tools 15 | from . import _env 16 | from ._api import api 17 | from .webui._web_template import render, get_templates 18 | from ._logging import logger 19 | from .vectorstores import load_loaders 20 | from . import webui 21 | from . import users 22 | from . import _auth 23 | from . import _version 24 | 25 | 26 | def import_extensions(): 27 | ext_dir = os.environ.get("EXT_DIR") 28 | if ext_dir: 29 | sys.path.append(ext_dir) 30 | 31 | entry_py = os.path.join(ext_dir, 'entry.py') 32 | if os.path.exists(entry_py): 33 | importlib.import_module('entry') 34 | 35 | 36 | @asynccontextmanager 37 | async def lifespan(app: Starlette): 38 | try: 39 | # Load extensions 40 | import_extensions() 41 | 42 | # Load tools 43 | _tools.load_tools() 44 | 45 | # Load loaders 46 | load_loaders() 47 | 48 | # Initialize database 49 | Persistence.default().initialize_database() 50 | 51 | # Create default super admin user 52 | sa = users.create_default_superadmin() 53 | if sa: 54 | logger.warn(f"Super admin user created: {sa.username}") 55 | 56 | # Start RunScheduler 57 | RunScheduler.default().start() 58 | 59 | except Exception as e: 60 | logger.error(f"Lifespan error: {e}", exc_info=e) 61 | finally: 62 | yield 63 | # on shutdown 64 | pass 65 | 66 | # Routes 67 | routes = [ 68 | Mount( 69 | '/api', 70 | name='api', 71 | app=api 72 | ), 73 | Mount( 74 | '/webui/statics', 75 | name='statics', 76 | app=StaticFiles(directory=os.path.join(_env.webui_dir(), 'statics'), check_dir=False), 77 | ), 78 | Route( 79 | '/webui/', 80 | name='webui', 81 | endpoint=render('index.html', context={'version': _version.VERSION}) 82 | ), 83 | Route( 84 | '/', 85 | name='home', 86 | endpoint=lambda r: RedirectResponse('/webui') 87 | ), 88 | Route( 89 | '/assistants/{assistant_id}', 90 | name='assistant', 91 | endpoint=webui.assistant 92 | ) 93 | ] 94 | 95 | # Middlewares 96 | middleware = [ 97 | Middleware(AuthenticationMiddleware, backend=_auth.BasicAuthBackend()) 98 | ] 99 | 100 | entry = Starlette(debug=False, routes=routes, middleware=middleware, lifespan=lifespan) 101 | 102 | 103 | def register(path, name, endpoint): 104 | entry.routes.insert(0, Route(path=path, name=name, endpoint=endpoint)) 105 | 106 | 107 | def register_webui(webui_dir): 108 | webui_dir = os.path.abspath(webui_dir) 109 | templates = get_templates(webui_dir) 110 | 111 | entry.routes.insert(0, Mount( 112 | '/static', 113 | name='static', 114 | app=StaticFiles(directory=os.path.join(webui_dir, 'static'), check_dir=False), 115 | )) 116 | 117 | entry.routes.insert(0, Route( 118 | '/', 119 | name='home', 120 | endpoint=render('index.html', templates=templates) 121 | )) 122 | 123 | 124 | if os.environ.get("WEBUI"): 125 | webui_dir = os.environ.get("WEBUI") 126 | register_webui(webui_dir=webui_dir) 127 | -------------------------------------------------------------------------------- /myla/retrieval.py: -------------------------------------------------------------------------------- 1 | import json 2 | from .vectorstores import get_default_vectorstore 3 | from .tools import Tool, Context 4 | from ._logging import logger 5 | from . import llms 6 | 7 | RETRIEVAL_INSTRUCTIONS_EN = """ 8 | Refer to the retrievals to generate your answer. 9 | """ 10 | 11 | RETRIEVAL_INSTRUCTIONS_ZH = """ 12 | 参考 Retrievals 信息生成你的回答。 13 | """ 14 | 15 | 16 | class RetrievalTool(Tool): 17 | def __init__(self) -> None: 18 | super().__init__() 19 | self._vs = get_default_vectorstore() 20 | 21 | async def execute(self, context: Context) -> None: 22 | if len(context.messages) == 0: 23 | logger.debug("History is empty, skip retrieval") 24 | return 25 | 26 | collections = context.file_ids if context.file_ids else [] 27 | 28 | if 'retrieval_collection' in context.run_metadata: 29 | collections.append(context.run_metadata["retrieval_collection"]) 30 | 31 | if len(collections) == 0: 32 | logger.debug( 33 | "no retrieval collections specified, skip retrieval") 34 | return 35 | 36 | args = {"limit": 20, "with_distance": True} 37 | if "retrieval_limit" in context.run_metadata: 38 | args["limit"] = context.run_metadata["retrieval_limit"] 39 | distance = 1 40 | if "retrieval_distance" in context.run_metadata: 41 | distance = context.run_metadata['retrieval_distance'] 42 | 43 | query = context.messages[-1]["content"] 44 | 45 | docs = [] 46 | for c in collections: 47 | r_docs = await self._vs.asearch(collection=c, query=query, **args) 48 | for doc in r_docs: 49 | if doc['_distance'] < distance: 50 | docs.append(doc) 51 | 52 | logger.debug("Retrieval docs:" + json.dumps(docs, ensure_ascii=False)) 53 | if docs and len(docs) > 0: 54 | messages = context.messages 55 | last_message = messages[-1] 56 | messages = messages[:-1] 57 | 58 | messages.append({ 59 | "role": "system", 60 | "content": RETRIEVAL_INSTRUCTIONS_ZH, 61 | }) 62 | messages.append({ 63 | "role": "system", 64 | "content": "" 65 | }) 66 | messages.append({ 67 | "role": "system", 68 | "content": json.dumps(docs, ensure_ascii=False), 69 | "type": "docs" 70 | }) 71 | messages.append({ 72 | "role": "system", 73 | "content": "" 74 | }) 75 | messages.append(last_message) 76 | context.messages = messages 77 | 78 | 79 | DOC_SUMMARY_INSTRUCTIONS_ZH = """ 80 | 你是专业的问答分析助手。下面是JSON格式的问答记录。 81 | <问答记录开始> 82 | {docs} 83 | <问答记录介绍> 84 | 85 | 请根据问答记录生成新问题的候选回答。 86 | 新问题: {query} 87 | 候选回答: 88 | """ 89 | 90 | 91 | class DocSummaryTool(Tool): 92 | async def execute(self, context: Context) -> None: 93 | if len(context.messages) == 0: 94 | return 95 | 96 | last_message = context.messages[-1]['content'] 97 | 98 | docs = None 99 | 100 | for msg in context.messages: 101 | if msg.get('type') == 'docs': 102 | docs = msg 103 | 104 | if docs: 105 | summary = await llms.get().chat(messages=[{ 106 | "role": "system", 107 | "content": DOC_SUMMARY_INSTRUCTIONS_ZH.format(docs=docs['content'], query=last_message) 108 | }], stream=False, temperature=0) 109 | if summary: 110 | docs['content'] = summary 111 | -------------------------------------------------------------------------------- /tests/myla/users_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from myla import users, persistence 4 | 5 | 6 | class TestUsers(unittest.TestCase): 7 | 8 | def setUp(self) -> None: 9 | self.db = persistence.Persistence(database_url="sqlite://") 10 | self.db.initialize_database() 11 | self.session = self.db.create_session() 12 | 13 | def tearDown(self) -> None: 14 | self.session.close() 15 | 16 | def test_list_sa_users(self): 17 | sa_users = users.list_sa_users(session=self.session) 18 | self.assertIsInstance(sa_users, users.UserList) 19 | self.assertCountEqual(sa_users.data, []) 20 | 21 | def test_create_organization(self): 22 | org = users.create_organization(users.OrganizationCreate(), session=self.session) 23 | self.assertIsInstance(org, users.OrganizationRead) 24 | self.assertIsNotNone(org.id) 25 | 26 | org_read = users.get_organization(id=org.id, session=self.session) 27 | self.assertEqual(org_read.id, org.id) 28 | self.assertFalse(org_read.is_primary) 29 | 30 | def test_create_organization_with_metadata(self): 31 | org = users.create_organization(users.OrganizationCreate(display_name='org', metadata={'k': 'v'}), session=self.session) 32 | self.assertIsInstance(org, users.OrganizationRead) 33 | self.assertEqual(org.display_name, 'org') 34 | self.assertEqual(org.metadata, {'k': 'v'}) 35 | 36 | def test_create_organization_without_auto_commit(self): 37 | org = users.create_organization(users.OrganizationCreate(), session=self.session, auto_commit=False) 38 | self.assertIsInstance(org, users.OrganizationRead) 39 | self.assertIsNotNone(org.id) 40 | 41 | org_read = users.get_organization(id=org.id, session=self.db.create_session()) 42 | self.assertIsNone(org_read) 43 | 44 | def test_create_user(self): 45 | user = users.create_user(user=users.UserCreate(username='shellc', password='shellc'), session=self.session) 46 | 47 | self.assertIsInstance(user, users.UserRead) 48 | 49 | user_read = users.get_user(id=user.id, session=self.session) 50 | self.assertEqual(user.id, user_read.id) 51 | 52 | user_dbo = users.get_user_dbo(id=user.id, session=self.session) 53 | self.assertEqual(user_dbo.password, users.generate_password('shellc', user_dbo.salt)) 54 | 55 | orgs = users.list_orgs(user_id=user_read.id, session=self.session) 56 | self.assertIsInstance(orgs, users.OrganizationList) 57 | self.assertEqual(len(orgs.data), 1) 58 | self.assertEqual(orgs.data[0].display_name, user.display_name) 59 | self.assertEqual(orgs.data[0].is_primary, True) 60 | #self.assertEqual(orgs.data[0].user_id, user.id) 61 | 62 | def test_create_secret_key(self): 63 | sk = users.create_secret_key(key=users.SecrectKeyCreate(tag='web'), user_id='shellc', session=self.session) 64 | self.assertIsInstance(sk, users.SecretKeyRead) 65 | self.assertIsNotNone(sk.id) 66 | self.assertEqual(sk.tag, 'web') 67 | self.assertEqual(sk.user_id, 'shellc') 68 | 69 | sk = users.get_secret_key(id=sk.id, session=self.session) 70 | self.assertIsInstance(sk, users.SecretKeyRead) 71 | self.assertIsNotNone(sk.id) 72 | self.assertEqual(sk.tag, 'web') 73 | self.assertEqual(sk.user_id, 'shellc') 74 | 75 | sks = users.list_secret_keys(user_id='shellc', session=self.session) 76 | self.assertEqual(len(sks.data), 1) 77 | sk = sks.data[0] 78 | self.assertIsInstance(sk, users.SecretKeyRead) 79 | self.assertIsNotNone(sk.id) 80 | self.assertEqual(sk.tag, 'web') 81 | self.assertEqual(sk.user_id, 'shellc') 82 | 83 | def create_default_superadmin(self): 84 | sa = users.create_default_superadmin(session=self.session) 85 | self.assertEqual(sa.username, 'admin') 86 | 87 | sa = users.create_default_superadmin(session=self.session) 88 | self.assertIsNone(sa) 89 | -------------------------------------------------------------------------------- /js/src/secret_key.js: -------------------------------------------------------------------------------- 1 | import React, { useEffect } from 'react'; 2 | import { useState } from 'react' 3 | import { Alert, Button, Input, Table } from 'antd'; 4 | import { DeleteOutlined, PlusOutlined } from '@ant-design/icons'; 5 | import Cookies from 'js-cookie'; 6 | 7 | export const SecretKeySettings = (props) => { 8 | const [keys, setKeys] = useState(); 9 | const [createdKey, setCreatedKey] = useState(); 10 | 11 | const loadSecretKeys = () => { 12 | fetch('/api/v1/secret_keys') 13 | .then(r => { 14 | if (r.status === 200) { 15 | return r.json(); 16 | } else if (r.status === 403) { 17 | Cookies.remove('secret_key'); 18 | window.location.href = '/'; 19 | } 20 | }) 21 | .then(data => { 22 | let sks = data.data; 23 | for (let i = 0; i < sks.length; i ++) { 24 | sks[i].key = i; 25 | } 26 | setKeys(sks); 27 | }) 28 | } 29 | 30 | const onDelete = (id) => { 31 | fetch(`/api/v1/secret_keys/${id}`, { 32 | method: 'DELETE' 33 | }).then(r => { 34 | if (r.status === 200) { 35 | loadSecretKeys(); 36 | } else { 37 | throw new Error("Status: " +r.status) 38 | } 39 | }) 40 | } 41 | 42 | const onCreate = () => { 43 | fetch('/api/v1/secret_keys', { 44 | method: 'POST', 45 | headers: {"Content-Type": "application/json"} 46 | }).then(r => r.json()) 47 | .then(data => { 48 | setCreatedKey(data.id); 49 | }) 50 | } 51 | 52 | const onCloseCreatedResult = () => { 53 | setCreatedKey(null); 54 | loadSecretKeys(); 55 | } 56 | 57 | useEffect(() => { 58 | loadSecretKeys(); 59 | }, []); 60 | 61 | return ( 62 |
63 | 71 | {createdKey ? ( 72 | 76 |

Please save this secret key somewhere safe and accessible. For security reasons, you won't be able to view it again through your account. If you lose this secret key, you'll need to generate a new one.

77 | 78 |
} 79 | type="warning" showIcon closable style={{marginBottom: 10}}/> 80 | ) : null} 81 | 82 | ( 104 | onDelete(r.id)} /> 105 | ) 106 | } 107 | ]} 108 | locale={{ emptyText: ' ' }} 109 | dataSource={keys} 110 | pagination={false} 111 | /> 112 | 113 | 114 | ) 115 | } -------------------------------------------------------------------------------- /tests/myla/vectorstores/faiss_group_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import unittest 4 | from myla.vectorstores.faiss_group import FAISSGroup 5 | from myla.utils import random_id, sha256 6 | 7 | here = os.path.abspath(os.path.dirname(__file__)) 8 | 9 | 10 | class FAISSGroupTests(unittest.TestCase): 11 | def setUp(self) -> None: 12 | self._vectors = [ 13 | [0, 0], 14 | [1, 1], 15 | [2, 2], 16 | [3, 3], 17 | [4, 4] 18 | ] 19 | 20 | self._records = [ 21 | { 22 | 'id': 0, 23 | 'gid': 'g0', 24 | }, 25 | { 26 | 'id': 1, 27 | 'gid': 'g0', 28 | }, 29 | { 30 | 'id': 2, 31 | 'gid': 'g2', 32 | }, 33 | { 34 | 'id': 3, 35 | 'gid': 'g3', 36 | }, 37 | { 38 | 'id': 4, 39 | } 40 | ] 41 | 42 | self._data = os.path.abspath(os.path.join(here, os.pardir, os.pardir, 'data', random_id())) 43 | 44 | def tearDown(self) -> None: 45 | if os.path.exists(self._data): 46 | shutil.rmtree(self._data) 47 | pass 48 | 49 | def test_create_collection(self): 50 | vs = FAISSGroup(path=self._data) 51 | vs.create_collection(collection='col') 52 | 53 | def test_add(self): 54 | vs = FAISSGroup(path=self._data) 55 | vs.create_collection(collection='col') 56 | 57 | vs.add(collection='col', records=self._records, vectors=self._vectors, group_by='gid') 58 | self.assertIsNotNone(vs._data.get('col')) 59 | self.assertEqual(vs._data.get('col'), self._records) 60 | self.assertIsNotNone(vs._indexes.get('col')) 61 | self.assertIsNotNone(vs._ids.get('col')) 62 | 63 | self.assertEqual(len(vs._indexes.get('col')), 4) 64 | self.assertEqual(len(vs._ids.get('col')), 4) 65 | 66 | self.assertEqual(vs._indexes.get('col').keys(), vs._ids.get('col').keys()) 67 | 68 | gids = list(vs._indexes.get('col').keys()) 69 | gids.sort() 70 | gids_1 = [] 71 | for r in self._records: 72 | gids_1.append(sha256(r.get('gid', '').encode()).hex()) 73 | gids_1 = list(set(gids_1)) 74 | gids_1.sort() 75 | 76 | self.assertEqual(gids, gids_1) 77 | 78 | self.assertEqual(vs._ids.get('col')[vs._group_id("")], [4]) 79 | self.assertEqual(vs._ids.get('col')[vs._group_id("g0")], [0, 1]) 80 | self.assertEqual(vs._ids.get('col')[vs._group_id("g2")], [2]) 81 | self.assertEqual(vs._ids.get('col')[vs._group_id("g3")], [3]) 82 | 83 | vs.add(collection='col', records=[{'gid': 'g2'}, {'gid': 'g3', 'id': 6}], vectors=[[5, 5], [6, 6]], group_by='gid') 84 | self.assertEqual(vs._ids.get('col')[vs._group_id("g2")], [2, 5]) 85 | self.assertEqual(vs._ids.get('col')[vs._group_id("g3")], [3, 6]) 86 | 87 | self.assertEqual(vs._data.get('col')[6]['id'], 6) 88 | 89 | def test_load(self): 90 | vs = FAISSGroup(path=self._data) 91 | vs.create_collection(collection='col') 92 | 93 | vs.add(collection='col', records=self._records, vectors=self._vectors, group_by='gid') 94 | 95 | vs._unload(collection='col') 96 | self.assertIsNone(vs._data.get('col')) 97 | 98 | vs._load(collection='col') 99 | self.assertIsNotNone(vs._data.get('col')) 100 | self.assertEqual(vs._data.get('col'), self._records) 101 | 102 | def test_search(self): 103 | vs = FAISSGroup(path=self._data) 104 | vs.create_collection(collection='col') 105 | 106 | vs.add(collection='col', records=self._records, vectors=self._vectors, group_by='gid') 107 | 108 | records = vs.search(collection='col', vector=self._vectors[0], group_ids=['g0']) 109 | self.assertEqual(records[0]['id'], 0) 110 | self.assertEqual(records[0]['_distance'], 0.0) 111 | 112 | records = vs.search(collection='col', vector=self._vectors[1], group_ids=['g0']) 113 | self.assertEqual(records[0]['id'], 1) 114 | self.assertEqual(records[0]['_distance'], 0.0) 115 | 116 | records = vs.search(collection='col', vector=self._vectors[0], group_ids=['g2']) 117 | self.assertEqual(records[0]['id'], 2) 118 | self.assertGreaterEqual(records[0]['_distance'], 0.5) 119 | 120 | records = vs.search(collection='col', vector=self._vectors[0], group_ids=None) 121 | self.assertEqual(records[0]['id'], 4) 122 | self.assertGreaterEqual(records[0]['_distance'], 0.5) 123 | 124 | records = vs.search(collection='col', vector=self._vectors[1], group_ids=['g0', 'g2', 'g0', None]) 125 | self.assertEqual(records[0]['id'], 1) 126 | self.assertEqual(records[0]['_distance'], 0.0) 127 | 128 | def test_group_id(self): 129 | vs = FAISSGroup(path=self._data) 130 | print(vs._group_id()) -------------------------------------------------------------------------------- /myla/llms/openai.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict, List 3 | 4 | import openai 5 | 6 | from .. import utils 7 | from .backend import LLM 8 | 9 | 10 | class OpenAI(LLM): 11 | def __init__(self, model=None, api_key=None, base_url=None) -> None: 12 | super().__init__(model) 13 | self.api_key = api_key 14 | self.base_url = base_url 15 | 16 | async def chat(self, messages: List[Dict], model=None, stream=False, **kwargs): 17 | if "api_key" not in kwargs: 18 | kwargs["api_key"] = self.api_key 19 | if "base_url" not in kwargs: 20 | kwargs["base_url"] = self.base_url 21 | 22 | return await chat( 23 | messages=messages, 24 | model=model if model else self.model, 25 | stream=stream, 26 | **kwargs 27 | ) 28 | 29 | async def generate(self, instructions: str, model=None, stream=False, **kwargs): 30 | if not model: 31 | model = self.model 32 | return await generate(instructions=instructions, model=model, stream=stream, **kwargs) 33 | 34 | def sync_chat(self, messages: List[Dict], model=None, stream=False, **kwargs): 35 | if "api_key" not in kwargs: 36 | kwargs["api_key"] = self.api_key 37 | if "base_url" not in kwargs: 38 | kwargs["base_url"] = self.base_url 39 | 40 | return sync_chat( 41 | messages=messages, 42 | model=model if model else self.model, 43 | stream=stream, 44 | **kwargs 45 | ) 46 | 47 | def sync_generate(self, instructions: str, model=None, stream=False, **kwargs): 48 | if not model: 49 | model = self.model 50 | return sync_generate(instructions=instructions, model=model, stream=stream, **kwargs) 51 | 52 | 53 | async def chat(messages: List[Dict], model=None, stream=False, api_key=None, base_url=None, **kwargs): 54 | if not api_key: 55 | api_key = os.environ.get("LLM_API_KEY") 56 | if not base_url: 57 | base_url = os.environ.get("LLM_ENDPOINT") 58 | if not model: 59 | model = os.environ.get("DEFAULT_LLM_MODEL_NAME") 60 | 61 | llm = openai.AsyncOpenAI(api_key=api_key, base_url=base_url) 62 | 63 | usage = None 64 | if "usage" in kwargs: 65 | usage = kwargs.pop("usage") 66 | 67 | @utils.retry 68 | async def _call(): 69 | resp = await llm.chat.completions.create( 70 | model=model, 71 | messages=messages, 72 | stream=stream, 73 | **kwargs 74 | ) 75 | if stream: 76 | async def iter(): 77 | async for r in resp: 78 | yield r.choices[0].delta.content if r.choices else 'Unexpected LLM error, possibly due to context being too long.' 79 | return iter() 80 | else: 81 | genereated = resp.choices[0].message.content 82 | if usage: 83 | usage.prompt_tokens = resp.usage.prompt_tokens 84 | usage.completion_tokens = resp.usage.completion_tokens 85 | return genereated 86 | 87 | return await _call() 88 | 89 | 90 | async def generate(instructions: str, model=None, stream=False, **kwargs): 91 | r = await chat(messages=[{ 92 | "role": "system", 93 | "content": instructions 94 | }], model=model, stream=stream, **kwargs) 95 | return r 96 | 97 | 98 | def sync_chat(messages: List[Dict], model=None, stream=False, api_key=None, base_url=None, **kwargs): 99 | if not api_key: 100 | api_key = os.environ.get("LLM_API_KEY") 101 | if not base_url: 102 | base_url = os.environ.get("LLM_ENDPOINT") 103 | if not model: 104 | model = os.environ.get("DEFAULT_LLM_MODEL_NAME") 105 | 106 | llm = openai.OpenAI(api_key=api_key, base_url=base_url) 107 | 108 | usage = None 109 | if "usage" in kwargs: 110 | usage = kwargs.pop("usage") 111 | 112 | @utils.retry 113 | def _call(): 114 | resp = llm.chat.completions.create( 115 | model=model, 116 | messages=messages, 117 | stream=stream, 118 | **kwargs 119 | ) 120 | if stream: 121 | def iter(): 122 | for r in resp: 123 | yield r.choices[0].delta.content if r.choices else 'Unexpected LLM error, possibly due to context being too long.' 124 | return iter() 125 | else: 126 | genereated = resp.choices[0].message.content 127 | if usage: 128 | usage.prompt_tokens = resp.usage.prompt_tokens 129 | usage.completion_tokens = resp.usage.completion_tokens 130 | return genereated 131 | 132 | return _call() 133 | 134 | 135 | def sync_generate(instructions: str, model=None, stream=False, **kwargs): 136 | r = sync_chat(messages=[{ 137 | "role": "system", 138 | "content": instructions 139 | }], model=model, stream=stream, **kwargs) 140 | return r 141 | -------------------------------------------------------------------------------- /myla/vectorstores/lancedb_vectorstore.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, Optional, Dict 2 | from ._base import Record, VectorStore 3 | from ._embeddings import Embeddings 4 | 5 | VECTOR_COLUMN_NAME = "_vector" 6 | 7 | 8 | class LanceDB(VectorStore): 9 | def __init__(self, db_uri, embeddings: Embeddings = None) -> None: 10 | super().__init__() 11 | 12 | try: 13 | import pyarrow as pa 14 | pa.__version__ 15 | except ImportError as exc: 16 | raise ImportError( 17 | "Could not import pyarrow python package. " 18 | "Please install it with `pip install pyarrow`." 19 | ) from exc 20 | 21 | try: 22 | import lancedb as lancedb 23 | 24 | # disable diagnostics 25 | lancedb.utils.CONFIG['diagnostics'] = False 26 | except ImportError as exc: 27 | raise ImportError( 28 | "Could not import lancedb python package. " 29 | "Please install it with `pip install lancedb`." 30 | ) from exc 31 | 32 | self._db_uri = db_uri 33 | self._embeddings = embeddings 34 | 35 | self._db = lancedb.connect(self._db_uri) 36 | 37 | self._tables = {} 38 | 39 | def create_collection(self, collection: str, schema: Dict[str, type] = None, mode="create"): 40 | if schema is None: 41 | raise ValueError("Invalid schema to create LanceDB table.") 42 | 43 | s = self._convert_schema(schema=schema) 44 | 45 | self._db.create_table(collection, schema=s, mode=mode) 46 | 47 | def add( 48 | self, 49 | collection: str, 50 | records: List[Record], 51 | embeddings_columns: List[str] = None, 52 | vectors: List[List[float]] = None, 53 | **kwargs 54 | ): 55 | tbl = self._db.open_table(collection) 56 | 57 | if not vectors: 58 | text_to_embed = [] 59 | for r in records: 60 | text_to_embed.append(Record.values_to_text(r, props=embeddings_columns)) 61 | 62 | vectors = self._embeddings.embed_batch(texts=text_to_embed) 63 | 64 | if len(vectors) != len(records): 65 | raise ValueError("The length of records must be the same as the length of vecotors.") 66 | 67 | for i in range(len(records)): 68 | records[i][VECTOR_COLUMN_NAME] = vectors[i] 69 | 70 | tbl.add(records) 71 | 72 | def delete(self, collection: str, query: str): 73 | tbl = self._db.open_table(collection) 74 | tbl.delete(query) 75 | 76 | def search( 77 | self, 78 | collection: str = None, 79 | query: str = None, 80 | vector: List = None, 81 | filter: Any = None, 82 | limit: int = 20, 83 | columns: Optional[List[str]] = None, 84 | with_vector: bool = False, 85 | with_distance: bool = False, 86 | **kwargs 87 | ) -> List[Record]: 88 | if not query and not vector: 89 | raise ValueError("LanceDB search must provide query or vector.") 90 | 91 | if query and not vector and self._embeddings: 92 | vector = self._embeddings.embed(text=query) 93 | if not vector: 94 | raise ValueError( 95 | "LanceDB search must provide Embeddings function.") 96 | 97 | tbl = self._db.open_table(collection) 98 | 99 | query = tbl.search(vector, vector_column_name=VECTOR_COLUMN_NAME) 100 | if filter: 101 | query = query.where(filter) 102 | 103 | if columns: 104 | query = query.select(columns=columns) 105 | 106 | results = query.limit(limit=limit).to_list() 107 | for v in results: 108 | if not with_vector: 109 | del v[VECTOR_COLUMN_NAME] 110 | if not with_distance: 111 | del v['_distance'] 112 | 113 | return results 114 | 115 | def _convert_schema(self, schema: Dict[str, type]): 116 | try: 117 | import pyarrow as pa 118 | except ImportError as exc: 119 | raise ImportError( 120 | "Could not import pyarrow python package. " 121 | "Please install it with `pip install pyarrow`." 122 | ) from exc 123 | 124 | dims = len(self._embeddings.embed("")) 125 | columns = [ 126 | pa.field(VECTOR_COLUMN_NAME, pa.list_(pa.float32(), dims)), 127 | ] 128 | 129 | for k, v in schema.items(): 130 | t = pa.string() 131 | 132 | if isinstance(v, float): 133 | t = pa.float64() 134 | if isinstance(v, int): 135 | t = pa.int64() 136 | if isinstance(v, bool): 137 | t = pa.bool_() 138 | 139 | columns.append( 140 | pa.field(k, t) 141 | ) 142 | 143 | s = pa.schema(columns) 144 | 145 | return s 146 | -------------------------------------------------------------------------------- /tests/perf/vs.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime 3 | import pickle 4 | import pandas as pd 5 | from concurrent.futures import ThreadPoolExecutor 6 | 7 | from myla import utils 8 | from myla.vectorstores import get_default_embeddings, get_default_vectorstore, pandas_loader 9 | 10 | here = os.path.abspath(os.path.dirname(__file__)) 11 | 12 | os.environ['EMBEDDINGS_IMPL'] = 'sentence_transformers' 13 | os.environ['EMBEDDINGS_MODEL_NAME'] = '/Users/shellc/Downloads/bge-large-zh-v1.5' 14 | os.environ['EMBEDDINGS_DEVICE'] = 'mps' 15 | 16 | os.environ['VECTORSTORE_DIR'] = os.path.join(here, 'vs') 17 | os.environ['VECTOR_STORE_IMPL'] = 'faissg' 18 | 19 | data_input = os.path.join(here, 'bq1.csv') 20 | embeds_output = os.path.join(here, 'embeds.pkl') 21 | 22 | embeddings = get_default_embeddings() 23 | vs = get_default_vectorstore() 24 | 25 | 26 | def load_records(): 27 | return list(pandas_loader.PandasLoader().load(data_input)) 28 | 29 | 30 | def records_stat(): 31 | df = pd.read_csv(data_input) 32 | df['len'] = df.apply(lambda x : len(x['sentence1']) + len(x['sentence2']), axis=1) 33 | df['len1'] = df['sentence1'].apply(lambda x : len(x)) 34 | s = pd.concat([df['len'].describe(), df['len1'].describe()], axis=1) 35 | print(s) 36 | 37 | 38 | def embed(): 39 | embedings = get_default_embeddings() 40 | records = load_records() 41 | 42 | texts = [r['sentence1'] for r in records] 43 | 44 | embeds = embedings.embed_batch(texts=texts) 45 | with open(embeds_output, 'wb') as f: 46 | pickle.dump(embeds, f) 47 | 48 | 49 | def test_embed(): 50 | embed() 51 | 52 | 53 | def create_col(records=None, vectors=None): 54 | if not records: 55 | records = load_records() 56 | if not vectors: 57 | vectors = pickle.load(open(embeds_output, 'rb')) 58 | 59 | print(f"length: record={len(records)}, vectors={len(vectors)}") 60 | 61 | col_name = utils.random_id() 62 | vs.create_collection(collection=col_name, schema=records[0]) 63 | 64 | return col_name 65 | 66 | 67 | def test_add_batch(records=None, vectors=None): 68 | records = load_records() 69 | vectors = pickle.load(open(embeds_output, 'rb')) 70 | 71 | col_name = create_col(records=records, vectors=vectors) 72 | 73 | begin = datetime.now().timestamp() 74 | vs.add(collection=col_name, records=records, vectors=vectors) 75 | end = datetime.now().timestamp() 76 | 77 | print(os.environ['VECTOR_STORE_IMPL'], "elapsed", end-begin) 78 | 79 | 80 | def test_add(): 81 | records = load_records() 82 | vectors = pickle.load(open(embeds_output, 'rb')) 83 | 84 | col_name = create_col(records=records, vectors=vectors) 85 | 86 | vs.add(collection=col_name, records=records, vectors=vectors) 87 | 88 | begin = datetime.now().timestamp() 89 | for i in range(1000): #range(len(records)): 90 | vs.add(collection=col_name, records=[records[i]], vectors=[vectors[i]]) 91 | end = datetime.now().timestamp() 92 | 93 | print(os.environ['VECTOR_STORE_IMPL'], "elapsed", end-begin) 94 | 95 | 96 | def test_search(): 97 | records = load_records() 98 | vectors = pickle.load(open(embeds_output, 'rb')) 99 | 100 | col_name = create_col(records=records, vectors=vectors) 101 | 102 | vs.add(collection=col_name, records=records, vectors=vectors, group_by='label') 103 | 104 | begin = datetime.now().timestamp() 105 | for i in range(1000): 106 | vs.search(collection=col_name, vector=vectors[i], group_ids=[0]) 107 | end = datetime.now().timestamp() 108 | 109 | print(os.environ['VECTOR_STORE_IMPL'], "elapsed", end-begin) 110 | 111 | 112 | def test_search_multi_trehads(): 113 | records = load_records() 114 | vectors = pickle.load(open(embeds_output, 'rb')) 115 | 116 | col_name = create_col(records=records, vectors=vectors) 117 | 118 | vs.add(collection=col_name, records=records, vectors=vectors) 119 | 120 | def _search(i): 121 | vs.search(collection=col_name, vector=vectors[i]) 122 | 123 | executor = ThreadPoolExecutor(max_workers=10) 124 | futures = [] 125 | 126 | begin = datetime.now().timestamp() 127 | for i in range(1000): 128 | futures.append(executor.submit(_search, i=i)) 129 | 130 | for f in futures: 131 | f.done() 132 | f.result() 133 | 134 | end = datetime.now().timestamp() 135 | 136 | print(os.environ['VECTOR_STORE_IMPL'], "elapsed", end-begin) 137 | 138 | 139 | def test_multi_vs(): 140 | records = load_records() 141 | vectors = pickle.load(open(embeds_output, 'rb')) 142 | 143 | cols = [] 144 | for i in range(100): 145 | rs = records[:100].copy() 146 | col_name = create_col(records=rs, vectors=vectors) 147 | vs.add(collection=col_name, records=rs, vectors=vectors[:100]) 148 | cols.append(col_name) 149 | 150 | import time 151 | time.sleep(100) 152 | 153 | 154 | if __name__ == '__main__': 155 | #records = load_records() 156 | #vectors = pickle.load(open(embeds_output, 'rb')) 157 | 158 | #records_stat() 159 | 160 | #test_embed() 161 | 162 | #test_add_batch() 163 | #test_add() 164 | 165 | #test_search() 166 | #test_search_multi_trehads() 167 | 168 | #test_multi_vs() 169 | 170 | import time 171 | time.sleep(100000) 172 | -------------------------------------------------------------------------------- /myla/vectorstores/__init__.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from typing import Optional 4 | 5 | from .._logging import logger 6 | from ..utils import create_instance 7 | from . import pandas_loader, pdf_loader 8 | from ._base import Record, VectorStore 9 | from ._embeddings import Embeddings 10 | from .chromadb_vectorstore import Chromadb 11 | from .faiss_group import FAISSGroup 12 | from .faiss_vectorstore import FAISS 13 | from .lancedb_vectorstore import LanceDB 14 | 15 | _default_embeddings = None 16 | 17 | 18 | def get_default_embeddings(): 19 | global _default_embeddings 20 | 21 | if _default_embeddings is None: 22 | impl = os.environ.get("EMBEDDINGS_IMPL") 23 | model_name = os.environ.get("EMBEDDINGS_MODEL_NAME") 24 | device = os.environ.get("EMBEDDINGS_DEVICE") 25 | instruction = os.environ.get("EMBEDDINGS_INSTRUCTION") 26 | multi_process = os.environ.get("EMBEDDINGS_MULTI_PROCESS") 27 | multi_process_devices = os.environ.get("EMBEDDINGS_MULTI_PROCESS_DEVICES") 28 | 29 | if multi_process is not None and multi_process.lower() == "true": 30 | multi_process = True 31 | else: 32 | multi_process = False 33 | 34 | if multi_process_devices is not None: 35 | multi_process_devices = multi_process_devices.split(",") 36 | model_kwargs = {'device': device if device else "cpu"} 37 | 38 | if not impl or impl == 'sentence_transformers': 39 | from .sentence_transformers_embeddings import \ 40 | SentenceTransformerEmbeddings 41 | _default_embeddings = SentenceTransformerEmbeddings( 42 | model_name=model_name, 43 | model_kwargs=model_kwargs, 44 | instruction=instruction, 45 | multi_process=multi_process, 46 | multi_process_devices=multi_process_devices 47 | ) 48 | elif impl == 'xinference': 49 | from .xinference_embeddings import XinferenceEmbeddings 50 | 51 | base_url = os.environ.get("XINFERENCE_BASE_URL") 52 | model_id = os.environ.get("XINFERENCE_MODEL_ID") 53 | 54 | _default_embeddings = XinferenceEmbeddings( 55 | base_url=base_url, 56 | model_id=model_id, 57 | instruction=instruction 58 | ) 59 | else: 60 | raise ValueError(f"Embedding implement not supported: {impl}") 61 | return _default_embeddings 62 | 63 | 64 | _default_vs = {} 65 | 66 | 67 | def get_default_vectorstore(): 68 | impl = os.environ.get("VECTOR_STORE_IMPL") 69 | if not impl: 70 | raise ValueError("VECTOR_STORE_IMPL is required.") 71 | 72 | vs = _default_vs.get(impl) 73 | if not vs: 74 | vs_dir = os.environ.get("VECTORSTORE_DIR") 75 | if not vs_dir: 76 | raise ValueError("VECTORSTORE_DIR is required.") 77 | 78 | embeddings = get_default_embeddings() 79 | 80 | if impl == 'faiss': 81 | vs = FAISS(db_path=vs_dir, embeddings=embeddings) 82 | elif impl == 'lancedb': 83 | vs = LanceDB(db_uri=vs_dir, embeddings=embeddings) 84 | elif impl == 'chromadb': 85 | vs = Chromadb(path=vs_dir, embeddings=embeddings) 86 | elif impl == 'faissg': 87 | vs = FAISSGroup(path=vs_dir, embeddings=embeddings) 88 | else: 89 | raise ValueError(f"VectorStore not suported: {impl}") 90 | _default_vs[impl] = vs 91 | return vs 92 | 93 | 94 | _loaders = {} 95 | 96 | 97 | def load_loaders(): 98 | """Load configured Loaders.""" 99 | loaders_cfg = os.environ.get("LOADERS") 100 | if not loaders_cfg: 101 | return 102 | 103 | loaders = json.loads(loaders_cfg) 104 | for c in loaders: 105 | try: 106 | name = c.get("name") 107 | impl = c.get("impl") 108 | if not name or not impl: 109 | logger.warn(f"Invalid Loader config: name={name}, impl={impl}") 110 | continue 111 | instance = create_instance(impl) 112 | _loaders[name] = instance 113 | except Exception as e: 114 | logger.warn(f"Create Loader failed: {e}", exc_info=e) 115 | 116 | 117 | def get_loader_instance(name: str): 118 | """Get configured Loader.""" 119 | return _loaders.get(name) 120 | 121 | 122 | def load_vectorstore_from_file(collection: str, fname: str, ftype: str, embeddings_columns=None, loader: Optional[str] = None, **kwargs): 123 | vs = get_default_vectorstore() 124 | 125 | if loader: 126 | loader_ = get_loader_instance(loader) 127 | elif ftype in ['csv', 'xls', 'xlsx', 'json']: 128 | loader_ = pandas_loader.PandasLoader(ftype=ftype) 129 | elif ftype == 'pdf': 130 | loader_ = pdf_loader.PDFLoader() 131 | else: 132 | raise ValueError("Invalid file type.") 133 | 134 | if not loader_: 135 | raise RuntimeError(f"Loader not found: {loader}") 136 | 137 | records = list(loader_.load(file=fname, metadata=kwargs.get("metadata"))) 138 | 139 | if len(records) == 0: 140 | return 141 | 142 | vs.create_collection(collection=collection, schema=records[0], mode='overwrite') 143 | 144 | vs.add( 145 | collection=collection, 146 | records=records, 147 | embeddings_columns=embeddings_columns, 148 | group_by=kwargs.get('group_by'), 149 | instruction=kwargs.get('instruction') 150 | ) 151 | -------------------------------------------------------------------------------- /js/src/user_admin.js: -------------------------------------------------------------------------------- 1 | import { DeleteOutlined, PlusOutlined } from "@ant-design/icons" 2 | import { Button, Form, Input, Space, Table, message } from "antd" 3 | import Link from "antd/es/typography/Link" 4 | import { useEffect, useState } from "react" 5 | import { getUser } from "./user" 6 | 7 | export const UserAdmin = () => { 8 | const [users, setUsers] = useState() 9 | const [msg, msgContext] = message.useMessage() 10 | const [formView, setFormView] = useState(false) 11 | const [createUserForm] = Form.useForm() 12 | 13 | let user = getUser() 14 | 15 | const loadUsers = () => { 16 | fetch('/api/v1/users').then(r => { 17 | if (r.status === 200) { 18 | return r.json(); 19 | } else { 20 | throw new Error("Status: " + r.status) 21 | } 22 | }).then(data => { 23 | setUsers(data.data); 24 | }).catch(err => { 25 | msg.error(err.message); 26 | }) 27 | } 28 | 29 | const onDelete = (username) => { 30 | fetch(`/api/v1/users/${username}`, { 31 | method: 'DELETE' 32 | }).then(r => { 33 | if (r.status === 200) { 34 | msg.success("OK"); 35 | loadUsers(); 36 | } else { 37 | throw new Error("Status: " + r.status) 38 | } 39 | }).catch(err => { 40 | msg.error(err.message); 41 | }) 42 | } 43 | 44 | const onCreate = () => { 45 | let username = createUserForm.getFieldValue('username'); 46 | let password = createUserForm.getFieldValue('password'); 47 | 48 | fetch('/api/v1/users', { 49 | method: 'POST', 50 | headers: {'Content-Type': 'application/json'}, 51 | body: JSON.stringify({ 52 | username: username, 53 | password: password 54 | }) 55 | }).then(r => { 56 | if (r.status === 200) { 57 | onCancel(); 58 | loadUsers(); 59 | } else { 60 | throw new Error('Stauts: ' + r.status); 61 | } 62 | }).catch(err => { 63 | msg.error(err.message) 64 | }) 65 | } 66 | 67 | const onCancel = () => { 68 | setFormView(false); 69 | createUserForm.resetFields(); 70 | } 71 | 72 | useEffect(() => { 73 | loadUsers(); 74 | }, []) 75 | 76 | return ( 77 |
78 | {msgContext} 79 | {!formView ? ( 80 |
81 | 89 |
( 101 | user.username !== r.username ? onDelete(r.username)} /> : null 102 | ) 103 | } 104 | ]} 105 | locale={{ emptyText: ' ' }} 106 | dataSource={users} 107 | pagination={false} 108 | /> 109 | 110 | ) : ( 111 |
112 | < Back 113 |
118 | 125 | 128 | 129 | 130 | 137 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 |
148 | )} 149 | 150 | ) 151 | } -------------------------------------------------------------------------------- /myla/vectorstores/faiss_vectorstore.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Optional, Dict, Any 3 | from ._base import Record, VectorStore 4 | from ._embeddings import Embeddings 5 | from .._logging import logger 6 | 7 | 8 | def _import_langchain_vectorstores(): 9 | try: 10 | import langchain.vectorstores as vectorstores 11 | except ImportError as exc: 12 | raise ImportError( 13 | "Could not import langchain.vectorstores python package. " 14 | "Please install it with `pip install langchain`." 15 | ) from exc 16 | return vectorstores 17 | 18 | 19 | class FAISS(VectorStore): 20 | def __init__(self, db_path: str = None, embeddings: Embeddings = None) -> None: 21 | self._db_path = db_path 22 | self._embeddings = embeddings 23 | self._collections = {} 24 | 25 | def create_collection(self, collection: str, schema: Dict[str, type] = None, mode="create"): 26 | vectorstores = _import_langchain_vectorstores() 27 | vs = vectorstores.FAISS.from_texts(texts=[''], embedding=self.get_embeddings(), normalize_L2=True) 28 | vs.save_local(os.path.join(self._db_path, collection)) 29 | 30 | def add( 31 | self, 32 | collection: str, 33 | records: List[Record], 34 | embeddings_columns: List[str] = None, 35 | vectors: List[List[float]] = None, 36 | **kwargs 37 | ): 38 | vs = self._get_vectorstore(collection) 39 | 40 | text_to_embed = [] 41 | for r in records: 42 | text_to_embed.append(Record.values_to_text(r, props=embeddings_columns)) 43 | 44 | if vectors is None: 45 | vectors = self._embeddings.embed_batch(texts=text_to_embed, instruction=kwargs.get('instruction')) 46 | 47 | if len(vectors) != len(text_to_embed): 48 | raise ValueError("The length of records must be the same as the length of vecotors.") 49 | 50 | text_embeddings = [] 51 | for i in range(len(text_to_embed)): 52 | text_embeddings.append((text_to_embed[i], vectors[i])) 53 | vs.add_embeddings(text_embeddings=text_embeddings, metadatas=records) 54 | 55 | vs.save_local(os.path.join(self._db_path, collection)) 56 | 57 | def delete(self, collection: str, query: str): 58 | raise RuntimeError("Not implemented.") 59 | 60 | def search( 61 | self, 62 | collection: str = None, 63 | query: str = None, 64 | vector: List = None, 65 | filter: Any = None, 66 | limit: int = 20, 67 | columns: List[str] = None, 68 | with_vector: bool = False, 69 | with_distance: bool = False, 70 | **kwargs 71 | ) -> List[Record]: 72 | fetch_k = kwargs['fetch_k'] if 'fetch_k' in kwargs else None 73 | if not fetch_k: 74 | fetch_k = limit * 10 75 | 76 | if vector is None: 77 | vector = self._embeddings.embed(text=query, instruction=kwargs.get('instruction')) 78 | 79 | return self._faiss_search(collection_name=collection, query=query, vector=vector, filter=filter, k=limit, fetch_k=fetch_k, **kwargs) 80 | 81 | def _faiss_search( 82 | self, 83 | collection_name, 84 | query: str = None, 85 | vector=None, 86 | k: int = 4, 87 | filter: Optional[Dict[str, Any]] = None, 88 | fetch_k: int = 20, 89 | **kwargs: Any 90 | ) -> Dict: 91 | vs = self._get_vectorstore(name=collection_name) 92 | 93 | if vector: 94 | docs = vs.similarity_search_with_score_by_vector( 95 | embedding=vector, 96 | k=k, 97 | filter=filter, 98 | fetch_k=fetch_k, 99 | **kwargs 100 | ) 101 | else: 102 | docs = vs.similarity_search_with_score( 103 | query=query, 104 | k=k, 105 | filter=filter, 106 | fetch_k=fetch_k, 107 | **kwargs 108 | ) 109 | 110 | d = [] 111 | for doc in docs: 112 | v = doc[0].metadata 113 | v['_distance'] = float(doc[1]) 114 | d.append(v) 115 | return d 116 | 117 | def _get_vectorstore_path(self, name): 118 | if not self._db_path: 119 | logger.warn("db_path required") 120 | return None 121 | 122 | fname = os.path.join(self._db_path, name) 123 | return fname 124 | 125 | def _get_vectorstore(self, name): 126 | if name not in self._collections: 127 | vectorstores = _import_langchain_vectorstores() 128 | 129 | vs_path = self._get_vectorstore_path(name=name) 130 | vs = vectorstores.FAISS.load_local( 131 | vs_path, self.get_embeddings(), normalize_L2=True) 132 | self._collections[name] = vs 133 | return self._collections[name] 134 | 135 | def get_embeddings(self): 136 | if not self._embeddings: 137 | raise ValueError("No default embeddings found.") 138 | 139 | from langchain.schema.embeddings import Embeddings as Embeddings_ 140 | class LCEmbeddings(Embeddings_): 141 | def __init__(self, embed) -> None: 142 | self.embed = embed 143 | 144 | def embed_documents(self, texts: List[str]) -> List[List[float]]: 145 | return self.embed.embed_batch(texts) 146 | def embed_query(self, text: str) -> List[float]: 147 | return self.embed.embed(text) 148 | 149 | return LCEmbeddings(embed=self._embeddings) 150 | -------------------------------------------------------------------------------- /myla/messages.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional, Union 2 | 3 | from pydantic import BaseModel 4 | from sqlmodel import JSON, Field, Session, select 5 | 6 | from . import _models 7 | 8 | 9 | class MessageText(BaseModel): 10 | value: str 11 | 12 | 13 | class MessageContent(BaseModel): 14 | type: str 15 | text: Optional[List[MessageText]] 16 | 17 | 18 | class MessageCreate(_models.MetadataModel): 19 | role: str 20 | content: str 21 | file_ids: Optional[List[str]] = None 22 | 23 | 24 | class MessageModify(_models.MetadataModel): 25 | pass 26 | 27 | 28 | class MessageRead(_models.ReadModel): 29 | thread_id: str 30 | assistant_id: Optional[str] = None 31 | run_id: Optional[str] = None 32 | role: Optional[str] = None 33 | content: Optional[List[MessageContent]] = None 34 | file_ids: Optional[List[str]] = [] 35 | 36 | 37 | class MessageList(_models.ListModel): 38 | data: List[MessageRead] 39 | 40 | 41 | class Message(_models.DBModel, table=True): 42 | """ 43 | Represents an assistant that can call the model and use tools. 44 | """ 45 | thread_id: Optional[str] = Field(index=True, default=None) 46 | assistant_id: Optional[str] = Field(index=True, nullable=True, default=None) 47 | run_id: Optional[str] = Field(index=True, nullable=True, default=None) 48 | role: str 49 | content: List[Dict] = Field(sa_type=JSON) 50 | file_ids: Optional[List[str]] = Field(sa_type=JSON, default=None) 51 | 52 | 53 | @_models.auto_session 54 | def create( 55 | thread_id: str, 56 | message: MessageCreate, 57 | assistant_id: Optional[str] = None, 58 | run_id: Optional[str] = None, 59 | tag: Optional[str] = None, 60 | user_id: Optional[str] = None, 61 | org_id: Optional[str] = None, 62 | session: Session = None 63 | ) -> MessageRead: 64 | db_model = Message( 65 | thread_id=thread_id, 66 | role=message.role, 67 | content=[ 68 | MessageContent(type="text", text=[MessageText(value=message.content)]).model_dump() 69 | ], 70 | assistant_id=assistant_id, 71 | run_id=run_id 72 | ) 73 | 74 | dbo = _models.create( 75 | object="thread.message", 76 | meta_model=message, 77 | db_model=db_model, 78 | tag=tag, 79 | user_id=user_id, 80 | org_id=org_id, 81 | session=session 82 | ) 83 | return dbo.to_read(MessageRead) 84 | 85 | 86 | @_models.auto_session 87 | def get(id: str, thread_id: str = None, user_id: str = None, session: Session = None) -> Union[MessageRead, None]: 88 | r = _models.get(db_cls=Message, read_cls=MessageRead, id=id, user_id=None, session=session) 89 | if not r: 90 | return None 91 | 92 | if thread_id is not None and thread_id != r.thread_id: 93 | return None 94 | 95 | #if thread_id is not None: 96 | # thread = threads.get(id=thread_id, user_id=user_id, session=session) 97 | # if not thread: 98 | # return None 99 | 100 | return r 101 | 102 | 103 | @_models.auto_session 104 | def modify(id: str, message: MessageModify, thread_id: str = None, user_id: str = None, session: Session = None) -> Union[MessageRead, None]: 105 | if thread_id is not None: 106 | msg = get(id=id, thread_id=thread_id, user_id=user_id, session=session) 107 | if not msg: 108 | return None 109 | return _models.modify(db_cls=Message, read_cls=MessageRead, id=id, to_update=message.model_dump(exclude_unset=True), user_id=user_id, session=session) 110 | 111 | 112 | @_models.auto_session 113 | def delete(id: str, thread_id: str = None, user_id: str = None, mode="soft", session: Optional[Session] = None) -> _models.DeletionStatus: 114 | if thread_id is not None: 115 | msg = get(id=id, thread_id=thread_id, user_id=user_id, session=session) 116 | if not msg: 117 | return None 118 | return _models.delete(db_cls=Message, id=id, user_id=user_id, mode=mode, session=session) 119 | 120 | 121 | @_models.auto_session 122 | def list( 123 | thread_id: str, 124 | limit: Optional[int] = 20, 125 | order: Optional[str] = "desc", 126 | after: Optional[str] = None, 127 | before: Optional[str] = None, 128 | tag: Optional[str] = None, 129 | user_id: Optional[str] = None, 130 | session: Session = None 131 | ) -> MessageList: 132 | #if user_id is not None: 133 | # thread = threads.get(id=thread_id, user_id=user_id, session=session) 134 | # if not thread: 135 | # return MessageList(data=[]) 136 | 137 | select_stmt = select(Message) 138 | select_stmt = select_stmt.filter(Message.is_deleted == False) 139 | select_stmt = select_stmt.where(Message.thread_id == thread_id) 140 | 141 | select_stmt = select_stmt.order_by(-Message.created_at if order == "desc" else Message.created_at) 142 | if after: 143 | m = get(id=after, thread_id=thread_id, user_id=user_id, session=session) 144 | if m: 145 | select_stmt = select_stmt.filter(Message.created_at > m.created_at) 146 | if before: 147 | m = get(id=before, thread_id=thread_id, user_id=user_id, session=session) 148 | if m: 149 | select_stmt = select_stmt.filter(Message.created_at < m.created_at) 150 | 151 | if tag: 152 | select_stmt = select_stmt.filter(Message.tag == tag) 153 | 154 | select_stmt = select_stmt.limit(limit) 155 | 156 | dbos = session.exec(select_stmt).all() 157 | rs = [] 158 | for dbo in dbos: 159 | rs.append(dbo.to_read(MessageRead)) 160 | r = MessageList(data=rs, first_id=rs[0].id if len(rs) > 0 else None, last_id=rs[-1].id if len(rs) > 0 else None) 161 | return r 162 | -------------------------------------------------------------------------------- /myla/runs.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Union 2 | 3 | from pydantic import BaseModel 4 | from sqlmodel import JSON, Field, Session, select 5 | 6 | from . import _models 7 | 8 | 9 | class RunEdit(_models.MetadataModel): 10 | pass 11 | 12 | 13 | class RunCreate(RunEdit): 14 | assistant_id: str 15 | model: Optional[str] = None 16 | instructions: Optional[str] = None 17 | tools: Optional[List[Dict[str, Any]]] = [] 18 | 19 | 20 | class RunModify(RunEdit): 21 | pass 22 | 23 | 24 | class RunBase(BaseModel): 25 | thread_id: Optional[str] = Field(index=True, default=None) 26 | assistant_id: str = Field(index=True) 27 | model: Optional[str] = None 28 | instructions: Optional[str] = None 29 | tools: Optional[List[Dict]] = Field(sa_type=JSON, default=None) 30 | status: Optional[str] = Field(index=True, nullable=True, default=None) 31 | required_action: Optional[Dict] = Field(sa_type=JSON, default=None) 32 | last_error: Optional[Dict] = Field(sa_type=JSON, default=None) 33 | expires_at: Optional[int] = None 34 | started_at: Optional[int] = None 35 | failed_at: Optional[int] = None 36 | completed_at: Optional[int] = None 37 | 38 | 39 | class RunRead(_models.ReadModel, RunBase): 40 | file_ids: Optional[List[str]] = None 41 | 42 | 43 | class RunList(_models.ListModel): 44 | data: List[RunRead] = [] 45 | 46 | 47 | class Run(_models.DBModel, RunBase, table=True): 48 | """ 49 | Represents an assistant that can call the model and use tools. 50 | """ 51 | 52 | 53 | class ThreadRunCreate(_models.MetadataModel): 54 | assistant_id: str 55 | thread: Optional[Dict] 56 | model: Optional[str] 57 | instructions: Optional[str] 58 | tools: Optional[List[Dict]] 59 | file_ids: Optional[List[str]] 60 | 61 | 62 | class RunStep(_models.DBModel, _models.MetadataModel): 63 | assistant_id: str = Field(index=True) 64 | thread_id: str = Field(index=True) 65 | run_id: str = Field(index=True) 66 | type: str 67 | status: str 68 | step_details: Dict 69 | last_error: Optional[Dict] 70 | expired_at: Optional[int] 71 | cancelled_at: Optional[int] 72 | failed_at: Optional[int] 73 | completed_at: Optional[int] 74 | 75 | 76 | @_models.auto_session 77 | def create(thread_id: str, run: RunCreate, user_id: str = None, org_id: str = None, session: Session = None) -> Union[RunRead, None]: 78 | db_model = Run.model_validate(run) 79 | db_model.thread_id = thread_id 80 | db_model.status = "queued" 81 | 82 | dbo = _models.create(object="thread.run", meta_model=run, db_model=db_model, user_id=user_id, org_id=org_id, session=session) 83 | 84 | return dbo.to_read(RunRead) 85 | 86 | 87 | @_models.auto_session 88 | def get(thread_id: str, run_id: str, user_id: str = None, session: Session = None) -> Union[RunRead, None]: 89 | r = _models.get(db_cls=Run, read_cls=RunRead, id=run_id, user_id=user_id, session=session) 90 | if r.thread_id != thread_id: 91 | return None 92 | return r 93 | 94 | 95 | @_models.auto_session 96 | def modify(id: str, run: RunModify, user_id: str = None, session: Session = None) -> Union[RunRead, None]: 97 | return _models.modify(db_cls=Run, read_cls=RunRead, id=id, to_update=run.model_dump(exclude_unset=True), user_id=user_id, session=session) 98 | 99 | 100 | @_models.auto_session 101 | def delete(id: str, user_id: str = None, mode="soft", session: Optional[Session] = None) -> _models.DeletionStatus: 102 | return _models.delete(db_cls=Run, id=id, user_id=user_id, mode=mode, session=session) 103 | 104 | 105 | @_models.auto_session 106 | def list( 107 | thread_id: str, 108 | limit: int = 20, 109 | order: str = "desc", 110 | after: str = None, 111 | before: str = None, 112 | user_id: str = None, 113 | org_id: str = None, 114 | session: Optional[Session] = None 115 | ) -> RunList: 116 | select_stmt = select(Run) 117 | select_stmt = select_stmt.filter(Run.is_deleted == False) 118 | 119 | if thread_id: 120 | select_stmt = select_stmt.filter(Run.thread_id == thread_id) 121 | 122 | select_stmt = select_stmt.order_by(-Run.created_at if order == "desc" else Run.created_at) 123 | 124 | if after: 125 | r = get(run_id=after, thread_id=thread_id, user_id=user_id, session=session) 126 | if r: 127 | select_stmt = select_stmt.filter(Run.created_at > r.created_at) 128 | if before: 129 | r = get(run_id=before, thread_id=thread_id, user_id=user_id, session=session) 130 | if r: 131 | select_stmt = select_stmt.filter(Run.created_at < r.created_at) 132 | 133 | if user_id: 134 | select_stmt = select_stmt.filter(Run.user_id == user_id) 135 | if org_id: 136 | select_stmt = select_stmt.filter(Run.org_id == org_id) 137 | 138 | select_stmt = select_stmt.limit(limit) 139 | 140 | dbos = session.exec(select_stmt).all() 141 | 142 | rs = [] 143 | for dbo in dbos: 144 | rs.append(dbo.to_read(RunRead)) 145 | return RunList(data=rs, first_id=rs[0].id if len(rs) > 0 else None, last_id=rs[-1].id if len(rs) > 0 else None) 146 | 147 | 148 | def cancel(thread_id: str, run_id: str, session: Session = None) -> Union[RunRead, None]: 149 | return 150 | 151 | 152 | def create_thread_and_run(thread_run: ThreadRunCreate, session: Session = None) -> Union[RunRead, None]: 153 | return None 154 | 155 | 156 | def create_step(run_id: str, step: RunStep, session: Session = None) -> Union[RunStep, None]: 157 | 158 | return None 159 | 160 | 161 | def list_steps(thread_id: str, run_id: str, session: Session = None) -> _models.ListModel: 162 | return _models.ListModel(object="thread.run.step", data=[]) 163 | 164 | 165 | def get_step(thread_id: str, run_id: str, step_id: str, session: Session = None) -> Union[RunStep, None]: 166 | return None 167 | 168 | 169 | @_models.auto_session 170 | def update(id: str, session: Session = None, **kwargs): 171 | dbo = session.get(Run, id) 172 | if dbo: 173 | for k, v in kwargs.items(): 174 | 175 | if k == 'metadata': 176 | dbo.metadata_ = v 177 | else: 178 | setattr(dbo, k, v) 179 | 180 | session.add(dbo) 181 | session.commit() 182 | session.refresh(dbo) 183 | -------------------------------------------------------------------------------- /js/src/members.js: -------------------------------------------------------------------------------- 1 | import { DeleteOutlined, PlusOutlined } from "@ant-design/icons" 2 | import { Button, Form, Input, Select, Space, Table, message } from "antd" 3 | import Link from "antd/es/typography/Link" 4 | import { useEffect, useState } from "react" 5 | 6 | export const Members = () => { 7 | const [members, setMembers] = useState() 8 | const [msg, msgContext] = message.useMessage() 9 | const [formView, setFormView] = useState(false) 10 | const [inviteMemberForm] = Form.useForm() 11 | 12 | const invite = () => { 13 | let username = inviteMemberForm.getFieldValue('username'); 14 | } 15 | 16 | const loadMembers = () => { 17 | fetch(`/api/v1/organizations/${localStorage.getItem('org_id')}/members`, { 18 | method: 'GET', 19 | headers: { 20 | 'Content-Type': 'application/json', 21 | 'OpenAI-Organization': `${localStorage.getItem('org_id')}` 22 | } 23 | }).then(r => r.json()).then(members => { 24 | setMembers(members.data); 25 | }) 26 | } 27 | 28 | const onCancel = () => { 29 | setFormView(false); 30 | inviteMemberForm.resetFields(); 31 | } 32 | 33 | 34 | const onInvite = () => { 35 | let username = inviteMemberForm.getFieldValue('username'); 36 | let role = inviteMemberForm.getFieldValue('role'); 37 | 38 | fetch(`/api/v1/organizations/${localStorage.getItem('org_id')}/members`, { 39 | method: 'POST', 40 | headers: { 41 | 'Content-Type': 'application/json', 42 | 'OpenAI-Organization': `${localStorage.getItem('org_id')}` 43 | }, 44 | body: JSON.stringify({ 45 | username: username, 46 | role: role || "reader" 47 | }) 48 | }).then(r => { 49 | if (r.status === 200) { 50 | onCancel(); 51 | loadMembers(); 52 | } else { 53 | throw new Error('Stauts: ' + r.status); 54 | } 55 | }).catch(err => { 56 | msg.error(err.message) 57 | }) 58 | } 59 | 60 | useEffect(() => { 61 | loadMembers(); 62 | }, []) 63 | 64 | return ( 65 |
66 | {msgContext} 67 | {!formView ? ( 68 |
69 | 77 |
r.user.username 85 | }, 86 | { 87 | title: 'Display Name', 88 | key: 'display_name', 89 | dataIndex: 'display_name', 90 | render: (text, r) => r.user.display_name 91 | }, 92 | { 93 | title: 'Role', 94 | key: 'role', 95 | dataIndex: 'role', 96 | render: (text, r) => r.role 97 | }, 98 | { 99 | title: '', 100 | key: 'actions', 101 | render: (_, r) => ( 102 | 103 | ) 104 | } 105 | ]} 106 | locale={{ emptyText: ' ' }} 107 | dataSource={members} 108 | pagination={false} 109 | /> 110 | 111 | ) : ( 112 |
113 | < Back 114 |
118 | 125 | 129 | 130 | 131 | 138 |