├── .gitignore ├── README.md ├── coseg_test.py ├── data_processed.py ├── loss.py ├── main.py ├── model.py ├── pic ├── Internet.png ├── MSRC.png ├── PASCALVOC.png ├── iCoseg.png ├── network.png ├── seen.png └── unseen.png ├── requirement.txt ├── tools.py ├── train.py └── val.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [AAAI20] Deep Object Co-segmentation via Spatial-Semantic Network Modulation(Oral paper) 2 | ## Authors: [Kaihua Zhang](http://kaihuazhang.net/), [Jin Chen](https://github.com/cj4L), [Bo Liu](https://scholar.google.com/citations?user=2Fe93n8AAAAJ&hl=en), [Qingshan Liu](https://scholar.google.com/citations?user=2Pyf20IAAAAJ&hl=zh-CN) 3 | * PDF: [arXiv](https://arxiv.org/abs/1911.12950) or [AAAI20](https://aaai.org/ojs/index.php/AAAI/article/view/6977) 4 | 5 | ## Abstract 6 |  Object co-segmentation is to segment the shared objects in multiple relevant images, which has numerous applications in computer vision. This paper presents a spatial and semantic modulated deep network framework for object co-segmentation. A backbone network is adopted to extract multi-resolution image features. With the multi-resolution features of the relevant images as input, we design a spatial modulator to learn a mask for each image. The spatial modulator captures the correlations of image feature descriptors via unsupervised learning. The learned mask can roughly localize the shared foreground object while suppressing the background. For the semantic modulator, we model it as a supervised image classification task. We propose a hierarchical second-order pooling module to transform the image features for classification use. The outputs of the two modulators manipulate the multi-resolution features by a shift-and-scale operation so that the features focus on segmenting co-object regions. The proposed model is trained end-to-end without any intricate post-processing. Extensive experiments on four image co-segmentation benchmark datasets demonstrate the superior accuracy of the proposed method compared to state-of-the-art methods. 7 | 8 | ## Examples 9 |

10 | 11 |

