├── demo_imgs
├── 012.jpg
├── 018.jpg
├── 238.jpg
├── 38092.jpg
├── 41029.jpg
├── 42049.jpg
├── 107793.jpg
├── 178279.jpg
├── 192566.jpg
├── 2007_001185.jpg
├── 2007_007151.jpg
└── 2008_004532.jpg
├── pics
├── results.png
├── strategy.png
└── structures.png
├── libs
├── EvalSPModule.so
├── __pycache__
│ ├── model.cpython-37.pyc
│ ├── test.cpython-37.pyc
│ ├── train.cpython-37.pyc
│ ├── utils.cpython-37.pyc
│ ├── layers.cpython-37.pyc
│ ├── losses.cpython-37.pyc
│ ├── update.cpython-37.pyc
│ └── assignment.cpython-37.pyc
├── layers
│ ├── __pycache__
│ │ ├── grm.cpython-37.pyc
│ │ ├── agrl.cpython-37.pyc
│ │ ├── seed.cpython-37.pyc
│ │ └── embedder.cpython-37.pyc
│ ├── embedder.py
│ ├── grm.py
│ └── seed.py
├── test.py
├── model.py
└── utils.py
├── lnsnet_BSDS_checkpoint.pth
├── runDemo.sh
├── README.md
└── demo.py
/demo_imgs/012.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zh460045050/LNSNet/HEAD/demo_imgs/012.jpg
--------------------------------------------------------------------------------
/demo_imgs/018.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zh460045050/LNSNet/HEAD/demo_imgs/018.jpg
--------------------------------------------------------------------------------
/demo_imgs/238.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zh460045050/LNSNet/HEAD/demo_imgs/238.jpg
--------------------------------------------------------------------------------
/pics/results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zh460045050/LNSNet/HEAD/pics/results.png
--------------------------------------------------------------------------------
/pics/strategy.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zh460045050/LNSNet/HEAD/pics/strategy.png
--------------------------------------------------------------------------------
/demo_imgs/38092.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zh460045050/LNSNet/HEAD/demo_imgs/38092.jpg
--------------------------------------------------------------------------------
/demo_imgs/41029.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zh460045050/LNSNet/HEAD/demo_imgs/41029.jpg
--------------------------------------------------------------------------------
/demo_imgs/42049.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zh460045050/LNSNet/HEAD/demo_imgs/42049.jpg
--------------------------------------------------------------------------------
/pics/structures.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zh460045050/LNSNet/HEAD/pics/structures.png
--------------------------------------------------------------------------------
/demo_imgs/107793.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zh460045050/LNSNet/HEAD/demo_imgs/107793.jpg
--------------------------------------------------------------------------------
/demo_imgs/178279.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zh460045050/LNSNet/HEAD/demo_imgs/178279.jpg
--------------------------------------------------------------------------------
/demo_imgs/192566.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zh460045050/LNSNet/HEAD/demo_imgs/192566.jpg
--------------------------------------------------------------------------------
/libs/EvalSPModule.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zh460045050/LNSNet/HEAD/libs/EvalSPModule.so
--------------------------------------------------------------------------------
/demo_imgs/2007_001185.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zh460045050/LNSNet/HEAD/demo_imgs/2007_001185.jpg
--------------------------------------------------------------------------------
/demo_imgs/2007_007151.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zh460045050/LNSNet/HEAD/demo_imgs/2007_007151.jpg
--------------------------------------------------------------------------------
/demo_imgs/2008_004532.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zh460045050/LNSNet/HEAD/demo_imgs/2008_004532.jpg
--------------------------------------------------------------------------------
/lnsnet_BSDS_checkpoint.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zh460045050/LNSNet/HEAD/lnsnet_BSDS_checkpoint.pth
--------------------------------------------------------------------------------
/libs/__pycache__/model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zh460045050/LNSNet/HEAD/libs/__pycache__/model.cpython-37.pyc
--------------------------------------------------------------------------------
/libs/__pycache__/test.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zh460045050/LNSNet/HEAD/libs/__pycache__/test.cpython-37.pyc
--------------------------------------------------------------------------------
/libs/__pycache__/train.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zh460045050/LNSNet/HEAD/libs/__pycache__/train.cpython-37.pyc
--------------------------------------------------------------------------------
/libs/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zh460045050/LNSNet/HEAD/libs/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/libs/__pycache__/layers.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zh460045050/LNSNet/HEAD/libs/__pycache__/layers.cpython-37.pyc
--------------------------------------------------------------------------------
/libs/__pycache__/losses.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zh460045050/LNSNet/HEAD/libs/__pycache__/losses.cpython-37.pyc
--------------------------------------------------------------------------------
/libs/__pycache__/update.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zh460045050/LNSNet/HEAD/libs/__pycache__/update.cpython-37.pyc
--------------------------------------------------------------------------------
/libs/__pycache__/assignment.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zh460045050/LNSNet/HEAD/libs/__pycache__/assignment.cpython-37.pyc
--------------------------------------------------------------------------------
/libs/layers/__pycache__/grm.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zh460045050/LNSNet/HEAD/libs/layers/__pycache__/grm.cpython-37.pyc
--------------------------------------------------------------------------------
/libs/layers/__pycache__/agrl.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zh460045050/LNSNet/HEAD/libs/layers/__pycache__/agrl.cpython-37.pyc
--------------------------------------------------------------------------------
/libs/layers/__pycache__/seed.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zh460045050/LNSNet/HEAD/libs/layers/__pycache__/seed.cpython-37.pyc
--------------------------------------------------------------------------------
/libs/layers/__pycache__/embedder.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zh460045050/LNSNet/HEAD/libs/layers/__pycache__/embedder.cpython-37.pyc
--------------------------------------------------------------------------------
/runDemo.sh:
--------------------------------------------------------------------------------
1 | python demo.py --n_spix 100 \
2 | --img_path "./demo_imgs/107793.jpg" \
3 | --check_path "./lnsnet_BSDS_checkpoint.pth" \
4 | --seed_strategy 'network'
5 |
--------------------------------------------------------------------------------
/libs/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 | import torch.nn.functional as F
5 | import scipy.io as io
6 | import cv2
7 | import time
8 |
9 | import os
10 | import argparse
11 | import numpy as np
12 | import matplotlib.pyplot as plt
13 | from skimage.segmentation import mark_boundaries
14 |
15 | from skimage.segmentation._slic import _enforce_label_connectivity_cython
16 |
17 | from libs.model import *
18 | from libs.utils import *
19 |
20 | import random
21 | random.seed(1)
22 | torch.manual_seed(1)
23 |
24 | def assignment_test(f, input, cx, cy, alpha=1):
25 |
26 | b, _, h, w = input.size()
27 | p = input[:, 3:, :, :]
28 |
29 | p = p.view(b, 2, -1)
30 | cind = cx * w + cy
31 | cind = cind.long()
32 | c_p = p[:, :, cind]
33 | c_f = f[:, :, cind]
34 |
35 | _, c, k = c_f.size()
36 |
37 | N = h*w
38 |
39 | dis = torch.zeros(b, k, N)
40 | for i in range(0, k):
41 | cur_c_f = c_f[:, :, i].unsqueeze(-1).expand(b, c, N)
42 | cur_p_ij = cur_c_f - f
43 | cur_p_ij = torch.pow(cur_p_ij, 2)
44 | cur_p_ij = torch.sum(cur_p_ij, dim=1)
45 | dis[:, i, :] = cur_p_ij
46 | dis = dis / alpha
47 | dis = torch.pow((1 + dis), -(alpha + 1) / 2)
48 | dis = dis.view(b, k, N).permute(0, 2, 1).contiguous() #b,N,k
49 | dis = dis / torch.sum(dis, dim=2).unsqueeze(-1)
50 |
51 | return dis
52 |
--------------------------------------------------------------------------------
/libs/layers/embedder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 | import torch.nn.functional as F
5 | import scipy.io as io
6 | import cv2
7 |
8 | import os
9 | import argparse
10 | import numpy as np
11 | import matplotlib.pyplot as plt
12 | from skimage.segmentation import mark_boundaries
13 |
14 | from skimage.segmentation._slic import _enforce_label_connectivity_cython
15 |
16 | from libs.utils import *
17 |
18 | import random
19 | random.seed(1)
20 | torch.manual_seed(1)
21 |
22 |
23 | class Embedder(nn.Module):
24 |
25 | def __init__(self, is_dilation=True):
26 | super().__init__()
27 |
28 | self.is_dilation = is_dilation
29 |
30 |
31 | self.rpad_1 = nn.ReflectionPad2d(1)
32 | if self.is_dilation:
33 | self.c1_1 = nn.Conv2d(5, 10, 3, padding=0)
34 | self.c1_2 = nn.Conv2d(5, 10, 3, padding=0, dilation=1)
35 | self.c1_3 = nn.Conv2d(5, 10, 3, padding=1, dilation=2)
36 | self.c1_4 = nn.Sequential(nn.InstanceNorm2d(35, affine=True), nn.ReLU(), nn.ReflectionPad2d(1), nn.Conv2d(35, 10, 3, padding=0))
37 | else:
38 | self.c1 = nn.Conv2d(5, 10, 3, padding=0)
39 | self.inorm_1 = nn.InstanceNorm2d(10, affine=True)
40 | #self.inorm_1 = nn.BatchNorm2d(10, affine=True)
41 |
42 | self.rpad_2 = nn.ReflectionPad2d(1)
43 | self.c2 = nn.Conv2d(15, 20, 3, padding=0)
44 | #self.inorm_2 = nn.BatchNorm2d(20, affine=True)
45 | self.inorm_2 = nn.InstanceNorm2d(20, affine=True)
46 |
47 | self.relu = nn.ReLU()
48 |
49 |
50 | def forward(self, x):
51 |
52 | spix = self.rpad_1(x)
53 | if self.is_dilation:
54 | spix_1 = self.c1_1(spix)
55 | spix_2 = self.c1_2(spix)
56 | spix_3 = self.c1_3(spix)
57 | spix = torch.cat([x, spix_1, spix_2, spix_3], dim=1)
58 | spix = self.c1_4(spix)
59 | else:
60 | spix = self.c1(spix)
61 |
62 | spix = self.inorm_1(spix)
63 | spix = self.relu(spix)
64 |
65 | spix = torch.cat((spix, x), dim=1)
66 |
67 | spix = self.rpad_2(spix)
68 | spix = self.c2(spix)
69 | spix = self.inorm_2(spix)
70 | spix = self.relu(spix)
71 |
72 | return spix
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
--------------------------------------------------------------------------------
/libs/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 | import torch.nn.functional as F
5 | import scipy.io as io
6 | import cv2
7 |
8 | import os
9 | import argparse
10 | import numpy as np
11 | import matplotlib.pyplot as plt
12 | from skimage.segmentation import mark_boundaries
13 |
14 | from skimage.segmentation._slic import _enforce_label_connectivity_cython
15 |
16 | from libs.layers.grm import *
17 | from libs.layers.embedder import *
18 | from libs.layers.seed import *
19 |
20 | import random
21 | random.seed(1)
22 | torch.manual_seed(1)
23 |
24 |
25 | class LNSN(nn.Module):
26 |
27 | def __init__(self, n_spix, args):
28 | super().__init__()
29 | self.n_spix = n_spix
30 | self.sp_num = n_spix
31 | self.is_dilation = args.is_dilation
32 | self.device = args.device
33 |
34 | self.seed_strategy = args.seed_strategy
35 |
36 |
37 | self.train = True
38 | self.kn = args.kn
39 | ###########Optimizer Parameter##########
40 |
41 | self.embedder = Embedder(self.is_dilation)
42 |
43 | self.generater = SeedGenerater(self.n_spix, self.device, seed_strategy=self.seed_strategy)
44 |
45 | self.grm = GRM(args)
46 |
47 |
48 | #############init#######################
49 |
50 | for m in self.modules():
51 | if isinstance(m, nn.Conv2d):
52 | nn.init.xavier_normal_(m.weight)
53 | if m.bias is not None:
54 | nn.init.constant_(m.bias, 0)
55 | elif isinstance(m, nn.InstanceNorm2d):
56 | nn.init.constant_(m.weight, 1)
57 | nn.init.constant_(m.bias, 0)
58 | elif isinstance(m, nn.BatchNorm2d):
59 | nn.init.constant_(m.weight, 1)
60 | nn.init.constant_(m.bias, 0)
61 |
62 | def forward(self, x, bd):
63 |
64 |
65 | b, _, h, w = x.size()
66 | ##########Feature Extracting###########
67 | f = self.embedder(x)
68 |
69 | #########Seed Generate#######
70 | cx, cy, probs = self.generater(f)
71 |
72 |
73 | if self.train:
74 |
75 | f, recons = self.grm(f, bd)
76 |
77 | f = f.view(b, -1, h*w)
78 |
79 | return recons, cx, cy, f, probs
80 | else:
81 | f = f.view(b, -1, h*w)
82 | return cx, cy, f, probs
83 |
84 |
85 |
86 |
87 |
88 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # LNSNet
2 | ## Overview
3 |
4 | Official implementation of [Learning the Superpixel in a Non-iterative and Lifelong Manner][arxiv] (CVPR'21)
5 |
6 | ### Learning Strategy
7 |
8 | The proposed LNSNet views superpixel segmentation process of each image as **an independent pixel-level clustering task** and use **lifelong learning strategy** to train the superpixel segmentation network for a a series of images.
9 |
10 |
11 |
12 |
13 |
14 |
15 | ### Model Structure
16 |
17 | The structure of proposed LNS-Net shown in Fig. 3 contains three parts:
18 |
19 | 1) **Feature Embedder Module (FEM)** that embeds the original feature into a cluster-friendly space;
20 | 2) **Non-iteratively Clustering Module (NCM)** that assigns the label for pixels with the help of a seed estimation module, which automatically estimates
21 | the indexes of seed nodes;
22 | 3) **Gradient Rescaling Module (GRM)** that adaptively rescales the gradient for each weight parameter based on the channel and spatial context to avoid catastrophic forgetting
23 | for the sequential learning.
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 | ## Getting Started
32 |
33 | Here we only release the model trained on BSDS dataset and corresponding code to utilizes it for superpixel segmentation. The whole training code will be coming soon.
34 |
35 | To uese the given model for generate superpixel:
36 |
37 | `git clone https://github.com/zh460045050/LNSNet`
38 |
39 | `cd LNSNet`
40 |
41 | `sh runDemo.sh`
42 |
43 | or
44 |
45 | `python demo.py --n_spix $num_superpixel --img_path $input_img_path --check_path lnsnet_BSDS_checkpoint.pth`
46 |
47 | The performance and complexity of methods for generating 100 superpixel on BSDS test dataset with image size 481*321:
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 | ## Citation
56 |
57 | If you find our work useful in your research, please cite:
58 |
59 | @InProceedings\{Lei_2021_CVPR,
60 | title = \{Learning the Superpixel in a Non-iterative and Lifelong Manner\},
61 | author = \{Zhu, Lei and She, Qi and Zhang, Bin and Lu, Yanye and Lu, Zhilin and Li, Duo and Hu, Jie\},
62 | booktitle = \{IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)\},
63 | month = \{June\},
64 | year = \{2021\}
65 | \}
66 |
67 |
68 | [arxiv]: https://arxiv.org/abs/2103.10681
69 |
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 | import torch.nn.functional as F
5 | import scipy.io as io
6 | import cv2
7 |
8 | from tensorboardX import SummaryWriter
9 | import os
10 | import argparse
11 | import numpy as np
12 | import matplotlib.pyplot as plt
13 | from skimage.segmentation import mark_boundaries
14 |
15 | from skimage.segmentation._slic import _enforce_label_connectivity_cython
16 |
17 | from libs.model import *
18 | from libs.test import *
19 |
20 | import time
21 |
22 | import random
23 | random.seed(1)
24 | torch.manual_seed(1)
25 |
26 | parser = argparse.ArgumentParser()
27 |
28 | ######Data Setting#######
29 | parser.add_argument('--img_path', default='demo_imgs/012.jpg', help='The path of the source image')
30 | ######Model Setting#######
31 | parser.add_argument('--device', default='cpu', help='use cuda or cpu')
32 | parser.add_argument('--n_spix', type=int, default=100, help='The number of superpixel')
33 | parser.add_argument('--kn', type=int, default=16, help='The number of dis limit')
34 | parser.add_argument('--seed_strategy', type=str, default='network', help='network/grid')
35 | #####Optimizing Setting####
36 | parser.add_argument('--lr', type=float, default=3e-4, help='learning rate')
37 | parser.add_argument('--use_gal', default=True, help='Using sowm weight update')
38 | parser.add_argument('--use_gbl', default=True, help='Using sowm weight update')
39 | parser.add_argument('--is_dilation', default=True, help='Using dilation convolution')
40 | parser.add_argument('--check_path', type=str, default='./lnsn_BSDS_checkpoint.pth.pth')
41 |
42 | ################
43 |
44 | args = parser.parse_args()
45 |
46 |
47 | model = LNSN(args.n_spix, args)
48 |
49 | model.load_state_dict(torch.load(args.check_path))
50 |
51 | img = plt.imread(args.img_path)
52 | input = preprocess(img, args.device)
53 |
54 |
55 | with torch.no_grad():
56 |
57 | b, _, h, w = input.size()
58 | recons, cx, cy, f, probs = model.forward(input, torch.zeros(h, w))
59 | spix = assignment_test(f, input, cx, cy)
60 |
61 | spix = spix.permute(0, 2, 1).contiguous().view(b, -1, h, w)
62 | spix = spix.argmax(1).squeeze().to("cpu").detach().numpy()
63 |
64 |
65 | segment_size = spix.size / args.n_spix
66 | min_size = int(0.06 * segment_size)
67 | max_size = int(3.0 * segment_size)
68 | spix = _enforce_label_connectivity_cython(spix[None], min_size, max_size)[0]
69 |
70 | if img.shape[:2] != spix.shape[-2:]:
71 | spix = spix.transpose(1, 0)
72 |
73 | write_img = mark_boundaries(img, spix, color=(1, 0, 0))
74 |
75 | plt.imsave("result_" + args.img_path.split('/')[-1], write_img)
76 |
77 |
78 |
79 |
80 |
81 |
--------------------------------------------------------------------------------
/libs/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | import torch
3 | import torch.nn as nn
4 | import torch.optim as optim
5 | import torch.nn.functional as F
6 | import scipy.io as io
7 | import cv2
8 |
9 | import os
10 | import argparse
11 | import numpy as np
12 | import matplotlib.pyplot as plt
13 | from skimage.segmentation import mark_boundaries
14 |
15 | import sys
16 | import errno
17 | import shutil
18 | import json
19 | import os.path as osp
20 | from skimage import color
21 | from skimage.segmentation._slic import _enforce_label_connectivity_cython
22 |
23 | import random
24 | random.seed(1)
25 | torch.manual_seed(1)
26 |
27 |
28 | def preprocess(image, device="cuda"):
29 | #image = torch.from_numpy(image).permute(2, 0, 1).float()[None]
30 | #h, w = image.shape[-2:]
31 | #coord = torch.stack(torch.meshgrid(torch.arange(h), torch.arange(w))).float()[None]
32 |
33 | #coord[:, 0, : , :] = coord[:, 0, : , :] / np.float(h)
34 | #coord[:, 1, : , :] = coord[:, 1, : , :] / np.float(w)
35 | #image = image / 255.0
36 |
37 |
38 | #input = torch.cat([image, coord], 1).to(device)
39 | #input = (input - input.mean((2, 3), keepdim=True)) / input.std((2, 3), keepdim=True)
40 |
41 | image = color.rgb2lab(image)
42 |
43 | image[:, :, 0] = image[:, :, 0] / np.float(128.0)
44 | image[:, :, 1] = image[:, :, 1] / np.float(256.0)
45 | image[:, :, 2] = image[:, :, 2] / np.float(256.0)
46 |
47 |
48 | image = torch.from_numpy(image).permute(2, 0, 1).float()[None]
49 | h, w = image.shape[-2:]
50 | #print(h, w)
51 | if h > w:
52 | image = image.permute(0, 1, 3, 2)
53 | h, w = image.shape[-2:]
54 | #print(h, w)
55 | coord = torch.stack(torch.meshgrid(torch.arange( h ), torch.arange(w))).float()[None]
56 | #coord = coord / img.shape[-2:
57 | coord[:, 0, : , :] = coord[:, 0, : , :] / np.float(h) - 0.5
58 | coord[:, 1, : , :] = coord[:, 1, : , :] / np.float(w) - 0.5
59 | #print(coord)
60 | #print(image.shape)
61 | input = torch.cat([image, coord], 1).to(device)
62 | input = (input - input.mean((2, 3), keepdim=True)) / input.std((2, 3), keepdim=True)
63 |
64 | return input
65 |
66 |
67 |
68 |
69 | def drawCenter(image, cxs, cys):
70 |
71 | for cx, cy in zip(cxs, cys):
72 | cv2.circle(image, (cy, cx), 2, (0, 0, 1.0), 4)
73 | return image
74 |
75 |
76 |
77 | def read_list(listPath):
78 | images = []
79 | with open(listPath, 'r') as file_to_read:
80 | while True:
81 | lines = file_to_read.readline() # 整行读取数据
82 | if not lines:
83 | break
84 | pass
85 | path = lines[:-1]
86 | images.append(path)
87 |
88 | return images
89 |
90 |
91 | class Logger(object):
92 | """
93 | Write console output to external text file.
94 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py.
95 | """
96 | def __init__(self, fpath=None):
97 | self.console = sys.stdout
98 | self.file = None
99 | if fpath is not None:
100 | mkdir_if_missing(os.path.dirname(fpath))
101 | self.file = open(fpath, 'w')
102 |
103 | def __del__(self):
104 | self.close()
105 |
106 | def __enter__(self):
107 | pass
108 |
109 | def __exit__(self, *args):
110 | self.close()
111 |
112 | def write(self, msg):
113 | self.console.write(msg)
114 | if self.file is not None:
115 | self.file.write(msg)
116 |
117 | def flush(self):
118 | self.console.flush()
119 | if self.file is not None:
120 | self.file.flush()
121 | os.fsync(self.file.fileno())
122 |
123 | def close(self):
124 | self.console.close()
125 | if self.file is not None:
126 | self.file.close()
127 |
128 |
129 | def mkdir_if_missing(directory):
130 | if not osp.exists(directory):
131 | try:
132 | os.makedirs(directory)
133 | except OSError as e:
134 | if e.errno != errno.EEXIST:
135 | raise
136 |
137 |
--------------------------------------------------------------------------------
/libs/layers/grm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 | import torch.nn.functional as F
5 | import scipy.io as io
6 | import cv2
7 |
8 | import os
9 | import argparse
10 | import numpy as np
11 | import matplotlib.pyplot as plt
12 | from skimage.segmentation import mark_boundaries
13 |
14 | from skimage.segmentation._slic import _enforce_label_connectivity_cython
15 |
16 | from libs.utils import *
17 |
18 | import random
19 | random.seed(1)
20 | torch.manual_seed(1)
21 |
22 |
23 | class GALFunction(torch.autograd.Function):
24 |
25 | @staticmethod
26 | def forward(ctx, input_, Lambda, w):
27 | ctx.save_for_backward(input_, Lambda, w)
28 | output = input_
29 | return output
30 |
31 | @staticmethod
32 | def backward(ctx, grad_output):
33 |
34 | inputs, Lambda, w = ctx.saved_tensors# pragma: no cover
35 | grad_input = None
36 | if ctx.needs_input_grad[0]:
37 |
38 | #print(grad_output.shape)
39 | w_cp = w.squeeze()
40 | F_in, F_out = w_cp.size()
41 | w_c = torch.abs(w_cp[:3, :])
42 | w_s = torch.abs(w_cp[3:, :])
43 | w_c = torch.mean(w_c, 0)
44 | w_s = torch.mean(w_s, 0)
45 | #red = dw
46 | dw = w_s * w_c
47 | red = dw / (Lambda + dw)
48 | #print(1 - red)
49 | Lambda = 0.5 * dw + 0.5 * Lambda
50 |
51 | grad_input = grad_output * (1 - red.unsqueeze(-1).unsqueeze(-1).unsqueeze(0))
52 | return grad_input, None, None
53 |
54 |
55 | class GBLFunction(torch.autograd.Function):
56 |
57 | @staticmethod
58 | def forward(ctx, input_, bd_map):
59 | ctx.save_for_backward(input_, bd_map)
60 | output = input_
61 | return output
62 |
63 | @staticmethod
64 | def backward(ctx, grad_output):
65 |
66 | inputs, bd_map = ctx.saved_tensors# pragma: no cover
67 | grad_input = None
68 | if ctx.needs_input_grad[0]:
69 | bd_map = 1 - bd_map
70 | #rate = 2 / (1 + np.exp(-10 * bd_map)) - 1
71 | lamda = - 1 * bd_map#rate #* torch.mean(bd_map, [0,1])
72 | lamda[bd_map < 0.7] = 1
73 | #lamda[bd_map >= 0.5] = -rate[bd_map >= 0.5]
74 | #print(torch.mean(bd_map, [0,1]))
75 | grad_input = grad_output
76 | grad_input[:, 3, :, :] = grad_input[:, 3, :, :] * lamda
77 | grad_input[:, 4, :, :] = grad_input[:, 4, :, :] * lamda
78 | return grad_input, None
79 |
80 |
81 | class GBLayer(nn.Module):
82 | def __init__(self):
83 | """
84 | A gradient reversal layer.
85 | This layer has no parameters, and simply reverses the gradient
86 | in the backward pass.
87 | """
88 |
89 | super().__init__()
90 |
91 | def forward(self, f, bd):
92 |
93 | return GBLFunction.apply(f, bd)
94 |
95 | class GALayer(nn.Module):
96 | def __init__(self):
97 | """
98 | A gradient reversal layer.
99 | This layer has no parameters, and simply reverses the gradient
100 | in the backward pass.
101 | """
102 |
103 | super().__init__()
104 |
105 | def forward(self, f, Lambda, w):
106 |
107 | return GALFunction.apply(f, Lambda, w)
108 |
109 |
110 | class GRM(nn.Module):
111 |
112 | def __init__(self, args):
113 | super().__init__()
114 |
115 | self.use_gal = args.use_gal
116 | self.use_gbl = args.use_gbl
117 | if args.use_gal:
118 | self.Lambda = torch.autograd.Variable(torch.zeros(20 * 1 * 1).type(torch.float32), volatile=True)
119 |
120 | self.recons = nn.Conv2d(20, 5, 1)
121 |
122 | if args.use_gbl:
123 | self.gbl = GBLayer()
124 |
125 | if args.use_gal:
126 | self.gal = GALayer()
127 |
128 |
129 | def forward(self, f, bd):
130 |
131 | recons = self.recons(f)
132 |
133 | if self.use_gbl:
134 | recons = self.gbl(recons, bd)
135 |
136 | if self.use_gal:
137 | f = self.gal(f, self.Lambda, self.recons.weight)
138 |
139 | return f, recons
140 |
141 |
142 |
143 |
144 |
145 |
146 |
--------------------------------------------------------------------------------
/libs/layers/seed.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 | import torch.nn.functional as F
5 | import scipy.io as io
6 | import cv2
7 |
8 | import os
9 | import argparse
10 | import numpy as np
11 | import matplotlib.pyplot as plt
12 | from skimage.segmentation import mark_boundaries
13 |
14 | from skimage.segmentation._slic import _enforce_label_connectivity_cython
15 |
16 | from libs.utils import *
17 |
18 | import random
19 | random.seed(1)
20 | torch.manual_seed(1)
21 |
22 |
23 | class SeedGenerater(nn.Module):
24 |
25 | def __init__(self, n_spix, device, seed_strategy='grid'):
26 | super().__init__()
27 |
28 | self.c3_inorm_1 = nn.InstanceNorm2d(20, affine=True)
29 | self.c3_seeds_1 = nn.Conv2d(20, 20, 3, padding=1)
30 | self.c3_seeds_2 = nn.Conv2d(20, 3, 1)
31 |
32 | self.relu = nn.ReLU()
33 |
34 |
35 | self.sp_num = n_spix
36 | self.device = device
37 | self.seed_strategy = seed_strategy
38 |
39 |
40 | def seed_generate(self, spix):
41 |
42 | b, _, h, w = spix.size()
43 |
44 | S = h * w / self.sp_num
45 | sp_h = np.int32(np.floor(np.sqrt(S) / (w / np.float(h))))
46 | sp_w = np.int32(np.floor(S / np.floor(sp_h)))
47 |
48 |
49 |
50 | spix = nn.AdaptiveAvgPool2d((np.int32(np.ceil(h / sp_h)), np.int32(np.ceil(w / sp_w))))(spix)
51 | spix = self.c3_seeds_1(spix)
52 | spix = self.c3_inorm_1(spix)
53 | spix = self.relu(spix)
54 | spix = self.c3_seeds_2(spix)
55 |
56 | prob = spix[:, 0].view(b, -1) #probability for seed
57 | prob = torch.sigmoid(prob)
58 | dx = spix[:, 1].view(b, -1) #x shift for seed
59 | dy = spix[:, 2].view(b, -1) #y shift for seed
60 | dx = torch.sigmoid(dx) - 0.5
61 | dy = torch.sigmoid(dy) - 0.5
62 |
63 |
64 |
65 | prob = prob.view(b, -1)
66 | ####Choosing the max prob in Grid as Seed#####
67 | sp_c = []
68 |
69 | for i in range(0, h, sp_h):
70 | for j in range(0, w, sp_w):
71 | start_x = i
72 | end_x = min(i + sp_h, h) - 1
73 | len_x = end_x - start_x + 1
74 | start_y = j
75 | end_y = min(j + sp_w, w) - 1
76 | len_y = end_y - start_y + 1
77 |
78 | x = (end_x + start_x) / 2.0
79 | y = (end_y + start_y) / 2.0
80 |
81 | ind = x*w + y
82 | sp_c.append(ind)
83 |
84 | sp_c = torch.from_numpy(np.array(sp_c)).long()
85 |
86 | o_cind = sp_c
87 | o_cx = torch.floor(o_cind / float(w))
88 | o_cy = torch.floor(o_cind - o_cx * w)
89 | if self.device == 'cuda':
90 | o_cx = o_cx.cuda()
91 | o_cy = o_cy.cuda()
92 | cx = torch.floor(o_cx + dx.view(-1) * sp_h * 2)
93 | cy = torch.floor(o_cy + dy.view(-1) * sp_w * 2)
94 |
95 |
96 |
97 | #print(dx[:, sp_c].view(-1))
98 | #print(dy[:, sp_c].view(-1))
99 |
100 | cx = cx.clamp(0, h-1)
101 | cy = cy.clamp(0, w-1)
102 |
103 |
104 | return cx, cy, prob
105 |
106 | def grid_seed(self, spix):
107 |
108 | b, _, h, w = spix.size()
109 |
110 | S = h * w / self.sp_num
111 | sp_h = np.int32(np.floor(np.sqrt(S) / (w / np.float(h))))
112 | sp_w = np.int32(np.floor(S / np.floor(sp_h)))
113 |
114 | ####Choosing the max prob in Grid as Seed#####
115 | sp_c = []
116 | for i in range(0, h, sp_h):
117 | for j in range(0, w, sp_w):
118 | start_x = i
119 | end_x = min(i + sp_h, h) - 1
120 | len_x = end_x - start_x + 1
121 | start_y = j
122 | end_y = min(j + sp_w, w) - 1
123 | len_y = end_y - start_y + 1
124 |
125 | x = (end_x + start_x) / 2.0
126 | y = (end_y + start_y) / 2.0
127 |
128 |
129 | ind = x*w + y
130 | sp_c.append(ind)
131 |
132 | sp_c = torch.from_numpy(np.array(sp_c)).long()
133 |
134 | o_cind = sp_c
135 | o_cx = torch.floor(o_cind / float(w))
136 | o_cy = torch.floor(o_cind - o_cx * w)
137 | if self.device == 'cuda':
138 | o_cx = o_cx.cuda()
139 | o_cy = o_cy.cuda()
140 |
141 | cx = o_cx.clamp(0, h-1)
142 | cy = o_cy.clamp(0, w-1)
143 |
144 |
145 | return cx, cy, torch.ones(b, h*w)
146 |
147 | def forward(self, x):
148 |
149 | if self.seed_strategy == 'network':
150 | #seed_dis = self.c3_seeds(x)
151 | cx, cy, probs = self.seed_generate(x)
152 | elif self.seed_strategy == 'grid':
153 | cx, cy, probs = self.grid_seed(x)
154 |
155 | return cx, cy, probs
156 |
--------------------------------------------------------------------------------