├── .gitignore
├── README.md
├── coseg_test.py
├── data_processed.py
├── loss.py
├── main.py
├── model.py
├── pic
├── Internet.png
├── MSRC.png
├── PASCALVOC.png
├── iCoseg.png
├── network.png
├── seen.png
└── unseen.png
├── requirement.txt
├── tools.py
├── train.py
└── val.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # [AAAI20] Deep Object Co-segmentation via Spatial-Semantic Network Modulation(Oral paper)
2 | ## Authors: [Kaihua Zhang](http://kaihuazhang.net/), [Jin Chen](https://github.com/cj4L), [Bo Liu](https://scholar.google.com/citations?user=2Fe93n8AAAAJ&hl=en), [Qingshan Liu](https://scholar.google.com/citations?user=2Pyf20IAAAAJ&hl=zh-CN)
3 | * PDF: [arXiv](https://arxiv.org/abs/1911.12950) or [AAAI20](https://aaai.org/ojs/index.php/AAAI/article/view/6977)
4 |
5 | ## Abstract
6 | Object co-segmentation is to segment the shared objects in multiple relevant images, which has numerous applications in computer vision. This paper presents a spatial and semantic modulated deep network framework for object co-segmentation. A backbone network is adopted to extract multi-resolution image features. With the multi-resolution features of the relevant images as input, we design a spatial modulator to learn a mask for each image. The spatial modulator captures the correlations of image feature descriptors via unsupervised learning. The learned mask can roughly localize the shared foreground object while suppressing the background. For the semantic modulator, we model it as a supervised image classification task. We propose a hierarchical second-order pooling module to transform the image features for classification use. The outputs of the two modulators manipulate the multi-resolution features by a shift-and-scale operation so that the features focus on segmenting co-object regions. The proposed model is trained end-to-end without any intricate post-processing. Extensive experiments on four image co-segmentation benchmark datasets demonstrate the superior accuracy of the proposed method compared to state-of-the-art methods.
7 |
8 | ## Examples
9 |
10 |
11 |
12 |
13 | ## Overview of our method
14 | 
15 |
16 | ## Datasets
17 | In order to compare the deep learning methods in recent years fairly, we conduct extensive evaluations on four widely-used benchmark datasets including sub-set of [MSRC](https://www.microsoft.com/en-us/research/project/image-understanding/?from=http%3A%2F%2Fresearch.microsoft.com%2Fen-us%2Fprojects%2Fobjectclassrecognition%2F), [Internet](http://people.csail.mit.edu/mrub/ObjectDiscovery/), sub-set of [iCoseg](http://chenlab.ece.cornell.edu/projects/touch-coseg/), and [PASCAL-VOC](http://host.robots.ox.ac.uk/pascal/VOC/). Among them:
18 | * The sub-set of MSRC includes 7 classes: bird, car, cat, cow, dog, plane, sheep, and each class contains 10 images.
19 | * The Internet has 3 categories of airplane, car and horse. Each class has 100 images including some images with noisy labels.
20 | * The sub-set of iCoseg contains 8 categories, and each has a different number of images.
21 | * The PASCAL-VOC is the most challenging dataset with 1037 images of 20 categories selected from the PASCAL-VOC 2010 dataset.
22 |
23 | ## Results download
24 | * VGG16-backbone: [Google Drive](https://drive.google.com/file/d/14h2XdIB0GR1Zb_0X59T0URgbuogJpR5Z/view?usp=sharing).
25 | * HRNet-backbone: [Google Drive](https://drive.google.com/file/d/1r8piQHHVosecDJD6DmDZVriUzEfQxCeB/view?usp=sharing).
26 |
27 | ## Environment
28 | * Ubuntu 16.04, Nvidia RTX 2080Ti
29 | * Python 3
30 | * PyTorch>=1.0, TorchVision>=0.2.2
31 | * Numpy==1.16.2, Pillow, pycocotools
32 |
33 | ## Test
34 | * Get or download the dataset we have processed in [Google Drive](https://drive.google.com/file/d/1bo5zE64bQwLUbCUGKDLcRjHei9FBmhfi/view?usp=sharing).
35 | * Download VGG16-backbone pretrained model in [Google Drive](https://drive.google.com/file/d/1Vvir1CeuCNQY7GU_Ygh593U5I-KXZWff/view?usp=sharing).
36 | * Modify the path config in coseg_test.py and run it.
37 |
38 | ## Train
39 | * Get the COCO2017 Dataset for training the whole network.
40 | * Get the test dataset for val and test phase.
41 | * Download VGG16 pretrained weights in [Google Drive](https://drive.google.com/file/d/1KIWIspVxLRwv8bzOuMn6lY8kStoedToV/view?usp=sharing). Actually is from PyTorch offical model weights, expect for deleting the last serveral layers.
42 | * Download dict.npy in [Google Drive](https://drive.google.com/file/d/1p15hGN3YwqWMRN4xx5mDIK04OhimpY2z/view?usp=sharing).
43 | * Modify the path config in main.py and run it.
44 |
45 | ### Notes
46 | * Following the suggestion of reviewers in AAAI20, we would not release the HRNet-backbone trained model for fairly comparing with others methods.
47 | * There are some slight differences in the 'Fusion' part of the model but little impact.
48 | * There is a mistake value in Table 2, our HRNet J-index(82.5) in 'Car' in Internet Dataset should be modified with (73.9).
49 | * There is something wrong about the share link of BaiduPan, contact me if want.
50 |
51 | #### Schedule
52 | - [x] Create github repo (2019.11.18)
53 | - [x] Release arXiv pdf (2019.12.2)
54 | - [x] Release AAAI20 pdf (2020.7.3)
55 | - [x] All results (2020.7.3)
56 | - [x] Test and Train code (2021.6.4)
57 |
--------------------------------------------------------------------------------
/coseg_test.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | import torch
4 | from torchvision import transforms
5 | from model import build_model
6 |
7 | def test(gpu_id, model_path, datapath, save_root_path, group_size, img_size, img_dir_name):
8 | net = build_model(device).to(device)
9 | net.load_state_dict(torch.load(model_path, map_location=gpu_id))
10 | net.eval()
11 |
12 | img_transform = transforms.Compose([transforms.Resize((img_size, img_size)), transforms.ToTensor(),
13 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
14 | img_transform_gray = transforms.Compose([transforms.Resize((img_size, img_size)), transforms.ToTensor(),
15 | transforms.Normalize(mean=[0.449], std=[0.226])])
16 | with torch.no_grad():
17 | for p in range(len(datapath)):
18 | all_class = os.listdir(os.path.join(datapath[p], img_dir_name))
19 | image_list, save_list = list(), list()
20 | for s in range(len(all_class)):
21 | image_path = os.listdir(os.path.join(datapath[p], img_dir_name, all_class[s]))
22 | image_list.append(list(map(lambda x: os.path.join(datapath[p], img_dir_name, all_class[s], x), image_path)))
23 | save_list.append(list(map(lambda x: os.path.join(save_root_path[p], all_class[s], x[:-4]+'.png'), image_path)))
24 | for i in range(len(image_list)):
25 | cur_class_all_image = image_list[i]
26 | cur_class_rgb = torch.zeros(len(cur_class_all_image), 3, img_size, img_size)
27 | for m in range(len(cur_class_all_image)):
28 | rgb_ = Image.open(cur_class_all_image[m])
29 | if rgb_.mode == 'RGB':
30 | rgb_ = img_transform(rgb_)
31 | else:
32 | rgb_ = img_transform_gray(rgb_)
33 | cur_class_rgb[m, :, :, :] = rgb_
34 |
35 | cur_class_mask = torch.zeros(len(cur_class_all_image), img_size, img_size)
36 | divided = len(cur_class_all_image) // group_size
37 | rested = len(cur_class_all_image) % group_size
38 | if divided != 0:
39 | for k in range(divided):
40 | group_rgb = cur_class_rgb[(k * group_size): ((k + 1) * group_size)]
41 | group_rgb = group_rgb.to(device)
42 | _, pred_mask = net(group_rgb)
43 | cur_class_mask[(k * group_size): ((k + 1) * group_size)] = pred_mask
44 | if rested != 0:
45 | group_rgb_tmp_l = cur_class_rgb[-rested:]
46 | group_rgb_tmp_r = cur_class_rgb[:group_size - rested]
47 | group_rgb = torch.cat((group_rgb_tmp_l, group_rgb_tmp_r), dim=0)
48 | group_rgb = group_rgb.to(device)
49 | _, pred_mask = net(group_rgb)
50 | cur_class_mask[(divided * group_size):] = pred_mask[:rested]
51 |
52 |
53 | class_save_path = os.path.join(save_root_path[p], all_class[i])
54 | if not os.path.exists(class_save_path):
55 | os.makedirs(class_save_path)
56 |
57 | for j in range(len(cur_class_all_image)):
58 | exact_save_path = save_list[i][j]
59 | result = cur_class_mask[j, :, :].numpy()
60 | result = Image.fromarray(result * 255)
61 | w, h = Image.open(image_list[i][j]).size
62 | result = result.resize((w, h), Image.BILINEAR)
63 | result.convert('L').save(exact_save_path)
64 |
65 | print('done')
66 |
67 |
68 |
69 |
70 | if __name__ == '__main__':
71 | gpu_id = 'cuda:0'
72 | device = torch.device(gpu_id)
73 | model_path = './models/SSNM-Coseg_best.pth'
74 |
75 | val_datapath = ['./cosegdatasets/iCoseg8',
76 | './cosegdatasets/Internet_Datasets300',
77 | './cosegdatasets/MSRC7',
78 | './cosegdatasets/PASCAL_VOC']
79 |
80 | save_root_path = ['./cosegresults/iCoseg8',
81 | './cosegresults/Internet_Datasets300',
82 | './cosegresults/MSRC7',
83 | './cosegresults/PASCAL_VOC']
84 |
85 | test(gpu_id, model_path, val_datapath, save_root_path, 5, 224, 'image')
86 |
--------------------------------------------------------------------------------
/data_processed.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision import transforms
3 | import copy
4 | import random
5 | import os
6 | import numpy as np
7 | from PIL import Image
8 |
9 |
10 | def filt_small_instance(coco_item, pixthreshold=4000,imgNthreshold=5):
11 | list_dict = coco_item.catToImgs
12 | for catid in list_dict:
13 | list_dict[catid] = list(set( list_dict[catid] ))
14 | new_dict = copy.deepcopy(list_dict)
15 | for catid in list_dict:
16 | imgids = list_dict[catid]
17 | for n in range(len(imgids)):
18 | imgid = imgids[n]
19 | anns = coco_item.imgToAnns[imgid]
20 | has_large_instance = False
21 | for ann in anns:
22 | if (ann['category_id'] == catid) and (ann['iscrowd'] == 0) and (ann['area'] > pixthreshold):
23 | has_large_instance = True
24 | if has_large_instance is False:
25 | new_dict[catid].remove(imgid)
26 | imgN = len(new_dict[catid])
27 | if imgN len(list_dict):
56 | remainN = batch_size - len(list_dict)
57 | batch_catid = random.sample(list_dict.keys(), remainN) + random.sample(list_dict, len(list_dict))
58 | else:
59 | batch_catid = random.sample(list_dict.keys(), batch_size)
60 | group_n = 0
61 | img_n = 0
62 | for catid in batch_catid:
63 | imgids = random.sample(list_dict[catid], group_size)
64 | co_catids = []
65 | anns = coco_item.imgToAnns[imgids[0]]
66 | for ann in anns:
67 | if (ann['iscrowd'] == 0) and (ann['area'] > 4000):
68 | co_catids.append(ann['category_id'])
69 | co_catids_backup = copy.deepcopy(co_catids)
70 | for imgid in imgids[1:]:
71 | img_catids = []
72 | anns = coco_item.imgToAnns[imgid]
73 | for ann in anns:
74 | if (ann['iscrowd'] == 0) and (ann['area'] > 4000):
75 | img_catids.append(ann['category_id'])
76 | for co_catid in co_catids_backup:
77 | if co_catid not in img_catids:
78 | co_catids.remove(co_catid)
79 | co_catids_backup = copy.deepcopy(co_catids)
80 | for co_catid in co_catids:
81 | cls_labels[group_n, catid2label[co_catid]] = 1
82 | for imgid in imgids:
83 | path = datapath + '%012d.jpg'%imgid
84 | img = Image.open(path)
85 | if img.mode == 'RGB':
86 | img = img_transform(img)
87 | else:
88 | img = img_transform_gray(img)
89 | anns = coco_item.imgToAnns[imgid]
90 | mask = None
91 | for ann in anns:
92 | if ann['category_id'] in co_catids:
93 | if mask is None:
94 | mask = coco_item.annToMask(ann)
95 | else:
96 | mask = mask + coco_item.annToMask(ann)
97 | mask[mask > 0] = 255
98 | mask = Image.fromarray(mask)
99 | mask = gt_transform(mask)
100 | mask[mask > 0.5] = 1
101 | mask[mask <= 0.5] = 0
102 | rgb[img_n,:,:,:] = copy.deepcopy(img)
103 | mask_labels[img_n,:,:] = copy.deepcopy(mask)
104 | img_n = img_n + 1
105 | group_n = group_n + 1
106 | idx = mask_labels[:, :, :] > 1
107 | mask_labels[idx] = 1
108 | q.put([rgb, cls_labels, mask_labels])
109 |
--------------------------------------------------------------------------------
/loss.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch.nn.functional as F
3 | import torch
4 |
5 |
6 | class Weighed_Bce_Loss(nn.Module):
7 | def __init__(self):
8 | super(Weighed_Bce_Loss, self).__init__()
9 |
10 | def forward(self, x, label):
11 | x = x.view(-1, 1, x.shape[1], x.shape[2])
12 | label = label.view(-1, 1, label.shape[1], label.shape[2])
13 | label_t = (label == 1).float()
14 | label_f = (label == 0).float()
15 | p = torch.sum(label_t) / (torch.sum(label_t) + torch.sum(label_f))
16 | w = torch.zeros_like(label)
17 | w[label == 1] = p
18 | w[label == 0] = 1 - p
19 | loss = F.binary_cross_entropy(x, label, weight=w)
20 | return loss
21 |
22 |
23 | class Cls_Loss(nn.Module):
24 | def __init__(self):
25 | super(Cls_Loss, self).__init__()
26 |
27 | def forward(self, x, label):
28 | loss = F.binary_cross_entropy(x, label)
29 | return loss
30 |
31 | class S_Loss(nn.Module):
32 | def __init__(self):
33 | super(S_Loss, self).__init__()
34 |
35 | def forward(self, x, label):
36 | loss = F.smooth_l1_loss(x, label)
37 | return loss
38 |
39 | class Loss(nn.Module):
40 | def __init__(self):
41 | super(Loss, self).__init__()
42 | self.loss_wbce = Weighed_Bce_Loss()
43 | self.loss_cls = Cls_Loss()
44 | self.loss_s = S_Loss()
45 | self.w_wbce = 1
46 | self.w_cls = 1
47 | self.w_smooth = 1
48 |
49 | def forward(self, x, label, x_cls, label_cls):
50 | m_loss = self.loss_wbce(x, label) * self.w_wbce
51 | c_loss = self.loss_cls(x_cls, label_cls) * self.w_cls
52 | s_loss = self.loss_s(x, label) * self.w_smooth
53 | loss = m_loss + c_loss + s_loss
54 |
55 | return loss, m_loss, c_loss, s_loss
56 |
57 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from pycocotools import coco
4 | import queue
5 | import threading
6 | from model import build_model, weights_init
7 | from tools import custom_print
8 | from data_processed import train_data_producer
9 | from train import train
10 | import time
11 | torch.backends.cudnn.benchmark = True
12 |
13 |
14 | if __name__ == '__main__':
15 | # train_val_config
16 | annotation_file = '/home/chenjin/dataset/COCO/COCO2017/annotations/instances_train2017.json'
17 | coco_item = coco.COCO(annotation_file=annotation_file)
18 |
19 | train_datapath = '/home/chenjin/dataset/COCO/COCO2017/train2017/'
20 |
21 | val_datapath = ['./cosegdatasets/iCoseg8',
22 | './cosegdatasets/MSRC7',
23 | './cosegdatasets/Internet_Datasets300',
24 | './cosegdatasets/PASCAL_VOC']
25 |
26 | vgg16_path = './weights/vgg16_bn_feat.pth'
27 | npy = './utils/new_cat2imgid_dict4000.npy'
28 |
29 | # project config
30 | project_name = 'SSNM-Coseg'
31 | device = torch.device('cuda:0')
32 | img_size = 224
33 | lr = 1e-5
34 | lr_de = 20000
35 | epochs = 100000
36 | batch_size = 4
37 | group_size = 5
38 | log_interval = 100
39 | val_interval = 1000
40 |
41 | # create log dir
42 | log_root = './logs'
43 | if not os.path.exists(log_root):
44 | os.makedirs(log_root)
45 |
46 | # create log txt
47 | log_txt_file = os.path.join(log_root, project_name + '_log.txt')
48 | custom_print(project_name, log_txt_file, 'w')
49 |
50 | # create model save dir
51 | models_root = './models'
52 | if not os.path.exists(models_root):
53 | os.makedirs(models_root)
54 |
55 | models_train_last = os.path.join(models_root, project_name + '_last.pth')
56 | models_train_best = os.path.join(models_root, project_name + '_best.pth')
57 |
58 | net = build_model(device).to(device)
59 | net.train()
60 | net.apply(weights_init)
61 | net.base.load_state_dict(torch.load(vgg16_path))
62 |
63 | # continute load checkpoint
64 | # net.load_state_dict(torch.load('./models/SSNM-Coseg_last.pth', map_location='cuda:0'))
65 |
66 | q = queue.Queue(maxsize=40)
67 |
68 | p1 = threading.Thread(target=train_data_producer, args=(coco_item, train_datapath, npy, q, batch_size, group_size, img_size))
69 | p2 = threading.Thread(target=train_data_producer, args=(coco_item, train_datapath, npy, q, batch_size, group_size, img_size))
70 | p3 = threading.Thread(target=train_data_producer, args=(coco_item, train_datapath, npy, q, batch_size, group_size, img_size))
71 | p1.start()
72 | p2.start()
73 | p3.start()
74 | time.sleep(2)
75 |
76 | train(net, device, q, log_txt_file, val_datapath, models_train_best, models_train_last, lr, lr_de, epochs, log_interval, val_interval)
77 |
78 |
79 |
80 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import init
4 | import torch.nn.functional as F
5 | from torch.optim import Adam
6 |
7 | # vgg choice
8 | base = {'vgg': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M']}
9 |
10 | # vgg16
11 | def vgg(cfg, i=3, batch_norm=True):
12 | layers = []
13 | in_channels = i
14 | for v in cfg:
15 | if v == 'M':
16 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
17 | else:
18 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
19 | if batch_norm:
20 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
21 | else:
22 | layers += [conv2d, nn.ReLU(inplace=True)]
23 | in_channels = v
24 | return layers
25 |
26 |
27 | def hsp(in_channel, out_channel):
28 | layers = nn.Sequential(nn.Conv2d(in_channel, out_channel, 1, 1),
29 | nn.ReLU())
30 | return layers
31 |
32 | def cls_modulation_branch(in_channel, hiden_channel):
33 | layers = nn.Sequential(nn.Linear(in_channel, hiden_channel),
34 | nn.ReLU())
35 | return layers
36 |
37 | def cls_branch(hiden_channel, class_num):
38 | layers = nn.Sequential(nn.Linear(hiden_channel, class_num),
39 | nn.Sigmoid())
40 | return layers
41 |
42 | def concat_r():
43 | layers = []
44 | layers += [nn.Conv2d(512, 512, 1, 1)]
45 | layers += [nn.ReLU()]
46 | layers += [nn.Conv2d(512, 512, 3, 1, 1)]
47 | layers += [nn.ReLU()]
48 | layers += [nn.ConvTranspose2d(512, 512, 4, 2, 1)]
49 | return layers
50 |
51 | def concat_1():
52 | layers = []
53 | layers += [nn.Conv2d(512, 512, 1, 1)]
54 | layers += [nn.ReLU()]
55 | layers += [nn.Conv2d(512, 512, 3, 1, 1)]
56 | layers += [nn.ReLU()]
57 | return layers
58 |
59 | def mask_branch():
60 | layers = []
61 | layers += [nn.Conv2d(512, 2, 3, 1, 1)]
62 | layers += [nn.ConvTranspose2d(2, 2, 8, 4, 2)]
63 | layers += [nn.Softmax2d()]
64 | return layers
65 |
66 | def incr_channel():
67 | layers = []
68 | layers += [nn.Conv2d(128, 512, 3, 1, 1)]
69 | layers += [nn.Conv2d(256, 512, 3, 1, 1)]
70 | layers += [nn.Conv2d(512, 512, 3, 1, 1)]
71 | layers += [nn.Conv2d(512, 512, 3, 1, 1)]
72 | return layers
73 |
74 | def incr_channel2():
75 | layers = []
76 | layers += [nn.Conv2d(512, 512, 3, 1, 1)]
77 | layers += [nn.Conv2d(512, 512, 3, 1, 1)]
78 | layers += [nn.Conv2d(512, 512, 3, 1, 1)]
79 | layers += [nn.Conv2d(512, 512, 3, 1, 1)]
80 | layers += [nn.ReLU()]
81 | return layers
82 |
83 | def norm(x, dim):
84 | squared_norm = (x ** 2).sum(dim=dim, keepdim=True)
85 | normed = x / torch.sqrt(squared_norm)
86 | return normed
87 |
88 | def fuse_hsp(x, p):
89 | group_size = 5
90 | t = torch.zeros(group_size, x.size(1))
91 | for i in range(x.size(0)):
92 | tmp = x[i, :]
93 | if i == 0:
94 | nx = tmp.expand_as(t)
95 | else:
96 | nx = torch.cat(([nx, tmp.expand_as(t)]), dim=0)
97 | nx = nx.view(x.size(0)*group_size, x.size(1), 1, 1)
98 | y = nx.expand_as(p)
99 | return y
100 |
101 |
102 | class Model(nn.Module):
103 | def __init__(self, device, base, incr_channel, incr_channel2, hsp1, hsp2, cls_m, cls, concat_r, concat_1, mask_branch):
104 | super(Model, self).__init__()
105 | self.base = nn.ModuleList(base)
106 | self.sp1 = hsp1
107 | self.sp2 = hsp2
108 | self.cls_m = cls_m
109 | self.cls = cls
110 | self.incr_channel1 = nn.ModuleList(incr_channel)
111 | self.incr_channel2 = nn.ModuleList(incr_channel2)
112 | self.concat4 = nn.ModuleList(concat_r)
113 | self.concat3 = nn.ModuleList(concat_r)
114 | self.concat2 = nn.ModuleList(concat_r)
115 | self.concat1 = nn.ModuleList(concat_1)
116 | self.mask = nn.ModuleList(mask_branch)
117 | self.extract = [13, 23, 33, 43]
118 | self.device = device
119 | self.group_size = 5
120 |
121 | def forward(self, x):
122 | # backbone, p is the pool2, 3, 4, 5
123 | p = list()
124 | for k in range(len(self.base)):
125 | x = self.base[k](x)
126 | if k in self.extract:
127 | p.append(x)
128 |
129 | # increase the channel
130 | newp = list()
131 | for k in range(len(p)):
132 | np = self.incr_channel1[k](p[k])
133 | np = self.incr_channel2[k](np)
134 | newp.append(self.incr_channel2[4](np))
135 |
136 | # spatial modulator
137 | spa_mask = spatial_optimize(newp[3], self.group_size).to(self.device)
138 |
139 | # hsp
140 | x = newp[3]
141 | x = self.sp1(x)
142 | x = x.view(-1, x.size(1), x.size(2) * x.size(3))
143 | x = torch.bmm(x, x.transpose(1, 2))
144 | x = x.view(-1, x.size(1) * x.size(2))
145 | x = x.view(x.size(0) // 5, x.size(1), -1, 1)
146 | x = self.sp2(x)
147 | x = x.view(-1, x.size(1), x.size(2) * x.size(3))
148 | x = torch.bmm(x, x.transpose(1, 2))
149 | x = x.view(-1, x.size(1) * x.size(2))
150 |
151 | #cls pred
152 | cls_modulated_vector = self.cls_m(x)
153 | cls_pred = self.cls(cls_modulated_vector)
154 |
155 | #semantic and spatial modulator
156 | g1 = fuse_hsp(cls_modulated_vector, newp[0])
157 | g2 = fuse_hsp(cls_modulated_vector, newp[1])
158 | g3 = fuse_hsp(cls_modulated_vector, newp[2])
159 | g4 = fuse_hsp(cls_modulated_vector, newp[3])
160 |
161 | spa_1 = F.interpolate(spa_mask, size=[g1.size(2), g1.size(3)], mode='bilinear')
162 | spa_1 = spa_1.expand_as(g1)
163 | spa_2 = F.interpolate(spa_mask, size=[g2.size(2), g2.size(3)], mode='bilinear')
164 | spa_2 = spa_2.expand_as(g2)
165 | spa_3 = F.interpolate(spa_mask, size=[g3.size(2), g3.size(3)], mode='bilinear')
166 | spa_3 = spa_3.expand_as(g3)
167 | spa_4 = F.interpolate(spa_mask, size=[g4.size(2), g4.size(3)], mode='bilinear')
168 | spa_4 = spa_4.expand_as(g4)
169 |
170 | y4 = newp[3] * g4 + spa_4
171 | for k in range(len(self.concat4)):
172 | y4 = self.concat4[k](y4)
173 | y3 = newp[2] * g3 + spa_3
174 |
175 | for k in range(len(self.concat3)):
176 | y3 = self.concat3[k](y3)
177 | if k == 1:
178 | y3 = y3 + y4
179 | y2 = newp[1] * g2 + spa_2
180 |
181 | for k in range(len(self.concat2)):
182 | y2 = self.concat2[k](y2)
183 | if k == 1:
184 | y2 = y2 + y3
185 | y1 = newp[0] * g1 + spa_1
186 |
187 | for k in range(len(self.concat1)):
188 | y1 = self.concat1[k](y1)
189 | if k == 1:
190 | y1 = y1 + y2
191 | y = y1
192 |
193 | # decoder
194 | for k in range(len(self.mask)):
195 | y = self.mask[k](y)
196 | mask_pred = y[:, 0, :, :]
197 |
198 | return cls_pred, mask_pred
199 |
200 |
201 |
202 | # build the whole network
203 | def build_model(device):
204 | return Model(device,
205 | vgg(base['vgg']),
206 | incr_channel(),
207 | incr_channel2(),
208 | hsp(512, 64),
209 | hsp(64**2, 32),
210 | cls_modulation_branch(32**2, 512),
211 | cls_branch(512, 78),
212 | concat_r(),
213 | concat_1(),
214 | mask_branch())
215 |
216 | # weight init
217 | def xavier(param):
218 | init.xavier_uniform_(param)
219 |
220 | def weights_init(m):
221 | if isinstance(m, nn.Conv2d):
222 | xavier(m.weight.data)
223 | elif isinstance(m, nn.BatchNorm2d):
224 | init.constant_(m.weight, 1)
225 | init.constant_(m.bias, 0)
226 |
227 |
228 | def spatial_optimize(fmap, group_size):
229 | fmap_split = torch.split(fmap, group_size, dim=0)
230 | for i in range(len(fmap_split)):
231 | cur_fmap = fmap_split[i]
232 | with torch.no_grad():
233 | spatial_x = cur_fmap.permute(0, 2, 3, 1).contiguous().view(-1, cur_fmap.size(1)).transpose(1, 0)
234 | spatial_x = norm(spatial_x, dim=0)
235 | spatial_x_t = spatial_x.transpose(1, 0)
236 | G = spatial_x_t @ spatial_x - 1
237 | G = G.detach().cpu()
238 |
239 | with torch.enable_grad():
240 | spatial_s = nn.Parameter(torch.sqrt(245 * torch.ones((245, 1))) / 245, requires_grad=True)
241 | spatial_s_t = spatial_s.transpose(1, 0)
242 | spatial_s_optimizer = Adam([spatial_s], 0.01)
243 |
244 | for iter in range(200):
245 | f_spa_loss = -1 * torch.sum(spatial_s_t @ G @ spatial_s)
246 | spatial_s_d = torch.sqrt(torch.sum(spatial_s ** 2))
247 | if spatial_s_d >= 1:
248 | d_loss = -1 * torch.log(2 - spatial_s_d)
249 | else:
250 | d_loss = -1 * torch.log(spatial_s_d)
251 |
252 | all_loss = 50 * d_loss + f_spa_loss
253 |
254 | spatial_s_optimizer.zero_grad()
255 | all_loss.backward()
256 | spatial_s_optimizer.step()
257 |
258 | result_map = spatial_s.data.view(5, 1, 7, 7)
259 |
260 | if i == 0:
261 | spa_mask = result_map
262 | else:
263 | spa_mask = torch.cat(([spa_mask, result_map]), dim=0)
264 |
265 | return spa_mask
266 |
267 |
268 |
269 |
--------------------------------------------------------------------------------
/pic/Internet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cj4L/SSNM-Coseg/e172652a07407e8d1ff58fac208cad4592349d32/pic/Internet.png
--------------------------------------------------------------------------------
/pic/MSRC.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cj4L/SSNM-Coseg/e172652a07407e8d1ff58fac208cad4592349d32/pic/MSRC.png
--------------------------------------------------------------------------------
/pic/PASCALVOC.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cj4L/SSNM-Coseg/e172652a07407e8d1ff58fac208cad4592349d32/pic/PASCALVOC.png
--------------------------------------------------------------------------------
/pic/iCoseg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cj4L/SSNM-Coseg/e172652a07407e8d1ff58fac208cad4592349d32/pic/iCoseg.png
--------------------------------------------------------------------------------
/pic/network.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cj4L/SSNM-Coseg/e172652a07407e8d1ff58fac208cad4592349d32/pic/network.png
--------------------------------------------------------------------------------
/pic/seen.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cj4L/SSNM-Coseg/e172652a07407e8d1ff58fac208cad4592349d32/pic/seen.png
--------------------------------------------------------------------------------
/pic/unseen.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cj4L/SSNM-Coseg/e172652a07407e8d1ff58fac208cad4592349d32/pic/unseen.png
--------------------------------------------------------------------------------
/requirement.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | ==1.16.2
3 | torch>=1.0
4 |
5 | torchvision==0.2.2
6 |
7 | pycocotools
8 |
9 | pillow
--------------------------------------------------------------------------------
/tools.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def custom_print(context, log_file, mode):
4 | #custom print and log out function
5 | if mode == 'w':
6 | fp = open(log_file, mode)
7 | fp.write(context + '\n')
8 | fp.close()
9 | elif mode == 'a+':
10 | print(context)
11 | fp = open(log_file, mode)
12 | print(context, file=fp)
13 | fp.close()
14 | else:
15 | raise Exception('other file operation is unimplemented !')
16 |
17 |
18 | def generate_binary_map(pred, type):
19 | if type == '2mean':
20 | threshold = np.mean(pred) * 2
21 | if threshold > 0.8:
22 | threshold = 0.8
23 | binary_map = pred > threshold
24 | return binary_map.astype(np.float32)
25 |
26 | if type == 'mean+std':
27 | threshold = np.mean(pred) + np.std(pred)
28 | if threshold > 0.8:
29 | threshold = 0.8
30 | binary_map = pred > threshold
31 | return binary_map.astype(np.float32)
32 |
33 |
34 |
35 | def calc_precision_and_jaccard(pred, gt):
36 | bin_pred = generate_binary_map(pred, 'mean+std')
37 | tp = (bin_pred == gt).sum()
38 | precision = tp / (pred.size)
39 |
40 | i = (bin_pred * gt).sum()
41 | u = bin_pred.sum() + gt.sum() - i
42 | jaccard = i / (u + 1e-10)
43 |
44 | return precision, jaccard
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from loss import Loss
2 | from torch.optim import Adam
3 | from tools import custom_print
4 | import datetime
5 | import torch
6 | from val import validation
7 |
8 | def train(net, device, q, log_txt_file, val_datapath, models_train_best, models_train_last, lr=1e-4, lr_de_epoch=25000,
9 | epochs=100000, log_interval=100, val_interval=1000):
10 | optimizer = Adam(net.parameters(), lr, weight_decay=1e-6)
11 | loss = Loss().to(device)
12 | best_p, best_j = 0, 1
13 | ave_loss, ave_m_loss, ave_c_loss, ave_s_loss = 0, 0, 0, 0
14 | for epoch in range(1, epochs+1):
15 | img, cls_gt, mask_gt = q.get()
16 | net.zero_grad()
17 | img, cls_gt, mask_gt = img.to(device), cls_gt.to(device), mask_gt.to(device)
18 | pred_cls, pred_mask = net(img)
19 | all_loss, m_loss, c_loss, s_loss = loss(pred_mask, mask_gt, pred_cls, cls_gt)
20 | all_loss.backward()
21 | epoch_loss = all_loss.item()
22 | m_l = m_loss.item()
23 | c_l = c_loss.item()
24 | s_l = s_loss.item()
25 | ave_loss += epoch_loss
26 | ave_m_loss += m_l
27 | ave_c_loss += c_l
28 | ave_s_loss += s_l
29 | optimizer.step()
30 |
31 | if epoch % log_interval == 0:
32 | ave_loss = ave_loss / log_interval
33 | ave_m_loss = ave_m_loss / log_interval
34 | ave_c_loss = ave_c_loss / log_interval
35 | ave_s_loss = ave_s_loss / log_interval
36 | custom_print(datetime.datetime.now().strftime('%F %T') +
37 | ' lr: %e, epoch: [%d/%d], all_loss: [%.4f], m_loss: [%.4f], c_loss: [%.4f], s_loss: [%.4f]' %
38 | (lr, epoch, epochs, ave_loss, ave_m_loss, ave_c_loss, ave_s_loss), log_txt_file, 'a+')
39 | ave_loss, ave_m_loss, ave_c_loss, ave_s_loss = 0, 0, 0, 0
40 |
41 | if epoch % val_interval == 0:
42 | net.eval()
43 | with torch.no_grad():
44 | custom_print(datetime.datetime.now().strftime('%F %T') +
45 | ' now is evaluating the coseg dataset', log_txt_file, 'a+')
46 | ave_p, ave_j = validation(net, val_datapath, device, group_size=5, img_size=224, img_dir_name='image', gt_dir_name='groundtruth',
47 | img_ext=['.jpg', '.jpg', '.jpg', '.jpg'], gt_ext=['.png', '.bmp', '.jpg', '.png'])
48 | if ave_p[3] > best_p:
49 | # follow yourself save condition
50 | best_p = ave_p[3]
51 | best_j = ave_j[0]
52 | torch.save(net.state_dict(), models_train_best)
53 | torch.save(net.state_dict(), models_train_last)
54 | custom_print('-' * 100, log_txt_file, 'a+')
55 | custom_print(datetime.datetime.now().strftime('%F %T') + ' iCoseg8 p: [%.4f], j: [%.4f]' %
56 | (ave_p[0], ave_j[0]), log_txt_file, 'a+')
57 | custom_print(datetime.datetime.now().strftime('%F %T') + ' MSRC7 p: [%.4f], j: [%.4f]' %
58 | (ave_p[1], ave_j[1]), log_txt_file, 'a+')
59 | custom_print(datetime.datetime.now().strftime('%F %T') + ' Int_300 p: [%.4f], j: [%.4f]' %
60 | (ave_p[2], ave_j[2]), log_txt_file, 'a+')
61 | custom_print(datetime.datetime.now().strftime('%F %T') + ' PAS_VOC p: [%.4f], j: [%.4f]' %
62 | (ave_p[3], ave_j[3]), log_txt_file, 'a+')
63 | custom_print('-' * 100, log_txt_file, 'a+')
64 | net.train()
65 |
66 | if epoch % lr_de_epoch == 0:
67 | optimizer = Adam(net.parameters(), lr/2, weight_decay=1e-6)
68 | lr = lr / 2
--------------------------------------------------------------------------------
/val.py:
--------------------------------------------------------------------------------
1 | from torchvision import transforms
2 | import torch
3 | import os
4 | from PIL import Image
5 | from tools import *
6 |
7 |
8 | def validation(net, datapath, device, group_size=5, img_size=224, img_dir_name='image', gt_dir_name='groundtruth',
9 | img_ext=['.jpg', '.jpg', '.jpg', '.jpg'], gt_ext=['.png', '.bmp', '.jpg', '.png']):
10 | img_transform = transforms.Compose([transforms.Resize((img_size, img_size)), transforms.ToTensor(),
11 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
12 | gt_transform = transforms.Compose([transforms.Resize((img_size, img_size)), transforms.ToTensor()])
13 | img_transform_gray = transforms.Compose([transforms.Resize((img_size, img_size)), transforms.ToTensor(),
14 | transforms.Normalize(mean=[0.449], std=[0.226])])
15 | net.eval()
16 | with torch.no_grad():
17 | ave_p, ave_j = [], []
18 | for p in range(len(datapath)):
19 | all_p, all_j = [], []
20 | all_class = os.listdir(os.path.join(datapath[p], img_dir_name))
21 | image_list, gt_list = list(), list()
22 | for s in range(len(all_class)):
23 | image_path = os.listdir(os.path.join(datapath[p], img_dir_name, all_class[s]))
24 | image_list.append(list(map(lambda x: os.path.join(datapath[p], img_dir_name, all_class[s], x), image_path)))
25 | gt_list.append(list(map(lambda x: os.path.join(datapath[p], gt_dir_name, all_class[s], x.replace(img_ext[p], gt_ext[p])), image_path)))
26 | for i in range(len(image_list)):
27 | cur_class_all_image = image_list[i]
28 | cur_class_all_gt = gt_list[i]
29 |
30 | cur_class_gt = torch.zeros(len(cur_class_all_gt), img_size, img_size)
31 | for g in range(len(cur_class_all_gt)):
32 | gt_ = Image.open(cur_class_all_gt[g]).convert('L')
33 | gt_ = gt_transform(gt_)
34 | gt_[gt_ > 0.5] = 1
35 | gt_[gt_ <= 0.5] = 0
36 | cur_class_gt[g, :, :] = gt_
37 |
38 | cur_class_rgb = torch.zeros(len(cur_class_all_image), 3, img_size, img_size)
39 | for m in range(len(cur_class_all_image)):
40 | rgb_ = Image.open(cur_class_all_image[m])
41 | if rgb_.mode == 'RGB':
42 | rgb_ = img_transform(rgb_)
43 | else:
44 | rgb_ = img_transform_gray(rgb_)
45 | cur_class_rgb[m, :, :, :] = rgb_
46 |
47 | cur_class_mask = torch.zeros(len(cur_class_all_image), img_size, img_size)
48 | divided = len(cur_class_all_image) // group_size
49 | rested = len(cur_class_all_image) % group_size
50 | if divided != 0:
51 | for k in range(divided):
52 | group_rgb = cur_class_rgb[(k * group_size): ((k + 1) * group_size)]
53 | group_rgb = group_rgb.to(device)
54 | _, pred_mask = net(group_rgb)
55 | cur_class_mask[(k * group_size): ((k + 1) * group_size)] = pred_mask
56 | if rested != 0:
57 | group_rgb_tmp_l = cur_class_rgb[-rested:]
58 | group_rgb_tmp_r = cur_class_rgb[:group_size-rested]
59 | group_rgb = torch.cat((group_rgb_tmp_l, group_rgb_tmp_r), dim=0)
60 | group_rgb = group_rgb.to(device)
61 | _, pred_mask = net(group_rgb)
62 | cur_class_mask[(divided * group_size): ] = pred_mask[:rested]
63 |
64 | for q in range(cur_class_mask.size(0)):
65 | single_p, single_j = calc_precision_and_jaccard(cur_class_mask[q, :, :].numpy(), cur_class_gt[q, :, :].numpy())
66 | all_p.append(single_p)
67 | all_j.append(single_j)
68 |
69 | dataset_p = np.mean(all_p)
70 | dataset_j = np.mean(all_j)
71 |
72 | ave_p.append(dataset_p)
73 | ave_j.append(dataset_j)
74 |
75 | return ave_p, ave_j
--------------------------------------------------------------------------------