├── README.md ├── config.py ├── dataset.py ├── evaluation ├── __pycache__ │ ├── dataloader.cpython-37.pyc │ └── evaluator.cpython-37.pyc ├── dataloader.py ├── evaluator.py ├── hist_of_pixel_values.py ├── main.py ├── select_results.py └── sort_results.py ├── loss.py ├── models ├── __pycache__ │ ├── main.cpython-37.pyc │ └── vgg.cpython-37.pyc ├── main.py └── vgg.py ├── requirements.txt ├── test.py ├── train.py ├── util.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # DCFM 2 | The official repo of the paper `Democracy Does Matter: Comprehensive Feature Mining for Co-Salient Object Detection`. 3 | 4 | ## Environment Requirement 5 | create enviroment and intall as following: 6 | `pip install -r requirements.txt` 7 | 8 | ## Data Format 9 | trainset: CoCo-SEG 10 | 11 | testset: CoCA, CoSOD3k, Cosal2015 12 | 13 | Put the [CoCo-SEG](https://drive.google.com/file/d/1GbA_WKvJm04Z1tR8pTSzBdYVQ75avg4f/view), [CoCA](http://zhaozhang.net/coca.html), [CoSOD3k](http://dpfan.net/CoSOD3K/) and [Cosal2015](https://drive.google.com/u/0/uc?id=1mmYpGx17t8WocdPcw2WKeuFpz6VHoZ6K&export=download) datasets to `DCFM/data` as the following structure: 14 | ``` 15 | DCFM 16 | ├── other codes 17 | ├── ... 18 | │ 19 | └── data 20 | 21 | ├── CoCo-SEG (CoCo-SEG's image files) 22 | ├── CoCA (CoCA's image files) 23 | ├── CoSOD3k (CoSOD3k's image files) 24 | └── Cosal2015 (Cosal2015's image files) 25 | ``` 26 | 27 | ## Trained model 28 | 29 | trained model can be downloaded from [papermodel](https://drive.google.com/file/d/1cfuq4eJoCwvFR9W1XOJX7Y0ttd8TGjlp/view?usp=sharing). 30 | 31 | Run `test.py` for inference. 32 | 33 | The evaluation tool please follow: https://github.com/zzhanghub/eval-co-sod 34 | 35 | 36 | 37 | ## Usage 38 | Download pretrainde backbone model [VGG](https://drive.google.com/file/d/1Z1aAYXMyJ6txQ1Z9N7gtxLOIai4dxrXd/view?usp=sharing). 39 | 40 | Run `train.py` for training. 41 | 42 | ## Prediction results 43 | The co-saliency maps of DCFM can be found at [preds](https://drive.google.com/file/d/1wGeNHXFWVSyqvmL4NIUmEFdlHDovEtQR/view?usp=sharing). 44 | 45 | ## Reproduction 46 | reproductions by myself on 2080Ti can be found at [reproduction1](https://drive.google.com/file/d/1vovii0RtYR_EC0Y2zxjY_cTWKWM3WaxP/view?usp=sharing) and [reproduction2](https://drive.google.com/file/d/1YPOKZ5kBtmZrCDhHpP3-w1GMVR5BfDoU/view?usp=sharing). 47 | 48 | reprodution by myself on TITAN X can be found at [reproduction3](https://drive.google.com/file/d/1bnGFtRTYkVXqI2dcjeWFRDXnqqbUUBJr/view?usp=sharing). 49 | 50 | ## Others 51 | The code is based on [GCoNet](https://github.com/fanq15/GCoNet). 52 | I've added a validation part to help select the model for closer results. This validation part is based on [GCoNet_plus](https://github.com/ZhengPeng7/GCoNet_plus). You can try different evaluation metrics to select the model. 53 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | class Config(): 5 | def __init__(self) -> None: 6 | 7 | # Performance of GCoNet 8 | self.val_measures = { 9 | 'Emax': {'CoCA': 0.783, 'CoSOD3k': 0.874, 'CoSal2015': 0.892}, 10 | 'Smeasure': {'CoCA': 0.710, 'CoSOD3k': 0.810, 'CoSal2015': 0.838}, 11 | 'Fmax': {'CoCA': 0.598, 'CoSOD3k': 0.805, 'CoSal2015': 0.856}, 12 | } 13 | 14 | # others 15 | 16 | 17 | self.validation = True -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import PILLOW_VERSION, Image, ImageOps, ImageFilter 3 | import torch 4 | import random 5 | import numpy as np 6 | from torch.utils import data 7 | from torchvision import transforms 8 | from torchvision.transforms import functional as F 9 | import numbers 10 | import random 11 | import pandas as pd 12 | 13 | 14 | class CoData(data.Dataset): 15 | def __init__(self, img_root, gt_root, img_size, transform, max_num, is_train): 16 | 17 | class_list = os.listdir(img_root) 18 | self.size = [img_size, img_size] 19 | self.img_dirs = list( 20 | map(lambda x: os.path.join(img_root, x), class_list)) 21 | self.gt_dirs = list( 22 | map(lambda x: os.path.join(gt_root, x), class_list)) 23 | self.transform = transform 24 | self.max_num = max_num 25 | self.is_train = is_train 26 | 27 | def __getitem__(self, item): 28 | names = os.listdir(self.img_dirs[item]) 29 | num = len(names) 30 | img_paths = list( 31 | map(lambda x: os.path.join(self.img_dirs[item], x), names)) 32 | gt_paths = list( 33 | map(lambda x: os.path.join(self.gt_dirs[item], x[:-4]+'.png'), names)) 34 | 35 | if self.is_train: 36 | final_num = min(num, self.max_num) 37 | 38 | sampled_list = random.sample(range(num), final_num) 39 | # print(sampled_list) 40 | new_img_paths = [img_paths[i] for i in sampled_list] 41 | img_paths = new_img_paths 42 | new_gt_paths = [gt_paths[i] for i in sampled_list] 43 | gt_paths = new_gt_paths 44 | 45 | final_num = final_num 46 | else: 47 | final_num = num 48 | 49 | imgs = torch.Tensor(final_num, 3, self.size[0], self.size[1]) 50 | gts = torch.Tensor(final_num, 1, self.size[0], self.size[1]) 51 | 52 | subpaths = [] 53 | ori_sizes = [] 54 | for idx in range(final_num): 55 | # print(idx) 56 | img = Image.open(img_paths[idx]).convert('RGB') 57 | gt = Image.open(gt_paths[idx]).convert('L') 58 | 59 | subpaths.append(os.path.join(img_paths[idx].split('/')[-2], img_paths[idx].split('/')[-1][:-4]+'.png')) 60 | ori_sizes.append((img.size[1], img.size[0])) 61 | # ori_sizes += ((img.size[1], img.size[0])) 62 | 63 | [img, gt] = self.transform(img, gt) 64 | 65 | imgs[idx] = img 66 | gts[idx] = gt 67 | if self.is_train: 68 | cls_ls = [item] * int(final_num) 69 | return imgs, gts, subpaths, ori_sizes, cls_ls 70 | else: 71 | return imgs, gts, subpaths, ori_sizes 72 | 73 | def __len__(self): 74 | return len(self.img_dirs) 75 | 76 | 77 | class FixedResize(object): 78 | def __init__(self, size): 79 | self.size = (size, size) # size: (h, w) 80 | 81 | def __call__(self, img, gt): 82 | # assert img.size == gt.size 83 | 84 | img = img.resize(self.size, Image.BILINEAR) 85 | gt = gt.resize(self.size, Image.NEAREST) 86 | # gt = gt.resize(self.size, Image.BILINEAR) 87 | 88 | return img, gt 89 | 90 | 91 | class ToTensor(object): 92 | def __call__(self, img, gt): 93 | 94 | return F.to_tensor(img), F.to_tensor(gt) 95 | 96 | 97 | class Normalize(object): 98 | """Normalize a tensor image with mean and standard deviation. 99 | Args: 100 | mean (tuple): means for each channel. 101 | std (tuple): standard deviations for each channel. 102 | """ 103 | 104 | def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)): 105 | self.mean = mean 106 | self.std = std 107 | 108 | def __call__(self, img, gt): 109 | img = F.normalize(img, self.mean, self.std) 110 | 111 | return img, gt 112 | 113 | 114 | class RandomHorizontalFlip(object): 115 | def __init__(self, p=0.5): 116 | self.p = p 117 | 118 | def __call__(self, img, gt): 119 | if random.random() < self.p: 120 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 121 | gt = gt.transpose(Image.FLIP_LEFT_RIGHT) 122 | 123 | return img, gt 124 | 125 | 126 | class RandomScaleCrop(object): 127 | def __init__(self, base_size, crop_size, fill=0): 128 | self.base_size = base_size 129 | self.crop_size = crop_size 130 | self.fill = fill 131 | 132 | def __call__(self, img, mask): 133 | # random scale (short edge) 134 | # img = img.numpy() 135 | # mask = mask.numpy() 136 | short_size = random.randint(int(self.base_size * 0.8), int(self.base_size * 1.2)) 137 | w, h = img.size 138 | if h > w: 139 | ow = short_size 140 | oh = int(1.0 * h * ow / w) 141 | else: 142 | oh = short_size 143 | ow = int(1.0 * w * oh / h) 144 | img = img.resize((ow, oh), Image.BILINEAR) 145 | mask = mask.resize((ow, oh), Image.NEAREST) 146 | # pad crop 147 | if short_size < self.crop_size: 148 | padh = self.crop_size - oh if oh < self.crop_size else 0 149 | padw = self.crop_size - ow if ow < self.crop_size else 0 150 | img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0) 151 | mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=self.fill) 152 | # random crop crop_size 153 | w, h = img.size 154 | x1 = random.randint(0, w - self.crop_size) 155 | y1 = random.randint(0, h - self.crop_size) 156 | img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 157 | mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 158 | 159 | return img, mask 160 | 161 | 162 | class RandomRotation(object): 163 | def __init__(self, degrees, resample=False, expand=False, center=None): 164 | if isinstance(degrees, numbers.Number): 165 | if degrees < 0: 166 | raise ValueError("If degrees is a single number, it must be positive.") 167 | self.degrees = (-degrees, degrees) 168 | else: 169 | if len(degrees) != 2: 170 | raise ValueError("If degrees is a sequence, it must be of len 2.") 171 | self.degrees = degrees 172 | 173 | self.resample = resample 174 | self.expand = expand 175 | self.center = center 176 | 177 | @staticmethod 178 | def get_params(degrees): 179 | angle = random.uniform(degrees[0], degrees[1]) 180 | 181 | return angle 182 | 183 | def __call__(self, img, gt): 184 | """ 185 | img (PIL Image): Image to be rotated. 186 | 187 | Returns: 188 | PIL Image: Rotated image. 189 | """ 190 | 191 | angle = self.get_params(self.degrees) 192 | 193 | return F.rotate(img, angle, Image.BILINEAR, self.expand, self.center), F.rotate(gt, angle, Image.NEAREST, self.expand, self.center) 194 | 195 | 196 | 197 | class Compose(object): 198 | def __init__(self, transforms): 199 | self.transforms = transforms 200 | 201 | def __call__(self, img, gt): 202 | for t in self.transforms: 203 | img, gt = t(img, gt) 204 | return img, gt 205 | 206 | def __repr__(self): 207 | format_string = self.__class__.__name__ + '(' 208 | for t in self.transforms: 209 | format_string += '\n' 210 | format_string += ' {0}'.format(t) 211 | format_string += '\n)' 212 | return format_string 213 | 214 | 215 | # get the dataloader (Note: without data augmentation) 216 | def get_loader(img_root, gt_root, img_size, batch_size, max_num = float('inf'), istrain=True, shuffle=False, num_workers=0, pin=False): 217 | if istrain: 218 | transform = Compose([ 219 | RandomScaleCrop(img_size*2, img_size*2), 220 | FixedResize(img_size), 221 | RandomHorizontalFlip(), 222 | 223 | RandomRotation((-90, 90)), 224 | ToTensor(), 225 | Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 226 | ]) 227 | else: 228 | transform = Compose([ 229 | FixedResize(img_size), 230 | # RandomHorizontalFlip(), 231 | ToTensor(), 232 | Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 233 | ]) 234 | 235 | dataset = CoData(img_root, gt_root, img_size, transform, max_num, is_train=istrain) 236 | data_loader = data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, 237 | pin_memory=pin) 238 | return data_loader 239 | 240 | 241 | if __name__ == '__main__': 242 | import matplotlib.pyplot as plt 243 | 244 | mean = [0.485, 0.456, 0.406] 245 | std = [0.229, 0.224, 0.225] 246 | img_root = './data/testtrain/img/' 247 | gt_root = './data/testtrain/gt/' 248 | loader = get_loader(img_root, gt_root, 20, 1, 16, istrain=False) 249 | for batch in loader: 250 | b, c, h, w = batch[0][0].shape 251 | for i in range(b): 252 | img = batch[0].squeeze(0)[i].permute(1, 2, 0).cpu().numpy() * std + mean 253 | image = img * 255 254 | mask = batch[1].squeeze(0)[i].squeeze().cpu().numpy() 255 | plt.subplot(121) 256 | plt.imshow(np.uint8(image)) 257 | plt.subplot(122) 258 | plt.imshow(mask) 259 | plt.show(block=True) 260 | -------------------------------------------------------------------------------- /evaluation/__pycache__/dataloader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siyueyu/DCFM/05800b67ccf70f9ed55dd1f33ee1ff3b3503eb09/evaluation/__pycache__/dataloader.cpython-37.pyc -------------------------------------------------------------------------------- /evaluation/__pycache__/evaluator.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siyueyu/DCFM/05800b67ccf70f9ed55dd1f33ee1ff3b3503eb09/evaluation/__pycache__/evaluator.cpython-37.pyc -------------------------------------------------------------------------------- /evaluation/dataloader.py: -------------------------------------------------------------------------------- 1 | from torch.utils import data 2 | import os 3 | from PIL import Image, ImageFile 4 | 5 | 6 | ImageFile.LOAD_TRUNCATED_IMAGES = True 7 | 8 | 9 | class EvalDataset(data.Dataset): 10 | def __init__(self, pred_root, label_root, return_predpath=False, return_gtpath=False): 11 | self.return_predpath = return_predpath 12 | self.return_gtpath = return_gtpath 13 | pred_dirs = os.listdir(pred_root) 14 | label_dirs = os.listdir(label_root) 15 | 16 | dir_name_list = [] 17 | for idir in pred_dirs: 18 | if idir in label_dirs: 19 | pred_names = os.listdir(os.path.join(pred_root, idir)) 20 | label_names = os.listdir(os.path.join(label_root, idir)) 21 | for iname in pred_names: 22 | if iname in label_names: 23 | dir_name_list.append(os.path.join(idir, iname)) 24 | 25 | self.image_path = list( 26 | map(lambda x: os.path.join(pred_root, x), dir_name_list)) 27 | self.label_path = list( 28 | map(lambda x: os.path.join(label_root, x), dir_name_list)) 29 | 30 | self.labels = [] 31 | for p in self.label_path: 32 | self.labels.append(Image.open(p).convert('L')) 33 | 34 | 35 | def __getitem__(self, item): 36 | predpath = self.image_path[item] 37 | gtpath = self.label_path[item] 38 | pred = Image.open(predpath).convert('L') 39 | gt = self.labels[item] 40 | if pred.size != gt.size: 41 | pred = pred.resize(gt.size, Image.BILINEAR) 42 | returns = [pred, gt] 43 | if self.return_predpath: 44 | returns.append(predpath) 45 | if self.return_gtpath: 46 | returns.append(gtpath) 47 | return returns 48 | 49 | def __len__(self): 50 | return len(self.image_path) 51 | -------------------------------------------------------------------------------- /evaluation/evaluator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import json 4 | 5 | import numpy as np 6 | from scipy.io import savemat 7 | import torch 8 | from torchvision import transforms 9 | 10 | from PIL import ImageFile 11 | ImageFile.LOAD_TRUNCATED_IMAGES = True 12 | 13 | 14 | class Eval_thread(): 15 | def __init__(self, loader, method='', dataset='', output_dir='', epoch='', cuda=True): 16 | self.loader = loader 17 | self.method = method 18 | self.dataset = dataset 19 | self.cuda = cuda 20 | self.output_dir = output_dir 21 | self.epoch = epoch.split('ep')[-1] 22 | self.logfile = os.path.join(output_dir, 'result.txt') 23 | self.dataset2smeasure_bottom_bound = {'CoCA': 0.673, 'CoSOD3k': 0.802, 'CoSal2015': 0.845} # S_measures of GCoNet 24 | 25 | def run(self, AP=False, AUC=False, save_metrics=False, continue_eval=True): 26 | Res = {} 27 | start_time = time.time() 28 | 29 | if continue_eval: 30 | s = self.Eval_Smeasure() 31 | if s > self.dataset2smeasure_bottom_bound[self.dataset]: 32 | mae = self.Eval_mae() 33 | Em = self.Eval_Emeasure() 34 | max_e = Em.max().item() 35 | mean_e = Em.mean().item() 36 | Em = Em.cpu().numpy() 37 | Fm, prec, recall = self.Eval_fmeasure() 38 | max_f = Fm.max().item() 39 | mean_f = Fm.mean().item() 40 | Fm = Fm.cpu().numpy() 41 | else: 42 | mae = 1 43 | Em = torch.zeros(255).cpu().numpy() 44 | max_e = 0 45 | mean_e = 0 46 | Fm, prec, recall = 0, 0, 0 47 | max_f = 0 48 | mean_f = 0 49 | continue_eval = False 50 | else: 51 | s = 0 52 | mae = 1 53 | Em = torch.zeros(255).cpu().numpy() 54 | max_e = 0 55 | mean_e = 0 56 | Fm, prec, recall = 0, 0, 0 57 | max_f = 0 58 | mean_f = 0 59 | continue_eval = False 60 | 61 | 62 | if AP: 63 | prec = prec.cpu().numpy() 64 | recall = recall.cpu().numpy() 65 | avg_p = self.Eval_AP(prec, recall) 66 | 67 | if AUC: 68 | auc, TPR, FPR = self.Eval_auc() 69 | TPR = TPR.cpu().numpy() 70 | FPR = FPR.cpu().numpy() 71 | 72 | if save_metrics: 73 | os.makedirs(os.path.join(self.output_dir, self.method, self.epoch), exist_ok=True) 74 | Res['Sm'] = s 75 | if s > self.dataset2smeasure_bottom_bound[self.dataset]: 76 | Res['MAE'] = mae 77 | Res['MaxEm'] = max_e 78 | Res['MeanEm'] = mean_e 79 | Res['Em'] = Em 80 | Res['Fm'] = Fm 81 | else: 82 | Res['MAE'] = 1 83 | Res['MaxEm'] = 0 84 | Res['MeanEm'] = 0 85 | Res['Em'] = torch.zeros(255).cpu().numpy() 86 | Res['Fm'] = 0 87 | 88 | if AP: 89 | Res['MaxFm'] = max_f 90 | Res['MeanFm'] = mean_f 91 | Res['AP'] = avg_p 92 | Res['Prec'] = prec 93 | Res['Recall'] = recall 94 | 95 | if AUC: 96 | Res['AUC'] = auc 97 | Res['TPR'] = TPR 98 | Res['FPR'] = FPR 99 | 100 | os.makedirs(os.path.join(self.output_dir, self.method, self.epoch), exist_ok=True) 101 | savemat(os.path.join(self.output_dir, self.method, self.epoch, self.dataset + '.mat'), Res) 102 | 103 | info = '{} ({}): {:.4f} max-Emeasure || {:.4f} S-measure || {:.4f} max-fm || {:.4f} mae || {:.4f} mean-Emeasure || {:.4f} mean-fm'.format( 104 | self.dataset, self.method+'-ep{}'.format(self.epoch), max_e, s, max_f, mae, mean_e, mean_f 105 | ) 106 | if AP: 107 | info += ' || {:.4f} AP'.format(avg_p) 108 | if AUC: 109 | info += ' || {:.4f} AUC'.format(auc) 110 | info += '.' 111 | self.LOG(info + '\n') 112 | 113 | return '[cost:{:.4f}s] '.format(time.time() - start_time) + info, continue_eval 114 | 115 | def Eval_mae(self): 116 | if self.epoch: 117 | print('Evaluating MAE...') 118 | avg_mae, img_num = 0.0, 0.0 119 | with torch.no_grad(): 120 | trans = transforms.Compose([transforms.ToTensor()]) 121 | for pred, gt in self.loader: 122 | if self.cuda: 123 | pred = trans(pred).cuda() 124 | gt = trans(gt).cuda() 125 | else: 126 | pred = trans(pred) 127 | gt = trans(gt) 128 | mea = torch.abs(pred - gt).mean() 129 | if mea == mea: # for Nan 130 | avg_mae += mea 131 | img_num += 1.0 132 | avg_mae /= img_num 133 | return avg_mae.item() 134 | 135 | def Eval_fmeasure(self): 136 | print('Evaluating FMeasure...') 137 | beta2 = 0.3 138 | avg_f, avg_p, avg_r, img_num = 0.0, 0.0, 0.0, 0.0 139 | 140 | with torch.no_grad(): 141 | trans = transforms.Compose([transforms.ToTensor()]) 142 | for pred, gt in self.loader: 143 | if self.cuda: 144 | pred = trans(pred).cuda() 145 | gt = trans(gt).cuda() 146 | pred = (pred - torch.min(pred)) / (torch.max(pred) - 147 | torch.min(pred) + 1e-20) 148 | else: 149 | pred = trans(pred) 150 | pred = (pred - torch.min(pred)) / (torch.max(pred) - 151 | torch.min(pred) + 1e-20) 152 | gt = trans(gt) 153 | prec, recall = self._eval_pr(pred, gt, 255) 154 | f_score = (1 + beta2) * prec * recall / (beta2 * prec + recall) 155 | f_score[f_score != f_score] = 0 # for Nan 156 | avg_f += f_score 157 | avg_p += prec 158 | avg_r += recall 159 | img_num += 1.0 160 | Fm = avg_f / img_num 161 | avg_p = avg_p / img_num 162 | avg_r = avg_r / img_num 163 | return Fm, avg_p, avg_r 164 | 165 | def Eval_auc(self): 166 | print('Evaluating AUC...') 167 | 168 | avg_tpr, avg_fpr, avg_auc, img_num = 0.0, 0.0, 0.0, 0.0 169 | 170 | with torch.no_grad(): 171 | trans = transforms.Compose([transforms.ToTensor()]) 172 | for pred, gt in self.loader: 173 | if self.cuda: 174 | pred = trans(pred).cuda() 175 | pred = (pred - torch.min(pred)) / (torch.max(pred) - 176 | torch.min(pred) + 1e-20) 177 | gt = trans(gt).cuda() 178 | else: 179 | pred = trans(pred) 180 | pred = (pred - torch.min(pred)) / (torch.max(pred) - 181 | torch.min(pred) + 1e-20) 182 | gt = trans(gt) 183 | TPR, FPR = self._eval_roc(pred, gt, 255) 184 | avg_tpr += TPR 185 | avg_fpr += FPR 186 | img_num += 1.0 187 | avg_tpr = avg_tpr / img_num 188 | avg_fpr = avg_fpr / img_num 189 | 190 | sorted_idxes = torch.argsort(avg_fpr) 191 | avg_tpr = avg_tpr[sorted_idxes] 192 | avg_fpr = avg_fpr[sorted_idxes] 193 | avg_auc = torch.trapz(avg_tpr, avg_fpr) 194 | 195 | return avg_auc.item(), avg_tpr, avg_fpr 196 | 197 | def Eval_Emeasure(self): 198 | print('Evaluating EMeasure...') 199 | avg_e, img_num = 0.0, 0.0 200 | with torch.no_grad(): 201 | trans = transforms.Compose([transforms.ToTensor()]) 202 | Em = torch.zeros(255) 203 | if self.cuda: 204 | Em = Em.cuda() 205 | for pred, gt in self.loader: 206 | if self.cuda: 207 | pred = trans(pred).cuda() 208 | pred = (pred - torch.min(pred)) / (torch.max(pred) - 209 | torch.min(pred) + 1e-20) 210 | gt = trans(gt).cuda() 211 | else: 212 | pred = trans(pred) 213 | pred = (pred - torch.min(pred)) / (torch.max(pred) - 214 | torch.min(pred) + 1e-20) 215 | gt = trans(gt) 216 | Em += self._eval_e(pred, gt, 255) 217 | img_num += 1.0 218 | 219 | Em /= img_num 220 | return Em 221 | 222 | def select_by_Smeasure(self, bar=0.9, loader_comp=None, bar_comp=0.1): 223 | print('Evaluating SMeasure...') 224 | good_ones = [] 225 | good_ones_comp = [] 226 | good_ones_gt = [] 227 | alpha, avg_q, img_num = 0.5, 0.0, 0.0 228 | with torch.no_grad(): 229 | trans = transforms.Compose([transforms.ToTensor()]) 230 | for (pred, gt, predpath, gtpath), (pred_comp, gt_comp, predpath_comp) in zip(self.loader, loader_comp): 231 | # pred X gt 232 | if self.cuda: 233 | pred = trans(pred).cuda() 234 | pred = (pred - torch.min(pred)) / (torch.max(pred) - 235 | torch.min(pred) + 1e-20) 236 | gt = trans(gt).cuda() 237 | else: 238 | pred = trans(pred) 239 | pred = (pred - torch.min(pred)) / (torch.max(pred) - 240 | torch.min(pred) + 1e-20) 241 | gt = trans(gt) 242 | y = gt.mean() 243 | if y == 0: 244 | x = pred.mean() 245 | Q = 1.0 - x 246 | elif y == 1: 247 | x = pred.mean() 248 | Q = x 249 | else: 250 | gt[gt >= 0.5] = 1 251 | gt[gt < 0.5] = 0 252 | Q = alpha * self._S_object( 253 | pred, gt) + (1 - alpha) * self._S_region(pred, gt) 254 | if Q.item() < 0: 255 | Q = torch.FloatTensor([0.0]) 256 | img_num += 1.0 257 | avg_q += Q.item() 258 | # pred_comp X gt 259 | if self.cuda: 260 | pred_comp = trans(pred_comp).cuda() 261 | pred_comp = (pred_comp - torch.min(pred_comp)) / (torch.max(pred_comp) - 262 | torch.min(pred_comp) + 1e-20) 263 | gt_comp = trans(gt_comp).cuda() 264 | else: 265 | pred_comp = trans(pred_comp) 266 | pred_comp = (pred_comp - torch.min(pred_comp)) / (torch.max(pred_comp) - 267 | torch.min(pred_comp) + 1e-20) 268 | gt_comp = trans(gt_comp) 269 | y = gt_comp.mean() 270 | if y == 0: 271 | x = pred_comp.mean() 272 | Q_comp = 1.0 - x 273 | elif y == 1: 274 | x = pred_comp.mean() 275 | Q_comp = x 276 | else: 277 | gt_comp[gt_comp >= 0.5] = 1 278 | gt_comp[gt_comp < 0.5] = 0 279 | Q_comp = alpha * self._S_object( 280 | pred_comp, gt_comp) + (1 - alpha) * self._S_region(pred_comp, gt_comp) 281 | if Q_comp.item() < 0: 282 | Q_comp = torch.FloatTensor([0.0]) 283 | if Q.item() > bar and (Q.item() - Q_comp.item()) > bar_comp: 284 | good_ones.append(predpath) 285 | good_ones_comp.append(predpath_comp) 286 | good_ones_gt.append(gtpath) 287 | avg_q /= img_num 288 | return avg_q, good_ones, good_ones_comp, good_ones_gt 289 | 290 | def Eval_Smeasure(self): 291 | print('Evaluating SMeasure...') 292 | alpha, avg_q, img_num = 0.5, 0.0, 0.0 293 | with torch.no_grad(): 294 | trans = transforms.Compose([transforms.ToTensor()]) 295 | for pred, gt in self.loader: 296 | if self.cuda: 297 | pred = trans(pred).cuda() 298 | pred = (pred - torch.min(pred)) / (torch.max(pred) - 299 | torch.min(pred) + 1e-20) 300 | gt = trans(gt).cuda() 301 | else: 302 | pred = trans(pred) 303 | pred = (pred - torch.min(pred)) / (torch.max(pred) - 304 | torch.min(pred) + 1e-20) 305 | gt = trans(gt) 306 | y = gt.mean() 307 | if y == 0: 308 | x = pred.mean() 309 | Q = 1.0 - x 310 | elif y == 1: 311 | x = pred.mean() 312 | Q = x 313 | else: 314 | gt[gt >= 0.5] = 1 315 | gt[gt < 0.5] = 0 316 | Q = alpha * self._S_object( 317 | pred, gt) + (1 - alpha) * self._S_region(pred, gt) 318 | if Q.item() < 0: 319 | Q = torch.FloatTensor([0.0]) 320 | img_num += 1.0 321 | avg_q += Q.item() 322 | avg_q /= img_num 323 | return avg_q 324 | 325 | def LOG(self, output): 326 | os.makedirs(self.output_dir, exist_ok=True) 327 | with open(self.logfile, 'a') as f: 328 | f.write(output) 329 | 330 | def _eval_e(self, y_pred, y, num): 331 | if self.cuda: 332 | score = torch.zeros(num).cuda() 333 | thlist = torch.linspace(0, 1 - 1e-10, num).cuda() 334 | else: 335 | score = torch.zeros(num) 336 | thlist = torch.linspace(0, 1 - 1e-10, num) 337 | for i in range(num): 338 | y_pred_th = (y_pred >= thlist[i]).float() 339 | fm = y_pred_th - y_pred_th.mean() 340 | gt = y - y.mean() 341 | align_matrix = 2 * gt * fm / (gt * gt + fm * fm + 1e-20) 342 | enhanced = ((align_matrix + 1) * (align_matrix + 1)) / 4 343 | score[i] = torch.sum(enhanced) / (y.numel() - 1 + 1e-20) 344 | return score 345 | 346 | def _eval_pr(self, y_pred, y, num): 347 | if self.cuda: 348 | prec, recall = torch.zeros(num).cuda(), torch.zeros(num).cuda() 349 | thlist = torch.linspace(0, 1 - 1e-10, num).cuda() 350 | else: 351 | prec, recall = torch.zeros(num), torch.zeros(num) 352 | thlist = torch.linspace(0, 1 - 1e-10, num) 353 | for i in range(num): 354 | y_temp = (y_pred >= thlist[i]).float() 355 | tp = (y_temp * y).sum() 356 | prec[i], recall[i] = tp / (y_temp.sum() + 1e-20), tp / (y.sum() + 1e-20) 357 | return prec, recall 358 | 359 | def _eval_roc(self, y_pred, y, num): 360 | if self.cuda: 361 | TPR, FPR = torch.zeros(num).cuda(), torch.zeros(num).cuda() 362 | thlist = torch.linspace(0, 1 - 1e-10, num).cuda() 363 | else: 364 | TPR, FPR = torch.zeros(num), torch.zeros(num) 365 | thlist = torch.linspace(0, 1 - 1e-10, num) 366 | for i in range(num): 367 | y_temp = (y_pred >= thlist[i]).float() 368 | tp = (y_temp * y).sum() 369 | fp = (y_temp * (1 - y)).sum() 370 | tn = ((1 - y_temp) * (1 - y)).sum() 371 | fn = ((1 - y_temp) * y).sum() 372 | 373 | TPR[i] = tp / (tp + fn + 1e-20) 374 | FPR[i] = fp / (fp + tn + 1e-20) 375 | 376 | return TPR, FPR 377 | 378 | def _S_object(self, pred, gt): 379 | fg = torch.where(gt == 0, torch.zeros_like(pred), pred) 380 | bg = torch.where(gt == 1, torch.zeros_like(pred), 1 - pred) 381 | o_fg = self._object(fg, gt) 382 | o_bg = self._object(bg, 1 - gt) 383 | u = gt.mean() 384 | Q = u * o_fg + (1 - u) * o_bg 385 | return Q 386 | 387 | def _object(self, pred, gt): 388 | temp = pred[gt == 1] 389 | x = temp.mean() 390 | sigma_x = temp.std() 391 | score = 2.0 * x / (x * x + 1.0 + sigma_x + 1e-20) 392 | 393 | return score 394 | 395 | def _S_region(self, pred, gt): 396 | X, Y = self._centroid(gt) 397 | gt1, gt2, gt3, gt4, w1, w2, w3, w4 = self._divideGT(gt, X, Y) 398 | p1, p2, p3, p4 = self._dividePrediction(pred, X, Y) 399 | Q1 = self._ssim(p1, gt1) 400 | Q2 = self._ssim(p2, gt2) 401 | Q3 = self._ssim(p3, gt3) 402 | Q4 = self._ssim(p4, gt4) 403 | Q = w1 * Q1 + w2 * Q2 + w3 * Q3 + w4 * Q4 404 | return Q 405 | 406 | def _centroid(self, gt): 407 | rows, cols = gt.size()[-2:] 408 | gt = gt.view(rows, cols) 409 | if gt.sum() == 0: 410 | if self.cuda: 411 | X = torch.eye(1).cuda() * round(cols / 2) 412 | Y = torch.eye(1).cuda() * round(rows / 2) 413 | else: 414 | X = torch.eye(1) * round(cols / 2) 415 | Y = torch.eye(1) * round(rows / 2) 416 | else: 417 | total = gt.sum() 418 | if self.cuda: 419 | i = torch.from_numpy(np.arange(0, cols)).cuda().float() 420 | j = torch.from_numpy(np.arange(0, rows)).cuda().float() 421 | else: 422 | i = torch.from_numpy(np.arange(0, cols)).float() 423 | j = torch.from_numpy(np.arange(0, rows)).float() 424 | X = torch.round((gt.sum(dim=0) * i).sum() / total + 1e-20) 425 | Y = torch.round((gt.sum(dim=1) * j).sum() / total + 1e-20) 426 | return X.long(), Y.long() 427 | 428 | def _divideGT(self, gt, X, Y): 429 | h, w = gt.size()[-2:] 430 | area = h * w 431 | gt = gt.view(h, w) 432 | LT = gt[:Y, :X] 433 | RT = gt[:Y, X:w] 434 | LB = gt[Y:h, :X] 435 | RB = gt[Y:h, X:w] 436 | X = X.float() 437 | Y = Y.float() 438 | w1 = X * Y / area 439 | w2 = (w - X) * Y / area 440 | w3 = X * (h - Y) / area 441 | w4 = 1 - w1 - w2 - w3 442 | return LT, RT, LB, RB, w1, w2, w3, w4 443 | 444 | def _dividePrediction(self, pred, X, Y): 445 | h, w = pred.size()[-2:] 446 | pred = pred.view(h, w) 447 | LT = pred[:Y, :X] 448 | RT = pred[:Y, X:w] 449 | LB = pred[Y:h, :X] 450 | RB = pred[Y:h, X:w] 451 | return LT, RT, LB, RB 452 | 453 | def _ssim(self, pred, gt): 454 | gt = gt.float() 455 | h, w = pred.size()[-2:] 456 | N = h * w 457 | x = pred.mean() 458 | y = gt.mean() 459 | sigma_x2 = ((pred - x) * (pred - x)).sum() / (N - 1 + 1e-20) 460 | sigma_y2 = ((gt - y) * (gt - y)).sum() / (N - 1 + 1e-20) 461 | sigma_xy = ((pred - x) * (gt - y)).sum() / (N - 1 + 1e-20) 462 | 463 | aplha = 4 * x * y * sigma_xy 464 | beta = (x * x + y * y) * (sigma_x2 + sigma_y2) 465 | 466 | if aplha != 0: 467 | Q = aplha / (beta + 1e-20) 468 | elif aplha == 0 and beta == 0: 469 | Q = 1.0 470 | else: 471 | Q = 0 472 | return Q 473 | 474 | def Eval_AP(self, prec, recall): 475 | # Ref: 476 | # https://github.com/facebookresearch/Detectron/blob/05d04d3a024f0991339de45872d02f2f50669b3d/lib/datasets/voc_eval.py#L54 477 | print('Evaluating AP...') 478 | ap_r = np.concatenate(([0.], recall, [1.])) 479 | ap_p = np.concatenate(([0.], prec, [0.])) 480 | sorted_idxes = np.argsort(ap_r) 481 | ap_r = ap_r[sorted_idxes] 482 | ap_p = ap_p[sorted_idxes] 483 | count = ap_r.shape[0] 484 | 485 | for i in range(count - 1, 0, -1): 486 | ap_p[i - 1] = max(ap_p[i], ap_p[i - 1]) 487 | 488 | i = np.where(ap_r[1:] != ap_r[:-1])[0] 489 | ap = np.sum((ap_r[i + 1] - ap_r[i]) * ap_p[i + 1]) 490 | return ap 491 | -------------------------------------------------------------------------------- /evaluation/hist_of_pixel_values.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | import matplotlib.pyplot as plt 5 | 6 | 7 | root_dir = os.path.join([rd for rd in os.listdir('.') if 'gconet_' in rd][0], 'CoCA/Accordion') 8 | image_paths = [os.path.join(root_dir, p) for p in os.listdir(root_dir)] 9 | pixel_values = [] 10 | for image_path in image_paths: 11 | image = cv2.imread(image_path) 12 | pixel_value = image.flatten().squeeze().tolist() 13 | pixel_values += pixel_value 14 | 15 | pixel_values = np.array(pixel_values) 16 | 17 | non_zero_values = pixel_values[pixel_values >= 0] 18 | margin_values_percent = (np.sum(non_zero_values > 230) + np.sum(non_zero_values <= 0)) / non_zero_values.shape[0] * 100 19 | print('histing...') 20 | plt.hist(x=non_zero_values) 21 | plt.title('(0+>230)/all, {:.1f} % are margin values'.format(margin_values_percent)) 22 | plt.savefig('hist_(0+>230)|all.png') 23 | plt.show() 24 | 25 | non_zero_values = pixel_values[pixel_values >= 0] 26 | margin_values_percent = (np.sum(non_zero_values > 230) + np.sum(non_zero_values < 0)) / non_zero_values.shape[0] * 100 27 | print('histing...') 28 | plt.figure() 29 | plt.hist(x=non_zero_values) 30 | plt.title('(230)/all, {:.1f} % are margin values'.format(margin_values_percent)) 31 | plt.savefig('hist_(230)|all.png') 32 | plt.show() 33 | 34 | non_zero_values = pixel_values[pixel_values > 0] 35 | margin_values_percent = (np.sum(non_zero_values > 230) + np.sum(non_zero_values <= 0)) / non_zero_values.shape[0] * 100 36 | print('histing...') 37 | plt.figure() 38 | plt.hist(x=non_zero_values) 39 | plt.title('(0+>230)/(all-0), {:.1f} % are margin values'.format(margin_values_percent)) 40 | plt.savefig('hist_(0+>230)|(all-0).png') 41 | plt.show() 42 | -------------------------------------------------------------------------------- /evaluation/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | from scipy.io import loadmat 7 | import torch 8 | import torch.nn as nn 9 | 10 | from evaluator import Eval_thread 11 | from dataloader import EvalDataset 12 | 13 | import sys 14 | sys.path.append('..') 15 | from config import Config 16 | 17 | 18 | styles = ['.-r', '.--b', '.--g', '.--c', '.-m', '.-y', '.-k', '.-c'] 19 | lines = ['-', '--', '--', '--', '-', '-', '-', '-'] 20 | points = ['*', '.', '.', '.', '.', '.', '.', '.'] 21 | colors = ['r', 'b', 'g', 'c', 'm', 'orange', 'k', 'navy'] 22 | 23 | 24 | def main_plot(cfg): 25 | method_names = cfg.methods.split('+') 26 | dataset_names = cfg.datasets.split('+') 27 | os.makedirs(cfg.output_figure, exist_ok=True) 28 | # plt.style.use('seaborn-white') 29 | 30 | # Plot PR Cureve 31 | for dataset in dataset_names: 32 | plt.figure() 33 | idx_style = 0 34 | for method in method_names: 35 | iRes = loadmat(os.path.join(cfg.output_dir, method, 'final', dataset + '.mat')) 36 | imax = np.argmax(iRes['Fm']) 37 | plt.plot( 38 | iRes['Recall'], 39 | iRes['Prec'], 40 | # styles[idx_style], 41 | color=colors[idx_style], 42 | linestyle=lines[idx_style], 43 | marker=points[idx_style], 44 | markevery=[imax, imax], 45 | label=method) 46 | idx_style += 1 47 | 48 | plt.grid(True, zorder=-1) 49 | # plt.xlim(0, 1) 50 | # plt.ylim(0, 1.02) 51 | plt.ylabel('Precision', fontsize=25) 52 | plt.xlabel('Recall', fontsize=25) 53 | 54 | plt.legend(loc='lower left', prop={'size': 15}) 55 | plt.savefig(os.path.join(cfg.output_figure, 'PR_' + dataset + '.png'), 56 | dpi=600, 57 | bbox_inches='tight') 58 | plt.close() 59 | 60 | # Plot Fm Cureve 61 | for dataset in dataset_names: 62 | plt.figure() 63 | idx_style = 0 64 | for method in method_names: 65 | iRes = loadmat(os.path.join(cfg.output_dir, method, 'final', dataset + '.mat')) 66 | imax = np.argmax(iRes['Fm']) 67 | plt.plot( 68 | np.arange(0, 255), 69 | iRes['Fm'], 70 | # styles[idx_style], 71 | color=colors[idx_style], 72 | linestyle=lines[idx_style], 73 | marker=points[idx_style], 74 | label=method, 75 | markevery=[imax, imax]) 76 | idx_style += 1 77 | plt.grid(True, zorder=-1) 78 | # plt.ylim(0, 1) 79 | plt.ylabel('F-measure', fontsize=25) 80 | plt.xlabel('Threshold', fontsize=25) 81 | 82 | plt.legend(loc='lower left', prop={'size': 15}) 83 | plt.savefig(os.path.join(cfg.output_figure, 'Fm_' + dataset + '.png'), 84 | dpi=600, 85 | bbox_inches='tight') 86 | plt.close() 87 | 88 | # Plot Em Cureve 89 | for dataset in dataset_names: 90 | plt.figure() 91 | idx_style = 0 92 | for method in method_names: 93 | iRes = loadmat(os.path.join(cfg.output_dir, method, 'final', dataset + '.mat')) 94 | imax = np.argmax(iRes['Em']) 95 | plt.plot( 96 | np.arange(0, 255), 97 | iRes['Em'], 98 | # styles[idx_style], 99 | color=colors[idx_style], 100 | linestyle=lines[idx_style], 101 | marker=points[idx_style], 102 | label=method, 103 | markevery=[imax, imax]) 104 | idx_style += 1 105 | plt.grid(True, zorder=-1) 106 | plt.ylim(0, 1) 107 | plt.ylabel('E-measure', fontsize=16) 108 | plt.xlabel('Threshold', fontsize=16) 109 | 110 | plt.legend(loc='lower left', prop={'size': 15}) 111 | plt.savefig(os.path.join(cfg.output_figure, 'Em_' + dataset + '.png'), 112 | dpi=600, 113 | bbox_inches='tight') 114 | plt.close() 115 | 116 | # Plot ROC Cureve 117 | for dataset in dataset_names: 118 | plt.figure() 119 | idx_style = 0 120 | for method in method_names: 121 | iRes = loadmat(os.path.join(cfg.output_dir, method, 'final', dataset + '.mat')) 122 | imax = np.argmax(iRes['Fm']) 123 | plt.plot( 124 | iRes['FPR'], 125 | iRes['TPR'], 126 | # styles[idx_style][1:], 127 | color=colors[idx_style], 128 | linestyle=lines[idx_style], 129 | label=method) 130 | idx_style += 1 131 | 132 | plt.grid(True, zorder=-1) 133 | plt.xlim(0, 1) 134 | plt.ylim(0, 1.02) 135 | plt.ylabel('TPR', fontsize=16) 136 | plt.xlabel('FPR', fontsize=16) 137 | 138 | plt.legend(loc='lower right') 139 | plt.savefig(os.path.join(cfg.output_figure, 'ROC_' + dataset + '.png'), 140 | dpi=600, 141 | bbox_inches='tight') 142 | plt.close() 143 | 144 | # Plot Sm-MAE 145 | for dataset in dataset_names: 146 | plt.figure() 147 | plt.gca().invert_xaxis() 148 | idx_style = 0 149 | for method in method_names: 150 | iRes = loadmat(os.path.join(cfg.output_dir, method, 'final', dataset + '.mat')) 151 | plt.scatter(iRes['MAE'], 152 | iRes['Sm'], 153 | marker=points[idx_style], 154 | c=colors[idx_style], 155 | s=120) 156 | plt.annotate(method, 157 | xy=(iRes['MAE'], iRes['Sm']), 158 | xytext=(iRes['MAE'] - 0.001, iRes['Sm'] - 0.001), 159 | fontsize=14) 160 | idx_style += 1 161 | 162 | plt.grid(True, zorder=-1) 163 | # plt.xlim(0, 1) 164 | plt.ylim(0, 1) 165 | plt.ylabel('S-measure', fontsize=16) 166 | plt.xlabel('MAE', fontsize=16) 167 | plt.savefig(os.path.join(cfg.output_figure, 'Sm-MAE_' + dataset + '.png'), 168 | bbox_inches='tight') 169 | plt.close() 170 | 171 | 172 | def main(cfg): 173 | if cfg.methods is None: 174 | method_names = os.listdir(cfg.pred_dir) 175 | else: 176 | method_names = cfg.methods.split('+') 177 | if cfg.datasets is None: 178 | dataset_names = os.listdir(cfg.gt_dir) 179 | else: 180 | dataset_names = cfg.datasets.split('+') 181 | 182 | num_model_eval = Config().val_last 183 | threads = [] 184 | # model -> ckpt -> dataset 185 | for method in method_names: 186 | epochs = os.listdir(os.path.join(cfg.pred_dir, method))[-num_model_eval:][::-1] 187 | for epoch in epochs: 188 | continue_eval = True 189 | for dataset in dataset_names: 190 | loader = EvalDataset( 191 | os.path.join(cfg.pred_dir, method, epoch, dataset), # preds 192 | os.path.join(cfg.gt_dir, dataset) # GT 193 | ) 194 | print('Evaluating predictions from {}'.format(os.path.join(cfg.pred_dir, method, epoch, dataset))) 195 | thread = Eval_thread(loader, method, dataset, cfg.output_dir, epoch, cfg.cuda) 196 | info, continue_eval = thread.run(continue_eval=continue_eval) 197 | print(info) 198 | 199 | 200 | if __name__ == "__main__": 201 | parser = argparse.ArgumentParser() 202 | parser.add_argument('--methods', type=str, default='GCoNet_ext') 203 | parser.add_argument('--datasets', type=str, default='CoCA+CoSOD3k+CoSal2015') 204 | 205 | parser.add_argument('--gt_dir', type=str, default='/root/datasets/sod/gts', help='GT') 206 | parser.add_argument('--pred_dir', type=str, default='/root/datasets/sod/preds', help='predictions') 207 | parser.add_argument('--output_dir', type=str, default='./output/details', help='saving measurements here.') 208 | parser.add_argument('--output_figure', type=str, default='./output/figures', help='saving figures here.') 209 | 210 | parser.add_argument('--cuda', type=bool, default=True) 211 | config = parser.parse_args() 212 | main(config) 213 | -------------------------------------------------------------------------------- /evaluation/select_results.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | import cv2 5 | 6 | from evaluator import Eval_thread 7 | from dataloader import EvalDataset 8 | 9 | import sys 10 | sys.path.append('..') 11 | 12 | 13 | def main(cfg): 14 | dataset_names = cfg.datasets.split('+') 15 | root_dir_predictions = [dr for dr in os.listdir('.') if 'gconet_' in dr] 16 | root_dir_prediction_comp = cfg.gt_dir.replace('/gts', '/gconet') 17 | print('root_dir_predictions:', root_dir_predictions) 18 | root_dir_prediction = root_dir_predictions[0] 19 | root_dir_good_ones = 'good_ones' 20 | for dataset in dataset_names: 21 | dir_prediction = os.path.join(root_dir_prediction, dataset) 22 | dir_prediction_comp = os.path.join(root_dir_prediction_comp, dataset) 23 | dir_gt = os.path.join(cfg.gt_dir, dataset) 24 | loader = EvalDataset( 25 | dir_prediction, # preds 26 | dir_gt, # GT 27 | return_predpath=True, 28 | return_gtpath=True 29 | ) 30 | loader_comp = EvalDataset( 31 | dir_prediction_comp, # preds 32 | dir_gt, # GT 33 | return_predpath=True 34 | ) 35 | print('Selecting predictions from {}'.format(dir_prediction)) 36 | thread = Eval_thread(loader, cuda=cfg.cuda) 37 | s_measure, good_ones, good_ones_comp, good_ones_gt = thread.select_by_Smeasure(bar=0.95, loader_comp=loader_comp, bar_comp=0.2) 38 | dir_good_ones = os.path.join(root_dir_good_ones, dataset) 39 | os.makedirs(dir_good_ones, exist_ok=True) 40 | print('have good_ones {}'.format(len(good_ones))) 41 | for good_one, good_one_comp, good_one_gt in zip(good_ones, good_ones_comp, good_ones_gt): 42 | dir_category = os.path.join(dir_good_ones, good_one.split('/')[-2]) 43 | os.makedirs(dir_category, exist_ok=True) 44 | save_path = os.path.join(dir_category, good_one.split('/')[-1]) 45 | sal_map = cv2.imread(good_one) 46 | sal_map_gt = cv2.imread(good_one_gt) 47 | sal_map_comp = cv2.imread(good_one_comp) 48 | image_path = good_one_gt.replace('/gts', '/images').replace('.png', '.jpg') 49 | image = cv2.imread(image_path) 50 | cv2.imwrite(save_path, sal_map) 51 | split_line = np.zeros((sal_map.shape[0], 10, 3)).astype(sal_map.dtype) + 127 52 | comp = cv2.hconcat([image, split_line, sal_map_gt, split_line, sal_map, split_line, sal_map_comp]) 53 | save_path_comp = ''.join((save_path[:-4], '_comp', save_path[-4:])) 54 | cv2.imwrite(save_path_comp, comp) 55 | 56 | 57 | if __name__ == "__main__": 58 | parser = argparse.ArgumentParser() 59 | parser.add_argument('--datasets', type=str, default='CoCA+CoSOD3k+CoSal2015') 60 | parser.add_argument('--gt_dir', type=str, default='/root/datasets/sod/gts', help='GT') 61 | parser.add_argument('--cuda', type=bool, default=True) 62 | config = parser.parse_args() 63 | main(config) -------------------------------------------------------------------------------- /evaluation/sort_results.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | 6 | 7 | move_best_results_here = False 8 | 9 | record = ['dataset', 'ckpt', 'Emax', 'Smeasure', 'Fmax', 'MAE', 'Emean', 'Fmean'] 10 | measurement = 'Emax' 11 | score_idx = record.index(measurement) 12 | 13 | with open('output/details/result.txt', 'r') as f: 14 | res = f.read() 15 | 16 | res = res.replace('||', '').replace('(', '').replace(')', '') 17 | 18 | score = [] 19 | for r in res.splitlines(): 20 | ds = r.split() 21 | s = ds[:2] 22 | for idx_d, d in enumerate(ds[2:]): 23 | if idx_d % 2 == 0: 24 | s.append(float(d)) 25 | score.append(s) 26 | 27 | ss = sorted(score, key=lambda x: (x[record.index('dataset')], x[record.index('Emax')], x[record.index('Smeasure')], x[record.index('Fmax')], x[record.index('ckpt')]), reverse=True) 28 | ss_ar = np.array(ss) 29 | np.savetxt('score_sorted.txt', ss_ar, fmt='%s') 30 | ckpt_coca = ss_ar[ss_ar[:, 0] == 'CoCA'][0][1] 31 | ckpt_cosod = ss_ar[ss_ar[:, 0] == 'CoSOD3k'][0][1] 32 | ckpt_cosal = ss_ar[ss_ar[:, 0] == 'CoSal2015'][0][1] 33 | 34 | best_coca_scores = ss_ar[ss_ar[:, 1] == ckpt_coca] 35 | best_cosod_scores = ss_ar[ss_ar[:, 1] == ckpt_cosod] 36 | best_cosal_scores = ss_ar[ss_ar[:, 1] == ckpt_cosal] 37 | print('Best (models may be different):') 38 | print('CoCA:\n', best_coca_scores) 39 | print('CoSOD3k:\n', best_cosod_scores) 40 | print('CoSal2015:\n', best_cosal_scores) 41 | 42 | # Overal relative Emax improvement on three datasets 43 | if measurement == 'Emax': 44 | gco_scores = {'CoCA': 0.760, 'CoSOD3k': 0.860, 'CoSal2015': 0.887} 45 | gco_scores_Smeasure = {'CoCA': 0.673, 'CoSOD3k': 0.802, 'CoSal2015': 0.845} 46 | elif measurement == 'Smeasure': 47 | gco_scores = {'CoCA': 0.673, 'CoSOD3k': 0.802, 'CoSal2015': 0.845} 48 | elif measurement == 'Fmax': 49 | gco_scores = {'CoCA': 0.544, 'CoSOD3k': 0.777, 'CoSal2015': 0.847} 50 | elif measurement == 'Emean': 51 | gco_scores = {'CoCA': 0.1, 'CoSOD3k': 0.1, 'CoSal2015': 0.1} 52 | elif measurement == 'Fmean': 53 | gco_scores = {'CoCA': 0.1, 'CoSOD3k': 0.1, 'CoSal2015': 0.1} 54 | ckpts = list(set(ss_ar[:, 1].squeeze().tolist())) 55 | improvements_mean = [] 56 | improvements_lst = [] 57 | improvements_mean_Smeasure = [] 58 | improvements_lst_Smeasure = [] 59 | for ckpt in ckpts: 60 | scores = ss_ar[ss_ar[:, 1] == ckpt] 61 | if scores.shape[0] != len(gco_scores): 62 | improvements_mean.append(-1) 63 | improvements_lst.append([-1, -1, 1]) 64 | improvements_mean_Smeasure.append(-1) 65 | improvements_lst_Smeasure.append([-1, -1, 1]) 66 | continue 67 | score_coca = float(scores[scores[:, 0] == 'CoCA'][0][score_idx]) 68 | score_cosod = float(scores[scores[:, 0] == 'CoSOD3k'][0][score_idx]) 69 | score_cosal = float(scores[scores[:, 0] == 'CoSal2015'][0][score_idx]) 70 | improvements = [ 71 | (score_coca - gco_scores['CoCA']) / gco_scores['CoCA'], 72 | (score_cosod - gco_scores['CoSOD3k']) / gco_scores['CoSOD3k'], 73 | (score_cosal - gco_scores['CoSal2015']) / gco_scores['CoSal2015'] 74 | ] 75 | improvement_mean = np.mean(improvements) 76 | improvements_mean.append(improvement_mean) 77 | improvements_lst.append(improvements) 78 | 79 | # Smeasure 80 | score_coca = float(scores[scores[:, 0] == 'CoCA'][0][record.index('Smeasure')]) 81 | score_cosod = float(scores[scores[:, 0] == 'CoSOD3k'][0][record.index('Smeasure')]) 82 | score_cosal = float(scores[scores[:, 0] == 'CoSal2015'][0][record.index('Smeasure')]) 83 | improvements_Smeasure = [ 84 | (score_coca - gco_scores_Smeasure['CoCA']) / gco_scores_Smeasure['CoCA'], 85 | (score_cosod - gco_scores_Smeasure['CoSOD3k']) / gco_scores_Smeasure['CoSOD3k'], 86 | (score_cosal - gco_scores_Smeasure['CoSal2015']) / gco_scores_Smeasure['CoSal2015'] 87 | ] 88 | improvement_mean_Smeasure = np.mean(improvements_Smeasure) 89 | improvements_mean_Smeasure.append(improvement_mean_Smeasure) 90 | improvements_lst_Smeasure.append(improvements_Smeasure) 91 | best_measurement = 'Emax' 92 | if best_measurement == 'Emax': 93 | best_improvement_index = np.argsort(improvements_mean).tolist()[-1] 94 | best_ckpt = ckpts[best_improvement_index] 95 | best_improvement_mean = improvements_mean[best_improvement_index] 96 | best_improvements = improvements_lst[best_improvement_index] 97 | 98 | best_improvement_mean_Smeasure = improvements_mean_Smeasure[best_improvement_index] 99 | best_improvements_Smeasure = improvements_lst_Smeasure[best_improvement_index] 100 | elif best_measurement == 'Smeasure': 101 | best_improvement_index = np.argsort(improvements_mean_Smeasure).tolist()[-1] 102 | best_ckpt = ckpts[best_improvement_index] 103 | best_improvement_mean_Smeasure = improvements_mean_Smeasure[best_improvement_index] 104 | best_improvements_Smeasure = improvements_lst_Smeasure[best_improvement_index] 105 | 106 | best_improvement_mean = improvements_mean[best_improvement_index] 107 | best_improvements = improvements_lst[best_improvement_index] 108 | 109 | print('The overall best one:') 110 | print(ss_ar[ss_ar[:, 1] == best_ckpt]) 111 | print('Got Emax improvements on CoCA-{:.3f}%, CoSOD3k-{:.3f}%, CoSal2015-{:.3f}%, mean_improvement: {:.3f}%.'.format( 112 | best_improvements[0]*100, best_improvements[1]*100, best_improvements[2]*100, best_improvement_mean*100 113 | )) 114 | print('Got Smes improvements on CoCA-{:.3f}%, CoSOD3k-{:.3f}%, CoSal2015-{:.3f}%, mean_improvement: {:.3f}%.'.format( 115 | best_improvements_Smeasure[0]*100, best_improvements_Smeasure[1]*100, best_improvements_Smeasure[2]*100, best_improvement_mean_Smeasure*100 116 | )) 117 | trial = int(best_ckpt.split('_')[-1].split('-')[0]) 118 | ep = int(best_ckpt.split('ep')[-1].split(':')[0]) 119 | if move_best_results_here: 120 | trial, ep = 'gconet_{}'.format(trial), 'ep{}'.format(ep) 121 | dr = os.path.join(trial, ep) 122 | dst = '-'.join((trial, ep)) 123 | shutil.move(os.path.join('/root/datasets/sod/preds', dr), dst) 124 | 125 | 126 | # model_indices = sorted([fname.split('_')[-1] for fname in os.listdir('output/details') if 'gconet_' in fname]) 127 | # emax = {} 128 | # for model_idx in model_indices: 129 | # m = 'gconet_{}-'.format(model_idx) 130 | # if m not in list(emax.keys()): 131 | # emax[m] = [] 132 | # for s in score: 133 | # if m in s[1]: 134 | # ep = int(s[1].split('ep')[-1].rstrip('):')) 135 | # emax[m].append([ep, s[2], s[0]]) 136 | 137 | # for m, e in emax.items(): 138 | # plot_name = m[:-1] 139 | # print('Saving {} ...'.format(plot_name)) 140 | # e = np.array(e) 141 | # e_coca = e[e[:, -1] == 'CoCA'] 142 | # e_cosod = e[e[:, -1] == 'CoSOD3k'] 143 | # e_cosal = e[e[:, -1] == 'CoSal2015'] 144 | # eps = sorted(list(set(e_coca[:, 0].astype(float)))) 145 | 146 | # e_coca = np.array(sorted(e_coca, key=lambda x: int(x[0])))[:, 1].astype(float) 147 | # e_cosod = np.array(sorted(e_cosod, key=lambda x: int(x[0])))[:, 1].astype(float) 148 | # e_cosal = np.array(sorted(e_cosal, key=lambda x: int(x[0])))[:, 1].astype(float) 149 | 150 | # plt.figure() 151 | # plt.plot(eps, e_coca) 152 | # plt.plot(eps, e_cosod) 153 | # plt.plot(eps, e_cosal) 154 | # plt.legend(['CoCA', 'CoSOD3k', 'CoSal2015']) 155 | # plt.title(m) 156 | # plt.savefig('{}.png'.format(plot_name)) 157 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | import torch.nn.functional as F 4 | import math 5 | import numpy as np 6 | from torch.autograd import Variable 7 | 8 | 9 | class IoU_loss(torch.nn.Module): 10 | def __init__(self): 11 | super(IoU_loss, self).__init__() 12 | 13 | def forward(self, pred, target): 14 | b = pred.shape[0] 15 | IoU = 0.0 16 | for i in range(0, b): 17 | #compute the IoU of the foreground 18 | Iand1 = torch.sum(target[i, :, :, :]*pred[i, :, :, :]) 19 | Ior1 = torch.sum(target[i, :, :, :]) + torch.sum(pred[i, :, :, :])-Iand1 20 | IoU1 = Iand1/(Ior1 + 1e-5) 21 | #IoU loss is (1-IoU1) 22 | IoU = IoU + (1-IoU1) 23 | 24 | return IoU/b 25 | #return IoU 26 | 27 | 28 | class Scale_IoU(nn.Module): 29 | def __init__(self): 30 | super(Scale_IoU, self).__init__() 31 | self.iou = IoU_loss() 32 | 33 | def forward(self, scaled_preds, gt): 34 | loss = 0 35 | for pred_lvl in scaled_preds[0:]: 36 | loss += self.iou(torch.sigmoid(pred_lvl), gt) + self.iou(1-torch.sigmoid(pred_lvl), 1-gt) 37 | return loss 38 | 39 | 40 | def compute_cos_dis(x_sup, x_que): 41 | x_sup = x_sup.view(x_sup.size()[0], x_sup.size()[1], -1) 42 | x_que = x_que.view(x_que.size()[0], x_que.size()[1], -1) 43 | 44 | x_que_norm = torch.norm(x_que, p=2, dim=1, keepdim=True) 45 | x_sup_norm = torch.norm(x_sup, p=2, dim=1, keepdim=True) 46 | 47 | x_que_norm = x_que_norm.permute(0, 2, 1) 48 | x_qs_norm = torch.matmul(x_que_norm, x_sup_norm) 49 | 50 | x_que = x_que.permute(0, 2, 1) 51 | 52 | x_qs = torch.matmul(x_que, x_sup) 53 | x_qs = x_qs / (x_qs_norm + 1e-5) 54 | return x_qs 55 | 56 | 57 | -------------------------------------------------------------------------------- /models/__pycache__/main.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siyueyu/DCFM/05800b67ccf70f9ed55dd1f33ee1ff3b3503eb09/models/__pycache__/main.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/vgg.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siyueyu/DCFM/05800b67ccf70f9ed55dd1f33ee1ff3b3503eb09/models/__pycache__/vgg.cpython-37.pyc -------------------------------------------------------------------------------- /models/main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from models.vgg import VGG_Backbone 5 | from util import * 6 | 7 | 8 | def weights_init(module): 9 | if isinstance(module, nn.Conv2d): 10 | nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu') 11 | if module.bias is not None: 12 | nn.init.zeros_(module.bias) 13 | elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)): 14 | nn.init.ones_(module.weight) 15 | if module.bias is not None: 16 | nn.init.zeros_(module.bias) 17 | elif isinstance(module, nn.Linear): 18 | nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu') 19 | if module.bias is not None: 20 | nn.init.zeros_(module.bias) 21 | 22 | 23 | class EnLayer(nn.Module): 24 | def __init__(self, in_channel=64): 25 | super(EnLayer, self).__init__() 26 | self.enlayer = nn.Sequential( 27 | nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1), 28 | nn.ReLU(inplace=True), 29 | nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), 30 | ) 31 | 32 | def forward(self, x): 33 | x = self.enlayer(x) 34 | return x 35 | 36 | 37 | class LatLayer(nn.Module): 38 | def __init__(self, in_channel): 39 | super(LatLayer, self).__init__() 40 | self.convlayer = nn.Sequential( 41 | nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1), 42 | nn.ReLU(inplace=True), 43 | nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), 44 | ) 45 | 46 | def forward(self, x): 47 | x = self.convlayer(x) 48 | return x 49 | 50 | 51 | class DSLayer(nn.Module): 52 | def __init__(self, in_channel=64): 53 | super(DSLayer, self).__init__() 54 | self.enlayer = nn.Sequential( 55 | nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1), 56 | nn.ReLU(inplace=True), 57 | nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), 58 | nn.ReLU(inplace=True), 59 | ) 60 | self.predlayer = nn.Sequential( 61 | nn.Conv2d(64, 1, kernel_size=1, stride=1, padding=0))#, nn.Sigmoid()) 62 | 63 | def forward(self, x): 64 | x = self.enlayer(x) 65 | x = self.predlayer(x) 66 | return x 67 | 68 | 69 | class half_DSLayer(nn.Module): 70 | def __init__(self, in_channel=512): 71 | super(half_DSLayer, self).__init__() 72 | self.enlayer = nn.Sequential( 73 | nn.Conv2d(in_channel, int(in_channel/4), kernel_size=3, stride=1, padding=1), 74 | nn.ReLU(inplace=True), 75 | ) 76 | self.predlayer = nn.Sequential( 77 | nn.Conv2d(int(in_channel/4), 1, kernel_size=1, stride=1, padding=0)) #, nn.Sigmoid()) 78 | 79 | def forward(self, x): 80 | x = self.enlayer(x) 81 | x = self.predlayer(x) 82 | return x 83 | 84 | 85 | class AugAttentionModule(nn.Module): 86 | def __init__(self, input_channels=512): 87 | super(AugAttentionModule, self).__init__() 88 | self.query_transform = nn.Sequential( 89 | nn.Conv2d(input_channels, input_channels, kernel_size=1, stride=1, padding=0), 90 | nn.Conv2d(input_channels, input_channels, kernel_size=1, stride=1, padding=0), 91 | ) 92 | self.key_transform = nn.Sequential( 93 | nn.Conv2d(input_channels, input_channels, kernel_size=1, stride=1, padding=0), 94 | nn.Conv2d(input_channels, input_channels, kernel_size=1, stride=1, padding=0), 95 | ) 96 | self.value_transform = nn.Sequential( 97 | nn.Conv2d(input_channels, input_channels, kernel_size=1, stride=1, padding=0), 98 | nn.Conv2d(input_channels, input_channels, kernel_size=1, stride=1, padding=0), 99 | ) 100 | self.scale = 1.0 / (input_channels ** 0.5) 101 | self.conv = nn.Sequential( 102 | nn.Conv2d(input_channels, input_channels, kernel_size=1, stride=1, padding=0), 103 | nn.ReLU(inplace=True), 104 | ) 105 | 106 | def forward(self, x): 107 | B, C, H, W = x.size() 108 | x = self.conv(x) 109 | x_query = self.query_transform(x).view(B, C, -1).permute(0, 2, 1) # B,HW,C 110 | # x_key: C,BHW 111 | x_key = self.key_transform(x).view(B, C, -1) # B, C,HW 112 | # x_value: BHW, C 113 | x_value = self.value_transform(x).view(B, C, -1).permute(0, 2, 1) # B,HW,C 114 | attention_bmm = torch.bmm(x_query, x_key)*self.scale # B, HW, HW 115 | attention = F.softmax(attention_bmm, dim=-1) 116 | attention_sort = torch.sort(attention_bmm, dim=-1, descending=True)[1] 117 | attention_sort = torch.sort(attention_sort, dim=-1)[1] 118 | ##### 119 | attention_positive_num = torch.ones_like(attention).cuda() 120 | attention_positive_num[attention_bmm < 0] = 0 121 | att_pos_mask = attention_positive_num.clone() 122 | attention_positive_num = torch.sum(attention_positive_num, dim=-1, keepdim=True).expand_as(attention_sort) 123 | attention_sort_pos = attention_sort.float().clone() 124 | apn = attention_positive_num-1 125 | attention_sort_pos[attention_sort > apn] = 0 126 | attention_mask = ((attention_sort_pos+1)**3)*att_pos_mask + (1-att_pos_mask) 127 | out = torch.bmm(attention*attention_mask, x_value) 128 | out = out.view(B, H, W, C).permute(0, 3, 1, 2) 129 | return out+x 130 | 131 | 132 | class AttLayer(nn.Module): 133 | def __init__(self, input_channels=512): 134 | super(AttLayer, self).__init__() 135 | self.query_transform = nn.Conv2d(input_channels, input_channels, kernel_size=1, stride=1, padding=0) 136 | self.key_transform = nn.Conv2d(input_channels, input_channels, kernel_size=1, stride=1, padding=0) 137 | self.scale = 1.0 / (input_channels ** 0.5) 138 | self.conv = nn.Conv2d(input_channels, input_channels, kernel_size=1, stride=1, padding=0) 139 | 140 | def correlation(self, x5, seeds): 141 | B, C, H5, W5 = x5.size() 142 | if self.training: 143 | correlation_maps = F.conv2d(x5, weight=seeds) # B,B,H,W 144 | else: 145 | correlation_maps = torch.relu(F.conv2d(x5, weight=seeds)) # B,B,H,W 146 | correlation_maps = correlation_maps.mean(1).view(B, -1) 147 | min_value = torch.min(correlation_maps, dim=1, keepdim=True)[0] 148 | max_value = torch.max(correlation_maps, dim=1, keepdim=True)[0] 149 | correlation_maps = (correlation_maps - min_value) / (max_value - min_value + 1e-12) # shape=[B, HW] 150 | correlation_maps = correlation_maps.view(B, 1, H5, W5) # shape=[B, 1, H, W] 151 | return correlation_maps 152 | 153 | def forward(self, x5): 154 | # x: B,C,H,W 155 | x5 = self.conv(x5)+x5 156 | B, C, H5, W5 = x5.size() 157 | x_query = self.query_transform(x5).view(B, C, -1) 158 | # x_query: B,HW,C 159 | x_query = torch.transpose(x_query, 1, 2).contiguous().view(-1, C) # BHW, C 160 | # x_key: B,C,HW 161 | x_key = self.key_transform(x5).view(B, C, -1) 162 | x_key = torch.transpose(x_key, 0, 1).contiguous().view(C, -1) # C, BHW 163 | # W = Q^T K: B,HW,HW 164 | x_w1 = torch.matmul(x_query, x_key) * self.scale # BHW, BHW 165 | x_w = x_w1.view(B * H5 * W5, B, H5 * W5) 166 | x_w = torch.max(x_w, -1).values # BHW, B 167 | x_w = x_w.mean(-1) 168 | x_w = x_w.view(B, -1) # B, HW 169 | x_w = F.softmax(x_w, dim=-1) # B, HW 170 | ##### mine ###### 171 | # x_w_max = torch.max(x_w, -1) 172 | # max_indices0 = x_w_max.indices.unsqueeze(-1).unsqueeze(-1) 173 | norm0 = F.normalize(x5, dim=1) 174 | # norm = norm0.view(B, C, -1) 175 | # max_indices = max_indices0.expand(B, C, -1) 176 | # seeds = torch.gather(norm, 2, max_indices).unsqueeze(-1) 177 | x_w = x_w.unsqueeze(1) 178 | x_w_max = torch.max(x_w, -1).values.unsqueeze(2).expand_as(x_w) 179 | mask = torch.zeros_like(x_w).cuda() 180 | mask[x_w == x_w_max] = 1 181 | mask = mask.view(B, 1, H5, W5) 182 | seeds = norm0 * mask 183 | seeds = seeds.sum(3).sum(2).unsqueeze(2).unsqueeze(3) 184 | cormap = self.correlation(norm0, seeds) 185 | x51 = x5 * cormap 186 | proto1 = torch.mean(x51, (0, 2, 3), True) 187 | return x5, proto1, x5*proto1+x51, mask 188 | 189 | 190 | class Decoder(nn.Module): 191 | def __init__(self): 192 | super(Decoder, self).__init__() 193 | self.toplayer = nn.Sequential( 194 | nn.Conv2d(512, 64, kernel_size=1, stride=1, padding=0), 195 | nn.ReLU(inplace=True), 196 | nn.Conv2d(64, 64, kernel_size=1, stride=1, padding=0)) 197 | self.latlayer4 = LatLayer(in_channel=512) 198 | self.latlayer3 = LatLayer(in_channel=256) 199 | self.latlayer2 = LatLayer(in_channel=128) 200 | self.latlayer1 = LatLayer(in_channel=64) 201 | 202 | self.enlayer4 = EnLayer() 203 | self.enlayer3 = EnLayer() 204 | self.enlayer2 = EnLayer() 205 | self.enlayer1 = EnLayer() 206 | 207 | self.dslayer4 = DSLayer() 208 | self.dslayer3 = DSLayer() 209 | self.dslayer2 = DSLayer() 210 | self.dslayer1 = DSLayer() 211 | 212 | def _upsample_add(self, x, y): 213 | [_, _, H, W] = y.size() 214 | x = F.interpolate(x, size=(H, W), mode='bilinear', align_corners=False) 215 | return x + y 216 | 217 | def forward(self, weighted_x5, x4, x3, x2, x1, H, W): 218 | preds = [] 219 | p5 = self.toplayer(weighted_x5) 220 | p4 = self._upsample_add(p5, self.latlayer4(x4)) 221 | p4 = self.enlayer4(p4) 222 | _pred = self.dslayer4(p4) 223 | preds.append( 224 | F.interpolate(_pred, 225 | size=(H, W), 226 | mode='bilinear', align_corners=False)) 227 | 228 | p3 = self._upsample_add(p4, self.latlayer3(x3)) 229 | p3 = self.enlayer3(p3) 230 | _pred = self.dslayer3(p3) 231 | preds.append( 232 | F.interpolate(_pred, 233 | size=(H, W), 234 | mode='bilinear', align_corners=False)) 235 | 236 | p2 = self._upsample_add(p3, self.latlayer2(x2)) 237 | p2 = self.enlayer2(p2) 238 | _pred = self.dslayer2(p2) 239 | preds.append( 240 | F.interpolate(_pred, 241 | size=(H, W), 242 | mode='bilinear', align_corners=False)) 243 | 244 | p1 = self._upsample_add(p2, self.latlayer1(x1)) 245 | p1 = self.enlayer1(p1) 246 | _pred = self.dslayer1(p1) 247 | preds.append( 248 | F.interpolate(_pred, 249 | size=(H, W), 250 | mode='bilinear', align_corners=False)) 251 | return preds 252 | 253 | 254 | class DCFMNet(nn.Module): 255 | """ Class for extracting activations and 256 | registering gradients from targetted intermediate layers """ 257 | def __init__(self, mode='train'): 258 | super(DCFMNet, self).__init__() 259 | self.gradients = None 260 | self.backbone = VGG_Backbone() 261 | self.mode = mode 262 | self.aug = AugAttentionModule() 263 | self.fusion = AttLayer(512) 264 | self.decoder = Decoder() 265 | 266 | def set_mode(self, mode): 267 | self.mode = mode 268 | 269 | def forward(self, x, gt): 270 | if self.mode == 'train': 271 | preds = self._forward(x, gt) 272 | else: 273 | with torch.no_grad(): 274 | preds = self._forward(x, gt) 275 | 276 | return preds 277 | 278 | def featextract(self, x): 279 | x1 = self.backbone.conv1(x) 280 | x2 = self.backbone.conv2(x1) 281 | x3 = self.backbone.conv3(x2) 282 | x4 = self.backbone.conv4(x3) 283 | x5 = self.backbone.conv5(x4) 284 | return x5, x4, x3, x2, x1 285 | 286 | def _forward(self, x, gt): 287 | [B, _, H, W] = x.size() 288 | x5, x4, x3, x2, x1 = self.featextract(x) 289 | feat, proto, weighted_x5, cormap = self.fusion(x5) 290 | feataug = self.aug(weighted_x5) 291 | preds = self.decoder(feataug, x4, x3, x2, x1, H, W) 292 | if self.training: 293 | gt = F.interpolate(gt, size=weighted_x5.size()[2:], mode='bilinear', align_corners=False) 294 | feat_pos, proto_pos, weighted_x5_pos, cormap_pos = self.fusion(x5 * gt) 295 | feat_neg, proto_neg, weighted_x5_neg, cormap_neg = self.fusion(x5*(1-gt)) 296 | return preds, proto, proto_pos, proto_neg 297 | return preds 298 | 299 | 300 | class DCFM(nn.Module): 301 | def __init__(self, mode='train'): 302 | super(DCFM, self).__init__() 303 | set_seed(123) 304 | self.dcfmnet = DCFMNet() 305 | self.mode = mode 306 | 307 | def set_mode(self, mode): 308 | self.mode = mode 309 | self.dcfmnet.set_mode(self.mode) 310 | 311 | def forward(self, x, gt): 312 | ########## Co-SOD ############ 313 | preds = self.dcfmnet(x, gt) 314 | return preds 315 | 316 | -------------------------------------------------------------------------------- /models/vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | 5 | 6 | class VGG_Backbone(nn.Module): 7 | # VGG16 with two branches 8 | # pooling layer at the front of block 9 | def __init__(self): 10 | super(VGG_Backbone, self).__init__() 11 | conv1 = nn.Sequential() 12 | conv1.add_module('conv1_1', nn.Conv2d(3, 64, 3, 1, 1)) 13 | conv1.add_module('relu1_1', nn.ReLU(inplace=True)) 14 | conv1.add_module('conv1_2', nn.Conv2d(64, 64, 3, 1, 1)) 15 | conv1.add_module('relu1_2', nn.ReLU(inplace=True)) 16 | self.conv1 = conv1 17 | 18 | conv2 = nn.Sequential() 19 | conv2.add_module('pool1', nn.MaxPool2d(2, stride=2)) 20 | conv2.add_module('conv2_1', nn.Conv2d(64, 128, 3, 1, 1)) 21 | conv2.add_module('relu2_1', nn.ReLU()) 22 | conv2.add_module('conv2_2', nn.Conv2d(128, 128, 3, 1, 1)) 23 | conv2.add_module('relu2_2', nn.ReLU()) 24 | self.conv2 = conv2 25 | 26 | conv3 = nn.Sequential() 27 | conv3.add_module('pool2', nn.MaxPool2d(2, stride=2)) 28 | conv3.add_module('conv3_1', nn.Conv2d(128, 256, 3, 1, 1)) 29 | conv3.add_module('relu3_1', nn.ReLU()) 30 | conv3.add_module('conv3_2', nn.Conv2d(256, 256, 3, 1, 1)) 31 | conv3.add_module('relu3_2', nn.ReLU()) 32 | conv3.add_module('conv3_3', nn.Conv2d(256, 256, 3, 1, 1)) 33 | conv3.add_module('relu3_3', nn.ReLU()) 34 | self.conv3 = conv3 35 | 36 | conv4 = nn.Sequential() 37 | conv4.add_module('pool3', nn.MaxPool2d(2, stride=2)) 38 | conv4.add_module('conv4_1', nn.Conv2d(256, 512, 3, 1, 1)) 39 | conv4.add_module('relu4_1', nn.ReLU()) 40 | conv4.add_module('conv4_2', nn.Conv2d(512, 512, 3, 1, 1)) 41 | conv4.add_module('relu4_2', nn.ReLU()) 42 | conv4.add_module('conv4_3', nn.Conv2d(512, 512, 3, 1, 1)) 43 | conv4.add_module('relu4_3', nn.ReLU()) 44 | self.conv4 = conv4 45 | 46 | conv5 = nn.Sequential() 47 | conv5.add_module('pool4', nn.MaxPool2d(2, stride=2)) 48 | conv5.add_module('conv5_1', nn.Conv2d(512, 512, 3, 1, 1)) 49 | conv5.add_module('relu5_1', nn.ReLU()) 50 | conv5.add_module('conv5_2', nn.Conv2d(512, 512, 3, 1, 1)) 51 | conv5.add_module('relu5_2', nn.ReLU()) 52 | conv5.add_module('conv5_3', nn.Conv2d(512, 512, 3, 1, 1)) 53 | conv5.add_module('relu5_3', nn.ReLU()) 54 | self.conv5 = conv5 55 | 56 | self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) 57 | self.classifier = nn.Sequential( 58 | nn.Linear(512 * 7 * 7, 4096), 59 | nn.ReLU(True), 60 | nn.Dropout(), 61 | nn.Linear(4096, 4096), 62 | nn.ReLU(True), 63 | nn.Dropout(), 64 | nn.Linear(4096, 1000), 65 | ) 66 | 67 | pre_train = torch.load(os.path.dirname(__file__) + '/vgg16-397923af.pth') 68 | self._initialize_weights(pre_train) 69 | 70 | def forward(self, x): 71 | x = self.conv1(x) 72 | x = self.conv2(x) 73 | x = self.conv3(x) 74 | x1 = self.conv4_1(x) 75 | x1 = self.conv5_1(x1) 76 | x1 = self.avgpool(x1) 77 | _x1 = x1.view(x1.size(0), -1) 78 | pred_vector = self.classifier(_x1) 79 | 80 | x2 = self.conv4_2(x) 81 | x2 = self.conv5_2(x2) 82 | return x1, pred_vector, x2 83 | 84 | def _initialize_weights(self, pre_train): 85 | keys = list(pre_train.keys()) 86 | self.conv1.conv1_1.weight.data.copy_(pre_train[keys[0]]) 87 | self.conv1.conv1_2.weight.data.copy_(pre_train[keys[2]]) 88 | self.conv2.conv2_1.weight.data.copy_(pre_train[keys[4]]) 89 | self.conv2.conv2_2.weight.data.copy_(pre_train[keys[6]]) 90 | self.conv3.conv3_1.weight.data.copy_(pre_train[keys[8]]) 91 | self.conv3.conv3_2.weight.data.copy_(pre_train[keys[10]]) 92 | self.conv3.conv3_3.weight.data.copy_(pre_train[keys[12]]) 93 | self.conv4.conv4_1.weight.data.copy_(pre_train[keys[14]]) 94 | self.conv4.conv4_2.weight.data.copy_(pre_train[keys[16]]) 95 | self.conv4.conv4_3.weight.data.copy_(pre_train[keys[18]]) 96 | self.conv5.conv5_1.weight.data.copy_(pre_train[keys[20]]) 97 | self.conv5.conv5_2.weight.data.copy_(pre_train[keys[22]]) 98 | self.conv5.conv5_3.weight.data.copy_(pre_train[keys[24]]) 99 | 100 | self.conv1.conv1_1.bias.data.copy_(pre_train[keys[1]]) 101 | self.conv1.conv1_2.bias.data.copy_(pre_train[keys[3]]) 102 | self.conv2.conv2_1.bias.data.copy_(pre_train[keys[5]]) 103 | self.conv2.conv2_2.bias.data.copy_(pre_train[keys[7]]) 104 | self.conv3.conv3_1.bias.data.copy_(pre_train[keys[9]]) 105 | self.conv3.conv3_2.bias.data.copy_(pre_train[keys[11]]) 106 | self.conv3.conv3_3.bias.data.copy_(pre_train[keys[13]]) 107 | self.conv4.conv4_1.bias.data.copy_(pre_train[keys[15]]) 108 | self.conv4.conv4_2.bias.data.copy_(pre_train[keys[17]]) 109 | self.conv4.conv4_3.bias.data.copy_(pre_train[keys[19]]) 110 | self.conv5.conv5_1.bias.data.copy_(pre_train[keys[21]]) 111 | self.conv5.conv5_2.bias.data.copy_(pre_train[keys[23]]) 112 | self.conv5.conv5_3.bias.data.copy_(pre_train[keys[25]]) 113 | 114 | self.classifier[0].weight.data.copy_(pre_train[keys[26]]) 115 | self.classifier[0].bias.data.copy_(pre_train[keys[27]]) 116 | self.classifier[3].weight.data.copy_(pre_train[keys[28]]) 117 | self.classifier[3].bias.data.copy_(pre_train[keys[29]]) 118 | self.classifier[6].weight.data.copy_(pre_train[keys[30]]) 119 | self.classifier[6].bias.data.copy_(pre_train[keys[31]]) 120 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.4.1 2 | numpy==1.19.2 3 | opencv_python==4.5.1.48 4 | pandas==1.2.4 5 | Pillow==9.1.0 6 | pytorch_toolbelt==0.4.3 7 | scikit_image==0.18.1 8 | skimage==0.0 9 | torch==1.7.1 10 | torchvision==0.2.2 11 | tqdm==4.60.0 12 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from dataset import get_loader 3 | import torch 4 | from torchvision import transforms 5 | from util import save_tensor_img, Logger 6 | from tqdm import tqdm 7 | from torch import nn 8 | import os 9 | from models.main import * 10 | import argparse 11 | import numpy as np 12 | import cv2 13 | from skimage import img_as_ubyte 14 | 15 | 16 | def main(args): 17 | # Init model 18 | 19 | device = torch.device("cuda") 20 | model = DCFM() 21 | model = model.to(device) 22 | try: 23 | modelname = os.path.join(args.param_root, 'best_ep198_Smeasure0.7019.pth') 24 | dcfmnet_dict = torch.load(modelname) 25 | print('loaded', modelname) 26 | except: 27 | dcfmnet_dict = torch.load(os.path.join(args.param_root, 'dcfm.pth')) 28 | 29 | model.to(device) 30 | model.dcfmnet.load_state_dict(dcfmnet_dict) 31 | model.eval() 32 | model.set_mode('test') 33 | 34 | tensor2pil = transforms.ToPILImage() 35 | for testset in ['CoCA']: 36 | if testset == 'CoCA': 37 | test_img_path = './data/images/CoCA/' 38 | test_gt_path = './data/gts/CoCA/' 39 | saved_root = os.path.join(args.save_root, 'CoCA') 40 | elif testset == 'CoSOD3k': 41 | test_img_path = './data/images/CoSOD3k/' 42 | test_gt_path = './data/gts/CoSOD3k/' 43 | saved_root = os.path.join(args.save_root, 'CoSOD3k') 44 | elif testset == 'CoSal2015': 45 | test_img_path = './data/images/CoSal2015/' 46 | test_gt_path = './data/gts/CoSal2015/' 47 | saved_root = os.path.join(args.save_root, 'CoSal2015') 48 | else: 49 | print('Unkonwn test dataset') 50 | print(args.dataset) 51 | 52 | test_loader = get_loader( 53 | test_img_path, test_gt_path, args.size, 1, istrain=False, shuffle=False, num_workers=8, pin=True) 54 | 55 | for batch in tqdm(test_loader): 56 | inputs = batch[0].to(device).squeeze(0) 57 | gts = batch[1].to(device).squeeze(0) 58 | subpaths = batch[2] 59 | ori_sizes = batch[3] 60 | scaled_preds= model(inputs, gts) 61 | scaled_preds = torch.sigmoid(scaled_preds[-1]) 62 | os.makedirs(os.path.join(saved_root, subpaths[0][0].split('/')[0]), exist_ok=True) 63 | num = gts.shape[0] 64 | for inum in range(num): 65 | subpath = subpaths[inum][0] 66 | ori_size = (ori_sizes[inum][0].item(), ori_sizes[inum][1].item()) 67 | res = nn.functional.interpolate(scaled_preds[inum].unsqueeze(0), size=ori_size, mode='bilinear', align_corners=True) 68 | save_tensor_img(res, os.path.join(saved_root, subpath)) 69 | 70 | 71 | if __name__ == '__main__': 72 | # Parameter from command line 73 | parser = argparse.ArgumentParser(description='') 74 | parser.add_argument('--size', 75 | default=224, 76 | type=int, 77 | help='input size') 78 | parser.add_argument('--param_root', default='/data1/dcfm/temp', type=str, help='model folder') 79 | parser.add_argument('--save_root', default='./CoSODmaps/pred', type=str, help='Output folder') 80 | 81 | args = parser.parse_args() 82 | 83 | main(args) 84 | 85 | 86 | 87 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | from util import Logger, AverageMeter, save_checkpoint, save_tensor_img, set_seed 5 | import os 6 | import numpy as np 7 | from matplotlib import pyplot as plt 8 | import time 9 | import argparse 10 | from tqdm import tqdm 11 | from dataset import get_loader 12 | from loss import * 13 | from config import Config 14 | from evaluation.dataloader import EvalDataset 15 | from evaluation.evaluator import Eval_thread 16 | 17 | 18 | from models.main import * 19 | 20 | import torch.nn.functional as F 21 | import pytorch_toolbelt.losses as PTL 22 | 23 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 24 | # Parameter from command line 25 | parser = argparse.ArgumentParser(description='') 26 | 27 | parser.add_argument('--loss', 28 | default='Scale_IoU', 29 | type=str, 30 | help="Options: '', ''") 31 | parser.add_argument('--bs', '--batch_size', default=1, type=int) 32 | parser.add_argument('--lr', 33 | '--learning_rate', 34 | default=1e-4, 35 | type=float, 36 | help='Initial learning rate') 37 | parser.add_argument('--resume', 38 | default=None, 39 | type=str, 40 | help='path to latest checkpoint') 41 | parser.add_argument('--epochs', default=200, type=int) 42 | parser.add_argument('--start_epoch', 43 | default=0, 44 | type=int, 45 | help='manual epoch number (useful on restarts)') 46 | parser.add_argument('--trainset', 47 | default='CoCo', 48 | type=str, 49 | help="Options: 'CoCo'") 50 | parser.add_argument('--testsets', 51 | default='CoCA', 52 | type=str, 53 | help="Options: 'CoCA','CoSal2015','CoSOD3k','iCoseg','MSRC'") 54 | parser.add_argument('--size', 55 | default=224, 56 | type=int, 57 | help='input size') 58 | parser.add_argument('--tmp', default='/data1/dcfm/temp', help='Temporary folder') 59 | parser.add_argument('--save_root', default='./CoSODmaps/pred', type=str, help='Output folder') 60 | 61 | args = parser.parse_args() 62 | config = Config() 63 | 64 | # Prepare dataset 65 | if args.trainset == 'CoCo': 66 | train_img_path = './data/CoCo/img/' 67 | train_gt_path = './data/CoCo/gt/' 68 | train_loader = get_loader(train_img_path, 69 | train_gt_path, 70 | args.size, 71 | args.bs, 72 | max_num=16, #20, 73 | istrain=True, 74 | shuffle=False, 75 | num_workers=8, #4, 76 | pin=True) 77 | 78 | else: 79 | print('Unkonwn train dataset') 80 | print(args.dataset) 81 | 82 | for testset in ['CoCA']: 83 | if testset == 'CoCA': 84 | test_img_path = './data/images/CoCA/' 85 | test_gt_path = './data/gts/CoCA/' 86 | 87 | saved_root = os.path.join(args.save_root, 'CoCA') 88 | elif testset == 'CoSOD3k': 89 | test_img_path = './data/images/CoSOD3k/' 90 | test_gt_path = './data/gts/CoSOD3k/' 91 | saved_root = os.path.join(args.save_root, 'CoSOD3k') 92 | elif testset == 'CoSal2015': 93 | test_img_path = './data/images/CoSal2015/' 94 | test_gt_path = './data/gts/CoSal2015/' 95 | saved_root = os.path.join(args.save_root, 'CoSal2015') 96 | elif testset == 'CoCo': 97 | test_img_path = './data/images/CoCo/' 98 | test_gt_path = './data/gts/CoCo/' 99 | saved_root = os.path.join(args.save_root, 'CoCo') 100 | else: 101 | print('Unkonwn test dataset') 102 | print(args.dataset) 103 | 104 | test_loader = get_loader( 105 | test_img_path, test_gt_path, args.size, 1, istrain=False, shuffle=False, num_workers=8, pin=True) 106 | 107 | # make dir for tmp 108 | os.makedirs(args.tmp, exist_ok=True) 109 | 110 | # Init log file 111 | logger = Logger(os.path.join(args.tmp, "log.txt")) 112 | set_seed(123) 113 | 114 | # Init model 115 | device = torch.device("cuda") 116 | 117 | model = DCFM() 118 | model = model.to(device) 119 | model.apply(weights_init) 120 | 121 | model.dcfmnet.backbone._initialize_weights(torch.load('./models/vgg16-397923af.pth')) 122 | 123 | backbone_params = list(map(id, model.dcfmnet.backbone.parameters())) 124 | base_params = filter(lambda p: id(p) not in backbone_params, 125 | model.dcfmnet.parameters()) 126 | 127 | all_params = [{'params': base_params}, {'params': model.dcfmnet.backbone.parameters(), 'lr': args.lr*0.1}] 128 | 129 | # Setting optimizer 130 | optimizer = optim.Adam(params=all_params,lr=args.lr, weight_decay=1e-4, betas=[0.9, 0.99]) 131 | 132 | for key, value in model.named_parameters(): 133 | if 'dcfmnet.backbone' in key and 'dcfmnet.backbone.conv5.conv5_3' not in key: 134 | value.requires_grad = False 135 | 136 | for key, value in model.named_parameters(): 137 | print(key, value.requires_grad) 138 | 139 | # log model and optimizer pars 140 | logger.info("Model details:") 141 | logger.info(model) 142 | logger.info("Optimizer details:") 143 | logger.info(optimizer) 144 | logger.info("Scheduler details:") 145 | # logger.info(scheduler) 146 | logger.info("Other hyperparameters:") 147 | logger.info(args) 148 | 149 | # Setting Loss 150 | exec('from loss import ' + args.loss) 151 | IOUloss = eval(args.loss+'()') 152 | 153 | 154 | def main(): 155 | val_measures = [] 156 | # Optionally resume from a checkpoint 157 | if args.resume: 158 | if os.path.isfile(args.resume): 159 | logger.info("=> loading checkpoint '{}'".format(args.resume)) 160 | checkpoint = torch.load(args.resume) 161 | args.start_epoch = checkpoint['epoch'] 162 | model.dcfmnet.load_state_dict(checkpoint['state_dict']) 163 | optimizer.load_state_dict(checkpoint['optimizer']) 164 | logger.info("=> loaded checkpoint '{}' (epoch {})".format( 165 | args.resume, checkpoint['epoch'])) 166 | else: 167 | logger.info("=> no checkpoint found at '{}'".format(args.resume)) 168 | 169 | print(args.epochs) 170 | for epoch in range(args.start_epoch, args.epochs): 171 | train_loss = train(epoch) 172 | if config.validation: 173 | measures = validate(model, test_loader, args.testsets) 174 | val_measures.append(measures) 175 | print( 176 | 'Validation: S_measure on CoCA for epoch-{} is {:.4f}. Best epoch is epoch-{} with S_measure {:.4f}'.format( 177 | epoch, measures[0], np.argmax(np.array(val_measures)[:, 0].squeeze()), 178 | np.max(np.array(val_measures)[:, 0])) 179 | ) 180 | # Save checkpoint 181 | save_checkpoint( 182 | { 183 | 'epoch': epoch + 1, 184 | 'state_dict': model.dcfmnet.state_dict(), 185 | #'scheduler': scheduler.state_dict(), 186 | }, 187 | path=args.tmp) 188 | if config.validation: 189 | if np.max(np.array(val_measures)[:, 0].squeeze()) == measures[0]: 190 | best_weights_before = [os.path.join(args.tmp, weight_file) for weight_file in 191 | os.listdir(args.tmp) if 'best_' in weight_file] 192 | for best_weight_before in best_weights_before: 193 | os.remove(best_weight_before) 194 | torch.save(model.dcfmnet.state_dict(), 195 | os.path.join(args.tmp, 'best_ep{}_Smeasure{:.4f}.pth'.format(epoch, measures[0]))) 196 | if (epoch + 1) % 10 == 0 or epoch == 0: 197 | torch.save(model.dcfmnet.state_dict(), args.tmp + '/model-' + str(epoch + 1) + '.pt') 198 | 199 | if epoch > 188: 200 | torch.save(model.dcfmnet.state_dict(), args.tmp+'/model-' + str(epoch + 1) + '.pt') 201 | #dcfmnet_dict = model.dcfmnet.state_dict() 202 | #torch.save(dcfmnet_dict, os.path.join(args.tmp, 'final.pth')) 203 | 204 | def sclloss(x, xt, xb): 205 | cosc = (1+compute_cos_dis(x, xt))*0.5 206 | cosb = (1+compute_cos_dis(x, xb))*0.5 207 | loss = -torch.log(cosc+1e-5)-torch.log(1-cosb+1e-5) 208 | return loss.sum() 209 | 210 | def train(epoch): 211 | # Switch to train mode 212 | model.train() 213 | model.set_mode('train') 214 | loss_sum = 0.0 215 | loss_sumkl = 0.0 216 | for batch_idx, batch in enumerate(train_loader): 217 | inputs = batch[0].to(device).squeeze(0) 218 | gts = batch[1].to(device).squeeze(0) 219 | pred, proto, protogt, protobg = model(inputs, gts) 220 | loss_iou = IOUloss(pred, gts) 221 | loss_scl = sclloss(proto, protogt, protobg) 222 | loss = loss_iou+0.1*loss_scl 223 | optimizer.zero_grad() 224 | loss.backward() 225 | optimizer.step() 226 | loss_sum = loss_sum + loss_iou.detach().item() 227 | 228 | if batch_idx % 20 == 0: 229 | logger.info('Epoch[{0}/{1}] Iter[{2}/{3}] ' 230 | 'Train Loss: loss_iou: {4:.3f}, loss_scl: {5:.3f} '.format( 231 | epoch, 232 | args.epochs, 233 | batch_idx, 234 | len(train_loader), 235 | loss_iou, 236 | loss_scl, 237 | )) 238 | loss_mean = loss_sum / len(train_loader) 239 | return loss_sum 240 | 241 | 242 | def validate(model, test_loaders, testsets): 243 | model.eval() 244 | 245 | testsets = testsets.split('+') 246 | measures = [] 247 | for testset in testsets[:1]: 248 | print('Validating {}...'.format(testset)) 249 | #test_loader = test_loaders[testset] 250 | 251 | saved_root = os.path.join(args.save_root, testset) 252 | 253 | for batch in test_loader: 254 | inputs = batch[0].to(device).squeeze(0) 255 | gts = batch[1].to(device).squeeze(0) 256 | subpaths = batch[2] 257 | ori_sizes = batch[3] 258 | with torch.no_grad(): 259 | scaled_preds = model(inputs, gts)[-1].sigmoid() 260 | 261 | os.makedirs(os.path.join(saved_root, subpaths[0][0].split('/')[0]), exist_ok=True) 262 | 263 | num = len(scaled_preds) 264 | for inum in range(num): 265 | subpath = subpaths[inum][0] 266 | ori_size = (ori_sizes[inum][0].item(), ori_sizes[inum][1].item()) 267 | res = nn.functional.interpolate(scaled_preds[inum].unsqueeze(0), size=ori_size, mode='bilinear', 268 | align_corners=True) 269 | save_tensor_img(res, os.path.join(saved_root, subpath)) 270 | 271 | eval_loader = EvalDataset( 272 | saved_root, # preds 273 | os.path.join('./data/gts', testset) # GT 274 | ) 275 | evaler = Eval_thread(eval_loader, cuda=True) 276 | # Use S_measure for validation 277 | s_measure = evaler.Eval_Smeasure() 278 | if s_measure > config.val_measures['Smeasure']['CoCA'] and 0: 279 | # TODO: evluate others measures if s_measure is very high. 280 | e_max = evaler.Eval_Emeasure().max().item() 281 | f_max = evaler.Eval_fmeasure().max().item() 282 | print('Emax: {:4.f}, Fmax: {:4.f}'.format(e_max, f_max)) 283 | measures.append(s_measure) 284 | 285 | model.train() 286 | return measures 287 | 288 | if __name__ == '__main__': 289 | main() 290 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import torch 4 | import shutil 5 | from torchvision import transforms 6 | import numpy as np 7 | import random 8 | import cv2 9 | 10 | 11 | class Logger(): 12 | def __init__(self, path="log.txt"): 13 | self.logger = logging.getLogger('DCFM') 14 | self.file_handler = logging.FileHandler(path, "w") 15 | self.stdout_handler = logging.StreamHandler() 16 | self.stdout_handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s %(message)s')) 17 | self.file_handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s %(message)s')) 18 | self.logger.addHandler(self.file_handler) 19 | self.logger.addHandler(self.stdout_handler) 20 | self.logger.setLevel(logging.INFO) 21 | self.logger.propagate = False 22 | 23 | def info(self, txt): 24 | self.logger.info(txt) 25 | 26 | def close(self): 27 | self.file_handler.close() 28 | self.stdout_handler.close() 29 | 30 | class AverageMeter(object): 31 | """Computes and stores the average and current value""" 32 | def __init__(self): 33 | self.reset() 34 | 35 | def reset(self): 36 | self.val = 0.0 37 | self.avg = 0.0 38 | self.sum = 0.0 39 | self.count = 0.0 40 | 41 | def update(self, val, n=1): 42 | self.val = val 43 | self.sum += val * n 44 | self.count += n 45 | self.avg = self.sum / self.count 46 | 47 | 48 | def save_checkpoint(state, path, filename="checkpoint.pth"): 49 | torch.save(state, os.path.join(path, filename)) 50 | 51 | 52 | def save_tensor_img(tenor_im, path): 53 | im = tenor_im.cpu().clone() 54 | im = im.squeeze(0) 55 | tensor2pil = transforms.ToPILImage() 56 | im = tensor2pil(im) 57 | im.save(path) 58 | 59 | 60 | def save_tensor_merge(tenor_im, tensor_mask, path, colormap='HOT'): 61 | im = tenor_im.cpu().detach().clone() 62 | im = im.squeeze(0).numpy() 63 | im = ((im - np.min(im)) / (np.max(im) - np.min(im) + 1e-20)) * 255 64 | im = np.array(im,np.uint8) 65 | mask = tensor_mask.cpu().detach().clone() 66 | mask = mask.squeeze(0).numpy() 67 | mask = ((mask - np.min(mask)) / (np.max(mask) - np.min(mask) + 1e-20)) * 255 68 | mask = np.clip(mask, 0, 255) 69 | mask = np.array(mask, np.uint8) 70 | if colormap == 'HOT': 71 | mask = cv2.applyColorMap(mask[0,:,:], cv2.COLORMAP_HOT) 72 | elif colormap == 'PINK': 73 | mask = cv2.applyColorMap(mask[0,:,:], cv2.COLORMAP_PINK) 74 | elif colormap == 'BONE': 75 | mask = cv2.applyColorMap(mask[0,:,:], cv2.COLORMAP_BONE) 76 | # exec('cv2.applyColorMap(mask[0,:,:], cv2.COLORMAP_' + colormap+')') 77 | im = im.transpose((1, 2, 0)) 78 | im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR) 79 | mix = cv2.addWeighted(im, 0.3, mask, 0.7, 0) 80 | cv2.imwrite(path, mix) 81 | 82 | 83 | def set_seed(seed): 84 | torch.manual_seed(seed) 85 | torch.cuda.manual_seed(seed) 86 | torch.cuda.manual_seed_all(seed) 87 | np.random.seed(seed) 88 | random.seed(seed) 89 | torch.backends.cudnn.deterministic = True 90 | torch.backends.cudnn.benchmark = False 91 | 92 | 93 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | import utils.utils as gen_utils 5 | import numpy as np 6 | 7 | def adjust_rate_poly(cur_iter, max_iter, power=0.9): 8 | return (1.0 - 1.0 * cur_iter / max_iter) ** power 9 | 10 | def adjust_learning_rate_exp(lr, optimizer, iters, decay_rate=0.1, decay_step=25): 11 | lr = lr * (decay_rate ** (iters // decay_step)) 12 | for param_group in optimizer.param_groups: 13 | param_group['lr'] = lr * param_group['lr_mult'] 14 | 15 | def adjust_learning_rate_RevGrad(lr, optimizer, max_iter, cur_iter, 16 | alpha=10, beta=0.75): 17 | p = 1.0 * cur_iter / (max_iter - 1) 18 | lr = lr / pow(1.0 + alpha * p, beta) 19 | for param_group in optimizer.param_groups: 20 | param_group['lr'] = lr * param_group['lr_mult'] 21 | 22 | def adjust_learning_rate_inv(lr, optimizer, iters, alpha=0.001, beta=0.75): 23 | lr = lr / pow(1.0 + alpha * iters, beta) 24 | for param_group in optimizer.param_groups: 25 | param_group['lr'] = lr * param_group['lr_mult'] 26 | 27 | def adjust_learning_rate_step(lr, optimizer, iters, steps, beta=0.1): 28 | n = 0 29 | for step in steps: 30 | if iters < step: 31 | break 32 | n += 1 33 | 34 | lr = lr * (beta ** n) 35 | for param_group in optimizer.param_groups: 36 | param_group['lr'] = lr * param_group['lr_mult'] 37 | 38 | def adjust_learning_rate_poly(lr, optimizer, iters, max_iter, power=0.9): 39 | lr = lr * (1.0 - 1.0 * iters / max_iter) ** power 40 | for param_group in optimizer.param_groups: 41 | param_group['lr'] = lr * param_group['lr_mult'] 42 | 43 | def set_param_groups(net, lr_mult_dict={}): 44 | params = [] 45 | if hasattr(net, "module"): 46 | net = net.module 47 | 48 | modules = net._modules 49 | for name in modules: 50 | module = modules[name] 51 | if name in lr_mult_dict: 52 | params += [{'params': module.parameters(), \ 53 | 'lr_mult': lr_mult_dict[name]}] 54 | else: 55 | params += [{'params': module.parameters(), 'lr_mult': 1.0}] 56 | 57 | return params 58 | 59 | def LSR(x, dim=1, thres=10.0): 60 | lsr = -1.0 * torch.mean(x, dim=dim) 61 | if thres > 0.0: 62 | return torch.mean((lsr/thres-1.0).detach() * lsr) 63 | else: 64 | return torch.mean(lsr) 65 | 66 | def crop(feats, preds, gt, h, w): 67 | H, W = feats.shape[-2:] 68 | tmp_feats = [] 69 | tmp_preds = [] 70 | tmp_gt = [] 71 | N = feats.size(0) 72 | for i in range(N): 73 | inds_H = torch.randperm(H)[0:h] 74 | inds_W = torch.randperm(W)[0:w] 75 | tmp_feats += [feats[i, :, inds_H[:, None], inds_W]] 76 | tmp_preds += [preds[i, :, inds_H[:, None], inds_W]] 77 | tmp_gt += [gt[i, inds_H[:, None], inds_W]] 78 | 79 | new_feats = torch.stack(tmp_feats, dim=0) 80 | new_gt = torch.stack(tmp_gt, dim=0) 81 | new_preds = torch.stack(tmp_preds, dim=0) 82 | probs = F.softmax(new_preds, dim=1) 83 | return new_feats, probs, new_gt 84 | --------------------------------------------------------------------------------