├── .gitignore ├── .readthedocs.yaml ├── LICENSE ├── README.md ├── databonsai ├── CONTRIBUTING.MD ├── __init__.py ├── categorize │ ├── __init__.py │ ├── base_categorizer.py │ └── multi_categorizer.py ├── examples │ └── categorize_news.ipynb ├── llm_providers │ ├── __init__.py │ ├── anthropic_provider.py │ ├── llm_provider.py │ ├── ollama_provider.py │ └── openai_provider.py ├── transform │ ├── __init__.py │ ├── base_transformer.py │ └── extract_transformer.py └── utils │ ├── __init__.py │ ├── apply.py │ └── logs.py ├── docs ├── AnthropicProvider.md ├── BaseCategorizer.md ├── BaseTransformer.md ├── ExtractTransformer.md ├── MultiCategorizer.md ├── OllamaProvider.md ├── OpenAIProvider.md ├── Utils.md └── source │ ├── conf.py │ ├── databonsai.categorize.rst │ ├── databonsai.llm_providers.rst │ ├── databonsai.rst │ ├── databonsai.transform.rst │ ├── databonsai.utils.rst │ ├── index.rst │ └── modules.rst ├── pyproject.toml ├── setup.py └── tests ├── __init__.py └── test_categorization.py /.gitignore: -------------------------------------------------------------------------------- 1 | .env 2 | .*pyc 3 | __pycache__/ 4 | venv/ 5 | dist/ 6 | databonsai.egg-info/ 7 | build/ 8 | docs/source/ 9 | docs/build/ -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | version: 2 3 | build: 4 | os: "ubuntu-22.04" 5 | tools: 6 | python: "3.11" 7 | 8 | sphinx: 9 | configuration: docs/source/conf.py 10 | 11 | python: 12 | install: 13 | - requirements: docs/requirements.txt 14 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Alvin Ryanputra 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # databonsai external-bonsai-tree-justicon-flat-justicon 2 | 3 | [![PyPI version](https://badge.fury.io/py/databonsai.svg)](https://badge.fury.io/py/databonsai) 4 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 5 | [![Python Version](https://img.shields.io/pypi/pyversions/databonsai.svg)](https://pypi.org/project/databonsai/) 6 | [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) 7 | 8 | ## Clean & curate your data with LLMs 9 | 10 | databonsai is a Python library that uses LLMs to perform data cleaning tasks. 11 | 12 | ## Features 13 | 14 | - Suite of tools for data processing using LLMs including categorization, 15 | transformation, and extraction 16 | - Validation of LLM outputs 17 | - Batch processing for token savings 18 | - Retry logic with exponential backoff for handling rate limits and transient 19 | errors 20 | 21 | ## Installation 22 | 23 | ```bash 24 | pip install databonsai 25 | ``` 26 | 27 | Store your API keys on an .env file in the root of your project, or specify it 28 | as an argument when initializing the provider. 29 | 30 | ```bash 31 | OPENAI_API_KEY=xxx # if you use OpenAiProvider 32 | ANTHROPIC_API_KEY=xxx # If you use AnthropicProvider 33 | ``` 34 | 35 | ## Quickstart 36 | 37 | ### Categorization 38 | 39 | Setup the LLM provider and categories (as a dictionary. 40 | 41 | ```python 42 | from databonsai.categorize import MultiCategorizer, BaseCategorizer 43 | from databonsai.llm_providers import OpenAIProvider, AnthropicProvider 44 | 45 | provider = OpenAIProvider() # Or AnthropicProvider(). Highly recommend using Haiku, which is the default AnthropicProvider() model, as it is cheap and effective for these tasks 46 | categories = { 47 | "Weather": "Insights and remarks about weather conditions.", 48 | "Sports": "Observations and comments on sports events.", 49 | "Politics": "Political events related to governments, nations, or geopolitical issues.", 50 | "Celebrities": "Celebrity sightings and gossip", 51 | "Others": "Comments do not fit into any of the above categories", 52 | "Anomaly": "Data that does not look like comments or natural language", 53 | } 54 | few_shot_examples = [ 55 | {"example": "Big stormy skies over city", "response": "Weather"}, 56 | {"example": "The team won the championship", "response": "Sports"}, 57 | {"example": "I saw a famous rapper at the mall", "response": "Celebrities"}, 58 | ] 59 | ``` 60 | 61 | Categorize your data: 62 | 63 | ```python 64 | categorizer = BaseCategorizer( 65 | categories=categories, 66 | llm_provider=provider, 67 | examples = few_shot_examples, 68 | #strict = False # Default true, set to False to allow for categories not in the provided dict 69 | ) 70 | category = categorizer.categorize("It's been raining outside all day") 71 | print(category) 72 | ``` 73 | 74 | Output: 75 | 76 | ```python 77 | Weather 78 | ``` 79 | 80 | Use categorize_batch to categorize a batch. This saves tokens as it only sends 81 | the schema and few shot examples once! (Works best for better models. Ideally, 82 | use at least 3 few shot examples.) 83 | 84 | ```python 85 | categories = categorizer.categorize_batch([ 86 | "Massive Blizzard Hits the Northeast, Thousands Without Power", 87 | "Local High School Basketball Team Wins State Championship After Dramatic Final", 88 | "Celebrated Actor Launches New Environmental Awareness Campaign", 89 | ]) 90 | print(categories) 91 | ``` 92 | 93 | Output: 94 | 95 | ```python 96 | ['Weather', 'Sports', 'Celebrities'] 97 | ``` 98 | 99 | ### AutoBatch for Larger datasets 100 | 101 | If you have a pandas dataframe or list, use `apply_to_column_autobatch` 102 | 103 | - Batching data for LLM api calls saves tokens by not sending the prompt for 104 | every row. However, too large a batch size / complex tasks can lead to 105 | errors. Naturally, the better the LLM model, the larger the batch size you 106 | can use. 107 | 108 | - This batching is handled adaptively (i.e., it will increase the batch size 109 | if the response is valid and reduce it if it's not, with a decay factor) 110 | 111 | Other features: 112 | 113 | - Progress bar 114 | - Returns the last successful index so you can resume from there, in case it 115 | exceeds max_retries 116 | - Modifies your output list in place, so you don't lose any progress 117 | 118 | Retry Logic: 119 | 120 | - LLM providers have retry logic built in for API related errors. This can be 121 | configured in the provider. 122 | - The retry logic in the apply_to_column_autobatch is for handling invalid 123 | responses (e.g. unexpected category, different number of outputs, etc.) 124 | 125 | ```python 126 | from databonsai.utils import apply_to_column_batch, apply_to_column, apply_to_column_autobatch 127 | import pandas as pd 128 | 129 | headlines = [ 130 | "Massive Blizzard Hits the Northeast, Thousands Without Power", 131 | "Local High School Basketball Team Wins State Championship After Dramatic Final", 132 | "Celebrated Actor Launches New Environmental Awareness Campaign", 133 | "President Announces Comprehensive Plan to Combat Cybersecurity Threats", 134 | "Tech Giant Unveils Revolutionary Quantum Computer", 135 | "Tropical Storm Alina Strengthens to Hurricane as It Approaches the Coast", 136 | "Olympic Gold Medalist Announces Retirement, Plans Coaching Career", 137 | "Film Industry Legends Team Up for Blockbuster Biopic", 138 | "Government Proposes Sweeping Reforms in Public Health Sector", 139 | "Startup Develops App That Predicts Traffic Patterns Using AI", 140 | ] 141 | df = pd.DataFrame(headlines, columns=["Headline"]) 142 | df["Category"] = None # Initialize it if it doesn't exist, as we modify it in place 143 | success_idx = apply_to_column_autobatch( df["Headline"], df["Category"], categorizer.categorize_batch, batch_size=3, start_idx=0) 144 | ``` 145 | 146 | There are many more options available for autobatch, such as setting a 147 | max_retries, decay factor, and more. Check [Utils](./docs/Utils.md) for more 148 | details 149 | 150 | If it fails midway (even after exponential backoff), you can resume from the 151 | last successful index + 1. 152 | 153 | ```python 154 | success_idx = apply_to_column_autobatch( df["Headline"], df["Category"], categorizer.categorize_batch, batch_size=10, start_idx=success_idx+1) 155 | ``` 156 | 157 | This also works for regular python lists. 158 | 159 | Note that the better the LLM model, the greater the batch_size you can use 160 | (depending on the length of your inputs). If you're getting errors, reduce the 161 | batch_size, or use a better LLM model. 162 | 163 | To use it with batching, but with a fixed batch size: 164 | 165 | ```python 166 | success_idx = apply_to_column_batch( df["Headline"], df["Category"], categorizer.categorize_batch, batch_size=3, start_idx=0) 167 | ``` 168 | 169 | To use it without batching: 170 | 171 | ```python 172 | success_idx = apply_to_column( df["Headline"], df["Category"], categorizer.categorize) 173 | ``` 174 | 175 | ### View System Prompt 176 | 177 | ```python 178 | print(categorizer.system_message) 179 | print(categorizer.system_message_batch) 180 | ``` 181 | 182 | ### View token usage 183 | 184 | Token usage is recorded for OpenAI and Anthropic. Use these to estimate your 185 | costs! 186 | 187 | ```python 188 | print(provder.input_tokens) 189 | print(provder.output_tokens) 190 | ``` 191 | 192 | ## [Docs](./docs/) 193 | 194 | ### Tools (Check out the docs for usage examples and details) 195 | 196 | - [BaseCategorizer](./docs/BaseCategorizer.md) - categorize data into a 197 | category 198 | - [MultiCategorizer](./docs/MultiCategorizer.md) - categorize data into 199 | multiple categories 200 | - [BaseTransformer](./docs/BaseTransformer.md) - transform data with a prompt 201 | - [ExtractTransformer](./docs/ExtractTransformer.md) - Extract data into a 202 | structured format based on a schema 203 | - .. more coming soon! 204 | 205 | ### LLM Providers 206 | 207 | - [OpenAIProvider](./docs/OpenAIProvider.md) - OpenAI 208 | - [AnthropicProvider](./docs/AnthropicProvider.md) - Anthropic 209 | - [OllamaProvider](./docs/OllamaProvider.md) - Ollama 210 | - .. more coming soon! 211 | 212 | ### Examples 213 | 214 | - [Examples](./databonsai/examples/) (TBD) 215 | 216 | ### Acknowledgements 217 | 218 | Bonsai icon from icons8 https://icons8.com/icon/74uBtdDr5yFq/bonsai 219 | -------------------------------------------------------------------------------- /databonsai/CONTRIBUTING.MD: -------------------------------------------------------------------------------- 1 | ## Contributing to DataBonsai 2 | 3 | - Reporting bugs and issues 4 | - Suggesting new features or enhancements 5 | - Submitting pull requests with bug fixes or new features 6 | - Improving documentation 7 | - Providing feedback and suggestions 8 | 9 | ## Getting Started 10 | 11 | To get started with contributing to DataBonsai, follow these steps: 12 | 13 | Fork the DataBonsai repository on GitHub. Clone your forked repository to your 14 | local machine. Create a new branch for your contribution: 15 | 16 | ``` 17 | git checkout -b feature/your-feature-name 18 | ``` 19 | 20 | - Make your changes and commit them with descriptive commit messages. 21 | 22 | - Push your changes to your forked repository. 23 | - Open a pull request against the main DataBonsai repository. 24 | 25 | ## Reporting Bugs and Issues 26 | 27 | If you encounter a bug or issue while using DataBonsai, please report it by 28 | opening a new issue on the GitHub repository, or email databonsai.ai@gmail.com 29 | -------------------------------------------------------------------------------- /databonsai/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alvin-r/databonsai/3f2b7c58d0aa172251b5c3c321dea93038853e32/databonsai/__init__.py -------------------------------------------------------------------------------- /databonsai/categorize/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_categorizer import BaseCategorizer 2 | from .multi_categorizer import MultiCategorizer 3 | -------------------------------------------------------------------------------- /databonsai/categorize/base_categorizer.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional 2 | from pydantic import BaseModel, field_validator, model_validator, computed_field 3 | from databonsai.llm_providers import OpenAIProvider, LLMProvider 4 | from databonsai.utils.logs import logger 5 | 6 | 7 | class BaseCategorizer(BaseModel): 8 | """ 9 | A base class for categorizing input data using a specified LLM provider. 10 | 11 | Attributes: 12 | categories (Dict[str, str]): A dictionary mapping category names to their descriptions. 13 | llm_provider (LLMProvider): An instance of an LLM provider to be used for categorization. 14 | 15 | """ 16 | 17 | categories: Dict[str, str] 18 | llm_provider: LLMProvider 19 | examples: Optional[List[Dict[str, str]]] = [] 20 | strict: bool = True 21 | 22 | class Config: 23 | arbitrary_types_allowed = True 24 | 25 | @field_validator("categories") 26 | def validate_categories(cls, v): 27 | """ 28 | Validates the categories dictionary. 29 | 30 | Args: 31 | v (Dict[str, str]): The categories dictionary to be validated. 32 | 33 | Raises: 34 | ValueError: If the categories dictionary is empty or has less than two key-value pairs. 35 | 36 | Returns: 37 | Dict[str, str]: The validated categories dictionary. 38 | """ 39 | if not v: 40 | raise ValueError("Categories dictionary cannot be empty.") 41 | if len(v) < 2: 42 | raise ValueError( 43 | "Categories dictionary must have more than one key-value pair." 44 | ) 45 | return v 46 | 47 | @field_validator("examples") 48 | def validate_examples(cls, v): 49 | """ 50 | Validates the examples list. 51 | 52 | Args: 53 | v (List[Dict[str, str]]): The examples list to be validated. 54 | 55 | Raises: 56 | ValueError: If the examples list is not a list of dictionaries or if any dictionary is missing the "example" or "response" key. 57 | 58 | Returns: 59 | List[Dict[str, str]]: The validated examples list. 60 | """ 61 | if not isinstance(v, list): 62 | raise ValueError("Examples must be a list of dictionaries.") 63 | for example in v: 64 | if not isinstance(example, dict): 65 | raise ValueError("Each example must be a dictionary.") 66 | if "example" not in example or "response" not in example: 67 | raise ValueError( 68 | "Each example dictionary must have 'example' and 'response' keys." 69 | ) 70 | return v 71 | 72 | @model_validator(mode="after") 73 | def validate_examples_responses(self): 74 | """ 75 | Validates that the "response" values in the examples are within the categories keys. 76 | """ 77 | 78 | if self.examples: 79 | category_keys = set(self.categories.keys()) 80 | for example in self.examples: 81 | response = example.get("response") 82 | if response not in category_keys: 83 | raise ValueError( 84 | f"Example response '{response}' is not one of the provided categories, {str(list(self.categories.keys()))}." 85 | ) 86 | 87 | return self 88 | 89 | @computed_field 90 | @property 91 | def category_mapping(self) -> Dict[int, str]: 92 | return {i: category for i, category in enumerate(self.categories.keys())} 93 | 94 | @computed_field 95 | @property 96 | def inverse_category_mapping(self) -> Dict[str, int]: 97 | return {category: i for i, category in self.category_mapping.items()} 98 | 99 | @computed_field 100 | @property 101 | def system_message(self) -> str: 102 | categories_with_numbers = "\n".join( 103 | [f"{i}: {desc}" for i, desc in enumerate(self.categories.values())] 104 | ) 105 | system_message = f""" 106 | You are a categorization expert, specializing in classifying text data into predefined categories. You cannot use any other categories except those provided. 107 | Each category is formatted as : 108 | {categories_with_numbers} 109 | Classify the given text snippet into one of the following categories: 110 | {str(list(range(len(self.categories))))} 111 | Do not use any other categories. 112 | Only reply with the category number. Do not make any other conversation. 113 | """ 114 | # Add in fewshot examples 115 | if self.examples: 116 | for example in self.examples: 117 | system_message += f"\nEXAMPLE: {example['example']} RESPONSE: {self.inverse_category_mapping[example['response']]}" 118 | return system_message 119 | 120 | @computed_field 121 | @property 122 | def system_message_batch(self) -> str: 123 | categories_with_numbers = "\n".join( 124 | [f"{i}: {desc}" for i, desc in enumerate(self.categories.values())] 125 | ) 126 | system_message = f""" 127 | You are a categorization expert, specializing in classifying text data into predefined categories. You cannot use any other categories except those provided. 128 | Each category is formatted as : 129 | {categories_with_numbers} 130 | Classify each given text snippet into one of the following categories: 131 | {str(list(range(len(self.categories))))}. 132 | Do not use any other categories. If there are multiple snippets, separate each category number with ||. 133 | EXAMPLE: RESPONSE: 134 | EXAMPLE: || RESPONSE: || 135 | Choose one category for each text snippet. 136 | Only reply with the category numbers. Do not make any other conversation. 137 | """ 138 | # Add in fewshot examples 139 | if self.examples: 140 | system_message += "\n EXAMPLE:" 141 | system_message += ( 142 | f"{'||'.join([example['example'] for example in self.examples])}" 143 | ) 144 | system_message += f"\n RESPONSE: {'||'.join([str(self.inverse_category_mapping[example['response']]) for example in self.examples])}" 145 | 146 | return system_message 147 | 148 | def categorize(self, input_data: str) -> str: 149 | """ 150 | Categorizes the input data using the specified LLM provider. 151 | 152 | Args: 153 | input_data (str): The text data to be categorized. 154 | 155 | Returns: 156 | str: The predicted category for the input data. 157 | 158 | Raises: 159 | ValueError: If the predicted category is not one of the provided categories. 160 | """ 161 | # Call the LLM provider to get the predicted category number 162 | response = self.llm_provider.generate(self.system_message, input_data) 163 | predicted_category_number = int(response.strip()) 164 | 165 | # Validate that the predicted category number is within the valid range 166 | if predicted_category_number not in self.category_mapping: 167 | if self.strict: 168 | raise ValueError( 169 | f"Predicted category number '{predicted_category_number}' is not one of the provided categories. Use 'strict=False' when instantiating the categorizer to allow categories not in the categories dict." 170 | ) 171 | else: 172 | logger.warning( 173 | f"Predicted category number '{predicted_category_number}' is not one of the provided categories. Use 'strict=True' when instantiating the categorizer to raise an error." 174 | ) 175 | 176 | # Convert the category number back to the category key 177 | predicted_category = self.category_mapping[predicted_category_number] 178 | return predicted_category 179 | 180 | def categorize_batch(self, input_data: List[str]) -> List[str]: 181 | """ 182 | Categorizes a batch of input data using the specified LLM provider. For less advanced LLMs, call this method on batches of 3-5 inputs (depending on the length of the input data). 183 | 184 | Args: 185 | input_data (List[str]): A list of text data to be categorized. 186 | 187 | Returns: 188 | List[str]: A list of predicted categories for the input data. 189 | 190 | Raises: 191 | ValueError: If the predicted categories are not a subset of the provided categories. 192 | """ 193 | # If there is only one input, call the categorize method 194 | if len(input_data) == 1: 195 | return self.validate_predicted_categories( 196 | [self.categorize(next(iter(input_data)))] 197 | ) 198 | 199 | input_data_prompt = "||".join(input_data) 200 | # Call the LLM provider to get the predicted category numbers 201 | response = self.llm_provider.generate( 202 | self.system_message_batch, input_data_prompt 203 | ) 204 | predicted_category_numbers = [ 205 | int(category.strip()) for category in response.split("||") 206 | ] 207 | if len(predicted_category_numbers) != len(input_data): 208 | raise ValueError( 209 | f"Number of predicted categories ({len(predicted_category_numbers)}) does not match the number of input data ({len(input_data)})." 210 | ) 211 | # Convert the category numbers back to category keys 212 | predicted_categories = [ 213 | self.category_mapping[number] for number in predicted_category_numbers 214 | ] 215 | return self.validate_predicted_categories(predicted_categories) 216 | 217 | def validate_predicted_categories( 218 | self, predicted_categories: List[str] 219 | ) -> List[str]: 220 | # Filter out empty strings from the predicted categories 221 | filtered_categories = [ 222 | category for category in predicted_categories if category 223 | ] 224 | 225 | # Validate each category in the filtered list 226 | for predicted_category in filtered_categories: 227 | if predicted_category not in self.categories: 228 | if self.strict: 229 | raise ValueError( 230 | f"Predicted category '{predicted_category}' is not one of the provided categories. Use 'strict=False' when instantiating the categorizer to allow categories not in the categories dict." 231 | ) 232 | else: 233 | # Warn the user if the predicted category is not one of the provided categories 234 | logger.warning( 235 | f"Predicted category '{predicted_category}' is not one of the provided categories. Use 'strict=True' when instantiating the categorizer to raise an error." 236 | ) 237 | return filtered_categories 238 | -------------------------------------------------------------------------------- /databonsai/categorize/multi_categorizer.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict 2 | from databonsai.categorize.base_categorizer import BaseCategorizer 3 | from pydantic import model_validator, computed_field 4 | 5 | 6 | class MultiCategorizer(BaseCategorizer): 7 | """ 8 | A class for categorizing input data into multiple categories using a specified LLM provider. 9 | 10 | This class extends the BaseCategorizer class and overrides the categorize method to allow 11 | for multiple category predictions. 12 | """ 13 | 14 | class Config: 15 | arbitrary_types_allowed = True 16 | 17 | @model_validator(mode="after") 18 | def validate_examples_responses(self): 19 | """ 20 | Validates that the "response" values in the examples are within the categories keys. If there are multiple categories, they should be separated by commas. 21 | """ 22 | 23 | if self.examples: 24 | category_keys = set(self.categories.keys()) 25 | for example in self.examples: 26 | response = example.get("response") 27 | response_categories = response.split(",") 28 | for response_category in response_categories: 29 | if response_category not in category_keys: 30 | raise ValueError( 31 | f"Example response '{response}' is not one of the provided categories, {str(list(self.categories.keys()))}." 32 | ) 33 | 34 | return self 35 | 36 | @computed_field 37 | @property 38 | def system_message(self) -> str: 39 | categories_with_numbers = "\n".join( 40 | [f"{i}: {desc}" for i, desc in enumerate(self.categories.values())] 41 | ) 42 | system_message = f""" 43 | Each category is formatted as : 44 | {categories_with_numbers} 45 | Classify the given text snippet into one or more of the following categories: 46 | {str(list(range(len(self.categories))))} 47 | Do not use any other categories. 48 | Assign multiple categories to one content snippet by separating the categories with ||. Do not make any other conversation. 49 | """ 50 | 51 | # Add in fewshot examples 52 | if self.examples: 53 | for example in self.examples: 54 | response_numbers = [ 55 | str(self.inverse_category_mapping[category.strip()]) 56 | for category in example["response"].split(",") 57 | ] 58 | system_message += f"\nEXAMPLE: {example['example']} RESPONSE: {'||'.join(response_numbers)}" 59 | return system_message 60 | 61 | @computed_field 62 | @property 63 | def system_message_batch(self) -> str: 64 | categories_with_numbers = "\n".join( 65 | [f"{i}: {desc}" for i, desc in enumerate(self.categories.values())] 66 | ) 67 | system_message = f""" 68 | Each category is formatted as : 69 | {categories_with_numbers} 70 | Classify the given text snippet into one or more of the following categories: 71 | {str(list(range(len(self.categories))))} 72 | Do not use any other categories. 73 | Assign multiple categories to one content snippet by separating the categories with ||. Differentiate between each content snippet using ##. EXAMPLE: ## \n RESPONSE: ||## Do not make any other conversation. 74 | """ 75 | 76 | # Add in fewshot examples 77 | if self.examples: 78 | system_message += "\nEXAMPLE: " 79 | system_message += ( 80 | f"{'##'.join([example['example'] for example in self.examples])}" 81 | ) 82 | response_numbers_list = [] 83 | for example in self.examples: 84 | response_numbers = [ 85 | str(self.inverse_category_mapping[category.strip()]) 86 | for category in example["response"].split(",") 87 | ] 88 | response_numbers_list.append("||".join(response_numbers)) 89 | system_message += f"\nRESPONSE: {'##'.join(response_numbers_list)}" 90 | return system_message 91 | 92 | def categorize(self, input_data: str) -> str: 93 | """ 94 | Categorizes the input data into multiple categories using the specified LLM provider. 95 | 96 | Args: 97 | input_data (str): The text data to be categorized. 98 | 99 | Returns: 100 | str: A string of categories for the input data, separated by commas. 101 | 102 | Raises: 103 | ValueError: If the predicted categories are not a subset of the provided categories. 104 | """ 105 | 106 | # Call the LLM provider to get the predicted category numbers 107 | response = self.llm_provider.generate(self.system_message, input_data) 108 | predicted_category_numbers = [ 109 | int(category.strip()) for category in response.split("||") 110 | ] 111 | 112 | # Convert the category numbers back to category keys 113 | predicted_categories = [ 114 | self.category_mapping[number] for number in predicted_category_numbers 115 | ] 116 | return ",".join(self.validate_predicted_categories(predicted_categories)) 117 | 118 | def categorize_batch(self, input_data: List[str]) -> List[str]: 119 | """ 120 | Categorizes the input data into multiple categories using the specified LLM provider. 121 | 122 | Args: 123 | input_data (str): The text data to be categorized. 124 | 125 | Returns: 126 | List[str]: A list of predicted categories for the input data. If there are multiple categories, they will be separated by commas. 127 | 128 | Raises: 129 | ValueError: If the predicted categories are not a subset of the provided categories. 130 | """ 131 | if len(input_data) == 1: 132 | return [self.categorize(next(iter(input_data)))] 133 | 134 | input_data_prompt = "##".join(input_data) 135 | # Call the LLM provider to get the predicted category numbers 136 | response = self.llm_provider.generate( 137 | self.system_message_batch, input_data_prompt 138 | ) 139 | # Split the response into category number sets for each input data 140 | category_number_sets = response.split("##") 141 | 142 | if len(category_number_sets) != len(input_data): 143 | raise ValueError( 144 | f"Number of predicted category sets ({len(category_number_sets)}) does not match the number of input data ({len(input_data)})." 145 | ) 146 | 147 | predicted_categories_list = [] 148 | for category_number_set in category_number_sets: 149 | predicted_category_numbers = [ 150 | int(category.strip()) for category in category_number_set.split("||") 151 | ] 152 | predicted_categories = [ 153 | self.category_mapping[number] for number in predicted_category_numbers 154 | ] 155 | predicted_categories_str = ",".join( 156 | self.validate_predicted_categories(predicted_categories) 157 | ) 158 | predicted_categories_list.append(predicted_categories_str) 159 | 160 | return predicted_categories_list 161 | -------------------------------------------------------------------------------- /databonsai/examples/categorize_news.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from databonsai.categorize import MultiCategorizer, BaseCategorizer\n", 10 | "from databonsai.transform import BaseTransformer\n", 11 | "from databonsai.llm_providers import OpenAIProvider, AnthropicProvider, OllamaProvider\n", 12 | "from databonsai.utils import (\n", 13 | " apply_to_column,\n", 14 | " apply_to_column_batch,\n", 15 | " apply_to_column_autobatch,\n", 16 | ")\n", 17 | "import pandas as pd" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 3, 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "provider = OpenAIProvider(model=\"gpt-3.5-turbo\") # Or AnthropicProvider()" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 4, 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "categories = {\n", 36 | " \"Weather\": \"Insights and remarks about weather conditions.\",\n", 37 | " \"Sports\": \"Observations and comments on sports events.\",\n", 38 | " \"Politics\": \"Political events related to governments, nations, or geopolitical issues.\",\n", 39 | " \"Celebrities\": \"Celebrity sightings and gossip\",\n", 40 | " \"Tech\": \"News and updates about technology and tech companies.\",\n", 41 | " \"Others\": \"Comments do not fit into any of the above categories\", # Best practice in case it can't be categorized easily\n", 42 | " \"Anomaly\": \"Data that does not look like comments or natural language\", # Helps to flag unclean/problematic data\n", 43 | "}\n", 44 | "categorizer = BaseCategorizer(\n", 45 | " categories=categories,\n", 46 | " llm_provider=provider,\n", 47 | " examples=[\n", 48 | " {\"example\": \"Big stormy skies over city\", \"response\": \"Weather\"},\n", 49 | " {\"example\": \"The team won the championship\", \"response\": \"Sports\"},\n", 50 | " {\"example\": \"I saw a famous rapper at the mall\", \"response\": \"Celebrities\"},\n", 51 | " ],\n", 52 | ")" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 5, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "headlines2 = [\n", 62 | " \"Local Fire Department Honored with National Award for Bravery\",\n", 63 | " \"Breakthrough Research Promises New Treatment for Alzheimer’s Disease\",\n", 64 | " \"Major Airline Announces Expansion of International Routes\",\n", 65 | " \"Historic Peace Agreement Signed Between Rival Nations\",\n", 66 | " \"City Council Votes to Increase Funding for Public Libraries\",\n", 67 | " \"Renowned Chef Opens Vegan Restaurant in Downtown\",\n", 68 | " \"Veteran Astronaut Set to Lead Next Moon Mission\",\n", 69 | " \"Global Music Festival Raises Funds for Refugee Relief\",\n", 70 | " \"Innovative Urban Farming Techniques Revolutionize City Life\",\n", 71 | " \"Climate Summit Sets Ambitious Goals for Carbon Reduction\",\n", 72 | " \"Documentary Film Exposing Corruption Premieres to Critical Acclaim\",\n", 73 | " \"New Legislation Aims to Boost Small Businesses\",\n", 74 | " \"Ancient Shipwreck Discovered Off the Coast of Sicily\",\n", 75 | " \"World Health Organization Declares New Strain of Virus Contained\",\n", 76 | " \"International Art Theft Ring Busted by Joint Task Force\",\n", 77 | " \"Leading Economists Predict Global Recession in Next Year\",\n", 78 | " \"Celebrity Fashion Designer Debuts Eco-Friendly Line\",\n", 79 | " \"Major Breakthrough in Quantum Encryption Technology Announced\",\n", 80 | " \"Wildlife Conservation Efforts Successfully Increase Tiger Population\",\n", 81 | " \"Rare Astronomical Event Visible This Weekend\",\n", 82 | " \"Nationwide Protests Demand Action on Climate Change\",\n", 83 | " \"Revolutionary New Battery Design Could Transform Renewable Energy Storage\",\n", 84 | " \"Record-Breaking Heatwave Strikes Southern Europe\",\n", 85 | " \"Underground Water Reserves Discovered Beneath Sahara Desert\",\n", 86 | " \"Virtual Reality Platform Takes Online Education to New Heights\",\n", 87 | " \"Controversial New Policy Sparks Debate Over Internet Privacy\",\n", 88 | " \"Youngest Nobel Laureate Awarded for Work in Peace Building\",\n", 89 | " \"Sports League Implements New Rules to Protect Players from Concussions\",\n", 90 | " \"Historic Church Undergoes Restoration to Preserve Cultural Heritage\",\n", 91 | " \"Pioneering Surgery Gives New Hope to Heart Disease Patients\",\n", 92 | " \"Wildfires Rage Across California, Thousands Evacuated\",\n", 93 | " \"Tech Start-Up Revolutionizes Mobile Payment Systems\",\n", 94 | " \"Pharmaceutical Company Faces Lawsuit Over Drug Side Effects\",\n", 95 | " \"Renewable Energy Now Powers Entire Small Nation\",\n", 96 | " \"Central Bank Raises Interest Rates in Surprise Move\",\n", 97 | " \"Marine Biologists Discover New Species in the Deep Ocean\",\n", 98 | " \"Global Conference on Women's Rights Concludes with Action Plan\",\n", 99 | " \"Country Music Star Reveals Struggle with Mental Health in New Album\",\n", 100 | " \"Massive Oil Spill Threatens Wildlife Along the Coast\",\n", 101 | " \"Protests Erupt as Government Cuts Healthcare Funding\",\n", 102 | " \"Archaeologists Uncover New Evidence of Ancient Civilization in Thailand\",\n", 103 | " \"Fashion Week Highlights Sustainability in New Collections\",\n", 104 | " \"New Strain of Wheat Could Increase Crop Yields Substantially\",\n", 105 | " \"Scientists Link Air Pollution to Decline in Urban Wildlife\",\n", 106 | " \"Innovative Community Program Cuts Urban Crime Rate\",\n", 107 | " \"Next Generation of Smartphones Features Advanced AI Capabilities\",\n", 108 | " \"Historical Drama Film Set to Break Box Office Records\",\n", 109 | " \"Study Shows Increase in Cyber Attacks on Financial Institutions\",\n", 110 | " \"New Yoga Trend Combines Traditional Practices with Modern Technology\",\n", 111 | " \"Local Community Garden Doubles as Educational Facility for Schools\",\n", 112 | "]" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": 6, 118 | "metadata": {}, 119 | "outputs": [ 120 | { 121 | "name": "stderr", 122 | "output_type": "stream", 123 | "text": [ 124 | "Categorizing: 0%| | 0/50 [00:00 str: 81 | """ 82 | Generates a text completion using Anthropic's Claude API, with a given system and user prompt. 83 | This method is decorated with retry logic to handle temporary failures. 84 | 85 | Parameters: 86 | system_prompt (str): The system prompt to provide context or instructions for the generation. 87 | user_prompt (str): The user's prompt, based on which the text completion is generated. 88 | max_tokens (int): The maximum number of tokens to generate in the response. 89 | 90 | Returns: 91 | str: The generated text completion. 92 | """ 93 | try: 94 | if not system_prompt: 95 | raise ValueError("System prompt is required.") 96 | if not user_prompt: 97 | raise ValueError("User prompt is required.") 98 | response = self.client.messages.create( 99 | model=self.model, 100 | max_tokens=max_tokens, 101 | temperature=self.temperature, 102 | system=f"{system_prompt}", 103 | messages=[ 104 | { 105 | "role": "user", 106 | "content": [ 107 | { 108 | "type": "text", 109 | "text": user_prompt, 110 | } 111 | ], 112 | } 113 | ], 114 | ) 115 | self.input_tokens += response.usage.input_tokens 116 | self.output_tokens += response.usage.output_tokens 117 | return response.content[0].text 118 | except Exception as e: 119 | logger.warning(f"Error occurred during generation: {str(e)}") 120 | raise 121 | -------------------------------------------------------------------------------- /databonsai/llm_providers/llm_provider.py: -------------------------------------------------------------------------------- 1 | # llm_providers/base_provider.py 2 | from abc import ABC, abstractmethod 3 | from typing import Optional 4 | 5 | 6 | class LLMProvider(ABC): 7 | @abstractmethod 8 | def __init__( 9 | self, 10 | model: str = "", 11 | temperature: float = 0, 12 | ): 13 | """ 14 | Initializes the LLMProvider with an API key and retry parameters. 15 | 16 | Parameters: 17 | 18 | model (str): The default model to use for text generation. 19 | temperature (float): The temperature parameter for text generation. 20 | """ 21 | pass 22 | 23 | @abstractmethod 24 | def generate(self, system_prompt: str, user_prompt: str, max_tokens: int) -> str: 25 | """ 26 | Generates a text completion using the provider's API, with a given system and user prompt. 27 | This method should be decorated with retry logic to handle temporary failures. 28 | 29 | Parameters: 30 | system_prompt (str): The system prompt to provide context or instructions for the generation. 31 | user_prompt (str): The user's prompt, based on which the text completion is generated. 32 | max_tokens (int): The maximum number of tokens to generate in the response. 33 | 34 | Returns: 35 | str: The generated text completion. 36 | """ 37 | pass 38 | -------------------------------------------------------------------------------- /databonsai/llm_providers/ollama_provider.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from ollama import Client, chat 3 | from .llm_provider import LLMProvider 4 | from databonsai.utils.logs import logger 5 | 6 | 7 | class OllamaProvider(LLMProvider): 8 | """ 9 | A provider class to interact with Ollama's API. 10 | Supports exponential backoff retries, since we'll often deal with large datasets. 11 | """ 12 | 13 | def __init__( 14 | self, 15 | model: str = "llama3", 16 | temperature: float = 0, 17 | host: Optional[str] = None, 18 | ): 19 | """ 20 | Initializes the OllamaProvider with an optional Ollama client or host, and retry parameters. 21 | 22 | Parameters: 23 | model (str): The default model to use for text generation. 24 | temperature (float): The temperature parameter for text generation. 25 | host (str): The host URL for the Ollama API. 26 | """ 27 | # Provider related configs 28 | self.model = model 29 | self.temperature = temperature 30 | 31 | if host: 32 | self.client = Client(host=host) 33 | else: 34 | self.client = None 35 | 36 | def _chat(self, messages, options): 37 | if self.client: 38 | return self.client.chat( 39 | model=self.model, messages=messages, options=options 40 | ) 41 | else: 42 | return chat(model=self.model, messages=messages, options=options) 43 | 44 | def generate(self, system_prompt: str, user_prompt: str, max_tokens=1000) -> str: 45 | """ 46 | Generates a text completion using Ollama's API, with a given system and user prompt. 47 | This method is decorated with retry logic to handle temporary failures. 48 | Parameters: 49 | system_prompt (str): The system prompt to provide context or instructions for the generation. 50 | user_prompt (str): The user's prompt, based on which the text completion is generated. 51 | max_tokens (int): The maximum number of tokens to generate in the response. 52 | Returns: 53 | str: The generated text completion. 54 | """ 55 | if not system_prompt: 56 | raise ValueError("System prompt is required.") 57 | if not user_prompt: 58 | raise ValueError("User prompt is required.") 59 | try: 60 | messages = [ 61 | {"role": "system", "content": system_prompt}, 62 | {"role": "user", "content": user_prompt}, 63 | ] 64 | 65 | response = self._chat( 66 | messages, 67 | options={"temperature": self.temperature, "num_predict": max_tokens}, 68 | ) 69 | completion = response["message"]["content"] 70 | 71 | return completion 72 | except Exception as e: 73 | logger.warning(f"Error occurred during generation: {str(e)}") 74 | raise 75 | -------------------------------------------------------------------------------- /databonsai/llm_providers/openai_provider.py: -------------------------------------------------------------------------------- 1 | from openai import OpenAI 2 | from .llm_provider import LLMProvider 3 | import os 4 | from functools import wraps 5 | from tenacity import retry, wait_exponential, stop_after_attempt 6 | from dotenv import load_dotenv 7 | from databonsai.utils.logs import logger 8 | 9 | load_dotenv() 10 | 11 | 12 | class OpenAIProvider(LLMProvider): 13 | """ 14 | A provider class to interact with OpenAI's API. 15 | Supports exponential backoff retries, since we'll often deal with large datasets. 16 | """ 17 | 18 | def __init__( 19 | self, 20 | api_key: str = None, 21 | multiplier: int = 1, 22 | min_wait: int = 1, 23 | max_wait: int = 30, 24 | max_tries: int = 5, 25 | model: str = "gpt-4-turbo", 26 | temperature: float = 0, 27 | ): 28 | """ 29 | Initializes the OpenAIProvider with an API key and retry parameters. 30 | 31 | Parameters: 32 | api_key (str): OpenAI API key. 33 | multiplier (int): The multiplier for the exponential backoff in retries. 34 | min_wait (int): The minimum wait time between retries. 35 | max_wait (int): The maximum wait time between retries. 36 | max_tries (int): The maximum number of attempts before giving up. 37 | model (str): The default model to use for text generation. 38 | temperature (float): The temperature parameter for text generation. 39 | """ 40 | super().__init__() 41 | 42 | # Provider related configs 43 | if api_key: 44 | self.api_key = api_key 45 | else: 46 | self.api_key = os.getenv("OPENAI_API_KEY") 47 | if not self.api_key: 48 | raise ValueError("OpenAI API key not provided.") 49 | self.model = model 50 | self.client = OpenAI(api_key=self.api_key) 51 | try: 52 | self.client.models.retrieve(model) 53 | except Exception as e: 54 | logger.warning(e.response.status_code) 55 | raise ValueError(f"Invalid OpenAI model: {model}") from e 56 | self.temperature = temperature 57 | self.input_tokens = 0 58 | self.output_tokens = 0 59 | 60 | # Retry related configs 61 | self.multiplier = multiplier 62 | self.min_wait = min_wait 63 | self.max_wait = max_wait 64 | self.max_tries = max_tries 65 | 66 | def retry_with_exponential_backoff(method): 67 | """ 68 | Decorator to apply retry logic with exponential backoff to an instance method. 69 | It captures the 'self' context to access instance attributes for retry configuration. 70 | """ 71 | 72 | @wraps(method) 73 | def wrapper(self, *args, **kwargs): 74 | retry_decorator = retry( 75 | wait=wait_exponential( 76 | multiplier=self.multiplier, min=self.min_wait, max=self.max_wait 77 | ), 78 | stop=stop_after_attempt(self.max_tries), 79 | ) 80 | return retry_decorator(method)(self, *args, **kwargs) 81 | 82 | return wrapper 83 | 84 | @retry_with_exponential_backoff 85 | def generate( 86 | self, system_prompt: str, user_prompt: str, max_tokens=1000, json=False 87 | ) -> str: 88 | """ 89 | Generates a text completion using OpenAI's API, with a given system and user prompt. 90 | This method is decorated with retry logic to handle temporary failures. 91 | 92 | Parameters: 93 | system_prompt (str): The system prompt to provide context or instructions for the generation. 94 | user_prompt (str): The user's prompt, based on which the text completion is generated. 95 | max_tokens (int): The maximum number of tokens to generate in the response. 96 | json (bool): Whether to use OpenAI's JSON response format. 97 | 98 | Returns: 99 | str: The generated text completion. 100 | """ 101 | if not system_prompt: 102 | raise ValueError("System prompt is required.") 103 | if not user_prompt: 104 | raise ValueError("User prompt is required.") 105 | try: 106 | response = self.client.chat.completions.create( 107 | model=self.model, 108 | messages=[ 109 | {"role": "system", "content": system_prompt}, 110 | {"role": "user", "content": f"{user_prompt}"}, 111 | ], 112 | temperature=self.temperature, 113 | max_tokens=max_tokens, 114 | frequency_penalty=0, 115 | presence_penalty=0, 116 | response_format={"type": "json_object"} if json else {"type": "text"}, 117 | ) 118 | self.input_tokens += response.usage.prompt_tokens 119 | self.output_tokens += response.usage.completion_tokens 120 | return response.choices[0].message.content 121 | except Exception as e: 122 | logger.warning(f"Error occurred during generation: {str(e)}") 123 | raise 124 | -------------------------------------------------------------------------------- /databonsai/transform/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_transformer import BaseTransformer 2 | from .extract_transformer import ExtractTransformer 3 | -------------------------------------------------------------------------------- /databonsai/transform/base_transformer.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Dict 2 | from pydantic import BaseModel, field_validator, model_validator, computed_field 3 | from databonsai.llm_providers import LLMProvider 4 | 5 | 6 | class BaseTransformer(BaseModel): 7 | """ 8 | A base class for transforming input data using a specified LLM provider. 9 | 10 | Attributes: 11 | prompt (str): The prompt used to guide the transformation process. 12 | llm_provider (LLMProvider): An instance of an LLM provider to be used for transformation. 13 | examples (Optional[List[Dict[str, str]]]): A list of example inputs and their corresponding transformed outputs. 14 | 15 | """ 16 | 17 | prompt: str 18 | llm_provider: LLMProvider 19 | examples: Optional[List[Dict[str, str]]] = [] 20 | 21 | class Config: 22 | arbitrary_types_allowed = True 23 | 24 | @field_validator("prompt") 25 | def validate_prompt(cls, v): 26 | """ 27 | Validates the prompt. 28 | 29 | Args: 30 | v (str): The prompt to be validated. 31 | 32 | Raises: 33 | ValueError: If the prompt is empty. 34 | 35 | Returns: 36 | str: The validated prompt. 37 | """ 38 | if not v: 39 | raise ValueError("Prompt cannot be empty.") 40 | return v 41 | 42 | @field_validator("examples") 43 | def validate_examples(cls, v): 44 | """ 45 | Validates the examples list. 46 | 47 | Args: 48 | v (List[Dict[str, str]]): The examples list to be validated. 49 | 50 | Raises: 51 | ValueError: If the examples list is not a list of dictionaries or if any dictionary is missing the "example" or "response" key. 52 | 53 | Returns: 54 | List[Dict[str, str]]: The validated examples list. 55 | """ 56 | if not isinstance(v, list): 57 | raise ValueError("Examples must be a list of dictionaries.") 58 | for example in v: 59 | if not isinstance(example, dict): 60 | raise ValueError("Each example must be a dictionary.") 61 | if "example" not in example or "response" not in example: 62 | raise ValueError( 63 | "Each example dictionary must have 'example' and 'response' keys." 64 | ) 65 | return v 66 | 67 | @computed_field 68 | @property 69 | def system_message(self) -> str: 70 | system_message = f""" 71 | Use the following prompt to transform the input data: 72 | Prompt: {self.prompt} 73 | """ 74 | 75 | # Add in fewshot examples 76 | if self.examples: 77 | for example in self.examples: 78 | system_message += ( 79 | f"\nEXAMPLE: {example['example']} RESPONSE: {example['response']}" 80 | ) 81 | 82 | return system_message 83 | 84 | @computed_field 85 | @property 86 | def system_message_batch(self) -> str: 87 | system_message = f""" 88 | Use the following prompt to transform each input data: 89 | Prompt: {self.prompt} 90 | Respond with the transformed data for each input, separated by ||. Do not make any other conversation. 91 | Example: Content 1: , Content 2: \n Response: || 92 | """ 93 | 94 | # Add in fewshot examples 95 | if self.examples: 96 | system_message += "\nExample: " 97 | for idx, example in enumerate(self.examples): 98 | system_message += f"Content {str(idx+1)}: {example['example']}, " 99 | system_message += f"\nResponse: " 100 | for example in self.examples: 101 | system_message += f"{example['response']}||" 102 | 103 | return system_message 104 | 105 | def transform(self, input_data: str, max_tokens=1000, json: bool =False) -> str: 106 | """ 107 | Transforms the input data using the specified LLM provider. 108 | 109 | Args: 110 | input_data (str): The text data to be transformed. 111 | max_tokens (int, optional): The maximum number of tokens to generate in the response. Defaults to 1000. 112 | 113 | Returns: 114 | str: The transformed data. 115 | """ 116 | # Call the LLM provider to perform the transformation 117 | response = self.llm_provider.generate( 118 | self.system_message, input_data, max_tokens=max_tokens, json=json 119 | ) 120 | transformed_data = response.strip() 121 | return transformed_data 122 | 123 | def transform_batch(self, input_data: List[str], max_tokens=1000, json: bool =False) -> List[str]: 124 | """ 125 | Transforms a batch of input data using the specified LLM provider. 126 | 127 | Args: 128 | input_data (List[str]): A list of text data to be transformed. 129 | max_tokens (int, optional): The maximum number of tokens to generate in each response. Defaults to 1000. 130 | 131 | Returns: 132 | List[str]: A list of transformed data, where each element corresponds to the transformed version of the respective input data. 133 | """ 134 | if len(input_data) == 1: 135 | return [self.transform(next(iter(input_data)))] 136 | # Call the LLM provider to perform the batch transformation 137 | input_data_prompt = "||".join(input_data) 138 | response = self.llm_provider.generate( 139 | self.system_message_batch, input_data_prompt, max_tokens=max_tokens, json=json 140 | ) 141 | 142 | # Split the response into individual transformed data 143 | transformed_data_list = response.split("||") 144 | 145 | # Strip any leading/trailing whitespace from each transformed data 146 | transformed_data_list = [data.strip() for data in transformed_data_list] 147 | 148 | if len(transformed_data_list) != len(input_data): 149 | raise ValueError( 150 | f"Length of output list ({len(transformed_data_list)}) does not match the length of input list ({len(input_data)})." 151 | ) 152 | 153 | return transformed_data_list 154 | -------------------------------------------------------------------------------- /databonsai/transform/extract_transformer.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional 2 | from pydantic import field_validator, model_validator, computed_field 3 | from databonsai.transform.base_transformer import BaseTransformer 4 | 5 | 6 | class ExtractTransformer(BaseTransformer): 7 | """ 8 | This class extends the BaseTransformer class and overrides the transform method to extract a given schema from the input data into a list of dictionaries. 9 | 10 | Attributes: 11 | output_schema (Dict[str, str]): A dictionary representing the schema of the output dictionaries. 12 | examples (Optional[List[Dict[str, str]]]): A list of example inputs and their corresponding extracted outputs. 13 | 14 | Raises: 15 | ValueError: If the output schema dictionary is empty, or if the transformed data does not match the expected format or schema. 16 | """ 17 | 18 | output_schema: Dict[str, str] 19 | examples: Optional[List[Dict[str, str]]] = [] 20 | 21 | @field_validator("output_schema") 22 | def validate_schema(cls, v): 23 | """ 24 | Validates the output schema. 25 | 26 | Args: 27 | v (Dict[str, str]): The output schema to be validated. 28 | 29 | Raises: 30 | ValueError: If the output schema dictionary is empty. 31 | 32 | Returns: 33 | Dict[str, str]: The validated output schema. 34 | """ 35 | if not v: 36 | raise ValueError("Schema dictionary cannot be empty.") 37 | return v 38 | 39 | @field_validator("examples") 40 | def validate_examples(cls, v): 41 | """ 42 | Validates the examples list. 43 | 44 | Args: 45 | v (List[Dict[str, str]]): The examples list to be validated. 46 | 47 | Raises: 48 | ValueError: If the examples list is not a list of dictionaries or if any dictionary is missing the "example" or "response" key. 49 | 50 | Returns: 51 | List[Dict[str, str]]: The validated examples list. 52 | """ 53 | if not isinstance(v, list): 54 | raise ValueError("Examples must be a list of dictionaries.") 55 | for example in v: 56 | if not isinstance(example, dict): 57 | raise ValueError("Each example must be a dictionary.") 58 | if "example" not in example or "response" not in example: 59 | raise ValueError( 60 | "Each example dictionary must have 'example' and 'response' keys." 61 | ) 62 | return v 63 | 64 | @model_validator(mode="after") 65 | def validate_examples_responses(self): 66 | """ 67 | Validates that the "response" values in the examples are valid JSON-formatted lists of dictionaries that match the output schema. 68 | """ 69 | if self.examples: 70 | for example in self.examples: 71 | response = example.get("response") 72 | try: 73 | response_data = eval(response) 74 | except (SyntaxError, NameError, TypeError, ZeroDivisionError): 75 | raise ValueError( 76 | f"Invalid format in the example response: {response}" 77 | ) 78 | if not isinstance(response_data, list): 79 | raise ValueError( 80 | f"Example response must be a JSON-formatted list: {response}" 81 | ) 82 | for item in response_data: 83 | if not isinstance(item, dict): 84 | raise ValueError( 85 | f"Each item in the example response must be a dictionary: {response}" 86 | ) 87 | if set(item.keys()) != set(self.output_schema.keys()): 88 | raise ValueError( 89 | f"The keys in the example response do not match the output schema: {response}" 90 | ) 91 | return self 92 | 93 | @computed_field 94 | @property 95 | def system_message(self) -> str: 96 | system_message = f""" 97 | Use the following prompt to transform the input data: 98 | Input Data: {self.prompt} 99 | The transformed data should be a list of dictionaries, where each dictionary has the following schema: 100 | {self.output_schema} 101 | Reply with a JSON-formatted list of dictionaries. Do not make any conversation. 102 | """ 103 | 104 | # Add in few-shot examples 105 | if self.examples: 106 | for example in self.examples: 107 | system_message += ( 108 | f"\nEXAMPLE: {example['example']} RESPONSE: {example['response']}" 109 | ) 110 | 111 | return system_message 112 | 113 | 114 | def transform(self, input_data: str, max_tokens=1000, json: bool =True) -> List[Dict[str, str]]: 115 | """ 116 | Transforms the input data into a list of dictionaries using the specified LLM provider. 117 | 118 | Args: 119 | input_data (str): The text data to be transformed. 120 | max_tokens (int, optional): The maximum number of tokens to generate in the response. Defaults to 1000. 121 | 122 | Returns: 123 | List[Dict[str, str]]: The transformed data as a list of dictionaries. 124 | 125 | Raises: 126 | ValueError: If the transformed data does not match the expected format or schema. 127 | """ 128 | # Call the LLM provider to perform the transformation 129 | response = self.llm_provider.generate( 130 | self.system_message, input_data, max_tokens=max_tokens, json=json 131 | ) 132 | 133 | try: 134 | transformed_data = eval(response) 135 | except (SyntaxError, NameError, TypeError, ZeroDivisionError): 136 | raise ValueError("Invalid format in the transformed data.") 137 | 138 | # Validate the transformed data 139 | if not isinstance(transformed_data, list): 140 | raise ValueError("Transformed data must be a list.") 141 | for item in transformed_data: 142 | if not isinstance(item, dict): 143 | raise ValueError( 144 | "Each item in the transformed data must be a dictionary." 145 | ) 146 | if set(item.keys()) != set(self.output_schema.keys()): 147 | raise ValueError( 148 | "The keys in the transformed data do not match the schema." 149 | ) 150 | 151 | return transformed_data 152 | 153 | # def transform_batch(self, input_data: List[str], max_tokens=1000) -> List[List[Dict[str, str]]]: 154 | # """ 155 | # Transforms a batch of input data into lists of dictionaries using the specified LLM provider. 156 | 157 | # Args: 158 | # input_data (List[str]): A list of text data to be transformed. 159 | # max_tokens (int, optional): The maximum number of tokens to generate in each response. Defaults to 1000. 160 | 161 | # Returns: 162 | # List[List[Dict[str, str]]]: A list of transformed data, where each element is a list of dictionaries corresponding to the respective input data. 163 | 164 | # Raises: 165 | # ValueError: If the transformed data does not match the expected format or schema. 166 | # """ 167 | # # Call the LLM provider to perform the batch transformation 168 | # response = self.llm_provider.generate_batch( 169 | # self.system_message_batch, input_data, max_tokens=max_tokens 170 | # ) 171 | 172 | # # Split the response into individual transformed data 173 | # transformed_data_list = response.split("||") 174 | 175 | # # Strip any leading/trailing whitespace from each transformed data 176 | # transformed_data_list = [data.strip() for data in transformed_data_list] 177 | 178 | # if len(transformed_data_list) != len(input_data): 179 | # raise ValueError( 180 | # f"Length of output list ({len(transformed_data_list)}) does not match the length of input list ({len(input_data)})." 181 | # ) 182 | 183 | # # Evaluate and validate each transformed data 184 | # result = [] 185 | # for data in transformed_data_list: 186 | # try: 187 | # transformed_data = eval(data) 188 | # except (SyntaxError, NameError, TypeError, ZeroDivisionError): 189 | # raise ValueError(f"Invalid format in the transformed data: {data}") 190 | 191 | # # Validate the transformed data 192 | # if not isinstance(transformed_data, list): 193 | # raise ValueError(f"Transformed data must be a list: {data}") 194 | # for item in transformed_data: 195 | # if not isinstance(item, dict): 196 | # raise ValueError( 197 | # f"Each item in the transformed data must be a dictionary: {data}" 198 | # ) 199 | # if set(item.keys()) != set(self.output_schema.keys()): 200 | # raise ValueError( 201 | # f"The keys in the transformed data do not match the schema: {data}" 202 | # ) 203 | 204 | # result.append(transformed_data) 205 | 206 | # return result 207 | -------------------------------------------------------------------------------- /databonsai/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .apply import apply_to_column, apply_to_column_batch, apply_to_column_autobatch 2 | -------------------------------------------------------------------------------- /databonsai/utils/apply.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | from typing import List, Callable, Union, get_origin 3 | import inspect 4 | 5 | 6 | def apply_to_column( 7 | input_column: List, 8 | output_column: List, 9 | func: Callable, 10 | start_idx: int = 0, 11 | ) -> int: 12 | """ 13 | Apply a function to each value in a column of a DataFrame or a normal Python list, starting from a specified index. 14 | 15 | Parameters: 16 | input_column (List): The column of the DataFrame or a normal Python list to apply the function to. 17 | output_column (List): A list to store the processed values. The function will mutate this list in-place. 18 | func (callable): The function to apply to each value in the column. 19 | The function should take a single value as input and return a single value. 20 | start_idx (int, optional): The index from which to start applying the function. Default is 0. 21 | 22 | Returns: 23 | int: The index of the last successfully processed value. 24 | 25 | """ 26 | if len(input_column) == 0: 27 | raise ValueError("Input input_column is empty.") 28 | 29 | if start_idx >= len(input_column): 30 | raise ValueError( 31 | f"start_idx ({start_idx}) is greater than or equal to the length of the input_column ({len(input_column)})." 32 | ) 33 | 34 | if len(output_column) > len(input_column): 35 | raise ValueError( 36 | f"The length of the output_column ({len(output_column)}) is greater than the length of the input_column ({len(input_column)})." 37 | ) 38 | 39 | success_idx = start_idx 40 | 41 | try: 42 | for idx, value in enumerate( 43 | tqdm(input_column[start_idx:], desc="Processing data..", unit="row"), 44 | start=start_idx, 45 | ): 46 | result = func(value) 47 | 48 | if idx >= len(output_column): 49 | output_column.append(result) 50 | else: 51 | output_column[idx] = result 52 | 53 | success_idx = idx + 1 54 | except Exception as e: 55 | print(f"Error occurred at index {success_idx}: {str(e)}") 56 | print(f"Processing stopped at index {success_idx - 1}") 57 | return success_idx 58 | 59 | return success_idx 60 | 61 | 62 | def apply_to_column_batch( 63 | input_column: List, 64 | output_column: List, 65 | func: Callable, 66 | batch_size: int = 5, 67 | start_idx: int = 0, 68 | ) -> int: 69 | """ 70 | Apply a function to each batch of values in a column of a DataFrame or a normal Python list, starting from a specified index. 71 | 72 | Parameters: 73 | input_column (List): The column of the DataFrame or a normal Python list to apply the function to. 74 | output_column (List): A list to store the processed values. The function will mutate this list in-place. 75 | func (callable): The batch function to apply to each batch of values in the column. 76 | The function should take a list of values as input and return a list of processed values. 77 | batch_size (int, optional): The size of each batch. Default is 5. 78 | start_idx (int, optional): The index from which to start applying the function. Default is 0. 79 | 80 | Returns: 81 | tuple: A tuple containing two elements: 82 | - success_idx (int): The index of the last successfully processed batch. 83 | 84 | """ 85 | if len(input_column) == 0: 86 | raise ValueError("Input input_column is empty.") 87 | 88 | if start_idx >= len(input_column): 89 | raise ValueError( 90 | f"start_idx ({start_idx}) is greater than or equal to the length of the input_column ({len(input_column)})." 91 | ) 92 | 93 | if len(output_column) > len(input_column): 94 | raise ValueError( 95 | f"The length of the output_column list ({len(output_column)}) is greater than the length of th input_column ({len(input_column)})." 96 | ) 97 | 98 | check_func(func) 99 | success_idx = start_idx 100 | num_items = len(input_column) 101 | try: 102 | with tqdm(total=num_items, desc="Processing data..", unit="item") as pbar: 103 | for i in range(start_idx, num_items, batch_size): 104 | batch_end = min(i + batch_size, num_items) 105 | batch = input_column[i:batch_end] 106 | batch_result = func(batch) 107 | 108 | # Update output column 109 | if i >= len(output_column): 110 | output_column.extend(batch_result) 111 | else: 112 | output_column[i : i + len(batch_result)] = batch_result 113 | 114 | # Update progress bar by the number of items processed in this batch 115 | pbar.update(len(batch_result)) 116 | 117 | success_idx = batch_end 118 | except Exception as e: 119 | 120 | print(f"Error occurred at batch starting at index {success_idx}: {str(e)}") 121 | print(f"Processing stopped at batch ending at index {success_idx - 1}") 122 | return success_idx 123 | 124 | return min(success_idx, len(input_column)) 125 | 126 | 127 | def apply_to_column_autobatch( 128 | input_column: List, 129 | output_column: List, 130 | func: Callable, 131 | max_retries: int = 3, 132 | max_batch_size: int = 5, 133 | batch_size: int = 2, 134 | ramp_factor: float = 1.5, 135 | ramp_factor_decay: float = 0.8, 136 | reduce_factor: float = 0.5, 137 | reduce_factor_decay: float = 0.8, 138 | start_idx: int = 0, 139 | ) -> int: 140 | """ 141 | Apply a function to the input column using adaptive batch processing. 142 | 143 | This function applies a batch processing function to the input column, starting from the 144 | specified index. It adaptively adjusts the batch size based on the success or failure of 145 | each batch processing attempt. The function retries failed batches with a reduced batch 146 | size and gradually decreases the rate of batch size adjustment over time. 147 | 148 | Parameters: 149 | input_column (List): The input column to be processed. 150 | output_column (List): The list to store the processed results. 151 | func (callable): The batch function to apply to each batch of values in the column. 152 | The function should take a list of values as input and return a list of processed values. 153 | max_retries (int): The maximum number of retries for failed batches. 154 | max_batch_size (int): The maximum allowed batch size. 155 | batch_size (int): The initial batch size. 156 | ramp_factor (float): The factor by which the batch size is increased after a successful batch. 157 | ramp_factor_decay (float): The decay rate for the ramp factor after each successful batch. 158 | reduce_factor (float): The factor by which the batch size is reduced after a failed batch. 159 | reduce_factor_decay (float): The decay rate for the reduce factor after each failed batch. 160 | start_idx (int): The index from which to start processing the input column. 161 | 162 | Returns: 163 | int: The index of the last successfully processed item in the input column. 164 | 165 | """ 166 | if len(input_column) == 0: 167 | raise ValueError("Input input_column is empty.") 168 | 169 | if start_idx >= len(input_column): 170 | raise ValueError( 171 | f"start_idx ({start_idx}) is greater than or equal to the length of the input_column ({len(input_column)})." 172 | ) 173 | 174 | if len(output_column) > len(input_column): 175 | raise ValueError( 176 | f"The length of the output_column list ({len(output_column)}) is greater than the length of the input_column ({len(input_column)})." 177 | ) 178 | 179 | check_func(func) 180 | success_idx = start_idx 181 | ramp_factor = ramp_factor 182 | reduce_factor = reduce_factor 183 | 184 | try: 185 | remaining_data = input_column[start_idx:] 186 | batch_size = batch_size 187 | retry_count = 0 188 | 189 | with tqdm( 190 | total=len(remaining_data), desc="Processing data..", unit="row" 191 | ) as pbar: 192 | while len(remaining_data) > 0: 193 | try: 194 | batch_size = min(batch_size, len(remaining_data)) 195 | batch = remaining_data[:batch_size] 196 | batch_results = func(batch) 197 | 198 | # Update output_column in place 199 | output_column[success_idx : success_idx + batch_size] = ( 200 | batch_results 201 | ) 202 | 203 | remaining_data = remaining_data[batch_size:] 204 | retry_count = 0 205 | success_idx += batch_size 206 | 207 | # Update progress bar 208 | pbar.update(batch_size) 209 | 210 | # Increase the batch size using the decayed ramp factor 211 | batch_size = min(round(batch_size * ramp_factor), max_batch_size) 212 | ramp_factor = max(ramp_factor * ramp_factor_decay, 1.0) 213 | 214 | except Exception as e: 215 | if retry_count >= max_retries: 216 | raise ValueError( 217 | f"Processing failed after {max_retries} retries. Error: {str(e)}" 218 | ) 219 | retry_count += 1 220 | # Decrease the batch size using the decayed reduce factor 221 | batch_size = max(round(batch_size * reduce_factor), 1) 222 | print(f"Retrying with smaller batch size: {batch_size}") 223 | reduce_factor *= reduce_factor_decay 224 | 225 | except Exception as e: 226 | print(f"Error occurred at batch starting at index {success_idx}: {str(e)}") 227 | print(f"Processing stopped at batch ending at index {success_idx - 1}") 228 | return success_idx 229 | 230 | return min(success_idx, len(input_column)) 231 | 232 | 233 | def check_func(func): 234 | if not inspect.signature(func).parameters: 235 | raise TypeError("The provided function does not take any arguments.") 236 | 237 | # Ensure func is a batch function that takes a list 238 | first_param = list(inspect.signature(func).parameters.values())[0] 239 | param_annotation = first_param.annotation 240 | origin = get_origin(param_annotation) 241 | 242 | if origin is not get_origin(List): 243 | raise TypeError( 244 | "The provided function does not take a list or pandas.Series as input." 245 | ) 246 | -------------------------------------------------------------------------------- /databonsai/utils/logs.py: -------------------------------------------------------------------------------- 1 | # utils/logger.py 2 | import logging 3 | 4 | # Configure the root logger 5 | logging.basicConfig( 6 | level=logging.WARNING, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" 7 | ) 8 | 9 | # Create a global logger instance 10 | logger = logging.getLogger(__name__) 11 | -------------------------------------------------------------------------------- /docs/AnthropicProvider.md: -------------------------------------------------------------------------------- 1 | # AnthropicProvider 2 | 3 | The `AnthropicProvider` class is a provider class that interacts with 4 | Anthropic's Claude API for generating text completions. It supports exponential 5 | backoff retries (from tenacity's library) to handle temporary failures, which is 6 | particularly useful when dealing with large datasets. 7 | 8 | ## Initialization 9 | 10 | The `__init__` method initializes the `AnthropicProvider` with an API key and 11 | retry parameters. 12 | 13 | ### Parameters 14 | 15 | - `api_key (str)`: Anthropic API key. 16 | - `multiplier (int)`: The multiplier for the exponential backoff in retries 17 | (default: 1). 18 | - `min_wait (int)`: The minimum wait time between retries (default: 1). 19 | - `max_wait (int)`: The maximum wait time between retries (default: 60). 20 | - `max_tries (int)`: The maximum number of attempts before giving up (default: 21 | 10). 22 | - `model (str)`: The default model to use for text generation (default: 23 | "claude-3-haiku-20240307"). 24 | - `temperature (float)`: The temperature parameter for text generation 25 | (default: 0). 26 | 27 | ## Methods 28 | 29 | ### `generate` 30 | 31 | The `generate` method generates a text completion using Anthropic's Claude API, 32 | given a system prompt and a user prompt. It is decorated with retry logic to 33 | handle temporary failures. 34 | 35 | #### Parameters 36 | 37 | - `system_prompt (str)`: The system prompt to provide context or instructions 38 | for the generation. 39 | - `user_prompt (str)`: The user's prompt, based on which the text completion 40 | is generated. 41 | - `max_tokens (int)`: The maximum number of tokens to generate in the response 42 | (default: 1000). 43 | 44 | #### Returns 45 | 46 | - `str`: The generated text completion. 47 | 48 | ## Retry Decorator 49 | 50 | The `retry_with_exponential_backoff` decorator is used to apply retry logic with 51 | exponential backoff to instance methods. It captures the `self` context to 52 | access instance attributes for retry configuration. 53 | 54 | ## Usage 55 | 56 | If your ANTHROPIC_API_KEY is defined in .env: 57 | 58 | ```python 59 | from databonsai.llm_providers import AnthropicProvider 60 | 61 | provider = AnthropicProvider() 62 | ``` 63 | 64 | Or, provide the api key as an argument: 65 | 66 | ```python 67 | provider = AnthropicProvider(api_key="your_Anthropic_api_key") 68 | ``` 69 | 70 | Other parameters, for example: 71 | 72 | ```python 73 | provider = AnthropicProvider(model="claude-3-opus-20240229", max_tries=5, max_wait=120) 74 | ``` 75 | -------------------------------------------------------------------------------- /docs/BaseCategorizer.md: -------------------------------------------------------------------------------- 1 | # BaseCategorizer 2 | 3 | The `BaseCategorizer` class provides functionality for categorizing input data 4 | utilizing a specified LLM provider. This class serves as a base for implementing 5 | categorization tasks where inputs are classified into predefined categories. 6 | 7 | ## Features 8 | 9 | - **Custom Categories**: Define your own categories for data classification. 10 | - **Input Validation**: Ensures the integrity of categories, examples, and 11 | input data for reliable categorization. 12 | - **Batch Categorization**: Categorizes multiple inputs simultaneously for 13 | token savings 14 | 15 | ## Attributes 16 | 17 | - `categories` (Dict[str, str]): A dictionary mapping category names to their 18 | descriptions. This structure allows for a clear definition of possible 19 | categories for classification. 20 | - `llm_provider` (LLMProvider): An instance of an LLM provider to be used for 21 | categorization. 22 | - `examples` (Optional[List[Dict[str, str]]]): A list of example inputs and 23 | their corresponding categories to improve categorization accuracy. 24 | - `strict` (bool): If True, raises an error when the predicted category is not 25 | one of the provided categories. 26 | 27 | ## Computed Fields 28 | 29 | - `system_message` (str): A system message used for single input 30 | categorization based on the provided categories and examples. 31 | - `system_message_batch` (str): A system message used for batch input 32 | - `category_mapping` (Dict[int, str]): Mapping of category index to category 33 | name 34 | - `inverse_category_mapping` (Dict[str, int]): Mapping of category name to 35 | index 36 | 37 | ## Methods 38 | 39 | ### `categorize` 40 | 41 | Categorizes the input data using the specified LLM provider. 42 | 43 | #### Arguments 44 | 45 | - `input_data` (str): The text data to be categorized. 46 | 47 | #### Returns 48 | 49 | - `str`: The predicted category for the input data. 50 | 51 | #### Raises 52 | 53 | - `ValueError`: If the predicted category is not one of the provided 54 | categories. 55 | 56 | ### `categorize_batch` 57 | 58 | Categorizes a batch of input data using the specified LLM provider. For less 59 | advanced LLMs, call this method on batches of 3-5 inputs (depending on the 60 | length of the input data). 61 | 62 | #### Arguments 63 | 64 | - `input_data` (List[str]): A list of text data to be categorized. 65 | 66 | #### Returns 67 | 68 | - `List[str]`: A list of predicted categories for the input data. 69 | 70 | #### Raises 71 | 72 | - `ValueError`: If the predicted categories are not a subset of the provided 73 | categories or if the number of predicted categories does not match the 74 | number of input data. 75 | 76 | ## Usage 77 | 78 | Setup the LLM provider and categories (as a dictionary): 79 | 80 | ```python 81 | from databonsai.categorize import BaseCategorizer 82 | from databonsai.llm_providers import OpenAIProvider, AnthropicProvider 83 | 84 | provider = OpenAIProvider() # Or AnthropicProvider() 85 | categories = { 86 | "Weather": "Insights and remarks about weather conditions.", 87 | "Sports": "Observations and comments on sports events.", 88 | "Celebrities": "Celebrity sightings and gossip", 89 | "Others": "Comments do not fit into any of the above categories", 90 | "Anomaly": "Data that does not look like comments or natural language", 91 | } 92 | ``` 93 | 94 | Categorize your data: 95 | 96 | ```python 97 | categorizer = BaseCategorizer( 98 | categories=categories, 99 | llm_provider=provider, 100 | # strict=False, # Default true, set to False to allow for categories not in the provided 101 | ) 102 | category = categorizer.categorize("It's been raining outside all day") 103 | print(category) 104 | ``` 105 | 106 | Output: 107 | 108 | ```python 109 | Weather 110 | ``` 111 | 112 | Categorize a few inputs: 113 | 114 | ```python 115 | categories = categorizer.categorize([ 116 | "Storm Delays Government Budget Meeting, Weather and Politics Clash", 117 | "Olympic Star's Controversial Tweets Ignite Political Debate, Sports Meets Politics", 118 | "Local Football Hero Opens New Gym, Sports and Business Combine"]) 119 | print(categories) 120 | ``` 121 | 122 | Output: 123 | 124 | ```python 125 | ['Weather', 'Sports', 'Celebrities'] 126 | ``` 127 | 128 | Categorize a list of inputs (Use shorter lists for weaker LLMs): 129 | 130 | ```python 131 | headlines = [ 132 | "Massive Blizzard Hits the Northeast, Thousands Without Power", 133 | "Local High School Basketball Team Wins State Championship After Dramatic Final", 134 | "Celebrated Actor Launches New Environmental Awareness Campaign", 135 | "President Announces Comprehensive Plan to Combat Cybersecurity Threats", 136 | "Tech Giant Unveils Revolutionary Quantum Computer", 137 | "Tropical Storm Alina Strengthens to Hurricane as It Approaches the Coast", 138 | "Olympic Gold Medalist Announces Retirement, Plans Coaching Career", 139 | "Film Industry Legends Team Up for Blockbuster Biopic", 140 | "Government Proposes Sweeping Reforms in Public Health Sector", 141 | "Startup Develops App That Predicts Traffic Patterns Using AI", 142 | ] 143 | categories = categorizer.categorize_batch(headlines) 144 | print(categories) 145 | ``` 146 | 147 | Output: 148 | 149 | ```python 150 | ['Weather', 'Sports', 'Celebrities', 'Politics', 'Tech', 'Weather', 'Sports', 'Celebrities', 'Politics', 'Tech'] 151 | ``` 152 | 153 | Categorize a long list of inputs, or a dataframe column: 154 | 155 | ```python 156 | from databonsai.utils import apply_to_column_batch, apply_to_column 157 | 158 | categories = [] 159 | success_idx = apply_to_column_batch(headlines, categories, categorizer.categorize, 3, 0) 160 | ``` 161 | 162 | Without batching: 163 | 164 | ```python 165 | categories = [] 166 | success_idx = apply_to_column(headlines, categories, categorizer.categorize) 167 | ``` 168 | -------------------------------------------------------------------------------- /docs/BaseTransformer.md: -------------------------------------------------------------------------------- 1 | # BaseTransformer 2 | 3 | The `BaseTransformer` class is a base class for transforming input data using a 4 | specified LLM provider. It provides a foundation for implementing data 5 | transformation tasks using language models. 6 | 7 | ## Features 8 | 9 | - **Transformation Prompts**: Define your own prompts to guide the 10 | transformation process. 11 | - **Input Validation**: Ensures the integrity of prompts, examples, and input 12 | data for reliable transformation. 13 | - **Few-Shot Learning**: Supports providing example inputs and responses to 14 | improve transformation accuracy. 15 | - **Batch Transformation**: Transforms multiple inputs simultaneously for 16 | token savings. 17 | 18 | ## Attributes 19 | 20 | - `prompt` (str): The prompt used to guide the transformation process. It 21 | provides instructions or context for the LLM provider to perform the desired 22 | transformation. 23 | - `llm_provider` (LLMProvider): An instance of an LLM provider to be used for 24 | transformation. The LLM provider is responsible for generating the 25 | transformed data based on the input data and the provided prompt. 26 | - `examples` (Optional[List[Dict[str, str]]]): A list of example inputs and 27 | their corresponding transformed outputs to improve transformation accuracy. 28 | 29 | ## Computed Fields 30 | 31 | - `system_message` (str): A system message used for single input 32 | transformation based on the provided prompt and examples. 33 | - `system_message_batch` (str): A system message used for batch input 34 | transformation based on the provided prompt and examples. 35 | 36 | ## Methods 37 | 38 | ### `transform` 39 | 40 | Transforms the input data using the specified LLM provider. 41 | 42 | #### Arguments 43 | 44 | - `input_data` (str): The text data to be transformed. 45 | - `max_tokens` (int, optional): The maximum number of tokens to generate in 46 | the response. Defaults to 1000. 47 | - `json` (bool): Whether to turn on JSON mode. Only works with OpenAIProvider for now 48 | 49 | #### Returns 50 | 51 | - `str`: The transformed data. 52 | 53 | ### `transform_batch` 54 | 55 | Transforms a batch of input data using the specified LLM provider. 56 | 57 | #### Arguments 58 | 59 | - `input_data` (List[str]): A list of text data to be transformed. 60 | - `max_tokens` (int, optional): The maximum number of tokens to generate in 61 | each response. Defaults to 1000. 62 | 63 | #### Returns 64 | 65 | - `List[str]`: A list of transformed data, where each element corresponds to 66 | the transformed version of the respective input data. 67 | 68 | #### Raises 69 | 70 | - `ValueError`: If the length of the output list does not match the length of 71 | the input list. 72 | 73 | ## Usage 74 | 75 | Prepare the transformer: 76 | 77 | ```python 78 | from databonsai.llm_providers import OpenAIProvider, AnthropicProvider 79 | from databonsai.transform import BaseTransformer 80 | 81 | pii_remover = BaseTransformer( 82 | prompt="Replace any Personal Identity Identifiers (PII) in the given text with . PII includes any information that can be used to identify an individual, such as names, addresses, phone numbers, email addresses, social security numbers, etc.", 83 | llm_provider=AnthropicProvider(), 84 | examples=[ 85 | { 86 | "example": "My name is John Doe and my phone number is (555) 123-4567.", 87 | "response": "My name is and my phone number is .", 88 | }, 89 | { 90 | "example": "My email address is johndoe@gmail.com.", 91 | "response": "My email address is .", 92 | }, 93 | ], 94 | ) 95 | ``` 96 | 97 | Run the transformation: 98 | 99 | ```python 100 | print( 101 | pii_remover.transform( 102 | "John Doe, residing at 1234 Maple Street, Anytown, CA, 90210, recently contacted customer support to report an issue. He provided his phone number, (555) 123-4567, and email address, johndoe@email.com, for follow-up communication." 103 | ) 104 | ) 105 | ``` 106 | 107 | Output: 108 | 109 | ```python 110 | , residing at
, , , , recently contacted customer support to report an issue. They provided their phone number, , and email address, , for follow-up communication. 111 | ``` 112 | 113 | Transforma a list of data: (use apply_to_column_batch for large datasets) 114 | 115 | ```python 116 | pii_texts = [ 117 | "Just confirmed the reservation for Amanda Clark, Passport No. B2345678, traveling to Tokyo on May 15, 2024.", 118 | "Received payment from Michael Thompson, Credit Card ending in 4547, Billing Address: 45 Westview Lane, Springfield, IL.", 119 | "Application received from Julia Martinez, DOB 03/19/1994, SSN 210-98-7654, applying for the marketing position.", 120 | "Lease agreement finalized for Henry Wilson, Tenant ID: WILH12345, property at 89 Riverside Drive, Brooklyn, NY.", 121 | "Registration details for Lucy Davis, Student ID 20231004, enrolled in Advanced Chemistry, Fall semester.", 122 | "David Lee called about his insurance claim, Policy #9988776655, regarding the accident on April 5th.", 123 | "Booking confirmation for Sarah H. Richards, Flight AC202, Seat 14C, Frequent Flyer #GH5554321, departing June 12, 2024.", 124 | "Kevin Brown's gym membership has been renewed, Member ID: 654321, Phone: (555) 987-6543, Email: kbrown@example.com.", 125 | "Prescription ready for Emma Thomas, Health ID 567890123, prescribed by Dr. Susan Hill on April 10th, 2024.", 126 | "Alice Johnson requested a copy of her employment contract, Employee No. 112233, hired on August 1st, 2023.", 127 | ] 128 | 129 | cleaned_texts = pii_remover.transform_batch(pii_texts) 130 | print(cleaned_texts) 131 | ``` 132 | 133 | Output: 134 | 135 | ```python 136 | ['Just confirmed the reservation for , Passport No. , traveling to Tokyo on May 15, 2024.', 'Received payment from , Credit Card ending in , Billing Address: 45 Westview Lane, Springfield, IL.', 'Application received from , DOB 03/19/1994, SSN , applying for the marketing position.', 'Lease agreement finalized for , Tenant ID: WILH12345, property at 89 Riverside Drive, Brooklyn, NY.', 'Registration details for , Student ID 20231004, enrolled in Advanced Chemistry, Fall semester.', ' called about his insurance claim, Policy #9988776655, regarding the accident on April 5th.', 'Booking confirmation for , Flight AC202, Seat 14C, Frequent Flyer #GH5554321, departing June 12, 2024.', "Kevin Brown's gym membership has been renewed, Member ID: 654321, Phone: (555) 987-6543, Email: kbrown@example.com.", 'Prescription ready for , Health ID 567890123, prescribed by Dr. Susan Hill on April 10th, 2024.', 'Alice Johnson requested a copy of her employment contract, Employee No. 112233, hired on August 1st, 2023.'] 137 | ``` 138 | 139 | Transform a long list of data or a dataframe column (with batching): 140 | 141 | ```python 142 | from databonsai.utils import apply_to_column_batch, apply_to_column 143 | 144 | cleaned_texts = [] 145 | success_idx = apply_to_column_batch( 146 | pii_texts, cleaned_texts, pii_remover.transform_batch, 4, 0 147 | ) 148 | ``` 149 | 150 | Without batching: 151 | 152 | ``` 153 | cleaned_texts = [] 154 | success_idx = apply_to_column( 155 | pii_texts, cleaned_texts, pii_remover.transform 156 | ) 157 | ``` 158 | -------------------------------------------------------------------------------- /docs/ExtractTransformer.md: -------------------------------------------------------------------------------- 1 | # ExtractTransformer 2 | 3 | The `ExtractTransformer` class extends the `BaseTransformer` class and overrides 4 | the `transform` method to extract a given schema from the input data into a list 5 | of dictionaries. It allows for transforming input data into a structured format 6 | according to a specified schema. 7 | 8 | ## Features 9 | 10 | - **Custom Output Schema**: Define your own output schema to structure the 11 | transformed data. 12 | - **Input Validation**: Ensures the integrity of the output schema, examples, 13 | and input data for reliable transformation. 14 | - **Few-Shot Learning**: Supports providing example inputs and responses to 15 | improve transformation accuracy. 16 | 17 | ## Attributes 18 | 19 | - `output_schema` (Dict[str, str]): A dictionary representing the schema of 20 | the output dictionaries. It defines the expected keys and their 21 | corresponding value types in the transformed data. 22 | - `examples` (Optional[List[Dict[str, str]]]): A list of example inputs and 23 | their corresponding extracted outputs. 24 | 25 | ## Computed Fields 26 | 27 | - `system_message` (str): A system message used for single input 28 | transformation based on the provided prompt, output schema, and examples. 29 | 30 | ## Methods 31 | 32 | ### `transform` 33 | 34 | Transforms the input data into a list of dictionaries using the specified LLM 35 | provider. 36 | 37 | #### Arguments 38 | 39 | - `input_data` (str): The text data to be transformed. 40 | - `max_tokens` (int, optional): The maximum number of tokens to generate in 41 | the response. Defaults to 1000. 42 | - `json` (bool): Whether to turn on JSON mode. Only works with OpenAIProvider for now 43 | 44 | #### Returns 45 | 46 | - `List[Dict[str, str]]`: The transformed data as a list of dictionaries. 47 | 48 | #### Raises 49 | 50 | - `ValueError`: If the transformed data does not match the expected format or 51 | schema. 52 | 53 | ## Usage 54 | 55 | Prepare a Extract transformer with a prompt and output schema: 56 | 57 | ```python 58 | from databonsai.llm_providers import OpenAIProvider 59 | from databonsai.transform import ExtractTransformer 60 | 61 | output_schema = { 62 | "question": "generated question about given information", 63 | "answer": "answer to the question, only using information from the given data", 64 | } 65 | 66 | qna = ExtractTransformer( 67 | prompt="Your goal is to create a set of questions and answers to help a person memorise every single detail of a document.", 68 | output_schema=output_schema, 69 | llm_provider=OpenAIProvider(), 70 | examples=[ 71 | { 72 | "example": "Bananas are naturally radioactive due to their potassium content. They contain potassium-40, a radioactive isotope of potassium, which contributes to a tiny amount of radiation in every banana.", 73 | "response": str( 74 | [ 75 | { 76 | "question": "Why are bananas naturally radioactive?", 77 | "answer": "Bananas are naturally radioactive due to their potassium content.", 78 | }, 79 | { 80 | "question": "What is the radioactive isotope of potassium in bananas?", 81 | "answer": "The radioactive isotope of potassium in bananas is potassium-40.", 82 | }, 83 | ] 84 | ), 85 | } 86 | ], 87 | ) 88 | ``` 89 | 90 | Here's the text we want to extract questions and answers from: 91 | 92 | ```python 93 | text = """ Sky-gazers across North America are in for a treat on April 8 when a total solar eclipse will pass over Mexico, the United States and Canada. 94 | 95 | The event will be visible to millions — including 32 million people in the US alone — who live along the route the moon’s shadow will travel during the eclipse, known as the path of totality. For those in the areas experiencing totality, the moon will appear to completely cover the sun. Those along the very center line of the path will see an eclipse that lasts between 3½ and 4 minutes, according to NASA. 96 | 97 | The next total solar eclipse won’t be visible across the contiguous United States again until August 2044. (It’s been nearly seven years since the “Great American Eclipse” of 2017.) And an annular eclipse won’t appear across this part of the world again until 2046.""" 98 | 99 | print(qna.transform(text)) 100 | ``` 101 | 102 | Output: 103 | 104 | ```python 105 | [ 106 | { 107 | "question": "When will the total solar eclipse pass over Mexico, the United States, and Canada?", 108 | "answer": "The total solar eclipse will pass over Mexico, the United States, and Canada on April 8.", 109 | }, 110 | { 111 | "question": "What is the path of totality?", 112 | "answer": "The path of totality is the route the moon's shadow will travel during the eclipse where the moon will appear to completely cover the sun.", 113 | }, 114 | { 115 | "question": "How long will the eclipse last for those along the very center line of the path of totality?", 116 | "answer": "For those along the very center line of the path of totality, the eclipse will last between 3½ and 4 minutes.", 117 | }, 118 | { 119 | "question": "When will the next total solar eclipse be visible across the contiguous United States?", 120 | "answer": "The next total solar eclipse visible across the contiguous United States will be in August 2044.", 121 | }, 122 | { 123 | "question": "When will an annular eclipse next appear across the contiguous United States?", 124 | "answer": "An annular eclipse won't appear across the contiguous United States again until 2046.", 125 | }, 126 | ] 127 | ``` 128 | 129 | Batching is not supported for ExtractTransformer yet. 130 | -------------------------------------------------------------------------------- /docs/MultiCategorizer.md: -------------------------------------------------------------------------------- 1 | # MultiCategorizer 2 | 3 | The `MultiCategorizer` class is an extension of the `BaseCategorizer` class, 4 | providing functionality for categorizing input data into multiple categories 5 | using a specified LLM provider. This class overrides the `categorize` and 6 | `categorize_batch` methods to enable the prediction of multiple categories for a 7 | given input. 8 | 9 | ## Features 10 | 11 | - **Multiple Category Prediction**: Categorize input data into one or more 12 | predefined categories. 13 | - **Subset Validation**: Ensures that the predicted categories are a subset of 14 | the provided categories. 15 | - **Batch Categorization**: Categorizes multiple inputs simultaneously for 16 | token savings. 17 | 18 | ## Attributes 19 | 20 | - `categories` (Dict[str, str]): A dictionary mapping category names to their 21 | descriptions. This structure allows for a clear definition of possible 22 | categories for classification. 23 | - `llm_provider` (LLMProvider): An instance of an LLM provider to be used for 24 | categorization. 25 | - `examples` (Optional[List[Dict[str, str]]]): A list of example inputs and 26 | their corresponding categories to improve categorization accuracy. If there 27 | are multiple categories for an example, they should be separated by commas. 28 | - `strict` (bool): If True, raises an error when the predicted category is not 29 | one of the provided categories. 30 | 31 | ## Computed Fields 32 | 33 | - `system_message` (str): A system message used for single input 34 | categorization based on the provided categories and examples. 35 | - `system_message_batch` (str): A system message used for batch input 36 | categorization based on the provided categories and examples. 37 | - `category_mapping` (Dict[int, str]): Mapping of category index to category 38 | name 39 | - `inverse_category_mapping` (Dict[str, int]): Mapping of category name to 40 | index 41 | 42 | ## Methods 43 | 44 | ### `categorize` 45 | 46 | Categorizes the input data into multiple categories using the specified LLM 47 | provider. 48 | 49 | #### Arguments 50 | 51 | - `input_data` (str): The text data to be categorized. 52 | 53 | #### Returns 54 | 55 | - `str`: A string of categories for the input data, separated by commas. 56 | 57 | #### Raises 58 | 59 | - `ValueError`: If the predicted categories are not a subset of the provided 60 | categories. 61 | 62 | ### `categorize_batch` 63 | 64 | Categorizes a batch of input data into multiple categories using the specified 65 | LLM provider. For less advanced LLMs, call this method on batches of 3-5 inputs 66 | (depending on the length of the input data). 67 | 68 | #### Arguments 69 | 70 | - `input_data` (List[str]): A list of text data to be categorized. 71 | 72 | #### Returns 73 | 74 | - `List[str]`: A list of predicted categories for each input data. If there 75 | are multiple categories for an input, they will be separated by commas. 76 | 77 | #### Raises 78 | 79 | - `ValueError`: If the predicted categories are not a subset of the provided 80 | categories or if the number of predicted category sets does not match the 81 | number of input data. 82 | 83 | ## Usage 84 | 85 | Setup the LLM provider and categories (as a dictionary): 86 | 87 | ```python 88 | from databonsai.categorize import MultiCategorizer 89 | from databonsai.llm_providers import OpenAIProvider, AnthropicProvider 90 | 91 | provider = OpenAIProvider() # Or AnthropicProvider() 92 | categories = { 93 | "Weather": "Insights and remarks about weather conditions.", 94 | "Sports": "Observations and comments on sports events.", 95 | "Celebrities": "Celebrity sightings and gossip", 96 | "Others": "Comments do not fit into any of the above categories", 97 | "Anomaly": "Data that does not look like comments or natural language", 98 | } 99 | 100 | tagger = MultiCategorizer( 101 | categories=categories, 102 | llm_provider=provider, 103 | examples=[ 104 | { 105 | "example": "Big stormy skies over city causes league football game to be cancelled", 106 | "response": "Weather,Sports", 107 | }, 108 | { 109 | "example": "Elon musk likes to play golf", 110 | "response": "Sports,Celebrities", 111 | }, 112 | # strict=False, # Default true, set to False to allow for categories not in the provided 113 | ], 114 | ) 115 | categories = tagger.categorize( 116 | "It's been raining outside all day, and I saw Elon Musk. 13rewfdsacw10289u(#!*@)" 117 | ) 118 | print(categories) 119 | ``` 120 | 121 | Output: 122 | 123 | ```python 124 | ['Weather', 'Celebrities', 'Anomaly'] 125 | ``` 126 | 127 | Categorize a list of data: 128 | 129 | ```python 130 | mixed_headlines = [ 131 | "Storm Delays Government Budget Meeting, Weather and Politics Clash", 132 | "Olympic Star's Controversial Tweets Ignite Political Debate, Sports Meets Politics", 133 | "Local Football Hero Opens New Gym, Sports and Business Combine", 134 | "Tech CEO's Groundbreaking Climate Initiative, Technology and Environment at Forefront", 135 | "Celebrity Chef Fights for Food Security Legislation, Culinary Meets Politics", 136 | "Hollywood Biopic of Legendary Athlete Set to Premiere, Blending Sports and Cinema", 137 | "Massive Flooding Disrupts Local Elections, Intersection of Weather and Politics", 138 | "Tech Billionaire Invests in Sports Teams, Merging Business with Athletics", 139 | "Pop Star's Concert Raises Funds for Disaster Relief, Combining Music with Charity", 140 | "Film Festival Highlights Environmental Documentaries, Merging Cinema and Green Activism", 141 | ] 142 | categories = tagger.categorize_batch(mixed_headlines) 143 | ``` 144 | 145 | Output: 146 | 147 | ```python 148 | ['Weather,Politics', 'Sports,Politics', 'Sports,Others', 'Tech', 'Politics,Celebrities', 'Sports,Celebrities', 'Weather,Politics', 'Tech,Sports', 'Celebrities,Others', 'Celebrities,Others'] 149 | ``` 150 | 151 | Categorize a long list of data, or a dataframe column (with batching): 152 | 153 | ```python 154 | from databonsai.utils import apply_to_column_batch, apply_to_column 155 | 156 | categories = [] 157 | success_idx = apply_to_column_batch( 158 | input_column=mixed_headlines, 159 | output_column=categories, 160 | function=tagger.categorize_batch, 161 | batch_size=3, 162 | start_idx=0 163 | ) 164 | ``` 165 | 166 | Without batching: 167 | 168 | ```python 169 | categories = [] 170 | success_idx = apply_to_column( 171 | input_column=mixed_headlines, 172 | output_column=categories, 173 | function=tagger.categorize 174 | ) 175 | ``` 176 | -------------------------------------------------------------------------------- /docs/OllamaProvider.md: -------------------------------------------------------------------------------- 1 | # OllamaProvider 2 | 3 | The `OllamaProvider` class is a provider class that interacts with Ollama's API 4 | for generating text completions. Note that tokens are not counted for Ollama, 5 | and there is no retry logic (since it's not needed). 6 | 7 | ## Initialization 8 | 9 | The `__init__` method initializes the `OllamaProvider` with an optional Ollama 10 | client or host, and retry parameters. 11 | 12 | ### Parameters 13 | 14 | - `model (str)`: The default model to use for text generation (default: 15 | "llama3"). 16 | - `temperature (float)`: The temperature parameter for text generation 17 | (default: 0). 18 | - `host (str)`: The host URL for the Ollama API (optional). 19 | 20 | ## Methods 21 | 22 | ### `generate` 23 | 24 | The `generate` method generates a text completion using Ollama's API, with a 25 | given system prompt and a user prompt. 26 | 27 | #### Parameters 28 | 29 | - `system_prompt (str)`: The system prompt to provide context or instructions 30 | for the generation. 31 | - `user_prompt (str)`: The user's prompt, based on which the text completion 32 | is generated. 33 | - `max_tokens (int)`: The maximum number of tokens to generate in the response 34 | (default: 1000). 35 | 36 | #### Returns 37 | 38 | - `str`: The generated text completion. 39 | 40 | ## Usage 41 | 42 | If you have a host URL for the Ollama API: 43 | 44 | ```python 45 | from databonsai.llm_providers import OllamaProvider 46 | 47 | provider = OllamaProvider(model="llama3") 48 | ``` 49 | 50 | or, provide a host 51 | 52 | ```python 53 | provider = OllamaProvider(host="http://localhost:11434", model="llama3") 54 | ``` 55 | 56 | This uses ollama's python library under the hood. 57 | -------------------------------------------------------------------------------- /docs/OpenAIProvider.md: -------------------------------------------------------------------------------- 1 | # OpenAIProvider 2 | 3 | The `OpenAIProvider` class is a provider class that interacts with OpenAI's API 4 | for generating text completions. It supports exponential backoff retries (from 5 | tenacity's library) to handle temporary failures, which is particularly useful 6 | when dealing with large datasets. 7 | 8 | ## Initialization 9 | 10 | The `__init__` method initializes the `OpenAIProvider` with an API key and retry 11 | parameters. 12 | 13 | ### Parameters 14 | 15 | - `api_key (str)`: OpenAI API key. 16 | - `multiplier (int)`: The multiplier for the exponential backoff in retries 17 | (default: 1). 18 | - `min_wait (int)`: The minimum wait time between retries (default: 1). 19 | - `max_wait (int)`: The maximum wait time between retries (default: 60). 20 | - `max_tries (int)`: The maximum number of attempts before giving up (default: 21 | 10). 22 | - `model (str)`: The default model to use for text generation (default: 23 | "gpt-4-turbo"). 24 | - `temperature (float)`: The temperature parameter for text generation 25 | (default: 0). 26 | 27 | ## Methods 28 | 29 | ### `generate` 30 | 31 | The `generate` method generates a text completion using OpenAI's API, given a 32 | system prompt and a user prompt. It is decorated with retry logic to handle 33 | temporary failures. 34 | 35 | #### Parameters 36 | 37 | - `system_prompt (str)`: The system prompt to provide context or instructions 38 | for the generation. 39 | - `user_prompt (str)`: The user's prompt, based on which the text completion 40 | is generated. 41 | - `max_tokens (int)`: The maximum number of tokens to generate in the response 42 | (default: 1000). 43 | - `json (bool)`: Whether to use OpenAI's JSON response format (default: 44 | False). 45 | 46 | #### Returns 47 | 48 | - `str`: The generated text completion. 49 | 50 | ## Retry Decorator 51 | 52 | The `retry_with_exponential_backoff` decorator is used to apply retry logic with 53 | exponential backoff to instance methods. It captures the `self` context to 54 | access instance attributes for retry configuration. 55 | 56 | ## Usage 57 | 58 | If your OPENAI_API_KEY is defined in .env: 59 | 60 | ```python 61 | from databonsai.llm_providers import OpenAIProvider 62 | 63 | provider = OpenAIProvider() 64 | 65 | ``` 66 | 67 | Or, provide the api key as an argument: 68 | 69 | ```python 70 | provider = OpenAIProvider(api_key="your_openai_api_key") 71 | ``` 72 | 73 | Other parameters, for example: 74 | 75 | ```python 76 | provider = OpenAIProvider(model="gpt-4-turbo", max_tries=5, max_wait=120) 77 | ``` 78 | -------------------------------------------------------------------------------- /docs/Utils.md: -------------------------------------------------------------------------------- 1 | ## Util Methods 2 | 3 | ### `apply_to_column` 4 | 5 | Applies a function to each value in a column of a DataFrame or a normal Python 6 | list, starting from a specified index. 7 | 8 | #### Arguments 9 | 10 | - `input_column` (List): The column of the DataFrame or a normal Python list 11 | to which the function will be applied. 12 | - `output_column` (List): A list where the processed values will be stored. 13 | The function will mutate this list in-place. 14 | - `func` (Callable): The function to apply to each value in the column. It 15 | should take a single value as input and return a single value. 16 | - `start_idx` (int, optional): The index from which to start applying the 17 | function. Default is 0. 18 | 19 | #### Returns 20 | 21 | - `int`: The index of the last successfully processed value. 22 | 23 | #### Raises 24 | 25 | - `ValueError`: If the input or output column conditions are not met or if the 26 | starting index is out of bounds. 27 | 28 | ### `apply_to_column_batch` 29 | 30 | Applies a function to batches of values in a column of a DataFrame or a normal 31 | Python list, starting from a specified index. 32 | 33 | #### Arguments 34 | 35 | - `input_column` (List): The column of the DataFrame or a normal Python list 36 | to which the function will be applied. 37 | - `output_column` (List): A list where the processed values will be stored. 38 | The function will mutate this list in-place. 39 | - `func` (Callable): The batch function to apply to each batch of values in 40 | the column. It should take a list of values as input and return a list of 41 | processed values. 42 | - `batch_size` (int, optional): The size of each batch. Default is 5. 43 | - `start_idx` (int, optional): The index from which to start applying the 44 | function. Default is 0. 45 | 46 | #### Returns 47 | 48 | - `int`: The index of the last successfully processed batch. 49 | 50 | #### Raises 51 | 52 | - `ValueError`: If the input or output column conditions are not met or if the 53 | starting index or batch sizes are out of bounds. 54 | 55 | ### `apply_to_column_autobatch` 56 | 57 | Applies a function to the input column using adaptive batch processing, starting 58 | from a specified index and adjusting batch sizes based on success or failure. 59 | 60 | #### Arguments 61 | 62 | - `input_column` (List): The input column to be processed. 63 | - `output_column` (List): The list where the processed results will be stored. 64 | - `func` (Callable): The batch function used for processing. 65 | - `max_retries` (int): The maximum number of retries for failed batches. 66 | - `max_batch_size` (int): The maximum allowed batch size. 67 | - `batch_size` (int): The initial batch size. 68 | - `ramp_factor` (float): The factor by which the batch size is increased after 69 | a successful batch. 70 | - `ramp_factor_decay` (float): The decay rate for the ramp factor after each 71 | successful batch. 72 | - `reduce_factor` (float): The factor by which the batch size is reduced after 73 | a failed batch. 74 | - `reduce_factor_decay` (float): The decay rate for the reduce factor after 75 | each failed batch. 76 | - `start_idx` (int): The index from which to start processing the input 77 | column. 78 | 79 | #### Returns 80 | 81 | - `int`: The index of the last successfully processed item in the input 82 | column. 83 | 84 | #### Raises 85 | 86 | - `ValueError`: If the input or output column conditions are not met or if 87 | processing fails despite retries. 88 | 89 | ## Usage: 90 | 91 | ### AutoBatch for Larger datasets 92 | 93 | If you have a pandas dataframe or list, use `apply_to_column_autobatch` 94 | 95 | - Batching data for LLM api calls saves tokens by not sending the prompt for 96 | every row. However, too large a batch size / complex tasks can lead to 97 | errors. Naturally, the better the LLM model, the larger the batch size you 98 | can use. 99 | 100 | - This batching is handled adaptively (i.e., it will increase the batch size 101 | if the response is valid and reduce it if it's not, with a decay factor) 102 | 103 | Other features: 104 | 105 | - progress bar 106 | - returns the last successful index so you can resume from there, in case it 107 | exceeds max_retries 108 | - modifies your output list in place, so you don't lose any progress 109 | 110 | Retry Logic: 111 | 112 | - LLM providers have retry logic built in for API related errors. This can be 113 | configured in the provider. 114 | - The retry logic in the apply_to_column_autobatch is for handling invalid 115 | responses (e.g. unexpected category, different number of outputs, etc.) 116 | 117 | ```python 118 | from databonsai.utils import ( 119 | apply_to_column_batch, 120 | apply_to_column, 121 | apply_to_column_autobatch, 122 | ) 123 | import pandas as pd 124 | 125 | from databonsai.categorize import MultiCategorizer, BaseCategorizer 126 | from databonsai.llm_providers import OpenAIProvider, AnthropicProvider 127 | 128 | provider = OpenAIProvider() 129 | categories = { 130 | "Weather": "Insights and remarks about weather conditions.", 131 | "Sports": "Observations and comments on sports events.", 132 | "Politics": "Political events related to governments, nations, or geopolitical issues.", 133 | "Celebrities": "Celebrity sightings and gossip", 134 | "Others": "Comments do not fit into any of the above categories", 135 | "Anomaly": "Data that does not look like comments or natural language", 136 | } 137 | few_shot_examples = [ 138 | {"example": "Big stormy skies over city", "response": "Weather"}, 139 | {"example": "The team won the championship", "response": "Sports"}, 140 | {"example": "I saw a famous rapper at the mall", "response": "Celebrities"}, 141 | ] 142 | categorizer = BaseCategorizer( 143 | categories=categories, llm_provider=provider, examples=few_shot_examples 144 | ) 145 | category = categorizer.categorize("It's been raining outside all day") 146 | headlines = [ 147 | "Massive Blizzard Hits the Northeast, Thousands Without Power", 148 | "Local High School Basketball Team Wins State Championship After Dramatic Final", 149 | "Celebrated Actor Launches New Environmental Awareness Campaign", 150 | "President Announces Comprehensive Plan to Combat Cybersecurity Threats", 151 | "Tech Giant Unveils Revolutionary Quantum Computer", 152 | "Tropical Storm Alina Strengthens to Hurricane as It Approaches the Coast", 153 | "Olympic Gold Medalist Announces Retirement, Plans Coaching Career", 154 | "Film Industry Legends Team Up for Blockbuster Biopic", 155 | "Government Proposes Sweeping Reforms in Public Health Sector", 156 | "Startup Develops App That Predicts Traffic Patterns Using AI", 157 | ] 158 | df = pd.DataFrame(headlines, columns=["Headline"]) 159 | df["Category"] = None # Initialize it if it doesn't exist, as we modify it in place 160 | success_idx = apply_to_column_autobatch( 161 | df["Headline"], 162 | df["Category"], 163 | categorizer.categorize_batch, 164 | batch_size=3, 165 | start_idx=0, 166 | ) 167 | ``` 168 | 169 | There are many more options available for autobatch, such as setting a 170 | max_retries, decay factor, and more. Check the docs for more details. 171 | 172 | If it fails midway (even after exponential backoff), you can resume from the 173 | last successful index + 1. 174 | 175 | ```python 176 | success_idx = apply_to_column_autobatch( df["Headline"], df["Category"], categorizer.categorize_batch, batch_size=10, start_idx=success_idx+1) 177 | ``` 178 | 179 | This also works for regular python lists. 180 | 181 | Note that the better the LLM model, the greater the batch_size you can use 182 | (depending on the length of your inputs). If you're getting errors, reduce the 183 | batch_size, or use a better LLM model. 184 | 185 | To use it with batching, but with a fixed batch size: 186 | 187 | ```python 188 | success_idx = apply_to_column_batch( df["Headline"], df["Category"], categorizer.categorize_batch, batch_size=3, start_idx=0) 189 | ``` 190 | 191 | To use it without batching: 192 | 193 | ```python 194 | success_idx = apply_to_column( df["Headline"], df["Category"], categorizer.categorize) 195 | ``` 196 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # For the full list of built-in configuration values, see the documentation: 4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | 6 | # -- Project information ----------------------------------------------------- 7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 8 | import os 9 | import sys 10 | import sphinx_rtd_theme 11 | 12 | html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] 13 | html_theme_options = { 14 | "analytics_id": "UA-XXXXXXX-1", # Provided by Google Analytics 15 | "analytics_anonymize_ip": False, 16 | "logo_only": False, 17 | "display_version": True, 18 | "prev_next_buttons_location": "bottom", 19 | "style_external_links": False, 20 | "vcs_pageview_mode": "", 21 | "style_nav_header_background": "white", 22 | # Other options... 23 | } 24 | sys.path.insert(0, os.path.abspath("../../")) # Add project root 25 | sys.path.insert(0, os.path.abspath("../../databonsai")) 26 | 27 | project = "databonsai" 28 | copyright = "2024, Data Bonsai" 29 | author = "Data Bonsai" 30 | release = "0.2.0" 31 | 32 | # -- General configuration --------------------------------------------------- 33 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 34 | 35 | extensions = ["sphinx.ext.autodoc", "sphinx_rtd_theme"] 36 | 37 | templates_path = ["_templates"] 38 | exclude_patterns = [] 39 | 40 | 41 | # -- Options for HTML output ------------------------------------------------- 42 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 43 | 44 | html_theme = "sphinx_rtd_theme" 45 | html_static_path = ["_static"] 46 | -------------------------------------------------------------------------------- /docs/source/databonsai.categorize.rst: -------------------------------------------------------------------------------- 1 | databonsai.categorize package 2 | ============================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | databonsai.categorize.base\_categorizer module 8 | ---------------------------------------------- 9 | 10 | .. automodule:: databonsai.categorize.base_categorizer 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | databonsai.categorize.multi\_categorizer module 16 | ----------------------------------------------- 17 | 18 | .. automodule:: databonsai.categorize.multi_categorizer 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | Module contents 24 | --------------- 25 | 26 | .. automodule:: databonsai.categorize 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | -------------------------------------------------------------------------------- /docs/source/databonsai.llm_providers.rst: -------------------------------------------------------------------------------- 1 | databonsai.llm\_providers package 2 | ================================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | databonsai.llm\_providers.anthropic\_provider module 8 | ---------------------------------------------------- 9 | 10 | .. automodule:: databonsai.llm_providers.anthropic_provider 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | databonsai.llm\_providers.llm\_provider module 16 | ---------------------------------------------- 17 | 18 | .. automodule:: databonsai.llm_providers.llm_provider 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | databonsai.llm\_providers.openai\_provider module 24 | ------------------------------------------------- 25 | 26 | .. automodule:: databonsai.llm_providers.openai_provider 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | Module contents 32 | --------------- 33 | 34 | .. automodule:: databonsai.llm_providers 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | -------------------------------------------------------------------------------- /docs/source/databonsai.rst: -------------------------------------------------------------------------------- 1 | databonsai package 2 | ================== 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | databonsai.categorize 11 | databonsai.llm_providers 12 | databonsai.transform 13 | databonsai.utils 14 | 15 | Module contents 16 | --------------- 17 | 18 | .. automodule:: databonsai 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | -------------------------------------------------------------------------------- /docs/source/databonsai.transform.rst: -------------------------------------------------------------------------------- 1 | databonsai.transform package 2 | ============================ 3 | 4 | Submodules 5 | ---------- 6 | 7 | databonsai.transform.base\_transformer module 8 | --------------------------------------------- 9 | 10 | .. automodule:: databonsai.transform.base_transformer 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | databonsai.transform.decompose\_transformer module 16 | -------------------------------------------------- 17 | 18 | .. automodule:: databonsai.transform.decompose_transformer 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | Module contents 24 | --------------- 25 | 26 | .. automodule:: databonsai.transform 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | -------------------------------------------------------------------------------- /docs/source/databonsai.utils.rst: -------------------------------------------------------------------------------- 1 | databonsai.utils package 2 | ======================== 3 | 4 | Module contents 5 | --------------- 6 | 7 | .. automodule:: databonsai.utils 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. databonsai documentation master file, created by 2 | sphinx-quickstart on Tue Apr 9 16:43:03 2024. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to databonsai's documentation! 7 | ====================================== 8 | 9 | .. toctree:: 10 | :maxdepth: 2 11 | :caption: Contents: 12 | 13 | databonsai 14 | 15 | .. Indices and tables 16 | .. ================== 17 | 18 | .. * :ref:`genindex` 19 | .. * :ref:`modindex` 20 | .. * :ref:`search` 21 | -------------------------------------------------------------------------------- /docs/source/modules.rst: -------------------------------------------------------------------------------- 1 | databonsai 2 | ========== 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | databonsai 8 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "databonsai" 3 | version = "0.8.0" 4 | description = "A Python package to clean and curate your data with LLMs" 5 | authors = ["Alvin Ryanputra "] 6 | 7 | [tool.poetry.dependencies] 8 | python = "^3.8" 9 | 10 | openai = "^1.16.2" 11 | anthropic = "^0.23.1" 12 | tenacity = "^8.2.3" 13 | python-dotenv = "^1.0.1" 14 | pydantic = "^2.6.4" 15 | pydantic_core = "^2.16.3" 16 | ollama = "^0.1.0" 17 | 18 | [tool.poetry.dev-dependencies] 19 | # Add development dependencies here (if any) 20 | pytest = "^7.2.2" 21 | 22 | [build-system] 23 | requires = ["poetry-core"] 24 | build-backend = "poetry.core.masonry.api" -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="databonsai", 5 | version="0.8.0", 6 | description="A package for cleaning and curating data with LLMs", 7 | long_description=open("README.md", encoding="utf-8").read(), 8 | long_description_content_type="text/markdown", 9 | author="Alvin Ryanputra", 10 | author_email="databonsai.ai@gmail.com", 11 | url="https://github.com/databonsai/databonsai", 12 | packages=find_packages(), 13 | install_requires=[ 14 | "openai", 15 | "anthropic", 16 | "tenacity", 17 | "python-dotenv", 18 | "pydantic", 19 | "anthropic", 20 | "ollama", 21 | ], 22 | classifiers=[ 23 | "Development Status :: 3 - Alpha", 24 | "Intended Audience :: Developers", 25 | "License :: OSI Approved :: MIT License", 26 | "Programming Language :: Python :: 3.8", 27 | "Programming Language :: Python :: 3.9", 28 | "Programming Language :: Python :: 3.10", 29 | "Programming Language :: Python :: 3.11", 30 | "Programming Language :: Python :: 3.12", 31 | ], 32 | ) 33 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alvin-r/databonsai/3f2b7c58d0aa172251b5c3c321dea93038853e32/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_categorization.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from databonsai.categorize import BaseCategorizer, MultiCategorizer 3 | from databonsai.llm_providers import OpenAIProvider, AnthropicProvider 4 | from databonsai.utils import apply_to_column, apply_to_column_batch 5 | import pandas as pd 6 | 7 | 8 | @pytest.fixture 9 | def sample_categories(): 10 | return { 11 | "Weather": "Insights and remarks about weather conditions.", 12 | "Sports": "Observations and comments on sports events.", 13 | "Celebrities": "Celebrity sightings and gossip", 14 | "Others": "Comments do not fit into any of the above categories", 15 | "Anomaly": "Data that does not look like comments or natural language", 16 | } 17 | 18 | 19 | @pytest.fixture( 20 | params=[ 21 | OpenAIProvider(model="gpt-4-turbo", max_tries=1), 22 | AnthropicProvider(max_tries=1), 23 | ] 24 | ) 25 | def sample_provider(request): 26 | return request.param 27 | 28 | 29 | @pytest.fixture 30 | def sample_dataframe(): 31 | return pd.DataFrame( 32 | { 33 | "text": [ 34 | "Massive Blizzard Hits the Northeast, Thousands Without Power", 35 | "Local High School Basketball Team Wins State Championship After Dramatic Final", 36 | "Celebrated Actor Launches New Environmental Awareness Campaign", 37 | "Startup Develops App That Predicts Traffic Patterns Using AI", 38 | "asdfoinasedf'awesdf", 39 | ] 40 | } 41 | ) 42 | 43 | 44 | @pytest.fixture 45 | def sample_list(): 46 | return [ 47 | "Massive Blizzard Hits the Northeast, Thousands Without Power", 48 | "Local High School Basketball Team Wins State Championship After Dramatic Final", 49 | "Celebrated Actor Launches New Environmental Awareness Campaign", 50 | "Startup Develops App That Predicts Traffic Patterns Using AI", 51 | "asdfoinasedf'awesdf", 52 | ] 53 | 54 | 55 | def test_base_categorizer(sample_categories, sample_provider): 56 | """ 57 | Test the BaseCategorizer class. 58 | """ 59 | categorizer = BaseCategorizer( 60 | categories=sample_categories, 61 | llm_provider=sample_provider, 62 | examples=[ 63 | {"example": "Big stormy skies over city", "response": "Weather"}, 64 | {"example": "The team won the championship", "response": "Sports"}, 65 | {"example": "I saw a famous rapper at the mall", "response": "Celebrities"}, 66 | ], 67 | ) 68 | 69 | assert categorizer.categorize("It's raining heavily today.") == "Weather" 70 | assert categorizer.categorize("The football match was exciting!") == "Sports" 71 | assert categorizer.categorize("I saw Emma Watson at the mall.") == "Celebrities" 72 | assert categorizer.categorize("This is a random comment.") == "Others" 73 | assert categorizer.categorize("1234567890!@#$%^&*()") == "Anomaly" 74 | 75 | 76 | def test_base_categorizer_batch(sample_categories, sample_provider, sample_list): 77 | """ 78 | Test the BaseCategorizer class with a batch of examples. 79 | """ 80 | categorizer = BaseCategorizer( 81 | categories=sample_categories, 82 | llm_provider=sample_provider, 83 | examples=[ 84 | {"example": "Big stormy skies over city", "response": "Weather"}, 85 | {"example": "The team won the championship", "response": "Sports"}, 86 | {"example": "I saw a famous rapper at the mall", "response": "Celebrities"}, 87 | ], 88 | ) 89 | print(categorizer.system_message_batch) 90 | assert categorizer.categorize_batch(sample_list) == [ 91 | "Weather", 92 | "Sports", 93 | "Celebrities", 94 | "Others", 95 | "Anomaly", 96 | ] 97 | 98 | 99 | def test_apply_to_column(sample_categories, sample_provider, sample_dataframe): 100 | """ 101 | Test the apply_to_column function. 102 | """ 103 | categorizer = BaseCategorizer( 104 | categories=sample_categories, 105 | llm_provider=sample_provider, 106 | examples=[ 107 | {"example": "Big stormy skies over city", "response": "Weather"}, 108 | {"example": "The team won the championship", "response": "Sports"}, 109 | {"example": "I saw a famous rapper at the mall", "response": "Celebrities"}, 110 | ], 111 | ) 112 | 113 | df = sample_dataframe.copy() 114 | df["category"] = None 115 | 116 | success_idx = apply_to_column(df["text"], df["category"], categorizer.categorize) 117 | 118 | assert success_idx == 5 119 | assert df["category"].tolist() == [ 120 | "Weather", 121 | "Sports", 122 | "Celebrities", 123 | "Others", 124 | "Anomaly", 125 | ] 126 | 127 | 128 | def test_apply_to_column_batch(sample_categories, sample_provider, sample_dataframe): 129 | """ 130 | Test the apply_to_column_batch function. 131 | """ 132 | categorizer = BaseCategorizer( 133 | categories=sample_categories, 134 | llm_provider=sample_provider, 135 | examples=[ 136 | {"example": "Big stormy skies over city", "response": "Weather"}, 137 | {"example": "The team won the championship", "response": "Sports"}, 138 | {"example": "I saw a famous rapper at the mall", "response": "Celebrities"}, 139 | ], 140 | ) 141 | 142 | df = sample_dataframe.copy() 143 | df["category"] = None 144 | 145 | success_idx = apply_to_column_batch( 146 | df["text"], df["category"], categorizer.categorize_batch, batch_size=2 147 | ) 148 | 149 | assert success_idx == 5 150 | assert df["category"].tolist() == [ 151 | "Weather", 152 | "Sports", 153 | "Celebrities", 154 | "Others", 155 | "Anomaly", 156 | ] 157 | 158 | 159 | def test_apply_to_column_batch_start_idx( 160 | sample_categories, sample_provider, sample_dataframe 161 | ): 162 | """ 163 | Test the apply_to_column_batch function with a start index. 164 | """ 165 | categorizer = BaseCategorizer( 166 | categories=sample_categories, 167 | llm_provider=sample_provider, 168 | examples=[ 169 | {"example": "Big stormy skies over city", "response": "Weather"}, 170 | {"example": "The team won the championship", "response": "Sports"}, 171 | {"example": "I saw a famous rapper at the mall", "response": "Celebrities"}, 172 | ], 173 | ) 174 | 175 | df = sample_dataframe.copy() 176 | df["category"] = None 177 | 178 | success_idx = apply_to_column_batch( 179 | df["text"], 180 | df["category"], 181 | categorizer.categorize_batch, 182 | batch_size=2, 183 | start_idx=1, 184 | ) 185 | 186 | assert success_idx == 5 187 | assert df["category"].tolist() == [ 188 | None, 189 | "Sports", 190 | "Celebrities", 191 | "Others", 192 | "Anomaly", 193 | ] 194 | 195 | 196 | def test_multi_categorizer(sample_categories, sample_provider): 197 | """ 198 | Test the MultiCategorizer class. 199 | """ 200 | categorizer = MultiCategorizer( 201 | categories=sample_categories, 202 | llm_provider=sample_provider, 203 | examples=[ 204 | { 205 | "example": "Big stormy skies over city causes league football game to be cancelled", 206 | "response": "Weather,Sports", 207 | }, 208 | { 209 | "example": "Elon musk likes to play golf", 210 | "response": "Sports,Celebrities", 211 | }, 212 | ], 213 | ) 214 | 215 | assert set( 216 | categorizer.categorize("It's raining and I saw Emma Watson.").split(",") 217 | ) == { 218 | "Weather", 219 | "Celebrities", 220 | } 221 | assert set( 222 | categorizer.categorize("The football match was exciting and it's sunny!").split( 223 | "," 224 | ) 225 | ) == {"Sports", "Weather"} 226 | 227 | 228 | def test_multi_categorizer_batch(sample_categories, sample_provider): 229 | """ 230 | Test the MultiCategorizer class with a batch of examples. 231 | """ 232 | categorizer = MultiCategorizer( 233 | categories=sample_categories, 234 | llm_provider=sample_provider, 235 | examples=[ 236 | { 237 | "example": "Big stormy skies over city causes league football game to be cancelled", 238 | "response": "Weather,Sports", 239 | }, 240 | { 241 | "example": "Elon musk likes to play golf", 242 | "response": "Sports,Celebrities", 243 | }, 244 | ], 245 | ) 246 | 247 | examples = [ 248 | "Thunderstorms cause major delays in baseball tournament", 249 | "Famous actor spotted at local charity basketball game", 250 | "Heavy rainfall leads to postponement of soccer match", 251 | ] 252 | 253 | expected_output = [ 254 | ["Weather", "Sports"], 255 | ["Celebrities", "Sports"], 256 | ["Weather", "Sports"], 257 | ] 258 | 259 | actual_output = categorizer.categorize_batch(examples) 260 | 261 | # Convert the actual output to sets for order-insensitive comparison 262 | actual_output_sets = [set(categories.split(",")) for categories in actual_output] 263 | 264 | # Convert the expected output to sets for order-insensitive comparison 265 | expected_output_sets = [set(categories) for categories in expected_output] 266 | 267 | assert len(actual_output_sets) == len(expected_output_sets) 268 | for actual_set, expected_set in zip(actual_output_sets, expected_output_sets): 269 | assert actual_set == expected_set 270 | --------------------------------------------------------------------------------