12 | 13 | ## Overview of our method 14 | ![](https://github.com/cj4L/SSNM-Coseg/raw/master/pic/network.png) 15 | 16 | ## Datasets 17 |  In order to compare the deep learning methods in recent years fairly, we conduct extensive evaluations on four widely-used benchmark datasets including sub-set of [MSRC](https://www.microsoft.com/en-us/research/project/image-understanding/?from=http%3A%2F%2Fresearch.microsoft.com%2Fen-us%2Fprojects%2Fobjectclassrecognition%2F), [Internet](http://people.csail.mit.edu/mrub/ObjectDiscovery/), sub-set of [iCoseg](http://chenlab.ece.cornell.edu/projects/touch-coseg/), and [PASCAL-VOC](http://host.robots.ox.ac.uk/pascal/VOC/). Among them: 18 | * The sub-set of MSRC includes 7 classes: bird, car, cat, cow, dog, plane, sheep, and each class contains 10 images. 19 | * The Internet has 3 categories of airplane, car and horse. Each class has 100 images including some images with noisy labels. 20 | * The sub-set of iCoseg contains 8 categories, and each has a different number of images. 21 | * The PASCAL-VOC is the most challenging dataset with 1037 images of 20 categories selected from the PASCAL-VOC 2010 dataset. 22 | 23 | ## Results download 24 | * VGG16-backbone: [Google Drive](https://drive.google.com/file/d/14h2XdIB0GR1Zb_0X59T0URgbuogJpR5Z/view?usp=sharing). 25 | * HRNet-backbone: [Google Drive](https://drive.google.com/file/d/1r8piQHHVosecDJD6DmDZVriUzEfQxCeB/view?usp=sharing). 26 | 27 | ## Environment 28 | * Ubuntu 16.04, Nvidia RTX 2080Ti 29 | * Python 3 30 | * PyTorch>=1.0, TorchVision>=0.2.2 31 | * Numpy==1.16.2, Pillow, pycocotools 32 | 33 | ## Test 34 | * Get or download the dataset we have processed in [Google Drive](https://drive.google.com/file/d/1bo5zE64bQwLUbCUGKDLcRjHei9FBmhfi/view?usp=sharing). 35 | * Download VGG16-backbone pretrained model in [Google Drive](https://drive.google.com/file/d/1Vvir1CeuCNQY7GU_Ygh593U5I-KXZWff/view?usp=sharing). 36 | * Modify the path config in coseg_test.py and run it. 37 | 38 | ## Train 39 | * Get the COCO2017 Dataset for training the whole network. 40 | * Get the test dataset for val and test phase. 41 | * Download VGG16 pretrained weights in [Google Drive](https://drive.google.com/file/d/1KIWIspVxLRwv8bzOuMn6lY8kStoedToV/view?usp=sharing). Actually is from PyTorch offical model weights, expect for deleting the last serveral layers. 42 | * Download dict.npy in [Google Drive](https://drive.google.com/file/d/1p15hGN3YwqWMRN4xx5mDIK04OhimpY2z/view?usp=sharing). 43 | * Modify the path config in main.py and run it. 44 | 45 | ### Notes 46 | * Following the suggestion of reviewers in AAAI20, we would not release the HRNet-backbone trained model for fairly comparing with others methods. 47 | * There are some slight differences in the 'Fusion' part of the model but little impact. 48 | * There is a mistake value in Table 2, our HRNet J-index(82.5) in 'Car' in Internet Dataset should be modified with (73.9). 49 | * There is something wrong about the share link of BaiduPan, contact me if want. 50 | 51 | #### Schedule 52 | - [x] Create github repo (2019.11.18) 53 | - [x] Release arXiv pdf (2019.12.2) 54 | - [x] Release AAAI20 pdf (2020.7.3) 55 | - [x] All results (2020.7.3) 56 | - [x] Test and Train code (2021.6.4) 57 | -------------------------------------------------------------------------------- /coseg_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import torch 4 | from torchvision import transforms 5 | from model import build_model 6 | 7 | def test(gpu_id, model_path, datapath, save_root_path, group_size, img_size, img_dir_name): 8 | net = build_model(device).to(device) 9 | net.load_state_dict(torch.load(model_path, map_location=gpu_id)) 10 | net.eval() 11 | 12 | img_transform = transforms.Compose([transforms.Resize((img_size, img_size)), transforms.ToTensor(), 13 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 14 | img_transform_gray = transforms.Compose([transforms.Resize((img_size, img_size)), transforms.ToTensor(), 15 | transforms.Normalize(mean=[0.449], std=[0.226])]) 16 | with torch.no_grad(): 17 | for p in range(len(datapath)): 18 | all_class = os.listdir(os.path.join(datapath[p], img_dir_name)) 19 | image_list, save_list = list(), list() 20 | for s in range(len(all_class)): 21 | image_path = os.listdir(os.path.join(datapath[p], img_dir_name, all_class[s])) 22 | image_list.append(list(map(lambda x: os.path.join(datapath[p], img_dir_name, all_class[s], x), image_path))) 23 | save_list.append(list(map(lambda x: os.path.join(save_root_path[p], all_class[s], x[:-4]+'.png'), image_path))) 24 | for i in range(len(image_list)): 25 | cur_class_all_image = image_list[i] 26 | cur_class_rgb = torch.zeros(len(cur_class_all_image), 3, img_size, img_size) 27 | for m in range(len(cur_class_all_image)): 28 | rgb_ = Image.open(cur_class_all_image[m]) 29 | if rgb_.mode == 'RGB': 30 | rgb_ = img_transform(rgb_) 31 | else: 32 | rgb_ = img_transform_gray(rgb_) 33 | cur_class_rgb[m, :, :, :] = rgb_ 34 | 35 | cur_class_mask = torch.zeros(len(cur_class_all_image), img_size, img_size) 36 | divided = len(cur_class_all_image) // group_size 37 | rested = len(cur_class_all_image) % group_size 38 | if divided != 0: 39 | for k in range(divided): 40 | group_rgb = cur_class_rgb[(k * group_size): ((k + 1) * group_size)] 41 | group_rgb = group_rgb.to(device) 42 | _, pred_mask = net(group_rgb) 43 | cur_class_mask[(k * group_size): ((k + 1) * group_size)] = pred_mask 44 | if rested != 0: 45 | group_rgb_tmp_l = cur_class_rgb[-rested:] 46 | group_rgb_tmp_r = cur_class_rgb[:group_size - rested] 47 | group_rgb = torch.cat((group_rgb_tmp_l, group_rgb_tmp_r), dim=0) 48 | group_rgb = group_rgb.to(device) 49 | _, pred_mask = net(group_rgb) 50 | cur_class_mask[(divided * group_size):] = pred_mask[:rested] 51 | 52 | 53 | class_save_path = os.path.join(save_root_path[p], all_class[i]) 54 | if not os.path.exists(class_save_path): 55 | os.makedirs(class_save_path) 56 | 57 | for j in range(len(cur_class_all_image)): 58 | exact_save_path = save_list[i][j] 59 | result = cur_class_mask[j, :, :].numpy() 60 | result = Image.fromarray(result * 255) 61 | w, h = Image.open(image_list[i][j]).size 62 | result = result.resize((w, h), Image.BILINEAR) 63 | result.convert('L').save(exact_save_path) 64 | 65 | print('done') 66 | 67 | 68 | 69 | 70 | if __name__ == '__main__': 71 | gpu_id = 'cuda:0' 72 | device = torch.device(gpu_id) 73 | model_path = './models/SSNM-Coseg_best.pth' 74 | 75 | val_datapath = ['./cosegdatasets/iCoseg8', 76 | './cosegdatasets/Internet_Datasets300', 77 | './cosegdatasets/MSRC7', 78 | './cosegdatasets/PASCAL_VOC'] 79 | 80 | save_root_path = ['./cosegresults/iCoseg8', 81 | './cosegresults/Internet_Datasets300', 82 | './cosegresults/MSRC7', 83 | './cosegresults/PASCAL_VOC'] 84 | 85 | test(gpu_id, model_path, val_datapath, save_root_path, 5, 224, 'image') 86 | -------------------------------------------------------------------------------- /data_processed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import transforms 3 | import copy 4 | import random 5 | import os 6 | import numpy as np 7 | from PIL import Image 8 | 9 | 10 | def filt_small_instance(coco_item, pixthreshold=4000,imgNthreshold=5): 11 | list_dict = coco_item.catToImgs 12 | for catid in list_dict: 13 | list_dict[catid] = list(set( list_dict[catid] )) 14 | new_dict = copy.deepcopy(list_dict) 15 | for catid in list_dict: 16 | imgids = list_dict[catid] 17 | for n in range(len(imgids)): 18 | imgid = imgids[n] 19 | anns = coco_item.imgToAnns[imgid] 20 | has_large_instance = False 21 | for ann in anns: 22 | if (ann['category_id'] == catid) and (ann['iscrowd'] == 0) and (ann['area'] > pixthreshold): 23 | has_large_instance = True 24 | if has_large_instance is False: 25 | new_dict[catid].remove(imgid) 26 | imgN = len(new_dict[catid]) 27 | if imgN len(list_dict): 56 | remainN = batch_size - len(list_dict) 57 | batch_catid = random.sample(list_dict.keys(), remainN) + random.sample(list_dict, len(list_dict)) 58 | else: 59 | batch_catid = random.sample(list_dict.keys(), batch_size) 60 | group_n = 0 61 | img_n = 0 62 | for catid in batch_catid: 63 | imgids = random.sample(list_dict[catid], group_size) 64 | co_catids = [] 65 | anns = coco_item.imgToAnns[imgids[0]] 66 | for ann in anns: 67 | if (ann['iscrowd'] == 0) and (ann['area'] > 4000): 68 | co_catids.append(ann['category_id']) 69 | co_catids_backup = copy.deepcopy(co_catids) 70 | for imgid in imgids[1:]: 71 | img_catids = [] 72 | anns = coco_item.imgToAnns[imgid] 73 | for ann in anns: 74 | if (ann['iscrowd'] == 0) and (ann['area'] > 4000): 75 | img_catids.append(ann['category_id']) 76 | for co_catid in co_catids_backup: 77 | if co_catid not in img_catids: 78 | co_catids.remove(co_catid) 79 | co_catids_backup = copy.deepcopy(co_catids) 80 | for co_catid in co_catids: 81 | cls_labels[group_n, catid2label[co_catid]] = 1 82 | for imgid in imgids: 83 | path = datapath + '%012d.jpg'%imgid 84 | img = Image.open(path) 85 | if img.mode == 'RGB': 86 | img = img_transform(img) 87 | else: 88 | img = img_transform_gray(img) 89 | anns = coco_item.imgToAnns[imgid] 90 | mask = None 91 | for ann in anns: 92 | if ann['category_id'] in co_catids: 93 | if mask is None: 94 | mask = coco_item.annToMask(ann) 95 | else: 96 | mask = mask + coco_item.annToMask(ann) 97 | mask[mask > 0] = 255 98 | mask = Image.fromarray(mask) 99 | mask = gt_transform(mask) 100 | mask[mask > 0.5] = 1 101 | mask[mask <= 0.5] = 0 102 | rgb[img_n,:,:,:] = copy.deepcopy(img) 103 | mask_labels[img_n,:,:] = copy.deepcopy(mask) 104 | img_n = img_n + 1 105 | group_n = group_n + 1 106 | idx = mask_labels[:, :, :] > 1 107 | mask_labels[idx] = 1 108 | q.put([rgb, cls_labels, mask_labels]) 109 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.nn.functional as F 3 | import torch 4 | 5 | 6 | class Weighed_Bce_Loss(nn.Module): 7 | def __init__(self): 8 | super(Weighed_Bce_Loss, self).__init__() 9 | 10 | def forward(self, x, label): 11 | x = x.view(-1, 1, x.shape[1], x.shape[2]) 12 | label = label.view(-1, 1, label.shape[1], label.shape[2]) 13 | label_t = (label == 1).float() 14 | label_f = (label == 0).float() 15 | p = torch.sum(label_t) / (torch.sum(label_t) + torch.sum(label_f)) 16 | w = torch.zeros_like(label) 17 | w[label == 1] = p 18 | w[label == 0] = 1 - p 19 | loss = F.binary_cross_entropy(x, label, weight=w) 20 | return loss 21 | 22 | 23 | class Cls_Loss(nn.Module): 24 | def __init__(self): 25 | super(Cls_Loss, self).__init__() 26 | 27 | def forward(self, x, label): 28 | loss = F.binary_cross_entropy(x, label) 29 | return loss 30 | 31 | class S_Loss(nn.Module): 32 | def __init__(self): 33 | super(S_Loss, self).__init__() 34 | 35 | def forward(self, x, label): 36 | loss = F.smooth_l1_loss(x, label) 37 | return loss 38 | 39 | class Loss(nn.Module): 40 | def __init__(self): 41 | super(Loss, self).__init__() 42 | self.loss_wbce = Weighed_Bce_Loss() 43 | self.loss_cls = Cls_Loss() 44 | self.loss_s = S_Loss() 45 | self.w_wbce = 1 46 | self.w_cls = 1 47 | self.w_smooth = 1 48 | 49 | def forward(self, x, label, x_cls, label_cls): 50 | m_loss = self.loss_wbce(x, label) * self.w_wbce 51 | c_loss = self.loss_cls(x_cls, label_cls) * self.w_cls 52 | s_loss = self.loss_s(x, label) * self.w_smooth 53 | loss = m_loss + c_loss + s_loss 54 | 55 | return loss, m_loss, c_loss, s_loss 56 | 57 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from pycocotools import coco 4 | import queue 5 | import threading 6 | from model import build_model, weights_init 7 | from tools import custom_print 8 | from data_processed import train_data_producer 9 | from train import train 10 | import time 11 | torch.backends.cudnn.benchmark = True 12 | 13 | 14 | if __name__ == '__main__': 15 | # train_val_config 16 | annotation_file = '/home/chenjin/dataset/COCO/COCO2017/annotations/instances_train2017.json' 17 | coco_item = coco.COCO(annotation_file=annotation_file) 18 | 19 | train_datapath = '/home/chenjin/dataset/COCO/COCO2017/train2017/' 20 | 21 | val_datapath = ['./cosegdatasets/iCoseg8', 22 | './cosegdatasets/MSRC7', 23 | './cosegdatasets/Internet_Datasets300', 24 | './cosegdatasets/PASCAL_VOC'] 25 | 26 | vgg16_path = './weights/vgg16_bn_feat.pth' 27 | npy = './utils/new_cat2imgid_dict4000.npy' 28 | 29 | # project config 30 | project_name = 'SSNM-Coseg' 31 | device = torch.device('cuda:0') 32 | img_size = 224 33 | lr = 1e-5 34 | lr_de = 20000 35 | epochs = 100000 36 | batch_size = 4 37 | group_size = 5 38 | log_interval = 100 39 | val_interval = 1000 40 | 41 | # create log dir 42 | log_root = './logs' 43 | if not os.path.exists(log_root): 44 | os.makedirs(log_root) 45 | 46 | # create log txt 47 | log_txt_file = os.path.join(log_root, project_name + '_log.txt') 48 | custom_print(project_name, log_txt_file, 'w') 49 | 50 | # create model save dir 51 | models_root = './models' 52 | if not os.path.exists(models_root): 53 | os.makedirs(models_root) 54 | 55 | models_train_last = os.path.join(models_root, project_name + '_last.pth') 56 | models_train_best = os.path.join(models_root, project_name + '_best.pth') 57 | 58 | net = build_model(device).to(device) 59 | net.train() 60 | net.apply(weights_init) 61 | net.base.load_state_dict(torch.load(vgg16_path)) 62 | 63 | # continute load checkpoint 64 | # net.load_state_dict(torch.load('./models/SSNM-Coseg_last.pth', map_location='cuda:0')) 65 | 66 | q = queue.Queue(maxsize=40) 67 | 68 | p1 = threading.Thread(target=train_data_producer, args=(coco_item, train_datapath, npy, q, batch_size, group_size, img_size)) 69 | p2 = threading.Thread(target=train_data_producer, args=(coco_item, train_datapath, npy, q, batch_size, group_size, img_size)) 70 | p3 = threading.Thread(target=train_data_producer, args=(coco_item, train_datapath, npy, q, batch_size, group_size, img_size)) 71 | p1.start() 72 | p2.start() 73 | p3.start() 74 | time.sleep(2) 75 | 76 | train(net, device, q, log_txt_file, val_datapath, models_train_best, models_train_last, lr, lr_de, epochs, log_interval, val_interval) 77 | 78 | 79 | 80 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import init 4 | import torch.nn.functional as F 5 | from torch.optim import Adam 6 | 7 | # vgg choice 8 | base = {'vgg': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M']} 9 | 10 | # vgg16 11 | def vgg(cfg, i=3, batch_norm=True): 12 | layers = [] 13 | in_channels = i 14 | for v in cfg: 15 | if v == 'M': 16 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 17 | else: 18 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 19 | if batch_norm: 20 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 21 | else: 22 | layers += [conv2d, nn.ReLU(inplace=True)] 23 | in_channels = v 24 | return layers 25 | 26 | 27 | def hsp(in_channel, out_channel): 28 | layers = nn.Sequential(nn.Conv2d(in_channel, out_channel, 1, 1), 29 | nn.ReLU()) 30 | return layers 31 | 32 | def cls_modulation_branch(in_channel, hiden_channel): 33 | layers = nn.Sequential(nn.Linear(in_channel, hiden_channel), 34 | nn.ReLU()) 35 | return layers 36 | 37 | def cls_branch(hiden_channel, class_num): 38 | layers = nn.Sequential(nn.Linear(hiden_channel, class_num), 39 | nn.Sigmoid()) 40 | return layers 41 | 42 | def concat_r(): 43 | layers = [] 44 | layers += [nn.Conv2d(512, 512, 1, 1)] 45 | layers += [nn.ReLU()] 46 | layers += [nn.Conv2d(512, 512, 3, 1, 1)] 47 | layers += [nn.ReLU()] 48 | layers += [nn.ConvTranspose2d(512, 512, 4, 2, 1)] 49 | return layers 50 | 51 | def concat_1(): 52 | layers = [] 53 | layers += [nn.Conv2d(512, 512, 1, 1)] 54 | layers += [nn.ReLU()] 55 | layers += [nn.Conv2d(512, 512, 3, 1, 1)] 56 | layers += [nn.ReLU()] 57 | return layers 58 | 59 | def mask_branch(): 60 | layers = [] 61 | layers += [nn.Conv2d(512, 2, 3, 1, 1)] 62 | layers += [nn.ConvTranspose2d(2, 2, 8, 4, 2)] 63 | layers += [nn.Softmax2d()] 64 | return layers 65 | 66 | def incr_channel(): 67 | layers = [] 68 | layers += [nn.Conv2d(128, 512, 3, 1, 1)] 69 | layers += [nn.Conv2d(256, 512, 3, 1, 1)] 70 | layers += [nn.Conv2d(512, 512, 3, 1, 1)] 71 | layers += [nn.Conv2d(512, 512, 3, 1, 1)] 72 | return layers 73 | 74 | def incr_channel2(): 75 | layers = [] 76 | layers += [nn.Conv2d(512, 512, 3, 1, 1)] 77 | layers += [nn.Conv2d(512, 512, 3, 1, 1)] 78 | layers += [nn.Conv2d(512, 512, 3, 1, 1)] 79 | layers += [nn.Conv2d(512, 512, 3, 1, 1)] 80 | layers += [nn.ReLU()] 81 | return layers 82 | 83 | def norm(x, dim): 84 | squared_norm = (x ** 2).sum(dim=dim, keepdim=True) 85 | normed = x / torch.sqrt(squared_norm) 86 | return normed 87 | 88 | def fuse_hsp(x, p): 89 | group_size = 5 90 | t = torch.zeros(group_size, x.size(1)) 91 | for i in range(x.size(0)): 92 | tmp = x[i, :] 93 | if i == 0: 94 | nx = tmp.expand_as(t) 95 | else: 96 | nx = torch.cat(([nx, tmp.expand_as(t)]), dim=0) 97 | nx = nx.view(x.size(0)*group_size, x.size(1), 1, 1) 98 | y = nx.expand_as(p) 99 | return y 100 | 101 | 102 | class Model(nn.Module): 103 | def __init__(self, device, base, incr_channel, incr_channel2, hsp1, hsp2, cls_m, cls, concat_r, concat_1, mask_branch): 104 | super(Model, self).__init__() 105 | self.base = nn.ModuleList(base) 106 | self.sp1 = hsp1 107 | self.sp2 = hsp2 108 | self.cls_m = cls_m 109 | self.cls = cls 110 | self.incr_channel1 = nn.ModuleList(incr_channel) 111 | self.incr_channel2 = nn.ModuleList(incr_channel2) 112 | self.concat4 = nn.ModuleList(concat_r) 113 | self.concat3 = nn.ModuleList(concat_r) 114 | self.concat2 = nn.ModuleList(concat_r) 115 | self.concat1 = nn.ModuleList(concat_1) 116 | self.mask = nn.ModuleList(mask_branch) 117 | self.extract = [13, 23, 33, 43] 118 | self.device = device 119 | self.group_size = 5 120 | 121 | def forward(self, x): 122 | # backbone, p is the pool2, 3, 4, 5 123 | p = list() 124 | for k in range(len(self.base)): 125 | x = self.base[k](x) 126 | if k in self.extract: 127 | p.append(x) 128 | 129 | # increase the channel 130 | newp = list() 131 | for k in range(len(p)): 132 | np = self.incr_channel1[k](p[k]) 133 | np = self.incr_channel2[k](np) 134 | newp.append(self.incr_channel2[4](np)) 135 | 136 | # spatial modulator 137 | spa_mask = spatial_optimize(newp[3], self.group_size).to(self.device) 138 | 139 | # hsp 140 | x = newp[3] 141 | x = self.sp1(x) 142 | x = x.view(-1, x.size(1), x.size(2) * x.size(3)) 143 | x = torch.bmm(x, x.transpose(1, 2)) 144 | x = x.view(-1, x.size(1) * x.size(2)) 145 | x = x.view(x.size(0) // 5, x.size(1), -1, 1) 146 | x = self.sp2(x) 147 | x = x.view(-1, x.size(1), x.size(2) * x.size(3)) 148 | x = torch.bmm(x, x.transpose(1, 2)) 149 | x = x.view(-1, x.size(1) * x.size(2)) 150 | 151 | #cls pred 152 | cls_modulated_vector = self.cls_m(x) 153 | cls_pred = self.cls(cls_modulated_vector) 154 | 155 | #semantic and spatial modulator 156 | g1 = fuse_hsp(cls_modulated_vector, newp[0]) 157 | g2 = fuse_hsp(cls_modulated_vector, newp[1]) 158 | g3 = fuse_hsp(cls_modulated_vector, newp[2]) 159 | g4 = fuse_hsp(cls_modulated_vector, newp[3]) 160 | 161 | spa_1 = F.interpolate(spa_mask, size=[g1.size(2), g1.size(3)], mode='bilinear') 162 | spa_1 = spa_1.expand_as(g1) 163 | spa_2 = F.interpolate(spa_mask, size=[g2.size(2), g2.size(3)], mode='bilinear') 164 | spa_2 = spa_2.expand_as(g2) 165 | spa_3 = F.interpolate(spa_mask, size=[g3.size(2), g3.size(3)], mode='bilinear') 166 | spa_3 = spa_3.expand_as(g3) 167 | spa_4 = F.interpolate(spa_mask, size=[g4.size(2), g4.size(3)], mode='bilinear') 168 | spa_4 = spa_4.expand_as(g4) 169 | 170 | y4 = newp[3] * g4 + spa_4 171 | for k in range(len(self.concat4)): 172 | y4 = self.concat4[k](y4) 173 | y3 = newp[2] * g3 + spa_3 174 | 175 | for k in range(len(self.concat3)): 176 | y3 = self.concat3[k](y3) 177 | if k == 1: 178 | y3 = y3 + y4 179 | y2 = newp[1] * g2 + spa_2 180 | 181 | for k in range(len(self.concat2)): 182 | y2 = self.concat2[k](y2) 183 | if k == 1: 184 | y2 = y2 + y3 185 | y1 = newp[0] * g1 + spa_1 186 | 187 | for k in range(len(self.concat1)): 188 | y1 = self.concat1[k](y1) 189 | if k == 1: 190 | y1 = y1 + y2 191 | y = y1 192 | 193 | # decoder 194 | for k in range(len(self.mask)): 195 | y = self.mask[k](y) 196 | mask_pred = y[:, 0, :, :] 197 | 198 | return cls_pred, mask_pred 199 | 200 | 201 | 202 | # build the whole network 203 | def build_model(device): 204 | return Model(device, 205 | vgg(base['vgg']), 206 | incr_channel(), 207 | incr_channel2(), 208 | hsp(512, 64), 209 | hsp(64**2, 32), 210 | cls_modulation_branch(32**2, 512), 211 | cls_branch(512, 78), 212 | concat_r(), 213 | concat_1(), 214 | mask_branch()) 215 | 216 | # weight init 217 | def xavier(param): 218 | init.xavier_uniform_(param) 219 | 220 | def weights_init(m): 221 | if isinstance(m, nn.Conv2d): 222 | xavier(m.weight.data) 223 | elif isinstance(m, nn.BatchNorm2d): 224 | init.constant_(m.weight, 1) 225 | init.constant_(m.bias, 0) 226 | 227 | 228 | def spatial_optimize(fmap, group_size): 229 | fmap_split = torch.split(fmap, group_size, dim=0) 230 | for i in range(len(fmap_split)): 231 | cur_fmap = fmap_split[i] 232 | with torch.no_grad(): 233 | spatial_x = cur_fmap.permute(0, 2, 3, 1).contiguous().view(-1, cur_fmap.size(1)).transpose(1, 0) 234 | spatial_x = norm(spatial_x, dim=0) 235 | spatial_x_t = spatial_x.transpose(1, 0) 236 | G = spatial_x_t @ spatial_x - 1 237 | G = G.detach().cpu() 238 | 239 | with torch.enable_grad(): 240 | spatial_s = nn.Parameter(torch.sqrt(245 * torch.ones((245, 1))) / 245, requires_grad=True) 241 | spatial_s_t = spatial_s.transpose(1, 0) 242 | spatial_s_optimizer = Adam([spatial_s], 0.01) 243 | 244 | for iter in range(200): 245 | f_spa_loss = -1 * torch.sum(spatial_s_t @ G @ spatial_s) 246 | spatial_s_d = torch.sqrt(torch.sum(spatial_s ** 2)) 247 | if spatial_s_d >= 1: 248 | d_loss = -1 * torch.log(2 - spatial_s_d) 249 | else: 250 | d_loss = -1 * torch.log(spatial_s_d) 251 | 252 | all_loss = 50 * d_loss + f_spa_loss 253 | 254 | spatial_s_optimizer.zero_grad() 255 | all_loss.backward() 256 | spatial_s_optimizer.step() 257 | 258 | result_map = spatial_s.data.view(5, 1, 7, 7) 259 | 260 | if i == 0: 261 | spa_mask = result_map 262 | else: 263 | spa_mask = torch.cat(([spa_mask, result_map]), dim=0) 264 | 265 | return spa_mask 266 | 267 | 268 | 269 | -------------------------------------------------------------------------------- /pic/Internet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cj4L/SSNM-Coseg/e172652a07407e8d1ff58fac208cad4592349d32/pic/Internet.png -------------------------------------------------------------------------------- /pic/MSRC.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cj4L/SSNM-Coseg/e172652a07407e8d1ff58fac208cad4592349d32/pic/MSRC.png -------------------------------------------------------------------------------- /pic/PASCALVOC.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cj4L/SSNM-Coseg/e172652a07407e8d1ff58fac208cad4592349d32/pic/PASCALVOC.png -------------------------------------------------------------------------------- /pic/iCoseg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cj4L/SSNM-Coseg/e172652a07407e8d1ff58fac208cad4592349d32/pic/iCoseg.png -------------------------------------------------------------------------------- /pic/network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cj4L/SSNM-Coseg/e172652a07407e8d1ff58fac208cad4592349d32/pic/network.png -------------------------------------------------------------------------------- /pic/seen.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cj4L/SSNM-Coseg/e172652a07407e8d1ff58fac208cad4592349d32/pic/seen.png -------------------------------------------------------------------------------- /pic/unseen.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cj4L/SSNM-Coseg/e172652a07407e8d1ff58fac208cad4592349d32/pic/unseen.png -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | ==1.16.2 3 | torch>=1.0 4 | 5 | torchvision==0.2.2 6 | 7 | pycocotools 8 | 9 | pillow -------------------------------------------------------------------------------- /tools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def custom_print(context, log_file, mode): 4 | #custom print and log out function 5 | if mode == 'w': 6 | fp = open(log_file, mode) 7 | fp.write(context + '\n') 8 | fp.close() 9 | elif mode == 'a+': 10 | print(context) 11 | fp = open(log_file, mode) 12 | print(context, file=fp) 13 | fp.close() 14 | else: 15 | raise Exception('other file operation is unimplemented !') 16 | 17 | 18 | def generate_binary_map(pred, type): 19 | if type == '2mean': 20 | threshold = np.mean(pred) * 2 21 | if threshold > 0.8: 22 | threshold = 0.8 23 | binary_map = pred > threshold 24 | return binary_map.astype(np.float32) 25 | 26 | if type == 'mean+std': 27 | threshold = np.mean(pred) + np.std(pred) 28 | if threshold > 0.8: 29 | threshold = 0.8 30 | binary_map = pred > threshold 31 | return binary_map.astype(np.float32) 32 | 33 | 34 | 35 | def calc_precision_and_jaccard(pred, gt): 36 | bin_pred = generate_binary_map(pred, 'mean+std') 37 | tp = (bin_pred == gt).sum() 38 | precision = tp / (pred.size) 39 | 40 | i = (bin_pred * gt).sum() 41 | u = bin_pred.sum() + gt.sum() - i 42 | jaccard = i / (u + 1e-10) 43 | 44 | return precision, jaccard -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from loss import Loss 2 | from torch.optim import Adam 3 | from tools import custom_print 4 | import datetime 5 | import torch 6 | from val import validation 7 | 8 | def train(net, device, q, log_txt_file, val_datapath, models_train_best, models_train_last, lr=1e-4, lr_de_epoch=25000, 9 | epochs=100000, log_interval=100, val_interval=1000): 10 | optimizer = Adam(net.parameters(), lr, weight_decay=1e-6) 11 | loss = Loss().to(device) 12 | best_p, best_j = 0, 1 13 | ave_loss, ave_m_loss, ave_c_loss, ave_s_loss = 0, 0, 0, 0 14 | for epoch in range(1, epochs+1): 15 | img, cls_gt, mask_gt = q.get() 16 | net.zero_grad() 17 | img, cls_gt, mask_gt = img.to(device), cls_gt.to(device), mask_gt.to(device) 18 | pred_cls, pred_mask = net(img) 19 | all_loss, m_loss, c_loss, s_loss = loss(pred_mask, mask_gt, pred_cls, cls_gt) 20 | all_loss.backward() 21 | epoch_loss = all_loss.item() 22 | m_l = m_loss.item() 23 | c_l = c_loss.item() 24 | s_l = s_loss.item() 25 | ave_loss += epoch_loss 26 | ave_m_loss += m_l 27 | ave_c_loss += c_l 28 | ave_s_loss += s_l 29 | optimizer.step() 30 | 31 | if epoch % log_interval == 0: 32 | ave_loss = ave_loss / log_interval 33 | ave_m_loss = ave_m_loss / log_interval 34 | ave_c_loss = ave_c_loss / log_interval 35 | ave_s_loss = ave_s_loss / log_interval 36 | custom_print(datetime.datetime.now().strftime('%F %T') + 37 | ' lr: %e, epoch: [%d/%d], all_loss: [%.4f], m_loss: [%.4f], c_loss: [%.4f], s_loss: [%.4f]' % 38 | (lr, epoch, epochs, ave_loss, ave_m_loss, ave_c_loss, ave_s_loss), log_txt_file, 'a+') 39 | ave_loss, ave_m_loss, ave_c_loss, ave_s_loss = 0, 0, 0, 0 40 | 41 | if epoch % val_interval == 0: 42 | net.eval() 43 | with torch.no_grad(): 44 | custom_print(datetime.datetime.now().strftime('%F %T') + 45 | ' now is evaluating the coseg dataset', log_txt_file, 'a+') 46 | ave_p, ave_j = validation(net, val_datapath, device, group_size=5, img_size=224, img_dir_name='image', gt_dir_name='groundtruth', 47 | img_ext=['.jpg', '.jpg', '.jpg', '.jpg'], gt_ext=['.png', '.bmp', '.jpg', '.png']) 48 | if ave_p[3] > best_p: 49 | # follow yourself save condition 50 | best_p = ave_p[3] 51 | best_j = ave_j[0] 52 | torch.save(net.state_dict(), models_train_best) 53 | torch.save(net.state_dict(), models_train_last) 54 | custom_print('-' * 100, log_txt_file, 'a+') 55 | custom_print(datetime.datetime.now().strftime('%F %T') + ' iCoseg8 p: [%.4f], j: [%.4f]' % 56 | (ave_p[0], ave_j[0]), log_txt_file, 'a+') 57 | custom_print(datetime.datetime.now().strftime('%F %T') + ' MSRC7 p: [%.4f], j: [%.4f]' % 58 | (ave_p[1], ave_j[1]), log_txt_file, 'a+') 59 | custom_print(datetime.datetime.now().strftime('%F %T') + ' Int_300 p: [%.4f], j: [%.4f]' % 60 | (ave_p[2], ave_j[2]), log_txt_file, 'a+') 61 | custom_print(datetime.datetime.now().strftime('%F %T') + ' PAS_VOC p: [%.4f], j: [%.4f]' % 62 | (ave_p[3], ave_j[3]), log_txt_file, 'a+') 63 | custom_print('-' * 100, log_txt_file, 'a+') 64 | net.train() 65 | 66 | if epoch % lr_de_epoch == 0: 67 | optimizer = Adam(net.parameters(), lr/2, weight_decay=1e-6) 68 | lr = lr / 2 -------------------------------------------------------------------------------- /val.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | import torch 3 | import os 4 | from PIL import Image 5 | from tools import * 6 | 7 | 8 | def validation(net, datapath, device, group_size=5, img_size=224, img_dir_name='image', gt_dir_name='groundtruth', 9 | img_ext=['.jpg', '.jpg', '.jpg', '.jpg'], gt_ext=['.png', '.bmp', '.jpg', '.png']): 10 | img_transform = transforms.Compose([transforms.Resize((img_size, img_size)), transforms.ToTensor(), 11 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 12 | gt_transform = transforms.Compose([transforms.Resize((img_size, img_size)), transforms.ToTensor()]) 13 | img_transform_gray = transforms.Compose([transforms.Resize((img_size, img_size)), transforms.ToTensor(), 14 | transforms.Normalize(mean=[0.449], std=[0.226])]) 15 | net.eval() 16 | with torch.no_grad(): 17 | ave_p, ave_j = [], [] 18 | for p in range(len(datapath)): 19 | all_p, all_j = [], [] 20 | all_class = os.listdir(os.path.join(datapath[p], img_dir_name)) 21 | image_list, gt_list = list(), list() 22 | for s in range(len(all_class)): 23 | image_path = os.listdir(os.path.join(datapath[p], img_dir_name, all_class[s])) 24 | image_list.append(list(map(lambda x: os.path.join(datapath[p], img_dir_name, all_class[s], x), image_path))) 25 | gt_list.append(list(map(lambda x: os.path.join(datapath[p], gt_dir_name, all_class[s], x.replace(img_ext[p], gt_ext[p])), image_path))) 26 | for i in range(len(image_list)): 27 | cur_class_all_image = image_list[i] 28 | cur_class_all_gt = gt_list[i] 29 | 30 | cur_class_gt = torch.zeros(len(cur_class_all_gt), img_size, img_size) 31 | for g in range(len(cur_class_all_gt)): 32 | gt_ = Image.open(cur_class_all_gt[g]).convert('L') 33 | gt_ = gt_transform(gt_) 34 | gt_[gt_ > 0.5] = 1 35 | gt_[gt_ <= 0.5] = 0 36 | cur_class_gt[g, :, :] = gt_ 37 | 38 | cur_class_rgb = torch.zeros(len(cur_class_all_image), 3, img_size, img_size) 39 | for m in range(len(cur_class_all_image)): 40 | rgb_ = Image.open(cur_class_all_image[m]) 41 | if rgb_.mode == 'RGB': 42 | rgb_ = img_transform(rgb_) 43 | else: 44 | rgb_ = img_transform_gray(rgb_) 45 | cur_class_rgb[m, :, :, :] = rgb_ 46 | 47 | cur_class_mask = torch.zeros(len(cur_class_all_image), img_size, img_size) 48 | divided = len(cur_class_all_image) // group_size 49 | rested = len(cur_class_all_image) % group_size 50 | if divided != 0: 51 | for k in range(divided): 52 | group_rgb = cur_class_rgb[(k * group_size): ((k + 1) * group_size)] 53 | group_rgb = group_rgb.to(device) 54 | _, pred_mask = net(group_rgb) 55 | cur_class_mask[(k * group_size): ((k + 1) * group_size)] = pred_mask 56 | if rested != 0: 57 | group_rgb_tmp_l = cur_class_rgb[-rested:] 58 | group_rgb_tmp_r = cur_class_rgb[:group_size-rested] 59 | group_rgb = torch.cat((group_rgb_tmp_l, group_rgb_tmp_r), dim=0) 60 | group_rgb = group_rgb.to(device) 61 | _, pred_mask = net(group_rgb) 62 | cur_class_mask[(divided * group_size): ] = pred_mask[:rested] 63 | 64 | for q in range(cur_class_mask.size(0)): 65 | single_p, single_j = calc_precision_and_jaccard(cur_class_mask[q, :, :].numpy(), cur_class_gt[q, :, :].numpy()) 66 | all_p.append(single_p) 67 | all_j.append(single_j) 68 | 69 | dataset_p = np.mean(all_p) 70 | dataset_j = np.mean(all_j) 71 | 72 | ave_p.append(dataset_p) 73 | ave_j.append(dataset_j) 74 | 75 | return ave_p, ave_j --------------------------------------------------------------------------------