├── nauta_pipnet_cpvr.png ├── used_arguments ├── CARS_arguments.txt ├── PETS_arguments.txt └── CUB_arguments.txt ├── util ├── func.py ├── log.py ├── preprocess_cub.py ├── visualize_prediction.py ├── args.py ├── vis_pipnet.py ├── eval_cub_csv.py └── data.py ├── features ├── convnext_features.py └── resnet_features.py ├── pipnet ├── pipnet.py ├── train.py └── test.py ├── README.md └── main.py /nauta_pipnet_cpvr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/M-Nauta/PIPNet/HEAD/nauta_pipnet_cpvr.png -------------------------------------------------------------------------------- /used_arguments/CARS_arguments.txt: -------------------------------------------------------------------------------- 1 | dataset: 'CARS' 2 | validation_size: 0.0 3 | net: 'convnext_tiny_26' 4 | batch_size: 64 5 | batch_size_pretrain: 128 6 | epochs: 60 7 | optimizer: 'Adam' 8 | lr: 0.05 9 | lr_block: 0.0005 10 | lr_net: 0.0005 11 | weight_decay: 0.0 12 | disable_cuda: False 13 | log_dir: './runs/pipnet_cars_cnext26' 14 | num_features: 0 15 | image_size: 224 16 | state_dict_dir_net: '' 17 | freeze_epochs: 10 18 | dir_for_saving_images: 'Visualization_results' 19 | disable_pretrained: False 20 | epochs_pretrain: 10 21 | weighted_loss: False 22 | seed: 1 23 | gpu_ids: '' 24 | num_workers: 8 25 | bias: False 26 | -------------------------------------------------------------------------------- /used_arguments/PETS_arguments.txt: -------------------------------------------------------------------------------- 1 | dataset: 'pets' 2 | validation_size: 0.0 3 | net: 'convnext_tiny_26' 4 | batch_size: 64 5 | batch_size_pretrain: 128 6 | epochs: 60 7 | optimizer: 'Adam' 8 | lr: 0.05 9 | lr_block: 0.0001 10 | lr_net: 0.0001 11 | weight_decay: 0.0 12 | disable_cuda: False 13 | log_dir: './runs/pipnet_pets_cnext26' 14 | num_features: 0 15 | image_size: 224 16 | state_dict_dir_net: '' 17 | freeze_epochs: 10 18 | dir_for_saving_images: 'Visualization_results' 19 | disable_pretrained: False 20 | epochs_pretrain: 10 21 | weighted_loss: False 22 | seed: 1 23 | gpu_ids: '' 24 | num_workers: 8 25 | bias: False 26 | -------------------------------------------------------------------------------- /used_arguments/CUB_arguments.txt: -------------------------------------------------------------------------------- 1 | dataset: 'CUB-200-2011' 2 | validation_size: 0.0 3 | net: 'convnext_tiny_26' 4 | batch_size: 64 5 | batch_size_pretrain: 128 6 | epochs: 60 7 | optimizer: 'Adam' 8 | lr: 0.05 9 | lr_block: 0.0005 10 | lr_net: 0.0005 11 | weight_decay: 0.0 12 | disable_cuda: False 13 | log_dir: './runs/pipnet_cub_cnext26' 14 | num_features: 0 15 | image_size: 224 16 | state_dict_dir_net: '' 17 | freeze_epochs: 10 18 | dir_for_saving_images: 'Visualization_results' 19 | disable_pretrained: False 20 | epochs_pretrain: 10 21 | weighted_loss: False 22 | seed: 1 23 | gpu_ids: '' 24 | num_workers: 8 25 | bias: False 26 | -------------------------------------------------------------------------------- /util/func.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def get_patch_size(args): 4 | patchsize = 32 5 | skip = round((args.image_size - patchsize) / (args.wshape-1)) 6 | return patchsize, skip 7 | 8 | def init_weights_xavier(m): 9 | if type(m) == torch.nn.Conv2d: 10 | torch.nn.init.xavier_uniform_(m.weight, gain=torch.nn.init.calculate_gain('sigmoid')) 11 | 12 | # https://gist.github.com/weiaicunzai/2a5ae6eac6712c70bde0630f3e76b77b?permalink_comment_id=3662215#gistcomment-3662215 13 | def topk_accuracy(output, target, topk=[1,]): 14 | """ 15 | Computes the accuracy over the k top predictions for the specified values of k 16 | """ 17 | with torch.no_grad(): 18 | topk2 = [x for x in topk if x <= output.shape[1]] #ensures that k is not larger than number of classes 19 | maxk = max(topk2) 20 | 21 | _, pred = output.topk(maxk, 1, True, True) 22 | pred = pred.t() 23 | correct = (pred == target.unsqueeze(dim=0)).expand_as(pred) 24 | 25 | res = [] 26 | for k in topk: 27 | if k in topk2: 28 | correct_k = correct[:k].reshape(-1).float() 29 | res.append(correct_k) 30 | else: 31 | res.append(torch.zeros_like(target)) 32 | return res -------------------------------------------------------------------------------- /features/convnext_features.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision import models 4 | 5 | def replace_convlayers_convnext(model, threshold): 6 | for n, module in model.named_children(): 7 | if len(list(module.children())) > 0: 8 | replace_convlayers_convnext(module, threshold) 9 | if isinstance(module, nn.Conv2d): 10 | if module.stride[0] == 2: 11 | if module.in_channels > threshold: #replace bigger strides to reduce receptive field, skip some 2x2 layers. >100 gives output size (26, 26). >300 gives (13, 13) 12 | module.stride = tuple(s//2 for s in module.stride) 13 | 14 | return model 15 | 16 | def convnext_tiny_26_features(pretrained=False, **kwargs): 17 | model = models.convnext_tiny(pretrained=pretrained, weights=models.ConvNeXt_Tiny_Weights.DEFAULT) 18 | with torch.no_grad(): 19 | model.avgpool = nn.Identity() 20 | model.classifier = nn.Identity() 21 | model = replace_convlayers_convnext(model, 100) 22 | 23 | return model 24 | 25 | def convnext_tiny_13_features(pretrained=False, **kwargs): 26 | model = models.convnext_tiny(pretrained=pretrained, weights=models.ConvNeXt_Tiny_Weights.DEFAULT) 27 | with torch.no_grad(): 28 | model.avgpool = nn.Identity() 29 | model.classifier = nn.Identity() 30 | model = replace_convlayers_convnext(model, 300) 31 | 32 | return model 33 | -------------------------------------------------------------------------------- /util/log.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | from util.args import save_args 5 | 6 | class Log: 7 | 8 | """ 9 | Object for managing the log directory 10 | """ 11 | 12 | def __init__(self, log_dir: str): # Store log in log_dir 13 | 14 | self._log_dir = log_dir 15 | self._logs = dict() 16 | 17 | # Ensure the directories exist 18 | if not os.path.isdir(self.log_dir): 19 | os.mkdir(self.log_dir) 20 | if not os.path.isdir(self.metadata_dir): 21 | os.mkdir(self.metadata_dir) 22 | if not os.path.isdir(self.checkpoint_dir): 23 | os.mkdir(self.checkpoint_dir) 24 | 25 | 26 | @property 27 | def log_dir(self): 28 | return self._log_dir 29 | 30 | @property 31 | def checkpoint_dir(self): 32 | return self._log_dir + '/checkpoints' 33 | 34 | @property 35 | def metadata_dir(self): 36 | return self._log_dir + '/metadata' 37 | 38 | def log_message(self, msg: str): 39 | """ 40 | Write a message to the log file 41 | :param msg: the message string to be written to the log file 42 | """ 43 | if not os.path.isfile(self.log_dir + '/log.txt'): 44 | open(self.log_dir + '/log.txt', 'w').close() #make log file empty if it already exists 45 | with open(self.log_dir + '/log.txt', 'a') as f: 46 | f.write(msg+"\n") 47 | 48 | def create_log(self, log_name: str, key_name: str, *value_names): 49 | """ 50 | Create a csv for logging information 51 | :param log_name: The name of the log. The log filename will be .csv. 52 | :param key_name: The name of the attribute that is used as key (e.g. epoch number) 53 | :param value_names: The names of the attributes that are logged 54 | """ 55 | if log_name in self._logs.keys(): 56 | raise Exception('Log already exists!') 57 | # Add to existing logs 58 | self._logs[log_name] = (key_name, value_names) 59 | # Create log file. Create columns 60 | with open(self.log_dir + f'/{log_name}.csv', 'w') as f: 61 | f.write(','.join((key_name,) + value_names) + '\n') 62 | 63 | def log_values(self, log_name, key, *values): 64 | """ 65 | Log values in an existent log file 66 | :param log_name: The name of the log file 67 | :param key: The key attribute for logging these values 68 | :param values: value attributes that will be stored in the log 69 | """ 70 | if log_name not in self._logs.keys(): 71 | raise Exception('Log not existent!') 72 | if len(values) != len(self._logs[log_name][1]): 73 | raise Exception('Not all required values are logged!') 74 | # Write a new line with the given values 75 | with open(self.log_dir + f'/{log_name}.csv', 'a') as f: 76 | f.write(','.join(str(v) for v in (key,) + values) + '\n') 77 | 78 | def log_args(self, args: argparse.Namespace): 79 | save_args(args, self._log_dir) 80 | 81 | -------------------------------------------------------------------------------- /util/preprocess_cub.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import numpy as np 4 | import time 5 | from PIL import Image 6 | 7 | path = './data/CUB_200_2011/' 8 | 9 | time_start = time.time() 10 | 11 | path_images = os.path.join(path,'images.txt') 12 | path_split = os.path.join(path,'train_test_split.txt') 13 | train_save_path = os.path.join(path,'dataset/train_crop/') 14 | test_save_path = os.path.join(path,'dataset/test_crop/') 15 | bbox_path = os.path.join(path, 'bounding_boxes.txt') 16 | 17 | images = [] 18 | with open(path_images,'r') as f: 19 | for line in f: 20 | images.append(list(line.strip('\n').split(','))) 21 | print("Images: ", images) 22 | split = [] 23 | with open(path_split, 'r') as f_: 24 | for line in f_: 25 | split.append(list(line.strip('\n').split(','))) 26 | 27 | bboxes = dict() 28 | with open(bbox_path, 'r') as bf: 29 | for line in bf: 30 | id, x, y, w, h = tuple(map(float, line.split(' '))) 31 | bboxes[int(id)]=(x, y, w, h) 32 | 33 | num = len(images) 34 | for k in range(num): 35 | id, fn = images[k][0].split(' ') 36 | id = int(id) 37 | file_name = fn.split('/')[0] 38 | if int(split[k][0][-1]) == 1: 39 | 40 | if not os.path.isdir(train_save_path + file_name): 41 | os.makedirs(os.path.join(train_save_path, file_name)) 42 | img = Image.open(os.path.join(os.path.join(path, 'images'),images[k][0].split(' ')[1])).convert('RGB') 43 | x, y, w, h = bboxes[id] 44 | cropped_img = img.crop((x, y, x+w, y+h)) 45 | cropped_img.save(os.path.join(os.path.join(train_save_path,file_name),images[k][0].split(' ')[1].split('/')[1])) 46 | print('%s' % images[k][0].split(' ')[1].split('/')[1]) 47 | else: 48 | if not os.path.isdir(test_save_path + file_name): 49 | os.makedirs(os.path.join(test_save_path,file_name)) 50 | img = Image.open(os.path.join(os.path.join(path, 'images'),images[k][0].split(' ')[1])).convert('RGB') 51 | x, y, w, h = bboxes[id] 52 | cropped_img = img.crop((x, y, x+w, y+h)) 53 | cropped_img.save(os.path.join(os.path.join(test_save_path,file_name),images[k][0].split(' ')[1].split('/')[1])) 54 | print('%s' % images[k][0].split(' ')[1].split('/')[1]) 55 | 56 | 57 | train_save_path = os.path.join(path,'dataset/train/') 58 | test_save_path = os.path.join(path,'dataset/test_full/') 59 | 60 | num = len(images) 61 | for k in range(num): 62 | id, fn = images[k][0].split(' ') 63 | id = int(id) 64 | file_name = fn.split('/')[0] 65 | if int(split[k][0][-1]) == 1: 66 | 67 | if not os.path.isdir(train_save_path + file_name): 68 | os.makedirs(os.path.join(train_save_path, file_name)) 69 | img = Image.open(os.path.join(os.path.join(path, 'images'),images[k][0].split(' ')[1])).convert('RGB') 70 | width, height = img.size 71 | 72 | img.save(os.path.join(os.path.join(train_save_path,file_name),images[k][0].split(' ')[1].split('/')[1])) 73 | 74 | print('%s' % images[k][0].split(' ')[1].split('/')[1]) 75 | else: 76 | if not os.path.isdir(test_save_path + file_name): 77 | os.makedirs(os.path.join(test_save_path,file_name)) 78 | shutil.copy(path + 'images/' + images[k][0].split(' ')[1], os.path.join(os.path.join(test_save_path,file_name),images[k][0].split(' ')[1].split('/')[1])) 79 | print('%s' % images[k][0].split(' ')[1].split('/')[1]) 80 | time_end = time.time() 81 | print('CUB200, %s!' % (time_end - time_start)) 82 | -------------------------------------------------------------------------------- /pipnet/pipnet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from features.resnet_features import resnet18_features, resnet34_features, resnet50_features, resnet50_features_inat, resnet101_features, resnet152_features 6 | from features.convnext_features import convnext_tiny_26_features, convnext_tiny_13_features 7 | import torch 8 | from torch import Tensor 9 | 10 | class PIPNet(nn.Module): 11 | def __init__(self, 12 | num_classes: int, 13 | num_prototypes: int, 14 | feature_net: nn.Module, 15 | args: argparse.Namespace, 16 | add_on_layers: nn.Module, 17 | pool_layer: nn.Module, 18 | classification_layer: nn.Module 19 | ): 20 | super().__init__() 21 | assert num_classes > 0 22 | self._num_features = args.num_features 23 | self._num_classes = num_classes 24 | self._num_prototypes = num_prototypes 25 | self._net = feature_net 26 | self._add_on = add_on_layers 27 | self._pool = pool_layer 28 | self._classification = classification_layer 29 | self._multiplier = classification_layer.normalization_multiplier 30 | 31 | def forward(self, xs, inference=False): 32 | features = self._net(xs) 33 | proto_features = self._add_on(features) 34 | pooled = self._pool(proto_features) 35 | if inference: 36 | clamped_pooled = torch.where(pooled < 0.1, 0., pooled) #during inference, ignore all prototypes that have 0.1 similarity or lower 37 | out = self._classification(clamped_pooled) #shape (bs*2, num_classes) 38 | return proto_features, clamped_pooled, out 39 | else: 40 | out = self._classification(pooled) #shape (bs*2, num_classes) 41 | return proto_features, pooled, out 42 | 43 | 44 | base_architecture_to_features = {'resnet18': resnet18_features, 45 | 'resnet34': resnet34_features, 46 | 'resnet50': resnet50_features, 47 | 'resnet50_inat': resnet50_features_inat, 48 | 'resnet101': resnet101_features, 49 | 'resnet152': resnet152_features, 50 | 'convnext_tiny_26': convnext_tiny_26_features, 51 | 'convnext_tiny_13': convnext_tiny_13_features} 52 | 53 | # adapted from https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear 54 | class NonNegLinear(nn.Module): 55 | """Applies a linear transformation to the incoming data with non-negative weights` 56 | """ 57 | def __init__(self, in_features: int, out_features: int, bias: bool = True, 58 | device=None, dtype=None) -> None: 59 | factory_kwargs = {'device': device, 'dtype': dtype} 60 | super(NonNegLinear, self).__init__() 61 | self.in_features = in_features 62 | self.out_features = out_features 63 | self.weight = nn.Parameter(torch.empty((out_features, in_features), **factory_kwargs)) 64 | self.normalization_multiplier = nn.Parameter(torch.ones((1,),requires_grad=True)) 65 | if bias: 66 | self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs)) 67 | else: 68 | self.register_parameter('bias', None) 69 | 70 | def forward(self, input: Tensor) -> Tensor: 71 | return F.linear(input,torch.relu(self.weight), self.bias) 72 | 73 | 74 | def get_network(num_classes: int, args: argparse.Namespace): 75 | features = base_architecture_to_features[args.net](pretrained=not args.disable_pretrained) 76 | features_name = str(features).upper() 77 | if 'next' in args.net: 78 | features_name = str(args.net).upper() 79 | if features_name.startswith('RES') or features_name.startswith('CONVNEXT'): 80 | first_add_on_layer_in_channels = \ 81 | [i for i in features.modules() if isinstance(i, nn.Conv2d)][-1].out_channels 82 | else: 83 | raise Exception('other base architecture NOT implemented') 84 | 85 | 86 | if args.num_features == 0: 87 | num_prototypes = first_add_on_layer_in_channels 88 | print("Number of prototypes: ", num_prototypes, flush=True) 89 | add_on_layers = nn.Sequential( 90 | nn.Softmax(dim=1), #softmax over every prototype for each patch, such that for every location in image, sum over prototypes is 1 91 | ) 92 | else: 93 | num_prototypes = args.num_features 94 | print("Number of prototypes set from", first_add_on_layer_in_channels, "to", num_prototypes,". Extra 1x1 conv layer added. Not recommended.", flush=True) 95 | add_on_layers = nn.Sequential( 96 | nn.Conv2d(in_channels=first_add_on_layer_in_channels, out_channels=num_prototypes, kernel_size=1, stride = 1, padding=0, bias=True), 97 | nn.Softmax(dim=1), #softmax over every prototype for each patch, such that for every location in image, sum over prototypes is 1 98 | ) 99 | pool_layer = nn.Sequential( 100 | nn.AdaptiveMaxPool2d(output_size=(1,1)), #outputs (bs, ps,1,1) 101 | nn.Flatten() #outputs (bs, ps) 102 | ) 103 | 104 | if args.bias: 105 | classification_layer = NonNegLinear(num_prototypes, num_classes, bias=True) 106 | else: 107 | classification_layer = NonNegLinear(num_prototypes, num_classes, bias=False) 108 | 109 | return features, add_on_layers, pool_layer, classification_layer, num_prototypes 110 | 111 | 112 | -------------------------------------------------------------------------------- /pipnet/train.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import torch 3 | import torch.nn.functional as F 4 | import torch.optim 5 | import torch.utils.data 6 | import math 7 | 8 | def train_pipnet(net, train_loader, optimizer_net, optimizer_classifier, scheduler_net, scheduler_classifier, criterion, epoch, nr_epochs, device, pretrain=False, finetune=False, progress_prefix: str = 'Train Epoch'): 9 | 10 | # Make sure the model is in train mode 11 | net.train() 12 | 13 | if pretrain: 14 | # Disable training of classification layer 15 | net.module._classification.requires_grad = False 16 | progress_prefix = 'Pretrain Epoch' 17 | else: 18 | # Enable training of classification layer (disabled in case of pretraining) 19 | net.module._classification.requires_grad = True 20 | 21 | # Store info about the procedure 22 | train_info = dict() 23 | total_loss = 0. 24 | total_acc = 0. 25 | 26 | iters = len(train_loader) 27 | # Show progress on progress bar. 28 | train_iter = tqdm(enumerate(train_loader), 29 | total=len(train_loader), 30 | desc=progress_prefix+'%s'%epoch, 31 | mininterval=2., 32 | ncols=0) 33 | 34 | count_param=0 35 | for name, param in net.named_parameters(): 36 | if param.requires_grad: 37 | count_param+=1 38 | print("Number of parameters that require gradient: ", count_param, flush=True) 39 | 40 | if pretrain: 41 | align_pf_weight = (epoch/nr_epochs)*1. 42 | unif_weight = 0.5 #ignored 43 | t_weight = 5. 44 | cl_weight = 0. 45 | else: 46 | align_pf_weight = 5. 47 | t_weight = 2. 48 | unif_weight = 0. 49 | cl_weight = 2. 50 | 51 | 52 | print("Align weight: ", align_pf_weight, ", U_tanh weight: ", t_weight, "Class weight:", cl_weight, flush=True) 53 | print("Pretrain?", pretrain, "Finetune?", finetune, flush=True) 54 | 55 | lrs_net = [] 56 | lrs_class = [] 57 | # Iterate through the data set to update leaves, prototypes and network 58 | for i, (xs1, xs2, ys) in train_iter: 59 | 60 | xs1, xs2, ys = xs1.to(device), xs2.to(device), ys.to(device) 61 | 62 | # Reset the gradients 63 | optimizer_classifier.zero_grad(set_to_none=True) 64 | optimizer_net.zero_grad(set_to_none=True) 65 | 66 | # Perform a forward pass through the network 67 | proto_features, pooled, out = net(torch.cat([xs1, xs2])) 68 | loss, acc = calculate_loss(proto_features, pooled, out, ys, align_pf_weight, t_weight, unif_weight, cl_weight, net.module._classification.normalization_multiplier, pretrain, finetune, criterion, train_iter, print=True, EPS=1e-8) 69 | 70 | # Compute the gradient 71 | loss.backward() 72 | 73 | if not pretrain: 74 | optimizer_classifier.step() 75 | scheduler_classifier.step(epoch - 1 + (i/iters)) 76 | lrs_class.append(scheduler_classifier.get_last_lr()[0]) 77 | 78 | if not finetune: 79 | optimizer_net.step() 80 | scheduler_net.step() 81 | lrs_net.append(scheduler_net.get_last_lr()[0]) 82 | else: 83 | lrs_net.append(0.) 84 | 85 | with torch.no_grad(): 86 | total_acc+=acc 87 | total_loss+=loss.item() 88 | 89 | if not pretrain: 90 | with torch.no_grad(): 91 | net.module._classification.weight.copy_(torch.clamp(net.module._classification.weight.data - 1e-3, min=0.)) #set weights in classification layer < 1e-3 to zero 92 | net.module._classification.normalization_multiplier.copy_(torch.clamp(net.module._classification.normalization_multiplier.data, min=1.0)) 93 | if net.module._classification.bias is not None: 94 | net.module._classification.bias.copy_(torch.clamp(net.module._classification.bias.data, min=0.)) 95 | train_info['train_accuracy'] = total_acc/float(i+1) 96 | train_info['loss'] = total_loss/float(i+1) 97 | train_info['lrs_net'] = lrs_net 98 | train_info['lrs_class'] = lrs_class 99 | 100 | return train_info 101 | 102 | def calculate_loss(proto_features, pooled, out, ys1, align_pf_weight, t_weight, unif_weight, cl_weight, net_normalization_multiplier, pretrain, finetune, criterion, train_iter, print=True, EPS=1e-10): 103 | ys = torch.cat([ys1,ys1]) 104 | pooled1, pooled2 = pooled.chunk(2) 105 | pf1, pf2 = proto_features.chunk(2) 106 | 107 | embv2 = pf2.flatten(start_dim=2).permute(0,2,1).flatten(end_dim=1) 108 | embv1 = pf1.flatten(start_dim=2).permute(0,2,1).flatten(end_dim=1) 109 | 110 | a_loss_pf = (align_loss(embv1, embv2.detach())+ align_loss(embv2, embv1.detach()))/2. 111 | tanh_loss = -(torch.log(torch.tanh(torch.sum(pooled1,dim=0))+EPS).mean() + torch.log(torch.tanh(torch.sum(pooled2,dim=0))+EPS).mean())/2. 112 | 113 | if not finetune: 114 | loss = align_pf_weight*a_loss_pf 115 | loss += t_weight * tanh_loss 116 | 117 | if not pretrain: 118 | softmax_inputs = torch.log1p(out**net_normalization_multiplier) 119 | class_loss = criterion(F.log_softmax((softmax_inputs),dim=1),ys) 120 | 121 | if finetune: 122 | loss= cl_weight * class_loss 123 | else: 124 | loss+= cl_weight * class_loss 125 | # Our tanh-loss optimizes for uniformity and was sufficient for our experiments. However, if pretraining of the prototypes is not working well for your dataset, you may try to add another uniformity loss from https://www.tongzhouwang.info/hypersphere/ Just uncomment the following three lines 126 | # else: 127 | # uni_loss = (uniform_loss(F.normalize(pooled1+EPS,dim=1)) + uniform_loss(F.normalize(pooled2+EPS,dim=1)))/2. 128 | # loss += unif_weight * uni_loss 129 | 130 | acc=0. 131 | if not pretrain: 132 | ys_pred_max = torch.argmax(out, dim=1) 133 | correct = torch.sum(torch.eq(ys_pred_max, ys)) 134 | acc = correct.item() / float(len(ys)) 135 | if print: 136 | with torch.no_grad(): 137 | if pretrain: 138 | train_iter.set_postfix_str( 139 | f'L: {loss.item():.3f}, LA:{a_loss_pf.item():.2f}, LT:{tanh_loss.item():.3f}, num_scores>0.1:{torch.count_nonzero(torch.relu(pooled-0.1),dim=1).float().mean().item():.1f}',refresh=False) 140 | else: 141 | if finetune: 142 | train_iter.set_postfix_str( 143 | f'L:{loss.item():.3f},LC:{class_loss.item():.3f}, LA:{a_loss_pf.item():.2f}, LT:{tanh_loss.item():.3f}, num_scores>0.1:{torch.count_nonzero(torch.relu(pooled-0.1),dim=1).float().mean().item():.1f}, Ac:{acc:.3f}',refresh=False) 144 | else: 145 | train_iter.set_postfix_str( 146 | f'L:{loss.item():.3f},LC:{class_loss.item():.3f}, LA:{a_loss_pf.item():.2f}, LT:{tanh_loss.item():.3f}, num_scores>0.1:{torch.count_nonzero(torch.relu(pooled-0.1),dim=1).float().mean().item():.1f}, Ac:{acc:.3f}',refresh=False) 147 | return loss, acc 148 | 149 | # Extra uniform loss from https://www.tongzhouwang.info/hypersphere/. Currently not used but you could try adding it if you want. 150 | def uniform_loss(x, t=2): 151 | # print("sum elements: ", torch.sum(torch.pow(x,2), dim=1).shape, torch.sum(torch.pow(x,2), dim=1)) #--> should be ones 152 | loss = (torch.pdist(x, p=2).pow(2).mul(-t).exp().mean() + 1e-10).log() 153 | return loss 154 | 155 | # from https://gitlab.com/mipl/carl/-/blob/main/losses.py 156 | def align_loss(inputs, targets, EPS=1e-12): 157 | assert inputs.shape == targets.shape 158 | assert targets.requires_grad == False 159 | 160 | loss = torch.einsum("nc,nc->n", [inputs, targets]) 161 | loss = -torch.log(loss + EPS).mean() 162 | return loss -------------------------------------------------------------------------------- /util/visualize_prediction.py: -------------------------------------------------------------------------------- 1 | import os, shutil 2 | import argparse 3 | from PIL import Image, ImageDraw as D 4 | import torchvision 5 | from util.func import get_patch_size 6 | from torchvision import transforms 7 | import torch 8 | from util.vis_pipnet import get_img_coordinates 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | 12 | try: 13 | import cv2 14 | use_opencv = True 15 | except ImportError: 16 | use_opencv = False 17 | print("Heatmaps showing where a prototype is found will not be generated because OpenCV is not installed.", flush=True) 18 | 19 | def vis_pred(net, vis_test_dir, classes, device, args: argparse.Namespace): 20 | # Make sure the model is in evaluation mode 21 | net.eval() 22 | 23 | save_dir = os.path.join(args.log_dir, args.dir_for_saving_images) 24 | if os.path.exists(save_dir): 25 | shutil.rmtree(save_dir) 26 | 27 | patchsize, skip = get_patch_size(args) 28 | 29 | num_workers = args.num_workers 30 | 31 | mean = (0.485, 0.456, 0.406) 32 | std = (0.229, 0.224, 0.225) 33 | normalize = transforms.Normalize(mean=mean,std=std) 34 | transform_no_augment = transforms.Compose([ 35 | transforms.Resize(size=(args.image_size, args.image_size)), 36 | transforms.ToTensor(), 37 | normalize]) 38 | 39 | vis_test_set = torchvision.datasets.ImageFolder(vis_test_dir, transform=transform_no_augment) 40 | vis_test_loader = torch.utils.data.DataLoader(vis_test_set, batch_size = 1, 41 | shuffle=False, pin_memory=not args.disable_cuda and torch.cuda.is_available(), 42 | num_workers=num_workers) 43 | imgs = vis_test_set.imgs 44 | 45 | last_y = -1 46 | for k, (xs, ys) in enumerate(vis_test_loader): #shuffle is false so should lead to same order as in imgs 47 | if ys[0] != last_y: 48 | last_y = ys[0] 49 | count_per_y = 0 50 | else: 51 | count_per_y +=1 52 | if count_per_y>5: #show max 5 imgs per class to speed up the process 53 | continue 54 | xs, ys = xs.to(device), ys.to(device) 55 | img = imgs[k][0] 56 | img_name = os.path.splitext(os.path.basename(img))[0] 57 | dir = os.path.join(save_dir,img_name) 58 | if not os.path.exists(dir): 59 | os.makedirs(dir) 60 | shutil.copy(img, dir) 61 | 62 | with torch.no_grad(): 63 | softmaxes, pooled, out = net(xs, inference=True) #softmaxes has shape (bs, num_prototypes, W, H), pooled has shape (bs, num_prototypes), out has shape (bs, num_classes) 64 | sorted_out, sorted_out_indices = torch.sort(out.squeeze(0), descending=True) 65 | for pred_class_idx in sorted_out_indices[:3]: 66 | pred_class = classes[pred_class_idx] 67 | save_path = os.path.join(dir, pred_class+"_"+str(f"{out[0,pred_class_idx].item():.3f}")) 68 | if not os.path.exists(save_path): 69 | os.makedirs(save_path) 70 | sorted_pooled, sorted_pooled_indices = torch.sort(pooled.squeeze(0), descending=True) 71 | simweights = [] 72 | for prototype_idx in sorted_pooled_indices: 73 | simweight = pooled[0,prototype_idx].item() * net.module._classification.weight[pred_class_idx, prototype_idx].item() 74 | simweights.append(simweight) 75 | if abs(simweight) > 0.01: 76 | max_h, max_idx_h = torch.max(softmaxes[0, prototype_idx, :, :], dim=0) 77 | max_w, max_idx_w = torch.max(max_h, dim=0) 78 | max_idx_h = max_idx_h[max_idx_w].item() 79 | max_idx_w = max_idx_w.item() 80 | image = transforms.Resize(size=(args.image_size, args.image_size))(Image.open(img)) 81 | img_tensor = transforms.ToTensor()(image).unsqueeze_(0) #shape (1, 3, h, w) 82 | h_coor_min, h_coor_max, w_coor_min, w_coor_max = get_img_coordinates(args.image_size, softmaxes.shape, patchsize, skip, max_idx_h, max_idx_w) 83 | img_tensor_patch = img_tensor[0, :, h_coor_min:h_coor_max, w_coor_min:w_coor_max] 84 | img_patch = transforms.ToPILImage()(img_tensor_patch) 85 | img_patch.save(os.path.join(save_path, 'mul%s_p%s_sim%s_w%s_patch.png'%(str(f"{simweight:.3f}"),str(prototype_idx.item()),str(f"{pooled[0,prototype_idx].item():.3f}"),str(f"{net.module._classification.weight[pred_class_idx, prototype_idx].item():.3f}")))) 86 | draw = D.Draw(image) 87 | draw.rectangle([(max_idx_w*skip,max_idx_h*skip), (min(args.image_size, max_idx_w*skip+patchsize), min(args.image_size, max_idx_h*skip+patchsize))], outline='yellow', width=2) 88 | image.save(os.path.join(save_path, 'mul%s_p%s_sim%s_w%s_rect.png'%(str(f"{simweight:.3f}"),str(prototype_idx.item()),str(f"{pooled[0,prototype_idx].item():.3f}"),str(f"{net.module._classification.weight[pred_class_idx, prototype_idx].item():.3f}")))) 89 | 90 | # visualise softmaxes as heatmap 91 | if use_opencv: 92 | softmaxes_resized = transforms.ToPILImage()(softmaxes[0, prototype_idx, :, :]) 93 | softmaxes_resized = softmaxes_resized.resize((args.image_size, args.image_size),Image.BICUBIC) 94 | softmaxes_np = (transforms.ToTensor()(softmaxes_resized)).squeeze().numpy() 95 | 96 | heatmap = cv2.applyColorMap(np.uint8(255*softmaxes_np), cv2.COLORMAP_JET) 97 | heatmap = np.float32(heatmap)/255 98 | heatmap = heatmap[...,::-1] # OpenCV's BGR to RGB 99 | heatmap_img = 0.2 * np.float32(heatmap) + 0.6 * np.float32(img_tensor.squeeze().numpy().transpose(1,2,0)) 100 | plt.imsave(fname=os.path.join(save_path, 'heatmap_p%s.png'%str(prototype_idx.item())),arr=heatmap_img,vmin=0.0,vmax=1.0) 101 | 102 | def vis_pred_experiments(net, imgs_dir, classes, device, args: argparse.Namespace): 103 | # Make sure the model is in evaluation mode 104 | net.eval() 105 | 106 | save_dir = os.path.join(os.path.join(args.log_dir, args.dir_for_saving_images),"Experiments") 107 | if os.path.exists(save_dir): 108 | shutil.rmtree(save_dir) 109 | 110 | patchsize, skip = get_patch_size(args) 111 | 112 | num_workers = args.num_workers 113 | 114 | mean = (0.485, 0.456, 0.406) 115 | std = (0.229, 0.224, 0.225) 116 | normalize = transforms.Normalize(mean=mean,std=std) 117 | transform_no_augment = transforms.Compose([ 118 | transforms.Resize(size=(args.image_size, args.image_size)), 119 | transforms.ToTensor(), 120 | normalize]) 121 | 122 | vis_test_set = torchvision.datasets.ImageFolder(imgs_dir, transform=transform_no_augment) 123 | vis_test_loader = torch.utils.data.DataLoader(vis_test_set, batch_size = 1, 124 | shuffle=False, pin_memory=not args.disable_cuda and torch.cuda.is_available(), 125 | num_workers=num_workers) 126 | imgs = vis_test_set.imgs 127 | for k, (xs, ys) in enumerate(vis_test_loader): #shuffle is false so should lead to same order as in imgs 128 | 129 | xs, ys = xs.to(device), ys.to(device) 130 | img = imgs[k][0] 131 | img_name = os.path.splitext(os.path.basename(img))[0] 132 | dir = os.path.join(save_dir,img_name) 133 | if not os.path.exists(dir): 134 | os.makedirs(dir) 135 | shutil.copy(img, dir) 136 | 137 | with torch.no_grad(): 138 | softmaxes, pooled, out = net(xs, inference=True) #softmaxes has shape (bs, num_prototypes, W, H), pooled has shape (bs, num_prototypes), out has shape (bs, num_classes) 139 | sorted_out, sorted_out_indices = torch.sort(out.squeeze(0), descending=True) 140 | 141 | for pred_class_idx in sorted_out_indices: 142 | pred_class = classes[pred_class_idx] 143 | save_path = os.path.join(dir, str(f"{out[0,pred_class_idx].item():.3f}")+"_"+pred_class) 144 | if not os.path.exists(save_path): 145 | os.makedirs(save_path) 146 | 147 | sorted_pooled, sorted_pooled_indices = torch.sort(pooled.squeeze(0), descending=True) 148 | 149 | simweights = [] 150 | for prototype_idx in sorted_pooled_indices: 151 | simweight = pooled[0,prototype_idx].item() * net.module._classification.weight[pred_class_idx, prototype_idx].item() 152 | 153 | simweights.append(simweight) 154 | if abs(simweight) > 0.01: 155 | max_h, max_idx_h = torch.max(softmaxes[0, prototype_idx, :, :], dim=0) 156 | max_w, max_idx_w = torch.max(max_h, dim=0) 157 | max_idx_h = max_idx_h[max_idx_w].item() 158 | max_idx_w = max_idx_w.item() 159 | 160 | image = transforms.Resize(size=(args.image_size, args.image_size))(Image.open(img).convert("RGB")) 161 | img_tensor = transforms.ToTensor()(image).unsqueeze_(0) #shape (1, 3, h, w) 162 | h_coor_min, h_coor_max, w_coor_min, w_coor_max = get_img_coordinates(args.image_size, softmaxes.shape, patchsize, skip, max_idx_h, max_idx_w) 163 | img_tensor_patch = img_tensor[0, :, h_coor_min:h_coor_max, w_coor_min:w_coor_max] 164 | img_patch = transforms.ToPILImage()(img_tensor_patch) 165 | img_patch.save(os.path.join(save_path, 'mul%s_p%s_sim%s_w%s_patch.png'%(str(f"{simweight:.3f}"),str(prototype_idx.item()),str(f"{pooled[0,prototype_idx].item():.3f}"),str(f"{net.module._classification.weight[pred_class_idx, prototype_idx].item():.3f}")))) 166 | draw = D.Draw(image) 167 | draw.rectangle([(max_idx_w*skip,max_idx_h*skip), (min(args.image_size, max_idx_w*skip+patchsize), min(args.image_size, max_idx_h*skip+patchsize))], outline='yellow', width=2) 168 | image.save(os.path.join(save_path, 'mul%s_p%s_sim%s_w%s_rect.png'%(str(f"{simweight:.3f}"),str(prototype_idx.item()),str(f"{pooled[0,prototype_idx].item():.3f}"),str(f"{net.module._classification.weight[pred_class_idx, prototype_idx].item():.3f}")))) 169 | 170 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PIP-Net: Patch-Based Intuitive Prototypes for Interpretable Image Classification 2 | This repository presents the PyTorch code for PIP-Net (Patch-based Intuitive Prototypes Network). 3 | 4 | **Main Paper at CVPR**: ["PIP-Net: Patch-Based Intuitive Prototypes for Interpretable Image Classification"](https://openaccess.thecvf.com/content/CVPR2023/papers/Nauta_PIP-Net_Patch-Based_Intuitive_Prototypes_for_Interpretable_Image_Classification_CVPR_2023_paper.pdf) introduces PIP-Net for natural images.\ 5 | **Medical applications, data quality inspection and manual corrections**: [Interpreting and Correcting Medical Image Classification with PIP-Net](https://link.springer.com/chapter/10.1007/978-3-031-50396-2_11), applies PIP-Net to X-rays and skin lesion images where biases can be fixed by (manually) disabling prototypes. \ 6 | **Evaluation of part-prototype models like PIP-Net**: [The Co-12 Recipe for Evaluating Interpretable Part-Prototype Image Classifiers](https://arxiv.org/abs/2307.14517), presented at the [XAI World Conference](https://xaiworldconference.com/) in July 2023. 7 | 8 | 9 | PIP-Net is an interpretable and intuitive deep learning method for image classification. PIP-Net learns prototypical parts: interpretable concepts visualized as image patches. PIP-Net classifies an image with a sparse scoring sheet where the presence of a prototypical part in an image adds evidence for a class. PIP-Net is globally interpretable since the set of learned prototypes shows the entire reasoning of the model. A smaller local explanation locates the relevant prototypes in a test image. The model can also abstain from a decision for out-of-distribution data by saying “I haven’t seen this before”. The model only uses image-level labels and does not rely on any part annotations. 10 | 11 | ![Overview of PIP-Net](https://github.com/M-Nauta/PIPNet/blob/main/nauta_pipnet_cpvr.png) 12 | 13 | ### Required Python Packages: 14 | * [PyTorch](https://pytorch.org/get-started/locally/) (incl torchvision, tested with PyTorch 1.13) 15 | * [tqdm](https://tqdm.github.io/) 16 | * scikit-learn 17 | * openCV (optional, used to generate heatmaps) 18 | * pandas 19 | * matplotlib 20 | 21 | ### Training PIP-Net 22 | PIP-Net can be trained by running `main.py` with arguments. Run `main.py --help` to see all the argument options. Recommended parameters per dataset are present in the `used_arguments.txt` file (usually corresponds to the default options). 23 | 24 | #### Training PIP-Net on your own data 25 | Want to train PIP-Net on another dataset? Add your dataset in ``util/data.py`` by creating a function ``get_yourdata`` with the desired data augmentation (that captures human perception of similarity), add it to the existing ``get_data`` function in ``util/data.py`` and give your dataset a name. Use ``--dataset your_dataset_name`` as argument to run PIP-Net on your dataset. 26 | 27 | Other relevant arguments are for example ``--weighted_loss`` which is useful when your data is imbalanced. In case of a 2-class task with presence/absence reasoning, you could consider using ``--bias`` to include a traininable bias term in the linear classification layer (which could decrease the OoD abilities) such that PIP-Net does not necessarily need to find evidence for the absence-class. 28 | 29 | Check your `--log_dir` to keep track of the training progress. This directory contains `log_epoch_overview.csv` which prints statistics per epoch. File `tqdm.txt` prints updates per iteration and potential errors. File `out.txt` includes all print statements such as additional info. See the **Interpreting the Results** section for further details. 30 | 31 | Visualizations of prototypes are included in your `--log_dir` / `--dir_for_saving_images`. 32 | 33 | #### Trained checkpoints 34 | Various trained versions of PIP-Net are made available: 35 | 36 | - PIP-Net with the ConvNext backbone (recommended) trained on the birds CUB-200-2011 dataset is available [for download here](https://drive.google.com/file/d/1G8iiXgZ5gENYicwS8nLIg2Gf43A49kKm/view) (320MB). Download the CUB dataset (see instructions in this README) and run the following command to generate the prototypes and evaluate the model: 37 | ``python3 main.py --dataset CUB-200-2011 --epochs_pretrain 0 --batch_size 64 --freeze_epochs 10 --epochs 0 --log_dir ./runs/pipnet_cub --state_dict_dir_net ./pipnet_cub_trained``. Update the path of ``--state_dict_dir_net`` to the checkpoint if needed. 38 | - PIP-Net with the ResNet50 backbone trained on the birds CUB-200-2011 dataset is available [for download here](https://drive.google.com/file/d/1zI1bcEXDsp8eN20msSiySo6UHD9y_bgw/view) (280MB). Use ``--net resnet50``. 39 | - PIP-Net with the ConvNext backbone (recommended) trained on the CARS dataset is available [for download here](https://drive.google.com/file/d/1JQNbhzw6s7yJsd_3--hCAReGkbT9PRlP/view) (320MB). Use ``--dataset CARS``. 40 | - PIP-Net with the ResNet50 backbone trained on the CARS dataset is available [for download here](https://drive.google.com/file/d/15t_nIjqR6m-dRFljqi-ntyv-gr4TIF7m/view) (280MB). 41 | 42 | ### Data 43 | The code can be applied to any imaging classification data set, structured according to the [Imagefolder format](https://pytorch.org/vision/stable/generated/torchvision.datasets.ImageFolder.html#torchvision.datasets.ImageFolder): 44 | 45 | >root/class1/xxx.png
root/class1/xxy.png
root/class2/xyy.png
root/class2/yyy.png 46 | 47 | Add or update the paths to your dataset in ``util/data.py``. 48 | 49 | For preparing [CUB-200-2011]([http://www.vision.caltech.edu/visipedia/CUB-200-2011.html](https://www.vision.caltech.edu/datasets/cub_200_2011/)) with 200 bird species, use `util/preprocess_cub.py`. For [Stanford Cars](https://ai.stanford.edu/~jkrause/cars/car_dataset.html) with 196 car types, use the [Instructions of ProtoTree](https://github.com/M-Nauta/ProtoTree/blob/main/README.md#preprocessing-cub). 50 | 51 | ### Interpreting the Results 52 | During training, various files will be created in your ``--log_dir``: 53 | 54 | - **``log_epoch_overview.csv``** keeps track of the training progress per epoch. It contains accuracies, the number of prototypes, loss values etc. In case of a 2-class task, the third value is F1-score, otherwise this is top5-accuracy. 55 | - **``out.txt``** collects the standard output from print statements. Its most relevant content is: 56 | - More performance metrics are printed, such as sparsity ratio. In case of a 2-class task, it also shows the sensitivity, specificity, confusion matrix, etc. 57 | - At the end of the file, after training, the relevant prototypes per class are printed. E.g., ``Class 0 has 5 relevant prototypes: [(prototype_id, class weight), ...]''. This information thus shows the learned scoring sheet of PIP-Net. 58 | - **``tqdm.txt``** contains the progress via progress bar package [tqdm](https://tqdm.github.io/). Useful to see how long one epoch will take, and how the losses evolve. Errors are also printed here. 59 | - **``metadata``** folder logs the provided arguments. 60 | - **``checkpoints``** folder contains state_dicts of the saved models. 61 | - **Prototype visualisations** After training, various folders are created to visualise the learned reasoning of PIP-Net. 62 | - ``visualised_pretrained_prototypes_topk`` visualises the top-10 most similar image patches per prototype after the pretraining phase. Each row in ``grid_topk_all`` corresponds to one prototype. The number corresponds with the index of the prototype node, starting at 0. 63 | - ``visualised_prototypes_topk`` visualises the top-10 most similar image patches after the full (first and second stage) training. Prototypes that are not relevant to any class (all weights are zero) are excluded. 64 | - ``visualised_prototypes`` is a more extensive visualisation of the prototypes learned after training PIP-Net. The ``grid_xxx.png`` images show all image patches that are similar to prototype with index ``xxx``. The number of image patches (or the size of the png file) already gives an indication how often this prototype is found in the training set. If you want to know where these image patches come from (to see some more context), you can open the corresponding folder ``prototype_xxx``. Each image contains a yellow square indicating where prototype ``xxx`` was found, coresponding with an image patch in ``grid_xxx.png``. The file name is ``pxxx_imageid_similarityscore_imagename_rect.png``. 65 | - ``visualization_results`` (or other ``--dir_for_saving_images``) contains predictions including local explanations for test images. A subfolder corresponding to a test image contains the test image itself, and folders with predicted classes: ``classname_outputscore``. In such a class folder, it is visualised where which prototypes are detected: ``muliplicationofsimilarityandweight_prototypeindex_similarityscore_classweight_rect_or_patch.png``. 66 | 67 | ### Hyperparameter FAQ 68 | * **What is the best number of epochs for my dataset?** 69 | The right number of epochs (`--epochs` and `--epochs_pretrain`) will depend on the data set size and difficulty of the classification task. Hence, tuning the parameters might require some trial-and-error. You can start with the default values. For datasets of different sizes, we recommend to set the number of epochs such that the number of iterations (i.e., weight updates) during the second training state is around 10,000 (rule of thumb). Hence, epochs = 10000 / (num_images_in_trainingset / batch_size). The number of iterations for one epoch is easily found in ``tqdm.txt``. Similarly, the number of pretraining epochs `--epochs_pretrain` can be set such that there are 2000 weight updates. 70 | 71 | * **I have CUDA memory issues, what can I do?** PIP-Net is designed to fit onto one GPU. If your GPU has less CUDA memory, you have the following options: 1) reduce your batch size `--batch_size` or `--batch_size_pretrain`. Set it as large as possible to still fit in CUDA memory. 2) freeze more layers of the CNN backbone. Rather than optimizing the whole CNN backbone from `--freeze_epochs` onwards, you could keep the first layers frozen during the whole training process. Adapt the code around line 200 in `util/args.py` as indicated in the comments there. Alternatively, set `--freeze_epochs` equal to `--epochs`. 3) Use ``--net convnext_tiny_13`` instead of the default ``convnext_tiny_26`` to make training faster and more efficient. The potential downside is that the latent output grid is less fine-grained and could therefore impact prototype localization, but the impact will depend on your data and classification task. 72 | 73 | ### Reference and Citation 74 | Please refer to our work when using or discussing PIP-Net: 75 | 76 | ``` 77 | Meike Nauta, Jörg Schlötterer, Maurice van Keulen, Christin Seifert (2023). “PIP-Net: Patch-Based Intuitive Prototypes for Interpretable Image Classification.” IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR). 78 | ``` 79 | 80 | BibTex citation: 81 | ``` 82 | @article{nauta2023pipnet, 83 | title={PIP-Net: Patch-Based Intuitive Prototypes for Interpretable Image Classification}, 84 | author={Nauta, Meike and Schlötterer, Jörg and van Keulen, Maurice and Seifert, Christin}, 85 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 86 | year={2023}, 87 | } 88 | ``` 89 | 90 | 91 | 92 | -------------------------------------------------------------------------------- /util/args.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import pickle 4 | import numpy as np 5 | import random 6 | import torch 7 | import torch.optim 8 | 9 | """ 10 | Utility functions for handling parsed arguments 11 | 12 | """ 13 | def get_args() -> argparse.Namespace: 14 | 15 | parser = argparse.ArgumentParser('Train a PIP-Net') 16 | parser.add_argument('--dataset', 17 | type=str, 18 | default='CUB-200-2011', 19 | help='Data set on PIP-Net should be trained') 20 | parser.add_argument('--validation_size', 21 | type=float, 22 | default=0., 23 | help='Split between training and validation set. Can be zero when there is a separate test or validation directory. Should be between 0 and 1. Used for partimagenet (e.g. 0.2)') 24 | parser.add_argument('--net', 25 | type=str, 26 | default='convnext_tiny_26', 27 | help='Base network used as backbone of PIP-Net. Default is convnext_tiny_26 with adapted strides to output 26x26 latent representations. Other option is convnext_tiny_13 that outputs 13x13 (smaller and faster to train, less fine-grained). Pretrained network on iNaturalist is only available for resnet50_inat. Options are: resnet18, resnet34, resnet50, resnet50_inat, resnet101, resnet152, convnext_tiny_26 and convnext_tiny_13.') 28 | parser.add_argument('--batch_size', 29 | type=int, 30 | default=64, 31 | help='Batch size when training the model using minibatch gradient descent. Batch size is multiplied with number of available GPUs') 32 | parser.add_argument('--batch_size_pretrain', 33 | type=int, 34 | default=128, 35 | help='Batch size when pretraining the prototypes (first training stage)') 36 | parser.add_argument('--epochs', 37 | type=int, 38 | default=60, 39 | help='The number of epochs PIP-Net should be trained (second training stage)') 40 | parser.add_argument('--epochs_pretrain', 41 | type=int, 42 | default = 10, 43 | help='Number of epochs to pre-train the prototypes (first training stage). Recommended to train at least until the align loss < 1' 44 | ) 45 | parser.add_argument('--optimizer', 46 | type=str, 47 | default='Adam', 48 | help='The optimizer that should be used when training PIP-Net') 49 | parser.add_argument('--lr', 50 | type=float, 51 | default=0.05, 52 | help='The optimizer learning rate for training the weights from prototypes to classes') 53 | parser.add_argument('--lr_block', 54 | type=float, 55 | default=0.0005, 56 | help='The optimizer learning rate for training the last conv layers of the backbone') 57 | parser.add_argument('--lr_net', 58 | type=float, 59 | default=0.0005, 60 | help='The optimizer learning rate for the backbone. Usually similar as lr_block.') 61 | parser.add_argument('--weight_decay', 62 | type=float, 63 | default=0.0, 64 | help='Weight decay used in the optimizer') 65 | parser.add_argument('--disable_cuda', 66 | action='store_true', 67 | help='Flag that disables GPU usage if set') 68 | parser.add_argument('--log_dir', 69 | type=str, 70 | default='./runs/run_pipnet', 71 | help='The directory in which train progress should be logged') 72 | parser.add_argument('--num_features', 73 | type=int, 74 | default = 0, 75 | help='Number of prototypes. When zero (default) the number of prototypes is the number of output channels of backbone. If this value is set, then a 1x1 conv layer will be added. Recommended to keep 0, but can be increased when number of classes > num output channels in backbone.') 76 | parser.add_argument('--image_size', 77 | type=int, 78 | default=224, 79 | help='Input images will be resized to --image_size x --image_size (square). Code only tested with 224x224, so no guarantees that it works for different sizes.') 80 | parser.add_argument('--state_dict_dir_net', 81 | type=str, 82 | default='', 83 | help='The directory containing a state dict with a pretrained PIP-Net. E.g., ./runs/run_pipnet/checkpoints/net_pretrained') 84 | parser.add_argument('--freeze_epochs', 85 | type=int, 86 | default = 10, 87 | help='Number of epochs where pretrained features_net will be frozen while training classification layer (and last layer(s) of backbone)' 88 | ) 89 | parser.add_argument('--dir_for_saving_images', 90 | type=str, 91 | default='visualization_results', 92 | help='Directoy for saving the prototypes and explanations') 93 | parser.add_argument('--disable_pretrained', 94 | action='store_true', 95 | help='When set, the backbone network is initialized with random weights instead of being pretrained on another dataset).' 96 | ) 97 | parser.add_argument('--weighted_loss', 98 | action='store_true', 99 | help='Flag that weights the loss based on the class balance of the dataset. Recommended to use when data is imbalanced. ') 100 | parser.add_argument('--seed', 101 | type=int, 102 | default=1, 103 | help='Random seed. Note that there will still be differences between runs due to nondeterminism. See https://pytorch.org/docs/stable/notes/randomness.html') 104 | parser.add_argument('--gpu_ids', 105 | type=str, 106 | default='', 107 | help='ID of gpu. Can be separated with comma') 108 | parser.add_argument('--num_workers', 109 | type=int, 110 | default=8, 111 | help='Num workers in dataloaders.') 112 | parser.add_argument('--bias', 113 | action='store_true', 114 | help='Flag that indicates whether to include a trainable bias in the linear classification layer.' 115 | ) 116 | parser.add_argument('--extra_test_image_folder', 117 | type=str, 118 | default='./experiments', 119 | help='Folder with images that PIP-Net will predict and explain, that are not in the training or test set. E.g. images with 2 objects or OOD image. Images should be in subfolder. E.g. images in ./experiments/images/, and argument --./experiments') 120 | 121 | args = parser.parse_args() 122 | if len(args.log_dir.split('/'))>2: 123 | if not os.path.exists(args.log_dir): 124 | os.makedirs(args.log_dir) 125 | 126 | 127 | return args 128 | 129 | 130 | def save_args(args: argparse.Namespace, directory_path: str) -> None: 131 | """ 132 | Save the arguments in the specified directory as 133 | - a text file called 'args.txt' 134 | - a pickle file called 'args.pickle' 135 | :param args: The arguments to be saved 136 | :param directory_path: The path to the directory where the arguments should be saved 137 | """ 138 | # If the specified directory does not exists, create it 139 | if not os.path.isdir(directory_path): 140 | os.mkdir(directory_path) 141 | # Save the args in a text file 142 | with open(directory_path + '/args.txt', 'w') as f: 143 | for arg in vars(args): 144 | val = getattr(args, arg) 145 | if isinstance(val, str): # Add quotation marks to indicate that the argument is of string type 146 | val = f"'{val}'" 147 | f.write('{}: {}\n'.format(arg, val)) 148 | # Pickle the args for possible reuse 149 | with open(directory_path + '/args.pickle', 'wb') as f: 150 | pickle.dump(args, f) 151 | 152 | def get_optimizer_nn(net, args: argparse.Namespace) -> torch.optim.Optimizer: 153 | torch.manual_seed(args.seed) 154 | torch.cuda.manual_seed_all(args.seed) 155 | random.seed(args.seed) 156 | np.random.seed(args.seed) 157 | 158 | #create parameter groups 159 | params_to_freeze = [] 160 | params_to_train = [] 161 | params_backbone = [] 162 | # set up optimizer 163 | if 'resnet50' in args.net: 164 | # freeze resnet50 except last convolutional layer 165 | for name,param in net.module._net.named_parameters(): 166 | if 'layer4.2' in name: 167 | params_to_train.append(param) 168 | elif 'layer4' in name or 'layer3' in name: 169 | params_to_freeze.append(param) 170 | elif 'layer2' in name: 171 | params_backbone.append(param) 172 | else: #such that model training fits on one gpu. 173 | param.requires_grad = False 174 | # params_backbone.append(param) 175 | 176 | elif 'convnext' in args.net: 177 | print("chosen network is convnext", flush=True) 178 | for name,param in net.module._net.named_parameters(): 179 | if 'features.7.2' in name: 180 | params_to_train.append(param) 181 | elif 'features.7' in name or 'features.6' in name: 182 | params_to_freeze.append(param) 183 | # CUDA MEMORY ISSUES? COMMENT LINE 202-203 AND USE THE FOLLOWING LINES INSTEAD 184 | # elif 'features.5' in name or 'features.4' in name: 185 | # params_backbone.append(param) 186 | # else: 187 | # param.requires_grad = False 188 | else: 189 | params_backbone.append(param) 190 | else: 191 | print("Network is not ResNet or ConvNext.", flush=True) 192 | classification_weight = [] 193 | classification_bias = [] 194 | for name, param in net.module._classification.named_parameters(): 195 | if 'weight' in name: 196 | classification_weight.append(param) 197 | elif 'multiplier' in name: 198 | param.requires_grad = False 199 | else: 200 | if args.bias: 201 | classification_bias.append(param) 202 | 203 | paramlist_net = [ 204 | {"params": params_backbone, "lr": args.lr_net, "weight_decay_rate": args.weight_decay}, 205 | {"params": params_to_freeze, "lr": args.lr_block, "weight_decay_rate": args.weight_decay}, 206 | {"params": params_to_train, "lr": args.lr_block, "weight_decay_rate": args.weight_decay}, 207 | {"params": net.module._add_on.parameters(), "lr": args.lr_block*10., "weight_decay_rate": args.weight_decay}] 208 | 209 | paramlist_classifier = [ 210 | {"params": classification_weight, "lr": args.lr, "weight_decay_rate": args.weight_decay}, 211 | {"params": classification_bias, "lr": args.lr, "weight_decay_rate": 0}, 212 | ] 213 | 214 | if args.optimizer == 'Adam': 215 | optimizer_net = torch.optim.AdamW(paramlist_net,lr=args.lr,weight_decay=args.weight_decay) 216 | optimizer_classifier = torch.optim.AdamW(paramlist_classifier,lr=args.lr,weight_decay=args.weight_decay) 217 | return optimizer_net, optimizer_classifier, params_to_freeze, params_to_train, params_backbone 218 | else: 219 | raise ValueError("this optimizer type is not implemented") 220 | 221 | -------------------------------------------------------------------------------- /features/resnet_features.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.utils.model_zoo as model_zoo 4 | import os 5 | import copy 6 | 7 | model_urls = { 8 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 9 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 10 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 11 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 12 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 13 | } 14 | 15 | model_dir = './pretrained_models' 16 | 17 | def conv3x3(in_planes, out_planes, stride=1): 18 | """3x3 convolution with padding""" 19 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 20 | padding=1, bias=False) 21 | 22 | def conv3x3_nopad(in_planes, out_planes, stride=1): 23 | """3x3 convolution without padding""" 24 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 25 | padding=0, bias=False) 26 | 27 | def conv1x1(in_planes, out_planes, stride=1): 28 | """1x1 convolution""" 29 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 30 | 31 | 32 | class BasicBlock(nn.Module): 33 | # class attribute 34 | expansion = 1 35 | num_layers = 2 36 | 37 | def __init__(self, inplanes, planes, stride=1, downsample=None): 38 | super(BasicBlock, self).__init__() 39 | # only conv with possibly not 1 stride 40 | self.conv1 = conv3x3(inplanes, planes, stride) 41 | self.bn1 = nn.BatchNorm2d(planes) 42 | self.relu = nn.ReLU(inplace=True) 43 | self.conv2 = conv3x3(planes, planes) 44 | self.bn2 = nn.BatchNorm2d(planes) 45 | 46 | # if stride is not 1 then self.downsample cannot be None 47 | self.downsample = downsample 48 | self.stride = stride 49 | 50 | def forward(self, x): 51 | identity = x 52 | 53 | out = self.conv1(x) 54 | out = self.bn1(out) 55 | out = self.relu(out) 56 | 57 | out = self.conv2(out) 58 | out = self.bn2(out) 59 | 60 | if self.downsample is not None: 61 | identity = self.downsample(x) 62 | 63 | # the residual connection 64 | out += identity 65 | out = self.relu(out) 66 | 67 | return out 68 | 69 | def block_conv_info(self): 70 | block_kernel_sizes = [3, 3] 71 | block_strides = [self.stride, 1] 72 | block_paddings = [1, 1] 73 | 74 | return block_kernel_sizes, block_strides, block_paddings 75 | 76 | 77 | class Bottleneck(nn.Module): 78 | # class attribute 79 | expansion = 4 80 | num_layers = 3 81 | 82 | def __init__(self, inplanes, planes, stride=1, downsample=None): 83 | super(Bottleneck, self).__init__() 84 | self.conv1 = conv1x1(inplanes, planes) 85 | self.bn1 = nn.BatchNorm2d(planes) 86 | # only conv with possibly not 1 stride 87 | self.conv2 = conv3x3(planes, planes, stride) 88 | self.bn2 = nn.BatchNorm2d(planes) 89 | self.conv3 = conv1x1(planes, planes * self.expansion) 90 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 91 | self.relu = nn.ReLU(inplace=True) 92 | 93 | # if stride is not 1 then self.downsample cannot be None 94 | self.downsample = downsample 95 | self.stride = stride 96 | 97 | def forward(self, x): 98 | identity = x 99 | 100 | out = self.conv1(x) 101 | out = self.bn1(out) 102 | out = self.relu(out) 103 | 104 | out = self.conv2(out) 105 | out = self.bn2(out) 106 | out = self.relu(out) 107 | 108 | out = self.conv3(out) 109 | out = self.bn3(out) 110 | 111 | if self.downsample is not None: 112 | identity = self.downsample(x) 113 | 114 | out += identity 115 | out = self.relu(out) 116 | 117 | return out 118 | 119 | def block_conv_info(self): 120 | block_kernel_sizes = [1, 3, 1] 121 | block_strides = [1, self.stride, 1] 122 | block_paddings = [0, 1, 0] 123 | 124 | return block_kernel_sizes, block_strides, block_paddings 125 | 126 | class ResNet_features(nn.Module): 127 | ''' 128 | the convolutional layers of ResNet 129 | the average pooling and final fully convolutional layer is removed 130 | ''' 131 | 132 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False): 133 | super(ResNet_features, self).__init__() 134 | 135 | self.inplanes = 64 136 | 137 | # the first convolutional layer before the structured sequence of blocks 138 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 139 | bias=False) 140 | self.bn1 = nn.BatchNorm2d(64) 141 | self.relu = nn.ReLU(inplace=True) 142 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 143 | # comes from the first conv and the following max pool 144 | self.kernel_sizes = [7, 3] 145 | self.strides = [2, 2] 146 | self.paddings = [3, 1] 147 | 148 | # the following layers, each layer is a sequence of blocks 149 | self.block = block 150 | self.layers = layers 151 | self.layer1 = self._make_layer(block=block, planes=64, num_blocks=self.layers[0]) 152 | self.layer2 = self._make_layer(block=block, planes=128, num_blocks=self.layers[1], stride=2) 153 | self.layer3 = self._make_layer(block=block, planes=256, num_blocks=self.layers[2], stride=1) 154 | self.layer4 = self._make_layer(block=block, planes=512, num_blocks=self.layers[3], stride=1) 155 | 156 | # initialize the parameters 157 | for m in self.modules(): 158 | if isinstance(m, nn.Conv2d): 159 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 160 | elif isinstance(m, nn.BatchNorm2d): 161 | nn.init.constant_(m.weight, 1) 162 | nn.init.constant_(m.bias, 0) 163 | 164 | # Zero-initialize the last BN in each residual branch, 165 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 166 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 167 | if zero_init_residual: 168 | for m in self.modules(): 169 | if isinstance(m, Bottleneck): 170 | nn.init.constant_(m.bn3.weight, 0) 171 | elif isinstance(m, BasicBlock): 172 | nn.init.constant_(m.bn2.weight, 0) 173 | 174 | def _make_layer(self, block, planes, num_blocks, stride=1): 175 | downsample = None 176 | if stride != 1 or self.inplanes != planes * block.expansion: 177 | downsample = nn.Sequential( 178 | conv1x1(self.inplanes, planes * block.expansion, stride), 179 | nn.BatchNorm2d(planes * block.expansion), 180 | ) 181 | 182 | layers = [] 183 | # only the first block has downsample that is possibly not None 184 | layers.append(block(self.inplanes, planes, stride, downsample)) 185 | 186 | self.inplanes = planes * block.expansion 187 | for _ in range(1, num_blocks): 188 | layers.append(block(self.inplanes, planes)) 189 | 190 | # keep track of every block's conv size, stride size, and padding size 191 | for each_block in layers: 192 | block_kernel_sizes, block_strides, block_paddings = each_block.block_conv_info() 193 | self.kernel_sizes.extend(block_kernel_sizes) 194 | self.strides.extend(block_strides) 195 | self.paddings.extend(block_paddings) 196 | 197 | return nn.Sequential(*layers) 198 | 199 | def forward(self, x): 200 | x = self.conv1(x) 201 | x = self.bn1(x) 202 | x = self.relu(x) 203 | x = self.maxpool(x) 204 | 205 | x = self.layer1(x) 206 | x = self.layer2(x) 207 | x = self.layer3(x) 208 | x = self.layer4(x) 209 | 210 | return x 211 | 212 | def conv_info(self): 213 | return self.kernel_sizes, self.strides, self.paddings 214 | 215 | def num_layers(self): 216 | ''' 217 | the number of conv layers in the network, not counting the number 218 | of bypass layers 219 | ''' 220 | 221 | return (self.block.num_layers * self.layers[0] 222 | + self.block.num_layers * self.layers[1] 223 | + self.block.num_layers * self.layers[2] 224 | + self.block.num_layers * self.layers[3] 225 | + 1) 226 | 227 | def __repr__(self): 228 | template = 'resnet{}_features' 229 | return template.format(self.num_layers() + 1) 230 | 231 | def resnet18_features(pretrained=False, **kwargs): 232 | """Constructs a ResNet-18 model. 233 | Args: 234 | pretrained (bool): If True, returns a model pre-trained on ImageNet 235 | """ 236 | model = ResNet_features(BasicBlock, [2, 2, 2, 2], **kwargs) 237 | if pretrained: 238 | my_dict = model_zoo.load_url(model_urls['resnet18'], model_dir=model_dir) 239 | my_dict.pop('fc.weight') 240 | my_dict.pop('fc.bias') 241 | model.load_state_dict(my_dict, strict=False) 242 | return model 243 | 244 | 245 | def resnet34_features(pretrained=False, **kwargs): 246 | """Constructs a ResNet-34 model. 247 | Args: 248 | pretrained (bool): If True, returns a model pre-trained on ImageNet 249 | """ 250 | model = ResNet_features(BasicBlock, [3, 4, 6, 3], **kwargs) 251 | if pretrained: 252 | my_dict = model_zoo.load_url(model_urls['resnet34'], model_dir=model_dir) 253 | my_dict.pop('fc.weight') 254 | my_dict.pop('fc.bias') 255 | model.load_state_dict(my_dict, strict=False) 256 | return model 257 | 258 | def resnet50_features(pretrained=False, **kwargs): 259 | """Constructs a ResNet-50 model. 260 | Args: 261 | pretrained (bool): If True, returns a model pre-trained on ImageNet 262 | """ 263 | 264 | model = ResNet_features(Bottleneck, [3, 4, 6, 3], **kwargs) 265 | if pretrained: 266 | my_dict = model_zoo.load_url(model_urls['resnet50'], model_dir=model_dir) 267 | my_dict.pop('fc.weight') 268 | my_dict.pop('fc.bias') 269 | model.load_state_dict(my_dict, strict=False) 270 | 271 | return model 272 | 273 | def resnet50_features_inat(pretrained=False, **kwargs): 274 | """Constructs a ResNet-50 model. 275 | Args: 276 | pretrained (bool): If True, returns a model pre-trained on Inaturalist2017 277 | """ 278 | model = ResNet_features(Bottleneck, [3, 4, 6, 3], **kwargs) 279 | if pretrained: 280 | #use BBN pretrained weights of the conventional learning branch (from BBN.iNaturalist2017.res50.180epoch.best_model.pth) 281 | #https://openaccess.thecvf.com/content_CVPR_2020/papers/Zhou_BBN_Bilateral-Branch_Network_With_Cumulative_Learning_for_Long-Tailed_Visual_Recognition_CVPR_2020_paper.pdf 282 | if not os.path.exists(os.path.join(os.path.join('features', 'state_dicts'), 'BBN.iNaturalist2017.res50.180epoch.best_model.pth')): 283 | print("To use Resnet50 pretrained on iNaturalist, create a folder called state_dicts in the folder features, and download BBN.iNaturalist2017.res50.180epoch.best_model.pth to there from https://drive.google.com/drive/folders/1yHme1iFQy-Lz_11yZJPlNd9bO_YPKlEU.", flush=True) 284 | model_dict = torch.load(os.path.join(os.path.join('features', 'state_dicts'), 'BBN.iNaturalist2017.res50.180epoch.best_model.pth')) 285 | # rename last residual block from cb_block to layer4.2 286 | new_model = copy.deepcopy(model_dict) 287 | for k in model_dict.keys(): 288 | if k.startswith('module.backbone.cb_block'): 289 | splitted = k.split('cb_block') 290 | new_model['layer4.2'+splitted[-1]]=model_dict[k] 291 | del new_model[k] 292 | elif k.startswith('module.backbone.rb_block'): 293 | del new_model[k] 294 | elif k.startswith('module.backbone.'): 295 | splitted = k.split('backbone.') 296 | new_model[splitted[-1]]=model_dict[k] 297 | del new_model[k] 298 | elif k.startswith('module.classifier'): 299 | del new_model[k] 300 | model.load_state_dict(new_model, strict=True) 301 | return model 302 | 303 | def resnet101_features(pretrained=False, **kwargs): 304 | """Constructs a ResNet-101 model. 305 | Args: 306 | pretrained (bool): If True, returns a model pre-trained on ImageNet 307 | """ 308 | model = ResNet_features(Bottleneck, [3, 4, 23, 3], **kwargs) 309 | if pretrained: 310 | my_dict = model_zoo.load_url(model_urls['resnet101'], model_dir=model_dir) 311 | my_dict.pop('fc.weight') 312 | my_dict.pop('fc.bias') 313 | model.load_state_dict(my_dict, strict=False) 314 | return model 315 | 316 | 317 | def resnet152_features(pretrained=False, **kwargs): 318 | """Constructs a ResNet-152 model. 319 | Args: 320 | pretrained (bool): If True, returns a model pre-trained on ImageNet 321 | """ 322 | model = ResNet_features(Bottleneck, [3, 8, 36, 3], **kwargs) 323 | if pretrained: 324 | my_dict = model_zoo.load_url(model_urls['resnet152'], model_dir=model_dir) 325 | my_dict.pop('fc.weight') 326 | my_dict.pop('fc.bias') 327 | model.load_state_dict(my_dict, strict=False) 328 | return model 329 | -------------------------------------------------------------------------------- /pipnet/test.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import numpy as np 3 | import torch 4 | import torch.optim 5 | from torch.utils.data import DataLoader 6 | import torch.nn.functional as F 7 | from util.log import Log 8 | from util.func import topk_accuracy 9 | from sklearn.metrics import accuracy_score, roc_auc_score, balanced_accuracy_score, f1_score 10 | 11 | @torch.no_grad() 12 | def eval_pipnet(net, 13 | test_loader: DataLoader, 14 | epoch, 15 | device, 16 | log: Log = None, 17 | progress_prefix: str = 'Eval Epoch' 18 | ) -> dict: 19 | 20 | net = net.to(device) 21 | # Make sure the model is in evaluation mode 22 | net.eval() 23 | # Keep an info dict about the procedure 24 | info = dict() 25 | # Build a confusion matrix 26 | cm = np.zeros((net.module._num_classes, net.module._num_classes), dtype=int) 27 | 28 | global_top1acc = 0. 29 | global_top5acc = 0. 30 | global_sim_anz = 0. 31 | global_anz = 0. 32 | local_size_total = 0. 33 | y_trues = [] 34 | y_preds = [] 35 | y_preds_classes = [] 36 | abstained = 0 37 | # Show progress on progress bar 38 | test_iter = tqdm(enumerate(test_loader), 39 | total=len(test_loader), 40 | desc=progress_prefix+' %s'%epoch, 41 | mininterval=5., 42 | ncols=0) 43 | (xs, ys) = next(iter(test_loader)) 44 | # Iterate through the test set 45 | for i, (xs, ys) in test_iter: 46 | xs, ys = xs.to(device), ys.to(device) 47 | 48 | with torch.no_grad(): 49 | net.module._classification.weight.copy_(torch.clamp(net.module._classification.weight.data - 1e-3, min=0.)) 50 | # Use the model to classify this batch of input data 51 | _, pooled, out = net(xs, inference=True) 52 | max_out_score, ys_pred = torch.max(out, dim=1) 53 | ys_pred_scores = torch.amax(F.softmax((torch.log1p(out**net.module._classification.normalization_multiplier)),dim=1),dim=1) 54 | abstained += (max_out_score.shape[0] - torch.count_nonzero(max_out_score)) 55 | repeated_weight = net.module._classification.weight.unsqueeze(1).repeat(1,pooled.shape[0],1) 56 | sim_scores_anz = torch.count_nonzero(torch.gt(torch.abs(pooled*repeated_weight), 1e-3).float(),dim=2).float() 57 | local_size = torch.count_nonzero(torch.gt(torch.relu((pooled*repeated_weight)-1e-3).sum(dim=1), 0.).float(),dim=1).float() 58 | local_size_total += local_size.sum().item() 59 | 60 | 61 | correct_class_sim_scores_anz = torch.diagonal(torch.index_select(sim_scores_anz, dim=0, index=ys_pred),0) 62 | global_sim_anz += correct_class_sim_scores_anz.sum().item() 63 | 64 | almost_nz = torch.count_nonzero(torch.gt(torch.abs(pooled), 1e-3).float(),dim=1).float() 65 | global_anz += almost_nz.sum().item() 66 | 67 | # Update the confusion matrix 68 | cm_batch = np.zeros((net.module._num_classes, net.module._num_classes), dtype=int) 69 | for y_pred, y_true in zip(ys_pred, ys): 70 | cm[y_true][y_pred] += 1 71 | cm_batch[y_true][y_pred] += 1 72 | acc = acc_from_cm(cm_batch) 73 | test_iter.set_postfix_str( 74 | f'SimANZCC: {correct_class_sim_scores_anz.mean().item():.2f}, ANZ: {almost_nz.mean().item():.1f}, LocS: {local_size.mean().item():.1f}, Acc: {acc:.3f}', refresh=False 75 | ) 76 | 77 | (top1accs, top5accs) = topk_accuracy(out, ys, topk=[1,5]) 78 | 79 | global_top1acc+=torch.sum(top1accs).item() 80 | global_top5acc+=torch.sum(top5accs).item() 81 | y_preds += ys_pred_scores.detach().tolist() 82 | y_trues += ys.detach().tolist() 83 | y_preds_classes += ys_pred.detach().tolist() 84 | 85 | del out 86 | del pooled 87 | del ys_pred 88 | 89 | print("PIP-Net abstained from a decision for", abstained.item(), "images", flush=True) 90 | info['num non-zero prototypes'] = torch.gt(net.module._classification.weight,1e-3).any(dim=0).sum().item() 91 | print("sparsity ratio: ", (torch.numel(net.module._classification.weight)-torch.count_nonzero(torch.nn.functional.relu(net.module._classification.weight-1e-3)).item()) / torch.numel(net.module._classification.weight), flush=True) 92 | info['confusion_matrix'] = cm 93 | info['test_accuracy'] = acc_from_cm(cm) 94 | info['top1_accuracy'] = global_top1acc/len(test_loader.dataset) 95 | info['top5_accuracy'] = global_top5acc/len(test_loader.dataset) 96 | info['almost_sim_nonzeros'] = global_sim_anz/len(test_loader.dataset) 97 | info['local_size_all_classes'] = local_size_total / len(test_loader.dataset) 98 | info['almost_nonzeros'] = global_anz/len(test_loader.dataset) 99 | 100 | if net.module._num_classes == 2: 101 | tp = cm[0][0] 102 | fn = cm[0][1] 103 | fp = cm[1][0] 104 | tn = cm[1][1] 105 | print("TP: ", tp, "FN: ",fn, "FP:", fp, "TN:", tn, flush=True) 106 | sensitivity = tp/(tp+fn) 107 | specificity = tn/(tn+fp) 108 | print("\n Epoch",epoch, flush=True) 109 | print("Confusion matrix: ", cm, flush=True) 110 | try: 111 | for classname, classidx in test_loader.dataset.class_to_idx.items(): 112 | if classidx == 0: 113 | print("Accuracy positive class (", classname, classidx,") (TPR, Sensitivity):", tp/(tp+fn)) 114 | elif classidx == 1: 115 | print("Accuracy negative class (", classname, classidx,") (TNR, Specificity):", tn/(tn+fp)) 116 | except ValueError: 117 | pass 118 | print("Balanced accuracy: ", balanced_accuracy_score(y_trues, y_preds_classes),flush=True) 119 | print("Sensitivity: ", sensitivity, "Specificity: ", specificity,flush=True) 120 | info['top5_accuracy'] = f1_score(y_trues, y_preds_classes) 121 | try: 122 | print("AUC macro: ", roc_auc_score(y_trues, y_preds, average='macro'), flush=True) 123 | print("AUC weighted: ", roc_auc_score(y_trues, y_preds, average='weighted'), flush=True) 124 | except ValueError: 125 | pass 126 | else: 127 | info['top5_accuracy'] = global_top5acc/len(test_loader.dataset) 128 | 129 | return info 130 | 131 | def acc_from_cm(cm: np.ndarray) -> float: 132 | """ 133 | Compute the accuracy from the confusion matrix 134 | :param cm: confusion matrix 135 | :return: the accuracy score 136 | """ 137 | assert len(cm.shape) == 2 and cm.shape[0] == cm.shape[1] 138 | 139 | correct = 0 140 | for i in range(len(cm)): 141 | correct += cm[i, i] 142 | 143 | total = np.sum(cm) 144 | if total == 0: 145 | return 1 146 | else: 147 | return correct / total 148 | 149 | 150 | @torch.no_grad() 151 | # Calculates class-specific threshold for the FPR@X metric. Also calculates threshold for images with correct prediction (currently not used, but can be insightful) 152 | def get_thresholds(net, 153 | test_loader: DataLoader, 154 | epoch, 155 | device, 156 | percentile:float = 95., 157 | log: Log = None, 158 | log_prefix: str = 'log_eval_epochs', 159 | progress_prefix: str = 'Get Thresholds Epoch' 160 | ) -> dict: 161 | 162 | net = net.to(device) 163 | # Make sure the model is in evaluation mode 164 | net.eval() 165 | 166 | outputs_per_class = dict() 167 | outputs_per_correct_class = dict() 168 | for c in range(net.module._num_classes): 169 | outputs_per_class[c] = [] 170 | outputs_per_correct_class[c] = [] 171 | # Show progress on progress bar 172 | test_iter = tqdm(enumerate(test_loader), 173 | total=len(test_loader), 174 | desc=progress_prefix+' %s Perc %s'%(epoch,percentile), 175 | mininterval=5., 176 | ncols=0) 177 | (xs, ys) = next(iter(test_loader)) 178 | # Iterate through the test set 179 | for i, (xs, ys) in test_iter: 180 | xs, ys = xs.to(device), ys.to(device) 181 | 182 | with torch.no_grad(): 183 | # Use the model to classify this batch of input data 184 | _, pooled, out = net(xs) 185 | 186 | ys_pred = torch.argmax(out, dim=1) 187 | for pred in range(len(ys_pred)): 188 | outputs_per_class[ys_pred[pred].item()].append(out[pred,:].max().item()) 189 | if ys_pred[pred].item()==ys[pred].item(): 190 | outputs_per_correct_class[ys_pred[pred].item()].append(out[pred,:].max().item()) 191 | 192 | del out 193 | del pooled 194 | del ys_pred 195 | 196 | class_thresholds = dict() 197 | correct_class_thresholds = dict() 198 | all_outputs = [] 199 | all_correct_outputs = [] 200 | for c in range(net.module._num_classes): 201 | if len(outputs_per_class[c])>0: 202 | outputs_c = outputs_per_class[c] 203 | all_outputs += outputs_c 204 | class_thresholds[c] = np.percentile(outputs_c,100-percentile) 205 | 206 | if len(outputs_per_correct_class[c])>0: 207 | correct_outputs_c = outputs_per_correct_class[c] 208 | all_correct_outputs += correct_outputs_c 209 | correct_class_thresholds[c] = np.percentile(correct_outputs_c,100-percentile) 210 | 211 | overall_threshold = np.percentile(all_outputs,100-percentile) 212 | overall_correct_threshold = np.percentile(all_correct_outputs,100-percentile) 213 | # if class is not predicted there is no threshold. we set it as the minimum value for any other class 214 | mean_ct = np.mean(list(class_thresholds.values())) 215 | mean_cct = np.mean(list(correct_class_thresholds.values())) 216 | for c in range(net.module._num_classes): 217 | if c not in class_thresholds.keys(): 218 | print(c,"not in class thresholds. Setting to mean threshold", flush=True) 219 | class_thresholds[c] = mean_ct 220 | if c not in correct_class_thresholds.keys(): 221 | correct_class_thresholds[c] = mean_cct 222 | 223 | calculated_percentile = 0 224 | correctly_classified = 0 225 | total = 0 226 | for c in range(net.module._num_classes): 227 | correctly_classified+=sum(i>class_thresholds[c] for i in outputs_per_class[c]) 228 | total += len(outputs_per_class[c]) 229 | calculated_percentile = correctly_classified/total 230 | 231 | if percentile<100: 232 | while calculated_percentile < (percentile/100.): 233 | class_thresholds.update((x, y*0.999) for x, y in class_thresholds.items()) 234 | correctly_classified = 0 235 | for c in range(net.module._num_classes): 236 | correctly_classified+=sum(i>=class_thresholds[c] for i in outputs_per_class[c]) 237 | calculated_percentile = correctly_classified/total 238 | 239 | return overall_correct_threshold, overall_threshold, correct_class_thresholds, class_thresholds 240 | 241 | @torch.no_grad() 242 | def eval_ood(net, 243 | test_loader: DataLoader, 244 | epoch, 245 | device, 246 | threshold, #class specific threshold or overall threshold. single float is overall, list or dict is class specific 247 | progress_prefix: str = 'Get Thresholds Epoch' 248 | ) -> dict: 249 | 250 | net = net.to(device) 251 | # Make sure the model is in evaluation mode 252 | net.eval() 253 | 254 | predicted_as_id = 0 255 | seen = 0. 256 | abstained = 0 257 | # Show progress on progress bar 258 | test_iter = tqdm(enumerate(test_loader), 259 | total=len(test_loader), 260 | desc=progress_prefix+' %s'%epoch, 261 | mininterval=5., 262 | ncols=0) 263 | (xs, ys) = next(iter(test_loader)) 264 | # Iterate through the test set 265 | for i, (xs, ys) in test_iter: 266 | xs, ys = xs.to(device), ys.to(device) 267 | 268 | with torch.no_grad(): 269 | # Use the model to classify this batch of input data 270 | _, pooled, out = net(xs) 271 | max_out_score, ys_pred = torch.max(out, dim=1) 272 | ys_pred = torch.argmax(out, dim=1) 273 | abstained += (max_out_score.shape[0] - torch.count_nonzero(max_out_score)) 274 | for j in range(len(ys_pred)): 275 | seen+=1. 276 | if isinstance(threshold, dict): 277 | thresholdj = threshold[ys_pred[j].item()] 278 | elif isinstance(threshold, float): #overall threshold 279 | thresholdj = threshold 280 | else: 281 | raise ValueError("provided threshold should be float or dict", type(threshold)) 282 | sample_out = out[j,:] 283 | 284 | if sample_out.max().item() >= thresholdj: 285 | predicted_as_id += 1 286 | 287 | del out 288 | del pooled 289 | del ys_pred 290 | print("Samples seen:", seen, "of which predicted as In-Distribution:", predicted_as_id, flush=True) 291 | print("PIP-Net abstained from a decision for", abstained.item(), "images", flush=True) 292 | return predicted_as_id/seen 293 | -------------------------------------------------------------------------------- /util/vis_pipnet.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import argparse 3 | import torch 4 | import torch.nn.functional as F 5 | import torch.utils.data 6 | import os 7 | from PIL import Image, ImageDraw as D 8 | import torchvision.transforms as transforms 9 | import torchvision 10 | from util.func import get_patch_size 11 | import random 12 | 13 | @torch.no_grad() 14 | def visualize_topk(net, projectloader, num_classes, device, foldername, args: argparse.Namespace, k=10): 15 | print("Visualizing prototypes for topk...", flush=True) 16 | dir = os.path.join(args.log_dir, foldername) 17 | if not os.path.exists(dir): 18 | os.makedirs(dir) 19 | 20 | near_imgs_dirs = dict() 21 | seen_max = dict() 22 | saved = dict() 23 | saved_ys = dict() 24 | tensors_per_prototype = dict() 25 | 26 | for p in range(net.module._num_prototypes): 27 | near_imgs_dir = os.path.join(dir, str(p)) 28 | near_imgs_dirs[p]=near_imgs_dir 29 | seen_max[p]=0. 30 | saved[p]=0 31 | saved_ys[p]=[] 32 | tensors_per_prototype[p]=[] 33 | 34 | patchsize, skip = get_patch_size(args) 35 | 36 | imgs = projectloader.dataset.imgs 37 | 38 | # Make sure the model is in evaluation mode 39 | net.eval() 40 | classification_weights = net.module._classification.weight 41 | 42 | # Show progress on progress bar 43 | img_iter = tqdm(enumerate(projectloader), 44 | total=len(projectloader), 45 | mininterval=50., 46 | desc='Collecting topk', 47 | ncols=0) 48 | 49 | # Iterate through the data 50 | images_seen = 0 51 | topks = dict() 52 | # Iterate through the training set 53 | for i, (xs, ys) in img_iter: 54 | images_seen+=1 55 | xs, ys = xs.to(device), ys.to(device) 56 | 57 | with torch.no_grad(): 58 | # Use the model to classify this batch of input data 59 | pfs, pooled, _ = net(xs, inference=True) 60 | pooled = pooled.squeeze(0) 61 | pfs = pfs.squeeze(0) 62 | 63 | for p in range(pooled.shape[0]): 64 | c_weight = torch.max(classification_weights[:,p]) 65 | if c_weight > 1e-3:#ignore prototypes that are not relevant to any class 66 | if p not in topks.keys(): 67 | topks[p] = [] 68 | 69 | if len(topks[p]) < k: 70 | topks[p].append((i, pooled[p].item())) 71 | else: 72 | topks[p] = sorted(topks[p], key=lambda tup: tup[1], reverse=True) 73 | if topks[p][-1][1] < pooled[p].item(): 74 | topks[p][-1] = (i, pooled[p].item()) 75 | if topks[p][-1][1] == pooled[p].item(): 76 | # equal scores. randomly chose one (since dataset is not shuffled so latter images with same scores can now also get in topk). 77 | replace_choice = random.choice([0, 1]) 78 | if replace_choice > 0: 79 | topks[p][-1] = (i, pooled[p].item()) 80 | 81 | alli = [] 82 | prototypes_not_used = [] 83 | for p in topks.keys(): 84 | found = False 85 | for idx, score in topks[p]: 86 | alli.append(idx) 87 | if score > 0.1: #in case prototypes have fewer than k well-related patches 88 | found = True 89 | if not found: 90 | prototypes_not_used.append(p) 91 | 92 | print(len(prototypes_not_used), "prototypes do not have any similarity score > 0.1. Will be ignored in visualisation.") 93 | abstained = 0 94 | # Show progress on progress bar 95 | img_iter = tqdm(enumerate(projectloader), 96 | total=len(projectloader), 97 | mininterval=50., 98 | desc='Visualizing topk', 99 | ncols=0) 100 | for i, (xs, ys) in img_iter: #shuffle is false so should lead to same order as in imgs 101 | if i in alli: 102 | xs, ys = xs.to(device), ys.to(device) 103 | for p in topks.keys(): 104 | if p not in prototypes_not_used: 105 | for idx, score in topks[p]: 106 | if idx == i: 107 | # Use the model to classify this batch of input data 108 | with torch.no_grad(): 109 | softmaxes, pooled, out = net(xs, inference=True) #softmaxes has shape (1, num_prototypes, W, H) 110 | outmax = torch.amax(out,dim=1)[0] #shape ([1]) because batch size of projectloader is 1 111 | if outmax.item() == 0.: 112 | abstained+=1 113 | 114 | # Take the max per prototype. 115 | max_per_prototype, max_idx_per_prototype = torch.max(softmaxes, dim=0) 116 | max_per_prototype_h, max_idx_per_prototype_h = torch.max(max_per_prototype, dim=1) 117 | max_per_prototype_w, max_idx_per_prototype_w = torch.max(max_per_prototype_h, dim=1) #shape (num_prototypes) 118 | 119 | c_weight = torch.max(classification_weights[:,p]) #ignore prototypes that are not relevant to any class 120 | if (c_weight > 1e-10) or ('pretrain' in foldername): 121 | 122 | h_idx = max_idx_per_prototype_h[p, max_idx_per_prototype_w[p]] 123 | w_idx = max_idx_per_prototype_w[p] 124 | 125 | img_to_open = imgs[i] 126 | if isinstance(img_to_open, tuple) or isinstance(img_to_open, list): #dataset contains tuples of (img,label) 127 | img_to_open = img_to_open[0] 128 | 129 | image = transforms.Resize(size=(args.image_size, args.image_size))(Image.open(img_to_open)) 130 | img_tensor = transforms.ToTensor()(image).unsqueeze_(0) #shape (1, 3, h, w) 131 | h_coor_min, h_coor_max, w_coor_min, w_coor_max = get_img_coordinates(args.image_size, softmaxes.shape, patchsize, skip, h_idx, w_idx) 132 | img_tensor_patch = img_tensor[0, :, h_coor_min:h_coor_max, w_coor_min:w_coor_max] 133 | 134 | saved[p]+=1 135 | tensors_per_prototype[p].append(img_tensor_patch) 136 | 137 | print("Abstained: ", abstained, flush=True) 138 | all_tensors = [] 139 | for p in range(net.module._num_prototypes): 140 | if saved[p]>0: 141 | # add text next to each topk-grid, to easily see which prototype it is 142 | text = "P "+str(p) 143 | txtimage = Image.new("RGB", (img_tensor_patch.shape[1],img_tensor_patch.shape[2]), (0, 0, 0)) 144 | draw = D.Draw(txtimage) 145 | draw.text((img_tensor_patch.shape[0]//2, img_tensor_patch.shape[1]//2), text, anchor='mm', fill="white") 146 | txttensor = transforms.ToTensor()(txtimage) 147 | tensors_per_prototype[p].append(txttensor) 148 | # save top-k image patches in grid 149 | try: 150 | grid = torchvision.utils.make_grid(tensors_per_prototype[p], nrow=k+1, padding=1) 151 | torchvision.utils.save_image(grid,os.path.join(dir,"grid_topk_%s.png"%(str(p)))) 152 | if saved[p]>=k: 153 | all_tensors+=tensors_per_prototype[p] 154 | except: 155 | pass 156 | if len(all_tensors)>0: 157 | grid = torchvision.utils.make_grid(all_tensors, nrow=k+1, padding=1) 158 | torchvision.utils.save_image(grid,os.path.join(dir,"grid_topk_all.png")) 159 | else: 160 | print("Pretrained prototypes not visualized. Try to pretrain longer.", flush=True) 161 | return topks 162 | 163 | 164 | def visualize(net, projectloader, num_classes, device, foldername, args: argparse.Namespace): 165 | print("Visualizing prototypes...", flush=True) 166 | dir = os.path.join(args.log_dir, foldername) 167 | if not os.path.exists(dir): 168 | os.makedirs(dir) 169 | 170 | near_imgs_dirs = dict() 171 | seen_max = dict() 172 | saved = dict() 173 | saved_ys = dict() 174 | tensors_per_prototype = dict() 175 | abstainedimgs = set() 176 | notabstainedimgs = set() 177 | 178 | for p in range(net.module._num_prototypes): 179 | near_imgs_dir = os.path.join(dir, str(p)) 180 | near_imgs_dirs[p]=near_imgs_dir 181 | seen_max[p]=0. 182 | saved[p]=0 183 | saved_ys[p]=[] 184 | tensors_per_prototype[p]=[] 185 | 186 | patchsize, skip = get_patch_size(args) 187 | 188 | imgs = projectloader.dataset.imgs 189 | 190 | # skip some images for visualisation to speed up the process 191 | if len(imgs)/num_classes <10: 192 | skip_img=10 193 | elif len(imgs)/num_classes < 50: 194 | skip_img=5 195 | else: 196 | skip_img = 2 197 | print("Every", skip_img, "is skipped in order to speed up the visualisation process", flush=True) 198 | 199 | # Make sure the model is in evaluation mode 200 | net.eval() 201 | classification_weights = net.module._classification.weight 202 | # Show progress on progress bar 203 | img_iter = tqdm(enumerate(projectloader), 204 | total=len(projectloader), 205 | mininterval=100., 206 | desc='Visualizing', 207 | ncols=0) 208 | 209 | # Iterate through the data 210 | images_seen_before = 0 211 | for i, (xs, ys) in img_iter: #shuffle is false so should lead to same order as in imgs 212 | if i % skip_img == 0: 213 | images_seen_before+=xs.shape[0] 214 | continue 215 | 216 | xs, ys = xs.to(device), ys.to(device) 217 | # Use the model to classify this batch of input data 218 | with torch.no_grad(): 219 | softmaxes, _, out = net(xs, inference=True) 220 | 221 | max_per_prototype, max_idx_per_prototype = torch.max(softmaxes, dim=0) 222 | # In PyTorch, images are represented as [channels, height, width] 223 | max_per_prototype_h, max_idx_per_prototype_h = torch.max(max_per_prototype, dim=1) 224 | max_per_prototype_w, max_idx_per_prototype_w = torch.max(max_per_prototype_h, dim=1) 225 | for p in range(0, net.module._num_prototypes): 226 | c_weight = torch.max(classification_weights[:,p]) #ignore prototypes that are not relevant to any class 227 | if c_weight>0: 228 | h_idx = max_idx_per_prototype_h[p, max_idx_per_prototype_w[p]] 229 | w_idx = max_idx_per_prototype_w[p] 230 | idx_to_select = max_idx_per_prototype[p,h_idx, w_idx].item() 231 | found_max = max_per_prototype[p,h_idx, w_idx].item() 232 | 233 | imgname = imgs[images_seen_before+idx_to_select] 234 | if out.max() < 1e-8: 235 | abstainedimgs.add(imgname) 236 | else: 237 | notabstainedimgs.add(imgname) 238 | 239 | if found_max > seen_max[p]: 240 | seen_max[p]=found_max 241 | 242 | if found_max > 0.5: 243 | img_to_open = imgs[images_seen_before+idx_to_select] 244 | if isinstance(img_to_open, tuple) or isinstance(img_to_open, list): #dataset contains tuples of (img,label) 245 | imglabel = img_to_open[1] 246 | img_to_open = img_to_open[0] 247 | 248 | image = transforms.Resize(size=(args.image_size, args.image_size))(Image.open(img_to_open).convert("RGB")) 249 | img_tensor = transforms.ToTensor()(image).unsqueeze_(0) #shape (1, 3, h, w) 250 | h_coor_min, h_coor_max, w_coor_min, w_coor_max = get_img_coordinates(args.image_size, softmaxes.shape, patchsize, skip, h_idx, w_idx) 251 | img_tensor_patch = img_tensor[0, :, h_coor_min:h_coor_max, w_coor_min:w_coor_max] 252 | saved[p]+=1 253 | tensors_per_prototype[p].append((img_tensor_patch, found_max)) 254 | 255 | save_path = os.path.join(dir, "prototype_%s")%str(p) 256 | if not os.path.exists(save_path): 257 | os.makedirs(save_path) 258 | draw = D.Draw(image) 259 | draw.rectangle([(w_coor_min,h_coor_min), (w_coor_max, h_coor_max)], outline='yellow', width=2) 260 | image.save(os.path.join(save_path, 'p%s_%s_%s_%s_rect.png'%(str(p),str(imglabel),str(round(found_max, 2)),str(img_to_open.split('/')[-1].split('.jpg')[0])))) 261 | 262 | 263 | images_seen_before+=len(ys) 264 | 265 | print("num images abstained: ", len(abstainedimgs), flush=True) 266 | print("num images not abstained: ", len(notabstainedimgs), flush=True) 267 | for p in range(net.module._num_prototypes): 268 | if saved[p]>0: 269 | try: 270 | sorted_by_second = sorted(tensors_per_prototype[p], key=lambda tup: tup[1], reverse=True) 271 | sorted_ps = [i[0] for i in sorted_by_second] 272 | grid = torchvision.utils.make_grid(sorted_ps, nrow=16, padding=1) 273 | torchvision.utils.save_image(grid,os.path.join(dir,"grid_%s.png"%(str(p)))) 274 | except RuntimeError: 275 | pass 276 | 277 | # convert latent location to coordinates of image patch 278 | def get_img_coordinates(img_size, softmaxes_shape, patchsize, skip, h_idx, w_idx): 279 | # in case latent output size is 26x26. For convnext with smaller strides. 280 | if softmaxes_shape[1] == 26 and softmaxes_shape[2] == 26: 281 | #Since the outer latent patches have a smaller receptive field, skip size is set to 4 for the first and last patch. 8 for rest. 282 | h_coor_min = max(0,(h_idx-1)*skip+4) 283 | if h_idx < softmaxes_shape[-1]-1: 284 | h_coor_max = h_coor_min + patchsize 285 | else: 286 | h_coor_min -= 4 287 | h_coor_max = h_coor_min + patchsize 288 | w_coor_min = max(0,(w_idx-1)*skip+4) 289 | if w_idx < softmaxes_shape[-1]-1: 290 | w_coor_max = w_coor_min + patchsize 291 | else: 292 | w_coor_min -= 4 293 | w_coor_max = w_coor_min + patchsize 294 | else: 295 | h_coor_min = h_idx*skip 296 | h_coor_max = min(img_size, h_idx*skip+patchsize) 297 | w_coor_min = w_idx*skip 298 | w_coor_max = min(img_size, w_idx*skip+patchsize) 299 | 300 | if h_idx == softmaxes_shape[1]-1: 301 | h_coor_max = img_size 302 | if w_idx == softmaxes_shape[2] -1: 303 | w_coor_max = img_size 304 | if h_coor_max == img_size: 305 | h_coor_min = img_size-patchsize 306 | if w_coor_max == img_size: 307 | w_coor_min = img_size-patchsize 308 | 309 | return h_coor_min, h_coor_max, w_coor_min, w_coor_max 310 | -------------------------------------------------------------------------------- /util/eval_cub_csv.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import numpy as np 3 | import pandas as pd 4 | import os 5 | from PIL import Image 6 | from tqdm import tqdm 7 | import torch 8 | from util.func import get_patch_size 9 | import csv 10 | import torchvision.transforms as transforms 11 | import torchvision 12 | from util.vis_pipnet import get_img_coordinates 13 | 14 | # Evaluates purity of CUB prototypes from csv file. General method that can be used for other part-prototype methods as well 15 | # Assumes that coordinates in csv file apply to a 224x224 image! 16 | def eval_prototypes_cub_parts_csv(csvfile, parts_loc_path, parts_name_path, imgs_id_path, epoch, args, log): 17 | patchsize, _ = get_patch_size(args) 18 | imgresize = float(args.image_size) 19 | path_to_id = dict() 20 | id_to_path = dict() 21 | with open(imgs_id_path) as f: 22 | for line in f: 23 | id, path = line.split('\n')[0].split(' ') 24 | path_to_id[path]=id 25 | id_to_path[id]=path 26 | 27 | img_to_part_xy_vis = dict() 28 | with open(parts_loc_path) as f: 29 | for line in f: 30 | img, partid, x, y, vis = line.split('\n')[0].split(' ') 31 | x =float(x) 32 | y =float(y) 33 | if img not in img_to_part_xy_vis.keys(): 34 | img_to_part_xy_vis[img]=dict() 35 | if vis == '1': 36 | img_to_part_xy_vis[img][partid]=(x,y) 37 | 38 | parts_id_to_name = dict() 39 | parts_name_to_id = dict() 40 | with open (parts_name_path) as f: 41 | for line in f: 42 | id, name = line.split('\n')[0].split(' ',1) 43 | parts_id_to_name[id]=name 44 | parts_name_to_id[name]=id 45 | print(parts_id_to_name) 46 | 47 | # merge left and right cub parts 48 | duplicate_part_ids = [] 49 | with open (parts_name_path) as f: 50 | for line in f: 51 | id, name = line.split('\n')[0].split(' ',1) 52 | if 'left' in name: 53 | new_name = name.replace('left', 'right') 54 | 55 | duplicate_part_ids.append((id, parts_name_to_id[new_name])) 56 | 57 | proto_parts_presences = dict() 58 | 59 | with open (csvfile, newline='') as f: 60 | filereader = csv.reader(f, delimiter=',') 61 | next(filereader) #skip header 62 | for (prototype, imgname, h_min_224, h_max_224, w_min_224, w_max_224) in filereader: 63 | 64 | if prototype not in proto_parts_presences.keys(): 65 | proto_parts_presences[prototype]=dict() 66 | p = prototype 67 | img = Image.open(imgname) 68 | imgname = imgname.replace('\\', '/') 69 | imgnamec, imgnamef = imgname.split('/')[-2:] 70 | if 'normal_' in imgnamef: 71 | imgnamef = imgnamef.split('normal_')[-1] 72 | imgname = imgnamec+'/'+imgnamef 73 | img_id = path_to_id[imgname] 74 | img_orig_width, img_orig_height = img.size 75 | h_min_224, h_max_224, w_min_224, w_max_224 = float(h_min_224), float(h_max_224), float(w_min_224), float(w_max_224) 76 | 77 | diffh = h_max_224 - h_min_224 78 | diffw = w_max_224 - w_min_224 79 | if diffh > patchsize: #patch size too big, we take the center. otherwise the bigger the patch, the higher the purity. 80 | correction = diffh-patchsize 81 | h_min_224 = h_min_224 + correction//2. 82 | h_max_224 = h_max_224 - correction//2. 83 | if diffw > patchsize: 84 | correction = diffw-patchsize 85 | w_min_224 = w_min_224 + correction//2. 86 | w_max_224 = w_max_224 - correction//2. 87 | 88 | orig_img_location_h_min = (img_orig_height/imgresize) * h_min_224 89 | orig_img_location_h_max = (img_orig_height/imgresize) * h_max_224 90 | orig_img_location_w_min = (img_orig_width/imgresize) * w_min_224 91 | orig_img_location_w_max = (img_orig_width/imgresize) * w_max_224 92 | 93 | part_dict_img = img_to_part_xy_vis[img_id] 94 | for part in part_dict_img.keys(): 95 | x,y = part_dict_img[part] 96 | part_in_patch = 0 97 | if y >= orig_img_location_h_min and y <= orig_img_location_h_max: 98 | if x >= orig_img_location_w_min and x <= orig_img_location_w_max: 99 | part_in_patch = 1 100 | if part not in proto_parts_presences[p].keys(): 101 | proto_parts_presences[p][part]=[] 102 | proto_parts_presences[p][part].append(part_in_patch) 103 | 104 | for pair in duplicate_part_ids: 105 | if pair[0] in part_dict_img.keys(): 106 | if pair[1] in part_dict_img.keys(): 107 | presence0 = proto_parts_presences[p][pair[0]][-1] 108 | presence1 = proto_parts_presences[p][pair[1]][-1] 109 | if presence0 > presence1: 110 | proto_parts_presences[p][pair[1]][-1] = presence0 111 | 112 | del proto_parts_presences[p][pair[0]] 113 | else: 114 | 115 | if pair[1] not in proto_parts_presences[p].keys(): 116 | proto_parts_presences[p][pair[1]]=[] 117 | proto_parts_presences[p][pair[1]].append(proto_parts_presences[p][pair[0]][-1]) 118 | del proto_parts_presences[p][pair[0]] 119 | 120 | print("\n Eval CUB Parts - Epoch: ", epoch, flush=True) 121 | print("Number of prototypes in parts_presences: ", len(proto_parts_presences.keys()), flush=True) 122 | 123 | prototypes_part_related = 0 124 | max_presence_purity = dict() 125 | max_presence_purity_part = dict() 126 | max_presence_purity_sum = dict() 127 | 128 | most_often_present_purity = dict() 129 | part_most_present = dict() 130 | 131 | for proto in proto_parts_presences.keys(): 132 | 133 | max_presence_purity[proto]= 0. 134 | part_most_present[proto] = ('0',0) 135 | most_often_present_purity[proto] = 0. 136 | 137 | # CUB parts 7,8 and 9 are duplicate (right and left). additional check that these should not occur (already fixed earlier in this function) 138 | if ('7' in proto_parts_presences[proto].keys() or '8' in proto_parts_presences[proto].keys() or '9' in proto_parts_presences[proto].keys()): 139 | print("unused part in keys! ", proto, proto_parts_presences[proto].keys(), proto_parts_presences[proto], flush=True) 140 | raise ValueError() 141 | 142 | for part in proto_parts_presences[proto].keys(): 143 | presence_purity = np.mean(proto_parts_presences[proto][part]) 144 | sum_occurs = np.array(proto_parts_presences[proto][part]).sum() 145 | 146 | # evaluate whether the purity of this prototype for this part is higher than for other parts 147 | if presence_purity > max_presence_purity[proto]: 148 | max_presence_purity[proto]=presence_purity 149 | max_presence_purity_part[proto]=parts_id_to_name[part] 150 | max_presence_purity_sum[proto] = sum_occurs 151 | elif presence_purity == max_presence_purity[proto]: 152 | if presence_purity == 0.: 153 | max_presence_purity[proto]=presence_purity 154 | max_presence_purity_part[proto]=parts_id_to_name[part] 155 | max_presence_purity_sum[proto] = sum_occurs 156 | elif sum_occurs > max_presence_purity_sum[proto]: 157 | max_presence_purity[proto]=presence_purity 158 | max_presence_purity_part[proto]=parts_id_to_name[part] 159 | max_presence_purity_sum[proto] = sum_occurs 160 | 161 | if sum_occurs > part_most_present[proto][1]: 162 | part_most_present[proto] = (part, sum_occurs) 163 | most_often_present_purity[proto]=presence_purity 164 | if max_presence_purity[proto] > 0.5: 165 | prototypes_part_related += 1 166 | 167 | print("Number of part-related prototypes (purity>0.5): ", prototypes_part_related, flush=True) 168 | 169 | print("Mean purity of prototypes (corresponding to purest part): ", np.mean(list(max_presence_purity.values())), "std: ", np.std(list(max_presence_purity.values())), flush=True) 170 | print("Prototypes with highest-purity part (no contraints): ", max_presence_purity_part, flush=True) 171 | print("Prototype with part that has most often overlap with prototype: ", part_most_present, flush=True) 172 | 173 | log.log_values('log_epoch_overview', "p_cub_"+str(epoch), "mean purity (averaged over all prototypes, corresponding to purest part)", "std purity", "mean purity (averaged over all prototypes, corresponding to part with most often overlap)", "std purity", "# prototypes in csv", "#part-related prototypes (purity > 0.5)","","") 174 | 175 | log.log_values('log_epoch_overview', "p_cub_"+str(epoch), np.mean(list(max_presence_purity.values())), np.std(list(max_presence_purity.values())), np.mean(list(most_often_present_purity.values())), np.std(list(most_often_present_purity.values())), len(list(proto_parts_presences.keys())), prototypes_part_related, "", "") 176 | 177 | # Writes coordinates of image patches per prototype to csv file (image resized to 224x224) 178 | def get_proto_patches_cub(net, projectloader, epoch, device, args, threshold=0.5): 179 | # Make sure the model is in evaluation mode 180 | net.eval() 181 | 182 | imgs = projectloader.dataset.imgs 183 | classification_weights = net.module._classification.weight 184 | patchsize, skip = get_patch_size(args) 185 | 186 | proto_img_coordinates = [] 187 | 188 | csvfilepath = os.path.join(args.log_dir, str(epoch)+'_pipnet_prototypes_cub_all.csv') 189 | columns = ["prototype", "img name", "h_min_224", "h_max_224", "w_min_224", "w_max_224"] 190 | with open(csvfilepath, "w", newline='') as csvfile: 191 | print("Collecting Prototype Image Patches for Evaluating CUB part purity. Writing CSV file with image patche coordinates..", flush=True) 192 | writer = csv.writer(csvfile, delimiter=',') 193 | writer.writerow(columns) 194 | # Iterate through the prototypes and projection set 195 | img_iter = tqdm(enumerate(range(len(imgs))), total=len(imgs),mininterval=50.,ncols=0,desc='Collecting patch coordinates CUB') 196 | for _, imgid in img_iter: 197 | imgname = imgs[imgid][0] 198 | imgtensor = projectloader.dataset[imgid][0].unsqueeze(0) 199 | with torch.no_grad(): 200 | # Use the model to classify this input image 201 | pfs, pooled, _ = net(imgtensor) 202 | pooled = pooled.squeeze(0) 203 | pfs = pfs.squeeze(0) 204 | 205 | for prototype in range(net.module._num_prototypes): 206 | c_weight = torch.max(classification_weights[:,prototype]) 207 | if c_weight > 1e-5:#ignore prototypes that are not relevant to any class 208 | if pooled[prototype].item()>threshold: #similarity score > threshold 209 | location_h, location_h_idx = torch.max(pfs[prototype,:,:], dim=0) 210 | _, location_w_idx = torch.max(location_h, dim=0) 211 | h_coor_min, h_coor_max, w_coor_min, w_coor_max = get_img_coordinates(args.image_size, pfs.shape, patchsize, skip, location_h_idx[location_w_idx].item(), location_w_idx.item()) 212 | proto_img_coordinates.append([prototype, imgname, h_coor_min, h_coor_max, w_coor_min, w_coor_max]) 213 | 214 | writer.writerows(proto_img_coordinates) 215 | return csvfilepath 216 | 217 | # Writes coordinates of top-k image patches per prototype to csv file (image resized to 224x224) 218 | def get_topk_cub(net, projectloader, k, epoch, device, args): 219 | # Make sure the model is in evaluation mode 220 | net.eval() 221 | 222 | # Show progress on progress bar 223 | project_iter = tqdm(enumerate(projectloader), 224 | total=len(projectloader), 225 | desc='Collecting top-k Prototypes CUB parts', 226 | mininterval=50., 227 | ncols=0) 228 | imgs = projectloader.dataset.imgs 229 | classification_weights = net.module._classification.weight 230 | patchsize, skip = get_patch_size(args) 231 | scores_per_prototype = dict() 232 | 233 | # Iterate through the projection set 234 | for i, (xs, ys) in project_iter: 235 | xs, ys = xs.to(device), ys.to(device) 236 | 237 | with torch.no_grad(): 238 | # Use the model to classify this batch of input data 239 | pfs, pooled, _ = net(xs) 240 | pooled = pooled.squeeze(0) 241 | pfs = pfs.squeeze(0) 242 | for p in range(pooled.shape[0]): 243 | c_weight = torch.max(classification_weights[:,p]) 244 | if c_weight > 1e-5:#ignore prototypes that are not relevant to any class 245 | if p not in scores_per_prototype: 246 | scores_per_prototype[p] = [] 247 | scores_per_prototype[p].append((i, pooled[p].item())) 248 | 249 | proto_img_coordinates = [] 250 | csvfilepath = os.path.join(args.log_dir, str(epoch)+'_pipnet_prototypes_cub_topk.csv') 251 | too_small = set() 252 | protoype_iter = tqdm(enumerate(scores_per_prototype.keys()), total=len(list(scores_per_prototype.keys())),mininterval=5.,ncols=0,desc='Collecting top-k patch coordinates CUB') 253 | with open(csvfilepath, "w", newline='') as csvfile: 254 | print("Writing CSV file with top k image patches..", flush=True) 255 | writer = csv.writer(csvfile, delimiter=',') 256 | writer.writerow(["prototype", "img name", "h_min_224", "h_max_224", "w_min_224", "w_max_224"]) 257 | for _, prototype in protoype_iter: 258 | df = pd.DataFrame(scores_per_prototype[prototype], columns=['img_id', 'scores']) 259 | topk = df.nlargest(k, 'scores') 260 | for index, row in topk.iterrows(): 261 | imgid = int(row['img_id']) 262 | imgname = imgs[imgid][0] 263 | imgtensor = projectloader.dataset[imgid][0].unsqueeze(0) 264 | with torch.no_grad(): 265 | # Use the model to classify this batch of input data 266 | pfs, pooled, _ = net(imgtensor) 267 | pfs = pfs.squeeze(0) 268 | pooled = pooled.squeeze(0) 269 | if pooled[p].item() < 0.1: 270 | too_small.add(p) 271 | location_h, location_h_idx = torch.max(pfs[prototype,:,:], dim=0) 272 | _, location_w_idx = torch.max(location_h, dim=0) 273 | location = (location_h_idx[location_w_idx].item(), location_w_idx.item()) 274 | h_coor_min, h_coor_max, w_coor_min, w_coor_max = get_img_coordinates(args.image_size, pfs.shape, patchsize, skip, location[0], location[1]) 275 | proto_img_coordinates.append([prototype, imgname, h_coor_min, h_coor_max, w_coor_min, w_coor_max]) 276 | # write intermediate results in case of large dataset 277 | if len(proto_img_coordinates)> 10000: 278 | writer.writerows(proto_img_coordinates) 279 | proto_img_coordinates = [] 280 | print("Warning: image patches included in topk, but similarity < 0.1! This might unfairly reduce the purity metric because prototype has less than k similar image patches. You could consider reducing k for prototypes", too_small, flush=True) 281 | 282 | writer.writerows(proto_img_coordinates) 283 | return csvfilepath 284 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from pipnet.pipnet import PIPNet, get_network 2 | from util.log import Log 3 | import torch.nn as nn 4 | from util.args import get_args, save_args, get_optimizer_nn 5 | from util.data import get_dataloaders 6 | from util.func import init_weights_xavier 7 | from pipnet.train import train_pipnet 8 | from pipnet.test import eval_pipnet, get_thresholds, eval_ood 9 | from util.eval_cub_csv import eval_prototypes_cub_parts_csv, get_topk_cub, get_proto_patches_cub 10 | import torch 11 | from util.vis_pipnet import visualize, visualize_topk 12 | from util.visualize_prediction import vis_pred, vis_pred_experiments 13 | import sys, os 14 | import random 15 | import numpy as np 16 | from shutil import copy 17 | import matplotlib.pyplot as plt 18 | from copy import deepcopy 19 | 20 | def run_pipnet(args=None): 21 | 22 | torch.manual_seed(args.seed) 23 | torch.cuda.manual_seed_all(args.seed) 24 | random.seed(args.seed) 25 | np.random.seed(args.seed) 26 | 27 | args = args or get_args() 28 | assert args.batch_size > 1 29 | 30 | # Create a logger 31 | log = Log(args.log_dir) 32 | print("Log dir: ", args.log_dir, flush=True) 33 | # Log the run arguments 34 | save_args(args, log.metadata_dir) 35 | 36 | gpu_list = args.gpu_ids.split(',') 37 | device_ids = [] 38 | if args.gpu_ids!='': 39 | for m in range(len(gpu_list)): 40 | device_ids.append(int(gpu_list[m])) 41 | 42 | global device 43 | if not args.disable_cuda and torch.cuda.is_available(): 44 | if len(device_ids)==1: 45 | device = torch.device('cuda:{}'.format(args.gpu_ids)) 46 | elif len(device_ids)==0: 47 | device = torch.device('cuda') 48 | print("CUDA device set without id specification", flush=True) 49 | device_ids.append(torch.cuda.current_device()) 50 | else: 51 | print("This code should work with multiple GPU's but we didn't test that, so we recommend to use only 1 GPU.", flush=True) 52 | device_str = '' 53 | for d in device_ids: 54 | device_str+=str(d) 55 | device_str+="," 56 | device = torch.device('cuda:'+str(device_ids[0])) 57 | else: 58 | device = torch.device('cpu') 59 | 60 | # Log which device was actually used 61 | print("Device used: ", device, "with id", device_ids, flush=True) 62 | 63 | # Obtain the dataset and dataloaders 64 | trainloader, trainloader_pretraining, trainloader_normal, trainloader_normal_augment, projectloader, testloader, test_projectloader, classes = get_dataloaders(args, device) 65 | if len(classes)<=20: 66 | if args.validation_size == 0.: 67 | print("Classes: ", testloader.dataset.class_to_idx, flush=True) 68 | else: 69 | print("Classes: ", str(classes), flush=True) 70 | 71 | # Create a convolutional network based on arguments and add 1x1 conv layer 72 | feature_net, add_on_layers, pool_layer, classification_layer, num_prototypes = get_network(len(classes), args) 73 | 74 | # Create a PIP-Net 75 | net = PIPNet(num_classes=len(classes), 76 | num_prototypes=num_prototypes, 77 | feature_net = feature_net, 78 | args = args, 79 | add_on_layers = add_on_layers, 80 | pool_layer = pool_layer, 81 | classification_layer = classification_layer 82 | ) 83 | net = net.to(device=device) 84 | net = nn.DataParallel(net, device_ids = device_ids) 85 | 86 | optimizer_net, optimizer_classifier, params_to_freeze, params_to_train, params_backbone = get_optimizer_nn(net, args) 87 | 88 | # Initialize or load model 89 | with torch.no_grad(): 90 | if args.state_dict_dir_net != '': 91 | epoch = 0 92 | checkpoint = torch.load(args.state_dict_dir_net,map_location=device) 93 | net.load_state_dict(checkpoint['model_state_dict'],strict=True) 94 | print("Pretrained network loaded", flush=True) 95 | net.module._multiplier.requires_grad = False 96 | try: 97 | optimizer_net.load_state_dict(checkpoint['optimizer_net_state_dict']) 98 | except: 99 | pass 100 | if torch.mean(net.module._classification.weight).item() > 1.0 and torch.mean(net.module._classification.weight).item() < 3.0 and torch.count_nonzero(torch.relu(net.module._classification.weight-1e-5)).float().item() > 0.8*(num_prototypes*len(classes)): #assume that the linear classification layer is not yet trained (e.g. when loading a pretrained backbone only) 101 | print("We assume that the classification layer is not yet trained. We re-initialize it...", flush=True) 102 | torch.nn.init.normal_(net.module._classification.weight, mean=1.0,std=0.1) 103 | torch.nn.init.constant_(net.module._multiplier, val=2.) 104 | print("Classification layer initialized with mean", torch.mean(net.module._classification.weight).item(), flush=True) 105 | if args.bias: 106 | torch.nn.init.constant_(net.module._classification.bias, val=0.) 107 | # else: #uncomment these lines if you want to load the optimizer too 108 | # if 'optimizer_classifier_state_dict' in checkpoint.keys(): 109 | # optimizer_classifier.load_state_dict(checkpoint['optimizer_classifier_state_dict']) 110 | 111 | else: 112 | net.module._add_on.apply(init_weights_xavier) 113 | torch.nn.init.normal_(net.module._classification.weight, mean=1.0,std=0.1) 114 | if args.bias: 115 | torch.nn.init.constant_(net.module._classification.bias, val=0.) 116 | torch.nn.init.constant_(net.module._multiplier, val=2.) 117 | net.module._multiplier.requires_grad = False 118 | 119 | print("Classification layer initialized with mean", torch.mean(net.module._classification.weight).item(), flush=True) 120 | 121 | # Define classification loss function and scheduler 122 | criterion = nn.NLLLoss(reduction='mean').to(device) 123 | scheduler_net = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_net, T_max=len(trainloader_pretraining)*args.epochs_pretrain, eta_min=args.lr_block/100., last_epoch=-1) 124 | 125 | # Forward one batch through the backbone to get the latent output size 126 | with torch.no_grad(): 127 | xs1, _, _ = next(iter(trainloader)) 128 | xs1 = xs1.to(device) 129 | proto_features, _, _ = net(xs1) 130 | wshape = proto_features.shape[-1] 131 | args.wshape = wshape #needed for calculating image patch size 132 | print("Output shape: ", proto_features.shape, flush=True) 133 | 134 | if net.module._num_classes == 2: 135 | # Create a csv log for storing the test accuracy, F1-score, mean train accuracy and mean loss for each epoch 136 | log.create_log('log_epoch_overview', 'epoch', 'test_top1_acc', 'test_f1', 'almost_sim_nonzeros', 'local_size_all_classes','almost_nonzeros_pooled', 'num_nonzero_prototypes', 'mean_train_acc', 'mean_train_loss_during_epoch') 137 | print("Your dataset only has two classes. Is the number of samples per class similar? If the data is imbalanced, we recommend to use the --weighted_loss flag to account for the imbalance.", flush=True) 138 | else: 139 | # Create a csv log for storing the test accuracy (top 1 and top 5), mean train accuracy and mean loss for each epoch 140 | log.create_log('log_epoch_overview', 'epoch', 'test_top1_acc', 'test_top5_acc', 'almost_sim_nonzeros', 'local_size_all_classes','almost_nonzeros_pooled', 'num_nonzero_prototypes', 'mean_train_acc', 'mean_train_loss_during_epoch') 141 | 142 | 143 | lrs_pretrain_net = [] 144 | # PRETRAINING PROTOTYPES PHASE 145 | for epoch in range(1, args.epochs_pretrain+1): 146 | for param in params_to_train: 147 | param.requires_grad = True 148 | for param in net.module._add_on.parameters(): 149 | param.requires_grad = True 150 | for param in net.module._classification.parameters(): 151 | param.requires_grad = False 152 | for param in params_to_freeze: 153 | param.requires_grad = True # can be set to False when you want to freeze more layers 154 | for param in params_backbone: 155 | param.requires_grad = False #can be set to True when you want to train whole backbone (e.g. if dataset is very different from ImageNet) 156 | 157 | print("\nPretrain Epoch", epoch, "with batch size", trainloader_pretraining.batch_size, flush=True) 158 | 159 | # Pretrain prototypes 160 | train_info = train_pipnet(net, trainloader_pretraining, optimizer_net, optimizer_classifier, scheduler_net, None, criterion, epoch, args.epochs_pretrain, device, pretrain=True, finetune=False) 161 | lrs_pretrain_net+=train_info['lrs_net'] 162 | plt.clf() 163 | plt.plot(lrs_pretrain_net) 164 | plt.savefig(os.path.join(args.log_dir,'lr_pretrain_net.png')) 165 | log.log_values('log_epoch_overview', epoch, "n.a.", "n.a.", "n.a.", "n.a.", "n.a.", "n.a.", "n.a.", train_info['loss']) 166 | 167 | if args.state_dict_dir_net == '': 168 | net.eval() 169 | torch.save({'model_state_dict': net.state_dict(), 'optimizer_net_state_dict': optimizer_net.state_dict()}, os.path.join(os.path.join(args.log_dir, 'checkpoints'), 'net_pretrained')) 170 | net.train() 171 | with torch.no_grad(): 172 | if 'convnext' in args.net and args.epochs_pretrain > 0: 173 | topks = visualize_topk(net, projectloader, len(classes), device, 'visualised_pretrained_prototypes_topk', args) 174 | 175 | # SECOND TRAINING PHASE 176 | # re-initialize optimizers and schedulers for second training phase 177 | optimizer_net, optimizer_classifier, params_to_freeze, params_to_train, params_backbone = get_optimizer_nn(net, args) 178 | scheduler_net = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_net, T_max=len(trainloader)*args.epochs, eta_min=args.lr_net/100.) 179 | # scheduler for the classification layer is with restarts, such that the model can re-active zeroed-out prototypes. Hence an intuitive choice. 180 | if args.epochs<=30: 181 | scheduler_classifier = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_classifier, T_0=5, eta_min=0.001, T_mult=1, verbose=False) 182 | else: 183 | scheduler_classifier = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_classifier, T_0=10, eta_min=0.001, T_mult=1, verbose=False) 184 | for param in net.module.parameters(): 185 | param.requires_grad = False 186 | for param in net.module._classification.parameters(): 187 | param.requires_grad = True 188 | 189 | frozen = True 190 | lrs_net = [] 191 | lrs_classifier = [] 192 | 193 | for epoch in range(1, args.epochs + 1): 194 | epochs_to_finetune = 3 #during finetuning, only train classification layer and freeze rest. usually done for a few epochs (at least 1, more depends on size of dataset) 195 | if epoch <= epochs_to_finetune and (args.epochs_pretrain > 0 or args.state_dict_dir_net != ''): 196 | for param in net.module._add_on.parameters(): 197 | param.requires_grad = False 198 | for param in params_to_train: 199 | param.requires_grad = False 200 | for param in params_to_freeze: 201 | param.requires_grad = False 202 | for param in params_backbone: 203 | param.requires_grad = False 204 | finetune = True 205 | 206 | else: 207 | finetune=False 208 | if frozen: 209 | # unfreeze backbone 210 | if epoch>(args.freeze_epochs): 211 | for param in net.module._add_on.parameters(): 212 | param.requires_grad = True 213 | for param in params_to_freeze: 214 | param.requires_grad = True 215 | for param in params_to_train: 216 | param.requires_grad = True 217 | for param in params_backbone: 218 | param.requires_grad = True 219 | frozen = False 220 | # freeze first layers of backbone, train rest 221 | else: 222 | for param in params_to_freeze: 223 | param.requires_grad = True #Can be set to False if you want to train fewer layers of backbone 224 | for param in net.module._add_on.parameters(): 225 | param.requires_grad = True 226 | for param in params_to_train: 227 | param.requires_grad = True 228 | for param in params_backbone: 229 | param.requires_grad = False 230 | 231 | print("\n Epoch", epoch, "frozen:", frozen, flush=True) 232 | if (epoch==args.epochs or epoch%30==0) and args.epochs>1: 233 | # SET SMALL WEIGHTS TO ZERO 234 | with torch.no_grad(): 235 | torch.set_printoptions(profile="full") 236 | net.module._classification.weight.copy_(torch.clamp(net.module._classification.weight.data - 0.001, min=0.)) 237 | print("Classifier weights: ", net.module._classification.weight[net.module._classification.weight.nonzero(as_tuple=True)], (net.module._classification.weight[net.module._classification.weight.nonzero(as_tuple=True)]).shape, flush=True) 238 | if args.bias: 239 | print("Classifier bias: ", net.module._classification.bias, flush=True) 240 | torch.set_printoptions(profile="default") 241 | 242 | train_info = train_pipnet(net, trainloader, optimizer_net, optimizer_classifier, scheduler_net, scheduler_classifier, criterion, epoch, args.epochs, device, pretrain=False, finetune=finetune) 243 | lrs_net+=train_info['lrs_net'] 244 | lrs_classifier+=train_info['lrs_class'] 245 | # Evaluate model 246 | eval_info = eval_pipnet(net, testloader, epoch, device, log) 247 | log.log_values('log_epoch_overview', epoch, eval_info['top1_accuracy'], eval_info['top5_accuracy'], eval_info['almost_sim_nonzeros'], eval_info['local_size_all_classes'], eval_info['almost_nonzeros'], eval_info['num non-zero prototypes'], train_info['train_accuracy'], train_info['loss']) 248 | 249 | with torch.no_grad(): 250 | net.eval() 251 | torch.save({'model_state_dict': net.state_dict(), 'optimizer_net_state_dict': optimizer_net.state_dict(), 'optimizer_classifier_state_dict': optimizer_classifier.state_dict()}, os.path.join(os.path.join(args.log_dir, 'checkpoints'), 'net_trained')) 252 | 253 | if epoch%30 == 0: 254 | net.eval() 255 | torch.save({'model_state_dict': net.state_dict(), 'optimizer_net_state_dict': optimizer_net.state_dict(), 'optimizer_classifier_state_dict': optimizer_classifier.state_dict()}, os.path.join(os.path.join(args.log_dir, 'checkpoints'), 'net_trained_%s'%str(epoch))) 256 | 257 | # save learning rate in figure 258 | plt.clf() 259 | plt.plot(lrs_net) 260 | plt.savefig(os.path.join(args.log_dir,'lr_net.png')) 261 | plt.clf() 262 | plt.plot(lrs_classifier) 263 | plt.savefig(os.path.join(args.log_dir,'lr_class.png')) 264 | 265 | net.eval() 266 | torch.save({'model_state_dict': net.state_dict(), 'optimizer_net_state_dict': optimizer_net.state_dict(), 'optimizer_classifier_state_dict': optimizer_classifier.state_dict()}, os.path.join(os.path.join(args.log_dir, 'checkpoints'), 'net_trained_last')) 267 | 268 | topks = visualize_topk(net, projectloader, len(classes), device, 'visualised_prototypes_topk', args) 269 | # set weights of prototypes that are never really found in projection set to 0 270 | set_to_zero = [] 271 | if topks: 272 | for prot in topks.keys(): 273 | found = False 274 | for (i_id, score) in topks[prot]: 275 | if score > 0.1: 276 | found = True 277 | if not found: 278 | torch.nn.init.zeros_(net.module._classification.weight[:,prot]) 279 | set_to_zero.append(prot) 280 | print("Weights of prototypes", set_to_zero, "are set to zero because it is never detected with similarity>0.1 in the training set", flush=True) 281 | eval_info = eval_pipnet(net, testloader, "notused"+str(args.epochs), device, log) 282 | log.log_values('log_epoch_overview', "notused"+str(args.epochs), eval_info['top1_accuracy'], eval_info['top5_accuracy'], eval_info['almost_sim_nonzeros'], eval_info['local_size_all_classes'], eval_info['almost_nonzeros'], eval_info['num non-zero prototypes'], "n.a.", "n.a.") 283 | 284 | print("classifier weights: ", net.module._classification.weight, flush=True) 285 | print("Classifier weights nonzero: ", net.module._classification.weight[net.module._classification.weight.nonzero(as_tuple=True)], (net.module._classification.weight[net.module._classification.weight.nonzero(as_tuple=True)]).shape, flush=True) 286 | print("Classifier bias: ", net.module._classification.bias, flush=True) 287 | # Print weights and relevant prototypes per class 288 | for c in range(net.module._classification.weight.shape[0]): 289 | relevant_ps = [] 290 | proto_weights = net.module._classification.weight[c,:] 291 | for p in range(net.module._classification.weight.shape[1]): 292 | if proto_weights[p]> 1e-3: 293 | relevant_ps.append((p, proto_weights[p].item())) 294 | if args.validation_size == 0.: 295 | print("Class", c, "(", list(testloader.dataset.class_to_idx.keys())[list(testloader.dataset.class_to_idx.values()).index(c)],"):","has", len(relevant_ps),"relevant prototypes: ", relevant_ps, flush=True) 296 | 297 | # Evaluate prototype purity 298 | if args.dataset == 'CUB-200-2011': 299 | projectset_img0_path = projectloader.dataset.samples[0][0] 300 | project_path = os.path.split(os.path.split(projectset_img0_path)[0])[0].split("dataset")[0] 301 | parts_loc_path = os.path.join(project_path, "parts/part_locs.txt") 302 | parts_name_path = os.path.join(project_path, "parts/parts.txt") 303 | imgs_id_path = os.path.join(project_path, "images.txt") 304 | cubthreshold = 0.5 305 | 306 | net.eval() 307 | print("\n\nEvaluating cub prototypes for training set", flush=True) 308 | csvfile_topk = get_topk_cub(net, projectloader, 10, 'train_'+str(epoch), device, args) 309 | eval_prototypes_cub_parts_csv(csvfile_topk, parts_loc_path, parts_name_path, imgs_id_path, 'train_topk_'+str(epoch), args, log) 310 | 311 | csvfile_all = get_proto_patches_cub(net, projectloader, 'train_all_'+str(epoch), device, args, threshold=cubthreshold) 312 | eval_prototypes_cub_parts_csv(csvfile_all, parts_loc_path, parts_name_path, imgs_id_path, 'train_all_thres'+str(cubthreshold)+'_'+str(epoch), args, log) 313 | 314 | print("\n\nEvaluating cub prototypes for test set", flush=True) 315 | csvfile_topk = get_topk_cub(net, test_projectloader, 10, 'test_'+str(epoch), device, args) 316 | eval_prototypes_cub_parts_csv(csvfile_topk, parts_loc_path, parts_name_path, imgs_id_path, 'test_topk_'+str(epoch), args, log) 317 | cubthreshold = 0.5 318 | csvfile_all = get_proto_patches_cub(net, test_projectloader, 'test_'+str(epoch), device, args, threshold=cubthreshold) 319 | eval_prototypes_cub_parts_csv(csvfile_all, parts_loc_path, parts_name_path, imgs_id_path, 'test_all_thres'+str(cubthreshold)+'_'+str(epoch), args, log) 320 | 321 | # visualize predictions 322 | visualize(net, projectloader, len(classes), device, 'visualised_prototypes', args) 323 | testset_img0_path = test_projectloader.dataset.samples[0][0] 324 | test_path = os.path.split(os.path.split(testset_img0_path)[0])[0] 325 | vis_pred(net, test_path, classes, device, args) 326 | if args.extra_test_image_folder != '': 327 | if os.path.exists(args.extra_test_image_folder): 328 | vis_pred_experiments(net, args.extra_test_image_folder, classes, device, args) 329 | 330 | 331 | # EVALUATE OOD DETECTION 332 | ood_datasets = ["CARS", "CUB-200-2011", "pets"] 333 | for percent in [95.]: 334 | print("\nOOD Evaluation for epoch", epoch,"with percent of", percent, flush=True) 335 | _, _, _, class_thresholds = get_thresholds(net, testloader, epoch, device, percent, log) 336 | print("Thresholds:", class_thresholds, flush=True) 337 | # Evaluate with in-distribution data 338 | id_fraction = eval_ood(net, testloader, epoch, device, class_thresholds) 339 | print("ID class threshold ID fraction (TPR) with percent",percent,":", id_fraction, flush=True) 340 | 341 | # Evaluate with out-of-distribution data 342 | for ood_dataset in ood_datasets: 343 | if ood_dataset != args.dataset: 344 | print("\n OOD dataset: ", ood_dataset,flush=True) 345 | ood_args = deepcopy(args) 346 | ood_args.dataset = ood_dataset 347 | _, _, _, _, _,ood_testloader, _, _ = get_dataloaders(ood_args, device) 348 | 349 | id_fraction = eval_ood(net, ood_testloader, epoch, device, class_thresholds) 350 | print(args.dataset, "- OOD", ood_dataset, "class threshold ID fraction (FPR) with percent",percent,":", id_fraction, flush=True) 351 | 352 | print("Done!", flush=True) 353 | 354 | if __name__ == '__main__': 355 | args = get_args() 356 | torch.manual_seed(args.seed) 357 | torch.cuda.manual_seed_all(args.seed) 358 | random.seed(args.seed) 359 | np.random.seed(args.seed) 360 | print_dir = os.path.join(args.log_dir,'out.txt') 361 | tqdm_dir = os.path.join(args.log_dir,'tqdm.txt') 362 | if not os.path.isdir(args.log_dir): 363 | os.mkdir(args.log_dir) 364 | 365 | sys.stdout.close() 366 | sys.stderr.close() 367 | sys.stdout = open(print_dir, 'w') 368 | sys.stderr = open(tqdm_dir, 'w') 369 | run_pipnet(args) 370 | 371 | sys.stdout.close() 372 | sys.stderr.close() 373 | -------------------------------------------------------------------------------- /util/data.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import argparse 4 | import torch 5 | import torch.optim 6 | import torch.utils.data 7 | import torchvision 8 | import torchvision.transforms as transforms 9 | from typing import Tuple, Dict 10 | from torch import Tensor 11 | import random 12 | from sklearn.model_selection import train_test_split 13 | 14 | def get_data(args: argparse.Namespace): 15 | """ 16 | Load the proper dataset based on the parsed arguments 17 | """ 18 | torch.manual_seed(args.seed) 19 | random.seed(args.seed) 20 | np.random.seed(args.seed) 21 | if args.dataset =='CUB-200-2011': 22 | return get_birds(True, './data/CUB_200_2011/dataset/train_crop', './data/CUB_200_2011/dataset/train', './data/CUB_200_2011/dataset/test_crop', args.image_size, args.seed, args.validation_size, './data/CUB_200_2011/dataset/train', './data/CUB_200_2011/dataset/test_full') 23 | if args.dataset == 'pets': 24 | return get_pets(True, './data/PETS/dataset/train','./data/PETS/dataset/train','./data/PETS/dataset/test', args.image_size, args.seed, args.validation_size) 25 | if args.dataset == 'partimagenet': #use --validation_size of 0.2 26 | return get_partimagenet(True, './data/partimagenet/dataset/all', './data/partimagenet/dataset/all', None, args.image_size, args.seed, args.validation_size) 27 | if args.dataset == 'CARS': 28 | return get_cars(True, './data/cars/dataset/train', './data/cars/dataset/train', './data/cars/dataset/test', args.image_size, args.seed, args.validation_size) 29 | if args.dataset == 'grayscale_example': 30 | return get_grayscale(True, './data/train', './data/train', './data/test', args.image_size, args.seed, args.validation_size) 31 | raise Exception(f'Could not load data set, data set "{args.dataset}" not found!') 32 | 33 | def get_dataloaders(args: argparse.Namespace, device): 34 | """ 35 | Get data loaders 36 | """ 37 | # Obtain the dataset 38 | trainset, trainset_pretraining, trainset_normal, trainset_normal_augment, projectset, testset, testset_projection, classes, num_channels, train_indices, targets = get_data(args) 39 | 40 | # Determine if GPU should be used 41 | cuda = not args.disable_cuda and torch.cuda.is_available() 42 | to_shuffle = True 43 | sampler = None 44 | 45 | num_workers = args.num_workers 46 | 47 | if args.weighted_loss: 48 | if targets is None: 49 | raise ValueError("Weighted loss not implemented for this dataset. Targets should be restructured") 50 | # https://discuss.pytorch.org/t/dataloader-using-subsetrandomsampler-and-weightedrandomsampler-at-the-same-time/29907 51 | class_sample_count = torch.tensor([(targets[train_indices] == t).sum() for t in torch.unique(targets, sorted=True)]) 52 | weight = 1. / class_sample_count.float() 53 | print("Weights for weighted sampler: ", weight, flush=True) 54 | samples_weight = torch.tensor([weight[t] for t in targets[train_indices]]) 55 | # Create sampler, dataset, loader 56 | sampler = torch.utils.data.WeightedRandomSampler(samples_weight, len(samples_weight),replacement=True) 57 | to_shuffle = False 58 | 59 | pretrain_batchsize = args.batch_size_pretrain 60 | 61 | 62 | trainloader = torch.utils.data.DataLoader(trainset, 63 | batch_size=args.batch_size, 64 | shuffle=to_shuffle, 65 | sampler=sampler, 66 | pin_memory=cuda, 67 | num_workers=num_workers, 68 | worker_init_fn=np.random.seed(args.seed), 69 | drop_last=True 70 | ) 71 | if trainset_pretraining is not None: 72 | trainloader_pretraining = torch.utils.data.DataLoader(trainset_pretraining, 73 | batch_size=pretrain_batchsize, 74 | shuffle=to_shuffle, 75 | sampler=sampler, 76 | pin_memory=cuda, 77 | num_workers=num_workers, 78 | worker_init_fn=np.random.seed(args.seed), 79 | drop_last=True 80 | ) 81 | 82 | else: 83 | trainloader_pretraining = torch.utils.data.DataLoader(trainset, 84 | batch_size=pretrain_batchsize, 85 | shuffle=to_shuffle, 86 | sampler=sampler, 87 | pin_memory=cuda, 88 | num_workers=num_workers, 89 | worker_init_fn=np.random.seed(args.seed), 90 | drop_last=True 91 | ) 92 | 93 | trainloader_normal = torch.utils.data.DataLoader(trainset_normal, 94 | batch_size=args.batch_size, 95 | shuffle=to_shuffle, 96 | sampler=sampler, 97 | pin_memory=cuda, 98 | num_workers=num_workers, 99 | worker_init_fn=np.random.seed(args.seed), 100 | drop_last=True 101 | ) 102 | trainloader_normal_augment = torch.utils.data.DataLoader(trainset_normal_augment, 103 | batch_size=args.batch_size, 104 | shuffle=to_shuffle, 105 | sampler=sampler, 106 | pin_memory=cuda, 107 | num_workers=num_workers, 108 | worker_init_fn=np.random.seed(args.seed), 109 | drop_last=True 110 | ) 111 | 112 | projectloader = torch.utils.data.DataLoader(projectset, 113 | batch_size = 1, 114 | shuffle=False, 115 | pin_memory=cuda, 116 | num_workers=num_workers, 117 | worker_init_fn=np.random.seed(args.seed), 118 | drop_last=False 119 | ) 120 | testloader = torch.utils.data.DataLoader(testset, 121 | batch_size=args.batch_size, 122 | shuffle=True, 123 | pin_memory=cuda, 124 | num_workers=num_workers, 125 | worker_init_fn=np.random.seed(args.seed), 126 | drop_last=False 127 | ) 128 | test_projectloader = torch.utils.data.DataLoader(testset_projection, 129 | batch_size=1, 130 | shuffle=False, 131 | pin_memory=cuda, 132 | num_workers=num_workers, 133 | worker_init_fn=np.random.seed(args.seed), 134 | drop_last=False 135 | ) 136 | print("Num classes (k) = ", len(classes), classes[:5], "etc.", flush=True) 137 | return trainloader, trainloader_pretraining, trainloader_normal, trainloader_normal_augment, projectloader, testloader, test_projectloader, classes 138 | 139 | def create_datasets(transform1, transform2, transform_no_augment, num_channels:int, train_dir:str, project_dir: str, test_dir:str, seed:int, validation_size:float, train_dir_pretrain = None, test_dir_projection = None, transform1p=None): 140 | 141 | trainvalset = torchvision.datasets.ImageFolder(train_dir) 142 | classes = trainvalset.classes 143 | targets = trainvalset.targets 144 | indices = list(range(len(trainvalset))) 145 | 146 | train_indices = indices 147 | 148 | if test_dir is None: 149 | if validation_size <= 0.: 150 | raise ValueError("There is no test set directory, so validation size should be > 0 such that training set can be split.") 151 | subset_targets = list(np.array(targets)[train_indices]) 152 | train_indices, test_indices = train_test_split(train_indices,test_size=validation_size,stratify=subset_targets, random_state=seed) 153 | testset = torch.utils.data.Subset(torchvision.datasets.ImageFolder(train_dir, transform=transform_no_augment), indices=test_indices) 154 | print("Samples in trainset:", len(indices), "of which",len(train_indices),"for training and ", len(test_indices),"for testing.", flush=True) 155 | else: 156 | testset = torchvision.datasets.ImageFolder(test_dir, transform=transform_no_augment) 157 | 158 | trainset = torch.utils.data.Subset(TwoAugSupervisedDataset(trainvalset, transform1=transform1, transform2=transform2), indices=train_indices) 159 | trainset_normal = torch.utils.data.Subset(torchvision.datasets.ImageFolder(train_dir, transform=transform_no_augment), indices=train_indices) 160 | trainset_normal_augment = torch.utils.data.Subset(torchvision.datasets.ImageFolder(train_dir, transform=transforms.Compose([transform1, transform2])), indices=train_indices) 161 | projectset = torchvision.datasets.ImageFolder(project_dir, transform=transform_no_augment) 162 | 163 | if test_dir_projection is not None: 164 | testset_projection = torchvision.datasets.ImageFolder(test_dir_projection, transform=transform_no_augment) 165 | else: 166 | testset_projection = testset 167 | if train_dir_pretrain is not None: 168 | trainvalset_pr = torchvision.datasets.ImageFolder(train_dir_pretrain) 169 | targets_pr = trainvalset_pr.targets 170 | indices_pr = list(range(len(trainvalset_pr))) 171 | train_indices_pr = indices_pr 172 | if test_dir is None: 173 | subset_targets_pr = list(np.array(targets_pr)[indices_pr]) 174 | train_indices_pr, test_indices_pr = train_test_split(indices_pr,test_size=validation_size,stratify=subset_targets_pr, random_state=seed) 175 | 176 | trainset_pretraining = torch.utils.data.Subset(TwoAugSupervisedDataset(trainvalset_pr, transform1=transform1p, transform2=transform2), indices=train_indices_pr) 177 | else: 178 | trainset_pretraining = None 179 | 180 | return trainset, trainset_pretraining, trainset_normal, trainset_normal_augment, projectset, testset, testset_projection, classes, num_channels, train_indices, torch.LongTensor(targets) 181 | 182 | def get_pets(augment:bool, train_dir:str, project_dir: str, test_dir:str, img_size: int, seed:int, validation_size:float): 183 | mean = (0.485, 0.456, 0.406) 184 | std = (0.229, 0.224, 0.225) 185 | normalize = transforms.Normalize(mean=mean,std=std) 186 | transform_no_augment = transforms.Compose([ 187 | transforms.Resize(size=(img_size, img_size)), 188 | transforms.ToTensor(), 189 | normalize 190 | ]) 191 | 192 | if augment: 193 | transform1 = transforms.Compose([ 194 | transforms.Resize(size=(img_size+48, img_size+48)), 195 | TrivialAugmentWideNoColor(), 196 | transforms.RandomHorizontalFlip(), 197 | transforms.RandomResizedCrop(img_size+8, scale=(0.95, 1.)) 198 | ]) 199 | 200 | transform2 = transforms.Compose([ 201 | TrivialAugmentWideNoShape(), 202 | transforms.RandomCrop(size=(img_size, img_size)), #includes crop 203 | transforms.ToTensor(), 204 | normalize 205 | ]) 206 | else: 207 | transform1 = transform_no_augment 208 | transform2 = transform_no_augment 209 | 210 | return create_datasets(transform1, transform2, transform_no_augment, 3, train_dir, project_dir, test_dir, seed, validation_size) 211 | 212 | def get_partimagenet(augment:bool, train_dir:str, project_dir: str, test_dir:str, img_size: int, seed:int, validation_size:float): 213 | # Validation size was set to 0.2, such that 80% of the data is used for training 214 | mean = (0.485, 0.456, 0.406) 215 | std = (0.229, 0.224, 0.225) 216 | normalize = transforms.Normalize(mean=mean,std=std) 217 | transform_no_augment = transforms.Compose([ 218 | transforms.Resize(size=(img_size, img_size)), 219 | transforms.ToTensor(), 220 | normalize 221 | ]) 222 | 223 | if augment: 224 | transform1 = transforms.Compose([ 225 | transforms.Resize(size=(img_size+48, img_size+48)), 226 | TrivialAugmentWideNoColor(), 227 | transforms.RandomHorizontalFlip(), 228 | transforms.RandomResizedCrop(img_size+8, scale=(0.95, 1.)) 229 | ]) 230 | transform2 = transforms.Compose([ 231 | TrivialAugmentWideNoShape(), 232 | transforms.RandomCrop(size=(img_size, img_size)), #includes crop 233 | transforms.ToTensor(), 234 | normalize 235 | ]) 236 | else: 237 | transform1 = transform_no_augment 238 | transform2 = transform_no_augment 239 | 240 | return create_datasets(transform1, transform2, transform_no_augment, 3, train_dir, project_dir, test_dir, seed, validation_size) 241 | 242 | def get_birds(augment: bool, train_dir:str, project_dir: str, test_dir:str, img_size: int, seed:int, validation_size:float, train_dir_pretrain = None, test_dir_projection = None): 243 | shape = (3, img_size, img_size) 244 | mean = (0.485, 0.456, 0.406) 245 | std = (0.229, 0.224, 0.225) 246 | normalize = transforms.Normalize(mean=mean,std=std) 247 | transform_no_augment = transforms.Compose([ 248 | transforms.Resize(size=(img_size, img_size)), 249 | transforms.ToTensor(), 250 | normalize 251 | ]) 252 | transform1p = None 253 | if augment: 254 | transform1 = transforms.Compose([ 255 | transforms.Resize(size=(img_size+8, img_size+8)), 256 | TrivialAugmentWideNoColor(), 257 | transforms.RandomHorizontalFlip(), 258 | transforms.RandomResizedCrop(img_size+4, scale=(0.95, 1.)) 259 | ]) 260 | transform1p = transforms.Compose([ 261 | transforms.Resize(size=(img_size+32, img_size+32)), #for pretraining, crop can be bigger since it doesn't matter when bird is not fully visible 262 | TrivialAugmentWideNoColor(), 263 | transforms.RandomHorizontalFlip(), 264 | transforms.RandomResizedCrop(img_size+4, scale=(0.95, 1.)) 265 | ]) 266 | transform2 = transforms.Compose([ 267 | TrivialAugmentWideNoShape(), 268 | transforms.RandomCrop(size=(img_size, img_size)), #includes crop 269 | transforms.ToTensor(), 270 | normalize 271 | ]) 272 | else: 273 | transform1 = transform_no_augment 274 | transform2 = transform_no_augment 275 | 276 | return create_datasets(transform1, transform2, transform_no_augment, 3, train_dir, project_dir, test_dir, seed, validation_size, train_dir_pretrain, test_dir_projection, transform1p) 277 | 278 | def get_cars(augment: bool, train_dir:str, project_dir: str, test_dir:str, img_size: int, seed:int, validation_size:float): 279 | shape = (3, img_size, img_size) 280 | mean = (0.485, 0.456, 0.406) 281 | std = (0.229, 0.224, 0.225) 282 | 283 | normalize = transforms.Normalize(mean=mean,std=std) 284 | transform_no_augment = transforms.Compose([ 285 | transforms.Resize(size=(img_size, img_size)), 286 | transforms.ToTensor(), 287 | normalize 288 | ]) 289 | 290 | if augment: 291 | transform1 = transforms.Compose([ 292 | transforms.Resize(size=(img_size+32, img_size+32)), 293 | TrivialAugmentWideNoColor(), 294 | transforms.RandomHorizontalFlip(), 295 | transforms.RandomResizedCrop(img_size+4, scale=(0.95, 1.)) 296 | ]) 297 | 298 | transform2 = transforms.Compose([ 299 | TrivialAugmentWideNoShapeWithColor(), 300 | transforms.RandomCrop(size=(img_size, img_size)), #includes crop 301 | transforms.ToTensor(), 302 | normalize 303 | ]) 304 | 305 | else: 306 | transform1 = transform_no_augment 307 | transform2 = transform_no_augment 308 | 309 | return create_datasets(transform1, transform2, transform_no_augment, 3, train_dir, project_dir, test_dir, seed, validation_size) 310 | 311 | def get_grayscale(augment:bool, train_dir:str, project_dir: str, test_dir:str, img_size: int, seed:int, validation_size:float, train_dir_pretrain = None): 312 | mean = (0.485, 0.456, 0.406) 313 | std = (0.229, 0.224, 0.225) 314 | normalize = transforms.Normalize(mean=mean,std=std) 315 | transform_no_augment = transforms.Compose([ 316 | transforms.Resize(size=(img_size, img_size)), 317 | transforms.Grayscale(3), #convert to grayscale with three channels 318 | transforms.ToTensor(), 319 | normalize 320 | ]) 321 | 322 | if augment: 323 | transform1 = transforms.Compose([ 324 | transforms.Resize(size=(img_size+32, img_size+32)), 325 | TrivialAugmentWideNoColor(), 326 | transforms.RandomHorizontalFlip(), 327 | transforms.RandomResizedCrop(224+8, scale=(0.95, 1.)) 328 | ]) 329 | transform2 = transforms.Compose([ 330 | TrivialAugmentWideNoShape(), 331 | transforms.RandomCrop(size=(img_size, img_size)), #includes crop 332 | transforms.Grayscale(3),#convert to grayscale with three channels 333 | transforms.ToTensor(), 334 | normalize 335 | ]) 336 | else: 337 | transform1 = transform_no_augment 338 | transform2 = transform_no_augment 339 | 340 | return create_datasets(transform1, transform2, transform_no_augment, 3, train_dir, project_dir, test_dir, seed, validation_size) 341 | 342 | class TwoAugSupervisedDataset(torch.utils.data.Dataset): 343 | r"""Returns two augmentation and no labels.""" 344 | def __init__(self, dataset, transform1, transform2): 345 | self.dataset = dataset 346 | self.classes = dataset.classes 347 | if type(dataset) == torchvision.datasets.folder.ImageFolder: 348 | self.imgs = dataset.imgs 349 | self.targets = dataset.targets 350 | else: 351 | self.targets = dataset._labels 352 | self.imgs = list(zip(dataset._image_files, dataset._labels)) 353 | self.transform1 = transform1 354 | self.transform2 = transform2 355 | 356 | 357 | def __getitem__(self, index): 358 | image, target = self.dataset[index] 359 | image = self.transform1(image) 360 | return self.transform2(image), self.transform2(image), target 361 | 362 | def __len__(self): 363 | return len(self.dataset) 364 | 365 | # function copied from https://pytorch.org/vision/stable/_modules/torchvision/transforms/autoaugment.html#TrivialAugmentWide (v0.12) and adapted 366 | class TrivialAugmentWideNoColor(transforms.TrivialAugmentWide): 367 | def _augmentation_space(self, num_bins: int) -> Dict[str, Tuple[Tensor, bool]]: 368 | return { 369 | "Identity": (torch.tensor(0.0), False), 370 | "ShearX": (torch.linspace(0.0, 0.5, num_bins), True), 371 | "ShearY": (torch.linspace(0.0, 0.5, num_bins), True), 372 | "TranslateX": (torch.linspace(0.0, 16.0, num_bins), True), 373 | "TranslateY": (torch.linspace(0.0, 16.0, num_bins), True), 374 | "Rotate": (torch.linspace(0.0, 60.0, num_bins), True), 375 | } 376 | 377 | class TrivialAugmentWideNoShapeWithColor(transforms.TrivialAugmentWide): 378 | def _augmentation_space(self, num_bins: int) -> Dict[str, Tuple[Tensor, bool]]: 379 | return { 380 | "Identity": (torch.tensor(0.0), False), 381 | "Brightness": (torch.linspace(0.0, 0.5, num_bins), True), 382 | "Color": (torch.linspace(0.0, 0.5, num_bins), True), 383 | "Contrast": (torch.linspace(0.0, 0.5, num_bins), True), 384 | "Sharpness": (torch.linspace(0.0, 0.5, num_bins), True), 385 | "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6)).round().int(), False), 386 | "Solarize": (torch.linspace(255.0, 0.0, num_bins), False), 387 | "AutoContrast": (torch.tensor(0.0), False), 388 | "Equalize": (torch.tensor(0.0), False), 389 | } 390 | 391 | class TrivialAugmentWideNoShape(transforms.TrivialAugmentWide): 392 | def _augmentation_space(self, num_bins: int) -> Dict[str, Tuple[Tensor, bool]]: 393 | return { 394 | 395 | "Identity": (torch.tensor(0.0), False), 396 | "Brightness": (torch.linspace(0.0, 0.5, num_bins), True), 397 | "Color": (torch.linspace(0.0, 0.02, num_bins), True), 398 | "Contrast": (torch.linspace(0.0, 0.5, num_bins), True), 399 | "Sharpness": (torch.linspace(0.0, 0.5, num_bins), True), 400 | "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6)).round().int(), False), 401 | "AutoContrast": (torch.tensor(0.0), False), 402 | "Equalize": (torch.tensor(0.0), False), 403 | } 404 | 405 | --------------------------------------------------------------------------------