├── LICENSE ├── README.md ├── README_zh.md ├── assets ├── api_keys │ ├── ChatGLM │ │ ├── api_key1.png │ │ ├── api_key2.png │ │ ├── login_register1.png │ │ ├── login_register2.png │ │ └── resources.png │ ├── DeepSeek │ │ ├── api_key2.png │ │ ├── login_register1.png │ │ ├── login_register2.png │ │ └── resources.png │ └── InternLM │ │ ├── api_key1.png │ │ ├── login_register1.png │ │ ├── login_register2.png │ │ ├── login_register3.png │ │ └── login_register4.png └── welcome.png ├── build └── lib │ └── edg4llm │ ├── __init__.py │ ├── core │ ├── __init__.py │ ├── dataGenerators.py │ ├── interface.py │ └── pipeline.py │ ├── generators │ ├── __init__.py │ └── text_generators │ │ ├── __init__.py │ │ ├── answer_generator.py │ │ ├── base_generator.py │ │ ├── dialogue_generator.py │ │ └── question_generator.py │ ├── models │ ├── __init__.py │ ├── baseModel.py │ ├── chatglm.py │ ├── chatgpt.py │ ├── deepseek.py │ └── internlm.py │ ├── processor │ ├── __init__.py │ ├── postprocess.py │ └── preprocess.py │ └── utils │ ├── __init__.py │ ├── config.py │ ├── data_utils.py │ ├── exceptions.py │ ├── list_supported_models.py │ ├── logger.py │ └── template.py ├── demos ├── chatglm_demo_v1_0_1.ipynb ├── chatgpt_demo_v1_0_1.ipynb ├── deepseek_demo_v1_0_1.ipynb ├── internlm_demo_v1_0_1.ipynb ├── question.json └── readme.md ├── dist ├── edg4llm-1.0.18-py3-none-any.whl └── edg4llm-1.0.18.tar.gz ├── docs ├── api_keys │ ├── ChatGLM_apply_for_api_key.md │ ├── DeepSeek_apply_for_api_key.md │ └── InternLM_apply_for_api_key.md └── readme.md ├── edg4llm.egg-info ├── PKG-INFO ├── SOURCES.txt ├── dependency_links.txt ├── entry_points.txt ├── not-zip-safe ├── requires.txt └── top_level.txt ├── edg4llm ├── __init__.py ├── core │ ├── __init__.py │ ├── dataGenerators.py │ ├── interface.py │ └── pipeline.py ├── generators │ ├── __init__.py │ └── text_generators │ │ ├── __init__.py │ │ ├── answer_generator.py │ │ ├── base_generator.py │ │ ├── dialogue_generator.py │ │ └── question_generator.py ├── models │ ├── __init__.py │ ├── baseModel.py │ ├── chatglm.py │ ├── chatgpt.py │ ├── deepseek.py │ └── internlm.py ├── processor │ ├── __init__.py │ ├── postprocess.py │ └── preprocess.py └── utils │ ├── __init__.py │ ├── config.py │ ├── data_utils.py │ ├── exceptions.py │ ├── list_supported_models.py │ ├── logger.py │ └── template.py └── setup.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Alannikos 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /assets/api_keys/ChatGLM/api_key1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alannikos/edg4llm/73b996c36c750fec2d99c934bda01a4d4a57954a/assets/api_keys/ChatGLM/api_key1.png -------------------------------------------------------------------------------- /assets/api_keys/ChatGLM/api_key2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alannikos/edg4llm/73b996c36c750fec2d99c934bda01a4d4a57954a/assets/api_keys/ChatGLM/api_key2.png -------------------------------------------------------------------------------- /assets/api_keys/ChatGLM/login_register1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alannikos/edg4llm/73b996c36c750fec2d99c934bda01a4d4a57954a/assets/api_keys/ChatGLM/login_register1.png -------------------------------------------------------------------------------- /assets/api_keys/ChatGLM/login_register2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alannikos/edg4llm/73b996c36c750fec2d99c934bda01a4d4a57954a/assets/api_keys/ChatGLM/login_register2.png -------------------------------------------------------------------------------- /assets/api_keys/ChatGLM/resources.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alannikos/edg4llm/73b996c36c750fec2d99c934bda01a4d4a57954a/assets/api_keys/ChatGLM/resources.png -------------------------------------------------------------------------------- /assets/api_keys/DeepSeek/api_key2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alannikos/edg4llm/73b996c36c750fec2d99c934bda01a4d4a57954a/assets/api_keys/DeepSeek/api_key2.png -------------------------------------------------------------------------------- /assets/api_keys/DeepSeek/login_register1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alannikos/edg4llm/73b996c36c750fec2d99c934bda01a4d4a57954a/assets/api_keys/DeepSeek/login_register1.png -------------------------------------------------------------------------------- /assets/api_keys/DeepSeek/login_register2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alannikos/edg4llm/73b996c36c750fec2d99c934bda01a4d4a57954a/assets/api_keys/DeepSeek/login_register2.png -------------------------------------------------------------------------------- /assets/api_keys/DeepSeek/resources.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alannikos/edg4llm/73b996c36c750fec2d99c934bda01a4d4a57954a/assets/api_keys/DeepSeek/resources.png -------------------------------------------------------------------------------- /assets/api_keys/InternLM/api_key1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alannikos/edg4llm/73b996c36c750fec2d99c934bda01a4d4a57954a/assets/api_keys/InternLM/api_key1.png -------------------------------------------------------------------------------- /assets/api_keys/InternLM/login_register1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alannikos/edg4llm/73b996c36c750fec2d99c934bda01a4d4a57954a/assets/api_keys/InternLM/login_register1.png -------------------------------------------------------------------------------- /assets/api_keys/InternLM/login_register2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alannikos/edg4llm/73b996c36c750fec2d99c934bda01a4d4a57954a/assets/api_keys/InternLM/login_register2.png -------------------------------------------------------------------------------- /assets/api_keys/InternLM/login_register3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alannikos/edg4llm/73b996c36c750fec2d99c934bda01a4d4a57954a/assets/api_keys/InternLM/login_register3.png -------------------------------------------------------------------------------- /assets/api_keys/InternLM/login_register4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alannikos/edg4llm/73b996c36c750fec2d99c934bda01a4d4a57954a/assets/api_keys/InternLM/login_register4.png -------------------------------------------------------------------------------- /assets/welcome.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alannikos/edg4llm/73b996c36c750fec2d99c934bda01a4d4a57954a/assets/welcome.png -------------------------------------------------------------------------------- /build/lib/edg4llm/__init__.py: -------------------------------------------------------------------------------- 1 | from edg4llm.core.interface import EDG4LLM 2 | 3 | __all__ = ["EDG4LLM"] 4 | 5 | __version__ = "1.0.18" 6 | __author__ = "Alannikos" 7 | __license__ = "MIT" 8 | -------------------------------------------------------------------------------- /build/lib/edg4llm/core/__init__.py: -------------------------------------------------------------------------------- 1 | from edg4llm.core.interface import EDG4LLM 2 | -------------------------------------------------------------------------------- /build/lib/edg4llm/core/pipeline.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any, Tuple, Dict 3 | 4 | from edg4llm.utils.logger import custom_logger 5 | from edg4llm.core.dataGenerators import DataGenerator 6 | 7 | logger = custom_logger("DataPipeline") 8 | 9 | class DataPipeline: 10 | """ 11 | The DataPipeline class manages the entire process of generating data, designed to 12 | automatically create fine-tuning data for different task types such as question 13 | generation, answer generation, and dialogue generation. 14 | 15 | This class uses a DataGenerator object to handle the core logic of data generation 16 | and dynamically executes the corresponding task based on the provided configuration 17 | parameters. It provides a unified interface for users to easily invoke specific 18 | data generation methods with minimal configuration. 19 | 20 | Attributes: 21 | ---------- 22 | data_generator (DataGenerator): An object that handles the specific data generation tasks. 23 | 24 | Methods: 25 | ---------- 26 | __init__(pConfig): Initializes the DataPipeline class and creates a DataGenerator 27 | object based on the configuration. 28 | generate_data(tConfig): Generates fine-tuning data based on the task configuration. 29 | Supported task types include question generation, answer generation, 30 | and dialogue generation. 31 | """ 32 | 33 | def __init__(self, pConfig): 34 | """ 35 | Initializes the data generation process. 36 | 37 | Parameters 38 | ---------- 39 | pConfig : dict 40 | Configuration for initializing the DataGenerator. Expected to contain: 41 | - model_provider: str 42 | The type of language model to use, by default "chatglm". 43 | - model_name: str 44 | The specific model to use within the model type, by default "chatglm-4-flash". 45 | - base_url : str 46 | The base URL of the LLM API. 47 | - api_key : str 48 | The API key for authentication. 49 | """ 50 | 51 | self.data_generator = DataGenerator(pConfig) 52 | 53 | def generate_data(self, tConfig) -> Dict: 54 | """ 55 | Generates data based on the provided configuration. 56 | 57 | Parameters 58 | ---------- 59 | tConfig : Dict 60 | Task configuration containing the following keys: 61 | - task_type : str 62 | Specifies the type of task ('question', 'answer', or 'dialogue'). 63 | - Other parameters required for data generation, specific to the task type. 64 | 65 | Returns 66 | ------- 67 | dict 68 | A dictionary containing the generated fine-tuning data. 69 | 70 | Raises 71 | ------ 72 | ValueError 73 | If the provided task type is unsupported. 74 | """ 75 | if tConfig["task_type"] == "question": 76 | logger.info("Generated data for task_type: 'question'") 77 | data = self.data_generator.generate_question(tConfig) 78 | elif tConfig["task_type"] == "answer": 79 | logger.info("Generated data for task_type: 'answer'") 80 | data = self.data_generator.generate_answer(tConfig) 81 | elif tConfig["task_type"] == "dialogue": 82 | logger.info("Generated data for task_type: 'dialogue'") 83 | data = self.data_generator.generate_dialogue(tConfig) 84 | else: 85 | logger.error("Unsupported task type: %s", tConfig["task_type"]) 86 | raise ValueError("Unsupported task type") 87 | 88 | return data 89 | -------------------------------------------------------------------------------- /build/lib/edg4llm/generators/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alannikos/edg4llm/73b996c36c750fec2d99c934bda01a4d4a57954a/build/lib/edg4llm/generators/__init__.py -------------------------------------------------------------------------------- /build/lib/edg4llm/generators/text_generators/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alannikos/edg4llm/73b996c36c750fec2d99c934bda01a4d4a57954a/build/lib/edg4llm/generators/text_generators/__init__.py -------------------------------------------------------------------------------- /build/lib/edg4llm/generators/text_generators/answer_generator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | from typing import Dict, Any 5 | 6 | from edg4llm.utils.logger import custom_logger 7 | from edg4llm.generators.text_generators.base_generator import BaseGenerator 8 | 9 | logger = custom_logger("AnswerGenerator") 10 | 11 | class AnswerGenerator(BaseGenerator): 12 | """ 13 | A class for generating answers based on user queries using a specified model. 14 | 15 | This class extends the `BaseGenerator` class and provides functionality to generate 16 | answers to user queries based on a given configuration. It interacts with the model's 17 | `execute_request` method to generate responses based on system-level and user-level prompts. 18 | It supports customization through parameters such as temperature, sampling strategies, 19 | and token limits. 20 | 21 | Attributes 22 | ---------- 23 | model : object 24 | The model interface used for generating answers. 25 | 26 | Methods 27 | ------- 28 | generate(tConfig: dict) -> list of dict: 29 | Generates answers based on the provided configuration. 30 | 31 | Notes 32 | ----- 33 | - The `generate` method ensures valid answers are returned, retrying if necessary. 34 | - It logs progress for each generated answer. 35 | """ 36 | 37 | def __init__(self, model): 38 | """ 39 | Initialize the AnswerGenerator. 40 | 41 | Parameters 42 | ---------- 43 | model : object 44 | The model interface used for generating answers. 45 | """ 46 | 47 | super().__init__(model) 48 | 49 | def generate(self, tConfig) -> str: 50 | """ 51 | Generate answers based on the provided configuration. 52 | 53 | This method generates one or more answers based on the parameters provided in 54 | the `tConfig` dictionary. It uses the model's `execute_request` method to generate 55 | answers based on the system and user prompts, with options to control randomness, 56 | output length, and sampling strategy. 57 | 58 | Parameters 59 | ---------- 60 | tConfig : dict 61 | A configuration dictionary containing the following key-value pairs: 62 | - "system_prompt" : str, optional 63 | A system-level prompt that provides context for generating the answer. Default is an empty string. 64 | - "user_prompt" : str 65 | A user-provided prompt (query) to generate the corresponding answer. 66 | - "model" : str, optional 67 | The specific model to use for answer generation. Default is "glm-4-flash". 68 | - "do_sample" : bool, optional 69 | Whether to use sampling strategies during answer generation. Default is True. 70 | - "temperature" : float, optional 71 | A sampling parameter to control the randomness of the output. Must be between 0.0 and 1.0. Default is 0.95. 72 | - "top_p" : float, optional 73 | Nucleus sampling parameter controlling the cumulative probability range for token selection. 74 | Must be between 0.0 and 1.0. Default is 0.7. 75 | - "max_tokens" : int, optional 76 | The maximum number of tokens to generate in the answer. Default is 4095. 77 | - "num_samples" : int, optional 78 | The number of answers to generate. Default is 1. 79 | 80 | Returns 81 | ------- 82 | list of dict 83 | A list of dictionaries containing the generated answers. Each dictionary 84 | includes the generated answer content and relevant metadata. 85 | 86 | Notes 87 | ----- 88 | - The method will retry generating answers if the model fails to provide a valid response. 89 | - Progress and debug information are logged for each generated answer. 90 | """ 91 | 92 | # Extract configuration parameters 93 | system_prompt = tConfig.get("system_prompt", "") 94 | user_prompt = tConfig.get("user_prompt", "") 95 | do_sample = tConfig.get("do_sample", True) 96 | temperature = tConfig.get("temperature", 0.95) 97 | top_p = tConfig.get("top_p", 0.7) 98 | max_tokens = tConfig.get("max_tokens", 4095) 99 | num_samples = tConfig.get("num_samples", 1) # Default is to generate 1 sample 100 | question_path = tConfig.get("question_path", None) 101 | 102 | try: 103 | with open(question_path, "r", encoding="utf-8") as file: 104 | data = json.load(file) 105 | 106 | if isinstance(data, dict): # If it's a single dictionary, wrap it in a list 107 | data = [data] 108 | elif not isinstance(data, list): # Ensure it's a list of dictionaries 109 | raise ValueError("Invalid JSON structure. Expected a list or a dictionary.") 110 | 111 | # Extract questions 112 | questions = [item["question"] for item in data if "question" in item] 113 | except FileNotFoundError: 114 | logger.error("The file at path %s was not found.", question_path) 115 | return None 116 | except json.JSONDecodeError as e: 117 | logger.error("Error decoding JSON from file %s: %s", question_path, str(e)) 118 | return None 119 | except Exception as e: 120 | logger.error("Unexpected error: %s", str(e)) 121 | return None 122 | 123 | if len(questions) != num_samples: 124 | logger.error( 125 | "The number of questions (%d) does not match the expected number (%d). Please check your input.", 126 | len(questions), 127 | num_samples, 128 | ) 129 | 130 | sys.exit(1) # 非零退出码表示异常终止 131 | 132 | # List to store the generated dialogues 133 | dialogues = [] 134 | 135 | # Generate dialogues for the specified number of samples 136 | total_samples = num_samples # Total number of samples to generate 137 | logger.info("Starting the data generation process.") 138 | for _idx, question in enumerate(questions): 139 | retry_count = 0 # 初始化重试计数 140 | max_retries = 5 # 设置最大重试次数(根据需要调整) 141 | 142 | while True: # Keep trying until valid dialogue data is generated 143 | retry_count += 1 144 | 145 | generated_answer = self.model.execute_request( 146 | system_prompt=system_prompt, 147 | user_prompt=user_prompt.replace("EDG4LLM", question), 148 | do_sample=do_sample, 149 | temperature=temperature, 150 | top_p=top_p, 151 | max_tokens=max_tokens, 152 | ) 153 | 154 | if "error" in generated_answer: 155 | logger.warning( 156 | "Sample %d: Request failed with error: %s. Retrying (%d/%d)...", 157 | _idx + 1, 158 | generated_answer["error"], 159 | retry_count, 160 | max_retries, 161 | ) 162 | 163 | if retry_count >= max_retries: 164 | logger.error("Sample %d: Max retries reached. Skipping this sample.", _idx + 1) 165 | break # 跳出当前样本,进入下一个 166 | continue # 继续当前样本的生成 167 | 168 | # Convert the generated dialogue to the desired format (e.g., Alpaca format) 169 | converted_generated_answer = self._convert_original_to_alpaca_answer(system_prompt, question, generated_answer) 170 | 171 | if converted_generated_answer is not None: 172 | # If the dialogue is valid, append it to the results and break the loop 173 | dialogues.append(converted_generated_answer) 174 | break 175 | else: 176 | logger.warning( 177 | "Sample %d: Generated answer is None. Retrying (%d/%d)...", 178 | _idx + 1, 179 | retry_count, 180 | max_retries, 181 | ) 182 | 183 | if retry_count >= max_retries: 184 | logger.error("Sample %d: Max retries reached. Skipping this sample.", _idx + 1) 185 | break # 跳出当前样本 186 | 187 | # Log the progress of dialogue generation 188 | progress = ((_idx+1) / total_samples) * 100 189 | logger.info("Data generation progress: %.2f%% (%d/%d samples completed)", progress, _idx+1, total_samples) 190 | 191 | return dialogues 192 | -------------------------------------------------------------------------------- /build/lib/edg4llm/generators/text_generators/base_generator.py: -------------------------------------------------------------------------------- 1 | import os 2 | from abc import ABC, abstractmethod 3 | from typing import Dict 4 | 5 | from edg4llm.processor.postprocess import PostProcessor 6 | class BaseGenerator(ABC): 7 | """ 8 | Base class for all data generators, defining a common interface for generating data. 9 | 10 | This class serves as a foundation for different types of data generators, providing common functionality 11 | such as interaction with a model and post-processing of generated data. Specific generators should extend 12 | this class and implement their own `generate` method. 13 | 14 | Attributes 15 | ---------- 16 | model : object 17 | The model interface used for generating data. 18 | postprocessor : PostProcessor 19 | An instance of the PostProcessor class for handling post-processing of generated data. 20 | 21 | Methods 22 | ------- 23 | generate(prompt: str) -> str 24 | Abstract method to generate data based on a prompt. Must be implemented by subclasses. 25 | 26 | """ 27 | def __init__(self, model): 28 | """ 29 | Initialize the generator. 30 | 31 | Parameters 32 | ---------- 33 | model : object 34 | The model interface used for generating data. 35 | """ 36 | 37 | self.model = model 38 | self.postprocessor = PostProcessor() 39 | 40 | @abstractmethod 41 | def generate(self, prompt: str) -> str: 42 | """ 43 | Convert original data into Alpaca format. 44 | 45 | This method uses the PostProcessor to process conversation data and structure it 46 | in a format suitable for Alpaca-based models. 47 | 48 | Parameters 49 | ---------- 50 | system_prompt : str 51 | The system-level prompt for context in the Alpaca format. 52 | single_data : str 53 | The raw conversation data to be processed. 54 | 55 | Returns 56 | ------- 57 | dict 58 | The conversation data converted to Alpaca format. 59 | """ 60 | pass 61 | 62 | def _convert_original_to_alpaca(self, system_prompt, single_data): 63 | """ 64 | Convert original data into Alpaca format. 65 | 66 | This method uses the PostProcessor to process conversation data and structure it 67 | in a format suitable for Alpaca-based models. 68 | 69 | Parameters 70 | ---------- 71 | system_prompt : str 72 | The system-level prompt for context in the Alpaca format. 73 | single_data : str 74 | The raw conversation data to be processed. 75 | 76 | Returns 77 | ------- 78 | dict 79 | The conversation data converted to Alpaca format. 80 | """ 81 | 82 | converted_data = self.postprocessor.dialogue_postprocessing(conversation_data=single_data, system_prompt=system_prompt) 83 | 84 | return converted_data 85 | 86 | def _convert_original_to_json(self, single_data): 87 | """ 88 | Convert original data into JSON format. 89 | 90 | This method uses the PostProcessor to process raw data into a JSON-compatible structure. 91 | 92 | Parameters 93 | ---------- 94 | single_data : str 95 | The raw question data to be processed. 96 | 97 | Returns 98 | ------- 99 | dict 100 | The data converted into JSON format. 101 | """ 102 | 103 | converted_data = self.postprocessor.question_postprocessing(question_data=single_data) 104 | 105 | return converted_data 106 | 107 | def _convert_original_to_alpaca_answer(self, system_prompt, question, single_data): 108 | """ 109 | Convert original data into Alpaca answer format. 110 | 111 | This method uses the PostProcessor to process raw data into an answer format suitable for Alpaca-based models. 112 | 113 | Parameters 114 | ---------- 115 | system_prompt : str 116 | The system-level prompt for context in the Alpaca format. 117 | question : str 118 | The question text for which the answer is generated. 119 | single_data : str 120 | The raw answer data to be processed. 121 | 122 | Returns 123 | ------- 124 | dict 125 | The data converted into Alpaca format. 126 | """ 127 | 128 | converted_data = self.postprocessor.answer_postprocessing(question=question, answer=single_data, system_prompt=system_prompt) 129 | 130 | return converted_data 131 | -------------------------------------------------------------------------------- /build/lib/edg4llm/generators/text_generators/dialogue_generator.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict, List, Any 3 | 4 | from edg4llm.utils.logger import custom_logger 5 | from edg4llm.generators.text_generators.base_generator import BaseGenerator 6 | 7 | logger = custom_logger("DialogueGenerator") 8 | 9 | class DialogueGenerator(BaseGenerator): 10 | """ 11 | Dialogue Generator class for generating dialogues using a specified model. 12 | 13 | This class extends the `BaseGenerator` and utilizes the given model to generate dialogues 14 | based on user input and system prompts. It provides flexibility to control generation parameters 15 | like sampling strategies, temperature, and output format. 16 | 17 | Parameters 18 | ---------- 19 | model : object 20 | The model interface used for generating dialogues. This model must have the 21 | `execute_request` method for generating dialogue based on the given parameters. 22 | """ 23 | 24 | def __init__(self, model): 25 | """ 26 | Initialize the Dialogue Generator. 27 | 28 | This constructor initializes the `DialogueGenerator` by calling the base class constructor 29 | with the provided model. It sets up the necessary components for generating dialogues. 30 | 31 | Parameters 32 | ---------- 33 | model : object 34 | The model interface to be used for generating dialogues. It should provide 35 | the `execute_request` method to generate data based on the parameters. 36 | 37 | Notes 38 | ----- 39 | The `model` should be capable of handling inputs like system prompts, user prompts, 40 | and additional parameters for controlling the text generation process. 41 | """ 42 | super().__init__(model) 43 | 44 | def generate(self, tConfig) -> List: 45 | """ 46 | Generate dialogues based on the provided configuration. 47 | 48 | This method generates one or more dialogues based on the parameters provided in 49 | the `tConfig` dictionary. The method interacts with the model's `execute_request` 50 | function to generate dialogue based on the system and user prompts. It also supports 51 | various options for controlling randomness, output length, and sampling strategy. 52 | 53 | Parameters 54 | ---------- 55 | tConfig : dict 56 | A configuration dictionary containing the following key-value pairs: 57 | - "system_prompt" : str, optional 58 | A system-level prompt that guides the dialogue generation. Default is an empty string. 59 | - "user_prompt" : str, optional 60 | A user-provided prompt to initiate the dialogue generation. Default is an empty string. 61 | - "model" : str, optional 62 | The specific model to use for generation. Default is "glm-4-flash". 63 | - "do_sample" : bool, optional 64 | Whether to use sampling strategies during text generation. Default is True. 65 | - "temperature" : float, optional 66 | A sampling parameter to control the randomness of output. Must be between 0.0 and 1.0. Default is 0.95. 67 | - "top_p" : float, optional 68 | Nucleus sampling parameter controlling the cumulative probability range for token selection. 69 | Must be between 0.0 and 1.0. Default is 0.7. 70 | - "max_tokens" : int, optional 71 | The maximum number of tokens to generate. Default is 4095. 72 | - "num_samples" : int, optional 73 | The number of dialogue samples to generate. Default is 1. 74 | 75 | Returns 76 | ------- 77 | list of dict 78 | A list of dictionaries containing the generated dialogues. Each dictionary 79 | includes the generated dialogue content. 80 | 81 | Notes 82 | ----- 83 | - The method will attempt to generate dialogues until a valid response is generated. 84 | If the generated dialogue is `None`, it will retry. 85 | - Progress is logged for each sample generated. 86 | """ 87 | 88 | # Extract configuration parameters 89 | system_prompt = tConfig.get("system_prompt", "") 90 | user_prompt = tConfig.get("user_prompt", "") 91 | do_sample = tConfig.get("do_sample", True) 92 | temperature = tConfig.get("temperature", 0.95) 93 | top_p = tConfig.get("top_p", 0.7) 94 | max_tokens = tConfig.get("max_tokens", 4095) 95 | num_samples = tConfig.get("num_samples", 1) # Default is to generate 1 sample 96 | 97 | # List to store the generated dialogues 98 | dialogues = [] 99 | 100 | # Generate dialogues for the specified number of samples 101 | total_samples = num_samples # Total number of samples to generate 102 | logger.info("Starting the data generation process.") 103 | for _idx in range(1, num_samples + 1): 104 | retry_count = 0 # 初始化重试计数 105 | max_retries = 5 # 设置最大重试次数(根据需要调整) 106 | 107 | while True: # Keep trying until valid dialogue data is generated 108 | retry_count += 1 109 | 110 | generated_dialogue = self.model.execute_request( 111 | system_prompt=system_prompt, 112 | user_prompt=user_prompt, 113 | do_sample=do_sample, 114 | temperature=temperature, 115 | top_p=top_p, 116 | max_tokens=max_tokens, 117 | ) 118 | 119 | if "error" in generated_dialogue: 120 | logger.warning( 121 | "Sample %d: Request failed with error: %s. Retrying (%d/%d)...", 122 | _idx, 123 | generated_dialogue["error"], 124 | retry_count, 125 | max_retries, 126 | ) 127 | 128 | if retry_count >= max_retries: 129 | logger.error("Sample %d: Max retries reached. Skipping this sample.", _idx) 130 | break # 跳出当前样本,进入下一个 131 | 132 | continue # 继续当前样本的生成 133 | 134 | 135 | # Convert the generated dialogue to the desired format (e.g., Alpaca format) 136 | converted_generated_dialogue = self._convert_original_to_alpaca(system_prompt, generated_dialogue) 137 | 138 | if converted_generated_dialogue is not None: 139 | # If the dialogue is valid, append it to the results and break the loop 140 | dialogues.append(converted_generated_dialogue) 141 | break 142 | else: 143 | logger.warning( 144 | "Sample %d: Generated dialogue is None. Retrying (%d/%d)...", 145 | _idx, 146 | retry_count, 147 | max_retries, 148 | ) 149 | 150 | if retry_count >= max_retries: 151 | logger.error("Sample %d: Max retries reached. Skipping this sample.", _idx) 152 | break # 跳出当前样本 153 | 154 | 155 | # Log the progress of dialogue generation 156 | progress = (_idx / total_samples) * 100 157 | logger.info("Data generation progress: %.2f%% (%d/%d samples completed)", progress, _idx, total_samples) 158 | 159 | return dialogues 160 | -------------------------------------------------------------------------------- /build/lib/edg4llm/generators/text_generators/question_generator.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict, List, Any 3 | from edg4llm.utils.logger import custom_logger 4 | from edg4llm.generators.text_generators.base_generator import BaseGenerator 5 | 6 | logger = custom_logger("QuestionGenerator") 7 | 8 | class QuestionGenerator(BaseGenerator): 9 | """ 10 | A class for generating questions based on user prompts and configuration. 11 | 12 | This class extends the `BaseGenerator` class and provides functionality to generate 13 | questions using a specified model. It interacts with the model's `execute_request` 14 | method to create output based on user-defined parameters such as sampling strategies, 15 | temperature, and maximum tokens. 16 | 17 | Attributes 18 | ---------- 19 | model : object 20 | The model interface used for generating questions. 21 | 22 | Methods 23 | ------- 24 | generate(tConfig: dict) -> list of dict: 25 | Generates questions based on the provided configuration. 26 | 27 | Notes 28 | ----- 29 | - The `generate` method ensures valid responses are returned, retrying if necessary. 30 | - Logs progress for each generated question. 31 | """ 32 | 33 | def __init__(self, model): 34 | """ 35 | Initialize the QuestionGenerator. 36 | 37 | Parameters 38 | ---------- 39 | model : object 40 | The model interface used for generating questions. 41 | """ 42 | 43 | super().__init__(model) 44 | 45 | def generate(self, tConfig: Dict) -> List: 46 | """ 47 | Generate questions based on the provided configuration. 48 | 49 | This method generates one or more questions using the parameters specified 50 | in the `tConfig` dictionary. It interacts with the model's `execute_request` 51 | method to generate output based on user prompts and various sampling options. 52 | 53 | Parameters 54 | ---------- 55 | tConfig : dict 56 | A dictionary containing configuration options for question generation: 57 | - "system_prompt" : str, optional 58 | A system-level instruction to guide the question generation. Default is an empty string. 59 | - "user_prompt" : str, optional 60 | A user-provided input to guide the question generation. Default is an empty string. 61 | - "model" : str, optional 62 | Specifies the model for text generation. Default is "glm-4-flash". 63 | - "do_sample" : bool, optional 64 | Whether to use sampling during generation. Default is True. 65 | - "temperature" : float, optional 66 | Controls randomness in output. Value should be between 0.0 and 1.0. Default is 0.95. 67 | - "top_p" : float, optional 68 | Nucleus sampling parameter to limit token selection to a cumulative probability. Default is 0.7. 69 | - "max_tokens" : int, optional 70 | The maximum number of tokens for the output. Default is 4095. 71 | - "num_samples" : int, optional 72 | The number of question samples to generate. Default is 1. 73 | 74 | Returns 75 | ------- 76 | list of dict 77 | A list of dictionaries containing the generated questions. 78 | 79 | Notes 80 | ----- 81 | - The method retries generation until a valid response is obtained. 82 | - Logs progress for each generated sample. 83 | """ 84 | 85 | # Extract parameters from the configuration 86 | system_prompt = tConfig.get("system_prompt", "") 87 | user_prompt = tConfig.get("user_prompt", "") 88 | do_sample = tConfig.get("do_sample", True) 89 | temperature = tConfig.get("temperature", 0.95) 90 | top_p = tConfig.get("top_p", 0.7) 91 | max_tokens = tConfig.get("max_tokens", 4095) 92 | num_samples = tConfig.get("num_samples", 1) 93 | 94 | # Initialize a list to store generated questions 95 | questions = [] 96 | cur_len = 0 97 | # Generate questions for the specified number of samples 98 | logger.info("Starting the data generation process.") 99 | for _idx in range(1, num_samples + 1): 100 | retry_count = 0 # 初始化重试计数 101 | max_retries = 5 # 设置最大重试次数(根据需要调整) 102 | 103 | while True: # Retry until a valid question is generated 104 | retry_count += 1 105 | 106 | generated_question = self.model.execute_request( 107 | system_prompt=system_prompt, 108 | user_prompt=user_prompt, 109 | do_sample=do_sample, 110 | temperature=temperature, 111 | top_p=top_p, 112 | max_tokens=max_tokens, 113 | ) 114 | 115 | if "error" in generated_question: 116 | logger.warning( 117 | "Sample %d: Request failed with error: %s. Retrying (%d/%d)...", 118 | _idx, 119 | generated_question["error"], 120 | retry_count, 121 | max_retries, 122 | ) 123 | 124 | if (retry_count >= max_retries): 125 | logger.error("Sample %d: Max retries reached. Skipping this sample.", _idx) 126 | break # 跳出当前样本 127 | 128 | # Convert the raw output to a specific format 129 | converted_question = self._convert_original_to_json(generated_question) 130 | 131 | if converted_question is not None: 132 | cur_len = len(converted_question) 133 | questions.extend(converted_question) 134 | break 135 | else: 136 | logger.warning( 137 | "Sample %d: Generated dialogue is None. Retrying (%d/%d)...", 138 | _idx, 139 | retry_count, 140 | max_retries, 141 | ) 142 | 143 | if retry_count >= max_retries: 144 | logger.error("Sample %d: Max retries reached. Skipping this sample.", _idx) 145 | break # 跳出当前样本 146 | 147 | # Log progress for tracking generation completion 148 | progress = (_idx / num_samples) * 100 149 | logger.info("Generation progress: %.2f%% (%d samples generated, %d/%d epoch completed)", progress, cur_len, _idx, num_samples) 150 | 151 | return questions 152 | -------------------------------------------------------------------------------- /build/lib/edg4llm/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alannikos/edg4llm/73b996c36c750fec2d99c934bda01a4d4a57954a/build/lib/edg4llm/models/__init__.py -------------------------------------------------------------------------------- /build/lib/edg4llm/models/baseModel.py: -------------------------------------------------------------------------------- 1 | """ 2 | Module for defining the base class of EDG models. 3 | 4 | This file contains the abstract base class `EDGBaseModel`, which serves as a foundation for implementing various 5 | machine learning models. The class defines key methods that must be implemented by any derived model class 6 | to handle requests, send HTTP requests, and interact with APIs. 7 | 8 | Classes 9 | ------- 10 | EDGBaseModel(ABC) 11 | Abstract base class for EDG models, providing a standard structure for derived model implementations. 12 | 13 | Methods 14 | ------- 15 | __init__(api_key: str = None, base_url: str = None, model_name: str = None) 16 | Initializes the base model with API key, base URL, and model name. 17 | 18 | execute_request(system_prompt: str, user_prompt: str, **kwargs) -> str 19 | Abstract method to process user input and generate model responses. 20 | Must be implemented by derived classes. 21 | 22 | send_request(request: Dict[str, Any]) -> Dict[str, Any] 23 | Abstract method to send HTTP requests and handle server interactions. 24 | Must be implemented by derived classes. 25 | """ 26 | 27 | import requests 28 | from abc import ABC, abstractmethod 29 | from typing import Any, Dict 30 | 31 | from edg4llm.utils.logger import custom_logger 32 | 33 | logger = custom_logger('baseModel') 34 | 35 | 36 | class EDGBaseModel(ABC): 37 | """ 38 | Abstract base class for EDG models. 39 | 40 | This class defines the blueprint for machine learning model implementations. Derived classes must 41 | implement methods to process user prompts, interact with APIs, and handle HTTP requests. 42 | 43 | Attributes 44 | ---------- 45 | api_key : str 46 | The API key required for authenticating requests. 47 | 48 | base_url : str 49 | The base URL of the model API endpoint. 50 | 51 | model_name : str 52 | The name of the model, used to differentiate between various models. 53 | """ 54 | 55 | def __init__(self, api_key: str = None, base_url: str = None, model_name: str = None): 56 | """ 57 | Initializes the base model with API key, base URL, and model name. 58 | 59 | Parameters 60 | ---------- 61 | api_key : str, optional 62 | The API key for authenticating requests. Default is None. 63 | 64 | base_url : str, optional 65 | The base URL of the model API endpoint. Default is None. 66 | 67 | model_name : str, optional 68 | The name of the model, used for identifying different models. Default is None. 69 | """ 70 | self.api_key = api_key 71 | self.base_url = base_url 72 | self.model_name = model_name 73 | 74 | @abstractmethod 75 | def execute_request(self, system_prompt: str, user_prompt: str, **kwargs) -> str: 76 | """ 77 | Abstract method to process and execute a request. 78 | 79 | This method must be implemented by derived classes. It processes user input and generates 80 | responses based on a system prompt and additional parameters. 81 | 82 | Parameters 83 | ---------- 84 | system_prompt : str 85 | The system-level instruction or prompt defining the role or behavior of the model. 86 | 87 | user_prompt : str 88 | The user's input or query for the model. 89 | 90 | kwargs : dict 91 | Additional parameters for processing the request. 92 | 93 | Returns 94 | ------- 95 | str 96 | The response generated by the model. 97 | 98 | Notes 99 | ----- 100 | - Derived classes should implement this method to handle the specific logic for generating responses. 101 | """ 102 | pass 103 | 104 | @abstractmethod 105 | def send_request(self, request: Dict[str, Any]) -> Dict[str, Any]: 106 | """ 107 | Abstract method to send HTTP requests. 108 | 109 | This method must be implemented by derived classes to handle API interactions and perform 110 | error handling for HTTP requests. 111 | 112 | Parameters 113 | ---------- 114 | request : dict 115 | A dictionary containing all necessary information for the HTTP request. 116 | 117 | Returns 118 | ------- 119 | dict 120 | The server's response as a dictionary. 121 | 122 | Notes 123 | ----- 124 | - Derived classes should implement this method to handle API-specific logic and error handling. 125 | """ 126 | pass 127 | -------------------------------------------------------------------------------- /build/lib/edg4llm/models/chatglm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import requests 3 | from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union, cast 4 | 5 | from edg4llm.utils.logger import custom_logger 6 | from edg4llm.models.baseModel import EDGBaseModel 7 | from edg4llm.utils.exceptions import HttpClientError, InvalidPromptError 8 | 9 | logger = custom_logger('chatglm') 10 | 11 | class EDGChatGLM(EDGBaseModel): 12 | """ 13 | EDGChatGLM interface for interacting with the ChatGLM model to generate text based on given prompts. 14 | 15 | This class provides an interface to interact with the ChatGLM model for generating text 16 | based on a system and user prompt. It supports customizable parameters such as temperature, 17 | sampling strategies, and model selection. It also handles HTTP requests and error management. 18 | 19 | Parameters 20 | ---------- 21 | base_url : str, optional 22 | The base URL for the ChatGLM API. If not provided, defaults to None. 23 | api_key : str, optional 24 | The API key for authenticating with the ChatGLM API. If not provided, defaults to None. 25 | """ 26 | 27 | def __init__(self, base_url: str = None, api_key: str = None, model_name: str = 'glm-4-flash'): 28 | """ 29 | Initialize the ChatGLM model interface. 30 | 31 | This constructor initializes the `EDGChatGLM` class by calling the base class constructor 32 | and passing the API key, base URL, and model name ("ChatGLM"). It sets up the necessary 33 | configuration for interacting with the ChatGLM API. 34 | 35 | Parameters 36 | ---------- 37 | base_url : str, optional 38 | The base URL for the ChatGLM API. Default is None. 39 | api_key : str, optional 40 | The API key for authenticating with the ChatGLM API. Default is None. 41 | model_name: str, optional 42 | The specific model to use within the selected provider. Default is "glm-4-flash". 43 | Notes 44 | ----- 45 | The base URL and API key are required for successful communication with the ChatGLM API. 46 | """ 47 | super().__init__(api_key, base_url, model_name=model_name) 48 | 49 | def execute_request( 50 | self, 51 | system_prompt: str = None, 52 | user_prompt: str = None, 53 | do_sample: bool = True, 54 | temperature: float = 0.95, 55 | top_p: float = 0.7, 56 | max_tokens: int = 4095 57 | ) -> str: 58 | """ 59 | Generate text using the ChatGLM model based on the provided prompts and parameters. 60 | 61 | This method calls the internal request execution function and handles the text 62 | generation process using the specified system and user prompts. It allows controlling 63 | text generation via parameters such as temperature, sampling strategy, and token limits. 64 | 65 | Parameters 66 | ---------- 67 | system_prompt : str, optional 68 | The system-level prompt that sets the context for the conversation. Default is None. 69 | user_prompt : str, optional 70 | The user-provided prompt that initiates the conversation. Default is None. 71 | do_sample : bool, optional 72 | Whether to use sampling during text generation. Default is True. 73 | temperature : float, optional 74 | Sampling temperature to control randomness. Default is 0.95. 75 | top_p : float, optional 76 | Nucleus sampling parameter for controlling randomness. Default is 0.7. 77 | max_tokens : int, optional 78 | The maximum number of tokens to generate in the output. Default is 4095. 79 | 80 | Returns 81 | ------- 82 | str 83 | The generated text content from the model. 84 | 85 | Raises 86 | ------ 87 | InvalidPromptError 88 | If both the system and user prompts are None. 89 | """ 90 | response = self._execute_request(system_prompt, user_prompt, self.model_name, do_sample, temperature, top_p, max_tokens) 91 | return response 92 | 93 | def send_request(self, request: Dict[str, Any]) -> Dict[str, Any]: 94 | """ 95 | Send an HTTP request to the ChatGLM API. 96 | 97 | This method sends a POST request to the ChatGLM API with the provided request data. 98 | It returns the response data as a dictionary. 99 | 100 | Parameters 101 | ---------- 102 | request : dict 103 | A dictionary containing the request data, including the URL, headers, and JSON body. 104 | 105 | Returns 106 | ------- 107 | dict 108 | The response from the API in the form of a dictionary. 109 | 110 | Raises 111 | ------ 112 | HttpClientError 113 | If any error occurs during the HTTP request process. 114 | """ 115 | response = self._send_request(request=request) 116 | return response 117 | 118 | def _send_request(self, request: Dict[str, Any]) -> Dict[str, Any]: 119 | """ 120 | Internal method to send a POST request to the ChatGLM API. 121 | 122 | This method handles the actual HTTP POST request to the ChatGLM API. It includes 123 | error handling for HTTP errors, connection issues, timeouts, and JSON decoding. 124 | 125 | Parameters 126 | ---------- 127 | request : dict 128 | A dictionary containing the request data, including the URL, headers, and JSON body. 129 | 130 | Returns 131 | ------- 132 | dict 133 | The JSON response from the API. 134 | 135 | Raises 136 | ------ 137 | HttpClientError 138 | If an error occurs during the request. 139 | """ 140 | url = request.get("url", "https://open.bigmodel.cn/api/paas/v4/chat/completions") 141 | headers = {**request.get("headers", {})} 142 | json = request.get("json", {}) 143 | try: 144 | response = requests.post( 145 | url=url, 146 | headers=headers, 147 | json=json, 148 | timeout=30, 149 | ) 150 | response.raise_for_status() 151 | return response.json()["choices"][0]["message"]["content"].strip() 152 | 153 | except requests.exceptions.HTTPError as e: 154 | # Handle HTTP error exceptions 155 | status_code = e.response.status_code 156 | logger.error( 157 | "HTTP error occurred. Status Code: %s, URL: %s, Message: %s", 158 | status_code, 159 | url, 160 | e, 161 | ) 162 | 163 | return {"error": "HTTP error", "status_code": status_code, "message": str(e)} 164 | 165 | 166 | except requests.exceptions.ConnectionError as e: 167 | # Handle connection errors 168 | logger.error("Connection error occurred while connecting to %s: %s", url, e) 169 | 170 | return {"error": "Connection error", "message": str(e)} 171 | 172 | except requests.exceptions.Timeout as e: 173 | # Handle timeout errors 174 | logger.error("Timeout occurred while sending request to %s: %s", url, e) 175 | 176 | return {"error": "Timeout", "message": str(e)} 177 | 178 | 179 | except requests.exceptions.RequestException as e: 180 | # Handle any generic request exceptions 181 | logger.error( 182 | "Request exception occurred while sending request to %s: %s", url, e 183 | ) 184 | 185 | return {"error": "Request exception", "message": str(e)} 186 | 187 | 188 | except ValueError as e: 189 | # Handle JSON decoding errors 190 | logger.error("JSON decoding error occurred: %s", e) 191 | 192 | return {"error": "JSON decoding error", "message": str(e)} 193 | 194 | except Exception as e: 195 | # Catch any unexpected errors 196 | logger.critical( 197 | "An unexpected error occurred while sending request to %s: %s", url, e 198 | ) 199 | 200 | return {"error": "Unexpected error", "message": str(e)} 201 | 202 | def _execute_request( 203 | self, 204 | system_prompt: str = None, 205 | user_prompt: str = None, 206 | model: str = "glm-4-flash", 207 | do_sample: bool = True, 208 | temperature: float = 0.95, 209 | top_p: float = 0.7, 210 | max_tokens: int = 4095 211 | ) -> str: 212 | """ 213 | Internal method to prepare the request data and execute the request for text generation. 214 | 215 | This method prepares the necessary data (including headers, JSON body) for the 216 | ChatGLM API request and then calls the `send_request` method to send the request 217 | and return the response. 218 | 219 | Parameters 220 | ---------- 221 | system_prompt : str, optional 222 | The system-level prompt that provides context for the dialogue generation. 223 | Default is None. 224 | user_prompt : str, optional 225 | The user-provided prompt that initiates the generation. 226 | Default is None. 227 | model : str, optional 228 | The model to use for the generation. Default is "glm-4-flash". 229 | do_sample : bool, optional 230 | Whether to use sampling during text generation. Default is True. 231 | temperature : float, optional 232 | Sampling temperature to control randomness. Default is 0.95. 233 | top_p : float, optional 234 | Nucleus sampling parameter for controlling randomness. Default is 0.7. 235 | max_tokens : int, optional 236 | The maximum number of tokens to generate. Default is 4095. 237 | 238 | Returns 239 | ------- 240 | str 241 | The generated text content from the model. 242 | 243 | Raises 244 | ------ 245 | InvalidPromptError 246 | If both the system and user prompts are None. 247 | """ 248 | if (system_prompt is None and user_prompt is None): 249 | logger.error("Both prompts cannot be empty") 250 | raise InvalidPromptError("Both prompts cannot be empty") 251 | 252 | request_data = { 253 | "url": f"{self.base_url}", 254 | "headers": { 255 | "Authorization": f"Bearer {self.api_key}", 256 | "Content-Type": "application/json", 257 | }, 258 | "json": { 259 | "model": model, 260 | "messages": [ 261 | {"role": "system", "content": system_prompt}, 262 | {"role": "user", "content": user_prompt}, 263 | ], 264 | "do_sample": do_sample, 265 | "temperature": temperature, 266 | "top_p": top_p, 267 | "max_tokens": max_tokens, 268 | }, 269 | } 270 | 271 | response = self.send_request(request_data) 272 | 273 | return response 274 | -------------------------------------------------------------------------------- /build/lib/edg4llm/models/chatgpt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import requests 3 | from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union, cast 4 | 5 | from edg4llm.utils.logger import custom_logger 6 | from edg4llm.models.baseModel import EDGBaseModel 7 | from edg4llm.utils.exceptions import HttpClientError, InvalidPromptError 8 | 9 | logger = custom_logger('chatgpt') 10 | 11 | class EDGChatGPT(EDGBaseModel): 12 | """ 13 | A class to interface with the ChatGPT model for text generation. 14 | 15 | This class extends the `EDGBaseModel` abstract base class to implement a specific interface 16 | for interacting with the ChatGPT API. It supports text generation using system-level and 17 | user-level prompts with customizable parameters such as temperature, sampling strategies, 18 | and token limits. The class also includes methods to handle HTTP requests and manage errors. 19 | 20 | Attributes 21 | ---------- 22 | base_url : str 23 | The base URL for the ChatGPT API endpoint. 24 | api_key : str 25 | The API key for authenticating with the ChatGPT API. 26 | model_name : str 27 | The specific model to use, defaulting to "gpt-4o-mini". 28 | 29 | Methods 30 | ------- 31 | execute_request(system_prompt: str, user_prompt: str, do_sample: bool, temperature: float, top_p: float, max_tokens: int) -> str: 32 | Generates text using the ChatGPT model based on the provided prompts and parameters. 33 | 34 | send_request(request: Dict[str, Any]) -> Dict[str, Any]: 35 | Sends an HTTP POST request to the ChatGPT API and returns the response as a dictionary. 36 | 37 | Notes 38 | ----- 39 | - The `base_url` and `api_key` are required for proper communication with the ChatGPT API. 40 | - Provides detailed error handling for HTTP, connection, timeout, and JSON decoding issues. 41 | - Supports customizable text generation parameters for flexibility in model behavior. 42 | """ 43 | 44 | def __init__(self, base_url:str = None, api_key: str = None, model_name: str = "gpt-4o-mini"): 45 | """ 46 | Initialize the ChatGPT model interface. 47 | 48 | Parameters 49 | ---------- 50 | base_url : str, optional 51 | The base URL for the ChatGPT API. Default is None. 52 | api_key : str, optional 53 | The API key for authenticating with the ChatGPT API. Default is None. 54 | model_name : str, optional 55 | The specific model to use, defaulting to "gpt-4o-mini". 56 | """ 57 | 58 | super().__init__(api_key, base_url, model_name=model_name) 59 | 60 | def execute_request( 61 | self 62 | , system_prompt: str = None 63 | , user_prompt: str = None 64 | , do_sample: bool = True 65 | , temperature: float = 0.95 66 | , top_p: float = 0.7 67 | , max_tokens: int = 4095 68 | ) -> str: 69 | 70 | """ 71 | Generate text using the ChatGPT model based on the provided prompts and parameters. 72 | 73 | Parameters 74 | ---------- 75 | system_prompt : str, optional 76 | The system-level prompt providing context for the text generation. Default is None. 77 | user_prompt : str, optional 78 | The user-provided prompt initiating the text generation. Default is None. 79 | do_sample : bool, optional 80 | Whether to use sampling during text generation. Default is True. 81 | temperature : float, optional 82 | Sampling temperature to control randomness. Default is 0.95. 83 | top_p : float, optional 84 | Nucleus sampling parameter to control randomness. Default is 0.7. 85 | max_tokens : int, optional 86 | The maximum number of tokens to generate. Default is 4095. 87 | 88 | Returns 89 | ------- 90 | str 91 | The generated text content from the model. 92 | 93 | Raises 94 | ------ 95 | InvalidPromptError 96 | If both system and user prompts are None. 97 | """ 98 | 99 | response = self._execute_request(system_prompt, user_prompt, self.model_name, do_sample, temperature, top_p, max_tokens) 100 | return response 101 | 102 | def send_request(self, request: Dict[str, Any]) -> Dict[str, Any]: 103 | 104 | """ 105 | Send an HTTP request to the ChatGPT API. 106 | 107 | Parameters 108 | ---------- 109 | request : dict 110 | A dictionary containing the request data, including the URL, headers, and JSON body. 111 | 112 | Returns 113 | ------- 114 | dict 115 | The response from the API in the form of a dictionary. 116 | 117 | Raises 118 | ------ 119 | HttpClientError 120 | If any error occurs during the HTTP request process. 121 | """ 122 | 123 | response = self._send_request(request=request) 124 | return response 125 | 126 | def _send_request(self, request: Dict[str, Any]) -> Dict[str, Any]: 127 | 128 | """ 129 | Internal method to send an HTTP POST request to the ChatGPT API. 130 | 131 | This method handles the actual HTTP POST request and manages error handling 132 | for issues like connection failures, timeouts, and JSON decoding errors. 133 | 134 | Parameters 135 | ---------- 136 | request : dict 137 | A dictionary containing the request data, including the URL, headers, and JSON body. 138 | 139 | Returns 140 | ------- 141 | dict 142 | The JSON response from the API. 143 | 144 | Raises 145 | ------ 146 | HttpClientError 147 | If an error occurs during the HTTP request. 148 | """ 149 | 150 | url = request.get("url", "https://api.openai.com/v1/chat/completions") 151 | headers = {**request.get("headers", {})} 152 | json = request.get("json", {}) 153 | try: 154 | response = requests.post( 155 | url=url, 156 | headers=headers, 157 | json=json, 158 | timeout=30, 159 | ) 160 | 161 | response.raise_for_status() 162 | 163 | return response.json()["choices"][0]["message"]["content"].strip() 164 | 165 | except requests.exceptions.HTTPError as e: 166 | # Handle HTTP error exceptions 167 | status_code = e.response.status_code 168 | logger.error( 169 | "HTTP error occurred. Status Code: %s, URL: %s, Message: %s", 170 | status_code, 171 | url, 172 | e, 173 | ) 174 | 175 | return {"error": "HTTP error", "status_code": status_code, "message": str(e)} 176 | 177 | 178 | except requests.exceptions.ConnectionError as e: 179 | # Handle connection errors 180 | logger.error("Connection error occurred while connecting to %s: %s", url, e) 181 | 182 | return {"error": "Connection error", "message": str(e)} 183 | 184 | except requests.exceptions.Timeout as e: 185 | # Handle timeout errors 186 | logger.error("Timeout occurred while sending request to %s: %s", url, e) 187 | 188 | return {"error": "Timeout", "message": str(e)} 189 | 190 | 191 | except requests.exceptions.RequestException as e: 192 | # Handle any generic request exceptions 193 | logger.error( 194 | "Request exception occurred while sending request to %s: %s", url, e 195 | ) 196 | 197 | return {"error": "Request exception", "message": str(e)} 198 | 199 | 200 | except ValueError as e: 201 | # Handle JSON decoding errors 202 | logger.error("JSON decoding error occurred: %s", e) 203 | 204 | return {"error": "JSON decoding error", "message": str(e)} 205 | 206 | except Exception as e: 207 | # Catch any unexpected errors 208 | logger.critical( 209 | "An unexpected error occurred while sending request to %s: %s", url, e 210 | ) 211 | 212 | return {"error": "Unexpected error", "message": str(e)} 213 | 214 | 215 | def _execute_request( 216 | self 217 | , system_prompt: str = None 218 | , user_prompt: str = None 219 | , model: str = "gpt-4o-mini" 220 | , do_sample: bool = True 221 | , temperature: float = 0.95 222 | , top_p: float = 0.7 223 | , max_tokens: int = 4095 224 | ) -> str: 225 | 226 | """ 227 | Internal method to prepare and execute the API request for text generation. 228 | 229 | Parameters 230 | ---------- 231 | system_prompt : str, optional 232 | The system-level prompt providing context for the text generation. Default is None. 233 | user_prompt : str, optional 234 | The user-provided prompt initiating the text generation. Default is None. 235 | model : str, optional 236 | The specific model to use for text generation. Default is "gpt-4o-mini". 237 | do_sample : bool, optional 238 | Whether to use sampling during text generation. Default is True. 239 | temperature : float, optional 240 | Sampling temperature to control randomness. Default is 0.95. 241 | top_p : float, optional 242 | Nucleus sampling parameter to control randomness. Default is 0.7. 243 | max_tokens : int, optional 244 | The maximum number of tokens to generate. Default is 4095. 245 | 246 | Returns 247 | ------- 248 | str 249 | The generated text content from the model. 250 | 251 | Raises 252 | ------ 253 | InvalidPromptError 254 | If both system and user prompts are None. 255 | """ 256 | 257 | if (system_prompt is None and user_prompt is None): 258 | logger.error("prompt不能同时为空") 259 | raise InvalidPromptError("prompt不能同时为空") 260 | 261 | request_data = { 262 | "url": f"{self.base_url}", 263 | "headers": { 264 | "Authorization": f"Bearer {self.api_key}", 265 | "Content-Type": "application/json", 266 | }, 267 | "json": { 268 | "model": model, 269 | "messages": [ 270 | { 271 | "role": "developer", 272 | "content": system_prompt, 273 | }, 274 | { 275 | "role": "user", 276 | "content": user_prompt, 277 | } 278 | ], 279 | "temperature": temperature, 280 | "top_p": top_p, 281 | "max_tokens": max_tokens 282 | }, 283 | } 284 | 285 | response = self.send_request(request_data) 286 | return response 287 | -------------------------------------------------------------------------------- /build/lib/edg4llm/processor/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alannikos/edg4llm/73b996c36c750fec2d99c934bda01a4d4a57954a/build/lib/edg4llm/processor/__init__.py -------------------------------------------------------------------------------- /build/lib/edg4llm/processor/postprocess.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Dict, List, Any 3 | 4 | from edg4llm.utils.logger import custom_logger 5 | 6 | logger = custom_logger("PostProcessor") 7 | 8 | class PostProcessor: 9 | """ 10 | A class for post-processing conversation and question data. 11 | 12 | This class provides methods to clean and structure raw data obtained from API responses or external sources. 13 | It handles the removal of unnecessary markdown formatting, parses the data into valid JSON format, and 14 | structures it for further use in applications such as chatbots or AI assistants. It can also incorporate 15 | an optional system prompt into the processed data for context. 16 | 17 | Methods 18 | ------- 19 | dialogue_postprocessing(conversation_data: Dict[str, str], system_prompt: str = None): 20 | Processes raw conversation data by cleaning, parsing, and adding an optional system prompt. 21 | 22 | question_postprocessing(question_data: str = None): 23 | Processes raw question data by cleaning and structuring it into a list of questions. 24 | 25 | answer_postprocessing(question: str, answer: str, system_prompt: str = None): 26 | Processes raw answer data by cleaning, parsing, and structuring it along with the question 27 | and an optional system prompt. 28 | """ 29 | 30 | def __init__(self): 31 | pass 32 | 33 | def dialogue_postprocessing(self, conversation_data: Dict[str, str], system_prompt: str = None): 34 | """ 35 | Post-process conversation data. 36 | 37 | This function processes raw conversation data by removing unnecessary formatting and parsing it 38 | into a valid JSON format. If a system-level prompt (system_prompt) is provided, it will be added 39 | as an "instruction" field to the first conversation entry. The processed data is returned as a 40 | dictionary with a "conversation" key. 41 | 42 | Parameters 43 | ---------- 44 | conversation_data : str 45 | The raw conversation data in string format, typically from an API response or an external source. 46 | It may contain markdown-style formatting such as "```json" or "```" that needs to be removed. 47 | 48 | system_prompt : str, optional 49 | An optional system-level prompt that will be added to the "instruction" field of the first 50 | conversation entry. If not provided, an empty string will be used. Default is None. 51 | 52 | Returns 53 | ------- 54 | dict or None 55 | Returns a dictionary containing the processed conversation data structured under the "conversation" key. 56 | Each item in the list corresponds to a conversation entry. If an error occurs during JSON parsing, 57 | the function logs the error and returns None. 58 | 59 | Examples 60 | -------- 61 | >>> conversation_data = ''' 62 | [ 63 | {"input": "AAA", "output": "BBBB"}, 64 | {"input": "CCC", "output": "DDDD"} 65 | ] 66 | ''' 67 | >>> system_prompt = "You are a helpful assistant." 68 | >>> processed_data = postprocessing(conversation_data, system_prompt) 69 | 70 | >>> # Output: 71 | >>> { 72 | "conversation": [ 73 | {"input": "AAA", "output": "BBBB", "instruction": "You are a helpful assistant."}, 74 | {"input": "CCC", "output": "DDDD"} 75 | ] 76 | } 77 | 78 | Notes 79 | ----- 80 | - The function removes any markdown formatting (like "```json" or "```") before parsing the data. 81 | - If JSON parsing fails, an error is logged, and the function returns None. 82 | """ 83 | try: 84 | # Clean and parse the JSON conversation data 85 | conversation_data = json.loads(conversation_data.replace("```json", "").replace("```", "")) 86 | except Exception as exception: 87 | logger.error("Error parsing JSON: %s", str(exception)) 88 | return None 89 | 90 | # Initialize the result dictionary with a "conversation" key 91 | result = {"conversation": []} 92 | 93 | # Add the system prompt as an instruction to the first conversation entry if provided 94 | for idx, data in enumerate(conversation_data): 95 | if idx == 0: 96 | data["instruction"] = system_prompt if system_prompt is not None else "" 97 | result["conversation"].append(data) 98 | 99 | return result 100 | 101 | 102 | def question_postprocessing(self, question_data: str = None): 103 | """ 104 | Post-process the question data. 105 | 106 | This function processes raw question data by removing unnecessary formatting and ensuring 107 | it is in a valid JSON format. It converts each question into a structured dictionary with 108 | the key "question" holding the processed content. 109 | 110 | Parameters 111 | ---------- 112 | question_data : str 113 | The raw question data in string format, typically from an API response or external source. 114 | The string may contain markdown-style formatting such as "```json" or "```" that should be removed. 115 | 116 | Returns 117 | ------- 118 | dict or None 119 | Returns a dictionary with the format {"question": }. 120 | If an error occurs during JSON parsing, it returns None. 121 | 122 | Examples 123 | -------- 124 | >>> question_data = "What is your name?" 125 | >>> processed_data = question_postprocessing(question_data) 126 | >>> print(processed_data) 127 | Output: {'question': 'What is your name?'} 128 | 129 | Notes 130 | ----- 131 | - This function removes any markdown formatting (e.g., "```json" or "```") from the input string. 132 | - If an exception occurs during JSON parsing, an error message is logged, and the function returns None. 133 | """ 134 | 135 | try: 136 | # Clean up and parse the JSON question data 137 | question_data = json.loads(question_data.replace("```json", "").replace("```", "")) 138 | except Exception as exception: 139 | logger.error("Error parsing JSON: %s", str(exception)) 140 | return None 141 | 142 | # Initialize the result with a "question" key 143 | result = [] 144 | 145 | # Extract the question and assign it to the result 146 | for _, data in enumerate(question_data): 147 | result.append(data) 148 | 149 | return result 150 | 151 | def answer_postprocessing(self, question: str, answer: str, system_prompt: str = None): 152 | """ 153 | Post-process conversation data. 154 | 155 | This function processes raw conversation data by parsing it into a valid JSON format and structuring 156 | it into a predefined format. It also adds an optional system prompt to each conversation entry 157 | under the "instruction" key. The processed data is returned as a dictionary wrapped in a list. 158 | 159 | Parameters 160 | ---------- 161 | question : str 162 | The input question or query from the user. 163 | 164 | answer : str 165 | The raw answer data in string format, typically containing JSON content. 166 | This string may contain markdown formatting (e.g., "```json" or "```") that needs to be removed. 167 | 168 | system_prompt : str, optional 169 | An optional system-level prompt to provide context or instructions. This will be added to 170 | each conversation entry under the "instruction" key. Default is None. 171 | 172 | Returns 173 | ------- 174 | list or None 175 | Returns a list containing a dictionary with the processed conversation data. 176 | The dictionary has a "conversation" key, which is a list of conversation entries. 177 | Each entry contains "input", "output", and "instruction" keys. 178 | If an error occurs during JSON parsing, the function logs the error and returns None. 179 | 180 | Examples 181 | -------- 182 | >>> # Input: 183 | >>> question = "What is AI?" 184 | >>> answer = ''' 185 | [ 186 | { 187 | "input": question, 188 | "output": "BBB" 189 | } 190 | ] 191 | ''' 192 | >>> system_prompt = "You are a helpful assistant." 193 | 194 | >>> # Function Call: 195 | >>> processed_data = answer_postprocessing(question, answer, system_prompt) 196 | 197 | >>> # Output: 198 | >>> [ 199 | { 200 | "conversation": [ 201 | { 202 | "input": "What is AI?", 203 | "output": "BBB", 204 | "instruction": "You are a helpful assistant." 205 | } 206 | ] 207 | } 208 | ] 209 | 210 | Notes 211 | ----- 212 | - The function removes any markdown formatting (like "```json" or "```") before parsing the data. 213 | - If JSON parsing fails, the function logs an error and returns None. 214 | - The output is wrapped in a list to allow for future extensibility. 215 | """ 216 | 217 | try: 218 | # Clean up and parse the JSON conversation data 219 | conversation_data = json.loads(answer.replace("```json","").replace("```","")) 220 | except Exception as exception: 221 | logger.error("Error parsing JSON: %s", str(exception)) 222 | return None 223 | 224 | # Initialize the result with a conversation key 225 | result = {"conversation": []} 226 | conversation = {"instruction" : system_prompt, "input" : question} 227 | # Add the system prompt to the first conversation entry if provided 228 | for idx, data in enumerate(conversation_data): 229 | conversation['output'] = data["answer"] 230 | result["conversation"].append(conversation) 231 | return result 232 | -------------------------------------------------------------------------------- /build/lib/edg4llm/processor/preprocess.py: -------------------------------------------------------------------------------- 1 | import re 2 | import sys 3 | import json 4 | 5 | from edg4llm.utils.logger import custom_logger 6 | from edg4llm.utils.data_utils import is_question_template_consistent 7 | from edg4llm.utils.data_utils import is_answer_template_consistent 8 | from edg4llm.utils.data_utils import is_dialogue_template_consistent 9 | 10 | from edg4llm.utils.template import Template 11 | 12 | logger = custom_logger("preprocess") 13 | 14 | class PreProcessor: 15 | """ 16 | A class for pre-processing user prompts before data generation. 17 | 18 | This class provides methods to validate and repair user prompts in different modes such as question, 19 | answer, and dialogue. If a user prompt does not match the expected template, the methods automatically 20 | append the corresponding format guidelines to ensure consistency. 21 | 22 | Methods 23 | ------- 24 | question_preprocess(user_prompt: str) -> str: 25 | Validates and repairs user prompts in question mode. 26 | 27 | answer_preprocess(user_prompt: str) -> str: 28 | Validates and repairs user prompts in answer mode. 29 | 30 | dialogue_preprocess(user_prompt: str) -> str: 31 | Validates and repairs user prompts in Q&A (dialogue) mode. 32 | """ 33 | def __init__(self): 34 | pass 35 | 36 | def question_preprocess(self, language: str, user_prompt: str) -> str: 37 | """ 38 | Validates and processes user prompts in question mode. 39 | 40 | Parameters 41 | ---------- 42 | language : str 43 | The language of data in data generation. Must be one of 'zh', 'en'. 44 | 45 | user_prompt : str 46 | The user's input prompt to be processed in question mode. 47 | 48 | Returns 49 | ------- 50 | str 51 | The validated and, if necessary, repaired user prompt. 52 | 53 | Notes 54 | ----- 55 | - If the user prompt matches the question template, it is returned unchanged. 56 | - If the user prompt does not match, format guidelines from `Template.question_template` 57 | are appended to the prompt. 58 | """ 59 | 60 | if is_question_template_consistent(user_prompt=user_prompt): 61 | logger.info("User prompt matches the question template. Proceeding with data generation.") 62 | return user_prompt 63 | else: 64 | logger.warning("User prompt does not match the question template. Automatically added format guidelines.") 65 | if language == "zh": 66 | repaired_user_prompt = user_prompt + '\n' + Template.question_zh_template 67 | else: 68 | repaired_user_prompt = user_prompt + '\n' + Template.question_en_template 69 | return repaired_user_prompt 70 | 71 | def answer_preprocess(self, language: str, user_prompt: str) -> str: 72 | """ 73 | Validates and processes user prompts in answer mode. 74 | 75 | Parameters 76 | ---------- 77 | language : str 78 | The language of data in data generation. Must be one of 'zh', 'en'. 79 | 80 | user_prompt : str 81 | The user's input prompt to be processed in answer mode. 82 | 83 | Returns 84 | ------- 85 | str 86 | The validated and, if necessary, repaired user prompt. 87 | 88 | Notes 89 | ----- 90 | - If the user prompt matches the answer template, it is returned unchanged. 91 | - If the user prompt does not match, format guidelines from `Template.answer_template` 92 | are appended to the prompt. 93 | """ 94 | 95 | if is_answer_template_consistent(user_prompt=user_prompt): 96 | logger.info("User prompt matches the answer template. Proceeding with data generation.") 97 | return user_prompt 98 | else: 99 | logger.warning("User prompt does not match the answer template. Automatically added format guidelines.") 100 | if language == "zh": 101 | repaired_user_prompt = user_prompt + '\n' + Template.answer_zh_template 102 | else: 103 | repaired_user_prompt = user_prompt + '\n' + Template.answer_en_template 104 | return repaired_user_prompt 105 | 106 | def dialogue_preprocess(self, language: str, user_prompt: str) -> str: 107 | """ 108 | Validates and processes user prompts in Q&A (dialogue) mode. 109 | 110 | Parameters 111 | ---------- 112 | language : str 113 | The language of data in data generation. Must be one of 'zh', 'en'. 114 | 115 | user_prompt : str 116 | The user's input prompt to be processed in Q&A mode. 117 | 118 | Returns 119 | ------- 120 | str 121 | The validated and, if necessary, repaired user prompt. 122 | 123 | Notes 124 | ----- 125 | - If the user prompt matches the dialogue template, it is returned unchanged. 126 | - If the user prompt does not match, format guidelines from `Template.dialogue_template` 127 | are appended to the prompt. 128 | """ 129 | 130 | if is_dialogue_template_consistent(user_prompt=user_prompt): 131 | logger.info("User prompt matches the dialogue template. Proceeding with data generation.") 132 | return user_prompt 133 | else: 134 | logger.warning("User prompt does not match the dialogue template. Automatically added format guidelines.") 135 | if language == "zh": 136 | repaired_user_prompt = user_prompt + '\n' + Template.dialogue_zh_template 137 | else: 138 | repaired_user_prompt = user_prompt + '\n' + Template.dialogue_en_template 139 | return repaired_user_prompt 140 | -------------------------------------------------------------------------------- /build/lib/edg4llm/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alannikos/edg4llm/73b996c36c750fec2d99c934bda01a4d4a57954a/build/lib/edg4llm/utils/__init__.py -------------------------------------------------------------------------------- /build/lib/edg4llm/utils/config.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | 3 | @dataclasses 4 | class DefaultConfig: 5 | """ 6 | A placeholder class for default configuration settings. 7 | """ 8 | pass 9 | -------------------------------------------------------------------------------- /build/lib/edg4llm/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | from typing import Dict, List, Any 4 | 5 | def is_question_template_consistent(user_prompt: str) -> bool: 6 | """ 7 | Check if the user prompt contains a consistent question JSON template. 8 | 9 | Parameters 10 | ---------- 11 | user_prompt : str 12 | The user-provided prompt to be validated. 13 | 14 | Returns 15 | ------- 16 | bool 17 | True if the user prompt contains a valid and consistent question JSON template, 18 | False otherwise. 19 | 20 | Notes 21 | ----- 22 | - The function uses a regular expression to extract the JSON template and compares it 23 | with the target template. 24 | - The target template is: 25 | [ 26 | { 27 | "question": "AAA" 28 | } 29 | ] 30 | - Returns False if the JSON extraction or comparison fails. 31 | """ 32 | target_template = [ 33 | { 34 | "question": "AAA" 35 | } 36 | ] 37 | 38 | # Regular expression to extract JSON template 39 | pattern = r"\[\s*{\s*\"question\"\s*:\s*\"AAA\"\s*}\s*\]" 40 | match = re.search(pattern, user_prompt) 41 | 42 | if match: 43 | try: 44 | extracted_template = json.loads(match.group(0)) 45 | except json.JSONDecodeError: 46 | return False 47 | return extracted_template == target_template 48 | return False 49 | 50 | def is_answer_template_consistent(user_prompt: str) -> bool: 51 | """ 52 | Check if the user prompt contains a consistent answer JSON template. 53 | 54 | Parameters 55 | ---------- 56 | user_prompt : str 57 | The user-provided prompt to be validated. 58 | 59 | Returns 60 | ------- 61 | bool 62 | True if the user prompt contains a valid and consistent answer JSON template, 63 | False otherwise. 64 | 65 | Notes 66 | ----- 67 | - The function uses a regular expression to extract the JSON template and compares it 68 | with the target template. 69 | - The target template is: 70 | [ 71 | { 72 | "answer": "AAA" 73 | } 74 | ] 75 | - Returns False if the JSON extraction or comparison fails. 76 | """ 77 | target_template = [ 78 | { 79 | "answer": "AAA" 80 | } 81 | ] 82 | 83 | # Regular expression to extract JSON template 84 | pattern = r"\[\s*{\s*\"answer\"\s*:\s*\"AAA\"\s*}\s*\]" 85 | match = re.search(pattern, user_prompt) 86 | 87 | if match: 88 | try: 89 | extracted_template = json.loads(match.group(0)) 90 | except json.JSONDecodeError: 91 | return False 92 | return extracted_template == target_template 93 | return False 94 | 95 | def is_dialogue_template_consistent(user_prompt: str) -> bool: 96 | """ 97 | Check if the user prompt contains a consistent dialogue JSON template. 98 | 99 | Parameters 100 | ---------- 101 | user_prompt : str 102 | The user-provided prompt to be validated. 103 | 104 | Returns 105 | ------- 106 | bool 107 | True if the user prompt contains a valid and consistent dialogue JSON template, 108 | False otherwise. 109 | 110 | Notes 111 | ----- 112 | - The function uses a regular expression to check for the dialogue JSON structure. 113 | - The expected template format is: 114 | [ 115 | { 116 | "input": "AAA", 117 | "output": "BBB" 118 | } 119 | ] 120 | """ 121 | 122 | pattern = r"\[\s*\{\{\s*\"input\"\s*:\s*\"AAA\"\s*,\s*\"output\"\s*:\s*\"BBB\"\s*\}\}\s*\]" 123 | match = re.search(pattern, user_prompt) 124 | return match is not None 125 | 126 | def save_data_to_json(data: List[Dict], output_path: str): 127 | """ 128 | Save a list of dictionaries to a JSON file. 129 | 130 | Parameters 131 | ---------- 132 | data : list of dict 133 | A list of dictionaries to be saved to a JSON file. Each dictionary should contain 134 | the data to be written. 135 | 136 | output_path : str 137 | The path (including the filename) where the JSON data will be saved. 138 | The file will be written in UTF-8 encoding. 139 | 140 | Returns 141 | ------- 142 | None 143 | This function does not return any value. It saves the data to the specified file. 144 | 145 | Examples 146 | -------- 147 | >>> data = [{"name": "John", "age": 30}, {"name": "Jane", "age": 25}] 148 | >>> save_data_to_json(data, "output.json") 149 | 150 | Notes 151 | ----- 152 | - The function uses `json.dump` to write the data to the file. 153 | - Non-ASCII characters are preserved with the `ensure_ascii=False` argument. 154 | - The file will be saved with an indentation of 4 spaces to make it human-readable. 155 | """ 156 | with open(output_path, 'w', encoding='utf-8') as f: 157 | json.dump(data, f, ensure_ascii=False, indent=4) 158 | -------------------------------------------------------------------------------- /build/lib/edg4llm/utils/exceptions.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | 4 | class HttpClientError(Exception): 5 | """ 6 | Exception raised for errors encountered in the HTTP client. 7 | 8 | Parameters 9 | ---------- 10 | message : str 11 | A detailed error message describing the issue. 12 | status_code : Optional[int], optional 13 | The HTTP status code associated with the error, by default None. 14 | 15 | Attributes 16 | ---------- 17 | status_code : Optional[int] 18 | The HTTP status code associated with the error. 19 | """ 20 | 21 | def __init__(self, message: str, status_code: Optional[int] = None): 22 | super().__init__(message) 23 | self.status_code = status_code 24 | 25 | 26 | class InvalidPromptError(Exception): 27 | """ 28 | Custom exception raised when an invalid or empty prompt is encountered. 29 | 30 | Notes 31 | ----- 32 | This exception is intended to handle cases where a required prompt input 33 | is missing or invalid. 34 | """ 35 | pass 36 | -------------------------------------------------------------------------------- /build/lib/edg4llm/utils/list_supported_models.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from edg4llm.utils.logger import custom_logger 3 | 4 | class ModelManager: 5 | """ 6 | A class to manage supported model providers and their models. 7 | 8 | Attributes 9 | ---------- 10 | supported_models : dict 11 | A dictionary mapping provider names to their supported models. 12 | 13 | Methods 14 | ------- 15 | list_providers(): 16 | Returns a list of all supported providers. 17 | list_models_by_provider(provider_name): 18 | Returns a list of models supported by the given provider. 19 | """ 20 | def __init__(self): 21 | """ 22 | Initializes the ModelManager with a predefined list of supported models. 23 | """ 24 | self.supported_models = { 25 | "ChatGLM": ["glm-4-plus", "glm-4-0520", "glm-4-air", "glm-4-airx", "glm-4-long", "glm-4-flashx", "glm-4-flash"], 26 | "DeepSeek": ["deepseek-chat", "deepseek-reasoner"], 27 | "InternLM": ["internlm2.5-latest", "internlm3-latest"], 28 | "ChatGPT": ["gpt-3.5-turbo-16k", "gpt-3.5-turbo-1106", "gpt-3.5-turbo-0125", "gpt-3.5-turbo", "gpt-4o-mini", "gpt-4o-mini-2024-07-18", "o1-mini", "o1-mini-2024-09-12", "o1-preview","o1-preview-2024-09-12"] 29 | } 30 | 31 | def list_providers(self): 32 | """ 33 | Lists all supported model providers. 34 | 35 | Returns 36 | ------- 37 | list 38 | A list of provider names. 39 | """ 40 | 41 | return list(self.supported_models.keys()) 42 | 43 | def list_models_by_provider(self, provider_name): 44 | """ 45 | Lists all models supported by a given provider. 46 | 47 | Parameters 48 | ---------- 49 | provider_name : str 50 | The name of the provider. 51 | 52 | Returns 53 | ------- 54 | list or None 55 | A list of model names supported by the provider, 56 | or None if the provider does not exist. 57 | """ 58 | return self.supported_models.get(provider_name, None) 59 | 60 | def main(): 61 | """ 62 | Entry point of the script to display supported model providers 63 | and their corresponding models based on the user's input. 64 | """ 65 | parser = argparse.ArgumentParser(description="View the list of supported models.") 66 | parser.add_argument("--list-providers", action="store_true", help="List all supported providers.") 67 | parser.add_argument("--list-models", type=str, metavar="PROVIDER", help="View the list of models for a specific provider.") 68 | 69 | args = parser.parse_args() 70 | 71 | manager = ModelManager() 72 | 73 | if args.list_providers: 74 | providers = manager.list_providers() 75 | print("Supported model providers:") 76 | for provider in providers: 77 | print(f" - {provider}") 78 | elif args.list_models: 79 | models = manager.list_models_by_provider(args.list_models) 80 | if models: 81 | print(f"{args.list_models} supports the following models:") 82 | for model in models: 83 | print(f" - {model}") 84 | else: 85 | print(f"Provider '{args.list_models}' does not exist or is not supported.") 86 | else: 87 | parser.print_help() 88 | 89 | -------------------------------------------------------------------------------- /build/lib/edg4llm/utils/logger.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | 4 | __all__ = ['custom_logger'] 5 | 6 | # Define log level colors for terminal output 7 | LOG_COLORS = { 8 | 'DEBUG': '\033[96m', # Cyan 9 | 'INFO': '\033[92m', # Green 10 | 'WARNING': '\033[93m', # Yellow 11 | 'ERROR': '\033[91m', # Red 12 | 'CRITICAL': '\033[1;91m', # Bold Red 13 | 'RESET': '\033[0m', # Reset color 14 | } 15 | 16 | def custom_logger(name: str): 17 | """ 18 | Creates a custom logger with color-coded log levels and UTC+8 time formatting. 19 | 20 | Parameters 21 | ---------- 22 | name : str 23 | The name of the logger, typically the name of the module or application. 24 | 25 | Returns 26 | ------- 27 | logging.Logger 28 | A customized logger instance with color-coded levels and UTC+8 timezone support. 29 | 30 | Notes 31 | ----- 32 | - Log levels are color-coded for easier readability in terminal output. 33 | - Log messages use UTC+8 timezone formatting. 34 | - The logger prevents propagation to root loggers and clears existing handlers. 35 | - The logger uses a custom `StreamHandler` with color support. 36 | """ 37 | # Create a logger instance 38 | logger = logging.getLogger(name) 39 | logger.setLevel(logging.INFO) # Default log level 40 | logger.propagate = False # Disable propagation to root loggers 41 | logger.handlers = [] # Clear any existing handlers 42 | 43 | # Define a custom log message format 44 | formatter = logging.Formatter( 45 | '[%(asctime)s]-[%(name)s:%(levelname)s]:%(message)s' 46 | ) 47 | 48 | # Custom time converter to use UTC+8 49 | def _utc8_aera(timestamp): 50 | """ 51 | Convert a timestamp to a UTC+8 time tuple. 52 | 53 | Parameters 54 | ---------- 55 | timestamp : float 56 | The timestamp to convert. 57 | 58 | Returns 59 | ------- 60 | time.struct_time 61 | A time tuple in UTC+8 timezone. 62 | """ 63 | now = datetime.datetime.fromtimestamp(timestamp, tz=datetime.timezone.utc) + datetime.timedelta(hours=8) 64 | return now.timetuple() 65 | 66 | # Set the custom time converter in the formatter 67 | formatter.converter = _utc8_aera 68 | 69 | # Define a custom StreamHandler with color-coded log levels 70 | class ColorStreamHandler(logging.StreamHandler): 71 | """ 72 | A custom logging stream handler that adds color coding to log messages. 73 | 74 | Methods 75 | ------- 76 | emit(record): 77 | Formats and outputs a log record with color coding based on log level. 78 | """ 79 | def emit(self, record): 80 | """ 81 | Format and emit a log record with color coding. 82 | 83 | Parameters 84 | ---------- 85 | record : logging.LogRecord 86 | The log record to process and output. 87 | """ 88 | try: 89 | msg = self.format(record) # Format the log record 90 | color = LOG_COLORS.get(record.levelname, LOG_COLORS['RESET']) # Get the color for the log level 91 | # Write the log message with color 92 | self.stream.write(f"{color}{msg}{LOG_COLORS['RESET']}\n") 93 | self.flush() # Flush the stream 94 | except Exception: 95 | self.handleError(record) # Handle any errors during logging 96 | 97 | # Create and configure the custom handler 98 | custom_handler = ColorStreamHandler() 99 | custom_handler.setFormatter(formatter) 100 | 101 | # Add the custom handler to the logger 102 | logger.addHandler(custom_handler) 103 | 104 | return logger 105 | -------------------------------------------------------------------------------- /build/lib/edg4llm/utils/template.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | @dataclass 4 | class Template: 5 | """ 6 | A class to define language-specific templates for user prompts, providing a strict JSON format 7 | to preprocess user input. If the user's prompt does not include format instructions, the 8 | appropriate template will be added to enforce the required structure. 9 | 10 | Attributes: 11 | ---------- 12 | question_zh_template : str 13 | A JSON format template for Chinese question prompts. Ensures that generated questions 14 | are returned in a JSON format with a "question" field. 15 | 16 | answer_zh_template : str 17 | A JSON format template for Chinese answer prompts. Ensures that generated answers 18 | are returned in a JSON format with an "answer" field. 19 | 20 | dialogue_zh_template : str 21 | A JSON format template for Chinese dialogue prompts. Ensures that the interaction is 22 | returned in a JSON format with "input" representing the question and "output" representing 23 | the response. 24 | 25 | question_en_template : str 26 | A JSON format template for English question prompts. Ensures that generated questions 27 | are returned in a JSON format with a "question" field. 28 | 29 | answer_en_template : str 30 | A JSON format template for English answer prompts. Ensures that generated answers 31 | are returned in a JSON format with an "answer" field. 32 | 33 | dialogue_en_template : str 34 | A JSON format template for English dialogue prompts. Ensures that the interaction is 35 | returned in a JSON format with "input" representing the question and "output" representing 36 | the response. 37 | 38 | Notes: 39 | ----- 40 | This class is designed for preprocessing user prompts. If a user's input does not include 41 | specific format instructions, the appropriate template (based on language) is appended to 42 | the user prompt to ensure compliance with the required JSON format. 43 | """ 44 | 45 | question_zh_template = \ 46 | """ 47 | 严格遵循规则: 请以如下格式返回生成的数据, 只返回JSON格式,json模板: 48 | [ 49 | { 50 | "question":"AAA" 51 | } 52 | ] 53 | 其中question字段表示生成的问题 54 | """ 55 | 56 | answer_zh_template = \ 57 | """ 58 | 严格遵循规则: 请以如下格式返回生成的数据, 只返回JSON格式,json模板: 59 | [ 60 | { 61 | "answer":"AAA" 62 | } 63 | ] 64 | 其中answer字段表示生成的答案 65 | """ 66 | 67 | dialogue_zh_template = \ 68 | """ 69 | 严格遵循规则: 请以如下格式返回生成的数据, 只返回JSON格式,json模板: 70 | [ 71 | {{ 72 | "input":"AAA","output":"BBB" 73 | }} 74 | ] 75 | 其中input字段表示问题, output字段回答 76 | """ 77 | 78 | question_en_template = \ 79 | """ 80 | Strictly follow the rules: Please return the generated data in the following format, 81 | only in JSON format. JSON template: 82 | [ 83 | { 84 | "question":"AAA" 85 | } 86 | ] 87 | The "question" field represents the generated question. 88 | """ 89 | 90 | answer_en_template = \ 91 | """ 92 | Strictly follow the rules: Please return the generated data in the following format, 93 | only in JSON format. JSON template: 94 | [ 95 | { 96 | "answer":"AAA" 97 | } 98 | ] 99 | The "answer" field represents the generated answer. 100 | """ 101 | 102 | dialogue_en_template = \ 103 | """ 104 | Strictly follow the rules: Please return the generated data in the following format, 105 | only in JSON format. JSON template: 106 | [ 107 | {{ 108 | "input":"AAA","output":"BBB" 109 | }} 110 | ] 111 | The "input" field represents the question, and the "output" field 112 | represents the answer. 113 | """ 114 | -------------------------------------------------------------------------------- /demos/chatglm_demo_v1_0_1.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import edg4llm" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "print(edg4llm.__version__)" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "from edg4llm import EDG4LLM\n", 28 | "\n", 29 | "api_key = \"xxx\"\n", 30 | "base_url = \"https://open.bigmodel.cn/api/paas/v4/chat/completions\"\n", 31 | "\n", 32 | "edg = EDG4LLM(model_provider='chatglm', model_name=\"glm-4-flash\", base_url=base_url, api_key=api_key)\n", 33 | "# 设置测试数据\n", 34 | "system_prompt = \"\"\"你是一个精通中国古代诗词的古文学大师\"\"\"\n", 35 | "\n", 36 | "user_prompt = '''\n", 37 | " 目标: 1. 请生成过年为场景的连续多轮对话记录\n", 38 | " 2. 提出的问题要多样化。\n", 39 | " 3. 要符合人类的说话习惯。\n", 40 | " 4. 严格遵循规则: 请以如下格式返回生成的数据, 只返回JSON格式,json模板: \n", 41 | " [\n", 42 | " {{\n", 43 | " \"input\":\"AAA\",\"output\":\"BBB\" \n", 44 | " }}\n", 45 | " ]\n", 46 | " 其中input字段表示一个人的话语, output字段表示专家的话语\n", 47 | "'''\n", 48 | "num_samples = 1 # 只生成一个对话样本\n", 49 | "\n", 50 | "# 调用 generate 方法生成对话\n", 51 | "data_dialogue = edg.generate(\n", 52 | " task_type=\"dialogue\",\n", 53 | " system_prompt=system_prompt,\n", 54 | " user_prompt=user_prompt,\n", 55 | " num_samples=num_samples\n", 56 | ")" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "print(data_dialogue)" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "from edg4llm import EDG4LLM\n", 75 | "\n", 76 | "api_key = \"479dc4c9611acd56f0b7981f126a3411.tNS782K8hcuk1UeO\"\n", 77 | "base_url = \"https://open.bigmodel.cn/api/paas/v4/chat/completions\"\n", 78 | "\n", 79 | "edg = EDG4LLM(model_provider='chatglm', model_name=\"glm-4-flash\", base_url=base_url, api_key=api_key)\n", 80 | "# 设置测试数据\n", 81 | "system_prompt = \"\"\"你是一个精通中国古代诗词的古文学大师\"\"\"\n", 82 | "\n", 83 | "user_prompt = '''\n", 84 | " 目标: 1. 请生成过年为场景的问题\n", 85 | " 2. 提出的问题要多样化。\n", 86 | " 3. 要符合人类的说话习惯。\n", 87 | " 4. 严格遵循规则: 请以如下格式返回生成的数据, 只返回JSON格式,json模板: \n", 88 | " [\n", 89 | " {\n", 90 | " \"question\": \"AAA\"\n", 91 | " }\n", 92 | " ]\n", 93 | " 其中question表示你生成的问题\n", 94 | "'''\n", 95 | "num_samples = 1 # 只生成一个对话样本\n", 96 | "\n", 97 | "# 调用 generate 方法生成对话\n", 98 | "data_question = edg.generate(\n", 99 | " task_type=\"question\",\n", 100 | " system_prompt=system_prompt,\n", 101 | " user_prompt=user_prompt,\n", 102 | " num_samples=num_samples\n", 103 | ")" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "print(data_question)" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": null, 118 | "metadata": {}, 119 | "outputs": [], 120 | "source": [ 121 | "from edg4llm import EDG4LLM\n", 122 | "\n", 123 | "api_key = \"479dc4c9611acd56f0b7981f126a3411.tNS782K8hcuk1UeO\"\n", 124 | "base_url = \"https://open.bigmodel.cn/api/paas/v4/chat/completions\"\n", 125 | "\n", 126 | "edg = EDG4LLM(model_provider='chatglm', model_name=\"glm-4-flash\", base_url=base_url, api_key=api_key)\n", 127 | "# 设置测试数据\n", 128 | "system_prompt = \"\"\"你是一个精通中国古代诗词的古文学大师\"\"\"\n", 129 | "\n", 130 | "user_prompt = '''\n", 131 | " 目标: 1. 请生成EDG4LLM为问题的回答\n", 132 | " 2. 要符合人类的说话习惯。\n", 133 | " 3. 严格遵循规则: 请以如下格式返回生成的数据, 只返回JSON格式,json模板: \n", 134 | " [\n", 135 | " {\n", 136 | " \"answer\": \"AAA\"\n", 137 | " }\n", 138 | " ]\n", 139 | " 其中answer表示你的回答\n", 140 | "'''\n", 141 | "num_samples = 1 # 存在29个问题\n", 142 | "\n", 143 | "# 调用 generate 方法生成对话\n", 144 | "data_answer = edg.generate(\n", 145 | " task_type=\"answer\",\n", 146 | " system_prompt=system_prompt,\n", 147 | " user_prompt=user_prompt,\n", 148 | " num_samples=num_samples,\n", 149 | " question_path=\"./question.json\"\n", 150 | ")" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": null, 156 | "metadata": {}, 157 | "outputs": [], 158 | "source": [ 159 | "print(data_answer)" 160 | ] 161 | } 162 | ], 163 | "metadata": { 164 | "kernelspec": { 165 | "display_name": "edg4llm", 166 | "language": "python", 167 | "name": "python3" 168 | }, 169 | "language_info": { 170 | "codemirror_mode": { 171 | "name": "ipython", 172 | "version": 3 173 | }, 174 | "file_extension": ".py", 175 | "mimetype": "text/x-python", 176 | "name": "python", 177 | "nbconvert_exporter": "python", 178 | "pygments_lexer": "ipython3", 179 | "version": "3.10.0" 180 | } 181 | }, 182 | "nbformat": 4, 183 | "nbformat_minor": 2 184 | } 185 | -------------------------------------------------------------------------------- /demos/chatgpt_demo_v1_0_1.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 4, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import edg4llm" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "print(edg4llm.__version__)" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "from edg4llm import EDG4LLM\n", 28 | "\n", 29 | "api_key = \"xxx\"\n", 30 | "base_url = \"https://api.openai.com/v1/chat/completions\"\n", 31 | "\n", 32 | "edg = EDG4LLM(model_provider='chatgpt', model_name=\"gpt-3.5-turbo\", base_url=base_url, api_key=api_key)\n", 33 | "# 设置测试数据\n", 34 | "system_prompt = \"\"\"你是一个精通中国古代诗词的古文学大师\"\"\"\n", 35 | "\n", 36 | "user_prompt = '''\n", 37 | " 目标: 1. 请生成过年为场景的连续多轮对话记录\n", 38 | " 2. 提出的问题要多样化。\n", 39 | " 3. 要符合人类的说话习惯。\n", 40 | " 4. 严格遵循规则: 请以如下格式返回生成的数据, 只返回JSON格式,json模板: \n", 41 | " [\n", 42 | " {{\n", 43 | " \"input\":\"AAA\",\"output\":\"BBB\" \n", 44 | " }}\n", 45 | " ]\n", 46 | " 其中input字段表示一个人的话语, output字段表示专家的话语\n", 47 | "'''\n", 48 | "num_samples = 1 # 只生成一个对话样本\n", 49 | "\n", 50 | "# 调用 generate 方法生成对话\n", 51 | "data = edg.generate(\n", 52 | " task_type=\"dialogue\",\n", 53 | " system_prompt=system_prompt,\n", 54 | " user_prompt=user_prompt,\n", 55 | " num_samples=num_samples\n", 56 | ")" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "print(data)" 66 | ] 67 | } 68 | ], 69 | "metadata": { 70 | "kernelspec": { 71 | "display_name": "edg4llm", 72 | "language": "python", 73 | "name": "python3" 74 | }, 75 | "language_info": { 76 | "codemirror_mode": { 77 | "name": "ipython", 78 | "version": 3 79 | }, 80 | "file_extension": ".py", 81 | "mimetype": "text/x-python", 82 | "name": "python", 83 | "nbconvert_exporter": "python", 84 | "pygments_lexer": "ipython3", 85 | "version": "3.10.0" 86 | } 87 | }, 88 | "nbformat": 4, 89 | "nbformat_minor": 2 90 | } 91 | -------------------------------------------------------------------------------- /demos/deepseek_demo_v1_0_1.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import edg4llm" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 2, 15 | "metadata": {}, 16 | "outputs": [ 17 | { 18 | "name": "stdout", 19 | "output_type": "stream", 20 | "text": [ 21 | "0.1.0\n" 22 | ] 23 | } 24 | ], 25 | "source": [ 26 | "print(edg4llm.__version__)" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 3, 32 | "metadata": {}, 33 | "outputs": [ 34 | { 35 | "name": "stderr", 36 | "output_type": "stream", 37 | "text": [ 38 | "\u001b[92m[2025-01-10 21:17:28,112]-[interface:INFO]:DataPipeline initialized successfully with the provided configuration.\u001b[0m\n", 39 | "\u001b[92m[2025-01-10 21:17:28,115]-[DataPipeline:INFO]:Generated data for task_type: 'dialogue'\u001b[0m\n", 40 | "\u001b[92m[2025-01-10 21:17:28,116]-[preprocess:INFO]:User prompt matches the dialogue template. Proceeding with data generation.\u001b[0m\n", 41 | "\u001b[92m[2025-01-10 21:17:28,116]-[DialogueGenerator:INFO]:Starting the data generation process.\u001b[0m\n", 42 | "\u001b[92m[2025-01-10 21:17:49,850]-[DialogueGenerator:INFO]:Data generation progress: 100.00% (1/1 samples completed)\u001b[0m\n", 43 | "\u001b[92m[2025-01-10 21:17:49,851]-[interface:INFO]:Data generation completed successfully for task_type: dialogue\u001b[0m\n" 44 | ] 45 | } 46 | ], 47 | "source": [ 48 | "from edg4llm import EDG4LLM\n", 49 | "\n", 50 | "api_key = \"xxx\"\n", 51 | "base_url = \"https://api.deepseek.com/chat/completions\"\n", 52 | "\n", 53 | "edg = EDG4LLM(model_provider='deepseek', model_name=\"deepseek-chat\", base_url=base_url, api_key=api_key)\n", 54 | "# 设置测试数据\n", 55 | "system_prompt = \"\"\"你是一个精通中国古代诗词的古文学大师\"\"\"\n", 56 | "\n", 57 | "user_prompt = '''\n", 58 | " 目标: 1. 请生成过年为场景的连续多轮对话记录\n", 59 | " 2. 提出的问题要多样化。\n", 60 | " 3. 要符合人类的说话习惯。\n", 61 | " 4. 严格遵循规则: 请以如下格式返回生成的数据, 只返回JSON格式,json模板: \n", 62 | " [\n", 63 | " {{\n", 64 | " \"input\":\"AAA\",\"output\":\"BBB\" \n", 65 | " }}\n", 66 | " ]\n", 67 | " 其中input字段表示一个人的话语, output字段表示专家的话语\n", 68 | "'''\n", 69 | "num_samples = 1 # 只生成一个对话样本\n", 70 | "\n", 71 | "# 调用 generate 方法生成对话\n", 72 | "data = edg.generate(\n", 73 | " task_type=\"dialogue\",\n", 74 | " system_prompt=system_prompt,\n", 75 | " user_prompt=user_prompt,\n", 76 | " num_samples=num_samples\n", 77 | ")" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 4, 83 | "metadata": {}, 84 | "outputs": [ 85 | { 86 | "name": "stdout", 87 | "output_type": "stream", 88 | "text": [ 89 | "[{'conversation': [{'input': '过年了,家里准备怎么布置呢?', 'output': '过年时节,家中自然要张灯结彩,挂上红灯笼,贴上春联,以增添喜庆气氛。春联上常写有吉祥话语,如‘岁岁平安’、‘年年有余’,寓意着对新一年的美好祝愿。', 'instruction': '你是一个精通中国古代诗词的古文学大师'}, {'input': '那春联的内容有什么讲究吗?', 'output': '春联的内容讲究对仗工整,意义吉祥。上联与下联字数相等,结构对称,平仄相对,意义相关。例如,‘天增岁月人增寿,春满乾坤福满门’,既表达了时间的流转,又寄托了人们对幸福生活的向往。'}, {'input': '过年时,除了贴春联,还有哪些传统习俗呢?', 'output': '过年习俗丰富多彩,除了贴春联,还有放鞭炮、守岁、拜年、吃年夜饭等。放鞭炮是为了驱邪避凶,守岁则是全家团聚,共同迎接新年的到来。拜年是向长辈和亲朋好友表达新年祝福,而年夜饭则是家人团聚的重要时刻,象征着团圆和美满。'}, {'input': '年夜饭通常都吃些什么呢?', 'output': '年夜饭的菜肴丰富多样,各地风俗不同,菜品也有所差异。但通常都会有鱼,寓意‘年年有余’;还有饺子,象征着财富和团圆。此外,还有各种肉类、蔬菜和甜点,每道菜都承载着对新年的美好祝愿。'}, {'input': '过年期间,人们还会做些什么特别的活动吗?', 'output': '过年期间,人们会进行许多特别的活动,如舞龙舞狮、逛庙会、赏花灯等。舞龙舞狮是为了祈求风调雨顺、五谷丰登;逛庙会则是体验传统文化,购买年货;赏花灯则是欣赏精美的灯笼艺术,感受节日的欢乐气氛。'}, {'input': '过年时,孩子们最期待的是什么?', 'output': '孩子们最期待的莫过于收到压岁钱了。压岁钱是长辈给晚辈的祝福,寓意着驱邪避灾,保佑孩子平安健康。此外,孩子们还喜欢放鞭炮、看烟花,以及参与各种游戏和活动,享受节日的快乐。'}, {'input': '过年时,人们如何表达对亲朋好友的祝福?', 'output': '人们通过拜年、送贺卡、发短信或微信等方式,向亲朋好友表达新年的祝福。常见的祝福语有‘新年快乐’、‘万事如意’、‘身体健康’等,都是对对方的美好祝愿。'}, {'input': '过年时,有哪些诗词可以表达节日的喜庆和祝福?', 'output': '有许多诗词可以表达过年的喜庆和祝福,如王安石的《元日》:‘爆竹声中一岁除,春风送暖入屠苏。千门万户曈曈日,总把新桃换旧符。’这首诗描绘了春节的热闹景象,表达了人们对新年的美好期待。'}, {'input': '过年时,人们如何保持传统与现代的结合?', 'output': '在现代社会,人们既保留了传统的过年习俗,如贴春联、放鞭炮、吃年夜饭等,也融入了现代元素,如通过社交媒体发送祝福、观看春节联欢晚会等。这种结合既传承了文化,又适应了现代生活的节奏。'}, {'input': '过年时,有哪些地方特色的庆祝活动?', 'output': '各地过年的庆祝活动各具特色。例如,北方的庙会、南方的花市、西南的舞龙舞狮、东北的冰灯节等。这些活动不仅丰富了节日的内容,也展示了各地的文化特色和风土人情。'}]}]\n" 90 | ] 91 | } 92 | ], 93 | "source": [ 94 | "print(data)" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": 5, 100 | "metadata": {}, 101 | "outputs": [ 102 | { 103 | "name": "stderr", 104 | "output_type": "stream", 105 | "text": [ 106 | "\u001b[92m[2025-01-10 21:17:49,890]-[interface:INFO]:DataPipeline initialized successfully with the provided configuration.\u001b[0m\n", 107 | "\u001b[92m[2025-01-10 21:17:49,891]-[DataPipeline:INFO]:Generated data for task_type: 'question'\u001b[0m\n", 108 | "\u001b[92m[2025-01-10 21:17:49,892]-[preprocess:INFO]:User prompt matches the question template. Proceeding with data generation.\u001b[0m\n", 109 | "\u001b[92m[2025-01-10 21:17:49,892]-[QuestionGenerator:INFO]:Starting the data generation process.\u001b[0m\n", 110 | "\u001b[92m[2025-01-10 21:17:58,114]-[QuestionGenerator:INFO]:Generation progress: 100.00% (10 samples generated, 1/1 epoch completed)\u001b[0m\n", 111 | "\u001b[92m[2025-01-10 21:17:58,115]-[interface:INFO]:Data generation completed successfully for task_type: question\u001b[0m\n" 112 | ] 113 | } 114 | ], 115 | "source": [ 116 | "from edg4llm import EDG4LLM\n", 117 | "\n", 118 | "api_key = \"sk-82454352d8ad499fb04126fc190f5caf\"\n", 119 | "base_url = \"https://api.deepseek.com/chat/completions\"\n", 120 | "\n", 121 | "edg = EDG4LLM(model_provider='deepseek', model_name=\"deepseek-chat\", base_url=base_url, api_key=api_key)\n", 122 | "# 设置测试数据\n", 123 | "system_prompt = \"\"\"你是一个精通中国古代诗词的古文学大师\"\"\"\n", 124 | "\n", 125 | "user_prompt = '''\n", 126 | " 目标: 1. 请生成过年为场景的问题\n", 127 | " 2. 提出的问题要多样化。\n", 128 | " 3. 要符合人类的说话习惯。\n", 129 | " 4. 严格遵循规则: 请以如下格式返回生成的数据, 只返回JSON格式,json模板: \n", 130 | " [\n", 131 | " {\n", 132 | " \"question\": \"AAA\"\n", 133 | " }\n", 134 | " ]\n", 135 | " 其中question表示你生成的问题\n", 136 | "'''\n", 137 | "num_samples = 1 # 只生成一个对话样本\n", 138 | "\n", 139 | "# 调用 generate 方法生成对话\n", 140 | "data = edg.generate(\n", 141 | " task_type=\"question\",\n", 142 | " system_prompt=system_prompt,\n", 143 | " user_prompt=user_prompt,\n", 144 | " num_samples=num_samples\n", 145 | ")" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": 6, 151 | "metadata": {}, 152 | "outputs": [ 153 | { 154 | "name": "stdout", 155 | "output_type": "stream", 156 | "text": [ 157 | "[{'question': '在过年期间,家家户户都会贴春联,你能告诉我春联的起源和它在中国文化中的意义吗?'}, {'question': '过年时,人们常说‘年年有余’,这句话背后有什么深层的文化含义吗?'}, {'question': '在中国古代诗词中,有哪些著名的诗句是描写过年气氛的?'}, {'question': '过年时,人们会放鞭炮,这个习俗是怎么来的?它有什么特别的象征意义吗?'}, {'question': '在过年期间,家人团聚是非常重要的,你能分享一些古代诗词中关于家庭团聚的描写吗?'}, {'question': '过年时,人们会吃饺子,这个习俗有什么历史背景和文化意义?'}, {'question': '在中国古代,过年时皇帝会举行什么特别的仪式或活动?'}, {'question': '过年时,人们会给孩子压岁钱,这个传统是怎么来的?它有什么特别的寓意吗?'}, {'question': '在古代诗词中,有哪些诗句是描写过年时的喜庆和热闹气氛的?'}, {'question': '过年时,人们会挂红灯笼,这个习俗有什么特别的象征意义吗?'}]\n" 158 | ] 159 | } 160 | ], 161 | "source": [ 162 | "print(data)" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": 7, 168 | "metadata": {}, 169 | "outputs": [ 170 | { 171 | "name": "stderr", 172 | "output_type": "stream", 173 | "text": [ 174 | "\u001b[92m[2025-01-10 21:17:58,141]-[interface:INFO]:DataPipeline initialized successfully with the provided configuration.\u001b[0m\n", 175 | "\u001b[92m[2025-01-10 21:17:58,142]-[DataPipeline:INFO]:Generated data for task_type: 'answer'\u001b[0m\n", 176 | "\u001b[92m[2025-01-10 21:17:58,143]-[preprocess:INFO]:User prompt matches the answer template. Proceeding with data generation.\u001b[0m\n", 177 | "\u001b[92m[2025-01-10 21:17:58,145]-[AnswerGenerator:INFO]:Starting the data generation process.\u001b[0m\n", 178 | "\u001b[92m[2025-01-10 21:18:01,111]-[AnswerGenerator:INFO]:Data generation progress: 100.00% (1/1 samples completed)\u001b[0m\n", 179 | "\u001b[92m[2025-01-10 21:18:01,112]-[interface:INFO]:Data generation completed successfully for task_type: answer\u001b[0m\n" 180 | ] 181 | } 182 | ], 183 | "source": [ 184 | "from edg4llm import EDG4LLM\n", 185 | "\n", 186 | "api_key = \"sk-82454352d8ad499fb04126fc190f5caf\"\n", 187 | "base_url = \"https://api.deepseek.com/chat/completions\"\n", 188 | "\n", 189 | "edg = EDG4LLM(model_provider='deepseek', model_name=\"deepseek-chat\", base_url=base_url, api_key=api_key)\n", 190 | "# 设置测试数据\n", 191 | "system_prompt = \"\"\"你是一个精通中国古代诗词的古文学大师\"\"\"\n", 192 | "\n", 193 | "user_prompt = '''\n", 194 | " 目标: 1. 请生成EDG4LLM为问题的回答\n", 195 | " 2. 要符合人类的说话习惯。\n", 196 | " 3. 严格遵循规则: 请以如下格式返回生成的数据, 只返回JSON格式,json模板: \n", 197 | " [\n", 198 | " {\n", 199 | " \"answer\": \"AAA\"\n", 200 | " }\n", 201 | " ]\n", 202 | " 其中answer表示你的回答\n", 203 | "'''\n", 204 | "num_samples = 1 # 只生成一个对话样本\n", 205 | "\n", 206 | "# 调用 generate 方法生成对话\n", 207 | "data = edg.generate(\n", 208 | " task_type=\"answer\",\n", 209 | " system_prompt=system_prompt,\n", 210 | " user_prompt=user_prompt,\n", 211 | " num_samples=num_samples,\n", 212 | " question_path=\"./question.json\"\n", 213 | ")" 214 | ] 215 | } 216 | ], 217 | "metadata": { 218 | "kernelspec": { 219 | "display_name": "edg4llm", 220 | "language": "python", 221 | "name": "python3" 222 | }, 223 | "language_info": { 224 | "codemirror_mode": { 225 | "name": "ipython", 226 | "version": 3 227 | }, 228 | "file_extension": ".py", 229 | "mimetype": "text/x-python", 230 | "name": "python", 231 | "nbconvert_exporter": "python", 232 | "pygments_lexer": "ipython3", 233 | "version": "3.10.0" 234 | } 235 | }, 236 | "nbformat": 4, 237 | "nbformat_minor": 2 238 | } 239 | -------------------------------------------------------------------------------- /demos/question.json: -------------------------------------------------------------------------------- 1 | [ 2 | {"question": "在这个喜庆的春节,你最喜欢哪种传统习俗?"} 3 | ] -------------------------------------------------------------------------------- /demos/readme.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alannikos/edg4llm/73b996c36c750fec2d99c934bda01a4d4a57954a/demos/readme.md -------------------------------------------------------------------------------- /dist/edg4llm-1.0.18-py3-none-any.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alannikos/edg4llm/73b996c36c750fec2d99c934bda01a4d4a57954a/dist/edg4llm-1.0.18-py3-none-any.whl -------------------------------------------------------------------------------- /dist/edg4llm-1.0.18.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alannikos/edg4llm/73b996c36c750fec2d99c934bda01a4d4a57954a/dist/edg4llm-1.0.18.tar.gz -------------------------------------------------------------------------------- /docs/api_keys/ChatGLM_apply_for_api_key.md: -------------------------------------------------------------------------------- 1 | ## ChatGLM如何申请一个api_key? 2 | 3 | ### 1. 注册一个ChatGLM官网账号 4 | 首先我们需要注册一个官网的账号,官网网址为https://bigmodel.cn/,点击右上角`登录/注册` 5 | ![register](/assets/api_keys/ChatGLM/login_register1.png) 6 | 7 | 然后可以使用手机号进行注册登录 8 | ![register](/assets/api_keys/ChatGLM/login_register2.png) 9 | 10 | 接着就能看到账户中收到了赠送的2000w token的资源包,这其中包含了`glm-4-plus`, `glm-4-air`等模型,同时`glm-4-flash`是无限量使用的,也就说可以免费使用`glm-4-flash`这个模型。 11 | ![register](/assets/api_keys/ChatGLM/resources.png) 12 | 13 | ### 2. 申请`api_kay` 14 | 进入个人中心,看到以下界面 15 | ![api_key1](/assets/api_keys/ChatGLM/api_key1.png) 16 | 此时,点击右上角财务旁边的小钥匙的选项,进入生成`api_key`的界面 17 | ![api_key2](/assets/api_keys/ChatGLM/api_key2.png) 18 | 19 | 至此,可以直接使用这个api_key,或者自己重新生成一个`api_key`都是可以的。 20 | 21 | -------------------------------------------------------------------------------- /docs/api_keys/DeepSeek_apply_for_api_key.md: -------------------------------------------------------------------------------- 1 | ## DeepSeek如何申请一个api_key? 2 | 3 | ### 1. 注册一个DeepSeek官网账号 4 | 首先我们需要注册一个官网的账号,官网网址为https://www.deepseek.com/,点击`接入API` 5 | ![register](/assets/api_keys/DeepSeek/login_register1.png) 6 | 7 | 然后可以使用手机号进行注册登录 8 | ![register](/assets/api_keys/DeepSeek/login_register2.png) 9 | 10 | 接着就能看到账户中收到了赠送的500w token的资源包。 11 | ![register](/assets/api_keys/DeepSeek/resources.png) 12 | 13 | ### 2. 申请`api_kay` 14 | 15 | 点击左上角的`API keys`,进入生成`api_key`的界面 16 | ![api_key2](/assets/api_keys/DeepSeek/api_key2.png) 17 | 18 | 然后自己创建一个API key即可使用我们的`edg4llm`。 19 | -------------------------------------------------------------------------------- /docs/api_keys/InternLM_apply_for_api_key.md: -------------------------------------------------------------------------------- 1 | ## InternLM如何申请一个api_key? 2 | 3 | ### 1. 注册一个InternLM官网账号 4 | 首先我们需要注册一个官网的账号,官网网址为https://internlm.intern-ai.org.cn/,点击正上方`API`即可。 5 | ![register](/assets/api_keys/InternLM/login_register1.png) 6 | 7 | 然后,点击右上角`注册` 8 | ![register](/assets/api_keys/ChatGLM/login_register2.png) 9 | 10 | 然后可以使用手机号或者邮箱进行注册登录 11 | ![register](/assets/api_keys/InternLM/login_register3.png) 12 | 13 | 重新登录之后,点击`申请更高流控`就可以看到现在有300w token的资源供大家使用。 14 | ![register](/assets/api_keys/InternLM/login_register4.png) 15 | 16 | ### 2. 申请`api_kay` 17 | 点击`获取个人密钥`,看到以下界面 18 | ![api_key1](/assets/api_keys/InternLM/api_key1.png) 19 | 20 | 然后自己创建一个API key即可使用我们的`edg4llm`。 21 | -------------------------------------------------------------------------------- /docs/readme.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alannikos/edg4llm/73b996c36c750fec2d99c934bda01a4d4a57954a/docs/readme.md -------------------------------------------------------------------------------- /edg4llm.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | LICENSE 2 | README.md 3 | setup.py 4 | edg4llm/__init__.py 5 | edg4llm.egg-info/PKG-INFO 6 | edg4llm.egg-info/SOURCES.txt 7 | edg4llm.egg-info/dependency_links.txt 8 | edg4llm.egg-info/entry_points.txt 9 | edg4llm.egg-info/not-zip-safe 10 | edg4llm.egg-info/requires.txt 11 | edg4llm.egg-info/top_level.txt 12 | edg4llm/core/__init__.py 13 | edg4llm/core/dataGenerators.py 14 | edg4llm/core/interface.py 15 | edg4llm/core/pipeline.py 16 | edg4llm/generators/__init__.py 17 | edg4llm/generators/text_generators/__init__.py 18 | edg4llm/generators/text_generators/answer_generator.py 19 | edg4llm/generators/text_generators/base_generator.py 20 | edg4llm/generators/text_generators/dialogue_generator.py 21 | edg4llm/generators/text_generators/question_generator.py 22 | edg4llm/models/__init__.py 23 | edg4llm/models/baseModel.py 24 | edg4llm/models/chatglm.py 25 | edg4llm/models/chatgpt.py 26 | edg4llm/models/deepseek.py 27 | edg4llm/models/internlm.py 28 | edg4llm/processor/__init__.py 29 | edg4llm/processor/postprocess.py 30 | edg4llm/processor/preprocess.py 31 | edg4llm/utils/__init__.py 32 | edg4llm/utils/config.py 33 | edg4llm/utils/data_utils.py 34 | edg4llm/utils/exceptions.py 35 | edg4llm/utils/list_supported_models.py 36 | edg4llm/utils/logger.py 37 | edg4llm/utils/template.py -------------------------------------------------------------------------------- /edg4llm.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /edg4llm.egg-info/entry_points.txt: -------------------------------------------------------------------------------- 1 | [console_scripts] 2 | edg4llm-cli = edg4llm.utils.list_supported_models:main 3 | -------------------------------------------------------------------------------- /edg4llm.egg-info/not-zip-safe: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /edg4llm.egg-info/requires.txt: -------------------------------------------------------------------------------- 1 | requests>=2.32.3 2 | -------------------------------------------------------------------------------- /edg4llm.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | edg4llm 2 | -------------------------------------------------------------------------------- /edg4llm/__init__.py: -------------------------------------------------------------------------------- 1 | from edg4llm.core.interface import EDG4LLM 2 | 3 | __all__ = ["EDG4LLM"] 4 | 5 | __version__ = "1.0.18" 6 | __author__ = "Alannikos" 7 | __license__ = "MIT" 8 | -------------------------------------------------------------------------------- /edg4llm/core/__init__.py: -------------------------------------------------------------------------------- 1 | from edg4llm.core.interface import EDG4LLM 2 | -------------------------------------------------------------------------------- /edg4llm/core/pipeline.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any, Tuple, Dict 3 | 4 | from edg4llm.utils.logger import custom_logger 5 | from edg4llm.core.dataGenerators import DataGenerator 6 | 7 | logger = custom_logger("DataPipeline") 8 | 9 | class DataPipeline: 10 | """ 11 | The DataPipeline class manages the entire process of generating data, designed to 12 | automatically create fine-tuning data for different task types such as question 13 | generation, answer generation, and dialogue generation. 14 | 15 | This class uses a DataGenerator object to handle the core logic of data generation 16 | and dynamically executes the corresponding task based on the provided configuration 17 | parameters. It provides a unified interface for users to easily invoke specific 18 | data generation methods with minimal configuration. 19 | 20 | Attributes: 21 | ---------- 22 | data_generator (DataGenerator): An object that handles the specific data generation tasks. 23 | 24 | Methods: 25 | ---------- 26 | __init__(pConfig): Initializes the DataPipeline class and creates a DataGenerator 27 | object based on the configuration. 28 | generate_data(tConfig): Generates fine-tuning data based on the task configuration. 29 | Supported task types include question generation, answer generation, 30 | and dialogue generation. 31 | """ 32 | 33 | def __init__(self, pConfig): 34 | """ 35 | Initializes the data generation process. 36 | 37 | Parameters 38 | ---------- 39 | pConfig : dict 40 | Configuration for initializing the DataGenerator. Expected to contain: 41 | - model_provider: str 42 | The type of language model to use, by default "chatglm". 43 | - model_name: str 44 | The specific model to use within the model type, by default "chatglm-4-flash". 45 | - base_url : str 46 | The base URL of the LLM API. 47 | - api_key : str 48 | The API key for authentication. 49 | """ 50 | 51 | self.data_generator = DataGenerator(pConfig) 52 | 53 | def generate_data(self, tConfig) -> Dict: 54 | """ 55 | Generates data based on the provided configuration. 56 | 57 | Parameters 58 | ---------- 59 | tConfig : Dict 60 | Task configuration containing the following keys: 61 | - task_type : str 62 | Specifies the type of task ('question', 'answer', or 'dialogue'). 63 | - Other parameters required for data generation, specific to the task type. 64 | 65 | Returns 66 | ------- 67 | dict 68 | A dictionary containing the generated fine-tuning data. 69 | 70 | Raises 71 | ------ 72 | ValueError 73 | If the provided task type is unsupported. 74 | """ 75 | if tConfig["task_type"] == "question": 76 | logger.info("Generated data for task_type: 'question'") 77 | data = self.data_generator.generate_question(tConfig) 78 | elif tConfig["task_type"] == "answer": 79 | logger.info("Generated data for task_type: 'answer'") 80 | data = self.data_generator.generate_answer(tConfig) 81 | elif tConfig["task_type"] == "dialogue": 82 | logger.info("Generated data for task_type: 'dialogue'") 83 | data = self.data_generator.generate_dialogue(tConfig) 84 | else: 85 | logger.error("Unsupported task type: %s", tConfig["task_type"]) 86 | raise ValueError("Unsupported task type") 87 | 88 | return data 89 | -------------------------------------------------------------------------------- /edg4llm/generators/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alannikos/edg4llm/73b996c36c750fec2d99c934bda01a4d4a57954a/edg4llm/generators/__init__.py -------------------------------------------------------------------------------- /edg4llm/generators/text_generators/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alannikos/edg4llm/73b996c36c750fec2d99c934bda01a4d4a57954a/edg4llm/generators/text_generators/__init__.py -------------------------------------------------------------------------------- /edg4llm/generators/text_generators/answer_generator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | from typing import Dict, Any 5 | 6 | from edg4llm.utils.logger import custom_logger 7 | from edg4llm.generators.text_generators.base_generator import BaseGenerator 8 | 9 | logger = custom_logger("AnswerGenerator") 10 | 11 | class AnswerGenerator(BaseGenerator): 12 | """ 13 | A class for generating answers based on user queries using a specified model. 14 | 15 | This class extends the `BaseGenerator` class and provides functionality to generate 16 | answers to user queries based on a given configuration. It interacts with the model's 17 | `execute_request` method to generate responses based on system-level and user-level prompts. 18 | It supports customization through parameters such as temperature, sampling strategies, 19 | and token limits. 20 | 21 | Attributes 22 | ---------- 23 | model : object 24 | The model interface used for generating answers. 25 | 26 | Methods 27 | ------- 28 | generate(tConfig: dict) -> list of dict: 29 | Generates answers based on the provided configuration. 30 | 31 | Notes 32 | ----- 33 | - The `generate` method ensures valid answers are returned, retrying if necessary. 34 | - It logs progress for each generated answer. 35 | """ 36 | 37 | def __init__(self, model): 38 | """ 39 | Initialize the AnswerGenerator. 40 | 41 | Parameters 42 | ---------- 43 | model : object 44 | The model interface used for generating answers. 45 | """ 46 | 47 | super().__init__(model) 48 | 49 | def generate(self, tConfig) -> str: 50 | """ 51 | Generate answers based on the provided configuration. 52 | 53 | This method generates one or more answers based on the parameters provided in 54 | the `tConfig` dictionary. It uses the model's `execute_request` method to generate 55 | answers based on the system and user prompts, with options to control randomness, 56 | output length, and sampling strategy. 57 | 58 | Parameters 59 | ---------- 60 | tConfig : dict 61 | A configuration dictionary containing the following key-value pairs: 62 | - "system_prompt" : str, optional 63 | A system-level prompt that provides context for generating the answer. Default is an empty string. 64 | - "user_prompt" : str 65 | A user-provided prompt (query) to generate the corresponding answer. 66 | - "model" : str, optional 67 | The specific model to use for answer generation. Default is "glm-4-flash". 68 | - "do_sample" : bool, optional 69 | Whether to use sampling strategies during answer generation. Default is True. 70 | - "temperature" : float, optional 71 | A sampling parameter to control the randomness of the output. Must be between 0.0 and 1.0. Default is 0.95. 72 | - "top_p" : float, optional 73 | Nucleus sampling parameter controlling the cumulative probability range for token selection. 74 | Must be between 0.0 and 1.0. Default is 0.7. 75 | - "max_tokens" : int, optional 76 | The maximum number of tokens to generate in the answer. Default is 4095. 77 | - "num_samples" : int, optional 78 | The number of answers to generate. Default is 1. 79 | 80 | Returns 81 | ------- 82 | list of dict 83 | A list of dictionaries containing the generated answers. Each dictionary 84 | includes the generated answer content and relevant metadata. 85 | 86 | Notes 87 | ----- 88 | - The method will retry generating answers if the model fails to provide a valid response. 89 | - Progress and debug information are logged for each generated answer. 90 | """ 91 | 92 | # Extract configuration parameters 93 | system_prompt = tConfig.get("system_prompt", "") 94 | user_prompt = tConfig.get("user_prompt", "") 95 | do_sample = tConfig.get("do_sample", True) 96 | temperature = tConfig.get("temperature", 0.95) 97 | top_p = tConfig.get("top_p", 0.7) 98 | max_tokens = tConfig.get("max_tokens", 4095) 99 | num_samples = tConfig.get("num_samples", 1) # Default is to generate 1 sample 100 | question_path = tConfig.get("question_path", None) 101 | 102 | try: 103 | with open(question_path, "r", encoding="utf-8") as file: 104 | data = json.load(file) 105 | 106 | if isinstance(data, dict): # If it's a single dictionary, wrap it in a list 107 | data = [data] 108 | elif not isinstance(data, list): # Ensure it's a list of dictionaries 109 | raise ValueError("Invalid JSON structure. Expected a list or a dictionary.") 110 | 111 | # Extract questions 112 | questions = [item["question"] for item in data if "question" in item] 113 | except FileNotFoundError: 114 | logger.error("The file at path %s was not found.", question_path) 115 | return None 116 | except json.JSONDecodeError as e: 117 | logger.error("Error decoding JSON from file %s: %s", question_path, str(e)) 118 | return None 119 | except Exception as e: 120 | logger.error("Unexpected error: %s", str(e)) 121 | return None 122 | 123 | if len(questions) != num_samples: 124 | logger.error( 125 | "The number of questions (%d) does not match the expected number (%d). Please check your input.", 126 | len(questions), 127 | num_samples, 128 | ) 129 | 130 | sys.exit(1) # 非零退出码表示异常终止 131 | 132 | # List to store the generated dialogues 133 | dialogues = [] 134 | 135 | # Generate dialogues for the specified number of samples 136 | total_samples = num_samples # Total number of samples to generate 137 | logger.info("Starting the data generation process.") 138 | for _idx, question in enumerate(questions): 139 | retry_count = 0 # 初始化重试计数 140 | max_retries = 5 # 设置最大重试次数(根据需要调整) 141 | 142 | while True: # Keep trying until valid dialogue data is generated 143 | retry_count += 1 144 | 145 | generated_answer = self.model.execute_request( 146 | system_prompt=system_prompt, 147 | user_prompt=user_prompt.replace("EDG4LLM", question), 148 | do_sample=do_sample, 149 | temperature=temperature, 150 | top_p=top_p, 151 | max_tokens=max_tokens, 152 | ) 153 | 154 | if "error" in generated_answer: 155 | logger.warning( 156 | "Sample %d: Request failed with error: %s. Retrying (%d/%d)...", 157 | _idx + 1, 158 | generated_answer["error"], 159 | retry_count, 160 | max_retries, 161 | ) 162 | 163 | if retry_count >= max_retries: 164 | logger.error("Sample %d: Max retries reached. Skipping this sample.", _idx + 1) 165 | break # 跳出当前样本,进入下一个 166 | continue # 继续当前样本的生成 167 | 168 | # Convert the generated dialogue to the desired format (e.g., Alpaca format) 169 | converted_generated_answer = self._convert_original_to_alpaca_answer(system_prompt, question, generated_answer) 170 | 171 | if converted_generated_answer is not None: 172 | # If the dialogue is valid, append it to the results and break the loop 173 | dialogues.append(converted_generated_answer) 174 | break 175 | else: 176 | logger.warning( 177 | "Sample %d: Generated answer is None. Retrying (%d/%d)...", 178 | _idx + 1, 179 | retry_count, 180 | max_retries, 181 | ) 182 | 183 | if retry_count >= max_retries: 184 | logger.error("Sample %d: Max retries reached. Skipping this sample.", _idx + 1) 185 | break # 跳出当前样本 186 | 187 | # Log the progress of dialogue generation 188 | progress = ((_idx+1) / total_samples) * 100 189 | logger.info("Data generation progress: %.2f%% (%d/%d samples completed)", progress, _idx+1, total_samples) 190 | 191 | return dialogues 192 | -------------------------------------------------------------------------------- /edg4llm/generators/text_generators/base_generator.py: -------------------------------------------------------------------------------- 1 | import os 2 | from abc import ABC, abstractmethod 3 | from typing import Dict 4 | 5 | from edg4llm.processor.postprocess import PostProcessor 6 | class BaseGenerator(ABC): 7 | """ 8 | Base class for all data generators, defining a common interface for generating data. 9 | 10 | This class serves as a foundation for different types of data generators, providing common functionality 11 | such as interaction with a model and post-processing of generated data. Specific generators should extend 12 | this class and implement their own `generate` method. 13 | 14 | Attributes 15 | ---------- 16 | model : object 17 | The model interface used for generating data. 18 | postprocessor : PostProcessor 19 | An instance of the PostProcessor class for handling post-processing of generated data. 20 | 21 | Methods 22 | ------- 23 | generate(prompt: str) -> str 24 | Abstract method to generate data based on a prompt. Must be implemented by subclasses. 25 | 26 | """ 27 | def __init__(self, model): 28 | """ 29 | Initialize the generator. 30 | 31 | Parameters 32 | ---------- 33 | model : object 34 | The model interface used for generating data. 35 | """ 36 | 37 | self.model = model 38 | self.postprocessor = PostProcessor() 39 | 40 | @abstractmethod 41 | def generate(self, prompt: str) -> str: 42 | """ 43 | Convert original data into Alpaca format. 44 | 45 | This method uses the PostProcessor to process conversation data and structure it 46 | in a format suitable for Alpaca-based models. 47 | 48 | Parameters 49 | ---------- 50 | system_prompt : str 51 | The system-level prompt for context in the Alpaca format. 52 | single_data : str 53 | The raw conversation data to be processed. 54 | 55 | Returns 56 | ------- 57 | dict 58 | The conversation data converted to Alpaca format. 59 | """ 60 | pass 61 | 62 | def _convert_original_to_alpaca(self, system_prompt, single_data): 63 | """ 64 | Convert original data into Alpaca format. 65 | 66 | This method uses the PostProcessor to process conversation data and structure it 67 | in a format suitable for Alpaca-based models. 68 | 69 | Parameters 70 | ---------- 71 | system_prompt : str 72 | The system-level prompt for context in the Alpaca format. 73 | single_data : str 74 | The raw conversation data to be processed. 75 | 76 | Returns 77 | ------- 78 | dict 79 | The conversation data converted to Alpaca format. 80 | """ 81 | 82 | converted_data = self.postprocessor.dialogue_postprocessing(conversation_data=single_data, system_prompt=system_prompt) 83 | 84 | return converted_data 85 | 86 | def _convert_original_to_json(self, single_data): 87 | """ 88 | Convert original data into JSON format. 89 | 90 | This method uses the PostProcessor to process raw data into a JSON-compatible structure. 91 | 92 | Parameters 93 | ---------- 94 | single_data : str 95 | The raw question data to be processed. 96 | 97 | Returns 98 | ------- 99 | dict 100 | The data converted into JSON format. 101 | """ 102 | 103 | converted_data = self.postprocessor.question_postprocessing(question_data=single_data) 104 | 105 | return converted_data 106 | 107 | def _convert_original_to_alpaca_answer(self, system_prompt, question, single_data): 108 | """ 109 | Convert original data into Alpaca answer format. 110 | 111 | This method uses the PostProcessor to process raw data into an answer format suitable for Alpaca-based models. 112 | 113 | Parameters 114 | ---------- 115 | system_prompt : str 116 | The system-level prompt for context in the Alpaca format. 117 | question : str 118 | The question text for which the answer is generated. 119 | single_data : str 120 | The raw answer data to be processed. 121 | 122 | Returns 123 | ------- 124 | dict 125 | The data converted into Alpaca format. 126 | """ 127 | 128 | converted_data = self.postprocessor.answer_postprocessing(question=question, answer=single_data, system_prompt=system_prompt) 129 | 130 | return converted_data 131 | -------------------------------------------------------------------------------- /edg4llm/generators/text_generators/dialogue_generator.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict, List, Any 3 | 4 | from edg4llm.utils.logger import custom_logger 5 | from edg4llm.generators.text_generators.base_generator import BaseGenerator 6 | 7 | logger = custom_logger("DialogueGenerator") 8 | 9 | class DialogueGenerator(BaseGenerator): 10 | """ 11 | Dialogue Generator class for generating dialogues using a specified model. 12 | 13 | This class extends the `BaseGenerator` and utilizes the given model to generate dialogues 14 | based on user input and system prompts. It provides flexibility to control generation parameters 15 | like sampling strategies, temperature, and output format. 16 | 17 | Parameters 18 | ---------- 19 | model : object 20 | The model interface used for generating dialogues. This model must have the 21 | `execute_request` method for generating dialogue based on the given parameters. 22 | """ 23 | 24 | def __init__(self, model): 25 | """ 26 | Initialize the Dialogue Generator. 27 | 28 | This constructor initializes the `DialogueGenerator` by calling the base class constructor 29 | with the provided model. It sets up the necessary components for generating dialogues. 30 | 31 | Parameters 32 | ---------- 33 | model : object 34 | The model interface to be used for generating dialogues. It should provide 35 | the `execute_request` method to generate data based on the parameters. 36 | 37 | Notes 38 | ----- 39 | The `model` should be capable of handling inputs like system prompts, user prompts, 40 | and additional parameters for controlling the text generation process. 41 | """ 42 | super().__init__(model) 43 | 44 | def generate(self, tConfig) -> List: 45 | """ 46 | Generate dialogues based on the provided configuration. 47 | 48 | This method generates one or more dialogues based on the parameters provided in 49 | the `tConfig` dictionary. The method interacts with the model's `execute_request` 50 | function to generate dialogue based on the system and user prompts. It also supports 51 | various options for controlling randomness, output length, and sampling strategy. 52 | 53 | Parameters 54 | ---------- 55 | tConfig : dict 56 | A configuration dictionary containing the following key-value pairs: 57 | - "system_prompt" : str, optional 58 | A system-level prompt that guides the dialogue generation. Default is an empty string. 59 | - "user_prompt" : str, optional 60 | A user-provided prompt to initiate the dialogue generation. Default is an empty string. 61 | - "model" : str, optional 62 | The specific model to use for generation. Default is "glm-4-flash". 63 | - "do_sample" : bool, optional 64 | Whether to use sampling strategies during text generation. Default is True. 65 | - "temperature" : float, optional 66 | A sampling parameter to control the randomness of output. Must be between 0.0 and 1.0. Default is 0.95. 67 | - "top_p" : float, optional 68 | Nucleus sampling parameter controlling the cumulative probability range for token selection. 69 | Must be between 0.0 and 1.0. Default is 0.7. 70 | - "max_tokens" : int, optional 71 | The maximum number of tokens to generate. Default is 4095. 72 | - "num_samples" : int, optional 73 | The number of dialogue samples to generate. Default is 1. 74 | 75 | Returns 76 | ------- 77 | list of dict 78 | A list of dictionaries containing the generated dialogues. Each dictionary 79 | includes the generated dialogue content. 80 | 81 | Notes 82 | ----- 83 | - The method will attempt to generate dialogues until a valid response is generated. 84 | If the generated dialogue is `None`, it will retry. 85 | - Progress is logged for each sample generated. 86 | """ 87 | 88 | # Extract configuration parameters 89 | system_prompt = tConfig.get("system_prompt", "") 90 | user_prompt = tConfig.get("user_prompt", "") 91 | do_sample = tConfig.get("do_sample", True) 92 | temperature = tConfig.get("temperature", 0.95) 93 | top_p = tConfig.get("top_p", 0.7) 94 | max_tokens = tConfig.get("max_tokens", 4095) 95 | num_samples = tConfig.get("num_samples", 1) # Default is to generate 1 sample 96 | 97 | # List to store the generated dialogues 98 | dialogues = [] 99 | 100 | # Generate dialogues for the specified number of samples 101 | total_samples = num_samples # Total number of samples to generate 102 | logger.info("Starting the data generation process.") 103 | for _idx in range(1, num_samples + 1): 104 | retry_count = 0 # 初始化重试计数 105 | max_retries = 5 # 设置最大重试次数(根据需要调整) 106 | 107 | while True: # Keep trying until valid dialogue data is generated 108 | retry_count += 1 109 | 110 | generated_dialogue = self.model.execute_request( 111 | system_prompt=system_prompt, 112 | user_prompt=user_prompt, 113 | do_sample=do_sample, 114 | temperature=temperature, 115 | top_p=top_p, 116 | max_tokens=max_tokens, 117 | ) 118 | 119 | if "error" in generated_dialogue: 120 | logger.warning( 121 | "Sample %d: Request failed with error: %s. Retrying (%d/%d)...", 122 | _idx, 123 | generated_dialogue["error"], 124 | retry_count, 125 | max_retries, 126 | ) 127 | 128 | if retry_count >= max_retries: 129 | logger.error("Sample %d: Max retries reached. Skipping this sample.", _idx) 130 | break # 跳出当前样本,进入下一个 131 | 132 | continue # 继续当前样本的生成 133 | 134 | 135 | # Convert the generated dialogue to the desired format (e.g., Alpaca format) 136 | converted_generated_dialogue = self._convert_original_to_alpaca(system_prompt, generated_dialogue) 137 | 138 | if converted_generated_dialogue is not None: 139 | # If the dialogue is valid, append it to the results and break the loop 140 | dialogues.append(converted_generated_dialogue) 141 | break 142 | else: 143 | logger.warning( 144 | "Sample %d: Generated dialogue is None. Retrying (%d/%d)...", 145 | _idx, 146 | retry_count, 147 | max_retries, 148 | ) 149 | 150 | if retry_count >= max_retries: 151 | logger.error("Sample %d: Max retries reached. Skipping this sample.", _idx) 152 | break # 跳出当前样本 153 | 154 | 155 | # Log the progress of dialogue generation 156 | progress = (_idx / total_samples) * 100 157 | logger.info("Data generation progress: %.2f%% (%d/%d samples completed)", progress, _idx, total_samples) 158 | 159 | return dialogues 160 | -------------------------------------------------------------------------------- /edg4llm/generators/text_generators/question_generator.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict, List, Any 3 | from edg4llm.utils.logger import custom_logger 4 | from edg4llm.generators.text_generators.base_generator import BaseGenerator 5 | 6 | logger = custom_logger("QuestionGenerator") 7 | 8 | class QuestionGenerator(BaseGenerator): 9 | """ 10 | A class for generating questions based on user prompts and configuration. 11 | 12 | This class extends the `BaseGenerator` class and provides functionality to generate 13 | questions using a specified model. It interacts with the model's `execute_request` 14 | method to create output based on user-defined parameters such as sampling strategies, 15 | temperature, and maximum tokens. 16 | 17 | Attributes 18 | ---------- 19 | model : object 20 | The model interface used for generating questions. 21 | 22 | Methods 23 | ------- 24 | generate(tConfig: dict) -> list of dict: 25 | Generates questions based on the provided configuration. 26 | 27 | Notes 28 | ----- 29 | - The `generate` method ensures valid responses are returned, retrying if necessary. 30 | - Logs progress for each generated question. 31 | """ 32 | 33 | def __init__(self, model): 34 | """ 35 | Initialize the QuestionGenerator. 36 | 37 | Parameters 38 | ---------- 39 | model : object 40 | The model interface used for generating questions. 41 | """ 42 | 43 | super().__init__(model) 44 | 45 | def generate(self, tConfig: Dict) -> List: 46 | """ 47 | Generate questions based on the provided configuration. 48 | 49 | This method generates one or more questions using the parameters specified 50 | in the `tConfig` dictionary. It interacts with the model's `execute_request` 51 | method to generate output based on user prompts and various sampling options. 52 | 53 | Parameters 54 | ---------- 55 | tConfig : dict 56 | A dictionary containing configuration options for question generation: 57 | - "system_prompt" : str, optional 58 | A system-level instruction to guide the question generation. Default is an empty string. 59 | - "user_prompt" : str, optional 60 | A user-provided input to guide the question generation. Default is an empty string. 61 | - "model" : str, optional 62 | Specifies the model for text generation. Default is "glm-4-flash". 63 | - "do_sample" : bool, optional 64 | Whether to use sampling during generation. Default is True. 65 | - "temperature" : float, optional 66 | Controls randomness in output. Value should be between 0.0 and 1.0. Default is 0.95. 67 | - "top_p" : float, optional 68 | Nucleus sampling parameter to limit token selection to a cumulative probability. Default is 0.7. 69 | - "max_tokens" : int, optional 70 | The maximum number of tokens for the output. Default is 4095. 71 | - "num_samples" : int, optional 72 | The number of question samples to generate. Default is 1. 73 | 74 | Returns 75 | ------- 76 | list of dict 77 | A list of dictionaries containing the generated questions. 78 | 79 | Notes 80 | ----- 81 | - The method retries generation until a valid response is obtained. 82 | - Logs progress for each generated sample. 83 | """ 84 | 85 | # Extract parameters from the configuration 86 | system_prompt = tConfig.get("system_prompt", "") 87 | user_prompt = tConfig.get("user_prompt", "") 88 | do_sample = tConfig.get("do_sample", True) 89 | temperature = tConfig.get("temperature", 0.95) 90 | top_p = tConfig.get("top_p", 0.7) 91 | max_tokens = tConfig.get("max_tokens", 4095) 92 | num_samples = tConfig.get("num_samples", 1) 93 | 94 | # Initialize a list to store generated questions 95 | questions = [] 96 | cur_len = 0 97 | # Generate questions for the specified number of samples 98 | logger.info("Starting the data generation process.") 99 | for _idx in range(1, num_samples + 1): 100 | retry_count = 0 # 初始化重试计数 101 | max_retries = 5 # 设置最大重试次数(根据需要调整) 102 | 103 | while True: # Retry until a valid question is generated 104 | retry_count += 1 105 | 106 | generated_question = self.model.execute_request( 107 | system_prompt=system_prompt, 108 | user_prompt=user_prompt, 109 | do_sample=do_sample, 110 | temperature=temperature, 111 | top_p=top_p, 112 | max_tokens=max_tokens, 113 | ) 114 | 115 | if "error" in generated_question: 116 | logger.warning( 117 | "Sample %d: Request failed with error: %s. Retrying (%d/%d)...", 118 | _idx, 119 | generated_question["error"], 120 | retry_count, 121 | max_retries, 122 | ) 123 | 124 | if (retry_count >= max_retries): 125 | logger.error("Sample %d: Max retries reached. Skipping this sample.", _idx) 126 | break # 跳出当前样本 127 | 128 | # Convert the raw output to a specific format 129 | converted_question = self._convert_original_to_json(generated_question) 130 | 131 | if converted_question is not None: 132 | cur_len = len(converted_question) 133 | questions.extend(converted_question) 134 | break 135 | else: 136 | logger.warning( 137 | "Sample %d: Generated dialogue is None. Retrying (%d/%d)...", 138 | _idx, 139 | retry_count, 140 | max_retries, 141 | ) 142 | 143 | if retry_count >= max_retries: 144 | logger.error("Sample %d: Max retries reached. Skipping this sample.", _idx) 145 | break # 跳出当前样本 146 | 147 | # Log progress for tracking generation completion 148 | progress = (_idx / num_samples) * 100 149 | logger.info("Generation progress: %.2f%% (%d samples generated, %d/%d epoch completed)", progress, cur_len, _idx, num_samples) 150 | 151 | return questions 152 | -------------------------------------------------------------------------------- /edg4llm/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alannikos/edg4llm/73b996c36c750fec2d99c934bda01a4d4a57954a/edg4llm/models/__init__.py -------------------------------------------------------------------------------- /edg4llm/models/baseModel.py: -------------------------------------------------------------------------------- 1 | """ 2 | Module for defining the base class of EDG models. 3 | 4 | This file contains the abstract base class `EDGBaseModel`, which serves as a foundation for implementing various 5 | machine learning models. The class defines key methods that must be implemented by any derived model class 6 | to handle requests, send HTTP requests, and interact with APIs. 7 | 8 | Classes 9 | ------- 10 | EDGBaseModel(ABC) 11 | Abstract base class for EDG models, providing a standard structure for derived model implementations. 12 | 13 | Methods 14 | ------- 15 | __init__(api_key: str = None, base_url: str = None, model_name: str = None) 16 | Initializes the base model with API key, base URL, and model name. 17 | 18 | execute_request(system_prompt: str, user_prompt: str, **kwargs) -> str 19 | Abstract method to process user input and generate model responses. 20 | Must be implemented by derived classes. 21 | 22 | send_request(request: Dict[str, Any]) -> Dict[str, Any] 23 | Abstract method to send HTTP requests and handle server interactions. 24 | Must be implemented by derived classes. 25 | """ 26 | 27 | import requests 28 | from abc import ABC, abstractmethod 29 | from typing import Any, Dict 30 | 31 | from edg4llm.utils.logger import custom_logger 32 | 33 | logger = custom_logger('baseModel') 34 | 35 | 36 | class EDGBaseModel(ABC): 37 | """ 38 | Abstract base class for EDG models. 39 | 40 | This class defines the blueprint for machine learning model implementations. Derived classes must 41 | implement methods to process user prompts, interact with APIs, and handle HTTP requests. 42 | 43 | Attributes 44 | ---------- 45 | api_key : str 46 | The API key required for authenticating requests. 47 | 48 | base_url : str 49 | The base URL of the model API endpoint. 50 | 51 | model_name : str 52 | The name of the model, used to differentiate between various models. 53 | """ 54 | 55 | def __init__(self, api_key: str = None, base_url: str = None, model_name: str = None): 56 | """ 57 | Initializes the base model with API key, base URL, and model name. 58 | 59 | Parameters 60 | ---------- 61 | api_key : str, optional 62 | The API key for authenticating requests. Default is None. 63 | 64 | base_url : str, optional 65 | The base URL of the model API endpoint. Default is None. 66 | 67 | model_name : str, optional 68 | The name of the model, used for identifying different models. Default is None. 69 | """ 70 | self.api_key = api_key 71 | self.base_url = base_url 72 | self.model_name = model_name 73 | 74 | @abstractmethod 75 | def execute_request(self, system_prompt: str, user_prompt: str, **kwargs) -> str: 76 | """ 77 | Abstract method to process and execute a request. 78 | 79 | This method must be implemented by derived classes. It processes user input and generates 80 | responses based on a system prompt and additional parameters. 81 | 82 | Parameters 83 | ---------- 84 | system_prompt : str 85 | The system-level instruction or prompt defining the role or behavior of the model. 86 | 87 | user_prompt : str 88 | The user's input or query for the model. 89 | 90 | kwargs : dict 91 | Additional parameters for processing the request. 92 | 93 | Returns 94 | ------- 95 | str 96 | The response generated by the model. 97 | 98 | Notes 99 | ----- 100 | - Derived classes should implement this method to handle the specific logic for generating responses. 101 | """ 102 | pass 103 | 104 | @abstractmethod 105 | def send_request(self, request: Dict[str, Any]) -> Dict[str, Any]: 106 | """ 107 | Abstract method to send HTTP requests. 108 | 109 | This method must be implemented by derived classes to handle API interactions and perform 110 | error handling for HTTP requests. 111 | 112 | Parameters 113 | ---------- 114 | request : dict 115 | A dictionary containing all necessary information for the HTTP request. 116 | 117 | Returns 118 | ------- 119 | dict 120 | The server's response as a dictionary. 121 | 122 | Notes 123 | ----- 124 | - Derived classes should implement this method to handle API-specific logic and error handling. 125 | """ 126 | pass 127 | -------------------------------------------------------------------------------- /edg4llm/models/chatglm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import requests 3 | from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union, cast 4 | 5 | from edg4llm.utils.logger import custom_logger 6 | from edg4llm.models.baseModel import EDGBaseModel 7 | from edg4llm.utils.exceptions import HttpClientError, InvalidPromptError 8 | 9 | logger = custom_logger('chatglm') 10 | 11 | class EDGChatGLM(EDGBaseModel): 12 | """ 13 | EDGChatGLM interface for interacting with the ChatGLM model to generate text based on given prompts. 14 | 15 | This class provides an interface to interact with the ChatGLM model for generating text 16 | based on a system and user prompt. It supports customizable parameters such as temperature, 17 | sampling strategies, and model selection. It also handles HTTP requests and error management. 18 | 19 | Parameters 20 | ---------- 21 | base_url : str, optional 22 | The base URL for the ChatGLM API. If not provided, defaults to None. 23 | api_key : str, optional 24 | The API key for authenticating with the ChatGLM API. If not provided, defaults to None. 25 | """ 26 | 27 | def __init__(self, base_url: str = None, api_key: str = None, model_name: str = 'glm-4-flash'): 28 | """ 29 | Initialize the ChatGLM model interface. 30 | 31 | This constructor initializes the `EDGChatGLM` class by calling the base class constructor 32 | and passing the API key, base URL, and model name ("ChatGLM"). It sets up the necessary 33 | configuration for interacting with the ChatGLM API. 34 | 35 | Parameters 36 | ---------- 37 | base_url : str, optional 38 | The base URL for the ChatGLM API. Default is None. 39 | api_key : str, optional 40 | The API key for authenticating with the ChatGLM API. Default is None. 41 | model_name: str, optional 42 | The specific model to use within the selected provider. Default is "glm-4-flash". 43 | Notes 44 | ----- 45 | The base URL and API key are required for successful communication with the ChatGLM API. 46 | """ 47 | super().__init__(api_key, base_url, model_name=model_name) 48 | 49 | def execute_request( 50 | self, 51 | system_prompt: str = None, 52 | user_prompt: str = None, 53 | do_sample: bool = True, 54 | temperature: float = 0.95, 55 | top_p: float = 0.7, 56 | max_tokens: int = 4095 57 | ) -> str: 58 | """ 59 | Generate text using the ChatGLM model based on the provided prompts and parameters. 60 | 61 | This method calls the internal request execution function and handles the text 62 | generation process using the specified system and user prompts. It allows controlling 63 | text generation via parameters such as temperature, sampling strategy, and token limits. 64 | 65 | Parameters 66 | ---------- 67 | system_prompt : str, optional 68 | The system-level prompt that sets the context for the conversation. Default is None. 69 | user_prompt : str, optional 70 | The user-provided prompt that initiates the conversation. Default is None. 71 | do_sample : bool, optional 72 | Whether to use sampling during text generation. Default is True. 73 | temperature : float, optional 74 | Sampling temperature to control randomness. Default is 0.95. 75 | top_p : float, optional 76 | Nucleus sampling parameter for controlling randomness. Default is 0.7. 77 | max_tokens : int, optional 78 | The maximum number of tokens to generate in the output. Default is 4095. 79 | 80 | Returns 81 | ------- 82 | str 83 | The generated text content from the model. 84 | 85 | Raises 86 | ------ 87 | InvalidPromptError 88 | If both the system and user prompts are None. 89 | """ 90 | response = self._execute_request(system_prompt, user_prompt, self.model_name, do_sample, temperature, top_p, max_tokens) 91 | return response 92 | 93 | def send_request(self, request: Dict[str, Any]) -> Dict[str, Any]: 94 | """ 95 | Send an HTTP request to the ChatGLM API. 96 | 97 | This method sends a POST request to the ChatGLM API with the provided request data. 98 | It returns the response data as a dictionary. 99 | 100 | Parameters 101 | ---------- 102 | request : dict 103 | A dictionary containing the request data, including the URL, headers, and JSON body. 104 | 105 | Returns 106 | ------- 107 | dict 108 | The response from the API in the form of a dictionary. 109 | 110 | Raises 111 | ------ 112 | HttpClientError 113 | If any error occurs during the HTTP request process. 114 | """ 115 | response = self._send_request(request=request) 116 | return response 117 | 118 | def _send_request(self, request: Dict[str, Any]) -> Dict[str, Any]: 119 | """ 120 | Internal method to send a POST request to the ChatGLM API. 121 | 122 | This method handles the actual HTTP POST request to the ChatGLM API. It includes 123 | error handling for HTTP errors, connection issues, timeouts, and JSON decoding. 124 | 125 | Parameters 126 | ---------- 127 | request : dict 128 | A dictionary containing the request data, including the URL, headers, and JSON body. 129 | 130 | Returns 131 | ------- 132 | dict 133 | The JSON response from the API. 134 | 135 | Raises 136 | ------ 137 | HttpClientError 138 | If an error occurs during the request. 139 | """ 140 | url = request.get("url", "https://open.bigmodel.cn/api/paas/v4/chat/completions") 141 | headers = {**request.get("headers", {})} 142 | json = request.get("json", {}) 143 | try: 144 | response = requests.post( 145 | url=url, 146 | headers=headers, 147 | json=json, 148 | timeout=30, 149 | ) 150 | response.raise_for_status() 151 | return response.json()["choices"][0]["message"]["content"].strip() 152 | 153 | except requests.exceptions.HTTPError as e: 154 | # Handle HTTP error exceptions 155 | status_code = e.response.status_code 156 | logger.error( 157 | "HTTP error occurred. Status Code: %s, URL: %s, Message: %s", 158 | status_code, 159 | url, 160 | e, 161 | ) 162 | 163 | return {"error": "HTTP error", "status_code": status_code, "message": str(e)} 164 | 165 | 166 | except requests.exceptions.ConnectionError as e: 167 | # Handle connection errors 168 | logger.error("Connection error occurred while connecting to %s: %s", url, e) 169 | 170 | return {"error": "Connection error", "message": str(e)} 171 | 172 | except requests.exceptions.Timeout as e: 173 | # Handle timeout errors 174 | logger.error("Timeout occurred while sending request to %s: %s", url, e) 175 | 176 | return {"error": "Timeout", "message": str(e)} 177 | 178 | 179 | except requests.exceptions.RequestException as e: 180 | # Handle any generic request exceptions 181 | logger.error( 182 | "Request exception occurred while sending request to %s: %s", url, e 183 | ) 184 | 185 | return {"error": "Request exception", "message": str(e)} 186 | 187 | 188 | except ValueError as e: 189 | # Handle JSON decoding errors 190 | logger.error("JSON decoding error occurred: %s", e) 191 | 192 | return {"error": "JSON decoding error", "message": str(e)} 193 | 194 | except Exception as e: 195 | # Catch any unexpected errors 196 | logger.critical( 197 | "An unexpected error occurred while sending request to %s: %s", url, e 198 | ) 199 | 200 | return {"error": "Unexpected error", "message": str(e)} 201 | 202 | def _execute_request( 203 | self, 204 | system_prompt: str = None, 205 | user_prompt: str = None, 206 | model: str = "glm-4-flash", 207 | do_sample: bool = True, 208 | temperature: float = 0.95, 209 | top_p: float = 0.7, 210 | max_tokens: int = 4095 211 | ) -> str: 212 | """ 213 | Internal method to prepare the request data and execute the request for text generation. 214 | 215 | This method prepares the necessary data (including headers, JSON body) for the 216 | ChatGLM API request and then calls the `send_request` method to send the request 217 | and return the response. 218 | 219 | Parameters 220 | ---------- 221 | system_prompt : str, optional 222 | The system-level prompt that provides context for the dialogue generation. 223 | Default is None. 224 | user_prompt : str, optional 225 | The user-provided prompt that initiates the generation. 226 | Default is None. 227 | model : str, optional 228 | The model to use for the generation. Default is "glm-4-flash". 229 | do_sample : bool, optional 230 | Whether to use sampling during text generation. Default is True. 231 | temperature : float, optional 232 | Sampling temperature to control randomness. Default is 0.95. 233 | top_p : float, optional 234 | Nucleus sampling parameter for controlling randomness. Default is 0.7. 235 | max_tokens : int, optional 236 | The maximum number of tokens to generate. Default is 4095. 237 | 238 | Returns 239 | ------- 240 | str 241 | The generated text content from the model. 242 | 243 | Raises 244 | ------ 245 | InvalidPromptError 246 | If both the system and user prompts are None. 247 | """ 248 | if (system_prompt is None and user_prompt is None): 249 | logger.error("Both prompts cannot be empty") 250 | raise InvalidPromptError("Both prompts cannot be empty") 251 | 252 | request_data = { 253 | "url": f"{self.base_url}", 254 | "headers": { 255 | "Authorization": f"Bearer {self.api_key}", 256 | "Content-Type": "application/json", 257 | }, 258 | "json": { 259 | "model": model, 260 | "messages": [ 261 | {"role": "system", "content": system_prompt}, 262 | {"role": "user", "content": user_prompt}, 263 | ], 264 | "do_sample": do_sample, 265 | "temperature": temperature, 266 | "top_p": top_p, 267 | "max_tokens": max_tokens, 268 | }, 269 | } 270 | 271 | response = self.send_request(request_data) 272 | 273 | return response 274 | -------------------------------------------------------------------------------- /edg4llm/models/chatgpt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import requests 3 | from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union, cast 4 | 5 | from edg4llm.utils.logger import custom_logger 6 | from edg4llm.models.baseModel import EDGBaseModel 7 | from edg4llm.utils.exceptions import HttpClientError, InvalidPromptError 8 | 9 | logger = custom_logger('chatgpt') 10 | 11 | class EDGChatGPT(EDGBaseModel): 12 | """ 13 | A class to interface with the ChatGPT model for text generation. 14 | 15 | This class extends the `EDGBaseModel` abstract base class to implement a specific interface 16 | for interacting with the ChatGPT API. It supports text generation using system-level and 17 | user-level prompts with customizable parameters such as temperature, sampling strategies, 18 | and token limits. The class also includes methods to handle HTTP requests and manage errors. 19 | 20 | Attributes 21 | ---------- 22 | base_url : str 23 | The base URL for the ChatGPT API endpoint. 24 | api_key : str 25 | The API key for authenticating with the ChatGPT API. 26 | model_name : str 27 | The specific model to use, defaulting to "gpt-4o-mini". 28 | 29 | Methods 30 | ------- 31 | execute_request(system_prompt: str, user_prompt: str, do_sample: bool, temperature: float, top_p: float, max_tokens: int) -> str: 32 | Generates text using the ChatGPT model based on the provided prompts and parameters. 33 | 34 | send_request(request: Dict[str, Any]) -> Dict[str, Any]: 35 | Sends an HTTP POST request to the ChatGPT API and returns the response as a dictionary. 36 | 37 | Notes 38 | ----- 39 | - The `base_url` and `api_key` are required for proper communication with the ChatGPT API. 40 | - Provides detailed error handling for HTTP, connection, timeout, and JSON decoding issues. 41 | - Supports customizable text generation parameters for flexibility in model behavior. 42 | """ 43 | 44 | def __init__(self, base_url:str = None, api_key: str = None, model_name: str = "gpt-4o-mini"): 45 | """ 46 | Initialize the ChatGPT model interface. 47 | 48 | Parameters 49 | ---------- 50 | base_url : str, optional 51 | The base URL for the ChatGPT API. Default is None. 52 | api_key : str, optional 53 | The API key for authenticating with the ChatGPT API. Default is None. 54 | model_name : str, optional 55 | The specific model to use, defaulting to "gpt-4o-mini". 56 | """ 57 | 58 | super().__init__(api_key, base_url, model_name=model_name) 59 | 60 | def execute_request( 61 | self 62 | , system_prompt: str = None 63 | , user_prompt: str = None 64 | , do_sample: bool = True 65 | , temperature: float = 0.95 66 | , top_p: float = 0.7 67 | , max_tokens: int = 4095 68 | ) -> str: 69 | 70 | """ 71 | Generate text using the ChatGPT model based on the provided prompts and parameters. 72 | 73 | Parameters 74 | ---------- 75 | system_prompt : str, optional 76 | The system-level prompt providing context for the text generation. Default is None. 77 | user_prompt : str, optional 78 | The user-provided prompt initiating the text generation. Default is None. 79 | do_sample : bool, optional 80 | Whether to use sampling during text generation. Default is True. 81 | temperature : float, optional 82 | Sampling temperature to control randomness. Default is 0.95. 83 | top_p : float, optional 84 | Nucleus sampling parameter to control randomness. Default is 0.7. 85 | max_tokens : int, optional 86 | The maximum number of tokens to generate. Default is 4095. 87 | 88 | Returns 89 | ------- 90 | str 91 | The generated text content from the model. 92 | 93 | Raises 94 | ------ 95 | InvalidPromptError 96 | If both system and user prompts are None. 97 | """ 98 | 99 | response = self._execute_request(system_prompt, user_prompt, self.model_name, do_sample, temperature, top_p, max_tokens) 100 | return response 101 | 102 | def send_request(self, request: Dict[str, Any]) -> Dict[str, Any]: 103 | 104 | """ 105 | Send an HTTP request to the ChatGPT API. 106 | 107 | Parameters 108 | ---------- 109 | request : dict 110 | A dictionary containing the request data, including the URL, headers, and JSON body. 111 | 112 | Returns 113 | ------- 114 | dict 115 | The response from the API in the form of a dictionary. 116 | 117 | Raises 118 | ------ 119 | HttpClientError 120 | If any error occurs during the HTTP request process. 121 | """ 122 | 123 | response = self._send_request(request=request) 124 | return response 125 | 126 | def _send_request(self, request: Dict[str, Any]) -> Dict[str, Any]: 127 | 128 | """ 129 | Internal method to send an HTTP POST request to the ChatGPT API. 130 | 131 | This method handles the actual HTTP POST request and manages error handling 132 | for issues like connection failures, timeouts, and JSON decoding errors. 133 | 134 | Parameters 135 | ---------- 136 | request : dict 137 | A dictionary containing the request data, including the URL, headers, and JSON body. 138 | 139 | Returns 140 | ------- 141 | dict 142 | The JSON response from the API. 143 | 144 | Raises 145 | ------ 146 | HttpClientError 147 | If an error occurs during the HTTP request. 148 | """ 149 | 150 | url = request.get("url", "https://api.openai.com/v1/chat/completions") 151 | headers = {**request.get("headers", {})} 152 | json = request.get("json", {}) 153 | try: 154 | response = requests.post( 155 | url=url, 156 | headers=headers, 157 | json=json, 158 | timeout=30, 159 | ) 160 | 161 | response.raise_for_status() 162 | 163 | return response.json()["choices"][0]["message"]["content"].strip() 164 | 165 | except requests.exceptions.HTTPError as e: 166 | # Handle HTTP error exceptions 167 | status_code = e.response.status_code 168 | logger.error( 169 | "HTTP error occurred. Status Code: %s, URL: %s, Message: %s", 170 | status_code, 171 | url, 172 | e, 173 | ) 174 | 175 | return {"error": "HTTP error", "status_code": status_code, "message": str(e)} 176 | 177 | 178 | except requests.exceptions.ConnectionError as e: 179 | # Handle connection errors 180 | logger.error("Connection error occurred while connecting to %s: %s", url, e) 181 | 182 | return {"error": "Connection error", "message": str(e)} 183 | 184 | except requests.exceptions.Timeout as e: 185 | # Handle timeout errors 186 | logger.error("Timeout occurred while sending request to %s: %s", url, e) 187 | 188 | return {"error": "Timeout", "message": str(e)} 189 | 190 | 191 | except requests.exceptions.RequestException as e: 192 | # Handle any generic request exceptions 193 | logger.error( 194 | "Request exception occurred while sending request to %s: %s", url, e 195 | ) 196 | 197 | return {"error": "Request exception", "message": str(e)} 198 | 199 | 200 | except ValueError as e: 201 | # Handle JSON decoding errors 202 | logger.error("JSON decoding error occurred: %s", e) 203 | 204 | return {"error": "JSON decoding error", "message": str(e)} 205 | 206 | except Exception as e: 207 | # Catch any unexpected errors 208 | logger.critical( 209 | "An unexpected error occurred while sending request to %s: %s", url, e 210 | ) 211 | 212 | return {"error": "Unexpected error", "message": str(e)} 213 | 214 | 215 | def _execute_request( 216 | self 217 | , system_prompt: str = None 218 | , user_prompt: str = None 219 | , model: str = "gpt-4o-mini" 220 | , do_sample: bool = True 221 | , temperature: float = 0.95 222 | , top_p: float = 0.7 223 | , max_tokens: int = 4095 224 | ) -> str: 225 | 226 | """ 227 | Internal method to prepare and execute the API request for text generation. 228 | 229 | Parameters 230 | ---------- 231 | system_prompt : str, optional 232 | The system-level prompt providing context for the text generation. Default is None. 233 | user_prompt : str, optional 234 | The user-provided prompt initiating the text generation. Default is None. 235 | model : str, optional 236 | The specific model to use for text generation. Default is "gpt-4o-mini". 237 | do_sample : bool, optional 238 | Whether to use sampling during text generation. Default is True. 239 | temperature : float, optional 240 | Sampling temperature to control randomness. Default is 0.95. 241 | top_p : float, optional 242 | Nucleus sampling parameter to control randomness. Default is 0.7. 243 | max_tokens : int, optional 244 | The maximum number of tokens to generate. Default is 4095. 245 | 246 | Returns 247 | ------- 248 | str 249 | The generated text content from the model. 250 | 251 | Raises 252 | ------ 253 | InvalidPromptError 254 | If both system and user prompts are None. 255 | """ 256 | 257 | if (system_prompt is None and user_prompt is None): 258 | logger.error("prompt不能同时为空") 259 | raise InvalidPromptError("prompt不能同时为空") 260 | 261 | request_data = { 262 | "url": f"{self.base_url}", 263 | "headers": { 264 | "Authorization": f"Bearer {self.api_key}", 265 | "Content-Type": "application/json", 266 | }, 267 | "json": { 268 | "model": model, 269 | "messages": [ 270 | { 271 | "role": "developer", 272 | "content": system_prompt, 273 | }, 274 | { 275 | "role": "user", 276 | "content": user_prompt, 277 | } 278 | ], 279 | "temperature": temperature, 280 | "top_p": top_p, 281 | "max_tokens": max_tokens 282 | }, 283 | } 284 | 285 | response = self.send_request(request_data) 286 | return response 287 | -------------------------------------------------------------------------------- /edg4llm/processor/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alannikos/edg4llm/73b996c36c750fec2d99c934bda01a4d4a57954a/edg4llm/processor/__init__.py -------------------------------------------------------------------------------- /edg4llm/processor/postprocess.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Dict, List, Any 3 | 4 | from edg4llm.utils.logger import custom_logger 5 | 6 | logger = custom_logger("PostProcessor") 7 | 8 | class PostProcessor: 9 | """ 10 | A class for post-processing conversation and question data. 11 | 12 | This class provides methods to clean and structure raw data obtained from API responses or external sources. 13 | It handles the removal of unnecessary markdown formatting, parses the data into valid JSON format, and 14 | structures it for further use in applications such as chatbots or AI assistants. It can also incorporate 15 | an optional system prompt into the processed data for context. 16 | 17 | Methods 18 | ------- 19 | dialogue_postprocessing(conversation_data: Dict[str, str], system_prompt: str = None): 20 | Processes raw conversation data by cleaning, parsing, and adding an optional system prompt. 21 | 22 | question_postprocessing(question_data: str = None): 23 | Processes raw question data by cleaning and structuring it into a list of questions. 24 | 25 | answer_postprocessing(question: str, answer: str, system_prompt: str = None): 26 | Processes raw answer data by cleaning, parsing, and structuring it along with the question 27 | and an optional system prompt. 28 | """ 29 | 30 | def __init__(self): 31 | pass 32 | 33 | def dialogue_postprocessing(self, conversation_data: Dict[str, str], system_prompt: str = None): 34 | """ 35 | Post-process conversation data. 36 | 37 | This function processes raw conversation data by removing unnecessary formatting and parsing it 38 | into a valid JSON format. If a system-level prompt (system_prompt) is provided, it will be added 39 | as an "instruction" field to the first conversation entry. The processed data is returned as a 40 | dictionary with a "conversation" key. 41 | 42 | Parameters 43 | ---------- 44 | conversation_data : str 45 | The raw conversation data in string format, typically from an API response or an external source. 46 | It may contain markdown-style formatting such as "```json" or "```" that needs to be removed. 47 | 48 | system_prompt : str, optional 49 | An optional system-level prompt that will be added to the "instruction" field of the first 50 | conversation entry. If not provided, an empty string will be used. Default is None. 51 | 52 | Returns 53 | ------- 54 | dict or None 55 | Returns a dictionary containing the processed conversation data structured under the "conversation" key. 56 | Each item in the list corresponds to a conversation entry. If an error occurs during JSON parsing, 57 | the function logs the error and returns None. 58 | 59 | Examples 60 | -------- 61 | >>> conversation_data = ''' 62 | [ 63 | {"input": "AAA", "output": "BBBB"}, 64 | {"input": "CCC", "output": "DDDD"} 65 | ] 66 | ''' 67 | >>> system_prompt = "You are a helpful assistant." 68 | >>> processed_data = postprocessing(conversation_data, system_prompt) 69 | 70 | >>> # Output: 71 | >>> { 72 | "conversation": [ 73 | {"input": "AAA", "output": "BBBB", "instruction": "You are a helpful assistant."}, 74 | {"input": "CCC", "output": "DDDD"} 75 | ] 76 | } 77 | 78 | Notes 79 | ----- 80 | - The function removes any markdown formatting (like "```json" or "```") before parsing the data. 81 | - If JSON parsing fails, an error is logged, and the function returns None. 82 | """ 83 | try: 84 | # Clean and parse the JSON conversation data 85 | conversation_data = json.loads(conversation_data.replace("```json", "").replace("```", "")) 86 | except Exception as exception: 87 | logger.error("Error parsing JSON: %s", str(exception)) 88 | return None 89 | 90 | # Initialize the result dictionary with a "conversation" key 91 | result = {"conversation": []} 92 | 93 | # Add the system prompt as an instruction to the first conversation entry if provided 94 | for idx, data in enumerate(conversation_data): 95 | if idx == 0: 96 | data["instruction"] = system_prompt if system_prompt is not None else "" 97 | result["conversation"].append(data) 98 | 99 | return result 100 | 101 | 102 | def question_postprocessing(self, question_data: str = None): 103 | """ 104 | Post-process the question data. 105 | 106 | This function processes raw question data by removing unnecessary formatting and ensuring 107 | it is in a valid JSON format. It converts each question into a structured dictionary with 108 | the key "question" holding the processed content. 109 | 110 | Parameters 111 | ---------- 112 | question_data : str 113 | The raw question data in string format, typically from an API response or external source. 114 | The string may contain markdown-style formatting such as "```json" or "```" that should be removed. 115 | 116 | Returns 117 | ------- 118 | dict or None 119 | Returns a dictionary with the format {"question": }. 120 | If an error occurs during JSON parsing, it returns None. 121 | 122 | Examples 123 | -------- 124 | >>> question_data = "What is your name?" 125 | >>> processed_data = question_postprocessing(question_data) 126 | >>> print(processed_data) 127 | Output: {'question': 'What is your name?'} 128 | 129 | Notes 130 | ----- 131 | - This function removes any markdown formatting (e.g., "```json" or "```") from the input string. 132 | - If an exception occurs during JSON parsing, an error message is logged, and the function returns None. 133 | """ 134 | 135 | try: 136 | # Clean up and parse the JSON question data 137 | question_data = json.loads(question_data.replace("```json", "").replace("```", "")) 138 | except Exception as exception: 139 | logger.error("Error parsing JSON: %s", str(exception)) 140 | return None 141 | 142 | # Initialize the result with a "question" key 143 | result = [] 144 | 145 | # Extract the question and assign it to the result 146 | for _, data in enumerate(question_data): 147 | result.append(data) 148 | 149 | return result 150 | 151 | def answer_postprocessing(self, question: str, answer: str, system_prompt: str = None): 152 | """ 153 | Post-process conversation data. 154 | 155 | This function processes raw conversation data by parsing it into a valid JSON format and structuring 156 | it into a predefined format. It also adds an optional system prompt to each conversation entry 157 | under the "instruction" key. The processed data is returned as a dictionary wrapped in a list. 158 | 159 | Parameters 160 | ---------- 161 | question : str 162 | The input question or query from the user. 163 | 164 | answer : str 165 | The raw answer data in string format, typically containing JSON content. 166 | This string may contain markdown formatting (e.g., "```json" or "```") that needs to be removed. 167 | 168 | system_prompt : str, optional 169 | An optional system-level prompt to provide context or instructions. This will be added to 170 | each conversation entry under the "instruction" key. Default is None. 171 | 172 | Returns 173 | ------- 174 | list or None 175 | Returns a list containing a dictionary with the processed conversation data. 176 | The dictionary has a "conversation" key, which is a list of conversation entries. 177 | Each entry contains "input", "output", and "instruction" keys. 178 | If an error occurs during JSON parsing, the function logs the error and returns None. 179 | 180 | Examples 181 | -------- 182 | >>> # Input: 183 | >>> question = "What is AI?" 184 | >>> answer = ''' 185 | [ 186 | { 187 | "input": question, 188 | "output": "BBB" 189 | } 190 | ] 191 | ''' 192 | >>> system_prompt = "You are a helpful assistant." 193 | 194 | >>> # Function Call: 195 | >>> processed_data = answer_postprocessing(question, answer, system_prompt) 196 | 197 | >>> # Output: 198 | >>> [ 199 | { 200 | "conversation": [ 201 | { 202 | "input": "What is AI?", 203 | "output": "BBB", 204 | "instruction": "You are a helpful assistant." 205 | } 206 | ] 207 | } 208 | ] 209 | 210 | Notes 211 | ----- 212 | - The function removes any markdown formatting (like "```json" or "```") before parsing the data. 213 | - If JSON parsing fails, the function logs an error and returns None. 214 | - The output is wrapped in a list to allow for future extensibility. 215 | """ 216 | 217 | try: 218 | # Clean up and parse the JSON conversation data 219 | conversation_data = json.loads(answer.replace("```json","").replace("```","")) 220 | except Exception as exception: 221 | logger.error("Error parsing JSON: %s", str(exception)) 222 | return None 223 | 224 | # Initialize the result with a conversation key 225 | result = {"conversation": []} 226 | conversation = {"instruction" : system_prompt, "input" : question} 227 | # Add the system prompt to the first conversation entry if provided 228 | for idx, data in enumerate(conversation_data): 229 | conversation['output'] = data["answer"] 230 | result["conversation"].append(conversation) 231 | return result 232 | -------------------------------------------------------------------------------- /edg4llm/processor/preprocess.py: -------------------------------------------------------------------------------- 1 | import re 2 | import sys 3 | import json 4 | 5 | from edg4llm.utils.logger import custom_logger 6 | from edg4llm.utils.data_utils import is_question_template_consistent 7 | from edg4llm.utils.data_utils import is_answer_template_consistent 8 | from edg4llm.utils.data_utils import is_dialogue_template_consistent 9 | 10 | from edg4llm.utils.template import Template 11 | 12 | logger = custom_logger("preprocess") 13 | 14 | class PreProcessor: 15 | """ 16 | A class for pre-processing user prompts before data generation. 17 | 18 | This class provides methods to validate and repair user prompts in different modes such as question, 19 | answer, and dialogue. If a user prompt does not match the expected template, the methods automatically 20 | append the corresponding format guidelines to ensure consistency. 21 | 22 | Methods 23 | ------- 24 | question_preprocess(user_prompt: str) -> str: 25 | Validates and repairs user prompts in question mode. 26 | 27 | answer_preprocess(user_prompt: str) -> str: 28 | Validates and repairs user prompts in answer mode. 29 | 30 | dialogue_preprocess(user_prompt: str) -> str: 31 | Validates and repairs user prompts in Q&A (dialogue) mode. 32 | """ 33 | def __init__(self): 34 | pass 35 | 36 | def question_preprocess(self, language: str, user_prompt: str) -> str: 37 | """ 38 | Validates and processes user prompts in question mode. 39 | 40 | Parameters 41 | ---------- 42 | language : str 43 | The language of data in data generation. Must be one of 'zh', 'en'. 44 | 45 | user_prompt : str 46 | The user's input prompt to be processed in question mode. 47 | 48 | Returns 49 | ------- 50 | str 51 | The validated and, if necessary, repaired user prompt. 52 | 53 | Notes 54 | ----- 55 | - If the user prompt matches the question template, it is returned unchanged. 56 | - If the user prompt does not match, format guidelines from `Template.question_template` 57 | are appended to the prompt. 58 | """ 59 | 60 | if is_question_template_consistent(user_prompt=user_prompt): 61 | logger.info("User prompt matches the question template. Proceeding with data generation.") 62 | return user_prompt 63 | else: 64 | logger.warning("User prompt does not match the question template. Automatically added format guidelines.") 65 | if language == "zh": 66 | repaired_user_prompt = user_prompt + '\n' + Template.question_zh_template 67 | else: 68 | repaired_user_prompt = user_prompt + '\n' + Template.question_en_template 69 | return repaired_user_prompt 70 | 71 | def answer_preprocess(self, language: str, user_prompt: str) -> str: 72 | """ 73 | Validates and processes user prompts in answer mode. 74 | 75 | Parameters 76 | ---------- 77 | language : str 78 | The language of data in data generation. Must be one of 'zh', 'en'. 79 | 80 | user_prompt : str 81 | The user's input prompt to be processed in answer mode. 82 | 83 | Returns 84 | ------- 85 | str 86 | The validated and, if necessary, repaired user prompt. 87 | 88 | Notes 89 | ----- 90 | - If the user prompt matches the answer template, it is returned unchanged. 91 | - If the user prompt does not match, format guidelines from `Template.answer_template` 92 | are appended to the prompt. 93 | """ 94 | 95 | if is_answer_template_consistent(user_prompt=user_prompt): 96 | logger.info("User prompt matches the answer template. Proceeding with data generation.") 97 | return user_prompt 98 | else: 99 | logger.warning("User prompt does not match the answer template. Automatically added format guidelines.") 100 | if language == "zh": 101 | repaired_user_prompt = user_prompt + '\n' + Template.answer_zh_template 102 | else: 103 | repaired_user_prompt = user_prompt + '\n' + Template.answer_en_template 104 | return repaired_user_prompt 105 | 106 | def dialogue_preprocess(self, language: str, user_prompt: str) -> str: 107 | """ 108 | Validates and processes user prompts in Q&A (dialogue) mode. 109 | 110 | Parameters 111 | ---------- 112 | language : str 113 | The language of data in data generation. Must be one of 'zh', 'en'. 114 | 115 | user_prompt : str 116 | The user's input prompt to be processed in Q&A mode. 117 | 118 | Returns 119 | ------- 120 | str 121 | The validated and, if necessary, repaired user prompt. 122 | 123 | Notes 124 | ----- 125 | - If the user prompt matches the dialogue template, it is returned unchanged. 126 | - If the user prompt does not match, format guidelines from `Template.dialogue_template` 127 | are appended to the prompt. 128 | """ 129 | 130 | if is_dialogue_template_consistent(user_prompt=user_prompt): 131 | logger.info("User prompt matches the dialogue template. Proceeding with data generation.") 132 | return user_prompt 133 | else: 134 | logger.warning("User prompt does not match the dialogue template. Automatically added format guidelines.") 135 | if language == "zh": 136 | repaired_user_prompt = user_prompt + '\n' + Template.dialogue_zh_template 137 | else: 138 | repaired_user_prompt = user_prompt + '\n' + Template.dialogue_en_template 139 | return repaired_user_prompt 140 | -------------------------------------------------------------------------------- /edg4llm/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alannikos/edg4llm/73b996c36c750fec2d99c934bda01a4d4a57954a/edg4llm/utils/__init__.py -------------------------------------------------------------------------------- /edg4llm/utils/config.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | 3 | @dataclasses 4 | class DefaultConfig: 5 | """ 6 | A placeholder class for default configuration settings. 7 | """ 8 | pass 9 | -------------------------------------------------------------------------------- /edg4llm/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | from typing import Dict, List, Any 4 | 5 | def is_question_template_consistent(user_prompt: str) -> bool: 6 | """ 7 | Check if the user prompt contains a consistent question JSON template. 8 | 9 | Parameters 10 | ---------- 11 | user_prompt : str 12 | The user-provided prompt to be validated. 13 | 14 | Returns 15 | ------- 16 | bool 17 | True if the user prompt contains a valid and consistent question JSON template, 18 | False otherwise. 19 | 20 | Notes 21 | ----- 22 | - The function uses a regular expression to extract the JSON template and compares it 23 | with the target template. 24 | - The target template is: 25 | [ 26 | { 27 | "question": "AAA" 28 | } 29 | ] 30 | - Returns False if the JSON extraction or comparison fails. 31 | """ 32 | target_template = [ 33 | { 34 | "question": "AAA" 35 | } 36 | ] 37 | 38 | # Regular expression to extract JSON template 39 | pattern = r"\[\s*{\s*\"question\"\s*:\s*\"AAA\"\s*}\s*\]" 40 | match = re.search(pattern, user_prompt) 41 | 42 | if match: 43 | try: 44 | extracted_template = json.loads(match.group(0)) 45 | except json.JSONDecodeError: 46 | return False 47 | return extracted_template == target_template 48 | return False 49 | 50 | def is_answer_template_consistent(user_prompt: str) -> bool: 51 | """ 52 | Check if the user prompt contains a consistent answer JSON template. 53 | 54 | Parameters 55 | ---------- 56 | user_prompt : str 57 | The user-provided prompt to be validated. 58 | 59 | Returns 60 | ------- 61 | bool 62 | True if the user prompt contains a valid and consistent answer JSON template, 63 | False otherwise. 64 | 65 | Notes 66 | ----- 67 | - The function uses a regular expression to extract the JSON template and compares it 68 | with the target template. 69 | - The target template is: 70 | [ 71 | { 72 | "answer": "AAA" 73 | } 74 | ] 75 | - Returns False if the JSON extraction or comparison fails. 76 | """ 77 | target_template = [ 78 | { 79 | "answer": "AAA" 80 | } 81 | ] 82 | 83 | # Regular expression to extract JSON template 84 | pattern = r"\[\s*{\s*\"answer\"\s*:\s*\"AAA\"\s*}\s*\]" 85 | match = re.search(pattern, user_prompt) 86 | 87 | if match: 88 | try: 89 | extracted_template = json.loads(match.group(0)) 90 | except json.JSONDecodeError: 91 | return False 92 | return extracted_template == target_template 93 | return False 94 | 95 | def is_dialogue_template_consistent(user_prompt: str) -> bool: 96 | """ 97 | Check if the user prompt contains a consistent dialogue JSON template. 98 | 99 | Parameters 100 | ---------- 101 | user_prompt : str 102 | The user-provided prompt to be validated. 103 | 104 | Returns 105 | ------- 106 | bool 107 | True if the user prompt contains a valid and consistent dialogue JSON template, 108 | False otherwise. 109 | 110 | Notes 111 | ----- 112 | - The function uses a regular expression to check for the dialogue JSON structure. 113 | - The expected template format is: 114 | [ 115 | { 116 | "input": "AAA", 117 | "output": "BBB" 118 | } 119 | ] 120 | """ 121 | 122 | pattern = r"\[\s*\{\{\s*\"input\"\s*:\s*\"AAA\"\s*,\s*\"output\"\s*:\s*\"BBB\"\s*\}\}\s*\]" 123 | match = re.search(pattern, user_prompt) 124 | return match is not None 125 | 126 | def save_data_to_json(data: List[Dict], output_path: str): 127 | """ 128 | Save a list of dictionaries to a JSON file. 129 | 130 | Parameters 131 | ---------- 132 | data : list of dict 133 | A list of dictionaries to be saved to a JSON file. Each dictionary should contain 134 | the data to be written. 135 | 136 | output_path : str 137 | The path (including the filename) where the JSON data will be saved. 138 | The file will be written in UTF-8 encoding. 139 | 140 | Returns 141 | ------- 142 | None 143 | This function does not return any value. It saves the data to the specified file. 144 | 145 | Examples 146 | -------- 147 | >>> data = [{"name": "John", "age": 30}, {"name": "Jane", "age": 25}] 148 | >>> save_data_to_json(data, "output.json") 149 | 150 | Notes 151 | ----- 152 | - The function uses `json.dump` to write the data to the file. 153 | - Non-ASCII characters are preserved with the `ensure_ascii=False` argument. 154 | - The file will be saved with an indentation of 4 spaces to make it human-readable. 155 | """ 156 | with open(output_path, 'w', encoding='utf-8') as f: 157 | json.dump(data, f, ensure_ascii=False, indent=4) 158 | -------------------------------------------------------------------------------- /edg4llm/utils/exceptions.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | 4 | class HttpClientError(Exception): 5 | """ 6 | Exception raised for errors encountered in the HTTP client. 7 | 8 | Parameters 9 | ---------- 10 | message : str 11 | A detailed error message describing the issue. 12 | status_code : Optional[int], optional 13 | The HTTP status code associated with the error, by default None. 14 | 15 | Attributes 16 | ---------- 17 | status_code : Optional[int] 18 | The HTTP status code associated with the error. 19 | """ 20 | 21 | def __init__(self, message: str, status_code: Optional[int] = None): 22 | super().__init__(message) 23 | self.status_code = status_code 24 | 25 | 26 | class InvalidPromptError(Exception): 27 | """ 28 | Custom exception raised when an invalid or empty prompt is encountered. 29 | 30 | Notes 31 | ----- 32 | This exception is intended to handle cases where a required prompt input 33 | is missing or invalid. 34 | """ 35 | pass 36 | -------------------------------------------------------------------------------- /edg4llm/utils/list_supported_models.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from edg4llm.utils.logger import custom_logger 3 | 4 | class ModelManager: 5 | """ 6 | A class to manage supported model providers and their models. 7 | 8 | Attributes 9 | ---------- 10 | supported_models : dict 11 | A dictionary mapping provider names to their supported models. 12 | 13 | Methods 14 | ------- 15 | list_providers(): 16 | Returns a list of all supported providers. 17 | list_models_by_provider(provider_name): 18 | Returns a list of models supported by the given provider. 19 | """ 20 | def __init__(self): 21 | """ 22 | Initializes the ModelManager with a predefined list of supported models. 23 | """ 24 | self.supported_models = { 25 | "ChatGLM": ["glm-4-plus", "glm-4-0520", "glm-4-air", "glm-4-airx", "glm-4-long", "glm-4-flashx", "glm-4-flash"], 26 | "DeepSeek": ["deepseek-chat", "deepseek-reasoner"], 27 | "InternLM": ["internlm2.5-latest", "internlm3-latest"], 28 | "ChatGPT": ["gpt-3.5-turbo-16k", "gpt-3.5-turbo-1106", "gpt-3.5-turbo-0125", "gpt-3.5-turbo", "gpt-4o-mini", "gpt-4o-mini-2024-07-18", "o1-mini", "o1-mini-2024-09-12", "o1-preview","o1-preview-2024-09-12"] 29 | } 30 | 31 | def list_providers(self): 32 | """ 33 | Lists all supported model providers. 34 | 35 | Returns 36 | ------- 37 | list 38 | A list of provider names. 39 | """ 40 | 41 | return list(self.supported_models.keys()) 42 | 43 | def list_models_by_provider(self, provider_name): 44 | """ 45 | Lists all models supported by a given provider. 46 | 47 | Parameters 48 | ---------- 49 | provider_name : str 50 | The name of the provider. 51 | 52 | Returns 53 | ------- 54 | list or None 55 | A list of model names supported by the provider, 56 | or None if the provider does not exist. 57 | """ 58 | return self.supported_models.get(provider_name, None) 59 | 60 | def main(): 61 | """ 62 | Entry point of the script to display supported model providers 63 | and their corresponding models based on the user's input. 64 | """ 65 | parser = argparse.ArgumentParser(description="View the list of supported models.") 66 | parser.add_argument("--list-providers", action="store_true", help="List all supported providers.") 67 | parser.add_argument("--list-models", type=str, metavar="PROVIDER", help="View the list of models for a specific provider.") 68 | 69 | args = parser.parse_args() 70 | 71 | manager = ModelManager() 72 | 73 | if args.list_providers: 74 | providers = manager.list_providers() 75 | print("Supported model providers:") 76 | for provider in providers: 77 | print(f" - {provider}") 78 | elif args.list_models: 79 | models = manager.list_models_by_provider(args.list_models) 80 | if models: 81 | print(f"{args.list_models} supports the following models:") 82 | for model in models: 83 | print(f" - {model}") 84 | else: 85 | print(f"Provider '{args.list_models}' does not exist or is not supported.") 86 | else: 87 | parser.print_help() 88 | 89 | -------------------------------------------------------------------------------- /edg4llm/utils/logger.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | 4 | __all__ = ['custom_logger'] 5 | 6 | # Define log level colors for terminal output 7 | LOG_COLORS = { 8 | 'DEBUG': '\033[96m', # Cyan 9 | 'INFO': '\033[92m', # Green 10 | 'WARNING': '\033[93m', # Yellow 11 | 'ERROR': '\033[91m', # Red 12 | 'CRITICAL': '\033[1;91m', # Bold Red 13 | 'RESET': '\033[0m', # Reset color 14 | } 15 | 16 | def custom_logger(name: str): 17 | """ 18 | Creates a custom logger with color-coded log levels and UTC+8 time formatting. 19 | 20 | Parameters 21 | ---------- 22 | name : str 23 | The name of the logger, typically the name of the module or application. 24 | 25 | Returns 26 | ------- 27 | logging.Logger 28 | A customized logger instance with color-coded levels and UTC+8 timezone support. 29 | 30 | Notes 31 | ----- 32 | - Log levels are color-coded for easier readability in terminal output. 33 | - Log messages use UTC+8 timezone formatting. 34 | - The logger prevents propagation to root loggers and clears existing handlers. 35 | - The logger uses a custom `StreamHandler` with color support. 36 | """ 37 | # Create a logger instance 38 | logger = logging.getLogger(name) 39 | logger.setLevel(logging.INFO) # Default log level 40 | logger.propagate = False # Disable propagation to root loggers 41 | logger.handlers = [] # Clear any existing handlers 42 | 43 | # Define a custom log message format 44 | formatter = logging.Formatter( 45 | '[%(asctime)s]-[%(name)s:%(levelname)s]:%(message)s' 46 | ) 47 | 48 | # Custom time converter to use UTC+8 49 | def _utc8_aera(timestamp): 50 | """ 51 | Convert a timestamp to a UTC+8 time tuple. 52 | 53 | Parameters 54 | ---------- 55 | timestamp : float 56 | The timestamp to convert. 57 | 58 | Returns 59 | ------- 60 | time.struct_time 61 | A time tuple in UTC+8 timezone. 62 | """ 63 | now = datetime.datetime.fromtimestamp(timestamp, tz=datetime.timezone.utc) + datetime.timedelta(hours=8) 64 | return now.timetuple() 65 | 66 | # Set the custom time converter in the formatter 67 | formatter.converter = _utc8_aera 68 | 69 | # Define a custom StreamHandler with color-coded log levels 70 | class ColorStreamHandler(logging.StreamHandler): 71 | """ 72 | A custom logging stream handler that adds color coding to log messages. 73 | 74 | Methods 75 | ------- 76 | emit(record): 77 | Formats and outputs a log record with color coding based on log level. 78 | """ 79 | def emit(self, record): 80 | """ 81 | Format and emit a log record with color coding. 82 | 83 | Parameters 84 | ---------- 85 | record : logging.LogRecord 86 | The log record to process and output. 87 | """ 88 | try: 89 | msg = self.format(record) # Format the log record 90 | color = LOG_COLORS.get(record.levelname, LOG_COLORS['RESET']) # Get the color for the log level 91 | # Write the log message with color 92 | self.stream.write(f"{color}{msg}{LOG_COLORS['RESET']}\n") 93 | self.flush() # Flush the stream 94 | except Exception: 95 | self.handleError(record) # Handle any errors during logging 96 | 97 | # Create and configure the custom handler 98 | custom_handler = ColorStreamHandler() 99 | custom_handler.setFormatter(formatter) 100 | 101 | # Add the custom handler to the logger 102 | logger.addHandler(custom_handler) 103 | 104 | return logger 105 | -------------------------------------------------------------------------------- /edg4llm/utils/template.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | @dataclass 4 | class Template: 5 | """ 6 | A class to define language-specific templates for user prompts, providing a strict JSON format 7 | to preprocess user input. If the user's prompt does not include format instructions, the 8 | appropriate template will be added to enforce the required structure. 9 | 10 | Attributes: 11 | ---------- 12 | question_zh_template : str 13 | A JSON format template for Chinese question prompts. Ensures that generated questions 14 | are returned in a JSON format with a "question" field. 15 | 16 | answer_zh_template : str 17 | A JSON format template for Chinese answer prompts. Ensures that generated answers 18 | are returned in a JSON format with an "answer" field. 19 | 20 | dialogue_zh_template : str 21 | A JSON format template for Chinese dialogue prompts. Ensures that the interaction is 22 | returned in a JSON format with "input" representing the question and "output" representing 23 | the response. 24 | 25 | question_en_template : str 26 | A JSON format template for English question prompts. Ensures that generated questions 27 | are returned in a JSON format with a "question" field. 28 | 29 | answer_en_template : str 30 | A JSON format template for English answer prompts. Ensures that generated answers 31 | are returned in a JSON format with an "answer" field. 32 | 33 | dialogue_en_template : str 34 | A JSON format template for English dialogue prompts. Ensures that the interaction is 35 | returned in a JSON format with "input" representing the question and "output" representing 36 | the response. 37 | 38 | Notes: 39 | ----- 40 | This class is designed for preprocessing user prompts. If a user's input does not include 41 | specific format instructions, the appropriate template (based on language) is appended to 42 | the user prompt to ensure compliance with the required JSON format. 43 | """ 44 | 45 | question_zh_template = \ 46 | """ 47 | 严格遵循规则: 请以如下格式返回生成的数据, 只返回JSON格式,json模板: 48 | [ 49 | { 50 | "question":"AAA" 51 | } 52 | ] 53 | 其中question字段表示生成的问题 54 | """ 55 | 56 | answer_zh_template = \ 57 | """ 58 | 严格遵循规则: 请以如下格式返回生成的数据, 只返回JSON格式,json模板: 59 | [ 60 | { 61 | "answer":"AAA" 62 | } 63 | ] 64 | 其中answer字段表示生成的答案 65 | """ 66 | 67 | dialogue_zh_template = \ 68 | """ 69 | 严格遵循规则: 请以如下格式返回生成的数据, 只返回JSON格式,json模板: 70 | [ 71 | {{ 72 | "input":"AAA","output":"BBB" 73 | }} 74 | ] 75 | 其中input字段表示问题, output字段回答 76 | """ 77 | 78 | question_en_template = \ 79 | """ 80 | Strictly follow the rules: Please return the generated data in the following format, 81 | only in JSON format. JSON template: 82 | [ 83 | { 84 | "question":"AAA" 85 | } 86 | ] 87 | The "question" field represents the generated question. 88 | """ 89 | 90 | answer_en_template = \ 91 | """ 92 | Strictly follow the rules: Please return the generated data in the following format, 93 | only in JSON format. JSON template: 94 | [ 95 | { 96 | "answer":"AAA" 97 | } 98 | ] 99 | The "answer" field represents the generated answer. 100 | """ 101 | 102 | dialogue_en_template = \ 103 | """ 104 | Strictly follow the rules: Please return the generated data in the following format, 105 | only in JSON format. JSON template: 106 | [ 107 | {{ 108 | "input":"AAA","output":"BBB" 109 | }} 110 | ] 111 | The "input" field represents the question, and the "output" field 112 | represents the answer. 113 | """ 114 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | with open("README.md", "r", encoding="utf-8") as fh: 4 | long_description = fh.read() 5 | 6 | setup( 7 | name="edg4llm", # 项目名称 8 | version="1.0.18", # 项目版本 9 | author="Alannikos", # 作者姓名 10 | author_email="alannikos768@outlook.com", # 作者邮箱 11 | description="A unified tool to generate fine-tuning datasets for LLMs, including questions, answers, and dialogues.", # 简短描述 12 | long_description=long_description, # 长描述 13 | long_description_content_type="text/markdown", # 长描述格式 14 | url="https://github.com/alannikos/edg4llm", # 项目主页(GitHub 或其他) 15 | packages=find_packages(include=["edg4llm", "edg4llm.*"]), # 自动发现包含的模块 16 | classifiers=[ 17 | "Programming Language :: Python :: 3", 18 | "License :: OSI Approved :: MIT License", # 选择的许可证类型 19 | "Operating System :: OS Independent", 20 | "Intended Audience :: Developers", 21 | "Topic :: Software Development :: Libraries", 22 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 23 | ], 24 | python_requires=">=3.8", # Python 版本要求 25 | install_requires=[ 26 | "requests>=2.32.3" 27 | ], 28 | include_package_data=True, # 包含非代码文件,如配置文件 29 | zip_safe=False, # 是否以 zip 格式分发(通常为 False) 30 | keywords="LLM fine-tuning data-generation AI NLP", # 关键词 31 | entry_points={ 32 | "console_scripts": [ 33 | "edg4llm-cli=edg4llm.utils.list_supported_models:main" 34 | ] 35 | }, 36 | ) 37 | --------------------------------------------------------------------------------