├── tests ├── __init__.py ├── unit │ └── __init__.py ├── integration │ └── __init__.py ├── test_setup_validation.py └── conftest.py ├── examples ├── __init__.py ├── embedding │ ├── __init__.py │ └── huggingface_tei_example.py └── flask │ ├── __init__.py │ ├── llms_cache │ ├── __init__.py │ ├── register.py │ ├── data_query.py │ ├── data_insert.py │ └── data_query_long.py │ └── multi_cache │ ├── __init__.py │ ├── register.py │ ├── remove.py │ ├── data_query.py │ └── data_insert.py ├── model ├── .gitignore ├── clip_zh │ └── __init__.py └── text2vec-base-chinese │ ├── sentence_bert_config.json │ ├── special_tokens_map.json │ ├── modules.json │ ├── tokenizer_config.json │ ├── logs.txt │ └── config.json ├── modelcache ├── adapter │ ├── __init__.py │ ├── adapter_register.py │ ├── adapter_remove.py │ ├── adapter.py │ └── adapter_insert.py ├── manager │ ├── __init__.py │ ├── eviction │ │ ├── __init__.py │ │ ├── base.py │ │ ├── memory_cache.py │ │ ├── arc_cache.py │ │ └── wtinylfu_cache.py │ ├── scalar_data │ │ ├── __init__.py │ │ └── base.py │ ├── vector_data │ │ ├── __init__.py │ │ ├── faiss.py │ │ ├── chroma.py │ │ ├── redis.py │ │ └── base.py │ ├── object_data │ │ ├── __init__.py │ │ └── base.py │ └── eviction_manager.py ├── utils │ ├── env_config.py │ ├── index_util.py │ ├── log.py │ ├── time.py │ ├── model_filter.py │ ├── dependency_control.py │ ├── lazy_import.py │ ├── error.py │ └── __init__.py ├── processor │ ├── __init__.py │ ├── post.py │ └── pre.py ├── config │ ├── chromadb_config.ini │ ├── redis_config.ini │ ├── milvus_config.ini │ ├── elasticsearch_config.ini │ └── mysql_config.ini ├── similarity_evaluation │ ├── __init__.py │ ├── base.py │ ├── exact_match.py │ └── distance.py ├── __init__.py ├── embedding │ ├── __init__.py │ ├── fasttext.py │ ├── huggingface_tei.py │ ├── bge_m3.py │ ├── huggingface.py │ ├── llmEmb.py │ ├── paddlenlp.py │ ├── onnx.py │ ├── timm_embedding.py │ ├── embedding_dispatcher.py │ ├── base.py │ └── data2vec.py └── report.py ├── modelcache_mm ├── adapter │ ├── __init__.py │ ├── adapter_register.py │ ├── adapter_remove.py │ ├── adapter.py │ └── adapter_insert.py ├── processor │ ├── __init__.py │ ├── post.py │ └── pre.py ├── utils │ ├── env_config.py │ ├── cache_func.py │ ├── log.py │ ├── time.py │ ├── dependency_control.py │ ├── lazy_import.py │ ├── index_util.py │ ├── error.py │ └── __init__.py ├── config │ ├── chromadb_config.ini │ ├── milvus_config.ini │ ├── redis_config.ini │ ├── elasticsearch_config.ini │ └── mysql_config.ini ├── embedding │ ├── string.py │ ├── __init__.py │ ├── base.py │ ├── clip.py │ └── timm.py ├── __init__.py ├── manager │ ├── __init__.py │ ├── object_data │ │ ├── __init__.py │ │ └── base.py │ ├── scalar_data │ │ ├── __init__.py │ │ ├── manager.py │ │ └── base.py │ ├── eviction │ │ ├── __init__.py │ │ └── base.py │ ├── vector_data │ │ ├── __init__.py │ │ ├── base.py │ │ ├── faiss.py │ │ ├── chroma.py │ │ ├── manager.py │ │ └── redis.py │ ├── factory.py │ └── eviction_manager.py ├── similarity_evaluation │ ├── __init__.py │ ├── similarity_evaluation.py │ ├── exact_match.py │ └── distance.py ├── config.py ├── report.py └── core.py ├── data ├── milvus │ ├── user.yaml │ └── embedEtcd.yaml └── mysql │ ├── my.conf │ └── init │ └── init.sql ├── docs ├── codefuse-LOGO.png ├── time-cost-comparison.webp ├── modelcache_modules_20231114.png ├── modelcache_modules_20240409.png ├── cache-service-cost-time-distribution.webp ├── script │ └── get_input_embedding_script.py ├── 4.create-cache.md ├── 2.model-cache-features.md ├── 3.model-cache-quick-start.md └── 1.what-is-model-cache.md ├── Dockerfile ├── requirements.txt ├── flask4modelcache.py ├── flask4modelcache_demo.py ├── fastapi4modelcache.py ├── fastapi4modelcache_demo.py ├── reference_doc └── create_table.sql ├── docker-compose.yaml ├── websocket4modelcache.py ├── websocket4modelcache_demo.py ├── pyproject.toml ├── .gitignore └── mulicache-readme-cn.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/unit/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/integration/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /model/.gitignore: -------------------------------------------------------------------------------- 1 | *.tflite 2 | text2vec-base-chinese/* -------------------------------------------------------------------------------- /model/clip_zh/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /examples/embedding/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /examples/flask/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /modelcache/adapter/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /modelcache/manager/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /modelcache/utils/env_config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /modelcache/processor/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /modelcache_mm/adapter/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /modelcache_mm/processor/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /modelcache_mm/utils/env_config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /data/milvus/user.yaml: -------------------------------------------------------------------------------- 1 | # Extra config to override default milvus.yaml -------------------------------------------------------------------------------- /examples/flask/llms_cache/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /examples/flask/multi_cache/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /modelcache/manager/eviction/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | -------------------------------------------------------------------------------- /modelcache/manager/scalar_data/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /modelcache/manager/vector_data/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /modelcache/config/chromadb_config.ini: -------------------------------------------------------------------------------- 1 | [chromadb] 2 | persist_directory='' 3 | -------------------------------------------------------------------------------- /modelcache/manager/object_data/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | -------------------------------------------------------------------------------- /modelcache/similarity_evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | -------------------------------------------------------------------------------- /modelcache/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /modelcache_mm/config/chromadb_config.ini: -------------------------------------------------------------------------------- 1 | [chromadb] 2 | persist_directory=./chromadb 3 | -------------------------------------------------------------------------------- /docs/codefuse-LOGO.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/ModelCache/main/docs/codefuse-LOGO.png -------------------------------------------------------------------------------- /modelcache/config/redis_config.ini: -------------------------------------------------------------------------------- 1 | [redis] 2 | host = '' 3 | port = '' 4 | user = '' 5 | password = '' 6 | -------------------------------------------------------------------------------- /modelcache_mm/config/milvus_config.ini: -------------------------------------------------------------------------------- 1 | [milvus] 2 | host = '' 3 | port = '' 4 | user = '' 5 | password = '' -------------------------------------------------------------------------------- /modelcache_mm/utils/cache_func.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | def cache_all(*_, **__): 3 | return True -------------------------------------------------------------------------------- /modelcache/config/milvus_config.ini: -------------------------------------------------------------------------------- 1 | [milvus] 2 | host = localhost 3 | port = 19530 4 | user = '' 5 | password = '' -------------------------------------------------------------------------------- /modelcache_mm/config/redis_config.ini: -------------------------------------------------------------------------------- 1 | [redis] 2 | host = '' 3 | port = '' 4 | user = '' 5 | password = '' 6 | -------------------------------------------------------------------------------- /docs/time-cost-comparison.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/ModelCache/main/docs/time-cost-comparison.webp -------------------------------------------------------------------------------- /modelcache/config/elasticsearch_config.ini: -------------------------------------------------------------------------------- 1 | [elasticsearch] 2 | host = '' 3 | port = '' 4 | user = '' 5 | password = '' -------------------------------------------------------------------------------- /model/text2vec-base-chinese/sentence_bert_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_seq_length": 128, 3 | "do_lower_case": false 4 | } 5 | -------------------------------------------------------------------------------- /modelcache_mm/config/elasticsearch_config.ini: -------------------------------------------------------------------------------- 1 | [elasticsearch] 2 | host = '' 3 | port = '' 4 | user = '' 5 | password = '' -------------------------------------------------------------------------------- /modelcache_mm/embedding/string.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | def to_embeddings(data, **_): 5 | return data 6 | -------------------------------------------------------------------------------- /docs/modelcache_modules_20231114.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/ModelCache/main/docs/modelcache_modules_20231114.png -------------------------------------------------------------------------------- /docs/modelcache_modules_20240409.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/ModelCache/main/docs/modelcache_modules_20240409.png -------------------------------------------------------------------------------- /modelcache_mm/config/mysql_config.ini: -------------------------------------------------------------------------------- 1 | [mysql] 2 | host = '' 3 | port = '' 4 | username = '' 5 | password = '' 6 | database = '' 7 | -------------------------------------------------------------------------------- /modelcache/embedding/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from modelcache.embedding.base import EmbeddingModel, MetricType, BaseEmbedding -------------------------------------------------------------------------------- /docs/cache-service-cost-time-distribution.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/ModelCache/main/docs/cache-service-cost-time-distribution.webp -------------------------------------------------------------------------------- /model/text2vec-base-chinese/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | {"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"} -------------------------------------------------------------------------------- /modelcache/config/mysql_config.ini: -------------------------------------------------------------------------------- 1 | [mysql] 2 | host = localhost 3 | port = 3306 4 | username = modelcache 5 | password = modelcache 6 | database = modelcache 7 | -------------------------------------------------------------------------------- /data/mysql/my.conf: -------------------------------------------------------------------------------- 1 | [mysqld] 2 | character-set-server=utf8mb4 3 | collation-server=utf8mb4_unicode_ci 4 | 5 | [client] 6 | default-character-set=utf8mb4 7 | 8 | [mysql] 9 | default-character-set=utf8mb4 -------------------------------------------------------------------------------- /modelcache_mm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from modelcache_mm.core import Cache 3 | from modelcache_mm.core import cache 4 | from modelcache_mm.config import Config 5 | import modelcache_mm.adapter 6 | -------------------------------------------------------------------------------- /data/milvus/embedEtcd.yaml: -------------------------------------------------------------------------------- 1 | listen-client-urls: http://0.0.0.0:2379 2 | advertise-client-urls: http://0.0.0.0:2379 3 | quota-backend-bytes: 4294967296 4 | auto-compaction-mode: revision 5 | auto-compaction-retention: '1000' 6 | -------------------------------------------------------------------------------- /modelcache/utils/index_util.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | def get_index_name(model): 5 | return 'modelcache' + '_' + model 6 | 7 | 8 | def get_index_prefix(model): 9 | return 'prefix' + '_' + model 10 | -------------------------------------------------------------------------------- /modelcache/utils/log.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import logging 3 | 4 | FORMAT = '%(asctime)s - %(thread)d - %(filename)s-%(module)s:%(lineno)s - %(levelname)s: %(message)s' 5 | logging.basicConfig(format=FORMAT) 6 | 7 | modelcache_log = logging.getLogger('modelcache') 8 | -------------------------------------------------------------------------------- /modelcache_mm/utils/log.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import logging 3 | 4 | FORMAT = '%(asctime)s - %(thread)d - %(filename)s-%(module)s:%(lineno)s - %(levelname)s: %(message)s' 5 | logging.basicConfig(format=FORMAT) 6 | 7 | modelcache_log = logging.getLogger('modelcache') 8 | -------------------------------------------------------------------------------- /modelcache_mm/embedding/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from modelcache.utils.lazy_import import LazyImport 3 | clip = LazyImport("clip", globals(), "modelcache_mm.embedding.clip") 4 | 5 | 6 | def Clip2Vec(model="damo/multi-modal_clip-vit-base-patch16_zh"): 7 | return clip.ClipAudio(model) 8 | -------------------------------------------------------------------------------- /modelcache_mm/manager/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from modelcache_mm.manager.scalar_data import CacheBase 3 | from modelcache_mm.manager.vector_data import VectorBase 4 | from modelcache_mm.manager.object_data import ObjectBase 5 | from modelcache_mm.manager.factory import get_data_manager 6 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.9-slim-bookworm 2 | 3 | WORKDIR /home/user 4 | 5 | COPY ./requirements.txt /home/user/docker_requirements.txt 6 | 7 | RUN pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple 8 | RUN pip install -r /home/user/docker_requirements.txt --retries 5 --timeout 120 9 | -------------------------------------------------------------------------------- /model/text2vec-base-chinese/modules.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "idx": 0, 4 | "name": "0", 5 | "path": "", 6 | "type": "sentence_transformers.models.Transformer" 7 | }, 8 | { 9 | "idx": 1, 10 | "name": "1", 11 | "path": "1_Pooling", 12 | "type": "sentence_transformers.models.Pooling" 13 | } 14 | ] 15 | -------------------------------------------------------------------------------- /modelcache_mm/manager/object_data/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from modelcache.utils.lazy_import import LazyImport 3 | object_manager = LazyImport( 4 | "object_manager", globals(), "modelcache.manager.object_data.manager" 5 | ) 6 | 7 | 8 | def ObjectBase(name: str, **kwargs): 9 | return object_manager.ObjectBase.get(name, **kwargs) 10 | -------------------------------------------------------------------------------- /model/text2vec-base-chinese/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | {"do_lower_case": true, "do_basic_tokenize": true, "never_split": null, "unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]", "tokenize_chinese_chars": true, "strip_accents": null, "name_or_path": "hfl/chinese-macbert-base", "tokenizer_class": "BertTokenizer"} 2 | -------------------------------------------------------------------------------- /modelcache_mm/manager/scalar_data/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from modelcache_mm.utils.lazy_import import LazyImport 3 | scalar_manager = LazyImport( 4 | "scalar_manager", globals(), "modelcache_mm.manager.scalar_data.manager" 5 | ) 6 | 7 | 8 | def CacheBase(name: str, **kwargs): 9 | return scalar_manager.CacheBase.get(name, **kwargs) 10 | -------------------------------------------------------------------------------- /modelcache/processor/post.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import random 3 | from typing import List, Any 4 | 5 | 6 | def random_one(messages: List[Any]) -> Any: 7 | return random.choice(messages) 8 | 9 | 10 | def first(messages: List[Any]) -> Any: 11 | return messages[0] 12 | 13 | 14 | def nop(messages: List[Any]) -> Any: 15 | return messages 16 | -------------------------------------------------------------------------------- /modelcache_mm/manager/eviction/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from modelcache.utils.lazy_import import LazyImport 3 | 4 | eviction_manager = LazyImport( 5 | "eviction_manager", globals(), "modelcache.manager.eviction.manager" 6 | ) 7 | 8 | 9 | def EvictionBase(name: str, **kwargs): 10 | return eviction_manager.EvictionBase.get(name, **kwargs) 11 | -------------------------------------------------------------------------------- /modelcache_mm/processor/post.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import random 3 | from typing import List, Any 4 | 5 | 6 | def random_one(messages: List[Any]) -> Any: 7 | return random.choice(messages) 8 | 9 | 10 | def first(messages: List[Any]) -> Any: 11 | return messages[0] 12 | 13 | 14 | def nop(messages: List[Any]) -> Any: 15 | return messages 16 | -------------------------------------------------------------------------------- /modelcache_mm/manager/vector_data/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from modelcache_mm.utils.lazy_import import LazyImport 3 | 4 | vector_manager = LazyImport( 5 | "vector_manager", globals(), "modelcache_mm.manager.vector_data.manager" 6 | ) 7 | 8 | 9 | def VectorBase(name: str, **kwargs): 10 | return vector_manager.VectorBase.get(name, **kwargs) 11 | -------------------------------------------------------------------------------- /modelcache_mm/embedding/base.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from abc import ABCMeta, abstractmethod 3 | 4 | 5 | class BaseEmbedding(metaclass=ABCMeta): 6 | """ 7 | _Embedding base. 8 | """ 9 | 10 | @abstractmethod 11 | def to_embeddings(self, data, **kwargs): 12 | pass 13 | 14 | @property 15 | @abstractmethod 16 | def dimension(self) -> int: 17 | return 0 18 | -------------------------------------------------------------------------------- /modelcache_mm/similarity_evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from modelcache.similarity_evaluation.similarity_evaluation import SimilarityEvaluation 3 | from modelcache.utils.lazy_import import LazyImport 4 | 5 | exact_match = LazyImport( 6 | "exact_match", globals(), "modelcache.similarity_evaluation.exact_match" 7 | ) 8 | 9 | 10 | def ExactMatchEvaluation(): 11 | return exact_match.ExactMatchEvaluation() 12 | -------------------------------------------------------------------------------- /docs/script/get_input_embedding_script.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import numpy as np 4 | from transformers import AutoModelForCausalLM 5 | 6 | 7 | model_path = '' 8 | device = torch.device('cuda') 9 | model = AutoModelForCausalLM.from_pretrained(model_path, local_files_only=True).to(device) 10 | embedding_weights = model.get_input_embeddings().weight.to('cpu').detach().numpy() 11 | np.save('gpt-neox-embedding.npy', embedding_weights) 12 | -------------------------------------------------------------------------------- /modelcache/adapter/adapter_register.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import asyncio 3 | 4 | 5 | async def adapt_register(*args, **kwargs): 6 | chat_cache = kwargs.pop("cache_obj") 7 | model = kwargs.pop("model", None) 8 | if model is None or len(model) == 0: 9 | return ValueError('') 10 | 11 | register_resp = await asyncio.to_thread( 12 | chat_cache.data_manager.create_index, 13 | model 14 | ) 15 | 16 | return register_resp 17 | -------------------------------------------------------------------------------- /modelcache/similarity_evaluation/base.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from abc import ABCMeta, abstractmethod 3 | from typing import Tuple, Dict, Any 4 | 5 | 6 | class SimilarityEvaluation(metaclass=ABCMeta): 7 | @abstractmethod 8 | def evaluation( 9 | self, src_dict: Dict[str, Any], cache_dict: Dict[str, Any], **kwargs 10 | ) -> float: 11 | pass 12 | 13 | @abstractmethod 14 | def range(self) -> Tuple[float, float]: 15 | pass 16 | -------------------------------------------------------------------------------- /modelcache_mm/adapter/adapter_register.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from modelcache_mm import cache 3 | 4 | 5 | def adapt_register(*args, **kwargs): 6 | chat_cache = kwargs.pop("cache_obj", cache) 7 | model = kwargs.pop("model", None) 8 | type = kwargs.pop("type", None) 9 | if model is None or len(model) == 0: 10 | return ValueError('') 11 | 12 | register_resp = chat_cache.data_manager.create_index(model, type) 13 | return register_resp 14 | -------------------------------------------------------------------------------- /modelcache_mm/similarity_evaluation/similarity_evaluation.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from abc import ABCMeta, abstractmethod 3 | from typing import Tuple, Dict, Any 4 | 5 | 6 | class SimilarityEvaluation(metaclass=ABCMeta): 7 | @abstractmethod 8 | def evaluation( 9 | self, src_dict: Dict[str, Any], cache_dict: Dict[str, Any], **kwargs 10 | ) -> float: 11 | pass 12 | 13 | @abstractmethod 14 | def range(self) -> Tuple[float, float]: 15 | pass 16 | -------------------------------------------------------------------------------- /modelcache_mm/manager/eviction/base.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from abc import ABCMeta, abstractmethod 3 | from typing import Any, List 4 | 5 | 6 | class EvictionBase(metaclass=ABCMeta): 7 | """ 8 | Eviction base. 9 | """ 10 | 11 | @abstractmethod 12 | def put(self, objs: List[Any]): 13 | pass 14 | 15 | @abstractmethod 16 | def get(self, obj: Any): 17 | pass 18 | 19 | @property 20 | @abstractmethod 21 | def policy(self) -> str: 22 | pass 23 | -------------------------------------------------------------------------------- /modelcache/manager/eviction/base.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from abc import ABCMeta, abstractmethod 3 | from typing import Any, List 4 | 5 | 6 | class EvictionBase(metaclass=ABCMeta): 7 | """ 8 | Eviction base. 9 | """ 10 | 11 | @abstractmethod 12 | def put(self, objs: List[Any], model:str): 13 | pass 14 | 15 | @abstractmethod 16 | def get(self, obj: Any, model:str): 17 | pass 18 | 19 | @property 20 | @abstractmethod 21 | def policy(self) -> str: 22 | pass 23 | -------------------------------------------------------------------------------- /examples/flask/llms_cache/register.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | register index for redis 4 | """ 5 | import json 6 | import requests 7 | 8 | 9 | def run(): 10 | url = 'http://127.0.0.1:5000/modelcache' 11 | type = 'register' 12 | scope = {"model": "CODEGPT-1117"} 13 | data = {'type': type, 'scope': scope} 14 | headers = {"Content-Type": "application/json"} 15 | res = requests.post(url, headers=headers, json=json.dumps(data)) 16 | res_text = res.text 17 | 18 | 19 | if __name__ == '__main__': 20 | run() -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cachetools==5.3.1 2 | DBUtils==1.4 3 | Flask==3.1.1 4 | numpy==2.2.6 5 | onnxruntime==1.22.0 6 | openai==0.28.1 7 | pymilvus==2.5.9 8 | PyMySQL==1.1.1 9 | Requests==2.32.3 10 | torch==2.7.0 11 | transformers==4.52.4 12 | faiss-cpu==1.11.0 13 | redis==5.0.1 14 | modelscope==1.26.0 15 | fastapi==0.115.9 16 | uvicorn==0.34.3 17 | chromadb==1.0.12 18 | elasticsearch==7.10.0 19 | snowflake-id==1.0.2 20 | flagembedding==1.3.5 21 | cryptography==45.0.2 22 | sentence-transformers==4.1.0 23 | pytest==8.0 24 | readerwriterlock==1.0.9 25 | -------------------------------------------------------------------------------- /examples/embedding/huggingface_tei_example.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import sys 3 | sys.path.append(".") 4 | from modelcache.embedding.huggingface_tei import HuggingfaceTEI 5 | 6 | ''' 7 | run tei server: 8 | text-embeddings-router --model-id BAAI/bge-large-zh-v1.5 --port 8080 9 | ''' 10 | 11 | def run(): 12 | tei_instance = HuggingfaceTEI('http://127.0.0.1:8080/v1/embeddings', 'BAAI/bge-large-zh-v1.5') 13 | print('dimenson', tei_instance.dimension) 14 | print('embedding', tei_instance.to_embeddings('hello')) 15 | 16 | if __name__ == '__main__': 17 | run() -------------------------------------------------------------------------------- /modelcache/similarity_evaluation/exact_match.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from typing import Tuple, Dict, Any 3 | from modelcache.similarity_evaluation.base import SimilarityEvaluation 4 | 5 | 6 | class ExactMatchEvaluation(SimilarityEvaluation): 7 | 8 | def __init__(self): 9 | pass 10 | 11 | def evaluation( 12 | self, src_dict: Dict[str, Any], cache_dict: Dict[str, Any], **_ 13 | ) -> float: 14 | return 1 if cache_dict["question"] == src_dict["question"] else 0 15 | 16 | def range(self) -> Tuple[float, float]: 17 | return 0, 1 18 | -------------------------------------------------------------------------------- /modelcache/manager/object_data/base.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from abc import ABC, abstractmethod 3 | from typing import Any, List 4 | 5 | 6 | class ObjectBase(ABC): 7 | """ 8 | Object storage base. 9 | """ 10 | 11 | @abstractmethod 12 | def put(self, obj: Any) -> str: 13 | pass 14 | 15 | @abstractmethod 16 | def get_access_link(self, obj: str) -> str: 17 | pass 18 | 19 | @abstractmethod 20 | def delete(self, to_delete: List[str]): 21 | pass 22 | 23 | @staticmethod 24 | def get(name: str) -> Any: 25 | pass 26 | -------------------------------------------------------------------------------- /modelcache_mm/similarity_evaluation/exact_match.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from typing import Tuple, Dict, Any 3 | from modelcache.similarity_evaluation.similarity_evaluation import SimilarityEvaluation 4 | 5 | 6 | class ExactMatchEvaluation(SimilarityEvaluation): 7 | def __init__(self): 8 | pass 9 | 10 | def evaluation( 11 | self, src_dict: Dict[str, Any], cache_dict: Dict[str, Any], **_ 12 | ) -> float: 13 | return 1 if cache_dict["question"] == src_dict["question"] else 0 14 | 15 | def range(self) -> Tuple[float, float]: 16 | return 0, 1 17 | -------------------------------------------------------------------------------- /modelcache_mm/manager/object_data/base.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from abc import ABC, abstractmethod 3 | from typing import Any, List 4 | 5 | 6 | class ObjectBase(ABC): 7 | """ 8 | Object storage base. 9 | """ 10 | 11 | @abstractmethod 12 | def put(self, obj: Any) -> str: 13 | pass 14 | 15 | @abstractmethod 16 | def get(self, obj: str) -> Any: 17 | pass 18 | 19 | @abstractmethod 20 | def get_access_link(self, obj: str) -> str: 21 | pass 22 | 23 | @abstractmethod 24 | def delete(self, to_delete: List[str]): 25 | pass 26 | -------------------------------------------------------------------------------- /modelcache/utils/time.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import time 3 | 4 | def time_cal(func, func_name=None, report_func=None, **kwargs): 5 | cache = kwargs.pop("cache_obj") 6 | def inner(*args, **kwargs): 7 | time_start = time.time() 8 | res = func(*args, **kwargs) 9 | delta_time = time.time() - time_start 10 | if cache.log_time_func: 11 | cache.log_time_func( 12 | func.__name__ if func_name is None else func_name, delta_time 13 | ) 14 | if report_func is not None: 15 | report_func(delta_time) 16 | return res 17 | 18 | return inner 19 | 20 | -------------------------------------------------------------------------------- /examples/flask/multi_cache/register.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | register index for redis 4 | """ 5 | import json 6 | import requests 7 | 8 | 9 | def run(): 10 | url = 'http://127.0.0.1:5000/multicache' 11 | request_type = 'register' 12 | scope = {"model": "multimodal_test"} 13 | type = 'IMG_TEXT' 14 | data = {'request_type': request_type, 'scope': scope, 'type': type} 15 | headers = {"Content-Type": "application/json"} 16 | res = requests.post(url, headers=headers, json=json.dumps(data)) 17 | res_text = res.text 18 | print('res_text: {}'.format(res_text)) 19 | 20 | 21 | if __name__ == '__main__': 22 | run() -------------------------------------------------------------------------------- /model/text2vec-base-chinese/logs.txt: -------------------------------------------------------------------------------- 1 | Epoch:0 Valid| corr: 0.794410 2 | Epoch:0 Valid| corr: 0.691819 3 | Epoch:1 Valid| corr: 0.722749 4 | Epoch:2 Valid| corr: 0.735054 5 | Epoch:3 Valid| corr: 0.738295 6 | Epoch:4 Valid| corr: 0.739411 7 | Test | corr: 0.679971 8 | Epoch:0 Valid| corr: 0.817416 9 | Epoch:1 Valid| corr: 0.832376 10 | Epoch:2 Valid| corr: 0.842308 11 | Epoch:3 Valid| corr: 0.843520 12 | Epoch:4 Valid| corr: 0.841837 13 | Test | corr: 0.793495 14 | Epoch:0 Valid| corr: 0.814648 15 | Epoch:1 Valid| corr: 0.831609 16 | Epoch:2 Valid| corr: 0.841678 17 | Epoch:3 Valid| corr: 0.842387 18 | Epoch:4 Valid| corr: 0.841435 19 | Test | corr: 0.794840 20 | -------------------------------------------------------------------------------- /modelcache_mm/utils/time.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import time 3 | from modelcache import cache 4 | 5 | 6 | def time_cal(func, func_name=None, report_func=None): 7 | def inner(*args, **kwargs): 8 | time_start = time.time() 9 | res = func(*args, **kwargs) 10 | delta_time = time.time() - time_start 11 | if cache.config.log_time_func: 12 | cache.config.log_time_func( 13 | func.__name__ if func_name is None else func_name, delta_time 14 | ) 15 | if report_func is not None: 16 | report_func(delta_time) 17 | return res 18 | 19 | return inner 20 | 21 | -------------------------------------------------------------------------------- /examples/flask/llms_cache/data_query.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import json 3 | import requests 4 | 5 | 6 | def run(): 7 | url = 'http://127.0.0.1:5000/modelcache' 8 | type = 'query' 9 | scope = {"model": "CODEGPT-1117"} 10 | query = [{"role": "system", "content": "你是一个python助手"}, {"role": "user", "content": "hello"}] 11 | data = {'type': type, 'scope': scope, 'query': query} 12 | 13 | headers = {"Content-Type": "application/json"} 14 | res = requests.post(url, headers=headers, json=json.dumps(data)) 15 | res_text = res.text 16 | 17 | print("data_query:", res.status_code, res_text) 18 | 19 | if __name__ == '__main__': 20 | run() 21 | -------------------------------------------------------------------------------- /examples/flask/multi_cache/remove.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | register index for redis 4 | """ 5 | import json 6 | import requests 7 | 8 | 9 | def run(): 10 | url = 'http://127.0.0.1:5000/multicache' 11 | request_type = 'remove' 12 | scope = {"model": "multimodal_test"} 13 | remove_type = 'truncate_by_model' 14 | data = {'request_type': request_type, 'scope': scope, 'remove_type': remove_type} 15 | 16 | headers = {"Content-Type": "application/json"} 17 | res = requests.post(url, headers=headers, json=json.dumps(data)) 18 | res_text = res.text 19 | print('res_text: {}'.format(res_text)) 20 | 21 | 22 | if __name__ == '__main__': 23 | run() 24 | -------------------------------------------------------------------------------- /modelcache/utils/model_filter.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | def model_blacklist_filter(model, request_type): 3 | black_list = ['DI_COPILOT_SECOND', 'DI_COPILOT_LAB', 'DI_COPILOT_THIRD'] 4 | result = None 5 | if model in black_list: 6 | if request_type == 'query': 7 | result = {"errorCode": 105, 8 | "errorDesc": "model: {} in blacklist".format(model), 9 | "cacheHit": False, "delta_time": 0, "hit_query": '', "answer": ''} 10 | elif request_type == 'insert': 11 | result = {"errorCode": 305, "errorDesc": "model: {} in blacklist".format(model), "writeStatus": ""} 12 | 13 | return result 14 | 15 | 16 | -------------------------------------------------------------------------------- /examples/flask/llms_cache/data_insert.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import json 3 | import requests 4 | 5 | 6 | def run(): 7 | url = 'http://127.0.0.1:5000/modelcache' 8 | type = 'insert' 9 | scope = {"model": "CODEGPT-1117"} 10 | chat_info = [{"query": [{"role": "system", "content": "你是一个python助手"}, {"role": "user", "content": "hello"}], 11 | "answer": "你好,我是智能助手,请问有什么能帮您!"}] 12 | data = {'type': type, 'scope': scope, 'chat_info': chat_info} 13 | headers = {"Content-Type": "application/json"} 14 | res = requests.post(url, headers=headers, json=json.dumps(data)) 15 | res_text = res.text 16 | 17 | print("data_insert:", res.status_code, res_text) 18 | 19 | if __name__ == '__main__': 20 | run() 21 | -------------------------------------------------------------------------------- /examples/flask/llms_cache/data_query_long.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import json 3 | import requests 4 | 5 | 6 | def run(): 7 | url = 'http://127.0.0.1:5000/modelcache' 8 | type = 'query' 9 | scope = {"model": "CODEGPT-1109"} 10 | system_conten = """ 11 | """ 12 | user_content = """ 13 | """ 14 | 15 | query = [{"role": "system", "content": system_conten}, {"role": "user", "content": user_content}] 16 | data = {'type': type, 'scope': scope, 'query': query} 17 | 18 | headers = {"Content-Type": "application/json"} 19 | res = requests.post(url, headers=headers, json=json.dumps(data)) 20 | res_text = res.text 21 | 22 | print("data_query_long:", res.status_code, res_text) 23 | 24 | if __name__ == '__main__': 25 | run() 26 | -------------------------------------------------------------------------------- /modelcache/utils/dependency_control.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import subprocess 3 | from modelcache.utils.error import PipInstallError 4 | from modelcache.utils.log import modelcache_log 5 | 6 | 7 | def prompt_install(package: str, warn: bool = False): # pragma: no cover 8 | """ 9 | Function used to prompt user to install a package. 10 | """ 11 | cmd = f"pip install {package}" 12 | try: 13 | if warn and input(f"Install {package}? Y/n: ") != "Y": 14 | raise ModuleNotFoundError(f"No module named {package}") 15 | subprocess.check_call(cmd, shell=True) 16 | modelcache_log.info("%s installed successfully!", package) 17 | except subprocess.CalledProcessError as e: 18 | raise PipInstallError(package) from e 19 | -------------------------------------------------------------------------------- /modelcache_mm/utils/dependency_control.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import subprocess 3 | from modelcache.utils.error import PipInstallError 4 | from modelcache.utils.log import modelcache_log 5 | 6 | 7 | def prompt_install(package: str, warn: bool = False): # pragma: no cover 8 | """ 9 | Function used to prompt user to install a package. 10 | """ 11 | cmd = f"pip install {package}" 12 | try: 13 | if warn and input(f"Install {package}? Y/n: ") != "Y": 14 | raise ModuleNotFoundError(f"No module named {package}") 15 | subprocess.check_call(cmd, shell=True) 16 | modelcache_log.info("%s installed successfully!", package) 17 | except subprocess.CalledProcessError as e: 18 | raise PipInstallError(package) from e 19 | -------------------------------------------------------------------------------- /modelcache_mm/utils/lazy_import.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import importlib 3 | from types import ModuleType 4 | 5 | 6 | class LazyImport(ModuleType): 7 | """ 8 | Lazily import a module. 9 | """ 10 | def __init__(self, local_name, parent_module_globals, name): 11 | self._local_name = local_name 12 | self._parent_module_globals = parent_module_globals 13 | super().__init__(name) 14 | 15 | def _load(self): 16 | module = importlib.import_module(self.__name__) 17 | self._parent_module_globals[self._local_name] = module 18 | self.__dict__.update(module.__dict__) 19 | return module 20 | 21 | def __getattr__(self, item): 22 | module = self._load() 23 | return getattr(module, item) 24 | 25 | def __dir__(self): 26 | module = self._load() 27 | return dir(module) 28 | -------------------------------------------------------------------------------- /modelcache/utils/lazy_import.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import importlib 3 | from types import ModuleType 4 | 5 | 6 | class LazyImport(ModuleType): 7 | """ 8 | Lazily import a module. 9 | """ 10 | 11 | def __init__(self, local_name, parent_module_globals, name): 12 | self._local_name = local_name 13 | self._parent_module_globals = parent_module_globals 14 | super().__init__(name) 15 | 16 | def _load(self): 17 | module = importlib.import_module(self.__name__) 18 | self._parent_module_globals[self._local_name] = module 19 | self.__dict__.update(module.__dict__) 20 | return module 21 | 22 | def __getattr__(self, item): 23 | module = self._load() 24 | return getattr(module, item) 25 | 26 | def __dir__(self): 27 | module = self._load() 28 | return dir(module) 29 | -------------------------------------------------------------------------------- /modelcache_mm/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from typing import Optional, Callable, List 3 | from modelcache.utils.error import CacheError 4 | 5 | 6 | class Config: 7 | 8 | def __init__( 9 | self, 10 | log_time_func: Optional[Callable[[str, float], None]] = None, 11 | similarity_threshold: float = 0.95, 12 | similarity_threshold_long: float = 0.95, 13 | prompts: Optional[List[str]] = None 14 | ): 15 | if similarity_threshold < 0 or similarity_threshold > 1: 16 | raise CacheError( 17 | "Invalid the similarity threshold param, reasonable range: 0-1" 18 | ) 19 | self.log_time_func = log_time_func 20 | self.similarity_threshold = similarity_threshold 21 | self.similarity_threshold_long = similarity_threshold_long 22 | self.prompts = prompts 23 | -------------------------------------------------------------------------------- /modelcache/similarity_evaluation/distance.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from typing import Tuple, Dict, Any 3 | from modelcache.similarity_evaluation.base import SimilarityEvaluation 4 | 5 | class SearchDistanceEvaluation(SimilarityEvaluation): 6 | def __init__(self, max_distance=4.0, positive=False): 7 | self.max_distance = max_distance 8 | self.positive = positive 9 | 10 | def evaluation( 11 | self, src_dict: Dict[str, Any], cache_dict: Dict[str, Any], **_ 12 | ) -> float: 13 | distance, _ = cache_dict["search_result"] 14 | if distance < 0: 15 | distance = 0 16 | elif distance > self.max_distance: 17 | distance = self.max_distance 18 | if self.positive: 19 | return distance 20 | return self.max_distance - distance 21 | 22 | def range(self) -> Tuple[float, float]: 23 | return 0.0, self.max_distance 24 | -------------------------------------------------------------------------------- /modelcache_mm/similarity_evaluation/distance.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from typing import Tuple, Dict, Any 3 | from modelcache.similarity_evaluation import SimilarityEvaluation 4 | 5 | 6 | class SearchDistanceEvaluation(SimilarityEvaluation): 7 | def __init__(self, max_distance=4.0, positive=False): 8 | self.max_distance = max_distance 9 | self.positive = positive 10 | 11 | def evaluation( 12 | self, src_dict: Dict[str, Any], cache_dict: Dict[str, Any], **_ 13 | ) -> float: 14 | distance, _ = cache_dict["search_result"] 15 | if distance < 0: 16 | distance = 0 17 | elif distance > self.max_distance: 18 | distance = self.max_distance 19 | if self.positive: 20 | return distance 21 | return self.max_distance - distance 22 | 23 | def range(self) -> Tuple[float, float]: 24 | return 0.0, self.max_distance 25 | -------------------------------------------------------------------------------- /modelcache/embedding/fasttext.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | import os 4 | from modelcache.utils import import_fasttext 5 | from modelcache.embedding.base import BaseEmbedding 6 | import_fasttext() 7 | import fasttext.util 8 | 9 | 10 | class FastText(BaseEmbedding): 11 | def __init__(self, model: str = "en", dim: int = None): 12 | self.model_path = os.path.abspath(fasttext.util.download_model(model)) 13 | self.ft = fasttext.load_model(self.model_path) 14 | 15 | if dim: 16 | fasttext.util.reduce_model(self.ft, dim) 17 | self.__dimension = self.ft.get_dimension() 18 | 19 | def to_embeddings(self, data, **_): 20 | assert isinstance(data, str), "Only allow string as input." 21 | emb = self.ft.get_sentence_vector(data) 22 | return np.array(emb).astype("float32") 23 | 24 | @property 25 | def dimension(self): 26 | return self.__dimension 27 | 28 | -------------------------------------------------------------------------------- /examples/flask/multi_cache/data_query.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import json 3 | import requests 4 | import uuid 5 | import time 6 | 7 | 8 | def run(): 9 | url = 'http://127.0.0.1:5000/multicache' 10 | request_type = 'query' 11 | UUID = str(uuid.uuid1()) + "==>" + str(time.time()) 12 | scope = {"model": "multimodal_test"} 13 | img_data = "https://img0.baidu.com/it/u=1436460262,4166266890&fm=253&fmt=auto&app=138&f=JPEG?w=500&h=282" 14 | query = {'text': ['父母带着孩子来这个地方可能会有什么顾虑'], 15 | 'imageRaw': '', 16 | 'imageUrl': img_data, 17 | 'multiType': 'IMG_TEXT'} 18 | 19 | data = {'request_type': request_type, 'scope': scope, 'query': query, 'UUID': UUID} 20 | 21 | headers = {"Content-Type": "application/json"} 22 | res = requests.post(url, headers=headers, json=json.dumps(data)) 23 | res_text = res.text 24 | print('res_text: {}'.format(res_text)) 25 | 26 | 27 | if __name__ == '__main__': 28 | run() 29 | -------------------------------------------------------------------------------- /modelcache_mm/adapter/adapter_remove.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from modelcache_mm import cache 3 | from modelcache_mm.utils.error import NotInitError 4 | 5 | 6 | def adapt_remove(*args, **kwargs): 7 | chat_cache = kwargs.pop("cache_obj", cache) 8 | model = kwargs.pop("model", None) 9 | remove_type = kwargs.pop("remove_type", None) 10 | require_object_store = kwargs.pop("require_object_store", False) 11 | if require_object_store: 12 | assert chat_cache.data_manager.o, "Object store is required for adapter." 13 | if not chat_cache.has_init: 14 | raise NotInitError() 15 | 16 | # delete data 17 | if remove_type == 'delete_by_id': 18 | id_list = kwargs.pop("id_list", []) 19 | resp = chat_cache.data_manager.delete(id_list, model=model) 20 | elif remove_type == 'truncate_by_model': 21 | resp = chat_cache.data_manager.truncate(model) 22 | else: 23 | resp = "remove_type_error" 24 | return resp 25 | -------------------------------------------------------------------------------- /model/text2vec-base-chinese/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "hfl/chinese-macbert-base", 3 | "architectures": [ 4 | "BertModel" 5 | ], 6 | "attention_probs_dropout_prob": 0.1, 7 | "classifier_dropout": null, 8 | "directionality": "bidi", 9 | "gradient_checkpointing": false, 10 | "hidden_act": "gelu", 11 | "hidden_dropout_prob": 0.1, 12 | "hidden_size": 768, 13 | "initializer_range": 0.02, 14 | "intermediate_size": 3072, 15 | "layer_norm_eps": 1e-12, 16 | "max_position_embeddings": 512, 17 | "model_type": "bert", 18 | "num_attention_heads": 12, 19 | "num_hidden_layers": 12, 20 | "pad_token_id": 0, 21 | "pooler_fc_size": 768, 22 | "pooler_num_attention_heads": 12, 23 | "pooler_num_fc_layers": 3, 24 | "pooler_size_per_head": 128, 25 | "pooler_type": "first_token_transform", 26 | "position_embedding_type": "absolute", 27 | "torch_dtype": "float32", 28 | "transformers_version": "4.12.3", 29 | "type_vocab_size": 2, 30 | "use_cache": true, 31 | "vocab_size": 21128 32 | } 33 | -------------------------------------------------------------------------------- /modelcache/utils/error.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | class CacheError(Exception): 3 | """ModelCache base error""" 4 | 5 | 6 | class NotInitError(CacheError): 7 | """Raise when the cache has been used before it's inited""" 8 | def __init__(self): 9 | super().__init__("The cache should be inited before using") 10 | 11 | 12 | class RemoveError(CacheError): 13 | """Raise when the cache has been used before it's inited""" 14 | def __init__(self): 15 | super().__init__("The cache remove error") 16 | 17 | class NotFoundError(CacheError): 18 | """Raise when getting an unsupported store.""" 19 | def __init__(self, store_type, current_type_name): 20 | super().__init__(f"Unsupported ${store_type}: {current_type_name}") 21 | 22 | 23 | class ParamError(CacheError): 24 | """Raise when receiving an invalid param.""" 25 | 26 | 27 | class PipInstallError(CacheError): 28 | """Raise when failed to install package.""" 29 | def __init__(self, package): 30 | super().__init__(f"Ran into error installing {package}.") 31 | -------------------------------------------------------------------------------- /modelcache/adapter/adapter_remove.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import asyncio 3 | 4 | from modelcache.utils.error import RemoveError 5 | 6 | 7 | async def adapt_remove(*args, **kwargs): 8 | chat_cache = kwargs.pop("cache_obj") 9 | model = kwargs.pop("model", None) 10 | remove_type = kwargs.pop("remove_type", None) 11 | require_object_store = kwargs.pop("require_object_store", False) 12 | if require_object_store: 13 | assert chat_cache.data_manager.o, "Object store is required for adapter." 14 | 15 | # delete data 16 | if remove_type == 'delete_by_id': 17 | id_list = kwargs.pop("id_list", []) 18 | resp = await asyncio.to_thread( 19 | chat_cache.data_manager.delete, 20 | id_list, model=model 21 | ) 22 | elif remove_type == 'truncate_by_model': 23 | resp = await asyncio.to_thread( 24 | chat_cache.data_manager.truncate, 25 | model 26 | ) 27 | else: 28 | # resp = "remove_type_error" 29 | raise RemoveError() 30 | return resp 31 | 32 | -------------------------------------------------------------------------------- /modelcache_mm/utils/index_util.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | def get_index_name(model): 5 | return 'multicache' + '_' + model 6 | 7 | 8 | def get_index_prefix(model): 9 | return 'prefix' + '_' + model 10 | 11 | 12 | def get_mm_index_name(model, mm_type): 13 | if mm_type not in ['IMG_TEXT', 'mm', 'IMG', 'image', 'TEXT', 'text']: 14 | raise ValueError('mm_type is not normal!') 15 | if mm_type == 'IMG_TEXT': 16 | mm_type = 'mm' 17 | elif mm_type == 'IMG': 18 | mm_type = 'image' 19 | elif mm_type == 'TEXT': 20 | mm_type = 'text' 21 | return 'multi_cache' + '_' + model + '_' + mm_type 22 | 23 | 24 | def get_mm_index_prefix(model, mm_type): 25 | if mm_type not in ['IMG_TEXT', 'mm', 'IMG', 'image', 'TEXT', 'text']: 26 | raise ValueError('mm_type is not normal!') 27 | if mm_type == 'IMG_TEXT': 28 | mm_type = 'mm' 29 | elif mm_type == 'IMG': 30 | mm_type = 'image' 31 | elif mm_type == 'TEXT': 32 | mm_type = 'text' 33 | return 'prefix' + '_' + model + '_' + mm_type 34 | -------------------------------------------------------------------------------- /modelcache/embedding/huggingface_tei.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import requests 3 | import numpy as np 4 | from modelcache.embedding.base import BaseEmbedding 5 | 6 | class HuggingfaceTEI(BaseEmbedding): 7 | def __init__(self, base_url: str, model: str): 8 | self.base_url = base_url 9 | self.model = model 10 | self.headers = { 11 | 'accept': 'application/json', 12 | 'Content-Type': 'application/json', 13 | } 14 | self.__dimension = self.to_embeddings('test').shape[0] 15 | 16 | def to_embeddings(self, data, **_): 17 | json_data = { 18 | 'input': data, 19 | 'model': self.model, 20 | } 21 | 22 | response = requests.post(self.base_url, headers=self.headers, json=json_data) 23 | embedding = response.json()['data'][0]['embedding'] 24 | return np.array(embedding) 25 | 26 | @property 27 | def dimension(self): 28 | """Embedding dimension. 29 | 30 | :return: embedding dimension 31 | """ 32 | return self.__dimension 33 | -------------------------------------------------------------------------------- /examples/flask/multi_cache/data_insert.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import time 3 | import json 4 | import uuid 5 | import requests 6 | 7 | 8 | def run(): 9 | url = 'http://127.0.0.1:5000/multicache' 10 | 11 | request_type = 'insert' 12 | scope = {"model": "multimodal_test"} 13 | # UUID = "820b0052-d9d8-11ee-95f1-52775e3e6fd1" + "==>" + str(time.time()) 14 | UUID = str(uuid.uuid1()) + "==>" + str(time.time()) 15 | img_data = "https://img0.baidu.com/it/u=1436460262,4166266890&fm=253&fmt=auto&app=138&f=JPEG?w=500&h=282" 16 | query = {'text': ['父母带着孩子来这个地方可能会有什么顾虑'], 17 | 'imageRaw': '', 18 | 'imageUrl': img_data, 19 | 'imageId': 'ccc'} 20 | answer = "应该注意小孩不要跑到铁轨上" 21 | chat_info = [{"query": query, "answer": answer}] 22 | data_dict = {'request_type': request_type, 'scope': scope, 'chat_info': chat_info, 'UUID': UUID} 23 | 24 | headers = {"Content-Type": "application/json"} 25 | res = requests.post(url, headers=headers, json=json.dumps(data_dict)) 26 | res_text = res.text 27 | print('res_text: {}'.format(res_text)) 28 | 29 | 30 | if __name__ == '__main__': 31 | run() 32 | -------------------------------------------------------------------------------- /modelcache_mm/manager/factory.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from typing import Union, Callable 3 | from modelcache_mm.manager import CacheBase, VectorBase, ObjectBase 4 | from modelcache_mm.manager.data_manager import SSDataManager, MapDataManager 5 | 6 | 7 | def get_data_manager( 8 | cache_base: Union[CacheBase, str] = None, 9 | vector_base: Union[VectorBase, str] = None, 10 | object_base: Union[ObjectBase, str] = None, 11 | max_size: int = 1000, 12 | clean_size: int = None, 13 | eviction: str = "LRU", 14 | data_path: str = "data_map.txt", 15 | get_data_container: Callable = None, 16 | ): 17 | if not cache_base and not vector_base: 18 | return MapDataManager(data_path, max_size, get_data_container) 19 | if isinstance(cache_base, str): 20 | cache_base = CacheBase(name=cache_base) 21 | if isinstance(vector_base, str): 22 | vector_base = VectorBase(name=vector_base) 23 | if isinstance(object_base, str): 24 | object_base = ObjectBase(name=object_base) 25 | assert cache_base and vector_base 26 | return SSDataManager(cache_base, vector_base, object_base, max_size, clean_size, eviction) 27 | -------------------------------------------------------------------------------- /modelcache/embedding/bge_m3.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | from modelcache.embedding.base import BaseEmbedding 4 | from transformers import AutoTokenizer, AutoModel 5 | from FlagEmbedding import BGEM3FlagModel 6 | 7 | class BgeM3Embedding(BaseEmbedding): 8 | def __init__(self, model_path: str = "model/bge-m3"): 9 | self.tokenizer = AutoTokenizer.from_pretrained(model_path) 10 | self.model = AutoModel.from_pretrained(model_path) 11 | 12 | self.bge_model = BGEM3FlagModel(model_name_or_path=model_path, 13 | model=self.model, 14 | tokenizer=self.tokenizer, 15 | use_fp16=False) 16 | 17 | self.__dimension = 768 18 | 19 | def to_embeddings(self, data, **_): 20 | if not isinstance(data, list): 21 | data = [data] 22 | 23 | embeddings = self.bge_model.encode(data, batch_size=12, max_length=8192)['dense_vecs'] 24 | return np.array(embeddings).astype("float32") 25 | 26 | @property 27 | def dimension(self): 28 | return self.__dimension -------------------------------------------------------------------------------- /modelcache_mm/manager/vector_data/base.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from abc import ABC, abstractmethod 3 | import numpy as np 4 | from typing import List 5 | from dataclasses import dataclass 6 | 7 | 8 | @dataclass 9 | class VectorData: 10 | id: int 11 | data: np.ndarray 12 | 13 | 14 | class VectorBase(ABC): 15 | """VectorBase: base vector store interface""" 16 | 17 | @abstractmethod 18 | def add(self, datas: List[VectorData], model=None, mm_type=None): 19 | pass 20 | 21 | # @abstractmethod 22 | # def search(self, data: np.ndarray, top_k: int, model): 23 | # pass 24 | 25 | @abstractmethod 26 | def search(self, data: np.ndarray, top_k: int, model, mm_type): 27 | pass 28 | 29 | @abstractmethod 30 | def create(self, model=None, mm_type=None): 31 | pass 32 | 33 | @abstractmethod 34 | def rebuild(self, ids=None) -> bool: 35 | pass 36 | 37 | @abstractmethod 38 | def delete(self, ids) -> bool: 39 | pass 40 | 41 | @abstractmethod 42 | def rebuild_idx(self, model): 43 | pass 44 | 45 | def flush(self): 46 | pass 47 | 48 | def close(self): 49 | pass 50 | -------------------------------------------------------------------------------- /modelcache_mm/utils/error.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | class CacheError(Exception): 3 | """ModelCache base error""" 4 | 5 | 6 | class NotInitError(CacheError): 7 | """Raise when the cache has been used before it's inited""" 8 | def __init__(self): 9 | super().__init__("The cache should be inited before using") 10 | 11 | 12 | class RemoveError(CacheError): 13 | """Raise when the cache has been used before it's inited""" 14 | def __init__(self): 15 | super().__init__("The cache remove error") 16 | 17 | class NotFoundError(CacheError): 18 | """Raise when getting an unsupported store.""" 19 | def __init__(self, store_type, current_type_name): 20 | super().__init__(f"Unsupported ${store_type}: {current_type_name}") 21 | 22 | 23 | class ParamError(CacheError): 24 | """Raise when receiving an invalid param.""" 25 | 26 | 27 | class PipInstallError(CacheError): 28 | """Raise when failed to install package.""" 29 | def __init__(self, package): 30 | super().__init__(f"Ran into error installing {package}.") 31 | 32 | 33 | class MultiTypeError(CacheError): 34 | def __init__(self): 35 | super().__init__("multichat type error, please check") 36 | -------------------------------------------------------------------------------- /modelcache_mm/manager/scalar_data/manager.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from modelcache_mm.utils import import_sql_client 3 | from modelcache_mm.utils.error import NotFoundError 4 | 5 | SQL_URL = {"sqlite": "./sqlite.db"} 6 | 7 | 8 | class CacheBase: 9 | """ 10 | CacheBase to manager the cache storage. 11 | """ 12 | 13 | def __init__(self): 14 | raise EnvironmentError( 15 | "CacheBase is designed to be instantiated, please using the `CacheBase.get(name)`." 16 | ) 17 | 18 | @staticmethod 19 | def get(name, **kwargs): 20 | 21 | if name in ["mysql", "oceanbase"]: 22 | from modelcache_mm.manager.scalar_data.sql_storage import SQLStorage 23 | config = kwargs.get("config") 24 | import_sql_client(name) 25 | cache_base = SQLStorage(db_type=name, config=config) 26 | elif name == 'sqlite': 27 | from modelcache_mm.manager.scalar_data.sql_storage_sqlite import SQLStorage 28 | sql_url = kwargs.get("sql_url", SQL_URL[name]) 29 | cache_base = SQLStorage(db_type=name, url=sql_url) 30 | else: 31 | raise NotFoundError("cache store", name) 32 | return cache_base 33 | -------------------------------------------------------------------------------- /modelcache/embedding/huggingface.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from modelcache.embedding.base import BaseEmbedding 3 | from sentence_transformers import SentenceTransformer 4 | 5 | class Huggingface(BaseEmbedding): 6 | def __init__(self, model: str): 7 | self.model = SentenceTransformer(model,tokenizer_kwargs={ 8 | "clean_up_tokenization_spaces":False 9 | }) 10 | try: 11 | self.__dimension = self.model.config.hidden_size 12 | except Exception: 13 | from transformers import AutoConfig 14 | 15 | config = AutoConfig.from_pretrained(model) 16 | self.__dimension = config.hidden_size 17 | 18 | def to_embeddings(self, data: str, **_): 19 | """Generate embedding given text input 20 | 21 | :param data: text in string. 22 | :type data: str 23 | 24 | :return: a text embedding in shape of (dim,). 25 | """ 26 | 27 | if not data: 28 | raise ValueError("No data provided for embedding.") 29 | embeddings = self.model.encode(data) 30 | return embeddings[0] if len(data) == 1 else embeddings 31 | 32 | @property 33 | def dimension(self): 34 | """Embedding dimension. 35 | 36 | :return: embedding dimension 37 | """ 38 | return self.__dimension 39 | -------------------------------------------------------------------------------- /modelcache/report.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | class Report: 3 | def __init__(self): 4 | self.embedding_all_time = 0 5 | self.embedding_count = 0 6 | self.search_all_time = 0 7 | self.search_count = 0 8 | self.hint_cache_count = 0 9 | 10 | def embedding(self, delta_time): 11 | """Embedding counts and time. 12 | 13 | :param delta_time: additional runtime. 14 | """ 15 | self.embedding_all_time += delta_time 16 | self.embedding_count += 1 17 | 18 | def search(self, delta_time): 19 | """Search counts and time. 20 | 21 | :param delta_time: additional runtime. 22 | """ 23 | self.search_all_time += delta_time 24 | self.search_count += 1 25 | 26 | def average_embedding_time(self): 27 | """Average embedding time.""" 28 | return round( 29 | self.embedding_all_time / self.embedding_count 30 | if self.embedding_count != 0 31 | else 0, 32 | 4, 33 | ) 34 | 35 | def average_search_time(self): 36 | return round( 37 | self.search_all_time / self.search_count 38 | if self.embedding_count != 0 39 | else 0, 40 | 4, 41 | ) 42 | 43 | def hint_cache(self): 44 | self.hint_cache_count += 1 45 | -------------------------------------------------------------------------------- /modelcache_mm/report.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | class Report: 3 | def __init__(self): 4 | self.embedding_all_time = 0 5 | self.embedding_count = 0 6 | self.search_all_time = 0 7 | self.search_count = 0 8 | self.hint_cache_count = 0 9 | 10 | def embedding(self, delta_time): 11 | """Embedding counts and time. 12 | 13 | :param delta_time: additional runtime. 14 | """ 15 | self.embedding_all_time += delta_time 16 | self.embedding_count += 1 17 | 18 | def search(self, delta_time): 19 | """Search counts and time. 20 | 21 | :param delta_time: additional runtime. 22 | """ 23 | self.search_all_time += delta_time 24 | self.search_count += 1 25 | 26 | def average_embedding_time(self): 27 | """Average embedding time.""" 28 | return round( 29 | self.embedding_all_time / self.embedding_count 30 | if self.embedding_count != 0 31 | else 0, 32 | 4, 33 | ) 34 | 35 | def average_search_time(self): 36 | return round( 37 | self.search_all_time / self.search_count 38 | if self.embedding_count != 0 39 | else 0, 40 | 4, 41 | ) 42 | 43 | def hint_cache(self): 44 | self.hint_cache_count += 1 45 | -------------------------------------------------------------------------------- /modelcache/embedding/llmEmb.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | from modelcache.embedding.base import BaseEmbedding 4 | from transformers import AutoTokenizer 5 | from transformers import AutoConfig 6 | 7 | 8 | class LlmEmb2Vec(BaseEmbedding): 9 | def __init__(self): 10 | 11 | self.model_name = '' # 13b-mft-embedding.npy 12 | model_path = '' # .npy file storage path 13 | model_file = model_path + self.model_name # .npy file 14 | config = AutoConfig.from_pretrained(model_path) 15 | dimension = config.hidden_size 16 | self.__dimension = dimension 17 | self.model = np.load(model_file) 18 | self.tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True) 19 | 20 | def to_embeddings(self, data, **_): 21 | """Generate embedding given text input 22 | 23 | :param data: text in string. 24 | :return: a text embedding in shape of (dim,). 25 | """ 26 | input_ids = self.tokenizer.encode(data, add_special_tokens=True) 27 | embedding_array = self.model[input_ids].mean(axis=0) 28 | return embedding_array 29 | 30 | def post_proc(self, token_embeddings, inputs): 31 | pass 32 | 33 | @property 34 | def dimension(self): 35 | """Embedding dimension. 36 | :return: embedding dimension 37 | """ 38 | return self.__dimension 39 | -------------------------------------------------------------------------------- /modelcache_mm/manager/eviction_manager.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | class EvictionManager: 3 | MAX_MARK_COUNT = 5000 4 | MAX_MARK_RATE = 0.1 5 | BATCH_SIZE = 100000 6 | REBUILD_CONDITION = 5 7 | 8 | def __init__(self, scalar_storage, vector_base): 9 | self._scalar_storage = scalar_storage 10 | self._vector_base = vector_base 11 | self.delete_count = 0 12 | 13 | def check_evict(self): 14 | mark_count = self._scalar_storage.count(state=-1) 15 | all_count = self._scalar_storage.count(is_all=True) 16 | if ( 17 | mark_count > self.MAX_MARK_COUNT 18 | or mark_count / all_count > self.MAX_MARK_RATE 19 | ): 20 | return True 21 | return False 22 | 23 | def delete(self): 24 | mark_ids = self._scalar_storage.get_ids(deleted=True) 25 | self._scalar_storage.clear_deleted_data() 26 | self._vector_base.delete(mark_ids) 27 | self.delete_count += 1 28 | if self.delete_count >= self.REBUILD_CONDITION: 29 | self.rebuild() 30 | 31 | def rebuild(self): 32 | self._scalar_storage.clear_deleted_data() 33 | ids = self._scalar_storage.get_ids(deleted=False) 34 | self._vector_base.rebuild(ids) 35 | self.delete_count = 0 36 | 37 | def soft_evict(self, marked_keys): 38 | self._scalar_storage.mark_deleted(marked_keys) 39 | -------------------------------------------------------------------------------- /flask4modelcache.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import asyncio 3 | 4 | from flask import Flask, request, jsonify 5 | from modelcache.cache import Cache 6 | from modelcache.embedding import EmbeddingModel 7 | 8 | 9 | async def main(): 10 | 11 | # 创建一个Flask实例 12 | app = Flask(__name__) 13 | 14 | cache,loop = await Cache.init( 15 | sql_storage="mysql", 16 | vector_storage="milvus", 17 | embedding_model=EmbeddingModel.HUGGINGFACE_ALL_MPNET_BASE_V2, 18 | embedding_workers_num=2 19 | ) 20 | 21 | @app.route('/welcome') 22 | def first_flask(): # 视图函数 23 | return 'hello, modelcache!' 24 | 25 | 26 | @app.post('/modelcache') 27 | def user_backend(): 28 | try: 29 | param_dict = request.json 30 | except Exception: 31 | result = {"errorCode": 400, "errorDesc": "bad request", "cacheHit": False, "delta_time": 0, "hit_query": '',"answer": ''} 32 | return jsonify(result), 400 33 | 34 | try: 35 | result = asyncio.run_coroutine_threadsafe( 36 | cache.handle_request(param_dict), loop 37 | ).result() 38 | return jsonify(result), 200 39 | except Exception as e: 40 | result = {"errorCode": 500, "errorDesc": str(e), "cacheHit": False, "delta_time": 0, "hit_query": '',"answer": ''} 41 | cache.save_query_resp(result, model='', query='', delta_time=0) 42 | return jsonify(result), 500 43 | 44 | await asyncio.to_thread(app.run, host='0.0.0.0', port=5000) 45 | 46 | 47 | if __name__ == '__main__': 48 | asyncio.run(main()) 49 | -------------------------------------------------------------------------------- /flask4modelcache_demo.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import asyncio 3 | 4 | from flask import Flask, request, jsonify 5 | from modelcache.cache import Cache 6 | from modelcache.embedding import EmbeddingModel 7 | 8 | 9 | async def main(): 10 | 11 | # 创建一个Flask实例 12 | app = Flask(__name__) 13 | 14 | cache,loop = await Cache.init( 15 | sql_storage="sqlite", 16 | vector_storage="faiss", 17 | embedding_model=EmbeddingModel.HUGGINGFACE_ALL_MPNET_BASE_V2, 18 | embedding_workers_num=2 19 | ) 20 | 21 | @app.route('/welcome') 22 | def first_flask(): # 视图函数 23 | return 'hello, modelcache!' 24 | 25 | 26 | @app.post('/modelcache') 27 | def user_backend(): 28 | try: 29 | param_dict = request.json 30 | except Exception: 31 | result = {"errorCode": 400, "errorDesc": "bad request", "cacheHit": False, "delta_time": 0, "hit_query": '',"answer": ''} 32 | return jsonify(result), 400 33 | 34 | try: 35 | result = asyncio.run_coroutine_threadsafe( 36 | cache.handle_request(param_dict), loop 37 | ).result() 38 | return jsonify(result), 200 39 | except Exception as e: 40 | result = {"errorCode": 500, "errorDesc": str(e), "cacheHit": False, "delta_time": 0, "hit_query": '',"answer": ''} 41 | cache.save_query_resp(result, model='', query='', delta_time=0) 42 | return jsonify(result), 500 43 | 44 | await asyncio.to_thread(app.run, host='0.0.0.0', port=5000) 45 | 46 | 47 | if __name__ == '__main__': 48 | asyncio.run(main()) 49 | -------------------------------------------------------------------------------- /fastapi4modelcache.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import asyncio 3 | from contextlib import asynccontextmanager 4 | import uvicorn 5 | import json 6 | from fastapi.responses import JSONResponse 7 | from fastapi import FastAPI, Request 8 | from modelcache.cache import Cache 9 | from modelcache.embedding import EmbeddingModel 10 | 11 | @asynccontextmanager 12 | async def lifespan(app: FastAPI): 13 | global cache 14 | cache, _ = await Cache.init( 15 | sql_storage="mysql", 16 | vector_storage="milvus", 17 | embedding_model=EmbeddingModel.HUGGINGFACE_ALL_MPNET_BASE_V2, 18 | embedding_workers_num=2 19 | ) 20 | yield 21 | 22 | app = FastAPI(lifespan=lifespan) 23 | cache: Cache = None 24 | 25 | @app.get("/welcome") 26 | async def first_fastapi(): 27 | return "hello, modelcache!" 28 | 29 | @app.post("/modelcache") 30 | async def user_backend(request: Request): 31 | 32 | try: 33 | request_data = await request.json() 34 | except Exception: 35 | result = {"errorCode": 400, "errorDesc": "bad request", "cacheHit": False, "delta_time": 0, "hit_query": '', "answer": ''} 36 | return JSONResponse(status_code=400, content=result) 37 | 38 | try: 39 | return await cache.handle_request(request_data) 40 | except Exception as e: 41 | result = {"errorCode": 500, "errorDesc": str(e), "cacheHit": False, "delta_time": 0, "hit_query": '', "answer": ''} 42 | cache.save_query_resp(result, model='', query='', delta_time=0) 43 | return JSONResponse(status_code=500, content=result) 44 | 45 | if __name__ == '__main__': 46 | uvicorn.run(app, host='0.0.0.0', port=5000, loop="asyncio", http="httptools") 47 | -------------------------------------------------------------------------------- /fastapi4modelcache_demo.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import asyncio 3 | from contextlib import asynccontextmanager 4 | import uvicorn 5 | import json 6 | from fastapi.responses import JSONResponse 7 | from fastapi import FastAPI, Request 8 | from modelcache.cache import Cache 9 | from modelcache.embedding import EmbeddingModel 10 | 11 | @asynccontextmanager 12 | async def lifespan(app: FastAPI): 13 | global cache 14 | cache, _ = await Cache.init( 15 | sql_storage="sqlite", 16 | vector_storage="faiss", 17 | embedding_model=EmbeddingModel.HUGGINGFACE_ALL_MPNET_BASE_V2, 18 | embedding_workers_num=2 19 | ) 20 | yield 21 | 22 | app = FastAPI(lifespan=lifespan) 23 | cache: Cache = None 24 | 25 | @app.get("/welcome") 26 | async def first_fastapi(): 27 | return "hello, modelcache!" 28 | 29 | @app.post("/modelcache") 30 | async def user_backend(request: Request): 31 | 32 | try: 33 | request_data = await request.json() 34 | except Exception: 35 | result = {"errorCode": 400, "errorDesc": "bad request", "cacheHit": False, "delta_time": 0, "hit_query": '', "answer": ''} 36 | return JSONResponse(status_code=400, content=result) 37 | 38 | try: 39 | return await cache.handle_request(request_data) 40 | except Exception as e: 41 | result = {"errorCode": 500, "errorDesc": str(e), "cacheHit": False, "delta_time": 0, "hit_query": '', "answer": ''} 42 | cache.save_query_resp(result, model='', query='', delta_time=0) 43 | return JSONResponse(status_code=500, content=result) 44 | 45 | if __name__ == '__main__': 46 | uvicorn.run(app, host='0.0.0.0', port=5000, loop="asyncio", http="httptools") 47 | -------------------------------------------------------------------------------- /modelcache/manager/eviction_manager.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | class EvictionManager: 3 | """ 4 | EvictionManager to manager the eviction policy. 5 | 6 | :param scalar_storage: CacheStorage to manager the scalar data. 7 | :type scalar_storage: :class:`CacheStorage` 8 | :param vector_base: VectorBase to manager the vector data. 9 | :type vector_base: :class:`VectorBase` 10 | """ 11 | 12 | MAX_MARK_COUNT = 5000 13 | MAX_MARK_RATE = 0.1 14 | BATCH_SIZE = 100000 15 | REBUILD_CONDITION = 5 16 | 17 | def __init__(self, scalar_storage, vector_base): 18 | self._scalar_storage = scalar_storage 19 | self._vector_base = vector_base 20 | self.delete_count = 0 21 | 22 | def check_evict(self): 23 | mark_count = self._scalar_storage.count(state=-1) 24 | all_count = self._scalar_storage.count(is_all=True) 25 | if ( 26 | mark_count > self.MAX_MARK_COUNT 27 | or mark_count / all_count > self.MAX_MARK_RATE 28 | ): 29 | return True 30 | return False 31 | 32 | def delete(self,model): 33 | mark_ids = self._scalar_storage.get_ids(deleted=True) 34 | self._scalar_storage.clear_deleted_data() 35 | self._vector_base.delete(mark_ids,model) 36 | self.delete_count += 1 37 | if self.delete_count >= self.REBUILD_CONDITION: 38 | self.rebuild() 39 | 40 | def rebuild(self): 41 | self._scalar_storage.clear_deleted_data() 42 | ids = self._scalar_storage.get_ids(deleted=False) 43 | self._vector_base.rebuild(ids) 44 | self.delete_count = 0 45 | 46 | def soft_evict(self, marked_keys): 47 | self._scalar_storage.mark_deleted(marked_keys) 48 | -------------------------------------------------------------------------------- /data/mysql/init/init.sql: -------------------------------------------------------------------------------- 1 | CREATE DATABASE IF NOT EXISTS `modelcache`; 2 | 3 | USE `modelcache`; 4 | 5 | CREATE TABLE IF NOT EXISTS `modelcache_llm_answer` ( 6 | `id` CHAR(36) comment '主键', 7 | `gmt_create` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP comment '创建时间', 8 | `gmt_modified` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP comment '修改时间', 9 | `question` text NOT NULL comment 'question', 10 | `answer` text NOT NULL comment 'answer', 11 | `answer_type` int(11) NOT NULL comment 'answer_type', 12 | `hit_count` int(11) NOT NULL DEFAULT '0' comment 'hit_count', 13 | `model` varchar(1000) NOT NULL comment 'model', 14 | `embedding_data` blob NOT NULL comment 'embedding_data', 15 | `is_deleted` tinyint(1) NOT NULL DEFAULT '0' COMMENT 'delete state(0 Not deleted,-1 deleted)', 16 | PRIMARY KEY(`id`) 17 | ) AUTO_INCREMENT = 1 DEFAULT CHARSET = utf8mb4 COMMENT = 'cache_codegpt_answer'; 18 | 19 | CREATE TABLE IF NOT EXISTS `modelcache_query_log` ( 20 | `id` bigint(20) unsigned NOT NULL AUTO_INCREMENT comment '主键', 21 | `gmt_create` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP comment '创建时间', 22 | `gmt_modified` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP comment '修改时间', 23 | `error_code` int(11) NOT NULL comment 'errorCode', 24 | `error_desc` varchar(1000) NOT NULL comment 'errorDesc', 25 | `cache_hit` varchar(100) NOT NULL comment 'cacheHit', 26 | `delta_time` float NOT NULL comment 'delta_time', 27 | `model` varchar(1000) NOT NULL comment 'model', 28 | `query` text NOT NULL comment 'query', 29 | `hit_query` text NOT NULL comment 'hitQuery', 30 | `answer` text NOT NULL comment 'answer', 31 | PRIMARY KEY(`id`) 32 | ) AUTO_INCREMENT = 1 DEFAULT CHARSET = utf8mb4 COMMENT = 'modelcache_query_log'; 33 | -------------------------------------------------------------------------------- /reference_doc/create_table.sql: -------------------------------------------------------------------------------- 1 | CREATE DATABASE IF NOT EXISTS `modelcache`; 2 | 3 | USE `modelcache`; 4 | 5 | CREATE TABLE IF NOT EXISTS `modelcache_llm_answer` ( 6 | `id` CHAR(36) comment '主键', 7 | `gmt_create` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP comment '创建时间', 8 | `gmt_modified` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP comment '修改时间', 9 | `question` text NOT NULL comment 'question', 10 | `answer` text NOT NULL comment 'answer', 11 | `answer_type` int(11) NOT NULL comment 'answer_type', 12 | `hit_count` int(11) NOT NULL DEFAULT '0' comment 'hit_count', 13 | `model` varchar(1000) NOT NULL comment 'model', 14 | `embedding_data` blob NOT NULL comment 'embedding_data', 15 | `is_deleted` tinyint(1) NOT NULL DEFAULT '0' COMMENT 'delete state(0 Not deleted,-1 deleted)', 16 | PRIMARY KEY(`id`) 17 | ) AUTO_INCREMENT = 1 DEFAULT CHARSET = utf8mb4 COMMENT = 'cache_codegpt_answer'; 18 | 19 | CREATE TABLE IF NOT EXISTS `modelcache_query_log` ( 20 | `id` bigint(20) unsigned NOT NULL AUTO_INCREMENT comment '主键', 21 | `gmt_create` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP comment '创建时间', 22 | `gmt_modified` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP comment '修改时间', 23 | `error_code` int(11) NOT NULL comment 'errorCode', 24 | `error_desc` varchar(1000) NOT NULL comment 'errorDesc', 25 | `cache_hit` varchar(100) NOT NULL comment 'cacheHit', 26 | `delta_time` float NOT NULL comment 'delta_time', 27 | `model` varchar(1000) NOT NULL comment 'model', 28 | `query` text NOT NULL comment 'query', 29 | `hit_query` text NOT NULL comment 'hitQuery', 30 | `answer` text NOT NULL comment 'answer', 31 | PRIMARY KEY(`id`) 32 | ) AUTO_INCREMENT = 1 DEFAULT CHARSET = utf8mb4 COMMENT = 'modelcache_query_log'; 33 | -------------------------------------------------------------------------------- /modelcache/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import importlib.util 3 | from typing import Optional 4 | from modelcache.utils.dependency_control import prompt_install 5 | 6 | 7 | def _check_library(libname: str, prompt: bool = True, package: Optional[str] = None): 8 | is_avail = False 9 | if importlib.util.find_spec(libname): 10 | is_avail = True 11 | if not is_avail and prompt: 12 | prompt_install(package if package else libname) 13 | return is_avail 14 | 15 | 16 | def import_onnxruntime(): 17 | _check_library("onnxruntime") 18 | 19 | 20 | def import_huggingface(): 21 | _check_library("transformers") 22 | 23 | 24 | def import_huggingface_hub(): 25 | _check_library("huggingface_hub", package="huggingface-hub") 26 | 27 | 28 | def import_pymysql(): 29 | _check_library("pymysql") 30 | 31 | 32 | def import_sql_client(db_name): 33 | if db_name in ["mysql"]: 34 | import_pymysql() 35 | 36 | 37 | def import_pymilvus(): 38 | _check_library("pymilvus") 39 | 40 | 41 | def import_milvus_lite(): 42 | _check_library("milvus") 43 | 44 | 45 | def import_faiss(): 46 | _check_library("faiss", package="faiss-cpu") 47 | 48 | 49 | def import_torch(): 50 | _check_library("torch") 51 | 52 | 53 | def import_fasttext(): 54 | _check_library("fasttext") 55 | 56 | 57 | def import_paddle(): 58 | prompt_install("protobuf==3.20.0") 59 | _check_library("paddlepaddle") 60 | 61 | 62 | def import_paddlenlp(): 63 | _check_library("paddlenlp") 64 | 65 | 66 | def import_timm(): 67 | _check_library("timm", package="timm") 68 | 69 | 70 | def import_pillow(): 71 | _check_library("PIL", package="pillow") 72 | 73 | 74 | def import_redis(): 75 | _check_library("redis") 76 | 77 | 78 | def import_chromadb(): 79 | _check_library("chromadb", package="chromadb") -------------------------------------------------------------------------------- /modelcache_mm/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import importlib.util 3 | from typing import Optional 4 | from modelcache.utils.dependency_control import prompt_install 5 | 6 | 7 | def _check_library(libname: str, prompt: bool = True, package: Optional[str] = None): 8 | is_avail = False 9 | if importlib.util.find_spec(libname): 10 | is_avail = True 11 | if not is_avail and prompt: 12 | prompt_install(package if package else libname) 13 | return is_avail 14 | 15 | 16 | def import_onnxruntime(): 17 | _check_library("onnxruntime") 18 | 19 | 20 | def import_huggingface(): 21 | _check_library("transformers") 22 | 23 | 24 | def import_huggingface_hub(): 25 | _check_library("huggingface_hub", package="huggingface-hub") 26 | 27 | 28 | def import_pymysql(): 29 | _check_library("pymysql") 30 | 31 | 32 | def import_sql_client(db_name): 33 | if db_name in ["mysql"]: 34 | import_pymysql() 35 | 36 | 37 | def import_pymilvus(): 38 | _check_library("pymilvus") 39 | 40 | 41 | def import_milvus_lite(): 42 | _check_library("milvus") 43 | 44 | 45 | def import_faiss(): 46 | _check_library("faiss", package="faiss-cpu") 47 | 48 | 49 | def import_torch(): 50 | _check_library("torch") 51 | 52 | 53 | def import_fasttext(): 54 | _check_library("fasttext") 55 | 56 | 57 | def import_paddle(): 58 | prompt_install("protobuf==3.20.0") 59 | _check_library("paddlepaddle") 60 | 61 | 62 | def import_paddlenlp(): 63 | _check_library("paddlenlp") 64 | 65 | 66 | def import_timm(): 67 | _check_library("timm", package="timm") 68 | 69 | 70 | def import_pillow(): 71 | _check_library("PIL", package="pillow") 72 | 73 | 74 | def import_redis(): 75 | _check_library("redis") 76 | 77 | 78 | def import_chromadb(): 79 | _check_library("chromadb", package="chromadb") -------------------------------------------------------------------------------- /modelcache_mm/adapter/adapter.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import logging 3 | 4 | from modelcache_mm.adapter.adapter_query import adapt_query 5 | from modelcache_mm.adapter.adapter_insert import adapt_insert 6 | from modelcache_mm.adapter.adapter_remove import adapt_remove 7 | from modelcache_mm.adapter.adapter_register import adapt_register 8 | 9 | 10 | class ChatCompletion(object): 11 | """Openai ChatCompletion Wrapper""" 12 | @classmethod 13 | def create_query(cls, *args, **kwargs): 14 | def cache_data_convert(cache_data, cache_query): 15 | return construct_resp_from_cache(cache_data, cache_query) 16 | try: 17 | return adapt_query( 18 | cache_data_convert, 19 | *args, 20 | **kwargs 21 | ) 22 | except Exception as e: 23 | # return str(e) 24 | raise e 25 | 26 | @classmethod 27 | def create_insert(cls, *args, **kwargs): 28 | try: 29 | return adapt_insert( 30 | *args, 31 | **kwargs 32 | ) 33 | except Exception as e: 34 | # return str(e) 35 | raise e 36 | 37 | @classmethod 38 | def create_remove(cls, *args, **kwargs): 39 | try: 40 | return adapt_remove( 41 | *args, 42 | **kwargs 43 | ) 44 | except Exception as e: 45 | raise e 46 | 47 | @classmethod 48 | def create_register(cls, *args, **kwargs): 49 | try: 50 | return adapt_register( 51 | *args, 52 | **kwargs 53 | ) 54 | except Exception as e: 55 | raise e 56 | 57 | 58 | def construct_resp_from_cache(return_message, return_query): 59 | return { 60 | "modelcache": True, 61 | "hitQuery": return_query, 62 | "data": return_message, 63 | "errorCode": 0 64 | } 65 | -------------------------------------------------------------------------------- /modelcache/adapter/adapter.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import logging 3 | from modelcache.adapter.adapter_query import adapt_query 4 | from modelcache.adapter.adapter_insert import adapt_insert 5 | from modelcache.adapter.adapter_remove import adapt_remove 6 | from modelcache.adapter.adapter_register import adapt_register 7 | 8 | 9 | class ChatCompletion(object): 10 | """Openai ChatCompletion Wrapper""" 11 | 12 | @classmethod 13 | async def create_query(cls, *args, **kwargs): 14 | def cache_data_convert(cache_data, cache_query): 15 | return construct_resp_from_cache(cache_data, cache_query) 16 | try: 17 | return await adapt_query( 18 | cache_data_convert, 19 | *args, 20 | **kwargs 21 | ) 22 | except Exception as e: 23 | print(e) 24 | return str(e) 25 | 26 | @classmethod 27 | async def create_insert(cls, *args, **kwargs): 28 | try: 29 | return await adapt_insert( 30 | *args, 31 | **kwargs 32 | ) 33 | except Exception as e: 34 | print(e) 35 | return str(e) 36 | 37 | @classmethod 38 | async def create_remove(cls, *args, **kwargs): 39 | try: 40 | return await adapt_remove( 41 | *args, 42 | **kwargs 43 | ) 44 | except Exception as e: 45 | print(e) 46 | return str(e) 47 | 48 | @classmethod 49 | async def create_register(cls, *args, **kwargs): 50 | try: 51 | return await adapt_register( 52 | *args, 53 | **kwargs 54 | ) 55 | except Exception as e: 56 | print(e) 57 | return str(e) 58 | 59 | 60 | def construct_resp_from_cache(return_message, return_query): 61 | return { 62 | "modelcache": True, 63 | "hitQuery": return_query, 64 | "data": return_message, 65 | "errorCode": 0 66 | } 67 | -------------------------------------------------------------------------------- /docs/4.create-cache.md: -------------------------------------------------------------------------------- 1 | # Create cache 2 | 3 | This topic describes how to create cache. 4 | 5 | ## Default cache interface 6 | 7 | ```py 8 | class Cache: 9 | # ModelCache calls it whe you start the cache system 10 | def __init__(self): 11 | self.has_init = False 12 | self.cache_enable_func = None 13 | self.embedding_func = None 14 | self.post_process_messages_func = None 15 | self.config = Config() 16 | ``` 17 | 18 | This function embeds text into dense vectors for context similarity search. ModelCache supports these embedding context methods: Huggingface, ONNX, and SentenceTransformers. The default model is text2vec Hugging Face because it performs better for Chinese. Simply initialize your embedding function as `text2vec.to_embeddings`. 19 | 20 | ```py 21 | data_manager = get_data_manager(CacheBase("mysql", config=mysql_config), 22 | VectorBase("milvus", dimension=data2vec.dimension, milvus_config=milvus_config)) 23 | 24 | cache.init( 25 | embedding_func=data2vec.to_embeddings, 26 | data_manager=data_manager, 27 | similarity_evaluation=SearchDistanceEvaluation(), 28 | query_pre_embedding_func=query_multi_splicing, 29 | insert_pre_embedding_func=insert_multi_splicing, 30 | ) 31 | ``` 32 | 33 | data_manager CacheVase stores all scalar data, such as original questions, prompts, answers, and access times. ModelCache supports multiple cache storages like SQLite, MySQL, and OceanBase. NoSQL databases will be supported in the future. 34 | 35 | data_manager VectorBase stores and searches all embedding vectors to find semantically similar results. ModelCache supports using vector search libraries like FAISS or vector databases like Milvus. More vector database and cloud service will be supported in the future. 36 | 37 | ## Examples 38 | 39 | ```py 40 | data_manager = get_data_manager(CacheBase("sqlite"), VectorBase("faiss", dimension=data2vec.dimension)) 41 | data_manager = get_data_manager(CacheBase("oceanbase"), VectorBase("milvus", dimension=data2vec.dimension)) 42 | ``` 43 | -------------------------------------------------------------------------------- /modelcache_mm/embedding/clip.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | from modelcache.embedding.base import BaseEmbedding 4 | from modelscope.utils.constant import Tasks 5 | from modelscope.pipelines import pipeline 6 | from modelscope.preprocessors.image import load_image 7 | 8 | 9 | class ClipAudio(BaseEmbedding): 10 | def __init__(self, model: str = 'damo/multi-modal_clip-vit-base-patch16_zh'): 11 | self.model = model 12 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 13 | self.clip_pipeline = pipeline(task=Tasks.multi_modal_embedding, 14 | model=model, model_revision='v1.0.1') 15 | self.__dimension = 1024 16 | 17 | def to_embeddings(self, data_dict, **_): 18 | text_list = data_dict['text'] 19 | image_data = data_dict['image'] 20 | 21 | # img_data = None 22 | # txt_data = None 23 | 24 | if image_data: 25 | input_img = load_image(image_data) 26 | img_embedding = self.clip_pipeline.forward({'img': input_img})['img_embedding'].tolist()[0] if input_img else [] 27 | else: 28 | raise ValueError('image_data is None, please check!') 29 | 30 | if text_list and len(text_list) > 0: 31 | text_embedding = self.clip_pipeline.forward({'text': text_list})['text_embedding'].tolist()[0] if text_list else [] 32 | else: 33 | raise ValueError('text_list is None, please check!') 34 | 35 | return {'image_embedding': img_embedding, 'text_embeddings': text_embedding} 36 | 37 | def post_proc(self, token_embeddings, inputs): 38 | attention_mask = inputs["attention_mask"] 39 | input_mask_expanded = ( 40 | attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() 41 | ) 42 | sentence_embs = torch.sum( 43 | token_embeddings * input_mask_expanded, 1 44 | ) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) 45 | return sentence_embs 46 | 47 | @property 48 | def dimension(self): 49 | return self.__dimension 50 | -------------------------------------------------------------------------------- /modelcache/manager/vector_data/faiss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | from typing import List 4 | import numpy as np 5 | from modelcache.manager.vector_data.base import VectorStorage, VectorData 6 | from modelcache.utils import import_faiss 7 | import_faiss() 8 | import faiss # pylint: disable=C0413 9 | 10 | 11 | class Faiss(VectorStorage): 12 | def __init__(self, index_file_path, dimension, top_k): 13 | self._index_file_path = index_file_path 14 | self._dimension = dimension 15 | self._index = faiss.index_factory(self._dimension, "IDMap,Flat", faiss.METRIC_L2) 16 | self._top_k = top_k 17 | if os.path.isfile(index_file_path): 18 | self._index = faiss.read_index(index_file_path) 19 | 20 | def mul_add(self, datas: List[VectorData], model=None): 21 | data_array, id_array = map(list, zip(*((data.data, data.id) for data in datas))) 22 | np_data = np.array(data_array).astype("float32") 23 | ids = np.array(id_array) 24 | self._index.add_with_ids(np_data, ids) 25 | 26 | def search(self, data: np.ndarray, top_k: int = -1, model=None): 27 | if self._index.ntotal == 0: 28 | return None 29 | if top_k == -1: 30 | top_k = self._top_k 31 | np_data = np.array(data).astype("float32").reshape(1, -1) 32 | dist, ids = self._index.search(np_data, top_k) 33 | ids = [int(i) for i in ids[0]] 34 | return list(zip(dist[0], ids)) 35 | 36 | def rebuild_col(self, ids=None): 37 | try: 38 | self._index.reset() 39 | except Exception as e: 40 | return f"An error occurred during index rebuild: {e}" 41 | 42 | def rebuild(self, ids=None): 43 | return True 44 | 45 | def delete(self, ids): 46 | ids_to_remove = np.array(ids) 47 | self._index.remove_ids(faiss.IDSelectorBatch(ids_to_remove.size, faiss.swig_ptr(ids_to_remove))) 48 | 49 | def flush(self): 50 | faiss.write_index(self._index, self._index_file_path) 51 | 52 | def close(self): 53 | self.flush() 54 | 55 | def count(self): 56 | return self._index.ntotal 57 | -------------------------------------------------------------------------------- /docker-compose.yaml: -------------------------------------------------------------------------------- 1 | name: "modelcache" 2 | services: 3 | mysql: 4 | image: mysql:8.0.23 5 | container_name: mysql 6 | environment: 7 | MYSQL_ROOT_PASSWORD: 'root' 8 | MYSQL_DATABASE: 'modelcache' 9 | MYSQL_USER: 'modelcache' 10 | MYSQL_PASSWORD: 'modelcache' 11 | ports: 12 | - 3306:3306 13 | volumes: 14 | - ./data/mysql/db:/var/lib/mysql 15 | - ./data/mysql/my.cnf:/etc/mysql/conf.d/my.cnf 16 | - ./data/mysql/init:/docker-entrypoint-initdb.d 17 | # restart: on-failure 18 | networks: 19 | - modelcache 20 | 21 | milvus: 22 | image: milvusdb/milvus:v2.5.10 23 | container_name: milvus 24 | security_opt: 25 | - seccomp:unconfined 26 | environment: 27 | ETCD_USE_EMBED: true 28 | ETCD_DATA_DIR: /var/lib/milvus/etcd 29 | ETCD_CONFIG_PATH: /milvus/configs/embedEtcd.yaml 30 | COMMON_STORAGETYPE: local 31 | volumes: 32 | - ./data/milvus/db:/var/lib/milvus 33 | - ./data/milvus/embedEtcd.yaml:/milvus/configs/embedEtcd.yaml 34 | - ./data/milvus/user.yaml:/milvus/configs/user.yaml 35 | ports: 36 | - 19530:19530 37 | - 9091:9091 38 | - 2379:2379 39 | # healthcheck: 40 | # test: ["CMD", "curl", "-f", "http://localhost:9091/healthz"] 41 | # interval: 30s 42 | # start_period: 90s 43 | # timeout: 20s 44 | # retries: 3 45 | networks: 46 | - modelcache 47 | # restart: on-failure 48 | command: milvus run standalone 49 | 50 | # modelcache: 51 | # build: 52 | # context: . 53 | # dockerfile: Dockerfile 54 | # container_name: modelcache 55 | # image: modelcache:0.1.0 56 | # ports: 57 | # - 5000:5000 58 | # volumes: 59 | # - ./model:/home/user/model 60 | # - ./modelcache:/home/user/modelcache 61 | # - ./modelcache_mm:/home/user/modelcache_mm 62 | # - ./fastapi4modelcache.py:/home/user/fastapi4modelcache.py 63 | # networks: 64 | # - modelcache 65 | # restart: on-failure 66 | # command: sh -c "uvicorn fastapi4modelcache:app --reload --reload-dir /home/user --port=5000 --host=0.0.0.0" 67 | 68 | networks: 69 | modelcache: 70 | driver: bridge -------------------------------------------------------------------------------- /modelcache/embedding/paddlenlp.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | 4 | from modelcache.embedding.base import BaseEmbedding 5 | from modelcache.utils import import_paddlenlp, import_paddle 6 | 7 | import_paddle() 8 | import_paddlenlp() 9 | 10 | 11 | import paddle # pylint: disable=C0413 12 | from paddlenlp.transformers import AutoModel, AutoTokenizer # pylint: disable=C0413 13 | 14 | 15 | class PaddleNLP(BaseEmbedding): 16 | def __init__(self, model: str = "ernie-3.0-medium-zh"): 17 | self.model = AutoModel.from_pretrained(model) 18 | self.model.eval() 19 | 20 | self.tokenizer = AutoTokenizer.from_pretrained(model) 21 | if not self.tokenizer.pad_token: 22 | self.tokenizer.pad_token = "" 23 | self.__dimension = None 24 | 25 | def to_embeddings(self, data, **_): 26 | """Generate embedding given text input 27 | 28 | :param data: text in string. 29 | :type data: str 30 | 31 | :return: a text embedding in shape of (dim,). 32 | """ 33 | if not isinstance(data, list): 34 | data = [data] 35 | inputs = self.tokenizer( 36 | data, padding=True, truncation=True, return_tensors="pd" 37 | ) 38 | outs = self.model(**inputs)[0] 39 | emb = self.post_proc(outs, inputs).squeeze(0).detach().numpy() 40 | return np.array(emb).astype("float32") 41 | 42 | def post_proc(self, token_embeddings, inputs): 43 | attention_mask = paddle.ones(inputs["token_type_ids"].shape) 44 | input_mask_expanded = ( 45 | attention_mask.unsqueeze(-1).expand(token_embeddings.shape).astype("float32") 46 | ) 47 | sentence_embs = paddle.sum( 48 | token_embeddings * input_mask_expanded, 1 49 | ) / paddle.clip(input_mask_expanded.sum(1), min=1e-9) 50 | return sentence_embs 51 | 52 | @property 53 | def dimension(self): 54 | """Embedding dimension. 55 | 56 | :return: embedding dimension 57 | """ 58 | if not self.__dimension: 59 | self.__dimension = len(self.to_embeddings("foo")) 60 | return self.__dimension 61 | -------------------------------------------------------------------------------- /websocket4modelcache.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from contextlib import asynccontextmanager 3 | import uvicorn 4 | import json 5 | import asyncio 6 | from fastapi import FastAPI, WebSocket 7 | from starlette.websockets import WebSocketDisconnect 8 | from modelcache.cache import Cache 9 | from modelcache.embedding import EmbeddingModel 10 | 11 | @asynccontextmanager 12 | async def lifespan(app: FastAPI): 13 | global cache 14 | cache, _ = await Cache.init( 15 | sql_storage="mysql", 16 | vector_storage="milvus", 17 | embedding_model=EmbeddingModel.HUGGINGFACE_ALL_MPNET_BASE_V2, 18 | embedding_workers_num=8 19 | ) 20 | yield 21 | 22 | app = FastAPI(lifespan=lifespan) 23 | cache: Cache = None 24 | 25 | @app.websocket("/modelcache") 26 | async def user_backend(websocket: WebSocket): 27 | await websocket.accept() 28 | try: 29 | while True: 30 | data = await websocket.receive_text() 31 | asyncio.create_task(handle_message(websocket, data)) 32 | except WebSocketDisconnect as e: 33 | print(e) 34 | 35 | 36 | async def handle_message(websocket,message): 37 | try: 38 | param_dict = json.loads(message) 39 | except Exception: 40 | await websocket.send_json({"errorCode": 400, "errorDesc": "bad request", "cacheHit": False, "delta_time": 0, "hit_query": '', "answer": ''}) 41 | return 42 | 43 | request_id = param_dict.get("requestId") 44 | request_payload = param_dict.get("payload") 45 | if not request_id or not request_payload: 46 | await websocket.send_json({"errorCode": 400, "errorDesc": "bad request", "cacheHit": False, "delta_time": 0, "hit_query": '', "answer": ''}) 47 | return 48 | try: 49 | result = await cache.handle_request(request_payload) 50 | await websocket.send_json({"requestId": request_id,"result": result}) 51 | except Exception as e: 52 | error_result = {"errorCode": 500, "errorDesc": str(e), "cacheHit": False, "delta_time": 0, "hit_query": '', "answer": ''} 53 | cache.save_query_resp(error_result, model='', query='', delta_time=0) 54 | await websocket.send_json(error_result) 55 | 56 | 57 | if __name__ == '__main__': 58 | uvicorn.run(app, host='0.0.0.0', port=5000, loop="asyncio", http="httptools") 59 | -------------------------------------------------------------------------------- /websocket4modelcache_demo.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from contextlib import asynccontextmanager 3 | import uvicorn 4 | import json 5 | import asyncio 6 | from fastapi import FastAPI, WebSocket 7 | from starlette.websockets import WebSocketDisconnect 8 | from modelcache.cache import Cache 9 | from modelcache.embedding import EmbeddingModel 10 | 11 | @asynccontextmanager 12 | async def lifespan(app: FastAPI): 13 | global cache 14 | cache, _ = await Cache.init( 15 | sql_storage="sqlite", 16 | vector_storage="faiss", 17 | embedding_model=EmbeddingModel.HUGGINGFACE_ALL_MPNET_BASE_V2, 18 | embedding_workers_num=2 19 | ) 20 | yield 21 | 22 | app = FastAPI(lifespan=lifespan) 23 | cache: Cache = None 24 | 25 | @app.websocket("/modelcache") 26 | async def user_backend(websocket: WebSocket): 27 | await websocket.accept() 28 | try: 29 | while True: 30 | data = await websocket.receive_text() 31 | asyncio.create_task(handle_message(websocket, data)) 32 | except WebSocketDisconnect as e: 33 | print(e) 34 | 35 | 36 | async def handle_message(websocket,message): 37 | try: 38 | param_dict = json.loads(message) 39 | except Exception: 40 | await websocket.send_json({"errorCode": 400, "errorDesc": "bad request", "cacheHit": False, "delta_time": 0, "hit_query": '', "answer": ''}) 41 | return 42 | 43 | request_id = param_dict.get("requestId") 44 | request_payload = param_dict.get("payload") 45 | if not request_id or not request_payload: 46 | await websocket.send_json({"errorCode": 400, "errorDesc": "bad request", "cacheHit": False, "delta_time": 0, "hit_query": '', "answer": ''}) 47 | return 48 | try: 49 | result = await cache.handle_request(request_payload) 50 | await websocket.send_json({"requestId": request_id,"result": result}) 51 | except Exception as e: 52 | error_result = {"errorCode": 500, "errorDesc": str(e), "cacheHit": False, "delta_time": 0, "hit_query": '', "answer": ''} 53 | cache.save_query_resp(error_result, model='', query='', delta_time=0) 54 | await websocket.send_json(error_result) 55 | 56 | 57 | if __name__ == '__main__': 58 | uvicorn.run(app, host='0.0.0.0', port=5000, loop="asyncio", http="httptools") 59 | -------------------------------------------------------------------------------- /docs/2.model-cache-features.md: -------------------------------------------------------------------------------- 1 | # ModelCache features 2 | 3 | This topic describes ModelCache features. In ModelCache, we incorporated the core principles of GPTCache. ModelCache has four modules: adapter, embedding, similarity, and data_manager. 4 | 5 | - The adapter module orchestrates the business logic for various tasks, integrate the embedding, similarity, and data_manager modules. 6 | - The embedding module converts text into semantic vector representations, and transforms user queries into vectors. 7 | - The rank module ranks and evaluate the similarity of recalled vectors. 8 | - The data_manager module manages the databases. 9 | 10 | To make ModelCache more suitable for industrial use, we made several improvements to its architecture and functionality: 11 | 12 | - [x] Architectural adjustment (lightweight integration): 13 | - Embedded into LLM products using a Redis-like caching mode. 14 | - Provided semantic caching without interfering with LLM calls, security audits, and other functions. 15 | - Compatible with all LLM services. 16 | - [x] Multiple model loading: 17 | - Supported local embedding model loading, and resolved Hugging Face network connectivity issues. 18 | - Supported loading embedding layers from various pre-trained models. 19 | - [x] Data isolation 20 | - Environment isolation: Read different database configurations based on the environment. Isolate development, staging, and production environments. 21 | - Multi-tenant data isolation: Dynamically create collections based on models for data isolation, addressing data separation issues in multi-model/service scenarios within large language model products. 22 | - [x] Supported system instruction: Adopted a concatenation approach to resolve issues with system instructions in the prompt paradigm. 23 | - [x] Long and short text differentiation: Long texts bring more challenges for similarity assessment. Added differentiation between long and short texts, allowing for separate threshold configurations. 24 | - [x] Milvus performance optimization: Adjusted Milvus consistency level to "Session" level for better performance. 25 | - [x] Data management: 26 | - One-click cache clearing to enable easy data management after model upgrades. 27 | - Recall of hit queries for subsequent data analysis and model iteration reference. 28 | - Asynchronous log write-back for data analysis and statistics. 29 | - Added model field and data statistics field to enhance features. 30 | -------------------------------------------------------------------------------- /modelcache/adapter/adapter_insert.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import asyncio 3 | 4 | from modelcache.utils.error import NotInitError 5 | from modelcache.utils.time import time_cal 6 | 7 | 8 | async def adapt_insert(*args, **kwargs): 9 | chat_cache = kwargs.pop("cache_obj") 10 | model = kwargs.pop("model", None) 11 | require_object_store = kwargs.pop("require_object_store", False) 12 | 13 | # Validate object store availability if required 14 | if require_object_store: 15 | assert chat_cache.data_manager.o, "Object store is required for adapter." 16 | 17 | context = kwargs.pop("cache_context", {}) 18 | chat_info = kwargs.pop("chat_info", []) 19 | 20 | # Initialize collections for parallel processing 21 | pre_embedding_data_list = [] # Preprocessed data ready for embedding 22 | embedding_futures_list = [] # Async embedding generation tasks 23 | llm_data_list = [] # Extracted LLM response data 24 | 25 | # Process each chat entry and prepare for parallel embedding generation 26 | for row in chat_info: 27 | # Preprocess chat data using configured preprocessing function 28 | pre_embedding_data = chat_cache.insert_pre_embedding_func( 29 | row, 30 | extra_param=context.get("pre_embedding_func", None), 31 | prompts=chat_cache.prompts, 32 | ) 33 | pre_embedding_data_list.append(pre_embedding_data) 34 | llm_data_list.append(row['answer']) # Extract answer text for storage 35 | 36 | # Create async embedding generation task with performance monitoring 37 | embedding_future = time_cal( 38 | chat_cache.embedding_func, 39 | func_name="embedding", 40 | report_func=chat_cache.report.embedding, 41 | cache_obj=chat_cache 42 | )(pre_embedding_data) 43 | embedding_futures_list.append(embedding_future) 44 | 45 | # Wait for all embedding generation tasks to complete in parallel 46 | embedding_data_list = await asyncio.gather(*embedding_futures_list) 47 | 48 | # Save all processed data to the data manager asynchronously 49 | await asyncio.to_thread( 50 | chat_cache.data_manager.save, 51 | pre_embedding_data_list, 52 | llm_data_list, 53 | embedding_data_list, 54 | model=model, 55 | extra_param=context.get("save_func", None) 56 | ) 57 | return 'success' 58 | -------------------------------------------------------------------------------- /modelcache/manager/eviction/memory_cache.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from typing import Any, Callable, List, Tuple 3 | import cachetools 4 | 5 | from modelcache.manager.eviction.base import EvictionBase 6 | from .arc_cache import ARC 7 | from .wtinylfu_cache import W2TinyLFU 8 | 9 | 10 | def popitem_wrapper(func, wrapper_func, clean_size): 11 | def wrapper(*args, **kwargs): 12 | keys = [] 13 | try: 14 | keys = [func(*args, **kwargs)[0] for _ in range(clean_size)] 15 | except KeyError: 16 | pass 17 | wrapper_func(keys) 18 | return wrapper 19 | 20 | 21 | class MemoryCacheEviction(EvictionBase): 22 | def __init__(self, policy: str, maxsize: int, clean_size: int, **kwargs): 23 | self._policy = policy.upper() 24 | self.model_to_cache = dict() 25 | self.maxsize = maxsize 26 | self.clean_size = clean_size 27 | self.kwargs = kwargs 28 | 29 | def create_cache(self, model: str): 30 | if self._policy == "LRU": 31 | cache = cachetools.LRUCache(maxsize=self.maxsize, **self.kwargs) 32 | elif self._policy == "LFU": 33 | cache = cachetools.LFUCache(maxsize=self.maxsize, **self.kwargs) 34 | elif self._policy == "FIFO": 35 | cache = cachetools.FIFOCache(maxsize=self.maxsize, **self.kwargs) 36 | elif self._policy == "RR": 37 | cache = cachetools.RRCache(maxsize=self.maxsize, **self.kwargs) 38 | elif self._policy == "WTINYLFU": 39 | cache = W2TinyLFU(maxsize=self.maxsize) 40 | elif self._policy == "ARC": 41 | cache = ARC(maxsize=self.maxsize) 42 | else: 43 | raise ValueError(f"Unknown policy {self.policy}") 44 | return cache 45 | 46 | def put(self, objs: List[Tuple[Any, Any]], model: str): 47 | cache = self.get_cache(model) 48 | for key, value in objs: 49 | cache[key] = value 50 | 51 | 52 | def get(self, obj: Any, model: str): 53 | cache = self.get_cache(model) 54 | return cache.get(obj) 55 | 56 | 57 | def clear(self, model: str): 58 | self.model_to_cache.pop(model, None) 59 | 60 | 61 | def get_cache(self, model: str): 62 | if not model in self.model_to_cache: 63 | self.model_to_cache[model] = self.create_cache(model) 64 | return self.model_to_cache[model] 65 | 66 | 67 | @property 68 | def policy(self) -> str: 69 | return self._policy 70 | 71 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "modelcache" 3 | version = "0.1.0" 4 | description = "A caching framework for machine learning models" 5 | authors = ["ModelCache Team"] 6 | readme = "README.md" 7 | packages = [ 8 | { include = "modelcache" }, 9 | { include = "modelcache_mm" } 10 | ] 11 | 12 | [tool.poetry.dependencies] 13 | python = "^3.8" 14 | cachetools = "5.3.1" 15 | DBUtils = "1.4" 16 | Flask = "3.0.0" 17 | numpy = "1.24.4" 18 | onnxruntime = "1.16.1" 19 | openai = "0.28.1" 20 | pymilvus = "2.3.1" 21 | PyMySQL = "1.1.0" 22 | Requests = "2.31.0" 23 | torch = "2.1.1" 24 | transformers = "4.38.2" 25 | faiss-cpu = "1.7.4" 26 | redis = "5.0.1" 27 | modelscope = "1.14.0" 28 | fastapi = "0.115.5" 29 | uvicorn = "0.32.0" 30 | chromadb = "0.5.23" 31 | elasticsearch = "7.10.0" 32 | snowflake-id = "1.0.2" 33 | 34 | [tool.poetry.group.dev.dependencies] 35 | pytest = "^8.0.0" 36 | pytest-cov = "^5.0.0" 37 | pytest-mock = "^3.14.0" 38 | 39 | [tool.poetry.scripts] 40 | test = "pytest:main" 41 | tests = "pytest:main" 42 | 43 | [tool.pytest.ini_options] 44 | minversion = "8.0" 45 | testpaths = ["tests"] 46 | python_files = ["test_*.py", "*_test.py"] 47 | python_classes = ["Test*"] 48 | python_functions = ["test_*"] 49 | addopts = [ 50 | "-ra", 51 | "--strict-markers", 52 | "--cov=modelcache", 53 | "--cov=modelcache_mm", 54 | "--cov-branch", 55 | "--cov-report=term-missing:skip-covered", 56 | "--cov-report=html", 57 | "--cov-report=xml", 58 | "--cov-fail-under=80", 59 | "-v" 60 | ] 61 | markers = [ 62 | "unit: Unit tests", 63 | "integration: Integration tests", 64 | "slow: Slow running tests" 65 | ] 66 | 67 | [tool.coverage.run] 68 | source = ["modelcache", "modelcache_mm"] 69 | omit = [ 70 | "*/tests/*", 71 | "*/test_*", 72 | "*/__pycache__/*", 73 | "*/site-packages/*", 74 | "*/distutils/*", 75 | "*/venv/*", 76 | "*/.venv/*" 77 | ] 78 | 79 | [tool.coverage.report] 80 | precision = 2 81 | show_missing = true 82 | skip_covered = false 83 | exclude_lines = [ 84 | "pragma: no cover", 85 | "def __repr__", 86 | "if __name__ == .__main__.:", 87 | "raise AssertionError", 88 | "raise NotImplementedError", 89 | "if TYPE_CHECKING:", 90 | "if typing.TYPE_CHECKING:" 91 | ] 92 | 93 | [tool.coverage.html] 94 | directory = "htmlcov" 95 | 96 | [tool.coverage.xml] 97 | output = "coverage.xml" 98 | 99 | [build-system] 100 | requires = ["poetry-core"] 101 | build-backend = "poetry.core.masonry.api" -------------------------------------------------------------------------------- /modelcache/embedding/onnx.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | 4 | from modelcache.embedding.base import BaseEmbedding 5 | from modelcache.utils import ( 6 | import_onnxruntime, 7 | import_huggingface_hub, 8 | import_huggingface, 9 | ) 10 | 11 | import_huggingface() 12 | import_onnxruntime() 13 | import_huggingface_hub() 14 | 15 | from transformers import AutoTokenizer, AutoConfig # pylint: disable=C0413 16 | import onnxruntime 17 | from modelcache.utils.env_config import get_onnx_tokenizer_path, get_onnx_model 18 | 19 | 20 | class Onnx(BaseEmbedding): 21 | 22 | def __init__(self, model="modelcache_open/paraphrase-albert-onnx"): 23 | # 本地加载 24 | onnx_tokenizer = get_onnx_tokenizer_path() 25 | self.tokenizer = AutoTokenizer.from_pretrained(onnx_tokenizer, local_files_only=True) 26 | # 本地加载 27 | onnx_model = get_onnx_model() 28 | self.ort_session = onnxruntime.InferenceSession(onnx_model) 29 | 30 | config = AutoConfig.from_pretrained(onnx_tokenizer, local_files_only=True) 31 | self.__dimension = config.hidden_size 32 | 33 | def to_embeddings(self, data, **_): 34 | """Generate embedding given text input. 35 | 36 | :param data: text in string. 37 | :type data: str 38 | 39 | :return: a text embedding in shape of (dim,). 40 | """ 41 | encoded_text = self.tokenizer.encode_plus(data, padding="max_length") 42 | ort_inputs = { 43 | "input_ids": np.array(encoded_text["input_ids"]).reshape(1, -1), 44 | "attention_mask": np.array(encoded_text["attention_mask"]).reshape(1, -1), 45 | "token_type_ids": np.array(encoded_text["token_type_ids"]).reshape(1, -1), 46 | } 47 | 48 | ort_outputs = self.ort_session.run(None, ort_inputs) 49 | ort_feat = ort_outputs[0] 50 | emb = self.post_proc(ort_feat, ort_inputs["attention_mask"]) 51 | return emb.flatten() 52 | 53 | def post_proc(self, token_embeddings, attention_mask): 54 | input_mask_expanded = ( 55 | np.expand_dims(attention_mask, -1) 56 | .repeat(token_embeddings.shape[-1], -1) 57 | .astype(float) 58 | ) 59 | sentence_embs = np.sum(token_embeddings * input_mask_expanded, 1) / np.maximum( 60 | input_mask_expanded.sum(1), 1e-9 61 | ) 62 | return sentence_embs 63 | 64 | @property 65 | def dimension(self): 66 | """Embedding dimension. 67 | 68 | :return: embedding dimension 69 | """ 70 | return self.__dimension 71 | -------------------------------------------------------------------------------- /modelcache_mm/embedding/timm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | 4 | from modelcache.utils import import_timm, import_torch, import_pillow 5 | from modelcache.embedding.base import BaseEmbedding 6 | 7 | import_torch() 8 | import_timm() 9 | import_pillow() 10 | 11 | import torch 12 | from timm.models import create_model 13 | from timm.data import create_transform, resolve_data_config 14 | from PIL import Image 15 | 16 | 17 | class Timm(BaseEmbedding): 18 | def __init__(self, model: str = "resnet18", device: str = "default"): 19 | if device == "default": 20 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 21 | else: 22 | self.device = device 23 | self.model_name = model 24 | self.model = create_model(model_name=model, pretrained=True) 25 | self.model.eval() 26 | 27 | try: 28 | self.__dimension = self.model.embed_dim 29 | except Exception: 30 | self.__dimension = None 31 | 32 | def to_embeddings(self, data, skip_preprocess: bool = False, **_): 33 | if not skip_preprocess: 34 | data = self.preprocess(data) 35 | if data.dim() == 3: 36 | data = data.unsqueeze(0) 37 | feats = self.model.forward_features(data) 38 | emb = self.post_proc(feats).squeeze(0).detach().numpy() 39 | 40 | return np.array(emb).astype("float32") 41 | 42 | def post_proc(self, features): 43 | features = features.to("cpu") 44 | if features.dim() == 3: 45 | features = features[:, 0] 46 | if features.dim() == 4: 47 | global_pool = torch.nn.AdaptiveAvgPool2d(1) 48 | features = global_pool(features) 49 | features = features.flatten(1) 50 | assert features.dim() == 2, f"Invalid output dim {features.dim()}" 51 | return features 52 | 53 | def preprocess(self, image_path): 54 | data_cfg = resolve_data_config(self.model.pretrained_cfg) 55 | transform = create_transform(**data_cfg) 56 | 57 | image = Image.open(image_path).convert("RGB") 58 | image_tensor = transform(image) 59 | return image_tensor 60 | 61 | @property 62 | def dimension(self): 63 | """Embedding dimension. 64 | :return: embedding dimension 65 | """ 66 | if not self.__dimension: 67 | input_size = self.model.pretrained_cfg["input_size"] 68 | dummy_input = torch.rand((1,) + input_size) 69 | feats = self.to_embeddings(dummy_input, skip_preprocess=True) 70 | self.__dimension = feats.shape[0] 71 | return self.__dimension 72 | -------------------------------------------------------------------------------- /modelcache_mm/manager/vector_data/faiss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | from typing import List 4 | import numpy as np 5 | from modelcache_mm.manager.vector_data.base import VectorBase, VectorData 6 | from modelcache_mm.utils import import_faiss 7 | import_faiss() 8 | import faiss # pylint: disable=C0413 9 | 10 | 11 | class Faiss(VectorBase): 12 | def __init__(self, 13 | index_file_path, 14 | dimension: int = 0, 15 | top_k: int = 1 16 | ): 17 | self._dimension = dimension 18 | self._index_file_path = index_file_path 19 | self._index = faiss.index_factory(self._dimension, "IDMap,Flat", faiss.METRIC_L2) 20 | self._top_k = top_k 21 | if os.path.isfile(index_file_path): 22 | self._index = faiss.read_index(index_file_path) 23 | 24 | def add(self, datas: List[VectorData], model=None, mm_type=None): 25 | data_array, id_array = map(list, zip(*((data.data, data.id) for data in datas))) 26 | np_data = np.array(data_array).astype("float32") 27 | ids = np.array(id_array) 28 | self._index.add_with_ids(np_data, ids) 29 | 30 | def search(self, data: np.ndarray, top_k: int, model, mm_type='mm'): 31 | if self._index.ntotal == 0: 32 | return None 33 | if top_k == -1: 34 | top_k = self._top_k 35 | np_data = np.array(data).astype("float32").reshape(1, -1) 36 | dist, ids = self._index.search(np_data, top_k) 37 | ids = [int(i) for i in ids[0]] 38 | return list(zip(dist[0], ids)) 39 | 40 | def rebuild_col(self, ids=None): 41 | try: 42 | self._index.reset() 43 | except Exception as e: 44 | return f"An error occurred during index rebuild: {e}" 45 | 46 | def rebuild(self, ids=None): 47 | return True 48 | 49 | def delete(self, ids): 50 | ids_to_remove = np.array(ids) 51 | self._index.remove_ids(faiss.IDSelectorBatch(ids_to_remove.size, faiss.swig_ptr(ids_to_remove))) 52 | 53 | def create(self, model=None, mm_type=None): 54 | pass 55 | # collection_name_model = get_mm_index_name(model, mm_type) 56 | # try: 57 | # index_prefix = get_mm_index_prefix(model, mm_type) 58 | # self.create_index(collection_name_model, mm_type, index_prefix) 59 | # except Exception as e: 60 | # raise ValueError(str(e)) 61 | # return 'success' 62 | 63 | def flush(self): 64 | faiss.write_index(self._index, self._index_file_path) 65 | 66 | def close(self): 67 | self.flush() 68 | 69 | def rebuild_idx(self, model): 70 | pass 71 | 72 | def count(self): 73 | return self._index.ntotal 74 | -------------------------------------------------------------------------------- /modelcache/embedding/timm_embedding.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | 4 | from modelcache.utils import import_timm, import_torch, import_pillow 5 | from modelcache.embedding.base import BaseEmbedding 6 | 7 | import_torch() 8 | import_timm() 9 | import_pillow() 10 | 11 | import torch # pylint: disable=C0413 12 | from timm.models import create_model # pylint: disable=C0413 13 | from timm.data import create_transform, resolve_data_config # pylint: disable=C0413 14 | from PIL import Image # pylint: disable=C0413 15 | 16 | 17 | class Timm(BaseEmbedding): 18 | def __init__(self, model: str = "resnet18", device: str = "default"): 19 | if device == "default": 20 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 21 | else: 22 | self.device = device 23 | self.model_name = model 24 | self.model = create_model(model_name=model, pretrained=True) 25 | self.model.eval() 26 | 27 | try: 28 | self.__dimension = self.model.embed_dim 29 | except Exception: # pylint: disable=W0703 30 | self.__dimension = None 31 | 32 | def to_embeddings(self, data, skip_preprocess: bool = False, **_): 33 | if not skip_preprocess: 34 | data = self.preprocess(data) 35 | if data.dim() == 3: 36 | data = data.unsqueeze(0) 37 | feats = self.model.forward_features(data) 38 | emb = self.post_proc(feats).squeeze(0).detach().numpy() 39 | 40 | return np.array(emb).astype("float32") 41 | 42 | def post_proc(self, features): 43 | features = features.to("cpu") 44 | if features.dim() == 3: 45 | features = features[:, 0] 46 | if features.dim() == 4: 47 | global_pool = torch.nn.AdaptiveAvgPool2d(1) 48 | features = global_pool(features) 49 | features = features.flatten(1) 50 | assert features.dim() == 2, f"Invalid output dim {features.dim()}" 51 | return features 52 | 53 | def preprocess(self, image_path): 54 | data_cfg = resolve_data_config(self.model.pretrained_cfg) 55 | transform = create_transform(**data_cfg) 56 | 57 | image = Image.open(image_path).convert("RGB") 58 | image_tensor = transform(image) 59 | return image_tensor 60 | 61 | @property 62 | def dimension(self): 63 | """Embedding dimension. 64 | 65 | :return: embedding dimension 66 | """ 67 | if not self.__dimension: 68 | input_size = self.model.pretrained_cfg["input_size"] 69 | dummy_input = torch.rand((1,) + input_size) 70 | feats = self.to_embeddings(dummy_input, skip_preprocess=True) 71 | self.__dimension = feats.shape[0] 72 | return self.__dimension 73 | -------------------------------------------------------------------------------- /modelcache_mm/core.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import atexit 3 | from typing import Optional, List, Any 4 | from modelcache_mm.processor.post import first 5 | from modelcache_mm.similarity_evaluation import ExactMatchEvaluation 6 | from modelcache_mm.similarity_evaluation import SimilarityEvaluation 7 | from modelcache_mm.embedding.string import to_embeddings as string_embedding 8 | from modelcache_mm.report import Report 9 | from modelcache_mm.config import Config 10 | from modelcache_mm.utils.cache_func import cache_all 11 | from modelcache_mm.utils.log import modelcache_log 12 | from modelcache_mm.manager import get_data_manager 13 | from modelcache_mm.manager.data_manager import DataManager 14 | 15 | 16 | class Cache: 17 | def __init__(self): 18 | self.has_init = False 19 | self.cache_enable_func = None 20 | self.query_pre_embedding_func = None 21 | self.insert_pre_embedding_func = None 22 | self.embedding_func = None 23 | self.data_manager: Optional[DataManager] = None 24 | self.similarity_evaluation: Optional[SimilarityEvaluation] = None 25 | self.post_process_messages_func = None 26 | self.config = Config() 27 | self.report = Report() 28 | self.next_cache = None 29 | 30 | def init( 31 | self, 32 | cache_enable_func=cache_all, 33 | query_pre_embedding_func=None, 34 | insert_pre_embedding_func=None, 35 | embedding_func=string_embedding, 36 | data_manager: DataManager = get_data_manager(), 37 | similarity_evaluation=ExactMatchEvaluation(), 38 | post_process_messages_func=first, 39 | config=Config(), 40 | next_cache=None, 41 | ): 42 | self.has_init = True 43 | self.cache_enable_func = cache_enable_func 44 | self.query_pre_embedding_func = query_pre_embedding_func 45 | self.insert_pre_embedding_func = insert_pre_embedding_func 46 | self.embedding_func = embedding_func 47 | self.data_manager: DataManager = data_manager 48 | self.similarity_evaluation = similarity_evaluation 49 | self.post_process_messages_func = post_process_messages_func 50 | self.config = config 51 | self.next_cache = next_cache 52 | 53 | @atexit.register 54 | def close(): 55 | try: 56 | self.data_manager.close() 57 | except Exception as e: 58 | modelcache_log.error(e) 59 | 60 | def import_data(self, questions: List[Any], answers: List[Any]) -> None: 61 | self.data_manager.import_data( 62 | questions=questions, 63 | answers=answers, 64 | embedding_datas=[self.embedding_func(question) for question in questions], 65 | ) 66 | 67 | def flush(self): 68 | self.data_manager.flush() 69 | if self.next_cache: 70 | self.next_cache.data_manager.flush() 71 | 72 | 73 | cache = Cache() 74 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | *.DS_Store 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .nox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | *.py,cover 48 | .hypothesis/ 49 | .pytest_cache/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | db.sqlite3 59 | db.sqlite3-journal 60 | *.db 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # IPython 79 | profile_default/ 80 | ipython_config.py 81 | 82 | # pyenv 83 | .python-version 84 | 85 | __pypackages__/ 86 | 87 | # Celery stuff 88 | celerybeat-schedule 89 | celerybeat.pid 90 | 91 | # SageMath parsed files 92 | *.sage.py 93 | 94 | # Environments 95 | .env 96 | .venv* 97 | env/ 98 | venv/ 99 | ENV/ 100 | env.bak/ 101 | venv.bak/ 102 | 103 | # Spyder project settings 104 | .spyderproject 105 | .spyproject 106 | 107 | # Rope project settings 108 | .ropeproject 109 | 110 | # mkdocs documentation 111 | /site 112 | 113 | # mypy 114 | .mypy_cache/ 115 | .dmypy.json 116 | dmypy.json 117 | 118 | # Pyre type checker 119 | .pyre/ 120 | 121 | .idea 122 | **/data_map**.txt 123 | **/faiss**.index 124 | **/sqlite**.db 125 | **/**.db 126 | **/example.py 127 | **/example.db 128 | **/.chroma 129 | 130 | /fuhui_dev 131 | *.index 132 | *model.onnx 133 | 134 | /data_analyse 135 | /embedding_npy 136 | /flask_server 137 | *.bin 138 | **/maya_embedding_service 139 | 140 | *.ini 141 | 142 | **/multicache_serving.py 143 | **/modelcache_serving.py 144 | 145 | **/model/text2vec-base-chinese 146 | 147 | /data/milvus/db 148 | /data/mysql/db 149 | 150 | # Testing 151 | .pytest_cache/ 152 | .coverage 153 | .coverage.* 154 | htmlcov/ 155 | coverage.xml 156 | *.py,cover 157 | .hypothesis/ 158 | pytest_cache/ 159 | test-results/ 160 | .tox/ 161 | .nox/ 162 | 163 | # Claude 164 | .claude/* 165 | 166 | # Poetry 167 | dist/ 168 | 169 | # Virtual environments 170 | .venv/ 171 | venv/ 172 | ENV/ 173 | env/ 174 | .env 175 | 176 | # IDE 177 | .vscode/ 178 | .idea/ 179 | *.swp 180 | *.swo 181 | *~ 182 | 183 | # OS 184 | .DS_Store 185 | Thumbs.db 186 | 187 | # Temporary files 188 | *.tmp 189 | *.bak 190 | *.orig 191 | tmp/ 192 | temp/ -------------------------------------------------------------------------------- /modelcache_mm/manager/scalar_data/base.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from abc import ABCMeta, abstractmethod 3 | from dataclasses import dataclass 4 | from typing import Union, Dict, List, Optional, Any 5 | from enum import IntEnum 6 | import numpy as np 7 | 8 | 9 | class DataType(IntEnum): 10 | STR = 0 11 | IMAGE_BASE64 = 1 12 | IMAGE_URL = 2 13 | 14 | 15 | @dataclass 16 | class QuestionDep: 17 | """ 18 | QuestionDep 19 | """ 20 | 21 | name: str 22 | data: str 23 | dep_type: int = DataType.STR 24 | 25 | @classmethod 26 | def from_dict(cls, d: Dict): 27 | return cls( 28 | name=d["name"], 29 | data=d["data"], 30 | dep_type=d["dep_type"] 31 | ) 32 | 33 | 34 | @dataclass 35 | class Question: 36 | """ 37 | Question 38 | """ 39 | 40 | content: str 41 | deps: Optional[List[QuestionDep]] = None 42 | 43 | @classmethod 44 | def from_dict(cls, d: Dict): 45 | deps = [] 46 | for dep in d["deps"]: 47 | deps.append(QuestionDep.from_dict(dep)) 48 | return cls(d["content"], deps) 49 | 50 | 51 | @dataclass 52 | class Answer: 53 | """ 54 | data_type: 55 | 0: str 56 | 1: base64 image 57 | """ 58 | 59 | answer: Any 60 | answer_type: int = DataType.STR 61 | 62 | 63 | @dataclass 64 | class CacheData: 65 | """ 66 | CacheData 67 | """ 68 | 69 | question: Union[str, Question] 70 | answers: List[Answer] 71 | embedding_data: Optional[np.ndarray] = None 72 | 73 | def __init__(self, question, answers, embedding_data=None): 74 | self.question = question 75 | self.answers = [] 76 | if isinstance(answers, (str, Answer)): 77 | answers = [answers] 78 | for data in answers: 79 | if isinstance(data, (list, tuple)): 80 | self.answers.append(Answer(*data)) 81 | elif isinstance(data, Answer): 82 | self.answers.append(data) 83 | else: 84 | self.answers.append(Answer(answer=data)) 85 | self.embedding_data = embedding_data 86 | 87 | 88 | class CacheStorage(metaclass=ABCMeta): 89 | """ 90 | BaseStorage for scalar data. 91 | """ 92 | 93 | @abstractmethod 94 | def create(self): 95 | pass 96 | 97 | @abstractmethod 98 | def batch_insert(self, all_data: List[CacheData]): 99 | pass 100 | 101 | @abstractmethod 102 | def insert_query_resp(self, query_resp, **kwargs): 103 | pass 104 | 105 | @abstractmethod 106 | def get_data_by_id(self, key): 107 | pass 108 | 109 | @abstractmethod 110 | def mark_deleted(self, keys): 111 | pass 112 | 113 | @abstractmethod 114 | def model_deleted(self, model_name): 115 | pass 116 | 117 | @abstractmethod 118 | def clear_deleted_data(self): 119 | pass 120 | 121 | @abstractmethod 122 | def get_ids(self, deleted=True): 123 | pass 124 | 125 | @abstractmethod 126 | def count(self): 127 | pass 128 | 129 | def flush(self): 130 | pass 131 | 132 | @abstractmethod 133 | def close(self): 134 | pass 135 | -------------------------------------------------------------------------------- /modelcache/manager/vector_data/chroma.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | import logging 5 | from modelcache.manager.vector_data.base import VectorStorage, VectorData 6 | from modelcache.utils import import_chromadb, import_torch 7 | 8 | import_torch() 9 | import_chromadb() 10 | 11 | import chromadb 12 | 13 | 14 | class Chromadb(VectorStorage): 15 | 16 | def __init__( 17 | self, 18 | persist_directory="./chromadb", 19 | top_k: int = 1, 20 | ): 21 | self.collection_name = "modelcache" 22 | self.top_k = top_k 23 | 24 | self._client = chromadb.PersistentClient(path=persist_directory) 25 | self._collection = None 26 | 27 | def mul_add(self, datas: List[VectorData], model=None): 28 | collection_name_model = self.collection_name + '_' + model 29 | self._collection = self._client.get_or_create_collection(name=collection_name_model) 30 | 31 | data_array, id_array = map(list, zip(*((data.data.tolist(), str(data.id)) for data in datas))) 32 | self._collection.add(embeddings=data_array, ids=id_array) 33 | 34 | def search(self, data: np.ndarray, top_k: int = -1, model=None): 35 | collection_name_model = self.collection_name + '_' + model 36 | self._collection = self._client.get_or_create_collection(name=collection_name_model) 37 | 38 | if self._collection.count() == 0: 39 | return [] 40 | if top_k == -1: 41 | top_k = self.top_k 42 | results = self._collection.query( 43 | query_embeddings=[data.tolist()], 44 | n_results=top_k, 45 | include=["distances"], 46 | ) 47 | return list(zip(results["distances"][0], [int(x) for x in results["ids"][0]])) 48 | 49 | def rebuild(self, ids=None): 50 | pass 51 | 52 | def delete(self, ids, model=None): 53 | try: 54 | collection_name_model = self.collection_name + '_' + model 55 | self._collection = self._client.get_or_create_collection(name=collection_name_model) 56 | # 查询集合中实际存在的 ID 57 | ids_str = [str(x) for x in ids] 58 | existing_ids = set(self._collection.get(ids=ids_str).ids) 59 | 60 | # 删除存在的 ID 61 | if existing_ids: 62 | self._collection.delete(list(existing_ids)) 63 | 64 | # 返回实际删除的条目数量 65 | return len(existing_ids) 66 | 67 | except Exception as e: 68 | logging.error('Error during deletion: {}'.format(e)) 69 | raise ValueError(str(e)) 70 | 71 | def rebuild_col(self, model): 72 | collection_name_model = self.collection_name + '_' + model 73 | 74 | # 检查集合是否存在,如果存在则删除 75 | collections = self._client.list_collections() 76 | if any(col.name == collection_name_model for col in collections): 77 | self._client.delete_collection(collection_name_model) 78 | else: 79 | return 'model collection not found, please check!' 80 | 81 | try: 82 | self._client.create_collection(collection_name_model) 83 | except Exception as e: 84 | logging.info(f'rebuild_collection: {e}') 85 | raise ValueError(str(e)) 86 | 87 | def flush(self): 88 | # chroma无flush方法 89 | pass 90 | 91 | def close(self): 92 | pass 93 | -------------------------------------------------------------------------------- /docs/3.model-cache-quick-start.md: -------------------------------------------------------------------------------- 1 | # Quick start 2 | 3 | This topic describes how to set up and use ModelCache. 4 | 5 | You can find the start script in `flask4modelcache.py` and `flask4modelcache_demo.py`. 6 | 7 | - `flask4modelcache_demo.py`: A quick test service that embeds SQLite and FAISS. No database configuration required. 8 | - `flask4modelcache.py`: The standard service that requires MySQL and Milvus configuration. 9 | 10 | ## Dependencies 11 | 12 | - Python: V3.8 or above 13 | - Package installation 14 | 15 | ```shell 16 | pip install -r requirements.txt 17 | ``` 18 | 19 | ## Start service 20 | 21 | ### Start demo 22 | 23 | 1. Download the embedding model bin file from [Hugging Face](https://huggingface.co/shibing624/text2vec-base-chinese/tree/main). Place it in the `model/text2vec-base-chinese` folder. 24 | 2. Start the backend service: 25 | 26 | ```shell 27 | cd CodeFuse-ModelCache 28 | ``` 29 | 30 | ```shell 31 | python flask4modelcache_demo.py 32 | ``` 33 | 34 | ### Start standard service 35 | 36 | Before you start standard service, do these steps: 37 | 38 | 1. Install MySQL and import the SQL file from `reference_doc/create_table.sql`. 39 | 2. Install vector database Milvus. 40 | 3. Configure database access in: 41 | - `modelcache/config/milvus_config.ini` 42 | - `modelcache/config/mysql_config.ini` 43 | 4. Download the embedding model bin file from [Hugging Face](https://huggingface.co/shibing624/text2vec-base-chinese/tree/main). Put it in `model/text2vec-base-chinese`. 44 | 5. Start the backend service: 45 | 46 | ```bash 47 | python flask4modelcache.py 48 | ``` 49 | 50 | ## Visit the service 51 | 52 | The service provides three core RESTful API functionalities: Cache-Writing, Cache-Querying, and Cache-Clearing. 53 | 54 | ### Write cache 55 | 56 | ```python 57 | import json 58 | import requests 59 | url = 'http://127.0.0.1:5000/modelcache' 60 | type = 'insert' 61 | scope = {"model": "CODEGPT-1008"} 62 | chat_info = [{"query": [{"role": "system", "content": "You are an AI code assistant and you must provide neutral and harmless answers to help users solve code-related problems."}, {"role": "user", "content": "你是谁?"}], 63 | "answer": "Hello, I am an intelligent assistant. How can I assist you?"}] 64 | data = {'type': type, 'scope': scope, 'chat_info': chat_info} 65 | headers = {"Content-Type": "application/json"} 66 | res = requests.post(url, headers=headers, json=json.dumps(data)) 67 | ``` 68 | 69 | ### Query cache 70 | 71 | ```python 72 | import json 73 | import requests 74 | url = 'http://127.0.0.1:5000/modelcache' 75 | type = 'query' 76 | scope = {"model": "CODEGPT-1008"} 77 | query = [{"role": "system", "content": "You are an AI code assistant and you must provide neutral and harmless answers to help users solve code-related problems."}, {"role": "user", "content": "Who are you?"}] 78 | data = {'type': type, 'scope': scope, 'query': query} 79 | 80 | headers = {"Content-Type": "application/json"} 81 | res = requests.post(url, headers=headers, json=json.dumps(data)) 82 | ``` 83 | 84 | ### Clear cache 85 | 86 | ```python 87 | import json 88 | import requests 89 | url = 'http://127.0.0.1:5000/modelcache' 90 | type = 'remove' 91 | scope = {"model": "CODEGPT-1008"} 92 | remove_type = 'truncate_by_model' 93 | data = {'type': type, 'scope': scope, 'remove_type': remove_type} 94 | 95 | headers = {"Content-Type": "application/json"} 96 | res = requests.post(url, headers=headers, json=json.dumps(data)) 97 | ``` 98 | -------------------------------------------------------------------------------- /modelcache/embedding/embedding_dispatcher.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import threading 3 | import uuid 4 | import asyncio 5 | import psutil 6 | from asyncio import Future, AbstractEventLoop 7 | 8 | from modelcache.embedding import EmbeddingModel 9 | from modelcache.embedding.base import BaseEmbedding 10 | 11 | 12 | def worker_func(embedding_model: EmbeddingModel, model_path, task_queue, result_queue, worker_id): 13 | """Worker function that runs in separate processes to generate embeddings.""" 14 | base_embedding = BaseEmbedding.get(embedding_model, model_path=model_path) 15 | print(f"Embedding worker {worker_id} started.") 16 | try: 17 | while True: 18 | job_id, data = task_queue.get() # Get task from queue 19 | try: 20 | result = base_embedding.to_embeddings(data) # Generate embedding 21 | except Exception as e: 22 | result = e 23 | result_queue.put((job_id, result)) # Send result back 24 | except KeyboardInterrupt: 25 | print(f"Embedding worker {worker_id} stopped.") 26 | except Exception as e: 27 | print(f"Embedding worker {worker_id} encountered an error: {e}") 28 | 29 | 30 | class EmbeddingDispatcher: 31 | """Manages a pool of worker processes for parallel embedding generation.""" 32 | 33 | def __init__( 34 | self, 35 | embedding_model: EmbeddingModel, 36 | model_path: str, 37 | event_loop: AbstractEventLoop, 38 | num_workers: int 39 | ): 40 | """Initialize the dispatcher with worker processes.""" 41 | if num_workers <= 0: 42 | raise ValueError("Number of workers must be greater than 0.") 43 | 44 | self.task_queue = multiprocessing.Queue() # Tasks to workers 45 | self.result_queue = multiprocessing.Queue() # Results from workers 46 | self.futures: dict[str, asyncio.Future] = {} # Pending futures 47 | self.event_loop = event_loop 48 | self._start_result_collector_thread() # Start result collection thread 49 | 50 | # Start worker processes 51 | self.workers = [] 52 | for i in range(num_workers): 53 | p = multiprocessing.Process( 54 | target=worker_func, 55 | args=(embedding_model, model_path, self.task_queue, self.result_queue, i) 56 | ) 57 | p.daemon = True 58 | p.start() 59 | psutil.Process(p.pid).nice(psutil.HIGH_PRIORITY_CLASS) 60 | self.workers.append(p) 61 | 62 | def _start_result_collector_thread(self): 63 | """Start a thread to collect results from worker processes.""" 64 | def collect(): 65 | while True: 66 | job_id, result = self.result_queue.get() # Get result from queue 67 | future = self.futures.pop(job_id, None) # Retrieve future 68 | if future: 69 | self.event_loop.call_soon_threadsafe( 70 | future.set_exception if isinstance(result, Exception) else future.set_result, 71 | result 72 | ) 73 | 74 | t = threading.Thread(target=collect, daemon=True) 75 | t.start() 76 | 77 | def embed(self, data: str) -> Future: 78 | """Submit a task for embedding generation.""" 79 | job_id = str(uuid.uuid4()) # Generate unique job ID 80 | future = asyncio.get_running_loop().create_future() # Create future 81 | self.futures[job_id] = future # Store future 82 | self.task_queue.put((job_id, data)) # Add task to queue 83 | return future 84 | 85 | -------------------------------------------------------------------------------- /modelcache_mm/adapter/adapter_insert.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | from modelcache_mm import cache 4 | from modelcache_mm.utils.error import NotInitError 5 | from modelcache_mm.utils.time import time_cal 6 | 7 | 8 | def adapt_insert(*args, **kwargs): 9 | chat_cache = kwargs.pop("cache_obj", cache) 10 | model = kwargs.pop("model", None) 11 | require_object_store = kwargs.pop("require_object_store", False) 12 | if require_object_store: 13 | assert chat_cache.data_manager.o, "Object store is required for adapter." 14 | if not chat_cache.has_init: 15 | raise NotInitError() 16 | cache_enable = chat_cache.cache_enable_func(*args, **kwargs) 17 | context = kwargs.pop("cache_context", {}) 18 | embedding_data = None 19 | pre_embedding_data_dict = chat_cache.insert_pre_embedding_func( 20 | kwargs, 21 | extra_param=context.get("pre_embedding_func", None), 22 | prompts=chat_cache.config.prompts, 23 | ) 24 | 25 | chat_info = kwargs.pop("chat_info", []) 26 | llm_data = chat_info[-1]['answer'] 27 | 28 | pre_embedding_text = '###'.join(pre_embedding_data_dict['text']) 29 | pre_embedding_image_url = pre_embedding_data_dict['imageUrl'] 30 | pre_embedding_image_raw = pre_embedding_data_dict['imageRaw'] 31 | pre_embedding_image_id = pre_embedding_data_dict.get('imageId', None) 32 | 33 | if pre_embedding_image_url and pre_embedding_image_raw: 34 | raise ValueError("Both pre_embedding_image_url and pre_embedding_image_raw cannot be non-empty at the same time.") 35 | 36 | if pre_embedding_image_url: 37 | pre_embedding_image = pre_embedding_image_url 38 | elif pre_embedding_image_raw: 39 | pre_embedding_image = pre_embedding_image_raw 40 | else: 41 | pre_embedding_image = None 42 | if not pre_embedding_text: 43 | raise ValueError( 44 | "Both pre_embedding_image_url and pre_embedding_image_raw are empty. Please provide at least one.") 45 | 46 | data_dict = {'text': [pre_embedding_text], 'image': pre_embedding_image} 47 | embedding_data = None 48 | mm_type = None 49 | 50 | if cache_enable: 51 | embedding_data_resp = time_cal( 52 | chat_cache.embedding_func, 53 | func_name="image_embedding", 54 | report_func=chat_cache.report.embedding, 55 | )(data_dict) 56 | 57 | image_embeddings = embedding_data_resp['image_embedding'] 58 | text_embeddings = embedding_data_resp['text_embeddings'] 59 | 60 | if len(image_embeddings) > 0 and len(image_embeddings) > 0: 61 | # image_embedding = np.array(image_embeddings[0]) 62 | # text_embedding = text_embeddings[0] 63 | embedding_data = np.concatenate((image_embeddings, text_embeddings)) 64 | mm_type = 'mm' 65 | elif len(image_embeddings) > 0: 66 | image_embedding = np.array(image_embeddings[0]) 67 | embedding_data = image_embedding 68 | mm_type = 'image' 69 | elif len(text_embeddings) > 0: 70 | text_embedding = np.array(text_embeddings[0]) 71 | embedding_data = text_embedding 72 | mm_type = 'text' 73 | else: 74 | raise ValueError('maya embedding service return both empty list, please check!') 75 | chat_cache.data_manager.save( 76 | pre_embedding_text, 77 | pre_embedding_image_url, 78 | pre_embedding_image_id, 79 | llm_data, 80 | embedding_data, 81 | model=model, 82 | mm_type=mm_type, 83 | extra_param=context.get("mm_save_func", None) 84 | ) 85 | return 'success' 86 | -------------------------------------------------------------------------------- /modelcache/processor/pre.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import re 3 | from typing import Dict, Any 4 | 5 | 6 | def insert_last_content(data: Dict[str, Any], **_: Dict[str, Any]) -> Any: 7 | return data.get("query")[-1]["content"] 8 | 9 | 10 | def query_last_content(data: Dict[str, Any], **_: Dict[str, Any]) -> Any: 11 | return data.get("query")[-1]["content"] 12 | 13 | 14 | def last_content_without_prompt(data: Dict[str, Any], **params: Dict[str, Any]) -> Any: 15 | last_content_str = data.get("messages")[-1]["content"] 16 | prompts = params.get("prompts", []) 17 | if prompts is None: 18 | return last_content_str 19 | pattern = "|".join(prompts) 20 | new_content_str = re.sub(pattern, "", last_content_str) 21 | return new_content_str 22 | 23 | 24 | def all_content(data: Dict[str, Any], **_: Dict[str, Any]) -> Any: 25 | s = "" 26 | messages = data.get("messages") 27 | for i, message in enumerate(messages): 28 | if i == len(messages) - 1: 29 | s += message["content"] 30 | else: 31 | s += message["content"] + "\n" 32 | return s 33 | 34 | 35 | def nop(data: Dict[str, Any], **_: Dict[str, Any]) -> Any: 36 | return data 37 | 38 | 39 | def get_prompt(data: Dict[str, Any], **_: Dict[str, Any]) -> Any: 40 | return data.get("prompt") 41 | 42 | 43 | def get_file_name(data: Dict[str, Any], **_: Dict[str, Any]) -> str: 44 | return data.get("file").name 45 | 46 | 47 | def get_file_bytes(data: Dict[str, Any], **_: Dict[str, Any]) -> bytes: 48 | return data.get("file").peek() 49 | 50 | 51 | def get_input_str(data: Dict[str, Any], **_: Dict[str, Any]) -> str: 52 | input_data = data.get("input") 53 | return str(input_data["image"].peek()) + input_data["question"] 54 | 55 | 56 | def get_input_image_file_name(data: Dict[str, Any], **_: Dict[str, Any]) -> str: 57 | input_data = data.get("input") 58 | return input_data["image"].name 59 | 60 | 61 | def query_multi_splicing(data: Dict[str, Any], **_: Dict[str, Any]) -> Any: 62 | query_list = data.get("query") 63 | return multi_splicing(query_list) 64 | 65 | 66 | def insert_multi_splicing(data: Dict[str, Any], **_: Dict[str, Any]) -> Any: 67 | insert_query_list = data['query'] 68 | return multi_splicing(insert_query_list) 69 | 70 | def query_with_role(data: Dict[str, Any], **_: Dict[str, Any]) -> Any: 71 | query = data["query"][-1] 72 | content = query["content"] 73 | role = query["role"] 74 | return role+": "+content 75 | 76 | def multi_splicing(data_list) -> Any: 77 | result_str = "" 78 | for d in data_list: 79 | role = d.get('role', '') 80 | content = d.get('content', '') 81 | result_str += role + "###" + content + "|||" 82 | 83 | # 去掉最后一个"|||" 84 | result_str = result_str[:-3] 85 | 86 | return result_str 87 | 88 | 89 | def multi_analysis(dialog_str): 90 | sub_strings = dialog_str.split('|||') 91 | 92 | dict_list = [] 93 | for s in sub_strings: 94 | parts = s.split('###') 95 | 96 | if len(parts) == 2: 97 | role = parts[0] 98 | content = parts[1] 99 | elif len(parts) > 2: 100 | role = parts[0] 101 | content = '###'.join(parts[1:]) 102 | else: 103 | content = 'exception' 104 | 105 | if content == '': 106 | d = {"role": role} 107 | else: 108 | d = {"role": role, "content": content} 109 | dict_list.append(d) 110 | 111 | # 3. 将每个字典添加到一个列表中,得到最终的列表 112 | result_list = dict_list 113 | 114 | # 输出结果 115 | return result_list 116 | -------------------------------------------------------------------------------- /mulicache-readme-cn.md: -------------------------------------------------------------------------------- 1 | # MultiModal Cache 2 | 3 | 为满足多模态的性能要求,我们在 LLModel Cache 的基础上,开发了 MultiModal Cache 系统。MultiModal Cache 增强了 ModelCache 功能,架优化架构,适应多种应用场景。 4 | 5 | - [MultiModal Cache](#multimodal-cache) 6 | - [最新动态](#最新动态) 7 | - [特性](#特性) 8 | - [性能](#性能) 9 | - [效果评估](#效果评估) 10 | - [参与贡献](#参与贡献) 11 | 12 | ## 最新动态 13 | 14 | - [2024.12.12] MultiModal Cache 系统正式发布。 15 | 16 | ## 特性 17 | 18 | | 场景 | 数据类型 | 图像格式 | 数据隔离 | 19 | |------|----------|----------|----------| 20 | | 文本对话 | 文本 | 不适用 | 支持 | 21 | | 图文理解 | 文本+图像 | image_url/image_base64 | 支持 | 22 | 23 | - **兼容性**:支持文本和图片链接(image_url)和图片 Base64 编码三种数据格式及其组合。 24 | - **数据隔离**:支持多模型数据隔离,允许不同数据模型在同一系统中独立运行。 25 | - **模态隔离**:支持同一模型下不同模态数据(如文本和图像)的隔离处理。 26 | 27 | ## 性能 28 | 29 | 我们在生产环境中使用企业级数据库对 MultiModal Cache 进行了全面的性能评估。以下是详细的性能数据: 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 |
请求类型Cache Hit总耗时范围组件组件耗时
TextHit420ms-520msMulti-Encoder (Text):~300ms
向量存储检索40-50ms
关系存储检索60-70ms
Not Hit300ms+N(s)Multi-Encoder (Text):~300ms
向量存储检索40-50ms
大模型调用N (s)
IMG_TEXTHit600ms-800msMulti-Encoder (image+text)~600ms
向量存储检索40-50ms
关系存储检索60-70ms
Not Hit600ms+N(s)Multi-Encoder (image+text)~600ms
向量存储检索40-50ms
大模型调用N (s)
98 | 99 | 根据目前的评估结果,Embedding 的推理时间存在较大的优化空间。 100 | **说明**:使用嵌入式数据库可能会进一步提升性能。 101 | 102 | ## 效果评估 103 | 104 | 为全面评估 Cache 对模型服务的影响,我们进行了端到端的性能测试,ua 比较了有 Cache 和无 Cache 两种服务配置。我们使用了 5000 个测试用例的数据集进行自动化测试。 105 | 106 | - 有 Cache 的预发模型服务:观察其响应时间,预期 Cache 的引入能够显著提升服务的性能,降低延迟。 107 | - 无 Cache 的线上模型服务,以获取其原始性能指标和输出结果。这些数据将作为对比基准。 108 | 109 | 为了确保 Cache 引入后的数据准确性和一致性,我们比较了两个服务返回的结果,验证了 Cache 机制是否会影响最终用户收到的回复内容。 110 | 111 | 与原始的直接模型调用方式相比,Cache Service 的调用耗时数据呈现出稳定的分布特征,性能上并不会随着模型参数规模的增加而受到影响。在传统情况下,随着模型参数规模的扩大,模型调用的耗时往往会上升,这是因为更大规模的模型需要更多的计算资源。Cache 服务通过存储经常访问的数据来避免重复的计算,从而一定程度上解耦了耗时与模型复杂性之间的关联。 112 | 113 | ![cache-service-cost-time-distribution](docs/cache-service-cost-time-distribution.webp) 114 | 115 | 我们对缓存命中的耗时与实际调用模型的耗时进行了对比分析。实验数据表明,在集成 Cache Service之后,基于 llama7B 模型,缓存命中所带来的性能提升超过了 40%。预计随着模型的持续迭代与优化,性能提升的幅度将会有更进一步的增长。 116 | 117 | ![time-cost-comparison](docs/time-cost-comparison.webp) 118 | 119 | ## 参与贡献 120 | 121 | MultiModal Cache 是一个充满潜力的开源项目,我们欢迎各种形式的贡献: 122 | 123 | - 提交问题和建议 124 | - 参与代码编写 125 | - 完善文档和示例 126 | 127 | 无论您是经验丰富的开发者还是新手,您的参与都将使这个项目更加出色,同时为开源社区做出贡献。 128 | -------------------------------------------------------------------------------- /modelcache_mm/processor/pre.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import re 3 | from typing import Dict, Any 4 | 5 | 6 | def insert_last_content(data: Dict[str, Any], **_: Dict[str, Any]) -> Any: 7 | return data.get("chat_info")[-1]["query"] 8 | 9 | 10 | def query_last_content(data: Dict[str, Any], **_: Dict[str, Any]) -> Any: 11 | return data.get("query")[-1]["content"] 12 | 13 | 14 | def last_content_without_prompt(data: Dict[str, Any], **params: Dict[str, Any]) -> Any: 15 | last_content_str = data.get("messages")[-1]["content"] 16 | prompts = params.get("prompts", []) 17 | if prompts is None: 18 | return last_content_str 19 | pattern = "|".join(prompts) 20 | new_content_str = re.sub(pattern, "", last_content_str) 21 | return new_content_str 22 | 23 | 24 | def all_content(data: Dict[str, Any], **_: Dict[str, Any]) -> Any: 25 | s = "" 26 | messages = data.get("messages") 27 | for i, message in enumerate(messages): 28 | if i == len(messages) - 1: 29 | s += message["content"] 30 | else: 31 | s += message["content"] + "\n" 32 | return s 33 | 34 | 35 | def nop(data: Dict[str, Any], **_: Dict[str, Any]) -> Any: 36 | return data 37 | 38 | 39 | def get_prompt(data: Dict[str, Any], **_: Dict[str, Any]) -> Any: 40 | return data.get("prompt") 41 | 42 | 43 | def get_file_name(data: Dict[str, Any], **_: Dict[str, Any]) -> str: 44 | return data.get("file").name 45 | 46 | 47 | def get_file_bytes(data: Dict[str, Any], **_: Dict[str, Any]) -> bytes: 48 | return data.get("file").peek() 49 | 50 | 51 | def get_input_str(data: Dict[str, Any], **_: Dict[str, Any]) -> str: 52 | input_data = data.get("input") 53 | return str(input_data["image"].peek()) + input_data["question"] 54 | 55 | 56 | def get_input_image_file_name(data: Dict[str, Any], **_: Dict[str, Any]) -> str: 57 | input_data = data.get("input") 58 | return input_data["image"].name 59 | 60 | 61 | def query_multi_splicing(data: Dict[str, Any], **_: Dict[str, Any]) -> Any: 62 | query_list = data.get("query") 63 | return multi_splicing(query_list) 64 | 65 | 66 | def insert_multi_splicing(data: Dict[str, Any], **_: Dict[str, Any]) -> Any: 67 | insert_query_list = data.get("chat_info")[-1]['query'] 68 | return multi_splicing(insert_query_list) 69 | 70 | 71 | def multi_splicing(data_list) -> Any: 72 | result_str = "" 73 | for d in data_list: 74 | role = d.get('role', '') 75 | content = d.get('content', '') 76 | result_str += role + "###" + content + "|||" 77 | 78 | # 去掉最后一个"|||" 79 | result_str = result_str[:-3] 80 | 81 | return result_str 82 | 83 | 84 | def multi_analysis(dialog_str): 85 | sub_strings = dialog_str.split('|||') 86 | dict_list = [] 87 | for s in sub_strings: 88 | parts = s.split('###') 89 | if len(parts) == 2: 90 | role = parts[0] 91 | content = parts[1] 92 | elif len(parts) > 2: 93 | role = parts[0] 94 | content = '###'.join(parts[1:]) 95 | else: 96 | content = 'exception' 97 | if content == '': 98 | d = {"role": role} 99 | else: 100 | d = {"role": role, "content": content} 101 | dict_list.append(d) 102 | result_list = dict_list 103 | 104 | # 输出结果 105 | return result_list 106 | 107 | 108 | def mm_insert_dict(data: Dict[str, Any], **_: Dict[str, Any]) -> Any: 109 | query_dict = data.get("chat_info")[-1]['query'] 110 | return query_dict 111 | 112 | 113 | def mm_query_dict(data: Dict[str, Any], **_: Dict[str, Any]) -> Any: 114 | query_dict = data.get("query") 115 | return query_dict 116 | -------------------------------------------------------------------------------- /modelcache_mm/manager/vector_data/chroma.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | import logging 5 | from modelcache_mm.manager.vector_data.base import VectorBase, VectorData 6 | from modelcache_mm.utils import import_chromadb, import_torch 7 | from modelcache_mm.utils.index_util import get_mm_index_name 8 | 9 | import_torch() 10 | import_chromadb() 11 | 12 | import chromadb 13 | 14 | 15 | class Chromadb(VectorBase): 16 | 17 | def __init__( 18 | self, 19 | persist_directory="./chromadb", 20 | top_k: int = 1, 21 | ): 22 | self.top_k = top_k 23 | 24 | self._client = chromadb.PersistentClient(path=persist_directory) 25 | self._collection = None 26 | 27 | def create(self, model=None, mm_type=None): 28 | try: 29 | collection_name_model = get_mm_index_name(model, mm_type) 30 | # collection_name_model = self.collection_name + '_' + model 31 | self._client.get_or_create_collection(name=collection_name_model) 32 | except Exception as e: 33 | raise ValueError(str(e)) 34 | 35 | def add(self, datas: List[VectorData], model=None, mm_type=None): 36 | collection_name_model = get_mm_index_name(model, mm_type) 37 | self._collection = self._client.get_or_create_collection(name=collection_name_model) 38 | 39 | data_array, id_array = map(list, zip(*((data.data.tolist(), str(data.id)) for data in datas))) 40 | self._collection.add(embeddings=data_array, ids=id_array) 41 | 42 | def search(self, data: np.ndarray, top_k: int = -1, model=None, mm_type='mm'): 43 | collection_name_model = get_mm_index_name(model, mm_type) 44 | self._collection = self._client.get_or_create_collection(name=collection_name_model) 45 | 46 | if self._collection.count() == 0: 47 | return [] 48 | if top_k == -1: 49 | top_k = self.top_k 50 | results = self._collection.query( 51 | query_embeddings=[data.tolist()], 52 | n_results=top_k, 53 | include=["distances"], 54 | ) 55 | return list(zip(results["distances"][0], [int(x) for x in results["ids"][0]])) 56 | 57 | def delete(self, ids, model=None, mm_type=None): 58 | try: 59 | collection_name_model = get_mm_index_name(model, mm_type) 60 | self._collection = self._client.get_or_create_collection(name=collection_name_model) 61 | # 查询集合中实际存在的 ID 62 | ids_str = [str(x) for x in ids] 63 | existing_ids = set(self._collection.get(ids=ids_str).ids) 64 | 65 | # 删除存在的 ID 66 | if existing_ids: 67 | self._collection.delete(list(existing_ids)) 68 | 69 | # 返回实际删除的条目数量 70 | return len(existing_ids) 71 | 72 | except Exception as e: 73 | logging.error('Error during deletion: {}'.format(e)) 74 | raise ValueError(str(e)) 75 | 76 | def rebuild_idx(self, model, mm_type=None): 77 | collection_name_model = get_mm_index_name(model, mm_type) 78 | 79 | # 检查集合是否存在,如果存在则删除 80 | collections = self._client.list_collections() 81 | if any(col.name == collection_name_model for col in collections): 82 | self._client.delete_collection(collection_name_model) 83 | else: 84 | return 'model collection not found, please check!' 85 | 86 | try: 87 | self._client.create_collection(collection_name_model) 88 | except Exception as e: 89 | logging.info(f'rebuild_collection: {e}') 90 | raise ValueError(str(e)) 91 | 92 | def rebuild(self, ids=None): 93 | pass 94 | 95 | def flush(self): 96 | pass 97 | 98 | def close(self): 99 | pass 100 | -------------------------------------------------------------------------------- /modelcache/manager/scalar_data/base.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from abc import ABCMeta, abstractmethod 3 | from dataclasses import dataclass 4 | from typing import Union, Dict, List, Optional, Any 5 | from enum import IntEnum 6 | import numpy as np 7 | 8 | from modelcache.utils import import_sql_client 9 | from modelcache.utils.error import NotFoundError 10 | 11 | 12 | class DataType(IntEnum): 13 | STR = 0 14 | IMAGE_BASE64 = 1 15 | IMAGE_URL = 2 16 | 17 | 18 | @dataclass 19 | class QuestionDep: 20 | """ 21 | QuestionDep 22 | """ 23 | 24 | name: str 25 | data: str 26 | dep_type: int = DataType.STR 27 | 28 | @classmethod 29 | def from_dict(cls, d: Dict): 30 | return cls( 31 | name=d["name"], 32 | data=d["data"], 33 | dep_type=d["dep_type"] 34 | ) 35 | 36 | 37 | @dataclass 38 | class Question: 39 | """ 40 | Question 41 | """ 42 | 43 | content: str 44 | deps: Optional[List[QuestionDep]] = None 45 | 46 | @classmethod 47 | def from_dict(cls, d: Dict): 48 | deps = [] 49 | for dep in d["deps"]: 50 | deps.append(QuestionDep.from_dict(dep)) 51 | return cls(d["content"], deps) 52 | 53 | 54 | @dataclass 55 | class Answer: 56 | """ 57 | data_type: 58 | 0: str 59 | 1: base64 image 60 | """ 61 | 62 | answer: Any 63 | answer_type: int = DataType.STR 64 | 65 | 66 | @dataclass 67 | class CacheData: 68 | """ 69 | CacheData 70 | """ 71 | 72 | question: Union[str, Question] 73 | answers: List[Answer] 74 | embedding_data: Optional[np.ndarray] = None 75 | 76 | def __init__(self, question, answers, embedding_data=None): 77 | self.question = question 78 | self.answers = [] 79 | if isinstance(answers, (str, Answer)): 80 | answers = [answers] 81 | for data in answers: 82 | if isinstance(data, (list, tuple)): 83 | self.answers.append(Answer(*data)) 84 | elif isinstance(data, Answer): 85 | self.answers.append(data) 86 | else: 87 | self.answers.append(Answer(answer=data)) 88 | self.embedding_data = embedding_data 89 | 90 | 91 | class CacheStorage(metaclass=ABCMeta): 92 | """ 93 | BaseStorage for scalar data. 94 | """ 95 | 96 | @abstractmethod 97 | def create(self): 98 | pass 99 | 100 | @abstractmethod 101 | def insert_query_resp(self, query_resp, **kwargs): 102 | pass 103 | 104 | @abstractmethod 105 | def get_data_by_id(self, key): 106 | pass 107 | 108 | @abstractmethod 109 | def mark_deleted(self, keys): 110 | pass 111 | 112 | @abstractmethod 113 | def model_deleted(self, model): 114 | pass 115 | 116 | @abstractmethod 117 | def clear_deleted_data(self): 118 | pass 119 | 120 | @abstractmethod 121 | def get_ids(self, deleted=True): 122 | pass 123 | 124 | @abstractmethod 125 | def count(self): 126 | pass 127 | 128 | @abstractmethod 129 | def flush(self): 130 | pass 131 | 132 | @abstractmethod 133 | def close(self): 134 | pass 135 | 136 | @abstractmethod 137 | def batch_insert(self, all_data: List[CacheData]): 138 | pass 139 | 140 | @abstractmethod 141 | def update_hit_count_by_id(self, primary_id): 142 | pass 143 | 144 | @staticmethod 145 | def get(name, **kwargs): 146 | if name in ["mysql", "oceanbase"]: 147 | from modelcache.manager.scalar_data.sql_storage import SQLStorage 148 | config = kwargs.get("config") 149 | import_sql_client(name) 150 | cache_base = SQLStorage(db_type=name, config=config) 151 | elif name == 'sqlite': 152 | SQL_URL = {"sqlite": "./sqlite.db"} 153 | from modelcache.manager.scalar_data.sql_storage_sqlite import SQLStorage 154 | sql_url = kwargs.get("sql_url", SQL_URL[name]) 155 | cache_base = SQLStorage(db_type=name, url=sql_url) 156 | elif name == 'elasticsearch': 157 | from modelcache.manager.scalar_data.sql_storage_es import SQLStorage 158 | config = kwargs.get("config") 159 | cache_base = SQLStorage(db_type=name, config=config) 160 | else: 161 | raise NotFoundError("cache store", name) 162 | return cache_base 163 | 164 | -------------------------------------------------------------------------------- /docs/1.what-is-model-cache.md: -------------------------------------------------------------------------------- 1 | # What is ModelCache 2 | 3 | In ModelCache, we adopted the main idea of GPTCache, includes core modules: adapter, embedding, similarity, and data_manager. The adapter module is responsible for handling the business logic of various tasks and can connect the embedding, similarity, and data_manager modules. The embedding module is mainly responsible for converting text into semantic vector representations, it transforms user queries into vector form.The rank module is used for sorting and evaluating the similarity of the recalled vectors. The data_manager module is primarily used for managing the database. In order to better facilitate industrial applications, we have made architectural and functional upgrades as follows: 4 | 5 | ## Architecture 6 | 7 | ![modelcache modules](modelcache_modules_20240409.png) 8 | 9 | ## Function comparison 10 | 11 | We've implemented several key updates to our repository. We've resolved network issues with Hugging Face and improved inference speed by introducing local embedding capabilities. Due to limitations in SqlAlchemy, we've redesigned our relational database interaction module for more flexible operations. We've added multi-tenancy support to ModelCache, recognizing the need for multiple users and models in LLM products. Lastly, we've made initial adjustments for better compatibility with system commands and multi-turn dialogues. 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 |
ModuleFunction
ModelCacheGPTCache
Basic InterfaceData query interface
Data writing interface
EmbeddingEmbedding model configuration
Large model embedding layer
BERT model long text processing
Large model invocationDecoupling from large models
Local loading of embedding model
Data isolationModel data isolation
Hyperparameter isolation
DatabasesMySQL
Milvus
OceanBase
Session managementSingle-turn dialogue
System commands
Multi-turn dialogue
Data managementData persistence
One-click cache clearance
Tenant managementSupport for multi-tenancy
Milvus multi-collection capability
OtherLong-short dialogue distinction
133 | -------------------------------------------------------------------------------- /modelcache_mm/manager/vector_data/manager.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from modelcache_mm.utils.error import NotFoundError, ParamError 3 | 4 | TOP_K = 1 5 | FAISS_INDEX_PATH = "mm_faiss.index" 6 | DIMENSION = 0 7 | MILVUS_HOST = "localhost" 8 | MILVUS_PORT = 19530 9 | MILVUS_USER = "" 10 | MILVUS_PSW = "" 11 | MILVUS_SECURE = False 12 | MILVUS_INDEX_PARAMS = { 13 | "metric_type": "L2", 14 | "index_type": "HNSW", 15 | "params": {"M": 8, "efConstruction": 64}, 16 | } 17 | 18 | COLLECTION_NAME = "modelcache" 19 | 20 | 21 | class VectorBase: 22 | """ 23 | VectorBase to manager the vector base. 24 | """ 25 | 26 | def __init__(self): 27 | raise EnvironmentError( 28 | "VectorBase is designed to be instantiated, please using the `VectorBase.get(name)`." 29 | ) 30 | 31 | @staticmethod 32 | def check_dimension(dimension): 33 | if dimension <= 0: 34 | raise ParamError( 35 | f"the dimension should be greater than zero, current value: {dimension}." 36 | ) 37 | 38 | @staticmethod 39 | def get(name, **kwargs): 40 | top_k = kwargs.get("top_k", TOP_K) 41 | if name == "milvus": 42 | from modelcache.manager.vector_data.milvus import Milvus 43 | milvus_config = kwargs.get("milvus_config") 44 | dimension = kwargs.get("dimension", DIMENSION) 45 | VectorBase.check_dimension(dimension) 46 | host = milvus_config.get('milvus', 'host') 47 | port = milvus_config.get('milvus', 'port') 48 | user = milvus_config.get('milvus', 'user') 49 | password = milvus_config.get('milvus', 'password') 50 | 51 | secure = kwargs.get("secure", MILVUS_SECURE) 52 | collection_name = kwargs.get("collection_name", COLLECTION_NAME) 53 | index_params = kwargs.get("index_params", MILVUS_INDEX_PARAMS) 54 | search_params = kwargs.get("search_params", None) 55 | local_mode = kwargs.get("local_mode", False) 56 | local_data = kwargs.get("local_data", "./milvus_data") 57 | vector_base = Milvus( 58 | host=host, 59 | port=port, 60 | user=user, 61 | password=password, 62 | secure=secure, 63 | collection_name=collection_name, 64 | dimension=dimension, 65 | top_k=top_k, 66 | index_params=index_params, 67 | search_params=search_params, 68 | local_mode=local_mode, 69 | local_data=local_data 70 | ) 71 | elif name == "redis": 72 | from modelcache_mm.manager.vector_data.redis import RedisVectorStore 73 | redis_config = kwargs.get("redis_config") 74 | 75 | mm_dimension = kwargs.get("mm_dimension", DIMENSION) 76 | i_dimension = kwargs.get("i_dimension", DIMENSION) 77 | t_dimension = kwargs.get("t_dimension", DIMENSION) 78 | VectorBase.check_dimension(mm_dimension) 79 | VectorBase.check_dimension(i_dimension) 80 | VectorBase.check_dimension(t_dimension) 81 | 82 | host = redis_config.get('redis', 'host') 83 | port = redis_config.get('redis', 'port') 84 | user = redis_config.get('redis', 'user') 85 | password = redis_config.get('redis', 'password') 86 | namespace = kwargs.get("namespace", "") 87 | # collection_name = kwargs.get("collection_name", COLLECTION_NAME) 88 | 89 | vector_base = RedisVectorStore( 90 | host=host, 91 | port=port, 92 | username=user, 93 | password=password, 94 | namespace=namespace, 95 | top_k=top_k, 96 | mm_dimension=mm_dimension, 97 | i_dimension=i_dimension, 98 | t_dimension=t_dimension, 99 | ) 100 | elif name == "faiss": 101 | from modelcache_mm.manager.vector_data.faiss import Faiss 102 | dimension = kwargs.get("dimension", DIMENSION) 103 | VectorBase.check_dimension(dimension) 104 | 105 | index_path = kwargs.pop("index_path", FAISS_INDEX_PATH) 106 | vector_base = Faiss( 107 | index_file_path=index_path, 108 | dimension=dimension, 109 | top_k=top_k 110 | ) 111 | elif name == "chromadb": 112 | from modelcache_mm.manager.vector_data.chroma import Chromadb 113 | 114 | chromadb_config = kwargs.get("chromadb_config", None) 115 | persist_directory = chromadb_config.get('chromadb', 'persist_directory') 116 | vector_base = Chromadb( 117 | persist_directory=persist_directory, 118 | top_k=top_k, 119 | ) 120 | else: 121 | raise NotFoundError("vector store", name) 122 | return vector_base 123 | -------------------------------------------------------------------------------- /modelcache/embedding/base.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from abc import abstractmethod, ABCMeta 3 | 4 | from modelcache.utils.error import CacheError 5 | from modelcache.utils.lazy_import import LazyImport 6 | from enum import Enum 7 | 8 | from modelcache.utils.log import modelcache_log 9 | 10 | huggingface = LazyImport("huggingface", globals(), "modelcache.embedding.huggingface") 11 | data2vec = LazyImport("data2vec", globals(), "modelcache.embedding.data2vec") 12 | llmEmb = LazyImport("llmEmb", globals(), "modelcache.embedding.llmEmb") 13 | fasttext = LazyImport("fasttext", globals(), "modelcache.embedding.fasttext") 14 | paddlenlp = LazyImport("paddlenlp", globals(), "modelcache.embedding.paddlenlp") 15 | timm = LazyImport("timm", globals(), "modelcache.embedding.timm") 16 | huggingface_tei = LazyImport("huggingface_tei", globals(), "modelcache.embedding.huggingface_tei") 17 | bge_m3 = LazyImport("bge_m3", globals(), "modelcache.embedding.bge_m3") 18 | 19 | # define the embedding model enum 20 | class EmbeddingModel(Enum): 21 | """ 22 | Enum for different embedding models. 23 | """ 24 | # todo: fill in the dimension and model_path for each embedding model as needed 25 | HUGGINGFACE_ALL_MPNET_BASE_V2 = {"dimension":768, "model_path":"sentence-transformers/all-mpnet-base-v2"} 26 | HUGGINGFACE_ALL_MINILM_L6_V2 = {"dimension":384, "model_path":"sentence-transformers/all-MiniLM-L6-v2"} 27 | HUGGINGFACE_ALL_MINILM_L12_V2 = {"dimension":384, "model_path":"sentence-transformers/all-MiniLM-L12-v2"} 28 | DATA2VEC_AUDIO = {"dimension":768, "model_path":"model/text2vec-base-chinese/"} 29 | LLM_EMB2VEC_AUDIO = {"dimension":None, "model_path":None} 30 | FASTTEXT = {"dimension":None, "model_path":None} 31 | PADDLE_NLP = {"dimension":None, "model_path":None} 32 | TIMM = {"dimension":None, "model_path":None} 33 | HUGGINGFACE_TEI = {"dimension":None, "model_path":None} 34 | BGE_M3 = {"dimension":None, "model_path":None} 35 | 36 | 37 | class MetricType(Enum): 38 | """ 39 | Enum for different metric types used in similarity evaluation. 40 | Different models may require different metrics for optimal performance. 41 | """ 42 | COSINE = "COSINE" 43 | L2 = "L2" 44 | 45 | 46 | class BaseEmbedding(metaclass=ABCMeta): 47 | """ 48 | _Embedding base. 49 | """ 50 | 51 | @abstractmethod 52 | def to_embeddings(self, data, **kwargs): 53 | pass 54 | 55 | @property 56 | @abstractmethod 57 | def dimension(self) -> int: 58 | return 0 59 | 60 | @staticmethod 61 | def get(model:EmbeddingModel, **kwargs): 62 | """ 63 | Get the embedding model instance based on the specified model type. 64 | :param model: The embedding model type. 65 | :type model: EmbeddingModel 66 | :param kwargs: Additional parameters for the model. 67 | :return: An instance of the specified embedding model. 68 | :rtype: BaseEmbedding 69 | :raises ValueError: If the specified model type is not supported. 70 | """ 71 | if model == EmbeddingModel.HUGGINGFACE_ALL_MPNET_BASE_V2: 72 | model_path = kwargs.pop("model_path","sentence-transformers/all-mpnet-base-v2") 73 | return huggingface.Huggingface(model_path) 74 | 75 | elif model == EmbeddingModel.HUGGINGFACE_ALL_MINILM_L6_V2: 76 | model_path = kwargs.pop("model_path","sentence-transformers/all-MiniLM-L6-v2") 77 | return huggingface.Huggingface(model_path) 78 | 79 | elif model == EmbeddingModel.HUGGINGFACE_ALL_MINILM_L12_V2: 80 | model_path = kwargs.pop("model_path","sentence-transformers/all-MiniLM-L12-v2") 81 | return huggingface.Huggingface(model_path) 82 | 83 | elif model == EmbeddingModel.DATA2VEC_AUDIO: 84 | model_path = kwargs.pop("model_path","model/text2vec-base-chinese/") 85 | return data2vec.Data2VecAudio(model_path) 86 | 87 | elif model == EmbeddingModel.LLM_EMB2VEC_AUDIO: 88 | return llmEmb.LlmEmb2Vec() 89 | 90 | elif model == EmbeddingModel.FASTTEXT: 91 | model_path = kwargs.pop("model_path","en") 92 | dim = kwargs.pop("dim", None) 93 | return fasttext.FastText(model_path, dim) 94 | 95 | elif model == EmbeddingModel.PADDLE_NLP: 96 | model_path = kwargs.pop("model_path", "ernie-3.0-medium-zh") 97 | return paddlenlp.PaddleNLP(model_path) 98 | 99 | elif model == EmbeddingModel.TIMM: 100 | model_path = kwargs.pop("model_path", "resnet50") 101 | device = kwargs.pop("device", "default") 102 | return timm.Timm(model_path, device) 103 | 104 | elif model == EmbeddingModel.HUGGINGFACE_TEI: 105 | base_url = kwargs.pop("base_url") 106 | model_path = kwargs.pop("model_path") 107 | return huggingface_tei.HuggingfaceTEI(base_url, model_path) 108 | 109 | elif model == EmbeddingModel.BGE_M3: 110 | model_path = kwargs.pop("model_path","model/bge-m3") 111 | return bge_m3.BgeM3Embedding(model_path) 112 | 113 | else: 114 | modelcache_log.error(f"Please add configuration for {model} in modelcache/embedding/base.py.") 115 | raise CacheError(f"Please add configuration for {model} in modelcache/embedding/base.py.") 116 | 117 | 118 | -------------------------------------------------------------------------------- /modelcache/embedding/data2vec.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import time 4 | import numpy as np 5 | import torch 6 | from transformers import BertTokenizer, BertModel 7 | from modelcache.embedding.base import BaseEmbedding 8 | 9 | 10 | def mean_pooling(model_output, attention_mask): 11 | token_embeddings = model_output[0] 12 | input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() 13 | return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) 14 | 15 | 16 | class Data2VecAudio(BaseEmbedding): 17 | def __init__(self, model): 18 | current_dir = os.path.dirname(os.path.abspath(__file__)) 19 | parent_dir = os.path.dirname(current_dir) 20 | model_dir = os.path.dirname(parent_dir) 21 | model_path = os.path.join(model_dir, model) 22 | 23 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 24 | self.tokenizer = BertTokenizer.from_pretrained(model_path, local_files_only=True) 25 | self.model = BertModel.from_pretrained(model_path, local_files_only=True) 26 | 27 | try: 28 | self.__dimension = self.model.config.hidden_size 29 | except Exception: 30 | from transformers import AutoConfig 31 | config = AutoConfig.from_pretrained(model) 32 | self.__dimension = config.hidden_size 33 | 34 | def to_embeddings(self, data, **_): 35 | encoded_input = self.tokenizer(data, padding=True, truncation=True, return_tensors='pt') 36 | num_tokens = sum(map(len, encoded_input['input_ids'])) 37 | 38 | if num_tokens <= 512: 39 | with torch.no_grad(): 40 | encoded_input = {k: v.to(self.device) for k, v in encoded_input.items()} 41 | model_output = self.model(**encoded_input) 42 | sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask']) 43 | sentence_embeddings = sentence_embeddings.squeeze(0).detach().cpu().numpy() 44 | embedding_array = np.array(sentence_embeddings).astype("float32") 45 | return embedding_array 46 | else: 47 | window_size = 510 48 | start = 0 49 | input_ids = encoded_input['input_ids'] 50 | input_ids = input_ids[:, 1:-1] 51 | start_token = self.tokenizer.cls_token 52 | end_token = self.tokenizer.sep_token 53 | start_token_id = self.tokenizer.convert_tokens_to_ids(start_token) 54 | end_token_id = self.tokenizer.convert_tokens_to_ids(end_token) 55 | begin_element = torch.tensor([[start_token_id]]) 56 | end_element = torch.tensor([[end_token_id]]) 57 | 58 | embedding_array_list = list() 59 | while start < num_tokens: 60 | # Calculate the ending position of the sliding window. 61 | end = start + window_size 62 | # If the ending position exceeds the length, adjust it to the length. 63 | if end > num_tokens: 64 | end = num_tokens 65 | # Retrieve the data within the sliding window. 66 | input_ids_window = input_ids[:, start:end] 67 | # Insert a new element at position 0. 68 | input_ids_window = torch.cat([begin_element, input_ids_window[:, 0:]], dim=1) 69 | # Insert a new element at the last position. 70 | input_ids_window = torch.cat([input_ids_window, end_element], dim=1) 71 | input_ids_window_length = sum(map(len, input_ids_window)) 72 | token_type_ids = torch.tensor([[0] * input_ids_window_length]) 73 | attention_mask = torch.tensor([[1] * input_ids_window_length]) 74 | 75 | # Concatenate new input_ids 76 | encoded_input_window = {'input_ids': input_ids_window, 'token_type_ids': token_type_ids, 77 | 'attention_mask': attention_mask} 78 | with torch.no_grad(): 79 | encoded_input_window = {k: v.to(self.device) for k, v in encoded_input_window.items()} 80 | model_output_window = self.model(**encoded_input_window) 81 | 82 | sentence_embeddings_window = mean_pooling(model_output_window, encoded_input_window['attention_mask']) 83 | sentence_embeddings_window = sentence_embeddings_window.squeeze(0).detach().cpu().numpy() 84 | embedding_array_window = np.array(sentence_embeddings_window).astype("float32") 85 | embedding_array_list.append(embedding_array_window) 86 | start = end 87 | 88 | embedding_array = np.mean(embedding_array_list, axis=0) 89 | return embedding_array 90 | 91 | def post_proc(self, token_embeddings, inputs): 92 | attention_mask = inputs["attention_mask"] 93 | input_mask_expanded = ( 94 | attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() 95 | ) 96 | sentence_embs = torch.sum( 97 | token_embeddings * input_mask_expanded, 1 98 | ) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) 99 | return sentence_embs 100 | 101 | @property 102 | def dimension(self): 103 | """Embedding dimension. 104 | 105 | :return: embedding dimension 106 | """ 107 | return self.__dimension 108 | -------------------------------------------------------------------------------- /modelcache/manager/eviction/arc_cache.py: -------------------------------------------------------------------------------- 1 | from cachetools import Cache 2 | from collections import OrderedDict 3 | from readerwriterlock import rwlock 4 | 5 | _sentinel = object() 6 | 7 | class ARC(Cache): 8 | """ 9 | Adaptive Replacement Cache (ARC) implementation. 10 | 11 | ARC maintains four lists (T1, T2, B1, B2) to adaptively balance 12 | between LRU and LFU eviction strategies based on access patterns. 13 | """ 14 | 15 | def __init__(self, maxsize, getsizeof=None): 16 | """Initialize ARC cache with maximum size.""" 17 | super().__init__(maxsize, getsizeof) 18 | self.t1 = OrderedDict() # Recent items 19 | self.t2 = OrderedDict() # Frequent items 20 | self.b1 = OrderedDict() # Ghost entries for T1 21 | self.b2 = OrderedDict() # Ghost entries for T2 22 | self.p = 0 # Adaptive parameter 23 | self._rw_lock = rwlock.RWLockWrite() # Thread safety 24 | 25 | def __len__(self): 26 | """Return total number of cached items.""" 27 | return len(self.t1) + len(self.t2) 28 | 29 | def __contains__(self, key): 30 | """Check if key exists in cache.""" 31 | return key in self.t1 or key in self.t2 32 | 33 | def _evict_internal(self): 34 | """Internal method to evict items when cache is full.""" 35 | # Evict from cache lists to ghost lists 36 | while len(self.t1) + len(self.t2) > self.maxsize: 37 | if len(self.t1) > self.p or (len(self.t1) == 0 and len(self.t2) > 0): 38 | key, value = self.t1.popitem(last=False) 39 | self.b1[key] = value 40 | else: 41 | key, value = self.t2.popitem(last=False) 42 | self.b2[key] = value 43 | 44 | # Maintain ghost list sizes 45 | while len(self.b1) > (self.maxsize - self.p): 46 | self.b1.popitem(last=False) 47 | while len(self.b2) > self.p: 48 | self.b2.popitem(last=False) 49 | 50 | def __setitem__(self, key, value): 51 | """Insert or update a cache entry.""" 52 | with self._rw_lock.gen_wlock(): 53 | # Remove key from all lists first 54 | for l in (self.t1, self.t2, self.b1, self.b2): 55 | l.pop(key, None) 56 | # Add to recent list (T1) 57 | self.t1[key] = value 58 | self.t1.move_to_end(key) 59 | self._evict_internal() 60 | 61 | def __getitem__(self, key): 62 | """Retrieve a cache entry and update access pattern.""" 63 | with self._rw_lock.gen_wlock(): 64 | if key in self.t1: 65 | # Move from recent to frequent list 66 | value = self.t1.pop(key) 67 | self.t2[key] = value 68 | self.t2.move_to_end(key) 69 | self.p = max(0, self.p - 1) # Adjust adaptive parameter 70 | self._evict_internal() 71 | return value 72 | if key in self.t2: 73 | # Access frequent list 74 | value = self.t2.pop(key) 75 | self.t2[key] = value 76 | self.t2.move_to_end(key) 77 | self.p = min(self.maxsize, self.p + 1) # Adjust adaptive parameter 78 | self._evict_internal() 79 | return value 80 | if key in self.b1: 81 | # Promote from ghost list B1 to frequent list T2 82 | self.b1.pop(key) 83 | self.p = min(self.maxsize, self.p + 1) # Adjust adaptive parameter 84 | self._evict_internal() 85 | value = super().__missing__(key) 86 | self.t2[key] = value 87 | self.t2.move_to_end(key) 88 | return value 89 | if key in self.b2: 90 | # Promote from ghost list B2 to frequent list T2 91 | self.b2.pop(key) 92 | self.p = max(0, self.p - 1) # Adjust adaptive parameter 93 | self._evict_internal() 94 | value = super().__missing__(key) 95 | self.t2[key] = value 96 | self.t2.move_to_end(key) 97 | return value 98 | return super().__getitem__(key) 99 | 100 | def __missing__(self, key): 101 | """Handle missing keys.""" 102 | raise KeyError(key) 103 | 104 | def pop(self, key, default=_sentinel): 105 | """Remove a cache entry.""" 106 | with self._rw_lock.gen_wlock(): 107 | for l in (self.t1, self.t2, self.b1, self.b2): 108 | if key in l: 109 | return l.pop(key) 110 | if default is _sentinel: 111 | raise KeyError(key) 112 | return default 113 | 114 | def clear(self): 115 | """Clear all cache entries.""" 116 | with self._rw_lock.gen_wlock(): 117 | self.t1.clear() 118 | self.t2.clear() 119 | self.b1.clear() 120 | self.b2.clear() 121 | self.p = 0 122 | super().clear() 123 | 124 | def __iter__(self): 125 | """Iterate over cache keys.""" 126 | yield from self.t1 127 | yield from self.t2 128 | 129 | def __repr__(self): 130 | """Return string representation of the cache.""" 131 | return (f"ARC(maxsize={self.maxsize}, p={self.p}, len={len(self)}, " 132 | f"t1_len={len(self.t1)}, t2_len={len(self.t2)}, " 133 | f"b1_len={len(self.b1)}, b2_len={len(self.b2)})") 134 | -------------------------------------------------------------------------------- /modelcache/manager/vector_data/redis.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from typing import List 3 | import numpy as np 4 | from redis.commands.search.indexDefinition import IndexDefinition, IndexType 5 | from redis.commands.search.query import Query 6 | from redis.commands.search.field import TagField, VectorField, NumericField 7 | from redis.client import Redis 8 | 9 | from modelcache.manager.vector_data.base import VectorStorage, VectorData 10 | from modelcache.utils import import_redis 11 | from modelcache.utils.log import modelcache_log 12 | from modelcache.utils.index_util import get_index_name 13 | from modelcache.utils.index_util import get_index_prefix 14 | import_redis() 15 | 16 | 17 | class RedisVectorStore(VectorStorage): 18 | def __init__( 19 | self, 20 | host: str = "localhost", 21 | port: str = "6379", 22 | username: str = "", 23 | password: str = "", 24 | dimension: int = 0, 25 | top_k: int = 1, 26 | namespace: str = "", 27 | ): 28 | if dimension <= 0: 29 | raise ValueError( 30 | f"invalid `dim` param: {dimension} in the Redis vector store." 31 | ) 32 | self._client = Redis( 33 | host=host, port=int(port), username=username, password=password 34 | ) 35 | self.top_k = top_k 36 | self.dimension = dimension 37 | self.namespace = namespace 38 | self.doc_prefix = f"{self.namespace}doc:" 39 | 40 | def _check_index_exists(self, index_name: str) -> bool: 41 | """Check if Redis index exists.""" 42 | try: 43 | self._client.ft(index_name).info() 44 | except: 45 | modelcache_log.info("Index does not exist") 46 | return False 47 | modelcache_log.info("Index already exists") 48 | return True 49 | 50 | def create_index(self, index_name, index_prefix): 51 | dimension = self.dimension 52 | if self._check_index_exists(index_name): 53 | modelcache_log.info( 54 | "The %s already exists, and it will be used directly", index_name 55 | ) 56 | return 'already_exists' 57 | else: 58 | id_field_name = "data_id" 59 | embedding_field_name = "data_vector" 60 | 61 | id = NumericField(name=id_field_name) 62 | embedding = VectorField(embedding_field_name, 63 | "HNSW", { 64 | "TYPE": "FLOAT32", 65 | "DIM": dimension, 66 | "DISTANCE_METRIC": "L2", 67 | "INITIAL_CAP": 1000, 68 | } 69 | ) 70 | fields = [id, embedding] 71 | definition = IndexDefinition(prefix=[index_prefix], index_type=IndexType.HASH) 72 | 73 | # create Index 74 | self._client.ft(index_name).create_index( 75 | fields=fields, definition=definition 76 | ) 77 | return 'create_success' 78 | 79 | def mul_add(self, datas: List[VectorData], model=None): 80 | # pipe = self._client.pipeline() 81 | for data in datas: 82 | id: int = data.id 83 | embedding = data.data.astype(np.float32).tobytes() 84 | id_field_name = "data_id" 85 | embedding_field_name = "data_vector" 86 | obj = {id_field_name: id, embedding_field_name: embedding} 87 | index_prefix = get_index_prefix(model) 88 | self._client.hset(f"{index_prefix}{id}", mapping=obj) 89 | 90 | def search(self, data: np.ndarray, top_k: int = -1, model=None): 91 | index_name = get_index_name(model) 92 | id_field_name = "data_id" 93 | embedding_field_name = "data_vector" 94 | base_query = f'*=>[KNN 2 @{embedding_field_name} $vector AS distance]' 95 | query = ( 96 | Query(base_query) 97 | .sort_by("distance") 98 | .return_fields(id_field_name, "distance") 99 | .dialect(2) 100 | ) 101 | 102 | query_params = {"vector": data.astype(np.float32).tobytes()} 103 | results = ( 104 | self._client.ft(index_name) 105 | .search(query, query_params=query_params) 106 | .docs 107 | ) 108 | return [(float(result.distance), int(getattr(result, id_field_name))) for result in results] 109 | 110 | def rebuild(self, ids=None) -> bool: 111 | pass 112 | 113 | def rebuild_col(self, model): 114 | index_name_model = get_index_name(model) 115 | if self._check_index_exists(index_name_model): 116 | try: 117 | self._client.ft(index_name_model).dropindex(delete_documents=True) 118 | except Exception as e: 119 | raise ValueError(str(e)) 120 | try: 121 | index_prefix = get_index_prefix(model) 122 | self.create_index(index_name_model, index_prefix) 123 | except Exception as e: 124 | raise ValueError(str(e)) 125 | # return 'rebuild success' 126 | 127 | def delete(self, ids) -> None: 128 | pipe = self._client.pipeline() 129 | for data_id in ids: 130 | pipe.delete(f"{self.doc_prefix}{data_id}") 131 | pipe.execute() 132 | 133 | def create(self, model=None): 134 | index_name = get_index_name(model) 135 | index_prefix = get_index_prefix(model) 136 | return self.create_index(index_name, index_prefix) 137 | 138 | def get_index_by_name(self, index_name): 139 | pass 140 | -------------------------------------------------------------------------------- /modelcache/manager/vector_data/base.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from abc import ABC, abstractmethod 3 | import numpy as np 4 | from typing import List 5 | from dataclasses import dataclass 6 | 7 | from modelcache.embedding import MetricType 8 | from modelcache.utils.error import ParamError, NotFoundError 9 | 10 | TOP_K = 1 11 | FAISS_INDEX_PATH = "faiss.index" 12 | DIMENSION = 0 13 | MILVUS_HOST = "localhost" 14 | MILVUS_PORT = 19530 15 | MILVUS_USER = "" 16 | MILVUS_PSW = "" 17 | MILVUS_SECURE = False 18 | 19 | COLLECTION_NAME = "modelcache" 20 | 21 | @dataclass 22 | class VectorData: 23 | id: int 24 | data: np.ndarray 25 | 26 | 27 | class VectorStorage(ABC): 28 | 29 | @abstractmethod 30 | def mul_add(self, datas: List[VectorData], model=None): 31 | pass 32 | 33 | @abstractmethod 34 | def search(self, data: np.ndarray, top_k: int, model): 35 | pass 36 | 37 | @abstractmethod 38 | def rebuild(self, ids=None) -> bool: 39 | pass 40 | 41 | @abstractmethod 42 | def delete(self, ids) -> bool: 43 | pass 44 | 45 | @abstractmethod 46 | def rebuild_col(self, model): 47 | pass 48 | 49 | @abstractmethod 50 | def flush(self): 51 | pass 52 | 53 | @abstractmethod 54 | def close(self): 55 | pass 56 | 57 | @staticmethod 58 | def get(name, **kwargs): 59 | top_k = kwargs.get("top_k", TOP_K) 60 | if name == "milvus": 61 | from modelcache.manager.vector_data.milvus import Milvus 62 | dimension = kwargs.get("dimension", DIMENSION) 63 | milvus_config = kwargs.get("config") 64 | check_dimension(dimension) 65 | host = milvus_config.get('milvus', 'host') 66 | port = milvus_config.get('milvus', 'port') 67 | user = milvus_config.get('milvus', 'user') 68 | password = milvus_config.get('milvus', 'password') 69 | 70 | metric_type = kwargs.get("metric_type",MetricType.COSINE) 71 | secure = kwargs.get("secure", MILVUS_SECURE) 72 | collection_name = kwargs.get("collection_name", COLLECTION_NAME) 73 | index_params = kwargs.get("index_params", None) 74 | search_params = kwargs.get("search_params", None) 75 | local_mode = kwargs.get("local_mode", False) 76 | local_data = kwargs.get("local_data", "./milvus_data") 77 | vector_base = Milvus( 78 | host=host, 79 | port=port, 80 | user=user, 81 | password=password, 82 | secure=secure, 83 | collection_name=collection_name, 84 | dimension=dimension, 85 | top_k=top_k, 86 | index_params=index_params, 87 | search_params=search_params, 88 | local_mode=local_mode, 89 | local_data=local_data, 90 | metric_type=metric_type 91 | ) 92 | elif name == "redis": 93 | from modelcache.manager.vector_data.redis import RedisVectorStore 94 | dimension = kwargs.get("dimension", DIMENSION) 95 | check_dimension(dimension) 96 | 97 | redis_config = kwargs.get("config") 98 | host = redis_config.get('redis', 'host') 99 | port = redis_config.get('redis', 'port') 100 | user = redis_config.get('redis', 'user') 101 | password = redis_config.get('redis', 'password') 102 | namespace = kwargs.get("namespace", "") 103 | # collection_name = kwargs.get("collection_name", COLLECTION_NAME) 104 | 105 | vector_base = RedisVectorStore( 106 | host=host, 107 | port=port, 108 | username=user, 109 | password=password, 110 | namespace=namespace, 111 | top_k=top_k, 112 | dimension=dimension, 113 | ) 114 | elif name == "faiss": 115 | from modelcache.manager.vector_data.faiss import Faiss 116 | 117 | dimension = kwargs.get("dimension", DIMENSION) 118 | index_path = kwargs.pop("index_path", FAISS_INDEX_PATH) 119 | check_dimension(dimension) 120 | vector_base = Faiss( 121 | index_file_path=index_path, dimension=dimension, top_k=top_k 122 | ) 123 | elif name == "chromadb": 124 | from modelcache.manager.vector_data.chroma import Chromadb 125 | 126 | chromadb_config = kwargs.get("config", None) 127 | persist_directory = chromadb_config.get('chromadb','persist_directory') 128 | 129 | vector_base = Chromadb( 130 | persist_directory=persist_directory, 131 | top_k=top_k, 132 | ) 133 | elif name == "hnswlib": 134 | from modelcache.manager.vector_data.hnswlib_store import Hnswlib 135 | 136 | dimension = kwargs.get("dimension", DIMENSION) 137 | index_path = kwargs.pop("index_path", "./hnswlib_index.bin") 138 | max_elements = kwargs.pop("max_elements", 100000) 139 | VectorStorage.check_dimension(dimension) 140 | vector_base = Hnswlib( 141 | index_file_path=index_path, dimension=dimension, 142 | top_k=top_k, max_elements=max_elements 143 | ) 144 | else: 145 | raise NotFoundError("vector store", name) 146 | return vector_base 147 | 148 | 149 | def check_dimension(dimension): 150 | if dimension <= 0: 151 | raise ParamError(f"the dimension should be greater than zero, current value: {dimension}.") 152 | 153 | 154 | -------------------------------------------------------------------------------- /tests/test_setup_validation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Validation tests to ensure the testing infrastructure is set up correctly. 3 | """ 4 | import pytest 5 | import sys 6 | from pathlib import Path 7 | 8 | 9 | class TestSetupValidation: 10 | """Test class to validate the testing infrastructure setup.""" 11 | 12 | def test_pytest_installed(self): 13 | """Verify pytest is installed and importable.""" 14 | import pytest 15 | assert pytest.__version__ 16 | 17 | def test_pytest_cov_installed(self): 18 | """Verify pytest-cov is installed.""" 19 | import pytest_cov 20 | assert pytest_cov 21 | 22 | def test_pytest_mock_installed(self): 23 | """Verify pytest-mock is installed.""" 24 | import pytest_mock 25 | assert pytest_mock 26 | 27 | def test_modelcache_importable(self): 28 | """Verify the main modelcache package can be imported.""" 29 | import modelcache 30 | assert modelcache 31 | 32 | def test_project_structure(self): 33 | """Verify the expected project structure exists.""" 34 | project_root = Path(__file__).parent.parent 35 | 36 | # Check main directories 37 | assert (project_root / "modelcache").exists() 38 | assert (project_root / "modelcache_mm").exists() 39 | assert (project_root / "tests").exists() 40 | assert (project_root / "tests" / "unit").exists() 41 | assert (project_root / "tests" / "integration").exists() 42 | 43 | # Check configuration files 44 | assert (project_root / "pyproject.toml").exists() 45 | 46 | @pytest.mark.unit 47 | def test_unit_marker(self): 48 | """Test that unit marker works correctly.""" 49 | assert True 50 | 51 | @pytest.mark.integration 52 | def test_integration_marker(self): 53 | """Test that integration marker works correctly.""" 54 | assert True 55 | 56 | @pytest.mark.slow 57 | def test_slow_marker(self): 58 | """Test that slow marker works correctly.""" 59 | assert True 60 | 61 | def test_fixtures_available(self, temp_dir, mock_config, mock_embedding): 62 | """Test that custom fixtures are available and working.""" 63 | # Test temp_dir fixture 64 | assert temp_dir.exists() 65 | assert temp_dir.is_dir() 66 | 67 | # Test mock_config fixture 68 | assert isinstance(mock_config, dict) 69 | assert "cache_dir" in mock_config 70 | assert "embedding_model" in mock_config 71 | 72 | # Test mock_embedding fixture 73 | assert hasattr(mock_embedding, "embed") 74 | assert hasattr(mock_embedding, "dimension") 75 | 76 | def test_sample_data_fixtures(self, sample_vector_data, sample_text_data): 77 | """Test that sample data fixtures provide expected data.""" 78 | # Test vector data 79 | assert isinstance(sample_vector_data, dict) 80 | assert "id" in sample_vector_data 81 | assert "vector" in sample_vector_data 82 | assert len(sample_vector_data["vector"]) == 768 83 | 84 | # Test text data 85 | assert isinstance(sample_text_data, list) 86 | assert len(sample_text_data) > 0 87 | assert all(isinstance(text, str) for text in sample_text_data) 88 | 89 | def test_mock_fixtures(self, mock_redis_client, mock_milvus_client, mock_cache_manager): 90 | """Test that mock fixtures are properly configured.""" 91 | # Test Redis mock 92 | assert mock_redis_client.get("test") is None 93 | assert mock_redis_client.set("test", "value") is True 94 | 95 | # Test Milvus mock 96 | assert hasattr(mock_milvus_client, "search") 97 | assert hasattr(mock_milvus_client, "insert") 98 | 99 | # Test cache manager mock 100 | assert mock_cache_manager.get("test") is None 101 | assert mock_cache_manager.set("test", "value") is True 102 | 103 | def test_environment_reset(self): 104 | """Test that environment is properly set for testing.""" 105 | import os 106 | assert os.environ.get("MODELCACHE_ENV") == "test" 107 | assert os.environ.get("MODELCACHE_LOG_LEVEL") == "DEBUG" 108 | 109 | def test_coverage_configured(self): 110 | """Test that coverage is properly configured.""" 111 | # This test will be meaningful when running with coverage 112 | # For now, just ensure the test runs 113 | assert True 114 | 115 | 116 | @pytest.mark.unit 117 | class TestUnitTestValidation: 118 | """Validate unit test setup.""" 119 | 120 | def test_unit_tests_discoverable(self): 121 | """Ensure unit tests can be discovered and run.""" 122 | assert True 123 | 124 | def test_unit_test_isolation(self, temp_dir): 125 | """Ensure unit tests have proper isolation with temp directories.""" 126 | test_file = temp_dir / "test.txt" 127 | test_file.write_text("test content") 128 | assert test_file.exists() 129 | assert test_file.read_text() == "test content" 130 | 131 | 132 | @pytest.mark.integration 133 | class TestIntegrationTestValidation: 134 | """Validate integration test setup.""" 135 | 136 | def test_integration_tests_discoverable(self): 137 | """Ensure integration tests can be discovered and run.""" 138 | assert True 139 | 140 | def test_integration_mock_available(self, mock_http_response): 141 | """Ensure integration tests have access to HTTP mocks.""" 142 | assert mock_http_response.status_code == 200 143 | assert mock_http_response.json() == {"status": "success", "data": {}} -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | """ 2 | Shared pytest fixtures and configuration for modelcache tests. 3 | """ 4 | import os 5 | import tempfile 6 | import shutil 7 | from pathlib import Path 8 | from typing import Iterator, Dict, Any 9 | import pytest 10 | from unittest.mock import MagicMock 11 | 12 | 13 | @pytest.fixture 14 | def temp_dir() -> Iterator[Path]: 15 | """ 16 | Create a temporary directory for test files. 17 | 18 | Yields: 19 | Path: Path to the temporary directory 20 | """ 21 | temp_path = Path(tempfile.mkdtemp()) 22 | yield temp_path 23 | # Cleanup after test 24 | if temp_path.exists(): 25 | shutil.rmtree(temp_path) 26 | 27 | 28 | @pytest.fixture 29 | def mock_config() -> Dict[str, Any]: 30 | """ 31 | Provide a mock configuration dictionary for testing. 32 | 33 | Returns: 34 | Dict[str, Any]: Mock configuration 35 | """ 36 | return { 37 | "cache_dir": "/tmp/test_cache", 38 | "max_cache_size": 1000, 39 | "ttl": 3600, 40 | "embedding_model": "test-model", 41 | "similarity_threshold": 0.8, 42 | "vector_dimension": 768, 43 | "batch_size": 32, 44 | "database": { 45 | "type": "memory", 46 | "host": "localhost", 47 | "port": 6379, 48 | "password": None 49 | } 50 | } 51 | 52 | 53 | @pytest.fixture 54 | def mock_embedding(): 55 | """ 56 | Mock embedding object for testing. 57 | 58 | Returns: 59 | MagicMock: Mock embedding with common methods 60 | """ 61 | mock = MagicMock() 62 | mock.embed.return_value = [0.1] * 768 # Default 768-dim embedding 63 | mock.embed_batch.return_value = [[0.1] * 768] * 10 64 | mock.dimension = 768 65 | mock.model_name = "test-embedding-model" 66 | return mock 67 | 68 | 69 | @pytest.fixture 70 | def mock_cache_manager(): 71 | """ 72 | Mock cache manager for testing. 73 | 74 | Returns: 75 | MagicMock: Mock cache manager with common methods 76 | """ 77 | mock = MagicMock() 78 | mock.get.return_value = None 79 | mock.set.return_value = True 80 | mock.delete.return_value = True 81 | mock.clear.return_value = True 82 | mock.size.return_value = 0 83 | return mock 84 | 85 | 86 | @pytest.fixture 87 | def sample_vector_data(): 88 | """ 89 | Sample vector data for testing vector operations. 90 | 91 | Returns: 92 | Dict[str, Any]: Sample vector data 93 | """ 94 | return { 95 | "id": "test_vector_001", 96 | "vector": [0.1, 0.2, 0.3, 0.4, 0.5] * 153 + [0.6, 0.7, 0.8], # 768 dimensions 97 | "metadata": { 98 | "source": "test", 99 | "timestamp": 1234567890, 100 | "model": "test-model" 101 | } 102 | } 103 | 104 | 105 | @pytest.fixture 106 | def mock_redis_client(): 107 | """ 108 | Mock Redis client for testing Redis-based operations. 109 | 110 | Returns: 111 | MagicMock: Mock Redis client 112 | """ 113 | mock = MagicMock() 114 | mock.get.return_value = None 115 | mock.set.return_value = True 116 | mock.delete.return_value = 1 117 | mock.exists.return_value = 0 118 | mock.expire.return_value = True 119 | mock.ttl.return_value = -2 120 | return mock 121 | 122 | 123 | @pytest.fixture 124 | def mock_milvus_client(): 125 | """ 126 | Mock Milvus client for testing vector database operations. 127 | 128 | Returns: 129 | MagicMock: Mock Milvus client 130 | """ 131 | mock = MagicMock() 132 | mock.create_collection.return_value = True 133 | mock.insert.return_value = MagicMock(primary_keys=[1, 2, 3]) 134 | mock.search.return_value = [[]] 135 | mock.query.return_value = [] 136 | mock.delete.return_value = MagicMock(delete_count=1) 137 | return mock 138 | 139 | 140 | @pytest.fixture(autouse=True) 141 | def reset_environment(): 142 | """ 143 | Reset environment variables before each test. 144 | """ 145 | # Store original env vars 146 | original_env = os.environ.copy() 147 | 148 | # Set test environment variables 149 | os.environ["MODELCACHE_ENV"] = "test" 150 | os.environ["MODELCACHE_LOG_LEVEL"] = "DEBUG" 151 | 152 | yield 153 | 154 | # Restore original env vars 155 | os.environ.clear() 156 | os.environ.update(original_env) 157 | 158 | 159 | @pytest.fixture 160 | def sample_text_data(): 161 | """ 162 | Sample text data for testing text processing. 163 | 164 | Returns: 165 | List[str]: List of sample texts 166 | """ 167 | return [ 168 | "This is a test sentence for modelcache.", 169 | "Machine learning models need efficient caching.", 170 | "Vector embeddings help with semantic search.", 171 | "Testing is important for code quality.", 172 | "PyTest makes testing in Python easier." 173 | ] 174 | 175 | 176 | @pytest.fixture 177 | def mock_http_response(): 178 | """ 179 | Mock HTTP response for testing API calls. 180 | 181 | Returns: 182 | MagicMock: Mock response object 183 | """ 184 | mock = MagicMock() 185 | mock.status_code = 200 186 | mock.json.return_value = {"status": "success", "data": {}} 187 | mock.text = '{"status": "success", "data": {}}' 188 | mock.headers = {"Content-Type": "application/json"} 189 | return mock 190 | 191 | 192 | # Pytest configuration hooks 193 | def pytest_configure(config): 194 | """ 195 | Configure pytest with custom settings. 196 | """ 197 | # Add custom markers description 198 | config.addinivalue_line( 199 | "markers", "unit: mark test as a unit test" 200 | ) 201 | config.addinivalue_line( 202 | "markers", "integration: mark test as an integration test" 203 | ) 204 | config.addinivalue_line( 205 | "markers", "slow: mark test as slow running" 206 | ) 207 | 208 | 209 | def pytest_collection_modifyitems(config, items): 210 | """ 211 | Modify test collection to add markers based on test location. 212 | """ 213 | for item in items: 214 | # Auto-mark tests based on their location 215 | if "unit" in str(item.fspath): 216 | item.add_marker(pytest.mark.unit) 217 | elif "integration" in str(item.fspath): 218 | item.add_marker(pytest.mark.integration) -------------------------------------------------------------------------------- /modelcache_mm/manager/vector_data/redis.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from typing import List 3 | import numpy as np 4 | from redis.commands.search.indexDefinition import IndexDefinition, IndexType 5 | from redis.commands.search.query import Query 6 | from redis.commands.search.field import VectorField, NumericField 7 | from redis.client import Redis 8 | 9 | from modelcache_mm.manager.vector_data.base import VectorBase, VectorData 10 | from modelcache_mm.utils import import_redis 11 | from modelcache_mm.utils.log import modelcache_log 12 | from modelcache_mm.utils.index_util import get_mm_index_name 13 | from modelcache_mm.utils.index_util import get_mm_index_prefix 14 | import_redis() 15 | 16 | 17 | class RedisVectorStore(VectorBase): 18 | def __init__( 19 | self, 20 | host: str = "localhost", 21 | port: str = "6379", 22 | username: str = "", 23 | password: str = "", 24 | mm_dimension: int = 0, 25 | i_dimension: int = 0, 26 | t_dimension: int = 0, 27 | top_k: int = 1, 28 | namespace: str = "", 29 | ): 30 | if mm_dimension <= 0: 31 | raise ValueError( 32 | f"invalid `dim` param: {mm_dimension} in the Milvus vector store." 33 | ) 34 | self._client = Redis( 35 | host=host, port=int(port), username=username, password=password 36 | ) 37 | self.top_k = top_k 38 | self.mm_dimension = mm_dimension 39 | self.i_dimension = i_dimension 40 | self.t_dimension = t_dimension 41 | self.namespace = namespace 42 | self.doc_prefix = f"{self.namespace}doc:" 43 | 44 | def _check_index_exists(self, index_name: str) -> bool: 45 | """Check if Redis index exists.""" 46 | try: 47 | self._client.ft(index_name).info() 48 | except: 49 | modelcache_log.info("Index does not exist") 50 | return False 51 | modelcache_log.info("Index already exists") 52 | return True 53 | 54 | def create_index(self, index_name, type, index_prefix): 55 | # dimension = self.dimension 56 | if type == 'IMG_TEXT': 57 | dimension = self.mm_dimension 58 | elif type == 'IMG': 59 | dimension = self.i_dimension 60 | elif type == 'TEXT': 61 | dimension = self.t_dimension 62 | else: 63 | raise ValueError('dimension type exception') 64 | if self._check_index_exists(index_name): 65 | modelcache_log.info( 66 | "The %s already exists, and it will be used directly", index_name 67 | ) 68 | return 'already_exists' 69 | else: 70 | id_field_name = "data_id" 71 | embedding_field_name = "data_vector" 72 | 73 | id = NumericField(name=id_field_name) 74 | embedding = VectorField(embedding_field_name, 75 | "HNSW", { 76 | "TYPE": "FLOAT32", 77 | "DIM": dimension, 78 | "DISTANCE_METRIC": "L2", 79 | "INITIAL_CAP": 1000, 80 | } 81 | ) 82 | fields = [id, embedding] 83 | definition = IndexDefinition(prefix=[index_prefix], index_type=IndexType.HASH) 84 | 85 | # create Index 86 | self._client.ft(index_name).create_index( 87 | fields=fields, definition=definition 88 | ) 89 | return 'create_success' 90 | 91 | def add(self, datas: List[VectorData], model=None, mm_type=None): 92 | for data in datas: 93 | id: int = data.id 94 | embedding = data.data.astype(np.float32).tobytes() 95 | index_prefix = get_mm_index_prefix(model, mm_type) 96 | id_field_name = "data_id" 97 | embedding_field_name = "data_vector" 98 | obj = {id_field_name: id, embedding_field_name: embedding} 99 | self._client.hset(f"{index_prefix}{id}", mapping=obj) 100 | 101 | def search(self, data: np.ndarray, top_k: int = -1, model=None, mm_type=None): 102 | index_name = get_mm_index_name(model, mm_type) 103 | id_field_name = "data_id" 104 | embedding_field_name = "data_vector" 105 | 106 | base_query = f'*=>[KNN 2 @{embedding_field_name} $vector AS distance]' 107 | query = ( 108 | Query(base_query) 109 | .sort_by("distance") 110 | .return_fields(id_field_name, "distance") 111 | .dialect(2) 112 | ) 113 | query_params = {"vector": data.astype(np.float32).tobytes()} 114 | results = ( 115 | self._client.ft(index_name) 116 | .search(query, query_params=query_params) 117 | .docs 118 | ) 119 | return [(float(result.distance), int(getattr(result, id_field_name))) for result in results] 120 | 121 | def create(self, model=None, mm_type=None): 122 | collection_name_model = get_mm_index_name(model, mm_type) 123 | try: 124 | index_prefix = get_mm_index_prefix(model, mm_type) 125 | self.create_index(collection_name_model, mm_type, index_prefix) 126 | except Exception as e: 127 | raise ValueError(str(e)) 128 | return 'success' 129 | 130 | def rebuild(self, ids=None) -> bool: 131 | pass 132 | 133 | def rebuild_idx(self, model, mm_type=None): 134 | for mm_type in ['IMG_TEXT', 'TEXT']: 135 | index_name = get_mm_index_name(model, mm_type) 136 | if self._check_index_exists(index_name): 137 | try: 138 | self._client.ft(index_name).dropindex(delete_documents=True) 139 | except Exception as e: 140 | raise ValueError(str(e)) 141 | try: 142 | index_prefix = get_mm_index_prefix(model, mm_type) 143 | self.create_index(index_name, mm_type, index_prefix) 144 | except Exception as e: 145 | raise ValueError(str(e)) 146 | 147 | def delete(self, ids) -> None: 148 | pipe = self._client.pipeline() 149 | for data_id in ids: 150 | pipe.delete(f"{self.doc_prefix}{data_id}") 151 | pipe.execute() 152 | 153 | def create(self, model=None, type=None): 154 | index_name = get_mm_index_name(model, type) 155 | index_prefix = get_mm_index_prefix(model, type) 156 | return self.create_index(index_name, type, index_prefix) 157 | 158 | def get_index_by_name(self, index_name): 159 | pass 160 | -------------------------------------------------------------------------------- /modelcache/manager/eviction/wtinylfu_cache.py: -------------------------------------------------------------------------------- 1 | from cachetools import LRUCache, Cache, LFUCache 2 | from readerwriterlock import rwlock 3 | import random 4 | 5 | class CountMinSketch: 6 | def __init__(self, width=1024, depth=4, decay_interval=10000): 7 | """Initialize Count-Min Sketch with specified dimensions.""" 8 | self.width = width 9 | self.depth = depth 10 | self.tables = [[0]*width for _ in range(depth)] # Hash tables 11 | self.seeds = [random.randrange(1<<30) for _ in range(depth)] # Hash seeds 12 | self.ops = 0 # Operation counter for decay trigger 13 | self.decay_interval = decay_interval 14 | 15 | def _hash(self, x, seed): 16 | """Hash function for mapping items to table positions.""" 17 | return hash((x, seed)) % self.width 18 | 19 | def add(self, x): 20 | """Add an item and increment its frequency estimate.""" 21 | self.ops += 1 22 | est = self.estimate(x) # Get current estimate 23 | # Update all hash tables 24 | for i, seed in enumerate(self.seeds): 25 | idx = self._hash(x, seed) 26 | if self.tables[i][idx] <= est: 27 | self.tables[i][idx] += 1 28 | 29 | # Periodic decay to handle changing patterns 30 | if self.ops >= self.decay_interval: 31 | self.decay() 32 | self.ops = 0 33 | 34 | def estimate(self, x): 35 | """Estimate frequency of an item (minimum across all tables).""" 36 | return min(self.tables[i][self._hash(x, seed)] 37 | for i, seed in enumerate(self.seeds)) 38 | 39 | def decay(self): 40 | """Decay all frequency counts by half.""" 41 | for table in self.tables: 42 | for i in range(len(table)): 43 | table[i] >>= 1 # Right shift (divide by 2) 44 | 45 | class W2TinyLFU(Cache): 46 | """ 47 | Window Tiny LFU cache implementation. 48 | 49 | Combines a small LRU window cache with a main cache divided into 50 | probation and protected segments, using frequency estimation for 51 | admission control. 52 | """ 53 | 54 | def __init__(self, maxsize, window_pct=0.01): 55 | """ 56 | Initialize W-TinyLFU cache. 57 | 58 | Args: 59 | maxsize: Maximum size of the cache 60 | window_pct: Percentage of cache size for the window (default 1%) 61 | """ 62 | super().__init__(maxsize) 63 | self.window_size = max(1, int(maxsize * window_pct)) 64 | rest = maxsize - self.window_size 65 | self.probation_size = rest // 2 66 | self.protected_size = rest - self.probation_size 67 | 68 | # Three cache segments 69 | self.window = LRUCache(maxsize=self.window_size) # Recent items 70 | self.probation = LFUCache(maxsize=self.probation_size) # New main cache items 71 | self.protected = LFUCache(maxsize=self.protected_size) # Frequently accessed items 72 | 73 | self.cms = CountMinSketch() # Frequency estimator 74 | self.data = {} # Cache data storage 75 | self._rw_lock = rwlock.RWLockWrite() # Read-write lock for thread safety 76 | 77 | def __setitem__(self, key, value): 78 | """Add or update an item in the cache.""" 79 | with self._rw_lock.gen_wlock(): 80 | self.data[key] = value 81 | self._put(key) 82 | 83 | def __getitem__(self, key): 84 | """Retrieve an item from the cache.""" 85 | val = self.get(key, default=None) 86 | if val is None: 87 | raise KeyError(key) 88 | return val 89 | 90 | def __contains__(self, key): 91 | """Check if an item exists in the cache.""" 92 | return key in self.window or key in self.probation or key in self.protected 93 | 94 | def __delitem__(self, key): 95 | """Remove an item from the cache.""" 96 | with self._rw_lock.gen_wlock(): 97 | self.data.pop(key, None) 98 | self.window.pop(key, None) 99 | self.probation.pop(key, None) 100 | self.protected.pop(key, None) 101 | 102 | def get(self, key, default=None): 103 | """ 104 | Retrieve an item from the cache, updating its position 105 | in the cache hierarchy if necessary. 106 | """ 107 | if key in self.window: 108 | self.window[key] = True 109 | return self.data.get(key, default) 110 | if key in self.protected: 111 | self.protected[key] = True 112 | return self.data.get(key, default) 113 | if key in self.probation: 114 | self.probation.pop(key) 115 | if len(self.protected) >= self.protected_size: 116 | demoted = next(iter(self.protected)) 117 | self.protected.pop(demoted) 118 | self.probation[demoted] = True 119 | self.protected[key] = True 120 | return self.data.get(key, default) 121 | return default 122 | 123 | def _put(self, key): 124 | """ 125 | Add an item to the cache, using frequency-based admission 126 | control and eviction policies. 127 | """ 128 | self.cms.add(key) 129 | if key in self: 130 | return 131 | 132 | if len(self.window) < self.window_size: 133 | self.window[key] = True 134 | return 135 | 136 | victim = next(iter(self.window)) 137 | self.window.pop(victim) 138 | 139 | if self.cms.estimate(key) >= self.cms.estimate(victim): 140 | self._admit_to_main(victim) 141 | self._admit_to_main(key) 142 | else: 143 | self._admit_to_main(victim) 144 | self.data.pop(key, None) 145 | 146 | def _admit_to_main(self, key): 147 | """ 148 | Admit an item to the main cache (probation or protected segment). 149 | """ 150 | if key in self.protected or key in self.probation: 151 | return 152 | if self.probation_size == 0: 153 | self.data.pop(key, None) 154 | return 155 | if len(self.probation) < self.probation_size: 156 | self.probation[key] = True 157 | elif self.probation: 158 | evicted = next(iter(self.probation)) 159 | self.probation.pop(evicted) 160 | self.probation[key] = True 161 | self.data.pop(evicted, None) 162 | else: 163 | self.data.pop(key, None) 164 | 165 | def clear(self): 166 | """Clear all items from the cache.""" 167 | with self._rw_lock.gen_wlock(): 168 | self.window.clear() 169 | self.probation.clear() 170 | self.protected.clear() 171 | self.data.clear() 172 | 173 | --------------------------------------------------------------------------------