├── sodeep_master ├── __init__.py ├── utils.py ├── dataset.py ├── sodeep.py ├── train.py └── model.py ├── __init__.py ├── spearman.py ├── eval.py ├── accloss.py ├── README.md ├── img_ensamble.py ├── submission.py ├── fresunet.py ├── siamunet_conc.py ├── siamunet_diff.py ├── siamunet_conc_extrahead.py └── main.py /sodeep_master/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from . import sodeep_master 3 | -------------------------------------------------------------------------------- /spearman.py: -------------------------------------------------------------------------------- 1 | from sodeep_master.sodeep import load_sorter, SpearmanLoss 2 | 3 | 4 | def Spear(sorter_checkpoint_path, device="cuda"): 5 | criterion = SpearmanLoss(*load_sorter(sorter_checkpoint_path)) 6 | criterion.to(device) 7 | return criterion 8 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import stats 3 | 4 | 5 | def eval(predicted, gt): 6 | r_s = np.abs(stats.spearmanr(predicted, gt))[0] 7 | 8 | z = np.polyfit(predicted, gt, 3) 9 | fit_func = np.poly1d(z) 10 | fitted_MOS = fit_func(predicted) 11 | r_p = np.abs(stats.pearsonr(fitted_MOS, gt))[0] 12 | 13 | return round(r_s + r_p, 4), round(r_s, 4), round(r_p, 4) 14 | -------------------------------------------------------------------------------- /accloss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def m_pearsonr(output, target): 5 | x = output 6 | y = target 7 | 8 | vx = x - torch.mean(x) 9 | vy = y - torch.mean(y) 10 | 11 | pr = torch.sum(vx * vy) / (torch.sqrt(torch.sum(vx ** 2)) * torch.sqrt(torch.sum(vy ** 2))) 12 | 13 | return pr 14 | 15 | 16 | def accloss(output, target): 17 | pr = m_pearsonr(output, target) 18 | return pr 19 | 20 | 21 | if __name__ == '__main__': 22 | output = torch.rand(1, 5) 23 | target = torch.rand(1, 5) 24 | pytt = accloss(output, target) 25 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ASNA_MACS_IQA (Pytorch) 2 | [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) 3 | 4 | ![PyTorch](https://img.shields.io/badge/PyTorch-%23EE4C2C.svg?style=for-the-badge&logo=PyTorch&logoColor=white) 5 | 6 | Tensorflow version : https://github.com/smehdia/NTIRE2021-IQA-MACS 7 | 8 | ### **competition of Image Quality Assessment (IQA) challenge - NTIRE 2021** 9 | 10 | ------ 11 | 12 | - ***Challenge Paper*** : [NTIRE 2021 Challenge on Perceptual Image Quality Assessment](https://openaccess.thecvf.com/content/CVPR2021W/NTIRE/papers/Gu_NTIRE_2021_Challenge_on_Perceptual_Image_Quality_Assessment_CVPRW_2021_paper.pdf) 13 | 14 | 15 | 16 | - ***Our Paper*** : [(ASNA) An Attention-Based Siamese-Difference Neural Network With Surrogate Ranking Loss Function for Perceptual Image Quality Assessment](https://openaccess.thecvf.com/content/CVPR2021W/NTIRE/papers/Ayyoubzadeh_ASNA_An_Attention-Based_Siamese-Difference_Neural_Network_With_Surrogate_Ranking_Loss_CVPRW_2021_paper.pdf) . 17 | 18 | - ***Supp. Materials*** : [Supplementary Materials](https://openaccess.thecvf.com/content/CVPR2021W/NTIRE/supplemental/Ayyoubzadeh_ASNA_An_Attention-Based_CVPRW_2021_supplemental.pdf) 19 | 20 | 21 | 22 | we use [SoSeep](https://github.com/technicolor-research/sodeep) codes for Ranking Loss section . 23 | -------------------------------------------------------------------------------- /img_ensamble.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | device = "cuda" 5 | 6 | 7 | def reshape_for_torch(I): 8 | """Transpose image for PyTorch coordinates.""" 9 | # out = np.swapaxes(I,1,2) 10 | # out = np.swapaxes(out,0,1) 11 | # out = out[np.newaxis,:] 12 | out = I.transpose((2, 0, 1)) 13 | out = np.expand_dims(out, axis=0) 14 | return torch.from_numpy(1.0 * out) 15 | 16 | 17 | def image_ensamble(I1_, I2_, net): 18 | all_predicted_score = [] 19 | 20 | for n in [0, 1, 2, 3]: 21 | for direction in ["none", "v", "h"]: 22 | I1, I2 = RandomFlip(I1_, I2_, direction) 23 | I1, I2 = RandomRot(I1, I2, n) 24 | predicted_score = net(I1, I2) 25 | predicted_score_ = torch.squeeze(predicted_score).detach().cpu().numpy() 26 | all_predicted_score.append(predicted_score_) 27 | avg_score = np.array(all_predicted_score).mean() 28 | return avg_score 29 | 30 | 31 | def RandomRot(I1_, I2_, n): 32 | I1 = I1_[0, ...] 33 | I2 = I2_[0, ...] 34 | 35 | I1 = I1.cpu().numpy().copy() 36 | I1 = np.rot90(I1, n, axes=(1, 2)).copy() 37 | I1 = np.expand_dims(I1, axis=0) 38 | I1 = torch.from_numpy(I1).to(device) 39 | 40 | I2 = I2.cpu().numpy().copy() 41 | I2 = np.rot90(I2, n, axes=(1, 2)).copy() 42 | I2 = np.expand_dims(I2, axis=0) 43 | I2 = torch.from_numpy(I2).to(device) 44 | 45 | return I1, I2 46 | 47 | 48 | def RandomFlip(I1_, I2_, direction): 49 | I1 = I1_ 50 | I2 = I2_ 51 | if direction == "none": 52 | return I1_, I2_ 53 | elif direction == "v": 54 | I1 = I1.cpu().numpy()[:, :, :, ::-1].copy() 55 | I1 = torch.from_numpy(I1).to(device) 56 | I2 = I2.cpu().numpy()[:, :, :, ::-1].copy() 57 | I2 = torch.from_numpy(I2).to(device) 58 | 59 | elif direction == "h": 60 | I1 = I1.cpu().numpy()[:, :, ::-1, :].copy() 61 | I1 = torch.from_numpy(I1).to(device) 62 | I2 = I2.cpu().numpy()[:, :, ::-1, :].copy() 63 | I2 = torch.from_numpy(I2).to(device) 64 | 65 | return I1, I2 66 | -------------------------------------------------------------------------------- /sodeep_master/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | ****************** COPYRIGHT AND CONFIDENTIALITY INFORMATION ****************** 3 | Copyright (c) 2019 [Thomson Licensing] 4 | All Rights Reserved 5 | This program contains proprietary information which is a trade secret/business \ 6 | secret of [Thomson Licensing] and is protected, even if unpublished, under \ 7 | applicable Copyright laws (including French droit d'auteur) and/or may be \ 8 | subject to one or more patent(s). 9 | Recipient is to retain this program in confidence and is not permitted to use \ 10 | or make copies thereof other than as permitted in a written agreement with \ 11 | [Thomson Licensing] unless otherwise expressly allowed by applicable laws or \ 12 | by [Thomson Licensing] under express agreement. 13 | Thomson Licensing is a company of the group TECHNICOLOR 14 | ******************************************************************************* 15 | This scripts permits one to reproduce training and experiments of: 16 | Engilberge, M., Chevallier, L., Pérez, P., & Cord, M. (2019, June). 17 | SoDeep: A Sorting Deep Net to Learn Ranking Loss Surrogates. 18 | In Proceedings of CVPR 19 | 20 | Author: Martin Engilberge 21 | """ 22 | 23 | import torch 24 | 25 | 26 | def get_rank(batch_score, dim=0): 27 | rank = torch.argsort(batch_score, dim=dim) 28 | rank = torch.argsort(rank, dim=dim) 29 | rank = (rank * -1) + batch_score.size(dim) 30 | rank = rank.float() 31 | rank = rank / batch_score.size(dim) 32 | 33 | return rank 34 | 35 | 36 | class AverageMeter(object): 37 | """Computes and stores the average and current value""" 38 | 39 | def __init__(self): 40 | self.reset() 41 | 42 | def reset(self): 43 | self.val = 0 44 | self.avg = 0 45 | self.sum = 0 46 | self.count = 0 47 | 48 | def update(self, val, n=1): 49 | self.val = val 50 | self.sum += val * n 51 | self.count += n 52 | self.avg = self.sum / self.count 53 | 54 | 55 | def build_vocab(sentences): 56 | vocab = {} 57 | for sentence in sentences: 58 | for word in sentence: 59 | try: 60 | vocab[word] += 1 61 | except KeyError: 62 | vocab[word] = 1 63 | return vocab 64 | 65 | 66 | def save_checkpoint(state, is_best, model_name, epoch): 67 | if is_best: 68 | torch.save(state, './weights/best_' + model_name + ".pth.tar") 69 | 70 | 71 | def count_parameters(model): 72 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 73 | 74 | 75 | def log_epoch(logger, epoch, train_loss, val_loss, lr, batch_train, batch_val, data_train, data_val): 76 | logger.add_scalar('Loss/Train', train_loss, epoch) 77 | logger.add_scalar('Loss/Val', val_loss, epoch) 78 | logger.add_scalar('Learning/Rate', lr, epoch) 79 | logger.add_scalar('Learning/Overfitting', val_loss / train_loss, epoch) 80 | logger.add_scalar('Time/Train/Batch Processing', batch_train, epoch) 81 | logger.add_scalar('Time/Val/Batch Processing', batch_val, epoch) 82 | logger.add_scalar('Time/Train/Data loading', data_train, epoch) 83 | logger.add_scalar('Time/Val/Data loading', data_val, epoch) 84 | 85 | 86 | def flatten(l): 87 | return [item for sublist in l for item in sublist] 88 | -------------------------------------------------------------------------------- /submission.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import numpy as np 4 | from skimage import io 5 | import random 6 | from tqdm import tqdm 7 | from torch.autograd import Variable 8 | from siamunet_diff import SiamUnet_diff 9 | from siamunet_conc import SiamUnet_conc 10 | from checkpoint.cinavad_sever.siamunet_conc_extrahead import SiamUnet_conc 11 | import torch 12 | from img_ensamble import image_ensamble 13 | 14 | # conf_type = r"img(mean_std_norm)_label(mean_std_norm)_maxpool_1e-5_type_covcat_pretrain_epoch-180_loss-0.42687_mse1_pr0.5_spr0.5" 15 | conf_type = r"D:\NTIRE Workshop and Challenges @ CVPR 2021\results\60\img(mean_std_norm)_label(mean_std_norm)_maxpool_1e-5_type_covcat_pretrain_epoch-172_loss-0.575_mse1_pr1_spr1" 16 | NORMALISE_IMGS = True 17 | device = "cuda" 18 | TYPE = "new" 19 | 20 | 21 | def apply_img_to_net(img): 22 | return random.uniform(1260.125454, 1590.555546) 23 | 24 | 25 | def reshape_for_torch(I): 26 | """Transpose image for PyTorch coordinates.""" 27 | # out = np.swapaxes(I,1,2) 28 | # out = np.swapaxes(out,0,1) 29 | # out = out[np.newaxis,:] 30 | out = I.transpose((2, 0, 1)) 31 | out = np.expand_dims(out, axis=0) 32 | return torch.from_numpy(1.0 * out) 33 | 34 | 35 | def count_parameters(model): 36 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 37 | 38 | 39 | if __name__ == '__main__': 40 | if TYPE == 2: 41 | net = SiamUnet_conc(3, 1) 42 | elif TYPE == "new": 43 | net = SiamUnet_conc(3, 32) 44 | elif TYPE == 3: 45 | net = SiamUnet_diff(3, 1) 46 | if device == "cuda": 47 | net.cuda() 48 | # checkpoint_path = torch.load(rf'./checkpoint/{conf_type}/ch_net-best_epoch-37_accu-1.5132.pth.tar') 49 | checkpoint = torch.load(rf'{conf_type}\ch_net-best_epoch-389_loss-2.5346076488494873.pth.tar') 50 | net.load_state_dict(checkpoint['model_state_dict']) 51 | 52 | val_path = r"D:\NTIRE Workshop and Challenges @ CVPR 2021\dataset\Dis" 53 | save_path = r"D:\NTIRE Workshop and Challenges @ CVPR 2021\results" 54 | 55 | with torch.no_grad(): 56 | net.eval() 57 | for img in tqdm(glob.glob(os.path.join(val_path, "*.bmp"))): 58 | ref = img.replace("Dis", "Ref")[:-10] + ".bmp" 59 | 60 | I1_ = io.imread(img) 61 | I2_ = io.imread(ref) 62 | 63 | if NORMALISE_IMGS: 64 | I1_m = (I1_ - I1_.mean()) / I1_.std() 65 | I2_m = (I2_ - I2_.mean()) / I2_.std() 66 | else: 67 | I1_m = I1_ 68 | I2_m = I2_ 69 | I1 = Variable(reshape_for_torch(I1_m).float().to(device)) 70 | I2 = Variable(reshape_for_torch(I2_m).float().to(device)) 71 | 72 | # predicted_score = net(I1, I2) 73 | avg_predicted_score = image_ensamble(I1, I2, net) 74 | std = 121.7751 75 | mean = 1448.9539 76 | # predicted_score = (torch.squeeze(predicted_score).detach().cpu().numpy())*std + mean 77 | predicted_score = (avg_predicted_score * std) + mean 78 | # predicted_score = avg_predicted_score + mean 79 | # predicted_score = avg_predicted_score 80 | with open(os.path.join(save_path, "output.txt"), "a") as out_file: 81 | out_file.write(img.split("\\")[-1] + "," + f"{predicted_score:.4f}" + "\n") 82 | -------------------------------------------------------------------------------- /sodeep_master/dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | ****************** COPYRIGHT AND CONFIDENTIALITY INFORMATION ****************** 3 | Copyright (c) 2019 [Thomson Licensing] 4 | All Rights Reserved 5 | This program contains proprietary information which is a trade secret/business \ 6 | secret of [Thomson Licensing] and is protected, even if unpublished, under \ 7 | applicable Copyright laws (including French droit d'auteur) and/or may be \ 8 | subject to one or more patent(s). 9 | Recipient is to retain this program in confidence and is not permitted to use \ 10 | or make copies thereof other than as permitted in a written agreement with \ 11 | [Thomson Licensing] unless otherwise expressly allowed by applicable laws or \ 12 | by [Thomson Licensing] under express agreement. 13 | Thomson Licensing is a company of the group TECHNICOLOR 14 | ******************************************************************************* 15 | This scripts permits one to reproduce training and experiments of: 16 | Engilberge, M., Chevallier, L., Pérez, P., & Cord, M. (2019, June). 17 | SoDeep: A Sorting Deep Net to Learn Ranking Loss Surrogates. 18 | In Proceedings of CVPR 19 | 20 | Author: Martin Engilberge 21 | """ 22 | 23 | import numpy as np 24 | import torch 25 | 26 | from random import randint 27 | from torch.utils.data import Dataset 28 | 29 | 30 | def get_rand_seq(seq_len, ind=None): 31 | if ind is None: 32 | type_rand = randint(0, 9) 33 | else: 34 | type_rand = int(ind) 35 | 36 | if type_rand == 0: 37 | rand_seq = np.random.rand(seq_len) * 2.0 - 1 38 | elif type_rand == 1: 39 | rand_seq = np.random.uniform(-1, 1, seq_len) 40 | elif type_rand == 2: 41 | rand_seq = np.random.standard_normal(seq_len) 42 | elif type_rand == 3: 43 | a = np.random.rand() 44 | b = np.random.rand() 45 | rand_seq = np.arange(a, b, (b - a) / seq_len) 46 | elif type_rand == 4: 47 | a = np.random.rand() 48 | b = np.random.rand() 49 | rand_seq = np.arange(a, b, (b - a) / seq_len) 50 | np.random.shuffle(rand_seq) 51 | elif type_rand == 5: 52 | split = randint(1, seq_len) 53 | rand_seq = np.concatenate( 54 | [np.random.rand(split) * 2.0 - 1, np.random.standard_normal(seq_len - split)]) 55 | elif type_rand == 6: 56 | split = randint(1, seq_len) 57 | rand_seq = np.concatenate( 58 | [np.random.uniform(-1, 1, split), np.random.standard_normal(seq_len - split)]) 59 | elif type_rand == 7: 60 | split = randint(1, seq_len) 61 | rand_seq = np.concatenate( 62 | [np.random.rand(split) * 2.0 - 1, np.random.uniform(-1, 1, seq_len - split)]) 63 | elif type_rand == 8: 64 | split = randint(1, seq_len) 65 | a = np.random.rand() 66 | b = np.random.rand() 67 | rand_seq = np.arange(a, b, (b - a) / split) 68 | np.random.shuffle(rand_seq) 69 | rand_seq = np.concatenate( 70 | [rand_seq, np.random.rand(seq_len - split) * 2.0 - 1]) 71 | elif type_rand == 9: 72 | a = -1.0 73 | b = 1.0 74 | rand_seq = np.arange(a, b, (b - a) / seq_len) 75 | elif type_rand == 10: 76 | rand_seq = np.random.rand(seq_len) * np.random.rand() - np.random.rand() 77 | 78 | return rand_seq[:seq_len] 79 | 80 | 81 | class SeqDataset(Dataset): 82 | 83 | def __init__(self, seq_len, nb_sample=400000, dist=None): 84 | self.seq_len = seq_len 85 | self.nb_sample = nb_sample 86 | 87 | self.dist = dist 88 | 89 | def __getitem__(self, index): 90 | rand_seq = get_rand_seq(self.seq_len, self.dist) 91 | zipp_sort_ind = zip(np.argsort(rand_seq)[::-1], range(self.seq_len)) 92 | 93 | ranks = [((y[1] + 1) / float(self.seq_len)) for y in sorted(zipp_sort_ind, key=lambda x: x[0])] 94 | 95 | return torch.FloatTensor(rand_seq), torch.FloatTensor(ranks) 96 | 97 | def __len__(self): 98 | return self.nb_sample 99 | 100 | 101 | def get_rank_single(batch_score): 102 | rank = torch.argsort(batch_score, dim=0) 103 | rank = torch.argsort(rank, dim=0) 104 | rank = (rank * -1) + batch_score.size(0) 105 | rank = rank.float() 106 | rank = rank / batch_score.size(0) 107 | 108 | return rank 109 | -------------------------------------------------------------------------------- /sodeep_master/sodeep.py: -------------------------------------------------------------------------------- 1 | """ 2 | ****************** COPYRIGHT AND CONFIDENTIALITY INFORMATION ****************** 3 | Copyright (c) 2019 [Thomson Licensing] 4 | All Rights Reserved 5 | This program contains proprietary information which is a trade secret/business \ 6 | secret of [Thomson Licensing] and is protected, even if unpublished, under \ 7 | applicable Copyright laws (including French droit d'auteur) and/or may be \ 8 | subject to one or more patent(s). 9 | Recipient is to retain this program in confidence and is not permitted to use \ 10 | or make copies thereof other than as permitted in a written agreement with \ 11 | [Thomson Licensing] unless otherwise expressly allowed by applicable laws or \ 12 | by [Thomson Licensing] under express agreement. 13 | Thomson Licensing is a company of the group TECHNICOLOR 14 | ******************************************************************************* 15 | This scripts permits one to reproduce training and experiments of: 16 | Engilberge, M., Chevallier, L., Pérez, P., & Cord, M. (2019, June). 17 | SoDeep: A Sorting Deep Net to Learn Ranking Loss Surrogates. 18 | In Proceedings of CVPR 19 | 20 | Author: Martin Engilberge 21 | """ 22 | 23 | import torch 24 | 25 | from .model import model_loader 26 | from .utils import get_rank 27 | 28 | 29 | def load_sorter(checkpoint_path): 30 | sorter_checkpoint = torch.load(checkpoint_path) 31 | 32 | model_type = sorter_checkpoint["args_dict"].model_type 33 | seq_len = sorter_checkpoint["args_dict"].seq_len 34 | state_dict = sorter_checkpoint["state_dict"] 35 | 36 | return model_type, seq_len, state_dict 37 | 38 | 39 | class RankHardLoss(torch.nn.Module): 40 | """ Loss function inspired by hard negative triplet loss, directly applied in the rank domain """ 41 | def __init__(self, sorter_type, seq_len=None, sorter_state_dict=None, margin=0.2, nmax=1): 42 | super(RankHardLoss, self).__init__() 43 | self.nmax = nmax 44 | self.margin = margin 45 | 46 | self.sorter = model_loader(sorter_type, seq_len, sorter_state_dict) 47 | 48 | def hc_loss(self, scores): 49 | rank = self.sorter(scores) 50 | 51 | diag = rank.diag() 52 | 53 | rank = rank + torch.diag(torch.ones(rank.diag().size(), device=rank.device) * 50.0) 54 | 55 | sorted_rank, _ = torch.sort(rank, 1, descending=False) 56 | 57 | hard_neg_rank = sorted_rank[:, :self.nmax] 58 | 59 | loss = torch.sum(torch.clamp(-hard_neg_rank + (1.0 / (scores.size(1)) + diag).view(-1, 1).expand_as(hard_neg_rank), min=0)) 60 | 61 | return loss 62 | 63 | def forward(self, scores): 64 | """ Expect a score matrix with scores of the positive pairs are on the diagonal """ 65 | caption_loss = self.hc_loss(scores) 66 | image_loss = self.hc_loss(scores.t()) 67 | 68 | image_caption_loss = caption_loss + image_loss 69 | 70 | return image_caption_loss 71 | 72 | 73 | class RankLoss(torch.nn.Module): 74 | """ Loss function inspired by recall """ 75 | def __init__(self, sorter_type, seq_len=None, sorter_state_dict=None,): 76 | super(RankLoss, self).__init__() 77 | self.sorter = model_loader(sorter_type, seq_len, sorter_state_dict) 78 | 79 | def forward(self, scores): 80 | """ Expect a score matrix with scores of the positive pairs are on the diagonal """ 81 | caption_rank = self.sorter(scores) 82 | image_rank = self.sorter(scores.t()) 83 | 84 | image_caption_loss = torch.sum(caption_rank.diag()) + torch.sum(image_rank.diag()) 85 | 86 | return image_caption_loss 87 | 88 | 89 | class MapRankingLoss(torch.nn.Module): 90 | """ Loss function inspired by mean Average Precision """ 91 | def __init__(self, sorter_type, seq_len=None, sorter_state_dict=None): 92 | super(MapRankingLoss, self).__init__() 93 | 94 | self.sorter = model_loader(sorter_type, seq_len, sorter_state_dict) 95 | 96 | def forward(self, output, target): 97 | # Compute map for each classes 98 | map_tot = 0 99 | for c in range(target.size(1)): 100 | gt_c = target[:, c] 101 | 102 | if torch.sum(gt_c) == 0: 103 | continue 104 | rank_pred = self.sorter(output[:, c].unsqueeze(0)).view(-1) 105 | rank_pos = rank_pred * gt_c 106 | 107 | map_tot += torch.sum(rank_pos) 108 | 109 | return map_tot 110 | 111 | 112 | class SpearmanLoss(torch.nn.Module): 113 | """ Loss function inspired by spearmann correlation.self 114 | Required the trained model to have a good initlization. 115 | 116 | Set lbd to 1 for a few epoch to help with the initialization. 117 | """ 118 | def __init__(self, sorter_type, seq_len=None, sorter_state_dict=None, lbd=0): 119 | super(SpearmanLoss, self).__init__() 120 | self.sorter = model_loader(sorter_type, seq_len, sorter_state_dict) 121 | 122 | self.criterion_mse = torch.nn.MSELoss() 123 | self.criterionl1 = torch.nn.L1Loss() 124 | 125 | self.lbd = lbd 126 | 127 | def forward(self, mem_pred, mem_gt, pr=False): 128 | rank_gt = get_rank(mem_gt) 129 | 130 | rank_pred = self.sorter(mem_pred.unsqueeze( 131 | 0)).view(-1) 132 | 133 | return self.criterion_mse(rank_pred, rank_gt) + self.lbd * self.criterionl1(mem_pred, mem_gt) 134 | -------------------------------------------------------------------------------- /fresunet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.modules.padding import ReplicationPad2d 5 | 6 | def conv3x3(in_planes, out_planes, stride=1): 7 | "3x3 convolution with padding" 8 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1) 9 | 10 | 11 | class BasicBlock_ss(nn.Module): 12 | 13 | def __init__(self, inplanes, planes = None, subsamp=1): 14 | super(BasicBlock_ss, self).__init__() 15 | if planes == None: 16 | planes = inplanes * subsamp 17 | self.conv1 = conv3x3(inplanes, planes) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | self.relu = nn.ReLU(inplace=True) 20 | self.conv2 = conv3x3(planes, planes) 21 | self.bn2 = nn.BatchNorm2d(planes) 22 | self.subsamp = subsamp 23 | self.doit = planes != inplanes 24 | if self.doit: 25 | self.couple = nn.Conv2d(inplanes, planes, kernel_size=1) 26 | self.bnc = nn.BatchNorm2d(planes) 27 | 28 | def forward(self, x): 29 | if self.doit: 30 | residual = self.couple(x) 31 | residual = self.bnc(residual) 32 | else: 33 | residual = x 34 | 35 | out = self.conv1(x) 36 | out = self.bn1(out) 37 | out = self.relu(out) 38 | 39 | if self.subsamp > 1: 40 | out = F.max_pool2d(out, kernel_size=self.subsamp, stride=self.subsamp) 41 | residual = F.max_pool2d(residual, kernel_size=self.subsamp, stride=self.subsamp) 42 | 43 | out = self.conv2(out) 44 | out = self.bn2(out) 45 | 46 | out += residual 47 | out = self.relu(out) 48 | 49 | return out 50 | 51 | 52 | 53 | class BasicBlock_us(nn.Module): 54 | 55 | def __init__(self, inplanes, upsamp=1): 56 | super(BasicBlock_us, self).__init__() 57 | planes = int(inplanes / upsamp) # assumes integer result, fix later 58 | self.conv1 = nn.ConvTranspose2d(inplanes, planes, kernel_size=3, padding=1, stride=upsamp, output_padding=1) 59 | self.bn1 = nn.BatchNorm2d(planes) 60 | self.relu = nn.ReLU(inplace=True) 61 | self.conv2 = conv3x3(planes, planes) 62 | self.bn2 = nn.BatchNorm2d(planes) 63 | self.upsamp = upsamp 64 | self.couple = nn.ConvTranspose2d(inplanes, planes, kernel_size=3, padding=1, stride=upsamp, output_padding=1) 65 | self.bnc = nn.BatchNorm2d(planes) 66 | 67 | def forward(self, x): 68 | residual = self.couple(x) 69 | residual = self.bnc(residual) 70 | 71 | out = self.conv1(x) 72 | out = self.bn1(out) 73 | out = self.relu(out) 74 | 75 | 76 | out = self.conv2(out) 77 | out = self.bn2(out) 78 | 79 | out += residual 80 | out = self.relu(out) 81 | 82 | return out 83 | 84 | 85 | class FresUNet(nn.Module): 86 | """FresUNet segmentation network.""" 87 | 88 | def __init__(self, input_nbr, label_nbr): 89 | """Init FresUNet fields.""" 90 | super(FresUNet, self).__init__() 91 | 92 | self.input_nbr = input_nbr 93 | 94 | cur_depth = input_nbr 95 | 96 | base_depth = 8 97 | 98 | # Encoding stage 1 99 | self.encres1_1 = BasicBlock_ss(cur_depth, planes = base_depth) 100 | cur_depth = base_depth 101 | d1 = base_depth 102 | self.encres1_2 = BasicBlock_ss(cur_depth, subsamp=2) 103 | cur_depth *= 2 104 | 105 | # Encoding stage 2 106 | self.encres2_1 = BasicBlock_ss(cur_depth) 107 | d2 = cur_depth 108 | self.encres2_2 = BasicBlock_ss(cur_depth, subsamp=2) 109 | cur_depth *= 2 110 | 111 | # Encoding stage 3 112 | self.encres3_1 = BasicBlock_ss(cur_depth) 113 | d3 = cur_depth 114 | self.encres3_2 = BasicBlock_ss(cur_depth, subsamp=2) 115 | cur_depth *= 2 116 | 117 | # Encoding stage 4 118 | self.encres4_1 = BasicBlock_ss(cur_depth) 119 | d4 = cur_depth 120 | self.encres4_2 = BasicBlock_ss(cur_depth, subsamp=2) 121 | cur_depth *= 2 122 | 123 | # Decoding stage 4 124 | self.decres4_1 = BasicBlock_ss(cur_depth) 125 | self.decres4_2 = BasicBlock_us(cur_depth, upsamp=2) 126 | cur_depth = int(cur_depth/2) 127 | 128 | # Decoding stage 3 129 | self.decres3_1 = BasicBlock_ss(cur_depth + d4, planes = cur_depth) 130 | self.decres3_2 = BasicBlock_us(cur_depth, upsamp=2) 131 | cur_depth = int(cur_depth/2) 132 | 133 | # Decoding stage 2 134 | self.decres2_1 = BasicBlock_ss(cur_depth + d3, planes = cur_depth) 135 | self.decres2_2 = BasicBlock_us(cur_depth, upsamp=2) 136 | cur_depth = int(cur_depth/2) 137 | 138 | # Decoding stage 1 139 | self.decres1_1 = BasicBlock_ss(cur_depth + d2, planes = cur_depth) 140 | self.decres1_2 = BasicBlock_us(cur_depth, upsamp=2) 141 | cur_depth = int(cur_depth/2) 142 | 143 | # Output 144 | self.coupling = nn.Conv2d(cur_depth + d1, label_nbr, kernel_size=1) 145 | self.sm = nn.LogSoftmax(dim=1) 146 | 147 | def forward(self, x1, x2): 148 | 149 | x = torch.cat((x1, x2), 1) 150 | 151 | # pad5 = ReplicationPad2d((0, x53.size(3) - x5d.size(3), 0, x53.size(2) - x5d.size(2))) 152 | 153 | s1_1 = x.size() 154 | x1 = self.encres1_1(x) 155 | x = self.encres1_2(x1) 156 | 157 | s2_1 = x.size() 158 | x2 = self.encres2_1(x) 159 | x = self.encres2_2(x2) 160 | 161 | s3_1 = x.size() 162 | x3 = self.encres3_1(x) 163 | x = self.encres3_2(x3) 164 | 165 | s4_1 = x.size() 166 | x4 = self.encres4_1(x) 167 | x = self.encres4_2(x4) 168 | 169 | x = self.decres4_1(x) 170 | x = self.decres4_2(x) 171 | s4_2 = x.size() 172 | pad4 = ReplicationPad2d((0, s4_1[3] - s4_2[3], 0, s4_1[2] - s4_2[2])) 173 | x = pad4(x) 174 | 175 | # x = self.decres3_1(x) 176 | x = self.decres3_1(torch.cat((x, x4), 1)) 177 | x = self.decres3_2(x) 178 | s3_2 = x.size() 179 | pad3 = ReplicationPad2d((0, s3_1[3] - s3_2[3], 0, s3_1[2] - s3_2[2])) 180 | x = pad3(x) 181 | 182 | x = self.decres2_1(torch.cat((x, x3), 1)) 183 | x = self.decres2_2(x) 184 | s2_2 = x.size() 185 | pad2 = ReplicationPad2d((0, s2_1[3] - s2_2[3], 0, s2_1[2] - s2_2[2])) 186 | x = pad2(x) 187 | 188 | x = self.decres1_1(torch.cat((x, x2), 1)) 189 | x = self.decres1_2(x) 190 | s1_2 = x.size() 191 | pad1 = ReplicationPad2d((0, s1_1[3] - s1_2[3], 0, s1_1[2] - s1_2[2])) 192 | x = pad1(x) 193 | 194 | x = self.coupling(torch.cat((x, x1), 1)) 195 | x = self.sm(x) 196 | 197 | return x -------------------------------------------------------------------------------- /sodeep_master/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | ****************** COPYRIGHT AND CONFIDENTIALITY INFORMATION ****************** 3 | Copyright (c) 2019 [Thomson Licensing] 4 | All Rights Reserved 5 | This program contains proprietary information which is a trade secret/business \ 6 | secret of [Thomson Licensing] and is protected, even if unpublished, under \ 7 | applicable Copyright laws (including French droit d'auteur) and/or may be \ 8 | subject to one or more patent(s). 9 | Recipient is to retain this program in confidence and is not permitted to use \ 10 | or make copies thereof other than as permitted in a written agreement with \ 11 | [Thomson Licensing] unless otherwise expressly allowed by applicable laws or \ 12 | by [Thomson Licensing] under express agreement. 13 | Thomson Licensing is a company of the group TECHNICOLOR 14 | ******************************************************************************* 15 | This scripts permits one to reproduce training and experiments of: 16 | Engilberge, M., Chevallier, L., Pérez, P., & Cord, M. (2019, June). 17 | SoDeep: A Sorting Deep Net to Learn Ranking Loss Surrogates. 18 | In Proceedings of CVPR 19 | 20 | Author: Martin Engilberge 21 | """ 22 | 23 | import argparse 24 | 25 | import os 26 | import time 27 | import torch 28 | import torch.nn as nn 29 | 30 | from .dataset import SeqDataset 31 | from .model import model_loader 32 | from torch.utils.data import DataLoader, SubsetRandomSampler 33 | from torch.optim.lr_scheduler import StepLR 34 | from tensorboardX import SummaryWriter 35 | from .utils import AverageMeter, save_checkpoint, log_epoch, count_parameters 36 | 37 | device = torch.device("cuda") 38 | # device = torch.device("cpu") 39 | 40 | 41 | 42 | 43 | def train(train_loader, model, criterion, optimizer, epoch, print_freq=1): 44 | model.train() 45 | 46 | batch_time = AverageMeter() 47 | data_time = AverageMeter() 48 | losses = AverageMeter() 49 | 50 | end = time.time() 51 | for i, (s, r) in enumerate(train_loader): 52 | 53 | seq_in, rank_in = s.float().to(device, non_blocking=True), r.float().to(device, non_blocking=True) 54 | data_time.update(time.time() - end) 55 | 56 | optimizer.zero_grad() 57 | rank_hat = model(seq_in) 58 | loss = criterion(rank_hat, rank_in) 59 | 60 | loss.backward() 61 | optimizer.step() 62 | 63 | losses.update(loss.item(), seq_in.size(0)) 64 | 65 | batch_time.update(time.time() - end) 66 | end = time.time() 67 | if i % print_freq == 0: 68 | print('Train: [{0}][{1}/{2}]\t' 69 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 70 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 71 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format( 72 | epoch, i + 1, len(train_loader), batch_time=batch_time, 73 | data_time=data_time, loss=losses)) 74 | 75 | print('Train: [{0}][{1}/{2}]\t' 76 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 77 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 78 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format( 79 | epoch, i + 1, len(train_loader), batch_time=batch_time, 80 | data_time=data_time, loss=losses), end="\n") 81 | 82 | return losses.avg, batch_time.avg, data_time.avg 83 | 84 | 85 | def validate(val_loader, model, criterion, print_freq=1): 86 | model.eval() 87 | 88 | batch_time = AverageMeter() 89 | data_time = AverageMeter() 90 | losses = AverageMeter() 91 | 92 | end = time.time() 93 | for i, (s, r) in enumerate(val_loader): 94 | 95 | seq_in, rank_in = s.float().to(device, non_blocking=True), r.float().to(device, non_blocking=True) 96 | data_time.update(time.time() - end) 97 | 98 | with torch.set_grad_enabled(False): 99 | rank_hat = model(seq_in) 100 | loss = criterion(rank_hat, rank_in) 101 | 102 | losses.update(loss.item(), seq_in.size(0)) 103 | batch_time.update(time.time() - end) 104 | end = time.time() 105 | 106 | if i % print_freq == 0: 107 | print('Val: [{0}/{1}]\t' 108 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 109 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 110 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format( 111 | i + 1, len(val_loader), batch_time=batch_time, 112 | data_time=data_time, loss=losses)) 113 | 114 | print('Val: [{0}/{1}]\t' 115 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 116 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 117 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format( 118 | i + 1, len(val_loader), batch_time=batch_time, 119 | data_time=data_time, loss=losses), end="\n") 120 | 121 | return losses.avg, batch_time.avg, data_time.avg 122 | 123 | 124 | if __name__ == '__main__': 125 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 126 | 127 | parser = argparse.ArgumentParser(description='Process some integers.') 128 | 129 | parser.add_argument("-n", '--name', default="model", help='Name of the model') 130 | parser.add_argument("-bs", "--batch_size", help="The size of the batches", type=int, default=256) 131 | parser.add_argument("-lr", dest="lr", help="Initialization of the learning rate", type=float, default=0.001) 132 | parser.add_argument("-lrs", dest="lr_steps", help="Number of epochs to step down LR", type=int, default=70) 133 | parser.add_argument("-mepoch", dest="mepoch", help="Max epoch", type=int, default=400) 134 | parser.add_argument("-pf", dest="print_frequency", help="Number of element processed between print", type=int, 135 | default=1) 136 | parser.add_argument("-slen", dest="seq_len", help="lenght of the sequence process by the ranker", type=int, 137 | default=30) 138 | parser.add_argument("-d", dest="dist", 139 | help="index of a single distribution for dataset if None all the distribution will be used.", 140 | default=None) 141 | parser.add_argument('-m', dest="model_type", 142 | help="Specify which model to use. (lstm, grus, gruc, grup, exa, lstmla, lstme, mlp, cnn) ", 143 | default='lstmla') 144 | 145 | args = parser.parse_args() 146 | 147 | # print("Using GPUs: ", os.environ['CUDA_VISIBLE_DEVICES']) 148 | 149 | writer = SummaryWriter(os.path.join("./logs/", args.name)) 150 | 151 | dset = SeqDataset(args.seq_len, dist=args.dist) 152 | 153 | train_loader = DataLoader(dset, batch_size=args.batch_size, shuffle=False, num_workers=8, 154 | sampler=SubsetRandomSampler(range(int(len(dset) * 0.1), len(dset)))) 155 | val_loader = DataLoader(dset, batch_size=args.batch_size, shuffle=False, num_workers=8, 156 | sampler=SubsetRandomSampler(range(int(len(dset) * 0.1)))) 157 | 158 | model = model_loader(args.model_type, args.seq_len) 159 | model.to(device) 160 | 161 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 162 | lr_scheduler = StepLR(optimizer, args.lr_steps, 0.5) 163 | 164 | criterion = nn.L1Loss() 165 | 166 | print("Nb parameters:", count_parameters(model)) 167 | 168 | start_epoch = 0 169 | best_rec = 10000 170 | for epoch in range(start_epoch, args.mepoch): 171 | is_best = False 172 | lr_scheduler.step() 173 | train_loss, batch_train, data_train = train(train_loader, model, criterion, optimizer, epoch, 174 | print_freq=args.print_frequency) 175 | 176 | val_loss, batch_val, data_val = validate(val_loader, model, criterion, print_freq=args.print_frequency) 177 | 178 | if (val_loss < best_rec): 179 | best_rec = val_loss 180 | is_best = True 181 | 182 | state = { 183 | 'epoch': epoch, 184 | 'state_dict': model.state_dict(), 185 | 'best_rec': best_rec, 186 | 'args_dict': args 187 | } 188 | 189 | log_epoch(writer, epoch, train_loss, val_loss, optimizer.param_groups[0]['lr'], batch_train, batch_val, 190 | data_train, data_val) 191 | save_checkpoint(state, is_best, args.name, epoch) 192 | 193 | print('Finished Training') 194 | print(best_rec) 195 | -------------------------------------------------------------------------------- /siamunet_conc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.modules.padding import ReplicationPad2d 5 | 6 | 7 | class SiamUnet_conc(nn.Module): 8 | """SiamUnet_conc segmentation network.""" 9 | 10 | def __init__(self, input_nbr, label_nbr): 11 | super(SiamUnet_conc, self).__init__() 12 | 13 | self.input_nbr = input_nbr 14 | 15 | self.conv11 = nn.Conv2d(input_nbr, 16, kernel_size=3, padding=1) 16 | self.bn11 = nn.BatchNorm2d(16) 17 | self.do11 = nn.Dropout2d(p=0.2) 18 | self.conv12 = nn.Conv2d(16, 16, kernel_size=3, padding=1) 19 | self.bn12 = nn.BatchNorm2d(16) 20 | self.do12 = nn.Dropout2d(p=0.2) 21 | 22 | self.conv21 = nn.Conv2d(16, 32, kernel_size=3, padding=1) 23 | self.bn21 = nn.BatchNorm2d(32) 24 | self.do21 = nn.Dropout2d(p=0.2) 25 | self.conv22 = nn.Conv2d(32, 32, kernel_size=3, padding=1) 26 | self.bn22 = nn.BatchNorm2d(32) 27 | self.do22 = nn.Dropout2d(p=0.2) 28 | 29 | self.conv31 = nn.Conv2d(32, 64, kernel_size=3, padding=1) 30 | self.bn31 = nn.BatchNorm2d(64) 31 | self.do31 = nn.Dropout2d(p=0.2) 32 | self.conv32 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 33 | self.bn32 = nn.BatchNorm2d(64) 34 | self.do32 = nn.Dropout2d(p=0.2) 35 | self.conv33 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 36 | self.bn33 = nn.BatchNorm2d(64) 37 | self.do33 = nn.Dropout2d(p=0.2) 38 | 39 | self.conv41 = nn.Conv2d(64, 128, kernel_size=3, padding=1) 40 | self.bn41 = nn.BatchNorm2d(128) 41 | self.do41 = nn.Dropout2d(p=0.2) 42 | self.conv42 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 43 | self.bn42 = nn.BatchNorm2d(128) 44 | self.do42 = nn.Dropout2d(p=0.2) 45 | self.conv43 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 46 | self.bn43 = nn.BatchNorm2d(128) 47 | self.do43 = nn.Dropout2d(p=0.2) 48 | 49 | self.upconv4 = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1, stride=2, output_padding=1) 50 | 51 | self.conv43d = nn.ConvTranspose2d(384, 128, kernel_size=3, padding=1) 52 | self.bn43d = nn.BatchNorm2d(128) 53 | self.do43d = nn.Dropout2d(p=0.2) 54 | self.conv42d = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1) 55 | self.bn42d = nn.BatchNorm2d(128) 56 | self.do42d = nn.Dropout2d(p=0.2) 57 | self.conv41d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1) 58 | self.bn41d = nn.BatchNorm2d(64) 59 | self.do41d = nn.Dropout2d(p=0.2) 60 | 61 | self.upconv3 = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1, stride=2, output_padding=1) 62 | 63 | self.conv33d = nn.ConvTranspose2d(192, 64, kernel_size=3, padding=1) 64 | self.bn33d = nn.BatchNorm2d(64) 65 | self.do33d = nn.Dropout2d(p=0.2) 66 | self.conv32d = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1) 67 | self.bn32d = nn.BatchNorm2d(64) 68 | self.do32d = nn.Dropout2d(p=0.2) 69 | self.conv31d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1) 70 | self.bn31d = nn.BatchNorm2d(32) 71 | self.do31d = nn.Dropout2d(p=0.2) 72 | 73 | self.upconv2 = nn.ConvTranspose2d(32, 32, kernel_size=3, padding=1, stride=2, output_padding=1) 74 | 75 | self.conv22d = nn.ConvTranspose2d(96, 32, kernel_size=3, padding=1) 76 | self.bn22d = nn.BatchNorm2d(32) 77 | self.do22d = nn.Dropout2d(p=0.2) 78 | self.conv21d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1) 79 | self.bn21d = nn.BatchNorm2d(16) 80 | self.do21d = nn.Dropout2d(p=0.2) 81 | 82 | self.upconv1 = nn.ConvTranspose2d(16, 16, kernel_size=3, padding=1, stride=2, output_padding=1) 83 | 84 | self.conv12d = nn.ConvTranspose2d(48, 16, kernel_size=3, padding=1) 85 | self.bn12d = nn.BatchNorm2d(16) 86 | self.do12d = nn.Dropout2d(p=0.2) 87 | self.conv11d = nn.ConvTranspose2d(16, label_nbr, kernel_size=3, padding=1) 88 | 89 | # self.sm = nn.LogSoftmax(dim=1) 90 | self.fc1 = nn.Linear(8649, 1) 91 | 92 | def forward(self, x1, x2): 93 | """Forward method.""" 94 | # Stage 1 95 | x11 = self.do11(F.relu(self.bn11(self.conv11(x1)))) 96 | x12_1 = self.do12(F.relu(self.bn12(self.conv12(x11)))) 97 | x1p = F.max_pool2d(x12_1, kernel_size=2, stride=2) 98 | 99 | # Stage 2 100 | x21 = self.do21(F.relu(self.bn21(self.conv21(x1p)))) 101 | x22_1 = self.do22(F.relu(self.bn22(self.conv22(x21)))) 102 | x2p = F.max_pool2d(x22_1, kernel_size=2, stride=2) 103 | 104 | # Stage 3 105 | x31 = self.do31(F.relu(self.bn31(self.conv31(x2p)))) 106 | x32 = self.do32(F.relu(self.bn32(self.conv32(x31)))) 107 | x33_1 = self.do33(F.relu(self.bn33(self.conv33(x32)))) 108 | x3p = F.max_pool2d(x33_1, kernel_size=2, stride=2) 109 | 110 | # Stage 4 111 | x41 = self.do41(F.relu(self.bn41(self.conv41(x3p)))) 112 | x42 = self.do42(F.relu(self.bn42(self.conv42(x41)))) 113 | x43_1 = self.do43(F.relu(self.bn43(self.conv43(x42)))) 114 | x4p = F.max_pool2d(x43_1, kernel_size=2, stride=2) 115 | 116 | #################################################### 117 | # Stage 1 118 | x11 = self.do11(F.relu(self.bn11(self.conv11(x2)))) 119 | x12_2 = self.do12(F.relu(self.bn12(self.conv12(x11)))) 120 | x1p = F.max_pool2d(x12_2, kernel_size=2, stride=2) 121 | 122 | # Stage 2 123 | x21 = self.do21(F.relu(self.bn21(self.conv21(x1p)))) 124 | x22_2 = self.do22(F.relu(self.bn22(self.conv22(x21)))) 125 | x2p = F.max_pool2d(x22_2, kernel_size=2, stride=2) 126 | 127 | # Stage 3 128 | x31 = self.do31(F.relu(self.bn31(self.conv31(x2p)))) 129 | x32 = self.do32(F.relu(self.bn32(self.conv32(x31)))) 130 | x33_2 = self.do33(F.relu(self.bn33(self.conv33(x32)))) 131 | x3p = F.max_pool2d(x33_2, kernel_size=2, stride=2) 132 | 133 | # Stage 4 134 | x41 = self.do41(F.relu(self.bn41(self.conv41(x3p)))) 135 | x42 = self.do42(F.relu(self.bn42(self.conv42(x41)))) 136 | x43_2 = self.do43(F.relu(self.bn43(self.conv43(x42)))) 137 | x4p = F.max_pool2d(x43_2, kernel_size=2, stride=2) 138 | 139 | #################################################### 140 | # Stage 4d 141 | x4d = self.upconv4(x4p) 142 | pad4 = ReplicationPad2d((0, x43_1.size(3) - x4d.size(3), 0, x43_1.size(2) - x4d.size(2))) 143 | x4d = torch.cat((pad4(x4d), x43_1, x43_2), 1) 144 | x43d = self.do43d(F.relu(self.bn43d(self.conv43d(x4d)))) 145 | x42d = self.do42d(F.relu(self.bn42d(self.conv42d(x43d)))) 146 | x41d = self.do41d(F.relu(self.bn41d(self.conv41d(x42d)))) 147 | 148 | # Stage 3d 149 | x3d = self.upconv3(x41d) 150 | pad3 = ReplicationPad2d((0, x33_1.size(3) - x3d.size(3), 0, x33_1.size(2) - x3d.size(2))) 151 | x3d = torch.cat((pad3(x3d), x33_1, x33_2), 1) 152 | x33d = self.do33d(F.relu(self.bn33d(self.conv33d(x3d)))) 153 | x32d = self.do32d(F.relu(self.bn32d(self.conv32d(x33d)))) 154 | x31d = self.do31d(F.relu(self.bn31d(self.conv31d(x32d)))) 155 | 156 | # Stage 2d 157 | x2d = self.upconv2(x31d) 158 | pad2 = ReplicationPad2d((0, x22_1.size(3) - x2d.size(3), 0, x22_1.size(2) - x2d.size(2))) 159 | x2d = torch.cat((pad2(x2d), x22_1, x22_2), 1) 160 | x22d = self.do22d(F.relu(self.bn22d(self.conv22d(x2d)))) 161 | x21d = self.do21d(F.relu(self.bn21d(self.conv21d(x22d)))) 162 | 163 | # Stage 1d 164 | x1d = self.upconv1(x21d) 165 | pad1 = ReplicationPad2d((0, x12_1.size(3) - x1d.size(3), 0, x12_1.size(2) - x1d.size(2))) 166 | x1d = torch.cat((pad1(x1d), x12_1, x12_2), 1) 167 | x12d = self.do12d(F.relu(self.bn12d(self.conv12d(x1d)))) 168 | x11d = self.conv11d(x12d) 169 | 170 | nfs = [8, 16, 32, 64, 64, 64, 171 | 32, 16, 8, 32] 172 | kss = [5, 5, 3, 3] 173 | 174 | alpha = 0.5 175 | dense_num = 32 176 | 177 | # return self.sm(x11d) 178 | last_act = F.relu(x11d) 179 | pool_m = F.max_pool2d(last_act, kernel_size=12, stride=3) 180 | # pool_m = F.avg_pool2d(last_act, kernel_size=12, stride=3) 181 | flat = torch.flatten(pool_m, 1) 182 | out = self.fc1(flat) 183 | 184 | return out 185 | 186 | # return self.sm(x11d) 187 | -------------------------------------------------------------------------------- /siamunet_diff.py: -------------------------------------------------------------------------------- 1 | # Rodrigo Caye Daudt 2 | # https://rcdaudt.github.io/ 3 | # Daudt, R. C., Le Saux, B., & Boulch, A. "Fully convolutional siamese networks for change detection". In 2018 25th IEEE International Conference on Image Processing (ICIP) (pp. 4063-4067). IEEE. 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.nn.modules.padding import ReplicationPad2d 9 | 10 | 11 | class SiamUnet_diff(nn.Module): 12 | """SiamUnet_diff segmentation network.""" 13 | 14 | def __init__(self, input_nbr, label_nbr): 15 | super(SiamUnet_diff, self).__init__() 16 | 17 | self.input_nbr = input_nbr 18 | 19 | self.conv11 = nn.Conv2d(input_nbr, 16, kernel_size=3, padding=1) 20 | self.bn11 = nn.BatchNorm2d(16) 21 | self.do11 = nn.Dropout2d(p=0.2) 22 | self.conv12 = nn.Conv2d(16, 16, kernel_size=3, padding=1) 23 | self.bn12 = nn.BatchNorm2d(16) 24 | self.do12 = nn.Dropout2d(p=0.2) 25 | 26 | self.conv21 = nn.Conv2d(16, 32, kernel_size=3, padding=1) 27 | self.bn21 = nn.BatchNorm2d(32) 28 | self.do21 = nn.Dropout2d(p=0.2) 29 | self.conv22 = nn.Conv2d(32, 32, kernel_size=3, padding=1) 30 | self.bn22 = nn.BatchNorm2d(32) 31 | self.do22 = nn.Dropout2d(p=0.2) 32 | 33 | self.conv31 = nn.Conv2d(32, 64, kernel_size=3, padding=1) 34 | self.bn31 = nn.BatchNorm2d(64) 35 | self.do31 = nn.Dropout2d(p=0.2) 36 | self.conv32 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 37 | self.bn32 = nn.BatchNorm2d(64) 38 | self.do32 = nn.Dropout2d(p=0.2) 39 | self.conv33 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 40 | self.bn33 = nn.BatchNorm2d(64) 41 | self.do33 = nn.Dropout2d(p=0.2) 42 | 43 | self.conv41 = nn.Conv2d(64, 128, kernel_size=3, padding=1) 44 | self.bn41 = nn.BatchNorm2d(128) 45 | self.do41 = nn.Dropout2d(p=0.2) 46 | self.conv42 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 47 | self.bn42 = nn.BatchNorm2d(128) 48 | self.do42 = nn.Dropout2d(p=0.2) 49 | self.conv43 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 50 | self.bn43 = nn.BatchNorm2d(128) 51 | self.do43 = nn.Dropout2d(p=0.2) 52 | 53 | self.upconv4 = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1, stride=2, output_padding=1) 54 | 55 | self.conv43d = nn.ConvTranspose2d(256, 128, kernel_size=3, padding=1) 56 | self.bn43d = nn.BatchNorm2d(128) 57 | self.do43d = nn.Dropout2d(p=0.2) 58 | self.conv42d = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1) 59 | self.bn42d = nn.BatchNorm2d(128) 60 | self.do42d = nn.Dropout2d(p=0.2) 61 | self.conv41d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1) 62 | self.bn41d = nn.BatchNorm2d(64) 63 | self.do41d = nn.Dropout2d(p=0.2) 64 | 65 | self.upconv3 = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1, stride=2, output_padding=1) 66 | 67 | self.conv33d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1) 68 | self.bn33d = nn.BatchNorm2d(64) 69 | self.do33d = nn.Dropout2d(p=0.2) 70 | self.conv32d = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1) 71 | self.bn32d = nn.BatchNorm2d(64) 72 | self.do32d = nn.Dropout2d(p=0.2) 73 | self.conv31d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1) 74 | self.bn31d = nn.BatchNorm2d(32) 75 | self.do31d = nn.Dropout2d(p=0.2) 76 | 77 | self.upconv2 = nn.ConvTranspose2d(32, 32, kernel_size=3, padding=1, stride=2, output_padding=1) 78 | 79 | self.conv22d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1) 80 | self.bn22d = nn.BatchNorm2d(32) 81 | self.do22d = nn.Dropout2d(p=0.2) 82 | self.conv21d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1) 83 | self.bn21d = nn.BatchNorm2d(16) 84 | self.do21d = nn.Dropout2d(p=0.2) 85 | 86 | self.upconv1 = nn.ConvTranspose2d(16, 16, kernel_size=3, padding=1, stride=2, output_padding=1) 87 | 88 | self.conv12d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1) 89 | self.bn12d = nn.BatchNorm2d(16) 90 | self.do12d = nn.Dropout2d(p=0.2) 91 | self.conv11d = nn.ConvTranspose2d(16, label_nbr, kernel_size=3, padding=1) 92 | 93 | self.fc1 = nn.Linear(8649, 1) 94 | # self.sm = nn.LogSoftmax(dim=1) 95 | 96 | def forward(self, x1, x2): 97 | """Forward method.""" 98 | # Stage 1 99 | x11 = self.do11(F.relu(self.bn11(self.conv11(x1)))) 100 | x12_1 = self.do12(F.relu(self.bn12(self.conv12(x11)))) 101 | x1p = F.max_pool2d(x12_1, kernel_size=2, stride=2) 102 | 103 | # Stage 2 104 | x21 = self.do21(F.relu(self.bn21(self.conv21(x1p)))) 105 | x22_1 = self.do22(F.relu(self.bn22(self.conv22(x21)))) 106 | x2p = F.max_pool2d(x22_1, kernel_size=2, stride=2) 107 | 108 | # Stage 3 109 | x31 = self.do31(F.relu(self.bn31(self.conv31(x2p)))) 110 | x32 = self.do32(F.relu(self.bn32(self.conv32(x31)))) 111 | x33_1 = self.do33(F.relu(self.bn33(self.conv33(x32)))) 112 | x3p = F.max_pool2d(x33_1, kernel_size=2, stride=2) 113 | 114 | # Stage 4 115 | x41 = self.do41(F.relu(self.bn41(self.conv41(x3p)))) 116 | x42 = self.do42(F.relu(self.bn42(self.conv42(x41)))) 117 | x43_1 = self.do43(F.relu(self.bn43(self.conv43(x42)))) 118 | x4p = F.max_pool2d(x43_1, kernel_size=2, stride=2) 119 | 120 | #################################################### 121 | # Stage 1 122 | x11 = self.do11(F.relu(self.bn11(self.conv11(x2)))) 123 | x12_2 = self.do12(F.relu(self.bn12(self.conv12(x11)))) 124 | x1p = F.max_pool2d(x12_2, kernel_size=2, stride=2) 125 | 126 | # Stage 2 127 | x21 = self.do21(F.relu(self.bn21(self.conv21(x1p)))) 128 | x22_2 = self.do22(F.relu(self.bn22(self.conv22(x21)))) 129 | x2p = F.max_pool2d(x22_2, kernel_size=2, stride=2) 130 | 131 | # Stage 3 132 | x31 = self.do31(F.relu(self.bn31(self.conv31(x2p)))) 133 | x32 = self.do32(F.relu(self.bn32(self.conv32(x31)))) 134 | x33_2 = self.do33(F.relu(self.bn33(self.conv33(x32)))) 135 | x3p = F.max_pool2d(x33_2, kernel_size=2, stride=2) 136 | 137 | # Stage 4 138 | x41 = self.do41(F.relu(self.bn41(self.conv41(x3p)))) 139 | x42 = self.do42(F.relu(self.bn42(self.conv42(x41)))) 140 | x43_2 = self.do43(F.relu(self.bn43(self.conv43(x42)))) 141 | x4p = F.max_pool2d(x43_2, kernel_size=2, stride=2) 142 | 143 | # Stage 4d 144 | x4d = self.upconv4(x4p) 145 | pad4 = ReplicationPad2d((0, x43_1.size(3) - x4d.size(3), 0, x43_1.size(2) - x4d.size(2))) 146 | x4d = torch.cat((pad4(x4d), torch.abs(x43_1 - x43_2)), 1) 147 | x43d = self.do43d(F.relu(self.bn43d(self.conv43d(x4d)))) 148 | x42d = self.do42d(F.relu(self.bn42d(self.conv42d(x43d)))) 149 | x41d = self.do41d(F.relu(self.bn41d(self.conv41d(x42d)))) 150 | 151 | # Stage 3d 152 | x3d = self.upconv3(x41d) 153 | pad3 = ReplicationPad2d((0, x33_1.size(3) - x3d.size(3), 0, x33_1.size(2) - x3d.size(2))) 154 | x3d = torch.cat((pad3(x3d), torch.abs(x33_1 - x33_2)), 1) 155 | x33d = self.do33d(F.relu(self.bn33d(self.conv33d(x3d)))) 156 | x32d = self.do32d(F.relu(self.bn32d(self.conv32d(x33d)))) 157 | x31d = self.do31d(F.relu(self.bn31d(self.conv31d(x32d)))) 158 | 159 | # Stage 2d 160 | x2d = self.upconv2(x31d) 161 | pad2 = ReplicationPad2d((0, x22_1.size(3) - x2d.size(3), 0, x22_1.size(2) - x2d.size(2))) 162 | x2d = torch.cat((pad2(x2d), torch.abs(x22_1 - x22_2)), 1) 163 | x22d = self.do22d(F.relu(self.bn22d(self.conv22d(x2d)))) 164 | x21d = self.do21d(F.relu(self.bn21d(self.conv21d(x22d)))) 165 | 166 | # Stage 1d 167 | x1d = self.upconv1(x21d) 168 | pad1 = ReplicationPad2d((0, x12_1.size(3) - x1d.size(3), 0, x12_1.size(2) - x1d.size(2))) 169 | x1d = torch.cat((pad1(x1d), torch.abs(x12_1 - x12_2)), 1) 170 | x12d = self.do12d(F.relu(self.bn12d(self.conv12d(x1d)))) 171 | x11d = self.conv11d(x12d) 172 | 173 | # return self.sm(x11d) 174 | last_act = F.relu(x11d) 175 | pool_m = F.max_pool2d(last_act, kernel_size=12, stride=3) 176 | # pool_m = F.avg_pool2d(last_act, kernel_size=12, stride=3) 177 | flat = torch.flatten(pool_m, 1) 178 | out = self.fc1(flat) 179 | 180 | return out 181 | -------------------------------------------------------------------------------- /siamunet_conc_extrahead.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.modules.padding import ReplicationPad2d 5 | 6 | 7 | class SiamUnet_conc(nn.Module): 8 | """SiamUnet_conc segmentation network.""" 9 | 10 | def __init__(self, input_nbr, label_nbr): 11 | super(SiamUnet_conc, self).__init__() 12 | 13 | self.input_nbr = input_nbr 14 | 15 | self.conv11 = nn.Conv2d(input_nbr, 16, kernel_size=3, padding=1) 16 | self.bn11 = nn.BatchNorm2d(16) 17 | self.do11 = nn.Dropout2d(p=0.2) 18 | self.conv12 = nn.Conv2d(16, 16, kernel_size=3, padding=1) 19 | self.bn12 = nn.BatchNorm2d(16) 20 | self.do12 = nn.Dropout2d(p=0.2) 21 | 22 | self.conv21 = nn.Conv2d(16, 32, kernel_size=3, padding=1) 23 | self.bn21 = nn.BatchNorm2d(32) 24 | self.do21 = nn.Dropout2d(p=0.2) 25 | self.conv22 = nn.Conv2d(32, 32, kernel_size=3, padding=1) 26 | self.bn22 = nn.BatchNorm2d(32) 27 | self.do22 = nn.Dropout2d(p=0.2) 28 | 29 | self.conv31 = nn.Conv2d(32, 64, kernel_size=3, padding=1) 30 | self.bn31 = nn.BatchNorm2d(64) 31 | self.do31 = nn.Dropout2d(p=0.2) 32 | self.conv32 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 33 | self.bn32 = nn.BatchNorm2d(64) 34 | self.do32 = nn.Dropout2d(p=0.2) 35 | self.conv33 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 36 | self.bn33 = nn.BatchNorm2d(64) 37 | self.do33 = nn.Dropout2d(p=0.2) 38 | 39 | self.conv41 = nn.Conv2d(64, 128, kernel_size=3, padding=1) 40 | self.bn41 = nn.BatchNorm2d(128) 41 | self.do41 = nn.Dropout2d(p=0.2) 42 | self.conv42 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 43 | self.bn42 = nn.BatchNorm2d(128) 44 | self.do42 = nn.Dropout2d(p=0.2) 45 | self.conv43 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 46 | self.bn43 = nn.BatchNorm2d(128) 47 | self.do43 = nn.Dropout2d(p=0.2) 48 | 49 | self.upconv4 = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1, stride=2, output_padding=1) 50 | 51 | self.conv43d = nn.ConvTranspose2d(384, 128, kernel_size=3, padding=1) 52 | self.bn43d = nn.BatchNorm2d(128) 53 | self.do43d = nn.Dropout2d(p=0.2) 54 | self.conv42d = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1) 55 | self.bn42d = nn.BatchNorm2d(128) 56 | self.do42d = nn.Dropout2d(p=0.2) 57 | self.conv41d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1) 58 | self.bn41d = nn.BatchNorm2d(64) 59 | self.do41d = nn.Dropout2d(p=0.2) 60 | 61 | self.upconv3 = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1, stride=2, output_padding=1) 62 | 63 | self.conv33d = nn.ConvTranspose2d(192, 64, kernel_size=3, padding=1) 64 | self.bn33d = nn.BatchNorm2d(64) 65 | self.do33d = nn.Dropout2d(p=0.2) 66 | self.conv32d = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1) 67 | self.bn32d = nn.BatchNorm2d(64) 68 | self.do32d = nn.Dropout2d(p=0.2) 69 | self.conv31d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1) 70 | self.bn31d = nn.BatchNorm2d(32) 71 | self.do31d = nn.Dropout2d(p=0.2) 72 | 73 | self.upconv2 = nn.ConvTranspose2d(32, 32, kernel_size=3, padding=1, stride=2, output_padding=1) 74 | 75 | self.conv22d = nn.ConvTranspose2d(96, 32, kernel_size=3, padding=1) 76 | self.bn22d = nn.BatchNorm2d(32) 77 | self.do22d = nn.Dropout2d(p=0.2) 78 | self.conv21d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1) 79 | self.bn21d = nn.BatchNorm2d(16) 80 | self.do21d = nn.Dropout2d(p=0.2) 81 | 82 | self.upconv1 = nn.ConvTranspose2d(16, 16, kernel_size=3, padding=1, stride=2, output_padding=1) 83 | 84 | self.conv12d = nn.ConvTranspose2d(48, 16, kernel_size=3, padding=1) 85 | self.bn12d = nn.BatchNorm2d(16) 86 | self.do12d = nn.Dropout2d(p=0.2) 87 | self.conv11d = nn.ConvTranspose2d(16, label_nbr, kernel_size=3, padding=1) 88 | 89 | self.cov_h1 = nn.Conv2d(label_nbr,16, kernel_size=3, padding=1, stride=4) 90 | self.cov_h2 = nn.Conv2d(16,8, kernel_size=3, padding=1, stride=2) 91 | self.fc_h1 = nn.Linear(10368 , 32) 92 | self.do_h1 = nn.Dropout(p=0.1) 93 | self.fc_h2 = nn.Linear(32, 1) 94 | self.fc_h3 = nn.Linear(1, 1) 95 | 96 | # self.sm = nn.LogSoftmax(dim=1) 97 | # self.fc1 = nn.Linear(8649, 1) 98 | 99 | def forward(self, x1, x2): 100 | """Forward method.""" 101 | # Stage 1 102 | x11 = self.do11(F.relu(self.bn11(self.conv11(x1)))) 103 | x12_1 = self.do12(F.relu(self.bn12(self.conv12(x11)))) 104 | x1p = F.max_pool2d(x12_1, kernel_size=2, stride=2) 105 | 106 | # Stage 2 107 | x21 = self.do21(F.relu(self.bn21(self.conv21(x1p)))) 108 | x22_1 = self.do22(F.relu(self.bn22(self.conv22(x21)))) 109 | x2p = F.max_pool2d(x22_1, kernel_size=2, stride=2) 110 | 111 | # Stage 3 112 | x31 = self.do31(F.relu(self.bn31(self.conv31(x2p)))) 113 | x32 = self.do32(F.relu(self.bn32(self.conv32(x31)))) 114 | x33_1 = self.do33(F.relu(self.bn33(self.conv33(x32)))) 115 | x3p = F.max_pool2d(x33_1, kernel_size=2, stride=2) 116 | 117 | # Stage 4 118 | x41 = self.do41(F.relu(self.bn41(self.conv41(x3p)))) 119 | x42 = self.do42(F.relu(self.bn42(self.conv42(x41)))) 120 | x43_1 = self.do43(F.relu(self.bn43(self.conv43(x42)))) 121 | x4p = F.max_pool2d(x43_1, kernel_size=2, stride=2) 122 | 123 | #################################################### 124 | # Stage 1 125 | x11 = self.do11(F.relu(self.bn11(self.conv11(x2)))) 126 | x12_2 = self.do12(F.relu(self.bn12(self.conv12(x11)))) 127 | x1p = F.max_pool2d(x12_2, kernel_size=2, stride=2) 128 | 129 | # Stage 2 130 | x21 = self.do21(F.relu(self.bn21(self.conv21(x1p)))) 131 | x22_2 = self.do22(F.relu(self.bn22(self.conv22(x21)))) 132 | x2p = F.max_pool2d(x22_2, kernel_size=2, stride=2) 133 | 134 | # Stage 3 135 | x31 = self.do31(F.relu(self.bn31(self.conv31(x2p)))) 136 | x32 = self.do32(F.relu(self.bn32(self.conv32(x31)))) 137 | x33_2 = self.do33(F.relu(self.bn33(self.conv33(x32)))) 138 | x3p = F.max_pool2d(x33_2, kernel_size=2, stride=2) 139 | 140 | # Stage 4 141 | x41 = self.do41(F.relu(self.bn41(self.conv41(x3p)))) 142 | x42 = self.do42(F.relu(self.bn42(self.conv42(x41)))) 143 | x43_2 = self.do43(F.relu(self.bn43(self.conv43(x42)))) 144 | x4p = F.max_pool2d(x43_2, kernel_size=2, stride=2) 145 | 146 | #################################################### 147 | # Stage 4d 148 | x4d = self.upconv4(x4p) 149 | pad4 = ReplicationPad2d((0, x43_1.size(3) - x4d.size(3), 0, x43_1.size(2) - x4d.size(2))) 150 | x4d = torch.cat((pad4(x4d), x43_1, x43_2), 1) 151 | x43d = self.do43d(F.relu(self.bn43d(self.conv43d(x4d)))) 152 | x42d = self.do42d(F.relu(self.bn42d(self.conv42d(x43d)))) 153 | x41d = self.do41d(F.relu(self.bn41d(self.conv41d(x42d)))) 154 | 155 | # Stage 3d 156 | x3d = self.upconv3(x41d) 157 | pad3 = ReplicationPad2d((0, x33_1.size(3) - x3d.size(3), 0, x33_1.size(2) - x3d.size(2))) 158 | x3d = torch.cat((pad3(x3d), x33_1, x33_2), 1) 159 | x33d = self.do33d(F.relu(self.bn33d(self.conv33d(x3d)))) 160 | x32d = self.do32d(F.relu(self.bn32d(self.conv32d(x33d)))) 161 | x31d = self.do31d(F.relu(self.bn31d(self.conv31d(x32d)))) 162 | 163 | # Stage 2d 164 | x2d = self.upconv2(x31d) 165 | pad2 = ReplicationPad2d((0, x22_1.size(3) - x2d.size(3), 0, x22_1.size(2) - x2d.size(2))) 166 | x2d = torch.cat((pad2(x2d), x22_1, x22_2), 1) 167 | x22d = self.do22d(F.relu(self.bn22d(self.conv22d(x2d)))) 168 | x21d = self.do21d(F.relu(self.bn21d(self.conv21d(x22d)))) 169 | 170 | # Stage 1d 171 | x1d = self.upconv1(x21d) 172 | pad1 = ReplicationPad2d((0, x12_1.size(3) - x1d.size(3), 0, x12_1.size(2) - x1d.size(2))) 173 | x1d = torch.cat((pad1(x1d), x12_1, x12_2), 1) 174 | x12d = self.do12d(F.relu(self.bn12d(self.conv12d(x1d)))) 175 | x11d = self.conv11d(x12d) 176 | 177 | # return self.sm(x11d) 178 | branch = self.cov_h1(x11d) 179 | branch = nn.LeakyReLU(0.5)(branch) 180 | branch = self.cov_h2(branch) 181 | branch = nn.LeakyReLU(0.5)(branch) 182 | branch = torch.flatten(branch,1) 183 | branch = F.relu(self.fc_h1(branch)) 184 | branch = self.do_h1(branch) 185 | branch = nn.Sigmoid()(self.fc_h2(branch)) 186 | out = self.fc_h3(branch) 187 | 188 | 189 | 190 | return out 191 | 192 | # return self.sm(x11d) 193 | -------------------------------------------------------------------------------- /sodeep_master/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | ****************** COPYRIGHT AND CONFIDENTIALITY INFORMATION ****************** 3 | Copyright (c) 2019 [Thomson Licensing] 4 | All Rights Reserved 5 | This program contains proprietary information which is a trade secret/business \ 6 | secret of [Thomson Licensing] and is protected, even if unpublished, under \ 7 | applicable Copyright laws (including French droit d'auteur) and/or may be \ 8 | subject to one or more patent(s). 9 | Recipient is to retain this program in confidence and is not permitted to use \ 10 | or make copies thereof other than as permitted in a written agreement with \ 11 | [Thomson Licensing] unless otherwise expressly allowed by applicable laws or \ 12 | by [Thomson Licensing] under express agreement. 13 | Thomson Licensing is a company of the group TECHNICOLOR 14 | ******************************************************************************* 15 | This scripts permits one to reproduce training and experiments of: 16 | Engilberge, M., Chevallier, L., Pérez, P., & Cord, M. (2019, June). 17 | SoDeep: A Sorting Deep Net to Learn Ranking Loss Surrogates. 18 | In Proceedings of CVPR 19 | 20 | Author: Martin Engilberge 21 | """ 22 | 23 | 24 | import torch 25 | import torch.nn as nn 26 | 27 | from .utils import get_rank 28 | 29 | 30 | def model_loader(model_type, seq_len, pretrained_state_dict=None): 31 | 32 | if model_type == "lstm": 33 | model = lstm_baseline(seq_len) 34 | elif model_type == "grus": 35 | model = gru_sum(seq_len) 36 | elif model_type == "gruc": 37 | model = gru_constrained(seq_len) 38 | elif model_type == "grup": 39 | model = gru_proj(seq_len) 40 | elif model_type == "exa": 41 | model = sorter_exact() 42 | elif model_type == "lstmla": 43 | model = lstm_large(seq_len) 44 | elif model_type == "lstme": 45 | model = lstm_end(seq_len) 46 | elif model_type == "mlp": 47 | model = mlp(seq_len) 48 | elif model_type == "cnn": 49 | return cnn(seq_len) 50 | else: 51 | raise Exception("Model type unknown", model_type) 52 | 53 | if pretrained_state_dict is not None: 54 | model.load_state_dict(pretrained_state_dict) 55 | 56 | return model 57 | 58 | 59 | class UpdatingWrapper(nn.Module): 60 | """ Wrapper to store the data forwarded throught the sorter and use them later to finetune the sorter on real data 61 | Once enough data have been colected a call to the method update_sorter will perform the finetuning of the sorter. 62 | """ 63 | def __init__(self, sorter, lr_sorter=0.00001): 64 | super(UpdatingWrapper, self).__init__() 65 | self.sorter = sorter 66 | 67 | self.opti = torch.optim.Adam(self.sorter.parameters(), lr=lr_sorter, betas=(0.9, 0.999)) 68 | self.criterion = nn.L1Loss() 69 | 70 | self.average_loss = list() 71 | 72 | self.collected_data = list() 73 | 74 | self.nb_update = 10 75 | 76 | def forward(self, input_): 77 | out = self.sorter(input_) 78 | 79 | self.collected_data.append(input_.detach().cpu()) 80 | return out 81 | 82 | def update_sorter(self): 83 | 84 | for input_opti in self.collected_data: 85 | self.opti.zero_grad() 86 | 87 | input_opti = input_opti.cuda() 88 | input_opti.requires_grad = True 89 | 90 | rank_gt = get_rank(input_opti) 91 | 92 | out_opti = self.sorter(input_opti) 93 | 94 | loss = self.criterion(out_opti, rank_gt) 95 | loss.backward() 96 | self.opti.step() 97 | 98 | self.average_loss.append(loss.item()) 99 | 100 | # Empty collected data 101 | self.collected_data = list() 102 | 103 | def save_data(self, path): 104 | torch.save(self.collected_data, path) 105 | 106 | def get_loss_average(self, windows=50): 107 | return sum(self.average_loss[-windows:]) / min(len(self.average_loss), windows) 108 | 109 | 110 | class lstm_baseline(nn.Module): 111 | def __init__(self, seq_len): 112 | super(lstm_baseline, self).__init__() 113 | self.lstm = nn.LSTM(1, 128, 2, batch_first=True, bidirectional=True) 114 | self.conv1 = nn.Conv1d(seq_len, seq_len, 256) 115 | 116 | def forward(self, input_): 117 | input_ = input_.reshape(input_.size(0), -1, 1) 118 | out, _ = self.lstm(input_) 119 | out = self.conv1(out) 120 | 121 | return out.view(input_.size(0), -1) 122 | 123 | 124 | class gru_constrained(nn.Module): 125 | def __init__(self, seq_len): 126 | super(gru_constrained, self).__init__() 127 | self.rnn = nn.GRU(1, 32, 6, batch_first=True, bidirectional=True) 128 | 129 | self.sig = torch.nn.Sigmoid() 130 | 131 | def forward(self, input_): 132 | input_ = (input_.reshape(input_.size(0), -1, 1) / 2.0) + 1 133 | input_ = self.sig(input_) 134 | 135 | x, hn = self.rnn(input_) 136 | out = x.sum(dim=2) 137 | 138 | out = self.sig(out) 139 | 140 | return out.view(input_.size(0), -1) 141 | 142 | 143 | class gru_proj(nn.Module): 144 | 145 | def __init__(self, seq_len): 146 | super(gru_proj, self).__init__() 147 | self.rnn = nn.GRU(1, 128, 6, batch_first=True, bidirectional=True) 148 | self.conv1 = nn.Conv1d(seq_len, seq_len, 256) 149 | 150 | self.sig = torch.nn.Sigmoid() 151 | 152 | def forward(self, input_): 153 | input_ = (input_.reshape(input_.size(0), -1, 1) / 2.0) + 1 154 | 155 | input_ = self.sig(input_) 156 | 157 | out, _ = self.rnn(input_) 158 | out = self.conv1(out) 159 | 160 | out = self.sig(out) 161 | 162 | return out.view(input_.size(0), -1) 163 | 164 | 165 | class cnn(nn.Module): 166 | def __init__(self, seq_len): 167 | super(cnn, self).__init__() 168 | self.layer1 = nn.Sequential( 169 | nn.Conv1d(1, 8, 2), 170 | nn.PReLU()) 171 | self.layer2 = nn.Sequential( 172 | nn.Conv1d(8, 16, 3), 173 | nn.BatchNorm1d(16), 174 | nn.PReLU()) 175 | self.layer3 = nn.Sequential( 176 | nn.Conv1d(16, 32, 5), 177 | nn.PReLU()) 178 | self.layer4 = nn.Sequential( 179 | nn.Conv1d(32, 64, 7), 180 | nn.BatchNorm1d(64), 181 | nn.PReLU()) 182 | self.layer5 = nn.Sequential( 183 | nn.Conv1d(64, 96, 10), 184 | nn.PReLU()) 185 | self.layer6 = nn.Sequential( 186 | nn.Conv1d(96, 128, 7), 187 | nn.BatchNorm1d(128), 188 | nn.PReLU()) 189 | self.layer7 = nn.Sequential( 190 | nn.Conv1d(128, 256, 5), 191 | nn.PReLU()) 192 | self.layer8 = nn.Sequential( 193 | nn.Conv1d(256, 256, 3), 194 | nn.BatchNorm1d(256), 195 | nn.PReLU()) 196 | self.layer9 = nn.Sequential( 197 | nn.Conv1d(256, 128, 3), 198 | nn.PReLU()) 199 | self.layer10 = nn.Conv1d(128, seq_len, 64) 200 | 201 | def forward(self, input_): 202 | out = input_.unsqueeze(1) 203 | out = self.layer1(out) 204 | out = self.layer2(out) 205 | out = self.layer3(out) 206 | out = self.layer4(out) 207 | out = self.layer5(out) 208 | out = self.layer6(out) 209 | out = self.layer7(out) 210 | out = self.layer8(out) 211 | out = self.layer9(out) 212 | out = self.layer10(out).view(input_.size(0), -1) 213 | out = torch.sigmoid(out) 214 | 215 | out = out 216 | return out 217 | 218 | 219 | class mlp(nn.Module): 220 | def __init__(self, seq_len): 221 | super(mlp, self).__init__() 222 | self.lin1 = nn.Linear(seq_len, 2048) 223 | self.lin2 = nn.Linear(2048, 2048) 224 | self.lin3 = nn.Linear(2048, seq_len) 225 | 226 | self.relu = nn.ReLU() 227 | 228 | def forward(self, input_): 229 | input_ = input_.reshape(input_.size(0), -1) 230 | out = self.lin1(input_) 231 | out = self.lin2(self.relu(out)) 232 | out = self.lin3(self.relu(out)) 233 | 234 | return out.view(input_.size(0), -1) 235 | 236 | 237 | class gru_sum(nn.Module): 238 | def __init__(self, seq_len): 239 | super(gru_sum, self).__init__() 240 | self.lstm = nn.GRU(1, 4, 1, batch_first=True, bidirectional=True) 241 | 242 | def forward(self, input_): 243 | input_ = input_.reshape(input_.size(0), -1, 1) 244 | out, _ = self.lstm(input_) 245 | out = out.sum(dim=2) 246 | 247 | return out.view(input_.size(0), -1) 248 | 249 | 250 | class lstm_end(nn.Module): 251 | def __init__(self, seq_len): 252 | super(lstm_end, self).__init__() 253 | self.seq_len = seq_len 254 | self.lstm = nn.GRU(self.seq_len, 5 * self.seq_len, batch_first=True, bidirectional=False) 255 | 256 | def forward(self, input_): 257 | input_ = input_.reshape(input_.size(0), -1, 1).repeat(1, input_.size(1), 1).view(input_.size(0), input_.size(1), -1) 258 | _, out = self.lstm(input_) 259 | 260 | out = out.view(input_.size(0), self.seq_len, -1) # .view(input_.size(0), -1)[:,:self.seq_len] 261 | out = out.sum(dim=2) 262 | 263 | return out 264 | 265 | 266 | class sorter_exact(nn.Module): 267 | 268 | def __init__(self): 269 | super(sorter_exact, self).__init__() 270 | 271 | def comp(self, inpu): 272 | in_mat1 = torch.triu(inpu.repeat(inpu.size(0), 1), diagonal=1) 273 | in_mat2 = torch.triu(inpu.repeat(inpu.size(0), 1).t(), diagonal=1) 274 | 275 | comp_first = (in_mat1 - in_mat2) 276 | comp_second = (in_mat2 - in_mat1) 277 | 278 | std1 = torch.std(comp_first).item() 279 | std2 = torch.std(comp_second).item() 280 | 281 | comp_first = torch.sigmoid(comp_first * (6.8 / std1)) 282 | comp_second = torch.sigmoid(comp_second * (6.8 / std2)) 283 | 284 | comp_first = torch.triu(comp_first, diagonal=1) 285 | comp_second = torch.triu(comp_second, diagonal=1) 286 | 287 | return (torch.sum(comp_first, 1) + torch.sum(comp_second, 0) + 1) / inpu.size(0) 288 | 289 | def forward(self, input_): 290 | out = [self.comp(input_[d]) for d in range(input_.size(0))] 291 | out = torch.stack(out) 292 | 293 | return out.view(input_.size(0), -1) 294 | 295 | 296 | class lstm_large(nn.Module): 297 | 298 | def __init__(self, seq_len): 299 | super(lstm_large, self).__init__() 300 | self.lstm = nn.LSTM(1, 512, 2, batch_first=True, bidirectional=True) 301 | self.conv1 = nn.Conv1d(seq_len, seq_len, 1024) 302 | 303 | def forward(self, input_): 304 | input_ = input_.reshape(input_.size(0), -1, 1) 305 | out, _ = self.lstm(input_) 306 | out = self.conv1(out) 307 | 308 | return out.view(input_.size(0), -1) 309 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # PyTorch 2 | import torch 3 | from torch.utils.data import Dataset, DataLoader 4 | from torch.autograd import Variable 5 | import torchvision.transforms as tr 6 | import torch.nn.functional as F 7 | from torch.nn import Sequential 8 | # Models 9 | from unet import Unet 10 | from siamunet_conc import SiamUnet_conc 11 | from SiamUnet_conc import SiamUnet_conc 12 | from siamunet_diff import SiamUnet_diff 13 | 14 | from fresunet import FresUNet 15 | 16 | # Other 17 | import os 18 | import numpy as np 19 | import random 20 | from skimage import io 21 | from scipy.ndimage import zoom 22 | import matplotlib.pyplot as plt 23 | from tqdm import tqdm as tqdm 24 | from IPython import display 25 | from eval import eval 26 | from accloss import accloss 27 | from spearman import Spear 28 | 29 | import time 30 | import warnings 31 | 32 | print('IMPORTS OK') 33 | 34 | # Global Variables' Definitions 35 | 36 | PATH_TO_DATASET = r'D:\NTIRE Workshop and Challenges @ CVPR 2021\dataset' 37 | 38 | conf_type = "img(mean_std_norm)_label(mean_std_norm)_maxpool_1e-5_type_covcat_pretrain_epoch-134_loss-0.432_mse1_pr0.5_spr0.5" 39 | 40 | sorter_checkpoint_path = r"D:\NTIRE Workshop and Challenges @ CVPR 2021\codes\FC-Siam-diff\fully_convolutional_change_detection-master\best_model0.00463445740044117.pth.tar" 41 | 42 | BATCH_SIZE = 30 43 | 44 | NUM_WORKER = 4 45 | 46 | scale_co_test = 1 47 | 48 | epoch_start_ = 21 49 | 50 | N_EPOCHS = 200 51 | 52 | NORMALISE_IMGS = True 53 | 54 | NORMALISE_LABELS = True 55 | 56 | TYPE = 2 # 0-RGB | 1-RGBIr | 2-All bands s.t. resulution <= 20m | 3-All bands 57 | 58 | apply_spearman = True 59 | 60 | LOAD_TRAINED = True 61 | conf_type_pretrain = r"2080server/11" 62 | if LOAD_TRAINED: 63 | ch_path = rf'./checkpoint/{conf_type}/ch_net-best_epoch-134_loss-0.4324578046798706.pth.tar' 64 | # ch_path = rf'./checkpoint/{conf_type}/ch_net-best_epoch-52_loss-1.37092924118042.pth.tar' 65 | # ch_path = rf'./checkpoint/{conf_type_pretrain}/ch_net-best_epoch-70_loss-0.6888124346733093.pth.tar' 66 | 67 | DATA_AUG = True 68 | 69 | print('DEFINITIONS OK') 70 | 71 | 72 | def reshape_for_torch(I): 73 | """Transpose image for PyTorch coordinates.""" 74 | # out = np.swapaxes(I,1,2) 75 | # out = np.swapaxes(out,0,1) 76 | # out = out[np.newaxis,:] 77 | out = I.transpose((2, 0, 1)) 78 | return torch.from_numpy(1.0 * out) 79 | 80 | 81 | class NTIR(Dataset): 82 | """Change Detection dataset class, used for both training and test data.""" 83 | 84 | def __init__(self, path, train=True, transform=None): 85 | 86 | self.transform = transform 87 | self.path = path 88 | self.names = [[], [], []] 89 | self.train_m = train 90 | 91 | if self.train_m: 92 | img_12 = 'path_img_train.txt' 93 | label = 'label_data_train.txt' 94 | 95 | with open(os.path.join(self.path, img_12), "r") as img_file: 96 | all_data = img_file.read().split("\n")[:-1] 97 | self.names[0] = [img.split(",")[0] for img in all_data] 98 | self.names[1] = [img.split(",")[1] for img in all_data] 99 | 100 | with open(os.path.join(self.path, label), "r") as gt_file: 101 | all_scores = gt_file.read().split("\n")[:-1] 102 | self.names[2] = [float(score) for score in all_scores] 103 | 104 | # self.names = [it[:200] for it in self.names] 105 | 106 | 107 | else: 108 | img_12 = 'path_img_test.txt' 109 | label = 'label_data_test.txt' 110 | 111 | with open(os.path.join(self.path, img_12), "r") as img_file: 112 | all_data = img_file.read().split("\n")[:-1] 113 | self.names[0] = [img.split(",")[0] for img in all_data] 114 | self.names[1] = [img.split(",")[1] for img in all_data] 115 | 116 | with open(os.path.join(self.path, label), "r") as gt_file: 117 | all_scores = gt_file.read().split("\n")[:-1] 118 | self.names[2] = [float(score) for score in all_scores] 119 | 120 | # self.names = [it[:200] for it in self.names] 121 | 122 | def __len__(self): 123 | return len(self.names[0]) 124 | 125 | def __getitem__(self, idx): 126 | 127 | I1_path = self.names[0][idx] 128 | I2_path = self.names[1][idx] 129 | 130 | I1_ = io.imread(I1_path) 131 | I2_ = io.imread(I2_path) 132 | 133 | if NORMALISE_IMGS: 134 | I1_m = (I1_ - I1_.mean()) / I1_.std() 135 | I2_m = (I2_ - I2_.mean()) / I2_.std() 136 | else: 137 | I1_m = I1_ 138 | I2_m = I2_ 139 | I1 = reshape_for_torch(I1_m) 140 | I2 = reshape_for_torch(I2_m) 141 | 142 | label = np.array([self.names[2][idx]]) 143 | if NORMALISE_LABELS: 144 | # label_ = label - np.array(self.names[2]).mean() 145 | label_ = (label - np.array(self.names[2]).mean()) / np.array(self.names[2]).std() 146 | else: 147 | label_ = label 148 | 149 | label = torch.from_numpy(1.0 * label_).float() 150 | 151 | sample = {'I1': I1, 'I2': I2, 'label': label} 152 | 153 | if self.transform: 154 | sample = self.transform(sample) 155 | 156 | return sample 157 | 158 | 159 | class RandomFlip(object): 160 | """Flip randomly the images in a sample.""" 161 | 162 | # def __init__(self): 163 | # return 164 | 165 | def __call__(self, sample): 166 | I1, I2, label = sample['I1'], sample['I2'], sample['label'] 167 | 168 | if random.random() > 0.5: 169 | I1 = I1.numpy()[:, :, ::-1].copy() 170 | I1 = torch.from_numpy(I1) 171 | I2 = I2.numpy()[:, :, ::-1].copy() 172 | I2 = torch.from_numpy(I2) 173 | 174 | return {'I1': I1, 'I2': I2, 'label': label} 175 | 176 | 177 | class RandomRot(object): 178 | """Rotate randomly the images in a sample.""" 179 | 180 | # def __init__(self): 181 | # return 182 | 183 | def __call__(self, sample): 184 | I1, I2, label = sample['I1'], sample['I2'], sample['label'] 185 | 186 | n = random.randint(0, 3) 187 | if n: 188 | I1 = sample['I1'].numpy() 189 | I1 = np.rot90(I1, n, axes=(1, 2)).copy() 190 | I1 = torch.from_numpy(I1) 191 | I2 = sample['I2'].numpy() 192 | I2 = np.rot90(I2, n, axes=(1, 2)).copy() 193 | I2 = torch.from_numpy(I2) 194 | 195 | return {'I1': I1, 'I2': I2, 'label': label} 196 | 197 | 198 | print('UTILS OK') 199 | 200 | 201 | def count_parameters(model): 202 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 203 | 204 | 205 | # net.load_state_dict(torch.load('net-best_epoch-1_fm-0.7394933126157746.pth.tar')) 206 | 207 | 208 | def train(epock_start, best_lss, best_acc, n_epochs=N_EPOCHS, save=True): 209 | t = np.linspace(1, n_epochs, n_epochs) 210 | 211 | epoch_train_loss = 0 * t 212 | epoch_train_accuracy = 0 * t 213 | 214 | epoch_test_loss = 0 * t 215 | epoch_test_accuracy = 0 * t 216 | 217 | plt.figure(num=1) 218 | plt.figure(num=2) 219 | 220 | try: 221 | epoch_train_loss[:epock_start] = Train_loss_curve[1] 222 | epoch_test_loss[:epock_start] = Test_loss_curve[1] 223 | epoch_train_accuracy[:epock_start] = Train_accuracy_curve[1] 224 | epoch_test_accuracy[:epock_start] = Test_accuracy_curve[1] 225 | except: 226 | epoch_train_loss[:epock_start] = 0 * np.array(list(range(epock_start))) 227 | epoch_test_loss[:epock_start] = 0 * np.array(list(range(epock_start))) 228 | epoch_train_accuracy[:epock_start] = 0 * np.array(list(range(epock_start))) 229 | epoch_test_accuracy[:epock_start] = 0 * np.array(list(range(epock_start))) 230 | 231 | for epoch_index in tqdm(range(epock_start, n_epochs)): 232 | net.train() 233 | print('Epoch: ' + str(epoch_index + 1) + ' of ' + str(N_EPOCHS)) 234 | 235 | for index, batch in enumerate(train_loader): 236 | # im1 = batch['I1'][2, :, :, :].numpy().transpose(1, 2, 0).astype("uint8") 237 | # im2 = batch['I2'][2, :, :, :].numpy().transpose(1, 2, 0).astype("uint8") 238 | I1 = Variable(batch['I1'].float().cuda()) 239 | I2 = Variable(batch['I2'].float().cuda()) 240 | label = Variable(batch['label'].cuda()) 241 | 242 | optimizer.zero_grad() 243 | output = net(I1, I2) 244 | 245 | # loss = coef_loss_mae * criterion_mae(output, label) + coef_loss_pr * ( 246 | # 1 - (criterion_pr(output, label)) ** 2) 247 | 248 | # zipp_sort_ind = zip(np.argsort(batch['label'].numpy())[::-1], range(BATCH_SIZE)) 249 | # ranks = [((y[1] + 1) / float(BATCH_SIZE)) for y in sorted(zipp_sort_ind, key=lambda x: x[0])] 250 | # label_spr = torch.FloatTensor(ranks).cuda() 251 | 252 | loss = coef_loss_mse * criterion_mse(output, label) + coef_loss_pr * (1-criterion_pr(output, 253 | label)) + coef_loss_spr * criterion_spr( 254 | output, label) + coef_loss_mae * criterion_mae(output, label) 255 | with torch.no_grad(): 256 | print("@@@@",criterion_mse(output, label)," ", 1-criterion_pr(output,label), " ", criterion_spr(output, label)) 257 | 258 | # loss = criterion_mse(output, label) 259 | 260 | loss.backward() 261 | optimizer.step() 262 | print( 263 | "\ntrain : " + f"epoch : {epoch_index + 1} -- " + f"{index + 1}" + " / " + f"{len(train_loader)}" + " ----->" + "loss : " 264 | + f"{loss.detach().cpu().numpy():.04f}") 265 | 266 | scheduler.step() 267 | with torch.no_grad(): 268 | 269 | epoch_train_loss[epoch_index], epoch_train_accuracy[epoch_index] = test(train_loader_val) 270 | 271 | epoch_test_loss[epoch_index], epoch_test_accuracy[epoch_index] = test(test_loader) 272 | 273 | plt.figure(num=1) 274 | plt.clf() 275 | 276 | l1_1, = plt.plot(t[:epoch_index + 1], epoch_train_loss[:epoch_index + 1], 277 | label='Train loss') 278 | l1_2, = plt.plot(t[:epoch_index + 1], epoch_test_loss[:epoch_index + 1], 279 | label='Test loss') 280 | plt.legend(handles=[l1_1, l1_2]) 281 | plt.grid() 282 | 283 | plt.gcf().gca().set_xlim(left=0) 284 | plt.title('Loss') 285 | display.clear_output(wait=True) 286 | display.display(plt.gcf()) 287 | 288 | plt.figure(num=2) 289 | plt.clf() 290 | 291 | l2_1, = plt.plot(t[:epoch_index + 1], epoch_train_accuracy[:epoch_index + 1], 292 | label='Train accuracy') 293 | l2_2, = plt.plot(t[:epoch_index + 1], epoch_test_accuracy[:epoch_index + 1], 294 | label='Test accuracy') 295 | plt.legend(handles=[l2_1, l2_2]) 296 | plt.grid() 297 | 298 | plt.gcf().gca().set_xlim(left=0) 299 | plt.title('Accuracy') 300 | display.clear_output(wait=True) 301 | display.display(plt.gcf()) 302 | 303 | # lss = epoch_train_loss[epoch_index] 304 | # accu = epoch_train_accuracy[epoch_index] 305 | 306 | lss = epoch_test_loss[epoch_index] 307 | accu = epoch_test_accuracy[epoch_index] 308 | 309 | if accu > best_acc: 310 | best_acc = accu 311 | save_str = fr'./checkpoint/{conf_type}/ch_net-best_epoch-' + str(epoch_index + 1) + '_accu-' + str( 312 | accu) + '.pth.tar' 313 | # torch.save(net.state_dict(), save_str) 314 | 315 | torch.save({ 316 | 'epoch': epoch_index, 317 | 'model_state_dict': net.state_dict(), 318 | 'model_state_dict_head': criterion_spr.state_dict(), 319 | 'optimizer_state_dict': optimizer.state_dict(), 320 | 'scheduler_state_dict': scheduler.state_dict(), 321 | "Train loss": [t[:epoch_index + 1], epoch_train_loss[:epoch_index + 1]], 322 | "Test loss": [t[:epoch_index + 1], epoch_test_loss[:epoch_index + 1]], 323 | "Train accuracy": [t[:epoch_index + 1], 324 | epoch_train_accuracy[:epoch_index + 1]], 325 | "Test accuracy": [t[:epoch_index + 1], epoch_test_accuracy[:epoch_index + 1]], 326 | 'loss': lss, 327 | 'acc': accu 328 | }, save_str) 329 | 330 | if lss < best_lss: 331 | best_lss = lss 332 | save_str = rf'./checkpoint/{conf_type}/ch_net-best_epoch-' + str(epoch_index + 1) + '_loss-' + str( 333 | lss) + '.pth.tar' 334 | torch.save({ 335 | 'epoch': epoch_index, 336 | 'model_state_dict': net.state_dict(), 337 | 'model_state_dict_head': criterion_spr.state_dict(), 338 | 'optimizer_state_dict': optimizer.state_dict(), 339 | 'scheduler_state_dict': scheduler.state_dict(), 340 | "Train loss": [t[:epoch_index + 1], epoch_train_loss[:epoch_index + 1]], 341 | "Test loss": [t[:epoch_index + 1], epoch_test_loss[:epoch_index + 1]], 342 | "Train accuracy": [t[:epoch_index + 1], 343 | epoch_train_accuracy[:epoch_index + 1]], 344 | "Test accuracy": [t[:epoch_index + 1], epoch_test_accuracy[:epoch_index + 1]], 345 | 'loss': lss, 346 | 'acc': accu 347 | }, save_str) 348 | print( 349 | f"\n ################## \n epock : {epoch_index + 1} \n avg_loss_train : {lss} \n avg_acc_train : {accu} \n ################## \n") 350 | 351 | accu_val = epoch_test_accuracy[epoch_index] 352 | lss_val = epoch_test_loss[epoch_index] 353 | print( 354 | f"\n ################## \n epock : {epoch_index + 1} \n avg_loss_test : {lss_val} \n avg_acc_test : {accu_val} \n ################## \n") 355 | 356 | if save: 357 | im_format = 'png' 358 | # im_format = 'eps' 359 | 360 | plt.figure(num=1) 361 | plt.savefig(net_name + '-01-loss.' + im_format) 362 | 363 | plt.figure(num=2) 364 | plt.savefig(net_name + '-02-accuracy.' + im_format) 365 | 366 | out = {'train_loss': epoch_train_loss[-1], 367 | 'train_accuracy': epoch_train_accuracy[-1], 368 | 'test_loss': epoch_test_loss[-1], 369 | 'test_accuracy': epoch_test_accuracy[-1]} 370 | 371 | return out 372 | 373 | 374 | L = 1024 375 | N = 2 376 | 377 | 378 | def test(dset): 379 | net.eval() 380 | tot_loss = 0 381 | tot_count = 0 382 | all_predicted = [] 383 | all_gt = [] 384 | 385 | for index, batch in enumerate(dset): 386 | I1 = Variable(batch['I1'].float().cuda()) 387 | I2 = Variable(batch['I2'].float().cuda()) 388 | cm = Variable(batch['label'].cuda()) 389 | 390 | output = net(I1, I2) 391 | # loss = coef_loss_mae * criterion_mae(output, label) + coef_loss_pr * ( 392 | # 1 - (criterion_pr(output, label)) ** 2) 393 | 394 | # zipp_sort_ind = zip(np.argsort(batch['label'].numpy())[::-1], range(BATCH_SIZE)) 395 | # ranks = [((y[1] + 1) / float(BATCH_SIZE)) for y in sorted(zipp_sort_ind, key=lambda x: x[0])] 396 | # label_spr_cm = torch.FloatTensor(ranks).cuda() 397 | 398 | loss = coef_loss_mse * criterion_mse(output, cm) + coef_loss_pr * (1-criterion_pr(output, 399 | cm)) + coef_loss_spr * criterion_spr( 400 | output, cm) + coef_loss_mae * criterion_mae(output, cm) 401 | # loss = criterion_mse(output, cm) 402 | 403 | print( 404 | "\n val : " + f"{index + 1}" + " / " + f"{len(dset)}" + " ----->" + "loss : " 405 | + f"{loss.detach().cpu().numpy():.04f}") 406 | tot_loss += loss.data * np.prod(cm.size()) 407 | tot_count += np.prod(cm.size()) 408 | all_predicted.extend(list(torch.squeeze(output).detach().cpu().numpy())) 409 | all_gt.extend(list(torch.squeeze(cm).detach().cpu().numpy())) 410 | 411 | net_loss = tot_loss / tot_count 412 | accuracy, _, _ = eval(np.array(all_predicted), np.array(all_gt)) 413 | 414 | return net_loss, accuracy 415 | 416 | 417 | def save_test_results(dset): 418 | for name in tqdm(dset.names): 419 | with warnings.catch_warnings(): 420 | I1, I2, cm = dset.get_img(name) 421 | I1 = Variable(torch.unsqueeze(I1, 0).float()).cuda() 422 | I2 = Variable(torch.unsqueeze(I2, 0).float()).cuda() 423 | out = net(I1, I2) 424 | _, predicted = torch.max(out.data, 1) 425 | I = np.stack((255 * cm, 255 * np.squeeze(predicted.cpu().numpy()), 255 * cm), 2) 426 | io.imsave(f'{net_name}-{name}.png', I) 427 | 428 | 429 | if __name__ == '__main__': 430 | 431 | if DATA_AUG: 432 | data_transform = tr.Compose([RandomFlip(), RandomRot()]) 433 | else: 434 | data_transform = None 435 | 436 | train_dataset = NTIR(PATH_TO_DATASET, train=True, transform=data_transform) 437 | 438 | train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKER, 439 | drop_last=True) 440 | 441 | train_loader_val = DataLoader(train_dataset, batch_size=int(BATCH_SIZE * scale_co_test), shuffle=True, 442 | num_workers=NUM_WORKER, 443 | drop_last=True) 444 | 445 | test_dataset = NTIR(PATH_TO_DATASET, train=False) 446 | test_loader = DataLoader(test_dataset, batch_size=int(BATCH_SIZE * scale_co_test), shuffle=True, 447 | num_workers=NUM_WORKER, 448 | drop_last=True) 449 | 450 | print('DATASETS OK') 451 | 452 | if TYPE == 0: 453 | # net, net_name = Unet(2*3, 2), 'FC-EF' 454 | # net, net_name = SiamUnet_conc(3, 2), 'FC-Siam-conc' 455 | # net, net_name = SiamUnet_diff(3, 2), 'FC-Siam-diff' 456 | net, net_name = FresUNet(2 * 3, 2), 'FresUNet' 457 | elif TYPE == 1: 458 | # net, net_name = Unet(2*4, 2), 'FC-EF' 459 | # net, net_name = SiamUnet_conc(4, 2), 'FC-Siam-conc' 460 | # net, net_name = SiamUnet_diff(4, 2), 'FC-Siam-diff' 461 | net, net_name = FresUNet(2 * 4, 2), 'FresUNet' 462 | elif TYPE == 2: 463 | # net, net_name = Unet(2*10, 2), 'FC-EF' 464 | net, net_name = SiamUnet_conc(3, 1), rf'./checkpoint/{conf_type}/FC-Siam-diff' 465 | # net, net_name = SiamUnet_diff(10, 2), 'FC-Siam-diff' 466 | # net, net_name = FresUNet(2 * 10, 2), 'FresUNet' 467 | elif TYPE == 3: 468 | # net, net_name = Unet(2*13, 2), 'FC-EF' 469 | # net, net_name = SiamUnet_conc(13, 2), 'FC-Siam-conc' 470 | net, net_name = SiamUnet_diff(3, 1), rf'./checkpoint/{conf_type}/FC-Siam-diff' 471 | # net, net_name = FresUNet(2 * 13, 2), 'FresUNet' 472 | 473 | net.cuda() 474 | 475 | criterion_mse = F.mse_loss 476 | criterion_mae = F.l1_loss 477 | criterion_pr = accloss 478 | 479 | 480 | coef_loss_mse = 1 481 | coef_loss_mae = 0 482 | coef_loss_pr = 0.5 483 | coef_loss_spr = 0.5 484 | 485 | 486 | 487 | print('Number of trainable parameters:', count_parameters(net)) 488 | 489 | print('NETWORK OK') 490 | 491 | # optimizer = torch.optim.Adam(net.parameters(), lr=0.0005) 492 | 493 | if LOAD_TRAINED: 494 | 495 | checkpoint = torch.load(ch_path) 496 | # checkpoint['optimizer_state_dict']['param_groups'][0]["lr"] = 1e-5 497 | try: 498 | net.load_state_dict(checkpoint['model_state_dict']) 499 | criterion_spr = Spear(sorter_checkpoint_path) 500 | optimizer = torch.optim.Adam(list(net.parameters()) + list(criterion_spr.parameters()), lr=1e-5, 501 | weight_decay=1e-4) 502 | # optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 503 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.95) 504 | # scheduler.load_state_dict(checkpoint['scheduler_state_dict']) 505 | # Train_loss_curve = checkpoint["Train loss"] 506 | # Test_loss_curve = checkpoint["Test loss"] 507 | # Train_accuracy_curve = checkpoint["Train accuracy"] 508 | # Test_accuracy_curve = checkpoint["Test accuracy"] 509 | # epoch_input = checkpoint['epoch'] + 1 510 | # best_acc_ = checkpoint['epoch'] 511 | # best_lss_ = checkpoint['epoch'] 512 | epoch_input = 0 513 | best_acc_ = 0 514 | best_lss_ = 1000 515 | 516 | except: 517 | epoch_input = epoch_start_ 518 | net.load_state_dict(checkpoint) 519 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.95) 520 | best_acc_ = 0 521 | best_lss_ = 1000 522 | else: 523 | epoch_input = 0 524 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.95) 525 | best_acc_ = 0 526 | best_lss_ = 1000 527 | 528 | print('LOAD OK') 529 | 530 | t_start = time.time() 531 | out_dic = train(epoch_input, best_lss_, best_acc_) 532 | t_end = time.time() 533 | print(out_dic) 534 | print('Elapsed time:') 535 | print(t_end - t_start) 536 | 537 | if not LOAD_TRAINED: 538 | torch.save(net.state_dict(), rf'./checkpoint/{conf_type}/net_final.pth.tar') 539 | print('SAVE OK') 540 | import pdb 541 | pdb.tra 542 | # t_start = time.time() 543 | # # save_test_results(train_dataset) 544 | # save_test_results(test_dataset) 545 | # t_end = time.time() 546 | # print('Elapsed time: {}'.format(t_end - t_start)) 547 | --------------------------------------------------------------------------------