├── datasets ├── __init__.py └── DatasetPairwiseTriplets.py ├── figures └── teaser.png ├── .gitignore ├── util ├── read_hdf5_data.py ├── convert_mat_files.py ├── convert_brown_store.py ├── convert_datasets.py ├── warmup_scheduler.py ├── encoder_heatmaps.py └── utils.py ├── layers └── spp_layer.py ├── README.md ├── networks ├── BackboneCNN.py ├── MultiscaleTransformerEncoder.py ├── positional_encodings.py ├── losses.py ├── transformer.py └── transforms.py ├── requirements.txt ├── train.py └── license.md /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /figures/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeJjang/multiscale-attention-patch-matching/HEAD/figures/teaser.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .gitattributes 2 | .git/ 3 | models/ 4 | models*/ 5 | best models/ 6 | logs*/ 7 | *.pyc 8 | *.h5 9 | data/ 10 | detection/ 11 | *.rar 12 | .idea/ 13 | artifacts/ -------------------------------------------------------------------------------- /util/read_hdf5_data.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | 4 | 5 | def read_hdf5_data(fname): 6 | with h5py.File(fname, 'r') as f: 7 | 8 | keys = list(f.keys()) 9 | 10 | if len(keys) == 1: 11 | data = f[keys[0]] 12 | res = np.squeeze(np.array(data[()])) 13 | else: 14 | i = 0 15 | res = dict() 16 | for v in keys: 17 | res[v] = np.array(f[keys[i]]) 18 | i += 1 19 | 20 | return res 21 | -------------------------------------------------------------------------------- /layers/spp_layer.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | def spatial_pyramid_pool(previous_conv, num_sample, previous_conv_size, out_pool_size): 8 | for i in range(len(out_pool_size)): 9 | h_wid = int(math.ceil(previous_conv_size[0] / out_pool_size[i])) 10 | w_wid = int(math.ceil(previous_conv_size[1] / out_pool_size[i])) 11 | h_pad = int((h_wid * out_pool_size[i] - previous_conv_size[0] + 1) / 2) 12 | w_pad = int((w_wid * out_pool_size[i] - previous_conv_size[1] + 1) / 2) 13 | maxpool = nn.MaxPool2d((h_wid, w_wid), stride=(h_wid, w_wid), padding=(h_pad, w_pad)) 14 | 15 | x = maxpool(previous_conv) 16 | if i == 0: 17 | spp = x.view(num_sample, -1) 18 | else: 19 | spp = torch.cat((spp, x.view(num_sample, -1)), 1) 20 | return spp 21 | -------------------------------------------------------------------------------- /util/convert_mat_files.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | 4 | train_mat_file_path = 'D:\\multisensor\\datasets\\Vis-Nir_grid\\Vis-Nir_grid_Train.mat' 5 | new_train_file_path = 'D:\\multisensor\\datasets\\Vis-Nir_grid\\train.hdf5' 6 | 7 | test_mat_file_path = 'D:\\multisensor\\datasets\\Vis-Nir_grid\\Vis-Nir_grid_Test.mat' 8 | new_test_file_path = 'D:\\multisensor\\datasets\\Vis-Nir_grid\\test.hdf5' 9 | 10 | with h5py.File(train_mat_file_path, 'r') as f: 11 | train_data = np.array(f.get('images/data')) 12 | train_labels = np.logical_not(np.array(f.get('images/labels')) - 1).astype(np.float64) 13 | train_set = np.array(f.get('images/set')) 14 | train_data = train_data.transpose(0, 2, 3, 1).reshape(train_data.shape[0], 1, train_data.shape[2], 15 | train_data.shape[3], train_data.shape[1]) 16 | with h5py.File(new_train_file_path, 'w') as f: 17 | f.create_dataset('Data', data=train_data) 18 | f.create_dataset('Labels', data=train_labels) 19 | f.create_dataset('Set', data=train_set) 20 | 21 | with h5py.File(test_mat_file_path, 'r') as f: 22 | test_data = np.array(f.get('testData')) 23 | test_labels = np.logical_not(np.array(f.get('testLabels')) - 1).astype(np.float64) 24 | test_data = test_data.transpose(0, 2, 3, 1).reshape(test_data.shape[0], 1, test_data.shape[2], test_data.shape[3], 25 | test_data.shape[1]) 26 | with h5py.File(new_test_file_path, 'w') as f: 27 | f.create_dataset('Data', data=test_data) 28 | f.create_dataset('Labels', data=test_labels) 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Paying Attention to Multiscale Feature Maps in Multimodal Image Matching 2 | 3 | ![teaser architecture fig](figures/teaser.png) 4 | 5 | We propose an attention-based approach for multimodal image patch matching using a Transformer encoder attending to the feature maps of a multiscale Siamese CNN. Our encoder is shown to efficiently aggregate multiscale image embeddings while emphasizing task-specific appearance-invariant image cues. We also introduce an attention-residual architecture, using a residual connection bypassing the encoder. This additional learning signal facilitates end-to-end training from scratch. 6 | 7 | ## System requirements 8 | * Code was developed and tested on Windows 10. 9 | * 64-bit Python 3.8.5. 10 | * Pytorch 1.7.1 or newer. 11 | * One or more NVIDIA GPUs with 11 GB RAM or more. We used three GeForce GTX 1080 Ti. 12 | * NVIDIA driver 460.89 or newer, CUDA toolkit 11.2 or newer. 13 | 14 | ## Setup 15 | Install python dependencies using: 16 | ``` 17 | pip install -r requirements.txt 18 | ``` 19 | 20 | ## Datasets 21 | The following datasets are already preprocessed for efficient training: 22 | 1. [VisNir] https://biu365-my.sharepoint.com/:f:/g/personal/kellery1_biu_ac_il/EmMhyd3UdKNIr0O8IpAUcJwBgMEYXRJS7lA2WCSk9X15vg?e=yVNVu3. 23 | 2. [En etal] https://biu365-my.sharepoint.com/:f:/g/personal/kellery1_biu_ac_il/EuQ70WqSFLNNk7JjYof0kPIBpblOCiJprqUBzSaT5Rhd8A?e=iBvv7z. 24 | 3. [UBC] https://biu365-my.sharepoint.com/:u:/g/personal/kellery1_biu_ac_il/EfcO4wg0jIFAlJKu5TRhyMYBP-Mpb6buYube1or_zV0guA?e=tCsnTt. 25 | 26 | 27 | Re-arrange the test folders with the test datasets containing the appropriate data as follows: test_yos_not, test_lib_yos, test_lib_not, for training on liberty, notredame and yosemite respectively. 28 | ## Training 29 | Run the following command: 30 | ``` 31 | python train.py --dataset-name=visnir --dataset-path= 32 | ``` 33 | 34 | For further configurations run `-h` with the above command. 35 | You may need to set `num_workers` to 0 in the train dataloader when training on UBC, depending on your hardware. 36 | -------------------------------------------------------------------------------- /util/convert_brown_store.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import h5py 3 | 4 | datasets = ['liberty', 'notredame', 'yosemite'] 5 | test_datasets_path = 'D:\\multisensor\\datasets\\brown\\patchdata\\evaluate_%s_64x64.h5' 6 | train_datasets_path = 'D:\\multisensor\\datasets\\brown\\patchdata\\full_evaluate_%s_64x64.h5' 7 | new_test_file_path = 'D:\\multisensor\\datasets\\brown\\patchdata\\%s_test_for_multisensor.hdf5' 8 | new_train_file_path = 'D:\\multisensor\\datasets\\brown\\patchdata\\%s_full_for_multisensor.hdf5' 9 | new_full_train_file_path = 'D:\\multisensor\\datasets\\brown\\patchdata\\full_for_multisensor.hdf5' 10 | 11 | 12 | def transform_dimensions(arr): 13 | samples = arr.shape[0] 14 | arr = arr.reshape(samples, 1, 64, 64).reshape(int(samples / 2), 2, 64, 64).transpose(0, 2, 3, 1) 15 | arr = np.expand_dims(arr, 1) 16 | assert arr.shape == (int(samples / 2), 1, 64, 64, 2) 17 | return arr 18 | 19 | 20 | def save_results(fpath, data, labels, set_labels): 21 | with h5py.File(fpath, 'w') as f: 22 | f.create_dataset('Data', data=data) 23 | f.create_dataset('Labels', data=labels) 24 | f.create_dataset('Set', data=set_labels) 25 | 26 | 27 | def convert_single_ds(): 28 | convert_train = True 29 | convert_test = False 30 | for dataset in datasets: 31 | if convert_test: 32 | with h5py.File(test_datasets_path % dataset, 'r') as f: 33 | pos = transform_dimensions(np.array(f.get('50000/match'))) 34 | neg = transform_dimensions(np.array(f.get('50000/non-match'))) 35 | data = np.concatenate((pos, neg)) 36 | labels = np.concatenate((np.full(pos.shape[0], 1), np.full(neg.shape[0], 0))) 37 | set_labels = np.full(labels.shape, 1) 38 | save_results(new_test_file_path % dataset, data, labels, set_labels) 39 | if convert_train: 40 | with h5py.File(train_datasets_path % dataset, 'r') as f: 41 | pos = transform_dimensions(np.array(f.get('250000/match'))) 42 | neg = transform_dimensions(np.array(f.get('250000/non-match'))) 43 | data = np.concatenate((pos, neg)) 44 | labels = np.concatenate((np.full(pos.shape[0], 1), np.full(neg.shape[0], 0))) 45 | set_labels = np.full(labels.shape, 1) 46 | save_results(new_train_file_path % dataset, data, labels, set_labels) 47 | 48 | 49 | def convert_all_ds_train(): 50 | ds_size = 500000 51 | with h5py.File(new_full_train_file_path, 'w') as f: 52 | data = f.create_dataset('Data', (ds_size * 3, 1, 64, 64, 2), maxshape=(None, 1, 64, 64, 2)) 53 | labels = f.create_dataset('Labels', (ds_size * 3,), maxshape=(None,)) 54 | set = f.create_dataset('Set', (ds_size * 3,), maxshape=(None,)) 55 | for i, ds in enumerate(datasets): 56 | with h5py.File(new_train_file_path % ds, 'r') as dsf: 57 | data[i * ds_size: (i + 1) * ds_size] = np.array(dsf.get('Data'), np.float32) 58 | labels[i * ds_size: (i + 1) * ds_size] = np.array(dsf.get('Labels'), np.float32) 59 | # set[i * ds_size: (i + 1) * ds_size] = dsf.get('Set') 60 | 61 | 62 | def main(): 63 | # convert_single_ds() 64 | convert_all_ds_train() 65 | 66 | 67 | main() 68 | -------------------------------------------------------------------------------- /networks/BackboneCNN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.utils.checkpoint 5 | from torch.nn import Dropout 6 | 7 | from layers.spp_layer import spatial_pyramid_pool 8 | 9 | 10 | class BackboneCNN(nn.Module): 11 | 12 | def __init__(self, dropout, output_feat_map=False): 13 | super(BackboneCNN, self).__init__() 14 | 15 | self.pre_block = nn.Sequential( 16 | nn.Conv2d(1, 32, kernel_size=3, padding=1, bias=False), 17 | nn.BatchNorm2d(32, affine=False), 18 | nn.ReLU() 19 | ) 20 | 21 | self.block = nn.Sequential( 22 | 23 | nn.Conv2d(32, 32, kernel_size=3, padding=1, bias=False), 24 | nn.BatchNorm2d(32, affine=False, momentum=0.1 ** 0.5), 25 | nn.ReLU(), 26 | 27 | nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1, dilation=2, bias=False), 28 | nn.BatchNorm2d(64, affine=False, momentum=0.1 ** 0.5), 29 | nn.ReLU(), 30 | 31 | nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=False), 32 | nn.BatchNorm2d(64, affine=False, momentum=0.1 ** 0.5), 33 | nn.ReLU(), 34 | 35 | nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, dilation=2, bias=False), 36 | nn.BatchNorm2d(128, affine=False, momentum=0.1 ** 0.5), 37 | nn.ReLU(), 38 | 39 | nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False), 40 | nn.BatchNorm2d(128, affine=False, momentum=0.1 ** 0.5), 41 | nn.ReLU(), 42 | 43 | nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False), 44 | nn.BatchNorm2d(128, affine=False, momentum=0.1 ** 0.5), 45 | nn.ReLU(), 46 | 47 | nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False), 48 | nn.BatchNorm2d(128, affine=False, momentum=0.1 ** 0.5), 49 | ) 50 | 51 | self.output_feat_map = output_feat_map 52 | if not output_feat_map: 53 | self.output_num = [8, 4, 2, 1] 54 | self.fc1 = nn.Sequential( 55 | nn.Linear(10880, 128) 56 | ) 57 | 58 | self.dropout = dropout 59 | 60 | return 61 | 62 | def input_norm(self, x): 63 | flat = x.reshape(x.size(0), -1) 64 | mp = torch.mean(flat, dim=1) 65 | sp = torch.std(flat, dim=1) + 1e-7 66 | return (x - mp.detach().unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).expand_as(x)) / sp.detach().unsqueeze( 67 | -1).unsqueeze(-1).unsqueeze(1).expand_as(x) 68 | 69 | def forward(self, x, mode='Normalized'): 70 | batch_size = x.size(0) 71 | # conv_feats = self.block(self.input_norm(x)) 72 | conv_feats = self.pre_block(self.input_norm(x)) 73 | conv_feats = torch.utils.checkpoint.checkpoint(self.block, conv_feats) 74 | if self.output_feat_map: 75 | return conv_feats 76 | 77 | spp = spatial_pyramid_pool(conv_feats, batch_size, [int(conv_feats.size(2)), int(conv_feats.size(3))], 78 | self.output_num) 79 | 80 | spp = Dropout(self.dropout)(spp) 81 | 82 | feature_a = self.fc1(spp).reshape(batch_size, -1) 83 | 84 | if mode == 'Normalized': 85 | return F.normalize(feature_a, dim=1, p=2) 86 | else: 87 | return feature_a 88 | -------------------------------------------------------------------------------- /util/convert_datasets.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | from multiprocessing import freeze_support 4 | 5 | import h5py 6 | import numpy as np 7 | import torch 8 | 9 | from util import read_hdf5_data 10 | 11 | if __name__ == '__main__': 12 | freeze_support() 13 | 14 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 15 | print(device) 16 | 17 | test_dir = './data/Vis-Nir_grid/' 18 | load_all_test_sets = True 19 | train_file = './data/brown/patchdata_64x64.h5' 20 | 21 | convert_train_files = True 22 | convert_test_files = True 23 | convert_patch_files = True 24 | 25 | if convert_patch_files: 26 | data = read_hdf5_data(train_file) 27 | data['liberty'] = np.reshape(data['liberty'], (data['liberty'].shape[0], 1, 64, 64), order='F') 28 | data['notredame'] = np.reshape(data['notredame'], (data['notredame'].shape[0], 1, 64, 64), order='F') 29 | data['yosemite'] = np.reshape(data['yosemite'], (data['yosemite'].shape[0], 1, 64, 64), order='F') 30 | 31 | with h5py.File('patchdata1' + '.h5', 'w') as f: 32 | f.create_dataset('liberty', data=data['liberty']) 33 | f.create_dataset('notredame', data=data['notredame']) 34 | f.create_dataset('yosemite', data=data['yosemite']) 35 | 36 | if convert_train_files: 37 | path, dataset_name = os.path.split(train_file) 38 | dataset_name = os.path.splitext(train_file)[0] 39 | 40 | data = read_hdf5_data(train_file) 41 | training_set_data = np.transpose(data['Data'], (0, 3, 2, 1)) 42 | training_set_labels = np.squeeze(data['Labels']) 43 | training_set_splits = np.squeeze(data['Set']) 44 | 45 | training_set_data = np.reshape(training_set_data, ( 46 | training_set_data.shape[0], 1, training_set_data.shape[1], training_set_data.shape[2], 47 | training_set_data.shape[3]), order='F') 48 | training_set_labels = 2 - training_set_labels 49 | 50 | with h5py.File(dataset_name + '.hdf5', 'w') as f: 51 | f.create_dataset('Data', data=training_set_data, compression='gzip', compression_opts=9) 52 | f.create_dataset('Labels', data=training_set_labels, compression='gzip', compression_opts=9) 53 | f.create_dataset('Set', data=training_set_splits, compression='gzip', compression_opts=9) 54 | 55 | if convert_test_files: 56 | 57 | # Load all datasets 58 | file_list = glob.glob(test_dir + "*.mat") 59 | 60 | if load_all_test_sets == False: 61 | file_list = [file_list[0]] 62 | 63 | file_list = ['./data/Vis-Nir_grid/Vis-Nir_grid_Test.mat'] 64 | 65 | test_data = dict() 66 | for f in file_list: 67 | path, dataset_name = os.path.split(f) 68 | dataset_name = os.path.splitext(dataset_name)[0] 69 | 70 | print(f) 71 | data = read_hdf5_data(f) 72 | 73 | x = np.transpose(data['testData'], (0, 3, 2, 1)) 74 | TestLabels = torch.from_numpy(2 - data['testLabels']) 75 | 76 | x = np.reshape(x, (x.shape[0], 1, x.shape[1], x.shape[2], x.shape[3]), order='F') 77 | with h5py.File(path + '/' + dataset_name[:-5] + '.hdf5', 'w') as f: 78 | f.create_dataset('Data', data=x, compression='gzip', compression_opts=9) 79 | f.create_dataset('Labels', data=TestLabels, compression='gzip', compression_opts=9) 80 | -------------------------------------------------------------------------------- /datasets/DatasetPairwiseTriplets.py: -------------------------------------------------------------------------------- 1 | import albumentations as A 2 | import cv2 3 | import numpy as np 4 | from skimage.transform import resize 5 | from torch.utils.data import Dataset 6 | 7 | from util.utils import normalize_image 8 | 9 | 10 | class DatasetPairwiseTriplets(Dataset): 11 | 12 | def __init__(self, data, labels, batch_size, augmentations, mode, negative_mining_mode='Random'): 13 | self.pos_indices = np.squeeze(np.asarray(np.where(labels == 1))) 14 | self.neg_indices = np.squeeze(np.asarray(np.where(labels == 0))) 15 | 16 | self.pos_amount = len(self.pos_indices) 17 | self.neg_amount = len(self.neg_indices) 18 | 19 | self.data = data 20 | self.labels = labels 21 | 22 | self.batch_size = batch_size 23 | self.augmentations = augmentations 24 | 25 | self.mode = mode 26 | self.negative_mining_mode = negative_mining_mode 27 | 28 | self.channel_mean1 = data[:, :, :, 0].mean() 29 | self.channel_mean2 = data[:, :, :, 1].mean() 30 | 31 | self.data_height = data.shape[1] 32 | self.data_width = data.shape[2] 33 | 34 | self.transform = A.ReplayCompose([ 35 | A.Rotate(limit=5, interpolation=cv2.INTER_CUBIC, border_mode=cv2.BORDER_REFLECT_101, always_apply=False, 36 | p=0.5), 37 | A.HorizontalFlip(always_apply=False, p=0.5), 38 | A.VerticalFlip(always_apply=False, p=0.5), 39 | ]) 40 | 41 | def __len__(self): 42 | return self.data.shape[0] 43 | 44 | def __getitem__(self, index): 45 | # Select pos pairs 46 | pos_idx = np.random.randint(self.pos_amount, size=self.batch_size) 47 | 48 | pos_idx = self.pos_indices[pos_idx] 49 | pos_images = self.data[pos_idx, :, :, :].astype(np.float32) 50 | 51 | pos1 = pos_images[:, :, :, 0] 52 | pos2 = pos_images[:, :, :, 1] 53 | 54 | for i in range(0, pos_images.shape[0]): 55 | 56 | # flip LR 57 | if (np.random.uniform(0, 1) > 0.5) and self.augmentations.get("HorizontalFlip"): 58 | pos1[i,] = np.fliplr(pos1[i,]) 59 | pos2[i,] = np.fliplr(pos2[i,]) 60 | 61 | # flip UD 62 | if (np.random.uniform(0, 1) > 0.5) and self.augmentations.get("VerticalFlip"): 63 | pos1[i,] = np.flipud(pos1[i,]) 64 | pos2[i,] = np.flipud(pos2[i,]) 65 | 66 | # test augmentations 67 | if self.augmentations.get("Test"): 68 | data = self.transform(image=pos1[i, :, :]) 69 | pos1[i,] = data['image'] 70 | pos2[i,] = A.ReplayCompose.replay(data['replay'], image=pos2[i, :, :])['image'] 71 | 72 | # rotate 73 | if self.augmentations.get("Rotate90"): 74 | idx = np.random.randint(low=0, high=4, size=1)[0] # choose rotation 75 | pos1[i,] = np.rot90(pos1[i,], idx) 76 | pos2[i,] = np.rot90(pos2[i,], idx) 77 | 78 | # random crop 79 | if (np.random.uniform(0, 1) > 0.5) & self.augmentations.get("RandomCrop", {}).get('Do'): 80 | dx = np.random.uniform(self.augmentations.get("RandomCrop", {}).get('MinDx'), 81 | self.augmentations.get("RandomCrop", {}).get('MaxDx')) 82 | dy = np.random.uniform(self.augmentations.get("RandomCrop", {}).get('MinDy'), 83 | self.augmentations.get("RandomCrop", {}).get('MaxDy')) 84 | 85 | dx = dy 86 | 87 | x0 = int(dx * self.data_width) 88 | y0 = int(dy * self.data_height) 89 | 90 | pos1[i,] = resize(pos1[i, y0:, x0:], (self.data_height, self.data_width)) 91 | 92 | pos2[i,] = resize(pos2[i, y0:, x0:], (self.data_height, self.data_width)) 93 | 94 | res = dict() 95 | 96 | pos1 -= self.channel_mean1 97 | pos2 -= self.channel_mean2 98 | 99 | res['pos1'] = normalize_image(pos1) 100 | res['pos2'] = normalize_image(pos2) 101 | 102 | return res 103 | -------------------------------------------------------------------------------- /util/warmup_scheduler.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import ReduceLROnPlateau 2 | from torch.optim.lr_scheduler import _LRScheduler 3 | 4 | 5 | class GradualWarmupScheduler(_LRScheduler): 6 | """ Gradually warm-up(increasing) learning rate in optimizer. 7 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. 8 | 9 | Args: 10 | optimizer (Optimizer): Wrapped optimizer. 11 | multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr. 12 | total_epoch: target learning rate is reached at total_epoch, gradually 13 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) 14 | """ 15 | 16 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): 17 | self.multiplier = multiplier 18 | if self.multiplier < 1.: 19 | raise ValueError('multiplier should be greater thant or equal to 1.') 20 | self.total_epoch = total_epoch 21 | self.after_scheduler = after_scheduler 22 | self.finished = False 23 | super(GradualWarmupScheduler, self).__init__(optimizer) 24 | 25 | def get_lr(self): 26 | if self.last_epoch > self.total_epoch: 27 | if self.after_scheduler: 28 | if not self.finished: 29 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] 30 | self.finished = True 31 | return self.after_scheduler.get_last_lr() 32 | return [base_lr * self.multiplier for base_lr in self.base_lrs] 33 | 34 | if self.multiplier == 1.0: 35 | res = [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs] 36 | log = '\n LR: ' 37 | for x in res: 38 | log += repr(x) + ' ' 39 | print(log) 40 | return res 41 | else: 42 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in 43 | self.base_lrs] 44 | 45 | def step_ReduceLROnPlateau(self, metrics, epoch=None): 46 | if epoch is None: 47 | epoch = self.last_epoch + 1 48 | self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning 49 | if self.last_epoch <= self.total_epoch: 50 | warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in 51 | self.base_lrs] 52 | for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): 53 | param_group['lr'] = lr 54 | else: 55 | if epoch is None: 56 | self.after_scheduler.step(metrics, None) 57 | else: 58 | self.after_scheduler.step(metrics, epoch - self.total_epoch) 59 | 60 | def step(self, epoch=None, metrics=None): 61 | if type(self.after_scheduler) != ReduceLROnPlateau: 62 | if self.finished and self.after_scheduler: 63 | if epoch is None: 64 | self.after_scheduler.step(None) 65 | else: 66 | self.after_scheduler.step(epoch - self.total_epoch) 67 | self._last_lr = self.after_scheduler.get_last_lr() 68 | else: 69 | return super(GradualWarmupScheduler, self).step(epoch) 70 | else: 71 | self.step_ReduceLROnPlateau(metrics, epoch) 72 | 73 | 74 | class GradualWarmupSchedulerV2(GradualWarmupScheduler): 75 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): 76 | super(GradualWarmupSchedulerV2, self).__init__(optimizer, multiplier, total_epoch, after_scheduler) 77 | 78 | def get_lr(self): 79 | if self.last_epoch > self.total_epoch: 80 | if self.after_scheduler: 81 | if not self.finished: 82 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] 83 | self.finished = True 84 | 85 | return self.after_scheduler.get_lr() 86 | return [base_lr * self.multiplier for base_lr in self.base_lrs] 87 | if self.multiplier == 1.0: 88 | return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs] 89 | else: 90 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in 91 | self.base_lrs] 92 | -------------------------------------------------------------------------------- /networks/MultiscaleTransformerEncoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.utils.checkpoint 5 | from torch import nn as nn 6 | from torch.nn import functional as F 7 | 8 | from networks.BackboneCNN import BackboneCNN 9 | from networks.positional_encodings import prepare_2d_pos_encodings 10 | from networks.transformer import TransformerEncoderLayer, TransformerEncoder 11 | 12 | 13 | class MultiscaleTransformerEncoder(nn.Module): 14 | 15 | def __init__(self, dropout, encoder_dim=128, pos_encoding_dim=20, output_attention_weights=False): 16 | super(MultiscaleTransformerEncoder, self).__init__() 17 | 18 | self.backbone_cnn = BackboneCNN(output_feat_map=True, dropout=dropout) 19 | 20 | self.query = nn.Parameter(torch.randn(1, encoder_dim)) 21 | self.query_pos_encoding = nn.Parameter(torch.randn(1, encoder_dim)) 22 | 23 | self.pos_encoding_x = nn.Parameter(torch.randn(pos_encoding_dim, int(encoder_dim / 2))) 24 | self.pos_encoding_y = nn.Parameter(torch.randn(pos_encoding_dim, int(encoder_dim / 2))) 25 | 26 | # spp levels; first feature map will be passed in a residual connection to the output 27 | self.spp_levels = [8, 8, 4, 2, 1] 28 | # self.spp_levels = [8, 4, 2, 1] 29 | 30 | encoder_layers = 2 31 | encoder_heads = 2 32 | 33 | self.encoder_layer = TransformerEncoderLayer(d_model=encoder_dim, nhead=encoder_heads, 34 | dim_feedforward=int(encoder_dim), 35 | dropout=0.1, activation="relu", normalize_before=False) 36 | self.encoder = TransformerEncoder(encoder_layer=self.encoder_layer, num_layers=encoder_layers) 37 | 38 | # self.SPP_FC = nn.Linear(8576, encoder_dim) # for [8,4,2,1] SPP 39 | self.SPP_FC = nn.Linear(8704, encoder_dim) # for [8,8,4,2,1] SPP 40 | self.output_attention_weights = output_attention_weights 41 | 42 | def encoder_spp(self, previous_conv, num_sample, previous_conv_size): 43 | attention_weights = [] 44 | for i in range(len(self.spp_levels)): 45 | 46 | # Pooling support 47 | h_wid = int(math.ceil(previous_conv_size[0] / self.spp_levels[i])) 48 | w_wid = int(math.ceil(previous_conv_size[1] / self.spp_levels[i])) 49 | 50 | # Padding to retain orthogonal dimensions 51 | h_pad = int((h_wid * self.spp_levels[i] - previous_conv_size[0] + 1) / 2) 52 | w_pad = int((w_wid * self.spp_levels[i] - previous_conv_size[1] + 1) / 2) 53 | 54 | # apply pooling 55 | maxpool = nn.MaxPool2d((h_wid, w_wid), stride=(h_wid, w_wid), padding=(h_pad, w_pad)) 56 | 57 | y = maxpool(previous_conv) 58 | 59 | if i == 0: 60 | spp = y.reshape(num_sample, -1) 61 | else: 62 | pos_encoding_2d = prepare_2d_pos_encodings(self.pos_encoding_x, 63 | self.pos_encoding_y, 64 | y.shape[2], y.shape[3]) 65 | 66 | pos_encoding = pos_encoding_2d.permute(2, 0, 1) 67 | pos_encoding = pos_encoding[:, 0:y.shape[2], 0:y.shape[3]] 68 | pos_encoding = pos_encoding.reshape( 69 | (pos_encoding.shape[0], pos_encoding.shape[1] * pos_encoding.shape[2])) 70 | pos_encoding = pos_encoding.permute(1, 0).unsqueeze(1) 71 | pos_encoding = torch.cat((self.query_pos_encoding.unsqueeze(0), pos_encoding), 0) 72 | 73 | seq = y.reshape((y.shape[0], y.shape[1], y.shape[2] * y.shape[3])) 74 | seq = seq.permute(2, 0, 1) 75 | 76 | query = self.query.repeat(1, seq.shape[1], 1) 77 | seq = torch.cat((query, seq), 0) 78 | 79 | enc_output = torch.utils.checkpoint.checkpoint(self.encoder, seq, None, None, pos_encoding) 80 | 81 | cls_token = enc_output[0,] 82 | if self.output_attention_weights and i == 1: 83 | attention_weights = enc_output[1:].transpose(1, 0) 84 | 85 | spp = torch.cat((spp, cls_token.reshape(num_sample, -1)), 1) 86 | 87 | if self.output_attention_weights: 88 | return spp, attention_weights 89 | return spp 90 | 91 | def forward_one(self, x): 92 | 93 | activ_map = self.backbone_cnn(x) 94 | 95 | spp_result = self.encoder_spp(activ_map, x.size(0), 96 | [int(activ_map.size(2)), int(activ_map.size(3))]) 97 | if self.output_attention_weights: 98 | spp_activations = spp_result[0] 99 | attention_weights = spp_result[1] 100 | else: 101 | spp_activations = spp_result 102 | 103 | res = self.SPP_FC(spp_activations) 104 | res = F.normalize(res, dim=1, p=2) 105 | 106 | if self.output_attention_weights: 107 | return res, attention_weights 108 | return res 109 | 110 | def forward(self, x1, x2): 111 | res = dict() 112 | if not self.output_attention_weights: 113 | res['Emb1'] = self.forward_one(x1) 114 | res['Emb2'] = self.forward_one(x2) 115 | else: 116 | res['Emb1'], res['Emb1Attention'] = self.forward_one(x1) 117 | res['Emb2'], res['Emb2Attention'] = self.forward_one(x2) 118 | return res 119 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.11.0 2 | alabaster==0.7.12 3 | albumentations==0.5.2 4 | anaconda-client==1.7.2 5 | anaconda-navigator==1.10.0 6 | anaconda-project==0.8.3 7 | argh==0.26.2 8 | argon2-cffi==20.1.0 9 | asn1crypto==1.4.0 10 | astroid==2.4.2 11 | astropy==4.0.2 12 | async-generator==1.10 13 | atomicwrites==1.4.0 14 | attrs==20.3.0 15 | autopep8==1.5.4 16 | Babel==2.8.1 17 | backcall==0.2.0 18 | backports.functools-lru-cache==1.6.1 19 | backports.shutil-get-terminal-size==1.0.0 20 | backports.tempfile==1.0 21 | backports.weakref==1.0.post1 22 | bcrypt==3.2.0 23 | beautifulsoup4==4.9.3 24 | bitarray==1.6.1 25 | bkcharts==0.2 26 | bleach==3.2.1 27 | bokeh==2.2.3 28 | boto==2.49.0 29 | Bottleneck==1.3.2 30 | brotlipy==0.7.0 31 | cachetools==4.2.0 32 | certifi==2020.6.20 33 | cffi==1.14.3 34 | chardet==3.0.4 35 | click==7.1.2 36 | cloudpickle==1.6.0 37 | clyent==1.2.2 38 | colorama==0.4.4 39 | comtypes==1.1.7 40 | conda==4.9.2 41 | conda-build==3.20.5 42 | conda-package-handling==1.7.2 43 | conda-verify==3.4.2 44 | contextlib2==0.6.0.post1 45 | cryptography==3.1.1 46 | cycler==0.10.0 47 | Cython==0.29.21 48 | cytoolz==0.11.0 49 | dask==2.30.0 50 | decorator==4.4.2 51 | defusedxml==0.6.0 52 | diff-match-patch==20200713 53 | dill==0.3.3 54 | distributed==2.30.1 55 | docutils==0.16 56 | entrypoints==0.3 57 | et-xmlfile==1.0.1 58 | fastcache==1.1.0 59 | filelock==3.0.12 60 | flake8==3.8.4 61 | Flask==1.1.2 62 | fsspec==0.8.3 63 | future==0.18.2 64 | gevent==20.9.0 65 | glob2==0.7 66 | google-auth==1.24.0 67 | google-auth-oauthlib==0.4.2 68 | GPUtil==1.4.0 69 | greenlet==0.4.17 70 | grpcio==1.34.0 71 | h5py==3.1.0 72 | HeapDict==1.0.1 73 | html5lib==1.1 74 | idna==2.10 75 | imageio==2.9.0 76 | imagesize==1.2.0 77 | imgaug==0.4.0 78 | importlib-metadata==2.0.0 79 | iniconfig==1.1.1 80 | intervaltree==3.1.0 81 | ipykernel==5.3.4 82 | ipython==7.19.0 83 | ipython-genutils==0.2.0 84 | ipywidgets==7.5.1 85 | isort==5.6.4 86 | itsdangerous==1.1.0 87 | jdcal==1.4.1 88 | jedi==0.17.1 89 | Jinja2==2.11.2 90 | joblib==1.0.0 91 | json5==0.9.5 92 | jsonschema==3.2.0 93 | jupyter==1.0.0 94 | jupyter-client==6.1.7 95 | jupyter-console==6.2.0 96 | jupyter-core==4.6.3 97 | jupyterlab==2.2.6 98 | jupyterlab-pygments==0.1.2 99 | jupyterlab-server==1.2.0 100 | keyring==21.4.0 101 | kiwisolver==1.3.0 102 | lazy-object-proxy==1.4.3 103 | libarchive-c==2.9 104 | llvmlite==0.34.0 105 | locket==0.2.0 106 | lxml==4.6.1 107 | Markdown==3.3.3 108 | MarkupSafe==1.1.1 109 | matplotlib==3.3.3 110 | mccabe==0.6.1 111 | menuinst==1.4.16 112 | mistune==0.8.4 113 | mkl-fft==1.2.0 114 | mkl-random==1.1.1 115 | mkl-service==2.3.0 116 | mock==4.0.2 117 | more-itertools==8.6.0 118 | mpmath==1.1.0 119 | msgpack==1.0.0 120 | multipledispatch==0.6.0 121 | navigator-updater==0.2.1 122 | nbclient==0.5.1 123 | nbconvert==6.0.7 124 | nbformat==5.0.8 125 | nest-asyncio==1.4.2 126 | networkx==2.5 127 | nltk==3.5 128 | nose==1.3.7 129 | notebook==6.1.4 130 | numba==0.51.2 131 | numexpr==2.7.1 132 | numpy==1.19.2 133 | numpydoc==1.1.0 134 | oauthlib==3.1.0 135 | olefile==0.46 136 | opencv-python==4.4.0.46 137 | opencv-python-headless==4.4.0.46 138 | openpyxl==3.0.5 139 | packaging==20.4 140 | pandas==1.2.0 141 | pandocfilters==1.4.3 142 | paramiko==2.7.2 143 | parso==0.7.0 144 | partd==1.1.0 145 | path==15.0.0 146 | pathlib2==2.3.5 147 | pathtools==0.1.2 148 | patsy==0.5.1 149 | pep8==1.7.1 150 | pexpect==4.8.0 151 | pickleshare==0.7.5 152 | Pillow==8.0.1 153 | pip==20.2.4 154 | pkginfo==1.6.1 155 | pluggy==0.13.1 156 | ply==3.11 157 | prometheus-client==0.8.0 158 | prompt-toolkit==3.0.8 159 | protobuf==3.14.0 160 | psutil==5.7.2 161 | py==1.9.0 162 | pyasn1==0.4.8 163 | pyasn1-modules==0.2.8 164 | pycodestyle==2.6.0 165 | pycosat==0.6.3 166 | pycparser==2.20 167 | pycurl==7.43.0.6 168 | pydocstyle==5.1.1 169 | pyflakes==2.2.0 170 | Pygments==2.7.2 171 | pylint==2.6.0 172 | PyNaCl==1.4.0 173 | pyodbc==4.0.0-unsupported 174 | pyOpenSSL==19.1.0 175 | pyparsing==2.4.7 176 | pyreadline==2.1 177 | pyrsistent==0.17.3 178 | PySocks==1.7.1 179 | pytest==0.0.0 180 | python-dateutil==2.8.1 181 | python-jsonrpc-server==0.4.0 182 | python-language-server==0.35.1 183 | pytz==2020.5 184 | PyWavelets==1.1.1 185 | pywin32==227 186 | pywin32-ctypes==0.2.0 187 | pywinpty==0.5.7 188 | PyYAML==5.3.1 189 | pyzmq==19.0.2 190 | QDarkStyle==2.8.1 191 | QtAwesome==1.0.1 192 | qtconsole==4.7.7 193 | QtPy==1.9.0 194 | regex==2020.10.15 195 | requests==2.24.0 196 | requests-oauthlib==1.3.0 197 | rope==0.18.0 198 | rsa==4.6 199 | Rtree==0.9.4 200 | ruamel-yaml==0.15.87 201 | scikit-image==0.18.1 202 | scikit-learn==0.23.2 203 | scipy==1.5.4 204 | seaborn==0.11.0 205 | Send2Trash==1.5.0 206 | setuptools==50.3.1.post20201107 207 | Shapely==1.7.1 208 | simplegeneric==0.8.1 209 | singledispatch==3.4.0.3 210 | sip==4.19.13 211 | six==1.15.0 212 | snowballstemmer==2.0.0 213 | sortedcollections==1.2.1 214 | sortedcontainers==2.2.2 215 | soupsieve==2.0.1 216 | Sphinx==3.2.1 217 | sphinxcontrib-applehelp==1.0.2 218 | sphinxcontrib-devhelp==1.0.2 219 | sphinxcontrib-htmlhelp==1.0.3 220 | sphinxcontrib-jsmath==1.0.1 221 | sphinxcontrib-qthelp==1.0.3 222 | sphinxcontrib-serializinghtml==1.1.4 223 | sphinxcontrib-websupport==1.2.4 224 | spyder==4.1.5 225 | spyder-kernels==1.9.4 226 | SQLAlchemy==1.3.20 227 | statsmodels==0.12.0 228 | sympy==1.6.2 229 | tables==3.6.1 230 | tabulate==0.8.7 231 | tblib==1.7.0 232 | tensorboard==2.4.0 233 | tensorboard-plugin-wit==1.7.0 234 | tensorboardX==2.1 235 | termcolor==1.1.0 236 | terminado==0.9.1 237 | testpath==0.4.4 238 | threadpoolctl==2.1.0 239 | tifffile==2020.10.1 240 | toml==0.10.1 241 | toolz==0.11.1 242 | torch==1.7.1 243 | torchaudio==0.7.2 244 | torchsummary==1.5.1 245 | torchvision==0.8.2 246 | tornado==6.0.4 247 | tqdm==4.56.0 248 | traitlets==5.0.5 249 | typing-extensions==3.7.4.3 250 | ujson==4.0.1 251 | unicodecsv==0.14.1 252 | urllib3==1.25.11 253 | watchdog==0.10.3 254 | wcwidth==0.2.5 255 | webencodings==0.5.1 256 | Werkzeug==1.0.1 257 | wheel==0.35.1 258 | widgetsnbextension==3.5.1 259 | win-inet-pton==1.1.0 260 | win-unicode-console==0.5 261 | wincertstore==0.2 262 | wrapt==1.11.2 263 | xlrd==1.2.0 264 | XlsxWriter==1.3.7 265 | xlwings==0.20.8 266 | xlwt==1.3.0 267 | xmltodict==0.12.0 268 | yapf==0.30.0 269 | zict==2.0.0 270 | zipp==3.4.0 271 | zope.event==4.5.0 272 | zope.interface==5.1.2 273 | -------------------------------------------------------------------------------- /networks/positional_encodings.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | """ 6 | This is an implemenation of 1D, 2D, and 3D sinusodal positional encoding, 7 | being able to encode on tensors of the form 8 | (batchsize, x, ch), (batchsize, x, y, ch), and (batchsize, x, y, z, ch), where 9 | the positional encodings will be added to the ch dimension. The Attention is All You Need allowed for positional encoding in only one dimension, however, this works to extend this to 2 and 3 dimensions. 10 | """ 11 | 12 | 13 | # (batchsize, x, ch) 14 | class PositionalEncoding1D(nn.Module): 15 | def __init__(self, channels): 16 | """ 17 | :param channels: The last dimension of the tensor you want to apply pos emb to. 18 | """ 19 | super(PositionalEncoding1D, self).__init__() 20 | self.channels = channels 21 | inv_freq = 1. / (10000 ** (torch.arange(0, channels, 2).float() / channels)) 22 | self.register_buffer('inv_freq', inv_freq) 23 | 24 | def forward(self, tensor): 25 | """ 26 | :param tensor: A 3d tensor of size (batch_size, x, ch) 27 | :return: Positional Encoding Matrix of size (batch_size, x, ch) 28 | """ 29 | if len(tensor.shape) != 3: 30 | raise RuntimeError("The input tensor has to be 3d!") 31 | _, x, orig_ch = tensor.shape 32 | pos_x = torch.arange(x, device=tensor.device).type(self.inv_freq.type()) 33 | sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq) 34 | emb_x = torch.cat((sin_inp_x.sin(), sin_inp_x.cos()), dim=-1) 35 | emb = torch.zeros((x, self.channels), device=tensor.device).type(tensor.type()) 36 | emb[:, :self.channels] = emb_x 37 | 38 | return emb[None, :, :orig_ch] 39 | 40 | 41 | # (batchsize, x, y, ch) 42 | class PositionalEncoding2D(nn.Module): 43 | def __init__(self, channels): 44 | """ 45 | :param channels: The last dimension of the tensor you want to apply pos emb to. 46 | """ 47 | super(PositionalEncoding2D, self).__init__() 48 | channels = int(np.ceil(channels / 2)) 49 | self.channels = channels 50 | inv_freq = 1. / (10000 ** (torch.arange(0, channels, 2).float() / channels)) 51 | self.register_buffer('inv_freq', inv_freq) 52 | 53 | def forward(self, tensor): 54 | """ 55 | :param tensor: A 4d tensor of size (batch_size, x, y, ch) 56 | :return: Positional Encoding Matrix of size (batch_size, x, y, ch) 57 | """ 58 | if len(tensor.shape) != 4: 59 | raise RuntimeError("The input tensor has to be 4d!") 60 | _, x, y, orig_ch = tensor.shape 61 | pos_x = torch.arange(x, device=tensor.device).type(self.inv_freq.type()) 62 | pos_y = torch.arange(y, device=tensor.device).type(self.inv_freq.type()) 63 | sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq) 64 | sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq) 65 | emb_x = torch.cat((sin_inp_x.sin(), sin_inp_x.cos()), dim=-1).unsqueeze(1) 66 | emb_y = torch.cat((sin_inp_y.sin(), sin_inp_y.cos()), dim=-1) 67 | emb = torch.zeros((x, y, self.channels * 2), device=tensor.device).type(tensor.type()) 68 | emb[:, :, :self.channels] = emb_x 69 | emb[:, :, self.channels:2 * self.channels] = emb_y 70 | 71 | return emb[None, :, :, :orig_ch] 72 | 73 | 74 | # (batchsize, x, y, z, ch) 75 | class PositionalEncoding3D(nn.Module): 76 | def __init__(self, channels): 77 | """ 78 | :param channels: The last dimension of the tensor you want to apply pos emb to. 79 | """ 80 | super(PositionalEncoding3D, self).__init__() 81 | channels = int(np.ceil(channels / 3)) 82 | if channels % 2: 83 | channels += 1 84 | self.channels = channels 85 | inv_freq = 1. / (10000 ** (torch.arange(0, channels, 2).float() / channels)) 86 | self.register_buffer('inv_freq', inv_freq) 87 | 88 | def forward(self, tensor): 89 | """ 90 | :param tensor: A 5d tensor of size (batch_size, x, y, z, ch) 91 | :return: Positional Encoding Matrix of size (batch_size, x, y, z, ch) 92 | """ 93 | if len(tensor.shape) != 5: 94 | raise RuntimeError("The input tensor has to be 5d!") 95 | _, x, y, z, orig_ch = tensor.shape 96 | pos_x = torch.arange(x, device=tensor.device).type(self.inv_freq.type()) 97 | pos_y = torch.arange(y, device=tensor.device).type(self.inv_freq.type()) 98 | pos_z = torch.arange(z, device=tensor.device).type(self.inv_freq.type()) 99 | sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq) 100 | sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq) 101 | sin_inp_z = torch.einsum("i,j->ij", pos_z, self.inv_freq) 102 | emb_x = torch.cat((sin_inp_x.sin(), sin_inp_x.cos()), dim=-1).unsqueeze(1).unsqueeze(1) 103 | emb_y = torch.cat((sin_inp_y.sin(), sin_inp_y.cos()), dim=-1).unsqueeze(1) 104 | emb_z = torch.cat((sin_inp_z.sin(), sin_inp_z.cos()), dim=-1) 105 | emb = torch.zeros((x, y, z, self.channels * 3), device=tensor.device).type(tensor.type()) 106 | emb[:, :, :, :self.channels] = emb_x 107 | emb[:, :, :, self.channels:2 * self.channels] = emb_y 108 | emb[:, :, :, 2 * self.channels:] = emb_z 109 | 110 | return emb[None, :, :, :, :orig_ch] 111 | 112 | 113 | def prepare_2d_pos_encodings(pos_enc_x, pos_enc_y, row_num, col_num): 114 | pos_enc_x = pos_enc_x[0:col_num].unsqueeze(0) # x=[1,..,20] 115 | pos_enc_y = pos_enc_y[0:row_num] 116 | 117 | for i in range(row_num): 118 | 119 | curr_y = pos_enc_y[i, :].unsqueeze(0).unsqueeze(0).repeat(1, col_num, 1) 120 | 121 | if i == 0: 122 | pos_encoding_2d = torch.cat((pos_enc_x, curr_y), 2) 123 | else: 124 | curr_pos_encoding_2d = torch.cat((pos_enc_x, curr_y), 2) 125 | 126 | pos_encoding_2d = torch.cat((pos_encoding_2d, curr_pos_encoding_2d), 0) 127 | 128 | return pos_encoding_2d 129 | 130 | 131 | def prepare_2d_pos_encodings_seq(pos_enc_x, pos_enc_y, row_num, col_num): 132 | pos_encoding_2d = prepare_2d_pos_encodings(pos_enc_x, pos_enc_y, row_num, col_num) 133 | 134 | pos_enc_2d_seq = pos_encoding_2d.permute(2, 0, 1) 135 | pos_enc_2d_seq = pos_enc_2d_seq[:, 0:row_num, 0:col_num] 136 | pos_enc_2d_seq = pos_enc_2d_seq.reshape( 137 | (pos_enc_2d_seq.shape[0], pos_enc_2d_seq.shape[1] * pos_enc_2d_seq.shape[2])) 138 | pos_enc_2d_seq = pos_enc_2d_seq.permute(1, 0).unsqueeze(1) 139 | 140 | return pos_enc_2d_seq 141 | -------------------------------------------------------------------------------- /util/encoder_heatmaps.py: -------------------------------------------------------------------------------- 1 | import ntpath 2 | import os 3 | import pathlib 4 | import warnings 5 | 6 | import GPUtil 7 | import cv2 8 | import matplotlib.image as mpimg 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | import torch 12 | import torch.nn as nn 13 | 14 | from networks.MultiscaleTransformerEncoder import MultiscaleTransformerEncoder 15 | from util.utils import load_model, normalize_image, evaluate_network 16 | 17 | warnings.filterwarnings("ignore", message="UserWarning: albumentations.augmentations.transforms.RandomResizedCrop") 18 | 19 | 20 | def display(rgb, attn1, nir_orig, attn2): 21 | fig, ax = plt.subplots(nrows=2, ncols=2) 22 | ax[0, 0].imshow(rgb) 23 | ax[0, 0].axis('off') 24 | ax[0, 0].set_title('Input Image') 25 | 26 | ax[0, 1].imshow(rgb) 27 | ax[0, 1].imshow(attn1, alpha=0.25, cmap='jet') 28 | ax[0, 1].axis('off') 29 | ax[0, 1].set_title('Attention') 30 | 31 | ax[1, 0].imshow(nir_orig, cmap="gray") 32 | ax[1, 0].axis('off') 33 | ax[1, 0].set_title('Input Image') 34 | 35 | ax[1, 1].imshow(nir_orig, cmap="gray") 36 | ax[1, 1].imshow(attn2, alpha=0.25, cmap='jet') 37 | ax[1, 1].axis('off') 38 | ax[1, 1].set_title('Attention') 39 | 40 | plt.show() 41 | 42 | 43 | def get_file_name(path): 44 | head, tail = ntpath.split(path) 45 | fname = tail or ntpath.basename(head) 46 | return fname.split('.')[0] 47 | 48 | 49 | def save_image(img, fname, out_folder): 50 | fname += '.png' 51 | mpimg.imsave(os.path.join(out_folder, fname), img) 52 | 53 | 54 | def generate_attn_heatmaps(net, imgs, outpath, device, disp=True): 55 | net.eval() 56 | rgb_attentions = [] 57 | nir_attentions = [] 58 | for rgb_path, nir_path in imgs: 59 | print('Working on:', rgb_path) 60 | 61 | dir = rgb_path.split('\\')[-2] 62 | curr_outpath = os.path.join(outpath, dir) 63 | 64 | rgb = mpimg.imread(rgb_path) 65 | rgb_gray_orig = cv2.imread(rgb_path) 66 | rgb_gray_orig = cv2.cvtColor(rgb_gray_orig, cv2.COLOR_BGR2GRAY) 67 | nir_orig = mpimg.imread(nir_path) 68 | pathlib.Path(curr_outpath).mkdir(parents=True, exist_ok=True) 69 | 70 | rgb_gray = rgb_gray_orig.copy().reshape(1, 1, rgb_gray_orig.shape[0], rgb_gray_orig.shape[1]) 71 | nir = nir_orig.copy().reshape(1, 1, nir_orig.shape[0], nir_orig.shape[1]) 72 | rgb_gray = torch.from_numpy(normalize_image(rgb_gray.astype(np.float32))) 73 | nir = torch.from_numpy(normalize_image(nir.astype(np.float32))) 74 | 75 | emb = evaluate_network(net, rgb_gray, nir, device, 800) 76 | emb1_attn = np.array(emb['Emb1Attention']).squeeze() 77 | emb2_attn = np.array(emb['Emb2Attention']).squeeze() 78 | 79 | _, emb1_attn = emb1_attn[0], emb1_attn[1:] 80 | 81 | emb1_attn = np.mean(emb1_attn.reshape(8 * 8, 128), axis=1) 82 | indices = emb1_attn.argsort()[:int(-0.9 * emb1_attn.shape[0])] 83 | emb1_attn[indices] = 0 84 | emb1_attn = emb1_attn.reshape(8, 8) 85 | emb1_attn = 255 * (emb1_attn - emb1_attn.min()) / (emb1_attn.max() - emb1_attn.min()) 86 | emb1_attn = np.uint8(emb1_attn) 87 | emb1_attn = cv2.resize(emb1_attn, (rgb.shape[1], rgb.shape[0]), interpolation=cv2.INTER_CUBIC) 88 | 89 | _, emb2_attn = emb2_attn[0], emb2_attn[1:] 90 | 91 | emb2_attn = np.mean(emb2_attn.reshape(8 * 8, 128), axis=1) 92 | indices = emb2_attn.argsort()[:int(-0.9 * emb2_attn.shape[0])] 93 | emb2_attn[indices] = 0 94 | emb2_attn = emb2_attn.reshape(8, 8) 95 | emb2_attn = 255 * (emb2_attn - emb2_attn.min()) / (emb2_attn.max() - emb2_attn.min()) 96 | emb2_attn = np.uint8(emb2_attn) 97 | emb2_attn = cv2.resize(emb2_attn, (nir_orig.shape[1], nir_orig.shape[0]), interpolation=cv2.INTER_CUBIC) 98 | 99 | rgb_attentions.append(emb1_attn.copy()) 100 | nir_attentions.append(emb2_attn.copy()) 101 | if disp: 102 | display(rgb, emb1_attn, nir_orig, emb2_attn) 103 | plt.close() 104 | 105 | max_h = max([emb.shape[0] for emb in rgb_attentions]) 106 | max_w = max([emb.shape[1] for emb in rgb_attentions]) 107 | padded_rgb = np.zeros((len(rgb_attentions), max_h, max_w)) 108 | for i, emb in enumerate(rgb_attentions): 109 | padded_rgb[i, :emb.shape[0], :emb.shape[1]] = emb 110 | padded_nir = np.zeros((len(nir_attentions), max_h, max_w)) 111 | for i, emb in enumerate(nir_attentions): 112 | padded_nir[i, :emb.shape[0], :emb.shape[1]] = emb 113 | rgb_attentions = np.array(padded_rgb).max(axis=0) 114 | nir_attentions = np.array(padded_nir).max(axis=0) 115 | 116 | dpi = 80 117 | figsize = rgb.shape[1] / dpi, rgb.shape[0] / dpi 118 | fig = plt.figure(figsize=figsize, dpi=dpi) 119 | ax = fig.add_axes([0, 0, 1, 1]) 120 | ax.imshow(rgb_attentions, alpha=1, cmap='jet') 121 | plt.axis('off') 122 | plt.savefig(os.path.join(outpath, 'avg_rgb' + '_attention' + '.png')) 123 | plt.close() 124 | figsize = nir_orig.shape[1] / dpi, nir_orig.shape[0] / dpi 125 | fig = plt.figure(figsize=figsize, dpi=dpi) 126 | ax = fig.add_axes([0, 0, 1, 1]) 127 | ax.imshow(nir_attentions, alpha=1, cmap='jet') 128 | plt.axis('off') 129 | plt.savefig(os.path.join(outpath, 'avg_nir' + '_attention' + '.png')) 130 | plt.close() 131 | 132 | 133 | def read_data_fnames(data_path): 134 | fnames = [] 135 | for _, subdirs, _ in os.walk(data_path): 136 | for subdir in subdirs: 137 | subdir_fpath = os.path.join(data_path, subdir) 138 | for _, _, files in os.walk(subdir_fpath): 139 | for f in files: 140 | if not '_rgb' in f: 141 | continue 142 | fpath_rgb = os.path.join(subdir_fpath, f) 143 | fpath_nir = fpath_rgb.replace('_rgb', '_nir') 144 | fnames.append((fpath_rgb, fpath_nir)) 145 | 146 | return fnames 147 | 148 | 149 | def main(): 150 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # "cuda:0" 151 | num_gpus = torch.cuda.device_count() 152 | torch.cuda.empty_cache() 153 | GPUtil.showUtilization() 154 | 155 | print(device) 156 | 157 | models_dir_name = '../artifacts/symmetric_enc_transformer_visnir_10/models/' 158 | best_fname = 'best_model' 159 | 160 | output_attention_weights = True 161 | net = MultiscaleTransformerEncoder(output_attention_weights) 162 | 163 | net, optimizer, LowestError, StartEpoch, scheduler, LodedNegativeMiningMode = load_model(net, True, 164 | models_dir_name, 165 | best_fname, 166 | True, device) 167 | if num_gpus > 1: 168 | print("Let's use", torch.cuda.device_count(), "GPUs!") 169 | net = nn.DataParallel(net) 170 | net.to(device) 171 | 172 | outpath = "D:\\multisensor\\attentions\\" 173 | 174 | imgs = read_data_fnames("D:\\multisensor\\datasets\\Vis-Nir\\data") 175 | generate_attn_heatmaps(net, imgs, outpath, device, disp=False) 176 | 177 | 178 | if __name__ == '__main__': 179 | main() 180 | -------------------------------------------------------------------------------- /util/utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | import os 4 | 5 | import numpy as np 6 | import torch 7 | from torch.optim.lr_scheduler import ReduceLROnPlateau 8 | 9 | from util.read_hdf5_data import read_hdf5_data 10 | 11 | 12 | def load_model(net, start_best_model, models_dirname, best_filename, use_best_score, device, load_epoch=None): 13 | scheduler = None 14 | optimizer = None 15 | 16 | lowest_err = 1e5 17 | 18 | negative_mining_mode = 'Random' 19 | 20 | if start_best_model: 21 | flist = glob.glob(models_dirname + best_filename + '.pth') 22 | else: 23 | flist = glob.glob(models_dirname + "model*") 24 | 25 | if flist: 26 | flist.sort(key=os.path.getmtime) 27 | 28 | if load_epoch is not None: 29 | model_path = models_dirname + 'model_epoch_%s.pth' % load_epoch 30 | print('%s loaded' % model_path) 31 | checkpoint = torch.load(model_path) 32 | else: 33 | print(flist[-1] + ' loaded') 34 | checkpoint = torch.load(flist[-1]) 35 | 36 | if ('lowest_err' in checkpoint.keys()) and use_best_score: 37 | lowest_err = checkpoint['lowest_err'] 38 | 39 | if 'negative_mining_mode' in checkpoint.keys(): 40 | negative_mining_mode = checkpoint['negative_mining_mode'] 41 | 42 | net_dict = net.state_dict() 43 | checkpoint['state_dict'] = {k: v for k, v in checkpoint['state_dict'].items() if 44 | (k in net_dict) and (net_dict[k].shape == checkpoint['state_dict'][k].shape)} 45 | 46 | net.load_state_dict(checkpoint['state_dict'], strict=False) 47 | 48 | if 'optimizer_name' in checkpoint.keys(): 49 | optimizer = torch.optim.Adam(net.parameters()) 50 | try: 51 | optimizer = checkpoint['optimizer'] 52 | for state in optimizer.state.values(): 53 | for k, v in state.items(): 54 | if isinstance(v, torch.Tensor): 55 | state[k] = v.cuda(device) 56 | except Exception as e: 57 | print(e) 58 | print('Optimizer loading error') 59 | 60 | if ('scheduler_name' in checkpoint.keys()) and (optimizer != None): 61 | 62 | try: 63 | if checkpoint['scheduler_name'] == 'ReduceLROnPlateau': 64 | scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=6, verbose=True) 65 | 66 | scheduler = checkpoint['scheduler'] 67 | except Exception as e: 68 | print(e) 69 | print('Optimizer loading error') 70 | 71 | start_epoch = checkpoint['epoch'] + 1 72 | else: 73 | print('Weights file not loaded') 74 | optimizer = None 75 | start_epoch = 0 76 | 77 | print('lowest_err: ' + repr(lowest_err)[0:6]) 78 | 79 | return net, optimizer, lowest_err, start_epoch, scheduler, negative_mining_mode 80 | 81 | 82 | class MultiEpochsDataLoader(torch.utils.data.DataLoader): 83 | 84 | def __init__(self, *args, **kwargs): 85 | super().__init__(*args, **kwargs) 86 | self._DataLoader__initialized = False 87 | self.batch_sampler = _RepeatSampler(self.batch_sampler) 88 | self._DataLoader__initialized = True 89 | self.iterator = super().__iter__() 90 | 91 | def __len__(self): 92 | return len(self.batch_sampler.sampler) 93 | 94 | def __iter__(self): 95 | for i in range(len(self)): 96 | yield next(self.iterator) 97 | 98 | 99 | class _RepeatSampler(object): 100 | """ Sampler that repeats forever. 101 | Args: 102 | sampler (Sampler) 103 | """ 104 | 105 | def __init__(self, sampler): 106 | self.sampler = sampler 107 | 108 | def __iter__(self): 109 | while True: 110 | yield from iter(self.sampler) 111 | 112 | 113 | class MyGradScaler: 114 | def __init__(self): 115 | pass 116 | 117 | def scale(self, loss): 118 | return loss 119 | 120 | def unscale_(self, optimizer): 121 | pass 122 | 123 | def step(self, optimizer): 124 | optimizer.step() 125 | 126 | def update(self): 127 | pass 128 | 129 | 130 | def save_best_model_stats(dir, epoch, test_err, test_data): 131 | content = { 132 | 'Test error': test_err, 133 | 'Epoch': epoch 134 | } 135 | for test_set in test_data: 136 | if isinstance(test_data[test_set], dict): 137 | content[f'Test set {test_set} error'] = test_data[test_set]['TestError'] 138 | fpath = os.path.join(dir, 'visnir_best_model_stats.json') 139 | with open(fpath, 'w', encoding='utf-8') as f: 140 | json.dump(content, f, ensure_ascii=False, indent=4) 141 | 142 | 143 | def FPR95Accuracy(dist_mat, labels): 144 | pos_indices = np.squeeze(np.asarray(np.where(labels == 1))) 145 | neg_indices = np.squeeze(np.asarray(np.where(labels == 0))) 146 | 147 | neg_dists = dist_mat[neg_indices] 148 | pos_dists = np.sort(dist_mat[pos_indices]) 149 | 150 | thresh = pos_dists[int(0.95 * pos_dists.shape[0])] 151 | 152 | fp = sum(neg_dists < thresh) 153 | 154 | return fp / float(neg_dists.shape[0]) 155 | 156 | 157 | def FPR95Threshold(PosDist): 158 | PosDist = PosDist.sort(dim=-1, descending=False)[0] 159 | Val = PosDist[int(0.95 * PosDist.shape[0])] 160 | 161 | return Val 162 | 163 | 164 | def normalize_image(x): 165 | return x / (255.0 / 2) 166 | 167 | 168 | def evaluate_network(net, data1, data2, device, step_size=800): 169 | with torch.no_grad(): 170 | 171 | for k in range(0, data1.shape[0], step_size): 172 | 173 | a = data1[k:(k + step_size), :, :, :] 174 | b = data2[k:(k + step_size), :, :, :] 175 | 176 | a, b = a.to(device), b.to(device) 177 | x = net(a, b) 178 | 179 | if k == 0: 180 | keys = list(x.keys()) 181 | emb = dict() 182 | for key in keys: 183 | emb[key] = np.zeros(tuple([data1.shape[0]]) + tuple(x[key].shape[1:]), dtype=np.float32) 184 | 185 | for key in keys: 186 | emb[key][k:(k + step_size)] = x[key].cpu() 187 | 188 | return emb 189 | 190 | 191 | def load_test_datasets(test_dir): 192 | file_list = glob.glob(test_dir + "*.hdf5") 193 | test_data = dict() 194 | for f in file_list: 195 | path, dataset_name = os.path.split(f) 196 | dataset_name = os.path.splitext(dataset_name)[0] 197 | 198 | data = read_hdf5_data(f) 199 | 200 | x = data['Data'].astype(np.float32) 201 | test_labels = torch.from_numpy(np.squeeze(data['Labels'])) 202 | del data 203 | 204 | x[:, :, :, :, 0] -= x[:, :, :, :, 0].mean() 205 | x[:, :, :, :, 1] -= x[:, :, :, :, 1].mean() 206 | 207 | x = normalize_image(x) 208 | x = torch.from_numpy(x) 209 | 210 | test_data[dataset_name] = dict() 211 | test_data[dataset_name]['Data'] = x 212 | test_data[dataset_name]['Labels'] = test_labels 213 | del x 214 | return test_data 215 | 216 | 217 | def load_validation_set(train_data, train_split, train_labels): 218 | val_indices = np.squeeze(np.asarray(np.where(train_split == 3))) 219 | 220 | # VALIDATION data 221 | val_labels = torch.from_numpy(train_labels[val_indices]) 222 | 223 | val_data = train_data[val_indices, :, :, :].astype(np.float32) 224 | val_data[:, :, :, :, 0] -= val_data[:, :, :, :, 0].mean() 225 | val_data[:, :, :, :, 1] -= val_data[:, :, :, :, 1].mean() 226 | val_data = torch.from_numpy(normalize_image(val_data)) 227 | 228 | return val_data, val_labels 229 | 230 | 231 | def evaluate_test(net, test_data, device, step_size=800): 232 | samples_amount = 0 233 | total_test_err = 0 234 | for dataset_name in test_data: 235 | dataset = test_data[dataset_name] 236 | emb = evaluate_network(net, dataset['Data'][:, :, :, :, 0], dataset['Data'][:, :, :, :, 1], device, step_size) 237 | 238 | dist = np.power(emb['Emb1'] - emb['Emb2'], 2).sum(1) 239 | dataset['TestError'] = FPR95Accuracy(dist, dataset['Labels']) * 100 240 | total_test_err += dataset['TestError'] * dataset['Data'].shape[0] 241 | samples_amount += dataset['Data'].shape[0] 242 | total_test_err /= samples_amount 243 | 244 | del emb 245 | return total_test_err 246 | 247 | 248 | def evaluate_validation(net, val_data, val_labels, device): 249 | val_emb = evaluate_network(net, val_data[:, :, :, :, 0], val_data[:, :, :, :, 1], device) 250 | 251 | dist = np.power(val_emb['Emb1'] - val_emb['Emb2'], 2).sum(1) 252 | val_err = FPR95Accuracy(dist, val_labels) * 100 253 | return val_err 254 | -------------------------------------------------------------------------------- /networks/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def sim_matrix(a, b, eps=1e-8): 7 | """ 8 | added eps for numerical stability 9 | """ 10 | a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None] 11 | a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n)) 12 | b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n)) 13 | sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1)) 14 | return sim_mt 15 | 16 | 17 | class ContrastiveLoss(nn.Module): 18 | """ 19 | Contrastive loss 20 | Takes embeddings of two samples and a target label == 1 if samples are from the same class and label == 0 otherwise 21 | """ 22 | 23 | def __init__(self, margin): 24 | super(ContrastiveLoss, self).__init__() 25 | self.margin = margin 26 | self.eps = 1e-9 27 | 28 | def forward(self, output1, output2, target, size_average=True): 29 | distances = (output2 - output1).pow(2).sum(1) # squared distances 30 | losses = 0.5 * (target.float() * distances + 31 | (1 + -1 * target).float() * F.relu(self.margin - (distances + self.eps).sqrt()).pow(2)) 32 | return losses.mean() if size_average else losses.sum() 33 | 34 | 35 | class TripletLoss(nn.Module): 36 | """ 37 | Triplet loss 38 | Takes embeddings of an anchor sample, a positive sample and a negative sample 39 | """ 40 | 41 | def __init__(self, margin): 42 | super(TripletLoss, self).__init__() 43 | self.margin = margin 44 | 45 | def forward(self, anchor, positive, negative): 46 | distance_positive = (anchor - positive).pow(2).sum(1) 47 | distance_negative = (anchor - negative).pow(2).sum(1) 48 | 49 | losses = F.relu(distance_positive - distance_negative + self.margin) 50 | 51 | return losses.mean() 52 | 53 | 54 | class PairwiseLoss(nn.Module): 55 | def __init__(self): 56 | super(PairwiseLoss, self).__init__() 57 | 58 | self.mode = 'FPR' 59 | 60 | @staticmethod 61 | def forward(pos1, pos2): 62 | if (pos1.nelement() == 0) | (pos2.nelement() == 0): 63 | return 0 64 | 65 | losses = (pos1 - pos2).pow(2).sum(1).pow(.5) 66 | 67 | return losses.mean() 68 | 69 | 70 | class OnlineContrastiveLoss(nn.Module): 71 | """ 72 | Online Contrastive loss 73 | Takes a batch of embeddings and corresponding labels. 74 | Pairs are generated using pair_selector object that take embeddings and targets and return indices of positive 75 | and negative pairs 76 | """ 77 | 78 | def __init__(self, margin, pair_selector): 79 | super(OnlineContrastiveLoss, self).__init__() 80 | self.margin = margin 81 | self.pair_selector = pair_selector 82 | 83 | def forward(self, embeddings, target): 84 | positive_pairs, negative_pairs = self.pair_selector.get_pairs(embeddings, target) 85 | if embeddings.is_cuda: 86 | positive_pairs = positive_pairs.cuda() 87 | negative_pairs = negative_pairs.cuda() 88 | positive_loss = (embeddings[positive_pairs[:, 0]] - embeddings[positive_pairs[:, 1]]).pow(2).sum(1) 89 | negative_loss = F.relu( 90 | self.margin - (embeddings[negative_pairs[:, 0]] - embeddings[negative_pairs[:, 1]]).pow(2).sum( 91 | 1).sqrt()).pow(2) 92 | loss = torch.cat([positive_loss, negative_loss], dim=0) 93 | return loss.mean() 94 | 95 | 96 | class OnlineTripletLoss(nn.Module): 97 | """ 98 | Online Triplets loss 99 | Takes a batch of embeddings and corresponding labels. 100 | Triplets are generated using triplet_selector object that take embeddings and targets and return indices of 101 | triplets 102 | """ 103 | 104 | def __init__(self, margin, triplet_selector): 105 | super(OnlineTripletLoss, self).__init__() 106 | self.margin = margin 107 | self.triplet_selector = triplet_selector 108 | 109 | def forward(self, embeddings, target): 110 | triplets = self.triplet_selector.get_triplets(embeddings, target) 111 | 112 | if embeddings.is_cuda: 113 | triplets = triplets.cuda() 114 | 115 | ap_distances = (embeddings[triplets[:, 0]] - embeddings[triplets[:, 1]]).pow(2).sum(1) 116 | an_distances = (embeddings[triplets[:, 0]] - embeddings[triplets[:, 2]]).pow(2).sum(1) 117 | losses = F.relu(ap_distances - an_distances + self.margin) 118 | 119 | return losses.mean(), len(triplets) 120 | 121 | 122 | class OnlineHardNegativeMiningTripletLoss(nn.Module): 123 | """ 124 | Online Triplets loss 125 | Takes a batch of embeddings and corresponding labels. 126 | Triplets are generated using triplet_selector object that take embeddings and targets and return indices of 127 | triplets 128 | """ 129 | 130 | def __init__(self, margin, mode, margin_ratio=1, pos_ratio=1, neg_pow=1, pos_pow=1, device=None): 131 | super(OnlineHardNegativeMiningTripletLoss, self).__init__() 132 | self.margin = margin 133 | self.mode = mode 134 | self.margin_ratio = margin_ratio 135 | self.pos_ratio = pos_ratio 136 | self.pos_pow = pos_pow 137 | self.neg_pow = neg_pow 138 | self.device = device 139 | 140 | def forward(self, emb1, emb2): 141 | 142 | if self.mode == 'Random': 143 | neg_idx = torch.randint(high=emb1.shape[0], size=(emb1.shape[0],), device=self.device) 144 | ap_distances = (emb1 - emb2).pow(2).sum(1) 145 | an_distances = (emb1 - emb2[neg_idx, :]).pow(2).sum(1) 146 | margin = ap_distances - an_distances 147 | 148 | if (self.mode == 'Hardest') | (self.mode == 'HardPos'): 149 | sim_matrix = torch.mm(emb1, emb2.transpose(0, 1)) 150 | sim_matrix -= 1000000000 * torch.eye(n=sim_matrix.shape[0], m=sim_matrix.shape[1], device=self.device) 151 | neg_idx = torch.argmax(sim_matrix, axis=1) # find negative with highest similarity 152 | 153 | if self.mode == 'Hardest': 154 | ap_distances = (emb1 - emb2).pow(2).sum(1) 155 | an_distances = (emb1 - emb2[neg_idx, :]).pow(2).sum(1) 156 | 157 | margin = ap_distances - an_distances 158 | 159 | if self.mode == 'HardPos': 160 | ap_distances = (emb1 - emb2).pow(2).sum(1) 161 | an_distances = (emb1 - emb2[neg_idx, :]).pow(2).sum(1) 162 | 163 | # get LARGEST positive distances 164 | pos_idx = ap_distances.argsort(dim=-1, descending=True) # sort positive distances 165 | pos_idx = pos_idx[0:int(self.pos_ratio * pos_idx.shape[0])] # retain only self.pos_ratio of the positives 166 | 167 | margin = ap_distances[pos_idx] - an_distances[pos_idx] 168 | 169 | # hard examples first: sort margin 170 | idx = margin.argsort(dim=-1, descending=True) 171 | 172 | # retain a subset of hard examples 173 | idx = idx[0:int(self.margin_ratio * idx.shape[0])] # retain some of the examples 174 | 175 | margin = margin[idx] 176 | 177 | losses = F.relu(margin + self.margin) 178 | idx = torch.where(losses > 0)[0] 179 | 180 | if idx.size()[0] > 0: 181 | losses = losses[idx].mean() 182 | 183 | if torch.isnan(losses): 184 | print('Found nan in loss ') 185 | else: 186 | losses = 0 187 | 188 | return losses 189 | 190 | 191 | class InnerProduct(nn.Module): 192 | 193 | def __init__(self): 194 | super(InnerProduct, self).__init__() 195 | 196 | @staticmethod 197 | def forward(emb1, emb2): 198 | loss = (emb1 * emb2).abs().sum(1) 199 | return loss.mean() 200 | 201 | 202 | def find_fpr_training_set(emb1, emb2, FprValPos, FprValNeg): 203 | with torch.no_grad(): 204 | 205 | sim_matrix = torch.mm(emb1, emb2.transpose(0, 1)) 206 | sim_matrix -= 1000000000 * torch.eye(n=sim_matrix.shape[0], m=sim_matrix.shape[1]) 207 | neg_idx = torch.argmax(sim_matrix, axis=1) 208 | 209 | # compute DISTANCES 210 | ap_distances = (emb1 - emb2).pow(2).sum(1) 211 | an_distances = (emb1 - emb2[neg_idx, :]).pow(2).sum(1) 212 | 213 | # get positive distances ABOVE fpr 214 | pos_idx = torch.squeeze(torch.where(ap_distances > FprValPos)[0]) 215 | 216 | # sort array: LARGEST distances first 217 | pos_idx = pos_idx[ap_distances[pos_idx].argsort(dim=-1, descending=True)] 218 | 219 | # get negative distances BELOW fpr 220 | neg_idx1 = torch.squeeze(torch.where(an_distances < FprValNeg)[0]) 221 | 222 | if (neg_idx1.nelement() > 1): 223 | neg_idx1 = neg_idx1[an_distances[neg_idx1].argsort(dim=-1, descending=False)] 224 | 225 | neg_idx2 = neg_idx[neg_idx1] 226 | 227 | res = dict() 228 | res['pos_idx'] = pos_idx 229 | res['NegIdxA1'] = neg_idx1 230 | res['NegIdxA2'] = neg_idx2 231 | 232 | neg_idx = torch.argmax(sim_matrix, axis=0) 233 | an_distances = (emb1[neg_idx, :] - emb2).pow(2).sum(1) 234 | 235 | neg_idx2 = torch.squeeze(torch.where(an_distances < FprValNeg)[0]) 236 | if neg_idx2.nelement() > 1: 237 | neg_idx2 = neg_idx2[an_distances[neg_idx2].argsort(dim=-1, descending=False)] 238 | neg_idx1 = neg_idx[neg_idx2] 239 | res['NegIdxB1'] = neg_idx1 240 | res['NegIdxB2'] = neg_idx2 241 | 242 | return res 243 | 244 | 245 | class FPRLoss(nn.Module): 246 | 247 | def __init__(self, ): 248 | super(FPRLoss, self).__init__() 249 | 250 | @staticmethod 251 | def forward(anchor, positive, negative): 252 | distance_positive = (anchor - positive).pow(2).sum(1) 253 | distance_negative = (anchor - negative).pow(2).sum(1) 254 | 255 | losses = distance_positive.mean() - distance_negative.mean() 256 | 257 | return losses 258 | -------------------------------------------------------------------------------- /networks/transformer.py: -------------------------------------------------------------------------------- 1 | """ 2 | DETR Transformer class. 3 | 4 | Copy-paste from torch.nn.Transformer with modifications: 5 | * positional encodings are passed in MHattention 6 | * extra LN at the end of encoder is removed 7 | * decoder returns a stack of activations from all decoding layers 8 | """ 9 | import copy 10 | from typing import Optional 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | from torch import nn, Tensor 15 | 16 | 17 | class Transformer(nn.Module): 18 | 19 | def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, 20 | num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, 21 | activation="relu", normalize_before=False, 22 | return_intermediate_dec=False): 23 | super().__init__() 24 | 25 | encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, 26 | dropout, activation, normalize_before) 27 | encoder_norm = nn.LayerNorm(d_model) if normalize_before else None 28 | self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) 29 | 30 | decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, 31 | dropout, activation, normalize_before) 32 | decoder_norm = nn.LayerNorm(d_model) 33 | self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm, 34 | return_intermediate=return_intermediate_dec) 35 | 36 | self._reset_parameters() 37 | 38 | self.d_model = d_model 39 | self.nhead = nhead 40 | 41 | def _reset_parameters(self): 42 | for p in self.parameters(): 43 | if p.dim() > 1: 44 | nn.init.xavier_uniform_(p) 45 | 46 | def forward(self, src, mask, query_embed, pos_embed, query_pos=None): 47 | # flatten NxCxHxW to HWxNxC 48 | bs = src.shape[1] 49 | 50 | if pos_embed is not None: 51 | pos_embed = pos_embed.flatten(2) 52 | 53 | # reshape Embeddings in QUERY to standard Transformer input 54 | if query_embed is not None: 55 | query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) 56 | 57 | if mask is not None: 58 | mask = mask.flatten(1) 59 | 60 | tgt = query_embed 61 | memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) 62 | 63 | # apply Transformer 64 | hs = self.decoder(tgt, memory, memory_key_padding_mask=None, pos=None, query_pos=None) 65 | 66 | # return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w) 67 | return hs.squeeze(), memory 68 | 69 | 70 | class TransformerEncoder(nn.Module): 71 | 72 | def __init__(self, encoder_layer, num_layers, norm=None): 73 | super().__init__() 74 | self.layers = _get_clones(encoder_layer, num_layers) 75 | self.num_layers = num_layers 76 | self.norm = norm 77 | 78 | def forward(self, src, 79 | mask: Optional[Tensor] = None, 80 | src_key_padding_mask: Optional[Tensor] = None, 81 | pos: Optional[Tensor] = None): 82 | output = src 83 | 84 | for layer in self.layers: 85 | output = layer(output, src_mask=mask, 86 | src_key_padding_mask=src_key_padding_mask, pos=pos) 87 | 88 | if self.norm is not None: 89 | output = self.norm(output) 90 | 91 | return output 92 | 93 | 94 | class TransformerDecoder(nn.Module): 95 | 96 | def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): 97 | super().__init__() 98 | self.layers = _get_clones(decoder_layer, num_layers) 99 | self.num_layers = num_layers 100 | self.norm = norm 101 | self.return_intermediate = return_intermediate 102 | 103 | def forward(self, tgt, memory, 104 | tgt_mask: Optional[Tensor] = None, 105 | memory_mask: Optional[Tensor] = None, 106 | tgt_key_padding_mask: Optional[Tensor] = None, 107 | memory_key_padding_mask: Optional[Tensor] = None, 108 | pos: Optional[Tensor] = None, 109 | query_pos: Optional[Tensor] = None): 110 | output = tgt 111 | 112 | intermediate = [] 113 | 114 | for layer in self.layers: 115 | output = layer(output, memory, tgt_mask=tgt_mask, 116 | memory_mask=memory_mask, 117 | tgt_key_padding_mask=tgt_key_padding_mask, 118 | memory_key_padding_mask=memory_key_padding_mask, 119 | pos=pos, query_pos=query_pos) 120 | if self.return_intermediate: 121 | intermediate.append(self.norm(output)) 122 | 123 | if self.norm is not None: 124 | output = self.norm(output) 125 | if self.return_intermediate: 126 | intermediate.pop() 127 | intermediate.append(output) 128 | 129 | if self.return_intermediate: 130 | return torch.stack(intermediate) 131 | 132 | return output.unsqueeze(0) 133 | 134 | 135 | class TransformerEncoderLayer(nn.Module): 136 | 137 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 138 | activation="relu", normalize_before=False): 139 | super().__init__() 140 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 141 | # Implementation of Feedforward model 142 | self.linear1 = nn.Linear(d_model, dim_feedforward) 143 | self.dropout = nn.Dropout(dropout) 144 | self.linear2 = nn.Linear(dim_feedforward, d_model) 145 | 146 | self.norm1 = nn.LayerNorm(d_model) 147 | self.norm2 = nn.LayerNorm(d_model) 148 | self.dropout1 = nn.Dropout(dropout) 149 | self.dropout2 = nn.Dropout(dropout) 150 | 151 | self.activation = _get_activation_fn(activation) 152 | self.normalize_before = normalize_before 153 | 154 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 155 | return tensor if pos is None else tensor + pos 156 | 157 | def forward_post(self, 158 | src, 159 | src_mask: Optional[Tensor] = None, 160 | src_key_padding_mask: Optional[Tensor] = None, 161 | pos: Optional[Tensor] = None): 162 | q = k = self.with_pos_embed(src, pos) 163 | src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, 164 | key_padding_mask=src_key_padding_mask)[0] 165 | src = src + self.dropout1(src2) 166 | src = self.norm1(src) 167 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) 168 | src = src + self.dropout2(src2) 169 | src = self.norm2(src) 170 | return src 171 | 172 | def forward_pre(self, src, 173 | src_mask: Optional[Tensor] = None, 174 | src_key_padding_mask: Optional[Tensor] = None, 175 | pos: Optional[Tensor] = None): 176 | src2 = self.norm1(src) 177 | q = k = self.with_pos_embed(src2, pos) 178 | src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, 179 | key_padding_mask=src_key_padding_mask)[0] 180 | src = src + self.dropout1(src2) 181 | src2 = self.norm2(src) 182 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) 183 | src = src + self.dropout2(src2) 184 | return src 185 | 186 | def forward(self, src, 187 | src_mask: Optional[Tensor] = None, 188 | src_key_padding_mask: Optional[Tensor] = None, 189 | pos: Optional[Tensor] = None): 190 | if self.normalize_before: 191 | return self.forward_pre(src, src_mask, src_key_padding_mask, pos) 192 | return self.forward_post(src, src_mask, src_key_padding_mask, pos) 193 | 194 | 195 | class TransformerDecoderLayer(nn.Module): 196 | 197 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 198 | activation="relu", normalize_before=False): 199 | super().__init__() 200 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 201 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 202 | # Implementation of Feedforward model 203 | self.linear1 = nn.Linear(d_model, dim_feedforward) 204 | self.dropout = nn.Dropout(dropout) 205 | self.linear2 = nn.Linear(dim_feedforward, d_model) 206 | 207 | self.norm1 = nn.LayerNorm(d_model) 208 | self.norm2 = nn.LayerNorm(d_model) 209 | self.norm3 = nn.LayerNorm(d_model) 210 | self.dropout1 = nn.Dropout(dropout) 211 | self.dropout2 = nn.Dropout(dropout) 212 | self.dropout3 = nn.Dropout(dropout) 213 | 214 | self.activation = _get_activation_fn(activation) 215 | self.normalize_before = normalize_before 216 | 217 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 218 | return tensor if pos is None else tensor + pos 219 | 220 | def forward_post(self, tgt, memory, 221 | tgt_mask: Optional[Tensor] = None, 222 | memory_mask: Optional[Tensor] = None, 223 | tgt_key_padding_mask: Optional[Tensor] = None, 224 | memory_key_padding_mask: Optional[Tensor] = None, 225 | pos: Optional[Tensor] = None, 226 | query_pos: Optional[Tensor] = None): 227 | q = k = self.with_pos_embed(tgt, query_pos) 228 | tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, 229 | key_padding_mask=tgt_key_padding_mask)[0] 230 | tgt = tgt + self.dropout1(tgt2) 231 | tgt = self.norm1(tgt) 232 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), 233 | key=self.with_pos_embed(memory, pos), 234 | value=memory, attn_mask=memory_mask, 235 | key_padding_mask=memory_key_padding_mask)[0] 236 | tgt = tgt + self.dropout2(tgt2) 237 | tgt = self.norm2(tgt) 238 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 239 | tgt = tgt + self.dropout3(tgt2) 240 | tgt = self.norm3(tgt) 241 | return tgt 242 | 243 | def forward_pre(self, tgt, memory, 244 | tgt_mask: Optional[Tensor] = None, 245 | memory_mask: Optional[Tensor] = None, 246 | tgt_key_padding_mask: Optional[Tensor] = None, 247 | memory_key_padding_mask: Optional[Tensor] = None, 248 | pos: Optional[Tensor] = None, 249 | query_pos: Optional[Tensor] = None): 250 | tgt2 = self.norm1(tgt) 251 | q = k = self.with_pos_embed(tgt2, query_pos) 252 | tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, 253 | key_padding_mask=tgt_key_padding_mask)[0] 254 | tgt = tgt + self.dropout1(tgt2) 255 | tgt2 = self.norm2(tgt) 256 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), 257 | key=self.with_pos_embed(memory, pos), 258 | value=memory, attn_mask=memory_mask, 259 | key_padding_mask=memory_key_padding_mask)[0] 260 | tgt = tgt + self.dropout2(tgt2) 261 | tgt2 = self.norm3(tgt) 262 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) 263 | tgt = tgt + self.dropout3(tgt2) 264 | return tgt 265 | 266 | def forward(self, tgt, memory, 267 | tgt_mask: Optional[Tensor] = None, 268 | memory_mask: Optional[Tensor] = None, 269 | tgt_key_padding_mask: Optional[Tensor] = None, 270 | memory_key_padding_mask: Optional[Tensor] = None, 271 | pos: Optional[Tensor] = None, 272 | query_pos: Optional[Tensor] = None): 273 | if self.normalize_before: 274 | return self.forward_pre(tgt, memory, tgt_mask, memory_mask, 275 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) 276 | return self.forward_post(tgt, memory, tgt_mask, memory_mask, 277 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) 278 | 279 | 280 | def _get_clones(module, N): 281 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 282 | 283 | 284 | def build_transformer(args): 285 | return Transformer( 286 | d_model=args.hidden_dim, 287 | dropout=args.dropout, 288 | nhead=args.nheads, 289 | dim_feedforward=args.dim_feedforward, 290 | num_encoder_layers=args.enc_layers, 291 | num_decoder_layers=args.dec_layers, 292 | normalize_before=args.pre_norm, 293 | return_intermediate_dec=True, 294 | ) 295 | 296 | 297 | def _get_activation_fn(activation): 298 | """Return an activation function given a string""" 299 | if activation == "relu": 300 | return F.relu 301 | if activation == "gelu": 302 | return F.gelu 303 | if activation == "glu": 304 | return F.glu 305 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.") 306 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import os 4 | import warnings 5 | from pathlib import Path 6 | 7 | import GPUtil 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | from tensorboardX import SummaryWriter 12 | from termcolor import colored 13 | from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR 14 | from tqdm import tqdm 15 | 16 | from datasets.DatasetPairwiseTriplets import DatasetPairwiseTriplets 17 | from networks.MultiscaleTransformerEncoder import MultiscaleTransformerEncoder 18 | from networks.losses import OnlineHardNegativeMiningTripletLoss 19 | from util.read_hdf5_data import read_hdf5_data 20 | from util.utils import load_model, MultiEpochsDataLoader, MyGradScaler, save_best_model_stats, evaluate_test, \ 21 | load_test_datasets, evaluate_validation, load_validation_set 22 | from util.warmup_scheduler import GradualWarmupSchedulerV2 23 | 24 | warnings.filterwarnings("ignore", message="UserWarning: albumentations.augmentations.transforms.RandomResizedCrop") 25 | 26 | 27 | def assert_dir(dir_path): 28 | Path(dir_path).mkdir(parents=True, exist_ok=True) 29 | 30 | 31 | def load_datasets_paths(ds_name, ds_path): 32 | if ds_name == 'visnir': 33 | test_dir = os.path.join(ds_path, 'test\\') 34 | train_file = os.path.join(ds_path, 'train\\train.hdf5') 35 | elif ds_name == 'cuhk': 36 | test_dir = os.path.join(ds_path, 'en_etal\\cuhk\\test\\') 37 | train_file = os.path.join(ds_path, 'en_etal\\cuhk\\train.hdf5') 38 | elif ds_name == 'vedai': 39 | test_dir = os.path.join(ds_path, 'en_etal\\vedai\\test\\') 40 | train_file = os.path.join(ds_path, 'en_etal\\vedai\\train.hdf5') 41 | elif ds_name == 'visnir-grid': 42 | test_dir = os.path.join(ds_path, 'en_etal\\visnir\\test\\') 43 | train_file = os.path.join(ds_path, 'en_etal\\visnir\\train.hdf5') 44 | elif ds_name == 'ubc-liberty': 45 | test_dir = os.path.join(ds_path, 'ubc\\test_yos_not\\') 46 | train_file = os.path.join(ds_path, 'ubc\\liberty_train_full.hdf5') 47 | elif ds_name == 'ubc-notredame': 48 | test_dir = os.path.join(ds_path, 'ubc\\test_lib_yos\\') 49 | train_file = os.path.join(ds_path, 'ubc\\notredame_train_full.hdf5') 50 | elif ds_name == 'ubc-yosemite': 51 | test_dir = os.path.join(ds_path, 'ubc\\test_lib_not\\') 52 | train_file = os.path.join(ds_path, 'ubc\\yosemite_train_full.hdf5') 53 | return train_file, test_dir 54 | 55 | 56 | def create_optimizer(net, lr_rate, weight_decay): 57 | return torch.optim.Adam( 58 | [{'params': filter(lambda p: p.requires_grad == True, net.parameters()), 'lr': lr_rate, 59 | 'weight_decay': weight_decay}, 60 | {'params': filter(lambda p: p.requires_grad == False, net.parameters()), 'lr': 0, 61 | 'weight_decay': 0}], 62 | lr=0, weight_decay=0) 63 | 64 | 65 | def train(net, train_dataloader, start_epoch, device, warmup_epochs, generator_mode, lr_rate, weight_decay, 66 | writer, evaluate_net_steps, models_dir, best_file_name, outer_batch_size, inner_batch_size, 67 | optimizer, scheduler, scheduler_warmup, criterion, lowest_err, arch_desc, test_data, val_data, val_labels, 68 | epochs, scheduler_patience): 69 | scaler = MyGradScaler() 70 | 71 | for epoch in range(start_epoch, epochs): 72 | optimizer.zero_grad() 73 | is_warmup_phase = epoch - start_epoch < warmup_epochs 74 | 75 | if is_warmup_phase: 76 | print('\n', colored('Warmup step #' + repr(epoch - start_epoch), 'green', attrs=['reverse', 'blink'])) 77 | scheduler_warmup.step() 78 | else: 79 | if epoch > start_epoch: 80 | if type(scheduler).__name__ == 'StepLR': 81 | scheduler.step() 82 | 83 | if type(scheduler).__name__ == 'ReduceLROnPlateau': 84 | scheduler.step(total_test_err) 85 | running_loss = 0 86 | 87 | log = 'LR: ' 88 | for param_group in optimizer.param_groups: 89 | log += repr(param_group['lr']) + ' ' 90 | print('\n', colored(log, 'blue', attrs=['reverse', 'blink'])) 91 | 92 | print('negative_mining_mode: ' + criterion.mode) 93 | print('generator_mode: ' + generator_mode) 94 | 95 | should_mine_hard_negatives = criterion.mode == 'Random' and \ 96 | optimizer.param_groups[0]['lr'] <= (lr_rate / 1e3 + 1e-8) and \ 97 | not is_warmup_phase 98 | if should_mine_hard_negatives: 99 | print(colored('Switching Random->Hardest', 'green', attrs=['reverse', 'blink'])) 100 | criterion = OnlineHardNegativeMiningTripletLoss(margin=1, mode='Hardest', device=device) 101 | 102 | optimizer = create_optimizer(net, lr_rate, weight_decay) 103 | 104 | scheduler_warmup = GradualWarmupSchedulerV2(optimizer, multiplier=1, total_epoch=warmup_epochs) 105 | start_epoch = epoch 106 | 107 | if type(scheduler).__name__ == 'StepLR': 108 | scheduler = StepLR(optimizer, step_size=10, gamma=0.1) 109 | 110 | if type(scheduler).__name__ == 'ReduceLROnPlateau': 111 | scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=scheduler_patience, 112 | verbose=True) 113 | 114 | bar = tqdm(train_dataloader, 0, leave=False, total=math.ceil((len(train_dataloader) - 1) / inner_batch_size)) 115 | for batch_num, data in enumerate(bar): 116 | 117 | # zero the parameter gradients 118 | optimizer.zero_grad() 119 | 120 | net = net.train() 121 | 122 | # get the inputs 123 | pos1 = data['pos1'] 124 | pos2 = data['pos2'] 125 | 126 | pos1 = np.reshape(pos1, (pos1.shape[0] * pos1.shape[1], 1, pos1.shape[2], pos1.shape[3]), order='F') 127 | pos2 = np.reshape(pos2, (pos2.shape[0] * pos2.shape[1], 1, pos2.shape[2], pos2.shape[3]), order='F') 128 | 129 | pos1, pos2 = pos1.to(device), pos2.to(device) 130 | 131 | emb = net(pos1, pos2) 132 | 133 | loss = criterion(emb['Emb1'], emb['Emb2']) + criterion(emb['Emb2'], emb['Emb1']) 134 | 135 | scaler.scale(loss).backward() 136 | scaler.step(optimizer) 137 | scaler.update() 138 | 139 | running_loss += loss.item() 140 | if epoch >= 50: 141 | evaluate_net_steps = 20 142 | if (batch_num % evaluate_net_steps == 0 or batch_num * inner_batch_size >= len(train_dataloader) - 1) and \ 143 | batch_num > 0: 144 | 145 | if batch_num > 0: 146 | running_loss /= batch_num 147 | 148 | net.eval() 149 | val_err = 0 150 | if len(val_data) > 0: 151 | val_err = evaluate_validation(net, val_data, val_labels, device) 152 | 153 | # test accuracy 154 | total_test_err = evaluate_test(net, test_data, device) 155 | 156 | state = {'epoch': epoch, 157 | 'state_dict': net.module.state_dict(), 158 | 'optimizer_name': type(optimizer).__name__, 159 | 'optimizer': optimizer, 160 | 'scheduler_name': type(scheduler).__name__, 161 | 'scheduler': scheduler, 162 | 'arch_desc': arch_desc, 163 | 'lowest_err': lowest_err, 164 | 'outer_batch_size': outer_batch_size, 165 | 'inner_batch_size': inner_batch_size, 166 | 'negative_mining_mode': criterion.mode, 167 | 'generator_mode': generator_mode, 168 | 'loss': criterion.mode} 169 | 170 | if total_test_err < lowest_err: 171 | lowest_err = total_test_err 172 | 173 | print('\n', colored('Best error found and saved: ' + repr(total_test_err)[0:5], 'red', 174 | attrs=['reverse', 'blink'])) 175 | filepath = os.path.join(models_dir, best_file_name + '.pth') 176 | # torch.save(state, filepath) 177 | save_best_model_stats(models_dir, epoch, total_test_err, test_data) 178 | 179 | log = '[%d, %5d] Loss: %.3f' % (epoch, batch_num, 100 * running_loss) + ' Val Error: ' + repr(val_err)[ 180 | 0:6] 181 | log += ' Test Error: ' + repr(total_test_err)[0:6] 182 | print(log) 183 | 184 | writer.add_scalar('Val Error', val_err, epoch * len(train_dataloader) + batch_num) 185 | writer.add_scalar('Test Error', total_test_err, epoch * len(train_dataloader) + batch_num) 186 | writer.add_scalar('Loss', 100 * running_loss, epoch * len(train_dataloader) + batch_num) 187 | writer.add_scalar('Learning Rate', optimizer.param_groups[0]['lr'], 188 | epoch * len(train_dataloader) + batch_num) 189 | writer.add_text('Log', log) 190 | writer.close() 191 | 192 | # save epoch 193 | filepath = models_dir + 'model_epoch_' + repr(epoch) + '.pth' 194 | # torch.save(state, filepath) 195 | 196 | if (batch_num * inner_batch_size) > (len(train_dataloader) - 1): 197 | bar.clear() 198 | bar.close() 199 | break 200 | 201 | print('Finished Training') 202 | 203 | 204 | def parse_args(): 205 | parser = argparse.ArgumentParser(description='Train models for multimodal patch matching.') 206 | parser.add_argument('--epochs', type=int, default=90, help='epochs') 207 | parser.add_argument('--artifacts', default='./artifacts', help='artifacts path') 208 | parser.add_argument('--exp-name', default='symmetric_enc_transformer_test_4', help='experiment name') 209 | parser.add_argument('--evaluate-every', type=int, default=100, help='evaluate network and print steps') 210 | parser.add_argument('--skip-validation', type=bool, const=True, default=False, 211 | help='whether to skip validation evaluation', nargs='?') 212 | parser.add_argument('--skip-test', type=bool, const=True, default=False, 213 | help='whether to skip test evaluation', nargs='?') 214 | parser.add_argument('--continue-from-checkpoint', type=bool, const=True, default=False, 215 | nargs='?', help='whether to continue training from checkpoint') 216 | parser.add_argument('--continue-from-best-score', type=bool, const=True, default=False, 217 | nargs='?', help='whether to use best score when continuing training') 218 | parser.add_argument('--continue-from-best-model', type=bool, const=True, default=True, 219 | nargs='?', help='whether to continue training using best model') 220 | parser.add_argument('--batch-size', type=int, default=48, help='batch size') 221 | parser.add_argument('--inner-batch-size', type=int, default=24, help='inner batch size of positive pairs') 222 | parser.add_argument('--lr', type=float, default=1e-1, help='learning rate') 223 | parser.add_argument('--dropout', type=float, default=0.5, help='dropout') 224 | parser.add_argument('--weight-decay', type=float, default=0, help='weight decay') 225 | parser.add_argument('--dataset-name', default='visnir', help='dataset name') 226 | parser.add_argument('--dataset-path', default='visnir', help='dataset name') 227 | parser.add_argument('--warmup-epochs', type=int, default=14, help='warmup epochs') 228 | parser.add_argument('--scheduler-patience', type=int, default=6, help='scheduler patience epochs') 229 | return parser.parse_args() 230 | 231 | 232 | def main(): 233 | args = parse_args() 234 | np.random.seed(0) 235 | torch.manual_seed(0) 236 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 237 | gpus_num = torch.cuda.device_count() 238 | torch.cuda.empty_cache() 239 | GPUtil.showUtilization() 240 | print('Using', device) 241 | 242 | models_dir = os.path.join(args.artifacts, args.exp_name, 'models') 243 | logs_dirname = os.path.join(args.artifacts, args.exp_name, 'logs') 244 | arch_desc = 'Symmetric CNN with Triplet loss and transformer encoder' 245 | best_file_name = 'best_model' 246 | 247 | train_file, test_dir = load_datasets_paths(args.dataset_name, args.dataset_path) 248 | 249 | assert_dir(models_dir) 250 | assert_dir(logs_dirname) 251 | 252 | start_best_model = args.continue_from_best_model 253 | use_best_score = args.continue_from_best_score 254 | writer = SummaryWriter(logs_dirname) 255 | generator_mode = 'Pairwise' 256 | negative_mining_mode = 'Random' 257 | skip_validation = args.skip_validation 258 | skip_test = args.skip_test 259 | lr_rate = args.lr 260 | weight_decay = args.weight_decay 261 | dropout = args.dropout 262 | outer_batch_size = args.batch_size 263 | inner_batch_size = args.inner_batch_size 264 | epochs = args.epochs 265 | scheduler_patience = args.scheduler_patience 266 | augmentations = { 267 | "Test": False, 268 | "HorizontalFlip": True, 269 | "Rotate90": True, 270 | "VerticalFlip": False, 271 | "RandomCrop": {'Do': False} 272 | } 273 | evaluate_net_steps = args.evaluate_every 274 | 275 | data = read_hdf5_data(train_file) 276 | train_data = data['Data'] 277 | train_labels = np.squeeze(data['Labels']) 278 | train_split = np.squeeze(data['Set']) 279 | del data 280 | 281 | val_data = [] 282 | val_labels = [] 283 | train_indices = np.squeeze(np.asarray(np.where(train_split == 1))) 284 | if not skip_validation: 285 | val_data, val_labels = load_validation_set(train_data, train_split, train_labels) 286 | 287 | train_data = np.squeeze(train_data[train_indices,]) 288 | train_labels = train_labels[train_indices] 289 | 290 | train_dataset = DatasetPairwiseTriplets(train_data, train_labels, inner_batch_size, augmentations, 291 | generator_mode) 292 | train_dataloader = MultiEpochsDataLoader(train_dataset, batch_size=outer_batch_size, shuffle=True, 293 | num_workers=8, pin_memory=True) 294 | 295 | test_data = None 296 | if not skip_test: 297 | test_data = load_test_datasets(test_dir) 298 | 299 | net = MultiscaleTransformerEncoder(dropout) 300 | optimizer = create_optimizer(net, lr_rate, weight_decay) 301 | start_epoch = 0 302 | lowest_err = 1e10 303 | if args.continue_from_checkpoint: 304 | net, optimizer, lowest_err, start_epoch, scheduler, loaded_negative_mining_mode = load_model(net, 305 | start_best_model, 306 | models_dir, 307 | best_file_name, 308 | use_best_score, 309 | device) 310 | 311 | if gpus_num > 1: 312 | print("Using", torch.cuda.device_count(), "GPUs") 313 | net = nn.DataParallel(net) 314 | net.to(device) 315 | 316 | scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=scheduler_patience, verbose=True) 317 | 318 | warmup_epochs = args.warmup_epochs 319 | scheduler_warmup = GradualWarmupSchedulerV2(optimizer, multiplier=1, total_epoch=warmup_epochs, 320 | after_scheduler=StepLR(optimizer, step_size=3, gamma=0.1)) 321 | 322 | criterion = OnlineHardNegativeMiningTripletLoss(margin=1, mode=negative_mining_mode, device=device) 323 | 324 | train(net, train_dataloader, start_epoch, device, warmup_epochs, generator_mode, lr_rate, weight_decay, 325 | writer, evaluate_net_steps, models_dir, best_file_name, outer_batch_size, inner_batch_size, 326 | optimizer, scheduler, scheduler_warmup, criterion, lowest_err, arch_desc, test_data, val_data, val_labels, 327 | epochs, scheduler_patience) 328 | 329 | 330 | if __name__ == '__main__': 331 | main() 332 | -------------------------------------------------------------------------------- /license.md: -------------------------------------------------------------------------------- 1 | # Attribution-NonCommercial-NoDerivatives 4.0 International 2 | 3 | > *Creative Commons Corporation (“Creative Commons”) is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an “as-is” basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible.* 4 | > 5 | > ### Using Creative Commons Public Licenses 6 | > 7 | > Creative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses. 8 | > 9 | > * __Considerations for licensors:__ Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other CC-licensed material, or material used under an exception or limitation to copyright. [More considerations for licensors](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensors). 10 | > 11 | > * __Considerations for the public:__ By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor’s permission is not necessary for any reason–for example, because of any applicable exception or limitation to copyright–then that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. [More considerations for the public](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensees). 12 | 13 | ## Creative Commons Attribution-NonCommercial-NoDerivatives 4.0 International Public License 14 | 15 | By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial-NoDerivatives 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions. 16 | 17 | ### Section 1 – Definitions. 18 | 19 | a. __Adapted Material__ means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image. 20 | 21 | b. __Copyright and Similar Rights__ means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights. 22 | 23 | e. __Effective Technological Measures__ means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements. 24 | 25 | f. __Exceptions and Limitations__ means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material. 26 | 27 | h. __Licensed Material__ means the artistic or literary work, database, or other material to which the Licensor applied this Public License. 28 | 29 | i. __Licensed Rights__ means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license. 30 | 31 | h. __Licensor__ means the individual(s) or entity(ies) granting rights under this Public License. 32 | 33 | i. __NonCommercial__ means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange. 34 | 35 | j. __Share__ means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them. 36 | 37 | k. __Sui Generis Database Rights__ means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world. 38 | 39 | l. __You__ means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning. 40 | 41 | ### Section 2 – Scope. 42 | 43 | a. ___License grant.___ 44 | 45 | 1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to: 46 | 47 | A. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and 48 | 49 | B. produce and reproduce, but not Share, Adapted Material for NonCommercial purposes only. 50 | 51 | 2. __Exceptions and Limitations.__ For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions. 52 | 53 | 3. __Term.__ The term of this Public License is specified in Section 6(a). 54 | 55 | 4. __Media and formats; technical modifications allowed.__ The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material. 56 | 57 | 5. __Downstream recipients.__ 58 | 59 | A. __Offer from the Licensor – Licensed Material.__ Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License. 60 | 61 | B. __No downstream restrictions.__ You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material. 62 | 63 | 6. __No endorsement.__ Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i). 64 | 65 | b. ___Other rights.___ 66 | 67 | 1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise. 68 | 69 | 2. Patent and trademark rights are not licensed under this Public License. 70 | 71 | 3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes. 72 | 73 | ### Section 3 – License Conditions. 74 | 75 | Your exercise of the Licensed Rights is expressly made subject to the following conditions. 76 | 77 | a. ___Attribution.___ 78 | 79 | 1. If You Share the Licensed Material, You must: 80 | 81 | A. retain the following if it is supplied by the Licensor with the Licensed Material: 82 | 83 | i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated); 84 | 85 | ii. a copyright notice; 86 | 87 | iii. a notice that refers to this Public License; 88 | 89 | iv. a notice that refers to the disclaimer of warranties; 90 | 91 | v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable; 92 | 93 | B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and 94 | 95 | C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License. 96 | 97 | For the avoidance of doubt, You do not have permission under this Public License to Share Adapted Material. 98 | 99 | 2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information. 100 | 101 | 3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable. 102 | 103 | ### Section 4 – Sui Generis Database Rights. 104 | 105 | Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material: 106 | 107 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only and provided You do not Share Adapted Material; 108 | 109 | b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material; and 110 | 111 | c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database. 112 | 113 | For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights. 114 | 115 | ### Section 5 – Disclaimer of Warranties and Limitation of Liability. 116 | 117 | a. __Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You.__ 118 | 119 | b. __To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You.__ 120 | 121 | c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability. 122 | 123 | ### Section 6 – Term and Termination. 124 | 125 | a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically. 126 | 127 | b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates: 128 | 129 | 1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or 130 | 131 | 2. upon express reinstatement by the Licensor. 132 | 133 | For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License. 134 | 135 | c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License. 136 | 137 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License. 138 | 139 | ### Section 7 – Other Terms and Conditions. 140 | 141 | a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed. 142 | 143 | b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License. 144 | 145 | ### Section 8 – Interpretation. 146 | 147 | a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License. 148 | 149 | b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions. 150 | 151 | c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor. 152 | 153 | d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority. 154 | 155 | > Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the “Licensor.” Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at [creativecommons.org/policies](http://creativecommons.org/policies), Creative Commons does not authorize the use of the trademark “Creative Commons” or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses. 156 | > 157 | > Creative Commons may be contacted at [creativecommons.org](http://creativecommons.org). 158 | -------------------------------------------------------------------------------- /networks/transforms.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | 4 | import torch 5 | from PIL import Image 6 | 7 | try: 8 | import accimage 9 | except ImportError: 10 | accimage = None 11 | import numpy as np 12 | import numbers 13 | from collections.abc import Sequence, Iterable 14 | import warnings 15 | 16 | import torch.nn.functional as F 17 | 18 | __all__ = ["Compose", "ToTensor", "PILToTensor", "ConvertImageDtype", "ToPILImage", "Normalize", "Resize", "Scale", 19 | "CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", 20 | "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", 21 | "LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale", 22 | "RandomPerspective", "RandomErasing"] 23 | 24 | _pil_interpolation_to_str = { 25 | Image.NEAREST: 'PIL.Image.NEAREST', 26 | Image.BILINEAR: 'PIL.Image.BILINEAR', 27 | Image.BICUBIC: 'PIL.Image.BICUBIC', 28 | Image.LANCZOS: 'PIL.Image.LANCZOS', 29 | Image.HAMMING: 'PIL.Image.HAMMING', 30 | Image.BOX: 'PIL.Image.BOX', 31 | } 32 | 33 | 34 | def _get_image_size(img): 35 | if F._is_pil_image(img): 36 | return img.size 37 | elif isinstance(img, torch.Tensor) and img.dim() > 2: 38 | return img.shape[-2:][::-1] 39 | else: 40 | raise TypeError("Unexpected type {}".format(type(img))) 41 | 42 | 43 | class Compose(object): 44 | """Composes several transforms together. 45 | 46 | Args: 47 | transforms (list of ``Transform`` objects): list of transforms to compose. 48 | 49 | Example: 50 | >>> transforms.Compose([ 51 | >>> transforms.CenterCrop(10), 52 | >>> transforms.ToTensor(), 53 | >>> ]) 54 | """ 55 | 56 | def __init__(self, transforms): 57 | self.transforms = transforms 58 | 59 | def __call__(self, img): 60 | for t in self.transforms: 61 | img = t(img) 62 | 63 | if img.ndims == 3: 64 | aa = 9 65 | 66 | return img 67 | 68 | def __repr__(self): 69 | format_string = self.__class__.__name__ + '(' 70 | for t in self.transforms: 71 | format_string += '\n' 72 | format_string += ' {0}'.format(t) 73 | format_string += '\n)' 74 | return format_string 75 | 76 | 77 | class ToTensor(object): 78 | """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. 79 | 80 | Converts a PIL Image or numpy.ndarray (H x W x C) in the range 81 | [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] 82 | if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1) 83 | or if the numpy.ndarray has dtype = np.uint8 84 | 85 | In the other cases, tensors are returned without scaling. 86 | """ 87 | 88 | def __call__(self, pic): 89 | """ 90 | Args: 91 | pic (PIL Image or numpy.ndarray): Image to be converted to tensor. 92 | 93 | Returns: 94 | Tensor: Converted image. 95 | """ 96 | return F.to_tensor(pic) 97 | 98 | def __repr__(self): 99 | return self.__class__.__name__ + '()' 100 | 101 | 102 | class PILToTensor(object): 103 | """Convert a ``PIL Image`` to a tensor of the same type. 104 | 105 | Converts a PIL Image (H x W x C) to a torch.Tensor of shape (C x H x W). 106 | """ 107 | 108 | def __call__(self, pic): 109 | """ 110 | Args: 111 | pic (PIL Image): Image to be converted to tensor. 112 | 113 | Returns: 114 | Tensor: Converted image. 115 | """ 116 | return F.pil_to_tensor(pic) 117 | 118 | def __repr__(self): 119 | return self.__class__.__name__ + '()' 120 | 121 | 122 | class ConvertImageDtype(object): 123 | """Convert a tensor image to the given ``dtype`` and scale the values accordingly 124 | 125 | Args: 126 | dtype (torch.dtype): Desired data type of the output 127 | 128 | .. note:: 129 | 130 | When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly. 131 | If converted back and forth, this mismatch has no effect. 132 | 133 | Raises: 134 | RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as 135 | well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to 136 | overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range 137 | of the integer ``dtype``. 138 | """ 139 | 140 | def __init__(self, dtype: torch.dtype) -> None: 141 | self.dtype = dtype 142 | 143 | def __call__(self, image: torch.Tensor) -> torch.Tensor: 144 | return F.convert_image_dtype(image, self.dtype) 145 | 146 | 147 | class ToPILImage(object): 148 | """Convert a tensor or an ndarray to PIL Image. 149 | 150 | Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape 151 | H x W x C to a PIL Image while preserving the value range. 152 | 153 | Args: 154 | mode (`PIL.Image mode`_): color space and pixel depth of input data (optional). 155 | If ``mode`` is ``None`` (default) there are some assumptions made about the input data: 156 | - If the input has 4 channels, the ``mode`` is assumed to be ``RGBA``. 157 | - If the input has 3 channels, the ``mode`` is assumed to be ``RGB``. 158 | - If the input has 2 channels, the ``mode`` is assumed to be ``LA``. 159 | - If the input has 1 channel, the ``mode`` is determined by the data type (i.e ``int``, ``float``, 160 | ``short``). 161 | 162 | .. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes 163 | """ 164 | 165 | def __init__(self, mode=None): 166 | self.mode = mode 167 | 168 | def __call__(self, pic): 169 | """ 170 | Args: 171 | pic (Tensor or numpy.ndarray): Image to be converted to PIL Image. 172 | 173 | Returns: 174 | PIL Image: Image converted to PIL Image. 175 | 176 | """ 177 | return F.to_pil_image(pic, self.mode) 178 | 179 | def __repr__(self): 180 | format_string = self.__class__.__name__ + '(' 181 | if self.mode is not None: 182 | format_string += 'mode={0}'.format(self.mode) 183 | format_string += ')' 184 | return format_string 185 | 186 | 187 | class Normalize(object): 188 | """Normalize a tensor image with mean and standard deviation. 189 | Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n`` 190 | channels, this transform will normalize each channel of the input 191 | ``torch.*Tensor`` i.e., 192 | ``output[channel] = (input[channel] - mean[channel]) / std[channel]`` 193 | 194 | .. note:: 195 | This transform acts out of place, i.e., it does not mutate the input tensor. 196 | 197 | Args: 198 | mean (sequence): Sequence of means for each channel. 199 | std (sequence): Sequence of standard deviations for each channel. 200 | inplace(bool,optional): Bool to make this operation in-place. 201 | 202 | """ 203 | 204 | def __init__(self, mean, std, inplace=False): 205 | self.mean = mean 206 | self.std = std 207 | self.inplace = inplace 208 | 209 | def __call__(self, tensor): 210 | """ 211 | Args: 212 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 213 | 214 | Returns: 215 | Tensor: Normalized Tensor image. 216 | """ 217 | return F.normalize(tensor, self.mean, self.std, self.inplace) 218 | 219 | def __repr__(self): 220 | return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) 221 | 222 | 223 | class Resize(object): 224 | """Resize the input PIL Image to the given size. 225 | 226 | Args: 227 | size (sequence or int): Desired output size. If size is a sequence like 228 | (h, w), output size will be matched to this. If size is an int, 229 | smaller edge of the image will be matched to this number. 230 | i.e, if height > width, then image will be rescaled to 231 | (size * height / width, size) 232 | interpolation (int, optional): Desired interpolation. Default is 233 | ``PIL.Image.BILINEAR`` 234 | """ 235 | 236 | def __init__(self, size, interpolation=Image.BILINEAR): 237 | assert isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2) 238 | self.size = size 239 | self.interpolation = interpolation 240 | 241 | def __call__(self, img): 242 | """ 243 | Args: 244 | img (PIL Image): Image to be scaled. 245 | 246 | Returns: 247 | PIL Image: Rescaled image. 248 | """ 249 | return F.resize(img, self.size, self.interpolation) 250 | 251 | def __repr__(self): 252 | interpolate_str = _pil_interpolation_to_str[self.interpolation] 253 | return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str) 254 | 255 | 256 | class Scale(Resize): 257 | """ 258 | Note: This transform is deprecated in favor of Resize. 259 | """ 260 | 261 | def __init__(self, *args, **kwargs): 262 | warnings.warn("The use of the transforms.Scale transform is deprecated, " + 263 | "please use transforms.Resize instead.") 264 | super(Scale, self).__init__(*args, **kwargs) 265 | 266 | 267 | class CenterCrop(object): 268 | """Crops the given PIL Image at the center. 269 | 270 | Args: 271 | size (sequence or int): Desired output size of the crop. If size is an 272 | int instead of sequence like (h, w), a square crop (size, size) is 273 | made. 274 | """ 275 | 276 | def __init__(self, size): 277 | if isinstance(size, numbers.Number): 278 | self.size = (int(size), int(size)) 279 | else: 280 | self.size = size 281 | 282 | def __call__(self, img): 283 | """ 284 | Args: 285 | img (PIL Image): Image to be cropped. 286 | 287 | Returns: 288 | PIL Image: Cropped image. 289 | """ 290 | return F.center_crop(img, self.size) 291 | 292 | def __repr__(self): 293 | return self.__class__.__name__ + '(size={0})'.format(self.size) 294 | 295 | 296 | class Pad(object): 297 | """Pad the given PIL Image on all sides with the given "pad" value. 298 | 299 | Args: 300 | padding (int or tuple): Padding on each border. If a single int is provided this 301 | is used to pad all borders. If tuple of length 2 is provided this is the padding 302 | on left/right and top/bottom respectively. If a tuple of length 4 is provided 303 | this is the padding for the left, top, right and bottom borders 304 | respectively. 305 | fill (int or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of 306 | length 3, it is used to fill R, G, B channels respectively. 307 | This value is only used when the padding_mode is constant 308 | padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric. 309 | Default is constant. 310 | 311 | - constant: pads with a constant value, this value is specified with fill 312 | 313 | - edge: pads with the last value at the edge of the image 314 | 315 | - reflect: pads with reflection of image without repeating the last value on the edge 316 | 317 | For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode 318 | will result in [3, 2, 1, 2, 3, 4, 3, 2] 319 | 320 | - symmetric: pads with reflection of image repeating the last value on the edge 321 | 322 | For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode 323 | will result in [2, 1, 1, 2, 3, 4, 4, 3] 324 | """ 325 | 326 | def __init__(self, padding, fill=0, padding_mode='constant'): 327 | assert isinstance(padding, (numbers.Number, tuple)) 328 | assert isinstance(fill, (numbers.Number, str, tuple)) 329 | assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric'] 330 | if isinstance(padding, Sequence) and len(padding) not in [2, 4]: 331 | raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " + 332 | "{} element tuple".format(len(padding))) 333 | 334 | self.padding = padding 335 | self.fill = fill 336 | self.padding_mode = padding_mode 337 | 338 | def __call__(self, img): 339 | """ 340 | Args: 341 | img (PIL Image): Image to be padded. 342 | 343 | Returns: 344 | PIL Image: Padded image. 345 | """ 346 | return F.pad(img, self.padding, self.fill, self.padding_mode) 347 | 348 | def __repr__(self): 349 | return self.__class__.__name__ + '(padding={0}, fill={1}, padding_mode={2})'. \ 350 | format(self.padding, self.fill, self.padding_mode) 351 | 352 | 353 | class Lambda(object): 354 | """Apply a user-defined lambda as a transform. 355 | 356 | Args: 357 | lambd (function): Lambda/function to be used for transform. 358 | """ 359 | 360 | def __init__(self, lambd): 361 | assert callable(lambd), repr(type(lambd).__name__) + " object is not callable" 362 | self.lambd = lambd 363 | 364 | def __call__(self, img): 365 | return self.lambd(img) 366 | 367 | def __repr__(self): 368 | return self.__class__.__name__ + '()' 369 | 370 | 371 | class RandomTransforms(object): 372 | """Base class for a list of transformations with randomness 373 | 374 | Args: 375 | transforms (list or tuple): list of transformations 376 | """ 377 | 378 | def __init__(self, transforms): 379 | assert isinstance(transforms, (list, tuple)) 380 | self.transforms = transforms 381 | 382 | def __call__(self, *args, **kwargs): 383 | raise NotImplementedError() 384 | 385 | def __repr__(self): 386 | format_string = self.__class__.__name__ + '(' 387 | for t in self.transforms: 388 | format_string += '\n' 389 | format_string += ' {0}'.format(t) 390 | format_string += '\n)' 391 | return format_string 392 | 393 | 394 | class RandomApply(RandomTransforms): 395 | """Apply randomly a list of transformations with a given probability 396 | 397 | Args: 398 | transforms (list or tuple): list of transformations 399 | p (float): probability 400 | """ 401 | 402 | def __init__(self, transforms, p=0.5): 403 | super(RandomApply, self).__init__(transforms) 404 | self.p = p 405 | 406 | def __call__(self, img): 407 | if self.p < random.random(): 408 | return img 409 | for t in self.transforms: 410 | img = t(img) 411 | return img 412 | 413 | def __repr__(self): 414 | format_string = self.__class__.__name__ + '(' 415 | format_string += '\n p={}'.format(self.p) 416 | for t in self.transforms: 417 | format_string += '\n' 418 | format_string += ' {0}'.format(t) 419 | format_string += '\n)' 420 | return format_string 421 | 422 | 423 | class RandomOrder(RandomTransforms): 424 | """Apply a list of transformations in a random order 425 | """ 426 | 427 | def __call__(self, img): 428 | order = list(range(len(self.transforms))) 429 | random.shuffle(order) 430 | for i in order: 431 | img = self.transforms[i](img) 432 | return img 433 | 434 | 435 | class RandomChoice(RandomTransforms): 436 | """Apply single transformation randomly picked from a list 437 | """ 438 | 439 | def __call__(self, img): 440 | t = random.choice(self.transforms) 441 | return t(img) 442 | 443 | 444 | class RandomCrop(object): 445 | """Crop the given PIL Image at a random location. 446 | 447 | Args: 448 | size (sequence or int): Desired output size of the crop. If size is an 449 | int instead of sequence like (h, w), a square crop (size, size) is 450 | made. 451 | padding (int or sequence, optional): Optional padding on each border 452 | of the image. Default is None, i.e no padding. If a sequence of length 453 | 4 is provided, it is used to pad left, top, right, bottom borders 454 | respectively. If a sequence of length 2 is provided, it is used to 455 | pad left/right, top/bottom borders, respectively. 456 | pad_if_needed (boolean): It will pad the image if smaller than the 457 | desired size to avoid raising an exception. Since cropping is done 458 | after padding, the padding seems to be done at a random offset. 459 | fill: Pixel fill value for constant fill. Default is 0. If a tuple of 460 | length 3, it is used to fill R, G, B channels respectively. 461 | This value is only used when the padding_mode is constant 462 | padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant. 463 | 464 | - constant: pads with a constant value, this value is specified with fill 465 | 466 | - edge: pads with the last value on the edge of the image 467 | 468 | - reflect: pads with reflection of image (without repeating the last value on the edge) 469 | 470 | padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode 471 | will result in [3, 2, 1, 2, 3, 4, 3, 2] 472 | 473 | - symmetric: pads with reflection of image (repeating the last value on the edge) 474 | 475 | padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode 476 | will result in [2, 1, 1, 2, 3, 4, 4, 3] 477 | 478 | """ 479 | 480 | def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant'): 481 | if isinstance(size, numbers.Number): 482 | self.size = (int(size), int(size)) 483 | else: 484 | self.size = size 485 | self.padding = padding 486 | self.pad_if_needed = pad_if_needed 487 | self.fill = fill 488 | self.padding_mode = padding_mode 489 | 490 | @staticmethod 491 | def get_params(img, output_size): 492 | """Get parameters for ``crop`` for a random crop. 493 | 494 | Args: 495 | img (PIL Image): Image to be cropped. 496 | output_size (tuple): Expected output size of the crop. 497 | 498 | Returns: 499 | tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. 500 | """ 501 | w, h = _get_image_size(img) 502 | th, tw = output_size 503 | if w == tw and h == th: 504 | return 0, 0, h, w 505 | 506 | i = random.randint(0, h - th) 507 | j = random.randint(0, w - tw) 508 | return i, j, th, tw 509 | 510 | def __call__(self, img): 511 | """ 512 | Args: 513 | img (PIL Image): Image to be cropped. 514 | 515 | Returns: 516 | PIL Image: Cropped image. 517 | """ 518 | if self.padding is not None: 519 | img = F.pad(img, self.padding, self.fill, self.padding_mode) 520 | 521 | # pad the width if needed 522 | if self.pad_if_needed and img.size[0] < self.size[1]: 523 | img = F.pad(img, (self.size[1] - img.size[0], 0), self.fill, self.padding_mode) 524 | # pad the height if needed 525 | if self.pad_if_needed and img.size[1] < self.size[0]: 526 | img = F.pad(img, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode) 527 | 528 | i, j, h, w = self.get_params(img, self.size) 529 | 530 | return F.crop(img, i, j, h, w) 531 | 532 | def __repr__(self): 533 | return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding) 534 | 535 | 536 | class RandomHorizontalFlip(torch.nn.Module): 537 | """Horizontally flip the given image randomly with a given probability. 538 | The image can be a PIL Image or a torch Tensor, in which case it is expected 539 | to have [..., H, W] shape, where ... means an arbitrary number of leading 540 | dimensions 541 | 542 | Args: 543 | p (float): probability of the image being flipped. Default value is 0.5 544 | """ 545 | 546 | def __init__(self, p=0.5): 547 | super().__init__() 548 | self.p = p 549 | 550 | def forward(self, img): 551 | """ 552 | Args: 553 | img (PIL Image or Tensor): Image to be flipped. 554 | 555 | Returns: 556 | PIL Image or Tensor: Randomly flipped image. 557 | """ 558 | if torch.rand(1) < self.p: 559 | return F.hflip(img) 560 | return img 561 | 562 | def __repr__(self): 563 | return self.__class__.__name__ + '(p={})'.format(self.p) 564 | 565 | 566 | class RandomVerticalFlip(torch.nn.Module): 567 | """Vertically flip the given PIL Image randomly with a given probability. 568 | The image can be a PIL Image or a torch Tensor, in which case it is expected 569 | to have [..., H, W] shape, where ... means an arbitrary number of leading 570 | dimensions 571 | 572 | Args: 573 | p (float): probability of the image being flipped. Default value is 0.5 574 | """ 575 | 576 | def __init__(self, p=0.5): 577 | super().__init__() 578 | self.p = p 579 | 580 | def forward(self, img): 581 | """ 582 | Args: 583 | img (PIL Image or Tensor): Image to be flipped. 584 | 585 | Returns: 586 | PIL Image or Tensor: Randomly flipped image. 587 | """ 588 | if torch.rand(1) < self.p: 589 | return F.vflip(img) 590 | return img 591 | 592 | def __repr__(self): 593 | return self.__class__.__name__ + '(p={})'.format(self.p) 594 | 595 | 596 | class RandomPerspective(object): 597 | """Performs Perspective transformation of the given PIL Image randomly with a given probability. 598 | 599 | Args: 600 | interpolation : Default- Image.BICUBIC 601 | 602 | p (float): probability of the image being perspectively transformed. Default value is 0.5 603 | 604 | distortion_scale(float): it controls the degree of distortion and ranges from 0 to 1. Default value is 0.5. 605 | 606 | fill (3-tuple or int): RGB pixel fill value for area outside the rotated image. 607 | If int, it is used for all channels respectively. Default value is 0. 608 | """ 609 | 610 | def __init__(self, distortion_scale=0.5, p=0.5, interpolation=Image.BICUBIC, fill=0): 611 | self.p = p 612 | self.interpolation = interpolation 613 | self.distortion_scale = distortion_scale 614 | self.fill = fill 615 | 616 | def __call__(self, img): 617 | """ 618 | Args: 619 | img (PIL Image): Image to be Perspectively transformed. 620 | 621 | Returns: 622 | PIL Image: Random perspectivley transformed image. 623 | """ 624 | if not F._is_pil_image(img): 625 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 626 | 627 | if random.random() < self.p: 628 | width, height = img.size 629 | startpoints, endpoints = self.get_params(width, height, self.distortion_scale) 630 | return F.perspective(img, startpoints, endpoints, self.interpolation, self.fill) 631 | return img 632 | 633 | @staticmethod 634 | def get_params(width, height, distortion_scale): 635 | """Get parameters for ``perspective`` for a random perspective transform. 636 | 637 | Args: 638 | width : width of the image. 639 | height : height of the image. 640 | 641 | Returns: 642 | List containing [top-left, top-right, bottom-right, bottom-left] of the original image, 643 | List containing [top-left, top-right, bottom-right, bottom-left] of the transformed image. 644 | """ 645 | half_height = int(height / 2) 646 | half_width = int(width / 2) 647 | topleft = (random.randint(0, int(distortion_scale * half_width)), 648 | random.randint(0, int(distortion_scale * half_height))) 649 | topright = (random.randint(width - int(distortion_scale * half_width) - 1, width - 1), 650 | random.randint(0, int(distortion_scale * half_height))) 651 | botright = (random.randint(width - int(distortion_scale * half_width) - 1, width - 1), 652 | random.randint(height - int(distortion_scale * half_height) - 1, height - 1)) 653 | botleft = (random.randint(0, int(distortion_scale * half_width)), 654 | random.randint(height - int(distortion_scale * half_height) - 1, height - 1)) 655 | startpoints = [(0, 0), (width - 1, 0), (width - 1, height - 1), (0, height - 1)] 656 | endpoints = [topleft, topright, botright, botleft] 657 | return startpoints, endpoints 658 | 659 | def __repr__(self): 660 | return self.__class__.__name__ + '(p={})'.format(self.p) 661 | 662 | 663 | class RandomResizedCrop(object): 664 | """Crop the given PIL Image to random size and aspect ratio. 665 | 666 | A crop of random size (default: of 0.08 to 1.0) of the original size and a random 667 | aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop 668 | is finally resized to given size. 669 | This is popularly used to train the Inception networks. 670 | 671 | Args: 672 | size: expected output size of each edge 673 | scale: range of size of the origin size cropped 674 | ratio: range of aspect ratio of the origin aspect ratio cropped 675 | interpolation: Default: PIL.Image.BILINEAR 676 | """ 677 | 678 | def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR): 679 | if isinstance(size, (tuple, list)): 680 | self.size = size 681 | else: 682 | self.size = (size, size) 683 | if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): 684 | warnings.warn("range should be of kind (min, max)") 685 | 686 | self.interpolation = interpolation 687 | self.scale = scale 688 | self.ratio = ratio 689 | 690 | @staticmethod 691 | def get_params(img, scale, ratio): 692 | """Get parameters for ``crop`` for a random sized crop. 693 | 694 | Args: 695 | img (PIL Image): Image to be cropped. 696 | scale (tuple): range of size of the origin size cropped 697 | ratio (tuple): range of aspect ratio of the origin aspect ratio cropped 698 | 699 | Returns: 700 | tuple: params (i, j, h, w) to be passed to ``crop`` for a random 701 | sized crop. 702 | """ 703 | width, height = _get_image_size(img) 704 | area = height * width 705 | 706 | for _ in range(10): 707 | target_area = random.uniform(*scale) * area 708 | log_ratio = (math.log(ratio[0]), math.log(ratio[1])) 709 | aspect_ratio = math.exp(random.uniform(*log_ratio)) 710 | 711 | w = int(round(math.sqrt(target_area * aspect_ratio))) 712 | h = int(round(math.sqrt(target_area / aspect_ratio))) 713 | 714 | if 0 < w <= width and 0 < h <= height: 715 | i = random.randint(0, height - h) 716 | j = random.randint(0, width - w) 717 | return i, j, h, w 718 | 719 | # Fallback to central crop 720 | in_ratio = float(width) / float(height) 721 | if (in_ratio < min(ratio)): 722 | w = width 723 | h = int(round(w / min(ratio))) 724 | elif (in_ratio > max(ratio)): 725 | h = height 726 | w = int(round(h * max(ratio))) 727 | else: # whole image 728 | w = width 729 | h = height 730 | i = (height - h) // 2 731 | j = (width - w) // 2 732 | return i, j, h, w 733 | 734 | def __call__(self, img): 735 | """ 736 | Args: 737 | img (PIL Image): Image to be cropped and resized. 738 | 739 | Returns: 740 | PIL Image: Randomly cropped and resized image. 741 | """ 742 | i, j, h, w = self.get_params(img, self.scale, self.ratio) 743 | return F.resized_crop(img, i, j, h, w, self.size, self.interpolation) 744 | 745 | def __repr__(self): 746 | interpolate_str = _pil_interpolation_to_str[self.interpolation] 747 | format_string = self.__class__.__name__ + '(size={0}'.format(self.size) 748 | format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale)) 749 | format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio)) 750 | format_string += ', interpolation={0})'.format(interpolate_str) 751 | return format_string 752 | 753 | 754 | class RandomSizedCrop(RandomResizedCrop): 755 | """ 756 | Note: This transform is deprecated in favor of RandomResizedCrop. 757 | """ 758 | 759 | def __init__(self, *args, **kwargs): 760 | warnings.warn("The use of the transforms.RandomSizedCrop transform is deprecated, " + 761 | "please use transforms.RandomResizedCrop instead.") 762 | super(RandomSizedCrop, self).__init__(*args, **kwargs) 763 | 764 | 765 | class FiveCrop(object): 766 | """Crop the given PIL Image into four corners and the central crop 767 | 768 | .. Note:: 769 | This transform returns a tuple of images and there may be a mismatch in the number of 770 | inputs and targets your Dataset returns. See below for an example of how to deal with 771 | this. 772 | 773 | Args: 774 | size (sequence or int): Desired output size of the crop. If size is an ``int`` 775 | instead of sequence like (h, w), a square crop of size (size, size) is made. 776 | 777 | Example: 778 | >>> transform = Compose([ 779 | >>> FiveCrop(size), # this is a list of PIL Images 780 | >>> Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor 781 | >>> ]) 782 | >>> #In your test loop you can do the following: 783 | >>> input, target = batch # input is a 5d tensor, target is 2d 784 | >>> bs, ncrops, c, h, w = input.size() 785 | >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops 786 | >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops 787 | """ 788 | 789 | def __init__(self, size): 790 | self.size = size 791 | if isinstance(size, numbers.Number): 792 | self.size = (int(size), int(size)) 793 | else: 794 | assert len(size) == 2, "Please provide only two dimensions (h, w) for size." 795 | self.size = size 796 | 797 | def __call__(self, img): 798 | return F.five_crop(img, self.size) 799 | 800 | def __repr__(self): 801 | return self.__class__.__name__ + '(size={0})'.format(self.size) 802 | 803 | 804 | class TenCrop(object): 805 | """Crop the given PIL Image into four corners and the central crop plus the flipped version of 806 | these (horizontal flipping is used by default) 807 | 808 | .. Note:: 809 | This transform returns a tuple of images and there may be a mismatch in the number of 810 | inputs and targets your Dataset returns. See below for an example of how to deal with 811 | this. 812 | 813 | Args: 814 | size (sequence or int): Desired output size of the crop. If size is an 815 | int instead of sequence like (h, w), a square crop (size, size) is 816 | made. 817 | vertical_flip (bool): Use vertical flipping instead of horizontal 818 | 819 | Example: 820 | >>> transform = Compose([ 821 | >>> TenCrop(size), # this is a list of PIL Images 822 | >>> Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor 823 | >>> ]) 824 | >>> #In your test loop you can do the following: 825 | >>> input, target = batch # input is a 5d tensor, target is 2d 826 | >>> bs, ncrops, c, h, w = input.size() 827 | >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops 828 | >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops 829 | """ 830 | 831 | def __init__(self, size, vertical_flip=False): 832 | self.size = size 833 | if isinstance(size, numbers.Number): 834 | self.size = (int(size), int(size)) 835 | else: 836 | assert len(size) == 2, "Please provide only two dimensions (h, w) for size." 837 | self.size = size 838 | self.vertical_flip = vertical_flip 839 | 840 | def __call__(self, img): 841 | return F.ten_crop(img, self.size, self.vertical_flip) 842 | 843 | def __repr__(self): 844 | return self.__class__.__name__ + '(size={0}, vertical_flip={1})'.format(self.size, self.vertical_flip) 845 | 846 | 847 | class LinearTransformation(object): 848 | """Transform a tensor image with a square transformation matrix and a mean_vector computed 849 | offline. 850 | Given transformation_matrix and mean_vector, will flatten the torch.*Tensor and 851 | subtract mean_vector from it which is then followed by computing the dot 852 | product with the transformation matrix and then reshaping the tensor to its 853 | original shape. 854 | 855 | Applications: 856 | whitening transformation: Suppose X is a column vector zero-centered data. 857 | Then compute the data covariance matrix [D x D] with torch.mm(X.t(), X), 858 | perform SVD on this matrix and pass it as transformation_matrix. 859 | 860 | Args: 861 | transformation_matrix (Tensor): tensor [D x D], D = C x H x W 862 | mean_vector (Tensor): tensor [D], D = C x H x W 863 | """ 864 | 865 | def __init__(self, transformation_matrix, mean_vector): 866 | if transformation_matrix.size(0) != transformation_matrix.size(1): 867 | raise ValueError("transformation_matrix should be square. Got " + 868 | "[{} x {}] rectangular matrix.".format(*transformation_matrix.size())) 869 | 870 | if mean_vector.size(0) != transformation_matrix.size(0): 871 | raise ValueError("mean_vector should have the same length {}".format(mean_vector.size(0)) + 872 | " as any one of the dimensions of the transformation_matrix [{}]" 873 | .format(tuple(transformation_matrix.size()))) 874 | 875 | self.transformation_matrix = transformation_matrix 876 | self.mean_vector = mean_vector 877 | 878 | def __call__(self, tensor): 879 | """ 880 | Args: 881 | tensor (Tensor): Tensor image of size (C, H, W) to be whitened. 882 | 883 | Returns: 884 | Tensor: Transformed image. 885 | """ 886 | if tensor.size(0) * tensor.size(1) * tensor.size(2) != self.transformation_matrix.size(0): 887 | raise ValueError("tensor and transformation matrix have incompatible shape." + 888 | "[{} x {} x {}] != ".format(*tensor.size()) + 889 | "{}".format(self.transformation_matrix.size(0))) 890 | flat_tensor = tensor.view(1, -1) - self.mean_vector 891 | transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix) 892 | tensor = transformed_tensor.view(tensor.size()) 893 | return tensor 894 | 895 | def __repr__(self): 896 | format_string = self.__class__.__name__ + '(transformation_matrix=' 897 | format_string += (str(self.transformation_matrix.tolist()) + ')') 898 | format_string += (", (mean_vector=" + str(self.mean_vector.tolist()) + ')') 899 | return format_string 900 | 901 | 902 | class ColorJitter(torch.nn.Module): 903 | """Randomly change the brightness, contrast and saturation of an image. 904 | 905 | Args: 906 | brightness (float or tuple of float (min, max)): How much to jitter brightness. 907 | brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness] 908 | or the given [min, max]. Should be non negative numbers. 909 | contrast (float or tuple of float (min, max)): How much to jitter contrast. 910 | contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast] 911 | or the given [min, max]. Should be non negative numbers. 912 | saturation (float or tuple of float (min, max)): How much to jitter saturation. 913 | saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation] 914 | or the given [min, max]. Should be non negative numbers. 915 | hue (float or tuple of float (min, max)): How much to jitter hue. 916 | hue_factor is chosen uniformly from [-hue, hue] or the given [min, max]. 917 | Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5. 918 | """ 919 | 920 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): 921 | super().__init__() 922 | self.brightness = self._check_input(brightness, 'brightness') 923 | self.contrast = self._check_input(contrast, 'contrast') 924 | self.saturation = self._check_input(saturation, 'saturation') 925 | self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5), 926 | clip_first_on_zero=False) 927 | 928 | @torch.jit.unused 929 | def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True): 930 | if isinstance(value, numbers.Number): 931 | if value < 0: 932 | raise ValueError("If {} is a single number, it must be non negative.".format(name)) 933 | value = [center - float(value), center + float(value)] 934 | if clip_first_on_zero: 935 | value[0] = max(value[0], 0.0) 936 | elif isinstance(value, (tuple, list)) and len(value) == 2: 937 | if not bound[0] <= value[0] <= value[1] <= bound[1]: 938 | raise ValueError("{} values should be between {}".format(name, bound)) 939 | else: 940 | raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name)) 941 | 942 | # if value is 0 or (1., 1.) for brightness/contrast/saturation 943 | # or (0., 0.) for hue, do nothing 944 | if value[0] == value[1] == center: 945 | value = None 946 | return value 947 | 948 | @staticmethod 949 | @torch.jit.unused 950 | def get_params(brightness, contrast, saturation, hue): 951 | """Get a randomized transform to be applied on image. 952 | 953 | Arguments are same as that of __init__. 954 | 955 | Returns: 956 | Transform which randomly adjusts brightness, contrast and 957 | saturation in a random order. 958 | """ 959 | transforms = [] 960 | 961 | if brightness is not None: 962 | brightness_factor = random.uniform(brightness[0], brightness[1]) 963 | transforms.append(Lambda(lambda img: F.adjust_brightness(img, brightness_factor))) 964 | 965 | if contrast is not None: 966 | contrast_factor = random.uniform(contrast[0], contrast[1]) 967 | transforms.append(Lambda(lambda img: F.adjust_contrast(img, contrast_factor))) 968 | 969 | if saturation is not None: 970 | saturation_factor = random.uniform(saturation[0], saturation[1]) 971 | transforms.append(Lambda(lambda img: F.adjust_saturation(img, saturation_factor))) 972 | 973 | if hue is not None: 974 | hue_factor = random.uniform(hue[0], hue[1]) 975 | transforms.append(Lambda(lambda img: F.adjust_hue(img, hue_factor))) 976 | 977 | random.shuffle(transforms) 978 | transform = Compose(transforms) 979 | 980 | return transform 981 | 982 | def forward(self, img): 983 | """ 984 | Args: 985 | img (PIL Image or Tensor): Input image. 986 | 987 | Returns: 988 | PIL Image or Tensor: Color jittered image. 989 | """ 990 | fn_idx = torch.randperm(4) 991 | for fn_id in fn_idx: 992 | if fn_id == 0 and self.brightness is not None: 993 | brightness = self.brightness 994 | brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item() 995 | img = F.adjust_brightness(img, brightness_factor) 996 | 997 | if fn_id == 1 and self.contrast is not None: 998 | contrast = self.contrast 999 | contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item() 1000 | img = F.adjust_contrast(img, contrast_factor) 1001 | 1002 | if fn_id == 2 and self.saturation is not None: 1003 | saturation = self.saturation 1004 | saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item() 1005 | img = F.adjust_saturation(img, saturation_factor) 1006 | 1007 | if fn_id == 3 and self.hue is not None: 1008 | hue = self.hue 1009 | hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item() 1010 | img = F.adjust_hue(img, hue_factor) 1011 | 1012 | return img 1013 | 1014 | def __repr__(self): 1015 | format_string = self.__class__.__name__ + '(' 1016 | format_string += 'brightness={0}'.format(self.brightness) 1017 | format_string += ', contrast={0}'.format(self.contrast) 1018 | format_string += ', saturation={0}'.format(self.saturation) 1019 | format_string += ', hue={0})'.format(self.hue) 1020 | return format_string 1021 | 1022 | 1023 | class RandomRotation(object): 1024 | """Rotate the image by angle. 1025 | 1026 | Args: 1027 | degrees (sequence or float or int): Range of degrees to select from. 1028 | If degrees is a number instead of sequence like (min, max), the range of degrees 1029 | will be (-degrees, +degrees). 1030 | resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional): 1031 | An optional resampling filter. See `filters`_ for more information. 1032 | If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST. 1033 | expand (bool, optional): Optional expansion flag. 1034 | If true, expands the output to make it large enough to hold the entire rotated image. 1035 | If false or omitted, make the output image the same size as the input image. 1036 | Note that the expand flag assumes rotation around the center and no translation. 1037 | center (2-tuple, optional): Optional center of rotation. 1038 | Origin is the upper left corner. 1039 | Default is the center of the image. 1040 | fill (n-tuple or int or float): Pixel fill value for area outside the rotated 1041 | image. If int or float, the value is used for all bands respectively. 1042 | Defaults to 0 for all bands. This option is only available for ``pillow>=5.2.0``. 1043 | 1044 | .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters 1045 | 1046 | """ 1047 | 1048 | def __init__(self, degrees, resample=False, expand=False, center=None, fill=None): 1049 | if isinstance(degrees, numbers.Number): 1050 | if degrees < 0: 1051 | raise ValueError("If degrees is a single number, it must be positive.") 1052 | self.degrees = (-degrees, degrees) 1053 | else: 1054 | if len(degrees) != 2: 1055 | raise ValueError("If degrees is a sequence, it must be of len 2.") 1056 | self.degrees = degrees 1057 | 1058 | self.resample = resample 1059 | self.expand = expand 1060 | self.center = center 1061 | self.fill = fill 1062 | 1063 | @staticmethod 1064 | def get_params(degrees): 1065 | """Get parameters for ``rotate`` for a random rotation. 1066 | 1067 | Returns: 1068 | sequence: params to be passed to ``rotate`` for random rotation. 1069 | """ 1070 | angle = random.uniform(degrees[0], degrees[1]) 1071 | 1072 | return angle 1073 | 1074 | def __call__(self, img): 1075 | """ 1076 | Args: 1077 | img (PIL Image): Image to be rotated. 1078 | 1079 | Returns: 1080 | PIL Image: Rotated image. 1081 | """ 1082 | 1083 | angle = self.get_params(self.degrees) 1084 | 1085 | return F.rotate(img, angle, self.resample, self.expand, self.center, self.fill) 1086 | 1087 | def __repr__(self): 1088 | format_string = self.__class__.__name__ + '(degrees={0}'.format(self.degrees) 1089 | format_string += ', resample={0}'.format(self.resample) 1090 | format_string += ', expand={0}'.format(self.expand) 1091 | if self.center is not None: 1092 | format_string += ', center={0}'.format(self.center) 1093 | if self.fill is not None: 1094 | format_string += ', fill={0}'.format(self.fill) 1095 | format_string += ')' 1096 | return format_string 1097 | 1098 | 1099 | class RandomAffine(object): 1100 | """Random affine transformation of the image keeping center invariant 1101 | 1102 | Args: 1103 | degrees (sequence or float or int): Range of degrees to select from. 1104 | If degrees is a number instead of sequence like (min, max), the range of degrees 1105 | will be (-degrees, +degrees). Set to 0 to deactivate rotations. 1106 | translate (tuple, optional): tuple of maximum absolute fraction for horizontal 1107 | and vertical translations. For example translate=(a, b), then horizontal shift 1108 | is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is 1109 | randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default. 1110 | scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is 1111 | randomly sampled from the range a <= scale <= b. Will keep original scale by default. 1112 | shear (sequence or float or int, optional): Range of degrees to select from. 1113 | If shear is a number, a shear parallel to the x axis in the range (-shear, +shear) 1114 | will be apllied. Else if shear is a tuple or list of 2 values a shear parallel to the x axis in the 1115 | range (shear[0], shear[1]) will be applied. Else if shear is a tuple or list of 4 values, 1116 | a x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied. 1117 | Will not apply shear by default 1118 | resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional): 1119 | An optional resampling filter. See `filters`_ for more information. 1120 | If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST. 1121 | fillcolor (tuple or int): Optional fill color (Tuple for RGB Image And int for grayscale) for the area 1122 | outside the transform in the output image.(Pillow>=5.0.0) 1123 | 1124 | .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters 1125 | 1126 | """ 1127 | 1128 | def __init__(self, degrees, translate=None, scale=None, shear=None, resample=False, fillcolor=0): 1129 | if isinstance(degrees, numbers.Number): 1130 | if degrees < 0: 1131 | raise ValueError("If degrees is a single number, it must be positive.") 1132 | self.degrees = (-degrees, degrees) 1133 | else: 1134 | assert isinstance(degrees, (tuple, list)) and len(degrees) == 2, \ 1135 | "degrees should be a list or tuple and it must be of length 2." 1136 | self.degrees = degrees 1137 | 1138 | if translate is not None: 1139 | assert isinstance(translate, (tuple, list)) and len(translate) == 2, \ 1140 | "translate should be a list or tuple and it must be of length 2." 1141 | for t in translate: 1142 | if not (0.0 <= t <= 1.0): 1143 | raise ValueError("translation values should be between 0 and 1") 1144 | self.translate = translate 1145 | 1146 | if scale is not None: 1147 | assert isinstance(scale, (tuple, list)) and len(scale) == 2, \ 1148 | "scale should be a list or tuple and it must be of length 2." 1149 | for s in scale: 1150 | if s <= 0: 1151 | raise ValueError("scale values should be positive") 1152 | self.scale = scale 1153 | 1154 | if shear is not None: 1155 | if isinstance(shear, numbers.Number): 1156 | if shear < 0: 1157 | raise ValueError("If shear is a single number, it must be positive.") 1158 | self.shear = (-shear, shear) 1159 | else: 1160 | assert isinstance(shear, (tuple, list)) and \ 1161 | (len(shear) == 2 or len(shear) == 4), \ 1162 | "shear should be a list or tuple and it must be of length 2 or 4." 1163 | # X-Axis shear with [min, max] 1164 | if len(shear) == 2: 1165 | self.shear = [shear[0], shear[1], 0., 0.] 1166 | elif len(shear) == 4: 1167 | self.shear = [s for s in shear] 1168 | else: 1169 | self.shear = shear 1170 | 1171 | self.resample = resample 1172 | self.fillcolor = fillcolor 1173 | 1174 | @staticmethod 1175 | def get_params(degrees, translate, scale_ranges, shears, img_size): 1176 | """Get parameters for affine transformation 1177 | 1178 | Returns: 1179 | sequence: params to be passed to the affine transformation 1180 | """ 1181 | angle = random.uniform(degrees[0], degrees[1]) 1182 | if translate is not None: 1183 | max_dx = translate[0] * img_size[0] 1184 | max_dy = translate[1] * img_size[1] 1185 | translations = (np.round(random.uniform(-max_dx, max_dx)), 1186 | np.round(random.uniform(-max_dy, max_dy))) 1187 | else: 1188 | translations = (0, 0) 1189 | 1190 | if scale_ranges is not None: 1191 | scale = random.uniform(scale_ranges[0], scale_ranges[1]) 1192 | else: 1193 | scale = 1.0 1194 | 1195 | if shears is not None: 1196 | if len(shears) == 2: 1197 | shear = [random.uniform(shears[0], shears[1]), 0.] 1198 | elif len(shears) == 4: 1199 | shear = [random.uniform(shears[0], shears[1]), 1200 | random.uniform(shears[2], shears[3])] 1201 | else: 1202 | shear = 0.0 1203 | 1204 | return angle, translations, scale, shear 1205 | 1206 | def __call__(self, img): 1207 | """ 1208 | img (PIL Image): Image to be transformed. 1209 | 1210 | Returns: 1211 | PIL Image: Affine transformed image. 1212 | """ 1213 | ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img.size) 1214 | return F.affine(img, *ret, resample=self.resample, fillcolor=self.fillcolor) 1215 | 1216 | def __repr__(self): 1217 | s = '{name}(degrees={degrees}' 1218 | if self.translate is not None: 1219 | s += ', translate={translate}' 1220 | if self.scale is not None: 1221 | s += ', scale={scale}' 1222 | if self.shear is not None: 1223 | s += ', shear={shear}' 1224 | if self.resample > 0: 1225 | s += ', resample={resample}' 1226 | if self.fillcolor != 0: 1227 | s += ', fillcolor={fillcolor}' 1228 | s += ')' 1229 | d = dict(self.__dict__) 1230 | d['resample'] = _pil_interpolation_to_str[d['resample']] 1231 | return s.format(name=self.__class__.__name__, **d) 1232 | 1233 | 1234 | class Grayscale(object): 1235 | """Convert image to grayscale. 1236 | 1237 | Args: 1238 | num_output_channels (int): (1 or 3) number of channels desired for output image 1239 | 1240 | Returns: 1241 | PIL Image: Grayscale version of the input. 1242 | - If ``num_output_channels == 1`` : returned image is single channel 1243 | - If ``num_output_channels == 3`` : returned image is 3 channel with r == g == b 1244 | 1245 | """ 1246 | 1247 | def __init__(self, num_output_channels=1): 1248 | self.num_output_channels = num_output_channels 1249 | 1250 | def __call__(self, img): 1251 | """ 1252 | Args: 1253 | img (PIL Image): Image to be converted to grayscale. 1254 | 1255 | Returns: 1256 | PIL Image: Randomly grayscaled image. 1257 | """ 1258 | return F.to_grayscale(img, num_output_channels=self.num_output_channels) 1259 | 1260 | def __repr__(self): 1261 | return self.__class__.__name__ + '(num_output_channels={0})'.format(self.num_output_channels) 1262 | 1263 | 1264 | class RandomGrayscale(object): 1265 | """Randomly convert image to grayscale with a probability of p (default 0.1). 1266 | 1267 | Args: 1268 | p (float): probability that image should be converted to grayscale. 1269 | 1270 | Returns: 1271 | PIL Image: Grayscale version of the input image with probability p and unchanged 1272 | with probability (1-p). 1273 | - If input image is 1 channel: grayscale version is 1 channel 1274 | - If input image is 3 channel: grayscale version is 3 channel with r == g == b 1275 | 1276 | """ 1277 | 1278 | def __init__(self, p=0.1): 1279 | self.p = p 1280 | 1281 | def __call__(self, img): 1282 | """ 1283 | Args: 1284 | img (PIL Image): Image to be converted to grayscale. 1285 | 1286 | Returns: 1287 | PIL Image: Randomly grayscaled image. 1288 | """ 1289 | num_output_channels = 1 if img.mode == 'L' else 3 1290 | if random.random() < self.p: 1291 | return F.to_grayscale(img, num_output_channels=num_output_channels) 1292 | return img 1293 | 1294 | def __repr__(self): 1295 | return self.__class__.__name__ + '(p={0})'.format(self.p) 1296 | 1297 | 1298 | class RandomErasing(object): 1299 | """ Randomly selects a rectangle region in an image and erases its pixels. 1300 | 'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/pdf/1708.04896.pdf 1301 | 1302 | Args: 1303 | p: probability that the random erasing operation will be performed. 1304 | scale: range of proportion of erased area against input image. 1305 | ratio: range of aspect ratio of erased area. 1306 | value: erasing value. Default is 0. If a single int, it is used to 1307 | erase all pixels. If a tuple of length 3, it is used to erase 1308 | R, G, B channels respectively. 1309 | If a str of 'random', erasing each pixel with random values. 1310 | inplace: boolean to make this transform inplace. Default set to False. 1311 | 1312 | Returns: 1313 | Erased Image. 1314 | 1315 | # Examples: 1316 | >>> transform = transforms.Compose([ 1317 | >>> transforms.RandomHorizontalFlip(), 1318 | >>> transforms.ToTensor(), 1319 | >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 1320 | >>> transforms.RandomErasing(), 1321 | >>> ]) 1322 | """ 1323 | 1324 | def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False): 1325 | assert isinstance(value, (numbers.Number, str, tuple, list)) 1326 | if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): 1327 | warnings.warn("range should be of kind (min, max)") 1328 | if scale[0] < 0 or scale[1] > 1: 1329 | raise ValueError("range of scale should be between 0 and 1") 1330 | if p < 0 or p > 1: 1331 | raise ValueError("range of random erasing probability should be between 0 and 1") 1332 | 1333 | self.p = p 1334 | self.scale = scale 1335 | self.ratio = ratio 1336 | self.value = value 1337 | self.inplace = inplace 1338 | 1339 | @staticmethod 1340 | def get_params(img, scale, ratio, value=0): 1341 | """Get parameters for ``erase`` for a random erasing. 1342 | 1343 | Args: 1344 | img (Tensor): Tensor image of size (C, H, W) to be erased. 1345 | scale: range of proportion of erased area against input image. 1346 | ratio: range of aspect ratio of erased area. 1347 | 1348 | Returns: 1349 | tuple: params (i, j, h, w, v) to be passed to ``erase`` for random erasing. 1350 | """ 1351 | img_c, img_h, img_w = img.shape 1352 | area = img_h * img_w 1353 | 1354 | for _ in range(10): 1355 | erase_area = random.uniform(scale[0], scale[1]) * area 1356 | aspect_ratio = random.uniform(ratio[0], ratio[1]) 1357 | 1358 | h = int(round(math.sqrt(erase_area * aspect_ratio))) 1359 | w = int(round(math.sqrt(erase_area / aspect_ratio))) 1360 | 1361 | if h < img_h and w < img_w: 1362 | i = random.randint(0, img_h - h) 1363 | j = random.randint(0, img_w - w) 1364 | if isinstance(value, numbers.Number): 1365 | v = value 1366 | elif isinstance(value, torch._six.string_classes): 1367 | v = torch.empty([img_c, h, w], dtype=torch.float32).normal_() 1368 | elif isinstance(value, (list, tuple)): 1369 | v = torch.tensor(value, dtype=torch.float32).view(-1, 1, 1).expand(-1, h, w) 1370 | return i, j, h, w, v 1371 | 1372 | # Return original image 1373 | return 0, 0, img_h, img_w, img 1374 | 1375 | def __call__(self, img): 1376 | """ 1377 | Args: 1378 | img (Tensor): Tensor image of size (C, H, W) to be erased. 1379 | 1380 | Returns: 1381 | img (Tensor): Erased Tensor image. 1382 | """ 1383 | if random.uniform(0, 1) < self.p: 1384 | x, y, h, w, v = self.get_params(img, scale=self.scale, ratio=self.ratio, value=self.value) 1385 | return F.erase(img, x, y, h, w, v, self.inplace) 1386 | return img 1387 | --------------------------------------------------------------------------------