├── src ├── __init__.py ├── imclaslib │ ├── __init__.py │ ├── dataset │ │ ├── __init__.py │ │ ├── video_predict_dataset.py │ │ ├── images_predict_dataset.py │ │ ├── datasetutils.py │ │ └── image_dataset.py │ ├── files │ │ ├── __init__.py │ │ ├── pathutils.py │ │ ├── modelloadingutils.py │ │ └── imageutils.py │ ├── logging │ │ ├── __init__.py │ │ └── loggerfactory.py │ ├── metrics │ │ ├── __init__.py │ │ └── metricutils.py │ ├── models │ │ ├── __init__.py │ │ ├── multilabel_dice_loss.py │ │ ├── modelutils.py │ │ ├── multilabel_classifier.py │ │ ├── multilabel_focal_loss.py │ │ ├── multilabel_embeddinglayer_model.py │ │ ├── model_layers.py │ │ ├── gcn_classifier.py │ │ ├── ensemble_classifier.py │ │ └── modelfactory.py │ ├── evaluation │ │ ├── __init__.py │ │ ├── modelevaluator.py │ │ └── test_model.py │ ├── tensorboard │ │ ├── __init__.py │ │ └── tensorboardwriter.py │ ├── training │ │ ├── __init__.py │ │ ├── train_model.py │ │ ├── distill_model.py │ │ └── modeltrainer.py │ ├── wandb │ │ └── wandb_writer.py │ └── config.py ├── train.py ├── test.py ├── train_many_models.py ├── default_config-example.yml ├── test_many_models.py ├── distill.py ├── computemean.py ├── generate_edge_indexes.py └── inference.py ├── .gitattributes ├── environment.yml ├── .gitignore ├── Dataset └── analyzeData.py ├── README.md └── LICENSE /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/imclaslib/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/imclaslib/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/imclaslib/files/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/imclaslib/logging/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/imclaslib/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/imclaslib/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/imclaslib/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/imclaslib/tensorboard/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/imclaslib/training/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: multilabelimage_model_env 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - python=3.11 9 | - pytorch 10 | - torchvision 11 | - torchaudio 12 | - pytorch-cuda=12.1 13 | - opencv 14 | - pandas 15 | - scikit-learn=1.4.0 16 | - wandb 17 | - matplotlib 18 | - tqdm 19 | - pillow 20 | - numpy 21 | - scipy 22 | - pyyaml 23 | - pip 24 | - pip: 25 | - torch-summary 26 | - tensorboard 27 | - torch-tb-profiler 28 | - torch-geometric 29 | - timm 30 | -------------------------------------------------------------------------------- /src/imclaslib/models/multilabel_dice_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class DiceLoss(nn.Module): 6 | def __init__(self, smooth=1.0): 7 | super(DiceLoss, self).__init__() 8 | self.smooth = smooth 9 | 10 | def forward(self, inputs, targets): 11 | # Apply sigmoid activation to predict probabilities 12 | inputs = torch.sigmoid(inputs) 13 | 14 | # Calculate intersection and union 15 | intersection = (inputs * targets).sum(dim=1) 16 | union = inputs.sum(dim=1) + targets.sum(dim=1) 17 | 18 | # Dice coefficient 19 | dice = (2. * intersection + self.smooth) / (union + self.smooth) 20 | 21 | # Dice loss 22 | dice_loss = 1 - dice 23 | 24 | # Average Dice loss over the batch 25 | return dice_loss.mean() -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | from imclaslib.config import Config 2 | import imclaslib.files.pathutils as pathutils 3 | from imclaslib.training.train_model import train_model 4 | from imclaslib.logging.loggerfactory import LoggerFactory 5 | from imclaslib.wandb.wandb_writer import WandbWriter 6 | 7 | config = Config("default_config.yml") 8 | # Set up logging for the training process 9 | logger = LoggerFactory.setup_logging("logger", log_file=pathutils.combine_path(config, 10 | pathutils.get_log_dir_path(config), 11 | f"{config.model_name}_{config.model_image_size}_{config.model_weights}", 12 | f"train__{pathutils.get_datetime()}.log"), config=config) 13 | 14 | def main(): 15 | # Call the train_model function with the configuration object 16 | with WandbWriter(config) as wandb_writer: 17 | train_model(config, wandbWriter=wandb_writer) 18 | 19 | if __name__ == '__main__': 20 | main() -------------------------------------------------------------------------------- /src/imclaslib/models/modelutils.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import ReduceLROnPlateau 2 | from imclaslib.logging.loggerfactory import LoggerFactory 3 | logger = LoggerFactory.get_logger(f"logger.{__name__}") 4 | 5 | def get_learningRate_scheduler(optimizer, config): 6 | """ 7 | Creates a Learning Rate Scheduler to reduce learning rate during training 8 | 9 | Parameters: 10 | optimizer (torch.optim.Optimizer): The optimizer instance to load the state into. 11 | config: (Config): Configuration object containing dataset parameters. 12 | 13 | Returns: 14 | torch.optim.lr_scheduler.ReduceLROnPlateau: The learning rate reducer 15 | """ 16 | return ReduceLROnPlateau(optimizer, mode='max', factor=config.train_learningrate_reducer_factor, patience=config.train_learningrate_reducer_patience, threshold=config.train_learningrate_reducer_threshold, min_lr=config.train_learningrate_reducer_min_lr) -------------------------------------------------------------------------------- /src/test.py: -------------------------------------------------------------------------------- 1 | from imclaslib.config import Config 2 | import imclaslib.files.pathutils as pathutils 3 | from imclaslib.evaluation.test_model import evaluate_model 4 | from imclaslib.logging.loggerfactory import LoggerFactory 5 | from imclaslib.config import Config 6 | from imclaslib.wandb.wandb_writer import WandbWriter 7 | # Set up logging for the training process 8 | config = Config("default_config.yml") 9 | logger = LoggerFactory.setup_logging("logger", log_file=pathutils.combine_path(config, 10 | pathutils.get_log_dir_path(config), 11 | f"{config.model_name}_{config.model_image_size}_{config.model_weights}", 12 | f"train__{pathutils.get_datetime()}.log"), config=config) 13 | 14 | 15 | def main(): 16 | # Call the train_model function with the configuration object 17 | with WandbWriter(config) as wandb_writer: 18 | evaluate_model(config, wandbWriter=wandb_writer) 19 | 20 | if __name__ == '__main__': 21 | main() -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ################################################################################ 2 | # This .gitignore file was automatically created by Microsoft(R) Visual Studio. 3 | ################################################################################ 4 | 5 | /.vs 6 | /inference_inputs 7 | /inference_inputs2 8 | /outputs 9 | /tensorboard_logs 10 | /ven 11 | /env 12 | /logs 13 | /MultiClassImageClassification.zip 14 | /src/__pycache__ 15 | /src/wandb/ 16 | __pycache__/ 17 | Dataset/dataset*.csv 18 | /inference_outputs 19 | Dataset/graph_commons.csv 20 | tags*.txt 21 | src/train_many_models*.json 22 | src/test_many_models*.json 23 | default_config*.yml 24 | src/default_config.yml 25 | Dataset/analyzeData.py 26 | Dataset/removeRemoved.py 27 | src/default_config.yml 28 | src/train_many_models.yml 29 | src/test_many_models*.yml 30 | environment_linux.yml 31 | annotationresults*.csv 32 | settings.json 33 | inference.yml 34 | src/distill_models.yml 35 | -------------------------------------------------------------------------------- /src/imclaslib/models/multilabel_classifier.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | class MultiLabelClassifier(nn.Module): 5 | def __init__(self, base_model, num_classes, dropout_prob): 6 | super().__init__() 7 | self.base_model = base_model 8 | self.num_classes = num_classes 9 | 10 | # Assuming base_model outputs features of size (batch_size, feature_dim) 11 | self.classifier = nn.Sequential( 12 | nn.Dropout(dropout_prob), 13 | nn.Linear(base_model.output_dim, num_classes) 14 | ) 15 | self.output_dim = num_classes # Set the output dimension 16 | 17 | def forward(self, x): 18 | # Get the image features from the base model 19 | image_features = self.base_model(x) # [batch_size, feature_dim] 20 | 21 | # Pass the image features through the classifier 22 | logits = self.classifier(image_features) # [batch_size, num_classes] 23 | 24 | return logits -------------------------------------------------------------------------------- /Dataset/analyzeData.py: -------------------------------------------------------------------------------- 1 | import csv 2 | 3 | def analyze_csv(csv_file_path): 4 | # Initialize dictionaries to store annotation counts and file counts 5 | annotation_counts = {} 6 | file_counts = {'with_annotations': 0, 'without_annotations': 0} 7 | 8 | with open(csv_file_path, 'r', newline='') as csvfile: 9 | reader = csv.DictReader(csvfile) 10 | 11 | # Iterate through each row in the CSV 12 | for row in reader: 13 | file_name = row['filepath'] 14 | 15 | # Count files without any annotations 16 | if all(value == '0' for key, value in row.items() if key != 'filepath'): 17 | file_counts['without_annotations'] += 1 18 | else: 19 | file_counts['with_annotations'] += 1 20 | 21 | # Count the usage of each annotation 22 | for annotation_name, annotation_value in row.items(): 23 | if annotation_name != 'filepath': 24 | annotation_counts[annotation_name] = annotation_counts.get(annotation_name, 0) + int(annotation_value) 25 | 26 | return annotation_counts, file_counts 27 | 28 | # Replace 'your_file.csv' with the actual path to your CSV file 29 | csv_file_path = 'Dataset/datasetV2.2.csv' 30 | annotation_counts, file_counts = analyze_csv(csv_file_path) 31 | 32 | # Print results 33 | print("Annotation Counts:") 34 | for annotation, count in annotation_counts.items(): 35 | print(f"{annotation}: {count}") 36 | 37 | print("\nFile Counts:") 38 | print(f"Files with annotations: {file_counts['with_annotations']}") 39 | print(f"Files without annotations: {file_counts['without_annotations']}") 40 | -------------------------------------------------------------------------------- /src/train_many_models.py: -------------------------------------------------------------------------------- 1 | from imclaslib.config import Config 2 | import imclaslib.files.pathutils as pathutils 3 | from imclaslib.training.train_model import train_model 4 | from imclaslib.logging.loggerfactory import LoggerFactory 5 | from imclaslib.wandb.wandb_writer import WandbWriter 6 | 7 | # Set up logging for the training process 8 | config = Config("default_config.yml") 9 | #TODO: CLEAN UP THIS. ITS REUSED IN A TON OF PLACES 10 | logger = LoggerFactory.setup_logging("logger", config, log_file=pathutils.combine_path(config, 11 | pathutils.get_log_dir_path(config), 12 | f"{config.model_name}_{config.model_image_size}_{config.model_weights}", 13 | f"train__{pathutils.get_datetime()}.log")) 14 | 15 | 16 | def main(json_file_path): 17 | """ 18 | Train multiple models as per configurations provided in the JSON file. 19 | 20 | Parameters: 21 | - json_file_path: str, the path to the JSON file containing the configurations 22 | """ 23 | configs = Config.load_configs_from_file(json_file_path, config) 24 | for config_instance in configs: 25 | logger.info(f"Starting training for model: {config_instance.model_name}, image size: {config_instance.model_image_size}, dropout: {config_instance.train_dropout_prob}, weights: {config_instance.model_weights}, l2: {config_instance.train_l2_enabled}, fp16: {config_instance.model_fp16}, dataset version: {config_instance.dataset_version}") 26 | with WandbWriter(config_instance) as wandb_writer: 27 | train_model(config_instance, wandb_writer) 28 | 29 | if __name__ == '__main__': 30 | # Get the path to the JSON file containing the model configurations 31 | json_file_path = pathutils.get_train_many_models_file(config) 32 | main(json_file_path) -------------------------------------------------------------------------------- /src/default_config-example.yml: -------------------------------------------------------------------------------- 1 | model: 2 | name: 'regnet_y_16gf' 3 | weights: 'IMAGENET1K_SWAG_E2E_V1' 4 | image_size: 400 5 | num_classes: 36 6 | #weights: 'DEFAULT' 7 | folder: '' #specify your path here 8 | tags_path: '' #specify your path here 9 | name_to_load: 'best_model' 10 | fp16: true 11 | dataset: 12 | path: '' #specify your path here 13 | version: 1.0 14 | augmentation_level: 0 15 | augmentation_level: 0 16 | normalization_mean: [0.5712, 0.4717, 0.4267] #use the values from computemean.py 17 | normalization_std: [0.2684, 0.2562, 0.2569] #use the values from computemean.py 18 | train_percentage: 80 19 | valid_percentage: 10 20 | test_percentage: 10 21 | preprocess_to_RAM: false 22 | train: 23 | batch_size: 24 24 | dropout_prob: 0 25 | learning_rate: 1.0e-4 26 | num_epochs: 50 27 | continue_training: false #if this is set then make sure best_model is the same model type or set another value for model_name_to_load 28 | requires_grad: true 29 | store_gradients_epoch_interval: 5 30 | check_test_loss_epoch_interval: 10 31 | many_models_path: '' #specify your path here 32 | model_to_load_raw_weights: '' 33 | early_stopping: 34 | patience: 6 35 | threshold: 4.0e-3 36 | learningrate_reducer: 37 | patience: 2 38 | threshold: 2.0e-3 39 | factor: 0.1 40 | min_lr: 1.0e-7 41 | l2: 42 | enabled: true 43 | lambda: 0.0001 44 | label_smoothing: 0.1 45 | test: 46 | batch_size: 128 47 | many_models_path: '' #specify your path here 48 | 49 | logs: 50 | level: 'DEBUG' 51 | folder: '' #specify your path here 52 | project_name: '' #your wandb project name 53 | using_wsl: false #set to true if you're running in WSL and want to use windows paths from inside -------------------------------------------------------------------------------- /src/test_many_models.py: -------------------------------------------------------------------------------- 1 | from imclaslib.config import Config 2 | import imclaslib.files.pathutils as pathutils 3 | from imclaslib.evaluation.test_model import evaluate_model 4 | from imclaslib.files.modelloadingutils import update_config_from_model_file 5 | from imclaslib.logging.loggerfactory import LoggerFactory 6 | from imclaslib.wandb.wandb_writer import WandbWriter 7 | 8 | # Set up logging for the training process 9 | config = Config("default_config.yml") 10 | logger = LoggerFactory.setup_logging("logger", config, log_file=pathutils.combine_path(config, 11 | pathutils.get_log_dir_path(config), 12 | f"{config.model_name}_{config.model_image_size}_{config.model_weights}", 13 | f"train__{pathutils.get_datetime()}.log")) 14 | 15 | 16 | def main(json_file_path): 17 | """ 18 | Train multiple models as per configurations provided in the JSON file. 19 | 20 | Parameters: 21 | - json_file_path: str, the path to the JSON file containing the configurations 22 | """ 23 | configs = Config.load_configs_from_file(json_file_path, config) 24 | for config_instance in configs: 25 | if not config_instance.model_ensemble_model_configs: 26 | update_config_from_model_file(config_instance) 27 | logger.info(f"Starting Evaluating for model: {config_instance.model_name}, image size: {config_instance.model_image_size}, weights: {config_instance.model_weights}") 28 | try: 29 | with WandbWriter(config) as wandb_writer: 30 | evaluate_model(config_instance, wandbWriter=wandb_writer) 31 | except Exception as e: 32 | logger.error(f"Failed testing for model: {config_instance.model_name}, image size: {config_instance.model_image_size}, weights: {config_instance.model_weights} Inner:{e.strerror}") 33 | 34 | if __name__ == '__main__': 35 | # Get the path to the JSON file containing the model configurations 36 | json_file_path = pathutils.get_test_many_models_file(config) 37 | main(json_file_path) -------------------------------------------------------------------------------- /src/distill.py: -------------------------------------------------------------------------------- 1 | from imclaslib.config import Config 2 | import imclaslib.files.pathutils as pathutils 3 | from imclaslib.training.distill_model import distill_model 4 | from imclaslib.logging.loggerfactory import LoggerFactory 5 | from imclaslib.wandb.wandb_writer import WandbWriter 6 | 7 | # Set up logging for the training process 8 | config = Config("default_config.yml") 9 | #TODO: CLEAN UP THIS. ITS REUSED IN A TON OF PLACES 10 | logger = LoggerFactory.setup_logging("logger", config, log_file=pathutils.combine_path(config, 11 | pathutils.get_log_dir_path(config), 12 | f"{config.model_name}_{config.model_image_size}_{config.model_weights}", 13 | f"distill__{pathutils.get_datetime()}.log")) 14 | 15 | 16 | def main(json_file_path): 17 | """ 18 | Train multiple models as per configurations provided in the JSON file. 19 | 20 | Parameters: 21 | - json_file_path: str, the path to the JSON file containing the configurations 22 | """ 23 | configs = Config.load_configs_from_file(json_file_path, config) 24 | teacher_config = configs[0] 25 | student_config = configs[1] 26 | logger.info(teacher_config.model_name) 27 | logger.info(student_config.model_name) 28 | logger.info(f"Starting distillation for model: {student_config.model_name}, image size: {student_config.model_image_size}, dropout: {student_config.train_dropout_prob}, weights: {student_config.model_weights}, l2: {student_config.train_l2_enabled}, fp16: {student_config.model_fp16}, dataset version: {student_config.dataset_version}") 29 | with WandbWriter(student_config) as wandb_writer: 30 | try: 31 | distill_model(teacher_config, student_config, wandb_writer) 32 | except Exception as e: 33 | logger.error('Error at %s', 'distill', exc_info=e) 34 | raise e 35 | 36 | if __name__ == '__main__': 37 | # Get the path to the JSON file containing the model configurations 38 | json_file_path = pathutils.get_distill_models_file(config) 39 | main(json_file_path) -------------------------------------------------------------------------------- /src/imclaslib/dataset/video_predict_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import torchvision.transforms as transforms 3 | import cv2 4 | from imclaslib.logging.loggerfactory import LoggerFactory 5 | logger = LoggerFactory.get_logger(f"logger.{__name__}") 6 | 7 | class VideoDatasetPredict(Dataset): 8 | """Custom dataset for loading images from a list of image paths.""" 9 | def __init__(self, video_path, time_interval, config): 10 | self.cap = cv2.VideoCapture(video_path) 11 | self.fps = self.cap.get(cv2.CAP_PROP_FPS) 12 | self.frame_interval = int(self.fps * time_interval) 13 | self.total_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) 14 | self.image_paths = video_path 15 | self.config = config 16 | self.transform = VideoDatasetPredict.test_transforms(self.config) 17 | 18 | def __len__(self): 19 | return self.total_frames // self.frame_interval 20 | 21 | def __del__(self): 22 | # Release the video capture object 23 | self.cap.release() 24 | 25 | @staticmethod 26 | def test_transforms(config): 27 | return transforms.Compose([ 28 | transforms.ToPILImage(), 29 | transforms.Resize((config.model_image_size, config.model_image_size)), 30 | transforms.ToTensor(), 31 | transforms.Normalize(mean=config.dataset_normalization_mean, std=config.dataset_normalization_std), 32 | ]) 33 | 34 | def __getitem__(self, idx): 35 | # Calculate the actual frame index 36 | frame_idx = idx * self.frame_interval 37 | 38 | # Set the video capture to the correct frame 39 | self.cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx) 40 | ret, frame = self.cap.read() 41 | 42 | # Check if the frame was read correctly 43 | if not ret: 44 | logger.warn(f"Frame at index {frame_idx} could not be read") 45 | return None 46 | 47 | # Convert the frame from BGR to RGB 48 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 49 | 50 | frame = self.transform(frame) 51 | 52 | return { 53 | 'image': frame, 54 | 'frame_count': frame_idx 55 | } 56 | 57 | 58 | -------------------------------------------------------------------------------- /src/computemean.py: -------------------------------------------------------------------------------- 1 | import imclaslib.files.pathutils as pathutils 2 | import imclaslib.dataset.datasetutils as datasetutils 3 | import torch 4 | from imclaslib.logging.loggerfactory import LoggerFactory 5 | from imclaslib.config import Config 6 | 7 | config = Config("default_config.yml") 8 | # Set up logging for the training process 9 | logger = LoggerFactory.setup_logging("logger", config, log_file=pathutils.combine_path(config, 10 | pathutils.get_log_dir_path(config), 11 | f"CalculateDatasetMeanStd", 12 | f"{pathutils.get_datetime()}.log")) 13 | 14 | def compute_mean_std(dataloader): 15 | channels_sum, channels_squared_sum, total_images = 0, 0, 0 16 | 17 | for data in dataloader: 18 | images = data.get('image') 19 | 20 | if images is None: 21 | # Skip corrupted or missing images 22 | continue 23 | 24 | if not isinstance(images, torch.Tensor): 25 | raise TypeError(f"Expected images to be a torch.Tensor but got {type(images)}") 26 | if not images.is_floating_point(): 27 | images = images.float() # Convert images to float if they're not already 28 | 29 | # Rearrange batch to be the shape of [B, C, W * H] 30 | images = images.view(images.size(0), images.size(1), -1) 31 | 32 | # Update total sum and squared sum 33 | channels_sum += images.mean(dim=[0, 2]) * images.size(0) 34 | channels_squared_sum += (images ** 2).mean(dim=[0, 2]) * images.size(0) 35 | total_images += images.size(0) 36 | 37 | # Compute mean and std 38 | mean = channels_sum / total_images 39 | std = (channels_squared_sum / total_images - mean ** 2) ** 0.5 40 | 41 | return mean, std 42 | if __name__ == '__main__': 43 | # DataLoader for your dataset 44 | config.dataset_normalization_mean = None 45 | config.dataset_normalization_std = None 46 | dataloader = datasetutils.get_data_loader_by_name('all', config, shuffle=True) 47 | 48 | try: 49 | # Calculate mean and std 50 | mean, std = compute_mean_std(dataloader) 51 | logger.info(f'Mean: {mean}') 52 | logger.info(f'Std: {std}') 53 | except Exception as e: 54 | print(f'An error occurred during computation: {e}') -------------------------------------------------------------------------------- /src/imclaslib/dataset/images_predict_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import torch 3 | from torchvision.transforms import v2 as V2 4 | import cv2 5 | from imclaslib.logging.loggerfactory import LoggerFactory 6 | from PIL import Image 7 | logger = LoggerFactory.get_logger(f"logger.{__name__}") 8 | 9 | class ImageDatasetPredict(Dataset): 10 | """Custom dataset for loading images from a list of image paths.""" 11 | def __init__(self, image_paths, config): 12 | self.image_paths = image_paths 13 | self.config = config 14 | self.preprocess_fn = V2.Compose([ 15 | V2.ToImage(), 16 | V2.Resize((config.model_image_size, config.model_image_size)), 17 | V2.ToDtype(torch.float32, scale=True), 18 | V2.Normalize(mean=config.dataset_normalization_mean, std=config.dataset_normalization_std), 19 | ]) 20 | 21 | def __len__(self): 22 | return len(self.image_paths) 23 | 24 | def preprocess_single_image(self, image_path): 25 | try: 26 | image = Image.open(image_path).convert('RGB') 27 | except Exception: 28 | image = None 29 | if image is None: 30 | logger.warning(f"Warning: Image not found or corrupted at path: {image_path}") 31 | return None 32 | # apply image transforms 33 | image = self.preprocess_fn(image) 34 | return image 35 | 36 | def __getitem__(self, idx): 37 | image_path = self.image_paths[idx] 38 | image = self.preprocess_single_image(image_path) 39 | 40 | if image is None: 41 | # Log that we're using a placeholder for a specific image 42 | #logger.warning(f"Using placeholder for missing or corrupted image at path: {image_path}") 43 | 44 | # Create a placeholder tensor (e.g., a tensor of zeros) 45 | # The shape should match your model's input size, e.g., (C, H, W) 46 | C, H, W = 3, self.config.model_image_size, self.config.model_image_size 47 | image = torch.zeros((C, H, W), dtype=torch.float32) 48 | image_path = "INVALID:" + image_path 49 | 50 | return { 51 | 'image': image, 52 | 'image_path': image_path 53 | } 54 | 55 | -------------------------------------------------------------------------------- /src/imclaslib/logging/loggerfactory.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import logging.handlers 3 | import os 4 | 5 | class LoggerFactory: 6 | DEFAULT_LOG_LEVEL = logging.INFO 7 | LOG_FILE_MAX_BYTES = 10 * 1024 * 1024 # 10 MB 8 | LOG_FILE_BACKUP_COUNT = 5 # Keep 5 backup files 9 | LONG_LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" 10 | SHORT_LOG_FORMAT = "%(levelname)s: %(message)s" 11 | DATE_FORMAT = "%Y-%m-%d %H:%M:%S" 12 | 13 | @staticmethod 14 | def setup_logging(loggername, config, log_file=None, level=None): 15 | """ 16 | Set up logging configuration for a logger with the specified name. 17 | 18 | Parameters: 19 | logger_name (str): The name of the logger to set up. 20 | log_file (str): The path to the log file. If None, logs to stdout. 21 | level (int): The logging level. If None, defaults to the level specified in config. 22 | config (module): The configuration module with a 'log_level' attribute. 23 | 24 | Returns: 25 | logging.Logger: Configured logger instance. 26 | """ 27 | if level is None: 28 | level = getattr(logging, config.logs_level, LoggerFactory.DEFAULT_LOG_LEVEL) 29 | 30 | # Since we are setting up handlers individually, we don't use basicConfig 31 | logger = logging.getLogger(loggername) 32 | logger.setLevel(level) 33 | 34 | console_handler = logging.StreamHandler() 35 | console_handler.setFormatter(logging.Formatter(LoggerFactory.SHORT_LOG_FORMAT)) 36 | logger.addHandler(console_handler) 37 | 38 | if log_file is not None: 39 | os.makedirs(os.path.dirname(log_file), exist_ok=True) 40 | file_handler = logging.handlers.RotatingFileHandler( 41 | log_file, maxBytes=LoggerFactory.LOG_FILE_MAX_BYTES, backupCount=LoggerFactory.LOG_FILE_BACKUP_COUNT) 42 | file_handler.setFormatter(logging.Formatter(LoggerFactory.LONG_LOG_FORMAT, LoggerFactory.DATE_FORMAT)) 43 | logger.addHandler(file_handler) 44 | 45 | return logger 46 | 47 | @staticmethod 48 | def get_logger(name): 49 | """ 50 | Get a logger with the specified name. 51 | 52 | Parameters: 53 | name (str): The name of the logger to retrieve. 54 | 55 | Returns: 56 | logging.Logger: The logger instance with the given name. 57 | """ 58 | return logging.getLogger(name) -------------------------------------------------------------------------------- /src/imclaslib/models/multilabel_focal_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class MultiLabelFocalLoss(nn.Module): 6 | def __init__(self, alpha=0.7, gamma=1.2, reduction='mean'): 7 | """ 8 | Focal loss for multilabel classification. 9 | Args: 10 | alpha (float, optional): Weighting factor for the rare class. Defaults to 0.25. 11 | gamma (float, optional): Focusing parameter to smooth the easy examples. Defaults to 2.0. 12 | reduction (str, optional): Specifies the reduction type: 'none' | 'mean' | 'sum'. Defaults to 'mean'. 13 | """ 14 | super(MultiLabelFocalLoss, self).__init__() 15 | self.alpha = alpha 16 | self.gamma = gamma 17 | self.reduction = reduction 18 | 19 | def forward(self, inputs, targets): 20 | """ 21 | Compute the focal loss given the model output (logits) and the ground truth labels. 22 | Args: 23 | inputs (torch.Tensor): Logits output by the model (before sigmoid). 24 | targets (torch.Tensor): Ground truth binary labels (same shape as inputs). 25 | 26 | Returns: 27 | torch.Tensor: Computed focal loss. 28 | """ 29 | # Ensure the inputs and targets are the same size 30 | if inputs.size() != targets.size(): 31 | raise ValueError(f"Target size ({targets.size()}) must be the same as input size ({inputs.size()})") 32 | 33 | # Calculate the binary cross-entropy loss without reduction 34 | bce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none') 35 | 36 | # Calculate probabilities 37 | probs = torch.sigmoid(inputs) 38 | # Calculate the modulating factor. For the positive class (targets == 1), this is (1 - p_t)**gamma. 39 | # For the negative class (targets == 0), this is (p_t)**gamma. 40 | modulating_factor = (1 - targets) * probs.pow(self.gamma) + targets * (1 - probs).pow(self.gamma) 41 | 42 | # Apply the alpha weighting 43 | alpha_weight = targets * self.alpha + (1 - targets) * (1 - self.alpha) 44 | 45 | # Compute the focal loss 46 | focal_loss = alpha_weight * modulating_factor * bce_loss 47 | 48 | # Apply the desired reduction 49 | if self.reduction == 'mean': 50 | return torch.mean(focal_loss) 51 | elif self.reduction == 'sum': 52 | return torch.sum(focal_loss) 53 | elif self.reduction == 'none': 54 | return focal_loss 55 | else: 56 | raise ValueError(f"Invalid reduction type '{self.reduction}'. Expected 'none', 'mean', or 'sum'.") -------------------------------------------------------------------------------- /src/imclaslib/models/multilabel_embeddinglayer_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from imclaslib.models.model_layers import MultiHeadAttention, Attention 4 | 5 | class MultiLabelClassifier_LabelEmbeddings(nn.Module): 6 | def __init__(self, base_model, num_classes, embedding_dim, dropout_prob): 7 | super().__init__() 8 | self.base_model = base_model 9 | self.num_classes = num_classes 10 | self.embedding_dim = embedding_dim 11 | 12 | # Embedding layer for labels 13 | self.label_embedding = nn.Embedding(num_classes, embedding_dim) 14 | 15 | # Assuming base_model outputs features of size (batch_size, feature_dim) 16 | self.feature_transform = nn.Linear(base_model.output_dim, embedding_dim) 17 | 18 | # Batch normalization layer after feature transformation 19 | #self.batch_norm = nn.BatchNorm1d(embedding_dim) 20 | 21 | # Classifier head, which maps the concatenated embeddings to the output space 22 | self.classifier = nn.Linear(embedding_dim, num_classes) 23 | 24 | # Dropout layer 25 | self.dropout = nn.Dropout(dropout_prob) if dropout_prob > 0.0 else nn.Identity() 26 | 27 | # Attention layer 28 | self.attention = Attention(embedding_dim, embedding_dim) 29 | 30 | def forward(self, x, labels=None): 31 | # Get the image features from the base model 32 | image_features = self.base_model(x) # [batch_size, feature_dim] 33 | 34 | # Transform image features to the same dimension as label embeddings 35 | transformed_image_features = self.feature_transform(image_features) # [batch_size, embedding_dim] 36 | 37 | # Apply batch normalization 38 | #transformed_image_features = self.batch_norm(transformed_image_features) 39 | transformed_image_features = self.dropout(transformed_image_features) 40 | 41 | if labels is not None: 42 | # During training, use the one-hot encoded labels to compute the label embeddings 43 | label_embeddings = torch.matmul(labels, self.label_embedding.weight) # [batch_size, embedding_dim] 44 | else: 45 | # We don't need to unsqueeze and squeeze since Attention now expects 2D tensors 46 | attention_output = self.attention(transformed_image_features, self.label_embedding.weight) 47 | label_embeddings = attention_output # This is now [batch_size, embedding_dim] 48 | 49 | # Combine the image features with the label embeddings 50 | combined_features = transformed_image_features + label_embeddings 51 | 52 | # Pass the combined features through the classifier 53 | logits = self.classifier(combined_features) # [batch_size, num_classes] 54 | return logits -------------------------------------------------------------------------------- /src/generate_edge_indexes.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import torch 3 | from imclaslib.config import Config 4 | import imclaslib.files.pathutils as pathutils 5 | import imclaslib.dataset.datasetutils as datasetutils 6 | # Define edge weights for different types of edges 7 | EDGE_WEIGHTS_DICT = { 8 | 'Usually mutually exclusive': -1.0, 9 | 'Somewhat common Together': 0.5, 10 | 'Very common together': 0.8, 11 | 'Parent': 1.0, 12 | 'Somewhat Mutually Exclusive': -0.5, 13 | } 14 | config = Config("default_config.yml") 15 | # Function to read the CSV and generate edge indexes and edge weights 16 | def generate_graph_edges(csv_filename): 17 | label_id_dict = datasetutils.get_tag_to_index_mapping(config) 18 | # Create a mapping from label IDs to numerical indices 19 | label_indices = {label_id: idx for idx, label_id in enumerate(label_id_dict.values())} 20 | 21 | source_nodes = [] 22 | target_nodes = [] 23 | edge_weights_list = [] 24 | 25 | with open(csv_filename, mode='r', encoding='utf-8') as csvfile: 26 | csv_reader = csv.DictReader(csvfile, delimiter=',') 27 | 28 | for row in csv_reader: 29 | # Get the indices from the label IDs 30 | from_idx = label_indices.get(label_id_dict.get(row['From Name'])) 31 | to_idx = label_indices.get(label_id_dict.get(row['To Name'])) 32 | 33 | # Ignore if the label is not found in the dictionary 34 | if from_idx is None or to_idx is None: 35 | continue 36 | 37 | # Add the edge to the lists 38 | source_nodes.append(from_idx) 39 | target_nodes.append(to_idx) 40 | 41 | # Get the edge weight based on edge type 42 | weight = EDGE_WEIGHTS_DICT.get(row['Edge Type'], 1.0) # Default weight is 1.0 if not found 43 | edge_weights_list.append(weight) 44 | 45 | # If the edge is not of type 'Parent' and is not directional, add the reverse edge 46 | if row['Edge Type'] != 'Parent': 47 | source_nodes.append(to_idx) 48 | target_nodes.append(from_idx) 49 | edge_weights_list.append(weight) 50 | 51 | # Convert lists to PyTorch tensors 52 | edge_index = torch.tensor([source_nodes, target_nodes], dtype=torch.long) 53 | edge_weights = torch.tensor(edge_weights_list, dtype=torch.float32) 54 | 55 | return edge_index, edge_weights 56 | 57 | # Replace 'graph_commons.csv' with the actual path to your CSV file 58 | edge_index, edge_weights = generate_graph_edges(pathutils.get_graph_path(config)) 59 | 60 | # Print the edge index and weights in the desired format 61 | print('edge_index =', edge_index) 62 | print('edge_weights =', edge_weights) -------------------------------------------------------------------------------- /src/imclaslib/models/model_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class MultiHeadAttention(nn.Module): 6 | def __init__(self, feature_dim, embedding_dim, num_heads): 7 | super().__init__() 8 | self.num_heads = num_heads 9 | self.embedding_dim = embedding_dim 10 | self.head_dim = embedding_dim // num_heads 11 | 12 | assert self.head_dim * num_heads == self.embedding_dim, "embedding_dim must be divisible by num_heads" 13 | 14 | self.query = nn.Linear(feature_dim, embedding_dim) 15 | self.key = nn.Linear(embedding_dim, embedding_dim) 16 | self.value = nn.Linear(embedding_dim, embedding_dim) 17 | 18 | self.out = nn.Linear(embedding_dim, embedding_dim) 19 | 20 | def forward(self, features, label_embeddings): 21 | batch_size = features.shape[0] 22 | 23 | # Linear projections 24 | query = self.query(features).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) 25 | key = self.key(label_embeddings).view(self.num_heads, -1, self.head_dim) 26 | value = self.value(label_embeddings).view(self.num_heads, -1, self.head_dim) 27 | 28 | # Attention scores and softmax 29 | attention_scores = torch.matmul(query, key.transpose(-2, -1)) / (self.head_dim ** 0.5) 30 | attention_distribution = F.softmax(attention_scores, dim=-1) 31 | 32 | # Concatenate heads and put through final linear layer 33 | attention_output = torch.matmul(attention_distribution, value).transpose(1, 2).contiguous().view(batch_size, -1, self.embedding_dim) 34 | output = self.out(attention_output) 35 | 36 | return output.squeeze(1) # Ensure output is [batch_size, embedding_dim] 37 | 38 | class Attention(nn.Module): 39 | def __init__(self, image_features_dim, label_embedding_dim): 40 | super().__init__() 41 | self.image_to_query = nn.Linear(image_features_dim, label_embedding_dim) 42 | self.softmax = nn.Softmax(dim=1) 43 | 44 | def forward(self, features, label_embeddings): 45 | query = self.image_to_query(features).unsqueeze(1) # [batch_size, 1, label_embedding_dim] 46 | # Use label_embeddings as key and value 47 | key_value = label_embeddings # [num_classes, label_embedding_dim] 48 | 49 | # Compute attention scores 50 | attention_scores = torch.matmul(query, key_value.transpose(0, 1)) # [batch_size, 1, num_classes] 51 | 52 | # Apply softmax to get probabilities 53 | attention_weights = self.softmax(attention_scores) 54 | 55 | # Apply attention weights to key_value 56 | context_vector = torch.matmul(attention_weights, key_value).squeeze(1) # [batch_size, label_embedding_dim] 57 | 58 | return context_vector -------------------------------------------------------------------------------- /src/imclaslib/models/gcn_classifier.py: -------------------------------------------------------------------------------- 1 | import torch_geometric.nn as GCN 2 | import torch.nn as nn 3 | import torch 4 | 5 | from imclaslib.models.model_layers import Attention, MultiHeadAttention 6 | 7 | class GCNClassifier(nn.Module): 8 | def __init__(self, base_model, num_classes, model_gcn_model_name, dropout_prob, gcn_model_params, edge_index, edge_weight=None, use_multihead_attention=True, num_heads=4): 9 | super().__init__() 10 | self.base_model = base_model 11 | self.num_classes = num_classes 12 | 13 | # Instantiate the pre-made GCN model from PyTorch Geometric 14 | self.gcn = getattr(GCN, model_gcn_model_name)(**gcn_model_params) 15 | 16 | # Store the graph structure 17 | self.edge_index = edge_index 18 | self.edge_weight = edge_weight 19 | 20 | # Dropout layer 21 | self.dropout = nn.Dropout(dropout_prob) 22 | 23 | # Final classifier layer 24 | self.classifier = nn.Linear(base_model.output_dim + gcn_model_params['out_channels'], num_classes) 25 | 26 | # Initialize a placeholder for label embeddings, which will be learned during training 27 | self.label_embeddings = nn.Parameter(torch.Tensor(num_classes, gcn_model_params['in_channels'])) 28 | nn.init.xavier_uniform_(self.label_embeddings) 29 | 30 | # Initialize the appropriate attention mechanism 31 | if use_multihead_attention: 32 | self.attention = MultiHeadAttention(base_model.output_dim, gcn_model_params['out_channels'], num_heads) 33 | else: 34 | self.attention = Attention(base_model.output_dim, gcn_model_params['out_channels']) 35 | 36 | def forward(self, x, labels=None): 37 | # Get the image features from the base model 38 | image_features = self.base_model(x) # [batch_size, feature_dim] 39 | 40 | # Update label_embeddings using the GCN and the graph structure 41 | label_embeddings_updated = self.gcn(self.label_embeddings, self.edge_index, self.edge_weight) 42 | 43 | if self.training and labels is not None: 44 | # Use the provided labels to select the relevant embeddings for each example in the batch 45 | batch_label_embeddings = torch.matmul(labels.float(), label_embeddings_updated) 46 | else: 47 | # During inference or if labels are not provided, use the attention mechanism 48 | batch_label_embeddings = self.attention(image_features, label_embeddings_updated) 49 | 50 | # Combine the image features with the label embeddings 51 | combined_features = torch.cat((image_features, batch_label_embeddings), dim=1) 52 | combined_features = self.dropout(combined_features) 53 | 54 | # Pass the combined features through the classifier 55 | logits = self.classifier(combined_features) # [batch_size, num_classes] 56 | return logits -------------------------------------------------------------------------------- /src/imclaslib/models/ensemble_classifier.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from imclaslib.files import pathutils 4 | from imclaslib.files import modelloadingutils 5 | from imclaslib.metrics import metricutils 6 | import imclaslib.models.modelfactory as modelfactory 7 | 8 | class EnsembleClassifier(nn.Module): 9 | def __init__(self, config): 10 | super().__init__() 11 | mode = config.model_ensemble_combiner 12 | self.mode = mode 13 | # We need to use ModuleList so that the models are properly registered as submodules of the ensemble 14 | self.models = nn.ModuleList() 15 | self.temperatures = [] 16 | for modelconfig in config.model_ensemble_model_configs: 17 | modelloadingutils.update_config_from_model_file(modelconfig) 18 | model = modelfactory.create_model(modelconfig) 19 | modelToLoadPath = pathutils.get_model_to_load_path(modelconfig) 20 | modelData = modelloadingutils.load_model(modelToLoadPath, modelconfig) 21 | model.load_state_dict(modelData['model_state_dict']) 22 | for param in model.parameters(): 23 | param.requires_grad = False # Freeze the model parameters 24 | self.models.append(model) 25 | self.temperatures.append(modelconfig.model_temperature) 26 | 27 | num_models = len(config.model_ensemble_model_configs) 28 | num_classes = config.model_num_classes 29 | 30 | if mode == 'einsum': 31 | self.meta_weights = nn.Parameter(torch.ones(num_models, num_classes)) 32 | elif mode == 'linear': 33 | self.combining_layer = nn.Linear(num_models * num_classes, num_classes) 34 | 35 | def forward(self, x): 36 | logits_list = [] 37 | for i, model in enumerate(self.models): 38 | model_logits = model(x) 39 | if self.temperatures[i] != None: 40 | model_logits = metricutils.temperature_scale(model_logits, self.temperatures[i]) 41 | logits_list.append(model_logits) 42 | 43 | if self.mode == 'einsum': 44 | stacked_logits = torch.stack(logits_list, dim=0) 45 | weighted_logits = torch.einsum('mnc,mc->mnc', stacked_logits, self.meta_weights) 46 | logits = torch.mean(weighted_logits, dim=0) 47 | elif self.mode == 'linear': 48 | concatenated_logits = torch.cat(logits_list, dim=-1) 49 | logits = self.combining_layer(concatenated_logits) 50 | elif self.mode == 'mean': 51 | stacked_logits = torch.stack(logits_list, dim=0) 52 | logits = torch.mean(stacked_logits, dim=0) 53 | elif self.mode == 'max': 54 | stacked_logits = torch.stack(logits_list, dim=0) 55 | logits, _ = torch.max(stacked_logits, dim=0) 56 | else: 57 | raise ValueError(f"Unsupported mode: {self.mode}") 58 | 59 | return logits -------------------------------------------------------------------------------- /src/imclaslib/wandb/wandb_writer.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | 3 | 4 | class WandbWriter(): 5 | """ 6 | Initializes the Wandb Writer with a given configuration. 7 | 8 | Parameters: 9 | config (module): Configuration module with necessary attributes. 10 | """ 11 | def __init__(self, config): 12 | wandb.init( 13 | # set the wandb project where this run will be logged 14 | project=config.project_name, 15 | config={ 16 | 'model_name': config.model_name, 17 | 'requires_grad': config.train_requires_grad, 18 | 'model_num_classes': config.model_num_classes, 19 | 'dropout': config.train_dropout_prob, 20 | 'embedding_layer': config.model_embedding_layer_enabled, 21 | 'model_gcn_enabled': config.model_gcn_enabled, 22 | 'train_batch_size': config.train_batch_size, 23 | 'optimizer': 'Adam', 24 | 'loss_function': 'BCEWithLogitsLoss', 25 | 'image_size': config.model_image_size, 26 | 'model_gcn_model_name': config.model_gcn_model_name, 27 | 'model_gcn_out_channels': config.model_gcn_out_channels, 28 | 'model_gcn_layers': config.model_gcn_layers, 29 | 'model_attention_layer_num_heads': config.model_attention_layer_num_heads, 30 | 'model_embedding_layer_dimension': config.model_embedding_layer_dimension, 31 | 'datset_version': config.dataset_version, 32 | 'l2': config.train_l2_enabled, 33 | 'l2_lambda': config.train_l2_lambda, 34 | 'label_smoothing': config.train_label_smoothing, 35 | 'dataset_normalization_mean': config.dataset_normalization_mean, 36 | 'dataset_normalization_std': config.dataset_normalization_std, 37 | } 38 | ) 39 | 40 | def log(self, *args, step=None): 41 | if step != None: 42 | wandb.log(*args, step=step) 43 | else: 44 | wandb.log(*args) 45 | 46 | def log_table(self, table_name, columnNames, columnData, step=None): 47 | if step != None: 48 | wandb.log({table_name: wandb.Table(columns=columnNames, data=columnData)}, step=step) 49 | else: 50 | wandb.log({table_name: wandb.Table(columns=columnNames, data=columnData)}) 51 | 52 | def watch(self, model): 53 | wandb.watch(model) 54 | def __enter__(self): 55 | """ 56 | Enter the runtime context for the ModelTrainer object. 57 | Allows the ModelTrainer to be used with the 'with' statement, ensuring resources are managed properly. 58 | 59 | Returns: 60 | ModelTrainer: The instance with which the context was entered. 61 | """ 62 | return self 63 | 64 | def __exit__(self, exc_type, exc_value, traceback): 65 | """ 66 | Exit the runtime context for the ModelTrainer object. 67 | This method is called after the 'with' block is executed, and it ensures that the TensorBoard writer is closed. 68 | 69 | Parameters: 70 | exc_type: Exception type, if any exception was raised within the 'with' block. 71 | exc_value: Exception value, the exception instance raised. 72 | traceback: Traceback object with details of where the exception occurred. 73 | """ 74 | wandb.finish() -------------------------------------------------------------------------------- /src/imclaslib/training/train_model.py: -------------------------------------------------------------------------------- 1 | import time 2 | from imclaslib.logging.loggerfactory import LoggerFactory 3 | logger = LoggerFactory.get_logger(f"logger.{__name__}") 4 | 5 | import torch 6 | import imclaslib.dataset.datasetutils as datasetutils 7 | from imclaslib.training.modeltrainer import ModelTrainer 8 | from imclaslib.evaluation.modelevaluator import ModelEvaluator 9 | from imclaslib.evaluation.test_model import evaluate_model 10 | 11 | def train_model(config, wandbWriter=None): 12 | """ 13 | Train a model based on the provided configuration. 14 | 15 | Parameters: 16 | config: Configuration module with necessary attributes. 17 | """ 18 | 19 | # Initialize the computation device 20 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 21 | # Get train, validation, and test dataset loaders 22 | train_loader, valid_loader, test_loader = datasetutils.get_train_valid_test_loaders(config=config) 23 | try: 24 | # Initialize the model trainer 25 | with ModelTrainer(device, train_loader, valid_loader, test_loader, config=config, wandbWriter=wandbWriter) as modelTrainer, ModelEvaluator.from_trainer(modelTrainer) as modelEvaluator: 26 | # Start the training and validation 27 | try: 28 | if config.using_wsl and config.train_compile: 29 | modelTrainer.compile() 30 | if config.using_wsl and config.test_compile: 31 | modelEvaluator.compile() 32 | for epoch in range(modelTrainer.start_epoch, modelTrainer.epochs): 33 | logger.info(f"Epoch {epoch+1} of {modelTrainer.epochs}") 34 | 35 | # Training and validation steps 36 | train_start_time = time.time() 37 | modelTrainer.train() 38 | train_end_time = time.time() 39 | logger.info(f"Finished Training Epoch in {train_end_time - train_start_time} seconds.") 40 | modelTrainer.validate(modelEvaluator) 41 | 42 | # Check for early stopping 43 | if modelTrainer.check_early_stopping(): 44 | break 45 | 46 | # Update learning rate based on validation loss 47 | modelTrainer.learningRateScheduler_check() 48 | 49 | # Evaluate test results at specified intervals 50 | if epoch % config.train_check_test_loss_epoch_interval == 0 and epoch != 0: 51 | logger.info("Evaluating Test Results") 52 | test_loss, test_f1, precision, recall = modelEvaluator.evaluate(test_loader, epoch, "Test", "Train") 53 | logger.info(f'Test Loss: {test_loss:.4f}') 54 | logger.info(f'Test Precision: {precision:.4f}, Recall:{recall:.4f}') 55 | logger.info(f'Test F1 Score: {test_f1:.4f}') 56 | except KeyboardInterrupt: 57 | logger.warn("\nTraining interrupted by user.") 58 | except Exception as e: 59 | raise e 60 | finally: 61 | if modelTrainer.best_model_state: 62 | modelTrainer.save_final_model() 63 | except Exception as e: 64 | raise e 65 | finally: 66 | config.model_name_to_load = "best_model" 67 | evaluate_model(config, valid_loader, test_loader, wandbWriter=wandbWriter) 68 | return train_loader, valid_loader, test_loader -------------------------------------------------------------------------------- /src/imclaslib/training/distill_model.py: -------------------------------------------------------------------------------- 1 | import time 2 | from imclaslib.logging.loggerfactory import LoggerFactory 3 | logger = LoggerFactory.get_logger(f"logger.{__name__}") 4 | 5 | import torch 6 | import imclaslib.dataset.datasetutils as datasetutils 7 | from imclaslib.training.modeltrainer import ModelTrainer 8 | from imclaslib.evaluation.modelevaluator import ModelEvaluator 9 | from imclaslib.evaluation.test_model import evaluate_model 10 | 11 | def distill_model(teacher_config, student_config, wandbWriter=None): 12 | """ 13 | Train a model based on the provided configuration. 14 | 15 | Parameters: 16 | config: Configuration module with necessary attributes. 17 | """ 18 | logger.info(f"Teacher image size:{teacher_config.model_image_size}, Student image size:{student_config.model_image_size}") 19 | # Initialize the computation device 20 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 21 | # Get train, validation, and test dataset loaders 22 | teacher_train_loader = datasetutils.get_data_loader_by_name('train', config=teacher_config) 23 | train_loader, valid_loader, test_loader = datasetutils.get_train_valid_test_loaders(config=student_config) 24 | try: 25 | # Initialize the model trainer 26 | with ModelTrainer(device, train_loader, valid_loader, test_loader, config=student_config, wandbWriter=wandbWriter) as modelTrainer, ModelEvaluator.from_trainer(modelTrainer) as modelEvaluator, ModelEvaluator.from_file(device, teacher_config) as teacherEvaluator: 27 | # Start the training and validation 28 | try: 29 | for epoch in range(modelTrainer.start_epoch, modelTrainer.epochs): 30 | logger.info(f"Epoch {epoch+1} of {modelTrainer.epochs}") 31 | 32 | # Training and validation steps 33 | train_start_time = time.time() 34 | modelTrainer.distill(teacherEvaluator.model, teacher_train_loader) 35 | train_end_time = time.time() 36 | logger.info(f"Finished Training Epoch in {train_end_time - train_start_time} seconds.") 37 | modelTrainer.validate(modelEvaluator) 38 | 39 | # Check for early stopping 40 | if modelTrainer.check_early_stopping(): 41 | break 42 | 43 | # Update learning rate based on validation loss 44 | modelTrainer.learningRateScheduler_check() 45 | 46 | # Evaluate test results at specified intervals 47 | if epoch % student_config.train_check_test_loss_epoch_interval == 0 and epoch != 0: 48 | logger.info("Evaluating Test Results") 49 | test_loss, test_f1, precision, recall = modelEvaluator.evaluate(test_loader, epoch, "Test", "Train") 50 | logger.info(f'Test Loss: {test_loss:.4f}') 51 | logger.info(f'Test Precision: {precision:.4f}, Recall:{recall:.4f}') 52 | logger.info(f'Test F1 Score: {test_f1:.4f}') 53 | except KeyboardInterrupt: 54 | logger.warn("\nTraining interrupted by user.") 55 | except Exception as e: 56 | logger.error('Error at %s', 'distill training loop', exc_info=e) 57 | raise e 58 | finally: 59 | if modelTrainer.best_model_state: 60 | modelTrainer.save_final_model() 61 | except Exception as e: 62 | raise e 63 | finally: 64 | student_config.model_name_to_load = "best_model" 65 | evaluate_model(student_config, valid_loader, test_loader, wandbWriter=wandbWriter) 66 | return train_loader, valid_loader, test_loader -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multi-Label Image Classification using Pytorch 2 | 3 | MultiLabelImageClassificationPytorch is a robust and flexible library designed to simplify the process of multilabel image classification with a dataset of images. This library provides a suite of scripts and modules to load various models, fine-tune hyperparameters, and even train multiple models sequentially with ease. The project boasts in-depth visualization support through TensorBoard, enabling users to compare performance across models and optimize their specific use cases efficiently. 4 | 5 | ## Table of Contents 6 | - [Installation](#installation) 7 | - [Usage](#usage) 8 | - [Understanding Results](#understanding-results) 9 | - [Project Structure](#project-structure) 10 | - [Contributing](#contributing) 11 | - [License](#license) 12 | 13 | ## Installation 14 | 15 | To set up the environment to run the code, follow these steps: 16 | 17 | ### Prerequisites 18 | 19 | - [Anaconda](https://www.anaconda.com/products/individual) or [Miniconda](https://docs.conda.io/en/latest/miniconda.html) 20 | 21 | ### Environment Setup 22 | 23 | 1. Clone the repository to your local machine: 24 | ```sh 25 | git clone https://github.com/SkierProjects/MultiLabelImageClassificationPytorch.git 26 | cd MultiLabelImageClassificationPytorch 27 | ``` 28 | 29 | 2. Create and activate the conda environment from the `environment.yml` file: 30 | ```sh 31 | conda env create -f environment.yml 32 | conda activate multilabelimage_model_env 33 | ``` 34 | 35 | ## Usage 36 | 37 | ### Dataset Preparation 38 | 39 | Place your `dataset.csv` file in the `Dataset` directory. The CSV file should have the following format: 40 | ``` 41 | filepath,classname0,classname1,... 42 | /path/to/image1.jpg,0,1,... 43 | /path/to/image2.jpg,1,0,... 44 | ``` 45 | Run `analyzeData.py` to get insights on the class balance in the dataset: 46 | ```sh 47 | python Dataset/analyzeData.py 48 | ``` 49 | 50 | Run `computemean.py` to calculate the mean and standard deviation of your dataset 51 | ```sh 52 | python src/computemean.py 53 | ``` 54 | Take the outputs for the mean and standard deviation and place them inside `src/config.py` for `dataset_normalization_mean` and `dataset_normalization_std`. These will be used to normalize the images used for training. 55 | 56 | ### Training a Model 57 | 58 | To train a model, use the `train.py` script: 59 | ```sh 60 | python src/train.py 61 | ``` 62 | 63 | For training multiple models, use the `train_many_models.py` script. Modify `train_many_models.json` to include the models you want to train, which overrides values in `config.py`: 64 | ```sh 65 | python src/train_many_models.py 66 | ``` 67 | 68 | ### Evaluating a Model 69 | 70 | To evaluate the performance of a model, you can use the `test.py` script: 71 | ```sh 72 | python src/test.py 73 | ``` 74 | 75 | ### Evaluating a Model 76 | 77 | To run a model and get results at runtime, you can use the `inference.py` script: 78 | ```sh 79 | python src/inference.py 80 | ``` 81 | 82 | ### TensorBoard 83 | 84 | To view TensorBoard logs, run: 85 | ```sh 86 | tensorboard --logdir=tensorboard_logs 87 | ``` 88 | 89 | ## Understanding Results 90 | 91 | The results of the model training and evaluation will be stored in the following directories: 92 | - `logs`: Contains log files with detailed information about the training process. 93 | - `outputs`: Contains saved models in `.pth` format. 94 | - `tensorboard_logs`: Contains TensorBoard logs for visualizing training progress and metrics. 95 | 96 | ## Project Structure 97 | 98 | - `Dataset/`: Contains the dataset and scripts for analyzing the dataset. 99 | - `logs/`: Where training logs are stored. 100 | - `outputs/`: Where trained model weights and checkpoints are saved. 101 | - `tensorboard_logs/`: Where TensorBoard logs are output. 102 | - `src/`: Contains all source code for the project. 103 | - `environment.yml`: Conda environment file with all required dependencies. 104 | 105 | ## Contributing 106 | 107 | To contribute, please submit a pull request to the repository. Your contributions will be reviewed and considered for merging. 108 | 109 | ## License 110 | 111 | This project is licensed under the Apache License 2.0. See the LICENSE file for more details. -------------------------------------------------------------------------------- /src/imclaslib/tensorboard/tensorboardwriter.py: -------------------------------------------------------------------------------- 1 | from torch.utils.tensorboard import SummaryWriter 2 | import imclaslib.files.pathutils as pathutils 3 | import imclaslib.files.imageutils as imageutils 4 | import imclaslib.dataset.datasetutils as datasetutils 5 | 6 | class TensorBoardWriter(): 7 | """ 8 | Initializes the TensorBoardWriter with a given configuration. 9 | 10 | Parameters: 11 | config (module): Configuration module with necessary attributes. 12 | """ 13 | def __init__(self, config): 14 | self.config = config 15 | 16 | modelAddons = "" 17 | if self.config.model_embedding_layer_enabled: 18 | modelAddons = f"_EmbeddingLayer_{config.model_embedding_layer_dimension}" 19 | elif self.config.model_gcn_enabled: 20 | modelAddons = f"_GCN_{config.model_embedding_layer_dimension}_{config.model_gcn_out_channels}_{config.model_gcn_layers}_{config.model_attention_layer_num_heads}" 21 | log_dir = pathutils.combine_path(config, 22 | pathutils.get_tensorboard_log_dir_path(config), 23 | f'{config.model_name}_{config.model_weights}_{config.model_image_size}_{config.train_dropout_prob}_{config.dataset_version}{modelAddons}' 24 | ) 25 | self.writer = SummaryWriter(log_dir) 26 | 27 | def add_scalar(self, tag, scalar_value, step): 28 | """ 29 | Writes a scalar value to TensorBoard. 30 | 31 | Parameters: 32 | tag (str): The tag associated with the scalar. 33 | scalar_value (float): The scalar value to write. 34 | step (int): The global step value to record. 35 | """ 36 | self.writer.add_scalar(tag, scalar_value, step) 37 | 38 | def add_scalars_from_dict(self, input_dict, step): 39 | """ 40 | Writes a scalar value to TensorBoard. 41 | 42 | Parameters: 43 | input_dict (dict): Dictionary of tag name to tag values 44 | step (int): The global step value to record. 45 | """ 46 | for key, value in input_dict.items(): 47 | self.add_scalar(key, value, step) 48 | 49 | def write_image_test_results(self, images, true_labels, predictions, step, runmode, dataSubset): 50 | """ 51 | Writes image test results with overlays to TensorBoard. 52 | 53 | Parameters: 54 | images (Tensor): Batch of images. 55 | true_labels (Tensor): True labels for the images. 56 | predictions (Tensor): Predicted labels for the images. 57 | step (int): The global step value to record. 58 | runmode (str): The mode of the run (e.g., 'Train', 'Test'). 59 | data_subset (str): The subset of data (e.g., 'Validation'). 60 | """ 61 | denormalized_images = imageutils.denormalize_images(images, self.config) 62 | pil_images = imageutils.convert_to_PIL(denormalized_images) 63 | overlaid_images = imageutils.overlay_predictions_batch(pil_images, predictions.cpu().tolist(), datasetutils.get_index_to_tag_mapping(self.config), true_labels.cpu().tolist()) 64 | tensor_overlaid_images = imageutils.convert_PIL_to_tensors(overlaid_images) 65 | self.add_images(f'{runmode}/{dataSubset}/Images', denormalized_images, step) 66 | self.add_images(f'{runmode}/{dataSubset}/True Labels', imageutils.convert_labels_to_color(true_labels.cpu(), self.config.model_num_classes), step) 67 | self.add_images(f'{runmode}/{dataSubset}/Predictions', imageutils.convert_labels_to_color(predictions.cpu(), self.config.model_num_classes), step) 68 | self.add_images(f'{runmode}/{dataSubset}/OverlayPredictions', tensor_overlaid_images, step) 69 | 70 | def add_histogram(self, tag, param, step): 71 | """ 72 | Writes a histogram of values to TensorBoard. 73 | 74 | Parameters: 75 | tag (str): The tag associated with the histogram. 76 | values (Tensor): Values to create a histogram. 77 | step (int): The global step value to record. 78 | """ 79 | self.writer.add_histogram(tag, param, step) 80 | 81 | def add_images(self, tag, images, step): 82 | """ 83 | Writes a batch of images to TensorBoard. 84 | 85 | Parameters: 86 | tag (str): The tag associated with the images. 87 | images (Tensor): Batch of images to write. 88 | step (int): The global step value to record. 89 | """ 90 | self.writer.add_images(tag, images, step) 91 | 92 | def close_writer(self): 93 | """ 94 | Closes the TensorBoard writer and cleans up resources. 95 | """ 96 | if self.writer: 97 | self.writer.close() 98 | self.writer = None 99 | 100 | def add_hparams(self, hparams, metrics): 101 | """ 102 | Writes hyperparameters and their associated metrics to TensorBoard. 103 | 104 | Parameters: 105 | hparams (dict): Dictionary of hyperparameters. 106 | metrics (dict): Dictionary of metrics associated with the hyperparameters. 107 | """ 108 | self.writer.add_hparams(hparam_dict=hparams, metric_dict=metrics) -------------------------------------------------------------------------------- /src/imclaslib/models/modelfactory.py: -------------------------------------------------------------------------------- 1 | from torchvision import models as models 2 | import torch.nn as nn 3 | import torch 4 | import timm 5 | from imclaslib.models.ensemble_classifier import EnsembleClassifier 6 | from imclaslib.models.gcn_classifier import GCNClassifier 7 | from imclaslib.models.multilabel_classifier import MultiLabelClassifier 8 | from imclaslib.models.multilabel_embeddinglayer_model import MultiLabelClassifier_LabelEmbeddings 9 | 10 | def create_model(config): 11 | """ 12 | Creates a PyTorch multi-label image classification model with optional label embedding. 13 | 14 | Parameters: 15 | config (config) The config used to load create the model 16 | 17 | Returns: 18 | nn.Module: The PyTorch model with the configured classifier head. 19 | """ 20 | if config.model_ensemble_model_configs: 21 | return EnsembleClassifier(config) 22 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 23 | 24 | num_features = None 25 | # Try to load the model from torchvision.models 26 | try: 27 | # Use the 'models' module from torchvision 28 | model = getattr(models, config.model_name)(weights=config.model_weights) 29 | except AttributeError: 30 | # If the model is not available in torchvision, try loading it from timm 31 | if timm.is_model(config.model_name): 32 | # Use timm to create the model without the classifier (head) 33 | model = timm.create_model( 34 | config.model_name, 35 | pretrained=True, 36 | num_classes=0 # Setting num_classes=0 removes the classifier 37 | ) 38 | num_features = model.num_features # Get the number of features after pooling 39 | else: 40 | raise ValueError(f"The model '{config.model_name}' is not available in torchvision or timm.") 41 | model = model.to(device) 42 | 43 | # Freeze or unfreeze the model parameters based on requires_grad 44 | for param in model.parameters(): 45 | param.requires_grad = config.train_requires_grad 46 | 47 | # Replace the appropriate classifier head with a new one 48 | if num_features is None: 49 | if hasattr(model, 'classifier') and isinstance(model.classifier, nn.Sequential): 50 | num_features = model.classifier[-1].in_features 51 | #print(f"Number of input features to the classifier: {num_features}") 52 | model.classifier[-1] = nn.Identity() 53 | elif hasattr(model, 'fc'): 54 | num_features = model.fc.in_features 55 | #print(f"Number of input features to the fc: {num_features}") 56 | model.fc = nn.Identity() 57 | elif hasattr(model, 'head'): 58 | num_features = model.head.in_features 59 | #print(f"Number of input features to the head: {num_features}") 60 | model.head = nn.Identity() 61 | elif hasattr(model, 'heads') and isinstance(model.heads, nn.Sequential): 62 | num_features = model.heads[-1].in_features 63 | #print(f"Number of input features to the heads: {num_features}") 64 | model.heads[-1] = nn.Identity() 65 | else: 66 | raise AttributeError(f"The model '{config.model_name}' does not have a recognized classifier head.") 67 | model.output_dim = num_features 68 | 69 | # If add_embedding_layer is True, wrap the base model with the MultiLabelClassifier_LabelEmbeddings 70 | if config.model_embedding_layer_enabled: 71 | model = MultiLabelClassifier_LabelEmbeddings(model, config.model_num_classes, config.model_embedding_layer_dimension, config.train_dropout_prob) 72 | elif config.model_gcn_enabled: 73 | if config.model_gcn_model_name is None: 74 | raise ValueError("GCN model name must be provided when use_gcn is True.") 75 | gcn_model_params = { 76 | 'in_channels': config.model_embedding_layer_dimension, 77 | 'out_channels': config.model_gcn_out_channels, # Output dimension size (should match base model output dimension for concatenation) 78 | 'dropout': config.train_dropout_prob / 100, 79 | 'hidden_channels': config.model_embedding_layer_dimension, 80 | 'num_layers': config.model_gcn_layers 81 | } 82 | config.model_gcn_edge_index = config.model_gcn_edge_index.to(device) 83 | if config.model_gcn_edge_weights is not None: 84 | config.model_gcn_edge_weights = config.model_gcn_edge_weights.to(device) 85 | model = GCNClassifier( 86 | base_model=model, 87 | num_classes=config.model_num_classes, 88 | model_gcn_model_name=config.model_gcn_model_name, 89 | dropout_prob=config.train_dropout_prob / 100, 90 | gcn_model_params=gcn_model_params, 91 | edge_index=config.model_gcn_edge_index, 92 | edge_weight=config.model_gcn_edge_weights, 93 | num_heads=config.model_attention_layer_num_heads 94 | ) 95 | else: 96 | # If not using the embedding layer, create a MultiLabelClassifier with dropout 97 | model = MultiLabelClassifier(model, config.model_num_classes, config.train_dropout_prob / 100) 98 | #print(model) 99 | return model -------------------------------------------------------------------------------- /src/imclaslib/files/pathutils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from datetime import datetime 3 | import sys 4 | 5 | def get_best_model_path(config): 6 | """ 7 | Gets the path for the best model's checkpoint file. 8 | 9 | Returns: 10 | Path: The path for the best model's checkpoint file. 11 | """ 12 | return combine_path(config, get_output_dir_path(config), "best_model.pth") 13 | 14 | def get_model_to_load_path(config): 15 | """ 16 | Gets the path for the model to load (defined in the config) 17 | 18 | Parameters: 19 | config (object): Configuration object containing the model to load name. 20 | 21 | Returns: 22 | Path: The path for the model to loads checkpoint file. 23 | """ 24 | return combine_path(config, get_output_dir_path(config), f"{config.model_name_to_load}.pth") 25 | 26 | def get_log_dir_path(config): 27 | """ 28 | Gets the path for the project's log directory, creating it if it doesn't exist. 29 | 30 | Returns: 31 | Path: The log directory path. 32 | """ 33 | return combine_path(config, config.logs_folder) 34 | 35 | 36 | def get_tensorboard_log_dir_path(config): 37 | """ 38 | Gets the path for the TensorBoard log directory. 39 | 40 | Returns: 41 | Path: The TensorBoard log directory path. 42 | """ 43 | return combine_path(config, config.logs_tensorboard_folder) 44 | 45 | def get_output_dir_path(config): 46 | """ 47 | Gets the path for the project's output directory. 48 | 49 | Returns: 50 | Path: The output directory path. 51 | """ 52 | return combine_path(config, config.model_folder) 53 | 54 | def combine_path(config, *args): 55 | """ 56 | Combines multiple path components into a single Path object. 57 | 58 | Parameters: 59 | *args: A variable number of path components. 60 | 61 | Returns: 62 | Path: The combined path. 63 | """ 64 | raw_path = Path(*args) 65 | if config.using_wsl: 66 | raw_path = convert_windows_path_to_wsl(raw_path) 67 | return raw_path 68 | 69 | def get_datetime(): 70 | """ 71 | Gets the current date and time in the format YYYYMMDD_HHMMSS. 72 | 73 | Returns: 74 | str: The current date and time as a string. 75 | """ 76 | return datetime.now().strftime("%Y%m%d_%H%M%S") 77 | 78 | def convert_windows_path_to_wsl(windows_path): 79 | """ 80 | Convert a Windows file path to a WSL file path. 81 | """ 82 | input_is_path = isinstance(windows_path, Path) 83 | windows_path_str = str(windows_path) 84 | 85 | # Replace the drive letter with '/mnt/' and lower the case of the drive letter 86 | if windows_path_str[1:3] == ':\\': 87 | wsl_path = '/mnt/' + windows_path_str[0].lower() + windows_path_str[2:] 88 | else: 89 | wsl_path = windows_path_str # If the path is already in the correct format 90 | 91 | # Replace backslashes with forward slashes 92 | wsl_path = wsl_path.replace('\\', '/') 93 | 94 | return Path(wsl_path) if input_is_path else wsl_path 95 | 96 | def convert_wsl_path_to_windows(wsl_path): 97 | """ 98 | Convert a WSL file path to a Windows file path. 99 | """ 100 | input_is_path = isinstance(wsl_path, Path) 101 | wsl_path_str = str(wsl_path) 102 | 103 | # Check if the path starts with '/mnt/' (case-insensitive) 104 | if wsl_path_str.lower().startswith('/mnt/'): 105 | # Extract the drive letter and construct the Windows path 106 | drive_letter = wsl_path_str[5] 107 | windows_path = f"{drive_letter.upper()}:{wsl_path_str[6:]}" 108 | else: 109 | windows_path = wsl_path_str # If the path is already in the correct format 110 | 111 | # Replace forward slashes with backslashes 112 | windows_path = windows_path.replace('/', '\\') 113 | 114 | return Path(windows_path) if input_is_path else windows_path 115 | 116 | def get_dataset_path(config): 117 | """ 118 | Gets the path for the dataset file based on the configuration. 119 | 120 | Parameters: 121 | config (object): Configuration object containing the dataset file name. 122 | 123 | Returns: 124 | Path: The dataset file path. 125 | """ 126 | return combine_path(config, config.dataset_path) 127 | 128 | def get_train_many_models_file(config): 129 | """ 130 | Gets the path for the JSON file containing configurations for training many models. 131 | 132 | Returns: 133 | Path: The path to the 'train_many_models.json' file. 134 | """ 135 | return combine_path(config, config.train_many_models_path) 136 | 137 | def get_distill_models_file(config): 138 | """ 139 | Gets the path for the JSON file containing configurations for training many models. 140 | 141 | Returns: 142 | Path: The path to the 'train_many_models.json' file. 143 | """ 144 | return combine_path(config, config.train_distill_models_path) 145 | 146 | def get_test_many_models_file(config): 147 | """ 148 | Gets the path for the JSON file containing configurations for training many models. 149 | 150 | Returns: 151 | Path: The path to the 'train_many_models.json' file. 152 | """ 153 | return combine_path(config, config.test_many_models_path) 154 | 155 | def get_tags_path(config): 156 | """ 157 | Gets the path for the tags file containing all possible tags. 158 | 159 | Returns: 160 | Path: The path to the 'tags.text' file. 161 | """ 162 | return combine_path(config, config.model_tags_path) 163 | 164 | def get_graph_path(config): 165 | return combine_path(config, config.model_gcn_graph_path) 166 | -------------------------------------------------------------------------------- /src/imclaslib/dataset/datasetutils.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import pandas as pd 3 | from imclaslib.dataset.image_dataset import ImageDataset 4 | from torch.utils.data import DataLoader 5 | import imclaslib.files.pathutils as pathutils 6 | 7 | # Global variable to cache the dataset CSV after being read for the first time. 8 | dataset_csv = None 9 | 10 | def get_train_valid_test_loaders(config): 11 | """ 12 | Creates and returns DataLoaders for the training, validation, and test sets. 13 | 14 | Parameters: 15 | - config: An immutable configuration object with necessary parameters. 16 | 17 | Returns: 18 | - Tuple of DataLoaders: (train_loader, valid_loader, test_loader) 19 | """ 20 | global dataset_csv 21 | dataset_csv = __get_dataset_csv(config) 22 | train_data = ImageDataset(dataset_csv, mode='train', config=config) 23 | valid_data = ImageDataset(dataset_csv, mode='valid', config=config) 24 | test_data = ImageDataset(dataset_csv, mode='test', config=config) 25 | 26 | train_loader = DataLoader(train_data, batch_size=config.train_batch_size, shuffle=True, num_workers=6, persistent_workers=True, pin_memory=False) 27 | valid_loader = DataLoader(valid_data, batch_size=config.train_batch_size, shuffle=False, num_workers=0, persistent_workers=False, pin_memory=False) 28 | test_loader = DataLoader(test_data, batch_size=config.train_batch_size, shuffle=False, num_workers=0, persistent_workers=False, pin_memory=False) 29 | 30 | return train_loader, valid_loader, test_loader 31 | 32 | def get_data_loader_by_name(mode, config, shuffle=False, num_workers=1): 33 | """ 34 | Creates and returns a DataLoader for the specified mode. 35 | 36 | Parameters: 37 | - mode: A string indicating the mode ('train', 'valid', 'test', or 'all'). 38 | - config: An immutable configuration object with necessary parameters. 39 | - shuffle: A boolean indicating whether to shuffle the dataset. 40 | 41 | Returns: 42 | - DataLoader for the specified mode. 43 | """ 44 | global dataset_csv 45 | dataset_csv = __get_dataset_csv(config) 46 | data = ImageDataset(dataset_csv, mode=mode, config=config) 47 | loader = DataLoader(data, batch_size=config.test_batch_size, shuffle=shuffle, pin_memory=False, persistent_workers=False, num_workers=num_workers) 48 | return loader 49 | 50 | def get_dataset_tag_mappings(config): 51 | """ 52 | Retrieves a mapping from index to tag names from the dataset CSV. 53 | 54 | Parameters: 55 | - config: An immutable configuration object with necessary parameters. 56 | 57 | Returns: 58 | - A dictionary mapping indices to tag names. 59 | """ 60 | global dataset_csv 61 | dataset_csv = __get_dataset_csv(config) 62 | return __get_index_to_tag_mapping(dataset_csv) 63 | 64 | def get_tag_to_index_mapping(config): 65 | """ 66 | Retrieves a mapping from tag names to indices by reading from a text file. 67 | 68 | Parameters: 69 | - tags_txt_path: Path to the text file containing tags, one on each line. 70 | 71 | Returns: 72 | - A dictionary mapping tag names to indices. 73 | """ 74 | tags_txt_path = pathutils.get_tags_path(config) 75 | tag_to_index = {} 76 | with open(tags_txt_path, 'r', encoding='utf-8') as file: 77 | for index, tag in enumerate(file): 78 | tag_to_index[tag.strip()] = index # Remove any leading/trailing whitespace 79 | return tag_to_index 80 | 81 | def get_index_to_tag_mapping(config): 82 | """ 83 | Retrieves a mapping from indices to tag names by reading from a text file. 84 | 85 | Parameters: 86 | - tags_txt_path: Path to the text file containing tags, one on each line. 87 | 88 | Returns: 89 | - A dictionary mapping indices to tag names. 90 | """ 91 | tags_txt_path = pathutils.get_tags_path(config) 92 | index_to_tag = {} 93 | with open(tags_txt_path, 'r', encoding='utf-8') as file: 94 | for index, tag in enumerate(file): 95 | index_to_tag[index] = tag.strip() # Remove any leading/trailing whitespace 96 | return index_to_tag 97 | 98 | def analyze_csv(config): 99 | 100 | csv_file_path = pathutils.get_dataset_path(config) 101 | # Initialize dictionaries to store annotation counts and file counts 102 | annotation_counts = {} 103 | file_counts = {'with_annotations': 0, 'without_annotations': 0} 104 | 105 | with open(csv_file_path, 'r', newline='') as csvfile: 106 | reader = csv.DictReader(csvfile) 107 | 108 | # Iterate through each row in the CSV 109 | for row in reader: 110 | file_name = row['filepath'] 111 | 112 | # Count files without any annotations 113 | if all(value == '0' for key, value in row.items() if key != 'filepath'): 114 | file_counts['without_annotations'] += 1 115 | else: 116 | file_counts['with_annotations'] += 1 117 | 118 | # Count the usage of each annotation 119 | for annotation_name, annotation_value in row.items(): 120 | if annotation_name != 'filepath': 121 | annotation_counts[annotation_name] = annotation_counts.get(annotation_name, 0) + int(annotation_value) 122 | 123 | return annotation_counts, file_counts 124 | 125 | def __get_index_to_tag_mapping(csv): 126 | """ 127 | Helper function to create a mapping from column index to tag name. 128 | 129 | Parameters: 130 | - csv: The dataset CSV DataFrame. 131 | 132 | Returns: 133 | - A dictionary mapping indices to tag names. 134 | """ 135 | tag_columns = csv.columns[1:] 136 | index_to_tag = {index: tag for index, tag in enumerate(tag_columns)} 137 | return index_to_tag 138 | 139 | def __get_dataset_csv(config): 140 | """ 141 | Retrieves the dataset CSV, reading it from file if not already cached. 142 | 143 | Parameters: 144 | - config: An immutable configuration object with necessary parameters. 145 | 146 | Returns: 147 | - The dataset CSV DataFrame. 148 | """ 149 | global dataset_csv 150 | if dataset_csv is None: 151 | dataset_csv = pd.read_csv(pathutils.get_dataset_path(config)) 152 | return dataset_csv 153 | -------------------------------------------------------------------------------- /src/imclaslib/files/modelloadingutils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from imclaslib.logging.loggerfactory import LoggerFactory 3 | import imclaslib.files.pathutils as pathutils 4 | import os 5 | import re 6 | 7 | logger = LoggerFactory.get_logger(f"logger.{__name__}") 8 | 9 | def save_best_model(model_state, config): 10 | """ 11 | Saves the best model state to the predetermined best model path. 12 | 13 | Parameters: 14 | model_state (dict): State dictionary of the model to be saved. 15 | """ 16 | torch.save(model_state, pathutils.get_best_model_path(config)) 17 | 18 | def save_final_model(model_state, f1_score, config): 19 | """ 20 | Saves the final model state to a filename that includes model details such as name, image size, and F1 score. 21 | 22 | Parameters: 23 | model_state (dict): State dictionary of the model to be saved. 24 | f1_score (float): The F1 score of the model. 25 | config (object): Configuration object containing model_name and image_size. 26 | """ 27 | modelAddons = "" 28 | if config.model_embedding_layer_enabled: 29 | modelAddons = "_EmbeddingLayer" 30 | elif config.model_gcn_enabled: 31 | modelAddons = "_GCN" 32 | final_model_path_template = os.path.join(str(pathutils.get_output_dir_path(config)), '{model_name}_{image_size}_{dataset_version}_{f1_score:.4f}{modelAddons}.pth') 33 | final_model_path = final_model_path_template.format( 34 | model_name=config.model_name, 35 | image_size=config.model_image_size, 36 | f1_score=f1_score, 37 | modelAddons=modelAddons, 38 | dataset_version=config.dataset_version 39 | ) 40 | torch.save(model_state, final_model_path) 41 | logger.info(f"Final model saved as {final_model_path}") 42 | 43 | def load_model(model_path, config): 44 | """ 45 | Loads a model and its optimizer state from a checkpoint file. 46 | 47 | Parameters: 48 | model_path (str): Path to the checkpoint file. 49 | config (object): Configuration object. 50 | 51 | Returns: 52 | model_data (dict): The model data from the file. 53 | """ 54 | checkpoint = torch.load(model_path) 55 | 56 | model_data = add_model_data(checkpoint, config) 57 | return model_data 58 | 59 | def add_model_data(checkpoint, config): 60 | model_data = {} 61 | model_data["epoch"] = checkpoint.get('epoch', 0) 62 | model_data["model_state_dict"] = checkpoint.get('model_state_dict', -1) 63 | model_data["optimizer_state_dict"] = checkpoint.get('optimizer_state_dict', -1) 64 | model_data["loss"] = checkpoint.get('loss', -1) 65 | model_data["f1_score"] = checkpoint.get('f1_score', -1) 66 | model_data["model_name"] = checkpoint.get('model_name', config.model_name) 67 | model_data["requires_grad"] = checkpoint.get('requires_grad', True) 68 | model_data["model_num_classes"] = checkpoint.get('model_num_classes', config.model_num_classes) 69 | model_data["dropout"] = checkpoint.get('dropout', 0) 70 | model_data["embedding_layer"] = checkpoint.get('embedding_layer', config.model_embedding_layer_enabled) 71 | model_data["model_gcn_enabled"] = checkpoint.get('model_gcn_enabled', config.model_gcn_enabled) 72 | model_data["train_batch_size"] = checkpoint.get('train_batch_size', config.train_batch_size) 73 | model_data["optimizer"] = checkpoint.get('optimizer', 'Adam') 74 | model_data["loss_function"] = checkpoint.get('loss_function', 'BCEWithLogitsLoss') 75 | model_data["image_size"] = checkpoint.get('image_size', config.model_image_size) 76 | model_data["model_gcn_model_name"] = checkpoint.get('model_gcn_model_name', config.model_gcn_model_name) 77 | model_data["model_gcn_out_channels"] = checkpoint.get('model_gcn_out_channels', config.model_gcn_out_channels) 78 | model_data["model_gcn_layers"] = checkpoint.get('model_gcn_layers', config.model_gcn_layers) 79 | model_data["model_attention_layer_num_heads"] = checkpoint.get('model_attention_layer_num_heads', config.model_attention_layer_num_heads) 80 | model_data["model_embedding_layer_dimension"] = checkpoint.get('model_embedding_layer_dimension', config.model_embedding_layer_dimension) 81 | model_data["dataset_version"] = checkpoint.get('dataset_version', config.dataset_version) 82 | model_data["train_loss"] = checkpoint.get('train_loss', 0) 83 | 84 | 85 | return model_data 86 | 87 | def update_config_from_model_file(config): 88 | pattern = r"(.+?)_(\d{3})_{\d*}_\d\.\d{4}" 89 | file_name = config.model_name_to_load 90 | if not file_name: 91 | return 92 | match = re.match(pattern, file_name) 93 | if match: 94 | # If there's a match, get the model name and image size 95 | model_name = match.group(1) # The first capture group (modelname) 96 | model_image_size = match.group(2) # The second capture group (image size) 97 | dataset_version = match.group(3) # The second capture group (image size) 98 | config.model_name = model_name 99 | config.model_image_size = int(model_image_size) 100 | config.dataset_version = float(dataset_version) 101 | return 102 | else: 103 | model_file_path = pathutils.get_model_to_load_path(config) 104 | checkpoint = torch.load(model_file_path) 105 | model_name = checkpoint.get('model_name', None) 106 | model_image_size = checkpoint.get('image_size', None) 107 | if model_name is not None: 108 | config.model_name = model_name 109 | if model_image_size is not None: 110 | config.model_image_size = model_image_size 111 | return 112 | 113 | def load_pretrained_weights_exclude_classifier(new_model, config, freeze_base_model=False): 114 | pretrained_model_path = pathutils.combine_path(config, pathutils.get_output_dir_path(config), f"{config.train_model_to_load_raw_weights}.pth") 115 | path = str(pretrained_model_path) 116 | # Load the state dictionary of the pretrained model 117 | pretrained_state_dict = torch.load(path) 118 | model_data = add_model_data(pretrained_state_dict, config) 119 | # Remove the weights for the final classifier layer from the pretrained state_dict 120 | classifier_keys = [key for key in pretrained_state_dict if key.startswith('classifier.') or key.startswith('fc.') or key.startswith('head.') or key.startswith('heads.') ] 121 | for key in classifier_keys: 122 | pretrained_state_dict.pop(key) 123 | 124 | # Load the remaining weights into the new model's base model 125 | # This will exclude the final classifier layer 126 | new_model.base_model.load_state_dict(pretrained_state_dict, strict=False) 127 | 128 | # Freeze the parameters of the base model, if required 129 | if freeze_base_model: 130 | for param in new_model.base_model.parameters(): 131 | param.requires_grad = False 132 | 133 | return new_model, model_data -------------------------------------------------------------------------------- /src/imclaslib/config.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import yaml 4 | import copy 5 | 6 | class Config: 7 | def __init__(self, default_config_path=None): 8 | """ 9 | Initialize a new Config instance, optionally loading values from a default JSON/YAML file. 10 | 11 | Parameters: 12 | - default_config_path: str (optional), path to a JSON/YAML file with default configuration values. 13 | """ 14 | 15 | # model - needed to define a model 16 | self.model_name = 'efficientnet_v2_m' 17 | self.model_image_size = 400 18 | self.model_num_classes = 31 19 | self.model_weights = 'DEFAULT' 20 | self.model_folder = "" 21 | self.model_tags_path = "" 22 | self.model_name_to_load = "best_model" 23 | self.model_attention_layer_num_heads = 8 24 | self.model_ensemble_combiner = "mean" 25 | self.model_ensemble_model_configs = None 26 | self.model_fp16 = False 27 | self.model_temperature = None 28 | 29 | #model - embedding layer 30 | self.model_embedding_layer_enabled = False 31 | self.model_embedding_layer_dimension = 512 32 | 33 | #model - gcn 34 | self.model_gcn_enabled = False 35 | self.model_gcn_model_name = "GAT" 36 | self.model_gcn_out_channels = 512 37 | self.model_gcn_layers = 4 38 | self.model_gcn_graph_path = "" 39 | self.model_gcn_edge_index = None 40 | self.model_gcn_edge_weights = None 41 | 42 | # dataset 43 | self.dataset_path = "" 44 | self.dataset_augmentation_level = 0 45 | self.dataset_normalization_mean = None 46 | self.dataset_normalization_std = None 47 | self.dataset_train_percentage = 80 48 | self.dataset_valid_percentage = 10 49 | self.dataset_test_percentage = 10 50 | self.dataset_version = 1.0 51 | self.dataset_tags_mapping_dict = {} 52 | self.dataset_preprocess_to_RAM = False 53 | 54 | # training 55 | self.train_batch_size = 24 56 | self.train_dropout_prob = 0 57 | self.train_learning_rate = 1e-4 58 | self.train_num_epochs = 50 59 | self.train_continue_training = False 60 | self.train_requires_grad = True 61 | self.train_store_gradients_epoch_interval = 5 62 | self.train_check_test_loss_epoch_interval = 10 63 | self.train_many_models_path = "" 64 | self.train_distill_models_path = "" 65 | self.train_model_to_load_raw_weights = "" 66 | self.train_l2_enabled = False 67 | self.train_l2_lambda = 0.01 68 | self.train_label_smoothing = 0.0 69 | self.train_compile = False 70 | 71 | # training - early stopping 72 | self.train_early_stopping_patience = 6 73 | self.train_early_stopping_threshold = 4e-3 74 | 75 | #training - learning rate reducer 76 | self.train_learningrate_reducer_patience = 2 77 | self.train_learningrate_reducer_threshold = 2e-3 78 | self.train_learningrate_reducer_factor = 0.1 79 | self.train_learningrate_reducer_min_lr = 1e-7 80 | 81 | #test 82 | self.test_batch_size = 72 83 | self.test_many_models_path = "" 84 | self.test_compile = False 85 | 86 | #logs 87 | self.logs_level = "DEBUG" 88 | self.logs_folder = "" 89 | self.logs_tensorboard_folder = "" 90 | self.project_name = "" 91 | 92 | self.using_wsl = False 93 | 94 | if default_config_path: 95 | self.load_config(default_config_path) 96 | 97 | 98 | def load_config(self, config_path): 99 | """ 100 | Load configuration data from a file (JSON or YAML) based on its extension. 101 | 102 | Parameters: 103 | - config_path: str, path to the JSON/YAML file with configuration values. 104 | """ 105 | extension = config_path.split('.')[-1].lower() 106 | if extension == 'json': 107 | with open(config_path, 'r') as f: 108 | config_data = json.load(f) 109 | elif extension in ['yaml', 'yml']: 110 | with open(config_path, 'r') as f: 111 | config_data = yaml.safe_load(f) 112 | else: 113 | raise ValueError(f"Unsupported configuration file format: {extension}") 114 | 115 | self.update_config(config_data, self) 116 | 117 | def update_config(self, new_config, default_config): 118 | self.__update_config(new_config) 119 | if self.model_gcn_edge_index is not None and not isinstance(self.model_gcn_edge_index, torch.Tensor): 120 | self.model_gcn_edge_index = torch.tensor(self.model_gcn_edge_index) 121 | 122 | if self.model_gcn_edge_weights is not None and not isinstance(self.model_gcn_edge_weights, torch.Tensor): 123 | self.model_gcn_edge_weights = torch.tensor(self.model_gcn_edge_weights) 124 | 125 | if self.model_ensemble_model_configs is not None: 126 | self.model_ensemble_model_configs = [Config.from_dict(ensemble_config_data, default_config) for ensemble_config_data in self.model_ensemble_model_configs] 127 | 128 | def __update_config(self, new_config, prefix=''): 129 | """ 130 | Update the configuration instance with new values. 131 | 132 | Parameters: 133 | - new_config: dict, new configuration values to update with. 134 | - prefix: str, prefix for nested attributes to maintain hierarchy. 135 | """ 136 | if new_config: 137 | for key, value in new_config.items(): 138 | if isinstance(value, dict): # It's a subsection 139 | new_prefix = f"{prefix}{key}_" 140 | self.__update_config(value, new_prefix) 141 | else: 142 | config_key = f"{prefix}{key}" 143 | if hasattr(self, config_key): 144 | setattr(self, config_key, value) 145 | 146 | def __getattr__(self, name): 147 | """ 148 | Allow dynamic access to configuration values. 149 | """ 150 | if name in self.__dict__: 151 | return self.__dict__[name] 152 | else: 153 | raise AttributeError(f"'Config' object has no attribute '{name}'") 154 | 155 | @classmethod 156 | def from_dict(cls, config_dict, default_config=None): 157 | if default_config is None: 158 | new_instance = cls() 159 | else: 160 | new_instance = copy.deepcopy(default_config) 161 | 162 | new_instance.update_config(config_dict, default_config) 163 | return new_instance 164 | 165 | @staticmethod 166 | def load_configs_from_file(file_path, default_config): 167 | """ 168 | Load model configurations from a JSON or yaml file. 169 | 170 | Parameters: 171 | - file_path: str, the path to the JSON or yaml file containing the configurations 172 | 173 | Returns: 174 | - list, the list of configuration objects 175 | """ 176 | extension = str(file_path).split('.')[-1].lower() 177 | if extension == 'json': 178 | with open(file_path, 'r') as f: 179 | config_data = json.load(f) 180 | elif extension in ['yaml', 'yml']: 181 | with open(file_path, 'r') as f: 182 | config_data = yaml.safe_load(f) 183 | else: 184 | raise ValueError(f"Unsupported configuration file format: {extension}") 185 | return [Config.from_dict(config, default_config) for config in config_data] 186 | 187 | @staticmethod 188 | def load_config_from_file(file_path, default_config): 189 | """ 190 | Load model configurations from a JSON or yaml file. 191 | 192 | Parameters: 193 | - file_path: str, the path to the JSON or yaml file containing the configurations 194 | 195 | Returns: 196 | - list, the list of configuration objects 197 | """ 198 | extension = str(file_path).split('.')[-1].lower() 199 | if extension == 'json': 200 | with open(file_path, 'r') as f: 201 | config_data = json.load(f) 202 | elif extension in ['yaml', 'yml']: 203 | with open(file_path, 'r') as f: 204 | config_data = yaml.safe_load(f) 205 | else: 206 | raise ValueError(f"Unsupported configuration file format: {extension}") 207 | return Config.from_dict(config_data, default_config) -------------------------------------------------------------------------------- /src/inference.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import time 3 | from imclaslib.config import Config 4 | import imclaslib.files.pathutils as pathutils 5 | import argparse 6 | import os 7 | import torch 8 | from imclaslib.evaluation.modelevaluator import ModelEvaluator 9 | from imclaslib.dataset.video_predict_dataset import VideoDatasetPredict 10 | from imclaslib.dataset.images_predict_dataset import ImageDatasetPredict 11 | from imclaslib.dataset import datasetutils 12 | from imclaslib.files import imageutils 13 | from torch.utils.data import DataLoader 14 | from pathlib import Path 15 | from PIL import Image 16 | from imclaslib.logging.loggerfactory import LoggerFactory 17 | from imclaslib.metrics import metricutils 18 | thisconfig = Config("default_config.yml") 19 | logger = LoggerFactory.setup_logging("logger", log_file=pathutils.combine_path(thisconfig, 20 | pathutils.get_log_dir_path(thisconfig), 21 | f"{thisconfig.model_name}_{thisconfig.model_image_size}_{thisconfig.model_weights}", 22 | f"train__{pathutils.get_datetime()}.log"), config=thisconfig) 23 | thisconfig = Config.load_config_from_file("inference.yml", thisconfig) 24 | 25 | def main(args): 26 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 27 | input_path = Path(args.input_path) 28 | output_folder = args.output_folder or input_path.parent.joinpath('inference_outputs') 29 | 30 | os.makedirs(output_folder, exist_ok=True) 31 | if thisconfig.model_ensemble_model_configs != None: 32 | modelEvaluator = ModelEvaluator.from_ensemble(device, thisconfig=thisconfig) 33 | else: 34 | modelEvaluator = ModelEvaluator.from_file(device, thisconfig=thisconfig) 35 | if thisconfig.using_wsl and thisconfig.test_compile: 36 | logger.info("Compiling model") 37 | modelEvaluator.compile() 38 | 39 | if input_path.is_dir(): 40 | output_csv_path = "annotationresults.csv" 41 | image_paths = [] 42 | for root, dirs, files in os.walk(input_path): 43 | for img in files: 44 | if img.lower().endswith(('.png', '.jpg', '.jpeg')): 45 | image_paths.append(os.path.join(root, img)) 46 | image_dataset = ImageDatasetPredict(image_paths, config=thisconfig) 47 | dataset_loader = DataLoader(image_dataset, batch_size=thisconfig.test_batch_size, shuffle=False, num_workers=6, persistent_workers=True, pin_memory=False) 48 | optimalTemp = thisconfig.model_temperature 49 | predict_start_time = time.time() 50 | predictionResults = modelEvaluator.predict(dataset_loader, False) 51 | predict_end_time = time.time() 52 | 53 | seconds = predict_end_time-predict_start_time 54 | image_count = len(image_dataset) 55 | logger.info(f"Took {seconds} seconds to predict for {image_count} images with an average of {image_count/seconds} images per second") 56 | logits = torch.Tensor(predictionResults['predictions']).to(device) 57 | scaled_logits = metricutils.temperature_scale(logits, optimalTemp) 58 | prediction_confidences = metricutils.getConfidences(scaled_logits) 59 | predictions = metricutils.getpredictions_with_threshold(scaled_logits, device, threshold=0.5) 60 | image_paths = predictionResults['image_paths'] 61 | flattened_image_paths = [path for sublist in image_paths for path in sublist] 62 | uncertainty_metrics = metricutils.uncertainty_metrics(prediction_confidences) 63 | cumul_uncertainty = uncertainty_metrics['cumulative_uncertainties'] 64 | max_uncertainty = uncertainty_metrics['max_uncertainties'] 65 | mean_uncertainty = uncertainty_metrics['mean_uncertainties'] 66 | mean_entropy = uncertainty_metrics['mean_entropies'] 67 | 68 | # Prepare CSV data 69 | csv_data = [] 70 | for image_path, pred, cumul_uncertainty, max_uncert, mean_uncert, mean_entrop in zip(flattened_image_paths, predictions, cumul_uncertainty, max_uncertainty, mean_uncertainty, mean_entropy): 71 | # Make sure pred is a list and uncertainty is a scalar 72 | pred_list = pred.tolist() if isinstance(pred, torch.Tensor) else list(pred) 73 | uncertainty_scalar = float(cumul_uncertainty) # Convert to a Python float 74 | # Create a row with image file name, cumulative uncertainty, and one-hot encoded predictions 75 | if thisconfig.using_wsl: 76 | image_path = pathutils.convert_wsl_path_to_windows(image_path) 77 | row = [image_path, uncertainty_scalar, max_uncert[0], max_uncert[1], mean_uncert, mean_entrop] + pred_list 78 | csv_data.append(row) 79 | 80 | while True: 81 | try: 82 | # Write the CSV data to a file 83 | with open(output_csv_path, 'w', newline='') as csv_file: 84 | csv_writer = csv.writer(csv_file) 85 | # Write the header 86 | tagmappings = datasetutils.get_index_to_tag_mapping(thisconfig) 87 | header = ['file_name', 'cumulative_uncertainty', 'Max Uncertainty', 'Max Uncertainty Tag', 'Mean Uncertainty', 'Mean Entropy'] + [f'{tagmappings[i]}' for i in range(len(predictions[0]))] 88 | csv_writer.writerow(header) 89 | # Write the rows 90 | csv_writer.writerows(csv_data) 91 | break # Exit the loop if file writing was successful 92 | except PermissionError: 93 | print("The file is currently open and cannot be written to. Please close the file and press Enter to retry.") 94 | input() # Wait for user to indicate they've closed the file 95 | except Exception as e: 96 | print(f"An unexpected error occurred: {e}") 97 | break # Exit the loop if an unexpected error occurs 98 | 99 | 100 | #Save the images with overlaid predictions 101 | for image_path, pred in zip(flattened_image_paths, predictions): 102 | original_image = Image.open(image_path) 103 | annotated_image = imageutils.overlay_predictions(original_image, pred, datasetutils.get_index_to_tag_mapping(thisconfig)) 104 | save_path = os.path.join(output_folder, os.path.basename(image_path)) 105 | annotated_image.save(save_path) 106 | elif input_path.is_file(): 107 | if str(input_path).lower().endswith(('.png', '.jpg', '.jpeg')): 108 | preprocessed_img = ImageDatasetPredict.preprocess_single_image(str(input_path), thisconfig) 109 | predicted_labels = modelEvaluator.single_image_prediction(preprocessed_img, 0.5) 110 | original_image = Image.open(input_path) 111 | annotated_image = imageutils.overlay_predictions(original_image, predicted_labels, datasetutils.get_index_to_tag_mapping(thisconfig)) 112 | save_path = os.path.join(output_folder, os.path.basename(input_path)) 113 | annotated_image.save(save_path) 114 | 115 | elif str(input_path).lower().endswith(('.mp4', '.avi', '.mov')): 116 | input_path = str(input_path) 117 | video_dataset = VideoDatasetPredict(input_path, args.time_interval, config=thisconfig) 118 | dataset_loader = DataLoader(video_dataset, batch_size=thisconfig.test_batch_size, shuffle=False) 119 | predictionResults = modelEvaluator.predict(dataset_loader, False, 0.5) 120 | predictions = predictionResults['predictions'] 121 | frame_counts = predictionResults['frame_counts'] 122 | 123 | flattened_frame_counts = [frame_count for sublist in frame_counts for frame_count in sublist] 124 | 125 | save_path = os.path.join(output_folder, os.path.basename(input_path)) 126 | imageutils.overlay_predictions_video(input_path, predictions, flattened_frame_counts, datasetutils.get_index_to_tag_mapping(thisconfig), save_path) 127 | else: 128 | print(f"Unsupported file type for input: {input_path}") 129 | else: 130 | print(f"Invalid input path: {input_path}") 131 | 132 | # Define the main function 133 | if __name__ == '__main__': 134 | parser = argparse.ArgumentParser(description='Run inference on images or videos.') 135 | parser.add_argument('input_path', type=str, help='Path to an input image, directory of images, or video file.') 136 | parser.add_argument('--output_folder', type=str, help='Path to save the output predictions.', default=None) 137 | parser.add_argument('--time_interval', type=float, help='Interval in seconds of how frequently to process frames from a video.', default=0.5) 138 | 139 | args = parser.parse_args() 140 | main(args) 141 | -------------------------------------------------------------------------------- /src/imclaslib/files/imageutils.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms as transforms 2 | import torch 3 | from PIL import ImageDraw, Image, ImageFont 4 | import matplotlib.pyplot as plt 5 | from torchvision.transforms import Compose, Resize, Normalize, ToTensor 6 | import cv2 7 | from torchvision.transforms.functional import to_pil_image 8 | import concurrent.futures 9 | import numpy as np 10 | 11 | from imclaslib.metrics import metricutils 12 | 13 | def denormalize_images(images, config): 14 | """ 15 | Denormalizes a batch of images using the specified mean and standard deviation. 16 | 17 | Parameters: 18 | images (torch.Tensor): Batch of images to denormalize. 19 | config (object): Configuration object with mean and std for denormalization. 20 | 21 | Returns: 22 | torch.Tensor: Batch of denormalized images. 23 | """ 24 | return torch.stack([denormalize(img.cpu(), config.dataset_normalization_mean, config.dataset_normalization_std) for img in images]) 25 | 26 | def denormalize(tensor, mean, std): 27 | """De-normalizes a tensor image with mean and standard deviation.""" 28 | # Clone the tensor so we don't change the original 29 | tensor = tensor.clone() 30 | for t, m, s in zip(tensor, mean, std): 31 | t.mul_(s).add_(m) # De-normalize 32 | return torch.clamp(tensor, 0, 1) 33 | 34 | def overlay_predictions_batch(images, predictions, index_to_tag, true_labels=None): 35 | """ 36 | Overlays prediction and ground truth labels on images. 37 | 38 | Parameters: 39 | images (list or torch.Tensor): Batch of images to annotate. 40 | true_labels (torch.Tensor): True labels for each image. 41 | predictions (torch.Tensor): Predicted labels for each image. 42 | index_to_tag (dict): Mapping from label indices to tag names. 43 | 44 | Returns: 45 | list: List of annotated images. 46 | """ 47 | annotated_images = [] 48 | for img, true_label_vec, pred_label_vec in zip(images, true_labels, predictions): 49 | annotated_images.append(overlay_predictions(img, pred_label_vec, index_to_tag, true_label_vec)) 50 | 51 | return annotated_images 52 | 53 | def overlay_predictions(image, predictions, index_to_tag, true_labels=None): 54 | """ 55 | Overlays prediction and ground truth labels on images. 56 | 57 | Parameters: 58 | image (PIL.Image or torch.Tensor): Single image to annotate. 59 | true_labels (torch.Tensor or list): True labels for the image. 60 | predictions (torch.Tensor or list): Predicted labels for the image. 61 | index_to_tag (dict): Mapping from label indices to tag names. 62 | 63 | Returns: 64 | PIL.Image: Annotated image. 65 | """ 66 | # If the image is a tensor, convert it to a PIL Image first 67 | if isinstance(image, torch.Tensor): 68 | image = transforms.ToPILImage()(image) 69 | 70 | # Get the size of the image 71 | width, height = image.size 72 | 73 | # Set the font size to be proportional to the width of the image 74 | font_size = int(width * 0.01) # You can adjust the 0.03 factor as needed 75 | #font = ImageFont.truetype("arial.ttf", font_size) # You can choose a different font if you like 76 | font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSerif-Bold.ttf", font_size) 77 | 78 | draw = ImageDraw.Draw(image) 79 | 80 | pred_label_text = metricutils.convert_labels_to_string(predictions, index_to_tag) 81 | if true_labels is not None: 82 | true_label_text = metricutils.convert_labels_to_string(true_labels, index_to_tag) 83 | # Prepare text to be overlayed on the image 84 | text = f"True: {true_label_text}\nPred: {pred_label_text}" 85 | else: 86 | text = f"Pred: {pred_label_text}" 87 | 88 | # Set text position to be proportional to the size of the image 89 | text_x = width * 0.01 # You can adjust the 0.01 factor as needed 90 | text_y = height * 0.01 # You can adjust the 0.01 factor as needed 91 | 92 | # Draw the text on the image with the proportional font size 93 | draw.text((text_x, text_y), text, (57, 255, 20), font=font) # Green text, top-left corner 94 | 95 | return image 96 | 97 | def process_frame(frame, last_prediction, index_to_tag): 98 | frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 99 | pil_image = Image.fromarray(frame_rgb) 100 | annotated_image = overlay_predictions(pil_image, last_prediction, index_to_tag) 101 | frame = cv2.cvtColor(np.array(annotated_image), cv2.COLOR_RGB2BGR) 102 | return frame 103 | 104 | def overlay_predictions_video(video_path, predictions, frame_counts, index_to_tag, output_path): 105 | video_capture = cv2.VideoCapture(str(video_path)) 106 | fps = video_capture.get(cv2.CAP_PROP_FPS) 107 | width = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH)) 108 | height = int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) 109 | fourcc = cv2.VideoWriter_fourcc(*'mp4v') 110 | video_writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) 111 | 112 | last_prediction = None 113 | frame_idx = 0 114 | pred_idx = 0 115 | frames_buffer = [] 116 | 117 | with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor: 118 | while True: 119 | ret, frame = video_capture.read() 120 | if not ret: 121 | break 122 | if pred_idx < len(frame_counts) and frame_idx == frame_counts[pred_idx]: 123 | last_prediction = predictions[pred_idx] 124 | pred_idx += 1 125 | 126 | if last_prediction is not None: 127 | frames_buffer.append(frame) 128 | 129 | # If the buffer size reaches a threshold, process frames in parallel 130 | if len(frames_buffer) >= 10: # Example buffer size 131 | # Process frames in parallel 132 | processed_frames = list(executor.map(lambda f: process_frame(f, last_prediction, index_to_tag), frames_buffer)) 133 | # Write processed frames to the output video 134 | for processed_frame in processed_frames: 135 | video_writer.write(processed_frame) 136 | frames_buffer = [] 137 | 138 | frame_idx += 1 139 | 140 | # Make sure to process the remaining frames in the buffer 141 | if frames_buffer: 142 | processed_frames = list(executor.map(lambda f: process_frame(f, last_prediction, index_to_tag), frames_buffer)) 143 | for processed_frame in processed_frames: 144 | video_writer.write(processed_frame) 145 | 146 | video_capture.release() 147 | video_writer.release() 148 | 149 | 150 | def convert_labels_to_color(labels, num_classes, height=10, width=10): 151 | """ 152 | Converts labels to a color representation using a colormap. 153 | 154 | Parameters: 155 | labels (torch.Tensor): Tensor of labels, either in class index form or one-hot encoded. 156 | num_classes (int): Number of classes. 157 | height (int): Height of the color image representation. 158 | width (int): Width of the color image representation. 159 | 160 | Returns: 161 | torch.Tensor: Color representation of labels in the shape [batch_size, channels, height, width]. 162 | """ 163 | # Generate a colormap 164 | cmap = plt.get_cmap('viridis', num_classes) # Get the colormap 165 | 166 | # Convert labels to indices if they are one-hot encoded 167 | if labels.ndim > 1 and labels.size(1) == num_classes: 168 | labels = labels.argmax(dim=1) 169 | else: 170 | # If labels are not one-hot encoded, ensure they are integer class indices 171 | labels = labels.long() 172 | 173 | # Normalize label indices to be between 0 and 1 174 | labels_normalized = labels.float() / (num_classes - 1) 175 | 176 | # Map normalized indices to colors using the colormap 177 | colors = cmap(labels_normalized.numpy())[:, :3] # Get the RGB values and exclude the alpha channel 178 | 179 | # Convert colors to a PyTorch tensor and reshape to [batch_size, 1, 1, channels] 180 | colors_tensor = torch.tensor(colors, dtype=torch.float32).view(-1, 1, 1, 3) 181 | 182 | # Repeat colors across the desired image dimensions to create a full image representation for each label 183 | colors_tensor = colors_tensor.repeat(1, height, width, 1) 184 | 185 | # Permute the tensor to match the [batch_size, channels, height, width] format 186 | colors_tensor = colors_tensor.permute(0, 3, 1, 2) 187 | 188 | return colors_tensor 189 | 190 | def convert_PIL_to_tensors(pil_images): 191 | """ 192 | Converts a list of PIL images to PyTorch tensors. 193 | 194 | Parameters: 195 | pil_images (list of PIL.Image): List of PIL images to convert. 196 | 197 | Returns: 198 | torch.Tensor: Batch of images as PyTorch tensors. 199 | """ 200 | return torch.stack([transforms.ToTensor()(img) for img in pil_images]) 201 | 202 | def convert_to_PIL(images): 203 | """ 204 | Converts a batch of PyTorch tensors to PIL images. 205 | 206 | Parameters: 207 | images (torch.Tensor or list of torch.Tensor): Batch of images to convert. 208 | 209 | Returns: 210 | list: List of PIL images. 211 | """ 212 | to_pil = transforms.ToPILImage() 213 | return [to_pil(image) for image in images] 214 | 215 | def preprocess_image(image_path, config): 216 | """Preprocess an image file to be suitable for model input. 217 | 218 | Parameters: 219 | image_path (str): Path of the image to process. 220 | config (object): Configuration desired image size and normalization parameters. 221 | 222 | """ 223 | transforms = Compose([ 224 | Resize(config.model_image_size), # Resize to the input size expected by the model 225 | ToTensor(), # Convert to PyTorch Tensor 226 | Normalize(config.dataset_normalization_mean, config.dataset_normalization_std) # Normalize with the same values used in training 227 | ]) 228 | image = Image.open(image_path) 229 | return transforms(image).unsqueeze(0) # Add batch dimension -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /src/imclaslib/metrics/metricutils.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import f1_score as sklearnf1, precision_score, recall_score, brier_score_loss 2 | import numpy as np 3 | import torch 4 | from imclaslib.logging.loggerfactory import LoggerFactory 5 | from scipy.special import expit 6 | from torch.optim import LBFGS 7 | logger = LoggerFactory.get_logger(f"logger.{__name__}") 8 | 9 | def f1_score(targets, predictions, average='micro'): 10 | """ 11 | Compute F1 score for binary predictions. 12 | 13 | Parameters: 14 | - targets: array-like, true binary labels 15 | - predictions: array-like, binary predictions 16 | - average: string, [None, 'micro' (default), 'macro', 'samples', 'weighted'] 17 | 18 | Returns: 19 | - f1: float, computed F1 score 20 | """ 21 | return sklearnf1(targets, predictions, average=average, zero_division=1.0) 22 | 23 | def compute_metrics(targets, outputs, average='micro'): 24 | """ 25 | Compute precision, recall, and F1 score for each class. 26 | 27 | Parameters: 28 | - targets: array-like, true binary labels 29 | - outputs: array-like, raw output scores from the classifier 30 | - threshold: float, threshold for converting raw scores to binary predictions 31 | - average: string, [None (default), 'micro', 'macro', 'samples', 'weighted'] 32 | This parameter is required for multilabel targets. 33 | 34 | Returns: 35 | - precision: float, precision score per class 36 | - recall: float, recall score per class 37 | - f1: float, F1 score per class 38 | """ 39 | 40 | precision = precision_score(targets, outputs, average=average, zero_division=1.0) 41 | recall = recall_score(targets, outputs, average=average, zero_division=1.0) 42 | f1 = sklearnf1(targets, outputs, average=average, zero_division=1.0) 43 | return precision, recall, f1 44 | 45 | def getConfidences(outputs): 46 | return torch.sigmoid(outputs).cpu().numpy() 47 | 48 | def getpredictions_with_threshold(outputs, device, threshold=0.5): 49 | """ 50 | Convert raw output scores to binary predictions using a threshold, using numpy arrays. 51 | 52 | Parameters: 53 | - outputs: numpy.ndarray, raw output scores from the classifier 54 | - threshold: float, threshold for converting raw scores to binary predictions 55 | 56 | Returns: 57 | - predictions: numpy.ndarray, binary predictions 58 | """ 59 | 60 | # Apply sigmoid to the outputs to get probabilities 61 | outputs = torch.Tensor(outputs).to(device) 62 | probabilities = getConfidences(outputs) 63 | 64 | if threshold is None: 65 | threshold = 0.5 66 | 67 | # Apply threshold to the probabilities to get binary predictions 68 | if np.isscalar(threshold): 69 | predictions = (probabilities > threshold).astype(int) 70 | else: 71 | # Ensure threshold is a numpy array and has the same number of elements as the number of classes 72 | if threshold.size != probabilities.shape[1]: 73 | raise ValueError("Threshold must have the same number of elements as the number of classes.") 74 | 75 | thresholds_reshaped = threshold.reshape(1, -1) 76 | # Repeat the threshold for each sample and compare 77 | predictions = (probabilities > thresholds_reshaped).astype(int) 78 | return predictions 79 | 80 | def convert_labels_to_strings(labels, index_to_tag): 81 | labels = labels.tolist() if isinstance(labels, torch.Tensor) else labels 82 | return (index_to_tag[i] for i, label in enumerate(labels) if label == 1) 83 | 84 | def convert_labels_to_string(labels, index_to_tag): 85 | labelstrings = convert_labels_to_strings(labels, index_to_tag) 86 | return ','.join(labelstrings) 87 | 88 | def cumulative_uncertainty(probabilities, certainty_window=0.03): 89 | return np.sum(np.where((probabilities > certainty_window) & (probabilities < (1-certainty_window)), 90 | np.minimum(probabilities - 0.0, 1.0 - probabilities), 0), axis=1) 91 | 92 | def uncertainty_metrics(probabilities, certainty_window=0.03): 93 | # Calculate cumulative uncertainty with the existing function logic 94 | cumulative_uncertainties = cumulative_uncertainty(probabilities, certainty_window) 95 | 96 | # Initialize lists to store the additional metrics 97 | max_uncertainties = [] 98 | mean_uncertainties = [] 99 | mean_entropies = [] 100 | 101 | # Calculate additional metrics for each image 102 | for image_probs in probabilities: 103 | # Calculate the uncertainty for each class probability in the image 104 | uncertainties = np.where( 105 | (image_probs > certainty_window) & (image_probs < (1 - certainty_window)), 106 | np.minimum(image_probs, 1.0 - image_probs), 107 | 0 108 | ) 109 | 110 | # Calculate the entropy for each class probability in the image 111 | eps = np.finfo(float).eps # Avoid division by zero in log 112 | image_entropies = -( 113 | image_probs * np.log(np.clip(image_probs, eps, 1)) + 114 | (1 - image_probs) * np.log(np.clip(1 - image_probs, eps, 1)) 115 | ) / np.log(2) # Normalize to log base 2 116 | 117 | # Calculate the mean entropy for the image 118 | mean_entropy = np.mean(image_entropies) 119 | 120 | # Find the max uncertainty (which is the max distance from certainty) for the image 121 | max_uncertainty_index = np.argmax(uncertainties) 122 | max_uncertainty = uncertainties[max_uncertainty_index] 123 | 124 | # Store the max uncertainty and its corresponding class number 125 | max_uncertainties.append((max_uncertainty, max_uncertainty_index)) 126 | 127 | # Calculate the mean uncertainty for the image 128 | mean_uncertainty = np.mean(uncertainties) 129 | 130 | # Append the mean entropy and mean uncertainty to their respective lists 131 | mean_entropies.append(mean_entropy) 132 | mean_uncertainties.append(mean_uncertainty) 133 | 134 | # Return the results as a dictionary 135 | return { 136 | 'cumulative_uncertainties': cumulative_uncertainties, 137 | 'max_uncertainties': max_uncertainties, 138 | 'mean_uncertainties': mean_uncertainties, 139 | 'mean_entropies': mean_entropies 140 | } 141 | 142 | def find_best_threshold(prediction_outputs, true_labels, device, metric='f1', num_thresholds=100, average='micro'): 143 | """ 144 | Find the best threshold for binary predictions to optimize the given metric. 145 | 146 | Parameters: 147 | data_loader (DataLoader): DataLoader for evaluation. 148 | prediction_outputs (numpy.ndarray): Raw model outputs. 149 | true_labels (numpy.ndarray): Corresponding true labels. 150 | metric (str): The metric to optimize ('f1', 'precision', or 'recall'). 151 | num_thresholds (int): The number of threshold values to consider between 0 and 1. 152 | average (str): The type of averaging performed when computing metrics. 153 | 154 | Returns: 155 | best_threshold (float): The threshold value that optimizes the given metric. 156 | best_metric_value (float): The value of the optimized metric at the best threshold. 157 | """ 158 | best_threshold = None 159 | best_metric_value = 0 160 | 161 | best_f1, best_precision, best_recall = 0, 0, 0 162 | 163 | for threshold in np.linspace(0, 1, num_thresholds): 164 | # Get binary predictions based on the current threshold 165 | predictions_binary = getpredictions_with_threshold(prediction_outputs, device, threshold) 166 | 167 | # Compute evaluation metrics 168 | precision, recall, f1 = compute_metrics(true_labels, predictions_binary, average=average) 169 | 170 | # Select the current metric 171 | if metric == 'f1': 172 | current_metric_value = f1 173 | elif metric == 'precision': 174 | current_metric_value = precision 175 | elif metric == 'recall': 176 | current_metric_value = recall 177 | else: 178 | raise ValueError("Invalid metric specified. Choose 'f1', 'precision', or 'recall'.") 179 | 180 | # Update the best threshold and metric value if the current one is better 181 | if current_metric_value > best_metric_value: 182 | best_metric_value = current_metric_value 183 | best_threshold = threshold 184 | best_f1, best_precision, best_recall = f1, precision, recall 185 | 186 | return best_threshold, best_f1, best_precision, best_recall 187 | 188 | def find_best_thresholds_per_class(probabilities, true_labels, metric='f1', num_thresholds=100): 189 | """ 190 | Find the best threshold for binary predictions for each class to optimize the given metric. 191 | 192 | Parameters: 193 | prediction_outputs (numpy.ndarray): Raw model outputs. 194 | true_labels (numpy.ndarray): Corresponding true labels. 195 | metric (str): The metric to optimize ('f1', 'precision', or 'recall'). 196 | num_thresholds (int): The number of threshold values to consider between 0 and 1. 197 | 198 | Returns: 199 | best_thresholds (numpy.ndarray): The threshold value that optimizes the given metric for each class. 200 | """ 201 | num_classes = probabilities.shape[1] 202 | best_thresholds = np.zeros(num_classes) 203 | 204 | for class_idx in range(num_classes): 205 | best_metric_value_for_class = 0 206 | best_threshold_for_class = 0.5 # Default threshold 207 | 208 | #logger.debug(f"Optimizing threshold for class {class_idx}.") 209 | for threshold in np.linspace(0, 1, num_thresholds): 210 | # Get binary predictions based on the current threshold for this class 211 | predictions_binary = (probabilities[:, class_idx] > threshold).astype(int) 212 | 213 | # Select the current metric 214 | if metric == 'f1': 215 | current_metric_value = sklearnf1(true_labels[:, class_idx], predictions_binary, zero_division=1) 216 | elif metric == 'precision': 217 | current_metric_value = precision_score(true_labels[:, class_idx], predictions_binary, zero_division=1) 218 | elif metric == 'recall': 219 | current_metric_value = recall_score(true_labels[:, class_idx], predictions_binary, zero_division=1) 220 | else: 221 | raise ValueError("Invalid metric specified. Choose 'f1', 'precision', or 'recall'.") 222 | #logger.debug(f"Class {class_idx}: Threshold {threshold:.4f}, {metric.capitalize()} {current_metric_value:.4f}") 223 | 224 | # Update the best threshold and metric value if the current one is better 225 | if current_metric_value > best_metric_value_for_class: 226 | best_metric_value_for_class = current_metric_value 227 | best_threshold_for_class = threshold 228 | 229 | best_thresholds[class_idx] = best_threshold_for_class 230 | return best_thresholds 231 | 232 | def filter_dict_for_hparams(input_dict): 233 | """ 234 | Filters out types that arent allowed in hparams for tensorboard 235 | """ 236 | # Define the allowed types 237 | allowed_types = (int, float, str, bool, torch.Tensor) 238 | 239 | # Create a new dictionary to store the filtered key-value pairs 240 | filtered_dict = {} 241 | 242 | # Iterate over the items in the original dictionary 243 | for key, value in input_dict.items(): 244 | # If the value is of an allowed type, add it to the new dictionary 245 | if isinstance(value, allowed_types): 246 | filtered_dict[key] = value 247 | 248 | return filtered_dict 249 | 250 | def multi_label_brier_score(y_true, y_pred, average='macro'): 251 | """ 252 | Calculate the Brier score for multi-label classification with various averaging methods. 253 | 254 | :param y_true: A NumPy array of ground truth labels with shape (num_samples, num_classes). 255 | :param y_pred: A NumPy array of predicted logits with shape (num_samples, num_classes). 256 | :param average: The averaging method to use ('macro', 'micro', 'weighted', 'samples'). 257 | :return: The Brier score. 258 | """ 259 | y_pred = getConfidences(y_pred) 260 | num_classes = y_true.shape[1] 261 | 262 | if average == 'macro': 263 | # Calculate Brier score for each label and then average 264 | brier_scores = [np.mean((y_true[:, i] - y_pred[:, i]) ** 2) for i in range(y_true.shape[1])] 265 | return np.mean(brier_scores) 266 | 267 | elif average == 'micro': 268 | # Calculate the mean squared difference between all true labels and predictions 269 | return np.mean((y_true - y_pred) ** 2) 270 | 271 | elif average == 'weighted': 272 | # Calculate Brier score for each label, weighted by support (the number of true instances for each label) 273 | supports = np.sum(y_true, axis=0) 274 | brier_scores = [np.mean((y_true[:, i] - y_pred[:, i]) ** 2) for i in range(y_true.shape[1])] 275 | return np.average(brier_scores, weights=supports) 276 | 277 | elif average == 'samples': 278 | # Calculate Brier score for each individual label within each sample and average these 279 | brier_scores = np.mean((y_true - y_pred) ** 2, axis=1) 280 | return np.mean(brier_scores) 281 | 282 | else: 283 | raise ValueError("The 'average' parameter should be one of 'macro', 'micro', 'weighted', or 'samples'.") 284 | 285 | def temperature_scale(logits, temperature): 286 | """ 287 | Scale the logits by the temperature. 288 | """ 289 | return logits / temperature 290 | 291 | def find_optimal_temperature(valid_logits, valid_labels, device): 292 | """ 293 | Find the optimal temperature for multilabel classification using the validation set. 294 | """ 295 | # Initial temperature 296 | temperature = torch.nn.Parameter(torch.ones(1, device=device)) 297 | 298 | # Define the loss function and optimizer 299 | bce_with_logits_loss = torch.nn.BCEWithLogitsLoss().to(device) 300 | optimizer = LBFGS([temperature], lr=0.01, max_iter=50) 301 | 302 | def eval(): 303 | loss = bce_with_logits_loss(temperature_scale(valid_logits, temperature), valid_labels) 304 | loss.backward() 305 | return loss 306 | 307 | # Find the optimal temperature 308 | optimizer.step(eval) 309 | 310 | return temperature.item() -------------------------------------------------------------------------------- /src/imclaslib/evaluation/modelevaluator.py: -------------------------------------------------------------------------------- 1 | # evaluator.py 2 | import torch 3 | import wandb 4 | from imclaslib.logging.loggerfactory import LoggerFactory 5 | import imclaslib.files.pathutils as pathutils 6 | import imclaslib.models.modelfactory as modelfactory 7 | import imclaslib.metrics.metricutils as metricutils 8 | import imclaslib.files.modelloadingutils as modelloadingutils 9 | from torch.cuda.amp import autocast 10 | from tqdm import tqdm 11 | import numpy as np 12 | import torch.nn as nn 13 | import random 14 | import os 15 | import itertools 16 | import gc 17 | 18 | # Initialize logger for this module. 19 | logger = LoggerFactory.get_logger(f"logger.{__name__}") 20 | 21 | class ModelEvaluator: 22 | def __init__(self, model, criterion, device, config, wandbWriter=None, model_data=None): 23 | """ 24 | Initializes the ModelEvaluator with a given model, loss criterion, device, 25 | optional TensorBoard writer, and configuration. 26 | 27 | Parameters: 28 | model (torch.nn.Module): The model to evaluate. 29 | criterion (function): The loss function. 30 | device (torch.device): The device to run evaluation on (CPU or GPU). 31 | wandbWriter (WandbWriter, optional): Writer for Wandb logging. 32 | config (object): An immutable configuration object with necessary parameters. 33 | """ 34 | self.config = config 35 | self.model = model 36 | self.criterion = criterion 37 | self.device = device 38 | self.num_classes = config.model_num_classes 39 | self.wandbWriter = wandbWriter 40 | self.model_data = model_data 41 | self.metrics_enabled = (wandbWriter != None) 42 | 43 | def __enter__(self): 44 | """ 45 | Context management method to use with 'with' statements. 46 | """ 47 | return self 48 | 49 | def __exit__(self, exc_type, exc_value, traceback): 50 | """ 51 | Context management method to close the TensorBoard writer upon exiting the 'with' block. 52 | """ 53 | del self.model 54 | torch.cuda.empty_cache() 55 | gc.collect() 56 | 57 | @classmethod 58 | def from_trainer(cls, model_trainer): 59 | """ 60 | Creates a ModelEvaluator instance from a ModelTrainer instance by extracting 61 | the relevant attributes. 62 | 63 | Parameters: 64 | model_trainer (ModelTrainer): The trainer instance to extract attributes from. 65 | 66 | Returns: 67 | ModelEvaluator: A new instance of ModelEvaluator. 68 | """ 69 | return cls( 70 | model=model_trainer.model, 71 | criterion=model_trainer.criterion, 72 | device=model_trainer.device, 73 | config=model_trainer.config, 74 | wandbWriter=model_trainer.wandbWriter, 75 | model_data=model_trainer.best_model_state 76 | ) 77 | 78 | @classmethod 79 | def from_file(cls, device, thisconfig, wandbWriter=None): 80 | """ 81 | Creates a ModelEvaluator instance from a model file by loading in the model and preparing it 82 | to be run. 83 | 84 | Parameters: 85 | device (torch.device): The device to run evaluation on (CPU or GPU). 86 | wandbWriter (WandbWriter, optional): Writer for Wandb logging. 87 | config (object): An immutable configuration object with necessary parameters. 88 | """ 89 | 90 | model = modelfactory.create_model(thisconfig).to(device) 91 | criterion = nn.BCEWithLogitsLoss() 92 | 93 | modelToLoadPath = pathutils.get_model_to_load_path(thisconfig) 94 | if os.path.exists(modelToLoadPath): 95 | logger.info("Loading the best model...") 96 | modelData = modelloadingutils.load_model(modelToLoadPath, thisconfig) 97 | model.load_state_dict(modelData['model_state_dict']) 98 | logger.info("Loaded the best model in the Evaluator") 99 | else: 100 | logger.error(f"Could not find a model at path: {modelToLoadPath}") 101 | raise ValueError(f"Could not find a model at path: {modelToLoadPath}. Check to ensure the config/json value for model_name_to_load is correct!") 102 | 103 | return cls( 104 | model=model, 105 | criterion=criterion, 106 | device=device, 107 | config=thisconfig, 108 | wandbWriter=wandbWriter, 109 | model_data=modelData 110 | ) 111 | 112 | @classmethod 113 | def from_ensemble(cls, device, thisconfig, wandbWriter=None, loadFromFile=False): 114 | """ 115 | Creates a ModelEvaluator instance from a model file by loading in the model and preparing it 116 | to be run. 117 | 118 | Parameters: 119 | device (torch.device): The device to run evaluation on (CPU or GPU). 120 | wandbWriter (WandbWriter, optional): Writer for TensorBoard logging. 121 | config (object): An immutable configuration object with necessary parameters. 122 | """ 123 | 124 | model = modelfactory.create_model(thisconfig).to(device) 125 | criterion = nn.BCEWithLogitsLoss() 126 | 127 | if loadFromFile: 128 | modelToLoadPath = pathutils.get_model_to_load_path(thisconfig) 129 | modelData = modelloadingutils.load_model(modelToLoadPath, thisconfig) 130 | model.load_state_dict(modelData['model_state_dict']) 131 | 132 | model_data = { 133 | "epoch": 1, 134 | "train_loss": 0, 135 | } 136 | 137 | return cls( 138 | model=model, 139 | criterion=criterion, 140 | device=device, 141 | config=thisconfig, 142 | wandbWriter=wandbWriter, 143 | model_data=model_data 144 | ) 145 | 146 | def single_image_prediction(self, preprocessed_image, threshold=None): 147 | """Run a prediction for a single preprocessed image.""" 148 | self.model.eval() # Set the model to evaluation mode 149 | 150 | # Move the preprocessed image to the same device as the model 151 | preprocessed_image = preprocessed_image.to(self.device) 152 | 153 | with torch.no_grad(): 154 | # Add a batch dimension to the image tensor 155 | image_batch = preprocessed_image.unsqueeze(0) 156 | outputs = self.model(image_batch) 157 | if threshold is not None: 158 | outputs_np = metricutils.getpredictions_with_threshold(outputs_np, self.device, threshold) 159 | # Wrap the NumPy array back into a PyTorch tensor if necessary 160 | outputs = torch.from_numpy(outputs_np) 161 | # Remove the batch dimension from the outputs before returning 162 | outputs = outputs.squeeze(0) 163 | return outputs 164 | 165 | def compile(self): 166 | self.model = torch.compile(self.model, mode="max-autotune") 167 | def predict(self, data_loader, return_true_labels=True, threshold=None): 168 | """ 169 | Perform inference on the given data_loader and return raw predictions. 170 | 171 | Parameters: 172 | data_loader (DataLoader): DataLoader for inference. 173 | return_true_labels (bool): If true, return true labels. Otherwise, skip label processing. 174 | 175 | Returns: 176 | prediction_labels (numpy.ndarray): Raw model outputs. 177 | true_labels (numpy.ndarray, optional): Corresponding true labels, if available and requested. 178 | avg_loss (float, optional): Average loss over dataset, if labels are available. 179 | """ 180 | model = self.model 181 | model.eval() # Set the model to evaluation mode 182 | prediction_outputs = [] # List to store all raw model outputs 183 | true_labels = [] # List to store all labels if they are available 184 | image_paths = [] # List to store all image paths if they are available 185 | frame_counts = [] # List to store all frame counts if they are available 186 | total_loss = 0.0 # Initialize total loss 187 | with torch.no_grad(): # Disable gradient calculation for efficiency 188 | for batch in tqdm(data_loader, total=len(data_loader)): 189 | images = batch['image'].to(self.device) 190 | if self.config.model_fp16: 191 | images = images.half() 192 | with autocast(enabled=self.config.model_fp16): 193 | outputs = model(images) 194 | prediction_outputs.append(outputs.cpu().numpy()) # Store raw model outputs 195 | 196 | # Process labels if they are available and requested 197 | if return_true_labels and 'label' in batch: 198 | labels = batch['label'].to(self.device) 199 | loss = self.criterion(outputs, labels.float()) # Calculate loss 200 | total_loss += loss.item() # Accumulate loss 201 | true_labels.append(labels.cpu().numpy()) # Store labels 202 | elif not return_true_labels and 'image_path' in batch: 203 | image_paths.append(batch['image_path']) 204 | elif not return_true_labels and 'frame_count' in batch: 205 | frame_counts.append(batch['frame_count']) 206 | # Concatenate all raw outputs and optionally labels from all batches 207 | prediction_outputs = np.vstack(prediction_outputs) 208 | results = {'predictions': prediction_outputs} 209 | 210 | if return_true_labels and true_labels: 211 | true_labels = np.vstack(true_labels) 212 | avg_loss = total_loss / len(data_loader.dataset) 213 | results['true_labels'] = true_labels 214 | results['avg_loss'] = avg_loss 215 | 216 | if image_paths: 217 | results['image_paths'] = image_paths 218 | 219 | if frame_counts: 220 | results['frame_counts'] = frame_counts 221 | 222 | if threshold != None: 223 | predictions_binary = metricutils.getpredictions_with_threshold(prediction_outputs, self.device, threshold) 224 | results['predictions'] = predictions_binary 225 | 226 | return results 227 | 228 | def evaluate_predictions(self, data_loader, prediction_outputs, true_labels, epoch, average, datasetSubset=None, metricMode=None, threshold=None): 229 | """ 230 | Evaluate the model on the given data_loader. 231 | 232 | Parameters: 233 | data_loader (DataLoader): DataLoader for evaluation. 234 | prediction_outputs (numpy.ndarray): Raw model outputs. 235 | true_labels (numpy.ndarray): Corresponding true labels. 236 | epoch (int): The current epoch number, used for TensorBoard logging. 237 | datasetSubset (str): Indicates the subset of data evaluated (e.g., 'test', 'validation'). 238 | average (str): Indicates the type of averaging to perform when computing metrics. Use None to get per-class metrics. 239 | metricMode (str, optional): Indicates from where this is being evaluated from (e.g., 'Train', 'Test'). 240 | threshold (float, optional): The threshold value for binary predictions. 241 | 242 | Returns: 243 | f1_score (float): The F1 score of the model on the dataset. 244 | precision (float): The precision of the model on the dataset. 245 | recall (float): The recall of the model on the dataset. 246 | """ 247 | predictions_binary = metricutils.getpredictions_with_threshold(prediction_outputs, self.device, threshold) 248 | # Compute evaluation metrics 249 | precision, recall, f1 = metricutils.compute_metrics(true_labels, predictions_binary, average=average) 250 | #if f1 >= 0.9: 251 | #something is wrong 252 | # i = 1 253 | # Log images with predictions to TensorBoard for a random batch, if configured 254 | if metricMode is not None and self.wandbWriter is not None and datasetSubset is not None: 255 | random_batch_index = random.randint(0, len(data_loader) - 1) 256 | batch_dict = next(itertools.islice(data_loader, random_batch_index, None)) 257 | images = batch_dict['image'] # Assuming the device transfer happens elsewhere if needed 258 | labels = batch_dict['label'] 259 | 260 | start_index = random_batch_index * data_loader.batch_size 261 | end_index = min((random_batch_index + 1) * data_loader.batch_size, len(predictions_binary)) 262 | 263 | selected_predictions = predictions_binary[start_index:end_index] 264 | selected_predictions_tensor = torch.tensor(selected_predictions, device=self.device, dtype=torch.float32) 265 | #self.tensorBoardWriter.write_image_test_results(images, labels, selected_predictions_tensor, epoch, metricMode, datasetSubset) 266 | # Return the average loss and computed metrics 267 | return f1, precision, recall 268 | 269 | def evaluate(self, data_loader, epoch, datasetSubset, metricMode=None, average='micro', threshold=None): 270 | """ 271 | Evaluate the model on the given data_loader. 272 | 273 | Parameters: 274 | data_loader (DataLoader): DataLoader for evaluation. 275 | epoch (int): The current epoch number, used for TensorBoard logging. 276 | datasetSubset (str): Indicates the subset of data being evaluated (e.g., 'test', 'validation'). 277 | average (str): Indicates the type of averaging to perform when computing metrics. Use None to get per-class metrics. 278 | metricMode (str, optional): Indicates from where this is being evaluated from (e.g., 'Train', 'Test'). 279 | threshold (float, optional): The threshold value for binary predictions. 280 | 281 | Returns: 282 | avg_loss (float): The average loss over the dataset. 283 | f1_score (float): The F1 score of the model on the dataset. 284 | precision (float): The precision of the model on the dataset. 285 | recall (float): The recall of the model on the dataset. 286 | """ 287 | # Perform inference and get raw outputs 288 | prediction_results = self.predict(data_loader) 289 | all_outputs, all_labels, avg_loss = prediction_results['predictions'], prediction_results['true_labels'], prediction_results['avg_loss'] 290 | 291 | f1, precision, recall = self.evaluate_predictions(data_loader, all_outputs, all_labels, epoch, average, datasetSubset, metricMode, threshold) 292 | 293 | # Return the average loss and computed metrics 294 | return avg_loss, f1, precision, recall 295 | -------------------------------------------------------------------------------- /src/imclaslib/evaluation/test_model.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | from imclaslib.evaluation.modelevaluator import ModelEvaluator 3 | from imclaslib.logging.loggerfactory import LoggerFactory 4 | from imclaslib.metrics import metricutils 5 | import torch 6 | import imclaslib.dataset.datasetutils as datasetutils 7 | import time 8 | import numpy as np 9 | from scipy.special import expit 10 | 11 | # Set up logging for the training process 12 | logger = LoggerFactory.get_logger(f"logger.{__name__}") 13 | 14 | def evaluate_model(this_config, valid_loader=None, test_loader=None, wandbWriter=None): 15 | # initialize the computation device 16 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 17 | 18 | if test_loader == None: 19 | test_loader = datasetutils.get_data_loader_by_name("test", config=this_config, num_workers=0) 20 | if valid_loader == None: 21 | valid_loader = datasetutils.get_data_loader_by_name("valid", config=this_config, num_workers=0) 22 | 23 | # intialize the model 24 | with get_model_evaluator(this_config, device, wandbWriter=wandbWriter) as modelEvaluator: 25 | epochs = modelEvaluator.model_data["epoch"] 26 | if this_config.using_wsl and this_config.test_compile: 27 | modelEvaluator.compile() 28 | valid_start_time = time.time() 29 | valid_results = modelEvaluator.predict(valid_loader) 30 | valid_end_time = time.time() 31 | 32 | test_start_time = time.time() 33 | test_results = modelEvaluator.predict(test_loader) 34 | test_end_time = time.time() 35 | 36 | valid_predictions, valid_correct_labels, valid_loss = valid_results['predictions'], valid_results['true_labels'], valid_results['avg_loss'] 37 | test_predictions, test_correct_labels, test_loss = test_results['predictions'], test_results['true_labels'], test_results['avg_loss'] 38 | 39 | # Assuming device is already defined as in your code snippet 40 | # Perform temperature scaling on validation logits 41 | valid_logits = torch.Tensor(valid_predictions).to(device) 42 | valid_labels = torch.Tensor(valid_correct_labels).to(device) # Assuming valid_correct_labels are floats in [0, 1] 43 | optimal_temperature = metricutils.find_optimal_temperature(valid_logits, valid_labels, device) 44 | 45 | valid_predictions = metricutils.temperature_scale(valid_logits, optimal_temperature) 46 | # Apply temperature scaling to test logits 47 | test_logits = torch.Tensor(test_predictions).to(device) 48 | test_predictions = metricutils.temperature_scale(test_logits, optimal_temperature) 49 | 50 | confidence_thresholds = (0.01, 0.02, 0.05, 0.1, 0.2) # Adjust these thresholds to suit your needs 51 | test_proabilities = metricutils.getConfidences(test_predictions) #torch.sigmoid(test_predictions).cpu().numpy() 52 | test_confidence_categories = categorize_predictions(test_proabilities, confidence_thresholds) 53 | 54 | valid_elapsed_time = valid_end_time - valid_start_time 55 | test_elapsed_time = test_end_time - test_start_time 56 | 57 | valid_num_images = len(valid_loader.dataset) 58 | test_num_images = len(test_loader.dataset) 59 | 60 | valid_images_per_second = valid_num_images / valid_elapsed_time 61 | test_images_per_second = test_num_images / test_elapsed_time 62 | 63 | avg_images_per_second = (valid_num_images + test_num_images) / (valid_elapsed_time + test_elapsed_time) 64 | 65 | logger.info(f"Validation Img/sec: {valid_images_per_second}") 66 | logger.info(f"Test Img/sec: {test_images_per_second}") 67 | 68 | logger.info(f"Avg Img/sec: {avg_images_per_second}") 69 | 70 | logger.info(f"Validation Loss: {valid_loss}") 71 | logger.info(f"Test Loss: {test_loss}") 72 | 73 | val_f1_default_micro, val_precision_default_micro, val_recall_default_micro = modelEvaluator.evaluate_predictions(valid_loader, valid_predictions, valid_correct_labels, epochs, threshold=0.5, average="micro") 74 | test_f1_default_micro, test_precision_default_micro, test_recall_default_micro = modelEvaluator.evaluate_predictions(test_loader, test_predictions, test_correct_labels, epochs, threshold=0.5, average="micro") 75 | 76 | val_f1_default_macro, val_precision_default_macro, val_recall_default_macro = modelEvaluator.evaluate_predictions(valid_loader, valid_predictions, valid_correct_labels, epochs, threshold=0.5, average="macro") 77 | test_f1_default_macro, test_precision_default_macro, test_recall_default_macro = modelEvaluator.evaluate_predictions(test_loader, test_predictions, test_correct_labels, epochs, threshold=0.5, average="macro") 78 | 79 | val_f1_default_weighted, val_precision_default_weighted, val_recall_default_weighted = modelEvaluator.evaluate_predictions(valid_loader, valid_predictions, valid_correct_labels, epochs, threshold=0.5, average="weighted") 80 | test_f1_default_weighted, test_precision_default_weighted, test_recall_default_weighted = modelEvaluator.evaluate_predictions(test_loader, test_predictions, test_correct_labels, epochs, threshold=0.5, average="weighted") 81 | 82 | val_f1_default_samples, val_precision_default_samples, val_recall_default_samples = modelEvaluator.evaluate_predictions(valid_loader, valid_predictions, valid_correct_labels, epochs, threshold=0.5, average="samples") 83 | test_f1_default_samples, test_precision_default_samples, test_recall_default_samples = modelEvaluator.evaluate_predictions(test_loader, test_predictions, test_correct_labels, epochs, threshold=0.5, average="samples") 84 | 85 | valid_brier_micro = metricutils.multi_label_brier_score(valid_correct_labels, valid_predictions, 'micro') 86 | valid_brier_macro = metricutils.multi_label_brier_score(valid_correct_labels, valid_predictions, 'macro') 87 | valid_brier_weighted = metricutils.multi_label_brier_score(valid_correct_labels, valid_predictions, 'weighted') 88 | valid_brier_samples = metricutils.multi_label_brier_score(valid_correct_labels, valid_predictions, 'samples') 89 | logger.info(f"Brier Score for Validation: Macro:{valid_brier_macro}, Micro:{valid_brier_micro}, Weighted:{valid_brier_weighted}, Samples:{valid_brier_samples}") 90 | 91 | test_brier_micro = metricutils.multi_label_brier_score(test_correct_labels, test_predictions, 'micro') 92 | test_brier_macro = metricutils.multi_label_brier_score(test_correct_labels, test_predictions, 'macro') 93 | test_brier_weighted = metricutils.multi_label_brier_score(test_correct_labels, test_predictions, 'weighted') 94 | test_brier_samples = metricutils.multi_label_brier_score(test_correct_labels, test_predictions, 'samples') 95 | logger.info(f"Brier Score for Test: Macro:{test_brier_macro}, Micro:{test_brier_micro}, Weighted:{test_brier_weighted}, Samples:{test_brier_samples}") 96 | 97 | logger.info(f"Validation Default F1: F1: {val_f1_default_micro}, Precision: {val_precision_default_micro}, Recall: {val_recall_default_micro} at Threshold: 0.5") 98 | logger.info(f"Test Default F1: F1: {test_f1_default_micro}, Precision: {test_precision_default_micro}, Recall: {test_recall_default_micro} at Threshold: 0.5") 99 | 100 | val_best_f1_threshold, val_f1_valoptimized, val_precision_valoptimized, val_recall_valoptimized = metricutils.find_best_threshold(valid_predictions, valid_correct_labels, device, "f1") 101 | logger.info(f"Validation Best F1: F1: {val_f1_valoptimized}, Precision: {val_precision_valoptimized}, Recall: {val_recall_valoptimized} at Threshold:{val_best_f1_threshold}") 102 | test_f1_valoptimized, test_precision_valoptimized, test_recall_valoptimized = modelEvaluator.evaluate_predictions(test_loader, test_predictions, test_correct_labels, epochs, threshold=val_best_f1_threshold, average="micro", datasetSubset="Test", metricMode="Test") 103 | test_f1_valoptimized_no_temp, test_precision_valoptimized_no_temp, test_recall_valoptimized_no_temp = modelEvaluator.evaluate_predictions(test_loader, test_logits, test_correct_labels, epochs, threshold=val_best_f1_threshold, average="micro", datasetSubset="Test", metricMode="Test") 104 | 105 | test_f1_valoptimized_macro, _, _ = modelEvaluator.evaluate_predictions(test_loader, test_predictions, test_correct_labels, epochs, threshold=val_best_f1_threshold, average="macro", datasetSubset="Test", metricMode="Test") 106 | test_f1_valoptimized_weighted, _, _ = modelEvaluator.evaluate_predictions(test_loader, test_predictions, test_correct_labels, epochs, threshold=val_best_f1_threshold, average="weighted", datasetSubset="Test", metricMode="Test") 107 | test_f1_valoptimized_samples, _, _ = modelEvaluator.evaluate_predictions(test_loader, test_predictions, test_correct_labels, epochs, threshold=val_best_f1_threshold, average="samples", datasetSubset="Test", metricMode="Test") 108 | logger.info(f"Test Best F1 (measured from Val): F1: {test_f1_valoptimized}, Precision: {test_precision_valoptimized}, Recall: {test_recall_valoptimized} at Threshold:{val_best_f1_threshold}") 109 | logger.info(f"Test Best F1 (measured from Val): F1 Macro: {test_f1_valoptimized_macro}, F1 Weighted: {test_f1_valoptimized_weighted}, F1 Samples: {test_f1_valoptimized_samples} at Threshold:{val_best_f1_threshold}") 110 | 111 | #best_f1_thresholds_per_class = metricutils.find_best_thresholds_per_class(metricutils.getConfidences(valid_predictions), valid_correct_labels) 112 | #test_f1_valoptimizedperclass, test_precision_valoptimizedperclass, test_recall_valoptimizedperclass = modelEvaluator.evaluate_predictions(test_loader, test_predictions, test_correct_labels, epochs, threshold=best_f1_thresholds_per_class, average="micro") 113 | #logger.info(f"Test Best F1 Per Class (Val Optimized): F1: {test_f1_valoptimizedperclass}, Precision: {test_precision_valoptimizedperclass}, Recall: {test_recall_valoptimizedperclass} at Threshold:{best_f1_thresholds_per_class}") 114 | 115 | hparams = metricutils.filter_dict_for_hparams(modelEvaluator.model_data) 116 | final_metrics = { 117 | 'F1/Default/Validation/Micro': val_f1_default_micro, 118 | 'F1/Default/Validation/Macro': val_f1_default_macro, 119 | 'F1/Default/Validation/Weighted': val_f1_default_weighted, 120 | 'F1/Default/Validation/Samples': val_f1_default_samples, 121 | 'F1/Default/Test/Micro': test_f1_default_micro, 122 | 'F1/Default/Test/Macro': test_f1_default_macro, 123 | 'F1/Default/Test/Weighted': test_f1_default_weighted, 124 | 'F1/Default/Test/Samples': test_f1_default_samples, 125 | 'F1/ValOptimizedThreshold/Validation': val_f1_valoptimized, 126 | 'F1/ValOptimizedThreshold/Test/Micro': test_f1_valoptimized, 127 | 'F1/ValOptimizedThreshold/Test/Raw/Micro': test_f1_valoptimized_no_temp, 128 | 'F1/ValOptimizedThreshold/Test/Macro': test_f1_valoptimized_macro, 129 | 'F1/ValOptimizedThreshold/Test/Weighted': test_f1_valoptimized_weighted, 130 | 'F1/ValOptimizedThreshold/Test/Samples': test_f1_valoptimized_samples, 131 | 'Precision/Default/Validation': val_precision_default_micro, 132 | 'Precision/Default/Test': test_precision_default_micro, 133 | 'Precision/ValOptimizedThreshold/Validation': val_precision_valoptimized, 134 | 'Precision/ValOptimizedThreshold/Test': test_precision_valoptimized, 135 | 'Recall/Default/Validation': val_recall_default_micro, 136 | 'Recall/Default/Test': test_recall_default_micro, 137 | 'Recall/ValOptimizedThreshold/Validation': val_recall_valoptimized, 138 | 'Recall/ValOptimizedThreshold/Test': test_recall_valoptimized, 139 | 'ImagesPerSecond/Validation': valid_images_per_second, 140 | 'ImagesPerSecond/Test': test_images_per_second, 141 | 'ImagesPerSecond/Average': avg_images_per_second, 142 | 'Loss/TrainOverTestRatio': modelEvaluator.model_data["train_loss"] / test_loss, 143 | 'Brier/Validation/Micro': valid_brier_micro, 144 | 'Brier/Validation/Macro': valid_brier_macro, 145 | 'Brier/Validation/Weighted': valid_brier_weighted, 146 | 'Brier/Validation/Samples': valid_brier_samples, 147 | 'Brier/Test/Micro': test_brier_micro, 148 | 'Brier/Test/Macro': test_brier_macro, 149 | 'Brier/Test/Weighted': test_brier_weighted, 150 | 'Brier/Test/Samples': test_brier_samples, 151 | 'Temperature/ValOptimized': optimal_temperature 152 | } 153 | wandbWriter.log(final_metrics) 154 | 155 | test_f1s_per_class, test_precision_per_class, test_recall_per_class = modelEvaluator.evaluate_predictions(test_loader, test_predictions, test_correct_labels, epochs, threshold=val_best_f1_threshold, average=None) 156 | tagmappings = datasetutils.get_index_to_tag_mapping(this_config) 157 | testF1s = [] 158 | 159 | annotationCounts, fileCounts = datasetutils.analyze_csv(this_config) 160 | for class_index in range(this_config.model_num_classes): 161 | testF1s.append([tagmappings[class_index], test_f1s_per_class[class_index], test_precision_per_class[class_index], test_recall_per_class[class_index], annotationCounts[tagmappings[class_index]]]) 162 | wandbWriter.log_table("F1_Scores_by_Class", ["ClassName", "ClassF1", "ClassPrecision", "ClassRecall", "ClassDatasetCount"], testF1s) 163 | 164 | testF1s = [] 165 | test_f1s_per_class_default, test_precision_per_class_default, test_recall_per_class_default = modelEvaluator.evaluate_predictions(test_loader, test_predictions, test_correct_labels, epochs, threshold=0.5, average=None) 166 | for class_index in range(this_config.model_num_classes): 167 | testF1s.append([tagmappings[class_index], test_f1s_per_class_default[class_index], test_precision_per_class_default[class_index], test_recall_per_class_default[class_index], annotationCounts[tagmappings[class_index]]]) 168 | wandbWriter.log_table("F1_Scores_by_Class_Default", ["ClassName", "ClassF1", "ClassPrecision", "ClassRecall", "ClassDatasetCount"], testF1s) 169 | 170 | wandbWriter.log({"Dataset/Stats": fileCounts}) 171 | # Prepare to store results by category 172 | f1_scores_by_category = [] 173 | samples_by_category = [] 174 | 175 | data = [] 176 | confidence_thresholds_len = len(confidence_thresholds) 177 | # Calculate F1 scores for each category 178 | for category in range(confidence_thresholds_len+1): 179 | category_mask = (test_confidence_categories == category) 180 | category_predictions = test_proabilities[category_mask] 181 | category_true_labels = test_correct_labels[category_mask] 182 | 183 | # Ensure there are samples in the category before calculating F1 184 | if category_predictions.shape[0] > 0: 185 | binary_predictions = (category_predictions > val_best_f1_threshold).astype(int) 186 | category_f1 = metricutils.f1_score(category_true_labels, binary_predictions) 187 | else: 188 | category_f1 = None 189 | 190 | f1_scores_by_category.append(category_f1) 191 | samples_by_category.append(np.sum(category_mask)) 192 | 193 | # Log results 194 | logger.info(f"Confidence Category {category} - Images: {samples_by_category[-1]}, F1 Score: {f1_scores_by_category[-1]}") 195 | 196 | data.append([confidence_thresholds[category] if category < confidence_thresholds_len else 1000, f1_scores_by_category[-1] if category_f1 else 0, samples_by_category[-1] if category_f1 else 0]) 197 | assert np.sum(samples_by_category) == test_num_images 198 | wandbWriter.log_table("Data by Categories of Confidence", ["Confidence Threshold", "F1 Score", "Sample Count"], data) 199 | 200 | def get_model_evaluator(config, device, wandbWriter): 201 | if config.model_ensemble_model_configs: 202 | return ModelEvaluator.from_ensemble(device, config, wandbWriter=wandbWriter) 203 | else: 204 | return ModelEvaluator.from_file(device, config, wandbWriter=wandbWriter) 205 | 206 | def categorize_predictions(probabilities, thresholds, certainty_window=0.03): 207 | """ 208 | Categorize images into confidence levels based on the distance from the decision boundary for each label. 209 | 210 | Parameters: 211 | - probabilities: numpy.ndarray, the predicted probabilities for each label of each image 212 | - thresholds: tuple, containing the confidence thresholds for categorization 213 | 214 | Returns: 215 | - categories: numpy.ndarray, the categories for each image 216 | """ 217 | 218 | mean_entropies = metricutils.uncertainty_metrics(probabilities, certainty_window)["mean_entropies"] 219 | 220 | # Calculate and print the mean and standard deviation of cumulative uncertainties for debugging 221 | mean_uncertainty = np.mean(mean_entropies) 222 | std_uncertainty = np.std(mean_entropies) 223 | print("Mean cumulative uncertainty:", mean_uncertainty) 224 | print("Standard deviation of cumulative uncertainty:", std_uncertainty) 225 | 226 | # Assign confidence categories based on image confidence levels 227 | categories = np.digitize(mean_entropies, thresholds) 228 | 229 | return categories -------------------------------------------------------------------------------- /src/imclaslib/dataset/image_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import hashlib 4 | import cv2 5 | import numpy as np 6 | from torchvision.transforms import v2 as transforms 7 | from torchvision.transforms.v2.functional import resize 8 | from torch.utils.data import Dataset 9 | from imclaslib.files import pathutils 10 | from imclaslib.logging.loggerfactory import LoggerFactory 11 | import imclaslib.dataset.datasetutils as datasetutils 12 | from torchvision.transforms.functional import InterpolationMode 13 | from PIL import Image 14 | import pandas as pd 15 | logger = LoggerFactory.get_logger(f"logger.{__name__}") 16 | 17 | class ImageDataset(Dataset): 18 | """ 19 | A dataset class for loading and transforming images for model training and evaluation. 20 | """ 21 | 22 | def __init__(self, csv, mode, config, random_state=42): 23 | """ 24 | Initializes the dataset with images and labels based on the provided CSV file and mode. 25 | 26 | Parameters: 27 | - csv: pandas.DataFrame, contains file paths to images and associated labels. 28 | - mode: str, one of 'train', 'valid', or 'test' to determine dataset usage. 29 | - config: Config, configuration object containing dataset parameters. 30 | - random_state: int, random state for reproducible train-test splits. 31 | """ 32 | if mode not in ['train', 'valid', 'test', 'valid+test', 'all']: 33 | raise ValueError("Mode must be 'train', 'valid', 'test', 'valid+test', or 'all'.") 34 | 35 | self.label_mapping = config.dataset_tags_mapping_dict 36 | self.class_to_idx = datasetutils.get_tag_to_index_mapping(config) 37 | self.csv = csv 38 | self.config = config 39 | self.mode = mode 40 | #self.csv['identifier'] = self.csv['filepath'] 41 | if config.using_wsl: 42 | self.csv['filepath'] = self.csv['filepath'].apply(pathutils.convert_windows_path_to_wsl) 43 | 44 | self.csv['identifier'] = self.csv['filepath'].apply(lambda x: os.path.basename(x)) 45 | self.all_image_names = self.csv[:]['filepath'] 46 | 47 | # Assuming all columns other than 'filepath' and 'identifier' are labels and should be numeric 48 | label_columns = self.csv.columns.drop(['filepath', 'identifier']) 49 | 50 | # Convert label columns to a numeric type (e.g., float32) and handle NaNs 51 | self.csv[label_columns] = self.csv[label_columns].apply(pd.to_numeric, errors='coerce').fillna(0.0).astype(np.float32) 52 | 53 | # Now create the all_labels array with a uniform dtype 54 | self.all_labels = np.array(self.csv[label_columns]) 55 | 56 | self.image_size = self.config.model_image_size 57 | train_size = config.dataset_train_percentage 58 | valid_size = config.dataset_valid_percentage 59 | test_size = config.dataset_test_percentage 60 | total_size = train_size + valid_size + test_size 61 | if total_size > 100: 62 | raise ValueError("The sum of train, valid, and test percentages should be <= 100.") 63 | 64 | # Convert self.all_image_names to a list if it's a pandas Series 65 | self.all_image_names = self.all_image_names.tolist() 66 | 67 | # Perform a stable split 68 | train_data, valid_data, test_data = self.stable_split( 69 | self.csv, train_size, valid_size, test_size, random_state=random_state 70 | ) 71 | 72 | # Map back to the original data format 73 | train_names = train_data['filepath'].tolist() 74 | train_labels = self.map_labels(train_data) 75 | 76 | valid_names = valid_data['filepath'].tolist() 77 | valid_labels = self.map_labels(valid_data) 78 | 79 | test_names = test_data['filepath'].tolist() 80 | test_labels = self.map_labels(test_data) 81 | 82 | # Concatenate validation and test sets to create valid+test set 83 | valid_test_names = np.concatenate((valid_names, test_names)) 84 | valid_test_labels = np.vstack((valid_labels, test_labels)) 85 | 86 | # Assign data based on mode 87 | if self.mode == 'train': 88 | self.image_names = train_names 89 | self.labels = train_labels 90 | self.transform = self.train_transforms() 91 | elif self.mode == 'valid': 92 | self.image_names = valid_names 93 | self.labels = valid_labels 94 | self.transform = self.valid_transforms() 95 | elif self.mode == 'test': 96 | self.image_names = test_names 97 | self.labels = test_labels 98 | self.transform = self.test_transforms() 99 | elif self.mode == 'valid+test': 100 | # Combine valid and test sets for the valid+test mode 101 | self.image_names = valid_test_names 102 | self.labels = valid_test_labels 103 | self.transform = self.test_transforms() 104 | elif self.mode == 'all': 105 | # Combine valid and test sets for the valid+test mode 106 | self.image_names = self.all_image_names 107 | self.labels = self.all_labels 108 | self.transform = self.test_transforms() 109 | else: 110 | raise ValueError("Mode must be 'train', 'valid', 'test', or 'valid+test'.") 111 | 112 | if self.config.dataset_preprocess_to_RAM: 113 | self.data = [] 114 | for index, file_path in enumerate(self.image_names): 115 | label = self.labels[index] 116 | 117 | image = Image.open(file_path).convert('RGB') 118 | if image is None: 119 | logger.warning(f"Warning: Image not found or corrupted at path: {file_path}") 120 | return None 121 | image = resize(image, (self.image_size, self.image_size), interpolation=InterpolationMode.BICUBIC) 122 | item = { 123 | 'image': image, 124 | 'label': torch.tensor(label, dtype=torch.float32), 125 | 'image_path': file_path 126 | } 127 | self.data.append(item) 128 | 129 | # Apply the label mapping to each subset after splitting 130 | def map_labels(self, data): 131 | # Initialize a label matrix for the given subset of data 132 | label_matrix = np.zeros((len(data), len(self.class_to_idx)), dtype=float) 133 | 134 | # Map the old and new labels using the dictionary 135 | for col in self.csv.columns: 136 | if col in self.class_to_idx: 137 | # This label is directly in the class_to_idx, so use it as is 138 | class_idx = self.class_to_idx[col] 139 | label_matrix[:, class_idx] = data[col].values 140 | if col in self.label_mapping: 141 | # This label should be mapped to another label 142 | mapped_label = self.label_mapping[col] 143 | class_idx = self.class_to_idx[mapped_label] 144 | # Set the broader category label to true if this or any previously mapped label is true 145 | label_matrix[:, class_idx] = np.logical_or(label_matrix[:, class_idx], data[col].values) 146 | return label_matrix 147 | 148 | # Define a function to scale the augmentation parameters based on input level (0-10) 149 | def scale_parameter(self, min_val, max_val, level): 150 | """Scales the parameter based on the augmentation level (0-10).""" 151 | return min_val + (max_val - min_val) * level / 10 152 | 153 | def train_transforms(self): 154 | augmentation_level = self.config.dataset_augmentation_level 155 | assert 0 <= augmentation_level <= 10, "Augmentation level must be between 0 and 10" 156 | 157 | # Define the augmentation parameters scaled by the augmentation_level 158 | horizontal_flip_prob = self.scale_parameter(0, 0.5, augmentation_level) 159 | color_jitter_brightness = self.scale_parameter(0, 0.5, augmentation_level) 160 | color_jitter_contrast = self.scale_parameter(0, 0.5, augmentation_level) 161 | color_jitter_saturation = self.scale_parameter(0, 0.5, augmentation_level) 162 | rotation_degrees = self.scale_parameter(0, 45, augmentation_level) 163 | affine_transform_degrees = self.scale_parameter(0, 10, augmentation_level) 164 | affine_transform_translate = self.scale_parameter(0, 0.05, augmentation_level) 165 | affine_transform_scale_min = self.scale_parameter(1, 0.95, augmentation_level) 166 | affine_transform_scale_max = self.scale_parameter(1, 1.05, augmentation_level) 167 | perspective_distortion_scale = self.scale_parameter(0, 0.2, augmentation_level) 168 | gaussian_blur_sigma = self.scale_parameter(0.1, 2, augmentation_level) 169 | random_erasing_prob = self.scale_parameter(0, 0.3, augmentation_level) 170 | 171 | # Now, create the list of transforms with the scaled parameters 172 | transforms_list = [] 173 | 174 | 175 | if self.config.dataset_normalization_mean == None: 176 | transforms_list = [ 177 | transforms.ToImage(), 178 | transforms.Resize((self.image_size, self.image_size), interpolation=InterpolationMode.BICUBIC), 179 | transforms.RandomHorizontalFlip(p=horizontal_flip_prob) if augmentation_level > 0 else None, 180 | transforms.ColorJitter(brightness=color_jitter_brightness, contrast=color_jitter_contrast, saturation=color_jitter_saturation) if augmentation_level > 0 else None, 181 | transforms.RandomRotation(degrees=rotation_degrees) if augmentation_level > 0 else None, 182 | transforms.RandomAffine(degrees=affine_transform_degrees, translate=(affine_transform_translate, affine_transform_translate), 183 | scale=(affine_transform_scale_min, affine_transform_scale_max)) if augmentation_level > 0 else None, 184 | transforms.RandomPerspective(distortion_scale=perspective_distortion_scale, p=0.5) if augmentation_level > 0 else None, 185 | transforms.GaussianBlur(kernel_size=(5, 9), sigma=gaussian_blur_sigma) if augmentation_level > 0 else None, 186 | transforms.ToDtype(torch.float32, scale=True), 187 | transforms.RandomErasing(p=random_erasing_prob, scale=(0.02, 0.1), ratio=(0.3, 3.3), value=0) if augmentation_level > 0 else None, 188 | ] 189 | elif self.config.dataset_preprocess_to_RAM: 190 | transforms_list = [ 191 | transforms.ToImage(), 192 | transforms.RandomHorizontalFlip(p=horizontal_flip_prob) if augmentation_level > 0 else None, 193 | transforms.ColorJitter(brightness=color_jitter_brightness, contrast=color_jitter_contrast, saturation=color_jitter_saturation) if augmentation_level > 0 else None, 194 | transforms.RandomRotation(degrees=rotation_degrees) if augmentation_level > 0 else None, 195 | transforms.RandomAffine(degrees=affine_transform_degrees, translate=(affine_transform_translate, affine_transform_translate), 196 | scale=(affine_transform_scale_min, affine_transform_scale_max)) if augmentation_level > 0 else None, 197 | transforms.RandomPerspective(distortion_scale=perspective_distortion_scale, p=0.5) if augmentation_level > 0 else None, 198 | transforms.GaussianBlur(kernel_size=(5, 9), sigma=gaussian_blur_sigma) if augmentation_level > 0 else None, 199 | transforms.ToDtype(torch.float32, scale=True), 200 | transforms.RandomErasing(p=random_erasing_prob, scale=(0.02, 0.1), ratio=(0.3, 3.3), value=0) if augmentation_level > 0 else None, 201 | transforms.Normalize(mean=self.config.dataset_normalization_mean, std=self.config.dataset_normalization_std), 202 | ] 203 | else: 204 | transforms_list = [ 205 | transforms.ToImage(), 206 | transforms.Resize((self.image_size, self.image_size), interpolation=InterpolationMode.BICUBIC), 207 | transforms.RandomHorizontalFlip(p=horizontal_flip_prob) if augmentation_level > 0 else None, 208 | transforms.ColorJitter(brightness=color_jitter_brightness, contrast=color_jitter_contrast, saturation=color_jitter_saturation) if augmentation_level > 0 else None, 209 | transforms.RandomRotation(degrees=rotation_degrees) if augmentation_level > 0 else None, 210 | transforms.RandomAffine(degrees=affine_transform_degrees, translate=(affine_transform_translate, affine_transform_translate), 211 | scale=(affine_transform_scale_min, affine_transform_scale_max)) if augmentation_level > 0 else None, 212 | transforms.RandomPerspective(distortion_scale=perspective_distortion_scale, p=0.5) if augmentation_level > 0 else None, 213 | transforms.GaussianBlur(kernel_size=(5, 9), sigma=gaussian_blur_sigma) if augmentation_level > 0 else None, 214 | transforms.ToDtype(torch.float32, scale=True), 215 | transforms.RandomErasing(p=random_erasing_prob, scale=(0.02, 0.1), ratio=(0.3, 3.3), value=0) if augmentation_level > 0 else None, 216 | transforms.Normalize(mean=self.config.dataset_normalization_mean, std=self.config.dataset_normalization_std), 217 | ] 218 | 219 | # Filter out None transforms (i.e., when augmentation_level is 0) 220 | transforms_list = [t for t in transforms_list if t is not None] 221 | 222 | return transforms.Compose(transforms_list) 223 | 224 | def valid_transforms(self): 225 | transforms_list = [] 226 | if self.config.dataset_normalization_mean == None: 227 | transforms_list = [ 228 | transforms.ToImage(), 229 | transforms.Resize((self.image_size, self.image_size), interpolation=InterpolationMode.BICUBIC), 230 | transforms.ToDtype(torch.float32, scale=True), 231 | ] 232 | elif self.config.dataset_preprocess_to_RAM: 233 | transforms_list = [ 234 | transforms.ToImage(), 235 | transforms.ToDtype(torch.float32, scale=True), 236 | transforms.Normalize(mean=self.config.dataset_normalization_mean, std=self.config.dataset_normalization_std), 237 | ] 238 | else: 239 | transforms_list = [ 240 | transforms.ToImage(), 241 | transforms.Resize((self.image_size, self.image_size), interpolation=InterpolationMode.BICUBIC), 242 | transforms.ToDtype(torch.float32, scale=True), 243 | transforms.Normalize(mean=self.config.dataset_normalization_mean, std=self.config.dataset_normalization_std), 244 | ] 245 | return transforms.Compose(transforms_list) 246 | 247 | def test_transforms(self): 248 | transforms_list = [] 249 | if self.config.dataset_normalization_mean == None: 250 | transforms_list = [ 251 | transforms.ToImage(), 252 | transforms.Resize((self.image_size, self.image_size), interpolation=InterpolationMode.BICUBIC), 253 | transforms.ToDtype(torch.float32, scale=True), 254 | ] 255 | elif self.config.dataset_preprocess_to_RAM: 256 | transforms_list = [ 257 | transforms.ToImage(), 258 | transforms.ToDtype(torch.float32, scale=True), 259 | transforms.Normalize(mean=self.config.dataset_normalization_mean, std=self.config.dataset_normalization_std), 260 | ] 261 | else: 262 | transforms_list = [ 263 | transforms.ToImage(), 264 | transforms.Resize((self.image_size, self.image_size), interpolation=InterpolationMode.BICUBIC), 265 | transforms.ToDtype(torch.float32, scale=True), 266 | transforms.Normalize(mean=self.config.dataset_normalization_mean, std=self.config.dataset_normalization_std), 267 | ] 268 | 269 | return transforms.Compose(transforms_list) 270 | 271 | def __len__(self): 272 | """ 273 | Returns the number of items in the dataset. 274 | """ 275 | return len(self.image_names) 276 | 277 | def __getitem__(self, index): 278 | """ 279 | Retrieves an image and its labels at the given index, applying appropriate transforms. 280 | """ 281 | if self.config.dataset_preprocess_to_RAM: 282 | return { 283 | 'image': self.transform((self.data[index])['image']), 284 | 'label': (self.data[index])['label'], 285 | 'image_path': (self.data[index])['image_path'] 286 | } 287 | image_path = self.image_names[index] 288 | image = Image.open(image_path).convert('RGB') 289 | if image is None: 290 | logger.warning(f"Warning: Image not found or corrupted at path: {image_path}") 291 | return None 292 | # apply image transforms 293 | image = self.transform(image) 294 | targets = self.labels[index] 295 | 296 | return { 297 | 'image': image, 298 | 'label': torch.tensor(targets, dtype=torch.float32), 299 | 'image_path': image_path 300 | } 301 | 302 | def stable_hash(self, x): 303 | # Use a large prime number to take the modulus of the hash 304 | large_prime = 2**61 - 1 305 | return int(hashlib.sha256(x.encode('utf-8')).hexdigest(), 16) % large_prime 306 | 307 | def stable_split(self, data, train_percent, valid_percent, test_percent, random_state=None): 308 | # Ensure that the sum of the sizes is <= 1 309 | if train_percent + valid_percent + test_percent > 100: 310 | raise ValueError("The sum of train, valid, and test sizes should be <= 100.") 311 | 312 | if random_state is not None: 313 | np.random.seed(random_state) # Set random seed for reproducibility 314 | 315 | # Assign a unique number to each element based on a hash of its identifier 316 | hashed_ids = data['identifier'].apply(lambda x: self.stable_hash(video_frame_group(x))) 317 | 318 | # Calculate the split thresholds 319 | train_threshold = np.percentile(hashed_ids, train_percent) 320 | valid_threshold = np.percentile(hashed_ids, (train_percent + valid_percent)) 321 | 322 | # Determine the subset for each element based on its hashed ID 323 | train_mask = hashed_ids < train_threshold 324 | valid_mask = (hashed_ids >= train_threshold) & (hashed_ids < valid_threshold) 325 | test_mask = hashed_ids >= valid_threshold 326 | 327 | train_data = data[train_mask] 328 | valid_data = data[valid_mask] 329 | test_data = data[test_mask] 330 | 331 | return train_data, valid_data, test_data 332 | 333 | def is_video_frame(identifier): 334 | return 'video' in identifier and 'frame' in identifier and 'studio' in identifier 335 | 336 | def video_frame_group(identifier): 337 | if is_video_frame(identifier): 338 | splits = identifier.split('-', maxsplit=1) 339 | return splits[0] + '-' + splits[1] # Returns 'studio_-video_' 340 | return identifier 341 | 342 | -------------------------------------------------------------------------------- /src/imclaslib/training/modeltrainer.py: -------------------------------------------------------------------------------- 1 | import imclaslib.models.modelfactory as modelfactory 2 | import torch.nn as nn 3 | from tqdm import tqdm 4 | import torch.optim as optim 5 | import torch 6 | import gc 7 | import os 8 | import imclaslib.files.pathutils as pathutils 9 | from imclaslib.logging.loggerfactory import LoggerFactory 10 | import imclaslib.models.modelutils as modelutils 11 | import imclaslib.files.modelloadingutils as modelloadingutils 12 | from torch.cuda.amp import autocast, GradScaler 13 | import copy 14 | import random 15 | import torch.nn.functional as F 16 | from imclaslib.metrics import metricutils 17 | from timm.loss import AsymmetricLossMultiLabel 18 | from imclaslib.models.multilabel_focal_loss import MultiLabelFocalLoss 19 | from imclaslib.models.multilabel_dice_loss import DiceLoss 20 | logger = LoggerFactory.get_logger(f"logger.{__name__}") 21 | 22 | class ModelTrainer(): 23 | def __init__(self, device, trainloader, validloader, testloader, config, wandbWriter=None): 24 | """ 25 | Initializes the ModelTrainer with the given datasets, device, and configuration. 26 | 27 | Parameters: 28 | device (torch.device): The device on which to train the model. 29 | trainloader (DataLoader): DataLoader for the training dataset. 30 | validloader (DataLoader): DataLoader for the validation dataset. 31 | testloader (DataLoader): DataLoader for the test dataset. 32 | config (module): Configuration module with necessary attributes. 33 | """ 34 | self.config = config 35 | self.wandbWriter = wandbWriter 36 | self.metrics_enabled = (wandbWriter != None) 37 | self.device = device 38 | self.trainloader = trainloader 39 | self.validloader = validloader 40 | self.testloader = testloader 41 | self.model = modelfactory.create_model(self.config).to(device) 42 | if self.config.train_l2_enabled: 43 | self.optimizer = optim.Adam(self.model.parameters(), lr=self.config.train_learning_rate, weight_decay=self.config.train_l2_lambda) 44 | else: 45 | self.optimizer = optim.Adam(self.model.parameters(), lr=self.config.train_learning_rate) 46 | 47 | # Compute label frequencies and create weights for the loss function 48 | #self.label_freqs = self.compute_label_frequencies() 49 | #self.pos_weight = self.compute_loss_weights(self.label_freqs).to(device) 50 | self.criterion = nn.BCEWithLogitsLoss()#MultiLabelFocalLoss() 51 | self.epochs = self.config.train_num_epochs 52 | self.lr_scheduler = modelutils.get_learningRate_scheduler(self.optimizer, config) 53 | self.last_train_loss = 10000 54 | self.last_valid_loss = 10000 55 | self.last_valid_f1 = 0 56 | self.current_lr = self.config.train_learning_rate 57 | 58 | modelToLoadPath = pathutils.get_model_to_load_path(self.config) 59 | if self.config.train_continue_training and os.path.exists(modelToLoadPath): 60 | logger.info("Loading the best model...") 61 | if self.config.model_embedding_layer_enabled or self.config.model_gcn_enabled and self.config.train_model_to_load_raw_weights != "": 62 | self.model, modelData = modelloadingutils.load_pretrained_weights_exclude_classifier(self.model, self.config, True) 63 | else: 64 | modelData = modelloadingutils.load_model(modelToLoadPath, self.config) 65 | self.model.load_state_dict(modelData['model_state_dict']) 66 | logger.info("Loaded the best model.") 67 | #self.optimizer.load_state_dict(modelData['optimizer_state_dict']) 68 | 69 | self.start_epoch = modelData["epoch"] + 1 70 | self.epochs = self.epochs + self.start_epoch 71 | self.best_f1_score = 0.0 72 | self.__set_best_model_state(modelData["epoch"]) 73 | 74 | 75 | else: 76 | self.best_f1_score = 0.0 77 | self.start_epoch = 0 78 | self.best_model_state = None 79 | self.__set_best_model_state(self.start_epoch) 80 | self.current_epoch = self.start_epoch - 1 81 | self.best_f1_score_at_last_reset = 0 82 | self.patience_counter = 0 83 | self.patience = self.config.train_early_stopping_patience 84 | 85 | if config.using_wsl and config.train_compile: 86 | self.compile() 87 | if self.metrics_enabled: 88 | self.wandbWriter.watch(self.model) 89 | 90 | def __enter__(self): 91 | """ 92 | Enter the runtime context for the ModelTrainer object. 93 | Allows the ModelTrainer to be used with the 'with' statement, ensuring resources are managed properly. 94 | 95 | Returns: 96 | ModelTrainer: The instance with which the context was entered. 97 | """ 98 | return self 99 | 100 | def __exit__(self, exc_type, exc_value, traceback): 101 | """ 102 | Exit the runtime context for the ModelTrainer object. 103 | This method is called after the 'with' block is executed, and it ensures that the TensorBoard writer is closed. 104 | 105 | Parameters: 106 | exc_type: Exception type, if any exception was raised within the 'with' block. 107 | exc_value: Exception value, the exception instance raised. 108 | traceback: Traceback object with details of where the exception occurred. 109 | """ 110 | del self.model 111 | del self.optimizer 112 | torch.cuda.empty_cache() 113 | gc.collect() 114 | 115 | def smooth_labels(self, labels): 116 | """ 117 | Applies label smoothing. Turning the vector of 0s and 1s into a vector of 118 | `smoothing / num_classes` and `1 - smoothing + (smoothing / num_classes)`. 119 | Args: 120 | labels: The binary labels (0 or 1). 121 | smoothing: The degree of smoothing (0 means no smoothing). 122 | Returns: 123 | The smoothed labels. 124 | """ 125 | smoothing = self.config.train_label_smoothing 126 | with torch.no_grad(): 127 | num_classes = labels.size(1) 128 | # Create a tensor of `smoothing / num_classes` for each label 129 | smooth_value = smoothing / num_classes 130 | # Subtract smoothing from the 1s, add it to the 0s 131 | labels = labels * (1 - smoothing) + (1 - labels) * smooth_value 132 | return labels 133 | 134 | def compile(self): 135 | self.model = torch.compile(self.model) 136 | 137 | def train(self): 138 | """ 139 | Train the model for one epoch using the provided training dataset. 140 | :return: The average training loss for the epoch. 141 | """ 142 | self.current_epoch += 1 143 | logger.info('Training') 144 | self.model.train() 145 | train_running_loss = 0.0 146 | 147 | # Initialize the gradient scaler for mixed precision 148 | scaler = GradScaler(enabled=self.config.model_fp16) 149 | 150 | for data in tqdm(self.trainloader, total=len(self.trainloader)): 151 | images, targets = data['image'].to(self.device), data['label'].to(self.device).float() 152 | self.optimizer.zero_grad() 153 | 154 | # Cast operations to mixed precision 155 | with autocast(enabled=self.config.model_fp16): 156 | if (self.config.model_embedding_layer_enabled or self.config.model_gcn_enabled): 157 | label_dropout_rate = 0.9 158 | use_labels = random.random() > label_dropout_rate 159 | if use_labels: 160 | outputs = self.model(images, targets) 161 | else: 162 | outputs = self.model(images) 163 | else: 164 | outputs = self.model(images) 165 | 166 | # Verify that outputs and targets have the same shape 167 | if outputs.shape != targets.shape: 168 | logger.error(f"Mismatched shapes detected: Outputs shape: {outputs.shape}, Targets shape: {targets.shape}") 169 | # Here you could also raise an exception or handle the error in some way 170 | loss = self.criterion(outputs, self.smooth_labels(targets)) 171 | 172 | # Scale the loss and call backward() to create scaled gradients 173 | scaler.scale(loss).backward() 174 | 175 | # Step optimizer and update the scale for next iteration 176 | scaler.step(self.optimizer) 177 | scaler.update() 178 | 179 | train_running_loss += loss.item() 180 | 181 | train_loss = train_running_loss / len(self.trainloader.dataset) 182 | self.last_train_loss = train_loss 183 | return train_loss 184 | 185 | def distillation_criterion(self, student_outputs, teacher_outputs, student_targets, alpha=0.5, temperature=1.5): 186 | """ 187 | Calculate the distillation loss for multilabel classification. 188 | :param student_outputs: Logits from the student model. 189 | :param teacher_outputs: Logits from the teacher model. 190 | :param student_targets: Ground truth labels. 191 | :param alpha: Weight for combining the soft and hard loss components. 192 | :param temperature: Temperature scaling factor for softening probabilities. 193 | :return: The combined distillation loss. 194 | """ 195 | # Apply the sigmoid function to get the probabilities since we are dealing with multilabel classification 196 | teacher_probs = torch.sigmoid(teacher_outputs / temperature) 197 | 198 | # Calculate the binary cross-entropy loss between the soft targets and the student outputs 199 | soft_loss = F.binary_cross_entropy_with_logits(student_outputs / temperature, teacher_probs, reduction='mean') 200 | 201 | # Calculate the binary cross-entropy loss between the true labels and the student outputs 202 | hard_loss = F.binary_cross_entropy_with_logits(student_outputs, student_targets, reduction='mean') 203 | 204 | # Combine the soft and hard losses 205 | loss = alpha * soft_loss * (temperature ** 2) + (1 - alpha) * hard_loss 206 | return loss 207 | def distill(self, teacher_model, teacher_trainloader): 208 | """ 209 | Distill knowledge from the teacher model to the student model for one epoch using the provided training dataset. 210 | :param teacher_model: The pre-trained teacher model from which knowledge will be transferred. 211 | :return: The average training loss for the epoch. 212 | """ 213 | self.current_epoch += 1 214 | logger.info('Distillation') 215 | self.model.train() 216 | teacher_model.eval() 217 | train_running_loss = 0.0 218 | 219 | # Initialize the gradient scaler for mixed precision 220 | scaler = GradScaler(enabled=self.config.model_fp16) 221 | 222 | # Ensure the student's trainloader and the teacher's trainloader have the same length 223 | assert len(self.trainloader) == len(teacher_trainloader), "The student and teacher trainloaders must have the same number of batches." 224 | 225 | for (student_data, teacher_data) in tqdm(zip(self.trainloader, teacher_trainloader), total=len(self.trainloader)): 226 | student_images, student_targets = student_data['image'].to(self.device), student_data['label'].to(self.device).float() 227 | teacher_images = teacher_data['image'].to(self.device) 228 | self.optimizer.zero_grad() 229 | 230 | # Forward pass of the teacher model to obtain soft labels 231 | with torch.no_grad(): 232 | teacher_outputs = teacher_model(teacher_images) 233 | 234 | # Cast operations to mixed precision 235 | with autocast(enabled=self.config.model_fp16): 236 | student_outputs = self.model(student_images) 237 | 238 | # Verify that outputs and targets have the same shape 239 | if student_outputs.shape != teacher_outputs.shape: 240 | logger.error(f"Mismatched shapes detected: Student outputs shape: {student_outputs.shape}, Teacher outputs shape: {teacher_outputs.shape}") 241 | 242 | # Calculate the distillation loss using the teacher's outputs as soft targets 243 | # For distillation, you might want to use a different criterion or adjust `self.criterion` to handle teacher-student loss 244 | # This might involve a combination of soft targets from the teacher and true labels, depending on your approach 245 | if self.config.model_temperature != None: 246 | teacher_outputs = metricutils.temperature_scale(teacher_outputs, self.config.model_temperature) 247 | loss = self.distillation_criterion(student_outputs, teacher_outputs, student_targets) 248 | 249 | # Scale the loss and call backward() to create scaled gradients 250 | scaler.scale(loss).backward() 251 | 252 | # Step optimizer and update the scale for next iteration 253 | scaler.step(self.optimizer) 254 | scaler.update() 255 | 256 | train_running_loss += loss.item() 257 | 258 | train_loss = train_running_loss / len(self.trainloader.dataset) 259 | self.last_train_loss = train_loss 260 | logger.info(f"Epoch {self.current_epoch} - Distillation Loss: {train_loss:.4f}") 261 | return train_loss 262 | 263 | def validate(self, modelEvaluator, threshold=None): 264 | """ 265 | Validate the model on the validation dataset using a model evaluator. 266 | 267 | Parameters: 268 | modelEvaluator: An instance of the model evaluator class with an 'evaluate' method. 269 | threshold (Optional[float]): Threshold value for converting probabilities to class labels. 270 | 271 | Returns: 272 | tuple: A tuple containing the average validation loss and the micro-averaged F1 score. 273 | """ 274 | logger.info("Validating") 275 | valid_loss, valid_f1, _, _ = modelEvaluator.evaluate(self.validloader, self.current_epoch, "Validation", threshold=threshold) 276 | self.last_valid_loss = valid_loss 277 | self.last_valid_f1 = valid_f1 278 | self.log_train_validation_results() 279 | return valid_loss, valid_f1 280 | 281 | def learningRateScheduler_check(self): 282 | """ 283 | Check and update the learning rate based on the validation loss. Log the updated learning rate to TensorBoard. 284 | """ 285 | self.lr_scheduler.step(self.last_valid_f1) 286 | oldlr = self.current_lr 287 | self.current_lr = self.optimizer.param_groups[0]['lr'] 288 | if self.metrics_enabled: 289 | self.wandbWriter.log({"Train/Learning_Rate": self.current_lr}, step=self.current_epoch) 290 | if oldlr != self.current_lr: 291 | logger.info(f"Reducing learning rate from {oldlr} to {self.current_lr}") 292 | 293 | def log_train_validation_results(self): 294 | """ 295 | Log training and validation results to the logger and TensorBoard. 296 | Includes the train loss, validation loss, and validation F1 score for the current epoch. 297 | """ 298 | logger.info(f"Train Loss: {self.last_train_loss:.4f}") 299 | logger.info(f'Validation Loss: {self.last_valid_loss:.4f}') 300 | logger.info(f'Validation F1 Score: {self.last_valid_f1:.4f}') 301 | 302 | if self.metrics_enabled: 303 | self.wandbWriter.log({"Loss/Train": self.last_train_loss, "Loss/Validation": self.last_valid_loss, "F1/Validation": self.last_valid_f1}, step=self.current_epoch) 304 | 305 | def check_early_stopping(self): 306 | """ 307 | Check if early stopping criteria are met based on the validation F1 score. 308 | If the score has not improved by a certain proportion over the patience window, 309 | trigger early stopping. 310 | 311 | Returns: 312 | bool: True if early stopping is triggered, False otherwise. 313 | """ 314 | improvement_threshold = self.config.train_early_stopping_threshold 315 | significant_improvement = False 316 | if self.last_valid_f1 > self.best_f1_score: 317 | logger.info(f"Validation F1 Score improved from {self.best_f1_score:.4f} to {self.last_valid_f1:.4f}") 318 | self.best_f1_score = self.last_valid_f1 319 | self.__set_best_model_state(self.current_epoch) 320 | 321 | modelloadingutils.save_best_model(self.best_model_state, self.config) 322 | 323 | # Check for significant improvement since the last reset of the patience counter 324 | if self.last_valid_f1 - self.best_f1_score_at_last_reset >= improvement_threshold: 325 | logger.info(f"Significant cumulative improvement of {self.last_valid_f1 - self.best_f1_score_at_last_reset:.4f} has been achieved since the last reset.") 326 | significant_improvement = True 327 | self.best_f1_score_at_last_reset = self.last_valid_f1 328 | self.patience_counter = 0 329 | 330 | # Increment patience counter if no significant improvement 331 | if not significant_improvement: 332 | self.patience_counter += 1 333 | 334 | # If there hasn't been significant improvement over the patience window, trigger early stopping 335 | if self.patience_counter >= self.patience: 336 | logger.info(f"Early stopping triggered after {self.patience} epochs without significant cumulative improvement.") 337 | return True 338 | 339 | 340 | def save_final_model(self): 341 | """ 342 | Save the state of the model that achieved the best validation F1 score during training. 343 | The model state is saved to a file defined by the configuration. 344 | """ 345 | state_to_save = copy.deepcopy(self.best_model_state) 346 | modelloadingutils.save_final_model(self.best_model_state, self.best_f1_score, self.config) 347 | self.model.load_state_dict(state_to_save['model_state_dict']) 348 | 349 | def compute_label_frequencies(self): 350 | """ 351 | Computes the frequency of each label in the dataset. 352 | 353 | Returns: 354 | label_freqs (torch.Tensor): Tensor containing the frequency of each label. 355 | """ 356 | # Initialize a tensor to hold the frequency of each label. 357 | # This assumes that the number of labels is known and stored in `self.config.num_classes`. 358 | label_freqs = torch.zeros(self.config.model_num_classes, dtype=torch.float) 359 | 360 | # Iterate over the dataset and sum the one-hot encoded labels. 361 | for batch in tqdm(self.trainloader, total=len(self.trainloader)): 362 | labels = batch["label"] 363 | label_freqs += labels.sum(dim=0) # Sum along the batch dimension. 364 | 365 | # Ensure that there's at least one count for each label to avoid division by zero. 366 | label_freqs = label_freqs.clamp(min=1) 367 | return label_freqs 368 | 369 | def compute_loss_weights(self, label_freqs): 370 | """ 371 | Computes the weights for each label to be used in the loss function. 372 | 373 | Parameters: 374 | label_freqs (torch.Tensor): Tensor containing the frequency of each label. 375 | 376 | Returns: 377 | weights (torch.Tensor): Tensor containing the weight for each label. 378 | """ 379 | # Compute the inverse frequency weights 380 | total_counts = label_freqs.sum() 381 | weights = total_counts / label_freqs 382 | 383 | # Normalize weights to prevent them from scaling the loss too much 384 | weights = weights / weights.mean() 385 | 386 | #weights = weights.view(-1) # Ensure it is a 1D tensor with shape [num_classes] 387 | #assert weights.shape[0] == self.config.num_classes, "pos_weight must have the same size as num_classes" 388 | 389 | return weights 390 | def __set_best_model_state(self, epoch): 391 | self.best_model_state = { 392 | 'epoch': epoch, 393 | 'model_state_dict': self.model.state_dict(), 394 | 'optimizer_state_dict': self.optimizer.state_dict(), 395 | 'loss': self.criterion, 396 | 'f1_score': self.best_f1_score, 397 | 'model_name': self.config.model_name, 398 | 'requires_grad': self.config.train_requires_grad, 399 | 'model_num_classes': self.config.model_num_classes, 400 | 'dropout': self.config.train_dropout_prob, 401 | 'embedding_layer': self.config.model_embedding_layer_enabled, 402 | 'model_gcn_enabled': self.config.model_gcn_enabled, 403 | 'train_batch_size': self.config.train_batch_size, 404 | 'optimizer': 'Adam', 405 | 'loss_function': 'BCEWithLogitsLoss', 406 | 'image_size': self.config.model_image_size, 407 | 'model_gcn_model_name': self.config.model_gcn_model_name, 408 | 'model_gcn_out_channels': self.config.model_gcn_out_channels, 409 | 'model_gcn_layers': self.config.model_gcn_layers, 410 | 'model_attention_layer_num_heads': self.config.model_attention_layer_num_heads, 411 | 'model_embedding_layer_dimension': self.config.model_embedding_layer_dimension, 412 | 'train_loss': self.last_train_loss, 413 | 'datset_version': self.config.dataset_version 414 | } --------------------------------------------------------------------------------