├── LICENSE.txt ├── README.md ├── client.py ├── requirements.txt ├── server.conf └── server.py /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Adam Swanda 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # vector-embedding-api 2 | `vector-embedding-api`provides a Flask API server and client to generate text embeddings using either [OpenAI's embedding model](https://platform.openai.com/docs/guides/embeddings) or the [SentenceTransformers](https://www.sbert.net/) library. The API server now supports in-memory LRU caching for faster retrievals, batch processing for handling multiple texts at once, and a health status endpoint for monitoring the server status. 3 | 4 | SentenceTransformers supports over 500 models via [HuggingFace Hub](https://huggingface.co/sentence-transformers). 5 | 6 | ## Features 🎯 7 | * POST endpoint to create text embeddings 8 | * sentence_transformers 9 | * OpenAI text-embedding-ada-002 10 | * In-memory LRU cache for quick retrieval of embeddings 11 | * Batch processing to handle multiple texts in a single request 12 | * Easy setup with configuration file 13 | * Health status endpoint 14 | * Python client utility for submitting text or files 15 | 16 | ### Installation 💻 17 | To run this server locally, follow the steps below: 18 | 19 | **Clone the repository:** 📦 20 | ```bash 21 | git clone https://github.com/deadbits/vector-embedding-api.git 22 | cd vector-embedding-api 23 | ``` 24 | 25 | **Set up a virtual environment (optional but recommended):** 🐍 26 | ```bash 27 | virtualenv -p /usr/bin/python3.10 venv 28 | source venv/bin/activate 29 | ``` 30 | 31 | **Install the required dependencies:** 🛠️ 32 | ```bash 33 | pip install -r requirements.txt 34 | ``` 35 | 36 | ### Usage 37 | 38 | **Modify the [server.conf](/server.conf) configuration file:** ⚙️ 39 | ```ini 40 | [main] 41 | openai_api_key = YOUR_OPENAI_API_KEY 42 | sent_transformers_model = sentence-transformers/all-MiniLM-L6-v2 43 | use_cache = true/false 44 | ``` 45 | 46 | **Start the server:** 🚀 47 | ``` 48 | python server.py 49 | ``` 50 | 51 | The server should now be running on http://127.0.0.1:5000/. 52 | 53 | ### API Endpoints 🌐 54 | ##### Client Usage 55 | A small [Python client](/client.py) is provided to assist with submitting text strings or files. 56 | 57 | **Usage** 58 | `python3 client.py -t "Your text here" -m local` 59 | 60 | `python3 client.py -f /path/to/yourfile.txt -m openai` 61 | 62 | #### POST /submit 63 | Submits an individual text string or a list of text strings for embedding generation. 64 | 65 | **Request Parameters** 66 | 67 | * **text:** The text string or list of text strings to generate the embedding for. (Required) 68 | * **model:** Type of model to be used, either local for SentenceTransformer models or openai for OpenAI's model. Default is local. 69 | 70 | **Response** 71 | 72 | * **embedding:** The generated embedding array. 73 | * **status:** Status of the request, either success or error. 74 | * **elapsed:** The elapsed time taken for generating the embedding (in milliseconds). 75 | * **model:** The model used to generate the embedding. 76 | * **cache:** Boolean indicating if the result was retrieved from cache. (Optional) 77 | * **message:** Error message if the status is error. (Optional) 78 | 79 | #### GET /health 80 | Checks the server's health status. 81 | 82 | **Response** 83 | 84 | * **cache.enabled:** Boolean indicating status of the cache 85 | * **cache.max_size:** Maximum cache size 86 | * **cache.size:** Current cache size 87 | * **models.openai:** Boolean indicating if OpenAI embeddings are enabled. (Optional) 88 | * **models.sentence-transformers:** Name of sentence-transformers model in use. 89 | 90 | ```json 91 | { 92 | "cache": { 93 | "enabled": true, 94 | "max_size": 500, 95 | "size": 0 96 | }, 97 | "models": { 98 | "openai": true, 99 | "sentence-transformers": "sentence-transformers/all-MiniLM-L6-v2" 100 | } 101 | } 102 | ``` 103 | 104 | #### Example Usage 105 | Send a POST request to the /submit endpoint with JSON payload: 106 | 107 | ```json 108 | { 109 | "text": "Your text here", 110 | "model": "local" 111 | } 112 | 113 | // multi text submission 114 | { 115 | "text": ["Text1 goes here", "Text2 goes here"], 116 | "model": "openai" 117 | } 118 | ``` 119 | 120 | You'll receive a response containing the embedding and additional information: 121 | 122 | ```json 123 | [ 124 | { 125 | "embedding": [...], 126 | "status": "success", 127 | "elapsed": 123, 128 | "model": "sentence-transformers/all-MiniLM-L6-v2" 129 | } 130 | ] 131 | 132 | [ 133 | { 134 | "embedding": [...], 135 | "status": "success", 136 | "elapsed": 123, 137 | "model": "openai" 138 | }, 139 | { 140 | "embedding": [...], 141 | "status": "success", 142 | "elapsed": 123, 143 | "model": "openai" 144 | }, 145 | ] 146 | ``` 147 | -------------------------------------------------------------------------------- /client.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # github.com/deadbits/vector-embedding-api 3 | # client.py 4 | import os 5 | import sys 6 | import json 7 | import argparse 8 | import requests 9 | from pydantic import BaseModel 10 | from typing import List, Optional 11 | from datetime import datetime 12 | 13 | 14 | def timestamp_str(): 15 | return datetime.isoformat(datetime.utcnow()) 16 | 17 | 18 | class Embedding(BaseModel): 19 | text: str = '' 20 | embedding: List[float] = [] 21 | metadata: Optional[dict] = {} 22 | 23 | 24 | def send_request(text_batch, model_type='local'): 25 | url = 'http://127.0.0.1:5000/submit' 26 | headers = {'Content-Type': 'application/json'} 27 | payload = { 28 | 'text': text_batch, 29 | 'model': model_type 30 | } 31 | 32 | try: 33 | response = requests.post( 34 | url, 35 | headers=headers, 36 | data=json.dumps(payload) 37 | ) 38 | 39 | response.raise_for_status() 40 | return response.json() 41 | except requests.RequestException as err: 42 | print(f'[error] exception sending http request: {err}') 43 | return None 44 | 45 | 46 | def process_batch(text_batch, model_type, embeddings_list, chunk_num, total_chunks): 47 | print(f'[status] {timestamp_str()} - Processing chunk {chunk_num} of {total_chunks}') 48 | result = send_request(text_batch, model_type) 49 | if result: 50 | if result[0]['status'] == 'error': 51 | print(f'[error] {timestamp_str()} - Received error: {result[0]["message"]}') 52 | return 53 | else: 54 | print(f'[status] {timestamp_str()} - Received embeddings: {len(result[0]["embeddings"])} ') 55 | for text, em in zip(text_batch, result[0]['embeddings']): 56 | metadata = { 57 | 'status': result[0]['status'], 58 | 'elapsed': result[0]['elapsed'], 59 | 'model': result[0]['model'] 60 | } 61 | embedding = Embedding(text=text, embedding=em, metadata=metadata) 62 | embeddings_list.append(embedding.dict()) 63 | 64 | 65 | if __name__ == '__main__': 66 | parser = argparse.ArgumentParser() 67 | 68 | group = parser.add_mutually_exclusive_group(required=True) 69 | group.add_argument( 70 | '-t', '--text', 71 | help='text to embed' 72 | ) 73 | 74 | group.add_argument( 75 | '-f', '--file', 76 | type=argparse.FileType('r'), 77 | help='text file to embed (one text per line)' 78 | ) 79 | 80 | parser.add_argument( 81 | '-m', '--model', 82 | help='embedding model type', 83 | choices=['local', 'openai'], 84 | default='local' 85 | ) 86 | 87 | parser.add_argument( 88 | '-o', '--output', 89 | help='output file', 90 | default='embeddings.json' 91 | ) 92 | 93 | args = parser.parse_args() 94 | model_type = args.model 95 | output_file = args.output 96 | embeddings_list = [] 97 | 98 | if os.path.exists(output_file): 99 | print(f'[error] {timestamp_str()} - Output file already exists') 100 | sys.exit(1) 101 | 102 | if args.file: 103 | if not os.path.exists(args.file.name): 104 | print(f'[error] {timestamp_str()} - File does not exist') 105 | sys.exit(1) 106 | 107 | print(f'[status] {timestamp_str()} - Processing file: {args.file.name}') 108 | 109 | text_batch = [] 110 | chunk_size = 100 111 | total_lines = sum(1 for _ in args.file) 112 | args.file.seek(0) 113 | total_chunks = (total_lines + chunk_size - 1) // chunk_size 114 | 115 | print(f'[info] {timestamp_str()} - Total chunks: {total_chunks}') 116 | 117 | chunk_num = 1 118 | 119 | for line in args.file: 120 | text = line.strip() 121 | text_batch.append(text) 122 | if len(text_batch) == chunk_size: 123 | process_batch(text_batch, model_type, embeddings_list, chunk_num, total_chunks) 124 | text_batch = [] 125 | chunk_num += 1 126 | 127 | if text_batch: 128 | process_batch(text_batch, model_type, embeddings_list, chunk_num, total_chunks) 129 | else: 130 | print(f'[status] {timestamp_str()} - Processing text input') 131 | text = args.text 132 | result = send_request([text], model_type) 133 | if result: 134 | for res in result: 135 | metadata = {'status': res['status'], 'elapsed': res['elapsed'], 'model': res['model']} 136 | embedding = Embedding(text=text, embedding=res['embedding'], metadata=metadata) 137 | embeddings_list.append(embedding.dict()) 138 | 139 | try: 140 | print(f'[status] {timestamp_str()} - Saving embeddings to {output_file}') 141 | with open(output_file, 'w') as f: 142 | json.dump(embeddings_list, f) 143 | 144 | print(f'[status] {timestamp_str()} - Embeddings saved to embeddings.json') 145 | except Exception as err: 146 | print(f'[error] {timestamp_str()} - exception saving embeddings: {err}') 147 | sys.exit(1) 148 | 149 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Flask==2.3.2 2 | openai==0.27.8 3 | sentence-transformers==2.2.2 4 | -------------------------------------------------------------------------------- /server.conf: -------------------------------------------------------------------------------- 1 | [main] 2 | openai_api_key = 3 | sent_transformers_model = sentence-transformers/all-MiniLM-L6-v2 4 | use_cache = true 5 | cache_max = 500 6 | -------------------------------------------------------------------------------- /server.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # github.com/deadbits/vector-embedding-api 3 | # server.py 4 | import os 5 | import sys 6 | import time 7 | import argparse 8 | import hashlib 9 | import logging 10 | import configparser 11 | 12 | import openai 13 | 14 | from typing import Dict, List, Union, Optional 15 | from collections import OrderedDict 16 | from flask import Flask, request, jsonify, abort 17 | from sentence_transformers import SentenceTransformer 18 | 19 | 20 | app = Flask(__name__) 21 | 22 | logging.basicConfig(level=logging.INFO) 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | class Config: 27 | def __init__(self, config_file: str): 28 | self.config_file = config_file 29 | self.config = configparser.ConfigParser() 30 | if not os.path.exists(self.config_file): 31 | logging.error(f'Config file not found: {self.config_file}') 32 | sys.exit(1) 33 | 34 | logging.info(f'Loading config file: {self.config_file}') 35 | self.config.read(config_file) 36 | 37 | def get_val(self, section: str, key: str) -> Optional[str]: 38 | answer = None 39 | 40 | try: 41 | answer = self.config.get(section, key) 42 | except Exception as err: 43 | logging.error(f'Config file missing section: {section} - {err}') 44 | 45 | return answer 46 | 47 | def get_bool(self, section: str, key: str, default: bool = False) -> bool: 48 | try: 49 | return self.config.getboolean(section, key) 50 | except Exception as err: 51 | logging.error(f'Failed to parse boolean - returning default "False": {section} - {err}') 52 | return default 53 | 54 | 55 | class EmbeddingCache: 56 | def __init__(self, max_size: int = 500): 57 | logger.info(f'Created in-memory cache; max size={max_size}') 58 | self.cache = OrderedDict() 59 | self.max_size = max_size 60 | 61 | def get_cache_key(self, text: str, model_type: str) -> str: 62 | return hashlib.sha256((text + model_type).encode()).hexdigest() 63 | 64 | def get(self, text: str, model_type: str): 65 | return self.cache.get(self.get_cache_key(text, model_type)) 66 | 67 | def set(self, text: str, model_type: str, embedding): 68 | key = self.get_cache_key(text, model_type) 69 | self.cache[key] = embedding 70 | if len(self.cache) > self.max_size: 71 | self.cache.popitem(last=False) 72 | 73 | 74 | class EmbeddingGenerator: 75 | def __init__(self, sbert_model: Optional[str] = None, openai_key: Optional[str] = None): 76 | self.sbert_model = sbert_model 77 | self.openai_key = openai_key 78 | if self.sbert_model is not None: 79 | try: 80 | self.model = SentenceTransformer(self.sbert_model) 81 | logger.info(f'enabled model: {self.sbert_model}') 82 | except Exception as err: 83 | logger.error(f'Failed to load SentenceTransformer model "{self.sbert_model}": {err}') 84 | sys.exit(1) 85 | 86 | if openai_key is not None: 87 | openai.api_key = self.openai_key 88 | try: 89 | openai.Model.list() 90 | logger.info('enabled model: text-embedding-ada-002') 91 | except Exception as err: 92 | logger.error(f'Failed to connect to OpenAI API; disabling OpenAI model: {err}') 93 | 94 | def generate(self, text_batch: List[str], model_type: str) -> Dict[str, Union[str, float, list]]: 95 | start_time = time.time() 96 | result = { 97 | 'status': 'success', 98 | 'message': '', 99 | 'model': '', 100 | 'elapsed': 0, 101 | 'embeddings': [] 102 | } 103 | 104 | if model_type == 'openai': 105 | try: 106 | response = openai.Embedding.create(input=text_batch, model='text-embedding-ada-002') 107 | result['embeddings'] = [data['embedding'] for data in response['data']] 108 | result['model'] = 'text-embedding-ada-002' 109 | except Exception as err: 110 | logger.error(f'Failed to get OpenAI embeddings: {err}') 111 | result['status'] = 'error' 112 | result['message'] = str(err) 113 | 114 | else: 115 | try: 116 | embedding = self.model.encode(text_batch, batch_size=len(text_batch), device='cuda').tolist() 117 | result['embeddings'] = embedding 118 | result['model'] = self.sbert_model 119 | except Exception as err: 120 | logger.error(f'Failed to get sentence-transformers embeddings: {err}') 121 | result['status'] = 'error' 122 | result['message'] = str(err) 123 | 124 | result['elapsed'] = (time.time() - start_time) * 1000 125 | return result 126 | 127 | 128 | @app.route('/health', methods=['GET']) 129 | def health_check(): 130 | sbert_on = embedding_generator.sbert_model if embedding_generator.sbert_model else 'disabled' 131 | openai_on = True if embedding_generator.openai_key else 'disabled' 132 | 133 | health_status = { 134 | "models": { 135 | "openai": openai_on, 136 | 'sentence-transformers': sbert_on 137 | }, 138 | "cache": { 139 | "enabled": embedding_cache is not None, 140 | "size": len(embedding_cache.cache) if embedding_cache else None, 141 | "max_size": embedding_cache.max_size if embedding_cache else None 142 | } 143 | } 144 | 145 | return jsonify(health_status) 146 | 147 | 148 | @app.route('/submit', methods=['POST']) 149 | def submit_text(): 150 | data = request.json 151 | 152 | text_data = data.get('text') 153 | model_type = data.get('model', 'local').lower() 154 | 155 | if text_data is None: 156 | abort(400, 'Missing text data to embed') 157 | 158 | if not all(isinstance(text, str) for text in text_data): 159 | abort(400, 'all data must be text strings') 160 | 161 | results = [] 162 | result = embedding_generator.generate(text_data, model_type) 163 | 164 | if embedding_cache and result['status'] == 'success': 165 | for text, embedding in zip(text_data, result['embeddings']): 166 | embedding_cache.set(text, model_type, embedding) 167 | logger.info('added to cache') 168 | 169 | results.append(result) 170 | 171 | return jsonify(results) 172 | 173 | 174 | if __name__ == '__main__': 175 | parser = argparse.ArgumentParser() 176 | 177 | parser.add_argument( 178 | '-c', '--config', 179 | help='config file', 180 | type=str, 181 | required=True 182 | ) 183 | 184 | args = parser.parse_args() 185 | 186 | conf = Config(args.config) 187 | openai_key = conf.get_val('main', 'openai_api_key') 188 | sbert_model = conf.get_val('main', 'sent_transformers_model') 189 | use_cache = conf.get_bool('main', 'use_cache', default=False) 190 | if use_cache: 191 | max_cache_size = int(conf.get_val('main', 'cache_max')) 192 | 193 | if openai_key is None: 194 | logger.warn('No OpenAI API key set in configuration file: server.conf') 195 | 196 | if sbert_model is None: 197 | logger.warn('No transformer model set in configuration file: server.conf') 198 | 199 | if openai_key is None and sbert_model is None: 200 | logger.error('No sbert model set *and* no openAI key set; exiting') 201 | sys.exit(1) 202 | 203 | embedding_cache = EmbeddingCache(max_cache_size) if use_cache else None 204 | embedding_generator = EmbeddingGenerator(sbert_model, openai_key) 205 | 206 | app.run(debug=True) 207 | --------------------------------------------------------------------------------