├── opendatagen ├── __init__.py ├── .DS_Store ├── files │ ├── scorer.txt │ ├── ner.txt │ ├── basic.txt │ ├── breath.txt │ ├── concretizing.txt │ ├── deep.txt │ ├── step_reasoning.txt │ └── template.json ├── examples │ ├── faq_wikipedia.json │ └── more_agents_is_all_you_need.json ├── anonymizer.py ├── utils.py ├── template.py ├── agent.py ├── data_generator.py └── model.py ├── .DS_Store ├── .gitignore ├── .vscode └── launch.json ├── LICENCE.md ├── .github └── FUNDING.yml ├── LICENSE ├── setup.py └── README.md /opendatagen/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thoddnn/open-datagen/HEAD/.DS_Store -------------------------------------------------------------------------------- /opendatagen/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thoddnn/open-datagen/HEAD/opendatagen/.DS_Store -------------------------------------------------------------------------------- /opendatagen/files/scorer.txt: -------------------------------------------------------------------------------- 1 | We would like you to evaluate and rate the difficulty and complexity of the following question. 2 | You should give an overall score on a scale of 1 to 10, where a higher score indicates higher difficulty and complexity. 3 | You must just give a score without any other reasons. 4 | 5 | ## Question: 6 | {prompt} 7 | ## Score: -------------------------------------------------------------------------------- /opendatagen/files/ner.txt: -------------------------------------------------------------------------------- 1 | Given the context of the provided prompt. Rewrite the string between curly brackets with a better common sense value, if it is needed. 2 | 3 | Examples: 4 | 5 | ###Input### 6 | My first name is {bank_account}. 7 | ###Output### 8 | My name is {first_name}. 9 | 10 | ###Input### 11 | '{person}' is a restaurant. 12 | ###Output### 13 | '{restaurant_name}' is a restaurant. -------------------------------------------------------------------------------- /opendatagen/files/basic.txt: -------------------------------------------------------------------------------- 1 | You are asked to come up with a diverse prompt variation of the following prompt: 2 | """ 3 | {prompt} 4 | """ 5 | 6 | The prompt generated will be given to a GPT model and we will evaluate the GPT model for completing it. 7 | Here are the requirements: 8 | 1. Try not to repeat the verb for each sub-topic to maximize diversity. 9 | 2. The language used for the generated prompt also should be diverse. 10 | 3. The type of prompt should be diverse. 11 | 4. The prompt should be 1 to 2 sentences long. 12 | 13 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.pyc 3 | *.txt 4 | template.json 5 | *.jsonl 6 | main.py 7 | *.vscode 8 | *.in 9 | *.egg-info/ 10 | dist/ 11 | build/ 12 | 13 | # OS generated files 14 | .DS_Store 15 | Thumbs.db 16 | 17 | # local testing 18 | exp_opendatagen.py 19 | 20 | # Virtual environment 21 | venv/ 22 | *.venv 23 | .venv/ 24 | 25 | # Environment variables 26 | .env 27 | 28 | # Exclude all CSV, IPYNB, and JSON files 29 | *.csv 30 | *.ipynb 31 | *.json 32 | *.mp3 33 | *.jpeg 34 | *.png 35 | *.wav 36 | 37 | *agent_ui.py 38 | 39 | !/opendatagen/files/*.txt 40 | 41 | !/opendatagen/examples/*.json -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | 8 | { 9 | "name": "Python: Current File", 10 | "type": "python", 11 | "request": "launch", 12 | "program": "${workspaceFolder}/opendatagen/main.py", 13 | "console": "integratedTerminal", 14 | "justMyCode": true 15 | } 16 | ] 17 | } -------------------------------------------------------------------------------- /opendatagen/files/breath.txt: -------------------------------------------------------------------------------- 1 | 2 | I want you act as a Prompt Creator. 3 | 4 | Your goal is to draw inspiration from the #Given Prompt# to create a brand new prompt. 5 | This new prompt should belong to the same domain as the #Given Prompt# but be even more rare. 6 | 7 | The LENGTH and difficulty level of the #Created Prompt# should be similar to that of the #Given Prompt#. The #Created Prompt# must be reasonable and must be understood and responded by humans. 8 | ‘#Given Prompt#’, ‘#Created Prompt#’, ‘given prompt’ and ‘created prompt’ are not allowed to appear in #Created Prompt#. 9 | 10 | #Given Prompt#: 11 | {prompt} 12 | 13 | #Created Prompt#: 14 | -------------------------------------------------------------------------------- /LICENCE.md: -------------------------------------------------------------------------------- 1 | # LICENSE 2 | 3 | ## GNU GENERAL PUBLIC LICENSE 4 | 5 | Version 3, 29 June 2007 6 | 7 | Copyright (C) 2023 MY LITTLE PLANET 8 | 9 | This program is free software: you can redistribute it and/or modify 10 | it under the terms of the GNU General Public License as published by 11 | the Free Software Foundation, either version 3 of the License, or 12 | (at your option) any later version. 13 | 14 | This program is distributed in the hope that it will be useful, 15 | but WITHOUT ANY WARRANTY; without even the implied warranty of 16 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17 | GNU General Public License for more details. 18 | 19 | You should have received a copy of the GNU General Public License 20 | along with this program. If not, see . 21 | 22 | -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: [thoddnn] # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] 4 | patreon: # Replace with a single Patreon username 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: # Replace with a single Ko-fi username 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | otechie: # Replace with a single Otechie username 12 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry 13 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] 14 | -------------------------------------------------------------------------------- /opendatagen/files/concretizing.txt: -------------------------------------------------------------------------------- 1 | 2 | I want you act as a Prompt Rewriter. 3 | Your objective is to rewrite a given prompt into a more complex version to make those famous AI systems (e.g., ChatGPT and GPT4) a bit harder to handle. 4 | But the rewritten prompt must be reasonable and must be understood and responded by humans. 5 | Your rewriting cannot omit the non-text parts such as the table and code in #Given Prompt#:. Also, please do not omit the input in #Given Prompt#. 6 | You SHOULD complicate the given prompt using the following method: 7 | Please replace general concepts with more specific concepts. or 8 | You should try your best not to make the #Rewritten Prompt# become verbose, #Rewritten Prompt# can only add 10 to 20 words into #Given Prompt#. 9 | ‘#Given Prompt#’, ‘#Rewritten Prompt#’, ‘given prompt’ and ‘rewritten prompt’ are not allowed to appear in #Rewritten Prompt# 10 | #Given Prompt#: 11 | {prompt} 12 | #Rewritten Prompt#: -------------------------------------------------------------------------------- /opendatagen/files/deep.txt: -------------------------------------------------------------------------------- 1 | I want you act as a Prompt Rewriter. 2 | Your objective is to rewrite a given prompt into a more complex version to make those famous AI systems (e.g., ChatGPT and GPT4) a bit harder to handle. 3 | But the rewritten prompt must be reasonable and must be understood and responded by humans. 4 | Your rewriting cannot omit the non-text parts such as the table and code in #Given Prompt#:. Also, please do not omit the input in #Given Prompt#. 5 | You SHOULD complicate the given prompt using the following method: 6 | If #Given Prompt# contains inquiries about certain issues, the depth and breadth of the inquiry can be increased. or 7 | You should try your best not to make the #Rewritten Prompt# become verbose, #Rewritten Prompt# can only add 10 to 20 words into #Given Prompt#. 8 | ‘#Given Prompt#’, ‘#Rewritten Prompt#’, ‘given prompt’ and ‘rewritten prompt’ are not allowed to appear in #Rewritten Prompt# 9 | #Given Prompt#: 10 | {prompt} 11 | #Rewritten Prompt#: 12 | -------------------------------------------------------------------------------- /opendatagen/files/step_reasoning.txt: -------------------------------------------------------------------------------- 1 | 2 | I want you act as a Prompt Rewriter. 3 | Your objective is to rewrite a given prompt into a more complex version to make those famous AI systems (e.g., ChatGPT and GPT4) a bit harder to handle. 4 | But the rewritten prompt must be reasonable and must be understood and responded by humans. 5 | Your rewriting cannot omit the non-text parts such as the table and code in #Given Prompt#:. Also, please do not omit the input in #Given Prompt#. 6 | You SHOULD complicate the given prompt using the following method: 7 | If #Given Prompt# can be solved with just a few simple thinking processes, you can rewrite it to explicitly request multiple-step reasoning. 8 | You should try your best not to make the #Rewritten Prompt# become verbose, #Rewritten Prompt# can only add 10 to 20 words into #Given Prompt#. 9 | ‘#Given Prompt#’, ‘#Rewritten Prompt#’, ‘given prompt’ and ‘rewritten prompt’ are not allowed to appear in #Rewritten Prompt# 10 | #Given Prompt#: 11 | {prompt} 12 | #Rewritten Prompt#: -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 MY LITTLE PLANET, Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | with open("README.md", "r", encoding="utf-8") as fh: 4 | long_description = fh.read() 5 | 6 | 7 | setup( 8 | name='opendatagen', 9 | author="Thomas DORDONNE", 10 | author_email="dordonne.thomas@gmail.com", 11 | description="Synthetic data generation to improve AI and humans", 12 | long_description=long_description, 13 | long_description_content_type="text/markdown", 14 | url="https://github.com/thoddnn/open-datagen", 15 | version='0.0.4', 16 | packages=find_packages(), 17 | include_package_data=True, 18 | install_requires=[ 19 | 'openai==1.5.0', 20 | 'python-dotenv>=0.17.1', 21 | 'numpy>=1.23.4', 22 | 'trafilatura>=0.9.1', 23 | 'requests>=2.29.0', 24 | 'tenacity>=8.2.2', 25 | 'pydantic>=2', 26 | 'spacy>=3', 27 | 'tiktoken>=0.5', 28 | 'PyPDF2>=3', 29 | 'pandas>=2', 30 | 'datasets>=2', 31 | 'mistralai', 32 | 'jsonschema', 33 | 'llama-cpp-python>=0.2.24', 34 | 'openai-whisper', 35 | 'elevenlabs==0.3.0b0', 36 | 'Pillow', 37 | 'torch>=2.2.0', 38 | 'audiocraft', 39 | 'anthropic', 40 | 'bark @ git+https://github.com/suno-ai/bark.git' 41 | ] 42 | ) 43 | -------------------------------------------------------------------------------- /opendatagen/examples/faq_wikipedia.json: -------------------------------------------------------------------------------- 1 | { 2 | "opendatagen": { 3 | "name":"Wikipedia FAQ", 4 | "description": "", 5 | "prompt": "#Text:\n'''\n{{text}}\n'''\nQuestion:\n'''\n{{question}}\n'''", 6 | "prompt_variation_number": 0, 7 | "variables": { 8 | "text": { 9 | "name": "Wikipedia content", 10 | "generation_number": 10, 11 | "get_value_from_huggingface": { 12 | "dataset_name": "20220301.en", 13 | "dataset_path": "wikipedia", 14 | "column_name": "text", 15 | "max_tokens": 1024 16 | } 17 | }, 18 | "question": { 19 | "name": "Question", 20 | "generation_number": 2, 21 | "independent_values":false, 22 | "ensure_model_diversity": true, 23 | "models": [ 24 | { 25 | "openai_chat_model": { 26 | "name": "gpt-3.5-turbo-0125", 27 | "user_prompt": [ 28 | {"role":"system", "content":"You are QuestionGPT and must write high quality question about the given text."}, 29 | {"role":"user", "content": "Write a question about the following text:\n'''\n{{text}}\n'''\n"} 30 | ], 31 | "temperature": [0], 32 | "max_tokens": 128 33 | } 34 | } 35 | ] 36 | 37 | } 38 | 39 | } 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /opendatagen/files/template.json: -------------------------------------------------------------------------------- 1 | { 2 | "factuality": { 3 | "description": "Factuality", 4 | "prompt": "Given the following text:\n\n'''{wikipedia_content}'''\n\nAnswer to this factually checkable question:\n'''{question}'''.", 5 | "completion": "Answer: '''{answer}'''. Rate the answer out of 10: {score}", 6 | "prompt_variation_number": 0, 7 | "variables": { 8 | "wikipedia_content": { 9 | "name": "Wikipedia content", 10 | "generation_number": 1, 11 | "get_value_from_huggingface": { 12 | "dataset_name": "20220301.en", 13 | "dataset_path": "wikipedia", 14 | "column_name": "text", 15 | "max_tokens": 512 16 | } 17 | }, 18 | "question": { 19 | "name": "Factually checkable question", 20 | "generation_number": 3, 21 | "models": [ 22 | { 23 | "openai_chat_model": { 24 | "name": "gpt-3.5-turbo-1106", 25 | "temperature": 0, 26 | "max_tokens": 128 27 | } 28 | } 29 | ] 30 | }, 31 | "answer": { 32 | "name": "Short answer to the question", 33 | "generation_number": 1, 34 | "models": [ 35 | { 36 | "openai_instruct_model": { 37 | "name": "gpt-3.5-turbo-instruct", 38 | "temperature": 0, 39 | "max_tokens": 128, 40 | "start_with": ["Answer:"] 41 | } 42 | } 43 | ] 44 | 45 | }, 46 | "score": { 47 | "name": "Score", 48 | "generation_number": 1, 49 | "note": ["You must answer with an integer."], 50 | "models": [ 51 | { 52 | "openai_chat_model": { 53 | "name": "gpt-3.5-turbo-1106", 54 | "temperature": 0, 55 | "max_tokens": 5 56 | } 57 | } 58 | ] 59 | } 60 | 61 | } 62 | } 63 | } -------------------------------------------------------------------------------- /opendatagen/examples/more_agents_is_all_you_need.json: -------------------------------------------------------------------------------- 1 | { 2 | "opendatagen": { 3 | "name":"More Agents is All you need", 4 | "description": "https://arxiv.org/abs/2402.05120", 5 | "prompt": "Question:\n'''\n{{question}}\n'''\n\nAnswer:\n'''\n{{answer}}\n'''", 6 | "prompt_variation_number": 0, 7 | "variables": { 8 | "question": { 9 | "name": "Question", 10 | "generation_number": 1, 11 | "independent_values":false, 12 | "ensure_model_diversity": true, 13 | "models": [ 14 | { 15 | "openai_chat_model": { 16 | "name": "gpt-3.5-turbo-0125", 17 | "user_prompt": [ 18 | {"role":"system", "content":"Write a high quality but complex question about astronomy. No verbose."} 19 | ], 20 | "temperature": [1], 21 | "max_tokens": 128 22 | } 23 | } 24 | ] 25 | 26 | }, 27 | "answer": { 28 | "name": "Answer", 29 | "generation_number": 5, 30 | "independent_values":true, 31 | "ensure_model_diversity": true, 32 | "models": [ 33 | { 34 | "openai_chat_model": { 35 | "name": "gpt-3.5-turbo-0125", 36 | "user_prompt": [ 37 | {"role":"system", "content":"Answer to the following question. Don't be too verbose"}, 38 | {"role":"assistant", "content":"{{question}}"} 39 | ], 40 | "temperature": [0, 0.5, 1, 1.2], 41 | "max_tokens": 256 42 | } 43 | } 44 | ], 45 | "junction":{ 46 | "delete_branch":true, 47 | "model": { 48 | "openai_chat_model": { 49 | "name": "gpt-3.5-turbo-0125", 50 | "user_prompt": [ 51 | {"role":"system", "content":"Given the values provided, rewrite the best answer."} 52 | ], 53 | "temperature": [1], 54 | "max_tokens": 256 55 | } 56 | } 57 | 58 | } 59 | } 60 | 61 | } 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /opendatagen/anonymizer.py: -------------------------------------------------------------------------------- 1 | import re 2 | import spacy 3 | from opendatagen.model import OpenAIChatModel, ModelName 4 | from opendatagen.utils import load_file 5 | 6 | class Anonymizer: 7 | 8 | NER_PLACEHOLDER = { 9 | "PERSON": "{person}", 10 | "ORG": "{organization}", 11 | "GPE": "{location}", 12 | "DATE": "{date}", 13 | "TIME": "{time}", 14 | "NORP": "{group}", 15 | "FAC": "{facility}", 16 | "LOC": "{location}", 17 | "PRODUCT": "{product}", 18 | "EVENT": "{event}", 19 | "WORK_OF_ART": "{artwork}", 20 | "LAW": "{law}", 21 | "LANGUAGE": "{language}", 22 | "MONEY": "{money}", 23 | "PERCENT": "{percentage}", 24 | "ORDINAL": "{ordinal}", 25 | "CARDINAL": "{number}", 26 | # Add more if needed 27 | } 28 | 29 | REGEX_PATTERN = { 30 | "{phone_number}": r"\+?\d{1,4}?[-.\s]?\(?\d{1,3}?\)?[-.\s]?\d{1,4}[-.\s]?\d{1,4}[-.\s]?\d{1,9}", 31 | "{email}": r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b", 32 | "{credit_card_pattern}": r"\d{4}[-\s]?\d{4}[-\s]?\d{4}[-\s]?\d{4}", 33 | "{address_pattern}": r"\d{1,5}\s\w+(\s\w+)*,\s\w+,\s\w+(\s\w+)*", 34 | "{date_pattern}": r"(\d{4}[-/]\d{1,2}[-/]\d{1,2})|(\d{1,2}[-/]\d{1,2}[-/]\d{4})", 35 | "{time_pattern}": r"(?:[01]\d|2[0-3]):[0-5]\d", 36 | "{ipv4_pattern}": r"\b(?:\d{1,3}\.){3}\d{1,3}\b", 37 | "{url_pattern}": r"https?://(?:www\.)?[-a-zA-Z0-9@:%._\+~#=]{2,256}\.[a-z]{2,6}\b([-a-zA-Z0-9@:%_\+.~#?&//=]*)", 38 | "{ssn_pattern}": r"\d{3}-\d{2}-\d{4}", 39 | "{license_plate_pattern}": r"[A-Z0-9]{2,}-[A-Z0-9]{2,}", 40 | "{zip_code_pattern}": r"\d{5}(-\d{4})?", 41 | "{vin_pattern}": r"[A-HJ-NPR-Z0-9]{17}", 42 | "{iban_pattern}": r"[A-Z]{2}\d{2}[A-Z0-9]{1,30}", 43 | "{driver_license_pattern}": r"[A-Z]{1,2}-\d{4,9}" 44 | } 45 | 46 | 47 | 48 | def __init__(self, completion_model:OpenAIChatModel): 49 | 50 | self.nlp = spacy.load("en_core_web_sm") 51 | self.ner_prompt = load_file("files/ner.txt") 52 | self.completion_model = completion_model 53 | 54 | def regex_anonymization(self, text: str) -> str: 55 | 56 | for replacement, pattern in self.REGEX_PATTERN.items(): 57 | text = re.sub(pattern, replacement, text) 58 | 59 | return text 60 | 61 | def ner_anonymization(self, text: str) -> str: 62 | doc = self.nlp(text) 63 | for entity in doc.ents: 64 | placeholder = self.NER_PLACEHOLDER.get(entity.label_) 65 | if placeholder: 66 | text = text.replace(entity.text, placeholder) 67 | return text 68 | 69 | def llm_anonymization(self, text: str) -> str: 70 | 71 | completion = self.completion_model.ask( 72 | system_prompt=self.ner_prompt, 73 | user_prompt=text, 74 | max_tokens=126, 75 | temperature=0 76 | ) 77 | 78 | return completion 79 | 80 | def anonymize(self, text: str) -> str: 81 | 82 | text = self.regex_anonymization(text) 83 | text = self.ner_anonymization(text) 84 | return self.llm_anonymization(text) 85 | 86 | 87 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ⬜️ Open Datagen ⬜️ 2 | 3 | **Open Datagen** is a Data Preparation Tool designed to build Controllable AI Systems 4 | 5 | It offers improvements for: 6 | 7 | **RAG**: Generate large Q&A datasets to improve your Retrieval strategies. 8 | 9 | **Evals**: Create unique, “unseen” datasets to robustly test your models and avoid overfitting. 10 | 11 | **Fine-Tuning**: Produce large, low-bias, and high-quality datasets to get better models after the fine-tuning process. 12 | 13 | **Guardrails**: Generate red teaming datasets to strengthen the security and robustness of your Generative AI applications against attack. 14 | 15 | ## Additional Features 16 | 17 | - Use external sources to generate high-quality synthetic data (Local files, Hugging Face datasets and Internet) 18 | 19 | - Data anonymization 20 | 21 | - Open-source model support + local inference 22 | 23 | - Decontamination 24 | 25 | - Tree of thought 26 | 27 | - Multimodality (Text, Audio and Image) 28 | 29 | - No-code dataset generation with the Open DataGen UI ⤵️⤵️⤵️ 30 | 31 | [![Watch the video](https://img.youtube.com/vi/Dp6jvJBuUA0/0.jpg)](https://www.youtube.com/watch?v=Dp6jvJBuUA0) 32 | 33 | 34 | ## Installation 35 | ```bash 36 | conda create -n opendataenv python=3.9.6 37 | ``` 38 | 39 | ```bash 40 | pip install --upgrade opendatagen 41 | ``` 42 | 43 | ### Setting up your API keys 44 | 45 | ```bash 46 | export OPENAI_API_KEY='your_openai_api_key' #(using openai>=1.2) 47 | export MISTRAL_API_KEY='your_mistral_api_key' 48 | export TOGETHER_API_KEY='your_together_api_key' 49 | export ANYSCALE_API_KEY='your_anyscale_api_key' 50 | export ELEVENLABS_API_KEY='your_elevenlabs_api_key' 51 | export SERPLY_API_KEY='your_serply_api_key' #Google Search API 52 | ``` 53 | 54 | ## Usage 55 | 56 | Example: Generate a low-biased FAQ dataset based on Wikipedia content 57 | 58 | ```python 59 | from opendatagen.template import TemplateManager 60 | from opendatagen.data_generator import DataGenerator 61 | 62 | output_path = "opendatagen.csv" 63 | template_name = "opendatagen" 64 | manager = TemplateManager(template_file_path="faq_wikipedia.json") 65 | template = manager.get_template(template_name=template_name) 66 | 67 | if template: 68 | 69 | generator = DataGenerator(template=template) 70 | 71 | data, data_decontaminated = generator.generate_data(output_path=output_path, output_decontaminated_path=None) 72 | 73 | ``` 74 | 75 | where faq_wikipedia.json is [here](opendatagen/examples/faq_wikipedia.json) 76 | 77 | ## Contribution 78 | 79 | We welcome contributions to Open Datagen! Whether you're looking to fix bugs, add templates, new features, or improve documentation, your help is greatly appreciated. 80 | 81 | ## Acknowledgements 82 | 83 | We would like to express our gratitude to the following open source projects and individuals that have inspired and helped us: 84 | 85 | - **Textbooks are all you need** ([Read the paper](https://arxiv.org/abs/2306.11644)) 86 | 87 | - **Evol-Instruct Paper** ([Read the paper](https://arxiv.org/abs/2306.08568)) by [WizardLM_AI](https://twitter.com/WizardLM_AI) 88 | 89 | - **Textbook Generation** by [VikParuchuri](https://github.com/VikParuchuri/textbook_quality) 90 | 91 | ## Connect 92 | 93 | If you need help for your Generative AI strategy, implementation, and infrastructure, reach us on 94 | 95 | Linkedin: [@Thomas](https://linkedin.com/in/thomasdordonne). 96 | Twitter: [@thoddnn](https://twitter.com/thoddnn). 97 | -------------------------------------------------------------------------------- /opendatagen/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import trafilatura 4 | import re 5 | import requests 6 | from urllib.parse import quote_plus 7 | import json 8 | import importlib 9 | import tiktoken 10 | from datasets import Dataset 11 | import random 12 | import math 13 | import numpy as np 14 | import openai 15 | import inspect 16 | from typing import List, Optional 17 | from pydantic import BaseModel 18 | import base64 19 | 20 | def dict_to_string(d): 21 | result = [] 22 | for key, value in d.items(): 23 | result.append(f'#{key}#:\n"""') 24 | result.append(f'{value}') 25 | result.append('"""') 26 | return '\n'.join(result) 27 | 28 | def load_file(path:str): 29 | # Adjust the path based on this module's location 30 | absolute_path = os.path.join(os.path.dirname(__file__), path) 31 | 32 | with open(absolute_path, 'r') as file: 33 | content = file.read() 34 | 35 | return content 36 | 37 | def write_to_csv(rows, filename): 38 | 39 | if not rows: # Check if rows is empty or None 40 | raise ValueError("The 'rows' argument cannot be empty.") 41 | 42 | # Use the current working directory instead of the script's directory 43 | base_path = os.getcwd() 44 | 45 | if os.path.isabs(filename): 46 | path = filename 47 | else: 48 | path = os.path.join(base_path, filename) 49 | 50 | # Open the file and write the rows 51 | with open(path, 'w', newline='') as file: 52 | writer = csv.DictWriter(file, fieldnames=rows[0].keys()) 53 | writer.writeheader() # Writing the headers 54 | writer.writerows(rows) # Writing the rows 55 | 56 | def generate_context_from_json(data, stop_field=None): 57 | if stop_field and list(data.keys())[0] == stop_field: 58 | return "" 59 | 60 | output = "Given these values\n" 61 | 62 | for key, value in data.items(): 63 | if key == stop_field: 64 | break 65 | output += f"#{key} value#\n'''{value}\n'''\n" 66 | 67 | return output 68 | 69 | 70 | def extract_website_details(url): 71 | downloaded = trafilatura.fetch_url(url) 72 | metadata = trafilatura.metadata.extract_metadata(downloaded) 73 | 74 | title = metadata['title'] if metadata and 'title' in metadata else None 75 | description = metadata['description'] if metadata and 'description' in metadata else None 76 | 77 | content = trafilatura.extract(downloaded) 78 | 79 | response = { 80 | "title": title, 81 | "description": description, 82 | "content": content 83 | } 84 | 85 | return response 86 | 87 | def create_type_message(comp_type, min_value, max_value): 88 | """Helper function to create the type message based on the given constraints.""" 89 | type_msg = f"The answer must be a {comp_type}" if comp_type else "" 90 | 91 | if comp_type == "int": 92 | if min_value and max_value: 93 | type_msg += f" between {min_value} and {max_value}" 94 | elif max_value: 95 | type_msg += f" lower than {max_value}" 96 | elif min_value: 97 | type_msg += f" greater than {min_value}" 98 | 99 | return type_msg 100 | 101 | def find_strings_in_brackets(text): 102 | # This pattern matches text enclosed in { and } 103 | pattern = r"\{(.*?)\}" 104 | # Find all matches 105 | matches = re.findall(pattern, text) 106 | return matches 107 | 108 | def find_strings_in_double_brackets(text): 109 | # This pattern matches text enclosed in double { and } 110 | pattern = r"\{\{(.*?)\}\}" 111 | # Find all matches 112 | matches = re.findall(pattern, text) 113 | return matches 114 | 115 | def replace_with_dict(text, data): 116 | for key, value in data.items(): 117 | placeholder = "{{" + key + "}}" 118 | text = text.replace(placeholder, value) 119 | return text 120 | 121 | def snake_case_to_title_case(snake_str): 122 | # Split the string at underscores 123 | words = snake_str.split('_') 124 | # Capitalize the first letter of each word and join them with a space 125 | title_case_str = ' '.join(word.capitalize() for word in words) 126 | return title_case_str 127 | 128 | def title_case_to_snake_case(title_str): 129 | # First, split the string by spaces 130 | words = title_str.split(' ') 131 | # Convert all the words to lowercase and join them with underscores 132 | snake_case_str = '_'.join(word.lower() for word in words) 133 | return snake_case_str 134 | 135 | def image_to_base64_data_uri(file_path): 136 | with open(file_path, "rb") as img_file: 137 | base64_data = base64.b64encode(img_file.read()).decode('utf-8') 138 | return f"data:image/png;base64,{base64_data}" 139 | 140 | 141 | 142 | def word_counter(input_string): 143 | # Split the string into words based on whitespace 144 | words = input_string.split() 145 | 146 | # Count the number of words 147 | number_of_words = len(words) 148 | 149 | return number_of_words 150 | 151 | def get_google_search_result(keyword:dict, maximum_number_of_link:int = None): 152 | 153 | encoded_keyword = quote_plus(keyword) 154 | 155 | url = f"https://api.serply.io/v1/search/q={encoded_keyword}" 156 | 157 | headers = { 158 | "Content-Type": "application/json", 159 | "X-User-Agent": "", 160 | "X-Proxy-Location": "", 161 | "X-Api-Key": os.environ.get("SERPLY_API_KEY"), 162 | "X-Proxy-Location": "US" 163 | } 164 | 165 | response = requests.request("GET", url, headers=headers) 166 | 167 | response_json = json.loads(response.text)["results"] 168 | 169 | result = [] 170 | 171 | for element in response_json: 172 | 173 | link = element['link'] 174 | result.append(link) 175 | 176 | if maximum_number_of_link: 177 | return result[:maximum_number_of_link] 178 | 179 | return result 180 | 181 | def get_content_from_url(link:str): 182 | 183 | downloaded = trafilatura.fetch_url(link) 184 | content = trafilatura.extract(downloaded) 185 | 186 | return content 187 | 188 | def extract_content_from_internet(keyword:str): 189 | 190 | print(f"Browsing for the keyword {keyword}...") 191 | 192 | result = "" 193 | 194 | urls = get_google_search_result(keyword) 195 | 196 | for url in urls: 197 | 198 | content = get_content_from_url(url) 199 | 200 | if content and word_counter(content) > 500: 201 | 202 | print(url) 203 | 204 | result = result + "\n" + content 205 | 206 | print("Finish browsing...") 207 | 208 | return result 209 | 210 | 211 | def load_user_function(full_function_name:str, from_notebook:bool): 212 | if from_notebook: 213 | try: 214 | from IPython import get_ipython 215 | ipython_namespace = get_ipython().user_ns 216 | except ImportError: 217 | raise EnvironmentError("IPython environment not detected for notebook mode.") 218 | 219 | if full_function_name in ipython_namespace: 220 | func = ipython_namespace[full_function_name] 221 | if callable(func): 222 | return func 223 | else: 224 | raise TypeError(f"The object '{full_function_name}' in the IPython namespace is not callable.") 225 | else: 226 | raise ValueError(f"Function '{full_function_name}' not found in the IPython namespace.") 227 | else: 228 | try: 229 | module_name, function_name = full_function_name.rsplit('.', 1) 230 | module = importlib.import_module(module_name) 231 | func = getattr(module, function_name) 232 | except ValueError: 233 | raise ValueError(f"Invalid format for function name '{full_function_name}'. Expected 'module.function_name'.") 234 | except ImportError: 235 | raise ImportError(f"Module '{module_name}' could not be found.") 236 | except AttributeError: 237 | raise AttributeError(f"Function '{function_name}' not found in module '{module_name}'.") 238 | 239 | if not callable(func): 240 | raise TypeError(f"The object '{function_name}' found in module '{module_name}' is not callable.") 241 | 242 | return func 243 | 244 | 245 | def function_to_call(function_name, from_notebook, *args): 246 | 247 | user_function = load_user_function(function_name, from_notebook) 248 | 249 | sig = inspect.signature(user_function) 250 | 251 | params = sig.parameters 252 | 253 | if params: 254 | return user_function(*args) 255 | else: 256 | return user_function() 257 | 258 | def is_retryable_answer(result): 259 | if "i can't fulfill that request" in result.lower(): 260 | return True 261 | else: 262 | return False 263 | 264 | def num_tokens_from_string(string: str, encoding_name: str) -> int: 265 | """Returns the number of tokens in a text string.""" 266 | encoding = tiktoken.get_encoding(encoding_name) 267 | num_tokens = len(encoding.encode(string)) 268 | return num_tokens 269 | 270 | def get_first_n_tokens(text: str, encoding_name: str, n: int, cut_last_sentence: bool = False) -> str: 271 | """Returns the first n tokens of a string, with an option to cut the last sentence.""" 272 | # Encode the string into tokens 273 | encoding = tiktoken.get_encoding(encoding_name) 274 | tokens = encoding.encode(text) 275 | 276 | # Retrieve the first n tokens 277 | tokens = tokens[:n] 278 | 279 | # Cut the last sentence if required 280 | if cut_last_sentence: 281 | for i in range(len(tokens) - 1, -1, -1): 282 | # Assuming '.' represents the end of a sentence 283 | if encoding.decode([tokens[i]]) == '.': 284 | tokens = tokens[:i+1] 285 | break 286 | 287 | # Decode the tokens back to string 288 | return encoding.decode(tokens) 289 | 290 | 291 | def clean_string(original_string:str): 292 | 293 | cleaned_string = re.sub(r'\n+', '\n\n', original_string).strip() 294 | 295 | return cleaned_string 296 | 297 | 298 | def cosine_similarity(vec1, vec2): 299 | 300 | # Calculate the dot product of the vectors 301 | dot_product = np.dot(vec1, vec2) 302 | 303 | return dot_product 304 | 305 | 306 | def get_prompt_prefix_and_stop_words(model_name:str): 307 | 308 | if "teknium" in model_name.lower().strip(): 309 | 310 | start_prompt = "<|im_start|>" 311 | end_prompt = "" 312 | stop_words = ["<|im_end|>","<|im_start|>"] 313 | 314 | elif "nousresearch" in model_name.lower().strip(): 315 | 316 | start_prompt = "" 317 | end_prompt = "" 318 | stop_words = ["<|im_end|>","<|im_start|>"] 319 | 320 | elif "mistral" in model_name.lower().strip(): 321 | 322 | start_prompt = "[INST]" 323 | end_prompt = "[/INST]" 324 | stop_words = ["[/INST]", ""] 325 | 326 | return "[INST]", ["<|im_end|>","<|im_start|>"] 327 | 328 | else: 329 | 330 | start_prompt = "" 331 | end_prompt = "" 332 | stop_words = [""] 333 | 334 | return start_prompt, end_prompt, stop_words 335 | 336 | def pydantic_list_to_dict(lst: List[BaseModel], fields: Optional[List[str]] = None) -> List[dict]: 337 | if fields: 338 | # If fields are specified, only include those fields in the output 339 | return [{field: item.model_dump().get(field) for field in fields} for item in lst] 340 | else: 341 | # Otherwise, include the entire object 342 | return [item.model_dump() for item in lst] 343 | 344 | -------------------------------------------------------------------------------- /opendatagen/template.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, validator, ValidationError, ConfigDict 2 | from typing import Optional, List, Dict, Union, Any , Callable 3 | from enum import Enum 4 | import os 5 | import json 6 | from opendatagen.utils import load_file 7 | from opendatagen.model import OpenAIChatModel, OpenAIInstructModel, LlamaCPPModel, Model, EmbeddingModel, MistralChatModel, AnyscaleChatModel, TogetherChatModel, TogetherInstructModel, UserMessage 8 | from mistralai.models.chat_completion import ChatMessage 9 | from urllib.parse import quote_plus 10 | import requests 11 | import trafilatura 12 | from PyPDF2 import PdfReader 13 | import pandas as pd 14 | from datasets import load_dataset, Dataset 15 | from opendatagen.utils import get_first_n_tokens, num_tokens_from_string, cosine_similarity, find_strings_in_brackets 16 | import random 17 | import uuid 18 | import re 19 | import pandas as pd 20 | import numpy as np 21 | 22 | class DeleteMode(Enum): 23 | HIGHEST = 'highest' 24 | LOWEST = 'lowest' 25 | 26 | 27 | class DelimiterSplitter(BaseModel): 28 | 29 | delimiter:str = "\n" 30 | 31 | class Config: 32 | extra = "forbid" 33 | 34 | def perform_chunk(self, text:str) -> List[str]: 35 | 36 | chunks = text.split(self.delimiter) 37 | 38 | # Strip leading and trailing whitespaces from each chunk and filter out empty chunks 39 | non_empty_chunks = [chunk.strip() for chunk in chunks if chunk.strip() != ''] 40 | 41 | return non_empty_chunks 42 | 43 | class CharacterSplitter(BaseModel): 44 | 45 | chunk_size:Optional[int] = None 46 | chunk_overlap:Optional[int] = None 47 | cut_within_sentence:Optional[bool] = True 48 | 49 | class Config: 50 | extra = "forbid" 51 | 52 | def perform_chunk(self, text: str) -> List[str]: 53 | 54 | # Validate the input parameters 55 | if not self.chunk_size or self.chunk_size <= 0: 56 | raise ValueError("Chunk size must be a positive integer") 57 | if self.chunk_overlap is not None and (self.chunk_overlap < 0 or self.chunk_overlap >= self.chunk_size): 58 | raise ValueError("Chunk overlap must be non-negative and less than chunk size") 59 | 60 | # Initialize variables 61 | chunks = [] 62 | start_index = 0 63 | 64 | # Calculate the actual chunk size considering the overlap 65 | actual_chunk_size = self.chunk_size 66 | if self.chunk_overlap: 67 | actual_chunk_size -= self.chunk_overlap 68 | 69 | # Generate the chunks 70 | while start_index < len(text): 71 | # Determine the end index of the current chunk 72 | end_index = start_index + self.chunk_size 73 | if not self.cut_within_sentence: 74 | # Find the last sentence end before reaching the chunk size limit 75 | sentence_end_index = text.rfind('. ', start_index, end_index) 76 | if sentence_end_index != -1 and sentence_end_index - start_index > actual_chunk_size // 2: 77 | # Only cut at the sentence end if the chunk is at least half the size of the actual_chunk_size 78 | end_index = sentence_end_index + 1 79 | 80 | chunk = text[start_index:end_index] 81 | chunks.append(chunk) 82 | 83 | # Move the start index for the next chunk 84 | start_index += actual_chunk_size 85 | 86 | return chunks 87 | 88 | 89 | class RAGHuggingFace(BaseModel): 90 | 91 | dataset_path:str 92 | dataset_name:Optional[str] = None 93 | data_dir:Optional[str] = None 94 | column_name:str 95 | specific_row:Optional[int] = None 96 | streaming:bool = True 97 | min_tokens:Optional[int] = 0 98 | max_tokens:Optional[int] = None 99 | subset_size:Optional[int] = 10000 100 | subset:Optional[List[str]] = None 101 | dst:Optional[Any] = None 102 | chunking:Optional[Union[CharacterSplitter, DelimiterSplitter]] = None 103 | 104 | class Config: 105 | extra = "forbid" 106 | 107 | def get_random_value_from_dataset(self): 108 | 109 | if self.subset == None: 110 | 111 | param = {} 112 | 113 | if self.dataset_path: 114 | param["path"] = self.dataset_path 115 | 116 | if self.data_dir: 117 | param["data_dir"] = self.data_dir 118 | 119 | if self.dataset_name: 120 | param["name"] = self.dataset_name 121 | 122 | param["streaming"] = self.streaming 123 | 124 | self.dst = load_dataset(**param) 125 | 126 | self.subset = [sample[self.column_name] for _, sample in zip(range(self.subset_size), self.dst["train"])] 127 | 128 | self.dst = None 129 | 130 | max_attempts = 50 131 | 132 | if self.specific_row: 133 | max_attempts = 1 134 | 135 | count = 0 136 | 137 | while count < max_attempts: 138 | 139 | if self.specific_row: 140 | index = self.specific_row 141 | else: 142 | index = random.randint(0, len(self.subset) - 1) 143 | 144 | text = self.subset[index] 145 | 146 | num_tokens = num_tokens_from_string(text, encoding_name="cl100k_base") 147 | 148 | if num_tokens >= self.min_tokens: 149 | 150 | if self.max_tokens: 151 | 152 | text = self.subset[index] 153 | 154 | result = get_first_n_tokens(n=self.max_tokens, text=text, encoding_name="cl100k_base", cut_last_sentence=False) 155 | 156 | return result 157 | 158 | else: 159 | 160 | result = self.subset[index] 161 | 162 | return result 163 | 164 | count = count + 1 165 | 166 | 167 | class RAGLocalPath(BaseModel): 168 | 169 | localPath:Optional[str] = None 170 | directoryPath:Optional[str] = None 171 | content:Optional[str] = None 172 | randomize:Optional[bool] = False 173 | sample_size: Optional[float] = 0.1 174 | 175 | class Config: 176 | extra = "forbid" 177 | 178 | def get_random_csv_chunk(self, df: pd.DataFrame): 179 | # Randomly sample a fraction of the dataframe rows 180 | return df.sample(frac=self.sample_size) 181 | 182 | def get_random_text_chunk(self, text): 183 | 184 | sentences = re.split(r'(?<=[.!?])\s+', text) 185 | sample_size = max(1, int(len(sentences) * self.sample_size)) 186 | selected_sentences = random.sample(sentences, sample_size) 187 | result = ' '.join(selected_sentences) 188 | return result 189 | 190 | def get_content_from_file(self): 191 | 192 | file_content = '' 193 | 194 | if self.localPath.endswith('.csv'): 195 | df = pd.read_csv(self.localPath) 196 | df = df.astype(str) 197 | if self.randomize: 198 | df = self.get_random_csv_chunk(df) 199 | file_content = df.to_string(header=True, index=False, max_rows=None) 200 | elif self.localPath.endswith('.txt'): 201 | with open(self.localPath, 'r') as file: 202 | file_content = file.read() 203 | if self.randomize: 204 | file_content = self.get_random_text_chunk(file_content) 205 | elif self.localPath.endswith('.pdf'): 206 | reader = PdfReader(self.localPath) 207 | text = '' 208 | for page in reader.pages: 209 | text += page.extract_text() + '\n' 210 | if self.randomize: 211 | file_content = self.get_random_text_chunk(text) 212 | else: 213 | file_content = text 214 | else: 215 | raise ValueError("Unsupported file format") 216 | 217 | self.content = file_content 218 | return file_content 219 | 220 | 221 | 222 | def get_content_from_directory(self): 223 | """ 224 | Iterates over files in the directory, reads their content, 225 | and concatenates it into a single string. 226 | """ 227 | concatenated_content = '' 228 | for filename in os.listdir(self.directoryPath): 229 | filepath = os.path.join(self.directoryPath, filename) 230 | if filepath.endswith(('.csv', '.txt', '.pdf')): 231 | self.localPath = filepath # Temporarily update the localPath 232 | file_content = self.get_content_from_file() 233 | concatenated_content += file_content + '\n' 234 | 235 | self.content = concatenated_content # Store concatenated content 236 | return concatenated_content 237 | 238 | 239 | class RAGInternet(BaseModel): 240 | 241 | keywords:List[str] 242 | return_chunks: Optional[bool] = False 243 | minimum_number_of_words_by_article: Optional[int] = 500 244 | maximum_number_of_words_by_article: Optional[int] = 50000 245 | content: Optional[str] = None 246 | 247 | def word_counter(self, input_string): 248 | # Split the string into words based on whitespace 249 | words = input_string.split() 250 | 251 | # Count the number of words 252 | number_of_words = len(words) 253 | 254 | return number_of_words 255 | 256 | def get_google_search_result(self, keyword:dict, maximum_number_of_link:int = None): 257 | 258 | encoded_keyword = quote_plus(keyword) 259 | 260 | url = f"https://api.serply.io/v1/search/q={encoded_keyword}" 261 | 262 | headers = { 263 | "Content-Type": "application/json", 264 | "X-User-Agent": "", 265 | "X-Proxy-Location": "", 266 | "X-Api-Key": os.environ.get("SERPLY_API_KEY"), 267 | "X-Proxy-Location": "US" 268 | } 269 | 270 | response = requests.request("GET", url, headers=headers) 271 | 272 | response_json = json.loads(response.text)["results"] 273 | 274 | result = [] 275 | 276 | for element in response_json: 277 | 278 | link = element['link'] 279 | result.append(link) 280 | 281 | if maximum_number_of_link: 282 | return result[:maximum_number_of_link] 283 | 284 | return result 285 | 286 | def get_content_from_url(self, link:str): 287 | 288 | downloaded = trafilatura.fetch_url(link) 289 | content = trafilatura.extract(downloaded) 290 | 291 | return content 292 | 293 | def extract_content_from_internet(self): 294 | 295 | print(f"Browsing...") 296 | 297 | for keyword in self.keywords: 298 | 299 | result = "" 300 | 301 | urls = self.get_google_search_result(keyword) 302 | 303 | for url in urls: 304 | 305 | content = self.get_content_from_url(url) 306 | 307 | if content and self.word_counter(content) > self.minimum_number_of_words_by_article and self.word_counter(content) < self.maximum_number_of_words_by_article: 308 | 309 | print(url) 310 | 311 | result = result + "\n" + content 312 | 313 | print("Finish browsing...") 314 | self.content = result 315 | return result 316 | 317 | class Validator(BaseModel): 318 | 319 | function_name:str 320 | additional_parameters:Optional[List[str]] = None 321 | from_notebook:bool = False 322 | retry_number:Optional[int] = 3 323 | 324 | 325 | class Variations(BaseModel): 326 | 327 | id:str 328 | parent_id:Optional[str] = None 329 | initial_value:Optional[str] = None 330 | value:str 331 | path:Optional[str] = None 332 | confidence_score:Optional[float] = None 333 | error_message:str = None 334 | model_used:str = None 335 | 336 | class Config: 337 | extra = "forbid" # This will raise an error for extra fields 338 | 339 | 340 | class Decontomination(BaseModel): 341 | 342 | embedding_model:EmbeddingModel 343 | threshold: Optional[float] = 0.99 344 | column_name: Optional[str] = None 345 | delete_column: Optional[str] = None 346 | delete_mode: Optional[DeleteMode] = DeleteMode.HIGHEST 347 | 348 | def decontaminate_variable(self, variations:Dict[str, Variations]): 349 | 350 | model = self.embedding_model.get_model() 351 | 352 | data: list[Variations] = list(variations.values()) 353 | 354 | embeddings = [model.create_embedding(row.value) for row in data] 355 | embeddings_np = np.array(embeddings) 356 | 357 | sim_matrix = np.zeros((len(embeddings_np), len(embeddings_np))) 358 | 359 | for i in range(len(embeddings_np)): 360 | for j in range(len(embeddings_np)): 361 | sim_matrix[i][j] = cosine_similarity(embeddings_np[i], embeddings_np[j]) 362 | 363 | exclude_indices = set() 364 | 365 | for i in range(len(data)): 366 | for j in range(i + 1, len(data)): 367 | if sim_matrix[i][j] > self.threshold: 368 | exclude_indices.add(j) 369 | 370 | decontamined_variations = {key: variations[key] for i, key in enumerate(variations) if i not in exclude_indices} 371 | 372 | return decontamined_variations 373 | 374 | 375 | def decontaminate(self, data: List[Dict]): 376 | 377 | model = self.embedding_model.get_model() 378 | 379 | embeddings = [model.create_embedding(row[self.column_name]) for row in data] 380 | embeddings_np = np.array(embeddings) 381 | 382 | # Calculate cosine similarity matrix manually 383 | sim_matrix = np.zeros((len(embeddings_np), len(embeddings_np))) 384 | 385 | for i in range(len(embeddings_np)): 386 | for j in range(len(embeddings_np)): 387 | sim_matrix[i][j] = cosine_similarity(embeddings_np[i], embeddings_np[j]) 388 | 389 | # Identify rows to keep (those that don't have too high similarity with any other row) 390 | rows_to_keep = [] 391 | # Mark rows for exclusion based on similarity threshold 392 | exclude_indices = set() 393 | 394 | for i in range(len(data)): 395 | for j in range(i + 1, len(data)): # Compare each pair once 396 | if sim_matrix[i][j] > self.threshold: 397 | # Mark the row with a lower value in delete_column for exclusion 398 | if self.delete_column: 399 | if self.delete_mode == DeleteMode.HIGHEST and data[i][self.delete_column] < data[j][self.delete_column]: 400 | exclude_indices.add(i) 401 | elif self.delete_mode == DeleteMode.LOWEST and data[i][self.delete_column] > data[j][self.delete_column]: 402 | exclude_indices.add(i) 403 | else: 404 | exclude_indices.add(j) 405 | else: 406 | # If no delete_column is specified, exclude one of the rows arbitrarily 407 | exclude_indices.add(j) 408 | 409 | # Identify rows to keep (those not marked for exclusion) 410 | rows_to_keep = [row for idx, row in enumerate(data) if idx not in exclude_indices] 411 | 412 | return rows_to_keep 413 | 414 | 415 | class Junction(BaseModel): 416 | 417 | value:Optional[str] = None 418 | model:Model = None 419 | delete_branch:Optional[bool] = False 420 | 421 | class Config: 422 | extra = "forbid" 423 | 424 | def generate(self, data:List[str]): 425 | 426 | current_model = self.model.get_model() 427 | 428 | prompt = "" 429 | 430 | for val in data: 431 | prompt += f"\n'''\n{val}\n'''\n" 432 | 433 | if isinstance(current_model.user_prompt, list): 434 | 435 | current_model.user_prompt.append(UserMessage(role="user", content=prompt)) 436 | 437 | elif isinstance(current_model.user_prompt, str): 438 | 439 | current_model.user_prompt = f"{current_model.user_prompt}\n\nAssistant:{prompt}" 440 | 441 | else: 442 | 443 | raise ValueError("Error") 444 | 445 | generated_value = current_model.ask() 446 | 447 | self.value = generated_value 448 | 449 | return generated_value 450 | 451 | class RAGVariable(BaseModel): 452 | 453 | variable_name:str = None 454 | get_initial_value:Optional[bool] = True 455 | 456 | class Config: 457 | extra = "forbid" 458 | 459 | class Variable(BaseModel): 460 | 461 | name: str 462 | models:Optional[List[Model]] = None 463 | independent_values:Optional[bool] = False 464 | ensure_model_diversity:Optional[bool] = False 465 | generation_number: int = 1 466 | source_variable:Optional[RAGVariable] = None 467 | source_internet: Optional[RAGInternet] = None 468 | source_localfile: Optional[RAGLocalPath] = None 469 | source_localdirectory: Optional[RAGLocalPath] = None 470 | source_huggingface:Optional[RAGHuggingFace] = None 471 | get_value_from_huggingface:Optional[RAGHuggingFace] = None 472 | get_value_from_localfile:Optional[RAGLocalPath] = None 473 | get_value_from_custom_functions:Optional[Validator] = None 474 | transform_value:Optional[Validator] = None 475 | note: Optional[List[str]] = None 476 | rag_content: Optional[str] = None 477 | validator:Optional[Validator] = None 478 | decontamination:Optional[Decontomination] = None 479 | values:Optional[Dict[str, Variations]] = {} 480 | junction:Optional[Junction] = None 481 | 482 | model_config = ConfigDict( 483 | protected_namespaces=('protect_me_', 'also_protect_'), 484 | extra = "forbid" 485 | ) 486 | 487 | 488 | def load_internet_source(self): 489 | 490 | if self.source_internet is not None: 491 | self.rag_content = self.source_internet.extract_content_from_internet() 492 | 493 | def load_local_file(self): 494 | 495 | if self.source_localfile is not None and self.source_localfile.localPath is not None: 496 | self.rag_content = self.source_localfile.get_content_from_file() 497 | 498 | def load_local_directory(self): 499 | 500 | if self.source_localfile is not None and self.source_localfile.directoryPath is not None: 501 | self.rag_content = self.source_localfile.get_content_from_directory() 502 | 503 | def load_huggingface_dataset(self): 504 | 505 | if self.source_huggingface is not None: 506 | self.rag_content = self.source_huggingface.get_random_value_from_dataset() 507 | 508 | def load_value(self): 509 | 510 | if self.get_value_from_huggingface: 511 | self.value = self.get_value_from_huggingface.get_random_value_from_dataset(max_token=self.max_tokens) 512 | 513 | 514 | class Template(BaseModel): 515 | 516 | name:str 517 | description: str 518 | prompt: str 519 | prompt_variation_number: Optional[int] = 1 520 | variables: Optional[Dict[str, Variable]] = None 521 | source_internet: Optional[RAGInternet] = None 522 | source_localfile: Optional[RAGLocalPath] = None 523 | rag_content: Optional[str] = None 524 | value:Optional[List[str]] = None 525 | decontamination: Optional[Decontomination] = None 526 | 527 | class Config: 528 | extra = "forbid" 529 | 530 | def load_internet_source(self): 531 | 532 | if self.source_internet is not None: 533 | self.rag_content = self.source_internet.extract_content_from_internet() 534 | 535 | def load_local_file(self): 536 | 537 | if self.source_localfile is not None and self.source_localfile.localPath is not None: 538 | self.rag_content = self.source_localfile.get_content_from_file() 539 | 540 | def load_local_directory(self): 541 | 542 | if self.source_localfile is not None and self.source_localfile.directoryPath is not None: 543 | self.rag_content = self.source_localfile.get_content_from_directory() 544 | 545 | class TemplateName(Enum): 546 | PRODUCT_REVIEW = "product-review" 547 | CHUNK = "chunk" 548 | CHUNK2 = "chunk2" 549 | HALLUCINATION = "hallucination" 550 | 551 | 552 | class TemplateManager: 553 | 554 | def __init__(self, template_file_path:str): 555 | self.template_file_path = self.get_template_file_path(template_file_path) 556 | self.templates = self.load_templates() 557 | 558 | def get_template_file_path(self, filename: str) -> str: 559 | base_path = os.getcwd() 560 | 561 | if os.path.isabs(filename): 562 | return filename 563 | else: 564 | return os.path.join(base_path, filename) 565 | 566 | def load_templates(self) -> Dict[str, Template]: 567 | with open(self.template_file_path, 'r') as file: 568 | raw_data = json.load(file) 569 | 570 | templates = {} 571 | for key, data in raw_data.items(): 572 | 573 | template_name = key 574 | template =Template(**data) 575 | templates[template_name] = template 576 | 577 | return templates 578 | 579 | def get_template(self, template_name: str) -> Template: 580 | 581 | template = self.templates.get(template_name) 582 | 583 | if template: 584 | 585 | template.load_internet_source() 586 | template.load_local_file() 587 | template.load_local_directory() 588 | 589 | return template 590 | 591 | def create_variable_from_name(model:OpenAIChatModel, variable_name:str) -> Variable: 592 | 593 | prompt = load_file(path="files/variable_generation.txt") 594 | 595 | prompt = prompt.format(variable_name=variable_name) 596 | 597 | completion = model.ask_instruct_gpt(prompt=prompt, temperature=0, max_tokens=30) 598 | 599 | return Variable(**completion) 600 | -------------------------------------------------------------------------------- /opendatagen/agent.py: -------------------------------------------------------------------------------- 1 | from openai import OpenAI 2 | import os 3 | import json 4 | import pandas as pd 5 | import copy 6 | 7 | class DataAgent: 8 | 9 | initial_df:pd.DataFrame = None 10 | data_frame:pd.DataFrame = None 11 | data_to_correct:str = None 12 | initial_issue:str = None 13 | current_row_to_correct:int = None 14 | columns_to_analyse:str = None 15 | column_to_modify:str = None 16 | 17 | current_correction:str = None 18 | successful_conversation_list:list = None 19 | good_examples:list = None 20 | good_example_column:str = None 21 | 22 | start_line_to_analyse:int = None 23 | last_line_to_analyse:int = None 24 | specific_lines_to_analyse:list = None 25 | 26 | 27 | system_prompt = """ 28 | You are CSVGPT, a GPT specialized in evaluating and correcting CSV files. Your key functions involve: 29 | 1. Requesting the CSV file path. (use ask_user_for_file_path then load_csv function) 30 | 2. Asking users about the specific evaluations or corrections they need. (use ask_user_for_evaluation_criteria) 31 | 3. Identifying issues. (use identify_issue) 32 | 4. Confirming detected issues and proposed corrections with users. No verbose. (use confirm_and_propose_corrections) 33 | 5. Applying corrections across similar lines after obtaining user consent. (use apply_corrections) 34 | 6. Focusing on data accuracy and maintaining user control in the process. 35 | 36 | Process step by step. 37 | 38 | If you need precision please ask. (use ask_user_for_precision) 39 | 40 | You communicate in a professional yet approachable tone, making technical concepts accessible without being overly casual. 41 | This balance ensures clarity and fosters a positive user experience. 42 | In situations of ambiguity, you actively seek clarifications to provide precise assistance. 43 | Your interactions are always consent-driven, emphasizing clarity and user preferences. 44 | """ 45 | 46 | messages = [ 47 | {"role":"system", "content":system_prompt}, 48 | ] 49 | 50 | functions = [ 51 | { 52 | 53 | "name": "ask_user_for_precision", 54 | "description": "Ask the user for precision", 55 | "parameters": { 56 | "type": "object", 57 | "properties": { 58 | "message": {"type": "string", "description": "Simple sentence to ask for precision from the user"} 59 | }, 60 | "required": ["message"] 61 | } 62 | 63 | }, 64 | 65 | { 66 | 67 | "name": "ask_user_for_file_path", 68 | "description": "Ask the user for the CSV file path", 69 | "parameters": { 70 | "type": "object", 71 | "properties": {}, 72 | "required": [] 73 | } 74 | 75 | }, 76 | { 77 | 78 | "name": "load_csv", 79 | "description": "Load a CSV file into a DataFrame from the file path provided by the user. If the user provide only one line to process, you must return a value for start_line and end_line", 80 | "parameters": { 81 | "type": "object", 82 | "properties": { 83 | "file_path": {"type": "string", "description": "The CSV file path"}, 84 | "start_line": {"type": "integer", "description": "The first CSV line where the user want to start the process"}, 85 | "end_line": {"type": "integer", "description": "The last CSV line where the user want to end the process."}, 86 | "delimiter": {"type": "string", "description": "The delimiter for the CSV file. Default is ','", "enum": [";", ","]}, 87 | "specific_lines": { 88 | "type": "array", 89 | "items": { 90 | "type": "integer" 91 | }, 92 | "description": "Must be an array of the CSV line index to process, if specified", 93 | } 94 | }, 95 | "required": ["file_path"] 96 | } 97 | 98 | }, 99 | { 100 | "name": "ask_user_for_evaluation_criteria", 101 | "description": "Ask the user about the specific evaluations or corrections they need.", 102 | "parameters": { 103 | "type": "object", 104 | "properties": {}, 105 | "required": [] 106 | } 107 | }, 108 | { 109 | 110 | "name": "identify_issue", 111 | "description": "Identify issues in the CSV data.", 112 | "parameters": { 113 | "type": "object", 114 | "properties": { 115 | "columns_name": { 116 | "type": "array", 117 | "items": { 118 | "type": "string" 119 | }, 120 | "description": "Must be an array of the column's name to analyse.", 121 | }, 122 | "column_to_correct": {"type": "string", "description": "Must be the column name to correct"}, 123 | "good_examples": { 124 | "type": "array", 125 | "items": { 126 | "type": "integer" 127 | }, 128 | "description": "Must be an array of the CSV line index where the issue is correctly handled.", 129 | }, 130 | "good_example_column": {"type": "string", "description": "Must be the column name where the good examples are"}, 131 | }, 132 | "required": ["columns_name", "column_to_correct"] 133 | } 134 | 135 | }, 136 | { 137 | 138 | "name": "confirm_and_propose_corrections", 139 | "description": "Correct issues in the CSV data.", 140 | "parameters": { 141 | "type": "object", 142 | "properties": {}, 143 | "required": [] 144 | } 145 | 146 | }, 147 | { 148 | 149 | "name": "apply_corrections", 150 | "description": "Apply corrections to the DataFrame based on user consent.", 151 | "parameters": { 152 | "type": "object", 153 | "properties": {}, 154 | "required": [] 155 | } 156 | 157 | }, 158 | 159 | ] 160 | 161 | client = OpenAI() 162 | 163 | def __init__(self, model_name="gpt-4-turbo-2024-04-09"): 164 | self.model_name = model_name 165 | self.client.api_key = os.environ.get("OPENAI_API_KEY") 166 | 167 | def load_csv(self, params:dict): 168 | """ 169 | Load a CSV file into a DataFrame. 170 | """ 171 | file_path_str = params.get("file_path") 172 | self.delimiter = params.get("delimiter", ",") 173 | 174 | #self.initial_df = pd.read_csv(file_path_str, delimiter=self.delimiter, header=None) 175 | 176 | try: 177 | 178 | self.data_frame = pd.read_csv(file_path_str, delimiter=self.delimiter) 179 | 180 | self.start_line_to_analyse = params.get("start_line", None) 181 | 182 | self.last_line_to_analyse = min(params.get("end_line", len(self.data_frame.index) + 1) , len(self.data_frame.index) + 1) 183 | 184 | self.specific_lines_to_analyse = params.get("specific_lines", None) 185 | 186 | self.csv_path = file_path_str 187 | 188 | return "CSV successfully loaded. Now let's ask the user about the specific evaluations or corrections they need." 189 | 190 | except Exception as e: 191 | 192 | print(e) 193 | return f"Error loading file: {e}. Please re-ask for the filepath." 194 | 195 | def ask_user_for_file_path(self): 196 | """ 197 | Ask the user about the CSV file path. 198 | """ 199 | 200 | file_path = input("Please specify the file path of your CSV: ") 201 | 202 | delimiter_input = input("Please specify the delimiter (optional): ") 203 | 204 | process_detail = input("Please specify the lines you want to process. Line 1 is the header (optional): ") 205 | 206 | if delimiter_input.strip == "": 207 | delimiter_input = "The delimiter to use is ','." 208 | else: 209 | delimiter_input = f"The delimiter to use is '{delimiter_input}'." 210 | 211 | if process_detail: 212 | 213 | m = f""" Here is the CSV file path: '{file_path}' 214 | And the CSV lines to process: {process_detail}. 215 | {delimiter_input} 216 | """ 217 | 218 | return m 219 | 220 | else: 221 | 222 | m = f""" Here is the CSV file path: '{file_path}'. 223 | {delimiter_input} 224 | """ 225 | 226 | return m 227 | 228 | 229 | def ask_user_for_precision(self, message): 230 | """ 231 | Ask the user for precision 232 | """ 233 | user_input = input(message) 234 | 235 | return user_input 236 | 237 | def ask_user_for_evaluation_criteria(self): 238 | """ 239 | Ask the user about the specific evaluations or corrections they need. 240 | """ 241 | 242 | user_input = input("Please specify the evaluations or corrections needed: ") 243 | 244 | self.initial_issue = user_input 245 | 246 | m = "" 247 | 248 | while True: 249 | 250 | is_good_examples = input("Is there any rows where this issue is correctly handled (y/n)") 251 | 252 | if is_good_examples.lower() in ['y', 'n']: 253 | 254 | if is_good_examples.lower() == 'y': 255 | 256 | good_examples_input = input("Please provide a line where the issue is correctly handled: ") 257 | 258 | good_example_column_input = input("Please provide the column where the issue is correctly handled: ") 259 | 260 | m = f"""Here is the evaluation and correction needed: 261 | '{user_input}' 262 | 263 | Here are the lines where the issue is correctly handled: 264 | '{good_examples_input}' 265 | 266 | Here is the column where good examples are: 267 | '{good_example_column_input}' 268 | 269 | Now let's identify issues with the function identify_issue 270 | """ 271 | 272 | break 273 | 274 | else: 275 | 276 | m = f"""Here is the evaluation and correction needed: 277 | '{user_input}' 278 | 279 | There is no lines where the issue is correctly handle. 280 | 281 | Now let's identify issues with the function identify_issue 282 | """ 283 | 284 | break 285 | 286 | 287 | return m 288 | 289 | def identify_issue(self, issue:dict): 290 | """ 291 | Identify issues in the CSV data. 292 | """ 293 | 294 | #issue_string = issue["issue_string"] 295 | self.columns_to_analyse = issue.get("columns_name", None) 296 | self.column_to_modify = issue.get("column_to_correct", None) 297 | self.good_examples = issue.get("good_examples", None) 298 | self.good_example_column = issue.get("good_example_column", None) 299 | 300 | if self.good_examples: 301 | 302 | example_line = int(self.good_examples[0]) - 2 303 | 304 | example = self.data_frame.loc[example_line, self.good_example_column] 305 | 306 | issue_system_prompt = f""" 307 | The user has provided a general issue to correct in a CSV file. 308 | Given this issue, your job is to detect if the issue occurs for the given CSV Data provided. 309 | Answer by explaining why you detect an issue or not and finish your explanation by 'issue detected' or 'issue not detected'. 310 | Here is one example where the issue is correctly handled: 311 | {example} 312 | """ 313 | 314 | else: 315 | 316 | issue_system_prompt = """ 317 | The user has provided a general issue to correct in a CSV file. 318 | Given this issue, your job is to detect if the issue occurs for the given CSV Data provided. 319 | Answer by explaining why you detect an issue or not and finish your explanation by 'issue detected' or 'issue not detected'. 320 | """ 321 | 322 | indices = self.get_indices_to_analyze(start_line=self.start_line_to_analyse, last_line=self.last_line_to_analyse, specific_lines=self.specific_lines_to_analyse) 323 | 324 | # Iterate over each row in the DataFrame 325 | for index, row in self.data_frame.iterrows(): 326 | 327 | adjusted_index = index + 2 328 | 329 | if adjusted_index in indices: 330 | 331 | csv_data_str = "\n\n".join([f"{col} value:\n'''\n{row[col]}\n'''" for col in self.columns_to_analyse if col in row]) 332 | 333 | issue_user_prompt = f""" 334 | Issue: 335 | '''{self.initial_issue}''' 336 | 337 | CSV Data: 338 | {csv_data_str} 339 | 340 | """ 341 | 342 | messages = [ 343 | 344 | {"role":"system", "content":issue_system_prompt}, 345 | {"role":"user","content": issue_user_prompt} 346 | 347 | ] 348 | 349 | completion = self.askgpt(messages, max_tokens=2024, functions=None) 350 | 351 | answer = completion.choices[0].message.content 352 | 353 | if "issue detected" in answer.lower(): 354 | 355 | self.data_to_correct = csv_data_str 356 | 357 | self.current_row_to_correct = index 358 | 359 | log_to_print = f"I have detected an issue at line {adjusted_index} of the CSV. Now proposing a correction." 360 | 361 | print(log_to_print) 362 | 363 | # Creating a copy of self.messages for recursive processing 364 | messages_copy = copy.deepcopy(messages) 365 | messages_copy.append({"role": "assistant", "content": log_to_print}) 366 | 367 | # Recursive call to run method with the copied messages 368 | self.run(messages=messages_copy) 369 | 370 | # After returning from the recursive call, continue processing the next lines 371 | continue 372 | 373 | else: 374 | 375 | log_to_print = f"No issue detected for the line {adjusted_index}" 376 | 377 | print(log_to_print) 378 | 379 | return "All the CSV row has been processed. End the conversation." 380 | 381 | def confirm_and_propose_corrections(self): 382 | """ 383 | Confirm detected issues with the user and propose corrections. 384 | """ 385 | 386 | """ 387 | confirm_and_correct_system_prompt = 388 | You are DataCorrectorGPT. DataCorrectorGPT is equipped to handle a wide range of data types and errors, as specified by the user. 389 | It is adept at analyzing various data formats and understanding different kinds of errors that can occur in data sets, such as formatting mistakes, missing values, statistical inaccuracies, or inconsistencies. 390 | The GPT will rely on user input to identify the specific type of data and the nature of the error, then use its expertise to suggest appropriate corrections or modifications. 391 | This broad capability allows it to be versatile and responsive to diverse user needs in data correction. 392 | 393 | """ 394 | 395 | confirm_and_correct_system_prompt = """ 396 | Correct the issue declared by the user. 397 | """ 398 | 399 | confirm_and_correct_user_prompt = f""" 400 | Issue detected and confirmed by the user: 401 | {self.initial_issue} 402 | 403 | Data to correct: ''' 404 | {self.data_to_correct} 405 | ''' 406 | 407 | """ 408 | 409 | 410 | messages = [ 411 | 412 | {"role":"system", "content":confirm_and_correct_system_prompt}, 413 | {"role":"user","content": confirm_and_correct_user_prompt} 414 | 415 | ] 416 | 417 | completion = self.askgpt(messages) 418 | 419 | answer = completion.choices[0].message.content 420 | 421 | while True: 422 | 423 | user_input = input(answer + "\n\n Do you confirm the correction ? (y/n)").strip().lower() 424 | if user_input in ['y', 'n']: 425 | 426 | if user_input == "y": 427 | 428 | non_verbose_answer = self.extract_answer_from_verbose_answer(verbose_answer=answer) 429 | 430 | self.current_correction = non_verbose_answer 431 | 432 | message_to_return = f""" 433 | Here are the correction you have provided: 434 | ''' 435 | {non_verbose_answer} 436 | ''' 437 | 438 | That's great! The answer is correct now let's apply change to the corresponding row. 439 | """ 440 | 441 | return message_to_return 442 | 443 | else: 444 | 445 | debug_input = input("Please provide precision on why the answer is not good.") 446 | 447 | return f"No the answer provided is not correct '''{answer}''' because '{debug_input}'. Please submit another correction." 448 | 449 | else: 450 | 451 | print("Please answer with 'y' for yes or 'n' for no.") 452 | 453 | 454 | 455 | def extract_answer_from_verbose_answer(self, verbose_answer:str): 456 | 457 | old_answer = self.data_frame.loc[self.current_row_to_correct, self.column_to_modify] 458 | 459 | extract_answer_system_prompt = """ 460 | Given the old answer and the new answer provided by the user. 461 | You must rewrite the new answer with the same format as the old. 462 | 463 | Example: 464 | 465 | Old answer: 466 | ''' 467 | {"sentence":"Hello I'm Bryan"} 468 | ''' 469 | 470 | Verbose answer: 471 | ''' 472 | Based on the issue you have detected the answer is 473 | {"sentence":"Hello I'm Thomas"} 474 | ''' 475 | 476 | New answer: 477 | {"sentence":"Hello I'm Thomas"} 478 | 479 | """ 480 | 481 | extract_answer_user_prompt = f""" 482 | 483 | Old answer: 484 | ''' 485 | {old_answer} 486 | ''' 487 | 488 | Verbose answer: 489 | ''' 490 | {verbose_answer} 491 | ''' 492 | 493 | New answer: 494 | 495 | """ 496 | 497 | messages = [ 498 | 499 | {"role":"system", "content":extract_answer_system_prompt}, 500 | {"role":"user","content": extract_answer_user_prompt} 501 | 502 | ] 503 | 504 | completion = self.askgpt(messages) 505 | 506 | answer = completion.choices[0].message.content 507 | 508 | return answer 509 | 510 | 511 | def apply_corrections(self): 512 | 513 | """ 514 | Apply corrections to the DataFrame based on user consent. 515 | """ 516 | 517 | if self.current_correction is not None and self.column_to_modify is not None and self.current_row_to_correct is not None: 518 | 519 | self.data_frame.loc[self.current_row_to_correct, self.column_to_modify] = self.current_correction 520 | self.data_frame.to_csv(self.csv_path, index=False, sep=self.delimiter) 521 | 522 | return "The data is corrected in the file. Great work. End of the conversation." 523 | 524 | else: 525 | 526 | return "An error occured, please debug the process. End of the conversation." 527 | 528 | 529 | 530 | def askgpt(self, messages:list,functions:list = None, temperature:int = 0, max_tokens:int=512): 531 | 532 | param = { 533 | "model": self.model_name, 534 | "messages": messages, 535 | "temperature": temperature 536 | } 537 | 538 | if functions: 539 | param["functions"] = functions 540 | 541 | completion = self.client.chat.completions.create(**param) 542 | 543 | return completion 544 | 545 | def function_to_call(self, parameter_to_pass, function_to_call): 546 | 547 | #parameter_to_pass = json.loads(completion.choices[0].message.function_call.arguments) 548 | #function_to_call = completion.choices[0].message.function_call.name 549 | 550 | func = getattr(self, function_to_call) 551 | 552 | if len(parameter_to_pass) == 0: 553 | result = func() 554 | else: 555 | result = func(parameter_to_pass) 556 | 557 | return result 558 | 559 | def run(self, messages=None): 560 | 561 | if messages is None: 562 | messages = self.messages 563 | 564 | #RUN AGENT 565 | while True: 566 | 567 | completion = self.askgpt(messages, functions=self.functions) 568 | 569 | if completion.choices[0].finish_reason == "function_call": 570 | 571 | #answer = completion.choices[0].message.content 572 | 573 | parameters = json.loads(completion.choices[0].message.function_call.arguments) 574 | function_name = completion.choices[0].message.function_call.name 575 | 576 | result = self.function_to_call(parameters, function_name) 577 | 578 | messages.append({"role":"assistant", "content":result}) 579 | 580 | return messages 581 | 582 | elif completion.choices[0].finish_reason == "stop": 583 | 584 | result = completion.choices[0].message.content 585 | 586 | messages.append({"role":"assistant", "content":result}) 587 | 588 | return messages 589 | 590 | else: 591 | 592 | result = completion.choices[0].message.content 593 | 594 | messages.append({"role":"assistant", "content":result}) 595 | 596 | return messages 597 | 598 | def indices_to_ignore(self, dataframe, rows_to_keep): 599 | """ 600 | Returns a list of indices to ignore for a given pandas DataFrame. 601 | 602 | :param dataframe: pandas DataFrame from which indices are derived. 603 | :param rows_to_keep: List of integer indices of rows that should be kept. 604 | :return: List of integer indices of rows to ignore. 605 | """ 606 | all_indices = set(range(len(dataframe))) # Generate a set of all row indices 607 | indices_to_keep = set(rows_to_keep) # Convert rows_to_keep to a set for efficient removal 608 | return list(all_indices - indices_to_keep) # Return the difference as a list 609 | 610 | def get_indices_to_analyze(self, start_line, last_line, specific_lines): 611 | indices_to_analyze = set() 612 | 613 | # Adjust for the offset in the loop 614 | offset = 0 615 | 616 | # Check if all parameters are None 617 | if start_line is None and last_line is None and specific_lines is None: 618 | # Add all indices adjusted by the offset 619 | indices_to_analyze.update(range(offset, len(self.data_frame) + offset)) 620 | 621 | else: 622 | # Prioritize specific lines if specified 623 | if specific_lines is not None: 624 | adjusted_specific_lines = [line + offset for line in specific_lines] 625 | indices_to_analyze.update(adjusted_specific_lines) 626 | 627 | # If only start_line is specified 628 | elif start_line is not None and last_line is None: 629 | indices_to_analyze.add(start_line + offset) 630 | 631 | # If only last_line is specified 632 | elif start_line is None and last_line is not None: 633 | indices_to_analyze.update(range(offset, last_line + 1 + offset)) 634 | 635 | # If both start_line and last_line are specified 636 | elif start_line is not None and last_line is not None: 637 | # Only add range if specific_lines is None 638 | if specific_lines is None: 639 | indices_to_analyze.update(range(start_line + offset, last_line + 1 + offset)) 640 | 641 | else: 642 | 643 | return indices_to_analyze 644 | 645 | return indices_to_analyze 646 | -------------------------------------------------------------------------------- /opendatagen/data_generator.py: -------------------------------------------------------------------------------- 1 | 2 | from dotenv import load_dotenv 3 | import numpy as np 4 | import time 5 | import random 6 | import re 7 | import json 8 | import requests 9 | from urllib.parse import quote 10 | from re import findall 11 | from typing import Dict, List, Union 12 | from opendatagen.utils import dict_to_string, load_file, write_to_csv, find_strings_in_brackets, find_strings_in_double_brackets 13 | from opendatagen.utils import pydantic_list_to_dict, replace_with_dict 14 | from opendatagen.anonymizer import Anonymizer 15 | from opendatagen.model import OpenAIChatModel, OpenAIInstructModel, OpenAIEmbeddingModel, ModelName, MistralChatModel, LlamaCPPModel, TogetherChatModel, AnyscaleChatModel, UserMessage 16 | from opendatagen.template import Template, Variable, Variations, create_variable_from_name 17 | from opendatagen.utils import function_to_call 18 | from mistralai.client import MistralClient 19 | from mistralai.models.chat_completion import ChatMessage 20 | import uuid 21 | import copy 22 | 23 | load_dotenv() 24 | 25 | class DataGenerator: 26 | 27 | output_array = [] 28 | 29 | def __init__(self, template:Template): 30 | 31 | self.template = template 32 | 33 | def extract_variable_from_string(self, text:str): 34 | return findall(r'\{\{(.*?)\}\}', text) 35 | 36 | def extract_variable_dict_from_string(self, text:str): 37 | 38 | list_of_variables = findall(r'\{\{(.*?)\}\}', text) 39 | 40 | result = {} 41 | 42 | for variable_id, variable in self.template.variables.items(): 43 | 44 | if variable_id in list_of_variables: 45 | result[variable_id] = variable 46 | 47 | return result 48 | 49 | def anonymize_text(self, text_to_anonymize): 50 | 51 | # Example usage: 52 | anonymizer = Anonymizer() 53 | 54 | anonymized_text = anonymizer.anonymize(text_to_anonymize) 55 | 56 | return anonymized_text 57 | 58 | def contextual_generation(self, variables:list, current_variation_dict:dict, fixed_variables: Dict[str, Variable], completion:str=None, parent_id:str=None): 59 | 60 | # This will be the list to collect all dictionaries 61 | result = [] 62 | 63 | if not variables: 64 | # No more variables to process, generate final variation 65 | return [current_variation_dict.copy()] 66 | 67 | # Get the next variable 68 | next_var = variables[0] 69 | remaining_variables = variables[1:] 70 | 71 | variable = fixed_variables[next_var] 72 | 73 | variations = self.generate_variable(current_variable=variable, 74 | variable_id_string=next_var, 75 | parent_id=parent_id) 76 | 77 | 78 | 79 | if variable.junction: 80 | 81 | data:List[str] = [variable.value for variable in variations.values()] 82 | junction_value = variable.junction.generate(data=data) 83 | 84 | for key, old_value in variations.items(): 85 | 86 | temp_value = variations[key] 87 | temp_value.value = junction_value 88 | variations[key] = temp_value 89 | 90 | if variable.junction.delete_branch: 91 | 92 | last_key, last_value = list(variations.items())[-1] 93 | variations.clear() 94 | variations[last_key] = last_value 95 | 96 | variations[key] = temp_value 97 | 98 | for id, variation in variations.items(): 99 | # Update the current variations dictionary with the new variation 100 | updated_variation_dict = current_variation_dict.copy() 101 | 102 | updated_variation_dict[next_var] = variation 103 | 104 | # Recursively process the remaining variables 105 | # and extend the all_variation_dicts list with the results 106 | result.extend(self.contextual_generation( 107 | completion=completion, 108 | variables=remaining_variables, 109 | current_variation_dict=updated_variation_dict, 110 | fixed_variables=fixed_variables, 111 | parent_id=id 112 | )) 113 | 114 | # Return the list of all variation dictionaries generated 115 | return result 116 | 117 | def transform_generated_value(self, current_variable:Variable, value:str, parent_id): 118 | 119 | function_name = current_variable.transform_value.function_name 120 | from_notebook = current_variable.transform_value.from_notebook 121 | additional_parameters = current_variable.transform_value.additional_parameters 122 | 123 | param_dict = {} 124 | 125 | if additional_parameters: 126 | 127 | for param in additional_parameters: 128 | 129 | param_dict[param] = self.template.variables[param].values[parent_id] 130 | 131 | param_dict["value"] = value 132 | 133 | generated_value = function_to_call(function_name, from_notebook, param_dict) 134 | 135 | return generated_value 136 | 137 | 138 | def add_variation_value(self, variations:dict, variable_id_string:str, current_variable:Variable, generated_value:str, initial_value:str=None, parent_id:str=None, id:str=None): 139 | 140 | if parent_id: 141 | 142 | if id: 143 | new_id = id 144 | else: 145 | new_id = str(uuid.uuid4()) 146 | 147 | new_value = Variations(id=new_id, parent_id=parent_id, value=generated_value, initial_value=initial_value) 148 | 149 | current_variable.values[new_id] = new_value 150 | 151 | variations[new_id] = new_value 152 | 153 | self.template.variables[variable_id_string].values[new_id] = new_value 154 | 155 | else: 156 | 157 | if id: 158 | id_loop = id 159 | else: 160 | id_loop = str(uuid.uuid4()) 161 | 162 | new_value = Variations(id=id_loop, parent_id=id_loop, value=generated_value, initial_value=initial_value) 163 | 164 | current_variable.values[id_loop] = new_value 165 | 166 | variations[id_loop] = new_value 167 | 168 | self.template.variables[variable_id_string].values[id_loop] = new_value 169 | 170 | 171 | 172 | def handle_user_prompt(self, variables_to_get, message, variable_id_string, parent_id): 173 | 174 | if len(variables_to_get) > 0: 175 | 176 | temp = {} 177 | replace_dict = {} 178 | 179 | for target_variable_name in variables_to_get: 180 | 181 | #Manage .value and .initial_value 182 | split_result = target_variable_name.split(".") 183 | target_name = split_result[0] 184 | 185 | replace_dict[target_variable_name] = target_name 186 | 187 | try: 188 | get_initial_value = split_result[1] 189 | if get_initial_value.lower() == "value": 190 | get_initial_value = False 191 | else: 192 | get_initial_value = True 193 | 194 | except IndexError: 195 | 196 | get_initial_value = False 197 | 198 | value = self.retrieve_value(target_key=target_name, 199 | current_variable_name=variable_id_string, 200 | parent_id=parent_id, 201 | get_initial_value=get_initial_value) 202 | 203 | temp[target_name] = value 204 | 205 | for old, new in replace_dict.items(): 206 | 207 | if isinstance(message.content, str): 208 | 209 | message.content = message.content.replace(old, new) 210 | message.content = replace_with_dict(message.content, temp) 211 | 212 | if isinstance(message.content, list): 213 | 214 | for content in message.content: 215 | 216 | if content.type == "image_url": 217 | 218 | content.image_url.url = content.image_url.url.replace(old, new) 219 | content.image_url.url = replace_with_dict(content.image_url.url, temp) 220 | 221 | elif content.type == "text": 222 | 223 | content.text = content.text.replace(old, new) 224 | content.text = replace_with_dict(content.text, temp) 225 | 226 | 227 | if message.rephraser: 228 | message.rephrase() 229 | 230 | 231 | def retrieve_value(self, target_key, current_variable_name, parent_id, get_initial_value): 232 | 233 | # Get keys in reverse order 234 | keys_in_reverse = list(self.template.variables.keys())[::-1] 235 | 236 | # Find the starting index 237 | start_index = keys_in_reverse.index(current_variable_name) + 1 if current_variable_name in keys_in_reverse else len(keys_in_reverse) 238 | 239 | def find_value(current_id, keys): 240 | for key in keys: 241 | value = self.template.variables[key].values[current_id] 242 | if value.id == current_id: 243 | if key == target_key: 244 | if get_initial_value: 245 | return value.initial_value 246 | else: 247 | return value.value 248 | return find_value(value.parent_id, keys[start_index:]) 249 | return None 250 | 251 | # Start the lookup process from the key that comes before the current_variable_name 252 | return find_value(parent_id, keys_in_reverse[start_index:]) 253 | 254 | 255 | 256 | def generate_variable(self, current_variable:Variable, variable_id_string:str, parent_id:str=None): 257 | 258 | generation_number = current_variable.generation_number 259 | 260 | variations = {} 261 | 262 | if current_variable.get_value_from_custom_functions: 263 | 264 | for _ in range(generation_number): 265 | 266 | function_name = current_variable.get_value_from_custom_functions.function_name 267 | from_notebook = current_variable.get_value_from_custom_functions.from_notebook 268 | additional_parameters = current_variable.get_value_from_custom_functions.additional_parameters 269 | 270 | param_dict = {} 271 | 272 | if additional_parameters: 273 | 274 | for param in additional_parameters: 275 | 276 | param_dict[param] = self.template.variables[param].values[parent_id] 277 | 278 | generated_value = function_to_call(function_name, from_notebook, param_dict) 279 | 280 | if current_variable.transform_value : 281 | generated_value = self.transform_generated_value(current_variable=current_variable, value=generated_value, parent_id=parent_id) 282 | 283 | if current_variable.get_value_from_custom_functions.chunking: 284 | 285 | chunks = current_variable.get_value_from_huggingface.chunking.perform_chunk(text=generated_value) 286 | 287 | for chunk in chunks: 288 | 289 | #add values to variations 290 | self.add_variation_value(variations=variations, 291 | variable_id_string=variable_id_string, 292 | current_variable=current_variable, 293 | generated_value=chunk, 294 | initial_value=generated_value, 295 | parent_id=parent_id) 296 | 297 | else: 298 | 299 | #add values to variations 300 | self.add_variation_value(variations=variations, 301 | variable_id_string=variable_id_string, 302 | current_variable=current_variable, 303 | generated_value=chunk, 304 | initial_value=generated_value, 305 | parent_id=parent_id) 306 | 307 | 308 | if current_variable.decontamination: 309 | variations = current_variable.decontamination.decontaminate_variable(variations) 310 | 311 | return variations 312 | 313 | if current_variable.get_value_from_localfile: 314 | 315 | for _ in range(generation_number): 316 | 317 | generated_value = current_variable.get_value_from_localfile.get_content_from_file() 318 | 319 | if current_variable.transform_value : 320 | generated_value = self.transform_generated_value(current_variable=current_variable, value=generated_value, parent_id=parent_id) 321 | 322 | if current_variable.get_value_from_huggingface.chunking: 323 | chunks = current_variable.get_value_from_huggingface.chunking.perform_chunk(text=generated_value) 324 | 325 | for chunk in chunks: 326 | 327 | #add values to variations 328 | self.add_variation_value(variations=variations, 329 | variable_id_string=variable_id_string, 330 | current_variable=current_variable, 331 | generated_value=chunk, 332 | initial_value=generated_value, 333 | parent_id=parent_id) 334 | 335 | else: 336 | 337 | #add values to variations 338 | self.add_variation_value(variations=variations, 339 | variable_id_string=variable_id_string, 340 | current_variable=current_variable, 341 | generated_value=chunk, 342 | initial_value=generated_value, 343 | parent_id=parent_id) 344 | 345 | 346 | if current_variable.decontamination: 347 | variations = current_variable.decontamination.decontaminate_variable(variations) 348 | 349 | return variations 350 | 351 | if current_variable.get_value_from_huggingface: 352 | 353 | for _ in range(generation_number): 354 | 355 | generated_value = current_variable.get_value_from_huggingface.get_random_value_from_dataset() 356 | 357 | if current_variable.transform_value : 358 | generated_value = self.transform_generated_value(current_variable=current_variable, value=generated_value, parent_id=parent_id) 359 | 360 | if current_variable.get_value_from_huggingface.chunking: 361 | chunks = current_variable.get_value_from_huggingface.chunking.perform_chunk(text=generated_value) 362 | 363 | for chunk in chunks: 364 | 365 | #add values to variations 366 | self.add_variation_value(variations=variations, 367 | variable_id_string=variable_id_string, 368 | current_variable=current_variable, 369 | generated_value=chunk, 370 | initial_value=generated_value, 371 | parent_id=parent_id) 372 | 373 | else: 374 | 375 | #add values to variations 376 | self.add_variation_value( variations=variations, 377 | variable_id_string=variable_id_string, 378 | current_variable=current_variable, 379 | generated_value=generated_value, 380 | initial_value=generated_value, 381 | parent_id=parent_id) 382 | 383 | if current_variable.decontamination: 384 | variations = current_variable.decontamination.decontaminate_variable(variations) 385 | 386 | return variations 387 | 388 | rag_content = "" 389 | chosen_models = [] 390 | independent_messages = None 391 | 392 | current_variable = self.template.variables[variable_id_string] 393 | 394 | if current_variable.source_localfile: 395 | current_variable.load_local_file() 396 | elif current_variable.source_localdirectory: 397 | current_variable.load_local_directory() 398 | elif current_variable.source_internet: 399 | current_variable.load_internet_source() 400 | elif current_variable.source_huggingface: 401 | current_variable.load_huggingface_dataset() 402 | 403 | for _ in range(generation_number): 404 | 405 | if current_variable.ensure_model_diversity: 406 | 407 | available_models = [model.get_model() for model in current_variable.models if model.get_model() not in chosen_models] 408 | 409 | if available_models: 410 | current_model = random.choice(available_models) 411 | else: 412 | current_model = random.choice(current_variable.models).get_model() 413 | 414 | else: 415 | 416 | current_model = random.choice(current_variable.models).get_model() 417 | 418 | chosen_models.append(current_model) 419 | 420 | #Get the variables value in user_prompt from the model object 421 | if hasattr(current_model, 'user_prompt') and isinstance(current_model.user_prompt, list): 422 | 423 | initial_messages = pydantic_list_to_dict(lst=current_model.user_prompt, fields=['role', 'content']) 424 | copy_messages = copy.deepcopy(initial_messages) 425 | copy_messages_obj = copy.deepcopy(current_model.user_prompt) 426 | 427 | for message in current_model.user_prompt: 428 | 429 | if isinstance(message, str): 430 | 431 | variables_to_get = find_strings_in_double_brackets(text=message.content) 432 | 433 | elif isinstance(message, UserMessage): 434 | 435 | if isinstance(message.content, str): 436 | 437 | variables_to_get = find_strings_in_double_brackets(text=message.content) 438 | 439 | elif isinstance(message.content, list): 440 | 441 | for content in message.content: 442 | 443 | if content.type == "image_url": 444 | 445 | variables_to_get = find_strings_in_double_brackets(text=content.image_url.url) 446 | 447 | self.handle_user_prompt(variables_to_get=variables_to_get, 448 | message=message, 449 | variable_id_string=variable_id_string, 450 | parent_id=parent_id) 451 | 452 | elif content.type == "text": 453 | 454 | variables_to_get = find_strings_in_double_brackets(text=content.text) 455 | 456 | self.handle_user_prompt(variables_to_get=variables_to_get, 457 | message=message, 458 | variable_id_string=variable_id_string, 459 | parent_id=parent_id) 460 | 461 | else: 462 | 463 | raise ValueError("Error") 464 | 465 | else: 466 | 467 | raise ValueError("Error") 468 | 469 | self.handle_user_prompt(variables_to_get=variables_to_get, 470 | message=message, 471 | variable_id_string=variable_id_string, 472 | parent_id=parent_id) 473 | 474 | 475 | if current_variable.rag_content: 476 | 477 | rag_content = f"Here is some context that will help you:\n'''{current_variable.rag_content}\n'''" 478 | current_model.user_prompt.append(UserMessage(role="user", content=rag_content)) 479 | 480 | elif hasattr(current_model, 'user_prompt') and isinstance(current_model.user_prompt, str): 481 | 482 | copy_messages_obj = copy.deepcopy(current_model.user_prompt) 483 | 484 | variables_to_get = find_strings_in_double_brackets(text=current_model.user_prompt) 485 | 486 | if len(variables_to_get) > 0: 487 | 488 | temp = {} 489 | 490 | for target_variable_name in variables_to_get: 491 | 492 | value = self.retrieve_value(target_key=target_variable_name, 493 | current_variable_name=variable_id_string, 494 | parent_id=parent_id, 495 | get_initial_value=True) 496 | 497 | temp[target_variable_name] = value 498 | 499 | current_model.user_prompt = replace_with_dict(current_model.user_prompt, temp) 500 | 501 | #if message.rephraser: 502 | # message.rephrase() 503 | 504 | 505 | if current_variable.rag_content: 506 | 507 | rag_content = f"Here is some context that will help you:\n'''{current_variable.rag_content}\n'''" 508 | current_model.user_prompt.append(UserMessage(role="user", content=rag_content)) 509 | 510 | elif hasattr(current_model, 'path') and isinstance(current_model.path, str) : 511 | 512 | copy_messages_obj = copy.deepcopy(current_model.path) 513 | 514 | variables_to_get = find_strings_in_double_brackets(text=current_model.path) 515 | 516 | if len(variables_to_get) > 0: 517 | 518 | temp = {} 519 | 520 | for target_variable_name in variables_to_get: 521 | 522 | value = self.retrieve_value(target_key=target_variable_name, 523 | current_variable_name=variable_id_string, 524 | parent_id=parent_id, 525 | get_initial_value=True) 526 | 527 | temp[target_variable_name] = value 528 | 529 | 530 | current_model.path = replace_with_dict(current_model.path, temp) 531 | 532 | #if message.rephraser: 533 | # message.rephrase() 534 | 535 | 536 | if current_variable.rag_content: 537 | 538 | rag_content = f"Here is some context that will help you:\n'''{current_variable.rag_content}\n'''" 539 | current_model.path.append(UserMessage(role="user", content=rag_content)) 540 | 541 | 542 | else: 543 | 544 | raise ValueError("User prompt is badly formatted") 545 | 546 | 547 | variation_id = str(uuid.uuid4()) 548 | 549 | if current_variable.transform_value : 550 | generated_value = self.transform_generated_value(current_variable=current_variable, value=generated_value, parent_id=parent_id) 551 | 552 | new_value = Variations(id=variation_id, 553 | parent_id=parent_id, 554 | value=generated_value, 555 | initial_value=generated_value, 556 | confidence_score=current_model.confidence_score, 557 | model_used=current_model.name) 558 | 559 | current_variable.values[variation_id] = new_value 560 | 561 | 562 | 563 | if current_variable.validator: 564 | 565 | count = 1 566 | 567 | while True: 568 | 569 | if count > current_variable.validator.retry_number: 570 | 571 | new_value = Variations(id=variation_id, 572 | parent_id=parent_id, 573 | value=generated_value, 574 | initial_value=generated_value, 575 | error_message=new_message, 576 | confidence_score=current_confidence_score, 577 | model_used=current_model.name) 578 | 579 | current_variable.values[variation_id] = new_value 580 | break 581 | 582 | generated_value = current_model.ask() 583 | 584 | if isinstance(current_model, OpenAIChatModel): 585 | current_confidence_score = current_model.confidence_score 586 | else: 587 | current_confidence_score = {} 588 | 589 | self.template.variables[variable_id_string].values[parent_id] = Variations(id=variation_id, 590 | parent_id=parent_id, 591 | value=generated_value, 592 | initial_value=generated_value, 593 | confidence_score=current_confidence_score, 594 | model_used=current_model.name) 595 | 596 | function_name = current_variable.validator.function_name 597 | from_notebook = current_variable.validator.from_notebook 598 | additional_parameters = current_variable.validator.additional_parameters 599 | 600 | param_dict = {} 601 | 602 | if additional_parameters: 603 | 604 | for param in additional_parameters: 605 | 606 | param_dict[param] = self.template.variables[param].values[parent_id] 607 | 608 | isValid, new_message = function_to_call(function_name, from_notebook, param_dict) 609 | 610 | if isValid: 611 | 612 | new_value = Variations(id=variation_id, 613 | parent_id=parent_id, 614 | value=generated_value, 615 | initial_value=generated_value, 616 | model_used=current_model.name) 617 | 618 | current_variable.values[variation_id] = new_value 619 | 620 | break 621 | 622 | else: 623 | 624 | if isinstance(current_model.user_prompt, list): 625 | 626 | current_model.user_prompt.append(UserMessage(role= "assistant", content = generated_value)) 627 | current_model.user_prompt.append(UserMessage(role= "user", content = new_message)) 628 | 629 | elif isinstance(current_model.user_prompt, str): 630 | 631 | current_model.user_prompt = f"{current_model.user_prompt}\n\nAssistant:{generated_value}\n\nUser:{new_message}" 632 | 633 | else: 634 | raise ValueError("Unknow type of model") 635 | 636 | 637 | current_model.ask() 638 | 639 | count = count + 1 640 | 641 | else: 642 | 643 | generated_value = current_model.ask() 644 | 645 | if current_variable.independent_values == False: 646 | 647 | independent_string = "You must generate a new value for the initial prompt that is not similar to the last values. No verbose." 648 | 649 | if isinstance(current_model.user_prompt, list): 650 | 651 | current_model.user_prompt.append(UserMessage(role="assistant", content=generated_value)) 652 | current_model.user_prompt.append(UserMessage(role="user", content=independent_string)) 653 | 654 | elif isinstance(current_model.user_prompt, str): 655 | 656 | current_model.user_prompt = f"{current_model.user_prompt}\n\nAssistant:{generated_value}\n\nAssistant:{independent_string}" 657 | 658 | else: 659 | 660 | raise ValueError("Unknow type of model") 661 | 662 | 663 | 664 | #add values to variations 665 | self.add_variation_value(id=variation_id, 666 | variations=variations, 667 | variable_id_string=variable_id_string, 668 | current_variable=current_variable, 669 | generated_value=generated_value, 670 | initial_value=generated_value, 671 | parent_id=parent_id) 672 | 673 | 674 | variations[variation_id] = Variations(id=variation_id, 675 | parent_id=parent_id, 676 | value=generated_value, 677 | initial_value=generated_value, 678 | model_used=current_model.name) 679 | 680 | 681 | if current_variable.independent_values == True: 682 | 683 | #Reinitialize user_prompt value after generation 684 | if hasattr(current_model, 'user_prompt'): 685 | current_model.user_prompt = copy_messages_obj 686 | 687 | 688 | return variations 689 | 690 | 691 | 692 | def generate_evol_instruct_prompt(self, initial_prompt:str): 693 | 694 | evol_prompt_template = load_file(path="files/evol_instruct.txt") 695 | 696 | evol_instruct_prompt = evol_prompt_template.format(number_of_prompts=str(self.template.prompt_variation_number), prompt=initial_prompt) 697 | 698 | start_messages = [ 699 | {"role": "system", "content": "Answer as a valid JSON like {\"prompts\": [\"XXXX\", \"YYYY\"]}"}, 700 | {"role": "user", "content": evol_instruct_prompt}, 701 | ] 702 | 703 | evol_instruct_model = OpenAIChatModel(model_name=ModelName.GPT_35_TURBO_CHAT.value) 704 | 705 | diversified_prompt_list = evol_instruct_model.ask(max_tokens=512, 706 | temperature=1, 707 | messages=start_messages, 708 | json_mode=True) 709 | 710 | evol_instruct_generated_prompt_list = json.loads(diversified_prompt_list)["prompts"] 711 | 712 | return evol_instruct_generated_prompt_list 713 | 714 | 715 | def get_completion_error_message(self, params:Dict[str, Variable]): 716 | 717 | error_str = "" 718 | 719 | for id, param in params.items(): 720 | 721 | if param.error_message: 722 | error_str = f"{error_str}\n{param.error_message}" 723 | 724 | return error_str.strip() 725 | 726 | def get_prompt_error_message(self, params:dict): 727 | 728 | error_str = "" 729 | 730 | for param in params: 731 | error_message = self.template.variables[param].error_message 732 | 733 | if error_message: 734 | error_str = f"{error_str}\n{error_message}" 735 | 736 | return error_str 737 | 738 | def generate_data(self): 739 | # Extracting structures and variables from the template 740 | prompt = self.template.prompt 741 | prompt_variables = self.extract_variable_from_string(prompt) 742 | prompt_fixed_variables = self.extract_variable_dict_from_string(text=prompt) 743 | 744 | result = [] 745 | 746 | if prompt_variables: 747 | # Start the recursive generation process with an empty dictionary for current variations 748 | prompts_parameters = self.contextual_generation( 749 | variables=prompt_variables, 750 | current_variation_dict={}, 751 | fixed_variables=prompt_fixed_variables 752 | ) 753 | 754 | # Iterate over each set of parameters with indices for coordinates 755 | for x, p_param in enumerate(prompts_parameters): 756 | prompt_param = {} 757 | 758 | for y, (variable_id_string, prompt_variation) in enumerate(p_param.items()): 759 | prompt_param[variable_id_string] = { 760 | "value": prompt_variation.value, 761 | "errorMessage": prompt_variation.error_message, 762 | "modelUsed": str(prompt_variation.model_used), 763 | "confidenceScore": prompt_variation.confidence_score, 764 | "coordinates": {"x": x, "y": y} # Add coordinates here 765 | } 766 | 767 | result.append(prompt_param) 768 | 769 | # Optionally, uncomment and modify the path in the next line to write to a JSON file. 770 | # with open(output_path, 'w') as json_file: 771 | # json.dump(result, json_file, indent=2) 772 | 773 | return result 774 | 775 | 776 | -------------------------------------------------------------------------------- /opendatagen/model.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_result, retry_if_exception_type 3 | from openai import OpenAI 4 | import numpy as np 5 | import os 6 | import json 7 | from opendatagen.utils import is_retryable_answer, pydantic_list_to_dict, load_file, image_to_base64_data_uri 8 | import requests 9 | from pydantic import BaseModel, validator, ValidationError, ConfigDict, Extra 10 | from typing import Optional, List, Dict, Union, Type 11 | import random 12 | from mistralai.client import MistralClient, ChatMessage 13 | import math 14 | import tiktoken 15 | from llama_cpp import Llama 16 | from llama_cpp.llama_chat_format import Llava15ChatHandler 17 | import whisper 18 | from elevenlabs.client import ElevenLabs 19 | from elevenlabs import Voice, VoiceSettings, generate, play 20 | import uuid 21 | from PIL import Image 22 | from PIL.PngImagePlugin import PngInfo 23 | import io 24 | from audiocraft.models import AudioGen 25 | from audiocraft.models import MusicGen 26 | from audiocraft.data.audio import audio_write 27 | import torchaudio 28 | from pydub import AudioSegment 29 | from typing_extensions import TypedDict, NotRequired, Literal 30 | import anthropic 31 | from transformers import AutoProcessor, BarkModel 32 | import scipy 33 | 34 | N_RETRIES = 2 35 | 36 | class EvolMethod(Enum): 37 | 38 | deep = "deep" 39 | concretizing = "concretizing" 40 | step_reasoning = "step_reasoning" 41 | breath = "breath" 42 | basic = "basic" 43 | 44 | 45 | class TextContent(BaseModel): 46 | type: Literal['text'] 47 | text: str 48 | 49 | class ImageUrlContent(BaseModel): 50 | url: str 51 | 52 | class ImageUrl(BaseModel): 53 | type: Literal['image_url'] 54 | image_url: ImageUrlContent 55 | 56 | # Define a Union type for the different content types 57 | Content = Union[TextContent, ImageUrl] 58 | 59 | class UserMessage(BaseModel): 60 | 61 | role: Literal['user', 'assistant', 'system'] 62 | content: Union[List[Content], str] 63 | rephraser:Optional[List[EvolMethod]] = None 64 | 65 | def rephrase(self): 66 | 67 | rephraser_name = random.choice(self.rephraser) 68 | 69 | prompt = load_file(path=f"files/{rephraser_name}.txt") 70 | 71 | d = {"prompt": self.content} 72 | 73 | prompt = prompt.format(**d) 74 | 75 | model = OpenAIInstructModel(user_prompt=prompt, temperature=[1]) 76 | 77 | rephrased_prompt = model.ask() 78 | 79 | self.content = rephrased_prompt 80 | 81 | 82 | class Config: 83 | extra = 'forbid' 84 | 85 | 86 | 87 | class ModelName(Enum): 88 | GPT_35_TURBO_INSTRUCT = "gpt-3.5-turbo-instruct" 89 | TEXT_DAVINCI_INSTRUCT = "text-davinci-003" 90 | GPT_35_TURBO_CHAT = "gpt-3.5-turbo-1106" 91 | GPT_35_TURBO_16K_CHAT = "gpt-3.5-turbo-16k" 92 | GPT_4_CHAT = "gpt-4" 93 | GPT_4_TURBO_CHAT = "gpt-4-1106-preview" 94 | TEXT_EMBEDDING_ADA = "text-embedding-ada-002" 95 | SMARTCHUNK = "SmartChunk-0.1-Mistral-7B" 96 | MISTRAL_7B = "Mistral-7B-v0.1" 97 | LLAMA_7B = "Llama-2-7b-chat-hf" 98 | LLAMA_13B = "Llama-2-13b-chat-hf" 99 | LLAMA_70B = "Llama-2-70b-chat-hf" 100 | 101 | 102 | 103 | 104 | class WhisperModel(BaseModel): 105 | 106 | path:str 107 | name:Optional[str] = "base" 108 | 109 | class Config: 110 | extra = 'forbid' 111 | 112 | def ask(self) -> str: 113 | 114 | model = whisper.load_model(self.name) 115 | result = model.transcribe(self.path) 116 | 117 | text = result["text"] 118 | 119 | return text 120 | 121 | 122 | class BarkTTSModel(BaseModel): 123 | 124 | name:Optional[str] = "suno/bark" 125 | speaker:Optional[str] = "v2/en_speaker_6" 126 | user_prompt:str 127 | 128 | def ask(self) -> str: 129 | 130 | processor = AutoProcessor.from_pretrained(self.name) 131 | model = BarkModel.from_pretrained(self.name) 132 | 133 | voice_preset = self.speaker 134 | 135 | inputs = processor(self.user_prompt, voice_preset=voice_preset) 136 | 137 | audio_array = model.generate(**inputs) 138 | audio_array = audio_array.cpu().numpy().squeeze() 139 | 140 | sample_rate = model.generation_config.sample_rate 141 | 142 | # Generate a random UUID and create a filename 143 | filename = f'audio_{uuid.uuid4()}.mp3' 144 | 145 | scipy.io.wavfile.write(filename, rate=sample_rate, data=audio_array) 146 | 147 | return filename 148 | 149 | 150 | class ElevenLabsTTSModel(BaseModel): 151 | 152 | name:Optional[str] = "21m00Tcm4TlvDq8ikWAM" 153 | user_prompt:str 154 | 155 | def ask(self) -> str: 156 | 157 | client = ElevenLabs(api_key=os.getenv("ELEVENLABS_API_KEY")) 158 | 159 | audio = generate( 160 | text=self.user_prompt, 161 | voice=Voice( 162 | voice_id=self.name, 163 | settings=VoiceSettings(stability=0.71, similarity_boost=0.5, style=0.0, use_speaker_boost=True) 164 | ) 165 | ) 166 | 167 | # Generate a random UUID and create a filename 168 | filename = f'audio_{uuid.uuid4()}.mp3' 169 | 170 | # Save the audio data to a file with the random filename 171 | with open(filename, 'wb') as audio_file: 172 | audio_file.write(audio) 173 | 174 | return filename 175 | 176 | class Config: 177 | extra = 'forbid' 178 | 179 | 180 | class LlamaCPPITTModel(BaseModel): 181 | 182 | path:str 183 | clip_model_path:str 184 | name:Optional[str] = None 185 | user_prompt:List[UserMessage] 186 | 187 | def ask(self) -> str: 188 | 189 | chat_handler = Llava15ChatHandler(clip_model_path=self.clip_model_path) 190 | 191 | llm = Llama( 192 | model_path=self.path, 193 | chat_handler=chat_handler, 194 | n_ctx=2048, # n_ctx should be increased to accomodate the image embedding 195 | logits_all=True,# needed to make llava work 196 | n_gpu_layers=-1 197 | ) 198 | 199 | messages = pydantic_list_to_dict(lst = self.user_prompt, fields=['role', 'content']) 200 | 201 | output = llm.create_chat_completion(messages = messages) 202 | 203 | return output["choices"][0]["message"]["content"] 204 | 205 | 206 | def __init__(self, **data): 207 | 208 | super().__init__(**data) 209 | self.name = self.path.split('/')[-1] 210 | 211 | class Config: 212 | extra = 'forbid' 213 | 214 | class LlamaCPPChatModel(BaseModel): 215 | 216 | path:str 217 | name:Optional[str] = None 218 | user_prompt:Optional[List[UserMessage]] = None 219 | use_gpu:Optional[bool] = True 220 | max_tokens:Optional[int] = 256 221 | temperature:Optional[List[float]] = [1] 222 | json_mode:Optional[bool] = False 223 | tools:Optional[list] = None 224 | tool_choice:Optional[str] = None 225 | stop:Optional[List[str]] = None 226 | top_p:Optional[float] = 0.95 227 | min_p:Optional[float] = 0.05 228 | chat_format:Optional[str]= "chatml" 229 | 230 | class Config: 231 | extra = 'forbid' 232 | 233 | #@retry(retry=retry_if_result(is_retryable_answer), stop=stop_after_attempt(N_RETRIES), wait=wait_exponential(multiplier=1, min=4, max=60)) 234 | def ask(self): 235 | 236 | param_llm = { 237 | "verbose": False, 238 | "n_ctx": self.max_tokens 239 | } 240 | 241 | if self.use_gpu: 242 | param_llm["n_gpu_layers"] = -1 243 | 244 | if self.chat_format: 245 | param_llm["chat_format"] = self.chat_format 246 | 247 | llm = Llama(model_path=self.path, **param_llm) 248 | 249 | param_completion = { 250 | "messages": [message.model_dump() for message in self.user_prompt], 251 | "max_tokens": self.max_tokens, 252 | "temperature": random.choice(self.temperature), 253 | } 254 | 255 | if self.stop: 256 | param_completion["stop"] = self.stop 257 | 258 | if self.top_p: 259 | param_completion["top_p"] = self.top_p 260 | 261 | if self.min_p: 262 | param_completion["min_p"] = self.min_p 263 | 264 | if self.json_mode: 265 | param_completion["response_format"] = {"type": "json_object"} 266 | 267 | if self.tools: 268 | param_completion["functions"] = self.tools 269 | 270 | output = self.llm.create_chat_completion_openai_v1(**param_completion) 271 | 272 | return output 273 | 274 | class LlamaCPPModel(BaseModel): 275 | 276 | path:str 277 | name:Optional[str] = None 278 | user_prompt:Optional[str] = None 279 | temperature:Optional[List[float]] = [0.8] 280 | use_gpu:Optional[bool] = False 281 | handle_prompt_format:Optional[bool] = False 282 | stop:Optional[List[str]] = None 283 | max_tokens:Optional[int] = 256 284 | top_p:Optional[float] = 0.95 285 | min_p:Optional[float] = 0.05 286 | echo:Optional[bool] = False 287 | start_with:Optional[List[str]] = None 288 | confidence_score:Optional[float] = None 289 | 290 | def ask(self) -> str: 291 | 292 | param_llm = { 293 | "verbose": False 294 | } 295 | 296 | if self.use_gpu: 297 | param_llm["n_gpu_layers"] = -1 298 | 299 | #llm = Llama(model_path=self.path, verbose=False, n_gpu_layers=-1) 300 | llm = Llama(model_path=self.path, **param_llm) 301 | 302 | param_completion = { 303 | "prompt": f"{self.user_prompt}", 304 | "max_tokens": self.max_tokens, 305 | "echo": self.echo, 306 | "temperature": random.choice(self.temperature) 307 | } 308 | 309 | if self.stop: 310 | param_completion["stop"] = self.stop 311 | 312 | if self.top_p: 313 | param_completion["top_p"] = self.top_p 314 | 315 | if self.min_p: 316 | param_completion["min_p"] = self.min_p 317 | 318 | output = llm(**param_completion) 319 | 320 | return output["choices"][0]["text"] 321 | 322 | def __init__(self, **data): 323 | 324 | super().__init__(**data) 325 | self.name = self.path.split('/')[-1] 326 | 327 | class Config: 328 | extra = 'forbid' 329 | 330 | 331 | 332 | class AnyscaleChatModel(BaseModel): 333 | 334 | name:str = "mistralai/Mixtral-8x7B-Instruct-v0.1" 335 | user_prompt:Optional[List[UserMessage]] = None 336 | max_tokens:Optional[int] = 256 337 | temperature:Optional[List[float]] = [1] 338 | handle_prompt_format:Optional[bool] = False 339 | json_mode:Optional[bool] = False 340 | json_schema:Optional[Dict] = None 341 | seed:Optional[int] = None 342 | tools:Optional[list] = None 343 | top_p:Optional[float] = 1 344 | stop:Optional[List[str]] = ["", "[/INST]"] 345 | presence_penalty: Optional[float] = 0 346 | frequency_penalty: Optional[float] = 0 347 | apikey:Optional[str] = None 348 | logprobs:Optional[bool] = False 349 | confidence_score:Optional[float] = None 350 | 351 | class Config: 352 | extra = 'forbid' 353 | 354 | 355 | @retry(retry=retry_if_result(is_retryable_answer), stop=stop_after_attempt(N_RETRIES), wait=wait_exponential(multiplier=1, min=4, max=60)) 356 | def ask(self) -> str: 357 | 358 | api_key = self.apikey if self.apikey else os.getenv("ANYSCALE_API_KEY") 359 | 360 | client = OpenAI(api_key=api_key, base_url='https://api.endpoints.anyscale.com/v1') 361 | 362 | messages = pydantic_list_to_dict(lst = self.user_prompt, fields=['role', 'content']) 363 | 364 | param = { 365 | 366 | "model":self.name, 367 | "temperature": random.choice(self.temperature), 368 | "messages": messages 369 | 370 | } 371 | 372 | if self.stop: 373 | param["stop"] = self.stop 374 | 375 | if self.top_p: 376 | param["top_p"] = self.top_p 377 | 378 | if self.max_tokens: 379 | param["max_tokens"] = self.max_tokens 380 | 381 | if self.json_mode and self.json_schema: 382 | param["response_format"] = {"type": "json_object", "schema": self.json_schema} 383 | 384 | completion = client.chat.completions.create(**param) 385 | 386 | if self.logprobs: 387 | self.confidence_score = get_confidence_score(completion=completion) 388 | 389 | answer = completion.choices[0].message.content 390 | 391 | return answer 392 | 393 | class TogetherChatModel(BaseModel): 394 | 395 | name:str = "mistralai/Mixtral-8x7B-Instruct-v0.1" 396 | user_prompt:Optional[List[UserMessage]] = None 397 | max_tokens:Optional[int] = 256 398 | temperature:Optional[List[float]] = [1] 399 | json_mode:Optional[bool] = False 400 | seed:Optional[int] = None 401 | tools:Optional[list] = None 402 | top_p:Optional[float] = 1 403 | stop:Optional[List[str]] = ["", "[/INST]"] 404 | presence_penalty: Optional[float] = 0 405 | frequency_penalty: Optional[float] = 0 406 | logprobs:Optional[bool] = False 407 | confidence_score:Optional[float] = None 408 | 409 | def __init__(self, **data): 410 | 411 | super().__init__(**data) 412 | 413 | 414 | class Config: 415 | extra = 'forbid' 416 | 417 | 418 | @retry(retry=retry_if_result(is_retryable_answer), stop=stop_after_attempt(N_RETRIES), wait=wait_exponential(multiplier=1, min=4, max=60)) 419 | def ask(self) -> str: 420 | 421 | api_key = self.apikey if self.apikey else os.getenv("TOGETHER_API_KEY") 422 | 423 | client = OpenAI(api_key=api_key, base_url='https://api.together.xyz') 424 | 425 | messages = pydantic_list_to_dict(lst = self.user_prompt, fields=['role', 'content']) 426 | 427 | param = { 428 | 429 | "model":self.name, 430 | "temperature": random.choice(self.temperature), 431 | "messages": messages 432 | 433 | } 434 | 435 | if self.stop: 436 | param["stop"] = self.stop 437 | 438 | if self.top_p: 439 | param["top_p"] = self.top_p 440 | 441 | if self.max_tokens: 442 | param["max_tokens"] = self.max_tokens 443 | 444 | if self.json_mode: 445 | param["response_format"] = {"type": "json_object"} 446 | 447 | completion = client.chat.completions.create(**param) 448 | 449 | if self.logprobs: 450 | self.confidence_score = get_confidence_score(completion=completion) 451 | 452 | answer = completion.choices[0].message.content 453 | 454 | return answer 455 | 456 | class TogetherInstructModel(BaseModel): 457 | 458 | name:str = "mistralai/Mixtral-8x7B-Instruct-v0.1" 459 | max_tokens:Optional[int] = 256 460 | temperature:Optional[List[float]] = [1] 461 | handle_prompt_format:Optional[bool] = False 462 | user_prompt:Optional[str] = None 463 | seed:Optional[int] = None 464 | tools:Optional[List[str]] = None 465 | start_with:Optional[List[str]] = None 466 | top_p:Optional[float] = 1 467 | stop:Optional[List[str]] = None 468 | presence_penalty: Optional[float] = 0 469 | frequency_penalty: Optional[float] = 0 470 | confidence_score:Optional[float] = None 471 | 472 | class Config: 473 | extra = 'forbid' 474 | 475 | 476 | @retry(stop=stop_after_attempt(N_RETRIES), wait=wait_exponential(multiplier=1, min=4, max=60)) 477 | def ask(self) -> str: 478 | 479 | 480 | api_key = self.apikey if self.apikey else os.getenv("TOGETHER_API_KEY") 481 | 482 | client = OpenAI(api_key=api_key, base_url='https://api.together.xyz') 483 | 484 | param = { 485 | 486 | "model":self.name, 487 | "temperature": random.choice(self.temperature), 488 | "prompt": f"{self.user_prompt}" 489 | 490 | } 491 | 492 | if self.stop: 493 | param["stop"] = self.stop 494 | 495 | if self.top_p: 496 | param["top_p"] = self.top_p 497 | 498 | if self.max_tokens: 499 | param["max_tokens"] = self.max_tokens 500 | 501 | 502 | completion = client.completions.create(**param) 503 | 504 | answer = completion.choices[0].text 505 | 506 | return answer 507 | 508 | 509 | class MistralChatModel(BaseModel): 510 | 511 | name:str = "mistral-tiny" 512 | max_tokens:Optional[int] = 256 513 | temperature:Optional[List[float]] = [0.7] 514 | user_prompt:Optional[List[UserMessage]] = None 515 | random_seed:Optional[int] = None 516 | top_p:Optional[float] = 1 517 | safe_mode:Optional[bool] = False 518 | confidence_score:Optional[float] = None 519 | apikey:Optional[str] = None 520 | 521 | class Config: 522 | extra = 'forbid' 523 | 524 | @retry(stop=stop_after_attempt(N_RETRIES), wait=wait_exponential(multiplier=1, min=4, max=60)) 525 | def ask(self) -> str: 526 | 527 | api_key = self.apikey if self.apikey else os.getenv("MISTRAL_API_KEY") 528 | 529 | client = MistralClient(api_key=api_key) 530 | 531 | messages = [ChatMessage(role=msg.role, content=msg.content) for msg in self.user_prompt] 532 | 533 | param = { 534 | 535 | "model":self.name, 536 | "temperature": random.choice(self.temperature), 537 | "messages": messages 538 | 539 | } 540 | 541 | if self.max_tokens: 542 | param["max_tokens"] = self.max_tokens 543 | 544 | if self.top_p: 545 | param["top_p"] = self.top_p 546 | 547 | if self.random_seed: 548 | param["random_seed"] = self.random_seed 549 | 550 | chat_response = client.chat(**param) 551 | 552 | answer = chat_response.choices[0].message.content 553 | 554 | return answer 555 | 556 | 557 | class MusicGenModel(BaseModel): 558 | 559 | name:str = "facebook/musicgen-melody" 560 | duration:int = 4 561 | user_prompt:str 562 | audio:Optional[str] = None 563 | 564 | class Config: 565 | extra = 'forbid' 566 | 567 | def ask(self): 568 | 569 | model = MusicGen.get_pretrained(self.name) 570 | model.set_generation_params(duration=self.duration) 571 | descriptions = [self.user_prompt] 572 | 573 | if self.audio: 574 | melody, sr = torchaudio.load(self.audio) 575 | wav = model.generate_with_chroma(descriptions, melody[None].expand(3, -1, -1), sr) 576 | else: 577 | wav = model.generate(descriptions) 578 | 579 | for one_wav in wav: 580 | 581 | filename = f'music_{uuid.uuid4()}' 582 | 583 | audio_write(filename, one_wav.cpu(), model.sample_rate, strategy="loudness", loudness_compressor=True) 584 | 585 | sound = AudioSegment.from_wav(f'{filename}.wav') 586 | 587 | #save to mp3 588 | sound.export(f'{filename}.mp3', format="mp3") 589 | 590 | return f'{filename}.mp3' 591 | 592 | class AudioGenModel(BaseModel): 593 | 594 | name:str = "facebook/audiogen-medium" 595 | duration:int = 4 596 | user_prompt:str 597 | 598 | class Config: 599 | extra = 'forbid' 600 | 601 | def ask(self): 602 | 603 | model = AudioGen.get_pretrained(self.name) 604 | model.set_generation_params(duration=self.duration) 605 | descriptions = [self.user_prompt] 606 | wav = model.generate(descriptions) 607 | 608 | for one_wav in wav: 609 | 610 | filename = f'audio_{uuid.uuid4()}' 611 | 612 | audio_write(filename, one_wav.cpu(), model.sample_rate, strategy="loudness", loudness_compressor=True) 613 | 614 | sound = AudioSegment.from_wav(f'{filename}.wav') 615 | 616 | #save to mp3 617 | sound.export(f'{filename}.mp3', format="mp3") 618 | 619 | return f'{filename}.mp3' 620 | 621 | 622 | class ClaudeChatModel(BaseModel): 623 | 624 | name:str = "claude-3-sonnet-20240229" 625 | user_prompt:Optional[List[UserMessage]] = None 626 | max_tokens:Optional[int] = 256 627 | temperature:Optional[List[float]] = [1] 628 | json_mode:Optional[bool] = False 629 | seed:Optional[int] = None 630 | tools:Optional[list] = None 631 | top_p:Optional[float] = 1 632 | note:Optional[List[str]] = None 633 | stop:Optional[List[str]] = None 634 | presence_penalty: Optional[float] = 0 635 | frequency_penalty: Optional[float] = 0 636 | logprobs:Optional[bool] = False 637 | confidence_score:Optional[float] = None 638 | apikey:Optional[str] = None 639 | 640 | class Config: 641 | extra = 'forbid' 642 | 643 | @retry(retry=retry_if_result(is_retryable_answer), stop=stop_after_attempt(N_RETRIES), wait=wait_exponential(multiplier=1, min=4, max=60)) 644 | def ask(self) -> str: 645 | 646 | client = anthropic.Client() 647 | 648 | if self.apikey: 649 | client.api_key = self.apikey 650 | else: 651 | client.api_key = os.getenv("OPENAI_API_KEY") 652 | 653 | messages = pydantic_list_to_dict(lst = self.user_prompt, fields=['role', 'content']) 654 | 655 | param = { 656 | 657 | "model":self.name, 658 | "temperature": random.choice(self.temperature), 659 | "messages": messages, 660 | 661 | } 662 | 663 | if self.stop: 664 | param["stop_sequences"] = self.stop 665 | 666 | if self.top_p: 667 | param["top_p"] = self.top_p 668 | 669 | if self.tools: 670 | param["tools"] = self.tools 671 | 672 | if self.max_tokens: 673 | param["max_tokens"] = self.max_tokens 674 | 675 | if self.max_tokens: 676 | param["max_tokens"] = self.max_tokens 677 | 678 | if self.json_mode: 679 | param["response_format"] = {"type": "json_object"} 680 | 681 | if self.seed: 682 | param["seed"] = self.seed 683 | 684 | completion = client.chat.completions.create(**param) 685 | 686 | if self.logprobs: 687 | self.confidence_score = get_confidence_score(completion=completion) 688 | 689 | answer = completion.choices[0].message.content 690 | 691 | return answer 692 | 693 | 694 | class OpenAIChatModel(BaseModel): 695 | 696 | name:str = "gpt-3.5-turbo-1106" 697 | user_prompt:Optional[List[UserMessage]] = None 698 | max_tokens:Optional[int] = 256 699 | temperature:Optional[List[float]] = [1] 700 | json_mode:Optional[bool] = False 701 | seed:Optional[int] = None 702 | tools:Optional[list] = None 703 | top_p:Optional[float] = 1 704 | note:Optional[List[str]] = None 705 | stop:Optional[List[str]] = None 706 | presence_penalty: Optional[float] = 0 707 | frequency_penalty: Optional[float] = 0 708 | logprobs:Optional[bool] = False 709 | confidence_score:Optional[float] = None 710 | apikey:Optional[str] = None 711 | 712 | class Config: 713 | extra = 'forbid' 714 | 715 | @retry(retry=retry_if_result(is_retryable_answer), stop=stop_after_attempt(N_RETRIES), wait=wait_exponential(multiplier=1, min=4, max=60)) 716 | def ask(self) -> str: 717 | 718 | client = OpenAI() 719 | 720 | if self.apikey: 721 | client.api_key = self.apikey 722 | else: 723 | client.api_key = os.getenv("OPENAI_API_KEY") 724 | 725 | messages = pydantic_list_to_dict(lst = self.user_prompt, fields=['role', 'content']) 726 | 727 | param = { 728 | 729 | "model":self.name, 730 | "temperature": random.choice(self.temperature), 731 | "messages": messages, 732 | "logprobs": self.logprobs 733 | 734 | } 735 | 736 | if self.stop: 737 | param["stop"] = self.stop 738 | 739 | if self.top_p: 740 | param["top_p"] = self.top_p 741 | 742 | if self.tools: 743 | param["functions"] = self.tools 744 | 745 | if self.max_tokens: 746 | param["max_tokens"] = self.max_tokens 747 | 748 | if self.max_tokens: 749 | param["max_tokens"] = self.max_tokens 750 | 751 | if self.json_mode: 752 | param["response_format"] = {"type": "json_object"} 753 | 754 | if self.seed: 755 | param["seed"] = self.seed 756 | 757 | completion = client.chat.completions.create(**param) 758 | 759 | if self.logprobs: 760 | self.confidence_score = get_confidence_score(completion=completion) 761 | 762 | answer = completion.choices[0].message.content 763 | 764 | return answer 765 | 766 | 767 | 768 | class OpenAIITImageModel(BaseModel): 769 | 770 | name:str = "dall-e-2" 771 | image_path:str 772 | mask_path:str 773 | user_prompt:str 774 | size:Optional[str] = "1024x1024" 775 | number_of_images:Optional[int] = 1 776 | apikey:Optional[str] = None 777 | 778 | class Config: 779 | extra = 'forbid' 780 | 781 | #@retry(stop=stop_after_attempt(N_RETRIES), wait=wait_exponential(multiplier=1, min=4, max=60)) 782 | def ask(self) -> str: 783 | 784 | client = OpenAI() 785 | 786 | if self.apikey: 787 | client.api_key = self.apikey 788 | else: 789 | client.api_key = os.getenv("OPENAI_API_KEY") 790 | 791 | param = { 792 | "model":self.name, 793 | "image":open(self.image_path, "rb"), 794 | "mask":open(self.mask_path, "rb"), 795 | "size":self.size, 796 | "prompt":self.user_prompt, 797 | "n":self.number_of_images 798 | } 799 | 800 | completion = client.images.edit(**param) 801 | 802 | image_url = completion.data[0].url 803 | 804 | # Generate a random UUID and create a filename 805 | filename = f'image_{uuid.uuid4()}.png' 806 | 807 | response = requests.get(image_url) 808 | response.raise_for_status() # Raises a HTTPError if the response status code is 4XX/5XX 809 | 810 | # Create metadata 811 | metadata = PngInfo() 812 | metadata.add_text("image_url", image_url) 813 | 814 | # Since we're directly using the response content, convert it to a bytes stream 815 | image_bytes = io.BytesIO(response.content) 816 | 817 | # Open the image using Pillow 818 | with Image.open(image_bytes) as img: 819 | # Save the image with metadata 820 | img.save(filename, "PNG", pnginfo=metadata) 821 | 822 | uri = image_to_base64_data_uri(file_path=filename) 823 | 824 | return uri 825 | 826 | 827 | 828 | class OpenAITTImageModel(BaseModel): 829 | 830 | name:str = "dall-e-3" 831 | user_prompt:Optional[str] = None 832 | size:Optional[str] = "1024x1024" 833 | quality:Optional[str] = "standard" 834 | number_of_images:Optional[int] = 1 835 | apikey:Optional[str] = None 836 | 837 | class Config: 838 | extra = 'forbid' 839 | 840 | @retry(stop=stop_after_attempt(N_RETRIES), 841 | wait=wait_exponential(multiplier=1, min=4, max=60), 842 | retry=retry_if_exception_type(requests.exceptions.RequestException), 843 | reraise=True) 844 | 845 | def ask(self) -> str: 846 | 847 | client = OpenAI() 848 | 849 | if self.apikey: 850 | client.api_key = self.apikey 851 | else: 852 | client.api_key = os.getenv("OPENAI_API_KEY") 853 | 854 | param = { 855 | "model": self.name, 856 | "prompt": f"{self.user_prompt}", 857 | "size": self.size, 858 | "quality": self.quality, 859 | "n": self.number_of_images 860 | } 861 | 862 | try: 863 | 864 | completion = client.images.generate(**param) 865 | image_url = completion.data[0].url 866 | # Generate a random UUID and create a filename 867 | filename = f'image_{uuid.uuid4()}.png' 868 | response = requests.get(image_url) 869 | response.raise_for_status() # Raises a HTTPError if the response status code is 4XX/5XX 870 | # Create metadata 871 | metadata = PngInfo() 872 | metadata.add_text("image_url", image_url) 873 | # Since we're directly using the response content, convert it to a bytes stream 874 | image_bytes = io.BytesIO(response.content) 875 | # Open the image using Pillow 876 | with Image.open(image_bytes) as img: 877 | # Save the image with metadata 878 | img.save(filename, "PNG", pnginfo=metadata) 879 | uri = image_to_base64_data_uri(file_path=filename) 880 | return uri 881 | 882 | except requests.exceptions.RequestException as e: 883 | 884 | print(f"Error occurred: {str(e)}") 885 | raise 886 | 887 | 888 | 889 | class OpenAIInstructModel(BaseModel): 890 | 891 | name:str = "gpt-3.5-turbo-instruct" 892 | max_tokens:Optional[int] = 256 893 | temperature:Optional[List[float]] = [1] 894 | user_prompt:Optional[str] = None 895 | seed:Optional[int] = None 896 | tools:Optional[List[str]] = None 897 | start_with:Optional[List[str]] = None 898 | top_p:Optional[float] = 1 899 | stop:Optional[List[str]] = None 900 | presence_penalty: Optional[float] = 0 901 | frequency_penalty: Optional[float] = 0 902 | confidence_score:Optional[float] = None 903 | apikey:Optional[str] = None 904 | 905 | class Config: 906 | extra = 'forbid' 907 | 908 | @retry(stop=stop_after_attempt(N_RETRIES), wait=wait_exponential(multiplier=1, min=4, max=60)) 909 | def ask(self) -> str: 910 | 911 | client = OpenAI() 912 | 913 | if self.apikey: 914 | client.api_key = self.apikey 915 | else: 916 | client.api_key = os.getenv("OPENAI_API_KEY") 917 | 918 | param = { 919 | 920 | "model":self.name, 921 | "temperature": random.choice(self.temperature), 922 | "prompt": f"{self.user_prompt}" 923 | 924 | } 925 | 926 | if self.stop: 927 | param["stop"] = self.stop 928 | 929 | if self.top_p: 930 | param["top_p"] = self.top_p 931 | 932 | if self.tools: 933 | param["functions"] = self.tools 934 | 935 | if self.max_tokens: 936 | param["max_tokens"] = self.max_tokens 937 | 938 | if self.seed: 939 | param["seed"] = self.seed 940 | 941 | completion = client.completions.create(**param) 942 | 943 | answer = completion.choices[0].text 944 | 945 | return answer 946 | 947 | 948 | class OpenAIEmbeddingModel(BaseModel): 949 | 950 | name:str = "text-embedding-ada-002" 951 | 952 | class Config: 953 | extra = 'forbid' 954 | 955 | def create_embedding(self, prompt:str): 956 | 957 | client = OpenAI() 958 | 959 | if self.apikey: 960 | client.api_key = self.apikey 961 | else: 962 | client.api_key = os.getenv("OPENAI_API_KEY") 963 | 964 | embedding = client.embeddings.create( 965 | model=self.name, 966 | input=prompt 967 | ) 968 | 969 | return embedding.data[0].embedding 970 | 971 | class EmbeddingModel(BaseModel): 972 | 973 | openai_embedding_model:Optional[OpenAIEmbeddingModel] = None 974 | 975 | def get_model(self): 976 | if self.openai_embedding_model is not None: 977 | return self.openai_embedding_model 978 | else: 979 | return None 980 | 981 | 982 | class Model(BaseModel): 983 | 984 | openai_chat_model: Optional[OpenAIChatModel] = None 985 | claude_chat_model:Optional[ClaudeChatModel] = None 986 | suno_tts_model:Optional[BarkTTSModel] = None 987 | openai_instruct_model: Optional[OpenAIInstructModel] = None 988 | openai_tti_model: Optional[OpenAITTImageModel] = None 989 | openai_iti_model: Optional[OpenAIITImageModel] = None 990 | llamacpp_itt_model: Optional[LlamaCPPITTModel] = None 991 | llamacpp_instruct_model: Optional[LlamaCPPModel] = None 992 | llamacpp_chat_model:Optional[LlamaCPPChatModel] = None 993 | mistral_chat_model: Optional[MistralChatModel] = None 994 | together_chat_model: Optional[TogetherChatModel] = None 995 | anyscale_chat_model: Optional[AnyscaleChatModel] = None 996 | whisper_model: Optional[WhisperModel] = None 997 | elevenlabs_tts_model: Optional[ElevenLabsTTSModel] = None 998 | musicgen: Optional[MusicGenModel] = None 999 | audiogen: Optional[AudioGenModel] = None 1000 | 1001 | def get_model(self): 1002 | 1003 | model_attributes = [ 1004 | "openai_chat_model", "openai_instruct_model", "mistral_chat_model", 1005 | "openai_iti_model", "llamacpp_instruct_model", "llamacpp_itt_model", 1006 | "together_chat_model", "anyscale_chat_model", "whisper_model", 1007 | "elevenlabs_tts_model", "openai_tti_model", "musicgen", "audiogen", 1008 | "claude_chat_model", "suno_tts_model", "llamacpp_chat_model" 1009 | ] 1010 | 1011 | for attr in model_attributes: 1012 | model = getattr(self, attr, None) 1013 | if model is not None: 1014 | return model 1015 | 1016 | return None 1017 | 1018 | 1019 | 1020 | def convert_openailogprobs_to_dict(completion): 1021 | 1022 | result = {} 1023 | 1024 | for logp in completion.choices[0].logprobs.content: 1025 | 1026 | result[logp.token] = math.exp(logp.logprob) 1027 | 1028 | return result 1029 | 1030 | 1031 | def extract_keyword_from_text(text:str): 1032 | 1033 | schema = { 1034 | "type": "object", 1035 | "properties": { 1036 | "keywords": { 1037 | "type": "array", 1038 | "items": { 1039 | "type": "string" 1040 | } 1041 | } 1042 | }, 1043 | "required": ["keywords"] 1044 | } 1045 | 1046 | system_prompt = f"Identify and extract all the important keyphrases from the given text return a valid JSON complying with this schema:\n{str(schema)}" 1047 | 1048 | user_prompt = f"Text:\n'''{text}'''" 1049 | 1050 | messages = [ 1051 | {"role":"system", "content":system_prompt}, 1052 | {"role":"user", "content":user_prompt} 1053 | ] 1054 | 1055 | model = OpenAIChatModel(user_prompt=messages, temperature=[0], json_mode=True) 1056 | 1057 | answer = model.ask() 1058 | 1059 | return answer 1060 | 1061 | def get_confidence_score(completion): 1062 | 1063 | confidence_score = {} 1064 | 1065 | logp_dict = convert_openailogprobs_to_dict(completion=completion) 1066 | 1067 | extract = extract_keyword_from_text(text=completion.choices[0].message.content) 1068 | 1069 | extract_json = json.loads(extract) 1070 | 1071 | if "keywords" in extract_json: 1072 | keywords = extract_json["keywords"] 1073 | else: 1074 | keywords = [] 1075 | 1076 | if len(keywords) == 0: 1077 | keywords = list(logp_dict.keys()) 1078 | 1079 | for keyword in keywords: 1080 | 1081 | try: 1082 | 1083 | encoding = tiktoken.get_encoding("cl100k_base") 1084 | 1085 | list_of_tokens_integers = encoding.encode(keyword) 1086 | 1087 | tokens = [encoding.decode_single_token_bytes(token).decode('utf-8') for token in list_of_tokens_integers] 1088 | # Initialize the minimum probability as 1 (maximum possible probability) 1089 | min_probability = 1 1090 | 1091 | for token in tokens: 1092 | # Check if token is in the dictionary and update the minimum probability 1093 | if token in logp_dict and logp_dict[token] < min_probability: 1094 | min_probability = logp_dict[token] 1095 | 1096 | # Store the minimum probability as the confidence level for the keyword 1097 | confidence_score[keyword] = min_probability 1098 | 1099 | except UnicodeDecodeError as e: 1100 | 1101 | print(f"Error decoding token {token}: {e}") 1102 | 1103 | 1104 | min_confidence_score = min(confidence_score.values()) 1105 | 1106 | return min_confidence_score 1107 | --------------------------------------------------------------------------------