├── LICENSE ├── README.md ├── SSN_deep_result.png ├── SSN_pix_result.png ├── inference.py ├── lib ├── __init__.py ├── dataset │ ├── __init__.py │ ├── augmentation.py │ └── bsds.py ├── ssn │ ├── __init__.py │ ├── pair_wise_distance.py │ ├── pair_wise_distance_cuda_source.py │ ├── ssn.py │ └── test.py └── utils │ ├── __init__.py │ ├── loss.py │ ├── meter.py │ └── sparse_utils.py ├── model.py └── train.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 teppei suzuki 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Superpixel Sampling Networks 2 | PyTorch implementation of Superpixel Sampling Networks 3 | paper: https://arxiv.org/abs/1807.10174 4 | original code: https://github.com/NVlabs/ssn_superpixels 5 | 6 | ### Note 7 | A pure PyTorch implementation of the core component, differentiable SLIC, is available [here](https://github.com/perrying/diffSLIC) (note that it implements the similarity function as the cosine similarity instead of the negative Euclidean distance). 8 | 9 | # Requirements 10 | - PyTorch >= 1.4 11 | - scikit-image 12 | - matplotlib 13 | 14 | # Usage 15 | ## inference 16 | SSN_pix 17 | ``` 18 | python inference --image /path/to/image 19 | ``` 20 | SSN_deep 21 | ``` 22 | python inference --image /path/to/image --weight /path/to/pretrained_weight 23 | ``` 24 | 25 | ## training 26 | ``` 27 | python train.py --root /path/to/BSDS500 28 | ``` 29 | 30 | # Results 31 | SSN_pix 32 | 33 | 34 | SSN_deep 35 | 36 | -------------------------------------------------------------------------------- /SSN_deep_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perrying/ssn-pytorch/3368840b1b72efcd8ea7ca61d1b08b2dfb846d47/SSN_deep_result.png -------------------------------------------------------------------------------- /SSN_pix_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perrying/ssn-pytorch/3368840b1b72efcd8ea7ca61d1b08b2dfb846d47/SSN_pix_result.png -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | 5 | from skimage.color import rgb2lab 6 | from skimage.segmentation._slic import _enforce_label_connectivity_cython 7 | 8 | from lib.ssn.ssn import sparse_ssn_iter 9 | 10 | 11 | @torch.no_grad() 12 | def inference(image, nspix, n_iter, fdim=None, color_scale=0.26, pos_scale=2.5, weight=None, enforce_connectivity=True): 13 | """ 14 | generate superpixels 15 | 16 | Args: 17 | image: numpy.ndarray 18 | An array of shape (h, w, c) 19 | nspix: int 20 | number of superpixels 21 | n_iter: int 22 | number of iterations 23 | fdim (optional): int 24 | feature dimension for supervised setting 25 | color_scale: float 26 | color channel factor 27 | pos_scale: float 28 | pixel coordinate factor 29 | weight: state_dict 30 | pretrained weight 31 | enforce_connectivity: bool 32 | if True, enforce superpixel connectivity in postprocessing 33 | 34 | Return: 35 | labels: numpy.ndarray 36 | An array of shape (h, w) 37 | """ 38 | if weight is not None: 39 | from model import SSNModel 40 | model = SSNModel(fdim, nspix, n_iter).to("cuda") 41 | model.load_state_dict(torch.load(weight)) 42 | model.eval() 43 | else: 44 | model = lambda data: sparse_ssn_iter(data, nspix, n_iter) 45 | 46 | height, width = image.shape[:2] 47 | 48 | nspix_per_axis = int(math.sqrt(nspix)) 49 | pos_scale = pos_scale * max(nspix_per_axis/height, nspix_per_axis/width) 50 | 51 | coords = torch.stack(torch.meshgrid(torch.arange(height, device="cuda"), torch.arange(width, device="cuda")), 0) 52 | coords = coords[None].float() 53 | 54 | image = rgb2lab(image) 55 | image = torch.from_numpy(image).permute(2, 0, 1)[None].to("cuda").float() 56 | 57 | inputs = torch.cat([color_scale*image, pos_scale*coords], 1) 58 | 59 | _, H, _ = model(inputs) 60 | 61 | labels = H.reshape(height, width).to("cpu").detach().numpy() 62 | 63 | if enforce_connectivity: 64 | segment_size = height * width / nspix 65 | min_size = int(0.06 * segment_size) 66 | max_size = int(3.0 * segment_size) 67 | labels = _enforce_label_connectivity_cython( 68 | labels[None], min_size, max_size)[0] 69 | 70 | return labels 71 | 72 | 73 | if __name__ == "__main__": 74 | import time 75 | import argparse 76 | import matplotlib.pyplot as plt 77 | from skimage.segmentation import mark_boundaries 78 | parser = argparse.ArgumentParser() 79 | parser.add_argument("--image", type=str, help="/path/to/image") 80 | parser.add_argument("--weight", default=None, type=str, help="/path/to/pretrained_weight") 81 | parser.add_argument("--fdim", default=20, type=int, help="embedding dimension") 82 | parser.add_argument("--niter", default=10, type=int, help="number of iterations for differentiable SLIC") 83 | parser.add_argument("--nspix", default=100, type=int, help="number of superpixels") 84 | parser.add_argument("--color_scale", default=0.26, type=float) 85 | parser.add_argument("--pos_scale", default=2.5, type=float) 86 | args = parser.parse_args() 87 | 88 | image = plt.imread(args.image) 89 | 90 | s = time.time() 91 | label = inference(image, args.nspix, args.niter, args.fdim, args.color_scale, args.pos_scale, args.weight) 92 | print(f"time {time.time() - s}sec") 93 | plt.imsave("results.png", mark_boundaries(image, label)) 94 | -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perrying/ssn-pytorch/3368840b1b72efcd8ea7ca61d1b08b2dfb846d47/lib/__init__.py -------------------------------------------------------------------------------- /lib/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perrying/ssn-pytorch/3368840b1b72efcd8ea7ca61d1b08b2dfb846d47/lib/dataset/__init__.py -------------------------------------------------------------------------------- /lib/dataset/augmentation.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import random 4 | 5 | 6 | class Compose: 7 | def __init__(self, augmentations): 8 | self.augmentations = augmentations 9 | 10 | def __call__(self, data): 11 | for aug in self.augmentations: 12 | data = aug(data) 13 | return data 14 | 15 | 16 | class RandomHorizontalFlip: 17 | def __init__(self, prob=0.5): 18 | self.prob = prob 19 | 20 | def __call__(self, data): 21 | if random.random() < self.prob: 22 | # call copy() to avoid negative stride error in torch.from_numpy 23 | data = [d[:, ::-1].copy() for d in data] 24 | return data 25 | 26 | 27 | class RandomScale: 28 | def __init__(self, scale_range=(0.75, 3.0)): 29 | self.scale_range = scale_range 30 | 31 | def __call__(self, data): 32 | rand_factor = np.random.normal(1, 0.75) 33 | scale = np.min((self.scale_range[1], rand_factor)) 34 | scale = np.max((self.scale_range[0], scale)) 35 | data = [ 36 | cv2.resize(d, None, fx=scale, fy=scale, 37 | interpolation=cv2.INTER_LINEAR if d.dtype == np.float32 else cv2.INTER_NEAREST) 38 | for d in data] 39 | return data 40 | 41 | 42 | class RandomCrop: 43 | def __init__(self, crop_size=(200, 200)): 44 | self.crop_size = crop_size 45 | 46 | def __call__(self, data): 47 | height, width = data[0].shape[:2] 48 | c_h, c_w = self.crop_size 49 | assert height >= c_h and width >= c_w, f"({height}, {width}) v.s. ({c_h}, {c_w})" 50 | left = random.randint(0, width - c_w) 51 | top = random.randint(0, height - c_h) 52 | data = [d[top:top+c_h, left:left+c_w] for d in data] 53 | return data 54 | -------------------------------------------------------------------------------- /lib/dataset/bsds.py: -------------------------------------------------------------------------------- 1 | import os, glob 2 | import torch 3 | import numpy as np 4 | import scipy.io 5 | from skimage.color import rgb2lab 6 | import matplotlib.pyplot as plt 7 | 8 | 9 | def convert_label(label): 10 | 11 | onehot = np.zeros((1, 50, label.shape[0], label.shape[1])).astype(np.float32) 12 | 13 | ct = 0 14 | for t in np.unique(label).tolist(): 15 | if ct >= 50: 16 | break 17 | else: 18 | onehot[:, ct, :, :] = (label == t) 19 | ct = ct + 1 20 | 21 | return onehot 22 | 23 | 24 | class BSDS: 25 | def __init__(self, root, split="train", color_transforms=None, geo_transforms=None): 26 | self.gt_dir = os.path.join(root, "BSDS500/data/groundTruth", split) 27 | self.img_dir = os.path.join(root, "BSDS500/data/images", split) 28 | 29 | self.index = os.listdir(self.gt_dir) 30 | 31 | self.color_transforms = color_transforms 32 | self.geo_transforms = geo_transforms 33 | 34 | 35 | def __getitem__(self, idx): 36 | idx = self.index[idx][:-4] 37 | gt = scipy.io.loadmat(os.path.join(self.gt_dir, idx+".mat")) 38 | t = np.random.randint(0, len(gt['groundTruth'][0])) 39 | gt = gt['groundTruth'][0][t][0][0][0] 40 | 41 | img = rgb2lab(plt.imread(os.path.join(self.img_dir, idx+".jpg"))) 42 | 43 | gt = gt.astype(np.int64) 44 | img = img.astype(np.float32) 45 | 46 | if self.color_transforms is not None: 47 | img = self.color_transforms(img) 48 | 49 | if self.geo_transforms is not None: 50 | img, gt = self.geo_transforms([img, gt]) 51 | 52 | gt = convert_label(gt) 53 | gt = torch.from_numpy(gt) 54 | img = torch.from_numpy(img) 55 | img = img.permute(2, 0, 1) 56 | 57 | return img, gt.reshape(50, -1).float() 58 | 59 | 60 | def __len__(self): 61 | return len(self.index) 62 | -------------------------------------------------------------------------------- /lib/ssn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perrying/ssn-pytorch/3368840b1b72efcd8ea7ca61d1b08b2dfb846d47/lib/ssn/__init__.py -------------------------------------------------------------------------------- /lib/ssn/pair_wise_distance.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.cpp_extension import load_inline 3 | from .pair_wise_distance_cuda_source import source 4 | 5 | 6 | print("compile cuda source of 'pair_wise_distance' function...") 7 | print("NOTE: if you avoid this process, you make .cu file and compile it following https://pytorch.org/tutorials/advanced/cpp_extension.html") 8 | pair_wise_distance_cuda = load_inline( 9 | "pair_wise_distance", cpp_sources="", cuda_sources=source 10 | ) 11 | print("done") 12 | 13 | 14 | class PairwiseDistFunction(torch.autograd.Function): 15 | @staticmethod 16 | def forward(self, pixel_features, spixel_features, init_spixel_indices, num_spixels_width, num_spixels_height): 17 | self.num_spixels_width = num_spixels_width 18 | self.num_spixels_height = num_spixels_height 19 | output = pixel_features.new(pixel_features.shape[0], 9, pixel_features.shape[-1]).zero_() 20 | self.save_for_backward(pixel_features, spixel_features, init_spixel_indices) 21 | 22 | return pair_wise_distance_cuda.forward( 23 | pixel_features.contiguous(), spixel_features.contiguous(), 24 | init_spixel_indices.contiguous(), output, 25 | self.num_spixels_width, self.num_spixels_height) 26 | 27 | @staticmethod 28 | def backward(self, dist_matrix_grad): 29 | pixel_features, spixel_features, init_spixel_indices = self.saved_tensors 30 | 31 | pixel_features_grad = torch.zeros_like(pixel_features) 32 | spixel_features_grad = torch.zeros_like(spixel_features) 33 | 34 | pixel_features_grad, spixel_features_grad = pair_wise_distance_cuda.backward( 35 | dist_matrix_grad.contiguous(), pixel_features.contiguous(), 36 | spixel_features.contiguous(), init_spixel_indices.contiguous(), 37 | pixel_features_grad, spixel_features_grad, 38 | self.num_spixels_width, self.num_spixels_height 39 | ) 40 | return pixel_features_grad, spixel_features_grad, None, None, None 41 | 42 | -------------------------------------------------------------------------------- /lib/ssn/pair_wise_distance_cuda_source.py: -------------------------------------------------------------------------------- 1 | source = ''' 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | #define CUDA_NUM_THREADS 256 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | #include 17 | 18 | template 19 | __global__ void forward_kernel( 20 | const scalar_t* __restrict__ pixel_features, 21 | const scalar_t* __restrict__ spixel_features, 22 | const scalar_t* __restrict__ spixel_indices, 23 | scalar_t* __restrict__ dist_matrix, 24 | int batchsize, int channels, int num_pixels, int num_spixels, 25 | int num_spixels_w, int num_spixels_h 26 | ){ 27 | int index = blockIdx.x * blockDim.x + threadIdx.x; 28 | if (index >= batchsize * num_pixels * 9) return; 29 | 30 | int cp = channels * num_pixels; 31 | int cs = channels * num_spixels; 32 | 33 | int b = index % batchsize; 34 | int spixel_offset = (index / batchsize) % 9; 35 | int p = (index / (batchsize * 9)) % num_pixels; 36 | 37 | int init_spix_index = spixel_indices[b * num_pixels + p]; 38 | 39 | int x_index = init_spix_index % num_spixels_w; 40 | int spixel_offset_x = (spixel_offset % 3 - 1); 41 | 42 | int y_index = init_spix_index / num_spixels_w; 43 | int spixel_offset_y = (spixel_offset / 3 - 1); 44 | 45 | if (x_index + spixel_offset_x < 0 || x_index + spixel_offset_x >= num_spixels_w) { 46 | dist_matrix[b * (9 * num_pixels) + spixel_offset * num_pixels + p] = 1e16; 47 | } 48 | else if (y_index + spixel_offset_y < 0 || y_index + spixel_offset_y >= num_spixels_h) { 49 | dist_matrix[b * (9 * num_pixels) + spixel_offset * num_pixels + p] = 1e16; 50 | } 51 | else { 52 | int query_spixel_index = init_spix_index + spixel_offset_x + num_spixels_w * spixel_offset_y; 53 | 54 | scalar_t sum_squared_diff = 0; 55 | for (int c=0; c<<< block, CUDA_NUM_THREADS >>>( 80 | pixel_features.data(), 81 | spixel_features.data(), 82 | spixel_indices.data(), 83 | dist_matrix.data(), 84 | batchsize, channels, num_pixels, 85 | num_spixels, num_spixels_w, num_spixels_h 86 | ); 87 | })); 88 | 89 | return dist_matrix; 90 | } 91 | 92 | template 93 | __global__ void backward_kernel( 94 | const scalar_t* __restrict__ dist_matrix_grad, 95 | const scalar_t* __restrict__ pixel_features, 96 | const scalar_t* __restrict__ spixel_features, 97 | const scalar_t* __restrict__ spixel_indices, 98 | scalar_t* __restrict__ pixel_feature_grad, 99 | scalar_t* __restrict__ spixel_feature_grad, 100 | int batchsize, int channels, int num_pixels, int num_spixels, 101 | int num_spixels_w, int num_spixels_h 102 | ){ 103 | int index = blockIdx.x * blockDim.x + threadIdx.x; 104 | if (index >= batchsize * num_pixels * 9) return; 105 | 106 | int cp = channels * num_pixels; 107 | int cs = channels * num_spixels; 108 | 109 | int b = index % batchsize; 110 | int spixel_offset = (index / batchsize) % 9; 111 | int p = (index / (batchsize * 9)) % num_pixels; 112 | 113 | int init_spix_index = spixel_indices[b * num_pixels + p]; 114 | 115 | int x_index = init_spix_index % num_spixels_w; 116 | int spixel_offset_x = (spixel_offset % 3 - 1); 117 | 118 | int y_index = init_spix_index / num_spixels_w; 119 | int spixel_offset_y = (spixel_offset / 3 - 1); 120 | 121 | if (x_index + spixel_offset_x < 0 || x_index + spixel_offset_x >= num_spixels_w) return; 122 | else if (y_index + spixel_offset_y < 0 || y_index + spixel_offset_y >= num_spixels_h) return; 123 | else { 124 | int query_spixel_index = init_spix_index + spixel_offset_x + num_spixels_w * spixel_offset_y; 125 | 126 | scalar_t dist_matrix_grad_val = dist_matrix_grad[b * (9 * num_pixels) + spixel_offset * num_pixels + p]; 127 | 128 | for (int c=0; c backward_cuda( 141 | const torch::Tensor dist_matrix_grad, 142 | const torch::Tensor pixel_features, 143 | const torch::Tensor spixel_features, 144 | const torch::Tensor spixel_indices, 145 | torch::Tensor pixel_features_grad, 146 | torch::Tensor spixel_features_grad, 147 | int num_spixels_w, int num_spixels_h 148 | ){ 149 | int batchsize = pixel_features.size(0); 150 | int channels = pixel_features.size(1); 151 | int num_pixels = pixel_features.size(2); 152 | int num_spixels = spixel_features.size(2); 153 | 154 | 155 | dim3 block((batchsize * 9 * num_pixels + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS); 156 | 157 | AT_DISPATCH_FLOATING_TYPES(pixel_features_grad.type(), "backward_kernel", ([&] { 158 | backward_kernel<<< block, CUDA_NUM_THREADS >>>( 159 | dist_matrix_grad.data(), 160 | pixel_features.data(), 161 | spixel_features.data(), 162 | spixel_indices.data(), 163 | pixel_features_grad.data(), 164 | spixel_features_grad.data(), 165 | batchsize, channels, num_pixels, 166 | num_spixels, num_spixels_w, num_spixels_h 167 | ); 168 | })); 169 | 170 | return {pixel_features_grad, spixel_features_grad}; 171 | } 172 | 173 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 174 | m.def("forward", &forward_cuda, "pair_wise_distance forward"); 175 | m.def("backward", &backward_cuda, "pair_wise_distance backward"); 176 | } 177 | ''' -------------------------------------------------------------------------------- /lib/ssn/ssn.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | from .pair_wise_distance import PairwiseDistFunction 5 | from ..utils.sparse_utils import naive_sparse_bmm 6 | 7 | 8 | def calc_init_centroid(images, num_spixels_width, num_spixels_height): 9 | """ 10 | calculate initial superpixels 11 | 12 | Args: 13 | images: torch.Tensor 14 | A Tensor of shape (B, C, H, W) 15 | spixels_width: int 16 | initial superpixel width 17 | spixels_height: int 18 | initial superpixel height 19 | 20 | Return: 21 | centroids: torch.Tensor 22 | A Tensor of shape (B, C, H * W) 23 | init_label_map: torch.Tensor 24 | A Tensor of shape (B, H * W) 25 | num_spixels_width: int 26 | A number of superpixels in each column 27 | num_spixels_height: int 28 | A number of superpixels int each raw 29 | """ 30 | batchsize, channels, height, width = images.shape 31 | device = images.device 32 | 33 | centroids = torch.nn.functional.adaptive_avg_pool2d(images, (num_spixels_height, num_spixels_width)) 34 | 35 | with torch.no_grad(): 36 | num_spixels = num_spixels_width * num_spixels_height 37 | labels = torch.arange(num_spixels, device=device).reshape(1, 1, *centroids.shape[-2:]).type_as(centroids) 38 | init_label_map = torch.nn.functional.interpolate(labels, size=(height, width), mode="nearest") 39 | init_label_map = init_label_map.repeat(batchsize, 1, 1, 1) 40 | 41 | init_label_map = init_label_map.reshape(batchsize, -1) 42 | centroids = centroids.reshape(batchsize, channels, -1) 43 | 44 | return centroids, init_label_map 45 | 46 | 47 | @torch.no_grad() 48 | def get_abs_indices(init_label_map, num_spixels_width): 49 | b, n_pixel = init_label_map.shape 50 | device = init_label_map.device 51 | r = torch.arange(-1, 2.0, device=device) 52 | relative_spix_indices = torch.cat([r - num_spixels_width, r, r + num_spixels_width], 0) 53 | 54 | abs_pix_indices = torch.arange(n_pixel, device=device)[None, None].repeat(b, 9, 1).reshape(-1).long() 55 | abs_spix_indices = (init_label_map[:, None] + relative_spix_indices[None, :, None]).reshape(-1).long() 56 | abs_batch_indices = torch.arange(b, device=device)[:, None, None].repeat(1, 9, n_pixel).reshape(-1).long() 57 | 58 | return torch.stack([abs_batch_indices, abs_spix_indices, abs_pix_indices], 0) 59 | 60 | 61 | @torch.no_grad() 62 | def get_hard_abs_labels(affinity_matrix, init_label_map, num_spixels_width): 63 | relative_label = affinity_matrix.max(1)[1] 64 | r = torch.arange(-1, 2.0, device=affinity_matrix.device) 65 | relative_spix_indices = torch.cat([r - num_spixels_width, r, r + num_spixels_width], 0) 66 | label = init_label_map + relative_spix_indices[relative_label] 67 | return label.long() 68 | 69 | 70 | @torch.no_grad() 71 | def sparse_ssn_iter(pixel_features, num_spixels, n_iter): 72 | """ 73 | computing assignment iterations with sparse matrix 74 | detailed process is in Algorithm 1, line 2 - 6 75 | NOTE: this function does NOT guarantee the backward computation. 76 | 77 | Args: 78 | pixel_features: torch.Tensor 79 | A Tensor of shape (B, C, H, W) 80 | num_spixels: int 81 | A number of superpixels 82 | n_iter: int 83 | A number of iterations 84 | return_hard_label: bool 85 | return hard assignment or not 86 | """ 87 | height, width = pixel_features.shape[-2:] 88 | num_spixels_width = int(math.sqrt(num_spixels * width / height)) 89 | num_spixels_height = int(math.sqrt(num_spixels * height / width)) 90 | 91 | spixel_features, init_label_map = \ 92 | calc_init_centroid(pixel_features, num_spixels_width, num_spixels_height) 93 | abs_indices = get_abs_indices(init_label_map, num_spixels_width) 94 | 95 | pixel_features = pixel_features.reshape(*pixel_features.shape[:2], -1) 96 | permuted_pixel_features = pixel_features.permute(0, 2, 1) 97 | 98 | for _ in range(n_iter): 99 | dist_matrix = PairwiseDistFunction.apply( 100 | pixel_features, spixel_features, init_label_map, num_spixels_width, num_spixels_height) 101 | 102 | affinity_matrix = (-dist_matrix).softmax(1) 103 | reshaped_affinity_matrix = affinity_matrix.reshape(-1) 104 | 105 | mask = (abs_indices[1] >= 0) * (abs_indices[1] < num_spixels) 106 | sparse_abs_affinity = torch.sparse_coo_tensor(abs_indices[:, mask], reshaped_affinity_matrix[mask]) 107 | spixel_features = naive_sparse_bmm(sparse_abs_affinity, permuted_pixel_features) \ 108 | / (torch.sparse.sum(sparse_abs_affinity, 2).to_dense()[..., None] + 1e-16) 109 | 110 | spixel_features = spixel_features.permute(0, 2, 1) 111 | 112 | hard_labels = get_hard_abs_labels(affinity_matrix, init_label_map, num_spixels_width) 113 | 114 | return sparse_abs_affinity, hard_labels, spixel_features 115 | 116 | 117 | def ssn_iter(pixel_features, num_spixels, n_iter): 118 | """ 119 | computing assignment iterations 120 | detailed process is in Algorithm 1, line 2 - 6 121 | 122 | Args: 123 | pixel_features: torch.Tensor 124 | A Tensor of shape (B, C, H, W) 125 | num_spixels: int 126 | A number of superpixels 127 | n_iter: int 128 | A number of iterations 129 | return_hard_label: bool 130 | return hard assignment or not 131 | """ 132 | height, width = pixel_features.shape[-2:] 133 | num_spixels_width = int(math.sqrt(num_spixels * width / height)) 134 | num_spixels_height = int(math.sqrt(num_spixels * height / width)) 135 | 136 | spixel_features, init_label_map = \ 137 | calc_init_centroid(pixel_features, num_spixels_width, num_spixels_height) 138 | abs_indices = get_abs_indices(init_label_map, num_spixels_width) 139 | 140 | pixel_features = pixel_features.reshape(*pixel_features.shape[:2], -1) 141 | permuted_pixel_features = pixel_features.permute(0, 2, 1).contiguous() 142 | 143 | for _ in range(n_iter): 144 | dist_matrix = PairwiseDistFunction.apply( 145 | pixel_features, spixel_features, init_label_map, num_spixels_width, num_spixels_height) 146 | 147 | affinity_matrix = (-dist_matrix).softmax(1) 148 | reshaped_affinity_matrix = affinity_matrix.reshape(-1) 149 | 150 | mask = (abs_indices[1] >= 0) * (abs_indices[1] < num_spixels) 151 | sparse_abs_affinity = torch.sparse_coo_tensor(abs_indices[:, mask], reshaped_affinity_matrix[mask]) 152 | 153 | abs_affinity = sparse_abs_affinity.to_dense().contiguous() 154 | spixel_features = torch.bmm(abs_affinity, permuted_pixel_features) \ 155 | / (abs_affinity.sum(2, keepdim=True) + 1e-16) 156 | 157 | spixel_features = spixel_features.permute(0, 2, 1).contiguous() 158 | 159 | 160 | hard_labels = get_hard_abs_labels(affinity_matrix, init_label_map, num_spixels_width) 161 | 162 | return abs_affinity, hard_labels, spixel_features 163 | -------------------------------------------------------------------------------- /lib/ssn/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .pair_wise_distance import PairwiseDistFunction 3 | 4 | 5 | # naive implementation for debug 6 | def naive_pair_wise_dist(pix, spix, idx, n_spix_w, n_spix_h): 7 | device = pix.device 8 | ba, ch, pi = pix.shape 9 | outputs = [] 10 | for b in range(ba): 11 | batch_out = [] 12 | for p in range(pi): 13 | pix_out = [] 14 | pix_v = pix[b, :, p] 15 | sp_i = idx[b, p] 16 | sp_i_x = sp_i % n_spix_w 17 | sp_i_y = sp_i // n_spix_w 18 | for i in range(9): 19 | if sp_i_x == 0 and (i % 3) == 0: 20 | d_dist = pix.new(1).fill_(0) 21 | pix_out.append(d_dist[0]) 22 | elif sp_i_x == (n_spix_w - 1) and (i % 3) == 2: 23 | d_dist = pix.new(1).fill_(0) 24 | pix_out.append(d_dist[0]) 25 | elif sp_i_y == 0 and (i // 3) == 0: 26 | d_dist = pix.new(1).fill_(0) 27 | pix_out.append(d_dist[0]) 28 | elif sp_i_y == (n_spix_h - 1) and (i // 3) == 2: 29 | d_dist = pix.new(1).fill_(0) 30 | pix_out.append(d_dist[0]) 31 | else: 32 | offset_x = i % 3 - 1 33 | offset_y = (i // 3 - 1) * n_spix_w 34 | s = int(sp_i + offset_y + offset_x) 35 | pix_out.append((pix_v - spix[b, :, s]).pow(2).sum()) 36 | batch_out.append(torch.stack(pix_out)) 37 | outputs.append(torch.stack(batch_out, 1)) 38 | return torch.stack(outputs, 0) 39 | 40 | 41 | def test(eps=1e-4): 42 | func = PairwiseDistFunction.apply 43 | 44 | pix = torch.randn(2, 20, 81).double().to("cuda") 45 | spix = torch.randn(2, 20, 9).double().to("cuda") 46 | idx = torch.randint(0, 9, (2, 81)).double().to("cuda") 47 | wid = 3 48 | hei = 3 49 | 50 | pix.requires_grad = True 51 | spix.requires_grad = True 52 | 53 | res = torch.autograd.gradcheck(func, (pix, spix, idx, wid, hei), eps=eps, raise_exception=False) 54 | print(res) 55 | 56 | o = PairwiseDistFunction.apply(pix, spix, idx, wid, hei) 57 | o.sum().backward() 58 | 59 | cuda_p_grad = pix.grad 60 | cuda_sp_grad = spix.grad 61 | 62 | pix.grad.zero_() 63 | spix.grad.zero_() 64 | 65 | naive_o = naive_pair_wise_dist(pix, spix, idx, wid, hei) 66 | naive_o.sum().backward() 67 | 68 | print("output diff between GPU and naive", torch.abs(o - naive_o).mean()) 69 | print("pix grad diff between GPU and naive", torch.abs(cuda_p_grad - pix.grad).mean()) 70 | print("spix grad diff between GPU and naive", torch.abs(cuda_sp_grad - spix.grad).mean()) 71 | -------------------------------------------------------------------------------- /lib/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perrying/ssn-pytorch/3368840b1b72efcd8ea7ca61d1b08b2dfb846d47/lib/utils/__init__.py -------------------------------------------------------------------------------- /lib/utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .sparse_utils import naive_sparse_bmm, sparse_permute 3 | 4 | 5 | def sparse_reconstruction(assignment, labels, hard_assignment=None): 6 | """ 7 | reconstruction loss with the sparse matrix 8 | NOTE: this function doesn't use it in this project, because may not return correct gradients 9 | 10 | Args: 11 | assignment: torch.sparse_coo_tensor 12 | A Tensor of shape (B, n_spixels, n_pixels) 13 | labels: torch.Tensor 14 | A Tensor of shape (B, C, n_pixels) 15 | hard_assignment: torch.Tensor 16 | A Tensor of shape (B, n_pixels) 17 | """ 18 | labels = labels.permute(0, 2, 1).contiguous() 19 | 20 | # matrix product between (n_spixels, n_pixels) and (n_pixels, channels) 21 | spixel_mean = naive_sparse_bmm(assignment, labels) / (torch.sparse.sum(assignment, 2).to_dense()[..., None] + 1e-16) 22 | if hard_assignment is None: 23 | # (B, n_spixels, n_pixels) -> (B, n_pixels, n_spixels) 24 | permuted_assignment = sparse_permute(assignment, (0, 2, 1)) 25 | # matrix product between (n_pixels, n_spixels) and (n_spixels, channels) 26 | reconstructed_labels = naive_sparse_bmm(permuted_assignment, spixel_mean) 27 | else: 28 | # index sampling 29 | reconstructed_labels = torch.stack([sm[ha, :] for sm, ha in zip(spixel_mean, hard_assignment)], 0) 30 | return reconstructed_labels.permute(0, 2, 1).contiguous() 31 | 32 | 33 | def reconstruction(assignment, labels, hard_assignment=None): 34 | """ 35 | reconstruction 36 | 37 | Args: 38 | assignment: torch.Tensor 39 | A Tensor of shape (B, n_spixels, n_pixels) 40 | labels: torch.Tensor 41 | A Tensor of shape (B, C, n_pixels) 42 | hard_assignment: torch.Tensor 43 | A Tensor of shape (B, n_pixels) 44 | """ 45 | labels = labels.permute(0, 2, 1).contiguous() 46 | 47 | # matrix product between (n_spixels, n_pixels) and (n_pixels, channels) 48 | spixel_mean = torch.bmm(assignment, labels) / (assignment.sum(2, keepdim=True) + 1e-16) 49 | if hard_assignment is None: 50 | # (B, n_spixels, n_pixels) -> (B, n_pixels, n_spixels) 51 | permuted_assignment = assignment.permute(0, 2, 1).contiguous() 52 | # matrix product between (n_pixels, n_spixels) and (n_spixels, channels) 53 | reconstructed_labels = torch.bmm(permuted_assignment, spixel_mean) 54 | else: 55 | # index sampling 56 | reconstructed_labels = torch.stack([sm[ha, :] for sm, ha in zip(spixel_mean, hard_assignment)], 0) 57 | return reconstructed_labels.permute(0, 2, 1).contiguous() 58 | 59 | 60 | def reconstruct_loss_with_cross_etnropy(assignment, labels, hard_assignment=None): 61 | """ 62 | reconstruction loss with cross entropy 63 | 64 | Args: 65 | assignment: torch.Tensor 66 | A Tensor of shape (B, n_spixels, n_pixels) 67 | labels: torch.Tensor 68 | A Tensor of shape (B, C, n_pixels) 69 | hard_assignment: torch.Tensor 70 | A Tensor of shape (B, n_pixels) 71 | """ 72 | reconstracted_labels = reconstruction(assignment, labels, hard_assignment) 73 | reconstracted_labels = reconstracted_labels / (1e-16 + reconstracted_labels.sum(1, keepdim=True)) 74 | mask = labels > 0 75 | return -(reconstracted_labels[mask] + 1e-16).log().mean() 76 | 77 | 78 | def reconstruct_loss_with_mse(assignment, labels, hard_assignment=None): 79 | """ 80 | reconstruction loss with mse 81 | 82 | Args: 83 | assignment: torch.Tensor 84 | A Tensor of shape (B, n_spixels, n_pixels) 85 | labels: torch.Tensor 86 | A Tensor of shape (B, C, n_pixels) 87 | hard_assignment: torch.Tensor 88 | A Tensor of shape (B, n_pixels) 89 | """ 90 | reconstracted_labels = reconstruction(assignment, labels, hard_assignment) 91 | return torch.nn.functional.mse_loss(reconstracted_labels, labels) 92 | -------------------------------------------------------------------------------- /lib/utils/meter.py: -------------------------------------------------------------------------------- 1 | class Meter: 2 | def __init__(self, ema_coef=0.9): 3 | self.ema_coef = ema_coef 4 | self.params = {} 5 | 6 | def add(self, params:dict, ignores:list = []): 7 | for k, v in params.items(): 8 | if k in ignores: 9 | continue 10 | if not k in self.params.keys(): 11 | self.params[k] = v 12 | else: 13 | self.params[k] -= (1 - self.ema_coef) * (self.params[k] - v) 14 | 15 | def state(self, header="", footer=""): 16 | state = header 17 | for k, v in self.params.items(): 18 | state += f" {k} {v:.6g} |" 19 | return state + " " + footer 20 | 21 | def reset(self): 22 | self.params = {} -------------------------------------------------------------------------------- /lib/utils/sparse_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def naive_sparse_bmm(sparse_mat, dense_mat, transpose=False): 5 | if transpose: 6 | return torch.stack([torch.sparse.mm(s_mat, d_mat.t()) for s_mat, d_mat in zip(sparse_mat, dense_mat)], 0) 7 | else: 8 | return torch.stack([torch.sparse.mm(s_mat, d_mat) for s_mat, d_mat in zip(sparse_mat, dense_mat)], 0) 9 | 10 | def sparse_permute(sparse_mat, order): 11 | values = sparse_mat.coalesce().values() 12 | indices = sparse_mat.coalesce().indices() 13 | indices = torch.stack([indices[o] for o in order], 0).contiguous() 14 | return torch.sparse_coo_tensor(indices, values) 15 | 16 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from lib.ssn.ssn import ssn_iter, sparse_ssn_iter 5 | 6 | 7 | def conv_bn_relu(in_c, out_c): 8 | return nn.Sequential( 9 | nn.Conv2d(in_c, out_c, 3, padding=1, bias=False), 10 | nn.BatchNorm2d(out_c), 11 | nn.ReLU(True) 12 | ) 13 | 14 | class SSNModel(nn.Module): 15 | def __init__(self, feature_dim, nspix, n_iter=10): 16 | super().__init__() 17 | self.nspix = nspix 18 | self.n_iter = n_iter 19 | 20 | self.scale1 = nn.Sequential( 21 | conv_bn_relu(5, 64), 22 | conv_bn_relu(64, 64) 23 | ) 24 | self.scale2 = nn.Sequential( 25 | nn.MaxPool2d(3, 2, padding=1), 26 | conv_bn_relu(64, 64), 27 | conv_bn_relu(64, 64) 28 | ) 29 | self.scale3 = nn.Sequential( 30 | nn.MaxPool2d(3, 2, padding=1), 31 | conv_bn_relu(64, 64), 32 | conv_bn_relu(64, 64) 33 | ) 34 | 35 | self.output_conv = nn.Sequential( 36 | nn.Conv2d(64*3+5, feature_dim-5, 3, padding=1), 37 | nn.ReLU(True) 38 | ) 39 | 40 | for m in self.modules(): 41 | if isinstance(m, nn.Conv2d): 42 | nn.init.normal_(m.weight, 0, 0.001) 43 | if m.bias is not None: 44 | nn.init.constant_(m.bias, 0) 45 | 46 | 47 | def forward(self, x): 48 | pixel_f = self.feature_extract(x) 49 | 50 | if self.training: 51 | return ssn_iter(pixel_f, self.nspix, self.n_iter) 52 | else: 53 | return sparse_ssn_iter(pixel_f, self.nspix, self.n_iter) 54 | 55 | 56 | def feature_extract(self, x): 57 | s1 = self.scale1(x) 58 | s2 = self.scale2(s1) 59 | s3 = self.scale3(s2) 60 | 61 | s2 = nn.functional.interpolate(s2, size=s1.shape[-2:], mode="bilinear", align_corners=False) 62 | s3 = nn.functional.interpolate(s3, size=s1.shape[-2:], mode="bilinear", align_corners=False) 63 | 64 | cat_feat = torch.cat([x, s1, s2, s3], 1) 65 | feat = self.output_conv(cat_feat) 66 | 67 | return torch.cat([feat, x], 1) 68 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os, math 2 | import numpy as np 3 | import time 4 | import torch 5 | import torch.optim as optim 6 | from torch.utils.data import DataLoader 7 | 8 | from lib.utils.meter import Meter 9 | from model import SSNModel 10 | from lib.dataset import bsds, augmentation 11 | from lib.utils.loss import reconstruct_loss_with_cross_etnropy, reconstruct_loss_with_mse 12 | 13 | 14 | @torch.no_grad() 15 | def eval(model, loader, color_scale, pos_scale, device): 16 | def achievable_segmentation_accuracy(superpixel, label): 17 | """ 18 | Function to calculate Achievable Segmentation Accuracy: 19 | ASA(S,G) = sum_j max_i |s_j \cap g_i| / sum_i |g_i| 20 | 21 | Args: 22 | input: superpixel image (H, W), 23 | output: ground-truth (H, W) 24 | """ 25 | TP = 0 26 | unique_id = np.unique(superpixel) 27 | for uid in unique_id: 28 | mask = superpixel == uid 29 | label_hist = np.histogram(label[mask]) 30 | maximum_regionsize = label_hist[0].max() 31 | TP += maximum_regionsize 32 | return TP / label.size 33 | 34 | model.eval() 35 | sum_asa = 0 36 | for data in loader: 37 | inputs, labels = data 38 | 39 | inputs = inputs.to(device) 40 | labels = labels.to(device) 41 | 42 | height, width = inputs.shape[-2:] 43 | 44 | nspix_per_axis = int(math.sqrt(model.nspix)) 45 | pos_scale = pos_scale * max(nspix_per_axis/height, nspix_per_axis/width) 46 | 47 | coords = torch.stack(torch.meshgrid(torch.arange(height, device=device), torch.arange(width, device=device)), 0) 48 | coords = coords[None].repeat(inputs.shape[0], 1, 1, 1).float() 49 | 50 | inputs = torch.cat([color_scale*inputs, pos_scale*coords], 1) 51 | 52 | Q, H, feat = model(inputs) 53 | 54 | H = H.reshape(height, width) 55 | labels = labels.argmax(1).reshape(height, width) 56 | 57 | asa = achievable_segmentation_accuracy(H.to("cpu").detach().numpy(), labels.to("cpu").numpy()) 58 | sum_asa += asa 59 | model.train() 60 | return sum_asa / len(loader) 61 | 62 | 63 | def update_param(data, model, optimizer, compactness, color_scale, pos_scale, device): 64 | inputs, labels = data 65 | 66 | inputs = inputs.to(device) 67 | labels = labels.to(device) 68 | 69 | height, width = inputs.shape[-2:] 70 | 71 | nspix_per_axis = int(math.sqrt(model.nspix)) 72 | pos_scale = pos_scale * max(nspix_per_axis/height, nspix_per_axis/width) 73 | 74 | coords = torch.stack(torch.meshgrid(torch.arange(height, device=device), torch.arange(width, device=device)), 0) 75 | coords = coords[None].repeat(inputs.shape[0], 1, 1, 1).float() 76 | 77 | inputs = torch.cat([color_scale*inputs, pos_scale*coords], 1) 78 | 79 | Q, H, feat = model(inputs) 80 | 81 | recons_loss = reconstruct_loss_with_cross_etnropy(Q, labels) 82 | compact_loss = reconstruct_loss_with_mse(Q, coords.reshape(*coords.shape[:2], -1), H) 83 | 84 | loss = recons_loss + compactness * compact_loss 85 | 86 | optimizer.zero_grad() 87 | loss.backward() 88 | optimizer.step() 89 | 90 | return {"loss": loss.item(), "reconstruction": recons_loss.item(), "compact": compact_loss.item()} 91 | 92 | 93 | def train(cfg): 94 | if torch.cuda.is_available(): 95 | device = "cuda" 96 | else: 97 | device = "cpu" 98 | 99 | model = SSNModel(cfg.fdim, cfg.nspix, cfg.niter).to(device) 100 | 101 | optimizer = optim.Adam(model.parameters(), cfg.lr) 102 | 103 | augment = augmentation.Compose([augmentation.RandomHorizontalFlip(), augmentation.RandomScale(), augmentation.RandomCrop()]) 104 | train_dataset = bsds.BSDS(cfg.root, geo_transforms=augment) 105 | train_loader = DataLoader(train_dataset, cfg.batchsize, shuffle=True, drop_last=True, num_workers=cfg.nworkers) 106 | 107 | test_dataset = bsds.BSDS(cfg.root, split="val") 108 | test_loader = DataLoader(test_dataset, 1, shuffle=False, drop_last=False) 109 | 110 | meter = Meter() 111 | 112 | iterations = 0 113 | max_val_asa = 0 114 | while iterations < cfg.train_iter: 115 | for data in train_loader: 116 | iterations += 1 117 | metric = update_param(data, model, optimizer, cfg.compactness, cfg.color_scale, cfg.pos_scale, device) 118 | meter.add(metric) 119 | state = meter.state(f"[{iterations}/{cfg.train_iter}]") 120 | print(state) 121 | if (iterations % cfg.test_interval) == 0: 122 | asa = eval(model, test_loader, cfg.color_scale, cfg.pos_scale, device) 123 | print(f"validation asa {asa}") 124 | if asa > max_val_asa: 125 | max_val_asa = asa 126 | torch.save(model.state_dict(), os.path.join(cfg.out_dir, "bset_model.pth")) 127 | if iterations == cfg.train_iter: 128 | break 129 | 130 | unique_id = str(int(time.time())) 131 | torch.save(model.state_dict(), os.path.join(cfg.out_dir, "model"+unique_id+".pth")) 132 | 133 | 134 | if __name__ == "__main__": 135 | import argparse 136 | parser = argparse.ArgumentParser() 137 | 138 | parser.add_argument("--root", type=str, help="/path/to/BSR") 139 | parser.add_argument("--out_dir", default="./log", type=str, help="/path/to/output directory") 140 | parser.add_argument("--batchsize", default=6, type=int) 141 | parser.add_argument("--nworkers", default=4, type=int, help="number of threads for CPU parallel") 142 | parser.add_argument("--lr", default=1e-4, type=float, help="learning rate") 143 | parser.add_argument("--train_iter", default=500000, type=int) 144 | parser.add_argument("--fdim", default=20, type=int, help="embedding dimension") 145 | parser.add_argument("--niter", default=5, type=int, help="number of iterations for differentiable SLIC") 146 | parser.add_argument("--nspix", default=100, type=int, help="number of superpixels") 147 | parser.add_argument("--color_scale", default=0.26, type=float) 148 | parser.add_argument("--pos_scale", default=2.5, type=float) 149 | parser.add_argument("--compactness", default=1e-5, type=float) 150 | parser.add_argument("--test_interval", default=10000, type=int) 151 | 152 | args = parser.parse_args() 153 | 154 | os.makedirs(args.out_dir, exist_ok=True) 155 | 156 | train(args) 157 | --------------------------------------------------------------------------------