├── demo_imgs ├── 012.jpg ├── 018.jpg ├── 238.jpg ├── 38092.jpg ├── 41029.jpg ├── 42049.jpg ├── 107793.jpg ├── 178279.jpg ├── 192566.jpg ├── 2007_001185.jpg ├── 2007_007151.jpg └── 2008_004532.jpg ├── pics ├── results.png ├── strategy.png └── structures.png ├── libs ├── EvalSPModule.so ├── __pycache__ │ ├── model.cpython-37.pyc │ ├── test.cpython-37.pyc │ ├── train.cpython-37.pyc │ ├── utils.cpython-37.pyc │ ├── layers.cpython-37.pyc │ ├── losses.cpython-37.pyc │ ├── update.cpython-37.pyc │ └── assignment.cpython-37.pyc ├── layers │ ├── __pycache__ │ │ ├── grm.cpython-37.pyc │ │ ├── agrl.cpython-37.pyc │ │ ├── seed.cpython-37.pyc │ │ └── embedder.cpython-37.pyc │ ├── embedder.py │ ├── grm.py │ └── seed.py ├── test.py ├── model.py └── utils.py ├── lnsnet_BSDS_checkpoint.pth ├── runDemo.sh ├── README.md └── demo.py /demo_imgs/012.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zh460045050/LNSNet/HEAD/demo_imgs/012.jpg -------------------------------------------------------------------------------- /demo_imgs/018.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zh460045050/LNSNet/HEAD/demo_imgs/018.jpg -------------------------------------------------------------------------------- /demo_imgs/238.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zh460045050/LNSNet/HEAD/demo_imgs/238.jpg -------------------------------------------------------------------------------- /pics/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zh460045050/LNSNet/HEAD/pics/results.png -------------------------------------------------------------------------------- /pics/strategy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zh460045050/LNSNet/HEAD/pics/strategy.png -------------------------------------------------------------------------------- /demo_imgs/38092.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zh460045050/LNSNet/HEAD/demo_imgs/38092.jpg -------------------------------------------------------------------------------- /demo_imgs/41029.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zh460045050/LNSNet/HEAD/demo_imgs/41029.jpg -------------------------------------------------------------------------------- /demo_imgs/42049.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zh460045050/LNSNet/HEAD/demo_imgs/42049.jpg -------------------------------------------------------------------------------- /pics/structures.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zh460045050/LNSNet/HEAD/pics/structures.png -------------------------------------------------------------------------------- /demo_imgs/107793.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zh460045050/LNSNet/HEAD/demo_imgs/107793.jpg -------------------------------------------------------------------------------- /demo_imgs/178279.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zh460045050/LNSNet/HEAD/demo_imgs/178279.jpg -------------------------------------------------------------------------------- /demo_imgs/192566.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zh460045050/LNSNet/HEAD/demo_imgs/192566.jpg -------------------------------------------------------------------------------- /libs/EvalSPModule.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zh460045050/LNSNet/HEAD/libs/EvalSPModule.so -------------------------------------------------------------------------------- /demo_imgs/2007_001185.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zh460045050/LNSNet/HEAD/demo_imgs/2007_001185.jpg -------------------------------------------------------------------------------- /demo_imgs/2007_007151.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zh460045050/LNSNet/HEAD/demo_imgs/2007_007151.jpg -------------------------------------------------------------------------------- /demo_imgs/2008_004532.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zh460045050/LNSNet/HEAD/demo_imgs/2008_004532.jpg -------------------------------------------------------------------------------- /lnsnet_BSDS_checkpoint.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zh460045050/LNSNet/HEAD/lnsnet_BSDS_checkpoint.pth -------------------------------------------------------------------------------- /libs/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zh460045050/LNSNet/HEAD/libs/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /libs/__pycache__/test.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zh460045050/LNSNet/HEAD/libs/__pycache__/test.cpython-37.pyc -------------------------------------------------------------------------------- /libs/__pycache__/train.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zh460045050/LNSNet/HEAD/libs/__pycache__/train.cpython-37.pyc -------------------------------------------------------------------------------- /libs/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zh460045050/LNSNet/HEAD/libs/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /libs/__pycache__/layers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zh460045050/LNSNet/HEAD/libs/__pycache__/layers.cpython-37.pyc -------------------------------------------------------------------------------- /libs/__pycache__/losses.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zh460045050/LNSNet/HEAD/libs/__pycache__/losses.cpython-37.pyc -------------------------------------------------------------------------------- /libs/__pycache__/update.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zh460045050/LNSNet/HEAD/libs/__pycache__/update.cpython-37.pyc -------------------------------------------------------------------------------- /libs/__pycache__/assignment.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zh460045050/LNSNet/HEAD/libs/__pycache__/assignment.cpython-37.pyc -------------------------------------------------------------------------------- /libs/layers/__pycache__/grm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zh460045050/LNSNet/HEAD/libs/layers/__pycache__/grm.cpython-37.pyc -------------------------------------------------------------------------------- /libs/layers/__pycache__/agrl.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zh460045050/LNSNet/HEAD/libs/layers/__pycache__/agrl.cpython-37.pyc -------------------------------------------------------------------------------- /libs/layers/__pycache__/seed.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zh460045050/LNSNet/HEAD/libs/layers/__pycache__/seed.cpython-37.pyc -------------------------------------------------------------------------------- /libs/layers/__pycache__/embedder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zh460045050/LNSNet/HEAD/libs/layers/__pycache__/embedder.cpython-37.pyc -------------------------------------------------------------------------------- /runDemo.sh: -------------------------------------------------------------------------------- 1 | python demo.py --n_spix 100 \ 2 | --img_path "./demo_imgs/107793.jpg" \ 3 | --check_path "./lnsnet_BSDS_checkpoint.pth" \ 4 | --seed_strategy 'network' 5 | -------------------------------------------------------------------------------- /libs/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | import scipy.io as io 6 | import cv2 7 | import time 8 | 9 | import os 10 | import argparse 11 | import numpy as np 12 | import matplotlib.pyplot as plt 13 | from skimage.segmentation import mark_boundaries 14 | 15 | from skimage.segmentation._slic import _enforce_label_connectivity_cython 16 | 17 | from libs.model import * 18 | from libs.utils import * 19 | 20 | import random 21 | random.seed(1) 22 | torch.manual_seed(1) 23 | 24 | def assignment_test(f, input, cx, cy, alpha=1): 25 | 26 | b, _, h, w = input.size() 27 | p = input[:, 3:, :, :] 28 | 29 | p = p.view(b, 2, -1) 30 | cind = cx * w + cy 31 | cind = cind.long() 32 | c_p = p[:, :, cind] 33 | c_f = f[:, :, cind] 34 | 35 | _, c, k = c_f.size() 36 | 37 | N = h*w 38 | 39 | dis = torch.zeros(b, k, N) 40 | for i in range(0, k): 41 | cur_c_f = c_f[:, :, i].unsqueeze(-1).expand(b, c, N) 42 | cur_p_ij = cur_c_f - f 43 | cur_p_ij = torch.pow(cur_p_ij, 2) 44 | cur_p_ij = torch.sum(cur_p_ij, dim=1) 45 | dis[:, i, :] = cur_p_ij 46 | dis = dis / alpha 47 | dis = torch.pow((1 + dis), -(alpha + 1) / 2) 48 | dis = dis.view(b, k, N).permute(0, 2, 1).contiguous() #b,N,k 49 | dis = dis / torch.sum(dis, dim=2).unsqueeze(-1) 50 | 51 | return dis 52 | -------------------------------------------------------------------------------- /libs/layers/embedder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | import scipy.io as io 6 | import cv2 7 | 8 | import os 9 | import argparse 10 | import numpy as np 11 | import matplotlib.pyplot as plt 12 | from skimage.segmentation import mark_boundaries 13 | 14 | from skimage.segmentation._slic import _enforce_label_connectivity_cython 15 | 16 | from libs.utils import * 17 | 18 | import random 19 | random.seed(1) 20 | torch.manual_seed(1) 21 | 22 | 23 | class Embedder(nn.Module): 24 | 25 | def __init__(self, is_dilation=True): 26 | super().__init__() 27 | 28 | self.is_dilation = is_dilation 29 | 30 | 31 | self.rpad_1 = nn.ReflectionPad2d(1) 32 | if self.is_dilation: 33 | self.c1_1 = nn.Conv2d(5, 10, 3, padding=0) 34 | self.c1_2 = nn.Conv2d(5, 10, 3, padding=0, dilation=1) 35 | self.c1_3 = nn.Conv2d(5, 10, 3, padding=1, dilation=2) 36 | self.c1_4 = nn.Sequential(nn.InstanceNorm2d(35, affine=True), nn.ReLU(), nn.ReflectionPad2d(1), nn.Conv2d(35, 10, 3, padding=0)) 37 | else: 38 | self.c1 = nn.Conv2d(5, 10, 3, padding=0) 39 | self.inorm_1 = nn.InstanceNorm2d(10, affine=True) 40 | #self.inorm_1 = nn.BatchNorm2d(10, affine=True) 41 | 42 | self.rpad_2 = nn.ReflectionPad2d(1) 43 | self.c2 = nn.Conv2d(15, 20, 3, padding=0) 44 | #self.inorm_2 = nn.BatchNorm2d(20, affine=True) 45 | self.inorm_2 = nn.InstanceNorm2d(20, affine=True) 46 | 47 | self.relu = nn.ReLU() 48 | 49 | 50 | def forward(self, x): 51 | 52 | spix = self.rpad_1(x) 53 | if self.is_dilation: 54 | spix_1 = self.c1_1(spix) 55 | spix_2 = self.c1_2(spix) 56 | spix_3 = self.c1_3(spix) 57 | spix = torch.cat([x, spix_1, spix_2, spix_3], dim=1) 58 | spix = self.c1_4(spix) 59 | else: 60 | spix = self.c1(spix) 61 | 62 | spix = self.inorm_1(spix) 63 | spix = self.relu(spix) 64 | 65 | spix = torch.cat((spix, x), dim=1) 66 | 67 | spix = self.rpad_2(spix) 68 | spix = self.c2(spix) 69 | spix = self.inorm_2(spix) 70 | spix = self.relu(spix) 71 | 72 | return spix 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | -------------------------------------------------------------------------------- /libs/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | import scipy.io as io 6 | import cv2 7 | 8 | import os 9 | import argparse 10 | import numpy as np 11 | import matplotlib.pyplot as plt 12 | from skimage.segmentation import mark_boundaries 13 | 14 | from skimage.segmentation._slic import _enforce_label_connectivity_cython 15 | 16 | from libs.layers.grm import * 17 | from libs.layers.embedder import * 18 | from libs.layers.seed import * 19 | 20 | import random 21 | random.seed(1) 22 | torch.manual_seed(1) 23 | 24 | 25 | class LNSN(nn.Module): 26 | 27 | def __init__(self, n_spix, args): 28 | super().__init__() 29 | self.n_spix = n_spix 30 | self.sp_num = n_spix 31 | self.is_dilation = args.is_dilation 32 | self.device = args.device 33 | 34 | self.seed_strategy = args.seed_strategy 35 | 36 | 37 | self.train = True 38 | self.kn = args.kn 39 | ###########Optimizer Parameter########## 40 | 41 | self.embedder = Embedder(self.is_dilation) 42 | 43 | self.generater = SeedGenerater(self.n_spix, self.device, seed_strategy=self.seed_strategy) 44 | 45 | self.grm = GRM(args) 46 | 47 | 48 | #############init####################### 49 | 50 | for m in self.modules(): 51 | if isinstance(m, nn.Conv2d): 52 | nn.init.xavier_normal_(m.weight) 53 | if m.bias is not None: 54 | nn.init.constant_(m.bias, 0) 55 | elif isinstance(m, nn.InstanceNorm2d): 56 | nn.init.constant_(m.weight, 1) 57 | nn.init.constant_(m.bias, 0) 58 | elif isinstance(m, nn.BatchNorm2d): 59 | nn.init.constant_(m.weight, 1) 60 | nn.init.constant_(m.bias, 0) 61 | 62 | def forward(self, x, bd): 63 | 64 | 65 | b, _, h, w = x.size() 66 | ##########Feature Extracting########### 67 | f = self.embedder(x) 68 | 69 | #########Seed Generate####### 70 | cx, cy, probs = self.generater(f) 71 | 72 | 73 | if self.train: 74 | 75 | f, recons = self.grm(f, bd) 76 | 77 | f = f.view(b, -1, h*w) 78 | 79 | return recons, cx, cy, f, probs 80 | else: 81 | f = f.view(b, -1, h*w) 82 | return cx, cy, f, probs 83 | 84 | 85 | 86 | 87 | 88 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LNSNet 2 | ## Overview 3 | 4 | Official implementation of [Learning the Superpixel in a Non-iterative and Lifelong Manner][arxiv] (CVPR'21) 5 | 6 | ### Learning Strategy 7 | 8 | The proposed LNSNet views superpixel segmentation process of each image as **an independent pixel-level clustering task** and use **lifelong learning strategy** to train the superpixel segmentation network for a a series of images. 9 | 10 |
11 | 12 | 13 |
14 | 15 | ### Model Structure 16 | 17 | The structure of proposed LNS-Net shown in Fig. 3 contains three parts: 18 | 19 | 1) **Feature Embedder Module (FEM)** that embeds the original feature into a cluster-friendly space; 20 | 2) **Non-iteratively Clustering Module (NCM)** that assigns the label for pixels with the help of a seed estimation module, which automatically estimates 21 | the indexes of seed nodes; 22 | 3) **Gradient Rescaling Module (GRM)** that adaptively rescales the gradient for each weight parameter based on the channel and spatial context to avoid catastrophic forgetting 23 | for the sequential learning. 24 | 25 |
26 | 27 | 28 |
29 | 30 | 31 | ## Getting Started 32 | 33 | Here we only release the model trained on BSDS dataset and corresponding code to utilizes it for superpixel segmentation. The whole training code will be coming soon. 34 | 35 | To uese the given model for generate superpixel: 36 | 37 | `git clone https://github.com/zh460045050/LNSNet` 38 | 39 | `cd LNSNet` 40 | 41 | `sh runDemo.sh` 42 | 43 | or 44 | 45 | `python demo.py --n_spix $num_superpixel --img_path $input_img_path --check_path lnsnet_BSDS_checkpoint.pth` 46 | 47 | The performance and complexity of methods for generating 100 superpixel on BSDS test dataset with image size 481*321: 48 | 49 |
50 | 51 | 52 |
53 | 54 | 55 | ## Citation 56 | 57 | If you find our work useful in your research, please cite: 58 | 59 | @InProceedings\{Lei_2021_CVPR,
60 | title = \{Learning the Superpixel in a Non-iterative and Lifelong Manner\},
61 | author = \{Zhu, Lei and She, Qi and Zhang, Bin and Lu, Yanye and Lu, Zhilin and Li, Duo and Hu, Jie\},
62 | booktitle = \{IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)\},
63 | month = \{June\},
64 | year = \{2021\}
65 | \} 66 | 67 | 68 | [arxiv]: https://arxiv.org/abs/2103.10681 69 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | import scipy.io as io 6 | import cv2 7 | 8 | from tensorboardX import SummaryWriter 9 | import os 10 | import argparse 11 | import numpy as np 12 | import matplotlib.pyplot as plt 13 | from skimage.segmentation import mark_boundaries 14 | 15 | from skimage.segmentation._slic import _enforce_label_connectivity_cython 16 | 17 | from libs.model import * 18 | from libs.test import * 19 | 20 | import time 21 | 22 | import random 23 | random.seed(1) 24 | torch.manual_seed(1) 25 | 26 | parser = argparse.ArgumentParser() 27 | 28 | ######Data Setting####### 29 | parser.add_argument('--img_path', default='demo_imgs/012.jpg', help='The path of the source image') 30 | ######Model Setting####### 31 | parser.add_argument('--device', default='cpu', help='use cuda or cpu') 32 | parser.add_argument('--n_spix', type=int, default=100, help='The number of superpixel') 33 | parser.add_argument('--kn', type=int, default=16, help='The number of dis limit') 34 | parser.add_argument('--seed_strategy', type=str, default='network', help='network/grid') 35 | #####Optimizing Setting#### 36 | parser.add_argument('--lr', type=float, default=3e-4, help='learning rate') 37 | parser.add_argument('--use_gal', default=True, help='Using sowm weight update') 38 | parser.add_argument('--use_gbl', default=True, help='Using sowm weight update') 39 | parser.add_argument('--is_dilation', default=True, help='Using dilation convolution') 40 | parser.add_argument('--check_path', type=str, default='./lnsn_BSDS_checkpoint.pth.pth') 41 | 42 | ################ 43 | 44 | args = parser.parse_args() 45 | 46 | 47 | model = LNSN(args.n_spix, args) 48 | 49 | model.load_state_dict(torch.load(args.check_path)) 50 | 51 | img = plt.imread(args.img_path) 52 | input = preprocess(img, args.device) 53 | 54 | 55 | with torch.no_grad(): 56 | 57 | b, _, h, w = input.size() 58 | recons, cx, cy, f, probs = model.forward(input, torch.zeros(h, w)) 59 | spix = assignment_test(f, input, cx, cy) 60 | 61 | spix = spix.permute(0, 2, 1).contiguous().view(b, -1, h, w) 62 | spix = spix.argmax(1).squeeze().to("cpu").detach().numpy() 63 | 64 | 65 | segment_size = spix.size / args.n_spix 66 | min_size = int(0.06 * segment_size) 67 | max_size = int(3.0 * segment_size) 68 | spix = _enforce_label_connectivity_cython(spix[None], min_size, max_size)[0] 69 | 70 | if img.shape[:2] != spix.shape[-2:]: 71 | spix = spix.transpose(1, 0) 72 | 73 | write_img = mark_boundaries(img, spix, color=(1, 0, 0)) 74 | 75 | plt.imsave("result_" + args.img_path.split('/')[-1], write_img) 76 | 77 | 78 | 79 | 80 | 81 | -------------------------------------------------------------------------------- /libs/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import torch.nn.functional as F 6 | import scipy.io as io 7 | import cv2 8 | 9 | import os 10 | import argparse 11 | import numpy as np 12 | import matplotlib.pyplot as plt 13 | from skimage.segmentation import mark_boundaries 14 | 15 | import sys 16 | import errno 17 | import shutil 18 | import json 19 | import os.path as osp 20 | from skimage import color 21 | from skimage.segmentation._slic import _enforce_label_connectivity_cython 22 | 23 | import random 24 | random.seed(1) 25 | torch.manual_seed(1) 26 | 27 | 28 | def preprocess(image, device="cuda"): 29 | #image = torch.from_numpy(image).permute(2, 0, 1).float()[None] 30 | #h, w = image.shape[-2:] 31 | #coord = torch.stack(torch.meshgrid(torch.arange(h), torch.arange(w))).float()[None] 32 | 33 | #coord[:, 0, : , :] = coord[:, 0, : , :] / np.float(h) 34 | #coord[:, 1, : , :] = coord[:, 1, : , :] / np.float(w) 35 | #image = image / 255.0 36 | 37 | 38 | #input = torch.cat([image, coord], 1).to(device) 39 | #input = (input - input.mean((2, 3), keepdim=True)) / input.std((2, 3), keepdim=True) 40 | 41 | image = color.rgb2lab(image) 42 | 43 | image[:, :, 0] = image[:, :, 0] / np.float(128.0) 44 | image[:, :, 1] = image[:, :, 1] / np.float(256.0) 45 | image[:, :, 2] = image[:, :, 2] / np.float(256.0) 46 | 47 | 48 | image = torch.from_numpy(image).permute(2, 0, 1).float()[None] 49 | h, w = image.shape[-2:] 50 | #print(h, w) 51 | if h > w: 52 | image = image.permute(0, 1, 3, 2) 53 | h, w = image.shape[-2:] 54 | #print(h, w) 55 | coord = torch.stack(torch.meshgrid(torch.arange( h ), torch.arange(w))).float()[None] 56 | #coord = coord / img.shape[-2: 57 | coord[:, 0, : , :] = coord[:, 0, : , :] / np.float(h) - 0.5 58 | coord[:, 1, : , :] = coord[:, 1, : , :] / np.float(w) - 0.5 59 | #print(coord) 60 | #print(image.shape) 61 | input = torch.cat([image, coord], 1).to(device) 62 | input = (input - input.mean((2, 3), keepdim=True)) / input.std((2, 3), keepdim=True) 63 | 64 | return input 65 | 66 | 67 | 68 | 69 | def drawCenter(image, cxs, cys): 70 | 71 | for cx, cy in zip(cxs, cys): 72 | cv2.circle(image, (cy, cx), 2, (0, 0, 1.0), 4) 73 | return image 74 | 75 | 76 | 77 | def read_list(listPath): 78 | images = [] 79 | with open(listPath, 'r') as file_to_read: 80 | while True: 81 | lines = file_to_read.readline() # 整行读取数据 82 | if not lines: 83 | break 84 | pass 85 | path = lines[:-1] 86 | images.append(path) 87 | 88 | return images 89 | 90 | 91 | class Logger(object): 92 | """ 93 | Write console output to external text file. 94 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py. 95 | """ 96 | def __init__(self, fpath=None): 97 | self.console = sys.stdout 98 | self.file = None 99 | if fpath is not None: 100 | mkdir_if_missing(os.path.dirname(fpath)) 101 | self.file = open(fpath, 'w') 102 | 103 | def __del__(self): 104 | self.close() 105 | 106 | def __enter__(self): 107 | pass 108 | 109 | def __exit__(self, *args): 110 | self.close() 111 | 112 | def write(self, msg): 113 | self.console.write(msg) 114 | if self.file is not None: 115 | self.file.write(msg) 116 | 117 | def flush(self): 118 | self.console.flush() 119 | if self.file is not None: 120 | self.file.flush() 121 | os.fsync(self.file.fileno()) 122 | 123 | def close(self): 124 | self.console.close() 125 | if self.file is not None: 126 | self.file.close() 127 | 128 | 129 | def mkdir_if_missing(directory): 130 | if not osp.exists(directory): 131 | try: 132 | os.makedirs(directory) 133 | except OSError as e: 134 | if e.errno != errno.EEXIST: 135 | raise 136 | 137 | -------------------------------------------------------------------------------- /libs/layers/grm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | import scipy.io as io 6 | import cv2 7 | 8 | import os 9 | import argparse 10 | import numpy as np 11 | import matplotlib.pyplot as plt 12 | from skimage.segmentation import mark_boundaries 13 | 14 | from skimage.segmentation._slic import _enforce_label_connectivity_cython 15 | 16 | from libs.utils import * 17 | 18 | import random 19 | random.seed(1) 20 | torch.manual_seed(1) 21 | 22 | 23 | class GALFunction(torch.autograd.Function): 24 | 25 | @staticmethod 26 | def forward(ctx, input_, Lambda, w): 27 | ctx.save_for_backward(input_, Lambda, w) 28 | output = input_ 29 | return output 30 | 31 | @staticmethod 32 | def backward(ctx, grad_output): 33 | 34 | inputs, Lambda, w = ctx.saved_tensors# pragma: no cover 35 | grad_input = None 36 | if ctx.needs_input_grad[0]: 37 | 38 | #print(grad_output.shape) 39 | w_cp = w.squeeze() 40 | F_in, F_out = w_cp.size() 41 | w_c = torch.abs(w_cp[:3, :]) 42 | w_s = torch.abs(w_cp[3:, :]) 43 | w_c = torch.mean(w_c, 0) 44 | w_s = torch.mean(w_s, 0) 45 | #red = dw 46 | dw = w_s * w_c 47 | red = dw / (Lambda + dw) 48 | #print(1 - red) 49 | Lambda = 0.5 * dw + 0.5 * Lambda 50 | 51 | grad_input = grad_output * (1 - red.unsqueeze(-1).unsqueeze(-1).unsqueeze(0)) 52 | return grad_input, None, None 53 | 54 | 55 | class GBLFunction(torch.autograd.Function): 56 | 57 | @staticmethod 58 | def forward(ctx, input_, bd_map): 59 | ctx.save_for_backward(input_, bd_map) 60 | output = input_ 61 | return output 62 | 63 | @staticmethod 64 | def backward(ctx, grad_output): 65 | 66 | inputs, bd_map = ctx.saved_tensors# pragma: no cover 67 | grad_input = None 68 | if ctx.needs_input_grad[0]: 69 | bd_map = 1 - bd_map 70 | #rate = 2 / (1 + np.exp(-10 * bd_map)) - 1 71 | lamda = - 1 * bd_map#rate #* torch.mean(bd_map, [0,1]) 72 | lamda[bd_map < 0.7] = 1 73 | #lamda[bd_map >= 0.5] = -rate[bd_map >= 0.5] 74 | #print(torch.mean(bd_map, [0,1])) 75 | grad_input = grad_output 76 | grad_input[:, 3, :, :] = grad_input[:, 3, :, :] * lamda 77 | grad_input[:, 4, :, :] = grad_input[:, 4, :, :] * lamda 78 | return grad_input, None 79 | 80 | 81 | class GBLayer(nn.Module): 82 | def __init__(self): 83 | """ 84 | A gradient reversal layer. 85 | This layer has no parameters, and simply reverses the gradient 86 | in the backward pass. 87 | """ 88 | 89 | super().__init__() 90 | 91 | def forward(self, f, bd): 92 | 93 | return GBLFunction.apply(f, bd) 94 | 95 | class GALayer(nn.Module): 96 | def __init__(self): 97 | """ 98 | A gradient reversal layer. 99 | This layer has no parameters, and simply reverses the gradient 100 | in the backward pass. 101 | """ 102 | 103 | super().__init__() 104 | 105 | def forward(self, f, Lambda, w): 106 | 107 | return GALFunction.apply(f, Lambda, w) 108 | 109 | 110 | class GRM(nn.Module): 111 | 112 | def __init__(self, args): 113 | super().__init__() 114 | 115 | self.use_gal = args.use_gal 116 | self.use_gbl = args.use_gbl 117 | if args.use_gal: 118 | self.Lambda = torch.autograd.Variable(torch.zeros(20 * 1 * 1).type(torch.float32), volatile=True) 119 | 120 | self.recons = nn.Conv2d(20, 5, 1) 121 | 122 | if args.use_gbl: 123 | self.gbl = GBLayer() 124 | 125 | if args.use_gal: 126 | self.gal = GALayer() 127 | 128 | 129 | def forward(self, f, bd): 130 | 131 | recons = self.recons(f) 132 | 133 | if self.use_gbl: 134 | recons = self.gbl(recons, bd) 135 | 136 | if self.use_gal: 137 | f = self.gal(f, self.Lambda, self.recons.weight) 138 | 139 | return f, recons 140 | 141 | 142 | 143 | 144 | 145 | 146 | -------------------------------------------------------------------------------- /libs/layers/seed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | import scipy.io as io 6 | import cv2 7 | 8 | import os 9 | import argparse 10 | import numpy as np 11 | import matplotlib.pyplot as plt 12 | from skimage.segmentation import mark_boundaries 13 | 14 | from skimage.segmentation._slic import _enforce_label_connectivity_cython 15 | 16 | from libs.utils import * 17 | 18 | import random 19 | random.seed(1) 20 | torch.manual_seed(1) 21 | 22 | 23 | class SeedGenerater(nn.Module): 24 | 25 | def __init__(self, n_spix, device, seed_strategy='grid'): 26 | super().__init__() 27 | 28 | self.c3_inorm_1 = nn.InstanceNorm2d(20, affine=True) 29 | self.c3_seeds_1 = nn.Conv2d(20, 20, 3, padding=1) 30 | self.c3_seeds_2 = nn.Conv2d(20, 3, 1) 31 | 32 | self.relu = nn.ReLU() 33 | 34 | 35 | self.sp_num = n_spix 36 | self.device = device 37 | self.seed_strategy = seed_strategy 38 | 39 | 40 | def seed_generate(self, spix): 41 | 42 | b, _, h, w = spix.size() 43 | 44 | S = h * w / self.sp_num 45 | sp_h = np.int32(np.floor(np.sqrt(S) / (w / np.float(h)))) 46 | sp_w = np.int32(np.floor(S / np.floor(sp_h))) 47 | 48 | 49 | 50 | spix = nn.AdaptiveAvgPool2d((np.int32(np.ceil(h / sp_h)), np.int32(np.ceil(w / sp_w))))(spix) 51 | spix = self.c3_seeds_1(spix) 52 | spix = self.c3_inorm_1(spix) 53 | spix = self.relu(spix) 54 | spix = self.c3_seeds_2(spix) 55 | 56 | prob = spix[:, 0].view(b, -1) #probability for seed 57 | prob = torch.sigmoid(prob) 58 | dx = spix[:, 1].view(b, -1) #x shift for seed 59 | dy = spix[:, 2].view(b, -1) #y shift for seed 60 | dx = torch.sigmoid(dx) - 0.5 61 | dy = torch.sigmoid(dy) - 0.5 62 | 63 | 64 | 65 | prob = prob.view(b, -1) 66 | ####Choosing the max prob in Grid as Seed##### 67 | sp_c = [] 68 | 69 | for i in range(0, h, sp_h): 70 | for j in range(0, w, sp_w): 71 | start_x = i 72 | end_x = min(i + sp_h, h) - 1 73 | len_x = end_x - start_x + 1 74 | start_y = j 75 | end_y = min(j + sp_w, w) - 1 76 | len_y = end_y - start_y + 1 77 | 78 | x = (end_x + start_x) / 2.0 79 | y = (end_y + start_y) / 2.0 80 | 81 | ind = x*w + y 82 | sp_c.append(ind) 83 | 84 | sp_c = torch.from_numpy(np.array(sp_c)).long() 85 | 86 | o_cind = sp_c 87 | o_cx = torch.floor(o_cind / float(w)) 88 | o_cy = torch.floor(o_cind - o_cx * w) 89 | if self.device == 'cuda': 90 | o_cx = o_cx.cuda() 91 | o_cy = o_cy.cuda() 92 | cx = torch.floor(o_cx + dx.view(-1) * sp_h * 2) 93 | cy = torch.floor(o_cy + dy.view(-1) * sp_w * 2) 94 | 95 | 96 | 97 | #print(dx[:, sp_c].view(-1)) 98 | #print(dy[:, sp_c].view(-1)) 99 | 100 | cx = cx.clamp(0, h-1) 101 | cy = cy.clamp(0, w-1) 102 | 103 | 104 | return cx, cy, prob 105 | 106 | def grid_seed(self, spix): 107 | 108 | b, _, h, w = spix.size() 109 | 110 | S = h * w / self.sp_num 111 | sp_h = np.int32(np.floor(np.sqrt(S) / (w / np.float(h)))) 112 | sp_w = np.int32(np.floor(S / np.floor(sp_h))) 113 | 114 | ####Choosing the max prob in Grid as Seed##### 115 | sp_c = [] 116 | for i in range(0, h, sp_h): 117 | for j in range(0, w, sp_w): 118 | start_x = i 119 | end_x = min(i + sp_h, h) - 1 120 | len_x = end_x - start_x + 1 121 | start_y = j 122 | end_y = min(j + sp_w, w) - 1 123 | len_y = end_y - start_y + 1 124 | 125 | x = (end_x + start_x) / 2.0 126 | y = (end_y + start_y) / 2.0 127 | 128 | 129 | ind = x*w + y 130 | sp_c.append(ind) 131 | 132 | sp_c = torch.from_numpy(np.array(sp_c)).long() 133 | 134 | o_cind = sp_c 135 | o_cx = torch.floor(o_cind / float(w)) 136 | o_cy = torch.floor(o_cind - o_cx * w) 137 | if self.device == 'cuda': 138 | o_cx = o_cx.cuda() 139 | o_cy = o_cy.cuda() 140 | 141 | cx = o_cx.clamp(0, h-1) 142 | cy = o_cy.clamp(0, w-1) 143 | 144 | 145 | return cx, cy, torch.ones(b, h*w) 146 | 147 | def forward(self, x): 148 | 149 | if self.seed_strategy == 'network': 150 | #seed_dis = self.c3_seeds(x) 151 | cx, cy, probs = self.seed_generate(x) 152 | elif self.seed_strategy == 'grid': 153 | cx, cy, probs = self.grid_seed(x) 154 | 155 | return cx, cy, probs 156 | --------------------------------------------------------------------------------