├── tests ├── __init__.py └── test_batcher.py ├── gpt_batch ├── __init__.py └── batcher.py ├── setup.py ├── .github └── workflows │ └── workflow.yml ├── .gitignore └── README.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /gpt_batch/__init__.py: -------------------------------------------------------------------------------- 1 | from .batcher import GPTBatcher 2 | 3 | 4 | __all__ = ['GPTBatcher'] 5 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='gpt_batch', 5 | version='0.1.9', 6 | packages=find_packages(), 7 | install_requires=[ 8 | 'openai', 'tqdm','anthropic' 9 | ], 10 | author='Ted Yu', 11 | author_email='liddlerain@gmail.com', 12 | description='A package for batch processing with OpenAI API.', 13 | long_description=open('README.md').read(), 14 | long_description_content_type='text/markdown', 15 | url='https://github.com/fengsxy/gpt_batch', 16 | ) 17 | -------------------------------------------------------------------------------- /.github/workflows/workflow.yml: -------------------------------------------------------------------------------- 1 | name: Publish Python 🐍 distribution 📦 to PyPI 2 | 3 | on: 4 | release: 5 | types: [created] 6 | 7 | jobs: 8 | deploy: 9 | runs-on: ubuntu-latest 10 | 11 | steps: 12 | - uses: actions/checkout@v2 13 | 14 | - name: Set up Python 15 | uses: actions/setup-python@v2 16 | with: 17 | python-version: '3.9' 18 | 19 | - name: List files in the repository 20 | run: ls -lha 21 | 22 | 23 | 24 | 25 | - name: Install dependencies 26 | run: | 27 | python -m pip install --upgrade pip 28 | pip install setuptools wheel twine 29 | 30 | - name: Build and publish 31 | env: 32 | TWINE_USERNAME: __token__ 33 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} 34 | run: | 35 | python setup.py sdist bdist_wheel 36 | twine upload dist/* 37 | 38 | - name: Clean up build artifacts 39 | run: rm -rf build dist *.egg-info 40 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a Python script from a template 32 | # before PyInstaller builds the executable, when PyInstaller is instructed to do so. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # pyenv 81 | .python-version 82 | 83 | # celery beat schedule file 84 | celerybeat-schedule 85 | 86 | # SageMath parsed files 87 | *.sage.py 88 | 89 | # Environments 90 | .env 91 | .venv 92 | env/ 93 | venv/ 94 | ENV/ 95 | env.bak/ 96 | venv.bak/ 97 | 98 | # Spyder project settings 99 | .spyderproject 100 | .spyproject 101 | 102 | # Rope project settings 103 | .ropeproject 104 | 105 | # mkdocs documentation 106 | /site 107 | 108 | # mypy 109 | .mypy_cache/ 110 | .dmypy.json 111 | dmypy.json 112 | 113 | # Pyre type checker 114 | .pyre/ 115 | 116 | # pytype static type analyzer 117 | .pytype/ 118 | 119 | # Cython debug symbols 120 | cython_debug/ 121 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | ```markdown 3 | # GPT Batcher 4 | 5 | A simple tool to batch process messages using OpenAI's GPT models. `GPTBatcher` allows for efficient handling of multiple requests simultaneously, ensuring quick responses and robust error management. 6 | 7 | ## Installation 8 | 9 | To get started with `GPTBatcher`, clone this repository to your local machine. Navigate to the repository directory and install the required dependencies (if any) by running: 10 | 11 | ```bash 12 | pip install gpt_batch 13 | ``` 14 | 15 | ## Quick Start 16 | 17 | To use `GPTBatcher`, you need to instantiate it with your OpenAI API key and the model name you wish to use. Here's a quick guide: 18 | 19 | ### Handling Message Lists 20 | 21 | This example demonstrates how to send a list of questions and receive answers: 22 | 23 | ```python 24 | from gpt_batch.batcher import GPTBatcher 25 | 26 | # Initialize the batcher 27 | batcher = GPTBatcher(api_key='your_key_here', model_name='gpt-3.5-turbo-1106') 28 | 29 | # Send a list of messages and receive answers 30 | result = batcher.handle_message_list(['question_1', 'question_2', 'question_3', 'question_4']) 31 | print(result) 32 | # Expected output: ["answer_1", "answer_2", "answer_3", "answer_4"] 33 | ``` 34 | 35 | ### Handling Embedding Lists 36 | 37 | This example shows how to get embeddings for a list of strings: 38 | 39 | ```python 40 | from gpt_batch.batcher import GPTBatcher 41 | 42 | # Reinitialize the batcher for embeddings 43 | batcher = GPTBatcher(api_key='your_key_here', model_name='text-embedding-3-small') 44 | 45 | # Send a list of strings and get their embeddings 46 | result = batcher.handle_embedding_list(['question_1', 'question_2', 'question_3', 'question_4']) 47 | print(result) 48 | # Expected output: ["embedding_1", "embedding_2", "embedding_3", "embedding_4"] 49 | ``` 50 | 51 | ### Handling Message Lists with different API 52 | 53 | This example demonstrates how to send a list of questions and receive answers with different api: 54 | 55 | ```python 56 | from gpt_batch.batcher import GPTBatcher 57 | 58 | # Initialize the batcher 59 | batcher = GPTBatcher(api_key='sk-', model_name='deepseek-chat',api_base_url="https://api.deepseek.com/v1") 60 | 61 | 62 | # Send a list of messages and receive answers 63 | result = batcher.handle_message_list(['question_1', 'question_2', 'question_3', 'question_4']) 64 | 65 | # Expected output: ["answer_1", "answer_2", "answer_3", "answer_4"] 66 | ``` 67 | ## Configuration 68 | 69 | The `GPTBatcher` class can be customized with several parameters to adjust its performance and behavior: 70 | 71 | - **api_key** (str): Your OpenAI API key. 72 | - **model_name** (str): Identifier for the GPT model version you want to use, default is 'gpt-3.5-turbo-1106'. 73 | - **system_prompt** (str): Initial text or question to seed the model, default is empty. 74 | - **temperature** (float): Adjusts the creativity of the responses, default is 1. 75 | - **num_workers** (int): Number of parallel workers for request handling, default is 64. 76 | - **timeout_duration** (int): Timeout for API responses in seconds, default is 60. 77 | - **retry_attempts** (int): How many times to retry a failed request, default is 2. 78 | - **miss_index** (list): Tracks indices of requests that failed to process correctly. 79 | 80 | For more detailed documentation on the parameters and methods, refer to the class docstring. 81 | -------------------------------------------------------------------------------- /tests/test_batcher.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from gpt_batch import GPTBatcher 3 | import os 4 | 5 | def test_handle_message_list(): 6 | # Initialize the GPTBatcher with hypothetical valid credentials 7 | #api_key = #get from system environment 8 | api_key = os.getenv('TEST_KEY') 9 | if not api_key: 10 | raise ValueError("API key must be set in the environment variables") 11 | batcher = GPTBatcher(api_key=api_key, model_name='gpt-3.5-turbo-1106', system_prompt="Your task is to discuss privacy questions and provide persuasive answers with supporting reasons.") 12 | message_list = ["I think privacy is important", "I don't think privacy is important"] 13 | 14 | # Call the method under test 15 | results = batcher.handle_message_list(message_list) 16 | 17 | # Assertions to verify the length of the results and the structure of each item 18 | assert len(results) == 2, "There should be two results, one for each message" 19 | assert all(len(result) >= 2 for result in results), "Each result should be at least two elements" 20 | 21 | 22 | def test_json_format(): 23 | import json 24 | # Initialize the GPTBatcher with hypothetical valid credentials 25 | #api_key = #get from system environment 26 | api_key = os.getenv('TEST_KEY') 27 | if not api_key: 28 | raise ValueError("API key must be set in the environment variables") 29 | batcher = GPTBatcher(api_key=api_key, model_name='gpt-3.5-turbo-1106', system_prompt="Your task is to discuss privacy questions and provide persuasive answers with supporting reasons.",response_format={ "type": "json_object" }) 30 | message_list = ["return me a random json object", "return me a random json object"] 31 | 32 | # Call the method under test 33 | results = batcher.handle_message_list(message_list) 34 | # Assertions to verify the length of the results and the structure of each item 35 | assert len(results) == 2, "There should be two results, one for each message" 36 | assert all(len(result) >= 2 for result in results), "Each result should be at least two elements" 37 | #assert all(isinstance(result, dict) and 'json' in result for result in results), "Each result should be a JSON object with 'json' key" 38 | assert all(isinstance(json.loads(result), dict) for result in results), "Each result should be a JSON object with 'json' key" 39 | 40 | 41 | 42 | def test_handle_embedding_list(): 43 | # Initialize the GPTBatcher with hypothetical valid credentials 44 | #api_key = #get from system environment 45 | api_key = os.getenv('TEST_KEY') 46 | if not api_key: 47 | raise ValueError("API key must be set in the environment variables") 48 | batcher = GPTBatcher(api_key=api_key, model_name='text-embedding-3-small') 49 | embedding_list = [ "I think privacy is important", "I don't think privacy is important"] 50 | results = batcher.handle_embedding_list(embedding_list) 51 | assert len(results) == 2, "There should be two results, one for each message" 52 | assert all(len(result) >= 2 for result in results), "Each result should be at least two elements" 53 | 54 | def test_base_url(): 55 | # Initialize the GPTBatcher with hypothetical valid credentials 56 | #api_key = #get from system environment 57 | api_key = os.getenv('TEST_KEY') 58 | if not api_key: 59 | raise ValueError("API key must be set in the environment variables") 60 | batcher = GPTBatcher(api_key=api_key, model_name='gpt-3.5-turbo-1106', api_base_url="https://api.openai.com/v2/") 61 | assert batcher.client.base_url == "https://api.openai.com/v2/", "The base URL should be set to the provided value" 62 | 63 | def test_get_miss_index(): 64 | # Initialize the GPTBatcher with hypothetical valid credentials 65 | #api_key = #get from system environment 66 | api_key = os.getenv('TEST_KEY') 67 | if not api_key: 68 | raise ValueError("API key must be set in the environment variables") 69 | batcher = GPTBatcher(api_key=api_key, model_name='gpt-3.5-turbo-1106', system_prompt="Your task is to discuss privacy questions and provide persuasive answers with supporting reasons.") 70 | message_list = ["I think privacy is important", "I don't think privacy is important"] 71 | results = batcher.handle_message_list(message_list) 72 | miss_index = batcher.get_miss_index() 73 | assert miss_index == [], "The miss index should be empty" 74 | # Optionally, you can add a test configuration if you have specific needs 75 | 76 | 77 | def test_claude_handle_message_list(): 78 | # Initialize the GPTBatcher with Claude model 79 | api_key = os.getenv('ANTHROPIC_API_KEY') 80 | if not api_key: 81 | raise ValueError("Anthropic API key must be set in the environment variables as ANTHROPIC_API_KEY") 82 | 83 | batcher = GPTBatcher( 84 | api_key=api_key, 85 | model_name='claude-3-7-sonnet-20250219', 86 | system_prompt="Your task is to discuss privacy questions and provide persuasive answers with supporting reasons." 87 | ) 88 | message_list = ["I think privacy is important", "I don't think privacy is important"] 89 | 90 | # Call the method under test 91 | results = batcher.handle_message_list(message_list) 92 | 93 | # Assertions to verify the length of the results and the structure of each item 94 | assert len(results) == 2, "There should be two results, one for each message" 95 | assert all(isinstance(result, str) and len(result) > 0 for result in results if result is not None), "Each result should be a non-empty string if not None" 96 | assert batcher.is_claude, "Should recognize model as Claude" 97 | if __name__ == "__main__": 98 | pytest.main() 99 | -------------------------------------------------------------------------------- /gpt_batch/batcher.py: -------------------------------------------------------------------------------- 1 | from openai import OpenAI 2 | import anthropic 3 | from concurrent.futures import ThreadPoolExecutor, wait 4 | from functools import partial 5 | from tqdm import tqdm 6 | import re 7 | 8 | class GPTBatcher: 9 | """ 10 | A class to handle batching and sending requests to the OpenAI GPT model and Anthropic Claude models efficiently. 11 | 12 | Attributes: 13 | client: The client instance to communicate with the API (OpenAI or Anthropic). 14 | is_claude (bool): Flag to indicate if using a Claude model. 15 | model_name (str): The name of the model to be used. Default is 'gpt-3.5-turbo-0125'. 16 | system_prompt (str): Initial prompt or context to be used with the model. Default is an empty string. 17 | temperature (float): Controls the randomness of the model's responses. Higher values lead to more diverse outputs. Default is 1. 18 | num_workers (int): Number of worker threads used for handling concurrent requests. Default is 64. 19 | timeout_duration (int): Maximum time (in seconds) to wait for a response from the API before timing out. Default is 60 seconds. 20 | retry_attempts (int): Number of retries if a request fails. Default is 2. 21 | miss_index (list): Tracks the indices of requests that failed to process correctly. 22 | """ 23 | def __init__(self, api_key, model_name="gpt-3.5-turbo-0125", system_prompt="",temperature=1,num_workers=64,timeout_duration=60,retry_attempts=2,api_base_url=None,**kwargs): 24 | 25 | self.is_claude = bool(re.search(r'claude', model_name, re.IGNORECASE)) 26 | 27 | if self.is_claude: 28 | self.client = anthropic.Anthropic(api_key=api_key) 29 | # Anthropic doesn't support custom base URL the same way 30 | # If needed, this could be implemented differently 31 | else: 32 | self.client = OpenAI(api_key=api_key) 33 | if api_base_url: 34 | self.client.base_url = api_base_url 35 | 36 | self.model_name = model_name 37 | self.system_prompt = system_prompt 38 | self.temperature = temperature 39 | self.num_workers = num_workers 40 | self.timeout_duration = timeout_duration 41 | self.retry_attempts = retry_attempts 42 | self.miss_index = [] 43 | self.extra_params = kwargs 44 | 45 | def get_attitude(self, ask_text): 46 | index, ask_text = ask_text 47 | try: 48 | if self.is_claude: 49 | # Use the Anthropic Claude API 50 | message = self.client.messages.create( 51 | model=self.model_name, 52 | max_tokens=1024, # You can make this configurable if needed 53 | messages=[ 54 | {"role": "user", "content": ask_text} 55 | ], 56 | system=self.system_prompt if self.system_prompt else None, 57 | temperature=self.temperature, 58 | **self.extra_params 59 | ) 60 | return (index, message.content[0].text) 61 | else: 62 | # Use the OpenAI API as before 63 | completion = self.client.chat.completions.create( 64 | model=self.model_name, 65 | messages=[ 66 | {"role": "system", "content": self.system_prompt}, 67 | {"role": "user", "content": ask_text} 68 | ], 69 | temperature=self.temperature, 70 | **self.extra_params 71 | ) 72 | return (index, completion.choices[0].message.content) 73 | except Exception as e: 74 | print(f"Error occurred: {e}") 75 | self.miss_index.append(index) 76 | return (index, None) 77 | 78 | def process_attitude(self, message_list): 79 | new_list = [] 80 | num_workers = self.num_workers 81 | timeout_duration = self.timeout_duration 82 | retry_attempts = self.retry_attempts 83 | 84 | executor = ThreadPoolExecutor(max_workers=num_workers) 85 | message_chunks = list(self.chunk_list(message_list, num_workers)) 86 | try: 87 | for chunk in tqdm(message_chunks, desc="Processing messages"): 88 | future_to_message = {executor.submit(self.get_attitude, message): message for message in chunk} 89 | for _ in range(retry_attempts): 90 | done, not_done = wait(future_to_message.keys(), timeout=timeout_duration) 91 | for future in not_done: 92 | future.cancel() 93 | new_list.extend(future.result() for future in done if future.done()) 94 | if len(not_done) == 0: 95 | break 96 | future_to_message = {executor.submit(self.get_attitude, future_to_message[future]): future_to_message[future] for future in not_done} 97 | except Exception as e: 98 | print(f"Error occurred: {e}") 99 | finally: 100 | executor.shutdown(wait=False) 101 | return new_list 102 | 103 | def complete_attitude_list(self, attitude_list, max_length): 104 | completed_list = [] 105 | current_index = 0 106 | for item in attitude_list: 107 | index, value = item 108 | # Fill in missing indices 109 | while current_index < index: 110 | completed_list.append((current_index, None)) 111 | current_index += 1 112 | # Add the current element from the list 113 | completed_list.append(item) 114 | current_index = index + 1 115 | while current_index < max_length: 116 | print("Filling in missing index", current_index) 117 | self.miss_index.append(current_index) 118 | completed_list.append((current_index, None)) 119 | current_index += 1 120 | return completed_list 121 | 122 | def chunk_list(self, lst, n): 123 | """Yield successive n-sized chunks from lst.""" 124 | for i in range(0, len(lst), n): 125 | yield lst[i:i + n] 126 | 127 | def handle_message_list(self, message_list): 128 | indexed_list = [(index, data) for index, data in enumerate(message_list)] 129 | max_length = len(indexed_list) 130 | attitude_list = self.process_attitude(indexed_list) 131 | attitude_list.sort(key=lambda x: x[0]) 132 | attitude_list = self.complete_attitude_list(attitude_list, max_length) 133 | attitude_list = [x[1] for x in attitude_list] 134 | return attitude_list 135 | 136 | def get_embedding(self, text): 137 | index, text = text 138 | try: 139 | if self.is_claude: 140 | # Use Anthropic's embedding API if available 141 | # Note: As of March 2025, make sure to check Anthropic's latest API 142 | # for embeddings, as the format might have changed 143 | response = self.client.embeddings.create( 144 | model=self.model_name, 145 | input=text 146 | ) 147 | return (index, response.embedding) 148 | else: 149 | # Use OpenAI's embedding API 150 | response = self.client.embeddings.create( 151 | input=text, 152 | model=self.model_name 153 | ) 154 | return (index, response.data[0].embedding) 155 | except Exception as e: 156 | print(f"Error getting embedding: {e}") 157 | self.miss_index.append(index) 158 | return (index, None) 159 | 160 | def process_embedding(self, message_list): 161 | new_list = [] 162 | executor = ThreadPoolExecutor(max_workers=self.num_workers) 163 | # Split message_list into chunks 164 | message_chunks = list(self.chunk_list(message_list, self.num_workers)) 165 | fixed_get_embedding = partial(self.get_embedding) 166 | for chunk in tqdm(message_chunks, desc="Processing messages"): 167 | future_to_message = {executor.submit(fixed_get_embedding, message): message for message in chunk} 168 | for i in range(self.retry_attempts): 169 | done, not_done = wait(future_to_message.keys(), timeout=self.timeout_duration) 170 | for future in not_done: 171 | future.cancel() 172 | new_list.extend(future.result() for future in done if future.done()) 173 | if len(not_done) == 0: 174 | break 175 | future_to_message = {executor.submit(fixed_get_embedding, future_to_message[future]): future_to_message[future] for future in not_done} 176 | executor.shutdown(wait=False) 177 | return new_list 178 | 179 | def handle_embedding_list(self, message_list): 180 | indexed_list = [(index, data) for index, data in enumerate(message_list)] 181 | max_length = len(indexed_list) 182 | attitude_list = self.process_embedding(indexed_list) 183 | attitude_list.sort(key=lambda x: x[0]) 184 | attitude_list = self.complete_attitude_list(attitude_list, max_length) 185 | attitude_list = [x[1] for x in attitude_list] 186 | return attitude_list 187 | 188 | def get_miss_index(self): 189 | return self.miss_index 190 | 191 | # Example usage: 192 | if __name__ == "__main__": 193 | # For OpenAI 194 | openai_batcher = GPTBatcher( 195 | api_key="your_openai_api_key", 196 | model_name="gpt-4-turbo" 197 | ) 198 | 199 | # For Claude 200 | claude_batcher = GPTBatcher( 201 | api_key="your_anthropic_api_key", 202 | model_name="claude-3-7-sonnet-20250219" 203 | ) --------------------------------------------------------------------------------