├── .gitignore
├── common_constants.py
├── dataset_helpers.py
├── environment.yml
├── experiment_logger.py
├── get_dataset.py
├── models.py
├── network_helpers.py
├── pirl_loss.py
├── pirl_stl_train_test.py
├── random_seed_setter.py
├── readme.md
├── stl10_data_load.py
├── submit_job.sh
├── train_stl_after_ssl.py
└── train_test_helper.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/
2 | cifar_data/
3 | __pycache__
4 | .DS_Store
5 | weights/
6 | parse_raw_output.py
7 | observations_dir/
8 | e1_test_observations.csv
9 | raw_stl10/
10 | stl10_data/
11 |
--------------------------------------------------------------------------------
/common_constants.py:
--------------------------------------------------------------------------------
1 | PAR_WEIGHTS_DIR = './weights'
2 | PAR_OBSERVATIONS_DIR = './observations_dir'
3 | PAR_ACTIVATIONS_DIR = './activations_dir'
4 |
--------------------------------------------------------------------------------
/dataset_helpers.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from torchvision.transforms import transforms
4 |
5 | def_train_transform_stl = transforms.Compose([
6 | transforms.ToTensor(),
7 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
8 | ])
9 |
10 | hflip_data_transform = transforms.Compose([
11 | transforms.RandomHorizontalFlip(p=1.0),
12 | transforms.ToTensor(),
13 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
14 | ])
15 |
16 | darkness_jitter_transform = transforms.Compose([
17 | transforms.ColorJitter(brightness=[0.5, 0.9]),
18 | transforms.ToTensor(),
19 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
20 | ])
21 |
22 | lightness_jitter_transform = transforms.Compose([
23 | transforms.ColorJitter(brightness=[1.1, 1.5]),
24 | transforms.ToTensor(),
25 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
26 | ])
27 |
28 | rotations_transform = transforms.Compose([
29 | transforms.RandomRotation(degrees=25),
30 | transforms.ToTensor(),
31 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
32 | ])
33 |
34 | all_in_transform = transforms.Compose([
35 | transforms.RandomHorizontalFlip(p=0.5),
36 | transforms.ColorJitter(brightness=[0.5, 1.5]),
37 | transforms.RandomRotation(degrees=25),
38 | transforms.ToTensor(),
39 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
40 | ])
41 |
42 | def_test_transform = transforms.Compose([
43 | transforms.ToTensor(),
44 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
45 | ])
46 |
47 | pirl_full_img_transform = transforms.Compose([
48 | transforms.RandomHorizontalFlip(),
49 | transforms.ToTensor(),
50 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
51 | ])
52 |
53 | pirl_stl10_jigsaw_patch_transform = transforms.Compose([
54 | transforms.RandomCrop(30, padding=1),
55 | transforms.ColorJitter(brightness=[0.5, 1.5]),
56 | transforms.ToTensor(),
57 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
58 | ])
59 |
60 |
61 | def get_file_paths_n_labels(par_images_dir):
62 | """
63 | Returns all file paths for images in a directory (par_images_dir).
64 | The par_images_dir is supposed to only have sub directories with each sub-directory representing a label
65 | And in turn each sub-directory holding images pertaining to the label it represents
66 | """
67 |
68 | label_names = [dir_name for dir_name in os.listdir(par_images_dir)
69 | if os.path.isdir(os.path.join(par_images_dir, dir_name))]
70 | label_dir_paths = [os.path.join(par_images_dir, dir_name) for dir_name in label_names]
71 |
72 | file_paths = []
73 | labels = []
74 |
75 | for label_name, label_dir in zip(label_names, label_dir_paths):
76 | file_names = [file_name for file_name in os.listdir(label_dir) if file_name[-4:]=='jpeg']
77 | file_paths += [os.path.join(label_dir, file_name) for file_name in file_names]
78 | labels += [int(label_name) - 1] * len(file_names)
79 |
80 | return file_paths, labels
81 |
82 |
83 | def get_nine_crops(pil_image):
84 | """
85 | Get nine crops for a square pillow image. That is height and width of the image should be same.
86 | :param pil_image: pillow image
87 | :return: List of pillow images. The nine crops
88 | """
89 | w, h = pil_image.size
90 | diff = int(w/3)
91 |
92 | r_vals = [0, diff, 2 * diff]
93 | c_vals = [0, diff, 2 * diff]
94 |
95 | list_patches = []
96 |
97 | for r in r_vals:
98 | for c in c_vals:
99 |
100 | left = c
101 | top = r
102 | right = c + diff
103 | bottom = r + diff
104 |
105 | patch = pil_image.crop((left, top, right, bottom))
106 | list_patches.append(patch)
107 |
108 | return list_patches
109 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: pirl-cifar
2 | channels:
3 | - conda-forge
4 | - pytorch
5 | dependencies:
6 | - imageio=2.6.1
7 | - pandas=1.0.1
8 | - python=3.7
9 | - pillow=6.1
10 | - pip
11 | - pytorch=1.1
12 | - scipy=1.2.0
13 | - torchvision
14 | - opencv
15 | - matplotlib
16 | - jupyterlab
17 | - nb_conda_kernels
18 | - pip:
19 | - torchtext
20 | - torchviz
21 |
--------------------------------------------------------------------------------
/experiment_logger.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pandas as pd
3 |
4 | from common_constants import PAR_OBSERVATIONS_DIR
5 |
6 |
7 | def log_experiment(exp_name, n_epochs, train_losses, val_losses, train_accs, val_accs):
8 | observations_df = pd.DataFrame()
9 | observations_df['epoch count'] = [i for i in range(1, n_epochs + 1)]
10 | observations_df['train loss'] = train_losses
11 | observations_df['val loss'] = val_losses
12 | observations_df['train acc'] = train_accs
13 | observations_df['val acc'] = val_accs
14 | observations_file_path = os.path.join(PAR_OBSERVATIONS_DIR, exp_name + '_observations.csv')
15 | observations_df.to_csv(observations_file_path)
16 |
--------------------------------------------------------------------------------
/get_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | from PIL import Image
5 |
6 | from torch.utils.data import Dataset
7 |
8 | from dataset_helpers import get_nine_crops, pirl_full_img_transform, pirl_stl10_jigsaw_patch_transform
9 |
10 |
11 | class GetSTL10Data(Dataset):
12 | 'Characterizes PyTorch Dataset object'
13 | def __init__(self, file_paths, labels, transform):
14 | 'Initialization'
15 | self.file_paths = file_paths
16 | self.labels = labels
17 | self.transform = transform
18 |
19 | def __len__(self):
20 | 'Denotes the total number of samples'
21 | return len(self.file_paths)
22 |
23 | def __getitem__(self, index):
24 | 'Generates one sample of data'
25 |
26 | # Select one file_path and convert to tensor object
27 | image = Image.open(self.file_paths[index])
28 | image_tensor = self.transform(image)
29 | label = self.labels[index]
30 |
31 | return image_tensor, label
32 |
33 |
34 | class GetSTL10DataForPIRL(Dataset):
35 | 'Characterizes PyTorch Dataset object'
36 | def __init__(self, file_paths):
37 | 'Initialization'
38 | self.file_paths = file_paths
39 |
40 | def __len__(self):
41 | 'Denotes the total number of samples'
42 | return len(self.file_paths)
43 |
44 | def __getitem__(self, index):
45 | 'Generates one sample of data'
46 |
47 | # Select one file_path and convert to tensor object
48 | image = Image.open(self.file_paths[index])
49 | image_tensor = pirl_full_img_transform(image)
50 |
51 | # Get nine crops for the image
52 | nine_crops = get_nine_crops(image)
53 |
54 | # Form the jigsaw order for this image
55 | original_order = np.arange(9)
56 | permuted_order = np.copy(original_order)
57 | np.random.shuffle(permuted_order)
58 |
59 | # Permut the 9 patches obtained from the image
60 | permuted_patches_arr = [None] * 9
61 | for patch_pos, patch in zip(permuted_order, nine_crops):
62 | permuted_patches_arr[patch_pos] = patch
63 |
64 | # Apply data transforms
65 | # TODO: Remove hard coded values from here
66 | tensor_patches = torch.zeros(9, 3, 30, 30)
67 | for ind, patch in enumerate(permuted_patches_arr):
68 | patch_tensor = pirl_stl10_jigsaw_patch_transform(patch)
69 | tensor_patches[ind] = patch_tensor
70 |
71 | return [image_tensor, tensor_patches], index
72 |
73 |
74 |
75 | if __name__ == '__main__':
76 |
77 | # Lets test the GetSTL10DataForPIRL class
78 | print("Testing for GetSTL10DataForPIRL")
79 | base_images_dir = 'stl10_data/unlabelled'
80 | file_names_list = os.listdir(base_images_dir)
81 | file_names_list = [file_name for file_name in file_names_list if file_name[-4:] == 'jpeg']
82 |
83 | file_paths_list = [os.path.join(base_images_dir, file_name) for file_name in file_names_list]
84 | ssl_dataset = GetSTL10DataForPIRL(file_paths_list)
85 | ssl_loader = torch.utils.data.DataLoader(ssl_dataset, batch_size=128, num_workers=8)
86 |
87 | dataiter = iter(ssl_loader)
88 | X, y = dataiter.__next__()
89 | print(X[0].size())
90 | print(X[1].size())
91 | print(y.size())
92 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torchvision.models import resnet18, resnet34, resnet50
5 |
6 |
7 | class ClassificationResNet(nn.Module):
8 |
9 | def __init__(self, resnet_module, num_classes):
10 | super(ClassificationResNet, self).__init__()
11 | self.resnet_module = resnet_module
12 | self.fc = nn.Linear(512, num_classes)
13 |
14 | def forward(self, input_batch):
15 |
16 | # Data returned by data loaders is of the shape (batch_size, no_channels, h_patch, w_patch)
17 | resnet_feat_vectors = self.resnet_module(input_batch)
18 | final_feat_vectors = torch.flatten(resnet_feat_vectors, 1)
19 | x = F.log_softmax(self.fc(final_feat_vectors))
20 |
21 | return x
22 |
23 |
24 | def get_base_resnet_module(model_type):
25 | """
26 | Returns the backbone network for required resnet architecture, specified as model_type
27 | :param model_type: Can be either of {res18, res34, res50}
28 | """
29 |
30 | if model_type == 'res18':
31 | original_model = resnet18(pretrained=False)
32 | elif model_type == 'res34':
33 | original_model = resnet34(pretrained=False)
34 | else:
35 | original_model = resnet50(pretrained=False)
36 | base_resnet_module = nn.Sequential(*list(original_model.children())[:-1])
37 |
38 | return base_resnet_module
39 |
40 |
41 | def classifier_resnet(model_type, num_classes):
42 | """
43 | Returns a classification network with backbone belonging to the family of ResNets
44 | :param model_type: Specifies which resnet network to employ. Can be one of {res18, res34, res50}
45 | :param num_classes: The number of classes that the final network classifies it inputs into.
46 | """
47 |
48 | base_resnet_module = get_base_resnet_module(model_type)
49 |
50 | return ClassificationResNet(base_resnet_module, num_classes)
51 |
52 |
53 | class PIRLResnet(nn.Module):
54 | def __init__(self, resnet_module, non_linear_head=False):
55 | super(PIRLResnet, self).__init__()
56 | self.resnet_module = resnet_module
57 | self.lin_project_1 = nn.Linear(512, 128)
58 | self.lin_project_2 = nn.Linear(128 * 9, 128)
59 | if non_linear_head:
60 | self.lin_project_3 = nn.Linear(128, 128) # Will only be used if non_linear_head is True
61 | self.non_linear_head = non_linear_head
62 |
63 | def forward(self, i_batch, i_t_patches_batch):
64 | """
65 | :param i_batch: Batch of images
66 | :param i_t_patches_batch: Batch of transformed image patches (jigsaw transformation)
67 | """
68 |
69 | # Run I and I_t through resnet
70 | vi_batch = self.resnet_module(i_batch)
71 | vi_batch = torch.flatten(vi_batch, 1)
72 | vi_t_patches_batch = [self.resnet_module(i_t_patches_batch[:, patch_ind, :, :, :])
73 | for patch_ind in range(9)]
74 | vi_t_patches_batch = [torch.flatten(vi_t_patches_batch[patch_ind], 1)
75 | for patch_ind in range(9)]
76 |
77 | # Run resnet features for I and I_t via lin_project_1 layer
78 | vi_batch = self.lin_project_1(vi_batch)
79 | vi_t_patches_batch = [self.lin_project_1(vi_t_patches_batch[patch_ind])
80 | for patch_ind in range(9)]
81 |
82 | # Concatenate together lin_project_1 outputs for patches of I_t
83 | vi_t_patches_concatenated = torch.cat(vi_t_patches_batch, 1)
84 |
85 | # Run concatenated feature vector for I_t through lin_project_2 layer
86 | vi_t_batch = self.lin_project_2(vi_t_patches_concatenated)
87 |
88 | # Run final feature vectors obtained for I and I_t through non-linearity (if specified)
89 | if self.non_linear_head:
90 | vi_batch = self.lin_project_3(F.relu(vi_batch))
91 | vi_t_batch = self.lin_project_3(F.relu(vi_t_batch))
92 |
93 | return vi_batch, vi_t_batch
94 |
95 |
96 | def pirl_resnet(model_type, non_linear_head=False):
97 | """
98 | Returns a network which supports Pre-text invariant representation learning
99 | with backbone belonging to the family of ResNets
100 | :param model_type: Specifies which resnet network to employ. Can be one of {res18, res34, res50}
101 | :param non_linear_head: If true apply non-linearity to the output of function heads
102 | applied to resnet image representations
103 | """
104 |
105 | base_resnet_module = get_base_resnet_module(model_type)
106 |
107 | return PIRLResnet(base_resnet_module, non_linear_head)
108 |
109 |
110 | if __name__ == '__main__':
111 | pr = pirl_resnet('res18', non_linear_head=True) # non_linear_head can be True or False either.
112 | image_batch = torch.randn(32, 3, 64, 64)
113 | tr_img_patch_batch = torch.randn(32, 9, 3, 32, 32)
114 |
115 | result1, result2 = pr.forward(image_batch, tr_img_patch_batch)
116 |
117 | print (result1.size())
118 | print (result2.size())
119 |
--------------------------------------------------------------------------------
/network_helpers.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from models import pirl_resnet, classifier_resnet
4 |
5 |
6 | def test_copy_weights(m1, m2):
7 | """
8 | Tests that weights copied from m1 into m2, are actually refected in m2
9 | """
10 | m1_state_dict = m1.state_dict()
11 | m2_state_dict = m2.state_dict()
12 | weight_copy_flag = 1
13 | for name, param in m1_state_dict.items():
14 | if name in m2_state_dict:
15 | if not torch.all(torch.eq(param.data, m2_state_dict[name].data)):
16 | print("Something is incorrect for layer {} in 2nd model", name)
17 | weight_copy_flag = 0
18 |
19 | if weight_copy_flag:
20 | print('All is well')
21 |
22 | return 1
23 |
24 |
25 | def copy_weights_between_models(m1, m2):
26 | """
27 | Copy weights for layers common between m1 and m2.
28 | From m1 => m2
29 | """
30 |
31 | # Load state dictionaries for m1 model and m2 model
32 | m1_state_dict = m1.state_dict()
33 | m2_state_dict = m2.state_dict()
34 |
35 | # Set the m2 model's weights with trained m1 model weights
36 | for name, param in m1_state_dict.items():
37 | if name not in m2_state_dict:
38 | continue
39 | else:
40 | m2_state_dict[name] = param.data
41 | m2.load_state_dict(m2_state_dict)
42 |
43 | # Test that model m2 **really** has got updated weights
44 | return test_copy_weights(m1, m2)
45 |
46 |
47 | if __name__ == '__main__':
48 |
49 | pr = pirl_resnet('res18')
50 | cr = classifier_resnet('res18', num_classes=10)
51 |
52 | copy_success = copy_weights_between_models(pr, cr)
53 |
54 |
--------------------------------------------------------------------------------
/pirl_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | def get_img_pair_probs(vi_batch, vi_t_batch, mn_arr, temp_parameter):
6 | """
7 | Returns the probability that feature representation for image I and I_t belong to same distribution.
8 | :param vi_batch: Feature representation for batch of images I
9 | :param vi_t_batch: Feature representation for batch containing transformed versions of I.
10 | :param mn_arr: Memory bank of feature representations for negative images for current batch
11 | :param temp_parameter: The temperature parameter
12 | """
13 |
14 | # Define constant eps to ensure training is not impacted if norm of any image rep is zero
15 | eps = 1e-6
16 |
17 | # L2 normalize vi, vi_t and memory bank representations
18 | vi_norm_arr = torch.norm(vi_batch, dim=1, keepdim=True)
19 | vi_t_norm_arr = torch.norm(vi_t_batch, dim=1, keepdim=True)
20 | mn_norm_arr = torch.norm(mn_arr, dim=1, keepdim=True)
21 |
22 | vi_batch = vi_batch / (vi_norm_arr + eps)
23 | vi_t_batch = vi_t_batch/ (vi_t_norm_arr + eps)
24 | mn_arr = mn_arr / (mn_norm_arr + eps)
25 |
26 | # Find cosine similarities
27 | sim_vi_vi_t_arr = (vi_batch @ vi_t_batch.t()).diagonal()
28 | sim_vi_t_mn_mat = (vi_t_batch @ mn_arr.t())
29 |
30 | # Fine exponentiation of similarity arrays
31 | exp_sim_vi_vi_t_arr = torch.exp(sim_vi_vi_t_arr / temp_parameter)
32 | exp_sim_vi_t_mn_mat = torch.exp(sim_vi_t_mn_mat / temp_parameter)
33 |
34 | # Sum exponential similarities of I_t with different images from memory bank of negatives
35 | sum_exp_sim_vi_t_mn_arr = torch.sum(exp_sim_vi_t_mn_mat, 1)
36 |
37 | # Find batch probabilities arr
38 | batch_prob_arr = exp_sim_vi_vi_t_arr / (exp_sim_vi_vi_t_arr + sum_exp_sim_vi_t_mn_arr + eps)
39 |
40 | return batch_prob_arr
41 |
42 |
43 | def loss_pirl(img_pair_probs_arr, img_mem_rep_probs_arr):
44 | """
45 | Returns the average of [-log(prob(img_pair_probs_arr)) - log(prob(img_mem_rep_probs_arr))]
46 | :param img_pair_probs_arr: Prob vector of batch of images I and I_t to belong to same data distribution.
47 | :param img_mem_rep_probs_arr: Prob vector of batch of I and mem_bank_rep of I to belong to same data distribution
48 | """
49 |
50 | # Get 1st term of loss
51 | neg_log_img_pair_probs = -1 * torch.log(img_pair_probs_arr)
52 | loss_i_i_t = torch.sum(neg_log_img_pair_probs) / neg_log_img_pair_probs.size()[0]
53 |
54 | # Get 2nd term of loss
55 | neg_log_img_mem_rep_probs_arr = -1 * torch.log(img_mem_rep_probs_arr)
56 | loss_i_mem_i = torch.sum(neg_log_img_mem_rep_probs_arr) / neg_log_img_mem_rep_probs_arr.size()[0]
57 |
58 | loss = (loss_i_i_t + loss_i_mem_i) / 2
59 |
60 | return loss
61 |
62 |
63 | if __name__ == '__main__':
64 | # Test get_img_pair_probs function
65 | vi_batch = torch.randn(256, 128)
66 | vi_t_batch = torch.randn(256, 128)
67 | mn_arr = torch.randn(6400, 128)
68 | mem_rep_of_batch_imgs = torch.randn(256, 128)
69 | temp_parameter = 1.5
70 |
71 | # Prob vector between I and I_t
72 | img_pair_probs_arr = get_img_pair_probs(vi_batch, vi_t_batch, mn_arr, temp_parameter)
73 | print (img_pair_probs_arr.shape)
74 |
75 | # Prob vector between I and mem bank representation of I
76 | img_mem_rep_probs_arr = get_img_pair_probs(vi_batch, mem_rep_of_batch_imgs, mn_arr, temp_parameter)
77 | print (img_mem_rep_probs_arr.shape)
78 |
79 | # Final loss
80 | loss_val = loss_pirl(img_pair_probs_arr, img_mem_rep_probs_arr)
81 |
82 | print (loss_val)
83 |
84 |
85 |
--------------------------------------------------------------------------------
/pirl_stl_train_test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 |
4 | import torch
5 | import torchvision
6 |
7 | import numpy as np
8 | import pandas as pd
9 |
10 | from torch import optim
11 | from torch.optim.lr_scheduler import CosineAnnealingLR
12 | from torch.utils.data import SubsetRandomSampler
13 |
14 | from common_constants import PAR_WEIGHTS_DIR
15 | from experiment_logger import log_experiment
16 | from get_dataset import GetSTL10DataForPIRL
17 | from models import pirl_resnet
18 | from random_seed_setter import set_random_generators_seed
19 | from train_test_helper import PIRLModelTrainTest
20 |
21 |
22 | def unpickle(file):
23 | import pickle
24 | with open(file, 'rb') as fo:
25 | dict = pickle.load(fo, encoding='bytes')
26 | return dict
27 |
28 |
29 | if __name__ == '__main__':
30 |
31 | # Training arguments
32 | parser = argparse.ArgumentParser(description='STL10 Train test script for PIRL task')
33 | parser.add_argument('--model-type', type=str, default='res18', help='The network architecture to employ as backbone')
34 | parser.add_argument('--batch-size', type=int, default=128, metavar='N',
35 | help='input batch size for training (default: 128)')
36 | parser.add_argument('--epochs', type=int, default=100, metavar='N',
37 | help='number of epochs to train (default: 100)')
38 | parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
39 | help='learning rate (default: 0.01)')
40 | parser.add_argument('--weight-decay', type=float, default=5e-4,
41 | help='Weight decay constant (default: 5e-4)')
42 | parser.add_argument('--tmax-for-cos-decay', type=int, default=70)
43 | parser.add_argument('--warm-start', type=bool, default=False)
44 | parser.add_argument('--count-negatives', type=int, default=6400,
45 | help='No of samples in memory bank of negatives')
46 | parser.add_argument('--beta', type=float, default=0.5, help='Exponential running average constant'
47 | 'in memory bank update')
48 | parser.add_argument('--only-train', type=bool, default=False,
49 | help='If true utilize the entire unannotated STL10 dataset for training.')
50 | parser.add_argument('--non-linear-head', type=bool, default=False,
51 | help='If true apply non-linearity to the output of function heads '
52 | 'applied to resnet image representations')
53 | parser.add_argument('--temp-parameter', type=float, default=0.07, help='Temperature parameter in NCE probability')
54 | parser.add_argument('--cont-epoch', type=int, default=1, help='Epoch to start the training from, helpful when using'
55 | 'warm start')
56 | parser.add_argument('--experiment-name', type=str, default='e1_resnet18_')
57 | args = parser.parse_args()
58 |
59 | # Set random number generation seed for all packages that generate random numbers
60 | set_random_generators_seed()
61 |
62 | # Identify device for holding tensors and carrying out computations
63 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
64 |
65 | # Define the file_path where trained model will be saved
66 | model_file_path = os.path.join(PAR_WEIGHTS_DIR, args.experiment_name + '_epoch_100')
67 |
68 | # Get train_val image file_paths
69 | base_images_dir = 'stl10_data/unlabelled'
70 | file_names_list = os.listdir(base_images_dir)
71 | file_names_list = [file_name for file_name in file_names_list if file_name[-4:] == 'jpeg']
72 | file_paths_list = [os.path.join(base_images_dir, file_name) for file_name in file_names_list]
73 |
74 | # Define train_set, val_set objects
75 | train_set = GetSTL10DataForPIRL(file_paths_list)
76 | val_set = GetSTL10DataForPIRL(file_paths_list)
77 |
78 | # Define train and validation data loaders
79 | len_train_val_set = len(train_set)
80 | train_val_indices = list(range(len_train_val_set))
81 | np.random.shuffle(train_val_indices)
82 |
83 | if args.only_train is False:
84 | count_train = 70000
85 | else:
86 | count_train = 100000
87 |
88 | train_indices = train_val_indices[:count_train]
89 | val_indices = train_val_indices[count_train:]
90 |
91 | train_sampler = SubsetRandomSampler(train_indices)
92 | val_sampler = SubsetRandomSampler(val_indices)
93 |
94 | train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, sampler=train_sampler,
95 | num_workers=8)
96 | val_loader = torch.utils.data.DataLoader(val_set, batch_size=args.batch_size, sampler=val_sampler,
97 | num_workers=8)
98 |
99 | # Print sample batches that would be returned by the train_data_loader
100 | dataiter = iter(train_loader)
101 | X, y = dataiter.__next__()
102 | print (X[0].size())
103 | print (X[1].size())
104 | print (y.size())
105 |
106 | # Train required model using data loaders defined above
107 | epochs = args.epochs
108 | lr = args.lr
109 | weight_decay_const = args.weight_decay
110 |
111 | # If using Resnet18
112 | model_to_train = pirl_resnet(args.model_type, args.non_linear_head)
113 |
114 | # Set device on which training is done. Plus optimizer to use.
115 | model_to_train.to(device)
116 | sgd_optimizer = optim.SGD(model_to_train.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay_const)
117 | scheduler = CosineAnnealingLR(sgd_optimizer, args.tmax_for_cos_decay, eta_min=1e-4, last_epoch=-1)
118 |
119 | # Initialize model weights with a previously trained model if using warm start
120 | if args.warm_start and os.path.exists(model_file_path):
121 | model_to_train.load_state_dict(torch.load(model_file_path, map_location=device))
122 |
123 | # Start training
124 | all_images_mem = np.random.randn(len_train_val_set, 128)
125 | model_train_test_obj = PIRLModelTrainTest(
126 | model_to_train, device, model_file_path, all_images_mem, train_indices, val_indices, args.count_negatives,
127 | args.temp_parameter, args.beta, args.only_train
128 | )
129 | train_losses, val_losses, train_accs, val_accs = [], [], [], []
130 | for epoch_no in range(args.cont_epoch, args.cont_epoch + epochs):
131 | train_loss, train_acc, val_loss, val_acc = model_train_test_obj.train(
132 | sgd_optimizer, epoch_no, params_max_norm=4,
133 | train_data_loader=train_loader, val_data_loader=val_loader,
134 | no_train_samples=len(train_indices), no_val_samples=len(val_indices)
135 | )
136 | train_losses.append(train_loss)
137 | val_losses.append(val_loss)
138 | train_accs.append(train_acc)
139 | val_accs.append(val_acc)
140 | scheduler.step()
141 |
142 | # Log train-test results
143 | log_experiment(args.experiment_name, args.epochs, train_losses, val_losses, train_accs, val_accs)
144 |
--------------------------------------------------------------------------------
/random_seed_setter.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | import torch
4 |
5 |
6 | def set_random_generators_seed():
7 |
8 | # Set up random seed to 1008. Do not change the random seed.
9 | # Yes, these are all necessary when you run experiments!
10 | seed = 1008
11 | random.seed(seed)
12 | np.random.seed(seed)
13 | torch.manual_seed(seed)
14 | cuda = torch.cuda.is_available()
15 | if cuda:
16 | torch.cuda.manual_seed(seed)
17 | torch.cuda.manual_seed_all(seed)
18 | torch.backends.cudnn.benchmark = False
19 | torch.backends.cudnn.deterministic = True
20 |
--------------------------------------------------------------------------------
/readme.md:
--------------------------------------------------------------------------------
1 | # Pytorch Implementation of Pre-text Invariant Representation Learning
2 | This repository contains the pyotrch implementation of Pretext invariant representation learning (PIRL)
3 | algorithm on STL10 dataset. PIRL was originally introduced by Misra et al, publication of which can be found [here](https://arxiv.org/abs/1912.01991).
4 |
5 | ## What is PIRL and why is it useful
6 | Pretext invariant representation learning (PIRL) is a self supervised learing algorithm that exploits contrastive
7 | learning to learn visual representations such that original and transformed version of the same image have similar
8 | representations, while being different from that of other images, thus achieving invariance to the transformation.
9 |
10 | In their paper, authors have primarily focused on jigsaw puzzles transformation.
11 |
12 | ## Loss Function and slight modification
13 | The CNN used for representation learning is trained using NCE (Noise Contrastive Estimation) technique,
14 | NCE models the porbability of event that (I, I_t) (original and transformed image) originate from the same
15 | data distribution, I.e.
16 | 
17 |
18 | Where s(., .) is cosine similarity between v_i and v_i_t, deep representations for original and transformed image respectively.
19 | While, the final NCE loss is given as:
20 | 
21 | where f(.) and g(.) are linear function heads.
22 |
23 | ## Slight Modification
24 | Instead of using NCE loss, for this implementation, optimization process would directly aim to minimize
25 | the negative log of probability described in the first equation above (with inputs as f(v_i) and g(v_i_t))
26 |
27 | ## Dataset Used
28 | The implementation uses STL10 dataset, which can be downloaded from [here](http://ai.stanford.edu/~acoates/stl10/)
29 | #### Dataset setup steps
30 | ```
31 | 1. Download raw data from above link to ./raw_stl10/
32 | 2. Run stl10_data_load.py. This will save three directories train, test and unlabelled in ./stl10_data/
33 | ```
34 |
35 | ## Training and evaluation steps
36 | 1. Run script pirl_stl_train_test.py for unsupervised (self supervised learning), example
37 | ```
38 | python pirl_stl_train_test.py --model-type res18 --batch-size 128 --lr 0.1 --experiment-name exp
39 | ```
40 | 2. Run script train_stl_after_ssl.py for fine tuning model parameters obtained from self supervised learning, example
41 | ```
42 | python train_stl_after_ssl.py --model-type res18 --batch-size 128 --lr 0.1 --patience-for-lr-decay 4 --full-fine-tune True --pirl-model-name
43 | ```
44 |
45 | ## Results
46 | After training the CNN model in PIRL manner, to evaluate how well learnt model weights transfer to classification
47 | problem in limited dataset scenario, following experiments were performed.
48 |
49 | Fine tuning strategy | Val Classification Accuracy
50 | --- | ---
51 | Only softmax layer is fine tuned | 50.50
52 | Full model is fine tuned | 67.87
53 |
54 | # References
55 | 1. PIRL paper: https://arxiv.org/abs/1912.01991
56 | 2. STL 10 dataset: http://ai.stanford.edu/~acoates/stl10/
57 | 3. Data loading code for STL 10: https://github.com/mttk/STL10
58 |
--------------------------------------------------------------------------------
/stl10_data_load.py:
--------------------------------------------------------------------------------
1 | """
2 | Script obtained from: https://github.com/mttk/STL10/blob/master/stl10_input.py
3 | """
4 |
5 | from __future__ import print_function
6 |
7 | import os, sys, tarfile, errno
8 | import numpy as np
9 | from PIL import Image
10 |
11 | if sys.version_info >= (3, 0, 0):
12 | import urllib.request as urllib # ugly but works
13 | else:
14 | import urllib
15 |
16 | print(sys.version_info)
17 |
18 | # image shape
19 | HEIGHT = 96
20 | WIDTH = 96
21 | DEPTH = 3
22 |
23 | # size of a single image in bytes
24 | SIZE = HEIGHT * WIDTH * DEPTH
25 |
26 | # path to the directory with the data
27 | DATA_DIR = './raw_stl10/'
28 | STL10_TRAIN_IMG_DIR = './stl10_data/train/'
29 | STL10_TEST_IMG_DIR = './stl10_data/test/'
30 | STL10_UNLABELLED_IMG_DIR = './stl10_data/unlabelled/'
31 |
32 | # url of the binary data
33 | DATA_URL = 'http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz'
34 |
35 | # path to the binary files with image data
36 | TRAIN_DATA_PATH = os.path.join(DATA_DIR ,'stl10_binary/train_X.bin')
37 | TEST_DATA_PATH = os.path.join(DATA_DIR , 'stl10_binary/test_X.bin')
38 | UNLABELLED_DATA_PATH = os.path.join(DATA_DIR, 'stl10_binary/unlabeled_X.bin')
39 |
40 | # path to the binary files with labels
41 | TRAIN_LABEL_PATH = os.path.join(DATA_DIR, 'stl10_binary/train_y.bin')
42 | TEST_LABEL_PATH = os.path.join(DATA_DIR, 'stl10_binary/test_y.bin')
43 |
44 |
45 | def read_labels(path_to_labels):
46 | """
47 | :param path_to_labels: path to the binary file containing labels from the STL-10 dataset
48 | :return: an array containing the labels
49 | """
50 | with open(path_to_labels, 'rb') as f:
51 | labels = np.fromfile(f, dtype=np.uint8)
52 | return labels
53 |
54 |
55 | def read_all_images(path_to_data):
56 | """
57 | :param path_to_data: the file containing the binary images from the STL-10 dataset
58 | :return: an array containing all the images
59 | """
60 |
61 | with open(path_to_data, 'rb') as f:
62 | # read whole file in uint8 chunks
63 | everything = np.fromfile(f, dtype=np.uint8)
64 |
65 | # We force the data into 3x96x96 chunks, since the
66 | # images are stored in "column-major order", meaning
67 | # that "the first 96*96 values are the red channel,
68 | # the next 96*96 are green, and the last are blue."
69 | # The -1 is since the size of the pictures depends
70 | # on the input file, and this way numpy determines
71 | # the size on its own.
72 |
73 | images = np.reshape(everything, (-1, 3, 96, 96))
74 |
75 | # Now transpose the images into a standard image format
76 | # readable by, for example, matplotlib.imshow
77 | # You might want to comment this line or reverse the shuffle
78 | # if you will use a learning algorithm like CNN, since they like
79 | # their channels separated.
80 | images = np.transpose(images, (0, 3, 2, 1))
81 | return images
82 |
83 |
84 | def save_image(image, name):
85 | image = Image.fromarray(image)
86 | image.save(name + '.jpeg')
87 |
88 |
89 | def download_and_extract():
90 | """
91 | Download and extract the STL-10 dataset
92 | :return: None
93 | """
94 | dest_directory = DATA_DIR
95 | if not os.path.exists(dest_directory):
96 | os.makedirs(dest_directory)
97 | filename = DATA_URL.split('/')[-1]
98 | filepath = os.path.join(dest_directory, filename)
99 | if not os.path.exists(filepath):
100 | def _progress(count, block_size, total_size):
101 | sys.stdout.write('\rDownloading %s %.2f%%' % (filename,
102 | float(count * block_size) / float(total_size) * 100.0))
103 | sys.stdout.flush()
104 |
105 | filepath, _ = urllib.urlretrieve(DATA_URL, filepath, reporthook=_progress)
106 | print('Downloaded', filename)
107 | tarfile.open(filepath, 'r:gz').extractall(dest_directory)
108 |
109 |
110 | def save_images(images_dir, images, labels):
111 | print("Saving images to disk")
112 | i = 0
113 | for image in images:
114 | if labels is not None:
115 | directory = os.path.join(images_dir, str(labels[i]))
116 | else:
117 | directory = images_dir
118 |
119 | try:
120 | os.makedirs(directory, exist_ok=True)
121 | except OSError as exc:
122 | if exc.errno == errno.EEXIST:
123 | pass
124 |
125 | filename = os.path.join(directory, str(i))
126 | # print(filename)
127 | save_image(image, filename)
128 | i = i + 1
129 |
130 |
131 | if __name__ == "__main__":
132 |
133 | # download data if needed
134 | download_and_extract()
135 |
136 | # Read train-images and labels and save to disk
137 | images = read_all_images(TRAIN_DATA_PATH)
138 | print(images.shape)
139 |
140 | labels = read_labels(TRAIN_LABEL_PATH)
141 | print(labels.shape)
142 |
143 | save_images(STL10_TRAIN_IMG_DIR, images, labels)
144 |
145 | # Read test-images and labels and save to disk
146 | images = read_all_images(TEST_DATA_PATH)
147 | print(images.shape)
148 |
149 | labels = read_labels(TEST_LABEL_PATH)
150 | print(labels.shape)
151 |
152 | save_images(STL10_TEST_IMG_DIR, images, labels)
153 |
154 | # Read unlabelled images and save to disk
155 | images = read_all_images(UNLABELLED_DATA_PATH)
156 | labels = None
157 | print (images.shape)
158 |
159 | save_images(STL10_UNLABELLED_IMG_DIR, images, labels)
160 |
--------------------------------------------------------------------------------
/submit_job.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | #SBATCH --nodes=1
3 | #SBATCH --ntasks-per-node=1
4 | #SBATCH --cpus-per-task=4
5 | #SBATCH --time=8:00:00
6 | #SBATCH --gres=gpu:p40:1
7 | #SBATCH --mem=14000
8 | #SBATCH --job-name=train_non_linear_pirl_on_stl
9 | #SBATCH --mail-type=END
10 | #SBATCH --mail-user=ab8700@nyu.edu
11 | #SBATCH --output=slurm_%j.out
12 |
13 | python pirl_stl_train_test.py --model-type 'res18' --lr 0.1 --tmax-for-cos-decay 50 --warm-start True --only-train True --non-linear-head True --cont-epoch 101 --experiment-name e8_stl_pirl
14 |
15 |
--------------------------------------------------------------------------------
/train_stl_after_ssl.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 |
4 | import torch
5 | import torchvision
6 |
7 | import numpy as np
8 | import pandas as pd
9 |
10 | from torch import optim
11 | from torch.optim.lr_scheduler import ReduceLROnPlateau
12 | from torch.utils.data import DataLoader, ConcatDataset
13 | from torch.utils.data import SubsetRandomSampler
14 |
15 | from common_constants import PAR_WEIGHTS_DIR
16 | from dataset_helpers import def_train_transform_stl, def_test_transform, get_file_paths_n_labels, hflip_data_transform, \
17 | darkness_jitter_transform, lightness_jitter_transform, rotations_transform, all_in_transform
18 | from experiment_logger import log_experiment
19 | from get_dataset import GetSTL10Data
20 | from models import classifier_resnet, pirl_resnet
21 | from network_helpers import copy_weights_between_models
22 | from random_seed_setter import set_random_generators_seed
23 | from train_test_helper import ModelTrainTest
24 |
25 | if __name__ == '__main__':
26 |
27 | # Training arguments
28 | parser = argparse.ArgumentParser(description='CIFAR10 Train test script')
29 | parser.add_argument('--model-type', type=str, default='res18',
30 | help='The network architecture to employ as backbone')
31 | parser.add_argument('--batch-size', type=int, default=128, metavar='N',
32 | help='input batch size for training (default: 128)')
33 | parser.add_argument('--epochs', type=int, default=100, metavar='N',
34 | help='number of epochs to train (default: 100)')
35 | parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
36 | help='learning rate (default: 0.01)')
37 | parser.add_argument('--weight-decay', type=float, default=5e-4,
38 | help='Weight decay constant (default: 5e-4)')
39 | parser.add_argument('--patience-for-lr-decay', type=int, default=10)
40 | parser.add_argument('--full-fine-tune', type=bool, default=False)
41 | parser.add_argument('--experiment-name', type=str, default='e1_pirl_sup_')
42 | parser.add_argument('--pirl-model-name', type=str)
43 | args = parser.parse_args()
44 |
45 | # Set random number generation seed for all packages that generate random numbers
46 | set_random_generators_seed()
47 |
48 | # Identify device for holding tensors and carrying out computations
49 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
50 |
51 | # Define file path with trained SSL model and file_path where trained classification model
52 | # will be saved
53 | pirl_file_path = os.path.join(PAR_WEIGHTS_DIR, args.pirl_model_name)
54 | model_file_path = os.path.join(PAR_WEIGHTS_DIR, args.experiment_name)
55 |
56 | # Get train-val file paths and labels for STL10
57 | par_train_val_images_dir = './stl10_data/train'
58 | train_val_file_paths, train_val_labels = get_file_paths_n_labels(par_train_val_images_dir)
59 | print ('Train val file paths count', len(train_val_file_paths))
60 | print ('Train val labels count', len(train_val_labels))
61 |
62 | # Split file paths into train and val file paths
63 | len_train_val_set = len(train_val_file_paths)
64 | train_val_indices = list(range(len_train_val_set))
65 | np.random.shuffle(train_val_indices)
66 |
67 | count_train = 4200
68 |
69 | train_indices = train_val_indices[:count_train]
70 | val_indices = train_val_indices[count_train:]
71 |
72 | train_val_file_paths = np.array(train_val_file_paths)
73 | train_val_labels = np.array(train_val_labels)
74 | train_file_paths, train_labels = train_val_file_paths[train_indices], train_val_labels[train_indices]
75 | val_file_paths, val_labels = train_val_file_paths[val_indices], train_val_labels[val_indices]
76 |
77 | # Define train_set, and val_set objects
78 | train_set = ConcatDataset(
79 | [GetSTL10Data(train_file_paths, train_labels, def_train_transform_stl),
80 | GetSTL10Data(train_file_paths, train_labels, hflip_data_transform),
81 | GetSTL10Data(train_file_paths, train_labels, darkness_jitter_transform),
82 | GetSTL10Data(train_file_paths, train_labels, lightness_jitter_transform),
83 | GetSTL10Data(train_file_paths, train_labels, rotations_transform),
84 | GetSTL10Data(train_file_paths, train_labels, all_in_transform)]
85 | )
86 |
87 | # train_set = GetSTL10Data(train_val_file_paths, train_val_labels, all_in_transform)
88 | val_set = GetSTL10Data(val_file_paths, val_labels, def_test_transform)
89 |
90 | # Define train, validation and test data loaders
91 | train_loader = torch.utils.data.DataLoader(train_set, batch_size=128, num_workers=8)
92 | val_loader = torch.utils.data.DataLoader(val_set, batch_size=100, num_workers=8)
93 |
94 | # Print sample batches that would be returned by the train_data_loader
95 | dataiter = iter(train_loader)
96 | X, y = dataiter.__next__()
97 | print (X.size())
98 | print (y.size())
99 |
100 | # Train required model using data loaders defined above
101 | num_outputs = 10
102 | epochs = args.epochs
103 | lr = args.lr
104 | weight_decay_const = args.weight_decay
105 |
106 | # Define model_to_train and inherit weights from pre-trained SSL model
107 | model_to_train = classifier_resnet(args.model_type, num_classes=num_outputs)
108 | pirl_model = pirl_resnet(args.model_type)
109 | pirl_model.load_state_dict(torch.load(pirl_file_path, map_location=device))
110 | weight_copy_success = copy_weights_between_models(pirl_model, model_to_train)
111 |
112 | if not weight_copy_success:
113 | print ('Weight copy between SSL and classification net failed. Pls check !!')
114 | exit()
115 |
116 | # Freeze all layers except fully connected in classification net
117 | for name, param in model_to_train.named_parameters():
118 | if name[:7] == 'resnet_':
119 | param.requires_grad = False
120 |
121 | # # To test what is trainable status of each layer
122 | # for name, param in model_to_train.named_parameters():
123 | # print (name, param.requires_grad)
124 |
125 | # Set device on which training is done. Plus optimizer to use.
126 | model_to_train.to(device)
127 | sgd_optimizer = optim.SGD(model_to_train.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay_const)
128 | scheduler = ReduceLROnPlateau(sgd_optimizer, 'min', patience=args.patience_for_lr_decay,
129 | verbose=True, min_lr=1e-5)
130 |
131 | # Start training
132 | model_train_test_obj = ModelTrainTest(model_to_train, device, model_file_path)
133 | train_losses, val_losses, train_accs, val_accs = [], [], [], []
134 | for epoch_no in range(epochs):
135 | train_loss, train_acc, val_loss, val_acc = model_train_test_obj.train(
136 | sgd_optimizer, epoch_no, params_max_norm=4,
137 | train_data_loader=train_loader, val_data_loader=val_loader,
138 | no_train_samples=len(train_indices), no_val_samples=len(val_indices)
139 | )
140 | train_losses.append(train_loss)
141 | val_losses.append(val_loss)
142 | train_accs.append(train_acc)
143 | val_accs.append(val_acc)
144 | scheduler.step(val_loss)
145 |
146 | # Log train-test results
147 | log_experiment(args.experiment_name + '_lin_clf', args.epochs, train_losses, val_losses, train_accs, val_accs)
148 |
149 | # Check if layers beyond last fully connected are to be fine tuned
150 | if args.full_fine_tune:
151 | for name, param in model_to_train.named_parameters():
152 | param.requires_grad = True
153 |
154 | # Reset optimizer and learning rate scheduler
155 | sgd_optimizer = optim.SGD(model_to_train.parameters(), lr=0.01, momentum=0.9, weight_decay=weight_decay_const)
156 | scheduler = ReduceLROnPlateau(sgd_optimizer, 'min', patience=args.patience_for_lr_decay,
157 | verbose=True, min_lr=1e-5)
158 |
159 | # Re-start training
160 | model_train_test_obj = ModelTrainTest(model_to_train, device, model_file_path)
161 | train_losses, val_losses, train_accs, val_accs = [], [], [], []
162 | for epoch_no in range(epochs):
163 | train_loss, train_acc, val_loss, val_acc = model_train_test_obj.train(
164 | sgd_optimizer, epoch_no, params_max_norm=4,
165 | train_data_loader=train_loader, val_data_loader=val_loader,
166 | no_train_samples=len(train_indices), no_val_samples=len(val_indices)
167 | )
168 | train_losses.append(train_loss)
169 | val_losses.append(val_loss)
170 | train_accs.append(train_acc)
171 | val_accs.append(val_acc)
172 | scheduler.step(val_loss)
173 |
174 | # Log train-test results
175 | log_experiment(args.experiment_name + '_full_ft', args.epochs, train_losses, val_losses, train_accs, val_accs)
176 |
--------------------------------------------------------------------------------
/train_test_helper.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | import numpy as np
5 |
6 | from torch.nn.utils import clip_grad_norm_
7 |
8 | from pirl_loss import loss_pirl, get_img_pair_probs
9 |
10 |
11 | def get_count_correct_preds(network_output, target):
12 |
13 | score, predicted = torch.max(network_output, 1) # Returns max score and the index where max score was recorded
14 | count_correct = (target == predicted).sum().float() # So that when accuracy is computed, it is not rounded to int
15 |
16 | return count_correct
17 |
18 |
19 | def get_count_correct_preds_pretext(img_pair_probs_arr, img_mem_rep_probs_arr):
20 | """
21 | Get count of correct predictions for pre-text task
22 | :param img_pair_probs_arr: Prob vector of batch of images I and I_t to belong to same data distribution.
23 | :param img_mem_rep_probs_arr: Prob vector of batch of I and mem_bank_rep of I to belong to same data distribution
24 | """
25 |
26 | avg_probs_arr = (1/2) * (img_pair_probs_arr + img_mem_rep_probs_arr)
27 | count_correct = (avg_probs_arr >= 0.5).sum().float() # So that when accuracy is computed, it is not rounded to int
28 |
29 | return count_correct.item()
30 |
31 |
32 | class PIRLModelTrainTest():
33 |
34 | def __init__(self, network, device, model_file_path, all_images_mem, train_image_indices,
35 | val_image_indices, count_negatives, temp_parameter, beta, only_train=False, threshold=1e-4):
36 | super(PIRLModelTrainTest, self).__init__()
37 | self.network = network
38 | self.device = device
39 | self.model_file_path = model_file_path
40 | self.threshold = threshold
41 | self.train_loss = 1e9
42 | self.val_loss = 1e9
43 | self.all_images_mem = torch.tensor(all_images_mem, dtype=torch.float).to(device)
44 | self.train_image_indices = train_image_indices.copy()
45 | self.val_image_indices = val_image_indices.copy()
46 | self.count_negatives = count_negatives
47 | self.temp_parameter = temp_parameter
48 | self.beta = beta
49 | self.only_train = only_train
50 |
51 | def train(self, optimizer, epoch, params_max_norm, train_data_loader, val_data_loader,
52 | no_train_samples, no_val_samples):
53 | self.network.train()
54 | train_loss, correct, cnt_batches = 0, 0, 0
55 |
56 | for batch_idx, (data_batch, batch_img_indices) in enumerate(train_data_loader):
57 |
58 | # Separate input image I batch and transformed image I_t batch (jigsaw patches) from data_batch
59 | i_batch, i_t_patches_batch = data_batch[0], data_batch[1]
60 |
61 | # Set device for i_batch, i_t_patches_batch and batch_img_indices
62 | i_batch, i_t_patches_batch = i_batch.to(self.device), i_t_patches_batch.to(self.device)
63 | batch_img_indices = batch_img_indices.to(self.device)
64 |
65 | # Forward pass through the network
66 | optimizer.zero_grad()
67 | vi_batch, vi_t_batch = self.network(i_batch, i_t_patches_batch)
68 |
69 | # Prepare memory bank of negatives for current batch
70 | np.random.shuffle(self.train_image_indices)
71 | mn_indices_all = np.array(list(set(self.train_image_indices) - set(batch_img_indices)))
72 | np.random.shuffle(mn_indices_all)
73 | mn_indices = mn_indices_all[:self.count_negatives]
74 | mn_arr = self.all_images_mem[mn_indices]
75 |
76 | # Get memory bank representation for current batch images
77 | mem_rep_of_batch_imgs = self.all_images_mem[batch_img_indices]
78 |
79 | # Get prob for I, I_t to belong to same data distribution.
80 | img_pair_probs_arr = get_img_pair_probs(vi_batch, vi_t_batch, mn_arr, self.temp_parameter)
81 |
82 | # Get prob for I and mem_bank_rep of I to belong to same data distribution
83 | img_mem_rep_probs_arr = get_img_pair_probs(vi_batch, mem_rep_of_batch_imgs, mn_arr, self.temp_parameter)
84 |
85 | # Compute loss => back-prop gradients => Update weights
86 | loss = loss_pirl(img_pair_probs_arr, img_mem_rep_probs_arr)
87 | loss.backward()
88 |
89 | clip_grad_norm_(self.network.parameters(), params_max_norm)
90 | optimizer.step()
91 |
92 | # Update running loss and no of pseudo correct predictions for epoch
93 | correct += get_count_correct_preds_pretext(img_pair_probs_arr, img_mem_rep_probs_arr)
94 | train_loss += loss.item()
95 | cnt_batches += 1
96 |
97 | # Update memory bank representation for images from current batch
98 | all_images_mem_new = self.all_images_mem.clone().detach()
99 | all_images_mem_new[batch_img_indices] = (self.beta * all_images_mem_new[batch_img_indices]) + \
100 | ((1 - self.beta) * vi_batch)
101 | self.all_images_mem = all_images_mem_new.clone().detach()
102 |
103 | del i_batch, i_t_patches_batch, vi_batch, vi_t_batch, mn_arr, mem_rep_of_batch_imgs
104 | del img_mem_rep_probs_arr, img_pair_probs_arr
105 |
106 | train_loss /= cnt_batches
107 |
108 | if epoch % 10 == 0:
109 | torch.save(self.network.state_dict(), self.model_file_path + '_epoch_{}'.format(epoch))
110 |
111 | if self.only_train is False:
112 | val_loss, val_acc = self.test(epoch, val_data_loader, no_val_samples)
113 |
114 | if val_loss < self.val_loss - self.threshold:
115 | self.val_loss = val_loss
116 | torch.save(self.network.state_dict(), self.model_file_path)
117 |
118 | else:
119 | val_loss, val_acc = 0.0, 0.0
120 |
121 | train_acc = correct / no_train_samples
122 |
123 | print('\nAfter epoch {} - Train set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
124 | epoch, train_loss, correct, no_train_samples, 100. * correct / no_train_samples))
125 |
126 | return train_loss, train_acc, val_loss, val_acc
127 |
128 | def test(self, epoch, test_data_loader, no_test_samples):
129 |
130 | self.network.eval()
131 | test_loss, correct, cnt_batches = 0, 0, 0
132 |
133 | for batch_idx, (data_batch, batch_img_indices) in enumerate(test_data_loader):
134 |
135 | # Separate input image I batch and transformed image I_t batch (jigsaw patches) from data_batch
136 | i_batch, i_t_patches_batch = data_batch[0], data_batch[1]
137 |
138 | # Set device for i_batch, i_t_patches_batch and batch_img_indices
139 | i_batch, i_t_patches_batch = i_batch.to(self.device), i_t_patches_batch.to(self.device)
140 | batch_img_indices = batch_img_indices.to(self.device)
141 |
142 | # Forward pass through the network
143 | vi_batch, vi_t_batch = self.network(i_batch, i_t_patches_batch)
144 |
145 | # Prepare memory bank of negatives for current batch
146 | np.random.shuffle(self.val_image_indices)
147 |
148 | mn_indices_all = np.array(list(set(self.val_image_indices) - set(batch_img_indices)))
149 | np.random.shuffle(mn_indices_all)
150 | mn_indices = mn_indices_all[:self.count_negatives]
151 | mn_arr = self.all_images_mem[mn_indices]
152 |
153 | # Get memory bank representation for current batch images
154 | mem_rep_of_batch_imgs = self.all_images_mem[batch_img_indices]
155 |
156 | # Get prob for I, I_t to belong to same data distribution.
157 | img_pair_probs_arr = get_img_pair_probs(vi_batch, vi_t_batch, mn_arr, self.temp_parameter)
158 |
159 | # Get prob for I and mem_bank_rep of I to belong to same data distribution
160 | img_mem_rep_probs_arr = get_img_pair_probs(vi_batch, mem_rep_of_batch_imgs, mn_arr, self.temp_parameter)
161 |
162 | # Compute loss
163 | loss = loss_pirl(img_pair_probs_arr, img_mem_rep_probs_arr)
164 |
165 | # Update running loss and no of pseudo correct predictions for epoch
166 | correct += get_count_correct_preds_pretext(img_pair_probs_arr, img_mem_rep_probs_arr)
167 | test_loss += loss.item()
168 | cnt_batches += 1
169 |
170 | # Update memory bank representation for images from current batch
171 | all_images_mem_new = self.all_images_mem.clone().detach()
172 | all_images_mem_new[batch_img_indices] = (self.beta * all_images_mem_new[batch_img_indices]) + \
173 | ((1 - self.beta) * vi_batch)
174 | self.all_images_mem = all_images_mem_new.clone().detach()
175 |
176 |
177 | del i_batch, i_t_patches_batch, vi_batch, vi_t_batch, mn_arr, mem_rep_of_batch_imgs
178 | del img_mem_rep_probs_arr, img_pair_probs_arr
179 |
180 | test_loss /= cnt_batches
181 | test_acc = correct / no_test_samples
182 | print('\nAfter epoch {} - Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
183 | epoch, test_loss, correct, no_test_samples, 100. * correct / no_test_samples))
184 |
185 | return test_loss, test_acc
186 |
187 |
188 | class ModelTrainTest():
189 |
190 | def __init__(self, network, device, model_file_path, threshold=1e-4):
191 | super(ModelTrainTest, self).__init__()
192 | self.network = network
193 | self.device = device
194 | self.model_file_path = model_file_path
195 | self.threshold = threshold
196 | self.train_loss = 1e9
197 | self.val_loss = 1e9
198 |
199 | def train(self, optimizer, epoch, params_max_norm, train_data_loader, val_data_loader,
200 | no_train_samples, no_val_samples):
201 | self.network.train()
202 | train_loss, correct, cnt_batches = 0, 0, 0
203 |
204 | for batch_idx, (data, target) in enumerate(train_data_loader):
205 | data, target = data.to(self.device), target.to(self.device)
206 |
207 | optimizer.zero_grad()
208 | output = self.network(data)
209 |
210 | loss = F.nll_loss(output, target)
211 | loss.backward()
212 |
213 | clip_grad_norm_(self.network.parameters(), params_max_norm)
214 | optimizer.step()
215 |
216 | correct += get_count_correct_preds(output, target)
217 | train_loss += loss.item()
218 | cnt_batches += 1
219 |
220 | del data, target, output
221 |
222 | train_loss /= cnt_batches
223 | val_loss, val_acc = self.test(epoch, val_data_loader, no_val_samples)
224 |
225 | if val_loss < self.val_loss - self.threshold:
226 | self.val_loss = val_loss
227 | torch.save(self.network.state_dict(), self.model_file_path)
228 |
229 | train_acc = correct / no_train_samples
230 |
231 | print('\nAfter epoch {} - Train set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
232 | epoch, train_loss, correct, no_train_samples, 100. * correct / no_train_samples))
233 |
234 | return train_loss, train_acc, val_loss, val_acc
235 |
236 | def test(self, epoch, test_data_loader, no_test_samples):
237 | self.network.eval()
238 | test_loss = 0
239 | correct = 0
240 |
241 | for batch_idx, (data, target) in enumerate(test_data_loader):
242 | data, target = data.to(self.device), target.to(self.device)
243 | output = self.network(data)
244 | test_loss += F.nll_loss(output, target, size_average=False).item() # sum up batch loss
245 |
246 | correct += get_count_correct_preds(output, target)
247 |
248 | del data, target, output
249 |
250 | test_loss /= no_test_samples
251 | test_acc = correct / no_test_samples
252 | print('\nAfter epoch {} - Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
253 | epoch, test_loss, correct, no_test_samples, 100. * correct / no_test_samples))
254 |
255 | return test_loss, test_acc
256 |
257 | if __name__ == '__main__':
258 | img_pair_probs_arr = torch.randn((256,))
259 | img_mem_rep_probs_arr = torch.randn((256,))
260 | print (get_count_correct_preds_pretext(img_pair_probs_arr, img_mem_rep_probs_arr))
--------------------------------------------------------------------------------