├── requirements.txt ├── run_all_tests.sh ├── model_config.json ├── questions.json ├── TEST.md ├── README.md ├── visualization.py └── evaluate.py /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers>=4.49.0 2 | torch>=2.1.0 3 | torchvision>=0.16.0 4 | torchaudio>=2.1.0 5 | tensorboard>=2.14.0 6 | pandas>=1.3.5 7 | numpy>=1.21.0 8 | matplotlib>=3.5.0 9 | seaborn>=0.11.2 -------------------------------------------------------------------------------- /run_all_tests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Run the main evaluation 4 | echo "Running model evaluation..." 5 | python evaluate.py 6 | 7 | # Generate visualizations 8 | echo "Generating visualizations..." 9 | python visualization.py 10 | 11 | echo "All tests completed!" 12 | echo "Check the 'results' directory for output files and visualizations." 13 | -------------------------------------------------------------------------------- /model_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "models": [ 3 | { 4 | "id": "titan-transformer", 5 | "name": "Titan Transformer", 6 | "path": "rajveer43/titan-transformer", 7 | "description": "Base implementation of the Titan architecture", 8 | "parameters": { 9 | "max_length": 150, 10 | "temperature": 0.7, 11 | "top_p": 0.9 12 | } 13 | } 14 | ], 15 | "evaluation": { 16 | "metrics": ["response_time", "token_count", "response_length"], 17 | "output_format": "json", 18 | "save_individual_results": true, 19 | "compare_models": true 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /questions.json: -------------------------------------------------------------------------------- 1 | { 2 | "Facts": [ 3 | "What is the capital of France?", 4 | "Who wrote \"To Kill a Mockingbird\"?", 5 | "What is the largest mammal in the world?", 6 | "What is the boiling point of water?", 7 | "Who painted the Mona Lisa?", 8 | "What is the speed of light?", 9 | "What is the chemical formula for table salt?", 10 | "Who discovered penicillin?", 11 | "What is the currency of Japan?", 12 | "What is the tallest mountain in the world?", 13 | "Who is the author of \"1984\"?", 14 | "What is the distance from the Earth to the Moon?" 15 | ], 16 | "Thinking": [ 17 | "If a train leaves the station at 3 PM and travels at 60 mph, what time will it arrive at a station 180 miles away?", 18 | "If a car travels 300 miles in 5 hours, what is its average speed?", 19 | "If a recipe calls for 2 cups of flour and you want to make half the recipe, how much flour do you need?", 20 | "If a rectangle has a length of 10 cm and a width of 5 cm, what is its area?", 21 | "If a triangle has a base of 6 cm and a height of 4 cm, what is its area?" 22 | ], 23 | "Code": [ 24 | "Write a Python function to calculate the factorial of a number.", 25 | "Write a Python function to check if a string is a palindrome.", 26 | "Write a Python function to find the maximum value in a list.", 27 | "Write a Python function to sort a list of numbers in ascending order.", 28 | "Write a Python function to reverse a string." 29 | ], 30 | "Art": [ 31 | "Write a poem about the beauty of nature.", 32 | "Write a short story about a time traveler.", 33 | "Write a haiku about the changing seasons.", 34 | "Write a song about friendship." 35 | ] 36 | } 37 | -------------------------------------------------------------------------------- /TEST.md: -------------------------------------------------------------------------------- 1 | ## List of questions to make to test the differents models: 2 | 3 | ### Facts 4 | 5 | - What is the capital of France? 6 | 7 | - Who wrote "To Kill a Mockingbird"? 8 | 9 | - What is the largest mammal in the world? 10 | 11 | - What is the boiling point of water? 12 | 13 | - Who painted the Mona Lisa? 14 | 15 | - What is the speed of light? 16 | 17 | - What is the chemical formula for table salt? 18 | 19 | - Who discovered penicillin? 20 | 21 | - What is the currency of Japan? 22 | 23 | - What is the tallest mountain in the world? 24 | 25 | - Who is the author of "1984"? 26 | 27 | - What is the distance from the Earth to the Moon? 28 | 29 | ### Thinking 30 | 31 | - If a train leaves the station at 3 PM and travels at 60 mph, what time will it arrive at a station 180 miles away? 32 | 33 | - If a car travels 300 miles in 5 hours, what is its average speed? 34 | 35 | - If a recipe calls for 2 cups of flour and you want to make half the recipe, how much flour do you need? 36 | 37 | - If a rectangle has a length of 10 cm and a width of 5 cm, what is its area? 38 | 39 | - If a triangle has a base of 6 cm and a height of 4 cm, what is its area? 40 | 41 | ### Code 42 | 43 | - Write a Python function to calculate the factorial of a number. 44 | 45 | - Write a Python function to check if a string is a palindrome. 46 | 47 | - Write a Python function to find the maximum value in a list. 48 | 49 | - Write a Python function to sort a list of numbers in ascending order. 50 | 51 | - Write a Python function to reverse a string. 52 | 53 | ### Art 54 | 55 | - Write a poem about the beauty of nature. 56 | 57 | - Write a short story about a time traveler. 58 | 59 | - Write a haiku about the changing seasons. 60 | 61 | - Write a song about friendship. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Testing some State of The Art custom models of the Google Titans Architecture 2 | 3 | [![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/release/python-380/) 4 | [![Hugging Face](https://img.shields.io/badge/🤗%20Transformers-Enabled-yellow.svg)](https://huggingface.co/transformers/) 5 | [![License: MIT](https://img.shields.io/badge/License-MIT-green.svg)](https://opensource.org/licenses/MIT) 6 | [![Experimental](https://img.shields.io/badge/Status-Experimental-orange.svg)]() 7 | 8 | ## Overview 9 | 10 | This repository contains experimental code and evaluations for testing various open-source models that claim to implement or be inspired by Google's Titans architecture. Google Titans represents a series of state-of-the-art AI models developed by Google, known for their advanced capabilities in natural language understanding and generation. 11 | 12 | ## Purpose 13 | 14 | The goal of this project is to: 15 | 16 | 1. Test the performance of different open-source models that claim to be based on or inspired by the Google Titans architecture. 17 | 18 | 2. Evaluate their capabilities in various tasks such as text generation, summarization, and question answering. 19 | 20 | 3. Collect metrics to compare and analyze model performance across different scenarios. 21 | 22 | ## Features 23 | 24 | - **Comprehensive Testing Framework**: Evaluate models across different categories of questions 25 | - **Performance Metrics**: Measure response time, token count, and other quantitative metrics 26 | - **Comparison Tools**: Compare results between different models with visual representations 27 | - **Configurable Testing**: Easily add new models and configure parameters 28 | 29 | ## Getting Started 30 | 31 | ### Prerequisites 32 | 33 | ```bash 34 | pip install -r requirements.txt 35 | ``` 36 | 37 | ### Running Tests 38 | 39 | To evaluate models: 40 | 41 | ```bash 42 | python evaluate.py 43 | ``` 44 | 45 | Results will be stored in the `results/` directory with a timestamp. 46 | 47 | ## Configuration 48 | 49 | ### Model Configuration 50 | 51 | Models can be added and configured in the `model_config.json` file: 52 | 53 | ```json 54 | { 55 | "models": [ 56 | { 57 | "id": "model-id", 58 | "name": "Model Name", 59 | "path": "huggingface/model-path", 60 | "description": "Description of the model", 61 | "parameters": { 62 | "max_length": 150, 63 | "temperature": 0.7, 64 | "top_p": 0.9 65 | } 66 | } 67 | ] 68 | } 69 | ``` 70 | 71 | ### Test Questions 72 | 73 | Questions are organized by category in `questions.json`. You can modify or extend these questions to test different aspects of the models. 74 | 75 | ## Models Tested 76 | 77 | - [Titan Transformer](https://huggingface.co/rajveer43/titan-transformer) - Base implementation of the Titan architecture 78 | 79 | ## Test Categories 80 | 81 | - **Facts**: Testing knowledge retrieval and factual accuracy 82 | - **Thinking**: Testing reasoning and problem-solving capabilities 83 | - **Code**: Testing code generation abilities 84 | - **Art**: Testing creative writing capabilities 85 | 86 | ## Disclaimer 87 | 88 | This is an independent research project and is not affiliated with or endorsed by Google. The models tested here are third-party implementations that claim compatibility or similarity with the Google Titans architecture. 89 | 90 | -------------------------------------------------------------------------------- /visualization.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pandas as pd 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | import seaborn as sns 7 | from typing import Dict, List, Any, Optional 8 | from datetime import datetime 9 | 10 | class ResultVisualizer: 11 | """ 12 | Visualizes model evaluation results from the evaluator. 13 | """ 14 | 15 | def __init__(self, results_dir: Optional[str] = None): 16 | """ 17 | Initialize the visualizer. 18 | 19 | Args: 20 | results_dir: Path to the results directory. If None, will use the most recent. 21 | """ 22 | if results_dir is None: 23 | # Find the most recent results directory 24 | results_base = 'results' 25 | if not os.path.exists(results_base): 26 | raise ValueError(f"Results directory '{results_base}' not found") 27 | 28 | subdirs = [os.path.join(results_base, d) for d in os.listdir(results_base) 29 | if os.path.isdir(os.path.join(results_base, d))] 30 | 31 | if not subdirs: 32 | raise ValueError("No results directories found") 33 | 34 | self.results_dir = max(subdirs, key=os.path.getmtime) 35 | else: 36 | self.results_dir = results_dir 37 | 38 | print(f"Visualizing results from: {self.results_dir}") 39 | 40 | # Load comparison data 41 | self.comparison_file = os.path.join(self.results_dir, "model_comparison.json") 42 | if os.path.exists(self.comparison_file): 43 | with open(self.comparison_file, 'r') as f: 44 | self.comparison_data = json.load(f) 45 | else: 46 | raise FileNotFoundError(f"Comparison file not found: {self.comparison_file}") 47 | 48 | def create_visualizations(self, output_dir: Optional[str] = None) -> None: 49 | """ 50 | Generate visualizations from the evaluation results. 51 | 52 | Args: 53 | output_dir: Directory to save visualizations. Defaults to results_dir/visualizations 54 | """ 55 | if output_dir is None: 56 | output_dir = os.path.join(self.results_dir, 'visualizations') 57 | 58 | os.makedirs(output_dir, exist_ok=True) 59 | 60 | # Set the style 61 | sns.set_theme(style="whitegrid") 62 | plt.rcParams.update({'font.size': 12}) 63 | 64 | # Generate overall metrics comparison 65 | self._plot_overall_metrics(output_dir) 66 | 67 | # Generate category performance comparison 68 | self._plot_category_performance(output_dir) 69 | 70 | print(f"Visualizations saved to: {output_dir}") 71 | 72 | def _plot_overall_metrics(self, output_dir: str) -> None: 73 | """Plot overall metrics comparison.""" 74 | # Extract metrics summary 75 | metrics_summary = self.comparison_data.get("metrics_summary", {}) 76 | if not metrics_summary: 77 | print("No metrics summary data found") 78 | return 79 | 80 | # Convert to dataframe 81 | metrics_df = pd.DataFrame.from_dict(metrics_summary, orient='index') 82 | 83 | # Plot average time per question 84 | plt.figure(figsize=(12, 6)) 85 | ax = sns.barplot(x=metrics_df.index, y='average_time', data=metrics_df) 86 | plt.title('Average Response Time per Question') 87 | plt.ylabel('Time (seconds)') 88 | plt.xlabel('Model') 89 | plt.xticks(rotation=45) 90 | plt.tight_layout() 91 | 92 | # Add value labels 93 | for i, v in enumerate(metrics_df['average_time']): 94 | ax.text(i, v + 0.1, f"{v:.2f}s", ha='center') 95 | 96 | plt.savefig(os.path.join(output_dir, 'average_response_time.png'), dpi=300) 97 | plt.close() 98 | 99 | # Plot total tokens 100 | plt.figure(figsize=(12, 6)) 101 | ax = sns.barplot(x=metrics_df.index, y='total_tokens', data=metrics_df) 102 | plt.title('Total Tokens Generated') 103 | plt.ylabel('Number of Tokens') 104 | plt.xlabel('Model') 105 | plt.xticks(rotation=45) 106 | plt.tight_layout() 107 | 108 | # Add value labels 109 | for i, v in enumerate(metrics_df['total_tokens']): 110 | ax.text(i, v + 0.1, f"{int(v)}", ha='center') 111 | 112 | plt.savefig(os.path.join(output_dir, 'total_tokens.png'), dpi=300) 113 | plt.close() 114 | 115 | def _plot_category_performance(self, output_dir: str) -> None: 116 | """Plot category performance comparison.""" 117 | # Extract category performance 118 | category_performance = self.comparison_data.get("category_performance", {}) 119 | if not category_performance: 120 | print("No category performance data found") 121 | return 122 | 123 | # Create a dataframe for plotting 124 | plot_data = [] 125 | for category, models_data in category_performance.items(): 126 | for model, metrics in models_data.items(): 127 | plot_data.append({ 128 | 'Category': category, 129 | 'Model': model, 130 | 'Average Time': metrics.get('average_time', 0) 131 | }) 132 | 133 | performance_df = pd.DataFrame(plot_data) 134 | 135 | # Plot category performance 136 | plt.figure(figsize=(14, 8)) 137 | ax = sns.barplot(x='Category', y='Average Time', hue='Model', data=performance_df) 138 | plt.title('Average Response Time by Category') 139 | plt.ylabel('Time (seconds)') 140 | plt.xlabel('Category') 141 | plt.legend(title='Model') 142 | plt.tight_layout() 143 | 144 | plt.savefig(os.path.join(output_dir, 'category_performance.png'), dpi=300) 145 | plt.close() 146 | 147 | # Create individual category plots 148 | for category in category_performance.keys(): 149 | category_data = performance_df[performance_df['Category'] == category] 150 | 151 | plt.figure(figsize=(10, 6)) 152 | ax = sns.barplot(x='Model', y='Average Time', data=category_data) 153 | plt.title(f'{category} - Average Response Time') 154 | plt.ylabel('Time (seconds)') 155 | plt.xlabel('Model') 156 | plt.xticks(rotation=45) 157 | plt.tight_layout() 158 | 159 | # Add value labels 160 | for i, v in enumerate(category_data['Average Time']): 161 | ax.text(i, v + 0.1, f"{v:.2f}s", ha='center') 162 | 163 | plt.savefig(os.path.join(output_dir, f'{category.lower()}_performance.png'), dpi=300) 164 | plt.close() 165 | 166 | def main(): 167 | """Main function to generate visualizations.""" 168 | try: 169 | visualizer = ResultVisualizer() 170 | visualizer.create_visualizations() 171 | print("✅ Visualization generation completed successfully!") 172 | except Exception as e: 173 | print(f"❌ Error generating visualizations: {e}") 174 | 175 | if __name__ == "__main__": 176 | main() 177 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import json 2 | import time 3 | import os 4 | import torch 5 | import numpy as np 6 | import pandas as pd 7 | from typing import Dict, List, Any, Optional, Tuple 8 | from datetime import datetime 9 | from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM 10 | 11 | class ModelEvaluator: 12 | """ 13 | Evaluates language models on various tasks and collects performance metrics. 14 | """ 15 | 16 | def __init__(self, model_config_path: str = 'model_config.json', questions_path: str = 'questions.json'): 17 | """ 18 | Initialize the evaluator with configuration files. 19 | 20 | Args: 21 | model_config_path: Path to the model configuration file 22 | questions_path: Path to the questions file 23 | """ 24 | self.load_config(model_config_path) 25 | self.load_questions(questions_path) 26 | self.results_dir = os.path.join('results', datetime.now().strftime('%Y%m%d_%H%M%S')) 27 | os.makedirs(self.results_dir, exist_ok=True) 28 | 29 | def load_config(self, config_path: str) -> None: 30 | """Load model configurations from JSON file.""" 31 | try: 32 | with open(config_path, 'r', encoding='utf-8') as f: 33 | self.config = json.load(f) 34 | print(f"✓ Loaded {len(self.config['models'])} model configurations") 35 | except Exception as e: 36 | print(f"Error loading model config: {e}") 37 | exit(1) 38 | 39 | def load_questions(self, questions_path: str) -> None: 40 | """Load evaluation questions from JSON file.""" 41 | try: 42 | with open(questions_path, 'r', encoding='utf-8') as f: 43 | self.questions = json.load(f) 44 | total_questions = sum(len(qs) for qs in self.questions.values()) 45 | print(f"✓ Loaded {total_questions} questions across {len(self.questions)} categories") 46 | except Exception as e: 47 | print(f"Error loading questions: {e}") 48 | exit(1) 49 | 50 | def load_model(self, model_config: Dict[str, Any]) -> Tuple[Any, Any]: 51 | """ 52 | Load a model and tokenizer based on configuration. 53 | 54 | Args: 55 | model_config: Dictionary containing model configuration 56 | 57 | Returns: 58 | Tuple of (model, tokenizer) 59 | """ 60 | print(f"\n📚 Loading model: {model_config['name']} ({model_config['path']})") 61 | try: 62 | tokenizer = AutoTokenizer.from_pretrained(model_config['path']) 63 | model = AutoModelForCausalLM.from_pretrained(model_config['path']) 64 | 65 | # Configure device 66 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 67 | print(f" Using device: {device}") 68 | 69 | return model, tokenizer 70 | except Exception as e: 71 | print(f"❌ Error loading model {model_config['name']}: {e}") 72 | return None, None 73 | 74 | def run_evaluation(self) -> Dict[str, Any]: 75 | """ 76 | Run evaluation for all configured models. 77 | 78 | Returns: 79 | Dictionary containing evaluation results 80 | """ 81 | all_results = {} 82 | 83 | for model_config in self.config['models']: 84 | model_id = model_config['id'] 85 | print(f"\n🔍 Evaluating model: {model_config['name']}") 86 | 87 | model, tokenizer = self.load_model(model_config) 88 | if model is None or tokenizer is None: 89 | all_results[model_id] = {"error": "Failed to load model"} 90 | continue 91 | 92 | # Create text generation pipeline 93 | text_generator = pipeline( 94 | 'text-generation', 95 | model=model, 96 | tokenizer=tokenizer, 97 | device='cuda' if torch.cuda.is_available() else 'cpu' 98 | ) 99 | 100 | # Get parameters for this model 101 | params = model_config.get('parameters', {}) 102 | 103 | # Process questions and collect metrics 104 | model_results = self.process_questions(text_generator, tokenizer, params) 105 | 106 | # Save individual results if configured 107 | if self.config['evaluation'].get('save_individual_results', True): 108 | self.save_results(model_results, f"{model_id}_results.json") 109 | 110 | all_results[model_id] = model_results 111 | 112 | # Generate comparison if multiple models were evaluated 113 | if len(all_results) > 1 and self.config['evaluation'].get('compare_models', True): 114 | self.generate_comparison(all_results) 115 | 116 | return all_results 117 | 118 | def process_questions(self, 119 | text_generator: Any, 120 | tokenizer: Any, 121 | params: Dict[str, Any]) -> Dict[str, Any]: 122 | """ 123 | Process all questions and collect metrics. 124 | 125 | Args: 126 | text_generator: The text generation pipeline 127 | tokenizer: The model tokenizer 128 | params: Generation parameters 129 | 130 | Returns: 131 | Dictionary with results and metrics 132 | """ 133 | results = {} 134 | metrics = { 135 | "total_time": 0, 136 | "total_tokens": 0, 137 | "average_time_per_question": 0, 138 | "category_metrics": {} 139 | } 140 | 141 | total_questions = sum(len(qs) for qs in self.questions.values()) 142 | processed_count = 0 143 | 144 | # Process each category 145 | for category, category_questions in self.questions.items(): 146 | results[category] = {} 147 | metrics["category_metrics"][category] = { 148 | "total_time": 0, 149 | "total_tokens": 0, 150 | "average_time": 0 151 | } 152 | 153 | print(f" Processing {len(category_questions)} {category} questions...") 154 | 155 | # Process each question 156 | for question in category_questions: 157 | processed_count += 1 158 | print(f" [{processed_count}/{total_questions}] {question[:50]}{'...' if len(question) > 50 else ''}") 159 | 160 | try: 161 | # Measure response time 162 | start_time = time.time() 163 | output = text_generator( 164 | question, 165 | max_length=params.get('max_length', 100), 166 | temperature=params.get('temperature', 0.7), 167 | top_p=params.get('top_p', 0.9), 168 | num_return_sequences=1, 169 | pad_token_id=tokenizer.eos_token_id 170 | ) 171 | end_time = time.time() 172 | response_time = end_time - start_time 173 | 174 | # Get the generated text 175 | generated_text = output[0]['generated_text'] 176 | 177 | # Count tokens 178 | tokens = tokenizer(generated_text, return_tensors="pt") 179 | token_count = len(tokens.input_ids[0]) 180 | 181 | # Store results and metrics 182 | results[category][question] = { 183 | "response": generated_text, 184 | "metrics": { 185 | "response_time": response_time, 186 | "token_count": token_count, 187 | "response_length": len(generated_text) 188 | } 189 | } 190 | 191 | # Update metrics 192 | metrics["total_time"] += response_time 193 | metrics["total_tokens"] += token_count 194 | metrics["category_metrics"][category]["total_time"] += response_time 195 | metrics["category_metrics"][category]["total_tokens"] += token_count 196 | 197 | print(f" ✓ Response time: {response_time:.2f}s, Tokens: {token_count}") 198 | 199 | except Exception as e: 200 | print(f" ❌ Error: {e}") 201 | results[category][question] = {"error": str(e)} 202 | 203 | # Calculate category averages 204 | cat_count = len(category_questions) 205 | if cat_count > 0: 206 | metrics["category_metrics"][category]["average_time"] = ( 207 | metrics["category_metrics"][category]["total_time"] / cat_count 208 | ) 209 | 210 | # Calculate overall average 211 | if total_questions > 0: 212 | metrics["average_time_per_question"] = metrics["total_time"] / total_questions 213 | 214 | return { 215 | "results": results, 216 | "metrics": metrics 217 | } 218 | 219 | def save_results(self, results: Dict[str, Any], filename: str) -> None: 220 | """Save results to a JSON file.""" 221 | file_path = os.path.join(self.results_dir, filename) 222 | with open(file_path, 'w', encoding='utf-8') as f: 223 | json.dump(results, f, ensure_ascii=False, indent=2) 224 | print(f"✓ Results saved to {file_path}") 225 | 226 | def generate_comparison(self, all_results: Dict[str, Dict[str, Any]]) -> None: 227 | """ 228 | Generate comparison report between models. 229 | 230 | Args: 231 | all_results: Dictionary of results keyed by model_id 232 | """ 233 | print("\n📊 Generating model comparison report...") 234 | 235 | # Extract metrics for comparison 236 | comparison = { 237 | "metrics_summary": {}, 238 | "category_performance": {} 239 | } 240 | 241 | for model_id, model_data in all_results.items(): 242 | if "error" in model_data: 243 | continue 244 | 245 | metrics = model_data.get("metrics", {}) 246 | comparison["metrics_summary"][model_id] = { 247 | "total_time": metrics.get("total_time", 0), 248 | "average_time": metrics.get("average_time_per_question", 0), 249 | "total_tokens": metrics.get("total_tokens", 0) 250 | } 251 | 252 | # Category performance 253 | for category, cat_metrics in metrics.get("category_metrics", {}).items(): 254 | if category not in comparison["category_performance"]: 255 | comparison["category_performance"][category] = {} 256 | 257 | comparison["category_performance"][category][model_id] = { 258 | "average_time": cat_metrics.get("average_time", 0) 259 | } 260 | 261 | # Save comparison to file 262 | self.save_results(comparison, "model_comparison.json") 263 | 264 | # Generate tables for easier viewing 265 | self._create_comparison_tables(comparison) 266 | 267 | def _create_comparison_tables(self, comparison: Dict[str, Dict]) -> None: 268 | """Create comparison tables and save as CSV files.""" 269 | # Overall metrics table 270 | metrics_df = pd.DataFrame(comparison["metrics_summary"]).T 271 | metrics_df = metrics_df.sort_values("average_time") 272 | metrics_df.to_csv(os.path.join(self.results_dir, "metrics_comparison.csv")) 273 | 274 | # Category performance tables 275 | for category, models_data in comparison["category_performance"].items(): 276 | cat_df = pd.DataFrame(models_data).T 277 | cat_df.to_csv(os.path.join(self.results_dir, f"{category}_comparison.csv")) 278 | 279 | def main(): 280 | """Main function to run the evaluation.""" 281 | print("🚀 Starting model evaluation process...") 282 | evaluator = ModelEvaluator() 283 | results = evaluator.run_evaluation() 284 | print("\n✅ Evaluation completed successfully!") 285 | print(f" Results saved in the {evaluator.results_dir} directory") 286 | 287 | if __name__ == "__main__": 288 | main() 289 | --------------------------------------------------------------------------------