├── 45000_0.527_model.pt ├── README.md ├── __pycache__ ├── dataset.cpython-36.pyc ├── loss.cpython-36.pyc ├── model.cpython-36.pyc └── util.cpython-36.pyc ├── connectivity.cpython-36m-x86_64-linux-gnu.so ├── dataset.py ├── log ├── events.out.tfevents.1578044794.node02 └── events.out.tfevents.1578044942.node02 ├── loss.py ├── model.py ├── results ├── pix │ ├── 100007_bdry_.jpg │ ├── 100039_bdry_.jpg │ ├── 100099_bdry_.jpg │ ├── 10081_bdry_.jpg │ ├── 101027_bdry_.jpg │ ├── 101084_bdry_.jpg │ ├── 102062_bdry_.jpg │ ├── 103006_bdry_.jpg │ ├── 103029_bdry_.jpg │ ├── 103078_bdry_.jpg │ ├── 104010_bdry_.jpg │ ├── 104055_bdry_.jpg │ └── 105027_bdry_.jpg └── ssn │ ├── 100007_bdry_.jpg │ ├── 100039_bdry_.jpg │ ├── 100099_bdry_.jpg │ ├── 10081_bdry_.jpg │ ├── 101027_bdry_.jpg │ ├── 101084_bdry_.jpg │ ├── 102062_bdry_.jpg │ ├── 103006_bdry_.jpg │ ├── 103029_bdry_.jpg │ ├── 103078_bdry_.jpg │ ├── 104010_bdry_.jpg │ ├── 104055_bdry_.jpg │ ├── 105027_bdry_.jpg │ └── 106005_bdry_.jpg ├── test.py ├── train.py └── util.py /45000_0.527_model.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CYang0515/pytorch_ssn/289ba3132525e1312a018abdcc082b912f7f9021/45000_0.527_model.pt -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch_ssn 2 | A pytorch version of SSN (Superpixel Sampling Networks) 3 | The data preparation is same as https://github.com/NVlabs/ssn_superpixels.git. 4 | To enforce connectivity in superpixels, the cython script takes from official code. 5 | 6 | To simplify the implementation, each init superpixel has the same number of pixels during the training. 7 | -------------------------------------------------------------------------------- /__pycache__/dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CYang0515/pytorch_ssn/289ba3132525e1312a018abdcc082b912f7f9021/__pycache__/dataset.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CYang0515/pytorch_ssn/289ba3132525e1312a018abdcc082b912f7f9021/__pycache__/loss.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CYang0515/pytorch_ssn/289ba3132525e1312a018abdcc082b912f7f9021/__pycache__/model.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CYang0515/pytorch_ssn/289ba3132525e1312a018abdcc082b912f7f9021/__pycache__/util.cpython-36.pyc -------------------------------------------------------------------------------- /connectivity.cpython-36m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CYang0515/pytorch_ssn/289ba3132525e1312a018abdcc082b912f7f9021/connectivity.cpython-36m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils import data 2 | import os 3 | import scipy 4 | from scipy.io import loadmat 5 | from skimage.color import rgb2lab 6 | from skimage.util import img_as_float 7 | from skimage import io 8 | import numpy as np 9 | from random import Random 10 | from scipy import interpolate 11 | from util import convert_index 12 | RAND_SEED = 2356 13 | myrandom = Random(RAND_SEED) 14 | 15 | def convert_label(label, num=50): 16 | 17 | problabel = np.zeros((1, num, label.shape[0], label.shape[1])).astype(np.float32) 18 | 19 | ct = 0 20 | for t in np.unique(label).tolist(): 21 | if ct >= num: 22 | print(np.unique(label).shape) 23 | break 24 | # raise IOError 25 | else: 26 | problabel[:, ct, :, :] = (label == t) 27 | ct = ct + 1 28 | 29 | label2 = np.squeeze(np.argmax(problabel, axis = 1)) 30 | 31 | return label2, problabel 32 | 33 | def transform_and_get_image(im, max_spixels, out_size): 34 | 35 | height = im.shape[0] 36 | width = im.shape[1] 37 | 38 | out_height = out_size[0] 39 | out_width = out_size[1] 40 | 41 | pad_height = out_height - height 42 | pad_width = out_width - width 43 | im = np.lib.pad(im, ((0, pad_height), (0, pad_width), (0, 0)), 'constant', 44 | constant_values=-10) 45 | im = np.expand_dims(im, axis=0) 46 | return im 47 | 48 | def get_spixel_init(num_spixels, img_width, img_height): 49 | """ 50 | :return each pixel belongs to which pixel 51 | """ 52 | 53 | k = num_spixels 54 | k_w = int(np.floor(np.sqrt(k * img_width / img_height))) 55 | k_h = int(np.floor(np.sqrt(k * img_height / img_width))) 56 | 57 | spixel_height = img_height / (1. * k_h) 58 | spixel_width = img_width / (1. * k_w) 59 | 60 | h_coords = np.arange(-spixel_height / 2. - 1, img_height + spixel_height - 1, 61 | spixel_height) 62 | w_coords = np.arange(-spixel_width / 2. - 1, img_width + spixel_width - 1, 63 | spixel_width) 64 | spix_values = np.int32(np.arange(0, k_w * k_h).reshape((k_h, k_w))) 65 | spix_values = np.pad(spix_values, 1, 'symmetric') 66 | f = interpolate.RegularGridInterpolator((h_coords, w_coords), spix_values, method='nearest') 67 | 68 | all_h_coords = np.arange(0, img_height, 1) 69 | all_w_coords = np.arange(0, img_width, 1) 70 | all_grid = np.array(np.meshgrid(all_h_coords, all_w_coords, indexing = 'ij')) 71 | all_points = np.reshape(all_grid, (2, img_width * img_height)).transpose() 72 | 73 | spixel_initmap = f(all_points).reshape((img_height,img_width)) 74 | 75 | feat_spixel_initmap = spixel_initmap 76 | return [spixel_initmap, feat_spixel_initmap, k_w, k_h] 77 | 78 | def transform_and_get_spixel_init(max_spixels, out_size): 79 | 80 | out_height = out_size[0] 81 | out_width = out_size[1] 82 | 83 | spixel_init, feat_spixel_initmap, k_w, k_h = \ 84 | get_spixel_init(max_spixels, out_width, out_height) 85 | spixel_init = spixel_init[None, None, :, :] 86 | feat_spixel_initmap = feat_spixel_initmap[None, None, :, :] 87 | 88 | return spixel_init, feat_spixel_initmap, k_h, k_w 89 | def get_rand_scale_factor(): 90 | 91 | rand_factor = np.random.normal(1, 0.75) 92 | 93 | s_factor = np.min((3.0, rand_factor)) 94 | s_factor = np.max((0.75, s_factor)) 95 | 96 | return s_factor 97 | def scale_image(im, s_factor): 98 | 99 | s_img = scipy.ndimage.zoom(im, (s_factor, s_factor, 1), order = 1) 100 | 101 | return s_img 102 | def scale_label(label, s_factor): 103 | 104 | s_label = scipy.ndimage.zoom(label, (s_factor, s_factor), order = 0) 105 | 106 | return s_label 107 | 108 | def PixelFeature(img, color_scale=None, pos_scale=None, type=None): 109 | b,h,w,c = img.shape 110 | feat = img * color_scale 111 | if type is 'RGB_AND_POSITION': #yxrcb 112 | x_axis = np.arange(0, w, 1) 113 | y_axis = np.arange(0, h, 1) 114 | x_mesh, y_mesh = np.meshgrid(x_axis, y_axis) 115 | yx = np.stack([y_mesh, x_mesh], axis=-1) 116 | yx_scaled = yx * pos_scale 117 | yx_scaled = np.repeat(yx_scaled[np.newaxis], b, axis=0) 118 | feat = np.concatenate([yx_scaled, feat], axis=-1) 119 | return feat 120 | 121 | class Dataset(data.Dataset): 122 | def __init__(self, num_spixel, root=None, patch_size=None, dtype='train'): 123 | self.patch_size = patch_size 124 | # self.width = width 125 | self.num_spixel = num_spixel 126 | self.out_types = ['img', 'spixel_init', 'feat_spixel_init', 'label', 'problabel'] 127 | 128 | self.root = root 129 | self.dtype = dtype 130 | self.data_dir = os.path.join(self.root, 'BSR', 'BSDS500', 'data') 131 | 132 | self.split_list = open(os.path.join(root, dtype + '.txt')).readlines() 133 | self.img_dir = os.path.join(self.data_dir, 'images', self.dtype) 134 | self.gt_dir = os.path.join(self.data_dir, 'groundTruth', self.dtype) 135 | 136 | # init pixel-spixel index 137 | self.out_spixel_init, self.feat_spixel_init, self.spixels_h, self.spixels_w = \ 138 | transform_and_get_spixel_init(self.num_spixel, [patch_size[0], patch_size[1]]) 139 | self.init, self.cir, self.p2sp_index_, self.invisible = convert_index(self.spixels_w, self.spixels_w*self.spixels_h, self.feat_spixel_init) 140 | self.invisible = self.invisible.astype(np.float) 141 | 142 | 143 | def __getitem__(self, item): 144 | img_name = self.split_list[item].rstrip('\n') 145 | e=io.imread(os.path.join(self.img_dir, img_name + '.jpg')) 146 | image = img_as_float(io.imread(os.path.join(self.img_dir, img_name + '.jpg'))) 147 | s_factor = get_rand_scale_factor() 148 | image = scale_image(image, s_factor) 149 | im = rgb2lab(image) 150 | h, w, _ = im.shape 151 | 152 | gtseg_all = loadmat(os.path.join(self.gt_dir, img_name + '.mat')) 153 | t = np.random.randint(0, len(gtseg_all['groundTruth'][0])) 154 | gtseg = gtseg_all['groundTruth'][0][t][0][0][0] 155 | gtseg = scale_label(gtseg, s_factor) 156 | 157 | if np.random.uniform(0, 1) > 0.5: 158 | im = im[:, ::-1, ...] 159 | gtseg = gtseg[:, ::-1] 160 | 161 | if self.patch_size == None: 162 | raise ('not define the output size') 163 | else: 164 | out_height = self.patch_size[0] 165 | out_width = self.patch_size[1] 166 | 167 | if out_height > h: 168 | raise ("Patch size is greater than image size") 169 | 170 | if out_width > w: 171 | raise ("Patch size is greater than image size") 172 | 173 | start_row = myrandom.randint(0, h - out_height) 174 | start_col = myrandom.randint(0, w - out_width) 175 | im_cropped = im[start_row: start_row + out_height, 176 | start_col: start_col + out_width, :] 177 | out_img = transform_and_get_image(im_cropped, self.num_spixel, [out_height, out_width]) 178 | # add xy information 179 | out_img = PixelFeature(out_img, color_scale=0.26, pos_scale=0.125, type='RGB_AND_POSITION') 180 | 181 | gtseg_cropped = gtseg[start_row: start_row + out_height, 182 | start_col: start_col + out_width] 183 | label_cropped, problabel_cropped = convert_label(gtseg_cropped) 184 | 185 | inputs = {} 186 | for in_name in self.out_types: 187 | if in_name == 'img': 188 | inputs['img'] = np.transpose(out_img[0], [2, 0, 1]).astype(np.float32) 189 | if in_name == 'spixel_init': 190 | inputs['spixel_init'] = self.out_spixel_init[0].astype(np.float32) 191 | if in_name == 'feat_spixel_init': 192 | inputs['feat_spixel_init'] = self.feat_spixel_init[0].astype(np.float32) 193 | if in_name == 'label': 194 | label_cropped = np.expand_dims(np.expand_dims(label_cropped, axis=0), axis=0) 195 | inputs['label'] = label_cropped[0] 196 | if in_name == 'problabel': 197 | inputs['problabel'] = problabel_cropped[0] 198 | 199 | return inputs, self.spixels_h, self.spixels_w, self.init, self.cir, self.p2sp_index_, self.invisible 200 | 201 | def __len__(self): 202 | return len(self.split_list) 203 | 204 | class Dataset_T(data.Dataset): 205 | def __init__(self, num_spixel, root='', patch_size=None, dtype='test'): 206 | self.patch_size = patch_size 207 | self.num_spixel = num_spixel 208 | self.out_types = ['img', 'spixel_init', 'feat_spixel_init', 'label', 'problabel'] 209 | 210 | self.root = root 211 | self.dtype = dtype 212 | self.data_dir = os.path.join(self.root, 'BSR', 'BSDS500', 'data') 213 | 214 | self.split_list = open(os.path.join(root, dtype + '.txt')).readlines() 215 | self.img_dir = os.path.join(self.data_dir, 'images', self.dtype) 216 | self.gt_dir = os.path.join(self.data_dir, 'groundTruth', self.dtype) 217 | 218 | def __getitem__(self, item): 219 | img_name = self.split_list[item].rstrip('\n') 220 | image = img_as_float(io.imread(os.path.join(self.img_dir, img_name + '.jpg'))) 221 | 222 | im = rgb2lab(image) 223 | h, w, _ = im.shape 224 | 225 | gtseg_all = loadmat(os.path.join(self.gt_dir, img_name + '.mat')) 226 | t = 0 #np.random.randint(0, len(gtseg_all['groundTruth'][0])) 227 | gtseg = gtseg_all['groundTruth'][0][t][0][0][0] 228 | 229 | k = self.num_spixel 230 | k_w = int(np.floor(np.sqrt(k * w / h))) 231 | k_h = int(np.floor(np.sqrt(k * h / w))) 232 | spixel_height = h / (1. * k_h) 233 | spixel_width = w / (1. * k_w) 234 | 235 | out_height = int(np.ceil(spixel_height) * k_h) 236 | out_width = int(np.ceil(spixel_width) * k_w) 237 | 238 | out_img = transform_and_get_image(im, self.num_spixel, [out_height, out_width]) 239 | # add xy information 240 | pos_scale = 2.5 * max(k_h/out_height, k_w/out_width) 241 | out_img = PixelFeature(out_img, color_scale=0.26, pos_scale=pos_scale, type='RGB_AND_POSITION') 242 | 243 | gtseg_ = np.ones_like(out_img[0, :, :, 0]) * 49 244 | gtseg_[:h, :w] = gtseg 245 | label_cropped, problabel_cropped = convert_label(gtseg_) 246 | 247 | self.out_spixel_init, self.feat_spixel_init, self.spixels_h, self.spixels_w = \ 248 | transform_and_get_spixel_init(self.num_spixel, [out_height, out_width]) 249 | self.init, self.cir, self.p2sp_index_, self.invisible = convert_index(self.spixels_w, 250 | self.spixels_w * self.spixels_h, 251 | self.feat_spixel_init) 252 | self.invisible = self.invisible.astype(np.float) 253 | 254 | inputs = {} 255 | for in_name in self.out_types: 256 | if in_name == 'img': 257 | inputs['img'] = np.transpose(out_img[0], [2, 0, 1]).astype(np.float32) 258 | if in_name == 'spixel_init': 259 | inputs['spixel_init'] = self.out_spixel_init[0].astype(np.float32) 260 | if in_name == 'feat_spixel_init': 261 | inputs['feat_spixel_init'] = self.feat_spixel_init[0].astype(np.float32) 262 | if in_name == 'label': 263 | label_cropped = np.expand_dims(np.expand_dims(label_cropped, axis=0), axis=0) 264 | inputs['label'] = label_cropped[0] 265 | if in_name == 'problabel': 266 | inputs['problabel'] = problabel_cropped[0] 267 | 268 | return inputs, self.spixels_h, self.spixels_w, self.init, self.cir, self.p2sp_index_, self.invisible, \ 269 | os.path.join(self.img_dir, img_name + '.jpg') 270 | 271 | def __len__(self): 272 | return len(self.split_list) 273 | 274 | if __name__ == '__main__': 275 | data = Dataset(100, patch_size=[200, 200]) 276 | for i in data: 277 | s=1 278 | 279 | 280 | -------------------------------------------------------------------------------- /log/events.out.tfevents.1578044794.node02: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CYang0515/pytorch_ssn/289ba3132525e1312a018abdcc082b912f7f9021/log/events.out.tfevents.1578044794.node02 -------------------------------------------------------------------------------- /log/events.out.tfevents.1578044942.node02: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CYang0515/pytorch_ssn/289ba3132525e1312a018abdcc082b912f7f9021/log/events.out.tfevents.1578044942.node02 -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import pdb 5 | class position_color_loss(nn.Module): 6 | def __init__(self, pos_weight=0.4, col_weight=0.): 7 | """ 8 | :param pos_weight: 9 | :param col_weight: 10 | """ 11 | super(position_color_loss, self).__init__() 12 | self.pos_weight = pos_weight 13 | self.col_weight = col_weight 14 | self.mse_loss = nn.MSELoss() 15 | 16 | def forward(self, recon_feat, pixel_features): 17 | """ 18 | 19 | :param recon_feat: B*C*H*W restructure pixel feature (c=RGBplusXY) 20 | :param pixel_features: B*C*H*W original pixel feature 21 | :return: 22 | """ 23 | # pdb.set_trace() 24 | pos_recon_feat = recon_feat[:, :2, :, :] 25 | color_recon_feat = recon_feat[:, 2:, :, :] 26 | pos_pix_feat = pixel_features[:, :2, :, :] 27 | color_pix_feat = pixel_features[:, 2:, :, :] 28 | 29 | pos_loss = self.mse_loss(pos_recon_feat, pos_pix_feat) 30 | color_loss = self.mse_loss(color_recon_feat, color_pix_feat) 31 | 32 | pos_clor_loss = pos_loss * self.pos_weight + color_loss * self.col_weight 33 | 34 | return pos_clor_loss 35 | 36 | class LossWithoutSoftmax(nn.Module): 37 | def __init__(self, loss_weight=1.0, ignore_label=255): 38 | super(LossWithoutSoftmax, self).__init__() 39 | self.loss_weight = loss_weight 40 | self.ignore_label = ignore_label 41 | self.NLLloss = nn.NLLLoss(reduction='none') 42 | def forward(self, recon_label3, label, invisible_p=None): 43 | """ 44 | 45 | :param recon_label3: B*C*H*W reconstructure label by soft threshold 46 | :param label: B*1*H*W gt label 47 | :param invisible_p: B*H*W invisible pixel (ignore region) 48 | :return: 49 | """ 50 | # pdb.set_trace() 51 | label = label[:, 0, ...] 52 | 53 | # add ignore region 54 | if invisible_p is not None: 55 | ignore = invisible_p == 1. 56 | elif self.ignore_label is not None: 57 | ignore = label == self.ignore_label 58 | else: 59 | raise IOError 60 | label[ignore] = 0 61 | 62 | loss = self.NLLloss(recon_label3, label) # B*H*W 63 | # 64 | # view_loss = loss.data.numpy() 65 | # 66 | loss = -1 * loss[1 - ignore] 67 | loss = -1 * torch.log(loss) 68 | loss = loss.mean() * self.loss_weight 69 | 70 | return loss 71 | 72 | 73 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from util import * 4 | from loss import * 5 | 6 | class conv_bn_relu(nn.Module): 7 | def __init__(self, in_channels, channels, bn=True): 8 | super(conv_bn_relu, self).__init__() 9 | self.BN_ = bn 10 | self.conv = nn.Conv2d(in_channels, channels, 3, padding=1) 11 | if self.BN_: 12 | self.bn = nn.BatchNorm2d(channels) 13 | self.relu = nn.ReLU(inplace=True) 14 | 15 | def forward(self, x): 16 | x = self.conv(x) 17 | if self.BN_: 18 | x = self.bn(x) 19 | x = self.relu(x) 20 | return x 21 | 22 | 23 | 24 | 25 | class cnn_module(nn.Module): 26 | def __init__(self, out_channel=15): 27 | super(cnn_module, self).__init__() 28 | self.conv1 = conv_bn_relu(5, 64) 29 | self.conv2 = conv_bn_relu(64, 64) 30 | self.pool1 = nn.MaxPool2d(3, 2, 1) 31 | 32 | self.conv3 = conv_bn_relu(64, 64) 33 | self.conv4 = conv_bn_relu(64, 64) 34 | self.pool2 = nn.MaxPool2d(3, 2, 1) 35 | 36 | self.conv5 = conv_bn_relu(64, 64) 37 | self.conv6 = conv_bn_relu(64, 64) 38 | 39 | self.conv6_up = nn.Upsample(scale_factor=4) 40 | self.conv4_up = nn.Upsample(scale_factor=2) 41 | 42 | self.conv7 = conv_bn_relu(197, out_channel, False) 43 | 44 | def forward(self, x): 45 | conv1 = self.conv1(x) 46 | conv2 = self.conv2(conv1) 47 | pool1 = self.pool1(conv2) 48 | 49 | conv3 = self.conv3(pool1) 50 | conv4 = self.conv4(conv3) 51 | pool2 = self.pool2(conv4) 52 | 53 | conv5 = self.conv5(pool2) 54 | conv6 = self.conv6(conv5) 55 | 56 | conv6_up = self.conv6_up(conv6) 57 | conv4_up = self.conv4_up(conv4) 58 | 59 | conv_concat = torch.cat((x, conv2, conv4_up, conv6_up), 1) 60 | conv7 = self.conv7(conv_concat) 61 | conv_comb = torch.cat((x, conv7), 1) 62 | 63 | return conv_comb 64 | 65 | class create_ssn_net(nn.Module): 66 | def __init__(self, num_spixels, num_iter, num_spixels_h, num_spixels_w, dtype='train', ssn=1): 67 | super(create_ssn_net, self).__init__() 68 | self.trans_features = cnn_module() 69 | self.num_spixels = num_spixels 70 | self.num_iter = num_iter 71 | self.num_spixels_h = num_spixels_h 72 | self.num_spixels_w = num_spixels_w 73 | self.num_spixels = num_spixels_h * num_spixels_w 74 | self.dtype = dtype 75 | self.ssn = ssn 76 | 77 | def forward(self, x, p2sp_index, invisible, init_index, cir_index, problabel, spixel_h, spixel_w, device): 78 | if self.ssn: 79 | trans_features = self.trans_features(x) 80 | else: 81 | trans_features = x 82 | self.num_spixels_h = spixel_h[0] 83 | self.num_spixels_w = spixel_w[0] 84 | self.num_spixels = spixel_h[0] * spixel_w[0] 85 | self.device = device 86 | 87 | # init spixel feature 88 | spixel_feature = SpixelFeature(trans_features, init_index, max_spixels=self.num_spixels) 89 | 90 | for i in range(self.num_iter): 91 | spixel_feature, _ = exec_iter(spixel_feature, trans_features, cir_index, p2sp_index, 92 | invisible, self.num_spixels_h, self.num_spixels_w, self.device) 93 | 94 | final_pixel_assoc = compute_assignments(spixel_feature, trans_features, p2sp_index, invisible, device) # out of memory 95 | 96 | if self.dtype == 'train': 97 | new_spixel_feat = SpixelFeature2(x, final_pixel_assoc, cir_index, invisible, 98 | self.num_spixels_h, self.num_spixels_w) 99 | new_spix_indices = compute_final_spixel_labels(final_pixel_assoc, p2sp_index, 100 | self.num_spixels_h, self.num_spixels_w) 101 | recon_feat2 = Semar(new_spixel_feat, new_spix_indices) 102 | spixel_label = SpixelFeature2(problabel, final_pixel_assoc, cir_index, invisible, 103 | self.num_spixels_h, self.num_spixels_w) 104 | recon_label = decode_features(final_pixel_assoc, spixel_label, p2sp_index, 105 | self.num_spixels_h, self.num_spixels_w, self.num_spixels, 50) 106 | return recon_feat2, recon_label 107 | 108 | elif self.dtype == 'test': 109 | new_spixel_feat = SpixelFeature2(x, final_pixel_assoc, cir_index, invisible, 110 | self.num_spixels_h, self.num_spixels_w) 111 | new_spix_indices = compute_final_spixel_labels(final_pixel_assoc, p2sp_index, 112 | self.num_spixels_h, self.num_spixels_w) 113 | recon_feat2 = Semar(new_spixel_feat, new_spix_indices) 114 | spixel_label = SpixelFeature2(problabel, final_pixel_assoc, cir_index, invisible, 115 | self.num_spixels_h, self.num_spixels_w) 116 | recon_label = decode_features(final_pixel_assoc, spixel_label, p2sp_index, 117 | self.num_spixels_h, self.num_spixels_w, self.num_spixels, 50) 118 | 119 | # import pdb 120 | # pdb.set_trace() 121 | return recon_feat2, recon_label, new_spix_indices 122 | 123 | else: 124 | pass 125 | 126 | class Loss(nn.Module): 127 | def __init__(self): 128 | super(Loss, self).__init__() 129 | self.loss1 = position_color_loss() 130 | self.loss2 = LossWithoutSoftmax() 131 | 132 | def forward(self, recon_feat2, pixel_feature, recon_label, label): 133 | loss1 = self.loss1(recon_feat2, pixel_feature) 134 | loss2 = self.loss2(recon_label, label) 135 | 136 | return loss1 + loss2, loss1, loss2 137 | 138 | 139 | 140 | 141 | 142 | -------------------------------------------------------------------------------- /results/pix/100007_bdry_.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CYang0515/pytorch_ssn/289ba3132525e1312a018abdcc082b912f7f9021/results/pix/100007_bdry_.jpg -------------------------------------------------------------------------------- /results/pix/100039_bdry_.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CYang0515/pytorch_ssn/289ba3132525e1312a018abdcc082b912f7f9021/results/pix/100039_bdry_.jpg -------------------------------------------------------------------------------- /results/pix/100099_bdry_.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CYang0515/pytorch_ssn/289ba3132525e1312a018abdcc082b912f7f9021/results/pix/100099_bdry_.jpg -------------------------------------------------------------------------------- /results/pix/10081_bdry_.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CYang0515/pytorch_ssn/289ba3132525e1312a018abdcc082b912f7f9021/results/pix/10081_bdry_.jpg -------------------------------------------------------------------------------- /results/pix/101027_bdry_.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CYang0515/pytorch_ssn/289ba3132525e1312a018abdcc082b912f7f9021/results/pix/101027_bdry_.jpg -------------------------------------------------------------------------------- /results/pix/101084_bdry_.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CYang0515/pytorch_ssn/289ba3132525e1312a018abdcc082b912f7f9021/results/pix/101084_bdry_.jpg -------------------------------------------------------------------------------- /results/pix/102062_bdry_.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CYang0515/pytorch_ssn/289ba3132525e1312a018abdcc082b912f7f9021/results/pix/102062_bdry_.jpg -------------------------------------------------------------------------------- /results/pix/103006_bdry_.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CYang0515/pytorch_ssn/289ba3132525e1312a018abdcc082b912f7f9021/results/pix/103006_bdry_.jpg -------------------------------------------------------------------------------- /results/pix/103029_bdry_.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CYang0515/pytorch_ssn/289ba3132525e1312a018abdcc082b912f7f9021/results/pix/103029_bdry_.jpg -------------------------------------------------------------------------------- /results/pix/103078_bdry_.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CYang0515/pytorch_ssn/289ba3132525e1312a018abdcc082b912f7f9021/results/pix/103078_bdry_.jpg -------------------------------------------------------------------------------- /results/pix/104010_bdry_.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CYang0515/pytorch_ssn/289ba3132525e1312a018abdcc082b912f7f9021/results/pix/104010_bdry_.jpg -------------------------------------------------------------------------------- /results/pix/104055_bdry_.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CYang0515/pytorch_ssn/289ba3132525e1312a018abdcc082b912f7f9021/results/pix/104055_bdry_.jpg -------------------------------------------------------------------------------- /results/pix/105027_bdry_.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CYang0515/pytorch_ssn/289ba3132525e1312a018abdcc082b912f7f9021/results/pix/105027_bdry_.jpg -------------------------------------------------------------------------------- /results/ssn/100007_bdry_.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CYang0515/pytorch_ssn/289ba3132525e1312a018abdcc082b912f7f9021/results/ssn/100007_bdry_.jpg -------------------------------------------------------------------------------- /results/ssn/100039_bdry_.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CYang0515/pytorch_ssn/289ba3132525e1312a018abdcc082b912f7f9021/results/ssn/100039_bdry_.jpg -------------------------------------------------------------------------------- /results/ssn/100099_bdry_.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CYang0515/pytorch_ssn/289ba3132525e1312a018abdcc082b912f7f9021/results/ssn/100099_bdry_.jpg -------------------------------------------------------------------------------- /results/ssn/10081_bdry_.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CYang0515/pytorch_ssn/289ba3132525e1312a018abdcc082b912f7f9021/results/ssn/10081_bdry_.jpg -------------------------------------------------------------------------------- /results/ssn/101027_bdry_.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CYang0515/pytorch_ssn/289ba3132525e1312a018abdcc082b912f7f9021/results/ssn/101027_bdry_.jpg -------------------------------------------------------------------------------- /results/ssn/101084_bdry_.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CYang0515/pytorch_ssn/289ba3132525e1312a018abdcc082b912f7f9021/results/ssn/101084_bdry_.jpg -------------------------------------------------------------------------------- /results/ssn/102062_bdry_.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CYang0515/pytorch_ssn/289ba3132525e1312a018abdcc082b912f7f9021/results/ssn/102062_bdry_.jpg -------------------------------------------------------------------------------- /results/ssn/103006_bdry_.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CYang0515/pytorch_ssn/289ba3132525e1312a018abdcc082b912f7f9021/results/ssn/103006_bdry_.jpg -------------------------------------------------------------------------------- /results/ssn/103029_bdry_.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CYang0515/pytorch_ssn/289ba3132525e1312a018abdcc082b912f7f9021/results/ssn/103029_bdry_.jpg -------------------------------------------------------------------------------- /results/ssn/103078_bdry_.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CYang0515/pytorch_ssn/289ba3132525e1312a018abdcc082b912f7f9021/results/ssn/103078_bdry_.jpg -------------------------------------------------------------------------------- /results/ssn/104010_bdry_.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CYang0515/pytorch_ssn/289ba3132525e1312a018abdcc082b912f7f9021/results/ssn/104010_bdry_.jpg -------------------------------------------------------------------------------- /results/ssn/104055_bdry_.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CYang0515/pytorch_ssn/289ba3132525e1312a018abdcc082b912f7f9021/results/ssn/104055_bdry_.jpg -------------------------------------------------------------------------------- /results/ssn/105027_bdry_.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CYang0515/pytorch_ssn/289ba3132525e1312a018abdcc082b912f7f9021/results/ssn/105027_bdry_.jpg -------------------------------------------------------------------------------- /results/ssn/106005_bdry_.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CYang0515/pytorch_ssn/289ba3132525e1312a018abdcc082b912f7f9021/results/ssn/106005_bdry_.jpg -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils import data 4 | from dataset import Dataset_T 5 | from model import create_ssn_net, Loss 6 | from PIL import Image 7 | import scipy 8 | from util import get_spixel_image 9 | import sys 10 | import numpy as np 11 | import argparse 12 | import imageio 13 | import scipy.io as scio 14 | # sys.path.append('') 15 | from connectivity import enforce_connectivity 16 | os.environ['CUDA_VISIBLE_DEVICES']='0' 17 | 18 | def compute_spixels(num_spixel, num_steps, pre_model, out_folder): 19 | 20 | if not os.path.exists(out_folder): 21 | os.makedirs(out_folder) 22 | # os.makedirs(out_folder+'png') 23 | # os.makedirs(out_folder + 'mat') 24 | 25 | dtype = 'test' 26 | dataloader = data.DataLoader(Dataset_T(num_spixel=num_spixel), 27 | batch_size=1, shuffle=False, num_workers=1) 28 | model = create_ssn_net(num_spixels=num_spixel, num_iter=num_steps, num_spixels_h=10, num_spixels_w=10, dtype=dtype, ssn=0) 29 | model = torch.nn.DataParallel(model) 30 | if pre_model is not None: 31 | if torch.cuda.is_available(): 32 | model.load_state_dict(torch.load(pre_model)) 33 | else: 34 | model.load_state_dict(torch.load(pre_model, map_location='cpu')) 35 | else: 36 | raise ('no model') 37 | criten = Loss() 38 | device = torch.device('cpu') 39 | if torch.cuda.is_available(): 40 | model.cuda() 41 | device = torch.device('cuda') 42 | for iter, [inputs, num_h, num_w, init_index, cir_index, p2sp_index_, invisible, file_name] in enumerate(dataloader): 43 | with torch.no_grad(): 44 | img = inputs['img'].to(device) 45 | label = inputs['label'].to(device) 46 | problabel = inputs['problabel'].to(device) 47 | num_h = num_h.to(device) 48 | num_w = num_w.to(device) 49 | init_index = [x.to(device) for x in init_index] 50 | cir_index = [x.to(device) for x in cir_index] 51 | p2sp_index_ = p2sp_index_.to(device) 52 | invisible = invisible.to(device) 53 | recon_feat2, recon_label, new_spix_indices = model(img, p2sp_index_, invisible, init_index, cir_index, problabel, num_h, 54 | num_w, device) 55 | # loss, loss_1, loss_2 = criten(recon_feat2, img, recon_label, label) 56 | 57 | given_img = np.asarray(Image.open(file_name[0])) 58 | h, w = given_img.shape[0], given_img.shape[1] 59 | new_spix_indices = new_spix_indices[:, :h, :w].contiguous() 60 | spix_index = new_spix_indices.cpu().numpy()[0] 61 | spix_index = spix_index.astype(int) 62 | 63 | if enforce_connectivity: 64 | segment_size = (given_img.shape[0] * given_img.shape[1]) / (int(num_h*num_w) * 1.0) 65 | min_size = int(0.06 * segment_size) 66 | max_size = int(3 * segment_size) 67 | spix_index = enforce_connectivity(spix_index[np.newaxis, :, :], min_size, max_size)[0] 68 | # given_img_ = np.zeros([spix_index.shape[0], spix_index.shape[1], 3], dtype=np.int) 69 | # h, w = given_img.shape[0], given_img.shape[1] 70 | # given_img_[:h, :w] = given_img 71 | 72 | counter_image = np.zeros_like(given_img) 73 | counter_image = get_spixel_image(counter_image, spix_index) 74 | spixel_image = get_spixel_image(given_img, spix_index) 75 | 76 | imgname = file_name[0].split('/')[-1][:-4] 77 | out_img_file = out_folder + imgname + '_bdry_.jpg' 78 | imageio.imwrite(out_img_file, spixel_image) 79 | # out_file = out_folder + imgname + '.npy' 80 | # np.save(out_file, spix_index) 81 | 82 | # validation code only for sp_pix 400 83 | # out_file_mat = out_folder + 'mat/'+ imgname + '.mat' 84 | # scio.savemat(out_file_mat, {'segs': spix_index}) 85 | 86 | # out_count_file = out_folder + 'png/' + imgname + '.png' 87 | # imageio.imwrite(out_count_file, counter_image) 88 | print(iter) 89 | 90 | 91 | 92 | if __name__ == '__main__': 93 | parser = argparse.ArgumentParser() 94 | 95 | parser.add_argument('--n_spixels', type=int, default=100) 96 | parser.add_argument('--num_steps', type=int, default=10) 97 | parser.add_argument('--result_dir', type=str, default='./results/pix/') 98 | parser.add_argument('--pre_dir', type=str, default='./45000_0.527_model.pt') 99 | 100 | var_args = parser.parse_args() 101 | compute_spixels(var_args.n_spixels, var_args.num_steps, 102 | var_args.pre_dir, var_args.result_dir) 103 | 104 | 105 | 106 | 107 | 108 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | import sys 4 | from torch.utils import data 5 | import torchvision.transforms as tf 6 | from dataset import Dataset 7 | from model import create_ssn_net, Loss 8 | import torch 9 | import os 10 | import time 11 | import logging 12 | from tensorboardX import SummaryWriter 13 | os.environ['CUDA_VISIBLE_DEVICES']='0,1,2,3,4,5,6,7' 14 | 15 | class loss_logger(): 16 | def __init__(self): 17 | self.loss = 0 18 | self.loss1 = 0 19 | self.loss2 = 0 20 | self.count = 0 21 | def add(self, l, l1, l2): 22 | self.loss += l 23 | self.loss1 +=l1 24 | self.loss2 +=l2 25 | self.count +=1 26 | def ave(self): 27 | self.loss /= self.count 28 | self.loss1 /= self.count 29 | self.loss2 /= self.count 30 | def clear(self): 31 | self.__init__() 32 | 33 | 34 | def train_net(args, writer, dtype='train'): 35 | is_shuffle = dtype == 'train' 36 | dataloader = data.DataLoader(Dataset(num_spixel=100, patch_size=[200, 200], root=args.root_dir, dtype=dtype), 37 | batch_size=16, shuffle=is_shuffle, num_workers=4) 38 | 39 | # build model 40 | model = create_ssn_net(num_spixels=100, num_iter=args.num_steps, num_spixels_h=10, num_spixels_w=10, dtype=dtype) 41 | # loss function 42 | criten = Loss() 43 | 44 | device = torch.device('cpu') 45 | if torch.cuda.is_available(): 46 | model = torch.nn.DataParallel(model) 47 | model.cuda() 48 | device = torch.device('cuda') 49 | optim = torch.optim.Adam(model.parameters(), lr=args.l_rate) 50 | 51 | if dtype == 'train' or dtype == 'test': 52 | if dtype == 'train': 53 | model.train() 54 | logger = loss_logger() 55 | for epoch in range(100000): 56 | logger.clear() 57 | for iter, [inputs, num_h, num_w, init_index, cir_index, p2sp_index_, invisible] in enumerate(dataloader): 58 | with torch.autograd.set_detect_anomaly(True): 59 | t0 = time.time() 60 | img = inputs['img'].to(device) 61 | label = inputs['label'].to(device) 62 | problabel = inputs['problabel'].to(device) 63 | num_h = num_h.to(device) 64 | num_w = num_w.to(device) 65 | init_index = [x.to(device) for x in init_index] 66 | cir_index = [x.to(device) for x in cir_index] 67 | p2sp_index_ = p2sp_index_.to(device) 68 | invisible = invisible.to(device) 69 | 70 | t1 = time.time() 71 | recon_feat2, recon_label = model(img, p2sp_index_, invisible, init_index, cir_index, problabel, num_h, num_w, device) 72 | loss, loss_1, loss_2 = criten(recon_feat2, img, recon_label, label) 73 | t2 = time.time() 74 | 75 | # optimizer 76 | optim.zero_grad() 77 | loss.backward() 78 | optim.step() 79 | t3 = time.time() 80 | print(f'epoch:{epoch}, iter:{iter}, total_loss:{loss}, pos_loss:{loss_1}, rec_loss:{loss_2}') 81 | print(f'forward time:{t2-t1:.3f}, backward time:{t3-t2:.3f}, total time:{t3-t0:.3f}') 82 | logger.add(loss.data, loss_1.data, loss_2.data) 83 | 84 | logger.ave() 85 | writer.add_scalar('train/total_loss', logger.loss, epoch) 86 | writer.add_scalar('train/pos_loss', logger.loss1, epoch) 87 | writer.add_scalar('train/rec_loss', logger.loss2, epoch) 88 | 89 | if epoch % 100 == 0 and epoch != 0: 90 | torch.save(model.state_dict(), f'./checkpoints/checkpoints/{epoch}_{loss:.3f}_model.pt') 91 | else: 92 | pass 93 | 94 | else: 95 | pass 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | def main(): 107 | parser = argparse.ArgumentParser() 108 | 109 | parser.add_argument('--l_rate', type=float, default=0.0001) 110 | parser.add_argument('--num_steps', type=int, default=5) 111 | parser.add_argument('--root_dir', type=str, default='/home/yc/ssn_superpixels/data') 112 | 113 | var_args = parser.parse_args() 114 | writer = SummaryWriter('log') 115 | train_net(var_args, writer) 116 | 117 | if __name__ == '__main__': 118 | main() -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import time 5 | from skimage.segmentation import mark_boundaries 6 | 7 | def convert_index(num_spixel_w=10, max_spixels=100, feat_spixel_init=None): 8 | ''' 9 | :param num_spixel_w: the number of spixels of an row 10 | :param max_spixels: the number of spixels 11 | :param feat_spixel_init: 1*1*H*W each pixel with corresponding spixel ids 12 | :return: 13 | ''' 14 | if feat_spixel_init is not None: 15 | length = [] 16 | ind_x = [] 17 | ind_y = [] 18 | feat_spixel_init = feat_spixel_init[0, 0] 19 | for i in range(max_spixels): 20 | id_y, id_x = np.where(feat_spixel_init==i) 21 | l = len(id_y) 22 | ind_y.extend(id_y.tolist()) 23 | ind_x.extend(id_x.tolist()) 24 | length.append(l) 25 | length = np.array(length) 26 | init_x = np.array(ind_x) 27 | init_y = np.array(ind_y) 28 | init_cum = np.cumsum(length) 29 | 30 | p2sp_index_, invisible = Passoc_Nspixel(feat_spixel_init, num_spixel_w, max_spixels) # H*W*9, H*W*9 31 | length = [] 32 | ind_x = [] 33 | ind_y = [] 34 | ind_z = [] 35 | for i in range(max_spixels): 36 | id_y, id_x, id_z = np.where(p2sp_index_ == i) 37 | l = len(id_y) 38 | ind_y.extend(id_y) 39 | ind_x.extend(id_x) 40 | ind_z.extend(id_z) 41 | length.append(l) 42 | cir_x = np.array(ind_x) 43 | cir_y = np.array(ind_y) 44 | cir_z = np.array(ind_z) 45 | cir_cum = np.cumsum(length) 46 | 47 | return [init_x, init_y, init_cum], [cir_x, cir_y, cir_z, cir_cum], p2sp_index_, invisible 48 | 49 | def SpixelFeature(feat, init_index, max_spixels=50, invisible_p=None): 50 | """ 51 | init superpixel feature 52 | :param feat: inputs feature of shape (B,C,H,W) 53 | :param init_index: each spixel with corresponding pixel coordinates 54 | :param type: feature merge style default average 55 | :param max_spixels: superpixel numbers 56 | :param ignore_id: 57 | :param rgb_scale: 58 | :param ignore_feat: 0 59 | :param invisible_p: ignore pixel of shape (B,H,W) 60 | :return: ave_feat: project pixel to superpixel feature; back_ave_feat: project superpixel to pixel feature 61 | """ 62 | b, c, h, w = feat.shape 63 | init_x, init_y, init_l = init_index #B*n B*n B*D n=D*init_l[0,0] 64 | if len(init_x.shape) ==1: 65 | init_x = torch.from_numpy(init_x).unsqueeze(0) 66 | init_y = torch.from_numpy(init_y).unsqueeze(0) 67 | init_l = torch.from_numpy(init_l).unsqueeze(0) 68 | 69 | feat = feat[:, :, init_y[0], init_x[0]] #B*C*n 70 | feat = feat.reshape(b, c, max_spixels, init_l[0, 0]) 71 | # add ignore regions 72 | if invisible_p is not None: 73 | inv = invisible_p[:, init_y[0], init_x[0]] 74 | inv = inv.reshape(b, 1, max_spixels, init_l[0, 0]) 75 | feat = (feat * (1 - inv)).sum(dim=3) 76 | valid = (1 - inv).sum(dim=3) 77 | ave_feat = feat / (valid + 1e-5) 78 | else: 79 | ave_feat = feat.sum(dim=3) / init_l[0, 0].float() 80 | 81 | return ave_feat 82 | 83 | def Passoc_Nspixel(spixel_init, num_spixels_w, num_spixs): 84 | """ 85 | calculate each pixel with corresponding 9 neighborhood spixel ids and whether is visible 86 | :param spixel_init: (H,W) each pixel locates at which superpixel 87 | :param num_spixels_w: the number of superpixel in one row 88 | :param num_spixs: the number of superpixel in one superpixels 89 | :return: p2sp_index_: the index of spixel of a pixel H*W*9 90 | invisible: whether the surrounding spixel is available H*W*9 91 | """ 92 | 93 | # b, c, h, w = pixel_features.shape 94 | center_spix_index = spixel_init[:, :] 95 | 96 | right_index = center_spix_index + 1 97 | left_index = center_spix_index - 1 98 | up_spix_index = center_spix_index - num_spixels_w 99 | up_right_index = up_spix_index + 1 100 | up_left_index = up_spix_index - 1 101 | down_spix_index = center_spix_index + num_spixels_w 102 | down_right_index = down_spix_index + 1 103 | down_left_index = down_spix_index - 1 104 | 105 | up_out_spix = up_spix_index <= -1 106 | down_out_spix = down_spix_index >= num_spixs 107 | right_out_spix = (center_spix_index + 1) % num_spixels_w == 0 108 | left_out_spix = center_spix_index % num_spixels_w == 0 109 | 110 | up_spix_index[up_out_spix] = center_spix_index[up_out_spix] 111 | down_spix_index[down_out_spix] = center_spix_index[down_out_spix] 112 | right_index[right_out_spix] = center_spix_index[right_out_spix] 113 | left_index[left_out_spix] = center_spix_index[left_out_spix] 114 | 115 | up_right_index[(right_out_spix + up_out_spix) > 0] = up_spix_index[(right_out_spix + up_out_spix) > 0] 116 | up_left_index[(left_out_spix + up_out_spix) > 0] = up_spix_index[(left_out_spix + up_out_spix) > 0] 117 | down_right_index[(right_out_spix + down_out_spix) > 0] = down_spix_index[(right_out_spix + down_out_spix) > 0] 118 | down_left_index[(left_out_spix + down_out_spix) > 0] = down_spix_index[(left_out_spix + down_out_spix) > 0] 119 | 120 | p2sp_index_ = np.stack([up_left_index, up_spix_index, up_right_index, 121 | left_index, center_spix_index, right_index, 122 | down_left_index, down_spix_index, down_right_index], axis=-1) # H*W*9 123 | center_out_pixel = np.zeros_like(left_out_spix) 124 | 125 | 126 | invisible = np.stack( 127 | [(left_out_spix + up_out_spix) > 0, up_out_spix, (right_out_spix + up_out_spix) > 0, 128 | left_out_spix, center_out_pixel, right_out_spix, 129 | (left_out_spix + down_out_spix) > 0, down_out_spix, (right_out_spix + down_out_spix) > 0], 130 | axis=-1) 131 | 132 | return p2sp_index_, invisible 133 | 134 | def Passoc(pixel_features, spixel_feat, p2sp_index_, invisible_, device, scale_value=-1): 135 | ''' 136 | calculate the distance between pixel with surrounding 9 superpixel. each iteration spixel_init is fixed, 137 | only change the feature and association. 138 | :param pixel_features: (B,C,H,W) 139 | :param spixel_feat: (B,C,D) D is the number of surpixels 140 | :param p2sp_index_: B*H*W*9 141 | :param invisible_: B*H*W*9 142 | :param scale_value: 143 | :return: 144 | ''' 145 | b, c, h, w = pixel_features.shape 146 | # p2sp_index = p2sp_index_.reshape(1, h, w, 9).repeat(b, 1, 1, 1).long() 147 | if len(p2sp_index_.shape) == 3: 148 | p2sp_index_ = torch.from_numpy(p2sp_index_).unsqueeze(0) 149 | invisible_ = torch.from_numpy(invisible_).unsqueeze(0) 150 | 151 | p2sp_index = p2sp_index_.long() 152 | B_index = torch.arange(0, b).reshape(b, 1, 1, 1).repeat(1, h, w, 9).long().to(device) 153 | spixel_feat = spixel_feat.permute(0, 2, 1) # B*C*D -> B*D*C 154 | p2sp_feat = spixel_feat[B_index, p2sp_index, :] # B*H*W*9*C (occupy storage 660M) 155 | p2sp_feat = p2sp_feat.permute(3, 0, 4, 1, 2) # 9*B*C*H*W 156 | 157 | distance = torch.pow(p2sp_feat - pixel_features, 2.0) # 9*B*C*H*W (occupy storage 440M) 158 | distance = distance.sum(2).permute(1, 0, 2, 3) # / c # B*9*H*W 159 | 160 | invisible = invisible_.permute(0, 3, 1, 2).float() 161 | distance = distance * (1 - invisible) + 10000.0 * invisible 162 | # 163 | distance = distance * scale_value # B*9*H*W 164 | return distance 165 | 166 | 167 | def SpixelFeature2(pixel_features, pixel_assoc, cir_index, invisible, num_spixels_h, num_spixels_w): 168 | ''' 169 | calculate spixel feature according to the similarity matrix between pixel and spixel 170 | :param pixel_features: B*C*H*W 171 | :param pixel_assoc: B*9*H*W 172 | :param p2sp_index_: H*W*9 173 | :param invisible: H*W*9 174 | :param num_spixels_h: 175 | :param num_spixels_w: 176 | :return: 177 | ''' 178 | 179 | b, c, h, w = pixel_features.shape 180 | num_spixels = num_spixels_w * num_spixels_h 181 | cir_x, cir_y, cir_z, cir_l = cir_index 182 | if len(cir_x.shape) ==1: 183 | cir_x = torch.from_numpy(cir_x).unsqueeze(0) 184 | cir_y = torch.from_numpy(cir_y).unsqueeze(0) 185 | cir_z = torch.from_numpy(cir_z).unsqueeze(0) 186 | cir_l = torch.from_numpy(cir_l).unsqueeze(0) 187 | invisible = torch.from_numpy(invisible).unsqueeze(0) 188 | 189 | feat = pixel_features[:, :, cir_y[0], cir_x[0]] #B*C*n 190 | w = pixel_assoc[:, cir_z[0], cir_y[0], cir_x[0]].unsqueeze(1) #B*1*n 191 | inv = invisible[:, cir_y[0], cir_x[0], cir_z[0]].unsqueeze(1) #B*1*n 192 | 193 | s_feat = feat * w * (1 - inv.float()) #B*C*n 194 | weight = w * (1.0 - inv.float()) #B*1*n 195 | 196 | s_feat = s_feat.reshape(b, c, num_spixels, cir_l[0, 0]) #B*C*D*(n/D) 197 | weight = weight.reshape(b, 1, num_spixels, cir_l[0, 0]) #B*1*D*(n/D) 198 | 199 | weight = weight.sum(3) #B*1*D 200 | s_feat = s_feat.sum(3) #B*C*D 201 | 202 | S_feat = s_feat / (weight + 1e-5) 203 | S_feat = S_feat * (weight > 0.001).float() 204 | 205 | return S_feat 206 | 207 | 208 | def compute_assignments(spixel_feat, pixel_features, 209 | p2sp_index_, invisible, device): 210 | 211 | pixel_spixel_neg_dist = Passoc(pixel_features, spixel_feat, p2sp_index_, invisible, device) 212 | pixel_spixel_assoc = (pixel_spixel_neg_dist - pixel_spixel_neg_dist.max(1, keepdim=True)[0]).exp() 213 | pixel_spixel_assoc = pixel_spixel_assoc / (pixel_spixel_assoc.sum(1, keepdim=True)) 214 | 215 | 216 | return pixel_spixel_assoc 217 | 218 | def exec_iter(spixel_feat, trans_features, cir_index, p2sp_index_, invisible, num_spixels_h, num_spixels_w, device): 219 | 220 | # Compute pixel-superpixel assignments 221 | pixel_assoc = \ 222 | compute_assignments(spixel_feat, trans_features, p2sp_index_, invisible, device) 223 | # t2 = time.time() 224 | spixel_feat1 = SpixelFeature2(trans_features, pixel_assoc, cir_index, invisible, 225 | num_spixels_h, num_spixels_w) 226 | # t3 = time.time() 227 | # print(f't2-t1:{t2-t1:.3f}, t3-t2:{t3-t2:.3f}') 228 | 229 | return spixel_feat1, pixel_assoc 230 | 231 | def compute_final_spixel_labels(final_pixel_assoc, p2sp_index, num_spixels_h, num_spixels_w): 232 | """ 233 | calculate the according spixel index of each pixel 234 | :param final_pixel_assoc: B*9*H*W 235 | :param p2sp_index: B*H*W*9 ndarray 236 | :param num_spixels_h: 237 | :param num_spixels_w: 238 | :return: 239 | """ 240 | def RelToAbsIndex(rel_label, p2sp_index, num_spixels_h=1, num_spixels_w=1): 241 | """ 242 | 243 | :param rel_label: B*H*W the position(0-8) of the most similar spixel of each pixel 244 | :param p2sp_index: B*H*W*9 ndarray 245 | :param num_spixels_h: 246 | :param num_spixels_w: 247 | :return: new_spix_indices : B*H*W each pixel corresponding to spixel index 248 | """ 249 | b, h, w = rel_label.shape 250 | rel_label = rel_label.flatten(start_dim=1) # b*n n=h*w 251 | if len(p2sp_index.shape)==3: 252 | p2sp_index = torch.from_numpy(p2sp_index).unsqueeze(0) 253 | 254 | p2sp_index = p2sp_index[0].flatten(end_dim=1) # n*9 255 | index = torch.arange(end=h*w) 256 | index = index.reshape(1, h*w).repeat(b, 1) 257 | real_sindex = p2sp_index[index, rel_label] # b*n 258 | real_sindex = real_sindex.reshape(b, h, w) 259 | 260 | return real_sindex 261 | 262 | rel_label = torch.argmax(final_pixel_assoc, 1) 263 | new_spix_indices = RelToAbsIndex(rel_label, p2sp_index) 264 | return new_spix_indices 265 | 266 | def Semar(new_spixel_feat, new_spix_indices): 267 | """ 268 | convert spixel feature to pixel via hard threshold 269 | :param new_spixel_feat: iter results of size B*C*D 270 | :param new_spix_indices: net final output of size B*H*W each pixel corresponding to spixel index (hard decision) 271 | :return: 272 | """ 273 | b, h, w = new_spix_indices.shape 274 | new_spixel_feat = new_spixel_feat.permute(0, 2, 1) # B*D*C 275 | index = torch.arange(end=b) 276 | index = index.reshape(-1, 1, 1).repeat(1, h, w) 277 | feat = new_spixel_feat[index, new_spix_indices.long(), :] # B*H*W*C 278 | feat_ = feat.permute(0, 3, 1, 2).contiguous() 279 | 280 | return feat_ 281 | 282 | def decode_features(pixel_spixel_assoc, spixel_feat, p2sp_index, 283 | num_spixels_h, num_spixels_w, num_spixels, num_channels): 284 | """ 285 | 286 | :param pixel_spixel_assoc: B*9*H*W the distance of each pixel and surrounding nine spixel 287 | :param spixel_feat: B*C*D spixel feature 288 | :param p2sp_index: B*H*W*9 289 | :param num_spixels_h: 290 | :param num_spixels_w: 291 | :param num_spixels: 292 | :param num_channels: 293 | :return: 294 | """ 295 | b, _, h, w = pixel_spixel_assoc.shape 296 | _, c, d = spixel_feat.shape 297 | img_concat_spixel_feat = spixel_feat[:, :, p2sp_index[0].long()] # B*C*H*W*9 298 | tiled_assoc = pixel_spixel_assoc.repeat(1, c, 1, 1) # B*c9*H*W 299 | img_concat_spixel_feat = img_concat_spixel_feat.permute(0, 1, 4, 2, 3).reshape(b, -1, h, w) 300 | weighted_spixel_feat = img_concat_spixel_feat * tiled_assoc # B*c9*H*W 301 | recon_feat = weighted_spixel_feat.reshape(b, c, 9, h, w) 302 | recon_feat = recon_feat.sum(2) + 1e-10 # B*c*H*W 303 | 304 | # norm 305 | try: 306 | assert recon_feat.min() >= 0., 'fails' 307 | except: 308 | import pdb 309 | pdb.set_trace() 310 | # 311 | recon_feat = recon_feat / recon_feat.sum(1, keepdim=True) 312 | 313 | 314 | return recon_feat 315 | 316 | 317 | def get_spixel_image(given_img, spix_index): 318 | spixel_image = mark_boundaries(given_img / 255., spix_index.astype(int), color = (1,0,0)) 319 | return spixel_image 320 | 321 | 322 | if __name__ == '__main__': 323 | feat = torch.rand((2,5,50,50)) 324 | feat_spixel_init = torch.from_numpy(np.random.randint(0, 50, [1,1,50,50])) 325 | p = SpixelFeature(feat, feat_spixel_init) 326 | s = 1 --------------------------------------------------------------------------------