├── .gitignore ├── common_constants.py ├── dataset_helpers.py ├── environment.yml ├── experiment_logger.py ├── get_dataset.py ├── models.py ├── network_helpers.py ├── pirl_loss.py ├── pirl_stl_train_test.py ├── random_seed_setter.py ├── readme.md ├── stl10_data_load.py ├── submit_job.sh ├── train_stl_after_ssl.py └── train_test_helper.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | cifar_data/ 3 | __pycache__ 4 | .DS_Store 5 | weights/ 6 | parse_raw_output.py 7 | observations_dir/ 8 | e1_test_observations.csv 9 | raw_stl10/ 10 | stl10_data/ 11 | -------------------------------------------------------------------------------- /common_constants.py: -------------------------------------------------------------------------------- 1 | PAR_WEIGHTS_DIR = './weights' 2 | PAR_OBSERVATIONS_DIR = './observations_dir' 3 | PAR_ACTIVATIONS_DIR = './activations_dir' 4 | -------------------------------------------------------------------------------- /dataset_helpers.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchvision.transforms import transforms 4 | 5 | def_train_transform_stl = transforms.Compose([ 6 | transforms.ToTensor(), 7 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 8 | ]) 9 | 10 | hflip_data_transform = transforms.Compose([ 11 | transforms.RandomHorizontalFlip(p=1.0), 12 | transforms.ToTensor(), 13 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 14 | ]) 15 | 16 | darkness_jitter_transform = transforms.Compose([ 17 | transforms.ColorJitter(brightness=[0.5, 0.9]), 18 | transforms.ToTensor(), 19 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 20 | ]) 21 | 22 | lightness_jitter_transform = transforms.Compose([ 23 | transforms.ColorJitter(brightness=[1.1, 1.5]), 24 | transforms.ToTensor(), 25 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 26 | ]) 27 | 28 | rotations_transform = transforms.Compose([ 29 | transforms.RandomRotation(degrees=25), 30 | transforms.ToTensor(), 31 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 32 | ]) 33 | 34 | all_in_transform = transforms.Compose([ 35 | transforms.RandomHorizontalFlip(p=0.5), 36 | transforms.ColorJitter(brightness=[0.5, 1.5]), 37 | transforms.RandomRotation(degrees=25), 38 | transforms.ToTensor(), 39 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 40 | ]) 41 | 42 | def_test_transform = transforms.Compose([ 43 | transforms.ToTensor(), 44 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 45 | ]) 46 | 47 | pirl_full_img_transform = transforms.Compose([ 48 | transforms.RandomHorizontalFlip(), 49 | transforms.ToTensor(), 50 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 51 | ]) 52 | 53 | pirl_stl10_jigsaw_patch_transform = transforms.Compose([ 54 | transforms.RandomCrop(30, padding=1), 55 | transforms.ColorJitter(brightness=[0.5, 1.5]), 56 | transforms.ToTensor(), 57 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 58 | ]) 59 | 60 | 61 | def get_file_paths_n_labels(par_images_dir): 62 | """ 63 | Returns all file paths for images in a directory (par_images_dir). 64 | The par_images_dir is supposed to only have sub directories with each sub-directory representing a label 65 | And in turn each sub-directory holding images pertaining to the label it represents 66 | """ 67 | 68 | label_names = [dir_name for dir_name in os.listdir(par_images_dir) 69 | if os.path.isdir(os.path.join(par_images_dir, dir_name))] 70 | label_dir_paths = [os.path.join(par_images_dir, dir_name) for dir_name in label_names] 71 | 72 | file_paths = [] 73 | labels = [] 74 | 75 | for label_name, label_dir in zip(label_names, label_dir_paths): 76 | file_names = [file_name for file_name in os.listdir(label_dir) if file_name[-4:]=='jpeg'] 77 | file_paths += [os.path.join(label_dir, file_name) for file_name in file_names] 78 | labels += [int(label_name) - 1] * len(file_names) 79 | 80 | return file_paths, labels 81 | 82 | 83 | def get_nine_crops(pil_image): 84 | """ 85 | Get nine crops for a square pillow image. That is height and width of the image should be same. 86 | :param pil_image: pillow image 87 | :return: List of pillow images. The nine crops 88 | """ 89 | w, h = pil_image.size 90 | diff = int(w/3) 91 | 92 | r_vals = [0, diff, 2 * diff] 93 | c_vals = [0, diff, 2 * diff] 94 | 95 | list_patches = [] 96 | 97 | for r in r_vals: 98 | for c in c_vals: 99 | 100 | left = c 101 | top = r 102 | right = c + diff 103 | bottom = r + diff 104 | 105 | patch = pil_image.crop((left, top, right, bottom)) 106 | list_patches.append(patch) 107 | 108 | return list_patches 109 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: pirl-cifar 2 | channels: 3 | - conda-forge 4 | - pytorch 5 | dependencies: 6 | - imageio=2.6.1 7 | - pandas=1.0.1 8 | - python=3.7 9 | - pillow=6.1 10 | - pip 11 | - pytorch=1.1 12 | - scipy=1.2.0 13 | - torchvision 14 | - opencv 15 | - matplotlib 16 | - jupyterlab 17 | - nb_conda_kernels 18 | - pip: 19 | - torchtext 20 | - torchviz 21 | -------------------------------------------------------------------------------- /experiment_logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | 4 | from common_constants import PAR_OBSERVATIONS_DIR 5 | 6 | 7 | def log_experiment(exp_name, n_epochs, train_losses, val_losses, train_accs, val_accs): 8 | observations_df = pd.DataFrame() 9 | observations_df['epoch count'] = [i for i in range(1, n_epochs + 1)] 10 | observations_df['train loss'] = train_losses 11 | observations_df['val loss'] = val_losses 12 | observations_df['train acc'] = train_accs 13 | observations_df['val acc'] = val_accs 14 | observations_file_path = os.path.join(PAR_OBSERVATIONS_DIR, exp_name + '_observations.csv') 15 | observations_df.to_csv(observations_file_path) 16 | -------------------------------------------------------------------------------- /get_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from PIL import Image 5 | 6 | from torch.utils.data import Dataset 7 | 8 | from dataset_helpers import get_nine_crops, pirl_full_img_transform, pirl_stl10_jigsaw_patch_transform 9 | 10 | 11 | class GetSTL10Data(Dataset): 12 | 'Characterizes PyTorch Dataset object' 13 | def __init__(self, file_paths, labels, transform): 14 | 'Initialization' 15 | self.file_paths = file_paths 16 | self.labels = labels 17 | self.transform = transform 18 | 19 | def __len__(self): 20 | 'Denotes the total number of samples' 21 | return len(self.file_paths) 22 | 23 | def __getitem__(self, index): 24 | 'Generates one sample of data' 25 | 26 | # Select one file_path and convert to tensor object 27 | image = Image.open(self.file_paths[index]) 28 | image_tensor = self.transform(image) 29 | label = self.labels[index] 30 | 31 | return image_tensor, label 32 | 33 | 34 | class GetSTL10DataForPIRL(Dataset): 35 | 'Characterizes PyTorch Dataset object' 36 | def __init__(self, file_paths): 37 | 'Initialization' 38 | self.file_paths = file_paths 39 | 40 | def __len__(self): 41 | 'Denotes the total number of samples' 42 | return len(self.file_paths) 43 | 44 | def __getitem__(self, index): 45 | 'Generates one sample of data' 46 | 47 | # Select one file_path and convert to tensor object 48 | image = Image.open(self.file_paths[index]) 49 | image_tensor = pirl_full_img_transform(image) 50 | 51 | # Get nine crops for the image 52 | nine_crops = get_nine_crops(image) 53 | 54 | # Form the jigsaw order for this image 55 | original_order = np.arange(9) 56 | permuted_order = np.copy(original_order) 57 | np.random.shuffle(permuted_order) 58 | 59 | # Permut the 9 patches obtained from the image 60 | permuted_patches_arr = [None] * 9 61 | for patch_pos, patch in zip(permuted_order, nine_crops): 62 | permuted_patches_arr[patch_pos] = patch 63 | 64 | # Apply data transforms 65 | # TODO: Remove hard coded values from here 66 | tensor_patches = torch.zeros(9, 3, 30, 30) 67 | for ind, patch in enumerate(permuted_patches_arr): 68 | patch_tensor = pirl_stl10_jigsaw_patch_transform(patch) 69 | tensor_patches[ind] = patch_tensor 70 | 71 | return [image_tensor, tensor_patches], index 72 | 73 | 74 | 75 | if __name__ == '__main__': 76 | 77 | # Lets test the GetSTL10DataForPIRL class 78 | print("Testing for GetSTL10DataForPIRL") 79 | base_images_dir = 'stl10_data/unlabelled' 80 | file_names_list = os.listdir(base_images_dir) 81 | file_names_list = [file_name for file_name in file_names_list if file_name[-4:] == 'jpeg'] 82 | 83 | file_paths_list = [os.path.join(base_images_dir, file_name) for file_name in file_names_list] 84 | ssl_dataset = GetSTL10DataForPIRL(file_paths_list) 85 | ssl_loader = torch.utils.data.DataLoader(ssl_dataset, batch_size=128, num_workers=8) 86 | 87 | dataiter = iter(ssl_loader) 88 | X, y = dataiter.__next__() 89 | print(X[0].size()) 90 | print(X[1].size()) 91 | print(y.size()) 92 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision.models import resnet18, resnet34, resnet50 5 | 6 | 7 | class ClassificationResNet(nn.Module): 8 | 9 | def __init__(self, resnet_module, num_classes): 10 | super(ClassificationResNet, self).__init__() 11 | self.resnet_module = resnet_module 12 | self.fc = nn.Linear(512, num_classes) 13 | 14 | def forward(self, input_batch): 15 | 16 | # Data returned by data loaders is of the shape (batch_size, no_channels, h_patch, w_patch) 17 | resnet_feat_vectors = self.resnet_module(input_batch) 18 | final_feat_vectors = torch.flatten(resnet_feat_vectors, 1) 19 | x = F.log_softmax(self.fc(final_feat_vectors)) 20 | 21 | return x 22 | 23 | 24 | def get_base_resnet_module(model_type): 25 | """ 26 | Returns the backbone network for required resnet architecture, specified as model_type 27 | :param model_type: Can be either of {res18, res34, res50} 28 | """ 29 | 30 | if model_type == 'res18': 31 | original_model = resnet18(pretrained=False) 32 | elif model_type == 'res34': 33 | original_model = resnet34(pretrained=False) 34 | else: 35 | original_model = resnet50(pretrained=False) 36 | base_resnet_module = nn.Sequential(*list(original_model.children())[:-1]) 37 | 38 | return base_resnet_module 39 | 40 | 41 | def classifier_resnet(model_type, num_classes): 42 | """ 43 | Returns a classification network with backbone belonging to the family of ResNets 44 | :param model_type: Specifies which resnet network to employ. Can be one of {res18, res34, res50} 45 | :param num_classes: The number of classes that the final network classifies it inputs into. 46 | """ 47 | 48 | base_resnet_module = get_base_resnet_module(model_type) 49 | 50 | return ClassificationResNet(base_resnet_module, num_classes) 51 | 52 | 53 | class PIRLResnet(nn.Module): 54 | def __init__(self, resnet_module, non_linear_head=False): 55 | super(PIRLResnet, self).__init__() 56 | self.resnet_module = resnet_module 57 | self.lin_project_1 = nn.Linear(512, 128) 58 | self.lin_project_2 = nn.Linear(128 * 9, 128) 59 | if non_linear_head: 60 | self.lin_project_3 = nn.Linear(128, 128) # Will only be used if non_linear_head is True 61 | self.non_linear_head = non_linear_head 62 | 63 | def forward(self, i_batch, i_t_patches_batch): 64 | """ 65 | :param i_batch: Batch of images 66 | :param i_t_patches_batch: Batch of transformed image patches (jigsaw transformation) 67 | """ 68 | 69 | # Run I and I_t through resnet 70 | vi_batch = self.resnet_module(i_batch) 71 | vi_batch = torch.flatten(vi_batch, 1) 72 | vi_t_patches_batch = [self.resnet_module(i_t_patches_batch[:, patch_ind, :, :, :]) 73 | for patch_ind in range(9)] 74 | vi_t_patches_batch = [torch.flatten(vi_t_patches_batch[patch_ind], 1) 75 | for patch_ind in range(9)] 76 | 77 | # Run resnet features for I and I_t via lin_project_1 layer 78 | vi_batch = self.lin_project_1(vi_batch) 79 | vi_t_patches_batch = [self.lin_project_1(vi_t_patches_batch[patch_ind]) 80 | for patch_ind in range(9)] 81 | 82 | # Concatenate together lin_project_1 outputs for patches of I_t 83 | vi_t_patches_concatenated = torch.cat(vi_t_patches_batch, 1) 84 | 85 | # Run concatenated feature vector for I_t through lin_project_2 layer 86 | vi_t_batch = self.lin_project_2(vi_t_patches_concatenated) 87 | 88 | # Run final feature vectors obtained for I and I_t through non-linearity (if specified) 89 | if self.non_linear_head: 90 | vi_batch = self.lin_project_3(F.relu(vi_batch)) 91 | vi_t_batch = self.lin_project_3(F.relu(vi_t_batch)) 92 | 93 | return vi_batch, vi_t_batch 94 | 95 | 96 | def pirl_resnet(model_type, non_linear_head=False): 97 | """ 98 | Returns a network which supports Pre-text invariant representation learning 99 | with backbone belonging to the family of ResNets 100 | :param model_type: Specifies which resnet network to employ. Can be one of {res18, res34, res50} 101 | :param non_linear_head: If true apply non-linearity to the output of function heads 102 | applied to resnet image representations 103 | """ 104 | 105 | base_resnet_module = get_base_resnet_module(model_type) 106 | 107 | return PIRLResnet(base_resnet_module, non_linear_head) 108 | 109 | 110 | if __name__ == '__main__': 111 | pr = pirl_resnet('res18', non_linear_head=True) # non_linear_head can be True or False either. 112 | image_batch = torch.randn(32, 3, 64, 64) 113 | tr_img_patch_batch = torch.randn(32, 9, 3, 32, 32) 114 | 115 | result1, result2 = pr.forward(image_batch, tr_img_patch_batch) 116 | 117 | print (result1.size()) 118 | print (result2.size()) 119 | -------------------------------------------------------------------------------- /network_helpers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from models import pirl_resnet, classifier_resnet 4 | 5 | 6 | def test_copy_weights(m1, m2): 7 | """ 8 | Tests that weights copied from m1 into m2, are actually refected in m2 9 | """ 10 | m1_state_dict = m1.state_dict() 11 | m2_state_dict = m2.state_dict() 12 | weight_copy_flag = 1 13 | for name, param in m1_state_dict.items(): 14 | if name in m2_state_dict: 15 | if not torch.all(torch.eq(param.data, m2_state_dict[name].data)): 16 | print("Something is incorrect for layer {} in 2nd model", name) 17 | weight_copy_flag = 0 18 | 19 | if weight_copy_flag: 20 | print('All is well') 21 | 22 | return 1 23 | 24 | 25 | def copy_weights_between_models(m1, m2): 26 | """ 27 | Copy weights for layers common between m1 and m2. 28 | From m1 => m2 29 | """ 30 | 31 | # Load state dictionaries for m1 model and m2 model 32 | m1_state_dict = m1.state_dict() 33 | m2_state_dict = m2.state_dict() 34 | 35 | # Set the m2 model's weights with trained m1 model weights 36 | for name, param in m1_state_dict.items(): 37 | if name not in m2_state_dict: 38 | continue 39 | else: 40 | m2_state_dict[name] = param.data 41 | m2.load_state_dict(m2_state_dict) 42 | 43 | # Test that model m2 **really** has got updated weights 44 | return test_copy_weights(m1, m2) 45 | 46 | 47 | if __name__ == '__main__': 48 | 49 | pr = pirl_resnet('res18') 50 | cr = classifier_resnet('res18', num_classes=10) 51 | 52 | copy_success = copy_weights_between_models(pr, cr) 53 | 54 | -------------------------------------------------------------------------------- /pirl_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def get_img_pair_probs(vi_batch, vi_t_batch, mn_arr, temp_parameter): 6 | """ 7 | Returns the probability that feature representation for image I and I_t belong to same distribution. 8 | :param vi_batch: Feature representation for batch of images I 9 | :param vi_t_batch: Feature representation for batch containing transformed versions of I. 10 | :param mn_arr: Memory bank of feature representations for negative images for current batch 11 | :param temp_parameter: The temperature parameter 12 | """ 13 | 14 | # Define constant eps to ensure training is not impacted if norm of any image rep is zero 15 | eps = 1e-6 16 | 17 | # L2 normalize vi, vi_t and memory bank representations 18 | vi_norm_arr = torch.norm(vi_batch, dim=1, keepdim=True) 19 | vi_t_norm_arr = torch.norm(vi_t_batch, dim=1, keepdim=True) 20 | mn_norm_arr = torch.norm(mn_arr, dim=1, keepdim=True) 21 | 22 | vi_batch = vi_batch / (vi_norm_arr + eps) 23 | vi_t_batch = vi_t_batch/ (vi_t_norm_arr + eps) 24 | mn_arr = mn_arr / (mn_norm_arr + eps) 25 | 26 | # Find cosine similarities 27 | sim_vi_vi_t_arr = (vi_batch @ vi_t_batch.t()).diagonal() 28 | sim_vi_t_mn_mat = (vi_t_batch @ mn_arr.t()) 29 | 30 | # Fine exponentiation of similarity arrays 31 | exp_sim_vi_vi_t_arr = torch.exp(sim_vi_vi_t_arr / temp_parameter) 32 | exp_sim_vi_t_mn_mat = torch.exp(sim_vi_t_mn_mat / temp_parameter) 33 | 34 | # Sum exponential similarities of I_t with different images from memory bank of negatives 35 | sum_exp_sim_vi_t_mn_arr = torch.sum(exp_sim_vi_t_mn_mat, 1) 36 | 37 | # Find batch probabilities arr 38 | batch_prob_arr = exp_sim_vi_vi_t_arr / (exp_sim_vi_vi_t_arr + sum_exp_sim_vi_t_mn_arr + eps) 39 | 40 | return batch_prob_arr 41 | 42 | 43 | def loss_pirl(img_pair_probs_arr, img_mem_rep_probs_arr): 44 | """ 45 | Returns the average of [-log(prob(img_pair_probs_arr)) - log(prob(img_mem_rep_probs_arr))] 46 | :param img_pair_probs_arr: Prob vector of batch of images I and I_t to belong to same data distribution. 47 | :param img_mem_rep_probs_arr: Prob vector of batch of I and mem_bank_rep of I to belong to same data distribution 48 | """ 49 | 50 | # Get 1st term of loss 51 | neg_log_img_pair_probs = -1 * torch.log(img_pair_probs_arr) 52 | loss_i_i_t = torch.sum(neg_log_img_pair_probs) / neg_log_img_pair_probs.size()[0] 53 | 54 | # Get 2nd term of loss 55 | neg_log_img_mem_rep_probs_arr = -1 * torch.log(img_mem_rep_probs_arr) 56 | loss_i_mem_i = torch.sum(neg_log_img_mem_rep_probs_arr) / neg_log_img_mem_rep_probs_arr.size()[0] 57 | 58 | loss = (loss_i_i_t + loss_i_mem_i) / 2 59 | 60 | return loss 61 | 62 | 63 | if __name__ == '__main__': 64 | # Test get_img_pair_probs function 65 | vi_batch = torch.randn(256, 128) 66 | vi_t_batch = torch.randn(256, 128) 67 | mn_arr = torch.randn(6400, 128) 68 | mem_rep_of_batch_imgs = torch.randn(256, 128) 69 | temp_parameter = 1.5 70 | 71 | # Prob vector between I and I_t 72 | img_pair_probs_arr = get_img_pair_probs(vi_batch, vi_t_batch, mn_arr, temp_parameter) 73 | print (img_pair_probs_arr.shape) 74 | 75 | # Prob vector between I and mem bank representation of I 76 | img_mem_rep_probs_arr = get_img_pair_probs(vi_batch, mem_rep_of_batch_imgs, mn_arr, temp_parameter) 77 | print (img_mem_rep_probs_arr.shape) 78 | 79 | # Final loss 80 | loss_val = loss_pirl(img_pair_probs_arr, img_mem_rep_probs_arr) 81 | 82 | print (loss_val) 83 | 84 | 85 | -------------------------------------------------------------------------------- /pirl_stl_train_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import torch 5 | import torchvision 6 | 7 | import numpy as np 8 | import pandas as pd 9 | 10 | from torch import optim 11 | from torch.optim.lr_scheduler import CosineAnnealingLR 12 | from torch.utils.data import SubsetRandomSampler 13 | 14 | from common_constants import PAR_WEIGHTS_DIR 15 | from experiment_logger import log_experiment 16 | from get_dataset import GetSTL10DataForPIRL 17 | from models import pirl_resnet 18 | from random_seed_setter import set_random_generators_seed 19 | from train_test_helper import PIRLModelTrainTest 20 | 21 | 22 | def unpickle(file): 23 | import pickle 24 | with open(file, 'rb') as fo: 25 | dict = pickle.load(fo, encoding='bytes') 26 | return dict 27 | 28 | 29 | if __name__ == '__main__': 30 | 31 | # Training arguments 32 | parser = argparse.ArgumentParser(description='STL10 Train test script for PIRL task') 33 | parser.add_argument('--model-type', type=str, default='res18', help='The network architecture to employ as backbone') 34 | parser.add_argument('--batch-size', type=int, default=128, metavar='N', 35 | help='input batch size for training (default: 128)') 36 | parser.add_argument('--epochs', type=int, default=100, metavar='N', 37 | help='number of epochs to train (default: 100)') 38 | parser.add_argument('--lr', type=float, default=0.01, metavar='LR', 39 | help='learning rate (default: 0.01)') 40 | parser.add_argument('--weight-decay', type=float, default=5e-4, 41 | help='Weight decay constant (default: 5e-4)') 42 | parser.add_argument('--tmax-for-cos-decay', type=int, default=70) 43 | parser.add_argument('--warm-start', type=bool, default=False) 44 | parser.add_argument('--count-negatives', type=int, default=6400, 45 | help='No of samples in memory bank of negatives') 46 | parser.add_argument('--beta', type=float, default=0.5, help='Exponential running average constant' 47 | 'in memory bank update') 48 | parser.add_argument('--only-train', type=bool, default=False, 49 | help='If true utilize the entire unannotated STL10 dataset for training.') 50 | parser.add_argument('--non-linear-head', type=bool, default=False, 51 | help='If true apply non-linearity to the output of function heads ' 52 | 'applied to resnet image representations') 53 | parser.add_argument('--temp-parameter', type=float, default=0.07, help='Temperature parameter in NCE probability') 54 | parser.add_argument('--cont-epoch', type=int, default=1, help='Epoch to start the training from, helpful when using' 55 | 'warm start') 56 | parser.add_argument('--experiment-name', type=str, default='e1_resnet18_') 57 | args = parser.parse_args() 58 | 59 | # Set random number generation seed for all packages that generate random numbers 60 | set_random_generators_seed() 61 | 62 | # Identify device for holding tensors and carrying out computations 63 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 64 | 65 | # Define the file_path where trained model will be saved 66 | model_file_path = os.path.join(PAR_WEIGHTS_DIR, args.experiment_name + '_epoch_100') 67 | 68 | # Get train_val image file_paths 69 | base_images_dir = 'stl10_data/unlabelled' 70 | file_names_list = os.listdir(base_images_dir) 71 | file_names_list = [file_name for file_name in file_names_list if file_name[-4:] == 'jpeg'] 72 | file_paths_list = [os.path.join(base_images_dir, file_name) for file_name in file_names_list] 73 | 74 | # Define train_set, val_set objects 75 | train_set = GetSTL10DataForPIRL(file_paths_list) 76 | val_set = GetSTL10DataForPIRL(file_paths_list) 77 | 78 | # Define train and validation data loaders 79 | len_train_val_set = len(train_set) 80 | train_val_indices = list(range(len_train_val_set)) 81 | np.random.shuffle(train_val_indices) 82 | 83 | if args.only_train is False: 84 | count_train = 70000 85 | else: 86 | count_train = 100000 87 | 88 | train_indices = train_val_indices[:count_train] 89 | val_indices = train_val_indices[count_train:] 90 | 91 | train_sampler = SubsetRandomSampler(train_indices) 92 | val_sampler = SubsetRandomSampler(val_indices) 93 | 94 | train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, sampler=train_sampler, 95 | num_workers=8) 96 | val_loader = torch.utils.data.DataLoader(val_set, batch_size=args.batch_size, sampler=val_sampler, 97 | num_workers=8) 98 | 99 | # Print sample batches that would be returned by the train_data_loader 100 | dataiter = iter(train_loader) 101 | X, y = dataiter.__next__() 102 | print (X[0].size()) 103 | print (X[1].size()) 104 | print (y.size()) 105 | 106 | # Train required model using data loaders defined above 107 | epochs = args.epochs 108 | lr = args.lr 109 | weight_decay_const = args.weight_decay 110 | 111 | # If using Resnet18 112 | model_to_train = pirl_resnet(args.model_type, args.non_linear_head) 113 | 114 | # Set device on which training is done. Plus optimizer to use. 115 | model_to_train.to(device) 116 | sgd_optimizer = optim.SGD(model_to_train.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay_const) 117 | scheduler = CosineAnnealingLR(sgd_optimizer, args.tmax_for_cos_decay, eta_min=1e-4, last_epoch=-1) 118 | 119 | # Initialize model weights with a previously trained model if using warm start 120 | if args.warm_start and os.path.exists(model_file_path): 121 | model_to_train.load_state_dict(torch.load(model_file_path, map_location=device)) 122 | 123 | # Start training 124 | all_images_mem = np.random.randn(len_train_val_set, 128) 125 | model_train_test_obj = PIRLModelTrainTest( 126 | model_to_train, device, model_file_path, all_images_mem, train_indices, val_indices, args.count_negatives, 127 | args.temp_parameter, args.beta, args.only_train 128 | ) 129 | train_losses, val_losses, train_accs, val_accs = [], [], [], [] 130 | for epoch_no in range(args.cont_epoch, args.cont_epoch + epochs): 131 | train_loss, train_acc, val_loss, val_acc = model_train_test_obj.train( 132 | sgd_optimizer, epoch_no, params_max_norm=4, 133 | train_data_loader=train_loader, val_data_loader=val_loader, 134 | no_train_samples=len(train_indices), no_val_samples=len(val_indices) 135 | ) 136 | train_losses.append(train_loss) 137 | val_losses.append(val_loss) 138 | train_accs.append(train_acc) 139 | val_accs.append(val_acc) 140 | scheduler.step() 141 | 142 | # Log train-test results 143 | log_experiment(args.experiment_name, args.epochs, train_losses, val_losses, train_accs, val_accs) 144 | -------------------------------------------------------------------------------- /random_seed_setter.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def set_random_generators_seed(): 7 | 8 | # Set up random seed to 1008. Do not change the random seed. 9 | # Yes, these are all necessary when you run experiments! 10 | seed = 1008 11 | random.seed(seed) 12 | np.random.seed(seed) 13 | torch.manual_seed(seed) 14 | cuda = torch.cuda.is_available() 15 | if cuda: 16 | torch.cuda.manual_seed(seed) 17 | torch.cuda.manual_seed_all(seed) 18 | torch.backends.cudnn.benchmark = False 19 | torch.backends.cudnn.deterministic = True 20 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Pytorch Implementation of Pre-text Invariant Representation Learning 2 | This repository contains the pyotrch implementation of Pretext invariant representation learning (PIRL) 3 | algorithm on STL10 dataset. PIRL was originally introduced by Misra et al, publication of which can be found [here](https://arxiv.org/abs/1912.01991). 4 | 5 | ## What is PIRL and why is it useful 6 | Pretext invariant representation learning (PIRL) is a self supervised learing algorithm that exploits contrastive 7 | learning to learn visual representations such that original and transformed version of the same image have similar 8 | representations, while being different from that of other images, thus achieving invariance to the transformation. 9 | 10 | In their paper, authors have primarily focused on jigsaw puzzles transformation. 11 | 12 | ## Loss Function and slight modification 13 | The CNN used for representation learning is trained using NCE (Noise Contrastive Estimation) technique, 14 | NCE models the porbability of event that (I, I_t) (original and transformed image) originate from the same 15 | data distribution, I.e. 16 | ![alt text](https://docs.google.com/drawings/d/e/2PACX-1vQIBzisD1g6le_VQlfj7oeJVr98inlrBsvTzssW35MO1nxilwXa2MhkUukLli1U1Orb50_kC_XY3XCL/pub?w=480&h=96 "probability function") 17 |
18 | Where s(., .) is cosine similarity between v_i and v_i_t, deep representations for original and transformed image respectively. 19 | While, the final NCE loss is given as: 20 | ![alt text](https://docs.google.com/drawings/d/e/2PACX-1vRh2RjlYsPaSyGDORVN3zDl3sZ1r1g48jxW-fT8ajrGFx1rbHqyRnlepbZ63wr1K0oOCfjfndUhKA4S/pub?w=960&h=720 "L_nce") 21 | where f(.) and g(.) are linear function heads. 22 | 23 | ## Slight Modification 24 | Instead of using NCE loss, for this implementation, optimization process would directly aim to minimize 25 | the negative log of probability described in the first equation above (with inputs as f(v_i) and g(v_i_t)) 26 | 27 | ## Dataset Used 28 | The implementation uses STL10 dataset, which can be downloaded from [here](http://ai.stanford.edu/~acoates/stl10/) 29 | #### Dataset setup steps 30 | ``` 31 | 1. Download raw data from above link to ./raw_stl10/ 32 | 2. Run stl10_data_load.py. This will save three directories train, test and unlabelled in ./stl10_data/ 33 | ``` 34 | 35 | ## Training and evaluation steps 36 | 1. Run script pirl_stl_train_test.py for unsupervised (self supervised learning), example 37 | ``` 38 | python pirl_stl_train_test.py --model-type res18 --batch-size 128 --lr 0.1 --experiment-name exp 39 | ``` 40 | 2. Run script train_stl_after_ssl.py for fine tuning model parameters obtained from self supervised learning, example 41 | ``` 42 | python train_stl_after_ssl.py --model-type res18 --batch-size 128 --lr 0.1 --patience-for-lr-decay 4 --full-fine-tune True --pirl-model-name 43 | ``` 44 | 45 | ## Results 46 | After training the CNN model in PIRL manner, to evaluate how well learnt model weights transfer to classification 47 | problem in limited dataset scenario, following experiments were performed. 48 | 49 | Fine tuning strategy | Val Classification Accuracy 50 | --- | --- 51 | Only softmax layer is fine tuned | 50.50 52 | Full model is fine tuned | 67.87 53 | 54 | # References 55 | 1. PIRL paper: https://arxiv.org/abs/1912.01991 56 | 2. STL 10 dataset: http://ai.stanford.edu/~acoates/stl10/ 57 | 3. Data loading code for STL 10: https://github.com/mttk/STL10 58 | -------------------------------------------------------------------------------- /stl10_data_load.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script obtained from: https://github.com/mttk/STL10/blob/master/stl10_input.py 3 | """ 4 | 5 | from __future__ import print_function 6 | 7 | import os, sys, tarfile, errno 8 | import numpy as np 9 | from PIL import Image 10 | 11 | if sys.version_info >= (3, 0, 0): 12 | import urllib.request as urllib # ugly but works 13 | else: 14 | import urllib 15 | 16 | print(sys.version_info) 17 | 18 | # image shape 19 | HEIGHT = 96 20 | WIDTH = 96 21 | DEPTH = 3 22 | 23 | # size of a single image in bytes 24 | SIZE = HEIGHT * WIDTH * DEPTH 25 | 26 | # path to the directory with the data 27 | DATA_DIR = './raw_stl10/' 28 | STL10_TRAIN_IMG_DIR = './stl10_data/train/' 29 | STL10_TEST_IMG_DIR = './stl10_data/test/' 30 | STL10_UNLABELLED_IMG_DIR = './stl10_data/unlabelled/' 31 | 32 | # url of the binary data 33 | DATA_URL = 'http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz' 34 | 35 | # path to the binary files with image data 36 | TRAIN_DATA_PATH = os.path.join(DATA_DIR ,'stl10_binary/train_X.bin') 37 | TEST_DATA_PATH = os.path.join(DATA_DIR , 'stl10_binary/test_X.bin') 38 | UNLABELLED_DATA_PATH = os.path.join(DATA_DIR, 'stl10_binary/unlabeled_X.bin') 39 | 40 | # path to the binary files with labels 41 | TRAIN_LABEL_PATH = os.path.join(DATA_DIR, 'stl10_binary/train_y.bin') 42 | TEST_LABEL_PATH = os.path.join(DATA_DIR, 'stl10_binary/test_y.bin') 43 | 44 | 45 | def read_labels(path_to_labels): 46 | """ 47 | :param path_to_labels: path to the binary file containing labels from the STL-10 dataset 48 | :return: an array containing the labels 49 | """ 50 | with open(path_to_labels, 'rb') as f: 51 | labels = np.fromfile(f, dtype=np.uint8) 52 | return labels 53 | 54 | 55 | def read_all_images(path_to_data): 56 | """ 57 | :param path_to_data: the file containing the binary images from the STL-10 dataset 58 | :return: an array containing all the images 59 | """ 60 | 61 | with open(path_to_data, 'rb') as f: 62 | # read whole file in uint8 chunks 63 | everything = np.fromfile(f, dtype=np.uint8) 64 | 65 | # We force the data into 3x96x96 chunks, since the 66 | # images are stored in "column-major order", meaning 67 | # that "the first 96*96 values are the red channel, 68 | # the next 96*96 are green, and the last are blue." 69 | # The -1 is since the size of the pictures depends 70 | # on the input file, and this way numpy determines 71 | # the size on its own. 72 | 73 | images = np.reshape(everything, (-1, 3, 96, 96)) 74 | 75 | # Now transpose the images into a standard image format 76 | # readable by, for example, matplotlib.imshow 77 | # You might want to comment this line or reverse the shuffle 78 | # if you will use a learning algorithm like CNN, since they like 79 | # their channels separated. 80 | images = np.transpose(images, (0, 3, 2, 1)) 81 | return images 82 | 83 | 84 | def save_image(image, name): 85 | image = Image.fromarray(image) 86 | image.save(name + '.jpeg') 87 | 88 | 89 | def download_and_extract(): 90 | """ 91 | Download and extract the STL-10 dataset 92 | :return: None 93 | """ 94 | dest_directory = DATA_DIR 95 | if not os.path.exists(dest_directory): 96 | os.makedirs(dest_directory) 97 | filename = DATA_URL.split('/')[-1] 98 | filepath = os.path.join(dest_directory, filename) 99 | if not os.path.exists(filepath): 100 | def _progress(count, block_size, total_size): 101 | sys.stdout.write('\rDownloading %s %.2f%%' % (filename, 102 | float(count * block_size) / float(total_size) * 100.0)) 103 | sys.stdout.flush() 104 | 105 | filepath, _ = urllib.urlretrieve(DATA_URL, filepath, reporthook=_progress) 106 | print('Downloaded', filename) 107 | tarfile.open(filepath, 'r:gz').extractall(dest_directory) 108 | 109 | 110 | def save_images(images_dir, images, labels): 111 | print("Saving images to disk") 112 | i = 0 113 | for image in images: 114 | if labels is not None: 115 | directory = os.path.join(images_dir, str(labels[i])) 116 | else: 117 | directory = images_dir 118 | 119 | try: 120 | os.makedirs(directory, exist_ok=True) 121 | except OSError as exc: 122 | if exc.errno == errno.EEXIST: 123 | pass 124 | 125 | filename = os.path.join(directory, str(i)) 126 | # print(filename) 127 | save_image(image, filename) 128 | i = i + 1 129 | 130 | 131 | if __name__ == "__main__": 132 | 133 | # download data if needed 134 | download_and_extract() 135 | 136 | # Read train-images and labels and save to disk 137 | images = read_all_images(TRAIN_DATA_PATH) 138 | print(images.shape) 139 | 140 | labels = read_labels(TRAIN_LABEL_PATH) 141 | print(labels.shape) 142 | 143 | save_images(STL10_TRAIN_IMG_DIR, images, labels) 144 | 145 | # Read test-images and labels and save to disk 146 | images = read_all_images(TEST_DATA_PATH) 147 | print(images.shape) 148 | 149 | labels = read_labels(TEST_LABEL_PATH) 150 | print(labels.shape) 151 | 152 | save_images(STL10_TEST_IMG_DIR, images, labels) 153 | 154 | # Read unlabelled images and save to disk 155 | images = read_all_images(UNLABELLED_DATA_PATH) 156 | labels = None 157 | print (images.shape) 158 | 159 | save_images(STL10_UNLABELLED_IMG_DIR, images, labels) 160 | -------------------------------------------------------------------------------- /submit_job.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | #SBATCH --nodes=1 3 | #SBATCH --ntasks-per-node=1 4 | #SBATCH --cpus-per-task=4 5 | #SBATCH --time=8:00:00 6 | #SBATCH --gres=gpu:p40:1 7 | #SBATCH --mem=14000 8 | #SBATCH --job-name=train_non_linear_pirl_on_stl 9 | #SBATCH --mail-type=END 10 | #SBATCH --mail-user=ab8700@nyu.edu 11 | #SBATCH --output=slurm_%j.out 12 | 13 | python pirl_stl_train_test.py --model-type 'res18' --lr 0.1 --tmax-for-cos-decay 50 --warm-start True --only-train True --non-linear-head True --cont-epoch 101 --experiment-name e8_stl_pirl 14 | 15 | -------------------------------------------------------------------------------- /train_stl_after_ssl.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import torch 5 | import torchvision 6 | 7 | import numpy as np 8 | import pandas as pd 9 | 10 | from torch import optim 11 | from torch.optim.lr_scheduler import ReduceLROnPlateau 12 | from torch.utils.data import DataLoader, ConcatDataset 13 | from torch.utils.data import SubsetRandomSampler 14 | 15 | from common_constants import PAR_WEIGHTS_DIR 16 | from dataset_helpers import def_train_transform_stl, def_test_transform, get_file_paths_n_labels, hflip_data_transform, \ 17 | darkness_jitter_transform, lightness_jitter_transform, rotations_transform, all_in_transform 18 | from experiment_logger import log_experiment 19 | from get_dataset import GetSTL10Data 20 | from models import classifier_resnet, pirl_resnet 21 | from network_helpers import copy_weights_between_models 22 | from random_seed_setter import set_random_generators_seed 23 | from train_test_helper import ModelTrainTest 24 | 25 | if __name__ == '__main__': 26 | 27 | # Training arguments 28 | parser = argparse.ArgumentParser(description='CIFAR10 Train test script') 29 | parser.add_argument('--model-type', type=str, default='res18', 30 | help='The network architecture to employ as backbone') 31 | parser.add_argument('--batch-size', type=int, default=128, metavar='N', 32 | help='input batch size for training (default: 128)') 33 | parser.add_argument('--epochs', type=int, default=100, metavar='N', 34 | help='number of epochs to train (default: 100)') 35 | parser.add_argument('--lr', type=float, default=0.01, metavar='LR', 36 | help='learning rate (default: 0.01)') 37 | parser.add_argument('--weight-decay', type=float, default=5e-4, 38 | help='Weight decay constant (default: 5e-4)') 39 | parser.add_argument('--patience-for-lr-decay', type=int, default=10) 40 | parser.add_argument('--full-fine-tune', type=bool, default=False) 41 | parser.add_argument('--experiment-name', type=str, default='e1_pirl_sup_') 42 | parser.add_argument('--pirl-model-name', type=str) 43 | args = parser.parse_args() 44 | 45 | # Set random number generation seed for all packages that generate random numbers 46 | set_random_generators_seed() 47 | 48 | # Identify device for holding tensors and carrying out computations 49 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 50 | 51 | # Define file path with trained SSL model and file_path where trained classification model 52 | # will be saved 53 | pirl_file_path = os.path.join(PAR_WEIGHTS_DIR, args.pirl_model_name) 54 | model_file_path = os.path.join(PAR_WEIGHTS_DIR, args.experiment_name) 55 | 56 | # Get train-val file paths and labels for STL10 57 | par_train_val_images_dir = './stl10_data/train' 58 | train_val_file_paths, train_val_labels = get_file_paths_n_labels(par_train_val_images_dir) 59 | print ('Train val file paths count', len(train_val_file_paths)) 60 | print ('Train val labels count', len(train_val_labels)) 61 | 62 | # Split file paths into train and val file paths 63 | len_train_val_set = len(train_val_file_paths) 64 | train_val_indices = list(range(len_train_val_set)) 65 | np.random.shuffle(train_val_indices) 66 | 67 | count_train = 4200 68 | 69 | train_indices = train_val_indices[:count_train] 70 | val_indices = train_val_indices[count_train:] 71 | 72 | train_val_file_paths = np.array(train_val_file_paths) 73 | train_val_labels = np.array(train_val_labels) 74 | train_file_paths, train_labels = train_val_file_paths[train_indices], train_val_labels[train_indices] 75 | val_file_paths, val_labels = train_val_file_paths[val_indices], train_val_labels[val_indices] 76 | 77 | # Define train_set, and val_set objects 78 | train_set = ConcatDataset( 79 | [GetSTL10Data(train_file_paths, train_labels, def_train_transform_stl), 80 | GetSTL10Data(train_file_paths, train_labels, hflip_data_transform), 81 | GetSTL10Data(train_file_paths, train_labels, darkness_jitter_transform), 82 | GetSTL10Data(train_file_paths, train_labels, lightness_jitter_transform), 83 | GetSTL10Data(train_file_paths, train_labels, rotations_transform), 84 | GetSTL10Data(train_file_paths, train_labels, all_in_transform)] 85 | ) 86 | 87 | # train_set = GetSTL10Data(train_val_file_paths, train_val_labels, all_in_transform) 88 | val_set = GetSTL10Data(val_file_paths, val_labels, def_test_transform) 89 | 90 | # Define train, validation and test data loaders 91 | train_loader = torch.utils.data.DataLoader(train_set, batch_size=128, num_workers=8) 92 | val_loader = torch.utils.data.DataLoader(val_set, batch_size=100, num_workers=8) 93 | 94 | # Print sample batches that would be returned by the train_data_loader 95 | dataiter = iter(train_loader) 96 | X, y = dataiter.__next__() 97 | print (X.size()) 98 | print (y.size()) 99 | 100 | # Train required model using data loaders defined above 101 | num_outputs = 10 102 | epochs = args.epochs 103 | lr = args.lr 104 | weight_decay_const = args.weight_decay 105 | 106 | # Define model_to_train and inherit weights from pre-trained SSL model 107 | model_to_train = classifier_resnet(args.model_type, num_classes=num_outputs) 108 | pirl_model = pirl_resnet(args.model_type) 109 | pirl_model.load_state_dict(torch.load(pirl_file_path, map_location=device)) 110 | weight_copy_success = copy_weights_between_models(pirl_model, model_to_train) 111 | 112 | if not weight_copy_success: 113 | print ('Weight copy between SSL and classification net failed. Pls check !!') 114 | exit() 115 | 116 | # Freeze all layers except fully connected in classification net 117 | for name, param in model_to_train.named_parameters(): 118 | if name[:7] == 'resnet_': 119 | param.requires_grad = False 120 | 121 | # # To test what is trainable status of each layer 122 | # for name, param in model_to_train.named_parameters(): 123 | # print (name, param.requires_grad) 124 | 125 | # Set device on which training is done. Plus optimizer to use. 126 | model_to_train.to(device) 127 | sgd_optimizer = optim.SGD(model_to_train.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay_const) 128 | scheduler = ReduceLROnPlateau(sgd_optimizer, 'min', patience=args.patience_for_lr_decay, 129 | verbose=True, min_lr=1e-5) 130 | 131 | # Start training 132 | model_train_test_obj = ModelTrainTest(model_to_train, device, model_file_path) 133 | train_losses, val_losses, train_accs, val_accs = [], [], [], [] 134 | for epoch_no in range(epochs): 135 | train_loss, train_acc, val_loss, val_acc = model_train_test_obj.train( 136 | sgd_optimizer, epoch_no, params_max_norm=4, 137 | train_data_loader=train_loader, val_data_loader=val_loader, 138 | no_train_samples=len(train_indices), no_val_samples=len(val_indices) 139 | ) 140 | train_losses.append(train_loss) 141 | val_losses.append(val_loss) 142 | train_accs.append(train_acc) 143 | val_accs.append(val_acc) 144 | scheduler.step(val_loss) 145 | 146 | # Log train-test results 147 | log_experiment(args.experiment_name + '_lin_clf', args.epochs, train_losses, val_losses, train_accs, val_accs) 148 | 149 | # Check if layers beyond last fully connected are to be fine tuned 150 | if args.full_fine_tune: 151 | for name, param in model_to_train.named_parameters(): 152 | param.requires_grad = True 153 | 154 | # Reset optimizer and learning rate scheduler 155 | sgd_optimizer = optim.SGD(model_to_train.parameters(), lr=0.01, momentum=0.9, weight_decay=weight_decay_const) 156 | scheduler = ReduceLROnPlateau(sgd_optimizer, 'min', patience=args.patience_for_lr_decay, 157 | verbose=True, min_lr=1e-5) 158 | 159 | # Re-start training 160 | model_train_test_obj = ModelTrainTest(model_to_train, device, model_file_path) 161 | train_losses, val_losses, train_accs, val_accs = [], [], [], [] 162 | for epoch_no in range(epochs): 163 | train_loss, train_acc, val_loss, val_acc = model_train_test_obj.train( 164 | sgd_optimizer, epoch_no, params_max_norm=4, 165 | train_data_loader=train_loader, val_data_loader=val_loader, 166 | no_train_samples=len(train_indices), no_val_samples=len(val_indices) 167 | ) 168 | train_losses.append(train_loss) 169 | val_losses.append(val_loss) 170 | train_accs.append(train_acc) 171 | val_accs.append(val_acc) 172 | scheduler.step(val_loss) 173 | 174 | # Log train-test results 175 | log_experiment(args.experiment_name + '_full_ft', args.epochs, train_losses, val_losses, train_accs, val_accs) 176 | -------------------------------------------------------------------------------- /train_test_helper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | import numpy as np 5 | 6 | from torch.nn.utils import clip_grad_norm_ 7 | 8 | from pirl_loss import loss_pirl, get_img_pair_probs 9 | 10 | 11 | def get_count_correct_preds(network_output, target): 12 | 13 | score, predicted = torch.max(network_output, 1) # Returns max score and the index where max score was recorded 14 | count_correct = (target == predicted).sum().float() # So that when accuracy is computed, it is not rounded to int 15 | 16 | return count_correct 17 | 18 | 19 | def get_count_correct_preds_pretext(img_pair_probs_arr, img_mem_rep_probs_arr): 20 | """ 21 | Get count of correct predictions for pre-text task 22 | :param img_pair_probs_arr: Prob vector of batch of images I and I_t to belong to same data distribution. 23 | :param img_mem_rep_probs_arr: Prob vector of batch of I and mem_bank_rep of I to belong to same data distribution 24 | """ 25 | 26 | avg_probs_arr = (1/2) * (img_pair_probs_arr + img_mem_rep_probs_arr) 27 | count_correct = (avg_probs_arr >= 0.5).sum().float() # So that when accuracy is computed, it is not rounded to int 28 | 29 | return count_correct.item() 30 | 31 | 32 | class PIRLModelTrainTest(): 33 | 34 | def __init__(self, network, device, model_file_path, all_images_mem, train_image_indices, 35 | val_image_indices, count_negatives, temp_parameter, beta, only_train=False, threshold=1e-4): 36 | super(PIRLModelTrainTest, self).__init__() 37 | self.network = network 38 | self.device = device 39 | self.model_file_path = model_file_path 40 | self.threshold = threshold 41 | self.train_loss = 1e9 42 | self.val_loss = 1e9 43 | self.all_images_mem = torch.tensor(all_images_mem, dtype=torch.float).to(device) 44 | self.train_image_indices = train_image_indices.copy() 45 | self.val_image_indices = val_image_indices.copy() 46 | self.count_negatives = count_negatives 47 | self.temp_parameter = temp_parameter 48 | self.beta = beta 49 | self.only_train = only_train 50 | 51 | def train(self, optimizer, epoch, params_max_norm, train_data_loader, val_data_loader, 52 | no_train_samples, no_val_samples): 53 | self.network.train() 54 | train_loss, correct, cnt_batches = 0, 0, 0 55 | 56 | for batch_idx, (data_batch, batch_img_indices) in enumerate(train_data_loader): 57 | 58 | # Separate input image I batch and transformed image I_t batch (jigsaw patches) from data_batch 59 | i_batch, i_t_patches_batch = data_batch[0], data_batch[1] 60 | 61 | # Set device for i_batch, i_t_patches_batch and batch_img_indices 62 | i_batch, i_t_patches_batch = i_batch.to(self.device), i_t_patches_batch.to(self.device) 63 | batch_img_indices = batch_img_indices.to(self.device) 64 | 65 | # Forward pass through the network 66 | optimizer.zero_grad() 67 | vi_batch, vi_t_batch = self.network(i_batch, i_t_patches_batch) 68 | 69 | # Prepare memory bank of negatives for current batch 70 | np.random.shuffle(self.train_image_indices) 71 | mn_indices_all = np.array(list(set(self.train_image_indices) - set(batch_img_indices))) 72 | np.random.shuffle(mn_indices_all) 73 | mn_indices = mn_indices_all[:self.count_negatives] 74 | mn_arr = self.all_images_mem[mn_indices] 75 | 76 | # Get memory bank representation for current batch images 77 | mem_rep_of_batch_imgs = self.all_images_mem[batch_img_indices] 78 | 79 | # Get prob for I, I_t to belong to same data distribution. 80 | img_pair_probs_arr = get_img_pair_probs(vi_batch, vi_t_batch, mn_arr, self.temp_parameter) 81 | 82 | # Get prob for I and mem_bank_rep of I to belong to same data distribution 83 | img_mem_rep_probs_arr = get_img_pair_probs(vi_batch, mem_rep_of_batch_imgs, mn_arr, self.temp_parameter) 84 | 85 | # Compute loss => back-prop gradients => Update weights 86 | loss = loss_pirl(img_pair_probs_arr, img_mem_rep_probs_arr) 87 | loss.backward() 88 | 89 | clip_grad_norm_(self.network.parameters(), params_max_norm) 90 | optimizer.step() 91 | 92 | # Update running loss and no of pseudo correct predictions for epoch 93 | correct += get_count_correct_preds_pretext(img_pair_probs_arr, img_mem_rep_probs_arr) 94 | train_loss += loss.item() 95 | cnt_batches += 1 96 | 97 | # Update memory bank representation for images from current batch 98 | all_images_mem_new = self.all_images_mem.clone().detach() 99 | all_images_mem_new[batch_img_indices] = (self.beta * all_images_mem_new[batch_img_indices]) + \ 100 | ((1 - self.beta) * vi_batch) 101 | self.all_images_mem = all_images_mem_new.clone().detach() 102 | 103 | del i_batch, i_t_patches_batch, vi_batch, vi_t_batch, mn_arr, mem_rep_of_batch_imgs 104 | del img_mem_rep_probs_arr, img_pair_probs_arr 105 | 106 | train_loss /= cnt_batches 107 | 108 | if epoch % 10 == 0: 109 | torch.save(self.network.state_dict(), self.model_file_path + '_epoch_{}'.format(epoch)) 110 | 111 | if self.only_train is False: 112 | val_loss, val_acc = self.test(epoch, val_data_loader, no_val_samples) 113 | 114 | if val_loss < self.val_loss - self.threshold: 115 | self.val_loss = val_loss 116 | torch.save(self.network.state_dict(), self.model_file_path) 117 | 118 | else: 119 | val_loss, val_acc = 0.0, 0.0 120 | 121 | train_acc = correct / no_train_samples 122 | 123 | print('\nAfter epoch {} - Train set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format( 124 | epoch, train_loss, correct, no_train_samples, 100. * correct / no_train_samples)) 125 | 126 | return train_loss, train_acc, val_loss, val_acc 127 | 128 | def test(self, epoch, test_data_loader, no_test_samples): 129 | 130 | self.network.eval() 131 | test_loss, correct, cnt_batches = 0, 0, 0 132 | 133 | for batch_idx, (data_batch, batch_img_indices) in enumerate(test_data_loader): 134 | 135 | # Separate input image I batch and transformed image I_t batch (jigsaw patches) from data_batch 136 | i_batch, i_t_patches_batch = data_batch[0], data_batch[1] 137 | 138 | # Set device for i_batch, i_t_patches_batch and batch_img_indices 139 | i_batch, i_t_patches_batch = i_batch.to(self.device), i_t_patches_batch.to(self.device) 140 | batch_img_indices = batch_img_indices.to(self.device) 141 | 142 | # Forward pass through the network 143 | vi_batch, vi_t_batch = self.network(i_batch, i_t_patches_batch) 144 | 145 | # Prepare memory bank of negatives for current batch 146 | np.random.shuffle(self.val_image_indices) 147 | 148 | mn_indices_all = np.array(list(set(self.val_image_indices) - set(batch_img_indices))) 149 | np.random.shuffle(mn_indices_all) 150 | mn_indices = mn_indices_all[:self.count_negatives] 151 | mn_arr = self.all_images_mem[mn_indices] 152 | 153 | # Get memory bank representation for current batch images 154 | mem_rep_of_batch_imgs = self.all_images_mem[batch_img_indices] 155 | 156 | # Get prob for I, I_t to belong to same data distribution. 157 | img_pair_probs_arr = get_img_pair_probs(vi_batch, vi_t_batch, mn_arr, self.temp_parameter) 158 | 159 | # Get prob for I and mem_bank_rep of I to belong to same data distribution 160 | img_mem_rep_probs_arr = get_img_pair_probs(vi_batch, mem_rep_of_batch_imgs, mn_arr, self.temp_parameter) 161 | 162 | # Compute loss 163 | loss = loss_pirl(img_pair_probs_arr, img_mem_rep_probs_arr) 164 | 165 | # Update running loss and no of pseudo correct predictions for epoch 166 | correct += get_count_correct_preds_pretext(img_pair_probs_arr, img_mem_rep_probs_arr) 167 | test_loss += loss.item() 168 | cnt_batches += 1 169 | 170 | # Update memory bank representation for images from current batch 171 | all_images_mem_new = self.all_images_mem.clone().detach() 172 | all_images_mem_new[batch_img_indices] = (self.beta * all_images_mem_new[batch_img_indices]) + \ 173 | ((1 - self.beta) * vi_batch) 174 | self.all_images_mem = all_images_mem_new.clone().detach() 175 | 176 | 177 | del i_batch, i_t_patches_batch, vi_batch, vi_t_batch, mn_arr, mem_rep_of_batch_imgs 178 | del img_mem_rep_probs_arr, img_pair_probs_arr 179 | 180 | test_loss /= cnt_batches 181 | test_acc = correct / no_test_samples 182 | print('\nAfter epoch {} - Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format( 183 | epoch, test_loss, correct, no_test_samples, 100. * correct / no_test_samples)) 184 | 185 | return test_loss, test_acc 186 | 187 | 188 | class ModelTrainTest(): 189 | 190 | def __init__(self, network, device, model_file_path, threshold=1e-4): 191 | super(ModelTrainTest, self).__init__() 192 | self.network = network 193 | self.device = device 194 | self.model_file_path = model_file_path 195 | self.threshold = threshold 196 | self.train_loss = 1e9 197 | self.val_loss = 1e9 198 | 199 | def train(self, optimizer, epoch, params_max_norm, train_data_loader, val_data_loader, 200 | no_train_samples, no_val_samples): 201 | self.network.train() 202 | train_loss, correct, cnt_batches = 0, 0, 0 203 | 204 | for batch_idx, (data, target) in enumerate(train_data_loader): 205 | data, target = data.to(self.device), target.to(self.device) 206 | 207 | optimizer.zero_grad() 208 | output = self.network(data) 209 | 210 | loss = F.nll_loss(output, target) 211 | loss.backward() 212 | 213 | clip_grad_norm_(self.network.parameters(), params_max_norm) 214 | optimizer.step() 215 | 216 | correct += get_count_correct_preds(output, target) 217 | train_loss += loss.item() 218 | cnt_batches += 1 219 | 220 | del data, target, output 221 | 222 | train_loss /= cnt_batches 223 | val_loss, val_acc = self.test(epoch, val_data_loader, no_val_samples) 224 | 225 | if val_loss < self.val_loss - self.threshold: 226 | self.val_loss = val_loss 227 | torch.save(self.network.state_dict(), self.model_file_path) 228 | 229 | train_acc = correct / no_train_samples 230 | 231 | print('\nAfter epoch {} - Train set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format( 232 | epoch, train_loss, correct, no_train_samples, 100. * correct / no_train_samples)) 233 | 234 | return train_loss, train_acc, val_loss, val_acc 235 | 236 | def test(self, epoch, test_data_loader, no_test_samples): 237 | self.network.eval() 238 | test_loss = 0 239 | correct = 0 240 | 241 | for batch_idx, (data, target) in enumerate(test_data_loader): 242 | data, target = data.to(self.device), target.to(self.device) 243 | output = self.network(data) 244 | test_loss += F.nll_loss(output, target, size_average=False).item() # sum up batch loss 245 | 246 | correct += get_count_correct_preds(output, target) 247 | 248 | del data, target, output 249 | 250 | test_loss /= no_test_samples 251 | test_acc = correct / no_test_samples 252 | print('\nAfter epoch {} - Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format( 253 | epoch, test_loss, correct, no_test_samples, 100. * correct / no_test_samples)) 254 | 255 | return test_loss, test_acc 256 | 257 | if __name__ == '__main__': 258 | img_pair_probs_arr = torch.randn((256,)) 259 | img_mem_rep_probs_arr = torch.randn((256,)) 260 | print (get_count_correct_preds_pretext(img_pair_probs_arr, img_mem_rep_probs_arr)) --------------------------------------------------------------------------------