├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── README_CN.md ├── data ├── milvus │ ├── embedEtcd.yaml │ └── user.yaml └── mysql │ ├── init │ └── init.sql │ └── my.conf ├── docker-compose.yaml ├── docs ├── 1.what-is-model-cache.md ├── 2.model-cache-features.md ├── 3.model-cache-quick-start.md ├── 4.create-cache.md ├── cache-service-cost-time-distribution.webp ├── codefuse-LOGO.png ├── modelcache_modules_20231114.png ├── modelcache_modules_20240409.png ├── script │ └── get_input_embedding_script.py └── time-cost-comparison.webp ├── examples ├── __init__.py ├── embedding │ ├── __init__.py │ └── huggingface_tei_example.py └── flask │ ├── __init__.py │ ├── llms_cache │ ├── __init__.py │ ├── data_insert.py │ ├── data_query.py │ ├── data_query_long.py │ └── register.py │ └── multi_cache │ ├── __init__.py │ ├── data_insert.py │ ├── data_query.py │ ├── register.py │ └── remove.py ├── fastapi4modelcache.py ├── fastapi4modelcache_demo.py ├── flask4modelcache.py ├── flask4modelcache_demo.py ├── flask4multicache.py ├── flask4multicache_demo.py ├── model ├── clip_zh │ └── __init__.py └── text2vec-base-chinese │ ├── config.json │ ├── logs.txt │ ├── modules.json │ ├── sentence_bert_config.json │ ├── special_tokens_map.json │ ├── tokenizer_config.json │ └── vocab.txt ├── modelcache ├── __init__.py ├── adapter │ ├── __init__.py │ ├── adapter.py │ ├── adapter_insert.py │ ├── adapter_query.py │ ├── adapter_register.py │ └── adapter_remove.py ├── config.py ├── config │ ├── chromadb_config.ini │ ├── elasticsearch_config.ini │ ├── milvus_config.ini │ ├── mysql_config.ini │ └── redis_config.ini ├── core.py ├── embedding │ ├── __init__.py │ ├── base.py │ ├── bge_m3.py │ ├── data2vec.py │ ├── fasttext.py │ ├── huggingface.py │ ├── huggingface_tei.py │ ├── llmEmb.py │ ├── onnx.py │ ├── paddlenlp.py │ ├── string_text.py │ └── timm_embedding.py ├── manager │ ├── __init__.py │ ├── data_manager.py │ ├── eviction │ │ ├── __init__.py │ │ ├── base.py │ │ ├── manager.py │ │ └── memory_cache.py │ ├── eviction_manager.py │ ├── factory.py │ ├── object_data │ │ ├── __init__.py │ │ └── base.py │ ├── scalar_data │ │ ├── __init__.py │ │ ├── base.py │ │ ├── manager.py │ │ ├── sql_storage.py │ │ ├── sql_storage_es.py │ │ └── sql_storage_sqlite.py │ └── vector_data │ │ ├── __init__.py │ │ ├── base.py │ │ ├── chroma.py │ │ ├── faiss.py │ │ ├── manager.py │ │ ├── milvus.py │ │ └── redis.py ├── processor │ ├── __init__.py │ ├── post.py │ └── pre.py ├── report.py ├── similarity_evaluation │ ├── __init__.py │ ├── distance.py │ ├── exact_match.py │ └── similarity_evaluation.py └── utils │ ├── __init__.py │ ├── cache_func.py │ ├── dependency_control.py │ ├── env_config.py │ ├── error.py │ ├── index_util.py │ ├── lazy_import.py │ ├── log.py │ ├── model_filter.py │ └── time.py ├── modelcache_mm ├── __init__.py ├── adapter │ ├── __init__.py │ ├── adapter.py │ ├── adapter_insert.py │ ├── adapter_query.py │ ├── adapter_register.py │ └── adapter_remove.py ├── config.py ├── config │ ├── chromadb_config.ini │ ├── elasticsearch_config.ini │ ├── milvus_config.ini │ ├── mysql_config.ini │ └── redis_config.ini ├── core.py ├── embedding │ ├── __init__.py │ ├── base.py │ ├── clip.py │ ├── string.py │ └── timm.py ├── manager │ ├── __init__.py │ ├── data_manager.py │ ├── eviction │ │ ├── __init__.py │ │ └── base.py │ ├── eviction_manager.py │ ├── factory.py │ ├── object_data │ │ ├── __init__.py │ │ └── base.py │ ├── scalar_data │ │ ├── __init__.py │ │ ├── base.py │ │ ├── manager.py │ │ ├── sql_storage.py │ │ ├── sql_storage_es.py │ │ └── sql_storage_sqlite.py │ └── vector_data │ │ ├── __init__.py │ │ ├── base.py │ │ ├── chroma.py │ │ ├── faiss.py │ │ ├── manager.py │ │ └── redis.py ├── processor │ ├── __init__.py │ ├── post.py │ └── pre.py ├── report.py ├── similarity_evaluation │ ├── __init__.py │ ├── distance.py │ ├── exact_match.py │ └── similarity_evaluation.py └── utils │ ├── __init__.py │ ├── cache_func.py │ ├── dependency_control.py │ ├── env_config.py │ ├── error.py │ ├── index_util.py │ ├── lazy_import.py │ ├── log.py │ └── time.py ├── mulicache-readme-cn.md ├── reference_doc └── create_table.sql └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | *.DS_Store 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .nox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | *.py,cover 48 | .hypothesis/ 49 | .pytest_cache/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | db.sqlite3 59 | db.sqlite3-journal 60 | *.db 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # IPython 79 | profile_default/ 80 | ipython_config.py 81 | 82 | # pyenv 83 | .python-version 84 | 85 | __pypackages__/ 86 | 87 | # Celery stuff 88 | celerybeat-schedule 89 | celerybeat.pid 90 | 91 | # SageMath parsed files 92 | *.sage.py 93 | 94 | # Environments 95 | .env 96 | .venv 97 | env/ 98 | venv/ 99 | ENV/ 100 | env.bak/ 101 | venv.bak/ 102 | 103 | # Spyder project settings 104 | .spyderproject 105 | .spyproject 106 | 107 | # Rope project settings 108 | .ropeproject 109 | 110 | # mkdocs documentation 111 | /site 112 | 113 | # mypy 114 | .mypy_cache/ 115 | .dmypy.json 116 | dmypy.json 117 | 118 | # Pyre type checker 119 | .pyre/ 120 | 121 | .idea 122 | **/data_map**.txt 123 | **/faiss**.index 124 | **/sqlite**.db 125 | **/**.db 126 | **/example.py 127 | **/example.db 128 | **/.chroma 129 | 130 | /fuhui_dev 131 | *.index 132 | *model.onnx 133 | 134 | /data_analyse 135 | /embedding_npy 136 | /flask_server 137 | *.bin 138 | **/maya_embedding_service 139 | 140 | *.ini 141 | 142 | **/multicache_serving.py 143 | **/modelcache_serving.py 144 | 145 | **/model/ 146 | 147 | /data/milvus/db 148 | /data/mysql/db -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.9-slim-bookworm 2 | 3 | WORKDIR /home/user 4 | 5 | COPY ./requirements.txt /home/user/docker_requirements.txt 6 | 7 | RUN pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple 8 | RUN pip install -r /home/user/docker_requirements.txt --retries 5 --timeout 120 9 | -------------------------------------------------------------------------------- /data/milvus/embedEtcd.yaml: -------------------------------------------------------------------------------- 1 | listen-client-urls: http://0.0.0.0:2379 2 | advertise-client-urls: http://0.0.0.0:2379 3 | quota-backend-bytes: 4294967296 4 | auto-compaction-mode: revision 5 | auto-compaction-retention: '1000' 6 | -------------------------------------------------------------------------------- /data/milvus/user.yaml: -------------------------------------------------------------------------------- 1 | # Extra config to override default milvus.yaml -------------------------------------------------------------------------------- /data/mysql/init/init.sql: -------------------------------------------------------------------------------- 1 | CREATE DATABASE IF NOT EXISTS `modelcache`; 2 | 3 | USE `modelcache`; 4 | 5 | CREATE TABLE IF NOT EXISTS `modelcache_llm_answer` ( 6 | `id` bigint(20) unsigned NOT NULL AUTO_INCREMENT comment '主键', 7 | `gmt_create` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP comment '创建时间', 8 | `gmt_modified` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP comment '修改时间', 9 | `question` text NOT NULL comment 'question', 10 | `answer` text NOT NULL comment 'answer', 11 | `answer_type` int(11) NOT NULL comment 'answer_type', 12 | `hit_count` int(11) NOT NULL DEFAULT '0' comment 'hit_count', 13 | `model` varchar(1000) NOT NULL comment 'model', 14 | `embedding_data` blob NOT NULL comment 'embedding_data', 15 | `is_deleted` tinyint(1) NOT NULL DEFAULT '0' COMMENT 'delete state(0 Not deleted,-1 deleted)', 16 | PRIMARY KEY(`id`) 17 | ) AUTO_INCREMENT = 1 DEFAULT CHARSET = utf8mb4 COMMENT = 'cache_codegpt_answer'; 18 | 19 | CREATE TABLE IF NOT EXISTS `modelcache_query_log` ( 20 | `id` bigint(20) unsigned NOT NULL AUTO_INCREMENT comment '主键', 21 | `gmt_create` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP comment '创建时间', 22 | `gmt_modified` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP comment '修改时间', 23 | `error_code` int(11) NOT NULL comment 'errorCode', 24 | `error_desc` varchar(1000) NOT NULL comment 'errorDesc', 25 | `cache_hit` varchar(100) NOT NULL comment 'cacheHit', 26 | `delta_time` float NOT NULL comment 'delta_time', 27 | `model` varchar(1000) NOT NULL comment 'model', 28 | `query` text NOT NULL comment 'query', 29 | `hit_query` text NOT NULL comment 'hitQuery', 30 | `answer` text NOT NULL comment 'answer', 31 | PRIMARY KEY(`id`) 32 | ) AUTO_INCREMENT = 1 DEFAULT CHARSET = utf8mb4 COMMENT = 'modelcache_query_log'; 33 | -------------------------------------------------------------------------------- /data/mysql/my.conf: -------------------------------------------------------------------------------- 1 | [mysqld] 2 | character-set-server=utf8mb4 3 | collation-server=utf8mb4_unicode_ci 4 | 5 | [client] 6 | default-character-set=utf8mb4 7 | 8 | [mysql] 9 | default-character-set=utf8mb4 -------------------------------------------------------------------------------- /docker-compose.yaml: -------------------------------------------------------------------------------- 1 | version: 'Beta' 2 | services: 3 | mysql: 4 | image: mysql:8.0.23 5 | container_name: mysql 6 | environment: 7 | MYSQL_ROOT_PASSWORD: 'root' 8 | MYSQL_DATABASE: 'modelcache' 9 | MYSQL_USER: 'modelcache' 10 | MYSQL_PASSWORD: 'modelcache' 11 | ports: 12 | - 3306:3306 13 | volumes: 14 | - ./data/mysql/db:/var/lib/mysql 15 | - ./data/mysql/my.cnf:/etc/mysql/conf.d/my.cnf 16 | - ./data/mysql/init:/docker-entrypoint-initdb.d 17 | restart: on-failure 18 | networks: 19 | - modelcache 20 | 21 | milvus: 22 | image: milvusdb/milvus:v2.5.0-beta 23 | container_name: milvus 24 | security_opt: 25 | - seccomp:unconfined 26 | environment: 27 | ETCD_USE_EMBED: true 28 | ETCD_DATA_DIR: /var/lib/milvus/etcd 29 | ETCD_CONFIG_PATH: /milvus/configs/embedEtcd.yaml 30 | COMMON_STORAGETYPE: local 31 | volumes: 32 | - ./data/milvus/db:/var/lib/milvus 33 | - ./data/milvus/embedEtcd.yaml:/milvus/configs/embedEtcd.yaml 34 | - ./data/milvus/user.yaml:/milvus/configs/user.yaml 35 | ports: 36 | - 19530:19530 37 | - 9091:9091 38 | - 2379:2379 39 | healthcheck: 40 | test: ["CMD", "curl", "-f", "http://localhost:9091/healthz"] 41 | interval: 30s 42 | start_period: 90s 43 | timeout: 20s 44 | retries: 3 45 | networks: 46 | - modelcache 47 | restart: on-failure 48 | command: milvus run standalone 49 | 50 | modelcache: 51 | build: 52 | context: . 53 | dockerfile: Dockerfile 54 | container_name: modelcache 55 | image: modelcache:0.1.0 56 | ports: 57 | - 5000:5000 58 | volumes: 59 | - ./model:/home/user/model 60 | - ./modelcache:/home/user/modelcache 61 | - ./modelcache_mm:/home/user/modelcache_mm 62 | - ./fastapi4modelcache.py:/home/user/fastapi4modelcache.py 63 | networks: 64 | - modelcache 65 | restart: on-failure 66 | command: sh -c "uvicorn fastapi4modelcache:app --reload --reload-dir /home/user --port=5000 --host=0.0.0.0" 67 | 68 | networks: 69 | modelcache: 70 | external: true -------------------------------------------------------------------------------- /docs/1.what-is-model-cache.md: -------------------------------------------------------------------------------- 1 | # What is ModelCache 2 | 3 | In ModelCache, we adopted the main idea of GPTCache, includes core modules: adapter, embedding, similarity, and data_manager. The adapter module is responsible for handling the business logic of various tasks and can connect the embedding, similarity, and data_manager modules. The embedding module is mainly responsible for converting text into semantic vector representations, it transforms user queries into vector form.The rank module is used for sorting and evaluating the similarity of the recalled vectors. The data_manager module is primarily used for managing the database. In order to better facilitate industrial applications, we have made architectural and functional upgrades as follows: 4 | 5 | ## Architecture 6 | 7 | ![modelcache modules](modelcache_modules_20240409.png) 8 | 9 | ## Function comparison 10 | 11 | We've implemented several key updates to our repository. We've resolved network issues with Hugging Face and improved inference speed by introducing local embedding capabilities. Due to limitations in SqlAlchemy, we've redesigned our relational database interaction module for more flexible operations. We've added multi-tenancy support to ModelCache, recognizing the need for multiple users and models in LLM products. Lastly, we've made initial adjustments for better compatibility with system commands and multi-turn dialogues. 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 |
ModuleFunction
ModelCacheGPTCache
Basic InterfaceData query interface
Data writing interface
EmbeddingEmbedding model configuration
Large model embedding layer
BERT model long text processing
Large model invocationDecoupling from large models
Local loading of embedding model
Data isolationModel data isolation
Hyperparameter isolation
DatabasesMySQL
Milvus
OceanBase
Session managementSingle-turn dialogue
System commands
Multi-turn dialogue
Data managementData persistence
One-click cache clearance
Tenant managementSupport for multi-tenancy
Milvus multi-collection capability
OtherLong-short dialogue distinction
133 | -------------------------------------------------------------------------------- /docs/2.model-cache-features.md: -------------------------------------------------------------------------------- 1 | # ModelCache features 2 | 3 | This topic describes ModelCache features. In ModelCache, we incorporated the core principles of GPTCache. ModelCache has four modules: adapter, embedding, similarity, and data_manager. 4 | 5 | - The adapter module orchestrates the business logic for various tasks, integrate the embedding, similarity, and data_manager modules. 6 | - The embedding module converts text into semantic vector representations, and transforms user queries into vectors. 7 | - The rank module ranks and evaluate the similarity of recalled vectors. 8 | - The data_manager module manages the databases. 9 | 10 | To make ModelCache more suitable for industrial use, we made several improvements to its architecture and functionality: 11 | 12 | - [x] Architectural adjustment (lightweight integration): 13 | - Embedded into LLM products using a Redis-like caching mode. 14 | - Provided semantic caching without interfering with LLM calls, security audits, and other functions. 15 | - Compatible with all LLM services. 16 | - [x] Multiple model loading: 17 | - Supported local embedding model loading, and resolved Hugging Face network connectivity issues. 18 | - Supported loading embedding layers from various pre-trained models. 19 | - [x] Data isolation 20 | - Environment isolation: Read different database configurations based on the environment. Isolate development, staging, and production environments. 21 | - Multi-tenant data isolation: Dynamically create collections based on models for data isolation, addressing data separation issues in multi-model/service scenarios within large language model products. 22 | - [x] Supported system instruction: Adopted a concatenation approach to resolve issues with system instructions in the prompt paradigm. 23 | - [x] Long and short text differentiation: Long texts bring more challenges for similarity assessment. Added differentiation between long and short texts, allowing for separate threshold configurations. 24 | - [x] Milvus performance optimization: Adjusted Milvus consistency level to "Session" level for better performance. 25 | - [x] Data management: 26 | - One-click cache clearing to enable easy data management after model upgrades. 27 | - Recall of hit queries for subsequent data analysis and model iteration reference. 28 | - Asynchronous log write-back for data analysis and statistics. 29 | - Added model field and data statistics field to enhance features. 30 | -------------------------------------------------------------------------------- /docs/3.model-cache-quick-start.md: -------------------------------------------------------------------------------- 1 | # Quick start 2 | 3 | This topic describes how to set up and use ModelCache. 4 | 5 | You can find the start script in `flask4modelcache.py` and `flask4modelcache_demo.py`. 6 | 7 | - `flask4modelcache_demo.py`: A quick test service that embeds SQLite and FAISS. No database configuration required. 8 | - `flask4modelcache.py`: The standard service that requires MySQL and Milvus configuration. 9 | 10 | ## Dependencies 11 | 12 | - Python: V3.8 or above 13 | - Package installation 14 | 15 | ```shell 16 | pip install -r requirements.txt 17 | ``` 18 | 19 | ## Start service 20 | 21 | ### Start demo 22 | 23 | 1. Download the embedding model bin file from [Hugging Face](https://huggingface.co/shibing624/text2vec-base-chinese/tree/main). Place it in the `model/text2vec-base-chinese` folder. 24 | 2. Start the backend service: 25 | 26 | ```shell 27 | cd CodeFuse-ModelCache 28 | ``` 29 | 30 | ```shell 31 | python flask4modelcache_demo.py 32 | ``` 33 | 34 | ### Start standard service 35 | 36 | Before you start standard service, do these steps: 37 | 38 | 1. Install MySQL and import the SQL file from `reference_doc/create_table.sql`. 39 | 2. Install vector database Milvus. 40 | 3. Configure database access in: 41 | - `modelcache/config/milvus_config.ini` 42 | - `modelcache/config/mysql_config.ini` 43 | 4. Download the embedding model bin file from [Hugging Face](https://huggingface.co/shibing624/text2vec-base-chinese/tree/main). Put it in `model/text2vec-base-chinese`. 44 | 5. Start the backend service: 45 | 46 | ```bash 47 | python flask4modelcache.py 48 | ``` 49 | 50 | ## Visit the service 51 | 52 | The service provides three core RESTful API functionalities: Cache-Writing, Cache-Querying, and Cache-Clearing. 53 | 54 | ### Write cache 55 | 56 | ```python 57 | import json 58 | import requests 59 | url = 'http://127.0.0.1:5000/modelcache' 60 | type = 'insert' 61 | scope = {"model": "CODEGPT-1008"} 62 | chat_info = [{"query": [{"role": "system", "content": "You are an AI code assistant and you must provide neutral and harmless answers to help users solve code-related problems."}, {"role": "user", "content": "你是谁?"}], 63 | "answer": "Hello, I am an intelligent assistant. How can I assist you?"}] 64 | data = {'type': type, 'scope': scope, 'chat_info': chat_info} 65 | headers = {"Content-Type": "application/json"} 66 | res = requests.post(url, headers=headers, json=json.dumps(data)) 67 | ``` 68 | 69 | ### Query cache 70 | 71 | ```python 72 | import json 73 | import requests 74 | url = 'http://127.0.0.1:5000/modelcache' 75 | type = 'query' 76 | scope = {"model": "CODEGPT-1008"} 77 | query = [{"role": "system", "content": "You are an AI code assistant and you must provide neutral and harmless answers to help users solve code-related problems."}, {"role": "user", "content": "Who are you?"}] 78 | data = {'type': type, 'scope': scope, 'query': query} 79 | 80 | headers = {"Content-Type": "application/json"} 81 | res = requests.post(url, headers=headers, json=json.dumps(data)) 82 | ``` 83 | 84 | ### Clear cache 85 | 86 | ```python 87 | import json 88 | import requests 89 | url = 'http://127.0.0.1:5000/modelcache' 90 | type = 'remove' 91 | scope = {"model": "CODEGPT-1008"} 92 | remove_type = 'truncate_by_model' 93 | data = {'type': type, 'scope': scope, 'remove_type': remove_type} 94 | 95 | headers = {"Content-Type": "application/json"} 96 | res = requests.post(url, headers=headers, json=json.dumps(data)) 97 | ``` 98 | -------------------------------------------------------------------------------- /docs/4.create-cache.md: -------------------------------------------------------------------------------- 1 | # Create cache 2 | 3 | This topic describes how to create cache. 4 | 5 | ## Default cache interface 6 | 7 | ```py 8 | class Cache: 9 | # ModelCache calls it whe you start the cache system 10 | def __init__(self): 11 | self.has_init = False 12 | self.cache_enable_func = None 13 | self.embedding_func = None 14 | self.post_process_messages_func = None 15 | self.config = Config() 16 | ``` 17 | 18 | This function embeds text into dense vectors for context similarity search. ModelCache supports these embedding context methods: Huggingface, ONNX, and SentenceTransformers. The default model is text2vec Hugging Face because it performs better for Chinese. Simply initialize your embedding function as `text2vec.to_embeddings`. 19 | 20 | ```py 21 | data_manager = get_data_manager(CacheBase("mysql", config=mysql_config), 22 | VectorBase("milvus", dimension=data2vec.dimension, milvus_config=milvus_config)) 23 | 24 | cache.init( 25 | embedding_func=data2vec.to_embeddings, 26 | data_manager=data_manager, 27 | similarity_evaluation=SearchDistanceEvaluation(), 28 | query_pre_embedding_func=query_multi_splicing, 29 | insert_pre_embedding_func=insert_multi_splicing, 30 | ) 31 | ``` 32 | 33 | data_manager CacheVase stores all scalar data, such as original questions, prompts, answers, and access times. ModelCache supports multiple cache storages like SQLite, MySQL, and OceanBase. NoSQL databases will be supported in the future. 34 | 35 | data_manager VectorBase stores and searches all embedding vectors to find semantically similar results. ModelCache supports using vector search libraries like FAISS or vector databases like Milvus. More vector database and cloud service will be supported in the future. 36 | 37 | ## Examples 38 | 39 | ```py 40 | data_manager = get_data_manager(CacheBase("sqlite"), VectorBase("faiss", dimension=data2vec.dimension)) 41 | data_manager = get_data_manager(CacheBase("oceanbase"), VectorBase("milvus", dimension=data2vec.dimension)) 42 | ``` 43 | -------------------------------------------------------------------------------- /docs/cache-service-cost-time-distribution.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/ModelCache/e053e0d57b532d4ad9378d2f31bb85a009b77d64/docs/cache-service-cost-time-distribution.webp -------------------------------------------------------------------------------- /docs/codefuse-LOGO.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/ModelCache/e053e0d57b532d4ad9378d2f31bb85a009b77d64/docs/codefuse-LOGO.png -------------------------------------------------------------------------------- /docs/modelcache_modules_20231114.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/ModelCache/e053e0d57b532d4ad9378d2f31bb85a009b77d64/docs/modelcache_modules_20231114.png -------------------------------------------------------------------------------- /docs/modelcache_modules_20240409.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/ModelCache/e053e0d57b532d4ad9378d2f31bb85a009b77d64/docs/modelcache_modules_20240409.png -------------------------------------------------------------------------------- /docs/script/get_input_embedding_script.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import numpy as np 4 | from transformers import AutoModelForCausalLM 5 | 6 | 7 | model_path = '' 8 | device = torch.device('cuda') 9 | model = AutoModelForCausalLM.from_pretrained(model_path, local_files_only=True).to(device) 10 | embedding_weights = model.get_input_embeddings().weight.to('cpu').detach().numpy() 11 | np.save('gpt-neox-embedding.npy', embedding_weights) 12 | -------------------------------------------------------------------------------- /docs/time-cost-comparison.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/ModelCache/e053e0d57b532d4ad9378d2f31bb85a009b77d64/docs/time-cost-comparison.webp -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /examples/embedding/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /examples/embedding/huggingface_tei_example.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import sys 3 | sys.path.append(".") 4 | from modelcache.embedding.huggingface_tei import HuggingfaceTEI 5 | 6 | ''' 7 | run tei server: 8 | text-embeddings-router --model-id BAAI/bge-large-zh-v1.5 --port 8080 9 | ''' 10 | 11 | def run(): 12 | tei_instance = HuggingfaceTEI('http://127.0.0.1:8080/v1/embeddings', 'BAAI/bge-large-zh-v1.5') 13 | print('dimenson', tei_instance.dimension) 14 | print('embedding', tei_instance.to_embeddings('hello')) 15 | 16 | if __name__ == '__main__': 17 | run() -------------------------------------------------------------------------------- /examples/flask/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /examples/flask/llms_cache/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /examples/flask/llms_cache/data_insert.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import json 3 | import requests 4 | 5 | 6 | def run(): 7 | url = 'http://127.0.0.1:5000/modelcache' 8 | type = 'insert' 9 | scope = {"model": "CODEGPT-1117"} 10 | chat_info = [{"query": [{"role": "system", "content": "你是一个python助手"}, {"role": "user", "content": "hello"}], 11 | "answer": "你好,我是智能助手,请问有什么能帮您!"}] 12 | data = {'type': type, 'scope': scope, 'chat_info': chat_info} 13 | headers = {"Content-Type": "application/json"} 14 | res = requests.post(url, headers=headers, json=json.dumps(data)) 15 | res_text = res.text 16 | 17 | print("data_insert:", res.status_code, res_text) 18 | 19 | if __name__ == '__main__': 20 | run() 21 | -------------------------------------------------------------------------------- /examples/flask/llms_cache/data_query.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import json 3 | import requests 4 | 5 | 6 | def run(): 7 | url = 'http://127.0.0.1:5000/modelcache' 8 | type = 'query' 9 | scope = {"model": "CODEGPT-1117"} 10 | query = [{"role": "system", "content": "你是一个python助手"}, {"role": "user", "content": "hello"}] 11 | data = {'type': type, 'scope': scope, 'query': query} 12 | 13 | headers = {"Content-Type": "application/json"} 14 | res = requests.post(url, headers=headers, json=json.dumps(data)) 15 | res_text = res.text 16 | 17 | print("data_query:", res.status_code, res_text) 18 | 19 | if __name__ == '__main__': 20 | run() 21 | -------------------------------------------------------------------------------- /examples/flask/llms_cache/data_query_long.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import json 3 | import requests 4 | 5 | 6 | def run(): 7 | url = 'http://127.0.0.1:5000/modelcache' 8 | type = 'query' 9 | scope = {"model": "CODEGPT-1109"} 10 | system_conten = """ 11 | """ 12 | user_content = """ 13 | """ 14 | 15 | query = [{"role": "system", "content": system_conten}, {"role": "user", "content": user_content}] 16 | data = {'type': type, 'scope': scope, 'query': query} 17 | 18 | headers = {"Content-Type": "application/json"} 19 | res = requests.post(url, headers=headers, json=json.dumps(data)) 20 | res_text = res.text 21 | 22 | print("data_query_long:", res.status_code, res_text) 23 | 24 | if __name__ == '__main__': 25 | run() 26 | -------------------------------------------------------------------------------- /examples/flask/llms_cache/register.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | register index for redis 4 | """ 5 | import json 6 | import requests 7 | 8 | 9 | def run(): 10 | url = 'http://127.0.0.1:5000/modelcache' 11 | type = 'register' 12 | scope = {"model": "CODEGPT-1117"} 13 | data = {'type': type, 'scope': scope} 14 | headers = {"Content-Type": "application/json"} 15 | res = requests.post(url, headers=headers, json=json.dumps(data)) 16 | res_text = res.text 17 | 18 | 19 | if __name__ == '__main__': 20 | run() -------------------------------------------------------------------------------- /examples/flask/multi_cache/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /examples/flask/multi_cache/data_insert.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import time 3 | import json 4 | import uuid 5 | import requests 6 | 7 | 8 | def run(): 9 | url = 'http://127.0.0.1:5000/multicache' 10 | 11 | request_type = 'insert' 12 | scope = {"model": "multimodal_test"} 13 | # UUID = "820b0052-d9d8-11ee-95f1-52775e3e6fd1" + "==>" + str(time.time()) 14 | UUID = str(uuid.uuid1()) + "==>" + str(time.time()) 15 | img_data = "https://img0.baidu.com/it/u=1436460262,4166266890&fm=253&fmt=auto&app=138&f=JPEG?w=500&h=282" 16 | query = {'text': ['父母带着孩子来这个地方可能会有什么顾虑'], 17 | 'imageRaw': '', 18 | 'imageUrl': img_data, 19 | 'imageId': 'ccc'} 20 | answer = "应该注意小孩不要跑到铁轨上" 21 | chat_info = [{"query": query, "answer": answer}] 22 | data_dict = {'request_type': request_type, 'scope': scope, 'chat_info': chat_info, 'UUID': UUID} 23 | 24 | headers = {"Content-Type": "application/json"} 25 | res = requests.post(url, headers=headers, json=json.dumps(data_dict)) 26 | res_text = res.text 27 | print('res_text: {}'.format(res_text)) 28 | 29 | 30 | if __name__ == '__main__': 31 | run() 32 | -------------------------------------------------------------------------------- /examples/flask/multi_cache/data_query.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import json 3 | import requests 4 | import uuid 5 | import time 6 | 7 | 8 | def run(): 9 | url = 'http://127.0.0.1:5000/multicache' 10 | request_type = 'query' 11 | UUID = str(uuid.uuid1()) + "==>" + str(time.time()) 12 | scope = {"model": "multimodal_test"} 13 | img_data = "https://img0.baidu.com/it/u=1436460262,4166266890&fm=253&fmt=auto&app=138&f=JPEG?w=500&h=282" 14 | query = {'text': ['父母带着孩子来这个地方可能会有什么顾虑'], 15 | 'imageRaw': '', 16 | 'imageUrl': img_data, 17 | 'multiType': 'IMG_TEXT'} 18 | 19 | data = {'request_type': request_type, 'scope': scope, 'query': query, 'UUID': UUID} 20 | 21 | headers = {"Content-Type": "application/json"} 22 | res = requests.post(url, headers=headers, json=json.dumps(data)) 23 | res_text = res.text 24 | print('res_text: {}'.format(res_text)) 25 | 26 | 27 | if __name__ == '__main__': 28 | run() 29 | -------------------------------------------------------------------------------- /examples/flask/multi_cache/register.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | register index for redis 4 | """ 5 | import json 6 | import requests 7 | 8 | 9 | def run(): 10 | url = 'http://127.0.0.1:5000/multicache' 11 | request_type = 'register' 12 | scope = {"model": "multimodal_test"} 13 | type = 'IMG_TEXT' 14 | data = {'request_type': request_type, 'scope': scope, 'type': type} 15 | headers = {"Content-Type": "application/json"} 16 | res = requests.post(url, headers=headers, json=json.dumps(data)) 17 | res_text = res.text 18 | print('res_text: {}'.format(res_text)) 19 | 20 | 21 | if __name__ == '__main__': 22 | run() -------------------------------------------------------------------------------- /examples/flask/multi_cache/remove.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | register index for redis 4 | """ 5 | import json 6 | import requests 7 | 8 | 9 | def run(): 10 | url = 'http://127.0.0.1:5000/multicache' 11 | request_type = 'remove' 12 | scope = {"model": "multimodal_test"} 13 | remove_type = 'truncate_by_model' 14 | data = {'request_type': request_type, 'scope': scope, 'remove_type': remove_type} 15 | 16 | headers = {"Content-Type": "application/json"} 17 | res = requests.post(url, headers=headers, json=json.dumps(data)) 18 | res_text = res.text 19 | print('res_text: {}'.format(res_text)) 20 | 21 | 22 | if __name__ == '__main__': 23 | run() 24 | -------------------------------------------------------------------------------- /fastapi4modelcache_demo.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import time 3 | import uvicorn 4 | import asyncio 5 | import logging 6 | # import configparser 7 | import json 8 | from fastapi import FastAPI, Request, HTTPException 9 | from pydantic import BaseModel 10 | from concurrent.futures import ThreadPoolExecutor 11 | from starlette.responses import PlainTextResponse 12 | import functools 13 | 14 | from modelcache import cache 15 | from modelcache.adapter import adapter 16 | from modelcache.manager import CacheBase, VectorBase, get_data_manager 17 | from modelcache.similarity_evaluation.distance import SearchDistanceEvaluation 18 | from modelcache.processor.pre import query_multi_splicing 19 | from modelcache.processor.pre import insert_multi_splicing 20 | from modelcache.utils.model_filter import model_blacklist_filter 21 | from modelcache.embedding import Data2VecAudio 22 | 23 | # 创建一个FastAPI实例 24 | app = FastAPI() 25 | 26 | class RequestData(BaseModel): 27 | type: str 28 | scope: dict = None 29 | query: str = None 30 | chat_info: list = None 31 | remove_type: str = None 32 | id_list: list = [] 33 | 34 | data2vec = Data2VecAudio() 35 | 36 | data_manager = get_data_manager(CacheBase("sqlite"), VectorBase("faiss", dimension=data2vec.dimension)) 37 | 38 | cache.init( 39 | embedding_func=data2vec.to_embeddings, 40 | data_manager=data_manager, 41 | similarity_evaluation=SearchDistanceEvaluation(), 42 | query_pre_embedding_func=query_multi_splicing, 43 | insert_pre_embedding_func=insert_multi_splicing, 44 | ) 45 | 46 | executor = ThreadPoolExecutor(max_workers=6) 47 | 48 | # 异步保存查询信息 49 | async def save_query_info_fastapi(result, model, query, delta_time_log): 50 | loop = asyncio.get_running_loop() 51 | func = functools.partial(cache.data_manager.save_query_resp, result, model=model, query=json.dumps(query, ensure_ascii=False), delta_time=delta_time_log) 52 | await loop.run_in_executor(None, func) 53 | 54 | 55 | 56 | @app.get("/welcome", response_class=PlainTextResponse) 57 | async def first_fastapi(): 58 | return "hello, modelcache!" 59 | 60 | @app.post("/modelcache") 61 | async def user_backend(request: Request): 62 | try: 63 | raw_body = await request.body() 64 | # 解析字符串为JSON对象 65 | if isinstance(raw_body, bytes): 66 | raw_body = raw_body.decode("utf-8") 67 | if isinstance(raw_body, str): 68 | try: 69 | # 尝试将字符串解析为JSON对象 70 | request_data = json.loads(raw_body) 71 | except json.JSONDecodeError: 72 | # 如果无法解析,返回格式错误 73 | raise HTTPException(status_code=400, detail="Invalid JSON format") 74 | else: 75 | request_data = raw_body 76 | 77 | # 确保request_data是字典对象 78 | if isinstance(request_data, str): 79 | try: 80 | request_data = json.loads(request_data) 81 | except json.JSONDecodeError: 82 | raise HTTPException(status_code=400, detail="Invalid JSON format") 83 | 84 | request_type = request_data.get('type') 85 | model = None 86 | if 'scope' in request_data: 87 | model = request_data['scope'].get('model', '').replace('-', '_').replace('.', '_') 88 | query = request_data.get('query') 89 | chat_info = request_data.get('chat_info') 90 | 91 | if not request_type or request_type not in ['query', 'insert', 'remove', 'detox']: 92 | raise HTTPException(status_code=400, detail="Type exception, should be one of ['query', 'insert', 'remove', 'detox']") 93 | 94 | except Exception as e: 95 | request_data = raw_body if 'raw_body' in locals() else None 96 | result = { 97 | "errorCode": 103, 98 | "errorDesc": str(e), 99 | "cacheHit": False, 100 | "delta_time": 0, 101 | "hit_query": '', 102 | "answer": '', 103 | "para_dict": request_data 104 | } 105 | return result 106 | 107 | 108 | # model filter 109 | filter_resp = model_blacklist_filter(model, request_type) 110 | if isinstance(filter_resp, dict): 111 | return filter_resp 112 | 113 | if request_type == 'query': 114 | try: 115 | start_time = time.time() 116 | response = adapter.ChatCompletion.create_query(scope={"model": model}, query=query) 117 | delta_time = f"{round(time.time() - start_time, 2)}s" 118 | 119 | if response is None: 120 | result = {"errorCode": 0, "errorDesc": '', "cacheHit": False, "delta_time": delta_time, "hit_query": '', "answer": ''} 121 | elif response in ['adapt_query_exception']: 122 | # elif isinstance(response, str): 123 | result = {"errorCode": 201, "errorDesc": response, "cacheHit": False, "delta_time": delta_time, 124 | "hit_query": '', "answer": ''} 125 | else: 126 | answer = response['data'] 127 | hit_query = response['hitQuery'] 128 | result = {"errorCode": 0, "errorDesc": '', "cacheHit": True, "delta_time": delta_time, "hit_query": hit_query, "answer": answer} 129 | 130 | delta_time_log = round(time.time() - start_time, 2) 131 | asyncio.create_task(save_query_info_fastapi(result, model, query, delta_time_log)) 132 | return result 133 | except Exception as e: 134 | result = {"errorCode": 202, "errorDesc": str(e), "cacheHit": False, "delta_time": 0, 135 | "hit_query": '', "answer": ''} 136 | logging.info(f'result: {str(result)}') 137 | return result 138 | 139 | if request_type == 'insert': 140 | try: 141 | response = adapter.ChatCompletion.create_insert(model=model, chat_info=chat_info) 142 | if response == 'success': 143 | return {"errorCode": 0, "errorDesc": "", "writeStatus": "success"} 144 | else: 145 | return {"errorCode": 301, "errorDesc": response, "writeStatus": "exception"} 146 | except Exception as e: 147 | return {"errorCode": 303, "errorDesc": str(e), "writeStatus": "exception"} 148 | 149 | if request_type == 'remove': 150 | response = adapter.ChatCompletion.create_remove(model=model, remove_type=request_data.get("remove_type"), id_list=request_data.get("id_list")) 151 | if not isinstance(response, dict): 152 | return {"errorCode": 401, "errorDesc": "", "response": response, "removeStatus": "exception"} 153 | 154 | state = response.get('status') 155 | if state == 'success': 156 | return {"errorCode": 0, "errorDesc": "", "response": response, "writeStatus": "success"} 157 | else: 158 | return {"errorCode": 402, "errorDesc": "", "response": response, "writeStatus": "exception"} 159 | 160 | # TODO: 可以修改为在命令行中使用`uvicorn your_module_name:app --host 0.0.0.0 --port 5000 --reload`的命令启动 161 | if __name__ == '__main__': 162 | uvicorn.run(app, host='0.0.0.0', port=5000) -------------------------------------------------------------------------------- /flask4modelcache_demo.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import time 3 | from flask import Flask, request 4 | import logging 5 | import json 6 | from modelcache import cache 7 | from modelcache.adapter import adapter 8 | from modelcache.manager import CacheBase, VectorBase, get_data_manager 9 | from modelcache.similarity_evaluation.distance import SearchDistanceEvaluation 10 | from modelcache.processor.pre import query_multi_splicing 11 | from modelcache.processor.pre import insert_multi_splicing 12 | from concurrent.futures import ThreadPoolExecutor 13 | from modelcache.utils.model_filter import model_blacklist_filter 14 | from modelcache.embedding import Data2VecAudio 15 | 16 | # 创建一个Flask实例 17 | app = Flask(__name__) 18 | 19 | 20 | def response_text(cache_resp): 21 | return cache_resp['data'] 22 | 23 | 24 | def save_query_info(result, model, query, delta_time_log): 25 | cache.data_manager.save_query_resp(result, model=model, query=json.dumps(query, ensure_ascii=False), 26 | delta_time=delta_time_log) 27 | 28 | 29 | def response_hitquery(cache_resp): 30 | return cache_resp['hitQuery'] 31 | 32 | 33 | data2vec = Data2VecAudio() 34 | data_manager = get_data_manager(CacheBase("sqlite"), VectorBase("faiss", dimension=data2vec.dimension)) 35 | 36 | 37 | cache.init( 38 | embedding_func=data2vec.to_embeddings, 39 | data_manager=data_manager, 40 | similarity_evaluation=SearchDistanceEvaluation(), 41 | query_pre_embedding_func=query_multi_splicing, 42 | insert_pre_embedding_func=insert_multi_splicing, 43 | ) 44 | 45 | # cache.set_openai_key() 46 | global executor 47 | executor = ThreadPoolExecutor(max_workers=6) 48 | 49 | 50 | @app.route('/welcome') 51 | def first_flask(): # 视图函数 52 | return 'hello, modelcache!' 53 | 54 | 55 | @app.route('/modelcache', methods=['GET', 'POST']) 56 | def user_backend(): 57 | try: 58 | if request.method == 'POST': 59 | request_data = request.json 60 | elif request.method == 'GET': 61 | request_data = request.args 62 | param_dict = json.loads(request_data) 63 | except Exception as e: 64 | result = {"errorCode": 101, "errorDesc": str(e), "cacheHit": False, "delta_time": 0, "hit_query": '', 65 | "answer": ''} 66 | cache.data_manager.save_query_resp(result, model='', query='', delta_time=0) 67 | return json.dumps(result) 68 | 69 | # param parsing 70 | try: 71 | request_type = param_dict.get("type") 72 | scope = param_dict.get("scope") 73 | if scope is not None: 74 | model = scope.get('model') 75 | model = model.replace('-', '_') 76 | model = model.replace('.', '_') 77 | query = param_dict.get("query") 78 | chat_info = param_dict.get("chat_info") 79 | if request_type is None or request_type not in ['query', 'insert', 'detox', 'remove']: 80 | result = {"errorCode": 102, 81 | "errorDesc": "type exception, should one of ['query', 'insert', 'detox', 'remove']", 82 | "cacheHit": False, "delta_time": 0, "hit_query": '', "answer": ''} 83 | cache.data_manager.save_query_resp(result, model=model, query='', delta_time=0) 84 | return json.dumps(result) 85 | except Exception as e: 86 | result = {"errorCode": 103, "errorDesc": str(e), "cacheHit": False, "delta_time": 0, "hit_query": '', 87 | "answer": ''} 88 | return json.dumps(result) 89 | 90 | # model filter 91 | filter_resp = model_blacklist_filter(model, request_type) 92 | if isinstance(filter_resp, dict): 93 | return json.dumps(filter_resp) 94 | 95 | if request_type == 'query': 96 | try: 97 | start_time = time.time() 98 | response = adapter.ChatCompletion.create_query( 99 | scope={"model": model}, 100 | query=query 101 | ) 102 | delta_time = '{}s'.format(round(time.time() - start_time, 2)) 103 | if response is None: 104 | result = {"errorCode": 0, "errorDesc": '', "cacheHit": False, "delta_time": delta_time, "hit_query": '', 105 | "answer": ''} 106 | elif response in ['adapt_query_exception']: 107 | result = {"errorCode": 201, "errorDesc": response, "cacheHit": False, "delta_time": delta_time, 108 | "hit_query": '', "answer": ''} 109 | else: 110 | answer = response_text(response) 111 | hit_query = response_hitquery(response) 112 | result = {"errorCode": 0, "errorDesc": '', "cacheHit": True, "delta_time": delta_time, 113 | "hit_query": hit_query, "answer": answer} 114 | delta_time_log = round(time.time() - start_time, 2) 115 | future = executor.submit(save_query_info, result, model, query, delta_time_log) 116 | except Exception as e: 117 | result = {"errorCode": 202, "errorDesc": e, "cacheHit": False, "delta_time": 0, 118 | "hit_query": '', "answer": ''} 119 | logging.info('result: {}'.format(result)) 120 | return json.dumps(result, ensure_ascii=False) 121 | 122 | if request_type == 'insert': 123 | try: 124 | try: 125 | response = adapter.ChatCompletion.create_insert( 126 | model=model, 127 | chat_info=chat_info 128 | ) 129 | except Exception as e: 130 | result = {"errorCode": 303, "errorDesc": e, "writeStatus": "exception"} 131 | return json.dumps(result, ensure_ascii=False) 132 | 133 | if response in ['adapt_insert_exception']: 134 | result = {"errorCode": 301, "errorDesc": response, "writeStatus": "exception"} 135 | elif response == 'success': 136 | result = {"errorCode": 0, "errorDesc": "", "writeStatus": "success"} 137 | else: 138 | result = {"errorCode": 302, "errorDesc": response, 139 | "writeStatus": "exception"} 140 | return json.dumps(result, ensure_ascii=False) 141 | except Exception as e: 142 | result = {"errorCode": 304, "errorDesc": e, "writeStatus": "exception"} 143 | return json.dumps(result, ensure_ascii=False) 144 | 145 | if request_type == 'remove': 146 | remove_type = param_dict.get("remove_type") 147 | id_list = param_dict.get("id_list", []) 148 | 149 | response = adapter.ChatCompletion.create_remove( 150 | model=model, 151 | remove_type=remove_type, 152 | id_list=id_list 153 | ) 154 | 155 | if not isinstance(response, dict): 156 | result = {"errorCode": 401, "errorDesc": "", "response": response, "removeStatus": "exception"} 157 | return json.dumps(result) 158 | 159 | state = response.get('status') 160 | if state == 'success': 161 | result = {"errorCode": 0, "errorDesc": "", "response": response, "writeStatus": "success"} 162 | else: 163 | result = {"errorCode": 402, "errorDesc": "", "response": response, "writeStatus": "exception"} 164 | return json.dumps(result) 165 | 166 | 167 | if __name__ == '__main__': 168 | # app.run(host='0.0.0.0', port=5000, debug=True) 169 | app.run(host='0.0.0.0', port=5000) 170 | -------------------------------------------------------------------------------- /model/clip_zh/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /model/text2vec-base-chinese/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "hfl/chinese-macbert-base", 3 | "architectures": [ 4 | "BertModel" 5 | ], 6 | "attention_probs_dropout_prob": 0.1, 7 | "classifier_dropout": null, 8 | "directionality": "bidi", 9 | "gradient_checkpointing": false, 10 | "hidden_act": "gelu", 11 | "hidden_dropout_prob": 0.1, 12 | "hidden_size": 768, 13 | "initializer_range": 0.02, 14 | "intermediate_size": 3072, 15 | "layer_norm_eps": 1e-12, 16 | "max_position_embeddings": 512, 17 | "model_type": "bert", 18 | "num_attention_heads": 12, 19 | "num_hidden_layers": 12, 20 | "pad_token_id": 0, 21 | "pooler_fc_size": 768, 22 | "pooler_num_attention_heads": 12, 23 | "pooler_num_fc_layers": 3, 24 | "pooler_size_per_head": 128, 25 | "pooler_type": "first_token_transform", 26 | "position_embedding_type": "absolute", 27 | "torch_dtype": "float32", 28 | "transformers_version": "4.12.3", 29 | "type_vocab_size": 2, 30 | "use_cache": true, 31 | "vocab_size": 21128 32 | } 33 | -------------------------------------------------------------------------------- /model/text2vec-base-chinese/logs.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /model/text2vec-base-chinese/modules.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "idx": 0, 4 | "name": "0", 5 | "path": "", 6 | "type": "sentence_transformers.models.Transformer" 7 | }, 8 | { 9 | "idx": 1, 10 | "name": "1", 11 | "path": "1_Pooling", 12 | "type": "sentence_transformers.models.Pooling" 13 | } 14 | ] 15 | -------------------------------------------------------------------------------- /model/text2vec-base-chinese/sentence_bert_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_seq_length": 128, 3 | "do_lower_case": false 4 | } 5 | -------------------------------------------------------------------------------- /model/text2vec-base-chinese/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | {"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"} -------------------------------------------------------------------------------- /model/text2vec-base-chinese/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | {"do_lower_case": true, "do_basic_tokenize": true, "never_split": null, "unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]", "tokenize_chinese_chars": true, "strip_accents": null, "name_or_path": "hfl/chinese-macbert-base", "tokenizer_class": "BertTokenizer"} 2 | -------------------------------------------------------------------------------- /modelcache/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from modelcache.core import Cache 3 | from modelcache.core import cache 4 | from modelcache.config import Config 5 | import modelcache.adapter 6 | -------------------------------------------------------------------------------- /modelcache/adapter/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /modelcache/adapter/adapter.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import logging 3 | from modelcache.adapter.adapter_query import adapt_query 4 | from modelcache.adapter.adapter_insert import adapt_insert 5 | from modelcache.adapter.adapter_remove import adapt_remove 6 | from modelcache.adapter.adapter_register import adapt_register 7 | 8 | 9 | class ChatCompletion(object): 10 | """Openai ChatCompletion Wrapper""" 11 | 12 | @classmethod 13 | def create_query(cls, *args, **kwargs): 14 | def cache_data_convert(cache_data, cache_query): 15 | return construct_resp_from_cache(cache_data, cache_query) 16 | try: 17 | return adapt_query( 18 | cache_data_convert, 19 | *args, 20 | **kwargs 21 | ) 22 | except Exception as e: 23 | return str(e) 24 | 25 | @classmethod 26 | def create_insert(cls, *args, **kwargs): 27 | try: 28 | return adapt_insert( 29 | *args, 30 | **kwargs 31 | ) 32 | except Exception as e: 33 | return str(e) 34 | 35 | @classmethod 36 | def create_remove(cls, *args, **kwargs): 37 | try: 38 | return adapt_remove( 39 | *args, 40 | **kwargs 41 | ) 42 | except Exception as e: 43 | logging.info('adapt_remove_e: {}'.format(e)) 44 | return str(e) 45 | 46 | @classmethod 47 | def create_register(cls, *args, **kwargs): 48 | try: 49 | return adapt_register( 50 | *args, 51 | **kwargs 52 | ) 53 | except Exception as e: 54 | return str(e) 55 | 56 | 57 | def construct_resp_from_cache(return_message, return_query): 58 | return { 59 | "modelcache": True, 60 | "hitQuery": return_query, 61 | "data": return_message, 62 | "errorCode": 0 63 | } 64 | -------------------------------------------------------------------------------- /modelcache/adapter/adapter_insert.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from modelcache import cache 3 | from modelcache.utils.error import NotInitError 4 | from modelcache.utils.time import time_cal 5 | 6 | 7 | def adapt_insert(*args, **kwargs): 8 | chat_cache = kwargs.pop("cache_obj", cache) 9 | model = kwargs.pop("model", None) 10 | require_object_store = kwargs.pop("require_object_store", False) 11 | if require_object_store: 12 | assert chat_cache.data_manager.o, "Object store is required for adapter." 13 | if not chat_cache.has_init: 14 | raise NotInitError() 15 | cache_enable = chat_cache.cache_enable_func(*args, **kwargs) 16 | context = kwargs.pop("cache_context", {}) 17 | embedding_data = None 18 | pre_embedding_data = chat_cache.insert_pre_embedding_func( 19 | kwargs, 20 | extra_param=context.get("pre_embedding_func", None), 21 | prompts=chat_cache.config.prompts, 22 | ) 23 | chat_info = kwargs.pop("chat_info", []) 24 | llm_data = chat_info[-1]['answer'] 25 | 26 | if cache_enable: 27 | embedding_data = time_cal( 28 | chat_cache.embedding_func, 29 | func_name="embedding", 30 | report_func=chat_cache.report.embedding, 31 | )(pre_embedding_data) 32 | 33 | chat_cache.data_manager.save( 34 | pre_embedding_data, 35 | llm_data, 36 | embedding_data, 37 | model=model, 38 | extra_param=context.get("save_func", None) 39 | ) 40 | return 'success' 41 | -------------------------------------------------------------------------------- /modelcache/adapter/adapter_register.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from modelcache import cache 3 | 4 | 5 | def adapt_register(*args, **kwargs): 6 | chat_cache = kwargs.pop("cache_obj", cache) 7 | model = kwargs.pop("model", None) 8 | if model is None or len(model) == 0: 9 | return ValueError('') 10 | 11 | register_resp = chat_cache.data_manager.create_index(model) 12 | return register_resp 13 | -------------------------------------------------------------------------------- /modelcache/adapter/adapter_remove.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from modelcache import cache 3 | from modelcache.utils.error import NotInitError, RemoveError 4 | 5 | 6 | def adapt_remove(*args, **kwargs): 7 | chat_cache = kwargs.pop("cache_obj", cache) 8 | model = kwargs.pop("model", None) 9 | remove_type = kwargs.pop("remove_type", None) 10 | require_object_store = kwargs.pop("require_object_store", False) 11 | if require_object_store: 12 | assert chat_cache.data_manager.o, "Object store is required for adapter." 13 | if not chat_cache.has_init: 14 | raise NotInitError() 15 | 16 | # delete data 17 | if remove_type == 'delete_by_id': 18 | id_list = kwargs.pop("id_list", []) 19 | resp = chat_cache.data_manager.delete(id_list, model=model) 20 | elif remove_type == 'truncate_by_model': 21 | resp = chat_cache.data_manager.truncate(model) 22 | else: 23 | # resp = "remove_type_error" 24 | raise RemoveError() 25 | return resp 26 | 27 | -------------------------------------------------------------------------------- /modelcache/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from typing import Optional, Callable, List 3 | from modelcache.utils.error import CacheError 4 | 5 | 6 | class Config: 7 | 8 | def __init__( 9 | self, 10 | log_time_func: Optional[Callable[[str, float], None]] = None, 11 | similarity_threshold: float = 0.95, 12 | similarity_threshold_long: float = 0.95, 13 | prompts: Optional[List[str]] = None 14 | ): 15 | if similarity_threshold < 0 or similarity_threshold > 1: 16 | raise CacheError( 17 | "Invalid the similarity threshold param, reasonable range: 0-1" 18 | ) 19 | self.log_time_func = log_time_func 20 | self.similarity_threshold = similarity_threshold 21 | self.similarity_threshold_long = similarity_threshold_long 22 | self.prompts = prompts 23 | -------------------------------------------------------------------------------- /modelcache/config/chromadb_config.ini: -------------------------------------------------------------------------------- 1 | [chromadb] 2 | persist_directory='' 3 | -------------------------------------------------------------------------------- /modelcache/config/elasticsearch_config.ini: -------------------------------------------------------------------------------- 1 | [elasticsearch] 2 | host = '' 3 | port = '' 4 | user = '' 5 | password = '' -------------------------------------------------------------------------------- /modelcache/config/milvus_config.ini: -------------------------------------------------------------------------------- 1 | [milvus] 2 | host = milvus 3 | port = 19530 4 | user = '' 5 | password = '' -------------------------------------------------------------------------------- /modelcache/config/mysql_config.ini: -------------------------------------------------------------------------------- 1 | [mysql] 2 | host = mysql 3 | port = 3306 4 | username = modelcache 5 | password = modelcache 6 | database = modelcache 7 | -------------------------------------------------------------------------------- /modelcache/config/redis_config.ini: -------------------------------------------------------------------------------- 1 | [redis] 2 | host = '' 3 | port = '' 4 | user = '' 5 | password = '' 6 | -------------------------------------------------------------------------------- /modelcache/core.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import atexit 3 | from typing import Optional, List, Any 4 | from modelcache.processor.post import first 5 | from modelcache.similarity_evaluation import ExactMatchEvaluation 6 | from modelcache.similarity_evaluation import SimilarityEvaluation 7 | from modelcache.embedding.string_text import to_embeddings as string_embedding 8 | from modelcache.report import Report 9 | from modelcache.config import Config 10 | from modelcache.utils.cache_func import cache_all 11 | from modelcache.utils.log import modelcache_log 12 | from modelcache.manager import get_data_manager 13 | from modelcache.manager.data_manager import DataManager 14 | 15 | 16 | class Cache: 17 | def __init__(self): 18 | self.has_init = False 19 | self.cache_enable_func = None 20 | self.query_pre_embedding_func = None 21 | self.insert_pre_embedding_func = None 22 | self.mm_query_pre_embedding_func = None 23 | self.mm_insert_pre_embedding_func = None 24 | self.embedding_func = None 25 | self.embedding_concurrent_func = None 26 | self.data_manager: Optional[DataManager] = None 27 | self.similarity_evaluation: Optional[SimilarityEvaluation] = None 28 | self.post_process_messages_func = None 29 | self.config = Config() 30 | self.report = Report() 31 | self.next_cache = None 32 | 33 | def init( 34 | self, 35 | cache_enable_func=cache_all, 36 | query_pre_embedding_func=None, 37 | insert_pre_embedding_func=None, 38 | embedding_func=string_embedding, 39 | data_manager: DataManager = get_data_manager(), 40 | similarity_evaluation=ExactMatchEvaluation(), 41 | post_process_messages_func=first, 42 | config=Config(), 43 | next_cache=None, 44 | ): 45 | self.has_init = True 46 | self.cache_enable_func = cache_enable_func 47 | self.query_pre_embedding_func = query_pre_embedding_func 48 | self.insert_pre_embedding_func = insert_pre_embedding_func 49 | self.embedding_func = embedding_func 50 | self.data_manager: DataManager = data_manager 51 | self.similarity_evaluation = similarity_evaluation 52 | self.post_process_messages_func = post_process_messages_func 53 | self.config = config 54 | self.next_cache = next_cache 55 | 56 | @atexit.register 57 | def close(): 58 | try: 59 | self.data_manager.close() 60 | except Exception as e: 61 | modelcache_log.error(e) 62 | 63 | def import_data(self, questions: List[Any], answers: List[Any]) -> None: 64 | self.data_manager.import_data( 65 | questions=questions, 66 | answers=answers, 67 | embedding_datas=[self.embedding_func(question) for question in questions], 68 | ) 69 | 70 | def flush(self): 71 | self.data_manager.flush() 72 | if self.next_cache: 73 | self.next_cache.data_manager.flush() 74 | 75 | 76 | cache = Cache() 77 | -------------------------------------------------------------------------------- /modelcache/embedding/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from modelcache.utils.lazy_import import LazyImport 3 | huggingface = LazyImport("huggingface", globals(), "modelcache.embedding.huggingface") 4 | data2vec = LazyImport("data2vec", globals(), "modelcache.embedding.data2vec") 5 | llmEmb = LazyImport("llmEmb", globals(), "modelcache.embedding.llmEmb") 6 | fasttext = LazyImport("fasttext", globals(), "modelcache.embedding.fasttext") 7 | paddlenlp = LazyImport("paddlenlp", globals(), "modelcache.embedding.paddlenlp") 8 | timm = LazyImport("timm", globals(), "modelcache.embedding.timm") 9 | huggingface_tei = LazyImport("huggingface_tei", globals(), "modelcache.embedding.huggingface_tei") 10 | bge_m3 = LazyImport("bge_m3", globals(), "modelcache.embedding.bge_m3") 11 | 12 | 13 | def Huggingface(model="sentence-transformers/all-mpnet-base-v2"): 14 | return huggingface.Huggingface(model) 15 | 16 | 17 | def Data2VecAudio(model="model/text2vec-base-chinese/"): 18 | return data2vec.Data2VecAudio(model) 19 | 20 | 21 | def LlmEmb2vecAudio(): 22 | return llmEmb.LlmEmb2Vec() 23 | 24 | 25 | def FastText(model="en", dim=None): 26 | return fasttext.FastText(model, dim) 27 | 28 | 29 | def PaddleNLP(model="ernie-3.0-medium-zh"): 30 | return paddlenlp.PaddleNLP(model) 31 | 32 | 33 | def Timm(model="resnet50", device="default"): 34 | return timm.Timm(model, device) 35 | 36 | def HuggingfaceTEI(base_url, model): 37 | return huggingface_tei.HuggingfaceTEI(base_url, model) 38 | 39 | def BgeM3Embedding(model_path="model/bge-m3"): 40 | return bge_m3.BgeM3Embedding(model_path) -------------------------------------------------------------------------------- /modelcache/embedding/base.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from abc import ABCMeta, abstractmethod 3 | 4 | 5 | class BaseEmbedding(metaclass=ABCMeta): 6 | """ 7 | _Embedding base. 8 | """ 9 | 10 | @abstractmethod 11 | def to_embeddings(self, data, **kwargs): 12 | pass 13 | 14 | @property 15 | @abstractmethod 16 | def dimension(self) -> int: 17 | return 0 18 | -------------------------------------------------------------------------------- /modelcache/embedding/bge_m3.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | from modelcache.embedding.base import BaseEmbedding 4 | from transformers import AutoTokenizer, AutoModel 5 | from FlagEmbedding import BGEM3FlagModel 6 | 7 | class BgeM3Embedding(BaseEmbedding): 8 | def __init__(self, model_path: str = "model/bge-m3"): 9 | self.tokenizer = AutoTokenizer.from_pretrained(model_path) 10 | self.model = AutoModel.from_pretrained(model_path) 11 | 12 | self.bge_model = BGEM3FlagModel(model_name_or_path=model_path, 13 | model=self.model, 14 | tokenizer=self.tokenizer, 15 | use_fp16=False) 16 | 17 | self.__dimension = 768 18 | 19 | def to_embeddings(self, data, **_): 20 | if not isinstance(data, list): 21 | data = [data] 22 | 23 | embeddings = self.bge_model.encode(data, batch_size=12, max_length=8192)['dense_vecs'] 24 | return np.array(embeddings).astype("float32") 25 | 26 | @property 27 | def dimension(self): 28 | return self.__dimension -------------------------------------------------------------------------------- /modelcache/embedding/data2vec.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import time 4 | import numpy as np 5 | import torch 6 | from transformers import BertTokenizer, BertModel 7 | from modelcache.embedding.base import BaseEmbedding 8 | 9 | 10 | def mean_pooling(model_output, attention_mask): 11 | token_embeddings = model_output[0] 12 | input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() 13 | return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) 14 | 15 | 16 | class Data2VecAudio(BaseEmbedding): 17 | def __init__(self, model): 18 | current_dir = os.path.dirname(os.path.abspath(__file__)) 19 | parent_dir = os.path.dirname(current_dir) 20 | model_dir = os.path.dirname(parent_dir) 21 | model_path = os.path.join(model_dir, model) 22 | 23 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 24 | self.tokenizer = BertTokenizer.from_pretrained(model_path, local_files_only=True) 25 | self.model = BertModel.from_pretrained(model_path, local_files_only=True) 26 | 27 | try: 28 | self.__dimension = self.model.config.hidden_size 29 | except Exception: 30 | from transformers import AutoConfig 31 | config = AutoConfig.from_pretrained(model) 32 | self.__dimension = config.hidden_size 33 | 34 | def to_embeddings(self, data, **_): 35 | encoded_input = self.tokenizer(data, padding=True, truncation=True, return_tensors='pt') 36 | num_tokens = sum(map(len, encoded_input['input_ids'])) 37 | 38 | if num_tokens <= 512: 39 | with torch.no_grad(): 40 | encoded_input = {k: v.to(self.device) for k, v in encoded_input.items()} 41 | model_output = self.model(**encoded_input) 42 | sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask']) 43 | sentence_embeddings = sentence_embeddings.squeeze(0).detach().cpu().numpy() 44 | embedding_array = np.array(sentence_embeddings).astype("float32") 45 | return embedding_array 46 | else: 47 | window_size = 510 48 | start = 0 49 | input_ids = encoded_input['input_ids'] 50 | input_ids = input_ids[:, 1:-1] 51 | start_token = self.tokenizer.cls_token 52 | end_token = self.tokenizer.sep_token 53 | start_token_id = self.tokenizer.convert_tokens_to_ids(start_token) 54 | end_token_id = self.tokenizer.convert_tokens_to_ids(end_token) 55 | begin_element = torch.tensor([[start_token_id]]) 56 | end_element = torch.tensor([[end_token_id]]) 57 | 58 | embedding_array_list = list() 59 | while start < num_tokens: 60 | # Calculate the ending position of the sliding window. 61 | end = start + window_size 62 | # If the ending position exceeds the length, adjust it to the length. 63 | if end > num_tokens: 64 | end = num_tokens 65 | # Retrieve the data within the sliding window. 66 | input_ids_window = input_ids[:, start:end] 67 | # Insert a new element at position 0. 68 | input_ids_window = torch.cat([begin_element, input_ids_window[:, 0:]], dim=1) 69 | # Insert a new element at the last position. 70 | input_ids_window = torch.cat([input_ids_window, end_element], dim=1) 71 | input_ids_window_length = sum(map(len, input_ids_window)) 72 | token_type_ids = torch.tensor([[0] * input_ids_window_length]) 73 | attention_mask = torch.tensor([[1] * input_ids_window_length]) 74 | 75 | # Concatenate new input_ids 76 | encoded_input_window = {'input_ids': input_ids_window, 'token_type_ids': token_type_ids, 77 | 'attention_mask': attention_mask} 78 | with torch.no_grad(): 79 | encoded_input_window = {k: v.to(self.device) for k, v in encoded_input_window.items()} 80 | model_output_window = self.model(**encoded_input_window) 81 | 82 | sentence_embeddings_window = mean_pooling(model_output_window, encoded_input_window['attention_mask']) 83 | sentence_embeddings_window = sentence_embeddings_window.squeeze(0).detach().cpu().numpy() 84 | embedding_array_window = np.array(sentence_embeddings_window).astype("float32") 85 | embedding_array_list.append(embedding_array_window) 86 | start = end 87 | 88 | embedding_array = np.mean(embedding_array_list, axis=0) 89 | return embedding_array 90 | 91 | def post_proc(self, token_embeddings, inputs): 92 | attention_mask = inputs["attention_mask"] 93 | input_mask_expanded = ( 94 | attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() 95 | ) 96 | sentence_embs = torch.sum( 97 | token_embeddings * input_mask_expanded, 1 98 | ) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) 99 | return sentence_embs 100 | 101 | @property 102 | def dimension(self): 103 | """Embedding dimension. 104 | 105 | :return: embedding dimension 106 | """ 107 | return self.__dimension 108 | -------------------------------------------------------------------------------- /modelcache/embedding/fasttext.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | import os 4 | from modelcache.utils import import_fasttext 5 | from modelcache.embedding.base import BaseEmbedding 6 | import_fasttext() 7 | import fasttext.util 8 | 9 | 10 | class FastText(BaseEmbedding): 11 | def __init__(self, model: str = "en", dim: int = None): 12 | self.model_path = os.path.abspath(fasttext.util.download_model(model)) 13 | self.ft = fasttext.load_model(self.model_path) 14 | 15 | if dim: 16 | fasttext.util.reduce_model(self.ft, dim) 17 | self.__dimension = self.ft.get_dimension() 18 | 19 | def to_embeddings(self, data, **_): 20 | assert isinstance(data, str), "Only allow string as input." 21 | emb = self.ft.get_sentence_vector(data) 22 | return np.array(emb).astype("float32") 23 | 24 | @property 25 | def dimension(self): 26 | return self.__dimension 27 | 28 | -------------------------------------------------------------------------------- /modelcache/embedding/huggingface.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | 4 | from modelcache.utils import import_huggingface, import_torch 5 | from modelcache.embedding.base import BaseEmbedding 6 | 7 | import_torch() 8 | import_huggingface() 9 | 10 | import torch # pylint: disable=C0413 11 | from transformers import AutoTokenizer, AutoModel # pylint: disable=C0413 12 | 13 | 14 | class Huggingface(BaseEmbedding): 15 | def __init__(self, model: str = "sentence-transformers/all-MiniLM-L6-v2"): 16 | self.model = AutoModel.from_pretrained(model, local_files_only=True) 17 | self.model.eval() 18 | 19 | # self.tokenizer = AutoTokenizer.from_pretrained(model) 20 | self.tokenizer = AutoTokenizer.from_pretrained(model, local_files_only=True) 21 | if not self.tokenizer.pad_token: 22 | self.tokenizer.pad_token = "[PAD]" 23 | try: 24 | self.__dimension = self.model.config.hidden_size 25 | except Exception: # pylint: disable=W0703 26 | from transformers import AutoConfig # pylint: disable=C0415 27 | 28 | config = AutoConfig.from_pretrained(model) 29 | self.__dimension = config.hidden_size 30 | 31 | def to_embeddings(self, data, **_): 32 | """Generate embedding given text input 33 | 34 | :param data: text in string. 35 | :type data: str 36 | 37 | :return: a text embedding in shape of (dim,). 38 | """ 39 | if not isinstance(data, list): 40 | data = [data] 41 | inputs = self.tokenizer( 42 | data, padding=True, truncation=True, return_tensors="pt" 43 | ) 44 | outs = self.model(**inputs).last_hidden_state 45 | emb = self.post_proc(outs, inputs).squeeze(0).detach().numpy() 46 | return np.array(emb).astype("float32") 47 | 48 | def post_proc(self, token_embeddings, inputs): 49 | attention_mask = inputs["attention_mask"] 50 | input_mask_expanded = ( 51 | attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() 52 | ) 53 | sentence_embs = torch.sum( 54 | token_embeddings * input_mask_expanded, 1 55 | ) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) 56 | return sentence_embs 57 | 58 | @property 59 | def dimension(self): 60 | """Embedding dimension. 61 | 62 | :return: embedding dimension 63 | """ 64 | return self.__dimension 65 | -------------------------------------------------------------------------------- /modelcache/embedding/huggingface_tei.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import requests 3 | import numpy as np 4 | from modelcache.embedding.base import BaseEmbedding 5 | 6 | class HuggingfaceTEI(BaseEmbedding): 7 | def __init__(self, base_url: str, model: str): 8 | self.base_url = base_url 9 | self.model = model 10 | self.headers = { 11 | 'accept': 'application/json', 12 | 'Content-Type': 'application/json', 13 | } 14 | self.__dimension = self.to_embeddings('test').shape[0] 15 | 16 | def to_embeddings(self, data, **_): 17 | json_data = { 18 | 'input': data, 19 | 'model': self.model, 20 | } 21 | 22 | response = requests.post(self.base_url, headers=self.headers, json=json_data) 23 | embedding = response.json()['data'][0]['embedding'] 24 | return np.array(embedding) 25 | 26 | @property 27 | def dimension(self): 28 | """Embedding dimension. 29 | 30 | :return: embedding dimension 31 | """ 32 | return self.__dimension 33 | -------------------------------------------------------------------------------- /modelcache/embedding/llmEmb.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | from modelcache.embedding.base import BaseEmbedding 4 | from transformers import AutoTokenizer 5 | from transformers import AutoConfig 6 | 7 | 8 | class LlmEmb2Vec(BaseEmbedding): 9 | def __init__(self): 10 | 11 | self.model_name = '' # 13b-mft-embedding.npy 12 | model_path = '' # .npy file storage path 13 | model_file = model_path + self.model_name # .npy file 14 | config = AutoConfig.from_pretrained(model_path) 15 | dimension = config.hidden_size 16 | self.__dimension = dimension 17 | self.model = np.load(model_file) 18 | self.tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True) 19 | 20 | def to_embeddings(self, data, **_): 21 | """Generate embedding given text input 22 | 23 | :param data: text in string. 24 | :return: a text embedding in shape of (dim,). 25 | """ 26 | input_ids = self.tokenizer.encode(data, add_special_tokens=True) 27 | embedding_array = self.model[input_ids].mean(axis=0) 28 | return embedding_array 29 | 30 | def post_proc(self, token_embeddings, inputs): 31 | pass 32 | 33 | @property 34 | def dimension(self): 35 | """Embedding dimension. 36 | :return: embedding dimension 37 | """ 38 | return self.__dimension 39 | -------------------------------------------------------------------------------- /modelcache/embedding/onnx.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | 4 | from modelcache.embedding.base import BaseEmbedding 5 | from modelcache.utils import ( 6 | import_onnxruntime, 7 | import_huggingface_hub, 8 | import_huggingface, 9 | ) 10 | 11 | import_huggingface() 12 | import_onnxruntime() 13 | import_huggingface_hub() 14 | 15 | from transformers import AutoTokenizer, AutoConfig # pylint: disable=C0413 16 | import onnxruntime 17 | from modelcache.utils.env_config import get_onnx_tokenizer_path, get_onnx_model 18 | 19 | 20 | class Onnx(BaseEmbedding): 21 | 22 | def __init__(self, model="modelcache_open/paraphrase-albert-onnx"): 23 | # 本地加载 24 | onnx_tokenizer = get_onnx_tokenizer_path() 25 | self.tokenizer = AutoTokenizer.from_pretrained(onnx_tokenizer, local_files_only=True) 26 | # 本地加载 27 | onnx_model = get_onnx_model() 28 | self.ort_session = onnxruntime.InferenceSession(onnx_model) 29 | 30 | config = AutoConfig.from_pretrained(onnx_tokenizer, local_files_only=True) 31 | self.__dimension = config.hidden_size 32 | 33 | def to_embeddings(self, data, **_): 34 | """Generate embedding given text input. 35 | 36 | :param data: text in string. 37 | :type data: str 38 | 39 | :return: a text embedding in shape of (dim,). 40 | """ 41 | encoded_text = self.tokenizer.encode_plus(data, padding="max_length") 42 | ort_inputs = { 43 | "input_ids": np.array(encoded_text["input_ids"]).reshape(1, -1), 44 | "attention_mask": np.array(encoded_text["attention_mask"]).reshape(1, -1), 45 | "token_type_ids": np.array(encoded_text["token_type_ids"]).reshape(1, -1), 46 | } 47 | 48 | ort_outputs = self.ort_session.run(None, ort_inputs) 49 | ort_feat = ort_outputs[0] 50 | emb = self.post_proc(ort_feat, ort_inputs["attention_mask"]) 51 | return emb.flatten() 52 | 53 | def post_proc(self, token_embeddings, attention_mask): 54 | input_mask_expanded = ( 55 | np.expand_dims(attention_mask, -1) 56 | .repeat(token_embeddings.shape[-1], -1) 57 | .astype(float) 58 | ) 59 | sentence_embs = np.sum(token_embeddings * input_mask_expanded, 1) / np.maximum( 60 | input_mask_expanded.sum(1), 1e-9 61 | ) 62 | return sentence_embs 63 | 64 | @property 65 | def dimension(self): 66 | """Embedding dimension. 67 | 68 | :return: embedding dimension 69 | """ 70 | return self.__dimension 71 | -------------------------------------------------------------------------------- /modelcache/embedding/paddlenlp.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | 4 | from modelcache.embedding.base import BaseEmbedding 5 | from modelcache.utils import import_paddlenlp, import_paddle 6 | 7 | import_paddle() 8 | import_paddlenlp() 9 | 10 | 11 | import paddle # pylint: disable=C0413 12 | from paddlenlp.transformers import AutoModel, AutoTokenizer # pylint: disable=C0413 13 | 14 | 15 | class PaddleNLP(BaseEmbedding): 16 | def __init__(self, model: str = "ernie-3.0-medium-zh"): 17 | self.model = AutoModel.from_pretrained(model) 18 | self.model.eval() 19 | 20 | self.tokenizer = AutoTokenizer.from_pretrained(model) 21 | if not self.tokenizer.pad_token: 22 | self.tokenizer.pad_token = "" 23 | self.__dimension = None 24 | 25 | def to_embeddings(self, data, **_): 26 | """Generate embedding given text input 27 | 28 | :param data: text in string. 29 | :type data: str 30 | 31 | :return: a text embedding in shape of (dim,). 32 | """ 33 | if not isinstance(data, list): 34 | data = [data] 35 | inputs = self.tokenizer( 36 | data, padding=True, truncation=True, return_tensors="pd" 37 | ) 38 | outs = self.model(**inputs)[0] 39 | emb = self.post_proc(outs, inputs).squeeze(0).detach().numpy() 40 | return np.array(emb).astype("float32") 41 | 42 | def post_proc(self, token_embeddings, inputs): 43 | attention_mask = paddle.ones(inputs["token_type_ids"].shape) 44 | input_mask_expanded = ( 45 | attention_mask.unsqueeze(-1).expand(token_embeddings.shape).astype("float32") 46 | ) 47 | sentence_embs = paddle.sum( 48 | token_embeddings * input_mask_expanded, 1 49 | ) / paddle.clip(input_mask_expanded.sum(1), min=1e-9) 50 | return sentence_embs 51 | 52 | @property 53 | def dimension(self): 54 | """Embedding dimension. 55 | 56 | :return: embedding dimension 57 | """ 58 | if not self.__dimension: 59 | self.__dimension = len(self.to_embeddings("foo")) 60 | return self.__dimension 61 | -------------------------------------------------------------------------------- /modelcache/embedding/string_text.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | def to_embeddings(data, **_): 5 | return data 6 | -------------------------------------------------------------------------------- /modelcache/embedding/timm_embedding.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | 4 | from modelcache.utils import import_timm, import_torch, import_pillow 5 | from modelcache.embedding.base import BaseEmbedding 6 | 7 | import_torch() 8 | import_timm() 9 | import_pillow() 10 | 11 | import torch # pylint: disable=C0413 12 | from timm.models import create_model # pylint: disable=C0413 13 | from timm.data import create_transform, resolve_data_config # pylint: disable=C0413 14 | from PIL import Image # pylint: disable=C0413 15 | 16 | 17 | class Timm(BaseEmbedding): 18 | def __init__(self, model: str = "resnet18", device: str = "default"): 19 | if device == "default": 20 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 21 | else: 22 | self.device = device 23 | self.model_name = model 24 | self.model = create_model(model_name=model, pretrained=True) 25 | self.model.eval() 26 | 27 | try: 28 | self.__dimension = self.model.embed_dim 29 | except Exception: # pylint: disable=W0703 30 | self.__dimension = None 31 | 32 | def to_embeddings(self, data, skip_preprocess: bool = False, **_): 33 | if not skip_preprocess: 34 | data = self.preprocess(data) 35 | if data.dim() == 3: 36 | data = data.unsqueeze(0) 37 | feats = self.model.forward_features(data) 38 | emb = self.post_proc(feats).squeeze(0).detach().numpy() 39 | 40 | return np.array(emb).astype("float32") 41 | 42 | def post_proc(self, features): 43 | features = features.to("cpu") 44 | if features.dim() == 3: 45 | features = features[:, 0] 46 | if features.dim() == 4: 47 | global_pool = torch.nn.AdaptiveAvgPool2d(1) 48 | features = global_pool(features) 49 | features = features.flatten(1) 50 | assert features.dim() == 2, f"Invalid output dim {features.dim()}" 51 | return features 52 | 53 | def preprocess(self, image_path): 54 | data_cfg = resolve_data_config(self.model.pretrained_cfg) 55 | transform = create_transform(**data_cfg) 56 | 57 | image = Image.open(image_path).convert("RGB") 58 | image_tensor = transform(image) 59 | return image_tensor 60 | 61 | @property 62 | def dimension(self): 63 | """Embedding dimension. 64 | 65 | :return: embedding dimension 66 | """ 67 | if not self.__dimension: 68 | input_size = self.model.pretrained_cfg["input_size"] 69 | dummy_input = torch.rand((1,) + input_size) 70 | feats = self.to_embeddings(dummy_input, skip_preprocess=True) 71 | self.__dimension = feats.shape[0] 72 | return self.__dimension 73 | -------------------------------------------------------------------------------- /modelcache/manager/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from modelcache.manager.scalar_data import CacheBase 3 | from modelcache.manager.vector_data import VectorBase 4 | from modelcache.manager.object_data import ObjectBase 5 | from modelcache.manager.factory import get_data_manager 6 | -------------------------------------------------------------------------------- /modelcache/manager/eviction/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from modelcache.utils.lazy_import import LazyImport 3 | 4 | eviction_manager = LazyImport( 5 | "eviction_manager", globals(), "modelcache.manager.eviction.manager" 6 | ) 7 | 8 | 9 | def EvictionBase(name: str, **kwargs): 10 | return eviction_manager.EvictionBase.get(name, **kwargs) 11 | -------------------------------------------------------------------------------- /modelcache/manager/eviction/base.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from abc import ABCMeta, abstractmethod 3 | from typing import Any, List 4 | 5 | 6 | class EvictionBase(metaclass=ABCMeta): 7 | """ 8 | Eviction base. 9 | """ 10 | 11 | @abstractmethod 12 | def put(self, objs: List[Any]): 13 | pass 14 | 15 | @abstractmethod 16 | def get(self, obj: Any): 17 | pass 18 | 19 | @property 20 | @abstractmethod 21 | def policy(self) -> str: 22 | pass 23 | -------------------------------------------------------------------------------- /modelcache/manager/eviction/manager.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from typing import Callable, List, Any 3 | from modelcache.utils.error import NotFoundError 4 | 5 | 6 | class EvictionBase: 7 | """ 8 | EvictionBase to evict the cache data. 9 | """ 10 | 11 | def __init__(self): 12 | raise EnvironmentError( 13 | "EvictionBase is designed to be instantiated, " 14 | "please using the `EvictionBase.get(name, policy, maxsize, clean_size)`." 15 | ) 16 | 17 | @staticmethod 18 | def get(name: str, policy: str, maxsize: int, clean_size: int, on_evict: Callable[[List[Any]], None], **kwargs): 19 | if name in "memory": 20 | from modelcache.manager.eviction.memory_cache import MemoryCacheEviction 21 | 22 | eviction_base = MemoryCacheEviction(policy, maxsize, clean_size, on_evict, **kwargs) 23 | else: 24 | raise NotFoundError("eviction base", name) 25 | return eviction_base 26 | -------------------------------------------------------------------------------- /modelcache/manager/eviction/memory_cache.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from typing import Any, Callable, List 3 | import cachetools 4 | 5 | from modelcache.manager.eviction.base import EvictionBase 6 | 7 | 8 | def popitem_wrapper(func, wrapper_func, clean_size): 9 | def wrapper(*args, **kwargs): 10 | keys = [] 11 | try: 12 | keys = [func(*args, **kwargs)[0] for _ in range(clean_size)] 13 | except KeyError: 14 | pass 15 | wrapper_func(keys) 16 | return wrapper 17 | 18 | 19 | class MemoryCacheEviction(EvictionBase): 20 | def __init__(self, policy: str, maxsize: int, clean_size: int, on_evict: Callable[[List[Any]], None], **kwargs): 21 | self._policy = policy.upper() 22 | if self._policy == "LRU": 23 | self._cache = cachetools.LRUCache(maxsize=maxsize, **kwargs) 24 | elif self._policy == "LFU": 25 | self._cache = cachetools.LFUCache(maxsize=maxsize, **kwargs) 26 | elif self._policy == "FIFO": 27 | self._cache = cachetools.FIFOCache(maxsize=maxsize, **kwargs) 28 | elif self._policy == "RR": 29 | self._cache = cachetools.RRCache(maxsize=maxsize, **kwargs) 30 | else: 31 | raise ValueError(f"Unknown policy {policy}") 32 | 33 | self._cache.popitem = popitem_wrapper(self._cache.popitem, on_evict, clean_size) 34 | 35 | def put(self, objs: List[Any]): 36 | for obj in objs: 37 | self._cache[obj] = True 38 | 39 | def get(self, obj: Any): 40 | return self._cache.get(obj) 41 | 42 | @property 43 | def policy(self) -> str: 44 | return self._policy 45 | -------------------------------------------------------------------------------- /modelcache/manager/eviction_manager.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | class EvictionManager: 3 | """ 4 | EvictionManager to manager the eviction policy. 5 | 6 | :param scalar_storage: CacheStorage to manager the scalar data. 7 | :type scalar_storage: :class:`CacheStorage` 8 | :param vector_base: VectorBase to manager the vector data. 9 | :type vector_base: :class:`VectorBase` 10 | """ 11 | 12 | MAX_MARK_COUNT = 5000 13 | MAX_MARK_RATE = 0.1 14 | BATCH_SIZE = 100000 15 | REBUILD_CONDITION = 5 16 | 17 | def __init__(self, scalar_storage, vector_base): 18 | self._scalar_storage = scalar_storage 19 | self._vector_base = vector_base 20 | self.delete_count = 0 21 | 22 | def check_evict(self): 23 | mark_count = self._scalar_storage.count(state=-1) 24 | all_count = self._scalar_storage.count(is_all=True) 25 | if ( 26 | mark_count > self.MAX_MARK_COUNT 27 | or mark_count / all_count > self.MAX_MARK_RATE 28 | ): 29 | return True 30 | return False 31 | 32 | def delete(self,model): 33 | mark_ids = self._scalar_storage.get_ids(deleted=True) 34 | self._scalar_storage.clear_deleted_data() 35 | self._vector_base.delete(mark_ids,model) 36 | self.delete_count += 1 37 | if self.delete_count >= self.REBUILD_CONDITION: 38 | self.rebuild() 39 | 40 | def rebuild(self): 41 | self._scalar_storage.clear_deleted_data() 42 | ids = self._scalar_storage.get_ids(deleted=False) 43 | self._vector_base.rebuild(ids) 44 | self.delete_count = 0 45 | 46 | def soft_evict(self, marked_keys): 47 | self._scalar_storage.mark_deleted(marked_keys) 48 | -------------------------------------------------------------------------------- /modelcache/manager/factory.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from typing import Union, Callable 3 | from modelcache.manager import CacheBase, VectorBase, ObjectBase 4 | from modelcache.manager.data_manager import SSDataManager, MapDataManager 5 | 6 | 7 | def get_data_manager( 8 | cache_base: Union[CacheBase, str] = None, 9 | vector_base: Union[VectorBase, str] = None, 10 | object_base: Union[ObjectBase, str] = None, 11 | max_size: int = 1000, 12 | clean_size: int = None, 13 | eviction: str = "LRU", 14 | data_path: str = "data_map.txt", 15 | get_data_container: Callable = None, 16 | ): 17 | if not cache_base and not vector_base: 18 | return MapDataManager(data_path, max_size, get_data_container) 19 | 20 | if isinstance(cache_base, str): 21 | cache_base = CacheBase(name=cache_base) 22 | if isinstance(vector_base, str): 23 | vector_base = VectorBase(name=vector_base) 24 | if isinstance(object_base, str): 25 | object_base = ObjectBase(name=object_base) 26 | assert cache_base and vector_base 27 | return SSDataManager(cache_base, vector_base, object_base, max_size, clean_size, eviction) 28 | -------------------------------------------------------------------------------- /modelcache/manager/object_data/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from modelcache.utils.lazy_import import LazyImport 3 | object_manager = LazyImport( 4 | "object_manager", globals(), "modelcache.manager.object_data.manager" 5 | ) 6 | 7 | 8 | def ObjectBase(name: str, **kwargs): 9 | return object_manager.ObjectBase.get(name, **kwargs) 10 | -------------------------------------------------------------------------------- /modelcache/manager/object_data/base.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from abc import ABC, abstractmethod 3 | from typing import Any, List 4 | 5 | 6 | class ObjectBase(ABC): 7 | """ 8 | Object storage base. 9 | """ 10 | 11 | @abstractmethod 12 | def put(self, obj: Any) -> str: 13 | pass 14 | 15 | @abstractmethod 16 | def get(self, obj: str) -> Any: 17 | pass 18 | 19 | @abstractmethod 20 | def get_access_link(self, obj: str) -> str: 21 | pass 22 | 23 | @abstractmethod 24 | def delete(self, to_delete: List[str]): 25 | pass 26 | -------------------------------------------------------------------------------- /modelcache/manager/scalar_data/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from modelcache.utils.lazy_import import LazyImport 3 | scalar_manager = LazyImport( 4 | "scalar_manager", globals(), "modelcache.manager.scalar_data.manager" 5 | ) 6 | 7 | 8 | def CacheBase(name: str, **kwargs): 9 | return scalar_manager.CacheBase.get(name, **kwargs) 10 | -------------------------------------------------------------------------------- /modelcache/manager/scalar_data/base.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from abc import ABCMeta, abstractmethod 3 | from dataclasses import dataclass 4 | from typing import Union, Dict, List, Optional, Any 5 | from enum import IntEnum 6 | import numpy as np 7 | 8 | 9 | class DataType(IntEnum): 10 | STR = 0 11 | IMAGE_BASE64 = 1 12 | IMAGE_URL = 2 13 | 14 | 15 | @dataclass 16 | class QuestionDep: 17 | """ 18 | QuestionDep 19 | """ 20 | 21 | name: str 22 | data: str 23 | dep_type: int = DataType.STR 24 | 25 | @classmethod 26 | def from_dict(cls, d: Dict): 27 | return cls( 28 | name=d["name"], 29 | data=d["data"], 30 | dep_type=d["dep_type"] 31 | ) 32 | 33 | 34 | @dataclass 35 | class Question: 36 | """ 37 | Question 38 | """ 39 | 40 | content: str 41 | deps: Optional[List[QuestionDep]] = None 42 | 43 | @classmethod 44 | def from_dict(cls, d: Dict): 45 | deps = [] 46 | for dep in d["deps"]: 47 | deps.append(QuestionDep.from_dict(dep)) 48 | return cls(d["content"], deps) 49 | 50 | 51 | @dataclass 52 | class Answer: 53 | """ 54 | data_type: 55 | 0: str 56 | 1: base64 image 57 | """ 58 | 59 | answer: Any 60 | answer_type: int = DataType.STR 61 | 62 | 63 | @dataclass 64 | class CacheData: 65 | """ 66 | CacheData 67 | """ 68 | 69 | question: Union[str, Question] 70 | answers: List[Answer] 71 | embedding_data: Optional[np.ndarray] = None 72 | 73 | def __init__(self, question, answers, embedding_data=None): 74 | self.question = question 75 | self.answers = [] 76 | if isinstance(answers, (str, Answer)): 77 | answers = [answers] 78 | for data in answers: 79 | if isinstance(data, (list, tuple)): 80 | self.answers.append(Answer(*data)) 81 | elif isinstance(data, Answer): 82 | self.answers.append(data) 83 | else: 84 | self.answers.append(Answer(answer=data)) 85 | self.embedding_data = embedding_data 86 | 87 | 88 | class CacheStorage(metaclass=ABCMeta): 89 | """ 90 | BaseStorage for scalar data. 91 | """ 92 | 93 | @abstractmethod 94 | def create(self): 95 | pass 96 | 97 | @abstractmethod 98 | def insert_query_resp(self, query_resp, **kwargs): 99 | pass 100 | 101 | @abstractmethod 102 | def get_data_by_id(self, key): 103 | pass 104 | 105 | @abstractmethod 106 | def mark_deleted(self, keys): 107 | pass 108 | 109 | @abstractmethod 110 | def model_deleted(self, model_name): 111 | pass 112 | 113 | @abstractmethod 114 | def clear_deleted_data(self): 115 | pass 116 | 117 | @abstractmethod 118 | def get_ids(self, deleted=True): 119 | pass 120 | 121 | @abstractmethod 122 | def count(self): 123 | pass 124 | 125 | def flush(self): 126 | pass 127 | 128 | @abstractmethod 129 | def close(self): 130 | pass 131 | -------------------------------------------------------------------------------- /modelcache/manager/scalar_data/manager.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from modelcache.utils import import_sql_client 3 | from modelcache.utils.error import NotFoundError 4 | 5 | SQL_URL = {"sqlite": "./sqlite.db"} 6 | 7 | 8 | class CacheBase: 9 | """ 10 | CacheBase to manager the cache storage. 11 | """ 12 | 13 | def __init__(self): 14 | raise EnvironmentError( 15 | "CacheBase is designed to be instantiated, please using the `CacheBase.get(name)`." 16 | ) 17 | 18 | @staticmethod 19 | def get(name, **kwargs): 20 | 21 | if name in ["mysql", "oceanbase"]: 22 | from modelcache.manager.scalar_data.sql_storage import SQLStorage 23 | config = kwargs.get("config") 24 | import_sql_client(name) 25 | cache_base = SQLStorage(db_type=name, config=config) 26 | elif name == 'sqlite': 27 | from modelcache.manager.scalar_data.sql_storage_sqlite import SQLStorage 28 | sql_url = kwargs.get("sql_url", SQL_URL[name]) 29 | cache_base = SQLStorage(db_type=name, url=sql_url) 30 | elif name == 'elasticsearch': 31 | from modelcache.manager.scalar_data.sql_storage_es import SQLStorage 32 | config = kwargs.get("config") 33 | cache_base = SQLStorage(db_type=name, config=config) 34 | else: 35 | raise NotFoundError("cache store", name) 36 | return cache_base 37 | -------------------------------------------------------------------------------- /modelcache/manager/scalar_data/sql_storage_sqlite.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import json 3 | from typing import List 4 | from modelcache.manager.scalar_data.base import CacheStorage, CacheData 5 | import sqlite3 6 | 7 | 8 | class SQLStorage(CacheStorage): 9 | def __init__( 10 | self, 11 | db_type: str = "mysql", 12 | config=None, 13 | url="./sqlite.db" 14 | ): 15 | self._url = url 16 | # self._engine = sqlite3.connect(url) 17 | self.create() 18 | 19 | def create(self): 20 | answer_table_sql = """CREATE TABLE IF NOT EXISTS modelcache_llm_answer ( 21 | id INTEGER PRIMARY KEY AUTOINCREMENT, 22 | gmt_create TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, 23 | gmt_modified TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, 24 | question TEXT NOT NULL, 25 | answer TEXT NOT NULL, 26 | answer_type INTEGER NOT NULL, 27 | hit_count INTEGER NOT NULL DEFAULT 0, 28 | model VARCHAR(1000) NOT NULL, 29 | embedding_data BLOB NOT NULL 30 | ); 31 | """ 32 | 33 | log_table_sql = """CREATE TABLE IF NOT EXISTS modelcache_query_log ( 34 | id INTEGER PRIMARY KEY AUTOINCREMENT, 35 | gmt_create TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, 36 | gmt_modified TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, 37 | error_code INTEGER NOT NULL, 38 | error_desc VARCHAR(1000) NOT NULL, 39 | cache_hit VARCHAR(100) NOT NULL, 40 | delta_time REAL NOT NULL, 41 | model VARCHAR(1000) NOT NULL, 42 | query TEXT NOT NULL, 43 | hit_query TEXT NOT NULL, 44 | answer TEXT NOT NULL 45 | ); 46 | """ 47 | 48 | conn = sqlite3.connect(self._url) 49 | try: 50 | cursor = conn.cursor() 51 | cursor.execute(answer_table_sql) 52 | cursor.execute(log_table_sql) 53 | conn.commit() 54 | cursor.close() 55 | conn.close() 56 | finally: 57 | conn.close() 58 | 59 | def _insert(self, data: List): 60 | answer = data[0] 61 | question = data[1] 62 | embedding_data = data[2] 63 | model = data[3] 64 | answer_type = 0 65 | embedding_data = embedding_data.tobytes() 66 | 67 | table_name = "modelcache_llm_answer" 68 | insert_sql = "INSERT INTO {} (question, answer, answer_type, model, embedding_data) VALUES (?, ?, ?, ?, ?)".format(table_name) 69 | 70 | conn = sqlite3.connect(self._url) 71 | try: 72 | cursor = conn.cursor() 73 | values = (question, answer, answer_type, model, embedding_data) 74 | cursor.execute(insert_sql, values) 75 | conn.commit() 76 | id = cursor.lastrowid 77 | cursor.close() 78 | conn.close() 79 | finally: 80 | conn.close() 81 | return id 82 | 83 | def batch_insert(self, all_data: List[CacheData]): 84 | ids = [] 85 | for data in all_data: 86 | ids.append(self._insert(data)) 87 | return ids 88 | 89 | def insert_query_resp(self, query_resp, **kwargs): 90 | error_code = query_resp.get('errorCode') 91 | error_desc = query_resp.get('errorDesc') 92 | cache_hit = query_resp.get('cacheHit') 93 | model = kwargs.get('model') 94 | query = kwargs.get('query') 95 | delta_time = kwargs.get('delta_time') 96 | hit_query = query_resp.get('hit_query') 97 | answer = query_resp.get('answer') 98 | 99 | if isinstance(hit_query, list): 100 | hit_query = json.dumps(hit_query, ensure_ascii=False) 101 | 102 | table_name = "modelcache_query_log" 103 | insert_sql = "INSERT INTO {} (error_code, error_desc, cache_hit, model, query, delta_time, hit_query, answer) VALUES (?, ?, ?, ?, ?, ?, ?, ?)".format(table_name) 104 | conn = sqlite3.connect(self._url) 105 | try: 106 | cursor = conn.cursor() 107 | values = (error_code, error_desc, cache_hit, model, query, delta_time, hit_query, answer) 108 | cursor.execute(insert_sql, values) 109 | conn.commit() 110 | cursor.close() 111 | conn.close() 112 | finally: 113 | conn.close() 114 | 115 | def get_data_by_id(self, key: int): 116 | table_name = "modelcache_llm_answer" 117 | query_sql = "select question, answer, embedding_data, model from {} where id={}".format(table_name, key) 118 | conn = sqlite3.connect(self._url) 119 | try: 120 | cursor = conn.cursor() 121 | cursor.execute(query_sql) 122 | resp = cursor.fetchone() 123 | conn.commit() 124 | cursor.close() 125 | conn.close() 126 | finally: 127 | conn.close() 128 | 129 | if resp is not None and len(resp) == 4: 130 | return resp 131 | else: 132 | return None 133 | 134 | def update_hit_count_by_id(self, primary_id: int): 135 | table_name = "modelcache_llm_answer" 136 | update_sql = "UPDATE {} SET hit_count = hit_count+1 WHERE id={}".format(table_name, primary_id) 137 | 138 | conn = sqlite3.connect(self._url) 139 | try: 140 | cursor = conn.cursor() 141 | cursor.execute(update_sql) 142 | conn.commit() 143 | cursor.close() 144 | conn.close() 145 | finally: 146 | # 关闭连接,将连接返回给连接池 147 | conn.close() 148 | 149 | def get_ids(self, deleted=True): 150 | pass 151 | 152 | def mark_deleted(self, keys): 153 | table_name = "modelcache_llm_answer" 154 | delete_sql = "Delete from {} WHERE id in ({})".format(table_name, ",".join([str(i) for i in keys])) 155 | conn = sqlite3.connect(self._url) 156 | try: 157 | cursor = conn.cursor() 158 | cursor.execute(delete_sql) 159 | delete_count = cursor.rowcount 160 | conn.commit() 161 | cursor.close() 162 | conn.close() 163 | finally: 164 | conn.close() 165 | return delete_count 166 | 167 | def model_deleted(self, model_name): 168 | table_name = "modelcache_llm_answer" 169 | delete_sql = "Delete from {} WHERE model=?".format(table_name) 170 | 171 | table_log_name = "modelcache_query_log" 172 | delete_log_sql = "Delete from {} WHERE model=?".format(table_log_name) 173 | conn = sqlite3.connect(self._url) 174 | try: 175 | cursor = conn.cursor() 176 | cursor.execute(delete_sql, (model_name,)) 177 | conn.commit() 178 | # get delete rows 179 | deleted_rows_count = cursor.rowcount 180 | 181 | cursor.execute(delete_log_sql, (model_name,)) 182 | conn.commit() 183 | cursor.close() 184 | except sqlite3.Error as e: 185 | print(f"SQLite error: {e}") 186 | deleted_rows_count = 0 # if except, return 0 187 | finally: 188 | conn.close() 189 | return deleted_rows_count 190 | 191 | def clear_deleted_data(self): 192 | pass 193 | 194 | def count(self, state: int = 0, is_all: bool = False): 195 | pass 196 | 197 | def close(self): 198 | pass 199 | 200 | def count_answers(self): 201 | pass 202 | -------------------------------------------------------------------------------- /modelcache/manager/vector_data/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from modelcache.utils.lazy_import import LazyImport 3 | 4 | vector_manager = LazyImport( 5 | "vector_manager", globals(), "modelcache.manager.vector_data.manager" 6 | ) 7 | 8 | 9 | def VectorBase(name: str, **kwargs): 10 | return vector_manager.VectorBase.get(name, **kwargs) 11 | -------------------------------------------------------------------------------- /modelcache/manager/vector_data/base.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from abc import ABC, abstractmethod 3 | import numpy as np 4 | from typing import List 5 | from dataclasses import dataclass 6 | 7 | 8 | @dataclass 9 | class VectorData: 10 | id: int 11 | data: np.ndarray 12 | 13 | 14 | class VectorBase(ABC): 15 | """VectorBase: base vector store interface""" 16 | 17 | @abstractmethod 18 | def mul_add(self, datas: List[VectorData], model=None): 19 | pass 20 | 21 | @abstractmethod 22 | def search(self, data: np.ndarray, top_k: int, model): 23 | pass 24 | 25 | @abstractmethod 26 | def rebuild(self, ids=None) -> bool: 27 | pass 28 | 29 | @abstractmethod 30 | def delete(self, ids) -> bool: 31 | pass 32 | 33 | @abstractmethod 34 | def rebuild_col(self, model): 35 | pass 36 | 37 | def flush(self): 38 | pass 39 | 40 | def close(self): 41 | pass 42 | -------------------------------------------------------------------------------- /modelcache/manager/vector_data/chroma.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | import logging 5 | from modelcache.manager.vector_data.base import VectorBase, VectorData 6 | from modelcache.utils import import_chromadb, import_torch 7 | 8 | import_torch() 9 | import_chromadb() 10 | 11 | import chromadb 12 | 13 | 14 | class Chromadb(VectorBase): 15 | 16 | def __init__( 17 | self, 18 | persist_directory="./chromadb", 19 | top_k: int = 1, 20 | ): 21 | self.collection_name = "modelcache" 22 | self.top_k = top_k 23 | 24 | self._client = chromadb.PersistentClient(path=persist_directory) 25 | self._collection = None 26 | 27 | def mul_add(self, datas: List[VectorData], model=None): 28 | collection_name_model = self.collection_name + '_' + model 29 | self._collection = self._client.get_or_create_collection(name=collection_name_model) 30 | 31 | data_array, id_array = map(list, zip(*((data.data.tolist(), str(data.id)) for data in datas))) 32 | self._collection.add(embeddings=data_array, ids=id_array) 33 | 34 | def search(self, data: np.ndarray, top_k: int = -1, model=None): 35 | collection_name_model = self.collection_name + '_' + model 36 | self._collection = self._client.get_or_create_collection(name=collection_name_model) 37 | 38 | if self._collection.count() == 0: 39 | return [] 40 | if top_k == -1: 41 | top_k = self.top_k 42 | results = self._collection.query( 43 | query_embeddings=[data.tolist()], 44 | n_results=top_k, 45 | include=["distances"], 46 | ) 47 | return list(zip(results["distances"][0], [int(x) for x in results["ids"][0]])) 48 | 49 | def rebuild(self, ids=None): 50 | pass 51 | 52 | def delete(self, ids, model=None): 53 | try: 54 | collection_name_model = self.collection_name + '_' + model 55 | self._collection = self._client.get_or_create_collection(name=collection_name_model) 56 | # 查询集合中实际存在的 ID 57 | ids_str = [str(x) for x in ids] 58 | existing_ids = set(self._collection.get(ids=ids_str).ids) 59 | 60 | # 删除存在的 ID 61 | if existing_ids: 62 | self._collection.delete(list(existing_ids)) 63 | 64 | # 返回实际删除的条目数量 65 | return len(existing_ids) 66 | 67 | except Exception as e: 68 | logging.error('Error during deletion: {}'.format(e)) 69 | raise ValueError(str(e)) 70 | 71 | def rebuild_col(self, model): 72 | collection_name_model = self.collection_name + '_' + model 73 | 74 | # 检查集合是否存在,如果存在则删除 75 | collections = self._client.list_collections() 76 | if any(col.name == collection_name_model for col in collections): 77 | self._client.delete_collection(collection_name_model) 78 | else: 79 | return 'model collection not found, please check!' 80 | 81 | try: 82 | self._client.create_collection(collection_name_model) 83 | except Exception as e: 84 | logging.info(f'rebuild_collection: {e}') 85 | raise ValueError(str(e)) 86 | 87 | def flush(self): 88 | # chroma无flush方法 89 | pass 90 | 91 | def close(self): 92 | pass 93 | -------------------------------------------------------------------------------- /modelcache/manager/vector_data/faiss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | from typing import List 4 | import numpy as np 5 | from modelcache.manager.vector_data.base import VectorBase, VectorData 6 | from modelcache.utils import import_faiss 7 | import_faiss() 8 | import faiss # pylint: disable=C0413 9 | 10 | 11 | class Faiss(VectorBase): 12 | def __init__(self, index_file_path, dimension, top_k): 13 | self._index_file_path = index_file_path 14 | self._dimension = dimension 15 | self._index = faiss.index_factory(self._dimension, "IDMap,Flat", faiss.METRIC_L2) 16 | self._top_k = top_k 17 | if os.path.isfile(index_file_path): 18 | self._index = faiss.read_index(index_file_path) 19 | 20 | def mul_add(self, datas: List[VectorData], model=None): 21 | data_array, id_array = map(list, zip(*((data.data, data.id) for data in datas))) 22 | np_data = np.array(data_array).astype("float32") 23 | ids = np.array(id_array) 24 | self._index.add_with_ids(np_data, ids) 25 | 26 | def search(self, data: np.ndarray, top_k: int = -1, model=None): 27 | if self._index.ntotal == 0: 28 | return None 29 | if top_k == -1: 30 | top_k = self._top_k 31 | np_data = np.array(data).astype("float32").reshape(1, -1) 32 | dist, ids = self._index.search(np_data, top_k) 33 | ids = [int(i) for i in ids[0]] 34 | return list(zip(dist[0], ids)) 35 | 36 | def rebuild_col(self, ids=None): 37 | try: 38 | self._index.reset() 39 | except Exception as e: 40 | return f"An error occurred during index rebuild: {e}" 41 | 42 | def rebuild(self, ids=None): 43 | return True 44 | 45 | def delete(self, ids): 46 | ids_to_remove = np.array(ids) 47 | self._index.remove_ids(faiss.IDSelectorBatch(ids_to_remove.size, faiss.swig_ptr(ids_to_remove))) 48 | 49 | def flush(self): 50 | faiss.write_index(self._index, self._index_file_path) 51 | 52 | def close(self): 53 | self.flush() 54 | 55 | def count(self): 56 | return self._index.ntotal 57 | -------------------------------------------------------------------------------- /modelcache/manager/vector_data/manager.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from modelcache.utils.error import NotFoundError, ParamError 3 | 4 | TOP_K = 1 5 | FAISS_INDEX_PATH = "faiss.index" 6 | DIMENSION = 0 7 | MILVUS_HOST = "localhost" 8 | MILVUS_PORT = 19530 9 | MILVUS_USER = "" 10 | MILVUS_PSW = "" 11 | MILVUS_SECURE = False 12 | MILVUS_INDEX_PARAMS = { 13 | "metric_type": "L2", 14 | "index_type": "HNSW", 15 | "params": {"M": 8, "efConstruction": 64}, 16 | } 17 | 18 | COLLECTION_NAME = "modelcache" 19 | 20 | 21 | class VectorBase: 22 | """ 23 | VectorBase to manager the vector base. 24 | """ 25 | 26 | def __init__(self): 27 | raise EnvironmentError( 28 | "VectorBase is designed to be instantiated, please using the `VectorBase.get(name)`." 29 | ) 30 | 31 | @staticmethod 32 | def check_dimension(dimension): 33 | if dimension <= 0: 34 | raise ParamError( 35 | f"the dimension should be greater than zero, current value: {dimension}." 36 | ) 37 | 38 | @staticmethod 39 | def get(name, **kwargs): 40 | top_k = kwargs.get("top_k", TOP_K) 41 | if name == "milvus": 42 | from modelcache.manager.vector_data.milvus import Milvus 43 | dimension = kwargs.get("dimension", DIMENSION) 44 | milvus_config = kwargs.get("milvus_config") 45 | VectorBase.check_dimension(dimension) 46 | host = milvus_config.get('milvus', 'host') 47 | port = milvus_config.get('milvus', 'port') 48 | user = milvus_config.get('milvus', 'user') 49 | password = milvus_config.get('milvus', 'password') 50 | 51 | secure = kwargs.get("secure", MILVUS_SECURE) 52 | collection_name = kwargs.get("collection_name", COLLECTION_NAME) 53 | index_params = kwargs.get("index_params", MILVUS_INDEX_PARAMS) 54 | search_params = kwargs.get("search_params", None) 55 | local_mode = kwargs.get("local_mode", False) 56 | local_data = kwargs.get("local_data", "./milvus_data") 57 | vector_base = Milvus( 58 | host=host, 59 | port=port, 60 | user=user, 61 | password=password, 62 | secure=secure, 63 | collection_name=collection_name, 64 | dimension=dimension, 65 | top_k=top_k, 66 | index_params=index_params, 67 | search_params=search_params, 68 | local_mode=local_mode, 69 | local_data=local_data 70 | ) 71 | elif name == "redis": 72 | from modelcache.manager.vector_data.redis import RedisVectorStore 73 | dimension = kwargs.get("dimension", DIMENSION) 74 | VectorBase.check_dimension(dimension) 75 | 76 | redis_config = kwargs.get("redis_config") 77 | host = redis_config.get('redis', 'host') 78 | port = redis_config.get('redis', 'port') 79 | user = redis_config.get('redis', 'user') 80 | password = redis_config.get('redis', 'password') 81 | namespace = kwargs.get("namespace", "") 82 | # collection_name = kwargs.get("collection_name", COLLECTION_NAME) 83 | 84 | vector_base = RedisVectorStore( 85 | host=host, 86 | port=port, 87 | username=user, 88 | password=password, 89 | namespace=namespace, 90 | top_k=top_k, 91 | dimension=dimension, 92 | ) 93 | elif name == "faiss": 94 | from modelcache.manager.vector_data.faiss import Faiss 95 | 96 | dimension = kwargs.get("dimension", DIMENSION) 97 | index_path = kwargs.pop("index_path", FAISS_INDEX_PATH) 98 | VectorBase.check_dimension(dimension) 99 | vector_base = Faiss( 100 | index_file_path=index_path, dimension=dimension, top_k=top_k 101 | ) 102 | elif name == "chromadb": 103 | from modelcache.manager.vector_data.chroma import Chromadb 104 | 105 | chromadb_config = kwargs.get("chromadb_config", None) 106 | persist_directory = chromadb_config.get('chromadb','persist_directory') 107 | 108 | vector_base = Chromadb( 109 | persist_directory=persist_directory, 110 | top_k=top_k, 111 | ) 112 | elif name == "hnswlib": 113 | from modelcache.manager.vector_data.hnswlib_store import Hnswlib 114 | 115 | dimension = kwargs.get("dimension", DIMENSION) 116 | index_path = kwargs.pop("index_path", "./hnswlib_index.bin") 117 | max_elements = kwargs.pop("max_elements", 100000) 118 | VectorBase.check_dimension(dimension) 119 | vector_base = Hnswlib( 120 | index_file_path=index_path, dimension=dimension, 121 | top_k=top_k, max_elements=max_elements 122 | ) 123 | else: 124 | raise NotFoundError("vector store", name) 125 | return vector_base 126 | -------------------------------------------------------------------------------- /modelcache/manager/vector_data/redis.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from typing import List 3 | import numpy as np 4 | from redis.commands.search.indexDefinition import IndexDefinition, IndexType 5 | from redis.commands.search.query import Query 6 | from redis.commands.search.field import TagField, VectorField, NumericField 7 | from redis.client import Redis 8 | 9 | from modelcache.manager.vector_data.base import VectorBase, VectorData 10 | from modelcache.utils import import_redis 11 | from modelcache.utils.log import modelcache_log 12 | from modelcache.utils.index_util import get_index_name 13 | from modelcache.utils.index_util import get_index_prefix 14 | import_redis() 15 | 16 | 17 | class RedisVectorStore(VectorBase): 18 | def __init__( 19 | self, 20 | host: str = "localhost", 21 | port: str = "6379", 22 | username: str = "", 23 | password: str = "", 24 | dimension: int = 0, 25 | top_k: int = 1, 26 | namespace: str = "", 27 | ): 28 | if dimension <= 0: 29 | raise ValueError( 30 | f"invalid `dim` param: {dimension} in the Redis vector store." 31 | ) 32 | self._client = Redis( 33 | host=host, port=int(port), username=username, password=password 34 | ) 35 | self.top_k = top_k 36 | self.dimension = dimension 37 | self.namespace = namespace 38 | self.doc_prefix = f"{self.namespace}doc:" 39 | 40 | def _check_index_exists(self, index_name: str) -> bool: 41 | """Check if Redis index exists.""" 42 | try: 43 | self._client.ft(index_name).info() 44 | except: 45 | modelcache_log.info("Index does not exist") 46 | return False 47 | modelcache_log.info("Index already exists") 48 | return True 49 | 50 | def create_index(self, index_name, index_prefix): 51 | dimension = self.dimension 52 | if self._check_index_exists(index_name): 53 | modelcache_log.info( 54 | "The %s already exists, and it will be used directly", index_name 55 | ) 56 | return 'already_exists' 57 | else: 58 | id_field_name = "data_id" 59 | embedding_field_name = "data_vector" 60 | 61 | id = NumericField(name=id_field_name) 62 | embedding = VectorField(embedding_field_name, 63 | "HNSW", { 64 | "TYPE": "FLOAT32", 65 | "DIM": dimension, 66 | "DISTANCE_METRIC": "L2", 67 | "INITIAL_CAP": 1000, 68 | } 69 | ) 70 | fields = [id, embedding] 71 | definition = IndexDefinition(prefix=[index_prefix], index_type=IndexType.HASH) 72 | 73 | # create Index 74 | self._client.ft(index_name).create_index( 75 | fields=fields, definition=definition 76 | ) 77 | return 'create_success' 78 | 79 | def mul_add(self, datas: List[VectorData], model=None): 80 | # pipe = self._client.pipeline() 81 | for data in datas: 82 | id: int = data.id 83 | embedding = data.data.astype(np.float32).tobytes() 84 | id_field_name = "data_id" 85 | embedding_field_name = "data_vector" 86 | obj = {id_field_name: id, embedding_field_name: embedding} 87 | index_prefix = get_index_prefix(model) 88 | self._client.hset(f"{index_prefix}{id}", mapping=obj) 89 | 90 | def search(self, data: np.ndarray, top_k: int = -1, model=None): 91 | index_name = get_index_name(model) 92 | id_field_name = "data_id" 93 | embedding_field_name = "data_vector" 94 | base_query = f'*=>[KNN 2 @{embedding_field_name} $vector AS distance]' 95 | query = ( 96 | Query(base_query) 97 | .sort_by("distance") 98 | .return_fields(id_field_name, "distance") 99 | .dialect(2) 100 | ) 101 | 102 | query_params = {"vector": data.astype(np.float32).tobytes()} 103 | results = ( 104 | self._client.ft(index_name) 105 | .search(query, query_params=query_params) 106 | .docs 107 | ) 108 | return [(float(result.distance), int(getattr(result, id_field_name))) for result in results] 109 | 110 | def rebuild(self, ids=None) -> bool: 111 | pass 112 | 113 | def rebuild_col(self, model): 114 | index_name_model = get_index_name(model) 115 | if self._check_index_exists(index_name_model): 116 | try: 117 | self._client.ft(index_name_model).dropindex(delete_documents=True) 118 | except Exception as e: 119 | raise ValueError(str(e)) 120 | try: 121 | index_prefix = get_index_prefix(model) 122 | self.create_index(index_name_model, index_prefix) 123 | except Exception as e: 124 | raise ValueError(str(e)) 125 | # return 'rebuild success' 126 | 127 | def delete(self, ids) -> None: 128 | pipe = self._client.pipeline() 129 | for data_id in ids: 130 | pipe.delete(f"{self.doc_prefix}{data_id}") 131 | pipe.execute() 132 | 133 | def create(self, model=None): 134 | index_name = get_index_name(model) 135 | index_prefix = get_index_prefix(model) 136 | return self.create_index(index_name, index_prefix) 137 | 138 | def get_index_by_name(self, index_name): 139 | pass 140 | -------------------------------------------------------------------------------- /modelcache/processor/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /modelcache/processor/post.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import random 3 | from typing import List, Any 4 | 5 | 6 | def random_one(messages: List[Any]) -> Any: 7 | return random.choice(messages) 8 | 9 | 10 | def first(messages: List[Any]) -> Any: 11 | return messages[0] 12 | 13 | 14 | def nop(messages: List[Any]) -> Any: 15 | return messages 16 | -------------------------------------------------------------------------------- /modelcache/processor/pre.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import re 3 | from typing import Dict, Any 4 | 5 | 6 | def insert_last_content(data: Dict[str, Any], **_: Dict[str, Any]) -> Any: 7 | return data.get("chat_info")[-1]["query"] 8 | 9 | 10 | def query_last_content(data: Dict[str, Any], **_: Dict[str, Any]) -> Any: 11 | return data.get("query")[-1]["content"] 12 | 13 | 14 | def last_content_without_prompt(data: Dict[str, Any], **params: Dict[str, Any]) -> Any: 15 | last_content_str = data.get("messages")[-1]["content"] 16 | prompts = params.get("prompts", []) 17 | if prompts is None: 18 | return last_content_str 19 | pattern = "|".join(prompts) 20 | new_content_str = re.sub(pattern, "", last_content_str) 21 | return new_content_str 22 | 23 | 24 | def all_content(data: Dict[str, Any], **_: Dict[str, Any]) -> Any: 25 | s = "" 26 | messages = data.get("messages") 27 | for i, message in enumerate(messages): 28 | if i == len(messages) - 1: 29 | s += message["content"] 30 | else: 31 | s += message["content"] + "\n" 32 | return s 33 | 34 | 35 | def nop(data: Dict[str, Any], **_: Dict[str, Any]) -> Any: 36 | return data 37 | 38 | 39 | def get_prompt(data: Dict[str, Any], **_: Dict[str, Any]) -> Any: 40 | return data.get("prompt") 41 | 42 | 43 | def get_file_name(data: Dict[str, Any], **_: Dict[str, Any]) -> str: 44 | return data.get("file").name 45 | 46 | 47 | def get_file_bytes(data: Dict[str, Any], **_: Dict[str, Any]) -> bytes: 48 | return data.get("file").peek() 49 | 50 | 51 | def get_input_str(data: Dict[str, Any], **_: Dict[str, Any]) -> str: 52 | input_data = data.get("input") 53 | return str(input_data["image"].peek()) + input_data["question"] 54 | 55 | 56 | def get_input_image_file_name(data: Dict[str, Any], **_: Dict[str, Any]) -> str: 57 | input_data = data.get("input") 58 | return input_data["image"].name 59 | 60 | 61 | def query_multi_splicing(data: Dict[str, Any], **_: Dict[str, Any]) -> Any: 62 | query_list = data.get("query") 63 | return multi_splicing(query_list) 64 | 65 | 66 | def insert_multi_splicing(data: Dict[str, Any], **_: Dict[str, Any]) -> Any: 67 | insert_query_list = data.get("chat_info")[-1]['query'] 68 | return multi_splicing(insert_query_list) 69 | 70 | 71 | def multi_splicing(data_list) -> Any: 72 | result_str = "" 73 | for d in data_list: 74 | role = d.get('role', '') 75 | content = d.get('content', '') 76 | result_str += role + "###" + content + "|||" 77 | 78 | # 去掉最后一个"|||" 79 | result_str = result_str[:-3] 80 | 81 | return result_str 82 | 83 | 84 | def multi_analysis(dialog_str): 85 | sub_strings = dialog_str.split('|||') 86 | 87 | dict_list = [] 88 | for s in sub_strings: 89 | parts = s.split('###') 90 | 91 | if len(parts) == 2: 92 | role = parts[0] 93 | content = parts[1] 94 | elif len(parts) > 2: 95 | role = parts[0] 96 | content = '###'.join(parts[1:]) 97 | else: 98 | content = 'exception' 99 | 100 | if content == '': 101 | d = {"role": role} 102 | else: 103 | d = {"role": role, "content": content} 104 | dict_list.append(d) 105 | 106 | # 3. 将每个字典添加到一个列表中,得到最终的列表 107 | result_list = dict_list 108 | 109 | # 输出结果 110 | return result_list 111 | -------------------------------------------------------------------------------- /modelcache/report.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | class Report: 3 | def __init__(self): 4 | self.embedding_all_time = 0 5 | self.embedding_count = 0 6 | self.search_all_time = 0 7 | self.search_count = 0 8 | self.hint_cache_count = 0 9 | 10 | def embedding(self, delta_time): 11 | """Embedding counts and time. 12 | 13 | :param delta_time: additional runtime. 14 | """ 15 | self.embedding_all_time += delta_time 16 | self.embedding_count += 1 17 | 18 | def search(self, delta_time): 19 | """Search counts and time. 20 | 21 | :param delta_time: additional runtime. 22 | """ 23 | self.search_all_time += delta_time 24 | self.search_count += 1 25 | 26 | def average_embedding_time(self): 27 | """Average embedding time.""" 28 | return round( 29 | self.embedding_all_time / self.embedding_count 30 | if self.embedding_count != 0 31 | else 0, 32 | 4, 33 | ) 34 | 35 | def average_search_time(self): 36 | return round( 37 | self.search_all_time / self.search_count 38 | if self.embedding_count != 0 39 | else 0, 40 | 4, 41 | ) 42 | 43 | def hint_cache(self): 44 | self.hint_cache_count += 1 45 | -------------------------------------------------------------------------------- /modelcache/similarity_evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from modelcache.similarity_evaluation.similarity_evaluation import SimilarityEvaluation 3 | from modelcache.utils.lazy_import import LazyImport 4 | 5 | exact_match = LazyImport( 6 | "exact_match", globals(), "modelcache.similarity_evaluation.exact_match" 7 | ) 8 | 9 | 10 | def ExactMatchEvaluation(): 11 | return exact_match.ExactMatchEvaluation() 12 | -------------------------------------------------------------------------------- /modelcache/similarity_evaluation/distance.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from typing import Tuple, Dict, Any 3 | from modelcache.similarity_evaluation import SimilarityEvaluation 4 | 5 | 6 | class SearchDistanceEvaluation(SimilarityEvaluation): 7 | def __init__(self, max_distance=4.0, positive=False): 8 | self.max_distance = max_distance 9 | self.positive = positive 10 | 11 | def evaluation( 12 | self, src_dict: Dict[str, Any], cache_dict: Dict[str, Any], **_ 13 | ) -> float: 14 | distance, _ = cache_dict["search_result"] 15 | if distance < 0: 16 | distance = 0 17 | elif distance > self.max_distance: 18 | distance = self.max_distance 19 | if self.positive: 20 | return distance 21 | return self.max_distance - distance 22 | 23 | def range(self) -> Tuple[float, float]: 24 | return 0.0, self.max_distance 25 | -------------------------------------------------------------------------------- /modelcache/similarity_evaluation/exact_match.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from typing import Tuple, Dict, Any 3 | from modelcache.similarity_evaluation.similarity_evaluation import SimilarityEvaluation 4 | 5 | 6 | class ExactMatchEvaluation(SimilarityEvaluation): 7 | 8 | def __init__(self): 9 | pass 10 | 11 | def evaluation( 12 | self, src_dict: Dict[str, Any], cache_dict: Dict[str, Any], **_ 13 | ) -> float: 14 | return 1 if cache_dict["question"] == src_dict["question"] else 0 15 | 16 | def range(self) -> Tuple[float, float]: 17 | return 0, 1 18 | -------------------------------------------------------------------------------- /modelcache/similarity_evaluation/similarity_evaluation.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from abc import ABCMeta, abstractmethod 3 | from typing import Tuple, Dict, Any 4 | 5 | 6 | class SimilarityEvaluation(metaclass=ABCMeta): 7 | @abstractmethod 8 | def evaluation( 9 | self, src_dict: Dict[str, Any], cache_dict: Dict[str, Any], **kwargs 10 | ) -> float: 11 | pass 12 | 13 | @abstractmethod 14 | def range(self) -> Tuple[float, float]: 15 | pass 16 | -------------------------------------------------------------------------------- /modelcache/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import importlib.util 3 | from typing import Optional 4 | from modelcache.utils.dependency_control import prompt_install 5 | 6 | 7 | def _check_library(libname: str, prompt: bool = True, package: Optional[str] = None): 8 | is_avail = False 9 | if importlib.util.find_spec(libname): 10 | is_avail = True 11 | if not is_avail and prompt: 12 | prompt_install(package if package else libname) 13 | return is_avail 14 | 15 | 16 | def import_onnxruntime(): 17 | _check_library("onnxruntime") 18 | 19 | 20 | def import_huggingface(): 21 | _check_library("transformers") 22 | 23 | 24 | def import_huggingface_hub(): 25 | _check_library("huggingface_hub", package="huggingface-hub") 26 | 27 | 28 | def import_pymysql(): 29 | _check_library("pymysql") 30 | 31 | 32 | def import_sql_client(db_name): 33 | if db_name in ["mysql"]: 34 | import_pymysql() 35 | 36 | 37 | def import_pymilvus(): 38 | _check_library("pymilvus") 39 | 40 | 41 | def import_milvus_lite(): 42 | _check_library("milvus") 43 | 44 | 45 | def import_faiss(): 46 | _check_library("faiss", package="faiss-cpu") 47 | 48 | 49 | def import_torch(): 50 | _check_library("torch") 51 | 52 | 53 | def import_fasttext(): 54 | _check_library("fasttext") 55 | 56 | 57 | def import_paddle(): 58 | prompt_install("protobuf==3.20.0") 59 | _check_library("paddlepaddle") 60 | 61 | 62 | def import_paddlenlp(): 63 | _check_library("paddlenlp") 64 | 65 | 66 | def import_timm(): 67 | _check_library("timm", package="timm") 68 | 69 | 70 | def import_pillow(): 71 | _check_library("PIL", package="pillow") 72 | 73 | 74 | def import_redis(): 75 | _check_library("redis") 76 | 77 | 78 | def import_chromadb(): 79 | _check_library("chromadb", package="chromadb") -------------------------------------------------------------------------------- /modelcache/utils/cache_func.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | def cache_all(*_, **__): 3 | return True -------------------------------------------------------------------------------- /modelcache/utils/dependency_control.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import subprocess 3 | from modelcache.utils.error import PipInstallError 4 | from modelcache.utils.log import modelcache_log 5 | 6 | 7 | def prompt_install(package: str, warn: bool = False): # pragma: no cover 8 | """ 9 | Function used to prompt user to install a package. 10 | """ 11 | cmd = f"pip install {package}" 12 | try: 13 | if warn and input(f"Install {package}? Y/n: ") != "Y": 14 | raise ModuleNotFoundError(f"No module named {package}") 15 | subprocess.check_call(cmd, shell=True) 16 | modelcache_log.info("%s installed successfully!", package) 17 | except subprocess.CalledProcessError as e: 18 | raise PipInstallError(package) from e 19 | -------------------------------------------------------------------------------- /modelcache/utils/env_config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /modelcache/utils/error.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | class CacheError(Exception): 3 | """ModelCache base error""" 4 | 5 | 6 | class NotInitError(CacheError): 7 | """Raise when the cache has been used before it's inited""" 8 | def __init__(self): 9 | super().__init__("The cache should be inited before using") 10 | 11 | 12 | class RemoveError(CacheError): 13 | """Raise when the cache has been used before it's inited""" 14 | def __init__(self): 15 | super().__init__("The cache remove error") 16 | 17 | class NotFoundError(CacheError): 18 | """Raise when getting an unsupported store.""" 19 | def __init__(self, store_type, current_type_name): 20 | super().__init__(f"Unsupported ${store_type}: {current_type_name}") 21 | 22 | 23 | class ParamError(CacheError): 24 | """Raise when receiving an invalid param.""" 25 | 26 | 27 | class PipInstallError(CacheError): 28 | """Raise when failed to install package.""" 29 | def __init__(self, package): 30 | super().__init__(f"Ran into error installing {package}.") 31 | -------------------------------------------------------------------------------- /modelcache/utils/index_util.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | def get_index_name(model): 5 | return 'modelcache' + '_' + model 6 | 7 | 8 | def get_index_prefix(model): 9 | return 'prefix' + '_' + model 10 | -------------------------------------------------------------------------------- /modelcache/utils/lazy_import.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import importlib 3 | from types import ModuleType 4 | 5 | 6 | class LazyImport(ModuleType): 7 | """ 8 | Lazily import a module. 9 | """ 10 | 11 | def __init__(self, local_name, parent_module_globals, name): 12 | self._local_name = local_name 13 | self._parent_module_globals = parent_module_globals 14 | super().__init__(name) 15 | 16 | def _load(self): 17 | module = importlib.import_module(self.__name__) 18 | self._parent_module_globals[self._local_name] = module 19 | self.__dict__.update(module.__dict__) 20 | return module 21 | 22 | def __getattr__(self, item): 23 | module = self._load() 24 | return getattr(module, item) 25 | 26 | def __dir__(self): 27 | module = self._load() 28 | return dir(module) 29 | -------------------------------------------------------------------------------- /modelcache/utils/log.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import logging 3 | 4 | FORMAT = '%(asctime)s - %(thread)d - %(filename)s-%(module)s:%(lineno)s - %(levelname)s: %(message)s' 5 | logging.basicConfig(format=FORMAT) 6 | 7 | modelcache_log = logging.getLogger('modelcache') 8 | -------------------------------------------------------------------------------- /modelcache/utils/model_filter.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | def model_blacklist_filter(model, request_type): 3 | black_list = ['DI_COPILOT_SECOND', 'DI_COPILOT_LAB', 'DI_COPILOT_THIRD'] 4 | result = None 5 | if model in black_list: 6 | if request_type == 'query': 7 | result = {"errorCode": 105, 8 | "errorDesc": "model: {} in blacklist".format(model), 9 | "cacheHit": False, "delta_time": 0, "hit_query": '', "answer": ''} 10 | elif request_type == 'insert': 11 | result = {"errorCode": 305, "errorDesc": "model: {} in blacklist".format(model), "writeStatus": ""} 12 | 13 | return result 14 | 15 | 16 | -------------------------------------------------------------------------------- /modelcache/utils/time.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import time 3 | from modelcache import cache 4 | 5 | 6 | def time_cal(func, func_name=None, report_func=None): 7 | def inner(*args, **kwargs): 8 | time_start = time.time() 9 | res = func(*args, **kwargs) 10 | delta_time = time.time() - time_start 11 | if cache.config.log_time_func: 12 | cache.config.log_time_func( 13 | func.__name__ if func_name is None else func_name, delta_time 14 | ) 15 | if report_func is not None: 16 | report_func(delta_time) 17 | return res 18 | 19 | return inner 20 | 21 | -------------------------------------------------------------------------------- /modelcache_mm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from modelcache_mm.core import Cache 3 | from modelcache_mm.core import cache 4 | from modelcache_mm.config import Config 5 | import modelcache_mm.adapter 6 | -------------------------------------------------------------------------------- /modelcache_mm/adapter/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /modelcache_mm/adapter/adapter.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import logging 3 | 4 | from modelcache_mm.adapter.adapter_query import adapt_query 5 | from modelcache_mm.adapter.adapter_insert import adapt_insert 6 | from modelcache_mm.adapter.adapter_remove import adapt_remove 7 | from modelcache_mm.adapter.adapter_register import adapt_register 8 | 9 | 10 | class ChatCompletion(object): 11 | """Openai ChatCompletion Wrapper""" 12 | @classmethod 13 | def create_query(cls, *args, **kwargs): 14 | def cache_data_convert(cache_data, cache_query): 15 | return construct_resp_from_cache(cache_data, cache_query) 16 | try: 17 | return adapt_query( 18 | cache_data_convert, 19 | *args, 20 | **kwargs 21 | ) 22 | except Exception as e: 23 | # return str(e) 24 | raise e 25 | 26 | @classmethod 27 | def create_insert(cls, *args, **kwargs): 28 | try: 29 | return adapt_insert( 30 | *args, 31 | **kwargs 32 | ) 33 | except Exception as e: 34 | # return str(e) 35 | raise e 36 | 37 | @classmethod 38 | def create_remove(cls, *args, **kwargs): 39 | try: 40 | return adapt_remove( 41 | *args, 42 | **kwargs 43 | ) 44 | except Exception as e: 45 | raise e 46 | 47 | @classmethod 48 | def create_register(cls, *args, **kwargs): 49 | try: 50 | return adapt_register( 51 | *args, 52 | **kwargs 53 | ) 54 | except Exception as e: 55 | raise e 56 | 57 | 58 | def construct_resp_from_cache(return_message, return_query): 59 | return { 60 | "modelcache": True, 61 | "hitQuery": return_query, 62 | "data": return_message, 63 | "errorCode": 0 64 | } 65 | -------------------------------------------------------------------------------- /modelcache_mm/adapter/adapter_insert.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | from modelcache_mm import cache 4 | from modelcache_mm.utils.error import NotInitError 5 | from modelcache_mm.utils.time import time_cal 6 | 7 | 8 | def adapt_insert(*args, **kwargs): 9 | chat_cache = kwargs.pop("cache_obj", cache) 10 | model = kwargs.pop("model", None) 11 | require_object_store = kwargs.pop("require_object_store", False) 12 | if require_object_store: 13 | assert chat_cache.data_manager.o, "Object store is required for adapter." 14 | if not chat_cache.has_init: 15 | raise NotInitError() 16 | cache_enable = chat_cache.cache_enable_func(*args, **kwargs) 17 | context = kwargs.pop("cache_context", {}) 18 | embedding_data = None 19 | pre_embedding_data_dict = chat_cache.insert_pre_embedding_func( 20 | kwargs, 21 | extra_param=context.get("pre_embedding_func", None), 22 | prompts=chat_cache.config.prompts, 23 | ) 24 | 25 | chat_info = kwargs.pop("chat_info", []) 26 | llm_data = chat_info[-1]['answer'] 27 | 28 | pre_embedding_text = '###'.join(pre_embedding_data_dict['text']) 29 | pre_embedding_image_url = pre_embedding_data_dict['imageUrl'] 30 | pre_embedding_image_raw = pre_embedding_data_dict['imageRaw'] 31 | pre_embedding_image_id = pre_embedding_data_dict.get('imageId', None) 32 | 33 | if pre_embedding_image_url and pre_embedding_image_raw: 34 | raise ValueError("Both pre_embedding_image_url and pre_embedding_image_raw cannot be non-empty at the same time.") 35 | 36 | if pre_embedding_image_url: 37 | pre_embedding_image = pre_embedding_image_url 38 | elif pre_embedding_image_raw: 39 | pre_embedding_image = pre_embedding_image_raw 40 | else: 41 | pre_embedding_image = None 42 | if not pre_embedding_text: 43 | raise ValueError( 44 | "Both pre_embedding_image_url and pre_embedding_image_raw are empty. Please provide at least one.") 45 | 46 | data_dict = {'text': [pre_embedding_text], 'image': pre_embedding_image} 47 | embedding_data = None 48 | mm_type = None 49 | 50 | if cache_enable: 51 | embedding_data_resp = time_cal( 52 | chat_cache.embedding_func, 53 | func_name="image_embedding", 54 | report_func=chat_cache.report.embedding, 55 | )(data_dict) 56 | 57 | image_embeddings = embedding_data_resp['image_embedding'] 58 | text_embeddings = embedding_data_resp['text_embeddings'] 59 | 60 | if len(image_embeddings) > 0 and len(image_embeddings) > 0: 61 | # image_embedding = np.array(image_embeddings[0]) 62 | # text_embedding = text_embeddings[0] 63 | embedding_data = np.concatenate((image_embeddings, text_embeddings)) 64 | mm_type = 'mm' 65 | elif len(image_embeddings) > 0: 66 | image_embedding = np.array(image_embeddings[0]) 67 | embedding_data = image_embedding 68 | mm_type = 'image' 69 | elif len(text_embeddings) > 0: 70 | text_embedding = np.array(text_embeddings[0]) 71 | embedding_data = text_embedding 72 | mm_type = 'text' 73 | else: 74 | raise ValueError('maya embedding service return both empty list, please check!') 75 | chat_cache.data_manager.save( 76 | pre_embedding_text, 77 | pre_embedding_image_url, 78 | pre_embedding_image_id, 79 | llm_data, 80 | embedding_data, 81 | model=model, 82 | mm_type=mm_type, 83 | extra_param=context.get("mm_save_func", None) 84 | ) 85 | return 'success' 86 | -------------------------------------------------------------------------------- /modelcache_mm/adapter/adapter_register.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from modelcache_mm import cache 3 | 4 | 5 | def adapt_register(*args, **kwargs): 6 | chat_cache = kwargs.pop("cache_obj", cache) 7 | model = kwargs.pop("model", None) 8 | type = kwargs.pop("type", None) 9 | if model is None or len(model) == 0: 10 | return ValueError('') 11 | 12 | register_resp = chat_cache.data_manager.create_index(model, type) 13 | return register_resp 14 | -------------------------------------------------------------------------------- /modelcache_mm/adapter/adapter_remove.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from modelcache_mm import cache 3 | from modelcache_mm.utils.error import NotInitError 4 | 5 | 6 | def adapt_remove(*args, **kwargs): 7 | chat_cache = kwargs.pop("cache_obj", cache) 8 | model = kwargs.pop("model", None) 9 | remove_type = kwargs.pop("remove_type", None) 10 | require_object_store = kwargs.pop("require_object_store", False) 11 | if require_object_store: 12 | assert chat_cache.data_manager.o, "Object store is required for adapter." 13 | if not chat_cache.has_init: 14 | raise NotInitError() 15 | 16 | # delete data 17 | if remove_type == 'delete_by_id': 18 | id_list = kwargs.pop("id_list", []) 19 | resp = chat_cache.data_manager.delete(id_list, model=model) 20 | elif remove_type == 'truncate_by_model': 21 | resp = chat_cache.data_manager.truncate(model) 22 | else: 23 | resp = "remove_type_error" 24 | return resp 25 | -------------------------------------------------------------------------------- /modelcache_mm/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from typing import Optional, Callable, List 3 | from modelcache.utils.error import CacheError 4 | 5 | 6 | class Config: 7 | 8 | def __init__( 9 | self, 10 | log_time_func: Optional[Callable[[str, float], None]] = None, 11 | similarity_threshold: float = 0.95, 12 | similarity_threshold_long: float = 0.95, 13 | prompts: Optional[List[str]] = None 14 | ): 15 | if similarity_threshold < 0 or similarity_threshold > 1: 16 | raise CacheError( 17 | "Invalid the similarity threshold param, reasonable range: 0-1" 18 | ) 19 | self.log_time_func = log_time_func 20 | self.similarity_threshold = similarity_threshold 21 | self.similarity_threshold_long = similarity_threshold_long 22 | self.prompts = prompts 23 | -------------------------------------------------------------------------------- /modelcache_mm/config/chromadb_config.ini: -------------------------------------------------------------------------------- 1 | [chromadb] 2 | persist_directory=./chromadb 3 | -------------------------------------------------------------------------------- /modelcache_mm/config/elasticsearch_config.ini: -------------------------------------------------------------------------------- 1 | [elasticsearch] 2 | host = '' 3 | port = '' 4 | user = '' 5 | password = '' -------------------------------------------------------------------------------- /modelcache_mm/config/milvus_config.ini: -------------------------------------------------------------------------------- 1 | [milvus] 2 | host = '' 3 | port = '' 4 | user = '' 5 | password = '' -------------------------------------------------------------------------------- /modelcache_mm/config/mysql_config.ini: -------------------------------------------------------------------------------- 1 | [mysql] 2 | host = '' 3 | port = '' 4 | username = '' 5 | password = '' 6 | database = '' 7 | -------------------------------------------------------------------------------- /modelcache_mm/config/redis_config.ini: -------------------------------------------------------------------------------- 1 | [redis] 2 | host = '' 3 | port = '' 4 | user = '' 5 | password = '' 6 | -------------------------------------------------------------------------------- /modelcache_mm/core.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import atexit 3 | from typing import Optional, List, Any 4 | from modelcache_mm.processor.post import first 5 | from modelcache_mm.similarity_evaluation import ExactMatchEvaluation 6 | from modelcache_mm.similarity_evaluation import SimilarityEvaluation 7 | from modelcache_mm.embedding.string import to_embeddings as string_embedding 8 | from modelcache_mm.report import Report 9 | from modelcache_mm.config import Config 10 | from modelcache_mm.utils.cache_func import cache_all 11 | from modelcache_mm.utils.log import modelcache_log 12 | from modelcache_mm.manager import get_data_manager 13 | from modelcache_mm.manager.data_manager import DataManager 14 | 15 | 16 | class Cache: 17 | def __init__(self): 18 | self.has_init = False 19 | self.cache_enable_func = None 20 | self.query_pre_embedding_func = None 21 | self.insert_pre_embedding_func = None 22 | self.embedding_func = None 23 | self.data_manager: Optional[DataManager] = None 24 | self.similarity_evaluation: Optional[SimilarityEvaluation] = None 25 | self.post_process_messages_func = None 26 | self.config = Config() 27 | self.report = Report() 28 | self.next_cache = None 29 | 30 | def init( 31 | self, 32 | cache_enable_func=cache_all, 33 | query_pre_embedding_func=None, 34 | insert_pre_embedding_func=None, 35 | embedding_func=string_embedding, 36 | data_manager: DataManager = get_data_manager(), 37 | similarity_evaluation=ExactMatchEvaluation(), 38 | post_process_messages_func=first, 39 | config=Config(), 40 | next_cache=None, 41 | ): 42 | self.has_init = True 43 | self.cache_enable_func = cache_enable_func 44 | self.query_pre_embedding_func = query_pre_embedding_func 45 | self.insert_pre_embedding_func = insert_pre_embedding_func 46 | self.embedding_func = embedding_func 47 | self.data_manager: DataManager = data_manager 48 | self.similarity_evaluation = similarity_evaluation 49 | self.post_process_messages_func = post_process_messages_func 50 | self.config = config 51 | self.next_cache = next_cache 52 | 53 | @atexit.register 54 | def close(): 55 | try: 56 | self.data_manager.close() 57 | except Exception as e: 58 | modelcache_log.error(e) 59 | 60 | def import_data(self, questions: List[Any], answers: List[Any]) -> None: 61 | self.data_manager.import_data( 62 | questions=questions, 63 | answers=answers, 64 | embedding_datas=[self.embedding_func(question) for question in questions], 65 | ) 66 | 67 | def flush(self): 68 | self.data_manager.flush() 69 | if self.next_cache: 70 | self.next_cache.data_manager.flush() 71 | 72 | 73 | cache = Cache() 74 | -------------------------------------------------------------------------------- /modelcache_mm/embedding/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from modelcache.utils.lazy_import import LazyImport 3 | clip = LazyImport("clip", globals(), "modelcache_mm.embedding.clip") 4 | 5 | 6 | def Clip2Vec(model="damo/multi-modal_clip-vit-base-patch16_zh"): 7 | return clip.ClipAudio(model) 8 | -------------------------------------------------------------------------------- /modelcache_mm/embedding/base.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from abc import ABCMeta, abstractmethod 3 | 4 | 5 | class BaseEmbedding(metaclass=ABCMeta): 6 | """ 7 | _Embedding base. 8 | """ 9 | 10 | @abstractmethod 11 | def to_embeddings(self, data, **kwargs): 12 | pass 13 | 14 | @property 15 | @abstractmethod 16 | def dimension(self) -> int: 17 | return 0 18 | -------------------------------------------------------------------------------- /modelcache_mm/embedding/clip.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | from modelcache.embedding.base import BaseEmbedding 4 | from modelscope.utils.constant import Tasks 5 | from modelscope.pipelines import pipeline 6 | from modelscope.preprocessors.image import load_image 7 | 8 | 9 | class ClipAudio(BaseEmbedding): 10 | def __init__(self, model: str = 'damo/multi-modal_clip-vit-base-patch16_zh'): 11 | self.model = model 12 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 13 | self.clip_pipeline = pipeline(task=Tasks.multi_modal_embedding, 14 | model=model, model_revision='v1.0.1') 15 | self.__dimension = 1024 16 | 17 | def to_embeddings(self, data_dict, **_): 18 | text_list = data_dict['text'] 19 | image_data = data_dict['image'] 20 | 21 | # img_data = None 22 | # txt_data = None 23 | 24 | if image_data: 25 | input_img = load_image(image_data) 26 | img_embedding = self.clip_pipeline.forward({'img': input_img})['img_embedding'].tolist()[0] if input_img else [] 27 | else: 28 | raise ValueError('image_data is None, please check!') 29 | 30 | if text_list and len(text_list) > 0: 31 | text_embedding = self.clip_pipeline.forward({'text': text_list})['text_embedding'].tolist()[0] if text_list else [] 32 | else: 33 | raise ValueError('text_list is None, please check!') 34 | 35 | return {'image_embedding': img_embedding, 'text_embeddings': text_embedding} 36 | 37 | def post_proc(self, token_embeddings, inputs): 38 | attention_mask = inputs["attention_mask"] 39 | input_mask_expanded = ( 40 | attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() 41 | ) 42 | sentence_embs = torch.sum( 43 | token_embeddings * input_mask_expanded, 1 44 | ) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) 45 | return sentence_embs 46 | 47 | @property 48 | def dimension(self): 49 | return self.__dimension 50 | -------------------------------------------------------------------------------- /modelcache_mm/embedding/string.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | def to_embeddings(data, **_): 5 | return data 6 | -------------------------------------------------------------------------------- /modelcache_mm/embedding/timm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | 4 | from modelcache.utils import import_timm, import_torch, import_pillow 5 | from modelcache.embedding.base import BaseEmbedding 6 | 7 | import_torch() 8 | import_timm() 9 | import_pillow() 10 | 11 | import torch 12 | from timm.models import create_model 13 | from timm.data import create_transform, resolve_data_config 14 | from PIL import Image 15 | 16 | 17 | class Timm(BaseEmbedding): 18 | def __init__(self, model: str = "resnet18", device: str = "default"): 19 | if device == "default": 20 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 21 | else: 22 | self.device = device 23 | self.model_name = model 24 | self.model = create_model(model_name=model, pretrained=True) 25 | self.model.eval() 26 | 27 | try: 28 | self.__dimension = self.model.embed_dim 29 | except Exception: 30 | self.__dimension = None 31 | 32 | def to_embeddings(self, data, skip_preprocess: bool = False, **_): 33 | if not skip_preprocess: 34 | data = self.preprocess(data) 35 | if data.dim() == 3: 36 | data = data.unsqueeze(0) 37 | feats = self.model.forward_features(data) 38 | emb = self.post_proc(feats).squeeze(0).detach().numpy() 39 | 40 | return np.array(emb).astype("float32") 41 | 42 | def post_proc(self, features): 43 | features = features.to("cpu") 44 | if features.dim() == 3: 45 | features = features[:, 0] 46 | if features.dim() == 4: 47 | global_pool = torch.nn.AdaptiveAvgPool2d(1) 48 | features = global_pool(features) 49 | features = features.flatten(1) 50 | assert features.dim() == 2, f"Invalid output dim {features.dim()}" 51 | return features 52 | 53 | def preprocess(self, image_path): 54 | data_cfg = resolve_data_config(self.model.pretrained_cfg) 55 | transform = create_transform(**data_cfg) 56 | 57 | image = Image.open(image_path).convert("RGB") 58 | image_tensor = transform(image) 59 | return image_tensor 60 | 61 | @property 62 | def dimension(self): 63 | """Embedding dimension. 64 | :return: embedding dimension 65 | """ 66 | if not self.__dimension: 67 | input_size = self.model.pretrained_cfg["input_size"] 68 | dummy_input = torch.rand((1,) + input_size) 69 | feats = self.to_embeddings(dummy_input, skip_preprocess=True) 70 | self.__dimension = feats.shape[0] 71 | return self.__dimension 72 | -------------------------------------------------------------------------------- /modelcache_mm/manager/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from modelcache_mm.manager.scalar_data import CacheBase 3 | from modelcache_mm.manager.vector_data import VectorBase 4 | from modelcache_mm.manager.object_data import ObjectBase 5 | from modelcache_mm.manager.factory import get_data_manager 6 | -------------------------------------------------------------------------------- /modelcache_mm/manager/eviction/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from modelcache.utils.lazy_import import LazyImport 3 | 4 | eviction_manager = LazyImport( 5 | "eviction_manager", globals(), "modelcache.manager.eviction.manager" 6 | ) 7 | 8 | 9 | def EvictionBase(name: str, **kwargs): 10 | return eviction_manager.EvictionBase.get(name, **kwargs) 11 | -------------------------------------------------------------------------------- /modelcache_mm/manager/eviction/base.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from abc import ABCMeta, abstractmethod 3 | from typing import Any, List 4 | 5 | 6 | class EvictionBase(metaclass=ABCMeta): 7 | """ 8 | Eviction base. 9 | """ 10 | 11 | @abstractmethod 12 | def put(self, objs: List[Any]): 13 | pass 14 | 15 | @abstractmethod 16 | def get(self, obj: Any): 17 | pass 18 | 19 | @property 20 | @abstractmethod 21 | def policy(self) -> str: 22 | pass 23 | -------------------------------------------------------------------------------- /modelcache_mm/manager/eviction_manager.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | class EvictionManager: 3 | MAX_MARK_COUNT = 5000 4 | MAX_MARK_RATE = 0.1 5 | BATCH_SIZE = 100000 6 | REBUILD_CONDITION = 5 7 | 8 | def __init__(self, scalar_storage, vector_base): 9 | self._scalar_storage = scalar_storage 10 | self._vector_base = vector_base 11 | self.delete_count = 0 12 | 13 | def check_evict(self): 14 | mark_count = self._scalar_storage.count(state=-1) 15 | all_count = self._scalar_storage.count(is_all=True) 16 | if ( 17 | mark_count > self.MAX_MARK_COUNT 18 | or mark_count / all_count > self.MAX_MARK_RATE 19 | ): 20 | return True 21 | return False 22 | 23 | def delete(self): 24 | mark_ids = self._scalar_storage.get_ids(deleted=True) 25 | self._scalar_storage.clear_deleted_data() 26 | self._vector_base.delete(mark_ids) 27 | self.delete_count += 1 28 | if self.delete_count >= self.REBUILD_CONDITION: 29 | self.rebuild() 30 | 31 | def rebuild(self): 32 | self._scalar_storage.clear_deleted_data() 33 | ids = self._scalar_storage.get_ids(deleted=False) 34 | self._vector_base.rebuild(ids) 35 | self.delete_count = 0 36 | 37 | def soft_evict(self, marked_keys): 38 | self._scalar_storage.mark_deleted(marked_keys) 39 | -------------------------------------------------------------------------------- /modelcache_mm/manager/factory.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from typing import Union, Callable 3 | from modelcache_mm.manager import CacheBase, VectorBase, ObjectBase 4 | from modelcache_mm.manager.data_manager import SSDataManager, MapDataManager 5 | 6 | 7 | def get_data_manager( 8 | cache_base: Union[CacheBase, str] = None, 9 | vector_base: Union[VectorBase, str] = None, 10 | object_base: Union[ObjectBase, str] = None, 11 | max_size: int = 1000, 12 | clean_size: int = None, 13 | eviction: str = "LRU", 14 | data_path: str = "data_map.txt", 15 | get_data_container: Callable = None, 16 | ): 17 | if not cache_base and not vector_base: 18 | return MapDataManager(data_path, max_size, get_data_container) 19 | if isinstance(cache_base, str): 20 | cache_base = CacheBase(name=cache_base) 21 | if isinstance(vector_base, str): 22 | vector_base = VectorBase(name=vector_base) 23 | if isinstance(object_base, str): 24 | object_base = ObjectBase(name=object_base) 25 | assert cache_base and vector_base 26 | return SSDataManager(cache_base, vector_base, object_base, max_size, clean_size, eviction) 27 | -------------------------------------------------------------------------------- /modelcache_mm/manager/object_data/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from modelcache.utils.lazy_import import LazyImport 3 | object_manager = LazyImport( 4 | "object_manager", globals(), "modelcache.manager.object_data.manager" 5 | ) 6 | 7 | 8 | def ObjectBase(name: str, **kwargs): 9 | return object_manager.ObjectBase.get(name, **kwargs) 10 | -------------------------------------------------------------------------------- /modelcache_mm/manager/object_data/base.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from abc import ABC, abstractmethod 3 | from typing import Any, List 4 | 5 | 6 | class ObjectBase(ABC): 7 | """ 8 | Object storage base. 9 | """ 10 | 11 | @abstractmethod 12 | def put(self, obj: Any) -> str: 13 | pass 14 | 15 | @abstractmethod 16 | def get(self, obj: str) -> Any: 17 | pass 18 | 19 | @abstractmethod 20 | def get_access_link(self, obj: str) -> str: 21 | pass 22 | 23 | @abstractmethod 24 | def delete(self, to_delete: List[str]): 25 | pass 26 | -------------------------------------------------------------------------------- /modelcache_mm/manager/scalar_data/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from modelcache_mm.utils.lazy_import import LazyImport 3 | scalar_manager = LazyImport( 4 | "scalar_manager", globals(), "modelcache_mm.manager.scalar_data.manager" 5 | ) 6 | 7 | 8 | def CacheBase(name: str, **kwargs): 9 | return scalar_manager.CacheBase.get(name, **kwargs) 10 | -------------------------------------------------------------------------------- /modelcache_mm/manager/scalar_data/base.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from abc import ABCMeta, abstractmethod 3 | from dataclasses import dataclass 4 | from typing import Union, Dict, List, Optional, Any 5 | from enum import IntEnum 6 | import numpy as np 7 | 8 | 9 | class DataType(IntEnum): 10 | STR = 0 11 | IMAGE_BASE64 = 1 12 | IMAGE_URL = 2 13 | 14 | 15 | @dataclass 16 | class QuestionDep: 17 | """ 18 | QuestionDep 19 | """ 20 | 21 | name: str 22 | data: str 23 | dep_type: int = DataType.STR 24 | 25 | @classmethod 26 | def from_dict(cls, d: Dict): 27 | return cls( 28 | name=d["name"], 29 | data=d["data"], 30 | dep_type=d["dep_type"] 31 | ) 32 | 33 | 34 | @dataclass 35 | class Question: 36 | """ 37 | Question 38 | """ 39 | 40 | content: str 41 | deps: Optional[List[QuestionDep]] = None 42 | 43 | @classmethod 44 | def from_dict(cls, d: Dict): 45 | deps = [] 46 | for dep in d["deps"]: 47 | deps.append(QuestionDep.from_dict(dep)) 48 | return cls(d["content"], deps) 49 | 50 | 51 | @dataclass 52 | class Answer: 53 | """ 54 | data_type: 55 | 0: str 56 | 1: base64 image 57 | """ 58 | 59 | answer: Any 60 | answer_type: int = DataType.STR 61 | 62 | 63 | @dataclass 64 | class CacheData: 65 | """ 66 | CacheData 67 | """ 68 | 69 | question: Union[str, Question] 70 | answers: List[Answer] 71 | embedding_data: Optional[np.ndarray] = None 72 | 73 | def __init__(self, question, answers, embedding_data=None): 74 | self.question = question 75 | self.answers = [] 76 | if isinstance(answers, (str, Answer)): 77 | answers = [answers] 78 | for data in answers: 79 | if isinstance(data, (list, tuple)): 80 | self.answers.append(Answer(*data)) 81 | elif isinstance(data, Answer): 82 | self.answers.append(data) 83 | else: 84 | self.answers.append(Answer(answer=data)) 85 | self.embedding_data = embedding_data 86 | 87 | 88 | class CacheStorage(metaclass=ABCMeta): 89 | """ 90 | BaseStorage for scalar data. 91 | """ 92 | 93 | @abstractmethod 94 | def create(self): 95 | pass 96 | 97 | @abstractmethod 98 | def batch_insert(self, all_data: List[CacheData]): 99 | pass 100 | 101 | @abstractmethod 102 | def insert_query_resp(self, query_resp, **kwargs): 103 | pass 104 | 105 | @abstractmethod 106 | def get_data_by_id(self, key): 107 | pass 108 | 109 | @abstractmethod 110 | def mark_deleted(self, keys): 111 | pass 112 | 113 | @abstractmethod 114 | def model_deleted(self, model_name): 115 | pass 116 | 117 | @abstractmethod 118 | def clear_deleted_data(self): 119 | pass 120 | 121 | @abstractmethod 122 | def get_ids(self, deleted=True): 123 | pass 124 | 125 | @abstractmethod 126 | def count(self): 127 | pass 128 | 129 | def flush(self): 130 | pass 131 | 132 | @abstractmethod 133 | def close(self): 134 | pass 135 | -------------------------------------------------------------------------------- /modelcache_mm/manager/scalar_data/manager.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from modelcache_mm.utils import import_sql_client 3 | from modelcache_mm.utils.error import NotFoundError 4 | 5 | SQL_URL = {"sqlite": "./sqlite.db"} 6 | 7 | 8 | class CacheBase: 9 | """ 10 | CacheBase to manager the cache storage. 11 | """ 12 | 13 | def __init__(self): 14 | raise EnvironmentError( 15 | "CacheBase is designed to be instantiated, please using the `CacheBase.get(name)`." 16 | ) 17 | 18 | @staticmethod 19 | def get(name, **kwargs): 20 | 21 | if name in ["mysql", "oceanbase"]: 22 | from modelcache_mm.manager.scalar_data.sql_storage import SQLStorage 23 | config = kwargs.get("config") 24 | import_sql_client(name) 25 | cache_base = SQLStorage(db_type=name, config=config) 26 | elif name == 'sqlite': 27 | from modelcache_mm.manager.scalar_data.sql_storage_sqlite import SQLStorage 28 | sql_url = kwargs.get("sql_url", SQL_URL[name]) 29 | cache_base = SQLStorage(db_type=name, url=sql_url) 30 | else: 31 | raise NotFoundError("cache store", name) 32 | return cache_base 33 | -------------------------------------------------------------------------------- /modelcache_mm/manager/scalar_data/sql_storage.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import time 4 | 5 | import pymysql 6 | import json 7 | import base64 8 | from typing import List 9 | from modelcache_mm.manager.scalar_data.base import CacheStorage, CacheData 10 | from DBUtils.PooledDB import PooledDB 11 | 12 | 13 | class SQLStorage(CacheStorage): 14 | def __init__( 15 | self, 16 | db_type: str = "mysql", 17 | config=None 18 | ): 19 | self.host = config.get('mysql', 'host') 20 | self.port = int(config.get('mysql', 'port')) 21 | self.username = config.get('mysql', 'username') 22 | self.password = config.get('mysql', 'password') 23 | self.database = config.get('mysql', 'database') 24 | self.pool = PooledDB( 25 | creator=pymysql, 26 | host=self.host, 27 | user=self.username, 28 | password=self.password, 29 | port=self.port, 30 | database=self.database 31 | ) 32 | 33 | def create(self): 34 | pass 35 | 36 | # def _insert(self, data: List): 37 | # answer = data[0] 38 | # text = data[1] 39 | # image_url = data[2] 40 | # image_id = data[3] 41 | # model = data[4] 42 | # answer_type = 0 43 | # 44 | # table_name = "multimodal_answer" 45 | # insert_sql = "INSERT INTO {} (question_text, image_url, image_id, answer, answer_type, model) VALUES (%s, %s, %s, %s, %s, %s)".format(table_name) 46 | # conn = self.pool.connection() 47 | # try: 48 | # with conn.cursor() as cursor: 49 | # # data insert operation 50 | # values = (text, image_url, image_id, answer, answer_type, model) 51 | # cursor.execute(insert_sql, values) 52 | # conn.commit() 53 | # id = cursor.lastrowid 54 | # finally: 55 | # # Close the connection and return it back to the connection pool 56 | # conn.close() 57 | # return id 58 | 59 | def _insert(self, data: List): 60 | answer = data[0] 61 | text = data[1] 62 | image_url = data[2] 63 | image_id = data[3] 64 | model = data[4] 65 | answer_type = 0 66 | 67 | table_name = "open_cache_mm_answer" 68 | insert_sql = "INSERT INTO {} (question_text, image_url, image_id, answer, answer_type, model) VALUES (%s, %s, %s, %s, %s, %s)".format(table_name) 69 | 70 | conn = self.pool.connection() 71 | try: 72 | with conn.cursor() as cursor: 73 | # insert data operation 74 | values = (text, image_url, image_id, answer, answer_type, model) 75 | cursor.execute(insert_sql, values) 76 | conn.commit() 77 | id = cursor.lastrowid 78 | finally: 79 | # Close the connection and return it to the connection pool. 80 | conn.close() 81 | return id 82 | 83 | def batch_insert(self, all_data: List[CacheData]): 84 | ids = [] 85 | for data in all_data: 86 | ids.append(self._insert(data)) 87 | return ids 88 | 89 | def insert_query_resp(self, query_resp, **kwargs): 90 | error_code = query_resp.get('errorCode') 91 | error_desc = query_resp.get('errorDesc') 92 | cache_hit = query_resp.get('cacheHit') 93 | model = kwargs.get('model') 94 | query = kwargs.get('query') 95 | delta_time = kwargs.get('delta_time') 96 | hit_query = query_resp.get('hit_query') 97 | answer = query_resp.get('answer') 98 | 99 | if isinstance(hit_query, list): 100 | hit_query = json.dumps(hit_query, ensure_ascii=False) 101 | 102 | table_name = "open_cache_mm_query_log" 103 | insert_sql = "INSERT INTO {} (error_code, error_desc, cache_hit, model, query, delta_time, hit_query, answer) VALUES (%s, %s, %s, %s, %s, %s, %s, %s)".format(table_name) 104 | conn = self.pool.connection() 105 | try: 106 | with conn.cursor() as cursor: 107 | # 执行插入数据操作 108 | values = (error_code, error_desc, cache_hit, model, query, delta_time, hit_query, answer) 109 | cursor.execute(insert_sql, values) 110 | conn.commit() 111 | finally: 112 | # 关闭连接,将连接返回给连接池 113 | conn.close() 114 | 115 | # def get_data_by_id(self, key: int): 116 | # table_name = "cache_codegpt_answer" 117 | # query_sql = "select question, answer, embedding_data, model from {} where id={}".format(table_name, key) 118 | # conn_start = time.time() 119 | # conn = self.pool.connection() 120 | # 121 | # search_start = time.time() 122 | # try: 123 | # with conn.cursor() as cursor: 124 | # # 执行数据库操作 125 | # cursor.execute(query_sql) 126 | # resp = cursor.fetchone() 127 | # finally: 128 | # # 关闭连接,将连接返回给连接池 129 | # conn.close() 130 | # 131 | # if resp is not None and len(resp) == 4: 132 | # return resp 133 | # else: 134 | # return None 135 | 136 | def get_data_by_id(self, key: int): 137 | table_name = "open_cache_mm_answer" 138 | query_sql = "select question_text, image_url, image_id, answer, model from {} where id={}".format(table_name, key) 139 | conn = self.pool.connection() 140 | try: 141 | with conn.cursor() as cursor: 142 | cursor.execute(query_sql) 143 | resp = cursor.fetchone() 144 | finally: 145 | conn.close() 146 | 147 | if resp is not None and len(resp) == 5: 148 | return resp 149 | else: 150 | return None 151 | 152 | def update_hit_count_by_id(self, primary_id: int): 153 | table_name = "open_cache_mm_answer" 154 | update_sql = "UPDATE {} SET hit_count = hit_count+1 WHERE id={}".format(table_name, primary_id) 155 | conn = self.pool.connection() 156 | 157 | try: 158 | with conn.cursor() as cursor: 159 | cursor.execute(update_sql) 160 | conn.commit() 161 | finally: 162 | conn.close() 163 | 164 | def get_ids(self, deleted=True): 165 | pass 166 | 167 | def mark_deleted(self, keys): 168 | table_name = "open_cache_mm_answer" 169 | delete_sql = "Delete from {} WHERE id in ({})".format(table_name, ",".join([str(i) for i in keys])) 170 | 171 | conn = self.pool.connection() 172 | try: 173 | with conn.cursor() as cursor: 174 | cursor.execute(delete_sql) 175 | delete_count = cursor.rowcount 176 | conn.commit() 177 | finally: 178 | conn.close() 179 | return delete_count 180 | 181 | def model_deleted(self, model_name): 182 | table_name = "open_cache_mm_answer" 183 | delete_sql = "Delete from {} WHERE model='{}'".format(table_name, model_name) 184 | conn = self.pool.connection() 185 | try: 186 | with conn.cursor() as cursor: 187 | resp = cursor.execute(delete_sql) 188 | conn.commit() 189 | finally: 190 | conn.close() 191 | return resp 192 | 193 | def clear_deleted_data(self): 194 | pass 195 | 196 | def count(self, state: int = 0, is_all: bool = False): 197 | pass 198 | 199 | def close(self): 200 | pass 201 | 202 | def count_answers(self): 203 | pass 204 | -------------------------------------------------------------------------------- /modelcache_mm/manager/scalar_data/sql_storage_sqlite.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import json 3 | from typing import List 4 | from modelcache.manager.scalar_data.base import CacheStorage, CacheData 5 | import sqlite3 6 | 7 | 8 | class SQLStorage(CacheStorage): 9 | def __init__( 10 | self, 11 | db_type: str = "mysql", 12 | config=None, 13 | url="./sqlite.db" 14 | ): 15 | self._url = url 16 | self.create() 17 | 18 | def create(self): 19 | # answer_table_sql = """CREATE TABLE IF NOT EXISTS modelcache_llm_answer ( 20 | # id INTEGER PRIMARY KEY AUTOINCREMENT, 21 | # gmt_create TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, 22 | # gmt_modified TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, 23 | # question TEXT NOT NULL, 24 | # answer TEXT NOT NULL, 25 | # answer_type INTEGER NOT NULL, 26 | # hit_count INTEGER NOT NULL DEFAULT 0, 27 | # model VARCHAR(1000) NOT NULL, 28 | # embedding_data BLOB NOT NULL 29 | # ); 30 | # """ 31 | 32 | answer_table_sql = """CREATE TABLE IF NOT EXISTS `open_cache_mm_answer` ( 33 | `id` INTEGER PRIMARY KEY AUTOINCREMENT, 34 | `gmt_create` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, 35 | `gmt_modified` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, 36 | `question_text` TEXT NOT NULL, 37 | `image_url` VARCHAR(2048) NOT NULL, 38 | `answer` TEXT NOT NULL, 39 | `answer_type` INTEGER NOT NULL, 40 | `hit_count` INTEGER NOT NULL DEFAULT 0, 41 | `model` VARCHAR(1000) NOT NULL, 42 | `image_raw` BLOB DEFAULT NULL, 43 | `image_id` VARCHAR(1000) DEFAULT NULL 44 | ); 45 | """ 46 | 47 | log_table_sql = """CREATE TABLE IF NOT EXISTS modelcache_query_log ( 48 | id INTEGER PRIMARY KEY AUTOINCREMENT, 49 | gmt_create TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, 50 | gmt_modified TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, 51 | error_code INTEGER NOT NULL, 52 | error_desc VARCHAR(1000) NOT NULL, 53 | cache_hit VARCHAR(100) NOT NULL, 54 | delta_time REAL NOT NULL, 55 | model VARCHAR(1000) NOT NULL, 56 | query TEXT NOT NULL, 57 | hit_query TEXT NOT NULL, 58 | answer TEXT NOT NULL 59 | ); 60 | """ 61 | 62 | conn = sqlite3.connect(self._url) 63 | try: 64 | cursor = conn.cursor() 65 | cursor.execute(answer_table_sql) 66 | cursor.execute(log_table_sql) 67 | conn.commit() 68 | cursor.close() 69 | conn.close() 70 | finally: 71 | conn.close() 72 | 73 | def _insert(self, data: List): 74 | answer = data[0] 75 | text = data[1] 76 | image_url = data[2] 77 | image_id = data[3] 78 | model = data[4] 79 | answer_type = 0 80 | 81 | table_name = "open_cache_mm_answer" 82 | insert_sql = "INSERT INTO {} (question_text, image_url, image_id, answer, answer_type, model) VALUES (?, ?, ?, ?, ?, ?)".format(table_name) 83 | 84 | conn = sqlite3.connect(self._url) 85 | try: 86 | cursor = conn.cursor() 87 | values = (text, image_url, image_id, answer, answer_type, model) 88 | cursor.execute(insert_sql, values) 89 | conn.commit() 90 | id = cursor.lastrowid 91 | cursor.close() 92 | conn.close() 93 | finally: 94 | conn.close() 95 | return id 96 | 97 | def batch_insert(self, all_data: List[CacheData]): 98 | ids = [] 99 | for data in all_data: 100 | ids.append(self._insert(data)) 101 | return ids 102 | 103 | def insert_query_resp(self, query_resp, **kwargs): 104 | error_code = query_resp.get('errorCode') 105 | error_desc = query_resp.get('errorDesc') 106 | cache_hit = query_resp.get('cacheHit') 107 | model = kwargs.get('model') 108 | query = kwargs.get('query') 109 | delta_time = kwargs.get('delta_time') 110 | hit_query = query_resp.get('hit_query') 111 | answer = query_resp.get('answer') 112 | 113 | if isinstance(hit_query, list): 114 | hit_query = json.dumps(hit_query, ensure_ascii=False) 115 | 116 | table_name = "modelcache_query_log" 117 | insert_sql = "INSERT INTO {} (error_code, error_desc, cache_hit, model, query, delta_time, hit_query, answer) VALUES (%s, %s, %s, %s, %s, %s, %s, %s)".format(table_name) 118 | 119 | conn = sqlite3.connect(self._url) 120 | try: 121 | cursor = conn.cursor() 122 | values = (error_code, error_desc, cache_hit, model, query, delta_time, hit_query, answer) 123 | cursor.execute(insert_sql, values) 124 | conn.commit() 125 | cursor.close() 126 | conn.close() 127 | finally: 128 | conn.close() 129 | 130 | def get_data_by_id(self, key: int): 131 | table_name = "open_cache_mm_answer" 132 | query_sql = "select question, answer, embedding_data, model from {} where id={}".format(table_name, key) 133 | conn = sqlite3.connect(self._url) 134 | try: 135 | cursor = conn.cursor() 136 | cursor.execute(query_sql) 137 | resp = cursor.fetchone() 138 | conn.commit() 139 | cursor.close() 140 | conn.close() 141 | finally: 142 | conn.close() 143 | 144 | if resp is not None and len(resp) == 4: 145 | return resp 146 | else: 147 | return None 148 | 149 | def update_hit_count_by_id(self, primary_id: int): 150 | table_name = "open_cache_mm_answer" 151 | update_sql = "UPDATE {} SET hit_count = hit_count+1 WHERE id={}".format(table_name, primary_id) 152 | 153 | conn = sqlite3.connect(self._url) 154 | try: 155 | cursor = conn.cursor() 156 | cursor.execute(update_sql) 157 | conn.commit() 158 | cursor.close() 159 | conn.close() 160 | finally: 161 | # 关闭连接,将连接返回给连接池 162 | conn.close() 163 | 164 | def get_ids(self, deleted=True): 165 | pass 166 | 167 | def mark_deleted(self, keys): 168 | table_name = "open_cache_mm_answer" 169 | delete_sql = "Delete from {} WHERE id in ({})".format(table_name, ",".join([str(i) for i in keys])) 170 | conn = sqlite3.connect(self._url) 171 | try: 172 | cursor = conn.cursor() 173 | cursor.execute(delete_sql) 174 | delete_count = cursor.rowcount 175 | conn.commit() 176 | cursor.close() 177 | conn.close() 178 | finally: 179 | conn.close() 180 | return delete_count 181 | 182 | def model_deleted(self, model_name): 183 | table_name = "open_cache_mm_answer" 184 | delete_sql = "Delete from {} WHERE model='{}'".format(table_name, model_name) 185 | conn = sqlite3.connect(self._url) 186 | try: 187 | cursor = conn.cursor() 188 | resp = cursor.execute(delete_sql) 189 | conn.commit() 190 | cursor.close() 191 | conn.close() 192 | finally: 193 | conn.close() 194 | return resp 195 | 196 | def clear_deleted_data(self): 197 | pass 198 | 199 | def count(self, state: int = 0, is_all: bool = False): 200 | pass 201 | 202 | def close(self): 203 | pass 204 | 205 | def count_answers(self): 206 | pass 207 | -------------------------------------------------------------------------------- /modelcache_mm/manager/vector_data/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from modelcache_mm.utils.lazy_import import LazyImport 3 | 4 | vector_manager = LazyImport( 5 | "vector_manager", globals(), "modelcache_mm.manager.vector_data.manager" 6 | ) 7 | 8 | 9 | def VectorBase(name: str, **kwargs): 10 | return vector_manager.VectorBase.get(name, **kwargs) 11 | -------------------------------------------------------------------------------- /modelcache_mm/manager/vector_data/base.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from abc import ABC, abstractmethod 3 | import numpy as np 4 | from typing import List 5 | from dataclasses import dataclass 6 | 7 | 8 | @dataclass 9 | class VectorData: 10 | id: int 11 | data: np.ndarray 12 | 13 | 14 | class VectorBase(ABC): 15 | """VectorBase: base vector store interface""" 16 | 17 | @abstractmethod 18 | def add(self, datas: List[VectorData], model=None, mm_type=None): 19 | pass 20 | 21 | # @abstractmethod 22 | # def search(self, data: np.ndarray, top_k: int, model): 23 | # pass 24 | 25 | @abstractmethod 26 | def search(self, data: np.ndarray, top_k: int, model, mm_type): 27 | pass 28 | 29 | @abstractmethod 30 | def create(self, model=None, mm_type=None): 31 | pass 32 | 33 | @abstractmethod 34 | def rebuild(self, ids=None) -> bool: 35 | pass 36 | 37 | @abstractmethod 38 | def delete(self, ids) -> bool: 39 | pass 40 | 41 | @abstractmethod 42 | def rebuild_idx(self, model): 43 | pass 44 | 45 | def flush(self): 46 | pass 47 | 48 | def close(self): 49 | pass 50 | -------------------------------------------------------------------------------- /modelcache_mm/manager/vector_data/chroma.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | import logging 5 | from modelcache_mm.manager.vector_data.base import VectorBase, VectorData 6 | from modelcache_mm.utils import import_chromadb, import_torch 7 | from modelcache_mm.utils.index_util import get_mm_index_name 8 | 9 | import_torch() 10 | import_chromadb() 11 | 12 | import chromadb 13 | 14 | 15 | class Chromadb(VectorBase): 16 | 17 | def __init__( 18 | self, 19 | persist_directory="./chromadb", 20 | top_k: int = 1, 21 | ): 22 | self.top_k = top_k 23 | 24 | self._client = chromadb.PersistentClient(path=persist_directory) 25 | self._collection = None 26 | 27 | def create(self, model=None, mm_type=None): 28 | try: 29 | collection_name_model = get_mm_index_name(model, mm_type) 30 | # collection_name_model = self.collection_name + '_' + model 31 | self._client.get_or_create_collection(name=collection_name_model) 32 | except Exception as e: 33 | raise ValueError(str(e)) 34 | 35 | def add(self, datas: List[VectorData], model=None, mm_type=None): 36 | collection_name_model = get_mm_index_name(model, mm_type) 37 | self._collection = self._client.get_or_create_collection(name=collection_name_model) 38 | 39 | data_array, id_array = map(list, zip(*((data.data.tolist(), str(data.id)) for data in datas))) 40 | self._collection.add(embeddings=data_array, ids=id_array) 41 | 42 | def search(self, data: np.ndarray, top_k: int = -1, model=None, mm_type='mm'): 43 | collection_name_model = get_mm_index_name(model, mm_type) 44 | self._collection = self._client.get_or_create_collection(name=collection_name_model) 45 | 46 | if self._collection.count() == 0: 47 | return [] 48 | if top_k == -1: 49 | top_k = self.top_k 50 | results = self._collection.query( 51 | query_embeddings=[data.tolist()], 52 | n_results=top_k, 53 | include=["distances"], 54 | ) 55 | return list(zip(results["distances"][0], [int(x) for x in results["ids"][0]])) 56 | 57 | def delete(self, ids, model=None, mm_type=None): 58 | try: 59 | collection_name_model = get_mm_index_name(model, mm_type) 60 | self._collection = self._client.get_or_create_collection(name=collection_name_model) 61 | # 查询集合中实际存在的 ID 62 | ids_str = [str(x) for x in ids] 63 | existing_ids = set(self._collection.get(ids=ids_str).ids) 64 | 65 | # 删除存在的 ID 66 | if existing_ids: 67 | self._collection.delete(list(existing_ids)) 68 | 69 | # 返回实际删除的条目数量 70 | return len(existing_ids) 71 | 72 | except Exception as e: 73 | logging.error('Error during deletion: {}'.format(e)) 74 | raise ValueError(str(e)) 75 | 76 | def rebuild_idx(self, model, mm_type=None): 77 | collection_name_model = get_mm_index_name(model, mm_type) 78 | 79 | # 检查集合是否存在,如果存在则删除 80 | collections = self._client.list_collections() 81 | if any(col.name == collection_name_model for col in collections): 82 | self._client.delete_collection(collection_name_model) 83 | else: 84 | return 'model collection not found, please check!' 85 | 86 | try: 87 | self._client.create_collection(collection_name_model) 88 | except Exception as e: 89 | logging.info(f'rebuild_collection: {e}') 90 | raise ValueError(str(e)) 91 | 92 | def rebuild(self, ids=None): 93 | pass 94 | 95 | def flush(self): 96 | pass 97 | 98 | def close(self): 99 | pass 100 | -------------------------------------------------------------------------------- /modelcache_mm/manager/vector_data/faiss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | from typing import List 4 | import numpy as np 5 | from modelcache_mm.manager.vector_data.base import VectorBase, VectorData 6 | from modelcache_mm.utils import import_faiss 7 | import_faiss() 8 | import faiss # pylint: disable=C0413 9 | 10 | 11 | class Faiss(VectorBase): 12 | def __init__(self, 13 | index_file_path, 14 | dimension: int = 0, 15 | top_k: int = 1 16 | ): 17 | self._dimension = dimension 18 | self._index_file_path = index_file_path 19 | self._index = faiss.index_factory(self._dimension, "IDMap,Flat", faiss.METRIC_L2) 20 | self._top_k = top_k 21 | if os.path.isfile(index_file_path): 22 | self._index = faiss.read_index(index_file_path) 23 | 24 | def add(self, datas: List[VectorData], model=None, mm_type=None): 25 | data_array, id_array = map(list, zip(*((data.data, data.id) for data in datas))) 26 | np_data = np.array(data_array).astype("float32") 27 | ids = np.array(id_array) 28 | self._index.add_with_ids(np_data, ids) 29 | 30 | def search(self, data: np.ndarray, top_k: int, model, mm_type='mm'): 31 | if self._index.ntotal == 0: 32 | return None 33 | if top_k == -1: 34 | top_k = self._top_k 35 | np_data = np.array(data).astype("float32").reshape(1, -1) 36 | dist, ids = self._index.search(np_data, top_k) 37 | ids = [int(i) for i in ids[0]] 38 | return list(zip(dist[0], ids)) 39 | 40 | def rebuild_col(self, ids=None): 41 | try: 42 | self._index.reset() 43 | except Exception as e: 44 | return f"An error occurred during index rebuild: {e}" 45 | 46 | def rebuild(self, ids=None): 47 | return True 48 | 49 | def delete(self, ids): 50 | ids_to_remove = np.array(ids) 51 | self._index.remove_ids(faiss.IDSelectorBatch(ids_to_remove.size, faiss.swig_ptr(ids_to_remove))) 52 | 53 | def create(self, model=None, mm_type=None): 54 | pass 55 | # collection_name_model = get_mm_index_name(model, mm_type) 56 | # try: 57 | # index_prefix = get_mm_index_prefix(model, mm_type) 58 | # self.create_index(collection_name_model, mm_type, index_prefix) 59 | # except Exception as e: 60 | # raise ValueError(str(e)) 61 | # return 'success' 62 | 63 | def flush(self): 64 | faiss.write_index(self._index, self._index_file_path) 65 | 66 | def close(self): 67 | self.flush() 68 | 69 | def rebuild_idx(self, model): 70 | pass 71 | 72 | def count(self): 73 | return self._index.ntotal 74 | -------------------------------------------------------------------------------- /modelcache_mm/manager/vector_data/manager.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from modelcache_mm.utils.error import NotFoundError, ParamError 3 | 4 | TOP_K = 1 5 | FAISS_INDEX_PATH = "mm_faiss.index" 6 | DIMENSION = 0 7 | MILVUS_HOST = "localhost" 8 | MILVUS_PORT = 19530 9 | MILVUS_USER = "" 10 | MILVUS_PSW = "" 11 | MILVUS_SECURE = False 12 | MILVUS_INDEX_PARAMS = { 13 | "metric_type": "L2", 14 | "index_type": "HNSW", 15 | "params": {"M": 8, "efConstruction": 64}, 16 | } 17 | 18 | COLLECTION_NAME = "modelcache" 19 | 20 | 21 | class VectorBase: 22 | """ 23 | VectorBase to manager the vector base. 24 | """ 25 | 26 | def __init__(self): 27 | raise EnvironmentError( 28 | "VectorBase is designed to be instantiated, please using the `VectorBase.get(name)`." 29 | ) 30 | 31 | @staticmethod 32 | def check_dimension(dimension): 33 | if dimension <= 0: 34 | raise ParamError( 35 | f"the dimension should be greater than zero, current value: {dimension}." 36 | ) 37 | 38 | @staticmethod 39 | def get(name, **kwargs): 40 | top_k = kwargs.get("top_k", TOP_K) 41 | if name == "milvus": 42 | from modelcache.manager.vector_data.milvus import Milvus 43 | milvus_config = kwargs.get("milvus_config") 44 | dimension = kwargs.get("dimension", DIMENSION) 45 | VectorBase.check_dimension(dimension) 46 | host = milvus_config.get('milvus', 'host') 47 | port = milvus_config.get('milvus', 'port') 48 | user = milvus_config.get('milvus', 'user') 49 | password = milvus_config.get('milvus', 'password') 50 | 51 | secure = kwargs.get("secure", MILVUS_SECURE) 52 | collection_name = kwargs.get("collection_name", COLLECTION_NAME) 53 | index_params = kwargs.get("index_params", MILVUS_INDEX_PARAMS) 54 | search_params = kwargs.get("search_params", None) 55 | local_mode = kwargs.get("local_mode", False) 56 | local_data = kwargs.get("local_data", "./milvus_data") 57 | vector_base = Milvus( 58 | host=host, 59 | port=port, 60 | user=user, 61 | password=password, 62 | secure=secure, 63 | collection_name=collection_name, 64 | dimension=dimension, 65 | top_k=top_k, 66 | index_params=index_params, 67 | search_params=search_params, 68 | local_mode=local_mode, 69 | local_data=local_data 70 | ) 71 | elif name == "redis": 72 | from modelcache_mm.manager.vector_data.redis import RedisVectorStore 73 | redis_config = kwargs.get("redis_config") 74 | 75 | mm_dimension = kwargs.get("mm_dimension", DIMENSION) 76 | i_dimension = kwargs.get("i_dimension", DIMENSION) 77 | t_dimension = kwargs.get("t_dimension", DIMENSION) 78 | VectorBase.check_dimension(mm_dimension) 79 | VectorBase.check_dimension(i_dimension) 80 | VectorBase.check_dimension(t_dimension) 81 | 82 | host = redis_config.get('redis', 'host') 83 | port = redis_config.get('redis', 'port') 84 | user = redis_config.get('redis', 'user') 85 | password = redis_config.get('redis', 'password') 86 | namespace = kwargs.get("namespace", "") 87 | # collection_name = kwargs.get("collection_name", COLLECTION_NAME) 88 | 89 | vector_base = RedisVectorStore( 90 | host=host, 91 | port=port, 92 | username=user, 93 | password=password, 94 | namespace=namespace, 95 | top_k=top_k, 96 | mm_dimension=mm_dimension, 97 | i_dimension=i_dimension, 98 | t_dimension=t_dimension, 99 | ) 100 | elif name == "faiss": 101 | from modelcache_mm.manager.vector_data.faiss import Faiss 102 | dimension = kwargs.get("dimension", DIMENSION) 103 | VectorBase.check_dimension(dimension) 104 | 105 | index_path = kwargs.pop("index_path", FAISS_INDEX_PATH) 106 | vector_base = Faiss( 107 | index_file_path=index_path, 108 | dimension=dimension, 109 | top_k=top_k 110 | ) 111 | elif name == "chromadb": 112 | from modelcache_mm.manager.vector_data.chroma import Chromadb 113 | 114 | chromadb_config = kwargs.get("chromadb_config", None) 115 | persist_directory = chromadb_config.get('chromadb', 'persist_directory') 116 | vector_base = Chromadb( 117 | persist_directory=persist_directory, 118 | top_k=top_k, 119 | ) 120 | else: 121 | raise NotFoundError("vector store", name) 122 | return vector_base 123 | -------------------------------------------------------------------------------- /modelcache_mm/manager/vector_data/redis.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from typing import List 3 | import numpy as np 4 | from redis.commands.search.indexDefinition import IndexDefinition, IndexType 5 | from redis.commands.search.query import Query 6 | from redis.commands.search.field import VectorField, NumericField 7 | from redis.client import Redis 8 | 9 | from modelcache_mm.manager.vector_data.base import VectorBase, VectorData 10 | from modelcache_mm.utils import import_redis 11 | from modelcache_mm.utils.log import modelcache_log 12 | from modelcache_mm.utils.index_util import get_mm_index_name 13 | from modelcache_mm.utils.index_util import get_mm_index_prefix 14 | import_redis() 15 | 16 | 17 | class RedisVectorStore(VectorBase): 18 | def __init__( 19 | self, 20 | host: str = "localhost", 21 | port: str = "6379", 22 | username: str = "", 23 | password: str = "", 24 | mm_dimension: int = 0, 25 | i_dimension: int = 0, 26 | t_dimension: int = 0, 27 | top_k: int = 1, 28 | namespace: str = "", 29 | ): 30 | if mm_dimension <= 0: 31 | raise ValueError( 32 | f"invalid `dim` param: {mm_dimension} in the Milvus vector store." 33 | ) 34 | self._client = Redis( 35 | host=host, port=int(port), username=username, password=password 36 | ) 37 | self.top_k = top_k 38 | self.mm_dimension = mm_dimension 39 | self.i_dimension = i_dimension 40 | self.t_dimension = t_dimension 41 | self.namespace = namespace 42 | self.doc_prefix = f"{self.namespace}doc:" 43 | 44 | def _check_index_exists(self, index_name: str) -> bool: 45 | """Check if Redis index exists.""" 46 | try: 47 | self._client.ft(index_name).info() 48 | except: 49 | modelcache_log.info("Index does not exist") 50 | return False 51 | modelcache_log.info("Index already exists") 52 | return True 53 | 54 | def create_index(self, index_name, type, index_prefix): 55 | # dimension = self.dimension 56 | if type == 'IMG_TEXT': 57 | dimension = self.mm_dimension 58 | elif type == 'IMG': 59 | dimension = self.i_dimension 60 | elif type == 'TEXT': 61 | dimension = self.t_dimension 62 | else: 63 | raise ValueError('dimension type exception') 64 | if self._check_index_exists(index_name): 65 | modelcache_log.info( 66 | "The %s already exists, and it will be used directly", index_name 67 | ) 68 | return 'already_exists' 69 | else: 70 | id_field_name = "data_id" 71 | embedding_field_name = "data_vector" 72 | 73 | id = NumericField(name=id_field_name) 74 | embedding = VectorField(embedding_field_name, 75 | "HNSW", { 76 | "TYPE": "FLOAT32", 77 | "DIM": dimension, 78 | "DISTANCE_METRIC": "L2", 79 | "INITIAL_CAP": 1000, 80 | } 81 | ) 82 | fields = [id, embedding] 83 | definition = IndexDefinition(prefix=[index_prefix], index_type=IndexType.HASH) 84 | 85 | # create Index 86 | self._client.ft(index_name).create_index( 87 | fields=fields, definition=definition 88 | ) 89 | return 'create_success' 90 | 91 | def add(self, datas: List[VectorData], model=None, mm_type=None): 92 | for data in datas: 93 | id: int = data.id 94 | embedding = data.data.astype(np.float32).tobytes() 95 | index_prefix = get_mm_index_prefix(model, mm_type) 96 | id_field_name = "data_id" 97 | embedding_field_name = "data_vector" 98 | obj = {id_field_name: id, embedding_field_name: embedding} 99 | self._client.hset(f"{index_prefix}{id}", mapping=obj) 100 | 101 | def search(self, data: np.ndarray, top_k: int = -1, model=None, mm_type=None): 102 | index_name = get_mm_index_name(model, mm_type) 103 | id_field_name = "data_id" 104 | embedding_field_name = "data_vector" 105 | 106 | base_query = f'*=>[KNN 2 @{embedding_field_name} $vector AS distance]' 107 | query = ( 108 | Query(base_query) 109 | .sort_by("distance") 110 | .return_fields(id_field_name, "distance") 111 | .dialect(2) 112 | ) 113 | query_params = {"vector": data.astype(np.float32).tobytes()} 114 | results = ( 115 | self._client.ft(index_name) 116 | .search(query, query_params=query_params) 117 | .docs 118 | ) 119 | return [(float(result.distance), int(getattr(result, id_field_name))) for result in results] 120 | 121 | def create(self, model=None, mm_type=None): 122 | collection_name_model = get_mm_index_name(model, mm_type) 123 | try: 124 | index_prefix = get_mm_index_prefix(model, mm_type) 125 | self.create_index(collection_name_model, mm_type, index_prefix) 126 | except Exception as e: 127 | raise ValueError(str(e)) 128 | return 'success' 129 | 130 | def rebuild(self, ids=None) -> bool: 131 | pass 132 | 133 | def rebuild_idx(self, model, mm_type=None): 134 | for mm_type in ['IMG_TEXT', 'TEXT']: 135 | index_name = get_mm_index_name(model, mm_type) 136 | if self._check_index_exists(index_name): 137 | try: 138 | self._client.ft(index_name).dropindex(delete_documents=True) 139 | except Exception as e: 140 | raise ValueError(str(e)) 141 | try: 142 | index_prefix = get_mm_index_prefix(model, mm_type) 143 | self.create_index(index_name, mm_type, index_prefix) 144 | except Exception as e: 145 | raise ValueError(str(e)) 146 | 147 | def delete(self, ids) -> None: 148 | pipe = self._client.pipeline() 149 | for data_id in ids: 150 | pipe.delete(f"{self.doc_prefix}{data_id}") 151 | pipe.execute() 152 | 153 | def create(self, model=None, type=None): 154 | index_name = get_mm_index_name(model, type) 155 | index_prefix = get_mm_index_prefix(model, type) 156 | return self.create_index(index_name, type, index_prefix) 157 | 158 | def get_index_by_name(self, index_name): 159 | pass 160 | -------------------------------------------------------------------------------- /modelcache_mm/processor/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /modelcache_mm/processor/post.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import random 3 | from typing import List, Any 4 | 5 | 6 | def random_one(messages: List[Any]) -> Any: 7 | return random.choice(messages) 8 | 9 | 10 | def first(messages: List[Any]) -> Any: 11 | return messages[0] 12 | 13 | 14 | def nop(messages: List[Any]) -> Any: 15 | return messages 16 | -------------------------------------------------------------------------------- /modelcache_mm/processor/pre.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import re 3 | from typing import Dict, Any 4 | 5 | 6 | def insert_last_content(data: Dict[str, Any], **_: Dict[str, Any]) -> Any: 7 | return data.get("chat_info")[-1]["query"] 8 | 9 | 10 | def query_last_content(data: Dict[str, Any], **_: Dict[str, Any]) -> Any: 11 | return data.get("query")[-1]["content"] 12 | 13 | 14 | def last_content_without_prompt(data: Dict[str, Any], **params: Dict[str, Any]) -> Any: 15 | last_content_str = data.get("messages")[-1]["content"] 16 | prompts = params.get("prompts", []) 17 | if prompts is None: 18 | return last_content_str 19 | pattern = "|".join(prompts) 20 | new_content_str = re.sub(pattern, "", last_content_str) 21 | return new_content_str 22 | 23 | 24 | def all_content(data: Dict[str, Any], **_: Dict[str, Any]) -> Any: 25 | s = "" 26 | messages = data.get("messages") 27 | for i, message in enumerate(messages): 28 | if i == len(messages) - 1: 29 | s += message["content"] 30 | else: 31 | s += message["content"] + "\n" 32 | return s 33 | 34 | 35 | def nop(data: Dict[str, Any], **_: Dict[str, Any]) -> Any: 36 | return data 37 | 38 | 39 | def get_prompt(data: Dict[str, Any], **_: Dict[str, Any]) -> Any: 40 | return data.get("prompt") 41 | 42 | 43 | def get_file_name(data: Dict[str, Any], **_: Dict[str, Any]) -> str: 44 | return data.get("file").name 45 | 46 | 47 | def get_file_bytes(data: Dict[str, Any], **_: Dict[str, Any]) -> bytes: 48 | return data.get("file").peek() 49 | 50 | 51 | def get_input_str(data: Dict[str, Any], **_: Dict[str, Any]) -> str: 52 | input_data = data.get("input") 53 | return str(input_data["image"].peek()) + input_data["question"] 54 | 55 | 56 | def get_input_image_file_name(data: Dict[str, Any], **_: Dict[str, Any]) -> str: 57 | input_data = data.get("input") 58 | return input_data["image"].name 59 | 60 | 61 | def query_multi_splicing(data: Dict[str, Any], **_: Dict[str, Any]) -> Any: 62 | query_list = data.get("query") 63 | return multi_splicing(query_list) 64 | 65 | 66 | def insert_multi_splicing(data: Dict[str, Any], **_: Dict[str, Any]) -> Any: 67 | insert_query_list = data.get("chat_info")[-1]['query'] 68 | return multi_splicing(insert_query_list) 69 | 70 | 71 | def multi_splicing(data_list) -> Any: 72 | result_str = "" 73 | for d in data_list: 74 | role = d.get('role', '') 75 | content = d.get('content', '') 76 | result_str += role + "###" + content + "|||" 77 | 78 | # 去掉最后一个"|||" 79 | result_str = result_str[:-3] 80 | 81 | return result_str 82 | 83 | 84 | def multi_analysis(dialog_str): 85 | sub_strings = dialog_str.split('|||') 86 | dict_list = [] 87 | for s in sub_strings: 88 | parts = s.split('###') 89 | if len(parts) == 2: 90 | role = parts[0] 91 | content = parts[1] 92 | elif len(parts) > 2: 93 | role = parts[0] 94 | content = '###'.join(parts[1:]) 95 | else: 96 | content = 'exception' 97 | if content == '': 98 | d = {"role": role} 99 | else: 100 | d = {"role": role, "content": content} 101 | dict_list.append(d) 102 | result_list = dict_list 103 | 104 | # 输出结果 105 | return result_list 106 | 107 | 108 | def mm_insert_dict(data: Dict[str, Any], **_: Dict[str, Any]) -> Any: 109 | query_dict = data.get("chat_info")[-1]['query'] 110 | return query_dict 111 | 112 | 113 | def mm_query_dict(data: Dict[str, Any], **_: Dict[str, Any]) -> Any: 114 | query_dict = data.get("query") 115 | return query_dict 116 | -------------------------------------------------------------------------------- /modelcache_mm/report.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | class Report: 3 | def __init__(self): 4 | self.embedding_all_time = 0 5 | self.embedding_count = 0 6 | self.search_all_time = 0 7 | self.search_count = 0 8 | self.hint_cache_count = 0 9 | 10 | def embedding(self, delta_time): 11 | """Embedding counts and time. 12 | 13 | :param delta_time: additional runtime. 14 | """ 15 | self.embedding_all_time += delta_time 16 | self.embedding_count += 1 17 | 18 | def search(self, delta_time): 19 | """Search counts and time. 20 | 21 | :param delta_time: additional runtime. 22 | """ 23 | self.search_all_time += delta_time 24 | self.search_count += 1 25 | 26 | def average_embedding_time(self): 27 | """Average embedding time.""" 28 | return round( 29 | self.embedding_all_time / self.embedding_count 30 | if self.embedding_count != 0 31 | else 0, 32 | 4, 33 | ) 34 | 35 | def average_search_time(self): 36 | return round( 37 | self.search_all_time / self.search_count 38 | if self.embedding_count != 0 39 | else 0, 40 | 4, 41 | ) 42 | 43 | def hint_cache(self): 44 | self.hint_cache_count += 1 45 | -------------------------------------------------------------------------------- /modelcache_mm/similarity_evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from modelcache.similarity_evaluation.similarity_evaluation import SimilarityEvaluation 3 | from modelcache.utils.lazy_import import LazyImport 4 | 5 | exact_match = LazyImport( 6 | "exact_match", globals(), "modelcache.similarity_evaluation.exact_match" 7 | ) 8 | 9 | 10 | def ExactMatchEvaluation(): 11 | return exact_match.ExactMatchEvaluation() 12 | -------------------------------------------------------------------------------- /modelcache_mm/similarity_evaluation/distance.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from typing import Tuple, Dict, Any 3 | from modelcache.similarity_evaluation import SimilarityEvaluation 4 | 5 | 6 | class SearchDistanceEvaluation(SimilarityEvaluation): 7 | def __init__(self, max_distance=4.0, positive=False): 8 | self.max_distance = max_distance 9 | self.positive = positive 10 | 11 | def evaluation( 12 | self, src_dict: Dict[str, Any], cache_dict: Dict[str, Any], **_ 13 | ) -> float: 14 | distance, _ = cache_dict["search_result"] 15 | if distance < 0: 16 | distance = 0 17 | elif distance > self.max_distance: 18 | distance = self.max_distance 19 | if self.positive: 20 | return distance 21 | return self.max_distance - distance 22 | 23 | def range(self) -> Tuple[float, float]: 24 | return 0.0, self.max_distance 25 | -------------------------------------------------------------------------------- /modelcache_mm/similarity_evaluation/exact_match.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from typing import Tuple, Dict, Any 3 | from modelcache.similarity_evaluation.similarity_evaluation import SimilarityEvaluation 4 | 5 | 6 | class ExactMatchEvaluation(SimilarityEvaluation): 7 | def __init__(self): 8 | pass 9 | 10 | def evaluation( 11 | self, src_dict: Dict[str, Any], cache_dict: Dict[str, Any], **_ 12 | ) -> float: 13 | return 1 if cache_dict["question"] == src_dict["question"] else 0 14 | 15 | def range(self) -> Tuple[float, float]: 16 | return 0, 1 17 | -------------------------------------------------------------------------------- /modelcache_mm/similarity_evaluation/similarity_evaluation.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from abc import ABCMeta, abstractmethod 3 | from typing import Tuple, Dict, Any 4 | 5 | 6 | class SimilarityEvaluation(metaclass=ABCMeta): 7 | @abstractmethod 8 | def evaluation( 9 | self, src_dict: Dict[str, Any], cache_dict: Dict[str, Any], **kwargs 10 | ) -> float: 11 | pass 12 | 13 | @abstractmethod 14 | def range(self) -> Tuple[float, float]: 15 | pass 16 | -------------------------------------------------------------------------------- /modelcache_mm/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import importlib.util 3 | from typing import Optional 4 | from modelcache.utils.dependency_control import prompt_install 5 | 6 | 7 | def _check_library(libname: str, prompt: bool = True, package: Optional[str] = None): 8 | is_avail = False 9 | if importlib.util.find_spec(libname): 10 | is_avail = True 11 | if not is_avail and prompt: 12 | prompt_install(package if package else libname) 13 | return is_avail 14 | 15 | 16 | def import_onnxruntime(): 17 | _check_library("onnxruntime") 18 | 19 | 20 | def import_huggingface(): 21 | _check_library("transformers") 22 | 23 | 24 | def import_huggingface_hub(): 25 | _check_library("huggingface_hub", package="huggingface-hub") 26 | 27 | 28 | def import_pymysql(): 29 | _check_library("pymysql") 30 | 31 | 32 | def import_sql_client(db_name): 33 | if db_name in ["mysql"]: 34 | import_pymysql() 35 | 36 | 37 | def import_pymilvus(): 38 | _check_library("pymilvus") 39 | 40 | 41 | def import_milvus_lite(): 42 | _check_library("milvus") 43 | 44 | 45 | def import_faiss(): 46 | _check_library("faiss", package="faiss-cpu") 47 | 48 | 49 | def import_torch(): 50 | _check_library("torch") 51 | 52 | 53 | def import_fasttext(): 54 | _check_library("fasttext") 55 | 56 | 57 | def import_paddle(): 58 | prompt_install("protobuf==3.20.0") 59 | _check_library("paddlepaddle") 60 | 61 | 62 | def import_paddlenlp(): 63 | _check_library("paddlenlp") 64 | 65 | 66 | def import_timm(): 67 | _check_library("timm", package="timm") 68 | 69 | 70 | def import_pillow(): 71 | _check_library("PIL", package="pillow") 72 | 73 | 74 | def import_redis(): 75 | _check_library("redis") 76 | 77 | 78 | def import_chromadb(): 79 | _check_library("chromadb", package="chromadb") -------------------------------------------------------------------------------- /modelcache_mm/utils/cache_func.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | def cache_all(*_, **__): 3 | return True -------------------------------------------------------------------------------- /modelcache_mm/utils/dependency_control.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import subprocess 3 | from modelcache.utils.error import PipInstallError 4 | from modelcache.utils.log import modelcache_log 5 | 6 | 7 | def prompt_install(package: str, warn: bool = False): # pragma: no cover 8 | """ 9 | Function used to prompt user to install a package. 10 | """ 11 | cmd = f"pip install {package}" 12 | try: 13 | if warn and input(f"Install {package}? Y/n: ") != "Y": 14 | raise ModuleNotFoundError(f"No module named {package}") 15 | subprocess.check_call(cmd, shell=True) 16 | modelcache_log.info("%s installed successfully!", package) 17 | except subprocess.CalledProcessError as e: 18 | raise PipInstallError(package) from e 19 | -------------------------------------------------------------------------------- /modelcache_mm/utils/env_config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /modelcache_mm/utils/error.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | class CacheError(Exception): 3 | """ModelCache base error""" 4 | 5 | 6 | class NotInitError(CacheError): 7 | """Raise when the cache has been used before it's inited""" 8 | def __init__(self): 9 | super().__init__("The cache should be inited before using") 10 | 11 | 12 | class RemoveError(CacheError): 13 | """Raise when the cache has been used before it's inited""" 14 | def __init__(self): 15 | super().__init__("The cache remove error") 16 | 17 | class NotFoundError(CacheError): 18 | """Raise when getting an unsupported store.""" 19 | def __init__(self, store_type, current_type_name): 20 | super().__init__(f"Unsupported ${store_type}: {current_type_name}") 21 | 22 | 23 | class ParamError(CacheError): 24 | """Raise when receiving an invalid param.""" 25 | 26 | 27 | class PipInstallError(CacheError): 28 | """Raise when failed to install package.""" 29 | def __init__(self, package): 30 | super().__init__(f"Ran into error installing {package}.") 31 | 32 | 33 | class MultiTypeError(CacheError): 34 | def __init__(self): 35 | super().__init__("multichat type error, please check") 36 | -------------------------------------------------------------------------------- /modelcache_mm/utils/index_util.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | def get_index_name(model): 5 | return 'multicache' + '_' + model 6 | 7 | 8 | def get_index_prefix(model): 9 | return 'prefix' + '_' + model 10 | 11 | 12 | def get_mm_index_name(model, mm_type): 13 | if mm_type not in ['IMG_TEXT', 'mm', 'IMG', 'image', 'TEXT', 'text']: 14 | raise ValueError('mm_type is not normal!') 15 | if mm_type == 'IMG_TEXT': 16 | mm_type = 'mm' 17 | elif mm_type == 'IMG': 18 | mm_type = 'image' 19 | elif mm_type == 'TEXT': 20 | mm_type = 'text' 21 | return 'multi_cache' + '_' + model + '_' + mm_type 22 | 23 | 24 | def get_mm_index_prefix(model, mm_type): 25 | if mm_type not in ['IMG_TEXT', 'mm', 'IMG', 'image', 'TEXT', 'text']: 26 | raise ValueError('mm_type is not normal!') 27 | if mm_type == 'IMG_TEXT': 28 | mm_type = 'mm' 29 | elif mm_type == 'IMG': 30 | mm_type = 'image' 31 | elif mm_type == 'TEXT': 32 | mm_type = 'text' 33 | return 'prefix' + '_' + model + '_' + mm_type 34 | -------------------------------------------------------------------------------- /modelcache_mm/utils/lazy_import.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import importlib 3 | from types import ModuleType 4 | 5 | 6 | class LazyImport(ModuleType): 7 | """ 8 | Lazily import a module. 9 | """ 10 | def __init__(self, local_name, parent_module_globals, name): 11 | self._local_name = local_name 12 | self._parent_module_globals = parent_module_globals 13 | super().__init__(name) 14 | 15 | def _load(self): 16 | module = importlib.import_module(self.__name__) 17 | self._parent_module_globals[self._local_name] = module 18 | self.__dict__.update(module.__dict__) 19 | return module 20 | 21 | def __getattr__(self, item): 22 | module = self._load() 23 | return getattr(module, item) 24 | 25 | def __dir__(self): 26 | module = self._load() 27 | return dir(module) 28 | -------------------------------------------------------------------------------- /modelcache_mm/utils/log.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import logging 3 | 4 | FORMAT = '%(asctime)s - %(thread)d - %(filename)s-%(module)s:%(lineno)s - %(levelname)s: %(message)s' 5 | logging.basicConfig(format=FORMAT) 6 | 7 | modelcache_log = logging.getLogger('modelcache') 8 | -------------------------------------------------------------------------------- /modelcache_mm/utils/time.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import time 3 | from modelcache import cache 4 | 5 | 6 | def time_cal(func, func_name=None, report_func=None): 7 | def inner(*args, **kwargs): 8 | time_start = time.time() 9 | res = func(*args, **kwargs) 10 | delta_time = time.time() - time_start 11 | if cache.config.log_time_func: 12 | cache.config.log_time_func( 13 | func.__name__ if func_name is None else func_name, delta_time 14 | ) 15 | if report_func is not None: 16 | report_func(delta_time) 17 | return res 18 | 19 | return inner 20 | 21 | -------------------------------------------------------------------------------- /mulicache-readme-cn.md: -------------------------------------------------------------------------------- 1 | # MultiModal Cache 2 | 3 | 为满足多模态的性能要求,我们在 LLModel Cache 的基础上,开发了 MultiModal Cache 系统。MultiModal Cache 增强了 ModelCache 功能,架优化架构,适应多种应用场景。 4 | 5 | - [MultiModal Cache](#multimodal-cache) 6 | - [最新动态](#最新动态) 7 | - [特性](#特性) 8 | - [性能](#性能) 9 | - [效果评估](#效果评估) 10 | - [参与贡献](#参与贡献) 11 | 12 | ## 最新动态 13 | 14 | - [2024.12.12] MultiModal Cache 系统正式发布。 15 | 16 | ## 特性 17 | 18 | | 场景 | 数据类型 | 图像格式 | 数据隔离 | 19 | |------|----------|----------|----------| 20 | | 文本对话 | 文本 | 不适用 | 支持 | 21 | | 图文理解 | 文本+图像 | image_url/image_base64 | 支持 | 22 | 23 | - **兼容性**:支持文本和图片链接(image_url)和图片 Base64 编码三种数据格式及其组合。 24 | - **数据隔离**:支持多模型数据隔离,允许不同数据模型在同一系统中独立运行。 25 | - **模态隔离**:支持同一模型下不同模态数据(如文本和图像)的隔离处理。 26 | 27 | ## 性能 28 | 29 | 我们在生产环境中使用企业级数据库对 MultiModal Cache 进行了全面的性能评估。以下是详细的性能数据: 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 |
请求类型Cache Hit总耗时范围组件组件耗时
TextHit420ms-520msMulti-Encoder (Text):~300ms
向量存储检索40-50ms
关系存储检索60-70ms
Not Hit300ms+N(s)Multi-Encoder (Text):~300ms
向量存储检索40-50ms
大模型调用N (s)
IMG_TEXTHit600ms-800msMulti-Encoder (image+text)~600ms
向量存储检索40-50ms
关系存储检索60-70ms
Not Hit600ms+N(s)Multi-Encoder (image+text)~600ms
向量存储检索40-50ms
大模型调用N (s)
98 | 99 | 根据目前的评估结果,Embedding 的推理时间存在较大的优化空间。 100 | **说明**:使用嵌入式数据库可能会进一步提升性能。 101 | 102 | ## 效果评估 103 | 104 | 为全面评估 Cache 对模型服务的影响,我们进行了端到端的性能测试,ua 比较了有 Cache 和无 Cache 两种服务配置。我们使用了 5000 个测试用例的数据集进行自动化测试。 105 | 106 | - 有 Cache 的预发模型服务:观察其响应时间,预期 Cache 的引入能够显著提升服务的性能,降低延迟。 107 | - 无 Cache 的线上模型服务,以获取其原始性能指标和输出结果。这些数据将作为对比基准。 108 | 109 | 为了确保 Cache 引入后的数据准确性和一致性,我们比较了两个服务返回的结果,验证了 Cache 机制是否会影响最终用户收到的回复内容。 110 | 111 | 与原始的直接模型调用方式相比,Cache Service 的调用耗时数据呈现出稳定的分布特征,性能上并不会随着模型参数规模的增加而受到影响。在传统情况下,随着模型参数规模的扩大,模型调用的耗时往往会上升,这是因为更大规模的模型需要更多的计算资源。Cache 服务通过存储经常访问的数据来避免重复的计算,从而一定程度上解耦了耗时与模型复杂性之间的关联。 112 | 113 | ![cache-service-cost-time-distribution](docs/cache-service-cost-time-distribution.webp) 114 | 115 | 我们对缓存命中的耗时与实际调用模型的耗时进行了对比分析。实验数据表明,在集成 Cache Service之后,基于 llama7B 模型,缓存命中所带来的性能提升超过了 40%。预计随着模型的持续迭代与优化,性能提升的幅度将会有更进一步的增长。 116 | 117 | ![time-cost-comparison](docs/time-cost-comparison.webp) 118 | 119 | ## 参与贡献 120 | 121 | MultiModal Cache 是一个充满潜力的开源项目,我们欢迎各种形式的贡献: 122 | 123 | - 提交问题和建议 124 | - 参与代码编写 125 | - 完善文档和示例 126 | 127 | 无论您是经验丰富的开发者还是新手,您的参与都将使这个项目更加出色,同时为开源社区做出贡献。 128 | -------------------------------------------------------------------------------- /reference_doc/create_table.sql: -------------------------------------------------------------------------------- 1 | CREATE TABLE `modelcache_llm_answer` ( 2 | `id` bigint(20) unsigned NOT NULL AUTO_INCREMENT comment '主键', 3 | `gmt_create` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP comment '创建时间', 4 | `gmt_modified` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP comment '修改时间', 5 | `question` text NOT NULL comment 'question', 6 | `answer` text NOT NULL comment 'answer', 7 | `answer_type` int(11) NOT NULL comment 'answer_type', 8 | `hit_count` int(11) NOT NULL DEFAULT '0' comment 'hit_count', 9 | `model` varchar(1000) NOT NULL comment 'model', 10 | `embedding_data` blob NOT NULL comment 'embedding_data', 11 | `is_deleted` tinyint(1) NOT NULL DEFAULT '0' COMMENT 'delete state(0 Not deleted,-1 deleted)', 12 | PRIMARY KEY(`id`) 13 | ) AUTO_INCREMENT = 1 DEFAULT CHARSET = utf8mb4 COMMENT = 'cache_codegpt_answer'; 14 | 15 | 16 | CREATE TABLE `modelcache_query_log` ( 17 | `id` bigint(20) unsigned NOT NULL AUTO_INCREMENT comment '主键', 18 | `gmt_create` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP comment '创建时间', 19 | `gmt_modified` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP comment '修改时间', 20 | `error_code` int(11) NOT NULL comment 'errorCode', 21 | `error_desc` varchar(1000) NOT NULL comment 'errorDesc', 22 | `cache_hit` varchar(100) NOT NULL comment 'cacheHit', 23 | `delta_time` float NOT NULL comment 'delta_time', 24 | `model` varchar(1000) NOT NULL comment 'model', 25 | `query` text NOT NULL comment 'query', 26 | `hit_query` text NOT NULL comment 'hitQuery', 27 | `answer` text NOT NULL comment 'answer', 28 | PRIMARY KEY(`id`) 29 | ) AUTO_INCREMENT = 1 DEFAULT CHARSET = utf8mb4 COMMENT = 'modelcache_query_log'; 30 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cachetools==5.3.1 2 | DBUtils==1.4 3 | Flask==3.0.0 4 | numpy==1.24.4 5 | onnxruntime==1.16.1 6 | openai==0.28.1 7 | pymilvus==2.3.1 8 | PyMySQL==1.1.0 9 | Requests==2.31.0 10 | torch==2.1.1 11 | transformers==4.38.2 12 | faiss-cpu==1.7.4 13 | redis==5.0.1 14 | modelscope==1.14.0 15 | fastapi==0.115.5 16 | uvicorn==0.32.0 17 | chromadb==0.5.23 18 | elasticsearch==7.10.0 19 | snowflake-id==1.0.2 20 | --------------------------------------------------------------------------------