├── .gitignore
├── Dockerfile
├── LICENSE
├── README.md
├── README_CN.md
├── data
├── milvus
│ ├── embedEtcd.yaml
│ └── user.yaml
└── mysql
│ ├── init
│ └── init.sql
│ └── my.conf
├── docker-compose.yaml
├── docs
├── 1.what-is-model-cache.md
├── 2.model-cache-features.md
├── 3.model-cache-quick-start.md
├── 4.create-cache.md
├── cache-service-cost-time-distribution.webp
├── codefuse-LOGO.png
├── modelcache_modules_20231114.png
├── modelcache_modules_20240409.png
├── script
│ └── get_input_embedding_script.py
└── time-cost-comparison.webp
├── examples
├── __init__.py
├── embedding
│ ├── __init__.py
│ └── huggingface_tei_example.py
└── flask
│ ├── __init__.py
│ ├── llms_cache
│ ├── __init__.py
│ ├── data_insert.py
│ ├── data_query.py
│ ├── data_query_long.py
│ └── register.py
│ └── multi_cache
│ ├── __init__.py
│ ├── data_insert.py
│ ├── data_query.py
│ ├── register.py
│ └── remove.py
├── fastapi4modelcache.py
├── fastapi4modelcache_demo.py
├── flask4modelcache.py
├── flask4modelcache_demo.py
├── flask4multicache.py
├── flask4multicache_demo.py
├── model
├── clip_zh
│ └── __init__.py
└── text2vec-base-chinese
│ ├── config.json
│ ├── logs.txt
│ ├── modules.json
│ ├── sentence_bert_config.json
│ ├── special_tokens_map.json
│ ├── tokenizer_config.json
│ └── vocab.txt
├── modelcache
├── __init__.py
├── adapter
│ ├── __init__.py
│ ├── adapter.py
│ ├── adapter_insert.py
│ ├── adapter_query.py
│ ├── adapter_register.py
│ └── adapter_remove.py
├── config.py
├── config
│ ├── chromadb_config.ini
│ ├── elasticsearch_config.ini
│ ├── milvus_config.ini
│ ├── mysql_config.ini
│ └── redis_config.ini
├── core.py
├── embedding
│ ├── __init__.py
│ ├── base.py
│ ├── bge_m3.py
│ ├── data2vec.py
│ ├── fasttext.py
│ ├── huggingface.py
│ ├── huggingface_tei.py
│ ├── llmEmb.py
│ ├── onnx.py
│ ├── paddlenlp.py
│ ├── string_text.py
│ └── timm_embedding.py
├── manager
│ ├── __init__.py
│ ├── data_manager.py
│ ├── eviction
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── manager.py
│ │ └── memory_cache.py
│ ├── eviction_manager.py
│ ├── factory.py
│ ├── object_data
│ │ ├── __init__.py
│ │ └── base.py
│ ├── scalar_data
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── manager.py
│ │ ├── sql_storage.py
│ │ ├── sql_storage_es.py
│ │ └── sql_storage_sqlite.py
│ └── vector_data
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── chroma.py
│ │ ├── faiss.py
│ │ ├── manager.py
│ │ ├── milvus.py
│ │ └── redis.py
├── processor
│ ├── __init__.py
│ ├── post.py
│ └── pre.py
├── report.py
├── similarity_evaluation
│ ├── __init__.py
│ ├── distance.py
│ ├── exact_match.py
│ └── similarity_evaluation.py
└── utils
│ ├── __init__.py
│ ├── cache_func.py
│ ├── dependency_control.py
│ ├── env_config.py
│ ├── error.py
│ ├── index_util.py
│ ├── lazy_import.py
│ ├── log.py
│ ├── model_filter.py
│ └── time.py
├── modelcache_mm
├── __init__.py
├── adapter
│ ├── __init__.py
│ ├── adapter.py
│ ├── adapter_insert.py
│ ├── adapter_query.py
│ ├── adapter_register.py
│ └── adapter_remove.py
├── config.py
├── config
│ ├── chromadb_config.ini
│ ├── elasticsearch_config.ini
│ ├── milvus_config.ini
│ ├── mysql_config.ini
│ └── redis_config.ini
├── core.py
├── embedding
│ ├── __init__.py
│ ├── base.py
│ ├── clip.py
│ ├── string.py
│ └── timm.py
├── manager
│ ├── __init__.py
│ ├── data_manager.py
│ ├── eviction
│ │ ├── __init__.py
│ │ └── base.py
│ ├── eviction_manager.py
│ ├── factory.py
│ ├── object_data
│ │ ├── __init__.py
│ │ └── base.py
│ ├── scalar_data
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── manager.py
│ │ ├── sql_storage.py
│ │ ├── sql_storage_es.py
│ │ └── sql_storage_sqlite.py
│ └── vector_data
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── chroma.py
│ │ ├── faiss.py
│ │ ├── manager.py
│ │ └── redis.py
├── processor
│ ├── __init__.py
│ ├── post.py
│ └── pre.py
├── report.py
├── similarity_evaluation
│ ├── __init__.py
│ ├── distance.py
│ ├── exact_match.py
│ └── similarity_evaluation.py
└── utils
│ ├── __init__.py
│ ├── cache_func.py
│ ├── dependency_control.py
│ ├── env_config.py
│ ├── error.py
│ ├── index_util.py
│ ├── lazy_import.py
│ ├── log.py
│ └── time.py
├── mulicache-readme-cn.md
├── reference_doc
└── create_table.sql
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 | *.DS_Store
30 | *.manifest
31 | *.spec
32 |
33 | # Installer logs
34 | pip-log.txt
35 | pip-delete-this-directory.txt
36 |
37 | # Unit test / coverage reports
38 | htmlcov/
39 | .tox/
40 | .nox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | *.py,cover
48 | .hypothesis/
49 | .pytest_cache/
50 |
51 | # Translations
52 | *.mo
53 | *.pot
54 |
55 | # Django stuff:
56 | *.log
57 | local_settings.py
58 | db.sqlite3
59 | db.sqlite3-journal
60 | *.db
61 |
62 | # Flask stuff:
63 | instance/
64 | .webassets-cache
65 |
66 | # Scrapy stuff:
67 | .scrapy
68 |
69 | # Sphinx documentation
70 | docs/_build/
71 |
72 | # PyBuilder
73 | target/
74 |
75 | # Jupyter Notebook
76 | .ipynb_checkpoints
77 |
78 | # IPython
79 | profile_default/
80 | ipython_config.py
81 |
82 | # pyenv
83 | .python-version
84 |
85 | __pypackages__/
86 |
87 | # Celery stuff
88 | celerybeat-schedule
89 | celerybeat.pid
90 |
91 | # SageMath parsed files
92 | *.sage.py
93 |
94 | # Environments
95 | .env
96 | .venv
97 | env/
98 | venv/
99 | ENV/
100 | env.bak/
101 | venv.bak/
102 |
103 | # Spyder project settings
104 | .spyderproject
105 | .spyproject
106 |
107 | # Rope project settings
108 | .ropeproject
109 |
110 | # mkdocs documentation
111 | /site
112 |
113 | # mypy
114 | .mypy_cache/
115 | .dmypy.json
116 | dmypy.json
117 |
118 | # Pyre type checker
119 | .pyre/
120 |
121 | .idea
122 | **/data_map**.txt
123 | **/faiss**.index
124 | **/sqlite**.db
125 | **/**.db
126 | **/example.py
127 | **/example.db
128 | **/.chroma
129 |
130 | /fuhui_dev
131 | *.index
132 | *model.onnx
133 |
134 | /data_analyse
135 | /embedding_npy
136 | /flask_server
137 | *.bin
138 | **/maya_embedding_service
139 |
140 | *.ini
141 |
142 | **/multicache_serving.py
143 | **/modelcache_serving.py
144 |
145 | **/model/
146 |
147 | /data/milvus/db
148 | /data/mysql/db
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM python:3.9-slim-bookworm
2 |
3 | WORKDIR /home/user
4 |
5 | COPY ./requirements.txt /home/user/docker_requirements.txt
6 |
7 | RUN pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
8 | RUN pip install -r /home/user/docker_requirements.txt --retries 5 --timeout 120
9 |
--------------------------------------------------------------------------------
/data/milvus/embedEtcd.yaml:
--------------------------------------------------------------------------------
1 | listen-client-urls: http://0.0.0.0:2379
2 | advertise-client-urls: http://0.0.0.0:2379
3 | quota-backend-bytes: 4294967296
4 | auto-compaction-mode: revision
5 | auto-compaction-retention: '1000'
6 |
--------------------------------------------------------------------------------
/data/milvus/user.yaml:
--------------------------------------------------------------------------------
1 | # Extra config to override default milvus.yaml
--------------------------------------------------------------------------------
/data/mysql/init/init.sql:
--------------------------------------------------------------------------------
1 | CREATE DATABASE IF NOT EXISTS `modelcache`;
2 |
3 | USE `modelcache`;
4 |
5 | CREATE TABLE IF NOT EXISTS `modelcache_llm_answer` (
6 | `id` bigint(20) unsigned NOT NULL AUTO_INCREMENT comment '主键',
7 | `gmt_create` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP comment '创建时间',
8 | `gmt_modified` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP comment '修改时间',
9 | `question` text NOT NULL comment 'question',
10 | `answer` text NOT NULL comment 'answer',
11 | `answer_type` int(11) NOT NULL comment 'answer_type',
12 | `hit_count` int(11) NOT NULL DEFAULT '0' comment 'hit_count',
13 | `model` varchar(1000) NOT NULL comment 'model',
14 | `embedding_data` blob NOT NULL comment 'embedding_data',
15 | `is_deleted` tinyint(1) NOT NULL DEFAULT '0' COMMENT 'delete state(0 Not deleted,-1 deleted)',
16 | PRIMARY KEY(`id`)
17 | ) AUTO_INCREMENT = 1 DEFAULT CHARSET = utf8mb4 COMMENT = 'cache_codegpt_answer';
18 |
19 | CREATE TABLE IF NOT EXISTS `modelcache_query_log` (
20 | `id` bigint(20) unsigned NOT NULL AUTO_INCREMENT comment '主键',
21 | `gmt_create` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP comment '创建时间',
22 | `gmt_modified` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP comment '修改时间',
23 | `error_code` int(11) NOT NULL comment 'errorCode',
24 | `error_desc` varchar(1000) NOT NULL comment 'errorDesc',
25 | `cache_hit` varchar(100) NOT NULL comment 'cacheHit',
26 | `delta_time` float NOT NULL comment 'delta_time',
27 | `model` varchar(1000) NOT NULL comment 'model',
28 | `query` text NOT NULL comment 'query',
29 | `hit_query` text NOT NULL comment 'hitQuery',
30 | `answer` text NOT NULL comment 'answer',
31 | PRIMARY KEY(`id`)
32 | ) AUTO_INCREMENT = 1 DEFAULT CHARSET = utf8mb4 COMMENT = 'modelcache_query_log';
33 |
--------------------------------------------------------------------------------
/data/mysql/my.conf:
--------------------------------------------------------------------------------
1 | [mysqld]
2 | character-set-server=utf8mb4
3 | collation-server=utf8mb4_unicode_ci
4 |
5 | [client]
6 | default-character-set=utf8mb4
7 |
8 | [mysql]
9 | default-character-set=utf8mb4
--------------------------------------------------------------------------------
/docker-compose.yaml:
--------------------------------------------------------------------------------
1 | version: 'Beta'
2 | services:
3 | mysql:
4 | image: mysql:8.0.23
5 | container_name: mysql
6 | environment:
7 | MYSQL_ROOT_PASSWORD: 'root'
8 | MYSQL_DATABASE: 'modelcache'
9 | MYSQL_USER: 'modelcache'
10 | MYSQL_PASSWORD: 'modelcache'
11 | ports:
12 | - 3306:3306
13 | volumes:
14 | - ./data/mysql/db:/var/lib/mysql
15 | - ./data/mysql/my.cnf:/etc/mysql/conf.d/my.cnf
16 | - ./data/mysql/init:/docker-entrypoint-initdb.d
17 | restart: on-failure
18 | networks:
19 | - modelcache
20 |
21 | milvus:
22 | image: milvusdb/milvus:v2.5.0-beta
23 | container_name: milvus
24 | security_opt:
25 | - seccomp:unconfined
26 | environment:
27 | ETCD_USE_EMBED: true
28 | ETCD_DATA_DIR: /var/lib/milvus/etcd
29 | ETCD_CONFIG_PATH: /milvus/configs/embedEtcd.yaml
30 | COMMON_STORAGETYPE: local
31 | volumes:
32 | - ./data/milvus/db:/var/lib/milvus
33 | - ./data/milvus/embedEtcd.yaml:/milvus/configs/embedEtcd.yaml
34 | - ./data/milvus/user.yaml:/milvus/configs/user.yaml
35 | ports:
36 | - 19530:19530
37 | - 9091:9091
38 | - 2379:2379
39 | healthcheck:
40 | test: ["CMD", "curl", "-f", "http://localhost:9091/healthz"]
41 | interval: 30s
42 | start_period: 90s
43 | timeout: 20s
44 | retries: 3
45 | networks:
46 | - modelcache
47 | restart: on-failure
48 | command: milvus run standalone
49 |
50 | modelcache:
51 | build:
52 | context: .
53 | dockerfile: Dockerfile
54 | container_name: modelcache
55 | image: modelcache:0.1.0
56 | ports:
57 | - 5000:5000
58 | volumes:
59 | - ./model:/home/user/model
60 | - ./modelcache:/home/user/modelcache
61 | - ./modelcache_mm:/home/user/modelcache_mm
62 | - ./fastapi4modelcache.py:/home/user/fastapi4modelcache.py
63 | networks:
64 | - modelcache
65 | restart: on-failure
66 | command: sh -c "uvicorn fastapi4modelcache:app --reload --reload-dir /home/user --port=5000 --host=0.0.0.0"
67 |
68 | networks:
69 | modelcache:
70 | external: true
--------------------------------------------------------------------------------
/docs/1.what-is-model-cache.md:
--------------------------------------------------------------------------------
1 | # What is ModelCache
2 |
3 | In ModelCache, we adopted the main idea of GPTCache, includes core modules: adapter, embedding, similarity, and data_manager. The adapter module is responsible for handling the business logic of various tasks and can connect the embedding, similarity, and data_manager modules. The embedding module is mainly responsible for converting text into semantic vector representations, it transforms user queries into vector form.The rank module is used for sorting and evaluating the similarity of the recalled vectors. The data_manager module is primarily used for managing the database. In order to better facilitate industrial applications, we have made architectural and functional upgrades as follows:
4 |
5 | ## Architecture
6 |
7 | 
8 |
9 | ## Function comparison
10 |
11 | We've implemented several key updates to our repository. We've resolved network issues with Hugging Face and improved inference speed by introducing local embedding capabilities. Due to limitations in SqlAlchemy, we've redesigned our relational database interaction module for more flexible operations. We've added multi-tenancy support to ModelCache, recognizing the need for multiple users and models in LLM products. Lastly, we've made initial adjustments for better compatibility with system commands and multi-turn dialogues.
12 |
13 |
14 |
15 | Module |
16 | Function |
17 |
18 |
19 |
20 | ModelCache |
21 | GPTCache |
22 |
23 |
24 | Basic Interface |
25 | Data query interface |
26 | ☑ |
27 | ☑ |
28 |
29 |
30 | Data writing interface |
31 | ☑ |
32 | ☑ |
33 |
34 |
35 | Embedding |
36 | Embedding model configuration |
37 | ☑ |
38 | ☑ |
39 |
40 |
41 | Large model embedding layer |
42 | ☑ |
43 | |
44 |
45 |
46 | BERT model long text processing |
47 | ☑ |
48 | |
49 |
50 |
51 | Large model invocation |
52 | Decoupling from large models |
53 | ☑ |
54 | |
55 |
56 |
57 | Local loading of embedding model |
58 | ☑ |
59 | |
60 |
61 |
62 | Data isolation |
63 | Model data isolation |
64 | ☑ |
65 | ☑ |
66 |
67 |
68 | Hyperparameter isolation |
69 | |
70 | |
71 |
72 |
73 | Databases |
74 | MySQL |
75 | ☑ |
76 | ☑ |
77 |
78 |
79 | Milvus |
80 | ☑ |
81 | ☑ |
82 |
83 |
84 | OceanBase |
85 | ☑ |
86 | |
87 |
88 |
89 | Session management |
90 | Single-turn dialogue |
91 | ☑ |
92 | ☑ |
93 |
94 |
95 | System commands |
96 | ☑ |
97 | |
98 |
99 |
100 | Multi-turn dialogue |
101 | ☑ |
102 | |
103 |
104 |
105 | Data management |
106 | Data persistence |
107 | ☑ |
108 | ☑ |
109 |
110 |
111 | One-click cache clearance |
112 | ☑ |
113 | |
114 |
115 |
116 | Tenant management |
117 | Support for multi-tenancy |
118 | ☑ |
119 | |
120 |
121 |
122 | Milvus multi-collection capability |
123 | ☑ |
124 | |
125 |
126 |
127 | Other |
128 | Long-short dialogue distinction |
129 | ☑ |
130 | |
131 |
132 |
133 |
--------------------------------------------------------------------------------
/docs/2.model-cache-features.md:
--------------------------------------------------------------------------------
1 | # ModelCache features
2 |
3 | This topic describes ModelCache features. In ModelCache, we incorporated the core principles of GPTCache. ModelCache has four modules: adapter, embedding, similarity, and data_manager.
4 |
5 | - The adapter module orchestrates the business logic for various tasks, integrate the embedding, similarity, and data_manager modules.
6 | - The embedding module converts text into semantic vector representations, and transforms user queries into vectors.
7 | - The rank module ranks and evaluate the similarity of recalled vectors.
8 | - The data_manager module manages the databases.
9 |
10 | To make ModelCache more suitable for industrial use, we made several improvements to its architecture and functionality:
11 |
12 | - [x] Architectural adjustment (lightweight integration):
13 | - Embedded into LLM products using a Redis-like caching mode.
14 | - Provided semantic caching without interfering with LLM calls, security audits, and other functions.
15 | - Compatible with all LLM services.
16 | - [x] Multiple model loading:
17 | - Supported local embedding model loading, and resolved Hugging Face network connectivity issues.
18 | - Supported loading embedding layers from various pre-trained models.
19 | - [x] Data isolation
20 | - Environment isolation: Read different database configurations based on the environment. Isolate development, staging, and production environments.
21 | - Multi-tenant data isolation: Dynamically create collections based on models for data isolation, addressing data separation issues in multi-model/service scenarios within large language model products.
22 | - [x] Supported system instruction: Adopted a concatenation approach to resolve issues with system instructions in the prompt paradigm.
23 | - [x] Long and short text differentiation: Long texts bring more challenges for similarity assessment. Added differentiation between long and short texts, allowing for separate threshold configurations.
24 | - [x] Milvus performance optimization: Adjusted Milvus consistency level to "Session" level for better performance.
25 | - [x] Data management:
26 | - One-click cache clearing to enable easy data management after model upgrades.
27 | - Recall of hit queries for subsequent data analysis and model iteration reference.
28 | - Asynchronous log write-back for data analysis and statistics.
29 | - Added model field and data statistics field to enhance features.
30 |
--------------------------------------------------------------------------------
/docs/3.model-cache-quick-start.md:
--------------------------------------------------------------------------------
1 | # Quick start
2 |
3 | This topic describes how to set up and use ModelCache.
4 |
5 | You can find the start script in `flask4modelcache.py` and `flask4modelcache_demo.py`.
6 |
7 | - `flask4modelcache_demo.py`: A quick test service that embeds SQLite and FAISS. No database configuration required.
8 | - `flask4modelcache.py`: The standard service that requires MySQL and Milvus configuration.
9 |
10 | ## Dependencies
11 |
12 | - Python: V3.8 or above
13 | - Package installation
14 |
15 | ```shell
16 | pip install -r requirements.txt
17 | ```
18 |
19 | ## Start service
20 |
21 | ### Start demo
22 |
23 | 1. Download the embedding model bin file from [Hugging Face](https://huggingface.co/shibing624/text2vec-base-chinese/tree/main). Place it in the `model/text2vec-base-chinese` folder.
24 | 2. Start the backend service:
25 |
26 | ```shell
27 | cd CodeFuse-ModelCache
28 | ```
29 |
30 | ```shell
31 | python flask4modelcache_demo.py
32 | ```
33 |
34 | ### Start standard service
35 |
36 | Before you start standard service, do these steps:
37 |
38 | 1. Install MySQL and import the SQL file from `reference_doc/create_table.sql`.
39 | 2. Install vector database Milvus.
40 | 3. Configure database access in:
41 | - `modelcache/config/milvus_config.ini`
42 | - `modelcache/config/mysql_config.ini`
43 | 4. Download the embedding model bin file from [Hugging Face](https://huggingface.co/shibing624/text2vec-base-chinese/tree/main). Put it in `model/text2vec-base-chinese`.
44 | 5. Start the backend service:
45 |
46 | ```bash
47 | python flask4modelcache.py
48 | ```
49 |
50 | ## Visit the service
51 |
52 | The service provides three core RESTful API functionalities: Cache-Writing, Cache-Querying, and Cache-Clearing.
53 |
54 | ### Write cache
55 |
56 | ```python
57 | import json
58 | import requests
59 | url = 'http://127.0.0.1:5000/modelcache'
60 | type = 'insert'
61 | scope = {"model": "CODEGPT-1008"}
62 | chat_info = [{"query": [{"role": "system", "content": "You are an AI code assistant and you must provide neutral and harmless answers to help users solve code-related problems."}, {"role": "user", "content": "你是谁?"}],
63 | "answer": "Hello, I am an intelligent assistant. How can I assist you?"}]
64 | data = {'type': type, 'scope': scope, 'chat_info': chat_info}
65 | headers = {"Content-Type": "application/json"}
66 | res = requests.post(url, headers=headers, json=json.dumps(data))
67 | ```
68 |
69 | ### Query cache
70 |
71 | ```python
72 | import json
73 | import requests
74 | url = 'http://127.0.0.1:5000/modelcache'
75 | type = 'query'
76 | scope = {"model": "CODEGPT-1008"}
77 | query = [{"role": "system", "content": "You are an AI code assistant and you must provide neutral and harmless answers to help users solve code-related problems."}, {"role": "user", "content": "Who are you?"}]
78 | data = {'type': type, 'scope': scope, 'query': query}
79 |
80 | headers = {"Content-Type": "application/json"}
81 | res = requests.post(url, headers=headers, json=json.dumps(data))
82 | ```
83 |
84 | ### Clear cache
85 |
86 | ```python
87 | import json
88 | import requests
89 | url = 'http://127.0.0.1:5000/modelcache'
90 | type = 'remove'
91 | scope = {"model": "CODEGPT-1008"}
92 | remove_type = 'truncate_by_model'
93 | data = {'type': type, 'scope': scope, 'remove_type': remove_type}
94 |
95 | headers = {"Content-Type": "application/json"}
96 | res = requests.post(url, headers=headers, json=json.dumps(data))
97 | ```
98 |
--------------------------------------------------------------------------------
/docs/4.create-cache.md:
--------------------------------------------------------------------------------
1 | # Create cache
2 |
3 | This topic describes how to create cache.
4 |
5 | ## Default cache interface
6 |
7 | ```py
8 | class Cache:
9 | # ModelCache calls it whe you start the cache system
10 | def __init__(self):
11 | self.has_init = False
12 | self.cache_enable_func = None
13 | self.embedding_func = None
14 | self.post_process_messages_func = None
15 | self.config = Config()
16 | ```
17 |
18 | This function embeds text into dense vectors for context similarity search. ModelCache supports these embedding context methods: Huggingface, ONNX, and SentenceTransformers. The default model is text2vec Hugging Face because it performs better for Chinese. Simply initialize your embedding function as `text2vec.to_embeddings`.
19 |
20 | ```py
21 | data_manager = get_data_manager(CacheBase("mysql", config=mysql_config),
22 | VectorBase("milvus", dimension=data2vec.dimension, milvus_config=milvus_config))
23 |
24 | cache.init(
25 | embedding_func=data2vec.to_embeddings,
26 | data_manager=data_manager,
27 | similarity_evaluation=SearchDistanceEvaluation(),
28 | query_pre_embedding_func=query_multi_splicing,
29 | insert_pre_embedding_func=insert_multi_splicing,
30 | )
31 | ```
32 |
33 | data_manager CacheVase stores all scalar data, such as original questions, prompts, answers, and access times. ModelCache supports multiple cache storages like SQLite, MySQL, and OceanBase. NoSQL databases will be supported in the future.
34 |
35 | data_manager VectorBase stores and searches all embedding vectors to find semantically similar results. ModelCache supports using vector search libraries like FAISS or vector databases like Milvus. More vector database and cloud service will be supported in the future.
36 |
37 | ## Examples
38 |
39 | ```py
40 | data_manager = get_data_manager(CacheBase("sqlite"), VectorBase("faiss", dimension=data2vec.dimension))
41 | data_manager = get_data_manager(CacheBase("oceanbase"), VectorBase("milvus", dimension=data2vec.dimension))
42 | ```
43 |
--------------------------------------------------------------------------------
/docs/cache-service-cost-time-distribution.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/ModelCache/e053e0d57b532d4ad9378d2f31bb85a009b77d64/docs/cache-service-cost-time-distribution.webp
--------------------------------------------------------------------------------
/docs/codefuse-LOGO.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/ModelCache/e053e0d57b532d4ad9378d2f31bb85a009b77d64/docs/codefuse-LOGO.png
--------------------------------------------------------------------------------
/docs/modelcache_modules_20231114.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/ModelCache/e053e0d57b532d4ad9378d2f31bb85a009b77d64/docs/modelcache_modules_20231114.png
--------------------------------------------------------------------------------
/docs/modelcache_modules_20240409.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/ModelCache/e053e0d57b532d4ad9378d2f31bb85a009b77d64/docs/modelcache_modules_20240409.png
--------------------------------------------------------------------------------
/docs/script/get_input_embedding_script.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import torch
3 | import numpy as np
4 | from transformers import AutoModelForCausalLM
5 |
6 |
7 | model_path = ''
8 | device = torch.device('cuda')
9 | model = AutoModelForCausalLM.from_pretrained(model_path, local_files_only=True).to(device)
10 | embedding_weights = model.get_input_embeddings().weight.to('cpu').detach().numpy()
11 | np.save('gpt-neox-embedding.npy', embedding_weights)
12 |
--------------------------------------------------------------------------------
/docs/time-cost-comparison.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/ModelCache/e053e0d57b532d4ad9378d2f31bb85a009b77d64/docs/time-cost-comparison.webp
--------------------------------------------------------------------------------
/examples/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
--------------------------------------------------------------------------------
/examples/embedding/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
--------------------------------------------------------------------------------
/examples/embedding/huggingface_tei_example.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import sys
3 | sys.path.append(".")
4 | from modelcache.embedding.huggingface_tei import HuggingfaceTEI
5 |
6 | '''
7 | run tei server:
8 | text-embeddings-router --model-id BAAI/bge-large-zh-v1.5 --port 8080
9 | '''
10 |
11 | def run():
12 | tei_instance = HuggingfaceTEI('http://127.0.0.1:8080/v1/embeddings', 'BAAI/bge-large-zh-v1.5')
13 | print('dimenson', tei_instance.dimension)
14 | print('embedding', tei_instance.to_embeddings('hello'))
15 |
16 | if __name__ == '__main__':
17 | run()
--------------------------------------------------------------------------------
/examples/flask/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
--------------------------------------------------------------------------------
/examples/flask/llms_cache/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
--------------------------------------------------------------------------------
/examples/flask/llms_cache/data_insert.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import json
3 | import requests
4 |
5 |
6 | def run():
7 | url = 'http://127.0.0.1:5000/modelcache'
8 | type = 'insert'
9 | scope = {"model": "CODEGPT-1117"}
10 | chat_info = [{"query": [{"role": "system", "content": "你是一个python助手"}, {"role": "user", "content": "hello"}],
11 | "answer": "你好,我是智能助手,请问有什么能帮您!"}]
12 | data = {'type': type, 'scope': scope, 'chat_info': chat_info}
13 | headers = {"Content-Type": "application/json"}
14 | res = requests.post(url, headers=headers, json=json.dumps(data))
15 | res_text = res.text
16 |
17 | print("data_insert:", res.status_code, res_text)
18 |
19 | if __name__ == '__main__':
20 | run()
21 |
--------------------------------------------------------------------------------
/examples/flask/llms_cache/data_query.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import json
3 | import requests
4 |
5 |
6 | def run():
7 | url = 'http://127.0.0.1:5000/modelcache'
8 | type = 'query'
9 | scope = {"model": "CODEGPT-1117"}
10 | query = [{"role": "system", "content": "你是一个python助手"}, {"role": "user", "content": "hello"}]
11 | data = {'type': type, 'scope': scope, 'query': query}
12 |
13 | headers = {"Content-Type": "application/json"}
14 | res = requests.post(url, headers=headers, json=json.dumps(data))
15 | res_text = res.text
16 |
17 | print("data_query:", res.status_code, res_text)
18 |
19 | if __name__ == '__main__':
20 | run()
21 |
--------------------------------------------------------------------------------
/examples/flask/llms_cache/data_query_long.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import json
3 | import requests
4 |
5 |
6 | def run():
7 | url = 'http://127.0.0.1:5000/modelcache'
8 | type = 'query'
9 | scope = {"model": "CODEGPT-1109"}
10 | system_conten = """
11 | """
12 | user_content = """
13 | """
14 |
15 | query = [{"role": "system", "content": system_conten}, {"role": "user", "content": user_content}]
16 | data = {'type': type, 'scope': scope, 'query': query}
17 |
18 | headers = {"Content-Type": "application/json"}
19 | res = requests.post(url, headers=headers, json=json.dumps(data))
20 | res_text = res.text
21 |
22 | print("data_query_long:", res.status_code, res_text)
23 |
24 | if __name__ == '__main__':
25 | run()
26 |
--------------------------------------------------------------------------------
/examples/flask/llms_cache/register.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | register index for redis
4 | """
5 | import json
6 | import requests
7 |
8 |
9 | def run():
10 | url = 'http://127.0.0.1:5000/modelcache'
11 | type = 'register'
12 | scope = {"model": "CODEGPT-1117"}
13 | data = {'type': type, 'scope': scope}
14 | headers = {"Content-Type": "application/json"}
15 | res = requests.post(url, headers=headers, json=json.dumps(data))
16 | res_text = res.text
17 |
18 |
19 | if __name__ == '__main__':
20 | run()
--------------------------------------------------------------------------------
/examples/flask/multi_cache/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
--------------------------------------------------------------------------------
/examples/flask/multi_cache/data_insert.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import time
3 | import json
4 | import uuid
5 | import requests
6 |
7 |
8 | def run():
9 | url = 'http://127.0.0.1:5000/multicache'
10 |
11 | request_type = 'insert'
12 | scope = {"model": "multimodal_test"}
13 | # UUID = "820b0052-d9d8-11ee-95f1-52775e3e6fd1" + "==>" + str(time.time())
14 | UUID = str(uuid.uuid1()) + "==>" + str(time.time())
15 | img_data = "https://img0.baidu.com/it/u=1436460262,4166266890&fm=253&fmt=auto&app=138&f=JPEG?w=500&h=282"
16 | query = {'text': ['父母带着孩子来这个地方可能会有什么顾虑'],
17 | 'imageRaw': '',
18 | 'imageUrl': img_data,
19 | 'imageId': 'ccc'}
20 | answer = "应该注意小孩不要跑到铁轨上"
21 | chat_info = [{"query": query, "answer": answer}]
22 | data_dict = {'request_type': request_type, 'scope': scope, 'chat_info': chat_info, 'UUID': UUID}
23 |
24 | headers = {"Content-Type": "application/json"}
25 | res = requests.post(url, headers=headers, json=json.dumps(data_dict))
26 | res_text = res.text
27 | print('res_text: {}'.format(res_text))
28 |
29 |
30 | if __name__ == '__main__':
31 | run()
32 |
--------------------------------------------------------------------------------
/examples/flask/multi_cache/data_query.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import json
3 | import requests
4 | import uuid
5 | import time
6 |
7 |
8 | def run():
9 | url = 'http://127.0.0.1:5000/multicache'
10 | request_type = 'query'
11 | UUID = str(uuid.uuid1()) + "==>" + str(time.time())
12 | scope = {"model": "multimodal_test"}
13 | img_data = "https://img0.baidu.com/it/u=1436460262,4166266890&fm=253&fmt=auto&app=138&f=JPEG?w=500&h=282"
14 | query = {'text': ['父母带着孩子来这个地方可能会有什么顾虑'],
15 | 'imageRaw': '',
16 | 'imageUrl': img_data,
17 | 'multiType': 'IMG_TEXT'}
18 |
19 | data = {'request_type': request_type, 'scope': scope, 'query': query, 'UUID': UUID}
20 |
21 | headers = {"Content-Type": "application/json"}
22 | res = requests.post(url, headers=headers, json=json.dumps(data))
23 | res_text = res.text
24 | print('res_text: {}'.format(res_text))
25 |
26 |
27 | if __name__ == '__main__':
28 | run()
29 |
--------------------------------------------------------------------------------
/examples/flask/multi_cache/register.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | register index for redis
4 | """
5 | import json
6 | import requests
7 |
8 |
9 | def run():
10 | url = 'http://127.0.0.1:5000/multicache'
11 | request_type = 'register'
12 | scope = {"model": "multimodal_test"}
13 | type = 'IMG_TEXT'
14 | data = {'request_type': request_type, 'scope': scope, 'type': type}
15 | headers = {"Content-Type": "application/json"}
16 | res = requests.post(url, headers=headers, json=json.dumps(data))
17 | res_text = res.text
18 | print('res_text: {}'.format(res_text))
19 |
20 |
21 | if __name__ == '__main__':
22 | run()
--------------------------------------------------------------------------------
/examples/flask/multi_cache/remove.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | register index for redis
4 | """
5 | import json
6 | import requests
7 |
8 |
9 | def run():
10 | url = 'http://127.0.0.1:5000/multicache'
11 | request_type = 'remove'
12 | scope = {"model": "multimodal_test"}
13 | remove_type = 'truncate_by_model'
14 | data = {'request_type': request_type, 'scope': scope, 'remove_type': remove_type}
15 |
16 | headers = {"Content-Type": "application/json"}
17 | res = requests.post(url, headers=headers, json=json.dumps(data))
18 | res_text = res.text
19 | print('res_text: {}'.format(res_text))
20 |
21 |
22 | if __name__ == '__main__':
23 | run()
24 |
--------------------------------------------------------------------------------
/fastapi4modelcache_demo.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import time
3 | import uvicorn
4 | import asyncio
5 | import logging
6 | # import configparser
7 | import json
8 | from fastapi import FastAPI, Request, HTTPException
9 | from pydantic import BaseModel
10 | from concurrent.futures import ThreadPoolExecutor
11 | from starlette.responses import PlainTextResponse
12 | import functools
13 |
14 | from modelcache import cache
15 | from modelcache.adapter import adapter
16 | from modelcache.manager import CacheBase, VectorBase, get_data_manager
17 | from modelcache.similarity_evaluation.distance import SearchDistanceEvaluation
18 | from modelcache.processor.pre import query_multi_splicing
19 | from modelcache.processor.pre import insert_multi_splicing
20 | from modelcache.utils.model_filter import model_blacklist_filter
21 | from modelcache.embedding import Data2VecAudio
22 |
23 | # 创建一个FastAPI实例
24 | app = FastAPI()
25 |
26 | class RequestData(BaseModel):
27 | type: str
28 | scope: dict = None
29 | query: str = None
30 | chat_info: list = None
31 | remove_type: str = None
32 | id_list: list = []
33 |
34 | data2vec = Data2VecAudio()
35 |
36 | data_manager = get_data_manager(CacheBase("sqlite"), VectorBase("faiss", dimension=data2vec.dimension))
37 |
38 | cache.init(
39 | embedding_func=data2vec.to_embeddings,
40 | data_manager=data_manager,
41 | similarity_evaluation=SearchDistanceEvaluation(),
42 | query_pre_embedding_func=query_multi_splicing,
43 | insert_pre_embedding_func=insert_multi_splicing,
44 | )
45 |
46 | executor = ThreadPoolExecutor(max_workers=6)
47 |
48 | # 异步保存查询信息
49 | async def save_query_info_fastapi(result, model, query, delta_time_log):
50 | loop = asyncio.get_running_loop()
51 | func = functools.partial(cache.data_manager.save_query_resp, result, model=model, query=json.dumps(query, ensure_ascii=False), delta_time=delta_time_log)
52 | await loop.run_in_executor(None, func)
53 |
54 |
55 |
56 | @app.get("/welcome", response_class=PlainTextResponse)
57 | async def first_fastapi():
58 | return "hello, modelcache!"
59 |
60 | @app.post("/modelcache")
61 | async def user_backend(request: Request):
62 | try:
63 | raw_body = await request.body()
64 | # 解析字符串为JSON对象
65 | if isinstance(raw_body, bytes):
66 | raw_body = raw_body.decode("utf-8")
67 | if isinstance(raw_body, str):
68 | try:
69 | # 尝试将字符串解析为JSON对象
70 | request_data = json.loads(raw_body)
71 | except json.JSONDecodeError:
72 | # 如果无法解析,返回格式错误
73 | raise HTTPException(status_code=400, detail="Invalid JSON format")
74 | else:
75 | request_data = raw_body
76 |
77 | # 确保request_data是字典对象
78 | if isinstance(request_data, str):
79 | try:
80 | request_data = json.loads(request_data)
81 | except json.JSONDecodeError:
82 | raise HTTPException(status_code=400, detail="Invalid JSON format")
83 |
84 | request_type = request_data.get('type')
85 | model = None
86 | if 'scope' in request_data:
87 | model = request_data['scope'].get('model', '').replace('-', '_').replace('.', '_')
88 | query = request_data.get('query')
89 | chat_info = request_data.get('chat_info')
90 |
91 | if not request_type or request_type not in ['query', 'insert', 'remove', 'detox']:
92 | raise HTTPException(status_code=400, detail="Type exception, should be one of ['query', 'insert', 'remove', 'detox']")
93 |
94 | except Exception as e:
95 | request_data = raw_body if 'raw_body' in locals() else None
96 | result = {
97 | "errorCode": 103,
98 | "errorDesc": str(e),
99 | "cacheHit": False,
100 | "delta_time": 0,
101 | "hit_query": '',
102 | "answer": '',
103 | "para_dict": request_data
104 | }
105 | return result
106 |
107 |
108 | # model filter
109 | filter_resp = model_blacklist_filter(model, request_type)
110 | if isinstance(filter_resp, dict):
111 | return filter_resp
112 |
113 | if request_type == 'query':
114 | try:
115 | start_time = time.time()
116 | response = adapter.ChatCompletion.create_query(scope={"model": model}, query=query)
117 | delta_time = f"{round(time.time() - start_time, 2)}s"
118 |
119 | if response is None:
120 | result = {"errorCode": 0, "errorDesc": '', "cacheHit": False, "delta_time": delta_time, "hit_query": '', "answer": ''}
121 | elif response in ['adapt_query_exception']:
122 | # elif isinstance(response, str):
123 | result = {"errorCode": 201, "errorDesc": response, "cacheHit": False, "delta_time": delta_time,
124 | "hit_query": '', "answer": ''}
125 | else:
126 | answer = response['data']
127 | hit_query = response['hitQuery']
128 | result = {"errorCode": 0, "errorDesc": '', "cacheHit": True, "delta_time": delta_time, "hit_query": hit_query, "answer": answer}
129 |
130 | delta_time_log = round(time.time() - start_time, 2)
131 | asyncio.create_task(save_query_info_fastapi(result, model, query, delta_time_log))
132 | return result
133 | except Exception as e:
134 | result = {"errorCode": 202, "errorDesc": str(e), "cacheHit": False, "delta_time": 0,
135 | "hit_query": '', "answer": ''}
136 | logging.info(f'result: {str(result)}')
137 | return result
138 |
139 | if request_type == 'insert':
140 | try:
141 | response = adapter.ChatCompletion.create_insert(model=model, chat_info=chat_info)
142 | if response == 'success':
143 | return {"errorCode": 0, "errorDesc": "", "writeStatus": "success"}
144 | else:
145 | return {"errorCode": 301, "errorDesc": response, "writeStatus": "exception"}
146 | except Exception as e:
147 | return {"errorCode": 303, "errorDesc": str(e), "writeStatus": "exception"}
148 |
149 | if request_type == 'remove':
150 | response = adapter.ChatCompletion.create_remove(model=model, remove_type=request_data.get("remove_type"), id_list=request_data.get("id_list"))
151 | if not isinstance(response, dict):
152 | return {"errorCode": 401, "errorDesc": "", "response": response, "removeStatus": "exception"}
153 |
154 | state = response.get('status')
155 | if state == 'success':
156 | return {"errorCode": 0, "errorDesc": "", "response": response, "writeStatus": "success"}
157 | else:
158 | return {"errorCode": 402, "errorDesc": "", "response": response, "writeStatus": "exception"}
159 |
160 | # TODO: 可以修改为在命令行中使用`uvicorn your_module_name:app --host 0.0.0.0 --port 5000 --reload`的命令启动
161 | if __name__ == '__main__':
162 | uvicorn.run(app, host='0.0.0.0', port=5000)
--------------------------------------------------------------------------------
/flask4modelcache_demo.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import time
3 | from flask import Flask, request
4 | import logging
5 | import json
6 | from modelcache import cache
7 | from modelcache.adapter import adapter
8 | from modelcache.manager import CacheBase, VectorBase, get_data_manager
9 | from modelcache.similarity_evaluation.distance import SearchDistanceEvaluation
10 | from modelcache.processor.pre import query_multi_splicing
11 | from modelcache.processor.pre import insert_multi_splicing
12 | from concurrent.futures import ThreadPoolExecutor
13 | from modelcache.utils.model_filter import model_blacklist_filter
14 | from modelcache.embedding import Data2VecAudio
15 |
16 | # 创建一个Flask实例
17 | app = Flask(__name__)
18 |
19 |
20 | def response_text(cache_resp):
21 | return cache_resp['data']
22 |
23 |
24 | def save_query_info(result, model, query, delta_time_log):
25 | cache.data_manager.save_query_resp(result, model=model, query=json.dumps(query, ensure_ascii=False),
26 | delta_time=delta_time_log)
27 |
28 |
29 | def response_hitquery(cache_resp):
30 | return cache_resp['hitQuery']
31 |
32 |
33 | data2vec = Data2VecAudio()
34 | data_manager = get_data_manager(CacheBase("sqlite"), VectorBase("faiss", dimension=data2vec.dimension))
35 |
36 |
37 | cache.init(
38 | embedding_func=data2vec.to_embeddings,
39 | data_manager=data_manager,
40 | similarity_evaluation=SearchDistanceEvaluation(),
41 | query_pre_embedding_func=query_multi_splicing,
42 | insert_pre_embedding_func=insert_multi_splicing,
43 | )
44 |
45 | # cache.set_openai_key()
46 | global executor
47 | executor = ThreadPoolExecutor(max_workers=6)
48 |
49 |
50 | @app.route('/welcome')
51 | def first_flask(): # 视图函数
52 | return 'hello, modelcache!'
53 |
54 |
55 | @app.route('/modelcache', methods=['GET', 'POST'])
56 | def user_backend():
57 | try:
58 | if request.method == 'POST':
59 | request_data = request.json
60 | elif request.method == 'GET':
61 | request_data = request.args
62 | param_dict = json.loads(request_data)
63 | except Exception as e:
64 | result = {"errorCode": 101, "errorDesc": str(e), "cacheHit": False, "delta_time": 0, "hit_query": '',
65 | "answer": ''}
66 | cache.data_manager.save_query_resp(result, model='', query='', delta_time=0)
67 | return json.dumps(result)
68 |
69 | # param parsing
70 | try:
71 | request_type = param_dict.get("type")
72 | scope = param_dict.get("scope")
73 | if scope is not None:
74 | model = scope.get('model')
75 | model = model.replace('-', '_')
76 | model = model.replace('.', '_')
77 | query = param_dict.get("query")
78 | chat_info = param_dict.get("chat_info")
79 | if request_type is None or request_type not in ['query', 'insert', 'detox', 'remove']:
80 | result = {"errorCode": 102,
81 | "errorDesc": "type exception, should one of ['query', 'insert', 'detox', 'remove']",
82 | "cacheHit": False, "delta_time": 0, "hit_query": '', "answer": ''}
83 | cache.data_manager.save_query_resp(result, model=model, query='', delta_time=0)
84 | return json.dumps(result)
85 | except Exception as e:
86 | result = {"errorCode": 103, "errorDesc": str(e), "cacheHit": False, "delta_time": 0, "hit_query": '',
87 | "answer": ''}
88 | return json.dumps(result)
89 |
90 | # model filter
91 | filter_resp = model_blacklist_filter(model, request_type)
92 | if isinstance(filter_resp, dict):
93 | return json.dumps(filter_resp)
94 |
95 | if request_type == 'query':
96 | try:
97 | start_time = time.time()
98 | response = adapter.ChatCompletion.create_query(
99 | scope={"model": model},
100 | query=query
101 | )
102 | delta_time = '{}s'.format(round(time.time() - start_time, 2))
103 | if response is None:
104 | result = {"errorCode": 0, "errorDesc": '', "cacheHit": False, "delta_time": delta_time, "hit_query": '',
105 | "answer": ''}
106 | elif response in ['adapt_query_exception']:
107 | result = {"errorCode": 201, "errorDesc": response, "cacheHit": False, "delta_time": delta_time,
108 | "hit_query": '', "answer": ''}
109 | else:
110 | answer = response_text(response)
111 | hit_query = response_hitquery(response)
112 | result = {"errorCode": 0, "errorDesc": '', "cacheHit": True, "delta_time": delta_time,
113 | "hit_query": hit_query, "answer": answer}
114 | delta_time_log = round(time.time() - start_time, 2)
115 | future = executor.submit(save_query_info, result, model, query, delta_time_log)
116 | except Exception as e:
117 | result = {"errorCode": 202, "errorDesc": e, "cacheHit": False, "delta_time": 0,
118 | "hit_query": '', "answer": ''}
119 | logging.info('result: {}'.format(result))
120 | return json.dumps(result, ensure_ascii=False)
121 |
122 | if request_type == 'insert':
123 | try:
124 | try:
125 | response = adapter.ChatCompletion.create_insert(
126 | model=model,
127 | chat_info=chat_info
128 | )
129 | except Exception as e:
130 | result = {"errorCode": 303, "errorDesc": e, "writeStatus": "exception"}
131 | return json.dumps(result, ensure_ascii=False)
132 |
133 | if response in ['adapt_insert_exception']:
134 | result = {"errorCode": 301, "errorDesc": response, "writeStatus": "exception"}
135 | elif response == 'success':
136 | result = {"errorCode": 0, "errorDesc": "", "writeStatus": "success"}
137 | else:
138 | result = {"errorCode": 302, "errorDesc": response,
139 | "writeStatus": "exception"}
140 | return json.dumps(result, ensure_ascii=False)
141 | except Exception as e:
142 | result = {"errorCode": 304, "errorDesc": e, "writeStatus": "exception"}
143 | return json.dumps(result, ensure_ascii=False)
144 |
145 | if request_type == 'remove':
146 | remove_type = param_dict.get("remove_type")
147 | id_list = param_dict.get("id_list", [])
148 |
149 | response = adapter.ChatCompletion.create_remove(
150 | model=model,
151 | remove_type=remove_type,
152 | id_list=id_list
153 | )
154 |
155 | if not isinstance(response, dict):
156 | result = {"errorCode": 401, "errorDesc": "", "response": response, "removeStatus": "exception"}
157 | return json.dumps(result)
158 |
159 | state = response.get('status')
160 | if state == 'success':
161 | result = {"errorCode": 0, "errorDesc": "", "response": response, "writeStatus": "success"}
162 | else:
163 | result = {"errorCode": 402, "errorDesc": "", "response": response, "writeStatus": "exception"}
164 | return json.dumps(result)
165 |
166 |
167 | if __name__ == '__main__':
168 | # app.run(host='0.0.0.0', port=5000, debug=True)
169 | app.run(host='0.0.0.0', port=5000)
170 |
--------------------------------------------------------------------------------
/model/clip_zh/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
--------------------------------------------------------------------------------
/model/text2vec-base-chinese/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "_name_or_path": "hfl/chinese-macbert-base",
3 | "architectures": [
4 | "BertModel"
5 | ],
6 | "attention_probs_dropout_prob": 0.1,
7 | "classifier_dropout": null,
8 | "directionality": "bidi",
9 | "gradient_checkpointing": false,
10 | "hidden_act": "gelu",
11 | "hidden_dropout_prob": 0.1,
12 | "hidden_size": 768,
13 | "initializer_range": 0.02,
14 | "intermediate_size": 3072,
15 | "layer_norm_eps": 1e-12,
16 | "max_position_embeddings": 512,
17 | "model_type": "bert",
18 | "num_attention_heads": 12,
19 | "num_hidden_layers": 12,
20 | "pad_token_id": 0,
21 | "pooler_fc_size": 768,
22 | "pooler_num_attention_heads": 12,
23 | "pooler_num_fc_layers": 3,
24 | "pooler_size_per_head": 128,
25 | "pooler_type": "first_token_transform",
26 | "position_embedding_type": "absolute",
27 | "torch_dtype": "float32",
28 | "transformers_version": "4.12.3",
29 | "type_vocab_size": 2,
30 | "use_cache": true,
31 | "vocab_size": 21128
32 | }
33 |
--------------------------------------------------------------------------------
/model/text2vec-base-chinese/logs.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/model/text2vec-base-chinese/modules.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "idx": 0,
4 | "name": "0",
5 | "path": "",
6 | "type": "sentence_transformers.models.Transformer"
7 | },
8 | {
9 | "idx": 1,
10 | "name": "1",
11 | "path": "1_Pooling",
12 | "type": "sentence_transformers.models.Pooling"
13 | }
14 | ]
15 |
--------------------------------------------------------------------------------
/model/text2vec-base-chinese/sentence_bert_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "max_seq_length": 128,
3 | "do_lower_case": false
4 | }
5 |
--------------------------------------------------------------------------------
/model/text2vec-base-chinese/special_tokens_map.json:
--------------------------------------------------------------------------------
1 | {"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}
--------------------------------------------------------------------------------
/model/text2vec-base-chinese/tokenizer_config.json:
--------------------------------------------------------------------------------
1 | {"do_lower_case": true, "do_basic_tokenize": true, "never_split": null, "unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]", "tokenize_chinese_chars": true, "strip_accents": null, "name_or_path": "hfl/chinese-macbert-base", "tokenizer_class": "BertTokenizer"}
2 |
--------------------------------------------------------------------------------
/modelcache/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from modelcache.core import Cache
3 | from modelcache.core import cache
4 | from modelcache.config import Config
5 | import modelcache.adapter
6 |
--------------------------------------------------------------------------------
/modelcache/adapter/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
--------------------------------------------------------------------------------
/modelcache/adapter/adapter.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import logging
3 | from modelcache.adapter.adapter_query import adapt_query
4 | from modelcache.adapter.adapter_insert import adapt_insert
5 | from modelcache.adapter.adapter_remove import adapt_remove
6 | from modelcache.adapter.adapter_register import adapt_register
7 |
8 |
9 | class ChatCompletion(object):
10 | """Openai ChatCompletion Wrapper"""
11 |
12 | @classmethod
13 | def create_query(cls, *args, **kwargs):
14 | def cache_data_convert(cache_data, cache_query):
15 | return construct_resp_from_cache(cache_data, cache_query)
16 | try:
17 | return adapt_query(
18 | cache_data_convert,
19 | *args,
20 | **kwargs
21 | )
22 | except Exception as e:
23 | return str(e)
24 |
25 | @classmethod
26 | def create_insert(cls, *args, **kwargs):
27 | try:
28 | return adapt_insert(
29 | *args,
30 | **kwargs
31 | )
32 | except Exception as e:
33 | return str(e)
34 |
35 | @classmethod
36 | def create_remove(cls, *args, **kwargs):
37 | try:
38 | return adapt_remove(
39 | *args,
40 | **kwargs
41 | )
42 | except Exception as e:
43 | logging.info('adapt_remove_e: {}'.format(e))
44 | return str(e)
45 |
46 | @classmethod
47 | def create_register(cls, *args, **kwargs):
48 | try:
49 | return adapt_register(
50 | *args,
51 | **kwargs
52 | )
53 | except Exception as e:
54 | return str(e)
55 |
56 |
57 | def construct_resp_from_cache(return_message, return_query):
58 | return {
59 | "modelcache": True,
60 | "hitQuery": return_query,
61 | "data": return_message,
62 | "errorCode": 0
63 | }
64 |
--------------------------------------------------------------------------------
/modelcache/adapter/adapter_insert.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from modelcache import cache
3 | from modelcache.utils.error import NotInitError
4 | from modelcache.utils.time import time_cal
5 |
6 |
7 | def adapt_insert(*args, **kwargs):
8 | chat_cache = kwargs.pop("cache_obj", cache)
9 | model = kwargs.pop("model", None)
10 | require_object_store = kwargs.pop("require_object_store", False)
11 | if require_object_store:
12 | assert chat_cache.data_manager.o, "Object store is required for adapter."
13 | if not chat_cache.has_init:
14 | raise NotInitError()
15 | cache_enable = chat_cache.cache_enable_func(*args, **kwargs)
16 | context = kwargs.pop("cache_context", {})
17 | embedding_data = None
18 | pre_embedding_data = chat_cache.insert_pre_embedding_func(
19 | kwargs,
20 | extra_param=context.get("pre_embedding_func", None),
21 | prompts=chat_cache.config.prompts,
22 | )
23 | chat_info = kwargs.pop("chat_info", [])
24 | llm_data = chat_info[-1]['answer']
25 |
26 | if cache_enable:
27 | embedding_data = time_cal(
28 | chat_cache.embedding_func,
29 | func_name="embedding",
30 | report_func=chat_cache.report.embedding,
31 | )(pre_embedding_data)
32 |
33 | chat_cache.data_manager.save(
34 | pre_embedding_data,
35 | llm_data,
36 | embedding_data,
37 | model=model,
38 | extra_param=context.get("save_func", None)
39 | )
40 | return 'success'
41 |
--------------------------------------------------------------------------------
/modelcache/adapter/adapter_register.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from modelcache import cache
3 |
4 |
5 | def adapt_register(*args, **kwargs):
6 | chat_cache = kwargs.pop("cache_obj", cache)
7 | model = kwargs.pop("model", None)
8 | if model is None or len(model) == 0:
9 | return ValueError('')
10 |
11 | register_resp = chat_cache.data_manager.create_index(model)
12 | return register_resp
13 |
--------------------------------------------------------------------------------
/modelcache/adapter/adapter_remove.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from modelcache import cache
3 | from modelcache.utils.error import NotInitError, RemoveError
4 |
5 |
6 | def adapt_remove(*args, **kwargs):
7 | chat_cache = kwargs.pop("cache_obj", cache)
8 | model = kwargs.pop("model", None)
9 | remove_type = kwargs.pop("remove_type", None)
10 | require_object_store = kwargs.pop("require_object_store", False)
11 | if require_object_store:
12 | assert chat_cache.data_manager.o, "Object store is required for adapter."
13 | if not chat_cache.has_init:
14 | raise NotInitError()
15 |
16 | # delete data
17 | if remove_type == 'delete_by_id':
18 | id_list = kwargs.pop("id_list", [])
19 | resp = chat_cache.data_manager.delete(id_list, model=model)
20 | elif remove_type == 'truncate_by_model':
21 | resp = chat_cache.data_manager.truncate(model)
22 | else:
23 | # resp = "remove_type_error"
24 | raise RemoveError()
25 | return resp
26 |
27 |
--------------------------------------------------------------------------------
/modelcache/config.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from typing import Optional, Callable, List
3 | from modelcache.utils.error import CacheError
4 |
5 |
6 | class Config:
7 |
8 | def __init__(
9 | self,
10 | log_time_func: Optional[Callable[[str, float], None]] = None,
11 | similarity_threshold: float = 0.95,
12 | similarity_threshold_long: float = 0.95,
13 | prompts: Optional[List[str]] = None
14 | ):
15 | if similarity_threshold < 0 or similarity_threshold > 1:
16 | raise CacheError(
17 | "Invalid the similarity threshold param, reasonable range: 0-1"
18 | )
19 | self.log_time_func = log_time_func
20 | self.similarity_threshold = similarity_threshold
21 | self.similarity_threshold_long = similarity_threshold_long
22 | self.prompts = prompts
23 |
--------------------------------------------------------------------------------
/modelcache/config/chromadb_config.ini:
--------------------------------------------------------------------------------
1 | [chromadb]
2 | persist_directory=''
3 |
--------------------------------------------------------------------------------
/modelcache/config/elasticsearch_config.ini:
--------------------------------------------------------------------------------
1 | [elasticsearch]
2 | host = ''
3 | port = ''
4 | user = ''
5 | password = ''
--------------------------------------------------------------------------------
/modelcache/config/milvus_config.ini:
--------------------------------------------------------------------------------
1 | [milvus]
2 | host = milvus
3 | port = 19530
4 | user = ''
5 | password = ''
--------------------------------------------------------------------------------
/modelcache/config/mysql_config.ini:
--------------------------------------------------------------------------------
1 | [mysql]
2 | host = mysql
3 | port = 3306
4 | username = modelcache
5 | password = modelcache
6 | database = modelcache
7 |
--------------------------------------------------------------------------------
/modelcache/config/redis_config.ini:
--------------------------------------------------------------------------------
1 | [redis]
2 | host = ''
3 | port = ''
4 | user = ''
5 | password = ''
6 |
--------------------------------------------------------------------------------
/modelcache/core.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import atexit
3 | from typing import Optional, List, Any
4 | from modelcache.processor.post import first
5 | from modelcache.similarity_evaluation import ExactMatchEvaluation
6 | from modelcache.similarity_evaluation import SimilarityEvaluation
7 | from modelcache.embedding.string_text import to_embeddings as string_embedding
8 | from modelcache.report import Report
9 | from modelcache.config import Config
10 | from modelcache.utils.cache_func import cache_all
11 | from modelcache.utils.log import modelcache_log
12 | from modelcache.manager import get_data_manager
13 | from modelcache.manager.data_manager import DataManager
14 |
15 |
16 | class Cache:
17 | def __init__(self):
18 | self.has_init = False
19 | self.cache_enable_func = None
20 | self.query_pre_embedding_func = None
21 | self.insert_pre_embedding_func = None
22 | self.mm_query_pre_embedding_func = None
23 | self.mm_insert_pre_embedding_func = None
24 | self.embedding_func = None
25 | self.embedding_concurrent_func = None
26 | self.data_manager: Optional[DataManager] = None
27 | self.similarity_evaluation: Optional[SimilarityEvaluation] = None
28 | self.post_process_messages_func = None
29 | self.config = Config()
30 | self.report = Report()
31 | self.next_cache = None
32 |
33 | def init(
34 | self,
35 | cache_enable_func=cache_all,
36 | query_pre_embedding_func=None,
37 | insert_pre_embedding_func=None,
38 | embedding_func=string_embedding,
39 | data_manager: DataManager = get_data_manager(),
40 | similarity_evaluation=ExactMatchEvaluation(),
41 | post_process_messages_func=first,
42 | config=Config(),
43 | next_cache=None,
44 | ):
45 | self.has_init = True
46 | self.cache_enable_func = cache_enable_func
47 | self.query_pre_embedding_func = query_pre_embedding_func
48 | self.insert_pre_embedding_func = insert_pre_embedding_func
49 | self.embedding_func = embedding_func
50 | self.data_manager: DataManager = data_manager
51 | self.similarity_evaluation = similarity_evaluation
52 | self.post_process_messages_func = post_process_messages_func
53 | self.config = config
54 | self.next_cache = next_cache
55 |
56 | @atexit.register
57 | def close():
58 | try:
59 | self.data_manager.close()
60 | except Exception as e:
61 | modelcache_log.error(e)
62 |
63 | def import_data(self, questions: List[Any], answers: List[Any]) -> None:
64 | self.data_manager.import_data(
65 | questions=questions,
66 | answers=answers,
67 | embedding_datas=[self.embedding_func(question) for question in questions],
68 | )
69 |
70 | def flush(self):
71 | self.data_manager.flush()
72 | if self.next_cache:
73 | self.next_cache.data_manager.flush()
74 |
75 |
76 | cache = Cache()
77 |
--------------------------------------------------------------------------------
/modelcache/embedding/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from modelcache.utils.lazy_import import LazyImport
3 | huggingface = LazyImport("huggingface", globals(), "modelcache.embedding.huggingface")
4 | data2vec = LazyImport("data2vec", globals(), "modelcache.embedding.data2vec")
5 | llmEmb = LazyImport("llmEmb", globals(), "modelcache.embedding.llmEmb")
6 | fasttext = LazyImport("fasttext", globals(), "modelcache.embedding.fasttext")
7 | paddlenlp = LazyImport("paddlenlp", globals(), "modelcache.embedding.paddlenlp")
8 | timm = LazyImport("timm", globals(), "modelcache.embedding.timm")
9 | huggingface_tei = LazyImport("huggingface_tei", globals(), "modelcache.embedding.huggingface_tei")
10 | bge_m3 = LazyImport("bge_m3", globals(), "modelcache.embedding.bge_m3")
11 |
12 |
13 | def Huggingface(model="sentence-transformers/all-mpnet-base-v2"):
14 | return huggingface.Huggingface(model)
15 |
16 |
17 | def Data2VecAudio(model="model/text2vec-base-chinese/"):
18 | return data2vec.Data2VecAudio(model)
19 |
20 |
21 | def LlmEmb2vecAudio():
22 | return llmEmb.LlmEmb2Vec()
23 |
24 |
25 | def FastText(model="en", dim=None):
26 | return fasttext.FastText(model, dim)
27 |
28 |
29 | def PaddleNLP(model="ernie-3.0-medium-zh"):
30 | return paddlenlp.PaddleNLP(model)
31 |
32 |
33 | def Timm(model="resnet50", device="default"):
34 | return timm.Timm(model, device)
35 |
36 | def HuggingfaceTEI(base_url, model):
37 | return huggingface_tei.HuggingfaceTEI(base_url, model)
38 |
39 | def BgeM3Embedding(model_path="model/bge-m3"):
40 | return bge_m3.BgeM3Embedding(model_path)
--------------------------------------------------------------------------------
/modelcache/embedding/base.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from abc import ABCMeta, abstractmethod
3 |
4 |
5 | class BaseEmbedding(metaclass=ABCMeta):
6 | """
7 | _Embedding base.
8 | """
9 |
10 | @abstractmethod
11 | def to_embeddings(self, data, **kwargs):
12 | pass
13 |
14 | @property
15 | @abstractmethod
16 | def dimension(self) -> int:
17 | return 0
18 |
--------------------------------------------------------------------------------
/modelcache/embedding/bge_m3.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import numpy as np
3 | from modelcache.embedding.base import BaseEmbedding
4 | from transformers import AutoTokenizer, AutoModel
5 | from FlagEmbedding import BGEM3FlagModel
6 |
7 | class BgeM3Embedding(BaseEmbedding):
8 | def __init__(self, model_path: str = "model/bge-m3"):
9 | self.tokenizer = AutoTokenizer.from_pretrained(model_path)
10 | self.model = AutoModel.from_pretrained(model_path)
11 |
12 | self.bge_model = BGEM3FlagModel(model_name_or_path=model_path,
13 | model=self.model,
14 | tokenizer=self.tokenizer,
15 | use_fp16=False)
16 |
17 | self.__dimension = 768
18 |
19 | def to_embeddings(self, data, **_):
20 | if not isinstance(data, list):
21 | data = [data]
22 |
23 | embeddings = self.bge_model.encode(data, batch_size=12, max_length=8192)['dense_vecs']
24 | return np.array(embeddings).astype("float32")
25 |
26 | @property
27 | def dimension(self):
28 | return self.__dimension
--------------------------------------------------------------------------------
/modelcache/embedding/data2vec.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | import time
4 | import numpy as np
5 | import torch
6 | from transformers import BertTokenizer, BertModel
7 | from modelcache.embedding.base import BaseEmbedding
8 |
9 |
10 | def mean_pooling(model_output, attention_mask):
11 | token_embeddings = model_output[0]
12 | input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
13 | return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
14 |
15 |
16 | class Data2VecAudio(BaseEmbedding):
17 | def __init__(self, model):
18 | current_dir = os.path.dirname(os.path.abspath(__file__))
19 | parent_dir = os.path.dirname(current_dir)
20 | model_dir = os.path.dirname(parent_dir)
21 | model_path = os.path.join(model_dir, model)
22 |
23 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
24 | self.tokenizer = BertTokenizer.from_pretrained(model_path, local_files_only=True)
25 | self.model = BertModel.from_pretrained(model_path, local_files_only=True)
26 |
27 | try:
28 | self.__dimension = self.model.config.hidden_size
29 | except Exception:
30 | from transformers import AutoConfig
31 | config = AutoConfig.from_pretrained(model)
32 | self.__dimension = config.hidden_size
33 |
34 | def to_embeddings(self, data, **_):
35 | encoded_input = self.tokenizer(data, padding=True, truncation=True, return_tensors='pt')
36 | num_tokens = sum(map(len, encoded_input['input_ids']))
37 |
38 | if num_tokens <= 512:
39 | with torch.no_grad():
40 | encoded_input = {k: v.to(self.device) for k, v in encoded_input.items()}
41 | model_output = self.model(**encoded_input)
42 | sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
43 | sentence_embeddings = sentence_embeddings.squeeze(0).detach().cpu().numpy()
44 | embedding_array = np.array(sentence_embeddings).astype("float32")
45 | return embedding_array
46 | else:
47 | window_size = 510
48 | start = 0
49 | input_ids = encoded_input['input_ids']
50 | input_ids = input_ids[:, 1:-1]
51 | start_token = self.tokenizer.cls_token
52 | end_token = self.tokenizer.sep_token
53 | start_token_id = self.tokenizer.convert_tokens_to_ids(start_token)
54 | end_token_id = self.tokenizer.convert_tokens_to_ids(end_token)
55 | begin_element = torch.tensor([[start_token_id]])
56 | end_element = torch.tensor([[end_token_id]])
57 |
58 | embedding_array_list = list()
59 | while start < num_tokens:
60 | # Calculate the ending position of the sliding window.
61 | end = start + window_size
62 | # If the ending position exceeds the length, adjust it to the length.
63 | if end > num_tokens:
64 | end = num_tokens
65 | # Retrieve the data within the sliding window.
66 | input_ids_window = input_ids[:, start:end]
67 | # Insert a new element at position 0.
68 | input_ids_window = torch.cat([begin_element, input_ids_window[:, 0:]], dim=1)
69 | # Insert a new element at the last position.
70 | input_ids_window = torch.cat([input_ids_window, end_element], dim=1)
71 | input_ids_window_length = sum(map(len, input_ids_window))
72 | token_type_ids = torch.tensor([[0] * input_ids_window_length])
73 | attention_mask = torch.tensor([[1] * input_ids_window_length])
74 |
75 | # Concatenate new input_ids
76 | encoded_input_window = {'input_ids': input_ids_window, 'token_type_ids': token_type_ids,
77 | 'attention_mask': attention_mask}
78 | with torch.no_grad():
79 | encoded_input_window = {k: v.to(self.device) for k, v in encoded_input_window.items()}
80 | model_output_window = self.model(**encoded_input_window)
81 |
82 | sentence_embeddings_window = mean_pooling(model_output_window, encoded_input_window['attention_mask'])
83 | sentence_embeddings_window = sentence_embeddings_window.squeeze(0).detach().cpu().numpy()
84 | embedding_array_window = np.array(sentence_embeddings_window).astype("float32")
85 | embedding_array_list.append(embedding_array_window)
86 | start = end
87 |
88 | embedding_array = np.mean(embedding_array_list, axis=0)
89 | return embedding_array
90 |
91 | def post_proc(self, token_embeddings, inputs):
92 | attention_mask = inputs["attention_mask"]
93 | input_mask_expanded = (
94 | attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
95 | )
96 | sentence_embs = torch.sum(
97 | token_embeddings * input_mask_expanded, 1
98 | ) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
99 | return sentence_embs
100 |
101 | @property
102 | def dimension(self):
103 | """Embedding dimension.
104 |
105 | :return: embedding dimension
106 | """
107 | return self.__dimension
108 |
--------------------------------------------------------------------------------
/modelcache/embedding/fasttext.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import numpy as np
3 | import os
4 | from modelcache.utils import import_fasttext
5 | from modelcache.embedding.base import BaseEmbedding
6 | import_fasttext()
7 | import fasttext.util
8 |
9 |
10 | class FastText(BaseEmbedding):
11 | def __init__(self, model: str = "en", dim: int = None):
12 | self.model_path = os.path.abspath(fasttext.util.download_model(model))
13 | self.ft = fasttext.load_model(self.model_path)
14 |
15 | if dim:
16 | fasttext.util.reduce_model(self.ft, dim)
17 | self.__dimension = self.ft.get_dimension()
18 |
19 | def to_embeddings(self, data, **_):
20 | assert isinstance(data, str), "Only allow string as input."
21 | emb = self.ft.get_sentence_vector(data)
22 | return np.array(emb).astype("float32")
23 |
24 | @property
25 | def dimension(self):
26 | return self.__dimension
27 |
28 |
--------------------------------------------------------------------------------
/modelcache/embedding/huggingface.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import numpy as np
3 |
4 | from modelcache.utils import import_huggingface, import_torch
5 | from modelcache.embedding.base import BaseEmbedding
6 |
7 | import_torch()
8 | import_huggingface()
9 |
10 | import torch # pylint: disable=C0413
11 | from transformers import AutoTokenizer, AutoModel # pylint: disable=C0413
12 |
13 |
14 | class Huggingface(BaseEmbedding):
15 | def __init__(self, model: str = "sentence-transformers/all-MiniLM-L6-v2"):
16 | self.model = AutoModel.from_pretrained(model, local_files_only=True)
17 | self.model.eval()
18 |
19 | # self.tokenizer = AutoTokenizer.from_pretrained(model)
20 | self.tokenizer = AutoTokenizer.from_pretrained(model, local_files_only=True)
21 | if not self.tokenizer.pad_token:
22 | self.tokenizer.pad_token = "[PAD]"
23 | try:
24 | self.__dimension = self.model.config.hidden_size
25 | except Exception: # pylint: disable=W0703
26 | from transformers import AutoConfig # pylint: disable=C0415
27 |
28 | config = AutoConfig.from_pretrained(model)
29 | self.__dimension = config.hidden_size
30 |
31 | def to_embeddings(self, data, **_):
32 | """Generate embedding given text input
33 |
34 | :param data: text in string.
35 | :type data: str
36 |
37 | :return: a text embedding in shape of (dim,).
38 | """
39 | if not isinstance(data, list):
40 | data = [data]
41 | inputs = self.tokenizer(
42 | data, padding=True, truncation=True, return_tensors="pt"
43 | )
44 | outs = self.model(**inputs).last_hidden_state
45 | emb = self.post_proc(outs, inputs).squeeze(0).detach().numpy()
46 | return np.array(emb).astype("float32")
47 |
48 | def post_proc(self, token_embeddings, inputs):
49 | attention_mask = inputs["attention_mask"]
50 | input_mask_expanded = (
51 | attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
52 | )
53 | sentence_embs = torch.sum(
54 | token_embeddings * input_mask_expanded, 1
55 | ) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
56 | return sentence_embs
57 |
58 | @property
59 | def dimension(self):
60 | """Embedding dimension.
61 |
62 | :return: embedding dimension
63 | """
64 | return self.__dimension
65 |
--------------------------------------------------------------------------------
/modelcache/embedding/huggingface_tei.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import requests
3 | import numpy as np
4 | from modelcache.embedding.base import BaseEmbedding
5 |
6 | class HuggingfaceTEI(BaseEmbedding):
7 | def __init__(self, base_url: str, model: str):
8 | self.base_url = base_url
9 | self.model = model
10 | self.headers = {
11 | 'accept': 'application/json',
12 | 'Content-Type': 'application/json',
13 | }
14 | self.__dimension = self.to_embeddings('test').shape[0]
15 |
16 | def to_embeddings(self, data, **_):
17 | json_data = {
18 | 'input': data,
19 | 'model': self.model,
20 | }
21 |
22 | response = requests.post(self.base_url, headers=self.headers, json=json_data)
23 | embedding = response.json()['data'][0]['embedding']
24 | return np.array(embedding)
25 |
26 | @property
27 | def dimension(self):
28 | """Embedding dimension.
29 |
30 | :return: embedding dimension
31 | """
32 | return self.__dimension
33 |
--------------------------------------------------------------------------------
/modelcache/embedding/llmEmb.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import numpy as np
3 | from modelcache.embedding.base import BaseEmbedding
4 | from transformers import AutoTokenizer
5 | from transformers import AutoConfig
6 |
7 |
8 | class LlmEmb2Vec(BaseEmbedding):
9 | def __init__(self):
10 |
11 | self.model_name = '' # 13b-mft-embedding.npy
12 | model_path = '' # .npy file storage path
13 | model_file = model_path + self.model_name # .npy file
14 | config = AutoConfig.from_pretrained(model_path)
15 | dimension = config.hidden_size
16 | self.__dimension = dimension
17 | self.model = np.load(model_file)
18 | self.tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)
19 |
20 | def to_embeddings(self, data, **_):
21 | """Generate embedding given text input
22 |
23 | :param data: text in string.
24 | :return: a text embedding in shape of (dim,).
25 | """
26 | input_ids = self.tokenizer.encode(data, add_special_tokens=True)
27 | embedding_array = self.model[input_ids].mean(axis=0)
28 | return embedding_array
29 |
30 | def post_proc(self, token_embeddings, inputs):
31 | pass
32 |
33 | @property
34 | def dimension(self):
35 | """Embedding dimension.
36 | :return: embedding dimension
37 | """
38 | return self.__dimension
39 |
--------------------------------------------------------------------------------
/modelcache/embedding/onnx.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import numpy as np
3 |
4 | from modelcache.embedding.base import BaseEmbedding
5 | from modelcache.utils import (
6 | import_onnxruntime,
7 | import_huggingface_hub,
8 | import_huggingface,
9 | )
10 |
11 | import_huggingface()
12 | import_onnxruntime()
13 | import_huggingface_hub()
14 |
15 | from transformers import AutoTokenizer, AutoConfig # pylint: disable=C0413
16 | import onnxruntime
17 | from modelcache.utils.env_config import get_onnx_tokenizer_path, get_onnx_model
18 |
19 |
20 | class Onnx(BaseEmbedding):
21 |
22 | def __init__(self, model="modelcache_open/paraphrase-albert-onnx"):
23 | # 本地加载
24 | onnx_tokenizer = get_onnx_tokenizer_path()
25 | self.tokenizer = AutoTokenizer.from_pretrained(onnx_tokenizer, local_files_only=True)
26 | # 本地加载
27 | onnx_model = get_onnx_model()
28 | self.ort_session = onnxruntime.InferenceSession(onnx_model)
29 |
30 | config = AutoConfig.from_pretrained(onnx_tokenizer, local_files_only=True)
31 | self.__dimension = config.hidden_size
32 |
33 | def to_embeddings(self, data, **_):
34 | """Generate embedding given text input.
35 |
36 | :param data: text in string.
37 | :type data: str
38 |
39 | :return: a text embedding in shape of (dim,).
40 | """
41 | encoded_text = self.tokenizer.encode_plus(data, padding="max_length")
42 | ort_inputs = {
43 | "input_ids": np.array(encoded_text["input_ids"]).reshape(1, -1),
44 | "attention_mask": np.array(encoded_text["attention_mask"]).reshape(1, -1),
45 | "token_type_ids": np.array(encoded_text["token_type_ids"]).reshape(1, -1),
46 | }
47 |
48 | ort_outputs = self.ort_session.run(None, ort_inputs)
49 | ort_feat = ort_outputs[0]
50 | emb = self.post_proc(ort_feat, ort_inputs["attention_mask"])
51 | return emb.flatten()
52 |
53 | def post_proc(self, token_embeddings, attention_mask):
54 | input_mask_expanded = (
55 | np.expand_dims(attention_mask, -1)
56 | .repeat(token_embeddings.shape[-1], -1)
57 | .astype(float)
58 | )
59 | sentence_embs = np.sum(token_embeddings * input_mask_expanded, 1) / np.maximum(
60 | input_mask_expanded.sum(1), 1e-9
61 | )
62 | return sentence_embs
63 |
64 | @property
65 | def dimension(self):
66 | """Embedding dimension.
67 |
68 | :return: embedding dimension
69 | """
70 | return self.__dimension
71 |
--------------------------------------------------------------------------------
/modelcache/embedding/paddlenlp.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import numpy as np
3 |
4 | from modelcache.embedding.base import BaseEmbedding
5 | from modelcache.utils import import_paddlenlp, import_paddle
6 |
7 | import_paddle()
8 | import_paddlenlp()
9 |
10 |
11 | import paddle # pylint: disable=C0413
12 | from paddlenlp.transformers import AutoModel, AutoTokenizer # pylint: disable=C0413
13 |
14 |
15 | class PaddleNLP(BaseEmbedding):
16 | def __init__(self, model: str = "ernie-3.0-medium-zh"):
17 | self.model = AutoModel.from_pretrained(model)
18 | self.model.eval()
19 |
20 | self.tokenizer = AutoTokenizer.from_pretrained(model)
21 | if not self.tokenizer.pad_token:
22 | self.tokenizer.pad_token = ""
23 | self.__dimension = None
24 |
25 | def to_embeddings(self, data, **_):
26 | """Generate embedding given text input
27 |
28 | :param data: text in string.
29 | :type data: str
30 |
31 | :return: a text embedding in shape of (dim,).
32 | """
33 | if not isinstance(data, list):
34 | data = [data]
35 | inputs = self.tokenizer(
36 | data, padding=True, truncation=True, return_tensors="pd"
37 | )
38 | outs = self.model(**inputs)[0]
39 | emb = self.post_proc(outs, inputs).squeeze(0).detach().numpy()
40 | return np.array(emb).astype("float32")
41 |
42 | def post_proc(self, token_embeddings, inputs):
43 | attention_mask = paddle.ones(inputs["token_type_ids"].shape)
44 | input_mask_expanded = (
45 | attention_mask.unsqueeze(-1).expand(token_embeddings.shape).astype("float32")
46 | )
47 | sentence_embs = paddle.sum(
48 | token_embeddings * input_mask_expanded, 1
49 | ) / paddle.clip(input_mask_expanded.sum(1), min=1e-9)
50 | return sentence_embs
51 |
52 | @property
53 | def dimension(self):
54 | """Embedding dimension.
55 |
56 | :return: embedding dimension
57 | """
58 | if not self.__dimension:
59 | self.__dimension = len(self.to_embeddings("foo"))
60 | return self.__dimension
61 |
--------------------------------------------------------------------------------
/modelcache/embedding/string_text.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | def to_embeddings(data, **_):
5 | return data
6 |
--------------------------------------------------------------------------------
/modelcache/embedding/timm_embedding.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import numpy as np
3 |
4 | from modelcache.utils import import_timm, import_torch, import_pillow
5 | from modelcache.embedding.base import BaseEmbedding
6 |
7 | import_torch()
8 | import_timm()
9 | import_pillow()
10 |
11 | import torch # pylint: disable=C0413
12 | from timm.models import create_model # pylint: disable=C0413
13 | from timm.data import create_transform, resolve_data_config # pylint: disable=C0413
14 | from PIL import Image # pylint: disable=C0413
15 |
16 |
17 | class Timm(BaseEmbedding):
18 | def __init__(self, model: str = "resnet18", device: str = "default"):
19 | if device == "default":
20 | self.device = "cuda" if torch.cuda.is_available() else "cpu"
21 | else:
22 | self.device = device
23 | self.model_name = model
24 | self.model = create_model(model_name=model, pretrained=True)
25 | self.model.eval()
26 |
27 | try:
28 | self.__dimension = self.model.embed_dim
29 | except Exception: # pylint: disable=W0703
30 | self.__dimension = None
31 |
32 | def to_embeddings(self, data, skip_preprocess: bool = False, **_):
33 | if not skip_preprocess:
34 | data = self.preprocess(data)
35 | if data.dim() == 3:
36 | data = data.unsqueeze(0)
37 | feats = self.model.forward_features(data)
38 | emb = self.post_proc(feats).squeeze(0).detach().numpy()
39 |
40 | return np.array(emb).astype("float32")
41 |
42 | def post_proc(self, features):
43 | features = features.to("cpu")
44 | if features.dim() == 3:
45 | features = features[:, 0]
46 | if features.dim() == 4:
47 | global_pool = torch.nn.AdaptiveAvgPool2d(1)
48 | features = global_pool(features)
49 | features = features.flatten(1)
50 | assert features.dim() == 2, f"Invalid output dim {features.dim()}"
51 | return features
52 |
53 | def preprocess(self, image_path):
54 | data_cfg = resolve_data_config(self.model.pretrained_cfg)
55 | transform = create_transform(**data_cfg)
56 |
57 | image = Image.open(image_path).convert("RGB")
58 | image_tensor = transform(image)
59 | return image_tensor
60 |
61 | @property
62 | def dimension(self):
63 | """Embedding dimension.
64 |
65 | :return: embedding dimension
66 | """
67 | if not self.__dimension:
68 | input_size = self.model.pretrained_cfg["input_size"]
69 | dummy_input = torch.rand((1,) + input_size)
70 | feats = self.to_embeddings(dummy_input, skip_preprocess=True)
71 | self.__dimension = feats.shape[0]
72 | return self.__dimension
73 |
--------------------------------------------------------------------------------
/modelcache/manager/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from modelcache.manager.scalar_data import CacheBase
3 | from modelcache.manager.vector_data import VectorBase
4 | from modelcache.manager.object_data import ObjectBase
5 | from modelcache.manager.factory import get_data_manager
6 |
--------------------------------------------------------------------------------
/modelcache/manager/eviction/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from modelcache.utils.lazy_import import LazyImport
3 |
4 | eviction_manager = LazyImport(
5 | "eviction_manager", globals(), "modelcache.manager.eviction.manager"
6 | )
7 |
8 |
9 | def EvictionBase(name: str, **kwargs):
10 | return eviction_manager.EvictionBase.get(name, **kwargs)
11 |
--------------------------------------------------------------------------------
/modelcache/manager/eviction/base.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from abc import ABCMeta, abstractmethod
3 | from typing import Any, List
4 |
5 |
6 | class EvictionBase(metaclass=ABCMeta):
7 | """
8 | Eviction base.
9 | """
10 |
11 | @abstractmethod
12 | def put(self, objs: List[Any]):
13 | pass
14 |
15 | @abstractmethod
16 | def get(self, obj: Any):
17 | pass
18 |
19 | @property
20 | @abstractmethod
21 | def policy(self) -> str:
22 | pass
23 |
--------------------------------------------------------------------------------
/modelcache/manager/eviction/manager.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from typing import Callable, List, Any
3 | from modelcache.utils.error import NotFoundError
4 |
5 |
6 | class EvictionBase:
7 | """
8 | EvictionBase to evict the cache data.
9 | """
10 |
11 | def __init__(self):
12 | raise EnvironmentError(
13 | "EvictionBase is designed to be instantiated, "
14 | "please using the `EvictionBase.get(name, policy, maxsize, clean_size)`."
15 | )
16 |
17 | @staticmethod
18 | def get(name: str, policy: str, maxsize: int, clean_size: int, on_evict: Callable[[List[Any]], None], **kwargs):
19 | if name in "memory":
20 | from modelcache.manager.eviction.memory_cache import MemoryCacheEviction
21 |
22 | eviction_base = MemoryCacheEviction(policy, maxsize, clean_size, on_evict, **kwargs)
23 | else:
24 | raise NotFoundError("eviction base", name)
25 | return eviction_base
26 |
--------------------------------------------------------------------------------
/modelcache/manager/eviction/memory_cache.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from typing import Any, Callable, List
3 | import cachetools
4 |
5 | from modelcache.manager.eviction.base import EvictionBase
6 |
7 |
8 | def popitem_wrapper(func, wrapper_func, clean_size):
9 | def wrapper(*args, **kwargs):
10 | keys = []
11 | try:
12 | keys = [func(*args, **kwargs)[0] for _ in range(clean_size)]
13 | except KeyError:
14 | pass
15 | wrapper_func(keys)
16 | return wrapper
17 |
18 |
19 | class MemoryCacheEviction(EvictionBase):
20 | def __init__(self, policy: str, maxsize: int, clean_size: int, on_evict: Callable[[List[Any]], None], **kwargs):
21 | self._policy = policy.upper()
22 | if self._policy == "LRU":
23 | self._cache = cachetools.LRUCache(maxsize=maxsize, **kwargs)
24 | elif self._policy == "LFU":
25 | self._cache = cachetools.LFUCache(maxsize=maxsize, **kwargs)
26 | elif self._policy == "FIFO":
27 | self._cache = cachetools.FIFOCache(maxsize=maxsize, **kwargs)
28 | elif self._policy == "RR":
29 | self._cache = cachetools.RRCache(maxsize=maxsize, **kwargs)
30 | else:
31 | raise ValueError(f"Unknown policy {policy}")
32 |
33 | self._cache.popitem = popitem_wrapper(self._cache.popitem, on_evict, clean_size)
34 |
35 | def put(self, objs: List[Any]):
36 | for obj in objs:
37 | self._cache[obj] = True
38 |
39 | def get(self, obj: Any):
40 | return self._cache.get(obj)
41 |
42 | @property
43 | def policy(self) -> str:
44 | return self._policy
45 |
--------------------------------------------------------------------------------
/modelcache/manager/eviction_manager.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | class EvictionManager:
3 | """
4 | EvictionManager to manager the eviction policy.
5 |
6 | :param scalar_storage: CacheStorage to manager the scalar data.
7 | :type scalar_storage: :class:`CacheStorage`
8 | :param vector_base: VectorBase to manager the vector data.
9 | :type vector_base: :class:`VectorBase`
10 | """
11 |
12 | MAX_MARK_COUNT = 5000
13 | MAX_MARK_RATE = 0.1
14 | BATCH_SIZE = 100000
15 | REBUILD_CONDITION = 5
16 |
17 | def __init__(self, scalar_storage, vector_base):
18 | self._scalar_storage = scalar_storage
19 | self._vector_base = vector_base
20 | self.delete_count = 0
21 |
22 | def check_evict(self):
23 | mark_count = self._scalar_storage.count(state=-1)
24 | all_count = self._scalar_storage.count(is_all=True)
25 | if (
26 | mark_count > self.MAX_MARK_COUNT
27 | or mark_count / all_count > self.MAX_MARK_RATE
28 | ):
29 | return True
30 | return False
31 |
32 | def delete(self,model):
33 | mark_ids = self._scalar_storage.get_ids(deleted=True)
34 | self._scalar_storage.clear_deleted_data()
35 | self._vector_base.delete(mark_ids,model)
36 | self.delete_count += 1
37 | if self.delete_count >= self.REBUILD_CONDITION:
38 | self.rebuild()
39 |
40 | def rebuild(self):
41 | self._scalar_storage.clear_deleted_data()
42 | ids = self._scalar_storage.get_ids(deleted=False)
43 | self._vector_base.rebuild(ids)
44 | self.delete_count = 0
45 |
46 | def soft_evict(self, marked_keys):
47 | self._scalar_storage.mark_deleted(marked_keys)
48 |
--------------------------------------------------------------------------------
/modelcache/manager/factory.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from typing import Union, Callable
3 | from modelcache.manager import CacheBase, VectorBase, ObjectBase
4 | from modelcache.manager.data_manager import SSDataManager, MapDataManager
5 |
6 |
7 | def get_data_manager(
8 | cache_base: Union[CacheBase, str] = None,
9 | vector_base: Union[VectorBase, str] = None,
10 | object_base: Union[ObjectBase, str] = None,
11 | max_size: int = 1000,
12 | clean_size: int = None,
13 | eviction: str = "LRU",
14 | data_path: str = "data_map.txt",
15 | get_data_container: Callable = None,
16 | ):
17 | if not cache_base and not vector_base:
18 | return MapDataManager(data_path, max_size, get_data_container)
19 |
20 | if isinstance(cache_base, str):
21 | cache_base = CacheBase(name=cache_base)
22 | if isinstance(vector_base, str):
23 | vector_base = VectorBase(name=vector_base)
24 | if isinstance(object_base, str):
25 | object_base = ObjectBase(name=object_base)
26 | assert cache_base and vector_base
27 | return SSDataManager(cache_base, vector_base, object_base, max_size, clean_size, eviction)
28 |
--------------------------------------------------------------------------------
/modelcache/manager/object_data/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from modelcache.utils.lazy_import import LazyImport
3 | object_manager = LazyImport(
4 | "object_manager", globals(), "modelcache.manager.object_data.manager"
5 | )
6 |
7 |
8 | def ObjectBase(name: str, **kwargs):
9 | return object_manager.ObjectBase.get(name, **kwargs)
10 |
--------------------------------------------------------------------------------
/modelcache/manager/object_data/base.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from abc import ABC, abstractmethod
3 | from typing import Any, List
4 |
5 |
6 | class ObjectBase(ABC):
7 | """
8 | Object storage base.
9 | """
10 |
11 | @abstractmethod
12 | def put(self, obj: Any) -> str:
13 | pass
14 |
15 | @abstractmethod
16 | def get(self, obj: str) -> Any:
17 | pass
18 |
19 | @abstractmethod
20 | def get_access_link(self, obj: str) -> str:
21 | pass
22 |
23 | @abstractmethod
24 | def delete(self, to_delete: List[str]):
25 | pass
26 |
--------------------------------------------------------------------------------
/modelcache/manager/scalar_data/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from modelcache.utils.lazy_import import LazyImport
3 | scalar_manager = LazyImport(
4 | "scalar_manager", globals(), "modelcache.manager.scalar_data.manager"
5 | )
6 |
7 |
8 | def CacheBase(name: str, **kwargs):
9 | return scalar_manager.CacheBase.get(name, **kwargs)
10 |
--------------------------------------------------------------------------------
/modelcache/manager/scalar_data/base.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from abc import ABCMeta, abstractmethod
3 | from dataclasses import dataclass
4 | from typing import Union, Dict, List, Optional, Any
5 | from enum import IntEnum
6 | import numpy as np
7 |
8 |
9 | class DataType(IntEnum):
10 | STR = 0
11 | IMAGE_BASE64 = 1
12 | IMAGE_URL = 2
13 |
14 |
15 | @dataclass
16 | class QuestionDep:
17 | """
18 | QuestionDep
19 | """
20 |
21 | name: str
22 | data: str
23 | dep_type: int = DataType.STR
24 |
25 | @classmethod
26 | def from_dict(cls, d: Dict):
27 | return cls(
28 | name=d["name"],
29 | data=d["data"],
30 | dep_type=d["dep_type"]
31 | )
32 |
33 |
34 | @dataclass
35 | class Question:
36 | """
37 | Question
38 | """
39 |
40 | content: str
41 | deps: Optional[List[QuestionDep]] = None
42 |
43 | @classmethod
44 | def from_dict(cls, d: Dict):
45 | deps = []
46 | for dep in d["deps"]:
47 | deps.append(QuestionDep.from_dict(dep))
48 | return cls(d["content"], deps)
49 |
50 |
51 | @dataclass
52 | class Answer:
53 | """
54 | data_type:
55 | 0: str
56 | 1: base64 image
57 | """
58 |
59 | answer: Any
60 | answer_type: int = DataType.STR
61 |
62 |
63 | @dataclass
64 | class CacheData:
65 | """
66 | CacheData
67 | """
68 |
69 | question: Union[str, Question]
70 | answers: List[Answer]
71 | embedding_data: Optional[np.ndarray] = None
72 |
73 | def __init__(self, question, answers, embedding_data=None):
74 | self.question = question
75 | self.answers = []
76 | if isinstance(answers, (str, Answer)):
77 | answers = [answers]
78 | for data in answers:
79 | if isinstance(data, (list, tuple)):
80 | self.answers.append(Answer(*data))
81 | elif isinstance(data, Answer):
82 | self.answers.append(data)
83 | else:
84 | self.answers.append(Answer(answer=data))
85 | self.embedding_data = embedding_data
86 |
87 |
88 | class CacheStorage(metaclass=ABCMeta):
89 | """
90 | BaseStorage for scalar data.
91 | """
92 |
93 | @abstractmethod
94 | def create(self):
95 | pass
96 |
97 | @abstractmethod
98 | def insert_query_resp(self, query_resp, **kwargs):
99 | pass
100 |
101 | @abstractmethod
102 | def get_data_by_id(self, key):
103 | pass
104 |
105 | @abstractmethod
106 | def mark_deleted(self, keys):
107 | pass
108 |
109 | @abstractmethod
110 | def model_deleted(self, model_name):
111 | pass
112 |
113 | @abstractmethod
114 | def clear_deleted_data(self):
115 | pass
116 |
117 | @abstractmethod
118 | def get_ids(self, deleted=True):
119 | pass
120 |
121 | @abstractmethod
122 | def count(self):
123 | pass
124 |
125 | def flush(self):
126 | pass
127 |
128 | @abstractmethod
129 | def close(self):
130 | pass
131 |
--------------------------------------------------------------------------------
/modelcache/manager/scalar_data/manager.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from modelcache.utils import import_sql_client
3 | from modelcache.utils.error import NotFoundError
4 |
5 | SQL_URL = {"sqlite": "./sqlite.db"}
6 |
7 |
8 | class CacheBase:
9 | """
10 | CacheBase to manager the cache storage.
11 | """
12 |
13 | def __init__(self):
14 | raise EnvironmentError(
15 | "CacheBase is designed to be instantiated, please using the `CacheBase.get(name)`."
16 | )
17 |
18 | @staticmethod
19 | def get(name, **kwargs):
20 |
21 | if name in ["mysql", "oceanbase"]:
22 | from modelcache.manager.scalar_data.sql_storage import SQLStorage
23 | config = kwargs.get("config")
24 | import_sql_client(name)
25 | cache_base = SQLStorage(db_type=name, config=config)
26 | elif name == 'sqlite':
27 | from modelcache.manager.scalar_data.sql_storage_sqlite import SQLStorage
28 | sql_url = kwargs.get("sql_url", SQL_URL[name])
29 | cache_base = SQLStorage(db_type=name, url=sql_url)
30 | elif name == 'elasticsearch':
31 | from modelcache.manager.scalar_data.sql_storage_es import SQLStorage
32 | config = kwargs.get("config")
33 | cache_base = SQLStorage(db_type=name, config=config)
34 | else:
35 | raise NotFoundError("cache store", name)
36 | return cache_base
37 |
--------------------------------------------------------------------------------
/modelcache/manager/scalar_data/sql_storage_sqlite.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import json
3 | from typing import List
4 | from modelcache.manager.scalar_data.base import CacheStorage, CacheData
5 | import sqlite3
6 |
7 |
8 | class SQLStorage(CacheStorage):
9 | def __init__(
10 | self,
11 | db_type: str = "mysql",
12 | config=None,
13 | url="./sqlite.db"
14 | ):
15 | self._url = url
16 | # self._engine = sqlite3.connect(url)
17 | self.create()
18 |
19 | def create(self):
20 | answer_table_sql = """CREATE TABLE IF NOT EXISTS modelcache_llm_answer (
21 | id INTEGER PRIMARY KEY AUTOINCREMENT,
22 | gmt_create TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
23 | gmt_modified TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
24 | question TEXT NOT NULL,
25 | answer TEXT NOT NULL,
26 | answer_type INTEGER NOT NULL,
27 | hit_count INTEGER NOT NULL DEFAULT 0,
28 | model VARCHAR(1000) NOT NULL,
29 | embedding_data BLOB NOT NULL
30 | );
31 | """
32 |
33 | log_table_sql = """CREATE TABLE IF NOT EXISTS modelcache_query_log (
34 | id INTEGER PRIMARY KEY AUTOINCREMENT,
35 | gmt_create TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
36 | gmt_modified TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
37 | error_code INTEGER NOT NULL,
38 | error_desc VARCHAR(1000) NOT NULL,
39 | cache_hit VARCHAR(100) NOT NULL,
40 | delta_time REAL NOT NULL,
41 | model VARCHAR(1000) NOT NULL,
42 | query TEXT NOT NULL,
43 | hit_query TEXT NOT NULL,
44 | answer TEXT NOT NULL
45 | );
46 | """
47 |
48 | conn = sqlite3.connect(self._url)
49 | try:
50 | cursor = conn.cursor()
51 | cursor.execute(answer_table_sql)
52 | cursor.execute(log_table_sql)
53 | conn.commit()
54 | cursor.close()
55 | conn.close()
56 | finally:
57 | conn.close()
58 |
59 | def _insert(self, data: List):
60 | answer = data[0]
61 | question = data[1]
62 | embedding_data = data[2]
63 | model = data[3]
64 | answer_type = 0
65 | embedding_data = embedding_data.tobytes()
66 |
67 | table_name = "modelcache_llm_answer"
68 | insert_sql = "INSERT INTO {} (question, answer, answer_type, model, embedding_data) VALUES (?, ?, ?, ?, ?)".format(table_name)
69 |
70 | conn = sqlite3.connect(self._url)
71 | try:
72 | cursor = conn.cursor()
73 | values = (question, answer, answer_type, model, embedding_data)
74 | cursor.execute(insert_sql, values)
75 | conn.commit()
76 | id = cursor.lastrowid
77 | cursor.close()
78 | conn.close()
79 | finally:
80 | conn.close()
81 | return id
82 |
83 | def batch_insert(self, all_data: List[CacheData]):
84 | ids = []
85 | for data in all_data:
86 | ids.append(self._insert(data))
87 | return ids
88 |
89 | def insert_query_resp(self, query_resp, **kwargs):
90 | error_code = query_resp.get('errorCode')
91 | error_desc = query_resp.get('errorDesc')
92 | cache_hit = query_resp.get('cacheHit')
93 | model = kwargs.get('model')
94 | query = kwargs.get('query')
95 | delta_time = kwargs.get('delta_time')
96 | hit_query = query_resp.get('hit_query')
97 | answer = query_resp.get('answer')
98 |
99 | if isinstance(hit_query, list):
100 | hit_query = json.dumps(hit_query, ensure_ascii=False)
101 |
102 | table_name = "modelcache_query_log"
103 | insert_sql = "INSERT INTO {} (error_code, error_desc, cache_hit, model, query, delta_time, hit_query, answer) VALUES (?, ?, ?, ?, ?, ?, ?, ?)".format(table_name)
104 | conn = sqlite3.connect(self._url)
105 | try:
106 | cursor = conn.cursor()
107 | values = (error_code, error_desc, cache_hit, model, query, delta_time, hit_query, answer)
108 | cursor.execute(insert_sql, values)
109 | conn.commit()
110 | cursor.close()
111 | conn.close()
112 | finally:
113 | conn.close()
114 |
115 | def get_data_by_id(self, key: int):
116 | table_name = "modelcache_llm_answer"
117 | query_sql = "select question, answer, embedding_data, model from {} where id={}".format(table_name, key)
118 | conn = sqlite3.connect(self._url)
119 | try:
120 | cursor = conn.cursor()
121 | cursor.execute(query_sql)
122 | resp = cursor.fetchone()
123 | conn.commit()
124 | cursor.close()
125 | conn.close()
126 | finally:
127 | conn.close()
128 |
129 | if resp is not None and len(resp) == 4:
130 | return resp
131 | else:
132 | return None
133 |
134 | def update_hit_count_by_id(self, primary_id: int):
135 | table_name = "modelcache_llm_answer"
136 | update_sql = "UPDATE {} SET hit_count = hit_count+1 WHERE id={}".format(table_name, primary_id)
137 |
138 | conn = sqlite3.connect(self._url)
139 | try:
140 | cursor = conn.cursor()
141 | cursor.execute(update_sql)
142 | conn.commit()
143 | cursor.close()
144 | conn.close()
145 | finally:
146 | # 关闭连接,将连接返回给连接池
147 | conn.close()
148 |
149 | def get_ids(self, deleted=True):
150 | pass
151 |
152 | def mark_deleted(self, keys):
153 | table_name = "modelcache_llm_answer"
154 | delete_sql = "Delete from {} WHERE id in ({})".format(table_name, ",".join([str(i) for i in keys]))
155 | conn = sqlite3.connect(self._url)
156 | try:
157 | cursor = conn.cursor()
158 | cursor.execute(delete_sql)
159 | delete_count = cursor.rowcount
160 | conn.commit()
161 | cursor.close()
162 | conn.close()
163 | finally:
164 | conn.close()
165 | return delete_count
166 |
167 | def model_deleted(self, model_name):
168 | table_name = "modelcache_llm_answer"
169 | delete_sql = "Delete from {} WHERE model=?".format(table_name)
170 |
171 | table_log_name = "modelcache_query_log"
172 | delete_log_sql = "Delete from {} WHERE model=?".format(table_log_name)
173 | conn = sqlite3.connect(self._url)
174 | try:
175 | cursor = conn.cursor()
176 | cursor.execute(delete_sql, (model_name,))
177 | conn.commit()
178 | # get delete rows
179 | deleted_rows_count = cursor.rowcount
180 |
181 | cursor.execute(delete_log_sql, (model_name,))
182 | conn.commit()
183 | cursor.close()
184 | except sqlite3.Error as e:
185 | print(f"SQLite error: {e}")
186 | deleted_rows_count = 0 # if except, return 0
187 | finally:
188 | conn.close()
189 | return deleted_rows_count
190 |
191 | def clear_deleted_data(self):
192 | pass
193 |
194 | def count(self, state: int = 0, is_all: bool = False):
195 | pass
196 |
197 | def close(self):
198 | pass
199 |
200 | def count_answers(self):
201 | pass
202 |
--------------------------------------------------------------------------------
/modelcache/manager/vector_data/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from modelcache.utils.lazy_import import LazyImport
3 |
4 | vector_manager = LazyImport(
5 | "vector_manager", globals(), "modelcache.manager.vector_data.manager"
6 | )
7 |
8 |
9 | def VectorBase(name: str, **kwargs):
10 | return vector_manager.VectorBase.get(name, **kwargs)
11 |
--------------------------------------------------------------------------------
/modelcache/manager/vector_data/base.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from abc import ABC, abstractmethod
3 | import numpy as np
4 | from typing import List
5 | from dataclasses import dataclass
6 |
7 |
8 | @dataclass
9 | class VectorData:
10 | id: int
11 | data: np.ndarray
12 |
13 |
14 | class VectorBase(ABC):
15 | """VectorBase: base vector store interface"""
16 |
17 | @abstractmethod
18 | def mul_add(self, datas: List[VectorData], model=None):
19 | pass
20 |
21 | @abstractmethod
22 | def search(self, data: np.ndarray, top_k: int, model):
23 | pass
24 |
25 | @abstractmethod
26 | def rebuild(self, ids=None) -> bool:
27 | pass
28 |
29 | @abstractmethod
30 | def delete(self, ids) -> bool:
31 | pass
32 |
33 | @abstractmethod
34 | def rebuild_col(self, model):
35 | pass
36 |
37 | def flush(self):
38 | pass
39 |
40 | def close(self):
41 | pass
42 |
--------------------------------------------------------------------------------
/modelcache/manager/vector_data/chroma.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import numpy as np
4 | import logging
5 | from modelcache.manager.vector_data.base import VectorBase, VectorData
6 | from modelcache.utils import import_chromadb, import_torch
7 |
8 | import_torch()
9 | import_chromadb()
10 |
11 | import chromadb
12 |
13 |
14 | class Chromadb(VectorBase):
15 |
16 | def __init__(
17 | self,
18 | persist_directory="./chromadb",
19 | top_k: int = 1,
20 | ):
21 | self.collection_name = "modelcache"
22 | self.top_k = top_k
23 |
24 | self._client = chromadb.PersistentClient(path=persist_directory)
25 | self._collection = None
26 |
27 | def mul_add(self, datas: List[VectorData], model=None):
28 | collection_name_model = self.collection_name + '_' + model
29 | self._collection = self._client.get_or_create_collection(name=collection_name_model)
30 |
31 | data_array, id_array = map(list, zip(*((data.data.tolist(), str(data.id)) for data in datas)))
32 | self._collection.add(embeddings=data_array, ids=id_array)
33 |
34 | def search(self, data: np.ndarray, top_k: int = -1, model=None):
35 | collection_name_model = self.collection_name + '_' + model
36 | self._collection = self._client.get_or_create_collection(name=collection_name_model)
37 |
38 | if self._collection.count() == 0:
39 | return []
40 | if top_k == -1:
41 | top_k = self.top_k
42 | results = self._collection.query(
43 | query_embeddings=[data.tolist()],
44 | n_results=top_k,
45 | include=["distances"],
46 | )
47 | return list(zip(results["distances"][0], [int(x) for x in results["ids"][0]]))
48 |
49 | def rebuild(self, ids=None):
50 | pass
51 |
52 | def delete(self, ids, model=None):
53 | try:
54 | collection_name_model = self.collection_name + '_' + model
55 | self._collection = self._client.get_or_create_collection(name=collection_name_model)
56 | # 查询集合中实际存在的 ID
57 | ids_str = [str(x) for x in ids]
58 | existing_ids = set(self._collection.get(ids=ids_str).ids)
59 |
60 | # 删除存在的 ID
61 | if existing_ids:
62 | self._collection.delete(list(existing_ids))
63 |
64 | # 返回实际删除的条目数量
65 | return len(existing_ids)
66 |
67 | except Exception as e:
68 | logging.error('Error during deletion: {}'.format(e))
69 | raise ValueError(str(e))
70 |
71 | def rebuild_col(self, model):
72 | collection_name_model = self.collection_name + '_' + model
73 |
74 | # 检查集合是否存在,如果存在则删除
75 | collections = self._client.list_collections()
76 | if any(col.name == collection_name_model for col in collections):
77 | self._client.delete_collection(collection_name_model)
78 | else:
79 | return 'model collection not found, please check!'
80 |
81 | try:
82 | self._client.create_collection(collection_name_model)
83 | except Exception as e:
84 | logging.info(f'rebuild_collection: {e}')
85 | raise ValueError(str(e))
86 |
87 | def flush(self):
88 | # chroma无flush方法
89 | pass
90 |
91 | def close(self):
92 | pass
93 |
--------------------------------------------------------------------------------
/modelcache/manager/vector_data/faiss.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | from typing import List
4 | import numpy as np
5 | from modelcache.manager.vector_data.base import VectorBase, VectorData
6 | from modelcache.utils import import_faiss
7 | import_faiss()
8 | import faiss # pylint: disable=C0413
9 |
10 |
11 | class Faiss(VectorBase):
12 | def __init__(self, index_file_path, dimension, top_k):
13 | self._index_file_path = index_file_path
14 | self._dimension = dimension
15 | self._index = faiss.index_factory(self._dimension, "IDMap,Flat", faiss.METRIC_L2)
16 | self._top_k = top_k
17 | if os.path.isfile(index_file_path):
18 | self._index = faiss.read_index(index_file_path)
19 |
20 | def mul_add(self, datas: List[VectorData], model=None):
21 | data_array, id_array = map(list, zip(*((data.data, data.id) for data in datas)))
22 | np_data = np.array(data_array).astype("float32")
23 | ids = np.array(id_array)
24 | self._index.add_with_ids(np_data, ids)
25 |
26 | def search(self, data: np.ndarray, top_k: int = -1, model=None):
27 | if self._index.ntotal == 0:
28 | return None
29 | if top_k == -1:
30 | top_k = self._top_k
31 | np_data = np.array(data).astype("float32").reshape(1, -1)
32 | dist, ids = self._index.search(np_data, top_k)
33 | ids = [int(i) for i in ids[0]]
34 | return list(zip(dist[0], ids))
35 |
36 | def rebuild_col(self, ids=None):
37 | try:
38 | self._index.reset()
39 | except Exception as e:
40 | return f"An error occurred during index rebuild: {e}"
41 |
42 | def rebuild(self, ids=None):
43 | return True
44 |
45 | def delete(self, ids):
46 | ids_to_remove = np.array(ids)
47 | self._index.remove_ids(faiss.IDSelectorBatch(ids_to_remove.size, faiss.swig_ptr(ids_to_remove)))
48 |
49 | def flush(self):
50 | faiss.write_index(self._index, self._index_file_path)
51 |
52 | def close(self):
53 | self.flush()
54 |
55 | def count(self):
56 | return self._index.ntotal
57 |
--------------------------------------------------------------------------------
/modelcache/manager/vector_data/manager.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from modelcache.utils.error import NotFoundError, ParamError
3 |
4 | TOP_K = 1
5 | FAISS_INDEX_PATH = "faiss.index"
6 | DIMENSION = 0
7 | MILVUS_HOST = "localhost"
8 | MILVUS_PORT = 19530
9 | MILVUS_USER = ""
10 | MILVUS_PSW = ""
11 | MILVUS_SECURE = False
12 | MILVUS_INDEX_PARAMS = {
13 | "metric_type": "L2",
14 | "index_type": "HNSW",
15 | "params": {"M": 8, "efConstruction": 64},
16 | }
17 |
18 | COLLECTION_NAME = "modelcache"
19 |
20 |
21 | class VectorBase:
22 | """
23 | VectorBase to manager the vector base.
24 | """
25 |
26 | def __init__(self):
27 | raise EnvironmentError(
28 | "VectorBase is designed to be instantiated, please using the `VectorBase.get(name)`."
29 | )
30 |
31 | @staticmethod
32 | def check_dimension(dimension):
33 | if dimension <= 0:
34 | raise ParamError(
35 | f"the dimension should be greater than zero, current value: {dimension}."
36 | )
37 |
38 | @staticmethod
39 | def get(name, **kwargs):
40 | top_k = kwargs.get("top_k", TOP_K)
41 | if name == "milvus":
42 | from modelcache.manager.vector_data.milvus import Milvus
43 | dimension = kwargs.get("dimension", DIMENSION)
44 | milvus_config = kwargs.get("milvus_config")
45 | VectorBase.check_dimension(dimension)
46 | host = milvus_config.get('milvus', 'host')
47 | port = milvus_config.get('milvus', 'port')
48 | user = milvus_config.get('milvus', 'user')
49 | password = milvus_config.get('milvus', 'password')
50 |
51 | secure = kwargs.get("secure", MILVUS_SECURE)
52 | collection_name = kwargs.get("collection_name", COLLECTION_NAME)
53 | index_params = kwargs.get("index_params", MILVUS_INDEX_PARAMS)
54 | search_params = kwargs.get("search_params", None)
55 | local_mode = kwargs.get("local_mode", False)
56 | local_data = kwargs.get("local_data", "./milvus_data")
57 | vector_base = Milvus(
58 | host=host,
59 | port=port,
60 | user=user,
61 | password=password,
62 | secure=secure,
63 | collection_name=collection_name,
64 | dimension=dimension,
65 | top_k=top_k,
66 | index_params=index_params,
67 | search_params=search_params,
68 | local_mode=local_mode,
69 | local_data=local_data
70 | )
71 | elif name == "redis":
72 | from modelcache.manager.vector_data.redis import RedisVectorStore
73 | dimension = kwargs.get("dimension", DIMENSION)
74 | VectorBase.check_dimension(dimension)
75 |
76 | redis_config = kwargs.get("redis_config")
77 | host = redis_config.get('redis', 'host')
78 | port = redis_config.get('redis', 'port')
79 | user = redis_config.get('redis', 'user')
80 | password = redis_config.get('redis', 'password')
81 | namespace = kwargs.get("namespace", "")
82 | # collection_name = kwargs.get("collection_name", COLLECTION_NAME)
83 |
84 | vector_base = RedisVectorStore(
85 | host=host,
86 | port=port,
87 | username=user,
88 | password=password,
89 | namespace=namespace,
90 | top_k=top_k,
91 | dimension=dimension,
92 | )
93 | elif name == "faiss":
94 | from modelcache.manager.vector_data.faiss import Faiss
95 |
96 | dimension = kwargs.get("dimension", DIMENSION)
97 | index_path = kwargs.pop("index_path", FAISS_INDEX_PATH)
98 | VectorBase.check_dimension(dimension)
99 | vector_base = Faiss(
100 | index_file_path=index_path, dimension=dimension, top_k=top_k
101 | )
102 | elif name == "chromadb":
103 | from modelcache.manager.vector_data.chroma import Chromadb
104 |
105 | chromadb_config = kwargs.get("chromadb_config", None)
106 | persist_directory = chromadb_config.get('chromadb','persist_directory')
107 |
108 | vector_base = Chromadb(
109 | persist_directory=persist_directory,
110 | top_k=top_k,
111 | )
112 | elif name == "hnswlib":
113 | from modelcache.manager.vector_data.hnswlib_store import Hnswlib
114 |
115 | dimension = kwargs.get("dimension", DIMENSION)
116 | index_path = kwargs.pop("index_path", "./hnswlib_index.bin")
117 | max_elements = kwargs.pop("max_elements", 100000)
118 | VectorBase.check_dimension(dimension)
119 | vector_base = Hnswlib(
120 | index_file_path=index_path, dimension=dimension,
121 | top_k=top_k, max_elements=max_elements
122 | )
123 | else:
124 | raise NotFoundError("vector store", name)
125 | return vector_base
126 |
--------------------------------------------------------------------------------
/modelcache/manager/vector_data/redis.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from typing import List
3 | import numpy as np
4 | from redis.commands.search.indexDefinition import IndexDefinition, IndexType
5 | from redis.commands.search.query import Query
6 | from redis.commands.search.field import TagField, VectorField, NumericField
7 | from redis.client import Redis
8 |
9 | from modelcache.manager.vector_data.base import VectorBase, VectorData
10 | from modelcache.utils import import_redis
11 | from modelcache.utils.log import modelcache_log
12 | from modelcache.utils.index_util import get_index_name
13 | from modelcache.utils.index_util import get_index_prefix
14 | import_redis()
15 |
16 |
17 | class RedisVectorStore(VectorBase):
18 | def __init__(
19 | self,
20 | host: str = "localhost",
21 | port: str = "6379",
22 | username: str = "",
23 | password: str = "",
24 | dimension: int = 0,
25 | top_k: int = 1,
26 | namespace: str = "",
27 | ):
28 | if dimension <= 0:
29 | raise ValueError(
30 | f"invalid `dim` param: {dimension} in the Redis vector store."
31 | )
32 | self._client = Redis(
33 | host=host, port=int(port), username=username, password=password
34 | )
35 | self.top_k = top_k
36 | self.dimension = dimension
37 | self.namespace = namespace
38 | self.doc_prefix = f"{self.namespace}doc:"
39 |
40 | def _check_index_exists(self, index_name: str) -> bool:
41 | """Check if Redis index exists."""
42 | try:
43 | self._client.ft(index_name).info()
44 | except:
45 | modelcache_log.info("Index does not exist")
46 | return False
47 | modelcache_log.info("Index already exists")
48 | return True
49 |
50 | def create_index(self, index_name, index_prefix):
51 | dimension = self.dimension
52 | if self._check_index_exists(index_name):
53 | modelcache_log.info(
54 | "The %s already exists, and it will be used directly", index_name
55 | )
56 | return 'already_exists'
57 | else:
58 | id_field_name = "data_id"
59 | embedding_field_name = "data_vector"
60 |
61 | id = NumericField(name=id_field_name)
62 | embedding = VectorField(embedding_field_name,
63 | "HNSW", {
64 | "TYPE": "FLOAT32",
65 | "DIM": dimension,
66 | "DISTANCE_METRIC": "L2",
67 | "INITIAL_CAP": 1000,
68 | }
69 | )
70 | fields = [id, embedding]
71 | definition = IndexDefinition(prefix=[index_prefix], index_type=IndexType.HASH)
72 |
73 | # create Index
74 | self._client.ft(index_name).create_index(
75 | fields=fields, definition=definition
76 | )
77 | return 'create_success'
78 |
79 | def mul_add(self, datas: List[VectorData], model=None):
80 | # pipe = self._client.pipeline()
81 | for data in datas:
82 | id: int = data.id
83 | embedding = data.data.astype(np.float32).tobytes()
84 | id_field_name = "data_id"
85 | embedding_field_name = "data_vector"
86 | obj = {id_field_name: id, embedding_field_name: embedding}
87 | index_prefix = get_index_prefix(model)
88 | self._client.hset(f"{index_prefix}{id}", mapping=obj)
89 |
90 | def search(self, data: np.ndarray, top_k: int = -1, model=None):
91 | index_name = get_index_name(model)
92 | id_field_name = "data_id"
93 | embedding_field_name = "data_vector"
94 | base_query = f'*=>[KNN 2 @{embedding_field_name} $vector AS distance]'
95 | query = (
96 | Query(base_query)
97 | .sort_by("distance")
98 | .return_fields(id_field_name, "distance")
99 | .dialect(2)
100 | )
101 |
102 | query_params = {"vector": data.astype(np.float32).tobytes()}
103 | results = (
104 | self._client.ft(index_name)
105 | .search(query, query_params=query_params)
106 | .docs
107 | )
108 | return [(float(result.distance), int(getattr(result, id_field_name))) for result in results]
109 |
110 | def rebuild(self, ids=None) -> bool:
111 | pass
112 |
113 | def rebuild_col(self, model):
114 | index_name_model = get_index_name(model)
115 | if self._check_index_exists(index_name_model):
116 | try:
117 | self._client.ft(index_name_model).dropindex(delete_documents=True)
118 | except Exception as e:
119 | raise ValueError(str(e))
120 | try:
121 | index_prefix = get_index_prefix(model)
122 | self.create_index(index_name_model, index_prefix)
123 | except Exception as e:
124 | raise ValueError(str(e))
125 | # return 'rebuild success'
126 |
127 | def delete(self, ids) -> None:
128 | pipe = self._client.pipeline()
129 | for data_id in ids:
130 | pipe.delete(f"{self.doc_prefix}{data_id}")
131 | pipe.execute()
132 |
133 | def create(self, model=None):
134 | index_name = get_index_name(model)
135 | index_prefix = get_index_prefix(model)
136 | return self.create_index(index_name, index_prefix)
137 |
138 | def get_index_by_name(self, index_name):
139 | pass
140 |
--------------------------------------------------------------------------------
/modelcache/processor/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
--------------------------------------------------------------------------------
/modelcache/processor/post.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import random
3 | from typing import List, Any
4 |
5 |
6 | def random_one(messages: List[Any]) -> Any:
7 | return random.choice(messages)
8 |
9 |
10 | def first(messages: List[Any]) -> Any:
11 | return messages[0]
12 |
13 |
14 | def nop(messages: List[Any]) -> Any:
15 | return messages
16 |
--------------------------------------------------------------------------------
/modelcache/processor/pre.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import re
3 | from typing import Dict, Any
4 |
5 |
6 | def insert_last_content(data: Dict[str, Any], **_: Dict[str, Any]) -> Any:
7 | return data.get("chat_info")[-1]["query"]
8 |
9 |
10 | def query_last_content(data: Dict[str, Any], **_: Dict[str, Any]) -> Any:
11 | return data.get("query")[-1]["content"]
12 |
13 |
14 | def last_content_without_prompt(data: Dict[str, Any], **params: Dict[str, Any]) -> Any:
15 | last_content_str = data.get("messages")[-1]["content"]
16 | prompts = params.get("prompts", [])
17 | if prompts is None:
18 | return last_content_str
19 | pattern = "|".join(prompts)
20 | new_content_str = re.sub(pattern, "", last_content_str)
21 | return new_content_str
22 |
23 |
24 | def all_content(data: Dict[str, Any], **_: Dict[str, Any]) -> Any:
25 | s = ""
26 | messages = data.get("messages")
27 | for i, message in enumerate(messages):
28 | if i == len(messages) - 1:
29 | s += message["content"]
30 | else:
31 | s += message["content"] + "\n"
32 | return s
33 |
34 |
35 | def nop(data: Dict[str, Any], **_: Dict[str, Any]) -> Any:
36 | return data
37 |
38 |
39 | def get_prompt(data: Dict[str, Any], **_: Dict[str, Any]) -> Any:
40 | return data.get("prompt")
41 |
42 |
43 | def get_file_name(data: Dict[str, Any], **_: Dict[str, Any]) -> str:
44 | return data.get("file").name
45 |
46 |
47 | def get_file_bytes(data: Dict[str, Any], **_: Dict[str, Any]) -> bytes:
48 | return data.get("file").peek()
49 |
50 |
51 | def get_input_str(data: Dict[str, Any], **_: Dict[str, Any]) -> str:
52 | input_data = data.get("input")
53 | return str(input_data["image"].peek()) + input_data["question"]
54 |
55 |
56 | def get_input_image_file_name(data: Dict[str, Any], **_: Dict[str, Any]) -> str:
57 | input_data = data.get("input")
58 | return input_data["image"].name
59 |
60 |
61 | def query_multi_splicing(data: Dict[str, Any], **_: Dict[str, Any]) -> Any:
62 | query_list = data.get("query")
63 | return multi_splicing(query_list)
64 |
65 |
66 | def insert_multi_splicing(data: Dict[str, Any], **_: Dict[str, Any]) -> Any:
67 | insert_query_list = data.get("chat_info")[-1]['query']
68 | return multi_splicing(insert_query_list)
69 |
70 |
71 | def multi_splicing(data_list) -> Any:
72 | result_str = ""
73 | for d in data_list:
74 | role = d.get('role', '')
75 | content = d.get('content', '')
76 | result_str += role + "###" + content + "|||"
77 |
78 | # 去掉最后一个"|||"
79 | result_str = result_str[:-3]
80 |
81 | return result_str
82 |
83 |
84 | def multi_analysis(dialog_str):
85 | sub_strings = dialog_str.split('|||')
86 |
87 | dict_list = []
88 | for s in sub_strings:
89 | parts = s.split('###')
90 |
91 | if len(parts) == 2:
92 | role = parts[0]
93 | content = parts[1]
94 | elif len(parts) > 2:
95 | role = parts[0]
96 | content = '###'.join(parts[1:])
97 | else:
98 | content = 'exception'
99 |
100 | if content == '':
101 | d = {"role": role}
102 | else:
103 | d = {"role": role, "content": content}
104 | dict_list.append(d)
105 |
106 | # 3. 将每个字典添加到一个列表中,得到最终的列表
107 | result_list = dict_list
108 |
109 | # 输出结果
110 | return result_list
111 |
--------------------------------------------------------------------------------
/modelcache/report.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | class Report:
3 | def __init__(self):
4 | self.embedding_all_time = 0
5 | self.embedding_count = 0
6 | self.search_all_time = 0
7 | self.search_count = 0
8 | self.hint_cache_count = 0
9 |
10 | def embedding(self, delta_time):
11 | """Embedding counts and time.
12 |
13 | :param delta_time: additional runtime.
14 | """
15 | self.embedding_all_time += delta_time
16 | self.embedding_count += 1
17 |
18 | def search(self, delta_time):
19 | """Search counts and time.
20 |
21 | :param delta_time: additional runtime.
22 | """
23 | self.search_all_time += delta_time
24 | self.search_count += 1
25 |
26 | def average_embedding_time(self):
27 | """Average embedding time."""
28 | return round(
29 | self.embedding_all_time / self.embedding_count
30 | if self.embedding_count != 0
31 | else 0,
32 | 4,
33 | )
34 |
35 | def average_search_time(self):
36 | return round(
37 | self.search_all_time / self.search_count
38 | if self.embedding_count != 0
39 | else 0,
40 | 4,
41 | )
42 |
43 | def hint_cache(self):
44 | self.hint_cache_count += 1
45 |
--------------------------------------------------------------------------------
/modelcache/similarity_evaluation/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from modelcache.similarity_evaluation.similarity_evaluation import SimilarityEvaluation
3 | from modelcache.utils.lazy_import import LazyImport
4 |
5 | exact_match = LazyImport(
6 | "exact_match", globals(), "modelcache.similarity_evaluation.exact_match"
7 | )
8 |
9 |
10 | def ExactMatchEvaluation():
11 | return exact_match.ExactMatchEvaluation()
12 |
--------------------------------------------------------------------------------
/modelcache/similarity_evaluation/distance.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from typing import Tuple, Dict, Any
3 | from modelcache.similarity_evaluation import SimilarityEvaluation
4 |
5 |
6 | class SearchDistanceEvaluation(SimilarityEvaluation):
7 | def __init__(self, max_distance=4.0, positive=False):
8 | self.max_distance = max_distance
9 | self.positive = positive
10 |
11 | def evaluation(
12 | self, src_dict: Dict[str, Any], cache_dict: Dict[str, Any], **_
13 | ) -> float:
14 | distance, _ = cache_dict["search_result"]
15 | if distance < 0:
16 | distance = 0
17 | elif distance > self.max_distance:
18 | distance = self.max_distance
19 | if self.positive:
20 | return distance
21 | return self.max_distance - distance
22 |
23 | def range(self) -> Tuple[float, float]:
24 | return 0.0, self.max_distance
25 |
--------------------------------------------------------------------------------
/modelcache/similarity_evaluation/exact_match.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from typing import Tuple, Dict, Any
3 | from modelcache.similarity_evaluation.similarity_evaluation import SimilarityEvaluation
4 |
5 |
6 | class ExactMatchEvaluation(SimilarityEvaluation):
7 |
8 | def __init__(self):
9 | pass
10 |
11 | def evaluation(
12 | self, src_dict: Dict[str, Any], cache_dict: Dict[str, Any], **_
13 | ) -> float:
14 | return 1 if cache_dict["question"] == src_dict["question"] else 0
15 |
16 | def range(self) -> Tuple[float, float]:
17 | return 0, 1
18 |
--------------------------------------------------------------------------------
/modelcache/similarity_evaluation/similarity_evaluation.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from abc import ABCMeta, abstractmethod
3 | from typing import Tuple, Dict, Any
4 |
5 |
6 | class SimilarityEvaluation(metaclass=ABCMeta):
7 | @abstractmethod
8 | def evaluation(
9 | self, src_dict: Dict[str, Any], cache_dict: Dict[str, Any], **kwargs
10 | ) -> float:
11 | pass
12 |
13 | @abstractmethod
14 | def range(self) -> Tuple[float, float]:
15 | pass
16 |
--------------------------------------------------------------------------------
/modelcache/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import importlib.util
3 | from typing import Optional
4 | from modelcache.utils.dependency_control import prompt_install
5 |
6 |
7 | def _check_library(libname: str, prompt: bool = True, package: Optional[str] = None):
8 | is_avail = False
9 | if importlib.util.find_spec(libname):
10 | is_avail = True
11 | if not is_avail and prompt:
12 | prompt_install(package if package else libname)
13 | return is_avail
14 |
15 |
16 | def import_onnxruntime():
17 | _check_library("onnxruntime")
18 |
19 |
20 | def import_huggingface():
21 | _check_library("transformers")
22 |
23 |
24 | def import_huggingface_hub():
25 | _check_library("huggingface_hub", package="huggingface-hub")
26 |
27 |
28 | def import_pymysql():
29 | _check_library("pymysql")
30 |
31 |
32 | def import_sql_client(db_name):
33 | if db_name in ["mysql"]:
34 | import_pymysql()
35 |
36 |
37 | def import_pymilvus():
38 | _check_library("pymilvus")
39 |
40 |
41 | def import_milvus_lite():
42 | _check_library("milvus")
43 |
44 |
45 | def import_faiss():
46 | _check_library("faiss", package="faiss-cpu")
47 |
48 |
49 | def import_torch():
50 | _check_library("torch")
51 |
52 |
53 | def import_fasttext():
54 | _check_library("fasttext")
55 |
56 |
57 | def import_paddle():
58 | prompt_install("protobuf==3.20.0")
59 | _check_library("paddlepaddle")
60 |
61 |
62 | def import_paddlenlp():
63 | _check_library("paddlenlp")
64 |
65 |
66 | def import_timm():
67 | _check_library("timm", package="timm")
68 |
69 |
70 | def import_pillow():
71 | _check_library("PIL", package="pillow")
72 |
73 |
74 | def import_redis():
75 | _check_library("redis")
76 |
77 |
78 | def import_chromadb():
79 | _check_library("chromadb", package="chromadb")
--------------------------------------------------------------------------------
/modelcache/utils/cache_func.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | def cache_all(*_, **__):
3 | return True
--------------------------------------------------------------------------------
/modelcache/utils/dependency_control.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import subprocess
3 | from modelcache.utils.error import PipInstallError
4 | from modelcache.utils.log import modelcache_log
5 |
6 |
7 | def prompt_install(package: str, warn: bool = False): # pragma: no cover
8 | """
9 | Function used to prompt user to install a package.
10 | """
11 | cmd = f"pip install {package}"
12 | try:
13 | if warn and input(f"Install {package}? Y/n: ") != "Y":
14 | raise ModuleNotFoundError(f"No module named {package}")
15 | subprocess.check_call(cmd, shell=True)
16 | modelcache_log.info("%s installed successfully!", package)
17 | except subprocess.CalledProcessError as e:
18 | raise PipInstallError(package) from e
19 |
--------------------------------------------------------------------------------
/modelcache/utils/env_config.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
--------------------------------------------------------------------------------
/modelcache/utils/error.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | class CacheError(Exception):
3 | """ModelCache base error"""
4 |
5 |
6 | class NotInitError(CacheError):
7 | """Raise when the cache has been used before it's inited"""
8 | def __init__(self):
9 | super().__init__("The cache should be inited before using")
10 |
11 |
12 | class RemoveError(CacheError):
13 | """Raise when the cache has been used before it's inited"""
14 | def __init__(self):
15 | super().__init__("The cache remove error")
16 |
17 | class NotFoundError(CacheError):
18 | """Raise when getting an unsupported store."""
19 | def __init__(self, store_type, current_type_name):
20 | super().__init__(f"Unsupported ${store_type}: {current_type_name}")
21 |
22 |
23 | class ParamError(CacheError):
24 | """Raise when receiving an invalid param."""
25 |
26 |
27 | class PipInstallError(CacheError):
28 | """Raise when failed to install package."""
29 | def __init__(self, package):
30 | super().__init__(f"Ran into error installing {package}.")
31 |
--------------------------------------------------------------------------------
/modelcache/utils/index_util.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | def get_index_name(model):
5 | return 'modelcache' + '_' + model
6 |
7 |
8 | def get_index_prefix(model):
9 | return 'prefix' + '_' + model
10 |
--------------------------------------------------------------------------------
/modelcache/utils/lazy_import.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import importlib
3 | from types import ModuleType
4 |
5 |
6 | class LazyImport(ModuleType):
7 | """
8 | Lazily import a module.
9 | """
10 |
11 | def __init__(self, local_name, parent_module_globals, name):
12 | self._local_name = local_name
13 | self._parent_module_globals = parent_module_globals
14 | super().__init__(name)
15 |
16 | def _load(self):
17 | module = importlib.import_module(self.__name__)
18 | self._parent_module_globals[self._local_name] = module
19 | self.__dict__.update(module.__dict__)
20 | return module
21 |
22 | def __getattr__(self, item):
23 | module = self._load()
24 | return getattr(module, item)
25 |
26 | def __dir__(self):
27 | module = self._load()
28 | return dir(module)
29 |
--------------------------------------------------------------------------------
/modelcache/utils/log.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import logging
3 |
4 | FORMAT = '%(asctime)s - %(thread)d - %(filename)s-%(module)s:%(lineno)s - %(levelname)s: %(message)s'
5 | logging.basicConfig(format=FORMAT)
6 |
7 | modelcache_log = logging.getLogger('modelcache')
8 |
--------------------------------------------------------------------------------
/modelcache/utils/model_filter.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | def model_blacklist_filter(model, request_type):
3 | black_list = ['DI_COPILOT_SECOND', 'DI_COPILOT_LAB', 'DI_COPILOT_THIRD']
4 | result = None
5 | if model in black_list:
6 | if request_type == 'query':
7 | result = {"errorCode": 105,
8 | "errorDesc": "model: {} in blacklist".format(model),
9 | "cacheHit": False, "delta_time": 0, "hit_query": '', "answer": ''}
10 | elif request_type == 'insert':
11 | result = {"errorCode": 305, "errorDesc": "model: {} in blacklist".format(model), "writeStatus": ""}
12 |
13 | return result
14 |
15 |
16 |
--------------------------------------------------------------------------------
/modelcache/utils/time.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import time
3 | from modelcache import cache
4 |
5 |
6 | def time_cal(func, func_name=None, report_func=None):
7 | def inner(*args, **kwargs):
8 | time_start = time.time()
9 | res = func(*args, **kwargs)
10 | delta_time = time.time() - time_start
11 | if cache.config.log_time_func:
12 | cache.config.log_time_func(
13 | func.__name__ if func_name is None else func_name, delta_time
14 | )
15 | if report_func is not None:
16 | report_func(delta_time)
17 | return res
18 |
19 | return inner
20 |
21 |
--------------------------------------------------------------------------------
/modelcache_mm/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from modelcache_mm.core import Cache
3 | from modelcache_mm.core import cache
4 | from modelcache_mm.config import Config
5 | import modelcache_mm.adapter
6 |
--------------------------------------------------------------------------------
/modelcache_mm/adapter/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
--------------------------------------------------------------------------------
/modelcache_mm/adapter/adapter.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import logging
3 |
4 | from modelcache_mm.adapter.adapter_query import adapt_query
5 | from modelcache_mm.adapter.adapter_insert import adapt_insert
6 | from modelcache_mm.adapter.adapter_remove import adapt_remove
7 | from modelcache_mm.adapter.adapter_register import adapt_register
8 |
9 |
10 | class ChatCompletion(object):
11 | """Openai ChatCompletion Wrapper"""
12 | @classmethod
13 | def create_query(cls, *args, **kwargs):
14 | def cache_data_convert(cache_data, cache_query):
15 | return construct_resp_from_cache(cache_data, cache_query)
16 | try:
17 | return adapt_query(
18 | cache_data_convert,
19 | *args,
20 | **kwargs
21 | )
22 | except Exception as e:
23 | # return str(e)
24 | raise e
25 |
26 | @classmethod
27 | def create_insert(cls, *args, **kwargs):
28 | try:
29 | return adapt_insert(
30 | *args,
31 | **kwargs
32 | )
33 | except Exception as e:
34 | # return str(e)
35 | raise e
36 |
37 | @classmethod
38 | def create_remove(cls, *args, **kwargs):
39 | try:
40 | return adapt_remove(
41 | *args,
42 | **kwargs
43 | )
44 | except Exception as e:
45 | raise e
46 |
47 | @classmethod
48 | def create_register(cls, *args, **kwargs):
49 | try:
50 | return adapt_register(
51 | *args,
52 | **kwargs
53 | )
54 | except Exception as e:
55 | raise e
56 |
57 |
58 | def construct_resp_from_cache(return_message, return_query):
59 | return {
60 | "modelcache": True,
61 | "hitQuery": return_query,
62 | "data": return_message,
63 | "errorCode": 0
64 | }
65 |
--------------------------------------------------------------------------------
/modelcache_mm/adapter/adapter_insert.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import numpy as np
3 | from modelcache_mm import cache
4 | from modelcache_mm.utils.error import NotInitError
5 | from modelcache_mm.utils.time import time_cal
6 |
7 |
8 | def adapt_insert(*args, **kwargs):
9 | chat_cache = kwargs.pop("cache_obj", cache)
10 | model = kwargs.pop("model", None)
11 | require_object_store = kwargs.pop("require_object_store", False)
12 | if require_object_store:
13 | assert chat_cache.data_manager.o, "Object store is required for adapter."
14 | if not chat_cache.has_init:
15 | raise NotInitError()
16 | cache_enable = chat_cache.cache_enable_func(*args, **kwargs)
17 | context = kwargs.pop("cache_context", {})
18 | embedding_data = None
19 | pre_embedding_data_dict = chat_cache.insert_pre_embedding_func(
20 | kwargs,
21 | extra_param=context.get("pre_embedding_func", None),
22 | prompts=chat_cache.config.prompts,
23 | )
24 |
25 | chat_info = kwargs.pop("chat_info", [])
26 | llm_data = chat_info[-1]['answer']
27 |
28 | pre_embedding_text = '###'.join(pre_embedding_data_dict['text'])
29 | pre_embedding_image_url = pre_embedding_data_dict['imageUrl']
30 | pre_embedding_image_raw = pre_embedding_data_dict['imageRaw']
31 | pre_embedding_image_id = pre_embedding_data_dict.get('imageId', None)
32 |
33 | if pre_embedding_image_url and pre_embedding_image_raw:
34 | raise ValueError("Both pre_embedding_image_url and pre_embedding_image_raw cannot be non-empty at the same time.")
35 |
36 | if pre_embedding_image_url:
37 | pre_embedding_image = pre_embedding_image_url
38 | elif pre_embedding_image_raw:
39 | pre_embedding_image = pre_embedding_image_raw
40 | else:
41 | pre_embedding_image = None
42 | if not pre_embedding_text:
43 | raise ValueError(
44 | "Both pre_embedding_image_url and pre_embedding_image_raw are empty. Please provide at least one.")
45 |
46 | data_dict = {'text': [pre_embedding_text], 'image': pre_embedding_image}
47 | embedding_data = None
48 | mm_type = None
49 |
50 | if cache_enable:
51 | embedding_data_resp = time_cal(
52 | chat_cache.embedding_func,
53 | func_name="image_embedding",
54 | report_func=chat_cache.report.embedding,
55 | )(data_dict)
56 |
57 | image_embeddings = embedding_data_resp['image_embedding']
58 | text_embeddings = embedding_data_resp['text_embeddings']
59 |
60 | if len(image_embeddings) > 0 and len(image_embeddings) > 0:
61 | # image_embedding = np.array(image_embeddings[0])
62 | # text_embedding = text_embeddings[0]
63 | embedding_data = np.concatenate((image_embeddings, text_embeddings))
64 | mm_type = 'mm'
65 | elif len(image_embeddings) > 0:
66 | image_embedding = np.array(image_embeddings[0])
67 | embedding_data = image_embedding
68 | mm_type = 'image'
69 | elif len(text_embeddings) > 0:
70 | text_embedding = np.array(text_embeddings[0])
71 | embedding_data = text_embedding
72 | mm_type = 'text'
73 | else:
74 | raise ValueError('maya embedding service return both empty list, please check!')
75 | chat_cache.data_manager.save(
76 | pre_embedding_text,
77 | pre_embedding_image_url,
78 | pre_embedding_image_id,
79 | llm_data,
80 | embedding_data,
81 | model=model,
82 | mm_type=mm_type,
83 | extra_param=context.get("mm_save_func", None)
84 | )
85 | return 'success'
86 |
--------------------------------------------------------------------------------
/modelcache_mm/adapter/adapter_register.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from modelcache_mm import cache
3 |
4 |
5 | def adapt_register(*args, **kwargs):
6 | chat_cache = kwargs.pop("cache_obj", cache)
7 | model = kwargs.pop("model", None)
8 | type = kwargs.pop("type", None)
9 | if model is None or len(model) == 0:
10 | return ValueError('')
11 |
12 | register_resp = chat_cache.data_manager.create_index(model, type)
13 | return register_resp
14 |
--------------------------------------------------------------------------------
/modelcache_mm/adapter/adapter_remove.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from modelcache_mm import cache
3 | from modelcache_mm.utils.error import NotInitError
4 |
5 |
6 | def adapt_remove(*args, **kwargs):
7 | chat_cache = kwargs.pop("cache_obj", cache)
8 | model = kwargs.pop("model", None)
9 | remove_type = kwargs.pop("remove_type", None)
10 | require_object_store = kwargs.pop("require_object_store", False)
11 | if require_object_store:
12 | assert chat_cache.data_manager.o, "Object store is required for adapter."
13 | if not chat_cache.has_init:
14 | raise NotInitError()
15 |
16 | # delete data
17 | if remove_type == 'delete_by_id':
18 | id_list = kwargs.pop("id_list", [])
19 | resp = chat_cache.data_manager.delete(id_list, model=model)
20 | elif remove_type == 'truncate_by_model':
21 | resp = chat_cache.data_manager.truncate(model)
22 | else:
23 | resp = "remove_type_error"
24 | return resp
25 |
--------------------------------------------------------------------------------
/modelcache_mm/config.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from typing import Optional, Callable, List
3 | from modelcache.utils.error import CacheError
4 |
5 |
6 | class Config:
7 |
8 | def __init__(
9 | self,
10 | log_time_func: Optional[Callable[[str, float], None]] = None,
11 | similarity_threshold: float = 0.95,
12 | similarity_threshold_long: float = 0.95,
13 | prompts: Optional[List[str]] = None
14 | ):
15 | if similarity_threshold < 0 or similarity_threshold > 1:
16 | raise CacheError(
17 | "Invalid the similarity threshold param, reasonable range: 0-1"
18 | )
19 | self.log_time_func = log_time_func
20 | self.similarity_threshold = similarity_threshold
21 | self.similarity_threshold_long = similarity_threshold_long
22 | self.prompts = prompts
23 |
--------------------------------------------------------------------------------
/modelcache_mm/config/chromadb_config.ini:
--------------------------------------------------------------------------------
1 | [chromadb]
2 | persist_directory=./chromadb
3 |
--------------------------------------------------------------------------------
/modelcache_mm/config/elasticsearch_config.ini:
--------------------------------------------------------------------------------
1 | [elasticsearch]
2 | host = ''
3 | port = ''
4 | user = ''
5 | password = ''
--------------------------------------------------------------------------------
/modelcache_mm/config/milvus_config.ini:
--------------------------------------------------------------------------------
1 | [milvus]
2 | host = ''
3 | port = ''
4 | user = ''
5 | password = ''
--------------------------------------------------------------------------------
/modelcache_mm/config/mysql_config.ini:
--------------------------------------------------------------------------------
1 | [mysql]
2 | host = ''
3 | port = ''
4 | username = ''
5 | password = ''
6 | database = ''
7 |
--------------------------------------------------------------------------------
/modelcache_mm/config/redis_config.ini:
--------------------------------------------------------------------------------
1 | [redis]
2 | host = ''
3 | port = ''
4 | user = ''
5 | password = ''
6 |
--------------------------------------------------------------------------------
/modelcache_mm/core.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import atexit
3 | from typing import Optional, List, Any
4 | from modelcache_mm.processor.post import first
5 | from modelcache_mm.similarity_evaluation import ExactMatchEvaluation
6 | from modelcache_mm.similarity_evaluation import SimilarityEvaluation
7 | from modelcache_mm.embedding.string import to_embeddings as string_embedding
8 | from modelcache_mm.report import Report
9 | from modelcache_mm.config import Config
10 | from modelcache_mm.utils.cache_func import cache_all
11 | from modelcache_mm.utils.log import modelcache_log
12 | from modelcache_mm.manager import get_data_manager
13 | from modelcache_mm.manager.data_manager import DataManager
14 |
15 |
16 | class Cache:
17 | def __init__(self):
18 | self.has_init = False
19 | self.cache_enable_func = None
20 | self.query_pre_embedding_func = None
21 | self.insert_pre_embedding_func = None
22 | self.embedding_func = None
23 | self.data_manager: Optional[DataManager] = None
24 | self.similarity_evaluation: Optional[SimilarityEvaluation] = None
25 | self.post_process_messages_func = None
26 | self.config = Config()
27 | self.report = Report()
28 | self.next_cache = None
29 |
30 | def init(
31 | self,
32 | cache_enable_func=cache_all,
33 | query_pre_embedding_func=None,
34 | insert_pre_embedding_func=None,
35 | embedding_func=string_embedding,
36 | data_manager: DataManager = get_data_manager(),
37 | similarity_evaluation=ExactMatchEvaluation(),
38 | post_process_messages_func=first,
39 | config=Config(),
40 | next_cache=None,
41 | ):
42 | self.has_init = True
43 | self.cache_enable_func = cache_enable_func
44 | self.query_pre_embedding_func = query_pre_embedding_func
45 | self.insert_pre_embedding_func = insert_pre_embedding_func
46 | self.embedding_func = embedding_func
47 | self.data_manager: DataManager = data_manager
48 | self.similarity_evaluation = similarity_evaluation
49 | self.post_process_messages_func = post_process_messages_func
50 | self.config = config
51 | self.next_cache = next_cache
52 |
53 | @atexit.register
54 | def close():
55 | try:
56 | self.data_manager.close()
57 | except Exception as e:
58 | modelcache_log.error(e)
59 |
60 | def import_data(self, questions: List[Any], answers: List[Any]) -> None:
61 | self.data_manager.import_data(
62 | questions=questions,
63 | answers=answers,
64 | embedding_datas=[self.embedding_func(question) for question in questions],
65 | )
66 |
67 | def flush(self):
68 | self.data_manager.flush()
69 | if self.next_cache:
70 | self.next_cache.data_manager.flush()
71 |
72 |
73 | cache = Cache()
74 |
--------------------------------------------------------------------------------
/modelcache_mm/embedding/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from modelcache.utils.lazy_import import LazyImport
3 | clip = LazyImport("clip", globals(), "modelcache_mm.embedding.clip")
4 |
5 |
6 | def Clip2Vec(model="damo/multi-modal_clip-vit-base-patch16_zh"):
7 | return clip.ClipAudio(model)
8 |
--------------------------------------------------------------------------------
/modelcache_mm/embedding/base.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from abc import ABCMeta, abstractmethod
3 |
4 |
5 | class BaseEmbedding(metaclass=ABCMeta):
6 | """
7 | _Embedding base.
8 | """
9 |
10 | @abstractmethod
11 | def to_embeddings(self, data, **kwargs):
12 | pass
13 |
14 | @property
15 | @abstractmethod
16 | def dimension(self) -> int:
17 | return 0
18 |
--------------------------------------------------------------------------------
/modelcache_mm/embedding/clip.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import torch
3 | from modelcache.embedding.base import BaseEmbedding
4 | from modelscope.utils.constant import Tasks
5 | from modelscope.pipelines import pipeline
6 | from modelscope.preprocessors.image import load_image
7 |
8 |
9 | class ClipAudio(BaseEmbedding):
10 | def __init__(self, model: str = 'damo/multi-modal_clip-vit-base-patch16_zh'):
11 | self.model = model
12 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
13 | self.clip_pipeline = pipeline(task=Tasks.multi_modal_embedding,
14 | model=model, model_revision='v1.0.1')
15 | self.__dimension = 1024
16 |
17 | def to_embeddings(self, data_dict, **_):
18 | text_list = data_dict['text']
19 | image_data = data_dict['image']
20 |
21 | # img_data = None
22 | # txt_data = None
23 |
24 | if image_data:
25 | input_img = load_image(image_data)
26 | img_embedding = self.clip_pipeline.forward({'img': input_img})['img_embedding'].tolist()[0] if input_img else []
27 | else:
28 | raise ValueError('image_data is None, please check!')
29 |
30 | if text_list and len(text_list) > 0:
31 | text_embedding = self.clip_pipeline.forward({'text': text_list})['text_embedding'].tolist()[0] if text_list else []
32 | else:
33 | raise ValueError('text_list is None, please check!')
34 |
35 | return {'image_embedding': img_embedding, 'text_embeddings': text_embedding}
36 |
37 | def post_proc(self, token_embeddings, inputs):
38 | attention_mask = inputs["attention_mask"]
39 | input_mask_expanded = (
40 | attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
41 | )
42 | sentence_embs = torch.sum(
43 | token_embeddings * input_mask_expanded, 1
44 | ) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
45 | return sentence_embs
46 |
47 | @property
48 | def dimension(self):
49 | return self.__dimension
50 |
--------------------------------------------------------------------------------
/modelcache_mm/embedding/string.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | def to_embeddings(data, **_):
5 | return data
6 |
--------------------------------------------------------------------------------
/modelcache_mm/embedding/timm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import numpy as np
3 |
4 | from modelcache.utils import import_timm, import_torch, import_pillow
5 | from modelcache.embedding.base import BaseEmbedding
6 |
7 | import_torch()
8 | import_timm()
9 | import_pillow()
10 |
11 | import torch
12 | from timm.models import create_model
13 | from timm.data import create_transform, resolve_data_config
14 | from PIL import Image
15 |
16 |
17 | class Timm(BaseEmbedding):
18 | def __init__(self, model: str = "resnet18", device: str = "default"):
19 | if device == "default":
20 | self.device = "cuda" if torch.cuda.is_available() else "cpu"
21 | else:
22 | self.device = device
23 | self.model_name = model
24 | self.model = create_model(model_name=model, pretrained=True)
25 | self.model.eval()
26 |
27 | try:
28 | self.__dimension = self.model.embed_dim
29 | except Exception:
30 | self.__dimension = None
31 |
32 | def to_embeddings(self, data, skip_preprocess: bool = False, **_):
33 | if not skip_preprocess:
34 | data = self.preprocess(data)
35 | if data.dim() == 3:
36 | data = data.unsqueeze(0)
37 | feats = self.model.forward_features(data)
38 | emb = self.post_proc(feats).squeeze(0).detach().numpy()
39 |
40 | return np.array(emb).astype("float32")
41 |
42 | def post_proc(self, features):
43 | features = features.to("cpu")
44 | if features.dim() == 3:
45 | features = features[:, 0]
46 | if features.dim() == 4:
47 | global_pool = torch.nn.AdaptiveAvgPool2d(1)
48 | features = global_pool(features)
49 | features = features.flatten(1)
50 | assert features.dim() == 2, f"Invalid output dim {features.dim()}"
51 | return features
52 |
53 | def preprocess(self, image_path):
54 | data_cfg = resolve_data_config(self.model.pretrained_cfg)
55 | transform = create_transform(**data_cfg)
56 |
57 | image = Image.open(image_path).convert("RGB")
58 | image_tensor = transform(image)
59 | return image_tensor
60 |
61 | @property
62 | def dimension(self):
63 | """Embedding dimension.
64 | :return: embedding dimension
65 | """
66 | if not self.__dimension:
67 | input_size = self.model.pretrained_cfg["input_size"]
68 | dummy_input = torch.rand((1,) + input_size)
69 | feats = self.to_embeddings(dummy_input, skip_preprocess=True)
70 | self.__dimension = feats.shape[0]
71 | return self.__dimension
72 |
--------------------------------------------------------------------------------
/modelcache_mm/manager/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from modelcache_mm.manager.scalar_data import CacheBase
3 | from modelcache_mm.manager.vector_data import VectorBase
4 | from modelcache_mm.manager.object_data import ObjectBase
5 | from modelcache_mm.manager.factory import get_data_manager
6 |
--------------------------------------------------------------------------------
/modelcache_mm/manager/eviction/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from modelcache.utils.lazy_import import LazyImport
3 |
4 | eviction_manager = LazyImport(
5 | "eviction_manager", globals(), "modelcache.manager.eviction.manager"
6 | )
7 |
8 |
9 | def EvictionBase(name: str, **kwargs):
10 | return eviction_manager.EvictionBase.get(name, **kwargs)
11 |
--------------------------------------------------------------------------------
/modelcache_mm/manager/eviction/base.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from abc import ABCMeta, abstractmethod
3 | from typing import Any, List
4 |
5 |
6 | class EvictionBase(metaclass=ABCMeta):
7 | """
8 | Eviction base.
9 | """
10 |
11 | @abstractmethod
12 | def put(self, objs: List[Any]):
13 | pass
14 |
15 | @abstractmethod
16 | def get(self, obj: Any):
17 | pass
18 |
19 | @property
20 | @abstractmethod
21 | def policy(self) -> str:
22 | pass
23 |
--------------------------------------------------------------------------------
/modelcache_mm/manager/eviction_manager.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | class EvictionManager:
3 | MAX_MARK_COUNT = 5000
4 | MAX_MARK_RATE = 0.1
5 | BATCH_SIZE = 100000
6 | REBUILD_CONDITION = 5
7 |
8 | def __init__(self, scalar_storage, vector_base):
9 | self._scalar_storage = scalar_storage
10 | self._vector_base = vector_base
11 | self.delete_count = 0
12 |
13 | def check_evict(self):
14 | mark_count = self._scalar_storage.count(state=-1)
15 | all_count = self._scalar_storage.count(is_all=True)
16 | if (
17 | mark_count > self.MAX_MARK_COUNT
18 | or mark_count / all_count > self.MAX_MARK_RATE
19 | ):
20 | return True
21 | return False
22 |
23 | def delete(self):
24 | mark_ids = self._scalar_storage.get_ids(deleted=True)
25 | self._scalar_storage.clear_deleted_data()
26 | self._vector_base.delete(mark_ids)
27 | self.delete_count += 1
28 | if self.delete_count >= self.REBUILD_CONDITION:
29 | self.rebuild()
30 |
31 | def rebuild(self):
32 | self._scalar_storage.clear_deleted_data()
33 | ids = self._scalar_storage.get_ids(deleted=False)
34 | self._vector_base.rebuild(ids)
35 | self.delete_count = 0
36 |
37 | def soft_evict(self, marked_keys):
38 | self._scalar_storage.mark_deleted(marked_keys)
39 |
--------------------------------------------------------------------------------
/modelcache_mm/manager/factory.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from typing import Union, Callable
3 | from modelcache_mm.manager import CacheBase, VectorBase, ObjectBase
4 | from modelcache_mm.manager.data_manager import SSDataManager, MapDataManager
5 |
6 |
7 | def get_data_manager(
8 | cache_base: Union[CacheBase, str] = None,
9 | vector_base: Union[VectorBase, str] = None,
10 | object_base: Union[ObjectBase, str] = None,
11 | max_size: int = 1000,
12 | clean_size: int = None,
13 | eviction: str = "LRU",
14 | data_path: str = "data_map.txt",
15 | get_data_container: Callable = None,
16 | ):
17 | if not cache_base and not vector_base:
18 | return MapDataManager(data_path, max_size, get_data_container)
19 | if isinstance(cache_base, str):
20 | cache_base = CacheBase(name=cache_base)
21 | if isinstance(vector_base, str):
22 | vector_base = VectorBase(name=vector_base)
23 | if isinstance(object_base, str):
24 | object_base = ObjectBase(name=object_base)
25 | assert cache_base and vector_base
26 | return SSDataManager(cache_base, vector_base, object_base, max_size, clean_size, eviction)
27 |
--------------------------------------------------------------------------------
/modelcache_mm/manager/object_data/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from modelcache.utils.lazy_import import LazyImport
3 | object_manager = LazyImport(
4 | "object_manager", globals(), "modelcache.manager.object_data.manager"
5 | )
6 |
7 |
8 | def ObjectBase(name: str, **kwargs):
9 | return object_manager.ObjectBase.get(name, **kwargs)
10 |
--------------------------------------------------------------------------------
/modelcache_mm/manager/object_data/base.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from abc import ABC, abstractmethod
3 | from typing import Any, List
4 |
5 |
6 | class ObjectBase(ABC):
7 | """
8 | Object storage base.
9 | """
10 |
11 | @abstractmethod
12 | def put(self, obj: Any) -> str:
13 | pass
14 |
15 | @abstractmethod
16 | def get(self, obj: str) -> Any:
17 | pass
18 |
19 | @abstractmethod
20 | def get_access_link(self, obj: str) -> str:
21 | pass
22 |
23 | @abstractmethod
24 | def delete(self, to_delete: List[str]):
25 | pass
26 |
--------------------------------------------------------------------------------
/modelcache_mm/manager/scalar_data/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from modelcache_mm.utils.lazy_import import LazyImport
3 | scalar_manager = LazyImport(
4 | "scalar_manager", globals(), "modelcache_mm.manager.scalar_data.manager"
5 | )
6 |
7 |
8 | def CacheBase(name: str, **kwargs):
9 | return scalar_manager.CacheBase.get(name, **kwargs)
10 |
--------------------------------------------------------------------------------
/modelcache_mm/manager/scalar_data/base.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from abc import ABCMeta, abstractmethod
3 | from dataclasses import dataclass
4 | from typing import Union, Dict, List, Optional, Any
5 | from enum import IntEnum
6 | import numpy as np
7 |
8 |
9 | class DataType(IntEnum):
10 | STR = 0
11 | IMAGE_BASE64 = 1
12 | IMAGE_URL = 2
13 |
14 |
15 | @dataclass
16 | class QuestionDep:
17 | """
18 | QuestionDep
19 | """
20 |
21 | name: str
22 | data: str
23 | dep_type: int = DataType.STR
24 |
25 | @classmethod
26 | def from_dict(cls, d: Dict):
27 | return cls(
28 | name=d["name"],
29 | data=d["data"],
30 | dep_type=d["dep_type"]
31 | )
32 |
33 |
34 | @dataclass
35 | class Question:
36 | """
37 | Question
38 | """
39 |
40 | content: str
41 | deps: Optional[List[QuestionDep]] = None
42 |
43 | @classmethod
44 | def from_dict(cls, d: Dict):
45 | deps = []
46 | for dep in d["deps"]:
47 | deps.append(QuestionDep.from_dict(dep))
48 | return cls(d["content"], deps)
49 |
50 |
51 | @dataclass
52 | class Answer:
53 | """
54 | data_type:
55 | 0: str
56 | 1: base64 image
57 | """
58 |
59 | answer: Any
60 | answer_type: int = DataType.STR
61 |
62 |
63 | @dataclass
64 | class CacheData:
65 | """
66 | CacheData
67 | """
68 |
69 | question: Union[str, Question]
70 | answers: List[Answer]
71 | embedding_data: Optional[np.ndarray] = None
72 |
73 | def __init__(self, question, answers, embedding_data=None):
74 | self.question = question
75 | self.answers = []
76 | if isinstance(answers, (str, Answer)):
77 | answers = [answers]
78 | for data in answers:
79 | if isinstance(data, (list, tuple)):
80 | self.answers.append(Answer(*data))
81 | elif isinstance(data, Answer):
82 | self.answers.append(data)
83 | else:
84 | self.answers.append(Answer(answer=data))
85 | self.embedding_data = embedding_data
86 |
87 |
88 | class CacheStorage(metaclass=ABCMeta):
89 | """
90 | BaseStorage for scalar data.
91 | """
92 |
93 | @abstractmethod
94 | def create(self):
95 | pass
96 |
97 | @abstractmethod
98 | def batch_insert(self, all_data: List[CacheData]):
99 | pass
100 |
101 | @abstractmethod
102 | def insert_query_resp(self, query_resp, **kwargs):
103 | pass
104 |
105 | @abstractmethod
106 | def get_data_by_id(self, key):
107 | pass
108 |
109 | @abstractmethod
110 | def mark_deleted(self, keys):
111 | pass
112 |
113 | @abstractmethod
114 | def model_deleted(self, model_name):
115 | pass
116 |
117 | @abstractmethod
118 | def clear_deleted_data(self):
119 | pass
120 |
121 | @abstractmethod
122 | def get_ids(self, deleted=True):
123 | pass
124 |
125 | @abstractmethod
126 | def count(self):
127 | pass
128 |
129 | def flush(self):
130 | pass
131 |
132 | @abstractmethod
133 | def close(self):
134 | pass
135 |
--------------------------------------------------------------------------------
/modelcache_mm/manager/scalar_data/manager.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from modelcache_mm.utils import import_sql_client
3 | from modelcache_mm.utils.error import NotFoundError
4 |
5 | SQL_URL = {"sqlite": "./sqlite.db"}
6 |
7 |
8 | class CacheBase:
9 | """
10 | CacheBase to manager the cache storage.
11 | """
12 |
13 | def __init__(self):
14 | raise EnvironmentError(
15 | "CacheBase is designed to be instantiated, please using the `CacheBase.get(name)`."
16 | )
17 |
18 | @staticmethod
19 | def get(name, **kwargs):
20 |
21 | if name in ["mysql", "oceanbase"]:
22 | from modelcache_mm.manager.scalar_data.sql_storage import SQLStorage
23 | config = kwargs.get("config")
24 | import_sql_client(name)
25 | cache_base = SQLStorage(db_type=name, config=config)
26 | elif name == 'sqlite':
27 | from modelcache_mm.manager.scalar_data.sql_storage_sqlite import SQLStorage
28 | sql_url = kwargs.get("sql_url", SQL_URL[name])
29 | cache_base = SQLStorage(db_type=name, url=sql_url)
30 | else:
31 | raise NotFoundError("cache store", name)
32 | return cache_base
33 |
--------------------------------------------------------------------------------
/modelcache_mm/manager/scalar_data/sql_storage.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | import time
4 |
5 | import pymysql
6 | import json
7 | import base64
8 | from typing import List
9 | from modelcache_mm.manager.scalar_data.base import CacheStorage, CacheData
10 | from DBUtils.PooledDB import PooledDB
11 |
12 |
13 | class SQLStorage(CacheStorage):
14 | def __init__(
15 | self,
16 | db_type: str = "mysql",
17 | config=None
18 | ):
19 | self.host = config.get('mysql', 'host')
20 | self.port = int(config.get('mysql', 'port'))
21 | self.username = config.get('mysql', 'username')
22 | self.password = config.get('mysql', 'password')
23 | self.database = config.get('mysql', 'database')
24 | self.pool = PooledDB(
25 | creator=pymysql,
26 | host=self.host,
27 | user=self.username,
28 | password=self.password,
29 | port=self.port,
30 | database=self.database
31 | )
32 |
33 | def create(self):
34 | pass
35 |
36 | # def _insert(self, data: List):
37 | # answer = data[0]
38 | # text = data[1]
39 | # image_url = data[2]
40 | # image_id = data[3]
41 | # model = data[4]
42 | # answer_type = 0
43 | #
44 | # table_name = "multimodal_answer"
45 | # insert_sql = "INSERT INTO {} (question_text, image_url, image_id, answer, answer_type, model) VALUES (%s, %s, %s, %s, %s, %s)".format(table_name)
46 | # conn = self.pool.connection()
47 | # try:
48 | # with conn.cursor() as cursor:
49 | # # data insert operation
50 | # values = (text, image_url, image_id, answer, answer_type, model)
51 | # cursor.execute(insert_sql, values)
52 | # conn.commit()
53 | # id = cursor.lastrowid
54 | # finally:
55 | # # Close the connection and return it back to the connection pool
56 | # conn.close()
57 | # return id
58 |
59 | def _insert(self, data: List):
60 | answer = data[0]
61 | text = data[1]
62 | image_url = data[2]
63 | image_id = data[3]
64 | model = data[4]
65 | answer_type = 0
66 |
67 | table_name = "open_cache_mm_answer"
68 | insert_sql = "INSERT INTO {} (question_text, image_url, image_id, answer, answer_type, model) VALUES (%s, %s, %s, %s, %s, %s)".format(table_name)
69 |
70 | conn = self.pool.connection()
71 | try:
72 | with conn.cursor() as cursor:
73 | # insert data operation
74 | values = (text, image_url, image_id, answer, answer_type, model)
75 | cursor.execute(insert_sql, values)
76 | conn.commit()
77 | id = cursor.lastrowid
78 | finally:
79 | # Close the connection and return it to the connection pool.
80 | conn.close()
81 | return id
82 |
83 | def batch_insert(self, all_data: List[CacheData]):
84 | ids = []
85 | for data in all_data:
86 | ids.append(self._insert(data))
87 | return ids
88 |
89 | def insert_query_resp(self, query_resp, **kwargs):
90 | error_code = query_resp.get('errorCode')
91 | error_desc = query_resp.get('errorDesc')
92 | cache_hit = query_resp.get('cacheHit')
93 | model = kwargs.get('model')
94 | query = kwargs.get('query')
95 | delta_time = kwargs.get('delta_time')
96 | hit_query = query_resp.get('hit_query')
97 | answer = query_resp.get('answer')
98 |
99 | if isinstance(hit_query, list):
100 | hit_query = json.dumps(hit_query, ensure_ascii=False)
101 |
102 | table_name = "open_cache_mm_query_log"
103 | insert_sql = "INSERT INTO {} (error_code, error_desc, cache_hit, model, query, delta_time, hit_query, answer) VALUES (%s, %s, %s, %s, %s, %s, %s, %s)".format(table_name)
104 | conn = self.pool.connection()
105 | try:
106 | with conn.cursor() as cursor:
107 | # 执行插入数据操作
108 | values = (error_code, error_desc, cache_hit, model, query, delta_time, hit_query, answer)
109 | cursor.execute(insert_sql, values)
110 | conn.commit()
111 | finally:
112 | # 关闭连接,将连接返回给连接池
113 | conn.close()
114 |
115 | # def get_data_by_id(self, key: int):
116 | # table_name = "cache_codegpt_answer"
117 | # query_sql = "select question, answer, embedding_data, model from {} where id={}".format(table_name, key)
118 | # conn_start = time.time()
119 | # conn = self.pool.connection()
120 | #
121 | # search_start = time.time()
122 | # try:
123 | # with conn.cursor() as cursor:
124 | # # 执行数据库操作
125 | # cursor.execute(query_sql)
126 | # resp = cursor.fetchone()
127 | # finally:
128 | # # 关闭连接,将连接返回给连接池
129 | # conn.close()
130 | #
131 | # if resp is not None and len(resp) == 4:
132 | # return resp
133 | # else:
134 | # return None
135 |
136 | def get_data_by_id(self, key: int):
137 | table_name = "open_cache_mm_answer"
138 | query_sql = "select question_text, image_url, image_id, answer, model from {} where id={}".format(table_name, key)
139 | conn = self.pool.connection()
140 | try:
141 | with conn.cursor() as cursor:
142 | cursor.execute(query_sql)
143 | resp = cursor.fetchone()
144 | finally:
145 | conn.close()
146 |
147 | if resp is not None and len(resp) == 5:
148 | return resp
149 | else:
150 | return None
151 |
152 | def update_hit_count_by_id(self, primary_id: int):
153 | table_name = "open_cache_mm_answer"
154 | update_sql = "UPDATE {} SET hit_count = hit_count+1 WHERE id={}".format(table_name, primary_id)
155 | conn = self.pool.connection()
156 |
157 | try:
158 | with conn.cursor() as cursor:
159 | cursor.execute(update_sql)
160 | conn.commit()
161 | finally:
162 | conn.close()
163 |
164 | def get_ids(self, deleted=True):
165 | pass
166 |
167 | def mark_deleted(self, keys):
168 | table_name = "open_cache_mm_answer"
169 | delete_sql = "Delete from {} WHERE id in ({})".format(table_name, ",".join([str(i) for i in keys]))
170 |
171 | conn = self.pool.connection()
172 | try:
173 | with conn.cursor() as cursor:
174 | cursor.execute(delete_sql)
175 | delete_count = cursor.rowcount
176 | conn.commit()
177 | finally:
178 | conn.close()
179 | return delete_count
180 |
181 | def model_deleted(self, model_name):
182 | table_name = "open_cache_mm_answer"
183 | delete_sql = "Delete from {} WHERE model='{}'".format(table_name, model_name)
184 | conn = self.pool.connection()
185 | try:
186 | with conn.cursor() as cursor:
187 | resp = cursor.execute(delete_sql)
188 | conn.commit()
189 | finally:
190 | conn.close()
191 | return resp
192 |
193 | def clear_deleted_data(self):
194 | pass
195 |
196 | def count(self, state: int = 0, is_all: bool = False):
197 | pass
198 |
199 | def close(self):
200 | pass
201 |
202 | def count_answers(self):
203 | pass
204 |
--------------------------------------------------------------------------------
/modelcache_mm/manager/scalar_data/sql_storage_sqlite.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import json
3 | from typing import List
4 | from modelcache.manager.scalar_data.base import CacheStorage, CacheData
5 | import sqlite3
6 |
7 |
8 | class SQLStorage(CacheStorage):
9 | def __init__(
10 | self,
11 | db_type: str = "mysql",
12 | config=None,
13 | url="./sqlite.db"
14 | ):
15 | self._url = url
16 | self.create()
17 |
18 | def create(self):
19 | # answer_table_sql = """CREATE TABLE IF NOT EXISTS modelcache_llm_answer (
20 | # id INTEGER PRIMARY KEY AUTOINCREMENT,
21 | # gmt_create TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
22 | # gmt_modified TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
23 | # question TEXT NOT NULL,
24 | # answer TEXT NOT NULL,
25 | # answer_type INTEGER NOT NULL,
26 | # hit_count INTEGER NOT NULL DEFAULT 0,
27 | # model VARCHAR(1000) NOT NULL,
28 | # embedding_data BLOB NOT NULL
29 | # );
30 | # """
31 |
32 | answer_table_sql = """CREATE TABLE IF NOT EXISTS `open_cache_mm_answer` (
33 | `id` INTEGER PRIMARY KEY AUTOINCREMENT,
34 | `gmt_create` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
35 | `gmt_modified` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
36 | `question_text` TEXT NOT NULL,
37 | `image_url` VARCHAR(2048) NOT NULL,
38 | `answer` TEXT NOT NULL,
39 | `answer_type` INTEGER NOT NULL,
40 | `hit_count` INTEGER NOT NULL DEFAULT 0,
41 | `model` VARCHAR(1000) NOT NULL,
42 | `image_raw` BLOB DEFAULT NULL,
43 | `image_id` VARCHAR(1000) DEFAULT NULL
44 | );
45 | """
46 |
47 | log_table_sql = """CREATE TABLE IF NOT EXISTS modelcache_query_log (
48 | id INTEGER PRIMARY KEY AUTOINCREMENT,
49 | gmt_create TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
50 | gmt_modified TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
51 | error_code INTEGER NOT NULL,
52 | error_desc VARCHAR(1000) NOT NULL,
53 | cache_hit VARCHAR(100) NOT NULL,
54 | delta_time REAL NOT NULL,
55 | model VARCHAR(1000) NOT NULL,
56 | query TEXT NOT NULL,
57 | hit_query TEXT NOT NULL,
58 | answer TEXT NOT NULL
59 | );
60 | """
61 |
62 | conn = sqlite3.connect(self._url)
63 | try:
64 | cursor = conn.cursor()
65 | cursor.execute(answer_table_sql)
66 | cursor.execute(log_table_sql)
67 | conn.commit()
68 | cursor.close()
69 | conn.close()
70 | finally:
71 | conn.close()
72 |
73 | def _insert(self, data: List):
74 | answer = data[0]
75 | text = data[1]
76 | image_url = data[2]
77 | image_id = data[3]
78 | model = data[4]
79 | answer_type = 0
80 |
81 | table_name = "open_cache_mm_answer"
82 | insert_sql = "INSERT INTO {} (question_text, image_url, image_id, answer, answer_type, model) VALUES (?, ?, ?, ?, ?, ?)".format(table_name)
83 |
84 | conn = sqlite3.connect(self._url)
85 | try:
86 | cursor = conn.cursor()
87 | values = (text, image_url, image_id, answer, answer_type, model)
88 | cursor.execute(insert_sql, values)
89 | conn.commit()
90 | id = cursor.lastrowid
91 | cursor.close()
92 | conn.close()
93 | finally:
94 | conn.close()
95 | return id
96 |
97 | def batch_insert(self, all_data: List[CacheData]):
98 | ids = []
99 | for data in all_data:
100 | ids.append(self._insert(data))
101 | return ids
102 |
103 | def insert_query_resp(self, query_resp, **kwargs):
104 | error_code = query_resp.get('errorCode')
105 | error_desc = query_resp.get('errorDesc')
106 | cache_hit = query_resp.get('cacheHit')
107 | model = kwargs.get('model')
108 | query = kwargs.get('query')
109 | delta_time = kwargs.get('delta_time')
110 | hit_query = query_resp.get('hit_query')
111 | answer = query_resp.get('answer')
112 |
113 | if isinstance(hit_query, list):
114 | hit_query = json.dumps(hit_query, ensure_ascii=False)
115 |
116 | table_name = "modelcache_query_log"
117 | insert_sql = "INSERT INTO {} (error_code, error_desc, cache_hit, model, query, delta_time, hit_query, answer) VALUES (%s, %s, %s, %s, %s, %s, %s, %s)".format(table_name)
118 |
119 | conn = sqlite3.connect(self._url)
120 | try:
121 | cursor = conn.cursor()
122 | values = (error_code, error_desc, cache_hit, model, query, delta_time, hit_query, answer)
123 | cursor.execute(insert_sql, values)
124 | conn.commit()
125 | cursor.close()
126 | conn.close()
127 | finally:
128 | conn.close()
129 |
130 | def get_data_by_id(self, key: int):
131 | table_name = "open_cache_mm_answer"
132 | query_sql = "select question, answer, embedding_data, model from {} where id={}".format(table_name, key)
133 | conn = sqlite3.connect(self._url)
134 | try:
135 | cursor = conn.cursor()
136 | cursor.execute(query_sql)
137 | resp = cursor.fetchone()
138 | conn.commit()
139 | cursor.close()
140 | conn.close()
141 | finally:
142 | conn.close()
143 |
144 | if resp is not None and len(resp) == 4:
145 | return resp
146 | else:
147 | return None
148 |
149 | def update_hit_count_by_id(self, primary_id: int):
150 | table_name = "open_cache_mm_answer"
151 | update_sql = "UPDATE {} SET hit_count = hit_count+1 WHERE id={}".format(table_name, primary_id)
152 |
153 | conn = sqlite3.connect(self._url)
154 | try:
155 | cursor = conn.cursor()
156 | cursor.execute(update_sql)
157 | conn.commit()
158 | cursor.close()
159 | conn.close()
160 | finally:
161 | # 关闭连接,将连接返回给连接池
162 | conn.close()
163 |
164 | def get_ids(self, deleted=True):
165 | pass
166 |
167 | def mark_deleted(self, keys):
168 | table_name = "open_cache_mm_answer"
169 | delete_sql = "Delete from {} WHERE id in ({})".format(table_name, ",".join([str(i) for i in keys]))
170 | conn = sqlite3.connect(self._url)
171 | try:
172 | cursor = conn.cursor()
173 | cursor.execute(delete_sql)
174 | delete_count = cursor.rowcount
175 | conn.commit()
176 | cursor.close()
177 | conn.close()
178 | finally:
179 | conn.close()
180 | return delete_count
181 |
182 | def model_deleted(self, model_name):
183 | table_name = "open_cache_mm_answer"
184 | delete_sql = "Delete from {} WHERE model='{}'".format(table_name, model_name)
185 | conn = sqlite3.connect(self._url)
186 | try:
187 | cursor = conn.cursor()
188 | resp = cursor.execute(delete_sql)
189 | conn.commit()
190 | cursor.close()
191 | conn.close()
192 | finally:
193 | conn.close()
194 | return resp
195 |
196 | def clear_deleted_data(self):
197 | pass
198 |
199 | def count(self, state: int = 0, is_all: bool = False):
200 | pass
201 |
202 | def close(self):
203 | pass
204 |
205 | def count_answers(self):
206 | pass
207 |
--------------------------------------------------------------------------------
/modelcache_mm/manager/vector_data/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from modelcache_mm.utils.lazy_import import LazyImport
3 |
4 | vector_manager = LazyImport(
5 | "vector_manager", globals(), "modelcache_mm.manager.vector_data.manager"
6 | )
7 |
8 |
9 | def VectorBase(name: str, **kwargs):
10 | return vector_manager.VectorBase.get(name, **kwargs)
11 |
--------------------------------------------------------------------------------
/modelcache_mm/manager/vector_data/base.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from abc import ABC, abstractmethod
3 | import numpy as np
4 | from typing import List
5 | from dataclasses import dataclass
6 |
7 |
8 | @dataclass
9 | class VectorData:
10 | id: int
11 | data: np.ndarray
12 |
13 |
14 | class VectorBase(ABC):
15 | """VectorBase: base vector store interface"""
16 |
17 | @abstractmethod
18 | def add(self, datas: List[VectorData], model=None, mm_type=None):
19 | pass
20 |
21 | # @abstractmethod
22 | # def search(self, data: np.ndarray, top_k: int, model):
23 | # pass
24 |
25 | @abstractmethod
26 | def search(self, data: np.ndarray, top_k: int, model, mm_type):
27 | pass
28 |
29 | @abstractmethod
30 | def create(self, model=None, mm_type=None):
31 | pass
32 |
33 | @abstractmethod
34 | def rebuild(self, ids=None) -> bool:
35 | pass
36 |
37 | @abstractmethod
38 | def delete(self, ids) -> bool:
39 | pass
40 |
41 | @abstractmethod
42 | def rebuild_idx(self, model):
43 | pass
44 |
45 | def flush(self):
46 | pass
47 |
48 | def close(self):
49 | pass
50 |
--------------------------------------------------------------------------------
/modelcache_mm/manager/vector_data/chroma.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import numpy as np
4 | import logging
5 | from modelcache_mm.manager.vector_data.base import VectorBase, VectorData
6 | from modelcache_mm.utils import import_chromadb, import_torch
7 | from modelcache_mm.utils.index_util import get_mm_index_name
8 |
9 | import_torch()
10 | import_chromadb()
11 |
12 | import chromadb
13 |
14 |
15 | class Chromadb(VectorBase):
16 |
17 | def __init__(
18 | self,
19 | persist_directory="./chromadb",
20 | top_k: int = 1,
21 | ):
22 | self.top_k = top_k
23 |
24 | self._client = chromadb.PersistentClient(path=persist_directory)
25 | self._collection = None
26 |
27 | def create(self, model=None, mm_type=None):
28 | try:
29 | collection_name_model = get_mm_index_name(model, mm_type)
30 | # collection_name_model = self.collection_name + '_' + model
31 | self._client.get_or_create_collection(name=collection_name_model)
32 | except Exception as e:
33 | raise ValueError(str(e))
34 |
35 | def add(self, datas: List[VectorData], model=None, mm_type=None):
36 | collection_name_model = get_mm_index_name(model, mm_type)
37 | self._collection = self._client.get_or_create_collection(name=collection_name_model)
38 |
39 | data_array, id_array = map(list, zip(*((data.data.tolist(), str(data.id)) for data in datas)))
40 | self._collection.add(embeddings=data_array, ids=id_array)
41 |
42 | def search(self, data: np.ndarray, top_k: int = -1, model=None, mm_type='mm'):
43 | collection_name_model = get_mm_index_name(model, mm_type)
44 | self._collection = self._client.get_or_create_collection(name=collection_name_model)
45 |
46 | if self._collection.count() == 0:
47 | return []
48 | if top_k == -1:
49 | top_k = self.top_k
50 | results = self._collection.query(
51 | query_embeddings=[data.tolist()],
52 | n_results=top_k,
53 | include=["distances"],
54 | )
55 | return list(zip(results["distances"][0], [int(x) for x in results["ids"][0]]))
56 |
57 | def delete(self, ids, model=None, mm_type=None):
58 | try:
59 | collection_name_model = get_mm_index_name(model, mm_type)
60 | self._collection = self._client.get_or_create_collection(name=collection_name_model)
61 | # 查询集合中实际存在的 ID
62 | ids_str = [str(x) for x in ids]
63 | existing_ids = set(self._collection.get(ids=ids_str).ids)
64 |
65 | # 删除存在的 ID
66 | if existing_ids:
67 | self._collection.delete(list(existing_ids))
68 |
69 | # 返回实际删除的条目数量
70 | return len(existing_ids)
71 |
72 | except Exception as e:
73 | logging.error('Error during deletion: {}'.format(e))
74 | raise ValueError(str(e))
75 |
76 | def rebuild_idx(self, model, mm_type=None):
77 | collection_name_model = get_mm_index_name(model, mm_type)
78 |
79 | # 检查集合是否存在,如果存在则删除
80 | collections = self._client.list_collections()
81 | if any(col.name == collection_name_model for col in collections):
82 | self._client.delete_collection(collection_name_model)
83 | else:
84 | return 'model collection not found, please check!'
85 |
86 | try:
87 | self._client.create_collection(collection_name_model)
88 | except Exception as e:
89 | logging.info(f'rebuild_collection: {e}')
90 | raise ValueError(str(e))
91 |
92 | def rebuild(self, ids=None):
93 | pass
94 |
95 | def flush(self):
96 | pass
97 |
98 | def close(self):
99 | pass
100 |
--------------------------------------------------------------------------------
/modelcache_mm/manager/vector_data/faiss.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | from typing import List
4 | import numpy as np
5 | from modelcache_mm.manager.vector_data.base import VectorBase, VectorData
6 | from modelcache_mm.utils import import_faiss
7 | import_faiss()
8 | import faiss # pylint: disable=C0413
9 |
10 |
11 | class Faiss(VectorBase):
12 | def __init__(self,
13 | index_file_path,
14 | dimension: int = 0,
15 | top_k: int = 1
16 | ):
17 | self._dimension = dimension
18 | self._index_file_path = index_file_path
19 | self._index = faiss.index_factory(self._dimension, "IDMap,Flat", faiss.METRIC_L2)
20 | self._top_k = top_k
21 | if os.path.isfile(index_file_path):
22 | self._index = faiss.read_index(index_file_path)
23 |
24 | def add(self, datas: List[VectorData], model=None, mm_type=None):
25 | data_array, id_array = map(list, zip(*((data.data, data.id) for data in datas)))
26 | np_data = np.array(data_array).astype("float32")
27 | ids = np.array(id_array)
28 | self._index.add_with_ids(np_data, ids)
29 |
30 | def search(self, data: np.ndarray, top_k: int, model, mm_type='mm'):
31 | if self._index.ntotal == 0:
32 | return None
33 | if top_k == -1:
34 | top_k = self._top_k
35 | np_data = np.array(data).astype("float32").reshape(1, -1)
36 | dist, ids = self._index.search(np_data, top_k)
37 | ids = [int(i) for i in ids[0]]
38 | return list(zip(dist[0], ids))
39 |
40 | def rebuild_col(self, ids=None):
41 | try:
42 | self._index.reset()
43 | except Exception as e:
44 | return f"An error occurred during index rebuild: {e}"
45 |
46 | def rebuild(self, ids=None):
47 | return True
48 |
49 | def delete(self, ids):
50 | ids_to_remove = np.array(ids)
51 | self._index.remove_ids(faiss.IDSelectorBatch(ids_to_remove.size, faiss.swig_ptr(ids_to_remove)))
52 |
53 | def create(self, model=None, mm_type=None):
54 | pass
55 | # collection_name_model = get_mm_index_name(model, mm_type)
56 | # try:
57 | # index_prefix = get_mm_index_prefix(model, mm_type)
58 | # self.create_index(collection_name_model, mm_type, index_prefix)
59 | # except Exception as e:
60 | # raise ValueError(str(e))
61 | # return 'success'
62 |
63 | def flush(self):
64 | faiss.write_index(self._index, self._index_file_path)
65 |
66 | def close(self):
67 | self.flush()
68 |
69 | def rebuild_idx(self, model):
70 | pass
71 |
72 | def count(self):
73 | return self._index.ntotal
74 |
--------------------------------------------------------------------------------
/modelcache_mm/manager/vector_data/manager.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from modelcache_mm.utils.error import NotFoundError, ParamError
3 |
4 | TOP_K = 1
5 | FAISS_INDEX_PATH = "mm_faiss.index"
6 | DIMENSION = 0
7 | MILVUS_HOST = "localhost"
8 | MILVUS_PORT = 19530
9 | MILVUS_USER = ""
10 | MILVUS_PSW = ""
11 | MILVUS_SECURE = False
12 | MILVUS_INDEX_PARAMS = {
13 | "metric_type": "L2",
14 | "index_type": "HNSW",
15 | "params": {"M": 8, "efConstruction": 64},
16 | }
17 |
18 | COLLECTION_NAME = "modelcache"
19 |
20 |
21 | class VectorBase:
22 | """
23 | VectorBase to manager the vector base.
24 | """
25 |
26 | def __init__(self):
27 | raise EnvironmentError(
28 | "VectorBase is designed to be instantiated, please using the `VectorBase.get(name)`."
29 | )
30 |
31 | @staticmethod
32 | def check_dimension(dimension):
33 | if dimension <= 0:
34 | raise ParamError(
35 | f"the dimension should be greater than zero, current value: {dimension}."
36 | )
37 |
38 | @staticmethod
39 | def get(name, **kwargs):
40 | top_k = kwargs.get("top_k", TOP_K)
41 | if name == "milvus":
42 | from modelcache.manager.vector_data.milvus import Milvus
43 | milvus_config = kwargs.get("milvus_config")
44 | dimension = kwargs.get("dimension", DIMENSION)
45 | VectorBase.check_dimension(dimension)
46 | host = milvus_config.get('milvus', 'host')
47 | port = milvus_config.get('milvus', 'port')
48 | user = milvus_config.get('milvus', 'user')
49 | password = milvus_config.get('milvus', 'password')
50 |
51 | secure = kwargs.get("secure", MILVUS_SECURE)
52 | collection_name = kwargs.get("collection_name", COLLECTION_NAME)
53 | index_params = kwargs.get("index_params", MILVUS_INDEX_PARAMS)
54 | search_params = kwargs.get("search_params", None)
55 | local_mode = kwargs.get("local_mode", False)
56 | local_data = kwargs.get("local_data", "./milvus_data")
57 | vector_base = Milvus(
58 | host=host,
59 | port=port,
60 | user=user,
61 | password=password,
62 | secure=secure,
63 | collection_name=collection_name,
64 | dimension=dimension,
65 | top_k=top_k,
66 | index_params=index_params,
67 | search_params=search_params,
68 | local_mode=local_mode,
69 | local_data=local_data
70 | )
71 | elif name == "redis":
72 | from modelcache_mm.manager.vector_data.redis import RedisVectorStore
73 | redis_config = kwargs.get("redis_config")
74 |
75 | mm_dimension = kwargs.get("mm_dimension", DIMENSION)
76 | i_dimension = kwargs.get("i_dimension", DIMENSION)
77 | t_dimension = kwargs.get("t_dimension", DIMENSION)
78 | VectorBase.check_dimension(mm_dimension)
79 | VectorBase.check_dimension(i_dimension)
80 | VectorBase.check_dimension(t_dimension)
81 |
82 | host = redis_config.get('redis', 'host')
83 | port = redis_config.get('redis', 'port')
84 | user = redis_config.get('redis', 'user')
85 | password = redis_config.get('redis', 'password')
86 | namespace = kwargs.get("namespace", "")
87 | # collection_name = kwargs.get("collection_name", COLLECTION_NAME)
88 |
89 | vector_base = RedisVectorStore(
90 | host=host,
91 | port=port,
92 | username=user,
93 | password=password,
94 | namespace=namespace,
95 | top_k=top_k,
96 | mm_dimension=mm_dimension,
97 | i_dimension=i_dimension,
98 | t_dimension=t_dimension,
99 | )
100 | elif name == "faiss":
101 | from modelcache_mm.manager.vector_data.faiss import Faiss
102 | dimension = kwargs.get("dimension", DIMENSION)
103 | VectorBase.check_dimension(dimension)
104 |
105 | index_path = kwargs.pop("index_path", FAISS_INDEX_PATH)
106 | vector_base = Faiss(
107 | index_file_path=index_path,
108 | dimension=dimension,
109 | top_k=top_k
110 | )
111 | elif name == "chromadb":
112 | from modelcache_mm.manager.vector_data.chroma import Chromadb
113 |
114 | chromadb_config = kwargs.get("chromadb_config", None)
115 | persist_directory = chromadb_config.get('chromadb', 'persist_directory')
116 | vector_base = Chromadb(
117 | persist_directory=persist_directory,
118 | top_k=top_k,
119 | )
120 | else:
121 | raise NotFoundError("vector store", name)
122 | return vector_base
123 |
--------------------------------------------------------------------------------
/modelcache_mm/manager/vector_data/redis.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from typing import List
3 | import numpy as np
4 | from redis.commands.search.indexDefinition import IndexDefinition, IndexType
5 | from redis.commands.search.query import Query
6 | from redis.commands.search.field import VectorField, NumericField
7 | from redis.client import Redis
8 |
9 | from modelcache_mm.manager.vector_data.base import VectorBase, VectorData
10 | from modelcache_mm.utils import import_redis
11 | from modelcache_mm.utils.log import modelcache_log
12 | from modelcache_mm.utils.index_util import get_mm_index_name
13 | from modelcache_mm.utils.index_util import get_mm_index_prefix
14 | import_redis()
15 |
16 |
17 | class RedisVectorStore(VectorBase):
18 | def __init__(
19 | self,
20 | host: str = "localhost",
21 | port: str = "6379",
22 | username: str = "",
23 | password: str = "",
24 | mm_dimension: int = 0,
25 | i_dimension: int = 0,
26 | t_dimension: int = 0,
27 | top_k: int = 1,
28 | namespace: str = "",
29 | ):
30 | if mm_dimension <= 0:
31 | raise ValueError(
32 | f"invalid `dim` param: {mm_dimension} in the Milvus vector store."
33 | )
34 | self._client = Redis(
35 | host=host, port=int(port), username=username, password=password
36 | )
37 | self.top_k = top_k
38 | self.mm_dimension = mm_dimension
39 | self.i_dimension = i_dimension
40 | self.t_dimension = t_dimension
41 | self.namespace = namespace
42 | self.doc_prefix = f"{self.namespace}doc:"
43 |
44 | def _check_index_exists(self, index_name: str) -> bool:
45 | """Check if Redis index exists."""
46 | try:
47 | self._client.ft(index_name).info()
48 | except:
49 | modelcache_log.info("Index does not exist")
50 | return False
51 | modelcache_log.info("Index already exists")
52 | return True
53 |
54 | def create_index(self, index_name, type, index_prefix):
55 | # dimension = self.dimension
56 | if type == 'IMG_TEXT':
57 | dimension = self.mm_dimension
58 | elif type == 'IMG':
59 | dimension = self.i_dimension
60 | elif type == 'TEXT':
61 | dimension = self.t_dimension
62 | else:
63 | raise ValueError('dimension type exception')
64 | if self._check_index_exists(index_name):
65 | modelcache_log.info(
66 | "The %s already exists, and it will be used directly", index_name
67 | )
68 | return 'already_exists'
69 | else:
70 | id_field_name = "data_id"
71 | embedding_field_name = "data_vector"
72 |
73 | id = NumericField(name=id_field_name)
74 | embedding = VectorField(embedding_field_name,
75 | "HNSW", {
76 | "TYPE": "FLOAT32",
77 | "DIM": dimension,
78 | "DISTANCE_METRIC": "L2",
79 | "INITIAL_CAP": 1000,
80 | }
81 | )
82 | fields = [id, embedding]
83 | definition = IndexDefinition(prefix=[index_prefix], index_type=IndexType.HASH)
84 |
85 | # create Index
86 | self._client.ft(index_name).create_index(
87 | fields=fields, definition=definition
88 | )
89 | return 'create_success'
90 |
91 | def add(self, datas: List[VectorData], model=None, mm_type=None):
92 | for data in datas:
93 | id: int = data.id
94 | embedding = data.data.astype(np.float32).tobytes()
95 | index_prefix = get_mm_index_prefix(model, mm_type)
96 | id_field_name = "data_id"
97 | embedding_field_name = "data_vector"
98 | obj = {id_field_name: id, embedding_field_name: embedding}
99 | self._client.hset(f"{index_prefix}{id}", mapping=obj)
100 |
101 | def search(self, data: np.ndarray, top_k: int = -1, model=None, mm_type=None):
102 | index_name = get_mm_index_name(model, mm_type)
103 | id_field_name = "data_id"
104 | embedding_field_name = "data_vector"
105 |
106 | base_query = f'*=>[KNN 2 @{embedding_field_name} $vector AS distance]'
107 | query = (
108 | Query(base_query)
109 | .sort_by("distance")
110 | .return_fields(id_field_name, "distance")
111 | .dialect(2)
112 | )
113 | query_params = {"vector": data.astype(np.float32).tobytes()}
114 | results = (
115 | self._client.ft(index_name)
116 | .search(query, query_params=query_params)
117 | .docs
118 | )
119 | return [(float(result.distance), int(getattr(result, id_field_name))) for result in results]
120 |
121 | def create(self, model=None, mm_type=None):
122 | collection_name_model = get_mm_index_name(model, mm_type)
123 | try:
124 | index_prefix = get_mm_index_prefix(model, mm_type)
125 | self.create_index(collection_name_model, mm_type, index_prefix)
126 | except Exception as e:
127 | raise ValueError(str(e))
128 | return 'success'
129 |
130 | def rebuild(self, ids=None) -> bool:
131 | pass
132 |
133 | def rebuild_idx(self, model, mm_type=None):
134 | for mm_type in ['IMG_TEXT', 'TEXT']:
135 | index_name = get_mm_index_name(model, mm_type)
136 | if self._check_index_exists(index_name):
137 | try:
138 | self._client.ft(index_name).dropindex(delete_documents=True)
139 | except Exception as e:
140 | raise ValueError(str(e))
141 | try:
142 | index_prefix = get_mm_index_prefix(model, mm_type)
143 | self.create_index(index_name, mm_type, index_prefix)
144 | except Exception as e:
145 | raise ValueError(str(e))
146 |
147 | def delete(self, ids) -> None:
148 | pipe = self._client.pipeline()
149 | for data_id in ids:
150 | pipe.delete(f"{self.doc_prefix}{data_id}")
151 | pipe.execute()
152 |
153 | def create(self, model=None, type=None):
154 | index_name = get_mm_index_name(model, type)
155 | index_prefix = get_mm_index_prefix(model, type)
156 | return self.create_index(index_name, type, index_prefix)
157 |
158 | def get_index_by_name(self, index_name):
159 | pass
160 |
--------------------------------------------------------------------------------
/modelcache_mm/processor/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
--------------------------------------------------------------------------------
/modelcache_mm/processor/post.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import random
3 | from typing import List, Any
4 |
5 |
6 | def random_one(messages: List[Any]) -> Any:
7 | return random.choice(messages)
8 |
9 |
10 | def first(messages: List[Any]) -> Any:
11 | return messages[0]
12 |
13 |
14 | def nop(messages: List[Any]) -> Any:
15 | return messages
16 |
--------------------------------------------------------------------------------
/modelcache_mm/processor/pre.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import re
3 | from typing import Dict, Any
4 |
5 |
6 | def insert_last_content(data: Dict[str, Any], **_: Dict[str, Any]) -> Any:
7 | return data.get("chat_info")[-1]["query"]
8 |
9 |
10 | def query_last_content(data: Dict[str, Any], **_: Dict[str, Any]) -> Any:
11 | return data.get("query")[-1]["content"]
12 |
13 |
14 | def last_content_without_prompt(data: Dict[str, Any], **params: Dict[str, Any]) -> Any:
15 | last_content_str = data.get("messages")[-1]["content"]
16 | prompts = params.get("prompts", [])
17 | if prompts is None:
18 | return last_content_str
19 | pattern = "|".join(prompts)
20 | new_content_str = re.sub(pattern, "", last_content_str)
21 | return new_content_str
22 |
23 |
24 | def all_content(data: Dict[str, Any], **_: Dict[str, Any]) -> Any:
25 | s = ""
26 | messages = data.get("messages")
27 | for i, message in enumerate(messages):
28 | if i == len(messages) - 1:
29 | s += message["content"]
30 | else:
31 | s += message["content"] + "\n"
32 | return s
33 |
34 |
35 | def nop(data: Dict[str, Any], **_: Dict[str, Any]) -> Any:
36 | return data
37 |
38 |
39 | def get_prompt(data: Dict[str, Any], **_: Dict[str, Any]) -> Any:
40 | return data.get("prompt")
41 |
42 |
43 | def get_file_name(data: Dict[str, Any], **_: Dict[str, Any]) -> str:
44 | return data.get("file").name
45 |
46 |
47 | def get_file_bytes(data: Dict[str, Any], **_: Dict[str, Any]) -> bytes:
48 | return data.get("file").peek()
49 |
50 |
51 | def get_input_str(data: Dict[str, Any], **_: Dict[str, Any]) -> str:
52 | input_data = data.get("input")
53 | return str(input_data["image"].peek()) + input_data["question"]
54 |
55 |
56 | def get_input_image_file_name(data: Dict[str, Any], **_: Dict[str, Any]) -> str:
57 | input_data = data.get("input")
58 | return input_data["image"].name
59 |
60 |
61 | def query_multi_splicing(data: Dict[str, Any], **_: Dict[str, Any]) -> Any:
62 | query_list = data.get("query")
63 | return multi_splicing(query_list)
64 |
65 |
66 | def insert_multi_splicing(data: Dict[str, Any], **_: Dict[str, Any]) -> Any:
67 | insert_query_list = data.get("chat_info")[-1]['query']
68 | return multi_splicing(insert_query_list)
69 |
70 |
71 | def multi_splicing(data_list) -> Any:
72 | result_str = ""
73 | for d in data_list:
74 | role = d.get('role', '')
75 | content = d.get('content', '')
76 | result_str += role + "###" + content + "|||"
77 |
78 | # 去掉最后一个"|||"
79 | result_str = result_str[:-3]
80 |
81 | return result_str
82 |
83 |
84 | def multi_analysis(dialog_str):
85 | sub_strings = dialog_str.split('|||')
86 | dict_list = []
87 | for s in sub_strings:
88 | parts = s.split('###')
89 | if len(parts) == 2:
90 | role = parts[0]
91 | content = parts[1]
92 | elif len(parts) > 2:
93 | role = parts[0]
94 | content = '###'.join(parts[1:])
95 | else:
96 | content = 'exception'
97 | if content == '':
98 | d = {"role": role}
99 | else:
100 | d = {"role": role, "content": content}
101 | dict_list.append(d)
102 | result_list = dict_list
103 |
104 | # 输出结果
105 | return result_list
106 |
107 |
108 | def mm_insert_dict(data: Dict[str, Any], **_: Dict[str, Any]) -> Any:
109 | query_dict = data.get("chat_info")[-1]['query']
110 | return query_dict
111 |
112 |
113 | def mm_query_dict(data: Dict[str, Any], **_: Dict[str, Any]) -> Any:
114 | query_dict = data.get("query")
115 | return query_dict
116 |
--------------------------------------------------------------------------------
/modelcache_mm/report.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | class Report:
3 | def __init__(self):
4 | self.embedding_all_time = 0
5 | self.embedding_count = 0
6 | self.search_all_time = 0
7 | self.search_count = 0
8 | self.hint_cache_count = 0
9 |
10 | def embedding(self, delta_time):
11 | """Embedding counts and time.
12 |
13 | :param delta_time: additional runtime.
14 | """
15 | self.embedding_all_time += delta_time
16 | self.embedding_count += 1
17 |
18 | def search(self, delta_time):
19 | """Search counts and time.
20 |
21 | :param delta_time: additional runtime.
22 | """
23 | self.search_all_time += delta_time
24 | self.search_count += 1
25 |
26 | def average_embedding_time(self):
27 | """Average embedding time."""
28 | return round(
29 | self.embedding_all_time / self.embedding_count
30 | if self.embedding_count != 0
31 | else 0,
32 | 4,
33 | )
34 |
35 | def average_search_time(self):
36 | return round(
37 | self.search_all_time / self.search_count
38 | if self.embedding_count != 0
39 | else 0,
40 | 4,
41 | )
42 |
43 | def hint_cache(self):
44 | self.hint_cache_count += 1
45 |
--------------------------------------------------------------------------------
/modelcache_mm/similarity_evaluation/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from modelcache.similarity_evaluation.similarity_evaluation import SimilarityEvaluation
3 | from modelcache.utils.lazy_import import LazyImport
4 |
5 | exact_match = LazyImport(
6 | "exact_match", globals(), "modelcache.similarity_evaluation.exact_match"
7 | )
8 |
9 |
10 | def ExactMatchEvaluation():
11 | return exact_match.ExactMatchEvaluation()
12 |
--------------------------------------------------------------------------------
/modelcache_mm/similarity_evaluation/distance.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from typing import Tuple, Dict, Any
3 | from modelcache.similarity_evaluation import SimilarityEvaluation
4 |
5 |
6 | class SearchDistanceEvaluation(SimilarityEvaluation):
7 | def __init__(self, max_distance=4.0, positive=False):
8 | self.max_distance = max_distance
9 | self.positive = positive
10 |
11 | def evaluation(
12 | self, src_dict: Dict[str, Any], cache_dict: Dict[str, Any], **_
13 | ) -> float:
14 | distance, _ = cache_dict["search_result"]
15 | if distance < 0:
16 | distance = 0
17 | elif distance > self.max_distance:
18 | distance = self.max_distance
19 | if self.positive:
20 | return distance
21 | return self.max_distance - distance
22 |
23 | def range(self) -> Tuple[float, float]:
24 | return 0.0, self.max_distance
25 |
--------------------------------------------------------------------------------
/modelcache_mm/similarity_evaluation/exact_match.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from typing import Tuple, Dict, Any
3 | from modelcache.similarity_evaluation.similarity_evaluation import SimilarityEvaluation
4 |
5 |
6 | class ExactMatchEvaluation(SimilarityEvaluation):
7 | def __init__(self):
8 | pass
9 |
10 | def evaluation(
11 | self, src_dict: Dict[str, Any], cache_dict: Dict[str, Any], **_
12 | ) -> float:
13 | return 1 if cache_dict["question"] == src_dict["question"] else 0
14 |
15 | def range(self) -> Tuple[float, float]:
16 | return 0, 1
17 |
--------------------------------------------------------------------------------
/modelcache_mm/similarity_evaluation/similarity_evaluation.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from abc import ABCMeta, abstractmethod
3 | from typing import Tuple, Dict, Any
4 |
5 |
6 | class SimilarityEvaluation(metaclass=ABCMeta):
7 | @abstractmethod
8 | def evaluation(
9 | self, src_dict: Dict[str, Any], cache_dict: Dict[str, Any], **kwargs
10 | ) -> float:
11 | pass
12 |
13 | @abstractmethod
14 | def range(self) -> Tuple[float, float]:
15 | pass
16 |
--------------------------------------------------------------------------------
/modelcache_mm/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import importlib.util
3 | from typing import Optional
4 | from modelcache.utils.dependency_control import prompt_install
5 |
6 |
7 | def _check_library(libname: str, prompt: bool = True, package: Optional[str] = None):
8 | is_avail = False
9 | if importlib.util.find_spec(libname):
10 | is_avail = True
11 | if not is_avail and prompt:
12 | prompt_install(package if package else libname)
13 | return is_avail
14 |
15 |
16 | def import_onnxruntime():
17 | _check_library("onnxruntime")
18 |
19 |
20 | def import_huggingface():
21 | _check_library("transformers")
22 |
23 |
24 | def import_huggingface_hub():
25 | _check_library("huggingface_hub", package="huggingface-hub")
26 |
27 |
28 | def import_pymysql():
29 | _check_library("pymysql")
30 |
31 |
32 | def import_sql_client(db_name):
33 | if db_name in ["mysql"]:
34 | import_pymysql()
35 |
36 |
37 | def import_pymilvus():
38 | _check_library("pymilvus")
39 |
40 |
41 | def import_milvus_lite():
42 | _check_library("milvus")
43 |
44 |
45 | def import_faiss():
46 | _check_library("faiss", package="faiss-cpu")
47 |
48 |
49 | def import_torch():
50 | _check_library("torch")
51 |
52 |
53 | def import_fasttext():
54 | _check_library("fasttext")
55 |
56 |
57 | def import_paddle():
58 | prompt_install("protobuf==3.20.0")
59 | _check_library("paddlepaddle")
60 |
61 |
62 | def import_paddlenlp():
63 | _check_library("paddlenlp")
64 |
65 |
66 | def import_timm():
67 | _check_library("timm", package="timm")
68 |
69 |
70 | def import_pillow():
71 | _check_library("PIL", package="pillow")
72 |
73 |
74 | def import_redis():
75 | _check_library("redis")
76 |
77 |
78 | def import_chromadb():
79 | _check_library("chromadb", package="chromadb")
--------------------------------------------------------------------------------
/modelcache_mm/utils/cache_func.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | def cache_all(*_, **__):
3 | return True
--------------------------------------------------------------------------------
/modelcache_mm/utils/dependency_control.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import subprocess
3 | from modelcache.utils.error import PipInstallError
4 | from modelcache.utils.log import modelcache_log
5 |
6 |
7 | def prompt_install(package: str, warn: bool = False): # pragma: no cover
8 | """
9 | Function used to prompt user to install a package.
10 | """
11 | cmd = f"pip install {package}"
12 | try:
13 | if warn and input(f"Install {package}? Y/n: ") != "Y":
14 | raise ModuleNotFoundError(f"No module named {package}")
15 | subprocess.check_call(cmd, shell=True)
16 | modelcache_log.info("%s installed successfully!", package)
17 | except subprocess.CalledProcessError as e:
18 | raise PipInstallError(package) from e
19 |
--------------------------------------------------------------------------------
/modelcache_mm/utils/env_config.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
--------------------------------------------------------------------------------
/modelcache_mm/utils/error.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | class CacheError(Exception):
3 | """ModelCache base error"""
4 |
5 |
6 | class NotInitError(CacheError):
7 | """Raise when the cache has been used before it's inited"""
8 | def __init__(self):
9 | super().__init__("The cache should be inited before using")
10 |
11 |
12 | class RemoveError(CacheError):
13 | """Raise when the cache has been used before it's inited"""
14 | def __init__(self):
15 | super().__init__("The cache remove error")
16 |
17 | class NotFoundError(CacheError):
18 | """Raise when getting an unsupported store."""
19 | def __init__(self, store_type, current_type_name):
20 | super().__init__(f"Unsupported ${store_type}: {current_type_name}")
21 |
22 |
23 | class ParamError(CacheError):
24 | """Raise when receiving an invalid param."""
25 |
26 |
27 | class PipInstallError(CacheError):
28 | """Raise when failed to install package."""
29 | def __init__(self, package):
30 | super().__init__(f"Ran into error installing {package}.")
31 |
32 |
33 | class MultiTypeError(CacheError):
34 | def __init__(self):
35 | super().__init__("multichat type error, please check")
36 |
--------------------------------------------------------------------------------
/modelcache_mm/utils/index_util.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | def get_index_name(model):
5 | return 'multicache' + '_' + model
6 |
7 |
8 | def get_index_prefix(model):
9 | return 'prefix' + '_' + model
10 |
11 |
12 | def get_mm_index_name(model, mm_type):
13 | if mm_type not in ['IMG_TEXT', 'mm', 'IMG', 'image', 'TEXT', 'text']:
14 | raise ValueError('mm_type is not normal!')
15 | if mm_type == 'IMG_TEXT':
16 | mm_type = 'mm'
17 | elif mm_type == 'IMG':
18 | mm_type = 'image'
19 | elif mm_type == 'TEXT':
20 | mm_type = 'text'
21 | return 'multi_cache' + '_' + model + '_' + mm_type
22 |
23 |
24 | def get_mm_index_prefix(model, mm_type):
25 | if mm_type not in ['IMG_TEXT', 'mm', 'IMG', 'image', 'TEXT', 'text']:
26 | raise ValueError('mm_type is not normal!')
27 | if mm_type == 'IMG_TEXT':
28 | mm_type = 'mm'
29 | elif mm_type == 'IMG':
30 | mm_type = 'image'
31 | elif mm_type == 'TEXT':
32 | mm_type = 'text'
33 | return 'prefix' + '_' + model + '_' + mm_type
34 |
--------------------------------------------------------------------------------
/modelcache_mm/utils/lazy_import.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import importlib
3 | from types import ModuleType
4 |
5 |
6 | class LazyImport(ModuleType):
7 | """
8 | Lazily import a module.
9 | """
10 | def __init__(self, local_name, parent_module_globals, name):
11 | self._local_name = local_name
12 | self._parent_module_globals = parent_module_globals
13 | super().__init__(name)
14 |
15 | def _load(self):
16 | module = importlib.import_module(self.__name__)
17 | self._parent_module_globals[self._local_name] = module
18 | self.__dict__.update(module.__dict__)
19 | return module
20 |
21 | def __getattr__(self, item):
22 | module = self._load()
23 | return getattr(module, item)
24 |
25 | def __dir__(self):
26 | module = self._load()
27 | return dir(module)
28 |
--------------------------------------------------------------------------------
/modelcache_mm/utils/log.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import logging
3 |
4 | FORMAT = '%(asctime)s - %(thread)d - %(filename)s-%(module)s:%(lineno)s - %(levelname)s: %(message)s'
5 | logging.basicConfig(format=FORMAT)
6 |
7 | modelcache_log = logging.getLogger('modelcache')
8 |
--------------------------------------------------------------------------------
/modelcache_mm/utils/time.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import time
3 | from modelcache import cache
4 |
5 |
6 | def time_cal(func, func_name=None, report_func=None):
7 | def inner(*args, **kwargs):
8 | time_start = time.time()
9 | res = func(*args, **kwargs)
10 | delta_time = time.time() - time_start
11 | if cache.config.log_time_func:
12 | cache.config.log_time_func(
13 | func.__name__ if func_name is None else func_name, delta_time
14 | )
15 | if report_func is not None:
16 | report_func(delta_time)
17 | return res
18 |
19 | return inner
20 |
21 |
--------------------------------------------------------------------------------
/mulicache-readme-cn.md:
--------------------------------------------------------------------------------
1 | # MultiModal Cache
2 |
3 | 为满足多模态的性能要求,我们在 LLModel Cache 的基础上,开发了 MultiModal Cache 系统。MultiModal Cache 增强了 ModelCache 功能,架优化架构,适应多种应用场景。
4 |
5 | - [MultiModal Cache](#multimodal-cache)
6 | - [最新动态](#最新动态)
7 | - [特性](#特性)
8 | - [性能](#性能)
9 | - [效果评估](#效果评估)
10 | - [参与贡献](#参与贡献)
11 |
12 | ## 最新动态
13 |
14 | - [2024.12.12] MultiModal Cache 系统正式发布。
15 |
16 | ## 特性
17 |
18 | | 场景 | 数据类型 | 图像格式 | 数据隔离 |
19 | |------|----------|----------|----------|
20 | | 文本对话 | 文本 | 不适用 | 支持 |
21 | | 图文理解 | 文本+图像 | image_url/image_base64 | 支持 |
22 |
23 | - **兼容性**:支持文本和图片链接(image_url)和图片 Base64 编码三种数据格式及其组合。
24 | - **数据隔离**:支持多模型数据隔离,允许不同数据模型在同一系统中独立运行。
25 | - **模态隔离**:支持同一模型下不同模态数据(如文本和图像)的隔离处理。
26 |
27 | ## 性能
28 |
29 | 我们在生产环境中使用企业级数据库对 MultiModal Cache 进行了全面的性能评估。以下是详细的性能数据:
30 |
31 |
32 |
33 | 请求类型 |
34 | Cache Hit |
35 | 总耗时范围 |
36 | 组件 |
37 | 组件耗时 |
38 |
39 |
40 | Text |
41 | Hit |
42 | 420ms-520ms |
43 | Multi-Encoder (Text): |
44 | ~300ms |
45 |
46 |
47 | 向量存储检索 |
48 | 40-50ms |
49 |
50 |
51 | 关系存储检索 |
52 | 60-70ms |
53 |
54 |
55 | Not Hit |
56 | 300ms+N(s) |
57 | Multi-Encoder (Text): |
58 | ~300ms |
59 |
60 |
61 | 向量存储检索 |
62 | 40-50ms |
63 |
64 |
65 | 大模型调用 |
66 | N (s) |
67 |
68 |
69 | IMG_TEXT |
70 | Hit |
71 | 600ms-800ms |
72 | Multi-Encoder (image+text) |
73 | ~600ms |
74 |
75 |
76 | 向量存储检索 |
77 | 40-50ms |
78 |
79 |
80 | 关系存储检索 |
81 | 60-70ms |
82 |
83 |
84 | Not Hit |
85 | 600ms+N(s) |
86 | Multi-Encoder (image+text) |
87 | ~600ms |
88 |
89 |
90 | 向量存储检索 |
91 | 40-50ms |
92 |
93 |
94 | 大模型调用 |
95 | N (s) |
96 |
97 |
98 |
99 | 根据目前的评估结果,Embedding 的推理时间存在较大的优化空间。
100 | **说明**:使用嵌入式数据库可能会进一步提升性能。
101 |
102 | ## 效果评估
103 |
104 | 为全面评估 Cache 对模型服务的影响,我们进行了端到端的性能测试,ua 比较了有 Cache 和无 Cache 两种服务配置。我们使用了 5000 个测试用例的数据集进行自动化测试。
105 |
106 | - 有 Cache 的预发模型服务:观察其响应时间,预期 Cache 的引入能够显著提升服务的性能,降低延迟。
107 | - 无 Cache 的线上模型服务,以获取其原始性能指标和输出结果。这些数据将作为对比基准。
108 |
109 | 为了确保 Cache 引入后的数据准确性和一致性,我们比较了两个服务返回的结果,验证了 Cache 机制是否会影响最终用户收到的回复内容。
110 |
111 | 与原始的直接模型调用方式相比,Cache Service 的调用耗时数据呈现出稳定的分布特征,性能上并不会随着模型参数规模的增加而受到影响。在传统情况下,随着模型参数规模的扩大,模型调用的耗时往往会上升,这是因为更大规模的模型需要更多的计算资源。Cache 服务通过存储经常访问的数据来避免重复的计算,从而一定程度上解耦了耗时与模型复杂性之间的关联。
112 |
113 | 
114 |
115 | 我们对缓存命中的耗时与实际调用模型的耗时进行了对比分析。实验数据表明,在集成 Cache Service之后,基于 llama7B 模型,缓存命中所带来的性能提升超过了 40%。预计随着模型的持续迭代与优化,性能提升的幅度将会有更进一步的增长。
116 |
117 | 
118 |
119 | ## 参与贡献
120 |
121 | MultiModal Cache 是一个充满潜力的开源项目,我们欢迎各种形式的贡献:
122 |
123 | - 提交问题和建议
124 | - 参与代码编写
125 | - 完善文档和示例
126 |
127 | 无论您是经验丰富的开发者还是新手,您的参与都将使这个项目更加出色,同时为开源社区做出贡献。
128 |
--------------------------------------------------------------------------------
/reference_doc/create_table.sql:
--------------------------------------------------------------------------------
1 | CREATE TABLE `modelcache_llm_answer` (
2 | `id` bigint(20) unsigned NOT NULL AUTO_INCREMENT comment '主键',
3 | `gmt_create` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP comment '创建时间',
4 | `gmt_modified` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP comment '修改时间',
5 | `question` text NOT NULL comment 'question',
6 | `answer` text NOT NULL comment 'answer',
7 | `answer_type` int(11) NOT NULL comment 'answer_type',
8 | `hit_count` int(11) NOT NULL DEFAULT '0' comment 'hit_count',
9 | `model` varchar(1000) NOT NULL comment 'model',
10 | `embedding_data` blob NOT NULL comment 'embedding_data',
11 | `is_deleted` tinyint(1) NOT NULL DEFAULT '0' COMMENT 'delete state(0 Not deleted,-1 deleted)',
12 | PRIMARY KEY(`id`)
13 | ) AUTO_INCREMENT = 1 DEFAULT CHARSET = utf8mb4 COMMENT = 'cache_codegpt_answer';
14 |
15 |
16 | CREATE TABLE `modelcache_query_log` (
17 | `id` bigint(20) unsigned NOT NULL AUTO_INCREMENT comment '主键',
18 | `gmt_create` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP comment '创建时间',
19 | `gmt_modified` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP comment '修改时间',
20 | `error_code` int(11) NOT NULL comment 'errorCode',
21 | `error_desc` varchar(1000) NOT NULL comment 'errorDesc',
22 | `cache_hit` varchar(100) NOT NULL comment 'cacheHit',
23 | `delta_time` float NOT NULL comment 'delta_time',
24 | `model` varchar(1000) NOT NULL comment 'model',
25 | `query` text NOT NULL comment 'query',
26 | `hit_query` text NOT NULL comment 'hitQuery',
27 | `answer` text NOT NULL comment 'answer',
28 | PRIMARY KEY(`id`)
29 | ) AUTO_INCREMENT = 1 DEFAULT CHARSET = utf8mb4 COMMENT = 'modelcache_query_log';
30 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | cachetools==5.3.1
2 | DBUtils==1.4
3 | Flask==3.0.0
4 | numpy==1.24.4
5 | onnxruntime==1.16.1
6 | openai==0.28.1
7 | pymilvus==2.3.1
8 | PyMySQL==1.1.0
9 | Requests==2.31.0
10 | torch==2.1.1
11 | transformers==4.38.2
12 | faiss-cpu==1.7.4
13 | redis==5.0.1
14 | modelscope==1.14.0
15 | fastapi==0.115.5
16 | uvicorn==0.32.0
17 | chromadb==0.5.23
18 | elasticsearch==7.10.0
19 | snowflake-id==1.0.2
20 |
--------------------------------------------------------------------------------