├── .env.example ├── .gitignore ├── README.md ├── groqcloud_darkmode.png ├── main.py ├── output.md └── requirements.txt /.env.example: -------------------------------------------------------------------------------- 1 | GROQ_API_KEY= -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | venv/* 2 | .env 3 | *.csv 4 | *.md 5 | *.txt 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CrewAI Machine Learning Assistant 2 | 3 | ## Overview 4 | 5 | The [CrewAI](https://docs.crewai.com/) Machine Learning Assistant is a command line application designed to kickstart your machine learning projects. It leverages a team of AI agents to guide you through the initial steps of defining, assessing, and solving machine learning problems. 6 | 7 | ## Features 8 | 9 | - **Agents**: Utilizes specialized agents to perform tasks such as problem definition, data assessment, model recommendation, and code generation, enhancing the workflow and efficiency of machine learning projects. 10 | 11 | - **CrewAI Framework**: Integrates multiple agents into a cohesive framework, enabling seamless interaction and task execution to streamline the machine learning process. 12 | 13 | - **LangChain Integration**: Incorporates LangChain to facilitate natural language processing and enhance the interaction between the user and the machine learning assistant. 14 | 15 | 16 | ## Usage 17 | 18 | You will need to store a valid Groq API Key as a secret to proceed with this example. You can generate one for free [here](https://console.groq.com/keys). 19 | 20 | You can [fork and run this application on Replit](https://replit.com/@GroqCloud/CrewAI-Machine-Learning-Assistant) or run it on the command line with `python main.py`. You can upload a sample .csv to the same directory as ```main.py``` to give the application a head start on your ML problem. The application will output a Markdown file including python code for your ML use case to the same directory as main.py. 21 | -------------------------------------------------------------------------------- /groqcloud_darkmode.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mustafa-Esoofally/ML-assistant-Crew/7707baa5224e7d69d50ade351b8a8cbb6889680e/groqcloud_darkmode.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import os 3 | from crewai import Agent, Task, Crew 4 | from langchain_groq import ChatGroq 5 | from dotenv import load_dotenv 6 | 7 | load_dotenv() # take environment variables from .env. 8 | 9 | def write_output_files(result): 10 | all_output = str(result) 11 | print(f"Length of all_output: {len(all_output)}") 12 | print("All output content:") 13 | print(all_output) 14 | 15 | with open('output.md', "w") as file: 16 | file.write("# Machine Learning Project Summary\n\n") 17 | file.write(all_output) 18 | 19 | # Python Library Dependencies 20 | file.write("\n\n## Python Library Dependencies\n\n") 21 | libraries = [ 22 | "pandas", 23 | "numpy", 24 | "scikit-learn", 25 | "matplotlib", 26 | "seaborn", 27 | "scipy", 28 | "category_encoders" 29 | ] 30 | for lib in libraries: 31 | file.write(f"- {lib}\n") 32 | 33 | print("Complete output has been exported to output.md") 34 | 35 | # Print the content of the file after writing 36 | with open('output.md', 'r') as file: 37 | print(f"Content of output.md:\n{file.read()}") 38 | 39 | def main(): 40 | """ 41 | Main function to initialize and run the CrewAI Machine Learning Assistant. 42 | 43 | This function sets up a machine learning assistant using the Llama 3 model with the ChatGroq API. 44 | It provides a text-based interface for users to define, assess, and solve machine learning problems 45 | by interacting with multiple specialized AI agents. The function outputs the results to the console 46 | and writes them to a markdown file. 47 | 48 | Steps: 49 | 1. Initialize the ChatGroq API with the specified model and API key. 50 | 2. Display introductory text about the CrewAI Machine Learning Assistant. 51 | 3. Create and configure four AI agents: 52 | - Problem_Definition_Agent: Clarifies the machine learning problem the user wants to solve, 53 | identifying the type of problem (e.g., classification, regression) and any specific requirements. 54 | - Data_Assessment_Agent: Thoroughly evaluates the provided data, assesses its quality, detects and handles data issues, 55 | and suggests comprehensive preprocessing steps to prepare the data for machine learning models. 56 | - Model_Recommendation_Agent: Suggests suitable machine learning models based on the problem definition 57 | and data assessment, providing reasons for each recommendation. 58 | - Starter_Code_Generator_Agent: Generates starter Python code for the project, including data loading, cleaning, 59 | preprocessing, model definition, training, cross-validation (if recommended), and model 60 | comparison visualizations, based on findings from the problem definitions, data assessment, 61 | specific model recommendations, cross-validation assessment, and visualization recommendations. 62 | 4. Prompt the user to describe their machine learning problem. 63 | 5. Check if a .csv file is available in the current directory and try to read it as a DataFrame. 64 | 6. Define tasks for the agents based on user input and data availability. 65 | 7. Create a Crew instance with the agents and tasks, and run the tasks. 66 | 8. Print the results and write them to an output markdown file. 67 | """ 68 | 69 | # model = 'llama3-8b-8192' 70 | model = "llama3-70b-8192" 71 | 72 | llm = ChatGroq( 73 | temperature=0, groq_api_key=os.getenv("GROQ_API_KEY"), model_name=model 74 | ) 75 | 76 | print("CrewAI Machine Learning Assistant") 77 | multiline_text = """ 78 | The CrewAI Machine Learning Assistant is designed to guide users through the process of defining, assessing, and solving machine learning problems. It leverages a team of AI agents, each with a specific role, to clarify the problem, evaluate the data, recommend suitable models, and generate starter Python code. Whether you're a seasoned data scientist or a beginner, this application provides valuable insights and a head start in your machine learning projects. 79 | """ 80 | 81 | print(multiline_text) 82 | 83 | Problem_Definition_Agent = Agent( 84 | role="Problem_Definition_Agent", 85 | goal="Define the machine learning problem clearly and concisely.", 86 | backstory="You are an expert in understanding and defining machine learning problems.", 87 | verbose=True, 88 | allow_delegation=False, 89 | llm=llm, 90 | ) 91 | 92 | # Check if there is a .csv file in the current directory 93 | csv_files = [file for file in os.listdir() if file.endswith(".csv")] 94 | if csv_files: 95 | sample_fp = csv_files[0] 96 | try: 97 | # Attempt to read the uploaded file as a DataFrame 98 | df = pd.read_csv(sample_fp) 99 | data_info = f""" 100 | Dataset Information: 101 | - Filename: {sample_fp} 102 | - Number of rows: {df.shape[0]} 103 | - Number of columns: {df.shape[1]} 104 | - Columns: {', '.join(df.columns)} 105 | - Data types: 106 | {df.dtypes.to_string()} 107 | 108 | First 5 rows of the dataset: 109 | {df.head().to_string()} 110 | """ 111 | print("Data successfully loaded:") 112 | print(data_info) 113 | except Exception as e: 114 | print(f"Error reading the file: {e}") 115 | data_info = "No valid CSV file found or error reading the file." 116 | else: 117 | data_info = "No CSV file found in the current directory." 118 | 119 | Data_Assessment_Agent = Agent( 120 | role="Data_Assessment_Agent", 121 | goal="Assess and preprocess the data for the AI problem.", 122 | backstory=f"You are a data scientist specializing in data assessment, exploration, and preprocessing. Here's the data you're working with:\n{data_info}", 123 | verbose=True, 124 | allow_delegation=False, 125 | llm=llm, 126 | ) 127 | 128 | AI_Technique_Recommendation_Agent = Agent( 129 | role="AI_Technique_Recommendation_Agent", 130 | goal="Recommend suitable AI techniques, including machine learning, deep learning, and other approaches.", 131 | backstory=f"You are an expert in various AI techniques and their applications. Here's the data you're working with:\n{data_info}", 132 | verbose=True, 133 | allow_delegation=False, 134 | llm=llm, 135 | ) 136 | 137 | Code_Generator_Agent = Agent( 138 | role="Code_Generator_Agent", 139 | goal="Generate comprehensive Python code for the entire AI pipeline.", 140 | backstory="You are a skilled AI engineer proficient in writing clean, efficient code for various AI techniques.", 141 | verbose=True, 142 | allow_delegation=False, 143 | llm=llm, 144 | ) 145 | 146 | Cross_Validation_Agent = Agent( 147 | role="Cross_Validation_Agent", 148 | goal="Determine if k-fold cross-validation is appropriate for the given problem and dataset, and implement it when suitable.", 149 | backstory="You are an expert in model validation techniques. Your task is to assess whether k-fold cross-validation is necessary based on the problem type, dataset size, and other relevant factors. If appropriate, you implement and explain the cross-validation process.", 150 | verbose=True, 151 | allow_delegation=False, 152 | llm=llm, 153 | ) 154 | 155 | Visualization_Agent = Agent( 156 | role='Visualization_Agent', 157 | goal="Generate visualizations comparing the performance of the recommended models using matplotlib and seaborn.", 158 | backstory="You are a data visualization expert specializing in machine learning model comparisons. Your task is to create clear, informative visualizations that help users understand and compare the performance of the specific machine learning models recommended for their problem. You use matplotlib and seaborn to create appropriate visualizations based on the problem type and the recommended models.", 159 | verbose=True, 160 | allow_delegation=False, 161 | llm=llm, 162 | ) 163 | 164 | # user_question = input("Describe your ML problem: ") 165 | user_question = "Develop a model to predict the price of houses using the given data" 166 | data_upload = False 167 | # Check if there is a .csv file in the current directory 168 | if any(file.endswith(".csv") for file in os.listdir()): 169 | sample_fp = [file for file in os.listdir() if file.endswith(".csv")][0] 170 | try: 171 | # Attempt to read the uploaded file as a DataFrame 172 | df = pd.read_csv(sample_fp).head(5) 173 | 174 | # If successful, set 'data_upload' to True 175 | data_upload = True 176 | 177 | # Display the DataFrame in the app 178 | print("Data successfully uploaded and read as DataFrame:") 179 | print(df) 180 | except Exception as e: 181 | print(f"Error reading the file: {e}") 182 | 183 | if user_question: 184 | 185 | task_define_problem = Task( 186 | description="""Define the machine learning problem based on the user's input. Include: 187 | 1. Clear problem statement 188 | 2. Type of machine learning problem (e.g., classification, regression) 189 | 3. Specific requirements or constraints 190 | 4. Potential challenges""", 191 | agent=Problem_Definition_Agent, 192 | expected_output="A comprehensive problem definition with the elements listed above.", 193 | ) 194 | 195 | task_assess_data = Task( 196 | description=f"""Assess the data and suggest preprocessing steps. Include: 197 | 1. Data collection and exploration insights 198 | 2. Data quality assessment 199 | 3. Necessary preprocessing steps 200 | 4. Feature engineering suggestions 201 | Provide code snippets for data loading, exploration, and preprocessing. 202 | 203 | Data Information: 204 | {data_info}""", 205 | agent=Data_Assessment_Agent, 206 | expected_output="A detailed data assessment report with code snippets for preprocessing and feature engineering.", 207 | ) 208 | 209 | task_recommend_technique = Task( 210 | description=f"""Recommend suitable AI techniques. Include: 211 | 1. List of recommended techniques (machine learning, deep learning, and other AI approaches) with rationale 212 | 2. Pros and cons of each technique 213 | 3. Suggestions for technique selection criteria 214 | 4. Any ensemble or hybrid methods to consider 215 | 216 | Base your recommendations on this data: 217 | {data_info}""", 218 | agent=AI_Technique_Recommendation_Agent, 219 | expected_output="A comprehensive list of recommended AI techniques with detailed explanations.", 220 | ) 221 | 222 | task_generate_code = Task( 223 | description="""Generate Python code for the entire AI pipeline. Include: 224 | 1. Data loading and preprocessing 225 | 2. Feature engineering (if applicable) 226 | 3. Model/technique implementation (for all recommended approaches) 227 | 4. Training and evaluation 228 | 5. Cross-validation or other validation methods (if applicable) 229 | 6. Hyperparameter tuning suggestions 230 | 7. Performance comparison visualizations 231 | Ensure the code is well-commented and follows best practices.""", 232 | agent=Code_Generator_Agent, 233 | expected_output="Complete, well-structured Python code for the entire AI pipeline.", 234 | ) 235 | 236 | task_cross_validation = Task( 237 | description="""Implement k-fold cross-validation for the recommended models. Include: 238 | 1. Explanation of cross-validation importance 239 | 2. Code for implementing k-fold cross-validation 240 | 3. Guidelines for interpreting cross-validation results""", 241 | agent=Cross_Validation_Agent, 242 | expected_output="Detailed explanation and code for k-fold cross-validation implementation.", 243 | ) 244 | 245 | task_visualize_comparison = Task( 246 | description="""Create visualizations to compare the performance of recommended models. Include: 247 | 1. Appropriate visualization types for the problem (e.g., ROC curves, confusion matrices) 248 | 2. Code snippets for generating visualizations 249 | 3. Guidelines for interpreting the visualizations""", 250 | agent=Visualization_Agent, 251 | expected_output="Code snippets and explanations for model comparison visualizations.", 252 | ) 253 | 254 | crew = Crew( 255 | agents=[ 256 | Problem_Definition_Agent, 257 | Data_Assessment_Agent, 258 | AI_Technique_Recommendation_Agent, 259 | Code_Generator_Agent, 260 | Cross_Validation_Agent, 261 | Visualization_Agent, 262 | ], 263 | tasks=[ 264 | task_define_problem, 265 | task_assess_data, 266 | task_recommend_technique, 267 | task_generate_code, 268 | task_cross_validation, 269 | task_visualize_comparison, 270 | ], 271 | verbose=False, 272 | ) 273 | 274 | results = crew.kickoff() 275 | 276 | # Write the output to output.md 277 | with open('output.md', "w") as file: 278 | file.write("# AI Project Summary\n\n") 279 | file.write("## Problem Statement\n") 280 | file.write(f"{user_question}\n\n") 281 | 282 | for task in crew.tasks: 283 | if task.agent.role != "Problem_Definition_Agent": 284 | file.write(f"## {task.agent.role.replace('_', ' ')}\n\n") 285 | file.write(f"**Task Summary:** {task.output.summary}\n\n") 286 | file.write("**Key Recommendations:**\n") 287 | key_points = extract_key_points(task.output.raw) 288 | if key_points: 289 | for point in key_points: 290 | file.write(f"{point}\n") 291 | else: 292 | file.write("No specific key points extracted. Please refer to the task summary.\n") 293 | file.write("\n") 294 | 295 | # Extract and write code snippets 296 | code_snippets = extract_code_snippets(task.output.raw) 297 | if code_snippets: 298 | file.write("**Code Snippets:**\n") 299 | for snippet in code_snippets: 300 | file.write(f"```python\n{snippet}\n```\n\n") 301 | 302 | file.write("## Next Steps\n") 303 | file.write("Based on the analysis, consider the following next steps:\n") 304 | next_steps = extract_key_points(crew.tasks[-1].output.raw) 305 | if next_steps: 306 | for step in next_steps: 307 | file.write(f"- {step}\n") 308 | else: 309 | file.write("- Implement the suggested preprocessing steps\n") 310 | file.write("- Train and evaluate the recommended AI techniques\n") 311 | file.write("- Fine-tune the best performing approach\n") 312 | file.write("- Deploy the model/system and monitor its performance\n") 313 | 314 | print("Complete output has been exported to output.md") 315 | 316 | # Print the content of the file after writing 317 | with open('output.md', 'r') as file: 318 | print(f"Content of output.md:\n{file.read()}") 319 | 320 | def extract_key_points(raw_output): 321 | key_points = [] 322 | lines = raw_output.split('\n') 323 | for line in lines: 324 | line = line.strip() 325 | if line.startswith('-') or line.startswith('*') or (': ' in line and not line.startswith('**')): 326 | key_points.append(line) 327 | return key_points 328 | 329 | def extract_code_snippets(raw_output): 330 | code_snippets = [] 331 | lines = raw_output.split('\n') 332 | in_code_block = False 333 | current_snippet = [] 334 | 335 | for line in lines: 336 | if line.strip().startswith('```python'): 337 | in_code_block = True 338 | current_snippet = [] 339 | elif line.strip() == '```' and in_code_block: 340 | in_code_block = False 341 | code_snippets.append('\n'.join(current_snippet)) 342 | elif in_code_block: 343 | current_snippet.append(line) 344 | 345 | return code_snippets 346 | 347 | if __name__ == "__main__": 348 | main() 349 | -------------------------------------------------------------------------------- /output.md: -------------------------------------------------------------------------------- 1 | # AI Project Summary 2 | 3 | ## Problem Statement 4 | Develop a model to predict the price of houses using the given data 5 | 6 | ## Data Assessment Agent 7 | 8 | **Task Summary:** Assess the data and suggest preprocessing steps. Include: 9 | ... 10 | 11 | **Key Recommendations:** 12 | **Data Assessment Report** 13 | **Data Collection and Exploration Insights** 14 | **Data Quality Assessment** 15 | 1. **Missing Values**: The `Price` column contains NaN values, which need to be handled. 16 | 2. **Object Data Types**: The `Suburb`, `Address`, `Type`, `Method`, `SellerG`, `CouncilArea`, and `Regionname` columns are of object data type, which may require additional preprocessing steps. 17 | 3. **Date Column**: The `Date` column is of object data type, which may need to be converted to a datetime format for further analysis. 18 | **Necessary Preprocessing Steps** 19 | 1. **Handle Missing Values**: Impute the missing values in the `Price` column using a suitable imputation method, such as mean or median imputation. 20 | 2. **Encode Object Data Types**: Convert the object data types to numerical or categorical data types using techniques such as one-hot encoding or label encoding. 21 | 3. **Convert Date Column**: Convert the `Date` column to a datetime format using the `pd.to_datetime()` function. 22 | **Feature Engineering Suggestions** 23 | 1. **Extract Date Features**: Extract relevant date features, such as year, month, and day, from the `Date` column. 24 | 2. **Create New Features**: Create new features, such as the distance to the city center or the proximity to public transportation, using the `Distance` and `Postcode` columns. 25 | 3. **Aggregate Features**: Aggregate features, such as the average `Price` or `Landsize` per `Suburb`, to capture neighborhood-level patterns. 26 | **Code Snippets for Data Loading, Exploration, and Preprocessing** 27 | 28 | ## AI Technique Recommendation Agent 29 | 30 | **Task Summary:** Recommend suitable AI techniques. Include: 31 | ... 32 | 33 | **Key Recommendations:** 34 | **Recommended Techniques:** 35 | 1. **Linear Regression**: A simple and interpretable model that can capture the linear relationships between the features and the target variable (Price). 36 | 2. **Decision Trees**: A tree-based model that can handle both numerical and categorical features, and can capture non-linear relationships. 37 | 3. **Random Forest**: An ensemble method that combines multiple decision trees to improve the accuracy and robustness of the model. 38 | 4. **Gradient Boosting**: Another ensemble method that combines multiple weak models to create a strong predictor. 39 | 5. **Neural Networks**: A deep learning approach that can capture complex non-linear relationships between the features and the target variable. 40 | 6. **K-Nearest Neighbors (KNN)**: A simple and interpretable model that can capture local patterns in the data. 41 | 7. **Support Vector Machines (SVM)**: A model that can capture non-linear relationships between the features and the target variable, and can handle high-dimensional data. 42 | **Rationale:** 43 | * Linear Regression is a simple and interpretable model that can capture the linear relationships between the features and the target variable. 44 | * Decision Trees and Random Forest can handle both numerical and categorical features, and can capture non-linear relationships. 45 | * Gradient Boosting and Neural Networks can capture complex non-linear relationships between the features and the target variable. 46 | * KNN and SVM can capture local patterns in the data and handle high-dimensional data, respectively. 47 | **Pros and Cons of Each Technique:** 48 | * Pros: Simple, interpretable, and fast to train. 49 | * Cons: Assumes linear relationships, may not capture non-linear relationships. 50 | * Pros: Can handle both numerical and categorical features, can capture non-linear relationships. 51 | * Cons: May overfit the data, sensitive to feature scaling. 52 | * Pros: Improves the accuracy and robustness of decision trees, can handle high-dimensional data. 53 | * Cons: May be computationally expensive, difficult to interpret. 54 | * Pros: Can capture complex non-linear relationships, can handle high-dimensional data. 55 | * Cons: May overfit the data, computationally expensive. 56 | * Pros: Can capture complex non-linear relationships, can handle high-dimensional data. 57 | * Cons: May overfit the data, computationally expensive, difficult to interpret. 58 | * Pros: Simple, interpretable, and fast to train. 59 | * Cons: May not capture global patterns in the data, sensitive to feature scaling. 60 | * Pros: Can handle high-dimensional data, can capture non-linear relationships. 61 | * Cons: May be computationally expensive, difficult to interpret. 62 | **Technique Selection Criteria:** 63 | 1. **Data Complexity**: If the data is complex and has non-linear relationships, techniques like Gradient Boosting, Neural Networks, and SVM may be more suitable. 64 | 2. **Feature Importance**: If feature importance is crucial, techniques like Linear Regression, Decision Trees, and Random Forest may be more suitable. 65 | 3. **Interpretability**: If interpretability is crucial, techniques like Linear Regression, Decision Trees, and KNN may be more suitable. 66 | 4. **Computational Resources**: If computational resources are limited, techniques like Linear Regression, Decision Trees, and KNN may be more suitable. 67 | **Ensemble or Hybrid Methods to Consider:** 68 | 1. **Stacking**: Combine the predictions of multiple models to improve the overall accuracy. 69 | 2. **Bagging**: Combine multiple instances of the same model to improve the overall accuracy. 70 | 3. **Boosting**: Combine multiple weak models to create a strong predictor. 71 | 72 | ## Code Generator Agent 73 | 74 | **Task Summary:** Generate Python code for the entire AI pipeline. Include: 75 | ... 76 | 77 | **Key Recommendations:** 78 | 'Linear Regression': LinearRegression(), 79 | 'Decision Tree': DecisionTreeRegressor(), 80 | 'Random Forest': RandomForestRegressor(), 81 | 'Gradient Boosting': GradientBoostingRegressor(), 82 | 'Neural Network': MLPRegressor(), 83 | 'K-Nearest Neighbors': KNeighborsRegressor(), 84 | 'Support Vector Machine': SVR() 85 | results[name] = {'MSE': mse, 'R2': r2} 86 | 'Random Forest': {'n_estimators': [100, 200, 300], 'max_depth': [None, 5, 10]}, 87 | 'Gradient Boosting': {'n_estimators': [100, 200, 300], 'learning_rate': [0.1, 0.5, 1]}, 88 | 'Neural Network': {'hidden_layer_sizes': [(50, 50), (100, 100), (200, 200)]} 89 | tuned_results[name] = {'MSE': mse, 'R2': r2} 90 | 91 | **Code Snippets:** 92 | ```python 93 | import pandas as pd 94 | import numpy as np 95 | from sklearn.model_selection import train_test_split, cross_val_score, GridSearchCV 96 | from sklearn.preprocessing import StandardScaler, LabelEncoder 97 | from sklearn.linear_model import LinearRegression 98 | from sklearn.tree import DecisionTreeRegressor 99 | from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor 100 | from sklearn.neural_network import MLPRegressor 101 | from sklearn.neighbors import KNeighborsRegressor 102 | from sklearn.svm import SVR 103 | from sklearn.metrics import mean_squared_error, r2_score 104 | import matplotlib.pyplot as plt 105 | import seaborn as sns 106 | 107 | # Load the Melbourne housing dataset 108 | df = pd.read_csv('melbourne_housing.csv') 109 | 110 | # Preprocess the data 111 | X = df.drop(['Price'], axis=1) 112 | y = df['Price'] 113 | 114 | # Handle missing values 115 | X.fillna(X.mean(), inplace=True) 116 | 117 | # Encode categorical variables 118 | le = LabelEncoder() 119 | X['Suburb'] = le.fit_transform(X['Suburb']) 120 | 121 | # Scale the data 122 | scaler = StandardScaler() 123 | X_scaled = scaler.fit_transform(X) 124 | 125 | # Split the data into training and testing sets 126 | X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42) 127 | 128 | # Define the models 129 | models = { 130 | 'Linear Regression': LinearRegression(), 131 | 'Decision Tree': DecisionTreeRegressor(), 132 | 'Random Forest': RandomForestRegressor(), 133 | 'Gradient Boosting': GradientBoostingRegressor(), 134 | 'Neural Network': MLPRegressor(), 135 | 'K-Nearest Neighbors': KNeighborsRegressor(), 136 | 'Support Vector Machine': SVR() 137 | } 138 | 139 | # Train and evaluate each model 140 | results = {} 141 | for name, model in models.items(): 142 | model.fit(X_train, y_train) 143 | y_pred = model.predict(X_test) 144 | mse = mean_squared_error(y_test, y_pred) 145 | r2 = r2_score(y_test, y_pred) 146 | results[name] = {'MSE': mse, 'R2': r2} 147 | 148 | # Perform cross-validation for each model 149 | cv_results = {} 150 | for name, model in models.items(): 151 | scores = cross_val_score(model, X_train, y_train, cv=5, scoring='neg_mean_squared_error') 152 | cv_results[name] = np.mean(scores) 153 | 154 | # Hyperparameter tuning suggestions 155 | hyperparameter_tuning = { 156 | 'Random Forest': {'n_estimators': [100, 200, 300], 'max_depth': [None, 5, 10]}, 157 | 'Gradient Boosting': {'n_estimators': [100, 200, 300], 'learning_rate': [0.1, 0.5, 1]}, 158 | 'Neural Network': {'hidden_layer_sizes': [(50, 50), (100, 100), (200, 200)]} 159 | } 160 | 161 | # Perform hyperparameter tuning using GridSearchCV 162 | tuned_models = {} 163 | for name, params in hyperparameter_tuning.items(): 164 | model = models[name] 165 | grid_search = GridSearchCV(model, params, cv=5, scoring='neg_mean_squared_error') 166 | grid_search.fit(X_train, y_train) 167 | tuned_models[name] = grid_search.best_estimator_ 168 | 169 | # Evaluate the tuned models 170 | tuned_results = {} 171 | for name, model in tuned_models.items(): 172 | y_pred = model.predict(X_test) 173 | mse = mean_squared_error(y_test, y_pred) 174 | r2 = r2_score(y_test, y_pred) 175 | tuned_results[name] = {'MSE': mse, 'R2': r2} 176 | 177 | # Visualize the performance comparison 178 | plt.figure(figsize=(10, 6)) 179 | sns.barplot(x=list(results.keys()), y=[result['MSE'] for result in results.values()]) 180 | plt.xlabel('Model') 181 | plt.ylabel('Mean Squared Error') 182 | plt.title('Performance Comparison') 183 | plt.show() 184 | 185 | plt.figure(figsize=(10, 6)) 186 | sns.barplot(x=list(cv_results.keys()), y=list(cv_results.values())) 187 | plt.xlabel('Model') 188 | plt.ylabel('Cross-Validation Score') 189 | plt.title('Cross-Validation Results') 190 | plt.show() 191 | 192 | plt.figure(figsize=(10, 6)) 193 | sns.barplot(x=list(tuned_results.keys()), y=[result['MSE'] for result in tuned_results.values()]) 194 | plt.xlabel('Model') 195 | plt.ylabel('Mean Squared Error') 196 | plt.title('Tuned Model Performance') 197 | plt.show() 198 | ``` 199 | 200 | ## Cross Validation Agent 201 | 202 | **Task Summary:** Implement k-fold cross-validation for the recommended models. Include: 203 | ... 204 | 205 | **Key Recommendations:** 206 | **Importance of Cross-Validation** 207 | **Implementing k-Fold Cross-Validation** 208 | **Guidelines for Interpreting Cross-Validation Results** 209 | * A high cross-validation score indicates that the model is generalizing well to new, unseen data. 210 | * A low cross-validation score indicates that the model may be overfitting or underfitting the training data. 211 | * Compare the cross-validation scores across different models to determine which model is performing best. 212 | * Use the cross-validation results to tune hyperparameters and improve the model's performance. 213 | 214 | **Code Snippets:** 215 | ```python 216 | from sklearn.model_selection import cross_val_score 217 | 218 | # Perform k-fold cross-validation for each model 219 | cv_results = {} 220 | for name, model in models.items(): 221 | scores = cross_val_score(model, X_train, y_train, cv=5, scoring='neg_mean_squared_error') 222 | cv_results[name] = np.mean(scores) 223 | ``` 224 | 225 | ## Visualization Agent 226 | 227 | **Task Summary:** Create visualizations to compare the performance of recommended models. Include: 228 | ... 229 | 230 | **Key Recommendations:** 231 | **1. Cross-Validation Scores** 232 | Visualization Type: Bar Chart 233 | * The bar chart shows the cross-validation scores for each model. 234 | * A higher score indicates better performance. 235 | * Compare the scores across models to determine which one performs best. 236 | **2. ROC Curves** 237 | Visualization Type: ROC Curve 238 | * The ROC curve shows the trade-off between true positive rate and false positive rate for each model. 239 | * A higher true positive rate and a lower false positive rate indicate better performance. 240 | * Compare the ROC curves across models to determine which one has the best trade-off. 241 | **3. Confusion Matrices** 242 | Visualization Type: Heatmap 243 | * The confusion matrix shows the number of true positives, false positives, true negatives, and false negatives for each model. 244 | * A higher true positive rate and a lower false positive rate indicate better performance. 245 | * Compare the confusion matrices across models to determine which one has the best performance. 246 | 247 | **Code Snippets:** 248 | ```python 249 | import matplotlib.pyplot as plt 250 | 251 | # Create a bar chart of cross-validation scores 252 | plt.bar(cv_results.keys(), cv_results.values()) 253 | plt.xlabel('Model Name') 254 | plt.ylabel('Cross-Validation Score') 255 | plt.title('Cross-Validation Scores') 256 | plt.show() 257 | ``` 258 | 259 | ```python 260 | from sklearn.metrics import roc_curve 261 | import matplotlib.pyplot as plt 262 | 263 | # Create ROC curves for each model 264 | for name, model in models.items(): 265 | y_pred_proba = model.predict_proba(X_test)[:, 1] 266 | fpr, tpr, _ = roc_curve(y_test, y_pred_proba) 267 | plt.plot(fpr, tpr, label=name) 268 | plt.xlabel('False Positive Rate') 269 | plt.ylabel('True Positive Rate') 270 | plt.title('ROC Curves') 271 | plt.legend() 272 | plt.show() 273 | ``` 274 | 275 | ```python 276 | from sklearn.metrics import confusion_matrix 277 | import seaborn as sns 278 | import matplotlib.pyplot as plt 279 | 280 | # Create confusion matrices for each model 281 | for name, model in models.items(): 282 | y_pred = model.predict(X_test) 283 | cm = confusion_matrix(y_test, y_pred) 284 | sns.heatmap(cm, annot=True, cmap='Blues') 285 | plt.xlabel('Predicted labels') 286 | plt.ylabel('True labels') 287 | plt.title(f'Confusion Matrix - {name}') 288 | plt.show() 289 | ``` 290 | 291 | ## Next Steps 292 | Based on the analysis, consider the following next steps: 293 | - **1. Cross-Validation Scores** 294 | - Visualization Type: Bar Chart 295 | - * The bar chart shows the cross-validation scores for each model. 296 | - * A higher score indicates better performance. 297 | - * Compare the scores across models to determine which one performs best. 298 | - **2. ROC Curves** 299 | - Visualization Type: ROC Curve 300 | - * The ROC curve shows the trade-off between true positive rate and false positive rate for each model. 301 | - * A higher true positive rate and a lower false positive rate indicate better performance. 302 | - * Compare the ROC curves across models to determine which one has the best trade-off. 303 | - **3. Confusion Matrices** 304 | - Visualization Type: Heatmap 305 | - * The confusion matrix shows the number of true positives, false positives, true negatives, and false negatives for each model. 306 | - * A higher true positive rate and a lower false positive rate indicate better performance. 307 | - * Compare the confusion matrices across models to determine which one has the best performance. 308 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | crewai 2 | langchain_groq 3 | pandas 4 | nbformat 5 | python-dotenv 6 | 7 | --------------------------------------------------------------------------------