├── .gitignore
├── .readthedocs.yaml
├── LICENSE
├── README.md
├── databonsai
├── CONTRIBUTING.MD
├── __init__.py
├── categorize
│ ├── __init__.py
│ ├── base_categorizer.py
│ └── multi_categorizer.py
├── examples
│ └── categorize_news.ipynb
├── llm_providers
│ ├── __init__.py
│ ├── anthropic_provider.py
│ ├── llm_provider.py
│ ├── ollama_provider.py
│ └── openai_provider.py
├── transform
│ ├── __init__.py
│ ├── base_transformer.py
│ └── extract_transformer.py
└── utils
│ ├── __init__.py
│ ├── apply.py
│ └── logs.py
├── docs
├── AnthropicProvider.md
├── BaseCategorizer.md
├── BaseTransformer.md
├── ExtractTransformer.md
├── MultiCategorizer.md
├── OllamaProvider.md
├── OpenAIProvider.md
├── Utils.md
└── source
│ ├── conf.py
│ ├── databonsai.categorize.rst
│ ├── databonsai.llm_providers.rst
│ ├── databonsai.rst
│ ├── databonsai.transform.rst
│ ├── databonsai.utils.rst
│ ├── index.rst
│ └── modules.rst
├── pyproject.toml
├── setup.py
└── tests
├── __init__.py
└── test_categorization.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .env
2 | .*pyc
3 | __pycache__/
4 | venv/
5 | dist/
6 | databonsai.egg-info/
7 | build/
8 | docs/source/
9 | docs/build/
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # .readthedocs.yaml
2 | version: 2
3 | build:
4 | os: "ubuntu-22.04"
5 | tools:
6 | python: "3.11"
7 |
8 | sphinx:
9 | configuration: docs/source/conf.py
10 |
11 | python:
12 | install:
13 | - requirements: docs/requirements.txt
14 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Alvin Ryanputra
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # databonsai
2 |
3 | [](https://badge.fury.io/py/databonsai)
4 | [](https://opensource.org/licenses/MIT)
5 | [](https://pypi.org/project/databonsai/)
6 | [](https://github.com/psf/black)
7 |
8 | ## Clean & curate your data with LLMs
9 |
10 | databonsai is a Python library that uses LLMs to perform data cleaning tasks.
11 |
12 | ## Features
13 |
14 | - Suite of tools for data processing using LLMs including categorization,
15 | transformation, and extraction
16 | - Validation of LLM outputs
17 | - Batch processing for token savings
18 | - Retry logic with exponential backoff for handling rate limits and transient
19 | errors
20 |
21 | ## Installation
22 |
23 | ```bash
24 | pip install databonsai
25 | ```
26 |
27 | Store your API keys on an .env file in the root of your project, or specify it
28 | as an argument when initializing the provider.
29 |
30 | ```bash
31 | OPENAI_API_KEY=xxx # if you use OpenAiProvider
32 | ANTHROPIC_API_KEY=xxx # If you use AnthropicProvider
33 | ```
34 |
35 | ## Quickstart
36 |
37 | ### Categorization
38 |
39 | Setup the LLM provider and categories (as a dictionary.
40 |
41 | ```python
42 | from databonsai.categorize import MultiCategorizer, BaseCategorizer
43 | from databonsai.llm_providers import OpenAIProvider, AnthropicProvider
44 |
45 | provider = OpenAIProvider() # Or AnthropicProvider(). Highly recommend using Haiku, which is the default AnthropicProvider() model, as it is cheap and effective for these tasks
46 | categories = {
47 | "Weather": "Insights and remarks about weather conditions.",
48 | "Sports": "Observations and comments on sports events.",
49 | "Politics": "Political events related to governments, nations, or geopolitical issues.",
50 | "Celebrities": "Celebrity sightings and gossip",
51 | "Others": "Comments do not fit into any of the above categories",
52 | "Anomaly": "Data that does not look like comments or natural language",
53 | }
54 | few_shot_examples = [
55 | {"example": "Big stormy skies over city", "response": "Weather"},
56 | {"example": "The team won the championship", "response": "Sports"},
57 | {"example": "I saw a famous rapper at the mall", "response": "Celebrities"},
58 | ]
59 | ```
60 |
61 | Categorize your data:
62 |
63 | ```python
64 | categorizer = BaseCategorizer(
65 | categories=categories,
66 | llm_provider=provider,
67 | examples = few_shot_examples,
68 | #strict = False # Default true, set to False to allow for categories not in the provided dict
69 | )
70 | category = categorizer.categorize("It's been raining outside all day")
71 | print(category)
72 | ```
73 |
74 | Output:
75 |
76 | ```python
77 | Weather
78 | ```
79 |
80 | Use categorize_batch to categorize a batch. This saves tokens as it only sends
81 | the schema and few shot examples once! (Works best for better models. Ideally,
82 | use at least 3 few shot examples.)
83 |
84 | ```python
85 | categories = categorizer.categorize_batch([
86 | "Massive Blizzard Hits the Northeast, Thousands Without Power",
87 | "Local High School Basketball Team Wins State Championship After Dramatic Final",
88 | "Celebrated Actor Launches New Environmental Awareness Campaign",
89 | ])
90 | print(categories)
91 | ```
92 |
93 | Output:
94 |
95 | ```python
96 | ['Weather', 'Sports', 'Celebrities']
97 | ```
98 |
99 | ### AutoBatch for Larger datasets
100 |
101 | If you have a pandas dataframe or list, use `apply_to_column_autobatch`
102 |
103 | - Batching data for LLM api calls saves tokens by not sending the prompt for
104 | every row. However, too large a batch size / complex tasks can lead to
105 | errors. Naturally, the better the LLM model, the larger the batch size you
106 | can use.
107 |
108 | - This batching is handled adaptively (i.e., it will increase the batch size
109 | if the response is valid and reduce it if it's not, with a decay factor)
110 |
111 | Other features:
112 |
113 | - Progress bar
114 | - Returns the last successful index so you can resume from there, in case it
115 | exceeds max_retries
116 | - Modifies your output list in place, so you don't lose any progress
117 |
118 | Retry Logic:
119 |
120 | - LLM providers have retry logic built in for API related errors. This can be
121 | configured in the provider.
122 | - The retry logic in the apply_to_column_autobatch is for handling invalid
123 | responses (e.g. unexpected category, different number of outputs, etc.)
124 |
125 | ```python
126 | from databonsai.utils import apply_to_column_batch, apply_to_column, apply_to_column_autobatch
127 | import pandas as pd
128 |
129 | headlines = [
130 | "Massive Blizzard Hits the Northeast, Thousands Without Power",
131 | "Local High School Basketball Team Wins State Championship After Dramatic Final",
132 | "Celebrated Actor Launches New Environmental Awareness Campaign",
133 | "President Announces Comprehensive Plan to Combat Cybersecurity Threats",
134 | "Tech Giant Unveils Revolutionary Quantum Computer",
135 | "Tropical Storm Alina Strengthens to Hurricane as It Approaches the Coast",
136 | "Olympic Gold Medalist Announces Retirement, Plans Coaching Career",
137 | "Film Industry Legends Team Up for Blockbuster Biopic",
138 | "Government Proposes Sweeping Reforms in Public Health Sector",
139 | "Startup Develops App That Predicts Traffic Patterns Using AI",
140 | ]
141 | df = pd.DataFrame(headlines, columns=["Headline"])
142 | df["Category"] = None # Initialize it if it doesn't exist, as we modify it in place
143 | success_idx = apply_to_column_autobatch( df["Headline"], df["Category"], categorizer.categorize_batch, batch_size=3, start_idx=0)
144 | ```
145 |
146 | There are many more options available for autobatch, such as setting a
147 | max_retries, decay factor, and more. Check [Utils](./docs/Utils.md) for more
148 | details
149 |
150 | If it fails midway (even after exponential backoff), you can resume from the
151 | last successful index + 1.
152 |
153 | ```python
154 | success_idx = apply_to_column_autobatch( df["Headline"], df["Category"], categorizer.categorize_batch, batch_size=10, start_idx=success_idx+1)
155 | ```
156 |
157 | This also works for regular python lists.
158 |
159 | Note that the better the LLM model, the greater the batch_size you can use
160 | (depending on the length of your inputs). If you're getting errors, reduce the
161 | batch_size, or use a better LLM model.
162 |
163 | To use it with batching, but with a fixed batch size:
164 |
165 | ```python
166 | success_idx = apply_to_column_batch( df["Headline"], df["Category"], categorizer.categorize_batch, batch_size=3, start_idx=0)
167 | ```
168 |
169 | To use it without batching:
170 |
171 | ```python
172 | success_idx = apply_to_column( df["Headline"], df["Category"], categorizer.categorize)
173 | ```
174 |
175 | ### View System Prompt
176 |
177 | ```python
178 | print(categorizer.system_message)
179 | print(categorizer.system_message_batch)
180 | ```
181 |
182 | ### View token usage
183 |
184 | Token usage is recorded for OpenAI and Anthropic. Use these to estimate your
185 | costs!
186 |
187 | ```python
188 | print(provder.input_tokens)
189 | print(provder.output_tokens)
190 | ```
191 |
192 | ## [Docs](./docs/)
193 |
194 | ### Tools (Check out the docs for usage examples and details)
195 |
196 | - [BaseCategorizer](./docs/BaseCategorizer.md) - categorize data into a
197 | category
198 | - [MultiCategorizer](./docs/MultiCategorizer.md) - categorize data into
199 | multiple categories
200 | - [BaseTransformer](./docs/BaseTransformer.md) - transform data with a prompt
201 | - [ExtractTransformer](./docs/ExtractTransformer.md) - Extract data into a
202 | structured format based on a schema
203 | - .. more coming soon!
204 |
205 | ### LLM Providers
206 |
207 | - [OpenAIProvider](./docs/OpenAIProvider.md) - OpenAI
208 | - [AnthropicProvider](./docs/AnthropicProvider.md) - Anthropic
209 | - [OllamaProvider](./docs/OllamaProvider.md) - Ollama
210 | - .. more coming soon!
211 |
212 | ### Examples
213 |
214 | - [Examples](./databonsai/examples/) (TBD)
215 |
216 | ### Acknowledgements
217 |
218 | Bonsai icon from icons8 https://icons8.com/icon/74uBtdDr5yFq/bonsai
219 |
--------------------------------------------------------------------------------
/databonsai/CONTRIBUTING.MD:
--------------------------------------------------------------------------------
1 | ## Contributing to DataBonsai
2 |
3 | - Reporting bugs and issues
4 | - Suggesting new features or enhancements
5 | - Submitting pull requests with bug fixes or new features
6 | - Improving documentation
7 | - Providing feedback and suggestions
8 |
9 | ## Getting Started
10 |
11 | To get started with contributing to DataBonsai, follow these steps:
12 |
13 | Fork the DataBonsai repository on GitHub. Clone your forked repository to your
14 | local machine. Create a new branch for your contribution:
15 |
16 | ```
17 | git checkout -b feature/your-feature-name
18 | ```
19 |
20 | - Make your changes and commit them with descriptive commit messages.
21 |
22 | - Push your changes to your forked repository.
23 | - Open a pull request against the main DataBonsai repository.
24 |
25 | ## Reporting Bugs and Issues
26 |
27 | If you encounter a bug or issue while using DataBonsai, please report it by
28 | opening a new issue on the GitHub repository, or email databonsai.ai@gmail.com
29 |
--------------------------------------------------------------------------------
/databonsai/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alvin-r/databonsai/3f2b7c58d0aa172251b5c3c321dea93038853e32/databonsai/__init__.py
--------------------------------------------------------------------------------
/databonsai/categorize/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_categorizer import BaseCategorizer
2 | from .multi_categorizer import MultiCategorizer
3 |
--------------------------------------------------------------------------------
/databonsai/categorize/base_categorizer.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Optional
2 | from pydantic import BaseModel, field_validator, model_validator, computed_field
3 | from databonsai.llm_providers import OpenAIProvider, LLMProvider
4 | from databonsai.utils.logs import logger
5 |
6 |
7 | class BaseCategorizer(BaseModel):
8 | """
9 | A base class for categorizing input data using a specified LLM provider.
10 |
11 | Attributes:
12 | categories (Dict[str, str]): A dictionary mapping category names to their descriptions.
13 | llm_provider (LLMProvider): An instance of an LLM provider to be used for categorization.
14 |
15 | """
16 |
17 | categories: Dict[str, str]
18 | llm_provider: LLMProvider
19 | examples: Optional[List[Dict[str, str]]] = []
20 | strict: bool = True
21 |
22 | class Config:
23 | arbitrary_types_allowed = True
24 |
25 | @field_validator("categories")
26 | def validate_categories(cls, v):
27 | """
28 | Validates the categories dictionary.
29 |
30 | Args:
31 | v (Dict[str, str]): The categories dictionary to be validated.
32 |
33 | Raises:
34 | ValueError: If the categories dictionary is empty or has less than two key-value pairs.
35 |
36 | Returns:
37 | Dict[str, str]: The validated categories dictionary.
38 | """
39 | if not v:
40 | raise ValueError("Categories dictionary cannot be empty.")
41 | if len(v) < 2:
42 | raise ValueError(
43 | "Categories dictionary must have more than one key-value pair."
44 | )
45 | return v
46 |
47 | @field_validator("examples")
48 | def validate_examples(cls, v):
49 | """
50 | Validates the examples list.
51 |
52 | Args:
53 | v (List[Dict[str, str]]): The examples list to be validated.
54 |
55 | Raises:
56 | ValueError: If the examples list is not a list of dictionaries or if any dictionary is missing the "example" or "response" key.
57 |
58 | Returns:
59 | List[Dict[str, str]]: The validated examples list.
60 | """
61 | if not isinstance(v, list):
62 | raise ValueError("Examples must be a list of dictionaries.")
63 | for example in v:
64 | if not isinstance(example, dict):
65 | raise ValueError("Each example must be a dictionary.")
66 | if "example" not in example or "response" not in example:
67 | raise ValueError(
68 | "Each example dictionary must have 'example' and 'response' keys."
69 | )
70 | return v
71 |
72 | @model_validator(mode="after")
73 | def validate_examples_responses(self):
74 | """
75 | Validates that the "response" values in the examples are within the categories keys.
76 | """
77 |
78 | if self.examples:
79 | category_keys = set(self.categories.keys())
80 | for example in self.examples:
81 | response = example.get("response")
82 | if response not in category_keys:
83 | raise ValueError(
84 | f"Example response '{response}' is not one of the provided categories, {str(list(self.categories.keys()))}."
85 | )
86 |
87 | return self
88 |
89 | @computed_field
90 | @property
91 | def category_mapping(self) -> Dict[int, str]:
92 | return {i: category for i, category in enumerate(self.categories.keys())}
93 |
94 | @computed_field
95 | @property
96 | def inverse_category_mapping(self) -> Dict[str, int]:
97 | return {category: i for i, category in self.category_mapping.items()}
98 |
99 | @computed_field
100 | @property
101 | def system_message(self) -> str:
102 | categories_with_numbers = "\n".join(
103 | [f"{i}: {desc}" for i, desc in enumerate(self.categories.values())]
104 | )
105 | system_message = f"""
106 | You are a categorization expert, specializing in classifying text data into predefined categories. You cannot use any other categories except those provided.
107 | Each category is formatted as :
108 | {categories_with_numbers}
109 | Classify the given text snippet into one of the following categories:
110 | {str(list(range(len(self.categories))))}
111 | Do not use any other categories.
112 | Only reply with the category number. Do not make any other conversation.
113 | """
114 | # Add in fewshot examples
115 | if self.examples:
116 | for example in self.examples:
117 | system_message += f"\nEXAMPLE: {example['example']} RESPONSE: {self.inverse_category_mapping[example['response']]}"
118 | return system_message
119 |
120 | @computed_field
121 | @property
122 | def system_message_batch(self) -> str:
123 | categories_with_numbers = "\n".join(
124 | [f"{i}: {desc}" for i, desc in enumerate(self.categories.values())]
125 | )
126 | system_message = f"""
127 | You are a categorization expert, specializing in classifying text data into predefined categories. You cannot use any other categories except those provided.
128 | Each category is formatted as :
129 | {categories_with_numbers}
130 | Classify each given text snippet into one of the following categories:
131 | {str(list(range(len(self.categories))))}.
132 | Do not use any other categories. If there are multiple snippets, separate each category number with ||.
133 | EXAMPLE: RESPONSE:
134 | EXAMPLE: || RESPONSE: ||
135 | Choose one category for each text snippet.
136 | Only reply with the category numbers. Do not make any other conversation.
137 | """
138 | # Add in fewshot examples
139 | if self.examples:
140 | system_message += "\n EXAMPLE:"
141 | system_message += (
142 | f"{'||'.join([example['example'] for example in self.examples])}"
143 | )
144 | system_message += f"\n RESPONSE: {'||'.join([str(self.inverse_category_mapping[example['response']]) for example in self.examples])}"
145 |
146 | return system_message
147 |
148 | def categorize(self, input_data: str) -> str:
149 | """
150 | Categorizes the input data using the specified LLM provider.
151 |
152 | Args:
153 | input_data (str): The text data to be categorized.
154 |
155 | Returns:
156 | str: The predicted category for the input data.
157 |
158 | Raises:
159 | ValueError: If the predicted category is not one of the provided categories.
160 | """
161 | # Call the LLM provider to get the predicted category number
162 | response = self.llm_provider.generate(self.system_message, input_data)
163 | predicted_category_number = int(response.strip())
164 |
165 | # Validate that the predicted category number is within the valid range
166 | if predicted_category_number not in self.category_mapping:
167 | if self.strict:
168 | raise ValueError(
169 | f"Predicted category number '{predicted_category_number}' is not one of the provided categories. Use 'strict=False' when instantiating the categorizer to allow categories not in the categories dict."
170 | )
171 | else:
172 | logger.warning(
173 | f"Predicted category number '{predicted_category_number}' is not one of the provided categories. Use 'strict=True' when instantiating the categorizer to raise an error."
174 | )
175 |
176 | # Convert the category number back to the category key
177 | predicted_category = self.category_mapping[predicted_category_number]
178 | return predicted_category
179 |
180 | def categorize_batch(self, input_data: List[str]) -> List[str]:
181 | """
182 | Categorizes a batch of input data using the specified LLM provider. For less advanced LLMs, call this method on batches of 3-5 inputs (depending on the length of the input data).
183 |
184 | Args:
185 | input_data (List[str]): A list of text data to be categorized.
186 |
187 | Returns:
188 | List[str]: A list of predicted categories for the input data.
189 |
190 | Raises:
191 | ValueError: If the predicted categories are not a subset of the provided categories.
192 | """
193 | # If there is only one input, call the categorize method
194 | if len(input_data) == 1:
195 | return self.validate_predicted_categories(
196 | [self.categorize(next(iter(input_data)))]
197 | )
198 |
199 | input_data_prompt = "||".join(input_data)
200 | # Call the LLM provider to get the predicted category numbers
201 | response = self.llm_provider.generate(
202 | self.system_message_batch, input_data_prompt
203 | )
204 | predicted_category_numbers = [
205 | int(category.strip()) for category in response.split("||")
206 | ]
207 | if len(predicted_category_numbers) != len(input_data):
208 | raise ValueError(
209 | f"Number of predicted categories ({len(predicted_category_numbers)}) does not match the number of input data ({len(input_data)})."
210 | )
211 | # Convert the category numbers back to category keys
212 | predicted_categories = [
213 | self.category_mapping[number] for number in predicted_category_numbers
214 | ]
215 | return self.validate_predicted_categories(predicted_categories)
216 |
217 | def validate_predicted_categories(
218 | self, predicted_categories: List[str]
219 | ) -> List[str]:
220 | # Filter out empty strings from the predicted categories
221 | filtered_categories = [
222 | category for category in predicted_categories if category
223 | ]
224 |
225 | # Validate each category in the filtered list
226 | for predicted_category in filtered_categories:
227 | if predicted_category not in self.categories:
228 | if self.strict:
229 | raise ValueError(
230 | f"Predicted category '{predicted_category}' is not one of the provided categories. Use 'strict=False' when instantiating the categorizer to allow categories not in the categories dict."
231 | )
232 | else:
233 | # Warn the user if the predicted category is not one of the provided categories
234 | logger.warning(
235 | f"Predicted category '{predicted_category}' is not one of the provided categories. Use 'strict=True' when instantiating the categorizer to raise an error."
236 | )
237 | return filtered_categories
238 |
--------------------------------------------------------------------------------
/databonsai/categorize/multi_categorizer.py:
--------------------------------------------------------------------------------
1 | from typing import List, Dict
2 | from databonsai.categorize.base_categorizer import BaseCategorizer
3 | from pydantic import model_validator, computed_field
4 |
5 |
6 | class MultiCategorizer(BaseCategorizer):
7 | """
8 | A class for categorizing input data into multiple categories using a specified LLM provider.
9 |
10 | This class extends the BaseCategorizer class and overrides the categorize method to allow
11 | for multiple category predictions.
12 | """
13 |
14 | class Config:
15 | arbitrary_types_allowed = True
16 |
17 | @model_validator(mode="after")
18 | def validate_examples_responses(self):
19 | """
20 | Validates that the "response" values in the examples are within the categories keys. If there are multiple categories, they should be separated by commas.
21 | """
22 |
23 | if self.examples:
24 | category_keys = set(self.categories.keys())
25 | for example in self.examples:
26 | response = example.get("response")
27 | response_categories = response.split(",")
28 | for response_category in response_categories:
29 | if response_category not in category_keys:
30 | raise ValueError(
31 | f"Example response '{response}' is not one of the provided categories, {str(list(self.categories.keys()))}."
32 | )
33 |
34 | return self
35 |
36 | @computed_field
37 | @property
38 | def system_message(self) -> str:
39 | categories_with_numbers = "\n".join(
40 | [f"{i}: {desc}" for i, desc in enumerate(self.categories.values())]
41 | )
42 | system_message = f"""
43 | Each category is formatted as :
44 | {categories_with_numbers}
45 | Classify the given text snippet into one or more of the following categories:
46 | {str(list(range(len(self.categories))))}
47 | Do not use any other categories.
48 | Assign multiple categories to one content snippet by separating the categories with ||. Do not make any other conversation.
49 | """
50 |
51 | # Add in fewshot examples
52 | if self.examples:
53 | for example in self.examples:
54 | response_numbers = [
55 | str(self.inverse_category_mapping[category.strip()])
56 | for category in example["response"].split(",")
57 | ]
58 | system_message += f"\nEXAMPLE: {example['example']} RESPONSE: {'||'.join(response_numbers)}"
59 | return system_message
60 |
61 | @computed_field
62 | @property
63 | def system_message_batch(self) -> str:
64 | categories_with_numbers = "\n".join(
65 | [f"{i}: {desc}" for i, desc in enumerate(self.categories.values())]
66 | )
67 | system_message = f"""
68 | Each category is formatted as :
69 | {categories_with_numbers}
70 | Classify the given text snippet into one or more of the following categories:
71 | {str(list(range(len(self.categories))))}
72 | Do not use any other categories.
73 | Assign multiple categories to one content snippet by separating the categories with ||. Differentiate between each content snippet using ##. EXAMPLE: ## \n RESPONSE: ||## Do not make any other conversation.
74 | """
75 |
76 | # Add in fewshot examples
77 | if self.examples:
78 | system_message += "\nEXAMPLE: "
79 | system_message += (
80 | f"{'##'.join([example['example'] for example in self.examples])}"
81 | )
82 | response_numbers_list = []
83 | for example in self.examples:
84 | response_numbers = [
85 | str(self.inverse_category_mapping[category.strip()])
86 | for category in example["response"].split(",")
87 | ]
88 | response_numbers_list.append("||".join(response_numbers))
89 | system_message += f"\nRESPONSE: {'##'.join(response_numbers_list)}"
90 | return system_message
91 |
92 | def categorize(self, input_data: str) -> str:
93 | """
94 | Categorizes the input data into multiple categories using the specified LLM provider.
95 |
96 | Args:
97 | input_data (str): The text data to be categorized.
98 |
99 | Returns:
100 | str: A string of categories for the input data, separated by commas.
101 |
102 | Raises:
103 | ValueError: If the predicted categories are not a subset of the provided categories.
104 | """
105 |
106 | # Call the LLM provider to get the predicted category numbers
107 | response = self.llm_provider.generate(self.system_message, input_data)
108 | predicted_category_numbers = [
109 | int(category.strip()) for category in response.split("||")
110 | ]
111 |
112 | # Convert the category numbers back to category keys
113 | predicted_categories = [
114 | self.category_mapping[number] for number in predicted_category_numbers
115 | ]
116 | return ",".join(self.validate_predicted_categories(predicted_categories))
117 |
118 | def categorize_batch(self, input_data: List[str]) -> List[str]:
119 | """
120 | Categorizes the input data into multiple categories using the specified LLM provider.
121 |
122 | Args:
123 | input_data (str): The text data to be categorized.
124 |
125 | Returns:
126 | List[str]: A list of predicted categories for the input data. If there are multiple categories, they will be separated by commas.
127 |
128 | Raises:
129 | ValueError: If the predicted categories are not a subset of the provided categories.
130 | """
131 | if len(input_data) == 1:
132 | return [self.categorize(next(iter(input_data)))]
133 |
134 | input_data_prompt = "##".join(input_data)
135 | # Call the LLM provider to get the predicted category numbers
136 | response = self.llm_provider.generate(
137 | self.system_message_batch, input_data_prompt
138 | )
139 | # Split the response into category number sets for each input data
140 | category_number_sets = response.split("##")
141 |
142 | if len(category_number_sets) != len(input_data):
143 | raise ValueError(
144 | f"Number of predicted category sets ({len(category_number_sets)}) does not match the number of input data ({len(input_data)})."
145 | )
146 |
147 | predicted_categories_list = []
148 | for category_number_set in category_number_sets:
149 | predicted_category_numbers = [
150 | int(category.strip()) for category in category_number_set.split("||")
151 | ]
152 | predicted_categories = [
153 | self.category_mapping[number] for number in predicted_category_numbers
154 | ]
155 | predicted_categories_str = ",".join(
156 | self.validate_predicted_categories(predicted_categories)
157 | )
158 | predicted_categories_list.append(predicted_categories_str)
159 |
160 | return predicted_categories_list
161 |
--------------------------------------------------------------------------------
/databonsai/examples/categorize_news.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 2,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "from databonsai.categorize import MultiCategorizer, BaseCategorizer\n",
10 | "from databonsai.transform import BaseTransformer\n",
11 | "from databonsai.llm_providers import OpenAIProvider, AnthropicProvider, OllamaProvider\n",
12 | "from databonsai.utils import (\n",
13 | " apply_to_column,\n",
14 | " apply_to_column_batch,\n",
15 | " apply_to_column_autobatch,\n",
16 | ")\n",
17 | "import pandas as pd"
18 | ]
19 | },
20 | {
21 | "cell_type": "code",
22 | "execution_count": 3,
23 | "metadata": {},
24 | "outputs": [],
25 | "source": [
26 | "provider = OpenAIProvider(model=\"gpt-3.5-turbo\") # Or AnthropicProvider()"
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "execution_count": 4,
32 | "metadata": {},
33 | "outputs": [],
34 | "source": [
35 | "categories = {\n",
36 | " \"Weather\": \"Insights and remarks about weather conditions.\",\n",
37 | " \"Sports\": \"Observations and comments on sports events.\",\n",
38 | " \"Politics\": \"Political events related to governments, nations, or geopolitical issues.\",\n",
39 | " \"Celebrities\": \"Celebrity sightings and gossip\",\n",
40 | " \"Tech\": \"News and updates about technology and tech companies.\",\n",
41 | " \"Others\": \"Comments do not fit into any of the above categories\", # Best practice in case it can't be categorized easily\n",
42 | " \"Anomaly\": \"Data that does not look like comments or natural language\", # Helps to flag unclean/problematic data\n",
43 | "}\n",
44 | "categorizer = BaseCategorizer(\n",
45 | " categories=categories,\n",
46 | " llm_provider=provider,\n",
47 | " examples=[\n",
48 | " {\"example\": \"Big stormy skies over city\", \"response\": \"Weather\"},\n",
49 | " {\"example\": \"The team won the championship\", \"response\": \"Sports\"},\n",
50 | " {\"example\": \"I saw a famous rapper at the mall\", \"response\": \"Celebrities\"},\n",
51 | " ],\n",
52 | ")"
53 | ]
54 | },
55 | {
56 | "cell_type": "code",
57 | "execution_count": 5,
58 | "metadata": {},
59 | "outputs": [],
60 | "source": [
61 | "headlines2 = [\n",
62 | " \"Local Fire Department Honored with National Award for Bravery\",\n",
63 | " \"Breakthrough Research Promises New Treatment for Alzheimer’s Disease\",\n",
64 | " \"Major Airline Announces Expansion of International Routes\",\n",
65 | " \"Historic Peace Agreement Signed Between Rival Nations\",\n",
66 | " \"City Council Votes to Increase Funding for Public Libraries\",\n",
67 | " \"Renowned Chef Opens Vegan Restaurant in Downtown\",\n",
68 | " \"Veteran Astronaut Set to Lead Next Moon Mission\",\n",
69 | " \"Global Music Festival Raises Funds for Refugee Relief\",\n",
70 | " \"Innovative Urban Farming Techniques Revolutionize City Life\",\n",
71 | " \"Climate Summit Sets Ambitious Goals for Carbon Reduction\",\n",
72 | " \"Documentary Film Exposing Corruption Premieres to Critical Acclaim\",\n",
73 | " \"New Legislation Aims to Boost Small Businesses\",\n",
74 | " \"Ancient Shipwreck Discovered Off the Coast of Sicily\",\n",
75 | " \"World Health Organization Declares New Strain of Virus Contained\",\n",
76 | " \"International Art Theft Ring Busted by Joint Task Force\",\n",
77 | " \"Leading Economists Predict Global Recession in Next Year\",\n",
78 | " \"Celebrity Fashion Designer Debuts Eco-Friendly Line\",\n",
79 | " \"Major Breakthrough in Quantum Encryption Technology Announced\",\n",
80 | " \"Wildlife Conservation Efforts Successfully Increase Tiger Population\",\n",
81 | " \"Rare Astronomical Event Visible This Weekend\",\n",
82 | " \"Nationwide Protests Demand Action on Climate Change\",\n",
83 | " \"Revolutionary New Battery Design Could Transform Renewable Energy Storage\",\n",
84 | " \"Record-Breaking Heatwave Strikes Southern Europe\",\n",
85 | " \"Underground Water Reserves Discovered Beneath Sahara Desert\",\n",
86 | " \"Virtual Reality Platform Takes Online Education to New Heights\",\n",
87 | " \"Controversial New Policy Sparks Debate Over Internet Privacy\",\n",
88 | " \"Youngest Nobel Laureate Awarded for Work in Peace Building\",\n",
89 | " \"Sports League Implements New Rules to Protect Players from Concussions\",\n",
90 | " \"Historic Church Undergoes Restoration to Preserve Cultural Heritage\",\n",
91 | " \"Pioneering Surgery Gives New Hope to Heart Disease Patients\",\n",
92 | " \"Wildfires Rage Across California, Thousands Evacuated\",\n",
93 | " \"Tech Start-Up Revolutionizes Mobile Payment Systems\",\n",
94 | " \"Pharmaceutical Company Faces Lawsuit Over Drug Side Effects\",\n",
95 | " \"Renewable Energy Now Powers Entire Small Nation\",\n",
96 | " \"Central Bank Raises Interest Rates in Surprise Move\",\n",
97 | " \"Marine Biologists Discover New Species in the Deep Ocean\",\n",
98 | " \"Global Conference on Women's Rights Concludes with Action Plan\",\n",
99 | " \"Country Music Star Reveals Struggle with Mental Health in New Album\",\n",
100 | " \"Massive Oil Spill Threatens Wildlife Along the Coast\",\n",
101 | " \"Protests Erupt as Government Cuts Healthcare Funding\",\n",
102 | " \"Archaeologists Uncover New Evidence of Ancient Civilization in Thailand\",\n",
103 | " \"Fashion Week Highlights Sustainability in New Collections\",\n",
104 | " \"New Strain of Wheat Could Increase Crop Yields Substantially\",\n",
105 | " \"Scientists Link Air Pollution to Decline in Urban Wildlife\",\n",
106 | " \"Innovative Community Program Cuts Urban Crime Rate\",\n",
107 | " \"Next Generation of Smartphones Features Advanced AI Capabilities\",\n",
108 | " \"Historical Drama Film Set to Break Box Office Records\",\n",
109 | " \"Study Shows Increase in Cyber Attacks on Financial Institutions\",\n",
110 | " \"New Yoga Trend Combines Traditional Practices with Modern Technology\",\n",
111 | " \"Local Community Garden Doubles as Educational Facility for Schools\",\n",
112 | "]"
113 | ]
114 | },
115 | {
116 | "cell_type": "code",
117 | "execution_count": 6,
118 | "metadata": {},
119 | "outputs": [
120 | {
121 | "name": "stderr",
122 | "output_type": "stream",
123 | "text": [
124 | "Categorizing: 0%| | 0/50 [00:00, ?row/s]"
125 | ]
126 | },
127 | {
128 | "name": "stdout",
129 | "output_type": "stream",
130 | "text": [
131 | "Retrying with smaller batch size: 4\n"
132 | ]
133 | },
134 | {
135 | "name": "stderr",
136 | "output_type": "stream",
137 | "text": [
138 | "Categorizing: 100%|██████████| 50/50 [00:04<00:00, 11.98row/s]\n"
139 | ]
140 | }
141 | ],
142 | "source": [
143 | "categories = []\n",
144 | "idx = apply_to_column_autobatch(\n",
145 | " headlines2,\n",
146 | " categories,\n",
147 | " categorizer.categorize_batch,\n",
148 | " max_retries=3,\n",
149 | " batch_size=5,\n",
150 | " ramp_factor=1.7,\n",
151 | " max_batch_size=20,\n",
152 | " ramp_factor_decay=0.98,\n",
153 | " reduce_factor=0.7,\n",
154 | " reduce_factor_decay=0.9,\n",
155 | " start_idx=0,\n",
156 | ")"
157 | ]
158 | },
159 | {
160 | "cell_type": "code",
161 | "execution_count": 7,
162 | "metadata": {},
163 | "outputs": [
164 | {
165 | "name": "stdout",
166 | "output_type": "stream",
167 | "text": [
168 | "['Politics', 'Tech', 'Others', 'Politics', 'Politics', 'Celebrities', 'Others', 'Others', 'Tech', 'Politics', 'Celebrities', 'Politics', 'Others', 'Others', 'Others', 'Others', 'Celebrities', 'Tech', 'Others', 'Weather', 'Politics', 'Tech', 'Weather', 'Tech', 'Tech', 'Politics', 'Celebrities', 'Sports', 'Others', 'Tech', 'Weather', 'Tech', 'Politics', 'Weather', 'Politics', 'Others', 'Celebrities', 'Weather', 'Politics', 'Others', 'Celebrities', 'Others', 'Others', 'Politics', 'Others', 'Tech', 'Celebrities', 'Tech', 'Others', 'Others']\n"
169 | ]
170 | }
171 | ],
172 | "source": [
173 | "print(categories)"
174 | ]
175 | },
176 | {
177 | "cell_type": "code",
178 | "execution_count": null,
179 | "metadata": {},
180 | "outputs": [],
181 | "source": []
182 | }
183 | ],
184 | "metadata": {
185 | "kernelspec": {
186 | "display_name": "venv",
187 | "language": "python",
188 | "name": "python3"
189 | },
190 | "language_info": {
191 | "codemirror_mode": {
192 | "name": "ipython",
193 | "version": 3
194 | },
195 | "file_extension": ".py",
196 | "mimetype": "text/x-python",
197 | "name": "python",
198 | "nbconvert_exporter": "python",
199 | "pygments_lexer": "ipython3",
200 | "version": "3.12.2"
201 | }
202 | },
203 | "nbformat": 4,
204 | "nbformat_minor": 2
205 | }
206 |
--------------------------------------------------------------------------------
/databonsai/llm_providers/__init__.py:
--------------------------------------------------------------------------------
1 | # llm_providers/__init__.py
2 | from .llm_provider import LLMProvider
3 | from .openai_provider import OpenAIProvider
4 | from .anthropic_provider import AnthropicProvider
5 | from .ollama_provider import OllamaProvider
6 |
--------------------------------------------------------------------------------
/databonsai/llm_providers/anthropic_provider.py:
--------------------------------------------------------------------------------
1 | import anthropic
2 | from .llm_provider import LLMProvider
3 | import os
4 | from functools import wraps
5 | from tenacity import retry, wait_exponential, stop_after_attempt
6 | from dotenv import load_dotenv
7 | from databonsai.utils.logs import logger
8 |
9 | load_dotenv()
10 |
11 |
12 | class AnthropicProvider(LLMProvider):
13 | """
14 | A provider class to interact with Anthropic's Claude API.
15 | Supports exponential backoff retries, since we'll often deal with large datasets.
16 | """
17 |
18 | def __init__(
19 | self,
20 | api_key: str = None,
21 | multiplier: int = 1,
22 | min_wait: int = 1,
23 | max_wait: int = 30,
24 | max_tries: int = 5,
25 | model: str = "claude-3-haiku-20240307",
26 | temperature: float = 0,
27 | ):
28 | """
29 | Initializes the ClaudeProvider with an API key and retry parameters.
30 |
31 | Parameters:
32 | api_key (str): Anthropic API key.
33 | multiplier (int): The multiplier for the exponential backoff in retries.
34 | min_wait (int): The minimum wait time between retries.
35 | max_wait (int): The maximum wait time between retries.
36 | max_tries (int): The maximum number of attempts before giving up.
37 | model (str): The default model to use for text generation.
38 | temperature (float): The temperature parameter for text generation.
39 | """
40 | super().__init__()
41 |
42 | # Provider related configs
43 | if api_key:
44 | self.api_key = api_key
45 | else:
46 | self.api_key = os.getenv("ANTHROPIC_API_KEY")
47 | if not self.api_key:
48 | raise ValueError("Anthropic API key not provided.")
49 | self.model = model
50 | self.client = anthropic.Anthropic(api_key=self.api_key)
51 | self.temperature = temperature
52 | self.input_tokens = 0
53 | self.output_tokens = 0
54 |
55 | # Retry related configs
56 | self.multiplier = multiplier
57 | self.min_wait = min_wait
58 | self.max_wait = max_wait
59 | self.max_tries = max_tries
60 |
61 | def retry_with_exponential_backoff(method):
62 | """
63 | Decorator to apply retry logic with exponential backoff to an instance method.
64 | It captures the 'self' context to access instance attributes for retry configuration.
65 | """
66 |
67 | @wraps(method)
68 | def wrapper(self, *args, **kwargs):
69 | retry_decorator = retry(
70 | wait=wait_exponential(
71 | multiplier=self.multiplier, min=self.min_wait, max=self.max_wait
72 | ),
73 | stop=stop_after_attempt(self.max_tries),
74 | )
75 | return retry_decorator(method)(self, *args, **kwargs)
76 |
77 | return wrapper
78 |
79 | @retry_with_exponential_backoff
80 | def generate(self, system_prompt: str, user_prompt: str, max_tokens=1000, json: bool = False) -> str:
81 | """
82 | Generates a text completion using Anthropic's Claude API, with a given system and user prompt.
83 | This method is decorated with retry logic to handle temporary failures.
84 |
85 | Parameters:
86 | system_prompt (str): The system prompt to provide context or instructions for the generation.
87 | user_prompt (str): The user's prompt, based on which the text completion is generated.
88 | max_tokens (int): The maximum number of tokens to generate in the response.
89 |
90 | Returns:
91 | str: The generated text completion.
92 | """
93 | try:
94 | if not system_prompt:
95 | raise ValueError("System prompt is required.")
96 | if not user_prompt:
97 | raise ValueError("User prompt is required.")
98 | response = self.client.messages.create(
99 | model=self.model,
100 | max_tokens=max_tokens,
101 | temperature=self.temperature,
102 | system=f"{system_prompt}",
103 | messages=[
104 | {
105 | "role": "user",
106 | "content": [
107 | {
108 | "type": "text",
109 | "text": user_prompt,
110 | }
111 | ],
112 | }
113 | ],
114 | )
115 | self.input_tokens += response.usage.input_tokens
116 | self.output_tokens += response.usage.output_tokens
117 | return response.content[0].text
118 | except Exception as e:
119 | logger.warning(f"Error occurred during generation: {str(e)}")
120 | raise
121 |
--------------------------------------------------------------------------------
/databonsai/llm_providers/llm_provider.py:
--------------------------------------------------------------------------------
1 | # llm_providers/base_provider.py
2 | from abc import ABC, abstractmethod
3 | from typing import Optional
4 |
5 |
6 | class LLMProvider(ABC):
7 | @abstractmethod
8 | def __init__(
9 | self,
10 | model: str = "",
11 | temperature: float = 0,
12 | ):
13 | """
14 | Initializes the LLMProvider with an API key and retry parameters.
15 |
16 | Parameters:
17 |
18 | model (str): The default model to use for text generation.
19 | temperature (float): The temperature parameter for text generation.
20 | """
21 | pass
22 |
23 | @abstractmethod
24 | def generate(self, system_prompt: str, user_prompt: str, max_tokens: int) -> str:
25 | """
26 | Generates a text completion using the provider's API, with a given system and user prompt.
27 | This method should be decorated with retry logic to handle temporary failures.
28 |
29 | Parameters:
30 | system_prompt (str): The system prompt to provide context or instructions for the generation.
31 | user_prompt (str): The user's prompt, based on which the text completion is generated.
32 | max_tokens (int): The maximum number of tokens to generate in the response.
33 |
34 | Returns:
35 | str: The generated text completion.
36 | """
37 | pass
38 |
--------------------------------------------------------------------------------
/databonsai/llm_providers/ollama_provider.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 | from ollama import Client, chat
3 | from .llm_provider import LLMProvider
4 | from databonsai.utils.logs import logger
5 |
6 |
7 | class OllamaProvider(LLMProvider):
8 | """
9 | A provider class to interact with Ollama's API.
10 | Supports exponential backoff retries, since we'll often deal with large datasets.
11 | """
12 |
13 | def __init__(
14 | self,
15 | model: str = "llama3",
16 | temperature: float = 0,
17 | host: Optional[str] = None,
18 | ):
19 | """
20 | Initializes the OllamaProvider with an optional Ollama client or host, and retry parameters.
21 |
22 | Parameters:
23 | model (str): The default model to use for text generation.
24 | temperature (float): The temperature parameter for text generation.
25 | host (str): The host URL for the Ollama API.
26 | """
27 | # Provider related configs
28 | self.model = model
29 | self.temperature = temperature
30 |
31 | if host:
32 | self.client = Client(host=host)
33 | else:
34 | self.client = None
35 |
36 | def _chat(self, messages, options):
37 | if self.client:
38 | return self.client.chat(
39 | model=self.model, messages=messages, options=options
40 | )
41 | else:
42 | return chat(model=self.model, messages=messages, options=options)
43 |
44 | def generate(self, system_prompt: str, user_prompt: str, max_tokens=1000) -> str:
45 | """
46 | Generates a text completion using Ollama's API, with a given system and user prompt.
47 | This method is decorated with retry logic to handle temporary failures.
48 | Parameters:
49 | system_prompt (str): The system prompt to provide context or instructions for the generation.
50 | user_prompt (str): The user's prompt, based on which the text completion is generated.
51 | max_tokens (int): The maximum number of tokens to generate in the response.
52 | Returns:
53 | str: The generated text completion.
54 | """
55 | if not system_prompt:
56 | raise ValueError("System prompt is required.")
57 | if not user_prompt:
58 | raise ValueError("User prompt is required.")
59 | try:
60 | messages = [
61 | {"role": "system", "content": system_prompt},
62 | {"role": "user", "content": user_prompt},
63 | ]
64 |
65 | response = self._chat(
66 | messages,
67 | options={"temperature": self.temperature, "num_predict": max_tokens},
68 | )
69 | completion = response["message"]["content"]
70 |
71 | return completion
72 | except Exception as e:
73 | logger.warning(f"Error occurred during generation: {str(e)}")
74 | raise
75 |
--------------------------------------------------------------------------------
/databonsai/llm_providers/openai_provider.py:
--------------------------------------------------------------------------------
1 | from openai import OpenAI
2 | from .llm_provider import LLMProvider
3 | import os
4 | from functools import wraps
5 | from tenacity import retry, wait_exponential, stop_after_attempt
6 | from dotenv import load_dotenv
7 | from databonsai.utils.logs import logger
8 |
9 | load_dotenv()
10 |
11 |
12 | class OpenAIProvider(LLMProvider):
13 | """
14 | A provider class to interact with OpenAI's API.
15 | Supports exponential backoff retries, since we'll often deal with large datasets.
16 | """
17 |
18 | def __init__(
19 | self,
20 | api_key: str = None,
21 | multiplier: int = 1,
22 | min_wait: int = 1,
23 | max_wait: int = 30,
24 | max_tries: int = 5,
25 | model: str = "gpt-4-turbo",
26 | temperature: float = 0,
27 | ):
28 | """
29 | Initializes the OpenAIProvider with an API key and retry parameters.
30 |
31 | Parameters:
32 | api_key (str): OpenAI API key.
33 | multiplier (int): The multiplier for the exponential backoff in retries.
34 | min_wait (int): The minimum wait time between retries.
35 | max_wait (int): The maximum wait time between retries.
36 | max_tries (int): The maximum number of attempts before giving up.
37 | model (str): The default model to use for text generation.
38 | temperature (float): The temperature parameter for text generation.
39 | """
40 | super().__init__()
41 |
42 | # Provider related configs
43 | if api_key:
44 | self.api_key = api_key
45 | else:
46 | self.api_key = os.getenv("OPENAI_API_KEY")
47 | if not self.api_key:
48 | raise ValueError("OpenAI API key not provided.")
49 | self.model = model
50 | self.client = OpenAI(api_key=self.api_key)
51 | try:
52 | self.client.models.retrieve(model)
53 | except Exception as e:
54 | logger.warning(e.response.status_code)
55 | raise ValueError(f"Invalid OpenAI model: {model}") from e
56 | self.temperature = temperature
57 | self.input_tokens = 0
58 | self.output_tokens = 0
59 |
60 | # Retry related configs
61 | self.multiplier = multiplier
62 | self.min_wait = min_wait
63 | self.max_wait = max_wait
64 | self.max_tries = max_tries
65 |
66 | def retry_with_exponential_backoff(method):
67 | """
68 | Decorator to apply retry logic with exponential backoff to an instance method.
69 | It captures the 'self' context to access instance attributes for retry configuration.
70 | """
71 |
72 | @wraps(method)
73 | def wrapper(self, *args, **kwargs):
74 | retry_decorator = retry(
75 | wait=wait_exponential(
76 | multiplier=self.multiplier, min=self.min_wait, max=self.max_wait
77 | ),
78 | stop=stop_after_attempt(self.max_tries),
79 | )
80 | return retry_decorator(method)(self, *args, **kwargs)
81 |
82 | return wrapper
83 |
84 | @retry_with_exponential_backoff
85 | def generate(
86 | self, system_prompt: str, user_prompt: str, max_tokens=1000, json=False
87 | ) -> str:
88 | """
89 | Generates a text completion using OpenAI's API, with a given system and user prompt.
90 | This method is decorated with retry logic to handle temporary failures.
91 |
92 | Parameters:
93 | system_prompt (str): The system prompt to provide context or instructions for the generation.
94 | user_prompt (str): The user's prompt, based on which the text completion is generated.
95 | max_tokens (int): The maximum number of tokens to generate in the response.
96 | json (bool): Whether to use OpenAI's JSON response format.
97 |
98 | Returns:
99 | str: The generated text completion.
100 | """
101 | if not system_prompt:
102 | raise ValueError("System prompt is required.")
103 | if not user_prompt:
104 | raise ValueError("User prompt is required.")
105 | try:
106 | response = self.client.chat.completions.create(
107 | model=self.model,
108 | messages=[
109 | {"role": "system", "content": system_prompt},
110 | {"role": "user", "content": f"{user_prompt}"},
111 | ],
112 | temperature=self.temperature,
113 | max_tokens=max_tokens,
114 | frequency_penalty=0,
115 | presence_penalty=0,
116 | response_format={"type": "json_object"} if json else {"type": "text"},
117 | )
118 | self.input_tokens += response.usage.prompt_tokens
119 | self.output_tokens += response.usage.completion_tokens
120 | return response.choices[0].message.content
121 | except Exception as e:
122 | logger.warning(f"Error occurred during generation: {str(e)}")
123 | raise
124 |
--------------------------------------------------------------------------------
/databonsai/transform/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_transformer import BaseTransformer
2 | from .extract_transformer import ExtractTransformer
3 |
--------------------------------------------------------------------------------
/databonsai/transform/base_transformer.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Dict
2 | from pydantic import BaseModel, field_validator, model_validator, computed_field
3 | from databonsai.llm_providers import LLMProvider
4 |
5 |
6 | class BaseTransformer(BaseModel):
7 | """
8 | A base class for transforming input data using a specified LLM provider.
9 |
10 | Attributes:
11 | prompt (str): The prompt used to guide the transformation process.
12 | llm_provider (LLMProvider): An instance of an LLM provider to be used for transformation.
13 | examples (Optional[List[Dict[str, str]]]): A list of example inputs and their corresponding transformed outputs.
14 |
15 | """
16 |
17 | prompt: str
18 | llm_provider: LLMProvider
19 | examples: Optional[List[Dict[str, str]]] = []
20 |
21 | class Config:
22 | arbitrary_types_allowed = True
23 |
24 | @field_validator("prompt")
25 | def validate_prompt(cls, v):
26 | """
27 | Validates the prompt.
28 |
29 | Args:
30 | v (str): The prompt to be validated.
31 |
32 | Raises:
33 | ValueError: If the prompt is empty.
34 |
35 | Returns:
36 | str: The validated prompt.
37 | """
38 | if not v:
39 | raise ValueError("Prompt cannot be empty.")
40 | return v
41 |
42 | @field_validator("examples")
43 | def validate_examples(cls, v):
44 | """
45 | Validates the examples list.
46 |
47 | Args:
48 | v (List[Dict[str, str]]): The examples list to be validated.
49 |
50 | Raises:
51 | ValueError: If the examples list is not a list of dictionaries or if any dictionary is missing the "example" or "response" key.
52 |
53 | Returns:
54 | List[Dict[str, str]]: The validated examples list.
55 | """
56 | if not isinstance(v, list):
57 | raise ValueError("Examples must be a list of dictionaries.")
58 | for example in v:
59 | if not isinstance(example, dict):
60 | raise ValueError("Each example must be a dictionary.")
61 | if "example" not in example or "response" not in example:
62 | raise ValueError(
63 | "Each example dictionary must have 'example' and 'response' keys."
64 | )
65 | return v
66 |
67 | @computed_field
68 | @property
69 | def system_message(self) -> str:
70 | system_message = f"""
71 | Use the following prompt to transform the input data:
72 | Prompt: {self.prompt}
73 | """
74 |
75 | # Add in fewshot examples
76 | if self.examples:
77 | for example in self.examples:
78 | system_message += (
79 | f"\nEXAMPLE: {example['example']} RESPONSE: {example['response']}"
80 | )
81 |
82 | return system_message
83 |
84 | @computed_field
85 | @property
86 | def system_message_batch(self) -> str:
87 | system_message = f"""
88 | Use the following prompt to transform each input data:
89 | Prompt: {self.prompt}
90 | Respond with the transformed data for each input, separated by ||. Do not make any other conversation.
91 | Example: Content 1: , Content 2: \n Response: ||
92 | """
93 |
94 | # Add in fewshot examples
95 | if self.examples:
96 | system_message += "\nExample: "
97 | for idx, example in enumerate(self.examples):
98 | system_message += f"Content {str(idx+1)}: {example['example']}, "
99 | system_message += f"\nResponse: "
100 | for example in self.examples:
101 | system_message += f"{example['response']}||"
102 |
103 | return system_message
104 |
105 | def transform(self, input_data: str, max_tokens=1000, json: bool =False) -> str:
106 | """
107 | Transforms the input data using the specified LLM provider.
108 |
109 | Args:
110 | input_data (str): The text data to be transformed.
111 | max_tokens (int, optional): The maximum number of tokens to generate in the response. Defaults to 1000.
112 |
113 | Returns:
114 | str: The transformed data.
115 | """
116 | # Call the LLM provider to perform the transformation
117 | response = self.llm_provider.generate(
118 | self.system_message, input_data, max_tokens=max_tokens, json=json
119 | )
120 | transformed_data = response.strip()
121 | return transformed_data
122 |
123 | def transform_batch(self, input_data: List[str], max_tokens=1000, json: bool =False) -> List[str]:
124 | """
125 | Transforms a batch of input data using the specified LLM provider.
126 |
127 | Args:
128 | input_data (List[str]): A list of text data to be transformed.
129 | max_tokens (int, optional): The maximum number of tokens to generate in each response. Defaults to 1000.
130 |
131 | Returns:
132 | List[str]: A list of transformed data, where each element corresponds to the transformed version of the respective input data.
133 | """
134 | if len(input_data) == 1:
135 | return [self.transform(next(iter(input_data)))]
136 | # Call the LLM provider to perform the batch transformation
137 | input_data_prompt = "||".join(input_data)
138 | response = self.llm_provider.generate(
139 | self.system_message_batch, input_data_prompt, max_tokens=max_tokens, json=json
140 | )
141 |
142 | # Split the response into individual transformed data
143 | transformed_data_list = response.split("||")
144 |
145 | # Strip any leading/trailing whitespace from each transformed data
146 | transformed_data_list = [data.strip() for data in transformed_data_list]
147 |
148 | if len(transformed_data_list) != len(input_data):
149 | raise ValueError(
150 | f"Length of output list ({len(transformed_data_list)}) does not match the length of input list ({len(input_data)})."
151 | )
152 |
153 | return transformed_data_list
154 |
--------------------------------------------------------------------------------
/databonsai/transform/extract_transformer.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Optional
2 | from pydantic import field_validator, model_validator, computed_field
3 | from databonsai.transform.base_transformer import BaseTransformer
4 |
5 |
6 | class ExtractTransformer(BaseTransformer):
7 | """
8 | This class extends the BaseTransformer class and overrides the transform method to extract a given schema from the input data into a list of dictionaries.
9 |
10 | Attributes:
11 | output_schema (Dict[str, str]): A dictionary representing the schema of the output dictionaries.
12 | examples (Optional[List[Dict[str, str]]]): A list of example inputs and their corresponding extracted outputs.
13 |
14 | Raises:
15 | ValueError: If the output schema dictionary is empty, or if the transformed data does not match the expected format or schema.
16 | """
17 |
18 | output_schema: Dict[str, str]
19 | examples: Optional[List[Dict[str, str]]] = []
20 |
21 | @field_validator("output_schema")
22 | def validate_schema(cls, v):
23 | """
24 | Validates the output schema.
25 |
26 | Args:
27 | v (Dict[str, str]): The output schema to be validated.
28 |
29 | Raises:
30 | ValueError: If the output schema dictionary is empty.
31 |
32 | Returns:
33 | Dict[str, str]: The validated output schema.
34 | """
35 | if not v:
36 | raise ValueError("Schema dictionary cannot be empty.")
37 | return v
38 |
39 | @field_validator("examples")
40 | def validate_examples(cls, v):
41 | """
42 | Validates the examples list.
43 |
44 | Args:
45 | v (List[Dict[str, str]]): The examples list to be validated.
46 |
47 | Raises:
48 | ValueError: If the examples list is not a list of dictionaries or if any dictionary is missing the "example" or "response" key.
49 |
50 | Returns:
51 | List[Dict[str, str]]: The validated examples list.
52 | """
53 | if not isinstance(v, list):
54 | raise ValueError("Examples must be a list of dictionaries.")
55 | for example in v:
56 | if not isinstance(example, dict):
57 | raise ValueError("Each example must be a dictionary.")
58 | if "example" not in example or "response" not in example:
59 | raise ValueError(
60 | "Each example dictionary must have 'example' and 'response' keys."
61 | )
62 | return v
63 |
64 | @model_validator(mode="after")
65 | def validate_examples_responses(self):
66 | """
67 | Validates that the "response" values in the examples are valid JSON-formatted lists of dictionaries that match the output schema.
68 | """
69 | if self.examples:
70 | for example in self.examples:
71 | response = example.get("response")
72 | try:
73 | response_data = eval(response)
74 | except (SyntaxError, NameError, TypeError, ZeroDivisionError):
75 | raise ValueError(
76 | f"Invalid format in the example response: {response}"
77 | )
78 | if not isinstance(response_data, list):
79 | raise ValueError(
80 | f"Example response must be a JSON-formatted list: {response}"
81 | )
82 | for item in response_data:
83 | if not isinstance(item, dict):
84 | raise ValueError(
85 | f"Each item in the example response must be a dictionary: {response}"
86 | )
87 | if set(item.keys()) != set(self.output_schema.keys()):
88 | raise ValueError(
89 | f"The keys in the example response do not match the output schema: {response}"
90 | )
91 | return self
92 |
93 | @computed_field
94 | @property
95 | def system_message(self) -> str:
96 | system_message = f"""
97 | Use the following prompt to transform the input data:
98 | Input Data: {self.prompt}
99 | The transformed data should be a list of dictionaries, where each dictionary has the following schema:
100 | {self.output_schema}
101 | Reply with a JSON-formatted list of dictionaries. Do not make any conversation.
102 | """
103 |
104 | # Add in few-shot examples
105 | if self.examples:
106 | for example in self.examples:
107 | system_message += (
108 | f"\nEXAMPLE: {example['example']} RESPONSE: {example['response']}"
109 | )
110 |
111 | return system_message
112 |
113 |
114 | def transform(self, input_data: str, max_tokens=1000, json: bool =True) -> List[Dict[str, str]]:
115 | """
116 | Transforms the input data into a list of dictionaries using the specified LLM provider.
117 |
118 | Args:
119 | input_data (str): The text data to be transformed.
120 | max_tokens (int, optional): The maximum number of tokens to generate in the response. Defaults to 1000.
121 |
122 | Returns:
123 | List[Dict[str, str]]: The transformed data as a list of dictionaries.
124 |
125 | Raises:
126 | ValueError: If the transformed data does not match the expected format or schema.
127 | """
128 | # Call the LLM provider to perform the transformation
129 | response = self.llm_provider.generate(
130 | self.system_message, input_data, max_tokens=max_tokens, json=json
131 | )
132 |
133 | try:
134 | transformed_data = eval(response)
135 | except (SyntaxError, NameError, TypeError, ZeroDivisionError):
136 | raise ValueError("Invalid format in the transformed data.")
137 |
138 | # Validate the transformed data
139 | if not isinstance(transformed_data, list):
140 | raise ValueError("Transformed data must be a list.")
141 | for item in transformed_data:
142 | if not isinstance(item, dict):
143 | raise ValueError(
144 | "Each item in the transformed data must be a dictionary."
145 | )
146 | if set(item.keys()) != set(self.output_schema.keys()):
147 | raise ValueError(
148 | "The keys in the transformed data do not match the schema."
149 | )
150 |
151 | return transformed_data
152 |
153 | # def transform_batch(self, input_data: List[str], max_tokens=1000) -> List[List[Dict[str, str]]]:
154 | # """
155 | # Transforms a batch of input data into lists of dictionaries using the specified LLM provider.
156 |
157 | # Args:
158 | # input_data (List[str]): A list of text data to be transformed.
159 | # max_tokens (int, optional): The maximum number of tokens to generate in each response. Defaults to 1000.
160 |
161 | # Returns:
162 | # List[List[Dict[str, str]]]: A list of transformed data, where each element is a list of dictionaries corresponding to the respective input data.
163 |
164 | # Raises:
165 | # ValueError: If the transformed data does not match the expected format or schema.
166 | # """
167 | # # Call the LLM provider to perform the batch transformation
168 | # response = self.llm_provider.generate_batch(
169 | # self.system_message_batch, input_data, max_tokens=max_tokens
170 | # )
171 |
172 | # # Split the response into individual transformed data
173 | # transformed_data_list = response.split("||")
174 |
175 | # # Strip any leading/trailing whitespace from each transformed data
176 | # transformed_data_list = [data.strip() for data in transformed_data_list]
177 |
178 | # if len(transformed_data_list) != len(input_data):
179 | # raise ValueError(
180 | # f"Length of output list ({len(transformed_data_list)}) does not match the length of input list ({len(input_data)})."
181 | # )
182 |
183 | # # Evaluate and validate each transformed data
184 | # result = []
185 | # for data in transformed_data_list:
186 | # try:
187 | # transformed_data = eval(data)
188 | # except (SyntaxError, NameError, TypeError, ZeroDivisionError):
189 | # raise ValueError(f"Invalid format in the transformed data: {data}")
190 |
191 | # # Validate the transformed data
192 | # if not isinstance(transformed_data, list):
193 | # raise ValueError(f"Transformed data must be a list: {data}")
194 | # for item in transformed_data:
195 | # if not isinstance(item, dict):
196 | # raise ValueError(
197 | # f"Each item in the transformed data must be a dictionary: {data}"
198 | # )
199 | # if set(item.keys()) != set(self.output_schema.keys()):
200 | # raise ValueError(
201 | # f"The keys in the transformed data do not match the schema: {data}"
202 | # )
203 |
204 | # result.append(transformed_data)
205 |
206 | # return result
207 |
--------------------------------------------------------------------------------
/databonsai/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .apply import apply_to_column, apply_to_column_batch, apply_to_column_autobatch
2 |
--------------------------------------------------------------------------------
/databonsai/utils/apply.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 | from typing import List, Callable, Union, get_origin
3 | import inspect
4 |
5 |
6 | def apply_to_column(
7 | input_column: List,
8 | output_column: List,
9 | func: Callable,
10 | start_idx: int = 0,
11 | ) -> int:
12 | """
13 | Apply a function to each value in a column of a DataFrame or a normal Python list, starting from a specified index.
14 |
15 | Parameters:
16 | input_column (List): The column of the DataFrame or a normal Python list to apply the function to.
17 | output_column (List): A list to store the processed values. The function will mutate this list in-place.
18 | func (callable): The function to apply to each value in the column.
19 | The function should take a single value as input and return a single value.
20 | start_idx (int, optional): The index from which to start applying the function. Default is 0.
21 |
22 | Returns:
23 | int: The index of the last successfully processed value.
24 |
25 | """
26 | if len(input_column) == 0:
27 | raise ValueError("Input input_column is empty.")
28 |
29 | if start_idx >= len(input_column):
30 | raise ValueError(
31 | f"start_idx ({start_idx}) is greater than or equal to the length of the input_column ({len(input_column)})."
32 | )
33 |
34 | if len(output_column) > len(input_column):
35 | raise ValueError(
36 | f"The length of the output_column ({len(output_column)}) is greater than the length of the input_column ({len(input_column)})."
37 | )
38 |
39 | success_idx = start_idx
40 |
41 | try:
42 | for idx, value in enumerate(
43 | tqdm(input_column[start_idx:], desc="Processing data..", unit="row"),
44 | start=start_idx,
45 | ):
46 | result = func(value)
47 |
48 | if idx >= len(output_column):
49 | output_column.append(result)
50 | else:
51 | output_column[idx] = result
52 |
53 | success_idx = idx + 1
54 | except Exception as e:
55 | print(f"Error occurred at index {success_idx}: {str(e)}")
56 | print(f"Processing stopped at index {success_idx - 1}")
57 | return success_idx
58 |
59 | return success_idx
60 |
61 |
62 | def apply_to_column_batch(
63 | input_column: List,
64 | output_column: List,
65 | func: Callable,
66 | batch_size: int = 5,
67 | start_idx: int = 0,
68 | ) -> int:
69 | """
70 | Apply a function to each batch of values in a column of a DataFrame or a normal Python list, starting from a specified index.
71 |
72 | Parameters:
73 | input_column (List): The column of the DataFrame or a normal Python list to apply the function to.
74 | output_column (List): A list to store the processed values. The function will mutate this list in-place.
75 | func (callable): The batch function to apply to each batch of values in the column.
76 | The function should take a list of values as input and return a list of processed values.
77 | batch_size (int, optional): The size of each batch. Default is 5.
78 | start_idx (int, optional): The index from which to start applying the function. Default is 0.
79 |
80 | Returns:
81 | tuple: A tuple containing two elements:
82 | - success_idx (int): The index of the last successfully processed batch.
83 |
84 | """
85 | if len(input_column) == 0:
86 | raise ValueError("Input input_column is empty.")
87 |
88 | if start_idx >= len(input_column):
89 | raise ValueError(
90 | f"start_idx ({start_idx}) is greater than or equal to the length of the input_column ({len(input_column)})."
91 | )
92 |
93 | if len(output_column) > len(input_column):
94 | raise ValueError(
95 | f"The length of the output_column list ({len(output_column)}) is greater than the length of th input_column ({len(input_column)})."
96 | )
97 |
98 | check_func(func)
99 | success_idx = start_idx
100 | num_items = len(input_column)
101 | try:
102 | with tqdm(total=num_items, desc="Processing data..", unit="item") as pbar:
103 | for i in range(start_idx, num_items, batch_size):
104 | batch_end = min(i + batch_size, num_items)
105 | batch = input_column[i:batch_end]
106 | batch_result = func(batch)
107 |
108 | # Update output column
109 | if i >= len(output_column):
110 | output_column.extend(batch_result)
111 | else:
112 | output_column[i : i + len(batch_result)] = batch_result
113 |
114 | # Update progress bar by the number of items processed in this batch
115 | pbar.update(len(batch_result))
116 |
117 | success_idx = batch_end
118 | except Exception as e:
119 |
120 | print(f"Error occurred at batch starting at index {success_idx}: {str(e)}")
121 | print(f"Processing stopped at batch ending at index {success_idx - 1}")
122 | return success_idx
123 |
124 | return min(success_idx, len(input_column))
125 |
126 |
127 | def apply_to_column_autobatch(
128 | input_column: List,
129 | output_column: List,
130 | func: Callable,
131 | max_retries: int = 3,
132 | max_batch_size: int = 5,
133 | batch_size: int = 2,
134 | ramp_factor: float = 1.5,
135 | ramp_factor_decay: float = 0.8,
136 | reduce_factor: float = 0.5,
137 | reduce_factor_decay: float = 0.8,
138 | start_idx: int = 0,
139 | ) -> int:
140 | """
141 | Apply a function to the input column using adaptive batch processing.
142 |
143 | This function applies a batch processing function to the input column, starting from the
144 | specified index. It adaptively adjusts the batch size based on the success or failure of
145 | each batch processing attempt. The function retries failed batches with a reduced batch
146 | size and gradually decreases the rate of batch size adjustment over time.
147 |
148 | Parameters:
149 | input_column (List): The input column to be processed.
150 | output_column (List): The list to store the processed results.
151 | func (callable): The batch function to apply to each batch of values in the column.
152 | The function should take a list of values as input and return a list of processed values.
153 | max_retries (int): The maximum number of retries for failed batches.
154 | max_batch_size (int): The maximum allowed batch size.
155 | batch_size (int): The initial batch size.
156 | ramp_factor (float): The factor by which the batch size is increased after a successful batch.
157 | ramp_factor_decay (float): The decay rate for the ramp factor after each successful batch.
158 | reduce_factor (float): The factor by which the batch size is reduced after a failed batch.
159 | reduce_factor_decay (float): The decay rate for the reduce factor after each failed batch.
160 | start_idx (int): The index from which to start processing the input column.
161 |
162 | Returns:
163 | int: The index of the last successfully processed item in the input column.
164 |
165 | """
166 | if len(input_column) == 0:
167 | raise ValueError("Input input_column is empty.")
168 |
169 | if start_idx >= len(input_column):
170 | raise ValueError(
171 | f"start_idx ({start_idx}) is greater than or equal to the length of the input_column ({len(input_column)})."
172 | )
173 |
174 | if len(output_column) > len(input_column):
175 | raise ValueError(
176 | f"The length of the output_column list ({len(output_column)}) is greater than the length of the input_column ({len(input_column)})."
177 | )
178 |
179 | check_func(func)
180 | success_idx = start_idx
181 | ramp_factor = ramp_factor
182 | reduce_factor = reduce_factor
183 |
184 | try:
185 | remaining_data = input_column[start_idx:]
186 | batch_size = batch_size
187 | retry_count = 0
188 |
189 | with tqdm(
190 | total=len(remaining_data), desc="Processing data..", unit="row"
191 | ) as pbar:
192 | while len(remaining_data) > 0:
193 | try:
194 | batch_size = min(batch_size, len(remaining_data))
195 | batch = remaining_data[:batch_size]
196 | batch_results = func(batch)
197 |
198 | # Update output_column in place
199 | output_column[success_idx : success_idx + batch_size] = (
200 | batch_results
201 | )
202 |
203 | remaining_data = remaining_data[batch_size:]
204 | retry_count = 0
205 | success_idx += batch_size
206 |
207 | # Update progress bar
208 | pbar.update(batch_size)
209 |
210 | # Increase the batch size using the decayed ramp factor
211 | batch_size = min(round(batch_size * ramp_factor), max_batch_size)
212 | ramp_factor = max(ramp_factor * ramp_factor_decay, 1.0)
213 |
214 | except Exception as e:
215 | if retry_count >= max_retries:
216 | raise ValueError(
217 | f"Processing failed after {max_retries} retries. Error: {str(e)}"
218 | )
219 | retry_count += 1
220 | # Decrease the batch size using the decayed reduce factor
221 | batch_size = max(round(batch_size * reduce_factor), 1)
222 | print(f"Retrying with smaller batch size: {batch_size}")
223 | reduce_factor *= reduce_factor_decay
224 |
225 | except Exception as e:
226 | print(f"Error occurred at batch starting at index {success_idx}: {str(e)}")
227 | print(f"Processing stopped at batch ending at index {success_idx - 1}")
228 | return success_idx
229 |
230 | return min(success_idx, len(input_column))
231 |
232 |
233 | def check_func(func):
234 | if not inspect.signature(func).parameters:
235 | raise TypeError("The provided function does not take any arguments.")
236 |
237 | # Ensure func is a batch function that takes a list
238 | first_param = list(inspect.signature(func).parameters.values())[0]
239 | param_annotation = first_param.annotation
240 | origin = get_origin(param_annotation)
241 |
242 | if origin is not get_origin(List):
243 | raise TypeError(
244 | "The provided function does not take a list or pandas.Series as input."
245 | )
246 |
--------------------------------------------------------------------------------
/databonsai/utils/logs.py:
--------------------------------------------------------------------------------
1 | # utils/logger.py
2 | import logging
3 |
4 | # Configure the root logger
5 | logging.basicConfig(
6 | level=logging.WARNING, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
7 | )
8 |
9 | # Create a global logger instance
10 | logger = logging.getLogger(__name__)
11 |
--------------------------------------------------------------------------------
/docs/AnthropicProvider.md:
--------------------------------------------------------------------------------
1 | # AnthropicProvider
2 |
3 | The `AnthropicProvider` class is a provider class that interacts with
4 | Anthropic's Claude API for generating text completions. It supports exponential
5 | backoff retries (from tenacity's library) to handle temporary failures, which is
6 | particularly useful when dealing with large datasets.
7 |
8 | ## Initialization
9 |
10 | The `__init__` method initializes the `AnthropicProvider` with an API key and
11 | retry parameters.
12 |
13 | ### Parameters
14 |
15 | - `api_key (str)`: Anthropic API key.
16 | - `multiplier (int)`: The multiplier for the exponential backoff in retries
17 | (default: 1).
18 | - `min_wait (int)`: The minimum wait time between retries (default: 1).
19 | - `max_wait (int)`: The maximum wait time between retries (default: 60).
20 | - `max_tries (int)`: The maximum number of attempts before giving up (default:
21 | 10).
22 | - `model (str)`: The default model to use for text generation (default:
23 | "claude-3-haiku-20240307").
24 | - `temperature (float)`: The temperature parameter for text generation
25 | (default: 0).
26 |
27 | ## Methods
28 |
29 | ### `generate`
30 |
31 | The `generate` method generates a text completion using Anthropic's Claude API,
32 | given a system prompt and a user prompt. It is decorated with retry logic to
33 | handle temporary failures.
34 |
35 | #### Parameters
36 |
37 | - `system_prompt (str)`: The system prompt to provide context or instructions
38 | for the generation.
39 | - `user_prompt (str)`: The user's prompt, based on which the text completion
40 | is generated.
41 | - `max_tokens (int)`: The maximum number of tokens to generate in the response
42 | (default: 1000).
43 |
44 | #### Returns
45 |
46 | - `str`: The generated text completion.
47 |
48 | ## Retry Decorator
49 |
50 | The `retry_with_exponential_backoff` decorator is used to apply retry logic with
51 | exponential backoff to instance methods. It captures the `self` context to
52 | access instance attributes for retry configuration.
53 |
54 | ## Usage
55 |
56 | If your ANTHROPIC_API_KEY is defined in .env:
57 |
58 | ```python
59 | from databonsai.llm_providers import AnthropicProvider
60 |
61 | provider = AnthropicProvider()
62 | ```
63 |
64 | Or, provide the api key as an argument:
65 |
66 | ```python
67 | provider = AnthropicProvider(api_key="your_Anthropic_api_key")
68 | ```
69 |
70 | Other parameters, for example:
71 |
72 | ```python
73 | provider = AnthropicProvider(model="claude-3-opus-20240229", max_tries=5, max_wait=120)
74 | ```
75 |
--------------------------------------------------------------------------------
/docs/BaseCategorizer.md:
--------------------------------------------------------------------------------
1 | # BaseCategorizer
2 |
3 | The `BaseCategorizer` class provides functionality for categorizing input data
4 | utilizing a specified LLM provider. This class serves as a base for implementing
5 | categorization tasks where inputs are classified into predefined categories.
6 |
7 | ## Features
8 |
9 | - **Custom Categories**: Define your own categories for data classification.
10 | - **Input Validation**: Ensures the integrity of categories, examples, and
11 | input data for reliable categorization.
12 | - **Batch Categorization**: Categorizes multiple inputs simultaneously for
13 | token savings
14 |
15 | ## Attributes
16 |
17 | - `categories` (Dict[str, str]): A dictionary mapping category names to their
18 | descriptions. This structure allows for a clear definition of possible
19 | categories for classification.
20 | - `llm_provider` (LLMProvider): An instance of an LLM provider to be used for
21 | categorization.
22 | - `examples` (Optional[List[Dict[str, str]]]): A list of example inputs and
23 | their corresponding categories to improve categorization accuracy.
24 | - `strict` (bool): If True, raises an error when the predicted category is not
25 | one of the provided categories.
26 |
27 | ## Computed Fields
28 |
29 | - `system_message` (str): A system message used for single input
30 | categorization based on the provided categories and examples.
31 | - `system_message_batch` (str): A system message used for batch input
32 | - `category_mapping` (Dict[int, str]): Mapping of category index to category
33 | name
34 | - `inverse_category_mapping` (Dict[str, int]): Mapping of category name to
35 | index
36 |
37 | ## Methods
38 |
39 | ### `categorize`
40 |
41 | Categorizes the input data using the specified LLM provider.
42 |
43 | #### Arguments
44 |
45 | - `input_data` (str): The text data to be categorized.
46 |
47 | #### Returns
48 |
49 | - `str`: The predicted category for the input data.
50 |
51 | #### Raises
52 |
53 | - `ValueError`: If the predicted category is not one of the provided
54 | categories.
55 |
56 | ### `categorize_batch`
57 |
58 | Categorizes a batch of input data using the specified LLM provider. For less
59 | advanced LLMs, call this method on batches of 3-5 inputs (depending on the
60 | length of the input data).
61 |
62 | #### Arguments
63 |
64 | - `input_data` (List[str]): A list of text data to be categorized.
65 |
66 | #### Returns
67 |
68 | - `List[str]`: A list of predicted categories for the input data.
69 |
70 | #### Raises
71 |
72 | - `ValueError`: If the predicted categories are not a subset of the provided
73 | categories or if the number of predicted categories does not match the
74 | number of input data.
75 |
76 | ## Usage
77 |
78 | Setup the LLM provider and categories (as a dictionary):
79 |
80 | ```python
81 | from databonsai.categorize import BaseCategorizer
82 | from databonsai.llm_providers import OpenAIProvider, AnthropicProvider
83 |
84 | provider = OpenAIProvider() # Or AnthropicProvider()
85 | categories = {
86 | "Weather": "Insights and remarks about weather conditions.",
87 | "Sports": "Observations and comments on sports events.",
88 | "Celebrities": "Celebrity sightings and gossip",
89 | "Others": "Comments do not fit into any of the above categories",
90 | "Anomaly": "Data that does not look like comments or natural language",
91 | }
92 | ```
93 |
94 | Categorize your data:
95 |
96 | ```python
97 | categorizer = BaseCategorizer(
98 | categories=categories,
99 | llm_provider=provider,
100 | # strict=False, # Default true, set to False to allow for categories not in the provided
101 | )
102 | category = categorizer.categorize("It's been raining outside all day")
103 | print(category)
104 | ```
105 |
106 | Output:
107 |
108 | ```python
109 | Weather
110 | ```
111 |
112 | Categorize a few inputs:
113 |
114 | ```python
115 | categories = categorizer.categorize([
116 | "Storm Delays Government Budget Meeting, Weather and Politics Clash",
117 | "Olympic Star's Controversial Tweets Ignite Political Debate, Sports Meets Politics",
118 | "Local Football Hero Opens New Gym, Sports and Business Combine"])
119 | print(categories)
120 | ```
121 |
122 | Output:
123 |
124 | ```python
125 | ['Weather', 'Sports', 'Celebrities']
126 | ```
127 |
128 | Categorize a list of inputs (Use shorter lists for weaker LLMs):
129 |
130 | ```python
131 | headlines = [
132 | "Massive Blizzard Hits the Northeast, Thousands Without Power",
133 | "Local High School Basketball Team Wins State Championship After Dramatic Final",
134 | "Celebrated Actor Launches New Environmental Awareness Campaign",
135 | "President Announces Comprehensive Plan to Combat Cybersecurity Threats",
136 | "Tech Giant Unveils Revolutionary Quantum Computer",
137 | "Tropical Storm Alina Strengthens to Hurricane as It Approaches the Coast",
138 | "Olympic Gold Medalist Announces Retirement, Plans Coaching Career",
139 | "Film Industry Legends Team Up for Blockbuster Biopic",
140 | "Government Proposes Sweeping Reforms in Public Health Sector",
141 | "Startup Develops App That Predicts Traffic Patterns Using AI",
142 | ]
143 | categories = categorizer.categorize_batch(headlines)
144 | print(categories)
145 | ```
146 |
147 | Output:
148 |
149 | ```python
150 | ['Weather', 'Sports', 'Celebrities', 'Politics', 'Tech', 'Weather', 'Sports', 'Celebrities', 'Politics', 'Tech']
151 | ```
152 |
153 | Categorize a long list of inputs, or a dataframe column:
154 |
155 | ```python
156 | from databonsai.utils import apply_to_column_batch, apply_to_column
157 |
158 | categories = []
159 | success_idx = apply_to_column_batch(headlines, categories, categorizer.categorize, 3, 0)
160 | ```
161 |
162 | Without batching:
163 |
164 | ```python
165 | categories = []
166 | success_idx = apply_to_column(headlines, categories, categorizer.categorize)
167 | ```
168 |
--------------------------------------------------------------------------------
/docs/BaseTransformer.md:
--------------------------------------------------------------------------------
1 | # BaseTransformer
2 |
3 | The `BaseTransformer` class is a base class for transforming input data using a
4 | specified LLM provider. It provides a foundation for implementing data
5 | transformation tasks using language models.
6 |
7 | ## Features
8 |
9 | - **Transformation Prompts**: Define your own prompts to guide the
10 | transformation process.
11 | - **Input Validation**: Ensures the integrity of prompts, examples, and input
12 | data for reliable transformation.
13 | - **Few-Shot Learning**: Supports providing example inputs and responses to
14 | improve transformation accuracy.
15 | - **Batch Transformation**: Transforms multiple inputs simultaneously for
16 | token savings.
17 |
18 | ## Attributes
19 |
20 | - `prompt` (str): The prompt used to guide the transformation process. It
21 | provides instructions or context for the LLM provider to perform the desired
22 | transformation.
23 | - `llm_provider` (LLMProvider): An instance of an LLM provider to be used for
24 | transformation. The LLM provider is responsible for generating the
25 | transformed data based on the input data and the provided prompt.
26 | - `examples` (Optional[List[Dict[str, str]]]): A list of example inputs and
27 | their corresponding transformed outputs to improve transformation accuracy.
28 |
29 | ## Computed Fields
30 |
31 | - `system_message` (str): A system message used for single input
32 | transformation based on the provided prompt and examples.
33 | - `system_message_batch` (str): A system message used for batch input
34 | transformation based on the provided prompt and examples.
35 |
36 | ## Methods
37 |
38 | ### `transform`
39 |
40 | Transforms the input data using the specified LLM provider.
41 |
42 | #### Arguments
43 |
44 | - `input_data` (str): The text data to be transformed.
45 | - `max_tokens` (int, optional): The maximum number of tokens to generate in
46 | the response. Defaults to 1000.
47 | - `json` (bool): Whether to turn on JSON mode. Only works with OpenAIProvider for now
48 |
49 | #### Returns
50 |
51 | - `str`: The transformed data.
52 |
53 | ### `transform_batch`
54 |
55 | Transforms a batch of input data using the specified LLM provider.
56 |
57 | #### Arguments
58 |
59 | - `input_data` (List[str]): A list of text data to be transformed.
60 | - `max_tokens` (int, optional): The maximum number of tokens to generate in
61 | each response. Defaults to 1000.
62 |
63 | #### Returns
64 |
65 | - `List[str]`: A list of transformed data, where each element corresponds to
66 | the transformed version of the respective input data.
67 |
68 | #### Raises
69 |
70 | - `ValueError`: If the length of the output list does not match the length of
71 | the input list.
72 |
73 | ## Usage
74 |
75 | Prepare the transformer:
76 |
77 | ```python
78 | from databonsai.llm_providers import OpenAIProvider, AnthropicProvider
79 | from databonsai.transform import BaseTransformer
80 |
81 | pii_remover = BaseTransformer(
82 | prompt="Replace any Personal Identity Identifiers (PII) in the given text with . PII includes any information that can be used to identify an individual, such as names, addresses, phone numbers, email addresses, social security numbers, etc.",
83 | llm_provider=AnthropicProvider(),
84 | examples=[
85 | {
86 | "example": "My name is John Doe and my phone number is (555) 123-4567.",
87 | "response": "My name is and my phone number is .",
88 | },
89 | {
90 | "example": "My email address is johndoe@gmail.com.",
91 | "response": "My email address is .",
92 | },
93 | ],
94 | )
95 | ```
96 |
97 | Run the transformation:
98 |
99 | ```python
100 | print(
101 | pii_remover.transform(
102 | "John Doe, residing at 1234 Maple Street, Anytown, CA, 90210, recently contacted customer support to report an issue. He provided his phone number, (555) 123-4567, and email address, johndoe@email.com, for follow-up communication."
103 | )
104 | )
105 | ```
106 |
107 | Output:
108 |
109 | ```python
110 | , residing at , , , , recently contacted customer support to report an issue. They provided their phone number, , and email address, , for follow-up communication.
111 | ```
112 |
113 | Transforma a list of data: (use apply_to_column_batch for large datasets)
114 |
115 | ```python
116 | pii_texts = [
117 | "Just confirmed the reservation for Amanda Clark, Passport No. B2345678, traveling to Tokyo on May 15, 2024.",
118 | "Received payment from Michael Thompson, Credit Card ending in 4547, Billing Address: 45 Westview Lane, Springfield, IL.",
119 | "Application received from Julia Martinez, DOB 03/19/1994, SSN 210-98-7654, applying for the marketing position.",
120 | "Lease agreement finalized for Henry Wilson, Tenant ID: WILH12345, property at 89 Riverside Drive, Brooklyn, NY.",
121 | "Registration details for Lucy Davis, Student ID 20231004, enrolled in Advanced Chemistry, Fall semester.",
122 | "David Lee called about his insurance claim, Policy #9988776655, regarding the accident on April 5th.",
123 | "Booking confirmation for Sarah H. Richards, Flight AC202, Seat 14C, Frequent Flyer #GH5554321, departing June 12, 2024.",
124 | "Kevin Brown's gym membership has been renewed, Member ID: 654321, Phone: (555) 987-6543, Email: kbrown@example.com.",
125 | "Prescription ready for Emma Thomas, Health ID 567890123, prescribed by Dr. Susan Hill on April 10th, 2024.",
126 | "Alice Johnson requested a copy of her employment contract, Employee No. 112233, hired on August 1st, 2023.",
127 | ]
128 |
129 | cleaned_texts = pii_remover.transform_batch(pii_texts)
130 | print(cleaned_texts)
131 | ```
132 |
133 | Output:
134 |
135 | ```python
136 | ['Just confirmed the reservation for , Passport No. , traveling to Tokyo on May 15, 2024.', 'Received payment from , Credit Card ending in , Billing Address: 45 Westview Lane, Springfield, IL.', 'Application received from , DOB 03/19/1994, SSN , applying for the marketing position.', 'Lease agreement finalized for , Tenant ID: WILH12345, property at 89 Riverside Drive, Brooklyn, NY.', 'Registration details for , Student ID 20231004, enrolled in Advanced Chemistry, Fall semester.', ' called about his insurance claim, Policy #9988776655, regarding the accident on April 5th.', 'Booking confirmation for , Flight AC202, Seat 14C, Frequent Flyer #GH5554321, departing June 12, 2024.', "Kevin Brown's gym membership has been renewed, Member ID: 654321, Phone: (555) 987-6543, Email: kbrown@example.com.", 'Prescription ready for , Health ID 567890123, prescribed by Dr. Susan Hill on April 10th, 2024.', 'Alice Johnson requested a copy of her employment contract, Employee No. 112233, hired on August 1st, 2023.']
137 | ```
138 |
139 | Transform a long list of data or a dataframe column (with batching):
140 |
141 | ```python
142 | from databonsai.utils import apply_to_column_batch, apply_to_column
143 |
144 | cleaned_texts = []
145 | success_idx = apply_to_column_batch(
146 | pii_texts, cleaned_texts, pii_remover.transform_batch, 4, 0
147 | )
148 | ```
149 |
150 | Without batching:
151 |
152 | ```
153 | cleaned_texts = []
154 | success_idx = apply_to_column(
155 | pii_texts, cleaned_texts, pii_remover.transform
156 | )
157 | ```
158 |
--------------------------------------------------------------------------------
/docs/ExtractTransformer.md:
--------------------------------------------------------------------------------
1 | # ExtractTransformer
2 |
3 | The `ExtractTransformer` class extends the `BaseTransformer` class and overrides
4 | the `transform` method to extract a given schema from the input data into a list
5 | of dictionaries. It allows for transforming input data into a structured format
6 | according to a specified schema.
7 |
8 | ## Features
9 |
10 | - **Custom Output Schema**: Define your own output schema to structure the
11 | transformed data.
12 | - **Input Validation**: Ensures the integrity of the output schema, examples,
13 | and input data for reliable transformation.
14 | - **Few-Shot Learning**: Supports providing example inputs and responses to
15 | improve transformation accuracy.
16 |
17 | ## Attributes
18 |
19 | - `output_schema` (Dict[str, str]): A dictionary representing the schema of
20 | the output dictionaries. It defines the expected keys and their
21 | corresponding value types in the transformed data.
22 | - `examples` (Optional[List[Dict[str, str]]]): A list of example inputs and
23 | their corresponding extracted outputs.
24 |
25 | ## Computed Fields
26 |
27 | - `system_message` (str): A system message used for single input
28 | transformation based on the provided prompt, output schema, and examples.
29 |
30 | ## Methods
31 |
32 | ### `transform`
33 |
34 | Transforms the input data into a list of dictionaries using the specified LLM
35 | provider.
36 |
37 | #### Arguments
38 |
39 | - `input_data` (str): The text data to be transformed.
40 | - `max_tokens` (int, optional): The maximum number of tokens to generate in
41 | the response. Defaults to 1000.
42 | - `json` (bool): Whether to turn on JSON mode. Only works with OpenAIProvider for now
43 |
44 | #### Returns
45 |
46 | - `List[Dict[str, str]]`: The transformed data as a list of dictionaries.
47 |
48 | #### Raises
49 |
50 | - `ValueError`: If the transformed data does not match the expected format or
51 | schema.
52 |
53 | ## Usage
54 |
55 | Prepare a Extract transformer with a prompt and output schema:
56 |
57 | ```python
58 | from databonsai.llm_providers import OpenAIProvider
59 | from databonsai.transform import ExtractTransformer
60 |
61 | output_schema = {
62 | "question": "generated question about given information",
63 | "answer": "answer to the question, only using information from the given data",
64 | }
65 |
66 | qna = ExtractTransformer(
67 | prompt="Your goal is to create a set of questions and answers to help a person memorise every single detail of a document.",
68 | output_schema=output_schema,
69 | llm_provider=OpenAIProvider(),
70 | examples=[
71 | {
72 | "example": "Bananas are naturally radioactive due to their potassium content. They contain potassium-40, a radioactive isotope of potassium, which contributes to a tiny amount of radiation in every banana.",
73 | "response": str(
74 | [
75 | {
76 | "question": "Why are bananas naturally radioactive?",
77 | "answer": "Bananas are naturally radioactive due to their potassium content.",
78 | },
79 | {
80 | "question": "What is the radioactive isotope of potassium in bananas?",
81 | "answer": "The radioactive isotope of potassium in bananas is potassium-40.",
82 | },
83 | ]
84 | ),
85 | }
86 | ],
87 | )
88 | ```
89 |
90 | Here's the text we want to extract questions and answers from:
91 |
92 | ```python
93 | text = """ Sky-gazers across North America are in for a treat on April 8 when a total solar eclipse will pass over Mexico, the United States and Canada.
94 |
95 | The event will be visible to millions — including 32 million people in the US alone — who live along the route the moon’s shadow will travel during the eclipse, known as the path of totality. For those in the areas experiencing totality, the moon will appear to completely cover the sun. Those along the very center line of the path will see an eclipse that lasts between 3½ and 4 minutes, according to NASA.
96 |
97 | The next total solar eclipse won’t be visible across the contiguous United States again until August 2044. (It’s been nearly seven years since the “Great American Eclipse” of 2017.) And an annular eclipse won’t appear across this part of the world again until 2046."""
98 |
99 | print(qna.transform(text))
100 | ```
101 |
102 | Output:
103 |
104 | ```python
105 | [
106 | {
107 | "question": "When will the total solar eclipse pass over Mexico, the United States, and Canada?",
108 | "answer": "The total solar eclipse will pass over Mexico, the United States, and Canada on April 8.",
109 | },
110 | {
111 | "question": "What is the path of totality?",
112 | "answer": "The path of totality is the route the moon's shadow will travel during the eclipse where the moon will appear to completely cover the sun.",
113 | },
114 | {
115 | "question": "How long will the eclipse last for those along the very center line of the path of totality?",
116 | "answer": "For those along the very center line of the path of totality, the eclipse will last between 3½ and 4 minutes.",
117 | },
118 | {
119 | "question": "When will the next total solar eclipse be visible across the contiguous United States?",
120 | "answer": "The next total solar eclipse visible across the contiguous United States will be in August 2044.",
121 | },
122 | {
123 | "question": "When will an annular eclipse next appear across the contiguous United States?",
124 | "answer": "An annular eclipse won't appear across the contiguous United States again until 2046.",
125 | },
126 | ]
127 | ```
128 |
129 | Batching is not supported for ExtractTransformer yet.
130 |
--------------------------------------------------------------------------------
/docs/MultiCategorizer.md:
--------------------------------------------------------------------------------
1 | # MultiCategorizer
2 |
3 | The `MultiCategorizer` class is an extension of the `BaseCategorizer` class,
4 | providing functionality for categorizing input data into multiple categories
5 | using a specified LLM provider. This class overrides the `categorize` and
6 | `categorize_batch` methods to enable the prediction of multiple categories for a
7 | given input.
8 |
9 | ## Features
10 |
11 | - **Multiple Category Prediction**: Categorize input data into one or more
12 | predefined categories.
13 | - **Subset Validation**: Ensures that the predicted categories are a subset of
14 | the provided categories.
15 | - **Batch Categorization**: Categorizes multiple inputs simultaneously for
16 | token savings.
17 |
18 | ## Attributes
19 |
20 | - `categories` (Dict[str, str]): A dictionary mapping category names to their
21 | descriptions. This structure allows for a clear definition of possible
22 | categories for classification.
23 | - `llm_provider` (LLMProvider): An instance of an LLM provider to be used for
24 | categorization.
25 | - `examples` (Optional[List[Dict[str, str]]]): A list of example inputs and
26 | their corresponding categories to improve categorization accuracy. If there
27 | are multiple categories for an example, they should be separated by commas.
28 | - `strict` (bool): If True, raises an error when the predicted category is not
29 | one of the provided categories.
30 |
31 | ## Computed Fields
32 |
33 | - `system_message` (str): A system message used for single input
34 | categorization based on the provided categories and examples.
35 | - `system_message_batch` (str): A system message used for batch input
36 | categorization based on the provided categories and examples.
37 | - `category_mapping` (Dict[int, str]): Mapping of category index to category
38 | name
39 | - `inverse_category_mapping` (Dict[str, int]): Mapping of category name to
40 | index
41 |
42 | ## Methods
43 |
44 | ### `categorize`
45 |
46 | Categorizes the input data into multiple categories using the specified LLM
47 | provider.
48 |
49 | #### Arguments
50 |
51 | - `input_data` (str): The text data to be categorized.
52 |
53 | #### Returns
54 |
55 | - `str`: A string of categories for the input data, separated by commas.
56 |
57 | #### Raises
58 |
59 | - `ValueError`: If the predicted categories are not a subset of the provided
60 | categories.
61 |
62 | ### `categorize_batch`
63 |
64 | Categorizes a batch of input data into multiple categories using the specified
65 | LLM provider. For less advanced LLMs, call this method on batches of 3-5 inputs
66 | (depending on the length of the input data).
67 |
68 | #### Arguments
69 |
70 | - `input_data` (List[str]): A list of text data to be categorized.
71 |
72 | #### Returns
73 |
74 | - `List[str]`: A list of predicted categories for each input data. If there
75 | are multiple categories for an input, they will be separated by commas.
76 |
77 | #### Raises
78 |
79 | - `ValueError`: If the predicted categories are not a subset of the provided
80 | categories or if the number of predicted category sets does not match the
81 | number of input data.
82 |
83 | ## Usage
84 |
85 | Setup the LLM provider and categories (as a dictionary):
86 |
87 | ```python
88 | from databonsai.categorize import MultiCategorizer
89 | from databonsai.llm_providers import OpenAIProvider, AnthropicProvider
90 |
91 | provider = OpenAIProvider() # Or AnthropicProvider()
92 | categories = {
93 | "Weather": "Insights and remarks about weather conditions.",
94 | "Sports": "Observations and comments on sports events.",
95 | "Celebrities": "Celebrity sightings and gossip",
96 | "Others": "Comments do not fit into any of the above categories",
97 | "Anomaly": "Data that does not look like comments or natural language",
98 | }
99 |
100 | tagger = MultiCategorizer(
101 | categories=categories,
102 | llm_provider=provider,
103 | examples=[
104 | {
105 | "example": "Big stormy skies over city causes league football game to be cancelled",
106 | "response": "Weather,Sports",
107 | },
108 | {
109 | "example": "Elon musk likes to play golf",
110 | "response": "Sports,Celebrities",
111 | },
112 | # strict=False, # Default true, set to False to allow for categories not in the provided
113 | ],
114 | )
115 | categories = tagger.categorize(
116 | "It's been raining outside all day, and I saw Elon Musk. 13rewfdsacw10289u(#!*@)"
117 | )
118 | print(categories)
119 | ```
120 |
121 | Output:
122 |
123 | ```python
124 | ['Weather', 'Celebrities', 'Anomaly']
125 | ```
126 |
127 | Categorize a list of data:
128 |
129 | ```python
130 | mixed_headlines = [
131 | "Storm Delays Government Budget Meeting, Weather and Politics Clash",
132 | "Olympic Star's Controversial Tweets Ignite Political Debate, Sports Meets Politics",
133 | "Local Football Hero Opens New Gym, Sports and Business Combine",
134 | "Tech CEO's Groundbreaking Climate Initiative, Technology and Environment at Forefront",
135 | "Celebrity Chef Fights for Food Security Legislation, Culinary Meets Politics",
136 | "Hollywood Biopic of Legendary Athlete Set to Premiere, Blending Sports and Cinema",
137 | "Massive Flooding Disrupts Local Elections, Intersection of Weather and Politics",
138 | "Tech Billionaire Invests in Sports Teams, Merging Business with Athletics",
139 | "Pop Star's Concert Raises Funds for Disaster Relief, Combining Music with Charity",
140 | "Film Festival Highlights Environmental Documentaries, Merging Cinema and Green Activism",
141 | ]
142 | categories = tagger.categorize_batch(mixed_headlines)
143 | ```
144 |
145 | Output:
146 |
147 | ```python
148 | ['Weather,Politics', 'Sports,Politics', 'Sports,Others', 'Tech', 'Politics,Celebrities', 'Sports,Celebrities', 'Weather,Politics', 'Tech,Sports', 'Celebrities,Others', 'Celebrities,Others']
149 | ```
150 |
151 | Categorize a long list of data, or a dataframe column (with batching):
152 |
153 | ```python
154 | from databonsai.utils import apply_to_column_batch, apply_to_column
155 |
156 | categories = []
157 | success_idx = apply_to_column_batch(
158 | input_column=mixed_headlines,
159 | output_column=categories,
160 | function=tagger.categorize_batch,
161 | batch_size=3,
162 | start_idx=0
163 | )
164 | ```
165 |
166 | Without batching:
167 |
168 | ```python
169 | categories = []
170 | success_idx = apply_to_column(
171 | input_column=mixed_headlines,
172 | output_column=categories,
173 | function=tagger.categorize
174 | )
175 | ```
176 |
--------------------------------------------------------------------------------
/docs/OllamaProvider.md:
--------------------------------------------------------------------------------
1 | # OllamaProvider
2 |
3 | The `OllamaProvider` class is a provider class that interacts with Ollama's API
4 | for generating text completions. Note that tokens are not counted for Ollama,
5 | and there is no retry logic (since it's not needed).
6 |
7 | ## Initialization
8 |
9 | The `__init__` method initializes the `OllamaProvider` with an optional Ollama
10 | client or host, and retry parameters.
11 |
12 | ### Parameters
13 |
14 | - `model (str)`: The default model to use for text generation (default:
15 | "llama3").
16 | - `temperature (float)`: The temperature parameter for text generation
17 | (default: 0).
18 | - `host (str)`: The host URL for the Ollama API (optional).
19 |
20 | ## Methods
21 |
22 | ### `generate`
23 |
24 | The `generate` method generates a text completion using Ollama's API, with a
25 | given system prompt and a user prompt.
26 |
27 | #### Parameters
28 |
29 | - `system_prompt (str)`: The system prompt to provide context or instructions
30 | for the generation.
31 | - `user_prompt (str)`: The user's prompt, based on which the text completion
32 | is generated.
33 | - `max_tokens (int)`: The maximum number of tokens to generate in the response
34 | (default: 1000).
35 |
36 | #### Returns
37 |
38 | - `str`: The generated text completion.
39 |
40 | ## Usage
41 |
42 | If you have a host URL for the Ollama API:
43 |
44 | ```python
45 | from databonsai.llm_providers import OllamaProvider
46 |
47 | provider = OllamaProvider(model="llama3")
48 | ```
49 |
50 | or, provide a host
51 |
52 | ```python
53 | provider = OllamaProvider(host="http://localhost:11434", model="llama3")
54 | ```
55 |
56 | This uses ollama's python library under the hood.
57 |
--------------------------------------------------------------------------------
/docs/OpenAIProvider.md:
--------------------------------------------------------------------------------
1 | # OpenAIProvider
2 |
3 | The `OpenAIProvider` class is a provider class that interacts with OpenAI's API
4 | for generating text completions. It supports exponential backoff retries (from
5 | tenacity's library) to handle temporary failures, which is particularly useful
6 | when dealing with large datasets.
7 |
8 | ## Initialization
9 |
10 | The `__init__` method initializes the `OpenAIProvider` with an API key and retry
11 | parameters.
12 |
13 | ### Parameters
14 |
15 | - `api_key (str)`: OpenAI API key.
16 | - `multiplier (int)`: The multiplier for the exponential backoff in retries
17 | (default: 1).
18 | - `min_wait (int)`: The minimum wait time between retries (default: 1).
19 | - `max_wait (int)`: The maximum wait time between retries (default: 60).
20 | - `max_tries (int)`: The maximum number of attempts before giving up (default:
21 | 10).
22 | - `model (str)`: The default model to use for text generation (default:
23 | "gpt-4-turbo").
24 | - `temperature (float)`: The temperature parameter for text generation
25 | (default: 0).
26 |
27 | ## Methods
28 |
29 | ### `generate`
30 |
31 | The `generate` method generates a text completion using OpenAI's API, given a
32 | system prompt and a user prompt. It is decorated with retry logic to handle
33 | temporary failures.
34 |
35 | #### Parameters
36 |
37 | - `system_prompt (str)`: The system prompt to provide context or instructions
38 | for the generation.
39 | - `user_prompt (str)`: The user's prompt, based on which the text completion
40 | is generated.
41 | - `max_tokens (int)`: The maximum number of tokens to generate in the response
42 | (default: 1000).
43 | - `json (bool)`: Whether to use OpenAI's JSON response format (default:
44 | False).
45 |
46 | #### Returns
47 |
48 | - `str`: The generated text completion.
49 |
50 | ## Retry Decorator
51 |
52 | The `retry_with_exponential_backoff` decorator is used to apply retry logic with
53 | exponential backoff to instance methods. It captures the `self` context to
54 | access instance attributes for retry configuration.
55 |
56 | ## Usage
57 |
58 | If your OPENAI_API_KEY is defined in .env:
59 |
60 | ```python
61 | from databonsai.llm_providers import OpenAIProvider
62 |
63 | provider = OpenAIProvider()
64 |
65 | ```
66 |
67 | Or, provide the api key as an argument:
68 |
69 | ```python
70 | provider = OpenAIProvider(api_key="your_openai_api_key")
71 | ```
72 |
73 | Other parameters, for example:
74 |
75 | ```python
76 | provider = OpenAIProvider(model="gpt-4-turbo", max_tries=5, max_wait=120)
77 | ```
78 |
--------------------------------------------------------------------------------
/docs/Utils.md:
--------------------------------------------------------------------------------
1 | ## Util Methods
2 |
3 | ### `apply_to_column`
4 |
5 | Applies a function to each value in a column of a DataFrame or a normal Python
6 | list, starting from a specified index.
7 |
8 | #### Arguments
9 |
10 | - `input_column` (List): The column of the DataFrame or a normal Python list
11 | to which the function will be applied.
12 | - `output_column` (List): A list where the processed values will be stored.
13 | The function will mutate this list in-place.
14 | - `func` (Callable): The function to apply to each value in the column. It
15 | should take a single value as input and return a single value.
16 | - `start_idx` (int, optional): The index from which to start applying the
17 | function. Default is 0.
18 |
19 | #### Returns
20 |
21 | - `int`: The index of the last successfully processed value.
22 |
23 | #### Raises
24 |
25 | - `ValueError`: If the input or output column conditions are not met or if the
26 | starting index is out of bounds.
27 |
28 | ### `apply_to_column_batch`
29 |
30 | Applies a function to batches of values in a column of a DataFrame or a normal
31 | Python list, starting from a specified index.
32 |
33 | #### Arguments
34 |
35 | - `input_column` (List): The column of the DataFrame or a normal Python list
36 | to which the function will be applied.
37 | - `output_column` (List): A list where the processed values will be stored.
38 | The function will mutate this list in-place.
39 | - `func` (Callable): The batch function to apply to each batch of values in
40 | the column. It should take a list of values as input and return a list of
41 | processed values.
42 | - `batch_size` (int, optional): The size of each batch. Default is 5.
43 | - `start_idx` (int, optional): The index from which to start applying the
44 | function. Default is 0.
45 |
46 | #### Returns
47 |
48 | - `int`: The index of the last successfully processed batch.
49 |
50 | #### Raises
51 |
52 | - `ValueError`: If the input or output column conditions are not met or if the
53 | starting index or batch sizes are out of bounds.
54 |
55 | ### `apply_to_column_autobatch`
56 |
57 | Applies a function to the input column using adaptive batch processing, starting
58 | from a specified index and adjusting batch sizes based on success or failure.
59 |
60 | #### Arguments
61 |
62 | - `input_column` (List): The input column to be processed.
63 | - `output_column` (List): The list where the processed results will be stored.
64 | - `func` (Callable): The batch function used for processing.
65 | - `max_retries` (int): The maximum number of retries for failed batches.
66 | - `max_batch_size` (int): The maximum allowed batch size.
67 | - `batch_size` (int): The initial batch size.
68 | - `ramp_factor` (float): The factor by which the batch size is increased after
69 | a successful batch.
70 | - `ramp_factor_decay` (float): The decay rate for the ramp factor after each
71 | successful batch.
72 | - `reduce_factor` (float): The factor by which the batch size is reduced after
73 | a failed batch.
74 | - `reduce_factor_decay` (float): The decay rate for the reduce factor after
75 | each failed batch.
76 | - `start_idx` (int): The index from which to start processing the input
77 | column.
78 |
79 | #### Returns
80 |
81 | - `int`: The index of the last successfully processed item in the input
82 | column.
83 |
84 | #### Raises
85 |
86 | - `ValueError`: If the input or output column conditions are not met or if
87 | processing fails despite retries.
88 |
89 | ## Usage:
90 |
91 | ### AutoBatch for Larger datasets
92 |
93 | If you have a pandas dataframe or list, use `apply_to_column_autobatch`
94 |
95 | - Batching data for LLM api calls saves tokens by not sending the prompt for
96 | every row. However, too large a batch size / complex tasks can lead to
97 | errors. Naturally, the better the LLM model, the larger the batch size you
98 | can use.
99 |
100 | - This batching is handled adaptively (i.e., it will increase the batch size
101 | if the response is valid and reduce it if it's not, with a decay factor)
102 |
103 | Other features:
104 |
105 | - progress bar
106 | - returns the last successful index so you can resume from there, in case it
107 | exceeds max_retries
108 | - modifies your output list in place, so you don't lose any progress
109 |
110 | Retry Logic:
111 |
112 | - LLM providers have retry logic built in for API related errors. This can be
113 | configured in the provider.
114 | - The retry logic in the apply_to_column_autobatch is for handling invalid
115 | responses (e.g. unexpected category, different number of outputs, etc.)
116 |
117 | ```python
118 | from databonsai.utils import (
119 | apply_to_column_batch,
120 | apply_to_column,
121 | apply_to_column_autobatch,
122 | )
123 | import pandas as pd
124 |
125 | from databonsai.categorize import MultiCategorizer, BaseCategorizer
126 | from databonsai.llm_providers import OpenAIProvider, AnthropicProvider
127 |
128 | provider = OpenAIProvider()
129 | categories = {
130 | "Weather": "Insights and remarks about weather conditions.",
131 | "Sports": "Observations and comments on sports events.",
132 | "Politics": "Political events related to governments, nations, or geopolitical issues.",
133 | "Celebrities": "Celebrity sightings and gossip",
134 | "Others": "Comments do not fit into any of the above categories",
135 | "Anomaly": "Data that does not look like comments or natural language",
136 | }
137 | few_shot_examples = [
138 | {"example": "Big stormy skies over city", "response": "Weather"},
139 | {"example": "The team won the championship", "response": "Sports"},
140 | {"example": "I saw a famous rapper at the mall", "response": "Celebrities"},
141 | ]
142 | categorizer = BaseCategorizer(
143 | categories=categories, llm_provider=provider, examples=few_shot_examples
144 | )
145 | category = categorizer.categorize("It's been raining outside all day")
146 | headlines = [
147 | "Massive Blizzard Hits the Northeast, Thousands Without Power",
148 | "Local High School Basketball Team Wins State Championship After Dramatic Final",
149 | "Celebrated Actor Launches New Environmental Awareness Campaign",
150 | "President Announces Comprehensive Plan to Combat Cybersecurity Threats",
151 | "Tech Giant Unveils Revolutionary Quantum Computer",
152 | "Tropical Storm Alina Strengthens to Hurricane as It Approaches the Coast",
153 | "Olympic Gold Medalist Announces Retirement, Plans Coaching Career",
154 | "Film Industry Legends Team Up for Blockbuster Biopic",
155 | "Government Proposes Sweeping Reforms in Public Health Sector",
156 | "Startup Develops App That Predicts Traffic Patterns Using AI",
157 | ]
158 | df = pd.DataFrame(headlines, columns=["Headline"])
159 | df["Category"] = None # Initialize it if it doesn't exist, as we modify it in place
160 | success_idx = apply_to_column_autobatch(
161 | df["Headline"],
162 | df["Category"],
163 | categorizer.categorize_batch,
164 | batch_size=3,
165 | start_idx=0,
166 | )
167 | ```
168 |
169 | There are many more options available for autobatch, such as setting a
170 | max_retries, decay factor, and more. Check the docs for more details.
171 |
172 | If it fails midway (even after exponential backoff), you can resume from the
173 | last successful index + 1.
174 |
175 | ```python
176 | success_idx = apply_to_column_autobatch( df["Headline"], df["Category"], categorizer.categorize_batch, batch_size=10, start_idx=success_idx+1)
177 | ```
178 |
179 | This also works for regular python lists.
180 |
181 | Note that the better the LLM model, the greater the batch_size you can use
182 | (depending on the length of your inputs). If you're getting errors, reduce the
183 | batch_size, or use a better LLM model.
184 |
185 | To use it with batching, but with a fixed batch size:
186 |
187 | ```python
188 | success_idx = apply_to_column_batch( df["Headline"], df["Category"], categorizer.categorize_batch, batch_size=3, start_idx=0)
189 | ```
190 |
191 | To use it without batching:
192 |
193 | ```python
194 | success_idx = apply_to_column( df["Headline"], df["Category"], categorizer.categorize)
195 | ```
196 |
--------------------------------------------------------------------------------
/docs/source/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # For the full list of built-in configuration values, see the documentation:
4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
5 |
6 | # -- Project information -----------------------------------------------------
7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
8 | import os
9 | import sys
10 | import sphinx_rtd_theme
11 |
12 | html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
13 | html_theme_options = {
14 | "analytics_id": "UA-XXXXXXX-1", # Provided by Google Analytics
15 | "analytics_anonymize_ip": False,
16 | "logo_only": False,
17 | "display_version": True,
18 | "prev_next_buttons_location": "bottom",
19 | "style_external_links": False,
20 | "vcs_pageview_mode": "",
21 | "style_nav_header_background": "white",
22 | # Other options...
23 | }
24 | sys.path.insert(0, os.path.abspath("../../")) # Add project root
25 | sys.path.insert(0, os.path.abspath("../../databonsai"))
26 |
27 | project = "databonsai"
28 | copyright = "2024, Data Bonsai"
29 | author = "Data Bonsai"
30 | release = "0.2.0"
31 |
32 | # -- General configuration ---------------------------------------------------
33 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
34 |
35 | extensions = ["sphinx.ext.autodoc", "sphinx_rtd_theme"]
36 |
37 | templates_path = ["_templates"]
38 | exclude_patterns = []
39 |
40 |
41 | # -- Options for HTML output -------------------------------------------------
42 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
43 |
44 | html_theme = "sphinx_rtd_theme"
45 | html_static_path = ["_static"]
46 |
--------------------------------------------------------------------------------
/docs/source/databonsai.categorize.rst:
--------------------------------------------------------------------------------
1 | databonsai.categorize package
2 | =============================
3 |
4 | Submodules
5 | ----------
6 |
7 | databonsai.categorize.base\_categorizer module
8 | ----------------------------------------------
9 |
10 | .. automodule:: databonsai.categorize.base_categorizer
11 | :members:
12 | :undoc-members:
13 | :show-inheritance:
14 |
15 | databonsai.categorize.multi\_categorizer module
16 | -----------------------------------------------
17 |
18 | .. automodule:: databonsai.categorize.multi_categorizer
19 | :members:
20 | :undoc-members:
21 | :show-inheritance:
22 |
23 | Module contents
24 | ---------------
25 |
26 | .. automodule:: databonsai.categorize
27 | :members:
28 | :undoc-members:
29 | :show-inheritance:
30 |
--------------------------------------------------------------------------------
/docs/source/databonsai.llm_providers.rst:
--------------------------------------------------------------------------------
1 | databonsai.llm\_providers package
2 | =================================
3 |
4 | Submodules
5 | ----------
6 |
7 | databonsai.llm\_providers.anthropic\_provider module
8 | ----------------------------------------------------
9 |
10 | .. automodule:: databonsai.llm_providers.anthropic_provider
11 | :members:
12 | :undoc-members:
13 | :show-inheritance:
14 |
15 | databonsai.llm\_providers.llm\_provider module
16 | ----------------------------------------------
17 |
18 | .. automodule:: databonsai.llm_providers.llm_provider
19 | :members:
20 | :undoc-members:
21 | :show-inheritance:
22 |
23 | databonsai.llm\_providers.openai\_provider module
24 | -------------------------------------------------
25 |
26 | .. automodule:: databonsai.llm_providers.openai_provider
27 | :members:
28 | :undoc-members:
29 | :show-inheritance:
30 |
31 | Module contents
32 | ---------------
33 |
34 | .. automodule:: databonsai.llm_providers
35 | :members:
36 | :undoc-members:
37 | :show-inheritance:
38 |
--------------------------------------------------------------------------------
/docs/source/databonsai.rst:
--------------------------------------------------------------------------------
1 | databonsai package
2 | ==================
3 |
4 | Subpackages
5 | -----------
6 |
7 | .. toctree::
8 | :maxdepth: 4
9 |
10 | databonsai.categorize
11 | databonsai.llm_providers
12 | databonsai.transform
13 | databonsai.utils
14 |
15 | Module contents
16 | ---------------
17 |
18 | .. automodule:: databonsai
19 | :members:
20 | :undoc-members:
21 | :show-inheritance:
22 |
--------------------------------------------------------------------------------
/docs/source/databonsai.transform.rst:
--------------------------------------------------------------------------------
1 | databonsai.transform package
2 | ============================
3 |
4 | Submodules
5 | ----------
6 |
7 | databonsai.transform.base\_transformer module
8 | ---------------------------------------------
9 |
10 | .. automodule:: databonsai.transform.base_transformer
11 | :members:
12 | :undoc-members:
13 | :show-inheritance:
14 |
15 | databonsai.transform.decompose\_transformer module
16 | --------------------------------------------------
17 |
18 | .. automodule:: databonsai.transform.decompose_transformer
19 | :members:
20 | :undoc-members:
21 | :show-inheritance:
22 |
23 | Module contents
24 | ---------------
25 |
26 | .. automodule:: databonsai.transform
27 | :members:
28 | :undoc-members:
29 | :show-inheritance:
30 |
--------------------------------------------------------------------------------
/docs/source/databonsai.utils.rst:
--------------------------------------------------------------------------------
1 | databonsai.utils package
2 | ========================
3 |
4 | Module contents
5 | ---------------
6 |
7 | .. automodule:: databonsai.utils
8 | :members:
9 | :undoc-members:
10 | :show-inheritance:
11 |
--------------------------------------------------------------------------------
/docs/source/index.rst:
--------------------------------------------------------------------------------
1 | .. databonsai documentation master file, created by
2 | sphinx-quickstart on Tue Apr 9 16:43:03 2024.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | Welcome to databonsai's documentation!
7 | ======================================
8 |
9 | .. toctree::
10 | :maxdepth: 2
11 | :caption: Contents:
12 |
13 | databonsai
14 |
15 | .. Indices and tables
16 | .. ==================
17 |
18 | .. * :ref:`genindex`
19 | .. * :ref:`modindex`
20 | .. * :ref:`search`
21 |
--------------------------------------------------------------------------------
/docs/source/modules.rst:
--------------------------------------------------------------------------------
1 | databonsai
2 | ==========
3 |
4 | .. toctree::
5 | :maxdepth: 4
6 |
7 | databonsai
8 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "databonsai"
3 | version = "0.8.0"
4 | description = "A Python package to clean and curate your data with LLMs"
5 | authors = ["Alvin Ryanputra "]
6 |
7 | [tool.poetry.dependencies]
8 | python = "^3.8"
9 |
10 | openai = "^1.16.2"
11 | anthropic = "^0.23.1"
12 | tenacity = "^8.2.3"
13 | python-dotenv = "^1.0.1"
14 | pydantic = "^2.6.4"
15 | pydantic_core = "^2.16.3"
16 | ollama = "^0.1.0"
17 |
18 | [tool.poetry.dev-dependencies]
19 | # Add development dependencies here (if any)
20 | pytest = "^7.2.2"
21 |
22 | [build-system]
23 | requires = ["poetry-core"]
24 | build-backend = "poetry.core.masonry.api"
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name="databonsai",
5 | version="0.8.0",
6 | description="A package for cleaning and curating data with LLMs",
7 | long_description=open("README.md", encoding="utf-8").read(),
8 | long_description_content_type="text/markdown",
9 | author="Alvin Ryanputra",
10 | author_email="databonsai.ai@gmail.com",
11 | url="https://github.com/databonsai/databonsai",
12 | packages=find_packages(),
13 | install_requires=[
14 | "openai",
15 | "anthropic",
16 | "tenacity",
17 | "python-dotenv",
18 | "pydantic",
19 | "anthropic",
20 | "ollama",
21 | ],
22 | classifiers=[
23 | "Development Status :: 3 - Alpha",
24 | "Intended Audience :: Developers",
25 | "License :: OSI Approved :: MIT License",
26 | "Programming Language :: Python :: 3.8",
27 | "Programming Language :: Python :: 3.9",
28 | "Programming Language :: Python :: 3.10",
29 | "Programming Language :: Python :: 3.11",
30 | "Programming Language :: Python :: 3.12",
31 | ],
32 | )
33 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alvin-r/databonsai/3f2b7c58d0aa172251b5c3c321dea93038853e32/tests/__init__.py
--------------------------------------------------------------------------------
/tests/test_categorization.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from databonsai.categorize import BaseCategorizer, MultiCategorizer
3 | from databonsai.llm_providers import OpenAIProvider, AnthropicProvider
4 | from databonsai.utils import apply_to_column, apply_to_column_batch
5 | import pandas as pd
6 |
7 |
8 | @pytest.fixture
9 | def sample_categories():
10 | return {
11 | "Weather": "Insights and remarks about weather conditions.",
12 | "Sports": "Observations and comments on sports events.",
13 | "Celebrities": "Celebrity sightings and gossip",
14 | "Others": "Comments do not fit into any of the above categories",
15 | "Anomaly": "Data that does not look like comments or natural language",
16 | }
17 |
18 |
19 | @pytest.fixture(
20 | params=[
21 | OpenAIProvider(model="gpt-4-turbo", max_tries=1),
22 | AnthropicProvider(max_tries=1),
23 | ]
24 | )
25 | def sample_provider(request):
26 | return request.param
27 |
28 |
29 | @pytest.fixture
30 | def sample_dataframe():
31 | return pd.DataFrame(
32 | {
33 | "text": [
34 | "Massive Blizzard Hits the Northeast, Thousands Without Power",
35 | "Local High School Basketball Team Wins State Championship After Dramatic Final",
36 | "Celebrated Actor Launches New Environmental Awareness Campaign",
37 | "Startup Develops App That Predicts Traffic Patterns Using AI",
38 | "asdfoinasedf'awesdf",
39 | ]
40 | }
41 | )
42 |
43 |
44 | @pytest.fixture
45 | def sample_list():
46 | return [
47 | "Massive Blizzard Hits the Northeast, Thousands Without Power",
48 | "Local High School Basketball Team Wins State Championship After Dramatic Final",
49 | "Celebrated Actor Launches New Environmental Awareness Campaign",
50 | "Startup Develops App That Predicts Traffic Patterns Using AI",
51 | "asdfoinasedf'awesdf",
52 | ]
53 |
54 |
55 | def test_base_categorizer(sample_categories, sample_provider):
56 | """
57 | Test the BaseCategorizer class.
58 | """
59 | categorizer = BaseCategorizer(
60 | categories=sample_categories,
61 | llm_provider=sample_provider,
62 | examples=[
63 | {"example": "Big stormy skies over city", "response": "Weather"},
64 | {"example": "The team won the championship", "response": "Sports"},
65 | {"example": "I saw a famous rapper at the mall", "response": "Celebrities"},
66 | ],
67 | )
68 |
69 | assert categorizer.categorize("It's raining heavily today.") == "Weather"
70 | assert categorizer.categorize("The football match was exciting!") == "Sports"
71 | assert categorizer.categorize("I saw Emma Watson at the mall.") == "Celebrities"
72 | assert categorizer.categorize("This is a random comment.") == "Others"
73 | assert categorizer.categorize("1234567890!@#$%^&*()") == "Anomaly"
74 |
75 |
76 | def test_base_categorizer_batch(sample_categories, sample_provider, sample_list):
77 | """
78 | Test the BaseCategorizer class with a batch of examples.
79 | """
80 | categorizer = BaseCategorizer(
81 | categories=sample_categories,
82 | llm_provider=sample_provider,
83 | examples=[
84 | {"example": "Big stormy skies over city", "response": "Weather"},
85 | {"example": "The team won the championship", "response": "Sports"},
86 | {"example": "I saw a famous rapper at the mall", "response": "Celebrities"},
87 | ],
88 | )
89 | print(categorizer.system_message_batch)
90 | assert categorizer.categorize_batch(sample_list) == [
91 | "Weather",
92 | "Sports",
93 | "Celebrities",
94 | "Others",
95 | "Anomaly",
96 | ]
97 |
98 |
99 | def test_apply_to_column(sample_categories, sample_provider, sample_dataframe):
100 | """
101 | Test the apply_to_column function.
102 | """
103 | categorizer = BaseCategorizer(
104 | categories=sample_categories,
105 | llm_provider=sample_provider,
106 | examples=[
107 | {"example": "Big stormy skies over city", "response": "Weather"},
108 | {"example": "The team won the championship", "response": "Sports"},
109 | {"example": "I saw a famous rapper at the mall", "response": "Celebrities"},
110 | ],
111 | )
112 |
113 | df = sample_dataframe.copy()
114 | df["category"] = None
115 |
116 | success_idx = apply_to_column(df["text"], df["category"], categorizer.categorize)
117 |
118 | assert success_idx == 5
119 | assert df["category"].tolist() == [
120 | "Weather",
121 | "Sports",
122 | "Celebrities",
123 | "Others",
124 | "Anomaly",
125 | ]
126 |
127 |
128 | def test_apply_to_column_batch(sample_categories, sample_provider, sample_dataframe):
129 | """
130 | Test the apply_to_column_batch function.
131 | """
132 | categorizer = BaseCategorizer(
133 | categories=sample_categories,
134 | llm_provider=sample_provider,
135 | examples=[
136 | {"example": "Big stormy skies over city", "response": "Weather"},
137 | {"example": "The team won the championship", "response": "Sports"},
138 | {"example": "I saw a famous rapper at the mall", "response": "Celebrities"},
139 | ],
140 | )
141 |
142 | df = sample_dataframe.copy()
143 | df["category"] = None
144 |
145 | success_idx = apply_to_column_batch(
146 | df["text"], df["category"], categorizer.categorize_batch, batch_size=2
147 | )
148 |
149 | assert success_idx == 5
150 | assert df["category"].tolist() == [
151 | "Weather",
152 | "Sports",
153 | "Celebrities",
154 | "Others",
155 | "Anomaly",
156 | ]
157 |
158 |
159 | def test_apply_to_column_batch_start_idx(
160 | sample_categories, sample_provider, sample_dataframe
161 | ):
162 | """
163 | Test the apply_to_column_batch function with a start index.
164 | """
165 | categorizer = BaseCategorizer(
166 | categories=sample_categories,
167 | llm_provider=sample_provider,
168 | examples=[
169 | {"example": "Big stormy skies over city", "response": "Weather"},
170 | {"example": "The team won the championship", "response": "Sports"},
171 | {"example": "I saw a famous rapper at the mall", "response": "Celebrities"},
172 | ],
173 | )
174 |
175 | df = sample_dataframe.copy()
176 | df["category"] = None
177 |
178 | success_idx = apply_to_column_batch(
179 | df["text"],
180 | df["category"],
181 | categorizer.categorize_batch,
182 | batch_size=2,
183 | start_idx=1,
184 | )
185 |
186 | assert success_idx == 5
187 | assert df["category"].tolist() == [
188 | None,
189 | "Sports",
190 | "Celebrities",
191 | "Others",
192 | "Anomaly",
193 | ]
194 |
195 |
196 | def test_multi_categorizer(sample_categories, sample_provider):
197 | """
198 | Test the MultiCategorizer class.
199 | """
200 | categorizer = MultiCategorizer(
201 | categories=sample_categories,
202 | llm_provider=sample_provider,
203 | examples=[
204 | {
205 | "example": "Big stormy skies over city causes league football game to be cancelled",
206 | "response": "Weather,Sports",
207 | },
208 | {
209 | "example": "Elon musk likes to play golf",
210 | "response": "Sports,Celebrities",
211 | },
212 | ],
213 | )
214 |
215 | assert set(
216 | categorizer.categorize("It's raining and I saw Emma Watson.").split(",")
217 | ) == {
218 | "Weather",
219 | "Celebrities",
220 | }
221 | assert set(
222 | categorizer.categorize("The football match was exciting and it's sunny!").split(
223 | ","
224 | )
225 | ) == {"Sports", "Weather"}
226 |
227 |
228 | def test_multi_categorizer_batch(sample_categories, sample_provider):
229 | """
230 | Test the MultiCategorizer class with a batch of examples.
231 | """
232 | categorizer = MultiCategorizer(
233 | categories=sample_categories,
234 | llm_provider=sample_provider,
235 | examples=[
236 | {
237 | "example": "Big stormy skies over city causes league football game to be cancelled",
238 | "response": "Weather,Sports",
239 | },
240 | {
241 | "example": "Elon musk likes to play golf",
242 | "response": "Sports,Celebrities",
243 | },
244 | ],
245 | )
246 |
247 | examples = [
248 | "Thunderstorms cause major delays in baseball tournament",
249 | "Famous actor spotted at local charity basketball game",
250 | "Heavy rainfall leads to postponement of soccer match",
251 | ]
252 |
253 | expected_output = [
254 | ["Weather", "Sports"],
255 | ["Celebrities", "Sports"],
256 | ["Weather", "Sports"],
257 | ]
258 |
259 | actual_output = categorizer.categorize_batch(examples)
260 |
261 | # Convert the actual output to sets for order-insensitive comparison
262 | actual_output_sets = [set(categories.split(",")) for categories in actual_output]
263 |
264 | # Convert the expected output to sets for order-insensitive comparison
265 | expected_output_sets = [set(categories) for categories in expected_output]
266 |
267 | assert len(actual_output_sets) == len(expected_output_sets)
268 | for actual_set, expected_set in zip(actual_output_sets, expected_output_sets):
269 | assert actual_set == expected_set
270 |
--------------------------------------------------------------------------------