├── assets └── images │ └── sergey.png ├── logos_shift_client ├── __init__.py ├── bohita.py ├── router.py └── logos_shift.py ├── .bumpversion.cfg ├── Makefile ├── .github └── workflows │ ├── test_and_build.yml │ └── release.yml ├── LICENSE ├── setup.py ├── tests ├── test_router.py └── test_logos_shift.py ├── .gitignore └── README.md /assets/images/sergey.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/virevolai/logos-shift-client/HEAD/assets/images/sergey.png -------------------------------------------------------------------------------- /logos_shift_client/__init__.py: -------------------------------------------------------------------------------- 1 | from .logos_shift import LogosShift # noqa 2 | from .router import APIRouter # noqa 3 | -------------------------------------------------------------------------------- /.bumpversion.cfg: -------------------------------------------------------------------------------- 1 | [bumpversion] 2 | current_version = 0.11.0 3 | commit = True 4 | tag = True 5 | tag_name = v{new_version} 6 | 7 | [bumpversion:file:setup.py] 8 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: install install-dev test lint lintgit clean build upload 2 | 3 | install: 4 | pip install . 5 | 6 | install-dev: 7 | pip install .[dev] 8 | 9 | test: 10 | pytest 11 | 12 | lint: 13 | ruff check logos_shift_client 14 | ruff format logos_shift_client 15 | 16 | lintgit: 17 | ruff check --output-format=github logos_shift_client 18 | 19 | clean: 20 | rm -rf dist build *.egg-info 21 | 22 | build: clean 23 | python setup.py sdist bdist_wheel 24 | 25 | # Upload the package to PyPI 26 | upload: 27 | twine upload dist/* 28 | 29 | all: install-dev lint lintgit test build 30 | 31 | -------------------------------------------------------------------------------- /.github/workflows/test_and_build.yml: -------------------------------------------------------------------------------- 1 | name: Lint and Test 2 | 3 | on: 4 | push: 5 | branches: 6 | - '**' 7 | paths-ignore: 8 | - '**.md' 9 | 10 | jobs: 11 | test_and_build: 12 | runs-on: ubuntu-latest 13 | strategy: 14 | fail-fast: false 15 | matrix: 16 | python-version: ["3.9", "3.10", "3.11"] 17 | 18 | steps: 19 | - name: Checkout code 20 | uses: actions/checkout@v4 21 | 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v4 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | 27 | - name: Install dependencies 28 | run: | 29 | pip install . 30 | pip install .[dev] 31 | pip install setuptools wheel twine 32 | 33 | - name: Lint 34 | run: make lintgit 35 | 36 | - name: Test 37 | run: make test 38 | 39 | - name: Build 40 | run: make build 41 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 virevol AI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | push: 5 | branches: 6 | - 'main' 7 | paths-ignore: 8 | - '**.md' 9 | 10 | jobs: 11 | release: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - name: Checkout code 15 | uses: actions/checkout@v4 16 | 17 | - name: Set up Python 18 | uses: actions/setup-python@v4 19 | with: 20 | python-version: 3 21 | 22 | - name: Install dependencies 23 | run: | 24 | pip install . 25 | pip install .[dev] 26 | pip install setuptools wheel twine 27 | pip install bump2version 28 | 29 | - name: Configure Git user 30 | run: | 31 | git config user.email "actions@github.com" 32 | git config user.name "GitHub Actions" 33 | 34 | - name: Bump version 35 | run: bump2version minor 36 | 37 | - name: Build 38 | run: make build 39 | 40 | - name: Push changes 41 | run: | 42 | git push origin HEAD:main 43 | git push origin --tags 44 | env: 45 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 46 | 47 | - name: Publish package 48 | uses: pypa/gh-action-pypi-publish@release/v1 49 | with: 50 | user: __token__ 51 | password: ${{ secrets.PYPI_API_TOKEN }} 52 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | from pathlib import Path 3 | this_directory = Path(__file__).parent 4 | long_description = (this_directory / "README.md").read_text() 5 | 6 | setup( 7 | name="logos-shift-client", 8 | version="0.11.0", 9 | author="Saurabh Bhatnagar", 10 | author_email="saurabh@virevol.com", 11 | description="Switch your current LLM with a finetuned one automatically, no additional latency", 12 | url="https://api.bohita.com", 13 | long_description=long_description, 14 | long_description_content_type='text/markdown', 15 | packages=find_packages(), 16 | install_requires=[ 17 | "requests", 18 | "asyncio", 19 | "tenacity", 20 | "httpx", 21 | ], 22 | extras_require={"dev": ["pytest", "ruff>=0.1.2", "bump2version==1.0.1"]}, 23 | classifiers=[ 24 | "Programming Language :: Python :: 3", 25 | "License :: OSI Approved :: MIT License", 26 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 27 | "Topic :: Text Processing :: Linguistic", 28 | "Topic :: Software Development :: Libraries :: Application Frameworks", 29 | "Framework :: AsyncIO", 30 | ], 31 | project_urls={ 32 | 'Homepage': 'https://api.bohita.com', 33 | 'Repository': 'https://github.com/virevolai/logos-shift-client', 34 | } 35 | ) 36 | -------------------------------------------------------------------------------- /tests/test_router.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from logos_shift_client import APIRouter 4 | 5 | # Mocks for old and new APIs to capture calls 6 | old_api_called = False 7 | new_api_called = False 8 | 9 | 10 | def mock_old_api(**kwargs): 11 | global old_api_called 12 | old_api_called = True 13 | return "old_api_response" 14 | 15 | 16 | def mock_new_api(**kwargs): 17 | global new_api_called 18 | new_api_called = True 19 | return "new_api_response" 20 | 21 | 22 | @pytest.fixture 23 | def setup_router(): 24 | router = APIRouter(threshold=0.5, mode="random") 25 | router.call_old_api = mock_old_api 26 | router.call_new_api = mock_new_api 27 | return router 28 | 29 | 30 | def test_random_routing(setup_router): 31 | global old_api_called, new_api_called 32 | 33 | # Resetting the mock flags 34 | old_api_called, new_api_called = False, False 35 | 36 | func_to_call = setup_router.get_api_to_call(mock_old_api) 37 | result = func_to_call() 38 | 39 | assert result in ["old_api_response", "new_api_response"] 40 | assert old_api_called or new_api_called, "Neither old nor new API was called" 41 | 42 | 43 | def test_user_based_routing(setup_router): 44 | global old_api_called, new_api_called 45 | 46 | # Changing the mode to user_based 47 | setup_router.mode = "user_based" 48 | 49 | # Resetting the mock flags 50 | old_api_called, new_api_called = False, False 51 | 52 | func_to_call = setup_router.get_api_to_call(mock_old_api, user_id="test_user") 53 | result = func_to_call() 54 | 55 | assert result in ["old_api_response", "new_api_response"] 56 | assert old_api_called or new_api_called, "Neither old nor new API was called" 57 | 58 | 59 | def test_async_routing(setup_router): 60 | global old_api_called, new_api_called 61 | 62 | # Mocking async versions of the APIs 63 | async def async_mock_old_api(**kwargs): 64 | global old_api_called 65 | old_api_called = True 66 | return "old_api_response" 67 | 68 | async def async_mock_new_api(**kwargs): 69 | global new_api_called 70 | new_api_called = True 71 | return "new_api_response" 72 | 73 | setup_router.call_old_api = async_mock_old_api 74 | setup_router.call_new_api = async_mock_new_api 75 | 76 | # Resetting the mock flags 77 | old_api_called, new_api_called = False, False 78 | 79 | async def run_test(): 80 | func_to_call = setup_router.get_api_to_call(async_mock_old_api) 81 | result = await func_to_call() 82 | assert result in ["old_api_response", "new_api_response"] 83 | assert old_api_called or new_api_called, "Neither old nor new API was called" 84 | 85 | # Run the async test 86 | import asyncio 87 | 88 | asyncio.run(run_test()) 89 | -------------------------------------------------------------------------------- /tests/test_logos_shift.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | 4 | import pytest 5 | 6 | from logos_shift_client import LogosShift 7 | 8 | logger = logging.getLogger(__name__) 9 | logger.setLevel(logging.DEBUG) 10 | 11 | ch = logging.StreamHandler() 12 | ch.setLevel(logging.DEBUG) 13 | 14 | formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") 15 | ch.setFormatter(formatter) 16 | logger.addHandler(ch) 17 | 18 | 19 | # Mock for the send_data function to capture data 20 | mock_data_buffer = [] 21 | 22 | 23 | def mock_send_data(data, dataset="default"): 24 | print(f"Sending ({data}, {dataset})") 25 | mock_data_buffer.append((data, dataset)) 26 | 27 | 28 | def wait_for_data(buffer, timeout=20): 29 | start_time = time.time() 30 | while time.time() - start_time < timeout: 31 | if buffer: 32 | return True 33 | time.sleep(0.1) # Check every 100 milliseconds 34 | return False 35 | 36 | 37 | @pytest.fixture 38 | def setup_logos_shift(): 39 | logos_shift = LogosShift(api_key="YOUR_API_KEY", max_entries=1, check_seconds=0.5) 40 | 41 | # Override the actual send_data method with our mock for testing 42 | logos_shift.buffer_manager.send_data = mock_send_data 43 | 44 | print( 45 | "config: ", 46 | logos_shift.max_entries, 47 | logos_shift.buffer_manager.check_seconds, 48 | logos_shift.buffer_manager.send_data, 49 | ) 50 | return logos_shift 51 | 52 | 53 | def test_basic_function_call(setup_logos_shift): 54 | @setup_logos_shift() 55 | def add(x, y): 56 | return x + y 57 | 58 | mock_data_buffer.clear() 59 | result = add(1, 2) 60 | assert result == 3 61 | 62 | time.sleep(1) 63 | print(f"mock_data_buffer: {mock_data_buffer}") 64 | assert wait_for_data(mock_data_buffer), "Timeout waiting for data" 65 | 66 | # Check if logos_shift captured the correct data 67 | expected_data = { 68 | "input": ((1, 2), {}), 69 | "output": 3, 70 | "dataset": "default", 71 | "metadata": {"function": "add"}, 72 | } 73 | assert any( 74 | item[0] == expected_data for item in mock_data_buffer 75 | ), "Expected data not found in mock_data_buffer" 76 | 77 | 78 | def test_dataset_parameter(setup_logos_shift): 79 | @setup_logos_shift(dataset="test_dataset") 80 | def subtract(x, y): 81 | return x - y 82 | 83 | mock_data_buffer.clear() 84 | result = subtract(5, 3) 85 | assert result == 2 86 | 87 | time.sleep(1) 88 | print(f"mock_data_buffer: {mock_data_buffer}") 89 | assert wait_for_data(mock_data_buffer), "Timeout waiting for data" 90 | 91 | assert any( 92 | item[1] == "test_dataset" for item in mock_data_buffer 93 | ), "Expected dataset not found in mock_data_buffer" 94 | -------------------------------------------------------------------------------- /logos_shift_client/bohita.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import httpx 3 | import logging 4 | 5 | BASE_URL = "https://logos-shift-sink-6kso2cgttq-uc.a.run.app" 6 | TIMEOUT = 10 # seconds 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | class BohitaClient: 11 | def __init__(self, api_key: str): 12 | if api_key is None: 13 | logging.warning( 14 | "No API KEY provided. No data will be sent to Bohita and automatic routing will not happen" 15 | ) 16 | self.headers = None 17 | else: 18 | self.headers = { 19 | "Content-Type": "application/json", 20 | "Bohita-Auth": f"Bearer {api_key}", 21 | } 22 | self.async_client = httpx.AsyncClient(headers=self.headers, timeout=TIMEOUT) 23 | 24 | def post_instrumentation_data(self, data, dataset): 25 | if not self.headers: 26 | return 27 | try: 28 | response = requests.post( 29 | f"{BASE_URL}/instrumentation/", 30 | headers=self.headers, 31 | json={**data, "dataset": dataset}, 32 | timeout=TIMEOUT, 33 | ) 34 | response.raise_for_status() 35 | except requests.RequestException as e: 36 | logger.error("Failed to post instrumentation data: %s", str(e)) 37 | 38 | async def post_instrumentation_data_async(self, data, dataset): 39 | if not self.headers: 40 | return 41 | try: 42 | response = await self.async_client.post( 43 | f"{BASE_URL}/instrumentation/", json={**data, "dataset": dataset} 44 | ) 45 | response.raise_for_status() 46 | except httpx.RequestError as e: 47 | logger.error("Failed to post instrumentation data: %s", str(e)) 48 | 49 | def get_config(self): 50 | if not self.headers: 51 | return {} 52 | try: 53 | response = requests.get( 54 | f"{BASE_URL}/config", headers=self.headers, timeout=TIMEOUT 55 | ) 56 | response.raise_for_status() 57 | return response.json() 58 | except requests.RequestException as e: 59 | logger.error("Failed to get configuration: %s", str(e)) 60 | return {} 61 | 62 | async def get_config_async(self): 63 | if not self.headers: 64 | return {} 65 | try: 66 | response = await self.async_client.get(f"{BASE_URL}/config") 67 | response.raise_for_status() 68 | return response.json() 69 | except httpx.RequestError as e: 70 | logger.error("Failed to get configuration: %s", str(e)) 71 | return {} 72 | 73 | def predict(self, **kwargs): 74 | if not self.headers: 75 | return 76 | try: 77 | response = requests.post( 78 | f"{BASE_URL}/predict", 79 | headers=self.headers, 80 | json=kwargs, 81 | timeout=TIMEOUT, 82 | ) 83 | response.raise_for_status() 84 | return response.json() 85 | except requests.RequestException as e: 86 | logger.error("Failed to make prediction: %s", str(e)) 87 | 88 | async def predict_async(self, **kwargs): 89 | if not self.headers: 90 | return 91 | try: 92 | response = await self.async_client.post(f"{BASE_URL}/predict", json=kwargs) 93 | response.raise_for_status() 94 | return response.json() 95 | except httpx.RequestError as e: 96 | logger.error("Failed to make prediction: %s", str(e)) 97 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | # vim 163 | *.swp 164 | -------------------------------------------------------------------------------- /logos_shift_client/router.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import logging 3 | import random 4 | import asyncio 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | class APIRouter: 10 | """ 11 | APIRouter is responsible for routing API calls based on the provided configuration. 12 | 13 | It supports three modes: 14 | - "never": Always use the old API. 15 | - "random": Randomly choose between the old and new API based on a threshold. 16 | - "user_based": Decide based on a hash of the user ID. 17 | 18 | Attributes: 19 | bohita_client (BohitaClient): The client used to communicate with the Bohita platform. 20 | threshold (float): The percentage of requests to route to the new API. Default is 0.1 (10%). 21 | mode (str): The routing mode. Can be "never", "random", or "user_based". Default is "never". 22 | call_count (int): The number of API calls made. 23 | conf_frequency (int): How frequently to fetch configuration updates from the server. 24 | 25 | Examples: 26 | >>> router = APIRouter(bohita_client, threshold=0.2, mode="random") 27 | >>> api_to_call = router.get_api_to_call(old_api_func) 28 | """ 29 | 30 | def __init__(self, bohita_client=None, threshold=0.1, mode="never"): 31 | """ 32 | Initializes a new instance of APIRouter. 33 | 34 | Args: 35 | bohita_client (Optional[BohitaClient]): An instance of BohitaClient used to communicate with the Bohita platform. 36 | threshold (float): The percentage of requests to route to the new API. Default is 0.1 (10%). 37 | mode (str): The routing mode. Can be "never", "random", or "user_based". Default is "never". 38 | """ 39 | self.bohita_client = bohita_client 40 | if not 0 <= threshold <= 1: 41 | raise ValueError("Threshold must be between 0 and 1") 42 | self.threshold = threshold # precentage of requests to new API 43 | self.mode = mode # "never", "random" or "user_based" 44 | self.call_count, self.conf_frequency = ( 45 | 0, 46 | 1_000, 47 | ) # How frequently to fetch config 48 | logger.info(f"Initialized {mode} router") 49 | # Fails in async context, disable for now 50 | # self._get_configuration() 51 | 52 | async def _get_configuration_common(self, is_async): 53 | """ 54 | Fetches the routing configuration from the Bohita platform and updates the router's settings. 55 | 56 | This method is called periodically based on the conf_frequency setting. 57 | """ 58 | try: 59 | logger.info("Checking for config updates") 60 | if is_async: 61 | config = await self.bohita_client.get_config_async() 62 | else: 63 | config = self.bohita_client.get_config() 64 | self.threshold = config.get("threshold", self.threshold) 65 | self.mode = config.get("mode", self.mode) 66 | self.conf_frequency = config.get("frequency", self.conf_frequency) 67 | logger.info("Configuration updated successfully") 68 | except Exception as e: 69 | logger.warning("Could not get configuration from server: %s", str(e)) 70 | logger.warning("If the problem persists, this instance might be stale") 71 | 72 | def _get_configuration(self): 73 | asyncio.run(self._get_configuration_common(False)) 74 | 75 | async def _get_configuration_async(self): 76 | await self._get_configuration_common(True) 77 | 78 | def _get_user_hash(self, user_id): 79 | return int(hashlib.md5(str(user_id).encode()).hexdigest(), 16) 80 | 81 | def should_route_to_new_api(self, user_id=None): 82 | """ 83 | Determines whether the next API call should be routed to the new API based on the current mode and threshold. 84 | 85 | Args: 86 | user_id (Optional[str]): The user ID for user-based routing. Required if mode is "user_based". 87 | 88 | Returns: 89 | bool: True if the call should be routed to the new API, False otherwise. 90 | """ 91 | if self.mode == "random": 92 | return random.random() < self.threshold 93 | elif self.mode == "user_based": 94 | if user_id: 95 | return self._get_user_hash(user_id) % 100 < self.threshold * 100 96 | return False 97 | 98 | def get_api_to_call(self, old_api_func, user_id=None): 99 | """ 100 | Determines which API function to call based on the routing configuration. 101 | 102 | Args: 103 | old_api_func (callable): The old API function. 104 | user_id (Optional[str]): The user ID for user-based routing. 105 | 106 | Returns: 107 | callable: The API function to call. 108 | """ 109 | self.call_count += 1 110 | if self.call_count % self.conf_frequency == 0: 111 | self._get_configuration() 112 | if self.should_route_to_new_api(user_id): 113 | return self.call_new_api 114 | return old_api_func 115 | 116 | async def get_api_to_call_async(self, old_api_func, user_id=None): 117 | """ 118 | Determines which API function to call based on the routing configuration. 119 | 120 | Args: 121 | old_api_func (callable): The old API function. 122 | user_id (Optional[str]): The user ID for user-based routing. 123 | 124 | Returns: 125 | callable: The API function to call. 126 | """ 127 | self.call_count += 1 128 | if self.call_count % self.conf_frequency == 0: 129 | await self._get_configuration_async() 130 | if self.should_route_to_new_api(user_id): 131 | return self.call_new_api_async 132 | return old_api_func 133 | 134 | async def call_new_api_async(self, **kwargs): 135 | await self.bohita_client.predict_async(**kwargs) 136 | 137 | def call_new_api(self, **kwargs): 138 | self.bohita_client.predict(**kwargs) 139 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![Tests](https://github.com/virevolai/logos-shift-client/actions/workflows/test_and_build.yml/badge.svg) 2 | 3 | --- 4 | 5 | # Logos Shift 6 | 7 | **Replace expensive GPT/Claude calls with smaller, faster finetuned Llama/Mistral automatically** 8 | 9 | Integrating Large Language Models (LLMs) into production systems can be a convoluted process, with myriad challenges to overcome. While several tools offer call instrumentation, **Logos Shift** sets itself apart with its game-changing feature: **automated rollout of your fine-tuned model**. Just integrate with a single line of code and let us notify you when your fine-tuned model is ready for deployment. 10 | 11 | You can also do this yourself for free. LogosShift is the simplest and best way to do this for hackers. 12 | 13 | Pssst: Can do this for any API, not just LLMs 14 | 15 | 16 | ## Key Feature 17 | 18 | - **Effortless A/B Rollout**: Once your fine-tuned model achieves readiness, it's rolled out as an A/B test. No manual intervention, no complex configurations. Just simplicity. 19 | 20 | ## Other Features 21 | 22 | - **No Proxying**: Direct calls, eliminating the latency of proxying. 23 | - **Retain Your OpenAI Key**: Your OpenAI key remains yours, safeguarding confidentiality. No leaked keys. 24 | - **Feedback-Driven Finetuning**: Refine model performance with feedback based on unique result IDs. 25 | - **Open Source**: Flexibility to modify and adapt as needed. 26 | - **Dynamic Configuration**: Synchronize with server configurations on-the-fly. 27 | - Simplicity: Simple is beautiful. Minimal dependancies. 28 | - No lock-in: Can optionally save data to local drive if you want to train model yourself. 29 | - **Upcoming**: 30 | - Error fallback mechanisms 31 | 32 | ## Why 33 | 34 | ![sergey](assets/images/sergey.png) 35 | 36 | At [Bohita](https://bohita.com), our pioneering efforts in deploying Large Language Models (LLMs) in production environments have brought forth unique challenges, especially concerning cost management, latency reduction, and optimization. The solutions available in the market weren't adequate for our needs, prompting us to develop and subsequently open-source some of our bespoke tools. 37 | 38 | On the subject of proxying: We prioritize the reliability and uptime of our services. By introducing an additional domain as a dependency, we'd inherently be reducing our uptime. Specifically, the probability of combined uptime would be \(1 - (P_A\_up\_B\_down + P_B\_up\_A\_down + P_both\_down)\), which is inherently less than the uptime of either individual service. Given the inherent unpredictability of APIs in today's landscape, compromising our reliability in this manner is not a trade-off we're willing to make. 39 | 40 | ## Getting Started 41 | 42 | ### Prerequisites 43 | 44 | - Obtain an API key from [Bohita Logos Shift Portal](https://api.bohita.com). 45 | 46 | ### Installation 47 | 48 | ```bash 49 | pip install logos_shift_client 50 | ``` 51 | 52 | ### Basic Usage 53 | 54 | ```python 55 | from logos_shift_client import LogosShift 56 | 57 | # Initialize with your API key (without if you just want the local copy) 58 | logos_shift = LogosShift(api_key="YOUR_API_KEY") 59 | 60 | # Instrument your function 61 | @logos_shift() 62 | def add(x, y): 63 | return x + y 64 | 65 | result = add(1, 2) 66 | 67 | # Optionally, provide feedback 68 | logos_shift.provide_feedback(result['bohita_logos_shift_id_123', "success") 69 | ``` 70 | 71 | ## How It Works 72 | 73 | Here's a high-level overview: 74 | 75 | ```mermaid 76 | graph LR 77 | A["Client Application"] 78 | B["Logos Shift Client"] 79 | C["Buffer Manager Thread"] 80 | D["Logos Server"] 81 | E["Expensive LLM Client"] 82 | F["Cheap LLM Client"] 83 | G["API Router A/B Test Rollout"] 84 | 85 | A -->|"Function Call"| B 86 | B -->|"Capture & Buffer Data"| C 87 | C -->|"Send Data"| D 88 | B -->|"Route API Call"| G 89 | G -->|"Expensive API Route"| E 90 | G -->|"Cheap API Route"| F 91 | 92 | classDef mainClass fill:#0077b6,stroke:#004c8c,stroke-width:2px,color:#fff; 93 | classDef api fill:#90e0ef,stroke:#0077b6,stroke-width:2px,color:#333; 94 | classDef buffer fill:#48cae4,stroke:#0077b6,stroke-width:2px,color:#333; 95 | classDef expensive fill:#d00000,stroke:#9d0208,stroke-width:2px,color:#fff; 96 | classDef cheap fill:#52b788,stroke:#0077b6,stroke-width:2px,color:#333; 97 | 98 | class A,B,D,G mainClass 99 | class E expensive 100 | class F cheap 101 | class C buffer 102 | ``` 103 | 104 | ## Dataset 105 | 106 | All function calls are grouped into datasets. Think of this as the usecase those calls are made for. 107 | If you are intrumentating just one call, then you don't need to do anything. The default dataset is 'default'. 108 | 109 | If you have different usecases in your application (E.g, chatbot for sales, vs chatbot for help), you should separate them out. 110 | 111 | ```python 112 | @logos_shift(dataset="sales") 113 | def add_sales(x, y): 114 | return x + y 115 | 116 | @logos_shift(dataset="help") 117 | def add_help(x, y): 118 | return x + y 119 | ``` 120 | 121 | This helps you track them separately and also finetune them separately for each use case. 122 | 123 | ## Metadata 124 | 125 | You can provide additional metadata, including `user_id`, which can be used for routing decisions based on user-specific details. 126 | 127 | ```python 128 | @logos_shift() 129 | def multiply(x, y, logos_shift_metadata={"user_id": "12345"}): 130 | return x * y 131 | ``` 132 | 133 | ## Feedback 134 | 135 | Using feedback you can get better models that will be cheaper and more effective. 136 | 137 | If you don't have feedback, it will be auto-regressive as usual. 138 | 139 | ## Configuration Retrieval 140 | 141 | The library will also support retrieving configurations every few minutes, ensuring your logos_shift adapts to dynamic environments. 142 | 143 | ## Local Copy 144 | 145 | Initialize with a filename to keep a local copy. You can also run it without Bohita, just set api_key to None 146 | 147 | ```python 148 | logos_shift = LogosShift(api_key="YOUR_API_KEY", filename="api_calls.log") 149 | 150 | # Can also disable sending data to Bohita. However, you will lose the automatic routing. 151 | logos_shift = LogosShift(api_key=None, filename="api_calls.log") 152 | ``` 153 | 154 | 155 | ## Best Practices 156 | 157 | When using Logos Shift to integrate Large Language Models (LLMs) into your applications, it’s crucial to tailor the integration to the specific outputs and outcomes that are most relevant to your use case. Below are some best practices to help you maximize the effectiveness of Logos Shift. 158 | 159 | ### Focus on Relevant Outputs 160 | 161 | When instrumenting your functions with Logos Shift, aim to return the specific parts of the output that are most pertinent to your application’s needs. 162 | 163 | #### Not Recommended: 164 | 165 | ```python 166 | @logos_shift(dataset="story_raw") 167 | def get_story(model, messages): 168 | """Generates a story""" 169 | completion = openai.ChatCompletion.create(model=model, messages=messages) 170 | return completion 171 | ``` 172 | 173 | In the above example, the entire completion object is returned, which might include a lot of information that is not directly relevant to your application. 174 | 175 | #### Recommended: 176 | 177 | ```python 178 | @logos_shift(dataset="story") 179 | def get_story(model, messages): 180 | """Generates a story""" 181 | completion = openai.ChatCompletion.create(model=model, messages=messages) 182 | result = completion.to_dict_recursive() 183 | story_d = {'story': result['choices'][0]['message']} 184 | return story_d 185 | ``` 186 | 187 | In this improved version, the function returns a dictionary with just the story part of the completion. This approach ensures that the data captured by Logos Shift is more concise and directly related to your application's primary function. 188 | 189 | ### Providing Feedback 190 | 191 | Providing feedback on specific outcomes is crucial for fine-tuning your models and ensuring accurate A/B test rollouts. 192 | 193 | ```python 194 | story_d = get_story() 195 | # ... your application logic ... 196 | logos_shift.provide_feedback(story_d['bohita_logos_shift_id'], "success") 197 | ``` 198 | 199 | In this example, `provide_feedback` is called with the result ID and an outcome string ("success" in this case). This helps in two ways: 200 | 201 | 1. **Accurate A/B Test Measurements**: The feedback ensures that the A/B test rollouts are based on actual outcomes, providing a true measure of the model’s performance in real-world scenarios. 202 | 2. **Targeted Fine-Tuning**: By providing feedback on specific outcomes, you help in fine-tuning the model to better suit your application’s needs, leading to more effective and efficient model performance over time. 203 | 204 | Adopting these best practices will help you leverage the full potential of Logos Shift, ensuring that your integration with LLMs is not just seamless but also highly effective and _outcome-driven_. 205 | 206 | ## Contribute 207 | 208 | Feel free to fork, open issues, and submit PRs. For major changes, please open an issue first to discuss what you'd like to change. 209 | 210 | ## License 211 | 212 | This project is licensed under the MIT License. 213 | -------------------------------------------------------------------------------- /logos_shift_client/logos_shift.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | import threading 4 | import time 5 | import uuid 6 | from pathlib import Path 7 | from collections import deque 8 | from typing import Optional, Union 9 | 10 | from tenacity import retry, wait_fixed 11 | 12 | from .bohita import BohitaClient 13 | from .router import APIRouter 14 | 15 | logger = logging.getLogger(__name__) 16 | MAX_ENTRIES = 10 17 | CHECK_SECONDS = 5 18 | 19 | 20 | class SingletonMeta(type): 21 | _instances = {} 22 | _lock = threading.Lock() 23 | 24 | def __call__(cls, *args, **kwargs): 25 | with cls._lock: 26 | if cls not in cls._instances: 27 | instance = super().__call__(*args, **kwargs) 28 | cls._instances[cls] = instance 29 | return cls._instances[cls] 30 | 31 | 32 | class BufferManager(metaclass=SingletonMeta): 33 | """ 34 | A singleton class responsible for managing data buffers and sending data to a remote server. 35 | 36 | Attributes: 37 | bohita_client: An instance of BohitaClient used to send data to the remote server. 38 | check_seconds: The interval in seconds between checks to send data from the buffers. 39 | filepath: The file path for local data storage. If None, data is not stored locally. 40 | buffers: A list of data buffers. 41 | thread: The thread responsible for sending data from the buffers. 42 | """ 43 | 44 | _instance = None 45 | lock = threading.Lock() 46 | 47 | def __init__( 48 | self, 49 | bohita_client: BohitaClient, 50 | check_seconds: int = CHECK_SECONDS, 51 | filename: Optional[Union[str, Path]] = None, 52 | ): 53 | self.bohita_client = bohita_client 54 | self.check_seconds = check_seconds 55 | self.open_handle(filename) 56 | self.buffers = [] 57 | self.thread = threading.Thread(target=self.send_data_from_buffers, daemon=True) 58 | self.thread.start() 59 | logger.info("BufferManager: Initialized and sending thread started.") 60 | 61 | def open_handle(self, filename: str): 62 | if filename: 63 | filepath = Path(filename) 64 | logdir = filepath.parent 65 | if not logdir.exists(): 66 | raise Exception(f"Directory {logdir} does not exist!") 67 | self.file_handle = open(filepath, "a", buffering=1) 68 | logger.debug(f"Buffered file handler opened for local file {filename}") 69 | else: 70 | self.file_handle = None 71 | 72 | def __del__(self): 73 | if self.file_handle: 74 | self.file_handle.close() 75 | logger.debug("Buffered file handle closed") 76 | 77 | def _write_to_local(self, data): 78 | try: 79 | if self.file_handle: 80 | self.file_handle.write(str(data) + "\n") 81 | except Exception as e: 82 | logger.error( 83 | "Could not save to local file. This might happen because local file format is simple. Local does str(data)" 84 | ) 85 | logger.exception(e) 86 | 87 | @retry(wait=wait_fixed(3)) 88 | def send_data(self, data, dataset="default"): 89 | logger.info(f"BufferManager: Sending data to dataset {dataset}. Data: {data}") 90 | self.bohita_client.post_instrumentation_data(data, dataset) 91 | self._write_to_local(data) 92 | 93 | def send_data_from_buffers(self): 94 | while True: 95 | time.sleep(self.check_seconds) 96 | for buffer in self.buffers: 97 | with buffer["lock"]: 98 | if buffer["data"]: 99 | data_to_send = list(buffer["data"]) 100 | buffer["data"].clear() 101 | for item in data_to_send: 102 | logger.debug(f"Sending {item}") 103 | self.send_data(item, dataset=item["dataset"]) 104 | 105 | def register_buffer(self, buffer, lock): 106 | self.buffers.append({"data": buffer, "lock": lock}) 107 | 108 | 109 | class LogosShift: 110 | """ 111 | LogosShift is a tool for capturing, logging, and optionally sending function call data to a remote server using rollouts. 112 | 113 | It allows developers to easily instrument their functions, capturing input arguments, output results, metadata, and optionally sending this data to the Bohita platform for further analysis. Data can also be stored locally. 114 | 115 | It supports both synchronous and asynchronous functions. For asynchronous functions, it automatically detects and wraps them accordingly. 116 | 117 | Attributes: 118 | bohita_client (BohitaClient): The client used to send data to the Bohita platform. 119 | max_entries (int): The maximum number of entries to store in a buffer before switching to the next buffer. 120 | buffer_A (collections.deque): The first data buffer. 121 | buffer_B (collections.deque): The second data buffer. 122 | active_buffer (collections.deque): The currently active data buffer. 123 | lock (threading.Lock): A lock to ensure thread-safety when modifying the buffers. 124 | buffer_manager (BufferManager): The manager for handling data buffers and sending data. 125 | router (APIRouter): The router for determining which API to call based on the function and user. 126 | 127 | Examples: 128 | >>> logos_shift = LogosShift(api_key="YOUR_API_KEY") 129 | >>> @logos_shift() 130 | ... def add(x, y): 131 | ... return x + y 132 | ... 133 | >>> result = add(1, 2) 134 | 135 | Asynchronous function: 136 | >>> @logos_shift() 137 | ... async def add_async(x, y): 138 | ... return x + y 139 | ... 140 | >>> result = await add_async(1, 2) 141 | 142 | To provide feedback: 143 | >>> logos_shift.provide_feedback(result['bohita_logos_shift_id'], "success") 144 | 145 | To specify a dataset: 146 | >>> @logos_shift(dataset="sales") 147 | ... def add_sales(x, y): 148 | ... return x + y 149 | 150 | Using metadata: 151 | >>> @logos_shift() 152 | ... def multiply(x, y, logos_shift_metadata={"user_id": "12345"}): 153 | ... return x * y 154 | 155 | To store data locally: 156 | >>> logos_shift = LogosShift(api_key="YOUR_API_KEY", filename="api_calls.log") 157 | 158 | To disable sending data to Bohita: 159 | >>> logos_shift = LogosShift(api_key=None, filename="api_calls.log") 160 | """ 161 | 162 | def __init__( 163 | self, 164 | api_key, 165 | bohita_client=None, 166 | router=None, 167 | max_entries=MAX_ENTRIES, 168 | check_seconds=CHECK_SECONDS, 169 | filename=None, 170 | ): 171 | """ 172 | Initializes a new instance of LogosShift. 173 | 174 | Args: 175 | api_key (str): Your API key for the Bohita platform. 176 | bohita_client (Optional[BohitaClient]): An optional instance of BohitaClient. If not provided, a new instance will be created. 177 | router (Optional[APIRouter]): An optional instance of APIRouter. If not provided, a new instance will be created. 178 | max_entries (int): The maximum number of entries to store in a buffer before switching to the next buffer. Default is 10. 179 | check_seconds (int): The interval in seconds between checks to send data from the buffers. Default is 5. 180 | filename (Optional[Union[str, Path]]): The file path for local data storage. If None, data is not stored locally. 181 | 182 | Examples: 183 | >>> logos_shift = LogosShift(api_key="YOUR_API_KEY") 184 | >>> logos_shift = LogosShift(api_key="YOUR_API_KEY", filename="api_calls.log") 185 | """ 186 | self.max_entries = max_entries 187 | self.bohita_client = ( 188 | bohita_client if bohita_client else BohitaClient(api_key=api_key) 189 | ) 190 | self.buffer_A, self.buffer_B = deque(), deque() 191 | self.active_buffer = self.buffer_A 192 | self.lock = threading.Lock() 193 | self.buffer_manager = BufferManager( 194 | bohita_client=self.bohita_client, 195 | check_seconds=check_seconds, 196 | filename=filename, 197 | ) 198 | self.buffer_manager.register_buffer(self.buffer_A, self.lock) 199 | self.buffer_manager.register_buffer(self.buffer_B, self.lock) 200 | self.router = router if router else APIRouter(bohita_client=self.bohita_client) 201 | logger.info("LogosShift: Initialized.") 202 | 203 | def handle_data(self, result, dataset, args, kwargs, metadata): 204 | if isinstance(result, dict): 205 | result["bohita_logos_shift_id"] = str(uuid.uuid4()) 206 | data = { 207 | "input": (args, kwargs), 208 | "output": result, 209 | "dataset": dataset, 210 | "metadata": metadata, 211 | } 212 | with self.lock: 213 | # Switch buffers if necessary 214 | if len(self.active_buffer) >= self.max_entries: 215 | logger.debug("Switching buffer") 216 | if self.active_buffer is self.buffer_A: 217 | self.active_buffer = self.buffer_B 218 | else: 219 | self.active_buffer = self.buffer_A 220 | self.active_buffer.append(data) 221 | logger.debug("Added data to active buffer") 222 | return result 223 | 224 | def _wrap_common_sync(self, func, dataset, *args, **kwargs): 225 | logger.debug( 226 | f"LogosShift: Wrapping function {func.__name__}. Args: {args}, Kwargs: {kwargs}" 227 | ) 228 | metadata = kwargs.pop("logos_shift_metadata", {}) 229 | metadata["function"] = func.__name__ 230 | 231 | if self.router: 232 | func_to_call = self.router.get_api_to_call( 233 | func, metadata.get("user_id", None) 234 | ) 235 | else: 236 | func_to_call = func 237 | 238 | return func_to_call, args, kwargs, metadata 239 | 240 | def wrap_function(self, func, dataset, *args, **kwargs): 241 | func_to_call, args, kwargs, metadata = self._wrap_common_sync( 242 | func, dataset, *args, **kwargs 243 | ) 244 | result = func_to_call(*args, **kwargs) 245 | return self.handle_data(result, dataset, args, kwargs, metadata) 246 | 247 | def __call__(self, dataset="default"): 248 | def wrapper(func): 249 | async def async_inner(*args, **kwargs): 250 | return await self._wrap_function_async(func, dataset, *args, **kwargs) 251 | 252 | def sync_inner(*args, **kwargs): 253 | return self.wrap_function(func, dataset, *args, **kwargs) 254 | 255 | if asyncio.iscoroutinefunction(func): 256 | return async_inner 257 | else: 258 | return sync_inner 259 | 260 | return wrapper 261 | 262 | async def _handle_data_async(self, result, dataset, args, kwargs, metadata): 263 | return self.handle_data(result, dataset, args, kwargs, metadata) 264 | 265 | async def _wrap_function_async(self, func, dataset, *args, **kwargs): 266 | func_to_call, args, kwargs, metadata = await self._wrap_common( 267 | func, dataset, *args, **kwargs 268 | ) 269 | result = await func_to_call(*args, **kwargs) 270 | return await self._handle_data_async(result, dataset, args, kwargs, metadata) 271 | 272 | def __call__(self, dataset="default"): # noqa 273 | def wrapper(func): 274 | async def async_inner(*args, **kwargs): 275 | return await self._wrap_function_async(func, dataset, *args, **kwargs) 276 | 277 | def sync_inner(*args, **kwargs): 278 | return self.wrap_function(func, dataset, *args, **kwargs) 279 | 280 | if asyncio.iscoroutinefunction(func): 281 | return async_inner 282 | else: 283 | return sync_inner 284 | 285 | return wrapper 286 | 287 | def provide_feedback(self, bohita_logos_shift_id, feedback): 288 | """ 289 | Provides feedback for a specific function call. 290 | 291 | Args: 292 | bohita_logos_shift_id (str): The unique identifier for the function call. 293 | feedback (str): The feedback string. 294 | 295 | Examples: 296 | >>> logos_shift.provide_feedback("unique_id_123", "success") 297 | """ 298 | feedback_data = { 299 | "bohita_logos_shift_id": bohita_logos_shift_id, 300 | "feedback": feedback, 301 | "dataset": "unknown", 302 | } 303 | with self.lock: 304 | self.active_buffer.append(feedback_data) 305 | --------------------------------------------------------------------------------