├── LICENSE
├── README.md
├── SSN_deep_result.png
├── SSN_pix_result.png
├── inference.py
├── lib
├── __init__.py
├── dataset
│ ├── __init__.py
│ ├── augmentation.py
│ └── bsds.py
├── ssn
│ ├── __init__.py
│ ├── pair_wise_distance.py
│ ├── pair_wise_distance_cuda_source.py
│ ├── ssn.py
│ └── test.py
└── utils
│ ├── __init__.py
│ ├── loss.py
│ ├── meter.py
│ └── sparse_utils.py
├── model.py
└── train.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 teppei suzuki
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Superpixel Sampling Networks
2 | PyTorch implementation of Superpixel Sampling Networks
3 | paper: https://arxiv.org/abs/1807.10174
4 | original code: https://github.com/NVlabs/ssn_superpixels
5 |
6 | ### Note
7 | A pure PyTorch implementation of the core component, differentiable SLIC, is available [here](https://github.com/perrying/diffSLIC) (note that it implements the similarity function as the cosine similarity instead of the negative Euclidean distance).
8 |
9 | # Requirements
10 | - PyTorch >= 1.4
11 | - scikit-image
12 | - matplotlib
13 |
14 | # Usage
15 | ## inference
16 | SSN_pix
17 | ```
18 | python inference --image /path/to/image
19 | ```
20 | SSN_deep
21 | ```
22 | python inference --image /path/to/image --weight /path/to/pretrained_weight
23 | ```
24 |
25 | ## training
26 | ```
27 | python train.py --root /path/to/BSDS500
28 | ```
29 |
30 | # Results
31 | SSN_pix
32 |
33 |
34 | SSN_deep
35 |
36 |
--------------------------------------------------------------------------------
/SSN_deep_result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/perrying/ssn-pytorch/3368840b1b72efcd8ea7ca61d1b08b2dfb846d47/SSN_deep_result.png
--------------------------------------------------------------------------------
/SSN_pix_result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/perrying/ssn-pytorch/3368840b1b72efcd8ea7ca61d1b08b2dfb846d47/SSN_pix_result.png
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import torch
4 |
5 | from skimage.color import rgb2lab
6 | from skimage.segmentation._slic import _enforce_label_connectivity_cython
7 |
8 | from lib.ssn.ssn import sparse_ssn_iter
9 |
10 |
11 | @torch.no_grad()
12 | def inference(image, nspix, n_iter, fdim=None, color_scale=0.26, pos_scale=2.5, weight=None, enforce_connectivity=True):
13 | """
14 | generate superpixels
15 |
16 | Args:
17 | image: numpy.ndarray
18 | An array of shape (h, w, c)
19 | nspix: int
20 | number of superpixels
21 | n_iter: int
22 | number of iterations
23 | fdim (optional): int
24 | feature dimension for supervised setting
25 | color_scale: float
26 | color channel factor
27 | pos_scale: float
28 | pixel coordinate factor
29 | weight: state_dict
30 | pretrained weight
31 | enforce_connectivity: bool
32 | if True, enforce superpixel connectivity in postprocessing
33 |
34 | Return:
35 | labels: numpy.ndarray
36 | An array of shape (h, w)
37 | """
38 | if weight is not None:
39 | from model import SSNModel
40 | model = SSNModel(fdim, nspix, n_iter).to("cuda")
41 | model.load_state_dict(torch.load(weight))
42 | model.eval()
43 | else:
44 | model = lambda data: sparse_ssn_iter(data, nspix, n_iter)
45 |
46 | height, width = image.shape[:2]
47 |
48 | nspix_per_axis = int(math.sqrt(nspix))
49 | pos_scale = pos_scale * max(nspix_per_axis/height, nspix_per_axis/width)
50 |
51 | coords = torch.stack(torch.meshgrid(torch.arange(height, device="cuda"), torch.arange(width, device="cuda")), 0)
52 | coords = coords[None].float()
53 |
54 | image = rgb2lab(image)
55 | image = torch.from_numpy(image).permute(2, 0, 1)[None].to("cuda").float()
56 |
57 | inputs = torch.cat([color_scale*image, pos_scale*coords], 1)
58 |
59 | _, H, _ = model(inputs)
60 |
61 | labels = H.reshape(height, width).to("cpu").detach().numpy()
62 |
63 | if enforce_connectivity:
64 | segment_size = height * width / nspix
65 | min_size = int(0.06 * segment_size)
66 | max_size = int(3.0 * segment_size)
67 | labels = _enforce_label_connectivity_cython(
68 | labels[None], min_size, max_size)[0]
69 |
70 | return labels
71 |
72 |
73 | if __name__ == "__main__":
74 | import time
75 | import argparse
76 | import matplotlib.pyplot as plt
77 | from skimage.segmentation import mark_boundaries
78 | parser = argparse.ArgumentParser()
79 | parser.add_argument("--image", type=str, help="/path/to/image")
80 | parser.add_argument("--weight", default=None, type=str, help="/path/to/pretrained_weight")
81 | parser.add_argument("--fdim", default=20, type=int, help="embedding dimension")
82 | parser.add_argument("--niter", default=10, type=int, help="number of iterations for differentiable SLIC")
83 | parser.add_argument("--nspix", default=100, type=int, help="number of superpixels")
84 | parser.add_argument("--color_scale", default=0.26, type=float)
85 | parser.add_argument("--pos_scale", default=2.5, type=float)
86 | args = parser.parse_args()
87 |
88 | image = plt.imread(args.image)
89 |
90 | s = time.time()
91 | label = inference(image, args.nspix, args.niter, args.fdim, args.color_scale, args.pos_scale, args.weight)
92 | print(f"time {time.time() - s}sec")
93 | plt.imsave("results.png", mark_boundaries(image, label))
94 |
--------------------------------------------------------------------------------
/lib/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/perrying/ssn-pytorch/3368840b1b72efcd8ea7ca61d1b08b2dfb846d47/lib/__init__.py
--------------------------------------------------------------------------------
/lib/dataset/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/perrying/ssn-pytorch/3368840b1b72efcd8ea7ca61d1b08b2dfb846d47/lib/dataset/__init__.py
--------------------------------------------------------------------------------
/lib/dataset/augmentation.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import random
4 |
5 |
6 | class Compose:
7 | def __init__(self, augmentations):
8 | self.augmentations = augmentations
9 |
10 | def __call__(self, data):
11 | for aug in self.augmentations:
12 | data = aug(data)
13 | return data
14 |
15 |
16 | class RandomHorizontalFlip:
17 | def __init__(self, prob=0.5):
18 | self.prob = prob
19 |
20 | def __call__(self, data):
21 | if random.random() < self.prob:
22 | # call copy() to avoid negative stride error in torch.from_numpy
23 | data = [d[:, ::-1].copy() for d in data]
24 | return data
25 |
26 |
27 | class RandomScale:
28 | def __init__(self, scale_range=(0.75, 3.0)):
29 | self.scale_range = scale_range
30 |
31 | def __call__(self, data):
32 | rand_factor = np.random.normal(1, 0.75)
33 | scale = np.min((self.scale_range[1], rand_factor))
34 | scale = np.max((self.scale_range[0], scale))
35 | data = [
36 | cv2.resize(d, None, fx=scale, fy=scale,
37 | interpolation=cv2.INTER_LINEAR if d.dtype == np.float32 else cv2.INTER_NEAREST)
38 | for d in data]
39 | return data
40 |
41 |
42 | class RandomCrop:
43 | def __init__(self, crop_size=(200, 200)):
44 | self.crop_size = crop_size
45 |
46 | def __call__(self, data):
47 | height, width = data[0].shape[:2]
48 | c_h, c_w = self.crop_size
49 | assert height >= c_h and width >= c_w, f"({height}, {width}) v.s. ({c_h}, {c_w})"
50 | left = random.randint(0, width - c_w)
51 | top = random.randint(0, height - c_h)
52 | data = [d[top:top+c_h, left:left+c_w] for d in data]
53 | return data
54 |
--------------------------------------------------------------------------------
/lib/dataset/bsds.py:
--------------------------------------------------------------------------------
1 | import os, glob
2 | import torch
3 | import numpy as np
4 | import scipy.io
5 | from skimage.color import rgb2lab
6 | import matplotlib.pyplot as plt
7 |
8 |
9 | def convert_label(label):
10 |
11 | onehot = np.zeros((1, 50, label.shape[0], label.shape[1])).astype(np.float32)
12 |
13 | ct = 0
14 | for t in np.unique(label).tolist():
15 | if ct >= 50:
16 | break
17 | else:
18 | onehot[:, ct, :, :] = (label == t)
19 | ct = ct + 1
20 |
21 | return onehot
22 |
23 |
24 | class BSDS:
25 | def __init__(self, root, split="train", color_transforms=None, geo_transforms=None):
26 | self.gt_dir = os.path.join(root, "BSDS500/data/groundTruth", split)
27 | self.img_dir = os.path.join(root, "BSDS500/data/images", split)
28 |
29 | self.index = os.listdir(self.gt_dir)
30 |
31 | self.color_transforms = color_transforms
32 | self.geo_transforms = geo_transforms
33 |
34 |
35 | def __getitem__(self, idx):
36 | idx = self.index[idx][:-4]
37 | gt = scipy.io.loadmat(os.path.join(self.gt_dir, idx+".mat"))
38 | t = np.random.randint(0, len(gt['groundTruth'][0]))
39 | gt = gt['groundTruth'][0][t][0][0][0]
40 |
41 | img = rgb2lab(plt.imread(os.path.join(self.img_dir, idx+".jpg")))
42 |
43 | gt = gt.astype(np.int64)
44 | img = img.astype(np.float32)
45 |
46 | if self.color_transforms is not None:
47 | img = self.color_transforms(img)
48 |
49 | if self.geo_transforms is not None:
50 | img, gt = self.geo_transforms([img, gt])
51 |
52 | gt = convert_label(gt)
53 | gt = torch.from_numpy(gt)
54 | img = torch.from_numpy(img)
55 | img = img.permute(2, 0, 1)
56 |
57 | return img, gt.reshape(50, -1).float()
58 |
59 |
60 | def __len__(self):
61 | return len(self.index)
62 |
--------------------------------------------------------------------------------
/lib/ssn/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/perrying/ssn-pytorch/3368840b1b72efcd8ea7ca61d1b08b2dfb846d47/lib/ssn/__init__.py
--------------------------------------------------------------------------------
/lib/ssn/pair_wise_distance.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.cpp_extension import load_inline
3 | from .pair_wise_distance_cuda_source import source
4 |
5 |
6 | print("compile cuda source of 'pair_wise_distance' function...")
7 | print("NOTE: if you avoid this process, you make .cu file and compile it following https://pytorch.org/tutorials/advanced/cpp_extension.html")
8 | pair_wise_distance_cuda = load_inline(
9 | "pair_wise_distance", cpp_sources="", cuda_sources=source
10 | )
11 | print("done")
12 |
13 |
14 | class PairwiseDistFunction(torch.autograd.Function):
15 | @staticmethod
16 | def forward(self, pixel_features, spixel_features, init_spixel_indices, num_spixels_width, num_spixels_height):
17 | self.num_spixels_width = num_spixels_width
18 | self.num_spixels_height = num_spixels_height
19 | output = pixel_features.new(pixel_features.shape[0], 9, pixel_features.shape[-1]).zero_()
20 | self.save_for_backward(pixel_features, spixel_features, init_spixel_indices)
21 |
22 | return pair_wise_distance_cuda.forward(
23 | pixel_features.contiguous(), spixel_features.contiguous(),
24 | init_spixel_indices.contiguous(), output,
25 | self.num_spixels_width, self.num_spixels_height)
26 |
27 | @staticmethod
28 | def backward(self, dist_matrix_grad):
29 | pixel_features, spixel_features, init_spixel_indices = self.saved_tensors
30 |
31 | pixel_features_grad = torch.zeros_like(pixel_features)
32 | spixel_features_grad = torch.zeros_like(spixel_features)
33 |
34 | pixel_features_grad, spixel_features_grad = pair_wise_distance_cuda.backward(
35 | dist_matrix_grad.contiguous(), pixel_features.contiguous(),
36 | spixel_features.contiguous(), init_spixel_indices.contiguous(),
37 | pixel_features_grad, spixel_features_grad,
38 | self.num_spixels_width, self.num_spixels_height
39 | )
40 | return pixel_features_grad, spixel_features_grad, None, None, None
41 |
42 |
--------------------------------------------------------------------------------
/lib/ssn/pair_wise_distance_cuda_source.py:
--------------------------------------------------------------------------------
1 | source = '''
2 | #include
3 | #include
4 | #include
5 | #include
6 |
7 | #define CUDA_NUM_THREADS 256
8 |
9 | #include
10 | #include
11 | #include
12 | #include
13 |
14 | #include
15 | #include
16 | #include
17 |
18 | template
19 | __global__ void forward_kernel(
20 | const scalar_t* __restrict__ pixel_features,
21 | const scalar_t* __restrict__ spixel_features,
22 | const scalar_t* __restrict__ spixel_indices,
23 | scalar_t* __restrict__ dist_matrix,
24 | int batchsize, int channels, int num_pixels, int num_spixels,
25 | int num_spixels_w, int num_spixels_h
26 | ){
27 | int index = blockIdx.x * blockDim.x + threadIdx.x;
28 | if (index >= batchsize * num_pixels * 9) return;
29 |
30 | int cp = channels * num_pixels;
31 | int cs = channels * num_spixels;
32 |
33 | int b = index % batchsize;
34 | int spixel_offset = (index / batchsize) % 9;
35 | int p = (index / (batchsize * 9)) % num_pixels;
36 |
37 | int init_spix_index = spixel_indices[b * num_pixels + p];
38 |
39 | int x_index = init_spix_index % num_spixels_w;
40 | int spixel_offset_x = (spixel_offset % 3 - 1);
41 |
42 | int y_index = init_spix_index / num_spixels_w;
43 | int spixel_offset_y = (spixel_offset / 3 - 1);
44 |
45 | if (x_index + spixel_offset_x < 0 || x_index + spixel_offset_x >= num_spixels_w) {
46 | dist_matrix[b * (9 * num_pixels) + spixel_offset * num_pixels + p] = 1e16;
47 | }
48 | else if (y_index + spixel_offset_y < 0 || y_index + spixel_offset_y >= num_spixels_h) {
49 | dist_matrix[b * (9 * num_pixels) + spixel_offset * num_pixels + p] = 1e16;
50 | }
51 | else {
52 | int query_spixel_index = init_spix_index + spixel_offset_x + num_spixels_w * spixel_offset_y;
53 |
54 | scalar_t sum_squared_diff = 0;
55 | for (int c=0; c<<< block, CUDA_NUM_THREADS >>>(
80 | pixel_features.data(),
81 | spixel_features.data(),
82 | spixel_indices.data(),
83 | dist_matrix.data(),
84 | batchsize, channels, num_pixels,
85 | num_spixels, num_spixels_w, num_spixels_h
86 | );
87 | }));
88 |
89 | return dist_matrix;
90 | }
91 |
92 | template
93 | __global__ void backward_kernel(
94 | const scalar_t* __restrict__ dist_matrix_grad,
95 | const scalar_t* __restrict__ pixel_features,
96 | const scalar_t* __restrict__ spixel_features,
97 | const scalar_t* __restrict__ spixel_indices,
98 | scalar_t* __restrict__ pixel_feature_grad,
99 | scalar_t* __restrict__ spixel_feature_grad,
100 | int batchsize, int channels, int num_pixels, int num_spixels,
101 | int num_spixels_w, int num_spixels_h
102 | ){
103 | int index = blockIdx.x * blockDim.x + threadIdx.x;
104 | if (index >= batchsize * num_pixels * 9) return;
105 |
106 | int cp = channels * num_pixels;
107 | int cs = channels * num_spixels;
108 |
109 | int b = index % batchsize;
110 | int spixel_offset = (index / batchsize) % 9;
111 | int p = (index / (batchsize * 9)) % num_pixels;
112 |
113 | int init_spix_index = spixel_indices[b * num_pixels + p];
114 |
115 | int x_index = init_spix_index % num_spixels_w;
116 | int spixel_offset_x = (spixel_offset % 3 - 1);
117 |
118 | int y_index = init_spix_index / num_spixels_w;
119 | int spixel_offset_y = (spixel_offset / 3 - 1);
120 |
121 | if (x_index + spixel_offset_x < 0 || x_index + spixel_offset_x >= num_spixels_w) return;
122 | else if (y_index + spixel_offset_y < 0 || y_index + spixel_offset_y >= num_spixels_h) return;
123 | else {
124 | int query_spixel_index = init_spix_index + spixel_offset_x + num_spixels_w * spixel_offset_y;
125 |
126 | scalar_t dist_matrix_grad_val = dist_matrix_grad[b * (9 * num_pixels) + spixel_offset * num_pixels + p];
127 |
128 | for (int c=0; c backward_cuda(
141 | const torch::Tensor dist_matrix_grad,
142 | const torch::Tensor pixel_features,
143 | const torch::Tensor spixel_features,
144 | const torch::Tensor spixel_indices,
145 | torch::Tensor pixel_features_grad,
146 | torch::Tensor spixel_features_grad,
147 | int num_spixels_w, int num_spixels_h
148 | ){
149 | int batchsize = pixel_features.size(0);
150 | int channels = pixel_features.size(1);
151 | int num_pixels = pixel_features.size(2);
152 | int num_spixels = spixel_features.size(2);
153 |
154 |
155 | dim3 block((batchsize * 9 * num_pixels + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS);
156 |
157 | AT_DISPATCH_FLOATING_TYPES(pixel_features_grad.type(), "backward_kernel", ([&] {
158 | backward_kernel<<< block, CUDA_NUM_THREADS >>>(
159 | dist_matrix_grad.data(),
160 | pixel_features.data(),
161 | spixel_features.data(),
162 | spixel_indices.data(),
163 | pixel_features_grad.data(),
164 | spixel_features_grad.data(),
165 | batchsize, channels, num_pixels,
166 | num_spixels, num_spixels_w, num_spixels_h
167 | );
168 | }));
169 |
170 | return {pixel_features_grad, spixel_features_grad};
171 | }
172 |
173 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
174 | m.def("forward", &forward_cuda, "pair_wise_distance forward");
175 | m.def("backward", &backward_cuda, "pair_wise_distance backward");
176 | }
177 | '''
--------------------------------------------------------------------------------
/lib/ssn/ssn.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 |
4 | from .pair_wise_distance import PairwiseDistFunction
5 | from ..utils.sparse_utils import naive_sparse_bmm
6 |
7 |
8 | def calc_init_centroid(images, num_spixels_width, num_spixels_height):
9 | """
10 | calculate initial superpixels
11 |
12 | Args:
13 | images: torch.Tensor
14 | A Tensor of shape (B, C, H, W)
15 | spixels_width: int
16 | initial superpixel width
17 | spixels_height: int
18 | initial superpixel height
19 |
20 | Return:
21 | centroids: torch.Tensor
22 | A Tensor of shape (B, C, H * W)
23 | init_label_map: torch.Tensor
24 | A Tensor of shape (B, H * W)
25 | num_spixels_width: int
26 | A number of superpixels in each column
27 | num_spixels_height: int
28 | A number of superpixels int each raw
29 | """
30 | batchsize, channels, height, width = images.shape
31 | device = images.device
32 |
33 | centroids = torch.nn.functional.adaptive_avg_pool2d(images, (num_spixels_height, num_spixels_width))
34 |
35 | with torch.no_grad():
36 | num_spixels = num_spixels_width * num_spixels_height
37 | labels = torch.arange(num_spixels, device=device).reshape(1, 1, *centroids.shape[-2:]).type_as(centroids)
38 | init_label_map = torch.nn.functional.interpolate(labels, size=(height, width), mode="nearest")
39 | init_label_map = init_label_map.repeat(batchsize, 1, 1, 1)
40 |
41 | init_label_map = init_label_map.reshape(batchsize, -1)
42 | centroids = centroids.reshape(batchsize, channels, -1)
43 |
44 | return centroids, init_label_map
45 |
46 |
47 | @torch.no_grad()
48 | def get_abs_indices(init_label_map, num_spixels_width):
49 | b, n_pixel = init_label_map.shape
50 | device = init_label_map.device
51 | r = torch.arange(-1, 2.0, device=device)
52 | relative_spix_indices = torch.cat([r - num_spixels_width, r, r + num_spixels_width], 0)
53 |
54 | abs_pix_indices = torch.arange(n_pixel, device=device)[None, None].repeat(b, 9, 1).reshape(-1).long()
55 | abs_spix_indices = (init_label_map[:, None] + relative_spix_indices[None, :, None]).reshape(-1).long()
56 | abs_batch_indices = torch.arange(b, device=device)[:, None, None].repeat(1, 9, n_pixel).reshape(-1).long()
57 |
58 | return torch.stack([abs_batch_indices, abs_spix_indices, abs_pix_indices], 0)
59 |
60 |
61 | @torch.no_grad()
62 | def get_hard_abs_labels(affinity_matrix, init_label_map, num_spixels_width):
63 | relative_label = affinity_matrix.max(1)[1]
64 | r = torch.arange(-1, 2.0, device=affinity_matrix.device)
65 | relative_spix_indices = torch.cat([r - num_spixels_width, r, r + num_spixels_width], 0)
66 | label = init_label_map + relative_spix_indices[relative_label]
67 | return label.long()
68 |
69 |
70 | @torch.no_grad()
71 | def sparse_ssn_iter(pixel_features, num_spixels, n_iter):
72 | """
73 | computing assignment iterations with sparse matrix
74 | detailed process is in Algorithm 1, line 2 - 6
75 | NOTE: this function does NOT guarantee the backward computation.
76 |
77 | Args:
78 | pixel_features: torch.Tensor
79 | A Tensor of shape (B, C, H, W)
80 | num_spixels: int
81 | A number of superpixels
82 | n_iter: int
83 | A number of iterations
84 | return_hard_label: bool
85 | return hard assignment or not
86 | """
87 | height, width = pixel_features.shape[-2:]
88 | num_spixels_width = int(math.sqrt(num_spixels * width / height))
89 | num_spixels_height = int(math.sqrt(num_spixels * height / width))
90 |
91 | spixel_features, init_label_map = \
92 | calc_init_centroid(pixel_features, num_spixels_width, num_spixels_height)
93 | abs_indices = get_abs_indices(init_label_map, num_spixels_width)
94 |
95 | pixel_features = pixel_features.reshape(*pixel_features.shape[:2], -1)
96 | permuted_pixel_features = pixel_features.permute(0, 2, 1)
97 |
98 | for _ in range(n_iter):
99 | dist_matrix = PairwiseDistFunction.apply(
100 | pixel_features, spixel_features, init_label_map, num_spixels_width, num_spixels_height)
101 |
102 | affinity_matrix = (-dist_matrix).softmax(1)
103 | reshaped_affinity_matrix = affinity_matrix.reshape(-1)
104 |
105 | mask = (abs_indices[1] >= 0) * (abs_indices[1] < num_spixels)
106 | sparse_abs_affinity = torch.sparse_coo_tensor(abs_indices[:, mask], reshaped_affinity_matrix[mask])
107 | spixel_features = naive_sparse_bmm(sparse_abs_affinity, permuted_pixel_features) \
108 | / (torch.sparse.sum(sparse_abs_affinity, 2).to_dense()[..., None] + 1e-16)
109 |
110 | spixel_features = spixel_features.permute(0, 2, 1)
111 |
112 | hard_labels = get_hard_abs_labels(affinity_matrix, init_label_map, num_spixels_width)
113 |
114 | return sparse_abs_affinity, hard_labels, spixel_features
115 |
116 |
117 | def ssn_iter(pixel_features, num_spixels, n_iter):
118 | """
119 | computing assignment iterations
120 | detailed process is in Algorithm 1, line 2 - 6
121 |
122 | Args:
123 | pixel_features: torch.Tensor
124 | A Tensor of shape (B, C, H, W)
125 | num_spixels: int
126 | A number of superpixels
127 | n_iter: int
128 | A number of iterations
129 | return_hard_label: bool
130 | return hard assignment or not
131 | """
132 | height, width = pixel_features.shape[-2:]
133 | num_spixels_width = int(math.sqrt(num_spixels * width / height))
134 | num_spixels_height = int(math.sqrt(num_spixels * height / width))
135 |
136 | spixel_features, init_label_map = \
137 | calc_init_centroid(pixel_features, num_spixels_width, num_spixels_height)
138 | abs_indices = get_abs_indices(init_label_map, num_spixels_width)
139 |
140 | pixel_features = pixel_features.reshape(*pixel_features.shape[:2], -1)
141 | permuted_pixel_features = pixel_features.permute(0, 2, 1).contiguous()
142 |
143 | for _ in range(n_iter):
144 | dist_matrix = PairwiseDistFunction.apply(
145 | pixel_features, spixel_features, init_label_map, num_spixels_width, num_spixels_height)
146 |
147 | affinity_matrix = (-dist_matrix).softmax(1)
148 | reshaped_affinity_matrix = affinity_matrix.reshape(-1)
149 |
150 | mask = (abs_indices[1] >= 0) * (abs_indices[1] < num_spixels)
151 | sparse_abs_affinity = torch.sparse_coo_tensor(abs_indices[:, mask], reshaped_affinity_matrix[mask])
152 |
153 | abs_affinity = sparse_abs_affinity.to_dense().contiguous()
154 | spixel_features = torch.bmm(abs_affinity, permuted_pixel_features) \
155 | / (abs_affinity.sum(2, keepdim=True) + 1e-16)
156 |
157 | spixel_features = spixel_features.permute(0, 2, 1).contiguous()
158 |
159 |
160 | hard_labels = get_hard_abs_labels(affinity_matrix, init_label_map, num_spixels_width)
161 |
162 | return abs_affinity, hard_labels, spixel_features
163 |
--------------------------------------------------------------------------------
/lib/ssn/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .pair_wise_distance import PairwiseDistFunction
3 |
4 |
5 | # naive implementation for debug
6 | def naive_pair_wise_dist(pix, spix, idx, n_spix_w, n_spix_h):
7 | device = pix.device
8 | ba, ch, pi = pix.shape
9 | outputs = []
10 | for b in range(ba):
11 | batch_out = []
12 | for p in range(pi):
13 | pix_out = []
14 | pix_v = pix[b, :, p]
15 | sp_i = idx[b, p]
16 | sp_i_x = sp_i % n_spix_w
17 | sp_i_y = sp_i // n_spix_w
18 | for i in range(9):
19 | if sp_i_x == 0 and (i % 3) == 0:
20 | d_dist = pix.new(1).fill_(0)
21 | pix_out.append(d_dist[0])
22 | elif sp_i_x == (n_spix_w - 1) and (i % 3) == 2:
23 | d_dist = pix.new(1).fill_(0)
24 | pix_out.append(d_dist[0])
25 | elif sp_i_y == 0 and (i // 3) == 0:
26 | d_dist = pix.new(1).fill_(0)
27 | pix_out.append(d_dist[0])
28 | elif sp_i_y == (n_spix_h - 1) and (i // 3) == 2:
29 | d_dist = pix.new(1).fill_(0)
30 | pix_out.append(d_dist[0])
31 | else:
32 | offset_x = i % 3 - 1
33 | offset_y = (i // 3 - 1) * n_spix_w
34 | s = int(sp_i + offset_y + offset_x)
35 | pix_out.append((pix_v - spix[b, :, s]).pow(2).sum())
36 | batch_out.append(torch.stack(pix_out))
37 | outputs.append(torch.stack(batch_out, 1))
38 | return torch.stack(outputs, 0)
39 |
40 |
41 | def test(eps=1e-4):
42 | func = PairwiseDistFunction.apply
43 |
44 | pix = torch.randn(2, 20, 81).double().to("cuda")
45 | spix = torch.randn(2, 20, 9).double().to("cuda")
46 | idx = torch.randint(0, 9, (2, 81)).double().to("cuda")
47 | wid = 3
48 | hei = 3
49 |
50 | pix.requires_grad = True
51 | spix.requires_grad = True
52 |
53 | res = torch.autograd.gradcheck(func, (pix, spix, idx, wid, hei), eps=eps, raise_exception=False)
54 | print(res)
55 |
56 | o = PairwiseDistFunction.apply(pix, spix, idx, wid, hei)
57 | o.sum().backward()
58 |
59 | cuda_p_grad = pix.grad
60 | cuda_sp_grad = spix.grad
61 |
62 | pix.grad.zero_()
63 | spix.grad.zero_()
64 |
65 | naive_o = naive_pair_wise_dist(pix, spix, idx, wid, hei)
66 | naive_o.sum().backward()
67 |
68 | print("output diff between GPU and naive", torch.abs(o - naive_o).mean())
69 | print("pix grad diff between GPU and naive", torch.abs(cuda_p_grad - pix.grad).mean())
70 | print("spix grad diff between GPU and naive", torch.abs(cuda_sp_grad - spix.grad).mean())
71 |
--------------------------------------------------------------------------------
/lib/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/perrying/ssn-pytorch/3368840b1b72efcd8ea7ca61d1b08b2dfb846d47/lib/utils/__init__.py
--------------------------------------------------------------------------------
/lib/utils/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .sparse_utils import naive_sparse_bmm, sparse_permute
3 |
4 |
5 | def sparse_reconstruction(assignment, labels, hard_assignment=None):
6 | """
7 | reconstruction loss with the sparse matrix
8 | NOTE: this function doesn't use it in this project, because may not return correct gradients
9 |
10 | Args:
11 | assignment: torch.sparse_coo_tensor
12 | A Tensor of shape (B, n_spixels, n_pixels)
13 | labels: torch.Tensor
14 | A Tensor of shape (B, C, n_pixels)
15 | hard_assignment: torch.Tensor
16 | A Tensor of shape (B, n_pixels)
17 | """
18 | labels = labels.permute(0, 2, 1).contiguous()
19 |
20 | # matrix product between (n_spixels, n_pixels) and (n_pixels, channels)
21 | spixel_mean = naive_sparse_bmm(assignment, labels) / (torch.sparse.sum(assignment, 2).to_dense()[..., None] + 1e-16)
22 | if hard_assignment is None:
23 | # (B, n_spixels, n_pixels) -> (B, n_pixels, n_spixels)
24 | permuted_assignment = sparse_permute(assignment, (0, 2, 1))
25 | # matrix product between (n_pixels, n_spixels) and (n_spixels, channels)
26 | reconstructed_labels = naive_sparse_bmm(permuted_assignment, spixel_mean)
27 | else:
28 | # index sampling
29 | reconstructed_labels = torch.stack([sm[ha, :] for sm, ha in zip(spixel_mean, hard_assignment)], 0)
30 | return reconstructed_labels.permute(0, 2, 1).contiguous()
31 |
32 |
33 | def reconstruction(assignment, labels, hard_assignment=None):
34 | """
35 | reconstruction
36 |
37 | Args:
38 | assignment: torch.Tensor
39 | A Tensor of shape (B, n_spixels, n_pixels)
40 | labels: torch.Tensor
41 | A Tensor of shape (B, C, n_pixels)
42 | hard_assignment: torch.Tensor
43 | A Tensor of shape (B, n_pixels)
44 | """
45 | labels = labels.permute(0, 2, 1).contiguous()
46 |
47 | # matrix product between (n_spixels, n_pixels) and (n_pixels, channels)
48 | spixel_mean = torch.bmm(assignment, labels) / (assignment.sum(2, keepdim=True) + 1e-16)
49 | if hard_assignment is None:
50 | # (B, n_spixels, n_pixels) -> (B, n_pixels, n_spixels)
51 | permuted_assignment = assignment.permute(0, 2, 1).contiguous()
52 | # matrix product between (n_pixels, n_spixels) and (n_spixels, channels)
53 | reconstructed_labels = torch.bmm(permuted_assignment, spixel_mean)
54 | else:
55 | # index sampling
56 | reconstructed_labels = torch.stack([sm[ha, :] for sm, ha in zip(spixel_mean, hard_assignment)], 0)
57 | return reconstructed_labels.permute(0, 2, 1).contiguous()
58 |
59 |
60 | def reconstruct_loss_with_cross_etnropy(assignment, labels, hard_assignment=None):
61 | """
62 | reconstruction loss with cross entropy
63 |
64 | Args:
65 | assignment: torch.Tensor
66 | A Tensor of shape (B, n_spixels, n_pixels)
67 | labels: torch.Tensor
68 | A Tensor of shape (B, C, n_pixels)
69 | hard_assignment: torch.Tensor
70 | A Tensor of shape (B, n_pixels)
71 | """
72 | reconstracted_labels = reconstruction(assignment, labels, hard_assignment)
73 | reconstracted_labels = reconstracted_labels / (1e-16 + reconstracted_labels.sum(1, keepdim=True))
74 | mask = labels > 0
75 | return -(reconstracted_labels[mask] + 1e-16).log().mean()
76 |
77 |
78 | def reconstruct_loss_with_mse(assignment, labels, hard_assignment=None):
79 | """
80 | reconstruction loss with mse
81 |
82 | Args:
83 | assignment: torch.Tensor
84 | A Tensor of shape (B, n_spixels, n_pixels)
85 | labels: torch.Tensor
86 | A Tensor of shape (B, C, n_pixels)
87 | hard_assignment: torch.Tensor
88 | A Tensor of shape (B, n_pixels)
89 | """
90 | reconstracted_labels = reconstruction(assignment, labels, hard_assignment)
91 | return torch.nn.functional.mse_loss(reconstracted_labels, labels)
92 |
--------------------------------------------------------------------------------
/lib/utils/meter.py:
--------------------------------------------------------------------------------
1 | class Meter:
2 | def __init__(self, ema_coef=0.9):
3 | self.ema_coef = ema_coef
4 | self.params = {}
5 |
6 | def add(self, params:dict, ignores:list = []):
7 | for k, v in params.items():
8 | if k in ignores:
9 | continue
10 | if not k in self.params.keys():
11 | self.params[k] = v
12 | else:
13 | self.params[k] -= (1 - self.ema_coef) * (self.params[k] - v)
14 |
15 | def state(self, header="", footer=""):
16 | state = header
17 | for k, v in self.params.items():
18 | state += f" {k} {v:.6g} |"
19 | return state + " " + footer
20 |
21 | def reset(self):
22 | self.params = {}
--------------------------------------------------------------------------------
/lib/utils/sparse_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def naive_sparse_bmm(sparse_mat, dense_mat, transpose=False):
5 | if transpose:
6 | return torch.stack([torch.sparse.mm(s_mat, d_mat.t()) for s_mat, d_mat in zip(sparse_mat, dense_mat)], 0)
7 | else:
8 | return torch.stack([torch.sparse.mm(s_mat, d_mat) for s_mat, d_mat in zip(sparse_mat, dense_mat)], 0)
9 |
10 | def sparse_permute(sparse_mat, order):
11 | values = sparse_mat.coalesce().values()
12 | indices = sparse_mat.coalesce().indices()
13 | indices = torch.stack([indices[o] for o in order], 0).contiguous()
14 | return torch.sparse_coo_tensor(indices, values)
15 |
16 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from lib.ssn.ssn import ssn_iter, sparse_ssn_iter
5 |
6 |
7 | def conv_bn_relu(in_c, out_c):
8 | return nn.Sequential(
9 | nn.Conv2d(in_c, out_c, 3, padding=1, bias=False),
10 | nn.BatchNorm2d(out_c),
11 | nn.ReLU(True)
12 | )
13 |
14 | class SSNModel(nn.Module):
15 | def __init__(self, feature_dim, nspix, n_iter=10):
16 | super().__init__()
17 | self.nspix = nspix
18 | self.n_iter = n_iter
19 |
20 | self.scale1 = nn.Sequential(
21 | conv_bn_relu(5, 64),
22 | conv_bn_relu(64, 64)
23 | )
24 | self.scale2 = nn.Sequential(
25 | nn.MaxPool2d(3, 2, padding=1),
26 | conv_bn_relu(64, 64),
27 | conv_bn_relu(64, 64)
28 | )
29 | self.scale3 = nn.Sequential(
30 | nn.MaxPool2d(3, 2, padding=1),
31 | conv_bn_relu(64, 64),
32 | conv_bn_relu(64, 64)
33 | )
34 |
35 | self.output_conv = nn.Sequential(
36 | nn.Conv2d(64*3+5, feature_dim-5, 3, padding=1),
37 | nn.ReLU(True)
38 | )
39 |
40 | for m in self.modules():
41 | if isinstance(m, nn.Conv2d):
42 | nn.init.normal_(m.weight, 0, 0.001)
43 | if m.bias is not None:
44 | nn.init.constant_(m.bias, 0)
45 |
46 |
47 | def forward(self, x):
48 | pixel_f = self.feature_extract(x)
49 |
50 | if self.training:
51 | return ssn_iter(pixel_f, self.nspix, self.n_iter)
52 | else:
53 | return sparse_ssn_iter(pixel_f, self.nspix, self.n_iter)
54 |
55 |
56 | def feature_extract(self, x):
57 | s1 = self.scale1(x)
58 | s2 = self.scale2(s1)
59 | s3 = self.scale3(s2)
60 |
61 | s2 = nn.functional.interpolate(s2, size=s1.shape[-2:], mode="bilinear", align_corners=False)
62 | s3 = nn.functional.interpolate(s3, size=s1.shape[-2:], mode="bilinear", align_corners=False)
63 |
64 | cat_feat = torch.cat([x, s1, s2, s3], 1)
65 | feat = self.output_conv(cat_feat)
66 |
67 | return torch.cat([feat, x], 1)
68 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os, math
2 | import numpy as np
3 | import time
4 | import torch
5 | import torch.optim as optim
6 | from torch.utils.data import DataLoader
7 |
8 | from lib.utils.meter import Meter
9 | from model import SSNModel
10 | from lib.dataset import bsds, augmentation
11 | from lib.utils.loss import reconstruct_loss_with_cross_etnropy, reconstruct_loss_with_mse
12 |
13 |
14 | @torch.no_grad()
15 | def eval(model, loader, color_scale, pos_scale, device):
16 | def achievable_segmentation_accuracy(superpixel, label):
17 | """
18 | Function to calculate Achievable Segmentation Accuracy:
19 | ASA(S,G) = sum_j max_i |s_j \cap g_i| / sum_i |g_i|
20 |
21 | Args:
22 | input: superpixel image (H, W),
23 | output: ground-truth (H, W)
24 | """
25 | TP = 0
26 | unique_id = np.unique(superpixel)
27 | for uid in unique_id:
28 | mask = superpixel == uid
29 | label_hist = np.histogram(label[mask])
30 | maximum_regionsize = label_hist[0].max()
31 | TP += maximum_regionsize
32 | return TP / label.size
33 |
34 | model.eval()
35 | sum_asa = 0
36 | for data in loader:
37 | inputs, labels = data
38 |
39 | inputs = inputs.to(device)
40 | labels = labels.to(device)
41 |
42 | height, width = inputs.shape[-2:]
43 |
44 | nspix_per_axis = int(math.sqrt(model.nspix))
45 | pos_scale = pos_scale * max(nspix_per_axis/height, nspix_per_axis/width)
46 |
47 | coords = torch.stack(torch.meshgrid(torch.arange(height, device=device), torch.arange(width, device=device)), 0)
48 | coords = coords[None].repeat(inputs.shape[0], 1, 1, 1).float()
49 |
50 | inputs = torch.cat([color_scale*inputs, pos_scale*coords], 1)
51 |
52 | Q, H, feat = model(inputs)
53 |
54 | H = H.reshape(height, width)
55 | labels = labels.argmax(1).reshape(height, width)
56 |
57 | asa = achievable_segmentation_accuracy(H.to("cpu").detach().numpy(), labels.to("cpu").numpy())
58 | sum_asa += asa
59 | model.train()
60 | return sum_asa / len(loader)
61 |
62 |
63 | def update_param(data, model, optimizer, compactness, color_scale, pos_scale, device):
64 | inputs, labels = data
65 |
66 | inputs = inputs.to(device)
67 | labels = labels.to(device)
68 |
69 | height, width = inputs.shape[-2:]
70 |
71 | nspix_per_axis = int(math.sqrt(model.nspix))
72 | pos_scale = pos_scale * max(nspix_per_axis/height, nspix_per_axis/width)
73 |
74 | coords = torch.stack(torch.meshgrid(torch.arange(height, device=device), torch.arange(width, device=device)), 0)
75 | coords = coords[None].repeat(inputs.shape[0], 1, 1, 1).float()
76 |
77 | inputs = torch.cat([color_scale*inputs, pos_scale*coords], 1)
78 |
79 | Q, H, feat = model(inputs)
80 |
81 | recons_loss = reconstruct_loss_with_cross_etnropy(Q, labels)
82 | compact_loss = reconstruct_loss_with_mse(Q, coords.reshape(*coords.shape[:2], -1), H)
83 |
84 | loss = recons_loss + compactness * compact_loss
85 |
86 | optimizer.zero_grad()
87 | loss.backward()
88 | optimizer.step()
89 |
90 | return {"loss": loss.item(), "reconstruction": recons_loss.item(), "compact": compact_loss.item()}
91 |
92 |
93 | def train(cfg):
94 | if torch.cuda.is_available():
95 | device = "cuda"
96 | else:
97 | device = "cpu"
98 |
99 | model = SSNModel(cfg.fdim, cfg.nspix, cfg.niter).to(device)
100 |
101 | optimizer = optim.Adam(model.parameters(), cfg.lr)
102 |
103 | augment = augmentation.Compose([augmentation.RandomHorizontalFlip(), augmentation.RandomScale(), augmentation.RandomCrop()])
104 | train_dataset = bsds.BSDS(cfg.root, geo_transforms=augment)
105 | train_loader = DataLoader(train_dataset, cfg.batchsize, shuffle=True, drop_last=True, num_workers=cfg.nworkers)
106 |
107 | test_dataset = bsds.BSDS(cfg.root, split="val")
108 | test_loader = DataLoader(test_dataset, 1, shuffle=False, drop_last=False)
109 |
110 | meter = Meter()
111 |
112 | iterations = 0
113 | max_val_asa = 0
114 | while iterations < cfg.train_iter:
115 | for data in train_loader:
116 | iterations += 1
117 | metric = update_param(data, model, optimizer, cfg.compactness, cfg.color_scale, cfg.pos_scale, device)
118 | meter.add(metric)
119 | state = meter.state(f"[{iterations}/{cfg.train_iter}]")
120 | print(state)
121 | if (iterations % cfg.test_interval) == 0:
122 | asa = eval(model, test_loader, cfg.color_scale, cfg.pos_scale, device)
123 | print(f"validation asa {asa}")
124 | if asa > max_val_asa:
125 | max_val_asa = asa
126 | torch.save(model.state_dict(), os.path.join(cfg.out_dir, "bset_model.pth"))
127 | if iterations == cfg.train_iter:
128 | break
129 |
130 | unique_id = str(int(time.time()))
131 | torch.save(model.state_dict(), os.path.join(cfg.out_dir, "model"+unique_id+".pth"))
132 |
133 |
134 | if __name__ == "__main__":
135 | import argparse
136 | parser = argparse.ArgumentParser()
137 |
138 | parser.add_argument("--root", type=str, help="/path/to/BSR")
139 | parser.add_argument("--out_dir", default="./log", type=str, help="/path/to/output directory")
140 | parser.add_argument("--batchsize", default=6, type=int)
141 | parser.add_argument("--nworkers", default=4, type=int, help="number of threads for CPU parallel")
142 | parser.add_argument("--lr", default=1e-4, type=float, help="learning rate")
143 | parser.add_argument("--train_iter", default=500000, type=int)
144 | parser.add_argument("--fdim", default=20, type=int, help="embedding dimension")
145 | parser.add_argument("--niter", default=5, type=int, help="number of iterations for differentiable SLIC")
146 | parser.add_argument("--nspix", default=100, type=int, help="number of superpixels")
147 | parser.add_argument("--color_scale", default=0.26, type=float)
148 | parser.add_argument("--pos_scale", default=2.5, type=float)
149 | parser.add_argument("--compactness", default=1e-5, type=float)
150 | parser.add_argument("--test_interval", default=10000, type=int)
151 |
152 | args = parser.parse_args()
153 |
154 | os.makedirs(args.out_dir, exist_ok=True)
155 |
156 | train(args)
157 |
--------------------------------------------------------------------------------