├── LICENSE ├── README.md ├── output ├── bottleneck_visualization.png ├── error_distribution.png └── performance_comparison.png ├── requirements.txt ├── run_demo.py ├── src ├── bottleneck_identifier.py ├── config_manager.py ├── evaluator.py ├── hook_manager.py ├── logger.py ├── model_parser.py ├── model_transformer.py ├── quantization_tester.py └── visualizer.py └── unit_test ├── net.py ├── test_bottleneck_identifier.py ├── test_config_manager.py ├── test_evaluator.py ├── test_hook_manager.py ├── test_logger.py ├── test_model_parser.py ├── test_model_transformer.py └── test_quantization_tester.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 TorchQuant Contributors 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TorchQuant 2 | 3 | TorchQuant is a PyTorch-based toolkit for model quantization that helps reduce model size and improve inference performance while maintaining accuracy. It provides tools for identifying and handling bottleneck modules in neural networks during the quantization process. 4 | 5 | ## Features 6 | 7 | - **Intelligent Model Quantization**: Convert PyTorch models from float32 to float16 precision while preserving critical layers 8 | - **Bottleneck Identification**: Automatically identify bottleneck modules where quantization would cause significant accuracy drop 9 | - **Quantization Testing**: Evaluate the impact of quantization on individual modules 10 | - **Error Analysis**: Measure and analyze the error introduced by quantization 11 | - **Visualization Tools**: Visualize error distributions, bottleneck modules, and performance comparisons 12 | - **Configurable Pipeline**: Customize the quantization process via configuration files 13 | 14 | ## Installation 15 | 16 | ### Prerequisites 17 | 18 | - Python 3.6+ 19 | - PyTorch 1.7+ 20 | - NetworkX 21 | - Matplotlib 22 | - Seaborn 23 | 24 | ### Installation Steps 25 | 26 | ```bash 27 | # Clone the repository 28 | git clone https://github.com/yourusername/torch_quant.git 29 | cd torch_quant 30 | 31 | # Install dependencies 32 | pip install torch networkx matplotlib seaborn 33 | ``` 34 | 35 | ## Quick Start 36 | 37 | ```python 38 | import torch 39 | import torch.nn as nn 40 | from torch_quant.src.model_transformer import ModelTransformer 41 | from torch_quant.src.bottleneck_identifier import BottleneckIdentifier 42 | from torch_quant.src.visualizer import Visualizer 43 | 44 | # Load your model 45 | model = YourModel() 46 | 47 | # Identify bottleneck modules 48 | identifier = BottleneckIdentifier(error_threshold=0.01) 49 | bottleneck_modules, module_names, errors = identifier.identify_bottlenecks(model, sample_input) 50 | 51 | # Quantize the model 52 | transformer = ModelTransformer() 53 | quantized_model = transformer.quantize_model(model, bottleneck_modules, module_names) 54 | 55 | # Visualize the results 56 | visualizer = Visualizer(output_dir='./output') 57 | visualizer.plot_error_distribution(errors) 58 | visualizer.visualize_bottlenecks(model, bottleneck_modules, module_names) 59 | 60 | # Run inference with the quantized model 61 | output = quantized_model(sample_input) 62 | ``` 63 | 64 | For a complete example, see the `run_demo.py` file. 65 | 66 | ## API Overview 67 | 68 | ### ModelTransformer 69 | 70 | The core class that handles model quantization: 71 | 72 | ```python 73 | from torch_quant.src.model_transformer import ModelTransformer 74 | 75 | transformer = ModelTransformer() 76 | quantized_model = transformer.quantize_model(model, bottleneck_modules, module_names) 77 | ``` 78 | 79 | ### BottleneckIdentifier 80 | 81 | Identifies modules that are sensitive to quantization: 82 | 83 | ```python 84 | from torch_quant.src.bottleneck_identifier import BottleneckIdentifier 85 | 86 | identifier = BottleneckIdentifier(error_threshold=0.01) 87 | bottleneck_modules, module_names, errors = identifier.identify_bottlenecks(model, sample_input) 88 | ``` 89 | 90 | ### QuantizationTester 91 | 92 | Tests the impact of quantization on model modules: 93 | 94 | ```python 95 | from torch_quant.src.quantization_tester import QuantizationTester 96 | 97 | tester = QuantizationTester(error_metric='mse') 98 | errors = tester.test_modules(modules, module_io, module_names) 99 | ``` 100 | 101 | ### Visualizer 102 | 103 | Provides visualization tools for analysis: 104 | 105 | ```python 106 | from torch_quant.src.visualizer import Visualizer 107 | 108 | visualizer = Visualizer(output_dir='./output') 109 | visualizer.plot_error_distribution(errors) 110 | visualizer.visualize_bottlenecks(model, bottleneck_modules, module_names) 111 | visualizer.plot_performance_comparison(original_results, quantized_results) 112 | ``` 113 | 114 | ## Configuration 115 | 116 | TorchQuant can be configured via code, environment variables, or configuration files. See `src/config_manager.py` for details. 117 | 118 | ## Testing 119 | 120 | Unit tests are available in the `unit_test/` directory. Run them with pytest: 121 | 122 | ```bash 123 | cd torch_quant 124 | pytest 125 | ``` 126 | 127 | ## License 128 | 129 | This project is licensed under the MIT License - see below for details: 130 | 131 | ``` 132 | MIT License 133 | 134 | Copyright (c) 2023 [Your Name] 135 | 136 | Permission is hereby granted, free of charge, to any person obtaining a copy 137 | of this software and associated documentation files (the "Software"), to deal 138 | in the Software without restriction, including without limitation the rights 139 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 140 | copies of the Software, and to permit persons to whom the Software is 141 | furnished to do so, subject to the following conditions: 142 | 143 | The above copyright notice and this permission notice shall be included in all 144 | copies or substantial portions of the Software. 145 | 146 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 147 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 148 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 149 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 150 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 151 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 152 | SOFTWARE. 153 | ``` 154 | 155 | ## Contributing 156 | 157 | Contributions are welcome! Please feel free to submit a Pull Request. 158 | 159 | 1. Fork the repository 160 | 2. Create your feature branch (`git checkout -b feature/amazing-feature`) 161 | 3. Commit your changes (`git commit -m 'Add some amazing feature'`) 162 | 4. Push to the branch (`git push origin feature/amazing-feature`) 163 | 5. Open a Pull Request 164 | 165 | ## Acknowledgments 166 | 167 | - PyTorch team for providing the foundational deep learning framework 168 | - The open-source community for various libraries used in this project -------------------------------------------------------------------------------- /output/bottleneck_visualization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pentilm/torch_quant/f16753d7b476b9587c5b9c7b46a6f2a8a033919a/output/bottleneck_visualization.png -------------------------------------------------------------------------------- /output/error_distribution.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pentilm/torch_quant/f16753d7b476b9587c5b9c7b46a6f2a8a033919a/output/error_distribution.png -------------------------------------------------------------------------------- /output/performance_comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pentilm/torch_quant/f16753d7b476b9587c5b9c7b46a6f2a8a033919a/output/performance_comparison.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.7.0 2 | numpy>=1.19.0 3 | matplotlib>=3.3.0 4 | seaborn>=0.11.0 5 | networkx>=2.5 6 | pytest>=6.0.0 -------------------------------------------------------------------------------- /run_demo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.data import DataLoader, TensorDataset 4 | import numpy as np 5 | import os 6 | 7 | # Import the modules from your codebase 8 | # (Assuming they are all in the same folder or installed in your environment) 9 | from unit_test.net import CNNWithAttention 10 | from src.hook_manager import HookManager 11 | from src.quantization_tester import QuantizationTester 12 | from src.model_parser import ModelParser 13 | from src.bottleneck_identifier import BottleneckIdentifier 14 | from src.model_transformer import ModelTransformer 15 | from src.evaluator import Evaluator 16 | from src.visualizer import Visualizer # Import the Visualizer 17 | 18 | def main(): 19 | # Create output directory for visualizations 20 | output_dir = './output' 21 | os.makedirs(output_dir, exist_ok=True) 22 | 23 | # 1. Create a small dummy dataset 24 | # Here we make random tensors mimicking an image dataset (like CIFAR10: shape [batch, 3, 32, 32]) 25 | # with random integer labels in 0..9. 26 | device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 27 | 28 | num_samples = 128 29 | num_classes = 10 30 | images = torch.rand(num_samples, 3, 32, 32).to(device) 31 | labels = torch.randint(0, num_classes, (num_samples,)).to(device) 32 | 33 | dataset = TensorDataset(images, labels) 34 | dataloader = DataLoader(dataset, batch_size=16, shuffle=False) 35 | 36 | # 2. Instantiate the model (from net.py) 37 | model = CNNWithAttention(num_classes=num_classes).to(device) 38 | 39 | # 3. Register hooks to capture module inputs/outputs 40 | hook_manager = HookManager() 41 | hook_manager.register_hooks(model) 42 | 43 | # 4. Feed a few batches through the model to record module I/O 44 | # We'll just run the entire dataset for simplicity. 45 | for batch_data, _ in dataloader: 46 | _ = model(batch_data) 47 | 48 | # Retrieve the captured I/O from the hook manager 49 | module_io = hook_manager.get_module_io() 50 | 51 | # 5. Use QuantizationTester to compare float32 vs float16 outputs at each module 52 | # We first need a ModelParser to map modules to names 53 | parser = ModelParser(model) 54 | module_names = parser.get_module_names() 55 | 56 | # Generate a dictionary {module_name: module} for convenience 57 | all_modules = parser.get_all_modules() 58 | 59 | # Build a dictionary that aligns with the quantization_tester's expected input: 60 | # { module_name : module_instance } 61 | modules_dict = {name: mod for name, mod in all_modules.items()} 62 | 63 | # Display model structure 64 | parser.print_model_structure() 65 | 66 | tester = QuantizationTester(error_metric='mse') 67 | errors = tester.test_modules(modules_dict, module_io, module_names) 68 | 69 | # Initialize the visualizer 70 | visualizer = Visualizer(output_dir=output_dir) 71 | 72 | # Visualize error distribution across modules 73 | print("\nGenerating error distribution visualization...") 74 | visualizer.plot_error_distribution(errors) 75 | 76 | # 6. Analyze dependencies between modules to identify bottleneck layers more accurately 77 | # Analyze module dependencies based on model structure 78 | module_dependencies = {} 79 | for name, module in model.named_modules(): 80 | if name == '': # Skip root module 81 | continue 82 | 83 | # Get all child modules of this module 84 | children = [] 85 | for child_name, _ in module.named_children(): 86 | full_child_name = f"{name}.{child_name}" if name else child_name 87 | children.append(full_child_name) 88 | 89 | # Create dependency relationship: for each child module, the current module depends on it 90 | for child in children: 91 | if child not in module_dependencies: 92 | module_dependencies[child] = [] 93 | 94 | # Current module depends on this child module 95 | module_dependencies[child].append(name) 96 | 97 | # 6. Identify bottleneck layers using BottleneckIdentifier with dependency information 98 | # Example: consider everything above an error threshold as a bottleneck 99 | bottleneck_finder = BottleneckIdentifier() 100 | # Use the new API that provides module dependency information 101 | bottleneck_modules = bottleneck_finder.identify_bottlenecks( 102 | error_data=errors, 103 | top_n=3, 104 | module_dependencies=module_dependencies 105 | ) 106 | 107 | # Visualize bottleneck modules in the model structure 108 | print("\nGenerating bottleneck visualization...") 109 | visualizer.visualize_bottlenecks(model, bottleneck_modules, module_names) 110 | 111 | # 7. Transform (quantize) the model using ModelTransformer, preserving bottleneck modules in float32 112 | transformer = ModelTransformer() 113 | quantized_model = transformer.quantize_model( 114 | model, 115 | bottleneck_modules=bottleneck_modules, 116 | module_names=module_names 117 | ) 118 | 119 | # 8. (Optional) Evaluate both models 120 | evaluator = Evaluator(device=device) 121 | 122 | # Simple cross-entropy loss for demonstration 123 | criterion = nn.CrossEntropyLoss() 124 | 125 | # Evaluate original model 126 | results_original = evaluator.evaluate_model( 127 | model=model, 128 | dataloader=dataloader, 129 | criterion=criterion 130 | ) 131 | 132 | # Evaluate quantized model 133 | results_quantized = evaluator.evaluate_model( 134 | model=quantized_model, 135 | dataloader=dataloader, 136 | criterion=criterion 137 | ) 138 | 139 | print("\n=== Original Model Results ===") 140 | for k, v in results_original.items(): 141 | print(f"{k}: {v}") 142 | 143 | print("\n=== Quantized Model Results ===") 144 | for k, v in results_quantized.items(): 145 | print(f"{k}: {v}") 146 | 147 | # 9. Display performance comparison before and after quantization 148 | print("\n=== Performance Comparison ===") 149 | speedup = results_original['avg_inference_time'] / results_quantized['avg_inference_time'] 150 | accuracy_loss = results_original['accuracy'] - results_quantized['accuracy'] 151 | 152 | print(f"Speedup: {speedup:.2f}x") 153 | print(f"Accuracy Loss: {accuracy_loss:.4f} ({accuracy_loss*100:.2f}%)") 154 | print(f"Memory Savings: Approximately 50% for parameters (FP32 -> FP16)") 155 | 156 | # Visualize performance comparison 157 | print("\nGenerating performance comparison visualization...") 158 | visualizer.plot_performance_comparison(results_original, results_quantized) 159 | 160 | print(f"\nAll visualizations have been saved to {output_dir}") 161 | 162 | hook_manager.remove_hooks() # Clean up hooks at the end 163 | 164 | if __name__ == "__main__": 165 | main() -------------------------------------------------------------------------------- /src/bottleneck_identifier.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Tuple 2 | 3 | class BottleneckIdentifier: 4 | def __init__(self): 5 | """ 6 | Initialize the BottleneckIdentifier. 7 | """ 8 | self.bottleneck_modules = [] # List of bottleneck module names 9 | 10 | def identify_bottlenecks( 11 | self, 12 | error_data: Dict[str, float], 13 | threshold: float = None, 14 | top_n: int = None, 15 | error_ratio: float = None, 16 | module_dependencies: Dict[str, List[str]] = None 17 | ) -> List[str]: 18 | """ 19 | Identify bottleneck modules based on error data, considering module dependencies. 20 | 21 | Args: 22 | error_data (Dict[str, float]): Dictionary of module names and their errors. 23 | threshold (float, optional): Error threshold to consider a module as bottleneck. 24 | top_n (int, optional): Number of top modules to consider as bottlenecks. 25 | error_ratio (float, optional): Percentage of maximum error to consider. 26 | module_dependencies (Dict[str, List[str]], optional): Dictionary mapping module names 27 | to their dependent modules. Used to identify error propagation. 28 | 29 | Returns: 30 | List[str]: List of bottleneck module names. 31 | """ 32 | if not error_data: 33 | raise ValueError("Error data is empty. Cannot identify bottlenecks.") 34 | 35 | # Sort modules by error in descending order 36 | sorted_errors = sorted(error_data.items(), key=lambda item: item[1], reverse=True) 37 | 38 | # Initial candidate bottleneck modules set 39 | candidate_bottlenecks = [] 40 | 41 | # Strategy 1: Use error threshold 42 | if threshold is not None: 43 | candidate_bottlenecks = [ 44 | module_name for module_name, error in sorted_errors if error > threshold 45 | ] 46 | # Strategy 2: Select top N modules with highest error 47 | elif top_n is not None: 48 | candidate_bottlenecks = [module_name for module_name, _ in sorted_errors[:top_n]] 49 | # Strategy 3: Use error ratio 50 | elif error_ratio is not None: 51 | max_error = sorted_errors[0][1] 52 | candidate_bottlenecks = [ 53 | module_name for module_name, error in sorted_errors if error > max_error * error_ratio 54 | ] 55 | else: 56 | raise ValueError("At least one identification criterion must be provided.") 57 | 58 | # If module dependencies are provided, consider error propagation 59 | if module_dependencies: 60 | # Calculate impact scores for each module, considering both the module's error and its impact on downstream modules 61 | impact_scores = self._calculate_impact_scores(error_data, module_dependencies) 62 | 63 | # Sort modules by impact score from high to low 64 | sorted_by_impact = sorted(impact_scores.items(), key=lambda item: item[1], reverse=True) 65 | 66 | # Include modules with top impact scores in consideration 67 | impact_bottlenecks = [module_name for module_name, _ in sorted_by_impact[:min(top_n, len(sorted_by_impact)) if top_n else 3]] 68 | 69 | # Merge bottleneck modules identified by both methods 70 | combined_bottlenecks = list(set(candidate_bottlenecks) | set(impact_bottlenecks)) 71 | 72 | # Output comparison of results 73 | print("Bottleneck modules identified based on error:", candidate_bottlenecks) 74 | print("Bottleneck modules identified based on impact score:", impact_bottlenecks) 75 | print("Combined bottleneck modules:", combined_bottlenecks) 76 | 77 | self.bottleneck_modules = combined_bottlenecks 78 | else: 79 | self.bottleneck_modules = candidate_bottlenecks 80 | 81 | print("Final identified bottleneck modules:") 82 | for module_name in self.bottleneck_modules: 83 | error = error_data[module_name] 84 | print(f"{module_name}: Error = {error}") 85 | 86 | return self.bottleneck_modules 87 | 88 | def _calculate_impact_scores( 89 | self, 90 | error_data: Dict[str, float], 91 | module_dependencies: Dict[str, List[str]] 92 | ) -> Dict[str, float]: 93 | """ 94 | Calculate impact scores for each module, considering its influence on downstream modules. 95 | 96 | Args: 97 | error_data (Dict[str, float]): Mapping from module names to errors 98 | module_dependencies (Dict[str, List[str]]): Mapping from module names to lists of modules that depend on them 99 | 100 | Returns: 101 | Dict[str, float]: Mapping from module names to impact scores 102 | """ 103 | impact_scores = {} 104 | 105 | for module_name, error in error_data.items(): 106 | # Initial score is the module's own error 107 | impact = error 108 | 109 | # If other modules depend on this module, increase the impact score 110 | if module_name in module_dependencies: 111 | for dependent_module in module_dependencies[module_name]: 112 | if dependent_module in error_data: 113 | # Consider the influence on dependent modules: weighted sum of all dependent module errors 114 | impact += 0.5 * error_data[dependent_module] 115 | 116 | impact_scores[module_name] = impact 117 | 118 | return impact_scores 119 | 120 | def get_bottleneck_modules(self) -> List[str]: 121 | """ 122 | Get the list of identified bottleneck module names. 123 | 124 | Returns: 125 | List[str]: List of bottleneck module names. 126 | """ 127 | return self.bottleneck_modules -------------------------------------------------------------------------------- /src/config_manager.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import argparse 3 | import os 4 | import torch 5 | from typing import Any, Dict 6 | 7 | class ConfigManager: 8 | def __init__(self, config_file: str = None): 9 | """ 10 | Initialize the ConfigManager. 11 | 12 | Args: 13 | config_file (str, optional): Path to the configuration file. 14 | """ 15 | self.config = {} 16 | self.load_default_config() 17 | if config_file: 18 | self.load_config_from_file(config_file) 19 | self.load_config_from_env() 20 | self.load_config_from_args() 21 | 22 | def load_default_config(self): 23 | """ 24 | Load default configuration parameters. 25 | """ 26 | self.config = { 27 | 'error_threshold': 1e-4, 28 | 'log_file': None, 29 | 'log_level': 'INFO', 30 | 'output_dir': './output', 31 | 'device': 'cuda' if torch.cuda.is_available() else 'cpu', 32 | 'batch_size': 32, 33 | 'num_workers': 4, 34 | 'dataset': { 35 | 'name': 'CIFAR10', 36 | 'root': './data', 37 | 'download': True 38 | }, 39 | 'model': { 40 | 'name': 'resnet18', 41 | 'pretrained': True 42 | }, 43 | 'quantization': { 44 | 'error_metric': 'mse', 45 | 'strategy': 'threshold', # Options: 'threshold', 'top_n', 'error_ratio' 46 | 'top_n': None, 47 | 'error_ratio': None 48 | } 49 | } 50 | 51 | def load_config_from_file(self, config_file: str): 52 | """ 53 | Load configuration parameters from a YAML file. 54 | 55 | Args: 56 | config_file (str): Path to the configuration file. 57 | """ 58 | with open(config_file, 'r') as f: 59 | file_config = yaml.safe_load(f) 60 | self.merge_config(file_config) 61 | 62 | def load_config_from_env(self): 63 | """ 64 | Load configuration parameters from environment variables. 65 | """ 66 | # Example: Override device from environment variable 67 | device = os.environ.get('DEVICE') 68 | if device: 69 | self.config['device'] = device 70 | 71 | def load_config_from_args(self): 72 | """ 73 | Load configuration parameters from command-line arguments. 74 | """ 75 | parser = argparse.ArgumentParser(description='Quantization Tool Configuration') 76 | parser.add_argument('--config_file', type=str, help='Path to the configuration file') 77 | parser.add_argument('--error_threshold', type=float, help='Error threshold for bottleneck identification') 78 | parser.add_argument('--log_file', type=str, help='Path to the log file') 79 | parser.add_argument('--log_level', type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'], help='Logging level') 80 | parser.add_argument('--output_dir', type=str, help='Directory to save outputs and visualizations') 81 | parser.add_argument('--device', type=str, help='Device to use for computation (e.g., "cpu" or "cuda")') 82 | parser.add_argument('--batch_size', type=int, help='Batch size for data loaders') 83 | parser.add_argument('--num_workers', type=int, help='Number of workers for data loaders') 84 | 85 | # Parse known arguments; ignore unknown ones 86 | args, _ = parser.parse_known_args() 87 | 88 | # Update configurations based on arguments 89 | args_dict = vars(args) 90 | for key, value in args_dict.items(): 91 | if value is not None: 92 | self.set_config(key, value) 93 | 94 | def merge_config(self, new_config: Dict[str, Any]): 95 | """ 96 | Merge a new configuration dictionary into the existing configuration. 97 | 98 | Args: 99 | new_config (Dict[str, Any]): New configuration parameters. 100 | """ 101 | self.config = self._recursive_merge(self.config, new_config) 102 | 103 | def _recursive_merge(self, default: Dict[str, Any], override: Dict[str, Any]) -> Dict[str, Any]: 104 | """ 105 | Recursively merge two dictionaries. 106 | 107 | Args: 108 | default (Dict[str, Any]): The default dictionary. 109 | override (Dict[str, Any]): The overriding dictionary. 110 | 111 | Returns: 112 | Dict[str, Any]: Merged dictionary. 113 | """ 114 | for key, value in override.items(): 115 | if key in default and isinstance(default[key], dict) and isinstance(value, dict): 116 | default[key] = self._recursive_merge(default[key], value) 117 | else: 118 | default[key] = value 119 | return default 120 | 121 | def set_config(self, key: str, value: Any): 122 | """ 123 | Set a configuration parameter. 124 | 125 | Args: 126 | key (str): The key of the configuration parameter. 127 | value (Any): The value to set. 128 | """ 129 | keys = key.split('.') 130 | config_section = self.config 131 | for k in keys[:-1]: 132 | if k in config_section: 133 | config_section = config_section[k] 134 | else: 135 | config_section[k] = {} 136 | config_section = config_section[k] 137 | config_section[keys[-1]] = value 138 | 139 | def get_config(self, key: str = None) -> Any: 140 | """ 141 | Get a configuration parameter. 142 | 143 | Args: 144 | key (str, optional): The key of the configuration parameter. 145 | 146 | Returns: 147 | Any: The value of the configuration parameter. 148 | """ 149 | if key is None: 150 | return self.config 151 | keys = key.split('.') 152 | config_section = self.config 153 | for k in keys: 154 | if k in config_section: 155 | config_section = config_section[k] 156 | else: 157 | return None 158 | return config_section 159 | 160 | def print_config(self): 161 | """ 162 | Print the current configuration. 163 | """ 164 | print("Current Configuration:") 165 | print(yaml.dump(self.config, default_flow_style=False)) 166 | 167 | -------------------------------------------------------------------------------- /src/evaluator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.data import DataLoader 4 | from typing import Dict, Callable, Any 5 | import time 6 | 7 | class Evaluator: 8 | def __init__(self, device: torch.device = None): 9 | """ 10 | Initialize the Evaluator. 11 | 12 | Args: 13 | device (torch.device, optional): The device to run evaluations on. 14 | Defaults to CPU if not specified. 15 | """ 16 | self.device = device if device is not None else torch.device('cpu') 17 | 18 | def evaluate_model( 19 | self, 20 | model: nn.Module, 21 | dataloader: DataLoader, 22 | criterion: nn.Module = None, 23 | metrics: Dict[str, Callable[[torch.Tensor, torch.Tensor], float]] = None 24 | ) -> Dict[str, float]: 25 | """ 26 | Evaluate the model on the given dataset. 27 | 28 | Args: 29 | model (nn.Module): The model to evaluate. 30 | dataloader (DataLoader): DataLoader providing the evaluation data. 31 | criterion (nn.Module, optional): Loss function to compute loss. 32 | metrics (Dict[str, Callable], optional): Additional metrics to compute. 33 | 34 | Returns: 35 | Dict[str, float]: Dictionary containing evaluation results. 36 | """ 37 | model.eval() 38 | model.to(self.device) 39 | 40 | total_loss = 0.0 41 | total_correct = 0 42 | total_samples = 0 43 | inference_times = [] 44 | 45 | # Initialize metric accumulators 46 | metric_sums = {metric_name: 0.0 for metric_name in metrics} if metrics else {} 47 | 48 | with torch.no_grad(): 49 | for batch in dataloader: 50 | inputs, targets = batch 51 | inputs = inputs.to(self.device) 52 | targets = targets.to(self.device) 53 | 54 | # Measure inference time 55 | start_time = time.time() 56 | outputs = model(inputs) 57 | end_time = time.time() 58 | inference_time = end_time - start_time 59 | inference_times.append(inference_time) 60 | 61 | # Compute loss if criterion is provided 62 | if criterion is not None: 63 | loss = criterion(outputs, targets) 64 | total_loss += loss.item() * inputs.size(0) 65 | 66 | # Compute accuracy 67 | _, preds = torch.max(outputs, 1) 68 | total_correct += torch.sum(preds == targets).item() 69 | total_samples += inputs.size(0) 70 | 71 | # Compute additional metrics 72 | if metrics: 73 | for metric_name, metric_fn in metrics.items(): 74 | metric_value = metric_fn(outputs, targets) 75 | metric_sums[metric_name] += metric_value * inputs.size(0) 76 | 77 | # Calculate average metrics 78 | avg_loss = total_loss / total_samples if criterion is not None else None 79 | accuracy = total_correct / total_samples 80 | avg_inference_time = sum(inference_times) / len(inference_times) 81 | # Calculate throughput (samples per second) 82 | total_inference_time = sum(inference_times) 83 | throughput = total_samples / total_inference_time if total_inference_time > 0 else float('inf') 84 | 85 | # Prepare results 86 | results = { 87 | 'loss': avg_loss, 88 | 'accuracy': accuracy, 89 | 'avg_inference_time': avg_inference_time, 90 | 'throughput': throughput, 91 | 'total_samples': total_samples 92 | } 93 | 94 | # Add additional metrics 95 | if metrics: 96 | for metric_name, total_metric_value in metric_sums.items(): 97 | results[metric_name] = total_metric_value / total_samples 98 | 99 | return results 100 | 101 | def compare_models( 102 | self, 103 | model1: nn.Module, 104 | model2: nn.Module, 105 | dataloader: DataLoader, 106 | criterion: nn.Module = None, 107 | metrics: Dict[str, Callable[[torch.Tensor, torch.Tensor], float]] = None 108 | ) -> Dict[str, Dict[str, float]]: 109 | """ 110 | Compare two models on the given dataset. 111 | 112 | Args: 113 | model1 (nn.Module): The first model to evaluate. 114 | model2 (nn.Module): The second model to evaluate. 115 | dataloader (DataLoader): DataLoader providing the evaluation data. 116 | criterion (nn.Module, optional): Loss function to compute loss. 117 | metrics (Dict[str, Callable], optional): Additional metrics to compute. 118 | 119 | Returns: 120 | Dict[str, Dict[str, float]]: Dictionary containing evaluation results for both models. 121 | """ 122 | print("Evaluating Model 1...") 123 | results1 = self.evaluate_model(model1, dataloader, criterion, metrics) 124 | print("Evaluating Model 2...") 125 | results2 = self.evaluate_model(model2, dataloader, criterion, metrics) 126 | 127 | # Compare results 128 | comparison = { 129 | 'model1': results1, 130 | 'model2': results2 131 | } 132 | 133 | return comparison 134 | 135 | -------------------------------------------------------------------------------- /src/hook_manager.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import Dict, Tuple, List 4 | 5 | class HookManager: 6 | def __init__(self): 7 | """ 8 | Initialize the HookManager. 9 | """ 10 | self.hooks = [] 11 | self.module_io = {} # Stores input and output tensors for each module 12 | 13 | def register_hooks(self, model: nn.Module): 14 | """ 15 | Register forward hooks on all modules in the model. 16 | 17 | Args: 18 | model (nn.Module): The PyTorch model instance. 19 | """ 20 | # Define the hook function 21 | def hook_fn(module, input, output): 22 | # Store input and output tensors for the module 23 | self.module_io[module] = {'input': input, 'output': output} 24 | 25 | # Register the hook on all modules 26 | for name, module in model.named_modules(): 27 | hook = module.register_forward_hook(hook_fn) 28 | self.hooks.append(hook) 29 | 30 | def remove_hooks(self): 31 | """ 32 | Remove all registered hooks. 33 | """ 34 | for hook in self.hooks: 35 | hook.remove() 36 | self.hooks = [] 37 | 38 | def get_module_io(self) -> Dict[nn.Module, Dict[str, Tuple[torch.Tensor]]]: 39 | """ 40 | Get the stored input and output tensors for each module. 41 | 42 | Returns: 43 | Dict[nn.Module, Dict[str, Tuple[torch.Tensor]]]: A dictionary mapping modules to their input/output tensors. 44 | """ 45 | return self.module_io 46 | 47 | def __enter__(self): 48 | """ 49 | Enter the context manager. 50 | """ 51 | return self 52 | 53 | def __exit__(self, exc_type, exc_value, traceback): 54 | """ 55 | Exit the context manager, remove all hooks. 56 | """ 57 | self.remove_hooks() -------------------------------------------------------------------------------- /src/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | 4 | class Logger: 5 | def __init__(self, log_file: str = None, log_level: int = logging.INFO): 6 | """ 7 | Initialize the Logger. 8 | 9 | Args: 10 | log_file (str, optional): The file path to save the log. If None, logs to console. 11 | log_level (int, optional): Logging level (e.g., logging.INFO, logging.DEBUG). 12 | """ 13 | self.logger = logging.getLogger('QuantizationLogger') 14 | self.logger.setLevel(log_level) 15 | 16 | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') 17 | 18 | # Clear existing handlers 19 | if self.logger.hasHandlers(): 20 | self.logger.handlers.clear() 21 | 22 | if log_file: 23 | # Log to file 24 | file_handler = logging.FileHandler(log_file) 25 | file_handler.setFormatter(formatter) 26 | self.logger.addHandler(file_handler) 27 | else: 28 | # Log to console 29 | console_handler = logging.StreamHandler(sys.stdout) 30 | console_handler.setFormatter(formatter) 31 | self.logger.addHandler(console_handler) 32 | 33 | def log_info(self, message: str): 34 | """ 35 | Log an informational message. 36 | 37 | Args: 38 | message (str): The message to log. 39 | """ 40 | self.logger.info(message) 41 | 42 | def log_debug(self, message: str): 43 | """ 44 | Log a debug message. 45 | 46 | Args: 47 | message (str): The message to log. 48 | """ 49 | self.logger.debug(message) 50 | 51 | def log_error(self, message: str): 52 | """ 53 | Log an error message. 54 | 55 | Args: 56 | message (str): The message to log. 57 | """ 58 | self.logger.error(message) -------------------------------------------------------------------------------- /src/model_parser.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import Dict, List, Tuple 4 | 5 | class ModelParser: 6 | def __init__(self, model: nn.Module): 7 | """ 8 | Initialize the ModelParser. 9 | 10 | Parameters: 11 | model (nn.Module): The PyTorch model instance to be parsed. 12 | """ 13 | self.model = model 14 | # Get all modules in the model, mapping names to module instances 15 | self.module_dict = dict(model.named_modules()) 16 | # Create a mapping from module instances to names for later use 17 | self.module_names = {module: name for name, module in self.module_dict.items()} 18 | # Get all leaf modules (modules without submodules) 19 | self.leaf_modules = self._get_leaf_modules() 20 | 21 | def get_all_modules(self) -> Dict[str, nn.Module]: 22 | """ 23 | Get a dictionary of all modules in the model. 24 | 25 | Returns: 26 | Dict[str, nn.Module]: A mapping dictionary from module names to module instances. 27 | """ 28 | return self.module_dict 29 | 30 | def get_leaf_modules(self) -> List[Tuple[str, nn.Module]]: 31 | """ 32 | Get all leaf modules in the model. 33 | 34 | Returns: 35 | List[Tuple[str, nn.Module]]: A list containing module names and module instances. 36 | """ 37 | return self.leaf_modules 38 | 39 | def _get_leaf_modules(self) -> List[Tuple[str, nn.Module]]: 40 | """ 41 | Internal method to get all leaf modules. 42 | 43 | Returns: 44 | List[Tuple[str, nn.Module]]: A list containing module names and module instances. 45 | """ 46 | leaf_modules = [] 47 | for name, module in self.module_dict.items(): 48 | # If the module has no submodules, it is considered a leaf module 49 | if len(list(module.children())) == 0: 50 | leaf_modules.append((name, module)) 51 | return leaf_modules 52 | 53 | def get_module_by_name(self, name: str) -> nn.Module: 54 | """ 55 | Get a module instance by its name. 56 | 57 | Parameters: 58 | name (str): The name of the module. 59 | 60 | Returns: 61 | nn.Module: The corresponding module instance, or None if it does not exist. 62 | """ 63 | return self.module_dict.get(name, None) 64 | 65 | def get_module_names(self) -> Dict[nn.Module, str]: 66 | """ 67 | Get a mapping from module instances to names. 68 | 69 | Returns: 70 | Dict[nn.Module, str]: A dictionary mapping module instances to their names. 71 | """ 72 | return self.module_names 73 | 74 | def print_model_structure(self): 75 | """ 76 | Print the model's structure, showing each module's name and type. 77 | """ 78 | print("Model Structure:") 79 | for name, module in self.module_dict.items(): 80 | print(f"{name}: {module.__class__.__name__}") -------------------------------------------------------------------------------- /src/model_transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import Dict, List, Set 4 | import copy 5 | 6 | class ModelTransformer: 7 | def __init__(self): 8 | """ 9 | Initialize the ModelTransformer. 10 | """ 11 | pass 12 | 13 | def quantize_model( 14 | self, 15 | model: nn.Module, 16 | bottleneck_modules: List[str], 17 | module_names: Dict[nn.Module, str] 18 | ) -> nn.Module: 19 | """ 20 | Quantize the model based on the identified bottleneck modules. 21 | 22 | Args: 23 | model (nn.Module): The original PyTorch model. 24 | bottleneck_modules (List[str]): List of bottleneck module names to keep in float32. 25 | module_names (Dict[nn.Module, str]): Mapping of module instances to names. 26 | 27 | Returns: 28 | nn.Module: The quantized model. 29 | """ 30 | # Create a copy of the model to avoid modifying the original 31 | model_copy = self._copy_model(model) 32 | 33 | # Save the original model structure information for later validation 34 | original_structure = self._get_model_structure(model) 35 | 36 | # Convert modules to float16, except for bottleneck modules 37 | for (module_name, module) in model_copy.named_modules(): 38 | if module_name in bottleneck_modules: 39 | # Keep bottleneck modules in float32 40 | self._wrap_bottleneck_module(module) 41 | else: 42 | # First check if it's a special layer, if so, handle it specially 43 | if self._is_special_layer(module): 44 | self._handle_special_layer(module) 45 | else: 46 | # Normal layers are converted to half precision 47 | self._convert_module_to_half(module) 48 | 49 | # Verify whether the model structure after quantization is consistent with the original 50 | quantized_structure = self._get_model_structure(model_copy) 51 | if not self._validate_model_structure(original_structure, quantized_structure): 52 | print("Warning: The structure of the quantized model is inconsistent with the original model!") 53 | else: 54 | print("Validation passed: Model structure remains consistent before and after quantization.") 55 | 56 | return model_copy 57 | 58 | def _copy_model(self, model: nn.Module) -> nn.Module: 59 | """ 60 | Create a deep copy of the model. 61 | 62 | Args: 63 | model (nn.Module): The original model. 64 | 65 | Returns: 66 | nn.Module: A deep copy of the model. 67 | """ 68 | return copy.deepcopy(model) 69 | 70 | def _convert_module_to_half(self, module: nn.Module): 71 | """ 72 | Recursively convert module parameters and buffers to float16. 73 | 74 | Args: 75 | module (nn.Module): The module to convert. 76 | """ 77 | # First recursively process child modules to avoid incorrect handling after the forward method is overridden 78 | for child in module.children(): 79 | self._convert_module_to_half(child) 80 | 81 | # Convert parameters 82 | for param in module.parameters(recurse=False): 83 | param.data = param.data.half() 84 | 85 | # Convert buffers 86 | for buffer in module.buffers(recurse=False): 87 | buffer.data = buffer.data.half() 88 | 89 | original_forward = module.forward 90 | 91 | def new_forward(*args, **kwargs): 92 | # Convert inputs to float16 93 | args = tuple(arg.half() if isinstance(arg, torch.Tensor) else arg for arg in args) 94 | kwargs = {k: v.half() if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} 95 | 96 | # Run the original forward method 97 | output = original_forward(*args, **kwargs) 98 | 99 | # Convert outputs back to float32 to match the original model's interface 100 | if isinstance(output, torch.Tensor): 101 | output = output.float() 102 | elif isinstance(output, (tuple, list)): 103 | output = tuple(out.float() if isinstance(out, torch.Tensor) else out for out in output) 104 | 105 | return output 106 | 107 | module.forward = new_forward 108 | 109 | def _wrap_bottleneck_module(self, module: nn.Module): 110 | """ 111 | Wrap the bottleneck module to ensure data type consistency. 112 | 113 | Args: 114 | module (nn.Module): The bottleneck module to wrap. 115 | """ 116 | original_forward = module.forward 117 | 118 | def new_forward(*args, **kwargs): 119 | # Convert inputs to float32 120 | args = tuple(arg.float() if isinstance(arg, torch.Tensor) else arg for arg in args) 121 | kwargs = {k: v.float() if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} 122 | 123 | # Run the original forward method 124 | output = original_forward(*args, **kwargs) 125 | 126 | # Convert outputs back to float16 if necessary 127 | if isinstance(output, torch.Tensor): 128 | output = output.float() 129 | elif isinstance(output, (tuple, list)): 130 | output = tuple(out.float() if isinstance(out, torch.Tensor) else out for out in output) 131 | 132 | return output 133 | 134 | # Replace the forward method of the module 135 | module.forward = new_forward 136 | 137 | def _is_special_layer(self, module: nn.Module) -> bool: 138 | """ 139 | Check if a module needs special handling. 140 | 141 | Args: 142 | module (nn.Module): The module to check 143 | 144 | Returns: 145 | bool: Whether special handling is needed 146 | """ 147 | # Check if it's a BatchNorm layer 148 | return isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)) 149 | 150 | def _handle_special_layer(self, module: nn.Module): 151 | """ 152 | Apply special handling to special layers. 153 | 154 | Args: 155 | module (nn.Module): The special module to process 156 | """ 157 | # For BatchNorm layers, keep running_mean and running_var in float32 precision 158 | # but weights and biases can be converted to half precision 159 | if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): 160 | # Convert weights and biases to half precision 161 | if module.weight is not None: 162 | module.weight.data = module.weight.data.half() 163 | if module.bias is not None: 164 | module.bias.data = module.bias.data.half() 165 | 166 | # BatchNorm layer's running_mean and running_var need to remain in float32 precision 167 | # (Converting to half would cause numerical instability) 168 | 169 | # Copy the original forward method 170 | original_forward = module.forward 171 | 172 | # Create a new forward method 173 | def new_forward(x): 174 | x_half = x.half() if isinstance(x, torch.Tensor) else x 175 | result = original_forward(x_half) 176 | return result.float() # Return float32 result 177 | 178 | # Replace the forward method 179 | module.forward = new_forward 180 | 181 | def save_quantized_model(self, model: nn.Module, file_path: str): 182 | """ 183 | Save the quantized model to a file. 184 | 185 | Args: 186 | model (nn.Module): The quantized model. 187 | file_path (str): The file path to save the model. 188 | """ 189 | torch.save(model.state_dict(), file_path) 190 | 191 | def load_quantized_model(self, model_class, file_path: str) -> nn.Module: 192 | """ 193 | Load the quantized model from a file. 194 | 195 | Args: 196 | model_class: The class of the model to instantiate. 197 | file_path (str): The file path of the saved model. 198 | 199 | Returns: 200 | nn.Module: The loaded quantized model. 201 | """ 202 | model = model_class() 203 | model.load_state_dict(torch.load(file_path)) 204 | return model 205 | 206 | def _get_model_structure(self, model: nn.Module) -> dict: 207 | """ 208 | Get the structure information of the model. 209 | 210 | Args: 211 | model (nn.Module): The model to analyze 212 | 213 | Returns: 214 | dict: Model structure information 215 | """ 216 | structure = {} 217 | 218 | # Record the types and submodule structures of all modules 219 | for name, module in model.named_modules(): 220 | if name == '': # Skip the root module 221 | continue 222 | 223 | # Record module type 224 | module_info = { 225 | 'type': module.__class__.__name__, 226 | 'parameters': {n: p.shape for n, p in module.named_parameters(recurse=False)}, 227 | 'buffers': {n: b.shape for n, b in module.named_buffers(recurse=False)}, 228 | 'children': [] 229 | } 230 | 231 | # Record direct child modules 232 | for child_name, _ in module.named_children(): 233 | full_child_name = f"{name}.{child_name}" if name else child_name 234 | module_info['children'].append(full_child_name) 235 | 236 | structure[name] = module_info 237 | 238 | return structure 239 | 240 | def _validate_model_structure(self, original: dict, quantized: dict) -> bool: 241 | """ 242 | Validate whether the model structure is consistent before and after quantization. 243 | 244 | Args: 245 | original (dict): Original model structure 246 | quantized (dict): Structure of the quantized model 247 | 248 | Returns: 249 | bool: Whether the structures are consistent 250 | """ 251 | # Check if the number of modules is consistent 252 | if len(original) != len(quantized): 253 | print(f"Module count mismatch: Original model {len(original)}, Quantized model {len(quantized)}") 254 | return False 255 | 256 | # Check the type and submodule structure of each module 257 | for name, orig_info in original.items(): 258 | if name not in quantized: 259 | print(f"Module missing: {name}") 260 | return False 261 | 262 | quant_info = quantized[name] 263 | 264 | # Check module type 265 | if orig_info['type'] != quant_info['type']: 266 | print(f"Module type mismatch: {name}, Original {orig_info['type']}, Quantized {quant_info['type']}") 267 | return False 268 | 269 | # Check parameter shapes 270 | for param_name, param_shape in orig_info['parameters'].items(): 271 | if param_name not in quant_info['parameters']: 272 | print(f"Parameter missing: {name}.{param_name}") 273 | return False 274 | 275 | if param_shape != quant_info['parameters'][param_name]: 276 | print(f"Parameter shape mismatch: {name}.{param_name}") 277 | return False 278 | 279 | # Check buffer shapes 280 | for buffer_name, buffer_shape in orig_info['buffers'].items(): 281 | if buffer_name not in quant_info['buffers']: 282 | print(f"Buffer missing: {name}.{buffer_name}") 283 | return False 284 | 285 | if buffer_shape != quant_info['buffers'][buffer_name]: 286 | print(f"Buffer shape mismatch: {name}.{buffer_name}") 287 | return False 288 | 289 | # Check submodule structure 290 | if set(orig_info['children']) != set(quant_info['children']): 291 | print(f"Child module structure mismatch: {name}") 292 | return False 293 | 294 | return True -------------------------------------------------------------------------------- /src/quantization_tester.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import copy 4 | from typing import Dict, List, Tuple, Union 5 | 6 | class QuantizationTester: 7 | def __init__(self, error_metric: str = 'mse'): 8 | """ 9 | Initialize the QuantizationTester. 10 | 11 | Args: 12 | error_metric (str): The error metric to use ('mse', 'mae', etc.). 13 | """ 14 | self.error_metric = error_metric 15 | self.module_errors = {} # Stores error data for each module 16 | 17 | def test_modules( 18 | self, 19 | modules: Dict[str, nn.Module], 20 | module_io: Dict[nn.Module, Dict[str, Tuple[torch.Tensor]]], 21 | module_names: Dict[nn.Module, str] 22 | ) -> Dict[str, float]: 23 | """ 24 | Perform quantization testing on each module. 25 | 26 | Args: 27 | modules (Dict[str, nn.Module]): Dictionary of module names to modules. 28 | module_io (Dict[nn.Module, Dict[str, Tuple[torch.Tensor]]]): Captured input/output tensors. 29 | module_names (Dict[nn.Module, str]): Mapping of module instances to names. 30 | 31 | Returns: 32 | Dict[str, float]: Error data for each module. 33 | """ 34 | for module in modules.values(): 35 | module_name = module_names.get(module, 'Unnamed Module') 36 | 37 | # Skip modules without captured data 38 | if module not in module_io: 39 | continue 40 | 41 | # Get the original input and output tensors 42 | input_tensors = module_io[module]['input'] 43 | output_float32 = module_io[module]['output'] 44 | 45 | # Create a copy of the module to avoid modifying the original 46 | module_copy = copy.deepcopy(module) 47 | 48 | # Use the same quantization strategy as the model transformer 49 | # 1. First recursively quantize the parameters of child modules 50 | self._convert_module_children_to_half(module_copy) 51 | 52 | # 2. Then quantize the parameters of the current module 53 | for param in module_copy.parameters(recurse=False): 54 | param.data = param.data.half() 55 | 56 | for buffer in module_copy.buffers(recurse=False): 57 | buffer.data = buffer.data.half() 58 | 59 | # Convert input tensors to float16 60 | input_tensors_half = tuple(inp.half() for inp in input_tensors) 61 | 62 | # Run the module with float16 inputs 63 | with torch.no_grad(): 64 | output_half = module_copy(*input_tensors_half) 65 | 66 | # Convert output back to float32 for comparison 67 | output_half_float32 = self._to_float32(output_half) 68 | 69 | # Compute the error between float32 and float16 outputs 70 | error = self._compute_error(output_float32, output_half_float32) 71 | 72 | # Store the error with the module name 73 | self.module_errors[module_name] = error 74 | 75 | print(f"Module: {module_name}, Error ({self.error_metric}): {error}") 76 | 77 | return self.module_errors 78 | 79 | def _convert_module_children_to_half(self, module: nn.Module): 80 | """ 81 | Recursively convert child module parameters and buffers to float16, 82 | ensuring we follow the same pattern as the model_transformer. 83 | 84 | Args: 85 | module (nn.Module): The module whose children to convert. 86 | """ 87 | for child in module.children(): 88 | self._convert_module_children_to_half(child) 89 | 90 | # Convert parameters of child 91 | for param in child.parameters(recurse=False): 92 | param.data = param.data.half() 93 | 94 | # Convert buffers of child 95 | for buffer in child.buffers(recurse=False): 96 | buffer.data = buffer.data.half() 97 | 98 | def _convert_module_to_half(self, module: nn.Module): 99 | """ 100 | Convert module parameters and buffers to float16. 101 | This method is used by the unit tests to test conversion to half precision. 102 | 103 | Args: 104 | module (nn.Module): The module to convert. 105 | """ 106 | # First convert child modules 107 | self._convert_module_children_to_half(module) 108 | 109 | # Then convert parameters of the current module 110 | for param in module.parameters(recurse=False): 111 | param.data = param.data.half() 112 | 113 | # And convert buffers of the current module 114 | for buffer in module.buffers(recurse=False): 115 | buffer.data = buffer.data.half() 116 | 117 | # Note: Unlike ModelTransformer, we don't modify the forward method here 118 | # as this is just for testing parameter conversion 119 | 120 | def _compute_error( 121 | self, 122 | output1: Union[torch.Tensor, Tuple[torch.Tensor]], 123 | output2: Union[torch.Tensor, Tuple[torch.Tensor]] 124 | ) -> float: 125 | """ 126 | Compute the error between two outputs. 127 | 128 | Args: 129 | output1: The original output (float32). 130 | output2: The quantized output (float16 converted back to float32). 131 | 132 | Returns: 133 | float: The computed error. 134 | """ 135 | if isinstance(output1, torch.Tensor): 136 | error = self._compute_tensor_error(output1, output2) 137 | elif isinstance(output1, (tuple, list)): 138 | error = sum( 139 | self._compute_tensor_error(o1, o2) for o1, o2 in zip(output1, output2) 140 | ) / len(output1) 141 | else: 142 | error = float('inf') # Unsupported output type 143 | return error 144 | 145 | def _compute_tensor_error( 146 | self, tensor1: torch.Tensor, tensor2: torch.Tensor 147 | ) -> float: 148 | """ 149 | Compute the error between two tensors. 150 | 151 | Args: 152 | tensor1: The original tensor (float32). 153 | tensor2: The quantized tensor (float32). 154 | 155 | Returns: 156 | float: The computed error. 157 | """ 158 | if self.error_metric == 'mse': 159 | return torch.nn.functional.mse_loss(tensor1, tensor2).item() 160 | elif self.error_metric == 'mae': 161 | return torch.nn.functional.l1_loss(tensor1, tensor2).item() 162 | else: 163 | raise ValueError(f"Unsupported error metric: {self.error_metric}") 164 | 165 | def _to_float32( 166 | self, output: Union[torch.Tensor, Tuple[torch.Tensor]] 167 | ) -> Union[torch.Tensor, Tuple[torch.Tensor]]: 168 | """ 169 | Convert output tensors to float32. 170 | 171 | Args: 172 | output: The output tensor(s) to convert. 173 | 174 | Returns: 175 | The converted output tensor(s). 176 | """ 177 | if isinstance(output, torch.Tensor): 178 | return output.float() 179 | elif isinstance(output, (tuple, list)): 180 | return tuple(out.float() for out in output) 181 | else: 182 | return output # Unsupported type, return as is 183 | 184 | def get_module_errors(self) -> Dict[str, float]: 185 | """ 186 | Get the error data for each module. 187 | 188 | Returns: 189 | Dict[str, float]: Error data for each module. 190 | """ 191 | return self.module_errors 192 | 193 | def sort_modules_by_error(self, descending: bool = True) -> List[Tuple[str, float]]: 194 | """ 195 | Sort modules by their error values. 196 | 197 | Args: 198 | descending (bool): Sort in descending order if True. 199 | 200 | Returns: 201 | List[Tuple[str, float]]: Sorted list of module names and errors. 202 | """ 203 | return sorted( 204 | self.module_errors.items(), 205 | key=lambda item: item[1], 206 | reverse=descending 207 | ) -------------------------------------------------------------------------------- /src/visualizer.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import seaborn as sns 3 | import os 4 | import torch.nn as nn 5 | from typing import Dict, List 6 | import networkx as nx 7 | 8 | class Visualizer: 9 | def __init__(self, output_dir: str = './visualizations'): 10 | """ 11 | Initialize the Visualizer. 12 | 13 | Args: 14 | output_dir (str): Directory to save visualization outputs. 15 | """ 16 | self.output_dir = output_dir 17 | os.makedirs(self.output_dir, exist_ok=True) 18 | 19 | def plot_error_distribution(self, error_data: Dict[str, float], save_fig: bool = True): 20 | """ 21 | Plot the distribution of errors across modules. 22 | 23 | Args: 24 | error_data (Dict[str, float]): Dictionary of module names and their errors. 25 | save_fig (bool): Whether to save the figure to a file. 26 | """ 27 | errors = list(error_data.values()) 28 | module_names = list(error_data.keys()) 29 | 30 | plt.figure(figsize=(12, 6)) 31 | sns.barplot(x=module_names, y=errors) 32 | plt.xticks(rotation=90) 33 | plt.xlabel('Module Name') 34 | plt.ylabel('Error') 35 | plt.title('Module-wise Quantization Error') 36 | plt.tight_layout() 37 | 38 | if save_fig: 39 | fig_path = os.path.join(self.output_dir, 'error_distribution.png') 40 | plt.savefig(fig_path) 41 | plt.close() 42 | print(f"Error distribution plot saved to {fig_path}") 43 | else: 44 | plt.show() 45 | 46 | def visualize_bottlenecks(self, model: nn.Module, bottleneck_modules: List[str], module_names: Dict[nn.Module, str]): 47 | """ 48 | Visualize bottleneck layers in the model's structure. 49 | 50 | Args: 51 | model (nn.Module): The PyTorch model. 52 | bottleneck_modules (List[str]): List of bottleneck module names. 53 | module_names (Dict[nn.Module, str]): Mapping of module instances to names. 54 | """ 55 | # Create a directed graph to represent the model 56 | G = nx.DiGraph() 57 | 58 | # Build the graph by traversing the model modules 59 | def add_edges(module: nn.Module, parent_name: str = ''): 60 | for name, child in module.named_children(): 61 | full_name = f"{parent_name}.{name}" if parent_name else name 62 | G.add_node(full_name) 63 | G.add_edge(parent_name, full_name) if parent_name else None 64 | add_edges(child, full_name) 65 | 66 | add_edges(model) 67 | 68 | # Assign colors to bottleneck modules 69 | color_map = [] 70 | node_sizes = [] 71 | for node in G.nodes(): 72 | if node in bottleneck_modules: 73 | color_map.append('red') # Bottleneck modules in red 74 | node_sizes.append(500) # Bigger size for bottleneck modules 75 | else: 76 | color_map.append('lightblue') # Other modules in light blue 77 | node_sizes.append(300) # Normal size for other modules 78 | 79 | plt.figure(figsize=(12, 8)) 80 | 81 | # Use a layout algorithm that doesn't depend on Graphviz 82 | try: 83 | # Try using spring_layout (force-directed algorithm) 84 | pos = nx.spring_layout(G, seed=42) 85 | 86 | # Draw the graph 87 | nx.draw(G, pos, with_labels=True, node_color=color_map, 88 | node_size=node_sizes, arrows=True, font_size=8, 89 | edge_color='gray', width=1, alpha=0.7) 90 | 91 | # Add legend 92 | plt.legend(handles=[ 93 | plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='red', markersize=10, label='Bottleneck Module'), 94 | plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='lightblue', markersize=10, label='Regular Module') 95 | ], loc='best') 96 | 97 | plt.title('Model Structure with Bottleneck Modules Highlighted') 98 | plt.tight_layout() 99 | 100 | fig_path = os.path.join(self.output_dir, 'bottleneck_visualization.png') 101 | plt.savefig(fig_path) 102 | plt.close() 103 | print(f"Bottleneck visualization saved to {fig_path}") 104 | 105 | except Exception as e: 106 | print(f"Warning: Could not generate bottleneck visualization. Error: {e}") 107 | 108 | # As a fallback option, create a simple bar chart showing bottleneck modules 109 | plt.figure(figsize=(10, 6)) 110 | bottleneck_names = bottleneck_modules 111 | 112 | # Create a simple chart that directly displays the names of bottleneck modules 113 | y_pos = range(len(bottleneck_names)) 114 | plt.barh(y_pos, [1] * len(bottleneck_names), color='red') 115 | plt.yticks(y_pos, bottleneck_names) 116 | plt.xlabel('Identified as Bottleneck') 117 | plt.title('Bottleneck Modules') 118 | plt.tight_layout() 119 | 120 | fig_path = os.path.join(self.output_dir, 'bottleneck_modules.png') 121 | plt.savefig(fig_path) 122 | plt.close() 123 | print(f"Simple bottleneck list saved to {fig_path}") 124 | 125 | def plot_performance_comparison(self, original_results: Dict[str, float], quantized_results: Dict[str, float], save_fig: bool = True): 126 | """ 127 | Plot the performance comparison between the original and quantized models. 128 | 129 | Args: 130 | original_results (Dict[str, float]): Evaluation results of the original model. 131 | quantized_results (Dict[str, float]): Evaluation results of the quantized model. 132 | save_fig (bool): Whether to save the figure to a file. 133 | """ 134 | metrics = ['accuracy', 'loss', 'avg_inference_time', 'throughput'] 135 | original_values = [original_results.get(metric) for metric in metrics] 136 | quantized_values = [quantized_results.get(metric) for metric in metrics] 137 | 138 | x = range(len(metrics)) 139 | 140 | plt.figure(figsize=(10, 6)) 141 | plt.bar(x, original_values, width=0.4, label='Original Model', align='center') 142 | plt.bar([i + 0.4 for i in x], quantized_values, width=0.4, label='Quantized Model', align='center') 143 | plt.xticks([i + 0.2 for i in x], metrics) 144 | plt.ylabel('Value') 145 | plt.title('Performance Comparison') 146 | plt.legend() 147 | plt.tight_layout() 148 | 149 | if save_fig: 150 | fig_path = os.path.join(self.output_dir, 'performance_comparison.png') 151 | plt.savefig(fig_path) 152 | plt.close() 153 | print(f"Performance comparison plot saved to {fig_path}") 154 | else: 155 | plt.show() 156 | -------------------------------------------------------------------------------- /unit_test/net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class CNNWithAttention(nn.Module): 6 | def __init__(self, num_classes: int=10): 7 | super(CNNWithAttention, self).__init__() 8 | 9 | # CNN part 10 | # input: (batch_size, 3, 32, 32) 11 | # after the below CNN, output: (batch_size, 64, 8, 8) 12 | self.conv1 = nn.Sequential( 13 | nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=2, padding=1), 14 | nn.BatchNorm2d(32), 15 | nn.ReLU() 16 | ) 17 | self.conv2 = nn.Sequential( 18 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1), 19 | nn.BatchNorm2d(64), 20 | nn.ReLU() 21 | ) 22 | 23 | # Attention part 24 | # MultiheadAttention input: (sequence_length, batch_size, embed_dim) 25 | # So CNN output format must be in reshape / permute 26 | # Let embed_dim=64, num_heads=4 27 | self.embed_dim = 64 28 | self.attn = nn.MultiheadAttention(embed_dim=self.embed_dim, num_heads=4) 29 | 30 | # MLP Part 31 | # Pooling the output of Attention, and then classify 32 | self.fc = nn.Sequential( 33 | nn.Linear(self.embed_dim, 128), 34 | nn.ReLU(), 35 | nn.Linear(128, num_classes) 36 | ) 37 | 38 | def forward(self, x: torch.Tensor): 39 | """ 40 | x: (batch_size, 3, 32, 32) 41 | return: (batch_size, 10) 42 | """ 43 | # Extract features by CNN 44 | x = self.conv1(x) # (batch_size, 32, 16, 16) 45 | x = self.conv2(x) # (batch_size, 64, 8, 8) 46 | 47 | # Faltten features and keep dims of batch_size and channel 48 | # flatten spatial dims: 8 * 8 = 64 to (batch_size, 64, 64) 49 | batch_size, channels, height, width = x.shape 50 | x = x.view(batch_size, channels, height * width) # (batch_size, 64, 64) 51 | 52 | # Note that MultiheadAttention need shape (sequence_length, batch_size, embed_dim) 53 | # So permute to (64, batch_size, 64) 54 | x = x.permute(2, 0, 1) # (seq_len=64, batch_size, embed_dim=64) 55 | 56 | # Let query, key, value comes from a some x (self-attention) 57 | attn_output, _ = self.attn(x, x, x) # (seq_len=64, batch_size, 64) 58 | 59 | # Take average on the time dim (batch_size, 64) 60 | attn_output = attn_output.mean(dim=0) # (batch_size, 64) 61 | 62 | # MLP layer 63 | out = self.fc(attn_output) # (batch_size, num_classes) 64 | 65 | return out 66 | 67 | 68 | if __name__ == '__main__': 69 | # Test the network 70 | model = CNNWithAttention(num_classes=10) 71 | print(model) 72 | 73 | x = torch.randn(4, 3, 32, 32) # batch_size=4 74 | y = model(x) 75 | print(f"input x shape: {x.shape}") 76 | print(f"output y shape: {y.shape}") -------------------------------------------------------------------------------- /unit_test/test_bottleneck_identifier.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from bottleneck_identifier import BottleneckIdentifier 3 | 4 | 5 | # case 1: get bottlenecks with threshold 6 | def test_identify_bottlenecks_with_threshold(): 7 | identifier = BottleneckIdentifier() 8 | error_data = {"module1": 0.8, "module2": 0.5, "module3": 0.3} 9 | threshold = 0.6 10 | bottlenecks = identifier.identify_bottlenecks(error_data, threshold=threshold) 11 | assert bottlenecks == ["module1"] 12 | 13 | # case 2: get bottlenecks with top_n 14 | def test_identify_bottlenecks_with_top_n(): 15 | identifier = BottleneckIdentifier() 16 | error_data = {"module1": 0.8, "module2": 0.5, "module3": 0.3} 17 | top_n = 2 18 | bottlenecks = identifier.identify_bottlenecks(error_data, top_n=top_n) 19 | assert bottlenecks == ["module1", "module2"] 20 | 21 | # case 3: get bottlenecks with ratio 22 | def test_identify_bottlenecks_with_error_ratio(): 23 | identifier = BottleneckIdentifier() 24 | error_data = {"module1": 0.8, "module2": 0.5, "module3": 0.3} 25 | error_ratio = 0.7 26 | bottlenecks = identifier.identify_bottlenecks(error_data, error_ratio=error_ratio) 27 | assert bottlenecks == ["module1"] 28 | 29 | # case 4: get bottlenecks with empty input 30 | def test_identify_bottlenecks_empty_error_data(): 31 | identifier = BottleneckIdentifier() 32 | with pytest.raises(ValueError): 33 | identifier.identify_bottlenecks({}) 34 | 35 | # case 5: get bottlenecks without option 36 | def test_identify_bottlenecks_no_criteria(): 37 | identifier = BottleneckIdentifier() 38 | error_data = {"module1": 0.8, "module2": 0.5, "module3": 0.3} 39 | with pytest.raises(ValueError): 40 | identifier.identify_bottlenecks(error_data) 41 | 42 | # case 6: get bottlenecks without input 43 | def test_get_bottleneck_modules_empty(): 44 | identifier = BottleneckIdentifier() 45 | expected = [] 46 | actual = identifier.get_bottleneck_modules() 47 | assert expected == actual 48 | 49 | # case 7: get bottlenecks with specified input 50 | def test_get_bottleneck_modules_with_modules(): 51 | identifier = BottleneckIdentifier() 52 | identifier.bottleneck_modules = ["module1", "module2"] 53 | expected = ["module1", "module2"] 54 | actual = identifier.get_bottleneck_modules() 55 | assert expected == actual -------------------------------------------------------------------------------- /unit_test/test_config_manager.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import yaml 3 | import argparse 4 | import torch 5 | import os 6 | from config_manager import ConfigManager 7 | 8 | # case 1: load default config 9 | def test_load_default_config(): 10 | config_manager = ConfigManager() 11 | default_config = config_manager.config 12 | 13 | assert default_config!= {} 14 | assert default_config['error_threshold'] == 1e-4 15 | assert default_config['log_file'] is None 16 | assert default_config['log_level'] == 'INFO' 17 | assert default_config['output_dir'] == './output' 18 | assert default_config['device'] == 'cuda' if torch.cuda.is_available() else 'cpu' 19 | assert default_config['batch_size'] == 32 20 | assert default_config['num_workers'] == 4 21 | assert default_config['dataset']['name'] == 'CIFAR10' 22 | assert default_config['dataset']['root'] == './data' 23 | assert default_config['dataset']['download'] is True 24 | assert default_config['model']['name'] == 'resnet18' 25 | assert default_config['model']['pretrained'] is True 26 | assert default_config['quantization']['error_metric'] =='mse' 27 | assert default_config['quantization']['strategy'] == 'threshold' 28 | assert default_config['quantization']['top_n'] is None 29 | assert default_config['quantization']['error_ratio'] is None 30 | 31 | # case 2: load config from file 32 | def test_load_config_from_file(tmp_path): 33 | config_file = tmp_path / "config.yaml" 34 | config_data = {"key1": "value1", "key2": "value2"} 35 | with open(config_file, "w") as f: 36 | yaml.dump(config_data, f) 37 | 38 | config_manager = ConfigManager(config_file=str(config_file)) 39 | assert config_manager.config.get("key1") == "value1" 40 | assert config_manager.config.get("key2") == "value2" 41 | 42 | # case 3: load config from environment variable 43 | def test_load_config_from_env(monkeypatch): 44 | monkeypatch.setenv("CONFIG_KEY1", "env_value1") 45 | config_manager = ConfigManager() 46 | assert config_manager.config.get("CONFIG_KEY1") == "env_value1" 47 | 48 | # case 4: load config from args 49 | def test_load_config_from_args(): 50 | config_manager = ConfigManager() 51 | config_manager.load_config_from_args() 52 | 53 | assert config_manager.config.get('config_file') == None 54 | assert config_manager.config.get('log_file') == None 55 | 56 | 57 | 58 | # case 5: load config from nonexistent file 59 | def test_load_nonexistent_config_file(): 60 | config_manager = ConfigManager() 61 | with pytest.raises(FileNotFoundError): 62 | config_manager.load_config_from_file('nonexistent_config.yaml') 63 | 64 | 65 | # case 6: load config from environment variable 66 | def test_load_config_from_env(): 67 | config_manager = ConfigManager() 68 | os.environ['DEVICE'] = 'test_device' 69 | config_manager.load_config_from_env() 70 | 71 | assert config_manager.config['device'] == 'test_device' 72 | del os.environ['DEVICE'] 73 | 74 | # case 7: load config from empty environment variable 75 | def test_load_config_from_env_no_variable(): 76 | config_manager = ConfigManager() 77 | config_manager.load_config_from_env() 78 | assert 'config_file' not in config_manager.config 79 | 80 | 81 | # case 8: merge config with empty config 82 | def test_merge_empty_config(): 83 | config_manager = ConfigManager() 84 | origin_size = len(config_manager.config) 85 | new_config = {} 86 | config_manager.merge_config(new_config) 87 | assert len(config_manager.config) == origin_size 88 | 89 | # case 9: merge config with non-empty config 90 | def test_merge_non_empty_config(): 91 | config_manager = ConfigManager() 92 | new_config = { 93 | "key1": "value1", 94 | "key2": "value2" 95 | } 96 | origin_size = len(config_manager.config) 97 | 98 | config_manager.merge_config(new_config) 99 | assert len(config_manager.config) == origin_size + len(new_config) 100 | 101 | # case 10: merge config with existing config 102 | def test_merge_override_existing_key(): 103 | config_manager = ConfigManager() 104 | config_manager.config = { 105 | "key1": "old_value" 106 | } 107 | new_config = { 108 | "key1": "new_value" 109 | } 110 | config_manager.merge_config(new_config) 111 | assert config_manager.config["key1"] == "new_value" 112 | 113 | # case 11: merge config with new key added 114 | def test_merge_add_new_key(): 115 | config_manager = ConfigManager() 116 | config_manager.config = { 117 | "key1": "value1" 118 | } 119 | new_config = { 120 | "key2": "value2" 121 | } 122 | config_manager.merge_config(new_config) 123 | assert "key2" in config_manager.config 124 | assert config_manager.config["key2"] == "value2" 125 | 126 | # case 12: merge config with nested dict 127 | def test_merge_nested_dict(): 128 | config_manager = ConfigManager() 129 | config_manager.config = { 130 | "key1": { 131 | "sub_key1": "value1" 132 | } 133 | } 134 | new_config = { 135 | "key1": { 136 | "sub_key2": "value2" 137 | } 138 | } 139 | config_manager.merge_config(new_config) 140 | assert "sub_key2" in config_manager.config["key1"] 141 | assert config_manager.config["key1"]["sub_key2"] == "value2" 142 | 143 | # case 13: merge config with dict 144 | def test_recursive_merge_normal(): 145 | config_manager = ConfigManager() 146 | default = {'a': 1, 'b': {'c': 2}} 147 | override = {'b': {'d': 3}} 148 | expected = {'a': 1, 'b': {'c': 2, 'd': 3}} 149 | assert config_manager._recursive_merge(default=default, override=override) == expected 150 | 151 | # case 14: merge config with dict overwrite 152 | def test_recursive_merge_override_value(): 153 | config_manager = ConfigManager() 154 | default = {'a': 1, 'b': 2} 155 | override = {'b': 3} 156 | expected = {'a': 1, 'b': 3} 157 | assert config_manager._recursive_merge(default, override) == expected 158 | 159 | # case 15: merge config with empty dict 160 | def test_recursive_merge_empty_override(): 161 | config_manager = ConfigManager() 162 | default = {'a': 1} 163 | override = {} 164 | expected = {'a': 1} 165 | assert config_manager._recursive_merge(default, override) == expected 166 | 167 | # case 16: merge config with empty dict 168 | def test_recursive_merge_empty_default(): 169 | config_manager = ConfigManager() 170 | default = {} 171 | override = {'a': 1} 172 | expected = {'a': 1} 173 | assert config_manager._recursive_merge(default, override) == expected 174 | 175 | 176 | # case 17: create config manager 177 | @pytest.fixture 178 | def config_manager(): 179 | return ConfigManager() 180 | 181 | 182 | # case 18: set config single key 183 | def test_set_config_single_key(config_manager): 184 | config_manager.set_config('key1', 'value1') 185 | assert config_manager.config['key1'] == 'value1' 186 | 187 | 188 | # case 19: set config nested key 189 | def test_set_config_nested_key(config_manager): 190 | config_manager.set_config('section1.key2', 'value2') 191 | assert config_manager.config['section1']['key2'] == 'value2' 192 | 193 | 194 | # case 20: set config nonexistent key 195 | def test_set_config_nonexistent_key(config_manager): 196 | config_manager.set_config('new_section.key3', 'value3') 197 | assert config_manager.config['new_section']['key3'] == 'value3' 198 | 199 | 200 | # case 21: get all config 201 | def test_get_config_all(): 202 | config_manager = ConfigManager() 203 | config = config_manager.get_config() 204 | assert config == config_manager.config 205 | 206 | 207 | # case 22: get existing key 208 | def test_get_config_existing_key(): 209 | config_manager = ConfigManager() 210 | config_manager.config['test'] = 'value' 211 | value = config_manager.get_config('test') 212 | assert value == 'value' 213 | 214 | # case 23: get nonexistent key 215 | def test_get_config_non_existing_key(): 216 | config_manager = ConfigManager() 217 | value = config_manager.get_config('non_existing_key') 218 | assert value is None 219 | 220 | 221 | # case 24: get nested key 222 | def test_get_config_nested_key(): 223 | config_manager = ConfigManager() 224 | config_manager.config['nested'] = {'key': 'value'} 225 | value = config_manager.get_config('nested.key') 226 | assert value == 'value' 227 | 228 | 229 | # case 25: print config 230 | def test_print_config_normal(): 231 | config_manager = ConfigManager() 232 | config_manager.print_config() 233 | 234 | 235 | # case 26: print empty config 236 | def test_print_config_empty(): 237 | config_manager = ConfigManager() 238 | config_manager.config = {} 239 | config_manager.print_config() 240 | 241 | 242 | # case 27: print complex config 243 | def test_print_config_complex(): 244 | config_manager = ConfigManager() 245 | config_manager.config = { 246 | "key1": "value1", 247 | "key2": [1, 2, 3], 248 | "key3": { 249 | "sub_key1": "sub_value1", 250 | "sub_key2": "sub_value2" 251 | } 252 | } 253 | config_manager.print_config() 254 | -------------------------------------------------------------------------------- /unit_test/test_evaluator.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import torch.nn as nn 4 | from torch.utils.data import DataLoader, TensorDataset 5 | from evaluator import Evaluator 6 | from net import CNNWithAttention 7 | 8 | 9 | # test dataloader 10 | def mock_dataloader(): 11 | num_samples = 128 12 | num_classes = 10 13 | images = torch.rand(num_samples, 3, 32, 32) 14 | labels = torch.randint(0, num_classes, (num_samples,)) 15 | 16 | dataset = TensorDataset(images, labels) 17 | return DataLoader(dataset, batch_size=16, shuffle=False) 18 | 19 | 20 | @pytest.fixture 21 | def evaluator(): 22 | return Evaluator() 23 | 24 | 25 | # case 1: test default device 26 | def test_default_device(): 27 | evaluator = Evaluator() 28 | assert evaluator.device == torch.device('cpu') 29 | 30 | # case 2: test specified device 31 | def test_specified_device(): 32 | device = torch.device('cuda:0') 33 | evaluator = Evaluator(device=device) 34 | assert evaluator.device == device 35 | 36 | # case 3: test evaluate_model 37 | def test_evaluate_model(evaluator): 38 | metrics = {'accuracy': lambda x, y:(x.argmax(dim=1) == y).float().mean()} 39 | dataloader = mock_dataloader() 40 | 41 | # Evaluate original model 42 | results = evaluator.evaluate_model( 43 | model = CNNWithAttention(num_classes=10), 44 | dataloader = dataloader, 45 | criterion = nn.CrossEntropyLoss(), 46 | metrics = metrics 47 | ) 48 | 49 | 50 | assert 'loss' in results 51 | assert 'accuracy' in results 52 | assert 'avg_inference_time' in results 53 | assert 'throughput' in results 54 | assert 'total_samples' in results 55 | 56 | assert results['loss'] is not None 57 | assert 0 <= results['accuracy'] <= 1 58 | assert results['avg_inference_time'] > 0 59 | assert results['throughput'] > 0 60 | assert results['total_samples'] == 128 61 | 62 | assert'accuracy' in results 63 | assert 0 <= results['accuracy'] <= 1 64 | 65 | 66 | # case 4: test compare_models method 67 | def test_compare_models(): 68 | evaluator = Evaluator() 69 | 70 | model1 = CNNWithAttention(num_classes=10) 71 | model2 = CNNWithAttention(num_classes=10) 72 | 73 | dataloader = mock_dataloader() 74 | 75 | # test evaluator without loss function or metrics 76 | comparison = evaluator.compare_models(model1, model2, dataloader) 77 | assert 'model1' in comparison 78 | assert 'model2' in comparison 79 | 80 | # test evaluator with loss function 81 | criterion = nn.CrossEntropyLoss() 82 | comparison = evaluator.compare_models(model1, model2, dataloader, criterion=criterion) 83 | assert 'loss' in comparison['model1'] 84 | assert 'loss' in comparison['model2'] 85 | 86 | # test evaluator with metrics 87 | metrics = {'accuracy': lambda x, y:(x.argmax(dim=1) == y).float().mean()} 88 | comparison = evaluator.compare_models(model1, model2, dataloader, metrics=metrics) 89 | assert 'accuracy' in comparison['model1'] 90 | assert 'accuracy' in comparison['model2'] -------------------------------------------------------------------------------- /unit_test/test_hook_manager.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import torch.nn as nn 4 | from hook_manager import HookManager 5 | 6 | # test model 7 | class MockModel(torch.nn.Module): 8 | def __init__(self): 9 | super(MockModel, self).__init__() 10 | 11 | def forward(self, x): 12 | return x + 1 13 | 14 | @pytest.fixture 15 | def model(): 16 | return MockModel() 17 | 18 | @pytest.fixture 19 | def hook_manager(): 20 | return HookManager() 21 | 22 | # case 1: init test 23 | def test_hook_manager_init(): 24 | hook_manager = HookManager() 25 | assert len(hook_manager.hooks) == 0 26 | assert len(hook_manager.module_io) == 0 27 | 28 | # case 2: register test 29 | def test_register_hooks(hook_manager, model): 30 | hook_manager.register_hooks(model) 31 | 32 | assert len(hook_manager.hooks) == 1 33 | 34 | x = torch.randn(2, 3, 4) 35 | output = model(x) 36 | assert torch.equal(x+1, output) 37 | 38 | for module, io in hook_manager.module_io.items(): 39 | assert io['input'] is not None 40 | print(io['input']) 41 | assert torch.equal(x, io['input'][0]) 42 | assert io['output'] is not None 43 | assert torch.equal(output, io['output']) 44 | 45 | 46 | # case 3: test remove_hooks method 47 | def test_remove_hooks(hook_manager, model): 48 | hook_manager.register_hooks(model) 49 | assert len(hook_manager.hooks) == 1 50 | 51 | hook_manager.remove_hooks() 52 | assert len(hook_manager.hooks) == 0 53 | 54 | 55 | # case 3: test get_module_io method 56 | def test_get_module_io(): 57 | hook_manager = HookManager() 58 | 59 | module1 = nn.Linear(10, 20) 60 | input1 = torch.randn(1, 10) 61 | output1 = module1(input1) 62 | hook_manager.module_io[module1] = {'input': input1, 'output': output1} 63 | 64 | module2 = nn.Conv2d(3, 16, 3) 65 | input2 = torch.randn(1, 3, 28, 28) 66 | output2 = module2(input2) 67 | hook_manager.module_io[module2] = {'input': input2, 'output': output2} 68 | 69 | module_io = hook_manager.get_module_io() 70 | 71 | assert torch.equal(module_io[module1]['input'], input1) 72 | assert torch.equal(module_io[module1]['output'], output1) 73 | 74 | assert torch.equal(module_io[module2]['input'], input2) 75 | assert torch.equal(module_io[module2]['output'], output2) 76 | 77 | 78 | -------------------------------------------------------------------------------- /unit_test/test_logger.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import logging 3 | from logger import Logger 4 | 5 | 6 | # case 1: init logger with file and level specified 7 | def test_init_with_file_and_level(): 8 | logger = Logger(log_file='test.log', log_level=logging.DEBUG) 9 | assert logger.logger.level == logging.DEBUG 10 | 11 | # case 2: init logger without file specified 12 | def test_init_with_level_only(): 13 | logger = Logger(log_level=logging.INFO) 14 | assert logger.logger.level == logging.INFO 15 | 16 | # case 3: init logger without paramater 17 | def test_init_default(): 18 | logger = Logger() 19 | assert logger.logger.level == logging.INFO 20 | 21 | # case 4: test log written to file 22 | def test_log_info_to_file(): 23 | logger = Logger(log_file='test.log') 24 | logger.log_info('this is a log from file') 25 | 26 | # case 5: test log shown to console 27 | def test_log_info_to_console(): 28 | logger = Logger() 29 | logger.log_info('this is a log from console') 30 | 31 | # case 6: test debug log written to file 32 | def test_log_debug_with_file(): 33 | logger = Logger(log_file='test.log') 34 | logger.log_debug('this is a log from file with debug information') 35 | 36 | # case 7: test debug log shown to console 37 | def test_log_debug_without_file(): 38 | logger = Logger() 39 | logger.log_debug('this is a log from console with debug information') 40 | 41 | # case 8: test error log written to file 42 | def test_log_error_with_file(): 43 | logger = Logger(log_file='test.log') 44 | logger.log_error('this is a log from file with error information') 45 | 46 | # case 9: test error log shown to console 47 | def test_log_error_without_file(): 48 | logger = Logger() 49 | logger.log_error('this is a log from console with error information') 50 | -------------------------------------------------------------------------------- /unit_test/test_model_parser.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import torch.nn as nn 4 | from model_parser import ModelParser 5 | 6 | # a simple model for test 7 | class SimpleModel(nn.Module): 8 | def __init__(self): 9 | super(SimpleModel, self).__init__() 10 | self.layer1 = nn.Linear(10, 20) 11 | self.layer2 = nn.Linear(20, 30) 12 | 13 | # a simple model with one leaf 14 | class SingleLeafModel(nn.Module): 15 | def __init__(self): 16 | super(SingleLeafModel, self).__init__() 17 | self.leaf_layer = nn.Linear(5, 10) 18 | 19 | @pytest.fixture 20 | def model(): 21 | return SimpleModel() 22 | 23 | # case 1: test get_all_modules method 24 | def test_get_all_modules(model): 25 | parser = ModelParser(model) 26 | all_modules = parser.get_all_modules() 27 | 28 | assert 'layer1' in all_modules 29 | assert 'layer2' in all_modules 30 | 31 | assert isinstance(all_modules[''], SimpleModel) 32 | assert isinstance(all_modules['layer1'], nn.Linear) 33 | assert isinstance(all_modules['layer2'], nn.Linear) 34 | 35 | # case 2: test get_leaf_modules method 36 | def test_get_leaf_modules(): 37 | model = SimpleModel() 38 | parser = ModelParser(model) 39 | 40 | leaf_modules = parser.get_leaf_modules() 41 | 42 | assert len(leaf_modules) == 2 43 | 44 | assert (leaf_modules[0][0], type(leaf_modules[0][1])) == ('layer1', nn.Linear) 45 | assert (leaf_modules[1][0], type(leaf_modules[1][1])) == ('layer2', nn.Linear) 46 | 47 | 48 | 49 | # case 3: test get_leaf_modules method 50 | def test_get_leaf_modules(): 51 | model = SimpleModel() 52 | parser = ModelParser(model) 53 | 54 | leaf_modules = parser._get_leaf_modules() 55 | 56 | assert len(leaf_modules) == 2 57 | assert leaf_modules[0][0] == 'layer1' 58 | assert leaf_modules[1][0] == 'layer2' 59 | 60 | # case 4: test empty model 61 | def test_empty_model(): 62 | emptymodel = nn.Module() 63 | parser = ModelParser(emptymodel) 64 | 65 | leaf_modules = parser._get_leaf_modules() 66 | assert len(leaf_modules) == 1 67 | 68 | 69 | # case 5: test model with one leaf 70 | def test_single_leaf_module(): 71 | model = SingleLeafModel() 72 | parser = ModelParser(model) 73 | 74 | leaf_modules = parser._get_leaf_modules() 75 | assert len(leaf_modules) == 1 76 | assert leaf_modules[0][0] == 'leaf_layer' 77 | 78 | 79 | # case 6: test get_module_by_name method 80 | def test_get_module_by_name_existing(model): 81 | model_parser = ModelParser(model) 82 | module = model_parser.get_module_by_name('layer1') 83 | assert module is not None 84 | 85 | # case 7: test get_module_by_name method 86 | def test_get_module_by_name_nonexistent(model): 87 | model_parser = ModelParser(model) 88 | module = model_parser.get_module_by_name('nonexistent_module') 89 | assert module is None -------------------------------------------------------------------------------- /unit_test/test_model_transformer.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import torch.nn as nn 4 | from model_transformer import ModelTransformer 5 | 6 | # case 1: init test 7 | def test_init_success(): 8 | transformer = ModelTransformer() 9 | assert transformer is not None 10 | 11 | 12 | # case 2: deep copy test 13 | def test_copy_model_normal(): 14 | model_transformer = ModelTransformer() 15 | original_model = nn.Linear(10, 20) 16 | copied_model = model_transformer._copy_model(original_model) 17 | assert copied_model is not original_model 18 | assert copied_model.weight.data.equal(original_model.weight.data) 19 | 20 | 21 | 22 | # a mock network 23 | class MockModel(nn.Module): 24 | def __init__(self): 25 | super(MockModel, self).__init__() 26 | self.layer1 = nn.Linear(10, 20) 27 | self.layer2 = nn.Linear(20, 30) 28 | 29 | def forward(self, x): 30 | x = self.layer1(x) 31 | x = self.layer2(x) 32 | return x 33 | 34 | 35 | # case 3: test quantize_model method 36 | def test_quantize_model(): 37 | model = MockModel() 38 | bottleneck_modules = ["layer2"] 39 | module_names = {model.layer1: "layer1", model.layer2: "layer2"} 40 | transformer = ModelTransformer() 41 | quantized_model = transformer.quantize_model(model, bottleneck_modules, module_names) 42 | 43 | assert quantized_model.layer1.weight.dtype == torch.float16 44 | assert quantized_model.layer2.weight.dtype == torch.float16 45 | 46 | 47 | 48 | 49 | # a mock module 50 | class SimpleModule(nn.Module): 51 | def __init__(self): 52 | super(SimpleModule, self).__init__() 53 | self.param = nn.Parameter(torch.randn(2, 3)) 54 | self.buffer = nn.Parameter(torch.randn(2, 3)) 55 | 56 | def forward(self, x): 57 | return x + self.param + self.buffer 58 | 59 | 60 | # case 3: test _convert_module_to_half method 61 | def test_convert_module_to_half(): 62 | model = SimpleModule() 63 | 64 | transformer = ModelTransformer() 65 | 66 | transformer._convert_module_to_half(model) 67 | 68 | assert model.param.dtype == torch.float16 69 | assert model.buffer.dtype == torch.float16 70 | 71 | x = torch.randn(2, 3) 72 | output = model(x) 73 | assert output.dtype == torch.float32 74 | 75 | for child in model.children(): 76 | assert child.param.dtype == torch.float16 77 | assert child.buffer.dtype == torch.float16 78 | 79 | 80 | # mock bottleneck 81 | class MockBottleneckModule(nn.Module): 82 | def __init__(self): 83 | super(MockBottleneckModule, self).__init__() 84 | 85 | def forward(self, inputs): 86 | if isinstance(inputs, tuple): 87 | x1, x2 = inputs 88 | elif isinstance(inputs, dict): 89 | x1 = inputs['input1'] 90 | x2 = inputs['input2'] 91 | else: 92 | x1 = inputs 93 | x2 = torch.zeros_like(x1) 94 | 95 | return x1 + x2 96 | 97 | # case 4: test _wrap_bottleneck_module method 98 | def test_wrap_bottleneck_module(): 99 | model_transformer = ModelTransformer() 100 | 101 | bottleneck_module = MockBottleneckModule() 102 | 103 | model_transformer._wrap_bottleneck_module(bottleneck_module) 104 | 105 | input_tensor1 = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) 106 | input_tensor2 = torch.tensor([4.0, 5.0, 6.0], dtype=torch.float32) 107 | output = bottleneck_module(input_tensor1) 108 | assert output.dtype == torch.float32 109 | 110 | output = bottleneck_module((input_tensor1, input_tensor2)) 111 | assert output.dtype == torch.float32 112 | 113 | output = bottleneck_module({"input1":input_tensor1, "input2":input_tensor2}) 114 | assert output.dtype == torch.float32 -------------------------------------------------------------------------------- /unit_test/test_quantization_tester.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import torch.nn as nn 4 | import math 5 | from quantization_tester import QuantizationTester 6 | 7 | # Define a simple test module 8 | class MockModule(nn.Module): 9 | def __init__(self): 10 | super(MockModule, self).__init__() 11 | self.linear = nn.Linear(10, 20) 12 | 13 | def forward(self, x): 14 | return self.linear(x) 15 | 16 | @pytest.fixture 17 | def tester(): 18 | return QuantizationTester() 19 | 20 | @pytest.fixture 21 | def test_module(): 22 | return MockModule() 23 | 24 | # case 1: init 25 | def test_init_normal(): 26 | tester = QuantizationTester(error_metric='mse') 27 | assert tester.error_metric == 'mse' 28 | assert not tester.module_errors 29 | 30 | # case 2: test test_modules method 31 | def test_test_modules(): 32 | module = MockModule() 33 | 34 | tester = QuantizationTester() 35 | 36 | input_tensor = torch.randn(1, 10) 37 | output_tensor = module(input_tensor) 38 | 39 | modules = {'MockModule': module} 40 | module_io = {module: {'input': (input_tensor,), 'output': output_tensor}} 41 | module_names = {module: 'MockModule'} 42 | 43 | errors = tester.test_modules(modules, module_io, module_names) 44 | assert errors 45 | assert 'MockModule' in errors 46 | 47 | 48 | # case 3: test convert_module_to_half method 49 | def test_convert_module_to_half(tester, test_module): 50 | original_params = {} 51 | original_buffers = {} 52 | 53 | for name, param in test_module.named_parameters(): 54 | original_params[name] = param.data.clone() 55 | for name, buffer in test_module.named_buffers(): 56 | original_buffers[name] = buffer.data.clone() 57 | 58 | tester._convert_module_to_half(test_module) 59 | 60 | for name, param in test_module.named_parameters(): 61 | assert param.dtype == torch.float16 62 | 63 | for name, buffer in test_module.named_buffers(): 64 | assert buffer.dtype == torch.float16 65 | 66 | for name, param in test_module.named_parameters(): 67 | param.data = original_params[name] 68 | for name, buffer in test_module.named_buffers(): 69 | buffer.data = original_buffers[name] 70 | 71 | 72 | # case 4: test _compute_error method for single tensor 73 | def test_compute_tensor_error_single_tensor(): 74 | tester = QuantizationTester() 75 | output1 = torch.randn(2, 3) 76 | output2 = torch.randn(2, 3) 77 | error = tester._compute_error(output1, output2) 78 | assert isinstance(error, float) 79 | 80 | # case 5: test _compute_error method for tuple 81 | def test_compute_tensor_error_tuple_output(): 82 | tester = QuantizationTester() 83 | output1 = (torch.randn(2, 3), torch.randn(2, 3)) 84 | output2 = (torch.randn(2, 3), torch.randn(2, 3)) 85 | error = tester._compute_error(output1, output2) 86 | assert isinstance(error, float) 87 | 88 | # case 6: test _compute_error method for unsupported type 89 | def test_compute_tensor_error_unsupported_type(): 90 | tester = QuantizationTester() 91 | output1 = "unsupported_type" 92 | output2 = "unsupported_type" 93 | error = tester._compute_error(output1, output2) 94 | assert math.isinf(error) 95 | 96 | # case 7: test _compute_tensor_error method with MSE loss 97 | def test_compute_tensor_error_mse(): 98 | tester = QuantizationTester(error_metric='mse') 99 | tensor1 = torch.tensor([1., 2., 3.]) 100 | tensor2 = torch.tensor([2., 3., 4.]) 101 | error = tester._compute_tensor_error(tensor1, tensor2) 102 | assert error == pytest.approx(1.) 103 | 104 | 105 | # case 8: test _compute_tensor_error method with MAE loss 106 | def test_compute_tensor_error_mae(): 107 | tester = QuantizationTester(error_metric='mae') 108 | tensor1 = torch.tensor([1.0, 2.0, 3.0]) 109 | tensor2 = torch.tensor([1.1, 2.1, 2.9]) 110 | error = tester._compute_tensor_error(tensor1, tensor2) 111 | assert error == pytest.approx(0.1) 112 | 113 | # case 9: test _compute_tensor_error method with unsupported loss 114 | def test_compute_tensor_error_invalid_metric(): 115 | tester = QuantizationTester(error_metric='invalid_metric') 116 | tensor1 = torch.tensor([1.0, 2.0, 3.0]) 117 | tensor2 = torch.tensor([1.1, 2.1, 2.9]) 118 | with pytest.raises(ValueError): 119 | tester._compute_tensor_error(tensor1, tensor2) 120 | 121 | 122 | # case 10: test _to_float32 method with single tensor 123 | def test_to_float32_single_tensor(): 124 | tester = QuantizationTester() 125 | input_tensor = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float64) 126 | expected_output = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) 127 | output = tester._to_float32(input_tensor) 128 | assert torch.allclose(output, expected_output) 129 | 130 | # case 11: test _to_float32 method with tuple 131 | def test_to_float32_tuple_tensors(): 132 | tester = QuantizationTester() 133 | input_tuple = (torch.tensor([1.0, 2.0], dtype=torch.float64), torch.tensor([3.0, 4.0], dtype=torch.float64)) 134 | expected_output_tuple = (torch.tensor([1.0, 2.0], dtype=torch.float32), torch.tensor([3.0, 4.0], dtype=torch.float32)) 135 | output_tuple = tester._to_float32(input_tuple) 136 | for output, expected_output in zip(output_tuple, expected_output_tuple): 137 | assert torch.allclose(output, expected_output) 138 | 139 | # case 12: test _to_float32 method with unsupported type 140 | def test_to_float32_other_type(): 141 | tester = QuantizationTester() 142 | input_other = "not a tensor" 143 | output = tester._to_float32(input_other) 144 | assert output == input_other 145 | 146 | 147 | # case 12: test sort_modules_by_error method 148 | def test_quantization_tester_sort_modules_by_error(): 149 | tester = QuantizationTester() 150 | 151 | tester.module_errors = { 152 | 'module1': 0.1, 153 | 'module2': 0.3, 154 | 'module3': 0.2 155 | } 156 | 157 | # assending 158 | sorted_modules_desc = tester.sort_modules_by_error(descending=True) 159 | assert sorted_modules_desc == [('module2', 0.3), ('module3', 0.2), ('module1', 0.1)] 160 | 161 | # desending 162 | sorted_modules_asc = tester.sort_modules_by_error(descending=False) 163 | assert sorted_modules_asc == [('module1', 0.1), ('module3', 0.2), ('module2', 0.3)] 164 | 165 | # empty dictionary 166 | tester.module_errors = {} 167 | sorted_modules_empty = tester.sort_modules_by_error() 168 | assert sorted_modules_empty == [] 169 | --------------------------------------------------------------------------------